diff --git a/.bazelrc b/.bazelrc index fc2995dc838c5..3656a86eb364c 100644 --- a/.bazelrc +++ b/.bazelrc @@ -2,7 +2,11 @@ build --cxxopt=--std=c++17 build --copt=-I. # Bazel does not support including its cc_library targets as system # headers. We work around this for generated code +<<<<<<< HEAD # (e.g. torch/headeronly/macros/cmake_macros.h) by making the generated directory a +======= +# (e.g. c10/macros/cmake_macros.h) by making the generated directory a +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # system include path. build --copt=-isystem --copt bazel-out/k8-fastbuild/bin build --copt=-isystem --copt bazel-out/darwin-fastbuild/bin diff --git a/.ci/aarch64_linux/aarch64_ci_build.sh b/.ci/aarch64_linux/aarch64_ci_build.sh index 41cabc3bf5113..771c10aecfde2 100644 --- a/.ci/aarch64_linux/aarch64_ci_build.sh +++ b/.ci/aarch64_linux/aarch64_ci_build.sh @@ -3,6 +3,7 @@ set -eux -o pipefail GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-} +<<<<<<< HEAD # Set CUDA architecture lists to match x86 build_cuda.sh if [[ "$GPU_ARCH_VERSION" == *"12.6"* ]]; then export TORCH_CUDA_ARCH_LIST="8.0;9.0" @@ -17,6 +18,10 @@ if [[ "$DESIRED_CUDA" == *"13"* ]]; then export TORCH_NVCC_FLAGS="-compress-mode=size" # Bundle ptxas into the cu13 wheel, see https://github.com/pytorch/pytorch/issues/163801 export BUILD_BUNDLE_PTXAS=1 +======= +if [[ "$GPU_ARCH_VERSION" == *"12.9"* ]]; then + export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" @@ -30,7 +35,11 @@ cd / # on the mounted pytorch repo git config --global --add safe.directory /pytorch pip install -r /pytorch/requirements.txt +<<<<<<< HEAD pip install auditwheel==6.2.0 wheel +======= +pip install auditwheel==6.2.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if [ "$DESIRED_CUDA" = "cpu" ]; then echo "BASE_CUDA_VERSION is not set. Building cpu wheel." #USE_PRIORITIZED_TEXT_FOR_LD for enable linker script optimization https://github.com/pytorch/pytorch/pull/121975/files @@ -38,6 +47,7 @@ if [ "$DESIRED_CUDA" = "cpu" ]; then else echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA" export USE_SYSTEM_NCCL=1 +<<<<<<< HEAD # Check if we should use NVIDIA libs from PyPI (similar to x86 build_cuda.sh logic) if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then @@ -48,6 +58,8 @@ else export USE_NVIDIA_PYPI_LIBS=1 fi +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #USE_PRIORITIZED_TEXT_FOR_LD for enable linker script optimization https://github.com/pytorch/pytorch/pull/121975/files USE_PRIORITIZED_TEXT_FOR_LD=1 python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda fi diff --git a/.ci/aarch64_linux/aarch64_wheel_ci_build.py b/.ci/aarch64_linux/aarch64_wheel_ci_build.py index 1b6429fa8c06e..cc3a41d4d9227 100755 --- a/.ci/aarch64_linux/aarch64_wheel_ci_build.py +++ b/.ci/aarch64_linux/aarch64_wheel_ci_build.py @@ -69,6 +69,7 @@ def replace_tag(filename) -> None: f.writelines(lines) +<<<<<<< HEAD def patch_library_rpath( folder: str, lib_name: str, @@ -131,11 +132,14 @@ def copy_and_patch_library( patch_library_rpath(folder, lib_name, use_nvidia_pypi_libs, desired_cuda) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def package_cuda_wheel(wheel_path, desired_cuda) -> None: """ Package the cuda wheel libraries """ folder = os.path.dirname(wheel_path) +<<<<<<< HEAD os.mkdir(f"{folder}/tmp") os.system(f"unzip {wheel_path} -d {folder}/tmp") # Delete original wheel since it will be repackaged @@ -249,6 +253,57 @@ def package_cuda_wheel(wheel_path, desired_cuda) -> None: # Copy libraries to unzipped_folder/torch/lib for lib_path in libs_to_copy: copy_and_patch_library(lib_path, folder, use_nvidia_pypi_libs, desired_cuda) +======= + wheelname = os.path.basename(wheel_path) + os.mkdir(f"{folder}/tmp") + os.system(f"unzip {wheel_path} -d {folder}/tmp") + libs_to_copy = [ + "/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12", + "/usr/local/cuda/lib64/libcudnn.so.9", + "/usr/local/cuda/lib64/libcublas.so.12", + "/usr/local/cuda/lib64/libcublasLt.so.12", + "/usr/local/cuda/lib64/libcudart.so.12", + "/usr/local/cuda/lib64/libcufft.so.11", + "/usr/local/cuda/lib64/libcusparse.so.12", + "/usr/local/cuda/lib64/libcusparseLt.so.0", + "/usr/local/cuda/lib64/libcusolver.so.11", + "/usr/local/cuda/lib64/libcurand.so.10", + "/usr/local/cuda/lib64/libnccl.so.2", + "/usr/local/cuda/lib64/libnvJitLink.so.12", + "/usr/local/cuda/lib64/libnvrtc.so.12", + "/usr/local/cuda/lib64/libcudnn_adv.so.9", + "/usr/local/cuda/lib64/libcudnn_cnn.so.9", + "/usr/local/cuda/lib64/libcudnn_graph.so.9", + "/usr/local/cuda/lib64/libcudnn_ops.so.9", + "/usr/local/cuda/lib64/libcudnn_engines_runtime_compiled.so.9", + "/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9", + "/usr/local/cuda/lib64/libcudnn_heuristic.so.9", + "/lib64/libgomp.so.1", + "/usr/lib64/libgfortran.so.5", + "/acl/build/libarm_compute.so", + "/acl/build/libarm_compute_graph.so", + "/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0", + "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0", + "/usr/local/lib/libnvpl_lapack_core.so.0", + "/usr/local/lib/libnvpl_blas_core.so.0", + ] + + if "129" in desired_cuda: + libs_to_copy += [ + "/usr/local/cuda/lib64/libnvrtc-builtins.so.12.9", + "/usr/local/cuda/lib64/libcufile.so.0", + "/usr/local/cuda/lib64/libcufile_rdma.so.1", + ] + + # Copy libraries to unzipped_folder/a/lib + for lib_path in libs_to_copy: + lib_name = os.path.basename(lib_path) + shutil.copy2(lib_path, f"{folder}/tmp/torch/lib/{lib_name}") + os.system( + f"cd {folder}/tmp/torch/lib/; " + f"patchelf --set-rpath '$ORIGIN' --force-rpath {folder}/tmp/torch/lib/{lib_name}" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Make sure the wheel is tagged with manylinux_2_28 for f in os.scandir(f"{folder}/tmp/"): @@ -256,8 +311,19 @@ def package_cuda_wheel(wheel_path, desired_cuda) -> None: replace_tag(f"{f.path}/WHEEL") break +<<<<<<< HEAD os.system(f"wheel pack {folder}/tmp/ -d {folder}") os.system(f"rm -rf {folder}/tmp/") +======= + os.mkdir(f"{folder}/cuda_wheel") + os.system(f"cd {folder}/tmp/; zip -r {folder}/cuda_wheel/{wheelname} *") + shutil.move( + f"{folder}/cuda_wheel/{wheelname}", + f"{folder}/{wheelname}", + copy_function=shutil.copy2, + ) + os.system(f"rm -rf {folder}/tmp/ {folder}/cuda_wheel/") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def complete_wheel(folder: str) -> str: @@ -280,7 +346,18 @@ def complete_wheel(folder: str) -> str: f"/{folder}/dist/{repaired_wheel_name}", ) else: +<<<<<<< HEAD repaired_wheel_name = list_dir(f"/{folder}/dist")[0] +======= + repaired_wheel_name = wheel_name.replace( + "linux_aarch64", "manylinux_2_28_aarch64" + ) + print(f"Renaming {wheel_name} wheel to {repaired_wheel_name}") + os.rename( + f"/{folder}/dist/{wheel_name}", + f"/{folder}/dist/{repaired_wheel_name}", + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) print(f"Copying {repaired_wheel_name} to artifacts") shutil.copy2( @@ -320,6 +397,7 @@ def parse_arguments(): build_vars = "CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000 " # MAX_JOB=5 is not required for CPU backend (see commit 465d98b) if enable_cuda: +<<<<<<< HEAD build_vars += "MAX_JOBS=5 " # Handle PyPI NVIDIA libraries vs bundled libraries @@ -331,6 +409,9 @@ def parse_arguments(): else: print("Configuring build for bundled NVIDIA libraries") # Keep existing static linking approach - already configured above +======= + build_vars = "MAX_JOBS=5 " + build_vars +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) override_package_version = os.getenv("OVERRIDE_PACKAGE_VERSION") desired_cuda = os.getenv("DESIRED_CUDA") diff --git a/.ci/aarch64_linux/build_aarch64_wheel.py b/.ci/aarch64_linux/build_aarch64_wheel.py index 7a4715d330060..ea74c9152d2d0 100755 --- a/.ci/aarch64_linux/build_aarch64_wheel.py +++ b/.ci/aarch64_linux/build_aarch64_wheel.py @@ -438,7 +438,13 @@ def build_torchvision( ) build_vars += f"BUILD_VERSION={version}.dev{build_date}" elif build_version is not None: +<<<<<<< HEAD build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" +======= + build_vars += ( + f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" @@ -493,7 +499,13 @@ def build_torchdata( ) build_vars += f"BUILD_VERSION={version}.dev{build_date}" elif build_version is not None: +<<<<<<< HEAD build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" +======= + build_vars += ( + f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" @@ -549,7 +561,13 @@ def build_torchtext( ) build_vars += f"BUILD_VERSION={version}.dev{build_date}" elif build_version is not None: +<<<<<<< HEAD build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" +======= + build_vars += ( + f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" @@ -607,7 +625,13 @@ def build_torchaudio( ) build_vars += f"BUILD_VERSION={version}.dev{build_date}" elif build_version is not None: +<<<<<<< HEAD build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}" +======= + build_vars += ( + f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if host.using_docker(): build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000" diff --git a/.ci/docker/README.md b/.ci/docker/README.md index 5a97a0a3c2d46..a795edf2c0b9b 100644 --- a/.ci/docker/README.md +++ b/.ci/docker/README.md @@ -36,6 +36,7 @@ See `build.sh` for valid build environments (it's the giant switch). # Set flags (see build.sh) and build image sudo bash -c 'TRITON=1 ./build.sh pytorch-linux-bionic-py3.8-gcc9 -t myimage:latest ``` +<<<<<<< HEAD ## [Guidance] Adding a New Base Docker Image @@ -137,3 +138,5 @@ If your new Docker image needs a library installed from a specific pinned commit The `docker-builds.yml` workflow pre-builds the Docker images whenever changes occur in the `.ci/docker/` directory. This includes the pinned commit updates. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/docker/almalinux/Dockerfile b/.ci/docker/almalinux/Dockerfile index 481d21b96cfe9..2843fe9e1bd97 100644 --- a/.ci/docker/almalinux/Dockerfile +++ b/.ci/docker/almalinux/Dockerfile @@ -64,10 +64,13 @@ FROM cuda as cuda12.9 RUN bash ./install_cuda.sh 12.9 ENV DESIRED_CUDA=12.9 +<<<<<<< HEAD FROM cuda as cuda13.0 RUN bash ./install_cuda.sh 13.0 ENV DESIRED_CUDA=13.0 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) FROM ${ROCM_IMAGE} as rocm ENV PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" ADD ./common/install_mkl.sh install_mkl.sh @@ -80,10 +83,17 @@ ADD ./common/install_mnist.sh install_mnist.sh RUN bash ./install_mnist.sh FROM base as all_cuda +<<<<<<< HEAD COPY --from=cuda12.6 /usr/local/cuda-12.6 /usr/local/cuda-12.6 COPY --from=cuda12.8 /usr/local/cuda-12.8 /usr/local/cuda-12.8 COPY --from=cuda12.9 /usr/local/cuda-12.9 /usr/local/cuda-12.9 COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0 +======= +COPY --from=cuda11.8 /usr/local/cuda-11.8 /usr/local/cuda-11.8 +COPY --from=cuda12.6 /usr/local/cuda-12.6 /usr/local/cuda-12.6 +COPY --from=cuda12.8 /usr/local/cuda-12.8 /usr/local/cuda-12.8 +COPY --from=cuda12.9 /usr/local/cuda-12.9 /usr/local/cuda-12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Final step FROM ${BASE_TARGET} as final diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 8672fae2bbdd1..1deb6a395ac47 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -76,6 +76,7 @@ elif [[ "$image" == *cuda*linter* ]]; then elif [[ "$image" == *linter* ]]; then # Use a separate Dockerfile for linter to keep a small image size DOCKERFILE="linter/Dockerfile" +<<<<<<< HEAD elif [[ "$image" == *riscv* ]]; then # Use RISC-V specific Dockerfile DOCKERFILE="ubuntu-cross-riscv/Dockerfile" @@ -83,6 +84,12 @@ fi _UCX_COMMIT=7836b165abdbe468a2f607e7254011c07d788152 _UCC_COMMIT=430e241bf5d38cbc73fc7a6b89155397232e3f96 +======= +fi + +_UCX_COMMIT=7bb2722ff2187a0cad557ae4a6afa090569f83fb +_UCC_COMMIT=20eae37090a4ce1b32bcce6144ccad0b49943e0b +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if [[ "$image" == *rocm* ]]; then _UCX_COMMIT=cc312eaa4655c0cc5c2bcd796db938f90563bcf6 _UCC_COMMIT=0c0fc21559835044ab107199e334f7157d6a0d3d @@ -94,6 +101,7 @@ tag=$(echo $image | awk -F':' '{print $2}') # configuration, so we hardcode everything here rather than do it # from scratch case "$tag" in +<<<<<<< HEAD pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11) CUDA_VERSION=12.4 ANACONDA_PYTHON_VERSION=3.10 @@ -116,6 +124,11 @@ case "$tag" in ;; pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11) CUDA_VERSION=13.0.0 +======= + pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11) + CUDA_VERSION=12.8.1 + CUDNN_VERSION=9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 VISION=yes @@ -126,6 +139,10 @@ case "$tag" in ;; pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks) CUDA_VERSION=12.8.1 +<<<<<<< HEAD +======= + CUDNN_VERSION=9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 VISION=yes @@ -135,18 +152,92 @@ case "$tag" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; +<<<<<<< HEAD pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm) CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=11 +======= + pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc9-inductor-benchmarks) + CUDA_VERSION=12.8.1 + CUDNN_VERSION=9 + ANACONDA_PYTHON_VERSION=3.12 + GCC_VERSION=9 + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + TRITON=yes + INDUCTOR_BENCHMARKS=yes + ;; + pytorch-linux-jammy-cuda12.8-cudnn9-py3.13-gcc9-inductor-benchmarks) + CUDA_VERSION=12.8.1 + CUDNN_VERSION=9 + ANACONDA_PYTHON_VERSION=3.13 + GCC_VERSION=9 + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + TRITON=yes + INDUCTOR_BENCHMARKS=yes + ;; + pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9) + CUDA_VERSION=12.6.3 + CUDNN_VERSION=9 + ANACONDA_PYTHON_VERSION=3.10 + GCC_VERSION=9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} TRITON=yes ;; +<<<<<<< HEAD pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9) CUDA_VERSION=12.8.1 +======= + pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks) + CUDA_VERSION=12.6 + CUDNN_VERSION=9 + ANACONDA_PYTHON_VERSION=3.10 + GCC_VERSION=9 + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + TRITON=yes + INDUCTOR_BENCHMARKS=yes + ;; + pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks) + CUDA_VERSION=12.6 + CUDNN_VERSION=9 + ANACONDA_PYTHON_VERSION=3.12 + GCC_VERSION=9 + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + TRITON=yes + INDUCTOR_BENCHMARKS=yes + ;; + pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks) + CUDA_VERSION=12.6 + CUDNN_VERSION=9 + ANACONDA_PYTHON_VERSION=3.13 + GCC_VERSION=9 + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + TRITON=yes + INDUCTOR_BENCHMARKS=yes + ;; + pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9) + CUDA_VERSION=12.8.1 + CUDNN_VERSION=9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 VISION=yes @@ -156,23 +247,61 @@ case "$tag" in TRITON=yes ;; pytorch-linux-jammy-py3-clang12-onnx) +<<<<<<< HEAD ANACONDA_PYTHON_VERSION=3.10 +======= + ANACONDA_PYTHON_VERSION=3.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CLANG_VERSION=12 VISION=yes ONNX=yes ;; +<<<<<<< HEAD pytorch-linux-jammy-py3.10-clang12) ANACONDA_PYTHON_VERSION=3.10 +======= + pytorch-linux-jammy-py3.9-clang12) + ANACONDA_PYTHON_VERSION=3.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CLANG_VERSION=12 VISION=yes TRITON=yes ;; +<<<<<<< HEAD pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3) if [[ $tag =~ "jammy" ]]; then ANACONDA_PYTHON_VERSION=3.10 else ANACONDA_PYTHON_VERSION=3.12 fi +======= + pytorch-linux-jammy-py3.11-clang12) + ANACONDA_PYTHON_VERSION=3.11 + CLANG_VERSION=12 + VISION=yes + TRITON=yes + ;; + pytorch-linux-jammy-py3.9-gcc9) + ANACONDA_PYTHON_VERSION=3.9 + GCC_VERSION=9 + VISION=yes + TRITON=yes + ;; + pytorch-linux-jammy-rocm-n-1-py3) + ANACONDA_PYTHON_VERSION=3.10 + GCC_VERSION=11 + VISION=yes + ROCM_VERSION=6.3 + NINJA_VERSION=1.9.0 + TRITON=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + INDUCTOR_BENCHMARKS=yes + ;; + pytorch-linux-jammy-rocm-n-py3) + ANACONDA_PYTHON_VERSION=3.10 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GCC_VERSION=11 VISION=yes ROCM_VERSION=6.4 @@ -181,6 +310,7 @@ case "$tag" in KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} +<<<<<<< HEAD if [[ $tag =~ "benchmarks" ]]; then INDUCTOR_BENCHMARKS=yes fi @@ -199,12 +329,27 @@ case "$tag" in ;; pytorch-linux-jammy-xpu-n-1-py3) ANACONDA_PYTHON_VERSION=3.10 +======= + INDUCTOR_BENCHMARKS=yes + ;; + pytorch-linux-jammy-xpu-2025.0-py3) + ANACONDA_PYTHON_VERSION=3.9 + GCC_VERSION=11 + VISION=yes + XPU_VERSION=2025.0 + NINJA_VERSION=1.9.0 + TRITON=yes + ;; + pytorch-linux-jammy-xpu-2025.1-py3) + ANACONDA_PYTHON_VERSION=3.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GCC_VERSION=11 VISION=yes XPU_VERSION=2025.1 NINJA_VERSION=1.9.0 TRITON=yes ;; +<<<<<<< HEAD pytorch-linux-jammy-xpu-n-py3) ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 @@ -215,6 +360,10 @@ case "$tag" in ;; pytorch-linux-jammy-py3-gcc11-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.10 +======= + pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks) + ANACONDA_PYTHON_VERSION=3.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GCC_VERSION=11 VISION=yes KATEX=yes @@ -222,20 +371,46 @@ case "$tag" in DOCS=yes INDUCTOR_BENCHMARKS=yes ;; +<<<<<<< HEAD pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-clang12) ANACONDA_PYTHON_VERSION=3.10 CUDA_VERSION=12.8.1 +======= + pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-clang12) + ANACONDA_PYTHON_VERSION=3.9 + CUDA_VERSION=12.8.1 + CUDNN_VERSION=9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CLANG_VERSION=12 VISION=yes TRITON=yes ;; +<<<<<<< HEAD +======= + pytorch-linux-jammy-py3-clang12-asan) + ANACONDA_PYTHON_VERSION=3.9 + CLANG_VERSION=12 + VISION=yes + TRITON=yes + ;; + pytorch-linux-jammy-py3-clang15-asan) + ANACONDA_PYTHON_VERSION=3.10 + CLANG_VERSION=15 + VISION=yes + ;; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pytorch-linux-jammy-py3-clang18-asan) ANACONDA_PYTHON_VERSION=3.10 CLANG_VERSION=18 VISION=yes ;; +<<<<<<< HEAD pytorch-linux-jammy-py3.10-gcc11) ANACONDA_PYTHON_VERSION=3.10 +======= + pytorch-linux-jammy-py3.9-gcc11) + ANACONDA_PYTHON_VERSION=3.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GCC_VERSION=11 VISION=yes KATEX=yes @@ -262,10 +437,20 @@ case "$tag" in TRITON_CPU=yes ;; pytorch-linux-jammy-linter) +<<<<<<< HEAD PYTHON_VERSION=3.10 ;; pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter) PYTHON_VERSION=3.10 +======= + # TODO: Use 3.9 here because of this issue https://github.com/python/mypy/issues/13627. + # We will need to update mypy version eventually, but that's for another day. The task + # would be to upgrade mypy to 1.0.0 with Python 3.11 + PYTHON_VERSION=3.9 + ;; + pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter) + PYTHON_VERSION=3.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CUDA_VERSION=12.8.1 ;; pytorch-linux-jammy-aarch64-py3.10-gcc11) @@ -273,6 +458,10 @@ case "$tag" in GCC_VERSION=11 ACL=yes VISION=yes +<<<<<<< HEAD +======= + CONDA_CMAKE=yes +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) OPENBLAS=yes # snadampal: skipping llvm src build install because the current version # from pytorch/llvm:9.0.1 is x86 specific @@ -283,15 +472,22 @@ case "$tag" in GCC_VERSION=11 ACL=yes VISION=yes +<<<<<<< HEAD +======= + CONDA_CMAKE=yes +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) OPENBLAS=yes # snadampal: skipping llvm src build install because the current version # from pytorch/llvm:9.0.1 is x86 specific SKIP_LLVM_SRC_BUILD_INSTALL=yes INDUCTOR_BENCHMARKS=yes ;; +<<<<<<< HEAD pytorch-linux-noble-riscv64-py3.12-gcc14) GCC_VERSION=14 ;; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *) # Catch-all for builds that are not hardcoded. VISION=yes @@ -301,6 +497,10 @@ case "$tag" in fi if [[ "$image" == *cuda* ]]; then extract_version_from_image_name cuda CUDA_VERSION +<<<<<<< HEAD +======= + extract_version_from_image_name cudnn CUDNN_VERSION +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi if [[ "$image" == *rocm* ]]; then extract_version_from_image_name rocm ROCM_VERSION @@ -352,6 +552,12 @@ docker build \ --build-arg "PYTHON_VERSION=${PYTHON_VERSION}" \ --build-arg "GCC_VERSION=${GCC_VERSION}" \ --build-arg "CUDA_VERSION=${CUDA_VERSION}" \ +<<<<<<< HEAD +======= + --build-arg "CUDNN_VERSION=${CUDNN_VERSION}" \ + --build-arg "TENSORRT_VERSION=${TENSORRT_VERSION}" \ + --build-arg "GRADLE_VERSION=${GRADLE_VERSION}" \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) --build-arg "NINJA_VERSION=${NINJA_VERSION:-}" \ --build-arg "KATEX=${KATEX:-}" \ --build-arg "ROCM_VERSION=${ROCM_VERSION:-}" \ @@ -412,6 +618,7 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then fi if [ -n "$GCC_VERSION" ]; then +<<<<<<< HEAD if [[ "$image" == *riscv* ]]; then # Check RISC-V cross-compilation toolchain version if !(drun riscv64-linux-gnu-gcc-${GCC_VERSION} --version 2>&1 | grep -q " $GCC_VERSION\\W"); then @@ -420,6 +627,9 @@ if [ -n "$GCC_VERSION" ]; then exit 1 fi elif !(drun gcc --version 2>&1 | grep -q " $GCC_VERSION\\W"); then +======= + if !(drun gcc --version 2>&1 | grep -q " $GCC_VERSION\\W"); then +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) echo "GCC_VERSION=$GCC_VERSION, but:" drun gcc --version exit 1 diff --git a/.ci/docker/ci_commit_pins/huggingface.txt b/.ci/docker/ci_commit_pins/huggingface.txt new file mode 100644 index 0000000000000..f00d6ca4f9ca7 --- /dev/null +++ b/.ci/docker/ci_commit_pins/huggingface.txt @@ -0,0 +1 @@ +243e186efbf7fb93328dd6b34927a4e8c8f24395 diff --git a/.ci/docker/ci_commit_pins/nccl-cu12.txt b/.ci/docker/ci_commit_pins/nccl-cu12.txt index d099a6b91b76a..57a4f51b2dd1e 100644 --- a/.ci/docker/ci_commit_pins/nccl-cu12.txt +++ b/.ci/docker/ci_commit_pins/nccl-cu12.txt @@ -1 +1,5 @@ +<<<<<<< HEAD v2.27.5-1 +======= +v2.27.3-1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/docker/ci_commit_pins/triton-xpu.txt b/.ci/docker/ci_commit_pins/triton-xpu.txt index b03606f6defc1..6abd4a388f1c2 100644 --- a/.ci/docker/ci_commit_pins/triton-xpu.txt +++ b/.ci/docker/ci_commit_pins/triton-xpu.txt @@ -1 +1,5 @@ +<<<<<<< HEAD 1b0418a9a454b2b93ab8d71f40e59d2297157fae +======= +ae324eeac8e102a2b40370e341460f3791353398 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index a1e9df4725c57..8c295fa297546 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1,5 @@ +<<<<<<< HEAD d704bc6e69c1a588c8edd3cbb67505d554ed65f6 +======= +21876a4bbaf371bcb83df8e6ee4f43a92f524dfe +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index 481de54a50f2c..ceebbc9d8bb65 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -65,10 +65,17 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then fi # Install PyTorch conda deps, as per https://github.com/pytorch/pytorch README +<<<<<<< HEAD if [[ $(uname -m) != "aarch64" ]]; then pip_install mkl==2024.2.0 pip_install mkl-static==2024.2.0 pip_install mkl-include==2024.2.0 +======= + if [[ $(uname -m) == "aarch64" ]]; then + conda_install "openblas==0.3.29=*openmp*" + else + conda_install "mkl=2021.4.0 mkl-include=2021.4.0" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi # Install llvm-8 as it is required to compile llvmlite-0.30.0 from source diff --git a/.ci/docker/common/install_cpython.sh b/.ci/docker/common/install_cpython.sh index 692edd0b898f1..92441b50102e0 100755 --- a/.ci/docker/common/install_cpython.sh +++ b/.ci/docker/common/install_cpython.sh @@ -3,10 +3,18 @@ set -uex -o pipefail PYTHON_DOWNLOAD_URL=https://www.python.org/ftp/python +<<<<<<< HEAD GET_PIP_URL=https://bootstrap.pypa.io/get-pip.py # Python versions to be installed in /opt/$VERSION_NO CPYTHON_VERSIONS=${CPYTHON_VERSIONS:-"3.9.0 3.10.1 3.11.0 3.12.0 3.13.0 3.13.0t 3.14.0 3.14.0t"} +======= +PYTHON_DOWNLOAD_GITHUB_BRANCH=https://github.com/python/cpython/archive/refs/heads # @lint-ignore +GET_PIP_URL=https://bootstrap.pypa.io/get-pip.py + +# Python versions to be installed in /opt/$VERSION_NO +CPYTHON_VERSIONS=${CPYTHON_VERSIONS:-"3.9.0 3.10.1 3.11.0 3.12.0 3.13.0 3.13.0t"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) function check_var { if [ -z "$1" ]; then @@ -23,8 +31,14 @@ function do_cpython_build { tar -xzf Python-$py_ver.tgz local additional_flags="" +<<<<<<< HEAD if [[ "$py_ver" == *"t" ]]; then additional_flags=" --disable-gil" +======= + if [ "$py_ver" == "3.13.0t" ]; then + additional_flags=" --disable-gil" + mv cpython-3.13/ cpython-3.13t/ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi pushd $py_folder @@ -66,15 +80,21 @@ function do_cpython_build { ln -s pip3 ${prefix}/bin/pip fi # install setuptools since python 3.12 is required to use distutils +<<<<<<< HEAD # packaging is needed to create symlink since wheel no longer provides needed information ${prefix}/bin/pip install packaging==25.0 wheel==0.45.1 setuptools==80.9.0 local abi_tag=$(${prefix}/bin/python -c "from packaging.tags import interpreter_name, interpreter_version; import sysconfig ; from sysconfig import get_config_var; print('{0}{1}-{0}{1}{2}'.format(interpreter_name(), interpreter_version(), 't' if sysconfig.get_config_var('Py_GIL_DISABLED') else ''))") +======= + ${prefix}/bin/pip install wheel==0.34.2 setuptools==68.2.2 + local abi_tag=$(${prefix}/bin/python -c "from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag; print('{0}{1}-{2}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag()))") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ln -sf ${prefix} /opt/python/${abi_tag} } function build_cpython { local py_ver=$1 check_var $py_ver +<<<<<<< HEAD local py_suffix=$py_ver local py_folder=$py_ver @@ -89,6 +109,26 @@ function build_cpython { fi wget -q $PYTHON_DOWNLOAD_URL/$py_folder/Python-$py_suffix.tgz -O Python-$py_ver.tgz do_cpython_build $py_ver Python-$py_suffix +======= + check_var $PYTHON_DOWNLOAD_URL + local py_ver_folder=$py_ver + + if [ "$py_ver" = "3.13.0t" ]; then + PY_VER_SHORT="3.13" + PYT_VER_SHORT="3.13t" + check_var $PYTHON_DOWNLOAD_GITHUB_BRANCH + wget $PYTHON_DOWNLOAD_GITHUB_BRANCH/$PY_VER_SHORT.tar.gz -O Python-$py_ver.tgz + do_cpython_build $py_ver cpython-$PYT_VER_SHORT + elif [ "$py_ver" = "3.13.0" ]; then + PY_VER_SHORT="3.13" + check_var $PYTHON_DOWNLOAD_GITHUB_BRANCH + wget $PYTHON_DOWNLOAD_GITHUB_BRANCH/$PY_VER_SHORT.tar.gz -O Python-$py_ver.tgz + do_cpython_build $py_ver cpython-$PY_VER_SHORT + else + wget -q $PYTHON_DOWNLOAD_URL/$py_ver_folder/Python-$py_ver.tgz + do_cpython_build $py_ver Python-$py_ver + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) rm -f Python-$py_ver.tgz } diff --git a/.ci/docker/common/install_cuda.sh b/.ci/docker/common/install_cuda.sh index c6808ea4a7a26..9aaa7459116ab 100644 --- a/.ci/docker/common/install_cuda.sh +++ b/.ci/docker/common/install_cuda.sh @@ -10,8 +10,11 @@ else arch_path='sbsa' fi +<<<<<<< HEAD NVSHMEM_VERSION=3.3.24 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) function install_cuda { version=$1 runfile=$2 @@ -42,6 +45,7 @@ function install_cudnn { rm -rf tmp_cudnn } +<<<<<<< HEAD function install_nvshmem { cuda_major_version=$1 # e.g. "12" nvshmem_version=$2 # e.g. "3.3.9" @@ -97,12 +101,20 @@ function install_124 { function install_126 { CUDNN_VERSION=9.10.2.21 echo "Installing CUDA 12.6.3 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" +======= +function install_126 { + CUDNN_VERSION=9.10.2.21 + echo "Installing CUDA 12.6.3 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.7.1" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) install_cuda 12.6.3 cuda_12.6.3_560.35.05_linux install_cudnn 12 $CUDNN_VERSION +<<<<<<< HEAD install_nvshmem 12 $NVSHMEM_VERSION +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CUDA_VERSION=12.6 bash install_nccl.sh CUDA_VERSION=12.6 bash install_cusparselt.sh @@ -112,15 +124,22 @@ function install_126 { function install_129 { CUDNN_VERSION=9.10.2.21 +<<<<<<< HEAD echo "Installing CUDA 12.9.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" +======= + echo "Installing CUDA 12.9.1 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.7.1" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # install CUDA 12.9.1 in the same container install_cuda 12.9.1 cuda_12.9.1_575.57.08_linux # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement install_cudnn 12 $CUDNN_VERSION +<<<<<<< HEAD install_nvshmem 12 $NVSHMEM_VERSION +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CUDA_VERSION=12.9 bash install_nccl.sh CUDA_VERSION=12.9 bash install_cusparselt.sh @@ -128,17 +147,60 @@ function install_129 { ldconfig } +<<<<<<< HEAD function install_128 { CUDNN_VERSION=9.8.0.87 echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" +======= +function prune_126 { + echo "Pruning CUDA 12.6" + ##################################################################################### + # CUDA 12.6 prune static libs + ##################################################################################### + export NVPRUNE="/usr/local/cuda-12.6/bin/nvprune" + export CUDA_LIB_DIR="/usr/local/cuda-12.6/lib64" + + export GENCODE="-gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90" + export GENCODE_CUDNN="-gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90" + + if [[ -n "$OVERRIDE_GENCODE" ]]; then + export GENCODE=$OVERRIDE_GENCODE + fi + if [[ -n "$OVERRIDE_GENCODE_CUDNN" ]]; then + export GENCODE_CUDNN=$OVERRIDE_GENCODE_CUDNN + fi + + # all CUDA libs except CuDNN and CuBLAS + ls $CUDA_LIB_DIR/ | grep "\.a" | grep -v "culibos" | grep -v "cudart" | grep -v "cudnn" | grep -v "cublas" | grep -v "metis" \ + | xargs -I {} bash -c \ + "echo {} && $NVPRUNE $GENCODE $CUDA_LIB_DIR/{} -o $CUDA_LIB_DIR/{}" + + # prune CuDNN and CuBLAS + $NVPRUNE $GENCODE_CUDNN $CUDA_LIB_DIR/libcublas_static.a -o $CUDA_LIB_DIR/libcublas_static.a + $NVPRUNE $GENCODE_CUDNN $CUDA_LIB_DIR/libcublasLt_static.a -o $CUDA_LIB_DIR/libcublasLt_static.a + + ##################################################################################### + # CUDA 12.6 prune visual tools + ##################################################################################### + export CUDA_BASE="/usr/local/cuda-12.6/" + rm -rf $CUDA_BASE/libnvvp $CUDA_BASE/nsightee_plugins $CUDA_BASE/nsight-compute-2024.3.2 $CUDA_BASE/nsight-systems-2024.5.1/ +} + +function install_128 { + CUDNN_VERSION=9.8.0.87 + echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.7.1" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # install CUDA 12.8.1 in the same container install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement install_cudnn 12 $CUDNN_VERSION +<<<<<<< HEAD install_nvshmem 12 $NVSHMEM_VERSION +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CUDA_VERSION=12.8 bash install_nccl.sh CUDA_VERSION=12.8 bash install_cusparselt.sh @@ -146,6 +208,7 @@ function install_128 { ldconfig } +<<<<<<< HEAD function install_130 { CUDNN_VERSION=9.13.0.50 echo "Installing CUDA 13.0 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" @@ -164,20 +227,29 @@ function install_130 { ldconfig } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # idiomatic parameter and option handling in sh while test $# -gt 0 do case "$1" in +<<<<<<< HEAD 12.4) install_124; ;; 12.6|12.6.*) install_126; +======= + 12.6|12.6.*) install_126; prune_126 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ;; 12.8|12.8.*) install_128; ;; 12.9|12.9.*) install_129; ;; +<<<<<<< HEAD 13.0|13.0.*) install_130; ;; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *) echo "bad argument $1"; exit 1 ;; esac diff --git a/.ci/docker/common/install_cudnn.sh b/.ci/docker/common/install_cudnn.sh new file mode 100644 index 0000000000000..7ee5e73226cb6 --- /dev/null +++ b/.ci/docker/common/install_cudnn.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +if [[ -n "${CUDNN_VERSION}" ]]; then + # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement + mkdir tmp_cudnn + pushd tmp_cudnn + if [[ ${CUDA_VERSION:0:4} == "12.9" || ${CUDA_VERSION:0:4} == "12.8" ]]; then + CUDNN_NAME="cudnn-linux-x86_64-9.10.2.21_cuda12-archive" + elif [[ ${CUDA_VERSION:0:4} == "12.6" ]]; then + CUDNN_NAME="cudnn-linux-x86_64-9.10.2.21_cuda12-archive" + elif [[ ${CUDA_VERSION:0:2} == "11" ]]; then + CUDNN_NAME="cudnn-linux-x86_64-9.1.0.70_cuda11-archive" + else + print "Unsupported CUDA version ${CUDA_VERSION}" + exit 1 + fi + curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz + tar xf ${CUDNN_NAME}.tar.xz + cp -a ${CUDNN_NAME}/include/* /usr/local/cuda/include/ + cp -a ${CUDNN_NAME}/lib/* /usr/local/cuda/lib64/ + popd + rm -rf tmp_cudnn + ldconfig +fi diff --git a/.ci/docker/common/install_cusparselt.sh b/.ci/docker/common/install_cusparselt.sh index b532c086371f1..3443da6482a1e 100644 --- a/.ci/docker/common/install_cusparselt.sh +++ b/.ci/docker/common/install_cusparselt.sh @@ -5,6 +5,7 @@ set -ex # cuSPARSELt license: https://docs.nvidia.com/cuda/cusparselt/license.html mkdir tmp_cusparselt && cd tmp_cusparselt +<<<<<<< HEAD if [[ ${CUDA_VERSION:0:4} =~ "13" ]]; then arch_path='sbsa' export TARGETARCH=${TARGETARCH:-$(uname -m)} @@ -14,6 +15,9 @@ if [[ ${CUDA_VERSION:0:4} =~ "13" ]]; then CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.8.0.4_cuda13-archive" curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz elif [[ ${CUDA_VERSION:0:4} =~ ^12\.[5-9]$ ]]; then +======= +if [[ ${CUDA_VERSION:0:4} =~ ^12\.[5-9]$ ]]; then +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) arch_path='sbsa' export TARGETARCH=${TARGETARCH:-$(uname -m)} if [ ${TARGETARCH} = 'amd64' ] || [ "${TARGETARCH}" = 'x86_64' ]; then @@ -21,6 +25,7 @@ elif [[ ${CUDA_VERSION:0:4} =~ ^12\.[5-9]$ ]]; then fi CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.7.1.0-archive" curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz +<<<<<<< HEAD elif [[ ${CUDA_VERSION:0:4} == "12.4" ]]; then arch_path='sbsa' export TARGETARCH=${TARGETARCH:-$(uname -m)} @@ -29,6 +34,8 @@ elif [[ ${CUDA_VERSION:0:4} == "12.4" ]]; then fi CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.6.2.3-archive" curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else echo "Not sure which libcusparselt version to install for this ${CUDA_VERSION}" fi diff --git a/.ci/docker/common/install_inductor_benchmark_deps.sh b/.ci/docker/common/install_inductor_benchmark_deps.sh index 81467d87f5140..c8ac925d402ad 100644 --- a/.ci/docker/common/install_inductor_benchmark_deps.sh +++ b/.ci/docker/common/install_inductor_benchmark_deps.sh @@ -5,7 +5,13 @@ set -ex source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh" function install_huggingface() { +<<<<<<< HEAD pip_install -r huggingface-requirements.txt +======= + local version + commit=$(get_pinned_commit huggingface) + pip_install "git+https://github.com/huggingface/transformers@${commit}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } function install_timm() { @@ -13,6 +19,7 @@ function install_timm() { commit=$(get_pinned_commit timm) pip_install "git+https://github.com/huggingface/pytorch-image-models@${commit}" +<<<<<<< HEAD } function install_torchbench() { @@ -30,10 +37,15 @@ function install_torchbench() { chown -R jenkins torchbench chown -R jenkins /opt/conda +======= + # Clean up + conda_run pip uninstall -y torch torchvision triton +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } # Pango is needed for weasyprint which is needed for doctr conda_install pango +<<<<<<< HEAD # Stable packages are ok here, just to satisfy TorchBench check pip_install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 @@ -44,3 +56,7 @@ install_timm # Clean up conda_run pip uninstall -y torch torchvision torchaudio triton torchao +======= +install_huggingface +install_timm +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/docker/common/install_nccl.sh b/.ci/docker/common/install_nccl.sh index 58a8e0b4e49c1..ea0cdfc2bf703 100644 --- a/.ci/docker/common/install_nccl.sh +++ b/.ci/docker/common/install_nccl.sh @@ -7,8 +7,11 @@ if [[ ${CUDA_VERSION:0:2} == "11" ]]; then NCCL_VERSION=$(cat ci_commit_pins/nccl-cu11.txt) elif [[ ${CUDA_VERSION:0:2} == "12" ]]; then NCCL_VERSION=$(cat ci_commit_pins/nccl-cu12.txt) +<<<<<<< HEAD elif [[ ${CUDA_VERSION:0:2} == "13" ]]; then NCCL_VERSION=$(cat ci_commit_pins/nccl-cu13.txt) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else echo "Unexpected CUDA_VERSION ${CUDA_VERSION}" exit 1 diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh index 9f23feb5adfaf..d82012d0db1f1 100755 --- a/.ci/docker/common/install_onnx.sh +++ b/.ci/docker/common/install_onnx.sh @@ -19,8 +19,13 @@ pip_install \ transformers==4.36.2 pip_install coloredlogs packaging +<<<<<<< HEAD pip_install onnxruntime==1.22.1 pip_install onnxscript==0.4.0 +======= +pip_install onnxruntime==1.18.1 +pip_install onnxscript==0.3.1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Cache the transformers model to be used later by ONNX tests. We need to run the transformers # package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/ diff --git a/.ci/docker/common/install_rocm.sh b/.ci/docker/common/install_rocm.sh index 5d355276def7c..6162418bb10d3 100644 --- a/.ci/docker/common/install_rocm.sh +++ b/.ci/docker/common/install_rocm.sh @@ -30,6 +30,7 @@ EOF # we want the patch version of 6.4 instead if [[ $(ver $ROCM_VERSION) -eq $(ver 6.4) ]]; then +<<<<<<< HEAD ROCM_VERSION="${ROCM_VERSION}.2" fi @@ -41,14 +42,25 @@ EOF if [[ $(ver "$ROCM_VERSION") -eq $(ver 7.0) ]]; then rocm_baseurl="https://repo.radeon.com/rocm/apt/7.0_alpha2" amdgpu_baseurl="https://repo.radeon.com/amdgpu/30.10_alpha2/ubuntu" +======= + ROCM_VERSION="${ROCM_VERSION}.1" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi # Add amdgpu repository UBUNTU_VERSION_NAME=`cat /etc/os-release | grep UBUNTU_CODENAME | awk -F= '{print $2}'` +<<<<<<< HEAD echo "deb [arch=amd64] ${amdgpu_baseurl} ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list # Add rocm repository wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - +======= + echo "deb [arch=amd64] https://repo.radeon.com/amdgpu/${ROCM_VERSION}/ubuntu ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list + + # Add rocm repository + wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - + local rocm_baseurl="http://repo.radeon.com/rocm/apt/${ROCM_VERSION}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) echo "deb [arch=amd64] ${rocm_baseurl} ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/rocm.list apt-get update --allow-insecure-repositories @@ -82,6 +94,7 @@ EOF done # ROCm 6.3 had a regression where initializing static code objects had significant overhead +<<<<<<< HEAD # CI no longer builds for ROCm 6.3, but # ROCm 6.4 did not yet fix the regression, also HIP branch names are different if [[ $(ver $ROCM_VERSION) -ge $(ver 6.4) ]] && [[ $(ver $ROCM_VERSION) -lt $(ver 7.0) ]]; then @@ -109,6 +122,31 @@ EOF cmake .. -DPython3_EXECUTABLE=/opt/conda/envs/py_${ANACONDA_PYTHON_VERSION}/bin/python3 -DCLR_BUILD_HIP=ON -DHIP_COMMON_DIR=$HIP_COMMON_DIR make -j cp hipamd/lib/libamdhip64.so.6.4.* /opt/rocm/lib/libamdhip64.so.6.4.* +======= + # ROCm 6.4 did not yet fix the regression, also HIP branch names are different + if [[ $(ver $ROCM_VERSION) -ge $(ver 6.3) ]] && [[ $(ver $ROCM_VERSION) -lt $(ver 7.0) ]]; then + if [[ $(ver $ROCM_VERSION) -eq $(ver 6.4.1) ]]; then + HIP_BRANCH=release/rocm-rel-6.4 + VER_STR=6.4 + VER_PATCH=.1 + elif [[ $(ver $ROCM_VERSION) -eq $(ver 6.4) ]]; then + HIP_BRANCH=release/rocm-rel-6.4 + VER_STR=6.4 + elif [[ $(ver $ROCM_VERSION) -eq $(ver 6.3) ]]; then + HIP_BRANCH=rocm-6.3.x + VER_STR=6.3 + fi + # clr build needs CppHeaderParser but can only find it using conda's python + /opt/conda/bin/python -m pip install CppHeaderParser + git clone https://github.com/ROCm/HIP -b $HIP_BRANCH + HIP_COMMON_DIR=$(readlink -f HIP) + git clone https://github.com/jeffdaily/clr -b release/rocm-rel-${VER_STR}${VER_PATCH}-statco-hotfix + mkdir -p clr/build + pushd clr/build + cmake .. -DCLR_BUILD_HIP=ON -DHIP_COMMON_DIR=$HIP_COMMON_DIR + make -j + cp hipamd/lib/libamdhip64.so.${VER_STR}.* /opt/rocm/lib/libamdhip64.so.${VER_STR}.* +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) popd rm -rf HIP clr fi diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh index 8e714bcb6cd32..0368047cbb415 100755 --- a/.ci/docker/common/install_triton.sh +++ b/.ci/docker/common/install_triton.sh @@ -57,7 +57,11 @@ if [ ! -f setup.py ]; then cd python fi +<<<<<<< HEAD pip_install pybind11==3.0.1 +======= +pip_install pybind11==2.13.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: remove patch setup.py once we have a proper fix for https://github.com/triton-lang/triton/issues/4527 as_jenkins sed -i -e 's/https:\/\/tritonlang.blob.core.windows.net\/llvm-builds/https:\/\/oaitriton.blob.core.windows.net\/public\/llvm-builds/g' setup.py @@ -98,10 +102,15 @@ fi if [ -n "${NUMPY_VERSION}" ]; then pip_install "numpy==${NUMPY_VERSION}" fi +<<<<<<< HEAD # IMPORTANT: helion needs to be installed without dependencies. # It depends on torch and triton. We don't want to install # triton and torch from production on Docker CI images if [[ "$ANACONDA_PYTHON_VERSION" != 3.9* ]]; then pip_install helion --no-deps +======= +if [[ "$ANACONDA_PYTHON_VERSION" != 3.9* ]]; then + pip_install helion +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi diff --git a/.ci/docker/common/install_ucc.sh b/.ci/docker/common/install_ucc.sh index 04f15a52e88e3..10048ebc19efc 100755 --- a/.ci/docker/common/install_ucc.sh +++ b/.ci/docker/common/install_ucc.sh @@ -44,12 +44,17 @@ function install_ucc() { ./autogen.sh +<<<<<<< HEAD if [[ -n "$CUDA_VERSION" && $CUDA_VERSION == 13* ]]; then NVCC_GENCODE="-gencode=arch=compute_86,code=compute_86" else # We only run distributed tests on Tesla M60 and A10G NVCC_GENCODE="-gencode=arch=compute_52,code=sm_52 -gencode=arch=compute_86,code=compute_86" fi +======= + # We only run distributed tests on Tesla M60 and A10G + NVCC_GENCODE="-gencode=arch=compute_52,code=sm_52 -gencode=arch=compute_86,code=compute_86" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if [[ -n "$ROCM_VERSION" ]]; then if [[ -n "$PYTORCH_ROCM_ARCH" ]]; then diff --git a/.ci/docker/common/install_xpu.sh b/.ci/docker/common/install_xpu.sh index 0b150872f93ce..f77c9bb6d2f95 100644 --- a/.ci/docker/common/install_xpu.sh +++ b/.ci/docker/common/install_xpu.sh @@ -34,6 +34,7 @@ function install_ubuntu() { # The xpu-smi packages apt-get install -y flex bison xpu-smi +<<<<<<< HEAD if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then # Compute and Media Runtimes @@ -55,6 +56,20 @@ function install_ubuntu() { apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev fi +======= + # Compute and Media Runtimes + apt-get install -y \ + intel-opencl-icd intel-level-zero-gpu level-zero \ + intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \ + libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \ + libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \ + mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo + if [[ "${XPU_DRIVER_TYPE,,}" == "rolling" ]]; then + apt-get install -y intel-ocloc + fi + # Development Packages + apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Install Intel Support Packages apt-get install -y ${XPU_PACKAGES} @@ -143,6 +158,7 @@ function install_sles() { } +<<<<<<< HEAD # Default use GPU driver rolling releases XPU_DRIVER_VERSION="" if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then @@ -155,6 +171,20 @@ if [[ "$XPU_VERSION" == "2025.2" ]]; then XPU_PACKAGES="intel-deep-learning-essentials-2025.2" else XPU_PACKAGES="intel-deep-learning-essentials-2025.1" +======= +# Default use GPU driver LTS releases +XPU_DRIVER_VERSION="/lts/2350" +if [[ "${XPU_DRIVER_TYPE,,}" == "rolling" ]]; then + # Use GPU driver rolling releases + XPU_DRIVER_VERSION="" +fi + +# Default use IntelĀ® oneAPI Deep Learning Essentials 2025.0 +if [[ "$XPU_VERSION" == "2025.1" ]]; then + XPU_PACKAGES="intel-deep-learning-essentials-2025.1" +else + XPU_PACKAGES="intel-deep-learning-essentials-2025.0" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi # The installation depends on the base OS diff --git a/.ci/docker/libtorch/Dockerfile b/.ci/docker/libtorch/Dockerfile index d19431ad8b541..d465977831f80 100644 --- a/.ci/docker/libtorch/Dockerfile +++ b/.ci/docker/libtorch/Dockerfile @@ -69,6 +69,7 @@ RUN bash ./install_cuda.sh 12.9 RUN bash ./install_magma.sh 12.9 RUN ln -sf /usr/local/cuda-12.9 /usr/local/cuda +<<<<<<< HEAD FROM cuda as cuda13.0 RUN bash ./install_cuda.sh 13.0 RUN bash ./install_magma.sh 13.0 @@ -82,6 +83,8 @@ RUN apt-get update -y && \ cp /usr/lib/x86_64-linux-gnu/libibverbs.so* /usr/local/cuda/lib64/ && \ cp /usr/lib/x86_64-linux-gnu/libnl* /usr/local/cuda/lib64/ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) FROM cpu as rocm ARG ROCM_VERSION ARG PYTORCH_ROCM_ARCH diff --git a/.ci/docker/libtorch/build.sh b/.ci/docker/libtorch/build.sh index 7caedf1f44d43..5e21b952097c2 100755 --- a/.ci/docker/libtorch/build.sh +++ b/.ci/docker/libtorch/build.sh @@ -39,10 +39,13 @@ case ${DOCKER_TAG_PREFIX} in DOCKER_GPU_BUILD_ARG="" ;; rocm*) +<<<<<<< HEAD # we want the patch version of 6.4 instead if [[ $(ver $GPU_ARCH_VERSION) -eq $(ver 6.4) ]]; then GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" fi +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) BASE_TARGET=rocm GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" diff --git a/.ci/docker/linter/Dockerfile b/.ci/docker/linter/Dockerfile index 95d08ffea051d..658ad4a91709e 100644 --- a/.ci/docker/linter/Dockerfile +++ b/.ci/docker/linter/Dockerfile @@ -27,7 +27,10 @@ COPY ./common/install_linter.sh install_linter.sh RUN bash ./install_linter.sh RUN rm install_linter.sh +<<<<<<< HEAD RUN chown -R jenkins:jenkins /var/lib/jenkins/ci_env +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) USER jenkins CMD ["bash"] diff --git a/.ci/docker/manywheel/Dockerfile_2_28 b/.ci/docker/manywheel/Dockerfile_2_28 index 4803cb778c905..ebbce2f360f93 100644 --- a/.ci/docker/manywheel/Dockerfile_2_28 +++ b/.ci/docker/manywheel/Dockerfile_2_28 @@ -130,8 +130,12 @@ ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/op RUN for cpython_version in "cp312-cp312" "cp313-cp313" "cp313-cp313t"; do \ /opt/python/${cpython_version}/bin/python -m pip install setuptools wheel; \ done; +<<<<<<< HEAD ADD ./common/patch_libstdc.sh patch_libstdc.sh RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # cmake-3.18.4 from pip; force in case cmake3 already exists RUN yum install -y python3-pip && \ @@ -176,6 +180,10 @@ ENV XPU_DRIVER_TYPE ROLLING RUN python3 -m pip install --upgrade pip && \ python3 -mpip install cmake==3.28.4 ADD ./common/install_xpu.sh install_xpu.sh +<<<<<<< HEAD ENV XPU_VERSION 2025.2 +======= +ENV XPU_VERSION 2025.1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) RUN bash ./install_xpu.sh && rm install_xpu.sh RUN pushd /opt/_internal && tar -xJf static-libs-for-embedding-only.tar.xz && popd diff --git a/.ci/docker/manywheel/Dockerfile_2_28_aarch64 b/.ci/docker/manywheel/Dockerfile_2_28_aarch64 index 6cfab77941fc8..083b5ba3d66ab 100644 --- a/.ci/docker/manywheel/Dockerfile_2_28_aarch64 +++ b/.ci/docker/manywheel/Dockerfile_2_28_aarch64 @@ -71,5 +71,8 @@ RUN rm -rf /opt/python/cp33-cp33m /opt/_internal/cpython-3.3.6 RUN rm -rf /opt/python/cp34-cp34m /opt/_internal/cpython-3.4.6 COPY --from=openblas /opt/OpenBLAS/ /opt/OpenBLAS/ ENV LD_LIBRARY_PATH=/opt/OpenBLAS/lib:$LD_LIBRARY_PATH +<<<<<<< HEAD ADD ./common/patch_libstdc.sh patch_libstdc.sh RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/docker/manywheel/Dockerfile_cuda_aarch64 b/.ci/docker/manywheel/Dockerfile_cuda_aarch64 index 4d2596fea8214..6604681c8cc03 100644 --- a/.ci/docker/manywheel/Dockerfile_cuda_aarch64 +++ b/.ci/docker/manywheel/Dockerfile_cuda_aarch64 @@ -95,5 +95,8 @@ COPY --from=nvpl /opt/nvpl/lib/ /usr/local/lib/ COPY --from=nvpl /opt/nvpl/include/ /usr/local/include/ RUN ln -sf /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda ENV PATH=/usr/local/cuda/bin:$PATH +<<<<<<< HEAD ADD ./common/patch_libstdc.sh patch_libstdc.sh RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/docker/manywheel/Dockerfile_s390x b/.ci/docker/manywheel/Dockerfile_s390x index 46ec7f77ae8ba..f4aad02f34587 100644 --- a/.ci/docker/manywheel/Dockerfile_s390x +++ b/.ci/docker/manywheel/Dockerfile_s390x @@ -131,8 +131,11 @@ RUN pip3 install flatbuffers && \ git clone https://github.com/microsoft/onnxruntime && \ cd onnxruntime && git checkout v1.21.0 && \ git submodule update --init --recursive && \ +<<<<<<< HEAD wget https://github.com/microsoft/onnxruntime/commit/f57db79743c4d1a3553aa05cf95bcd10966030e6.patch && \ patch -p1 < f57db79743c4d1a3553aa05cf95bcd10966030e6.patch && \ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ./build.sh --config Release --parallel 0 --enable_pybind \ --build_wheel --enable_training --enable_training_apis \ --enable_training_ops --skip_tests --allow_running_as_root \ diff --git a/.ci/docker/manywheel/build.sh b/.ci/docker/manywheel/build.sh index 5dee4325857fb..3c33e540e8f72 100755 --- a/.ci/docker/manywheel/build.sh +++ b/.ci/docker/manywheel/build.sh @@ -67,12 +67,15 @@ case ${image} in DOCKER_GPU_BUILD_ARG="--build-arg BASE_CUDA_VERSION=${GPU_ARCH_VERSION} --build-arg DEVTOOLSET_VERSION=13" MANY_LINUX_VERSION="2_28" ;; +<<<<<<< HEAD manylinux2_28-builder:cuda13*) TARGET=cuda_final GPU_IMAGE=amd64/almalinux:8 DOCKER_GPU_BUILD_ARG="--build-arg BASE_CUDA_VERSION=${GPU_ARCH_VERSION} --build-arg DEVTOOLSET_VERSION=13" MANY_LINUX_VERSION="2_28" ;; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) manylinuxaarch64-builder:cuda*) TARGET=cuda_final GPU_IMAGE=amd64/almalinux:8 @@ -81,10 +84,13 @@ case ${image} in DOCKERFILE_SUFFIX="_cuda_aarch64" ;; manylinux2_28-builder:rocm*) +<<<<<<< HEAD # we want the patch version of 6.4 instead if [[ $(ver $GPU_ARCH_VERSION) -eq $(ver 6.4) ]]; then GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" fi +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TARGET=rocm_final MANY_LINUX_VERSION="2_28" DEVTOOLSET_VERSION="11" diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index d3da8e93c639a..dae3230360a84 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -10,7 +10,11 @@ boto3==1.35.42 #Pinned versions: 1.19.12, 1.16.34 #test that import: +<<<<<<< HEAD click==8.3.0 +======= +click +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #Description: Command Line Interface Creation Kit #Pinned versions: #test that import: @@ -50,7 +54,11 @@ flatbuffers==24.12.23 hypothesis==5.35.1 # Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136 #Description: advanced library for generating parametrized tests +<<<<<<< HEAD #Pinned versions: 5.35.1 +======= +#Pinned versions: 3.44.6, 4.53.2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #test that import: test_xnnpack_integration.py, test_pruning_op.py, test_nn.py junitparser==2.1.1 @@ -63,12 +71,20 @@ lark==0.12.0 #Pinned versions: 0.12.0 #test that import: +<<<<<<< HEAD librosa==0.11.0 ; python_version < "3.11" and platform_machine != "s390x" librosa==0.10.2 ; python_version == "3.12" and platform_machine != "s390x" #Description: A python package for music and audio analysis #Pinned versions: >=0.6.2 #test that import: test_spectral_ops.py #librosa depends on numba; disable it for s390x while numba is disabled too +======= +librosa>=0.6.2 ; python_version < "3.11" +librosa==0.10.2 ; python_version == "3.12" +#Description: A python package for music and audio analysis +#Pinned versions: >=0.6.2 +#test that import: test_spectral_ops.py +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #mkl #this breaks linux-bionic-rocm4.5-py3.7 #Description: Intel oneAPI Math Kernel Library @@ -93,9 +109,14 @@ librosa==0.10.2 ; python_version == "3.12" and platform_machine != "s390x" #Pinned versions: #test that import: +<<<<<<< HEAD mypy==1.16.0 ; platform_system != "Windows" # Pin MyPy version because new errors are likely to appear with each release # Skip on Windows as lots of type annotations are POSIX specific +======= +mypy==1.16.0 +# Pin MyPy version because new errors are likely to appear with each release +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #Description: linter #Pinned versions: 1.16.0 #test that import: test_typing.py, test_type_hints.py @@ -112,14 +133,22 @@ ninja==1.11.1.3 #Pinned versions: 1.11.1.3 #test that import: run_test.py, test_cpp_extensions_aot.py,test_determination.py +<<<<<<< HEAD numba==0.49.0 ; python_version < "3.9" and platform_machine != "s390x" numba==0.60.0 ; python_version == "3.9" and platform_machine != "s390x" numba==0.61.2 ; python_version > "3.9" and platform_machine != "s390x" +======= +numba==0.60.0 ; python_version == "3.9" +numba==0.61.2 ; python_version > "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #Description: Just-In-Time Compiler for Numerical Functions #Pinned versions: 0.54.1, 0.49.0, <=0.49.1 #test that import: test_numba_integration.py #For numba issue see https://github.com/pytorch/pytorch/issues/51511 +<<<<<<< HEAD #Need release > 0.61.2 for s390x due to https://github.com/numba/numba/pull/10073 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #numpy #Description: Provides N-dimensional arrays and linear algebra @@ -192,7 +221,11 @@ pytest-flakefinder==1.1.0 #Pinned versions: 1.1.0 #test that import: +<<<<<<< HEAD pytest-rerunfailures==14.0 +======= +pytest-rerunfailures>=10.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #Description: plugin for rerunning failure tests in pytest #Pinned versions: #test that import: @@ -222,9 +255,15 @@ pygments==2.15.0 #Pinned versions: 2.12.0 #test that import: the doctests +<<<<<<< HEAD #pyyaml #Description: data serialization format #Pinned versions: 6.0.2 +======= +#PyYAML +#Description: data serialization format +#Pinned versions: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #test that import: #requests @@ -234,7 +273,11 @@ pygments==2.15.0 #rich #Description: rich text and beautiful formatting in the terminal +<<<<<<< HEAD #Pinned versions: 14.1.0 +======= +#Pinned versions: 10.9.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #test that import: scikit-image==0.19.3 ; python_version < "3.10" @@ -263,7 +306,11 @@ scipy==1.14.1 ; python_version > "3.9" #test that import: # needed by torchgen utils +<<<<<<< HEAD typing_extensions==4.15.0 +======= +typing-extensions>=4.10.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #Description: type hints for python #Pinned versions: #test that import: @@ -273,7 +320,11 @@ typing_extensions==4.15.0 #Pinned versions: #test that import: +<<<<<<< HEAD unittest-xml-reporting==3.2.0 +======= +unittest-xml-reporting<=3.2.0,>=2.0.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #Description: saves unit test results to xml #Pinned versions: #test that import: @@ -284,7 +335,11 @@ lintrunner==0.12.7 #Pinned versions: 0.12.7 #test that import: +<<<<<<< HEAD redis==6.4.0 +======= +redis>=4.0.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #Description: redis database #test that import: anything that tests OSS caching/mocking (inductor/test_codecache.py, inductor/test_max_autotune.py) @@ -303,7 +358,11 @@ pytest-cpp==2.3.0 #Pinned versions: 2.3.0 #test that import: +<<<<<<< HEAD z3-solver==4.15.1.0 ; platform_machine != "s390x" +======= +z3-solver==4.12.6.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #Description: The Z3 Theorem Prover Project #Pinned versions: #test that import: @@ -339,7 +398,11 @@ onnx==1.18.0 ; python_version == "3.13" #Pinned versions: #test that import: +<<<<<<< HEAD onnxscript==0.4.0 +======= +onnxscript==0.3.1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal #Pinned versions: #test that import: @@ -358,11 +421,20 @@ pwlf==2.2.1 #Pinned versions: 2.2.1 #test that import: test_sac_estimator.py +<<<<<<< HEAD # To build PyTorch itself PyYAML==6.0.3 pyzstd==0.18.0 setuptools==79.0.1 six==1.17.0 +======= + +# To build PyTorch itself +astunparse +PyYAML +pyzstd +setuptools +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) scons==4.5.2 ; platform_machine == "aarch64" @@ -379,6 +451,7 @@ dataclasses_json==0.6.7 cmake==4.0.0 #Description: required for building +<<<<<<< HEAD tlparse==0.4.0 #Description: required for log parsing @@ -391,3 +464,11 @@ scikit-build==0.18.1 pyre-extensions==0.0.32 tabulate==0.9.0 #Description: These package are needed to build FBGEMM and torchrec on PyTorch CI +======= +tlparse==0.3.30 +#Description: required for log parsing + +cuda-bindings>=12.0,<13.0 +#Description: required for testing CUDAGraph::raw_cuda_graph(). See https://nvidia.github.io/cuda-python/cuda-bindings/latest/support.html for how this version was chosen. Note "Any fix in the latest bindings would be backported to the prior major version" means that only the newest version of cuda-bindings will get fixes. Depending on the latest version of 12.x is okay because all 12.y versions will be supported via "CUDA minor version compatibility". Pytorch builds against 13.z versions of cuda toolkit work with 12.x versions of cuda-bindings as well because newer drivers work with old toolkits. +#test that import: test_cuda.py +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/docker/requirements-docs.txt b/.ci/docker/requirements-docs.txt index c5ad8e969fb9e..6de9047c0957a 100644 --- a/.ci/docker/requirements-docs.txt +++ b/.ci/docker/requirements-docs.txt @@ -1,7 +1,11 @@ sphinx==5.3.0 #Description: This is used to generate PyTorch docs #Pinned versions: 5.3.0 +<<<<<<< HEAD -e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2 +======= +-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@pytorch_sphinx_theme2#egg=pytorch_sphinx_theme2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering # but it doesn't seem to work and hangs around idly. The initial thought that it is probably @@ -19,10 +23,16 @@ sphinx_sitemap==2.6.0 #Description: This is used to generate sitemap for PyTorch docs #Pinned versions: 2.6.0 +<<<<<<< HEAD matplotlib==3.5.3 ; python_version < "3.13" matplotlib==3.6.3 ; python_version >= "3.13" #Description: This is used to generate PyTorch docs #Pinned versions: 3.6.3 if python > 3.12. Otherwise 3.5.3. +======= +matplotlib==3.5.3 +#Description: This is used to generate PyTorch docs +#Pinned versions: 3.5.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensorboard==2.13.0 ; python_version < "3.13" tensorboard==2.18.0 ; python_version >= "3.13" @@ -50,8 +60,13 @@ IPython==8.12.0 #Pinned versions: 8.12.0 myst-nb==0.17.2 +<<<<<<< HEAD #Description: This is used to generate PyTorch functorch and torch.compile docs. #Pinned versions: 0.17.2 +======= +#Description: This is used to generate PyTorch functorch docs +#Pinned versions: 0.13.2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The following are required to build torch.distributed.elastic.rendezvous.etcd* docs python-etcd==0.4.5 diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index 1545d966571dc..561eb4a3cc51e 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1,5 @@ +<<<<<<< HEAD 3.5.0 +======= +3.4.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/docker/triton_xpu_version.txt b/.ci/docker/triton_xpu_version.txt index 1545d966571dc..561eb4a3cc51e 100644 --- a/.ci/docker/triton_xpu_version.txt +++ b/.ci/docker/triton_xpu_version.txt @@ -1 +1,5 @@ +<<<<<<< HEAD 3.5.0 +======= +3.4.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index 681f6fe750510..01456bc58ec9c 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -96,11 +96,18 @@ ARG ANACONDA_PYTHON_VERSION ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh COPY ./common/common_utils.sh common_utils.sh +<<<<<<< HEAD COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt COPY ci_commit_pins/timm.txt timm.txt COPY ci_commit_pins/torchbench.txt torchbench.txt RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt +======= +COPY ci_commit_pins/huggingface.txt huggingface.txt +COPY ci_commit_pins/timm.txt timm.txt +RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi +RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # (optional) Install non-default Ninja version ARG NINJA_VERSION diff --git a/.ci/docker/ubuntu-xpu/Dockerfile b/.ci/docker/ubuntu-xpu/Dockerfile index 8765249688ce5..8ab05c37b9ec5 100644 --- a/.ci/docker/ubuntu-xpu/Dockerfile +++ b/.ci/docker/ubuntu-xpu/Dockerfile @@ -56,10 +56,17 @@ RUN rm install_openssl.sh ARG INDUCTOR_BENCHMARKS COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh COPY ./common/common_utils.sh common_utils.sh +<<<<<<< HEAD COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt COPY ci_commit_pins/timm.txt timm.txt RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt +======= +COPY ci_commit_pins/huggingface.txt huggingface.txt +COPY ci_commit_pins/timm.txt timm.txt +RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi +RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Install XPU Dependencies ARG XPU_VERSION diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index 1edc8c60c2f07..ecd8600c476f5 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -66,7 +66,10 @@ ENV NCCL_LIB_DIR="/usr/local/cuda/lib64/" # (optional) Install UCC ARG UCX_COMMIT ARG UCC_COMMIT +<<<<<<< HEAD ARG CUDA_VERSION +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ENV UCX_COMMIT $UCX_COMMIT ENV UCC_COMMIT $UCC_COMMIT ENV UCX_HOME /usr @@ -97,11 +100,18 @@ RUN rm install_openssl.sh ARG INDUCTOR_BENCHMARKS COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh COPY ./common/common_utils.sh common_utils.sh +<<<<<<< HEAD COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt COPY ci_commit_pins/timm.txt timm.txt COPY ci_commit_pins/torchbench.txt torchbench.txt RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt +======= +COPY ci_commit_pins/huggingface.txt huggingface.txt +COPY ci_commit_pins/timm.txt timm.txt +RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi +RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ARG TRITON ARG TRITON_CPU @@ -182,6 +192,10 @@ COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm RUN if [ -n "${SKIP_LLVM_SRC_BUILD_INSTALL}" ]; then set -eu; rm -rf /opt/llvm; fi # AWS specific CUDA build guidance +<<<<<<< HEAD +======= +ENV TORCH_CUDA_ARCH_LIST Maxwell +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ENV TORCH_NVCC_FLAGS "-Xfatbin -compress-all" ENV CUDA_PATH /usr/local/cuda diff --git a/.ci/libtorch/build.sh b/.ci/libtorch/build.sh index c2d67f8b1bb29..7c668ca81e714 100644 --- a/.ci/libtorch/build.sh +++ b/.ci/libtorch/build.sh @@ -7,4 +7,8 @@ set -ex SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +<<<<<<< HEAD USE_NVSHMEM=0 USE_CUSPARSELT=0 BUILD_PYTHONLESS=1 DESIRED_PYTHON="3.10" ${SCRIPTPATH}/../manywheel/build.sh +======= +USE_CUSPARSELT=0 BUILD_PYTHONLESS=1 DESIRED_PYTHON="3.9" ${SCRIPTPATH}/../manywheel/build.sh +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/magma/Makefile b/.ci/magma/Makefile index 4169aedd03fa5..233925d95eb67 100644 --- a/.ci/magma/Makefile +++ b/.ci/magma/Makefile @@ -16,7 +16,10 @@ DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \ magma/build_magma.sh .PHONY: all +<<<<<<< HEAD all: magma-cuda130 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) all: magma-cuda129 all: magma-cuda128 all: magma-cuda126 @@ -26,12 +29,15 @@ clean: $(RM) -r magma-* $(RM) -r output +<<<<<<< HEAD .PHONY: magma-cuda130 magma-cuda130: DESIRED_CUDA := 13.0 magma-cuda130: CUDA_ARCH_LIST := -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90 -gencode arch=compute_100,code=sm_100 -gencode arch=compute_120,code=sm_120 magma-cuda130: $(DOCKER_RUN) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .PHONY: magma-cuda129 magma-cuda129: DESIRED_CUDA := 12.9 magma-cuda129: CUDA_ARCH_LIST += -gencode arch=compute_100,code=sm_100 -gencode arch=compute_120,code=sm_120 diff --git a/.ci/magma/build_magma.sh b/.ci/magma/build_magma.sh index 6f1924fa45965..c88109ab01765 100755 --- a/.ci/magma/build_magma.sh +++ b/.ci/magma/build_magma.sh @@ -28,7 +28,10 @@ pushd ${PACKAGE_DIR}/magma-${MAGMA_VERSION} patch < ${PACKAGE_FILES}/CMake.patch patch < ${PACKAGE_FILES}/cmakelists.patch patch -p0 < ${PACKAGE_FILES}/thread_queue.patch +<<<<<<< HEAD patch -p1 < ${PACKAGE_FILES}/cuda13.patch +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) patch -p1 < ${PACKAGE_FILES}/getrf_shfl.patch patch -p1 < ${PACKAGE_FILES}/getrf_nbparam.patch # The build.sh script expects to be executed from the sources root folder @@ -38,7 +41,10 @@ popd # Package recipe, license and tarball # Folder and package name are backward compatible for the build workflow cp ${PACKAGE_FILES}/build.sh ${PACKAGE_RECIPE}/build.sh +<<<<<<< HEAD cp ${PACKAGE_FILES}/cuda13.patch ${PACKAGE_RECIPE}/cuda13.patch +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cp ${PACKAGE_FILES}/thread_queue.patch ${PACKAGE_RECIPE}/thread_queue.patch cp ${PACKAGE_FILES}/cmakelists.patch ${PACKAGE_RECIPE}/cmakelists.patch cp ${PACKAGE_FILES}/getrf_shfl.patch ${PACKAGE_RECIPE}/getrf_shfl.patch diff --git a/.ci/manywheel/build.sh b/.ci/manywheel/build.sh index 6b2a60bc5ca28..82339921b69dd 100755 --- a/.ci/manywheel/build.sh +++ b/.ci/manywheel/build.sh @@ -5,6 +5,13 @@ set -ex SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" case "${GPU_ARCH_TYPE:-BLANK}" in +<<<<<<< HEAD +======= + BLANK) + # Legacy behavior for CircleCI + bash "${SCRIPTPATH}/build_cuda.sh" + ;; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cuda) bash "${SCRIPTPATH}/build_cuda.sh" ;; diff --git a/.ci/manywheel/build_common.sh b/.ci/manywheel/build_common.sh index 4c268befb30e5..080bb20a3b0b6 100644 --- a/.ci/manywheel/build_common.sh +++ b/.ci/manywheel/build_common.sh @@ -97,7 +97,11 @@ if [[ -z "$PYTORCH_ROOT" ]]; then exit 1 fi pushd "$PYTORCH_ROOT" +<<<<<<< HEAD retry pip install -qUr requirements-build.txt +======= +retry pip install -q cmake +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python setup.py clean retry pip install -qr requirements.txt case ${DESIRED_PYTHON} in @@ -138,11 +142,36 @@ fi echo "Calling setup.py bdist at $(date)" +<<<<<<< HEAD time CMAKE_ARGS=${CMAKE_ARGS[@]} \ EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \ BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \ USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \ python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR +======= +if [[ "$USE_SPLIT_BUILD" == "true" ]]; then + echo "Calling setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)" + time EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \ + BUILD_LIBTORCH_WHL=1 BUILD_PYTHON_ONLY=0 \ + BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \ + USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \ + python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR + echo "Finished setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)" + echo "Calling setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" + time EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \ + BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 \ + BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \ + USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \ + CMAKE_FRESH=1 python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR + echo "Finished setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" +else + time CMAKE_ARGS=${CMAKE_ARGS[@]} \ + EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \ + BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \ + USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \ + python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR +fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) echo "Finished setup.py bdist at $(date)" # Build libtorch packages @@ -255,6 +284,13 @@ ls /tmp/$WHEELHOUSE_DIR mkdir -p "/$WHEELHOUSE_DIR" mv /tmp/$WHEELHOUSE_DIR/torch*linux*.whl /$WHEELHOUSE_DIR/ +<<<<<<< HEAD +======= +if [[ "$USE_SPLIT_BUILD" == "true" ]]; then + mv /tmp/$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/ || true +fi + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if [[ -n "$BUILD_PYTHONLESS" ]]; then mkdir -p /$LIBTORCH_HOUSE_DIR mv /tmp/$LIBTORCH_HOUSE_DIR/*.zip /$LIBTORCH_HOUSE_DIR @@ -431,8 +467,21 @@ if [[ -z "$BUILD_PYTHONLESS" ]]; then pushd $PYTORCH_ROOT/test # Install the wheel for this Python version +<<<<<<< HEAD pip uninstall -y "$TORCH_PACKAGE_NAME" +======= + if [[ "$USE_SPLIT_BUILD" == "true" ]]; then + pip uninstall -y "$TORCH_NO_PYTHON_PACKAGE_NAME" || true + fi + + pip uninstall -y "$TORCH_PACKAGE_NAME" + + if [[ "$USE_SPLIT_BUILD" == "true" ]]; then + pip install "$TORCH_NO_PYTHON_PACKAGE_NAME" --no-index -f /$WHEELHOUSE_DIR --no-dependencies -v + fi + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pip install "$TORCH_PACKAGE_NAME" --no-index -f /$WHEELHOUSE_DIR --no-dependencies -v # Print info on the libraries installed in this wheel diff --git a/.ci/manywheel/build_cuda.sh b/.ci/manywheel/build_cuda.sh index 6ed38f8b25c62..76464615ce30c 100644 --- a/.ci/manywheel/build_cuda.sh +++ b/.ci/manywheel/build_cuda.sh @@ -66,9 +66,12 @@ case ${CUDA_VERSION} in TORCH_CUDA_ARCH_LIST="7.5;8.0;9.0;10.0;12.0+PTX" fi ;; +<<<<<<< HEAD 13.0) TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0;12.0+PTX" ;; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 12.6) TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6;9.0" ;; @@ -113,6 +116,7 @@ DEPS_SONAME=( ) +<<<<<<< HEAD # CUDA_VERSION 12.*, 13.* if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then export USE_STATIC_CUDNN=0 @@ -125,6 +129,15 @@ if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then echo "Bundling with cudnn and cublas." +======= +# CUDA_VERSION 12.6, 12.8, 12.9 +if [[ $CUDA_VERSION == 12* ]]; then + export USE_STATIC_CUDNN=0 + # Try parallelizing nvcc as well + export TORCH_NVCC_FLAGS="-Xfatbin -compress-all --threads 2" + if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then + echo "Bundling with cudnn and cublas." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DEPS_LIST+=( "/usr/local/cuda/lib64/libcudnn_adv.so.9" "/usr/local/cuda/lib64/libcudnn_cnn.so.9" @@ -134,12 +147,23 @@ if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then "/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9" "/usr/local/cuda/lib64/libcudnn_heuristic.so.9" "/usr/local/cuda/lib64/libcudnn.so.9" +<<<<<<< HEAD "/usr/local/cuda/lib64/libcusparseLt.so.0" "/usr/local/cuda/lib64/libnvrtc-builtins.so" "/usr/local/cuda/lib64/libcufile.so.0" "/usr/local/cuda/lib64/libcufile_rdma.so.1" "/usr/local/cuda/lib64/libnvshmem_host.so.3" "/usr/local/cuda/extras/CUPTI/lib64/libnvperf_host.so" +======= + "/usr/local/cuda/lib64/libcublas.so.12" + "/usr/local/cuda/lib64/libcublasLt.so.12" + "/usr/local/cuda/lib64/libcusparseLt.so.0" + "/usr/local/cuda/lib64/libcudart.so.12" + "/usr/local/cuda/lib64/libnvrtc.so.12" + "/usr/local/cuda/lib64/libnvrtc-builtins.so" + "/usr/local/cuda/lib64/libcufile.so.0" + "/usr/local/cuda/lib64/libcufile_rdma.so.1" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) DEPS_SONAME+=( "libcudnn_adv.so.9" @@ -150,6 +174,7 @@ if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then "libcudnn_engines_precompiled.so.9" "libcudnn_heuristic.so.9" "libcudnn.so.9" +<<<<<<< HEAD "libcusparseLt.so.0" "libnvrtc-builtins.so" "libnvshmem_host.so.3" @@ -227,6 +252,35 @@ if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then ) fi +======= + "libcublas.so.12" + "libcublasLt.so.12" + "libcusparseLt.so.0" + "libcudart.so.12" + "libnvrtc.so.12" + "libnvrtc-builtins.so" + "libcufile.so.0" + "libcufile_rdma.so.1" + ) + else + echo "Using nvidia libs from pypi." + CUDA_RPATHS=( + '$ORIGIN/../../nvidia/cublas/lib' + '$ORIGIN/../../nvidia/cuda_cupti/lib' + '$ORIGIN/../../nvidia/cuda_nvrtc/lib' + '$ORIGIN/../../nvidia/cuda_runtime/lib' + '$ORIGIN/../../nvidia/cudnn/lib' + '$ORIGIN/../../nvidia/cufft/lib' + '$ORIGIN/../../nvidia/curand/lib' + '$ORIGIN/../../nvidia/cusolver/lib' + '$ORIGIN/../../nvidia/cusparse/lib' + '$ORIGIN/../../nvidia/cusparselt/lib' + '$ORIGIN/../../cusparselt/lib' + '$ORIGIN/../../nvidia/nccl/lib' + '$ORIGIN/../../nvidia/nvtx/lib' + '$ORIGIN/../../nvidia/cufile/lib' + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CUDA_RPATHS=$(IFS=: ; echo "${CUDA_RPATHS[*]}") export C_SO_RPATH=$CUDA_RPATHS':$ORIGIN:$ORIGIN/lib' export LIB_SO_RPATH=$CUDA_RPATHS':$ORIGIN' diff --git a/.ci/manywheel/build_libtorch.sh b/.ci/manywheel/build_libtorch.sh index 4de775b1823ca..16286cbad4c9e 100644 --- a/.ci/manywheel/build_libtorch.sh +++ b/.ci/manywheel/build_libtorch.sh @@ -92,7 +92,11 @@ if [[ -z "$PYTORCH_ROOT" ]]; then exit 1 fi pushd "$PYTORCH_ROOT" +<<<<<<< HEAD retry pip install -qUr requirements-build.txt +======= +retry pip install -q cmake +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python setup.py clean retry pip install -qr requirements.txt retry pip install -q numpy==2.0.1 @@ -104,7 +108,11 @@ if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then export ROCclr_DIR=/opt/rocm/rocclr/lib/cmake/rocclr fi +<<<<<<< HEAD echo "Calling 'python -m pip install .' at $(date)" +======= +echo "Calling setup.py install at $(date)" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if [[ $LIBTORCH_VARIANT = *"static"* ]]; then STATIC_CMAKE_FLAG="-DTORCH_STATIC=1" @@ -120,7 +128,11 @@ fi # TODO: Remove this flag once https://github.com/pytorch/pytorch/issues/55952 is closed CFLAGS='-Wno-deprecated-declarations' \ BUILD_LIBTORCH_CPU_WITH_DEBUG=1 \ +<<<<<<< HEAD python -m pip install --no-build-isolation -v . +======= + python setup.py install +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mkdir -p libtorch/{lib,bin,include,share} diff --git a/.ci/manywheel/build_rocm.sh b/.ci/manywheel/build_rocm.sh index ffc15bcdc5fad..1291c7d4b8985 100755 --- a/.ci/manywheel/build_rocm.sh +++ b/.ci/manywheel/build_rocm.sh @@ -194,7 +194,11 @@ ROCBLAS_LIB_SRC=$ROCM_HOME/lib/rocblas/library ROCBLAS_LIB_DST=lib/rocblas/library ROCBLAS_ARCH_SPECIFIC_FILES=$(ls $ROCBLAS_LIB_SRC | grep -E $ARCH) ROCBLAS_OTHER_FILES=$(ls $ROCBLAS_LIB_SRC | grep -v gfx) +<<<<<<< HEAD ROCBLAS_LIB_FILES=($ROCBLAS_ARCH_SPECIFIC_FILES $ROCBLAS_OTHER_FILES) +======= +ROCBLAS_LIB_FILES=($ROCBLAS_ARCH_SPECIFIC_FILES $OTHER_FILES) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # hipblaslt library files HIPBLASLT_LIB_SRC=$ROCM_HOME/lib/hipblaslt/library diff --git a/.ci/manywheel/build_xpu.sh b/.ci/manywheel/build_xpu.sh index bd7b168be336c..034ef7cf08fc9 100755 --- a/.ci/manywheel/build_xpu.sh +++ b/.ci/manywheel/build_xpu.sh @@ -25,7 +25,10 @@ source /opt/intel/oneapi/mpi/latest/env/vars.sh export USE_STATIC_MKL=1 export USE_ONEMKL=1 export USE_XCCL=1 +<<<<<<< HEAD export USE_MPI=0 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) WHEELHOUSE_DIR="wheelhousexpu" LIBTORCH_HOUSE_DIR="libtorch_housexpu" diff --git a/.ci/pytorch/build-mobile.sh b/.ci/pytorch/build-mobile.sh new file mode 100755 index 0000000000000..1f253ff58c03d --- /dev/null +++ b/.ci/pytorch/build-mobile.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +# DO NOT ADD 'set -x' not to reveal CircleCI secret context environment variables +set -eu -o pipefail + +# This script uses linux host toolchain + mobile build options in order to +# build & test mobile libtorch without having to setup Android/iOS +# toolchain/simulator. + +# shellcheck source=./common.sh +source "$(dirname "${BASH_SOURCE[0]}")/common.sh" +# shellcheck source=./common-build.sh +source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" + +# Install torch & torchvision - used to download & trace test model. +# Ideally we should use the libtorch built on the PR so that backward +# incompatible changes won't break this script - but it will significantly slow +# down mobile CI jobs. +# Here we install nightly instead of stable so that we have an option to +# temporarily skip mobile CI jobs on BC-breaking PRs until they are in nightly. +retry pip install --pre torch torchvision \ + -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html \ + --progress-bar off + +# Run end-to-end process of building mobile library, linking into the predictor +# binary, and running forward pass with a real model. +if [[ "$BUILD_ENVIRONMENT" == *-mobile-custom-build-static* ]]; then + TEST_CUSTOM_BUILD_STATIC=1 test/mobile/custom_build/build.sh +elif [[ "$BUILD_ENVIRONMENT" == *-mobile-lightweight-dispatch* ]]; then + test/mobile/lightweight_dispatch/build.sh +else + TEST_DEFAULT_BUILD=1 test/mobile/custom_build/build.sh +fi + +print_sccache_stats diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 1c88554c2af96..c239e3c7817b8 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -11,6 +11,13 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" # shellcheck source=./common-build.sh source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" +<<<<<<< HEAD +======= +if [[ "$BUILD_ENVIRONMENT" == *-mobile-*build* ]]; then + exec "$(dirname "${BASH_SOURCE[0]}")/build-mobile.sh" "$@" +fi + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) echo "Python version:" python --version @@ -50,6 +57,12 @@ if [[ ${BUILD_ENVIRONMENT} == *"parallelnative"* ]]; then export ATEN_THREADING=NATIVE fi +<<<<<<< HEAD +======= +# Enable LLVM dependency for TensorExpr testing +export USE_LLVM=/opt/llvm +export LLVM_DIR=/opt/llvm/lib/cmake/llvm +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ! which conda; then # In ROCm CIs, we are doing cross compilation on build machines with @@ -92,6 +105,7 @@ if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then export ACL_ROOT_DIR=/ComputeLibrary fi +<<<<<<< HEAD if [[ "$BUILD_ENVIRONMENT" == *riscv64* ]]; then if [[ -f /opt/riscv-cross-env/bin/activate ]]; then # shellcheck disable=SC1091 @@ -113,6 +127,8 @@ if [[ "$BUILD_ENVIRONMENT" == *riscv64* ]]; then fi +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if [[ "$BUILD_ENVIRONMENT" == *libtorch* ]]; then POSSIBLE_JAVA_HOMES=() POSSIBLE_JAVA_HOMES+=(/usr/local) @@ -138,8 +154,31 @@ if [[ "$BUILD_ENVIRONMENT" == *libtorch* ]]; then fi # Use special scripts for Android builds +<<<<<<< HEAD if [[ "$BUILD_ENVIRONMENT" == *vulkan* ]]; then +======= +if [[ "${BUILD_ENVIRONMENT}" == *-android* ]]; then + export ANDROID_NDK=/opt/ndk + build_args=() + if [[ "${BUILD_ENVIRONMENT}" == *-arm-v7a* ]]; then + build_args+=("-DANDROID_ABI=armeabi-v7a") + elif [[ "${BUILD_ENVIRONMENT}" == *-arm-v8a* ]]; then + build_args+=("-DANDROID_ABI=arm64-v8a") + elif [[ "${BUILD_ENVIRONMENT}" == *-x86_32* ]]; then + build_args+=("-DANDROID_ABI=x86") + elif [[ "${BUILD_ENVIRONMENT}" == *-x86_64* ]]; then + build_args+=("-DANDROID_ABI=x86_64") + fi + if [[ "${BUILD_ENVIRONMENT}" == *vulkan* ]]; then + build_args+=("-DUSE_VULKAN=ON") + fi + build_args+=("-DUSE_LITE_INTERPRETER_PROFILER=OFF") + exec ./scripts/build_android.sh "${build_args[@]}" "$@" +fi + +if [[ "$BUILD_ENVIRONMENT" != *android* && "$BUILD_ENVIRONMENT" == *vulkan* ]]; then +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) export USE_VULKAN=1 # shellcheck disable=SC1091 source /var/lib/jenkins/vulkansdk/setup-env.sh @@ -173,7 +212,10 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then source /opt/intel/oneapi/mpi/latest/env/vars.sh # Enable XCCL build export USE_XCCL=1 +<<<<<<< HEAD export USE_MPI=0 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # XPU kineto feature dependencies are not fully ready, disable kineto build as temp WA export USE_KINETO=0 export TORCH_XPU_ARCH_LIST=pvc @@ -195,6 +237,7 @@ fi # We only build FlashAttention files for CUDA 8.0+, and they require large amounts of # memory to build and will OOM +<<<<<<< HEAD if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && echo "${TORCH_CUDA_ARCH_LIST}" | tr ' ' '\n' | sed 's/$/>= 8.0/' | bc | grep -q 1; then J=2 # default to 2 jobs @@ -205,6 +248,12 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && echo "${TORCH_CUDA_ARCH_LIST}" | tr ' esac echo "Building FlashAttention with job limit $J" export BUILD_CUSTOM_STEP="ninja -C build flash_attention -j ${J}" +======= +if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && [[ 1 -eq $(echo "${TORCH_CUDA_ARCH_LIST} >= 8.0" | bc) ]] && [ -z "$MAX_JOBS_OVERRIDE" ]; then + echo "WARNING: FlashAttention files require large amounts of memory to build and will OOM" + echo "Setting MAX_JOBS=(nproc-2)/3 to reduce memory usage" + export MAX_JOBS="$(( $(nproc --ignore=2) / 3 ))" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi if [[ "${BUILD_ENVIRONMENT}" == *clang* ]]; then @@ -219,6 +268,10 @@ if [[ "$BUILD_ENVIRONMENT" == *-clang*-asan* ]]; then export USE_ASAN=1 export REL_WITH_DEB_INFO=1 export UBSAN_FLAGS="-fno-sanitize-recover=all" +<<<<<<< HEAD +======= + unset USE_LLVM +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi if [[ "${BUILD_ENVIRONMENT}" == *no-ops* ]]; then @@ -229,7 +282,11 @@ if [[ "${BUILD_ENVIRONMENT}" == *-pch* ]]; then export USE_PRECOMPILED_HEADERS=1 fi +<<<<<<< HEAD if [[ "${BUILD_ENVIRONMENT}" != *cuda* ]]; then +======= +if [[ "${BUILD_ENVIRONMENT}" != *android* && "${BUILD_ENVIRONMENT}" != *cuda* ]]; then +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) export BUILD_STATIC_RUNTIME_BENCHMARK=ON fi @@ -239,7 +296,11 @@ fi # Do not change workspace permissions for ROCm and s390x CI jobs # as it can leave workspace with bad permissions for cancelled jobs +<<<<<<< HEAD if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *riscv64* && -d /var/lib/jenkins/workspace ]]; then +======= +if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *s390x* && -d /var/lib/jenkins/workspace ]]; then +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Workaround for dind-rootless userid mapping (https://github.com/pytorch/ci-infra/issues/96) WORKSPACE_ORIGINAL_OWNER_ID=$(stat -c '%u' "/var/lib/jenkins/workspace") cleanup_workspace() { @@ -284,18 +345,32 @@ else # XLA test build fails when WERROR=1 # set only when building other architectures # or building non-XLA tests. +<<<<<<< HEAD if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *xla* && "$BUILD_ENVIRONMENT" != *riscv64* ]]; then +======= + if [[ "$BUILD_ENVIRONMENT" != *rocm* && + "$BUILD_ENVIRONMENT" != *xla* ]]; then +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Install numpy-2.0.2 for builds which are backward compatible with 1.X python -mpip install numpy==2.0.2 WERROR=1 python setup.py clean +<<<<<<< HEAD WERROR=1 python setup.py bdist_wheel +======= + if [[ "$USE_SPLIT_BUILD" == "true" ]]; then + python3 tools/packaging/split_wheel.py bdist_wheel + else + WERROR=1 python setup.py bdist_wheel + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else python setup.py clean if [[ "$BUILD_ENVIRONMENT" == *xla* ]]; then source .ci/pytorch/install_cache_xla.sh fi +<<<<<<< HEAD python setup.py bdist_wheel fi pip_install_whl "$(echo dist/*.whl)" @@ -316,6 +391,17 @@ else install_torchao fi +======= + if [[ "$USE_SPLIT_BUILD" == "true" ]]; then + echo "USE_SPLIT_BUILD cannot be used with xla or rocm" + exit 1 + else + python setup.py bdist_wheel + fi + fi + pip_install_whl "$(echo dist/*.whl)" + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then echo "Checking that xpu is compiled" pushd dist/ @@ -403,8 +489,15 @@ else # This is an attempt to mitigate flaky libtorch build OOM error. By default, the build parallelization # is set to be the number of CPU minus 2. So, let's try a more conservative value here. A 4xlarge has # 16 CPUs +<<<<<<< HEAD MAX_JOBS=$(nproc --ignore=4) export MAX_JOBS +======= + if [ -z "$MAX_JOBS_OVERRIDE" ]; then + MAX_JOBS=$(nproc --ignore=4) + export MAX_JOBS + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: Install outside of source directory (at the same level as the root # pytorch folder) so that it doesn't get cleaned away prior to docker push. @@ -421,7 +514,12 @@ if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]]; # don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build python tools/stats/export_test_times.py fi +<<<<<<< HEAD # don't do this for bazel or s390x or riscv64 as they don't use sccache if [[ "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *riscv64* && "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then +======= +# don't do this for bazel or s390x as they don't use sccache +if [[ "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) print_sccache_stats fi diff --git a/.ci/pytorch/check_binary.sh b/.ci/pytorch/check_binary.sh index cca289ac146b8..94ee47c9539d6 100755 --- a/.ci/pytorch/check_binary.sh +++ b/.ci/pytorch/check_binary.sh @@ -67,7 +67,11 @@ fi # wheels with cxx11-abi echo "Checking that the gcc ABI is what we expect" +<<<<<<< HEAD if [[ "$(uname)" != 'Darwin' && "$(uname -m)" != "s390x" ]]; then +======= +if [[ "$(uname)" != 'Darwin' ]]; then +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We also check that there are cxx11 symbols in libtorch # echo "Checking that symbols in libtorch.so have the right gcc abi" @@ -300,3 +304,27 @@ except RuntimeError as e: exit 1 fi fi +<<<<<<< HEAD +======= + +############################################################################### +# Check for C++ ABI compatibility to GCC-11 - GCC 13 +############################################################################### +if [[ "$(uname)" == 'Linux' && "$PACKAGE_TYPE" == 'manywheel' ]]; then + pushd /tmp + # Per https://gcc.gnu.org/onlinedocs/gcc/C_002b_002b-Dialect-Options.html + # gcc-11 is ABI16, gcc-13 is ABI18, gcc-14 is ABI19 + # gcc 11 - CUDA 11.8, xpu, rocm + # gcc 13 - CUDA 12.6, 12.8 and cpu + # Please see issue for reference: https://github.com/pytorch/pytorch/issues/152426 + if [[ "$(uname -m)" == "s390x" ]]; then + cxx_abi="19" + elif [[ "$DESIRED_CUDA" != 'xpu' && "$DESIRED_CUDA" != 'rocm'* ]]; then + cxx_abi="18" + else + cxx_abi="16" + fi + python -c "import torch; exit(0 if torch._C._PYBIND11_BUILD_ABI == '_cxxabi10${cxx_abi}' else 1)" + popd +fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/pytorch/common-build.sh b/.ci/pytorch/common-build.sh index 8ca9fdb34c77a..23dca9287491e 100644 --- a/.ci/pytorch/common-build.sh +++ b/.ci/pytorch/common-build.sh @@ -13,6 +13,7 @@ if [[ "$BUILD_ENVIRONMENT" != *win-* ]]; then fi if which sccache > /dev/null; then +<<<<<<< HEAD # Clear SCCACHE_BUCKET and SCCACHE_REGION if they are empty, otherwise # sccache will complain about invalid bucket configuration if [[ -z "${SCCACHE_BUCKET:-}" ]]; then @@ -20,6 +21,8 @@ if [[ "$BUILD_ENVIRONMENT" != *win-* ]]; then unset SCCACHE_REGION fi +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Save sccache logs to file sccache --stop-server > /dev/null 2>&1 || true rm -f ~/sccache_error.log || true diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index bf03e132d30bb..583bb774f1a84 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -78,6 +78,7 @@ function pip_install_whl() { fi } +<<<<<<< HEAD function pip_build_and_install() { local build_target=$1 local wheel_dir=$2 @@ -106,6 +107,8 @@ function pip_build_and_install() { pip_install_whl "${file}" done } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) function pip_install() { # retry 3 times @@ -149,6 +152,7 @@ function get_pinned_commit() { cat .github/ci_commit_pins/"${1}".txt } +<<<<<<< HEAD function detect_cuda_arch() { if [[ "${BUILD_ENVIRONMENT}" == *cuda* ]]; then if command -v nvidia-smi; then @@ -166,6 +170,19 @@ function install_torchaudio() { local commit commit=$(get_pinned_commit audio) pip_build_and_install "git+https://github.com/pytorch/audio.git@${commit}" dist/audio +======= +function install_torchaudio() { + local commit + commit=$(get_pinned_commit audio) + if [[ "$1" == "cuda" ]]; then + # TODO: This is better to be passed as a parameter from _linux-test workflow + # so that it can be consistent with what is set in build + TORCH_CUDA_ARCH_LIST="8.0;8.6" pip_install --no-use-pep517 "git+https://github.com/pytorch/audio.git@${commit}" + else + pip_install --no-use-pep517 "git+https://github.com/pytorch/audio.git@${commit}" + fi + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } function install_torchtext() { @@ -173,8 +190,13 @@ function install_torchtext() { local text_commit data_commit=$(get_pinned_commit data) text_commit=$(get_pinned_commit text) +<<<<<<< HEAD pip_build_and_install "git+https://github.com/pytorch/data.git@${data_commit}" dist/data pip_build_and_install "git+https://github.com/pytorch/text.git@${text_commit}" dist/text +======= + pip_install --no-use-pep517 "git+https://github.com/pytorch/data.git@${data_commit}" + pip_install --no-use-pep517 "git+https://github.com/pytorch/text.git@${text_commit}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } function install_torchvision() { @@ -187,6 +209,7 @@ function install_torchvision() { echo 'char* dlerror(void) { return "";}'|gcc -fpic -shared -o "${HOME}/dlerror.so" -x c - LD_PRELOAD=${orig_preload}:${HOME}/dlerror.so fi +<<<<<<< HEAD if [[ "${BUILD_ENVIRONMENT}" == *cuda* ]]; then # Not sure if both are needed, but why not @@ -195,6 +218,9 @@ function install_torchvision() { fi pip_build_and_install "git+https://github.com/pytorch/vision.git@${commit}" dist/vision +======= + pip_install --no-use-pep517 "git+https://github.com/pytorch/vision.git@${commit}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if [ -n "${LD_PRELOAD}" ]; then LD_PRELOAD=${orig_preload} fi @@ -214,6 +240,7 @@ function install_torchrec_and_fbgemm() { if [[ "$BUILD_ENVIRONMENT" == *rocm* ]] ; then # install torchrec first because it installs fbgemm nightly on top of rocm fbgemm +<<<<<<< HEAD pip_build_and_install "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" dist/torchrec pip_uninstall fbgemm-gpu-nightly @@ -279,12 +306,37 @@ function install_torchrec_and_fbgemm() { else pip_build_and_install "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" dist/torchrec pip_build_and_install "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#subdirectory=fbgemm_gpu" dist/fbgemm_gpu +======= + pip_install --no-use-pep517 "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" + pip_uninstall fbgemm-gpu-nightly + + pip_install tabulate # needed for newer fbgemm + pip_install patchelf # needed for rocm fbgemm + git clone --recursive https://github.com/pytorch/fbgemm + pushd fbgemm/fbgemm_gpu + git checkout "${fbgemm_commit}" + python setup.py install \ + --package_variant=rocm \ + -DHIP_ROOT_DIR="${ROCM_PATH}" \ + -DCMAKE_C_FLAGS="-DTORCH_USE_HIP_DSA" \ + -DCMAKE_CXX_FLAGS="-DTORCH_USE_HIP_DSA" + popd + rm -rf fbgemm + else + # See https://github.com/pytorch/pytorch/issues/106971 + CUDA_PATH=/usr/local/cuda-12.1 pip_install --no-use-pep517 "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#egg=fbgemm-gpu&subdirectory=fbgemm_gpu" + pip_install --no-use-pep517 "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi } function clone_pytorch_xla() { if [[ ! -d ./xla ]]; then +<<<<<<< HEAD git clone --recursive -b r2.9 https://github.com/pytorch/xla.git +======= + git clone --recursive -b r2.8 https://github.com/pytorch/xla.git +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pushd xla # pin the xla hash so that we don't get broken by changes to xla git checkout "$(cat ../.github/ci_commit_pins/xla.txt)" @@ -294,10 +346,41 @@ function clone_pytorch_xla() { fi } +<<<<<<< HEAD function install_torchao() { local commit commit=$(get_pinned_commit torchao) pip_build_and_install "git+https://github.com/pytorch/ao.git@${commit}" dist/ao +======= +function checkout_install_torchbench() { + local commit + commit=$(get_pinned_commit torchbench) + git clone https://github.com/pytorch/benchmark torchbench + pushd torchbench + git checkout "$commit" + + if [ "$1" ]; then + python install.py --continue_on_fail models "$@" + else + # Occasionally the installation may fail on one model but it is ok to continue + # to install and test other models + python install.py --continue_on_fail + fi + + # TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488 + # is regressing speedup metric. This needs to be investigated further + pip install transformers==4.38.1 + + echo "Print all dependencies after TorchBench is installed" + python -mpip freeze + popd +} + +function install_torchao() { + local commit + commit=$(get_pinned_commit torchao) + pip_install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${commit}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } function print_sccache_stats() { diff --git a/.ci/pytorch/cpp_doc_push_script.sh b/.ci/pytorch/cpp_doc_push_script.sh index f085fa78bebe9..536966a992503 100755 --- a/.ci/pytorch/cpp_doc_push_script.sh +++ b/.ci/pytorch/cpp_doc_push_script.sh @@ -58,7 +58,11 @@ time python tools/setup_helpers/generate_code.py \ # Build the docs pushd docs/cpp +<<<<<<< HEAD time make VERBOSE=1 html +======= +time make VERBOSE=1 html -j +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) popd popd diff --git a/.ci/pytorch/create_test_cert.py b/.ci/pytorch/create_test_cert.py new file mode 100644 index 0000000000000..f2be0c13227d1 --- /dev/null +++ b/.ci/pytorch/create_test_cert.py @@ -0,0 +1,123 @@ +from datetime import datetime, timedelta, timezone +from tempfile import mkdtemp + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID + + +temp_dir = mkdtemp() +print(temp_dir) + + +def genrsa(path): + key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + with open(path, "wb") as f: + f.write( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + return key + + +def create_cert(path, C, ST, L, O, key): + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COUNTRY_NAME, C), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ST), + x509.NameAttribute(NameOID.LOCALITY_NAME, L), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, O), + ] + ) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after( + # Our certificate will be valid for 10 days + datetime.now(timezone.utc) + timedelta(days=10) + ) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(key, hashes.SHA256()) + ) + # Write our certificate out to disk. + with open(path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + return cert + + +def create_req(path, C, ST, L, O, key): + csr = ( + x509.CertificateSigningRequestBuilder() + .subject_name( + x509.Name( + [ + # Provide various details about who we are. + x509.NameAttribute(NameOID.COUNTRY_NAME, C), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ST), + x509.NameAttribute(NameOID.LOCALITY_NAME, L), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, O), + ] + ) + ) + .sign(key, hashes.SHA256()) + ) + with open(path, "wb") as f: + f.write(csr.public_bytes(serialization.Encoding.PEM)) + return csr + + +def sign_certificate_request(path, csr_cert, ca_cert, private_ca_key): + cert = ( + x509.CertificateBuilder() + .subject_name(csr_cert.subject) + .issuer_name(ca_cert.subject) + .public_key(csr_cert.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after( + # Our certificate will be valid for 10 days + datetime.now(timezone.utc) + timedelta(days=10) + # Sign our certificate with our private key + ) + .sign(private_ca_key, hashes.SHA256()) + ) + with open(path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + return cert + + +ca_key = genrsa(temp_dir + "/ca.key") +ca_cert = create_cert( + temp_dir + "/ca.pem", + "US", + "New York", + "New York", + "Gloo Certificate Authority", + ca_key, +) + +pkey = genrsa(temp_dir + "/pkey.key") +csr = create_req( + temp_dir + "/csr.csr", + "US", + "California", + "San Francisco", + "Gloo Testing Company", + pkey, +) + +cert = sign_certificate_request(temp_dir + "/cert.pem", csr, ca_cert, ca_key) diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index a859901191e03..ef1926cf8aaf6 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -157,6 +157,7 @@ test_jit_hooks() { assert_git_not_dirty } +<<<<<<< HEAD # Shellcheck doesn't like it when you pass no arguments to a function # that can take args. See https://www.shellcheck.net/wiki/SC2120 # shellcheck disable=SC2120 @@ -185,6 +186,8 @@ checkout_install_torchbench() { python -mpip freeze } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torchbench_setup_macos() { git clone --recursive https://github.com/pytorch/vision torchvision git clone --recursive https://github.com/pytorch/audio torchaudio @@ -195,7 +198,11 @@ torchbench_setup_macos() { git checkout "$(cat ../.github/ci_commit_pins/vision.txt)" git submodule update --init --recursive python setup.py clean +<<<<<<< HEAD python -m pip install -e . -v --no-build-isolation +======= + python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) popd pushd torchaudio @@ -204,14 +211,26 @@ torchbench_setup_macos() { git submodule update --init --recursive python setup.py clean #TODO: Remove me, when figure out how to make TorchAudio find brew installed openmp +<<<<<<< HEAD USE_OPENMP=0 python -m pip install -e . -v --no-build-isolation popd +======= + USE_OPENMP=0 python setup.py develop + popd + + # Shellcheck doesn't like it when you pass no arguments to a function that can take args. See https://www.shellcheck.net/wiki/SC2120 + # shellcheck disable=SC2119,SC2120 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) checkout_install_torchbench } pip_benchmark_deps() { +<<<<<<< HEAD python -mpip install --no-input requests cython scikit-learn six +======= + python -mpip install --no-input astunparse requests cython scikit-learn +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } @@ -302,6 +321,7 @@ test_torchbench_smoketest() { fi done +<<<<<<< HEAD echo "Pytorch benchmark on mps device completed" } @@ -343,6 +363,8 @@ test_aoti_torchbench_smoketest() { PYTHONPATH="$(pwd)"/torchbench python benchmarks/dynamo/huggingface.py \ --accuracy --export-aot-inductor --inference --devices "$device" "$dtype_arg" \ --output "$TEST_REPORTS_DIR/aot_inductor_huggingface_${dtype}_inference_${device}_accuracy.csv" || true +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) echo "Pytorch benchmark on mps device completed" } @@ -391,8 +413,11 @@ elif [[ $TEST_CONFIG == *"perf_timm"* ]]; then test_timm_perf elif [[ $TEST_CONFIG == *"perf_smoketest"* ]]; then test_torchbench_smoketest "${SHARD_NUMBER}" +<<<<<<< HEAD elif [[ $TEST_CONFIG == *"aot_inductor_perf_smoketest"* ]]; then test_aoti_torchbench_smoketest "${SHARD_NUMBER}" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif [[ $TEST_CONFIG == *"mps"* ]]; then test_python_mps elif [[ $NUM_TEST_SHARDS -gt 1 ]]; then diff --git a/.ci/pytorch/multigpu-test.sh b/.ci/pytorch/multigpu-test.sh index 219463f318dbd..0d14202644ab2 100755 --- a/.ci/pytorch/multigpu-test.sh +++ b/.ci/pytorch/multigpu-test.sh @@ -45,7 +45,10 @@ if [[ "${SHARD_NUMBER:-2}" == "2" ]]; then # DTensor tests time python test/run_test.py --verbose -i distributed/tensor/test_random_ops time python test/run_test.py --verbose -i distributed/tensor/test_dtensor_compile +<<<<<<< HEAD time python test/run_test.py --verbose -i distributed/tensor/test_utils.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # DeviceMesh test time python test/run_test.py --verbose -i distributed/test_device_mesh diff --git a/.ci/pytorch/run_glootls_test.sh b/.ci/pytorch/run_glootls_test.sh new file mode 100755 index 0000000000000..cd17b269fe6a9 --- /dev/null +++ b/.ci/pytorch/run_glootls_test.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +CREATE_TEST_CERT="$(dirname "${BASH_SOURCE[0]}")/create_test_cert.py" +TMP_CERT_DIR=$(python "$CREATE_TEST_CERT") + +openssl verify -CAfile "${TMP_CERT_DIR}/ca.pem" "${TMP_CERT_DIR}/cert.pem" + +export GLOO_DEVICE_TRANSPORT=TCP_TLS +export GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY=${TMP_CERT_DIR}/pkey.key +export GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT=${TMP_CERT_DIR}/cert.pem +export GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE=${TMP_CERT_DIR}/ca.pem + +time python test/run_test.py --include distributed/test_c10d_gloo --verbose -- ProcessGroupGlooTest + +unset GLOO_DEVICE_TRANSPORT +unset GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY +unset GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT +unset GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE diff --git a/.ci/pytorch/run_tests.sh b/.ci/pytorch/run_tests.sh index f5ed90deef249..97ae8d22c7917 100755 --- a/.ci/pytorch/run_tests.sh +++ b/.ci/pytorch/run_tests.sh @@ -74,6 +74,7 @@ else fi # Environment initialization +<<<<<<< HEAD retry pip install -qUr requirements-build.txt if [[ "$(uname)" == Darwin ]]; then # Install the testing dependencies @@ -81,6 +82,14 @@ if [[ "$(uname)" == Darwin ]]; then else retry pip install -qr requirements.txt || true retry pip install -q hypothesis protobuf pytest || true +======= +if [[ "$(uname)" == Darwin ]]; then + # Install the testing dependencies + retry pip install -q future hypothesis ${NUMPY_PACKAGE} ${PROTOBUF_PACKAGE} pytest setuptools six typing_extensions pyyaml +else + retry pip install -qr requirements.txt || true + retry pip install -q hypothesis protobuf pytest setuptools || true +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) numpy_ver=1.15 case "$(python --version 2>&1)" in *2* | *3.5* | *3.6*) diff --git a/.ci/pytorch/smoke_test/check_binary_symbols.py b/.ci/pytorch/smoke_test/check_binary_symbols.py index b0c607659c72d..1dd56236e2619 100755 --- a/.ci/pytorch/smoke_test/check_binary_symbols.py +++ b/.ci/pytorch/smoke_test/check_binary_symbols.py @@ -32,9 +32,12 @@ "torch::", ) +<<<<<<< HEAD # Patterns for detecting statically linked libstdc++ symbols STATICALLY_LINKED_CXX11_ABI = [re.compile(r".*recursive_directory_iterator.*")] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _apply_libtorch_symbols(symbols): return [ @@ -56,17 +59,24 @@ def get_symbols(lib: str) -> list[tuple[str, str, str]]: return [x.split(" ", 2) for x in lines.decode("latin1").split("\n")[:-1]] +<<<<<<< HEAD def grep_symbols( lib: str, patterns: list[Any], symbol_type: str | None = None ) -> list[str]: +======= +def grep_symbols(lib: str, patterns: list[Any]) -> list[str]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _grep_symbols( symbols: list[tuple[str, str, str]], patterns: list[Any] ) -> list[str]: rc = [] for _s_addr, _s_type, s_name in symbols: +<<<<<<< HEAD # Filter by symbol type if specified if symbol_type and _s_type != symbol_type: continue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for pattern in patterns: if pattern.match(s_name): rc.append(s_name) @@ -88,6 +98,7 @@ def _get_symbols_chunk(i): return functools.reduce(list.__add__, (x.result() for x in tasks), []) +<<<<<<< HEAD def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None: cxx11_statically_linked_symbols = grep_symbols( lib, STATICALLY_LINKED_CXX11_ABI, symbol_type="T" @@ -100,6 +111,8 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None: ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def check_lib_symbols_for_abi_correctness(lib: str) -> None: print(f"lib: {lib}") cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS) @@ -127,7 +140,10 @@ def main() -> None: libtorch_cpu_path = str(install_root / "lib" / "libtorch_cpu.so") check_lib_symbols_for_abi_correctness(libtorch_cpu_path) +<<<<<<< HEAD check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": diff --git a/.ci/pytorch/smoke_test/smoke_test.py b/.ci/pytorch/smoke_test/smoke_test.py index 305ad15d98e7e..8539d262389ac 100644 --- a/.ci/pytorch/smoke_test/smoke_test.py +++ b/.ci/pytorch/smoke_test/smoke_test.py @@ -385,6 +385,7 @@ def foo(x: torch.Tensor) -> torch.Tensor: x_pt2 = torch.compile(model, mode="max-autotune")(x) +<<<<<<< HEAD def smoke_test_nvshmem() -> None: if not torch.cuda.is_available(): print("CUDA is not available, skipping NVSHMEM test") @@ -408,6 +409,8 @@ def smoke_test_nvshmem() -> None: print(f"NVSHMEM available at run time: {_is_nvshmem_available()}") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def smoke_test_modules(): cwd = os.getcwd() for module in MODULES: @@ -502,8 +505,11 @@ def main() -> None: options.pypi_pkg_check, ) +<<<<<<< HEAD smoke_test_nvshmem() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": main() diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index e8c5b3fc56af2..33182e0fe02bd 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -11,8 +11,11 @@ export TERM=vt100 # shellcheck source=./common.sh source "$(dirname "${BASH_SOURCE[0]}")/common.sh" +<<<<<<< HEAD # shellcheck source=./common-build.sh source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Do not change workspace permissions for ROCm and s390x CI jobs # as it can leave workspace with bad permissions for cancelled jobs @@ -32,6 +35,7 @@ if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *s390x* && -d /v git config --global --add safe.directory /var/lib/jenkins/workspace fi +<<<<<<< HEAD # Patch numba to avoid CUDA-13 crash, see https://github.com/pytorch/pytorch/issues/162878 NUMBA_CUDA_DIR=$(python -c "import os;import numba.cuda; print(os.path.dirname(numba.cuda.__file__))" 2>/dev/null || true) @@ -42,6 +46,8 @@ if [ -n "$NUMBA_CUDA_DIR" ]; then popd fi +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) echo "Environment variables:" env @@ -101,7 +107,10 @@ if [[ "$BUILD_ENVIRONMENT" == *clang9* || "$BUILD_ENVIRONMENT" == *xpu* ]]; then export VALGRIND=OFF fi +<<<<<<< HEAD detect_cuda_arch +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if [[ "$BUILD_ENVIRONMENT" == *s390x* ]]; then # There are additional warnings on s390x, maybe due to newer gcc. @@ -176,6 +185,11 @@ elif [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then export PYTORCH_TESTING_DEVICE_ONLY_FOR="xpu" # setting PYTHON_TEST_EXTRA_OPTION export PYTHON_TEST_EXTRA_OPTION="--xpu" +<<<<<<< HEAD +======= + # Disable sccache for xpu test due to flaky issue https://github.com/pytorch/pytorch/issues/143585 + sudo rm -rf /opt/cache +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi if [[ "$TEST_CONFIG" == *crossref* ]]; then @@ -300,12 +314,15 @@ elif [[ $TEST_CONFIG == 'nogpu_AVX512' ]]; then export ATEN_CPU_CAPABILITY=avx2 fi +<<<<<<< HEAD if [[ "${TEST_CONFIG}" == "legacy_nvidia_driver" ]]; then # Make sure that CUDA can be initialized (cd test && python -c "import torch; torch.rand(2, 2, device='cuda')") export USE_LEGACY_DRIVER=1 fi +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test_python_legacy_jit() { time python test/run_test.py --include test_jit_legacy test_jit_fuser_legacy --verbose assert_git_not_dirty @@ -344,6 +361,7 @@ test_h100_distributed() { time python test/run_test.py --include distributed/_composable/test_composability/test_pp_composability.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running # This test requires multicast support time python test/run_test.py --include distributed/_composable/fsdp/test_fully_shard_comm.py -k TestFullyShardAllocFromPG $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running +<<<<<<< HEAD assert_git_not_dirty } @@ -362,6 +380,14 @@ test_h100_cutlass_backend() { TORCHINDUCTOR_CUTLASS_DIR=$(realpath "./third_party/cutlass") python test/run_test.py --include inductor/test_cutlass_evt $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running } +======= + # symmetric memory test + time python test/run_test.py --include distributed/test_symmetric_memory.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time python test/run_test.py --include distributed/test_nvshmem.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + assert_git_not_dirty +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test_lazy_tensor_meta_reference_disabled() { export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1 echo "Testing lazy tensor operations without meta reference" @@ -404,10 +430,16 @@ test_einops() { test_inductor_distributed() { # Smuggle a few multi-gpu tests here so that we don't have to request another large node echo "Testing multi_gpu tests in test_torchinductor" +<<<<<<< HEAD python test/run_test.py -i inductor/test_aot_inductor.py -k test_replicate_on_devices --verbose python test/run_test.py -i inductor/test_aot_inductor.py -k test_on_gpu_device1 --verbose python test/run_test.py -i inductor/test_aot_inductor.py -k test_non_default_gpu_device --verbose python test/run_test.py -i inductor/test_aot_inductor.py -k test_load_package_multiple_gpus --verbose +======= + python test/run_test.py -i inductor/test_torchinductor.py -k test_multi_gpu --verbose + python test/run_test.py -i inductor/test_aot_inductor.py -k test_non_default_cuda_device --verbose + python test/run_test.py -i inductor/test_aot_inductor.py -k test_replicate_on_devices --verbose +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python test/run_test.py -i distributed/test_c10d_functional_native.py --verbose python test/run_test.py -i distributed/tensor/test_dtensor_compile.py --verbose python test/run_test.py -i distributed/tensor/parallel/test_micro_pipeline_tp.py --verbose @@ -459,6 +491,7 @@ test_inductor_aoti() { python3 tools/amd_build/build_amd.py fi if [[ "$BUILD_ENVIRONMENT" == *sm86* ]]; then +<<<<<<< HEAD BUILD_COMMAND=(TORCH_CUDA_ARCH_LIST=8.6 USE_FLASH_ATTENTION=OFF python -m pip install --no-build-isolation -v -e .) # TODO: Replace me completely, as one should not use conda libstdc++, nor need special path to TORCH_LIB TEST_ENVS=(CPP_TESTS_DIR="${BUILD_BIN_DIR}" LD_LIBRARY_PATH="/opt/conda/envs/py_3.10/lib:${TORCH_LIB_DIR}:${LD_LIBRARY_PATH}") @@ -474,6 +507,16 @@ test_inductor_aoti() { /usr/bin/env CMAKE_FRESH=1 BUILD_AOT_INDUCTOR_TEST=1 "${BUILD_COMMAND[@]}" /usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile +======= + BUILD_AOT_INDUCTOR_TEST=1 TORCH_CUDA_ARCH_LIST=8.6 USE_FLASH_ATTENTION=OFF python setup.py develop + # TODO: Replace me completely, as one should not use conda libstdc++, nor need special path to TORCH_LIB + LD_LIBRARY_PATH=/opt/conda/envs/py_3.10/lib/:${TORCH_LIB_DIR}:$LD_LIBRARY_PATH + CPP_TESTS_DIR="${BUILD_BIN_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference -dist=loadfile + else + BUILD_AOT_INDUCTOR_TEST=1 python setup.py develop + CPP_TESTS_DIR="${BUILD_BIN_DIR}" LD_LIBRARY_PATH="${TORCH_LIB_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference -dist=loadfile + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } test_inductor_cpp_wrapper_shard() { @@ -486,6 +529,7 @@ test_inductor_cpp_wrapper_shard() { TEST_REPORTS_DIR=$(pwd)/test/test-reports mkdir -p "$TEST_REPORTS_DIR" +<<<<<<< HEAD # Run certain inductor unit tests with cpp wrapper. In the end state, we # should be able to run all the inductor unit tests with cpp_wrapper. # @@ -513,6 +557,48 @@ test_inductor_cpp_wrapper_shard() { -k 'xpu' \ --shard "$1" "$NUM_TEST_SHARDS" \ --verbose +======= + if [[ "$1" -eq "2" ]]; then + # For now, manually put the opinfo tests in shard 2, and all other tests in + # shard 1. Run all CPU tests, as well as specific GPU tests triggering past + # bugs, for now. + python test/run_test.py \ + --include inductor/test_torchinductor_opinfo \ + -k 'linalg or to_sparse or TestInductorOpInfoCPU' \ + --verbose + exit + fi + + # Run certain inductor unit tests with cpp wrapper. In the end state, we + # should be able to run all the inductor unit tests with cpp_wrapper. + python test/run_test.py \ + --include inductor/test_torchinductor inductor/test_max_autotune inductor/test_cpu_repro \ + --verbose + python test/run_test.py --inductor --include test_torch -k 'take' --verbose + + # Run inductor benchmark tests with cpp wrapper. + # Skip benchmark tests if it's in rerun-disabled-mode. + if [[ "${PYTORCH_TEST_RERUN_DISABLED_TESTS}" == "1" ]]; then + echo "skip dynamo benchmark tests for rerun-disabled-test" + else + echo "run dynamo benchmark tests with cpp wrapper" + python benchmarks/dynamo/timm_models.py --device cuda --accuracy --amp \ + --training --inductor --disable-cudagraphs --only vit_base_patch16_224 \ + --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_training.csv" + python benchmarks/dynamo/check_accuracy.py \ + --actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_training.csv" \ + --expected "benchmarks/dynamo/ci_expected_accuracy/${MAYBE_ROCM}inductor_timm_training.csv" + + python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ + --bfloat16 --inference --inductor --only hf_T5 --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" + python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ + --bfloat16 --inference --inductor --only llama --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" + python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ + --bfloat16 --inference --inductor --only moco --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" + python benchmarks/dynamo/check_accuracy.py \ + --actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" \ + --expected "benchmarks/dynamo/ci_expected_accuracy/${MAYBE_ROCM}inductor_torchbench_inference.csv" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi } @@ -634,8 +720,13 @@ test_perf_for_dashboard() { local device=cuda if [[ "${TEST_CONFIG}" == *cpu* ]]; then +<<<<<<< HEAD if [[ "${TEST_CONFIG}" == *cpu_x86_zen* ]]; then device=cpu_x86_zen +======= + if [[ "${TEST_CONFIG}" == *zen_cpu_x86* ]]; then + device=zen_cpu_x86 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif [[ "${TEST_CONFIG}" == *cpu_x86* ]]; then device=cpu_x86 elif [[ "${TEST_CONFIG}" == *cpu_aarch64* ]]; then @@ -646,19 +737,26 @@ test_perf_for_dashboard() { device=cuda_a10g elif [[ "${TEST_CONFIG}" == *h100* ]]; then device=cuda_h100 +<<<<<<< HEAD elif [[ "${TEST_CONFIG}" == *b200* ]]; then device=cuda_b200 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif [[ "${TEST_CONFIG}" == *rocm* ]]; then device=rocm fi for mode in "${modes[@]}"; do if [[ "$mode" == "inference" ]]; then +<<<<<<< HEAD if [[ "$device" == "cpu_x86" ]]; then dtype=amp else dtype=bfloat16 fi +======= + dtype=bfloat16 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif [[ "$mode" == "training" ]]; then dtype=amp fi @@ -670,10 +768,13 @@ test_perf_for_dashboard() { target_flag+=( --no-translation-validation) fi +<<<<<<< HEAD if [[ "$DASHBOARD_TAG" == *freezing-true* ]]; then target_flag+=( --freezing) fi +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if [[ "$DASHBOARD_TAG" == *default-true* ]]; then $TASKSET python "benchmarks/dynamo/$suite.py" \ "${target_flag[@]}" --"$mode" --"$dtype" --backend "$backend" --disable-cudagraphs "$@" \ @@ -822,6 +923,7 @@ test_dynamo_benchmark() { if [[ "${TEST_CONFIG}" == *perf_compare* ]]; then test_single_dynamo_benchmark "training" "$suite" "$shard_id" --training --amp "$@" elif [[ "${TEST_CONFIG}" == *perf* ]]; then +<<<<<<< HEAD # TODO (huydhn): Just smoke test some sample models if [[ "${TEST_CONFIG}" == *b200* ]]; then if [[ "${suite}" == "huggingface" ]]; then @@ -832,6 +934,8 @@ test_dynamo_benchmark() { export TORCHBENCH_ONLY_MODELS="hf_Bert" fi fi +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test_single_dynamo_benchmark "dashboard" "$suite" "$shard_id" "$@" else if [[ "${TEST_CONFIG}" == *cpu* ]]; then @@ -959,6 +1063,15 @@ test_torchbench_gcp_smoketest(){ popd } +<<<<<<< HEAD +======= +test_python_gloo_with_tls() { + source "$(dirname "${BASH_SOURCE[0]}")/run_glootls_test.sh" + assert_git_not_dirty +} + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test_aten() { # Test ATen # The following test(s) of ATen have already been skipped by caffe2 in rocm environment: @@ -1005,8 +1118,11 @@ test_without_numpy() { if [[ "${TEST_CONFIG}" == *dynamo_wrapped* ]]; then python -c "import sys;sys.path.insert(0, 'fake_numpy');import torch;torch.compile(lambda x:print(x))('Hello World')" fi +<<<<<<< HEAD # Regression test for https://github.com/pytorch/pytorch/pull/157734 (torch.onnx should be importable without numpy) python -c "import sys;sys.path.insert(0, 'fake_numpy');import torch; import torch.onnx" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) popd } @@ -1070,10 +1186,26 @@ test_libtorch_api() { mkdir -p $TEST_REPORTS_DIR OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" "$TORCH_BIN_DIR"/test_api --gtest_filter='-IMethodTest.*' --gtest_output=xml:$TEST_REPORTS_DIR/test_api.xml +<<<<<<< HEAD +======= + "$TORCH_BIN_DIR"/test_tensorexpr --gtest_output=xml:$TEST_REPORTS_DIR/test_tensorexpr.xml +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else # Exclude IMethodTest that relies on torch::deploy, which will instead be ran in test_deploy OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_api -k "not IMethodTest" +<<<<<<< HEAD +======= + # On s390x, pytorch is built without llvm. + # Even if it would be built with llvm, llvm currently doesn't support used features on s390x and + # test fails with errors like: + # JIT session error: Unsupported target machine architecture in ELF object pytorch-jitted-objectbuffer + # unknown file: Failure + # C++ exception with description "valOrErr INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/torch/csrc/jit/tensorexpr/llvm_jit.h":34, please report a bug to PyTorch. Unexpected failure in LLVM JIT: Failed to materialize symbols: { (main, { func }) } + if [[ "${BUILD_ENVIRONMENT}" != *s390x* ]]; then + python test/run_test.py --cpp --verbose -i cpp/test_tensorexpr + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi # quantization is not fully supported on s390x yet @@ -1341,6 +1473,7 @@ EOF # Step 2. Make sure that the public API test "test_correct_module_names" fails when an existing # file is modified to introduce an invalid public API function. +<<<<<<< HEAD # The filepath here must not have __all__ defined in it, otherwise the test will pass. # If your PR introduces __all__ to torch/cuda/streams.py please point this to another file # that does not have __all__ defined. @@ -1348,6 +1481,12 @@ EOF cp -v "${EXISTING_FILEPATH}" "${EXISTING_FILEPATH}.orig" echo "${BAD_PUBLIC_FUNC}" >> "${EXISTING_FILEPATH}" invalid_api="torch.cuda.streams.new_public_func" +======= + EXISTING_FILEPATH="${TORCH_INSTALL_DIR}/nn/parameter.py" + cp -v "${EXISTING_FILEPATH}" "${EXISTING_FILEPATH}.orig" + echo "${BAD_PUBLIC_FUNC}" >> "${EXISTING_FILEPATH}" + invalid_api="torch.nn.parameter.new_public_func" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) echo "Appended an invalid public API function to existing file ${EXISTING_FILEPATH}..." check_public_api_test_fails \ @@ -1581,7 +1720,11 @@ test_executorch() { test_linux_aarch64() { python test/run_test.py --include test_modules test_mkldnn test_mkldnn_fusion test_openmp test_torch test_dynamic_shapes \ test_transformers test_multiprocessing test_numpy_interop test_autograd test_binary_ufuncs test_complex test_spectral_ops \ +<<<<<<< HEAD test_foreach test_reductions test_unary_ufuncs test_tensor_creation_ops test_ops \ +======= + test_foreach test_reductions test_unary_ufuncs test_tensor_creation_ops test_ops test_cpp_extensions_open_device_registration \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose # Dynamo tests @@ -1611,7 +1754,11 @@ test_operator_benchmark() { test_inductor_set_cpu_affinity cd benchmarks/operator_benchmark/pt_extension +<<<<<<< HEAD python -m pip install . +======= + python setup.py install +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cd "${TEST_DIR}"/benchmarks/operator_benchmark $TASKSET python -m benchmark_all_test --device "$1" --tag-filter "$2" \ @@ -1624,6 +1771,7 @@ test_operator_benchmark() { --expected "expected_ci_operator_benchmark_eager_float32_cpu.csv" } +<<<<<<< HEAD test_operator_microbenchmark() { TEST_REPORTS_DIR=$(pwd)/test/test-reports mkdir -p "$TEST_REPORTS_DIR" @@ -1643,6 +1791,8 @@ test_operator_microbenchmark() { --benchmark-name "PyTorch operator microbenchmark" done } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then (cd test && python -c "import torch; print(torch.__config__.show())") @@ -1650,6 +1800,7 @@ if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-baze fi if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then # Install numpy-2.0.2 and compatible scipy & numba versions +<<<<<<< HEAD # Force re-install of pandas to avoid error where pandas checks numpy version from initial install and fails upon import TMP_PANDAS_VERSION=$(python -c "import pandas; print(pandas.__version__)" 2>/dev/null) if [ -n "$TMP_PANDAS_VERSION" ]; then @@ -1657,6 +1808,9 @@ if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then else python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0 fi +======= + python -mpip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" != *perf_cpu_aarch64* ]]; then test_linux_aarch64 @@ -1667,10 +1821,13 @@ elif [[ "${TEST_CONFIG}" == *xla* ]]; then install_torchvision build_xla test_xla +<<<<<<< HEAD elif [[ "$TEST_CONFIG" == *vllm* ]]; then echo "vLLM CI uses TORCH_CUDA_ARCH_LIST: $TORCH_CUDA_ARCH_LIST" (cd .ci/lumen_cli && python -m pip install -e .) python -m cli.run test external vllm --test-plan "$TEST_CONFIG" --shard-id "$SHARD_NUMBER" --num-shards "$NUM_TEST_SHARDS" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif [[ "${TEST_CONFIG}" == *executorch* ]]; then test_executorch elif [[ "$TEST_CONFIG" == 'jit_legacy' ]]; then @@ -1697,8 +1854,11 @@ elif [[ "${TEST_CONFIG}" == *operator_benchmark* ]]; then test_operator_benchmark cpu ${TEST_MODE} fi +<<<<<<< HEAD elif [[ "${TEST_CONFIG}" == *operator_microbenchmark* ]]; then test_operator_microbenchmark +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then test_inductor_distributed elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then @@ -1716,6 +1876,7 @@ elif [[ "${TEST_CONFIG}" == *timm* ]]; then id=$((SHARD_NUMBER-1)) test_dynamo_benchmark timm_models "$id" elif [[ "${TEST_CONFIG}" == cachebench ]]; then +<<<<<<< HEAD install_torchaudio install_torchvision PYTHONPATH=/torchbench test_cachebench @@ -1726,21 +1887,56 @@ elif [[ "${TEST_CONFIG}" == verify_cachebench ]]; then elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then install_torchaudio install_torchvision +======= + install_torchaudio cuda + install_torchvision + checkout_install_torchbench nanogpt BERT_pytorch resnet50 hf_T5 llama moco + PYTHONPATH=$(pwd)/torchbench test_cachebench +elif [[ "${TEST_CONFIG}" == verify_cachebench ]]; then + install_torchaudio cpu + install_torchvision + checkout_install_torchbench nanogpt + PYTHONPATH=$(pwd)/torchbench test_verify_cachebench +elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then + if [[ "${TEST_CONFIG}" == *cpu* ]]; then + install_torchaudio cpu + else + install_torchaudio cuda + fi + install_torchvision + TORCH_CUDA_ARCH_LIST="8.0;8.6" install_torchao +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) id=$((SHARD_NUMBER-1)) # https://github.com/opencv/opencv-python/issues/885 pip_install opencv-python==4.8.0.74 if [[ "${TEST_CONFIG}" == *inductor_torchbench_smoketest_perf* ]]; then +<<<<<<< HEAD PYTHONPATH=/torchbench test_inductor_torchbench_smoketest_perf elif [[ "${TEST_CONFIG}" == *inductor_torchbench_cpu_smoketest_perf* ]]; then PYTHONPATH=/torchbench test_inductor_torchbench_cpu_smoketest_perf elif [[ "${TEST_CONFIG}" == *torchbench_gcp_smoketest* ]]; then TORCHBENCHPATH=/torchbench test_torchbench_gcp_smoketest else +======= + checkout_install_torchbench hf_Bert hf_Albert timm_vision_transformer + PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_smoketest_perf + elif [[ "${TEST_CONFIG}" == *inductor_torchbench_cpu_smoketest_perf* ]]; then + checkout_install_torchbench timm_vision_transformer phlippe_densenet basic_gnn_edgecnn \ + llama_v2_7b_16h resnet50 timm_efficientnet mobilenet_v3_large timm_resnest \ + functorch_maml_omniglot yolov3 mobilenet_v2 resnext50_32x4d densenet121 mnasnet1_0 + PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_cpu_smoketest_perf + elif [[ "${TEST_CONFIG}" == *torchbench_gcp_smoketest* ]]; then + checkout_install_torchbench + TORCHBENCHPATH=$(pwd)/torchbench test_torchbench_gcp_smoketest + else + checkout_install_torchbench +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Do this after checkout_install_torchbench to ensure we clobber any # nightlies that torchbench may pull in if [[ "${TEST_CONFIG}" != *cpu* ]]; then install_torchrec_and_fbgemm fi +<<<<<<< HEAD PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id" fi elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then @@ -1752,6 +1948,24 @@ elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then elif [[ "${TEST_CONFIG}" == *inductor* ]]; then install_torchvision test_inductor_shard "${SHARD_NUMBER}" +======= + PYTHONPATH=$(pwd)/torchbench test_dynamo_benchmark torchbench "$id" + fi +elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then + install_torchaudio cuda + install_torchvision + checkout_install_torchbench hf_T5 llama moco + PYTHONPATH=$(pwd)/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER" + test_inductor_aoti +elif [[ "${TEST_CONFIG}" == *inductor* ]]; then + install_torchvision + test_inductor_shard "${SHARD_NUMBER}" + if [[ "${SHARD_NUMBER}" == 1 ]]; then + if [[ "${BUILD_ENVIRONMENT}" != linux-jammy-py3.9-gcc11-build ]]; then + test_inductor_distributed + fi + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif [[ "${TEST_CONFIG}" == *einops* ]]; then test_einops elif [[ "${TEST_CONFIG}" == *dynamo_wrapped* ]]; then @@ -1803,10 +2017,13 @@ elif [[ "${TEST_CONFIG}" == smoke ]]; then test_python_smoke elif [[ "${TEST_CONFIG}" == h100_distributed ]]; then test_h100_distributed +<<<<<<< HEAD elif [[ "${TEST_CONFIG}" == "h100-symm-mem" ]]; then test_h100_symm_mem elif [[ "${TEST_CONFIG}" == h100_cutlass_backend ]]; then test_h100_cutlass_backend +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else install_torchvision install_monkeytype diff --git a/.ci/pytorch/test_example_code/CMakeLists.txt b/.ci/pytorch/test_example_code/CMakeLists.txt index e87f37ae61fb4..688395d1615d9 100644 --- a/.ci/pytorch/test_example_code/CMakeLists.txt +++ b/.ci/pytorch/test_example_code/CMakeLists.txt @@ -16,7 +16,11 @@ target_link_libraries(simple-torch-test CUDA::cudart CUDA::cufft CUDA::cusparse find_library(CUDNN_LIBRARY NAMES cudnn) target_link_libraries(simple-torch-test ${CUDNN_LIBRARY} ) if(MSVC) +<<<<<<< HEAD file(GLOB TORCH_DLLS "$ENV{CUDA_PATH}/bin/cudnn64_8.dll" "$ENV{NVTOOLSEXT_PATH}/bin/x64/*.dll") +======= + file(GLOB TORCH_DLLS "$ENV{CUDA_PATH}/bin/cudnn64_8.dll") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) message("dlls to copy " ${TORCH_DLLS}) add_custom_command(TARGET simple-torch-test POST_BUILD diff --git a/.ci/pytorch/win-test-helpers/build_pytorch.bat b/.ci/pytorch/win-test-helpers/build_pytorch.bat index 67d1569221924..00f5f898e96f1 100644 --- a/.ci/pytorch/win-test-helpers/build_pytorch.bat +++ b/.ci/pytorch/win-test-helpers/build_pytorch.bat @@ -42,7 +42,11 @@ call choco upgrade -y cmake --no-progress --installargs 'ADD_CMAKE_TO_PATH=Syste if errorlevel 1 goto fail if not errorlevel 0 goto fail +<<<<<<< HEAD call pip install mkl==2024.2.0 mkl-static==2024.2.0 mkl-include==2024.2.0 +======= +call pip install mkl-include==2021.4.0 mkl-devel==2021.4.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if errorlevel 1 goto fail if not errorlevel 0 goto fail @@ -61,10 +65,16 @@ if "%USE_XPU%"=="1" ( call "C:\Program Files (x86)\Intel\oneAPI\compiler\latest\env\vars.bat" call "C:\Program Files (x86)\Intel\oneAPI\ocloc\latest\env\vars.bat" if errorlevel 1 exit /b 1 +<<<<<<< HEAD :: Reduce build time SET TORCH_XPU_ARCH_LIST=bmg :: Re-setup python env for build call pip install -r requirements.txt +======= + :: Reduce build time. Only have MTL self-hosted runner now + SET TORCH_XPU_ARCH_LIST=xe-lpg + SET USE_KINETO=0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @echo on @@ -137,7 +147,11 @@ sccache --show-stats python -c "import os, glob; os.system('python -mpip install --no-index --no-deps ' + glob.glob('dist/*.whl')[0])" ( if "%BUILD_ENVIRONMENT%"=="" ( +<<<<<<< HEAD echo NOTE: To run `import torch`, please make sure to activate the conda environment by running `call %CONDA_ROOT_DIR%\Scripts\activate.bat %CONDA_ROOT_DIR%\envs\py_tmp` in Command Prompt before running Git Bash. +======= + echo NOTE: To run `import torch`, please make sure to activate the conda environment by running `call %CONDA_PARENT_DIR%\Miniconda3\Scripts\activate.bat %CONDA_PARENT_DIR%\Miniconda3` in Command Prompt before running Git Bash. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else ( copy /Y "dist\*.whl" "%PYTORCH_FINAL_PACKAGE_DIR%" diff --git a/.ci/pytorch/win-test-helpers/installation-helpers/activate_miniconda3.bat b/.ci/pytorch/win-test-helpers/installation-helpers/activate_miniconda3.bat index abd2c8722b11d..09c66282f04d2 100644 --- a/.ci/pytorch/win-test-helpers/installation-helpers/activate_miniconda3.bat +++ b/.ci/pytorch/win-test-helpers/installation-helpers/activate_miniconda3.bat @@ -3,12 +3,20 @@ if "%BUILD_ENVIRONMENT%"=="" ( ) else ( set CONDA_PARENT_DIR=C:\Jenkins ) +<<<<<<< HEAD set CONDA_ROOT_DIR=%CONDA_PARENT_DIR%\Miniconda3 +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) :: Be conservative here when rolling out the new AMI with conda. This will try :: to install conda as before if it couldn't find the conda installation. This :: can be removed eventually after we gain enough confidence in the AMI +<<<<<<< HEAD if not exist %CONDA_ROOT_DIR% ( +======= +if not exist %CONDA_PARENT_DIR%\Miniconda3 ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) set INSTALL_FRESH_CONDA=1 ) @@ -17,14 +25,22 @@ if "%INSTALL_FRESH_CONDA%"=="1" ( if errorlevel 1 exit /b if not errorlevel 0 exit /b +<<<<<<< HEAD %TMP_DIR_WIN%\Miniconda3-latest-Windows-x86_64.exe /InstallationType=JustMe /RegisterPython=0 /S /AddToPath=0 /D=%CONDA_ROOT_DIR% +======= + %TMP_DIR_WIN%\Miniconda3-latest-Windows-x86_64.exe /InstallationType=JustMe /RegisterPython=0 /S /AddToPath=0 /D=%CONDA_PARENT_DIR%\Miniconda3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if errorlevel 1 exit /b if not errorlevel 0 exit /b ) :: Activate conda so that we can use its commands, i.e. conda, python, pip +<<<<<<< HEAD call %CONDA_ROOT_DIR%\Scripts\activate.bat %CONDA_ROOT_DIR% :: Activate conda so that we can use its commands, i.e. conda, python, pip call conda activate py_tmp call pip install -r .ci/docker/requirements-ci.txt +======= +call %CONDA_PARENT_DIR%\Miniconda3\Scripts\activate.bat %CONDA_PARENT_DIR%\Miniconda3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/pytorch/win-test-helpers/setup_pytorch_env.bat b/.ci/pytorch/win-test-helpers/setup_pytorch_env.bat index 3173582b06f45..928fc58113ca6 100644 --- a/.ci/pytorch/win-test-helpers/setup_pytorch_env.bat +++ b/.ci/pytorch/win-test-helpers/setup_pytorch_env.bat @@ -14,7 +14,11 @@ if not errorlevel 0 exit /b :: build\torch. Rather than changing all these references, making a copy of torch folder :: from conda to the current workspace is easier. The workspace will be cleaned up after :: the job anyway +<<<<<<< HEAD xcopy /s %CONDA_ROOT_DIR%\envs\py_tmp\Lib\site-packages\torch %TMP_DIR_WIN%\build\torch\ +======= +xcopy /s %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %TMP_DIR_WIN%\build\torch\ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pushd . if "%VC_VERSION%" == "" ( diff --git a/.ci/pytorch/win-test.sh b/.ci/pytorch/win-test.sh index c96d5c331c9f8..e5d84b549e4f6 100755 --- a/.ci/pytorch/win-test.sh +++ b/.ci/pytorch/win-test.sh @@ -38,6 +38,7 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then fi # TODO: Move both of them to Windows AMI +<<<<<<< HEAD python -m pip install tensorboard==2.13.0 protobuf==5.29.4 pytest-subtests==0.13.1 # Copied from https://github.com/pytorch/test-infra/blob/be01a40157c36cd5a48391fdf44a7bc3ebd4c7e3/aws/ami/windows/scripts/Installers/Install-Pip-Dependencies.ps1#L16 with some adjustments @@ -52,6 +53,15 @@ python -m pip install z3-solver==4.15.1.0 # Install tlparse for test\dynamo\test_structured_trace.py UTs. python -m pip install tlparse==0.4.0 +======= +python -m pip install pytest-rerunfailures==10.3 pytest-cpp==2.3.0 tensorboard==2.13.0 protobuf==5.29.4 pytest-subtests==0.13.1 + +# Install Z3 optional dependency for Windows builds. +python -m pip install z3-solver==4.12.2.0 + +# Install tlparse for test\dynamo\test_structured_trace.py UTs. +python -m pip install tlparse==0.3.30 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Install parameterized python -m pip install parameterized==0.8.1 diff --git a/.ci/pytorch/windows/cuda126.bat b/.ci/pytorch/windows/cuda126.bat index efb8cfec63e7e..2db616810ecb6 100644 --- a/.ci/pytorch/windows/cuda126.bat +++ b/.ci/pytorch/windows/cuda126.bat @@ -18,6 +18,7 @@ REM Check for optional components set USE_CUDA= set CMAKE_GENERATOR=Visual Studio 15 2017 Win64 +<<<<<<< HEAD IF "%NVTOOLSEXT_PATH%"=="" ( IF EXIST "C:\Program Files\NVIDIA Corporation\NvToolsExt\lib\x64\nvToolsExt64_1.lib" ( set NVTOOLSEXT_PATH=C:\Program Files\NVIDIA Corporation\NvToolsExt @@ -27,6 +28,8 @@ IF "%NVTOOLSEXT_PATH%"=="" ( ) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) IF "%CUDA_PATH_V126%"=="" ( IF EXIST "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6\bin\nvcc.exe" ( set "CUDA_PATH_V126=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6" @@ -37,7 +40,11 @@ IF "%CUDA_PATH_V126%"=="" ( ) IF "%BUILD_VISION%" == "" ( +<<<<<<< HEAD set TORCH_CUDA_ARCH_LIST=5.0;6.0;6.1;7.0;7.5;8.0;8.6;9.0 +======= + set TORCH_CUDA_ARCH_LIST=6.1;7.0;7.5;8.0;8.6;9.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) set TORCH_NVCC_FLAGS=-Xfatbin -compress-all ) ELSE ( set NVCC_FLAGS=-D__CUDA_NO_HALF_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=compute_90 diff --git a/.ci/pytorch/windows/cuda128.bat b/.ci/pytorch/windows/cuda128.bat index bbd349e2efb4b..0234ec324c039 100644 --- a/.ci/pytorch/windows/cuda128.bat +++ b/.ci/pytorch/windows/cuda128.bat @@ -18,6 +18,7 @@ REM Check for optional components set USE_CUDA= set CMAKE_GENERATOR=Visual Studio 15 2017 Win64 +<<<<<<< HEAD IF "%NVTOOLSEXT_PATH%"=="" ( IF EXIST "C:\Program Files\NVIDIA Corporation\NvToolsExt\lib\x64\nvToolsExt64_1.lib" ( set NVTOOLSEXT_PATH=C:\Program Files\NVIDIA Corporation\NvToolsExt @@ -27,6 +28,8 @@ IF "%NVTOOLSEXT_PATH%"=="" ( ) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) IF "%CUDA_PATH_V128%"=="" ( IF EXIST "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin\nvcc.exe" ( set "CUDA_PATH_V128=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8" @@ -37,10 +40,17 @@ IF "%CUDA_PATH_V128%"=="" ( ) IF "%BUILD_VISION%" == "" ( +<<<<<<< HEAD set TORCH_CUDA_ARCH_LIST=7.0;7.5;8.0;8.6;9.0;10.0;12.0 set TORCH_NVCC_FLAGS=-Xfatbin -compress-all ) ELSE ( set NVCC_FLAGS=-D__CUDA_NO_HALF_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=compute_90 -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_120,code=compute_120 +======= + set TORCH_CUDA_ARCH_LIST=6.1;7.0;7.5;8.0;8.6;9.0;10.0;12.0 + set TORCH_NVCC_FLAGS=-Xfatbin -compress-all +) ELSE ( + set NVCC_FLAGS=-D__CUDA_NO_HALF_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=compute_90 -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_120,code=compute_120 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) set "CUDA_PATH=%CUDA_PATH_V128%" diff --git a/.ci/pytorch/windows/cuda129.bat b/.ci/pytorch/windows/cuda129.bat index b17e6113c63e2..ad19af5363c3c 100644 --- a/.ci/pytorch/windows/cuda129.bat +++ b/.ci/pytorch/windows/cuda129.bat @@ -18,6 +18,7 @@ REM Check for optional components set USE_CUDA= set CMAKE_GENERATOR=Visual Studio 15 2017 Win64 +<<<<<<< HEAD IF "%NVTOOLSEXT_PATH%"=="" ( IF EXIST "C:\Program Files\NVIDIA Corporation\NvToolsExt\lib\x64\nvToolsExt64_1.lib" ( set NVTOOLSEXT_PATH=C:\Program Files\NVIDIA Corporation\NvToolsExt @@ -27,6 +28,8 @@ IF "%NVTOOLSEXT_PATH%"=="" ( ) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) IF "%CUDA_PATH_V129%"=="" ( IF EXIST "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9\bin\nvcc.exe" ( set "CUDA_PATH_V129=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9" diff --git a/.ci/pytorch/windows/internal/copy.bat b/.ci/pytorch/windows/internal/copy.bat index e0281c0d78a44..993f11e1e0142 100644 --- a/.ci/pytorch/windows/internal/copy.bat +++ b/.ci/pytorch/windows/internal/copy.bat @@ -1,3 +1,4 @@ +<<<<<<< HEAD if %CUDA_VERSION% geq 130 ( set "dll_path=bin\x64" @@ -19,6 +20,19 @@ copy "%CUDA_PATH%\extras\CUPTI\lib64\cupti64_*.dll*" pytorch\torch\lib copy "%CUDA_PATH%\extras\CUPTI\lib64\nvperf_host*.dll*" pytorch\torch\lib copy "C:\Program Files\NVIDIA Corporation\NvToolsExt\bin\x64\nvToolsExt64_1.dll*" pytorch\torch\lib +======= +copy "%CUDA_PATH%\bin\cusparse*64_*.dll*" pytorch\torch\lib +copy "%CUDA_PATH%\bin\cublas*64_*.dll*" pytorch\torch\lib +copy "%CUDA_PATH%\bin\cudart*64_*.dll*" pytorch\torch\lib +copy "%CUDA_PATH%\bin\curand*64_*.dll*" pytorch\torch\lib +copy "%CUDA_PATH%\bin\cufft*64_*.dll*" pytorch\torch\lib +copy "%CUDA_PATH%\bin\cusolver*64_*.dll*" pytorch\torch\lib + +copy "%CUDA_PATH%\bin\cudnn*64_*.dll*" pytorch\torch\lib +copy "%CUDA_PATH%\bin\nvrtc*64_*.dll*" pytorch\torch\lib +copy "%CUDA_PATH%\extras\CUPTI\lib64\cupti64_*.dll*" pytorch\torch\lib + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) copy "%PYTHON_LIB_PATH%\libiomp*5md.dll" pytorch\torch\lib :: Should be set in build_pytorch.bat @@ -28,3 +42,11 @@ copy "%libuv_ROOT%\bin\uv.dll" pytorch\torch\lib if exist "C:\Windows\System32\zlibwapi.dll" ( copy "C:\Windows\System32\zlibwapi.dll" pytorch\torch\lib ) +<<<<<<< HEAD +======= + +::copy nvJitLink dll is requires for cuda 12+ +if exist "%CUDA_PATH%\bin\nvJitLink_*.dll*" ( + copy "%CUDA_PATH%\bin\nvJitLink_*.dll*" pytorch\torch\lib +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/pytorch/windows/internal/cuda_install.bat b/.ci/pytorch/windows/internal/cuda_install.bat index 1349d3e661f55..b17eda7de7815 100644 --- a/.ci/pytorch/windows/internal/cuda_install.bat +++ b/.ci/pytorch/windows/internal/cuda_install.bat @@ -26,7 +26,10 @@ if exist "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR% if %CUDA_VER% EQU 126 goto cuda126 if %CUDA_VER% EQU 128 goto cuda128 if %CUDA_VER% EQU 129 goto cuda129 +<<<<<<< HEAD if %CUDA_VER% EQU 130 goto cuda130 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) echo CUDA %CUDA_VERSION_STR% is not supported exit /b 1 @@ -114,6 +117,7 @@ xcopy /Y "%SRC_DIR%\temp_build\zlib\dll_x64\*.dll" "C:\Windows\System32" goto cuda_common +<<<<<<< HEAD :cuda130 set CUDA_INSTALL_EXE=cuda_13.0.0_windows.exe @@ -141,17 +145,22 @@ xcopy /Y "%SRC_DIR%\temp_build\zlib\dll_x64\*.dll" "C:\Windows\System32" goto cuda_common +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) :cuda_common :: NOTE: We only install CUDA if we don't have it installed already. :: With GHA runners these should be pre-installed as part of our AMI process :: If you cannot find the CUDA version you want to build for here then please :: add it @ https://github.com/pytorch/test-infra/tree/main/aws/ami/windows if not exist "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin\nvcc.exe" ( +<<<<<<< HEAD if not exist "%SRC_DIR%\temp_build\NvToolsExt.7z" ( curl -k -L https://ossci-windows.s3.us-east-1.amazonaws.com/builder/NvToolsExt.7z --output "%SRC_DIR%\temp_build\NvToolsExt.7z" if errorlevel 1 exit /b 1 ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not exist "%SRC_DIR%\temp_build\gpu_driver_dlls.zip" ( curl -k -L "https://ossci-windows.s3.us-east-1.amazonaws.com/builder/additional_dlls.zip" --output "%SRC_DIR%\temp_build\gpu_driver_dlls.zip" if errorlevel 1 exit /b 1 @@ -178,6 +187,7 @@ if not exist "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_ xcopy /Y "%SRC_DIR%\temp_build\cuda\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\MSBuild\Microsoft\VC\v170\BuildCustomizations" ) +<<<<<<< HEAD echo Installing NvToolsExt... 7z x %SRC_DIR%\temp_build\NvToolsExt.7z -o"%SRC_DIR%\temp_build\NvToolsExt" mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" @@ -187,6 +197,8 @@ if not exist "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_ xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\include\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\lib\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) echo Installing cuDNN... 7z x %CUDNN_SETUP_FILE% -o"%SRC_DIR%\temp_build\cudnn" xcopy /Y "%SRC_DIR%\temp_build\cudnn\%CUDNN_FOLDER%\bin\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin" @@ -217,4 +229,7 @@ echo Setting up environment... set "PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin;%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\libnvvp;%PATH%" set "CUDA_PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" set "CUDA_PATH_V%CUDA_VER_MAJOR%_%CUDA_VER_MINOR%=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" +<<<<<<< HEAD set "NVTOOLSEXT_PATH=%ProgramFiles%\NVIDIA Corporation\NvToolsExt" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/pytorch/windows/internal/driver_update.bat b/.ci/pytorch/windows/internal/driver_update.bat index 2c173aed818b4..f9ffb6de2fd29 100644 --- a/.ci/pytorch/windows/internal/driver_update.bat +++ b/.ci/pytorch/windows/internal/driver_update.bat @@ -1,3 +1,4 @@ +<<<<<<< HEAD set WIN_DRIVER_VN=580.88 set "DRIVER_DOWNLOAD_LINK=https://ossci-windows.s3.amazonaws.com/%WIN_DRIVER_VN%-data-center-tesla-desktop-win10-win11-64bit-dch-international.exe" & REM @lint-ignore curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output %WIN_DRIVER_VN%-data-center-tesla-desktop-win10-win11-64bit-dch-international.exe @@ -7,3 +8,14 @@ start /wait %WIN_DRIVER_VN%-data-center-tesla-desktop-win10-win11-64bit-dch-inte if errorlevel 1 exit /b 1 del %WIN_DRIVER_VN%-data-center-tesla-desktop-win10-win11-64bit-dch-international.exe || ver > NUL +======= +set WIN_DRIVER_VN=528.89 +set "DRIVER_DOWNLOAD_LINK=https://ossci-windows.s3.amazonaws.com/%WIN_DRIVER_VN%-data-center-tesla-desktop-winserver-2016-2019-2022-dch-international.exe" & REM @lint-ignore +curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output %WIN_DRIVER_VN%-data-center-tesla-desktop-winserver-2016-2019-2022-dch-international.exe +if errorlevel 1 exit /b 1 + +start /wait %WIN_DRIVER_VN%-data-center-tesla-desktop-winserver-2016-2019-2022-dch-international.exe -s -noreboot +if errorlevel 1 exit /b 1 + +del %WIN_DRIVER_VN%-data-center-tesla-desktop-winserver-2016-2019-2022-dch-international.exe || ver > NUL +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/pytorch/windows/internal/install_python.bat b/.ci/pytorch/windows/internal/install_python.bat index 84d0f9caccefb..6cd069f4098de 100644 --- a/.ci/pytorch/windows/internal/install_python.bat +++ b/.ci/pytorch/windows/internal/install_python.bat @@ -1,12 +1,16 @@ set ADDITIONAL_OPTIONS="" set PYTHON_EXEC="python" +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if "%DESIRED_PYTHON%" == "3.13t" ( echo Python version is set to 3.13t set "PYTHON_INSTALLER_URL=https://www.python.org/ftp/python/3.13.0/python-3.13.0-amd64.exe" set ADDITIONAL_OPTIONS="Include_freethreaded=1" set PYTHON_EXEC="python3.13t" +<<<<<<< HEAD ) else if "%DESIRED_PYTHON%"=="3.14" ( echo Python version is set to 3.14 or 3.14t set "PYTHON_INSTALLER_URL=https://www.python.org/ftp/python/3.14.0/python-3.14.0rc1-amd64.exe" @@ -17,6 +21,10 @@ if "%DESIRED_PYTHON%" == "3.13t" ( set PYTHON_EXEC="python3.14t" ) else ( echo Python version is set to %DESIRED_PYTHON% +======= +) else ( + echo DESIRED_PYTHON not defined, Python version is set to %DESIRED_PYTHON% +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) set "PYTHON_INSTALLER_URL=https://www.python.org/ftp/python/%DESIRED_PYTHON%.0/python-%DESIRED_PYTHON%.0-amd64.exe" %= @lint-ignore =% ) @@ -28,5 +36,8 @@ start /wait "" python-amd64.exe /quiet InstallAllUsers=1 PrependPath=0 Include_t if errorlevel 1 exit /b 1 set "PATH=%CD%\Python\Scripts;%CD%\Python;%PATH%" +<<<<<<< HEAD %PYTHON_EXEC% -m pip install --upgrade pip setuptools packaging wheel if errorlevel 1 exit /b 1 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/pytorch/windows/internal/smoke_test.bat b/.ci/pytorch/windows/internal/smoke_test.bat index f671a9d0e0abb..eb803a058cae0 100644 --- a/.ci/pytorch/windows/internal/smoke_test.bat +++ b/.ci/pytorch/windows/internal/smoke_test.bat @@ -148,7 +148,18 @@ if "%NVIDIA_GPU_EXISTS%" == "0" ( goto end ) +<<<<<<< HEAD cl %PYTORCH_ROOT%\.ci\pytorch\test_example_code\check-torch-cuda.cpp torch_cpu.lib c10.lib torch_cuda.lib /EHsc /std:c++17 /link /INCLUDE:?warp_size@cuda@at@@YAHXZ +======= +set BUILD_SPLIT_CUDA= +if exist "%install_root%\lib\torch_cuda_cu.lib" if exist "%install_root%\lib\torch_cuda_cpp.lib" set BUILD_SPLIT_CUDA=ON + +if "%BUILD_SPLIT_CUDA%" == "ON" ( + cl %PYTORCH_ROOT%\.ci\pytorch\test_example_code\check-torch-cuda.cpp torch_cpu.lib c10.lib torch_cuda_cu.lib torch_cuda_cpp.lib /EHsc /std:c++17 /link /INCLUDE:?warp_size@cuda@at@@YAHXZ /INCLUDE:?_torch_cuda_cu_linker_symbol_op_cuda@native@at@@YA?AVTensor@2@AEBV32@@Z +) else ( + cl %PYTORCH_ROOT%\.ci\pytorch\test_example_code\check-torch-cuda.cpp torch_cpu.lib c10.lib torch_cuda.lib /EHsc /std:c++17 /link /INCLUDE:?warp_size@cuda@at@@YAHXZ +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .\check-torch-cuda.exe if ERRORLEVEL 1 exit /b 1 diff --git a/.ci/pytorch/windows/internal/xpu_install.bat b/.ci/pytorch/windows/internal/xpu_install.bat index f143571a56922..85b72caccaba0 100644 --- a/.ci/pytorch/windows/internal/xpu_install.bat +++ b/.ci/pytorch/windows/internal/xpu_install.bat @@ -13,9 +13,15 @@ if not exist "%SRC_DIR%\temp_build" mkdir "%SRC_DIR%\temp_build" :xpu_bundle_install_start set XPU_BUNDLE_PARENT_DIR=C:\Program Files (x86)\Intel\oneAPI +<<<<<<< HEAD set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/75d4eb97-914a-4a95-852c-7b9733d80f74/intel-deep-learning-essentials-2025.1.3.8_offline.exe set XPU_BUNDLE_PRODUCT_NAME=intel.oneapi.win.deep-learning-essentials.product set XPU_BUNDLE_VERSION=2025.1.3+5 +======= +set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/9d6d6c17-ca2d-4735-9331-99447e4a1280/intel-deep-learning-essentials-2025.0.1.28_offline.exe +set XPU_BUNDLE_PRODUCT_NAME=intel.oneapi.win.deep-learning-essentials.product +set XPU_BUNDLE_VERSION=2025.0.1+20 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) set XPU_BUNDLE_INSTALLED=0 set XPU_BUNDLE_UNINSTALL=0 set XPU_EXTRA_URL=NULL @@ -24,9 +30,15 @@ set XPU_EXTRA_VERSION=2025.0.1+1226 set XPU_EXTRA_INSTALLED=0 set XPU_EXTRA_UNINSTALL=0 +<<<<<<< HEAD if not [%XPU_VERSION%]==[] if [%XPU_VERSION%]==[2025.2] ( set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/24751ead-ddc5-4479-b9e6-f9fe2ff8b9f2/intel-deep-learning-essentials-2025.2.1.25_offline.exe set XPU_BUNDLE_VERSION=2025.2.1+20 +======= +if not [%XPU_VERSION%]==[] if [%XPU_VERSION%]==[2025.1] ( + set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/75d4eb97-914a-4a95-852c-7b9733d80f74/intel-deep-learning-essentials-2025.1.3.8_offline.exe + set XPU_BUNDLE_VERSION=2025.1.3+5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) :: Check if XPU bundle is target version or already installed @@ -90,3 +102,17 @@ if errorlevel 1 exit /b 1 del xpu_extra.exe :xpu_install_end +<<<<<<< HEAD +======= + +if not "%XPU_ENABLE_KINETO%"=="1" goto install_end +:: Install Level Zero SDK +set XPU_EXTRA_LZ_URL=https://github.com/oneapi-src/level-zero/releases/download/v1.14.0/level-zero-sdk_1.14.0.zip +curl -k -L %XPU_EXTRA_LZ_URL% --output "%SRC_DIR%\temp_build\level_zero_sdk.zip" +echo "Installing level zero SDK..." +7z x "%SRC_DIR%\temp_build\level_zero_sdk.zip" -o"%SRC_DIR%\temp_build\level_zero" +set "INCLUDE=%SRC_DIR%\temp_build\level_zero\include;%INCLUDE%" +del "%SRC_DIR%\temp_build\level_zero_sdk.zip" + +:install_end +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.ci/pytorch/windows/setup_build.bat b/.ci/pytorch/windows/setup_build.bat index dbdc9891324cc..d05abf6dc1c31 100644 --- a/.ci/pytorch/windows/setup_build.bat +++ b/.ci/pytorch/windows/setup_build.bat @@ -7,8 +7,11 @@ call "internal\install_python.bat" %PYTHON_EXEC% --version set "PATH=%CD%\Python\Lib\site-packages\cmake\data\bin;%CD%\Python\Scripts;%CD%\Python;%PATH%" +<<<<<<< HEAD if "%DESIRED_PYTHON%" == "3.14t" %PYTHON_EXEC% -m pip install numpy==2.3.2 cmake if "%DESIRED_PYTHON%" == "3.14" %PYTHON_EXEC% -m pip install numpy==2.3.2 cmake +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if "%DESIRED_PYTHON%" == "3.13t" %PYTHON_EXEC% -m pip install numpy==2.2.1 cmake if "%DESIRED_PYTHON%" == "3.13" %PYTHON_EXEC% -m pip install numpy==2.1.2 cmake if "%DESIRED_PYTHON%" == "3.12" %PYTHON_EXEC% -m pip install numpy==2.0.2 cmake diff --git a/.ci/wheel/build_wheel.sh b/.ci/wheel/build_wheel.sh index e63a68e4f1934..94fb3234813d7 100755 --- a/.ci/wheel/build_wheel.sh +++ b/.ci/wheel/build_wheel.sh @@ -124,13 +124,22 @@ popd export TH_BINARY_BUILD=1 export INSTALL_TEST=0 # dont install test binaries into site-packages +<<<<<<< HEAD export MACOSX_DEPLOYMENT_TARGET=11.0 export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} +======= +export MACOSX_DEPLOYMENT_TARGET=10.15 +export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} + +SETUPTOOLS_PINNED_VERSION="=46.0.0" +PYYAML_PINNED_VERSION="=5.3" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) EXTRA_CONDA_INSTALL_FLAGS="" CONDA_ENV_CREATE_FLAGS="" RENAME_WHEEL=true case $desired_python in +<<<<<<< HEAD 3.14t) echo "Using 3.14 deps" NUMPY_PINNED_VERSION="==2.1.0" @@ -149,6 +158,13 @@ case $desired_python in 3.13t) echo "Using 3.13 deps" NUMPY_PINNED_VERSION="==2.1.0" +======= + 3.13t) + echo "Using 3.13 deps" + SETUPTOOLS_PINNED_VERSION=">=68.0.0" + PYYAML_PINNED_VERSION=">=6.0.1" + NUMPY_PINNED_VERSION="=2.1.0" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CONDA_ENV_CREATE_FLAGS="python-freethreading" EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge" desired_python="3.13" @@ -156,6 +172,7 @@ case $desired_python in ;; 3.13) echo "Using 3.13 deps" +<<<<<<< HEAD NUMPY_PINNED_VERSION="==2.1.0" ;; 3.12) @@ -173,6 +190,39 @@ case $desired_python in *) echo "Unsupported version $desired_python" exit 1 +======= + SETUPTOOLS_PINNED_VERSION=">=68.0.0" + PYYAML_PINNED_VERSION=">=6.0.1" + NUMPY_PINNED_VERSION="=2.1.0" + ;; + 3.12) + echo "Using 3.12 deps" + SETUPTOOLS_PINNED_VERSION=">=68.0.0" + PYYAML_PINNED_VERSION=">=6.0.1" + NUMPY_PINNED_VERSION="=2.0.2" + ;; + 3.11) + echo "Using 3.11 deps" + SETUPTOOLS_PINNED_VERSION=">=46.0.0" + PYYAML_PINNED_VERSION=">=5.3" + NUMPY_PINNED_VERSION="=2.0.2" + ;; + 3.10) + echo "Using 3.10 deps" + SETUPTOOLS_PINNED_VERSION=">=46.0.0" + PYYAML_PINNED_VERSION=">=5.3" + NUMPY_PINNED_VERSION="=2.0.2" + ;; + 3.9) + echo "Using 3.9 deps" + SETUPTOOLS_PINNED_VERSION=">=46.0.0" + PYYAML_PINNED_VERSION=">=5.3" + NUMPY_PINNED_VERSION="=2.0.2" + ;; + *) + echo "Using default deps" + NUMPY_PINNED_VERSION="=1.11.3" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ;; esac @@ -181,17 +231,27 @@ tmp_env_name="wheel_py$python_nodot" conda create ${EXTRA_CONDA_INSTALL_FLAGS} -yn "$tmp_env_name" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} source activate "$tmp_env_name" +<<<<<<< HEAD PINNED_PACKAGES=( "numpy${NUMPY_PINNED_VERSION}" ) retry pip install "${PINNED_PACKAGES[@]}" -r "${pytorch_rootdir}/requirements-build.txt" pip install requests ninja typing-extensions +======= +pip install "numpy=${NUMPY_PINNED_VERSION}" "pyyaml${PYYAML_PINNED_VERSION}" requests ninja "setuptools${SETUPTOOLS_PINNED_VERSION}" typing_extensions +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) retry pip install -r "${pytorch_rootdir}/requirements.txt" || true retry brew install libomp # For USE_DISTRIBUTED=1 on macOS, need libuv, which is build as part of tensorpipe submodule export USE_DISTRIBUTED=1 +<<<<<<< HEAD +======= +if [[ -n "$CROSS_COMPILE_ARM64" ]]; then + export CMAKE_OSX_ARCHITECTURES=arm64 +fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) export USE_MKLDNN=OFF export USE_QNNPACK=OFF export BUILD_TEST=OFF @@ -199,7 +259,20 @@ export BUILD_TEST=OFF pushd "$pytorch_rootdir" echo "Calling setup.py bdist_wheel at $(date)" +<<<<<<< HEAD python setup.py bdist_wheel -d "$whl_tmp_dir" --plat-name ${mac_version} +======= +if [[ "$USE_SPLIT_BUILD" == "true" ]]; then + echo "Calling setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)" + BUILD_LIBTORCH_WHL=1 BUILD_PYTHON_ONLY=0 python setup.py bdist_wheel -d "$whl_tmp_dir" + echo "Finished setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)" + echo "Calling setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" + BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 CMAKE_FRESH=1 python setup.py bdist_wheel -d "$whl_tmp_dir" + echo "Finished setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" +else + python setup.py bdist_wheel -d "$whl_tmp_dir" +fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) echo "Finished setup.py bdist_wheel at $(date)" diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index c24a50b8b17ed..0af272b341abf 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -65,8 +65,21 @@ fi if [[ "$PACKAGE_TYPE" != libtorch ]]; then if [[ "\$BUILD_ENVIRONMENT" != *s390x* ]]; then +<<<<<<< HEAD pip install "\$pkg" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}" retry pip install -q numpy protobuf typing-extensions +======= + if [[ "$USE_SPLIT_BUILD" == "true" ]]; then + pkg_no_python="$(ls -1 /final_pkgs/torch_no_python* | sort |tail -1)" + pkg_torch="$(ls -1 /final_pkgs/torch-* | sort |tail -1)" + # todo: after folder is populated use the pypi_pkg channel instead + pip install "\$pkg_no_python" "\$pkg_torch" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}_pypi_pkg" + retry pip install -q numpy protobuf typing-extensions + else + pip install "\$pkg" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}" + retry pip install -q numpy protobuf typing-extensions + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else pip install "\$pkg" retry pip install -q numpy protobuf typing-extensions diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index aa82d36aa7ce6..11cb76268476b 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -75,7 +75,18 @@ export PYTORCH_BUILD_NUMBER=1 : <<'BLOCK_COMMENT' # Set triton version as part of PYTORCH_EXTRA_INSTALL_REQUIREMENTS TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt) +<<<<<<< HEAD TRITON_CONSTRAINT="platform_system == 'Linux'" +======= + +# Here PYTORCH_EXTRA_INSTALL_REQUIREMENTS is already set for the all the wheel builds hence append TRITON_CONSTRAINT +TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64'" + +# CUDA 12.9 builds have triton for Linux and Linux aarch64 binaries. +if [[ "$DESIRED_CUDA" == "cu129" ]]; then + TRITON_CONSTRAINT="platform_system == 'Linux'" +fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" && ! "$PYTORCH_BUILD_VERSION" =~ .*xpu.* ]]; then TRITON_REQUIREMENT="triton==${TRITON_VERSION}; ${TRITON_CONSTRAINT}" @@ -132,6 +143,10 @@ export DESIRED_PYTHON="${DESIRED_PYTHON:-}" export DESIRED_CUDA="$DESIRED_CUDA" export LIBTORCH_VARIANT="${LIBTORCH_VARIANT:-}" export BUILD_PYTHONLESS="${BUILD_PYTHONLESS:-}" +<<<<<<< HEAD +======= +export USE_SPLIT_BUILD="${USE_SPLIT_BUILD:-}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if [[ "${OSTYPE}" == "msys" ]]; then export LIBTORCH_CONFIG="${LIBTORCH_CONFIG:-}" if [[ "${LIBTORCH_CONFIG:-}" == 'debug' ]]; then diff --git a/.circleci/scripts/binary_upload.sh b/.circleci/scripts/binary_upload.sh index d48077e112455..81b8a06778c42 100755 --- a/.circleci/scripts/binary_upload.sh +++ b/.circleci/scripts/binary_upload.sh @@ -23,6 +23,13 @@ if [[ "${DRY_RUN}" = "disabled" ]]; then AWS_S3_CP="aws s3 cp" fi +<<<<<<< HEAD +======= +if [[ "${USE_SPLIT_BUILD:-false}" == "true" ]]; then + UPLOAD_SUBFOLDER="${UPLOAD_SUBFOLDER}_pypi_pkg" +fi + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # this is special build with all dependencies packaged if [[ ${BUILD_NAME} == *-full* ]]; then UPLOAD_SUBFOLDER="${UPLOAD_SUBFOLDER}_full" @@ -51,12 +58,23 @@ s3_upload() { s3_upload_dir="${s3_root_dir}/${UPLOAD_SUBFOLDER}/" fi ( +<<<<<<< HEAD +======= + cache_control_flag="" + if [[ "${UPLOAD_CHANNEL}" = "test" ]]; then + cache_control_flag="--cache-control='no-cache,no-store,must-revalidate'" + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for pkg in ${PKG_DIR}/*.${extension}; do ( set -x shm_id=$(sha256sum "${pkg}" | awk '{print $1}') ${AWS_S3_CP} --no-progress --acl public-read "${pkg}" "${s3_upload_dir}" \ +<<<<<<< HEAD --metadata "checksum-sha256=${shm_id}" +======= + --metadata "checksum-sha256=${shm_id}" ${cache_control_flag} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) done ) diff --git a/.circleci/scripts/binary_windows_build.sh b/.circleci/scripts/binary_windows_build.sh index 18dcde50e2b65..60ffb1e15a817 100644 --- a/.circleci/scripts/binary_windows_build.sh +++ b/.circleci/scripts/binary_windows_build.sh @@ -15,7 +15,12 @@ fi if [[ "$DESIRED_CUDA" == 'xpu' ]]; then export VC_YEAR=2022 export USE_SCCACHE=0 +<<<<<<< HEAD export XPU_VERSION=2025.2 +======= + export XPU_VERSION=2025.1 + export XPU_ENABLE_KINETO=1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi echo "Free space on filesystem before build:" diff --git a/.circleci/scripts/binary_windows_test.sh b/.circleci/scripts/binary_windows_test.sh index 9326d9037e8b3..eb5b15b762cd1 100644 --- a/.circleci/scripts/binary_windows_test.sh +++ b/.circleci/scripts/binary_windows_test.sh @@ -8,7 +8,11 @@ export VC_YEAR=2022 if [[ "$DESIRED_CUDA" == 'xpu' ]]; then export VC_YEAR=2022 +<<<<<<< HEAD export XPU_VERSION=2025.2 +======= + export XPU_VERSION=2025.1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi pushd "$PYTORCH_ROOT/.ci/pytorch/" diff --git a/.clang-format b/.clang-format index 67b722d967c7e..448aa5d0f343d 100644 --- a/.clang-format +++ b/.clang-format @@ -120,7 +120,10 @@ UseTab: Never Language: ObjC ColumnLimit: 120 AlignAfterOpenBracket: Align +<<<<<<< HEAD IndentWidth: 2 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ObjCBlockIndentWidth: 2 ObjCSpaceAfterProperty: false ObjCSpaceBeforeProtocolList: false diff --git a/.devcontainer/README.md b/.devcontainer/README.md index 7ef8da027ad9e..c7e65eeecedd1 100644 --- a/.devcontainer/README.md +++ b/.devcontainer/README.md @@ -61,8 +61,13 @@ You are now all set to start developing with PyTorch in a DevContainer environme ## Step 8: Build PyTorch To build pytorch from source, simply run: +<<<<<<< HEAD ```bash python -m pip install --no-build-isolation -v -e . +======= + ``` + python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` The process involves compiling thousands of files, and would take a long time. Fortunately, the compiled objects can be useful for your next build. When you modify some files, you only need to compile the changed files the next time. diff --git a/.editorconfig b/.editorconfig index e9581612a050e..0456b5cd51a07 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,11 +1,15 @@ root = true [*] +<<<<<<< HEAD charset = utf-8 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) end_of_line = lf insert_final_newline = true # Python +<<<<<<< HEAD [*.{py,pyi,py.in,pyi.in}] indent_style = space indent_size = 4 @@ -34,3 +38,12 @@ indent_style = tab indent_style = space indent_size = 2 end_of_line = crlf +======= +[*.py] +indent_style = space +indent_size = 4 + +# Make +[Makefile] +indent_style = tab +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.flake8 b/.flake8 index fc9ab167fbeef..ae1fc9c9fe246 100644 --- a/.flake8 +++ b/.flake8 @@ -7,12 +7,20 @@ max-line-length = 120 # C408 ignored because we like the dict keyword argument syntax # E501 is not flexible enough, we're using B950 instead ignore = +<<<<<<< HEAD E203,E305,E402,E501,E704,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824, +======= + E203,E305,E402,E501,E704,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # shebang has extra meaning in fbcode lints, so I think it's not worth trying # to line this up with executable bit EXE001, # these ignores are from flake8-bugbear; please fix! +<<<<<<< HEAD B007,B008,B017,B019,B023,B028,B903,B904,B905,B906,B907,B908,B910 +======= + B007,B008,B017,B019,B023,B028,B903,B904,B905,B906,B907 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # these ignores are from flake8-comprehensions; please fix! C407, # these ignores are from flake8-logging-format; please fix! @@ -48,7 +56,10 @@ per-file-ignores = torch/__init__.py: F401,TOR901 torch/_custom_op/impl.py: TOR901 torch/_export/serde/upgrade.py: TOR901 +<<<<<<< HEAD torch/_functorch/predispatch.py: TOR901 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch/_functorch/vmap.py: TOR901 torch/_inductor/test_operators.py: TOR901 torch/_library/abstract_impl.py: TOR901 diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index 798dee312306d..c6976b9c5f14c 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -12,9 +12,13 @@ self-hosted-runner: - linux.9xlarge.ephemeral - am2.linux.9xlarge.ephemeral - linux.12xlarge +<<<<<<< HEAD - linux.12xlarge.memory - linux.24xlarge - linux.24xlarge.memory +======= + - linux.24xlarge +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - linux.24xlarge.ephemeral - linux.24xlarge.amd - linux.arm64.2xlarge @@ -55,6 +59,7 @@ self-hosted-runner: - linux.rocm.gpu.mi250 - linux.rocm.gpu.2 - linux.rocm.gpu.4 +<<<<<<< HEAD # gfx942 runners - linux.rocm.gpu.gfx942.1 - linux.rocm.gpu.gfx942.2 @@ -62,6 +67,18 @@ self-hosted-runner: - rocm-docker # Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors) - macos-m1-stable +======= + # MI300 runners + - linux.rocm.gpu.mi300.2 + - linux.rocm.gpu.mi300.4 + - rocm-docker + # Repo-specific Apple hosted runners + - macos-m1-ultra + - macos-m2-14 + # Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors) + - macos-m1-stable + - macos-m1-13 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - macos-m1-14 # GitHub-hosted MacOS runners - macos-latest-xlarge diff --git a/.github/actions/build-android/action.yml b/.github/actions/build-android/action.yml new file mode 100644 index 0000000000000..bccd42aa42f2c --- /dev/null +++ b/.github/actions/build-android/action.yml @@ -0,0 +1,78 @@ +name: build android + +description: build android for a specific arch + +inputs: + arch: + description: arch to build + required: true + arch-for-build-env: + description: | + arch to pass to build environment. + This is currently different than the arch name we use elsewhere, which + should be fixed. + required: true + github-secret: + description: github token + required: true + build-environment: + required: true + description: Top-level label for what's being built/tested. + docker-image: + required: true + description: Name of the base docker image to build with. + branch: + required: true + description: What branch we are building on. +outputs: + container_id: + description: Docker container identifier used to build the artifacts + value: ${{ steps.build.outputs.container_id }} + +runs: + using: composite + steps: + - name: Build-${{ inputs.arch }} + id: build + shell: bash + env: + BRANCH: ${{ inputs.branch }} + BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-${{ inputs.arch-for-build-env }}-build" + AWS_DEFAULT_REGION: us-east-1 + PR_NUMBER: ${{ github.event.pull_request.number }} + SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + SCCACHE_REGION: us-east-1 + DOCKER_IMAGE: ${{ inputs.docker-image }} + MATRIX_ARCH: ${{ inputs.arch }} + run: | + # detached container should get cleaned up by teardown_ec2_linux + set -exo pipefail + export container_name + container_name=$(docker run \ + -e BUILD_ENVIRONMENT \ + -e MAX_JOBS="$(nproc --ignore=2)" \ + -e AWS_DEFAULT_REGION \ + -e PR_NUMBER \ + -e SHA1 \ + -e BRANCH \ + -e SCCACHE_BUCKET \ + -e SCCACHE_REGION \ + -e SKIP_SCCACHE_INITIALIZATION=1 \ + --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ + --security-opt seccomp=unconfined \ + --cap-add=SYS_PTRACE \ + --tty \ + --detach \ + --user jenkins \ + -w /var/lib/jenkins/workspace \ + "${DOCKER_IMAGE}" + ) + git submodule sync && git submodule update -q --init --recursive --depth 1 + docker cp "${GITHUB_WORKSPACE}/." "${container_name}:/var/lib/jenkins/workspace" + (echo "sudo chown -R jenkins . && .ci/pytorch/build.sh && find ${BUILD_ROOT} -type f -name "*.a" -or -name "*.o" -delete" | docker exec -u jenkins -i "${container_name}" bash) 2>&1 + + # Copy install binaries back + mkdir -p "${GITHUB_WORKSPACE}/build_android_install_${MATRIX_ARCH}" + docker cp "${container_name}:/var/lib/jenkins/workspace/build_android/install" "${GITHUB_WORKSPACE}/build_android_install_${MATRIX_ARCH}" + echo "container_id=${container_name}" >> "${GITHUB_OUTPUT}" diff --git a/.github/actions/checkout-pytorch/action.yml b/.github/actions/checkout-pytorch/action.yml index 15f193ef3a5dc..b64267fdf45c3 100644 --- a/.github/actions/checkout-pytorch/action.yml +++ b/.github/actions/checkout-pytorch/action.yml @@ -57,6 +57,7 @@ runs: submodules: ${{ inputs.submodules }} show-progress: false +<<<<<<< HEAD - name: Clean submodules post checkout id: clean-submodules if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' }} @@ -72,6 +73,8 @@ runs: git submodule foreach --recursive git clean -ffdx fi +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Clean workspace (try again) if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && (steps.first-clean.outcome != 'success' || steps.first-checkout-attempt.outcome != 'success') }} diff --git a/.github/actions/filter-test-configs/action.yml b/.github/actions/filter-test-configs/action.yml index 338fc0c2a844c..0fc3a4ac53048 100644 --- a/.github/actions/filter-test-configs/action.yml +++ b/.github/actions/filter-test-configs/action.yml @@ -70,7 +70,11 @@ runs: set -eux # PyYAML 6.0 doesn't work with MacOS x86 anymore # This must run on Python-3.7 (AmazonLinux2) so can't use request=3.32.2 +<<<<<<< HEAD python3 -m pip install requests==2.27.1 pyyaml==6.0.2 +======= + python3 -m pip install requests==2.27.1 pyyaml==6.0.1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Parse ref id: parse-ref @@ -125,7 +129,11 @@ runs: TAG: ${{ steps.parse-ref.outputs.tag }} EVENT_NAME: ${{ github.event_name }} SCHEDULE: ${{ github.event.schedule }} +<<<<<<< HEAD HEAD_BRANCH: ${{ steps.parse-ref.outputs.branch }} +======= + HEAD_BRANCH: ${{ github.event.workflow_run.head_branch }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) id: filter run: | echo "Workflow: ${GITHUB_WORKFLOW}" diff --git a/.github/actions/linux-test/action.yml b/.github/actions/linux-test/action.yml index 32fe1d7385b18..df7e978491741 100644 --- a/.github/actions/linux-test/action.yml +++ b/.github/actions/linux-test/action.yml @@ -126,7 +126,11 @@ runs: shell: bash continue-on-error: true run: | +<<<<<<< HEAD python3 -m pip install psutil==5.9.8 nvidia-ml-py==11.525.84 +======= + python3 -m pip install psutil==5.9.1 nvidia-ml-py==11.525.84 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python3 -m tools.stats.monitor > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" diff --git a/.github/actions/reuse-old-whl/reuse_old_whl.py b/.github/actions/reuse-old-whl/reuse_old_whl.py index def0276a9c8a3..2cd832d2568f6 100644 --- a/.github/actions/reuse-old-whl/reuse_old_whl.py +++ b/.github/actions/reuse-old-whl/reuse_old_whl.py @@ -304,7 +304,12 @@ def change_content_to_new_version(file: Union[str, Path]) -> None: def set_output() -> None: +<<<<<<< HEAD print("Setting output reuse=true") +======= + # Disable for now so we can monitor first + # pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if os.getenv("GITHUB_OUTPUT"): with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env: print("reuse=true", file=env) diff --git a/.github/actions/setup-rocm/action.yml b/.github/actions/setup-rocm/action.yml index a58db801b1cf8..053e25dba2d7a 100644 --- a/.github/actions/setup-rocm/action.yml +++ b/.github/actions/setup-rocm/action.yml @@ -59,6 +59,14 @@ runs: echo "$msg" exit 1 fi +<<<<<<< HEAD +======= + if [[ $ngpu -eq 1 ]]; then + echo "Error: only 1 GPU detected, at least 2 GPUs are needed for distributed jobs" + echo "$msg" + exit 1 + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Runner diskspace health check uses: pytorch/pytorch/.github/actions/diskspace-cleanup@main diff --git a/.github/actions/setup-win/action.yml b/.github/actions/setup-win/action.yml index 2ea330f93b490..90850b5551edb 100644 --- a/.github/actions/setup-win/action.yml +++ b/.github/actions/setup-win/action.yml @@ -6,12 +6,15 @@ inputs: cuda-version: description: which cuda version to install, 'cpu' for none required: true +<<<<<<< HEAD python-version: required: false type: string default: "3.10" description: | The python version to be used. Will be 3.10 by default +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) runs: using: composite @@ -44,24 +47,34 @@ runs: CONDA="C:\Jenkins\Miniconda3\condabin\conda.bat" { +<<<<<<< HEAD echo "CONDA=${CONDA}"; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) echo "CONDA_RUN=${CONDA} run --no-capture-output"; echo "CONDA_BUILD=${CONDA} run conda-build"; echo "CONDA_INSTALL=${CONDA} install"; } >> "${GITHUB_ENV}" - name: Setup Python3 +<<<<<<< HEAD env: PYTHON_VERSION: ${{ inputs.python-version }} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shell: bash run: | set +e set -x +<<<<<<< HEAD # Create new py_tmp env with python-version ${CONDA} create -y -n py_tmp python=${PYTHON_VERSION} intel-openmp libuv PYTHON3=$(${CONDA_RUN} -n py_tmp which python3) +======= + PYTHON3=$(${CONDA_RUN} which python3) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) EXIT_CODE=$? if [[ "${EXIT_CODE}" == "0" ]]; then @@ -74,7 +87,11 @@ runs: # installation, which is Python 3 based. Its Python is default to Python 3. Further, there # is also the Miniconda installation that is Python 2 based, and both can be installed if # needed. In both cases, Python binary is just called python +<<<<<<< HEAD PYTHON=$(${CONDA_RUN} -n py_tmp which python) +======= + PYTHON=$(${CONDA_RUN} which python) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) EXIT_CODE=$? if [[ "${EXIT_CODE}" == "0" ]]; then diff --git a/.github/actions/test-pytorch-binary/action.yml b/.github/actions/test-pytorch-binary/action.yml index d4b8be8b609a0..70415bcaad20c 100644 --- a/.github/actions/test-pytorch-binary/action.yml +++ b/.github/actions/test-pytorch-binary/action.yml @@ -24,6 +24,10 @@ runs: -e PYTORCH_FINAL_PACKAGE_DIR \ -e PYTORCH_ROOT \ -e SKIP_ALL_TESTS \ +<<<<<<< HEAD +======= + -e USE_SPLIT_BUILD \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) --tty \ --detach \ -v "${GITHUB_WORKSPACE}/pytorch:/pytorch" \ diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index b0255e764c594..12b16c097bbbd 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1,5 @@ +<<<<<<< HEAD 27fc2493d383354a008106f22f3be232badee9a1 +======= +4e94321c54617dd738a05bfedfc28bc0fa635b5c +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/ci_commit_pins/fbgemm_rocm.txt b/.github/ci_commit_pins/fbgemm_rocm.txt index db140a31f3fa4..f743b45aa493d 100644 --- a/.github/ci_commit_pins/fbgemm_rocm.txt +++ b/.github/ci_commit_pins/fbgemm_rocm.txt @@ -1 +1,5 @@ +<<<<<<< HEAD 7f1de94a4c2d14f59ad4ca84538c36084ea6b2c8 +======= +5fb5024118e9bb9decf96c2b0b1a8f0010bf56be +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/ci_commit_pins/torchbench.txt b/.github/ci_commit_pins/torchbench.txt new file mode 100644 index 0000000000000..efbc3ceeb2afe --- /dev/null +++ b/.github/ci_commit_pins/torchbench.txt @@ -0,0 +1 @@ +e03a63be43e33596f7f0a43b0f530353785e4a59 diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index ee530f8c8b210..c714bf1cf1d2d 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1,5 @@ +<<<<<<< HEAD r2.9 +======= +r2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/label_to_label.yml b/.github/label_to_label.yml index 0cd56143535fe..4b0df90713fdd 100644 --- a/.github/label_to_label.yml +++ b/.github/label_to_label.yml @@ -48,6 +48,7 @@ - "module: dynamic shapes" then: - "oncall: pt2" +<<<<<<< HEAD - any: - "release notes: distributed (c10d)" - "release notes: distributed (symm_mem)" @@ -57,3 +58,5 @@ - "oncall: distributed" then: - "ciflow/h100-distributed" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index 354381755ce5b..8cb3e2192024b 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -76,7 +76,10 @@ - .github/ci_commit_pins/audio.txt - .github/ci_commit_pins/vision.txt - .github/ci_commit_pins/torchdynamo.txt +<<<<<<< HEAD - .github/ci_commit_pins/vllm.txt +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - .ci/docker/ci_commit_pins/triton.txt approved_by: - pytorchbot @@ -131,6 +134,24 @@ - Lint - pull +<<<<<<< HEAD +======= +- name: Mobile + patterns: + - ios/** + - android/** + - test/mobile/** + approved_by: + - linbinyu + - IvanKobzarev + - dreiss + - raziel + mandatory_checks_name: + - EasyCLA + - Lint + - pull + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: PrimTorch patterns: - torch/_meta_registrations.py @@ -370,7 +391,10 @@ - leslie-fang-intel - jgong5 - EikanWang +<<<<<<< HEAD - CaoE +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mandatory_checks_name: - EasyCLA - Lint @@ -422,7 +446,10 @@ approved_by: - leslie-fang-intel - jgong5 +<<<<<<< HEAD - CaoE +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mandatory_checks_name: - EasyCLA - Lint @@ -477,6 +504,7 @@ - srossross - chillee - zou3519 +<<<<<<< HEAD - guilhermeleobas mandatory_checks_name: - EasyCLA @@ -494,6 +522,8 @@ - test/inductor_skips/** approved_by: - guilhermeleobas +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mandatory_checks_name: - EasyCLA - Lint diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index a0aa6921b92ba..d5c53c1ab0da6 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -4,7 +4,10 @@ ciflow_push_tags: - ciflow/binaries - ciflow/binaries_libtorch - ciflow/binaries_wheel +<<<<<<< HEAD - ciflow/triton_binaries +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - ciflow/inductor - ciflow/inductor-periodic - ciflow/inductor-rocm @@ -22,20 +25,29 @@ ciflow_push_tags: - ciflow/rocm - ciflow/rocm-mi300 - ciflow/s390 +<<<<<<< HEAD - ciflow/riscv64 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - ciflow/slow - ciflow/trunk - ciflow/unstable - ciflow/xpu +<<<<<<< HEAD - ciflow/vllm +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - ciflow/torchbench - ciflow/op-benchmark - ciflow/pull - ciflow/h100 - ciflow/h100-distributed +<<<<<<< HEAD - ciflow/win-arm64 - ciflow/h100-symm-mem - ciflow/h100-cutlass-backend +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) retryable_workflows: - pull - trunk diff --git a/.github/requirements-gha-cache.txt b/.github/requirements-gha-cache.txt index c274ca1e5914d..e4085fcc6adbb 100644 --- a/.github/requirements-gha-cache.txt +++ b/.github/requirements-gha-cache.txt @@ -1,15 +1,28 @@ # This file is to cache other dependencies not specified elsewhere in: +<<<<<<< HEAD # requirements.txt # requirements-build.txt +======= +# requirement.txt +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # docs/requirements.txt # docs/cpp/requirements.txt # functorch/docs/requirements.txt # .ci/docker/requirements-ci.txt boto3==1.35.42 jinja2==3.1.6 +<<<<<<< HEAD lintrunner==0.12.7 ninja==1.10.0.post1 nvidia-ml-py==11.525.84 pyyaml==6.0.2 requests==2.32.4 rich==14.1.0 +======= +lintrunner==0.10.7 +ninja==1.10.0.post1 +nvidia-ml-py==11.525.84 +pyyaml==6.0 +requests==2.32.4 +rich==10.9.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/requirements/conda-env-macOS-ARM64 b/.github/requirements/conda-env-macOS-ARM64 new file mode 100644 index 0000000000000..b6e9a6ce9f3e5 --- /dev/null +++ b/.github/requirements/conda-env-macOS-ARM64 @@ -0,0 +1,5 @@ +# Not pinning certifi so that we can always get the latest certificates +certifi +pip=23.2.1 +pkg-config=0.29.2 +wheel=0.37.1 diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index 3a27cac46f71f..ee8d9a5bb4e3e 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -2,7 +2,11 @@ boto3==1.35.42 cmake==3.27.* expecttest==0.3.0 fbscribelogger==0.1.7 +<<<<<<< HEAD filelock==3.18.0 +======= +filelock==3.6.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) hypothesis==6.56.4 librosa>=0.6.2 mpmath==1.3.0 @@ -16,7 +20,11 @@ packaging==23.1 parameterized==0.8.1 pillow==10.3.0 protobuf==5.29.4 +<<<<<<< HEAD psutil==5.9.8 +======= +psutil==5.9.1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pygments==2.15.0 pytest-cpp==2.3.0 pytest-flakefinder==1.1.0 @@ -28,9 +36,17 @@ pyyaml==6.0.2 scipy==1.12.0 setuptools==72.1.0 sympy==1.13.3 +<<<<<<< HEAD tlparse==0.4.0 +======= +tlparse==0.3.30 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensorboard==2.13.0 typing-extensions==4.12.2 unittest-xml-reporting<=3.2.0,>=2.0.0 xdoctest==1.1.0 +<<<<<<< HEAD z3-solver==4.15.1.0 +======= +z3-solver==4.12.2.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py index f2851e3317256..6109ee8585f26 100644 --- a/.github/scripts/build_triton_wheel.py +++ b/.github/scripts/build_triton_wheel.py @@ -119,7 +119,10 @@ def build_triton( ["git", "checkout", f"release/{ver}.{rev}.x"], cwd=triton_basedir ) else: +<<<<<<< HEAD check_call(["git", "fetch", "origin", commit_hash], cwd=triton_basedir) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) check_call(["git", "checkout", commit_hash], cwd=triton_basedir) # change built wheel name and version diff --git a/.github/scripts/delete_old_branches.py b/.github/scripts/delete_old_branches.py index 8032008edf122..63a82ca1b3dd5 100644 --- a/.github/scripts/delete_old_branches.py +++ b/.github/scripts/delete_old_branches.py @@ -275,7 +275,11 @@ def delete_branches() -> None: delete_branch(git_repo, branch) +<<<<<<< HEAD def delete_old_tags() -> None: +======= +def delete_old_ciflow_tags() -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Deletes ciflow tags if they are associated with a closed PR or a specific # commit. Lightweight tags don't have information about the date they were # created, so we can't check how old they are. The script just assumes that @@ -288,14 +292,20 @@ def delete_tag(tag: str) -> None: delete_branch(git_repo, f"refs/tags/{tag}") tags = git_repo._run_git("tag").splitlines() +<<<<<<< HEAD CIFLOW_TAG_REGEX = re.compile(r"^ciflow\/.*\/(\d{5,6}|[0-9a-f]{40})$") AUTO_REVERT_TAG_REGEX = re.compile(r"^trunk\/[0-9a-f]{40}$") +======= + open_pr_numbers = [x["number"] for x in get_open_prs()] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for tag in tags: try: if ESTIMATED_TOKENS[0] > 400: print("Estimated tokens exceeded, exiting") break +<<<<<<< HEAD if not CIFLOW_TAG_REGEX.match(tag) and not AUTO_REVERT_TAG_REGEX.match(tag): continue @@ -311,6 +321,18 @@ def delete_tag(tag: str) -> None: if tag_age_days > 7: print(f"[{tag}] Tag is older than 7 days, deleting") +======= + if not tag.startswith("ciflow/"): + continue + re_match_pr = re.match(r"^ciflow\/.*\/(\d{5,6})$", tag) + re_match_sha = re.match(r"^ciflow\/.*\/([0-9a-f]{40})$", tag) + if re_match_pr: + pr_number = int(re_match_pr.group(1)) + if pr_number in open_pr_numbers: + continue + delete_tag(tag) + elif re_match_sha: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) delete_tag(tag) except Exception as e: print(f"Failed to check tag {tag}: {e}") @@ -318,4 +340,8 @@ def delete_tag(tag: str) -> None: if __name__ == "__main__": delete_branches() +<<<<<<< HEAD delete_old_tags() +======= + delete_old_ciflow_tags() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/scripts/filter_test_configs.py b/.github/scripts/filter_test_configs.py index dd16dbc18db25..99c8e6ff3d946 100755 --- a/.github/scripts/filter_test_configs.py +++ b/.github/scripts/filter_test_configs.py @@ -18,7 +18,10 @@ REENABLE_TEST_REGEX = "(?i)(Close(d|s)?|Resolve(d|s)?|Fix(ed|es)?) (#|https://github.com/pytorch/pytorch/issues/)([0-9]+)" +<<<<<<< HEAD MAIN_BRANCH = "main" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PREFIX = "test-config/" @@ -41,9 +44,15 @@ def is_cuda_or_rocm_job(job_name: Optional[str]) -> bool: } # The link to the published list of disabled jobs +<<<<<<< HEAD DISABLED_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/disabled-jobs.json?versionId=hjktHz2WOejHpxKpkqpDknTt5rMTM9KK" # and unstable jobs UNSTABLE_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/unstable-jobs.json?versionId=wrjdvvQTJxgvMO.rGw5MEuMsj6XbjuV7" +======= +DISABLED_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/disabled-jobs.json?versionId=HnkH0xQWnnsoeMsSIVf9291NE5c4jWSa" +# and unstable jobs +UNSTABLE_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/unstable-jobs.json?versionId=iP_F8gBs60PfOMAJ8gnn1paVrzM1WYsK" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Some constants used to handle disabled and unstable jobs JOB_NAME_SEP = "/" @@ -98,7 +107,11 @@ def parse_args() -> Any: parser.add_argument( "--branch", type=str, +<<<<<<< HEAD default=MAIN_BRANCH, +======= + default="main", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) help="the branch name", ) return parser.parse_args() @@ -457,7 +470,10 @@ def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> An def set_output(name: str, val: Any) -> None: +<<<<<<< HEAD print(f"Setting output {name}={val}") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if os.getenv("GITHUB_OUTPUT"): with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env: print(f"{name}={val}", file=env) @@ -497,20 +513,28 @@ def check_for_setting(labels: set[str], body: str, setting: str) -> bool: def perform_misc_tasks( +<<<<<<< HEAD labels: set[str], test_matrix: dict[str, list[Any]], job_name: str, pr_body: str, branch: Optional[str] = None, +======= + labels: set[str], test_matrix: dict[str, list[Any]], job_name: str, pr_body: str +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: """ In addition to apply the filter logic, the script also does the following misc tasks to set keep-going and is-unstable variables """ +<<<<<<< HEAD set_output( "keep-going", branch == MAIN_BRANCH or check_for_setting(labels, pr_body, "keep-going"), ) +======= + set_output("keep-going", check_for_setting(labels, pr_body, "keep-going")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) set_output( "ci-verbose-test-logs", check_for_setting(labels, pr_body, "ci-verbose-test-logs"), @@ -633,7 +657,10 @@ def main() -> None: test_matrix=filtered_test_matrix, job_name=args.job_name, pr_body=pr_body if pr_body else "", +<<<<<<< HEAD branch=args.branch, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Set the filtered test matrix as the output diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 4dc97ee6a284b..0bd76fc0d01df 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -16,17 +16,29 @@ # NOTE: Please also update the CUDA sources in `PIP_SOURCES` in tools/nightly.py when changing this +<<<<<<< HEAD CUDA_ARCHES = ["12.6", "12.8", "13.0"] +======= +CUDA_ARCHES = ["12.6", "12.8", "12.9"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CUDA_STABLE = "12.8" CUDA_ARCHES_FULL_VERSION = { "12.6": "12.6.3", "12.8": "12.8.1", +<<<<<<< HEAD "13.0": "13.0.0", +======= + "12.9": "12.9.1", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } CUDA_ARCHES_CUDNN_VERSION = { "12.6": "9", "12.8": "9", +<<<<<<< HEAD "13.0": "9", +======= + "12.9": "9", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } # NOTE: Please also update the ROCm sources in `PIP_SOURCES` in tools/nightly.py when changing this @@ -38,11 +50,16 @@ CPU_S390X_ARCH = ["cpu-s390x"] +<<<<<<< HEAD CUDA_AARCH64_ARCHES = ["12.6-aarch64", "12.8-aarch64", "13.0-aarch64"] +======= +CUDA_AARCH64_ARCHES = ["12.9-aarch64"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PYTORCH_EXTRA_INSTALL_REQUIREMENTS = { "12.6": ( +<<<<<<< HEAD "nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | " "nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | " "nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | " @@ -114,6 +131,76 @@ "tcmlib==1.4.0 | " "umf==0.11.0 | " "intel-pti==0.13.1" +======= + "nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64'" + ), + "12.8": ( + "nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64'" + ), + "12.9": ( + "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'" + ), + "xpu": ( + "intel-cmplr-lib-rt==2025.1.1 | " + "intel-cmplr-lib-ur==2025.1.1 | " + "intel-cmplr-lic-rt==2025.1.1 | " + "intel-sycl-rt==2025.1.1 | " + "oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "onemkl-sycl-blas==2025.1.0 | " + "onemkl-sycl-dft==2025.1.0 | " + "onemkl-sycl-lapack==2025.1.0 | " + "onemkl-sycl-rng==2025.1.0 | " + "onemkl-sycl-sparse==2025.1.0 | " + "dpcpp-cpp-rt==2025.1.1 | " + "intel-opencl-rt==2025.1.1 | " + "mkl==2025.1.0 | " + "intel-openmp==2025.1.1 | " + "tbb==2022.1.0 | " + "tcmlib==1.3.0 | " + "umf==0.10.0 | " + "intel-pti==0.12.3" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), } @@ -124,7 +211,13 @@ def get_nccl_wheel_version(arch_version: str) -> str: requirements = map( str.strip, re.split("[;|]", PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version]) ) +<<<<<<< HEAD return next(x for x in requirements if x.startswith("nvidia-nccl")).split("==")[1] +======= + return next(x for x in requirements if x.startswith("nvidia-nccl-cu")).split("==")[ + 1 + ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def read_nccl_pin(arch_version: str) -> str: @@ -191,7 +284,11 @@ def arch_type(arch_version: str) -> str: "cpu": "libtorch-cxx11-builder:cpu", } +<<<<<<< HEAD FULL_PYTHON_VERSIONS = ["3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t"] +======= +FULL_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13", "3.13t"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str: @@ -271,6 +368,10 @@ def generate_wheels_matrix( os: str, arches: Optional[list[str]] = None, python_versions: Optional[list[str]] = None, +<<<<<<< HEAD +======= + use_split_build: bool = False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> list[dict[str, str]]: package_type = "wheel" if os == "linux" or os == "linux-aarch64" or os == "linux-s390x": @@ -309,6 +410,7 @@ def generate_wheels_matrix( else arch_version ) +<<<<<<< HEAD # TODO: Enable python 3.14 for rest if os not in [ "linux", @@ -323,6 +425,25 @@ def generate_wheels_matrix( if ( arch_version in ["13.0", "12.8", "12.6"] +======= + # TODO: Enable python 3.13t on cpu-s390x + if gpu_arch_type == "cpu-s390x" and python_version == "3.13t": + continue + + if use_split_build and ( + arch_version not in ["12.6", "12.8", "12.9", "cpu"] or os != "linux" + ): + raise RuntimeError( + "Split build is only supported on linux with cuda 12* and cpu.\n" + f"Currently attempting to build on arch version {arch_version} and os {os}.\n" + "Please modify the matrix generation to exclude this combination." + ) + + # cuda linux wheels require PYTORCH_EXTRA_INSTALL_REQUIREMENTS to install + + if ( + arch_version in ["12.9", "12.8", "12.6"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and os == "linux" or arch_version in CUDA_AARCH64_ARCHES ): @@ -333,6 +454,10 @@ def generate_wheels_matrix( "gpu_arch_type": gpu_arch_type, "gpu_arch_version": gpu_arch_version, "desired_cuda": desired_cuda, +<<<<<<< HEAD +======= + "use_split_build": "True" if use_split_build else "False", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "container_image": WHEEL_CONTAINER_IMAGES[arch_version].split( ":" )[0], @@ -355,6 +480,33 @@ def generate_wheels_matrix( ), # include special case for aarch64 build, remove the -aarch64 postfix } ) +<<<<<<< HEAD +======= + # Special build building to use on Colab. Python 3.11 for 12.6 CUDA + if python_version == "3.11" and arch_version == CUDA_STABLE: + ret.append( + { + "python_version": python_version, + "gpu_arch_type": gpu_arch_type, + "gpu_arch_version": gpu_arch_version, + "desired_cuda": translate_desired_cuda( + gpu_arch_type, gpu_arch_version + ), + "use_split_build": "True" if use_split_build else "False", + "container_image": WHEEL_CONTAINER_IMAGES[ + arch_version + ].split(":")[0], + "container_image_tag_prefix": WHEEL_CONTAINER_IMAGES[ + arch_version + ].split(":")[1], + "package_type": package_type, + "pytorch_extra_install_requirements": "", + "build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-full".replace( # noqa: B950 + ".", "_" + ), + } + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: ret.append( { @@ -364,6 +516,10 @@ def generate_wheels_matrix( "desired_cuda": translate_desired_cuda( gpu_arch_type, gpu_arch_version ), +<<<<<<< HEAD +======= + "use_split_build": "True" if use_split_build else "False", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "container_image": WHEEL_CONTAINER_IMAGES[arch_version].split( ":" )[0], @@ -385,6 +541,10 @@ def generate_wheels_matrix( return ret +<<<<<<< HEAD validate_nccl_dep_consistency("13.0") +======= +validate_nccl_dep_consistency("12.9") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) validate_nccl_dep_consistency("12.8") validate_nccl_dep_consistency("12.6") diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 0396c405ad0a7..542e6a70e4a9f 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -22,7 +22,10 @@ LABEL_CIFLOW_PERIODIC = "ciflow/periodic" LABEL_CIFLOW_BINARIES_LIBTORCH = "ciflow/binaries_libtorch" LABEL_CIFLOW_BINARIES_WHEEL = "ciflow/binaries_wheel" +<<<<<<< HEAD LABEL_CIFLOW_ROCM = "ciflow/rocm" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass @@ -59,7 +62,13 @@ class BinaryBuildWorkflow: is_scheduled: str = "" branches: str = "nightly" # Mainly for macos +<<<<<<< HEAD macos_runner: str = "macos-14-xlarge" +======= + cross_compile_arm64: bool = False + macos_runner: str = "macos-14-xlarge" + use_split_build: bool = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Mainly used for libtorch builds build_variant: str = "" @@ -70,6 +79,12 @@ def __post_init__(self) -> None: for item in [self.os, "binary", self.package_type, self.build_variant] if item != "" ) +<<<<<<< HEAD +======= + if self.use_split_build: + # added to distinguish concurrency groups + self.build_environment += "-split" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: output_file_path = ( @@ -112,6 +127,24 @@ class OperatingSystem: isolated_workflow=True, ), ), +<<<<<<< HEAD +======= + # See https://github.com/pytorch/pytorch/issues/138750 + # BinaryBuildWorkflow( + # os=OperatingSystem.LINUX, + # package_type="manywheel", + # build_configs=generate_binary_build_matrix.generate_wheels_matrix( + # OperatingSystem.LINUX, + # use_split_build=True, + # arches=["11.8", "12.1", "12.4", "cpu"], + # ), + # ciflow_config=CIFlowConfig( + # labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL}, + # isolated_workflow=True, + # ), + # use_split_build=True, + # ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) BinaryBuildWorkflow( os=OperatingSystem.LINUX, package_type="libtorch", @@ -127,6 +160,7 @@ class OperatingSystem: ), ] +<<<<<<< HEAD ROCM_SMOKE_WORKFLOWS = [ BinaryBuildWorkflow( os=OperatingSystem.LINUX, @@ -149,17 +183,43 @@ class OperatingSystem: ), ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) LINUX_BINARY_SMOKE_WORKFLOWS = [ BinaryBuildWorkflow( os=OperatingSystem.LINUX, package_type="manywheel", build_configs=generate_binary_build_matrix.generate_wheels_matrix( OperatingSystem.LINUX, +<<<<<<< HEAD arches=["12.8"], python_versions=["3.12"], ), branches="main", ), +======= + arches=["12.6", "12.8", "12.9", "6.4"], + python_versions=["3.9"], + ), + branches="main", + ), + # See https://github.com/pytorch/pytorch/issues/138750 + # BinaryBuildWorkflow( + # os=OperatingSystem.LINUX, + # package_type="manywheel", + # build_configs=generate_binary_build_matrix.generate_wheels_matrix( + # OperatingSystem.LINUX, + # arches=["11.8", "12.1", "12.4"], + # python_versions=["3.9"], + # use_split_build=True, + # ), + # ciflow_config=CIFlowConfig( + # labels={LABEL_CIFLOW_PERIODIC}, + # ), + # branches="main", + # use_split_build=True, + # ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) BinaryBuildWorkflow( os=OperatingSystem.LINUX, package_type="libtorch", @@ -302,6 +362,10 @@ class OperatingSystem: generate_binary_build_matrix.RELEASE, libtorch_variants=["shared-with-deps"], ), +<<<<<<< HEAD +======= + cross_compile_arm64=False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) macos_runner="macos-14-xlarge", ciflow_config=CIFlowConfig( labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_LIBTORCH}, @@ -314,6 +378,10 @@ class OperatingSystem: build_configs=generate_binary_build_matrix.generate_wheels_matrix( OperatingSystem.MACOS_ARM64 ), +<<<<<<< HEAD +======= + cross_compile_arm64=False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) macos_runner="macos-14-xlarge", ciflow_config=CIFlowConfig( labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL}, @@ -373,11 +441,14 @@ def main() -> None: S390X_BINARY_BUILD_WORKFLOWS, ), ( +<<<<<<< HEAD # Give rocm it's own workflow file jinja_env.get_template("linux_binary_build_workflow.yml.j2"), ROCM_SMOKE_WORKFLOWS, ), ( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jinja_env.get_template("linux_binary_build_workflow.yml.j2"), LINUX_BINARY_SMOKE_WORKFLOWS, ), diff --git a/.github/scripts/get_workflow_job_id.py b/.github/scripts/get_workflow_job_id.py index b04cbed76e955..bf8e669531096 100644 --- a/.github/scripts/get_workflow_job_id.py +++ b/.github/scripts/get_workflow_job_id.py @@ -136,10 +136,17 @@ def find_job_id_name(args: Any) -> tuple[str, str]: def set_output(name: str, val: Any) -> None: +<<<<<<< HEAD print(f"Setting output {name}={val}") if os.getenv("GITHUB_OUTPUT"): with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env: print(f"{name}={val}", file=env) +======= + if os.getenv("GITHUB_OUTPUT"): + with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env: + print(f"{name}={val}", file=env) + print(f"setting {name}={val}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: print(f"::set-output name={name}::{val}") diff --git a/.github/scripts/lintrunner.sh b/.github/scripts/lintrunner.sh index b353617a45b2b..cd04147193c63 100755 --- a/.github/scripts/lintrunner.sh +++ b/.github/scripts/lintrunner.sh @@ -2,7 +2,11 @@ set -ex # Use uv to speed up lintrunner init +<<<<<<< HEAD python3 -m pip install -U uv==0.8.* setuptools +======= +python3 -m pip install uv==0.1.45 setuptools +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CACHE_DIRECTORY="/tmp/.lintbin" # Try to recover the cached binaries diff --git a/.github/scripts/parse_ref.py b/.github/scripts/parse_ref.py index e821750a49e10..05433caa11efa 100755 --- a/.github/scripts/parse_ref.py +++ b/.github/scripts/parse_ref.py @@ -5,7 +5,10 @@ def set_output(name: str, val: str) -> None: +<<<<<<< HEAD print(f"Setting output {name}={val}") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if os.getenv("GITHUB_OUTPUT"): with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env: print(f"{name}={val}", file=env) diff --git a/.github/scripts/runner_determinator.py b/.github/scripts/runner_determinator.py index baf560234549b..9af3be41dd65b 100644 --- a/.github/scripts/runner_determinator.py +++ b/.github/scripts/runner_determinator.py @@ -262,12 +262,16 @@ def is_exception_branch(branch: str) -> bool: """ Branches that get opted out of experiments by default, until they're explicitly enabled. """ +<<<<<<< HEAD return branch.split("/", maxsplit=1)[0] in { "main", "nightly", "release", "landchecks", } +======= + return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def load_yaml(yaml_text: str) -> Any: diff --git a/.github/scripts/tag_docker_images_for_release.py b/.github/scripts/tag_docker_images_for_release.py new file mode 100644 index 0000000000000..b2bf474575f6f --- /dev/null +++ b/.github/scripts/tag_docker_images_for_release.py @@ -0,0 +1,64 @@ +import argparse +import subprocess + +import generate_binary_build_matrix + + +def tag_image( + image: str, + default_tag: str, + release_version: str, + dry_run: str, + tagged_images: dict[str, bool], +) -> None: + if image in tagged_images: + return + release_image = image.replace(f"-{default_tag}", f"-{release_version}") + print(f"Tagging {image} to {release_image} , dry_run: {dry_run}") + + if dry_run == "disabled": + subprocess.check_call(["docker", "pull", image]) + subprocess.check_call(["docker", "tag", image, release_image]) + subprocess.check_call(["docker", "push", release_image]) + tagged_images[image] = True + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--version", + help="Version to tag", + type=str, + default="2.2", + ) + parser.add_argument( + "--dry-run", + help="No Runtime Error check", + type=str, + choices=["enabled", "disabled"], + default="enabled", + ) + + options = parser.parse_args() + tagged_images: dict[str, bool] = {} + platform_images = [ + generate_binary_build_matrix.WHEEL_CONTAINER_IMAGES, + generate_binary_build_matrix.LIBTORCH_CONTAINER_IMAGES, + ] + default_tag = generate_binary_build_matrix.DEFAULT_TAG + + for platform_image in platform_images: # type: ignore[attr-defined] + for arch in platform_image.keys(): # type: ignore[attr-defined] + if arch == "cpu-s390x": + continue + tag_image( + platform_image[arch], # type: ignore[index] + default_tag, + options.version, + options.dry_run, + tagged_images, + ) + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/td_llm_indexer.sh b/.github/scripts/td_llm_indexer.sh index cc8f363659ba6..834664fc00d24 100644 --- a/.github/scripts/td_llm_indexer.sh +++ b/.github/scripts/td_llm_indexer.sh @@ -6,7 +6,11 @@ set -euxo pipefail cd llm-target-determinator pip install -q -r requirements.txt cd ../codellama +<<<<<<< HEAD pip install --no-build-isolation -v -e . +======= +pip install -e . +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pip install numpy==1.26.0 # Run indexer diff --git a/.github/scripts/test_trymerge.py b/.github/scripts/test_trymerge.py index ac3a1cc12921c..ff88390297c96 100755 --- a/.github/scripts/test_trymerge.py +++ b/.github/scripts/test_trymerge.py @@ -27,7 +27,10 @@ get_drci_classifications, gh_get_team_members, GitHubPR, +<<<<<<< HEAD iter_issue_timeline_until_comment, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) JobCheckState, main as trymerge_main, MandatoryChecksMissingError, @@ -35,8 +38,11 @@ RE_GHSTACK_DESC, read_merge_rules, remove_job_name_suffix, +<<<<<<< HEAD sha_from_committed_event, sha_from_force_push_after, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) validate_revert, ) @@ -73,9 +79,12 @@ def save_mocked_queries(obj: Any) -> None: if key in mocked_queries: return mocked_queries[key] +<<<<<<< HEAD # TODO: Remove me once https://github.com/pytorch/pytorch/issues/160489 is resolved raise ValueError(f"Key {key} could not be found in gql_mocks") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: rc = fallback_function(*args) except HTTPError as err: @@ -127,7 +136,11 @@ def __init__(self) -> None: self.force = force self.pr_num = 76123 self.dry_run = True +<<<<<<< HEAD self.comment_id = 12345 # Set to non-zero value +======= + self.comment_id = 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.reason = "this is for testing" self.ignore_current = False self.check_mergeability = False @@ -155,9 +168,15 @@ def mock_revert( def mock_merge( pr: GitHubPR, repo: GitRepo, +<<<<<<< HEAD comment_id: int, dry_run: bool = False, skip_mandatory_checks: bool = False, +======= + dry_run: bool = False, + skip_mandatory_checks: bool = False, + comment_id: Optional[int] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) timeout_minutes: int = 400, stale_pr_days: int = 3, ignore_current: bool = False, @@ -473,9 +492,15 @@ def test_main_force( mock_merge.assert_called_once_with( mock.ANY, mock.ANY, +<<<<<<< HEAD comment_id=mock.ANY, dry_run=mock.ANY, skip_mandatory_checks=True, +======= + dry_run=mock.ANY, + skip_mandatory_checks=True, + comment_id=mock.ANY, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ignore_current=False, ) @@ -488,9 +513,15 @@ def test_main_merge(self, mock_merge: Any, *args: Any) -> None: mock_merge.assert_called_once_with( mock.ANY, mock.ANY, +<<<<<<< HEAD comment_id=mock.ANY, dry_run=mock.ANY, skip_mandatory_checks=False, +======= + dry_run=mock.ANY, + skip_mandatory_checks=False, + comment_id=mock.ANY, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ignore_current=False, ) @@ -1141,6 +1172,7 @@ def test__revlist_to_prs_two_prs( ) +<<<<<<< HEAD @mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql) @mock.patch("trymerge.gh_fetch_merge_base", return_value="") @mock.patch( @@ -1312,5 +1344,7 @@ def test_get_commit_sha_at_comment_exception( self.assertIsNone(sha) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": main() diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index 00b66869dcf2a..4e6e4c835f615 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -108,6 +108,13 @@ def __init__(self, name: str, url: str, run_id: int, status: Optional[str]): fragment PRCheckSuites on CheckSuiteConnection { edges { node { +<<<<<<< HEAD +======= + app { + name + databaseId + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) workflowRun { workflow { name @@ -450,6 +457,7 @@ def __init__(self, name: str, url: str, run_id: int, status: Optional[str]): IGNORABLE_FAILED_CHECKS_THESHOLD = 10 +<<<<<<< HEAD def iter_issue_timeline_until_comment( org: str, repo: str, issue_number: int, target_comment_id: int, max_pages: int = 200 ) -> Any: @@ -507,6 +515,8 @@ def sha_from_force_push_after(ev: dict[str, Any]) -> Optional[str]: return ev.get("after_sha") or ev.get("head_sha") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def gh_get_pr_info(org: str, proj: str, pr_no: int) -> Any: rc = gh_graphql(GH_GET_PR_INFO_QUERY, name=proj, owner=org, number=pr_no) return rc["data"]["repository"]["pullRequest"] @@ -794,6 +804,7 @@ def get_changed_files_count(self) -> int: def last_commit(self) -> Any: return self.info["commits"]["nodes"][-1]["commit"] +<<<<<<< HEAD def last_commit_sha(self, default: Optional[str] = None) -> str: # for commits, the oid is the sha @@ -802,16 +813,26 @@ def last_commit_sha(self, default: Optional[str] = None) -> str: return str(self.last_commit().get("oid", default)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_merge_base(self) -> str: if self.merge_base: return self.merge_base +<<<<<<< HEAD last_commit_sha = self.last_commit_sha() +======= + last_commit_oid = self.last_commit()["oid"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: We could use self.base_ref() here for regular PR, however, that doesn't # work for ghstack where the base is the custom branch, i.e. gh/USER/ID/base, # so let's just use main instead self.merge_base = gh_fetch_merge_base( +<<<<<<< HEAD self.org, self.project, last_commit_sha, self.default_branch() +======= + self.org, self.project, last_commit_oid, self.default_branch() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Fallback to baseRefOid if the API call fails, i.e. rate limit. Note that baseRefOid @@ -900,6 +921,7 @@ def get_approved_by(self) -> list[str]: def get_commit_count(self) -> int: return int(self.info["commits_with_authors"]["totalCount"]) +<<<<<<< HEAD def get_commit_sha_at_comment(self, comment_id: int) -> Optional[str]: """ Get the PR head commit SHA that was present when a specific comment was posted. @@ -938,6 +960,8 @@ def get_commit_sha_at_comment(self, comment_id: int) -> Optional[str]: print(f"Did not find comment with id {comment_id} in the PR timeline") return None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_pr_creator_login(self) -> str: return cast(str, self.info["author"]["login"]) @@ -1254,7 +1278,11 @@ def merge_into( *, skip_mandatory_checks: bool = False, dry_run: bool = False, +<<<<<<< HEAD comment_id: int, +======= + comment_id: Optional[int] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ignore_current_checks: Optional[list[str]] = None, ) -> None: # Raises exception if matching rule is not found @@ -1270,7 +1298,11 @@ def merge_into( skip_internal_checks=can_skip_internal_checks(self, comment_id), ignore_current_checks=ignore_current_checks, ) +<<<<<<< HEAD additional_merged_prs = self.merge_changes_locally( +======= + additional_merged_prs = self.merge_changes( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) repo, skip_mandatory_checks, comment_id ) @@ -1299,7 +1331,11 @@ def merge_into( broken_trunk_checks=ignorable_checks.get("BROKEN_TRUNK", []), flaky_checks=ignorable_checks.get("FLAKY", []), unstable_checks=ignorable_checks.get("UNSTABLE", []), +<<<<<<< HEAD last_commit_sha=self.last_commit_sha(default=""), +======= + last_commit_sha=self.last_commit().get("oid", ""), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) merge_base_sha=self.get_merge_base(), merge_commit_sha=merge_commit_sha, is_failed=False, @@ -1320,7 +1356,11 @@ def merge_into( dry_run=dry_run, ) +<<<<<<< HEAD def merge_changes_locally( +======= + def merge_changes( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, repo: GitRepo, skip_mandatory_checks: bool = False, @@ -1329,15 +1369,38 @@ def merge_changes_locally( skip_all_rule_checks: bool = False, ) -> list["GitHubPR"]: """ +<<<<<<< HEAD :param skip_all_rule_checks: If true, skips all rule checks on ghstack PRs, useful for dry-running merge locally +======= + :param skip_all_rule_checks: If true, skips all rule checks, useful for dry-running merge locally +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ branch_to_merge_into = self.default_branch() if branch is None else branch if repo.current_branch() != branch_to_merge_into: repo.checkout(branch_to_merge_into) +<<<<<<< HEAD # It's okay to skip the commit SHA check for ghstack PRs since # authoring requires write access to the repo. if self.is_ghstack_pr(): +======= + if not self.is_ghstack_pr(): + msg = self.gen_commit_message() + pr_branch_name = f"__pull-request-{self.pr_num}__init__" + repo.fetch(self.last_commit()["oid"], pr_branch_name) + repo._run_git("merge", "--squash", pr_branch_name) + repo._run_git("commit", f'--author="{self.get_author()}"', "-m", msg) + + # Did the PR change since we started the merge? + pulled_sha = repo.show_ref(pr_branch_name) + latest_pr_status = GitHubPR(self.org, self.project, self.pr_num) + if pulled_sha != latest_pr_status.last_commit()["oid"]: + raise RuntimeError( + "PR has been updated since CI checks last passed. Please rerun the merge command." + ) + return [] + else: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.merge_ghstack_into( repo, skip_mandatory_checks, @@ -1345,6 +1408,7 @@ def merge_changes_locally( skip_all_rule_checks=skip_all_rule_checks, ) +<<<<<<< HEAD msg = self.gen_commit_message() pr_branch_name = f"__pull-request-{self.pr_num}__init__" @@ -1387,6 +1451,8 @@ def merge_changes_locally( ) return [] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class MergeRuleFailedError(RuntimeError): def __init__(self, message: str, rule: Optional["MergeRule"] = None) -> None: @@ -1591,7 +1657,11 @@ def find_matching_merge_rule( pending_checks = [] failed_checks = [] +<<<<<<< HEAD hud_link = f"https://hud.pytorch.org/{pr.org}/{pr.project}/commit/{pr.last_commit_sha()}" +======= + hud_link = f"https://hud.pytorch.org/{pr.org}/{pr.project}/commit/{pr.last_commit()['oid']}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if len(failed_checks) > 0: if reject_reason_score < 30000: reject_reason_score = 30000 @@ -2020,9 +2090,13 @@ def validate_revert( else pr.get_comment_by_id(comment_id) ) if comment.editor_login is not None: +<<<<<<< HEAD raise PostCommentError( "Halting the revert as the revert comment has been edited." ) +======= + raise PostCommentError("Don't want to revert based on edited command") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) author_association = comment.author_association author_login = comment.author_login allowed_reverters = ["COLLABORATOR", "MEMBER", "OWNER"] @@ -2289,14 +2363,24 @@ def categorize_checks( def merge( pr: GitHubPR, repo: GitRepo, +<<<<<<< HEAD comment_id: int, dry_run: bool = False, skip_mandatory_checks: bool = False, +======= + dry_run: bool = False, + skip_mandatory_checks: bool = False, + comment_id: Optional[int] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) timeout_minutes: int = 400, stale_pr_days: int = 3, ignore_current: bool = False, ) -> None: +<<<<<<< HEAD initial_commit_sha = pr.last_commit_sha() +======= + initial_commit_sha = pr.last_commit()["oid"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pr_link = f"https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num}" print(f"Attempting merge of {initial_commit_sha} ({pr_link})") @@ -2367,7 +2451,11 @@ def merge( f"Attempting merge of https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num} ({elapsed_time / 60} minutes elapsed)" ) pr = GitHubPR(pr.org, pr.project, pr.pr_num) +<<<<<<< HEAD if initial_commit_sha != pr.last_commit_sha(): +======= + if initial_commit_sha != pr.last_commit()["oid"]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise RuntimeError( "New commits were pushed while merging. Please rerun the merge command." ) @@ -2534,7 +2622,11 @@ def handle_exception(e: Exception, title: str = "Merge failed") -> None: if args.check_mergeability: if pr.is_ghstack_pr(): get_ghstack_prs(repo, pr) # raises error if out of sync +<<<<<<< HEAD pr.merge_changes_locally( +======= + pr.merge_changes( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) repo, skip_mandatory_checks=True, skip_all_rule_checks=True, @@ -2549,6 +2641,7 @@ def handle_exception(e: Exception, title: str = "Merge failed") -> None: gh_post_pr_comment(org, project, args.pr_num, message, dry_run=args.dry_run) return try: +<<<<<<< HEAD # Ensure comment id is set, else fail if not args.comment_id: raise ValueError( @@ -2561,6 +2654,14 @@ def handle_exception(e: Exception, title: str = "Merge failed") -> None: comment_id=args.comment_id, dry_run=args.dry_run, skip_mandatory_checks=args.force, +======= + merge( + pr, + repo, + dry_run=args.dry_run, + skip_mandatory_checks=args.force, + comment_id=args.comment_id, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ignore_current=args.ignore_current, ) except Exception as e: @@ -2582,7 +2683,11 @@ def handle_exception(e: Exception, title: str = "Merge failed") -> None: broken_trunk_checks=[], flaky_checks=[], unstable_checks=[], +<<<<<<< HEAD last_commit_sha=pr.last_commit_sha(default=""), +======= + last_commit_sha=pr.last_commit().get("oid", ""), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) merge_base_sha=pr.get_merge_base(), is_failed=True, skip_mandatory_checks=args.force, diff --git a/.github/scripts/windows/build_magma.bat b/.github/scripts/windows/build_magma.bat index 75c916ecdbef7..28977ee042ffc 100644 --- a/.github/scripts/windows/build_magma.bat +++ b/.github/scripts/windows/build_magma.bat @@ -17,7 +17,10 @@ if errorlevel 1 exit /b 1 set "PATH=C:\Tools;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%CUVER%\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%CUVER%\libnvvp;%PATH%" set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%CUVER% +<<<<<<< HEAD set NVTOOLSEXT_PATH=C:\Program Files\NVIDIA Corporation\NvToolsExt +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mkdir magma_cuda%CUVER_NODOT% cd magma_cuda%CUVER_NODOT% @@ -35,9 +38,12 @@ cd magma mkdir build && cd build set GPU_TARGET=All +<<<<<<< HEAD if "%CUVER_NODOT%" == "130" ( set CUDA_ARCH_LIST=-gencode=arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90 -gencode arch=compute_100,code=sm_100 -gencode arch=compute_120,code=sm_120 ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if "%CUVER_NODOT%" == "129" ( set CUDA_ARCH_LIST=-gencode=arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90 -gencode arch=compute_100,code=sm_100 -gencode arch=compute_120,code=sm_120 ) diff --git a/.github/scripts/windows/build_triton.bat b/.github/scripts/windows/build_triton.bat index d26dc8bf3b198..761d5cfbc962f 100644 --- a/.github/scripts/windows/build_triton.bat +++ b/.github/scripts/windows/build_triton.bat @@ -1,12 +1,30 @@ @echo on +<<<<<<< HEAD set DESIRED_PYTHON=%PY_VERS% call .ci/pytorch/windows/internal/install_python.bat :: Fix cmake version for issue https://github.com/pytorch/pytorch/issues/150480 %PYTHON_EXEC% -m pip install wheel pybind11 certifi cython cmake==3.31.6 setuptools==72.1.0 ninja==1.11.1.4 +======= +set PYTHON_PREFIX=%PY_VERS:.=% +set PYTHON_PREFIX=py%PYTHON_PREFIX:;=;py% +call .ci/pytorch/win-test-helpers/installation-helpers/activate_miniconda3.bat +:: Create a new conda environment +if "%PY_VERS%" == "3.13t" ( + call conda create -n %PYTHON_PREFIX% -y -c=conda-forge python-freethreading python=3.13 +) else ( + call conda create -n %PYTHON_PREFIX% -y -c=conda-forge python=%PY_VERS% +) +:: Fix cmake version for issue https://github.com/pytorch/pytorch/issues/150480 +call conda run -n %PYTHON_PREFIX% pip install wheel pybind11 certifi cython cmake==3.31.6 setuptools==72.1.0 ninja +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dir "%VC_INSTALL_PATH%" call "%VC_INSTALL_PATH%\VC\Auxiliary\Build\vcvarsall.bat" x64 +<<<<<<< HEAD %PYTHON_EXEC% .github/scripts/build_triton_wheel.py --device=%BUILD_DEVICE% %RELEASE% +======= +call conda run -n %PYTHON_PREFIX% python .github/scripts/build_triton_wheel.py --device=%BUILD_DEVICE% %RELEASE% +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/templates/common.yml.j2 b/.github/templates/common.yml.j2 index 7c93fdf522a47..a4646179cdca7 100644 --- a/.github/templates/common.yml.j2 +++ b/.github/templates/common.yml.j2 @@ -4,7 +4,11 @@ {%- set download_artifact_action = "actions/download-artifact@v4.1.7" -%} {%- set timeout_minutes = 240 -%} +<<<<<<< HEAD {%- set timeout_minutes_windows_binary = 360 -%} +======= +{%- set timeout_minutes_windows_binary = 300 -%} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {%- macro concurrency(build_environment) -%} concurrency: @@ -32,7 +36,11 @@ concurrency: {%- macro setup_ec2_windows() -%} !{{ display_ec2_information() }} - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/templates/linux_binary_build_workflow.yml.j2 b/.github/templates/linux_binary_build_workflow.yml.j2 index bf7db5866e783..e36c19e6cc477 100644 --- a/.github/templates/linux_binary_build_workflow.yml.j2 +++ b/.github/templates/linux_binary_build_workflow.yml.j2 @@ -56,7 +56,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -77,9 +81,12 @@ jobs: runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" timeout-minutes: 420 +<<<<<<< HEAD {%- elif config["gpu_arch_type"] == "rocm" %} runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {%- elif "conda" in build_environment and config["gpu_arch_type"] == "cuda" %} runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral @@ -117,12 +124,21 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" {%- elif config["gpu_arch_type"] == "rocm" %} runs_on: linux.rocm.gpu +<<<<<<< HEAD {%- elif config["gpu_arch_type"] == "cuda" and config["gpu_arch_version"] in ["12.6"] %} runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu # 12.6 build can use maxwell (sm_50) runner {%- elif config["gpu_arch_type"] == "cuda" %} runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8+ builds need sm_70+ runner +======= + {%- elif config["gpu_arch_type"] == "cuda" and config["gpu_arch_version"] in ["12.8", "12.9"] %} + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + {%- elif config["gpu_arch_type"] == "cuda" %} + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {%- else %} runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge @@ -138,7 +154,11 @@ jobs: contents: read steps: - name: Setup XPU +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/setup-xpu@release/2.9 +======= + uses: ./.github/actions/setup-xpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: configure aws credentials id: aws_creds uses: aws-actions/configure-aws-credentials@v4 @@ -156,7 +176,11 @@ jobs: !{{ common.checkout(deep_clone=False, directory="pytorch", checkout_pr_head=False) }} - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: !{{ config["container_image"] }} @@ -164,7 +188,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -199,7 +227,11 @@ jobs: role-duration-seconds: 18000 - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: !{{ config["container_image"] }} @@ -207,7 +239,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary diff --git a/.github/templates/macos_binary_build_workflow.yml.j2 b/.github/templates/macos_binary_build_workflow.yml.j2 index 662060bb13075..e7c09cc0b4204 100644 --- a/.github/templates/macos_binary_build_workflow.yml.j2 +++ b/.github/templates/macos_binary_build_workflow.yml.j2 @@ -47,6 +47,12 @@ env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PR_NUMBER: ${{ github.event.pull_request.number }} SKIP_ALL_TESTS: 0 +<<<<<<< HEAD +======= +{%- if cross_compile_arm64 %} + CROSS_COMPILE_ARM64: 1 +{% endif %} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) !{{ common.concurrency(build_environment) }} jobs: @@ -68,6 +74,14 @@ jobs: chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" +<<<<<<< HEAD +======= + if [ -d "/Applications/Xcode_14.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_14.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + elif [ -d "/Applications/Xcode_13.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) !{{ common.checkout(deep_clone=False, directory="pytorch", checkout_pr_head=False) }} - name: Populate binary env run: | @@ -105,6 +119,7 @@ jobs: # Create new "clean" conda environment for testing SMOKE_TEST_PARAMS="" +<<<<<<< HEAD EXTRA_CONDA_INSTALL_FLAGS="" CONDA_ENV_CREATE_FLAGS="" @@ -132,6 +147,14 @@ jobs: # shellcheck disable=SC2086 conda create -yn "test_conda_env" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} ${EXTRA_CONDA_INSTALL_FLAGS} +======= + if [[ $DESIRED_PYTHON == "3.13t" ]]; then + conda create -yn "test_conda_env" python="3.13" python-freethreading -c conda-forge + SMOKE_TEST_PARAMS="--torch-compile-check disabled" + else + conda create -yn "test_conda_env" python="$DESIRED_PYTHON" + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) conda activate test_conda_env pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v diff --git a/.github/templates/upload.yml.j2 b/.github/templates/upload.yml.j2 index 5e3798f8e2377..ae519cc9a7330 100644 --- a/.github/templates/upload.yml.j2 +++ b/.github/templates/upload.yml.j2 @@ -15,7 +15,11 @@ # favor of GPU_ARCH_VERSION DESIRED_CUDA: !{{ config["desired_cuda"] }} {%- if config["gpu_arch_version"] %} +<<<<<<< HEAD GPU_ARCH_VERSION: "!{{ config["gpu_arch_version"] }}" +======= + GPU_ARCH_VERSION: !{{ config["gpu_arch_version"] }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {%- endif %} GPU_ARCH_TYPE: !{{ config["gpu_arch_type"] }} {%- if include_skip_tests %} @@ -25,6 +29,14 @@ DOCKER_IMAGE: !{{ config["container_image"] }} DOCKER_IMAGE_TAG_PREFIX: !{{ config["container_image_tag_prefix"] }} {%- endif %} +<<<<<<< HEAD +======= +{%- if config["package_type"] == "manywheel" %} + {%- if config.use_split_build is defined %} + use_split_build: !{{ config["use_split_build"] }} + {%- endif %} +{%- endif %} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {%- if config["package_type"] == "libtorch" %} {%- if config["libtorch_config"] %} LIBTORCH_CONFIG: !{{ config["libtorch_config"] }} @@ -33,7 +45,11 @@ {%- if is_windows %} # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {%- endif %} {%- else %} diff --git a/.github/templates/windows_binary_build_workflow.yml.j2 b/.github/templates/windows_binary_build_workflow.yml.j2 index c61686f8df273..5ea957ef98538 100644 --- a/.github/templates/windows_binary_build_workflow.yml.j2 +++ b/.github/templates/windows_binary_build_workflow.yml.j2 @@ -64,7 +64,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/_bazel-build-test.yml b/.github/workflows/_bazel-build-test.yml index d9e5e29576d4c..783b8166eaa14 100644 --- a/.github/workflows/_bazel-build-test.yml +++ b/.github/workflows/_bazel-build-test.yml @@ -47,7 +47,11 @@ jobs: reenabled-issues: ${{ steps.filter.outputs.reenabled-issues }} steps: - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: fetch-depth: 1 submodules: false @@ -69,25 +73,41 @@ jobs: runs-on: ${{ matrix.runner }} steps: - name: Setup SSH (Click me for login details) +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: github-secret: ${{ secrets.GITHUB_TOKEN }} # [see note: pytorch repo ref] - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Setup Linux uses: ./.github/actions/setup-linux - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image-name: ${{ inputs.docker-image-name }} - name: Pull docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -97,7 +117,11 @@ jobs: run: echo "IN_CONTAINER_RUNNER=$(if [ -f /.inarc ] || [ -f /.incontainer ]; then echo true ; else echo false; fi)" >> "$GITHUB_OUTPUT" - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-nvidia@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-nvidia@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ inputs.cuda-version != 'cpu' && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' }} - name: Output disk space left @@ -209,5 +233,9 @@ jobs: file-suffix: bazel-${{ github.job }}_${{ steps.get-job-id.outputs.job-id }} - name: Teardown Linux +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: always() diff --git a/.github/workflows/_binary-build-linux.yml b/.github/workflows/_binary-build-linux.yml index e81e4b6a8b26e..2cb5cecabe62c 100644 --- a/.github/workflows/_binary-build-linux.yml +++ b/.github/workflows/_binary-build-linux.yml @@ -26,6 +26,16 @@ on: default: 240 type: number description: timeout for the job +<<<<<<< HEAD +======= + use_split_build: + description: | + [Experimental] Build a libtorch only wheel and build pytorch such that + are built from the libtorch wheel. + required: false + type: boolean + default: false +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ALPINE_IMAGE: required: false type: string @@ -110,6 +120,10 @@ jobs: PR_NUMBER: ${{ github.event.pull_request.number }} PYTORCH_FINAL_PACKAGE_DIR: /artifacts SHA1: ${{ github.event.pull_request.head.sha || github.sha }} +<<<<<<< HEAD +======= + USE_SPLIT_BUILD: ${{ inputs.use_split_build }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Make the env permanent during this workflow (but not the secrets) shell: bash @@ -134,6 +148,10 @@ jobs: echo "PR_NUMBER=${{ env.PR_NUMBER }}" echo "PYTORCH_FINAL_PACKAGE_DIR=${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" echo "SHA1=${{ env.SHA1 }}" +<<<<<<< HEAD +======= + echo "USE_SPLIT_BUILD=${{ env.use_split_build }}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } >> "${GITHUB_ENV} }}" - name: List the env @@ -142,13 +160,21 @@ jobs: - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" if: inputs.build_environment != 'linux-s390x-binary-manywheel' +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.github-token }} - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: no-sudo: ${{ inputs.build_environment == 'linux-aarch64-binary-manywheel' || inputs.build_environment == 'linux-s390x-binary-manywheel' }} @@ -212,9 +238,15 @@ jobs: - name: Calculate docker image id: calculate-docker-image if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }} +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 with: # If doing this in release/2.9 or release branch, use docker.io. Otherwise +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 + with: + # If doing this in release/2.8 or release branch, use docker.io. Otherwise +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # use ECR docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: ${{ inputs.DOCKER_IMAGE }} @@ -226,7 +258,11 @@ jobs: - name: Pull Docker image if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }} +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -251,6 +287,10 @@ jobs: -e PYTORCH_ROOT \ -e SKIP_ALL_TESTS \ -e PYTORCH_EXTRA_INSTALL_REQUIREMENTS \ +<<<<<<< HEAD +======= + -e USE_SPLIT_BUILD \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) --tty \ --detach \ -v "${GITHUB_WORKSPACE}/pytorch:/pytorch" \ @@ -282,7 +322,11 @@ jobs: - name: Teardown Linux if: always() && inputs.build_environment != 'linux-s390x-binary-manywheel' +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Chown workspace if: always() && inputs.build_environment != 'linux-s390x-binary-manywheel' diff --git a/.github/workflows/_binary-test-linux.yml b/.github/workflows/_binary-test-linux.yml index 887ab908b2d8a..8b3f0e5ff46e1 100644 --- a/.github/workflows/_binary-test-linux.yml +++ b/.github/workflows/_binary-test-linux.yml @@ -64,6 +64,16 @@ on: required: true type: string description: Hardware to run this job on. Valid values are linux.4xlarge, linux.4xlarge.nvidia.gpu, linux.arm64.2xlarge, and linux.rocm.gpu +<<<<<<< HEAD +======= + use_split_build: + description: | + [Experimental] Build a libtorch only wheel and build pytorch such that + are built from the libtorch wheel. + required: false + type: boolean + default: false +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: required: true @@ -97,6 +107,10 @@ jobs: PR_NUMBER: ${{ github.event.pull_request.number }} PYTORCH_FINAL_PACKAGE_DIR: /artifacts SHA1: ${{ github.event.pull_request.head.sha || github.sha }} +<<<<<<< HEAD +======= + USE_SPLIT_BUILD: ${{ inputs.use_split_build }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Make the env permanent during this workflow (but not the secrets) shell: bash @@ -121,18 +135,30 @@ jobs: echo "PR_NUMBER=${{ env.PR_NUMBER }}" echo "PYTORCH_FINAL_PACKAGE_DIR=${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" echo "SHA1=${{ env.SHA1 }}" +<<<<<<< HEAD +======= + echo "USE_SPLIT_BUILD=${{ env.USE_SPLIT_BUILD }}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } >> "${GITHUB_ENV} }}" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" if: inputs.build_environment != 'linux-s390x-binary-manywheel' +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.github-token }} # Setup the environment - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: no-sudo: ${{ inputs.build_environment == 'linux-aarch64-binary-manywheel' || inputs.build_environment == 'linux-s390x-binary-manywheel' }} @@ -185,7 +211,11 @@ jobs: path: "${{ runner.temp }}/artifacts/" - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-nvidia@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-nvidia@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ inputs.GPU_ARCH_TYPE == 'cuda' && steps.filter.outputs.is-test-matrix-empty == 'False' }} - name: configure aws credentials @@ -200,7 +230,11 @@ jobs: - name: Calculate docker image id: calculate-docker-image if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }} +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: ${{ inputs.DOCKER_IMAGE }} @@ -210,7 +244,11 @@ jobs: - name: Pull Docker image if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }} +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -222,7 +260,11 @@ jobs: - name: Teardown Linux if: always() && inputs.build_environment != 'linux-s390x-binary-manywheel' +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Chown workspace if: always() && inputs.build_environment != 'linux-s390x-binary-manywheel' diff --git a/.github/workflows/_binary-upload.yml b/.github/workflows/_binary-upload.yml index 61896f52bbed5..15078dea4e360 100644 --- a/.github/workflows/_binary-upload.yml +++ b/.github/workflows/_binary-upload.yml @@ -51,6 +51,16 @@ on: required: false type: string description: Desired python version +<<<<<<< HEAD +======= + use_split_build: + description: | + [Experimental] Build a libtorch only wheel and build pytorch such that + are built from the libtorch wheel. + required: false + type: boolean + default: false +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: required: true @@ -79,9 +89,16 @@ jobs: PR_NUMBER: ${{ github.event.pull_request.number }} PYTORCH_FINAL_PACKAGE_DIR: /artifacts SHA1: ${{ github.event.pull_request.head.sha || github.sha }} +<<<<<<< HEAD steps: - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + USE_SPLIT_BUILD: ${{ inputs.use_split_build }} + steps: + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: no-sudo: true diff --git a/.github/workflows/_docs.yml b/.github/workflows/_docs.yml index 5980ad849fa7e..bf20afcabe448 100644 --- a/.github/workflows/_docs.yml +++ b/.github/workflows/_docs.yml @@ -67,10 +67,17 @@ jobs: # an OOM issue when running the job, so this upgrades the runner from 4xlarge # to the next available tier of 12xlarge. So much memory just to generate cpp # doc +<<<<<<< HEAD runner: ${{ inputs.runner_prefix }}linux.12xlarge.memory # TODO: Nightly cpp docs take longer and longer to finish (more than 3h now) # Let's try to figure out how this can be improved timeout-minutes: 360 +======= + runner: ${{ inputs.runner_prefix }}linux.12xlarge + # TODO: Nightly cpp docs take longer and longer to finish (more than 3h now) + # Let's try to figure out how this can be improved + timeout-minutes: 240 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - docs_type: python runner: ${{ inputs.runner_prefix }}linux.2xlarge # It takes less than 30m to finish python docs unless there are issues @@ -84,7 +91,11 @@ jobs: name: build-docs-${{ matrix.docs_type }}-${{ inputs.push }} steps: - name: Setup SSH (Click me for login details) +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: github-secret: ${{ secrets.GITHUB_TOKEN }} instructions: | @@ -95,7 +106,11 @@ jobs: # [see note: pytorch repo ref] - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Setup Linux uses: ./.github/actions/setup-linux @@ -110,12 +125,20 @@ jobs: - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image-name: ${{ inputs.docker-image }} - name: Pull docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -222,5 +245,9 @@ jobs: s3-prefix: pytorch/pytorch/${{ github.event.pull_request.number }}/functorchdocs - name: Teardown Linux +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: always() diff --git a/.github/workflows/_link_check.yml b/.github/workflows/_link_check.yml index 4c46ad28cf6bc..d3578dcba9945 100644 --- a/.github/workflows/_link_check.yml +++ b/.github/workflows/_link_check.yml @@ -11,9 +11,14 @@ on: jobs: lint-urls: if: ${{ github.event_name != 'pull_request' || !contains(github.event.pull_request.labels.*.name, 'skip-url-lint') }} +<<<<<<< HEAD uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.9 with: job-name: lint-urls +======= + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.8 + with: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) timeout: 120 runner: ${{ inputs.runner }}linux.2xlarge docker-image: ci-image:pytorch-linux-jammy-linter @@ -37,9 +42,14 @@ jobs: lint-xrefs: if: ${{ github.event_name != 'pull_request' || !contains(github.event.pull_request.labels.*.name, 'skip-xref-lint') }} +<<<<<<< HEAD uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.9 with: job-name: lint-xrefs +======= + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.8 + with: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) timeout: 60 runner: ${{ inputs.runner }}linux.2xlarge docker-image: ci-image:pytorch-linux-jammy-linter diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index f909488850d0b..132ba8df3db83 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -16,6 +16,14 @@ on: type: boolean default: true description: If set, upload generated build artifacts. +<<<<<<< HEAD +======= + build-with-debug: + required: false + type: boolean + default: false + description: If set, build in debug mode. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sync-tag: required: false type: string @@ -64,6 +72,14 @@ on: required: false type: string default: "" +<<<<<<< HEAD +======= + max-jobs: + description: | + Overwrite the number of jobs to use for the build + required: false + type: string +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) disable-monitor: description: | Disable utilization monitoring for build job @@ -82,6 +98,10 @@ on: required: false type: number default: 1 +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) allow-reuse-old-whl: description: | If set, the build try to pull an old wheel from s3 that was built on a @@ -89,6 +109,7 @@ on: required: false type: boolean default: true +<<<<<<< HEAD build-additional-packages: description: | If set, the build job will also builds these packages and saves their @@ -103,6 +124,8 @@ on: required: false type: string default: "" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: HUGGING_FACE_HUB_TOKEN: @@ -114,6 +137,10 @@ on: description: | FB app token to write to scribe endpoint +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) outputs: docker-image: value: ${{ jobs.build.outputs.docker-image }} @@ -128,12 +155,17 @@ jobs: # Don't run on forked repos if: github.repository_owner == 'pytorch' runs-on: ${{ inputs.runner_prefix}}${{ inputs.runner }} +<<<<<<< HEAD timeout-minutes: 480 +======= + timeout-minutes: 240 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) outputs: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} test-matrix: ${{ steps.filter.outputs.test-matrix }} steps: - name: Setup SSH (Click me for login details) +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 if: inputs.build-environment != 'linux-s390x-binary-manywheel' with: @@ -141,13 +173,23 @@ jobs: instructions: | Build is done inside the container, to start an interactive session run: docker exec -it $(docker container ps --format '{{.ID}}') bash +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 + if: inputs.build-environment != 'linux-s390x-binary-manywheel' + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # [pytorch repo ref] # Use a pytorch/pytorch reference instead of a reference to the local # checkout because when we run this action we don't *have* a local # checkout. In other cases you should prefer a local checkout. - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: no-sudo: true @@ -183,7 +225,11 @@ jobs: - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: inputs.build-environment != 'linux-s390x-binary-manywheel' with: docker-image-name: ${{ inputs.docker-image-name }} @@ -199,7 +245,11 @@ jobs: echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}" - name: Pull docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: inputs.build-environment != 'linux-s390x-binary-manywheel' && steps.use-old-whl.outputs.reuse != 'true' with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -232,7 +282,11 @@ jobs: MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }} run: | mkdir -p ../../usage_logs +<<<<<<< HEAD python3 -m pip install psutil==5.9.8 dataclasses_json==0.6.7 +======= + python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python3 -m tools.stats.monitor \ --log-interval "$MONITOR_LOG_INTERVAL" \ --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" \ @@ -254,6 +308,11 @@ jobs: env: BUILD_ENVIRONMENT: ${{ inputs.build-environment }} BRANCH: ${{ steps.parse-ref.outputs.branch }} +<<<<<<< HEAD +======= + # TODO duplicated + AWS_DEFAULT_REGION: us-east-1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PR_NUMBER: ${{ github.event.pull_request.number }} SHA1: ${{ github.event.pull_request.head.sha || github.sha }} # Do not set SCCACHE_S3_KEY_PREFIX to share the cache between all build jobs @@ -265,11 +324,19 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} DOCKER_IMAGE_S390X: ${{ inputs.docker-image-name }} XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }} +<<<<<<< HEAD OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} BUILD_ADDITIONAL_PACKAGES: ${{ inputs.build-additional-packages }} RUNNER: ${{ inputs.runner }} +======= + DEBUG: ${{ inputs.build-with-debug && '1' || '0' }} + OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} + MAX_JOBS_OVERRIDE: ${{ inputs.max-jobs }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run: | START_TIME=$(date +%s) if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then @@ -289,12 +356,22 @@ jobs: DOCKER_SHELL_CMD= fi +<<<<<<< HEAD +======= + if [[ ${MAX_JOBS_OVERRIDE} == "" ]]; then + MAX_JOBS="$(nproc --ignore=2)" + else + MAX_JOBS="${MAX_JOBS_OVERRIDE}" + fi + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Leaving 1GB for the runner and other things TOTAL_AVAILABLE_MEMORY_IN_GB=$(awk '/MemTotal/ { printf "%.3f \n", $2/1024/1024 - 1 }' /proc/meminfo) # https://docs.docker.com/engine/containers/resource_constraints/#--memory-swap-details, the 3GB swap # comes from https://github.com/pytorch/test-infra/pull/6058 TOTAL_MEMORY_WITH_SWAP=$(("${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}" + 3)) +<<<<<<< HEAD if [[ ${BUILD_ENVIRONMENT} == *"riscv64"* ]]; then # EC2 specific setup for RISC-V emulation # Ensure binfmt_misc is available @@ -320,13 +397,22 @@ jobs: RISCV_DOCKER_ARGS= fi +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # detached container should get cleaned up by teardown_ec2_linux # Used for JENKINS_USER and DOCKER_SHELL_CMD, which can be empty # shellcheck disable=SC2086 container_name=$(docker run \ +<<<<<<< HEAD ${RISCV_DOCKER_ARGS} \ -e BUILD_ENVIRONMENT \ -e MAX_JOBS="$(nproc --ignore=2)" \ +======= + -e BUILD_ENVIRONMENT \ + -e MAX_JOBS=${MAX_JOBS} \ + -e MAX_JOBS_OVERRIDE \ + -e AWS_DEFAULT_REGION \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -e PR_NUMBER \ -e SHA1 \ -e BRANCH \ @@ -340,8 +426,12 @@ jobs: -e OUR_GITHUB_JOB_ID \ -e HUGGING_FACE_HUB_TOKEN \ -e SCRIBE_GRAPHQL_ACCESS_TOKEN \ +<<<<<<< HEAD -e BUILD_ADDITIONAL_PACKAGES \ -e RUNNER \ +======= + -e USE_SPLIT_BUILD \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) --memory="${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}g" \ --memory-swap="${TOTAL_MEMORY_WITH_SWAP}g" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ @@ -355,16 +445,20 @@ jobs: "${USED_IMAGE}" \ ${DOCKER_SHELL_CMD} ) +<<<<<<< HEAD if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then docker exec -t "${container_name}" sh -c "python3 -m pip install -r requirements.txt" fi +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) docker exec -t "${container_name}" sh -c '.ci/pytorch/build.sh' END_TIME=$(date +%s) echo "build_time=$((END_TIME - START_TIME))" >> "$GITHUB_OUTPUT" +<<<<<<< HEAD - name: Build external packages id: build-external-packages if: inputs.build-external-packages != '' && steps.build.outcome != 'skipped' @@ -385,6 +479,8 @@ jobs: mv "$src" "dist/$(dirname "$src")/" fi +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Stop monitoring script if: ${{ always() && steps.monitor-script.outputs.monitor-script-pid }} shell: bash @@ -457,7 +553,11 @@ jobs: artifact_prefix: usage_log_build_${{ steps.get-job-id.outputs.job-id }} - name: Teardown Linux +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: always() && inputs.build-environment != 'linux-s390x-binary-manywheel' - name: Cleanup docker diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index f413f497d79e8..dafeaa07aab37 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -72,10 +72,13 @@ on: required: false description: | HF Auth token to avoid rate limits when downloading models or datasets from hub +<<<<<<< HEAD VLLM_TEST_HUGGING_FACE_TOKEN: required: false description: | HF Auth token to test vllm +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SCRIBE_GRAPHQL_ACCESS_TOKEN: required: false description: | @@ -94,6 +97,7 @@ jobs: environment: ${{ github.ref == 'refs/heads/main' && 'scribe-protected' || startsWith(github.ref, 'refs/heads/release/') && 'scribe-protected' || contains(github.event.pull_request.labels.*.name, 'ci-scribe') && 'scribe-pr' || '' }} runs-on: ${{ matrix.runner }} timeout-minutes: ${{ matrix.mem_leak_check == 'mem_leak_check' && 600 || inputs.timeout-minutes }} +<<<<<<< HEAD permissions: id-token: write contents: read @@ -101,6 +105,12 @@ jobs: - name: Setup SSH (Click me for login details) uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 if: ${{ !contains(matrix.runner, 'b200') && inputs.build-environment != 'linux-s390x-binary-manywheel' }} +======= + steps: + - name: Setup SSH (Click me for login details) + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 + if: ${{ !contains(matrix.runner, 'gcp.a100') && inputs.build-environment != 'linux-s390x-binary-manywheel' }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: github-secret: ${{ secrets.GITHUB_TOKEN }} instructions: | @@ -108,6 +118,7 @@ jobs: docker exec -it $(docker container ps --format '{{.ID}}') bash - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 with: no-sudo: true @@ -125,12 +136,25 @@ jobs: - name: configure aws credentials if: ${{ inputs.aws-role-to-assume != '' && inputs.build-environment != 'linux-s390x-binary-manywheel' }} +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 + with: + no-sudo: true + + - name: Setup Linux + uses: ./.github/actions/setup-linux + if: inputs.build-environment != 'linux-s390x-binary-manywheel' + + - name: configure aws credentials + if : ${{ inputs.aws-role-to-assume != '' && inputs.build-environment != 'linux-s390x-binary-manywheel' }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 with: role-to-assume: ${{ inputs.aws-role-to-assume }} role-session-name: gha-linux-test aws-region: us-east-1 +<<<<<<< HEAD - name: Login to Amazon ECR if: ${{ inputs.aws-role-to-assume != '' && contains(matrix.runner, 'b200') }} id: login-ecr @@ -140,6 +164,11 @@ jobs: - name: Calculate docker image id: calculate-docker-image uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + - name: Calculate docker image + id: calculate-docker-image + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: inputs.build-environment != 'linux-s390x-binary-manywheel' with: docker-image-name: ${{ inputs.docker-image }} @@ -155,7 +184,11 @@ jobs: echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}" - name: Pull docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: inputs.build-environment != 'linux-s390x-binary-manywheel' with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -167,20 +200,33 @@ jobs: - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG id: install-nvidia-driver +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-nvidia@release/2.9 with: driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '580.82.07' }} if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && !contains(matrix.runner, 'b200') }} +======= + uses: pytorch/test-infra/.github/actions/setup-nvidia@release/2.8 + if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Setup GPU_FLAG for docker run id: setup-gpu-flag run: echo "GPU_FLAG=--gpus all -e NVIDIA_DRIVER_CAPABILITIES=all" >> "${GITHUB_ENV}" +<<<<<<< HEAD if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || contains(matrix.runner, 'b200')) }} +======= + if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Setup SCCACHE_SERVER_PORT environment for docker run when on container id: setup-sscache-port-flag run: echo "SCCACHE_SERVER_PORT_DOCKER_FLAG=-e SCCACHE_SERVER_PORT=$((RUNNER_UID + 4226))" >> "${GITHUB_ENV}" +<<<<<<< HEAD if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' && !contains(matrix.runner, 'b200') }} +======= + if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Lock NVIDIA A100 40GB Frequency run: | @@ -209,7 +255,11 @@ jobs: MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }} MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }} run: | +<<<<<<< HEAD python3 -m pip install psutil==5.9.8 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 +======= + python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python3 -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" @@ -247,12 +297,15 @@ jobs: run: | echo "timeout=$((JOB_TIMEOUT-30))" >> "${GITHUB_OUTPUT}" +<<<<<<< HEAD - name: Preserve github env variables for use in docker shell: bash run: | env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}" env | grep '^CI' >> "/tmp/github_env_${GITHUB_RUN_ID}" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Test id: test timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }} @@ -273,8 +326,11 @@ jobs: TEST_CONFIG: ${{ matrix.config }} SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} +<<<<<<< HEAD EXTRA_FLAGS: ${{ matrix.extra_flags || '' }} OP_BENCHMARK_TESTS: ${{ matrix.op_benchmark_tests }} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) REENABLED_ISSUES: ${{ steps.keep-going.outputs.reenabled-issues }} CONTINUE_THROUGH_ERROR: ${{ steps.keep-going.outputs.keep-going }} VERBOSE_TEST_LOGS: ${{ steps.keep-going.outputs.ci-verbose-test-logs }} @@ -283,8 +339,13 @@ jobs: NO_TD: ${{ steps.keep-going.outputs.ci-no-td }} TD_DISTRIBUTED: ${{ steps.keep-going.outputs.ci-td-distributed }} # Do not set SCCACHE_S3_KEY_PREFIX to share the cache between all build jobs +<<<<<<< HEAD SCCACHE_BUCKET: ${{ !contains(matrix.runner, 'b200') && 'ossci-compiler-cache-circleci-v2' || '' }} SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }} +======= + SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + SCCACHE_REGION: us-east-1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }} DOCKER_IMAGE: ${{ inputs.docker-image }} XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }} @@ -292,9 +353,15 @@ jobs: PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }} DASHBOARD_TAG: ${{ inputs.dashboard-tag }} +<<<<<<< HEAD VLLM_TEST_HUGGING_FACE_TOKEN: ${{ secrets.VLLM_TEST_HUGGING_FACE_TOKEN }} HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} +======= + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} + IS_A100_RUNNER: ${{ contains(matrix.runner, 'a100') && '1' || '0' }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ARTIFACTS_FILE_SUFFIX: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.get-job-id.outputs.job-id }} run: | set -x @@ -320,6 +387,13 @@ jobs: # if for some reason cleanup action doesn't stop container # when job is cancelled DOCKER_SHELL_CMD="sleep 12h" +<<<<<<< HEAD +======= + + # since some steps are skipped on s390x, if they are necessary, run them here + env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}" + env | grep '^CI' >> "/tmp/github_env_${GITHUB_RUN_ID}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else SHM_OPTS="--shm-size=${SHM_SIZE}" JENKINS_USER="--user jenkins" @@ -369,9 +443,15 @@ jobs: -e PYTORCH_TEST_RERUN_DISABLED_TESTS \ -e SKIP_SCCACHE_INITIALIZATION=1 \ -e HUGGING_FACE_HUB_TOKEN \ +<<<<<<< HEAD -e VLLM_TEST_HUGGING_FACE_TOKEN \ -e SCRIBE_GRAPHQL_ACCESS_TOKEN \ -e DASHBOARD_TAG \ +======= + -e SCRIBE_GRAPHQL_ACCESS_TOKEN \ + -e DASHBOARD_TAG \ + -e IS_A100_RUNNER \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -e ARTIFACTS_FILE_SUFFIX \ --memory="${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}g" \ --memory-swap="${TOTAL_MEMORY_WITH_SWAP}g" \ @@ -410,6 +490,7 @@ jobs: test_config: ${{ matrix.config }} job_identifier: ${{ github.workflow }}_${{ inputs.build-environment }} +<<<<<<< HEAD - name: Authenticate with AWS if: ${{ always() && contains(matrix.runner, 'b200') }} uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 @@ -421,6 +502,10 @@ jobs: - name: Upload the benchmark results uses: pytorch/test-infra/.github/actions/upload-benchmark-results@release/2.9 +======= + - name: Upload the benchmark results + uses: pytorch/test-infra/.github/actions/upload-benchmark-results@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: inputs.build-environment != 'linux-s390x-binary-manywheel' with: benchmark-results-dir: test/test-reports @@ -478,7 +563,11 @@ jobs: workflow_attempt: ${{github.run_attempt}} - name: Teardown Linux +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: always() && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' # NB: We are currently having an intermittent GPU-related issue on G5 runners with diff --git a/.github/workflows/_mac-build.yml b/.github/workflows/_mac-build.yml index 9561dcc8b8959..7a2c9a0f88d1a 100644 --- a/.github/workflows/_mac-build.yml +++ b/.github/workflows/_mac-build.yml @@ -67,11 +67,19 @@ jobs: test-matrix: ${{ steps.filter.outputs.test-matrix }} steps: - name: Clean up disk space before running MacOS workflow +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/check-disk-space@release/2.9 # [see note: pytorch repo ref] - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/check-disk-space@release/2.8 + + # [see note: pytorch repo ref] + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Set xcode version env: @@ -82,7 +90,11 @@ jobs: fi - name: Setup Python +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-python@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-python@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: python-version: ${{ inputs.python-version }} pip-requirements-file: .github/requirements/pip-requirements-macOS.txt @@ -123,7 +135,11 @@ jobs: else # The runner has access to the S3 bucket via IAM profile without the need # for any credential +<<<<<<< HEAD echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" +======= + echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}"0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) echo "SCCACHE_S3_KEY_PREFIX=${GITHUB_WORKFLOW}" >> "${GITHUB_ENV}" fi @@ -152,14 +168,27 @@ jobs: env: OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} run: | +<<<<<<< HEAD # TODO: Remove me later, and properly activate venv PATH="$VENV_PATH/bin:$PATH" export PATH +======= + echo "CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname "$(which conda)")/../"}" >> "${GITHUB_ENV}" + + if [[ -n "$CONDA_ENV" ]]; then + # Use binaries under conda environment + export PATH="$CONDA_ENV/bin":$PATH + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: Same trick as Linux, there is no need to initialize sccache with the risk of getting # it hangs or timeout at initialization. The cache will be started automatically export SKIP_SCCACHE_INITIALIZATION=1 +<<<<<<< HEAD .ci/pytorch/macos-build.sh +======= + ${CONDA_RUN} .ci/pytorch/macos-build.sh +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Archive artifacts into zip if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped' @@ -188,4 +217,8 @@ jobs: - name: Clean up disk space if: always() continue-on-error: true +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/check-disk-space@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/check-disk-space@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/workflows/_mac-test.yml b/.github/workflows/_mac-test.yml index 29ff3a72817fb..a7cc4b9ff926d 100644 --- a/.github/workflows/_mac-test.yml +++ b/.github/workflows/_mac-test.yml @@ -88,6 +88,7 @@ jobs: pkill "${PROCESS}" || true done +<<<<<<< HEAD - name: Clean up brew miniconda, if installed continue-on-error: true run: | @@ -95,6 +96,11 @@ jobs: brew uninstall miniconda echo "REINSTALL_BREW_MINICONDA=1" >> "${GITHUB_ENV}" fi +======= + - name: Clean up leftover miniconda installation + continue-on-error: true + run: brew uninstall miniconda || true +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Clean up leftover local python3 site-packages on MacOS pet runner continue-on-error: true @@ -105,11 +111,19 @@ jobs: done - name: Clean up disk space before running MacOS workflow +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/check-disk-space@release/2.9 # [see note: pytorch repo ref] - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/check-disk-space@release/2.8 + + # [see note: pytorch repo ref] + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Get workflow job id id: get-job-id @@ -118,12 +132,15 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} +<<<<<<< HEAD - name: Setup Python uses: pytorch/test-infra/.github/actions/setup-python@release/2.9 with: python-version: ${{ inputs.python-version }} pip-requirements-file: .github/requirements/pip-requirements-macOS.txt +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Start monitoring script id: monitor-script if: ${{ !inputs.disable-monitor }} @@ -136,8 +153,13 @@ jobs: MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }} MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }} run: | +<<<<<<< HEAD "$VENV_PATH/bin/python3" -m pip install psutil==5.9.8 dataclasses_json==0.6.7 "$VENV_PATH/bin/python3" -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & +======= + python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7 + python3 -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" - name: Download build artifacts @@ -152,6 +174,16 @@ jobs: with: use-gha: true +<<<<<<< HEAD +======= + - name: Setup Python + uses: pytorch/test-infra/.github/actions/setup-python@release/2.8 + with: + python-version: ${{ inputs.python-version }} + pip-requirements-file: .github/requirements/pip-requirements-macOS.txt + default-packages: "" + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Parse ref id: parse-ref run: .github/scripts/parse_ref.py @@ -202,7 +234,11 @@ jobs: set -ex # TODO: Remove me later, and properly activate venv +<<<<<<< HEAD PATH="$VENV_PATH/bin:$PATH" +======= + PATH="$(dirname "$(which python)"):$PATH" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) export PATH # Print out some information about the test environment @@ -257,7 +293,11 @@ jobs: file-suffix: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.get-job-id.outputs.job-id }} - name: Upload the benchmark results +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/upload-benchmark-results@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/upload-benchmark-results@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: benchmark-results-dir: test/test-reports dry-run: false @@ -276,6 +316,7 @@ jobs: workflow_attempt: ${{github.run_attempt}} local_path: usage_log.txt +<<<<<<< HEAD - name: Reinstall brew miniconda, if was installed if: always() continue-on-error: true @@ -288,3 +329,9 @@ jobs: if: always() continue-on-error: true uses: pytorch/test-infra/.github/actions/check-disk-space@release/2.9 +======= + - name: Clean up disk space + if: always() + continue-on-error: true + uses: pytorch/test-infra/.github/actions/check-disk-space@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index b6cd5d88a0941..3b690f338e927 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -81,13 +81,18 @@ jobs: steps: # [see note: pytorch repo ref] - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: no-sudo: true - name: Setup ROCm uses: ./.github/actions/setup-rocm +<<<<<<< HEAD - name: Runner check GPU count (distributed jobs) if: ${{ contains(matrix.config, 'distributed') }} shell: bash @@ -98,6 +103,8 @@ jobs: exit 1 fi +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: configure aws credentials id: aws_creds uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 @@ -113,12 +120,20 @@ jobs: - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image-name: ${{ inputs.docker-image }} - name: Pull docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -142,7 +157,11 @@ jobs: shell: bash continue-on-error: true run: | +<<<<<<< HEAD python3 -m pip install psutil==5.9.8 dataclasses_json==0.6.7 +======= + python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python3 -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" @@ -279,8 +298,13 @@ jobs: # copy test results back to the mounted workspace, needed sudo, resulting permissions were correct docker exec -t "${{ env.CONTAINER_NAME }}" sh -c "cd ../pytorch && sudo cp -R test/test-reports ../workspace/test" +<<<<<<< HEAD - name: Change permissions (only needed for kubernetes runners for now) if: ${{ always() && steps.test.conclusion && (contains(matrix.runner, 'gfx942') || contains(matrix.runner, 'mi355')) }} +======= + - name: Change permissions (only needed for MI300 runners for now) + if: ${{ always() && steps.test.conclusion && contains(matrix.runner, 'mi300') }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run: | docker exec -t "${{ env.CONTAINER_NAME }}" sh -c "sudo chown -R 1001:1001 test" @@ -330,7 +354,11 @@ jobs: aws-region: us-east-1 - name: Upload the benchmark results +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/upload-benchmark-results@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/upload-benchmark-results@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: benchmark-results-dir: test/test-reports dry-run: false diff --git a/.github/workflows/_runner-determinator.yml b/.github/workflows/_runner-determinator.yml index dd28024dbd806..b71a3a6db130f 100644 --- a/.github/workflows/_runner-determinator.yml +++ b/.github/workflows/_runner-determinator.yml @@ -59,7 +59,11 @@ jobs: PR_NUMBER: ${{ github.event.pull_request.number }} steps: # - name: Checkout PyTorch +<<<<<<< HEAD # uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + # uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # with: # fetch-depth: 1 # submodules: true diff --git a/.github/workflows/_win-build.yml b/.github/workflows/_win-build.yml index 92543128265d1..b2a63480f8417 100644 --- a/.github/workflows/_win-build.yml +++ b/.github/workflows/_win-build.yml @@ -77,7 +77,10 @@ jobs: run: | git config --global core.longpaths true git config --global core.symlinks true +<<<<<<< HEAD git config --global core.ignorecase false +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock # the directory on Windows and prevent GHA from checking out as reported @@ -85,10 +88,17 @@ jobs: git config --global core.fsmonitor false - name: Clean up leftover processes on non-ephemeral Windows runner +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/cleanup-runner@release/2.9 - name: Setup SSH (Click me for login details) uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/cleanup-runner@release/2.8 + + - name: Setup SSH (Click me for login details) + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: github-secret: ${{ secrets.GITHUB_TOKEN }} instructions: | @@ -103,7 +113,11 @@ jobs: # [see note: pytorch repo ref] - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: no-sudo: true @@ -151,7 +165,11 @@ jobs: BUILD_WHEEL: 1 MAX_JOBS: 8 CUDA_VERSION: ${{ inputs.cuda-version }} +<<<<<<< HEAD PYTHON_VERSION: "3.10" +======= + PYTHON_VERSION: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SCCACHE_BUCKET: "ossci-compiler-cache" SCCACHE_S3_KEY_PREFIX: ${{ github.workflow }} SCCACHE_REGION: us-east-1 diff --git a/.github/workflows/_win-test.yml b/.github/workflows/_win-test.yml index 37e48d99e2bed..0d5c1d5da50de 100644 --- a/.github/workflows/_win-test.yml +++ b/.github/workflows/_win-test.yml @@ -70,7 +70,10 @@ jobs: run: | git config --global core.longpaths true git config --global core.symlinks true +<<<<<<< HEAD git config --global core.ignorecase false +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock # the directory on Windows and prevent GHA from checking out as reported @@ -78,10 +81,17 @@ jobs: git config --global core.fsmonitor false - name: Clean up leftover processes on non-ephemeral Windows runner +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/cleanup-runner@release/2.9 - name: Setup SSH (Click me for login details) uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/cleanup-runner@release/2.8 + + - name: Setup SSH (Click me for login details) + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: github-secret: ${{ secrets.GITHUB_TOKEN }} instructions: | @@ -97,7 +107,11 @@ jobs: # [see note: pytorch repo ref] - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: no-sudo: true @@ -139,7 +153,11 @@ jobs: continue-on-error: true run: | # Windows conda doesn't have python3 binary, only python, but it's python3 +<<<<<<< HEAD ${CONDA_RUN} python -m pip install psutil==5.9.8 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 +======= + ${CONDA_RUN} python -m pip install psutil==5.9.1 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ${CONDA_RUN} python -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" @@ -184,7 +202,11 @@ jobs: env: USE_CUDA: ${{ inputs.cuda-version != 'cpu' && '1' || '0' }} INSTALL_WINDOWS_SDK: 1 +<<<<<<< HEAD PYTHON_VERSION: "3.10" +======= + PYTHON_VERSION: 3.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CONTINUE_THROUGH_ERROR: ${{ steps.keep-going.outputs.keep-going }} VERBOSE_TEST_LOGS: ${{ steps.keep-going.outputs.ci-verbose-test-logs }} TEST_SHOWLOCALS: ${{ steps.keep-going.outputs.ci-test-showlocals }} diff --git a/.github/workflows/_xpu-test.yml b/.github/workflows/_xpu-test.yml index 6bceb4eef6ba9..546c6a8900ec9 100644 --- a/.github/workflows/_xpu-test.yml +++ b/.github/workflows/_xpu-test.yml @@ -77,7 +77,11 @@ jobs: steps: # [see note: pytorch repo ref] - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Setup XPU uses: ./.github/actions/setup-xpu @@ -95,7 +99,11 @@ jobs: - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image-name: ${{ inputs.docker-image }} @@ -109,7 +117,11 @@ jobs: echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}" - name: Pull docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -133,7 +145,11 @@ jobs: MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }} MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }} run: | +<<<<<<< HEAD python3 -m pip install psutil==5.9.8 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 +======= + python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python3 -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" @@ -191,6 +207,12 @@ jobs: SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} REENABLED_ISSUES: ${{ steps.keep-going.outputs.reenabled-issues }} +<<<<<<< HEAD +======= + SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + SCCACHE_REGION: us-east-1 + SCCACHE_S3_KEY_PREFIX: ${{ github.workflow }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DOCKER_IMAGE: ${{ inputs.docker-image }} XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} @@ -275,7 +297,11 @@ jobs: - name: Change permissions if: ${{ always() && steps.test.conclusion }} run: | +<<<<<<< HEAD docker exec -t "${{ env.CONTAINER_NAME }}" sh -c "sudo chown -R 1000:1000 test" +======= + docker exec -t "${{ env.CONTAINER_NAME }}" sh -c "sudo chown -R 1001:1001 test" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Print remaining test logs shell: bash diff --git a/.github/workflows/build-almalinux-images.yml b/.github/workflows/build-almalinux-images.yml index e0492f7364421..96e1b8c544abd 100644 --- a/.github/workflows/build-almalinux-images.yml +++ b/.github/workflows/build-almalinux-images.yml @@ -36,10 +36,17 @@ jobs: runs-on: linux.9xlarge.ephemeral strategy: matrix: +<<<<<<< HEAD tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.3", "rocm6.4", "cpu"] steps: - name: Build docker image uses: pytorch/pytorch/.github/actions/binary-docker-build@release/2.9 +======= + tag: ["cuda12.6", "cuda12.8", "cuda12.9", "rocm6.3", "rocm6.4", "cpu"] + steps: + - name: Build docker image + uses: pytorch/pytorch/.github/actions/binary-docker-build@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image-name: almalinux-builder custom-tag-prefix: ${{matrix.tag}} diff --git a/.github/workflows/build-libtorch-images.yml b/.github/workflows/build-libtorch-images.yml index edfa0168e19fe..d5f2dddae1df9 100644 --- a/.github/workflows/build-libtorch-images.yml +++ b/.github/workflows/build-libtorch-images.yml @@ -32,7 +32,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -48,7 +52,10 @@ jobs: fail-fast: false matrix: include: [ +<<<<<<< HEAD { tag: "cuda13.0" }, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) { tag: "cuda12.9" }, { tag: "cuda12.8" }, { tag: "cuda12.6" }, @@ -58,7 +65,11 @@ jobs: ] steps: - name: Build docker image +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/binary-docker-build@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/binary-docker-build@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image-name: libtorch-cxx11-builder custom-tag-prefix: ${{ matrix.tag }} diff --git a/.github/workflows/build-magma-linux.yml b/.github/workflows/build-magma-linux.yml index be8f613169e8c..d96f1505826ce 100644 --- a/.github/workflows/build-magma-linux.yml +++ b/.github/workflows/build-magma-linux.yml @@ -34,7 +34,11 @@ jobs: id-token: write strategy: matrix: +<<<<<<< HEAD cuda_version: ["130", "129", "128", "126"] +======= + cuda_version: ["129", "128", "126"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Checkout PyTorch uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/build-magma-windows.yml b/.github/workflows/build-magma-windows.yml index b7d293a5cec11..fc5ea76151cd5 100644 --- a/.github/workflows/build-magma-windows.yml +++ b/.github/workflows/build-magma-windows.yml @@ -22,7 +22,11 @@ jobs: runs-on: windows-2022 strategy: matrix: +<<<<<<< HEAD cuda_version: ["130", "129", "128", "126"] +======= + cuda_version: ["129", "128", "126"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) config: ["Release", "Debug"] env: CUDA_VERSION: ${{ matrix.cuda_version }} diff --git a/.github/workflows/build-manywheel-images-s390x.yml b/.github/workflows/build-manywheel-images-s390x.yml index a719bf21a1ca4..6dda8a17af207 100644 --- a/.github/workflows/build-manywheel-images-s390x.yml +++ b/.github/workflows/build-manywheel-images-s390x.yml @@ -25,7 +25,11 @@ jobs: runs-on: linux.s390x steps: - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: submodules: false no-sudo: true diff --git a/.github/workflows/build-manywheel-images.yml b/.github/workflows/build-manywheel-images.yml index e3549cd6284a0..eeda122ca140a 100644 --- a/.github/workflows/build-manywheel-images.yml +++ b/.github/workflows/build-manywheel-images.yml @@ -32,7 +32,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -46,12 +50,20 @@ jobs: fail-fast: false matrix: include: [ +<<<<<<< HEAD { name: "manylinux2_28-builder", tag: "cuda13.0", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "cuda12.8", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "cuda12.6", runner: "linux.9xlarge.ephemeral" }, { name: "manylinuxaarch64-builder", tag: "cuda13.0", runner: "linux.arm64.2xlarge.ephemeral" }, { name: "manylinuxaarch64-builder", tag: "cuda12.8", runner: "linux.arm64.2xlarge.ephemeral" }, { name: "manylinuxaarch64-builder", tag: "cuda12.6", runner: "linux.arm64.2xlarge.ephemeral" }, +======= + { name: "manylinux2_28-builder", tag: "cuda12.9", runner: "linux.9xlarge.ephemeral" }, + { name: "manylinux2_28-builder", tag: "cuda12.8", runner: "linux.9xlarge.ephemeral" }, + { name: "manylinux2_28-builder", tag: "cuda12.6", runner: "linux.9xlarge.ephemeral" }, + { name: "manylinuxaarch64-builder", tag: "cuda12.9", runner: "linux.arm64.2xlarge.ephemeral" }, + { name: "manylinuxaarch64-builder", tag: "cuda12.8", runner: "linux.arm64.2xlarge.ephemeral" }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) { name: "manylinux2_28-builder", tag: "rocm6.3", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "rocm6.4", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "cpu", runner: "linux.9xlarge.ephemeral" }, @@ -63,7 +75,11 @@ jobs: name: ${{ matrix.name }}:${{ matrix.tag }} steps: - name: Build docker image +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/binary-docker-build@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/binary-docker-build@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image-name: ${{ matrix.name }} custom-tag-prefix: ${{ matrix.tag }} diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index 8f066de47534c..03d5f64cd227c 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -3,12 +3,19 @@ name: Build Triton wheels on: push: branches: +<<<<<<< HEAD - release/2.9 +======= + - release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: # NOTE: Binary build pipelines should only get triggered on release candidate builds # Release candidate tags look like: v1.11.0-rc1 - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ +<<<<<<< HEAD - 'ciflow/triton_binaries/*' +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) paths: - .github/workflows/build-triton-wheel.yml - .github/scripts/build_triton_wheel.py @@ -36,7 +43,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -50,7 +61,11 @@ jobs: strategy: fail-fast: false matrix: +<<<<<<< HEAD py_vers: [ "3.9", "3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t" ] +======= + py_vers: [ "3.9", "3.10", "3.11", "3.12", "3.13", "3.13t" ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device: ["cuda", "rocm", "xpu", "aarch64"] docker-image: ["pytorch/manylinux2_28-builder:cpu"] include: @@ -74,12 +89,20 @@ jobs: PLATFORM: 'manylinux_2_28_x86_64' steps: - name: Setup SSH (Click me for login details) +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: github-secret: ${{ secrets.GITHUB_TOKEN }} - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: submodules: false @@ -87,7 +110,11 @@ jobs: uses: ./.github/actions/setup-linux - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ env.DOCKER_IMAGE }} @@ -126,12 +153,15 @@ jobs: 3.13t) PYTHON_EXECUTABLE=/opt/python/cp313-cp313t/bin/python ;; +<<<<<<< HEAD 3.14) PYTHON_EXECUTABLE=/opt/python/cp314-cp314/bin/python ;; 3.14t) PYTHON_EXECUTABLE=/opt/python/cp314-cp314t/bin/python ;; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *) echo "Unsupported python version ${PY_VERS}" exit 1 @@ -145,7 +175,11 @@ jobs: fi docker exec -t "${container_name}" yum install -y zlib-devel zip +<<<<<<< HEAD docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install -U setuptools==78.1.0 pybind11==3.0.1 auditwheel wheel +======= + docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install -U setuptools==78.1.0 pybind11==2.13.1 auditwheel wheel +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) set +e docker exec -t "${container_name}" command -v pip has_pip=$? @@ -184,7 +218,11 @@ jobs: path: ${{ runner.temp }}/artifacts/wheelhouse/* - name: Teardown Linux +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: always() build-wheel-win: @@ -194,7 +232,11 @@ jobs: strategy: fail-fast: false matrix: +<<<<<<< HEAD py_vers: [ "3.9", "3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t" ] +======= + py_vers: [ "3.9", "3.10", "3.11", "3.12", "3.13", "3.13t" ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device: ["xpu"] timeout-minutes: 40 env: @@ -217,7 +259,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/check-labels.yml b/.github/workflows/check-labels.yml index 1174a1c502f62..82b0e350529a1 100644 --- a/.github/workflows/check-labels.yml +++ b/.github/workflows/check-labels.yml @@ -38,7 +38,11 @@ jobs: runs-on: linux.24_04.4x steps: - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: submodules: false fetch-depth: 1 diff --git a/.github/workflows/check_mergeability_ghstack.yml b/.github/workflows/check_mergeability_ghstack.yml index 569a174665ba8..c94545096896f 100644 --- a/.github/workflows/check_mergeability_ghstack.yml +++ b/.github/workflows/check_mergeability_ghstack.yml @@ -56,7 +56,11 @@ jobs: cache: pip architecture: x64 +<<<<<<< HEAD - run: pip install pyyaml==6.0.2 +======= + - run: pip install pyyaml==6.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shell: bash - name: Verify mergeability diff --git a/.github/workflows/cherry-pick.yml b/.github/workflows/cherry-pick.yml index 310857782ea14..3153a0fc07175 100644 --- a/.github/workflows/cherry-pick.yml +++ b/.github/workflows/cherry-pick.yml @@ -26,7 +26,11 @@ jobs: cache: pip # Not the direct dependencies but the script uses trymerge +<<<<<<< HEAD - run: pip install pyyaml==6.0.2 +======= + - run: pip install pyyaml==6.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Setup committer id run: | diff --git a/.github/workflows/close-nonexistent-disable-issues.yml b/.github/workflows/close-nonexistent-disable-issues.yml index da83019a59084..8113b63237ffd 100644 --- a/.github/workflows/close-nonexistent-disable-issues.yml +++ b/.github/workflows/close-nonexistent-disable-issues.yml @@ -13,7 +13,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: submodules: false fetch-depth: 1 diff --git a/.github/workflows/create_release.yml b/.github/workflows/create_release.yml index 03631be3e5630..d3ad75ca954e6 100644 --- a/.github/workflows/create_release.yml +++ b/.github/workflows/create_release.yml @@ -19,7 +19,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -57,11 +61,14 @@ jobs: echo "PT_RELEASE_FILE=pytorch-$tag_or_branch.tar.gz" >> "$GITHUB_ENV" - name: Checkout optional submodules run: python3 tools/optional_submodules.py +<<<<<<< HEAD - name: Copy docs requirements for inclusion run: | # Replace symlink with actual file rm docs/requirements.txt || true cp .ci/docker/requirements-docs.txt docs/requirements.txt +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Create source distribution run: | # Create new folder with specified name so extracting the archive yields that diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index f88244a13ffc0..a490addcd6ac5 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -33,7 +33,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -50,6 +54,7 @@ jobs: runner: [linux.12xlarge] docker-image-name: [ pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, +<<<<<<< HEAD pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11, pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks, @@ -75,6 +80,34 @@ jobs: # pytorch-linux-jammy-py3-clang12-executorch, pytorch-linux-jammy-py3.12-triton-cpu, pytorch-linux-noble-riscv64-py3.12-gcc14 +======= + pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks, + pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks, + pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks, + pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks, + pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc9-inductor-benchmarks, + pytorch-linux-jammy-cuda12.8-cudnn9-py3.13-gcc9-inductor-benchmarks, + pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, + pytorch-linux-jammy-py3.9-clang12, + pytorch-linux-jammy-py3.11-clang12, + pytorch-linux-jammy-py3.12-clang12, + pytorch-linux-jammy-py3.13-clang12, + pytorch-linux-jammy-rocm-n-1-py3, + pytorch-linux-jammy-rocm-n-py3, + pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-clang12, + pytorch-linux-jammy-py3.9-gcc11, + pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks, + pytorch-linux-jammy-py3.12-halide, + pytorch-linux-jammy-xpu-2025.0-py3, + pytorch-linux-jammy-xpu-2025.1-py3, + pytorch-linux-jammy-py3-clang15-asan, + pytorch-linux-jammy-py3-clang18-asan, + pytorch-linux-jammy-py3-clang12-onnx, + pytorch-linux-jammy-linter, + pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter, + pytorch-linux-jammy-py3-clang12-executorch, + pytorch-linux-jammy-py3.12-triton-cpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] include: - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11 @@ -96,21 +129,33 @@ jobs: # [see note: pytorch repo ref] # deep clone (fetch-depth 0) required for git merge-base - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Setup Linux uses: ./.github/actions/setup-linux - name: Build docker image id: build-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image-name: ci-image:${{ matrix.docker-image-name }} always-rebuild: true push: true - name: Pull docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.build-docker-image.outputs.docker-image }} @@ -123,7 +168,11 @@ jobs: GHCR_PAT: ${{ secrets.GHCR_PAT }} with: shell: bash +<<<<<<< HEAD timeout_minutes: 60 +======= + timeout_minutes: 30 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) max_attempts: 5 retry_wait_seconds: 90 command: | @@ -141,5 +190,9 @@ jobs: if: always() - name: Teardown Linux +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: always() diff --git a/.github/workflows/docker-cache-mi300.yml b/.github/workflows/docker-cache-mi300.yml index bc2ae450f7c20..9f2f5bf97c057 100644 --- a/.github/workflows/docker-cache-mi300.yml +++ b/.github/workflows/docker-cache-mi300.yml @@ -20,7 +20,11 @@ jobs: runs-on: rocm-docker steps: - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: no-sudo: true @@ -39,13 +43,21 @@ jobs: - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 push: false - name: Pull docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml index 134e4caf30882..e940c00412046 100644 --- a/.github/workflows/docker-release.yml +++ b/.github/workflows/docker-release.yml @@ -37,7 +37,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -52,7 +56,11 @@ jobs: matrix: ${{ steps.generate-matrix.outputs.matrix }} steps: - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: fetch-depth: 1 submodules: true @@ -82,7 +90,11 @@ jobs: CUDNN_VERSION: ${{ matrix.cudnn_version }} steps: - name: Setup SSH (Click me for login details) +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: github-secret: ${{ secrets.GITHUB_TOKEN }} # [see note: pytorch repo ref] @@ -144,7 +156,11 @@ jobs: run: | make -f docker.Makefile "${BUILD_IMAGE_TYPE}-image" - name: Push nightly tags +<<<<<<< HEAD if: ${{ github.event.ref == 'refs/heads/nightly' && matrix.image_type == 'runtime' && matrix.platform == 'linux/amd4' }} +======= + if: ${{ github.event.ref == 'refs/heads/nightly' && matrix.image_type == 'runtime' && matrix.build_platforms == 'linux/amd4' }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run: | PYTORCH_DOCKER_TAG="${PYTORCH_VERSION}-cuda${CUDA_VERSION_SHORT}-cudnn${CUDNN_VERSION}-runtime" CUDA_SUFFIX="-cu${CUDA_VERSION}" @@ -164,12 +180,20 @@ jobs: fi - name: Teardown Linux +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: always() validate: needs: build +<<<<<<< HEAD uses: pytorch/test-infra/.github/workflows/validate-docker-images.yml@release/2.9 +======= + uses: pytorch/test-infra/.github/workflows/validate-docker-images.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: channel: test ref: main diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index 7e36c82644dc8..835feddbb26ef 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -41,12 +41,135 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} +<<<<<<< HEAD +======= + manywheel-py3_9-cpu-aarch64-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu-aarch64 + DOCKER_IMAGE: manylinux2_28_aarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.m7g.4xlarge.ephemeral + ALPINE_IMAGE: "arm64v8/alpine" + build_name: manywheel-py3_9-cpu-aarch64 + build_environment: linux-aarch64-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cpu-aarch64-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cpu-aarch64-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu-aarch64 + DOCKER_IMAGE: manylinux2_28_aarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cpu-aarch64 + build_environment: linux-aarch64-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.2xlarge + ALPINE_IMAGE: "arm64v8/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cpu-aarch64-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-cpu-aarch64-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu-aarch64 + DOCKER_IMAGE: manylinux2_28_aarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cpu-aarch64 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_9-cuda-aarch64-12_9-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.arm64.m7g.4xlarge.ephemeral + ALPINE_IMAGE: "arm64v8/alpine" + build_name: manywheel-py3_9-cuda-aarch64-12_9 + build_environment: linux-aarch64-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' + timeout-minutes: 420 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda-aarch64-12_9-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-cuda-aarch64-12_9-build + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda-aarch64-12_9 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) manywheel-py3_10-cpu-aarch64-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -60,6 +183,10 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -83,6 +210,10 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -106,13 +237,21 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-aarch64 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD manywheel-py3_10-cuda-aarch64-12_6-build: +======= + manywheel-py3_10-cuda-aarch64-12_9-build: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -121,15 +260,25 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu126 GPU_ARCH_VERSION: "12.6-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" +<<<<<<< HEAD build_name: manywheel-py3_10-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' @@ -137,16 +286,30 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda-aarch64-12_6-upload: # Uploading +======= + build_name: manywheel-py3_10-cuda-aarch64-12_9 + build_environment: linux-aarch64-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' + timeout-minutes: 420 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda-aarch64-12_9-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: manywheel-py3_10-cuda-aarch64-12_6-build +======= + needs: manywheel-py3_10-cuda-aarch64-12_9-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu126 GPU_ARCH_VERSION: "12.6-aarch64" GPU_ARCH_TYPE: cuda-aarch64 @@ -246,6 +409,16 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda-aarch64-13_0 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda-aarch64-12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -263,6 +436,10 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -286,6 +463,10 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -309,13 +490,21 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-aarch64 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD manywheel-py3_11-cuda-aarch64-12_6-build: +======= + manywheel-py3_11-cuda-aarch64-12_9-build: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -324,15 +513,25 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu126 GPU_ARCH_VERSION: "12.6-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" +<<<<<<< HEAD build_name: manywheel-py3_11-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' @@ -340,16 +539,30 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda-aarch64-12_6-upload: # Uploading +======= + build_name: manywheel-py3_11-cuda-aarch64-12_9 + build_environment: linux-aarch64-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' + timeout-minutes: 420 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda-aarch64-12_9-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: manywheel-py3_11-cuda-aarch64-12_6-build +======= + needs: manywheel-py3_11-cuda-aarch64-12_9-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu126 GPU_ARCH_VERSION: "12.6-aarch64" GPU_ARCH_TYPE: cuda-aarch64 @@ -449,6 +662,16 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda-aarch64-13_0 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda-aarch64-12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -466,6 +689,10 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -489,6 +716,10 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -512,13 +743,21 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-aarch64 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD manywheel-py3_12-cuda-aarch64-12_6-build: +======= + manywheel-py3_12-cuda-aarch64-12_9-build: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -527,15 +766,25 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu126 GPU_ARCH_VERSION: "12.6-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" +<<<<<<< HEAD build_name: manywheel-py3_12-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' @@ -543,16 +792,30 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda-aarch64-12_6-upload: # Uploading +======= + build_name: manywheel-py3_12-cuda-aarch64-12_9 + build_environment: linux-aarch64-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' + timeout-minutes: 420 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda-aarch64-12_9-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: manywheel-py3_12-cuda-aarch64-12_6-build +======= + needs: manywheel-py3_12-cuda-aarch64-12_9-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu126 GPU_ARCH_VERSION: "12.6-aarch64" GPU_ARCH_TYPE: cuda-aarch64 @@ -652,6 +915,16 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda-aarch64-13_0 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda-aarch64-12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -669,6 +942,10 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -692,6 +969,10 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -715,13 +996,21 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-aarch64 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD manywheel-py3_13-cuda-aarch64-12_6-build: +======= + manywheel-py3_13-cuda-aarch64-12_9-build: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -730,15 +1019,25 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu126 GPU_ARCH_VERSION: "12.6-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" +<<<<<<< HEAD build_name: manywheel-py3_13-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' @@ -746,16 +1045,30 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda-aarch64-12_6-upload: # Uploading +======= + build_name: manywheel-py3_13-cuda-aarch64-12_9 + build_environment: linux-aarch64-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' + timeout-minutes: 420 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda-aarch64-12_9-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: manywheel-py3_13-cuda-aarch64-12_6-build +======= + needs: manywheel-py3_13-cuda-aarch64-12_9-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu126 GPU_ARCH_VERSION: "12.6-aarch64" GPU_ARCH_TYPE: cuda-aarch64 @@ -855,6 +1168,16 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda-aarch64-13_0 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda-aarch64-12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -872,6 +1195,10 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral @@ -895,6 +1222,10 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -918,13 +1249,21 @@ jobs: GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: manylinux2_28_aarch64-builder DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cpu-aarch64 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD manywheel-py3_13t-cuda-aarch64-12_6-build: +======= + manywheel-py3_13t-cuda-aarch64-12_9-build: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -933,15 +1272,25 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu126 GPU_ARCH_VERSION: "12.6-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" +<<<<<<< HEAD build_name: manywheel-py3_13t-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' @@ -949,16 +1298,30 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda-aarch64-12_6-upload: # Uploading +======= + build_name: manywheel-py3_13t-cuda-aarch64-12_9 + build_environment: linux-aarch64-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' + timeout-minutes: 420 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13t-cuda-aarch64-12_9-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: manywheel-py3_13t-cuda-aarch64-12_6-build +======= + needs: manywheel-py3_13t-cuda-aarch64-12_9-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu126 GPU_ARCH_VERSION: "12.6-aarch64" GPU_ARCH_TYPE: cuda-aarch64 @@ -1464,6 +1827,16 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.14t" build_name: manywheel-py3_14t-cuda-aarch64-13_0 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9-aarch64 + GPU_ARCH_TYPE: cuda-aarch64 + DOCKER_IMAGE: manylinuxaarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.13t" + build_name: manywheel-py3_13t-cuda-aarch64-12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml diff --git a/.github/workflows/generated-linux-binary-libtorch-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-nightly.yml index bc671ae80ae2a..4b36b93bbcba5 100644 --- a/.github/workflows/generated-linux-binary-libtorch-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-nightly.yml @@ -41,7 +41,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -122,7 +126,11 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DOCKER_IMAGE: libtorch-cxx11-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 @@ -145,7 +153,11 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DOCKER_IMAGE: libtorch-cxx11-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 @@ -154,7 +166,11 @@ jobs: build_name: libtorch-cuda12_6-shared-with-deps-release build_environment: linux-binary-libtorch runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD runs_on: linux.4xlarge.nvidia.gpu # 12.6 build can use maxwell (sm_50) runner +======= + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} libtorch-cuda12_6-shared-with-deps-release-upload: # Uploading @@ -169,7 +185,11 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DOCKER_IMAGE: libtorch-cxx11-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 @@ -190,7 +210,11 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DOCKER_IMAGE: libtorch-cxx11-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 @@ -213,7 +237,11 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DOCKER_IMAGE: libtorch-cxx11-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 @@ -222,7 +250,11 @@ jobs: build_name: libtorch-cuda12_8-shared-with-deps-release build_environment: linux-binary-libtorch runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8+ builds need sm_70+ runner +======= + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} libtorch-cuda12_8-shared-with-deps-release-upload: # Uploading @@ -237,7 +269,11 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DOCKER_IMAGE: libtorch-cxx11-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 @@ -248,7 +284,11 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD libtorch-cuda13_0-shared-with-deps-release-build: +======= + libtorch-cuda12_9-shared-with-deps-release-build: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -257,6 +297,7 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -273,6 +314,24 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: - libtorch-cuda13_0-shared-with-deps-release-build +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: libtorch-cxx11-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + LIBTORCH_CONFIG: release + LIBTORCH_VARIANT: shared-with-deps + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: libtorch-cuda12_9-shared-with-deps-release + build_environment: linux-binary-libtorch + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + libtorch-cuda12_9-shared-with-deps-release-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - libtorch-cuda12_9-shared-with-deps-release-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -280,6 +339,7 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -294,16 +354,37 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} libtorch-cuda13_0-shared-with-deps-release-upload: # Uploading +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: libtorch-cxx11-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + LIBTORCH_CONFIG: release + LIBTORCH_VARIANT: shared-with-deps + build_name: libtorch-cuda12_9-shared-with-deps-release + build_environment: linux-binary-libtorch + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + libtorch-cuda12_9-shared-with-deps-release-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: libtorch-cuda13_0-shared-with-deps-release-test +======= + needs: libtorch-cuda12_9-shared-with-deps-release-test +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -312,6 +393,16 @@ jobs: LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps build_name: libtorch-cuda13_0-shared-with-deps-release +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: libtorch-cxx11-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + LIBTORCH_CONFIG: release + LIBTORCH_VARIANT: shared-with-deps + build_name: libtorch-cuda12_9-shared-with-deps-release +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -326,14 +417,21 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" +======= + GPU_ARCH_VERSION: 6.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm DOCKER_IMAGE: libtorch-cxx11-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD timeout-minutes: 300 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: libtorch-rocm6_3-shared-with-deps-release build_environment: linux-binary-libtorch secrets: @@ -351,7 +449,11 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" +======= + GPU_ARCH_VERSION: 6.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: libtorch-cxx11-builder @@ -390,7 +492,11 @@ jobs: role-duration-seconds: 18000 - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: libtorch-cxx11-builder @@ -398,7 +504,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -419,7 +529,11 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" +======= + GPU_ARCH_VERSION: 6.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm DOCKER_IMAGE: libtorch-cxx11-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 @@ -440,14 +554,21 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" +======= + GPU_ARCH_VERSION: 6.4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm DOCKER_IMAGE: libtorch-cxx11-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD timeout-minutes: 300 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: libtorch-rocm6_4-shared-with-deps-release build_environment: linux-binary-libtorch secrets: @@ -465,7 +586,11 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" +======= + GPU_ARCH_VERSION: 6.4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: libtorch-cxx11-builder @@ -504,7 +629,11 @@ jobs: role-duration-seconds: 18000 - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: libtorch-cxx11-builder @@ -512,7 +641,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -533,7 +666,11 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" +======= + GPU_ARCH_VERSION: 6.4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm DOCKER_IMAGE: libtorch-cxx11-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 diff --git a/.github/workflows/generated-linux-binary-libtorch-release-main.yml b/.github/workflows/generated-linux-binary-libtorch-release-main.yml index 9d55fc6e50ab7..ddf9f894fba7d 100644 --- a/.github/workflows/generated-linux-binary-libtorch-release-main.yml +++ b/.github/workflows/generated-linux-binary-libtorch-release-main.yml @@ -36,7 +36,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index 85b91378b253a..c38939b625786 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -36,13 +36,68 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} +<<<<<<< HEAD manywheel-py3_12-cuda12_8-build: +======= + manywheel-py3_9-cuda12_6-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda12_6 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_6-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda12_6-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_6 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_9-cuda12_8-build: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -52,6 +107,7 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder @@ -67,6 +123,24 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: - manywheel-py3_12-cuda12_8-build +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda12_8 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda12_8-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -75,6 +149,7 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder @@ -86,3 +161,155 @@ jobs: runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8+ builds need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_8 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_9-cuda12_9-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda12_9 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda12_9-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_9 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_9-rocm6_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-rocm6_4 + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-rocm6_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-rocm6_4-build + - get-label-type + runs-on: linux.rocm.gpu.mi250 + timeout-minutes: 240 + env: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + SKIP_ALL_TESTS: 1 + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.9" + steps: + - name: Setup ROCm + uses: ./.github/actions/setup-rocm + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: manywheel-py3_9-rocm6_4 + path: "${{ runner.temp }}/artifacts/" + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: ROCm set GPU_FLAG + run: | + echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" + - name: configure aws credentials + id: aws_creds + if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + aws-region: us-east-1 + role-duration-seconds: 18000 + - name: Calculate docker image + id: calculate-docker-image + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 + with: + docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} + docker-image-name: manylinux2_28-builder + custom-tag-prefix: rocm6.4 + docker-build-dir: .ci/docker + working-directory: pytorch + - name: Pull Docker image + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 + with: + docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Test Pytorch binary + uses: ./pytorch/.github/actions/test-pytorch-binary + env: + DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Teardown ROCm + uses: ./.github/actions/teardown-rocm +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index 5f9eaab976a62..2ca9d976629ce 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -41,12 +41,629 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} +<<<<<<< HEAD +======= + manywheel-py3_9-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cpu + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cpu + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cpu-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cpu + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cpu + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-cpu-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cpu + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_9-cuda12_6-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda12_6 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_6-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda12_6-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_6 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_6-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-cuda12_6-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_6 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_9-cuda12_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda12_8 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda12_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_8 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_8-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-cuda12_8-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_8 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_9-cuda12_9-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda12_9 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda12_9-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_9 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_9-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-cuda12_9-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_9 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_9-rocm6_3-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.3 + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-rocm6_3 + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-rocm6_3-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-rocm6_3-build + - get-label-type + runs-on: linux.rocm.gpu.mi250 + timeout-minutes: 240 + env: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.3 + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + SKIP_ALL_TESTS: 1 + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + use_split_build: False + DESIRED_PYTHON: "3.9" + steps: + - name: Setup ROCm + uses: ./.github/actions/setup-rocm + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: manywheel-py3_9-rocm6_3 + path: "${{ runner.temp }}/artifacts/" + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: ROCm set GPU_FLAG + run: | + echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" + - name: configure aws credentials + id: aws_creds + if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + aws-region: us-east-1 + role-duration-seconds: 18000 + - name: Calculate docker image + id: calculate-docker-image + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 + with: + docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} + docker-image-name: manylinux2_28-builder + custom-tag-prefix: rocm6.3 + docker-build-dir: .ci/docker + working-directory: pytorch + - name: Pull Docker image + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 + with: + docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Test Pytorch binary + uses: ./pytorch/.github/actions/test-pytorch-binary + env: + DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Teardown ROCm + uses: ./.github/actions/teardown-rocm + manywheel-py3_9-rocm6_3-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-rocm6_3-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.3 + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-rocm6_3 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_9-rocm6_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-rocm6_4 + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-rocm6_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-rocm6_4-build + - get-label-type + runs-on: linux.rocm.gpu.mi250 + timeout-minutes: 240 + env: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + SKIP_ALL_TESTS: 1 + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.9" + steps: + - name: Setup ROCm + uses: ./.github/actions/setup-rocm + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: manywheel-py3_9-rocm6_4 + path: "${{ runner.temp }}/artifacts/" + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: ROCm set GPU_FLAG + run: | + echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" + - name: configure aws credentials + id: aws_creds + if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + aws-region: us-east-1 + role-duration-seconds: 18000 + - name: Calculate docker image + id: calculate-docker-image + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 + with: + docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} + docker-image-name: manylinux2_28-builder + custom-tag-prefix: rocm6.4 + docker-build-dir: .ci/docker + working-directory: pytorch + - name: Pull Docker image + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 + with: + docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Test Pytorch binary + uses: ./pytorch/.github/actions/test-pytorch-binary + env: + DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Teardown ROCm + uses: ./.github/actions/teardown-rocm + manywheel-py3_9-rocm6_4-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-rocm6_4-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-rocm6_4 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_9-xpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: xpu + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-xpu + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-xpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-xpu-build + - get-label-type + runs-on: linux.idc.xpu + timeout-minutes: 240 + env: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + SKIP_ALL_TESTS: 1 + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: xpu + use_split_build: False + DESIRED_PYTHON: "3.9" + permissions: + id-token: write + contents: read + steps: + - name: Setup XPU + uses: ./.github/actions/setup-xpu + - name: configure aws credentials + id: aws_creds + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + aws-region: us-east-1 + - name: Login to Amazon ECR + id: login-ecr + uses: aws-actions/amazon-ecr-login@v2 + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: manywheel-py3_9-xpu + path: "${{ runner.temp }}/artifacts/" + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Calculate docker image + id: calculate-docker-image + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 + with: + docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} + docker-image-name: manylinux2_28-builder + custom-tag-prefix: xpu + docker-build-dir: .ci/docker + working-directory: pytorch + - name: Pull Docker image + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 + with: + docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Test Pytorch binary + uses: ./pytorch/.github/actions/test-pytorch-binary + env: + DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Teardown XPU + uses: ./.github/actions/teardown-xpu + manywheel-py3_9-xpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-xpu-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: xpu + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-xpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) manywheel-py3_10-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -60,6 +677,10 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cpu @@ -81,6 +702,10 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu build_environment: linux-binary-manywheel @@ -103,6 +728,10 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu secrets: @@ -119,15 +748,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_6 build_environment: linux-binary-manywheel +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_6-test: # Testing @@ -142,15 +783,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_6 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD runs_on: linux.4xlarge.nvidia.gpu # 12.6 build can use maxwell (sm_50) runner +======= + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_6-upload: # Uploading @@ -165,10 +818,18 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_6 secrets: @@ -185,15 +846,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_8 build_environment: linux-binary-manywheel +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_8-test: # Testing @@ -208,15 +881,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_8 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8+ builds need sm_70+ runner +======= + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_8-upload: # Uploading @@ -231,17 +916,29 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD manywheel-py3_10-cuda13_0-build: +======= + manywheel-py3_10-cuda12_9-build: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -250,6 +947,7 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -266,6 +964,25 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: - manywheel-py3_10-cuda13_0-build +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.10" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_10-cuda12_9 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_10-cuda12_9-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -273,6 +990,7 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -286,16 +1004,37 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda13_0-upload: # Uploading +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda12_9 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda12_9-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: manywheel-py3_10-cuda13_0-test +======= + needs: manywheel-py3_10-cuda12_9-test +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -303,6 +1042,16 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda13_0 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -317,6 +1066,7 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder @@ -324,6 +1074,15 @@ jobs: DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 +======= + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + use_split_build: False + DESIRED_PYTHON: "3.10" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: manywheel-py3_10-rocm6_3 build_environment: linux-binary-manywheel secrets: @@ -341,11 +1100,19 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" +======= + GPU_ARCH_VERSION: 6.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" steps: - name: Setup ROCm @@ -379,7 +1146,11 @@ jobs: role-duration-seconds: 18000 - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder @@ -387,7 +1158,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -408,10 +1183,18 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 +======= + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-rocm6_3 secrets: @@ -428,6 +1211,7 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder @@ -435,6 +1219,15 @@ jobs: DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 +======= + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.10" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: manywheel-py3_10-rocm6_4 build_environment: linux-binary-manywheel secrets: @@ -452,11 +1245,19 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" +======= + GPU_ARCH_VERSION: 6.4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" steps: - name: Setup ROCm @@ -490,7 +1291,11 @@ jobs: role-duration-seconds: 18000 - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder @@ -498,7 +1303,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -519,10 +1328,18 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 +======= + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-rocm6_4 secrets: @@ -542,11 +1359,19 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-xpu build_environment: linux-binary-manywheel +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-xpu-test: # Testing @@ -566,13 +1391,21 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" permissions: id-token: write contents: read steps: - name: Setup XPU +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/setup-xpu@release/2.9 +======= + uses: ./.github/actions/setup-xpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: configure aws credentials id: aws_creds uses: aws-actions/configure-aws-credentials@v4 @@ -600,7 +1433,11 @@ jobs: working-directory: pytorch - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder @@ -608,7 +1445,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -632,6 +1473,10 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-xpu secrets: @@ -651,6 +1496,10 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cpu @@ -672,6 +1521,10 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu build_environment: linux-binary-manywheel @@ -694,6 +1547,10 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu secrets: @@ -710,15 +1567,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_6 build_environment: linux-binary-manywheel +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_6-test: # Testing @@ -733,15 +1602,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_6 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD runs_on: linux.4xlarge.nvidia.gpu # 12.6 build can use maxwell (sm_50) runner +======= + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_6-upload: # Uploading @@ -756,10 +1637,18 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_6 secrets: @@ -776,15 +1665,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_8 build_environment: linux-binary-manywheel +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_8-test: # Testing @@ -799,15 +1700,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_8 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8+ builds need sm_70+ runner +======= + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_8-upload: # Uploading @@ -822,17 +1735,29 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD manywheel-py3_11-cuda13_0-build: +======= + manywheel-py3_11-cuda12_8-full-build: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -841,6 +1766,7 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -857,6 +1783,24 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: - manywheel-py3_11-cuda13_0-build +======= + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False + DESIRED_PYTHON: "3.11" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cuda12_8-full + build_environment: linux-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_8-full-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cuda12_8-full-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -864,6 +1808,7 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -877,16 +1822,37 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda13_0-upload: # Uploading +======= + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_8-full + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_8-full-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: manywheel-py3_11-cuda13_0-test +======= + needs: manywheel-py3_11-cuda12_8-full-test +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -894,6 +1860,85 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda13_0 +======= + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_8-full + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_11-cuda12_9-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.11" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cuda12_9 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cuda12_9-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_9 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_9-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_11-cuda12_9-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -908,6 +1953,7 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder @@ -915,6 +1961,15 @@ jobs: DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 +======= + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + use_split_build: False + DESIRED_PYTHON: "3.11" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: manywheel-py3_11-rocm6_3 build_environment: linux-binary-manywheel secrets: @@ -932,11 +1987,19 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" +======= + GPU_ARCH_VERSION: 6.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" steps: - name: Setup ROCm @@ -970,7 +2033,11 @@ jobs: role-duration-seconds: 18000 - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder @@ -978,7 +2045,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -999,10 +2070,18 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 +======= + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-rocm6_3 secrets: @@ -1019,6 +2098,7 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder @@ -1026,6 +2106,15 @@ jobs: DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 +======= + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.11" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: manywheel-py3_11-rocm6_4 build_environment: linux-binary-manywheel secrets: @@ -1043,11 +2132,19 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" +======= + GPU_ARCH_VERSION: 6.4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" steps: - name: Setup ROCm @@ -1081,7 +2178,11 @@ jobs: role-duration-seconds: 18000 - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder @@ -1089,7 +2190,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -1110,10 +2215,18 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 +======= + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-rocm6_4 secrets: @@ -1133,11 +2246,19 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-xpu build_environment: linux-binary-manywheel +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-xpu-test: # Testing @@ -1157,13 +2278,21 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" permissions: id-token: write contents: read steps: - name: Setup XPU +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/setup-xpu@release/2.9 +======= + uses: ./.github/actions/setup-xpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: configure aws credentials id: aws_creds uses: aws-actions/configure-aws-credentials@v4 @@ -1191,7 +2320,11 @@ jobs: working-directory: pytorch - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder @@ -1199,7 +2332,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -1223,6 +2360,10 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-xpu secrets: @@ -1242,6 +2383,10 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cpu @@ -1263,6 +2408,10 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu build_environment: linux-binary-manywheel @@ -1285,6 +2434,10 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu secrets: @@ -1301,15 +2454,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_6 build_environment: linux-binary-manywheel +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_6-test: # Testing @@ -1324,15 +2489,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_6 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD runs_on: linux.4xlarge.nvidia.gpu # 12.6 build can use maxwell (sm_50) runner +======= + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_6-upload: # Uploading @@ -1347,10 +2524,18 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_6 secrets: @@ -1367,15 +2552,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_8 build_environment: linux-binary-manywheel +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_8-test: # Testing @@ -1390,15 +2587,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_8 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8+ builds need sm_70+ runner +======= + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_8-upload: # Uploading @@ -1413,17 +2622,29 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD manywheel-py3_12-cuda13_0-build: +======= + manywheel-py3_12-cuda12_9-build: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1432,6 +2653,7 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -1448,6 +2670,25 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: - manywheel-py3_12-cuda13_0-build +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.12" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_12-cuda12_9 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_12-cuda12_9-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1455,6 +2696,7 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -1468,16 +2710,37 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda13_0-upload: # Uploading +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda12_9 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda12_9-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: manywheel-py3_12-cuda13_0-test +======= + needs: manywheel-py3_12-cuda12_9-test +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -1485,6 +2748,16 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda13_0 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -1499,6 +2772,7 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder @@ -1506,6 +2780,15 @@ jobs: DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 +======= + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + use_split_build: False + DESIRED_PYTHON: "3.12" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: manywheel-py3_12-rocm6_3 build_environment: linux-binary-manywheel secrets: @@ -1523,11 +2806,19 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" +======= + GPU_ARCH_VERSION: 6.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" steps: - name: Setup ROCm @@ -1561,7 +2852,11 @@ jobs: role-duration-seconds: 18000 - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder @@ -1569,7 +2864,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -1590,10 +2889,18 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 +======= + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-rocm6_3 secrets: @@ -1610,6 +2917,7 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder @@ -1617,6 +2925,15 @@ jobs: DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 +======= + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.12" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: manywheel-py3_12-rocm6_4 build_environment: linux-binary-manywheel secrets: @@ -1634,11 +2951,19 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" +======= + GPU_ARCH_VERSION: 6.4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" steps: - name: Setup ROCm @@ -1672,7 +2997,11 @@ jobs: role-duration-seconds: 18000 - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder @@ -1680,7 +3009,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -1701,10 +3034,18 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 +======= + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-rocm6_4 secrets: @@ -1724,11 +3065,19 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-xpu build_environment: linux-binary-manywheel +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-xpu-test: # Testing @@ -1748,13 +3097,21 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" permissions: id-token: write contents: read steps: - name: Setup XPU +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/setup-xpu@release/2.9 +======= + uses: ./.github/actions/setup-xpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: configure aws credentials id: aws_creds uses: aws-actions/configure-aws-credentials@v4 @@ -1782,7 +3139,11 @@ jobs: working-directory: pytorch - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder @@ -1790,7 +3151,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -1814,6 +3179,10 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-xpu secrets: @@ -1833,6 +3202,10 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cpu @@ -1854,6 +3227,10 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu build_environment: linux-binary-manywheel @@ -1876,6 +3253,10 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu secrets: @@ -1892,15 +3273,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_6 build_environment: linux-binary-manywheel +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda12_6-test: # Testing @@ -1915,15 +3308,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_6 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD runs_on: linux.4xlarge.nvidia.gpu # 12.6 build can use maxwell (sm_50) runner +======= + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda12_6-upload: # Uploading @@ -1938,10 +3343,18 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_6 secrets: @@ -1958,15 +3371,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_8 build_environment: linux-binary-manywheel +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda12_8-test: # Testing @@ -1981,15 +3406,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_8 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8+ builds need sm_70+ runner +======= + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda12_8-upload: # Uploading @@ -2004,17 +3441,29 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD manywheel-py3_13-cuda13_0-build: +======= + manywheel-py3_13-cuda12_9-build: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2023,6 +3472,7 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -2039,6 +3489,25 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: - manywheel-py3_13-cuda13_0-build +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.13" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_13-cuda12_9 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13-cuda12_9-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2046,6 +3515,7 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -2059,16 +3529,37 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda13_0-upload: # Uploading +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda12_9 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda12_9-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: manywheel-py3_13-cuda13_0-test +======= + needs: manywheel-py3_13-cuda12_9-test +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -2076,6 +3567,16 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda13_0 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -2090,6 +3591,7 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder @@ -2097,6 +3599,15 @@ jobs: DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 +======= + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + use_split_build: False + DESIRED_PYTHON: "3.13" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: manywheel-py3_13-rocm6_3 build_environment: linux-binary-manywheel secrets: @@ -2114,11 +3625,19 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" +======= + GPU_ARCH_VERSION: 6.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" steps: - name: Setup ROCm @@ -2152,7 +3671,11 @@ jobs: role-duration-seconds: 18000 - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder @@ -2160,7 +3683,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -2181,10 +3708,18 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 +======= + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-rocm6_3 secrets: @@ -2201,6 +3736,7 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder @@ -2208,6 +3744,15 @@ jobs: DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 +======= + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.13" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: manywheel-py3_13-rocm6_4 build_environment: linux-binary-manywheel secrets: @@ -2225,11 +3770,19 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" +======= + GPU_ARCH_VERSION: 6.4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" steps: - name: Setup ROCm @@ -2263,7 +3816,11 @@ jobs: role-duration-seconds: 18000 - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder @@ -2271,7 +3828,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -2292,10 +3853,18 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 +======= + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-rocm6_4 secrets: @@ -2315,11 +3884,19 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-xpu build_environment: linux-binary-manywheel +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-xpu-test: # Testing @@ -2339,13 +3916,21 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" permissions: id-token: write contents: read steps: - name: Setup XPU +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/setup-xpu@release/2.9 +======= + uses: ./.github/actions/setup-xpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: configure aws credentials id: aws_creds uses: aws-actions/configure-aws-credentials@v4 @@ -2373,7 +3958,11 @@ jobs: working-directory: pytorch - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder @@ -2381,7 +3970,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -2405,6 +3998,10 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-xpu secrets: @@ -2424,6 +4021,10 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cpu @@ -2445,6 +4046,10 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cpu build_environment: linux-binary-manywheel @@ -2467,6 +4072,10 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cpu secrets: @@ -2483,15 +4092,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_6 build_environment: linux-binary-manywheel +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda12_6-test: # Testing @@ -2506,15 +4127,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_6 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD runs_on: linux.4xlarge.nvidia.gpu # 12.6 build can use maxwell (sm_50) runner +======= + runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda12_6-upload: # Uploading @@ -2529,10 +4162,18 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.6 +======= + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_6 secrets: @@ -2549,15 +4190,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_8 build_environment: linux-binary-manywheel +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda12_8-test: # Testing @@ -2572,15 +4225,27 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_8 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8+ builds need sm_70+ runner +======= + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda12_8-upload: # Uploading @@ -2595,17 +4260,29 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: cuda12.8 +======= + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD manywheel-py3_13t-cuda13_0-build: +======= + manywheel-py3_13t-cuda12_9-build: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2614,6 +4291,7 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -2630,6 +4308,25 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: - manywheel-py3_13t-cuda13_0-build +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.13t" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_13t-cuda12_9 + build_environment: linux-binary-manywheel + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13t-cuda12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13t-cuda12_9-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2637,6 +4334,7 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -2650,16 +4348,37 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda13_0-upload: # Uploading +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.13t" + build_name: manywheel-py3_13t-cuda12_9 + build_environment: linux-binary-manywheel + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13t-cuda12_9-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: manywheel-py3_13t-cuda13_0-test +======= + needs: manywheel-py3_13t-cuda12_9-test +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda @@ -2667,6 +4386,16 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-cuda13_0 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + use_split_build: False + DESIRED_PYTHON: "3.13t" + build_name: manywheel-py3_13t-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -2681,6 +4410,7 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder @@ -2688,6 +4418,15 @@ jobs: DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 +======= + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + use_split_build: False + DESIRED_PYTHON: "3.13t" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: manywheel-py3_13t-rocm6_3 build_environment: linux-binary-manywheel secrets: @@ -2705,11 +4444,19 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" +======= + GPU_ARCH_VERSION: 6.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" steps: - name: Setup ROCm @@ -2743,7 +4490,11 @@ jobs: role-duration-seconds: 18000 - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder @@ -2751,7 +4502,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -2772,10 +4527,18 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.3 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.3" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.3 +======= + GPU_ARCH_VERSION: 6.3 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.3 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-rocm6_3 secrets: @@ -2792,6 +4555,7 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder @@ -2799,6 +4563,15 @@ jobs: DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 +======= + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.13t" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: manywheel-py3_13t-rocm6_4 build_environment: linux-binary-manywheel secrets: @@ -2816,11 +4589,19 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" +======= + GPU_ARCH_VERSION: 6.4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" steps: - name: Setup ROCm @@ -2854,7 +4635,11 @@ jobs: role-duration-seconds: 18000 - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder @@ -2862,7 +4647,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -2883,10 +4672,18 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: rocm6.4 +<<<<<<< HEAD GPU_ARCH_VERSION: "6.4" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: rocm6.4 +======= + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-rocm6_4 secrets: @@ -2906,11 +4703,19 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-xpu build_environment: linux-binary-manywheel +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-xpu-test: # Testing @@ -2930,13 +4735,21 @@ jobs: SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" permissions: id-token: write contents: read steps: - name: Setup XPU +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/setup-xpu@release/2.9 +======= + uses: ./.github/actions/setup-xpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: configure aws credentials id: aws_creds uses: aws-actions/configure-aws-credentials@v4 @@ -2964,7 +4777,11 @@ jobs: working-directory: pytorch - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder @@ -2972,7 +4789,11 @@ jobs: docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Test Pytorch binary @@ -2996,11 +4817,16 @@ jobs: GPU_ARCH_TYPE: xpu DOCKER_IMAGE: manylinux2_28-builder DOCKER_IMAGE_TAG_PREFIX: xpu +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13t" build_name: manywheel-py3_13t-xpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD manywheel-py3_14-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} @@ -4183,3 +6009,5 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml index d7fd44031be22..fc09e4a4c29f9 100644 --- a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml @@ -41,12 +41,86 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} +<<<<<<< HEAD +======= + manywheel-py3_9-cpu-s390x-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu-s390x + DOCKER_IMAGE: pytorch/manylinuxs390x-builder + DOCKER_IMAGE_TAG_PREFIX: cpu-s390x + use_split_build: False + DESIRED_PYTHON: "3.9" + runs_on: linux.s390x + ALPINE_IMAGE: "docker.io/s390x/alpine" + timeout-minutes: 420 + build_name: manywheel-py3_9-cpu-s390x + build_environment: linux-s390x-binary-manywheel + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cpu-s390x-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cpu-s390x-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu-s390x + DOCKER_IMAGE: pytorch/manylinuxs390x-builder + DOCKER_IMAGE_TAG_PREFIX: cpu-s390x + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cpu-s390x + build_environment: linux-s390x-binary-manywheel + runs_on: linux.s390x + ALPINE_IMAGE: "docker.io/s390x/alpine" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cpu-s390x-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-cpu-s390x-test + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu-s390x + DOCKER_IMAGE: pytorch/manylinuxs390x-builder + DOCKER_IMAGE_TAG_PREFIX: cpu-s390x + use_split_build: False + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cpu-s390x + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) manywheel-py3_10-cpu-s390x-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -60,6 +134,10 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -83,6 +161,10 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -105,6 +187,10 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-s390x secrets: @@ -124,6 +210,10 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -147,6 +237,10 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -169,6 +263,10 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-s390x secrets: @@ -188,6 +286,10 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -211,6 +313,10 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -233,6 +339,10 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-s390x secrets: @@ -252,6 +362,10 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -275,6 +389,10 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -297,11 +415,16 @@ jobs: GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder DOCKER_IMAGE_TAG_PREFIX: cpu-s390x +<<<<<<< HEAD +======= + use_split_build: False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-s390x secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD manywheel-py3_13t-cpu-s390x-build: if: ${{ github.repository_owner == 'pytorch' }} @@ -494,3 +617,5 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml b/.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml index 5f21fc565901d..a12fbb84ad537 100644 --- a/.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml @@ -46,7 +46,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -67,6 +71,14 @@ jobs: chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" +<<<<<<< HEAD +======= + if [ -d "/Applications/Xcode_14.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_14.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + elif [ -d "/Applications/Xcode_13.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Checkout PyTorch uses: actions/checkout@v4 with: diff --git a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml index b12a5212cd4e7..e8d02e5b9b701 100644 --- a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml @@ -30,7 +30,11 @@ concurrency: cancel-in-progress: true jobs: +<<<<<<< HEAD wheel-py3_10-cpu-build: +======= + wheel-py3_9-cpu-build: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} runs-on: macos-14-xlarge timeout-minutes: 240 @@ -42,7 +46,11 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -63,6 +71,14 @@ jobs: chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" +<<<<<<< HEAD +======= + if [ -d "/Applications/Xcode_14.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_14.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + elif [ -d "/Applications/Xcode_13.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Checkout PyTorch uses: actions/checkout@v4 with: @@ -109,6 +125,7 @@ jobs: # Create new "clean" conda environment for testing SMOKE_TEST_PARAMS="" +<<<<<<< HEAD EXTRA_CONDA_INSTALL_FLAGS="" CONDA_ENV_CREATE_FLAGS="" @@ -136,6 +153,137 @@ jobs: # shellcheck disable=SC2086 conda create -yn "test_conda_env" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} ${EXTRA_CONDA_INSTALL_FLAGS} +======= + if [[ $DESIRED_PYTHON == "3.13t" ]]; then + conda create -yn "test_conda_env" python="3.13" python-freethreading -c conda-forge + SMOKE_TEST_PARAMS="--torch-compile-check disabled" + else + conda create -yn "test_conda_env" python="$DESIRED_PYTHON" + fi + conda activate test_conda_env + pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v + + # shellcheck disable=SC2086 + python "${PYTORCH_ROOT}/.ci/pytorch/smoke_test/smoke_test.py" --package torchonly ${SMOKE_TEST_PARAMS} + - uses: actions/upload-artifact@v4.4.0 + if: always() + with: + name: wheel-py3_9-cpu + retention-days: 14 + if-no-files-found: error + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + wheel-py3_9-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: wheel-py3_9-cpu-build + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: cpu + DESIRED_PYTHON: "3.9" + build_name: wheel-py3_9-cpu + use_s3: False + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + wheel-py3_10-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + runs-on: macos-14-xlarge + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.10" + steps: + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + # shellcheck disable=SC2129 + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + # shellcheck disable=SC2129 + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + # shellcheck disable=SC2129 + echo "MAC_PACKAGE_WORK_DIR=${RUNNER_TEMP}" >> "${GITHUB_ENV}" + - name: Install conda and dependencies + run: | + # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" "https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-MacOSX-$(uname -m).sh" + chmod +x "${RUNNER_TEMP}/conda.sh" + /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" + echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + if [ -d "/Applications/Xcode_14.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_14.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + elif [ -d "/Applications/Xcode_13.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + fi + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Populate binary env + run: | + # shellcheck disable=SC1091 + source "${RUNNER_TEMP}/anaconda/bin/activate" + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Build PyTorch binary + run: | + # shellcheck disable=SC1091 + source "${RUNNER_TEMP}/anaconda/bin/activate" + set -eux -o pipefail + # shellcheck disable=SC1090 + source "${BINARY_ENV_FILE:-/Users/distiller/project/env}" + mkdir -p "$PYTORCH_FINAL_PACKAGE_DIR" + + # Build + USE_PYTORCH_METAL_EXPORT=1 + USE_COREML_DELEGATE=1 + TORCH_PACKAGE_NAME="${TORCH_PACKAGE_NAME//-/_}" + export USE_PYTORCH_METAL_EXPORT + export USE_COREML_DELEGATE + export TORCH_PACKAGE_NAME + "${PYTORCH_ROOT}/.ci/wheel/build_wheel.sh" + - name: Test PyTorch wheel + run: | + # shellcheck disable=SC1091 + source "${RUNNER_TEMP}/anaconda/bin/activate" + set -eux -o pipefail + # shellcheck disable=SC1090 + source "${BINARY_ENV_FILE:-/Users/distiller/project/env}" + pip uninstall -y "$TORCH_PACKAGE_NAME" || true + pip uninstall -y "$TORCH_PACKAGE_NAME" || true + + # Create new "clean" conda environment for testing + + SMOKE_TEST_PARAMS="" + if [[ $DESIRED_PYTHON == "3.13t" ]]; then + conda create -yn "test_conda_env" python="3.13" python-freethreading -c conda-forge + SMOKE_TEST_PARAMS="--torch-compile-check disabled" + else + conda create -yn "test_conda_env" python="$DESIRED_PYTHON" + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) conda activate test_conda_env pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v @@ -202,6 +350,14 @@ jobs: chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" +<<<<<<< HEAD +======= + if [ -d "/Applications/Xcode_14.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_14.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + elif [ -d "/Applications/Xcode_13.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Checkout PyTorch uses: actions/checkout@v4 with: @@ -248,6 +404,7 @@ jobs: # Create new "clean" conda environment for testing SMOKE_TEST_PARAMS="" +<<<<<<< HEAD EXTRA_CONDA_INSTALL_FLAGS="" CONDA_ENV_CREATE_FLAGS="" @@ -275,6 +432,14 @@ jobs: # shellcheck disable=SC2086 conda create -yn "test_conda_env" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} ${EXTRA_CONDA_INSTALL_FLAGS} +======= + if [[ $DESIRED_PYTHON == "3.13t" ]]; then + conda create -yn "test_conda_env" python="3.13" python-freethreading -c conda-forge + SMOKE_TEST_PARAMS="--torch-compile-check disabled" + else + conda create -yn "test_conda_env" python="$DESIRED_PYTHON" + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) conda activate test_conda_env pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v @@ -341,6 +506,14 @@ jobs: chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" +<<<<<<< HEAD +======= + if [ -d "/Applications/Xcode_14.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_14.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + elif [ -d "/Applications/Xcode_13.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Checkout PyTorch uses: actions/checkout@v4 with: @@ -387,6 +560,7 @@ jobs: # Create new "clean" conda environment for testing SMOKE_TEST_PARAMS="" +<<<<<<< HEAD EXTRA_CONDA_INSTALL_FLAGS="" CONDA_ENV_CREATE_FLAGS="" @@ -414,6 +588,14 @@ jobs: # shellcheck disable=SC2086 conda create -yn "test_conda_env" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} ${EXTRA_CONDA_INSTALL_FLAGS} +======= + if [[ $DESIRED_PYTHON == "3.13t" ]]; then + conda create -yn "test_conda_env" python="3.13" python-freethreading -c conda-forge + SMOKE_TEST_PARAMS="--torch-compile-check disabled" + else + conda create -yn "test_conda_env" python="$DESIRED_PYTHON" + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) conda activate test_conda_env pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v @@ -480,6 +662,14 @@ jobs: chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" +<<<<<<< HEAD +======= + if [ -d "/Applications/Xcode_14.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_14.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + elif [ -d "/Applications/Xcode_13.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Checkout PyTorch uses: actions/checkout@v4 with: @@ -526,6 +716,7 @@ jobs: # Create new "clean" conda environment for testing SMOKE_TEST_PARAMS="" +<<<<<<< HEAD EXTRA_CONDA_INSTALL_FLAGS="" CONDA_ENV_CREATE_FLAGS="" @@ -553,6 +744,14 @@ jobs: # shellcheck disable=SC2086 conda create -yn "test_conda_env" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} ${EXTRA_CONDA_INSTALL_FLAGS} +======= + if [[ $DESIRED_PYTHON == "3.13t" ]]; then + conda create -yn "test_conda_env" python="3.13" python-freethreading -c conda-forge + SMOKE_TEST_PARAMS="--torch-compile-check disabled" + else + conda create -yn "test_conda_env" python="$DESIRED_PYTHON" + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) conda activate test_conda_env pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v @@ -619,6 +818,14 @@ jobs: chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" +<<<<<<< HEAD +======= + if [ -d "/Applications/Xcode_14.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_14.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + elif [ -d "/Applications/Xcode_13.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Checkout PyTorch uses: actions/checkout@v4 with: @@ -665,6 +872,7 @@ jobs: # Create new "clean" conda environment for testing SMOKE_TEST_PARAMS="" +<<<<<<< HEAD EXTRA_CONDA_INSTALL_FLAGS="" CONDA_ENV_CREATE_FLAGS="" @@ -692,6 +900,14 @@ jobs: # shellcheck disable=SC2086 conda create -yn "test_conda_env" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} ${EXTRA_CONDA_INSTALL_FLAGS} +======= + if [[ $DESIRED_PYTHON == "3.13t" ]]; then + conda create -yn "test_conda_env" python="3.13" python-freethreading -c conda-forge + SMOKE_TEST_PARAMS="--torch-compile-check disabled" + else + conda create -yn "test_conda_env" python="$DESIRED_PYTHON" + fi +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) conda activate test_conda_env pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v @@ -725,6 +941,7 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD wheel-py3_14-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} runs-on: macos-14-xlarge @@ -1003,3 +1220,5 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/workflows/generated-windows-arm64-binary-libtorch-debug-nightly.yml b/.github/workflows/generated-windows-arm64-binary-libtorch-debug-nightly.yml index 7a8ea9cbfa2c1..612c61e356da3 100644 --- a/.github/workflows/generated-windows-arm64-binary-libtorch-debug-nightly.yml +++ b/.github/workflows/generated-windows-arm64-binary-libtorch-debug-nightly.yml @@ -41,7 +41,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -51,7 +55,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "windows-11-arm64-preview" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch @@ -64,7 +72,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Populate binary env shell: cmd @@ -128,7 +140,11 @@ jobs: - libtorch-cpu-shared-with-deps-debug-build - get-label-type runs-on: "windows-11-arm64-preview" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch @@ -141,7 +157,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Populate binary env shell: cmd @@ -201,7 +221,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: libtorch-cpu-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-windows-arm64-binary-libtorch-release-nightly.yml b/.github/workflows/generated-windows-arm64-binary-libtorch-release-nightly.yml index 14081649d370d..c27025775e915 100644 --- a/.github/workflows/generated-windows-arm64-binary-libtorch-release-nightly.yml +++ b/.github/workflows/generated-windows-arm64-binary-libtorch-release-nightly.yml @@ -41,7 +41,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -51,7 +55,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "windows-11-arm64-preview" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch @@ -64,7 +72,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Populate binary env shell: cmd @@ -128,7 +140,11 @@ jobs: - libtorch-cpu-shared-with-deps-release-build - get-label-type runs-on: "windows-11-arm64-preview" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch @@ -141,7 +157,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Populate binary env shell: cmd @@ -201,7 +221,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: libtorch-cpu-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-windows-arm64-binary-wheel-nightly.yml b/.github/workflows/generated-windows-arm64-binary-wheel-nightly.yml index d0e02dade2998..3e474738945e5 100644 --- a/.github/workflows/generated-windows-arm64-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-arm64-binary-wheel-nightly.yml @@ -41,7 +41,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -51,7 +55,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "windows-11-arm64-preview" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -124,7 +132,11 @@ jobs: - wheel-py3_11-cpu-build - get-label-type runs-on: "windows-11-arm64-preview" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -198,7 +210,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "windows-11-arm64-preview" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -271,7 +287,11 @@ jobs: - wheel-py3_12-cpu-build - get-label-type runs-on: "windows-11-arm64-preview" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -345,7 +365,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "windows-11-arm64-preview" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -418,7 +442,11 @@ jobs: - wheel-py3_13-cpu-build - get-label-type runs-on: "windows-11-arm64-preview" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml index 3df2c65440a5f..ef6eff32d3ae8 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml @@ -28,7 +28,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -38,7 +42,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch @@ -51,7 +59,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -77,7 +89,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -152,7 +168,11 @@ jobs: - libtorch-cpu-shared-with-deps-debug-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch @@ -165,7 +185,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Display EC2 information shell: bash @@ -182,7 +206,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml index f4413a86c6578..740452e7eb6a3 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml @@ -35,7 +35,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -45,7 +49,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch @@ -58,7 +66,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -84,7 +96,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -159,7 +175,11 @@ jobs: - libtorch-cpu-shared-with-deps-debug-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch @@ -172,7 +192,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Display EC2 information shell: bash @@ -189,7 +213,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -281,7 +309,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: libtorch-cpu-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -290,21 +322,33 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -330,7 +374,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -405,21 +453,33 @@ jobs: - libtorch-cuda12_6-shared-with-deps-debug-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Display EC2 information shell: bash @@ -436,7 +496,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -523,13 +587,21 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: debug LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: libtorch-cuda12_6-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -538,21 +610,33 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -578,7 +662,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -653,21 +741,33 @@ jobs: - libtorch-cuda12_8-shared-with-deps-debug-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Display EC2 information shell: bash @@ -684,7 +784,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -771,36 +875,61 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: debug LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: libtorch-cuda12_8-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD libtorch-cuda13_0-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 360 +======= + libtorch-cuda12_9-shared-with-deps-debug-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -826,7 +955,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -877,7 +1010,11 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: +<<<<<<< HEAD name: libtorch-cuda13_0-shared-with-deps-debug +======= + name: libtorch-cuda12_9-shared-with-deps-debug +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -895,6 +1032,7 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 +<<<<<<< HEAD libtorch-cuda13_0-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -902,20 +1040,38 @@ jobs: - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 +======= + libtorch-cuda12_9-shared-with-deps-debug-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - libtorch-cuda12_9-shared-with-deps-debug-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Display EC2 information shell: bash @@ -932,7 +1088,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -984,7 +1144,11 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: +<<<<<<< HEAD name: libtorch-cuda13_0-shared-with-deps-debug +======= + name: libtorch-cuda12_9-shared-with-deps-debug +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -1007,26 +1171,44 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 +<<<<<<< HEAD libtorch-cuda13_0-shared-with-deps-debug-upload: # Uploading +======= + libtorch-cuda12_9-shared-with-deps-debug-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: libtorch-cuda13_0-shared-with-deps-debug-test +======= + needs: libtorch-cuda12_9-shared-with-deps-debug-test +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: debug LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" build_name: libtorch-cuda13_0-shared-with-deps-debug +======= + DESIRED_PYTHON: "3.9" + build_name: libtorch-cuda12_9-shared-with-deps-debug +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml diff --git a/.github/workflows/generated-windows-binary-libtorch-release-main.yml b/.github/workflows/generated-windows-binary-libtorch-release-main.yml index ef94d6212af35..6cc90db6a7762 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-main.yml @@ -28,7 +28,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -38,7 +42,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch @@ -51,7 +59,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -77,7 +89,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -152,7 +168,11 @@ jobs: - libtorch-cpu-shared-with-deps-release-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch @@ -165,7 +185,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Display EC2 information shell: bash @@ -182,7 +206,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml index 8f4ec6e0b2054..50e8123e049b6 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml @@ -35,7 +35,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -45,7 +49,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch @@ -58,7 +66,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -84,7 +96,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -159,7 +175,11 @@ jobs: - libtorch-cpu-shared-with-deps-release-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch @@ -172,7 +192,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Display EC2 information shell: bash @@ -189,7 +213,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -281,7 +309,11 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: libtorch-cpu-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -290,21 +322,33 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -330,7 +374,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -405,21 +453,33 @@ jobs: - libtorch-cuda12_6-shared-with-deps-release-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Display EC2 information shell: bash @@ -436,7 +496,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -523,13 +587,21 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: libtorch-cuda12_6-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -538,21 +610,33 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -578,7 +662,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -653,21 +741,33 @@ jobs: - libtorch-cuda12_8-shared-with-deps-release-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Display EC2 information shell: bash @@ -684,7 +784,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -771,36 +875,61 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_name: libtorch-cuda12_8-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD libtorch-cuda13_0-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 360 +======= + libtorch-cuda12_9-shared-with-deps-release-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -826,7 +955,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -877,7 +1010,11 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: +<<<<<<< HEAD name: libtorch-cuda13_0-shared-with-deps-release +======= + name: libtorch-cuda12_9-shared-with-deps-release +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -895,6 +1032,7 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 +<<<<<<< HEAD libtorch-cuda13_0-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -902,20 +1040,38 @@ jobs: - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 +======= + libtorch-cuda12_9-shared-with-deps-release-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - libtorch-cuda12_9-shared-with-deps-release-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" +======= + DESIRED_PYTHON: "3.9" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: - name: Display EC2 information shell: bash @@ -932,7 +1088,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -984,7 +1144,11 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: +<<<<<<< HEAD name: libtorch-cuda13_0-shared-with-deps-release +======= + name: libtorch-cuda12_9-shared-with-deps-release +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -1007,26 +1171,44 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 +<<<<<<< HEAD libtorch-cuda13_0-shared-with-deps-release-upload: # Uploading +======= + libtorch-cuda12_9-shared-with-deps-release-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: libtorch-cuda13_0-shared-with-deps-release-test +======= + needs: libtorch-cuda12_9-shared-with-deps-release-test +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason +<<<<<<< HEAD DESIRED_PYTHON: "3.10" build_name: libtorch-cuda13_0-shared-with-deps-release +======= + DESIRED_PYTHON: "3.9" + build_name: libtorch-cuda12_9-shared-with-deps-release +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index bca8d4843463a..4dc97a4ab94f2 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -35,17 +35,1203 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} +<<<<<<< HEAD +======= + wheel-py3_9-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 300 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.9" + steps: + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Build PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" + - uses: actions/upload-artifact@v4.4.0 + if: always() + with: + name: wheel-py3_9-cpu + retention-days: 14 + if-no-files-found: error + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + + wheel-py3_9-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_9-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 300 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.9" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: wheel-py3_9-cpu + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Test PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_9-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: wheel-py3_9-cpu-test + with: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DESIRED_PYTHON: "3.9" + build_name: wheel-py3_9-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + wheel-py3_9-cuda12_6-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 300 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.9" + steps: + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Build PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" + - uses: actions/upload-artifact@v4.4.0 + if: always() + with: + name: wheel-py3_9-cuda12_6 + retention-days: 14 + if-no-files-found: error + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + + wheel-py3_9-cuda12_6-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_9-cuda12_6-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" + timeout-minutes: 300 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.9" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: wheel-py3_9-cuda12_6 + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Test PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_9-cuda12_6-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: wheel-py3_9-cuda12_6-test + with: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: 12.6 + GPU_ARCH_TYPE: cuda + DESIRED_PYTHON: "3.9" + build_name: wheel-py3_9-cuda12_6 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + wheel-py3_9-cuda12_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 300 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.9" + steps: + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Build PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" + - uses: actions/upload-artifact@v4.4.0 + if: always() + with: + name: wheel-py3_9-cuda12_8 + retention-days: 14 + if-no-files-found: error + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + + wheel-py3_9-cuda12_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_9-cuda12_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" + timeout-minutes: 300 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.9" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: wheel-py3_9-cuda12_8 + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Test PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_9-cuda12_8-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: wheel-py3_9-cuda12_8-test + with: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu128 + GPU_ARCH_VERSION: 12.8 + GPU_ARCH_TYPE: cuda + DESIRED_PYTHON: "3.9" + build_name: wheel-py3_9-cuda12_8 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + wheel-py3_9-cuda12_9-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 300 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.9" + steps: + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Build PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" + - uses: actions/upload-artifact@v4.4.0 + if: always() + with: + name: wheel-py3_9-cuda12_9 + retention-days: 14 + if-no-files-found: error + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + + wheel-py3_9-cuda12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_9-cuda12_9-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" + timeout-minutes: 300 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.9" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: wheel-py3_9-cuda12_9 + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Test PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_9-cuda12_9-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: wheel-py3_9-cuda12_9-test + with: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DESIRED_PYTHON: "3.9" + build_name: wheel-py3_9-cuda12_9 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml + wheel-py3_9-xpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 300 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.9" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 + steps: + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Build PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" + - uses: actions/upload-artifact@v4.4.0 + if: always() + with: + name: wheel-py3_9-xpu + retention-days: 14 + if-no-files-found: error + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + + wheel-py3_9-xpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_9-xpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 300 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.9" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: wheel-py3_9-xpu + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Test PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_9-xpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: wheel-py3_9-xpu-test + with: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + DESIRED_PYTHON: "3.9" + build_name: wheel-py3_9-xpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + uses: ./.github/workflows/_binary-upload.yml +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) wheel-py3_10-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -80,7 +1266,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -155,7 +1345,11 @@ jobs: - wheel-py3_10-cpu-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -181,7 +1375,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -278,14 +1476,22 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -314,7 +1520,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -389,14 +1599,22 @@ jobs: - wheel-py3_10-cuda12_6-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -416,7 +1634,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -503,7 +1725,11 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.10" build_name: wheel-py3_10-cuda12_6 @@ -514,14 +1740,22 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -550,7 +1784,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -625,14 +1863,22 @@ jobs: - wheel-py3_10-cuda12_8-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -652,7 +1898,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -739,25 +1989,42 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.10" build_name: wheel-py3_10-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD wheel-py3_10-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 360 +======= + wheel-py3_10-cuda12_9-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -786,7 +2053,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -837,7 +2108,11 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: +<<<<<<< HEAD name: wheel-py3_10-cuda13_0 +======= + name: wheel-py3_10-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -855,6 +2130,7 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 +<<<<<<< HEAD wheel-py3_10-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -862,13 +2138,27 @@ jobs: - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 +======= + wheel-py3_10-cuda12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_10-cuda12_9-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -888,7 +2178,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -940,7 +2234,11 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: +<<<<<<< HEAD name: wheel-py3_10-cuda13_0 +======= + name: wheel-py3_10-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -963,22 +2261,38 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 +<<<<<<< HEAD wheel-py3_10-cuda13_0-upload: # Uploading +======= + wheel-py3_10-cuda12_9-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: wheel-py3_10-cuda13_0-test +======= + needs: wheel-py3_10-cuda12_9-test +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.10" build_name: wheel-py3_10-cuda13_0 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DESIRED_PYTHON: "3.10" + build_name: wheel-py3_10-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -986,7 +2300,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -996,7 +2314,11 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -1022,7 +2344,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -1097,7 +2423,11 @@ jobs: - wheel-py3_10-xpu-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -1123,7 +2453,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -1220,7 +2554,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -1255,7 +2593,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -1330,7 +2672,11 @@ jobs: - wheel-py3_11-cpu-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -1356,7 +2702,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -1453,14 +2803,22 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -1489,7 +2847,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -1564,14 +2926,22 @@ jobs: - wheel-py3_11-cuda12_6-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -1591,7 +2961,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -1678,7 +3052,11 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.11" build_name: wheel-py3_11-cuda12_6 @@ -1689,14 +3067,22 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -1725,7 +3111,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -1800,14 +3190,22 @@ jobs: - wheel-py3_11-cuda12_8-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -1827,7 +3225,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -1914,25 +3316,42 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.11" build_name: wheel-py3_11-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD wheel-py3_11-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 360 +======= + wheel-py3_11-cuda12_9-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -1961,7 +3380,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -2012,7 +3435,11 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: +<<<<<<< HEAD name: wheel-py3_11-cuda13_0 +======= + name: wheel-py3_11-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -2030,6 +3457,7 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 +<<<<<<< HEAD wheel-py3_11-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -2037,13 +3465,27 @@ jobs: - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 +======= + wheel-py3_11-cuda12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_11-cuda12_9-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -2063,7 +3505,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -2115,7 +3561,11 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: +<<<<<<< HEAD name: wheel-py3_11-cuda13_0 +======= + name: wheel-py3_11-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -2138,22 +3588,38 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 +<<<<<<< HEAD wheel-py3_11-cuda13_0-upload: # Uploading +======= + wheel-py3_11-cuda12_9-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: wheel-py3_11-cuda13_0-test +======= + needs: wheel-py3_11-cuda12_9-test +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.11" build_name: wheel-py3_11-cuda13_0 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DESIRED_PYTHON: "3.11" + build_name: wheel-py3_11-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -2161,7 +3627,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -2171,7 +3641,11 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -2197,7 +3671,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -2272,7 +3750,11 @@ jobs: - wheel-py3_11-xpu-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -2298,7 +3780,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -2395,7 +3881,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -2430,7 +3920,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -2505,7 +3999,11 @@ jobs: - wheel-py3_12-cpu-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -2531,7 +4029,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -2628,14 +4130,22 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -2664,7 +4174,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -2739,14 +4253,22 @@ jobs: - wheel-py3_12-cuda12_6-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -2766,7 +4288,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -2853,7 +4379,11 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.12" build_name: wheel-py3_12-cuda12_6 @@ -2864,14 +4394,22 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -2900,7 +4438,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -2975,14 +4517,22 @@ jobs: - wheel-py3_12-cuda12_8-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -3002,7 +4552,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -3089,25 +4643,42 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.12" build_name: wheel-py3_12-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD wheel-py3_12-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 360 +======= + wheel-py3_12-cuda12_9-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -3136,7 +4707,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -3187,7 +4762,11 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: +<<<<<<< HEAD name: wheel-py3_12-cuda13_0 +======= + name: wheel-py3_12-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -3205,6 +4784,7 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 +<<<<<<< HEAD wheel-py3_12-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -3212,13 +4792,27 @@ jobs: - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 +======= + wheel-py3_12-cuda12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_12-cuda12_9-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -3238,7 +4832,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -3290,7 +4888,11 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: +<<<<<<< HEAD name: wheel-py3_12-cuda13_0 +======= + name: wheel-py3_12-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -3313,22 +4915,38 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 +<<<<<<< HEAD wheel-py3_12-cuda13_0-upload: # Uploading +======= + wheel-py3_12-cuda12_9-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: wheel-py3_12-cuda13_0-test +======= + needs: wheel-py3_12-cuda12_9-test +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.12" build_name: wheel-py3_12-cuda13_0 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DESIRED_PYTHON: "3.12" + build_name: wheel-py3_12-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -3336,7 +4954,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -3346,7 +4968,11 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -3372,7 +4998,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -3447,7 +5077,11 @@ jobs: - wheel-py3_12-xpu-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -3473,7 +5107,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -3570,7 +5208,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -3605,7 +5247,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -3680,7 +5326,11 @@ jobs: - wheel-py3_13-cpu-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -3706,7 +5356,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -3803,14 +5457,22 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -3839,7 +5501,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -3914,14 +5580,22 @@ jobs: - wheel-py3_13-cuda12_6-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -3941,7 +5615,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -4028,7 +5706,11 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13" build_name: wheel-py3_13-cuda12_6 @@ -4039,14 +5721,22 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -4075,7 +5765,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -4150,14 +5844,22 @@ jobs: - wheel-py3_13-cuda12_8-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -4177,7 +5879,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -4264,25 +5970,42 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13" build_name: wheel-py3_13-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD wheel-py3_13-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 360 +======= + wheel-py3_13-cuda12_9-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -4311,7 +6034,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -4362,7 +6089,11 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: +<<<<<<< HEAD name: wheel-py3_13-cuda13_0 +======= + name: wheel-py3_13-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -4380,6 +6111,7 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 +<<<<<<< HEAD wheel-py3_13-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -4387,13 +6119,27 @@ jobs: - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 +======= + wheel-py3_13-cuda12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_13-cuda12_9-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -4413,7 +6159,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -4465,7 +6215,11 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: +<<<<<<< HEAD name: wheel-py3_13-cuda13_0 +======= + name: wheel-py3_13-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -4488,22 +6242,38 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 +<<<<<<< HEAD wheel-py3_13-cuda13_0-upload: # Uploading +======= + wheel-py3_13-cuda12_9-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: wheel-py3_13-cuda13_0-test +======= + needs: wheel-py3_13-cuda12_9-test +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13" build_name: wheel-py3_13-cuda13_0 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DESIRED_PYTHON: "3.13" + build_name: wheel-py3_13-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -4511,7 +6281,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -4521,7 +6295,11 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -4547,7 +6325,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -4622,7 +6404,11 @@ jobs: - wheel-py3_13-xpu-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -4648,7 +6434,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -4745,7 +6535,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -4780,7 +6574,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -4855,7 +6653,11 @@ jobs: - wheel-py3_13t-cpu-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -4881,7 +6683,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -4978,14 +6784,22 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -5014,7 +6828,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -5089,14 +6907,22 @@ jobs: - wheel-py3_13t-cuda12_6-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -5116,7 +6942,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -5203,7 +7033,11 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu126 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.6" +======= + GPU_ARCH_VERSION: 12.6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13t" build_name: wheel-py3_13t-cuda12_6 @@ -5214,14 +7048,22 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -5250,7 +7092,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -5325,14 +7171,22 @@ jobs: - wheel-py3_13t-cuda12_8-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -5352,7 +7206,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -5439,25 +7297,42 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu128 +<<<<<<< HEAD GPU_ARCH_VERSION: "12.8" +======= + GPU_ARCH_VERSION: 12.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13t" build_name: wheel-py3_13t-cuda12_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD wheel-py3_13t-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 360 +======= + wheel-py3_13t-cuda12_9-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -5486,7 +7361,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -5537,7 +7416,11 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: +<<<<<<< HEAD name: wheel-py3_13t-cuda13_0 +======= + name: wheel-py3_13t-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -5555,6 +7438,7 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 +<<<<<<< HEAD wheel-py3_13t-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: @@ -5562,13 +7446,27 @@ jobs: - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 +======= + wheel-py3_13t-cuda12_9-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_13t-cuda12_9-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -5588,7 +7486,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -5640,7 +7542,11 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: +<<<<<<< HEAD name: wheel-py3_13t-cuda13_0 +======= + name: wheel-py3_13t-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -5663,22 +7569,38 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 +<<<<<<< HEAD wheel-py3_13t-cuda13_0-upload: # Uploading +======= + wheel-py3_13t-cuda12_9-upload: # Uploading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read +<<<<<<< HEAD needs: wheel-py3_13t-cuda13_0-test +======= + needs: wheel-py3_13t-cuda12_9-test +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION +<<<<<<< HEAD DESIRED_CUDA: cu130 GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13t" build_name: wheel-py3_13t-cuda13_0 +======= + DESIRED_CUDA: cu129 + GPU_ARCH_VERSION: 12.9 + GPU_ARCH_TYPE: cuda + DESIRED_PYTHON: "3.13t" + build_name: wheel-py3_13t-cuda12_9 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -5686,7 +7608,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -5696,7 +7622,11 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" +<<<<<<< HEAD PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.2.1 | intel-cmplr-lib-ur==2025.2.1 | intel-cmplr-lic-rt==2025.2.1 | intel-sycl-rt==2025.2.1 | oneccl-devel==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.16.1; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.2.0 | onemkl-sycl-dft==2025.2.0 | onemkl-sycl-lapack==2025.2.0 | onemkl-sycl-rng==2025.2.0 | onemkl-sycl-sparse==2025.2.0 | dpcpp-cpp-rt==2025.2.1 | intel-opencl-rt==2025.2.1 | mkl==2025.2.0 | intel-openmp==2025.2.1 | tbb==2022.2.0 | tcmlib==1.4.0 | umf==0.11.0 | intel-pti==0.13.1 +======= + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.2; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -5722,7 +7652,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -5797,7 +7731,11 @@ jobs: - wheel-py3_13t-xpu-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" +<<<<<<< HEAD timeout-minutes: 360 +======= + timeout-minutes: 300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch PACKAGE_TYPE: wheel @@ -5823,7 +7761,11 @@ jobs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -5916,6 +7858,7 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +<<<<<<< HEAD wheel-py3_14-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type @@ -8266,3 +10209,5 @@ jobs: secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/workflows/h100-distributed.yml b/.github/workflows/h100-distributed.yml index 8996add88383b..0281289b40b75 100644 --- a/.github/workflows/h100-distributed.yml +++ b/.github/workflows/h100-distributed.yml @@ -8,23 +8,33 @@ on: push: tags: - ciflow/h100-distributed/* +<<<<<<< HEAD schedule: - cron: 46 8 * * * # about 1:46am PDT +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/inductor-micro-benchmark-x86.yml b/.github/workflows/inductor-micro-benchmark-x86.yml index c6cc075e6b270..ce2b5f9bdec16 100644 --- a/.github/workflows/inductor-micro-benchmark-x86.yml +++ b/.github/workflows/inductor-micro-benchmark-x86.yml @@ -13,6 +13,7 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read @@ -25,6 +26,18 @@ jobs: with: build-environment: linux-jammy-py3.9-gcc11 docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks +======= +permissions: read-all + +jobs: + linux-jammy-cpu-py3_9-gcc11-inductor-build: + if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} + name: linux-jammy-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-jammy-py3.9-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Use metal host for benchmark jobs test-matrix: | { include: [ @@ -32,6 +45,7 @@ jobs: ]} secrets: inherit +<<<<<<< HEAD inductor-micro-benchmark-test: name: inductor-micro-benchmark-test uses: ./.github/workflows/_linux-test.yml @@ -40,5 +54,15 @@ jobs: build-environment: linux-jammy-py3.9-gcc11 docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} +======= + linux-jammy-cpu-py3_9-gcc11-inductor-micro-benchmark-test: + name: linux-jammy-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cpu-py3_9-gcc11-inductor-build + with: + build-environment: linux-jammy-py3.9-gcc11 + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) timeout-minutes: 720 secrets: inherit diff --git a/.github/workflows/inductor-micro-benchmark.yml b/.github/workflows/inductor-micro-benchmark.yml index 842094e0eb484..397a41fce1fef 100644 --- a/.github/workflows/inductor-micro-benchmark.yml +++ b/.github/workflows/inductor-micro-benchmark.yml @@ -13,14 +13,22 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: get-default-label-prefix: name: get-default-label-prefix +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} diff --git a/.github/workflows/inductor-nightly.yml b/.github/workflows/inductor-nightly.yml index 7502381de93d0..f1a0682c55735 100644 --- a/.github/workflows/inductor-nightly.yml +++ b/.github/workflows/inductor-nightly.yml @@ -16,14 +16,22 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: get-default-label-prefix: name: get-default-label-prefix +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} @@ -32,6 +40,7 @@ jobs: curr_ref_type: ${{ github.ref_type }} opt_out_experiments: lf +<<<<<<< HEAD nightly-dynamo-benchmarks-build: name: nightly-dynamo-benchmarks-build uses: ./.github/workflows/_linux-build.yml @@ -39,6 +48,15 @@ jobs: with: build-environment: linux-jammy-py3.10-gcc11-build docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks +======= + linux-jammy-cpu-py3_9-gcc11-nightly-dynamo-benchmarks-build: + name: linux-jammy-cpu-py3.9-gcc11-nightly-dynamo-benchmarks + uses: ./.github/workflows/_linux-build.yml + needs: get-default-label-prefix + with: + build-environment: linux-jammy-py3.9-gcc11-build + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" test-matrix: | { include: [ @@ -48,6 +66,7 @@ jobs: { config: "dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "linux.8xlarge.amx" }, { config: "dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "linux.8xlarge.amx" }, ]} +<<<<<<< HEAD build-additional-packages: "vision audio torchao" secrets: inherit @@ -59,5 +78,17 @@ jobs: build-environment: linux-jammy-py3.10-gcc11-build docker-image: ${{ needs.nightly-dynamo-benchmarks-build.outputs.docker-image }} test-matrix: ${{ needs.nightly-dynamo-benchmarks-build.outputs.test-matrix }} +======= + secrets: inherit + + linux-jammy-cpu-py3_9-gcc11-nightly-dynamo-benchmarks-test: + name: linux-jammy-cpu-py3.9-gcc11-nightly-dynamo-benchmarks + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cpu-py3_9-gcc11-nightly-dynamo-benchmarks-build + with: + build-environment: linux-jammy-py3.9-gcc11-build + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-nightly-dynamo-benchmarks-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-nightly-dynamo-benchmarks-build.outputs.test-matrix }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) timeout-minutes: 720 secrets: inherit diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index 35217f72bf1ae..16cd9d2dfa6d6 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -10,15 +10,23 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: get-default-label-prefix: if: github.repository_owner == 'pytorch' name: get-default-label-prefix +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -43,7 +51,10 @@ jobs: { config: "inductor_timm_perf_compare", shard: 2, num_shards: 2, runner: "linux.aws.a100" }, { config: "inductor_torchbench_perf_compare", shard: 1, num_shards: 1, runner: "linux.aws.a100" }, ]} +<<<<<<< HEAD build-additional-packages: "vision audio fbgemm torchao" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: inherit test: diff --git a/.github/workflows/inductor-perf-test-nightly-aarch64.yml b/.github/workflows/inductor-perf-test-nightly-aarch64.yml index 9e3165fe11ea9..b96fe8032c88b 100644 --- a/.github/workflows/inductor-perf-test-nightly-aarch64.yml +++ b/.github/workflows/inductor-perf-test-nightly-aarch64.yml @@ -48,14 +48,22 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} @@ -116,7 +124,10 @@ jobs: { config: "inductor_torchbench_perf_cpu_aarch64", shard: 15, num_shards: 15, runner: "linux.arm64.m7g.metal" }, ]} selected-test-configs: ${{ inputs.benchmark_configs }} +<<<<<<< HEAD build-additional-packages: "vision audio torchao" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: inherit diff --git a/.github/workflows/inductor-perf-test-nightly-h100.yml b/.github/workflows/inductor-perf-test-nightly-h100.yml index 7e323fa5a92ed..336bebe553c24 100644 --- a/.github/workflows/inductor-perf-test-nightly-h100.yml +++ b/.github/workflows/inductor-perf-test-nightly-h100.yml @@ -2,7 +2,11 @@ name: inductor-perf-nightly-h100 on: schedule: +<<<<<<< HEAD - cron: 15 0,12 * * 1-6 +======= + - cron: 15 0,4,8,12,16,20 * * 1-6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - cron: 0 7 * * 0 # NB: GitHub has an upper limit of 10 inputs here, so before we can sort it # out, let try to run torchao cudagraphs_low_precision as part of cudagraphs @@ -58,6 +62,7 @@ on: required: false type: string default: inductor_huggingface_perf_cuda_h100,inductor_timm_perf_cuda_h100,inductor_torchbench_perf_cuda_h100 +<<<<<<< HEAD pull_request: # Changing these files guarantees that this workflow needs to be run paths: @@ -71,11 +76,23 @@ concurrency: permissions: id-token: write contents: read +======= + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} @@ -84,17 +101,26 @@ jobs: curr_ref_type: ${{ github.ref_type }} opt_out_experiments: lf +<<<<<<< HEAD build: name: build +======= + # NB: Keep this in sync with trunk.yml + build: + name: cuda12.8-py3.10-gcc9-sm90 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD # Use a bigger runner here because CUDA_ARCH 9.0 is only built for H100 # or newer GPUs, so it doesn't benefit much from existing compiler cache # from trunk. Also use a memory-intensive runner here because memory is # usually the bottleneck runner: linux.12xlarge.memory +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '9.0' @@ -123,6 +149,7 @@ jobs: { config: "inductor_torchbench_perf_cuda_h100", shard: 9, num_shards: 9, runner: "linux.aws.h100" }, ]} selected-test-configs: ${{ inputs.benchmark_configs }} +<<<<<<< HEAD build-additional-packages: "vision audio fbgemm torchao" secrets: inherit @@ -131,6 +158,15 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: build if: github.event.schedule == '15 0,12 * * 1-6' +======= + secrets: inherit + + test-periodically: + name: cuda12.8-py3.10-gcc9-sm90 + uses: ./.github/workflows/_linux-test.yml + needs: build + if: github.event.schedule == '15 0,4,8,12,16,20 * * 1-6' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true @@ -144,7 +180,11 @@ jobs: secrets: inherit test-weekly: +<<<<<<< HEAD name: test-weekly +======= + name: cuda12.8-py3.10-gcc9-sm90 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-test.yml needs: build if: github.event.schedule == '0 7 * * 0' @@ -161,6 +201,7 @@ jobs: secrets: inherit test: +<<<<<<< HEAD name: test uses: ./.github/workflows/_linux-test.yml needs: build @@ -170,6 +211,15 @@ jobs: with: build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 dashboard-tag: training-${{ inputs.training || 'true' }}-inference-${{ inputs.inference || 'true' }}-default-${{ inputs.default || 'true' }}-dynamic-${{ inputs.dynamic || 'true' }}-cudagraphs-${{ inputs.cudagraphs || 'true' }}-cppwrapper-${{ inputs.cppwrapper || 'false' }}-aotinductor-${{ inputs.aotinductor || 'false' }}-maxautotune-${{ inputs.maxautotune || 'false' }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs || 'false' }}-cudagraphs_low_precision-${{ inputs.cudagraphs || 'false' }} +======= + name: cuda12.8-py3.10-gcc9-sm90 + uses: ./.github/workflows/_linux-test.yml + needs: build + if: github.event_name == 'workflow_dispatch' + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 + dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} timeout-minutes: 720 diff --git a/.github/workflows/inductor-perf-test-nightly-macos.yml b/.github/workflows/inductor-perf-test-nightly-macos.yml index c3b9a42299247..b1615aee436e2 100644 --- a/.github/workflows/inductor-perf-test-nightly-macos.yml +++ b/.github/workflows/inductor-perf-test-nightly-macos.yml @@ -48,9 +48,12 @@ jobs: { config: "perf_smoketest", shard: 1, num_shards: 3, runner: "macos-m2-15" }, { config: "perf_smoketest", shard: 2, num_shards: 3, runner: "macos-m2-15" }, { config: "perf_smoketest", shard: 3, num_shards: 3, runner: "macos-m2-15" }, +<<<<<<< HEAD { config: "aot_inductor_perf_smoketest", shard: 1, num_shards: 3, runner: "macos-m2-15" }, { config: "aot_inductor_perf_smoketest", shard: 2, num_shards: 3, runner: "macos-m2-15" }, { config: "aot_inductor_perf_smoketest", shard: 3, num_shards: 3, runner: "macos-m2-15" }, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ]} secrets: inherit diff --git a/.github/workflows/inductor-perf-test-nightly-rocm.yml b/.github/workflows/inductor-perf-test-nightly-rocm.yml index dddf68091fdb5..382344facb9c6 100644 --- a/.github/workflows/inductor-perf-test-nightly-rocm.yml +++ b/.github/workflows/inductor-perf-test-nightly-rocm.yml @@ -5,7 +5,11 @@ on: tags: - ciflow/inductor-perf-test-nightly-rocm/* schedule: +<<<<<<< HEAD - cron: 0 7 * * 0,3 +======= + - cron: 0 7 * * 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: GitHub has an upper limit of 10 inputs here, so before we can sort it # out, let try to run torchao cudagraphs_low_precision as part of cudagraphs workflow_dispatch: @@ -70,7 +74,11 @@ permissions: read-all jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} @@ -85,6 +93,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: build-environment: linux-jammy-rocm-py3_10 +<<<<<<< HEAD docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks test-matrix: | { include: [ @@ -105,6 +114,23 @@ jobs: { config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, { config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, { config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" }, +======= + docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 + test-matrix: | + { include: [ + { config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ]} secrets: inherit diff --git a/.github/workflows/inductor-perf-test-nightly-x86-zen.yml b/.github/workflows/inductor-perf-test-nightly-x86-zen.yml index 8057b10426768..544b78b03e117 100644 --- a/.github/workflows/inductor-perf-test-nightly-x86-zen.yml +++ b/.github/workflows/inductor-perf-test-nightly-x86-zen.yml @@ -47,20 +47,32 @@ on: description: The list of configs used the benchmark required: false type: string +<<<<<<< HEAD default: inductor_huggingface_perf_cpu_x86_zen,inductor_timm_perf_cpu_x86_zen,inductor_torchbench_perf_cpu_x86_zen +======= + default: inductor_huggingface_perf_zen_cpu_x86,inductor_timm_perf_zen_cpu_x86,inductor_torchbench_perf_zen_cpu_x86 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} @@ -69,12 +81,18 @@ jobs: curr_ref_type: ${{ github.ref_type }} opt_out_experiments: lf +<<<<<<< HEAD inductor-build: name: inductor-build +======= + linux-jammy-zen-cpu-py3_9-gcc11-inductor-build: + name: linux-jammy-zen-cpu-py3.9-gcc11-inductor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD build-environment: linux-jammy-py3.10-gcc11-build docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks test-matrix: | @@ -91,10 +109,29 @@ jobs: { config: "inductor_torchbench_perf_cpu_x86_zen", shard: 2, num_shards: 4, runner: "linux.24xlarge.amd" }, { config: "inductor_torchbench_perf_cpu_x86_zen", shard: 3, num_shards: 4, runner: "linux.24xlarge.amd" }, { config: "inductor_torchbench_perf_cpu_x86_zen", shard: 4, num_shards: 4, runner: "linux.24xlarge.amd" }, +======= + build-environment: linux-jammy-py3.9-gcc11-build + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks + test-matrix: | + { include: [ + { config: "inductor_huggingface_perf_zen_cpu_x86", shard: 1, num_shards: 3, runner: "linux.24xlarge.amd" }, + { config: "inductor_huggingface_perf_zen_cpu_x86", shard: 2, num_shards: 3, runner: "linux.24xlarge.amd" }, + { config: "inductor_huggingface_perf_zen_cpu_x86", shard: 3, num_shards: 3, runner: "linux.24xlarge.amd" }, + { config: "inductor_timm_perf_zen_cpu_x86", shard: 1, num_shards: 5, runner: "linux.24xlarge.amd" }, + { config: "inductor_timm_perf_zen_cpu_x86", shard: 2, num_shards: 5, runner: "linux.24xlarge.amd" }, + { config: "inductor_timm_perf_zen_cpu_x86", shard: 3, num_shards: 5, runner: "linux.24xlarge.amd" }, + { config: "inductor_timm_perf_zen_cpu_x86", shard: 4, num_shards: 5, runner: "linux.24xlarge.amd" }, + { config: "inductor_timm_perf_zen_cpu_x86", shard: 5, num_shards: 5, runner: "linux.24xlarge.amd" }, + { config: "inductor_torchbench_perf_zen_cpu_x86", shard: 1, num_shards: 4, runner: "linux.24xlarge.amd" }, + { config: "inductor_torchbench_perf_zen_cpu_x86", shard: 2, num_shards: 4, runner: "linux.24xlarge.amd" }, + { config: "inductor_torchbench_perf_zen_cpu_x86", shard: 3, num_shards: 4, runner: "linux.24xlarge.amd" }, + { config: "inductor_torchbench_perf_zen_cpu_x86", shard: 4, num_shards: 4, runner: "linux.24xlarge.amd" }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ]} selected-test-configs: ${{ inputs.benchmark_configs }} secrets: inherit +<<<<<<< HEAD inductor-test-nightly: name: inductor-test-nightly uses: ./.github/workflows/_linux-test.yml @@ -105,6 +142,18 @@ jobs: dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} +======= + linux-jammy-zen-cpu-py3_9-gcc11-inductor-test-nightly: + name: linux-jammy-zen-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-zen-cpu-py3_9-gcc11-inductor-build + if: github.event.schedule == '0 7 * * *' + with: + build-environment: linux-jammy-py3.9-gcc11-build + dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true + docker-image: ${{ needs.linux-jammy-zen-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-zen-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) timeout-minutes: 720 # disable monitor in perf tests disable-monitor: false @@ -112,6 +161,7 @@ jobs: monitor-data-collect-interval: 4 secrets: inherit +<<<<<<< HEAD inductor-test: name: inductor-test uses: ./.github/workflows/_linux-test.yml @@ -122,6 +172,19 @@ jobs: dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }} docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} +======= + + linux-jammy-zen-cpu-py3_9-gcc11-inductor-test: + name: linux-jammy-zen-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-zen-cpu-py3_9-gcc11-inductor-build + if: github.event_name == 'workflow_dispatch' + with: + build-environment: linux-jammy-py3.9-gcc11-build + dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }} + docker-image: ${{ needs.linux-jammy-zen-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-zen-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) timeout-minutes: 720 # disable monitor in perf tests disable-monitor: false diff --git a/.github/workflows/inductor-perf-test-nightly-x86.yml b/.github/workflows/inductor-perf-test-nightly-x86.yml index b68e9ad95ca40..dbc82851fc5e2 100644 --- a/.github/workflows/inductor-perf-test-nightly-x86.yml +++ b/.github/workflows/inductor-perf-test-nightly-x86.yml @@ -1,9 +1,12 @@ name: inductor-perf-nightly-x86 on: +<<<<<<< HEAD pull_request: paths: - .github/workflows/inductor-perf-test-nightly-x86.yml +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) schedule: # - cron: 0 7 * * 1-6 # - cron: 0 7 * * 0 @@ -43,11 +46,14 @@ on: required: false type: boolean default: false +<<<<<<< HEAD freezing: description: Run freezing? required: false type: boolean default: true +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) benchmark_configs: description: The list of configs used the benchmark required: false @@ -55,17 +61,28 @@ on: default: inductor_huggingface_perf_cpu_x86,inductor_timm_perf_cpu_x86,inductor_torchbench_perf_cpu_x86 concurrency: +<<<<<<< HEAD group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true permissions: id-token: write contents: read +======= + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} @@ -74,14 +91,24 @@ jobs: curr_ref_type: ${{ github.ref_type }} opt_out_experiments: lf +<<<<<<< HEAD inductor-build: name: inductor-build +======= + linux-jammy-cpu-py3_9-gcc11-inductor-build: + name: linux-jammy-cpu-py3.9-gcc11-inductor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD build-environment: linux-jammy-py3.10-gcc11-build docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks +======= + build-environment: linux-jammy-py3.9-gcc11-build + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test-matrix: | { include: [ { config: "inductor_huggingface_perf_cpu_x86", shard: 1, num_shards: 3, runner: "linux.24xl.spr-metal" }, @@ -98,6 +125,7 @@ jobs: { config: "inductor_torchbench_perf_cpu_x86", shard: 4, num_shards: 4, runner: "linux.24xl.spr-metal" }, ]} selected-test-configs: ${{ inputs.benchmark_configs }} +<<<<<<< HEAD build-additional-packages: "vision audio torchao" secrets: inherit @@ -111,6 +139,21 @@ jobs: dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true-freezing-true docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} +======= + secrets: inherit + + + linux-jammy-cpu-py3_9-gcc11-inductor-test-nightly: + name: linux-jammy-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cpu-py3_9-gcc11-inductor-build + if: github.event.schedule == '0 7 * * *' + with: + build-environment: linux-jammy-py3.9-gcc11-build + dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) timeout-minutes: 720 # disable monitor in perf tests disable-monitor: false @@ -118,6 +161,7 @@ jobs: monitor-data-collect-interval: 4 secrets: inherit +<<<<<<< HEAD inductor-test: name: inductor-test uses: ./.github/workflows/_linux-test.yml @@ -128,6 +172,19 @@ jobs: dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-freezing-${{ inputs.freezing }} docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} +======= + + linux-jammy-cpu-py3_9-gcc11-inductor-test: + name: linux-jammy-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cpu-py3_9-gcc11-inductor-build + if: github.event_name == 'workflow_dispatch' + with: + build-environment: linux-jammy-py3.9-gcc11-build + dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }} + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) timeout-minutes: 720 # disable monitor in perf tests disable-monitor: false diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index 7c573d4d25716..1bb52eaf381f4 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -63,14 +63,22 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} @@ -79,14 +87,21 @@ jobs: curr_ref_type: ${{ github.ref_type }} opt_out_experiments: lf +<<<<<<< HEAD +======= + # NB: Keep this in sync with trunk.yml +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build: name: cuda12.8-py3.10-gcc9-sm80 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD # Every bit to make perf run faster helps runner: linux.12xlarge.memory +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' @@ -113,7 +128,10 @@ jobs: { config: "cachebench", shard: 2, num_shards: 2, runner: "linux.aws.a100" }, ]} selected-test-configs: ${{ inputs.benchmark_configs }} +<<<<<<< HEAD build-additional-packages: "vision audio fbgemm torchao" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: inherit test-nightly: diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index b17ebb84d5d38..a97f3e797329e 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -15,14 +15,22 @@ concurrency: cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: get-default-label-prefix: name: get-default-label-prefix +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} @@ -31,8 +39,13 @@ jobs: curr_ref_type: ${{ github.ref_type }} opt_out_experiments: lf +<<<<<<< HEAD periodic-dynamo-benchmarks-build: name: periodic-dynamo-benchmarks-build +======= + linux-jammy-cuda12_8-py3_10-gcc9-periodic-dynamo-benchmarks-build: + name: cuda12.8-py3.10-gcc9-sm86-periodic-dynamo-benchmarks +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-default-label-prefix with: @@ -57,6 +70,7 @@ jobs: { config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, +<<<<<<< HEAD { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, @@ -124,6 +138,64 @@ jobs: inductor-smoke-build: name: inductor-smoke-build +======= + ]} + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc9-periodic-dynamo-benchmarks-test: + name: cuda12.8-py3.10-gcc9-sm86-periodic-dynamo-benchmarks + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cuda12_8-py3_10-gcc9-periodic-dynamo-benchmarks-build + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-periodic-dynamo-benchmarks-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-periodic-dynamo-benchmarks-build.outputs.test-matrix }} + secrets: inherit + + linux-jammy-rocm-py3_10-periodic-dynamo-benchmarks-build: + if: github.repository_owner == 'pytorch' + name: rocm-py3_10-periodic-dynamo-benchmarks + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-jammy-rocm-py3_10 + docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 + sync-tag: rocm-build + test-matrix: | + { include: [ + { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, + { config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, + { config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" }, + { config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, + { config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, + { config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, + { config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, + { config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" }, + { config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, + { config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, + { config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, + { config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, + { config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" }, + { config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, + { config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, + ]} + secrets: inherit + + linux-jammy-rocm-py3_10-periodic-dynamo-benchmarks-test: + permissions: + id-token: write + contents: read + name: rocm-py3_10-periodic-dynamo-benchmarks + uses: ./.github/workflows/_rocm-test.yml + needs: linux-jammy-rocm-py3_10-periodic-dynamo-benchmarks-build + with: + build-environment: linux-jammy-rocm-py3_10 + docker-image: ${{ needs.linux-jammy-rocm-py3_10-periodic-dynamo-benchmarks-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-rocm-py3_10-periodic-dynamo-benchmarks-build.outputs.test-matrix }} + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc9-inductor-smoke-build: + name: cuda12.8-py3.10-gcc9-sm80 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: - get-default-label-prefix @@ -136,6 +208,7 @@ jobs: { include: [ { config: "inductor_torchbench_smoketest_perf", shard: 1, num_shards: 1, runner: "linux.aws.a100" }, ]} +<<<<<<< HEAD build-additional-packages: "vision audio fbgemm torchao" secrets: inherit @@ -156,6 +229,27 @@ jobs: with: build-environment: linux-jammy-py3.10-gcc11-build docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks +======= + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc9-inductor-smoke-test: + name: cuda12.8-py3.10-gcc9-sm80 + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cuda12_8-py3_10-gcc9-inductor-smoke-build + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-smoke-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-smoke-build.outputs.test-matrix }} + secrets: inherit + + linux-jammy-cpu-py3_9-gcc11-periodic-dynamo-benchmarks-build: + name: linux-jammy-cpu-py3.9-gcc11-periodic-dynamo-benchmarks + uses: ./.github/workflows/_linux-build.yml + needs: get-default-label-prefix + with: + build-environment: linux-jammy-py3.9-gcc11-build + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" test-matrix: | { include: [ @@ -170,6 +264,69 @@ jobs: { config: "cpu_inductor_freezing_avx2_torchbench", shard: 2, num_shards: 2, runner: "linux.10xlarge.avx2" }, { config: "cpu_inductor_freezing_avx2_timm", shard: 1, num_shards: 2, runner: "linux.10xlarge.avx2" }, { config: "cpu_inductor_freezing_avx2_timm", shard: 2, num_shards: 2, runner: "linux.10xlarge.avx2" }, +<<<<<<< HEAD +======= + ]} + secrets: inherit + + linux-jammy-cpu-py3_9-gcc11-periodic-dynamo-benchmarks-test: + name: linux-jammy-cpu-py3.9-gcc11-periodic-dynamo-benchmarks + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cpu-py3_9-gcc11-periodic-dynamo-benchmarks-build + with: + build-environment: linux-jammy-py3.9-gcc11-build + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-periodic-dynamo-benchmarks-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-periodic-dynamo-benchmarks-build.outputs.test-matrix }} + secrets: inherit + + + linux-jammy-cuda12_8-py3_10-gcc9-inductor-build: + name: cuda12.8-py3.10-gcc9-sm86 + uses: ./.github/workflows/_linux-build.yml + needs: get-default-label-prefix + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + cuda-arch-list: '8.6' + runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" + sync-tag: linux-jammy-cuda12_8-py3_10-gcc9-inductor-build + test-matrix: | + { include: [ + { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + ]} + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc9-inductor-test: + name: cuda12.8-py3.10-gcc9-sm86 + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cuda12_8-py3_10-gcc9-inductor-build + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }} + secrets: inherit + + linux-jammy-cpu-py3_9-gcc11-inductor-build: + name: linux-jammy-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-build.yml + needs: get-default-label-prefix + with: + build-environment: linux-jammy-py3.9-gcc11-build + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks + runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" + sync-tag: linux-jammy-cpu-py3_9-gcc11-inductor-build + test-matrix: | + { include: [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) { config: "cpu_inductor_freezing_huggingface", shard: 1, num_shards: 1, runner: "linux.8xlarge.amx" }, { config: "cpu_inductor_freezing_timm", shard: 1, num_shards: 2, runner: "linux.8xlarge.amx" }, { config: "cpu_inductor_freezing_timm", shard: 2, num_shards: 2, runner: "linux.8xlarge.amx" }, @@ -192,6 +349,7 @@ jobs: { config: "dynamic_cpu_aot_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "linux.8xlarge.amx" }, { config: "dynamic_cpu_aot_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "linux.8xlarge.amx" }, ]} +<<<<<<< HEAD build-additional-packages: "vision audio torchao" secrets: inherit @@ -203,4 +361,16 @@ jobs: build-environment: linux-jammy-py3.10-gcc11-build docker-image: ${{ needs.periodic-dynamo-benchmarks-cpu-build.outputs.docker-image }} test-matrix: ${{ needs.periodic-dynamo-benchmarks-cpu-build.outputs.test-matrix }} +======= + secrets: inherit + + linux-jammy-cpu-py3_9-gcc11-inductor-test: + name: linux-jammy-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cpu-py3_9-gcc11-inductor-build + with: + build-environment: linux-jammy-py3.9-gcc11-build + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: inherit diff --git a/.github/workflows/inductor-rocm-mi300.yml b/.github/workflows/inductor-rocm-mi300.yml index 369eee791dd62..0d2b9e1f27c91 100644 --- a/.github/workflows/inductor-rocm-mi300.yml +++ b/.github/workflows/inductor-rocm-mi300.yml @@ -28,7 +28,11 @@ jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} @@ -47,8 +51,13 @@ jobs: docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 test-matrix: | { include: [ +<<<<<<< HEAD { config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, { config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, +======= + { config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, + { config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ]} secrets: inherit diff --git a/.github/workflows/inductor-rocm.yml b/.github/workflows/inductor-rocm.yml index 87d78b600f44e..35c30fe637ec2 100644 --- a/.github/workflows/inductor-rocm.yml +++ b/.github/workflows/inductor-rocm.yml @@ -7,6 +7,10 @@ on: - release/* tags: - ciflow/inductor-rocm/* +<<<<<<< HEAD +======= + - ciflow/inductor/* +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) workflow_dispatch: concurrency: @@ -20,7 +24,11 @@ permissions: jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index 31ca8e6faa3ba..48a3c23bcde2f 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -12,14 +12,22 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-unittest cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} @@ -28,8 +36,13 @@ jobs: curr_ref_type: ${{ github.ref_type }} opt_out_experiments: lf +<<<<<<< HEAD inductor-build: name: inductor-build +======= + linux-jammy-cuda12_8-py3_10-gcc9-inductor-build: + name: cuda12.8-py3.10-gcc9-sm86 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: @@ -47,6 +60,7 @@ jobs: ]} secrets: inherit +<<<<<<< HEAD inductor-test: name: inductor-test uses: ./.github/workflows/_linux-test.yml @@ -59,6 +73,46 @@ jobs: inductor-halide-build: name: inductor-halide-build +======= + linux-jammy-cuda12_8-py3_10-gcc9-inductor-test: + name: cuda12.8-py3.10-gcc9-sm86 + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cuda12_8-py3_10-gcc9-inductor-build + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }} + secrets: inherit + + linux-jammy-cuda12_8-py3_12-gcc9-inductor-build: + name: cuda12.8-py3.12-gcc9-sm86 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + build-environment: linux-jammy-cuda12.8-py3.12-gcc9-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc9-inductor-benchmarks + cuda-arch-list: '8.6' + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + test-matrix: | + { include: [ + { config: "inductor", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + ]} + secrets: inherit + + linux-jammy-cuda12_8-py3_12-gcc9-inductor-test: + name: cuda12.8-py3.12-gcc9-sm86 + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cuda12_8-py3_12-gcc9-inductor-build + with: + build-environment: linux-jammy-cuda12.8-py3.12-gcc9-sm86 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_12-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_12-gcc9-inductor-build.outputs.test-matrix }} + secrets: inherit + + linux-jammy-cpu-py3_12-inductor-halide-build: + name: linux-jammy-cpu-py3.12-gcc11-inductor-halide +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: @@ -71,6 +125,7 @@ jobs: ]} secrets: inherit +<<<<<<< HEAD inductor-halide-test: name: inductor-halide-test uses: ./.github/workflows/_linux-test.yml @@ -83,6 +138,20 @@ jobs: inductor-triton-cpu-build: name: inductor-triton-cpu-build +======= + linux-jammy-cpu-py3_12-inductor-halide-test: + name: linux-jammy-cpu-py3.12-gcc11-inductor-halide + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cpu-py3_12-inductor-halide-build + with: + build-environment: linux-jammy-py3.12-gcc11 + docker-image: ${{ needs.linux-jammy-cpu-py3_12-inductor-halide-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_12-inductor-halide-build.outputs.test-matrix }} + secrets: inherit + + linux-jammy-cpu-py3_12-inductor-triton-cpu-build: + name: linux-jammy-cpu-py3.12-gcc11-inductor-triton-cpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: @@ -95,6 +164,7 @@ jobs: ]} secrets: inherit +<<<<<<< HEAD inductor-triton-cpu-test: name: linux-jammy-cpu-py3.12-gcc11-inductor-triton-cpu uses: ./.github/workflows/_linux-test.yml @@ -112,6 +182,25 @@ jobs: with: build-environment: linux-jammy-py3.10-gcc11-build docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks +======= + linux-jammy-cpu-py3_12-inductor-triton-cpu-test: + name: linux-jammy-cpu-py3.12-gcc11-inductor-triton-cpu + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cpu-py3_12-inductor-triton-cpu-build + with: + build-environment: linux-jammy-py3.12-gcc11 + docker-image: ${{ needs.linux-jammy-cpu-py3_12-inductor-triton-cpu-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_12-inductor-triton-cpu-build.outputs.test-matrix }} + secrets: inherit + + linux-jammy-cpu-py3_9-gcc11-inductor-build: + name: linux-jammy-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + build-environment: linux-jammy-py3.9-gcc11-build + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ @@ -122,6 +211,7 @@ jobs: ]} secrets: inherit +<<<<<<< HEAD inductor-cpu-test: name: inductor-cpu-test uses: ./.github/workflows/_linux-test.yml @@ -130,4 +220,39 @@ jobs: build-environment: linux-jammy-py3.10-gcc11-build docker-image: ${{ needs.inductor-cpu-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-cpu-build.outputs.test-matrix }} +======= + linux-jammy-cpu-py3_9-gcc11-inductor-test: + name: linux-jammy-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cpu-py3_9-gcc11-inductor-build + with: + build-environment: linux-jammy-py3.9-gcc11-build + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} + secrets: inherit + + linux-jammy-cuda12_8-py3_13-gcc9-inductor-build: + name: cuda12.8-py3.13-gcc9-sm86 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + build-environment: linux-jammy-cuda12.8-py3.13-gcc9-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.13-gcc9-inductor-benchmarks + cuda-arch-list: '8.6' + test-matrix: | + { include: [ + { config: "inductor", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + ]} + secrets: inherit + + linux-jammy-cuda12_8-py3_13-gcc9-inductor-test: + name: cuda12.8-py3.13-gcc9-sm86 + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cuda12_8-py3_13-gcc9-inductor-build + with: + build-environment: linux-jammy-cuda12.8-py3.13-gcc9-sm86 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_13-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_13-gcc9-inductor-build.outputs.test-matrix }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: inherit diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index a70929dd868d1..c0385d1369275 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -22,9 +22,13 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: unit-test: @@ -35,7 +39,11 @@ jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} @@ -44,8 +52,13 @@ jobs: curr_ref_type: ${{ github.ref_type }} opt_out_experiments: lf +<<<<<<< HEAD inductor-build: name: inductor-build +======= + linux-jammy-cuda12_8-py3_10-gcc9-inductor-build: + name: cuda12.8-py3.10-gcc9-sm86 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: @@ -53,6 +66,10 @@ jobs: docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD +======= + sync-tag: linux-jammy-cuda12_8-py3_10-gcc9-inductor-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test-matrix: | { include: [ { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, @@ -61,6 +78,7 @@ jobs: { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} +<<<<<<< HEAD build-additional-packages: "vision audio fbgemm torchao" secrets: inherit @@ -82,6 +100,29 @@ jobs: build-environment: linux-jammy-py3.10-gcc11-build docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +======= + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc9-inductor-test: + name: cuda12.8-py3.10-gcc9-sm86 + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cuda12_8-py3_10-gcc9-inductor-build + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }} + secrets: inherit + + linux-jammy-cpu-py3_9-gcc11-inductor-build: + name: linux-jammy-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + build-environment: linux-jammy-py3.9-gcc11-build + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + sync-tag: linux-jammy-cpu-py3_9-gcc11-inductor-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test-matrix: | { include: [ { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, @@ -93,6 +134,7 @@ jobs: { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, { config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" }, ]} +<<<<<<< HEAD build-additional-packages: "vision audio torchao" secrets: inherit @@ -104,4 +146,16 @@ jobs: build-environment: linux-jammy-py3.10-gcc11-build docker-image: ${{ needs.inductor-cpu-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-cpu-build.outputs.test-matrix }} +======= + secrets: inherit + + linux-jammy-cpu-py3_9-gcc11-inductor-test: + name: linux-jammy-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cpu-py3_9-gcc11-inductor-build + with: + build-environment: linux-jammy-py3.9-gcc11-build + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: inherit diff --git a/.github/workflows/lint-autoformat.yml b/.github/workflows/lint-autoformat.yml index f64c9973d698f..25a85ce64f02c 100644 --- a/.github/workflows/lint-autoformat.yml +++ b/.github/workflows/lint-autoformat.yml @@ -13,7 +13,11 @@ jobs: if: ${{ github.repository_owner == 'pytorch' && contains(github.event.pull_request.labels.*.name, 'autoformat') }} steps: - name: Checkout pytorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: submodules: true fetch-depth: 0 diff --git a/.github/workflows/lint-bc.yml b/.github/workflows/lint-bc.yml index 98adf44aefd82..fbf6afdb58b81 100644 --- a/.github/workflows/lint-bc.yml +++ b/.github/workflows/lint-bc.yml @@ -20,7 +20,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Run BC Lint Action +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/bc-lint@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/bc-lint@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: repo: ${{ github.event.pull_request.head.repo.full_name }} base_sha: ${{ github.event.pull_request.base.sha }} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 534c15824715e..7993423774be2 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -21,12 +21,17 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} +<<<<<<< HEAD get-changed-files: if: github.repository_owner == 'pytorch' name: Get changed files @@ -54,12 +59,22 @@ jobs: timeout: 120 runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" docker-image: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter +======= + lintrunner-clang: + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.8 + needs: get-label-type + with: + timeout: 120 + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" + docker-image: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed fetch-depth: 0 submodules: true ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | +<<<<<<< HEAD CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" if [ "$CHANGED_FILES" = "*" ]; then export ADDITIONAL_LINTRUNNER_ARGS="--take CLANGTIDY,CLANGFORMAT --all-files" @@ -102,12 +117,26 @@ jobs: timeout: 120 runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" docker-image: ci-image:pytorch-linux-jammy-linter +======= + export ADDITIONAL_LINTRUNNER_ARGS="--take CLANGTIDY,CLANGFORMAT --all-files" + export CLANG=1 + .github/scripts/lintrunner.sh + + lintrunner-noclang: + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.8 + needs: get-label-type + with: + timeout: 120 + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" + docker-image: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed fetch-depth: 0 submodules: true ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | +<<<<<<< HEAD CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" echo "Running all other linters" if [ "$CHANGED_FILES" = '*' ]; then @@ -118,6 +147,13 @@ jobs: quick-checks: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.9 +======= + export ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT --all-files" + .github/scripts/lintrunner.sh + + quick-checks: + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) needs: get-label-type with: timeout: 120 @@ -157,7 +193,11 @@ jobs: if: github.event_name == 'pull_request' && !contains(github.event.pull_request.labels.*.name, 'skip-pr-sanity-checks') steps: - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: submodules: false fetch-depth: -1 @@ -170,7 +210,11 @@ jobs: bash .github/scripts/pr-sanity-check.sh workflow-checks: +<<<<<<< HEAD uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.9 +======= + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) needs: get-label-type with: timeout: 120 @@ -181,7 +225,11 @@ jobs: ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | # Regenerate workflows +<<<<<<< HEAD export RELEASE_VERSION_TAG=2.9 +======= + export RELEASE_VERSION_TAG=2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .github/scripts/generate_ci_workflows.py RC=0 @@ -191,7 +239,11 @@ jobs: echo 'As shown by the above diff, the committed .github/workflows' echo 'are not up to date according to .github/templates.' echo 'Please run this command, commit, and push again to your PR:' +<<<<<<< HEAD echo export RELEASE_VERSION_TAG=2.9 +======= + echo +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) echo ' .github/scripts/generate_ci_workflows.py' echo echo 'If running that command does nothing, you may need to rebase' @@ -205,7 +257,11 @@ jobs: exit $RC toc: +<<<<<<< HEAD uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.9 +======= + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) needs: get-label-type with: timeout: 120 @@ -241,7 +297,11 @@ jobs: test-tools: name: Test tools if: ${{ github.repository == 'pytorch/pytorch' }} +<<<<<<< HEAD uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.9 +======= + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) needs: get-label-type with: timeout: 120 @@ -261,6 +321,7 @@ jobs: runs-on: linux.24_04.4x steps: - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 with: submodules: false @@ -269,6 +330,16 @@ jobs: uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: '3.10' +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 + with: + submodules: false + fetch-depth: 1 + - name: Setup Python 3.9 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 + with: + python-version: '3.9' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) architecture: x64 cache: pip - name: Install dependencies @@ -298,7 +369,11 @@ jobs: # [see note: pytorch repo ref] # deep clone (fetch-depth 0) required, to allow us to use git log - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: submodules: false fetch-depth: 1 @@ -318,7 +393,10 @@ jobs: check-latest: false cache: pip cache-dependency-path: | +<<<<<<< HEAD **/requirements-build.txt +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) **/requirements.txt - name: Setup Min Python version if: matrix.test_type != 'older_python_version' @@ -329,7 +407,10 @@ jobs: check-latest: false cache: pip cache-dependency-path: | +<<<<<<< HEAD **/requirements-build.txt +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) **/requirements.txt - name: Install torch if: matrix.test_type == 'with_torch' diff --git a/.github/workflows/linux-aarch64.yml b/.github/workflows/linux-aarch64.yml index 357347f781381..7dc7f8df8418c 100644 --- a/.github/workflows/linux-aarch64.yml +++ b/.github/workflows/linux-aarch64.yml @@ -19,7 +19,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/llm_td_retrieval.yml b/.github/workflows/llm_td_retrieval.yml index 292f0a956c35d..2f19eb6bb9a0e 100644 --- a/.github/workflows/llm_td_retrieval.yml +++ b/.github/workflows/llm_td_retrieval.yml @@ -12,7 +12,11 @@ jobs: name: get-label-type # Don't run on forked repos if: github.repository_owner == 'pytorch' +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -116,5 +120,9 @@ jobs: AWS_REGION: "" - name: Teardown Linux +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: always() diff --git a/.github/workflows/mac-mps.yml b/.github/workflows/mac-mps.yml index c80599fe89988..87477dda1e2e3 100644 --- a/.github/workflows/mac-mps.yml +++ b/.github/workflows/mac-mps.yml @@ -28,6 +28,10 @@ jobs: # than our AWS macos-m1-14 runners test-matrix: | { include: [ +<<<<<<< HEAD +======= + { config: "test_mps", shard: 1, num_shards: 1, runner: "macos-m1-13" }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) { config: "test_mps", shard: 1, num_shards: 1, runner: "macos-m1-14" }, { config: "test_mps", shard: 1, num_shards: 1, runner: "macos-m2-15" }, ]} diff --git a/.github/workflows/nightly-s3-uploads.yml b/.github/workflows/nightly-s3-uploads.yml index 1cafca0e0c850..b1afd9f74156f 100644 --- a/.github/workflows/nightly-s3-uploads.yml +++ b/.github/workflows/nightly-s3-uploads.yml @@ -23,7 +23,11 @@ jobs: environment: upload-stats steps: - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: fetch-depth: 1 submodules: false diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index eddb21ea2ca58..3bfac12d436ff 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -19,7 +19,11 @@ concurrency: jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} @@ -42,8 +46,13 @@ jobs: needs: get-label-type with: runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" +<<<<<<< HEAD build-environment: linux-jammy-py3.10-gcc11 docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11 +======= + build-environment: linux-jammy-py3.9-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: inherit docs-push: @@ -54,7 +63,11 @@ jobs: - get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD build-environment: linux-jammy-py3.10-gcc11 +======= + build-environment: linux-jammy-py3.9-gcc11 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) docker-image: ${{ needs.docs-build.outputs.docker-image }} push: ${{ github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || startsWith(github.event.ref, 'refs/tags/v') }} run-doxygen: true @@ -75,24 +88,38 @@ jobs: repo-owner: pytorch branch: main pin-folder: .github/ci_commit_pins +<<<<<<< HEAD # executorch jobs are disabled since it needs some manual work for the hash update # - repo-name: executorch # repo-owner: pytorch # branch: main # pin-folder: .ci/docker/ci_commit_pins +======= + - repo-name: executorch + repo-owner: pytorch + branch: main + pin-folder: .ci/docker/ci_commit_pins +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - repo-name: triton repo-owner: triton-lang branch: main pin-folder: .ci/docker/ci_commit_pins +<<<<<<< HEAD - repo-name: vllm repo-owner: vllm-project branch: main pin-folder: .github/ci_commit_pins +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Allow this to be triggered on either a schedule or on workflow_dispatch to allow for easier testing if: github.repository_owner == 'pytorch' && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') steps: - name: "${{ matrix.repo-owner }}/${{ matrix.repo-name }} update-commit-hash" +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/update-commit-hash@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/update-commit-hash@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: repo-owner: ${{ matrix.repo-owner }} repo-name: ${{ matrix.repo-name }} diff --git a/.github/workflows/nitpicker.yml b/.github/workflows/nitpicker.yml index 242f021e46fa9..b5d2c3a5af960 100644 --- a/.github/workflows/nitpicker.yml +++ b/.github/workflows/nitpicker.yml @@ -19,7 +19,11 @@ jobs: if: ${{ github.event.pull_request.number != 26921 && github.repository_owner == 'pytorch' }} steps: - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - uses: ethanis/nitpicker@v1 with: nitpicks: '.github/nitpicks.yml' diff --git a/.github/workflows/operator_benchmark.yml b/.github/workflows/operator_benchmark.yml index dcdc2cd0ba24e..4b9bcf35a836e 100644 --- a/.github/workflows/operator_benchmark.yml +++ b/.github/workflows/operator_benchmark.yml @@ -14,15 +14,19 @@ on: schedule: # Run at 07:00 UTC every Sunday - cron: 0 7 * * 0 +<<<<<<< HEAD pull_request: paths: - benchmarks/operator_benchmark/** - .github/workflows/operator_benchmark.yml +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read @@ -35,12 +39,25 @@ jobs: with: build-environment: linux-jammy-py3.10-gcc11-build docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks +======= +permissions: read-all + +jobs: + linux-jammy-cpu-py3_9-gcc11-opbenchmark-build: + if: github.repository_owner == 'pytorch' + name: linux-jammy-cpu-py3.9-gcc11-opbenchmark + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-jammy-py3.9-gcc11-build + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test-matrix: | { include: [ { config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.12xlarge" }, ]} secrets: inherit +<<<<<<< HEAD opbenchmark-on-demand-build: if: ${{ github.event_name == 'workflow_dispatch' && github.repository_owner == 'pytorch' }} name: opbenchmark-on-demand-build @@ -48,12 +65,22 @@ jobs: with: build-environment: linux-jammy-py3.10-gcc11-build docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks +======= + linux-jammy-cpu-py3_9-gcc11-opbenchmark-on-demand-build: + if: ${{ github.event_name == 'workflow_dispatch' && github.repository_owner == 'pytorch' }} + name: linux-jammy-cpu-py3.9-gcc11-opbenchmark + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-jammy-py3.9-gcc11-build + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test-matrix: | { include: [ { config: "cpu_operator_benchmark_${{ inputs.test_mode }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" }, ]} secrets: inherit +<<<<<<< HEAD opbenchmark-test: name: opbenchmark-test uses: ./.github/workflows/_linux-test.yml @@ -62,4 +89,14 @@ jobs: build-environment: linux-jammy-py3.10-gcc11-build docker-image: ${{ needs.opbenchmark-build.outputs.docker-image }} test-matrix: ${{ needs.opbenchmark-build.outputs.test-matrix }} +======= + linux-jammy-cpu-py3_9-gcc11-opbenchmark-test: + name: linux-jammy-cpu-py3.9-gcc11-opbenchmark + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cpu-py3_9-gcc11-opbenchmark-build + with: + build-environment: linux-jammy-py3.9-gcc11-build + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-opbenchmark-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-opbenchmark-build.outputs.test-matrix }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: inherit diff --git a/.github/workflows/periodic-rocm-mi300.yml b/.github/workflows/periodic-rocm-mi300.yml index 850c98b3fa81b..d17236c0e6da3 100644 --- a/.github/workflows/periodic-rocm-mi300.yml +++ b/.github/workflows/periodic-rocm-mi300.yml @@ -41,7 +41,11 @@ jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' with: triggering_actor: ${{ github.triggering_actor }} @@ -59,9 +63,15 @@ jobs: docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 test-matrix: | { include: [ +<<<<<<< HEAD { config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4", owners: ["module:rocm", "oncall:distributed"] }, { config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4", owners: ["module:rocm", "oncall:distributed"] }, { config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4", owners: ["module:rocm", "oncall:distributed"] }, +======= + { config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.mi300.4", owners: ["module:rocm", "oncall:distributed"] }, + { config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.mi300.4", owners: ["module:rocm", "oncall:distributed"] }, + { config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.mi300.4", owners: ["module:rocm", "oncall:distributed"] }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ]} secrets: inherit diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 418699cb5f5a7..69deba048a9c9 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -20,9 +20,13 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: llm-td: @@ -43,7 +47,11 @@ jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' with: triggering_actor: ${{ github.triggering_actor }} @@ -51,6 +59,7 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} +<<<<<<< HEAD linux-jammy-cuda12_4-py3_10-gcc11-build: name: linux-jammy-cuda12.4-py3.10-gcc11 uses: ./.github/workflows/_linux-build.yml @@ -82,6 +91,8 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) linux-jammy-cuda12_8-py3_10-gcc11-build: name: linux-jammy-cuda12.8-py3.10-gcc11 uses: ./.github/workflows/_linux-build.yml @@ -127,6 +138,10 @@ jobs: { config: "multigpu", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, { config: "multigpu", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, ]} +<<<<<<< HEAD +======= + build-with-debug: false +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: inherit linux-jammy-cuda12_8-py3_9-gcc9-test: @@ -147,6 +162,10 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 +<<<<<<< HEAD +======= + build-with-debug: true +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, @@ -171,6 +190,7 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-debug-build.outputs.test-matrix }} secrets: inherit +<<<<<<< HEAD linux-jammy-cuda13_0-py3_10-gcc11-build: name: linux-jammy-cuda13.0-py3.10-gcc11 uses: ./.github/workflows/_linux-build.yml @@ -203,6 +223,8 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) linux-jammy-rocm-py3_10-build: name: linux-jammy-rocm-py3.10 uses: ./.github/workflows/_linux-build.yml diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index f884fee53fc7c..e20fe43b9bb33 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -19,9 +19,13 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: llm-td: @@ -42,21 +46,35 @@ jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} +<<<<<<< HEAD linux-jammy-py3_10-gcc11-build: name: linux-jammy-py3.10-gcc11 +======= + linux-jammy-py3_9-gcc11-build: + name: linux-jammy-py3.9-gcc11 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD build-environment: linux-jammy-py3.10-gcc11 docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11 +======= + build-environment: linux-jammy-py3.9-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, @@ -73,6 +91,7 @@ jobs: ]} secrets: inherit +<<<<<<< HEAD linux-jammy-py3_10-gcc11-test: name: linux-jammy-py3.10-gcc11 uses: ./.github/workflows/_linux-test.yml @@ -83,11 +102,24 @@ jobs: build-environment: linux-jammy-py3.10-gcc11 docker-image: ${{ needs.linux-jammy-py3_10-gcc11-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3_10-gcc11-build.outputs.test-matrix }} +======= + linux-jammy-py3_9-gcc11-test: + name: linux-jammy-py3.9-gcc11 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-py3_9-gcc11-build + - target-determination + with: + build-environment: linux-jammy-py3.9-gcc11 + docker-image: ${{ needs.linux-jammy-py3_9-gcc11-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_9-gcc11-build.outputs.test-matrix }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: inherit linux-docs: name: linux-docs uses: ./.github/workflows/_docs.yml +<<<<<<< HEAD needs: linux-jammy-py3_10-gcc11-build with: build-environment: linux-jammy-py3.10-gcc11 @@ -96,26 +128,51 @@ jobs: linux-jammy-py3_10-gcc11-no-ops: name: linux-jammy-py3.10-gcc11-no-ops +======= + needs: linux-jammy-py3_9-gcc11-build + with: + build-environment: linux-jammy-py3.9-gcc11 + docker-image: ${{ needs.linux-jammy-py3_9-gcc11-build.outputs.docker-image }} + secrets: inherit + + linux-jammy-py3_9-gcc11-no-ops: + name: linux-jammy-py3.9-gcc11-no-ops +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD build-environment: linux-jammy-py3.10-gcc11-no-ops docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11 +======= + build-environment: linux-jammy-py3.9-gcc11-no-ops + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, ]} secrets: inherit +<<<<<<< HEAD linux-jammy-py3_10-gcc11-pch: name: linux-jammy-py3.10-gcc11-pch +======= + linux-jammy-py3_9-gcc11-pch: + name: linux-jammy-py3.9-gcc11-pch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD build-environment: linux-jammy-py3.10-gcc11-pch docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11 +======= + build-environment: linux-jammy-py3.9-gcc11-pch + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, @@ -132,6 +189,7 @@ jobs: docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan test-matrix: | { include: [ +<<<<<<< HEAD { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "default", shard: 2, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "default", shard: 3, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, @@ -139,10 +197,22 @@ jobs: { config: "default", shard: 5, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "default", shard: 6, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, +======= + { config: "default", shard: 1, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 2, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 3, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 4, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 5, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 6, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ]} sync-tag: asan-build secrets: inherit +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) linux-jammy-py3_10-clang18-asan-test: name: linux-jammy-py3.10-clang18-asan uses: ./.github/workflows/_linux-test.yml @@ -156,13 +226,22 @@ jobs: sync-tag: asan-test secrets: inherit +<<<<<<< HEAD linux-jammy-py3_10-clang12-onnx-build: name: linux-jammy-py3.10-clang12-onnx +======= + linux-jammy-py3_9-clang12-onnx-build: + name: linux-jammy-py3.9-clang12-onnx +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD build-environment: linux-jammy-py3.10-clang12-onnx +======= + build-environment: linux-jammy-py3.9-clang12-onnx +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) docker-image-name: ci-image:pytorch-linux-jammy-py3-clang12-onnx test-matrix: | { include: [ @@ -171,6 +250,7 @@ jobs: ]} secrets: inherit +<<<<<<< HEAD linux-jammy-py3_10-clang12-onnx-test: name: linux-jammy-py3.10-clang12-onnx uses: ./.github/workflows/_linux-test.yml @@ -185,12 +265,33 @@ jobs: linux-jammy-py3_10-clang12-build: name: linux-jammy-py3.10-clang12 +======= + linux-jammy-py3_9-clang12-onnx-test: + name: linux-jammy-py3.9-clang12-onnx + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-py3_9-clang12-onnx-build + - target-determination + with: + build-environment: linux-jammy-py3.9-clang12-onnx + docker-image: ${{ needs.linux-jammy-py3_9-clang12-onnx-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_9-clang12-onnx-build.outputs.test-matrix }} + secrets: inherit + + linux-jammy-py3_9-clang12-build: + name: linux-jammy-py3.9-clang12 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD build-environment: linux-jammy-py3.10-clang12 docker-image-name: ci-image:pytorch-linux-jammy-py3.10-clang12 +======= + build-environment: linux-jammy-py3.9-clang12 + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-clang12 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, @@ -207,6 +308,7 @@ jobs: ]} secrets: inherit +<<<<<<< HEAD linux-jammy-py3_10-clang12-test: name: linux-jammy-py3.10-clang12 uses: ./.github/workflows/_linux-test.yml @@ -217,6 +319,18 @@ jobs: build-environment: linux-jammy-py3.10-clang12 docker-image: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.test-matrix }} +======= + linux-jammy-py3_9-clang12-test: + name: linux-jammy-py3.9-clang12 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-py3_9-clang12-build + - target-determination + with: + build-environment: linux-jammy-py3.9-clang12 + docker-image: ${{ needs.linux-jammy-py3_9-clang12-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_9-clang12-build.outputs.test-matrix }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: inherit linux-jammy-py3_13-clang12-build: @@ -251,22 +365,138 @@ jobs: build-environment: linux-jammy-py3.13-clang12 docker-image: ${{ needs.linux-jammy-py3_13-clang12-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3_13-clang12-build.outputs.test-matrix }} +<<<<<<< HEAD secrets: inherit linux-jammy-cuda12_8-cudnn9-py3_10-clang12-build: name: linux-jammy-cuda12.8-cudnn9-py3.10-clang12 +======= + timeout-minutes: 600 + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc11-build-distributed: + name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD build-environment: linux-jammy-cuda12.8-cudnn9-py3.10-clang12 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-clang12 +======= + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 + cuda-arch-list: '7.5' + test-matrix: | + { include: [ + { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + ]} + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc11-test-distributed: + name: linux-jammy-cuda12.8-py3.10-gcc11-test + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-cuda12_8-py3_10-gcc11-build-distributed + - target-determination + with: + timeout-minutes: 360 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed.outputs.test-matrix }} + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc11-build: + name: linux-jammy-cuda12.8-py3.10-gcc11 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-cuda12.8-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + ]} + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc11-test: + name: linux-jammy-cuda12.8-py3.10-gcc11 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-cuda12_8-py3_10-gcc11-build + - target-determination + with: + timeout-minutes: 360 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} + secrets: inherit + + linux-jammy-py3-clang12-mobile-build: + name: linux-jammy-py3-clang12-mobile-build + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-py3-clang12-mobile-build + docker-image-name: ci-image:pytorch-linux-jammy-py3-clang15-asan + build-generates-artifacts: false +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, ]} secrets: inherit +<<<<<<< HEAD +======= + linux-jammy-cuda12_8-cudnn9-py3_9-clang12-build: + name: linux-jammy-cuda12.8-cudnn9-py3.9-clang12 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-cuda12.8-cudnn9-py3.9-clang12 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-clang12 + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 1 }, + ]} + secrets: inherit + + linux-jammy-py3_9-clang9-xla-build: + name: linux-jammy-py3_9-clang9-xla + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-py3.9-clang9-xla + docker-image-name: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base:v1.3-lite + test-matrix: | + { include: [ + { config: "xla", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + ]} + secrets: inherit + + linux-jammy-py3_9-clang9-xla-test: + name: linux-jammy-py3_9-clang9-xla + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-py3_9-clang9-xla-build + with: + build-environment: linux-jammy-py3.9-clang9-xla + docker-image: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.test-matrix }} + secrets: inherit + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) linux-jammy-cpu-py3_10-gcc11-bazel-test: name: linux-jammy-cpu-py3.10-gcc11-bazel-test uses: ./.github/workflows/_bazel-build-test.yml @@ -282,14 +512,24 @@ jobs: ]} secrets: inherit +<<<<<<< HEAD linux-jammy-py3_10-gcc11-mobile-lightweight-dispatch-build: name: linux-jammy-py3.10-gcc11-mobile-lightweight-dispatch-build +======= + linux-jammy-py3_9-gcc11-mobile-lightweight-dispatch-build: + name: linux-jammy-py3.9-gcc11-mobile-lightweight-dispatch-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD build-environment: linux-jammy-py3.10-gcc11-mobile-lightweight-dispatch-build docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11 +======= + build-environment: linux-jammy-py3.9-gcc11-mobile-lightweight-dispatch-build + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build-generates-artifacts: false test-matrix: | { include: [ @@ -316,8 +556,43 @@ jobs: ]} secrets: inherit +<<<<<<< HEAD linux-jammy-py3-clang12-executorch-build: if: false # Docker build needs pin update +======= + linux-jammy-cuda12_8-py3_10-gcc11-sm89-build: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm89 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm89 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 + cuda-arch-list: 8.9 + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + ]} + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc11-sm89-test: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm89 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-cuda12_8-py3_10-gcc11-sm89-build + - target-determination + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm89 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm89-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm89-build.outputs.test-matrix }} + secrets: inherit + + linux-jammy-py3-clang12-executorch-build: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) name: linux-jammy-py3-clang12-executorch uses: ./.github/workflows/_linux-build.yml needs: get-label-type @@ -367,6 +642,7 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }} secrets: inherit +<<<<<<< HEAD linux-jammy-xpu-n-py3_9-build: name: linux-jammy-xpu-n-py3.9 uses: ./.github/workflows/_linux-build.yml @@ -376,6 +652,17 @@ jobs: runner_prefix: ${{ needs.get-label-type.outputs.label-type }} build-environment: linux-jammy-xpu-n-py3.9 docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3 +======= + linux-jammy-xpu-2025_1-py3_9-build: + name: linux-jammy-xpu-2025.1-py3.9 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + sync-tag: linux-xpu-2025-1-build + runner_prefix: ${{ needs.get-label-type.outputs.label-type }} + build-environment: linux-jammy-xpu-2025.1-py3.9 + docker-image-name: ci-image:pytorch-linux-jammy-xpu-2025.1-py3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 4, runner: "linux.idc.xpu" }, diff --git a/.github/workflows/revert.yml b/.github/workflows/revert.yml index 226d773e48977..476691f003e69 100644 --- a/.github/workflows/revert.yml +++ b/.github/workflows/revert.yml @@ -26,7 +26,11 @@ jobs: architecture: x64 check-latest: false cache: pip +<<<<<<< HEAD - run: pip install pyyaml==6.0.2 +======= + - run: pip install pyyaml==6.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Setup committer id run: | diff --git a/.github/workflows/rocm-mi300.yml b/.github/workflows/rocm-mi300.yml index 51a807250f549..1c49cc5397a17 100644 --- a/.github/workflows/rocm-mi300.yml +++ b/.github/workflows/rocm-mi300.yml @@ -28,7 +28,11 @@ jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} @@ -36,13 +40,20 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} +<<<<<<< HEAD linux-noble-rocm-py3_12-build: if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} name: linux-noble-rocm-py3.12-mi300 +======= + linux-jammy-rocm-py3_10-build: + if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} + name: linux-jammy-rocm-py3.10-mi300 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD build-environment: linux-noble-rocm-py3.12-mi300 docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 sync-tag: rocm-build @@ -70,4 +81,33 @@ jobs: build-environment: linux-noble-rocm-py3.12-mi300 docker-image: ${{ needs.linux-noble-rocm-py3_12-build.outputs.docker-image }} test-matrix: ${{ needs.linux-noble-rocm-py3_12-build.outputs.test-matrix }} +======= + build-environment: linux-jammy-rocm-py3.10-mi300 + docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 + sync-tag: rocm-build + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" }, + { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" }, + { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" }, + { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" }, + { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" }, + { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" }, + ]} + secrets: inherit + + linux-jammy-rocm-py3_10-test: + permissions: + id-token: write + contents: read + name: linux-jammy-rocm-py3.10-mi300 + uses: ./.github/workflows/_rocm-test.yml + needs: + - linux-jammy-rocm-py3_10-build + - target-determination + with: + build-environment: linux-jammy-rocm-py3.10-mi300 + docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: inherit diff --git a/.github/workflows/s390x-periodic.yml b/.github/workflows/s390x-periodic.yml index 405e3e1a581cc..2723fa23dc2ba 100644 --- a/.github/workflows/s390x-periodic.yml +++ b/.github/workflows/s390x-periodic.yml @@ -15,9 +15,13 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: llm-td: diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index 197a04054bfee..558b71f4c82a7 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -18,9 +18,13 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: llm-td: @@ -41,7 +45,11 @@ jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} @@ -78,14 +86,24 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm86-build.outputs.test-matrix }} secrets: inherit +<<<<<<< HEAD linux-jammy-py3_10-clang12-build: name: linux-jammy-py3.10-clang12 +======= + linux-jammy-py3_9-clang12-build: + name: linux-jammy-py3.9-clang12 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD build-environment: linux-jammy-py3.10-clang12 docker-image-name: ci-image:pytorch-linux-jammy-py3.10-clang12 +======= + build-environment: linux-jammy-py3.9-clang12 + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-clang12 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test-matrix: | { include: [ { config: "slow", shard: 1, num_shards: 2, runner: "linux.2xlarge" }, @@ -93,6 +111,7 @@ jobs: ]} secrets: inherit +<<<<<<< HEAD linux-jammy-py3_10-clang12-test: name: linux-jammy-py3.10-clang12 uses: ./.github/workflows/_linux-test.yml @@ -103,6 +122,18 @@ jobs: build-environment: linux-jammy-py3.10-clang12 docker-image: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.test-matrix }} +======= + linux-jammy-py3_9-clang12-test: + name: linux-jammy-py3.9-clang12 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-py3_9-clang12-build + - target-determination + with: + build-environment: linux-jammy-py3.9-clang12 + docker-image: ${{ needs.linux-jammy-py3_9-clang12-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_9-clang12-build.outputs.test-matrix }} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secrets: inherit linux-jammy-rocm-py3_10-build: diff --git a/.github/workflows/target-determination-indexer.yml b/.github/workflows/target-determination-indexer.yml index f5f29c9646f40..c387b4c667e83 100644 --- a/.github/workflows/target-determination-indexer.yml +++ b/.github/workflows/target-determination-indexer.yml @@ -13,7 +13,11 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -35,7 +39,11 @@ jobs: - name: Calculate docker image id: calculate-docker-image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 working-directory: pytorch @@ -50,13 +58,21 @@ jobs: echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}" - name: Pull docker image +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG id: install-nvidia-driver +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/setup-nvidia@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/setup-nvidia@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Clone CodeLlama uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -149,7 +165,11 @@ jobs: "s3://target-determinator-assets/indexes/latest/${ZIP_NAME}" - name: Teardown Linux +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: always() concurrency: diff --git a/.github/workflows/target_determination.yml b/.github/workflows/target_determination.yml index 3e9f848e9e09a..87a543a2e9f3c 100644 --- a/.github/workflows/target_determination.yml +++ b/.github/workflows/target_determination.yml @@ -9,7 +9,11 @@ jobs: name: get-label-type # Don't run on forked repos if: github.repository_owner == 'pytorch' +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -27,7 +31,11 @@ jobs: # checkout because when we run this action we don't *have* a local # checkout. In other cases you should prefer a local checkout. - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: submodules: false diff --git a/.github/workflows/test-check-binary.yml b/.github/workflows/test-check-binary.yml index a13e1d027f130..522a9311b00fb 100644 --- a/.github/workflows/test-check-binary.yml +++ b/.github/workflows/test-check-binary.yml @@ -15,7 +15,11 @@ jobs: check_binary_linux_cpu: if: github.repository_owner == 'pytorch' name: Test check_binary.sh for Linux CPU +<<<<<<< HEAD uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.9 +======= + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: docker-image: python:3.11 docker-build-dir: "skip-docker-build" @@ -28,9 +32,15 @@ jobs: check_binary_linux_cuda: if: github.repository_owner == 'pytorch' name: Test check_binary.sh for Linux CUDA +<<<<<<< HEAD uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.9 with: runner: linux.g4dn.4xlarge.nvidia.gpu +======= + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@release/2.8 + with: + runner: linux.4xlarge.nvidia.gpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) docker-image: python:3.11 docker-build-dir: "skip-docker-build" script: | diff --git a/.github/workflows/test-h100.yml b/.github/workflows/test-h100.yml index d08d6033c47e8..fb08ebc72cbc6 100644 --- a/.github/workflows/test-h100.yml +++ b/.github/workflows/test-h100.yml @@ -4,10 +4,13 @@ on: pull_request: paths: - .github/workflows/test-h100.yml +<<<<<<< HEAD - test/inductor/test_max_autotune.py - torch/_inductor/kernel/mm.py - torch/_inductor/kernel/mm_grouped.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) workflow_dispatch: schedule: - cron: 0 4,10,16,22 * * * # every 6 hours @@ -19,16 +22,23 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -41,7 +51,11 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD runner: linux.12xlarge.memory +======= + runner: "linux.12xlarge" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '9.0' diff --git a/.github/workflows/torchbench.yml b/.github/workflows/torchbench.yml index e4f0c692e9764..5ab5a80c50221 100644 --- a/.github/workflows/torchbench.yml +++ b/.github/workflows/torchbench.yml @@ -10,15 +10,22 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: get-default-label-prefix: if: github.repository_owner == 'pytorch' name: get-default-label-prefix +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index efc027ad2acb2..cdacda4edcddf 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -16,9 +16,13 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: llm-td: @@ -39,7 +43,11 @@ jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} @@ -63,6 +71,7 @@ jobs: ]} secrets: inherit +<<<<<<< HEAD linux-jammy-cuda12_8-py3_10-gcc11-build: name: linux-jammy-cuda12.8-py3.10-gcc11 uses: ./.github/workflows/_linux-build.yml @@ -100,6 +109,8 @@ jobs: secrets: inherit +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # no-ops builds test USE_PER_OPERATOR_HEADERS=0 where ATen/ops is not generated linux-jammy-cuda12_8-py3_10-gcc11-no-ops-build: name: linux-jammy-cuda12.8-py3.10-gcc11-no-ops @@ -131,6 +142,10 @@ jobs: { config: "default", shard: 1, num_shards: 3, runner: "macos-m1-stable" }, { config: "default", shard: 2, num_shards: 3, runner: "macos-m1-stable" }, { config: "default", shard: 3, num_shards: 3, runner: "macos-m1-stable" }, +<<<<<<< HEAD +======= + { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-13" }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) { config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-14" }, { config: "mps", shard: 1, num_shards: 1, runner: "macos-m2-15" }, ]} @@ -201,9 +216,15 @@ jobs: sync-tag: rocm-build test-matrix: | { include: [ +<<<<<<< HEAD { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, { config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.4" }, +======= + { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2" }, + { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2" }, + { config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.4" }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ]} secrets: inherit @@ -224,12 +245,22 @@ jobs: tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl" secrets: inherit +<<<<<<< HEAD inductor-build: name: inductor-build uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: build-environment: linux-jammy-cuda12.8-py3.12-gcc9-sm80 +======= + # NB: Keep this in sync with inductor-perf-test-nightly.yml + linux-jammy-cuda12_8-py3_10-gcc9-inductor-build: + name: cuda12.8-py3.10-gcc9-sm80 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' secrets: inherit @@ -240,8 +271,13 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" +<<<<<<< HEAD build-environment: linux-jammy-py3.10-gcc11 docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks +======= + build-environment: linux-jammy-py3.9-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test-matrix: | { include: [ { config: "verify_cachebench", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, @@ -255,7 +291,11 @@ jobs: - verify-cachebench-cpu-build - target-determination with: +<<<<<<< HEAD build-environment: linux-jammy-py3.10-gcc11 +======= + build-environment: linux-jammy-py3.9-gcc11 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) docker-image: ${{ needs.verify-cachebench-cpu-build.outputs.docker-image }} test-matrix: ${{ needs.verify-cachebench-cpu-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/trymerge.yml b/.github/workflows/trymerge.yml index 5c456c607c887..e18659f799f1f 100644 --- a/.github/workflows/trymerge.yml +++ b/.github/workflows/trymerge.yml @@ -28,7 +28,11 @@ jobs: check-latest: false cache: pip architecture: x64 +<<<<<<< HEAD - run: pip install pyyaml==6.0.2 +======= + - run: pip install pyyaml==6.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Setup committer id run: | @@ -59,6 +63,7 @@ jobs: # on the PR appear in chronological order (timing issues can shuffle them around) sleep 60 fi +<<<<<<< HEAD # Require a comment id for merge operations if [ -z "${COMMENT_ID}" ]; then @@ -72,6 +77,24 @@ jobs: python3 .github/scripts/trymerge.py --ignore-current --comment-id "${COMMENT_ID}" "${PR_NUM}" else python3 .github/scripts/trymerge.py --comment-id "${COMMENT_ID}" "${PR_NUM}" +======= + if [ -n "${FORCE}" ]; then + if [ -n "${COMMENT_ID}" ]; then + python3 .github/scripts/trymerge.py --force --comment-id "${COMMENT_ID}" "${PR_NUM}" + else + python3 .github/scripts/trymerge.py --force "${PR_NUM}" + fi + elif [ -n "${IGNORE_CURRENT}" ]; then + if [ -n "${COMMENT_ID}" ]; then + python3 .github/scripts/trymerge.py --ignore-current --comment-id "${COMMENT_ID}" "${PR_NUM}" + else + python3 .github/scripts/trymerge.py --ignore-current "${PR_NUM}" + fi + elif [ -n "${COMMENT_ID}" ]; then + python3 .github/scripts/trymerge.py --comment-id "${COMMENT_ID}" "${PR_NUM}" + else + python3 .github/scripts/trymerge.py "${PR_NUM}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fi - name: Comment on Canceled if: ${{ cancelled() && steps.checkout.outcome == 'success' }} diff --git a/.github/workflows/tryrebase.yml b/.github/workflows/tryrebase.yml index 1a8e00e4390be..43275303c3acc 100644 --- a/.github/workflows/tryrebase.yml +++ b/.github/workflows/tryrebase.yml @@ -25,7 +25,11 @@ jobs: architecture: x64 check-latest: false cache: pip +<<<<<<< HEAD - run: pip install pyyaml==6.0.2 +======= + - run: pip install pyyaml==6.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Setup committer id run: | diff --git a/.github/workflows/unstable.yml b/.github/workflows/unstable.yml index 5eeb8b19a325a..07fa6a6e17cb1 100644 --- a/.github/workflows/unstable.yml +++ b/.github/workflows/unstable.yml @@ -12,9 +12,13 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true +<<<<<<< HEAD permissions: id-token: write contents: read +======= +permissions: read-all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: # There must be at least one job here to satisfy GitHub action workflow syntax @@ -46,13 +50,18 @@ jobs: get-label-type: name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} +<<<<<<< HEAD linux-jammy-py3_9-clang9-xla-build: name: linux-jammy-py3_9-clang9-xla @@ -77,3 +86,5 @@ jobs: docker-image: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.test-matrix }} secrets: inherit +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.github/workflows/update-viablestrict.yml b/.github/workflows/update-viablestrict.yml index e3ca35d2d01dc..4cd42f6e4f516 100644 --- a/.github/workflows/update-viablestrict.yml +++ b/.github/workflows/update-viablestrict.yml @@ -7,7 +7,11 @@ on: concurrency: group: ${{ github.workflow }} +<<<<<<< HEAD cancel-in-progress: true +======= + cancel-in-progress: false +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) jobs: do_update_viablestrict: @@ -18,12 +22,20 @@ jobs: environment: ${{ (github.event_name == 'schedule') && 'mergebot' || '' }} steps: - name: Update viable/strict +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/update-viablestrict@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/update-viablestrict@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) id: update_viablestrict with: repository: pytorch/pytorch stable-branch: viable/strict +<<<<<<< HEAD requires: '[\"pull\", \"trunk\", \"lint\", \"^linux-binary-manywheel$\", \"^linux-binary-libtorch-release$\", \"linux-aarch64\"]' +======= + requires: '[\"pull\", \"trunk\", \"lint\", \"linux-binary\"]' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) secret-bot-token: ${{ secrets.MERGEBOT_TOKEN }} clickhouse-url: ${{ secrets.CLICKHOUSE_URL }} clickhouse-username: ${{ secrets.CLICKHOUSE_VIABLESTRICT_USERNAME }} diff --git a/.github/workflows/update_pytorch_labels.yml b/.github/workflows/update_pytorch_labels.yml index 535950b3c0b73..b7c9b016173fe 100644 --- a/.github/workflows/update_pytorch_labels.yml +++ b/.github/workflows/update_pytorch_labels.yml @@ -17,7 +17,11 @@ jobs: contents: read steps: - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: fetch-depth: 1 submodules: false diff --git a/.github/workflows/upload-test-stats-while-running.yml b/.github/workflows/upload-test-stats-while-running.yml index 82c21467dc6a0..038bed54e5f9c 100644 --- a/.github/workflows/upload-test-stats-while-running.yml +++ b/.github/workflows/upload-test-stats-while-running.yml @@ -16,7 +16,11 @@ jobs: runs-on: linux.2xlarge steps: - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: fetch-depth: 1 submodules: false diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index 3cfc651b2a62d..72ac9c1d140dd 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -14,7 +14,10 @@ on: - inductor-periodic - rocm - rocm-mi300 +<<<<<<< HEAD - rocm-mi355 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - inductor-micro-benchmark - inductor-micro-benchmark-x86 - inductor-cu124 @@ -58,7 +61,11 @@ jobs: run: echo "${TRIGGERING_WORKFLOW}" - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: Configure aws credentials uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 diff --git a/.github/workflows/upload-torch-dynamo-perf-stats.yml b/.github/workflows/upload-torch-dynamo-perf-stats.yml index db3fc72e68e92..555c7dcd7832d 100644 --- a/.github/workflows/upload-torch-dynamo-perf-stats.yml +++ b/.github/workflows/upload-torch-dynamo-perf-stats.yml @@ -32,7 +32,11 @@ jobs: name: Upload dynamo performance stats for ${{ github.event.workflow_run.id }}, attempt ${{ github.event.workflow_run.run_attempt }} steps: - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: submodules: false fetch-depth: 1 diff --git a/.github/workflows/upload_test_stats_intermediate.yml b/.github/workflows/upload_test_stats_intermediate.yml index 1764139fed25c..5366f84b60fd8 100644 --- a/.github/workflows/upload_test_stats_intermediate.yml +++ b/.github/workflows/upload_test_stats_intermediate.yml @@ -17,7 +17,11 @@ jobs: environment: upload-stats steps: - name: Checkout PyTorch +<<<<<<< HEAD uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.9 +======= + uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: fetch-depth: 1 submodules: false diff --git a/.github/workflows/weekly.yml b/.github/workflows/weekly.yml index 2c534891c6e2d..081d6d3a93454 100644 --- a/.github/workflows/weekly.yml +++ b/.github/workflows/weekly.yml @@ -22,7 +22,11 @@ jobs: fetch-depth: 0 - name: update-xla-commit-hash continue-on-error: true +<<<<<<< HEAD uses: pytorch/test-infra/.github/actions/update-commit-hash@release/2.9 +======= + uses: pytorch/test-infra/.github/actions/update-commit-hash@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: repo-name: xla branch: master diff --git a/.github/workflows/xpu.yml b/.github/workflows/xpu.yml index 3a17bb9d70a19..cd20b01463ffd 100644 --- a/.github/workflows/xpu.yml +++ b/.github/workflows/xpu.yml @@ -5,10 +5,13 @@ on: tags: - ciflow/xpu/* workflow_dispatch: +<<<<<<< HEAD schedule: # Run 3 times on weekdays and less frequently on weekends. - cron: 45 0,8,16 * * 1-5 - cron: 45 4 * * 0,6 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} @@ -19,13 +22,18 @@ jobs: get-label-type: if: github.repository_owner == 'pytorch' name: get-label-type +<<<<<<< HEAD uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.9 +======= + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} +<<<<<<< HEAD linux-jammy-xpu-n-1-py3_10-build: name: linux-jammy-xpu-n-1-py3.10 uses: ./.github/workflows/_linux-build.yml @@ -35,6 +43,17 @@ jobs: runner_prefix: ${{ needs.get-label-type.outputs.label-type }} build-environment: linux-jammy-xpu-n-1-py3.10 docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-1-py3 +======= + linux-jammy-xpu-2025_0-py3_9-build: + name: linux-jammy-xpu-2025.0-py3.9 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + sync-tag: linux-xpu-2025-0-build + runner_prefix: ${{ needs.get-label-type.outputs.label-type }} + build-environment: linux-jammy-xpu-2025.0-py3.9 + docker-image-name: ci-image:pytorch-linux-jammy-xpu-2025.0-py3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) runner: linux.12xlarge test-matrix: | { include: [ @@ -47,6 +66,7 @@ jobs: ]} secrets: inherit +<<<<<<< HEAD linux-jammy-xpu-n-py3_10-build: name: linux-jammy-xpu-n-py3.10 uses: ./.github/workflows/_linux-build.yml @@ -74,10 +94,38 @@ jobs: name: linux-jammy-xpu-n-py3.10 uses: ./.github/workflows/_xpu-test.yml needs: linux-jammy-xpu-n-py3_10-build +======= + linux-jammy-xpu-2025_1-py3_9-build: + name: linux-jammy-xpu-2025.1-py3.9 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + sync-tag: linux-xpu-2025-1-build + runner_prefix: ${{ needs.get-label-type.outputs.label-type }} + build-environment: linux-jammy-xpu-2025.1-py3.9 + docker-image-name: ci-image:pytorch-linux-jammy-xpu-2025.1-py3 + runner: linux.12xlarge + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "default", shard: 2, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "default", shard: 3, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "default", shard: 4, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "default", shard: 5, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "default", shard: 6, num_shards: 6, runner: "linux.idc.xpu" }, + ]} + secrets: inherit + + linux-jammy-xpu-2025_1-py3_9-test: + name: linux-jammy-xpu-2025.1-py3.9 + uses: ./.github/workflows/_xpu-test.yml + needs: linux-jammy-xpu-2025_1-py3_9-build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) permissions: id-token: write contents: read with: +<<<<<<< HEAD build-environment: linux-jammy-xpu-n-py3.10 docker-image: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.test-matrix }} @@ -89,11 +137,37 @@ jobs: uses: ./.github/workflows/_win-build.yml with: build-environment: win-vs2022-xpu-n-1-py3 +======= + build-environment: linux-jammy-xpu-2025.1-py3.9 + docker-image: ${{ needs.linux-jammy-xpu-2025_1-py3_9-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-xpu-2025_1-py3_9-build.outputs.test-matrix }} + secrets: inherit + + windows-xpu-2025_0-build: + if: github.repository_owner == 'pytorch' + name: win-vs2022-xpu-2025_0-py3 + uses: ./.github/workflows/_win-build.yml + with: + build-environment: win-vs2022-xpu-py3 + cuda-version: cpu + use-xpu: true + xpu-version: '2025.0' + vc-year: '2022' + secrets: inherit + + windows-xpu-2025_1-build: + if: github.repository_owner == 'pytorch' + name: win-vs2022-xpu-2025_1-py3 + uses: ./.github/workflows/_win-build.yml + with: + build-environment: win-vs2022-xpu-py3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cuda-version: cpu use-xpu: true xpu-version: '2025.1' vc-year: '2022' secrets: inherit +<<<<<<< HEAD windows-xpu-n-build: if: github.repository_owner == 'pytorch' @@ -106,3 +180,5 @@ jobs: xpu-version: '2025.2' vc-year: '2022' secrets: inherit +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.gitignore b/.gitignore index f204868067965..47bb7a10d3c61 100644 --- a/.gitignore +++ b/.gitignore @@ -32,7 +32,10 @@ coverage.xml aten/build/ aten/src/ATen/Config.h aten/src/ATen/cuda/CUDAConfig.h +<<<<<<< HEAD aten/src/ATen/hip/HIPConfig.h +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) benchmarks/.data caffe2/cpp_test/ dist/ @@ -82,7 +85,10 @@ torch/return_types.pyi torch/nn/functional.pyi torch/utils/data/datapipes/datapipe.pyi torch/csrc/autograd/generated/* +<<<<<<< HEAD torch/csrc/functionalization/generated/* +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch/csrc/lazy/generated/*.[!m]* torch_compile_debug/ # Listed manually because some files in this directory are not generated @@ -148,9 +154,12 @@ merge_record.json torchgen/packaged/* !torchgen/packaged/README.md +<<<<<<< HEAD # This file is injected by ROCm build scripts to bootstrap in torch/__init__.py. torch/_rocm_init.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # IPython notebook checkpoints .ipynb_checkpoints diff --git a/.gitmodules b/.gitmodules index 4eb6e511127d0..b8457c3aad784 100644 --- a/.gitmodules +++ b/.gitmodules @@ -129,6 +129,9 @@ [submodule "third_party/flash-attention"] path = third_party/flash-attention url = https://github.com/Dao-AILab/flash-attention.git +<<<<<<< HEAD [submodule "third_party/aiter"] path = third_party/aiter url = https://github.com/ROCm/aiter.git +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/.lintrunner.toml b/.lintrunner.toml index 944829fa38977..7a98ea90e076e 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -39,6 +39,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', +<<<<<<< HEAD 'flake8==7.3.0', 'flake8-bugbear==24.12.12', 'flake8-comprehensions==3.16.0', @@ -49,6 +50,18 @@ init_command = [ 'mccabe==0.7.0', 'pycodestyle==2.14.0', 'pyflakes==3.4.0', +======= + 'flake8==6.1.0', + 'flake8-bugbear==23.3.23', + 'flake8-comprehensions==3.15.0', + 'flake8-executable==2.1.3', + 'flake8-logging-format==0.9.0', + 'flake8-pyi==23.3.1', + 'flake8-simplify==0.19.3', + 'mccabe==0.7.0', + 'pycodestyle==2.11.1', + 'pyflakes==3.1.0', +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'torchfix==0.4.0 ; python_version >= "3.9" and python_version < "3.13"', ] @@ -122,7 +135,10 @@ is_formatter = true [[linter]] code = 'MYPY' include_patterns = [ +<<<<<<< HEAD 'setup.py', +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'torch/**/*.py', 'torch/**/*.pyi', 'caffe2/**/*.py', @@ -132,7 +148,11 @@ include_patterns = [ 'test/test_complex.py', 'test/test_datapipe.py', 'test/test_futures.py', +<<<<<<< HEAD 'test/test_numpy_interop.py', +======= + # 'test/test_numpy_interop.py', +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'test/test_torch.py', 'test/test_type_hints.py', 'test/test_type_info.py', @@ -158,16 +178,27 @@ init_command = [ 'mypy==1.16.0', 'sympy==1.13.3', 'types-requests==2.27.25', +<<<<<<< HEAD 'types-pyyaml==6.0.2', +======= + 'types-pyyaml==6.0.1', +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'types-tabulate==0.8.8', 'types-protobuf==5.29.1.20250403', 'types-setuptools==79.0.0.20250422', 'types-jinja2==2.11.9', 'types-colorama==0.4.6', +<<<<<<< HEAD 'filelock==3.18.0', 'junitparser==2.1.1', 'rich==14.1.0', 'pyyaml==6.0.2', +======= + 'filelock==3.13.1', + 'junitparser==2.1.1', + 'rich==10.9.0', + 'pyyaml==6.0.1', +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'optree==0.13.0', 'dataclasses-json==0.6.7', 'pandas==2.2.3', @@ -231,8 +262,12 @@ include_patterns = [ 'c10/**/*.cpp', 'c10/**/*.h', 'torch/*.h', +<<<<<<< HEAD 'torch/_inductor/codegen/aoti_runtime/*.h', 'torch/_inductor/codegen/aoti_runtime/*.cpp', +======= + 'torch/_inductor/codegen/aoti_runtime/interface.cpp', +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'torch/csrc/*.h', 'torch/csrc/*.cpp', 'torch/csrc/**/*.h', @@ -500,7 +535,11 @@ include_patterns = [ '**/*.h', ] exclude_patterns = [ +<<<<<<< HEAD 'torch/headeronly/macros/Macros.h', +======= + 'c10/macros/Macros.h', +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] command = [ 'python3', @@ -523,7 +562,11 @@ include_patterns = [ '**/*.h', ] exclude_patterns = [ +<<<<<<< HEAD 'torch/headeronly/macros/Macros.h', +======= + 'c10/macros/Macros.h', +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] command = [ 'python3', @@ -583,7 +626,11 @@ exclude_patterns = [ command = [ 'python3', 'tools/linter/adapters/grep_linter.py', +<<<<<<< HEAD '--pattern=#include >>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) '--allowlist-pattern=#include ', '--linter-name=PYBIND11_INCLUDE', '--match-first-only', @@ -1111,7 +1158,11 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', +<<<<<<< HEAD 'pyyaml==6.0.2', +======= + 'PyYAML==6.0.1', +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] [[linter]] @@ -1133,7 +1184,11 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', +<<<<<<< HEAD 'pyyaml==6.0.2', +======= + 'PyYAML==6.0.1', +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] [[linter]] @@ -1158,6 +1213,7 @@ exclude_patterns = [ 'torch/_vendor/**', 'torch/_inductor/fx_passes/serialized_patterns/**', 'torch/_inductor/autoheuristic/artifacts/**', +<<<<<<< HEAD 'torch/utils/model_dump/preact.mjs', # These files are all grandfathered in, feel free to remove from this list # as necessary @@ -1165,6 +1221,31 @@ exclude_patterns = [ 'aten/src/ATen/native/[a-pA-P]*/**', 'aten/src/ATen/[a-mA-M]*/**', 'test/**', +======= + # These files are all grandfathered in, feel free to remove from this list + # as necessary + # NOTE: remove the patterns in the order they are listed + 'aten/**', + 'aten/src/ATen/native/**', + 'aten/src/ATen/native/q*/**', + 'aten/src/ATen/native/[a-pA-P]*/**', + 'aten/src/ATen/[a-mA-M]*/**', + 'test/**', + 'test/test_*', + 'test/[a-hA-h]*/**', + 'test/inductor/**', + 'test/dynamo/**', + 'test/distributed/**', + 'torch/**', + 'torch/_*/**', + 'torch/ao/**', + 'torch/fx/**', + 'torch/distributed/tensor/**', + 'torch/[j-o]*/**', + 'torch/utils/**', + 'torch/csrc/jit/**', + 'torch/csrc/jit/[a-o]*/**', +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] init_command = [ 'python3', @@ -1452,13 +1533,22 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', +<<<<<<< HEAD 'usort==1.0.8.post1', 'isort==6.0.1', 'ruff==0.12.9', # sync with RUFF +======= + '--no-black-binary', + 'black==23.12.1', + 'usort==1.0.8.post1', + 'isort==6.0.1', + 'ruff==0.11.13', # sync with RUFF +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] is_formatter = true [[linter]] +<<<<<<< HEAD code = 'PYPROJECT' command = [ 'python3', @@ -1503,6 +1593,8 @@ init_command = [ ] [[linter]] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) code = 'COPYRIGHT' include_patterns = ['**'] exclude_patterns = [ @@ -1589,7 +1681,11 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', +<<<<<<< HEAD 'ruff==0.12.9', # sync with PYFMT +======= + 'ruff==0.11.13', # sync with PYFMT +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] is_formatter = true @@ -1598,10 +1694,14 @@ is_formatter = true # the same line, merge conflicts should not arise in git or hg [[linter]] code = 'MERGE_CONFLICTLESS_CSV' +<<<<<<< HEAD include_patterns = [ 'benchmarks/dynamo/ci_expected_accuracy/*.csv', 'benchmarks/dynamo/pr_time_benchmarks/expected_results.csv', ] +======= +include_patterns = ['benchmarks/dynamo/ci_expected_accuracy/*.csv'] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) command = [ 'python3', 'tools/linter/adapters/no_merge_conflict_csv_linter.py', @@ -1792,6 +1892,7 @@ include_patterns = [ 'torch/header_only_apis.txt', ] is_formatter = false +<<<<<<< HEAD [[linter]] @@ -1801,3 +1902,5 @@ command = [ "python3", "tools/linter/adapters/gb_registry_linter.py", ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/AGENTS.md b/AGENTS.md index 3d5436a02a85d..dd27ff6213af6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,4 +1,5 @@ - This is the only AGENTS.md, there are no recursive AGENTS.md +<<<<<<< HEAD - When you are working on a bug, first create a standalone file that reproduces the bug and verify it fails in the expected way. Use this to test if your changes work. Once the change is passing, find an appropriate @@ -15,3 +16,5 @@ - git reset --hard $(cat /tmp/orig_work.txt) # NB: reset to the LOCAL branch, do NOT fetch - git stash pop - Resolve conflicts if necessary +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/BUILD.bazel b/BUILD.bazel index f13da6bfbe431..b47bb0a308fd6 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -91,8 +91,11 @@ generated_cpu_cpp = [ "aten/src/ATen/NativeMetaFunctions.h", "aten/src/ATen/RegistrationDeclarations.h", "aten/src/ATen/VmapGeneratedPlumbing.h", +<<<<<<< HEAD "aten/src/ATen/ViewMetaClasses.h", "aten/src/ATen/ViewMetaClasses.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "aten/src/ATen/core/aten_interned_strings.h", "aten/src/ATen/core/enum_tag.h", "aten/src/ATen/core/TensorBody.h", @@ -281,7 +284,10 @@ header_template_rule( "@AT_BLAS_F2C@": "0", "@AT_BLAS_USE_CBLAS_DOT@": "1", "@AT_KLEIDIAI_ENABLED@": "0", +<<<<<<< HEAD "@AT_USE_EIGEN_SPARSE@": "0", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, ) @@ -682,7 +688,10 @@ cc_library( [ "torch/*.h", "torch/csrc/**/*.h", +<<<<<<< HEAD "torch/nativert/**/*.h", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/csrc/distributed/c10d/**/*.hpp", "torch/lib/libshm/*.h", ], @@ -749,7 +758,10 @@ cc_library( "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu", "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp", +<<<<<<< HEAD "torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu", ], )) + torch_sources, @@ -1108,7 +1120,10 @@ test_suite( "aten/src/ATen/templates/LazyNonNativeIr.h", "aten/src/ATen/templates/RegisterDispatchKey.cpp", "aten/src/ATen/templates/RegisterDispatchDefinitions.ini", +<<<<<<< HEAD "aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml", "aten/src/ATen/native/ts_native_functions.yaml", diff --git a/CMakeLists.txt b/CMakeLists.txt index 91181735750d6..5a659f90ed3a6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -239,9 +239,13 @@ option(USE_XPU "Use XPU" ON) cmake_dependent_option( BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON "USE_CUDA AND LINUX AND BUILD_PYTHON" OFF) +<<<<<<< HEAD cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX OR WIN32" OFF) cmake_dependent_option(USE_ROCM_CK_GEMM "Use ROCm Composable Kernel for GEMMs" ON "USE_ROCM;NOT WIN32" OFF) option(USE_ROCM_CK_SDPA "Use ROCm Composable Kernel for SDPA" OFF) +======= +cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF) cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF) cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF @@ -253,6 +257,10 @@ cmake_dependent_option(USE_CUFILE "Use cuFile" ON "USE_CUDA AND NOT WIN32" OFF) option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON) option(USE_KINETO "Use Kineto profiling library" ON) option(USE_CUPTI_SO "Use CUPTI as a shared library" ON) +<<<<<<< HEAD +======= +option(USE_FAKELOWP "Use FakeLowp operators" OFF) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) option(USE_GFLAGS "Use GFLAGS" OFF) option(USE_GLOG "Use GLOG" OFF) option(USE_LITE_PROTO "Use lite protobuf instead of full." OFF) @@ -261,6 +269,7 @@ option(USE_PYTORCH_METAL "Use Metal for PyTorch iOS build" OFF) option(USE_PYTORCH_METAL_EXPORT "Export Metal models on MacOSX desktop" OFF) option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF) +<<<<<<< HEAD option(USE_DISTRIBUTED "Use distributed" ON) cmake_dependent_option(USE_NCCL "Use NCCL" ON "USE_DISTRIBUTED;USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) @@ -273,6 +282,16 @@ cmake_dependent_option(USE_SYSTEM_NCCL "Use system-wide NCCL" OFF "USE_NCCL" OFF) cmake_dependent_option(USE_NVSHMEM "Use NVSHMEM" ON "USE_DISTRIBUTED;USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) +======= +cmake_dependent_option(USE_NCCL "Use NCCL" ON + "USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) +cmake_dependent_option(USE_XCCL "Use XCCL" ON + "USE_XPU;UNIX;NOT APPLE" OFF) +cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF) +cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF) +cmake_dependent_option(USE_SYSTEM_NCCL "Use system-wide NCCL" OFF "USE_NCCL" + OFF) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) option(USE_NNAPI "Use NNAPI" OFF) option(USE_NNPACK "Use NNPACK" ON) cmake_dependent_option(USE_NUMA "Use NUMA. Only available on Linux." ON "LINUX" @@ -289,7 +308,10 @@ option(USE_PRECOMPILED_HEADERS "Use pre-compiled headers to accelerate build." option(USE_PROF "Use profiling" OFF) option(USE_PYTORCH_QNNPACK "Use ATen/QNNPACK (quantized 8-bit operators)" ON) option(USE_SNPE "Use Qualcomm's SNPE library" OFF) +<<<<<<< HEAD option(USE_EIGEN_SPARSE "Use Eigen Sparse Matrices" OFF) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) option(USE_SYSTEM_EIGEN_INSTALL "Use system Eigen instead of the one under third_party" OFF) cmake_dependent_option( @@ -326,6 +348,10 @@ set(MKLDNN_ENABLE_CONCURRENT_EXEC ${USE_MKLDNN}) cmake_dependent_option(USE_MKLDNN_CBLAS "Use CBLAS in MKLDNN" OFF "USE_MKLDNN" OFF) option(USE_STATIC_MKL "Prefer to link with MKL statically (Unix only)" OFF) +<<<<<<< HEAD +======= +option(USE_DISTRIBUTED "Use distributed" ON) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cmake_dependent_option( USE_MPI "Use MPI for Caffe2. Only available if USE_DISTRIBUTED is on." ON "USE_DISTRIBUTED" OFF) @@ -567,7 +593,11 @@ if(MSVC) set(CMAKE_NINJA_CMCLDEPS_RC OFF) if(MSVC_Z7_OVERRIDE) # CMake set debug flags to use /Z7 +<<<<<<< HEAD set(CMAKE_MSVC_DEBUG_INFORMATION_FORMAT "$<$:Embedded>") +======= + set(CMAKE_MSVC_DEBUG_INFORMATION_FORMAT Embedded) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() foreach( flag_var @@ -837,11 +867,18 @@ include(ExternalProject) # ---[ Dependencies ---[ FBGEMM doesn't work on x86 32bit and # CMAKE_SYSTEM_PROCESSOR thinks its 64bit +<<<<<<< HEAD if(USE_FBGEMM AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") message(WARNING "x64 operating system is required for FBGEMM. " "Not compiling with FBGEMM. " "Turn this warning off by USE_FBGEMM=OFF.") +======= +if(USE_FBGEMM + AND((CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND CMAKE_SIZEOF_VOID_P EQUAL + 4) + OR CMAKE_SYSTEM_PROCESSOR STREQUAL "x86")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) set(USE_FBGEMM OFF) endif() @@ -876,6 +913,7 @@ cmake_dependent_option( "(USE_CUDA AND NOT MSVC) OR USE_ROCM" OFF) +<<<<<<< HEAD cmake_dependent_option( USE_FBGEMM_GENAI "Whether to build FBGEMM GenAI quantized GEMM kernels.\ @@ -895,6 +933,8 @@ if(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0a") set(USE_FBGEMM_GENAI ON) endif() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem # Eff Attention won't cmake_dependent_option( @@ -928,10 +968,13 @@ if(USE_FBGEMM) string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM") endif() +<<<<<<< HEAD if(USE_FBGEMM_GENAI) string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM_GENAI") endif() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if(USE_PYTORCH_QNNPACK) string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_QNNPACK") endif() @@ -1208,7 +1251,11 @@ if(APPLE) string( APPEND CMAKE_SHARED_LINKER_FLAGS +<<<<<<< HEAD " -weak_framework Foundation -weak_framework MetalPerformanceShaders -weak_framework MetalPerformanceShadersGraph -weak_framework Metal -weak_framework IOKit" +======= + " -weak_framework Foundation -weak_framework MetalPerformanceShaders -weak_framework MetalPerformanceShadersGraph -weak_framework Metal" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # To suppress MPSGraph availability warnings append_cxx_flag_if_supported("-Wno-unguarded-availability-new" @@ -1217,6 +1264,13 @@ if(APPLE) append_cxx_flag_if_supported("-Wno-missing-braces" CMAKE_CXX_FLAGS) endif() +<<<<<<< HEAD +======= +if(USE_XPU) + string(APPEND CMAKE_CXX_FLAGS " -DUSE_XPU") +endif() + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if(EMSCRIPTEN) string( APPEND @@ -1268,7 +1322,10 @@ if(USE_MIMALLOC AND USE_MIMALLOC_ON_MKL) endif() # ---[ Main build +<<<<<<< HEAD add_subdirectory(torch/headeronly) # headeronly headers +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add_subdirectory(c10) add_subdirectory(caffe2) diff --git a/CODEOWNERS b/CODEOWNERS index 1d91adacb0629..ee66b0dca5f78 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -14,6 +14,10 @@ /torch/csrc/autograd/ @albanD @soulitzer /torch/autograd/ @albanD @soulitzer /tools/autograd/ @albanD @soulitzer +<<<<<<< HEAD +======= +/torch/header_only_apis.txt @janeyx99 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /torch/nn/ @albanD @jbschlosser @mikaylagawarecki /torch/optim/ @albanD @janeyx99 /test/test_public_bindings.py @albanD @@ -50,12 +54,21 @@ nn/qat/ @jerryzh168 /torch/csrc/distributed/c10d/Ops.* @kwen2501 # ONNX Export +<<<<<<< HEAD /torch/_dynamo/backends/onnxrt.py @titaiwangms @xadupre @justinchuby /torch/csrc/jit/passes/onnx.h @titaiwangms @xadupre /torch/csrc/jit/passes/onnx.cpp @titaiwangms @xadupre /torch/csrc/jit/passes/onnx/ @titaiwangms @xadupre /torch/onnx/ @titaiwangms @xadupre @justinchuby /test/onnx/ @titaiwangms @xadupre @justinchuby +======= +/torch/_dynamo/backends/onnxrt.py @wschin +/torch/csrc/jit/passes/onnx.h @titaiwangms @shubhambhokare1 +/torch/csrc/jit/passes/onnx.cpp @titaiwangms @shubhambhokare1 +/torch/csrc/jit/passes/onnx/ @titaiwangms @shubhambhokare1 +/torch/onnx/ @titaiwangms @shubhambhokare1 @justinchuby @wschin +/test/onnx/ @titaiwangms @shubhambhokare1 @justinchuby @wschin +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # CI /.ci @pytorch/pytorch-dev-infra @@ -135,7 +148,11 @@ torch/profiler/ @sraikund16 test/functorch/test_aotdispatch.py @ezyang @Chillee # Dataloader +<<<<<<< HEAD torch/utils/data/ @divyanshk @ramanishsingh @scotts +======= +torch/utils/data/ @divyanshk @ramanishsingh +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # hipify torch/utils/hipify/ @jeffdaily @jithunnair-amd @@ -164,7 +181,10 @@ caffe2/utils/hip @jeffdaily @jithunnair-amd # torch.export /torch/export/ @avikchaudhuri @tugsbayasgalan @zhxchen17 @ydwu4 @angelayi /torch/_export/ @avikchaudhuri @tugsbayasgalan @zhxchen17 @ydwu4 @angelayi +<<<<<<< HEAD /torch/_export/serde/schema.py @SherlockNoMad @zhxchen17 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Dynamic Shapes /torch/fx/experimental/symbolic_shapes.py @bobrenjc93 @laithsakka @@ -196,8 +216,11 @@ torch/backends/cudnn/ @eqy @syed-ahmed /torch/utils/_cxx_pytree.py @XuehaiPan /torch/utils/pytree/ @XuehaiPan /torch/_dynamo/polyfills/pytree.py @XuehaiPan +<<<<<<< HEAD # Relating to libtorch ABI /torch/csrc/stable/ @janeyx99 @mikaylagawarecki /torch/headeronly/ @janeyx99 /torch/header_only_apis.txt @janeyx99 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9d2b5d3553910..9f5edc4027483 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -88,6 +88,7 @@ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows * If you want to have no-op incremental rebuilds (which are fast), see [Make no-op build fast](#make-no-op-build-fast) below. +<<<<<<< HEAD * When installing with `python -m pip install -e . -v --no-build-isolation` (in contrast to `python -m pip install . -v --no-build-isolation`) Python runtime will use the current local source-tree when importing `torch` package. (This is done by creating [`.egg-link`](https://wiki.python.org/moin/PythonPackagingTerminology#egg-link) file in `site-packages` folder) This way you do not need to repeatedly install after modifying Python files (`.py`). @@ -101,6 +102,22 @@ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows ``` Afterwards rebuilding a library (for example to rebuild `libtorch_cpu.so` issue `ninja torch_cpu` from `build` folder), would be sufficient to make change visible in `torch` package. +======= +* When installing with `python setup.py develop` (in contrast to `python setup.py install`) Python runtime will use + the current local source-tree when importing `torch` package. (This is done by creating [`.egg-link`](https://wiki.python.org/moin/PythonPackagingTerminology#egg-link) file in `site-packages` folder) + This way you do not need to repeatedly install after modifying Python files (`.py`). + However, you would need to reinstall if you modify Python interface (`.pyi`, `.pyi.in`) or + non-Python files (`.cpp`, `.cc`, `.cu`, `.h`, ...). + + + One way to avoid running `python setup.py develop` every time one makes a change to C++/CUDA/ObjectiveC files on Linux/Mac, + is to create a symbolic link from `build` folder to `torch/lib`, for example, by issuing following: + ```bash + pushd torch/lib; sh -c "ln -sf ../../build/lib/libtorch_cpu.* ."; popd + ``` + Afterwards rebuilding a library (for example to rebuild `libtorch_cpu.so` issue `ninja torch_cpu` from `build` folder), + would be sufficient to make change visible in `torch` package. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) To reinstall, first uninstall all existing PyTorch installs. You may need to run `pip @@ -114,9 +131,15 @@ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows pip uninstall torch ``` +<<<<<<< HEAD Next run `python setup.py clean`. After that, you can install in editable mode again. * If you run into errors when running `python -m pip install -e . -v --no-build-isolation`, here are some debugging steps: +======= + Next run `python setup.py clean`. After that, you can install in `develop` mode again. + +* If you run into errors when running `python setup.py develop`, here are some debugging steps: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 1. Run `printf '#include \nint main() { printf("Hello World");}'|clang -x c -; ./a.out` to make sure your CMake works and can compile this simple Hello World program without errors. 2. Nuke your `build` directory. The `setup.py` script compiles binaries into the `build` folder and caches many @@ -129,6 +152,7 @@ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows git clean -xdf python setup.py clean git submodule update --init --recursive +<<<<<<< HEAD python -m pip install --group dev python -m pip install --no-build-isolation -v -e . ``` @@ -143,6 +167,15 @@ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows python -m pip install --no-build-isolation -v -e . ``` +======= + python setup.py develop + ``` + 4. The main step within `python setup.py develop` is running `make` from the `build` directory. If you want to + experiment with some environment variables, you can pass them into the command: + ```bash + ENV_KEY1=ENV_VAL1[, ENV_KEY2=ENV_VAL2]* python setup.py develop + ``` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * If you run into issue running `git submodule update --init --recursive`. Please try the following: - If you encounter an error such as @@ -259,7 +292,10 @@ dependencies as well as the nightly binaries into the repo directory. support for PyTorch. * [tools](tools) - Code generation scripts for the PyTorch library. See [README](tools/README.md) of this directory for more details. +<<<<<<< HEAD * [torchgen](torchgen) - contains the logic and tooling for generating PyTorch's low-level C++ and Python bindings from operator definitions, typically specified in native_functions.yaml +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * [test](test) - Python unit tests for PyTorch Python frontend. * [test_torch.py](test/test_torch.py) - Basic tests for PyTorch functionality. @@ -295,7 +331,11 @@ The following packages should be installed with `pip`: - `pytest` - recommended to run tests more selectively Running ``` +<<<<<<< HEAD pip install --group dev +======= +pip install -r requirements.txt +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` will install these dependencies for you. @@ -646,9 +686,15 @@ can be selected interactively with your mouse to zoom in on a particular part of the program execution timeline. The `--native` command-line option tells `py-spy` to record stack frame entries for PyTorch C++ code. To get line numbers for C++ code it may be necessary to compile PyTorch in debug mode by prepending +<<<<<<< HEAD your `python -m pip install -e . -v --no-build-isolation` call to compile PyTorch with `DEBUG=1`. Depending on your operating system it may also be necessary to run `py-spy` with root privileges. +======= +your `setup.py develop` call to compile PyTorch with `DEBUG=1`. Depending on +your operating system it may also be necessary to run `py-spy` with root +privileges. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) `py-spy` can also work in an `htop`-like "live profiling" mode and can be tweaked to adjust the stack sampling rate, see the `py-spy` readme for more @@ -656,10 +702,17 @@ details. ## Managing multiple build trees +<<<<<<< HEAD One downside to using `python -m pip install -e . -v --no-build-isolation` is that your development version of PyTorch will be installed globally on your account (e.g., if you run `import torch` anywhere else, the development version will be used). +======= +One downside to using `python setup.py develop` is that your development +version of PyTorch will be installed globally on your account (e.g., if +you run `import torch` anywhere else, the development version will be +used). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) If you want to manage multiple builds of PyTorch, you can make use of [venv environments](https://docs.python.org/3/library/venv.html) to maintain @@ -670,7 +723,11 @@ specific build of PyTorch. To set one up: python -m venv pytorch-myfeature source pytorch-myfeature/bin/activate # or `& .\pytorch-myfeature\Scripts\Activate.ps1` on Windows # if you run python now, torch will NOT be installed +<<<<<<< HEAD python -m pip install --no-build-isolation -v -e . +======= +python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` ## C++ development tips @@ -708,9 +765,13 @@ variables `DEBUG`, `USE_DISTRIBUTED`, `USE_MKLDNN`, `USE_CUDA`, `USE_FLASH_ATTEN For example: ```bash +<<<<<<< HEAD DEBUG=1 USE_DISTRIBUTED=0 USE_MKLDNN=0 USE_CUDA=0 BUILD_TEST=0 \ USE_FBGEMM=0 USE_NNPACK=0 USE_QNNPACK=0 USE_XNNPACK=0 \ python -m pip install --no-build-isolation -v -e . +======= +DEBUG=1 USE_DISTRIBUTED=0 USE_MKLDNN=0 USE_CUDA=0 BUILD_TEST=0 USE_FBGEMM=0 USE_NNPACK=0 USE_QNNPACK=0 USE_XNNPACK=0 python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` For subsequent builds (i.e., when `build/CMakeCache.txt` exists), the build @@ -720,7 +781,11 @@ options. ### Code completion and IDE support +<<<<<<< HEAD When using `python -m pip install -e . -v --no-build-isolation`, PyTorch will generate +======= +When using `python setup.py develop`, PyTorch will generate +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a `compile_commands.json` file that can be used by many editors to provide command completion and error highlighting for PyTorch's C++ code. You need to `pip install ninja` to generate accurate @@ -781,7 +846,11 @@ If not, you can define these variables on the command line before invoking `setu export CMAKE_C_COMPILER_LAUNCHER=ccache export CMAKE_CXX_COMPILER_LAUNCHER=ccache export CMAKE_CUDA_COMPILER_LAUNCHER=ccache +<<<<<<< HEAD python -m pip install --no-build-isolation -v -e . +======= +python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` #### Use a faster linker @@ -794,7 +863,11 @@ If you are editing a single file and rebuilding in a tight loop, the time spent Starting with CMake 3.29, you can specify the linker type using the [`CMAKE_LINKER_TYPE`](https://cmake.org/cmake/help/latest/variable/CMAKE_LINKER_TYPE.html) variable. For example, with `mold` installed: ```sh +<<<<<<< HEAD CMAKE_LINKER_TYPE=MOLD python -m pip install --no-build-isolation -v -e . +======= +CMAKE_LINKER_TYPE=MOLD python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` #### Use pre-compiled headers @@ -806,7 +879,11 @@ setting `USE_PRECOMPILED_HEADERS=1` either on first setup, or in the `CMakeCache.txt` file. ```sh +<<<<<<< HEAD USE_PRECOMPILED_HEADERS=1 python -m pip install --no-build-isolation -v -e . +======= +USE_PRECOMPILED_HEADERS=1 python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` This adds a build step where the compiler takes `` and essentially @@ -829,7 +906,11 @@ A compiler-wrapper to fix this is provided in `tools/nvcc_fix_deps.py`. You can this as a compiler launcher, similar to `ccache` ```bash export CMAKE_CUDA_COMPILER_LAUNCHER="python;`pwd`/tools/nvcc_fix_deps.py;ccache" +<<<<<<< HEAD python -m pip install --no-build-isolation -v -e . +======= +python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` ### Rebuild few files with debug information @@ -1180,7 +1261,11 @@ build_with_asan() CFLAGS="-fsanitize=address -fno-sanitize-recover=all -shared-libasan -pthread" \ CXX_FLAGS="-pthread" \ USE_CUDA=0 USE_OPENMP=0 USE_DISTRIBUTED=0 DEBUG=1 \ +<<<<<<< HEAD python -m pip install --no-build-isolation -v -e . +======= + python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } run_with_asan() diff --git a/Dockerfile b/Dockerfile index 331cf00593cb2..dceb1b1bb9663 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,7 +33,11 @@ RUN case ${TARGETPLATFORM} in \ *) MINICONDA_ARCH=x86_64 ;; \ esac && \ curl -fsSL -v -o ~/miniconda.sh -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-${MINICONDA_ARCH}.sh" +<<<<<<< HEAD COPY requirements.txt requirements-build.txt . +======= +COPY requirements.txt . +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Manually invoke bash on miniconda script per https://github.com/conda/conda/issues/10431 RUN chmod +x ~/miniconda.sh && \ bash ~/miniconda.sh -b -p /opt/conda && \ @@ -47,6 +51,7 @@ WORKDIR /opt/pytorch COPY . . RUN git submodule update --init --recursive +<<<<<<< HEAD FROM conda as conda-installs ARG PYTHON_VERSION=3.11 ARG CUDA_PATH=cu121 @@ -54,6 +59,28 @@ ARG INSTALL_CHANNEL=whl/nightly # Automatically set by buildx # pinning version of conda here see: https://github.com/pytorch/pytorch/issues/164574 RUN /opt/conda/bin/conda install -y python=${PYTHON_VERSION} conda=25.7.0 +======= +FROM conda as build +ARG CMAKE_VARS +WORKDIR /opt/pytorch +COPY --from=conda /opt/conda /opt/conda +COPY --from=submodule-update /opt/pytorch /opt/pytorch +RUN make triton +RUN --mount=type=cache,target=/opt/ccache \ + export eval ${CMAKE_VARS} && \ + TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 8.9 9.0 9.0a" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \ + CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \ + python setup.py install + +FROM conda as conda-installs +ARG PYTHON_VERSION=3.11 +ARG CUDA_PATH=cu121 +ARG CUDA_CHANNEL=nvidia +ARG INSTALL_CHANNEL=whl/nightly +# Automatically set by buildx +RUN /opt/conda/bin/conda update -y -n base -c defaults conda +RUN /opt/conda/bin/conda install -y python=${PYTHON_VERSION} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ARG TARGETPLATFORM @@ -96,5 +123,9 @@ WORKDIR /workspace FROM official as dev # Should override the already installed version from the official-image stage +<<<<<<< HEAD COPY --from=conda /opt/conda /opt/conda COPY --from=submodule-update /opt/pytorch /opt/pytorch +======= +COPY --from=build /opt/conda /opt/conda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/MANIFEST.in b/MANIFEST.in index ec00f251160b7..01ed986a3d4db 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,4 @@ +<<<<<<< HEAD # Reference: https://setuptools.pypa.io/en/latest/userguide/miscellaneous.html # Include source files in SDist @@ -48,3 +49,36 @@ global-exclude *.o *.obj *.so *.a *.dylib *.pxd *.dll *.lib *.py[cod] prune */.git global-exclude .git *~ *.swp +======= +include MANIFEST.in +include CMakeLists.txt +include CITATION.cff +include LICENSE +include NOTICE +include .gitmodules +include build_variables.bzl +include mypy.ini +include requirements.txt +include ufunc_defs.bzl +include version.txt +recursive-include android *.* +recursive-include aten *.* +recursive-include binaries *.* +recursive-include c10 *.* +recursive-include caffe2 *.* +recursive-include cmake *.* +recursive-include torch *.* +recursive-include tools *.* +recursive-include test *.* +recursive-include docs *.* +recursive-include ios *.* +recursive-include third_party * +recursive-include test *.* +recursive-include benchmarks *.* +recursive-include scripts *.* +recursive-include mypy_plugins *.* +recursive-include modules *.* +recursive-include functorch *.* +prune */__pycache__ +global-exclude *.o *.so *.dylib *.a .git *.pyc *.swp +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/Makefile b/Makefile index 3db2b7aa44e76..3dc907d125f8f 100644 --- a/Makefile +++ b/Makefile @@ -57,8 +57,12 @@ setup-env-cuda: setup-env-rocm: $(MAKE) setup-env PYTHON="$(PYTHON)" NIGHTLY_TOOL_OPTS="$(NIGHTLY_TOOL_OPTS) --rocm" +<<<<<<< HEAD .PHONY: setup-lint setup-lint .lintbin/.lintrunner.sha256: requirements.txt pyproject.toml .lintrunner.toml +======= +.lintbin/.lintrunner.sha256: requirements.txt pyproject.toml .lintrunner.toml +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @echo "Setting up lintrunner..." $(PIP) install lintrunner lintrunner init @@ -66,6 +70,12 @@ setup-lint .lintbin/.lintrunner.sha256: requirements.txt pyproject.toml .lintrun @mkdir -p .lintbin @sha256sum requirements.txt pyproject.toml .lintrunner.toml > .lintbin/.lintrunner.sha256 +<<<<<<< HEAD +======= +.PHONY: setup-lint +setup-lint: .lintbin/.lintrunner.sha256 + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .PHONY: lazy-setup-lint lazy-setup-lint: .lintbin/.lintrunner.sha256 @if [ ! -x "$(shell command -v lintrunner)" ]; then \ diff --git a/README.md b/README.md index 99e6dabd16181..79206f3c9747a 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,8 @@ +<<<<<<< HEAD ![PyTorch Logo](https://github.com/pytorch/pytorch/blob/9708fcf92db88b80b9010c68662d634434da3106/docs/source/_static/img/pytorch-logo-dark.png) +======= +![PyTorch Logo](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/pytorch-logo-dark.png) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -------------------------------------------------------------------------------- @@ -72,7 +76,11 @@ Elaborating Further: If you use NumPy, then you have used Tensors (a.k.a. ndarray). +<<<<<<< HEAD ![Tensor illustration](https://github.com/pytorch/pytorch/blob/9708fcf92db88b80b9010c68662d634434da3106/docs/source/_static/img/tensor_illustration.png) +======= +![Tensor illustration](./docs/source/_static/img/tensor_illustration.png) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PyTorch provides Tensors that can live either on the CPU or the GPU and accelerates the computation by a huge amount. @@ -99,7 +107,11 @@ from several research papers on this topic, as well as current and past work suc While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date. You get the best of speed and flexibility for your crazy research. +<<<<<<< HEAD ![Dynamic graph](https://github.com/pytorch/pytorch/blob/9708fcf92db88b80b9010c68662d634434da3106/docs/source/_static/img/dynamic_graph.gif) +======= +![Dynamic graph](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/dynamic_graph.gif) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ### Python First @@ -200,7 +212,11 @@ If you want to compile with CUDA support, [select a supported version of CUDA fr - [NVIDIA cuDNN](https://developer.nvidia.com/cudnn) v8.5 or above - [Compiler](https://gist.github.com/ax3l/9489132) compatible with CUDA +<<<<<<< HEAD Note: You could refer to the [cuDNN Support Matrix](https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html) for cuDNN versions with the various supported CUDA, CUDA driver, and NVIDIA hardware. +======= +Note: You could refer to the [cuDNN Support Matrix](https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html) for cuDNN versions with the various supported CUDA, CUDA driver and NVIDIA hardware +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) If you want to disable CUDA support, export the environment variable `USE_CUDA=0`. Other potentially useful environment variables may be found in `setup.py`. If @@ -228,7 +244,10 @@ If you want to disable Intel GPU support, export the environment variable `USE_X Other potentially useful environment variables may be found in `setup.py`. #### Get the PyTorch Source +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ```bash git clone https://github.com/pytorch/pytorch cd pytorch @@ -242,8 +261,14 @@ git submodule update --init --recursive **Common** ```bash +<<<<<<< HEAD # Run this command from the PyTorch directory after cloning the source code using the ā€œGet the PyTorch Sourceā€œ section above pip install --group dev +======= +conda install cmake ninja +# Run this command from the PyTorch directory after cloning the source code using the ā€œGet the PyTorch Sourceā€œ section below +pip install -r requirements.txt +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` **On Linux** @@ -275,6 +300,7 @@ conda install pkg-config libuv pip install mkl-static mkl-include # Add these packages if torch.distributed is needed. # Distributed package support on Windows is a prototype feature and is subject to changes. +<<<<<<< HEAD conda install -c conda-forge libuv ``` @@ -284,22 +310,41 @@ conda install -c conda-forge libuv If you're compiling for AMD ROCm then first run this command: +======= +conda install -c conda-forge libuv=1.39 +``` + +#### Install PyTorch +**On Linux** + +If you're compiling for AMD ROCm then first run this command: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ```bash # Only run this if you're compiling for ROCm python tools/amd_build/build_amd.py ``` Install PyTorch +<<<<<<< HEAD ```bash export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" python -m pip install --no-build-isolation -v -e . +======= +```bash +export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" +python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` **On macOS** ```bash +<<<<<<< HEAD python -m pip install --no-build-isolation -v -e . +======= +python3 setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` **On Windows** @@ -311,7 +356,11 @@ If you want to build legacy python code, please refer to [Building on legacy cod In this mode PyTorch computations will run on your CPU, not your GPU. ```cmd +<<<<<<< HEAD python -m pip install --no-build-isolation -v -e . +======= +python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` Note on OpenMP: The desired OpenMP implementation is Intel OpenMP (iomp). In order to link against iomp, you'll need to manually download the library and set up the building environment by tweaking `CMAKE_INCLUDE_PATH` and `LIB`. The instruction [here](https://github.com/pytorch/pytorch/blob/main/docs/source/notes/windows.rst#building-from-source) is an example for setting up both MKL and Intel OpenMP. Without these configurations for CMake, Microsoft Visual C OpenMP runtime (vcomp) will be used. @@ -332,6 +381,10 @@ Additional libraries such as You can refer to the [build_pytorch.bat](https://github.com/pytorch/pytorch/blob/main/.ci/pytorch/win-test-helpers/build_pytorch.bat) script for some other environment variables configurations +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ```cmd cmd @@ -351,7 +404,12 @@ for /f "usebackq tokens=*" %i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\ :: [Optional] If you want to override the CUDA host compiler set CUDAHOSTCXX=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.27.29110\bin\HostX64\x64\cl.exe +<<<<<<< HEAD python -m pip install --no-build-isolation -v -e . +======= +python setup.py develop + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` **Intel GPU builds** @@ -373,7 +431,11 @@ if defined CMAKE_PREFIX_PATH ( set "CMAKE_PREFIX_PATH=%CONDA_PREFIX%\Library" ) +<<<<<<< HEAD python -m pip install --no-build-isolation -v -e . +======= +python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` ##### Adjust Build Options (Optional) @@ -383,7 +445,10 @@ the following. For example, adjusting the pre-detected directories for CuDNN or with such a step. On Linux +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ```bash export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" CMAKE_ONLY=1 python setup.py build @@ -391,10 +456,16 @@ ccmake build # or cmake-gui build ``` On macOS +<<<<<<< HEAD ```bash export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" MACOSX_DEPLOYMENT_TARGET=11.0 CMAKE_ONLY=1 python setup.py build +======= +```bash +export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" +MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ CMAKE_ONLY=1 python setup.py build +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ccmake build # or cmake-gui build ``` @@ -517,7 +588,11 @@ on [our website](https://pytorch.org/get-started/previous-versions). ## Getting Started +<<<<<<< HEAD Three pointers to get you started: +======= +Three-pointers to get you started: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - [Tutorials: get you started with understanding and using PyTorch](https://pytorch.org/tutorials/) - [Examples: easy to understand PyTorch code across all domains](https://github.com/pytorch/examples) - [The API Reference](https://pytorch.org/docs/) @@ -559,7 +634,11 @@ To learn more about making a contribution to Pytorch, please see our [Contributi PyTorch is a community-driven project with several skillful engineers and researchers contributing to it. +<<<<<<< HEAD PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), [Alban Desmaison](https://github.com/albanD), [Piotr Bialecki](https://github.com/ptrblck) and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means. +======= +PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) A non-exhaustive but growing list needs to mention: [Trevor Killeen](https://github.com/killeent), [Sasank Chilamkurthy](https://github.com/chsasank), [Sergey Zagoruyko](https://github.com/szagoruyko), [Adam Lerer](https://github.com/adamlerer), [Francisco Massa](https://github.com/fmassa), [Alykhan Tejani](https://github.com/alykhantejani), [Luca Antiga](https://github.com/lantiga), [Alban Desmaison](https://github.com/albanD), [Andreas Koepf](https://github.com/andreaskoepf), [James Bradbury](https://github.com/jekbradbury), [Zeming Lin](https://github.com/ebetica), [Yuandong Tian](https://github.com/yuandong-tian), [Guillaume Lample](https://github.com/glample), [Marat Dukhan](https://github.com/Maratyszcza), [Natalia Gimelshein](https://github.com/ngimel), [Christian Sarofeen](https://github.com/csarofeen), [Martin Raison](https://github.com/martinraison), [Edward Yang](https://github.com/ezyang), [Zachary Devito](https://github.com/zdevito). Note: This project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch. diff --git a/RELEASE.md b/RELEASE.md index 047bb10161f71..df3f57b29d554 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -50,7 +50,10 @@ Following is the Release Compatibility Matrix for PyTorch releases: | PyTorch version | Python | C++ | Stable CUDA | Experimental CUDA | Stable ROCm | | --- | --- | --- | --- | --- | --- | +<<<<<<< HEAD | 2.8 | >=3.9, <=3.13, (3.13t experimental) | C++17 | CUDA 12.6 (CUDNN 9.10.2.21), CUDA 12.8 (CUDNN 9.10.2.21) | CUDA 12.9 (CUDNN 9.10.2.21) | ROCm 6.4 | +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) | 2.7 | >=3.9, <=3.13, (3.13t experimental) | C++17 | CUDA 11.8 (CUDNN 9.1.0.70), CUDA 12.6 (CUDNN 9.5.1.17) | CUDA 12.8 (CUDNN 9.7.1.26) | ROCm 6.3 | | 2.6 | >=3.9, <=3.13, (3.13t experimental) | C++17 | CUDA 11.8, CUDA 12.4 (CUDNN 9.1.0.70) | CUDA 12.6 (CUDNN 9.5.1.17) | ROCm 6.2.4 | | 2.5 | >=3.9, <=3.12, (3.13 experimental) | C++17 | CUDA 11.8, CUDA 12.1, CUDA 12.4, CUDNN 9.1.0.70 | None | ROCm 6.2 | @@ -74,9 +77,15 @@ Following is the release cadence. All future dates below are tentative. For late | 2.4 | Jun 2024 | Jul 2024 | Sept 2024 | Not planned | | 2.5 | Sep 2024 | Oct 2024 | Nov 2024 | Not planned | | 2.6 | Dec 2024 | Jan 2025 | Not planned | Not planned | +<<<<<<< HEAD | 2.7 | Mar 2025 | Apr 2025 | Jun 2025 | Not planned | | 2.8 | Jun 2025 | Jul 2025 | (Aug 2025) | (Sep 2025) | | 2.9 | Sept 2025 | Oct 2025 | (Nov 2025) | (Dec 2025) | +======= +| 2.7 | Mar 2025 | Apr 2025 | (May 2025) | (Jun 2025) | +| 2.8 | Jun 2025 | Jul 2025 | (Aug 2025) | (Sep 2025) | +| 2.9 | Aug 2025 | Oct 2025 | (Nov 2025) | (Dec 2025) | +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) | 2.10 | Dec 2025 | Jan 2026 | (Feb 2026) | (Mar 2026) | | 2.11 | Mar 2026 | Apr 2026 | (Jun 2026) | (Jul 2026) | diff --git a/android/README.md b/android/README.md index f0c74750522de..102a795fed980 100644 --- a/android/README.md +++ b/android/README.md @@ -2,7 +2,11 @@ ## Demo applications and tutorials +<<<<<<< HEAD Please refer to [meta-pytorch/executorch-examples](https://github.com/meta-pytorch/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch). +======= +Please refer to [pytorch-labs/executorch-examples](https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Please join our [Discord](https://discord.com/channels/1334270993966825602/1349854760299270284) for any questions. diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index b30d8336e8ec9..124691d1c79d2 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -96,8 +96,11 @@ file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp") file(GLOB vulkan_cpp "vulkan/*.cpp") file(GLOB native_vulkan_cpp "native/vulkan/*.cpp" "native/vulkan/api/*.cpp" "native/vulkan/impl/*.cpp" "native/vulkan/ops/*.cpp") +<<<<<<< HEAD file(GLOB native_eigen_cpp "native/sparse/eigen/*.cpp") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Metal file(GLOB metal_h "metal/*.h") file(GLOB metal_cpp "metal/*.cpp") @@ -121,8 +124,11 @@ file(GLOB_RECURSE native_mps_cpp "native/mps/*.cpp") file(GLOB_RECURSE native_mps_mm "native/mps/*.mm") file(GLOB_RECURSE native_mps_metal "native/mps/*.metal") file(GLOB_RECURSE native_mps_h "native/mps/*.h") +<<<<<<< HEAD file(GLOB_RECURSE native_sparse_mps_mm "native/sparse/mps/*.mm") file(GLOB_RECURSE native_mps_sparse_metal "native/sparse/mps/*.metal") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) file(GLOB native_sparse_cpp "native/sparse/*.cpp") file(GLOB native_quantized_cpp @@ -182,6 +188,7 @@ file(GLOB native_flash_attn_api_cpp "native/transformers/cuda/flash_attn/flash_a file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip") # if USE_FLASH_ATTENTION is set, ensure CK instances get generated if(USE_FLASH_ATTENTION) +<<<<<<< HEAD if("$ENV{USE_CK_FLASH_ATTENTION}" STREQUAL "1") message(STATUS "USE_CK_FLASH_ATTENTION is being deprecated. Please use USE_ROCM_CK_SDPA instead") caffe2_update_option(USE_ROCM_CK_SDPA ON) @@ -203,6 +210,24 @@ if(USE_FLASH_ATTENTION) add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3) file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip") list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip}) +======= + if(DEFINED ENV{USE_CK_FLASH_ATTENTION}) + set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION}) + if(USE_CK_FLASH_ATTENTION STREQUAL "1") + if(DEFINED ENV{PYTORCH_ROCM_ARCH}) + list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS) + if(NUM_ARCHS GREATER 1) + message(WARNING "Building CK for multiple archs can increase build time considerably! + Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for") + endif() + endif() + message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled") + message(STATUS "Generating CK kernel instances...") + add_subdirectory(native/transformers/hip/flash_attn/ck) + file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") + list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) + endif() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") @@ -216,7 +241,11 @@ file(GLOB mem_eff_attention_cuda_cpp "native/transformers/cuda/mem_eff_attention if(USE_CUDA AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)) add_library(flash_attention OBJECT EXCLUDE_FROM_ALL ${flash_attention_cuda_kernels_cu} ${flash_attention_cuda_cpp}) +<<<<<<< HEAD target_include_directories(flash_attention SYSTEM PUBLIC +======= + target_include_directories(flash_attention PUBLIC +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc ${PROJECT_SOURCE_DIR}/third_party/flash-attention/include ${PROJECT_SOURCE_DIR}/third_party/cutlass/include @@ -252,6 +281,7 @@ if(USE_MEM_EFF_ATTENTION) list(APPEND ATen_ATTENTION_KERNEL_SRCS ${mem_eff_attention_cuda_kernels_cu}) endif() +<<<<<<< HEAD # FBGEMM GenAI IF(USE_FBGEMM_GENAI) set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/) @@ -330,6 +360,8 @@ IF(USE_FBGEMM_GENAI) endif() endif() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # XNNPACK file(GLOB native_xnnpack "native/xnnpack/*.cpp") @@ -377,9 +409,12 @@ if(USE_VULKAN) else() set(all_cpu_cpp ${all_cpu_cpp} ${vulkan_cpp}) endif() +<<<<<<< HEAD if(USE_EIGEN_SPARSE) set(all_cpu_cpp ${all_cpu_cpp} ${native_eigen_cpp}) endif() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if(USE_MTIA) set(ATen_MTIA_SRCS ${ATen_MTIA_SRCS} ${mtia_cpp} ${mtia_h} ${native_mtia_cpp} ${native_mtia_h}) @@ -458,6 +493,7 @@ if(USE_CUDA) endif() if(USE_ROCM) +<<<<<<< HEAD if((USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA) OR USE_ROCM_CK_GEMM) # NOTE: The PyTorch build does not actually add_subdirectory # third_party/composable_kernel or use it as a CMake library. What is used @@ -487,13 +523,46 @@ if(USE_ROCM) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include) _pytorch_rocm_generate_ck_conf() endif() +======= + # NOTE: The PyTorch build does not actually add_subdirectory + # third_party/composable_kernel or use it as a CMake library. What is used + # is header only, so this should be ok, except that the CMake build generates + # a ck/config.h. We just do that part here. Without this, the ck.h from the + # ROCM SDK may get accidentally used instead. + function(_pytorch_rocm_generate_ck_conf) + set(CK_ENABLE_INT8 "ON") + set(CK_ENABLE_FP16 "ON") + set(CK_ENABLE_FP32 "ON") + set(CK_ENABLE_FP64 "ON") + set(CK_ENABLE_BF16 "ON") + set(CK_ENABLE_FP8 "ON") + set(CK_ENABLE_BF8 "ON") + set(CK_USE_XDL "ON") + set(CK_USE_WMMA "ON") + configure_file( + "${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in" + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h" + ) + endfunction() + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel) + _pytorch_rocm_generate_ck_conf() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Next two lines are needed because TunableOp uses third-party/fmt list(APPEND ATen_HIP_INCLUDE $) list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only) +<<<<<<< HEAD if(USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck) endif() +======= +if(USE_FLASH_ATTENTION) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck) +endif() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) list(APPEND ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} @@ -503,13 +572,20 @@ if(USE_ROCM) ${native_quantized_hip_hip} ${native_transformers_hip_hip} ${native_transformers_src_hip_hip} ) +<<<<<<< HEAD if(NOT USE_ROCM_CK_GEMM) +======= + if(WIN32) # Windows doesn't support Composable Kernels +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip") file(GLOB native_hip_ck "native/hip/ck*.hip") exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}" ${native_hip_bgemm} ${native_hip_ck}) endif() +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources) list(APPEND all_hip_cpp ${native_nested_hip_cpp} @@ -548,7 +624,11 @@ if(LAPACK_FOUND) # would not need this at all), some of our libraries (magma in particular) # backend to CPU BLAS/LAPACK implementations, and so it is very important # we get the *right* implementation, because even if the symbols are the +<<<<<<< HEAD # same, LAPACK implementations may have different calling conventions. +======= + # same, LAPACK implementions may have different calling conventions. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This caused https://github.com/pytorch/pytorch/issues/7353 # # We do NOT do this on Linux, since we just rely on torch_cpu to @@ -669,6 +749,7 @@ if(USE_CUDA AND NOT USE_ROCM) add_definitions(-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include) +<<<<<<< HEAD # Add FBGEMM_GENAI include directories for torch_ops.h if(USE_FBGEMM_GENAI) @@ -694,6 +775,26 @@ if(USE_CUDA AND NOT USE_ROCM) CUDA::cusolver_static ${CUDAToolkit_LIBRARY_DIR}/libcusolver_lapack_static.a # needed for libcusolver_static ) +======= + if($ENV{ATEN_STATIC_CUDA}) + list(APPEND ATen_CUDA_DEPENDENCY_LIBS + ${CUDA_LIBRARIES} + CUDA::cusparse_static + CUDA::cufft_static_nocallback + ) + if(NOT BUILD_LAZY_CUDA_LINALG) + if(CUDA_VERSION_MAJOR LESS_EQUAL 11) + list(APPEND ATen_CUDA_DEPENDENCY_LIBS + CUDA::cusolver_static + ${CUDAToolkit_LIBRARY_DIR}/liblapack_static.a # needed for libcusolver_static + ) + elseif(CUDA_VERSION_MAJOR GREATER_EQUAL 12) + list(APPEND ATen_CUDA_DEPENDENCY_LIBS + CUDA::cusolver_static + ${CUDAToolkit_LIBRARY_DIR}/libcusolver_lapack_static.a # needed for libcusolver_static + ) + endif() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() else() list(APPEND ATen_CUDA_DEPENDENCY_LIBS @@ -758,6 +859,7 @@ endif() if(USE_MPS) include(../../../cmake/Metal.cmake) +<<<<<<< HEAD set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h} ${native_sparse_mps_mm}) if(CAN_COMPILE_METAL) @@ -777,6 +879,31 @@ if(USE_MPS) else() file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps") foreach(SHADER ${native_mps_metal} ${native_mps_sparse_metal}) +======= + set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h}) + + if(CAN_COMPILE_METAL) + foreach(SHADER ${native_mps_metal}) + cmake_path(GET SHADER STEM TGT_STEM) + string(CONCAT TGT_BASIC ${TGT_STEM} "_30.air") + string(CONCAT TGT_BFLOAT ${TGT_STEM} "_31.air") + list(APPEND AIR_BASIC ${TGT_BASIC}) + list(APPEND AIR_BFLOAT ${TGT_BFLOAT}) + metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.0") + metal_to_air(${SHADER} ${TGT_BFLOAT} "-std=metal3.1") + endforeach() + air_to_metallib(kernels_basic.metallib ${AIR_BASIC}) + air_to_metallib(kernels_bfloat.metallib ${AIR_BFLOAT}) + add_custom_command( + COMMAND echo "// $$(date)" > metallib_dummy.cpp + DEPENDS kernels_basic.metallib kernels_bfloat.metallib + OUTPUT metallib_dummy.cpp + COMMENT "Updating metallibs timestamp") + add_custom_target(metallibs DEPENDS kernels_basic.metallib kernels_bfloat.metallib metallib_dummy.cpp) + else() + file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps") + foreach(SHADER ${native_mps_metal}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cmake_path(GET SHADER STEM TGT_STEM) string(CONCAT SHADER_HDR_NAME "${CMAKE_CURRENT_BINARY_DIR}" /native/mps/ ${TGT_STEM} "_metallib.h") metal_to_metallib_h(${SHADER} ${SHADER_HDR_NAME}) diff --git a/aten/src/ATen/CPUGeneratorImpl.cpp b/aten/src/ATen/CPUGeneratorImpl.cpp index 44ad24b81755f..f27391e7ee73c 100644 --- a/aten/src/ATen/CPUGeneratorImpl.cpp +++ b/aten/src/ATen/CPUGeneratorImpl.cpp @@ -131,18 +131,36 @@ uint64_t CPUGeneratorImpl::seed() { /** * Sets the internal state of CPUGeneratorImpl. The new internal state +<<<<<<< HEAD * must be a strided CPU byte tensor and of the same size as CPUGeneratorImplState. +======= + * must be a strided CPU byte tensor and of the same size as either + * CPUGeneratorImplStateLegacy (for legacy CPU generator state) or + * CPUGeneratorImplState (for new state). + * + * FIXME: Remove support of the legacy state in the future? +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) */ void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) { using detail::CPUGeneratorImplState; using detail::CPUGeneratorImplStateLegacy; +<<<<<<< HEAD static_assert(std::is_standard_layout_v, "CPUGeneratorImplState is not a PODType"); constexpr size_t size = sizeof(CPUGeneratorImplState); +======= + static_assert(std::is_standard_layout_v, "CPUGeneratorImplStateLegacy is not a PODType"); + static_assert(std::is_standard_layout_v, "CPUGeneratorImplState is not a PODType"); + + static const size_t size_legacy = sizeof(CPUGeneratorImplStateLegacy); + static const size_t size_current = sizeof(CPUGeneratorImplState); + static_assert(size_legacy != size_current, "CPUGeneratorImplStateLegacy and CPUGeneratorImplState can't be of the same size"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) detail::check_rng_state(new_state); at::mt19937 engine; +<<<<<<< HEAD auto new_state_size = new_state.numel(); TORCH_CHECK(new_state_size == size, "Expected a CPUGeneratorImplState of size ", size, @@ -150,6 +168,51 @@ void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) { auto rng_state = new_state.data_ptr_impl(); auto legacy_pod = &(rng_state->legacy_pod); +======= + auto float_normal_sample = std::optional(); + auto double_normal_sample = std::optional(); + + // Construct the state of at::CPUGeneratorImpl based on input byte tensor size. + CPUGeneratorImplStateLegacy* legacy_pod{nullptr}; + auto new_state_size = new_state.numel(); + if (new_state_size == size_legacy) { + legacy_pod = (CPUGeneratorImplStateLegacy*)new_state.data(); + // Note that in CPUGeneratorImplStateLegacy, we didn't have float version + // of normal sample and hence we leave the std::optional as is + + // Update next_double_normal_sample. + // Note that CPUGeneratorImplStateLegacy stores two uniform values (normal_x, normal_y) + // and a rho value (normal_rho). These three values were redundant and in the new + // DistributionsHelper.h, we store the actual extra normal sample, rather than three + // intermediate values. + if (legacy_pod->normal_is_valid) { + auto r = legacy_pod->normal_rho; + auto theta = 2.0 * c10::pi * legacy_pod->normal_x; + // we return the sin version of the normal sample when in caching mode + double_normal_sample = std::optional(r * ::sin(theta)); + } + } else if (new_state_size == size_current) { + auto rng_state = (CPUGeneratorImplState*)new_state.data(); + legacy_pod = &rng_state->legacy_pod; + // update next_float_normal_sample + if (rng_state->is_next_float_normal_sample_valid) { + float_normal_sample = std::optional(rng_state->next_float_normal_sample); + } + + // Update next_double_normal_sample. + // Note that in getRNGState, we now return the actual normal sample in normal_y + // and if it's valid in normal_is_valid. The redundant normal_x and normal_rho + // are squashed to 0.0. + if (legacy_pod->normal_is_valid) { + double_normal_sample = std::optional(legacy_pod->normal_y); + } + } else { + TORCH_CHECK(false, "Expected either a CPUGeneratorImplStateLegacy of size ", size_legacy, + " or a CPUGeneratorImplState of size ", size_current, + " but found the input RNG state size to be ", new_state_size); + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // construct engine_ // Note that CPUGeneratorImplStateLegacy stored a state array of 64 bit uints, whereas in our // redefined mt19937, we have changed to a state array of 32 bit uints. Hence, we are @@ -163,12 +226,17 @@ void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) { engine.set_data(rng_data); TORCH_CHECK(engine.is_valid(), "Invalid mt19937 state"); this->engine_ = engine; +<<<<<<< HEAD this->next_float_normal_sample_ = rng_state->is_next_float_normal_sample_valid ? std::optional(rng_state->next_float_normal_sample) : std::optional(); this->next_double_normal_sample_ = legacy_pod->normal_is_valid ? std::optional(legacy_pod->normal_y) : std::optional(); +======= + this->next_float_normal_sample_ = float_normal_sample; + this->next_double_normal_sample_ = double_normal_sample; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } /** diff --git a/aten/src/ATen/Config.h.in b/aten/src/ATen/Config.h.in index 0bae6d4af6e5e..c4475dc390fce 100644 --- a/aten/src/ATen/Config.h.in +++ b/aten/src/ATen/Config.h.in @@ -20,4 +20,7 @@ #define AT_BLAS_F2C() @AT_BLAS_F2C@ #define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@ #define AT_KLEIDIAI_ENABLED() @AT_KLEIDIAI_ENABLED@ +<<<<<<< HEAD #define AT_USE_EIGEN_SPARSE() @AT_USE_EIGEN_SPARSE@ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 7a8d02be530e3..6ecbce44b60de 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -14,13 +14,18 @@ #include #ifdef USE_FBGEMM +<<<<<<< HEAD C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include C10_DIAGNOSTIC_POP() +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif // USE_FBGEMM #if defined(__aarch64__) && !defined(C10_MOBILE) #include #endif +<<<<<<< HEAD namespace at { namespace { @@ -86,6 +91,11 @@ void check_fp32_prec_backend_and_op( } } // namespace +======= + +namespace at { + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Context::Context() = default; // TODO: This could be bad juju if someone calls globalContext() in the @@ -179,6 +189,7 @@ void Context::setUserEnabledNNPACK(bool e) { enabled_nnpack = e; } +<<<<<<< HEAD bool Context::allowTF32CuDNN(const std::string& op) const { if (op.empty()){ bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32"; @@ -194,14 +205,21 @@ bool Context::allowTF32CuDNN(const std::string& op) const { return float32Precision("cuda", op) == "tf32"; } warn_deprecated_fp32_precision_api(); +======= +bool Context::allowTF32CuDNN() const { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return allow_tf32_cudnn; } void Context::setAllowTF32CuDNN(bool b) { +<<<<<<< HEAD setFloat32Precision("cuda", "rnn", b ? "tf32" : "none"); setFloat32Precision("cuda", "conv", b ? "tf32" : "none"); allow_tf32_cudnn = b; warn_deprecated_fp32_precision_api(); +======= + allow_tf32_cudnn = b; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void Context::setSDPPriorityOrder(const std::vector& order) { @@ -222,6 +240,7 @@ bool Context::allowTF32OneDNN() const { return allow_tf32_onednn; } +<<<<<<< HEAD // NOLINTNEXTLINE(clang-diagnostic-unused-parameter) void Context::setAllowTF32OneDNN(bool b){ #ifdef USE_XPU @@ -229,6 +248,14 @@ bool Context::allowTF32OneDNN() const { #else TORCH_WARN("TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support."); #endif +======= +void Context::setAllowTF32OneDNN(bool b){ +#ifdef USE_XPU + allow_tf32_onednn = b; +#else + TORCH_WARN("TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support."); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } bool Context::userEnabledFlashSDP() const { @@ -281,6 +308,12 @@ bool Context::userEnabledOverrideableSDP() const { static constexpr const auto cublas_config_var_name = "CUBLAS_WORKSPACE_CONFIG"; static constexpr const std::array cublas_deterministic_configs = {":4096:8", ":16:8"}; +<<<<<<< HEAD +======= +#ifdef USE_ROCM +static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32"; +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool Context::checkCuBLASConfigDeterministic() { // If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config @@ -331,6 +364,7 @@ void Context::setBenchmarkLimitCuDNN(int b) { benchmark_limit_cudnn = b; } +<<<<<<< HEAD bool Context::immediateMiopen() const { return immediate_miopen; } @@ -385,10 +419,41 @@ std::string Context::float32Precision(const std::string& backend, const std::str precision = fp32_precision.find("generic")->second.find("all")->second; bool valid_prec = validate_fp32_prec(backend, precision); return valid_prec ? precision : "none"; +======= +bool Context::allowTF32CuBLAS() const { +#ifdef USE_ROCM + const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32); + if (allow_tf32 != true) { + return false; + } +#endif + return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST; +} + +void Context::setAllowTF32CuBLAS(bool b) { +#ifdef USE_ROCM + const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32); + if (allow_tf32 != true) { + C10_LOG_FIRST_N(INFO, 10) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. " + << "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it."; + return; + } +#endif + float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST; +} + +Float32MatmulPrecision Context::float32MatmulPrecision() const { + return float32_matmul_precision; +} + +void Context::setFloat32MatmulPrecision(Float32MatmulPrecision p) { + float32_matmul_precision = p; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void Context::setFloat32MatmulPrecision(const std::string &s) { auto match = [this](const std::string & s_) { +<<<<<<< HEAD warn_deprecated_fp32_precision_api(); // TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention if (s_ == "highest") { @@ -405,6 +470,17 @@ void Context::setFloat32MatmulPrecision(const std::string &s) { float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM; setFloat32Precision("cuda", "matmul", "tf32"); setFloat32Precision("mkldnn", "matmul", "bf16"); +======= + // TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention + if (s_ == "highest") { + float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST; + return true; + } else if (s_ == "high") { + float32_matmul_precision = at::Float32MatmulPrecision::HIGH; + return true; + } else if (s_ == "medium") { + float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return true; } return false; @@ -418,6 +494,7 @@ void Context::setFloat32MatmulPrecision(const std::string &s) { "setFloat32MatmulPrecision call has no effect."); } +<<<<<<< HEAD void Context::setFloat32Precision(const std::string& backend, const std::string& op, const std::string& p) { check_fp32_prec_backend_and_op(backend, op); if (validate_fp32_prec(backend, p)) { @@ -439,6 +516,8 @@ void Context::setFloat32Precision(const std::string& backend, const std::string& } } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::LinalgBackend Context::linalgPreferredBackend() const { return linalg_preferred_backend; } @@ -463,9 +542,12 @@ at::BlasBackend Context::blasPreferredBackend() { // call site for blasPreferredBackend(), we set it to an actual value. if (blas_preferred_backend == at::BlasBackend::Default) { blas_preferred_backend = at::BlasBackend::Cublas; +<<<<<<< HEAD // This logic sits in the getter because it needs to validate // values set via env vars such as TORCH_BLAS_PREFER_CUBLASLT // which initialize the backend without calling the setter +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef USE_ROCM // AMD Instinct targets prefer hipblaslt static const bool hipblaslt_preferred = []() { @@ -474,6 +556,12 @@ at::BlasBackend Context::blasPreferredBackend() { #if ROCM_VERSION >= 60400 "gfx1200", "gfx1201", #endif +<<<<<<< HEAD +======= +#if ROCM_VERSION >= 60402 + "gfx1150", "gfx1151", +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if ROCM_VERSION >= 60500 "gfx950" #endif @@ -495,6 +583,7 @@ at::BlasBackend Context::blasPreferredBackend() { // hipblaslt support for all archs is not as complete as hipblas if (blas_preferred_backend == at::BlasBackend::Cublaslt) { static const bool hipblaslt_unsupported = []() { +<<<<<<< HEAD if(!hasCuBLASLt()) { return true; @@ -503,6 +592,15 @@ at::BlasBackend Context::blasPreferredBackend() { "gfx90a", "gfx942", #if ROCM_VERSION >= 60300 "gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908", +======= + static const std::vector archs = { + "gfx90a", "gfx942", +#if ROCM_VERSION >= 60300 + "gfx1100", "gfx1101", "gfx1200", "gfx1201", +#endif +#if ROCM_VERSION >= 60402 + "gfx1150", "gfx1151", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif #if ROCM_VERSION >= 60500 "gfx950" @@ -524,6 +622,7 @@ at::BlasBackend Context::blasPreferredBackend() { return blas_preferred_backend; } +<<<<<<< HEAD bool Context::ckSupported() { #ifdef USE_ROCM static const std::vector supported_archs = { @@ -542,6 +641,8 @@ bool Context::ckSupported() { #endif } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void Context::setBlasPreferredBackend(at::BlasBackend b) { #ifdef _MSC_VER TORCH_WARN_ONCE( @@ -551,6 +652,7 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) { #else TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(), "Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt."); +<<<<<<< HEAD #ifdef USE_ROCM static const bool ckSupportedFlag = ckSupported(); static const bool hasCKGEMMFlag = hasCKGEMM(); @@ -559,6 +661,10 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) { "architecture supported for CK: ", ckSupportedFlag, ", PyTorch built with CK GEMM support: ", hasCKGEMMFlag); #endif +======= + TORCH_CHECK((b != at::BlasBackend::Ck) || hasROCM(), + "Cannot set preferred backend to Ck if PyTorch has not been compiled for ROCm."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (b != at::BlasBackend::Default && b != at::BlasBackend::Cublas) { TORCH_WARN_ONCE( "torch.backends.cuda.preferred_blas_library is an experimental feature. " @@ -570,6 +676,7 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) { #endif } +<<<<<<< HEAD at::ROCmFABackend Context::getROCmFAPreferredBackend() { #ifdef USE_ROCM // Set potential "Default" value so we don't have to interpret at call sites. @@ -593,10 +700,14 @@ at::ROCmFABackend Context::getROCmFAPreferredBackend() { } #endif +======= +at::ROCmFABackend Context::getROCmFAPreferredBackend() const { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return rocm_fa_preferred_backend; } void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) { +<<<<<<< HEAD #ifdef USE_ROCM static const bool hasCKSDPAFlag = hasCKSDPA(); static const bool ckSupportedFlag = ckSupported(); @@ -604,6 +715,32 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) { "Cannot set preferred SDPA backend to CK since following conditions are not true: ", "architecture supported for CK: ", ckSupportedFlag, ", PyTorch built with CK SDPA support: ", hasCKSDPAFlag); +======= + + // TODO: add plumbing for hasCK for validity checking + TORCH_CHECK((b != at::ROCmFABackend::Ck) || hasROCM(), + "Cannot set preferred flash attention backend to Ck if PyTorch has not been compiled for ROCm."); +#ifdef USE_ROCM + if(b == at::ROCmFABackend::Ck) { + static const bool ck_unsupported = []() { + static const std::vector archs = { + "gfx90a", "gfx942", "gfx950" + }; + for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) { + if (!detail::getCUDAHooks().isGPUArch(archs, index)) { + TORCH_WARN_ONCE( + "Attempting to use CK on an unsupported architecture! Cannot set backend to CK"); + return true; + } + } + return false; + }(); + if(!ck_unsupported) rocm_fa_preferred_backend = b; + } + else { + rocm_fa_preferred_backend = b; + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif rocm_fa_preferred_backend = b; } @@ -681,6 +818,7 @@ bool Context::hasLAPACK() { #endif } +<<<<<<< HEAD bool Context::hasEigenSparse() { #if AT_USE_EIGEN_SPARSE() return true; @@ -689,6 +827,8 @@ bool Context::hasEigenSparse() { #endif } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::QEngine Context::qEngine() const { static auto _quantized_engine = []() { at::QEngine qengine = at::kNoQEngine; @@ -712,14 +852,22 @@ at::QEngine Context::qEngine() const { #endif return qengine; }(); +<<<<<<< HEAD auto qt_engine = quantized_engine.load(); return qt_engine == at::QEngine::NoQEngine ? _quantized_engine : qt_engine; +======= + return quantized_engine.value_or(_quantized_engine); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void Context::setQEngine(at::QEngine e) { const auto& qengines = supportedQEngines(); if (std::find(qengines.begin(), qengines.end(), e) != qengines.end()) { +<<<<<<< HEAD quantized_engine.store(e); +======= + quantized_engine = e; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return; } TORCH_CHECK(false, "quantized engine ", toString(e), " is not supported"); @@ -731,9 +879,23 @@ const std::vector& Context::supportedQEngines() { // Engines are listed in priority order: later one wins // By default we prefer FBGEMM if we're running on server side // QNNPACK on server side has some issue, so we disable it by default. +<<<<<<< HEAD +#ifdef USE_PYTORCH_QNNPACK + engines.push_back(at::kQNNPACK); +#endif +======= +#ifdef C10_MOBILE + engines.push_back(at::kNoQEngine); +#ifdef USE_PYTORCH_QNNPACK + engines.push_back(at::kQNNPACK); +#endif +#else // C10_MOBILE #ifdef USE_PYTORCH_QNNPACK engines.push_back(at::kQNNPACK); #endif + engines.push_back(at::kNoQEngine); +#endif // C10_MOBILE +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if AT_MKLDNN_ENABLED() engines.push_back(at::kONEDNN); @@ -865,7 +1027,10 @@ void Context::setAllowFP16ReductionCPU(bool b) { #if defined(__aarch64__) && !defined(C10_MOBILE) if (!cpuinfo_initialize() || !cpuinfo_has_arm_fp16_arith()) #else +<<<<<<< HEAD // NOLINTNEXTLINE(facebook-hte-MissingBraces) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (true) #endif TORCH_CHECK(false, "Float16 arithmetic is not supported by the CPU!"); diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 5cfa9b23e20aa..887a1831eef96 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -28,7 +28,10 @@ #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include namespace at { @@ -132,8 +135,11 @@ class TORCH_API Context { static bool hasKleidiAI(); static bool hasLAPACK(); static bool hasMKLDNN(); +<<<<<<< HEAD static bool ckSupported(); static bool hasEigenSparse(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static bool hasMAGMA() { return detail::getCUDAHooks().hasMAGMA(); } @@ -164,12 +170,15 @@ class TORCH_API Context { static bool hasROCM() { return detail::getCUDAHooks().hasROCM(); } +<<<<<<< HEAD static bool hasCKSDPA() { return detail::getCUDAHooks().hasCKSDPA(); } static bool hasCKGEMM() { return detail::getCUDAHooks().hasCKGEMM(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static bool hasHIP() { return detail::getHIPHooks().hasHIP(); } @@ -213,8 +222,11 @@ class TORCH_API Context { void setBenchmarkCuDNN(bool); int benchmarkLimitCuDNN() const; void setBenchmarkLimitCuDNN(int); +<<<<<<< HEAD bool immediateMiopen() const; void setImmediateMiopen(bool); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool deterministicCuDNN() const; void setDeterministicCuDNN(bool); bool deterministicMkldnn() const; @@ -260,7 +272,11 @@ class TORCH_API Context { at::BlasBackend blasPreferredBackend(); void setBlasPreferredBackend(at::BlasBackend); +<<<<<<< HEAD at::ROCmFABackend getROCmFAPreferredBackend(); +======= + at::ROCmFABackend getROCmFAPreferredBackend() const; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void setROCmFAPreferredBackend(at::ROCmFABackend); // Note [Enabling Deterministic Operations] @@ -347,20 +363,28 @@ class TORCH_API Context { void alertCuBLASConfigNotDeterministic() const; void setFloat32MatmulPrecision(const std::string& s); +<<<<<<< HEAD void setFloat32Precision( const std::string& backend, const std::string& op, const std::string& s); bool allowTF32CuDNN(const std::string& op = std::string()) const; +======= + bool allowTF32CuDNN() const; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void setAllowTF32CuDNN(bool); bool allowTF32OneDNN() const; void setAllowTF32OneDNN(bool); bool allowTF32CuBLAS() const; void setAllowTF32CuBLAS(bool); Float32MatmulPrecision float32MatmulPrecision() const; +<<<<<<< HEAD std::string float32Precision( const std::string& backend, const std::string& op) const; +======= + void setFloat32MatmulPrecision(Float32MatmulPrecision p); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool allowFP16ReductionCuBLAS() const; void setAllowFP16ReductionCuBLAS(bool); bool allowBF16ReductionCuBLAS() const; @@ -441,8 +465,12 @@ class TORCH_API Context { at::SDPBackend::flash_attention, at::SDPBackend::efficient_attention, at::SDPBackend::math, +<<<<<<< HEAD at::SDPBackend::cudnn_attention, at::SDPBackend::overrideable}; +======= + at::SDPBackend::cudnn_attention}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool enabled_flashSDP = true; bool enabled_mem_efficientSDP = true; bool enabled_mathSDP = true; @@ -450,7 +478,10 @@ class TORCH_API Context { bool enabled_overrideable = true; bool allow_fp16_bf16_reduction_mathSDP = false; bool benchmark_cudnn = false; +<<<<<<< HEAD bool immediate_miopen = false; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Float32MatmulPrecision float32_matmul_precision = c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true ? at::Float32MatmulPrecision::HIGH @@ -484,6 +515,7 @@ class TORCH_API Context { bool release_original_weights = false; #endif bool display_vmap_fallback_warnings_ = false; +<<<<<<< HEAD std::atomic quantized_engine = at::QEngine::NoQEngine; bool enable_sparse_tensor_invariant_checks = false; bool allow_fp16_reduction_cpu = false; @@ -505,6 +537,12 @@ class TORCH_API Context { {"all", "none"}}}, }; +======= + std::optional quantized_engine = std::nullopt; + bool enable_sparse_tensor_invariant_checks = false; + bool allow_fp16_reduction_cpu = false; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Allocator* prev_allocator_ptr_{nullptr}; }; @@ -616,10 +654,13 @@ inline bool hasLAPACK() { return globalContext().hasLAPACK(); } +<<<<<<< HEAD inline bool hasEigenSparse() { return globalContext().hasEigenSparse(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inline bool hasMAGMA() { return globalContext().hasMAGMA(); } diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index 98ad757946bec..5e67271c2f212 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -69,33 +69,54 @@ DLDataType getDLDataType(const Tensor& t) { case ScalarType::Float8_e4m3fn: case ScalarType::Float8_e4m3fnuz: case ScalarType::Float8_e8m0fnu: +<<<<<<< HEAD TORCH_CHECK_BUFFER(false, "float8 types are not supported by dlpack"); break; case ScalarType::Float4_e2m1fn_x2: TORCH_CHECK_BUFFER(false, "float4 types are not supported by dlpack"); +======= + TORCH_CHECK(false, "float8 types are not supported by dlpack"); + break; + case ScalarType::Float4_e2m1fn_x2: + TORCH_CHECK(false, "float4 types are not supported by dlpack"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) break; case ScalarType::QInt8: case ScalarType::QUInt8: case ScalarType::QInt32: case ScalarType::QUInt4x2: case ScalarType::QUInt2x4: +<<<<<<< HEAD TORCH_CHECK_BUFFER(false, "QUInt/QInt types are not supported by dlpack"); +======= + TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) break; case ScalarType::Bits1x8: case ScalarType::Bits2x4: case ScalarType::Bits4x2: case ScalarType::Bits8: case ScalarType::Bits16: +<<<<<<< HEAD TORCH_CHECK_BUFFER(false, "Bit types are not supported by dlpack"); break; case ScalarType::Undefined: TORCH_CHECK_BUFFER(false, "Undefined is not a valid ScalarType"); case ScalarType::NumOptions: TORCH_CHECK_BUFFER(false, "NumOptions is not a valid ScalarType"); +======= + TORCH_CHECK(false, "Bit types are not supported by dlpack"); + break; + case ScalarType::Undefined: + TORCH_CHECK(false, "Undefined is not a valid ScalarType"); + case ScalarType::NumOptions: + TORCH_CHECK(false, "NumOptions is not a valid ScalarType"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return dtype; } +<<<<<<< HEAD DLDevice torchDeviceToDLDevice(at::Device device) { DLDevice ctx; @@ -104,6 +125,12 @@ DLDevice torchDeviceToDLDevice(at::Device device) { : 0; switch (device.type()) { +======= +static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) { + DLDevice ctx; + ctx.device_id = static_cast(static_cast(device_id)); + switch (tensor.device().type()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case DeviceType::CPU: ctx.device_type = DLDeviceType::kDLCPU; break; @@ -124,7 +151,12 @@ DLDevice torchDeviceToDLDevice(at::Device device) { break; case DeviceType::XPU: ctx.device_type = DLDeviceType::kDLOneAPI; +<<<<<<< HEAD ctx.device_id = at::detail::getXPUHooks().getGlobalIdxFromDevice(device); +======= + ctx.device_id = + at::detail::getXPUHooks().getGlobalIdxFromDevice(tensor.device()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) break; case DeviceType::MAIA: ctx.device_type = DLDeviceType::kDLMAIA; @@ -132,6 +164,7 @@ DLDevice torchDeviceToDLDevice(at::Device device) { case DeviceType::PrivateUse1: ctx.device_type = DLDeviceType::kDLExtDev; break; +<<<<<<< HEAD case DeviceType::MPS: ctx.device_type = DLDeviceType::kDLMetal; break; @@ -144,11 +177,22 @@ DLDevice torchDeviceToDLDevice(at::Device device) { static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* data = nullptr) { switch (type) { +======= + default: + TORCH_CHECK(false, "Cannot pack tensors on " + tensor.device().str()); + } + return ctx; +} + +static Device getATenDevice(const DLDevice& ctx, void* data) { + switch (ctx.device_type) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case DLDeviceType::kDLCPU: return at::Device(DeviceType::CPU); #ifndef USE_ROCM // if we are compiled under HIP, we cannot do cuda case DLDeviceType::kDLCUDA: +<<<<<<< HEAD return at::Device(DeviceType::CUDA, index); #endif case DLDeviceType::kDLOpenCL: @@ -172,12 +216,38 @@ static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* dat default: TORCH_CHECK_BUFFER( false, "Unsupported device_type: ", std::to_string(type)); +======= + return at::Device(DeviceType::CUDA, static_cast(ctx.device_id)); +#endif + case DLDeviceType::kDLOpenCL: + return at::Device(DeviceType::OPENCL, static_cast(ctx.device_id)); + case DLDeviceType::kDLROCM: +#ifdef USE_ROCM + // this looks funny, we need to return CUDA here to masquerade + return at::Device(DeviceType::CUDA, static_cast(ctx.device_id)); +#else + return at::Device(DeviceType::HIP, static_cast(ctx.device_id)); +#endif + case DLDeviceType::kDLOneAPI: + return at::detail::getXPUHooks().getDeviceFromPtr(data); + case DLDeviceType::kDLMAIA: + return at::Device(DeviceType::MAIA, static_cast(ctx.device_id)); + case DLDeviceType::kDLExtDev: + return at::Device(DeviceType::PrivateUse1, static_cast(ctx.device_id)); + default: + TORCH_CHECK( + false, "Unsupported device_type: ", std::to_string(ctx.device_type)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } ScalarType toScalarType(const DLDataType& dtype) { ScalarType stype = ScalarType::Undefined; +<<<<<<< HEAD TORCH_CHECK_BUFFER(dtype.lanes == 1, "ATen does not support lanes != 1"); +======= + TORCH_CHECK(dtype.lanes == 1, "ATen does not support lanes != 1"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) switch (dtype.code) { case DLDataTypeCode::kDLUInt: switch (dtype.bits) { @@ -194,7 +264,11 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::UInt64; break; default: +<<<<<<< HEAD TORCH_CHECK_BUFFER( +======= + TORCH_CHECK( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) false, "Unsupported kUInt bits ", std::to_string(dtype.bits)); } break; @@ -213,7 +287,11 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::Long; break; default: +<<<<<<< HEAD TORCH_CHECK_BUFFER( +======= + TORCH_CHECK( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) false, "Unsupported kInt bits ", std::to_string(dtype.bits)); } break; @@ -229,7 +307,11 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::Double; break; default: +<<<<<<< HEAD TORCH_CHECK_BUFFER( +======= + TORCH_CHECK( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); } break; @@ -239,7 +321,11 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::BFloat16; break; default: +<<<<<<< HEAD TORCH_CHECK_BUFFER( +======= + TORCH_CHECK( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); } break; @@ -255,7 +341,11 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::ComplexDouble; break; default: +<<<<<<< HEAD TORCH_CHECK_BUFFER( +======= + TORCH_CHECK( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); } break; @@ -265,17 +355,26 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::Bool; break; default: +<<<<<<< HEAD TORCH_CHECK_BUFFER( +======= + TORCH_CHECK( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) false, "Unsupported kDLBool bits ", std::to_string(dtype.bits)); } break; default: +<<<<<<< HEAD TORCH_CHECK_BUFFER(false, "Unsupported code ", std::to_string(dtype.code)); +======= + TORCH_CHECK(false, "Unsupported code ", std::to_string(dtype.code)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return stype; } namespace { +<<<<<<< HEAD // The templated classes below are needed for supporting both: // - DLManagedTensor @@ -302,10 +401,21 @@ void fillVersion( tensor->flags = 0; tensor->version.major = DLPACK_MAJOR_VERSION; tensor->version.minor = DLPACK_MINOR_VERSION; +======= +struct ATenDLMTensor { + Tensor handle; + DLManagedTensor tensor{}; +}; +} // namespace + +static void deleter(DLManagedTensor* arg) { + delete static_cast(arg->manager_ctx); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // This function returns a shared_ptr to memory managed DLpack tensor // constructed out of ATen tensor +<<<<<<< HEAD template T* toDLPackImpl(const Tensor& src) { auto view = src; @@ -352,11 +462,36 @@ T* toDLPackImpl(const Tensor& src) { atDLMTensor->tensor.deleter = &deleter; atDLMTensor->tensor.dl_tensor.data = view.data_ptr(); atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device()); +======= +DLManagedTensor* toDLPack(const Tensor& src) { + // create a new tensor with possibly normalized strides + // gh-83069 + auto shape = src.sizes(); + auto strides = src.strides().vec(); + for (int i = 0; i < src.dim(); i++) { + if (shape[i] < 2) { + strides[i] = 1; + } + } + + auto view = src.as_strided(shape, strides, src.storage_offset()); + ATenDLMTensor* atDLMTensor(new ATenDLMTensor); + atDLMTensor->handle = view; + atDLMTensor->tensor.manager_ctx = atDLMTensor; + atDLMTensor->tensor.deleter = &deleter; + atDLMTensor->tensor.dl_tensor.data = view.data_ptr(); + c10::DeviceIndex device_id = 0; + if (src.is_cuda() || src.is_privateuseone()) { + device_id = src.get_device(); + } + atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) atDLMTensor->tensor.dl_tensor.ndim = static_cast(src.dim()); atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src); atDLMTensor->tensor.dl_tensor.shape = view.sizes().data(); atDLMTensor->tensor.dl_tensor.strides = view.strides().data(); atDLMTensor->tensor.dl_tensor.byte_offset = 0; +<<<<<<< HEAD fillVersion(&atDLMTensor->tensor); return &(atDLMTensor->tensor); @@ -386,18 +521,46 @@ at::Tensor fromDLPackImpl(T* src, std::function deleter) { return at::from_blob( dl_tensor.data, IntArrayRef(dl_tensor.shape, dl_tensor.ndim), +======= + return &(atDLMTensor->tensor); +} + +Tensor fromDLPack(DLManagedTensor* src) { + auto deleter = [src](void* self [[maybe_unused]]) { + if (src->deleter) { + src->deleter(src); + } + }; + return fromDLPack(src, std::move(deleter)); +} + +Tensor fromDLPack(DLManagedTensor* src, std::function deleter) { + Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data); + ScalarType stype = toScalarType(src->dl_tensor.dtype); + if (!src->dl_tensor.strides) { + return at::from_blob( + src->dl_tensor.data, + IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::move(deleter), at::device(device).dtype(stype), {device}); } return at::from_blob( +<<<<<<< HEAD dl_tensor.data, IntArrayRef(dl_tensor.shape, dl_tensor.ndim), IntArrayRef(dl_tensor.strides, dl_tensor.ndim), +======= + src->dl_tensor.data, + IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim), + IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) deleter, at::device(device).dtype(stype), {device}); } +<<<<<<< HEAD // Explicitly instantiate the template above for both classes. template at::Tensor fromDLPackImpl(DLManagedTensor* src, std::function deleter); @@ -452,4 +615,6 @@ Tensor maybeCopyTensor( return data; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at diff --git a/aten/src/ATen/DLConvertor.h b/aten/src/ATen/DLConvertor.h index b1c2eaa2d6eae..8853fdd3f3b2a 100644 --- a/aten/src/ATen/DLConvertor.h +++ b/aten/src/ATen/DLConvertor.h @@ -4,7 +4,11 @@ #include #include +<<<<<<< HEAD // this converter will: +======= +// this convertor will: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // 1) take a Tensor object and wrap it in the DLPack tensor // 2) take a dlpack tensor and convert it to the ATen Tensor @@ -12,6 +16,7 @@ namespace at { TORCH_API ScalarType toScalarType(const DLDataType& dtype); TORCH_API DLManagedTensor* toDLPack(const Tensor& src); +<<<<<<< HEAD TORCH_API struct DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src); TORCH_API Tensor fromDLPack(DLManagedTensor* src, std::function deleter = {}); @@ -66,4 +71,12 @@ struct DLPackTraits { inline static auto fromDLPack = at::fromDLPackVersioned; }; +======= +TORCH_API Tensor fromDLPack(DLManagedTensor* src); +TORCH_API Tensor +fromDLPack(DLManagedTensor* src, std::function deleter); +TORCH_API DLDataType getDLDataType(const Tensor& t); +TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index f23b35047fcc8..0f01d3278929a 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -1,6 +1,9 @@ #pragma once +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -31,7 +34,11 @@ TORCH_API bool isAccelerator(c10::DeviceType device_type); template < typename... T, typename = std::enable_if_t<(std::is_same_v && ...)>> +<<<<<<< HEAD inline bool isAcceleratorExcluded( +======= +TORCH_API inline bool isAcceleratorExcluded( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::DeviceType device_type, c10::DeviceType first_excluded, T... rest_excluded) { @@ -73,6 +80,7 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); // original device index that was active before the change. TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); +<<<<<<< HEAD TORCH_API inline void emptyCache() { const auto device_type = getAccelerator(true).value(); at::getDeviceAllocator(device_type)->emptyCache(); @@ -94,6 +102,8 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) { at::getDeviceAllocator(device_type)->resetPeakStats(device_index); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::accelerator namespace at { diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp index 0e535ab20cd21..e6faaaf1e7f12 100644 --- a/aten/src/ATen/EmptyTensor.cpp +++ b/aten/src/ATen/EmptyTensor.cpp @@ -31,9 +31,13 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) { return at::globalContext().getPinnedMemoryAllocator(opt_device_type); } else { TORCH_CHECK( +<<<<<<< HEAD false, "pin_memory=True requires a CUDA or other accelerator backend; " "no pinned memory allocator is available on this system.") +======= + false, "Need to provide pin_memory allocator to use pin memory.") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index 123d87b304148..1af022b0411ae 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -233,8 +233,13 @@ Tensor FunctionalInverses::slice_Tensor_inverse(const Tensor& base, const Tensor // NOLINTNEXTLINE(performance-unnecessary-value-param) Tensor FunctionalInverses::split_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymInt split_size, int64_t dim) { +<<<<<<< HEAD // It would be nice if this logic could be reused from autograd's split_backward(), but I don't think it can. // For functionalization, we have only have one of the tensors from the TensorList outputted by split(), and we want to layer i +======= + // It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can. + // For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // on top of the base tensor. // For autograd, we have all of the tensors outputted by split() and we just want to stack them. dim = at::maybe_wrap_dim(dim, base.dim()); diff --git a/aten/src/ATen/FunctionalStorageImpl.cpp b/aten/src/ATen/FunctionalStorageImpl.cpp index 8bca495abdc6d..e1b1a7ce3bc45 100644 --- a/aten/src/ATen/FunctionalStorageImpl.cpp +++ b/aten/src/ATen/FunctionalStorageImpl.cpp @@ -9,6 +9,14 @@ namespace at::functionalization { +<<<<<<< HEAD +======= +ViewMeta ViewMeta::to_out_idx(int64_t out_idx) { + if (out_idx == this->out_index) return *this; + return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Note [Functionalization: Alias Removal Part 2] // See Note [Functionalization: Alias Removal] for more details. // This function applies a single update from one of the views to the StorageImpl. @@ -37,12 +45,20 @@ namespace at::functionalization { static const Tensor apply_update(const FunctionalStorageImpl::Update& update, const Tensor& base) { at::Tensor t = update.new_val; TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); +<<<<<<< HEAD if (update.view_metas.empty()) { return t; } +======= + if (update.view_metas.empty()) return t; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector tmp_values({base}); tmp_values.reserve(update.view_metas.size()); for (size_t i = 0; i < update.view_metas.size() - 1; ++i) { +<<<<<<< HEAD at::Tensor next_view = update.view_metas[i]->forward(tmp_values.back()); +======= + at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided // All of these ops require additional information to recover the sizes of the original tensor. // If need to, we could probably apply this optimization and only bother computing tmp_values @@ -50,8 +66,14 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co tmp_values.push_back(std::move(next_view)); } for(int64_t i = static_cast(update.view_metas.size()) - 1; i >= 0; --i) { +<<<<<<< HEAD // Each view inverse is implemented in ViewInverses.cpp. t = update.view_metas[i]->reverse(tmp_values[i], t); +======= + int64_t out_idx = update.view_metas[i].out_index; + // Each view inverse is implemented in ViewInverses.cpp. + t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); return t; @@ -105,13 +127,21 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base) TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_)); } +<<<<<<< HEAD void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector>& metas) { +======= +void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector& metas) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage"); if (metas.size() > 1) { for (size_t i = 1; i < metas.size(); ++i) { // Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI +<<<<<<< HEAD TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i]->is_as_strided, +======= + TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i].is_as_strided, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i, " was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today," "so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you " diff --git a/aten/src/ATen/FunctionalStorageImpl.h b/aten/src/ATen/FunctionalStorageImpl.h index 0c9c1fd775f32..9fff2c6e0b677 100644 --- a/aten/src/ATen/FunctionalStorageImpl.h +++ b/aten/src/ATen/FunctionalStorageImpl.h @@ -8,6 +8,7 @@ namespace at::functionalization { // See Note [Functionalization Pass In Core] +<<<<<<< HEAD enum class InverseReturnMode { /// Specifies that functional inverses should always return a view. AlwaysView, @@ -77,20 +78,58 @@ enum class InverseReturnMode { // a type are used for supporting pickle serialization. struct ViewMeta { ViewMeta( +======= +// ViewMeta is a class used by the functionalization pass to navigate between +// a base tensor and a view tensor. +// For example, if I call `b = a.view1(...)` +// the functionalization pass will generate and store a ViewMeta on b that looks +// like: +// +// ViewMeta( +// [](const Tensor& base, int64_t mutated_view_idx) { +// return base.view1(...); +// }, +// [](const at::Tensor& base, const at::Tensor& mutated_view, +// int64_t mutated_view_idx) -> at::Tensor { +// return at::functionalization::impl::view1_inverse(base, mutated_view, +// ...); +// } +// +// The forward_fn lambda describes how to replay view1 on a tensor. +// +// The reverse_fn lambda describes how, given a tensor that is already a view, +// how to get the corresponding base tensor. See Note [Functionalization Pass: +// View Inverses] for details. +struct ViewMeta { + ViewMeta( + std::function forward, + std::function reverse, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool has_symbolic_inputs, bool is_multi_output = false, bool is_as_strided = false, int64_t out_idx = 0) +<<<<<<< HEAD : out_index(out_idx), +======= + : forward_fn(std::move(forward)), + reverse_fn(std::move(reverse)), + out_index(out_idx), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_multi_output(is_multi_output), is_as_strided(is_as_strided), has_symbolic_inputs(has_symbolic_inputs) {} +<<<<<<< HEAD virtual ~ViewMeta() = default; virtual Tensor forward(const Tensor& base) = 0; virtual Tensor reverse(const Tensor& base, const Tensor& mutated_view) = 0; +======= + std::function forward_fn; + std::function reverse_fn; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // See Note [out_idx in ViewMeta] int64_t out_index; @@ -102,6 +141,7 @@ struct ViewMeta { // Tells us if this view operation has any symbolic inputs bool has_symbolic_inputs; +<<<<<<< HEAD // Returns a new ViewMeta with the same forward/reverse // functions, but a new out index. // @@ -113,6 +153,12 @@ struct ViewMeta { "ViewMeta::to_out_index not implemented. ", "Likely because there's only one output."); } +======= + // Returns a copy of the current ViewMeta, if out_idx matches the current + // out_index. Otherwise, returns a new ViewMeta with the same forward/reverse + // functions, but a new out index. + ViewMeta to_out_idx(int64_t out_idx); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; // FunctionalStorageImpl is a subclass of StorageImpl used by the @@ -145,14 +191,22 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const at::Tensor new_val; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) +<<<<<<< HEAD const std::vector> view_metas; +======= + const std::vector view_metas; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; explicit FunctionalStorageImpl(const Tensor& value); void add_update( const Tensor& updated_val, +<<<<<<< HEAD const std::vector>& view_metas); +======= + const std::vector& view_metas); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool apply_updates(); const Tensor& base() { return base_; @@ -174,9 +228,12 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { ~FunctionalStorageImpl() override = default; +<<<<<<< HEAD uint64_t mutation_counter() { return mutation_counter_; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void mark_mutation() { mutation_counter_++; } @@ -205,17 +262,23 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { void mark_inductor_storage_resize(c10::SymInt new_size) { inductor_storage_resized_ = true; curr_storage_size_ = std::move(new_size); +<<<<<<< HEAD inductor_storage_resized_counter_++; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } bool was_inductor_storage_resized() { return inductor_storage_resized_; } +<<<<<<< HEAD uint64_t inductor_storage_resized_counter() { return inductor_storage_resized_counter_; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) private: // NB: base_ should always point to a tensor BELOW the current // functionalization layer. This is mainly to avoid reference cycles. e.g. @@ -261,7 +324,10 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { // (1) There were any storage resizes on a graph input // (2) The original/curr storage size tell us if these resizes result in a nop bool inductor_storage_resized_ = false; +<<<<<<< HEAD uint64_t inductor_storage_resized_counter_ = 0; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::SymInt original_storage_size_; c10::SymInt curr_storage_size_; }; diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index 3a574fa7d491c..83923c7253e39 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -129,6 +129,7 @@ void FunctionalTensorWrapper::freeze_storage() const { // - view_value: The output tensor that we need to wrap. // - base: The "base" of the view that `view_value` was generated from. // See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic. +<<<<<<< HEAD FunctionalTensorWrapper::FunctionalTensorWrapper( const Tensor& view_value, const FunctionalTensorWrapper* base, @@ -142,6 +143,19 @@ FunctionalTensorWrapper::FunctionalTensorWrapper( base->is_multi_output_view_ || meta->is_multi_output), was_storage_changed_(base->was_storage_changed_), is_symbolic_(base->is_symbolic_) { +======= +FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, const functionalization::ViewMeta& meta) + : c10::TensorImpl( + c10::DispatchKeySet(DispatchKey::Functionalize), + view_value.dtype(), + view_value.device() + ), + value_(view_value), + is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output), + was_storage_changed_(base->was_storage_changed_), + is_symbolic_(base->is_symbolic_) +{ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_)); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); set_constructor_metadata(); @@ -150,10 +164,18 @@ FunctionalTensorWrapper::FunctionalTensorWrapper( view_metas_ = base->view_metas_; // copy } view_metas_.push_back(meta); +<<<<<<< HEAD maybe_mark_symbolic(meta.get()); storage_ = base->storage_; // alias this tensor's storage with the base tensor's } +======= + maybe_mark_symbolic(meta); + storage_ = base->storage_; // alias this tensor's storage with the base tensor's +} + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const { return static_cast(storage_.unsafeGetStorageImpl()); } @@ -177,18 +199,31 @@ bool FunctionalTensorWrapper::is_up_to_date() const { } // See Note [Functionalization Pass - Inplace View Ops] +<<<<<<< HEAD void FunctionalTensorWrapper::mutate_view_meta(const std::shared_ptr& meta) { view_metas_.push_back(meta); // Manually track the fact that this tensor received a metadata mutation! has_metadata_mutation_ = true; // Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation. maybe_mark_symbolic(meta.get()); +======= +void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::ViewMeta& meta) { + view_metas_.push_back(meta); + // Manually track the fact that this tensor recieved a metadata mutation! + has_metadata_mutation_ = true; + // Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation. + maybe_mark_symbolic(meta); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Note [Functionalization Pass - Inplace View Ops] // So, these ops are special - they're mutation AND view ops. They get special codegen. // An example is transpose_, e.g. `a.transpose_()` // Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas. at::AutoDispatchSkipFunctionalize guard; +<<<<<<< HEAD value_ = meta->forward(value_); +======= + value_ = meta.forward_fn(value_, meta.out_index); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); } @@ -274,7 +309,11 @@ void FunctionalTensorWrapper::set__impl(const FunctionalTensorWrapper* other) { // (We could check if the updated value has a new storage than the original value, // but this won't also let us uniquely determine if the tensor **also** // experienced a data mutation). +<<<<<<< HEAD mark_storage_changed(); +======= + was_storage_changed_ = true; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto sizes_ = value_.sym_sizes(); auto strides_ = value_.sym_strides(); @@ -287,11 +326,19 @@ void FunctionalTensorWrapper::storage_resize_(const c10::SymInt& new_size) { // storage resizing is severely limited: we only support resizing either to zero, or from zero bytes. TORCH_CHECK(new_size == 0 || curr_storage_size == 0, "new_size: ", new_size, ". curr_storage_size: ", curr_storage_size); // The "functionalization rule" for storage resizing is a giant no-op, mainly because we don't want +<<<<<<< HEAD // resize_() calls to actually emit any ops in the functional graph. // How does it work? // Resizing up (old size == 0): // We do nothing in this case. // The expectation is that for the user code to be valid, the next op that should run against the current tensor "x" +======= + // resize_() calls to actualy emit any ops in the functional graph. + // How does it work? + // Resizing up (old size == 0): + // We do nothing in this case. + // The expection is that for the user code to be valid, the next op that should run against the current tensor "x" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // will be a x.copy_(y) (or similar), that will fully overwrite the data of x. // If there are any outstanding aliases of x, we expect them not to be used until after the copy_() call // (otherwise the eager code would be invalid), @@ -328,7 +375,11 @@ void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) { // We're also no longer re-generate "b" fully from "a" anymore, since "a" refers to a slice of "b"'s data. // // This is probably fixable in theory, but: +<<<<<<< HEAD // - the fix would likely complicated the functionalization logic quite a bit. +======= + // - the fix would likey complicated the functionalization logic quite a bit. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // - the primary use case for resize_() today is resizing zero-sized tensors in out= variants of operators // - resize_() also can give you weird results today if you try to resize_() a weirdly strided tensor. // @@ -345,7 +396,11 @@ void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) { set_sizes_and_strides(value_.sizes(), value_.strides()); refresh_numel(); // (Technically we should be guaranteed that the tensor was already contiguous, +<<<<<<< HEAD // since it's guaranteed not to have been a view. Doesn't hurt to run though) +======= + // since it's guaranteed not to have been a view. Doesnt hurt to run though) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) refresh_contiguous(); // Swapping out the storage of a tensor (aka from a resize_() call) will update the sizes and strides of the tensor, // so we need to record the fact that metadata was mutated. @@ -369,8 +424,20 @@ void FunctionalTensorWrapper::sync_() { regenerate_from_base(); } +<<<<<<< HEAD const std::vector>& FunctionalTensorWrapper::view_metas() const { return view_metas_; +======= +Tensor FunctionalTensorWrapper::apply_view_metas(const Tensor& base) { + auto t = base; + + // Reapply views to get the viewed tensor from the base in alias_ + for (auto& view_meta: view_metas_) { + t = view_meta.forward_fn(t, view_meta.out_index); + } + + return t; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void FunctionalTensorWrapper::regenerate_from_base() { @@ -379,7 +446,11 @@ void FunctionalTensorWrapper::regenerate_from_base() { auto t = storage_impl->base(); TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); +<<<<<<< HEAD t = at::functionalization::impl::apply_view_meta_sequence(t, view_metas_); +======= + t = apply_view_metas(t); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); replace_(t, /*from_lazy_regenerate=*/true); @@ -493,8 +564,13 @@ int64_t FunctionalTensorWrapper::dim_custom() const { int64_t FunctionalTensorWrapper::numel_custom() const { return value_.unsafeGetTensorImpl()->numel(); } +<<<<<<< HEAD c10::SymBool FunctionalTensorWrapper::sym_is_contiguous_custom(at::MemoryFormat memory_format) const { return value_.unsafeGetTensorImpl()->sym_is_contiguous(memory_format); +======= +bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const { + return value_.unsafeGetTensorImpl()->is_contiguous(memory_format); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const { return value_.unsafeGetTensorImpl()->sym_sizes(); @@ -573,7 +649,11 @@ std::vector from_functional_tensor(ITensorListRef t_list) { for (const auto& tensor : t_list) { // from_functional_tensor(Tensor) has asserts to make sure you don't accidentally call // it on a non-functional input, +<<<<<<< HEAD // but from_functional_tensor(TensorList) can receive a list containing both +======= + // but from_functional_tensor(TensorList) can recieve a list containing both +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // functional and non-functional tensors. // Example of when that can happen: torch.cat(function_input_tensor, global_state_tensor). // When that happens, we're okay with only unwrapping the functional tensors. @@ -718,11 +798,19 @@ bool isFunctionalTensor(const std::optional& t) { } bool isFunctionalTensor(const c10::List<::std::optional>& t_list) { +<<<<<<< HEAD if (t_list.empty()) { return false; } auto functional_count = 0; for (const auto i : c10::irange(t_list.size())) { auto const & e= t_list[i]; if (!e.has_value() || !e->defined()) { continue; } +======= + if (t_list.empty()) return false; + auto functional_count = 0; + for (const auto i : c10::irange(t_list.size())) { + auto const & e= t_list[i]; + if (!e.has_value() || !e->defined()) continue; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (isFunctionalTensor(e)) { ++functional_count; } @@ -732,10 +820,17 @@ bool isFunctionalTensor(const c10::List<::std::optional>& t_list) { template static bool isFunctionalTensorIListRef(c10::IListRef list) { +<<<<<<< HEAD if (list.size() == 0) { return false; } auto functional_count = 0; for (const auto& tensor : list) { if (!tensor.defined()) { continue; } +======= + if (list.size() == 0) return false; + auto functional_count = 0; + for (const auto& tensor : list) { + if (!tensor.defined()) continue; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (isFunctionalTensor(tensor)) { ++functional_count; } @@ -753,6 +848,7 @@ void freeze_functional_tensor(const Tensor& tensor) { functional_base_impl->freeze_storage(); } +<<<<<<< HEAD Tensor create_functional_tensor_with_view_meta( const at::Tensor& view_to_wrap, const at::Tensor& base, @@ -762,10 +858,17 @@ Tensor create_functional_tensor_with_view_meta( TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base)); auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base); auto meta_ = meta; +======= +Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) { + TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap)); + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base)); + auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (out_idx != 0) { // Note [out_idx in ViewMeta] // When a view op outputs multiple tensors, each output needs its own separate ViewMeta. // Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function. +<<<<<<< HEAD meta_ = meta->to_out_index(out_idx); } return at::detail::make_tensor(view_to_wrap, functional_base_impl, meta_); @@ -775,6 +878,14 @@ std::vector create_functional_tensor_with_view_meta( ITensorListRef view_to_wrap, const at::Tensor& base, const std::shared_ptr& meta) { +======= + meta = meta.to_out_idx(out_idx); + } + return at::detail::make_tensor(view_to_wrap, functional_base_impl, meta); +} + +std::vector create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, const functionalization::ViewMeta& meta) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector outputs(view_to_wrap.size()); int64_t i = 0; for (const auto& tensor : view_to_wrap) { @@ -784,12 +895,17 @@ std::vector create_functional_tensor_with_view_meta( return outputs; } +<<<<<<< HEAD void mutate_view_meta(const at::Tensor& self, const std::shared_ptr& meta) { +======= +void mutate_view_meta(const at::Tensor& self, const functionalization::ViewMeta& meta) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self)); auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self); self_impl->mutate_view_meta(meta); } +<<<<<<< HEAD Tensor apply_view_meta_sequence( const Tensor& base, const std::vector>& sequence) { @@ -800,6 +916,8 @@ Tensor apply_view_meta_sequence( return r; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Note [Propagating strides in the functionalization pass] // In order to properly compute stride information, the functionalization pass // calls each {view} reference implementations with meta tensors. @@ -831,7 +949,11 @@ void setFunctionalizationReapplyViewsTLS(bool reapply_views) { // This function will "functionalize" it. // That is, it will call the operator, but removing any intermediate views/mutations // that are performed inside of it. +<<<<<<< HEAD // This is useful for LTC/XLA, which would like to reuse some of our composite kernels +======= +// This is useful for LTC/XLA, which would like to re-use some of our composite kernels +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // from pytorch core but not have to worry about the view ops that they might call. // e.g. at::block_diag void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* stack) { @@ -893,7 +1015,11 @@ void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* s const auto& ivalue = returns[idx]; if (ivalue.isTensor()) { const auto& t = ivalue.toTensor(); +<<<<<<< HEAD if (!t.defined()) { continue; } +======= + if (!t.defined()) continue; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::functionalization::impl::sync(t); auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t)); (*stack)[returns_begin + idx] = t_new; diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index 6d9050728da70..3beade67d507a 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -56,7 +56,11 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { explicit FunctionalTensorWrapper( const Tensor& view_value, const FunctionalTensorWrapper* base, +<<<<<<< HEAD const std::shared_ptr& meta); +======= + const functionalization::ViewMeta& meta); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Get the underlying, actual tensor, that doesn't know anything about // functionalization. @@ -74,9 +78,13 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { bool has_metadata_mutation() const { return has_metadata_mutation_; } +<<<<<<< HEAD uint64_t mutation_counter() const { return functional_storage_impl()->mutation_counter(); } +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void mark_mutation() { functional_storage_impl()->mark_mutation(); } @@ -99,17 +107,28 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { ->are_all_mutations_under_no_grad_or_inference_mode(); } +<<<<<<< HEAD void maybe_mark_symbolic(functionalization::ViewMeta* meta) { is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs; +======= + void maybe_mark_symbolic(const functionalization::ViewMeta& meta) { + is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } bool is_symbolic() const { return is_symbolic_; } +<<<<<<< HEAD // Retrieves the ViewMeta sequence of this tensor. const std::vector>& view_metas() const; +======= + // Runs the forward_fn of every ViewMeta collected in the current instance + // to some other base. + Tensor apply_view_metas(const Tensor& base); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Sync's the underlying tensor with its alias, if it's out of date. This // involves two steps: 1) Apply any pending updates/mutations to the alias 2) @@ -146,8 +165,12 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { // from the base tensor. This method is used by inplace-view ops like // transpose_. It appends a ViewMeta to the existing stack, and refreshes the // tensor by replaying the views off of the alias. +<<<<<<< HEAD void mutate_view_meta( const std::shared_ptr& meta); +======= + void mutate_view_meta(const at::functionalization::ViewMeta& meta); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Custom implementation of self.set_(src) void set__impl(const FunctionalTensorWrapper* other); @@ -164,6 +187,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { return was_storage_changed_; } +<<<<<<< HEAD void mark_storage_changed() { was_storage_changed_ = true; storage_changed_counter_++; @@ -171,6 +195,10 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { uint64_t storage_changed_counter() { return storage_changed_counter_; +======= + void set_storage_changed() { + was_storage_changed_ = true; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // A FunctionalTensor is considered a base if its not a view of another @@ -189,9 +217,12 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { return functional_storage_impl()->was_inductor_storage_resized(); } +<<<<<<< HEAD bool inductor_storage_resized_counter() { return functional_storage_impl()->inductor_storage_resized_counter(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // The functionalization pass can be used to remove mutations. // It does so by replacing any mutation op with it's corresponding // out-of-place op, followed by a call to replace_(). e.g: @@ -237,8 +268,12 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { at::IntArrayRef strides_custom() const override; int64_t dim_custom() const override; int64_t numel_custom() const override; +<<<<<<< HEAD c10::SymBool sym_is_contiguous_custom( at::MemoryFormat memory_format) const override; +======= + bool is_contiguous_custom(at::MemoryFormat memory_format) const override; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::SymIntArrayRef sym_sizes_custom() const override; c10::SymInt sym_size_custom(int64_t d) const override; c10::SymIntArrayRef sym_strides_custom() const override; @@ -281,12 +316,19 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { bool is_multi_output_view_ = false; // Did the tensor experience a set_() call. bool was_storage_changed_ = false; +<<<<<<< HEAD uint64_t storage_changed_counter_ = 0; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Did the tensor experience any view operation with symbolic int. bool is_symbolic_ = false; size_t generation_ = 0; +<<<<<<< HEAD std::vector> view_metas_; +======= + std::vector view_metas_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) protected: static void copy_tensor_metadata( @@ -301,7 +343,11 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { namespace functionalization { namespace impl { +<<<<<<< HEAD inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper( +======= +TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const Tensor& tensor) { auto functional_impl = static_cast(tensor.unsafeGetTensorImpl()); @@ -378,11 +424,16 @@ TORCH_API void propagate_xla_data_direct( Tensor create_functional_tensor_with_view_meta( const Tensor& view_to_wrap, const Tensor& base, +<<<<<<< HEAD const std::shared_ptr& meta, +======= + functionalization::ViewMeta meta, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t out_idx = 0); std::vector create_functional_tensor_with_view_meta( ITensorListRef view_to_wrap, const Tensor& base, +<<<<<<< HEAD const std::shared_ptr& meta); void mutate_view_meta( @@ -392,6 +443,13 @@ void mutate_view_meta( TORCH_API Tensor apply_view_meta_sequence( const Tensor& base, const std::vector>& sequence); +======= + const functionalization::ViewMeta& meta); + +void mutate_view_meta( + const Tensor& self, + const functionalization::ViewMeta& meta); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out); void set_sizes_strides_offset( diff --git a/aten/src/ATen/FunctionalizeFallbackKernel.cpp b/aten/src/ATen/FunctionalizeFallbackKernel.cpp index 10f988b4d2815..51f1e43a68498 100644 --- a/aten/src/ATen/FunctionalizeFallbackKernel.cpp +++ b/aten/src/ATen/FunctionalizeFallbackKernel.cpp @@ -1,5 +1,8 @@ +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -9,6 +12,10 @@ #include #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifndef AT_PER_OPERATOR_HEADERS #include @@ -29,6 +36,7 @@ #include #endif +<<<<<<< HEAD namespace at::functionalization { Tensor resize__ViewMeta::forward(const Tensor& base) { @@ -54,6 +62,8 @@ Tensor _unsafe_view_ViewMeta::reverse(const Tensor& base, const Tensor& mutated_ } // namespace at::functionalization +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace { void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) { const auto& schema = op.schema(); @@ -132,9 +142,13 @@ namespace { const auto& ivalue = returns[idx]; if (ivalue.isTensor() && should_wrap_outputs) { const auto& t = ivalue.toTensor(); +<<<<<<< HEAD if (!t.defined()) { continue; } +======= + if (!t.defined()) continue; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t)); (*stack)[returns_begin + idx] = t_new; } else if (ivalue.isTensorList() && should_wrap_outputs) { @@ -197,8 +211,24 @@ static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatch // The output of resizing is equivalent to taking a slice of a larger tensor. // We have to emulate this "slicing" with an as_strided call. auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); +<<<<<<< HEAD auto view_meta = std::make_shared( reapply_views, size.vec()); +======= + at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( + [reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { + if (reapply_views) { + return base.as_strided(size, c10::contiguous_strides(size)); + } else { + return at::as_strided_copy(base, size, c10::contiguous_strides(size)); + } + }, + [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { + return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size)); + }, + /*has_symbolic_inputs=*/false + ); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::functionalization::impl::mutate_view_meta(self, view_meta); return self; } @@ -317,11 +347,25 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt tmp_output = at::_unsafe_view_symint(self_, size); } +<<<<<<< HEAD bool has_symbolic_inputs = std::any_of( size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); }); auto view_meta = std::make_shared( has_symbolic_inputs, size.vec()); +======= + bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); }); + + at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( + [size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { + return at::_unsafe_view_symint(base, size); + }, + [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { + return at::_unsafe_view_symint(mutated_view, base.sym_sizes()); + }, + /*has_symbolic_inputs=*/has_symbolic_inputs + ); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta)); // See Note [Propagating strides in the functionalization pass] @@ -331,9 +375,17 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size); if (!stride.has_value()) { +<<<<<<< HEAD TORCH_SYM_CHECK( self.sym_is_contiguous(), +======= + // With unbacked symints, computeStride could fail even on contiguous + // tensors. In this case, we can use the strides of an empty tensor of + // inferred_size. + TORCH_CHECK( + self.is_contiguous(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "View is not valid from size:", self.sym_sizes(), " stride: ", @@ -342,9 +394,12 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt inferred_size, " in case of unbacked symbols consider adding torch.check to guide computing strides."); +<<<<<<< HEAD // With unbacked symints, computeStride could fail even on contiguous // tensors. In this case, we can use the strides of an empty tensor of // inferred_size. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stride = at::detail::empty_symint_meta( inferred_size, std::nullopt, diff --git a/aten/src/ATen/LegacyBatchedFallback.cpp b/aten/src/ATen/LegacyBatchedFallback.cpp index f2b527302a97b..f49559e21f97f 100644 --- a/aten/src/ATen/LegacyBatchedFallback.cpp +++ b/aten/src/ATen/LegacyBatchedFallback.cpp @@ -218,7 +218,11 @@ static Tensor safeStack(TensorList tensors) { // is possible for the backward function to return an undefined grad for some // grad_input for each example. In that case, we return an undefined grad. // +<<<<<<< HEAD // It is theoretically possible for *some* of the examples to produce an +======= + // It is theoretically posssible for *some* of the examples to produce an +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // undefined grad (a kernel could peek at the gradient values and return an // undefined tensor if it determines the gradient is full of zeros). We // could handle this by treating the undefined grad as a zero-filled tensor diff --git a/aten/src/ATen/LegacyBatchedTensorImpl.cpp b/aten/src/ATen/LegacyBatchedTensorImpl.cpp index cceefe985a7e2..d944682a2e8e2 100644 --- a/aten/src/ATen/LegacyBatchedTensorImpl.cpp +++ b/aten/src/ATen/LegacyBatchedTensorImpl.cpp @@ -84,7 +84,11 @@ IntArrayRef BatchedTensorImpl::strides_custom() const { // TODO: implement proper contiguity on batched tensor, then put // sizes_strides_policy back to Default +<<<<<<< HEAD c10::SymBool BatchedTensorImpl::sym_is_contiguous_custom(at::MemoryFormat memory_format) const { +======= +bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(memory_format == MemoryFormat::Contiguous, "NYI: querying is_contiguous inside of vmap for memory_format ", "other than torch.contiguous_format"); diff --git a/aten/src/ATen/LegacyBatchedTensorImpl.h b/aten/src/ATen/LegacyBatchedTensorImpl.h index 798e3535af3fb..22d2400b26a9e 100644 --- a/aten/src/ATen/LegacyBatchedTensorImpl.h +++ b/aten/src/ATen/LegacyBatchedTensorImpl.h @@ -82,8 +82,12 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { IntArrayRef strides_custom() const override; // Override a bunch of methods inherited from TensorImpl to return error // messages. +<<<<<<< HEAD c10::SymBool sym_is_contiguous_custom( at::MemoryFormat memory_format) const override; +======= + bool is_contiguous_custom(at::MemoryFormat memory_format) const override; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void set_size(int64_t dim, int64_t new_size) override; void set_stride(int64_t dim, int64_t new_stride) override; void set_storage_offset(int64_t storage_offset) override; diff --git a/aten/src/ATen/LegacyVmapTransforms.h b/aten/src/ATen/LegacyVmapTransforms.h index be6cf1b697a22..3ca5c09332f98 100644 --- a/aten/src/ATen/LegacyVmapTransforms.h +++ b/aten/src/ATen/LegacyVmapTransforms.h @@ -140,7 +140,11 @@ struct TORCH_API VmapPhysicalView { // mapping a physical tensor to a new logical tensor (BatchedTensor) VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const; +<<<<<<< HEAD // Maps a logical shape to a physical shape by prepending the batch +======= + // Maps a logical shape to a physical shape by pre-pending the batch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // sizes to the logical shape. VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const; diff --git a/aten/src/ATen/MapAllocator.cpp b/aten/src/ATen/MapAllocator.cpp index 63a278050e8a7..8aad67f3879e1 100644 --- a/aten/src/ATen/MapAllocator.cpp +++ b/aten/src/ATen/MapAllocator.cpp @@ -299,7 +299,11 @@ MapAllocator::MapAllocator(WithFd, std::string_view filename, int fd, int flags, ::close(fd); TORCH_CHECK(false, "unable to stretch file <", filename_, "> to the right size: ", c10::utils::str_error(last_err), " (", last_err, ")"); } +<<<<<<< HEAD /* on macOS write returns with errno 45 (Operation not supported) when used +======= +/* on macOS write returns with errno 45 (Opperation not supported) when used +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * with a file descriptor obtained via shm_open */ #ifndef __APPLE__ diff --git a/aten/src/ATen/MemoryOverlap.cpp b/aten/src/ATen/MemoryOverlap.cpp index 1bc8c30158aec..004e35b82904c 100644 --- a/aten/src/ATen/MemoryOverlap.cpp +++ b/aten/src/ATen/MemoryOverlap.cpp @@ -24,7 +24,11 @@ MemOverlap has_internal_overlap(TensorImpl* t) { } } +<<<<<<< HEAD if (t->is_non_overlapping_and_dense_or_false()) { +======= + if (t->is_non_overlapping_and_dense()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return MemOverlap::No; } @@ -63,7 +67,11 @@ MemOverlapStatus get_overlap_status(const TensorImpl* a, const TensorImpl* b) { if (a->numel() == 0 || b->numel() == 0) { return MemOverlapStatus::No; } +<<<<<<< HEAD if (!a->is_non_overlapping_and_dense_or_false() || !b->is_non_overlapping_and_dense_or_false()) { +======= + if (!a->is_non_overlapping_and_dense() || !b->is_non_overlapping_and_dense()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return MemOverlapStatus::TooHard; } // Test for storage equality, rather than pointer equality. diff --git a/aten/src/ATen/NamedTensorUtils.h b/aten/src/ATen/NamedTensorUtils.h index c6198dccd2431..9f0388c7fccf0 100644 --- a/aten/src/ATen/NamedTensorUtils.h +++ b/aten/src/ATen/NamedTensorUtils.h @@ -167,14 +167,22 @@ TORCH_API TensorImpl* propagate_names( TORCH_API void propagate_names(TensorImpl* result, /*const */ TensorImpl* src); +<<<<<<< HEAD inline void propagate_names( +======= +TORCH_API inline void propagate_names( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const TensorBase& result, DimnameList names, bool validate_names = false) { propagate_names(result.unsafeGetTensorImpl(), names, validate_names); } +<<<<<<< HEAD inline void propagate_names_if_nonempty( +======= +TORCH_API inline void propagate_names_if_nonempty( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const TensorBase& result, DimnameList names, bool validate_names = false) { @@ -182,7 +190,13 @@ inline void propagate_names_if_nonempty( result.unsafeGetTensorImpl(), names, validate_names); } +<<<<<<< HEAD inline void propagate_names(const TensorBase& result, const TensorBase& src) { +======= +TORCH_API inline void propagate_names( + const TensorBase& result, + const TensorBase& src) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl()); } diff --git a/aten/src/ATen/NestedTensorImpl.cpp b/aten/src/ATen/NestedTensorImpl.cpp index 63bd867f90220..2d57df694e03d 100644 --- a/aten/src/ATen/NestedTensorImpl.cpp +++ b/aten/src/ATen/NestedTensorImpl.cpp @@ -211,7 +211,11 @@ NestedTensorImpl::NestedTensorImpl( } // assume contiguous, `nested_strides` and `offsets` +<<<<<<< HEAD // can be inferred from `nested_sizes` +======= +// can be infered from `nested_sizes` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NestedTensorImpl::NestedTensorImpl( const at::Tensor& buffer, const at::Tensor& nested_sizes) @@ -273,7 +277,11 @@ c10::SymInt NestedTensorImpl::sym_numel_custom() const { return NestedTensorImpl::numel_custom(); } +<<<<<<< HEAD c10::SymBool NestedTensorImpl::sym_is_contiguous_custom(MemoryFormat) const { +======= +bool NestedTensorImpl::is_contiguous_custom(MemoryFormat) const { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return nested_tensor_impl_is_contiguous(this); } IntArrayRef NestedTensorImpl::sizes_custom() const { diff --git a/aten/src/ATen/NestedTensorImpl.h b/aten/src/ATen/NestedTensorImpl.h index cddf37df34a52..48f83708088d1 100644 --- a/aten/src/ATen/NestedTensorImpl.h +++ b/aten/src/ATen/NestedTensorImpl.h @@ -32,7 +32,11 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl { at::Tensor nested_strides, at::Tensor storage_offsets); // assume contiguous, `nested_strides` and `offsets` +<<<<<<< HEAD // can be inferred from `nested_sizes` +======= + // can be infered from `nested_sizes` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) explicit NestedTensorImpl( const at::Tensor& buffer, const at::Tensor& nested_sizes); @@ -115,7 +119,11 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl { // with real implementations int64_t numel_custom() const override; c10::SymInt sym_numel_custom() const override; +<<<<<<< HEAD c10::SymBool sym_is_contiguous_custom(MemoryFormat) const override; +======= + bool is_contiguous_custom(MemoryFormat) const override; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t size_custom(int64_t d) const override { return this->size(d); } diff --git a/aten/src/ATen/Parallel.h b/aten/src/ATen/Parallel.h index b55dad02f347e..dfcb6432a424e 100644 --- a/aten/src/ATen/Parallel.h +++ b/aten/src/ATen/Parallel.h @@ -93,12 +93,20 @@ ident: identity for binary combination function sf. sf(ident, x) needs to return x. f: function for reduction over a chunk. f needs to be of signature scalar_t +<<<<<<< HEAD f(int64_t partial_begin, int64_t partial_end, scalar_t identify) +======= +f(int64_t partial_begin, int64_t partial_end, scalar_t identifiy) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sf: function to combine two partial results. sf needs to be of signature scalar_t sf(scalar_t x, scalar_t y) +<<<<<<< HEAD For example, you might have a tensor of 10000 entries and want to sum together +======= +For example, you might have a tensor of 10000 entires and want to sum together +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) all the elements. Parallel_reduce with a grain_size of 2500 will then allocate an intermediate result tensor with 4 elements. Then it will execute the function "f" you provide and pass the beginning and end index of these chunks, so diff --git a/aten/src/ATen/SparseCsrTensorImpl.cpp b/aten/src/ATen/SparseCsrTensorImpl.cpp index f73d75ab53ad9..6f4d58cd81844 100644 --- a/aten/src/ATen/SparseCsrTensorImpl.cpp +++ b/aten/src/ATen/SparseCsrTensorImpl.cpp @@ -252,7 +252,14 @@ void SparseCsrTensorImpl::set_stride(int64_t dim, int64_t new_stride) { void SparseCsrTensorImpl::set_storage_offset(int64_t storage_offset) { TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_storage_offset."); } +<<<<<<< HEAD c10::SymBool SparseCsrTensorImpl::sym_is_contiguous_custom(MemoryFormat) const { TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have is_contiguous"); } +======= +bool SparseCsrTensorImpl::is_contiguous_custom(MemoryFormat) const { + TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have is_contiguous"); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at diff --git a/aten/src/ATen/SparseCsrTensorImpl.h b/aten/src/ATen/SparseCsrTensorImpl.h index 14688163a374f..b95b4001c8c06 100644 --- a/aten/src/ATen/SparseCsrTensorImpl.h +++ b/aten/src/ATen/SparseCsrTensorImpl.h @@ -86,7 +86,11 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl { protected: IntArrayRef strides_custom() const override; SymIntArrayRef sym_strides_custom() const override; +<<<<<<< HEAD SymBool sym_is_contiguous_custom(MemoryFormat) const override; +======= + bool is_contiguous_custom(MemoryFormat) const override; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) public: void set_size(int64_t dim, int64_t new_size) override; diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index a487589833e8c..005d2ff1d11ee 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -214,7 +214,11 @@ inline Tensor applySlice( "step must be greater than zero"); // See NOTE [nested tensor size for indexing] +<<<<<<< HEAD if (self_sizes.has_value() && self_sizes.value().size() > 0) { +======= + if (self_sizes.has_value()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Skip this optimization if we are tracing, as the trace may be polymorphic // over the shape of the `self` tensor, and we still want to record // the slice. @@ -223,7 +227,11 @@ inline Tensor applySlice( : self.sym_size(dim); if (!disable_slice_optimization && TORCH_STATICALLY_KNOWN_TRUE(start.sym_eq(0)) && +<<<<<<< HEAD TORCH_STATICALLY_KNOWN_TRUE(length.sym_le(stop)) && step == 1) { +======= + TORCH_STATICALLY_KNOWN_TRUE(length.sym_eq(stop)) && step == 1) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self; } } @@ -252,7 +260,11 @@ inline Tensor applySelect( // Note: `size >= -index` is not equivalent to `size > -1 - index` if index // is INT64_MIN For std::numeric_limits::min() result of unary // minus is undefined by the standard but in practice is equal to self. On +<<<<<<< HEAD // the other hand, indexing wrapping is valid for all negative int64_t +======= + // the other hand, indexing wraping is valid for all negative int64_t +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // values, as x[INT64_MIN] is the same as x[INT64_MAX] TORCH_CHECK_INDEX( size.sym_gt(-1 - index) @@ -315,6 +327,7 @@ inline void recordTensorIndex( const Tensor& tensor, std::vector& outIndices, int64_t* dim_ptr) { +<<<<<<< HEAD if (outIndices.empty()) { outIndices.resize(*dim_ptr + 1); outIndices[*dim_ptr] = tensor; @@ -326,6 +339,12 @@ inline void recordTensorIndex( } else { *dim_ptr += 1; } +======= + // TODO: check scalarType + outIndices.resize(*dim_ptr + 1); + outIndices[*dim_ptr] = tensor; + (*dim_ptr)++; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } inline c10::List<::std::optional> typeConvertIndices( @@ -465,6 +484,7 @@ inline Tensor handleDimInMultiDimIndexing( original_tensor_device, prev_dim_result_sizes); (*dim_ptr)++; +<<<<<<< HEAD if (!outIndices.empty()) { outIndices.resize(outIndices.size() + 1); } @@ -475,13 +495,21 @@ inline Tensor handleDimInMultiDimIndexing( if (!outIndices.empty()) { outIndices.resize(outIndices.size() + ellipsis_ndims); } +======= + return result; + } else if (index.is_ellipsis()) { + (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return prev_dim_result; } else if (index.is_none()) { Tensor result = prev_dim_result.unsqueeze(*dim_ptr); (*dim_ptr)++; +<<<<<<< HEAD if (!outIndices.empty()) { outIndices.resize(outIndices.size() + 1); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return result; } else if (index.is_boolean()) { Tensor result = prev_dim_result.unsqueeze(*dim_ptr); @@ -577,10 +605,13 @@ inline Tensor applySlicing( inline Tensor dispatch_index( const Tensor& self, std::vector&& indices) { +<<<<<<< HEAD // Remove trailing null elements from indices while (!indices.empty() && !indices.back().defined()) { indices.pop_back(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.index(impl::typeConvertIndices(self, std::move(indices))); } @@ -588,10 +619,13 @@ inline Tensor dispatch_index_put_( Tensor& self, std::vector&& indices, const Tensor& value) { +<<<<<<< HEAD // Remove trailing null elements from indices while (!indices.empty() && !indices.back().defined()) { indices.pop_back(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.index_put_( impl::typeConvertIndices(self, std::move(indices)), value); } diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index 9096cbfc68eb6..7343d4a0472cf 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -208,7 +208,11 @@ bool TensorIteratorConfig::is_tensor_const(size_t idx) { // same strides are increasing. If dimensions are non-increasing, we move on to the next input to break the tie. // // Instead of applying rule 4 for tie breaking, we could move on to the next tensor directly. This would result in possibly +<<<<<<< HEAD // losing the correct permutation of the first tensor if there are permuted trivial dimensions, but could potentially +======= +// losing the correct permuation of the first tensor if there are permuted trivial dimensions, but could potentially +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // improve traversal order of the second tensor. We chose the former option to better propagate channels last layout // for example for a tensor with the sizes N1H1 // These rules result in the intuitive behavior that in most cases recovers permutation of either the first argument (if all @@ -244,7 +248,11 @@ void TensorIteratorBase::reorder_dimensions() { // initialize perm with n-1, n-2, ..., 1, 0 std::iota(perm_.rbegin(), perm_.rend(), 0); +<<<<<<< HEAD // Reordering dimensions changes iteration order +======= + // Reordering dimensions changes iteraton order +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (enforce_linear_iteration_) { permute_dimensions(perm_); return; diff --git a/aten/src/ATen/TensorIterator.h b/aten/src/ATen/TensorIterator.h index d8eebd4c06a42..cd087df807d80 100644 --- a/aten/src/ATen/TensorIterator.h +++ b/aten/src/ATen/TensorIterator.h @@ -388,7 +388,11 @@ struct TORCH_API TensorIteratorBase : public impl::MetaBase { /// Return scalar value from original_tensor_base if it is defined. When /// common_dtype is Half, casting scalar input to common_dtype might overflow. +<<<<<<< HEAD /// If the scalar is already given in the type of Half, then return scalar +======= + /// If the scalar is aleady given in the type of Half, then return scalar +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /// value from tensor_base. template T original_scalar_value(int64_t arg) { @@ -502,7 +506,11 @@ struct TORCH_API TensorIteratorBase : public impl::MetaBase { /// kernels bool can_use_32bit_indexing() const; +<<<<<<< HEAD /// An "iterable" object that recursively splits this iterator into +======= + /// An "iteratable" object that recursively splits this iterator into +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /// sub-iterators that can use 32-bit indexing. SplitUntil32Bit with_32bit_indexing() const; @@ -878,7 +886,11 @@ class TORCH_API TensorIteratorConfig final { // Sets the enforce_linear_iteration_ flag, which is false by default. // If true, iteration goes in the same order as a C-contiguous tensor +<<<<<<< HEAD // is laid out in memory. i.e. last dimension iterates fastest. +======= + // is layed out in memory. i.e. last dimension iterates fastest. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // // This iteration order can be less efficient and may even prevent // vectorization. So only use if the correctness of your kernel depends on it. diff --git a/aten/src/ATen/TensorSubclassLikeUtils.h b/aten/src/ATen/TensorSubclassLikeUtils.h index 515642a0c51d2..73c3b3c2d87b2 100644 --- a/aten/src/ATen/TensorSubclassLikeUtils.h +++ b/aten/src/ATen/TensorSubclassLikeUtils.h @@ -78,7 +78,11 @@ inline bool areAnyOptionalTensorSubclassLike( // NOTE: This function expects a scalar tensor of boolean dtype. // Eg. // Non-Composite Compliant Pattern : (t == 0).all().item() +<<<<<<< HEAD // Composite Compliant Pattern : is_salar_tensor_true((t == 0).all()) +======= +// Composite Compliant Patter : is_salar_tensor_true((t == 0).all()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inline bool is_scalar_tensor_true(const Tensor& t) { TORCH_INTERNAL_ASSERT(t.dim() == 0) TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool) diff --git a/aten/src/ATen/TensorUtils.cpp b/aten/src/ATen/TensorUtils.cpp index 34cb5329de6a3..87b56ef6eeb92 100644 --- a/aten/src/ATen/TensorUtils.cpp +++ b/aten/src/ATen/TensorUtils.cpp @@ -378,9 +378,15 @@ inline static std::optional computeStride_impl( (TORCH_GUARD_OR_TRUE(sym_ne(oldshape[tensor_d - 1], 1)) && TORCH_GUARD_OR_TRUE(sym_ne(oldstride[tensor_d - 1], tensor_numel * chunk_base_stride)))) { // We want to accumulate stuff in view_numel until view_numel == tensor_numel, if we do not +<<<<<<< HEAD // know if that is satisfied we keep accumulating. For example if view_numel = 1 and tensor_numel = u1, // we want to take that path, view_numel will become u0. Next iteration if u0==u1 we want to stop. // That's why we use TORCH_GUARD_OR_TRUE below. +======= + // know if that is satisfied we keep accumalating. For example if view_numel = 1 and tensor_numel = u1, + // we want to take that path, view_numel will become u0. Next iteration if u0==u1 we want to stop. + // Thats why we use TORCH_GUARD_OR_TRUE below. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // we use TORCH_GUARD_OR_FALSE and not TORCH_GUARD_OR_TRUE when comparing newshape[view_d] ==1 because // if we know view_numel < tensor_numel is false, we want to stop. Unless we know for sure newshape[view_d]==1 diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp index 22509c7be4e19..bfdd8f9bf51c5 100644 --- a/aten/src/ATen/ThreadLocalState.cpp +++ b/aten/src/ATen/ThreadLocalState.cpp @@ -8,7 +8,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace at { @@ -20,7 +23,10 @@ ThreadLocalState::ThreadLocalState() torch_dispatch_mode_state_(c10::impl::TorchDispatchModeTLS::get_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()), python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()), saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()), +<<<<<<< HEAD dtensor_allow_implicit_replication_(at::get_dtensor_allow_implicit_replication()), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) { #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER) for(size_t i=0; i>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_); c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_); diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index d0d8112fc4cec..7728f99799d60 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -75,8 +75,11 @@ class TORCH_API ThreadLocalState { bool functionalization_reapply_views_state_; +<<<<<<< HEAD bool dtensor_allow_implicit_replication_; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // TLS for arbitrary python objects that is registered via hooks at::impl::ThreadLocalPythonObjects saved_objects_; diff --git a/aten/src/ATen/TracerMode.h b/aten/src/ATen/TracerMode.h index d0d4c93a84f53..bffd0abda1533 100644 --- a/aten/src/ATen/TracerMode.h +++ b/aten/src/ATen/TracerMode.h @@ -27,7 +27,11 @@ // ops (ops being called by other ops). After the intermediate op call // finishes it's set back to the original `TracingState` object. // +<<<<<<< HEAD // The `TracingState` object in TLS can also be read/written via its Python +======= +// The `TracingState` obect in TLS can also be read/written via its Python +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs, // which are also exposed as `TORCH_API`. // diff --git a/aten/src/ATen/Version.cpp b/aten/src/ATen/Version.cpp index 7239f357fdd64..5bc698b6cb3a7 100644 --- a/aten/src/ATen/Version.cpp +++ b/aten/src/ATen/Version.cpp @@ -95,6 +95,7 @@ std::string get_cpu_capability() { // environment variable auto capability = native::get_cpu_capability(); switch (capability) { +<<<<<<< HEAD case native::CPUCapability::DEFAULT: return "DEFAULT"; #if defined(HAVE_VSX_CPU_DEFINITION) @@ -107,6 +108,26 @@ std::string get_cpu_capability() { case native::CPUCapability::SVE256: return "SVE256"; #else +======= +#if defined(HAVE_VSX_CPU_DEFINITION) + case native::CPUCapability::DEFAULT: + return "DEFAULT"; + case native::CPUCapability::VSX: + return "VSX"; +#elif defined(HAVE_ZVECTOR_CPU_DEFINITION) + case native::CPUCapability::DEFAULT: + return "DEFAULT"; + case native::CPUCapability::ZVECTOR: + return "Z VECTOR"; +#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION) + case native::CPUCapability::DEFAULT: + return "DEFAULT"; + case native::CPUCapability::SVE256: + return "SVE256"; +#else + case native::CPUCapability::DEFAULT: + return "NO AVX"; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case native::CPUCapability::AVX2: return "AVX2"; case native::CPUCapability::AVX512: diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h index aa000b118daa2..a4c5d1e5b1224 100644 --- a/aten/src/ATen/WrapDimUtils.h +++ b/aten/src/ATen/WrapDimUtils.h @@ -121,7 +121,11 @@ inline int64_t legacy_cat_wrap_dim_symint( const std::vector>& tensor_sizes) { for (auto& sizes : tensor_sizes) { if (sizes.size() == 1) { +<<<<<<< HEAD if (TORCH_GUARD_OR_FALSE(sizes[0].sym_eq(0))) { +======= + if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[0].sym_eq(0))) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue; } } @@ -135,7 +139,11 @@ inline int64_t legacy_cat_wrap_dim( const MaterializedITensorListRef& tensors) { for (const Tensor& tensor : tensors) { if (tensor.dim() == 1) { +<<<<<<< HEAD if (TORCH_GUARD_OR_FALSE(tensor.sym_sizes()[0].sym_eq(0))) { +======= + if (TORCH_GUARD_SIZE_OBLIVIOUS(tensor.sym_sizes()[0].sym_eq(0))) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue; } } diff --git a/aten/src/ATen/ZeroTensorFallback.cpp b/aten/src/ATen/ZeroTensorFallback.cpp index 40b34030b85b9..d29d2c981ad26 100644 --- a/aten/src/ATen/ZeroTensorFallback.cpp +++ b/aten/src/ATen/ZeroTensorFallback.cpp @@ -9,6 +9,7 @@ namespace at { +<<<<<<< HEAD /* * Design: * 1. ZeroTensors are regular tensors with TensorOptions, a storage @@ -39,6 +40,9 @@ namespace at { * it does not perfectly handle NaNs and Infs as we don't check the actual values * and assume that they are non-zero, non-inf, non-NaN etc. */ +======= + // TODO: add a note explaining the design decisions +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // ZeroTensors are designed to be immutable. Thus, we error out when an in-place operation is performed on ZeroTensors static void zeroTensorFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { const auto& arguments = op.schema().arguments(); @@ -124,7 +128,11 @@ namespace at { m.impl("clone", torch::CppFunction::makeFallthrough()); m.impl("dot", torch::CppFunction::makeFallthrough()); m.impl("vdot", torch::CppFunction::makeFallthrough()); +<<<<<<< HEAD // The functions in the list below have a specific registration in native_functions.yaml and +======= + // The functions in the list below have a specific registeration in native_functions.yaml and +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // do not use the fallback. // m.impl("mul.Tensor", torch::CppFunction::makeFallthrough()); // m.impl("add.Tensor", torch::CppFunction::makeFallthrough()); diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 4b8b5f6c5d187..8591df64ece48 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -216,7 +216,10 @@ TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) { KERNEL_MPS(_convolution, lower_precision_fp) KERNEL_MPS(conv1d, lower_precision_fp) KERNEL_MPS(conv2d, lower_precision_fp) +<<<<<<< HEAD KERNEL_MPS(conv3d, lower_precision_fp) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) KERNEL_MPS(conv_tbc, lower_precision_fp) KERNEL_MPS(conv_transpose1d, lower_precision_fp) KERNEL_MPS(conv_transpose2d, input, lower_precision_fp) @@ -240,7 +243,10 @@ TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) { KERNEL_MPS(scaled_dot_product_attention, lower_precision_fp) // fp32 +<<<<<<< HEAD KERNEL_MPS(conv_transpose3d, input, fp32) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) KERNEL_MPS(acos, fp32) KERNEL_MPS(asin, fp32) KERNEL_MPS(cosh, fp32) diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index 655b2343d5d5c..96d719004904c 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -25,7 +25,11 @@ TORCH_API void set_autocast_cache_enabled(bool enabled); // deprecated CUDA-specific autocast APIs C10_DEPRECATED_MESSAGE( "at::autocast::is_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.") +<<<<<<< HEAD inline bool is_enabled() { +======= +TORCH_API inline bool is_enabled() { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_WARN_DEPRECATION( "at::autocast::", __func__, @@ -34,7 +38,11 @@ inline bool is_enabled() { } C10_DEPRECATED_MESSAGE( "at::autocast::set_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.") +<<<<<<< HEAD inline void set_enabled(bool enabled) { +======= +TORCH_API inline void set_enabled(bool enabled) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_WARN_DEPRECATION( "at::autocast::", __func__, @@ -43,7 +51,11 @@ inline void set_enabled(bool enabled) { } C10_DEPRECATED_MESSAGE( "at::autocast::get_autocast_gpu_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.") +<<<<<<< HEAD inline at::ScalarType get_autocast_gpu_dtype() { +======= +TORCH_API inline at::ScalarType get_autocast_gpu_dtype() { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_WARN_DEPRECATION( "at::autocast::", __func__, @@ -52,7 +64,11 @@ inline at::ScalarType get_autocast_gpu_dtype() { } C10_DEPRECATED_MESSAGE( "at::autocast::set_autocast_gpu_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.") +<<<<<<< HEAD inline void set_autocast_gpu_dtype(at::ScalarType dtype) { +======= +TORCH_API inline void set_autocast_gpu_dtype(at::ScalarType dtype) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_WARN_DEPRECATION( "at::autocast::", __func__, @@ -65,7 +81,11 @@ inline void set_autocast_gpu_dtype(at::ScalarType dtype) { "at::autocast::is_" #name \ "_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \ ") instead.") \ +<<<<<<< HEAD inline bool is_##name##_enabled() { \ +======= + TORCH_API inline bool is_##name##_enabled() { \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_WARN_DEPRECATION( \ "at::autocast::", \ __func__, \ @@ -78,7 +98,11 @@ inline void set_autocast_gpu_dtype(at::ScalarType dtype) { "at::autocast::set_" #name \ "_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \ ", enabled) instead.") \ +<<<<<<< HEAD inline void set_##name##_enabled(bool enabled) { \ +======= + TORCH_API inline void set_##name##_enabled(bool enabled) { \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_WARN_DEPRECATION( \ "at::autocast::", \ __func__, \ @@ -91,7 +115,11 @@ inline void set_autocast_gpu_dtype(at::ScalarType dtype) { "at::autocast::get_autocast_" #name \ "_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(" #device_type \ ") instead.") \ +<<<<<<< HEAD inline at::ScalarType get_autocast_##name##_dtype() { \ +======= + TORCH_API inline at::ScalarType get_autocast_##name##_dtype() { \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_WARN_DEPRECATION( \ "at::autocast::", \ __func__, \ @@ -104,7 +132,11 @@ inline void set_autocast_gpu_dtype(at::ScalarType dtype) { "at::autocast::set_autocast_" #name \ "_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \ ", dtype) instead.") \ +<<<<<<< HEAD inline void set_autocast_##name##_dtype(at::ScalarType dtype) { \ +======= + TORCH_API inline void set_autocast_##name##_dtype(at::ScalarType dtype) { \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_WARN_DEPRECATION( \ "at::autocast::", \ __func__, \ @@ -377,7 +409,11 @@ Keep it simple for now by assuming only one such flag is present in the argument list. If I ever need a function with more than flag I'll figure out something else. The policy is: +<<<<<<< HEAD If the user has explicitly specified a dtype, respect it. +======= +If the user has explicity specified a dtype, respect it. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Otherwise, set it to the autocast type. ********************************************************/ diff --git a/aten/src/ATen/core/CachingHostAllocator.h b/aten/src/ATen/core/CachingHostAllocator.h index 53e95cd2d4cfd..ea4f2d13fffd6 100644 --- a/aten/src/ATen/core/CachingHostAllocator.h +++ b/aten/src/ATen/core/CachingHostAllocator.h @@ -1,7 +1,10 @@ #pragma once #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -252,7 +255,10 @@ struct CachingHostAllocatorImpl { auto* block = reinterpret_cast(ctx); std::optional> events; +<<<<<<< HEAD ska::flat_hash_set streams; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) { std::lock_guard g(block->mutex_); block->allocated_ = false; @@ -261,19 +267,29 @@ struct CachingHostAllocatorImpl { } else { events = std::vector(); events->reserve(block->streams_.size()); +<<<<<<< HEAD block->event_count_ += block->streams_.size(); // Move out streams to avoid holding the mutex during event recording streams = std::move(block->streams_); +======= + for (auto stream : block->streams_) { + record_stream(events, stream); + } + block->event_count_ += events->size(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) block->streams_.clear(); } } +<<<<<<< HEAD // Event recording must be done outside the mutex to avoid potential // deadlocks (e.g., when Python GIL is involved) for (auto stream : streams) { record_stream(events, stream); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!events) { auto index = size_index(block->size_); std::lock_guard g(free_list_[index].mutex_); @@ -352,8 +368,12 @@ struct CachingHostAllocatorImpl { } virtual bool pinned_use_background_threads() { +<<<<<<< HEAD return c10::CachingAllocator::AcceleratorAllocatorConfig:: pinned_use_background_threads(); +======= + return false; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } virtual void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const { diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h index 5f43738ea0faf..e2c17c60ac7eb 100644 --- a/aten/src/ATen/core/TensorBase.h +++ b/aten/src/ATen/core/TensorBase.h @@ -1,5 +1,6 @@ #pragma once +<<<<<<< HEAD // See https://github.com/pytorch/pytorch/issues/161660 // This compile flag is intended to be passed in to CppExtensions that rely on // the stable ABI via the `extra_compile_args` argument. This is a stopgap @@ -13,6 +14,8 @@ "TensorBase.h should not be included when TORCH_STABLE_ONLY compile flag is passed" #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -137,7 +140,11 @@ class TORCH_API TensorBase { } TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const { +<<<<<<< HEAD if (is_contiguous_or_false(memory_format)) { +======= + if (is_contiguous(memory_format)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return *this; } else { return __dispatch_contiguous(memory_format); @@ -278,6 +285,7 @@ class TORCH_API TensorBase { return impl_->is_contiguous(memory_format); } +<<<<<<< HEAD // Like is_contiguous, but more dynamic shape-friendly. May return a symbolic representation of // contiguity instead of SymTrue SymFalse, when results are data-dependent. c10::SymBool sym_is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const { @@ -297,6 +305,8 @@ class TORCH_API TensorBase { return impl_->is_contiguous(memory_format); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool is_non_overlapping_and_dense() const { return impl_->is_non_overlapping_and_dense(); } diff --git a/aten/src/ATen/core/boxing/KernelFunction_impl.h b/aten/src/ATen/core/boxing/KernelFunction_impl.h index a89a0e8952b6e..a801b32583072 100644 --- a/aten/src/ATen/core/boxing/KernelFunction_impl.h +++ b/aten/src/ATen/core/boxing/KernelFunction_impl.h @@ -15,7 +15,11 @@ std::enable_if_t< std::is_base_of_v, std::unique_ptr> make_unique_base(Args&&... args) { +<<<<<<< HEAD return std::make_unique(std::forward(args)...); +======= + return std::unique_ptr(new Child(std::forward(args)...)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace detail diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h index 20dfde846e648..8381f88edd0cf 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -105,7 +105,11 @@ using supported_primitive_arg_types = guts::typelist::typelist< // So a valid input type is one that our boxed functor wrapper can // unbox from an IValue into a C++ value. // +<<<<<<< HEAD // Whereas a valid output type is one that our wrapper can receive +======= +// Whereas a valid output type is one that our wrapper can recieve +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // as a C++ value from the unboxed functor, and box into an IValue. // diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index 91a5f64596177..fd0c78b5d572e 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -568,9 +568,15 @@ bool Dispatcher::profilingOperatorEvents() { return TORCH_SDT_IS_ENABLED(operator_start) || TORCH_SDT_IS_ENABLED(operator_end); } +<<<<<<< HEAD C10_NOINLINE void Dispatcher::fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref, std::vector& argsAddresses, std::vector& argsTypes) { if (TORCH_SDT_IS_ENABLED(operator_start)) { TORCH_SDT_WITH_SEMAPHORE(operator_start, schema_ref.get().name().c_str(), argsAddresses.size(), argsAddresses.data(), argsTypes.data()); +======= +C10_NOINLINE void Dispatcher::fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref) { + if (TORCH_SDT_IS_ENABLED(operator_start)) { + TORCH_SDT_WITH_SEMAPHORE(operator_start, schema_ref.get().name().c_str()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index bc043df6a93e9..2bd7f18a02321 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -371,10 +371,14 @@ class TORCH_API Dispatcher final { #ifdef FBCODE_CAFFE2 static bool profilingOperatorEvents(); +<<<<<<< HEAD static void fireOpStartUSDT( at::RecordFunction::schema_ref_t schema_ref, std::vector& argsAddresses, std::vector& argsTypes); +======= + static void fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static void fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref); #endif // FBCODE_CAFFE2 @@ -798,6 +802,7 @@ C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call( #ifdef FBCODE_CAFFE2 if (profilingOperatorEvents()) { +<<<<<<< HEAD std::vector argsAddresses = {(void*)(&args)...}; std::vector argsTypes = {(typeid(args).name())...}; struct FireOpRAII { @@ -807,12 +812,22 @@ C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call( std::vector& argsTypes) : schema_ref_(schema_ref) { fireOpStartUSDT(schema_ref, argsAddresses, argsTypes); +======= + struct FireOpRAII { + FireOpRAII(at::RecordFunction::schema_ref_t schema_ref) + : schema_ref_(schema_ref) { + fireOpStartUSDT(schema_ref); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ~FireOpRAII() { fireOpEndUSDT(schema_ref_); } at::RecordFunction::schema_ref_t schema_ref_; +<<<<<<< HEAD } event(op.schema(), argsAddresses, argsTypes); +======= + } event(op.schema()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return kernel.template call( op, dispatchKeySet, std::forward(args)...); } else { diff --git a/aten/src/ATen/core/dynamic_type.h b/aten/src/ATen/core/dynamic_type.h index 2ba841e44e202..696d8605cfc1c 100644 --- a/aten/src/ATen/core/dynamic_type.h +++ b/aten/src/ATen/core/dynamic_type.h @@ -64,7 +64,10 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10); _(ScalarType, kDynamicIntTypeBit, 1) \ _(Layout, kDynamicIntTypeBit, 1) \ _(SymInt, kDynamicIntTypeBit, 1) \ +<<<<<<< HEAD _(SymBool, kDynamicIntTypeBit, 1) \ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _(MemoryFormat, kDynamicIntTypeBit, 1) #define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type; diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 72589436606ec..47e09b1bbfda4 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -97,8 +97,11 @@ c10::TypePtr IValue::TagType::get(const IValue& v) { return ComplexType::get(); case Tag::Int: return IntType::get(); +<<<<<<< HEAD case Tag::UInt: return IntType::get(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case Tag::SymInt: return c10::SymIntType::get(); case Tag::SymFloat: @@ -322,8 +325,11 @@ IValue IValue::equals(const IValue& rhs) const { return rhs.isComplexDouble() && lhs.toComplexDouble() == rhs.toComplexDouble(); case Tag::Int: return rhs.isInt() && lhs.toInt() == rhs.toInt(); +<<<<<<< HEAD case Tag::UInt: return rhs.isUnsigned() && lhs.toUInt() == rhs.toUInt(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case Tag::SymInt: return rhs.isSymInt() && lhs.toSymInt() == rhs.toSymInt(); case Tag::SymFloat: @@ -383,8 +389,11 @@ size_t IValue::hash(const IValue& v) { case Tag::Int: return c10::get_hash(v.payload.u.as_int); // NB: these are technically strict aliasing violations +<<<<<<< HEAD case Tag::UInt: return c10::get_hash(v.payload.u.as_int); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case Tag::SymInt: return c10::get_hash(v.payload.u.as_int); case Tag::SymFloat: @@ -812,8 +821,11 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) { return printComplex(out, v); } case IValue::Tag::Int: return out << v.toInt(); +<<<<<<< HEAD case IValue::Tag::UInt: return out << v.toUInt(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case IValue::Tag::SymInt: return out << v.toSymInt(); case IValue::Tag::SymFloat: diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index ab2039e058201..f7facc66a7906 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -12,7 +12,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -161,7 +164,10 @@ struct Capsule { _(Double) \ _(ComplexDouble) \ _(Int) \ +<<<<<<< HEAD _(UInt) \ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _(SymInt) \ _(SymFloat) \ _(SymBool) \ @@ -655,6 +661,7 @@ struct TORCH_API IValue final { } } +<<<<<<< HEAD // Unsigned IValue(uint64_t u) : tag( u <= std::numeric_limits::max() ? Tag::Int : Tag::UInt) { payload.u.as_uint = u; @@ -678,6 +685,8 @@ struct TORCH_API IValue final { } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Bool IValue(bool b) : tag(Tag::Bool) { #if defined(__clang__) && defined(__x86_64__) @@ -918,6 +927,7 @@ struct TORCH_API IValue final { } else { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( s.isIntegral(false), "Unknown type in Scalar"); +<<<<<<< HEAD if (s.isUnsigned()) { const auto val = s.toUInt64(); payload.u.as_uint = val; @@ -926,6 +936,10 @@ struct TORCH_API IValue final { payload.u.as_int = s.toLong(); tag = Tag::Int; } +======= + tag = Tag::Int; + payload.u.as_int = s.toLong(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -949,8 +963,11 @@ struct TORCH_API IValue final { return toSymFloat(); else if (isSymBool()) return toSymBool(); +<<<<<<< HEAD else if (isUnsigned()) return toUInt(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(false, "IValue is not a Scalar"); } @@ -1280,8 +1297,11 @@ struct TORCH_API IValue final { return true; case Tag::Int: return false; +<<<<<<< HEAD case Tag::UInt: return false; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case Tag::SymInt: return true; case Tag::SymFloat: @@ -1378,8 +1398,11 @@ struct TORCH_API IValue final { union TriviallyCopyablePayload { TriviallyCopyablePayload() : as_int(0) {} int64_t as_int; +<<<<<<< HEAD // See Note [Meaning of HAS_u] uint64_t as_uint; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) double as_double; bool as_bool; // Invariant: never nullptr; null state is represented as diff --git a/aten/src/ATen/core/jit_type_base.h b/aten/src/ATen/core/jit_type_base.h index 18077ad9f6b3a..a6dcf45f1d4f1 100644 --- a/aten/src/ATen/core/jit_type_base.h +++ b/aten/src/ATen/core/jit_type_base.h @@ -677,7 +677,11 @@ inline TypePtr Type::withContained(std::vector contained_types) { } +<<<<<<< HEAD inline bool operator==(const Type& lhs, const Type& rhs) { +======= +TORCH_API inline bool operator==(const Type& lhs, const Type& rhs) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (C10_UNLIKELY(!rhs.symmetric())) { return rhs.equals(lhs); } diff --git a/aten/src/ATen/cpu/vec/intrinsics.h b/aten/src/ATen/cpu/vec/intrinsics.h index 70223700f6364..65752fe8628cf 100644 --- a/aten/src/ATen/cpu/vec/intrinsics.h +++ b/aten/src/ATen/cpu/vec/intrinsics.h @@ -1 +1,59 @@ +<<<<<<< HEAD #include +======= +#pragma once +#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) +/* GCC or clang-compatible compiler, targeting x86/x86-64 */ +#include +#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__)) +/* Clang-compatible compiler, targeting arm neon */ +#include +#if defined(__ARM_FEATURE_SVE) +/* CLANG-compatible compiler, targeting ARM with SVE */ +#include +#endif +#elif defined(_MSC_VER) +/* Microsoft C/C++-compatible compiler */ +#include +#if _MSC_VER <= 1900 +#define _mm256_extract_epi64(X, Y) \ + (_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2)) +#define _mm256_extract_epi32(X, Y) \ + (_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4)) +#define _mm256_extract_epi16(X, Y) \ + (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8)) +#define _mm256_extract_epi8(X, Y) \ + (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16)) +#endif +#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__)) +/* GCC-compatible compiler, targeting ARM with NEON */ +#include +#if defined(__ARM_FEATURE_SVE) +/* GCC-compatible compiler, targeting ARM with SVE */ +#include +#endif +#if defined(MISSING_ARM_VLD1) +#include +#elif defined(MISSING_ARM_VST1) +#include +#endif +#elif defined(__GNUC__) && defined(__IWMMXT__) +/* GCC-compatible compiler, targeting ARM with WMMX */ +#include +#elif defined(__s390x__) +// targets Z/architecture +// we will include vecintrin later +#elif (defined(__GNUC__) || defined(__xlC__)) && \ + (defined(__VEC__) || defined(__ALTIVEC__)) +/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */ +#include +/* We need to undef those tokens defined by to avoid conflicts + with the C++ types. => Can still use __bool/__vector */ +#undef bool +#undef vector +#undef pixel +#elif defined(__GNUC__) && defined(__SPE__) +/* GCC-compatible compiler, targeting PowerPC with SPE */ +#include +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h b/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h index d269e10739599..7a1c33b765fc7 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h +++ b/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h @@ -5,7 +5,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include namespace at { namespace vec { @@ -37,7 +40,11 @@ class Vectorized { return VECTOR_WIDTH / sizeof(BFloat16); } +<<<<<<< HEAD Vectorized(); +======= + Vectorized() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized(svbfloat16_t v) : values(v) {} Vectorized(int val); Vectorized(BFloat16 val); @@ -164,9 +171,12 @@ class Vectorized { Vectorized exp_u20() const { return exp(); } +<<<<<<< HEAD Vectorized fexp_u20() const { return exp(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized fmod(const Vectorized& q) const; Vectorized hypot(const Vectorized& b) const; Vectorized i0() const; @@ -224,12 +234,17 @@ class Vectorized { Vectorized le(const Vectorized& other) const; }; +<<<<<<< HEAD #if defined(__GNUC__) && __GNUC__ == 14 // Workaround for gcc-14.2.0 ICE during RTL pass: vregs when compiling for SVE __attribute__((optimize("no-tree-vectorize"))) #endif inline std::tuple, Vectorized> convert_bfloat16_float(const Vectorized& a) { +======= +inline std::tuple, Vectorized> convert_bfloat16_float( + const Vectorized& a) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static_assert( Vectorized::size() == 2 * Vectorized::size()); auto zero = svreinterpret_bf16_f32(svdup_n_f32(0.0f)); @@ -307,11 +322,14 @@ Vectorized inline operator/( return binary_operator_via_float(std::divides>(), a, b); } +<<<<<<< HEAD inline Vectorized::Vectorized() { const short zero = 0; values = svdup_n_bf16(c10::bit_cast(zero)); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inline Vectorized::Vectorized(int val) { auto vals_f = svdup_n_f32(val); values = convert_float_bfloat16(vals_f, vals_f); diff --git a/aten/src/ATen/cpu/vec/sve/vec_double.h b/aten/src/ATen/cpu/vec/sve/vec_double.h index 474652be17a1a..6867d1687451a 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_double.h +++ b/aten/src/ATen/cpu/vec/sve/vec_double.h @@ -38,9 +38,13 @@ class Vectorized { static constexpr size_type size() { return VECTOR_WIDTH / sizeof(double); } +<<<<<<< HEAD Vectorized() { values = svdup_n_f64(0); } +======= + Vectorized() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized(svfloat64_t v) : values(v) {} Vectorized(double val) { values = svdup_n_f64(val); @@ -251,9 +255,12 @@ class Vectorized { Vectorized exp_u20() const { return exp(); } +<<<<<<< HEAD Vectorized fexp_u20() const { return exp(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized fmod(const Vectorized& q) const {USE_SLEEF( { return Vectorized(Sleef_fmoddx_sve(values, q)); }, { @@ -587,6 +594,7 @@ Vectorized inline fmadd( return svmad_f64_x(ptrue, a, b, c); } +<<<<<<< HEAD template <> Vectorized inline fnmadd( const Vectorized& a, @@ -611,6 +619,8 @@ Vectorized inline fnmsub( return svnmad_f64_x(ptrue, a, b, c); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif // defined(CPU_CAPABILITY_SVE) } // namespace CPU_CAPABILITY diff --git a/aten/src/ATen/cpu/vec/sve/vec_float.h b/aten/src/ATen/cpu/vec/sve/vec_float.h index 89bce507c4849..f38a6a8140b35 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_float.h +++ b/aten/src/ATen/cpu/vec/sve/vec_float.h @@ -38,9 +38,13 @@ class Vectorized { static constexpr size_type size() { return VECTOR_WIDTH / sizeof(float); } +<<<<<<< HEAD Vectorized() { values = svdup_n_f32(0); } +======= + Vectorized() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized(svfloat32_t v) : values(v) {} Vectorized(float val) { values = svdup_n_f32(val); @@ -316,9 +320,12 @@ class Vectorized { Vectorized exp_u20() const { return exp(); } +<<<<<<< HEAD Vectorized fexp_u20() const { return exp(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized fmod(const Vectorized& q) const {USE_SLEEF( { return Vectorized(Sleef_fmodfx_sve(values, q)); }, { @@ -758,6 +765,7 @@ Vectorized inline fmadd( return svmad_f32_x(ptrue, a, b, c); } +<<<<<<< HEAD template <> Vectorized inline fnmadd( const Vectorized& a, @@ -782,6 +790,8 @@ Vectorized inline fnmsub( return svnmad_f32_x(ptrue, a, b, c); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif // defined(CPU_CAPABILITY_SVE) } // namespace CPU_CAPABILITY diff --git a/aten/src/ATen/cpu/vec/sve/vec_int.h b/aten/src/ATen/cpu/vec/sve/vec_int.h index f0bc42caa9502..2a5f8a2468851 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_int.h +++ b/aten/src/ATen/cpu/vec/sve/vec_int.h @@ -32,9 +32,13 @@ inline namespace CPU_CAPABILITY { static constexpr size_type size() { \ return vl; \ } \ +<<<<<<< HEAD Vectorized() { \ values = svdup_n_s##bit(0); \ } \ +======= + Vectorized() {} \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized(svint##bit##_t v) : values(v) {} \ Vectorized(int##bit##_t val) { \ values = svdup_n_s##bit(val); \ diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h b/aten/src/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h index 02f64af3bb088..2deea025774bb 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h @@ -553,6 +553,7 @@ Vectorized inline fmadd( } template <> +<<<<<<< HEAD Vectorized inline fnmadd( const Vectorized& a, const Vectorized& b, @@ -562,6 +563,8 @@ Vectorized inline fnmadd( } template <> +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized inline fmsub( const Vectorized& a, const Vectorized& b, @@ -570,6 +573,7 @@ Vectorized inline fmsub( return a * b - c; } +<<<<<<< HEAD template <> Vectorized inline fnmsub( const Vectorized& a, @@ -579,6 +583,8 @@ Vectorized inline fnmsub( return -a * b - c; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif // !defined(C10_MOBILE) && defined(__aarch64__) } // namespace CPU_CAPABILITY diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h b/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h index c6c34222c5cf6..a437d5cc42135 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h @@ -83,9 +83,13 @@ class Vectorized { static constexpr size_type size() { return 4; } +<<<<<<< HEAD Vectorized() { values = vmovq_n_f32(0); } +======= + Vectorized() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized(float32x4_t v) : values(v) {} Vectorized(float val) : values{vdupq_n_f32(val)} {} Vectorized(float val0, float val1, float val2, float val3) @@ -204,6 +208,7 @@ class Vectorized { store(tmp); return tmp[idx]; } +<<<<<<< HEAD int zero_mask() const { uint32x4_t is_zero_vec = vceqzq_f32(values); const int32x4_t shift = vcombine_s32( @@ -212,6 +217,20 @@ class Vectorized { uint32x4_t bits_vec = vshlq_u32(vandq_u32(is_zero_vec, vdupq_n_u32(1)), shift); return vaddvq_u32(bits_vec); +======= + // For boolean version where we want to if any 1/all zero + // etc. can be done faster in a different way. + int zero_mask() const { + __at_align__ float tmp[size()]; + store(tmp); + int mask = 0; + for (int i = 0; i < size(); ++i) { + if (tmp[i] == 0.f) { + mask |= (1 << i); + } + } + return mask; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } Vectorized isnan() const { return vreinterpretq_f32_u32(vmvnq_u32(vceqq_f32(values, values))); @@ -310,9 +329,12 @@ class Vectorized { Vectorized exp_u20() const { return exp(); } +<<<<<<< HEAD Vectorized fexp_u20() const { return exp(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME( fmod, Sleef_fmodf4) @@ -585,6 +607,7 @@ Vectorized inline fmadd( } template <> +<<<<<<< HEAD Vectorized inline fnmadd( const Vectorized& a, const Vectorized& b, @@ -593,6 +616,8 @@ Vectorized inline fnmadd( } template <> +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized inline fmsub( const Vectorized& a, const Vectorized& b, @@ -600,6 +625,7 @@ Vectorized inline fmsub( return Vectorized(vnegq_f32(vfmsq_f32(c, a, b))); } +<<<<<<< HEAD template <> Vectorized inline fnmsub( const Vectorized& a, @@ -608,6 +634,8 @@ Vectorized inline fnmsub( return Vectorized(vnegq_f32(vfmaq_f32(c, a, b))); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inline Vectorized Vectorized::erf() const { // constants const Vectorized neg_zero_vec(-0.f); diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h b/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h index ab4a5a89cba77..0a409dff051f1 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h @@ -220,6 +220,7 @@ class Vectorized : public Vectorized16< std::memcpy(ptr, tmp_values, count * sizeof(float16_t)); } } +<<<<<<< HEAD int zero_mask() const { #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC uint16x8_t is_zero_vec = vceqzq_f16(values); @@ -246,6 +247,10 @@ class Vectorized : public Vectorized16< return mask; #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC } +======= + // For boolean version where we want to if any 1/all zero + // etc. can be done faster in a different way. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized isnan() const { #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC return vreinterpretq_f16_u16(vmvnq_u16(vceqq_f16(values, values))); @@ -622,6 +627,7 @@ Vectorized inline fmadd( } template <> +<<<<<<< HEAD Vectorized inline fnmadd( const Vectorized& a, const Vectorized& b, @@ -634,6 +640,8 @@ Vectorized inline fnmadd( } template <> +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized inline fmsub( const Vectorized& a, const Vectorized& b, @@ -644,6 +652,7 @@ Vectorized inline fmsub( return a * b - c; #endif } +<<<<<<< HEAD template <> Vectorized inline fnmsub( @@ -656,6 +665,8 @@ Vectorized inline fnmsub( return -a * b - c; #endif } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif // !defined(C10_MOBILE) && defined(__aarch64__) } // namespace CPU_CAPABILITY diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h b/aten/src/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h index 5fb3679f37239..c7855f0a4fb4d 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h @@ -206,10 +206,13 @@ struct Vectorized16 { return static_cast(this)->map_with_vec_float_method( &Vectorized::exp_u20); } +<<<<<<< HEAD Derived fexp_u20() const { return static_cast(this)->map_with_vec_float_method( &Vectorized::exp_u20); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Derived fmod(const Derived& q) const { // This function is questionable with a conversion, so we use map2 return map2(q, std::fmod); diff --git a/aten/src/ATen/cpu/vec/vec256/missing_vld1_neon.h b/aten/src/ATen/cpu/vec/vec256/missing_vld1_neon.h index aa40000b6ccdb..21ae6d5aef1a8 100644 --- a/aten/src/ATen/cpu/vec/vec256/missing_vld1_neon.h +++ b/aten/src/ATen/cpu/vec/vec256/missing_vld1_neon.h @@ -1 +1,400 @@ +<<<<<<< HEAD #include +======= +/* Workaround for missing vld1_*_x2 and vst1_*_x2 intrinsics in gcc-7. */ + +__extension__ extern __inline uint8x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_u8_x2(const uint8_t* __a) { + uint8x8x2_t ret; + asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int8x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_s8_x2(const int8_t* __a) { + int8x8x2_t ret; + asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint16x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_u16_x2(const uint16_t* __a) { + uint16x4x2_t ret; + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int16x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_s16_x2(const int16_t* __a) { + int16x4x2_t ret; + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint32x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_u32_x2(const uint32_t* __a) { + uint32x2x2_t ret; + asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int32x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_s32_x2(const int32_t* __a) { + int32x2x2_t ret; + asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint64x1x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_u64_x2(const uint64_t* __a) { + uint64x1x2_t ret; + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int64x1x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_s64_x2(const int64_t* __a) { + int64x1x2_t ret; + __builtin_aarch64_simd_oi __o; + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float16x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_f16_x2(const float16_t* __a) { + float16x4x2_t ret; + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float32x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_f32_x2(const float32_t* __a) { + float32x2x2_t ret; + asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float64x1x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_f64_x2(const float64_t* __a) { + float64x1x2_t ret; + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly8x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_p8_x2(const poly8_t* __a) { + poly8x8x2_t ret; + asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly16x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_p16_x2(const poly16_t* __a) { + poly16x4x2_t ret; + asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly64x1x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1_p64_x2(const poly64_t* __a) { + poly64x1x2_t ret; + asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint8x16x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_u8_x2(const uint8_t* __a) { + uint8x16x2_t ret; + asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int8x16x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_s8_x2(const int8_t* __a) { + int8x16x2_t ret; + asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint16x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_u16_x2(const uint16_t* __a) { + uint16x8x2_t ret; + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int16x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_s16_x2(const int16_t* __a) { + int16x8x2_t ret; + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint32x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_u32_x2(const uint32_t* __a) { + uint32x4x2_t ret; + asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int32x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_s32_x2(const int32_t* __a) { + int32x4x2_t ret; + asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline uint64x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_u64_x2(const uint64_t* __a) { + uint64x2x2_t ret; + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline int64x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_s64_x2(const int64_t* __a) { + int64x2x2_t ret; + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float16x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_f16_x2(const float16_t* __a) { + float16x8x2_t ret; + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float32x4x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_f32_x2(const float32_t* __a) { + float32x4x2_t ret; + asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline float64x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_f64_x2(const float64_t* __a) { + float64x2x2_t ret; + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly8x16x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_p8_x2(const poly8_t* __a) { + poly8x16x2_t ret; + asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly16x8x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_p16_x2(const poly16_t* __a) { + poly16x8x2_t ret; + asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +__extension__ extern __inline poly64x2x2_t + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vld1q_p64_x2(const poly64_t* __a) { + poly64x2x2_t ret; + asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); + return ret; +} + +/* vst1x2 */ + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_s64_x2(int64_t* __a, int64x1x2_t val) { + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_u64_x2(uint64_t* __a, uint64x1x2_t val) { + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_f64_x2(float64_t* __a, float64x1x2_t val) { + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_s8_x2(int8_t* __a, int8x8x2_t val) { + asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_p8_x2(poly8_t* __a, poly8x8x2_t val) { + asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_s16_x2(int16_t* __a, int16x4x2_t val) { + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_p16_x2(poly16_t* __a, poly16x4x2_t val) { + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_s32_x2(int32_t* __a, int32x2x2_t val) { + asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_u8_x2(uint8_t* __a, uint8x8x2_t val) { + asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_u16_x2(uint16_t* __a, uint16x4x2_t val) { + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_u32_x2(uint32_t* __a, uint32x2x2_t val) { + asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_f16_x2(float16_t* __a, float16x4x2_t val) { + asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_f32_x2(float32_t* __a, float32x2x2_t val) { + asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1_p64_x2(poly64_t* __a, poly64x1x2_t val) { + asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_s8_x2(int8_t* __a, int8x16x2_t val) { + asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_p8_x2(poly8_t* __a, poly8x16x2_t val) { + asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_s16_x2(int16_t* __a, int16x8x2_t val) { + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_p16_x2(poly16_t* __a, poly16x8x2_t val) { + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_s32_x2(int32_t* __a, int32x4x2_t val) { + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_s64_x2(int64_t* __a, int64x2x2_t val) { + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_u8_x2(uint8_t* __a, uint8x16x2_t val) { + asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_u16_x2(uint16_t* __a, uint16x8x2_t val) { + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_u32_x2(uint32_t* __a, uint32x4x2_t val) { + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_u64_x2(uint64_t* __a, uint64x2x2_t val) { + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_f16_x2(float16_t* __a, float16x8x2_t val) { + asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_f32_x2(float32_t* __a, float32x4x2_t val) { + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_f64_x2(float64_t* __a, float64x2x2_t val) { + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); +} + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_p64_x2(poly64_t* __a, poly64x2x2_t val) { + asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); +} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/cpu/vec/vec256/missing_vst1_neon.h b/aten/src/ATen/cpu/vec/vec256/missing_vst1_neon.h index b3d721531d246..c2c0c0d91e29c 100644 --- a/aten/src/ATen/cpu/vec/vec256/missing_vst1_neon.h +++ b/aten/src/ATen/cpu/vec/vec256/missing_vst1_neon.h @@ -1 +1,11 @@ +<<<<<<< HEAD #include +======= +/* Workaround for missing vst1q_f32_x2 in gcc-8. */ + +__extension__ extern __inline void + __attribute__((__always_inline__, __gnu_inline__, __artificial__)) + vst1q_f32_x2(float32_t* __a, float32x4x2_t val) { + asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); +} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_16bit_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_16bit_float.h index 425fb6aa79e13..19eed91ff9199 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_16bit_float.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_16bit_float.h @@ -488,9 +488,12 @@ class Vectorized16 { Vectorized expm1() const { return map(Sleef_expm1f8_u10); } +<<<<<<< HEAD Vectorized fexp_u20() const { return exp(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized exp_u20() const { return exp(); } diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h index ba57ca034e9a6..ef695c37fe462 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h @@ -34,9 +34,13 @@ class Vectorized> { static constexpr size_type size() { return 2; } +<<<<<<< HEAD Vectorized() { values = _mm256_setzero_pd(); } +======= + Vectorized() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized(__m256d v) : values(v) {} Vectorized(c10::complex val) { double real_value = val.real(); diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h index 5d8c69a34b9d2..4b6f518e96c7c 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h @@ -33,9 +33,13 @@ class Vectorized> { static constexpr size_type size() { return 4; } +<<<<<<< HEAD Vectorized() { values = _mm256_setzero_ps(); } +======= + Vectorized() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized(__m256 v) : values(v) {} Vectorized(c10::complex val) { float real_value = val.real(); diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_double.h index d5abafedec2e6..75df7b555381a 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_double.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_double.h @@ -31,9 +31,13 @@ class Vectorized { static constexpr size_type size() { return 4; } +<<<<<<< HEAD Vectorized() { values = _mm256_setzero_pd(); } +======= + Vectorized() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized(__m256d v) : values(v) {} Vectorized(double val) { values = _mm256_set1_pd(val); @@ -200,9 +204,12 @@ class Vectorized { Vectorized exp_u20() const { return exp(); } +<<<<<<< HEAD Vectorized fexp_u20() const { return exp(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized fmod(const Vectorized& q) const { return Vectorized(Sleef_fmodd4(values, q)); } @@ -496,6 +503,7 @@ Vectorized inline fmadd( } template <> +<<<<<<< HEAD Vectorized inline fnmadd( const Vectorized& a, const Vectorized& b, @@ -504,12 +512,15 @@ Vectorized inline fnmadd( } template <> +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized inline fmsub( const Vectorized& a, const Vectorized& b, const Vectorized& c) { return _mm256_fmsub_pd(a, b, c); } +<<<<<<< HEAD template <> Vectorized inline fnmsub( @@ -518,6 +529,8 @@ Vectorized inline fnmsub( const Vectorized& c) { return _mm256_fnmsub_pd(a, b, c); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif #endif diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_float.h index a42a51e567a63..c8584f3ad84b8 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_float.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_float.h @@ -1,4 +1,8 @@ #pragma once +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] @@ -30,9 +34,13 @@ class Vectorized { static constexpr size_type size() { return 8; } +<<<<<<< HEAD Vectorized() { values = _mm256_setzero_ps(); } +======= + Vectorized() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized(__m256 v) : values(v) {} Vectorized(float val) { values = _mm256_set1_ps(val); @@ -257,6 +265,7 @@ class Vectorized { Vectorized expm1() const { return Vectorized(Sleef_expm1f8_u10(values)); } +<<<<<<< HEAD Vectorized fexp_u20() const { const __m256 vec_c0 = _mm256_set1_ps(0.00010703434948458272f); const __m256 vec_c1 = _mm256_set1_ps(0.30354260500649682f); @@ -314,6 +323,8 @@ class Vectorized { return result; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized exp_u20() const { // A faster version of exp with ULP=20 const __m256 vec_factorial_1 = @@ -697,6 +708,7 @@ Vectorized inline fmadd( } template <> +<<<<<<< HEAD Vectorized inline fnmadd( const Vectorized& a, const Vectorized& b, @@ -705,6 +717,8 @@ Vectorized inline fnmadd( } template <> +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized inline fmsub( const Vectorized& a, const Vectorized& b, @@ -712,6 +726,7 @@ Vectorized inline fmsub( return _mm256_fmsub_ps(a, b, c); } +<<<<<<< HEAD template <> Vectorized inline fnmsub( const Vectorized& a, @@ -720,6 +735,8 @@ Vectorized inline fnmsub( return _mm256_fnmsub_ps(a, b, c); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // TODO: rewrite with ATEN vectorized (need to add unpack and shuffle) // Used by Inductor CPP codegen for micro gemm inline void transpose_block(at::vec::VectorizedN& input) { diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_int.h b/aten/src/ATen/cpu/vec/vec256/vec256_int.h index 515cbff730d9b..fb078e92f6759 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_int.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_int.h @@ -23,9 +23,13 @@ struct Vectorizedi { } public: +<<<<<<< HEAD Vectorizedi() { values = _mm256_setzero_si256(); } +======= + Vectorizedi() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorizedi(__m256i v) : values(v) {} operator __m256i() const { return values; @@ -55,9 +59,13 @@ class Vectorized : public Vectorizedi { return 4; } using Vectorizedi::Vectorizedi; +<<<<<<< HEAD Vectorized() { values = _mm256_setzero_si256(); } +======= + Vectorized() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized(int64_t v) { values = _mm256_set1_epi64x(v); } diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h index dafe444163eb1..8612e804c18e7 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h @@ -54,9 +54,13 @@ struct Vectorizedqi { #endif public: +<<<<<<< HEAD Vectorizedqi() { vals = _mm256_setzero_si256(); } +======= + Vectorizedqi() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorizedqi(__m256i v) : vals(v) {} operator __m256i() const { return vals; @@ -123,29 +127,46 @@ typename std::enable_if_t< } template +<<<<<<< HEAD at::vec::Vectorized inline convert_float_to_int8( at::vec::Vectorized src); template <> at::vec::Vectorized inline convert_float_to_int8( at::vec::Vectorized src) { +======= +typename std::enable_if_t< + std::is_same_v || std::is_same_v, + at::vec::Vectorized< + T>> inline convert_float_to_int8(at::vec::Vectorized src) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Convert from float32 to int32 with truncation __m256i x_values_int32 = _mm256_cvttps_epi32(src); // Convert from int32 to int16 using signed saturation __m256i xy_packed_v = _mm256_packs_epi32(x_values_int32, x_values_int32); +<<<<<<< HEAD constexpr auto min_val = std::numeric_limits::min(); constexpr auto max_val = std::numeric_limits::max(); // Convert from int16 to int8 using unsigned saturation __m256i xyzw_clamped_v = pack_saturate_and_clamp( xy_packed_v, xy_packed_v, min_val, max_val); +======= + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + + // Convert from int16 to uint8/int8 using unsigned saturation + __m256i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, xy_packed_v, min_val, max_val); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __m256i permute_mask_v = _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); return _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); } +<<<<<<< HEAD template <> at::vec::Vectorized inline convert_float_to_int8( at::vec::Vectorized src) { @@ -169,6 +190,8 @@ at::vec::Vectorized inline convert_float_to_int8( return _mm256_castsi128_si256(result); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template __FORCE_INLINE void QuantizeAvx2( const float* src, diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h index 0f24ed3f69355..689c44f35688d 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h @@ -273,9 +273,12 @@ class Vectorized { Vectorized C10_ALWAYS_INLINE exp_u20() const { return exp(); } +<<<<<<< HEAD Vectorized C10_ALWAYS_INLINE fexp_u20() const { return exp(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized lgamma() const __ubsan_ignore_undefined__ { return {Sleef_lgammad2_u10(_vec0), Sleef_lgammad2_u10(_vec1)}; diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h index c02f85d08e261..6fb7328cd7142 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h @@ -352,9 +352,12 @@ class Vectorized { Vectorized C10_ALWAYS_INLINE exp_u20() const { return exp(); } +<<<<<<< HEAD Vectorized C10_ALWAYS_INLINE fexp_u20() const { return exp(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized C10_ALWAYS_INLINE log() const { return {Sleef_logf4_u10(_vec0), Sleef_logf4_u10(_vec1)}; diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h index 2c2a199da80dc..7fe375fe7dca2 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h @@ -349,6 +349,29 @@ class Vectorized { }; template <> +<<<<<<< HEAD +======= +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + vuint16 shift_vec0 = reinterpret_cast(b.vec0()); + vuint16 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)}; +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + vuint16 shift_vec0 = reinterpret_cast(b.vec0()); + vuint16 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)}; +} + +template <> +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized inline maximum( const Vectorized& a, const Vectorized& b) { @@ -362,8 +385,11 @@ Vectorized inline minimum( return a.minimum(b); } +<<<<<<< HEAD DEFINE_SHIFT_FUNCS(int16_t) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template <> Vectorized C10_ALWAYS_INLINE operator+(const Vectorized& a, const Vectorized& b) { diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h index ea22e8dde2df2..743a144aaa6b5 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h @@ -279,6 +279,29 @@ class Vectorized { }; template <> +<<<<<<< HEAD +======= +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + vuint32 shift_vec0 = reinterpret_cast(b.vec0()); + vuint32 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)}; +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + vuint32 shift_vec0 = reinterpret_cast(b.vec0()); + vuint32 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)}; +} + +template <> +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized inline maximum( const Vectorized& a, const Vectorized& b) { @@ -292,8 +315,11 @@ Vectorized inline minimum( return a.minimum(b); } +<<<<<<< HEAD DEFINE_SHIFT_FUNCS(int32_t) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template <> Vectorized C10_ALWAYS_INLINE operator+(const Vectorized& a, const Vectorized& b) { diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h index 8d0bd52c90103..71459aa6d9984 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h @@ -232,6 +232,29 @@ class Vectorized { }; template <> +<<<<<<< HEAD +======= +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { + vuint64 shift_vec0 = reinterpret_cast(b.vec0()); + vuint64 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)}; +} + +template <> +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { + vuint64 shift_vec0 = reinterpret_cast(b.vec0()); + vuint64 shift_vec1 = reinterpret_cast(b.vec1()); + return Vectorized{ + vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)}; +} + +template <> +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized inline maximum( const Vectorized& a, const Vectorized& b) { @@ -245,8 +268,11 @@ Vectorized inline minimum( return a.minimum(b); } +<<<<<<< HEAD DEFINE_SHIFT_FUNCS(int64_t) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template <> Vectorized C10_ALWAYS_INLINE operator+(const Vectorized& a, const Vectorized& b) { diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h b/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h index 7ca603c0b91df..95b7905203127 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h @@ -1,6 +1,9 @@ #pragma once #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -40,6 +43,7 @@ using vfloat32 = __attribute__((altivec(vector__))) float; using vfloat64 = __attribute__((altivec(vector__))) double; #endif +<<<<<<< HEAD inline auto make_vuint(vint8 v) { return reinterpret_cast(v); } @@ -53,6 +57,8 @@ inline auto make_vuint(vint64 v) { return reinterpret_cast(v); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if !defined(vec_float) C10_ALWAYS_INLINE vfloat32 vec_float(const vint32& vec_in) { vfloat32 vec_out; @@ -535,6 +541,7 @@ const vfloat64 vd_imag_half = vfloat64{0.0, 0.5}; const vfloat64 vd_sqrt2_2 = vfloat64{0.70710678118654757, 0.70710678118654757}; const vfloat64 vd_pi_2 = vfloat64{M_PI / 2.0, 0.0}; +<<<<<<< HEAD template Vectorized VsxShiftRightArith( const Vectorized& a, @@ -571,6 +578,8 @@ Vectorized VsxShiftLeftArith( return VsxShiftLeftArith(a, b); \ } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace CPU_CAPABILITY } // namespace vec } // namespace at diff --git a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h index efb97b3c614db..cd430913e695c 100644 --- a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h +++ b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h @@ -1023,9 +1023,12 @@ struct Vectorized()>> { Vectorized exp_u20() const { return exp(); } +<<<<<<< HEAD Vectorized fexp_u20() const { return exp(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized log() const { return mapSleef(Sleef_logf4_u10, Sleef_logd2_u10); diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h index 844b3b1fcc1e8..cd1b61eeb2e85 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h @@ -192,9 +192,13 @@ class Vectorized16 { static constexpr size_type size() { return 32; } +<<<<<<< HEAD Vectorized16() { values = _mm512_setzero_si512(); } +======= + Vectorized16() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized16(__m512i v) : values(v) {} Vectorized16(T val) { value_type uw = val.x; @@ -537,9 +541,12 @@ class Vectorized16 { Vectorized expm1() const { return map(Sleef_expm1f16_u10); } +<<<<<<< HEAD Vectorized fexp_u20() const { return exp(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized exp_u20() const { return exp(); } diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h index 3776001fc8720..3d11a98ee0815 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h @@ -34,9 +34,13 @@ class Vectorized> { static constexpr size_type size() { return 4; } +<<<<<<< HEAD Vectorized() { values = _mm512_setzero_pd(); } +======= + Vectorized() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized(__m512d v) : values(v) {} Vectorized(c10::complex val) { double real_value = val.real(); diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h index d434b2a1e2070..bb91ac64c4549 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h @@ -34,9 +34,13 @@ class Vectorized> { static constexpr size_type size() { return 8; } +<<<<<<< HEAD Vectorized() { values = _mm512_setzero_ps(); } +======= + Vectorized() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized(__m512 v) : values(v) {} Vectorized(c10::complex val) { float real_value = val.real(); diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_double.h index 438fd31e91618..4fcab45731748 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_double.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_double.h @@ -34,9 +34,13 @@ class Vectorized { static constexpr size_type size() { return 8; } +<<<<<<< HEAD Vectorized() { values = _mm512_setzero_pd(); } +======= + Vectorized() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized(__m512d v) : values(v) {} Vectorized(double val) { values = _mm512_set1_pd(val); @@ -223,9 +227,12 @@ class Vectorized { Vectorized exp_u20() const { return exp(); } +<<<<<<< HEAD Vectorized fexp_u20() const { return exp(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized fmod(const Vectorized& q) const { return Vectorized(Sleef_fmodd8(values, q)); } @@ -537,6 +544,7 @@ Vectorized inline fmadd( } template <> +<<<<<<< HEAD Vectorized inline fnmadd( const Vectorized& a, const Vectorized& b, @@ -545,6 +553,8 @@ Vectorized inline fnmadd( } template <> +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized inline fmsub( const Vectorized& a, const Vectorized& b, @@ -552,6 +562,7 @@ Vectorized inline fmsub( return _mm512_fmsub_pd(a, b, c); } +<<<<<<< HEAD template <> Vectorized inline fnmsub( const Vectorized& a, @@ -560,6 +571,8 @@ Vectorized inline fnmsub( return _mm512_fnmsub_pd(a, b, c); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif } // namespace CPU_CAPABILITY diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_float.h index 7a9e69b76c851..de0c9ea3fca26 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_float.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_float.h @@ -32,9 +32,13 @@ class Vectorized { static constexpr size_type size() { return 16; } +<<<<<<< HEAD Vectorized() { values = _mm512_setzero_ps(); } +======= + Vectorized() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized(__m512 v) : values(v) {} Vectorized(float val) { values = _mm512_set1_ps(val); @@ -312,6 +316,7 @@ class Vectorized { Vectorized expm1() const { return Vectorized(Sleef_expm1f16_u10(values)); } +<<<<<<< HEAD Vectorized fexp_u20() const { const __m512 vec_c0 = _mm512_set1_ps(0.00010703434948458272f); const __m512 vec_c1 = _mm512_set1_ps(0.30354260500649682f); @@ -366,6 +371,8 @@ class Vectorized { // final interpretation to float return _mm512_castsi512_ps(casted_integer); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized exp_u20() const { // A faster version of exp with ULP=20 const __m512 vec_factorial_1 = @@ -750,6 +757,7 @@ Vectorized inline fmadd( } template <> +<<<<<<< HEAD Vectorized inline fnmadd( const Vectorized& a, const Vectorized& b, @@ -758,6 +766,8 @@ Vectorized inline fnmadd( } template <> +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized inline fmsub( const Vectorized& a, const Vectorized& b, @@ -765,6 +775,7 @@ Vectorized inline fmsub( return _mm512_fmsub_ps(a, b, c); } +<<<<<<< HEAD template <> Vectorized inline fnmsub( const Vectorized& a, @@ -773,6 +784,8 @@ Vectorized inline fnmsub( return _mm512_fnmsub_ps(a, b, c); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // TODO: rewrite with ATEN vectorized (need to add unpack and shuffle) // Used by Inductor CPP codegen for micro gemm // Code referred to FBGEMM: diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_int.h b/aten/src/ATen/cpu/vec/vec512/vec512_int.h index 5f80a7c2bcff0..01d20b33a1c03 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_int.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_int.h @@ -53,9 +53,13 @@ class Vectorized : public Vectorizedi { return 8; } using Vectorizedi::Vectorizedi; +<<<<<<< HEAD Vectorized() { values = _mm512_setzero_si512(); } +======= + Vectorized() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized(int64_t v) { values = _mm512_set1_epi64(v); } diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h index 64ba47e0f0646..ba4fa82bbeed5 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h @@ -55,9 +55,13 @@ struct Vectorizedqi { #endif public: +<<<<<<< HEAD Vectorizedqi() { vals = _mm512_setzero_si512(); } +======= + Vectorizedqi() {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorizedqi(__m512i v) : vals(v) {} operator __m512i() const { return vals; @@ -125,24 +129,40 @@ typename std::enable_if_t< } template +<<<<<<< HEAD at::vec::Vectorized inline convert_float_to_int8( at::vec::Vectorized src); template <> at::vec::Vectorized inline convert_float_to_int8( at::vec::Vectorized src) { +======= +typename std::enable_if_t< + std::is_same_v || std::is_same_v, + at::vec::Vectorized< + T>> inline convert_float_to_int8(at::vec::Vectorized src) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Convert from float32 to int32 with truncation __m512i x_values_int32 = _mm512_cvttps_epi32(src); // Convert from int32 to int16 using signed saturation __m512i xy_packed_v = _mm512_packs_epi32(x_values_int32, x_values_int32); +<<<<<<< HEAD constexpr auto min_val = std::numeric_limits::min(); constexpr auto max_val = std::numeric_limits::max(); // Convert from int16 to int8 using unsigned saturation __m512i xyzw_clamped_v = pack_saturate_and_clamp( xy_packed_v, xy_packed_v, min_val, max_val); +======= + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + + // Convert from int16 to uint8/int8 using unsigned saturation + __m512i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, xy_packed_v, min_val, max_val); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __m512i permute_mask_v = _mm512_set_epi32( 0x0f, 0x0b, @@ -163,6 +183,7 @@ at::vec::Vectorized inline convert_float_to_int8( return _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); } +<<<<<<< HEAD template <> at::vec::Vectorized inline convert_float_to_int8( at::vec::Vectorized src) { @@ -178,6 +199,8 @@ at::vec::Vectorized inline convert_float_to_int8( return _mm512_castsi128_si512(int8_src); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template __FORCE_INLINE void QuantizeAvx512( const float* src, diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index bfecfa3f933a2..6e89dc5162bfe 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -238,6 +238,12 @@ struct Vectorized { Vectorized vector; int_same_size_t buffer[size()]; mask.store(buffer); +<<<<<<< HEAD +======= +#if defined(__clang__) && __ARM_FEATURE_SVE +#pragma clang loop vectorize(disable) +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto i : c10::irange(size())) { if (buffer[i] & 0x01) { vector[i] = b[i]; @@ -544,9 +550,12 @@ struct Vectorized { Vectorized exp_u20() const { return map(std::exp); } +<<<<<<< HEAD Vectorized fexp_u20() const { return map(std::exp); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized frac() const { return *this - this->trunc(); } @@ -1248,6 +1257,7 @@ inline Vectorized fmadd( VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmadd) template +<<<<<<< HEAD inline Vectorized fnmadd( const Vectorized& a, const Vectorized& b, @@ -1258,6 +1268,8 @@ inline Vectorized fnmadd( VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fnmadd) template +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inline Vectorized fmsub( const Vectorized& a, const Vectorized& b, @@ -1268,6 +1280,7 @@ inline Vectorized fmsub( VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmsub) template +<<<<<<< HEAD inline Vectorized fnmsub( const Vectorized& a, const Vectorized& b, @@ -1278,6 +1291,8 @@ inline Vectorized fnmsub( VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fnmsub) template +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Vectorized inline operator&&( const Vectorized& a, const Vectorized& b) { diff --git a/aten/src/ATen/cpu/vec/vec_half.h b/aten/src/ATen/cpu/vec/vec_half.h index dc1c23c74ae52..2bf13659596c5 100644 --- a/aten/src/ATen/cpu/vec/vec_half.h +++ b/aten/src/ATen/cpu/vec/vec_half.h @@ -3,12 +3,58 @@ #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace at::vec { // See Note [CPU_CAPABILITY namespace] inline namespace CPU_CAPABILITY { +<<<<<<< HEAD +======= +#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) +static inline uint16_t float2half_scalar(float val) { +#if defined(CPU_CAPABILITY_AVX2) +#if defined(_MSC_VER) + __m256 v = _mm256_set1_ps(val); + __m128i o = + _mm256_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return static_cast(_mm_cvtsi128_si32(o)); +#else + return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); +#endif +#elif defined(CPU_CAPABILITY_AVX512) + __m512 v = _mm512_set1_ps(val); + __m256i o = + _mm512_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return static_cast( + _mm_cvtsi128_si32(_mm256_castsi256_si128(o))); +#endif +} + +static inline float half2float_scalar(uint16_t val) { +#if defined(CPU_CAPABILITY_AVX2) +#if defined(_MSC_VER) + __m128i v = _mm_cvtsi32_si128(val); + __m256 o = _mm256_cvtph_ps(v); + return _mm256_cvtss_f32(o); +#else + return _cvtsh_ss(val); +#endif +#elif defined(CPU_CAPABILITY_AVX512) + __m256i v = + _mm256_setr_epi16(val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + __m512 o = _mm512_cvtph_ps(v); + return _mm512_cvtss_f32(o); +#endif +} + +#endif + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Transpose a [2, 32] matrix to [32, 2] // Note: the output leading dimension should be 2, // that is, the output must be contiguous diff --git a/aten/src/ATen/cpu/vec/vec_n.h b/aten/src/ATen/cpu/vec/vec_n.h index 3de55de6f1b85..93c61ad3a44b4 100644 --- a/aten/src/ATen/cpu/vec/vec_n.h +++ b/aten/src/ATen/cpu/vec/vec_n.h @@ -263,7 +263,10 @@ class VectorizedN { VECTORIZEDN_DEFINE_UNARY_OP(exp2) VECTORIZEDN_DEFINE_UNARY_OP(expm1) VECTORIZEDN_DEFINE_UNARY_OP(exp_u20) +<<<<<<< HEAD VECTORIZEDN_DEFINE_UNARY_OP(fexp_u20) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) VECTORIZEDN_DEFINE_UNARY_OP(frac) VECTORIZEDN_DEFINE_BINARY_OP(fmod) VECTORIZEDN_DEFINE_UNARY_OP(log) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 0d319ea593840..1849d864e0ef8 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -17,7 +17,10 @@ #include #ifdef USE_ROCM +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include // until hipblas has an API to accept flags, we must use rocblas here #include @@ -189,6 +192,7 @@ uint32_t _getAlignment(uintptr_t address) { } #endif +<<<<<<< HEAD #ifdef USE_ROCM static c10::cuda::CUDAStream _getCarveoutStream(int32_t value) { static int32_t last_value = 0; @@ -243,6 +247,8 @@ static void _syncCurrentWithCarveoutStream(hipStream_t stream, bool presync) { } #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct CublasLtWorkspace { CublasLtWorkspace() { size = at::cuda::getCUDABlasLtWorkspaceSize(); @@ -391,7 +397,11 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D computeType = CUBLAS_COMPUTE_64F; scaleType = CUDA_R_64F; } else if constexpr (std::is_same_v) { +<<<<<<< HEAD if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") { +======= + if (at::globalContext().allowTF32CuBLAS()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) computeType = CUBLAS_COMPUTE_32F_FAST_TF32; } } else if constexpr (std::is_same_v>) { @@ -445,7 +455,10 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb); +<<<<<<< HEAD auto stream = at::cuda::getCurrentCUDAStream(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifndef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { computeDesc.setAttribute( @@ -453,12 +466,15 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value()); } +<<<<<<< HEAD #else if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { stream = _getCarveoutStream( at::globalContext()._SMCarveout_EXPERIMENTAL().value()); _syncCurrentWithCarveoutStream(stream, true); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif CuBlasLtMatrixLayout Adesc(abType, m, k, lda, opa == CUBLAS_OP_T); CuBlasLtMatrixLayout Bdesc(abType, k, n, ldb, opb == CUBLAS_OP_T); @@ -521,12 +537,16 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D &heuristicResult.algo, ltworkspace.ptr, ltworkspace.size, +<<<<<<< HEAD stream); #ifdef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { _syncCurrentWithCarveoutStream(stream, false); } #endif +======= + at::cuda::getCurrentCUDAStream()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if (cublasStatus != CUBLAS_STATUS_SUCCESS) { TORCH_WARN( @@ -832,7 +852,11 @@ void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(at::BFloat16)); } } +<<<<<<< HEAD #if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) +======= +#if defined(USE_ROCM) && !defined(_MSC_VER) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::bgemm_internal_ck(CUDABLAS_BGEMM_ARGS(at::BFloat16)); } @@ -996,6 +1020,12 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { template <> void bgemm(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)) { +<<<<<<< HEAD +======= + #ifdef USE_ROCM + TORCH_CHECK(false, "bgemm input type at::Half and output type float is not supported for ROCm"); + #endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // TODO: Support tuning for Half inputs and FP32 output bgemm_internal(CUDABLAS_BGEMM_ARGS(at::Half)); } @@ -1003,7 +1033,13 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, float) template <> void bgemm(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float)) { +<<<<<<< HEAD #ifndef USE_ROCM +======= + #ifdef USE_ROCM + TORCH_CHECK(false, "bgemm input type at::BFloat16 and output type float is not supported for ROCm"); + #else +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); if (prop->major < 8) @@ -1268,7 +1304,11 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double)) gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(double)); #endif } +<<<<<<< HEAD #if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) +======= +#if defined(USE_ROCM) && !defined(_MSC_VER) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(double)); } @@ -1284,7 +1324,11 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(float)); } +<<<<<<< HEAD #if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) +======= +#if defined(USE_ROCM) && !defined(_MSC_VER) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100 gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(float)); @@ -1336,7 +1380,11 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::Half)); } +<<<<<<< HEAD #if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) +======= +#if defined(USE_ROCM) && !defined(_MSC_VER) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(at::Half)); } @@ -1352,7 +1400,11 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::BFloat16)); } +<<<<<<< HEAD #if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) +======= +#if defined(USE_ROCM) && !defined(_MSC_VER) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(at::BFloat16)); } @@ -1508,6 +1560,12 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { template <> void gemm(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)) { +<<<<<<< HEAD +======= + #ifdef USE_ROCM + TORCH_CHECK(false, "gemm input type at::Half and output type float is not supported for ROCm"); + #endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // TODO: Support Tuning for fp16-fp32 gemm gemm_internal(CUDABLAS_GEMM_ARGS(at::Half)); } @@ -1515,7 +1573,13 @@ void gemm(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)) template <> void gemm(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float)) { +<<<<<<< HEAD #ifndef USE_ROCM +======= + #ifdef USE_ROCM + TORCH_CHECK(false, "gemm input type at::BFloat16 and output type float is not supported for ROCm"); + #else +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); if (prop->major < 8) @@ -1575,7 +1639,11 @@ bool gemm_and_bias( computeType = CUBLAS_COMPUTE_64F; scaleType = CUDA_R_64F; } else if constexpr (std::is_same_v) { +<<<<<<< HEAD if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") { +======= + if (at::globalContext().allowTF32CuBLAS()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) computeType = CUBLAS_COMPUTE_32F_FAST_TF32; } } else if constexpr (std::is_same_v) { @@ -1614,7 +1682,10 @@ bool gemm_and_bias( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa); cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N; computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb); +<<<<<<< HEAD auto stream = at::cuda::getCurrentCUDAStream(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifndef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { computeDesc.setAttribute( @@ -1622,12 +1693,15 @@ bool gemm_and_bias( at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value()); } +<<<<<<< HEAD #else if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { stream = _getCarveoutStream( at::globalContext()._SMCarveout_EXPERIMENTAL().value()); _syncCurrentWithCarveoutStream(stream, true); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; if (activation == GEMMAndBiasActivationEpilogue::RELU) { @@ -1696,12 +1770,16 @@ bool gemm_and_bias( &heuristicResult.algo, ltworkspace.ptr, ltworkspace.size, +<<<<<<< HEAD stream); #ifdef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { _syncCurrentWithCarveoutStream(stream, false); } #endif +======= + at::cuda::getCurrentCUDAStream()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if (cublasStatus != CUBLAS_STATUS_SUCCESS) { TORCH_WARN( @@ -1833,6 +1911,7 @@ template bool gemm_and_bias( int64_t result_ld, GEMMAndBiasActivationEpilogue activation); +<<<<<<< HEAD int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fast_accum) { switch (scaling_type) { case ScalingType::BlockWise1x32: @@ -1900,6 +1979,8 @@ case ScalingType::TensorWise: } } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void scaled_gemm( char transa, char transb, @@ -1911,20 +1992,31 @@ void scaled_gemm( int64_t mat1_ld, ScalarType mat1_dtype, ScalarType mat1_scale_dtype, +<<<<<<< HEAD ScalingType mat1_scaling_type, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const void* mat2_ptr, const void* mat2_scale_ptr, int64_t mat2_ld, ScalarType mat2_dtype, ScalarType mat2_scale_dtype, +<<<<<<< HEAD ScalingType mat2_scaling_type, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const void* bias_ptr, ScalarType bias_dtype, void* result_ptr, const void *result_scale_ptr, int64_t result_ld, ScalarType result_dtype, +<<<<<<< HEAD bool use_fast_accum) { +======= + bool use_fast_accum, + bool use_rowwise) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Note: see `cublasCommonArgs` for various non-intuitive manupulations // of input arguments to this function. #if CUDA_VERSION >= 11080 || defined(USE_ROCM) @@ -1937,15 +2029,23 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER; cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER; +<<<<<<< HEAD #if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT) // hipblaslt supported row-wise before cublas, and did so their own way (via // the SCALE_POINTERSs), but then migrated to match how cublas does it (via // the SCALE_MODEs). Here we check for this early custom mode. bool use_rowwise = (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise); +======= +#if defined(USE_ROCM) +#if defined(HIPBLASLT_OUTER_VEC) + // this case is handled later as hipified CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F +#elif defined(HIPBLASLT_VEC_EXT) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (use_rowwise) { matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; } +<<<<<<< HEAD else if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) { #if ROCM_VERSION >= 70000 if (at::detail::getCUDAHooks().isGPUArch({"gfx950"})) { @@ -1964,12 +2064,32 @@ void scaled_gemm( // rowwise isn't supported using older cublaslt or older hipblaslt TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); #endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT) +======= + else if(mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) { +#if ROCM_VERSION >= 70000 + if (at::detail::getCUDAHooks().isGPUArch(0, {"gfx950"})) { + // Validate matrix dimensions for MX format + TORCH_CHECK((m % 32 == 0) && (n % 32 == 0) && (k % 32 == 0), + "Matrix dimensions must be multiples of 32 for MX format. ", + "Got m=", m, ", n=", n, ", k=", k); + } +#endif + } +#else + // rowwise isn't supported using older hipblaslt + TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with older hipblaslt"); +#endif +#endif // defined(USE_ROCM) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) computeDesc.setAttribute(matmulDescA, mat1_scale_ptr); computeDesc.setAttribute(matmulDescB, mat2_scale_ptr); if (result_scale_ptr != nullptr) { computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); } +<<<<<<< HEAD auto stream = at::cuda::getCurrentCUDAStream(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifndef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { computeDesc.setAttribute( @@ -1977,12 +2097,15 @@ void scaled_gemm( at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value()); } +<<<<<<< HEAD #else if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { stream = _getCarveoutStream( at::globalContext()._SMCarveout_EXPERIMENTAL().value()); _syncCurrentWithCarveoutStream(stream, true); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif // ifndef USE_ROCM #ifndef USE_ROCM const int8_t fastAccuMode = use_fast_accum ? 1 : 0; @@ -2002,6 +2125,7 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype)); } +<<<<<<< HEAD // For other data types, use the get_scale_mode function based on scaling type // The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt, // but we must invoke get_scale_mode anyways to trigger the version checks. @@ -2013,6 +2137,35 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode); #endif // if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && ROCM_VERSION >= 70000 && defined(HIPBLASLT_OUTER_VEC)) +======= + + if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) { +#if (!defined(USE_ROCM) && CUDA_VERSION >= 12080) || (defined(USE_ROCM) && ROCM_VERSION >= 70000) + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0); + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0); +#else + TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 or ROCm 7.0(with gfx950) and above"); +#endif // if CUDA_VERSION >= 12080 + } else if (mat1_scale_dtype == kFloat8_e4m3fn && mat2_scale_dtype == kFloat8_e4m3fn) { +#if CUDA_VERSION >= 12080 + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3); + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3); +#else + TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales is only supported for CUDA 12.8 and above"); +#endif // if CUDA_VERSION >= 12080 + } else if (mat1_scale_dtype == kFloat && mat2_scale_dtype == kFloat && use_rowwise) { +#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); +#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT) + // no-op here for older hipblaslt ext enums, to avoid TORCH_CHECK below +#else + TORCH_CHECK(false, "scaled_gemm with `torch.float` outer vector scaling is only supported for CUDA 12.9 and above"); +#endif // if CUDA_VERSION >= 12090 + } + + auto stream = c10::cuda::getCurrentCUDAStream(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CuBlasLtMatmulPreference preference; auto ltworkspace = CublasLtWorkspace(); preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size); @@ -2099,11 +2252,14 @@ void scaled_gemm( ltworkspace.ptr, ltworkspace.size, stream); +<<<<<<< HEAD #ifdef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { _syncCurrentWithCarveoutStream(stream, false); } #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK( cublasStatus == CUBLAS_STATUS_SUCCESS, "CUDA error: ", @@ -2157,7 +2313,10 @@ void int8_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa); cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N; computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb); +<<<<<<< HEAD auto stream = at::cuda::getCurrentCUDAStream(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifndef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { computeDesc.setAttribute( @@ -2165,12 +2324,15 @@ void int8_gemm( at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value()); } +<<<<<<< HEAD #else if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { stream = _getCarveoutStream( at::globalContext()._SMCarveout_EXPERIMENTAL().value()); _syncCurrentWithCarveoutStream(stream, true); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1); @@ -2232,7 +2394,11 @@ void int8_gemm( #else 0, #endif +<<<<<<< HEAD stream); +======= + at::cuda::getCurrentCUDAStream()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK( cublasStatus == CUBLAS_STATUS_SUCCESS, "CUDA error: ", @@ -2261,11 +2427,14 @@ void int8_gemm( computeType, " scaleType ", scaleType); +<<<<<<< HEAD #ifdef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { _syncCurrentWithCarveoutStream(stream, false); } #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } template <> @@ -2577,6 +2746,11 @@ void vdot>(CUDABLAS_DOT_ARGTYPES(c10::complex)) { reinterpret_cast(result))); } +<<<<<<< HEAD +======= +// HIP on Windows does not support +#if !(defined(USE_ROCM) && defined(_MSC_VER)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template <> void getrsBatched(CUDABLAS_GETRS_ARGTYPES(float)) { TORCH_CUDABLAS_CHECK(cublasSgetrsBatched( @@ -2775,5 +2949,9 @@ void gelsBatched>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::comple devInfoArray, batchSize)); } +<<<<<<< HEAD +======= +#endif // !(defined(USE_ROCM) && defined(_MSC_VER)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::cuda::blas diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index b235840418e25..cf15571cb36a1 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -136,6 +136,7 @@ void int8_gemm( int32_t* result_ptr, int64_t result_ld); +<<<<<<< HEAD enum class ScalingType : std::uint8_t { TensorWise, // fp32 scales RowWise, // fp32 scales @@ -145,6 +146,8 @@ enum class ScalingType : std::uint8_t { BlockWise128x128, // fp32 scales }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void scaled_gemm( char transa, char transb, @@ -156,20 +159,31 @@ void scaled_gemm( int64_t mat1_ld, ScalarType mat1_dtype, ScalarType mat1_scale_dtype, +<<<<<<< HEAD ScalingType mat1_scaling_type, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const void* mat2_ptr, const void* mat2_scale_ptr, int64_t mat2_ld, ScalarType mat2_dtype, ScalarType mat2_scale_dtype, +<<<<<<< HEAD ScalingType mat2_scaling_type, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const void* bias_ptr, ScalarType bias_dtype, void* result_ptr, const void* result_scale_ptr, int64_t result_ld, ScalarType result_dtype, +<<<<<<< HEAD bool use_fast_accum); +======= + bool use_fast_accum, + bool use_rowwise); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype) @@ -343,6 +357,12 @@ void vdot>(CUDABLAS_DOT_ARGTYPES(c10::complex)); int m, int n, int nrhs, Dtype** dA_array, int ldda, \ Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize +<<<<<<< HEAD +======= +// HIP on Windows does not support getrs, geqrf, getrf, gels +#if !(defined(USE_ROCM) && defined(_MSC_VER)) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype),"at::cuda::blas::getrsBatched: not implemented"); @@ -397,4 +417,31 @@ TORCH_CUDA_CU_API void gelsBatched>(CUDABLAS_GELS_BATCHED_A template<> TORCH_CUDA_CU_API void gelsBatched>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex)); +<<<<<<< HEAD +======= +#else // !(defined(USE_ROCM) && defined(_MSC_VER)) + +template +void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) { + TORCH_CHECK(false, "at::cuda::blas::getrsBatched: not supported for HIP on Windows"); +} + +template +void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) { + TORCH_CHECK(false, "at::cuda::blas::geqrfBatched: not supported for HIP on Windows"); +} + +template +void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) { + TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not supported for HIP on Windows"); +} + +template +void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) { + TORCH_CHECK(false, "at::cuda::blas::gelsBatched: not supported for HIP on Windows"); +} + +#endif // !(defined(USE_ROCM) && defined(_MSC_VER)) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::cuda::blas diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index b8cd84c56daef..97e4e84126cec 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -2,6 +2,10 @@ #include #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -252,6 +256,7 @@ cudaGraph_t CUDAGraph::raw_cuda_graph() { return graph_; } +<<<<<<< HEAD cudaGraphExec_t CUDAGraph::raw_cuda_graph_exec() { TORCH_CHECK( has_graph_exec_, @@ -259,6 +264,8 @@ cudaGraphExec_t CUDAGraph::raw_cuda_graph_exec() { return graph_exec_; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void CUDAGraph::reset() { // I'd prefer these checks throw exceptions, not print warnings, // but the destructor calls reset(), and at least one CI build diff --git a/aten/src/ATen/cuda/CUDAGraph.h b/aten/src/ATen/cuda/CUDAGraph.h index c18ad66b20809..260bfe5952eda 100644 --- a/aten/src/ATen/cuda/CUDAGraph.h +++ b/aten/src/ATen/cuda/CUDAGraph.h @@ -2,7 +2,10 @@ #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -37,7 +40,10 @@ struct TORCH_CUDA_CPP_API CUDAGraph { void enable_debug_mode(); void debug_dump(const std::string& debug_path); cudaGraph_t raw_cuda_graph(); +<<<<<<< HEAD cudaGraphExec_t raw_cuda_graph_exec(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) protected: cudaGraph_t graph_ = nullptr; diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index 34aa15d0c06cf..f8f792f2a9911 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -258,7 +258,11 @@ DECLARE_HOST_ALLOCATOR( CUDACachingHostAllocator, CUDACachingHostAllocatorImpl, raw_local_deleter, +<<<<<<< HEAD caching_host_allocator) +======= + caching_host_allocator); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) REGISTER_HOST_ALLOCATOR(at::kCUDA, &caching_host_allocator) diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index 3298513414438..9c8601e8e118a 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -309,8 +309,12 @@ cublasHandle_t getCurrentCUDABlasHandle() { // On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup // FP32 data type calculations based on the value of the allow_tf32 flag. // To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH. +<<<<<<< HEAD if (!NoTF32Guard::should_disable_tf32() && at::globalContext().float32Precision("cuda", "matmul") == "tf32") { +======= + if (!NoTF32Guard::should_disable_tf32() && at::globalContext().allowTF32CuBLAS()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH)); } else { TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); diff --git a/aten/src/ATen/cuda/PeerToPeerAccess.cpp b/aten/src/ATen/cuda/PeerToPeerAccess.cpp index 66a75db6ea067..93ca53c77d6df 100644 --- a/aten/src/ATen/cuda/PeerToPeerAccess.cpp +++ b/aten/src/ATen/cuda/PeerToPeerAccess.cpp @@ -4,9 +4,12 @@ #include #include +<<<<<<< HEAD #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) #include #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -15,7 +18,10 @@ namespace at::cuda { static std::vector p2pAccessEnabled_; +<<<<<<< HEAD static std::vector fabricAccessEnabled_; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static int64_t num_devices_ = -1; namespace detail { @@ -33,15 +39,22 @@ void init_p2p_access_cache(int64_t num_devices) { for (const auto i : c10::irange(num_devices)) { p2pAccessEnabled_[i * num_devices + i] = 1; } +<<<<<<< HEAD fabricAccessEnabled_.clear(); fabricAccessEnabled_.resize(num_devices, -1); } } // namespace detail +======= +} + +} // namespace detail +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool get_p2p_access(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) { at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); +<<<<<<< HEAD TORCH_CHECK(dev >= 0 || dev < num_devices_, dev, " is not a device"); TORCH_CHECK( dev_to_access >= 0 || dev_to_access < num_devices_, @@ -50,6 +63,15 @@ bool get_p2p_access(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) { TORCH_INTERNAL_ASSERT(num_devices_ >= 0, "p2p access cache not initialized"); auto& cache = p2pAccessEnabled_[dev * num_devices_ + dev_to_access]; +======= + TORCH_CHECK(dev >= 0 || dev < num_devices_, + dev, " is not a device"); + TORCH_CHECK(dev_to_access >= 0 || dev_to_access < num_devices_, + dev_to_access, " is not a device"); + TORCH_INTERNAL_ASSERT(num_devices_ >= 0, "p2p access cache not initialized"); + + auto &cache = p2pAccessEnabled_[dev * num_devices_ + dev_to_access]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (cache != -1) { return cache; @@ -65,6 +87,7 @@ bool get_p2p_access(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) { return cache; } +<<<<<<< HEAD namespace { #if !defined USE_ROCM && defined CUDA_VERSION && CUDA_VERSION >= 12040 && defined PYTORCH_C10_DRIVER_API_SUPPORTED @@ -180,3 +203,6 @@ bool get_fabric_access(c10::DeviceIndex dev) { } } // namespace at::cuda +======= +} // namespace at::cuda::detail +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/cuda/PeerToPeerAccess.h b/aten/src/ATen/cuda/PeerToPeerAccess.h index 30d21af83ed88..c041aa7d6f107 100644 --- a/aten/src/ATen/cuda/PeerToPeerAccess.h +++ b/aten/src/ATen/cuda/PeerToPeerAccess.h @@ -8,6 +8,9 @@ void init_p2p_access_cache(int64_t num_devices); } TORCH_CUDA_CPP_API bool get_p2p_access(c10::DeviceIndex source_dev, c10::DeviceIndex dest_dev); +<<<<<<< HEAD TORCH_CUDA_CPP_API bool get_fabric_access(c10::DeviceIndex device); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::cuda diff --git a/aten/src/ATen/cuda/cub.cu b/aten/src/ATen/cuda/cub.cu index bc863b8880da7..3bcbe9d60dc9b 100644 --- a/aten/src/ATen/cuda/cub.cu +++ b/aten/src/ATen/cuda/cub.cu @@ -15,7 +15,12 @@ struct SumOp { template void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t num_items) { +<<<<<<< HEAD inclusive_scan(input, output, NO_ROCM(::cuda)::std::plus<>{}, num_items); +======= + using NO_ROCM(at_cuda_detail)::cub::Sum; + inclusive_scan(input, output, Sum{}, num_items); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } template void inclusive_sum_truncating(const int32_t *input, int32_t *output, int64_t num_items); @@ -41,7 +46,12 @@ struct CountMaskOp { void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n) { CountMaskOp op{}; +<<<<<<< HEAD auto iter = ATEN_CUB_TRANSFORM_ITERATOR(bool, decltype(op), decltype(mask))(mask, op); +======= + auto iter = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator< + bool, decltype(op), decltype(mask)>(mask, op); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) exclusive_scan(iter, output_idx, SumOp{}, int64_t{0}, n); } diff --git a/aten/src/ATen/cuda/cub.cuh b/aten/src/ATen/cuda/cub.cuh index 23a3ff8c8958c..2b7366a517894 100644 --- a/aten/src/ATen/cuda/cub.cuh +++ b/aten/src/ATen/cuda/cub.cuh @@ -6,10 +6,13 @@ #include #include +<<<<<<< HEAD #ifndef USE_ROCM #include #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -55,6 +58,7 @@ #define ROCM_HIPCUB(x) x #endif +<<<<<<< HEAD #if CUB_V3_PLUS() #include #include @@ -70,6 +74,8 @@ #define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max() #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM) #if !defined(USE_ROCM) @@ -289,7 +295,11 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT return x.value; } }; +<<<<<<< HEAD auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)( +======= + auto input_ = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ArgIndexInputIterator(input + i), input_iter_transform); CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan, input_, @@ -444,7 +454,11 @@ __global__ void calc_block_sums(const T * d_in, aggT * agg, int64_t nelem, int i aggT data[ITEMS_PER_THREAD]; aggT agg_val = 0; TransformFunctor transform_functor; +<<<<<<< HEAD auto iter_in = ATEN_CUB_TRANSFORM_ITERATOR(aggT, TransformFunctor, const T*)(d_in, transform_functor); +======= + auto iter_in = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator, const T*>(d_in, transform_functor); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (int i=0; i= BLOCK_THREADS * ITEMS_PER_THREAD) { BlockLoadT(temp_storage.load).Load(iter_in, data); @@ -587,7 +601,11 @@ inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT i "cub InclusiveSumByKey does not support more than INT_MAX elements"); #if !defined(USE_ROCM) CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveSumByKey, +<<<<<<< HEAD keys, input, output, num_items, NO_ROCM(::cuda)::std::equal_to<>(), at::cuda::getCurrentCUDAStream()); +======= + keys, input, output, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #else CUB_WRAPPER(cub::DeviceScan::InclusiveSumByKey, keys, input, output, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream()); @@ -600,7 +618,11 @@ inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT "cub InclusiveSumByKey does not support more than INT_MAX elements"); #if !defined(USE_ROCM) CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveScanByKey, +<<<<<<< HEAD keys, input, output, scan_op, num_items, NO_ROCM(::cuda)::std::equal_to<>(), at::cuda::getCurrentCUDAStream()); +======= + keys, input, output, scan_op, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #else CUB_WRAPPER(cub::DeviceScan::InclusiveScanByKey, keys, input, output, scan_op, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream()); diff --git a/aten/src/ATen/cuda/cub_definitions.cuh b/aten/src/ATen/cuda/cub_definitions.cuh index b809512692093..9a475c17cc7dd 100644 --- a/aten/src/ATen/cuda/cub_definitions.cuh +++ b/aten/src/ATen/cuda/cub_definitions.cuh @@ -51,6 +51,7 @@ #else #define CUB_SUPPORTS_FUTURE_VALUE() false #endif +<<<<<<< HEAD // There were many bc-breaking changes in major version release of CCCL v3.0.0 // Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html @@ -59,3 +60,5 @@ #else #define CUB_V3_PLUS() false #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 72826b5847925..29b209e5a3817 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -19,6 +19,13 @@ #include #include +<<<<<<< HEAD +======= +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +#include +#endif + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if AT_CUDNN_ENABLED() #include #endif @@ -89,6 +96,32 @@ void CUDAHooks::init() const { // have a chance to enable vitals. at::vitals::VitalsAPI.setVital("CUDA", "used", "true", /* force = */ true); +<<<<<<< HEAD +======= + // Sets the CUDA_MODULE_LOADING environment variable + // if it's not set by the user. + // CUDA_MODULE_LOADING="LAZY" is default for all drivers released for CUDA 12.2+. + // Check the driver version and only set the env variable if needed. + bool set_lazy_module_loading = true; + #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + auto driver_api = c10::cuda::DriverAPI::get(); + // Initialize NVML + if (driver_api->nvmlInit_v2_() == NVML_SUCCESS) { + // Get the driver version + int version = -1; + auto res = driver_api->nvmlSystemGetCudaDriverVersion_v2_(&version); + if (res == NVML_SUCCESS) { + // Check if driver is sufficiently new + if (version >= 12020) { + set_lazy_module_loading = false; + } + } + } + #endif + if (set_lazy_module_loading) { + c10::utils::set_env("CUDA_MODULE_LOADING", "LAZY", false); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto num_devices = c10::cuda::device_count_ensure_non_zero(); c10::cuda::CUDACachingAllocator::init(num_devices); at::cuda::detail::init_p2p_access_cache(num_devices); @@ -180,6 +213,7 @@ bool CUDAHooks::hasCuBLASLt() const { #endif } +<<<<<<< HEAD bool CUDAHooks::hasCKSDPA() const { #if !defined(USE_ROCM) @@ -201,6 +235,8 @@ bool CUDAHooks::hasCKGEMM() const { #endif } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool CUDAHooks::hasROCM() const { // Currently, this is same as `compiledWithMIOpen`. // But in future if there are ROCm builds without MIOpen, diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index 2780369a37b71..81df53f994f23 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -31,8 +31,11 @@ struct CUDAHooks : public at::CUDAHooksInterface { bool hasCuSOLVER() const override; bool hasCuBLASLt() const override; bool hasROCM() const override; +<<<<<<< HEAD bool hasCKSDPA() const override; bool hasCKGEMM() const override; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const at::cuda::NVRTC& nvrtc() const override; DeviceIndex current_device() const override; bool isBuilt() const override {return true;} diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h index aca83386ad421..cbf4d811ebf14 100644 --- a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h @@ -117,8 +117,11 @@ namespace at::cuda { _(nvrtcGetPTXSize) \ _(nvrtcGetPTX) \ _(cuModuleLoadData) \ +<<<<<<< HEAD _(cuModuleLoad) \ _(cuGetErrorString) \ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _(cuModuleGetFunction) \ _(HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR) \ _(nvrtcGetErrorString) \ diff --git a/aten/src/ATen/cuda/tunable/GemmCommon.h b/aten/src/ATen/cuda/tunable/GemmCommon.h index 6d19907aba4ad..7687f2ca74ff3 100644 --- a/aten/src/ATen/cuda/tunable/GemmCommon.h +++ b/aten/src/ATen/cuda/tunable/GemmCommon.h @@ -29,8 +29,11 @@ namespace at::cuda::tunable { +<<<<<<< HEAD using at::cuda::blas::ScalingType; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) enum class BlasOp { N = 0, T = 1 @@ -162,7 +165,11 @@ inline std::string ComputeTypeFor() { // ROCBLAS and hipBLASLt. template <> inline std::string ComputeTypeFor() { +<<<<<<< HEAD if (at::globalContext().float32Precision("cuda", "matmul") != "tf32") { +======= + if (!at::globalContext().allowTF32CuBLAS()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return "f32_r"; } else { return "xf32_r"; @@ -600,8 +607,12 @@ struct ScaledGemmParams : OpParams { // // In TunableOp, we must distinguish in param signature these two cases: with and without a bias vector. return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld_rw_%d_bias_%s", +<<<<<<< HEAD transa, transb, m, n, k, lda, ldb, ldc, a_scaling_type == ScalingType::RowWise && b_scaling_type == ScalingType::RowWise, +======= + transa, transb, m, n, k, lda, ldb, ldc, use_rowwise, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bias_ptr == nullptr ? "None" : at::toString(bias_dtype)); } @@ -676,13 +687,19 @@ struct ScaledGemmParams : OpParams { int64_t lda{}; ScalarType a_dtype{}; ScalarType a_scale_dtype{}; +<<<<<<< HEAD ScalingType a_scaling_type{}; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const void* b{}; const void* b_scale_ptr{}; int64_t ldb{}; ScalarType b_dtype{}; ScalarType b_scale_dtype{}; +<<<<<<< HEAD ScalingType b_scaling_type{}; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const void* bias_ptr{}; ScalarType bias_dtype{}; void* c{}; @@ -691,6 +708,10 @@ struct ScaledGemmParams : OpParams { ScalarType c_dtype{}; void* amax_ptr{}; bool use_fast_accum{}; +<<<<<<< HEAD +======= + bool use_rowwise{}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) private: bool duplicate_inputs_{false}; }; diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index 1a0d968999067..cd57bc65488bd 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -215,6 +215,7 @@ float GetBetaFromParams(const ScaledGemmParams* params) { } template +<<<<<<< HEAD ScalingType GetAScalingTypeFromParams(const GemmParams* params) { return ScalingType::TensorWise; } @@ -252,6 +253,25 @@ ScalingType GetAScalingTypeFromParams(const ScaledGemmParams* params) { template ScalingType GetBScalingTypeFromParams(const ScaledGemmParams* params) { return params->b_scaling_type; +======= +bool GetUseRowwiseFromParams(const GemmParams* params) { + return false; +} + +template +bool GetUseRowwiseFromParams(const GemmAndBiasParams* params) { + return false; +} + +template +bool GetUseRowwiseFromParams(const GemmStridedBatchedParams* params) { + return false; +} + +template +bool GetUseRowwiseFromParams(const ScaledGemmParams* params) { + return params->use_rowwise; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } template @@ -506,7 +526,11 @@ class HipblasltGemmOp : public Callable { } hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; +<<<<<<< HEAD if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") { +======= + if (at::globalContext().allowTF32CuBLAS()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) computeType = HIPBLAS_COMPUTE_32F_FAST_TF32; } HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F); @@ -518,6 +542,7 @@ class HipblasltGemmOp : public Callable { const void* mat2_scale_ptr = GetBScalePointerFromParams(params); const void* result_scale_ptr = GetDScalePointerFromParams(params); if (mat1_scale_ptr && mat2_scale_ptr) { +<<<<<<< HEAD hipblasLtMatmulDescAttributes_t a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER; hipblasLtMatmulDescAttributes_t b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER; if (GetAScalingTypeFromParams(params) == ScalingType::RowWise) { @@ -536,6 +561,25 @@ class HipblasltGemmOp : public Callable { } matmul.setAttribute(a_scale_ptr_desc, mat1_scale_ptr); matmul.setAttribute(b_scale_ptr_desc, mat2_scale_ptr); +======= +#ifdef HIPBLASLT_VEC_EXT + if (GetUseRowwiseFromParams(params)) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT, mat1_scale_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT, mat2_scale_ptr); + } + else +#endif + { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); + } +#ifdef HIPBLASLT_OUTER_VEC + if (GetUseRowwiseFromParams(params)) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + } +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if (result_scale_ptr) { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); diff --git a/aten/src/ATen/cuda/tunable/GemmRocblas.h b/aten/src/ATen/cuda/tunable/GemmRocblas.h index d7c45dc91c212..24c6eba2fb894 100644 --- a/aten/src/ATen/cuda/tunable/GemmRocblas.h +++ b/aten/src/ATen/cuda/tunable/GemmRocblas.h @@ -141,7 +141,11 @@ class RocblasGemmOp : public Callable> { TuningStatus Call(const GemmParams* params) override { auto input_output_type = RocBlasDataTypeFor(); +<<<<<<< HEAD if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r) +======= + if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return FAIL; // no support for TF32 in rocBLAS auto compute_type = RocBlasComputeTypeFor(); auto h_a = DoCastForHalfOrBfloat16(params->alpha); @@ -209,7 +213,11 @@ class RocblasGemmStridedBatchedOp : public Callable> TuningStatus Call(const GemmStridedBatchedParams* params) override { auto input_output_type = RocBlasDataTypeFor(); +<<<<<<< HEAD if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r) +======= + if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return FAIL; // no support for TF32 in rocBLAS auto compute_type = RocBlasComputeTypeFor(); auto h_a = DoCastForHalfOrBfloat16(params->alpha); diff --git a/aten/src/ATen/cuda/tunable/README.md b/aten/src/ATen/cuda/tunable/README.md index b30040b7e2842..a304d2ed02042 100644 --- a/aten/src/ATen/cuda/tunable/README.md +++ b/aten/src/ATen/cuda/tunable/README.md @@ -154,7 +154,11 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins | PYTORCH_TUNABLEOP_MAX_WARMUP_ITERATIONS | Default is 0, meaning it is not used. | | PYTORCH_TUNABLEOP_ICACHE_FLUSH_ENABLED | Default is 1. Set to 0 to disable. | | PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE | Default (or < 0) is to query L2 cache size. Set to 0 to disable. Otherwise, set to the number of MiB to use for the pool of operator parameters. For example, setting this to the size of your device's memory cache will guarantee that every tuning iteration will use a cold cache. | +<<<<<<< HEAD | PYTORCH_TUNABLEOP_BLAS_LOG | Default is 0. Set to 1 to enable. Write BLAS parameters to tuning CSV file. | +======= +| PYTORCH_TUNABLEOP_BLAS_LOG | Default is 0. Set to 1 to enable. Write BLAS paramters to tuning CSV file. | +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ### Python Interface All python APIs exist in the `torch.cuda.tunable` module. diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp index 3511e48ae061a..6ba1880d6ba3c 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.cpp +++ b/aten/src/ATen/cuda/tunable/Tunable.cpp @@ -220,6 +220,7 @@ TuningResultsValidator::TuningResultsValidator() { []() { return GetPyTorchVersion(); }, [this](auto&& k) { return ValidatePyTorchVersion(std::forward(k)); }); #ifdef USE_ROCM +<<<<<<< HEAD // hip { // HIP version is more accurate than ROCm version. User's environment could be a stock @@ -231,6 +232,21 @@ TuningResultsValidator::TuningResultsValidator() { [hip_version](auto&& k) { TUNABLE_LOG1("HIP_VERSION validation: expect ", k, " to match ", hip_version); return hip_version == k ? OK : FAIL; +======= + // rocm + { +#ifdef _WIN32 + std::string rocm_version = HIP_VERSION_BUILD_NAME; +#else + std::string rocm_version = ROCM_BUILD_INFO; +#endif + RegisterValidator( + "ROCM_VERSION", + [rocm_version]() { return rocm_version; }, + [rocm_version](auto&& k) { + TUNABLE_LOG1("ROCM_VERSION validation: expect ", k, " to match ", rocm_version); + return rocm_version == k ? OK : FAIL; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }); } // gfx arch diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index d941c230630c4..f96ed7de52b2d 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -96,20 +96,31 @@ class DefaultScaledGemmOp : public Callable> { params->lda, params->a_dtype, params->a_scale_dtype, +<<<<<<< HEAD params->a_scaling_type, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) params->b, params->b_scale_ptr, params->ldb, params->b_dtype, params->b_scale_dtype, +<<<<<<< HEAD params->b_scaling_type, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) params->bias_ptr, params->bias_dtype, params->c, params->c_scale_ptr, params->ldc, params->c_dtype, +<<<<<<< HEAD params->use_fast_accum); +======= + params->use_fast_accum, + params->use_rowwise); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return OK; } }; diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index 00573e3cf701b..5bb9ecc85197b 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -118,6 +118,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { return false; } +<<<<<<< HEAD virtual bool hasCKSDPA() const { return false; } @@ -126,6 +127,8 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { return false; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) virtual const at::cuda::NVRTC& nvrtc() const { TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP); } diff --git a/aten/src/ATen/detail/MTIAHooksInterface.cpp b/aten/src/ATen/detail/MTIAHooksInterface.cpp index d2e331abb0c04..f4bd44b96649a 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.cpp +++ b/aten/src/ATen/detail/MTIAHooksInterface.cpp @@ -21,10 +21,13 @@ bool isMTIAHooksBuilt() { } // namespace detail +<<<<<<< HEAD bool MTIAHooksInterface::isAvailable() const { return detail::isMTIAHooksBuilt() && detail::getMTIAHooks().deviceCount() > 0; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) C10_DEFINE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs) } // namespace at diff --git a/aten/src/ATen/detail/MTIAHooksInterface.h b/aten/src/ATen/detail/MTIAHooksInterface.h index b415862f29e7c..5e140a93a1190 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.h +++ b/aten/src/ATen/detail/MTIAHooksInterface.h @@ -149,8 +149,11 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { FAIL_MTIAHOOKS_FUNC(__func__); return; } +<<<<<<< HEAD virtual bool isAvailable() const override; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; struct TORCH_API MTIAHooksArgs {}; diff --git a/aten/src/ATen/dlpack.h b/aten/src/ATen/dlpack.h index 82c0668211188..60fe71cc64380 100644 --- a/aten/src/ATen/dlpack.h +++ b/aten/src/ATen/dlpack.h @@ -15,11 +15,19 @@ #define DLPACK_EXTERN_C #endif +<<<<<<< HEAD /*! \brief The current major version of dlpack */ #define DLPACK_MAJOR_VERSION 1 /*! \brief The current minor version of dlpack */ #define DLPACK_MINOR_VERSION 0 +======= +/*! \brief The current version of dlpack */ +#define DLPACK_VERSION 80 + +/*! \brief The current ABI version of dlpack */ +#define DLPACK_ABI_VERSION 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /*! \brief DLPACK_DLL prefix for windows */ #ifdef _WIN32 @@ -40,6 +48,7 @@ #ifdef __cplusplus extern "C" { #endif +<<<<<<< HEAD /*! * \brief The DLPack version. @@ -67,6 +76,8 @@ typedef struct { uint32_t minor; } DLPackVersion; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /*! * \brief The device type in DLDevice. */ @@ -118,7 +129,11 @@ typedef enum { kDLWebGPU = 15, /*! \brief Qualcomm Hexagon DSP */ kDLHexagon = 16, +<<<<<<< HEAD /*! \brief Microsoft MAIA devices */ +======= + /*! \brief Microsoft AI Accelerator */ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kDLMAIA = 17, } DLDeviceType; @@ -199,7 +214,11 @@ typedef struct { * `byte_offset` field should be used to point to the beginning of the data. * * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow, +<<<<<<< HEAD * TVM, perhaps others) do not adhere to this 256 byte alignment requirement +======= + * TVM, perhaps others) do not adhere to this 256 byte aligment requirement +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed * (after which this note will be updated); at the moment it is recommended * to not rely on the data pointer being correctly aligned. @@ -217,9 +236,12 @@ typedef struct { * return size; * } * \endcode +<<<<<<< HEAD * * Note that if the tensor is of size zero, then the data pointer should be * set to `NULL`. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) */ void* data; /*! \brief The device of the tensor */ @@ -245,6 +267,7 @@ typedef struct { * not meant to transfer the tensor. When the borrowing framework doesn't need * the tensor, it should call the deleter to notify the host that the resource * is no longer needed. +<<<<<<< HEAD * * \note This data structure is used as Legacy DLManagedTensor * in DLPack exchange and is deprecated after DLPack v0.8 @@ -252,6 +275,8 @@ typedef struct { * This data structure may get renamed or deleted in future versions. * * \sa DLManagedTensorVersioned +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) */ typedef struct DLManagedTensor { /*! \brief DLTensor which is being memory managed */ @@ -260,6 +285,7 @@ typedef struct DLManagedTensor { * which DLManagedTensor is used in the framework. It can also be NULL. */ void * manager_ctx; +<<<<<<< HEAD /*! * \brief Destructor - this should be called * to destruct the manager_ctx which backs the DLManagedTensor. It can be @@ -328,6 +354,15 @@ struct DLManagedTensorVersioned { DLTensor dl_tensor; }; +======= + /*! \brief Destructor signature void (*)(void*) - this should be called + * to destruct manager_ctx which holds the DLManagedTensor. It can be NULL + * if there is no way for the caller to provide a reasonable destructor. + * The destructors deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensor * self); +} DLManagedTensor; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef __cplusplus } // DLPACK_EXTERN_C #endif diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index d58d436c511d1..5cb6684c332f2 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -158,7 +158,10 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(kron); OP_DECOMPOSE(l1_loss); m.impl("layer_norm", native::layer_norm_symint); +<<<<<<< HEAD m.impl("_fused_rms_norm", native::rms_norm_composite); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) OP_DECOMPOSE2(ldexp, Tensor); OP_DECOMPOSE2(less_equal, Tensor ); OP_DECOMPOSE2(less, Tensor ); diff --git a/aten/src/ATen/functorch/BatchRulesModules.cpp b/aten/src/ATen/functorch/BatchRulesModules.cpp index 6e63708a90f4a..40125d4a7e076 100644 --- a/aten/src/ATen/functorch/BatchRulesModules.cpp +++ b/aten/src/ATen/functorch/BatchRulesModules.cpp @@ -7,7 +7,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include @@ -45,6 +48,7 @@ static std::tuple> embedding_batch_rule( const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight); auto indices_ = moveBatchDimToFront(indices, indices_bdim); +<<<<<<< HEAD { // getStepTensor returns a regular Tensor. If indices_ is a DTensor // we want to allow this mixed DTensor-Tensor operation. @@ -52,6 +56,10 @@ static std::tuple> embedding_batch_rule( const auto range = getStepTensor(indices, batch_size, num_embeddings); indices_ = indices_ + range; } +======= + const auto range = getStepTensor(indices, batch_size, num_embeddings); + indices_ = indices_ + range; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto result = at::embedding_symint(weight_, indices_, std::move(padding_idx), scale_grad_by_freq, sparse); return std::make_tuple(std::move(result), 0); } diff --git a/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp b/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp index 48a735c3e5332..88c1fc755aae4 100644 --- a/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp @@ -171,8 +171,11 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { POINTWISE_BOXED(fill_.Scalar); POINTWISE_BOXED(zero_); +<<<<<<< HEAD // This is special because this op doesn't return anything m.impl("_assert_tensor_metadata", native::_assert_tensor_metadata); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #undef UNARY_POINTWISE #undef UNARY_POINTWISE_ALL diff --git a/aten/src/ATen/functorch/BatchedTensorImpl.cpp b/aten/src/ATen/functorch/BatchedTensorImpl.cpp index ee222b4e61a52..c3bc9236d7efd 100644 --- a/aten/src/ATen/functorch/BatchedTensorImpl.cpp +++ b/aten/src/ATen/functorch/BatchedTensorImpl.cpp @@ -126,7 +126,11 @@ SymIntArrayRef BatchedTensorImpl::sym_strides_custom() const { // TODO: implement proper contiguity on batched tensor, then put // sizes_strides_policy back to Default +<<<<<<< HEAD c10::SymBool BatchedTensorImpl::sym_is_contiguous_custom(at::MemoryFormat memory_format) const { +======= +bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(memory_format == MemoryFormat::Contiguous, "NYI: querying is_contiguous inside of vmap for memory_format ", "other than torch.contiguous_format"); diff --git a/aten/src/ATen/functorch/BatchedTensorImpl.h b/aten/src/ATen/functorch/BatchedTensorImpl.h index 3eccc94d3ea60..3d9b0989d181a 100644 --- a/aten/src/ATen/functorch/BatchedTensorImpl.h +++ b/aten/src/ATen/functorch/BatchedTensorImpl.h @@ -69,7 +69,11 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { IntArrayRef strides_custom() const override; SymIntArrayRef sym_strides_custom() const override; // Override a bunch of methods inherited from TensorImpl to return error messages. +<<<<<<< HEAD c10::SymBool sym_is_contiguous_custom(at::MemoryFormat memory_format) const override; +======= + bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void set_size(int64_t dim, int64_t new_size) override; void set_stride(int64_t dim, int64_t new_stride) override; c10::intrusive_ptr shallow_copy_and_detach( diff --git a/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h index f4316def4fb42..70277756869da 100644 --- a/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h +++ b/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h @@ -37,6 +37,7 @@ class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllo allocator_->copy_data(dest, src, count); } +<<<<<<< HEAD // From DeviceAllocator bool initialized() override { @@ -64,6 +65,8 @@ class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllo allocator_->resetPeakStats(device); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // From CUDAAllocator void* raw_alloc(size_t nbytes) override { @@ -82,6 +85,13 @@ class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllo allocator_->init(device_count); } +<<<<<<< HEAD +======= + bool initialized() override { + return allocator_->initialized(); + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) double getMemoryFraction(c10::DeviceIndex device) override { return allocator_->getMemoryFraction(device); } @@ -90,6 +100,13 @@ class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllo allocator_->setMemoryFraction(fraction, device); } +<<<<<<< HEAD +======= + void emptyCache(MempoolId_t mempool_id = {0, 0}) override { + allocator_->emptyCache(mempool_id); + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void enable(bool value) override { allocator_->enable(value); } @@ -110,6 +127,21 @@ class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllo allocator_->recordStream(ptr, stream); } +<<<<<<< HEAD +======= + CachingDeviceAllocator::DeviceStats getDeviceStats(c10::DeviceIndex device) override { + return allocator_->getDeviceStats(device); + } + + void resetAccumulatedStats(c10::DeviceIndex device) override { + allocator_->resetAccumulatedStats(device); + } + + void resetPeakStats(c10::DeviceIndex device) override { + allocator_->resetPeakStats(device); + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) HIPCachingAllocator::SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) override { return allocator_->snapshot(mempool_id); } diff --git a/aten/src/ATen/mps/EmptyTensor.cpp b/aten/src/ATen/mps/EmptyTensor.cpp index 6c58de099648d..e8c178b18f00d 100644 --- a/aten/src/ATen/mps/EmptyTensor.cpp +++ b/aten/src/ATen/mps/EmptyTensor.cpp @@ -12,7 +12,11 @@ #define MPS_ERROR_NOT_COMPILED "PyTorch code is not compiled with MPS enabled" #define MPS_ERROR_RUNTIME_TOO_LOW \ +<<<<<<< HEAD "The MPS backend is supported on MacOS 14.0+. ", \ +======= + "The MPS backend is supported on MacOS 13.0+.", \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "Current OS version can be queried using `sw_vers`" #define MPS_ERROR_DOUBLE_NOT_SUPPORTED "Cannot convert a MPS Tensor to float64 dtype " \ "as the MPS framework doesn't support float64. Please use float32 instead." @@ -43,6 +47,10 @@ TensorBase empty_mps( int64_t nelements = c10::multiply_integers(size); auto dtype = dtype_or_default(dtype_opt); TORCH_CHECK_TYPE(dtype != ScalarType::Double, MPS_ERROR_DOUBLE_NOT_SUPPORTED); +<<<<<<< HEAD +======= + TORCH_CHECK_TYPE(dtype != ScalarType::BFloat16 || is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_14_0_PLUS), "MPS BFloat16 is only supported on MacOS 14 or newer"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto dtype_meta = scalarTypeToTypeMeta(dtype); diff --git a/aten/src/ATen/mps/MPSDevice.h b/aten/src/ATen/mps/MPSDevice.h index 9b58477104978..9075e733eed99 100644 --- a/aten/src/ATen/mps/MPSDevice.h +++ b/aten/src/ATen/mps/MPSDevice.h @@ -18,7 +18,15 @@ namespace at::mps { // Helper enum to check if a MPSGraph op is supported in a given macOS version enum class MacOSVersion : uint32_t { +<<<<<<< HEAD MACOS_VER_14_4_PLUS = 0, +======= + MACOS_VER_13_1_PLUS = 0, + MACOS_VER_13_2_PLUS, + MACOS_VER_13_3_PLUS, + MACOS_VER_14_0_PLUS, + MACOS_VER_14_4_PLUS, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MACOS_VER_15_0_PLUS, MACOS_VER_15_1_PLUS, MACOS_VER_15_2_PLUS, @@ -55,6 +63,7 @@ class TORCH_API MPSDevice { */ bool isMacOS13Plus(MacOSVersion version) const; +<<<<<<< HEAD /** * Returns device name */ @@ -66,6 +75,8 @@ class TORCH_API MPSDevice { */ unsigned getCoreCount() const; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ~MPSDevice(); private: diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm index 5a37490c02402..39ddd7c1ceec0 100644 --- a/aten/src/ATen/mps/MPSDevice.mm +++ b/aten/src/ATen/mps/MPSDevice.mm @@ -32,11 +32,19 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de MPSDevice::MPSDevice() : _mtl_device(nil) { // Check that MacOS 13.0+ version of MPS framework is available +<<<<<<< HEAD // Create the MPSGraph and check method introduced in 14.0 // which is used by MPS backend. id mpsCD = NSClassFromString(@"MPSGraph"); if ([mpsCD instancesRespondToSelector:@selector(HermiteanToRealFFTWithTensor:axes:descriptor:name:)] == NO) { +======= + // Create the MPSGraph and check method introduced in 13.0 + // which is used by MPS backend. + id mpsCD = NSClassFromString(@"MPSGraph"); + + if ([mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == NO) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return; } @@ -66,12 +74,30 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de isOperatingSystemAtLeastVersion:{.majorVersion = major, .minorVersion = minor, .patchVersion = 0}]; } }; +<<<<<<< HEAD +======= + static bool _macos_13_1_plus = is_os_version_at_least(13, 1); + static bool _macos_13_2_plus = is_os_version_at_least(13, 2); + static bool _macos_13_3_plus = is_os_version_at_least(13, 3); + static bool _macos_14_0_plus = is_os_version_at_least(14, 0); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static bool _macos_14_4_plus = is_os_version_at_least(14, 4); static bool _macos_15_0_plus = is_os_version_at_least(15, 0); static bool _macos_15_1_plus = is_os_version_at_least(15, 1); static bool _macos_15_2_plus = is_os_version_at_least(15, 2); switch (version) { +<<<<<<< HEAD +======= + case MacOSVersion::MACOS_VER_13_1_PLUS: + return _macos_13_1_plus; + case MacOSVersion::MACOS_VER_13_2_PLUS: + return _macos_13_2_plus; + case MacOSVersion::MACOS_VER_13_3_PLUS: + return _macos_13_3_plus; + case MacOSVersion::MACOS_VER_14_0_PLUS: + return _macos_14_0_plus; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case MacOSVersion::MACOS_VER_14_4_PLUS: return _macos_14_4_plus; case MacOSVersion::MACOS_VER_15_0_PLUS: @@ -85,6 +111,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de } } +<<<<<<< HEAD std::string MPSDevice::getName() const { @autoreleasepool { return [[_mtl_device name] UTF8String]; @@ -115,6 +142,12 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de at::Allocator* GetMPSAllocator(bool useSharedAllocator) { return getIMPSAllocator(useSharedAllocator); } +======= +at::Allocator* GetMPSAllocator(bool useSharedAllocator) { + return getIMPSAllocator(useSharedAllocator); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool is_available() { return MPSDevice::getInstance()->device() != nil; } diff --git a/aten/src/ATen/mps/MPSHooks.mm b/aten/src/ATen/mps/MPSHooks.mm index 34fbd31af91da..812a09750c225 100644 --- a/aten/src/ATen/mps/MPSHooks.mm +++ b/aten/src/ATen/mps/MPSHooks.mm @@ -34,7 +34,11 @@ case 14: switch (minor) { case 0: +<<<<<<< HEAD return true; +======= + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case 4: return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS); default: @@ -42,7 +46,23 @@ return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS); } case 13: +<<<<<<< HEAD return true; +======= + switch (minor) { + case 0: + return true; + case 1: + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS); + case 2: + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS); + case 3: + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); + default: + TORCH_WARN("Can't check whether running on 13.", minor, "+ returning one for 13.3+"); + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) default: TORCH_WARN("Checking for unexpected MacOS ", major, ".", minor, " returning false"); return false; @@ -70,10 +90,14 @@ } void* MPSHooks::getCommandBuffer() const { +<<<<<<< HEAD auto stream = at::mps::getDefaultMPSStream(); // Release pending computeCommandEncoder, as extensions is likely to allocate new one stream->endKernelCoalescing(); return stream->commandBuffer(); +======= + return at::mps::getDefaultMPSStream()->commandBuffer(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void* MPSHooks::getDispatchQueue() const { diff --git a/aten/src/ATen/mps/MPSStream.mm b/aten/src/ATen/mps/MPSStream.mm index 71325bd69e1df..6d4cec8929611 100644 --- a/aten/src/ATen/mps/MPSStream.mm +++ b/aten/src/ATen/mps/MPSStream.mm @@ -158,6 +158,7 @@ @interface MPSGraphExecutionDescriptor () endKernelCoalescing(); id blitEncoder = [commandBuffer() blitCommandEncoder]; +<<<<<<< HEAD // For some reason fillBufferfor stopped working for lengh > 4Gb on MacOS 26 // See https://github.com/pytorch/pytorch/issues/163962 // Workaround by batching copy commands into 4Gb chunks @@ -170,6 +171,9 @@ @interface MPSGraphExecutionDescriptor () bytes_filled += bytes_to_copy; bytes_remains -= bytes_to_copy; } +======= + [blitEncoder fillBuffer:buffer range:NSMakeRange(offset, length) value:value]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [blitEncoder endEncoding]; synchronize(syncType); } diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index d323e54a95abe..ef47b26ccf4d8 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -2453,7 +2453,11 @@ TORCH_IMPL_FUNC(linalg_qr_out)(const Tensor& A, // geqrf requires m x n workspace input that is modified in-place // We try to use Q. If it doesn't fit, we try to use R +<<<<<<< HEAD // If m > n and compute_q==false, it won't fit into Q or R, so we need to create an auxiliary tensor +======= + // If m > n and compute_q==false, it won't fit into Q or R, so we neet to create an auxiliary tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor QR; if (compute_q && Q.size(-1) == n) { QR = Q; @@ -4095,7 +4099,11 @@ Tensor linalg_vander_symint( const auto n = N.value_or(shape.back()); TORCH_CHECK(n > 1, "N must be greater than 1."); +<<<<<<< HEAD // Append cumprod of the other 0...n-1 powers +======= + // Append cumprod of the oher 0...n-1 powers +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shape.push_back(n - 1); auto result = at::cumprod(x_.unsqueeze(-1).expand_symint(shape), -1); // The row of ones diff --git a/aten/src/ATen/native/Blas.cpp b/aten/src/ATen/native/Blas.cpp index 49366151ae60b..29058b477d5f7 100644 --- a/aten/src/ATen/native/Blas.cpp +++ b/aten/src/ATen/native/Blas.cpp @@ -9,7 +9,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if !defined(__s390x__) && !defined(__powerpc__) #include #endif @@ -333,6 +336,7 @@ _scaled_mm_cpu(const Tensor& mat_a, const Tensor& mat_b, return _scaled_mm_out_cpu(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out); } +<<<<<<< HEAD // TODO(vasiliy, future PR): figure out why we need to declare this function, when // other functions that live in ATen/native/*.cpp without declarations // or headers work just fine. @@ -352,4 +356,6 @@ std::optional out_dtype) { return out; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::native diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index e06afddd05aa7..8b5c5724cd20b 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -51,7 +51,11 @@ extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int * // brgemm_pack_B is changed to transform and the setting of brgemm beta is changed to set_add_C #if (IDEEP_VERSION_MAJOR == 3 && IDEEP_VERSION_MINOR == 5) #define ONEDNN_UKERNEL_1 +<<<<<<< HEAD #elif ((IDEEP_VERSION_MAJOR == 3 && IDEEP_VERSION_MINOR >= 6) || (IDEEP_VERSION_MAJOR > 3)) +======= +#elif (IDEEP_VERSION_MAJOR >= 3 && IDEEP_VERSION_MINOR >= 6) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #define ONEDNN_UKERNEL_2 #endif #if ((defined(ONEDNN_UKERNEL_1) || defined(ONEDNN_UKERNEL_2)) && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))) @@ -202,7 +206,11 @@ void gemm( float *c, int64_t ldc) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); #if AT_MKLDNN_ENABLED() +<<<<<<< HEAD if (mkldnn_reduced_f32_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) { +======= + if (mkldnn_bf32_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return; } #endif @@ -358,6 +366,7 @@ void gemm( int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; char transa_ = to_blas(transa), transb_ = to_blas(transb); float alpha_ = alpha, beta_ = beta; +<<<<<<< HEAD int c_size = n_ * m_; // C matrix in OpenBLAS sbgemm are of type "float" so we have to convert, copy and copy back. std::vector float_v(c_size, 0.0f); @@ -366,17 +375,28 @@ void gemm( float_v[j * m_ + i] = c10::convert(c[j * ldc_ + i]); } } +======= + int c_size = n_ * ldc_; + // C matrix in OpenBLAS sbgemm are of type "float" so we have to convert, copy and copy back. + std::vector float_v(c, c + c_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sbgemm_(&transa_, &transb_, &m_, &n_, &k_, &alpha_, a, &lda_, b, &ldb_, &beta_, +<<<<<<< HEAD float_v.data(), &m_); for (const auto j : c10::irange(n)) { for (const auto i : c10::irange(m)) { c[j * ldc_ + i] = c10::convert(float_v[j * m_ + i]); } +======= + float_v.data(), &ldc_); + for (auto cv: float_v) { + *(c++) = c10::convert(cv); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return; } @@ -496,18 +516,31 @@ void gemm( // for the fallback path, first compute gemm with beta = 0, // and then add c in full precision. int64_t c_size = n * m; +<<<<<<< HEAD std::vector float_c(c_size, 0.f); gemm_no_downcast_stub( at::kCPU, at::kHalf, transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m); +======= + std::vector float16_c(c_size, 0.f); + gemm_stub( + at::kCPU, at::kHalf, + transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float16_c.data(), m); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto j : c10::irange(n)) { for (const auto i : c10::irange(m)) { auto offset = j * ldc + i; // beta == 0 won't propagate NaN from C if (beta == 0.f) { +<<<<<<< HEAD c[offset] = float_c[j * m + i]; } else { c[offset] = beta * c[offset] + float_c[j * m + i]; +======= + c[offset] = c10::convert(float16_c[j * m + i]); + } else { + c[offset] = beta * c[offset] + c10::convert(float16_c[j * m + i]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } } diff --git a/aten/src/ATen/native/CPUBlas.h b/aten/src/ATen/native/CPUBlas.h index 8b75f12ebaf21..8512af333fb8e 100644 --- a/aten/src/ATen/native/CPUBlas.h +++ b/aten/src/ATen/native/CPUBlas.h @@ -206,6 +206,7 @@ void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *x, int64_t incx, c10::complex int32 #define CPUBLAS_BRGEMM_I8I8I32 // signed char * signed char -> int32 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_API void brgemm( int64_t M, int64_t N, diff --git a/aten/src/ATen/native/ComparisonUtils.cpp b/aten/src/ATen/native/ComparisonUtils.cpp index 13bef0a00b9c9..3c7134cd18617 100644 --- a/aten/src/ATen/native/ComparisonUtils.cpp +++ b/aten/src/ATen/native/ComparisonUtils.cpp @@ -24,6 +24,7 @@ static void _assert_match(const O& original, const C& compared, const std::strin } } +<<<<<<< HEAD template<> void _assert_match>( const c10::Device& original, @@ -47,6 +48,8 @@ void _assert_match>( } } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional dtype, std::optional device, std::optional layout) { _assert_match(tensor.sym_sizes(), sizes, "sizes"); _assert_match(tensor.sym_strides(), strides, "strides"); diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 2e0e4a47f37be..eb826bdda906b 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -3,7 +3,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -14,7 +17,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -301,6 +307,7 @@ struct ConvParams { bool allow_tf32{}; bool is_strided() const { +<<<<<<< HEAD return std::any_of( stride.cbegin(), stride.cend(), [](const T& s) { return s != 1; }); } @@ -345,6 +352,69 @@ struct ConvParams { bool is_stride_nonpos() const { return std::any_of( stride.cbegin(), stride.cend(), [](const T& s) { return s <= 0; }); +======= + bool is_strided = false; + for (const auto& s : stride) { + is_strided |= (s != 1); + } + return is_strided; + } + + bool is_dilated() const { + bool is_dilated = false; + for (const auto& d : dilation) { + is_dilated |= (d != 1); + } + return is_dilated; + } + + bool is_padded() const { + bool is_padded = false; + for (auto p : padding) { + is_padded |= (p != 0); + } + return is_padded; + } + + bool is_output_padding_neg() const { + bool is_non_neg = false; + for (const auto& p : output_padding) { + is_non_neg |= (p < 0); + } + return is_non_neg; + } + + bool is_output_padding_big() const { + bool is_big = false; + for (auto i: c10::irange(output_padding.size())) { + is_big |= (output_padding[i] >= stride[i]); + } + return is_big; + } + + bool is_padding_neg() const { + bool is_non_neg = false; + for (const auto& p : padding) { + is_non_neg |= (p < 0); + } + return is_non_neg; + } + + bool is_dilation_neg() const { + bool is_non_neg = false; + for (const auto& p : dilation) { + is_non_neg |= (p < 0); + } + return is_non_neg; + } + + bool is_stride_nonpos() const { + bool is_nonpos = false; + for (const auto& s : stride) { + is_nonpos |= (s <= 0); + } + return is_nonpos; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void view1d_as_2d() { @@ -410,6 +480,7 @@ struct ConvParams { // cudnn and miopen are guaranteed not to be on mobile, and T102591915 / T110194934 suggest // that maybe the compiledWithCuDNN() check sometimes segfaults (though I can't imagine how) #if !defined(C10_MOBILE) +<<<<<<< HEAD if (!detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda() || !cudnn_enabled) { return false; } @@ -427,6 +498,13 @@ struct ConvParams { } } if (needs_64bit_indexing_no_split(input, weight)) { +======= + if (!detail::getCUDAHooks().compiledWithCuDNN()) { + return false; + } + if (needs_64bit_indexing_no_split(input, weight)) { + static long cudnn_version = detail::getCUDAHooks().versionCuDNN(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) { TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions" " if the V8 API is not enabled or before cuDNN version 9.3+." @@ -434,6 +512,12 @@ struct ConvParams { return false; } } +<<<<<<< HEAD +======= + if (!input.is_cuda() || !cudnn_enabled) { + return false; + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (input.scalar_type() == at::kBFloat16 || weight.scalar_type() == at::kBFloat16) { if (!(detail::getCUDAHooks().supportsBFloat16ConvolutionWithCuDNNv8() && at::native::cudnnv8_enabled_check_debug())) { return false; @@ -452,6 +536,7 @@ struct ConvParams { // Use cudnn for FP16 depthwise convolutions bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const { +<<<<<<< HEAD if (!cudnn_enabled || !detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda()) { return false; } @@ -465,6 +550,18 @@ struct ConvParams { } } +======= + if (!detail::getCUDAHooks().compiledWithCuDNN()) { + return false; + } + if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous && use_cudnn(input, weight)) { + // always use cudnn_depthwise for channels_last format + return true; + } + // native kernel doesn't support 64-bit non-splittable case + if (cudnn_enabled && needs_64bit_indexing_no_split(input, weight)) { + static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) { TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions" " if the V8 API is not enabled or before cuDNN version 9.3+." @@ -474,10 +571,13 @@ struct ConvParams { return true; } } +<<<<<<< HEAD if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous) { // always use cudnn_depthwise for channels_last format return true; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) { bool kernel_cond = (use_cudnn(input, weight) && input.scalar_type() == kHalf && // only for FP16 @@ -1178,7 +1278,11 @@ at::Tensor convolution( bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); return at::_convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, +<<<<<<< HEAD ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN("conv")); +======= + ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } at::Tensor convolution_overrideable( @@ -1323,7 +1427,11 @@ ConvBackend select_conv_backend( params.benchmark = ctx.benchmarkCuDNN(); params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); params.cudnn_enabled = ctx.userEnabledCuDNN(); +<<<<<<< HEAD params.allow_tf32 = ctx.allowTF32CuDNN("conv"); +======= + params.allow_tf32 = ctx.allowTF32CuDNN(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto input = input_r; auto weight = weight_r; @@ -1707,7 +1815,11 @@ at::Tensor _convolution( c10::MaybeOwned bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt); const Tensor& bias_r = *bias_r_maybe_owned; +<<<<<<< HEAD return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN("conv")); +======= + return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } std::tuple convolution_backward_overrideable( @@ -2005,7 +2117,11 @@ std::tuple convolution_backward( params.benchmark = ctx.benchmarkCuDNN(); params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); params.cudnn_enabled = ctx.userEnabledCuDNN(); +<<<<<<< HEAD params.allow_tf32 = ctx.allowTF32CuDNN("conv"); +======= + params.allow_tf32 = ctx.allowTF32CuDNN(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Validate inputs. check_shape_backward(input, weight.sizes(), params); diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 3d388194ea49d..c6e79961d76e7 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -36,10 +36,15 @@ #endif #ifdef USE_FBGEMM +<<<<<<< HEAD C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include C10_DIAGNOSTIC_POP() +======= +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif namespace { diff --git a/aten/src/ATen/native/DilatedMaxPool2d.cpp b/aten/src/ATen/native/DilatedMaxPool2d.cpp index 641e9f14dd711..ad084a913ef86 100644 --- a/aten/src/ATen/native/DilatedMaxPool2d.cpp +++ b/aten/src/ATen/native/DilatedMaxPool2d.cpp @@ -54,7 +54,11 @@ bool ceil_mode) { TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4), "non-empty 3D or 4D (batch mode) tensor expected for input"); } else { +<<<<<<< HEAD TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); +======= + TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast, Contiguous"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } /* sizes */ @@ -130,7 +134,11 @@ const Tensor& indices) { TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4), "non-empty 3D or 4D (batch mode) tensor expected for input"); } else { +<<<<<<< HEAD TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); +======= + TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast, Contiguous"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } /* sizes */ diff --git a/aten/src/ATen/native/DilatedMaxPool3d.cpp b/aten/src/ATen/native/DilatedMaxPool3d.cpp index 23d77cb210720..afa493eda70ad 100644 --- a/aten/src/ATen/native/DilatedMaxPool3d.cpp +++ b/aten/src/ATen/native/DilatedMaxPool3d.cpp @@ -63,7 +63,11 @@ void max_pool3d_with_indices_out_cpu_template( TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5), "non-empty 4D or 5D (batch mode) tensor expected for input"); } else { +<<<<<<< HEAD TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous"); +======= + TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast3d, Contiguous"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } const int64_t nslices = input.size(-4); @@ -158,7 +162,11 @@ Tensor& max_pool3d_with_indices_backward_out_cpu_template( TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5), "non-empty 4D or 5D (batch mode) tensor expected for input"); } else { +<<<<<<< HEAD TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous"); +======= + TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast3d, Contiguous"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } const int64_t nslices = input.size(-4); diff --git a/aten/src/ATen/native/DistributionTemplates.h b/aten/src/ATen/native/DistributionTemplates.h index 21a15b80c9c84..b5675ecb79d1a 100644 --- a/aten/src/ATen/native/DistributionTemplates.h +++ b/aten/src/ATen/native/DistributionTemplates.h @@ -28,13 +28,21 @@ namespace at::native::templates { // ==================================================== Random ======================================================== // The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can be used as actual `from`. +<<<<<<< HEAD // The current implementation of `random_` uses uint64_t arithmetic and casts the result to the target dtype(scalar_t). +======= +// The current implementation of `random_` uses uint64_t arithmetics and casts the result to the target dtype(scalar_t). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // This casting can result in generating numbers that happen to be greater or equal to `to` value. For instance: // // auto actual = torch::empty({3, 3}, torch::half); // actual.random_(0, 65504); // +<<<<<<< HEAD // If random's uint64_t arithmetic produces 65503 as a random value after casting to torch::half it becomes 65504 +======= +// If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it becomes 65504 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // and violates the requirement that random value must be less than `to`. To resolve this issue `update_from` and `update_to` // moves `from` to the right and `to` to the left to the next closest value that won't go outside [from, to) after casting to // the target dtype. For `to` = 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index 5f34ed9d24c17..03f8acb547cd1 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -424,6 +424,7 @@ Tensor _dirichlet_grad_cpu(const Tensor& x, const Tensor& alpha, const Tensor& t */ Tensor _s_binomial_cpu(const Tensor& count, const Tensor& prob, std::optional gen) { +<<<<<<< HEAD TORCH_CHECK_VALUE( at::isFloatingType(count.scalar_type()), "binomial only supports floating-point dtypes for count, got: ", @@ -432,6 +433,8 @@ Tensor _s_binomial_cpu(const Tensor& count, const Tensor& prob, std::optional>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor ret = at::zeros(count.sizes(), count.options()); auto iter = TensorIteratorConfig() .add_output(ret) diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index 150970edc5076..de6404e576731 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -14,10 +14,15 @@ #include #ifdef USE_FBGEMM +<<<<<<< HEAD C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include C10_DIAGNOSTIC_POP() +======= +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #else #include #endif diff --git a/aten/src/ATen/native/ForeachOpsKernels.cpp b/aten/src/ATen/native/ForeachOpsKernels.cpp index cb437fb45ce21..4065c93f87bae 100644 --- a/aten/src/ATen/native/ForeachOpsKernels.cpp +++ b/aten/src/ATen/native/ForeachOpsKernels.cpp @@ -260,7 +260,10 @@ namespace at::native { check_foreach_api_restrictions(input, tensors1, tensors2); \ \ std::vector result; \ +<<<<<<< HEAD result.reserve(input.size()); \ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto i : c10::irange(input.size())) { \ result.emplace_back(input[i].OP(tensors1[i], tensors2[i], scalar)); \ } \ @@ -289,7 +292,10 @@ namespace at::native { check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \ \ std::vector result; \ +<<<<<<< HEAD result.reserve(input.size()); \ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto i : c10::irange(input.size())) { \ result.emplace_back(input[i].OP(tensors1[i], tensors2[i], scalars[i])); \ } \ @@ -419,7 +425,10 @@ std::vector foreach_tensor_ternary_lerp_slow( TensorList tensors3) { check_foreach_api_restrictions(tensors1, tensors2, tensors3); std::vector result; +<<<<<<< HEAD result.reserve(tensors1.size()); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto i : c10::irange(tensors1.size())) { result.emplace_back(tensors1[i].lerp(tensors2[i], tensors3[i])); } @@ -442,7 +451,10 @@ std::vector foreach_tensor_lerp_scalarlist_kernel_slow( at::ArrayRef scalars) { check_foreach_api_restrictions(tensors1, tensors2, scalars); std::vector result; +<<<<<<< HEAD result.reserve(tensors1.size()); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto i : c10::irange(tensors1.size())) { result.emplace_back(tensors1[i].lerp(tensors2[i], scalars[i])); } @@ -473,7 +485,10 @@ std::vector foreach_tensor_norm_slow( std::optional dtype) { check_foreach_api_restrictions(tensors); std::vector result; +<<<<<<< HEAD result.reserve(tensors.size()); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto& t : tensors) { result.emplace_back(at::linalg_vector_norm(t, ord, {}, false, dtype)); } @@ -483,7 +498,10 @@ std::vector foreach_tensor_norm_slow( std::vector foreach_tensor_max_slow(TensorList tensors) { check_foreach_api_restrictions(tensors); std::vector result; +<<<<<<< HEAD result.reserve(tensors.size()); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto& t : tensors) { result.emplace_back(at::max(t)); } diff --git a/aten/src/ATen/native/ForeachUtils.h b/aten/src/ATen/native/ForeachUtils.h index f0dce20a6eff4..a2bfe650a1219 100644 --- a/aten/src/ATen/native/ForeachUtils.h +++ b/aten/src/ATen/native/ForeachUtils.h @@ -22,7 +22,11 @@ namespace { // Check if tensor list has either a boolean tensor or a integer tensor inline bool has_integral_tensor(TensorList tensors, const bool includeBool) { return std::any_of( +<<<<<<< HEAD tensors.begin(), tensors.end(), [includeBool](const auto& t) { +======= + tensors.begin(), tensors.end(), [&includeBool](const auto& t) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return at::isIntegralType(t.scalar_type(), includeBool); }); } @@ -53,8 +57,13 @@ inline void check_foreach_api_restrictions( inline void check_foreach_api_restrictions( TensorList tensors1, TensorList tensors2) { +<<<<<<< HEAD check_foreach_api_restrictions(tensors1); check_foreach_api_restrictions(tensors2); +======= + TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor."); + TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK( tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", @@ -67,8 +76,26 @@ inline void check_foreach_api_restrictions( TensorList tensors1, TensorList tensors2, TensorList tensors3) { +<<<<<<< HEAD check_foreach_api_restrictions(tensors1, tensors2); check_foreach_api_restrictions(tensors1, tensors3); +======= + TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor."); + TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor."); + TORCH_CHECK(!tensors3.empty(), "Tensor list must have at least one tensor."); + TORCH_CHECK( + tensors1.size() == tensors2.size(), + "Tensor lists must have the same number of tensors, got ", + tensors1.size(), + " and ", + tensors2.size()); + TORCH_CHECK( + tensors1.size() == tensors3.size(), + "Tensor lists must have the same number of tensors, got ", + tensors1.size(), + " and ", + tensors3.size()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } inline void check_foreach_api_restrictions( @@ -77,7 +104,16 @@ inline void check_foreach_api_restrictions( TensorList tensors3, ArrayRef scalars) { check_foreach_api_restrictions(tensors1, tensors2, tensors3); +<<<<<<< HEAD check_foreach_api_restrictions(tensors1, scalars); +======= + TORCH_CHECK( + tensors1.size() == scalars.size(), + "Tensor list must have same number of elements as scalar list, got ", + tensors1.size(), + " and ", + scalars.size()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } inline void check_foreach_api_restrictions( @@ -85,7 +121,16 @@ inline void check_foreach_api_restrictions( TensorList tensors2, ArrayRef scalars) { check_foreach_api_restrictions(tensors1, tensors2); +<<<<<<< HEAD check_foreach_api_restrictions(tensors1, scalars); +======= + TORCH_CHECK( + tensors1.size() == scalars.size(), + "Tensor list must have same number of elements as scalar list, got ", + tensors1.size(), + " and ", + scalars.size()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // Helper function called in check_fast_path_restrictions to check whether all @@ -103,6 +148,7 @@ inline bool _check_tensors_share_device_and_dtype( tensor.is_non_overlapping_and_dense(); }; +<<<<<<< HEAD return std::all_of( tensorLists.cbegin(), tensorLists.cend(), @@ -110,6 +156,17 @@ inline bool _check_tensors_share_device_and_dtype( return std::all_of( tensorList.cbegin(), tensorList.cend(), is_tensor_okay); }); +======= + for (const auto& tensorList : tensorLists) { + for (const auto& tensor : tensorList) { + if (!is_tensor_okay(tensor)) { + return false; + } + } + } + + return true; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // Helper function called in check_fast_path_restrictions to check if @@ -155,9 +212,17 @@ inline bool _check_tensors_do_type_promotion_with_scalars( bool does_op_promote_integer_inputs_to_float = false) { for (const auto i : c10::irange(tensorList.size())) { // For division, integer inputs will result in float. +<<<<<<< HEAD if (does_op_promote_integer_inputs_to_float && at::isIntegralType(tensorList[i].scalar_type(), /*includeBool*/ true)) { return false; +======= + if (does_op_promote_integer_inputs_to_float) { + if (at::isIntegralType( + tensorList[i].scalar_type(), /*includeBool*/ true)) { + return false; + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if (!scalarList.empty()) { const auto& scalar = @@ -334,6 +399,7 @@ inline FlatMap _group_tensors_by_first_tensors_device_and_dtype( } }), "Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding"); +<<<<<<< HEAD grouped_tensors_with_indices.try_emplace( key, TensorsAndIndicesT{ @@ -362,6 +428,38 @@ inline FlatMap _group_tensors_by_first_tensors_device_and_dtype( return indices; } }()}); +======= + if (!grouped_tensors_with_indices.count(key)) { + grouped_tensors_with_indices.insert( + {key, + TensorsAndIndicesT{ + [&]() -> nested_optional_tensorvec_t { + nested_optional_tensorvec_t nested_tensorvec; + nested_tensorvec.reserve(num_lists); + for (const auto& i : c10::irange(num_lists)) { + std::vector> tensors; + if (!nested_tensorlist[i].empty()) { + // NB: num_tensors is the max possible length for any of + // the inner lists of tensor references. Reserving the max + // trades memory for perf. This should not have significant + // impact. + tensors.reserve(num_tensors); + } + nested_tensorvec.emplace_back(tensors); + } + return nested_tensorvec; + }(), + [&]() -> IndicesT { + if (!with_indices) { + return {}; + } else { + IndicesT indices; + indices.reserve(num_tensors); + return indices; + } + }()}}); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto& list_index : c10::irange(num_lists)) { if (!nested_tensorlist[list_index].empty()) { grouped_tensors_with_indices[key].first[list_index].emplace_back( diff --git a/aten/src/ATen/native/GridSampler.cpp b/aten/src/ATen/native/GridSampler.cpp index 0ca8ec2a3a887..37454e5a0f260 100644 --- a/aten/src/ATen/native/GridSampler.cpp +++ b/aten/src/ATen/native/GridSampler.cpp @@ -86,7 +86,11 @@ namespace { for (const auto d : c10::irange(out_D)) { for (const auto h : c10::irange(out_H)) { for (const auto w : c10::irange(out_W)) { +<<<<<<< HEAD // get the corresponding input x, y, z coordinates from grid +======= + // get the corresponding input x, y, z co-ordinates from grid +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW; scalar_t ix = *grid_ptr_NDHW; scalar_t iy = grid_ptr_NDHW[grid_sCoor]; @@ -285,7 +289,11 @@ namespace { for (const auto d : c10::irange(out_D)) { for (const auto h : c10::irange(out_H)) { for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NDHW += gGrid_sW /* grad_grid is contiguous */ ) { +<<<<<<< HEAD // get the corresponding input x, y, z coordinates from grid +======= + // get the corresponding input x, y, z co-ordinates from grid +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW; scalar_t ix = *grid_ptr_NDHW; scalar_t iy = grid_ptr_NDHW[grid_sCoor]; @@ -496,7 +504,11 @@ static Tensor _grid_sampler_2d_cpu_quantized( uint8_t* inp_ptr_N = inp_ptr + n * inp_sN; for (const auto h : c10::irange(out_H)) { for (const auto w : c10::irange(out_W)) { +<<<<<<< HEAD // get the corresponding input x, y, z coordinates from grid +======= + // get the corresponding input x, y, z co-ordinates from grid +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) float* grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; float x = *grid_ptr_NHW; float y = grid_ptr_NHW[grid_sCoor]; @@ -599,7 +611,11 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid, const scalar_t *inp_ptr_N = inp_ptr + n * inp_sN; for (const auto h : c10::irange(out_H)) { for (const auto w : c10::irange(out_W)) { +<<<<<<< HEAD // get the corresponding input x, y, z coordinates from grid +======= + // get the corresponding input x, y, z co-ordinates from grid +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; scalar_t x = *grid_ptr_NHW; scalar_t y = grid_ptr_NHW[grid_sCoor]; @@ -771,7 +787,11 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output, scalar_t *gGrid_ptr_NHW = gGrid_ptr + n * gGrid_sN; for (const auto h : c10::irange(out_H)) { for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NHW += gGrid_sW /* grad_grid is contiguous */ ) { +<<<<<<< HEAD // get the corresponding input x, y coordinates from grid +======= + // get the corresponding input x, y co-ordinates from grid +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; scalar_t x = *grid_ptr_NHW; scalar_t y = grid_ptr_NHW[grid_sCoor]; diff --git a/aten/src/ATen/native/IndexingUtils.h b/aten/src/ATen/native/IndexingUtils.h index 948a6b8320a4e..612f479ddeaa5 100644 --- a/aten/src/ATen/native/IndexingUtils.h +++ b/aten/src/ATen/native/IndexingUtils.h @@ -5,6 +5,7 @@ #include #include +<<<<<<< HEAD #ifndef AT_PER_OPERATOR_HEADERS #include #else @@ -12,6 +13,8 @@ #include #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace at::native { [[noreturn]] @@ -22,8 +25,12 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, [[maybe_unused]] static std::vector expandTensors( const Tensor& self, +<<<<<<< HEAD IOptTensorListRef indices, bool ensure_same_device = false) { +======= + IOptTensorListRef indices) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // If indices come in as ByteTensor or BoolTensor (masks), expand them into // the equivalent indexing by LongTensors std::vector result; @@ -46,6 +53,7 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, } } // Replace with nonzeros +<<<<<<< HEAD at::Tensor nonzero; if (ensure_same_device && index.device() != self.device()) { bool non_blocking = index.is_cpu() && self.device().is_cuda(); @@ -59,6 +67,12 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, } } else if (ensure_same_device && index.device() != self.device()) { result.emplace_back(index.to(self.device())); +======= + auto nonzero = index.nonzero(); + for (const auto j : c10::irange(index.dim())) { + result.emplace_back(nonzero.select(1, j)); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { result.emplace_back(index); } diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp index a744da3bcad2e..3ad3024e59735 100644 --- a/aten/src/ATen/native/Linear.cpp +++ b/aten/src/ATen/native/Linear.cpp @@ -93,7 +93,11 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optionaldefined() && !input.is_xla()) { // Also hit the fused path for contiguous 3D input, if not using xla // backend. Reshaping/flattening has some performance implications on xla. +<<<<<<< HEAD bool is_contiguous = input.is_contiguous_or_false(); +======= + bool is_contiguous = definitely_contiguous(input.sym_sizes(), input.sym_strides(), input.sym_numel()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (is_contiguous && input_dim == 3) { return _flatten_nd_linear(input, weight, *bias); } else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) { @@ -154,8 +158,13 @@ static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArra Tensor left = left_; Tensor right = right_; for (const auto i : c10::irange(dim)) { +<<<<<<< HEAD auto sl = TORCH_GUARD_OR_TRUE(left.sym_size(i).sym_ne(1)); auto sr = TORCH_GUARD_OR_TRUE(right.sym_size(i).sym_ne(1)); +======= + auto sl = TORCH_GUARD_SIZE_OBLIVIOUS(left.sym_size(i).sym_ne(1)); + auto sr = TORCH_GUARD_SIZE_OBLIVIOUS(right.sym_size(i).sym_ne(1)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (sum_dims[i]) { // first dimensions that will be summed over after multiplication if (sl && sr) { // dimensions nontrivially in both left and right must be of the same size TORCH_SYM_CHECK(left.sym_size(i).sym_eq(right.sym_size(i)), "non-broadcast dimensions must match"); @@ -185,6 +194,7 @@ static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArra // right: "lro, summed, ro" permuted with rpermutation and the three flattened // then the permuted output is a view of bmm(left, right) // finally, opermutation reverts the permutation to the original order of dimensions +<<<<<<< HEAD // By default the output is "lro, lo, 1-for-summed-dims, ro" with original shape dimensions. // However, if all dimensions from the right operand appear before those from the left // operand in the final output, we can swap the operands so that bmm directly produces @@ -196,6 +206,8 @@ static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArra std::swap(lo, ro); std::swap(lo_size, ro_size); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto out_num_dim = lro.size() + lo.size() + sum_dims_.size() + ro.size(); std::vector out_size; out_size.reserve(out_num_dim); @@ -488,7 +500,11 @@ Tensor einsum(std::string_view equation, TensorList operands, at::OptionalIntArr // Iterate over each dimension covered by ellipsis const auto ndim = operands[i].ndimension() - (static_cast(op_labels[i].size()) - 1); for (auto j = ell_num_dim - ndim; j < ell_num_dim; ++j) { +<<<<<<< HEAD if (TORCH_GUARD_OR_TRUE(op.sym_size(dim).sym_ne(1))) { +======= + if (TORCH_GUARD_SIZE_OBLIVIOUS(op.sym_size(dim).sym_ne(1))) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Update ellipsis size TORCH_SYM_CHECK( ell_sizes[j].sym_eq(1).sym_or(ell_sizes[j].sym_eq(op.sym_size(dim))), @@ -507,7 +523,11 @@ Tensor einsum(std::string_view equation, TensorList operands, at::OptionalIntArr permutation[ell_index + j] = dim++; } } else if (permutation[label_perm_index[s]] == -1) { +<<<<<<< HEAD if (TORCH_GUARD_OR_TRUE(op.sym_size(dim).sym_ne(1))) { +======= + if (TORCH_GUARD_SIZE_OBLIVIOUS(op.sym_size(dim).sym_ne(1))) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Update subscript TORCH_SYM_CHECK( label_size[s].sym_eq(1).sym_or(label_size[s].sym_eq(op.sym_size(dim))), @@ -585,6 +605,7 @@ Tensor einsum(std::string_view equation, TensorList operands, at::OptionalIntArr SmallVector a_dims_to_sum; SmallVector b_dims_to_sum; for (auto dim = out_num_dim; dim < perm_index; ++dim) { +<<<<<<< HEAD auto sa = TORCH_GUARD_OR_TRUE(a.sym_size(dim).sym_ne(1)); auto sb = TORCH_GUARD_OR_TRUE(b.sym_size(dim).sym_ne(1)); @@ -592,15 +613,26 @@ Tensor einsum(std::string_view equation, TensorList operands, at::OptionalIntArr // if both a and b are equal, or we can't tell that its a broadcast for sure, // we assume non-broadcast. TORCH_SYM_CHECK(a.sym_size(dim).sym_eq(b.sym_size(dim)), "non-broadcast dimensions must match"); +======= + if (TORCH_GUARD_SIZE_OBLIVIOUS(a.sym_size(dim).sym_ne(1)) + && TORCH_GUARD_SIZE_OBLIVIOUS(b.sym_size(dim).sym_ne(1))) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (--dim_counts[dim] == 1) { sum_dims.push_back(dim); dim_counts[dim] = 0; } } else if (dim_counts[dim] == 1) { +<<<<<<< HEAD if (sa) { a_dims_to_sum.push_back(dim); dim_counts[dim] = 0; } else if (sb) { +======= + if (TORCH_GUARD_SIZE_OBLIVIOUS(a.sym_size(dim).sym_ne(1))) { + a_dims_to_sum.push_back(dim); + dim_counts[dim] = 0; + } else if (TORCH_GUARD_SIZE_OBLIVIOUS(b.sym_size(dim).sym_ne(1))) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) b_dims_to_sum.push_back(dim); dim_counts[dim] = 0; } diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 616e6ec60e13d..2e05ee88a491f 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1360,8 +1360,11 @@ Tensor outer(const Tensor& self, const Tensor& vec2) { #endif +<<<<<<< HEAD #if !defined(__aarch64__) || AT_MKLDNN_ACL_ENABLED() // Used by default on x86 platforms and on AArch64+ACL +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static inline int64_t get_mkldnn_matmul_min_dim() { static auto value = [&] { const int64_t default_min_dim = [&] { @@ -1395,7 +1398,12 @@ static inline bool apply_mkldnn_matmul_heur(int64_t m, int64_t k, int64_t n) { const int64_t min_size = get_mkldnn_matmul_min_size(); return at::globalContext().userEnabledMkldnn() && m > min_dim && k > min_dim && n > min_dim && m * k * n > min_size; } +<<<<<<< HEAD #endif +======= + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static void addmm_impl_cpu_( Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) { TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2); @@ -1771,8 +1779,12 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens return (strides[2] == 1 && (sizes[1] == 1 || strides[1] >= sizes[2])) || (strides[1] == 1 && (sizes[2] == 1 || strides[2] >= sizes[1])); }; +<<<<<<< HEAD #if !defined(__aarch64__) || AT_MKLDNN_ACL_ENABLED() // Always apply mkldnn heuristic on x86 platform, but on ARM only if compiled with ACL +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool apply_heur = apply_mkldnn_matmul_heur(batch1.sizes()[1], batch1.sizes()[2], batch2.sizes()[2]); if (apply_heur && use_mkldnn_matmul(batch1, batch2, self_or_result)) { try { @@ -1783,7 +1795,11 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens at::globalContext().setUserEnabledMkldnn(false); } } +<<<<<<< HEAD #endif +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (contraction_size * res_rows * res_cols < 400) { if (is_bmm_out) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, batch1.scalar_type(), "bmm", [&] { diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp index 265bc112adcc2..c0267fe957aac 100644 --- a/aten/src/ATen/native/Loss.cpp +++ b/aten/src/ATen/native/Loss.cpp @@ -127,9 +127,12 @@ TORCH_IMPL_FUNC(smooth_l1_loss_out) TORCH_IMPL_FUNC(mse_loss_out) (const Tensor& input, const Tensor& target, int64_t reduction, const Tensor& result) { +<<<<<<< HEAD TORCH_CHECK(input.device() == target.device(), "Expected all tensors to be on the same device, but found at least two devices, ", input.device(), " and ", target.device(), "!"); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (reduction != Reduction::None) { Tensor loss; auto iter = TensorIterator::borrowing_binary_op(loss, input, target); diff --git a/aten/src/ATen/native/LossNLL.cpp b/aten/src/ATen/native/LossNLL.cpp index ca86292403fbf..e687749aa72f8 100644 --- a/aten/src/ATen/native/LossNLL.cpp +++ b/aten/src/ATen/native/LossNLL.cpp @@ -47,6 +47,7 @@ TORCH_META_FUNC(nll_loss_forward) TORCH_CHECK( target.dim() <= 1, "0D or 1D target tensor expected, multi-target not supported"); +<<<<<<< HEAD if (self.dim() == 1 && target.dim() == 1) { TORCH_CHECK_VALUE( target.size(0) == 1, @@ -55,6 +56,12 @@ TORCH_META_FUNC(nll_loss_forward) } TORCH_CHECK( self.dim() == 1 || (self.size(0) == target.size(0)), +======= + + auto no_batch_dim = self.dim() == 1 && target.dim() == 0; + TORCH_CHECK( + no_batch_dim || (self.size(0) == target.size(0)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "size mismatch (got input: ", self.sizes(), ", target: ", diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index b261da5fe54ee..3cca9dcfecd68 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -1068,7 +1068,11 @@ inline scalar_t calc_igammac(scalar_t a, scalar_t x) { * result at the boundary * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for * Large Parameter (see DLMF 8.12.4 [igam1]) +<<<<<<< HEAD * - if x > 1.1 and x < a, using the subtraction from the regularized lower +======= + * - if x > 1.1 and x < a, using the substraction from the regularized lower +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * incomplete gamma * - otherwise, calculate the series from [igam2] eq (5) */ @@ -1148,7 +1152,11 @@ scalar_t calc_igamma(scalar_t a, scalar_t x) { * result at the boundary * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for * Large Parameter (see DLMF 8.12.3 [igam1]) +<<<<<<< HEAD * - if x > 1 and x > a, using the subtraction from the regularized upper +======= + * - if x > 1 and x > a, using the substraction from the regularized upper +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * incomplete gamma * - otherwise, calculate the series from [igam2] eq (4) */ @@ -1730,7 +1738,11 @@ inline C10_HOST_DEVICE T calc_ndtri(T y0) { with the usual checks for overflow etcetera. Performance-wise, it seems to be substantially faster than either +<<<<<<< HEAD the SLATEC DERFC function [or an erfcx function derived there from] +======= + the SLATEC DERFC function [or an erfcx function derived therefrom] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) or Cody's CALERF function (from netlib.org/specfun), while retaining near machine precision in accuracy. */ @@ -2862,7 +2874,11 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) { T q = x; T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x) * q - p; p = q; q = r; @@ -2910,7 +2926,11 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) { T q = x + x; T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x) * q - p; p = q; q = r; @@ -2966,7 +2986,11 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) { T q = x + x - T(1.0); T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x) * q - p; p = q; q = r; @@ -3026,7 +3050,11 @@ inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) { T q = x + x + T(1.0); T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x) * q - p; p = q; q = r; @@ -3150,7 +3178,11 @@ inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) { T q = T(1.0) - x; T r; +<<<<<<< HEAD for (int64_t k = 1; (k < n) && !std::isnan(q); k++) { +======= + for (int64_t k = 1; k < n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (((k + k) + (T(1.0) - x)) * q - k * p) / (k + 1); p = q; q = r; @@ -3190,7 +3222,11 @@ inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) { T q = x; T r; +<<<<<<< HEAD for (int64_t k = 1; (k < n) && !std::isnan(q); k++) { +======= + for (int64_t k = 1; k < n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = ((k + k + 1) * x * q - k * p) / (k + 1); p = q; q = r; @@ -3733,7 +3769,11 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) T q = x + x - T(1.0); T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; @@ -3785,7 +3825,11 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) T q = x + x - T(1.0) + (x + x - T(1.0)); T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; @@ -3841,7 +3885,11 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) T q = x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; @@ -3897,7 +3945,11 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) T q = x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 7327bf2d7e30b..41fe6f34d058a 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -521,6 +521,7 @@ BatchNormBackend _select_batch_norm_backend( } // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM once ROCm officially supports NHWC in MIOpen +<<<<<<< HEAD // See https://github.com/pytorch/pytorch/issues/64427. // non static variable is used to be able to change environment variable in runtime for testing // enabled by default for ROCm >= 7.0.0 with miopen 3.5 @@ -528,6 +529,12 @@ BatchNormBackend _select_batch_norm_backend( bool is_miopen_3_4 = miopen_version >= 30400; // ROCm 6.4 bool is_miopen_3_5 = miopen_version >= 30500; // ROCm 7.0 bool PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM").value_or(is_miopen_3_5); +======= + // See #64427 + // non static variable is used to be able to change environment variable in runtime for testing + // enabled by default for ROCm >= 7.0.0 + bool PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM").value_or(ROCM_VERSION >= 70000); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( detail::getCUDAHooks().compiledWithMIOpen() @@ -536,15 +543,30 @@ BatchNormBackend _select_batch_norm_backend( && input.dim() <= MIOPEN_DIM_MAX && input.dim() >= 3 && input.scalar_type() != at::kDouble +<<<<<<< HEAD && (is_miopen_3_4 || input.scalar_type() != at::kBFloat16) +======= +#if (defined(USE_ROCM) && ROCM_VERSION < 60400) + && (input.scalar_type() != at::kBFloat16) +#endif + && (detail::getCUDAHooks().versionMIOpen() >= 30400 || input.scalar_type() != at::kBFloat16) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) && weight.scalar_type() == at::kFloat // only FP32 weight for FP32 or FP16/BF16(mixed) input && weight.defined() && bias.defined() && ((running_mean.defined() && running_var.defined()) || (!running_mean.defined() && !running_var.defined() && training)) && (input.suggest_memory_format() == MemoryFormat::Contiguous +<<<<<<< HEAD || (is_miopen_3_5 && PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM && (input.suggest_memory_format() == MemoryFormat::ChannelsLast || input.suggest_memory_format() == MemoryFormat::ChannelsLast3d))) +======= +#if (defined(USE_ROCM) && ROCM_VERSION >= 60500) + || (input.suggest_memory_format() == MemoryFormat::ChannelsLast && PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM) + || (input.suggest_memory_format() == MemoryFormat::ChannelsLast3d && PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM) +#endif + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) { return BatchNormBackend::Miopen; } @@ -624,7 +646,11 @@ std::tuple _batch_norm_impl_index( if (backend == BatchNormBackend::Miopen) { return std::tuple_cat( at::miopen_batch_norm( +<<<<<<< HEAD input.contiguous(), weight.contiguous(), bias.contiguous(), +======= + input.contiguous(input.suggest_memory_format()), weight.contiguous(), bias.contiguous(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) running_mean.defined() ? running_mean.contiguous() : running_mean, running_var.defined() ? running_var.contiguous() : running_var, training, momentum, eps), diff --git a/aten/src/ATen/native/Onehot.cpp b/aten/src/ATen/native/Onehot.cpp index 8833bdb6e471d..ba0c335999877 100644 --- a/aten/src/ATen/native/Onehot.cpp +++ b/aten/src/ATen/native/Onehot.cpp @@ -1,6 +1,9 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifndef AT_PER_OPERATOR_HEADERS #include @@ -25,6 +28,7 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) { if (num_classes == -1) { num_classes = self.max().item().toLong() + 1; } +<<<<<<< HEAD { // If `self` is a DTensor, then allow implicit replication // of the `index` Tensor. @@ -32,6 +36,10 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) { at::Tensor index = at::arange(num_classes, self.options()); return at::eq(self.unsqueeze(-1), index).to(kLong); } +======= + at::Tensor index = at::arange(num_classes, self.options()); + return at::eq(self.unsqueeze(-1), index).to(kLong); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } auto shape = self.sizes().vec(); diff --git a/aten/src/ATen/native/PadNd.cpp b/aten/src/ATen/native/PadNd.cpp index 8099648d37b29..53ba273a4f643 100644 --- a/aten/src/ATen/native/PadNd.cpp +++ b/aten/src/ATen/native/PadNd.cpp @@ -240,6 +240,7 @@ Tensor _pad_enum_symint(const Tensor &self, c10::SymIntArrayRef pad, int64_t mod default: {} } } +<<<<<<< HEAD std::ostringstream error_msg; error_msg << "Padding size " << pad.size() << " is not supported for " << input_dim << "D input tensor.\n"; @@ -249,6 +250,10 @@ Tensor _pad_enum_symint(const Tensor &self, c10::SymIntArrayRef pad, int64_t mod error_msg << " - 4D or 5D input: padding size = 6 (pads last 3 dimensions)"; C10_THROW_ERROR(NotImplementedError, error_msg.str()); +======= + C10_THROW_ERROR(NotImplementedError, + "Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } Tensor pad_symint(const Tensor &self, c10::SymIntArrayRef pad, std::string_view mode, std::optional value) { diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index 7f335de04b90a..204b3afe97ad3 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -17,7 +17,11 @@ using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& g DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel) DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel) +<<<<<<< HEAD // average pooling has same signature for forward and backward +======= +// averge pooling has same signature for forward and backward +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH, int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, std::optional divisor_override); using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH, @@ -26,7 +30,11 @@ using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel) DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel) +<<<<<<< HEAD // average pooling has same signature for forward and backward +======= +// averge pooling has same signature for forward and backward +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using avg_pool3d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH, int64_t kD, int64_t dW, int64_t dH, int64_t dD, int64_t padW, int64_t padH, int64_t padD, bool count_include_pad, diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp index 746d8c1a2db4f..77d6957d18d00 100644 --- a/aten/src/ATen/native/QuantizedLinear.cpp +++ b/aten/src/ATen/native/QuantizedLinear.cpp @@ -25,11 +25,17 @@ #include #ifdef USE_FBGEMM +<<<<<<< HEAD C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include #include C10_DIAGNOSTIC_POP() +======= +#include +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif // USE_FBGEMM namespace caffe2 { @@ -411,8 +417,12 @@ Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) { Tensor fbgemm_linear_fp16_weight_fp32_activation( const Tensor& input, const Tensor& packed_weight, +<<<<<<< HEAD const std::optional& bias, at::Tensor& output) { +======= + const Tensor& bias) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated " "and will be removed in a future PyTorch release.") @@ -433,15 +443,25 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation( TORCH_CHECK(input.size(input.dim() - 1) == packed_weight_fp16.numRows()) TORCH_CHECK(input.dim() >= 2); +<<<<<<< HEAD +======= + TORCH_CHECK(bias.dim() == 1); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) const int64_t M = size_to_dim_(input.dim() - 1, input.sizes()); const int64_t N = packed_weight_fp16.numCols(); +<<<<<<< HEAD std::vector output_size = input.sizes().vec(); output_size.back() = N; // Resize output Tensor output.resize_(output_size); +======= + std::vector output_size = input.sizes().vec(); + output_size.back() = N; + Tensor output = at::empty(output_size, input.options().dtype(at::kFloat)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Call the fp16 gemm interface fbgemm::cblas_gemm_compute( @@ -453,16 +473,21 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation( output.data_ptr()); // Add bias term +<<<<<<< HEAD c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias); const Tensor& bias_ = *bias_maybe_owned; if (bias_.defined()) { TORCH_CHECK(bias_.dim() == 1); output.add_(bias_); } +======= + output.add_(bias); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return output; } +<<<<<<< HEAD Tensor fbgemm_linear_fp16_weight_fp32_activation( const Tensor& input, const Tensor& packed_weight, @@ -471,6 +496,8 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation( return at::native::fbgemm_linear_fp16_weight_fp32_activation(input, packed_weight, bias, output); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor fbgemm_linear_fp16_weight( const Tensor& input, const Tensor& packed_weight, @@ -479,6 +506,7 @@ Tensor fbgemm_linear_fp16_weight( input, packed_weight, bias); } +<<<<<<< HEAD Tensor fbgemm_linear_fp16_weight( const Tensor& input, const Tensor& packed_weight, @@ -488,6 +516,8 @@ Tensor fbgemm_linear_fp16_weight( input, packed_weight, bias, output); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #else // USE_FBGEMM Tensor fbgemm_linear_int8_weight_fp32_activation( @@ -577,8 +607,12 @@ Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) { Tensor fbgemm_linear_fp16_weight_fp32_activation( const Tensor& input, const Tensor& packed_weight, +<<<<<<< HEAD const std::optional& bias, at::Tensor& output) { +======= + const Tensor& bias) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated " "and will be removed in a future PyTorch release.") @@ -589,6 +623,7 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation( false, "This PyTorch installation was not built with FBGEMM operators"); } +<<<<<<< HEAD Tensor fbgemm_linear_fp16_weight_fp32_activation( const Tensor& input, const Tensor& packed_weight, @@ -618,6 +653,8 @@ Tensor fbgemm_linear_fp16_weight( false, "This PyTorch installation was not built with FBGEMM operators"); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor fbgemm_linear_fp16_weight( const Tensor& input, const Tensor& packed_weight, diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index db046428bb683..0a5a154d59cd5 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -71,8 +71,11 @@ #include #include #include +<<<<<<< HEAD #include #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -220,8 +223,11 @@ static void check_argmax_argmin( const char* name, const Tensor& self, const std::optional& dim) { +<<<<<<< HEAD TORCH_CHECK(!self.is_complex(), name, ": does not support complex input"); TORCH_CHECK(!(self.scalar_type() == kBool), name, ": does not support bool input"); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (dim.has_value()) { auto dim_ = maybe_wrap_dim(dim.value(), self.dim()); native::zero_numel_check_dims(self, dim_, name); @@ -402,6 +408,7 @@ TORCH_META_FUNC(amin) resize_reduction(*this, self, dim, keepdim, out_dtype); } +<<<<<<< HEAD TORCH_META_FUNC(hash_tensor) (const Tensor& self, IntArrayRef dim, bool keepdim, int64_t mode) { auto maybe_result = maybe_get_output(); @@ -415,6 +422,8 @@ TORCH_META_FUNC(hash_tensor) } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::meta namespace at::native { @@ -458,7 +467,10 @@ DEFINE_DISPATCH(argmin_stub); DEFINE_DISPATCH(cumsum_stub); DEFINE_DISPATCH(cumprod_stub); DEFINE_DISPATCH(logcumsumexp_stub); +<<<<<<< HEAD DEFINE_DISPATCH(xor_sum_stub); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor _logcumsumexp_cpu(const Tensor& self, int64_t dim) { Tensor result = at::empty_like(self, MemoryFormat::Contiguous); @@ -1469,7 +1481,11 @@ Tensor& nanmean_out( "nanmean(): expected input to have floating point or complex dtype but got ", self.scalar_type()); const auto factor = at::native::isnan(self).logical_not_().sum(dim, keepdim); +<<<<<<< HEAD at::nansum_out(result, self, dim, keepdim, opt_dtype).div_(factor); +======= + at::native::nansum_out(self, dim, keepdim, opt_dtype, result).div_(factor); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return result; } @@ -2251,6 +2267,7 @@ Tensor dist(const Tensor &self, const Tensor& other, const Scalar& p){ return at::norm(self - other, p); } +<<<<<<< HEAD enum class HashMode { XOR_SUM = 0 }; TORCH_IMPL_FUNC(hash_tensor_out) (const Tensor& self, IntArrayRef dim, bool keepdim, int64_t mode, const Tensor& result) { @@ -2269,6 +2286,8 @@ TORCH_IMPL_FUNC(hash_tensor_out) (const Tensor& self, IntArrayRef dim, bool keep } } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool cpu_equal(const Tensor& self, const Tensor& other) { if (!at::namedinference::are_names_equal( self.unsafeGetTensorImpl(), other.unsafeGetTensorImpl())) { diff --git a/aten/src/ATen/native/ReduceOps.h b/aten/src/ATen/native/ReduceOps.h index c562bf548403b..818a69b597693 100644 --- a/aten/src/ATen/native/ReduceOps.h +++ b/aten/src/ATen/native/ReduceOps.h @@ -27,7 +27,10 @@ DECLARE_DISPATCH(reduce_fn, min_values_stub) DECLARE_DISPATCH(reduce_fn, max_values_stub) DECLARE_DISPATCH(reduce_fn, argmax_stub) DECLARE_DISPATCH(reduce_fn, argmin_stub) +<<<<<<< HEAD DECLARE_DISPATCH(reduce_fn, xor_sum_stub) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using reduce_std_var_function = void (*)(TensorIterator&, double correction, bool take_sqrt); diff --git a/aten/src/ATen/native/ReplicationPadding.cpp b/aten/src/ATen/native/ReplicationPadding.cpp index 0c66c7a632997..18e57899bfd9d 100644 --- a/aten/src/ATen/native/ReplicationPadding.cpp +++ b/aten/src/ATen/native/ReplicationPadding.cpp @@ -229,20 +229,29 @@ void replication_pad3d_backward_out_cpu_template( int pbottom = paddingSize[3]; int pfront = paddingSize[4]; int pback = paddingSize[5]; +<<<<<<< HEAD int dimc = 0; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int dimw = 3; int dimh = 2; int dimd = 1; if (input.dim() == 5) { +<<<<<<< HEAD dimc++; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dimw++; dimh++; dimd++; } /* sizes */ +<<<<<<< HEAD int64_t ichannel = input.size(dimc); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t idepth = input.size(dimd); int64_t iheight = input.size(dimh); int64_t iwidth = input.size(dimw); @@ -252,9 +261,12 @@ void replication_pad3d_backward_out_cpu_template( at::native::padding::check_valid_input<3>(input, paddingSize); +<<<<<<< HEAD TORCH_CHECK(ichannel == gradOutput.size(dimc), "gradOutput width unexpected. Expected: ", ichannel, ", Got: ", gradOutput.size(dimc)); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(owidth == gradOutput.size(dimw), "gradOutput width unexpected. Expected: ", owidth, ", Got: ", gradOutput.size(dimw)); diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp index 2b61bcec6a828..13e34c9c25f75 100644 --- a/aten/src/ATen/native/SegmentReduce.cpp +++ b/aten/src/ATen/native/SegmentReduce.cpp @@ -480,7 +480,11 @@ REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets REGISTER_SVE256_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel) // Currently some computation is being duplicated across forward and backward. +<<<<<<< HEAD // TODO: Cache indices in forward pass to reuse in backward +======= +// TODO: Cache indices in forward pass to re-use in backward +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor _segment_reduce_backward_kernel( const Tensor& grad, const Tensor& output, diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp index 44215a26018f0..2f7b649c26795 100644 --- a/aten/src/ATen/native/Sorting.cpp +++ b/aten/src/ATen/native/Sorting.cpp @@ -59,8 +59,11 @@ TORCH_META_FUNC(topk) "selected index k out of range"); int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim); TORCH_CHECK(k >= 0 && k <= sliceSize, "k not in range for dimension"); +<<<<<<< HEAD TORCH_CHECK(!self.is_complex(), " topk does not support complex dtypes on CPU"); TORCH_CHECK(!(self.scalar_type() == kBool), "topk does not support bool dtypes on CPU"); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Build the output size, which is the dim being selected set to // size k @@ -76,7 +79,15 @@ TORCH_META_FUNC2(sort, stable) (const Tensor& self, std::optional stable, int64_t dim, bool descending) { maybe_wrap_dim(dim, self.dim()); +<<<<<<< HEAD TORCH_CHECK(!self.is_complex(), " Sort does not support complex dtypes on CPU"); +======= + const auto self_dtype = self.dtype(); + TORCH_CHECK_VALUE( + self_dtype != ScalarType::ComplexFloat && + self_dtype != ScalarType::ComplexDouble, + "Sort currently does not support complex dtypes on CPU."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // See issue: https://github.com/pytorch/pytorch/issues/65863 // Strides should be dense, so as not to allocate too much memory. diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 7d613fc023120..af5e7e42fdd0e 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -475,7 +475,11 @@ static void build_index_op( TensorIteratorBase& iter, const at::native::AdvancedIndex& info, const Tensor& result) { +<<<<<<< HEAD // 'TensorIterator' needs to own the things coming from 'info', since +======= + // 'TensorIterator' needs to own the things comming from 'info', since +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // 'info' will be destroyed after the META function. TensorIteratorConfig config; // info.src is a restrided view of result @@ -2153,6 +2157,7 @@ static void _scatter_via_index_put( const Tensor& src, const Tensor& mut_out, bool accumulate) { +<<<<<<< HEAD // If index is expanded with zero strides across non-scatter dimensions, // advanced indexing with the index tensor alone achieves the desired // semantics and avoids creating large intermediate tensors. @@ -2200,6 +2205,83 @@ static void _scatter_via_index_put( } } mut_out.index_put_(indices, src_view, accumulate); +======= + if (self.dim() == 1) { + torch::List> indices; + indices.reserve(1); + indices.push_back(index); + mut_out.index_put_(indices, src, accumulate); + } else { + Tensor mut_out_contig = mut_out.contiguous(); + + auto index_coords_sizes = index.sizes().vec(); + index_coords_sizes.push_back(self.dim()); + auto index_coords = at::empty( + index_coords_sizes, + at::TensorOptions().dtype(at::ScalarType::Long).device(self.device())); + + for (int64_t dim_other = 0; dim_other < self.dim(); dim_other++) { + if (dim_other == dim) { + continue; + } + auto dim_coord_vals = at::arange( + index.size(dim_other), at::TensorOptions().device(self.device())); + + for (int64_t dim_unsqueeze = 0; dim_unsqueeze < self.dim() - 1; + dim_unsqueeze++) { + dim_coord_vals = + dim_coord_vals.unsqueeze((dim_unsqueeze >= dim_other) ? -1 : 0); + } + + auto view_sizes = index.sizes().vec(); + view_sizes.push_back(1); + auto view_strides = index_coords.strides().vec(); + view_strides[self.dim()] = self.dim(); + + at::as_strided(index_coords, view_sizes, view_strides, dim_other) + .copy_(dim_coord_vals.unsqueeze(-1)); + } + + auto view_sizes = index.sizes().vec(); + view_sizes.push_back(1); + auto view_strides = index_coords.strides().vec(); + view_strides[self.dim()] = self.dim(); + + at::as_strided(index_coords, view_sizes, view_strides, dim) + .copy_(index.unsqueeze(-1)); + + Tensor index_coords_flat = index_coords.flatten(0, -2); + + // Copy mut_out_contig's strides into a tensor + // TODO: Is there a utility function that already does this? + IntArrayRef mut_out_contig_strides = mut_out_contig.strides(); + Tensor coord_strides = at::empty( + {mut_out_contig.dim()}, + TensorOptions().dtype(at::ScalarType::Long).device(at::kCPU)); + std::memcpy( + coord_strides.mutable_data_ptr(), + mut_out_contig_strides.data(), + coord_strides.nbytes()); + coord_strides = coord_strides.to(mut_out_contig.device()); + + // `index_flat` contains the 1-D indices corresponding with the + // flattened `mut_out` + Tensor index_flat = (index_coords_flat * coord_strides).sum({-1}); + Tensor mut_out_flat = mut_out_contig.flatten(); + Tensor src_flat = + at::as_strided(src, index.sizes(), src.strides()).flatten(); + + torch::List> indices; + indices.reserve(1); + indices.push_back(index_flat); + + mut_out_flat.index_put_(indices, src_flat, accumulate); + + if (!mut_out.is_contiguous()) { + mut_out.copy_(mut_out_flat.reshape(mut_out.sizes())); + } + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } template < diff --git a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h index bc6c2533eac5c..3ee903d6a01ae 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h +++ b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h @@ -35,9 +35,13 @@ inline std::tuple canDispatchToMaskedFill( auto self_device = self.device(); for (const std::optional& i : indices) { if (!i.has_value() || !(*i).defined()) { +<<<<<<< HEAD if (!mask.defined()) { num_ind++; } +======= + num_ind++; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { const Tensor& index = *i; if ((index.scalar_type() != kByte && index.scalar_type() != kBool) || @@ -73,7 +77,11 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) { checkIndexTensorTypes(orig, /*allow_int*/ true); // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more // LongTensors +<<<<<<< HEAD auto indices = expandTensors(self, orig, /*ensure_same_device=*/true); +======= + auto indices = expandTensors(self, orig); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // next broadcast all index tensors together try { indices = expand_outplace(indices); @@ -93,6 +101,15 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) { if (!hasContiguousSubspace(indices)) { std::tie(self, indices) = transposeToFront(self, indices); } +<<<<<<< HEAD +======= + // Ensure indices are on the same device as self + for (auto& indice : indices) { + if (indice.defined() && indice.device() != self.device()) { + indice = indice.to(self.device()); + } + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (auto& indice : indices) { if (indice.defined() && indice.dtype() == at::kInt) { indice = indice.to(at::kLong); diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index 7df7745fc5077..ef8065a3ff90b 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -67,7 +67,11 @@ namespace at::native { namespace { // dense_to_sparse_{csr,bsr,csc,bsc} common helpers +<<<<<<< HEAD // Preparation for the N-D dense -> sparse compressed conversion. +======= +// Preparation fo the N-D dense -> sparse compressed conversion. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // The N-D input is converted to 3-D (single batch dim) where we check that the // product of batch dims is nonzero and for each batch the sparse matrix // contained within has the same number of non-zero elements. diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 1886e65fc1edc..20349ae43ea90 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -1367,9 +1367,15 @@ void randperm_cpu(Tensor& result, int64_t n, CPUGeneratorImpl* generator) { for (int64_t i = 0; i < n - 1; i++) { // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand) int64_t z = generator->random() % (n - i); +<<<<<<< HEAD scalar_t save = r__data[i * r__stride_0]; r__data[i * r__stride_0] = r__data[(z + i) * r__stride_0]; r__data[(z + i) * r__stride_0] = save; +======= + scalar_t sav = r__data[i * r__stride_0]; + r__data[i * r__stride_0] = r__data[(z + i) * r__stride_0]; + r__data[(z + i) * r__stride_0] = sav; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return; } @@ -1640,9 +1646,12 @@ Tensor zeros_symint( std::optional layout, std::optional device, std::optional pin_memory) { +<<<<<<< HEAD for (const auto& dim_size : size) { TORCH_CHECK(dim_size >= 0, "zeros: Dimension size must be non-negative."); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Layout layout_ = layout.value_or(Layout::Strided); if (at::sparse_csr::is_sparse_compressed(layout_)) { return zeros_sparse_compressed_symint( diff --git a/aten/src/ATen/native/TensorIteratorReduce.cpp b/aten/src/ATen/native/TensorIteratorReduce.cpp index ce2987eb251ae..a32782eec7763 100644 --- a/aten/src/ATen/native/TensorIteratorReduce.cpp +++ b/aten/src/ATen/native/TensorIteratorReduce.cpp @@ -80,7 +80,11 @@ static void two_pass_reduction(TensorIteratorBase& iter, loop2d_t loop) { } /// Chooses a dimension over which to parallelize. Prefers the outer-most +<<<<<<< HEAD /// dimension that's larger than the number of available threads. +======= +/// dimension thats larger than the number of available threads. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static int find_split_dim(TensorIteratorBase& iter) { int num_threads = at::get_num_threads(); auto shape = iter.shape(); diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp index 4fa0556ad7859..fc66f6d01ae24 100644 --- a/aten/src/ATen/native/TensorProperties.cpp +++ b/aten/src/ATen/native/TensorProperties.cpp @@ -18,7 +18,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -58,12 +61,15 @@ c10::SymInt sym_size(const Tensor& self, int64_t dim) { return self.sym_size(dim); } +<<<<<<< HEAD c10::SymBool sym_is_contiguous( const Tensor& self, c10::MemoryFormat memory_format) { return self.sym_is_contiguous(memory_format); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::SymInt sym_stride(const Tensor& self, int64_t dim) { return self.sym_stride(dim); } @@ -120,7 +126,11 @@ Tensor& detach_(Tensor& self) { } Tensor contiguous(const Tensor& self, MemoryFormat memory_format) { +<<<<<<< HEAD if (self.is_contiguous_or_false(memory_format)) { +======= + if (self.is_contiguous(memory_format)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self; } TORCH_CHECK( diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index c2d0856c3cd4c..4e08d69c4c983 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -247,7 +247,11 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) { // Checking names before the actual dimensions. auto maybe_outnames = namedinference::compute_cat_outnames(materialized); +<<<<<<< HEAD TORCH_CHECK_VALUE( +======= + TORCH_CHECK( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) !materialized.empty(), "torch.cat(): expected a non-empty list of Tensors"); @@ -274,7 +278,11 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) { // when computing the actual output dtype and the flags. if (is_out_defined) { // Check for type promotion, if the output tensor is defined. +<<<<<<< HEAD TORCH_CHECK_TYPE( +======= + TORCH_CHECK( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) canCast(out_dtype, result.scalar_type()), "torch.cat(): input types can't be cast to the desired output type ", result.scalar_type()); @@ -293,7 +301,11 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) { // are compatible, i.e. we can execute `cat` on them. bool found_valid_tensor = valid < materialized.size(); if (found_valid_tensor) { +<<<<<<< HEAD TORCH_CHECK_INDEX( +======= + TORCH_CHECK( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dim <= materialized[valid].get().dim(), "torch.cat(): dimension ", dim, @@ -384,7 +396,11 @@ Tensor& set_storage_cpu_( result.unsafeGetTensorImpl()->set_storage_offset(storage_offset); at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt; +<<<<<<< HEAD // We can reuse this kernel for the meta device. +======= + // We can re-use this kernel for the meta device. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // We just need to make sure we don't actually try to resize the (null) // storage. at::native::resize_impl_cpu_( @@ -459,7 +475,12 @@ Tensor& set_storage_meta__symint( size, stride, itemsize, std::move(storage_offset)); if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() && +<<<<<<< HEAD (new_size_bytes > storage.sym_nbytes())) { +======= + TORCH_GUARD_SIZE_OBLIVIOUS( + new_size_bytes.sym_gt(storage.sym_nbytes()))) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) storage.set_nbytes(std::move(new_size_bytes)); } } @@ -505,7 +526,11 @@ Tensor& set_cpu_(Tensor& result) { return result; } +<<<<<<< HEAD // We can't reuse the cpu kernel here because we don't want to use the cpu +======= +// We can't re-use the cpu kernel here because we don't want to use the cpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // allocator. Tensor& set_meta_(Tensor& result) { caffe2::TypeMeta dtype = result.dtype(); @@ -1408,6 +1433,12 @@ Tensor as_strided_tensorimpl( IntArrayRef size, IntArrayRef stride, std::optional storage_offset_) { +<<<<<<< HEAD +======= + TORCH_INTERNAL_ASSERT( + !self.is_mps(), + "as_strided_tensorimpl does not work with MPS; call self.as_strided(...) instead"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto storage_offset = storage_offset_.value_or(self.storage_offset()); auto result = at::detail::make_tensor( c10::TensorImpl::VIEW, @@ -1904,7 +1935,11 @@ Tensor repeat(const Tensor& self, IntArrayRef repeats) { } Tensor tile_symint(const Tensor& self, SymIntArrayRef reps) { +<<<<<<< HEAD // If self.size() > len(reps), reps is promoted to self.size() by prepending +======= + // If self.size() > len(reps), reps is promoted to self.size() by pre-pending +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // 1’s to it to keep the same behaviour as `numpy.tile`. // Thus for a tensor of shape (2, 3, 4, 5), a dims of (2, 2) is treated // as (1, 1, 2, 2). @@ -1995,18 +2030,32 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) { TORCH_CHECK(false, "reshape is not implemented for sparse tensors"); } +<<<<<<< HEAD if (self.is_contiguous_or_false() && !self.is_mkldnn()) { return self.view_symint(proposed_shape); } auto sym_numel = self.sym_numel(); +======= + auto sym_sizes = self.sym_sizes(); + auto sym_strides = self.sym_strides(); + auto sym_numel = self.sym_numel(); + if (definitely_contiguous(sym_sizes, sym_strides, sym_numel) && + !self.is_mkldnn()) { + return self.view_symint(proposed_shape); + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::SymDimVector shape = infer_size_dv(proposed_shape, sym_numel); if (self.is_mkldnn()) { return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape)); } +<<<<<<< HEAD auto sym_sizes = self.sym_sizes(); auto sym_strides = self.sym_strides(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // `computeStride` returns the proper strides to use if this // `reshape` can be just a view. @@ -2428,7 +2477,11 @@ Tensor index_select_sparse_cpu( const auto dim_indices = indices[dim].contiguous(); // If nnz is smaller than size, then either indices[dim] or index gets +<<<<<<< HEAD // sorted, then this is followed by a binary search to find intersections. +======= + // sorted, then this is followed by a binary search to find interesections. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto get_selected_indices_small_nnz_large_size = [&]() -> std::tuple { const auto grain_size = at::internal::GRAIN_SIZE; @@ -3934,7 +3987,11 @@ Tensor squeeze_qtensor(const Tensor& self, c10::OptionalIntArrayRef dims) { quantizer->scalar_type()); } // TODO: quantized Tensor support for SymInt needs to be added but basic +<<<<<<< HEAD // building blocks are missing for now. +======= + // building blocs are missing for now. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto result = make_qtensor( self, C10_AS_INTARRAYREF_SLOW(sizes), diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h index 9a122cd7cf05e..6730187ebd385 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h @@ -4,11 +4,17 @@ #include #ifdef USE_FBGEMM +<<<<<<< HEAD C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include #include C10_DIAGNOSTIC_POP() +======= +#include +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace ao::sparse { diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp index 00c9f4eb25348..80043c9fbc721 100644 --- a/aten/src/ATen/native/cpu/Activation.cpp +++ b/aten/src/ATen/native/cpu/Activation.cpp @@ -26,10 +26,13 @@ namespace at::native { namespace { +<<<<<<< HEAD #if defined(__GNUC__) && __GNUC__ == 14 && defined(__aarch64__) && !defined(__ARM_FEATURE_SVE) // Workaround for gcc-14.2.0 ICE during RTL pass: expand when compiling for NEON __attribute__((optimize("no-tree-vectorize"))) #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static void log_sigmoid_cpu_kernel(TensorBase &output, TensorBase &buffer, const TensorBase &input) { if (at::isReducedFloatingType(input.scalar_type())) { AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "log_sigmoid_cpu", [&]() { diff --git a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp index a1a7059b7d64f..dd2b99d6a9164 100644 --- a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp @@ -139,7 +139,11 @@ struct Dist { static inline data_t map(const data_t& diff, const data_t& p) { return diff; } static inline data_t red(const data_t& agg, const data_t& up) { return max(agg, up); } static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return agg; } +<<<<<<< HEAD // TODO This backward pass uses a very complex expression to compute (diff +======= + // TODO This backward pass uses a very complext expression to compute (diff +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // == dist) that could be much faster if using SSE instructions. static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return Vec(grad) * sign(diff) * (Vec(1) - vec::minimum(Vec(1), (diff.abs() - Vec(dist)).abs().ceil())); } }; diff --git a/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp b/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp index 4432b9ace7911..e57b6203227cf 100644 --- a/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp +++ b/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp @@ -96,6 +96,7 @@ inline void _exp_reduce_sum_fusion_kernel( for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) { auto tmp0 = vec::Vectorized::loadu(a + i); auto tmp1 = tmp0 - vec_max; +<<<<<<< HEAD Vectorized tmp2; if constexpr (std::is_same_v && (std::is_same_v || std::is_same_v)) @@ -104,6 +105,9 @@ inline void _exp_reduce_sum_fusion_kernel( } else { tmp2 = tmp1.exp_u20(); } +======= + auto tmp2 = tmp1.exp_u20(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) vec_tmp_sum += tmp2; _store(out + i, tmp2); } @@ -315,12 +319,21 @@ void cpu_flash_attention( bool is_causal, std::optional attn_mask, std::optional scale) { +<<<<<<< HEAD // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) // Key (Batch x KV_num_heads x KV_seq_len x Dim_per_head) // -> (Batch x KV_seq_len x KV_num_heads x Dim_per_head) // Value (Batch x KV_num_heads x KV_seq_len x Dim_per_head) // -> (Batch x KV_seq_len x KV_num_heads x Dim_per_head) +======= + // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) + // -> (Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value (Batch x Num_heads x KV_seq_len x Dim_per_head) + // -> (Batch x KV_seq_len x Num_heads x Dim_per_head) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::Tensor query = q.transpose(1, 2); at::Tensor key = k.transpose(1, 2); at::Tensor value = v.transpose(1, 2); @@ -338,8 +351,11 @@ void cpu_flash_attention( int64_t qSize = query.size(1); int64_t kvSize = value.size(1); int64_t num_head = query.size(2); +<<<<<<< HEAD int64_t kv_num_head = key.size(2); int64_t repeat_factor = num_head / kv_num_head; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t headSize = query.size(3); bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); @@ -400,7 +416,11 @@ void cpu_flash_attention( // When the number of gemm is greater than the number of pack, // the pack overhead can be overlapped. if (need_pack) { +<<<<<<< HEAD double pack_size = batchSize * kv_num_head * kvSize * headSize; +======= + double pack_size = batchSize * num_head * kvSize * headSize; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) double qs_per_thread = (batchSize * num_head * qSlice + num_thread - 1) / num_thread; double gemm_size_per_thread = qs_per_thread * qSplitSize * (is_causal ? std::min(qSize, kvSize) : kvSize) * headSize; @@ -450,10 +470,17 @@ void cpu_flash_attention( at::Tensor qeury_t_padding; if (need_pack) { key_t_reorder = at::empty( +<<<<<<< HEAD {batchSize, kv_num_head, eheadSize, kvSize}, c10::CppTypeToScalarType::value); value_t_reorder = at::empty( {batchSize, kv_num_head, kv_padding_size, headSize}, +======= + {batchSize, num_head, eheadSize, kvSize}, + c10::CppTypeToScalarType::value); + value_t_reorder = at::empty( + {batchSize, num_head, kv_padding_size, headSize}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::CppTypeToScalarType::value); key_reorder_ptr = key_t_reorder.data_ptr(); value_reorder_ptr = value_t_reorder.data_ptr(); @@ -472,11 +499,19 @@ void cpu_flash_attention( {num_thread, kvSplitSize, headSize}, c10::CppTypeToScalarType::value); scalar_t* transpose_buffer_ptr = tranpose_t_reorder.data_ptr(); +<<<<<<< HEAD at::parallel_for(0, batchSize * kv_num_head * kvSlice, 1, [&](int64_t begin, int64_t end) { int ompIdx = at::get_thread_num(); int64_t i = 0, kv_j = 0, l = 0, n = 0; scalar_t* transpose_ptr = transpose_buffer_ptr + ompIdx * kvSplitSize * headSize; at::native::data_index_init(begin, i, batchSize, kv_j, kv_num_head, l, kvSlice); +======= + at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) { + int ompIdx = at::get_thread_num(); + int64_t i = 0, j = 0, l = 0, n = 0; + scalar_t* transpose_ptr = transpose_buffer_ptr + ompIdx * kvSplitSize * headSize; + at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for ([[maybe_unused]] auto z : c10::irange(begin, end)) { n = l * kvSplitSize; int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); @@ -486,7 +521,11 @@ void cpu_flash_attention( kvBlockSize, headSize, /* src_ptr */ +<<<<<<< HEAD reinterpret_cast(k_data + i * kStrideB + kv_j * kStrideH + n * kStrideN), +======= + reinterpret_cast(k_data + i * kStrideB + j * kStrideH + n * kStrideN), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /* ld_src */ kStrideN, /* dst */ reinterpret_cast(transpose_ptr), /* ld_dst */ kvBlockSize); @@ -494,24 +533,40 @@ void cpu_flash_attention( // Pack [headSize, kvBlockSize] at::vec::pack_vnni2( /* src */ reinterpret_cast(transpose_ptr), +<<<<<<< HEAD /* dst */ reinterpret_cast(key_reorder_ptr + i * kv_num_head * eheadSize * kvSize + kv_j * eheadSize * kvSize + n * eheadSize), +======= + /* dst */ reinterpret_cast(key_reorder_ptr + i * num_head * eheadSize * kvSize + + j * eheadSize * kvSize + n * eheadSize), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /* ld_src */ kvBlockSize, /* K */ headSize, /* N */ kvBlockSize); // Pack [kvBlockSize, headSize] at::vec::pack_vnni2( +<<<<<<< HEAD /* src */ reinterpret_cast(v_data + i * vStrideB + kv_j * vStrideH + n * vStrideN), /* dst */ reinterpret_cast(value_reorder_ptr + i * kv_num_head * kv_padding_size * headSize + kv_j * kv_padding_size * headSize + n * headSize), +======= + /* src */ reinterpret_cast(v_data + i * vStrideB + j * vStrideH + n * vStrideN), + /* dst */ reinterpret_cast(value_reorder_ptr + + i * num_head * kv_padding_size * headSize + + j * kv_padding_size * headSize + n * headSize), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /* ld_src */ vStrideN, /* K */ kvBlockSize, /* N */ headSize); // Move to the next query +<<<<<<< HEAD at::native::data_index_step(i, batchSize, kv_j, kv_num_head, l, kvSlice); +======= + at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } }); } @@ -533,7 +588,10 @@ void cpu_flash_attention( for ([[maybe_unused]] auto z : c10::irange(begin, end)) { int64_t m = k * qSplitSize; int64_t qBlockSize = std::min(qSplitSize, qSize - m); +<<<<<<< HEAD int64_t kv_j = j / repeat_factor; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Initialize max and sum fill_stub(qk_max_data, -std::numeric_limits::infinity(), qBlockSize); @@ -570,8 +628,13 @@ void cpu_flash_attention( !headSize_even ? query_t_padding_ptr : q_data + i * qStrideB + j * qStrideH + m * qStrideM, +<<<<<<< HEAD key_reorder_ptr + i * kv_num_head * eheadSize * kvSize + kv_j * eheadSize * kvSize + n * eheadSize, +======= + key_reorder_ptr + i * num_head * eheadSize * kvSize + + j * eheadSize * kvSize + n * eheadSize, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) qk_data); } } else { @@ -582,7 +645,11 @@ void cpu_flash_attention( qBlockSize, headSize, static_cast(1), +<<<<<<< HEAD k_data + i * kStrideB + kv_j * kStrideH + +======= + k_data + i * kStrideB + j * kStrideH + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) n * kStrideN, kStrideN, q_data + i * qStrideB + j * qStrideH + @@ -701,8 +768,13 @@ void cpu_flash_attention( n > 0, qk_reduced_data, value_reorder_ptr + +<<<<<<< HEAD i * kv_num_head * kv_padding_size * headSize + kv_j * kv_padding_size * headSize + psize * headSize, +======= + i * num_head * kv_padding_size * headSize + + j * kv_padding_size * headSize + psize * headSize, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dst_data); } } else { @@ -713,7 +785,11 @@ void cpu_flash_attention( qBlockSize, kvBlockSize, static_cast(1), +<<<<<<< HEAD v_data + i * vStrideB + kv_j * vStrideH + +======= + v_data + i * vStrideB + j * vStrideH + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) n * vStrideN, vStrideN, conditional_data_ptr(qk_data, qk_reduced_data), @@ -778,15 +854,24 @@ void cpu_flash_attention_backward( // Sizes TORCH_CHECK((query.size(3) == value.size(3)) && (key.size(3) == value.size(3)), "scaled_dot_product_attention_flash_attention_backward: Q/K/V should have the same head size"); +<<<<<<< HEAD // Query (Batch x Q_seq_len x Num_heads x Dim_per_head) // Key (Batch x KV_seq_len x KV_num_heads x Dim_per_head) // Value (Batch x KV_seq_len x KV_num_heads x Dim_per_head) +======= + // Query (Batch x Q_seq_len x Num_heads x Dim_per_head) + // Key (Batch x KV_seq_len x Num_heads x Dim_per_head) + // Value (Batch x KV_seq_len x Num_heads x Dim_per_head) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t batchSize = query.size(0); int64_t qSize = query.size(1); int64_t kvSize = value.size(1); int64_t num_head = query.size(2); +<<<<<<< HEAD int64_t kv_num_head = key.size(2); int64_t repeat_factor = num_head / kv_num_head; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t headSize = query.size(3); bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); @@ -877,9 +962,15 @@ void cpu_flash_attention_backward( accum_t* buf_data = buf.data_ptr(); scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr() : nullptr; +<<<<<<< HEAD at::parallel_for(0, batchSize * kv_num_head, 1, [&](int64_t begin, int64_t end) { int64_t i = 0, kv_j = 0; data_index_init(begin, i, batchSize, kv_j, kv_num_head); +======= + at::parallel_for(0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0; + data_index_init(begin, i, batchSize, j, num_head); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int ompIdx = at::get_thread_num(); accum_t* buf_ptr = buf_data + ompIdx * size_per_thread; accum_t* attn_data = buf_ptr; @@ -891,6 +982,7 @@ void cpu_flash_attention_backward( at::Tensor dsum = at::empty({qSplitSize}, query.options().dtype(accumulate_dtype)); accum_t* dsum_data = dsum.data_ptr(); for ([[maybe_unused]] auto z : c10::irange(begin, end)) { +<<<<<<< HEAD for (int64_t r = 0; r < repeat_factor; r++) { int64_t j = kv_j * repeat_factor + r; // rowsum of grad_out * out @@ -1084,6 +1176,198 @@ void cpu_flash_attention_backward( } // Move to the next query data_index_step(i, batchSize, kv_j, kv_num_head); +======= + // rowsum of grad_out * out + for (int64_t m = 0; m < qSize; m += qSplitSize) { + int64_t qBlockSize = std::min(qSplitSize, qSize - m); + // dsum <- rowsum(grad_out * out) + for (const auto row : c10::irange(qBlockSize)) { + *(dsum_data + row) = vec::map2_reduce_all( + [](Vec x, Vec y) { return x * y; }, + [](Vec x, Vec y) { return x + y; }, + grad_out_data + i * grad_oStrideB + j * grad_oStrideH + (m + row) * grad_oStrideM, + out_data + i * oStrideB + j * oStrideH + (m + row) * oStrideM, + headSize); + } + int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; + for (int64_t n = 0; n < num_keys; n += kvSplitSize) { + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + // attn <- scale * q @ k.T + cpublas::gemm( + TransposeType::Transpose, + TransposeType::NoTranspose, + kvBlockSize, + qBlockSize, + headSize, + scaling_factor, + k_data + i * kStrideB + j * kStrideH + + n * kStrideN, + kStrideN, + q_data + i * qStrideB + j * qStrideH + + m * qStrideM, + qStrideM, + static_cast(0), + attn_data, + kvBlockSize); + // attn <- attn + mask + if (has_attn_mask) { + accum_t one = accum_t(1); + for (const auto row : c10::irange(qBlockSize)) { +#if __GNUC__ == 11 && defined(__ARM_FEATURE_SVE) + _scale_attn_mask_fusion_kernel( + attn_data + row * kvBlockSize, + mask_data + i * mStrideB + j * mStrideH + + (m + row) * mStrideM + (mStrideN == 0 ? 0 : n), + kvBlockSize, + attn_data + row * kvBlockSize, + one, + mStrideN == 0); +#else + if (mStrideN == 0) { + _scale_attn_mask_fusion_kernel( + attn_data + row * kvBlockSize, + mask_data + i * mStrideB + j * mStrideH + + (m + row) * mStrideM, + kvBlockSize, + attn_data + row * kvBlockSize, + one); + } else { + _scale_attn_mask_fusion_kernel( + attn_data + row * kvBlockSize, + mask_data + i * mStrideB + j * mStrideH + + (m + row) * mStrideM + n, + kvBlockSize, + attn_data + row * kvBlockSize, + one); + } +#endif + } + } + // restore self attention after softmax from logsumexp + // attn <- exp(attn - normalizer) + for (const auto row : c10::irange(qBlockSize)) { + accum_t normalizer = lse_data[i * lStrideB + j * lStrideH + (m + row) * lStrideM]; + vec::map( + [normalizer](Vec x) { return (x - Vec(normalizer)).exp(); }, + attn_data + row * kvBlockSize, + attn_data + row * kvBlockSize, + kvBlockSize); + } + // Apply causal mask, filled unused with 0 + if (is_causal && num_keys - n <= kvSplitSize) { + for (const auto row : c10::irange(qBlockSize)) { + int64_t last_col = m + row - n; + accum_t* row_ptr = attn_data + row * kvBlockSize; + fill_stub(row_ptr + last_col + 1, static_cast(0), kvBlockSize - last_col - 1); + } + } +#ifdef _MSC_VER + if (is_reduced_type) { +#else + if constexpr (is_reduced_type) { +#endif + for (const auto row : c10::irange(qBlockSize)) { + convert( + attn_data + row * kvBlockSize, + attn_reduced_data + row * kvBlockSize, + kvBlockSize); + } + } + // grad_v <- grad_v + attn.T @ grad_out + cpublas::gemm( + TransposeType::NoTranspose, + TransposeType::Transpose, + headSize, + kvBlockSize, + qBlockSize, + static_cast(1), + grad_out_data + i * grad_oStrideB + j * grad_oStrideH + + m * grad_oStrideM, + grad_oStrideM, + conditional_data_ptr(attn_data, attn_reduced_data), + kvBlockSize, + static_cast(1), + grad_v_data + i * grad_vStrideB + j * grad_vStrideH + + n * grad_vStrideN, + grad_vStrideN); + // grad_attn <- grad_out @ v.T + cpublas::gemm( + TransposeType::Transpose, + TransposeType::NoTranspose, + kvBlockSize, + qBlockSize, + headSize, + static_cast(1), + v_data + i * vStrideB + j * vStrideH + + n * vStrideN, + vStrideN, + grad_out_data + i * grad_oStrideB + j * grad_oStrideH + + m * grad_oStrideM, + grad_oStrideM, + static_cast(0), + grad_attn_data, + kvBlockSize); + // grad_attn <- attn * (grad_attn - dsum) + for (const auto row : c10::irange(qBlockSize)) { + accum_t d = *(dsum_data + row); + vec::map2( + [d](Vec attn, Vec grad_attn) { return attn * (grad_attn - Vec(d)); }, + grad_attn_data + row * kvBlockSize, + attn_data + row * kvBlockSize, + grad_attn_data + row * kvBlockSize, + kvBlockSize); + } +#ifdef _MSC_VER + if (is_reduced_type) { +#else + if constexpr (is_reduced_type) { +#endif + for (const auto row : c10::irange(qBlockSize)) { + convert( + grad_attn_data + row * kvBlockSize, + grad_attn_reduced_data + row * kvBlockSize, + kvBlockSize); + } + } + // grad_q <- grad_q + scale * grad_attn @ k + cpublas::gemm( + TransposeType::NoTranspose, + TransposeType::NoTranspose, + headSize, + qBlockSize, + kvBlockSize, + scaling_factor, + k_data + i * kStrideB + j * kStrideH + + n * kStrideN, + kStrideN, + conditional_data_ptr(grad_attn_data, grad_attn_reduced_data), + kvBlockSize, + static_cast(1), + grad_q_data + i * grad_qStrideB + j * grad_qStrideH + + m * grad_qStrideM, + grad_qStrideM); + // grad_k <- grad_k + scale * grad_attn.T @ q + cpublas::gemm( + TransposeType::NoTranspose, + TransposeType::Transpose, + headSize, + kvBlockSize, + qBlockSize, + scaling_factor, + q_data + i * qStrideB + j * qStrideH + + m * qStrideM, + qStrideM, + conditional_data_ptr(grad_attn_data, grad_attn_reduced_data), + kvBlockSize, + static_cast(1), + grad_k_data + i * grad_kStrideB + j * grad_kStrideH + + n * grad_kStrideN, + grad_kStrideN); + } + } + // Move to the next query + data_index_step(i, batchSize, j, num_head); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } }); } diff --git a/aten/src/ATen/native/cpu/Loops.h b/aten/src/ATen/native/cpu/Loops.h index 83b51a9985637..088973da98a10 100644 --- a/aten/src/ATen/native/cpu/Loops.h +++ b/aten/src/ATen/native/cpu/Loops.h @@ -89,7 +89,11 @@ execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t using result_type = typename traits::result_type; for (; i < n; i++) { result_type* out_ptr = (result_type*)(data[0] + i * strides[0]); +<<<<<<< HEAD *out_ptr = std::apply(op, dereference( +======= + *out_ptr = c10::guts::apply(op, dereference( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) &data[1], &strides[1], i)); @@ -102,7 +106,11 @@ inline void execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) { using traits = function_traits; for (; i < n; i++) { +<<<<<<< HEAD std::apply(op, dereference( +======= + c10::guts::apply(op, dereference( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) &data[0], &strides[0], i)); @@ -162,7 +170,11 @@ void handle_tuple_outputs(char* C10_RESTRICT data[], } // Loop operation for `cpu_kernel_multiple_outputs`. +<<<<<<< HEAD // 1. Use `std::apply` to make dynamic method invocation +======= +// 1. Use `c10::guts::apply` to make dynamic method invocation +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // for the lambda passed in `cpu_kernel_multiple_outputs`. // 2. Iterate over the members of the returned tuple, set the corresponding // output tensor by the tuple member in `handle_tuple_outputs` function. @@ -183,7 +195,11 @@ multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_ } for (; i < n; i++) { +<<<<<<< HEAD auto output = std::apply(op, dereference( +======= + auto output = c10::guts::apply(op, dereference( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) &data[num_outputs], &strides[num_outputs], i)); @@ -213,8 +229,13 @@ vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, ve for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) { auto args1 = dereference_vec(&data[1], opt_scalar, S, i); auto args2 = dereference_vec(&data[1], opt_scalar, S, i + Vec::size()); +<<<<<<< HEAD auto out1 = std::apply(vop, std::move(args1)); auto out2 = std::apply(vop, std::move(args2)); +======= + auto out1 = c10::guts::apply(vop, std::move(args1)); + auto out2 = c10::guts::apply(vop, std::move(args2)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out1.store(data[0] + i * sizeof(scalar_t)); out2.store(data[0] + (i + Vec::size()) * sizeof(scalar_t)); } diff --git a/aten/src/ATen/native/cpu/PaddingKernel.cpp b/aten/src/ATen/native/cpu/PaddingKernel.cpp index 59d838b9782da..b4b655210a788 100644 --- a/aten/src/ATen/native/cpu/PaddingKernel.cpp +++ b/aten/src/ATen/native/cpu/PaddingKernel.cpp @@ -156,7 +156,11 @@ void cpu_padding( int64_t offset_h = ndim >= 2 ? p.offsets[ndim - 2] : 0; int64_t offset_w = p.offsets[ndim - 1]; +<<<<<<< HEAD // do vectorized copy when output is overlapped with input on W, +======= + // do vectorized copy whe output is overlapped with input on W, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // only applies to positive padding auto loop = [=](scalar_t* out, const scalar_t* in, bool positive_padding) { if (positive_padding) { diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index c06731dfc718c..be89b6326c393 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -425,6 +425,7 @@ static void argmin_kernel_impl(TensorIterator &iter) { }); } +<<<<<<< HEAD template struct XorSumOps { inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { @@ -468,6 +469,8 @@ static void xor_sum_kernel_impl(TensorIterator& iter) { }); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // anonymous namespace REGISTER_DISPATCH(std_var_stub, &std_var_kernel_impl) @@ -482,7 +485,10 @@ REGISTER_DISPATCH(min_values_stub, &min_values_kernel_impl) REGISTER_DISPATCH(max_values_stub, &max_values_kernel_impl) REGISTER_DISPATCH(argmax_stub, &argmax_kernel_impl) REGISTER_DISPATCH(argmin_stub, &argmin_kernel_impl) +<<<<<<< HEAD REGISTER_DISPATCH(xor_sum_stub, &xor_sum_kernel_impl) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) REGISTER_DISPATCH(cumprod_stub, &cumprod_cpu_kernel) REGISTER_DISPATCH(cumsum_stub, &cumsum_cpu_kernel) diff --git a/aten/src/ATen/native/cpu/SoftMaxKernel.cpp b/aten/src/ATen/native/cpu/SoftMaxKernel.cpp index dac0f3bef25ee..4d10c0c42b57f 100644 --- a/aten/src/ATen/native/cpu/SoftMaxKernel.cpp +++ b/aten/src/ATen/native/cpu/SoftMaxKernel.cpp @@ -7,7 +7,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -648,10 +651,17 @@ _vec_softmax( parallel_for( 0, outer_size * inner_size, 0, [&](int64_t begin, int64_t end) { int64_t idx = begin; +<<<<<<< HEAD std::vector temp_vec_input(dim_size * vectorized_step); std::vector temp_vec_output(dim_size * vectorized_step); float* temp_vec_input_data = temp_vec_input.data(); float* temp_vec_output_data = temp_vec_output.data(); +======= + std::unique_ptr temp_vec_input(new float[dim_size*vectorized_step]()); + std::unique_ptr temp_vec_output(new float[dim_size*vectorized_step]()); + float* temp_vec_input_data = temp_vec_input.get(); + float* temp_vec_output_data = temp_vec_output.get(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) while (idx < end) { int64_t outer_idx = idx / inner_size; int64_t inner_idx = idx % inner_size; diff --git a/aten/src/ATen/native/cpu/Unfold2d.cpp b/aten/src/ATen/native/cpu/Unfold2d.cpp index 8c94decfff023..14020dfc3c07b 100644 --- a/aten/src/ATen/native/cpu/Unfold2d.cpp +++ b/aten/src/ATen/native/cpu/Unfold2d.cpp @@ -169,10 +169,13 @@ static void unfolded2d_acc_channels_last( /* note: due to write issues, this one cannot be parallelized as well as * unfolded2d_copy */ +<<<<<<< HEAD #if defined(__GNUC__) && __GNUC__ == 14 && defined(__ARM_FEATURE_SVE) && !defined(__ARM_FEATURE_BF16) // Workaround for gcc-14.2.0 ICE during RTL pass: vregs when compiling for SVE without BF16 __attribute__((optimize("no-tree-vectorize"))) #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void unfolded2d_acc_kernel( ScalarType dtype, void *finput_data, diff --git a/aten/src/ATen/native/cpu/batch_norm_kernel.cpp b/aten/src/ATen/native/cpu/batch_norm_kernel.cpp index d013dfa0485e0..a471ff71cce07 100644 --- a/aten/src/ATen/native/cpu/batch_norm_kernel.cpp +++ b/aten/src/ATen/native/cpu/batch_norm_kernel.cpp @@ -318,7 +318,11 @@ batch_norm_cpu_collect_stats_channels_last_impl( // // The optimal THRESHOLD to tile was found empirically. // When C > THRESHOLD, C is large enough that the benefit from tiling and vectorization outweigh the synchronization overhead. +<<<<<<< HEAD // When C <= TILE_SIZE, the problem size is small enough (C <= TILE_SIZE && NHW <= max_threads) that it's better to launch single thread with vectorization than C threads without vectorization. +======= + // Wehn C <= TILE_SIZE, the problem size is small enough (C <= TILE_SIZE && NHW <= max_threads) that it's better to launch single thread with vectorization than C threads without vectorization. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // // When num_threads == 1, always use Method 2 as there is no synchronization overhead. // diff --git a/aten/src/ATen/native/cpu/int8mm_kernel.cpp b/aten/src/ATen/native/cpu/int8mm_kernel.cpp index 7e2cba98ff1d7..4c8aebe40d8ac 100644 --- a/aten/src/ATen/native/cpu/int8mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int8mm_kernel.cpp @@ -367,6 +367,7 @@ void int8pack_mm_kernel_( auto* C_data = C.data_ptr(); const auto* S_data = scales.const_data_ptr(); +<<<<<<< HEAD int64_t M = A.size(0); int64_t N = B.size(0); int64_t K = A.size(1); @@ -379,15 +380,36 @@ void int8pack_mm_kernel_( at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { int64_t mb{0}, nb{0}; +======= + int M = A.size(0); + int N = B.size(0); + int K = A.size(1); + int lda = A.stride(0); + constexpr int BLOCK_M = 4; + constexpr int BLOCK_N = 4; + + const int MB = (M + BLOCK_M - 1) / BLOCK_M; + const int NB = (N + BLOCK_N - 1) / BLOCK_N; + + at::parallel_for(0, MB * NB, 0, [&](int begin, int end) { + int mb{0}, nb{0}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) data_index_init(begin, mb, MB, nb, NB); for (const auto i : c10::irange(begin, end)) { (void)i; +<<<<<<< HEAD int64_t mb_start = mb * BLOCK_M; int64_t mb_size = std::min(BLOCK_M, M - mb_start); int64_t nb_start = nb * BLOCK_N; int64_t nb_size = std::min(BLOCK_N, N - nb_start); +======= + int mb_start = mb * BLOCK_M; + int mb_size = std::min(BLOCK_M, M - mb_start); + int nb_start = nb * BLOCK_N; + int nb_size = std::min(BLOCK_N, N - nb_start); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto* A_ptr = A_data + mb_start * lda; const auto* B_ptr = B_data + nb_start * K; diff --git a/aten/src/ATen/native/cpu/moments_utils.h b/aten/src/ATen/native/cpu/moments_utils.h index 8aba425e89637..d52bc276fbbff 100644 --- a/aten/src/ATen/native/cpu/moments_utils.h +++ b/aten/src/ATen/native/cpu/moments_utils.h @@ -8,6 +8,10 @@ #include #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include namespace at::native { @@ -117,11 +121,17 @@ std::pair, opmath_t> RowwiseMomentsImpl(const T* X, int64_t N, in using Vec = vec::Vectorized; const Vec kZeroVec(math_t(0)); +<<<<<<< HEAD std::array m0_stk = {{0}}; std::array m1_stk; m1_stk.fill(kZeroVec); std::array m2_stk; m2_stk.fill(kZeroVec); +======= + c10::SmallVector m0_stk(depth, 0); + c10::SmallVector m1_stk(depth, kZeroVec); + c10::SmallVector m2_stk(depth, kZeroVec); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto i : c10::irange(m)) { const T* X_ptr = X + i * kChunkSize * kVecSize; diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h index 827c69629eb37..968835d1874d2 100644 --- a/aten/src/ATen/native/cpu/utils.h +++ b/aten/src/ATen/native/cpu/utils.h @@ -6,9 +6,13 @@ #include #ifdef USE_FBGEMM +<<<<<<< HEAD C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include C10_DIAGNOSTIC_POP() +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif namespace at::native { @@ -167,12 +171,15 @@ inline void transpose(int64_t M, int64_t N, const uint16_t* src, int64 TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); } +<<<<<<< HEAD template <> inline void transpose(int64_t M, int64_t N, const uint8_t* src, int64_t ld_src, uint8_t* dst, int64_t ld_dst) { TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif template diff --git a/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu b/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu index 47c705a667b52..2ac0cbed4d2d6 100644 --- a/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu +++ b/aten/src/ATen/native/cuda/AdaptiveAveragePooling.cu @@ -526,7 +526,11 @@ namespace { // we are dealing with packed tensor here. max index is the same as numel. +<<<<<<< HEAD // TODO: to really support input tensor large enough to go beyond int32, +======= + // TODO: to really support input tensor large enought to go beyond int32, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // we will need to restrict out shared memory usage and adjust the launch // config; AT_ASSERT(input_.numel() < std::numeric_limits::max()); @@ -681,7 +685,11 @@ namespace { const dim3 grid(grid_x, grid_y, grid_z); // we are dealing with packed tensor here. max index is the same as numel. +<<<<<<< HEAD // TODO: to really support input tensor large enough to go beyond int32, +======= + // TODO: to really support input tensor large enought to go beyond int32, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // we will need to restrict out shared memory usage and adjust the launch // config; AT_ASSERT(input.numel() < std::numeric_limits::max()); diff --git a/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu b/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu index d9a0b0059917f..265b74036e321 100644 --- a/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu +++ b/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu @@ -53,7 +53,11 @@ __global__ void adaptiveaveragepool( const scalar_t *input, scalar_t *output, int isizeT, int isizeH, int isizeW, int osizeT, int osizeH, int osizeW, +<<<<<<< HEAD int64_t sizeD, int64_t istrideB, int64_t istrideD, +======= + int64_t istrideD, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t istrideT, int64_t istrideH, int64_t istrideW, int64_t offsetZ) { // iterates on output pixels @@ -70,17 +74,26 @@ __global__ void adaptiveaveragepool( // select output plane int64_t o_plane = blockIdx.x + offsetZ; ot = o_plane % osizeT; // output frame/time +<<<<<<< HEAD int d = o_plane / osizeT; // flattened (batch, channel) index // Decompose d into batch and channel indices int batch_idx = d / sizeD; int channel_idx = d % sizeD; +======= + int d = o_plane / osizeT; // slice/feature +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // input frame/time range is fixed. int istartT = start_index(ot, osizeT, isizeT); int iendT = end_index(ot, osizeT, isizeT); int kT = iendT - istartT; +<<<<<<< HEAD +======= + // input offset by slice/feature and earliest relevant frame/time + const scalar_t *input_dt = input + d*istrideD + istartT*istrideT; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // output offset by slice/feature and frame/time scalar_t *output_dt = output + o_plane*osizeH*osizeW; @@ -95,6 +108,11 @@ __global__ void adaptiveaveragepool( int iendW = end_index(ow, osizeW, isizeW); int kW = iendW - istartW; +<<<<<<< HEAD +======= + // Compute the average pooling from corresponding input pixels + const scalar_t *ptr_input = input_dt + istartH*istrideH + istartW*istrideW; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) scalar_t *ptr_output = output_dt + oh*osizeW + ow; accscalar_t sum = static_cast(0); @@ -102,6 +120,7 @@ __global__ void adaptiveaveragepool( for (it = 0; it < kT; ++it) { for (ih = 0; ih < kH; ++ih) { for (iw = 0; iw < kW; ++iw) { +<<<<<<< HEAD int64_t input_offset = batch_idx * istrideB + channel_idx * istrideD + (istartT + it) * istrideT + (istartH + ih) * istrideH + (istartW + iw) * istrideW; @@ -109,6 +128,13 @@ __global__ void adaptiveaveragepool( sum += static_cast(val); } } +======= + scalar_t val = ptr_input[ih*istrideH + iw*istrideW]; + sum += static_cast(val); + } + } + ptr_input += istrideT; // next input frame +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // Update output const accscalar_t divide_factor = static_cast(kT * kH * kW); @@ -123,7 +149,11 @@ void adaptiveaveragepool_loop( int64_t totalZ, int isizeT, int isizeH, int isizeW, int osizeT, int osizeH, int osizeW, +<<<<<<< HEAD int64_t sizeD, int64_t istrideB, int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW) { +======= + int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t offsetZ = 0; dim3 threads(32, 8); // each H*W plane is processed by blocksH thread blocks @@ -135,7 +165,11 @@ void adaptiveaveragepool_loop( input_data, output_data, isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, +<<<<<<< HEAD sizeD, istrideB, istrideD, +======= + istrideD, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) istrideT, istrideH, istrideW, offsetZ); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -366,7 +400,11 @@ void adaptive_avg_pool3d_out_cuda_template( int64_t osizeW = output_size[2]; int64_t sizeD, isizeT, isizeH, isizeW; +<<<<<<< HEAD int64_t istrideB, istrideD, istrideT, istrideH, istrideW; +======= + int64_t istrideD, istrideT, istrideH, istrideW; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t totalZ; const Tensor& input = input_.ndimension() == 4 ? input_ : input_.contiguous(); @@ -377,7 +415,10 @@ void adaptive_avg_pool3d_out_cuda_template( isizeH = input.size(2); isizeW = input.size(3); +<<<<<<< HEAD istrideB = 0; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) istrideD = input.stride(0); istrideT = input.stride(1); istrideH = input.stride(2); @@ -393,7 +434,10 @@ void adaptive_avg_pool3d_out_cuda_template( isizeH = input.size(3); isizeW = input.size(4); +<<<<<<< HEAD istrideB = input.stride(0); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) istrideD = input.stride(1); istrideT = input.stride(2); istrideH = input.stride(3); @@ -419,7 +463,11 @@ void adaptive_avg_pool3d_out_cuda_template( totalZ, isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, +<<<<<<< HEAD sizeD, istrideB, istrideD, istrideT, istrideH, istrideW); +======= + istrideD, istrideT, istrideH, istrideW); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }); } diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 23447c7e09b3f..f2408c54f3e76 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -16,6 +16,7 @@ #include #include #include +<<<<<<< HEAD #include #include #include @@ -25,6 +26,11 @@ #ifdef USE_FBGEMM_GENAI #include #endif +======= +#include +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifndef AT_PER_OPERATOR_HEADERS #include @@ -105,7 +111,10 @@ c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, b } } +<<<<<<< HEAD using at::cuda::blas::ScalingType; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /** * @brief Prepares matrices for CUBLAS operation @@ -147,9 +156,13 @@ struct cublasCommonArgs { Tensor& c, const std::optional& scale_a = std::nullopt, const std::optional& scale_b = std::nullopt, +<<<<<<< HEAD const std::optional& scale_result = std::nullopt, const std::optional& scaling_choice_a = std::nullopt, const std::optional& scaling_choice_b = std::nullopt) { +======= + const std::optional& scale_result = std::nullopt) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool transpose_result = false, transpose_a = false, transpose_b = false; result = prepare_matrix_for_cublas(c, transpose_result); mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result); @@ -161,10 +174,15 @@ struct cublasCommonArgs { // as B.T @ A.T, check transpose_result to determine if we flip the scales scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr(); scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type(); +<<<<<<< HEAD scaling_mata_type = transpose_result ? scaling_choice_b : scaling_choice_a; scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr(); scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type(); scaling_matb_type = transpose_result ? scaling_choice_a : scaling_choice_b; +======= + scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr(); + scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if (scale_result) { @@ -210,9 +228,13 @@ struct cublasCommonArgs { void* scale_matb_ptr = nullptr; void* scale_result_ptr = nullptr; std::optional scale_mata_dtype; +<<<<<<< HEAD std::optional scaling_mata_type; std::optional scale_matb_dtype; std::optional scaling_matb_type; +======= + std::optional scale_matb_dtype; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::optional scale_result_dtype; }; } // namespace @@ -283,7 +305,14 @@ static bool isSupportedHipLtROCmArch(int index) { static const std::vector archs = { "gfx90a", "gfx942", #if ROCM_VERSION >= 60300 +<<<<<<< HEAD "gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908", +======= + "gfx1100", "gfx1101", "gfx1200", "gfx1201", +#endif +#if ROCM_VERSION >= 60402 + "gfx1150", "gfx1151", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif #if ROCM_VERSION >= 60500 "gfx950" @@ -1057,7 +1086,11 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) { return _int_mm_out_cuda(self, mat2, result); } +<<<<<<< HEAD static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) { +======= +static bool _scaled_mm_allowed_device(bool sm90_only=false) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef USE_ROCM static const std::vector archs = { "gfx942", @@ -1071,15 +1104,21 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=fals return at::detail::getCUDAHooks().isGPUArch(archs); #else auto dprops = at::cuda::getCurrentDeviceProperties(); +<<<<<<< HEAD if (sm90_only || sm100_only) { return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10); +======= + if (sm90_only) { + return dprops->major == 9; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9); } #endif } +<<<<<<< HEAD static bool _grouped_mm_allowed_device() { #ifdef USE_ROCM return false; @@ -1090,6 +1129,8 @@ static bool _grouped_mm_allowed_device() { #endif } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef USE_ROCM static bool _scaled_mm_is_fnuz() { return at::detail::getCUDAHooks().isGPUArch({"gfx942"}); @@ -1098,11 +1139,21 @@ static bool _scaled_mm_is_fnuz() { namespace{ +<<<<<<< HEAD +======= +enum class ScalingType : std::uint8_t { + TensorWise, + RowWise, + BlockWise, + Error +}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /* * Scaling Type Determination: * --------------------------- * Conditions and corresponding Scaling Types: * +<<<<<<< HEAD * - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`: * - Returns BlockWise (with additional size checks). * @@ -1118,10 +1169,22 @@ namespace{ * - Else if scale.dim() == 2 && scale.size(0) == outer_dim / 128 && scale.size(1) == inner_dim / 128: * - Returns BlockWise 128x128. * +======= + * - If scale tensors are both `Float8_e8m0fnu` or `Float8_e4m3fn`: + * - Returns BlockWise (with additional size checks). + * + * - If scale_a.numel() == 1 && scale_b.numel() == 1: + * - Returns TensorWise. + * + * - Else if scale_a.dim() == 2 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n: + * - Returns RowWise. + * +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * - Otherwise: * - Returns Error. */ +<<<<<<< HEAD using at::cuda::blas::ScalingType; bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) { @@ -1211,10 +1274,122 @@ std::pair get_joint_scaling( "Got a.dtype()=", a.scalar_type(), ", scale_a.dtype()=", scale_a.scalar_type(), ", scale_a.size()=", scale_a.sizes(), ", scale_a.stride()=", scale_a.strides(), ", ", "b.dtype()=", b.scalar_type(), ", scale_b.dtype()=", scale_b.scalar_type(), ", scale_b.size()=", scale_b.sizes(), " and scale_b.stride()=", scale_b.strides() ); +======= +// Validates the scale tensors to scaled_mm +// And returns the type of scaling/which kernel to use +ScalingType get_scaling_type( + const at::Tensor& scale_a, + const at::Tensor& scale_b, + int64_t dim_m, + int64_t dim_k, + int64_t dim_n) { + // Check for BlockWise scaling (FP8_E8M0 and FP8_E4M3 types) + if ((scale_a.scalar_type() == scale_b.scalar_type()) && + ((scale_a.scalar_type() == at::kFloat8_e8m0fnu) || (scale_a.scalar_type() == at::kFloat8_e4m3fn))) { + const bool is_nvfp4 = scale_a.scalar_type() == at::kFloat8_e4m3fn; + + // cuBLAS's mxfp8 gemm: block_size is 1 scale per 32 elements + // cuBLAS's nvfp4 gemm: block_size is 1 scale per 16 unpacked elements. + const auto BLOCK_SIZE_K = is_nvfp4 ? 16 : 32; + + constexpr int64_t BLOCK_SIZE_MN = 128; + + // adjust for fp4x2 packing if necessary + const auto dim_k_unpacked = is_nvfp4 ? dim_k * 2 : dim_k; + + auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; }; + auto num_k_blocks = ceil_div(dim_k_unpacked, BLOCK_SIZE_K); + auto padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4; + + // TODO: We might want to enforce some structure on the shapes of the scale + // tensors + + // Check expected sizes for block-wise scaling + auto expected_a_size = + BLOCK_SIZE_MN * ceil_div(dim_m, BLOCK_SIZE_MN) * padded_num_k_blocks; + auto expected_b_size = + BLOCK_SIZE_MN * ceil_div(dim_n, BLOCK_SIZE_MN) * padded_num_k_blocks; + + //TODO: enable the checks for ROCm +#ifndef USE_ROCM + TORCH_CHECK(scale_a.numel() == expected_a_size, + "For BlockWise scaling: Expected scale_a size to be ", + expected_a_size, " but got ", scale_a.numel()); + TORCH_CHECK(scale_b.numel() == expected_b_size, + "For BlockWise scaling: Expected scale_b size to be ", + expected_b_size, " but got ", scale_b.numel()); +#endif + + TORCH_CHECK( + scale_a.is_contiguous() && scale_b.is_contiguous(), + "For BlockWise scaling: Both scale_a and scale_b must be contiguous"); + + return ScalingType::BlockWise; + } + // Both Per-Tensor and Row-wise scaling expect fp32 tensors + TORCH_CHECK( + scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat, + "Both scale_a and scale_b must be float (fp32) tensors."); + + // Check the singluar scale case for per-tensor scaling + if (scale_a.numel() == 1 && scale_b.numel() == 1) { + return ScalingType::TensorWise; + } + + // For non-TensorWise scaling, enforce 2D input tensors + TORCH_CHECK( + scale_a.dim() == 2 && scale_b.dim() == 2, + "For non-TensorWise scaling, scale tensors must be 2-dimensional, " + "but got scale_a.dim()=", + scale_a.dim(), + " and scale_b.dim()=", + scale_b.dim()); + + // Check for RowWise scaling + if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 && + scale_b.size(0) == 1 && scale_b.size(1) == dim_n) { +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || \ + (defined(USE_ROCM) && (defined(HIPBLASLT_VEC_EXT) || defined(HIPBLASLT_OUTER_VEC))) + TORCH_CHECK( + scale_a.is_contiguous() && scale_b.is_contiguous(), + "Both scale_a and scale_b must be contiguous for RowWise scaling."); + return ScalingType::RowWise; +#else + TORCH_CHECK(false, "Per-row scaling is not supported for this platform!"); + return ScalingType::Error; +#endif + } + + // If we reach here, the input doesn't match any valid scaling type + TORCH_CHECK( + false, + "Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. " + "For RowWise scaling, scale_a should be (", + dim_m, + ", 1) and scale_b should be (1, ", + dim_n, + "). " + "Got scale_a.size()=(", + scale_a.size(0), + ", ", + scale_a.size(1), + ") and ", + "scale_b.size()=(", + scale_b.size(0), + ", ", + scale_b.size(1), + ")"); + + return ScalingType::Error; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Computes matrix multiply + bias while applying scaling to input and output matrices // Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default. // If output matrix type is 16 or 32-bit type, scale_result is not applied. @@ -1228,10 +1403,17 @@ std::pair get_joint_scaling( // - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2` // - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16` // - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type +<<<<<<< HEAD // - `scale_a`: a tensor with the inverse scale of `mat1`, whose shape/strides/dtype depend on the scaling scheme // - `scale_b`: a tensor with the inverse scale of `mat2`, whose shape/strides/dtype depend on the scaling scheme // - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type // - `use_fast_accum`: if true, enables fast float8 accumulation. Backends may ignore this option if not applicable. +======= +// - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type +// - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type +// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type +// - `use_fast_accum`: if true, enables fast float8 accumulation +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // - `out`: a reference to the output tensor Tensor& @@ -1252,6 +1434,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); +<<<<<<< HEAD // Check what type of scaling we are doing based on inputs. This list is sorted // by decreasing priority. We prefer "simpler" schemes as they are supported // more broadly (more GPU archs, more CUDA versions) and because they are more @@ -1267,6 +1450,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, std::make_pair(ScalingType::BlockWise1x16, ScalingType::BlockWise1x16) }, mat1, mat2, scale_a, scale_b); +======= + // Check what type of scaling we are doing based on inputs + ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat1.size(1), mat2.size(1)); + TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), "scale_result must be a float scalar"); @@ -1291,6 +1479,17 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2, "Multiplication of two Float8_e5m2 matrices is not supported"); #endif +<<<<<<< HEAD +======= +#ifdef USE_ROCM + if (mat1.scalar_type() == ScalarType::Float8_e5m2 || mat2.scalar_type() == ScalarType::Float8_e5m2) { + TORCH_CHECK(ROCM_VERSION >= 60000, "Float8_e5m2 is only supported for ROCm 6.0 and above"); + } + if (mat1.scalar_type() == ScalarType::Float8_e4m3fn || mat2.scalar_type() == ScalarType::Float8_e4m3fn) { + TORCH_CHECK(ROCM_VERSION >= 60000, "Float8_e4m3fn is only supported for ROCm 6.0 and above"); + } +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (use_fast_accum) { TORCH_CHECK(mat1.scalar_type() != ScalarType::Float4_e2m1fn_x2 && mat2.scalar_type() != ScalarType::Float4_e2m1fn_x2, "`use_fast_accum` is not supported when `mat1` or `mat2` tensors have the `Float4_e2m1fn_x2` dtype."); } @@ -1299,6 +1498,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, TORCH_CHECK(ROCM_VERSION >= 70000, "Float4_e2m1fn_x2 is only supported for ROCm 7.0 and above"); } if (mat1.scalar_type() == ScalarType::Float8_e5m2 || mat2.scalar_type() == ScalarType::Float8_e5m2) { +<<<<<<< HEAD TORCH_CHECK(ROCM_VERSION >= 60500, "Float8_e5m2 is only supported for ROCm 6.5 and above"); } if (mat1.scalar_type() == ScalarType::Float8_e4m3fn || mat2.scalar_type() == ScalarType::Float8_e4m3fn) { @@ -1323,6 +1523,23 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, bias->scalar_type() == ScalarType::Half, "Bias must be Float16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type()); +======= + TORCH_CHECK(ROCM_VERSION >= 70000, "Float8_e5m2 is only supported for ROCm 7.0 and above"); + } + if (mat1.scalar_type() == ScalarType::Float8_e4m3fn || mat2.scalar_type() == ScalarType::Float8_e4m3fn) { + TORCH_CHECK(ROCM_VERSION >= 70000, "Float8_e4m3fn is only supported for ROCm 7.0 and above"); + } +#endif + if (bias) { + TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32"); + TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half, + "Bias must be either Half or BFloat16, but got ", bias->scalar_type()); + TORCH_CHECK((out.scalar_type() != kFloat && out.scalar_type() != ScalarType::BFloat16) || + bias->scalar_type() == ScalarType::BFloat16, + "Bias must be BFloat16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type()); + TORCH_CHECK(out.scalar_type() != ScalarType::Half || bias->scalar_type() == ScalarType::Half, + "Bias must be Float16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } { auto bias_ = bias.value_or(Tensor()); @@ -1352,6 +1569,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, return out; } +<<<<<<< HEAD // NVIDIA's cuBLAS only started supporting row-wise scaling in version 12.9, // and only for compute capability 9.0+. In other cases we use CUTLASS. #ifndef USE_ROCM @@ -1361,6 +1579,12 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, && ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900) // cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales || (dprops->major >= 10 && (scale_a.sizes().size() || scale_b.sizes().size())))) { +======= + // ROCm's hipblaslt supports rowwise, so skip this check that sends this to cutlass. +#ifndef USE_ROCM + // We are doing row-wise scaling + if (scaling_choice == ScalingType::RowWise) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling."); at::cuda::detail::f8f8bf16_rowwise( mat1, @@ -1373,8 +1597,13 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, return out; } #else +<<<<<<< HEAD if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) { // For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes. +======= + if (scaling_choice == ScalingType::RowWise) { + // For ROCm, match behavior of f8f8bf16_rowwise type checking +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor b = mat2; if (_scaled_mm_is_fnuz()) { TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz); @@ -1382,6 +1611,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, else { TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn); } +<<<<<<< HEAD // Until more than bf16 is supported. TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16, "hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type()); @@ -1394,6 +1624,20 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, TORCH_CHECK(mat1.size(0) % 32 == 0 && mat1.size(1) % 32 == 0 && mat2.size(0) % 32 == 0 && mat2.size(1) % 32 == 0, "Matrix dimensions must be multiples of 32 for block-wise scaling"); +======= + // Until more than bf16 is supported + TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16, + "hipblaslt rowwise _scaled_mm only supports BFloat16 output"); + } + else if (scaling_choice == ScalingType::BlockWise) { +#if ROCM_VERSION >= 70000 + TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch({"gfx950"}, 0), + "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); + + TORCH_CHECK(mat1.size(0) % 32 == 0 && mat1.size(1) % 32 == 0 && + mat2.size(0) % 32 == 0 && mat2.size(1) % 32 == 0, + "Matrix dimensions must be multiples of 32 for block-wise scaling"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 || out.scalar_type() == ScalarType::Half, @@ -1404,7 +1648,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, } #endif +<<<<<<< HEAD cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b); +======= + cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); @@ -1479,6 +1727,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, params.k = args.k; params.a = args.mata->data_ptr(); params.a_scale_ptr = args.scale_mata_ptr; +<<<<<<< HEAD params.a_scale_dtype = args.scale_mata_dtype.value(); params.lda = args.lda; params.a_dtype = args.mata->scalar_type(); @@ -1491,6 +1740,16 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, params.b_dtype = args.matb->scalar_type(); params.b_scale_dtype = args.scale_matb_dtype.value(); params.b_scaling_type = args.scaling_matb_type.value(); +======= + params.a_scale_dtype = scale_a.scalar_type(); + params.lda = args.lda; + params.a_dtype = args.mata->scalar_type(); + params.b = args.matb->data_ptr(); + params.b_scale_ptr = args.scale_matb_ptr; + params.b_scale_dtype = scale_b.scalar_type(); + params.ldb = args.ldb; + params.b_dtype = args.matb->scalar_type(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) params.bias_ptr = bias ? bias->data_ptr(): nullptr; params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_; params.c = args.result->data_ptr(); @@ -1498,6 +1757,10 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, params.ldc = args.result_ld; params.c_dtype = out_dtype_; params.use_fast_accum = use_fast_accum; +<<<<<<< HEAD +======= + params.use_rowwise = scaling_choice == ScalingType::RowWise; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (transa_ && transb_) { TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T) } @@ -1531,28 +1794,105 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, args.lda, args.mata->scalar_type(), args.scale_mata_dtype.value(), +<<<<<<< HEAD args.scaling_mata_type.value(), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args.matb->data_ptr(), args.scale_matb_ptr, args.ldb, args.matb->scalar_type(), args.scale_matb_dtype.value(), +<<<<<<< HEAD args.scaling_matb_type.value(), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bias ? bias->data_ptr(): nullptr, bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, args.result->data_ptr(), args.scale_result_ptr, args.result_ld, out_dtype_, +<<<<<<< HEAD use_fast_accum); +======= + use_fast_accum, + scaling_choice == ScalingType::RowWise); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return out; } namespace { +<<<<<<< HEAD void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) { // Checks scales for 2d or 3d target tensors (`mat`). +======= + at::Tensor create_grouped_gemm_output_tensor(const Tensor& mat_a, + const Tensor& mat_b, + const std::optional& offs, + std::optional out_dtype + ) { + c10::SmallVector out_size; + const bool a_is_2d = mat_a.dim() == 2; + const bool b_is_2d = mat_b.dim() == 2; + if (a_is_2d) { + if (b_is_2d) { + out_size = {offs->size(0), mat_a.size(0), mat_b.size(1)}; + } else { + TORCH_CHECK(offs->size(0) == mat_b.size(0), "matrix batch sizes have to match"); + out_size = {mat_a.size(0), mat_b.size(-1)}; + } + } else { + if (b_is_2d) { + // this case is not actually encountered for MoE gemms + TORCH_CHECK(offs->size(0) == mat_a.size(0), "matrix batch sizes have to match"); + out_size = {mat_a.size(1), mat_b.size(1)}; + } else { // regular bmm + TORCH_CHECK(mat_a.size(0) == mat_b.size(0), "batched dimension has to match"); + out_size = {mat_a.size(0), mat_a.size(1), mat_b.size(-1)}; + } + } + + const auto out_dtype_ = out_dtype.value_or(kBFloat16); + TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm"); + + // For TMA transfers, strides of output tensor have to be either + // 1, or aligned to 16 bytes. + const auto last_dim = out_size.size() - 1; + const auto alignment = 16 / c10::elementSize(out_dtype_); + const int64_t size_padded = (out_size[last_dim] + alignment - 1) / alignment * alignment; + std::vector out_stride; + if (a_is_2d != b_is_2d) { + out_stride = {size_padded, 1}; + } else { + out_stride = {out_size[1] * size_padded, size_padded, 1}; + } + auto out = at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_)); + + return out; + } + + bool check_valid_strides_and_return_transposed(const Tensor& mat) { + IntArrayRef tensor_strides = mat.strides(); + IntArrayRef tensor_sizes = mat.sizes(); + int end_dim = mat.dim() - 1; + int alignment = 16 / mat.element_size(); + TORCH_CHECK(uint64_t(mat.data_ptr()) % 16 ==0, "expected data_ptr to be aligned to 16 bytes\n"); + if ((tensor_strides[end_dim - 1] == 1) && (tensor_strides[end_dim] >= std::max(1, tensor_sizes[end_dim - 1]))) { + TORCH_CHECK(tensor_strides[end_dim] % alignment == 0, "strides should be multiple of 16 bytes"); + return true; + } else if ((tensor_strides[end_dim] == 1) && (tensor_strides[end_dim - 1] >= std::max(1, tensor_sizes[end_dim]))) { + TORCH_CHECK(tensor_strides[end_dim - 1] % alignment == 0, "strides should be multiple of 16 bytes"); + return false; + } else { + TORCH_CHECK(false, "Invalid strides/sizes, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes"); + } + } + + void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (mat.dim() == 2) { TORCH_CHECK( scale.dim() == 1, @@ -1586,6 +1926,7 @@ namespace { "scale must have the same first dimension as mat for arg ", arg_idx); } +<<<<<<< HEAD } void _check_scales_mxfp8(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx) { @@ -1646,6 +1987,11 @@ namespace { TORCH_CHECK(false, "scale must be float32 or float8_e8m0fnu, but got ", scale.dtype()); } } +======= +} + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } Tensor @@ -1670,18 +2016,30 @@ const std::optional& bias, const std::optional& scale_result, std::optional out_dtype, bool use_fast_accum) { +<<<<<<< HEAD bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true); TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+"); +======= +#ifndef USE_ROCM + bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true); + TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = 9.0"); + + TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type()); + TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed"); TORCH_CHECK(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed"); TORCH_CHECK(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d"); TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); const bool a_is_2d = mat_a.dim() == 2; const bool b_is_2d = mat_b.dim() == 2; +<<<<<<< HEAD if (!a_is_2d || !b_is_2d) { TORCH_CHECK(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match"); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK( mat_a.size(-1) % 16 == 0, "Expected trailing dimension of mat_a to be divisible by 16 ", @@ -1704,17 +2062,25 @@ bool use_fast_accum) { TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32"); } +<<<<<<< HEAD // FP8 per-tensor and per-row scaling expect fp32 scales. // MXFP8 expects float8_e8m0fnu scales. TORCH_CHECK( (scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat) || (scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu), "For FP8 tensorwise and rowwise, both scales must both be float32 tensors. For MXFP8, scales must both be float8_e8m0fnu tensors."); +======= + // Both Per-Tensor and Row-wise scaling expect fp32 tensors + TORCH_CHECK( + scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat, + "Both scale_a and scale_b must be float (fp32) tensors."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1; check_scale(mat_a, scale_a, 0 ,0, scale_multiplier); check_scale(mat_b, scale_b, 1, 1, scale_multiplier); +<<<<<<< HEAD const auto out_dtype_ = out_dtype.value_or(kBFloat16); TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm"); @@ -1749,6 +2115,9 @@ bool use_fast_accum) { #ifndef USE_ROCM TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type()); TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type()); +======= + Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::cuda::detail::f8f8bf16_grouped_mm( mat_a, @@ -1760,6 +2129,7 @@ bool use_fast_accum) { use_fast_accum, out); return out; +<<<<<<< HEAD #else #ifdef USE_FBGEMM_GENAI TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_a.scalar_type()); @@ -1778,6 +2148,14 @@ bool use_fast_accum) { TORCH_CHECK(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM") #endif +======= + + + + +#else + TORCH_CHECK(false, "grouped gemm is not supported on ROCM") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif } @@ -1787,6 +2165,7 @@ const std::optional& offs, const std::optional& bias, std::optional out_dtype) { #ifndef USE_ROCM +<<<<<<< HEAD _grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype); bool a_b_and_out_are_bf16 = ( mat_a.dtype() == at::kBFloat16 && @@ -1802,6 +2181,32 @@ std::optional out_dtype) { } else { _grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out); } +======= + bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true); + TORCH_CHECK(allowed_device, "torch._grouped_mm is only supported on CUDA devices with compute capability = 9.0"); + + TORCH_CHECK(mat_a.dtype() == at::kBFloat16, "Expected mat_a to be BFloat16 matrix got ", mat_a.scalar_type()); + TORCH_CHECK(mat_b.dtype() == at::kBFloat16, "Expected mat_a to be BFloat16 matrix got ", mat_b.scalar_type()); + TORCH_CHECK(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d"); + TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); + const bool a_is_2d = mat_a.dim() == 2; + const bool b_is_2d = mat_b.dim() == 2; + + // check that the strides are valid, the fn will throw an error if not + check_valid_strides_and_return_transposed(mat_a); + check_valid_strides_and_return_transposed(mat_b); + TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix, or no offset if both matrices are 3d"); + + if (offs.has_value()) { + TORCH_CHECK(offs->dim() == 1, "offs has to be 1D"); + TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32"); + } + TORCH_CHECK(!bias.has_value(), "Bias not supported yet"); + + Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype); + + at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return out; #else TORCH_CHECK(false, "grouped gemm is not supported on ROCM") diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index ee28c5c1693f4..6ec67e998f8fe 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -297,7 +297,10 @@ static inline void launch_vectorized_kernel( int vec_size = memory::can_vectorize_up_to(data); c10::DeviceIndex curDevice = -1; AT_CUDA_CHECK(c10::cuda::GetDevice(&curDevice)); +<<<<<<< HEAD // Similar check in vectorized_elementwise_kernel() as well. Both should be in sync. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int tws = at::detail::getCUDAHooks().isGPUArch({"gfx942"}, curDevice) ? 16 : elems_per_thread(); #else using cpp_type = typename function_traits::result_type; @@ -436,6 +439,10 @@ static inline void launch_vectorized_templated_kernel( loader_t l, storer_t s) { TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); +<<<<<<< HEAD +======= + using traits = function_traits; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t grid = (N + vectorized_templated_config::block_work_size() - 1) / vectorized_templated_config::block_work_size(); auto stream = at::cuda::getCurrentCUDAStream(); @@ -882,6 +889,72 @@ struct type_specialized_kernel_launcher { } }; +<<<<<<< HEAD +======= +template +struct type_specialized_broadcast_kernel_launcher { + template < + typename func_t, + typename array_t, + typename dtypes_t, + typename calc_t> + static void apply( + int64_t numel, + func_t f, + array_t data, + dtypes_t dtypes, + calc_t offset_calc) { + using traits = function_traits; + using ret_t = typename traits::result_type; + using arg0_t = typename traits::template arg<0>::type; + using arg1_t = typename traits::template arg<1>::type; + if (dtypes[0] == rt_binary_specializations[arg_index][0] && + dtypes[1] == rt_binary_specializations[arg_index][1] && + dtypes[2] == rt_binary_specializations[arg_index][2]) { + using ret_cpp_t = c10::impl::ScalarTypeToCPPTypeT; + using arg0_cpp_t = c10::impl::ScalarTypeToCPPTypeT; + using arg1_cpp_t = c10::impl::ScalarTypeToCPPTypeT; + constexpr int grp_sz = 128; + launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { + if (unrl) { + auto offsets0 = offset_calc.get(idx); + auto offsets1 = offset_calc.get(idx + grp_sz); + auto offsets2 = offset_calc.get(idx + grp_sz * 2); + auto offsets3 = offset_calc.get(idx + grp_sz * 3); + void* out0 = data[0] + offsets0[0]; + void* out1 = data[0] + offsets1[0]; + void* out2 = data[0] + offsets2[0]; + void* out3 = data[0] + offsets3[0]; + auto u = c10::load(data[1] + offsets0[1]); + auto v = c10::load(data[2] + offsets0[2]); + ret_t result0 = f(c10::convert(u), c10::convert(v)); + auto u1 = c10::load(data[1] + offsets1[1]); + auto v1 = c10::load(data[2]+ offsets1[2]); + ret_t result1 = f(c10::convert(u1), c10::convert(v1)); + auto u2 = c10::load(data[1] + offsets2[1]); + auto v2 = c10::load(data[2] + offsets2[2]); + ret_t result2 = f(c10::convert(u2), c10::convert(v2)); + auto u3 = c10::load(data[1] + offsets3[1]); + auto v3 = c10::load(data[2] + offsets3[2]); + ret_t result3 = f(c10::convert(u3), c10::convert(v3)); + *(ret_cpp_t*)out0 = c10::convert(result0); + *(ret_cpp_t*)out1 = c10::convert(result1); + *(ret_cpp_t*)out2 = c10::convert(result2); + *(ret_cpp_t*)out3 = c10::convert(result3); + } else { + auto offsets = offset_calc.get(idx); + void* out = data[0] + offsets[0]; + auto u = c10::load(data[1] + offsets[1]); + auto v = c10::load(data[2] + offsets[2]); + ret_t result = f(c10::convert(u), c10::convert(v)); + *(ret_cpp_t*)out = c10::convert(result); + } + }); + } + } +}; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace #endif @@ -1000,6 +1073,35 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { } auto offset_calc = ::make_offset_calculator(iter); #ifdef USE_ROCM +<<<<<<< HEAD +======= + if (check_binary_rt_types_for_specialization(iter)) { + // constexpr to reduce the amount of kernels generated for + // broadcast elementwise with mexed dtypes and limit which functors are actually + // applied to the load and store at compile time. + using func_tuple = typename traits::ArgsTuple; + if constexpr ( + std::is_same_v && traits::arity == 2 && + check_binary_functor_types_for_specialization< + func_tuple, + float, + float, + traits::arity, + /*arg_num=*/0>::check()) { + memory::detail::static_unroll< + type_specialized_broadcast_kernel_launcher, + rt_binary_specializations.size()>::with_args( + numel, + f, + data, + dtypes, + offset_calc + ); + return; + } + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) constexpr int grp_sz = 128; launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { if (unrl) { diff --git a/aten/src/ATen/native/cuda/CUDAScalar.cu b/aten/src/ATen/native/cuda/CUDAScalar.cu index 0d34bd52f211a..524578b07c90b 100644 --- a/aten/src/ATen/native/cuda/CUDAScalar.cu +++ b/aten/src/ATen/native/cuda/CUDAScalar.cu @@ -11,11 +11,31 @@ #include +<<<<<<< HEAD +======= +#if defined(USE_ROCM) +// TODO(lufang): Tensor.item() on AMD HIP is not synced in the Recsys models. +// This is just a short term workaround. Issue is tracked as FBA-388 on the AMD side. +namespace { + bool use_sync_mode() { + static const bool sync_mode = c10::utils::check_env("HIP_DOUBLE_SYNC_ON_LOCAL_SCALE_DENSE") == true; + return sync_mode; + } +} +#endif + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace at::native { Scalar _local_scalar_dense_cuda(const Tensor& self) { Scalar r; TORCH_CHECK(self.numel() > 0, "_local_scalar_dense: Empty tensor not supported"); +<<<<<<< HEAD +======= +#if defined(USE_ROCM) + if (!use_sync_mode()){ +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AT_DISPATCH_V2( self.scalar_type(), "_local_scalar_dense_cuda", AT_WRAP([&] { // Create pinned memory for the scalar value to avoid implicit @@ -32,6 +52,18 @@ Scalar _local_scalar_dense_cuda(const Tensor& self) { at::cuda::memcpy_and_sync((void *)value.const_data_ptr(), self.const_data_ptr(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream); r = Scalar(*value.const_data_ptr()); }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +<<<<<<< HEAD +======= +#if defined(USE_ROCM) + } else { + auto cpu_self = self.cpu(); + AT_DISPATCH_V2( + self.scalar_type(), "_local_scalar_dense_hip", AT_WRAP([&] { + r = Scalar(*cpu_self.const_data_ptr()); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + } +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return r; } diff --git a/aten/src/ATen/native/cuda/CuFFTPlanCache.h b/aten/src/ATen/native/cuda/CuFFTPlanCache.h index 333c21e94f18e..49e14b9b51540 100644 --- a/aten/src/ATen/native/cuda/CuFFTPlanCache.h +++ b/aten/src/ATen/native/cuda/CuFFTPlanCache.h @@ -223,7 +223,11 @@ inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bo class CuFFTConfig { public: +<<<<<<< HEAD // Only move semantics is enough for this class. Although we already use +======= + // Only move semantics is enought for this class. Although we already use +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // unique_ptr for the plan, still remove copy constructor and assignment op so // we don't accidentally copy and take perf hit. CuFFTConfig(const CuFFTConfig&) = delete; diff --git a/aten/src/ATen/native/cuda/CuFFTUtils.h b/aten/src/ATen/native/cuda/CuFFTUtils.h index 38013137f0a40..42e5e9fee11a0 100644 --- a/aten/src/ATen/native/cuda/CuFFTUtils.h +++ b/aten/src/ATen/native/cuda/CuFFTUtils.h @@ -38,12 +38,22 @@ static inline std::string _cudaGetErrorEnum(cufftResult error) return "CUFFT_INVALID_SIZE"; case CUFFT_UNALIGNED_DATA: return "CUFFT_UNALIGNED_DATA"; +<<<<<<< HEAD case CUFFT_INVALID_DEVICE: return "CUFFT_INVALID_DEVICE"; +======= + case CUFFT_INCOMPLETE_PARAMETER_LIST: + return "CUFFT_INCOMPLETE_PARAMETER_LIST"; + case CUFFT_INVALID_DEVICE: + return "CUFFT_INVALID_DEVICE"; + case CUFFT_PARSE_ERROR: + return "CUFFT_PARSE_ERROR"; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case CUFFT_NO_WORKSPACE: return "CUFFT_NO_WORKSPACE"; case CUFFT_NOT_IMPLEMENTED: return "CUFFT_NOT_IMPLEMENTED"; +<<<<<<< HEAD #if CUDA_VERSION <= 12090 case CUFFT_INCOMPLETE_PARAMETER_LIST: return "CUFFT_INCOMPLETE_PARAMETER_LIST"; @@ -51,6 +61,9 @@ static inline std::string _cudaGetErrorEnum(cufftResult error) return "CUFFT_PARSE_ERROR"; #endif #if !defined(USE_ROCM) && CUDA_VERSION <= 12090 +======= +#if !defined(USE_ROCM) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case CUFFT_LICENSE_ERROR: return "CUFFT_LICENSE_ERROR"; #endif diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index 602dfd6e52882..663553e343fc0 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -317,7 +317,11 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice auto count_data = count.mutable_data_ptr(); cuda::cub::inclusive_sum_by_key( sorted_data, +<<<<<<< HEAD ATEN_CUB_CONSTANT_ITERATOR(index_t)(1), +======= + NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::ConstantInputIterator(1), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) count_data, num_indices ); @@ -329,7 +333,11 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice thrust::make_reverse_iterator(sorted_data + num_indices), thrust::make_reverse_iterator(static_cast(count_data) + num_indices), thrust::make_reverse_iterator(count_data + num_indices), +<<<<<<< HEAD ATEN_CUB_MAXIMUM(), +======= + NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) num_indices ); }); @@ -369,7 +377,11 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, int warp_size = at::cuda::warp_size(); TORCH_INTERNAL_ASSERT(num_threads() % warp_size == 0 && +<<<<<<< HEAD num_threads() <= static_cast(cuda_utils::kCUDABlockReduceMaxThreads()), +======= + num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "BlockReduceSum requires all warps be active"); const int64_t *num_unique_indices_ptr = num_unique_indices.const_data_ptr(); dim3 grid = unique_indices.numel(); diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index fb92c7488a152..12d35ef840dc6 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -210,7 +210,11 @@ Tensor embedding_bag_backward_cuda_sum_avg( auto count_data = count.mutable_data_ptr(); cuda::cub::inclusive_sum_by_key( sorted_data, +<<<<<<< HEAD ATEN_CUB_CONSTANT_ITERATOR(index_t)(1), +======= + NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::ConstantInputIterator(1), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) count_data, num_indices ); @@ -222,7 +226,11 @@ Tensor embedding_bag_backward_cuda_sum_avg( thrust::make_reverse_iterator(sorted_data + num_indices), thrust::make_reverse_iterator(count_data + num_indices), thrust::make_reverse_iterator(count_data + num_indices), +<<<<<<< HEAD ATEN_CUB_MAXIMUM(), +======= + NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) num_indices ); }); diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh index c121d971cd7be..a6938aeb94d24 100644 --- a/aten/src/ATen/native/cuda/ForeachFunctors.cuh +++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh @@ -208,7 +208,11 @@ struct BinaryOpScalarFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( +<<<<<<< HEAD int64_t chunk_size, +======= + int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorListMetadata& tl, Op op, opmath_t scalar) { @@ -232,7 +236,11 @@ struct BinaryOpScalarListFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( +<<<<<<< HEAD int64_t chunk_size, +======= + int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorListScalarListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -256,7 +264,11 @@ struct BinaryOpListAlphaFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( +<<<<<<< HEAD int64_t chunk_size, +======= + int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorListMetadata& tl, Op op, opmath_t alpha) { @@ -308,7 +320,11 @@ struct BinaryOpScalarTensorFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( +<<<<<<< HEAD int64_t chunk_size, +======= + int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorListMetadata& tl, Op op, T* scalar, @@ -364,7 +380,11 @@ struct BinaryOpScalarTensorFunctor { template struct ZeroFunctor { __device__ __forceinline__ void operator()( +<<<<<<< HEAD int64_t chunk_size, +======= + int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorListMetadata<1>& tl) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; @@ -406,7 +426,11 @@ struct UnaryOpFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( +<<<<<<< HEAD int64_t chunk_size, +======= + int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -458,7 +482,11 @@ struct PointwiseOpScalarFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( +<<<<<<< HEAD int64_t chunk_size, +======= + int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorListMetadata& tl, Op op, opmath_t scalar) { @@ -482,7 +510,11 @@ struct PointwiseOpScalarListFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( +<<<<<<< HEAD int64_t chunk_size, +======= + int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorListScalarListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -506,7 +538,11 @@ struct PointwiseOpListFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( +<<<<<<< HEAD int64_t chunk_size, +======= + int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorListMetadata& tl, Op op) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; @@ -557,7 +593,11 @@ struct TernaryOpListFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( +<<<<<<< HEAD int64_t chunk_size, +======= + int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorListMetadata& tl, Op op) { static_assert(depth == 3 || depth == 4, ""); @@ -611,7 +651,11 @@ struct TernaryOpScalarFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( +<<<<<<< HEAD int64_t chunk_size, +======= + int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorListMetadata& tl, Op op, opmath_t alpha) { @@ -668,7 +712,11 @@ struct TernaryOpScalarListFunctor { using opmath_t = at::opmath_type; template __device__ __forceinline__ void operator()( +<<<<<<< HEAD int64_t chunk_size, +======= + int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorListScalarListMetadata& tl, Op op) { static_assert(depth == 2 || depth == 3, ""); diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu index 2da8e634981f9..ee674fbed2767 100644 --- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu +++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu @@ -53,7 +53,11 @@ template < int res_arg_index = 0> struct LpMaxFunctor { __device__ __forceinline__ void operator()( +<<<<<<< HEAD int64_t chunk_size, +======= + int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorListMetadata& tl, T* output_per_tensor_ptr, const int max_chunks_per_tensor) { @@ -243,7 +247,11 @@ template < struct LpNormFunctor { using out_opmath_t = typename at::opmath_type; __device__ __forceinline__ void operator()( +<<<<<<< HEAD int64_t chunk_size, +======= + int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorListMetadata& tl, out_opmath_t* output_per_tensor_ptr, const int max_chunks_per_tensor) { diff --git a/aten/src/ATen/native/cuda/FusedSgdKernel.cu b/aten/src/ATen/native/cuda/FusedSgdKernel.cu index d0cf7e06c8688..12c68d73c2b4a 100644 --- a/aten/src/ATen/native/cuda/FusedSgdKernel.cu +++ b/aten/src/ATen/native/cuda/FusedSgdKernel.cu @@ -62,7 +62,11 @@ struct FusedSgdMathFunctor { depth == 2 || depth == 3, "depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0"); C10_DEVICE __forceinline__ void operator()( +<<<<<<< HEAD const int64_t chunk_size, +======= + const int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorListMetadata& tl, const double weight_decay, const double momentum, diff --git a/aten/src/ATen/native/cuda/GroupMM.cu b/aten/src/ATen/native/cuda/GroupMM.cu index a917b0d6163fa..44e151d1b6ab7 100644 --- a/aten/src/ATen/native/cuda/GroupMM.cu +++ b/aten/src/ATen/native/cuda/GroupMM.cu @@ -8,10 +8,16 @@ #include +<<<<<<< HEAD // Three warninngs in Cutlass included header files C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable") +======= +// Two warninngs in Cutlass included header files +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used") +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Determine if the architecture supports rowwise scaled mm // Currently failing on windows with: @@ -44,6 +50,7 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable") #include #include +<<<<<<< HEAD #include namespace { @@ -52,6 +59,13 @@ using Strides = at::cuda::detail::Strides; // std::array; template struct Schedule { // SM90 +======= +namespace { +using Strides = at::cuda::detail::Strides; // std::array; + +template +struct Schedule { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using CooperativeSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; using PongSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; @@ -59,6 +73,7 @@ struct Schedule { cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; using PongEpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; +<<<<<<< HEAD // SM100 using MMA1SMKernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; using MMA1SMEpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; @@ -72,6 +87,12 @@ struct Schedule { cute::conditional_t, cute::conditional_t>; +======= + using KernelSchedule = + cute::conditional_t; + using EpilogueSchedule = cute:: + conditional_t; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; int ceildiv(int a, int b) { @@ -83,6 +104,7 @@ int round_up_to_nearest_multiple(int a, int b) { } template < +<<<<<<< HEAD typename ArchTag, bool a_row_major, bool b_row_major, @@ -91,6 +113,15 @@ template < typename TB_N, typename TB_K> void bf16bf16_grouped_gemm_impl_sm90_sm100( +======= + bool a_row_major, + bool b_row_major, + bool Pong, + typename TB_M, + typename TB_N, + typename TB_K> +void bf16bf16_grouped_gemm_impl_sm90( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::Tensor mat_a, // bf16 at::Tensor mat_b, // bf16 std::optional offs, @@ -113,13 +144,23 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100( constexpr int AlignmentB = 16 / sizeof(DtypeB); using LayoutOutput = cutlass::layout::RowMajor; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); +<<<<<<< HEAD +======= + using ArchTag = cutlass::arch::Sm90; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using OperatorClass = cutlass::arch::OpClassTensorOp; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using KernelSchedule = +<<<<<<< HEAD typename Schedule::KernelSchedule; using EpilogueSchedule = typename Schedule::EpilogueSchedule; +======= + typename Schedule::KernelSchedule; + using EpilogueSchedule = + typename Schedule::EpilogueSchedule; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using ProblemShape = cutlass::gemm::GroupProblemShape< cute::Shape>; // per // group @@ -159,6 +200,7 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100( cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; +<<<<<<< HEAD using GemmKernelBase = cutlass::gemm::kernel::GemmUniversal< ProblemShape, @@ -169,6 +211,10 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100( std::is_same_v, at::cuda::detail::enable_3x_kernel_for_sm10, at::cuda::detail::enable_3x_kernel_for_sm9x>; +======= + using GemmKernel = cutlass::gemm::kernel:: + GemmUniversal; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using Gemm = cutlass::gemm::device::GemmUniversalAdapter; using StrideA = typename Gemm::GemmKernel::InternalStrideA; @@ -241,8 +287,11 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100( Strides tensor_StrideA = make_strides(mat_a.strides()); Strides tensor_StrideB = make_strides(mat_b.strides()); Strides tensor_StrideOutput = make_strides(out.strides()); +<<<<<<< HEAD Strides tensor_ShapeA = make_strides(mat_a.sizes()); Strides tensor_ShapeB = make_strides(mat_b.sizes()); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::cuda::detail::prepare_grouped_gemm_data<<<1, group_count, 0, stream>>>( reinterpret_cast(mat_a.data_ptr()), @@ -266,8 +315,11 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100( tensor_StrideA, tensor_StrideB, tensor_StrideOutput, +<<<<<<< HEAD tensor_ShapeA, tensor_ShapeB, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 0, 0, a_row_major, @@ -344,6 +396,7 @@ void dispatch_bf16_grouped_kernel_on_tile_size( // ((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) || // (K >= 2048 && N >= 2048)); bool small = (M <= 128 || N <= 128); +<<<<<<< HEAD cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); const bool sm10x = properties != nullptr && properties->major == 10; @@ -387,6 +440,24 @@ void dispatch_bf16_grouped_kernel_on_tile_size( cute::_256, cute::_64>(mat_a, mat_b, offs, bias, out); } +======= + if (small) { + bf16bf16_grouped_gemm_impl_sm90< + a_row_major, + b_row_major, + /*Pong*/ true, + cute::_64, + cute::_128, + cute::_128>(mat_a, mat_b, offs, bias, out); + } else { + bf16bf16_grouped_gemm_impl_sm90< + a_row_major, + b_row_major, + /*Pong*/ false, + cute::_128, + cute::_256, + cute::_64>(mat_a, mat_b, offs, bias, out); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } diff --git a/aten/src/ATen/native/cuda/GroupMMCommon.cuh b/aten/src/ATen/native/cuda/GroupMMCommon.cuh index ed8176b53f84c..9fecb6c1b760d 100644 --- a/aten/src/ATen/native/cuda/GroupMMCommon.cuh +++ b/aten/src/ATen/native/cuda/GroupMMCommon.cuh @@ -38,20 +38,37 @@ __global__ void prepare_grouped_gemm_data( Strides tensor_StrideA, Strides tensor_StrideB, Strides tensor_StrideOutput, +<<<<<<< HEAD Strides tensor_ShapeA, Strides tensor_ShapeB, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t a_scale_stride, int64_t b_scale_stride, bool a_row_major = true, bool b_row_major = false) { int32_t tid = threadIdx.x; int32_t delta = 0; +<<<<<<< HEAD int32_t offset = 0; if (offs != nullptr) { int32_t start = tid == 0 ? 0 : offs[tid - 1]; offset = offs[tid]; delta = offset - start; CUDA_KERNEL_ASSERT(delta >=0 && "expected gemm dimension to be greater or equal 0\n"); +======= + if (offs != nullptr) { + int32_t start = tid == 0 ? 0 : offs[tid - 1]; + delta = offs[tid] - start; + if (K < 0) { + if (!a_row_major && b_row_major) { + CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n"); + } else { + // CUTLASS cannot handle delta=0 here. + CUDA_KERNEL_ASSERT(delta >0 && "expected ofsets to be greater than 0\n"); + } + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // TMA transfers require global memory tensor addresses to be // aligned to 16 bytes. @@ -86,7 +103,10 @@ __global__ void prepare_grouped_gemm_data( int64_t lda, ldb, ldoutput; if (M < 0) { // A and output is 2d +<<<<<<< HEAD CUDA_KERNEL_ASSERT(offset <= tensor_ShapeA[0] && "expected offset to be less than tensor size\n"); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) M = delta; lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2]; @@ -99,7 +119,10 @@ __global__ void prepare_grouped_gemm_data( output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput; B_ptrs[tid] = B + tid * tensor_StrideB[0]; } else if (N < 0) { +<<<<<<< HEAD CUDA_KERNEL_ASSERT(offset <= tensor_ShapeB[1] && "expected offset to be less than tensor size\n"); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) N = delta; lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2]; ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; // B is transposed @@ -112,7 +135,10 @@ __global__ void prepare_grouped_gemm_data( inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1]; } } else if (K < 0) { +<<<<<<< HEAD CUDA_KERNEL_ASSERT(offset <= tensor_ShapeA[1] && offset <= tensor_ShapeB[0] && "expected offset to be less than tensor size\n"); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // A, B is 2d, output is 3d K = delta; lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; diff --git a/aten/src/ATen/native/cuda/KernelUtils.cuh b/aten/src/ATen/native/cuda/KernelUtils.cuh index 75fdd6922a8bd..03f9b419999c2 100644 --- a/aten/src/ATen/native/cuda/KernelUtils.cuh +++ b/aten/src/ATen/native/cuda/KernelUtils.cuh @@ -316,8 +316,11 @@ __device__ __forceinline__ void opportunistic_fastAtomicAdd( } } +<<<<<<< HEAD // not coalsced, so now let try to capture lane-matches... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (numel > 16 /*<-hueristic threshold*/ * 64 ) { // well shucks, unlikely to capture same-dest atomics in a wave. // fall back to direct fastAtomic... @@ -325,6 +328,10 @@ __device__ __forceinline__ void opportunistic_fastAtomicAdd( return; } +<<<<<<< HEAD +======= + // not coalsced, so now let try to capture lane-matches... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // __activemask() -- finds the set of threads in the warp that are about to perform atomicAdd // __match_any_sync() -- returns bit mask of the threads that have same dest addr auto mask = __match_any_sync(__activemask(), (int64_t)dst); diff --git a/aten/src/ATen/native/cuda/LossCTC.cu b/aten/src/ATen/native/cuda/LossCTC.cu index c6d3c25200d50..26894d844203f 100644 --- a/aten/src/ATen/native/cuda/LossCTC.cu +++ b/aten/src/ATen/native/cuda/LossCTC.cu @@ -644,12 +644,16 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ Tensor grad = at::full_like(log_probs, neginf, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // initialization for log(sum (alpha beta)) // As above, there may be better configurations to use. +<<<<<<< HEAD constexpr int max_threads_ = std::is_same_v ? 1024 : 896; // we need 72 or so 32 bit registers for double int max_threads = max_threads_; // Blackwell launch bounds if (at::cuda::getCurrentDeviceProperties()->major >= 10) { max_threads = 512; } +======= + constexpr int max_threads = std::is_same_v ? 1024 : 896; // we need 72 or so 32 bit registers for double +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int threads_target = max_threads; while (threads_target / 2 >= 2*max_target_length+1) { threads_target /= 2; diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index 1d603132e6893..bfd1b74ed865a 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -1946,7 +1946,11 @@ const auto chebyshev_polynomial_t_string = jiterator_stringify( T q = x; T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x) * q - p; p = q; q = r; @@ -1996,7 +2000,11 @@ const auto chebyshev_polynomial_u_string = jiterator_stringify( T q = x + x; T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x) * q - p; p = q; q = r; @@ -2054,7 +2062,11 @@ const auto chebyshev_polynomial_v_string = jiterator_stringify( T q = x + x - T(1.0); T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x) * q - p; p = q; q = r; @@ -2116,7 +2128,11 @@ const auto chebyshev_polynomial_w_string = jiterator_stringify( T q = x + x + T(1.0); T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x) * q - p; p = q; q = r; @@ -2252,7 +2268,11 @@ const auto laguerre_polynomial_l_string = jiterator_stringify( T q = T(1.0) - x; T r; +<<<<<<< HEAD for (int64_t k = 1; (k < n) && !isnan(q); k++) { +======= + for (int64_t k = 1; k < n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (((k + k) + (T(1.0) - x)) * q - k * p) / (k + 1); p = q; q = r; @@ -2294,7 +2314,11 @@ const auto legendre_polynomial_p_string = jiterator_stringify( T q = x; T r; +<<<<<<< HEAD for (int64_t k = 1; (k < n) && !isnan(q); k++) { +======= + for (int64_t k = 1; k < n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = ((k + k + 1) * x * q - k * p) / (k + 1); p = q; q = r; @@ -2851,7 +2875,11 @@ const auto shifted_chebyshev_polynomial_t_string = jiterator_stringify( T q = x + x - T(1.0); T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; @@ -2905,7 +2933,11 @@ const auto shifted_chebyshev_polynomial_u_string = jiterator_stringify( T q = x + x - T(1.0) + (x + x - T(1.0)); T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; @@ -2963,7 +2995,11 @@ const auto shifted_chebyshev_polynomial_v_string = jiterator_stringify( T q = x + x - T(1.0) + (x + x - T(1.0)) - T(1.0); T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; @@ -3021,7 +3057,11 @@ const auto shifted_chebyshev_polynomial_w_string = jiterator_stringify( T q = x + x - T(1.0) + (x + x - T(1.0)) + T(1.0); T r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p; p = q; q = r; diff --git a/aten/src/ATen/native/cuda/MultiMarginLoss.cu b/aten/src/ATen/native/cuda/MultiMarginLoss.cu index ad7b3638b489d..3c179f61c6e6e 100644 --- a/aten/src/ATen/native/cuda/MultiMarginLoss.cu +++ b/aten/src/ATen/native/cuda/MultiMarginLoss.cu @@ -121,7 +121,10 @@ __global__ void MultiMarginLoss_backward_kernel( gradInput_k[target_k] = static_cast(gradInput_target_k); } +<<<<<<< HEAD __syncthreads(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (int i=i_start; i; +<<<<<<< HEAD using TransformInputIteratorT = ATEN_CUB_TRANSFORM_ITERATOR(int, NonZeroOp, const T*); +======= + using TransformInputIteratorT = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator, const T*>; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using BlockExchangeT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockExchange; // Shared memory @@ -184,7 +188,11 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { auto num_nonzeros = allocator.allocate(sizeof(int) * num_chunks); for (int64_t idx = 0; idx < num_chunks; idx++) { int64_t remaining = std::min(chunk_size, self.numel() - idx * chunk_size); +<<<<<<< HEAD ATEN_CUB_TRANSFORM_ITERATOR(bool, NonZeroOp, const scalar_t*) itr( +======= + cub::TransformInputIterator, const scalar_t*> itr( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self_.const_data_ptr() + idx * chunk_size, NonZeroOp()); AT_CUDA_CHECK(cub::DeviceReduce::Sum( @@ -243,8 +251,13 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { for (int64_t idx = 0; idx < num_chunks; idx++) { int remaining = std::min(chunk_size, self.numel() - idx * chunk_size); +<<<<<<< HEAD ATEN_CUB_COUNTING_ITERATOR(int64_t) counting_itr(idx * chunk_size); ATEN_CUB_TRANSFORM_ITERATOR(bool, NonZeroOp, const scalar_t*) +======= + cub::CountingInputIterator counting_itr(idx * chunk_size); + cub::TransformInputIterator, const scalar_t*> +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) itr(self_.const_data_ptr() + idx * chunk_size, NonZeroOp()); temp_storage_bytes = 0; diff --git a/aten/src/ATen/native/cuda/Pow.cuh b/aten/src/ATen/native/cuda/Pow.cuh index fe249c1cdaef3..17ac5781b37eb 100644 --- a/aten/src/ATen/native/cuda/Pow.cuh +++ b/aten/src/ATen/native/cuda/Pow.cuh @@ -14,7 +14,11 @@ namespace { // pow(double, int) // pow(float, float) // pow(double, double) +<<<<<<< HEAD #if defined(_MSC_VER) || defined(_LIBCPP_VERSION) +======= +#ifdef _MSC_VER +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Functions for pow // pow for at::Half static inline __host__ __device__ at::Half pow_(at::Half base, at::Half exp) { diff --git a/aten/src/ATen/native/cuda/PowKernel.cu b/aten/src/ATen/native/cuda/PowKernel.cu index 2698207c45ef5..06759dfdb5ea9 100644 --- a/aten/src/ATen/native/cuda/PowKernel.cu +++ b/aten/src/ATen/native/cuda/PowKernel.cu @@ -185,12 +185,15 @@ void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar return; } AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "pow_cuda", [&]() { +<<<<<<< HEAD if (exp_scalar.equal(2.0)) { gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base) -> scalar_t { return base * base; }); return; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto exp = exp_scalar.to(); gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base) -> scalar_t { return pow_(base, exp); diff --git a/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu b/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu index ea1bd955b8ddf..a97949f00fa60 100644 --- a/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu @@ -87,7 +87,11 @@ struct nansum_functor_complex { #else void operator()(TensorIterator& iter) { using acc_t = at::opmath_type; +<<<<<<< HEAD gpu_reduce_kernel( +======= + gpu_reduce_kernel( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) iter, NanSumOps{}); } #endif @@ -154,6 +158,7 @@ struct prod_functor> { #endif }; +<<<<<<< HEAD template struct xor_sum_functor { void operator()(TensorIterator& iter) { @@ -199,6 +204,8 @@ struct xor_sum_functor } }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // The function `reduce_dispatch` below dispatches to the kernel based // on the type of `iter`. It takes care of the common logic // for handling Half-Precision floating types. @@ -267,6 +274,7 @@ static void prod_kernel_cuda(TensorIterator& iter) { reduce_dispatch(iter, general_dispatcher); } +<<<<<<< HEAD static void xor_sum_kernel_cuda(TensorIterator& iter) { // Use iter.dtype(1) to dispatch based on the type of the input tensor AT_DISPATCH_ALL_TYPES_AND3( @@ -279,5 +287,10 @@ REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda) REGISTER_DISPATCH(nansum_stub, &nansum_kernel_cuda) REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda) REGISTER_DISPATCH(xor_sum_stub, &xor_sum_kernel_cuda) +======= +REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda) +REGISTER_DISPATCH(nansum_stub, &nansum_kernel_cuda) +REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::native diff --git a/aten/src/ATen/native/cuda/Repeat.cu b/aten/src/ATen/native/cuda/Repeat.cu index 1e2364ae50913..00de1c9f33eaf 100644 --- a/aten/src/ATen/native/cuda/Repeat.cu +++ b/aten/src/ATen/native/cuda/Repeat.cu @@ -17,6 +17,7 @@ __global__ static void compute_cuda_kernel( index_t* result_ptr, int64_t size, int64_t result_size) { +<<<<<<< HEAD if (C10_UNLIKELY((result_size != cumsum_ptr[size - 1]))) { printf("%s:%d:%s: block: [%d,%d,%d], thread: [%d,%d,%d] " "Invalid input! In `repeat_interleave`, the `output_size` argument (%ld) must be the same as the sum of the elements in the `repeats` tensor (%ld).\n", @@ -24,6 +25,9 @@ __global__ static void compute_cuda_kernel( CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1]) } +======= + CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t idx = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; int64_t stride = (blockDim.x * gridDim.x) / C10_WARP_SIZE; int warp_id = idx / C10_WARP_SIZE; diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index 3eeca901a18d5..519d201945dff 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -9,7 +9,10 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wmissing-field-initializers") +<<<<<<< HEAD C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Determine if the architecture supports rowwise scaled mm // Currently failing on windows with: @@ -47,7 +50,10 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable") C10_DIAGNOSTIC_POP() C10_DIAGNOSTIC_POP() +<<<<<<< HEAD C10_DIAGNOSTIC_POP() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace { diff --git a/aten/src/ATen/native/cuda/ScaledGroupMM.cu b/aten/src/ATen/native/cuda/ScaledGroupMM.cu index 9a06c5907febc..837ce032d0f9f 100644 --- a/aten/src/ATen/native/cuda/ScaledGroupMM.cu +++ b/aten/src/ATen/native/cuda/ScaledGroupMM.cu @@ -10,7 +10,10 @@ // Two warninngs in Cutlass included header files C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") +<<<<<<< HEAD C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Determine if the architecture supports rowwise scaled mm // Currently failing on windows with: @@ -47,7 +50,10 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable") C10_DIAGNOSTIC_POP() C10_DIAGNOSTIC_POP() +<<<<<<< HEAD C10_DIAGNOSTIC_POP() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace { @@ -298,9 +304,12 @@ void f8f8bf16_grouped_gemm_impl_sm90( Strides tensor_StrideA = make_strides(mat_a.strides()); Strides tensor_StrideB = make_strides(mat_b.strides()); Strides tensor_StrideOutput = make_strides(out.strides()); +<<<<<<< HEAD Strides tensor_ShapeA = make_strides(mat_a.sizes()); Strides tensor_ShapeB = make_strides(mat_b.sizes()); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // scale stride will be used inside the kernel only if needed, // so for 1d scales the "1" assigned here won't be used int64_t a_scale_stride = scale_a.stride(0); @@ -328,8 +337,11 @@ void f8f8bf16_grouped_gemm_impl_sm90( tensor_StrideA, tensor_StrideB, tensor_StrideOutput, +<<<<<<< HEAD tensor_ShapeA, tensor_ShapeB, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a_scale_stride, b_scale_stride); diff --git a/aten/src/ATen/native/cuda/SegmentReduce.cu b/aten/src/ATen/native/cuda/SegmentReduce.cu index c6f88692a8a5c..ffd18f3f337ed 100644 --- a/aten/src/ATen/native/cuda/SegmentReduce.cu +++ b/aten/src/ATen/native/cuda/SegmentReduce.cu @@ -20,7 +20,11 @@ // SegmentReduce compilation with CUDA-12.9 causes NVCC crash on Windows // See https://github.com/pytorch/pytorch/issues/156181 +<<<<<<< HEAD #if !(defined(_WIN32) && CUDART_VERSION == 12090) +======= +#if !defined(_WIN32) || CUDART_VERSION < 12090 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace at::native { @@ -606,4 +610,8 @@ REGISTER_DISPATCH( } // namespace at::native -#endif \ No newline at end of file +<<<<<<< HEAD +#endif +======= +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/cuda/TensorTopK.cu b/aten/src/ATen/native/cuda/TensorTopK.cu index 584c1c49a03ca..7c68bd79fde47 100644 --- a/aten/src/ATen/native/cuda/TensorTopK.cu +++ b/aten/src/ATen/native/cuda/TensorTopK.cu @@ -725,8 +725,13 @@ void launch( desired, counts, num_blocks, blocks_per_slice, kthCounts); C10_CUDA_KERNEL_LAUNCH_CHECK(); // Do a prefix scan of withinKCounts and kthCounts using slice_idx as keys to get the starting index of each block +<<<<<<< HEAD using counting_iter_t = ATEN_CUB_COUNTING_ITERATOR(uint32_t, uint32_t); using slice_idx_iter_t = ATEN_CUB_TRANSFORM_ITERATOR(uint32_t, BlockIdxToKey, counting_iter_t); +======= + using counting_iter_t = cub::CountingInputIterator; + using slice_idx_iter_t = cub::TransformInputIterator; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) slice_idx_iter_t slice_idx_iter(counting_iter_t(0), BlockIdxToKey(blocks_per_slice)); at::cuda::cub::inclusive_sum_by_key(slice_idx_iter, withinKCounts, withinKCounts, num_blocks); at::cuda::cub::inclusive_sum_by_key(slice_idx_iter, kthCounts, kthCounts, num_blocks); diff --git a/aten/src/ATen/native/cuda/UniqueCub.cu b/aten/src/ATen/native/cuda/UniqueCub.cu index 0a1f3408e783d..b3c5340aac393 100644 --- a/aten/src/ATen/native/cuda/UniqueCub.cu +++ b/aten/src/ATen/native/cuda/UniqueCub.cu @@ -54,7 +54,11 @@ struct LoadBoolOp { auto wrap_input_iterator(const bool *data) { // See NOTE [Loading boolean values] LoadBoolOp op; +<<<<<<< HEAD return ATEN_CUB_TRANSFORM_ITERATOR(bool, LoadBoolOp, const uint8_t*, int)( +======= + return NO_ROCM(at_cuda_detail)::cub::TransformInputIterator( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reinterpret_cast(data), op); } @@ -259,10 +263,17 @@ struct UniqueCub { const bool* self_data = self.const_data_ptr(); MapNumberOfTrueValues op; +<<<<<<< HEAD ATEN_CUB_TRANSFORM_ITERATOR(int, MapNumberOfTrueValues, const uint8_t*, int) data_iter(reinterpret_cast(self_data), op); at::cuda::cub::reduce(data_iter, tmp_num_true.get(), num_inp, NO_ROCM(::cuda)::std::plus<>{}, 0); +======= + NO_ROCM(at_cuda_detail)::cub::TransformInputIterator + data_iter(reinterpret_cast(self_data), op); + at::cuda::cub::reduce(data_iter, tmp_num_true.get(), num_inp, + NO_ROCM(at_cuda_detail)::cub::Sum{}, 0); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto options = self.options(); output = at::empty({2}, self.options()); diff --git a/aten/src/ATen/native/cuda/cutlass_common.cuh b/aten/src/ATen/native/cuda/cutlass_common.cuh index 8f5143713aa99..b6503c9a73273 100644 --- a/aten/src/ATen/native/cuda/cutlass_common.cuh +++ b/aten/src/ATen/native/cuda/cutlass_common.cuh @@ -26,6 +26,7 @@ struct enable_3x_kernel_for_sm9x : Kernel { }; template +<<<<<<< HEAD struct enable_3x_kernel_for_sm10 : Kernel { template CUTLASS_DEVICE void operator()(Args&&... args) { @@ -36,6 +37,8 @@ struct enable_3x_kernel_for_sm10 : Kernel { }; template +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct enable_3x_kernel_for_sm10_or_later : Kernel { template CUTLASS_DEVICE void operator()(Args&&... args) { diff --git a/aten/src/ATen/native/cuda/fused_adam_utils.cuh b/aten/src/ATen/native/cuda/fused_adam_utils.cuh index 7a8f4a0d0e7e2..dc64c1f27914b 100644 --- a/aten/src/ATen/native/cuda/fused_adam_utils.cuh +++ b/aten/src/ATen/native/cuda/fused_adam_utils.cuh @@ -108,7 +108,11 @@ struct FusedAdamMathFunctor { "depth of 4 for Adam, depth of 5 for Adam with AMSGrad."); using opmath_t = at::opmath_type; C10_DEVICE __forceinline__ void operator()( +<<<<<<< HEAD int64_t chunk_size, +======= + int chunk_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) FusedOptimizerTensorListMetadata& tl, const float* lr_ptr, const double& lr, diff --git a/aten/src/ATen/native/cuda/int4mm.cu b/aten/src/ATen/native/cuda/int4mm.cu index 5444bb57eba7c..bcc8ecb652a85 100644 --- a/aten/src/ATen/native/cuda/int4mm.cu +++ b/aten/src/ATen/native/cuda/int4mm.cu @@ -1304,7 +1304,11 @@ at::Tensor _convert_weight_to_int4pack_cuda( constexpr int32_t kKTileSize = 16; // GPT-FAST assumes nTileSize of 8 for quantized weight tensor. +<<<<<<< HEAD // See https://github.com/meta-pytorch/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/quantize.py#L510 +======= + // See https://github.com/pytorch-labs/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/quantize.py#L510 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Torch dynamo also requires the torch ops has the same output shape for each device. // See https://github.com/pytorch/pytorch/blob/ec284d3a74ec1863685febd53687d491fd99a161/torch/_meta_registrations.py#L3263 constexpr int32_t kNTileSizeTensor = 8; diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp index 152aa324002fb..c873658947457 100644 --- a/aten/src/ATen/native/cuda/jit_utils.cpp +++ b/aten/src/ATen/native/cuda/jit_utils.cpp @@ -45,7 +45,11 @@ namespace at::cuda::jit { // Copied from aten/src/ATen/cuda/llvm_basic.cpp, then modified as above. // If not compiling for ROCm, return the original get_traits_string(). std::string get_traits_string_but_hiprtc_safe() { +<<<<<<< HEAD #if defined(USE_ROCM) && HIP_VERSION_MAJOR < 7 +======= +#if defined(USE_ROCM) && ROCM_VERSION < 70000 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return R"ESCAPE( namespace std { diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 81387bcceaf01..a9602f6ef425b 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -55,7 +55,11 @@ bool can_vectorize(const T * ptr, int alignment) { }; +<<<<<<< HEAD template +======= +template +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __global__ void RowwiseMomentsCUDAKernel( int64_t N, T_ACC eps, @@ -89,6 +93,7 @@ __global__ void RowwiseMomentsCUDAKernel( T_ACC m1; T_ACC m2; thrust::tie(m2, m1) = welford_op.project(val); +<<<<<<< HEAD if constexpr (!rms_norm){ mean[i] = m1; rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); @@ -100,6 +105,14 @@ __global__ void RowwiseMomentsCUDAKernel( } template +======= + mean[i] = m1; + rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); + } +} + +template +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __global__ void LayerNormForwardCUDAKernel( int64_t N, const T* X, @@ -113,6 +126,7 @@ __global__ void LayerNormForwardCUDAKernel( const int64_t index = i * N + j; const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); +<<<<<<< HEAD if constexpr (!rms_norm){ const T_ACC beta_v = beta == nullptr ? T_ACC(0) : static_cast(beta[j]); @@ -122,6 +136,13 @@ __global__ void LayerNormForwardCUDAKernel( } else { Y[index] = (static_cast(X[index])) * static_cast(rstd[i]) * gamma_v; } +======= + const T_ACC beta_v = + beta == nullptr ? T_ACC(0) : static_cast(beta[j]); + Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]) * gamma_v + + beta_v; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -133,11 +154,16 @@ struct WelfordDataLN{ C10_HOST_DEVICE WelfordDataLN(float mean, float sigma2, float count): mean(mean), sigma2(sigma2), count(count) {} }; +<<<<<<< HEAD template __device__ +======= +template __device__ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) WelfordDataLN cuWelfordOnlineSum( const U val, const WelfordDataLN& curr_sum) { +<<<<<<< HEAD if constexpr (!rms_norm){ U delta = val - curr_sum.mean; U new_count = curr_sum.count + 1.f; @@ -153,10 +179,24 @@ WelfordDataLN cuWelfordOnlineSum( } template __device__ +======= + U delta = val - curr_sum.mean; + U new_count = curr_sum.count + 1.f; +#if defined(USE_ROCM) && defined(PYTORCH_LAYERNORM_FAST_RECIPROCAL) + U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count); +#else + U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster +#endif + return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; +} + +__device__ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) WelfordDataLN cuWelfordCombine( const WelfordDataLN dataB, const WelfordDataLN dataA ) { +<<<<<<< HEAD if constexpr (!rms_norm){ using U = decltype(dataB.count); U delta = dataB.mean - dataA.mean; @@ -183,6 +223,30 @@ WelfordDataLN cuWelfordCombine( } template +======= + using U = decltype(dataB.count); + U delta = dataB.mean - dataA.mean; + U count = dataA.count + dataB.count; + U mean, sigma2; + if (count > decltype(dataB.count){0}) { +#if defined(USE_ROCM) && defined(PYTORCH_LAYERNORM_FAST_RECIPROCAL) + auto coef = __builtin_amdgcn_rcpf(count); +#else + auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division +#endif + auto nA = dataA.count * coef; + auto nB = dataB.count * coef; + mean = nA*dataA.mean + nB*dataB.mean; + sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB; + } else { + mean = U(0); + sigma2 = U(0); + } + return {mean, sigma2, count}; +} + +template +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __device__ WelfordDataLN compute_stats( const T* __restrict__ X, const int N, @@ -201,13 +265,23 @@ __device__ WelfordDataLN compute_stats( vec_t data = X_vec[i]; #pragma unroll for (int ii=0; ii < vec_size; ii++){ +<<<<<<< HEAD wd = cuWelfordOnlineSum(static_cast(data.val[ii]), wd); +======= + wd = cuWelfordOnlineSum(static_cast(data.val[ii]), wd); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // intra-warp reduction for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { +<<<<<<< HEAD WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset), WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)}; wd = cuWelfordCombine(wd, wdB); +======= + WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset), + WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)}; + wd = cuWelfordCombine(wd, wdB); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // threadIdx.x == 0 has correct values for each warp // inter-warp reductions @@ -228,7 +302,11 @@ __device__ WelfordDataLN compute_stats( WelfordDataLN wdB{meansigmabuf[2*threadIdx.y], meansigmabuf[2*threadIdx.y+1], countbuf[threadIdx.y]}; +<<<<<<< HEAD wd = cuWelfordCombine(wd, wdB); +======= + wd = cuWelfordCombine(wd, wdB); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } __syncthreads(); } @@ -245,7 +323,11 @@ __device__ WelfordDataLN compute_stats( } +<<<<<<< HEAD template >>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) typename std::enable_if_t, int> = 0> __device__ __inline__ void vectorized_layer_norm_kernel_impl( const int N, @@ -260,7 +342,11 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( //as one thread would have to write 3 consecutive floats auto i1 = blockIdx.x; const T * block_row = X + i1 * N; +<<<<<<< HEAD WelfordDataLN wd = compute_stats(block_row, N, s_data); +======= + WelfordDataLN wd = compute_stats(block_row, N, s_data); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using vec_t = aligned_vector; const vec_t * X_vec = reinterpret_cast(block_row); @@ -283,48 +369,73 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( if (gamma_vec != nullptr && beta_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ +<<<<<<< HEAD if constexpr (!rms_norm){ out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + static_cast(beta_vec[i].val[ii]); } else { out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * static_cast(data.val[ii])); } +======= + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + + static_cast(beta_vec[i].val[ii]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } else if (gamma_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ +<<<<<<< HEAD if constexpr (!rms_norm){ out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)); } else { out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * static_cast(data.val[ii])); } +======= + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } else if (beta_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ +<<<<<<< HEAD out.val[ii] = (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + static_cast(beta_vec[i].val[ii]); +======= + out.val[ii] = (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + static_cast(beta_vec[i].val[ii]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } else { #pragma unroll for (int ii=0; ii < vec_size; ii++){ +<<<<<<< HEAD if constexpr (!rms_norm){ out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); } else { out.val[ii] = rstd_val * static_cast(data.val[ii]); } +======= + out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } Y_vec[i] = out; } if (thrx == 0) { +<<<<<<< HEAD if constexpr (!rms_norm){ mean[i1] = wd.mean; } +======= + mean[i1] = wd.mean; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) rstd[i1] = rstd_val; } } +<<<<<<< HEAD template >>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) typename std::enable_if_t, int> = 0> __device__ __inline__ void vectorized_layer_norm_kernel_impl( const int /*N*/, @@ -339,7 +450,11 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( } //to avoid windows SFINAE errors +<<<<<<< HEAD template +======= +template +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __global__ void vectorized_layer_norm_kernel( const int N, T_ACC eps, @@ -349,11 +464,19 @@ __global__ void vectorized_layer_norm_kernel( T_ACC* mean, T_ACC* rstd, T* Y){ +<<<<<<< HEAD vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y); } template +======= + vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y); + } + + +template +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __device__ __inline__ void compute_gI( const T* __restrict__ dY, const T* __restrict__ X, @@ -364,10 +487,14 @@ __device__ __inline__ void compute_gI( const int N, T_ACC * buf){ const auto i1 = blockIdx.x; +<<<<<<< HEAD T_ACC mean_val = 0; if constexpr (!rms_norm){ mean_val = mean[i1]; } +======= + const T_ACC mean_val = mean[i1]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const T_ACC rstd_val = rstd[i1]; T_ACC stats_x1{0}, stats_x2{0}; constexpr int unroll = 4; @@ -383,18 +510,24 @@ __device__ __inline__ void compute_gI( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l+k]) : T_ACC(1); const auto c_h = static_cast(X_i[l+k]); const auto c_loss = static_cast(dY_i[l+k]); +<<<<<<< HEAD if constexpr (!rms_norm){ stats_x1 += c_loss * gamma_val; stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; } else { stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; } +======= + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } for (; l < N; l ++) { const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); const auto c_h = static_cast(X_i[l]); const auto c_loss = static_cast(dY_i[l]); +<<<<<<< HEAD if constexpr (!rms_norm){ stats_x1 += c_loss * gamma_val; stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; @@ -416,6 +549,20 @@ __device__ __inline__ void compute_gI( if constexpr (!rms_norm){ stats_x1 = buf[0]; } +======= + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } + + stats_x1 = cuda_utils::BlockReduceSum(stats_x1, buf); + stats_x2 = cuda_utils::BlockReduceSum(stats_x2, buf); + if (threadIdx.x == 0) { + buf[0] = stats_x1; + buf[1] = stats_x2; + } + __syncthreads(); + stats_x1 = buf[0]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stats_x2 = buf[1]; T_ACC fH = N; T_ACC term1 = (T_ACC(1) / fH) * rstd_val; @@ -426,6 +573,7 @@ __device__ __inline__ void compute_gI( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); T_ACC f_grad_input = fH * gamma_val * dy; +<<<<<<< HEAD if constexpr (!rms_norm){ f_grad_input -= (x - mean_val) * rstd_val * stats_x2; f_grad_input -= stats_x1; @@ -433,13 +581,21 @@ __device__ __inline__ void compute_gI( f_grad_input -= (x) * rstd_val * stats_x2; } +======= + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f_grad_input *= term1; dX_i[l] = f_grad_input; } } +<<<<<<< HEAD template +======= +template +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __global__ void layer_norm_grad_input_kernel( const T* __restrict__ dY, const T* __restrict__ X, @@ -451,7 +607,11 @@ __global__ void layer_norm_grad_input_kernel( alignas(sizeof(double)) extern __shared__ char s_data1[]; T_ACC * buf = reinterpret_cast(&s_data1); +<<<<<<< HEAD compute_gI(dY, X, mean, rstd, gamma, dX, N, buf); +======= + compute_gI(dY, X, mean, rstd, gamma, dX, N, buf); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } @@ -460,7 +620,11 @@ __global__ void layer_norm_grad_input_kernel( // faster measured at PT operator level, with cases seeing a 2X speedup (where N >> M). // There are no noticeable regressions on the rest of the sizes. +<<<<<<< HEAD template +======= +template +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __global__ void layer_norm_grad_input_kernel_vectorized( const T* __restrict__ dY, const T* __restrict__ X, @@ -473,10 +637,14 @@ __global__ void layer_norm_grad_input_kernel_vectorized( T_ACC* reduce_buf = reinterpret_cast(&shared_data); const auto bIdx = blockIdx.x; +<<<<<<< HEAD T_ACC mean_val = 0; if constexpr (!rms_norm){ mean_val = mean[bIdx]; } +======= + const T_ACC mean_val = mean[bIdx]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const T_ACC rstd_val = rstd[bIdx]; const T* X_i = X + bIdx * N; const T* dY_i = dY + bIdx * N; @@ -508,12 +676,17 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = static_cast(gamma_vec_reg.val[k]); const auto c_h = static_cast(X_i_vec_reg.val[k]); const auto c_loss = static_cast(dY_i_vec_reg.val[k]); +<<<<<<< HEAD if constexpr (!rms_norm){ stats_x1 += c_loss * gamma_val; stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; } else { stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; } +======= + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -522,6 +695,7 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); const auto c_h = static_cast(X_i[l]); const auto c_loss = static_cast(dY_i[l]); +<<<<<<< HEAD if constexpr (!rms_norm){ stats_x1 += c_loss * gamma_val; stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; @@ -545,6 +719,21 @@ __global__ void layer_norm_grad_input_kernel_vectorized( if constexpr (!rms_norm){ stats_x1 = reduce_buf[0]; } +======= + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } + + // Reduction in Shared Memory + stats_x1 = cuda_utils::BlockReduceSum(stats_x1, reduce_buf); + stats_x2 = cuda_utils::BlockReduceSum(stats_x2, reduce_buf); + if (threadIdx.x == 0) { + reduce_buf[0] = stats_x1; + reduce_buf[1] = stats_x2; + } + __syncthreads(); + stats_x1 = reduce_buf[0]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stats_x2 = reduce_buf[1]; T_ACC fH = N; @@ -566,12 +755,17 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto dy = static_cast(dY_i_vec_reg.val[k]); T_ACC f_grad_input = fH * gamma_val * dy; +<<<<<<< HEAD if constexpr (!rms_norm){ f_grad_input -= (x - mean_val) * rstd_val * stats_x2; f_grad_input -= stats_x1; } else { f_grad_input -= (x) * rstd_val * stats_x2; } +======= + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f_grad_input *= term1; dX_i_vec_reg.val[k] = f_grad_input; } @@ -586,19 +780,28 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); T_ACC f_grad_input = fH * gamma_val * dy; +<<<<<<< HEAD if constexpr (!rms_norm){ f_grad_input -= (x - mean_val) * rstd_val * stats_x2; f_grad_input -= stats_x1; } else { f_grad_input -= (x) * rstd_val * stats_x2; } +======= + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f_grad_input *= term1; dX_i[l] = f_grad_input; } } +<<<<<<< HEAD template +======= +template +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __global__ void GammaBetaBackwardSimpleCUDAKernel( int64_t M, int64_t N, @@ -614,6 +817,7 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel( T_ACC sum2 = 0; for (int64_t i = 0; i < M; ++i) { const int64_t index = i * N + j; +<<<<<<< HEAD if constexpr (!rms_norm){ sum1 += dg == nullptr ? T_ACC(0) : static_cast(dY[index]) * @@ -625,14 +829,25 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel( : static_cast(dY[index]) * (static_cast(X[index])) * static_cast(rstd[i]); } +======= + sum1 += dg == nullptr ? T_ACC(0) + : static_cast(dY[index]) * + (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]); + sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if (dg != nullptr) { dg[j] = sum1; } if (db != nullptr) { +<<<<<<< HEAD if constexpr (!rms_norm){ db[j] = sum2; } +======= + db[j] = sum2; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } } @@ -642,8 +857,12 @@ unsigned int block_dim_x, unsigned int block_dim_y, unsigned int rows_per_block_y, bool check_x, +<<<<<<< HEAD bool check_y, bool rms_norm> +======= +bool check_y> +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __device__ __forceinline__ void @@ -667,9 +886,13 @@ blockReduceGammaBetaBackwardsHelper( int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y; T_ACC warp_mean = 0, warp_rstd = 0; if (lane_id < rows_per_thread_y && mean_index + lane_id < M) { +<<<<<<< HEAD if constexpr (!rms_norm){ warp_mean = mean[mean_index + lane_id]; } +======= + warp_mean = mean[mean_index + lane_id]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) warp_rstd = rstd[mean_index + lane_id]; } // We do a WARP_SYNC() here because we use WARP_SHFL below to access @@ -696,6 +919,7 @@ blockReduceGammaBetaBackwardsHelper( #pragma unroll for (int i = 0; i < rows_per_thread_y; ++i) { +<<<<<<< HEAD T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize); if constexpr (!rms_norm){ T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); @@ -704,6 +928,12 @@ blockReduceGammaBetaBackwardsHelper( } else{ dg_sum += dY_regs[i] * (X_regs[i]) * rstd_reg; } +======= + T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); + T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize); + dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; + db_sum += dY_regs[i]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -712,8 +942,12 @@ unsigned int block_dim_x, unsigned int block_dim_y, unsigned int rows_per_block_y, bool check_x, +<<<<<<< HEAD bool check_y, bool rms_norm> +======= +bool check_y> +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __device__ __forceinline__ void @@ -734,10 +968,17 @@ blockReduceGammaBetaBackwardsWithChecks( M_start += rows_per_block_y * gridDim.y) { int64_t M_end = M_start + rows_per_block_y - 1; if (!check_y || M_end < M) { +<<<<<<< HEAD blockReduceGammaBetaBackwardsHelper (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { blockReduceGammaBetaBackwardsHelper +======= + blockReduceGammaBetaBackwardsHelper + (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } else { + blockReduceGammaBetaBackwardsHelper +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } } @@ -759,8 +1000,12 @@ template >>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) > __global__ void @@ -785,7 +1030,11 @@ __launch_bounds__(block_dim_x * block_dim_y) // When N and M align perfectly with block_dim_x and block_dim_y, we // can skip boundary condition checks that waste instruction issue slots. blockReduceGammaBetaBackwardsWithChecks +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { // In the general case we need to check boundary conditions in the M @@ -793,11 +1042,19 @@ __launch_bounds__(block_dim_x * block_dim_y) // for the inner blocks. So try to avoid those checks when possible. if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) { blockReduceGammaBetaBackwardsWithChecks +<<<<<<< HEAD (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { blockReduceGammaBetaBackwardsWithChecks +======= + + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } else { + blockReduceGammaBetaBackwardsWithChecks + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } } @@ -812,7 +1069,11 @@ __launch_bounds__(block_dim_x * block_dim_y) if (dg) { dg[thread_y * N + thread_x] = dg_sum; } +<<<<<<< HEAD if (db && !rms_norm) { +======= + if (db) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) db[thread_y * N + thread_x] = db_sum; } } @@ -858,7 +1119,11 @@ __launch_bounds__(block_dim_x * block_dim_y) if (dg) { dg[out_index] = reg_dg; } +<<<<<<< HEAD if (db && !rms_norm) { +======= + if (db) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) db[out_index] = reg_db; } } @@ -869,8 +1134,12 @@ __launch_bounds__(block_dim_x * block_dim_y) template +======= +bool partial_reduction> +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void LaunchAndCheckGammaBetaBackwardKernel( bool aligned_grid, dim3 blocks, @@ -886,7 +1155,11 @@ void LaunchAndCheckGammaBetaBackwardKernel( T* dgamma_data, T* dbeta_data) { if (aligned_grid) { +<<<<<<< HEAD GammaBetaBackwardCUDAKernelTemplate +======= + GammaBetaBackwardCUDAKernelTemplate +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) <<>>( M, N, @@ -897,7 +1170,11 @@ if (aligned_grid) { dgamma_data, dbeta_data); } else { +<<<<<<< HEAD GammaBetaBackwardCUDAKernelTemplate +======= + GammaBetaBackwardCUDAKernelTemplate +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) <<>>( M, N, @@ -913,7 +1190,11 @@ if (aligned_grid) { template +======= +int rows_per_block_y> +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void ConfigureAndLaunchGammaBetaBackwardKernel( const T* dY_data, const T* X_data, @@ -936,16 +1217,27 @@ void ConfigureAndLaunchGammaBetaBackwardKernel( if (blocks.y == 1 && threads.y == 1) { // Optimization: since there is just one thread doing all the summation, we don't need a reduction // across threads. So we set partial_reduction to true. +<<<<<<< HEAD LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); } else { LaunchAndCheckGammaBetaBackwardKernel( +======= + LaunchAndCheckGammaBetaBackwardKernel( + aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); + } else { + LaunchAndCheckGammaBetaBackwardKernel( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); } } +<<<<<<< HEAD template +======= +template +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void LaunchGammaBetaBackwardCUDAKernel( const T* dY_data, const T* X_data, @@ -983,21 +1275,34 @@ void LaunchGammaBetaBackwardCUDAKernel( dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); dgamma_blocks_ptr = dgamma_blocks.data_ptr(); } +<<<<<<< HEAD if (dbeta->defined() && !rms_norm) { +======= + if (dbeta->defined()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto options = dbeta->options(); dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); dbeta_blocks_ptr = dbeta_blocks.data_ptr(); } +<<<<<<< HEAD LaunchAndCheckGammaBetaBackwardKernel( +======= + LaunchAndCheckGammaBetaBackwardKernel( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr); if (dgamma_blocks.defined()) { *dgamma = dgamma_blocks.sum(0); } +<<<<<<< HEAD if constexpr (!rms_norm){ if (dbeta_blocks.defined()) { *dbeta = dbeta_blocks.sum(0); } +======= + if (dbeta_blocks.defined()) { + *dbeta = dbeta_blocks.sum(0); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } else { // We are in the normal case where M is not that large. @@ -1005,6 +1310,7 @@ void LaunchGammaBetaBackwardCUDAKernel( // For small M it is faster to have a smaller tile, otherwise we could have idle threads. // For larger M we use a bigger tile size. if (M < 64) { +<<<<<<< HEAD ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else if (M < 128) { ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); @@ -1012,11 +1318,24 @@ void LaunchGammaBetaBackwardCUDAKernel( ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else { ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); +======= + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + } else if (M < 128) { + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + } else if (M < 256) { + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + } else { + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } } +<<<<<<< HEAD template +======= +template +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void launch_vectorized_layer_norm_kernel( int N, int64_t M, @@ -1045,7 +1364,11 @@ void launch_vectorized_layer_norm_kernel( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(threads.y % 2 == 0 || threads.y == 1); int nshared = threads.y > 1 ? threads.y * 3/2 *sizeof(T_ACC) : 0; +<<<<<<< HEAD vectorized_layer_norm_kernel<<>>(N, eps, X_data, +======= + vectorized_layer_norm_kernel<<>>(N, eps, X_data, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gamma_data, beta_data, mean_data, rstd_data, Y_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -1067,7 +1390,11 @@ void launch_vectorized_layer_norm_kernel( blocks.x = (remaining > blocks.x) ? blocks.x : remaining; +<<<<<<< HEAD vectorized_layer_norm_kernel<<>>(N, eps, X_data2, +======= + vectorized_layer_norm_kernel<<>>(N, eps, X_data2, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gamma_data, beta_data, mean_data2, rstd_data2, Y_data2); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -1077,7 +1404,11 @@ void launch_vectorized_layer_norm_kernel( } +<<<<<<< HEAD template +======= +template +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void LayerNormKernelImplInternal( const Tensor& X, const Tensor& gamma, @@ -1096,7 +1427,11 @@ void LayerNormKernelImplInternal( const T* gamma_data = gamma.defined() ? gamma.const_data_ptr() : nullptr; const T* beta_data = beta.defined() ? beta.const_data_ptr() : nullptr; T* Y_data = Y->data_ptr(); +<<<<<<< HEAD T_ACC* mean_data = !rms_norm ? mean->data_ptr() : nullptr; +======= + T_ACC* mean_data = mean->data_ptr(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T_ACC* rstd_data = rstd->data_ptr(); // check if can take fast path - all tensors are properly aligned, N is less than 2^24 (to use float count), @@ -1111,6 +1446,7 @@ void LayerNormKernelImplInternal( if ((std::is_same_v || std::is_same_v || std::is_same_v) && N <= static_cast(1ULL << std::numeric_limits::digits) && N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma && can_vec_beta) { +<<<<<<< HEAD launch_vectorized_layer_norm_kernel(static_cast(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); } else { cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); @@ -1119,6 +1455,16 @@ void LayerNormKernelImplInternal( N, eps, X_data, mean_data, rstd_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); LayerNormForwardCUDAKernel<<>>( +======= + launch_vectorized_layer_norm_kernel(static_cast(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); + } else { + cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); + RowwiseMomentsCUDAKernel + <<>>( + N, eps, X_data, mean_data, rstd_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + LayerNormForwardCUDAKernel<<>>( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) N, X_data, mean_data, rstd_data, gamma_data, beta_data, Y_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1146,6 +1492,7 @@ void LayerNormKernelImpl( }); } +<<<<<<< HEAD void RmsNormKernelImpl( const Tensor& X, const Tensor& gamma, @@ -1169,6 +1516,9 @@ AT_DISPATCH_FLOATING_TYPES_AND2( } template __device__ +======= +template __device__ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void cuLoadWriteStridedInputs( const int i1_block, const int thr_load_row_off, @@ -1186,10 +1536,14 @@ void cuLoadWriteStridedInputs( { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { +<<<<<<< HEAD T_ACC curr_mean = 0; if constexpr (!rms_norm){ curr_mean = mean[i1]; } +======= + T_ACC curr_mean = mean[i1]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T_ACC curr_rstd = rstd[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; @@ -1214,7 +1568,11 @@ void cuLoadWriteStridedInputs( } } +<<<<<<< HEAD template __device__ +======= +template __device__ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void cuLoadAddStridedInputs( const int i1_block, const int thr_load_row_off, @@ -1232,11 +1590,15 @@ void cuLoadAddStridedInputs( { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { +<<<<<<< HEAD T_ACC curr_mean = 0; if constexpr (!rms_norm){ curr_mean = mean[i1]; } +======= + T_ACC curr_mean = mean[i1]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T_ACC curr_rstd = rstd[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; @@ -1252,7 +1614,11 @@ void cuLoadAddStridedInputs( } } +<<<<<<< HEAD template __global__ +======= +template __global__ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void cuComputePartGradGammaBeta( const T* __restrict__ dout, const T* __restrict__ input, @@ -1278,9 +1644,15 @@ void cuComputePartGradGammaBeta( T_ACC* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; // compute partial sums from strided inputs // do this to increase number of loads in flight +<<<<<<< HEAD cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); +======= + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } __syncthreads(); // inter-warp reductions @@ -1319,7 +1691,11 @@ void cuComputePartGradGammaBeta( } } +<<<<<<< HEAD template __global__ +======= +template __global__ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void cuComputeGradGammaBeta( const T_ACC* part_grad_gamma, const T_ACC* part_grad_beta, @@ -1344,9 +1720,13 @@ void cuComputeGradGammaBeta( if (i2 < N) { for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { sum_gamma += part_grad_gamma_ptr[warp_offset*N]; +<<<<<<< HEAD if constexpr (!rms_norm){ sum_beta += part_grad_beta_ptr[warp_offset*N]; } +======= + sum_beta += part_grad_beta_ptr[warp_offset*N]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -1364,9 +1744,13 @@ void cuComputeGradGammaBeta( if (threadIdx.y < offset) { const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; sum_gamma += buf[read_idx]; +<<<<<<< HEAD if constexpr (!rms_norm){ sum_beta += buf[read_idx+nbsize3]; } +======= + sum_beta += buf[read_idx+nbsize3]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } __syncthreads(); } @@ -1377,14 +1761,22 @@ void cuComputeGradGammaBeta( grad_gamma[i2] = sum_gamma; } if (grad_beta) { +<<<<<<< HEAD if constexpr (!rms_norm){ grad_beta[i2] = sum_beta; } +======= + grad_beta[i2] = sum_beta; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } } +<<<<<<< HEAD template __global__ +======= +template __global__ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void cuComputeGradInput( const T* __restrict__ dout, const T* __restrict__ input, @@ -1398,10 +1790,14 @@ void cuComputeGradInput( for (int i1=blockIdx.y; i1 < M; i1 += gridDim.y) { T_ACC sum_loss1 = T_ACC(0); T_ACC sum_loss2 = T_ACC(0); +<<<<<<< HEAD T_ACC c_mean = 0; if constexpr (!rms_norm){ c_mean = mean[i1]; } +======= + T_ACC c_mean = mean[i1]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const T_ACC c_rstd = rstd[i1]; const T* k_input = input + i1*N; const T* k_dout = dout + i1*N; @@ -1414,31 +1810,45 @@ void cuComputeGradInput( const T_ACC gamma_idx = static_cast((idx((idx((idx>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } else { for( int l = 0; l < N ; l += numx) { int idx = l + thrx; const T_ACC c_h = static_cast((idx((idx>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // intra-warp reductions for (int mask = blockDim.x/2; mask > 0; mask /= 2) { +<<<<<<< HEAD if constexpr (!rms_norm){ sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); } +======= + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); } // inter-warp reductions @@ -1449,33 +1859,49 @@ void cuComputeGradInput( // upper half of warps write to shared if (threadIdx.y >= offset && threadIdx.y < 2*offset) { const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; +<<<<<<< HEAD if constexpr (!rms_norm){ buf[2*wrt_i] = sum_loss1; } +======= + buf[2*wrt_i] = sum_loss1; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) buf[2*wrt_i+1] = sum_loss2; } __syncthreads(); // lower half merges if (threadIdx.y < offset) { const int read_i = threadIdx.y * blockDim.x + threadIdx.x; +<<<<<<< HEAD if constexpr (!rms_norm){ sum_loss1 += buf[2*read_i]; } +======= + sum_loss1 += buf[2*read_i]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sum_loss2 += buf[2*read_i+1]; } __syncthreads(); } if (threadIdx.y == 0) { +<<<<<<< HEAD if constexpr (!rms_norm){ buf[2*threadIdx.x] = sum_loss1; } +======= + buf[2*threadIdx.x] = sum_loss1; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) buf[2*threadIdx.x+1] = sum_loss2; } __syncthreads(); if (threadIdx.y !=0) { +<<<<<<< HEAD if constexpr (!rms_norm){ sum_loss1 = buf[2*threadIdx.x]; } +======= + sum_loss1 = buf[2*threadIdx.x]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sum_loss2 = buf[2*threadIdx.x+1]; } } @@ -1488,12 +1914,17 @@ void cuComputeGradInput( const T_ACC c_h = static_cast(k_input[l]); const T_ACC c_loss = static_cast(k_dout[l]); T_ACC f_grad_input = fH * c_loss * gamma[l]; +<<<<<<< HEAD if constexpr (!rms_norm){ f_grad_input -= sum_loss1; f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; } else { f_grad_input -= (c_h) * c_rstd * sum_loss2; } +======= + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -1502,12 +1933,17 @@ void cuComputeGradInput( const T_ACC c_h = static_cast(k_input[l]); const T_ACC c_loss = static_cast(k_dout[l]); T_ACC f_grad_input = fH * c_loss; +<<<<<<< HEAD if constexpr (!rms_norm){ f_grad_input -= sum_loss1; f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; } else { f_grad_input -= (c_h) * c_rstd * sum_loss2; } +======= + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -1517,7 +1953,11 @@ void cuComputeGradInput( } } +<<<<<<< HEAD template +======= +template +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void LayerNormBackwardKernelImplInternal( const Tensor& dY, const Tensor& X, @@ -1531,9 +1971,13 @@ void LayerNormBackwardKernelImplInternal( Tensor* dbeta) { using T_ACC = acc_type; TORCH_CHECK(dY.numel() == M * N); +<<<<<<< HEAD if constexpr (!rms_norm){ TORCH_CHECK(mean.numel() == M); } +======= + TORCH_CHECK(mean.numel() == M); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(rstd.numel() == M); TORCH_CHECK(M <= at::cuda::getCurrentDeviceProperties()->maxGridSize[0], "M should be less than maximum CUDA grid size, \ file a support request to support bigger batches"); @@ -1559,7 +2003,11 @@ void LayerNormBackwardKernelImplInternal( threads1.y > 1 ? threads1.y*threads1.x*sizeof(T_ACC) : 0; +<<<<<<< HEAD cuComputeGradInput<<>>( +======= + cuComputeGradInput<<>>( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dY_data, X_data, M, N, @@ -1571,7 +2019,11 @@ void LayerNormBackwardKernelImplInternal( } else { const dim3 blocks(M); int nshared = (num_threads()/warp_size) * sizeof(T_ACC); +<<<<<<< HEAD layer_norm_grad_input_kernel<<>>(dY_data, +======= + layer_norm_grad_input_kernel<<>>(dY_data, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1585,12 +2037,22 @@ void LayerNormBackwardKernelImplInternal( const unsigned int alignment = sizeof(T) * vec_size; bool bAlignedBuffers = can_vectorize(dY_data, alignment) && can_vectorize(X_data, alignment) && can_vectorize(gamma_data, alignment) && can_vectorize(dX_data, alignment); +<<<<<<< HEAD if (bAlignedBuffers && bTargetDataTypes && bVectorSizeMultiple) { layer_norm_grad_input_kernel_vectorized<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { layer_norm_grad_input_kernel<<>>(dY_data, +======= + + if (bAlignedBuffers && bTargetDataTypes && bVectorSizeMultiple) { + layer_norm_grad_input_kernel_vectorized<<>>(dY_data, + X_data, mean_data, rstd_data, gamma_data, dX_data, N); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + layer_norm_grad_input_kernel<<>>(dY_data, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1606,7 +2068,11 @@ void LayerNormBackwardKernelImplInternal( if (M < 128) { // For small batch size, do colwise reduce directly. const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; +<<<<<<< HEAD GammaBetaBackwardSimpleCUDAKernel +======= + GammaBetaBackwardSimpleCUDAKernel +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) <<>>( M, N, @@ -1630,7 +2096,11 @@ void LayerNormBackwardKernelImplInternal( Tensor part_grad_gamma = at::empty({part_size,N}, gamma.options().dtype(part_grad_dtype)); Tensor part_grad_beta = at::native::empty_like(part_grad_gamma); +<<<<<<< HEAD cuComputePartGradGammaBeta<<>>( +======= + cuComputePartGradGammaBeta<<>>( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dY_data, X_data, M,N, @@ -1644,7 +2114,11 @@ void LayerNormBackwardKernelImplInternal( const dim3 blocks3((N + threads3.x - 1) / threads3.x, 1, 1); const int nshared3 = threads3.x * threads3.y * sizeof(T_ACC); +<<<<<<< HEAD cuComputeGradGammaBeta<<>>( +======= + cuComputeGradGammaBeta<<>>( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) part_grad_gamma.template data_ptr(), part_grad_beta.template data_ptr(), part_size, @@ -1654,7 +2128,11 @@ void LayerNormBackwardKernelImplInternal( C10_CUDA_KERNEL_LAUNCH_CHECK(); } #else +<<<<<<< HEAD LaunchGammaBetaBackwardCUDAKernel( +======= + LaunchGammaBetaBackwardCUDAKernel( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); #endif } @@ -1682,6 +2160,7 @@ void LayerNormBackwardKernelImpl( }); } +<<<<<<< HEAD void RMSNormBackwardKernelImpl( const Tensor& dY, const Tensor& X, @@ -1705,6 +2184,10 @@ void RMSNormBackwardKernelImpl( } // namespace +======= +} // namespace + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::tuple layer_norm_cuda( const Tensor& input, IntArrayRef normalized_shape, @@ -1833,6 +2316,7 @@ std::tuple layer_norm_backward_cuda( return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta)); } +<<<<<<< HEAD /* RMSNorm is implemented by reusing layer_norm's kernels */ std::tuple _fused_rms_norm_cuda( const Tensor& input, @@ -1940,6 +2424,8 @@ std::tuple _fused_rms_norm_backward_cuda( return std::make_tuple(std::move(dX), std::move(dgamma)); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) REGISTER_DISPATCH(LayerNormKernel, &LayerNormKernelImpl) REGISTER_DISPATCH(LayerNormBackwardKernel, &LayerNormBackwardKernelImpl) diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp index 71cbe361a0373..d82f7fa14a8d2 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp @@ -1433,7 +1433,11 @@ Tensor& cholesky_inverse_kernel_impl(Tensor &result, Tensor& infos, bool upper) // This function calculates the inverse matrix in-place // result should be in column major order and contain matrices to invert // the content of result is overwritten by 'apply_cholesky_inverse' +<<<<<<< HEAD #if defined(USE_LINALG_SOLVER) +======= +#if defined(USE_LINALG_SOLVER) && !defined(USE_ROCM) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto preferred_backend = at::globalContext().linalgPreferredBackend(); switch (preferred_backend) { case at::LinalgBackend::Cusolver: diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp index 371b77722cd54..befe3e129e546 100644 --- a/aten/src/ATen/native/cudnn/BatchNorm.cpp +++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp @@ -28,6 +28,7 @@ std::tuple cudnn_batch_norm( TORCH_CHECK(false, "cudnn_batch_norm: ATen not compiled with cuDNN support"); } +<<<<<<< HEAD std::tuple cudnn_batch_norm_out( const Tensor& input, const Tensor& weight, @@ -44,6 +45,8 @@ std::tuple cudnn_batch_norm_out( AT_ERROR("cudnn_batch_norm_out: ATen not compiled with cuDNN support"); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::tuple cudnn_batch_norm_backward( const Tensor& input, const Tensor& grad_output, @@ -136,12 +139,16 @@ size_t _get_cudnn_batch_norm_reserve_space_size( return reserve_size; } +<<<<<<< HEAD // Param `reserve` is a placeholder, just passing an empty tensor. // usage: // auto reserve = torch::empty({0}, torch::device(torch::kCUDA)); // at::native::cudnn_batch_norm_out(..., epsilon, output, save_mean, save_var, // reserve); std::tuple cudnn_batch_norm_out( +======= +std::tuple cudnn_batch_norm( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const Tensor& input_t, const Tensor& weight_t, const std::optional& bias_t_opt, @@ -149,11 +156,15 @@ std::tuple cudnn_batch_norm_out( const std::optional& running_var_t_opt, bool training, double exponential_average_factor, +<<<<<<< HEAD double epsilon, Tensor& output_t, Tensor& save_mean, Tensor& save_var, Tensor& reserve) { +======= + double epsilon) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt); @@ -193,6 +204,12 @@ std::tuple cudnn_batch_norm_out( cudnnBatchNormMode_t mode = getCudnnBatchNormMode( training, input->suggest_memory_format(), input->dim()); +<<<<<<< HEAD +======= + auto output_t = + at::empty_like(*input, input->options(), input->suggest_memory_format()); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorArg output{output_t, "output", 0}; auto handle = getCudnnHandle(); @@ -204,8 +221,20 @@ std::tuple cudnn_batch_norm_out( Constant one(dataType, 1); Constant zero(dataType, 0); +<<<<<<< HEAD + + if (training) { +======= + Tensor save_mean, save_var; + + Tensor reserve; if (training) { + int64_t num_features = input_t.size(1); + save_mean = at::empty({num_features}, weight_t.options()); + save_var = at::empty({num_features}, weight_t.options()); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto op = CUDNN_BATCHNORM_OPS_BN; size_t workspace_size; AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( @@ -253,6 +282,12 @@ std::tuple cudnn_batch_norm_out( reserve_size)); } else { reserve = at::empty({0}, input->options().dtype(kByte)); +<<<<<<< HEAD +======= + // This keeps a consistent output with native_batch_norm + save_mean = at::empty({0}, weight_t.options()); + save_var = at::empty({0}, weight_t.options()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AT_CUDNN_CHECK(cudnnBatchNormalizationForwardInference( handle, mode, @@ -273,6 +308,7 @@ std::tuple cudnn_batch_norm_out( // save_mean and save_var can be undefined // If this causes problems, we can initialize them to empty tensors // of the correct type +<<<<<<< HEAD return std::tuple{ output_t, save_mean, save_var, reserve}; } @@ -315,6 +351,12 @@ std::tuple cudnn_batch_norm( reserve); } +======= + return std::tuple{ + output_t, save_mean, save_var, reserve}; +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // NB: CuDNN only implements the backward algorithm for batchnorm // in training mode (evaluation mode batchnorm has a different algorithm), // which is why this doesn't accept a 'training' parameter. diff --git a/aten/src/ATen/native/cudnn/ConvShared.cpp b/aten/src/ATen/native/cudnn/ConvShared.cpp index 9b32f05482d5c..24a72a85883b9 100644 --- a/aten/src/ATen/native/cudnn/ConvShared.cpp +++ b/aten/src/ATen/native/cudnn/ConvShared.cpp @@ -169,8 +169,12 @@ std::string repro_from_args(const ConvolutionParams& params) { ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n"; ss << "import torch\n"; ss << "torch.backends.cuda.matmul.allow_tf32 = " +<<<<<<< HEAD << pybool(at::globalContext().float32Precision("cuda", "matmul") == "tf32") << "\n"; +======= + << pybool(at::globalContext().allowTF32CuBLAS()) << "\n"; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ss << "torch.backends.cudnn.benchmark = " << pybool(at::globalContext().benchmarkCuDNN()) << "\n"; ss << "torch.backends.cudnn.deterministic = " << pybool(params.deterministic) @@ -726,7 +730,11 @@ Tensor cudnn_convolution_relu( auto& ctx = at::globalContext(); bool benchmark = ctx.benchmarkCuDNN(); +<<<<<<< HEAD bool allow_tf32 = ctx.allowTF32CuDNN("conv"); +======= + bool allow_tf32 = ctx.allowTF32CuDNN(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto _bias = bias_t.has_value() ? bias_t.value() : at::zeros( @@ -784,7 +792,11 @@ Tensor cudnn_convolution_add_relu( } auto& ctx = at::globalContext(); +<<<<<<< HEAD bool allow_tf32 = ctx.allowTF32CuDNN("conv"); +======= + bool allow_tf32 = ctx.allowTF32CuDNN(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool benchmark = ctx.benchmarkCuDNN(); auto _alpha = alpha.has_value() ? alpha.value().to() : 1.0; auto _bias = bias_t.has_value() diff --git a/aten/src/ATen/native/cudnn/Conv_v7.cpp b/aten/src/ATen/native/cudnn/Conv_v7.cpp index 081b4afa15ac5..8dc97a8cd853e 100644 --- a/aten/src/ATen/native/cudnn/Conv_v7.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v7.cpp @@ -285,7 +285,11 @@ struct algorithm_search { sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution forward algorithms"); int perf_count; +<<<<<<< HEAD c10::SmallVector perf_results; +======= + std::unique_ptr perf_results(new perf_t[num_algos]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!benchmark) { AT_CUDNN_CHECK_WITH_SHAPES( cudnnGetConvolutionForwardAlgorithm_v7( @@ -296,7 +300,11 @@ struct algorithm_search { args.odesc.desc(), num_algos, &perf_count, +<<<<<<< HEAD perf_results.data()), +======= + perf_results.get()), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args); } else { size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); @@ -314,7 +322,11 @@ struct algorithm_search { args.output.data_ptr(), num_algos, &perf_count, +<<<<<<< HEAD perf_results.data(), +======= + perf_results.get(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ws.data, ws.size), args); @@ -324,7 +336,11 @@ struct algorithm_search { // memory, e.g. a few GBs. c10::cuda::CUDACachingAllocator::emptyCache(); } +<<<<<<< HEAD return getValidAlgorithms(perf_results.data(), args, perf_count); +======= + return getValidAlgorithms(perf_results.get(), args, perf_count); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } static void getWorkspaceSize( @@ -369,8 +385,12 @@ struct algorithm_search { sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward data algorithms."); int perf_count; +<<<<<<< HEAD c10::SmallVector perf_results; +======= + std::unique_ptr perf_results(new perf_t[num_algos]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!benchmark) { AT_CUDNN_CHECK_WITH_SHAPES( cudnnGetConvolutionBackwardDataAlgorithm_v7( @@ -381,7 +401,11 @@ struct algorithm_search { args.idesc.desc(), num_algos, &perf_count, +<<<<<<< HEAD perf_results.data()), +======= + perf_results.get()), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args); } else { size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); @@ -399,7 +423,11 @@ struct algorithm_search { args.input.data_ptr(), num_algos, &perf_count, +<<<<<<< HEAD perf_results.data(), +======= + perf_results.get(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ws.data, ws.size), args); @@ -409,7 +437,11 @@ struct algorithm_search { // memory, e.g. a few GBs. c10::cuda::CUDACachingAllocator::emptyCache(); } +<<<<<<< HEAD return getValidAlgorithms(perf_results.data(), args, perf_count); +======= + return getValidAlgorithms(perf_results.get(), args, perf_count); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } static void getWorkspaceSize( @@ -457,8 +489,12 @@ struct algorithm_search { static_assert( sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); +<<<<<<< HEAD c10::SmallVector perf_results; +======= + std::unique_ptr perf_results(new perf_t[num_algos]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int perf_count; if (!benchmark) { AT_CUDNN_CHECK_WITH_SHAPES( @@ -470,7 +506,11 @@ struct algorithm_search { args.wdesc.desc(), num_algos, &perf_count, +<<<<<<< HEAD perf_results.data()), +======= + perf_results.get()), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args); } else { size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); @@ -488,7 +528,11 @@ struct algorithm_search { args.weight.data_ptr(), num_algos, &perf_count, +<<<<<<< HEAD perf_results.data(), +======= + perf_results.get(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ws.data, ws.size), args); @@ -498,7 +542,11 @@ struct algorithm_search { // memory, e.g. a few GBs. c10::cuda::CUDACachingAllocator::emptyCache(); } +<<<<<<< HEAD return getValidAlgorithms(perf_results.data(), args, perf_count); +======= + return getValidAlgorithms(perf_results.get(), args, perf_count); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } static void getWorkspaceSize( diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index 1658ce34ca6c5..523cbf5561b1d 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -2,6 +2,7 @@ #include #include +<<<<<<< HEAD #if AT_CUDNN_ENABLED() #include #endif @@ -9,6 +10,11 @@ #if defined(USE_ROCM) || !AT_CUDNN_ENABLED() || \ (defined(CUDNN_VERSION) && CUDNN_VERSION < 8900) || \ (defined(CUDNN_FRONTEND_VERSION) && CUDNN_FRONTEND_VERSION < 10100) +======= +#if defined(USE_ROCM) || !AT_CUDNN_ENABLED() || \ + (defined(CUDNN_VERSION) && CUDNN_VERSION < 8900) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace at { namespace native { @@ -88,6 +94,7 @@ void run_cudnn_SDP_bprop( false, "PyTorch was not compiled with cuDNN Flash Attention enabled!"); } +<<<<<<< HEAD void run_cudnn_SDP_bprop_nestedtensor( int64_t b, int64_t h_q, @@ -119,6 +126,8 @@ void run_cudnn_SDP_bprop_nestedtensor( false, "PyTorch was not compiled with cuDNN Flash Attention enabled!"); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace native } // namespace at @@ -130,6 +139,10 @@ void run_cudnn_SDP_bprop_nestedtensor( #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -145,6 +158,7 @@ namespace native { #include namespace fe = cudnn_frontend; +<<<<<<< HEAD constexpr uint8_t MAX_MHA_DIM = 4; @@ -197,6 +211,44 @@ int roundup_power2(int dim) { dim++; return dim; } +======= +using graph_and_tensors = std::tuple< + std::shared_ptr, + std::shared_ptr, // Q, + std::shared_ptr, // K, + std::shared_ptr, // V, + std::optional>, // Bias + std::shared_ptr, // Attn_scale, + // TODO(eqy): additional options + // std::shared_ptr, // SEQ_LEN_Q, + // std::shared_ptr, // SEQ_LEN_KV, + std::shared_ptr, // Seed, + std::shared_ptr, // Offset, + // std::shared_ptr, // Dropout_mask, + // std::shared_ptr, // Dropout_scale + std::shared_ptr, // O + std::shared_ptr // Stats + >; + +using graph_and_tensors_backward = std::tuple< + std::shared_ptr, + std::shared_ptr, // Q, + std::shared_ptr, // K, + std::shared_ptr, // V, + std::optional>, // Bias, + std::shared_ptr, // Attn_scale, + std::shared_ptr, // Seed, + std::shared_ptr, // Offset, + std::shared_ptr, // O, + std::shared_ptr, // dO, + std::shared_ptr, // stats, + std::shared_ptr, // dQ, + std::shared_ptr, // dK,, + std::shared_ptr // dV, + >; + +#define MAX_MHA_DIM 4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct MHAParams { c10::DeviceIndex device_id; @@ -221,7 +273,10 @@ struct MHAParams { // might be redundant if we take 0 dim/stride // as signaling no-bias bool has_attn_bias; +<<<<<<< HEAD bool use_ragged; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; void setMHAParams( @@ -238,8 +293,12 @@ void setMHAParams( const std::optional& attn_bias, double dropout_probability, bool is_causal, +<<<<<<< HEAD bool return_softmaxstats, bool is_nested) { +======= + bool return_softmaxstats) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) memset(¶ms, 0, sizeof(MHAParams)); params.device_id = at::cuda::current_device(); params.dataType = fe::DataType_t::HALF; @@ -256,6 +315,7 @@ void setMHAParams( params.is_causal = is_causal; params.return_softmaxstats = return_softmaxstats; params.has_attn_bias = attn_bias.has_value(); +<<<<<<< HEAD // Expect 4D dense tensor, 3D nested case (THD) TORCH_INTERNAL_ASSERT( q.sizes().size() == (uint8_t)(MAX_MHA_DIM - (uint8_t)is_nested), @@ -274,6 +334,25 @@ void setMHAParams( "V tensor has unexpected number of dims, please report a bug to PyTorch."); TORCH_INTERNAL_ASSERT( v.strides().size() == (uint8_t)(MAX_MHA_DIM - (uint8_t)is_nested), +======= + TORCH_INTERNAL_ASSERT( + q.sizes().size() == MAX_MHA_DIM, + "Q tensor has unexpected number of dims, please report a bug to PyTorch."); + TORCH_INTERNAL_ASSERT( + q.strides().size() == MAX_MHA_DIM, + "Q tensor has unexpected number of dims, please report a bug to PyTorch."); + TORCH_INTERNAL_ASSERT( + k.sizes().size() == MAX_MHA_DIM, + "K tensor has unexpected number of dims, please report a bug to PyTorch."); + TORCH_INTERNAL_ASSERT( + k.strides().size() == MAX_MHA_DIM, + "K tensor has unexpected number of dims, please report a bug to PyTorch."); + TORCH_INTERNAL_ASSERT( + v.sizes().size() == MAX_MHA_DIM, + "V tensor has unexpected number of dims, please report a bug to PyTorch."); + TORCH_INTERNAL_ASSERT( + v.strides().size() == MAX_MHA_DIM, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "V tensor has unexpected number of dims, please report a bug to PyTorch."); std::copy(q.sizes().begin(), q.sizes().end(), params.q_dim.begin()); std::copy(q.strides().begin(), q.strides().end(), params.q_stride.begin()); @@ -281,6 +360,7 @@ void setMHAParams( std::copy(k.strides().begin(), k.strides().end(), params.k_stride.begin()); std::copy(v.sizes().begin(), v.sizes().end(), params.v_dim.begin()); std::copy(v.strides().begin(), v.strides().end(), params.v_stride.begin()); +<<<<<<< HEAD bool use_ragged = use_ragged_in_dense(q, k, v, q, params.has_attn_bias); params.use_ragged = use_ragged; if (use_ragged) { @@ -295,6 +375,8 @@ void setMHAParams( params.k_dim[2] = roundup_power2(params.k_dim[2]); params.v_dim[2] = roundup_power2(params.v_dim[2]); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // uninit is OK as the struct is memset 0'd if (params.has_attn_bias) { std::copy( @@ -322,8 +404,12 @@ struct MHACacheKeyWrapper : ParamsWrapper { const std::optional& attn_bias, double dropout_probability, bool is_causal, +<<<<<<< HEAD bool return_softmaxstats, bool is_nested) { +======= + bool return_softmaxstats) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) setMHAParams( this->pod, b, @@ -338,21 +424,29 @@ struct MHACacheKeyWrapper : ParamsWrapper { attn_bias, dropout_probability, is_causal, +<<<<<<< HEAD return_softmaxstats, is_nested); +======= + return_softmaxstats); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } }; template struct MHAGraphCache { std::unordered_map> engine_cache; +<<<<<<< HEAD int count = 0; int hits = 0; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // no mutexes here as caches are now thread local for v8, can also return a // pointer to the Execution Plan if we know it will not be invalidated by // another thread T* find(const KeyType& key) { +<<<<<<< HEAD static bool flag = c10::utils::check_env("TORCH_CUDNN_SDPA_CACHE_DEBUG") == true; if (flag && count) { @@ -364,11 +458,16 @@ struct MHAGraphCache { "%"); } count++; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto it = engine_cache.find(key); if (it == engine_cache.end()) { return nullptr; } +<<<<<<< HEAD hits++; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return &(it->second); } @@ -381,6 +480,7 @@ struct MHAGraphCache { // @eqy: use thread local caches as cuDNN Execution Plans are not guaranteed to // be thread safe across all engines see Limitations in // https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html +<<<<<<< HEAD // We also leak the caches to workaround potential teardown race issues. auto& getMHAGraphCache_() { @@ -420,6 +520,13 @@ enum UIDS { RAG_LSE_OFF }; +======= +thread_local MHAGraphCache mhagraphcache; +thread_local MHAGraphCache + mhagraphbackwardcache; + +namespace { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // analogous to the same function in Descriptors.h for cuDNN Convolutions... auto fixSizeOneDimStrideSDPA( const IntArrayRef sizes, @@ -437,10 +544,16 @@ auto fixSizeOneDimStrideSDPA( } return strides; } +<<<<<<< HEAD } // namespace auto build_graph( +======= +} // namespace + +auto build_graph_and_tensors( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t b, int64_t h, int64_t s_q, @@ -473,16 +586,39 @@ auto build_graph( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() +<<<<<<< HEAD .set_uid(SCALE) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); +<<<<<<< HEAD +======= + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutseed.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutoffset.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto scaled_dot_product_flash_attention_options = fe::graph::SDPA_attributes() .set_name("CUDNN_SDPA") .set_is_inference(return_softmaxstats == false) +<<<<<<< HEAD // TODO(eqy): switch to this API once cuDNN FE is upgraded // .set_generate_stats(return_softmaxstats) .set_causal_mask(is_causal) @@ -534,17 +670,41 @@ auto build_graph( fe::graph::Tensor_attributes().set_uid(K).set_name("K")); auto V_ = mha_graph->tensor( fe::graph::Tensor_attributes().set_uid(V).set_name("V")); +======= + .set_causal_mask(is_causal) + .set_attn_scale(attn_scale) + .set_dropout(dropout_probability, seed, offset); + auto Q = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim(q.sizes().vec()) + .set_stride(fixSizeOneDimStrideSDPA(q.sizes(), q.strides().vec()))); + auto K = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("K") + .set_dim(k.sizes().vec()) + .set_stride(fixSizeOneDimStrideSDPA(k.sizes(), k.strides().vec()))); + auto V = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("V") + .set_dim(v.sizes().vec()) + .set_stride(fixSizeOneDimStrideSDPA(v.sizes(), v.strides().vec()))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::optional> bias; if (attn_bias.has_value()) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() +<<<<<<< HEAD .set_uid(BIAS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); scaled_dot_product_flash_attention_options.set_bias(bias.value()); } +<<<<<<< HEAD auto [O_, Stats] = mha_graph->sdpa(Q_, K_, V_, scaled_dot_product_flash_attention_options); O_->set_uid(O).set_output(true); @@ -631,6 +791,14 @@ auto build_graph( if (Stats) { Stats->set_dim(softmaxstats.sizes().vec()); } +======= + auto [O, Stats] = + mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); + O->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec()); + + if (Stats) { + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); @@ -640,10 +808,27 @@ auto build_graph( AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); +<<<<<<< HEAD return mha_graph; } auto build_graph_nestedtensor( +======= + return std::make_tuple( + std::move(mha_graph), + std::move(Q), + std::move(K), + std::move(V), + std::move(bias), + std::move(attn_scale), + std::move(seed), + std::move(offset), + std::move(O), + std::move(Stats)); +} + +auto build_graph_and_tensors_nestedtensor( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t b, int64_t h_q, int64_t h_k, @@ -680,12 +865,16 @@ auto build_graph_nestedtensor( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() +<<<<<<< HEAD .set_uid(SCALE) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); +<<<<<<< HEAD auto SEQ_LEN_Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(SEQ_LEN_Q) @@ -696,6 +885,25 @@ auto build_graph_nestedtensor( auto SEQ_LEN_KV_ = mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(SEQ_LEN_KV) +======= + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto SEQ_LEN_Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto SEQ_LEN_KV = + mha_graph->tensor(fe::graph::Tensor_attributes() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .set_name("Seq_kv") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -705,6 +913,7 @@ auto build_graph_nestedtensor( fe::graph::SDPA_attributes() .set_name("CUDNN_SDPA_NESTEDTENSOR") .set_is_inference(return_softmaxstats == false) +<<<<<<< HEAD // TODO(eqy): switch to this API once cuDNN FE is upgraded // .set_generate_stats(return_softmaxstats) .set_causal_mask(is_causal) @@ -734,10 +943,19 @@ auto build_graph_nestedtensor( scaled_dot_product_flash_attention_options.set_dropout( dropout_probability, seed, offset); } +======= + .set_causal_mask(is_causal) + .set_attn_scale(attn_scale) + .set_dropout(dropout_probability, seed, offset) + .set_seq_len_q(SEQ_LEN_Q) + .set_seq_len_kv(SEQ_LEN_KV) + .set_padding_mask(true); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // We hardcode BSHD to cuDNN even though the underlying layout is THD auto q_strides = q.strides(); auto k_strides = k.strides(); auto v_strides = v.strides(); +<<<<<<< HEAD // NB: cuDNN API shape is transposed: we pass it nominally as HTD constexpr int strideidx0 = 1; constexpr int strideidx1 = 0; @@ -769,6 +987,35 @@ auto build_graph_nestedtensor( v_strides[strideidx0], v_strides[strideidx1], v_strides[strideidx2]})); +======= + constexpr int strideidx0 = 1; + constexpr int strideidx1 = 0; + constexpr int strideidx2 = 2; + auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h_q, s_q, d_qk}) + .set_stride( + {INT_MAX, + q_strides[strideidx0], + q_strides[strideidx1], + q_strides[strideidx2]})); + auto K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride( + {INT_MAX, + k_strides[strideidx0], + k_strides[strideidx1], + k_strides[strideidx2]})); + auto V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, h_v, s_kv, d_v}) + .set_stride( + {INT_MAX, + v_strides[strideidx0], + v_strides[strideidx1], + v_strides[strideidx2]})); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::optional> bias; if (attn_bias.has_value()) { TORCH_CHECK( @@ -776,12 +1023,16 @@ auto build_graph_nestedtensor( "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); bias = mha_graph->tensor(fe::graph::Tensor_attributes() +<<<<<<< HEAD .set_uid(BIAS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); scaled_dot_product_flash_attention_options.set_bias(bias.value()); } +<<<<<<< HEAD auto RAG_Q_OFF_ = mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(RAG_Q_OFF) @@ -818,6 +1069,41 @@ auto build_graph_nestedtensor( auto o_strides = o.strides(); O_->set_output(true) .set_uid(O) +======= + auto RAG_Q_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("cum_seq_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_K_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("cum_seq_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_V_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("cum_seq_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_O_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("cum_seq_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + // auto RAG_STATS_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("cum_seq_stats") + // .set_dim({b + 1, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT32)); + auto RAG_STATS_OFF = nullptr; + Q->set_ragged_offset(RAG_Q_OFF); + K->set_ragged_offset(RAG_K_OFF); + V->set_ragged_offset(RAG_V_OFF); + auto [O, Stats] = + mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); + auto o_strides = o.strides(); + O->set_output(true) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .set_dim({b, h_q, s_q, d_v}) .set_stride( {INT_MAX, @@ -825,6 +1111,7 @@ auto build_graph_nestedtensor( o_strides[strideidx1], o_strides[strideidx2]}); +<<<<<<< HEAD O_->set_ragged_offset(RAG_O_OFF_); if (Stats) { auto RAG_STATS_OFF = @@ -839,6 +1126,18 @@ auto build_graph_nestedtensor( .set_data_type(fe::DataType_t::FLOAT) .set_dim({b, h_q, s_q, 1}) .set_stride({h_q * s_q, 1, h_q, 1}); +======= + O->set_ragged_offset(RAG_O_OFF); + if (Stats) { + TORCH_CHECK( + false, + "cuDNN SDPA Nested Tensor does not yet handle backwards/logsumexp computation"); + // TODO(eqy): fix when stats (backward) support is added + Stats->set_output(true) + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({b, h_q, s_q, 1}) + .set_stride({h_q * s_q * d_v, d_v, s_q * d_v, 1}); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Stats->set_ragged_offset(RAG_STATS_OFF); } AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); @@ -847,10 +1146,34 @@ auto build_graph_nestedtensor( mha_graph->create_execution_plans({fe::HeurMode_t::A})); AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); +<<<<<<< HEAD return mha_graph; } auto build_graph_backward( +======= + return std::make_tuple( + std::move(mha_graph), + std::move(Q), + std::move(K), + std::move(V), + std::move(bias), + std::move(attn_scale), + std::move(seed), + std::move(offset), + std::move(O), + std::move(Stats), + std::move(RAG_Q_OFF), + std::move(RAG_K_OFF), + std::move(RAG_V_OFF), + std::move(RAG_O_OFF), + std::move(RAG_STATS_OFF), + std::move(SEQ_LEN_Q), + std::move(SEQ_LEN_KV)); +} + +auto build_graph_and_tensors_backward( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t b, int64_t h, int64_t s_q, @@ -886,7 +1209,10 @@ auto build_graph_backward( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() +<<<<<<< HEAD .set_uid(SCALE) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -896,6 +1222,7 @@ auto build_graph_backward( .set_name("CUDNN_SDPA_BACKWARD") .set_causal_mask(is_causal) .set_attn_scale(attn_scale); +<<<<<<< HEAD if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) { auto SEQ_LEN_Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -922,16 +1249,34 @@ auto build_graph_backward( fe::graph::Tensor_attributes().set_uid(K).set_name("K")); auto V_ = mha_graph->tensor( fe::graph::Tensor_attributes().set_uid(V).set_name("V")); +======= + auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim(q.sizes().vec()) + .set_stride(q.strides().vec())); + auto K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim(k.sizes().vec()) + .set_stride(k.strides().vec())); + auto V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim(v.sizes().vec()) + .set_stride(v.strides().vec())); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::optional> bias; if (attn_bias.has_value()) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() +<<<<<<< HEAD .set_uid(BIAS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); sdpa_backward_options.set_bias(bias.value()); } +<<<<<<< HEAD if (dropout_probability != 0.0f) { auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(SEED) @@ -1298,13 +1643,72 @@ auto build_graph_backward_nestedtensor( v_strides[strideidx1], v_strides[strideidx2]}); +======= + auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutseed.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + + auto Offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutoffset.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + + auto O = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_dim(o.sizes().vec()) + .set_stride(o.strides().vec())); + auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Stats") + .set_dim(softmaxstats.sizes().vec()) + .set_stride(softmaxstats.strides().vec()) + .set_data_type(fe::DataType_t::FLOAT)); + auto DO = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("DO") + .set_dim(dO.sizes().vec()) + .set_stride(dO.strides().vec())); + if (dropout_probability != 0.0f) { + sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset); + } + auto [DQ, DK, DV] = + mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options); + DQ->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec()); + DK->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec()); + DV->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); AT_CUDNN_FRONTEND_CHECK( mha_graph->create_execution_plans({fe::HeurMode_t::A})); AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); +<<<<<<< HEAD return mha_graph; +======= + return std::make_tuple( + std::move(mha_graph), + std::move(Q), + std::move(K), + std::move(V), + std::move(bias), + std::move(attn_scale), + std::move(Seed), + std::move(Offset), + std::move(O), + std::move(DO), + std::move(STATS), + std::move(DQ), + std::move(DK), + std::move(DV)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void run_cudnn_SDP_fprop( @@ -1326,6 +1730,7 @@ void run_cudnn_SDP_fprop( Tensor& o, Tensor& dropoutseed, Tensor& dropoutoffset) { +<<<<<<< HEAD // do nothing if we got 0-element tensors if (!q.numel() || !k.numel() || !v.numel()) { return; @@ -1367,6 +1772,8 @@ void run_cudnn_SDP_fprop( } } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto dprops = at::cuda::getCurrentDeviceProperties(); auto _dropoutseed = dropoutseed; auto _dropoutoffset = dropoutoffset; @@ -1377,10 +1784,28 @@ void run_cudnn_SDP_fprop( } cudnnHandle_t handle = getCudnnHandle(); +<<<<<<< HEAD // NB: The key initialization will round up sequence length, stride data etc. // if use_ragged_in_dense is enabled (to allow multiple sequence lenghths to // reuse the same cached value/graph) +======= + if (!o.defined()) { + // q is passed to us in BHSD dim order + alloc_with_matching_layout(q, o, {b, h, s_q, d_v}); + } + + if (return_softmaxstats && !softmaxstats.defined()) { + // TODO(eqy): verify that this is correct + softmaxstats = at::empty({b, h, s_q}, q.options().dtype(kFloat)); + } + + // do nothing if we got 0-element tensors + if (!q.numel() || !k.numel() || !v.numel()) { + return; + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto key = MHACacheKeyWrapper( b, h, @@ -1394,6 +1819,7 @@ void run_cudnn_SDP_fprop( attn_bias, dropout_probability, is_causal, +<<<<<<< HEAD return_softmaxstats, false); auto graph_ptr = getMHAGraphCache_().find(key); @@ -1402,6 +1828,15 @@ void run_cudnn_SDP_fprop( mha_graph = *graph_ptr; } else { mha_graph = build_graph( +======= + return_softmaxstats); + auto graph_and_tensors_ptr = mhagraphcache.find(key); + graph_and_tensors graph_and_tensors_values; + if (graph_and_tensors_ptr) { + graph_and_tensors_values = *graph_and_tensors_ptr; + } else { + graph_and_tensors_values = build_graph_and_tensors( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) b, h, s_q, @@ -1422,6 +1857,7 @@ void run_cudnn_SDP_fprop( _dropoutoffset, handle); } +<<<<<<< HEAD std::unordered_map variant_pack = { {Q, q.data_ptr()}, {K, k.data_ptr()}, @@ -1448,13 +1884,35 @@ void run_cudnn_SDP_fprop( if (return_softmaxstats) { variant_pack[RAG_LSE_OFF] = rag_off_lse.data_ptr(); } +======= + auto [mha_graph, Q, K, V, bias, attn_scale, seed, offset, O, Stats] = + graph_and_tensors_values; + std::unordered_map, void*> + variant_pack = { + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {attn_scale, &scaling_factor}, + {seed, _dropoutseed.data_ptr()}, + {offset, _dropoutoffset.data_ptr()}, + {O, o.data_ptr()}}; + if (return_softmaxstats) { + variant_pack[Stats] = softmaxstats.data_ptr(); + } + if (attn_bias.has_value()) { + variant_pack[bias.value()] = attn_bias.value().data_ptr(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); TORCH_CHECK( mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); +<<<<<<< HEAD getMHAGraphCache_().update(key, mha_graph); +======= + mhagraphcache.update(key, graph_and_tensors_values); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void run_cudnn_SDP_fprop_nestedtensor( @@ -1493,6 +1951,7 @@ void run_cudnn_SDP_fprop_nestedtensor( if (return_softmaxstats && !softmaxstats.defined()) { softmaxstats = at::empty({q.size(0), h_q, 1}, q.options().dtype(kFloat)); } +<<<<<<< HEAD auto key = MHACacheKeyWrapper( b, @@ -1565,6 +2024,74 @@ void run_cudnn_SDP_fprop_nestedtensor( if (dropout_probability != 0.0f) { variant_pack[SEED] = dropoutseed.data_ptr(); variant_pack[OFFSET] = dropoutoffset.data_ptr(); +======= + auto + [mha_graph, + Q, + K, + V, + bias, + attn_scale, + seed, + offset, + O, + Stats, + RAG_Q_OFF, + RAG_K_OFF, + RAG_V_OFF, + RAG_O_OFF, + RAG_STATS_OFF, + SEQ_LEN_Q, + SEQ_LEN_KV] = + build_graph_and_tensors_nestedtensor( + b, + h_q, + h_k, + h_v, + s_q, + s_kv, + d_qk, + d_v, + scaling_factor, + return_softmaxstats, + is_causal, + dropout_probability, + cum_seqlen_q, + cum_seqlen_kv, + q, + k, + v, + attn_bias, + softmaxstats, + o, + dropoutseed, + dropoutoffset, + handle); + auto seqlen_q = at::diff(cum_seqlen_q, 1, 0); + auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0); + auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk); + auto rag_k_off = cum_seqlen_kv.mul(h_k * d_qk); + auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v); + auto rag_stats_off = cum_seqlen_q.mul(h_q); + std::unordered_map, void*> + variant_pack = { + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {attn_scale, &scaling_factor}, + {seed, dropoutseed.data_ptr()}, + {offset, dropoutoffset.data_ptr()}, + {O, o.data_ptr()}, + {RAG_Q_OFF, rag_q_off.data_ptr()}, + {RAG_O_OFF, rag_q_off.data_ptr()}, + {RAG_K_OFF, rag_k_off.data_ptr()}, + {RAG_V_OFF, rag_v_off.data_ptr()}, + {SEQ_LEN_Q, seqlen_q.data_ptr()}, + {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; + if (return_softmaxstats) { + variant_pack[Stats] = softmaxstats.data_ptr(); + variant_pack[RAG_STATS_OFF] = cum_seqlen_q.data_ptr(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if (attn_bias.has_value()) { TORCH_CHECK("bias not supported with nestedtensor"); @@ -1603,9 +2130,12 @@ void run_cudnn_SDP_bprop( !softmaxstats.numel()) { return; } +<<<<<<< HEAD Tensor seqlen_q, seqlen_kv; Tensor rag_off_q, rag_off_k, rag_off_v, rag_off_o, rag_off_lse; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto dprops = at::cuda::getCurrentDeviceProperties(); auto _dropoutseed = dropoutseed; auto _dropoutoffset = dropoutoffset; @@ -1632,6 +2162,7 @@ void run_cudnn_SDP_bprop( "with matching strides..."); #else const auto innermost_dO_stride = dO.strides()[dO.strides().size() - 1]; +<<<<<<< HEAD if (innermost_dO_stride != 1 || use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) { permute_to_matching_layout(o, dO_); @@ -1654,6 +2185,12 @@ void run_cudnn_SDP_bprop( rag_off_lse = cum_seqlen_q.mul(softmaxstats.stride(-2)); } +======= + if (innermost_dO_stride != 1) { + permute_to_matching_layout(o, dO_); + } +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cudnnHandle_t handle = getCudnnHandle(); auto key = MHACacheKeyWrapper( b, @@ -1668,6 +2205,7 @@ void run_cudnn_SDP_bprop( attn_bias, dropout_probability, is_causal, +<<<<<<< HEAD true, false); auto graph_backward_ptr = getMHAGraphBackwardCache_().find(key); @@ -1676,6 +2214,15 @@ void run_cudnn_SDP_bprop( mha_graph = *graph_backward_ptr; } else { mha_graph = build_graph_backward( +======= + true); + auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key); + graph_and_tensors_backward graph_and_tensors_backward_values; + if (graph_and_tensors_backward_ptr) { + graph_and_tensors_backward_values = *graph_and_tensors_backward_ptr; + } else { + graph_and_tensors_backward_values = build_graph_and_tensors_backward( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) b, h, s_q, @@ -1699,6 +2246,7 @@ void run_cudnn_SDP_bprop( _dropoutoffset, handle); } +<<<<<<< HEAD std::unordered_map variant_pack = { // inputs {Q, q.data_ptr()}, @@ -1872,12 +2420,54 @@ void run_cudnn_SDP_bprop_nestedtensor( !attn_bias.has_value(), "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); +======= + auto + [mha_graph, + Q, + K, + V, + bias, + attn_scale, + Seed, + Offset, + O, + Do, + Stats, + Dq, + Dk, + Dv] = graph_and_tensors_backward_values; + std::unordered_map, void*> + variant_pack = {// inputs + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {O, o.data_ptr()}, + {Do, dO_.data_ptr()}, + {Stats, softmaxstats.data_ptr()}, + // outputs + {Dq, dQ.data_ptr()}, + {Dk, dK.data_ptr()}, + {Dv, dV.data_ptr()}, + // pass by value + {attn_scale, &scaling_factor}}; + if (dropout_probability != 0.0f) { + variant_pack[Seed] = _dropoutseed.data_ptr(); + variant_pack[Offset] = _dropoutoffset.data_ptr(); + } + if (attn_bias.has_value()) { + variant_pack[bias.value()] = attn_bias.value().data_ptr(); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); TORCH_CHECK(!workspace_size || workspace_ptr.get()); TORCH_CHECK( mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); +<<<<<<< HEAD +======= + mhagraphbackwardcache.update(key, graph_and_tensors_backward_values); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace native diff --git a/aten/src/ATen/native/cudnn/MHA.h b/aten/src/ATen/native/cudnn/MHA.h index 620abc1aa0a8e..bd983f0e65dea 100644 --- a/aten/src/ATen/native/cudnn/MHA.h +++ b/aten/src/ATen/native/cudnn/MHA.h @@ -70,6 +70,7 @@ void run_cudnn_SDP_bprop( const Tensor& dropoutseed, const Tensor& dropoutoffset); +<<<<<<< HEAD void run_cudnn_SDP_bprop_nestedtensor( int64_t b, int64_t h_q, @@ -97,4 +98,6 @@ void run_cudnn_SDP_bprop_nestedtensor( const Tensor& dropoutseed, const Tensor& dropoutoffset); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::native diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index 7d73ed5305108..25404b555b719 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -245,7 +245,11 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const { datatype, input_datatype, algo, +<<<<<<< HEAD at::globalContext().allowTF32CuDNN("rnn")); +======= + at::globalContext().allowTF32CuDNN()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #else rnn_desc.set( handle, @@ -261,7 +265,11 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const { datatype, input_datatype, algo, +<<<<<<< HEAD at::globalContext().allowTF32CuDNN("rnn")); +======= + at::globalContext().allowTF32CuDNN()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif return rnn_desc; } diff --git a/aten/src/ATen/native/hip/ck_gemm.h b/aten/src/ATen/native/hip/ck_gemm.h index 0d42cad56fcda..c3d757faebfac 100644 --- a/aten/src/ATen/native/hip/ck_gemm.h +++ b/aten/src/ATen/native/hip/ck_gemm.h @@ -10,7 +10,10 @@ inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented"); } +<<<<<<< HEAD #if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(double)); template <> @@ -19,7 +22,11 @@ template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)); template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); +<<<<<<< HEAD #endif +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::native diff --git a/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip b/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip index 7561cede386fb..482dab71d4be3 100644 --- a/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip +++ b/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip @@ -1,7 +1,12 @@ #undef __HIP_NO_HALF_CONVERSIONS__ +<<<<<<< HEAD #include #if defined(USE_ROCM_CK_GEMM) +======= + +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -782,4 +787,7 @@ void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { } } // namespace at::native +<<<<<<< HEAD #endif // USE_ROCM_CK_GEMM +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/hip/ck_gemm_float.hip b/aten/src/ATen/native/hip/ck_gemm_float.hip index c4fea6088d3f0..358739d46b7e9 100644 --- a/aten/src/ATen/native/hip/ck_gemm_float.hip +++ b/aten/src/ATen/native/hip/ck_gemm_float.hip @@ -1,7 +1,10 @@ #undef __HIP_NO_HALF_CONVERSIONS__ #include +<<<<<<< HEAD #if defined(USE_ROCM_CK_GEMM) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -485,4 +488,7 @@ void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(double)) { } } // namespace at::native +<<<<<<< HEAD #endif // USE_ROCM_CK_GEMM +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/hip/ck_gemm_half.hip b/aten/src/ATen/native/hip/ck_gemm_half.hip index ebe044c389721..cf101a64cfdaf 100644 --- a/aten/src/ATen/native/hip/ck_gemm_half.hip +++ b/aten/src/ATen/native/hip/ck_gemm_half.hip @@ -1,7 +1,10 @@ #undef __HIP_NO_HALF_CONVERSIONS__ #include +<<<<<<< HEAD #if defined(USE_ROCM_CK_GEMM) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -607,4 +610,7 @@ void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)) { } } // namespace at::native +<<<<<<< HEAD #endif // USE_ROCM_CK_GEMM +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index dadfe8aef5fd6..414f6c0aaaef1 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -261,11 +261,38 @@ std::tuple math_native_layer_norm( return outputs; } +<<<<<<< HEAD std::tuple rms_norm_composite( const Tensor& input, IntArrayRef normalized_shape, const std::optional& weight_opt /* optional */, std::optional eps) { +======= +Tensor rms_norm_symint( + const Tensor& input, + c10::SymIntArrayRef normalized_shape, + const std::optional& weight_opt /* optional */, + std::optional eps) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + _check_rms_norm_inputs_symint(input, normalized_shape, weight); + +#ifdef USE_MPS + if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) { + const Tensor weight = weight_opt.value(); + const bool any_nested = input.is_nested() || weight.is_nested(); + const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad(); + const bool is_input_fp = isFloatingType(input.scalar_type()); + const bool is_weight_fp = isFloatingType(weight.scalar_type()); + + if (!(GradMode::is_enabled() && any_inputs_require_grad) && !any_nested && is_input_fp && is_weight_fp) { + auto eps_val = eps.value_or(std::numeric_limits::epsilon()); + return at::_fused_rms_norm(input.contiguous(), normalized_shape.size(), weight.contiguous(), eps_val); + } + } +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector dims_to_reduce; for (const auto i : c10::irange(normalized_shape.size())) { @@ -302,6 +329,7 @@ std::tuple rms_norm_composite( upcasted_result = upcasted_result.mul(weight_opt.value()); } +<<<<<<< HEAD // if nested do not make contiguous if(input.is_nested() || (weight_opt.has_value() && weight_opt.value().is_nested())){ return std::make_tuple(upcasted_result, rqrst_input); @@ -365,4 +393,12 @@ Tensor rms_norm_symint( return std::get<0>(at::_fused_rms_norm(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); } +======= + return upcasted_result; + }); + + return result.type_as(input); + +} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::native diff --git a/aten/src/ATen/native/layer_norm.h b/aten/src/ATen/native/layer_norm.h index 0debe942dd0a6..fd0c82a7128ea 100644 --- a/aten/src/ATen/native/layer_norm.h +++ b/aten/src/ATen/native/layer_norm.h @@ -106,12 +106,15 @@ void layer_norm_cpu_out( int64_t M, int64_t N); +<<<<<<< HEAD std::tuple rms_norm_composite( const Tensor& input, IntArrayRef normalized_shape, const std::optional& weight_opt /* optional */, std::optional eps); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor rms_norm_symint( const Tensor& input, c10::SymIntArrayRef normalized_shape, diff --git a/aten/src/ATen/native/metal/MetalTensorImpl.h b/aten/src/ATen/native/metal/MetalTensorImpl.h index 44152dd3c6d03..17248081e7063 100644 --- a/aten/src/ATen/native/metal/MetalTensorImpl.h +++ b/aten/src/ATen/native/metal/MetalTensorImpl.h @@ -35,7 +35,11 @@ struct TORCH_API MetalTensorImpl : public OpaqueTensorImpl { return c10::fromIntArrayRefKnownNonNegative(strides_); } +<<<<<<< HEAD c10::SymBool sym_is_contiguous_custom(c10::MemoryFormat memory_format) const override { +======= + bool is_contiguous_custom(c10::MemoryFormat memory_format) const override { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return true; } diff --git a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp index af69dfc76e571..2dc4dc978606a 100644 --- a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp +++ b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp @@ -7,6 +7,10 @@ #include #else #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #endif @@ -102,7 +106,11 @@ std::tuple miopen_batch_norm( mode = miopenBNSpatial; } +<<<<<<< HEAD auto output_t = at::empty(input->sizes(), input->options()); +======= + auto output_t = at::empty_like(input_t, input_t.options(), input_t.suggest_memory_format()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorArg output{ output_t, "output", 0 }; auto handle = getMiopenHandle(); @@ -170,6 +178,7 @@ std::tuple miopen_batch_norm_backward( const std::optional& save_var_t_opt, double epsilon) { // See [Note: hacky wrapper removal for optional tensor] +<<<<<<< HEAD const Tensor& running_mean = running_mean_opt.value_or(Tensor()); const Tensor& running_var = @@ -184,6 +193,17 @@ std::tuple miopen_batch_norm_backward( weight{ weight_t, "weight", 3 }, save_mean{ save_mean_t, "save_mean", 4 }, save_var{ save_var_t, "save_var", 5 }; +======= + const Tensor& save_mean_t = save_mean_t_opt.value_or(Tensor()); + const Tensor& save_var_t = save_var_t_opt.value_or(Tensor()); + + auto grad_output_contig = + grad_output_t.contiguous(input_t.suggest_memory_format()); + TensorArg input{input_t, "input", 1}, + grad_output{grad_output_contig, "grad_output", 2}, + weight{weight_t, "weight", 3}, save_mean{save_mean_t, "save_mean", 4}, + save_var{save_var_t, "save_var", 5}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CheckedFrom c = "miopen_batch_norm_backward"; checkAllDefined(c, {input, grad_output, weight, save_mean, save_var}); @@ -195,7 +215,13 @@ std::tuple miopen_batch_norm_backward( } checkAllSameType(c, {input, grad_output}); checkAllSameType(c, {weight, save_mean, save_var}); +<<<<<<< HEAD checkAllContiguous(c, {input, grad_output, save_mean, save_var}); +======= + checkAllContiguous(c, {save_mean, save_var}); + TORCH_CHECK(input->is_contiguous(input->suggest_memory_format())); + TORCH_CHECK(grad_output->is_contiguous(input->suggest_memory_format())); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) checkDimRange(c, input, 2, 6 /* exclusive */); checkSameSize(c, input, grad_output); auto num_features = input->size(1); @@ -210,7 +236,12 @@ std::tuple miopen_batch_norm_backward( mode = miopenBNSpatial; } +<<<<<<< HEAD auto grad_input_t = at::empty(input->sizes(), input->options()); +======= + auto grad_input_t = at::empty( + input->sizes(), input->options(), input->suggest_memory_format()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto grad_weight_t = at::empty(weight->sizes(), weight->options()); auto grad_bias_t = at::empty(weight->sizes(), weight->options()); diff --git a/aten/src/ATen/native/miopen/Conv_miopen.cpp b/aten/src/ATen/native/miopen/Conv_miopen.cpp index 41226680c4b58..400d005f035cb 100644 --- a/aten/src/ATen/native/miopen/Conv_miopen.cpp +++ b/aten/src/ATen/native/miopen/Conv_miopen.cpp @@ -855,7 +855,12 @@ void raw_miopen_convolution_forward_out_32bit( benchmark, deterministic); +<<<<<<< HEAD if (at::globalContext().immediateMiopen()) { +======= + if (deterministic && !benchmark) { + // immediate mode is triggered for the specific combination of benchmark=off deterministic=on +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint64_t solution_id; Workspace workspace = chooseSolution(args, &solution_id); @@ -1144,7 +1149,12 @@ void raw_miopen_convolution_backward_input_out_32bit( benchmark, deterministic); +<<<<<<< HEAD if (at::globalContext().immediateMiopen()) { +======= + if (deterministic && !benchmark) { + // immediate mode is triggered for the specific combination of benchmark=off deterministic=on +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint64_t solution_id; Workspace workspace = chooseSolution(args, &solution_id); @@ -1318,7 +1328,12 @@ void raw_miopen_convolution_backward_weight_out_32bit( benchmark, deterministic); +<<<<<<< HEAD if (at::globalContext().immediateMiopen()) { +======= + if (deterministic && !benchmark) { + // immediate mode is triggered for the specific combination of benchmark=off deterministic=on +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint64_t solution_id; Workspace workspace = chooseSolution(args, &solution_id); diff --git a/aten/src/ATen/native/mkl/SpectralOps.cpp b/aten/src/ATen/native/mkl/SpectralOps.cpp index 4aa53c5e794b8..ffd144b8b16d4 100644 --- a/aten/src/ATen/native/mkl/SpectralOps.cpp +++ b/aten/src/ATen/native/mkl/SpectralOps.cpp @@ -337,7 +337,10 @@ Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -480,6 +483,7 @@ static Tensor& _exec_fft(Tensor& out, const Tensor& self, IntArrayRef out_sizes, const auto value_type = c10::toRealValueType(input.scalar_type()); out.resize_(batched_out_sizes, MemoryFormat::Contiguous); +<<<<<<< HEAD // fix mkl issue // https://github.com/pytorch/pytorch/issues/154477 #ifdef INTEL_MKL_VERSION @@ -493,6 +497,8 @@ static Tensor& _exec_fft(Tensor& out, const Tensor& self, IntArrayRef out_sizes, #endif #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto descriptor = _plan_mkl_fft( input.strides(), out.strides(), signal_size, input.is_complex(), out.is_complex(), normalization, forward, value_type); diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index 8222304e6d072..78335323a318f 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -155,6 +155,7 @@ static void check_shape_forward(const Tensor& input, // but weight/bias and grad_weight/grad_bias are always CPU tensor. // +<<<<<<< HEAD static bool mkldnn_conv_enabled_fpmath_mode_bf16(){ return at::globalContext().float32Precision("mkldnn", "conv") == "bf16" && mkldnn_bf16_device_check(); @@ -165,6 +166,8 @@ static bool mkldnn_conv_enabled_fpmath_mode_tf32(){ cpuinfo_has_x86_amx_fp16(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) { auto memory_format = at::MemoryFormat::Contiguous; if (is_channels_last) { @@ -173,7 +176,11 @@ static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bo return memory_format; } +<<<<<<< HEAD static void _mkldnn_convolution_out( +======= +static void _mkldnn_convolution_out ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const Tensor& input_t, const Tensor& weight_t, const Tensor& bias, @@ -271,6 +278,7 @@ static Tensor _mkldnn_convolution( output.resize_(output_sizes, memory_format); y = itensor_from_tensor(output); } +<<<<<<< HEAD if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); @@ -279,6 +287,8 @@ static Tensor _mkldnn_convolution( input_t.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _mkldnn_convolution_out( input_t, weight_t, @@ -460,6 +470,7 @@ Tensor mkldnn_convolution_pointwise_binary( op_attr.set_post_ops(po); auto aprop_kind = ideep::prop_kind::forward_inference; +<<<<<<< HEAD if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){ op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } @@ -467,6 +478,8 @@ Tensor mkldnn_convolution_pointwise_binary( op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (bias.defined()) { const ideep::tensor b = itensor_from_tensor(bias); ideep::convolution_forward::compute_binary( @@ -604,6 +617,7 @@ Tensor& mkldnn_convolution_pointwise_binary_( op_attr = ideep::attr_t::fuse_sum(); } auto aprop_kind = ideep::prop_kind::forward_inference; +<<<<<<< HEAD if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); @@ -612,6 +626,8 @@ Tensor& mkldnn_convolution_pointwise_binary_( input_t.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _mkldnn_convolution_out( input_t, weight_t, @@ -730,6 +746,7 @@ Tensor _mkldnn_convolution_transpose( y = itensor_from_tensor(output); } +<<<<<<< HEAD if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){ op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } @@ -737,6 +754,8 @@ Tensor _mkldnn_convolution_transpose( op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (bias.defined()) { const ideep::tensor b = itensor_from_tensor(bias, /*from_const_data_ptr*/true); ideep::convolution_transpose_forward::compute_v3( @@ -821,6 +840,7 @@ Tensor mkldnn_convolution_backward_input( grad_input.resize_(input_size, memory_format); grad_x = itensor_from_tensor(grad_input); } +<<<<<<< HEAD ideep::attr_t op_attr = ideep::attr_t(); if (mkldnn_conv_enabled_fpmath_mode_bf16() && weight.scalar_type() == at::kFloat) { @@ -830,6 +850,8 @@ Tensor mkldnn_convolution_backward_input( weight.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ideep::convolution_backward_data::compute_v2( grad_y, w, @@ -840,6 +862,7 @@ Tensor mkldnn_convolution_backward_input( padding.vec(), padding.vec(), groups, +<<<<<<< HEAD #if IDEEP_PREREQ(3, 4, 1, 3) is_channels_last, op_attr); @@ -856,6 +879,9 @@ Tensor mkldnn_convolution_backward_input( "Unexpected ideep version to support fpmath_mode_tf32, please update ideep version to align with pytorch main branch"); } #endif +======= + is_channels_last); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (grad_output.is_mkldnn()) { return MKLDNNTensor(grad_x, grad_output.options()); @@ -880,6 +906,7 @@ std::tuple mkldnn_convolution_backward_weights( const ideep::tensor x = itensor_from_tensor(input, /*from_const_data_ptr*/true); ideep::tensor grad_w, grad_b; +<<<<<<< HEAD ideep::attr_t op_attr = ideep::attr_t(); if (mkldnn_conv_enabled_fpmath_mode_bf16() && input.scalar_type() == at::kFloat) { @@ -889,6 +916,8 @@ std::tuple mkldnn_convolution_backward_weights( input.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (bias_defined) { ideep::convolution_backward_weights::compute_v2( x, @@ -901,8 +930,12 @@ std::tuple mkldnn_convolution_backward_weights( padding.vec(), padding.vec(), groups, +<<<<<<< HEAD is_channels_last, op_attr); +======= + is_channels_last); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { ideep::convolution_backward_weights::compute_v2( x, @@ -914,8 +947,12 @@ std::tuple mkldnn_convolution_backward_weights( padding.vec(), padding.vec(), groups, +<<<<<<< HEAD is_channels_last, op_attr); +======= + is_channels_last); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if (!is_channels_last) { @@ -1037,6 +1074,7 @@ Tensor mkldnn_convolution_transpose_backward_input( grad_input.resize_(input_size, memory_format); grad_x = itensor_from_tensor(grad_input); } +<<<<<<< HEAD ideep::attr_t op_attr = ideep::attr_t(); if (mkldnn_conv_enabled_fpmath_mode_bf16() && weight.scalar_type() == at::kFloat) { @@ -1046,6 +1084,8 @@ Tensor mkldnn_convolution_transpose_backward_input( weight.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ideep::convolution_transpose_backward_data::compute_v3( grad_y, w, @@ -1056,8 +1096,12 @@ Tensor mkldnn_convolution_transpose_backward_input( padding_r(padding, output_padding), dilation.vec(), groups, +<<<<<<< HEAD is_channels_last, op_attr); +======= + is_channels_last); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (grad_output.is_mkldnn()) { return MKLDNNTensor(grad_x, grad_output.options()); @@ -1083,6 +1127,7 @@ std::tuple mkldnn_convolution_transpose_backward_weights( auto x = itensor_from_tensor(input, /*from_const_data_ptr*/true); ideep::tensor grad_w, grad_b; +<<<<<<< HEAD ideep::attr_t op_attr = ideep::attr_t(); if (mkldnn_conv_enabled_fpmath_mode_bf16() && input.scalar_type() == at::kFloat) { @@ -1092,6 +1137,8 @@ std::tuple mkldnn_convolution_transpose_backward_weights( input.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (bias_defined) { ideep::convolution_transpose_backward_weights::compute_v3( x, @@ -1104,8 +1151,12 @@ std::tuple mkldnn_convolution_transpose_backward_weights( padding_r(padding, output_padding), dilation.vec(), groups, +<<<<<<< HEAD is_channels_last, op_attr); +======= + is_channels_last); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { ideep::convolution_transpose_backward_weights::compute_v3( x, @@ -1117,8 +1168,12 @@ std::tuple mkldnn_convolution_transpose_backward_weights( padding_r(padding, output_padding), dilation.vec(), groups, +<<<<<<< HEAD is_channels_last, op_attr); +======= + is_channels_last); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if (!is_channels_last) { diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp index 8f0b91b3e3f7e..55bbbf8749020 100644 --- a/aten/src/ATen/native/mkldnn/Linear.cpp +++ b/aten/src/ATen/native/mkldnn/Linear.cpp @@ -68,6 +68,7 @@ mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, namespace at::native { +<<<<<<< HEAD static bool use_mkldnn_bf32_linear() { return at::globalContext().float32Precision("mkldnn", "matmul") == "bf16" && mkldnn_bf16_device_check(); @@ -78,6 +79,8 @@ static bool use_mkldnn_tf32_linear() { cpuinfo_has_x86_amx_fp16(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor mkldnn_linear( const Tensor& self, const Tensor& weight_t, const std::optional& bias_opt) { @@ -261,12 +264,16 @@ Tensor mkldnn_linear_pointwise( it != fusion_unary_attr_map().end(), "Fusion behavior undefined."); op_attr = it->second(scalars, algorithm); } +<<<<<<< HEAD if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){ op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } if (use_mkldnn_tf32_linear() && input_t.scalar_type() == at::kFloat){ op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); } +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (mkldnn_bias.has_value()) { ideep::inner_product_forward::compute( mkldnn_input, @@ -356,6 +363,7 @@ Tensor mkldnn_linear_pointwise_binary( auto op_attr = ideep::attr_t::fuse_binary(it_binary->second, other_desc); auto aprop_kind = ideep::prop_kind::forward_inference; +<<<<<<< HEAD if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){ op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } @@ -364,6 +372,8 @@ Tensor mkldnn_linear_pointwise_binary( op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (mkldnn_bias.has_value()) { ideep::inner_product_forward::compute_binary( mkldnn_input, diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp index 44c06a74a2228..b317e14d4bc4a 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.cpp +++ b/aten/src/ATen/native/mkldnn/Matmul.cpp @@ -1,7 +1,13 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +<<<<<<< HEAD #include #include #include +======= +#include +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #if !AT_MKLDNN_ENABLED() @@ -53,7 +59,11 @@ bool mkldnn_fp16_gemm( c10::Half *c, int64_t ldc) { return false; } +<<<<<<< HEAD bool mkldnn_reduced_f32_gemm( +======= +bool mkldnn_bf32_gemm( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, float alpha, @@ -85,6 +95,7 @@ void mkldnn_matmul_i8i8i32( TORCH_INTERNAL_ASSERT(false, __func__, ": ATen not compiled with MKLDNN support"); } +<<<<<<< HEAD bool use_mkldnn_tf32_matmul( const Tensor& mat1, const Tensor& mat2, @@ -92,6 +103,8 @@ bool use_mkldnn_tf32_matmul( return false; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::native @@ -111,11 +124,15 @@ static bool use_mkldnn_fp16_matmul() { } static bool use_mkldnn_bf32_matmul() { +<<<<<<< HEAD return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision("mkldnn", "matmul") == "bf16"; } static bool use_mkldnn_tf32_matmul() { return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision("mkldnn", "matmul") == "tf32"; +======= + return use_mkldnn_bf16_matmul() && at::globalContext().float32MatmulPrecision() == at::Float32MatmulPrecision::MEDIUM; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // returns an ideep::tensor @@ -155,8 +172,12 @@ mkldnn_gemm( bool bf16_usable = std::is_same_v && use_mkldnn_bf16_matmul(); bool fp16_usable = std::is_same_v && use_mkldnn_fp16_matmul(); bool bf32_usable = std::is_same_v && use_mkldnn_bf32_matmul(); +<<<<<<< HEAD bool tf32_usable = std::is_same_v && use_mkldnn_tf32_matmul(); if ( !(bf16_usable || fp16_usable || bf32_usable || tf32_usable) || +======= + if ( !(bf16_usable || fp16_usable || bf32_usable) || +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (m * n * k <= 16 * 16 * 16) || (alpha == 0.0f)) { return false; } @@ -167,7 +188,10 @@ mkldnn_gemm( op_attr = ideep::attr_t::fuse_sum(); } if (bf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); // bf32 path +<<<<<<< HEAD if (tf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); // tf32 path +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // NOTE: View as c-contiguous to avoid extra reordering in mkldnn // Use identity: C = AB <=> C^T = B^T A^T @@ -294,7 +318,11 @@ bool mkldnn_fp16_gemm( return mkldnn_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } +<<<<<<< HEAD bool mkldnn_reduced_f32_gemm( +======= +bool mkldnn_bf32_gemm( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, float alpha, @@ -352,7 +380,10 @@ void mkldnn_matmul( auto mat2_unsqueezed = mat2.dim() == 1 ? mat2.unsqueeze(1) : mat2; auto result_unsqueezed = result.dim() == 1 ? result.unsqueeze(1) : result; bool bf32_usable = mat1.scalar_type() == at::kFloat && use_mkldnn_bf32_matmul(); +<<<<<<< HEAD bool tf32_usable = mat1.scalar_type() == at::kFloat && use_mkldnn_tf32_matmul(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ideep::attr_t op_attr; // "addmm", "addbmm" "baddbmm" in pytorch allow bias to be 2-D or 3-D tensor @@ -360,7 +391,10 @@ void mkldnn_matmul( // to address their differences, we use mkldnn post ops to perform a fused "add" after matrix multiplication is over if (beta != 0.0f) op_attr = ideep::attr_t::fuse_sum(); if (bf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); // bf32 path +<<<<<<< HEAD if (tf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); // tf32 path +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // If alpha = 0, dose not need actually do gemm computation if (alpha == 0) return; @@ -433,6 +467,7 @@ bool use_mkldnn_bf16_matmul( const Tensor& result) { #if defined(__aarch64__) if (mkldnn_bf16_device_check_arm()) { +<<<<<<< HEAD // onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. // Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16 // inputs, allow it for float as well @@ -450,6 +485,28 @@ bool use_mkldnn_bf16_matmul( mat2.scalar_type() == kBFloat16 && (!result.defined() || result.scalar_type() == kBFloat16) && mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); +======= + //onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. Arm Neoverse V1 + //so, don't restrict the mkldnn_matmul only for bf16 inputs, allow it for float as well + return ( + use_mkldnn_bf16_matmul() && + (mat1.scalar_type() == mat2.scalar_type()) && (!result.defined() || (mat1.scalar_type() == result.scalar_type())) && + ((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) && + mat1.numel() != 0 && + mat2.numel() != 0 && + checksize(mat1, mat2)); + } else +#endif + { + return ( + use_mkldnn_bf16_matmul() && + mat1.scalar_type() == kBFloat16 && + mat2.scalar_type() == kBFloat16 && + (!result.defined() || result.scalar_type() == kBFloat16) && + mat1.numel() != 0 && + mat2.numel() != 0 && + checksize(mat1, mat2)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -457,17 +514,30 @@ bool use_mkldnn_fp16_matmul( const Tensor& mat1, const Tensor& mat2, const Tensor& result) { +<<<<<<< HEAD return ( use_mkldnn_fp16_matmul() && mat1.scalar_type() == kHalf && mat2.scalar_type() == kHalf && (!result.defined() || result.scalar_type() == kHalf) && mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); +======= + + return ( + use_mkldnn_fp16_matmul() && + mat1.scalar_type() == kHalf && + mat2.scalar_type() == kHalf && + (!result.defined() || result.scalar_type() == kHalf) && + mat1.numel() != 0 && + mat2.numel() != 0 && + checksize(mat1, mat2)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } bool use_mkldnn_bf32_matmul( const Tensor& mat1, const Tensor& mat2, const Tensor& result) { +<<<<<<< HEAD return ( use_mkldnn_bf32_matmul() && mat1.scalar_type() == kFloat && mat2.scalar_type() == kFloat && @@ -484,17 +554,32 @@ bool use_mkldnn_tf32_matmul( mat2.scalar_type() == kFloat && (!result.defined() || result.scalar_type() == kFloat) && mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2)); +======= + + return ( + use_mkldnn_bf32_matmul() && + mat1.scalar_type() == kFloat && + mat2.scalar_type() == kFloat && + (!result.defined() || result.scalar_type() == kFloat) && + mat1.numel() != 0 && + mat2.numel() != 0 && + checksize(mat1, mat2)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } bool use_mkldnn_matmul( const Tensor& mat1, const Tensor& mat2, const Tensor& result) { +<<<<<<< HEAD return ( use_mkldnn_bf16_matmul(mat1, mat2, result) || use_mkldnn_fp16_matmul(mat1, mat2, result) || use_mkldnn_bf32_matmul(mat1, mat2, result) || use_mkldnn_tf32_matmul(mat1, mat2, result)); +======= + return (use_mkldnn_bf16_matmul(mat1, mat2, result) || use_mkldnn_fp16_matmul(mat1, mat2, result) || use_mkldnn_bf32_matmul(mat1, mat2, result)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } static void _mkldnn_matmul_i8i8i32_with_primitive( diff --git a/aten/src/ATen/native/mkldnn/Matmul.h b/aten/src/ATen/native/mkldnn/Matmul.h index 80247497d58f0..5a3a1bfe22745 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.h +++ b/aten/src/ATen/native/mkldnn/Matmul.h @@ -29,11 +29,14 @@ bool use_mkldnn_bf32_matmul( const Tensor& mat2, const Tensor& result_opt); +<<<<<<< HEAD bool use_mkldnn_tf32_matmul( const Tensor& mat1, const Tensor& mat2, const Tensor& result_opt); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Try running mkldnn optimized gemm, or returns false if naive gemm would be faster bool mkldnn_bf16_gemm( TransposeType transa, TransposeType transb, @@ -67,7 +70,11 @@ oneDNN implicit reduced precision arithmetic feature https://github.com/mgouicem/oneDNN/tree/mgouicem/rfcs/implicit_downconvert/rfcs/20210301-computation-datatype to allow implicitly cast data type from FP32 to BF16 in onednn compute primitives */ +<<<<<<< HEAD bool mkldnn_reduced_f32_gemm( +======= +bool mkldnn_bf32_gemm( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, float alpha, diff --git a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp index 873005b3dd2bc..531b918c6801e 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp @@ -1,4 +1,7 @@ +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -50,7 +53,11 @@ bool check_no_grad(sdp::sdp_params const& params, bool debug) { return !any_inputs_require_grad || !gradmode_enabled; } +<<<<<<< HEAD bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) { +======= +bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) constexpr auto supported_dtypes = c10::array_of( at::kFloat, at::kBFloat16, at::kHalf); // double is not supported @@ -74,6 +81,7 @@ bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) { return sdp::check_tensor_dtype(params, supported_dtypes, debug); } +<<<<<<< HEAD bool can_use_flash_attention(sdp::sdp_params const& params, bool debug) { // Currently, XPU fallbacks flash attention to overrideable return can_use_overrideable_attention(params, debug); @@ -110,28 +118,49 @@ std::array priority_order( return at::globalContext().sDPPriorityOrder(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) { // This function defines the priority order of the different sdp backends // 1. Flash Attention // 2. Math fallback auto& ctx = at::globalContext(); // use overrideable linked to onednn as overrideable implementation +<<<<<<< HEAD if (!ctx.userEnabledMathSDP() && !ctx.userEnabledOverrideableSDP() && !ctx.userEnabledFlashSDP()) { +======= + if (!ctx.userEnabledMathSDP() && !ctx.userEnabledOverrideableSDP()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return sdp::SDPBackend::error; } // Get ideal kernel ordering +<<<<<<< HEAD const auto ordering = priority_order(kernel_params); +======= + const std::array priority_order{ + sdp::SDPBackend::overrideable, + sdp::SDPBackend::math, + }; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Because TORCHCHECK checks if condition is true we negate debug so that // The statements will be printed when debug is true bool print_debug = false; +<<<<<<< HEAD for (auto& backend : ordering) { switch (backend) { case sdp::SDPBackend::overrideable: if (ctx.userEnabledOverrideableSDP() && can_use_overrideable_attention(kernel_params, print_debug)) { +======= + for (auto& backend : priority_order) { + switch (backend) { + case sdp::SDPBackend::overrideable: + if (ctx.userEnabledOverrideableSDP() && + use_overrideable_xpu(kernel_params, print_debug)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return sdp::SDPBackend::overrideable; } break; @@ -140,6 +169,7 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) { return sdp::SDPBackend::math; } break; +<<<<<<< HEAD case sdp::SDPBackend::flash_attention: if (ctx.userEnabledFlashSDP() && can_use_flash_attention(kernel_params, print_debug)) { @@ -160,17 +190,24 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) { TORCH_CHECK(false, "Invalid backend"); } break; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) default: TORCH_CHECK(false, "Invalid backend"); } } // If we have gotten to this point then two things have happened: +<<<<<<< HEAD // 1. can_use_overrideable_attention did not satisfy the constraints to be ran +======= + // 1. use_overrideable_xpu did not satisfy the constraints to be ran +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // 2. The user has explicitly disabled the math kernel // We then re-run the kernel checks with debug enabled to print out the // reason why the kernel was not selected print_debug = true; +<<<<<<< HEAD TORCH_WARN("Flash attention kernel not used because:"); can_use_flash_attention(kernel_params, print_debug); TORCH_WARN("Overrideable attention kernel not used because:"); @@ -179,6 +216,10 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) { can_use_cudnn_attention(kernel_params, print_debug); TORCH_WARN("Memory Efficient attention kernel not used because:"); can_use_mem_efficien_attention(kernel_params, print_debug); +======= + TORCH_WARN("OneDNN kernel not used because:"); + use_overrideable_xpu(kernel_params, print_debug); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.") return sdp::SDPBackend::error; } @@ -202,7 +243,11 @@ int64_t _fused_sdp_choice_xpu( TORCH_CHECK( false, "No viable backend for scaled_dot_product_attention was found. ", +<<<<<<< HEAD "This is likely due to turning off both the math kernel and the overrideable kernels."); +======= + "This is likely due to turning off both the math kernel and the fused kernels."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return static_cast(backend); } @@ -260,7 +305,11 @@ _scaled_dot_product_fused_attention_overrideable_xpu( alloc_with_matching_layout(query, output, output_shape); at::Tensor logsumexp, debug_attn_mask; // not supported +<<<<<<< HEAD at::native::onednn::sdpa( +======= + at::native::onednn::gpu_float_sdpa( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) batch_size, seq_len_q, seq_len_kv, @@ -274,9 +323,13 @@ _scaled_dot_product_fused_attention_overrideable_xpu( attn_bias, is_causal, scale.has_value() ? scale.value() : (1.0 / std::sqrt(head_dim_qk)), +<<<<<<< HEAD output, false, logsumexp); +======= + output); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // rng not used auto philox_seed = at::empty({}, at::dtype(at::kLong)); diff --git a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp index 6a66abc7b062f..374e118135d3d 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp @@ -469,6 +469,7 @@ Tensor _weight_int4pack_mm_xpu( return C; } +<<<<<<< HEAD Tensor& _int_mm_out_xpu( const Tensor& self, @@ -559,4 +560,6 @@ Tensor _int_mm_xpu(const Tensor& self, const Tensor& mat2) { at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt)); return _int_mm_out_xpu(self, mat2, result); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp index e840e21f4f7a1..a2aba3abb76f0 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Attention.cpp @@ -13,9 +13,12 @@ using dims = logical_tensor::dims; using op = dnnl::graph::op; using partition = dnnl::graph::partition; +<<<<<<< HEAD constexpr logical_tensor::data_type sdpa_intermediate_dtype = logical_tensor::data_type::f32; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inline data_type to_logical_tensor_data_type(c10::ScalarType scalar_type) { return scalar_type == c10::ScalarType::Float ? data_type::f32 : scalar_type == c10::ScalarType::Half ? data_type::f16 @@ -23,8 +26,11 @@ inline data_type to_logical_tensor_data_type(c10::ScalarType scalar_type) { : data_type::undef; } +<<<<<<< HEAD namespace sdpa_forward { +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct SDPALogicalParams { enum class TensorID { query, @@ -33,8 +39,12 @@ struct SDPALogicalParams { neg_inf, attn_mask, value, +<<<<<<< HEAD attention, logsumexp, +======= + output, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) end, }; @@ -44,16 +54,24 @@ struct SDPALogicalParams { std::optional neg_inf; std::optional attn_mask; logical_tensor value{}; +<<<<<<< HEAD logical_tensor attention{}; std::optional logsumexp; +======= + logical_tensor output{}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SDPALogicalParams( const at::Tensor& query_, const at::Tensor& key_, const at::Tensor& value_, const std::optional& attn_mask_, +<<<<<<< HEAD const at::Tensor& attention_, const at::Tensor& logsumexp_, +======= + const at::Tensor& output_, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int batch_size, int seq_len_q, int seq_len_kv, @@ -61,26 +79,40 @@ struct SDPALogicalParams { int num_head_kv, int head_dim_qk, int head_dim_v, +<<<<<<< HEAD bool is_causal, bool compute_logsumexp) { +======= + bool is_causal) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const data_type dtype = to_logical_tensor_data_type(query_.scalar_type()); TORCH_INTERNAL_ASSERT( (dtype != data_type::undef), "Only FP16/BF16/FP32 datatypes are currently supported"); +<<<<<<< HEAD TORCH_INTERNAL_ASSERT( query_.scalar_type() == attention_.scalar_type(), "scaled_dot_product_attention_xpu: query and attention tensors should have the same data type."); const dims scalar_shape = {1}; +======= + const dims scalar_shape = {1}; + std::vector inputLogicalTensors; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::Tensor reshaped_query = query_; at::Tensor reshaped_key = key_; at::Tensor reshaped_value = value_; +<<<<<<< HEAD at::Tensor reshaped_attention = attention_; at::Tensor reshaped_logsumexp = compute_logsumexp ? logsumexp_.unsqueeze(-1) : logsumexp_; at::Tensor reshaped_attn_mask = attn_mask_.value_or(at::Tensor()); // handle broadcasted input tensors for OneDNN +======= + at::Tensor reshaped_output = output_; + at::Tensor reshaped_attn_mask = attn_mask_.value_or(at::Tensor()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (at::native::onednn::is_broadcast(reshaped_query)) { at::native::onednn::undo_broadcast(reshaped_query); } @@ -90,6 +122,12 @@ struct SDPALogicalParams { if (at::native::onednn::is_broadcast(reshaped_value)) { at::native::onednn::undo_broadcast(reshaped_value); } +<<<<<<< HEAD +======= + if (at::native::onednn::is_broadcast(reshaped_output)) { + at::native::onednn::undo_broadcast(reshaped_output); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (attn_mask_.has_value() && at::native::onednn::is_broadcast(reshaped_attn_mask)) { at::native::onednn::undo_broadcast(reshaped_attn_mask); @@ -107,13 +145,18 @@ struct SDPALogicalParams { {batch_size, group_num, group_size, seq_len_q, head_dim_qk}); reshaped_key = key_.unsqueeze(2); reshaped_value = value_.unsqueeze(2); +<<<<<<< HEAD reshaped_attention = attention_.view( +======= + reshaped_output = output_.view( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {batch_size, group_num, group_size, seq_len_q, head_dim_v}); if (attn_mask_.has_value() && attn_mask_.value().dim() == 4) { reshaped_attn_mask = attn_mask_.value().unsqueeze(2); } } +<<<<<<< HEAD #define LOGIC_TENSOR_DESC(name, dtype) \ name = { \ static_cast(TensorID::name), \ @@ -123,6 +166,18 @@ struct SDPALogicalParams { LOGIC_TENSOR_DESC(query, dtype); LOGIC_TENSOR_DESC(key, dtype); +======= + query = { + static_cast(TensorID::query), + dtype, + reshaped_query.sizes().vec(), + reshaped_query.strides().vec()}; + key = { + static_cast(TensorID::key), + dtype, + reshaped_key.sizes().vec(), + reshaped_key.strides().vec()}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) scale = { static_cast(TensorID::scale), to_logical_tensor_data_type(at::toOpMathType(query_.scalar_type())), @@ -143,6 +198,7 @@ struct SDPALogicalParams { TORCH_INTERNAL_ASSERT( (mask_dtype != data_type::undef), "Only FP16/BF16/FP32 datatypes are currently supported for attn_mask"); +<<<<<<< HEAD LOGIC_TENSOR_DESC(attn_mask, mask_dtype); } LOGIC_TENSOR_DESC(value, dtype); @@ -156,6 +212,24 @@ struct SDPALogicalParams { LOGIC_TENSOR_DESC(logsumexp, sdpa_intermediate_dtype); } #undef LOGIC_TENSOR_DESC +======= + attn_mask = { + static_cast(TensorID::attn_mask), + mask_dtype, + reshaped_attn_mask.sizes().vec(), + reshaped_attn_mask.strides().vec()}; + } + value = { + static_cast(TensorID::value), + dtype, + reshaped_value.sizes().vec(), + reshaped_value.strides().vec()}; + output = { + static_cast(TensorID::output), + dtype, + reshaped_output.sizes().vec(), + reshaped_output.strides().vec()}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } std::vector get_input() const { std::vector input = {query, key, scale}; @@ -169,21 +243,32 @@ struct SDPALogicalParams { return input; } std::vector get_output() const { +<<<<<<< HEAD std::vector output; output.push_back(attention); if (logsumexp.has_value()) { output.push_back(logsumexp.value()); } return output; +======= + return {output}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } }; partition create_sdpa_graph_partition( bool is_causal, +<<<<<<< HEAD bool compute_logsumexp, data_type dtype, const SDPALogicalParams& params) { // graph building and partitioning +======= + data_type dtype, + const SDPALogicalParams& params) { + // graph building and partitioning + // currently, we assume that Q and K have same sequence length +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) size_t lt_id = static_cast(SDPALogicalParams::TensorID::end); size_t op_id = 0; @@ -193,7 +278,11 @@ partition create_sdpa_graph_partition( // Matrix Extensions (Intel(R) XMX) support, which means the // Q/K/V tensors have bf16 or f16 data type while the output of the first // MatMul, Scale, Mask, and the input of SoftMax are in f32 data type. +<<<<<<< HEAD logical_tensor matmul_qk_out{lt_id++, sdpa_intermediate_dtype}; +======= + logical_tensor matmul_qk_out{lt_id++, data_type::f32}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) op matmul_qk{ op_id++, op::kind::MatMul, @@ -202,7 +291,11 @@ partition create_sdpa_graph_partition( "matmul_qk"}; matmul_qk.set_attr(op::attr::transpose_b, true); +<<<<<<< HEAD logical_tensor scaled_qk_out{lt_id++, sdpa_intermediate_dtype}; +======= + logical_tensor scaled_qk_out{lt_id++, data_type::f32}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) op scale_mul{ op_id++, op::kind::Multiply, @@ -227,7 +320,11 @@ partition create_sdpa_graph_partition( if (params.attn_mask.has_value()) { TORCH_INTERNAL_ASSERT( !is_causal, "Additive mask cannot use with is_causal."); +<<<<<<< HEAD masked_qk_out = {lt_id++, sdpa_intermediate_dtype}; +======= + masked_qk_out = {lt_id++, data_type::f32}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mask_add = { op_id++, op::kind::Add, @@ -262,7 +359,11 @@ partition create_sdpa_graph_partition( {mask_gt_out.value()}, "mask_gt"}; +<<<<<<< HEAD masked_qk_out = {lt_id++, sdpa_intermediate_dtype}; +======= + masked_qk_out = {lt_id++, data_type::f32}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mask_select = { op_id++, op::kind::Select, @@ -283,15 +384,22 @@ partition create_sdpa_graph_partition( logical_tensor softmax_out{lt_id++, dtype}; softmax.add_input(masked_qk_out.value_or(scaled_qk_out)); softmax.add_output(softmax_out); +<<<<<<< HEAD if (compute_logsumexp) { softmax.add_output(params.logsumexp.value()); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) op matmul_v{ op_id++, op::kind::MatMul, {softmax_out, params.value}, +<<<<<<< HEAD {params.attention}, +======= + {params.output}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "matmul_v"}; constexpr auto ekind = dnnl::engine::kind::gpu; @@ -320,15 +428,21 @@ partition create_sdpa_graph_partition( partition& find_or_create_graph_partition( bool is_causal, +<<<<<<< HEAD bool compute_logsumexp, const SDPALogicalParams& params) { thread_local PartitionCache cache; +======= + const SDPALogicalParams& params) { + thread_local static PartitionCache cache; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const data_type dtype = params.query.get_data_type(); // cache key creation // patternID is determined on the basis of the arguments provided std::bitset<32> patternID; if (dtype == data_type::f32) { +<<<<<<< HEAD patternID.set(static_cast(PartitionCache::BitType::Float32), 1); } if (dtype == data_type::bf16) { @@ -336,25 +450,45 @@ partition& find_or_create_graph_partition( } // sdp pattern patternID.set(static_cast(PartitionCache::BitType::SdpaPattern), 1); +======= + // bit 3 corresponds to float32 dtype + patternID.set(3, 1); + } + if (dtype == data_type::bf16) { + // bit 2 corresponds to fp16/bf16 dtype + patternID.set(2, 1); + } + // sdp pattern + patternID.set(4, 1); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Refer to comments in Utils.h. The first 8 bits are reserved int pos = 8; // attn_mask patternID.set(pos++, params.attn_mask.has_value()); patternID.set(pos++, is_causal); +<<<<<<< HEAD // compute_logsumexp patternID.set(pos++, compute_logsumexp); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto partition_ = cache.find_partition(patternID); if (!partition_.has_value()) { // partition cache no hit // graph building and partitioning +<<<<<<< HEAD partition sdp_partition = create_sdpa_graph_partition( is_causal, compute_logsumexp, dtype, params); +======= + partition sdp_partition = + create_sdpa_graph_partition(is_causal, dtype, params); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) partition_ = cache.insert_partition_cache(patternID, sdp_partition); } return *partition_; } +<<<<<<< HEAD } // namespace sdpa_forward namespace sdpa_backward { @@ -783,6 +917,12 @@ partition& find_or_create_backward_graph_partition( namespace at::native::onednn { void sdpa( +======= +} // namespace + +namespace at::native::onednn { +void gpu_float_sdpa( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int batch_size, int seq_len_q, int seq_len_kv, @@ -796,9 +936,13 @@ void sdpa( std::optional attn_mask, bool is_causal, float softmax_scale, +<<<<<<< HEAD const Tensor& attention, bool compute_logsumexp, const Tensor& logsumexp) { +======= + const Tensor& output) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto& eng = GpuEngineManager::Instance().get_engine(); auto& strm = GpuStreamManager::Instance().get_stream(); @@ -813,8 +957,13 @@ void sdpa( }; // OneDNN doesn't support fp32 ukernel for implicit causal mask, +<<<<<<< HEAD // and the reference implementation is worse than aten math + explicit causal // mask. Fall back to explicit causal mask until OneDNN v3.9 which has fp32 +======= + // and the reference implementation is worse than aten math + explict causal + // mask. Fall back to explict causal mask until OneDNN v3.9 which has fp32 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // ukernel for implicit causal mask. if (is_causal && query.dtype() == at::kFloat) { attn_mask = get_tril_mask(); @@ -824,6 +973,7 @@ void sdpa( std::vector l_inputs, l_outputs; std::optional compiled_partition; +<<<<<<< HEAD const sdpa_forward::SDPALogicalParams logical_params( query, key, @@ -845,6 +995,34 @@ void sdpa( l_inputs = std::move(logical_params.get_input()); l_outputs = std::move(logical_params.get_output()); compiled_partition = partition.compile(l_inputs, l_outputs, eng); +======= + auto get_compiled_partition = [&]() { + const SDPALogicalParams logical_params( + query, + key, + value, + attn_mask, + output, + batch_size, + seq_len_q, + seq_len_kv, + num_head_q, + num_head_kv, + head_dim_qk, + head_dim_v, + is_causal); + auto& partition_ = + find_or_create_graph_partition(is_causal, logical_params); + auto i = logical_params.get_input(); + auto o = logical_params.get_output(); + auto compiled_partition = partition_.compile(i, o, eng); + l_inputs = std::move(i); + l_outputs = std::move(o); + return compiled_partition; + }; + + compiled_partition = get_compiled_partition(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor softmax_scale1 = at::full( {}, @@ -854,11 +1032,16 @@ void sdpa( if (is_causal) { neg_inf = at::full( {}, +<<<<<<< HEAD -std::numeric_limits::infinity(), +======= + -INFINITY, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) query.options().dtype(at::toOpMathType(query.scalar_type()))); } std::vector outputs = { +<<<<<<< HEAD {l_outputs[0], eng, attention.data_ptr()}, }; if (compute_logsumexp) { @@ -995,6 +1178,23 @@ void sdpa_backward( } #undef ADD_INPUT +======= + {l_outputs[0], eng, output.data_ptr()}, + }; + size_t i = 0; + std::vector inputs; + inputs.reserve(l_inputs.size()); + inputs.emplace_back(l_inputs[i++], eng, query.data_ptr()); + inputs.emplace_back(l_inputs[i++], eng, key.data_ptr()); + inputs.emplace_back(l_inputs[i++], eng, softmax_scale1.data_ptr()); + if (neg_inf.has_value()) { + inputs.emplace_back(l_inputs[i++], eng, neg_inf->data_ptr()); + } + if (attn_mask.has_value()) { + inputs.emplace_back(l_inputs[i++], eng, attn_mask->data_ptr()); + } + inputs.emplace_back(l_inputs[i++], eng, value.data_ptr()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) compiled_partition->execute(strm, inputs, outputs); } } // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h index 52f89bc1395d7..113b695f64648 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h @@ -110,6 +110,7 @@ struct PartitionCache { // bit 1: is uint8 // bit 2: fp16(0) / bf16(1) // bit 3: is fp32 +<<<<<<< HEAD // bit 4: is sdpa pattern // bit 5: is sdpa backward pattern // bit 6-7: reserved for future use @@ -125,6 +126,13 @@ struct PartitionCache { SdpaBwdPattern = 5 }; +======= + // bit 4: is sdp pattern + // bit 5-7: N/A + // The rest of the bits depend upon the arguments provided + // However, down the line, we might have different bitsets for different + // patterns +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dnnl::graph::partition& insert_partition_cache( std::bitset<32>& patternID, dnnl::graph::partition& p) { diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h index 6b2bf01e6d73d..45b828b3c5c86 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h @@ -164,7 +164,11 @@ void quantized_matmul( std::string_view unary_post_op_algorithm, bool m2_trnas); +<<<<<<< HEAD void sdpa( +======= +void gpu_float_sdpa( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int batch_size, int seq_len_q, int seq_len_kv, @@ -178,6 +182,7 @@ void sdpa( std::optional attn_mask, bool is_causal, float softmax_scale, +<<<<<<< HEAD const Tensor& attention, bool compute_logsumexp, const Tensor& logsumexp); @@ -202,4 +207,7 @@ void sdpa_backward( Tensor& grad_query, Tensor& grad_key, Tensor& grad_value); +======= + const Tensor& output); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/qconv.cpp b/aten/src/ATen/native/mkldnn/xpu/qconv.cpp index c014313a5b35d..1614e45ef0d24 100644 --- a/aten/src/ATen/native/mkldnn/xpu/qconv.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/qconv.cpp @@ -1,7 +1,10 @@ #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -9,7 +12,11 @@ using namespace at::native::onednn; namespace at::native::xpu { +<<<<<<< HEAD inline c10::ScalarType QConvoneDNNXPU::qconv_decide_out_dtype( +======= +static inline c10::ScalarType qconv_decide_out_dtype( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const at::Tensor& act, const std::optional output_dtype) { bool fp32_output = output_dtype.has_value() && (output_dtype == c10::kFloat); @@ -21,7 +28,11 @@ inline c10::ScalarType QConvoneDNNXPU::qconv_decide_out_dtype( return dst_dtype; } +<<<<<<< HEAD at::Tensor QConvoneDNNXPU::qconv_prepack_xpu( +======= +static at::Tensor qconv_prepack_xpu( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::Tensor weight, at::Tensor weight_scales, double input_scale, @@ -35,6 +46,7 @@ at::Tensor QConvoneDNNXPU::qconv_prepack_xpu( return weight; } +<<<<<<< HEAD at::Tensor QConvoneDNNXPU::run_pointwise( at::Tensor act, double act_scale, @@ -289,11 +301,228 @@ at::Tensor QConvoneDNNXPU::run_pointwise_binary_tensor( unary_scalars, unary_algorithm); } +======= +class QConvoneDNNXPU final { + public: + static at::Tensor run_pointwise( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double inv_output_scale, + int64_t output_zero_point, + std::optional output_dtype, + std::string_view attr, + torch::List> scalars, + std::optional algorithm) { + if (act.dim() == 3 || act.dim() == 5) { + TORCH_CHECK( + attr == "none", + "quantized pointwise conv", + act.dim() - 2, + "d doesn't support unary_post_op fusion. Got unary_post_op:", + attr, + "."); + } else { + TORCH_CHECK( + attr == "none" || attr == "relu" || attr == "hardtanh" || + attr == "hardswish" || attr == "swish", + "We support quantized convolution without any post-ops or combinations for Quantized Conv + ReLU, Hardtanh, GELU, Swish, and Hardswish are supported. However, encountered unsupported post operation:", + attr, + "."); + } + + bool is_channels_last_suggested = use_channels_last_for_conv(act, weight); + auto mfmt = is_channels_last_suggested + ? get_cl_tag_by_ndim(act.ndimension()) + : at::MemoryFormat::Contiguous; + Tensor input_ = act.contiguous(mfmt); + Tensor weight_ = weight.contiguous(mfmt); + + auto dst_tz = conv_dst_size( + input_.ndimension(), + input_.sizes(), + weight_.sizes(), + padding.vec(), + padding.vec(), + stride.vec(), + dilation.vec()); + + auto dst_dtype = qconv_decide_out_dtype(act, output_dtype); + Tensor output = + at::empty(dst_tz, act.options().dtype(dst_dtype).memory_format(mfmt)); + + return quantized_convolution( + act, + act_scale, + act_zero_point, + weight, + weight_scales, + weight_zero_points, + bias, + stride, + padding, + dilation, + /*transposed*/ false, + groups, + output, + inv_output_scale, + output_zero_point, + /*accum*/ std::nullopt, + /*accum_scale*/ 0.0, + /*accum_zero_point*/ 0, + /*output_dtype*/ output_dtype, + /*binary_attr*/ std::nullopt, + /*binary_alpha*/ std::nullopt, + /*unary_attr*/ attr, + /*unary_scalars*/ scalars, + /*unary_algorithm*/ algorithm); + } + + static at::Tensor run_pointwise_tensor( + at::Tensor act, + at::Tensor act_scale, + at::Tensor act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + std::string_view attr, + torch::List> scalars, + std::optional algorithm) { + return run_pointwise( + act, + act_scale.item().toDouble(), + act_zero_point.item().toLong(), + weight, + weight_scales, + weight_zero_points, + bias, + stride, + padding, + dilation, + groups, + output_scale, + output_zero_point, + output_dtype, + /*unary_attr*/ attr, + /*unary_scalars*/ scalars, + /*unary_algorithm*/ algorithm); + } + + static at::Tensor run_pointwise_binary( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + at::Tensor accum, + std::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + double accum_scale, + int64_t accum_zero_point, + std::string_view binary_attr, + std::optional alpha, + std::optional unary_attr, + torch::List> unary_scalars, + std::optional unary_algorithm) { + TORCH_CHECK( + act.dim() == 4 && binary_attr == "sum" && + (!unary_attr.has_value() || + (unary_attr.has_value() && + (unary_attr.value() == "none" || unary_attr.value() == "relu"))), + "post_op sum or post_op sum_relu is supported for quantized pointwise conv2d. Got binary_post_op: ", + binary_attr, + " unary_post_op: ", + unary_attr.has_value() ? unary_attr.value() : "none", + ".") + + bool is_channels_last_suggested = use_channels_last_for_conv(act, weight); + auto mfmt = is_channels_last_suggested + ? get_cl_tag_by_ndim(act.ndimension()) + : at::MemoryFormat::Contiguous; + Tensor input_ = act.contiguous(mfmt); + Tensor weight_ = weight.contiguous(mfmt); + + auto dst_tz = conv_dst_size( + input_.ndimension(), + input_.sizes(), + weight_.sizes(), + padding.vec(), + padding.vec(), + stride.vec(), + dilation.vec()); + + auto dst_dtype = qconv_decide_out_dtype(act, output_dtype); + bool has_accum_postop_sum = binary_attr == "sum"; + Tensor output = has_accum_postop_sum + ? accum + : at::empty(dst_tz, act.options().dtype(dst_dtype).memory_format(mfmt)); + + output = quantized_convolution( + act, + act_scale, + act_zero_point, + weight, + weight_scales, + weight_zero_points, + bias, + stride, + padding, + dilation, + /*transposed*/ false, + groups, + output, + output_scale, + output_zero_point, + /*accum*/ accum, + /*accum_scale*/ accum_scale, + /*accum_zero_point*/ accum_zero_point, + /*output_dtype*/ output_dtype, + /*binary_attr*/ binary_attr, + /*binary_alpha*/ alpha, + /*unary_attr*/ unary_attr, + /*unary_scalars*/ unary_scalars, + /*unary_algorithm*/ unary_algorithm); + + if (!has_accum_postop_sum) { + return output; + } else { + return accum; + } + } +}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_LIBRARY_IMPL(onednn, XPU, m) { m.impl( TORCH_SELECTIVE_NAME("onednn::qconv_prepack"), +<<<<<<< HEAD TORCH_FN(QConvoneDNNXPU::qconv_prepack_xpu)); +======= + TORCH_FN(xpu::qconv_prepack_xpu)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m.impl( TORCH_SELECTIVE_NAME("onednn::qconv1d_pointwise"), QConvoneDNNXPU::run_pointwise); @@ -312,9 +541,12 @@ TORCH_LIBRARY_IMPL(onednn, XPU, m) { m.impl( TORCH_SELECTIVE_NAME("onednn::qconv_pointwise.tensor"), QConvoneDNNXPU::run_pointwise_tensor); +<<<<<<< HEAD m.impl( TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.binary_tensor"), QConvoneDNNXPU::run_pointwise_binary_tensor); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace at::native::xpu diff --git a/aten/src/ATen/native/mkldnn/xpu/qlinear.cpp b/aten/src/ATen/native/mkldnn/xpu/qlinear.cpp index e9584e8289eb2..2e9bc2019c0d0 100644 --- a/aten/src/ATen/native/mkldnn/xpu/qlinear.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/qlinear.cpp @@ -1,14 +1,21 @@ #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include using namespace at::native::onednn; namespace at::native::xpu { +<<<<<<< HEAD inline c10::ScalarType QLinearOnednnXPU::qlinear_decide_out_dtype( +======= +static inline c10::ScalarType qlinear_decide_out_dtype( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const at::Tensor& act, const std::optional output_dtype) { bool fp32_output = output_dtype.has_value() && (output_dtype == c10::kFloat); @@ -20,7 +27,11 @@ inline c10::ScalarType QLinearOnednnXPU::qlinear_decide_out_dtype( return dst_dtype; } +<<<<<<< HEAD Tensor QLinearOnednnXPU::q_linear_pointwise( +======= +static Tensor q_linear_pointwise( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor act, double act_scale, int64_t act_zero_point, @@ -79,7 +90,11 @@ Tensor QLinearOnednnXPU::q_linear_pointwise( return qout; } +<<<<<<< HEAD Tensor QLinearOnednnXPU::q_linear_pointwise_tensor( +======= +static Tensor q_linear_pointwise_tensor( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor act, Tensor act_scale, Tensor act_zero_point, @@ -138,7 +153,11 @@ Tensor QLinearOnednnXPU::q_linear_pointwise_tensor( return qout; } +<<<<<<< HEAD Tensor QLinearOnednnXPU::q_linear_pointwise_binary( +======= +static Tensor q_linear_pointwise_binary( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor act, double act_scale, int64_t act_zero_point, @@ -209,7 +228,11 @@ Tensor QLinearOnednnXPU::q_linear_pointwise_binary( return dim == 3 ? qout.reshape({act.size(0), -1, N}) : qout; } +<<<<<<< HEAD Tensor QLinearOnednnXPU::q_linear_pointwise_binary_tensor( +======= +static Tensor q_linear_pointwise_binary_tensor( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor act, Tensor act_scale, Tensor act_zero_point, @@ -249,7 +272,11 @@ Tensor QLinearOnednnXPU::q_linear_pointwise_binary_tensor( unary_post_op_algorithm); } +<<<<<<< HEAD Tensor QLinearOnednnXPU::q_linear_prepack_onednn( +======= +static at::Tensor q_linear_prepack_onednn( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::Tensor weight, std::optional> input_shape) { at::Tensor weight_transposed = weight.transpose(0, 1); @@ -259,6 +286,7 @@ Tensor QLinearOnednnXPU::q_linear_prepack_onednn( TORCH_LIBRARY_IMPL(onednn, XPU, m) { m.impl( TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise"), +<<<<<<< HEAD TORCH_FN(QLinearOnednnXPU::q_linear_pointwise)); m.impl( TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.tensor"), @@ -272,6 +300,21 @@ TORCH_LIBRARY_IMPL(onednn, XPU, m) { m.impl( TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary_tensor"), TORCH_FN(QLinearOnednnXPU::q_linear_pointwise_binary_tensor)); +======= + TORCH_FN(q_linear_pointwise)); + m.impl( + TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.tensor"), + TORCH_FN(q_linear_pointwise_tensor)); + m.impl( + TORCH_SELECTIVE_NAME("onednn::qlinear_prepack"), + TORCH_FN(q_linear_prepack_onednn)); + m.impl( + TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary"), + TORCH_FN(q_linear_pointwise_binary)); + m.impl( + TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary_tensor"), + TORCH_FN(q_linear_pointwise_binary_tensor)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace at::native::xpu diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index f9cd28ca06fa8..427aef2e52d11 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -88,8 +88,19 @@ std::string getArrayRefString(const IntArrayRef s); // use has_storage() on the returned tensor to determine if src actually is a view Tensor gatherViewTensor(const Tensor& src, Tensor& dst); Tensor& scatterViewTensor(const Tensor& src, Tensor& output); +<<<<<<< HEAD MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input); MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input); +======= +MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, + MPSGraphTensor* inputTensor, + const TensorBase& input, + bool includesInt64 = false); +MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, + MPSGraphTensor* inputTensor, + const TensorBase& input, + bool includesInt64 = false); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray); MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {}); @@ -139,6 +150,11 @@ MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStre MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar); MPSGraph* make_mps_graph(); +<<<<<<< HEAD +======= +void printTensorNDArray(const TensorBase& t); +MPSNDArray* ndArrayFromTensor(const TensorBase& tensor, MPSShape* shape, MPSDataType mpsType); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape); @@ -429,6 +445,17 @@ inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /** * Returns distance from lowest to highest element offset in given tensor. */ @@ -604,6 +631,13 @@ inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds, runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result)); } +<<<<<<< HEAD +======= +inline bool supportsComplex() { + return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // MPS yet to support double types, but starting from MacOS 14, supports bfloat16 inline bool supportedFloatingType(ScalarType dtype) { return dtype == kFloat || dtype == kHalf || dtype == kBFloat16; @@ -615,7 +649,11 @@ inline bool supportedFloatingType(const TensorBase& t) { inline bool supportedFloatingOrComplexType(ScalarType dtype) { if (dtype == kComplexFloat || dtype == kComplexHalf) { +<<<<<<< HEAD return true; +======= + return supportsComplex(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return supportedFloatingType(dtype); } @@ -623,6 +661,14 @@ inline bool supportedFloatingOrComplexType(const TensorBase& t) { return supportedFloatingOrComplexType(t.scalar_type()); } +<<<<<<< HEAD +======= +inline void checkSupportsBFloat16() { + TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS), + "MPS bfloat16 type is supported on MacOS 14.0 or newer."); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inline bool needsGather(const TensorBase& t) { static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset()); diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index bf3e94207e25b..531db3067eb5a 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -89,6 +89,13 @@ void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, mpsStream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_ADAPTIVE); } +<<<<<<< HEAD +======= +static inline void checkSupportsComplex() { + TORCH_CHECK_TYPE(supportsComplex(), "MPS complex types are only supported on MacOS 14.0 or newer."); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPSDataType getMPSDataType(ScalarType scalar_type) { switch (scalar_type) { case ScalarType::Float: @@ -96,6 +103,10 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { case ScalarType::Half: return MPSDataTypeFloat16; case ScalarType::BFloat16: +<<<<<<< HEAD +======= + checkSupportsBFloat16(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return MPSDataTypeBFloat16; case ScalarType::Int: return MPSDataTypeInt32; @@ -114,6 +125,7 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { "Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. " "Please use float32 instead.") case ScalarType::ComplexHalf: +<<<<<<< HEAD return MPSDataTypeComplexFloat16; case ScalarType::ComplexFloat: return MPSDataTypeComplexFloat32; @@ -124,6 +136,13 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { return MPSDataTypeUInt32; case ScalarType::UInt16: return MPSDataTypeUInt16; +======= + checkSupportsComplex(); + return MPSDataTypeComplexFloat16; + case ScalarType::ComplexFloat: + checkSupportsComplex(); + return MPSDataTypeComplexFloat32; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) default: TORCH_CHECK_TYPE( false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.") @@ -133,10 +152,23 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { // #issue 104398441 sortWithTensor and argsortWithTensor has support of // Int32, Half and Float32 types. These utilities are to help cast to these // types. +<<<<<<< HEAD MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input) { MPSDataType dataType = getMPSDataType(input.scalar_type()); bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16) && (dataType != MPSDataTypeInt64); +======= +MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, + MPSGraphTensor* inputTensor, + const TensorBase& input, + bool includesInt64) { + MPSDataType dataType = getMPSDataType(input.scalar_type()); + bool condition = + (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16); + if (includesInt64) { + condition = condition && (dataType != MPSDataTypeInt64); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (condition) { dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32; return [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"]; @@ -147,10 +179,23 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { // #issue 104398441 sortWithTensor and argsortWithTensor has support of // Int32, Half and Float32 types. These utilities are to help cast from these // types. +<<<<<<< HEAD MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input) { MPSDataType dataType = getMPSDataType(input.scalar_type()); bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16) && (dataType != MPSDataTypeInt64); +======= +MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, + MPSGraphTensor* inputTensor, + const TensorBase& input, + bool includesInt64) { + MPSDataType dataType = getMPSDataType(input.scalar_type()); + bool condition = + (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16); + if (includesInt64) { + condition = condition && (dataType != MPSDataTypeInt64); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (condition) { inputTensor = [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"]; } @@ -167,6 +212,10 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { case ScalarType::Half: return MPSDataTypeFloat16; case ScalarType::BFloat16: +<<<<<<< HEAD +======= + checkSupportsBFloat16(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return MPSDataTypeBFloat16; case ScalarType::Int: return MPSDataTypeInt32; @@ -181,11 +230,16 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { case ScalarType::Bool: return MPSDataTypeBool; case ScalarType::ComplexHalf: +<<<<<<< HEAD +======= + checkSupportsComplex(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return MPSDataTypeComplexFloat16; // This is an intentional fallthrough supporting ComplexDouble for Scalar // types as they are casted to Complex64 currently. case ScalarType::ComplexDouble: case ScalarType::ComplexFloat: +<<<<<<< HEAD return MPSDataTypeComplexFloat32; // Unsigned types case ScalarType::UInt64: @@ -194,6 +248,10 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { return MPSDataTypeUInt32; case ScalarType::UInt16: return MPSDataTypeUInt16; +======= + checkSupportsComplex(); + return MPSDataTypeComplexFloat32; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) default: TORCH_CHECK_TYPE( false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.") @@ -226,6 +284,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { return short_name ? "c16" : "ComplexFloat16"; case ScalarType::ComplexFloat: return short_name ? "c32" : "ComplexFloat32"; +<<<<<<< HEAD // Unsigned types case ScalarType::UInt64: return short_name ? "u64" : "UInt64"; @@ -233,6 +292,8 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { return short_name ? "u32" : "UInt32"; case ScalarType::UInt16: return short_name ? "u16" : "UInt16"; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) default: return "Undefined"; } @@ -245,6 +306,10 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { case ScalarType::Half: return "half"; case ScalarType::BFloat16: +<<<<<<< HEAD +======= + checkSupportsBFloat16(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return "bfloat"; case ScalarType::Int: return "int"; @@ -262,6 +327,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { return "half2"; case ScalarType::ComplexFloat: return "float2"; +<<<<<<< HEAD // Unsigned types case ScalarType::UInt64: return "ulong"; @@ -269,6 +335,8 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { return "uint"; case ScalarType::UInt16: return "ushort"; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) default: TORCH_CHECK(false, "Undefined type ", scalar_type); return "Undefined"; @@ -332,7 +400,11 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { if (exclude_shape) { fmt::format_to(buf_iterator, "-1"); } else { +<<<<<<< HEAD fmt::format_to(buf_iterator, "{}", getArrayRefString(tensor.sizes())); +======= + fmt::format_to(buf_iterator, getArrayRefString(tensor.sizes())); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } fmt::format_to(buf_iterator, "]"); @@ -382,6 +454,39 @@ Tensor getTensorView(const Tensor& t, MPSShape* shape) { return [NSArray arrayWithObjects:numbers.data() count:numbers.size()]; } +<<<<<<< HEAD +======= +void printTensorNDArray(const TensorBase& t) { + if (!t.is_mps()) + return; + if (t.numel() == 0) + return; + // Get shape and data type + auto selfShape = getMPSShape(t); + auto selfDType = getMPSDataType(t.scalar_type()); + + // Initialize data + id selfBuf = getMTLBufferStorage(t); + MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf shape:selfShape + dataType:selfDType] autorelease]; + C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wobjc-method-access") + C10_CLANG_DIAGNOSTIC_IGNORE("-Wobjc-method-access") +#endif + [tdata printNDArray]; + C10_CLANG_DIAGNOSTIC_POP() +} + +MPSNDArray* ndArrayFromTensor(const TensorBase& tensor, MPSShape* shape, MPSDataType mpsType) { + id buffer = getMTLBufferStorage(tensor); + MPSGraphTensorData* tmpGraphTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buffer + shape:shape + dataType:mpsType] autorelease]; + + return [tmpGraphTensorData mpsndarray]; +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static std::vector getSortedStrides(const IntArrayRef& s) { std::vector idx(s.size()); iota(idx.begin(), idx.end(), 0); @@ -432,6 +537,7 @@ Tensor getTensorView(const Tensor& t, MPSShape* shape) { return result; } +<<<<<<< HEAD // Should be called before initWithBuffer to prevent hard crashes with // '[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: NDArray dimension length > INT_MAX' static void check_mps_shape(MPSShape* shape) { @@ -441,13 +547,18 @@ static void check_mps_shape(MPSShape* shape) { } } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes, MPSShape* strides) { id srcBuf = getMTLBufferStorage(t); MPSDataType mpsDataType = getMPSDataType(t.scalar_type()); MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:mpsDataType shape:sizes]; srcTensorDesc.preferPackedRows = YES; +<<<<<<< HEAD check_mps_shape(sizes); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPSNDArray* srcNDArray = [[[MPSNDArray alloc] initWithBuffer:srcBuf offset:t.storage_offset() * t.element_size() descriptor:srcTensorDesc] autorelease]; @@ -557,9 +668,15 @@ static void check_mps_shape(MPSShape* shape) { // Tensor is contiguous and has no storage offset. // Wrap it directly inside MPSGraphTensorData if ((_tensor.is_contiguous() && !_tensor.storage_offset()) || !useMPSStridedAPI || !is_macOS_15_0_or_newer) { +<<<<<<< HEAD auto shape = mpsShape_ ? mpsShape_ : getMPSShape(_tensor); check_mps_shape(shape); _value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf shape:shape dataType:dataType] autorelease]; +======= + _value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf + shape:mpsShape_ ? mpsShape_ : getMPSShape(_tensor) + dataType:dataType] autorelease]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { IntArrayRef view_shape; if (mpsShape_) { @@ -568,11 +685,16 @@ static void check_mps_shape(MPSShape* shape) { MPSShape* mpsShape = getMPSShape(_tensor); MPSShape* mpsStrides = getMPSShape(_tensor.strides()); +<<<<<<< HEAD check_mps_shape(mpsShape); auto storage_numel = src.storage().nbytes() / src.element_size(); TORCH_CHECK(storage_numel <= std::numeric_limits::max(), "MPSGaph does not support tensor dims larger than INT_MAX"); +======= + + auto storage_numel = src.storage().nbytes() / src.element_size(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:dataType shape:@[ @(storage_numel) ]]; srcTensorDesc.preferPackedRows = YES; @@ -656,11 +778,14 @@ MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type) { case ScalarType::ComplexFloat: case ScalarType::ComplexDouble: return {.size = sizeof(int64_t), .type = type, .value.cf = scalar.to>()}; +<<<<<<< HEAD // Unsigned types case ScalarType::UInt32: return {.size = sizeof(uint32_t), .type = type, .value.i = scalar.to()}; case ScalarType::UInt16: return {.size = sizeof(uint16_t), .type = type, .value.i = scalar.to()}; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) default: TORCH_INTERNAL_ASSERT(false, "Unsupported scalar type '", type, "' on MPS backend."); } @@ -856,7 +981,13 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {} MTLCompileOptions* options = compile_options; if (!options) { options = [[MTLCompileOptions new] autorelease]; +<<<<<<< HEAD [options setLanguageVersion:MTLLanguageVersion3_1]; +======= + // Need 3.0 for atomic oprations, 3.1 introduces bfloat support + [options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1 + : MTLLanguageVersion3_0]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) { options.mathMode = fast_math ? MTLMathModeFast : MTLMathModeSafe; options.mathFloatingPointFunctions = @@ -928,7 +1059,12 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {} if (C10_UNLIKELY(!library)) { auto device = MPSDevice::getInstance()->device(); NSError* error = nil; +<<<<<<< HEAD library = [device newLibraryWithData:getSectionData("metal_basic") error:&error]; +======= + auto section_name = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? "metal_bfloat" : "metal_basic"; + library = [device newLibraryWithData:getSectionData(section_name) error:&error]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]); } return library; diff --git a/aten/src/ATen/native/mps/kernels/ActivationKernel.metal b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal index ae1fda66c3b38..be9169c323162 100644 --- a/aten/src/ATen/native/mps/kernels/ActivationKernel.metal +++ b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal @@ -33,6 +33,7 @@ struct shrink_backward_functor { REGISTER_UNARY_ALPHA_OP(hardshrink, float, float, float); REGISTER_UNARY_ALPHA_OP(hardshrink, half, half, half); +<<<<<<< HEAD REGISTER_UNARY_ALPHA_OP(hardshrink, bfloat, bfloat, bfloat); REGISTER_UNARY_ALPHA_OP(softshrink, float, float, float); @@ -42,6 +43,23 @@ REGISTER_UNARY_ALPHA_OP(softshrink, bfloat, bfloat, bfloat); REGISTER_BINARY_ALPHA_OP(shrink_backward, float, float, float); REGISTER_BINARY_ALPHA_OP(shrink_backward, half, half, half); REGISTER_BINARY_ALPHA_OP(shrink_backward, bfloat, bfloat, bfloat); +======= +#if __METAL_VERSION__ >= 310 +REGISTER_UNARY_ALPHA_OP(hardshrink, bfloat, bfloat, bfloat); +#endif + +REGISTER_UNARY_ALPHA_OP(softshrink, float, float, float); +REGISTER_UNARY_ALPHA_OP(softshrink, half, half, half); +#if __METAL_VERSION__ >= 310 +REGISTER_UNARY_ALPHA_OP(softshrink, bfloat, bfloat, bfloat); +#endif + +REGISTER_BINARY_ALPHA_OP(shrink_backward, float, float, float); +REGISTER_BINARY_ALPHA_OP(shrink_backward, half, half, half); +#if __METAL_VERSION__ >= 310 +REGISTER_BINARY_ALPHA_OP(shrink_backward, bfloat, bfloat, bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct hardsigmoid_functor { template @@ -61,11 +79,23 @@ struct hardsigmoid_backward_functor { REGISTER_UNARY_OP(hardsigmoid, float, float); REGISTER_UNARY_OP(hardsigmoid, half, half); +<<<<<<< HEAD +REGISTER_UNARY_OP(hardsigmoid, bfloat, bfloat); + +REGISTER_BINARY_OP(hardsigmoid_backward, float, float); +REGISTER_BINARY_OP(hardsigmoid_backward, half, half); +REGISTER_BINARY_OP(hardsigmoid_backward, bfloat, bfloat); +======= +#if __METAL_VERSION__ >= 310 REGISTER_UNARY_OP(hardsigmoid, bfloat, bfloat); +#endif REGISTER_BINARY_OP(hardsigmoid_backward, float, float); REGISTER_BINARY_OP(hardsigmoid_backward, half, half); +#if __METAL_VERSION__ >= 310 REGISTER_BINARY_OP(hardsigmoid_backward, bfloat, bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct hardswish_functor { template @@ -93,11 +123,23 @@ struct hardswish_backward_functor { REGISTER_UNARY_OP(hardswish, float, float); REGISTER_UNARY_OP(hardswish, half, half); +<<<<<<< HEAD +REGISTER_UNARY_OP(hardswish, bfloat, bfloat); + +REGISTER_BINARY_OP(hardswish_backward, float, float); +REGISTER_BINARY_OP(hardswish_backward, half, half); +REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat); +======= +#if __METAL_VERSION__ >= 310 REGISTER_UNARY_OP(hardswish, bfloat, bfloat); +#endif REGISTER_BINARY_OP(hardswish_backward, float, float); REGISTER_BINARY_OP(hardswish_backward, half, half); +#if __METAL_VERSION__ >= 310 REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct leaky_relu_functor { template @@ -121,8 +163,20 @@ struct leaky_relu_backward_functor { REGISTER_UNARY_ALPHA_OP(leaky_relu, float, float, float); REGISTER_UNARY_ALPHA_OP(leaky_relu, half, half, half); +<<<<<<< HEAD +REGISTER_UNARY_ALPHA_OP(leaky_relu, bfloat, bfloat, bfloat); + +REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, float, float, float); +REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, half, half, half); +REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, bfloat, bfloat, bfloat); +======= +#if __METAL_VERSION__ >= 310 REGISTER_UNARY_ALPHA_OP(leaky_relu, bfloat, bfloat, bfloat); +#endif REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, float, float, float); REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, half, half, half); +#if __METAL_VERSION__ >= 310 REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, bfloat, bfloat, bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/Amp.metal b/aten/src/ATen/native/mps/kernels/Amp.metal index 653c2057d498d..873a98ca6ac0d 100644 --- a/aten/src/ATen/native/mps/kernels/Amp.metal +++ b/aten/src/ATen/native/mps/kernels/Amp.metal @@ -113,6 +113,7 @@ kernel void ampUpdateScale( INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(float); INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(half); +<<<<<<< HEAD INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(bfloat); INSTANTIATE_AMP_UPDATE_SCALE(float); @@ -122,3 +123,20 @@ INSTANTIATE_AMP_UPDATE_SCALE(bfloat); INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(float); INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(half); INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(bfloat); +======= +#if __METAL_VERSION__ >= 310 +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(bfloat); +#endif + +INSTANTIATE_AMP_UPDATE_SCALE(float); +INSTANTIATE_AMP_UPDATE_SCALE(half); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_AMP_UPDATE_SCALE(bfloat); +#endif + +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(float); +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(half); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/Attention.metal b/aten/src/ATen/native/mps/kernels/Attention.metal index 5a317895f508e..91e7552f1335a 100644 --- a/aten/src/ATen/native/mps/kernels/Attention.metal +++ b/aten/src/ATen/native/mps/kernels/Attention.metal @@ -14,8 +14,13 @@ template device T* out [[buffer(3)]], const constant uint& gqa_factor [[buffer(4)]], const constant uint& N [[buffer(5)]], +<<<<<<< HEAD const constant uint3& qkv_head_strides [[buffer(6)]], const constant uint3& qkv_seq_strides [[buffer(7)]], +======= + const constant uint2& k_head_seq_stride [[buffer(6)]], + const constant uint2& v_head_seq_stride [[buffer(7)]], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const constant float& scale [[buffer(8)]], const device bool* mask [[buffer(9)]], const constant uint3& mask_strides [[buffer(10)]], @@ -28,12 +33,19 @@ template constexpr uint BD = 32; constexpr uint qk_per_thread = D / BD; constexpr uint v_per_thread = V / BD; +<<<<<<< HEAD const uint q_head_stride = qkv_head_strides.x; const uint q_seq_stride = qkv_seq_strides.x; const uint k_head_stride = qkv_head_strides.y; const uint k_seq_stride = qkv_seq_strides.y; const uint v_head_stride = qkv_head_strides.z; const uint v_seq_stride = qkv_seq_strides.z; +======= + const uint k_head_stride = k_head_seq_stride.x; + const uint k_seq_stride = k_head_seq_stride.y; + const uint v_head_stride = v_head_seq_stride.x; + const uint v_seq_stride = v_head_seq_stride.y; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const uint mask_head_stride = mask_strides.x; const uint mask_kv_seq_stride = mask_strides.y; const uint mask_q_seq_stride = mask_strides.z; @@ -56,9 +68,15 @@ template const int kv_head_idx = head_idx / gqa_factor; const int Q = tpg.y; const int group_offset = head_idx * Q + q_seq_idx; +<<<<<<< HEAD const int o_offset = group_offset; queries += head_idx * q_head_stride + q_seq_idx * q_seq_stride + simd_lid * qk_per_thread; +======= + const int q_offset = group_offset; + const int o_offset = group_offset; + queries += q_offset * D + simd_lid * qk_per_thread; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + simd_lid * qk_per_thread; values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + @@ -158,8 +176,13 @@ template device float* maxs [[buffer(5)]], const constant uint& gqa_factor [[buffer(6)]], const constant uint& N [[buffer(7)]], +<<<<<<< HEAD const constant uint3& qkv_head_strides [[buffer(8)]], const constant uint3& qkv_seq_strides [[buffer(9)]], +======= + const constant uint2& k_head_seq_stride [[buffer(8)]], + const constant uint2& v_head_seq_stride [[buffer(9)]], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const constant float& scale [[buffer(10)]], const device bool* mask [[buffer(11)]], const constant uint3& mask_strides [[buffer(12)]], @@ -172,12 +195,19 @@ template constexpr int BD = 32; constexpr int qk_per_thread = D / BD; constexpr int v_per_thread = V / BD; +<<<<<<< HEAD const int q_head_stride = qkv_head_strides.x; const int q_seq_stride = qkv_seq_strides.x; const int k_head_stride = qkv_head_strides.y; const int k_seq_stride = qkv_seq_strides.y; const int v_head_stride = qkv_head_strides.z; const int v_seq_stride = qkv_seq_strides.z; +======= + const int k_head_stride = k_head_seq_stride.x; + const int k_seq_stride = k_head_seq_stride.y; + const int v_head_stride = v_head_seq_stride.x; + const int v_seq_stride = v_head_seq_stride.y; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const int mask_kv_seq_stride = mask_strides.x; const int mask_q_seq_stride = mask_strides.y; const int mask_head_stride = mask_strides.z; @@ -200,10 +230,17 @@ template const int head_idx = tid.x; const int q_seq_idx = tid.y; const int o_offset = head_idx * tpg.y + q_seq_idx; +<<<<<<< HEAD const int kv_head_idx = head_idx / gqa_factor; queries += head_idx * q_head_stride + q_seq_idx * q_seq_stride + simd_lid * qk_per_thread; +======= + const int q_offset = o_offset; + const int kv_head_idx = head_idx / gqa_factor; + + queries += q_offset * D + simd_lid * qk_per_thread; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) keys += kv_head_idx * k_head_stride + (block_idx * BN + simd_gid) * k_seq_stride + simd_lid * qk_per_thread; values += kv_head_idx * v_head_stride + @@ -524,6 +561,7 @@ kernel void attention( } } +<<<<<<< HEAD #define INSTANTIATE_SDPA_VECTOR(DTYPE, QK_DIM, VALUE_DIM) \ template [[host_name("sdpa_vector_" #DTYPE "_" #QK_DIM \ "_" #VALUE_DIM)]] kernel void \ @@ -543,6 +581,27 @@ kernel void attention( uint3 tid [[threadgroup_position_in_grid]], \ uint3 tpg [[threadgroups_per_grid]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \ +======= +#define INSTANTIATE_SDPA_VECTOR(DTYPE, QK_DIM, VALUE_DIM) \ + template [[host_name("sdpa_vector_" #DTYPE "_" #QK_DIM \ + "_" #VALUE_DIM)]] kernel void \ + sdpa_vector( \ + const device DTYPE* queries [[buffer(0)]], \ + const device DTYPE* keys [[buffer(1)]], \ + const device DTYPE* values [[buffer(2)]], \ + device DTYPE* out [[buffer(3)]], \ + const constant uint& gqa_factor [[buffer(4)]], \ + const constant uint& N [[buffer(5)]], \ + const constant uint2& k_head_seq_stride [[buffer(6)]], \ + const constant uint2& v_head_seq_stride [[buffer(7)]], \ + const constant float& scale [[buffer(8)]], \ + const device bool* mask [[buffer(9)]], \ + const constant uint3& mask_strides [[buffer(10)]], \ + const constant bool& has_mask [[buffer(11)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 tpg [[threadgroups_per_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint simd_lid [[thread_index_in_simdgroup]]); #define INSTANTIATE_SDPA_VECTOR_2PASS_1(DTYPE, QK_DIM, VALUE_DIM) \ @@ -557,8 +616,13 @@ kernel void attention( device float* maxs [[buffer(5)]], \ const constant uint& gqa_factor [[buffer(6)]], \ const constant uint& N [[buffer(7)]], \ +<<<<<<< HEAD const constant uint3& qkv_head_strides [[buffer(8)]], \ const constant uint3& qkv_seq_strides [[buffer(9)]], \ +======= + const constant uint2& k_head_seq_stride [[buffer(8)]], \ + const constant uint2& v_head_seq_stride [[buffer(9)]], \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const constant float& scale [[buffer(10)]], \ const device bool* mask [[buffer(11)]], \ const constant uint3& mask_strides [[buffer(12)]], \ @@ -594,7 +658,13 @@ kernel void attention( INSTANTIATE_SDPA_VECTOR_HEADS(float); INSTANTIATE_SDPA_VECTOR_HEADS(half); +<<<<<<< HEAD INSTANTIATE_SDPA_VECTOR_HEADS(bfloat); +======= +#if __METAL_VERSION__ >= 310 +INSTANTIATE_SDPA_VECTOR_HEADS(bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #define INSTANTIATE_ATTN(DTYPE, bq, bk, bd, wm, wn) \ template [[host_name("attention_" #DTYPE "_bq" #bq "_bk" #bk "_bd" #bd \ @@ -623,4 +693,10 @@ INSTANTIATE_SDPA_VECTOR_HEADS(bfloat); INSTANTIATE_ATTN_SHAPES_HELPER(float); INSTANTIATE_ATTN_SHAPES_HELPER(half); +<<<<<<< HEAD +INSTANTIATE_ATTN_SHAPES_HELPER(bfloat); +======= +#if __METAL_VERSION__ >= 310 INSTANTIATE_ATTN_SHAPES_HELPER(bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 0539eab79500d..7c7becd761f80 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -39,6 +39,7 @@ struct lerp_alpha_functor { } }; +<<<<<<< HEAD struct native_dropout_mask_and_scale_functor { template inline TA operator()(const TI a, const TI b, const TA scale) { @@ -46,6 +47,8 @@ struct native_dropout_mask_and_scale_functor { } }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct fmax_functor { template inline T operator()(const T a, const T b) { @@ -141,6 +144,7 @@ struct chebyshev_polynomial_w_functor { } }; +<<<<<<< HEAD struct shifted_chebyshev_polynomial_t_functor { template , bool> = true> inline T operator()(const T a, const T b) { @@ -193,6 +197,8 @@ struct shifted_chebyshev_polynomial_w_functor { } }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct hermite_polynomial_h_functor { template , bool> = true> inline T operator()(const T a, const T b) { @@ -216,9 +222,44 @@ struct hermite_polynomial_he_functor { }; struct nextafter_functor { +<<<<<<< HEAD template inline T operator()(const T a, const T b) { return static_cast(::metal::nextafter(a, b)); +======= +#if __METAL_VERSION__ < 310 + template + struct bit_type {}; + template <> + struct bit_type { + using type = int; + }; + template <> + struct bit_type { + using type = short; + }; +#endif + template + inline T operator()(const T a, const T b) { +#if __METAL_VERSION__ >= 310 + return static_cast(::metal::nextafter(a, b)); +#else + using U = typename bit_type::type; + if (a == b) { + return a; + } + if (::metal::isunordered(a, b)) { + return NAN; + } + if (a == 0) { + constexpr auto eps = as_type(static_cast(1)); + return b > 0 ? eps : -eps; + } + auto bits = as_type(a); + (a > 0) ^ (a > b) ? bits++ : bits--; + return as_type(bits); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } }; @@ -322,6 +363,7 @@ struct fmod_functor { } }; +<<<<<<< HEAD struct igamma_functor { template inline T operator()(const T a, const T b) { @@ -335,6 +377,14 @@ struct igammac_functor { return c10::metal::igammac(a, b); } }; +======= +// Some helper defines +#if __METAL_VERSION__ >= 310 +#define _METAL_310_PLUS(x) x +#else +#define _METAL_310_PLUS(x) +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #define REGISTER_INTEGER_BINARY_OP(NAME) \ REGISTER_BINARY_OP(NAME, long, long); \ @@ -355,12 +405,20 @@ struct igammac_functor { #define REGISTER_FLOAT_BINARY_OP(NAME) \ REGISTER_BINARY_OP(NAME, float, float); \ REGISTER_BINARY_OP(NAME, half, half); \ +<<<<<<< HEAD REGISTER_BINARY_OP(NAME, bfloat, bfloat) +======= + _METAL_310_PLUS(REGISTER_BINARY_OP(NAME, bfloat, bfloat)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #define REGISTER_OPMATH_FLOAT_BINARY_OP(NAME) \ REGISTER_OPMATH_BINARY_OP(NAME, float, float); \ REGISTER_OPMATH_BINARY_OP(NAME, half, half); \ +<<<<<<< HEAD REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat) +======= + _METAL_310_PLUS(REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) REGISTER_FLOAT_BINARY_OP(copysign); REGISTER_INT2FLOAT_BINARY_OP(copysign); @@ -379,6 +437,7 @@ REGISTER_FLOAT_BINARY_OP(chebyshev_polynomial_v); REGISTER_INT2FLOAT_BINARY_OP(chebyshev_polynomial_w); REGISTER_FLOAT_BINARY_OP(chebyshev_polynomial_w); REGISTER_INT2FLOAT_BINARY_OP(chebyshev_polynomial_v); +<<<<<<< HEAD REGISTER_FLOAT_BINARY_OP(shifted_chebyshev_polynomial_t); REGISTER_INT2FLOAT_BINARY_OP(shifted_chebyshev_polynomial_t); REGISTER_FLOAT_BINARY_OP(shifted_chebyshev_polynomial_u); @@ -387,6 +446,8 @@ REGISTER_FLOAT_BINARY_OP(shifted_chebyshev_polynomial_v); REGISTER_INT2FLOAT_BINARY_OP(shifted_chebyshev_polynomial_v); REGISTER_FLOAT_BINARY_OP(shifted_chebyshev_polynomial_w); REGISTER_INT2FLOAT_BINARY_OP(shifted_chebyshev_polynomial_w); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) REGISTER_FLOAT_BINARY_OP(hermite_polynomial_h); REGISTER_INT2FLOAT_BINARY_OP(hermite_polynomial_h); REGISTER_FLOAT_BINARY_OP(hermite_polynomial_he); @@ -407,8 +468,11 @@ REGISTER_OPMATH_FLOAT_BINARY_OP(remainder); REGISTER_INTEGER_BINARY_OP(remainder); REGISTER_OPMATH_FLOAT_BINARY_OP(fmod); REGISTER_INTEGER_BINARY_OP(fmod); +<<<<<<< HEAD REGISTER_OPMATH_FLOAT_BINARY_OP(igamma); REGISTER_OPMATH_FLOAT_BINARY_OP(igammac); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) REGISTER_BINARY_ALPHA_OP(add_alpha, long, long, long); REGISTER_BINARY_ALPHA_OP(add_alpha, int, int, int); REGISTER_BINARY_ALPHA_OP(add_alpha, float, float, float); @@ -434,6 +498,7 @@ REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar, uchar); REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char, char); REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool, bool); +<<<<<<< HEAD REGISTER_BINARY_ALPHA_OP(native_dropout_mask_and_scale, float, float, float); REGISTER_BINARY_ALPHA_OP(native_dropout_mask_and_scale, bfloat, bfloat, bfloat); REGISTER_BINARY_ALPHA_OP(native_dropout_mask_and_scale, half, half, half); @@ -441,6 +506,13 @@ REGISTER_BINARY_ALPHA_OP(native_dropout_mask_and_scale, half, half, half); REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat); REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat); REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat); +======= +#if __METAL_VERSION__ >= 310 +REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat); +REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Complex binary functions REGISTER_BINARY_OP(polar, float, float2); diff --git a/aten/src/ATen/native/mps/kernels/Bucketization.metal b/aten/src/ATen/native/mps/kernels/Bucketization.metal index a84698d77f57c..51ea3ec43f63a 100644 --- a/aten/src/ATen/native/mps/kernels/Bucketization.metal +++ b/aten/src/ATen/native/mps/kernels/Bucketization.metal @@ -180,8 +180,15 @@ REGISTER_SEARCHSORTED_OP(float, int); REGISTER_SEARCHSORTED_OP(float, long); REGISTER_SEARCHSORTED_OP(half, int); REGISTER_SEARCHSORTED_OP(half, long); +<<<<<<< HEAD REGISTER_SEARCHSORTED_OP(bfloat, int); REGISTER_SEARCHSORTED_OP(bfloat, long); +======= +#if __METAL_VERSION__ >= 310 +REGISTER_SEARCHSORTED_OP(bfloat, int); +REGISTER_SEARCHSORTED_OP(bfloat, long); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) REGISTER_SEARCHSORTED_OP(char, int); REGISTER_SEARCHSORTED_OP(char, long); REGISTER_SEARCHSORTED_OP(uchar, int); diff --git a/aten/src/ATen/native/mps/kernels/Col2Im.metal b/aten/src/ATen/native/mps/kernels/Col2Im.metal index 61f596a9250f4..a0ac3f63e5b38 100644 --- a/aten/src/ATen/native/mps/kernels/Col2Im.metal +++ b/aten/src/ATen/native/mps/kernels/Col2Im.metal @@ -96,4 +96,10 @@ kernel void col2im_kernel( INSTANTIATE_COL2IM(bool); INSTANTIATE_COL2IM(float); INSTANTIATE_COL2IM(half); +<<<<<<< HEAD INSTANTIATE_COL2IM(bfloat); +======= +#if __METAL_VERSION__ >= 310 +INSTANTIATE_COL2IM(bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/CrossKernel.metal b/aten/src/ATen/native/mps/kernels/CrossKernel.metal index bceae51c02db4..5d31426969504 100644 --- a/aten/src/ATen/native/mps/kernels/CrossKernel.metal +++ b/aten/src/ATen/native/mps/kernels/CrossKernel.metal @@ -20,7 +20,13 @@ REGISTER_CROSS_FUNC(short); REGISTER_CROSS_FUNC(char); REGISTER_CROSS_FUNC(uchar); REGISTER_CROSS_FUNC(bool); +<<<<<<< HEAD REGISTER_CROSS_FUNC(bfloat); +======= +#if __METAL_VERSION__ >= 310 +REGISTER_CROSS_FUNC(bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template kernel void cross( @@ -66,4 +72,10 @@ REGISTER_CROSS_OP(short); REGISTER_CROSS_OP(char); REGISTER_CROSS_OP(uchar); REGISTER_CROSS_OP(bool); +<<<<<<< HEAD +REGISTER_CROSS_OP(bfloat); +======= +#if __METAL_VERSION__ >= 310 REGISTER_CROSS_OP(bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal b/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal index f46b10f99bf4b..0626d13acb87b 100644 --- a/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal +++ b/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal @@ -1,9 +1,17 @@ #include using metal::max; +<<<<<<< HEAD bfloat max(bfloat a, bfloat b) { return a > b ? a : b; } +======= +#if __METAL_VERSION__ >= 310 +bfloat max(bfloat a, bfloat b) { + return a > b ? a : b; +} +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #define kmaxThreadGroups 32 #define kmaxTensors 32 @@ -304,9 +312,17 @@ REGISTER_ADAM_OPS_QUART(float, float); REGISTER_ADAM_OPS_QUART(float, half); REGISTER_ADAM_OPS_QUART(half, float); REGISTER_ADAM_OPS_QUART(half, half); +<<<<<<< HEAD +REGISTER_ADAM_OPS_QUART(float, bfloat); +REGISTER_ADAM_OPS_QUART(bfloat, bfloat); +REGISTER_ADAM_OPS_QUART(bfloat, float); +======= +#if __METAL_VERSION__ >= 310 REGISTER_ADAM_OPS_QUART(float, bfloat); REGISTER_ADAM_OPS_QUART(bfloat, bfloat); REGISTER_ADAM_OPS_QUART(bfloat, float); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template inline void sgd_momentum_math( @@ -456,5 +472,12 @@ REGISTER_FUSED_SGD_OP(float); REGISTER_FUSED_SGD_OP(half); REGISTER_FUSED_SGD_MOMENTUM_OP(float); REGISTER_FUSED_SGD_MOMENTUM_OP(half); +<<<<<<< HEAD +REGISTER_FUSED_SGD_OP(bfloat); +REGISTER_FUSED_SGD_MOMENTUM_OP(bfloat); +======= +#if __METAL_VERSION__ >= 310 REGISTER_FUSED_SGD_OP(bfloat); REGISTER_FUSED_SGD_MOMENTUM_OP(bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/Gamma.metal b/aten/src/ATen/native/mps/kernels/Gamma.metal index 1c150a726edb1..1e82f33b9213b 100644 --- a/aten/src/ATen/native/mps/kernels/Gamma.metal +++ b/aten/src/ATen/native/mps/kernels/Gamma.metal @@ -106,7 +106,13 @@ kernel void polygamma( constant int64_t& order [[buffer(2)]], \ uint id [[thread_position_in_grid]]); +<<<<<<< HEAD INSTANTIATE_GAMMA_KERNELS(bfloat, bfloat); +======= +#if __METAL_VERSION__ >= 310 +INSTANTIATE_GAMMA_KERNELS(bfloat, bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) INSTANTIATE_GAMMA_KERNELS(half, half); INSTANTIATE_GAMMA_KERNELS(float, float); INSTANTIATE_GAMMA_KERNELS(bool, float); diff --git a/aten/src/ATen/native/mps/kernels/Im2Col.metal b/aten/src/ATen/native/mps/kernels/Im2Col.metal index 191462bbd3d08..39a8b4a1ed39c 100644 --- a/aten/src/ATen/native/mps/kernels/Im2Col.metal +++ b/aten/src/ATen/native/mps/kernels/Im2Col.metal @@ -76,4 +76,10 @@ INSTANTIATE_IM2COL(float); INSTANTIATE_IM2COL(float2); INSTANTIATE_IM2COL(half); INSTANTIATE_IM2COL(half2); +<<<<<<< HEAD INSTANTIATE_IM2COL(bfloat); +======= +#if __METAL_VERSION__ >= 310 +INSTANTIATE_IM2COL(bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/Indexing.metal b/aten/src/ATen/native/mps/kernels/Indexing.metal index b41e64d70ced5..3b4a49010e108 100644 --- a/aten/src/ATen/native/mps/kernels/Indexing.metal +++ b/aten/src/ATen/native/mps/kernels/Indexing.metal @@ -9,6 +9,7 @@ struct IndexAB { constant int64_t* indexArray; }; +<<<<<<< HEAD uint3 index_get_offsets( constant int64_t* sizes, constant int64_t* output_strides, @@ -43,6 +44,8 @@ OffsetT index_apply_indices( return rc; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template kernel void index_select( device T* output, @@ -58,6 +61,7 @@ kernel void index_select( uint thread_index [[thread_position_in_grid]]) { const auto ndim = ndim_nindices_numel.x; const auto num_indices = ndim_nindices_numel.y; +<<<<<<< HEAD const auto offs = index_get_offsets( sizes, output_strides, @@ -68,6 +72,22 @@ kernel void index_select( auto input_offs = index_apply_indices( offs.yz, indices, index_sizes, index_strides, num_indices); output[offs.x / sizeof(T)] = input[input_offs / sizeof(T)]; +======= + uint pos[max_ndim]; + pos_from_thread_index(thread_index, pos, sizes, ndim); + const auto output_offs = offset_from_coord(pos, output_strides, ndim); + OffsetT input_offs = offset_from_coord(pos, input_strides, ndim); + const auto indices_offs = + offset_from_coord(pos, indices_strides, ndim) / sizeof(int64_t); + for (uint i = 0; i < num_indices; i++) { + auto idx = indices[i].indexArray[indices_offs]; + if (idx < 0) { + idx += index_sizes[i]; + } + input_offs += idx * index_strides[i]; + } + output[output_offs / sizeof(T)] = input[input_offs / sizeof(T)]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } template @@ -85,6 +105,7 @@ inline void index_put_impl( uint thread_index) { const auto ndim = ndim_nindices_numel.x; const auto num_indices = ndim_nindices_numel.y; +<<<<<<< HEAD const auto offs = index_get_offsets( sizes, output_strides, @@ -95,6 +116,22 @@ inline void index_put_impl( auto output_offs = index_apply_indices( offs.xz, indices, index_sizes, index_strides, num_indices); output[output_offs / sizeof(T)] = input[offs.y / sizeof(T)]; +======= + uint pos[max_ndim]; + pos_from_thread_index(thread_index, pos, sizes, ndim); + OffsetT output_offs = offset_from_coord(pos, output_strides, ndim); + const auto input_offs = offset_from_coord(pos, input_strides, ndim); + const auto indices_offs = + offset_from_coord(pos, indices_strides, ndim) / sizeof(int64_t); + for (uint i = 0; i < num_indices; i++) { + auto idx = indices[i].indexArray[indices_offs]; + if (idx < 0) { + idx += index_sizes[i]; + } + output_offs += idx * index_strides[i]; + } + output[output_offs / sizeof(T)] = input[input_offs / sizeof(T)]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } template @@ -169,6 +206,7 @@ kernel void index_put_accumulate( uint thread_index [[thread_position_in_grid]]) { const auto ndim = ndim_nindices_numel.x; const auto num_indices = ndim_nindices_numel.y; +<<<<<<< HEAD const auto offs = index_get_offsets( sizes, output_strides, @@ -182,6 +220,25 @@ kernel void index_put_accumulate( reinterpret_cast*>(output), output_offs / sizeof(T), input[offs.y / sizeof(T)]); +======= + uint pos[max_ndim]; + pos_from_thread_index(thread_index, pos, sizes, ndim); + OffsetT output_offs = offset_from_coord(pos, output_strides, ndim); + const auto input_offs = offset_from_coord(pos, input_strides, ndim); + const auto indices_offs = + offset_from_coord(pos, indices_strides, ndim) / sizeof(int64_t); + for (uint i = 0; i < num_indices; i++) { + auto idx = indices[i].indexArray[indices_offs]; + if (idx < 0) { + idx += index_sizes[i]; + } + output_offs += idx * index_strides[i]; + } + AtomicType::atomic_add( + reinterpret_cast*>(output), + output_offs / sizeof(T), + input[input_offs / sizeof(T)]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } #define REGISTER_INDEX_OP(OP_NAME, SUFFIX, DTYPE) \ @@ -211,6 +268,7 @@ REGISTER_INDEX_OP_ALL_DTYPES(put_serial); REGISTER_INDEX_OP(put_accumulate, float, float); REGISTER_INDEX_OP(put_accumulate, half, half); +<<<<<<< HEAD REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat); REGISTER_INDEX_OP(put_accumulate, long, long); REGISTER_INDEX_OP(put_accumulate, int, int); @@ -220,6 +278,13 @@ REGISTER_INDEX_OP(put_accumulate, uchar, uchar); REGISTER_INDEX_OP(put_accumulate, bool, bool); REGISTER_INDEX_OP(put_accumulate, float2, float2); REGISTER_INDEX_OP(put_accumulate, half2, half2); +======= +REGISTER_INDEX_OP(put_accumulate, int, int); +REGISTER_INDEX_OP(put_accumulate, bool, bool); +#if __METAL_VERSION__ >= 310 +REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template kernel void kernel_index_offsets( @@ -358,7 +423,10 @@ kernel void index_copy_strided( constant long* input_strides, constant long* output_strides, constant long* source_strides, +<<<<<<< HEAD constant long& indices_stride, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint thread_index [[thread_position_in_grid]]) { int pos[max_ndim]; pos_from_thread_index(int(thread_index), pos, sizes, ndim); @@ -375,7 +443,11 @@ kernel void index_copy_strided( // find the last index in the indices array that equals this coordinate int last_matching_index = -1; for (uint i = 0; i < indices_numel; i++) { +<<<<<<< HEAD if (indices[i * indices_stride] == orig_dim) { +======= + if (indices[i] == orig_dim) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) last_matching_index = int(i); } } @@ -414,7 +486,10 @@ kernel void index_copy_strided( constant long*, \ constant long*, \ constant long*, \ +<<<<<<< HEAD constant long&, \ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint); #define REGISTER_MASKED_FILL_SCALAR(SIZE, DTYPE) \ @@ -456,8 +531,15 @@ INSTANTIATE_INDEX_COPY(char, long); INSTANTIATE_INDEX_COPY(uchar, int); INSTANTIATE_INDEX_COPY(uchar, long); +<<<<<<< HEAD +INSTANTIATE_INDEX_COPY(bfloat, int); +INSTANTIATE_INDEX_COPY(bfloat, long); +======= +#if __METAL_VERSION__ >= 310 INSTANTIATE_INDEX_COPY(bfloat, int); INSTANTIATE_INDEX_COPY(bfloat, long); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) INSTANTIATE_INDEX_COPY(float2, int); INSTANTIATE_INDEX_COPY(float2, long); INSTANTIATE_INDEX_COPY(half2, int); diff --git a/aten/src/ATen/native/mps/kernels/LayerNorm.metal b/aten/src/ATen/native/mps/kernels/LayerNorm.metal index 7b4a789ed292a..a3e2ae8c670f8 100644 --- a/aten/src/ATen/native/mps/kernels/LayerNorm.metal +++ b/aten/src/ATen/native/mps/kernels/LayerNorm.metal @@ -1,8 +1,14 @@ +<<<<<<< HEAD #include #include #include using namespace metal; using c10::metal::simdgroup_size; +======= +#include +#include +using namespace metal; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template kernel void layer_norm_single_row( @@ -20,6 +26,10 @@ kernel void layer_norm_single_row( uint tid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simdgroup_id [[simdgroup_index_in_threadgroup]]) { +<<<<<<< HEAD +======= + constexpr int SIMD_SIZE = 32; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) constexpr int N_READS = 4; // each threadgroup handles one full ā€œrowā€ of length axis_size @@ -53,8 +63,13 @@ kernel void layer_norm_single_row( } // threadgroup‐wide reduction +<<<<<<< HEAD threadgroup float local_sums[simdgroup_size]; threadgroup float local_sums_sq[simdgroup_size]; +======= + threadgroup float local_sums[SIMD_SIZE]; + threadgroup float local_sums_sq[SIMD_SIZE]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) threadgroup float tg_mean[1]; threadgroup float tg_inv_std[1]; @@ -143,6 +158,10 @@ kernel void layer_norm_looped( uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simdgroup_id [[simdgroup_index_in_threadgroup]]) { +<<<<<<< HEAD +======= + constexpr int SIMD_SIZE = 32; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) constexpr int N_READS = 4; uint row_offset = tg_id * axis_size; @@ -178,8 +197,13 @@ kernel void layer_norm_looped( partial_sum = simd_sum(partial_sum); partial_sum_sq = simd_sum(partial_sum_sq); +<<<<<<< HEAD threadgroup float local_sums[simdgroup_size]; threadgroup float local_sums_sq[simdgroup_size]; +======= + threadgroup float local_sums[SIMD_SIZE]; + threadgroup float local_sums_sq[SIMD_SIZE]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) threadgroup float tg_mean[1]; threadgroup float tg_inv_std[1]; @@ -288,6 +312,13 @@ kernel void layer_norm_looped( #define instantiate_layer_norm(DTYPE) \ instantiate_layer_norm_single_row(DTYPE) instantiate_layer_norm_looped(DTYPE) +<<<<<<< HEAD instantiate_layer_norm(float); instantiate_layer_norm(half); instantiate_layer_norm(bfloat); +======= +instantiate_layer_norm(float) instantiate_layer_norm(half) +#if __METAL_VERSION__ >= 310 + instantiate_layer_norm(bfloat) +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal index 4ba2bca720db7..3ff7253354ff3 100644 --- a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal +++ b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal @@ -69,6 +69,7 @@ kernel void matmul( } template +<<<<<<< HEAD kernel void addmm( constant T* mat1Data [[buffer(0)]], constant T* mat2Data [[buffer(1)]], @@ -100,6 +101,8 @@ kernel void addmm( } template +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kernel void naive_bmm( constant T* mat1Data [[buffer(0)]], constant T* mat2Data [[buffer(1)]], @@ -644,6 +647,7 @@ kernel void applyPivots( } } +<<<<<<< HEAD #define INSTANTIATE_MM_OPS(DTYPE) \ template [[host_name("matmul_" #DTYPE)]] kernel void matmul( \ constant DTYPE * mat1Data [[buffer(0)]], \ @@ -653,6 +657,19 @@ kernel void applyPivots( constant uint3 & sizes [[buffer(4)]], \ uint2 tid [[thread_position_in_threadgroup]], \ uint2 group_id [[threadgroup_position_in_grid]]); \ +======= +#define INSTANTIATE_NAIVE_MM(DTYPE) \ + template [[host_name("matmul_" #DTYPE)]] kernel void matmul( \ + constant DTYPE * mat1Data [[buffer(0)]], \ + constant DTYPE * mat2Data [[buffer(1)]], \ + device DTYPE * outputData [[buffer(2)]], \ + constant array & strides [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 tid [[thread_position_in_threadgroup]], \ + uint2 group_id [[threadgroup_position_in_grid]]) + +#define INSTANTIATE_NAIVE_BMM(DTYPE) \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template [[host_name("naive_bmm_" #DTYPE)]] kernel void naive_bmm( \ constant DTYPE * mat1Data [[buffer(0)]], \ constant DTYPE * mat2Data [[buffer(1)]], \ @@ -660,6 +677,7 @@ kernel void applyPivots( constant array & strides [[buffer(3)]], \ constant uint4 & sizes [[buffer(4)]], \ uint3 tid [[thread_position_in_threadgroup]], \ +<<<<<<< HEAD uint3 group_id [[threadgroup_position_in_grid]]); \ template [[host_name("addmm_" #DTYPE)]] kernel void addmm( \ constant DTYPE * mat1Data [[buffer(0)]], \ @@ -683,3 +701,24 @@ INSTANTIATE_MM_OPS(int); INSTANTIATE_MM_OPS(short); INSTANTIATE_MM_OPS(char); INSTANTIATE_MM_OPS(uchar); +======= + uint3 group_id [[threadgroup_position_in_grid]]) + +INSTANTIATE_NAIVE_MM(float); +INSTANTIATE_NAIVE_MM(half); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_NAIVE_MM(bfloat); +#endif + +// Integral MM +INSTANTIATE_NAIVE_MM(short); +INSTANTIATE_NAIVE_MM(int); +INSTANTIATE_NAIVE_MM(long); +INSTANTIATE_NAIVE_MM(char); +INSTANTIATE_NAIVE_MM(uchar); +INSTANTIATE_NAIVE_BMM(short); +INSTANTIATE_NAIVE_BMM(int); +INSTANTIATE_NAIVE_BMM(long); +INSTANTIATE_NAIVE_BMM(char); +INSTANTIATE_NAIVE_BMM(uchar); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/Quantized.metal b/aten/src/ATen/native/mps/kernels/Quantized.metal index b84c033a07f49..3973b1f04de96 100644 --- a/aten/src/ATen/native/mps/kernels/Quantized.metal +++ b/aten/src/ATen/native/mps/kernels/Quantized.metal @@ -197,10 +197,18 @@ INSTANTIATE_INT4MV(float, 128); INSTANTIATE_INT4MV(half, 128); INSTANTIATE_INT4MV(float, 256); INSTANTIATE_INT4MV(half, 256); +<<<<<<< HEAD +======= +#if __METAL_VERSION__ >= 310 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) INSTANTIATE_INT4MV(bfloat, 32); INSTANTIATE_INT4MV(bfloat, 64); INSTANTIATE_INT4MV(bfloat, 128); INSTANTIATE_INT4MV(bfloat, 256); +<<<<<<< HEAD +======= +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // ------------------------------ int8 MM For M >= 12 ------------------------------------ /** @@ -232,10 +240,18 @@ template <> struct BlockType { using simdgroup_type8x8 = simdgroup_half8x8; using type4 = half4; }; +<<<<<<< HEAD +======= +#if __METAL_VERSION__ >= 310 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template <> struct BlockType { using simdgroup_type8x8 = simdgroup_bfloat8x8; using type4 = bfloat4; }; +<<<<<<< HEAD +======= +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template float2 get_scale_zero_q8(constant T * scalesAndZeros, uint2 index) { @@ -486,7 +502,13 @@ kernel void kernel_mul_mm( \ INSTANTIATE_MM(float, char, get_scale_zero_q8); INSTANTIATE_MM(half, char, get_scale_zero_q8); +<<<<<<< HEAD INSTANTIATE_MM(bfloat, char, get_scale_zero_q8); +======= +#if __METAL_VERSION__ >= 310 +INSTANTIATE_MM(bfloat, char, get_scale_zero_q8); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // ------------------------------ int8 MM For M < 12 ------------------------------------ /* Matrix vector multiplication, used for small M size for matrix multiplication as well. @@ -640,4 +662,10 @@ kernel void kernel_mul_mv( INSTANTIATE_MV(float); INSTANTIATE_MV(half); +<<<<<<< HEAD +INSTANTIATE_MV(bfloat); +======= +#if __METAL_VERSION__ >= 310 INSTANTIATE_MV(bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/RMSNorm.metal b/aten/src/ATen/native/mps/kernels/RMSNorm.metal index d6c69217e65f3..51d94b4540efb 100644 --- a/aten/src/ATen/native/mps/kernels/RMSNorm.metal +++ b/aten/src/ATen/native/mps/kernels/RMSNorm.metal @@ -2,13 +2,19 @@ // https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/rms_norm.metal // Copyright Ā© 2024 Apple Inc. +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include using namespace metal; +<<<<<<< HEAD using c10::metal::simdgroup_size; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template [[kernel]] void rms_single_row( @@ -22,10 +28,18 @@ template uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { +<<<<<<< HEAD constexpr int N_READS = 4; threadgroup float local_inv_mean[1]; threadgroup float local_sums[simdgroup_size]; +======= + constexpr int SIMD_SIZE = 32; + constexpr int N_READS = 4; + + threadgroup float local_inv_mean[1]; + threadgroup float local_sums[SIMD_SIZE]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) float acc = 0; x += gid * size_t(axis_size) + lid * N_READS; @@ -93,9 +107,16 @@ template uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { +<<<<<<< HEAD constexpr int N_READS = 4; threadgroup float local_inv_mean[1]; threadgroup float local_sums[simdgroup_size]; +======= + constexpr int SIMD_SIZE = 32; + constexpr int N_READS = 4; + threadgroup float local_inv_mean[1]; + threadgroup float local_sums[SIMD_SIZE]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) float acc = 0; x += gid * size_t(axis_size) + lid * N_READS; @@ -192,4 +213,10 @@ template instantiate_rms(float) instantiate_rms(half) +<<<<<<< HEAD +instantiate_rms(bfloat) +======= +#if __METAL_VERSION__ >= 310 instantiate_rms(bfloat) +#endif // clang-format on +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/RenormKernel.metal b/aten/src/ATen/native/mps/kernels/RenormKernel.metal index 0bfd60b04c162..518ebf380bcb7 100644 --- a/aten/src/ATen/native/mps/kernels/RenormKernel.metal +++ b/aten/src/ATen/native/mps/kernels/RenormKernel.metal @@ -23,4 +23,10 @@ kernel void renorm( REGISTER_RENORM_OP(float); REGISTER_RENORM_OP(half); +<<<<<<< HEAD REGISTER_RENORM_OP(bfloat); +======= +#if __METAL_VERSION__ >= 310 +REGISTER_RENORM_OP(bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/ScanKernel.metal b/aten/src/ATen/native/mps/kernels/ScanKernel.metal index de493af7aaa05..4fa91dea9b4d4 100644 --- a/aten/src/ATen/native/mps/kernels/ScanKernel.metal +++ b/aten/src/ATen/native/mps/kernels/ScanKernel.metal @@ -1,4 +1,7 @@ +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include using namespace metal; @@ -7,6 +10,7 @@ using namespace metal; using c10::metal::accum_t; +<<<<<<< HEAD struct LogAddExp { template T operator()(T x, T y) { @@ -212,11 +216,43 @@ struct CumMinOp { return make_pair( simd_shuffle_and_fill_up(data.value, filling.value, delta), simd_shuffle_and_fill_up(data.index, filling.index, delta)); +======= +template > +struct CumSumOp { + static acc_t apply(acc_t a, acc_t b) { + return a + b; + } + static acc_t identity() { + return acc_t(0); + } +}; + +template > +struct CumProdOp { + static acc_t apply(acc_t a, acc_t b) { + return a * b; + } + static acc_t identity() { + return acc_t(1); + } +}; + +template > +struct CumMinOp { + static acc_t apply(acc_t a, acc_t b) { + return metal::min(a, b); + } + static acc_t identity() { + return static_cast( + metal::is_floating_point_v ? metal::numeric_limits::infinity() + : metal::numeric_limits::max()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } }; template > struct CumMaxOp { +<<<<<<< HEAD using pair_t = ValueIndexPair; static constexpr constant acc_t init_val = static_cast( @@ -420,10 +456,43 @@ kernel void scan_innermost_dim( } threadgroup_barrier(mem_flags::mem_threadgroup); prefix = simdgroup_sums[0]; +======= + static acc_t apply(acc_t a, acc_t b) { + return metal::max(a, b); + } + static acc_t identity() { + return static_cast( + metal::is_floating_point_v ? -metal::numeric_limits::infinity() + : metal::numeric_limits::lowest()); + } +}; + +// Inclusive scan along innermost dimension for contiguous tensors +template > +kernel void scan_contiguous_innermost_dim( + constant T* input [[buffer(0)]], + device T* output [[buffer(1)]], + constant uint& num_rows [[buffer(2)]], + constant uint& row_size [[buffer(3)]], + uint row [[thread_position_in_grid]]) { + if (row >= num_rows) + return; + + const uint offset = row * row_size; + + acc_t accumulator = Op::identity(); + + for (uint col = 0; col < row_size; col++) { + T val = input[offset + col]; + acc_t accum_val = static_cast(val); + accumulator = Op::apply(accumulator, accum_val); + output[offset + col] = static_cast(accumulator); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // Inclusive scan along outer dimension for contiguous tensors +<<<<<<< HEAD template > kernel void scan_outer_dim( const device T* in [[buffer(0)]], @@ -518,10 +587,126 @@ kernel void scan_outer_dim( out[index_y * stride + i] = static_cast(read_into[i]); } } +======= +template > +kernel void scan_contiguous_outer_dim( + constant T* input [[buffer(0)]], + device T* output [[buffer(1)]], + constant uint& num_orows [[buffer(2)]], + constant uint& num_irows [[buffer(3)]], + constant uint& row_size [[buffer(4)]], + uint thread_index [[thread_position_in_grid]]) { + const uint orow = thread_index / num_irows; + const uint irow = thread_index % num_irows; + + if (orow >= num_orows) + return; + + acc_t accumulator = Op::identity(); + + const uint idx_base = orow * row_size * num_irows + irow; + for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) { + T val = input[idx]; + acc_t accum_val = static_cast(val); + accumulator = Op::apply(accumulator, accum_val); + output[idx] = static_cast(accumulator); + } +} + +// Inclusive scan with indices along innermost dimension for contiguous tensors +template > +kernel void scan_with_indices_contiguous_innermost_dim( + constant T* input [[buffer(0)]], + device T* values [[buffer(1)]], + device int64_t* indices [[buffer(2)]], + constant uint& num_rows [[buffer(3)]], + constant uint& row_size [[buffer(4)]], + uint row [[thread_position_in_grid]]) { + if (row >= num_rows) + return; + + const uint offset = row * row_size; + + acc_t accumulator = Op::identity(); + int64_t best_idx = 0; + + for (uint col = 0; col < row_size; col++) { + T val = input[offset + col]; + acc_t accum_val = static_cast(val); + if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) { + accumulator = accum_val; + best_idx = col; } + values[offset + col] = static_cast(accumulator); + indices[offset + col] = best_idx; } } +// Inclusive scan with indices along outer dimension for contiguous tensors +template > +kernel void scan_with_indices_contiguous_outer_dim( + constant T* input [[buffer(0)]], + device T* values [[buffer(1)]], + device int64_t* indices [[buffer(2)]], + constant uint& num_orows [[buffer(3)]], + constant uint& num_irows [[buffer(4)]], + constant uint& row_size [[buffer(5)]], + uint thread_index [[thread_position_in_grid]]) { + const uint orow = thread_index / num_irows; + const uint irow = thread_index % num_irows; + + if (orow >= num_orows) + return; + + acc_t accumulator = Op::identity(); + int64_t best_idx = 0; + + const uint idx_base = orow * row_size * num_irows + irow; + for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) { + T val = input[idx]; + acc_t accum_val = static_cast(val); + if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) { + accumulator = accum_val; + best_idx = col; + } + values[idx] = static_cast(accumulator); + indices[idx] = best_idx; + } +} + +// Shared utility functions for strided kernels +inline long calculate_non_scan_elements( + constant long* sizes, + uint ndim, + uint scan_dim) { + long total = 1; + for (uint i = 0; i < ndim; ++i) { + if (i != scan_dim) { + total *= sizes[i]; + } + } + return total; +} + +inline void thread_index_to_coordinates( + uint index, + int pos[c10::metal::max_ndim], + constant long* sizes, + uint ndim, + uint scan_dim) { + long remaining_index = index; + for (uint i = 0; i < ndim; ++i) { + if (i != scan_dim) { + pos[i] = remaining_index % sizes[i]; + remaining_index /= sizes[i]; + } else { + pos[i] = 0; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + } + } +} + +<<<<<<< HEAD template > kernel void scan_with_indices_innermost_dim( const device T* in [[buffer(0)]], @@ -786,3 +971,200 @@ REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short, 4); REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char, 4); REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar, 4); REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool, 4); +======= +inline long calculate_base_offset( + int pos[c10::metal::max_ndim], + constant long* strides, + uint ndim, + uint scan_dim) { + long offset = 0; + for (uint i = 0; i < ndim; ++i) { + if (i != scan_dim) { + offset += pos[i] * strides[i]; + } + } + return offset; +} + +// Generic strided scan kernel +template > +kernel void scan_strided( + constant T* input [[buffer(0)]], + device T* output [[buffer(1)]], + constant long* sizes [[buffer(2)]], + constant long* input_strides [[buffer(3)]], + constant long* output_strides [[buffer(4)]], + constant uint& ndim [[buffer(5)]], + constant uint& scan_dim [[buffer(6)]], + uint thread_index [[thread_position_in_grid]]) { + const long total_non_scan_elements = + calculate_non_scan_elements(sizes, ndim, scan_dim); + if (thread_index >= total_non_scan_elements) { + return; + } + + int pos[c10::metal::max_ndim]; + thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim); + + const long input_base_offset = + calculate_base_offset(pos, input_strides, ndim, scan_dim); + const long output_base_offset = + calculate_base_offset(pos, output_strides, ndim, scan_dim); + + acc_t accumulator = Op::identity(); + const long scan_size = sizes[scan_dim]; + const long input_scan_stride = input_strides[scan_dim]; + const long output_scan_stride = output_strides[scan_dim]; + + for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) { + const long input_offset = input_base_offset + scan_idx * input_scan_stride; + const long output_offset = + output_base_offset + scan_idx * output_scan_stride; + + T val = input[input_offset]; + acc_t accum_val = static_cast(val); + accumulator = Op::apply(accumulator, accum_val); + output[output_offset] = static_cast(accumulator); + } +} + +// Generic strided scan with indices kernel +template > +kernel void scan_with_indices_strided( + constant T* input [[buffer(0)]], + device T* values [[buffer(1)]], + device int64_t* indices [[buffer(2)]], + constant long* sizes [[buffer(3)]], + constant long* input_strides [[buffer(4)]], + constant long* values_strides [[buffer(5)]], + constant long* indices_strides [[buffer(6)]], + constant uint& ndim [[buffer(7)]], + constant uint& scan_dim [[buffer(8)]], + uint thread_index [[thread_position_in_grid]]) { + const long total_non_scan_elements = + calculate_non_scan_elements(sizes, ndim, scan_dim); + if (thread_index >= total_non_scan_elements) { + return; + } + + int pos[c10::metal::max_ndim]; + thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim); + + const long input_base_offset = + calculate_base_offset(pos, input_strides, ndim, scan_dim); + const long values_base_offset = + calculate_base_offset(pos, values_strides, ndim, scan_dim); + const long indices_base_offset = + calculate_base_offset(pos, indices_strides, ndim, scan_dim); + + acc_t accumulator = Op::identity(); + int64_t best_idx = 0; + const long scan_size = sizes[scan_dim]; + const long input_scan_stride = input_strides[scan_dim]; + const long values_scan_stride = values_strides[scan_dim]; + const long indices_scan_stride = indices_strides[scan_dim]; + + for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) { + const long input_offset = input_base_offset + scan_idx * input_scan_stride; + const long values_offset = + values_base_offset + scan_idx * values_scan_stride; + const long indices_offset = + indices_base_offset + scan_idx * indices_scan_stride; + + T val = input[input_offset]; + acc_t accum_val = static_cast(val); + if (scan_idx == 0 || Op::apply(accum_val, accumulator) == accum_val) { + accumulator = accum_val; + best_idx = scan_idx; + } + values[values_offset] = static_cast(accumulator); + indices[indices_offset] = best_idx; + } +} + +#define REGISTER_SCAN_OP(OP_NAME, OP_CLASS, DTYPE) \ + template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \ + scan_contiguous_innermost_dim>( \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * output [[buffer(1)]], \ + constant uint & num_rows [[buffer(2)]], \ + constant uint & row_size [[buffer(3)]], \ + uint row [[thread_position_in_grid]]); \ + \ + template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \ + scan_contiguous_outer_dim>( \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * output [[buffer(1)]], \ + constant uint & num_orows [[buffer(2)]], \ + constant uint & num_irows [[buffer(3)]], \ + constant uint & row_size [[buffer(4)]], \ + uint thread_index [[thread_position_in_grid]]); \ + \ + template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \ + scan_strided>( \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * output [[buffer(1)]], \ + constant long* sizes [[buffer(2)]], \ + constant long* input_strides [[buffer(3)]], \ + constant long* output_strides [[buffer(4)]], \ + constant uint& ndim [[buffer(5)]], \ + constant uint& scan_dim [[buffer(6)]], \ + uint thread_index [[thread_position_in_grid]]); + +#define REGISTER_SCAN_WITH_INDICES_OP(OP_NAME, OP_CLASS, DTYPE) \ + template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \ + scan_with_indices_contiguous_innermost_dim>( \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * values [[buffer(1)]], \ + device int64_t* indices [[buffer(2)]], \ + constant uint& num_rows [[buffer(3)]], \ + constant uint& row_size [[buffer(4)]], \ + uint row [[thread_position_in_grid]]); \ + \ + template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \ + scan_with_indices_contiguous_outer_dim>( \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * values [[buffer(1)]], \ + device int64_t* indices [[buffer(2)]], \ + constant uint& num_orows [[buffer(3)]], \ + constant uint& num_irows [[buffer(4)]], \ + constant uint& row_size [[buffer(5)]], \ + uint thread_index [[thread_position_in_grid]]); \ + \ + template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \ + scan_with_indices_strided>( \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * values [[buffer(1)]], \ + device int64_t* indices [[buffer(2)]], \ + constant long* sizes [[buffer(3)]], \ + constant long* input_strides [[buffer(4)]], \ + constant long* values_strides [[buffer(5)]], \ + constant long* indices_strides [[buffer(6)]], \ + constant uint& ndim [[buffer(7)]], \ + constant uint& scan_dim [[buffer(8)]], \ + uint thread_index [[thread_position_in_grid]]); + +// Scan operations with indices +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, float); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, half); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, long); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, int); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, short); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, char); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, uchar); +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, bool); + +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, float); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, half); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, long); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, int); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool); + +#if __METAL_VERSION__ >= 310 +REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, bfloat); +REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/SpecialOps.metal b/aten/src/ATen/native/mps/kernels/SpecialOps.metal index 1e37573a36e88..e5e6c312a56e4 100644 --- a/aten/src/ATen/native/mps/kernels/SpecialOps.metal +++ b/aten/src/ATen/native/mps/kernels/SpecialOps.metal @@ -89,4 +89,10 @@ REGISTER_SPECIAL(short, float); REGISTER_SPECIAL(int, float); REGISTER_SPECIAL(long, float); REGISTER_SPECIAL(half, half); +<<<<<<< HEAD REGISTER_SPECIAL(bfloat, bfloat); +======= +#if __METAL_VERSION__ >= 310 +REGISTER_SPECIAL(bfloat, bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/TriangularOps.metal b/aten/src/ATen/native/mps/kernels/TriangularOps.metal index ad1a0f93a217a..5b994f740f6d3 100644 --- a/aten/src/ATen/native/mps/kernels/TriangularOps.metal +++ b/aten/src/ATen/native/mps/kernels/TriangularOps.metal @@ -100,7 +100,13 @@ kernel void triul( INSTANTIATE_TRIUL_KERNELS(float, int); INSTANTIATE_TRIUL_KERNELS(half, int); +<<<<<<< HEAD INSTANTIATE_TRIUL_KERNELS(bfloat, int); +======= +#if __METAL_VERSION__ >= 310 +INSTANTIATE_TRIUL_KERNELS(bfloat, int); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) INSTANTIATE_TRIUL_KERNELS(float2, int); INSTANTIATE_TRIUL_KERNELS(half2, int); diff --git a/aten/src/ATen/native/mps/kernels/UnaryKernel.metal b/aten/src/ATen/native/mps/kernels/UnaryKernel.metal index 7db38da80532f..fa555fc2b2824 100644 --- a/aten/src/ATen/native/mps/kernels/UnaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/UnaryKernel.metal @@ -490,6 +490,14 @@ struct bitwise_not_functor { } }; +<<<<<<< HEAD +======= +template +float erfc(T x) { + return 1.0 - erf(x); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct round_decimals_functor { template inline T operator()(const T x, const long ndigits) { @@ -498,6 +506,7 @@ struct round_decimals_functor { } }; +<<<<<<< HEAD struct round_functor { template , bool> = true> inline T operator()(const T x) { @@ -509,6 +518,8 @@ struct round_functor { } }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DEFINE_UNARY_FLOATING_FUNCTOR(erf); DEFINE_UNARY_FLOATING_FUNCTOR(erfc); DEFINE_UNARY_FLOATING_FUNCTOR(erfinv); @@ -521,6 +532,7 @@ REGISTER_UNARY_OP(neg, char, char); REGISTER_UNARY_OP(neg, uchar, uchar); REGISTER_UNARY_OP(neg, float, float); REGISTER_UNARY_OP(neg, half, half); +<<<<<<< HEAD REGISTER_UNARY_OP(round, int, int); REGISTER_UNARY_OP(round, long, long); REGISTER_UNARY_OP(round, short, short); @@ -528,6 +540,8 @@ REGISTER_UNARY_OP(round, char, char); REGISTER_UNARY_OP(round, uchar, uchar); REGISTER_UNARY_OP(round, float, float); REGISTER_UNARY_OP(round, half, half); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) REGISTER_UNARY_OP(bitwise_not, int, int); REGISTER_UNARY_OP(bitwise_not, long, long); @@ -569,10 +583,18 @@ REGISTER_UNARY_OP(abs, half, half); REGISTER_UNARY_OP(acos, DTYPE1, DTYPE0); \ REGISTER_UNARY_OP(atan, DTYPE1, DTYPE0) +<<<<<<< HEAD INSTANTIATE_UNARY_KERNELS2(bfloat, bfloat); REGISTER_UNARY_OP(neg, bfloat, bfloat); REGISTER_UNARY_OP(round, bfloat, bfloat); REGISTER_UNARY_OP(abs, bfloat, bfloat); +======= +#if __METAL_VERSION__ >= 310 +INSTANTIATE_UNARY_KERNELS2(bfloat, bfloat); +REGISTER_UNARY_OP(neg, bfloat, bfloat); +REGISTER_UNARY_OP(abs, bfloat, bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) INSTANTIATE_UNARY_KERNELS2(half, half); INSTANTIATE_UNARY_KERNELS2(float, float); INSTANTIATE_UNARY_KERNELS2(float, bool); @@ -612,4 +634,10 @@ INSTANTIATE_UNARY_KERNELS_VEC2(float); REGISTER_UNARY_ALPHA_OP(round_decimals, float, long, float); REGISTER_UNARY_ALPHA_OP(round_decimals, half, long, half); +<<<<<<< HEAD +REGISTER_UNARY_ALPHA_OP(round_decimals, bfloat, long, bfloat); +======= +#if __METAL_VERSION__ >= 310 REGISTER_UNARY_ALPHA_OP(round_decimals, bfloat, long, bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/UnfoldBackward.metal b/aten/src/ATen/native/mps/kernels/UnfoldBackward.metal index 8369258a30a6a..25a0cff0a1039 100644 --- a/aten/src/ATen/native/mps/kernels/UnfoldBackward.metal +++ b/aten/src/ATen/native/mps/kernels/UnfoldBackward.metal @@ -70,4 +70,10 @@ kernel void unfold_backward( INSTANTIATE_UNFOLD_BACKWARD(float); INSTANTIATE_UNFOLD_BACKWARD(half); +<<<<<<< HEAD INSTANTIATE_UNFOLD_BACKWARD(bfloat); +======= +#if __METAL_VERSION__ >= 310 +INSTANTIATE_UNFOLD_BACKWARD(bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/kernels/UpSample.h b/aten/src/ATen/native/mps/kernels/UpSample.h index e9fb5f8b631ed..5d3f0d007dc9d 100644 --- a/aten/src/ATen/native/mps/kernels/UpSample.h +++ b/aten/src/ATen/native/mps/kernels/UpSample.h @@ -1,4 +1,5 @@ #pragma once +<<<<<<< HEAD #include template @@ -8,5 +9,24 @@ struct UpsampleParams { ::c10::metal::array output_strides; ::c10::metal::array output_sizes; ::c10::metal::array scales; +======= + +#ifndef __METAL__ +#include +using ulong = unsigned long; +#define _ARRAY_NS std +#else +#include +#define _ARRAY_NS metal +#endif + +template +struct UpsampleParams { + _ARRAY_NS::array input_strides; + _ARRAY_NS::array input_sizes; + _ARRAY_NS::array output_strides; + _ARRAY_NS::array output_sizes; + _ARRAY_NS::array scales; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool align_corners; }; diff --git a/aten/src/ATen/native/mps/kernels/UpSample.metal b/aten/src/ATen/native/mps/kernels/UpSample.metal index 393c9e1b4d422..90d0c3cc1e304 100644 --- a/aten/src/ATen/native/mps/kernels/UpSample.metal +++ b/aten/src/ATen/native/mps/kernels/UpSample.metal @@ -66,7 +66,11 @@ template scalar_t upsample_get_value_bounded( constant scalar_t* data, uint3 dim, +<<<<<<< HEAD ::metal::array strides, +======= + array strides, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint n, uint c, uint z, @@ -131,7 +135,11 @@ template void upsample_increment_value_bounded( device AtomicType_t* data, uint3 dim, +<<<<<<< HEAD ::metal::array strides, +======= + array strides, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint n, uint c, uint z, @@ -852,4 +860,10 @@ INSTANTIATE_UPSAMPLE_2D(bilinear2d, uchar); INSTANTIATE_UPSAMPLE_3D(uchar); INSTANTIATE_UPSAMPLE_ALL(float); INSTANTIATE_UPSAMPLE_ALL(half); +<<<<<<< HEAD INSTANTIATE_UPSAMPLE_ALL(bfloat); +======= +#if __METAL_VERSION__ >= 310 +INSTANTIATE_UPSAMPLE_ALL(bfloat); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/mps/operations/Attention.mm b/aten/src/ATen/native/mps/operations/Attention.mm index 11498ade6fd08..be17aac55ec1c 100644 --- a/aten/src/ATen/native/mps/operations/Attention.mm +++ b/aten/src/ATen/native/mps/operations/Attention.mm @@ -114,6 +114,7 @@ graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask); maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil]; } +<<<<<<< HEAD // Account for case where all values were masked causing division by 0 in softmax (issue:#156707) // Overwrites expected NANs in sm with zeros. @@ -130,6 +131,10 @@ name:nil]; auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil]; +======= + auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil]; + auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:sm secondaryTensor:vTensor name:nil]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph->qTensor = qTensor; graph->kTensor = kTensor; graph->vTensor = vTensor; @@ -182,8 +187,11 @@ uint maxSeqLength = k_.size(2); uint N = k_.size(2); uint B = q_.size(0) * q_.size(1); +<<<<<<< HEAD uint q_head_stride = q_.stride(1); uint q_seq_stride = q_.stride(2); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint k_head_stride = k_.stride(1); uint k_seq_stride = k_.stride(2); uint v_head_stride = v_.stride(1); @@ -211,8 +219,13 @@ out, 1, N, +<<<<<<< HEAD std::array{q_head_stride, k_head_stride, v_head_stride}, std::array{q_seq_stride, k_seq_stride, v_seq_stride}, +======= + std::array{k_head_stride, k_seq_stride}, + std::array{v_head_stride, v_seq_stride}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) scale_factor); if (has_mask) { @@ -259,8 +272,11 @@ uint B = batchSize * num_heads; uint gqa_factor = q_.size(1) / k_.size(1); +<<<<<<< HEAD uint q_head_stride = q_.stride(1); uint q_seq_stride = q_.stride(2); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint k_head_stride = k_.stride(1); uint k_seq_stride = k_.stride(2); uint v_head_stride = v_.stride(1); @@ -298,8 +314,13 @@ maxs, gqa_factor, N, +<<<<<<< HEAD std::array{q_head_stride, k_head_stride, v_head_stride}, std::array{q_seq_stride, k_seq_stride, v_seq_stride}, +======= + std::array{k_head_stride, k_seq_stride}, + std::array{v_head_stride, v_seq_stride}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) scale_factor); if (has_mask) { diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 0b303f48028f4..f5f88b6b1878a 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -53,7 +53,10 @@ void binary_op_kernel(const std::string func_name, .add_input(input) .add_input(other) .check_all_same_dtype(false) +<<<<<<< HEAD .promote_inputs_to_common_dtype(true) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .build(); lib.exec_binary_kernel(iter, func_name, alpha); @@ -120,6 +123,7 @@ static void chebyshev_polynomial_w_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "chebyshev_polynomial_w"); } +<<<<<<< HEAD static void shifted_chebyshev_polynomial_t_mps_kernel(TensorIteratorBase& iter) { TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "shifted_chebyshev_polynomial_t_mps not implemented for non-floating types"); @@ -144,6 +148,8 @@ static void shifted_chebyshev_polynomial_w_mps_kernel(TensorIteratorBase& iter) lib.exec_binary_kernel(iter, "shifted_chebyshev_polynomial_w"); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static void hermite_polynomial_h_mps_kernel(TensorIteratorBase& iter) { TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "hermite_polynomial_h_mps not implemented for non-floating types"); @@ -168,10 +174,13 @@ static void lerp_scalar_mps_kernel(at::TensorIteratorBase& iter, const Scalar& w lib.exec_binary_kernel(iter, "lerp_alpha", weight); } +<<<<<<< HEAD static void native_dropout_mask_and_scale_mps_kernel(at::TensorIteratorBase& iter, const Scalar& scale) { lib.exec_binary_kernel(iter, "native_dropout_mask_and_scale", scale); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static void mul_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "mul"); } @@ -196,6 +205,7 @@ static void fmod_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "fmod"); } +<<<<<<< HEAD static void igamma_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "igamma"); } @@ -204,6 +214,8 @@ static void igammac_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "igammac"); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel) REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel) REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel) @@ -214,10 +226,13 @@ static void igammac_mps_kernel(TensorIteratorBase& iter) { REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_v_stub, &chebyshev_polynomial_v_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_w_stub, &chebyshev_polynomial_w_mps_kernel) +<<<<<<< HEAD REGISTER_DISPATCH(shifted_chebyshev_polynomial_t_stub, &shifted_chebyshev_polynomial_t_mps_kernel) REGISTER_DISPATCH(shifted_chebyshev_polynomial_u_stub, &shifted_chebyshev_polynomial_u_mps_kernel) REGISTER_DISPATCH(shifted_chebyshev_polynomial_v_stub, &shifted_chebyshev_polynomial_v_mps_kernel) REGISTER_DISPATCH(shifted_chebyshev_polynomial_w_stub, &shifted_chebyshev_polynomial_w_mps_kernel) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) REGISTER_DISPATCH(hermite_polynomial_h_stub, &hermite_polynomial_h_mps_kernel) REGISTER_DISPATCH(hermite_polynomial_he_stub, &hermite_polynomial_he_mps_kernel) REGISTER_DISPATCH(polar_stub, &polar_mps_kernel); @@ -229,6 +244,9 @@ static void igammac_mps_kernel(TensorIteratorBase& iter) { REGISTER_DISPATCH(div_trunc_stub, &div_trunc_mps_kernel) REGISTER_DISPATCH(fmod_stub, &fmod_mps_kernel) REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel) +<<<<<<< HEAD REGISTER_DISPATCH(igamma_stub, &igamma_mps_kernel) REGISTER_DISPATCH(igammac_stub, &igammac_mps_kernel) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 06b6edcff9407..622e0327ce23b 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -48,11 +48,34 @@ #define BinaryOpFn(graph, primary, secondary) \ MPSGraphTensor*(mps::BinaryOpCachedGraph * graph, MPSGraphTensor * primary, MPSGraphTensor * secondary) +<<<<<<< HEAD +======= +static inline Tensor legacy_complex_as_view(const Tensor& t) { + // Convert non-complex types (and cdouble CPU scalars) to cfloat + if (!isComplexType(t.scalar_type()) || t.scalar_type() == kComplexDouble) { + return at::view_as_real(t.to(kMPS, kComplexFloat)); + } + return at::view_as_real(t.dim() != 0 ? t : t.to(kMPS)); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static void binaryOpTensor(const Tensor& self, const Tensor& other, const Tensor& output_, std::string op_name, BinaryOpBlock binaryBlock) { +<<<<<<< HEAD +======= + TORCH_CHECK(!(op_name == "power" && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS) && + (self.scalar_type() == ScalarType::Long || + (other.scalar_type() == ScalarType::Long && + (self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))), + "MPS: ", + op_name, + " op with int64 input is supported natively starting from macOS 13.2"); + TORCH_CHECK_TYPE(!isComplexType(self.scalar_type()) || mps::supportsComplex(), + "Complex types are supported starting from MacOS 14.0+"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPSStream* mpsStream = getCurrentMPSStream(); const bool is_self_scalar = self.dim() == 0; diff --git a/aten/src/ATen/native/mps/operations/Blas.mm b/aten/src/ATen/native/mps/operations/Blas.mm index 101ef5feb224e..45140f657d26c 100644 --- a/aten/src/ATen/native/mps/operations/Blas.mm +++ b/aten/src/ATen/native/mps/operations/Blas.mm @@ -51,6 +51,12 @@ inline void dot_check(const Tensor& self, const Tensor& other) { } // namespace mps Tensor dot_mps(const Tensor& self, const Tensor& other) { +<<<<<<< HEAD +======= + TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) || self.scalar_type() != ScalarType::Long, + "MPS: dot op doesn't support int64 input on MacOS13") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using namespace mps; using CachedGraph = MPSBinaryCachedGraph; diff --git a/aten/src/ATen/native/mps/operations/ConstantOps.mm b/aten/src/ATen/native/mps/operations/ConstantOps.mm index e36ac4dc45246..7cf692f455745 100644 --- a/aten/src/ATen/native/mps/operations/ConstantOps.mm +++ b/aten/src/ATen/native/mps/operations/ConstantOps.mm @@ -62,12 +62,24 @@ return self; } +<<<<<<< HEAD static Tensor& fill_mps_tensor_(Tensor& self, uint8_t value) { TORCH_INTERNAL_ASSERT(self.is_contiguous()); const auto stream = getCurrentMPSStream(); auto storage_byte_offset = self.storage_offset() * self.itemsize(); stream->fill(mps::getMTLBufferStorage(self), value, self.nbytes(), storage_byte_offset); return self; +======= +// returns false if tensor cannot be filled with fillBuffer() +static bool fill_mps_tensor_(Tensor& self, uint8_t value) { + if (self.is_contiguous()) { + MPSStream* stream = getCurrentMPSStream(); + auto storage_byte_offset = self.storage_offset() * self.itemsize(); + stream->fill(mps::getMTLBufferStorage(self), value, self.nbytes(), storage_byte_offset); + return true; + } + return false; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } Tensor& fill_scalar_mps(Tensor& self, const Scalar& value) { @@ -86,6 +98,7 @@ return self; } // check if it's possible to use fillBuffer() to fill the Tensor's storage +<<<<<<< HEAD if (self.is_contiguous()) { if (value.toDouble() == 0.0) { return fill_mps_tensor_(self, 0); @@ -100,6 +113,10 @@ return fill_mps_tensor_(self, value.toChar()); } } +======= + if (value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true) + return self; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return fill_scalar_mps_impl(self, value); } @@ -110,6 +127,11 @@ value.dim(), " dimensions."); Scalar scalar_value = value.item(); +<<<<<<< HEAD +======= + if (scalar_value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true) + return self; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return fill_scalar_mps(self, scalar_value); } diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm index d572d52d103a1..436ad89f3b4ac 100644 --- a/aten/src/ATen/native/mps/operations/Convolution.mm +++ b/aten/src/ATen/native/mps/operations/Convolution.mm @@ -124,6 +124,10 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_, IntArrayRef dilation, int64_t groups, std::optional input_shape) { +<<<<<<< HEAD +======= + const bool is_macOS_13_2_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); Tensor input_t = input_t_; bool is3DConv = input_t.dim() == 5; @@ -131,6 +135,12 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_, input_t = input_t.contiguous(); } +<<<<<<< HEAD +======= + TORCH_CHECK(((input_t.dim() < 5) || is_macOS_13_2_or_newer), + "Conv3D is only supported on MPS for MacOS_13_2 or newer"); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types"); using namespace at::native::mps; diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index 0c121cee8fb62..c590517bdb640 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -60,6 +60,10 @@ static void copy_cast_mps(at::Tensor& dst, outputTensor = [mpsGraph castTensor:outputTensor toType:dstDType name:@"cast"]; } if (needs_conj) { +<<<<<<< HEAD +======= + TORCH_CHECK(supportsComplex(), "MPS complex tensors conjugation needs MacOS14+"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) outputTensor = [mpsGraph conjugateWithTensor:outputTensor name:nil]; } @@ -274,7 +278,28 @@ void copy_blit_mps(void* dst, const void* src, size_t size) { // for GPU to GPU copies we only encode to stream's command buffer (no flushing) stream->copy(sourceBuffer, destBuffer, src.nbytes(), src_byte_offset, dst_byte_offset, profile_id); } else { +<<<<<<< HEAD if (dst_byte_offset) { +======= + // Simulate cast to Complex on older MacOS by initializing real and imag parts + if (dst_.is_complex() && !supportsComplex()) { + if (!src.is_complex()) { + at::real(dst_).copy_(src); + at::imag(dst_).fill_(0); + } else if (src.is_conj() || dst_.is_conj()) { + // One cannot take view of conjugated tensor, but for some reason real and imag views are fine + // Use this to implement a conjugation + at::real(dst_).copy_(at::real(src)); + if (src.is_conj() != dst_.is_conj()) { + at::imag(dst_).copy_(at::neg(at::imag(src))); + } else { + at::imag(dst_).copy_(at::imag(src)); + } + } else { + at::view_as_real(dst_).copy_(at::view_as_real(src)); + } + } else if (dst_byte_offset) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto maybeCastedSource = at::empty(dst_.sizes(), dst_.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt); auto maybeCastedSourceBuffer = getMTLBufferStorage(maybeCastedSource); diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index 4d3f99ea9e02d..b83fa3e47628c 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -87,6 +87,10 @@ case kFloat: return MPSDataTypeFloat32; case kBFloat16: { +<<<<<<< HEAD +======= + checkSupportsBFloat16(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return MPSDataTypeBFloat16; } default: @@ -417,9 +421,14 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, std::optional::epsilon(); return mps::random_mps_impl(self, eps, +======= + return mps::random_mps_impl(self, + 0.0, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 1.0, std::nullopt, std::nullopt, diff --git a/aten/src/ATen/native/mps/operations/FastFourierTransform.mm b/aten/src/ATen/native/mps/operations/FastFourierTransform.mm index 7e9867c9b948d..f6924bdf6785b 100644 --- a/aten/src/ATen/native/mps/operations/FastFourierTransform.mm +++ b/aten/src/ATen/native/mps/operations/FastFourierTransform.mm @@ -88,6 +88,10 @@ Tensor _fft_c2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, // TODO: Investigate numerical discrepancies see https://github.com/pytorch/pytorch/issues/120237 Tensor& _fft_r2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor& out) { +<<<<<<< HEAD +======= + TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto key = __func__ + getTensorsStringKey({self, out}) + ":" + getArrayRefString(dim) + ":" + std::to_string(normalization) + ":" + std::to_string(onesided); @autoreleasepool { @@ -128,6 +132,10 @@ Tensor _fft_c2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t normalization, int64_t last_dim_size, Tensor& out) { +<<<<<<< HEAD +======= + TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" + std::to_string(normalization) + ":" + std::to_string(last_dim_size); @autoreleasepool { @@ -153,6 +161,10 @@ Tensor _fft_c2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, } Tensor& _fft_c2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward, Tensor& out) { +<<<<<<< HEAD +======= + TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" + std::to_string(normalization) + ":" + std::to_string(forward); @autoreleasepool { diff --git a/aten/src/ATen/native/mps/operations/GridSampler.mm b/aten/src/ATen/native/mps/operations/GridSampler.mm index ef85633889487..efd567cb24edc 100644 --- a/aten/src/ATen/native/mps/operations/GridSampler.mm +++ b/aten/src/ATen/native/mps/operations/GridSampler.mm @@ -1,10 +1,16 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +<<<<<<< HEAD #include #include #include #include #include #include +======= +#include +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifndef AT_PER_OPERATOR_HEADERS #include @@ -12,6 +18,7 @@ #else #include #include +<<<<<<< HEAD #include #endif @@ -23,6 +30,11 @@ #include #endif +======= +#endif + +namespace at::native { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace mps { static void grid_sampler_2d_mps_impl(Tensor& output, const Tensor& input, @@ -131,6 +143,7 @@ static void grid_sampler_2d_mps_impl(Tensor& output, runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); } } +<<<<<<< HEAD static void grid_sampler_template(Tensor& output, const Tensor& input, @@ -221,6 +234,8 @@ static void grid_sampler_template(Tensor& output, }); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace mps Tensor grid_sampler_2d_mps(const Tensor& input, @@ -228,6 +243,18 @@ Tensor grid_sampler_2d_mps(const Tensor& input, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { +<<<<<<< HEAD +======= + if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS)) { + TORCH_WARN_ONCE("MPS: grid_sampler_2d op is supported natively starting from macOS 13.2. ", + "Falling back on CPU. This may have performance implications."); + + return at::grid_sampler_2d(input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners) + .clone() + .to("mps"); + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto in_size = input.sizes(); auto grid_size = grid.sizes(); auto output = at::empty({in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options()); @@ -236,6 +263,7 @@ Tensor grid_sampler_2d_mps(const Tensor& input, return output; } +<<<<<<< HEAD Tensor grid_sampler_3d_mps(const Tensor& input, const Tensor& grid, int64_t interpolation_mode, @@ -253,4 +281,6 @@ Tensor grid_sampler_3d_mps(const Tensor& input, return output; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/Indexing.h b/aten/src/ATen/native/mps/operations/Indexing.h new file mode 100644 index 0000000000000..f52e5cd7334c3 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/Indexing.h @@ -0,0 +1,8 @@ +// Copyright Ā© 2022 Apple Inc. +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include + +using namespace at::mps; diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index fa19d2f4d127f..f225f0164bbbd 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -18,6 +18,11 @@ #include #include #include +<<<<<<< HEAD +======= +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -108,12 +113,33 @@ static void validateInputData(const TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride, +<<<<<<< HEAD const std::string& op) { +======= + const std::string& op, + bool accumulate) { + using namespace mps; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto num_indices = index_size.size(); TORCH_CHECK(num_indices <= 16, "Current limit allows up to 16 indices to be used in MPS indexing kernels"); AT_ASSERT(num_indices == index_stride.size()); AT_ASSERT(static_cast(num_indices) == iter.ntensors() - 2); +<<<<<<< HEAD +======= + const Tensor& inputTensor = iter.tensor(1); + const auto scalar_type = inputTensor.scalar_type(); + + if (accumulate) { + // No atomic support for the rest of dtypes + TORCH_CHECK(supportedFloatingType(scalar_type) || scalar_type == kInt || scalar_type == kBool); + } else { + TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type) || + scalar_type == ScalarType::ComplexFloat || scalar_type == ScalarType::ComplexHalf, + getMPSTypeString(inputTensor) + std::string(" not supported for index.Tensor_out")); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } static Tensor& masked_select_out_mps_impl(Tensor& result, const Tensor& self, const Tensor& mask) { @@ -144,7 +170,11 @@ static void dispatch_index_kernel(TensorIteratorBase& iter, IntArrayRef index_stride, const std::string& kernel_name, const bool serial = false) { +<<<<<<< HEAD validateInputData(iter, index_size, index_stride, "index.Tensor_out"); +======= + validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (iter.numel() == 0) return; if (!iter.can_use_32bit_indexing()) { @@ -186,7 +216,11 @@ static void dispatch_index_kernel(TensorIteratorBase& iter, } static void index_kernel_mps(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) { +<<<<<<< HEAD validateInputData(iter, index_size, index_stride, "index.Tensor_out"); +======= + validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dispatch_index_kernel( iter, index_size, index_stride, fmt::format("index_select_{}", getBitSizeString(iter.tensor_base(0)))); } @@ -196,7 +230,11 @@ static void index_put_kernel_mps(TensorIterator& iter, IntArrayRef index_stride, bool accumulate) { @autoreleasepool { +<<<<<<< HEAD validateInputData(iter, index_size, index_stride, "index_put_impl"); +======= + validateInputData(iter, index_size, index_stride, "index_put_impl", accumulate); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (accumulate) { dispatch_index_kernel(iter, index_size, @@ -230,7 +268,11 @@ static void index_put_kernel_mps(TensorIterator& iter, index.numel()); int64_t idx = index.item(); TORCH_CHECK(idx == 0, "index_copy_(): the only valid index for a 0-dim tensor is 0, but got ", idx); +<<<<<<< HEAD result.copy_(source.squeeze()); +======= + result.copy_(source); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return; } @@ -254,12 +296,20 @@ static void index_put_kernel_mps(TensorIterator& iter, } } +<<<<<<< HEAD const auto source_size_dim = source.dim() > 0 ? source.size(dim) : 1; TORCH_CHECK(index.numel() == source_size_dim, "index_copy_(): Number of indices (", index.numel(), ") should be equal to source.size(dim) (", source_size_dim, +======= + TORCH_CHECK(source.size(dim) == index.numel(), + "index_copy_(): Number of indices (", + index.numel(), + ") should be equal to source.size(dim) (", + source.size(dim), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ")"); auto stream = getCurrentMPSStream(); @@ -282,7 +332,11 @@ static void index_put_kernel_mps(TensorIterator& iter, [computeEncoder setComputePipelineState:indexCopyPSO]; mtl_setArgs(computeEncoder, result, self, source, index, dim_arg, self.sizes(), ndim, indices_numel); if (!is_dense) { +<<<<<<< HEAD mtl_setArgs<8>(computeEncoder, self.strides(), result.strides(), source.strides(), index.strides()); +======= + mtl_setArgs<8>(computeEncoder, self.strides(), result.strides(), source.strides()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } mtl_dispatch1DJob(computeEncoder, indexCopyPSO, result.numel()); } @@ -340,7 +394,18 @@ static Tensor nonzero_fallback(const Tensor& self) { } Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) { +<<<<<<< HEAD if (self.is_complex()) { +======= + if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) { + TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 14.0. ", + "Falling back on CPU. This may have performance implications."); + Tensor out_fallback = nonzero_fallback(self); + at::native::resize_output(out_, out_fallback.sizes()); + out_.copy_(out_fallback); + return out_; + } else if (self.is_complex()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes. ", "Falling back on CPU. This may have performance implications."); Tensor out_fallback = nonzero_fallback(self); @@ -425,7 +490,15 @@ static Tensor nonzero_fallback(const Tensor& self) { } Tensor nonzero_mps(const Tensor& self) { +<<<<<<< HEAD if (self.is_complex()) { +======= + if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) { + TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 14.0. ", + "Falling back on CPU. This may have performance implications."); + return nonzero_fallback(self); + } else if (self.is_complex()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes ", "Falling back on CPU. This may have performance implications."); return nonzero_fallback(self); @@ -513,6 +586,7 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { return; } +<<<<<<< HEAD bool use_deterministic_algorithm = globalContext().deterministicAlgorithms(); // TODO: Do not use deterministic algorithm for long/complex but rather implement it as Metal shader @@ -535,6 +609,9 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { return; } +======= + TORCH_CHECK(source.scalar_type() != ScalarType::Long, "index_add(): Expected non int64 dtype for source."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto casted_type = isFloatingType(source.scalar_type()) ? ScalarType::Float : ScalarType::Int; struct CachedGraph : public MPSCachedGraph { @@ -596,7 +673,32 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { } Tensor index_select_mps(const Tensor& self, int64_t dim, const Tensor& index) { +<<<<<<< HEAD Tensor result = at::empty({0}, self.options()); +======= + IntArrayRef input_shape = self.sizes(); + auto num_input_dims = input_shape.size(); + + auto num_indices = index.numel(); + TORCH_CHECK_INDEX(index.dim() <= 1, "index_select(): Index is supposed to be a vector"); + + dim = maybe_wrap_dim(dim, self.dim()); + std::vector shape_data(num_input_dims); + + // Calculate new shape + for (const auto i : c10::irange(num_input_dims)) { + if (i == static_cast(dim)) { + shape_data[i] = num_indices; + } else { + shape_data[i] = input_shape[i]; + } + } + + IntArrayRef output_shape = IntArrayRef(shape_data.data(), num_input_dims); + + Tensor result = at::empty(output_shape, self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) index_select_out_mps(self, dim, index, result); return result; } @@ -618,11 +720,33 @@ Tensor index_select_mps(const Tensor& self, int64_t dim, const Tensor& index) { TORCH_CHECK(self.scalar_type() == output.scalar_type(), "index_select(): self and output must have the same scalar type"); TORCH_CHECK(dim == 0 || dim < self.dim(), "index_select(): Indexing dim ", dim, " is out of bounds of tensor"); +<<<<<<< HEAD auto output_size = self.sizes().vec(); if (self.dim() > 0) { output_size[dim] = num_indices; } at::native::resize_output(output, output_size); +======= + TORCH_CHECK(output.dim() == 0 || index.size(-1) == output.size(dim), + "index_select(): index and output must have the same size at `dim`th dimension, but got ", + index.size(-1), + " and ", + output.size(dim), + "."); + + for (const auto i : irange(self.dim())) { + if (i == dim) + continue; + TORCH_CHECK(self.size(i) == output.size(i), + "index_select(): self and output must have the same dimensions except for `dim`th dimension, but got ", + self.size(i), + " and ", + output.size(i), + " at dimension ", + i, + "."); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Empty index if (num_indices == 0 || self.numel() == 0) { @@ -908,8 +1032,11 @@ Tensor embedding_dense_backward_mps(const Tensor& grad_, TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_fill_(): Expected dtype int32 or int64 for index"); TORCH_CHECK(dim == 0 || dim < self.dim(), "index_fill_(): Indexing dim ", dim, " is out of bounds of tensor"); +<<<<<<< HEAD // MPS.scatter crashes if used with complex dtypes TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "index_fill_(): Complex types are yet not supported"); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Empty index if (num_indices == 0) { diff --git a/aten/src/ATen/native/mps/operations/Linear.mm b/aten/src/ATen/native/mps/operations/Linear.mm index 219086edd8e37..4964cbe29d46d 100644 --- a/aten/src/ATen/native/mps/operations/Linear.mm +++ b/aten/src/ATen/native/mps/operations/Linear.mm @@ -115,10 +115,14 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const std::opt return output; } +<<<<<<< HEAD // No-graph execution causes nonsense if these are non-contiguous. const bool is_contiguous = input.is_contiguous() && weight.is_contiguous() && bias.is_contiguous(); if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS) && is_contiguous) { +======= + if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _mps_linear_nograph(input, weight, bias, output); // Squeeze last dim of 1D linear return weight_arg.dim() != 1 ? output : output.squeeze(-1); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 7a3dde679c05f..f46fb0bb38201 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -112,6 +112,7 @@ return output; } +<<<<<<< HEAD Tensor& do_metal_addmm(const Tensor& self, const Tensor& other, Tensor& output, @@ -167,6 +168,8 @@ return output; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::tuple do_mm(MPSGraph* graph, const Tensor& self, const Tensor& other) { @@ -699,6 +702,10 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const TORCH_CHECK(output.is_mps()); TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D"); +<<<<<<< HEAD +======= + TORCH_CHECK(supportedFloatingOrComplexType(self), "MPS device does not support addmm for non-float input"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorArg args[]{{output, "out", 0}, {bias, "self", 1}, {self, "mat1", 2}, {other, "mat2", 3}}; checkAllSameGPU(__func__, args); @@ -725,10 +732,13 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const return output; } +<<<<<<< HEAD if (use_metal_mm(self, other, output)) { return do_metal_addmm(self, other, output, alpha, beta, *bias_); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool is_beta_non_zero = beta.toDouble() != 0.0; struct CachedGraph : public mps::MPSCachedGraph { diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm index f5264cf32d9f2..cb1d4d0365b58 100644 --- a/aten/src/ATen/native/mps/operations/Normalization.mm +++ b/aten/src/ATen/native/mps/operations/Normalization.mm @@ -597,10 +597,14 @@ Check if running mean exists (maybe do this check before making graph) const bool has_weight = (weight_opt.has_value() && weight_opt->defined()); +<<<<<<< HEAD bool any_grad_needed = (grad_input_mask[0] && grad_input.numel() > 0) || (grad_input_mask[1] && grad_weight.numel() > 0) || (grad_input_mask[2] && grad_bias.numel() > 0); if (!any_grad_needed) { +======= + if (grad_input.numel() == 0) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return std::make_tuple(grad_input, grad_weight, grad_bias); } diff --git a/aten/src/ATen/native/mps/operations/Pad.mm b/aten/src/ATen/native/mps/operations/Pad.mm index 2945ebf715f27..10e9b4c70013e 100644 --- a/aten/src/ATen/native/mps/operations/Pad.mm +++ b/aten/src/ATen/native/mps/operations/Pad.mm @@ -460,9 +460,12 @@ Tensor replication_pad3d_backward_mps(const Tensor& grad_output, const Tensor& i // backward pass is explicitly handled in autograd by negating the "pad" argument Tensor constant_pad_nd_mps(const Tensor& self, IntArrayRef pad, const Scalar& value) { +<<<<<<< HEAD if (pad.empty()) { return self.clone(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (pad.size() > 6) { TORCH_WARN_ONCE("MPS: The constant padding of more than 3 dimensions is not currently supported natively. ", "It uses View Ops default implementation to run. This may have performance implications."); diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index d916320b2e238..587ff99c80cfd 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -1,10 +1,15 @@ // Copyright Ā© 2022 Apple Inc. #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +<<<<<<< HEAD #include #include #include #include #include +======= +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifndef AT_PER_OPERATOR_HEADERS #include @@ -14,12 +19,16 @@ #include #include #include +<<<<<<< HEAD #include #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include #include +<<<<<<< HEAD #include #include #include @@ -34,6 +43,11 @@ #include #endif +======= +#endif + +namespace at::native { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace mps { struct PoolingCachedGraph : public MPSCachedGraph { @@ -256,6 +270,7 @@ static void pool2d_template(const Tensor& input, } } +<<<<<<< HEAD static std::vector copy_and_maybe_expand(IntArrayRef a, int32_t pooling_dims) { std::vector b(pooling_dims); for (const auto dim : c10::irange(pooling_dims)) { @@ -565,6 +580,8 @@ static void max_unpool_out_mps_template(const Tensor& input, }); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static void avg_pool2d_template(const Tensor& input, const Tensor& output, const std::optional& grad_output_opt, @@ -686,6 +703,7 @@ static void avg_pool2d_template(const Tensor& input, op_name); } +<<<<<<< HEAD static void avg_pool_out_mps_template(const Tensor& output, const Tensor& input, IntArrayRef _kernel_size, @@ -812,6 +830,10 @@ static bool use_graph_for_max_pool2d(IntArrayRef kernel_size, IntArrayRef stride return (stride[0] == 1) && (stride.size() == 1 || stride[1] == 1); } +======= +} // namespace mps + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor mps_max_pool2d(const Tensor& input, IntArrayRef kernel_size, IntArrayRef stride, @@ -819,6 +841,7 @@ Tensor mps_max_pool2d(const Tensor& input, IntArrayRef dilation, bool ceil_mode) { Tensor output = at::empty({0}, input.options(), MemoryFormat::Contiguous); +<<<<<<< HEAD bool use_graph = use_graph_for_max_pool2d(kernel_size, stride); if (use_graph) { mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { @@ -850,6 +873,26 @@ Tensor mps_max_pool2d(const Tensor& input, /*pooling_dims=*/2, "max_pool2d"); } +======= + mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { + MPSGraph* mpsGraph = cachedGraph.graph(); + return [mpsGraph maxPooling2DWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil]; + }; + mps::pool2d_template(input, + output, + std::nullopt, + std::nullopt, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + false, + std::nullopt, + pooling_op_block, + "max_pool2d"); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return output; } @@ -894,6 +937,7 @@ Tensor mps_max_pool2d_backward(const Tensor& grad_output, bool ceil_mode, const Tensor& output, const Tensor& indices) { +<<<<<<< HEAD bool use_graph = use_graph_for_max_pool2d(kernel_size, stride); if (use_graph) { auto indices_memory_format = indices.suggest_memory_format(); @@ -933,6 +977,34 @@ Tensor mps_max_pool2d_backward(const Tensor& grad_output, ceil_mode, /*pooling_dims=*/2, "max_pool2d"); +======= + auto indices_memory_format = indices.suggest_memory_format(); + + mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { + MPSGraph* mpsGraph = cachedGraph.graph(); + NSArray* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor:cachedGraph.inputTensor + descriptor:desc + name:nil]; + cachedGraph.indicesTensor = mps::castMPSTensor(mpsGraph, poolOutputs[1], ScalarType::Long); + return poolOutputs[0]; + }; + mps::pool2d_template(input, + output, + indices, + std::nullopt, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + false, + std::nullopt, + pooling_op_block, + "max_pool2d_indices"); + + if (indices_memory_format == MemoryFormat::ChannelsLast) { + const_cast(indices) = indices.to(MemoryFormat::ChannelsLast); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -968,6 +1040,7 @@ Tensor mps_max_pool2d_backward(const Tensor& grad_output, "max_pool2d_indices_backward"); } +<<<<<<< HEAD std::tuple max_pool3d_with_indices_out_mps(const Tensor& input, IntArrayRef kernel_size, IntArrayRef stride, @@ -1125,6 +1198,8 @@ Tensor max_unpooling3d_forward_mps(const Tensor& self, return output; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_IMPL_FUNC(avg_pool2d_out_mps) (const Tensor& input, int64_t kH, @@ -1137,6 +1212,7 @@ Tensor max_unpooling3d_forward_mps(const Tensor& self, bool count_include_pad, std::optional divisor_override, const Tensor& output) { +<<<<<<< HEAD if (ceil_mode) { mps::avg_pool_out_mps_template(output, input, @@ -1161,6 +1237,19 @@ Tensor max_unpooling3d_forward_mps(const Tensor& self, divisor_override, "avg_pool2d"); } +======= + mps::avg_pool2d_template(input, + output, + std::nullopt, + {kH, kW}, + {dH, dW}, + {padH, padW}, + {1, 1}, + ceil_mode, + count_include_pad, + divisor_override, + "avg_pool2d"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps) @@ -1186,6 +1275,7 @@ Tensor max_unpooling3d_forward_mps(const Tensor& self, "avg_pool2d_backward"); } +<<<<<<< HEAD TORCH_IMPL_FUNC(avg_pool3d_out_mps) (const Tensor& input, IntArrayRef kernel_size, @@ -1229,4 +1319,6 @@ Tensor max_unpooling3d_forward_mps(const Tensor& self, "avg_pool3d_backward"); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/RMSNorm.mm b/aten/src/ATen/native/mps/operations/RMSNorm.mm index 7948b5acd8e93..1572d56ddb9cf 100644 --- a/aten/src/ATen/native/mps/operations/RMSNorm.mm +++ b/aten/src/ATen/native/mps/operations/RMSNorm.mm @@ -19,6 +19,7 @@ #include #endif +<<<<<<< HEAD std::tuple _fused_rms_norm_mps(const Tensor& input, IntArrayRef normalized_shape, const std::optional& weight_opt, @@ -27,6 +28,9 @@ const int64_t normalized_ndim = normalized_shape.size(); auto eps_val = eps.value_or(std::numeric_limits::epsilon()); +======= +Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, const Tensor& weight, const double eps) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(input.is_contiguous() && weight.is_contiguous(), "Expected contiguous input and weight tensors"); auto output = at::empty_like(input); const auto input_shape = input.sizes(); @@ -48,7 +52,11 @@ const std::string kernel = fmt::format("{}_{}", name, scalarToMetalTypeString(output)); id rms_norm_pso = lib.getPipelineStateForFunc(kernel); [computeEncoder setComputePipelineState:rms_norm_pso]; +<<<<<<< HEAD mtl_setArgs(computeEncoder, input, weight, output, eps_val, N, 1); +======= + mtl_setArgs(computeEncoder, input, weight, output, eps, N, 1); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto maxThreadsPerGroup = static_cast([rms_norm_pso maxTotalThreadsPerThreadgroup]); size_t threadgroup_size = maxThreadsPerGroup; @@ -65,7 +73,11 @@ } }); +<<<<<<< HEAD return std::make_tuple(output, Tensor()); +======= + return output; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index ae13504d9003e..c166f437be039 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -152,6 +152,11 @@ static void reduction_out_mps(const Tensor& input_t, const Tensor& output_t, MPSReductionType reduction_type, const std::string& func_name) { +<<<<<<< HEAD +======= + bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); + MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, func_name); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // NS: TODO: get rid of all those shenanigans and just call reduction_op with view tensor bool canSqueezeLastDim = true; IntArrayRef input_shape = input_t.sizes(); @@ -234,10 +239,19 @@ static void reduction_out_mps(const Tensor& input_t, MPSGraphTensor* castInputTensor = inputTensor; MPSDataType inputCastType = MPSDataTypeInvalid; if (dtype.has_value() && +<<<<<<< HEAD (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt || dtype.value() == kLong)) { inputCastType = getMPSDataType(dtype.value()); } else if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat && inputScalarType != kComplexFloat && inputScalarType != kComplexHalf && inputScalarType != kLong) { +======= + (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt || + (dtype.value() == kLong && macOS13_3_plus))) { + inputCastType = getMPSDataType(dtype.value()); + } else if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat && + inputScalarType != kComplexFloat && inputScalarType != kComplexHalf && + (inputScalarType != kLong || !macOS13_3_plus)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inputCastType = getMPSDataType(kFloat); } @@ -456,7 +470,11 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t, errMessage += ": reduction dim must be in the range of input shape"; for (const auto dim : dim_value) { auto wrap_dim = maybe_wrap_dim(dim, num_input_dims); +<<<<<<< HEAD TORCH_CHECK(wrap_dim < (num_input_dims ? num_input_dims : 1), errMessage.c_str()) +======= + TORCH_CHECK(wrap_dim < static_cast(input_shape.size()), errMessage.c_str()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -611,6 +629,12 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t, } static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) { +<<<<<<< HEAD +======= + bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); + MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, nanmedian ? "nanmedian" : "median"); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) IntArrayRef input_shape = input_t.sizes(); int64_t num_in_elements = c10::multiply_integers(input_shape); @@ -627,7 +651,12 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) { auto medianCachedGraph = LookUpOrCreateCachedGraph(medianKey, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); +<<<<<<< HEAD MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); +======= + MPSGraphTensor* castInputTensor = + castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPSGraphTensor* reshapedTensor = [mpsGraph reshapeTensor:castInputTensor withShape:@[ @-1 ] name:nil]; @@ -685,6 +714,12 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) { } static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction_type, const std::string& func_name) { +<<<<<<< HEAD +======= + bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); + MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max"); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using CachedGraph = MPSUnaryCachedGraph; IntArrayRef input_shape = input_t.sizes(); @@ -702,7 +737,12 @@ static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); MPSGraphTensor* castOutputTensor = nil; +<<<<<<< HEAD MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); +======= + MPSGraphTensor* castInputTensor = + castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NSArray* axes = getTensorAxes(input_t); if (reduction_type == MPSReductionType::MAX) { @@ -737,6 +777,12 @@ static void min_max_out_mps(const Tensor& input_t, const Tensor& indices_t, MPSReductionType reduction_type, const std::string& func_name) { +<<<<<<< HEAD +======= + bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); + MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max_out"); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (output_t.numel() == 0) { return; } @@ -774,7 +820,12 @@ static void min_max_out_mps(const Tensor& input_t, auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); MPSGraphTensor* outputTensor = nil; +<<<<<<< HEAD MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); +======= + MPSGraphTensor* castInputTensor = + castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (reduction_type == MPSReductionType::MAX) { outputTensor = [mpsGraph reductionMaximumPropagateNaNWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil]; @@ -880,6 +931,12 @@ static void argmax_argmin_out_mps(const Tensor& input_t, const std::string& func_name) { using CachedGraph = MPSUnaryCachedGraph; +<<<<<<< HEAD +======= + bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); + MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "argmax_argmin_out"); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t dim_ = -1; if (dim.has_value()) { @@ -934,7 +991,11 @@ static void argmax_argmin_out_mps(const Tensor& input_t, MPSGraphTensor* castInputTensor = inputTensor; if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat && +<<<<<<< HEAD inputScalarType != kLong) { +======= + (inputScalarType != kLong || !macOS13_3_plus)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) castInputTensor = castMPSTensor(mpsGraph, inputTensor, kFloat); } if (reduction_type == MPSReductionType::MAX) { @@ -1263,6 +1324,12 @@ static void all_any_common_impl_mps(const Tensor& input_t, return; } +<<<<<<< HEAD +======= + bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); + MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, op_name); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); native::zero_numel_check_dims(input_t, dim_, op_name.c_str()); @@ -1281,7 +1348,11 @@ static void all_any_common_impl_mps(const Tensor& input_t, auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); +<<<<<<< HEAD auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); +======= + auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // reductionOrWithTensor:axis: will throw an internal assert if number of dimentions is more than 4 // See https://github.com/pytorch/pytorch/issues/95538 MPSGraphTensor* outputTensor = nil; @@ -1347,11 +1418,21 @@ static void all_any_common_impl_mps(const Tensor& input_t, return; } +<<<<<<< HEAD +======= + bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); + MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "any_all_out"); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @autoreleasepool { std::string key = std::string("any_all_out_mps:") + getTensorsStringKey(input_t); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); +<<<<<<< HEAD auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); +======= + auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // reductionOrWithTensor:axes: will throw an internal assert if number of dimentions is more than 4 // See https://github.com/pytorch/pytorch/issues/95538 if (input_t.dim() > 4) { @@ -1395,11 +1476,21 @@ static void all_any_common_impl_mps(const Tensor& input_t, return; } +<<<<<<< HEAD +======= + bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); + MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "all_all_out"); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @autoreleasepool { std::string key = std::string("all_all_out_mps:") + getTensorsStringKey(input_t); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); +<<<<<<< HEAD auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); +======= + auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // reductionAndWithTensor:axes: will throw an internal assert if number of dimentions is more than 4 // See https://github.com/pytorch/pytorch/issues/95538 if (input_t.ndimension() > 4) { @@ -1484,6 +1575,12 @@ static void median_out_mps_common(const Tensor& input_t, Tensor& indices, const std::string& func_name, bool nanmedian) { +<<<<<<< HEAD +======= + bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); + MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median_out"); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); native::zero_numel_check_dims(input_t, dim_, "max()"); @@ -1554,7 +1651,12 @@ static void median_out_mps_common(const Tensor& input_t, getTensorsStringKey(indices); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); +<<<<<<< HEAD MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t); +======= + MPSGraphTensor* castInputTensor = + castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPSGraphTensor* effectiveLengthTensor = nil; if (nanmedian) { diff --git a/aten/src/ATen/native/mps/operations/Repeat.mm b/aten/src/ATen/native/mps/operations/Repeat.mm index 40afa15b4f700..41cb8876e1dab 100644 --- a/aten/src/ATen/native/mps/operations/Repeat.mm +++ b/aten/src/ATen/native/mps/operations/Repeat.mm @@ -129,8 +129,21 @@ void computeRepeatIndices(const index_t* repeat_ptr, }); } +<<<<<<< HEAD Tensor repeat_interleave_mps(const Tensor& repeat, std::optional output_size) { Tensor output; +======= +Tensor repeat_interleave_mps(const Tensor& repeat_, std::optional output_size) { + Tensor output; + Tensor repeat = repeat_; + if (repeat.scalar_type() == kLong && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) { + // #103810551: `repeat_interleave_common` uses cumsum to calculate the final shape of output, + // which currently doesn't support int64_t as input. Casting internally the indices to int32_t. + TORCH_WARN_ONCE( + "MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3"); + repeat = repeat.to(kInt); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() { output = repeat_interleave_common>(repeat, output_size); }); diff --git a/aten/src/ATen/native/mps/operations/Scalar.mm b/aten/src/ATen/native/mps/operations/Scalar.mm index afda8557c9524..409643bedfd46 100644 --- a/aten/src/ATen/native/mps/operations/Scalar.mm +++ b/aten/src/ATen/native/mps/operations/Scalar.mm @@ -5,6 +5,13 @@ #include #include +<<<<<<< HEAD +======= +#ifdef __OBJC__ +#include +#endif + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using namespace at::mps; namespace at::native { @@ -15,12 +22,18 @@ Scalar _local_scalar_dense_mps(const Tensor& self) { auto output = at::empty_like(self, TensorOptions(kCPU)); mps::mps_copy_(output, self, false); +<<<<<<< HEAD AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::UInt16, at::ScalarType::UInt32, at::ScalarType::UInt64, +======= + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, + at::ScalarType::Bool, + at::ScalarType::BFloat16, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.scalar_type(), "_local_scalar_dense_mps", [&] { diff --git a/aten/src/ATen/native/mps/operations/ScanKernel.mm b/aten/src/ATen/native/mps/operations/ScanKernel.mm index 80495ba9d501d..169d2abed4844 100644 --- a/aten/src/ATen/native/mps/operations/ScanKernel.mm +++ b/aten/src/ATen/native/mps/operations/ScanKernel.mm @@ -10,9 +10,13 @@ #else #include #include +<<<<<<< HEAD #include #endif #include +======= +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace at::native { namespace mps { @@ -23,6 +27,7 @@ #include #endif +<<<<<<< HEAD // Utility function to get 2D grid dimensions for dispatch static std::pair get_2d_grid_dims(const IntArrayRef& shape, const int64_t dim) { size_t grid_x = 1; @@ -47,11 +52,20 @@ static void scan_simple_mps_impl(const Tensor& self, const Tensor& output, int64_t dim, const std::string& op_name) { if (output.numel() == 0) { +======= +// Generic scan implementation that handles both simple scans and scans with indices +static void scan_mps_impl(const Tensor& self, + const std::vector& outputs, + int64_t dim, + const std::string& op_name) { + if (outputs[0].numel() == 0) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return; } const int64_t ndim = self.dim(); const int64_t wrapped_dim = maybe_wrap_dim(dim, ndim); +<<<<<<< HEAD const int64_t axis_size = self.size(wrapped_dim); // Preprocess input tensor - ensure it's contiguous for Metal shaders @@ -70,26 +84,71 @@ static void scan_simple_mps_impl(const Tensor& self, const Tensor& output, int64 // Determine which kernel to use based on scan dimension position bool is_innermost_scan = (wrapped_dim == ndim - 1); +======= + + // Calculate dimensions for scan operation + int64_t row_size = self.size(wrapped_dim); + auto sizes = self.sizes(); + + bool is_innermost = (wrapped_dim == ndim - 1); + + // Check if all tensors are contiguous + bool is_contiguous = self.is_contiguous(); + for (const auto& output : outputs) { + is_contiguous = is_contiguous && output.is_contiguous(); + } + + uint32_t num_rows, num_orows, num_irows, num_threads; + + if (is_innermost) { + // Treat all outer dimensions as a single dimension + num_rows = self.numel() / row_size; + num_threads = num_rows; + } else { + // Treat all outer dimensions (i.e. dim_ < dim) as one + num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + wrapped_dim); + // Treat all inner dimensions (i.e. dim > dimension) as one + num_irows = c10::multiply_integers(sizes.begin() + wrapped_dim + 1, sizes.end()); + num_threads = num_orows * num_irows; + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPSStream* mpsStream = getCurrentMPSStream(); dispatch_sync_with_rethrow(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); +<<<<<<< HEAD // Build kernel name based on scan dimension position const auto type_str = scalarToMetalTypeString(input_tensor); const auto kernel_name = fmt::format("{}_{}_{}", op_name, is_innermost_scan ? "innermost" : "outer", type_str); +======= + // Choose kernel based on contiguity and dimension + std::string kernel_name; + if (is_contiguous) { + kernel_name = + op_name + "_contiguous_" + (is_innermost ? "innermost_" : "outer_") + scalarToMetalTypeString(self); + } else { + kernel_name = op_name + "_strided_" + scalarToMetalTypeString(self); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) id scanPSO = lib.getPipelineStateForFunc(kernel_name); // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(scanPSO, op_name, [&]() { +<<<<<<< HEAD std::vector all_tensors = {input_tensor, output_tensor}; +======= + std::vector all_tensors = {self}; + all_tensors.insert(all_tensors.end(), outputs.begin(), outputs.end()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return all_tensors; }()); [computeEncoder setComputePipelineState:scanPSO]; +<<<<<<< HEAD // Set input and output buffers (both guaranteed contiguous) mtl_setBuffer(computeEncoder, input_tensor, 0); mtl_setBuffer(computeEncoder, output_tensor, 1); @@ -251,11 +310,68 @@ static void scan_with_indices_mps_impl(const Tensor& self, if (indices_needs_copy) { indices_output.copy_(indices_tensor); } +======= + // Set input tensor + mtl_setBuffer(computeEncoder, self, 0); + + // Set output tensors + for (size_t i = 0; i < outputs.size(); ++i) { + mtl_setBuffer(computeEncoder, outputs[i], i + 1); + } + + if (is_contiguous) { + // Contiguous kernels + if (is_innermost) { + if (outputs.size() == 1) { + // Simple scan + mtl_setArgs<2>(computeEncoder, num_rows, static_cast(row_size)); + } else { + // Scan with indices + mtl_setArgs<3>(computeEncoder, num_rows, static_cast(row_size)); + } + } else { + if (outputs.size() == 1) { + // Simple scan + mtl_setArgs<2>(computeEncoder, num_orows, num_irows, static_cast(row_size)); + } else { + // Scan with indices + mtl_setArgs<3>(computeEncoder, num_orows, num_irows, static_cast(row_size)); + } + } + } else { + // Strided kernels - pass full tensor information + if (outputs.size() == 1) { + // Simple scan + mtl_setArgs<2>(computeEncoder, + self.sizes(), + self.strides(), + outputs[0].strides(), + static_cast(self.ndimension()), + static_cast(wrapped_dim)); + } else { + // Scan with indices + mtl_setArgs<3>(computeEncoder, + self.sizes(), + self.strides(), + outputs[0].strides(), + outputs[1].strides(), + static_cast(self.ndimension()), + static_cast(wrapped_dim)); + } + } + + mtl_dispatch1DJob(computeEncoder, scanPSO, num_threads); + + getMPSProfiler().endProfileKernel(scanPSO); + } + }); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace mps void cummax_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) { +<<<<<<< HEAD mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummax"); } @@ -282,6 +398,13 @@ void cummin_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int6 Tensor _logcumsumexp_mps(const Tensor& self, int64_t dim) { Tensor result = at::empty_like(self, MemoryFormat::Contiguous); return _logcumsumexp_out_mps(self, dim, result); +======= + mps::scan_mps_impl(self, {values, indices}, dim, "cummax"); +} + +void cummin_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) { + mps::scan_mps_impl(self, {values, indices}, dim, "cummin"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/Sort.mm b/aten/src/ATen/native/mps/operations/Sort.mm index 6ff47044df133..c54ec51c36041 100644 --- a/aten/src/ATen/native/mps/operations/Sort.mm +++ b/aten/src/ATen/native/mps/operations/Sort.mm @@ -2,7 +2,10 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -12,11 +15,15 @@ #include #include #else +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #endif namespace at::native { +<<<<<<< HEAD namespace { void kthvalue_out_mps_impl(const Tensor& self, int64_t k, int64_t dim, Tensor& values, Tensor& indices) { @@ -91,6 +98,8 @@ void kthvalue_out_mps_impl(const Tensor& self, int64_t k, int64_t dim, Tensor& v } } } // anonymous namespace +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // sort TORCH_IMPL_FUNC(sort_stable_out_mps) @@ -102,6 +111,12 @@ void kthvalue_out_mps_impl(const Tensor& self, int64_t k, int64_t dim, Tensor& v const Tensor& indices) { using namespace mps; +<<<<<<< HEAD +======= + bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); + MPS_CHECK_INT64_OP_SUPPORTED(self, macOS13_3_plus, "sort_stable_out"); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (self.numel() == 0) { return; } @@ -128,7 +143,12 @@ void kthvalue_out_mps_impl(const Tensor& self, int64_t k, int64_t dim, Tensor& v auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape); +<<<<<<< HEAD MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self); +======= + MPSGraphTensor* castInputTensor = + castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor axis:(NSInteger)dim descending:(BOOL)descending @@ -157,6 +177,7 @@ void kthvalue_out_mps_impl(const Tensor& self, int64_t k, int64_t dim, Tensor& v runMPSGraph(stream, cachedGraph->graph(), feeds, results); } } +<<<<<<< HEAD std::tuple kthvalue_out_mps(const Tensor& self, int64_t k, @@ -184,4 +205,6 @@ void kthvalue_out_mps_impl(const Tensor& self, int64_t k, int64_t dim, Tensor& v return std::forward_as_tuple(values, indices); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/TensorCompare.mm b/aten/src/ATen/native/mps/operations/TensorCompare.mm index 7b637d896f850..b7a2b8aefb84c 100644 --- a/aten/src/ATen/native/mps/operations/TensorCompare.mm +++ b/aten/src/ATen/native/mps/operations/TensorCompare.mm @@ -297,6 +297,12 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements, const auto common_type = at::result_type(elements, test_elements); TORCH_CHECK(elements.is_mps() && test_elements.is_mps()); +<<<<<<< HEAD +======= + TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) || supportedFloatingType(common_type), + "isin_Tensor_Tensor_out only works on floating types on MPS for pre MacOS_14_0. Received dtype: ", + common_type); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @autoreleasepool { std::string key = op_name + getTensorsStringKey({elements, test_elements}) + std::to_string(invert); @@ -335,9 +341,12 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements, } static void is_posneginf_helper(TensorIteratorBase& iter, bool is_neg) { +<<<<<<< HEAD if (iter.numel() == 0) { return; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto& self = iter.input(0); auto& out = iter.output(0); @autoreleasepool { diff --git a/aten/src/ATen/native/mps/operations/UnaryKernel.mm b/aten/src/ATen/native/mps/operations/UnaryKernel.mm index 7e150b133cc65..d01341cc35b00 100644 --- a/aten/src/ATen/native/mps/operations/UnaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/UnaryKernel.mm @@ -14,7 +14,10 @@ #include #endif +<<<<<<< HEAD // KURT: call site of `exec_unary_kernel` +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #define REGISTER_UNARY_TI_DISPATCH(NAME) \ static void NAME##_kernel_mps(TensorIteratorBase& iter) { \ lib.exec_unary_kernel(iter, #NAME); \ @@ -50,7 +53,10 @@ static void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals) { REGISTER_UNARY_TI_DISPATCH(log); REGISTER_UNARY_TI_DISPATCH(log1p); REGISTER_UNARY_TI_DISPATCH(bitwise_not); +<<<<<<< HEAD REGISTER_UNARY_TI_DISPATCH(round); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) REGISTER_UNARY_TI_DISPATCH(sigmoid); REGISTER_DISPATCH(round_decimals_stub, round_decimals_kernel); } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index d7ce40e5cbb4f..cd6d5400e8d69 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -184,6 +184,10 @@ static void unary_op(const Tensor& self, REGISTER_MPS_UNARY_STUB(ceil, ceil); REGISTER_MPS_UNARY_STUB(floor, floor); +<<<<<<< HEAD +======= +REGISTER_MPS_UNARY_STUB(round, round); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) REGISTER_MPS_UNARY_STUB(trunc, truncate); #define CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \ @@ -207,12 +211,37 @@ static void unary_op(const Tensor& self, } Tensor& angle_out_mps(const Tensor& self, Tensor& output) { +<<<<<<< HEAD mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil]; auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil]; return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil]; }); return output; +======= + if (mps::supportsComplex()) { + mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil]; + auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil]; + return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil]; + }); + return output; + } else { + TORCH_CHECK(!self.is_complex(), "MPS does not support angle with complex input on macOS13") + mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + // On macOS 13 with non-complex input, realPartOfTensor and imaginaryPartOfTensor are + // not available, and NaN is not propagated correctly: + auto imagPart = [mpsGraph constantWithScalar:0.0 shape:inputTensor.shape dataType:inputTensor.dataType]; + auto result = [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:inputTensor name:nil]; + auto nanMask = [mpsGraph isNaNWithTensor:inputTensor name:nil]; + return [mpsGraph selectWithPredicateTensor:nanMask + truePredicateTensor:inputTensor + falsePredicateTensor:result + name:nil]; + }); + return output; + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } Tensor angle_mps(const Tensor& self) { @@ -345,6 +374,10 @@ static void cumulative_op_impl(const Tensor& self, const Tensor& result, MPSCumulativeOpType cumulativeOpType, const std::string& op_name) { +<<<<<<< HEAD +======= + bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto nDims = self.dim(); auto wrapped_dim = maybe_wrap_dim(dim, nDims); TORCH_CHECK(wrapped_dim >= 0 && wrapped_dim < std::max(1LL, self.ndimension()), @@ -363,6 +396,14 @@ static void cumulative_op_impl(const Tensor& self, bool castInputData = (isIntegralType(input.scalar_type(), true) && input.scalar_type() != ScalarType::Int && input.scalar_type() != ScalarType::Long); +<<<<<<< HEAD +======= + TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long, + "MPS does not support ", + op_name, + " op with int64 input. Support has been added in macOS 13.3"); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mps::unary_op( input, result, op_name + std::to_string(dim), ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { if (castInputData) { @@ -417,10 +458,24 @@ static void cumulative_op_impl(const Tensor& self, Tensor& conj_physical_out_mps(const Tensor& self, Tensor& result) { TORCH_CHECK(self.is_complex()); +<<<<<<< HEAD TORCH_CHECK(self.dtype() != at::kComplexDouble); mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { return [mpsGraph conjugateWithTensor:inputTensor name:nil]; }); +======= + if (!mps::supportsComplex()) { + if (!result.is_same_size(self)) { + result.resize_(self.sizes()); + } + at::real(result).copy_(at::real(self)); + at::imag(result).copy_(at::neg(at::imag(self))); + } else { + mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + return [mpsGraph conjugateWithTensor:inputTensor name:nil]; + }); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return result; } diff --git a/aten/src/ATen/native/mps/operations/View.mm b/aten/src/ATen/native/mps/operations/View.mm index 5efd4a3cfbdf3..47a7192c515da 100644 --- a/aten/src/ATen/native/mps/operations/View.mm +++ b/aten/src/ATen/native/mps/operations/View.mm @@ -17,7 +17,30 @@ #include #endif +<<<<<<< HEAD namespace at::native::mps { +======= +namespace at::native { +namespace mps { + +static IntArrayRef updateTensorBaseShape(const Tensor& self) { + IntArrayRef base_shape = getIMPSAllocator()->getBufferShape(self.storage().data()); + // if there's no base_shape stored in MPSAllocator, then infer it from tensor's size and store it + if (base_shape.size() == 0) { + // IntArrayRef wouldn't own the data, so we use a static storage + static const int64_t shape_1d = 1; + // self.sizes().size() could be zero + base_shape = self.sizes().size() + ? self.sizes() + : ((self.is_view() && self._base().sizes().size()) ? self._base().sizes() : IntArrayRef(&shape_1d, 1)); + + // base_shape will be retained in MPSAllocator until buffer gets recycled + if (self.storage().data()) + getIMPSAllocator()->setBufferShape(self.storage().data(), base_shape); + } + return base_shape; +} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // For both scatter and gather kernels, there are 4 specized ones (for 1D to 4D tensor) // and one generic, for 5+D ones. Assumption (to be tested) about specialized kernels @@ -179,4 +202,30 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) { return output; } +<<<<<<< HEAD } // namespace at::native::mps +======= +} // namespace mps + +// implementation of as_strided() op +Tensor as_strided_tensorimpl_mps(const Tensor& self, + IntArrayRef size, + IntArrayRef stride, + std::optional storage_offset_) { + auto storage_offset = storage_offset_.value_or(self.storage_offset()); + auto result = + detail::make_tensor(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype()); + setStrided(result, size, stride, storage_offset); + + // creating the view graph will be deferred until gatherViewTensor() or scatterViewTensor() are called. + // In as_strided, we just update the base shape of the buffer in order to retrieve it later + // when we create/run the view graph. + IntArrayRef base_shape = mps::updateTensorBaseShape(self); + TORCH_INTERNAL_ASSERT( + !base_shape.empty(), "Failed to update the base shape of tensor's buffer at ", self.storage().data()); + + return result; +} + +} // namespace at::native +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index abb061afc5c95..2e2adbc54e3e9 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -288,7 +288,10 @@ dispatch: CPU: native_dropout_cpu CUDA: native_dropout_cuda +<<<<<<< HEAD MPS: native_dropout_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: native_dropout_nested tags: [nondeterministic_seeded, core] autogen: native_dropout.out @@ -297,7 +300,10 @@ dispatch: CPU, NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: native_dropout_backward CUDA: native_dropout_backward_cuda +<<<<<<< HEAD MPS: native_dropout_backward_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: native_dropout_backward.out tags: pointwise @@ -342,8 +348,13 @@ variants: function, method dispatch: CompositeExplicitAutograd: abs +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: abs_sparse SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: abs_sparse_csr +======= + SparseCPU, SparseCUDA: abs_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_abs tags: [core, pointwise] @@ -352,16 +363,27 @@ variants: function, method dispatch: CompositeExplicitAutograd: abs_ +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: abs_sparse_ SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: abs_sparse_csr_ +======= + SparseCPU, SparseCUDA: abs_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_abs_ - func: abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: +<<<<<<< HEAD CPU, CUDA, MPS, MTIA: abs_out SparseCPU, SparseCUDA, SparseMPS: abs_sparse_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: abs_sparse_csr_out +======= + CPU, CUDA, MPS: abs_out + SparseCPU, SparseCUDA: abs_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: pointwise # Note [Adding an alias] @@ -430,7 +452,11 @@ variants: function, method structured_delegate: sgn.out dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sgn_sparse +======= + SparseCPU, SparseCUDA: sgn_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sgn_sparse_csr NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_sgn tags: pointwise @@ -439,7 +465,11 @@ variants: method structured_delegate: sgn.out dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sgn_sparse_ +======= + SparseCPU, SparseCUDA: sgn_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sgn_sparse_csr_ NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_sgn_ tags: pointwise @@ -450,7 +480,11 @@ dispatch: CPU, CUDA: sgn_out MPS: sgn_out_mps +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sgn_sparse_out +======= + SparseCPU, SparseCUDA: sgn_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sgn_sparse_csr_out tags: pointwise @@ -478,7 +512,11 @@ variants: function, method dispatch: CompositeExplicitAutograd: _conj_physical +<<<<<<< HEAD SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: conj_physical_sparse_csr +======= + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: conj_physical_sparse_csr +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _conj_physical.out - func: conj_physical(Tensor self) -> Tensor @@ -489,8 +527,13 @@ dispatch: CPU, CUDA: conj_physical_out MPS: conj_physical_out_mps +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: conj_physical_out_sparse SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: conj_physical_sparse_csr_out +======= + SparseCPU, SparseCUDA: conj_physical_out_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: conj_physical_sparse_csr_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: pointwise - func: conj_physical_(Tensor(a!) self) -> Tensor(a!) @@ -556,7 +599,11 @@ structured_delegate: add.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS, SparseMeta: add_sparse +======= + SparseCPU, SparseCUDA, SparseMeta: add_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: add_sparse_csr MkldnnCPU: mkldnn_add ZeroTensor: add_zerotensor @@ -568,7 +615,11 @@ variants: method structured_delegate: add.out dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS, SparseMeta: add_sparse_ +======= + SparseCPU, SparseCUDA, SparseMeta: add_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: add_sparse_csr_ MkldnnCPU: mkldnn_add_ NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_add__Tensor @@ -584,12 +635,18 @@ dispatch: SparseCPU, SparseMeta: add_out_sparse_cpu SparseCUDA: add_out_sparse_cuda +<<<<<<< HEAD SparseMPS: add_out_sparse_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrMeta: add_out_sparse_compressed_cpu SparseCsrCUDA: add_out_sparse_compressed_cuda MkldnnCPU: mkldnn_add_out MPS: add_out_mps +<<<<<<< HEAD MTIA: add_out_mtia +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: pointwise - func: _add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor @@ -722,7 +779,10 @@ dispatch: CPU, CUDA: all_out MPS: all_out_mps +<<<<<<< HEAD MTIA: all_out_mtia +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -812,7 +872,10 @@ CPU, Meta: arange_out CUDA: arange_cuda_out MPS: arange_mps_out +<<<<<<< HEAD MTIA: arange_mtia_out +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cpp_no_default_args: ['step'] # This function is a temporary hack to allow tracing of arange like constructs with dynamic @@ -877,7 +940,11 @@ variants: function, method structured_delegate: asinh.out dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: asinh_sparse +======= + SparseCPU, SparseCUDA: asinh_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asinh_sparse_csr tags: [core, pointwise] @@ -885,7 +952,11 @@ variants: function, method structured_delegate: asinh.out dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: asinh_sparse_ +======= + SparseCPU, SparseCUDA: asinh_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asinh_sparse_csr_ tags: pointwise @@ -895,7 +966,11 @@ dispatch: CPU, CUDA: asinh_out MPS: asinh_out_mps +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: asinh_sparse_out +======= + SparseCPU, SparseCUDA: asinh_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asinh_sparse_csr_out tags: pointwise @@ -912,7 +987,11 @@ structured_delegate: atanh.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: atanh_sparse +======= + SparseCPU, SparseCUDA: atanh_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atanh_sparse_csr tags: [core, pointwise] @@ -920,7 +999,11 @@ structured_delegate: atanh.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: atanh_sparse_ +======= + SparseCPU, SparseCUDA: atanh_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atanh_sparse_csr_ tags: pointwise @@ -930,7 +1013,11 @@ dispatch: CPU, CUDA: atanh_out MPS: atanh_out_mps +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: atanh_sparse_out +======= + SparseCPU, SparseCUDA: atanh_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atanh_sparse_csr_out tags: pointwise # arctanh, alias for atanh @@ -946,8 +1033,14 @@ - func: as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) variants: function, method dispatch: +<<<<<<< HEAD ZeroTensor, CPU, CUDA, MTIA, MPS: as_strided_tensorimpl Meta: as_strided_tensorimpl_meta_symint +======= + ZeroTensor, CPU, CUDA, MTIA: as_strided_tensorimpl + Meta: as_strided_tensorimpl_meta_symint + MPS: as_strided_tensorimpl_mps +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) QuantizedCPU, QuantizedCUDA: as_strided_qtensorimpl device_check: NoCheck device_guard: False @@ -967,7 +1060,11 @@ variants: function, method structured_delegate: asin.out dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: asin_sparse +======= + SparseCPU, SparseCUDA: asin_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asin_sparse_csr tags: [core, pointwise] @@ -976,7 +1073,11 @@ variants: function, method structured_delegate: asin.out dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: asin_sparse_ +======= + SparseCPU, SparseCUDA: asin_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asin_sparse_csr_ tags: pointwise @@ -986,7 +1087,11 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: asin_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: asin_sparse_out +======= + SparseCPU, SparseCUDA: asin_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asin_sparse_csr_out tags: pointwise @@ -1004,7 +1109,11 @@ structured_delegate: atan.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: atan_sparse +======= + SparseCPU, SparseCUDA: atan_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atan_sparse_csr tags: [core, pointwise] @@ -1013,7 +1122,11 @@ structured_delegate: atan.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: atan_sparse_ +======= + SparseCPU, SparseCUDA: atan_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atan_sparse_csr_ tags: pointwise @@ -1023,7 +1136,11 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: atan_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: atan_sparse_out +======= + SparseCPU, SparseCUDA: atan_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atan_sparse_csr_out tags: pointwise @@ -1072,7 +1189,10 @@ CUDA: baddbmm_out_cuda MPS: baddbmm_out_mps XPU: baddbmm_out_xpu +<<<<<<< HEAD MTIA: baddbmm_out_mtia +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCUDA: baddbmm_out_sparse_csr_cuda - func: baddbmm.dtype(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor @@ -1287,7 +1407,11 @@ - func: logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: +<<<<<<< HEAD CPU, CUDA, MTIA: logical_not_out +======= + CPU, CUDA: logical_not_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPS: logical_not_out_mps tags: pointwise @@ -1382,7 +1506,10 @@ CUDA: bmm_out_cuda MPS: bmm_out_mps XPU: bmm_out_xpu +<<<<<<< HEAD MTIA: bmm_out_mtia +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCPU: bmm_out_sparse_cpu SparseCUDA: bmm_out_sparse_cuda SparseCsrCUDA: bmm_out_sparse_csr_cuda @@ -1462,7 +1589,11 @@ structured_delegate: ceil.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: ceil_sparse +======= + SparseCPU, SparseCUDA: ceil_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: ceil_sparse_csr tags: [core, pointwise] @@ -1471,7 +1602,11 @@ structured_delegate: ceil.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: ceil_sparse_ +======= + SparseCPU, SparseCUDA: ceil_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: ceil_sparse_csr_ tags: pointwise @@ -1481,7 +1616,11 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: ceil_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: ceil_sparse_out +======= + SparseCPU, SparseCUDA: ceil_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: ceil_sparse_csr_out tags: pointwise @@ -1894,10 +2033,14 @@ - func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) dispatch: CUDA: cudnn_batch_norm +<<<<<<< HEAD - func: cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) dispatch: CUDA: cudnn_batch_norm_out +======= + autogen: cudnn_batch_norm.out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: You can only use this if you used cudnn_batch_norm training=True - func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) @@ -2178,7 +2321,11 @@ structured: True structured_inherits: TensorIteratorBase dispatch: +<<<<<<< HEAD CPU, CUDA, MPS, MTIA: div_out +======= + CPU, CUDA, MPS: div_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCPU, SparseCUDA: div_out_sparse_zerodim tags: pointwise @@ -2409,7 +2556,11 @@ MPS: empty_mps Meta: empty_meta_symint MkldnnCPU: empty_mkldnn +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: empty_sparse +======= + SparseCPU, SparseCUDA: empty_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseMeta: empty_sparse_symint SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed SparseCsrMeta: empty_sparse_compressed_symint @@ -2537,7 +2688,11 @@ structured_delegate: erf.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: erf_sparse +======= + SparseCPU, SparseCUDA: erf_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erf_sparse_csr tags: [core, pointwise] @@ -2546,7 +2701,11 @@ structured_delegate: erf.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: erf_sparse_ +======= + SparseCPU, SparseCUDA: erf_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erf_sparse_csr_ tags: pointwise @@ -2556,7 +2715,11 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS, MTIA: erf_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: erf_sparse_out +======= + SparseCPU, SparseCUDA: erf_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erf_sparse_csr_out tags: pointwise @@ -2622,7 +2785,11 @@ structured_delegate: expm1.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: expm1_sparse +======= + SparseCPU, SparseCUDA: expm1_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: expm1_sparse_csr tags: [core, pointwise] @@ -2631,7 +2798,11 @@ structured_delegate: expm1.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: expm1_sparse_ +======= + SparseCPU, SparseCUDA: expm1_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: expm1_sparse_csr_ tags: pointwise @@ -2641,7 +2812,11 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: expm1_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: expm1_sparse_out +======= + SparseCPU, SparseCUDA: expm1_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: expm1_sparse_csr_out tags: pointwise @@ -2740,7 +2915,11 @@ structured_delegate: floor.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: floor_sparse +======= + SparseCPU, SparseCUDA: floor_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: floor_sparse_csr tags: [core, pointwise] @@ -2749,7 +2928,11 @@ structured_delegate: floor.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: floor_sparse_ +======= + SparseCPU, SparseCUDA: floor_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: floor_sparse_csr_ tags: pointwise @@ -2759,7 +2942,11 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: floor_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: floor_sparse_out +======= + SparseCPU, SparseCUDA: floor_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: floor_sparse_csr_out tags: pointwise @@ -2767,7 +2954,11 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: +<<<<<<< HEAD CPU, CUDA, MPS, MTIA: floor_divide +======= + CPU, CUDA, MPS: floor_divide +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCPU, SparseCUDA: floor_divide_sparse - func: floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) @@ -2801,7 +2992,11 @@ structured_delegate: frac.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: frac_sparse +======= + SparseCPU, SparseCUDA: frac_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: frac_sparse_csr tags: pointwise @@ -2810,7 +3005,11 @@ structured_delegate: frac.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: frac_sparse_ +======= + SparseCPU, SparseCUDA: frac_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: frac_sparse_csr_ tags: pointwise @@ -2821,7 +3020,11 @@ dispatch: CPU, CUDA: frac_out MPS: frac_out_mps +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: frac_sparse_out +======= + SparseCPU, SparseCUDA: frac_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: frac_sparse_csr_out tags: pointwise @@ -2934,7 +3137,10 @@ dispatch: CPU: grid_sampler_3d_cpu CUDA: grid_sampler_3d_cuda +<<<<<<< HEAD MPS: grid_sampler_3d_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: grid_sampler_3d.out # `grid_sampler_3d_backward` takes in `output_mask` to optimize performance for @@ -3211,7 +3417,11 @@ dispatch: CPU, CUDA, MPS, MTIA: isnan NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_isnan +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: isnan_sparse +======= + SparseCPU, SparseCUDA: isnan_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isnan_sparse_csr autogen: isnan.out tags: [core, pointwise] @@ -3292,7 +3502,10 @@ dispatch: CPU: kthvalue_out_cpu CUDA: kthvalue_out_cuda +<<<<<<< HEAD MPS: kthvalue_out_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) variants: function, method @@ -3326,6 +3539,7 @@ dispatch: CompositeImplicitAutograd: rms_norm_symint +<<<<<<< HEAD - func: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) dispatch: CUDA: _fused_rms_norm_cuda @@ -3335,26 +3549,43 @@ - func: _fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor) dispatch: CUDA: _fused_rms_norm_backward_cuda +======= +- func: _fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor + dispatch: + MPS: _fused_rms_norm_mps +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor variants: function, method dispatch: CompositeExplicitAutograd: nan_to_num +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: nan_to_num_sparse +======= + SparseCPU, SparseCUDA: nan_to_num_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: pointwise - func: nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!) variants: function, method dispatch: CompositeExplicitAutograd: nan_to_num_ +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: nan_to_num_sparse_ +======= + SparseCPU, SparseCUDA: nan_to_num_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: pointwise - func: nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA, MTIA: nan_to_num_out MPS: nan_to_num_out_mps +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: nan_to_num_sparse_out +======= + SparseCPU, SparseCUDA: nan_to_num_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: pointwise - func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor @@ -3450,6 +3681,7 @@ - func: _wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor +<<<<<<< HEAD - func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor? bias) -> Tensor - func: fbgemm_linear_fp16_weight_fp32_activation.out(Tensor input, Tensor packed_weight, Tensor? bias, Tensor(a!) output) -> Tensor @@ -3458,6 +3690,12 @@ - func: fbgemm_linear_fp16_weight.out(Tensor input, Tensor packed_weight, Tensor bias, Tensor(a!) output) -> Tensor +======= +- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor + +- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: fbgemm_pack_quantized_matrix(Tensor input) -> Tensor - func: fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor @@ -3557,7 +3795,11 @@ structured_delegate: log1p.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: log1p_sparse +======= + SparseCPU, SparseCUDA: log1p_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: log1p_sparse_csr tags: [core, pointwise] @@ -3566,7 +3808,11 @@ structured_delegate: log1p.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: log1p_sparse_ +======= + SparseCPU, SparseCUDA: log1p_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: log1p_sparse_csr_ tags: pointwise @@ -3576,7 +3822,11 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: log1p_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: log1p_sparse_out +======= + SparseCPU, SparseCUDA: log1p_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: log1p_sparse_csr_out tags: pointwise @@ -3762,13 +4012,19 @@ dispatch: CPU: _logcumsumexp_cpu CUDA: _logcumsumexp_cuda +<<<<<<< HEAD MPS: _logcumsumexp_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: _logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: _logcumsumexp_out_cpu CUDA: _logcumsumexp_out_cuda +<<<<<<< HEAD MPS: _logcumsumexp_out_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: logcumsumexp(Tensor self, int dim) -> Tensor variants: function, method @@ -4196,13 +4452,19 @@ dispatch: CPU: _int_mm_cpu CUDA: _int_mm_cuda +<<<<<<< HEAD XPU: _int_mm_xpu +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: _int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: _int_mm_out_cpu CUDA: _int_mm_out_cuda +<<<<<<< HEAD XPU: _int_mm_out_xpu +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor dispatch: @@ -4239,7 +4501,10 @@ - func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor dispatch: CPU: _weight_int8pack_mm_cpu +<<<<<<< HEAD CUDA: _weight_int8pack_mm_cuda +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPS: _weight_int8pack_mm_mps - func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor @@ -4296,7 +4561,11 @@ structured: True structured_inherits: TensorIteratorBase dispatch: +<<<<<<< HEAD CPU, CUDA, MPS, MTIA: mul_out +======= + CPU, CUDA, MPS: mul_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCPU: mul_out_sparse_cpu SparseCUDA: mul_out_sparse_cuda SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_out_sparse_csr @@ -4668,7 +4937,11 @@ variants: function, method dispatch: CompositeExplicitAutograd: rad2deg +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: rad2deg_sparse +======= + SparseCPU, SparseCUDA: rad2deg_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: rad2deg_sparse_csr tags: pointwise @@ -4676,14 +4949,22 @@ variants: function, method dispatch: CompositeExplicitAutograd: rad2deg_ +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: rad2deg_sparse_ +======= + SparseCPU, SparseCUDA: rad2deg_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: rad2deg_sparse_csr_ tags: pointwise - func: rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: CompositeExplicitAutograd: rad2deg_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: rad2deg_sparse_out +======= + SparseCPU, SparseCUDA: rad2deg_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: rad2deg_sparse_csr_out tags: pointwise @@ -4691,7 +4972,11 @@ variants: function, method dispatch: CompositeExplicitAutograd: deg2rad +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: deg2rad_sparse +======= + SparseCPU, SparseCUDA: deg2rad_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: deg2rad_sparse_csr tags: pointwise @@ -4699,14 +4984,22 @@ variants: function, method dispatch: CompositeExplicitAutograd: deg2rad_ +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: deg2rad_sparse_ +======= + SparseCPU, SparseCUDA: deg2rad_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: deg2rad_sparse_csr_ tags: pointwise - func: deg2rad.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: CompositeExplicitAutograd: deg2rad_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: deg2rad_sparse_out +======= + SparseCPU, SparseCUDA: deg2rad_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: deg2rad_sparse_csr_out tags: pointwise @@ -4932,7 +5225,11 @@ structured_delegate: neg.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: neg_sparse +======= + SparseCPU, SparseCUDA: neg_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: neg_sparse_csr NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_neg tags: [core, pointwise] @@ -4942,7 +5239,11 @@ structured_delegate: neg.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: neg_sparse_ +======= + SparseCPU, SparseCUDA: neg_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: neg_sparse_csr_ NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_neg_ tags: pointwise @@ -4953,7 +5254,11 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS, MTIA: neg_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: neg_out_sparse +======= + SparseCPU, SparseCUDA: neg_out_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: neg_sparse_csr_out tags: pointwise # Alias for neg @@ -5037,7 +5342,11 @@ structured_delegate: round.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: round_sparse +======= + SparseCPU, SparseCUDA: round_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: round_sparse_csr tags: [core, pointwise] @@ -5046,7 +5355,11 @@ structured_delegate: round.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: round_sparse_ +======= + SparseCPU, SparseCUDA: round_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: round_sparse_csr_ tags: pointwise @@ -5056,7 +5369,11 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: round_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: round_sparse_out +======= + SparseCPU, SparseCUDA: round_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: round_sparse_csr_out tags: pointwise @@ -5092,14 +5409,23 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: +<<<<<<< HEAD CPU, CUDA: relu MPS: relu_mps MTIA: relu_mtia +======= + CPU, CUDA, MTIA: relu + MPS: relu_mps +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MkldnnCPU: mkldnn_relu QuantizedCPU: relu_quantized_cpu QuantizedCUDA: relu_quantized_cuda NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_relu +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: relu_sparse +======= + SparseCPU, SparseCUDA: relu_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: relu_sparse_csr tags: [core, pointwise] @@ -5107,14 +5433,23 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: +<<<<<<< HEAD CPU, CUDA: relu_ MPS: relu_mps_ MTIA: relu_mtia_ +======= + CPU, CUDA, MTIA: relu_ + MPS: relu_mps_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MkldnnCPU: mkldnn_relu_ QuantizedCPU: relu_quantized_cpu_ QuantizedCUDA: relu_quantized_cuda_ NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_relu_ +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: relu_sparse_ +======= + SparseCPU, SparseCUDA: relu_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: relu_sparse_csr_ autogen: relu.out tags: pointwise @@ -5401,7 +5736,11 @@ variants: function, method dispatch: SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sin_sparse_csr +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sin_sparse +======= + SparseCPU, SparseCUDA: sin_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_sin tags: [core, pointwise] @@ -5411,7 +5750,11 @@ variants: function, method dispatch: SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sin_sparse_csr_ +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sin_sparse_ +======= + SparseCPU, SparseCUDA: sin_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: pointwise - func: sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) @@ -5421,7 +5764,11 @@ dispatch: CPU, CUDA, MPS, MTIA: sin_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sin_sparse_csr_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sin_sparse_out +======= + SparseCPU, SparseCUDA: sin_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: pointwise - func: sinc(Tensor self) -> Tensor @@ -5446,7 +5793,11 @@ structured_delegate: sinh.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sinh_sparse +======= + SparseCPU, SparseCUDA: sinh_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sinh_sparse_csr tags: [core, pointwise] @@ -5455,7 +5806,11 @@ structured_delegate: sinh.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sinh_sparse_ +======= + SparseCPU, SparseCUDA: sinh_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sinh_sparse_csr_ tags: pointwise @@ -5465,7 +5820,11 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: sinh_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sinh_sparse_out +======= + SparseCPU, SparseCUDA: sinh_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sinh_sparse_csr_out # Returns a copy of this `Variable` that is detached from its autograd graph. @@ -5513,6 +5872,7 @@ tags: core manual_cpp_binding: True +<<<<<<< HEAD - func: sym_is_contiguous(Tensor self, MemoryFormat memory_format=contiguous_format) -> SymBool variants: function device_check: NoCheck @@ -5520,6 +5880,8 @@ tags: core manual_cpp_binding: True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: sym_numel(Tensor self) -> SymInt variants: function device_check: NoCheck @@ -5893,6 +6255,7 @@ CPU, CUDA: nansum_out MPS: nansum_out_mps +<<<<<<< HEAD - func: hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor variants: function, method structured_delegate: hash_tensor.out @@ -5902,6 +6265,8 @@ dispatch: CPU, CUDA: hash_tensor_out +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: sum_to_size(Tensor self, SymInt[] size) -> Tensor variants: method device_check: NoCheck @@ -5915,7 +6280,11 @@ variants: function, method dispatch: NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_sqrt +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sqrt_sparse +======= + SparseCPU, SparseCUDA: sqrt_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sqrt_sparse_csr tags: [core, pointwise] @@ -5924,7 +6293,11 @@ structured_delegate: sqrt.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sqrt_sparse_ +======= + SparseCPU, SparseCUDA: sqrt_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sqrt_sparse_csr_ tags: pointwise @@ -5934,7 +6307,11 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS, MTIA: sqrt_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sqrt_sparse_out +======= + SparseCPU, SparseCUDA: sqrt_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sqrt_sparse_csr_out tags: pointwise @@ -6072,7 +6449,11 @@ structured_delegate: tan.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: tan_sparse +======= + SparseCPU, SparseCUDA: tan_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tan_sparse_csr tags: [core, pointwise] @@ -6081,7 +6462,11 @@ structured_delegate: tan.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: tan_sparse_ +======= + SparseCPU, SparseCUDA: tan_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tan_sparse_csr_ tags: pointwise @@ -6091,7 +6476,11 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: tan_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: tan_sparse_out +======= + SparseCPU, SparseCUDA: tan_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tan_sparse_csr_out tags: pointwise @@ -6102,7 +6491,11 @@ dispatch: QuantizedCPU: tanh_quantized_cpu MkldnnCPU: mkldnn_tanh +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: tanh_sparse +======= + SparseCPU, SparseCUDA: tanh_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tanh_sparse_csr NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_tanh tags: [core, pointwise] @@ -6113,7 +6506,11 @@ variants: function, method dispatch: MkldnnCPU: mkldnn_tanh_ +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: tanh_sparse_ +======= + SparseCPU, SparseCUDA: tanh_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tanh_sparse_csr_ NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_tanh_ tags: pointwise @@ -6124,7 +6521,11 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS, MTIA: tanh_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: tanh_sparse_out +======= + SparseCPU, SparseCUDA: tanh_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tanh_sparse_csr_out tags: pointwise @@ -6396,8 +6797,13 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: trunc_sparse SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: trunc_sparse_csr +======= + SparseCPU, SparseCUDA: trunc_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: trunc_sparse_csr +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: [core, pointwise] - func: trunc_(Tensor(a!) self) -> Tensor(a!) @@ -6405,8 +6811,13 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: trunc_sparse_ SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: trunc_sparse_csr_ +======= + SparseCPU, SparseCUDA: trunc_sparse_ + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: trunc_sparse_csr_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: pointwise - func: trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) @@ -6415,8 +6826,13 @@ device_check: NoCheck # TensorIterator dispatch: CPU, CUDA, MPS: trunc_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: trunc_sparse_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: trunc_sparse_csr_out +======= + SparseCPU, SparseCUDA: trunc_sparse_out + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: trunc_sparse_csr_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: pointwise # Alias for trunc @@ -6926,7 +7342,11 @@ variants: function, method dispatch: CompositeExplicitAutograd: clone +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: clone_sparse +======= + SparseCPU, SparseCUDA: clone_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: clone_sparse_compressed MkldnnCPU: mkldnn_clone QuantizedCPU, QuantizedCUDA: quantized_clone @@ -6961,7 +7381,11 @@ CPU, CUDA: zero_ MPS: zero_mps_ Meta: zero_meta_ +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS, SparseMeta: zero_sparse_ +======= + SparseCPU, SparseCUDA, SparseMeta: zero_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: zero_sparse_csr_ MkldnnCPU: mkldnn_zero_ NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: zero_nested_ @@ -6974,7 +7398,10 @@ dispatch: CPU, CUDA: sub_out MPS: sub_out_mps +<<<<<<< HEAD MTIA: sub_out_mtia +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCPU, SparseCUDA: sub_out_sparse tags: pointwise @@ -7032,7 +7459,11 @@ device_check: NoCheck # TensorIterator variants: function dispatch: +<<<<<<< HEAD CPU, CUDA, MPS, MTIA: rsub +======= + CPU, CUDA, MPS: rsub +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: rsub.Tensor_out - func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!) @@ -7100,7 +7531,10 @@ CUDA: addmm_out_cuda MPS: addmm_out_mps XPU: addmm_out_xpu +<<<<<<< HEAD MTIA: addmm_out_mtia +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCPU: addmm_out_sparse_dense_cpu SparseCUDA: addmm_out_sparse_dense_cuda SparseCsrCPU: addmm_out_sparse_compressed_cpu @@ -7148,26 +7582,38 @@ dispatch: CPU: _scaled_mm_cpu CUDA: _scaled_mm_cuda +<<<<<<< HEAD tags: needs_exact_strides +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: CPU: _scaled_mm_out_cpu CUDA: _scaled_mm_out_cuda +<<<<<<< HEAD tags: needs_exact_strides +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: _scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor variants: function dispatch: CUDA: _scaled_grouped_mm_cuda +<<<<<<< HEAD tags: needs_exact_strides +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor variants: function dispatch: +<<<<<<< HEAD CompositeExplicitAutograd: _grouped_mm +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CUDA: _grouped_mm_cuda # NOTE [ Sparse: autograd and API ] @@ -7334,26 +7780,42 @@ - func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMeta, SparseMPS, Meta: new_with_dims_sparse +======= + SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _sparse_coo_tensor_with_dims.out - func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMeta, SparseMPS, Meta: new_with_dims_and_tensor_sparse_symint +======= + SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_and_tensor_sparse_symint +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _sparse_coo_tensor_with_dims_and_tensors.out - func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) use_const_ref_for_mutable_tensors: True variants: method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sparse_resize_ +======= + SparseCPU, SparseCUDA, SparseMeta: sparse_resize_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: sparse_resize, sparse_resize.out - func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) use_const_ref_for_mutable_tensors: True variants: method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sparse_resize_and_clear_ +======= + SparseCPU, SparseCUDA, SparseMeta: sparse_resize_and_clear_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: sparse_resize_and_clear, sparse_resize_and_clear.out - func: sparse_mask(Tensor self, Tensor mask) -> Tensor @@ -7379,8 +7841,13 @@ - func: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor variants: method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sparse_to_dense SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: sparse_compressed_to_dense +======= + SparseCPU, SparseCUDA: sparse_to_dense + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_dense +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MkldnnCPU: mkldnn_to_dense autogen: _to_dense.out @@ -7389,8 +7856,13 @@ - func: sparse_dim(Tensor self) -> int variants: method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sparse_dim_sparse SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: sparse_dim_sparse_csr +======= + SparseCPU, SparseCUDA, SparseMeta: sparse_dim_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_dim_sparse_csr +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CompositeExplicitAutograd: sparse_dim_default device_check: NoCheck device_guard: False @@ -7406,8 +7878,13 @@ - func: dense_dim(Tensor self) -> int variants: method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS, SparseMeta: dense_dim_sparse SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: dense_dim_sparse_csr +======= + SparseCPU, SparseCUDA, SparseMeta: dense_dim_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: dense_dim_sparse_csr +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CompositeExplicitAutograd: dense_dim_default device_check: NoCheck device_guard: False @@ -7423,8 +7900,13 @@ - func: _nnz(Tensor self) -> int variants: method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _nnz_sparse SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: _nnz_sparse_csr +======= + SparseCPU, SparseCUDA, SparseMeta: _nnz_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _nnz_sparse_csr +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device_check: NoCheck device_guard: False @@ -7440,13 +7922,20 @@ dispatch: SparseCPU: _coalesce_sparse_cpu SparseCUDA: _coalesce_sparse_cuda +<<<<<<< HEAD SparseMPS: _coalesce_sparse_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _coalesce.out - func: is_coalesced(Tensor self) -> bool variants: method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS, SparseMeta: is_coalesced_sparse +======= + SparseCPU, SparseCUDA, SparseMeta: is_coalesced_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CompositeExplicitAutograd: is_coalesced_default device_check: NoCheck device_guard: False @@ -7454,14 +7943,22 @@ - func: _indices(Tensor(a) self) -> Tensor(a) variants: method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _indices_sparse +======= + SparseCPU, SparseCUDA, SparseMeta: _indices_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device_check: NoCheck device_guard: False - func: _values(Tensor(a) self) -> Tensor(a) variants: method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _values_sparse +======= + SparseCPU, SparseCUDA, SparseMeta: _values_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device_check: NoCheck device_guard: False @@ -7471,7 +7968,11 @@ - func: _coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!) variants: method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _coalesced_sparse_ +======= + SparseCPU, SparseCUDA, SparseMeta: _coalesced_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device_check: NoCheck device_guard: False autogen: _coalesced, _coalesced.out @@ -7479,7 +7980,11 @@ - func: indices(Tensor(a) self) -> Tensor(a) variants: method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS, SparseMeta: indices_sparse +======= + SparseCPU, SparseCUDA, SparseMeta: indices_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CompositeExplicitAutograd: indices_default device_check: NoCheck device_guard: False @@ -7487,7 +7992,11 @@ - func: values(Tensor(a) self) -> Tensor(a) variants: method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS, SparseMeta: values_sparse +======= + SparseCPU, SparseCUDA, SparseMeta: values_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: values_sparse_csr NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: values_nested CompositeExplicitAutograd: values_default @@ -7540,7 +8049,11 @@ device_check: NoCheck # Allows copy into different device variants: function dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS, SparseMeta: copy_sparse_ +======= + SparseCPU, SparseCUDA, SparseMeta: copy_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: copy_sparse_to_sparse, copy_sparse_to_sparse.out # By adding the AutogradNestedTensor this makes this function CompositeImplicit-like for nested tensors @@ -7560,9 +8073,15 @@ - func: _to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor variants: method dispatch: +<<<<<<< HEAD CPU, CUDA, MPS: dense_to_sparse SparseCPU, SparseCUDA, SparseMPS: sparse_coo_to_sparse SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta, SparseCsrMPS: sparse_compressed_to_sparse +======= + CPU, CUDA: dense_to_sparse + SparseCPU, SparseCUDA: sparse_coo_to_sparse + SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _to_sparse.sparse_dim_out - func: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor @@ -7572,8 +8091,13 @@ - func: _to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor variants: method dispatch: +<<<<<<< HEAD CPU, CUDA, MPS: dense_to_sparse SparseCPU, SparseCUDA, SparseMPS: sparse_coo_to_sparse +======= + CPU, CUDA: dense_to_sparse + SparseCPU, SparseCUDA: sparse_coo_to_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse autogen: _to_sparse.out @@ -8946,7 +9470,11 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: +<<<<<<< HEAD CPU, CUDA, MTIA: ne_Scalar_out +======= + CPU, CUDA: ne_Scalar_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPS: ne_scalar_out_mps QuantizedCPU: ne_out_quantized_cpu tags: pointwise @@ -8964,7 +9492,11 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: +<<<<<<< HEAD CPU, CUDA, MTIA: ne_Tensor_out +======= + CPU, CUDA: ne_Tensor_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPS: ne_tensor_out_mps QuantizedCPU: ne_out_quantized_cpu tags: pointwise @@ -9009,7 +9541,11 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: +<<<<<<< HEAD CPU, CUDA, MTIA: eq_Scalar_out +======= + CPU, CUDA: eq_Scalar_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPS: eq_scalar_out_mps QuantizedCPU: eq_out_quantized_cpu tags: pointwise @@ -9028,7 +9564,11 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: +<<<<<<< HEAD CPU, CUDA, MTIA: eq_Tensor_out +======= + CPU, CUDA: eq_Tensor_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPS: eq_tensor_out_mps QuantizedCPU: eq_out_quantized_cpu tags: pointwise @@ -9047,7 +9587,11 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: +<<<<<<< HEAD CPU, CUDA, MTIA: ge_Scalar_out +======= + CPU, CUDA: ge_Scalar_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPS: ge_scalar_out_mps QuantizedCPU: ge_out_quantized_cpu tags: pointwise @@ -9066,7 +9610,11 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: +<<<<<<< HEAD CPU, CUDA, MTIA: ge_Tensor_out +======= + CPU, CUDA: ge_Tensor_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPS: ge_tensor_out_mps QuantizedCPU: ge_out_quantized_cpu tags: pointwise @@ -9111,7 +9659,11 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: +<<<<<<< HEAD CPU, CUDA, MTIA: le_Scalar_out +======= + CPU, CUDA: le_Scalar_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPS: le_scalar_out_mps QuantizedCPU: le_out_quantized_cpu tags: pointwise @@ -9129,7 +9681,11 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: +<<<<<<< HEAD CPU, CUDA, MTIA: le_Tensor_out +======= + CPU, CUDA: le_Tensor_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPS: le_tensor_out_mps QuantizedCPU: le_out_quantized_cpu tags: pointwise @@ -9174,7 +9730,11 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: +<<<<<<< HEAD CPU, CUDA,MTIA: gt_Scalar_out +======= + CPU, CUDA: gt_Scalar_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPS: gt_scalar_out_mps QuantizedCPU: gt_out_quantized_cpu tags: pointwise @@ -9193,7 +9753,11 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: +<<<<<<< HEAD CPU, CUDA, MTIA: gt_Tensor_out +======= + CPU, CUDA: gt_Tensor_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPS: gt_tensor_out_mps QuantizedCPU: gt_out_quantized_cpu tags: pointwise @@ -9421,7 +9985,11 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: +<<<<<<< HEAD CPU, CUDA, MTIA: addcmul_out +======= + CPU, CUDA: addcmul_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPS: addcmul_out_mps tags: pointwise @@ -9442,7 +10010,11 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: +<<<<<<< HEAD CPU, CUDA, MTIA: addcdiv_out +======= + CPU, CUDA: addcdiv_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPS: addcdiv_out_mps tags: pointwise @@ -9731,7 +10303,11 @@ structured_delegate: sign.out variants: function, method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sign_sparse +======= + SparseCPU, SparseCUDA: sign_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sign_sparse_csr tags: [core, pointwise] @@ -9740,7 +10316,11 @@ structured_delegate: sign.out variants: method dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sign_sparse_ +======= + SparseCPU, SparseCUDA: sign_sparse_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sign_sparse_csr_ tags: pointwise @@ -9751,7 +10331,11 @@ dispatch: CPU, CUDA: sign_out MPS: sign_out_mps +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: sign_sparse_out +======= + SparseCPU, SparseCUDA: sign_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sign_sparse_csr_out tags: pointwise @@ -9759,7 +10343,11 @@ variants: function, method structured_delegate: signbit.out dispatch: +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: signbit_sparse +======= + SparseCPU, SparseCUDA: signbit_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: signbit_sparse_csr tags: pointwise @@ -9770,7 +10358,11 @@ CPU: signbit_out CUDA: signbit_out MPS: signbit_out_mps +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: signbit_sparse_out +======= + SparseCPU, SparseCUDA: signbit_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: signbit_sparse_csr_out tags: pointwise @@ -9916,7 +10508,11 @@ structured: True structured_inherits: TensorIteratorBase dispatch: +<<<<<<< HEAD CPU, CUDA, MPS, MTIA: fmod_out +======= + CPU, CUDA, MPS: fmod_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: pointwise - func: fmod.Tensor(Tensor self, Tensor other) -> Tensor @@ -9953,7 +10549,11 @@ structured: True structured_inherits: TensorIteratorBase dispatch: +<<<<<<< HEAD CPU, CUDA, MPS: igamma_out +======= + CPU, CUDA: igamma_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: pointwise - func: igamma(Tensor self, Tensor other) -> Tensor @@ -9970,7 +10570,11 @@ structured: True structured_inherits: TensorIteratorBase dispatch: +<<<<<<< HEAD CPU, CUDA, MPS: igammac_out +======= + CPU, CUDA: igammac_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: pointwise - func: igammac(Tensor self, Tensor other) -> Tensor @@ -10516,7 +11120,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow_ CUDA: foreach_tensor_add_scalar_kernel_cuda_ +<<<<<<< HEAD MTIA: foreach_tensor_add_scalar_kernel_mtia_ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _foreach_add.Scalar_out - func: _foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] @@ -10525,7 +11132,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow CUDA: foreach_tensor_add_list_kernel_cuda +<<<<<<< HEAD MTIA: foreach_tensor_add_list_kernel_mtia +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10533,7 +11143,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow_ CUDA: foreach_tensor_add_list_kernel_cuda_ +<<<<<<< HEAD MTIA: foreach_tensor_add_list_kernel_mtia_ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _foreach_add.List_out - func: _foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10564,7 +11177,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow_ CUDA: foreach_tensor_add_tensor_kernel_cuda_ +<<<<<<< HEAD MTIA: foreach_tensor_add_tensor_kernel_mtia_ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _foreach_add.Tensor_out - func: _foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10625,7 +11241,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow_ CUDA: foreach_tensor_mul_scalar_kernel_cuda_ +<<<<<<< HEAD MTIA: foreach_tensor_mul_scalar_kernel_mtia_ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _foreach_mul.Scalar_out - func: _foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[] @@ -10634,7 +11253,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow CUDA: foreach_tensor_mul_list_kernel_cuda +<<<<<<< HEAD MTIA: foreach_tensor_mul_list_kernel_mtia +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10642,7 +11264,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow_ CUDA: foreach_tensor_mul_list_kernel_cuda_ +<<<<<<< HEAD MTIA: foreach_tensor_mul_list_kernel_mtia_ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _foreach_mul.List_out - func: _foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10666,7 +11291,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow CUDA: foreach_tensor_mul_tensor_kernel_cuda +<<<<<<< HEAD MTIA: foreach_tensor_mul_tensor_kernel_mtia +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: _foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10674,7 +11302,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow_ CUDA: foreach_tensor_mul_tensor_kernel_cuda_ +<<<<<<< HEAD MTIA: foreach_tensor_mul_tensor_kernel_mtia_ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _foreach_mul.Tensor_out - func: _foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10971,7 +11602,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow CUDA: foreach_tensor_addcmul_scalar_cuda +<<<<<<< HEAD MTIA: foreach_tensor_addcmul_scalar_mtia +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10993,7 +11627,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow_ CUDA: foreach_tensor_addcmul_scalar_cuda_ +<<<<<<< HEAD MTIA: foreach_tensor_addcmul_scalar_mtia_ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _foreach_addcmul.Scalar_out - func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () @@ -11018,7 +11655,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_abs_slow CUDA: foreach_tensor_abs_cuda +<<<<<<< HEAD MTIA: foreach_tensor_abs_mtia +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: _foreach_abs_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11026,7 +11666,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_abs_slow_ CUDA: foreach_tensor_abs_cuda_ +<<<<<<< HEAD MTIA: foreach_tensor_abs_mtia_ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _foreach_abs.out - func: _foreach_acos(Tensor[] self) -> Tensor[] @@ -11361,7 +12004,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_norm_slow CUDA: foreach_tensor_norm_cuda +<<<<<<< HEAD MTIA: foreach_tensor_norm_mtia +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _foreach_norm.Scalar_out - func: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[] @@ -11534,7 +12180,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_sqrt_slow_ CUDA: foreach_tensor_sqrt_cuda_ +<<<<<<< HEAD MTIA: foreach_tensor_sqrt_mtia_ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _foreach_sqrt.out - func: _foreach_tan(Tensor[] self) -> Tensor[] @@ -11596,7 +12245,10 @@ dispatch: CompositeExplicitAutograd: foreach_tensor_copy_list_kernel_slow_ CUDA: foreach_tensor_copy_list_kernel_cuda_ +<<<<<<< HEAD MTIA: foreach_tensor_copy_list_kernel_mtia_ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _foreach_copy.out - func: _foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out @@ -11604,7 +12256,10 @@ variants: function dispatch: CompositeExplicitAutograd: _foreach_copy +<<<<<<< HEAD MTIA: foreach_tensor_copy_list_kernel_mtia +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor dispatch: @@ -12380,7 +13035,10 @@ dispatch: CPU: avg_pool3d_out_cpu CUDA: avg_pool3d_out_cuda +<<<<<<< HEAD MPS: avg_pool3d_out_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MkldnnCPU: mkldnn_avg_pool3d_out - func: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor @@ -12397,7 +13055,10 @@ dispatch: CPU: avg_pool3d_backward_out_cpu CUDA: avg_pool3d_backward_out_cuda +<<<<<<< HEAD MPS: avg_pool3d_backward_out_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MkldnnCPU: mkldnn_avg_pool3d_backward_out - func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor @@ -12493,7 +13154,10 @@ dispatch: CPU: max_pool3d_with_indices_out_cpu CUDA: max_pool3d_with_indices_out_cuda +<<<<<<< HEAD MPS: max_pool3d_with_indices_out_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Return: (Tensor output, Tensor indices) - func: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) @@ -12501,7 +13165,10 @@ dispatch: CPU: max_pool3d_with_indices_cpu CUDA: max_pool3d_with_indices_cuda +<<<<<<< HEAD MPS: max_pool3d_with_indices_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: core - func: max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) @@ -12509,42 +13176,60 @@ dispatch: CPU: max_pool3d_with_indices_backward_out_cpu CUDA: max_pool3d_with_indices_backward_out_cuda +<<<<<<< HEAD MPS: max_pool3d_with_indices_backward_out_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor python_module: nn dispatch: CPU: max_pool3d_with_indices_backward_cpu CUDA: max_pool3d_with_indices_backward_cuda +<<<<<<< HEAD MPS: max_pool3d_with_indices_backward_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) python_module: nn dispatch: CPU: max_unpooling2d_forward_out_cpu CUDA: max_unpooling2d_forward_out_cuda +<<<<<<< HEAD MPS: max_unpooling2d_forward_out_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor python_module: nn dispatch: CPU: max_unpooling2d_forward_cpu CUDA: max_unpooling2d_forward_cuda +<<<<<<< HEAD MPS: max_unpooling2d_forward_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!) python_module: nn dispatch: CPU: max_unpooling3d_forward_out_cpu CUDA: max_unpooling3d_forward_out_cuda +<<<<<<< HEAD MPS: max_unpooling3d_forward_out_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor python_module: nn dispatch: CPU: max_unpooling3d_forward_cpu CUDA: max_unpooling3d_forward_cuda +<<<<<<< HEAD MPS: max_unpooling3d_forward_mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -13274,7 +13959,11 @@ dispatch: CompositeExplicitAutograd: isinf NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_isinf +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: isinf_sparse +======= + SparseCPU, SparseCUDA: isinf_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseMeta: isinf_sparse_meta SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isinf_sparse_csr autogen: isinf.out @@ -13290,7 +13979,11 @@ structured_delegate: isposinf.out dispatch: NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_isposinf +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: isposinf_sparse +======= + SparseCPU, SparseCUDA: isposinf_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isposinf_sparse_csr tags: pointwise @@ -13299,7 +13992,11 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: isposinf_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: isposinf_sparse_out +======= + SparseCPU, SparseCUDA: isposinf_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isposinf_sparse_csr_out tags: pointwise @@ -13308,7 +14005,11 @@ structured_delegate: isneginf.out dispatch: NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_isneginf +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: isneginf_sparse +======= + SparseCPU, SparseCUDA: isneginf_sparse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isneginf_sparse_csr tags: pointwise @@ -13317,7 +14018,11 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: isneginf_out +<<<<<<< HEAD SparseCPU, SparseCUDA, SparseMPS: isneginf_sparse_out +======= + SparseCPU, SparseCUDA: isneginf_sparse_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isneginf_sparse_csr_out tags: pointwise @@ -15030,7 +15735,10 @@ - func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) dispatch: CUDA: _scaled_dot_product_cudnn_attention_backward_cuda +<<<<<<< HEAD NestedTensorCUDA: _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags: nondeterministic_seeded - func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) @@ -15063,11 +15771,14 @@ CUDA: _cudnn_attention_forward tags: nondeterministic_seeded +<<<<<<< HEAD - func: _cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) dispatch: CUDA: _cudnn_attention_backward tags: nondeterministic_seeded +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor variants: function dispatch: @@ -15670,7 +16381,11 @@ - func: special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: +<<<<<<< HEAD CPU, CUDA, MPS: special_shifted_chebyshev_polynomial_t_out +======= + CPU, CUDA: special_shifted_chebyshev_polynomial_t_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15719,7 +16434,11 @@ - func: special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: +<<<<<<< HEAD CPU, CUDA, MPS: special_shifted_chebyshev_polynomial_u_out +======= + CPU, CUDA: special_shifted_chebyshev_polynomial_u_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15768,7 +16487,11 @@ - func: special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: +<<<<<<< HEAD CPU, CUDA, MPS: special_shifted_chebyshev_polynomial_v_out +======= + CPU, CUDA: special_shifted_chebyshev_polynomial_v_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15817,7 +16540,11 @@ - func: special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: +<<<<<<< HEAD CPU, CUDA, MPS: special_shifted_chebyshev_polynomial_w_out +======= + CPU, CUDA: special_shifted_chebyshev_polynomial_w_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15926,7 +16653,10 @@ variants: function dispatch: CPU: _fused_adagrad_kernel_cpu_ +<<<<<<< HEAD CUDA: _fused_adagrad_kernel_cuda_ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _fused_adagrad, _fused_adagrad.out - func: _fused_adagrad_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () @@ -15934,7 +16664,10 @@ variants: function dispatch: CPU: _fused_adagrad_kernel_cpu_ +<<<<<<< HEAD CUDA: _fused_adagrad_kernel_cuda_ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) autogen: _fused_adagrad.tensor_lr, _fused_adagrad.tensor_lr_out # This op is ONLY used by pytorch/XLA in functionalization, and should never show up in vanilla eager mode or in any pytorch tracing contexts. diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 96c6ab8310f80..fa47c12907b9c 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -349,6 +349,7 @@ _scaled_dot_product_cudnn_attention_nestedtensor_cuda( return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor()); } +<<<<<<< HEAD std::tuple _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda( const Tensor& grad_out, const Tensor& query, @@ -406,6 +407,8 @@ std::tuple _scaled_dot_product_cudnn_attention_nestedten } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::tuple _scaled_dot_product_flash_attention_backward_nested( const at::Tensor& grad_out_, const at::Tensor& query, diff --git a/aten/src/ATen/native/quantized/QTensor.cpp b/aten/src/ATen/native/quantized/QTensor.cpp index f804670c31538..37c3b3a17f381 100644 --- a/aten/src/ATen/native/quantized/QTensor.cpp +++ b/aten/src/ATen/native/quantized/QTensor.cpp @@ -335,8 +335,11 @@ std::tuple choose_qparams_optimized( const int64_t n_bins, const double ratio, int64_t bit_width) { +<<<<<<< HEAD const float* input_row = input_tensor.const_data_ptr(); TORCH_CHECK_VALUE(input_row != nullptr, "input tensor is empty and has no data"); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (numel < 0 || numel > input_tensor.numel()) { TORCH_CHECK(false, "numel is out of the bound of input tensor"); @@ -344,7 +347,11 @@ std::tuple choose_qparams_optimized( TORCH_CHECK(numel <= input_tensor.numel(), "numel ", numel, " greater than input_tensor.numel() ", input_tensor.numel()); +<<<<<<< HEAD +======= + const float* input_row = input_tensor.const_data_ptr(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) float xmin = *std::min_element(input_row, input_row + numel); float xmax = *std::max_element(input_row, input_row + numel); float n_bins_float = static_cast(n_bins); diff --git a/aten/src/ATen/native/quantized/cpu/ACLUtils.cpp b/aten/src/ATen/native/quantized/cpu/ACLUtils.cpp index c689132c7692e..546de933a5f86 100644 --- a/aten/src/ATen/native/quantized/cpu/ACLUtils.cpp +++ b/aten/src/ATen/native/quantized/cpu/ACLUtils.cpp @@ -81,7 +81,11 @@ DynamicQuantMatmul::DynamicQuantMatmul( auto src_q_tensor_info = arm_compute::TensorInfo( arm_compute::TensorShape(weight_dim_0, m), 1, +<<<<<<< HEAD // ACL dynamically quantized matmuls only support (signed) int8_t +======= + // ACL dyanamically quantized matmuls only support (signed) int8_t +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) arm_compute::DataType::QASYMM8_SIGNED, // TODO: setting the initial offset value to int8_t max instead of zero, // because ACL currently skips MatrixBReduction calculation if the diff --git a/aten/src/ATen/native/quantized/cpu/OnednnUtils.h b/aten/src/ATen/native/quantized/cpu/OnednnUtils.h index 963a47a21fa9f..122f6f658889f 100644 --- a/aten/src/ATen/native/quantized/cpu/OnednnUtils.h +++ b/aten/src/ATen/native/quantized/cpu/OnednnUtils.h @@ -460,6 +460,9 @@ at::Tensor _qconv_prepack_onednn( int64_t groups, std::optional> input_shape=std::nullopt); +<<<<<<< HEAD #define FP8E4M3_MAX 448.0 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif // #if AT_MKLDNN_ENABLED() diff --git a/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h b/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h index 764d237e68b4c..8a61a578bf2f1 100644 --- a/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h +++ b/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h @@ -456,7 +456,11 @@ make_zero_points_and_scales_tensor( uint32_t groups = 1) { const int out_ch_idx = transpose ? 1 : 0; const auto num_output_channels = weight_contig.size(out_ch_idx) * (transpose ? groups : 1); +<<<<<<< HEAD // Add 8 to account for buffering needed by QNNPACK. +======= + // Add 8 to account for bufferring needed by QNNPACK. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto num_output_channels_padded = num_output_channels + kPaddingChannels; const auto qtype = weight_contig.qscheme(); std::vector weight_zp(num_output_channels_padded, 0); diff --git a/aten/src/ATen/native/quantized/cpu/UpSampleNearest2d.cpp b/aten/src/ATen/native/quantized/cpu/UpSampleNearest2d.cpp index 42c000ee09d5c..091cbe5b90427 100644 --- a/aten/src/ATen/native/quantized/cpu/UpSampleNearest2d.cpp +++ b/aten/src/ATen/native/quantized/cpu/UpSampleNearest2d.cpp @@ -17,7 +17,10 @@ #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace at::native { @@ -54,8 +57,13 @@ static void upsample_nearest2d_out_frame( return; } +<<<<<<< HEAD std::vector input_offset_arr(output_width); int64_t* input_offset = input_offset_arr.data(); +======= + std::unique_ptr input_offset_arr(new int64_t[output_width]); + int64_t* input_offset = input_offset_arr.get(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto w2 : c10::irange(output_width)) { const int64_t w1 = nn_compute_source_index_fn(width_scale, w2, input_width); diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp index 0919acd21deb5..39bd5cffccd87 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp @@ -83,8 +83,15 @@ void CopyICFirst3dTensorToChannelsLast3dTensor( for (int64_t i = 0; i < G * OC_G; ++i) { for (const auto j : c10::irange(inner_size)) { for (const auto ic : c10::irange(IC_G)) { +<<<<<<< HEAD int g = static_cast(i / OC_G); int oc = static_cast(i % OC_G); +======= + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + int g = i / OC_G; + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + int oc = i % OC_G; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dst[(i * inner_size + j) * IC_G + ic] = src[((g * IC_G + ic) * OC_G + oc) * inner_size + j]; } @@ -110,6 +117,7 @@ fbgemm::conv_param_t MakeFbgemmConvParam( std::array image_shape_{}; std::array kernels_{}; std::array strides_{}; +<<<<<<< HEAD std::array pads_{}; std::array dilations_{}; std::array output_padding_{}; @@ -130,6 +138,26 @@ fbgemm::conv_param_t MakeFbgemmConvParam( std::copy(pads.begin(), pads.begin() + static_cast(pads.size()), pads_.begin()); const auto pads_size = static_cast(pads.size()); std::move(pads.begin(), pads.begin() + pads_size, pads_.begin() + pads_size); +======= + std::array pads_{}; + std::array dilations_{}; + std::array output_padding_{}; + std::move(image_shape.begin(), image_shape.begin() + image_shape.size(), image_shape_.begin()); + std::move( + kernels.begin(), kernels.begin() + kernels.size(), kernels_.begin()); + std::move( + strides.begin(), strides.begin() + strides.size(), strides_.begin()); + std::move( + dilations.begin(), + dilations.begin() + dilations.size(), + dilations_.begin()); + std::move( + output_padding.begin(), + output_padding.begin() + output_padding.size(), + output_padding_.begin()); + std::copy(pads.begin(), pads.begin() + pads.size(), pads_.begin()); + std::move(pads.begin(), pads.begin() + pads.size(), pads_.begin() + pads.size()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return fbgemm::conv_param_t( N, // batch size @@ -158,7 +186,11 @@ Tensor MakeStridedQTensorCPU( TORCH_CHECK( isQIntType(typeMetaToScalarType(dtype)), "ScalarType is not supported in new_qtensor_cpu."); +<<<<<<< HEAD int64_t size_bytes = static_cast(nelements * dtype.itemsize()); +======= + int64_t size_bytes = nelements * dtype.itemsize(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto storage = c10::make_intrusive( StorageImpl::use_byte_size_t(), size_bytes, @@ -366,7 +398,11 @@ Tensor ConvertConvWeightsToChannelLastTensor<3>( #endif // USE_FBGEMM namespace { +<<<<<<< HEAD // This is really terrible, but couldn't figure out a better way to constexpr convert int to +======= + // This is really terrible, but couldnt figure out a better way to constexpr convert int to +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // string and then perform string concatenation on/with it constexpr const char* _hack_int_to_class_name(int x) { switch(x) { @@ -531,8 +567,13 @@ int register_embedding_params() { TORCH_INTERNAL_ASSERT(longs.size() == 1, "EmbeddingPackedParams: Expected bit_rate to be serialized"); TORCH_CHECK(version == 1, "EmbeddingPackedParams: Currently only version 1 supported."); +<<<<<<< HEAD const auto& weight = tensors[0]; return PackedEmbeddingBagWeight::prepack(weight); +======= + at::Tensor weight = std::move(tensors[0]); + return PackedEmbeddingBagWeight::prepack(std::move(weight)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def("bit_rate", &EmbeddingPackedParamsBase::bit_rate) .def("unpack", &EmbeddingPackedParamsBase::unpack) diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h index a1139be833f87..ef0947ef34c5a 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h @@ -7,13 +7,19 @@ #include #ifdef USE_FBGEMM +<<<<<<< HEAD C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Winconsistent-missing-destructor-override") #include C10_DIAGNOSTIC_POP() #include +<<<<<<< HEAD C10_DIAGNOSTIC_POP() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // The struct for the packed weight matrix (PackBMatrix) and the corresponding // column offsets used for the fully connect layer, which are both prepared in @@ -380,7 +386,11 @@ struct TORCH_API PackedEmbeddingBagWeight : public EmbeddingPackedParamsBase { at::Tensor unpack() override; static c10::intrusive_ptr prepack( +<<<<<<< HEAD const at::Tensor& weight); +======= + at::Tensor weight); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t bit_rate() const override { return bit_rate_; diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index b5b887b98bb08..55aa754ad04f0 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -2696,11 +2696,18 @@ void _fake_quantize_tensor_helper( bool* mask_val = (bool*)(data[1] + i * strides[1]); scalar_t* input_val = (scalar_t*)(data[2] + i * strides[2]); +<<<<<<< HEAD if (fake_quant_on) { auto qval_f = z_point + std::nearbyint(*input_val * inv_scale); const auto qval = static_cast(std::fmin(std::fmax(qval_f, quant_min), quant_max)); *output_val = (qval - z_point) * sc; *mask_val = ((quant_min <= qval_f) && (qval_f <= quant_max)); +======= + const auto qval = static_cast(z_point + std::nearbyint(*input_val * inv_scale)); + if (fake_quant_on) { + *output_val = (std::fmin(std::fmax(qval, quant_min), quant_max) - z_point) * sc; + *mask_val = ((quant_min <= qval) && (qval <= quant_max)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { *output_val = *input_val; *mask_val = 1; @@ -2716,11 +2723,18 @@ void _fake_quantize_tensor_helper( bool* mask_val = (bool*)(data[1] + i * strides[1]); scalar_t* input_val = (scalar_t*)(data[2] + i * strides[2]); +<<<<<<< HEAD if (fake_quant_on) { auto qval_f = z_point + std::nearbyint(*input_val * inv_scale); const auto qval = static_cast(std::fmin(std::fmax(qval_f, quant_min), quant_max)); *output_val = (qval - z_point) * sc; *mask_val = ((quant_min <= qval_f) && (qval_f <= quant_max)); +======= + const auto qval = static_cast(z_point + std::nearbyint(*input_val * inv_scale)); + if (fake_quant_on) { + *output_val = (std::fmin(std::fmax(qval, quant_min), quant_max) - z_point) * sc; + *mask_val = ((quant_min <= qval) && (qval <= quant_max)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { *output_val = *input_val; *mask_val = 1; diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 3b50bad579023..52ab5fd13ab4d 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -34,6 +34,7 @@ #include #include #include +<<<<<<< HEAD #include #include #include @@ -43,6 +44,8 @@ #include #include #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif #include @@ -1277,7 +1280,11 @@ at::Tensor PackedConvWeightsOnednn::apply_impl( float sum_scale = has_accum ? accum.value().q_scale() : 1.0; int32_t sum_zero_point = has_accum ? accum.value().q_zero_point() : 0; if (has_accum) { +<<<<<<< HEAD // Just tells we have these post op, the actual value such as scale and zero point will be set later. +======= + // Just tells we have these post op, the actual value such as scale and zero point will be setted later. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) op_attr = kReluFused ? ideep::attr_t::residual_with_sum_zero_point() : ideep::attr_t::fuse_sum(); const ideep::scale_t accum_scale = ideep::scale_t(1, 1.0/sum_scale); const ideep::zero_point_t accum_zero_points = ideep::zero_point_t(1, sum_zero_point); @@ -1393,6 +1400,7 @@ template at::Tensor PackedConvWeightsOnednn<3>::apply_relu( double output_scale, int64_t output_zero_point); +<<<<<<< HEAD static at::Tensor _fp8_convolution_onednn_ref( at::Tensor act, // contains quantized values but not QTensor double act_scale, @@ -1507,6 +1515,8 @@ static at::Tensor _fp8_convolution_onednn_ref( return y_f32.to(out_dtype); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static at::Tensor _quantized_convolution_onednn( at::Tensor act, // contains quantized values but not QTensor double act_scale, @@ -1531,7 +1541,10 @@ static at::Tensor _quantized_convolution_onednn( std::optional unary_attr, torch::List> unary_scalars, std::optional unary_algorithm) { +<<<<<<< HEAD using ideep::tensor; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /*********************************/ /* Checks */ /*********************************/ @@ -1588,6 +1601,13 @@ static at::Tensor _quantized_convolution_onednn( if (kSpatialDim == 1) { kSpatialDim += 1; } +<<<<<<< HEAD +======= + TORCH_CHECK( + weight.is_mkldnn(), + func_name, ": Weight should be prepacked as an MKLDNN tensor" + ); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (transposed) { TORCH_CHECK( false, @@ -1601,6 +1621,7 @@ static at::Tensor _quantized_convolution_onednn( padding = quant_utils::MakeArgForConv1d(padding, 0); dilation = quant_utils::MakeArgForConv1d(dilation, 1); } +<<<<<<< HEAD auto act_dtype = act.scalar_type(); TORCH_CHECK( act_dtype == c10::ScalarType::Byte || act_dtype == c10::ScalarType::Float8_e4m3fn, @@ -1608,6 +1629,14 @@ static at::Tensor _quantized_convolution_onednn( TORCH_CHECK( weight.scalar_type() == c10::ScalarType::Char || weight.scalar_type() == c10::ScalarType::Float8_e4m3fn, func_name, ": Weight tensor should have int8 (char) or fp8 data type"); +======= + TORCH_CHECK( + act.scalar_type() == c10::ScalarType::Byte, + func_name, ": Input tensor should have uint8 (unsigned char) data type"); + TORCH_CHECK( + weight.scalar_type() == c10::ScalarType::Char, + func_name, ": Weight tensor should have int8 (char) data type"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK( weight.ndimension() == kSpatialDim + 2, func_name, ": Weights are expected to have ", kSpatialDim + 2, " dimensions"); @@ -1623,6 +1652,7 @@ static at::Tensor _quantized_convolution_onednn( dilation.size() == (decltype(dilation.size()))kSpatialDim, func_name, ": dilation should contain ", kSpatialDim, " elements for ", kSpatialDim, "D convolution."); +<<<<<<< HEAD bool is_fp8 = weight.scalar_type() == c10::ScalarType::Float8_e4m3fn; if (is_fp8) { TORCH_CHECK(act_dtype == c10::ScalarType::Float8_e4m3fn, @@ -1647,6 +1677,8 @@ static at::Tensor _quantized_convolution_onednn( weight.is_mkldnn(), func_name, ": Weight should be prepacked as an MKLDNN tensor" ); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Parameters #if IDEEP_PREREQ(3, 1, 0, 1) @@ -1722,7 +1754,11 @@ static at::Tensor _quantized_convolution_onednn( c10::MemoryFormat::ChannelsLast : c10::MemoryFormat::ChannelsLast3d); auto src_dims = act_contig.sizes().vec(); +<<<<<<< HEAD auto src_data_type = at::native::get_mkldnn_dtype(act.scalar_type()); +======= + auto src_data_type = dnnl::memory::data_type::u8; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto src_desc = ideep::tensor::desc(src_dims, src_data_type, kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc); ideep::tensor src; @@ -1734,13 +1770,20 @@ static at::Tensor _quantized_convolution_onednn( output_sizes = at::native::conv_output_size(input_size, kernel_size, padding.vec(), stride.vec(), dilation.vec()); ideep::dims dst_dims = ideep::dims({output_sizes.cbegin(), output_sizes.cend()}); // Output is not a quantized tensor but data type is uint8 +<<<<<<< HEAD auto out_dtype = output_dtype.has_value() ? output_dtype.value() : act_dtype; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::Tensor output = has_accum_postop_sum ? accum.value() : at::empty( dst_dims, at::device(c10::kCPU) +<<<<<<< HEAD .dtype(out_dtype) +======= + .dtype(fp32_output ? c10::kFloat : (bfloat16_output ? c10::kBFloat16 : c10::kByte)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .memory_format(kSpatialDim == 2 ? c10::MemoryFormat::ChannelsLast : c10::MemoryFormat::ChannelsLast3d) @@ -1760,6 +1803,7 @@ static at::Tensor _quantized_convolution_onednn( unary_scalars, unary_algorithm.has_value() ? unary_algorithm.value() : "" ); +<<<<<<< HEAD // Avoid NaN if output dtype is fp8 if (out_dtype == c10::kFloat8_e4m3fn) { // To avoid NaN, we need to clamp the intermediate results (in fp32) to [-488, 488] @@ -1770,13 +1814,21 @@ static at::Tensor _quantized_convolution_onednn( op_attr.set_post_ops(post_ops); output_scale = 1.0f; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if IDEEP_PREREQ(3, 1, 0, 0) // Use oneDNN's APIs instead of prepare/compute from ideep to reduce integration overhead. // The functions from ideep are heavy because they have complex data structures for unified API // oneDNN version >= 3.1.0 is required. +<<<<<<< HEAD auto weight_grouped = packed_weight.make_grouped_weights(groups, /* is_deconv */false); auto weights_desc = tensor::desc(weight_grouped.get_dims(), packed_weight.get_data_type(), ideep::format_tag::any); +======= + using ideep::tensor; + auto weight_grouped = packed_weight.make_grouped_weights(groups, /* is_deconv */false); + auto weights_desc = tensor::desc(weight_grouped.get_dims(), ideep::data_type::s8, ideep::format_tag::any); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (groups > 1) { weights_desc = weights_desc.to_grouped(groups); } @@ -2245,6 +2297,7 @@ TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) { m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.binary_tensor"), at::native::QConvoneDNN::run_pointwise_binary_tensor); } +<<<<<<< HEAD TORCH_LIBRARY_IMPL(onednn, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("onednn::qconv_pointwise"), at::native::QConvoneDNN::run_pointwise); m.impl(TORCH_SELECTIVE_NAME("onednn::qconv_pointwise.tensor"), at::native::QConvoneDNN::run_pointwise_tensor); @@ -2252,5 +2305,7 @@ TORCH_LIBRARY_IMPL(onednn, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.binary_tensor"), at::native::QConvoneDNN::run_pointwise_binary_tensor); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace } // namespace at::native diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index e9043f06b3018..9c764c1dfd9e6 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -519,10 +519,13 @@ at::Tensor _qconv_prepack_onednn( dilation.size() == (decltype(dilation.size()))kSpatialDim, "dilation should contain ", kSpatialDim, " elements for ", kSpatialDim, "D convolution."); +<<<<<<< HEAD TORCH_CHECK( weight.scalar_type() == at::kChar || weight.scalar_type() == at::kFloat8_e4m3fn, "Weight should have dtype int8 or fp8_e4m3fn but got ", weight.scalar_type()); bool is_fp8 = weight.scalar_type() == at::kFloat8_e4m3fn; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool is_1d = (1 == kSpatialDim); auto x_dims = input_shape.has_value()?input_shape.value().vec():ideep::dims(); @@ -539,12 +542,15 @@ at::Tensor _qconv_prepack_onednn( dilation = quant_utils::MakeArgForConv1d(dilation, 1); kSpatialDim += 1; } +<<<<<<< HEAD if (is_fp8) { // The current version of oneDNN does not support fp8 conv // TODO(weiwen) Remove this when oneDNN supports fp8 conv // FP8 convolution is not supported by oneDNN until v3.9 return weight; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto w_dims = weight.sizes().vec(); auto strides = stride.vec(); auto padding_l = padding.vec(); @@ -591,6 +597,7 @@ at::Tensor _qconv_prepack_onednn( ideep::dims dims_iohw, dims_giohw; ideep::tag w_tag = ideep::tag::any; const bool with_groups = groups > 1; +<<<<<<< HEAD auto w_dnnl_dtype = at::native::get_mkldnn_dtype(weight.scalar_type()); auto x_dnnl_dtype = is_fp8 ? dnnl::memory::data_type::f8_e4m3 : dnnl::memory::data_type::u8; w_desc = ideep::convolution_forward::expected_weights_desc( @@ -598,6 +605,13 @@ at::Tensor _qconv_prepack_onednn( strides, padding_l, padding_r, dilates, groups, dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference, x_dnnl_dtype, x_dims, op_attr, /*is_channels_last=*/true); +======= + w_desc = ideep::convolution_forward::expected_weights_desc( + w_dims, dnnl::memory::data_type::s8, + strides, padding_l, padding_r, dilates, groups, + dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference, + dnnl::memory::data_type::u8, x_dims, op_attr, /*is_channels_last=*/true); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Note: Weight in Conv1D will unsqueeze into Conv2D in previous step weight_copy = weight.clone(c10::MemoryFormat::Contiguous); @@ -610,7 +624,11 @@ at::Tensor _qconv_prepack_onednn( ideep::dims wei_dims = with_groups ? ideep::utils::group_dims(w_desc.get_dims(), groups) : w_desc.get_dims(); ideep::tensor wgt = ideep::tensor( +<<<<<<< HEAD ideep::tensor::desc({wei_dims, w_dnnl_dtype, w_tag}, groups), +======= + ideep::tensor::desc({wei_dims, dnnl::memory::data_type::s8, w_tag}, groups), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) weight_copy.data_ptr()); wgt.set_scale(weights_scales); // Scales are needed for feed_from(). diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp index 807a9b25d3772..c7a2bb8a1edfa 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp @@ -33,7 +33,11 @@ * for each row along with the quantized weights. */ c10::intrusive_ptr PackedEmbeddingBagWeight::prepack( +<<<<<<< HEAD const at::Tensor& qweight) { +======= + at::Tensor qweight) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static constexpr int64_t version = 1; TORCH_CHECK( qweight.dim() == 2, @@ -67,8 +71,13 @@ c10::intrusive_ptr PackedEmbeddingBagWeight::prepack( "Expect embedding_bag weights to be quantized using kPerChannelAffineFloatQParams"); std::vector weight_bias(embedding_rows); +<<<<<<< HEAD const auto& channel_scales = qweight.q_per_channel_scales(); const auto& channel_zero_points = qweight.q_per_channel_zero_points(); +======= + at::Tensor channel_scales = qweight.q_per_channel_scales(); + at::Tensor channel_zero_points = qweight.q_per_channel_zero_points(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector weight_scales( channel_scales.data_ptr(), channel_scales.data_ptr() + embedding_rows); @@ -77,11 +86,14 @@ c10::intrusive_ptr PackedEmbeddingBagWeight::prepack( channel_zero_points.data_ptr() + embedding_rows); for (const auto i : c10::irange(embedding_rows)) { +<<<<<<< HEAD // As of now weight_zero_points and weight_scales are initialized with // the size of embedding_rows. Hence, this linter is a false positive. // However, if this assumption changes in the future, we need to // ensure that the bounds are checked. // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) weight_bias[i] = weight_zero_points[i] * weight_scales[i] * -1; } @@ -242,16 +254,27 @@ Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) { const auto weight_sizes = weight.sizes(); const auto cols_dim = weight_sizes.size() - 1; +<<<<<<< HEAD const int64_t embedding_rows = c10::size_to_dim_(static_cast(cols_dim), weight_sizes); const int32_t embedding_cols = static_cast(weight_sizes[cols_dim]); // Add 8 bytes per column to store FP32 scale and zero_point per row. const int32_t output_columns = static_cast(embedding_cols + 2 * sizeof(float)); +======= + const int64_t embedding_rows = c10::size_to_dim_(cols_dim, weight_sizes); + const int32_t embedding_cols = weight_sizes[cols_dim]; + // Add 8 bytes per column to store FP32 scale and zero_point per row. + const int32_t output_columns = embedding_cols + 2 * sizeof(float); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto weight_contig = weight.expect_contiguous(weight.suggest_memory_format()); // Adjust output dimensions to account for FP32 scale and zero_points. std::vector output_shape = weight_sizes.vec(); +<<<<<<< HEAD output_shape.at(cols_dim) = output_columns; +======= + output_shape[cols_dim] = output_columns; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::native::resize_(output, output_shape, std::nullopt); auto* output_data = output.data_ptr(); @@ -333,6 +356,7 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) { weight.scalar_type() == at::ScalarType::Float || weight.scalar_type() == at::ScalarType::Half, "'embedding_bag_byte_prepack' only support float32 or float16."); +<<<<<<< HEAD const auto weight_sizes = weight.sym_sizes(); const auto cols_dim = weight.ndimension() - 1; const auto embedding_cols = weight_sizes[cols_dim]; @@ -342,6 +366,17 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) { // Adjust output dimensions to account for FP32 scale and zero_points. auto output_shape = weight_sizes.vec(); output_shape.at(cols_dim) = output_columns; +======= + const auto weight_sizes = weight.sizes(); + const auto cols_dim = weight_sizes.size() - 1; + const int32_t embedding_cols = weight_sizes[cols_dim]; + // Add 8 bytes per column to store FP32 scale and zero_point per row. + const int32_t output_columns = embedding_cols + 2 * sizeof(float); + + // Adjust output dimensions to account for FP32 scale and zero_points. + std::vector output_shape = weight_sizes.vec(); + output_shape[cols_dim] = output_columns; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::SymDimVector output_shape_vec(output_shape); return at::empty_symint( @@ -412,7 +447,11 @@ Tensor _qembeddingbag_nbit_prepack_helper( bit_width, weight_data + start_idx * embedding_cols, end_idx - start_idx, +<<<<<<< HEAD static_cast(embedding_cols), +======= + embedding_cols, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_data + start_idx * output_shape[1]); }); } else { @@ -423,7 +462,11 @@ Tensor _qembeddingbag_nbit_prepack_helper( bit_width, weight_data + start_idx * embedding_cols, end_idx - start_idx, +<<<<<<< HEAD static_cast(embedding_cols), +======= + embedding_cols, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_data + start_idx * output_shape[1]); }); } @@ -480,7 +523,11 @@ Tensor _qembeddingbag_nbit_prepack_helper( std::uint8_t quantized = std::max( 0, std::min( +<<<<<<< HEAD static_cast(lrintf((X - Xmin) * inverse_scale)), (1 << bit_width) - 1)); +======= + lrintf((X - Xmin) * inverse_scale), (1 << bit_width) - 1)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // We pack 2 4-bit values in a byte. Index 0 is packed in the lower // 4-bits and index 1 is packed in the upper 4-bits. if (col % NUM_ELEM_PER_BYTE == 0) { @@ -533,8 +580,13 @@ Tensor qembeddingbag_2bit_prepack( class QEmbeddingPackWeights final { public: +<<<<<<< HEAD static c10::intrusive_ptr run(const at::Tensor& weight) { return PackedEmbeddingBagWeight::prepack(weight); +======= + static c10::intrusive_ptr run(at::Tensor weight) { + return PackedEmbeddingBagWeight::prepack(std::move(weight)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } }; diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index a3a494d16fd69..e1867cd8b6fcf 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -3,7 +3,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -28,6 +31,7 @@ #include // for quantize_per_te... #include #include +<<<<<<< HEAD #include #include #include @@ -36,6 +40,8 @@ #include #include #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif #include @@ -927,6 +933,7 @@ at::Tensor PackedLinearWeightsOnednn:: apply_tanh( std::move(input), output_scale, output_zero_point); } +<<<<<<< HEAD static at::Tensor fp8_qlinear_onednn_ref( at::Tensor input, double input_scale, @@ -1053,6 +1060,8 @@ static at::Tensor fp8_qlinear_onednn_ref( return y_f32.to(out_dtype); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static at::Tensor linear_int8_with_onednn_weight( at::Tensor input, // int8 CPU Tensor, not QTensor double input_scale, @@ -1074,6 +1083,7 @@ static at::Tensor linear_int8_with_onednn_weight( std::string_view& unary_post_op_algorithm) { using ideep::tensor; const int64_t dim = input.dim(); +<<<<<<< HEAD TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte || input.scalar_type() == c10::ScalarType::Char || input.scalar_type() == c10::ScalarType::Float8_e4m3fn, "qlinear with mkldnn tensor: data type of input should be uint8, int8 or float8_e4m3fn."); TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char || onednn_weight.scalar_type() == c10::ScalarType::Float8_e4m3fn, @@ -1086,6 +1096,12 @@ static at::Tensor linear_int8_with_onednn_weight( input.scalar_type(), " and ", onednn_weight.scalar_type()); is_fp8 = true; } +======= + TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte || input.scalar_type() == c10::ScalarType::Char, + "qlinear with mkldnn tensor: data type of input should be uint8 or int8 (unsigned char or char)."); + TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char, + "qlinear with mkldnn tensor: data type of weight should be int8 (char)."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK( weight_scales.scalar_type() == c10::ScalarType::Float, "weight scales should be dtype c10::ScalarType::Float."); TORCH_CHECK( @@ -1119,7 +1135,11 @@ static at::Tensor linear_int8_with_onednn_weight( ); } if (binary_post_op == "sum") { +<<<<<<< HEAD auto expected_dtype = output_dtype.has_value() ? output_dtype.value() : input.scalar_type(); +======= + auto expected_dtype = output_dtype.has_value() ? output_dtype.value() : c10::kByte; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK( other.value().scalar_type() == expected_dtype, "onednn qlinear: the dtype of extra input for binary post op should be ", expected_dtype, @@ -1127,6 +1147,7 @@ static at::Tensor linear_int8_with_onednn_weight( ); } } +<<<<<<< HEAD #if defined(__powerpc__) if (is_fp8) { #else @@ -1140,6 +1161,8 @@ static at::Tensor linear_int8_with_onednn_weight( binary_post_op, binary_alpha, unary_post_op, unary_post_op_args, unary_post_op_algorithm); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // If the input has more than two dimensions, we will reshape it to a 2-dimensional form // for calculation and subsequently reshape the output back. @@ -1167,13 +1190,20 @@ static at::Tensor linear_int8_with_onednn_weight( } std::vector src_dims = {M, K}; std::vector dst_dims = {M, N}; +<<<<<<< HEAD auto out_dtype = output_dtype.has_value() ? output_dtype.value() : input.scalar_type(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::Tensor output = binary_post_op == "sum" ? other.value() : at::empty( dst_dims, at::device(c10::kCPU) +<<<<<<< HEAD .dtype(out_dtype) +======= + .dtype(fp32_output ? c10::kFloat : (bf16_output ? c10::kBFloat16 : c10::kByte)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ); if (output.numel() == 0) { return output; @@ -1186,7 +1216,11 @@ static at::Tensor linear_int8_with_onednn_weight( empty_tensor; // Create onednn primitive +<<<<<<< HEAD auto src_dtype = at::native::get_mkldnn_dtype(input.scalar_type()); +======= + auto src_dtype = input.scalar_type() == c10::kByte ? ideep::data_type::u8 : ideep::data_type::s8; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto src_desc = tensor::desc(src_dims, src_dtype, ideep::format_tag::any); auto weights_desc = packed_weight.get_desc(); auto dst_dtype = dst.get_data_type(); @@ -1208,6 +1242,7 @@ static at::Tensor linear_int8_with_onednn_weight( unary_post_op_args, unary_post_op_algorithm ); +<<<<<<< HEAD // Avoid NaN if output dtype is fp8 if (out_dtype == c10::kFloat8_e4m3fn) { // To avoid NaN, we need to clamp the intermediate results (in fp32) to [-488, 488] @@ -1218,6 +1253,8 @@ static at::Tensor linear_int8_with_onednn_weight( op_attr.set_post_ops(post_ops); output_scale = 1.0f; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (input_scale != 1.0f) { op_attr.set_scales_mask(DNNL_ARG_SRC, 0); } @@ -1630,6 +1667,7 @@ TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) { TORCH_FN(at::native::QLinearOnednn::run_pointwise_binary_tensor)); } +<<<<<<< HEAD TORCH_LIBRARY_IMPL(onednn, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise"), TORCH_FN(QLinearOnednn::run_pointwise)); @@ -1641,5 +1679,7 @@ TORCH_LIBRARY_IMPL(onednn, CPU, m) { TORCH_FN(at::native::QLinearOnednn::run_pointwise_binary_tensor)); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace } // namespace at::native diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index 4ed50f6f8735a..72446d74f957c 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -888,7 +888,11 @@ class QLinearUnpackedDynamicFp16 final { static at::Tensor run( at::Tensor input, const at::Tensor& weight, +<<<<<<< HEAD const std::optional& bias) { +======= + const at::Tensor& bias) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // We make a strong guarantee that models using these operators will have // the same numerics across different machines. Therefore, we do not provide // a fallback path and rather fail loudly if we cannot run FBGEMM. @@ -908,7 +912,11 @@ class QLinearUnpackedDynamicFp16 final { static at::Tensor meta( at::Tensor input, const at::Tensor& weight, +<<<<<<< HEAD const std::optional& bias) { +======= + const at::Tensor& bias) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // We make a strong guarantee that models using these operators will have // the same numerics across different machines. Therefore, we do not provide // a fallback path and rather fail loudly if we cannot run FBGEMM. @@ -929,7 +937,11 @@ class QLinearUnpackedDynamicFp16 final { static at::Tensor run( at::Tensor /* input */, const at::Tensor& weight, +<<<<<<< HEAD const std::optional& bias) { +======= + const at::Tensor& bias) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // We make a strong guarantee that models using these operators will have // the same numerics across different machines. Therefore, we do not provide // a fallback path and rather fail loudly if we cannot run FBGEMM. @@ -940,7 +952,11 @@ class QLinearUnpackedDynamicFp16 final { static at::Tensor meta( at::Tensor /* input */, const at::Tensor& weight, +<<<<<<< HEAD const std::optional& bias) { +======= + const at::Tensor& bias) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK( false, "This PyTorch installation was not built with FBGEMM operators"); } diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index b4ae4e677bcd2..b1444316243bb 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -297,6 +297,7 @@ c10::intrusive_ptr PackedLinearWeightsOnednn::prepack( static inline at::Tensor pack_weight_to_onednn_tensor( const at::Tensor& weight, std::optional>& input_shape) { +<<<<<<< HEAD at::ScalarType weigh_dtype = weight.scalar_type(); TORCH_CHECK( weigh_dtype == at::kChar || weigh_dtype == at::kFloat8_e4m3fn, @@ -328,6 +329,16 @@ static inline at::Tensor pack_weight_to_onednn_tensor( : dnnl::memory::data_type::u8; auto w_desc = ideep::matmul_forward::expected_weights_desc( wei.get_dims(), input_dims, w_data_type, x_data_type, op_attr); +======= + std::vector w_dims = weight.sizes().vec(); + ideep::tensor wei = ideep::tensor({w_dims, dnnl::memory::data_type::s8}, weight.data_ptr()); + wei.transpose_(0, 1); // oneDNN requires transposed weight + ideep::dims input_dims = input_shape.has_value() ? input_shape.value().vec() : ideep::dims(); + ideep::attr_t op_attr; + op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0); + auto w_desc = ideep::matmul_forward::expected_weights_desc( + wei.get_dims(), input_dims, dnnl::memory::data_type::s8, dnnl::memory::data_type::u8, op_attr); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ideep::tensor expected_weight(w_desc); expected_weight.feed_from(wei); auto packed_weight = at::native::new_with_itensor_mkldnn( diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-arm64.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-arm64.sh index 5c52f1a020f1e..d99d19eb19cd1 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-arm64.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-arm64.sh @@ -53,7 +53,11 @@ CMAKE_ARGS+=("-DANDROID_PIE=ON") CMAKE_ARGS+=("-DANDROID_STL=c++_static") CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=exceptions") +<<<<<<< HEAD # Use-specified CMake arguments go last to allow overriding defaults +======= +# Use-specified CMake arguments go last to allow overridding defaults +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CMAKE_ARGS+=($@) cd build/android/arm64-v8a && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-armv7.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-armv7.sh index 81da44097801f..5d70431398ac8 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-armv7.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-armv7.sh @@ -53,7 +53,11 @@ CMAKE_ARGS+=("-DANDROID_PIE=ON") CMAKE_ARGS+=("-DANDROID_STL=c++_static") CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=exceptions") +<<<<<<< HEAD # Use-specified CMake arguments go last to allow overriding defaults +======= +# Use-specified CMake arguments go last to allow overridding defaults +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CMAKE_ARGS+=($@) cd build/android/armeabi-v7a && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-x86.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-x86.sh index 747704f1edfea..63e8aada1d80a 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-x86.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-x86.sh @@ -53,7 +53,11 @@ CMAKE_ARGS+=("-DANDROID_PIE=ON") CMAKE_ARGS+=("-DANDROID_STL=c++_static") CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=exceptions") +<<<<<<< HEAD # Use-specified CMake arguments go last to allow overriding defaults +======= +# Use-specified CMake arguments go last to allow overridding defaults +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CMAKE_ARGS+=($@) cd build/android/x86 && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64.sh index 8e867f18d3f91..7c6a88af0bbf4 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64.sh @@ -40,7 +40,11 @@ CMAKE_ARGS+=("-DIOS_ARCH=arm64") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") +<<<<<<< HEAD # Use-specified CMake arguments go last to allow overriding defaults +======= +# Use-specified CMake arguments go last to allow overridding defaults +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CMAKE_ARGS+=($@) cd build/ios/arm64 && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64e.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64e.sh index 34a95d1944148..eaead37025b79 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64e.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64e.sh @@ -40,7 +40,11 @@ CMAKE_ARGS+=("-DIOS_ARCH=arm64e") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") +<<<<<<< HEAD # Use-specified CMake arguments go last to allow overriding defaults +======= +# Use-specified CMake arguments go last to allow overridding defaults +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CMAKE_ARGS+=($@) cd build/ios/arm64e && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7.sh index 37e57ab557fcc..ab94bf8c61aca 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7.sh @@ -40,7 +40,11 @@ CMAKE_ARGS+=("-DIOS_ARCH=armv7") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") +<<<<<<< HEAD # Use-specified CMake arguments go last to allow overriding defaults +======= +# Use-specified CMake arguments go last to allow overridding defaults +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CMAKE_ARGS+=($@) cd build/ios/armv7 && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7s.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7s.sh index 2fd2732191112..eb24af313b6e0 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7s.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7s.sh @@ -40,7 +40,11 @@ CMAKE_ARGS+=("-DIOS_ARCH=armv7s") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") +<<<<<<< HEAD # Use-specified CMake arguments go last to allow overriding defaults +======= +# Use-specified CMake arguments go last to allow overridding defaults +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CMAKE_ARGS+=($@) cd build/ios/armv7s && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-i386.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-i386.sh index b51b574d8136a..747f1e207402c 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-i386.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-i386.sh @@ -40,7 +40,11 @@ CMAKE_ARGS+=("-DIOS_ARCH=i386") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") +<<<<<<< HEAD # Use-specified CMake arguments go last to allow overriding defaults +======= +# Use-specified CMake arguments go last to allow overridding defaults +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CMAKE_ARGS+=($@) cd build/ios/i386 && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-x86_64.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-x86_64.sh index a3430082e3e57..6a688da2abba6 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-x86_64.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-x86_64.sh @@ -45,7 +45,11 @@ CMAKE_ARGS+=("-DIOS_ARCH=x86_64") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") +<<<<<<< HEAD # Use-specified CMake arguments go last to allow overriding defaults +======= +# Use-specified CMake arguments go last to allow overridding defaults +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CMAKE_ARGS+=($@) cd build/ios/x86_64 && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-local.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-local.sh index ac61a4061b90c..eb65d8021abec 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-local.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-local.sh @@ -27,7 +27,11 @@ CMAKE_ARGS+=("-DPYTORCH_QNNPACK_LIBRARY_TYPE=static") CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_BENCHMARKS=ON") CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_TESTS=ON") +<<<<<<< HEAD # Use-specified CMake arguments go last to allow overriding defaults +======= +# Use-specified CMake arguments go last to allow overridding defaults +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CMAKE_ARGS+=($@) cd build/local && cmake ../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c index 29f5338f5c734..4a709c83810a8 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c @@ -368,7 +368,11 @@ static enum pytorch_qnnp_status pytorch_qnnp_create_convolution_ndhwc_q8( case pytorch_qnnp_ukernel_type_xzp_gemm: { // TODO: XZP kernels won't be supporting per channel quantization. // For now we dont use XZP kernels anywhere. Probably deprecate it for now +<<<<<<< HEAD // and resurrect later if needed. +======= + // and ressurrect later if needed. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const uint32_t nr = pytorch_qnnp_params.q8conv_xzp.nr; const uint32_t kr = pytorch_qnnp_params.q8conv_xzp.kr; const uint32_t sr = pytorch_qnnp_params.q8conv_xzp.kc; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S index ac06fa5973eca..0708e433d5c0c 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S @@ -20,6 +20,7 @@ # Args passed via stack. # TOS +<<<<<<< HEAD # |------------| # |a | 0 # |w | 4 @@ -28,10 +29,21 @@ # |out ch index| 16 # |params | 20 # |------------| +======= +# |-----------| +# |a | 0 +# |w | 4 +# |c | 8 +# |c_stride | 12 +# |out ch indx| 16 +# |params | 20 +# |-----------| +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # # After loading w pointer in ip reg. # And after pushing r4-r8 and d8-d15 on stack +<<<<<<< HEAD # |------------| # |d8 - d15 | 0 # |r4 - r11 | 64 @@ -42,6 +54,18 @@ # |out ch index| 112 # |params | 116 # |------------| +======= +# |-----------| +# |d8 - d15 | 0 +# |r4 - r11 | 64 +# |a | 96 +# |w | 100 +# |c | 104 +# |c_stride | 108 +# |out ch indx| 112 +# |params | 116 +# |-----------| +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # # void pytorch_q8conv_ukernel_4x8__aarch32_neon( diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S index 1653b46e2d374..ebb7b2f2979ae 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S @@ -23,10 +23,17 @@ # Args passed via stack. # TOS +<<<<<<< HEAD # |------------| # |out ch index| 0 # |params | 8 # |------------| +======= +# |-----------| +# |out ch indx| 0 +# |params | 8 +# |-----------| +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # void pytorch_q8conv_ukernel_8x8__aarch64_neon( # size_t mr, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S index f18605124356e..a8b2ffc8eb3d3 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S @@ -20,6 +20,7 @@ # Args passed via stack. # TOS +<<<<<<< HEAD # |------------| # |a_stride | 0 # |w | 4 @@ -28,10 +29,21 @@ # |out ch index| 16 # |params | 20 # |------------| +======= +# |-----------| +# |a_stride | 0 +# |w | 4 +# |c | 8 +# |c_stride | 12 +# |out ch indx| 16 +# |params | 20 +# |-----------| +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # # After loading w pointer in ip reg. # And after pushing r4-r9 and d8-d15 on stack +<<<<<<< HEAD # |------------| # |d8 - d15 | 0 # |r4 - r9 | 64 @@ -42,6 +54,18 @@ # |out ch index| 104 # |params | 108 # |------------| +======= +# |-----------| +# |d8 - d15 | 0 +# |r4 - r9 | 64 +# |a_stride | 88 +# |w | 92 +# |c | 96 +# |c_stride | 100 +# |out ch indx| 104 +# |params | 108 +# |-----------| +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # # diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-dq-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-dq-aarch32-neon.S index c964bf2be7c44..1fcb59061000d 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-dq-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-dq-aarch32-neon.S @@ -33,6 +33,7 @@ # Args passed via stack. # TOS +<<<<<<< HEAD # |------------| # |a_stride | 0 # |w | 4 @@ -41,10 +42,21 @@ # |out ch index| 16 # |params | 20 # |------------| +======= +# |-----------| +# |a_stride | 0 +# |w | 4 +# |c | 8 +# |c_stride | 12 +# |out ch indx| 16 +# |params | 20 +# |-----------| +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # # After loading w pointer in ip reg. # And after pushing r4-r8 and d8-d15 on stack +<<<<<<< HEAD # |------------| # |d8 - d15 | 0 # |r4 - r7 | 64 @@ -56,6 +68,19 @@ # |out ch index| 100 # |params | 104 # |------------| +======= +# |-----------| +# |d8 - d15 | 0 +# |r4 - r7 | 64 +# |a_stride | 80 +# |w | 84 +# |b | 88 +# |c | 92 +# |c_stride | 96 +# |out ch indx| 100 +# |params | 104 +# |-----------| +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # # void pytorch_q8gemm_ukernel_4x8__aarch32_neon( diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S index 51866fd3b1ed1..1156d3ae6ecd5 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S @@ -22,10 +22,17 @@ # Args passed via stack. # TOS +<<<<<<< HEAD # |------------| # |out ch index| 0 # |params | 8 # |------------| +======= +# |-----------| +# |out ch indx| 0 +# |params | 8 +# |-----------| +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # void pytorch_q8gemm_ukernel_8x8__aarch64_neon( # size_t mr, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-dq-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-dq-aarch64-neon.S index 63f667b04a283..42ad88a09523e 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-dq-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-dq-aarch64-neon.S @@ -14,11 +14,19 @@ # Args passed via stack. # TOS +<<<<<<< HEAD # |------------| # |c_stride | 0 # |out ch index| 8 # |params | 16 # |------------| +======= +# |-----------| +# |c_stride | 0 +# |out ch indx| 8 +# |params | 16 +# |-----------| +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # void pytorch_q8gemm_dq_ukernel_8x8__aarch64_neon( # size_t mr, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x4-packA-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x4-packA-aarch32-neon.S index 4583e50046d69..a0e9d01250125 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x4-packA-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x4-packA-aarch32-neon.S @@ -32,7 +32,11 @@ # # Packed A format. +<<<<<<< HEAD # 4kx4m blocks for all blocks given 4 rows (4m) are placed in contiguous memory. +======= +# 4kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Original A # --------- K ----------- -- (K + 4 - 1) / 4 -- # | | | | @@ -53,7 +57,11 @@ # This locality helps in loading 8kx4m blocks of activations # Note when M is not multiple of 4, the rest can contain arbitrary # data in packed A as we will not be writing those out. +<<<<<<< HEAD # This will be taken care by just copying the appropriate valid data +======= +# This wil be taken care by just copying the appropriate valid data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Also note that this packing is same as taking for 4x1 pattern. # This is because all the adjacent k's are laid next to each other @@ -109,7 +117,11 @@ k_loop: VLD1.8 {d2}, [r6]! VLD1.8 {d3}, [r7]! +<<<<<<< HEAD # Now we have 4x8 block of values that we will transpose +======= + # Now we have 4x8 block of values that we will tranpose +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # A matrix # -------------------------------- # | | @@ -155,7 +167,11 @@ k_loop: VTRN.32 d2, d3 VSWP d1, d2 +<<<<<<< HEAD # Now store the transposed values +======= + # Now store the tranposed values +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # d0, d1, d2, d3 VST1.8 {q0}, [r2]! VST1.8 {q1}, [r2]! @@ -172,7 +188,11 @@ k_loop: VLD1.32 {d2[]}, [r6] VLD1.32 {d3[]}, [r7] +<<<<<<< HEAD # Now we have 4x8 block of values that we will transpose +======= + # Now we have 4x8 block of values that we will tranpose +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # _d{0-3} are arm neon vector registers # va0 = _d0 = a0 a1 a2 a3 # va1 = _d1 = b0 b1 b2 b3 @@ -218,7 +238,11 @@ k_loop: VEXT.8 d0, d0, d1, #4 VEXT.8 d1, d2, d3, #4 +<<<<<<< HEAD # Now store the transposed values +======= + # Now store the tranposed values +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # d0, d1, d2, d3 VST1.8 {q0}, [r2] .p2align 4 diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S index d7a3aa6eaaf74..0d6149c25090f 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S @@ -46,7 +46,11 @@ # |b | 12 # |c | 16 # |c_stride | 20 +<<<<<<< HEAD # |out ch index | 24 +======= +# |out ch indx | 24 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # |params | 28 # |----------------| # @@ -61,7 +65,11 @@ # |b | 108 # |c | 112 # |c_stride | 116 +<<<<<<< HEAD # |out ch index | 120 +======= +# |out ch indx | 120 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # |params | 124 # |----------------| # @@ -101,7 +109,11 @@ /* Add output_channel_index to the b_zero_point pointer */ ;\ ADD r4, r4, r5 ;\ ;\ +<<<<<<< HEAD /* We enter the loop if r1 is at least 1. */ ;\ +======= + /* We enter the loop if r1 is atleast 1. */ ;\ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /* r1 = r1 - 1 will happen in the epilogue */ ;\ /* of the loop */ ;\ CMP r1, 1 ;\ @@ -222,7 +234,11 @@ /* Thus we will load accumulators back in q0, q1, q2, q3, q4, q5, q6, q7 */ ;\ /* When nr < 4, extra q values will be fetched from stack which may overlap */ ;\ /* with other parts of stack storing local variables. To avoid that we just */ ;\ +<<<<<<< HEAD /* create a buffer of 128 bytes in between to make sure pointer increment */ ;\ +======= + /* create a buffer of 128 bytes inbetween to make sure pointer increment */ ;\ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /* never produces address that is beyond the stack frame of this function. */ ;\ SUB r9, sp, 140 ;\ /* Each iteration produce 4 values each of 4 bytes */ ;\ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c8x1-dq-packedA-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c8x1-dq-packedA-aarch32-neon.S index 37db2adcad069..56c4f705b93c7 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c8x1-dq-packedA-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c8x1-dq-packedA-aarch32-neon.S @@ -46,7 +46,11 @@ # |b | 12 # |c | 16 # |c_stride | 20 +<<<<<<< HEAD # |out ch index | 24 +======= +# |out ch indx | 24 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # |params | 28 # |----------------| # @@ -61,7 +65,11 @@ # |b | 108 # |c | 112 # |c_stride | 116 +<<<<<<< HEAD # |out ch index | 120 +======= +# |out ch indx | 120 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # |params | 124 # |----------------| # diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch32-neon.S index a5a91b9cb64f7..6d2dd6260f6c8 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch32-neon.S @@ -32,7 +32,11 @@ # # Packed A format. +<<<<<<< HEAD # 8kx4m blocks for all blocks given 4 rows (4m) are placed in contiguous memory. +======= +# 8kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Original A # --------- K ----------- -- (K + 4 - 1) / 4 -- # | | | | @@ -53,7 +57,11 @@ # This locality helps in loading 8kx8m blocks of activations # Note when M is not multiple of 8, the rest can contain arbitrary # data in packed A as we will not be writing those out. +<<<<<<< HEAD # This will be taken care by just copying the appropriate valid data +======= +# This wil be taken care by just copying the appropriate valid data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # void pytorch_q8gemm_sparse_packA_ukernel_8x4__aarch32_neon( # size_t mr, @@ -125,7 +133,11 @@ k_loop: VLD1.8 {d6}, [r10]! VLD1.8 {d7}, [r11]! +<<<<<<< HEAD # Now we have 8x8 block of values that we will transpose +======= + # Now we have 8x8 block of values that we will tranpose +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # A matrix # -------------------------------- # | | @@ -189,7 +201,11 @@ k_loop: VTRN.32 q0, q2 VTRN.32 q1, q3 +<<<<<<< HEAD # Now store the transposed values +======= + # Now store the tranposed values +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # d0, d1, d2, d3 # then d4, d5, d6, d7 contiguously VST1.8 {q0}, [r2]! @@ -213,7 +229,11 @@ k_loop: VLD1.32 {d6[]}, [r7] VLD1.32 {d7[]}, [r11] +<<<<<<< HEAD # Now we have 4x8 block of values that we will transpose +======= + # Now we have 4x8 block of values that we will tranpose +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # _d{0-3} are arm neon vector registers # va04 = _d0 = a0 a1 a2 a3 e0 e1 e2 e3 # va15 = _d1 = b0 b1 b2 b3 f0 f1 f2 f3 @@ -260,7 +280,11 @@ k_loop: VTRN.16 d0, d2 VTRN.16 d1, d3 +<<<<<<< HEAD # Now store the transposed values +======= + # Now store the tranposed values +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # d0, d1, d2, d3 # then d4, d5, d6, d7 contiguously VST1.8 {q0}, [r2]! diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch64-neon.S index b1f8fe719ca44..abf3da38bb7c6 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch64-neon.S @@ -9,7 +9,11 @@ #include # Packed A format. +<<<<<<< HEAD # 8kx4m blocks for all blocks given 4 rows (4m) are placed in contiguous memory. +======= +# 8kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Original A # --------- K ----------- -- (K + 4 - 1) / 4 -- # | | | | @@ -30,7 +34,11 @@ # This locality helps in loading 8kx8m blocks of activations # Note when M is not multiple of 8, the rest can contain arbitrary # data in packed A as we will not be writing those out. +<<<<<<< HEAD # This will be taken care by just copying the appropriate valid data +======= +# This wil be taken care by just copying the appropriate valid data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # void pytorch_q8gemm_sparse_packA_ukernel_8x4__aarch32_neon( # size_t mr, @@ -93,7 +101,11 @@ k_loop: LD1 {v3.d}[0], [x7], 8 LD1 {v3.d}[1], [x11], 8 +<<<<<<< HEAD # Now we have 8x8 block of values that we will transpose +======= + # Now we have 8x8 block of values that we will tranpose +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # A matrix # ------------------------ # | | @@ -180,7 +192,11 @@ k_loop: LD1 {v3.s}[0], [x7] LD1 {v3.s}[1], [x11] +<<<<<<< HEAD # Now we have 8x4 block of values that we will transpose +======= + # Now we have 8x4 block of values that we will tranpose +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # A matrix # ---------------------------- # | | diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-sse2.c index df707d3d800ea..cff379b7daf58 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-sse2.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-sse2.c @@ -14,7 +14,11 @@ #include "8x4c1x4-packed-sse2.h" // This is a super slow kernel in that it does not use intrinsics to +<<<<<<< HEAD // transpose. Since this is for x86 we are not optimizing it. +======= +// tranpose. Since this is for x86 we are not optimizing it. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // For ARM this will be optimized. void pytorch_q8gemm_sparse_packA_ukernel_8x4__sse2( const size_t mr, @@ -24,7 +28,11 @@ void pytorch_q8gemm_sparse_packA_ukernel_8x4__sse2( uint8_t* a_packed) { // Packed A format. +<<<<<<< HEAD // 8kx4m blocks for all blocks given 4 rows (4m) are placed in contiguous memory. +======= + // 8kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Original A // --------- K ----------- -- (K + 4 - 1) / 4 -- // | | | | @@ -45,7 +53,11 @@ void pytorch_q8gemm_sparse_packA_ukernel_8x4__sse2( // This locality helps in loading 8kx8m blocks of activations // Note when M is not multiple of 8, the rest can contain arbitrary // data in packed A as we will not be writing those out. +<<<<<<< HEAD // This will be taken care by just copying the appropriate valid data +======= + // This wil be taken care by just copying the appropriate valid data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Note that parts of A that are not filled are: // Remainder of M blocks. So some m values are random. This is ok diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4c1x4-dq-packedA-sse2.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4c1x4-dq-packedA-sse2.h index ef771b4187b82..503344bc98b4c 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4c1x4-dq-packedA-sse2.h +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4c1x4-dq-packedA-sse2.h @@ -47,7 +47,11 @@ void KERNEL_NAME( const __m128i vzero = _mm_setzero_si128(); // Packed A format. +<<<<<<< HEAD // 8kx4m blocks for all blocks given 4 rows (4m) are placed in contiguous memory. +======= + // 8kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Original A // --------- K ----------- -- (K + 4 - 1) / 4 -- // | | | | @@ -68,7 +72,11 @@ void KERNEL_NAME( // This locality helps in loading 8kx8m blocks of activations // Note when M is not multiple of 8, the rest can contain arbitrary // data in packed A as we will not be writing those out. +<<<<<<< HEAD // This will be taken care by just copying the appropriate valid data +======= + // This wil be taken care by just copying the appropriate valid data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __m128i vacc_low[4]; __m128i vacc_high[4]; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S index 8af5c417da31f..d87c9b2a359d2 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S @@ -42,11 +42,19 @@ # Args passed via stack. # TOS +<<<<<<< HEAD # |------------| # |c_stride | 0 # |out ch index| 8 # |params | 16 # |------------| +======= +# |-----------| +# |c_stride | 0 +# |out ch indx| 8 +# |params | 16 +# |-----------| +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # void pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w##W_INDEX_DTYPE_NUM_BITS##__aarch64_neon( # size_t mr, @@ -234,7 +242,11 @@ /* v16, v17, v18, v19, v20, v21, v22, v23 */ XX\ /* When nr < 8, say nr = 1, extra v values will be fetched from stack which may overlap */ XX\ /* with other parts of stack storing local variables. To avoid that we just */ XX\ +<<<<<<< HEAD /* create a buffer of 256 bytes in between to make sure pointer increment */ XX\ +======= + /* create a buffer of 256 bytes inbetween to make sure pointer increment */ XX\ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /* never produces address that is beyond the stack frame of this function. */ XX\ SUB x9, sp, 320 XX\ /* Each iteration produce 8 values each of 4 bytes */ XX\ @@ -287,7 +299,11 @@ LD1 {v22.4s}, [x9], 16 XX\ LD1 {v23.4s}, [x9] XX\ XX\ +<<<<<<< HEAD /* We can transpose one 4x4 block using macro */ XX\ +======= + /* We can tranpose one 4x4 block using macro */ XX\ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /* TRANSPOSE_4X4_S32 v8, v10, v12, v14, v0, v1, v2, v3 */ XX\ /* After this we have */ XX\ /* v8 : x00, x01, x02, x03 */ XX\ @@ -302,7 +318,11 @@ /* v20 : x24, x25, x26, x27 */ XX\ /* v22 : x34, x35, x36, x37 */ XX\ /* Similarly we can transpose other two 4x4 blocks and we get */ XX\ +<<<<<<< HEAD /* transposed 8x8 */ XX\ +======= + /* tranposed 8x8 */ XX\ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) XX\ TRANSPOSE_4X4_S32 v8, v10, v12, v14, v0, v1, v2, v3 XX\ TRANSPOSE_4X4_S32 v16, v18, v20, v22, v4, v5, v6, v7 XX\ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S index 58602beb030d1..317ef065456ee 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S @@ -31,11 +31,19 @@ # Args passed via stack. # TOS +<<<<<<< HEAD # |------------| # |c_stride | 0 # |out ch index| 8 # |params | 16 # |------------| +======= +# |-----------| +# |c_stride | 0 +# |out ch indx| 8 +# |params | 16 +# |-----------| +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # void pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w##W_INDEX_DTYPE_NUM_BITS##__aarch64_neon( # size_t mr, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h index 14365d1ab3ddc..da1431c22f727 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h @@ -238,7 +238,11 @@ static inline void pytorch_pack_q8conv_wrq( } } if (kzp != 0) { +<<<<<<< HEAD // This part fills the packed weights with zero points for output channels +======= + // This part fills the packed wights with zero points for output channels +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // when they are not divisible by nr blocking parameter. // In that case for (size_t nr_block_offset = 0; nr_block_offset < (nr - nr_block_size); @@ -360,7 +364,11 @@ static inline void pytorch_pack_q8deconv_wrq( } } if (kzp != 0) { +<<<<<<< HEAD // This part fills the packed weights with zero points for output channels +======= + // This part fills the packed wights with zero points for output channels +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // when they are not divisible by nr blocking parameter. // In that case for (size_t nr_block_offset = 0; nr_block_offset < (nr - nr_block_size); diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-scalar.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-scalar.c index 74961b51ff638..da2f0bcbc42f4 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-scalar.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-scalar.c @@ -93,7 +93,11 @@ void pytorch_qnnp_requantize_q31__scalar( * overflow is possible only when input is positive, and even when addition * of a rounding constant overflows 32-bit signed integer, it still doesn't * overflow 32-bit unsigned integer. Thus, in case of signed overflow, we +<<<<<<< HEAD * can compute the result using unsigned arithmetic, specifically using +======= + * can compute the result using unsigned arithmetics, specifically using +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * logical shift right instead of arithmetic shift right. * 3. Performs arithmetic shift as is, which will produce division result * rounded down. Then compute remainder of this division by a power of 2, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/fully-connected-sparse-operator-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/fully-connected-sparse-operator-tester.h index 597662fbbbae4..b97ee1bf4ef9d 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/test/fully-connected-sparse-operator-tester.h +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/fully-connected-sparse-operator-tester.h @@ -579,9 +579,15 @@ class FullyConnectedSparseOperatorTester { for (size_t i = 0; i < batchSize(); i++) { for (size_t c = 0; c < outputChannels(); c++) { +<<<<<<< HEAD ASSERT_NEAR( output_dynamic[i * outputChannels() + c], accumulators_float[i * outputChannels() + c], 1e-3) +======= + ASSERT_FLOAT_EQ( + output_dynamic[i * outputChannels() + c], + accumulators_float[i * outputChannels() + c]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) << "at " << i << ", " << c << ": reference = " << accumulators_float[i * outputChannels() + c] diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization.cc index f535e4b99ed76..005e15606af4a 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization.cc +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization.cc @@ -17,7 +17,11 @@ #include "requantization-tester.h" /* +<<<<<<< HEAD * Precise scalar implementation using unsigned 32-bit arithmetic. +======= + * Precise scalar implementation using unsigned 32-bit arithmetics. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) */ TEST(PRECISE__SCALAR_UNSIGNED32, exact_divide_by_po2) { @@ -83,7 +87,11 @@ TEST(PRECISE__SCALAR_UNSIGNED32, random_cases) { } /* +<<<<<<< HEAD * Precise scalar implementation using unsigned 64-bit arithmetic. +======= + * Precise scalar implementation using unsigned 64-bit arithmetics. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) */ TEST(PRECISE__SCALAR_UNSIGNED64, exact_divide_by_po2) { @@ -149,7 +157,11 @@ TEST(PRECISE__SCALAR_UNSIGNED64, random_cases) { } /* +<<<<<<< HEAD * Precise scalar implementation using signed 64-bit arithmetic. +======= + * Precise scalar implementation using signed 64-bit arithmetics. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) */ TEST(PRECISE__SCALAR_SIGNED64, exact_divide_by_po2) { @@ -302,7 +314,11 @@ TEST(GEMMLOWP__SCALAR, random_cases) { } /* +<<<<<<< HEAD * Precise PSIMD implementation using unsigned 32-bit arithmetic. +======= + * Precise PSIMD implementation using unsigned 32-bit arithmetics. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) */ TEST(PRECISE__PSIMD, exact_divide_by_po2) { diff --git a/aten/src/ATen/native/quantized/cudnn/Linear.cpp b/aten/src/ATen/native/quantized/cudnn/Linear.cpp index 230850998fda1..b82d345310159 100644 --- a/aten/src/ATen/native/quantized/cudnn/Linear.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Linear.cpp @@ -171,7 +171,11 @@ void PackedLinearWeightCudnn::apply_impl_helper(const at::Tensor& quantized_outp return; } +<<<<<<< HEAD // linear_op computes act_int8 * transpose(w_int8) (matrix multiplication) +======= + // linear_op computes act_int8 * tranpose(w_int8) (matrix multiplication) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // where act_int8 and w_int8 are the input and weight variables, resp. // output is a fp32 tensor auto linear_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) diff --git a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp index 7fe44de11e54c..39f0693bc81b4 100644 --- a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp @@ -54,7 +54,11 @@ void check_maxpool2d_params( Tensor adaptive_avg_pool2d_quantized_cuda( const at::Tensor& input, IntArrayRef output_size) { +<<<<<<< HEAD // TODO: re-enable these cudnn preprocessors like quantized_max_pool2d_cudnn below when we implement this function with cudnn +======= +// TODO: renable these cudnn preprocessors like quantized_max_pool2d_cudnn below when we implement this function with cudnn +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef USE_CUDA // #if AT_CUDNN_ENABLED() // TODO: limit this to per tensor quantized tensors for now, though should be easy to adapt diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 550280dbf6d3e..2d27136ce9c75 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -142,7 +142,11 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); +<<<<<<< HEAD m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16_unpacked_weight(Tensor X, Tensor weight, Tensor? bias) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); +======= + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16_unpacked_weight(Tensor X, Tensor weight, Tensor bias) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_leaky_relu(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i, float negative_slope) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_tanh(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h index 805035cdd6263..03d0bcdaa86dc 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h @@ -51,10 +51,17 @@ ForwardIt find_bound(ForwardIt first, ForwardIt last, const T& value) { // Similarly, an upper bound is a value at *it with the smallest index // such that *it > value if such value exists, or last if does not. // Let is_lower = true and *it < value, then we know that *it and values +<<<<<<< HEAD // preceding *it cannot contain a lower bound, so we adjust initial iterator range // from [first, first + count] to [first + step + 1, first + count - (step + 1)], // where +1 skips the element at which we have just evaluated *it < value. // Similar logic holds when is_lower = false. +======= + // preceeding *it cannot contain a lower bound, so we adjust initial iterator range + // from [first, first + count] to [first + step + 1, first + count - (step + 1)], + // where +1 skips the element at which we have just evaluated *it < value. + // Samilar logic holds when is_lower = false. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (is_lower ? *it < value : value >= *it) { first = ++it; count -= step + 1; diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp index cf854a84e7dad..62a04b9119469 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp @@ -79,7 +79,11 @@ struct CPUValueSelectionIntersectionKernel { const auto* ptr_argsort = argsort.const_data_ptr(); for (int64_t i = 0; i < n; ++i) { +<<<<<<< HEAD // Extract data +======= + // Exctract data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto* ptr_res_values = reinterpret_cast(ptr_res_values_bytes); const auto* ptr_lhs_values = reinterpret_cast(ptr_lhs_values_bytes); const auto lhs_nnz_idx = *reinterpret_cast(ptr_lhs_select_idx_bytes); diff --git a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp index c841da8354b5f..cd5372bc57b34 100644 --- a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp @@ -23,9 +23,12 @@ #include #endif +<<<<<<< HEAD #if AT_USE_EIGEN_SPARSE() #include #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace at::native::sparse::impl { @@ -445,6 +448,7 @@ void add_out_sparse_csr( const Tensor& mat2, const Scalar& alpha, const Tensor& result) { +<<<<<<< HEAD #if AT_USE_MKL_SPARSE() sparse::impl::mkl::add_out_sparse_csr(mat1, mat2, alpha, result); #elif AT_USE_EIGEN_SPARSE() @@ -454,6 +458,15 @@ void add_out_sparse_csr( false, "Calling add on a sparse CPU tensor requires compiling PyTorch with MKL. ", "Please use PyTorch built MKL support."); +======= +#if !AT_MKL_ENABLED() + TORCH_CHECK( + false, + "Calling add on a sparse CPU tensor requires compiling PyTorch with MKL. ", + "Please use PyTorch built MKL support."); +#else + sparse::impl::mkl::add_out_sparse_csr(mat1, mat2, alpha, result); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif } @@ -464,7 +477,11 @@ void triangular_solve_out_sparse_csr( bool upper, bool transpose, bool unitriangular) { +<<<<<<< HEAD #if !AT_USE_MKL_SPARSE() +======= +#if !AT_MKL_ENABLED() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK( false, "Calling triangular_solve on a sparse CPU tensor requires compiling PyTorch with MKL. ", diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp index 4faa135713d65..95b1664c3743a 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp @@ -127,10 +127,13 @@ #include #endif +<<<<<<< HEAD #if AT_USE_EIGEN_SPARSE() #include #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include namespace at { @@ -540,12 +543,16 @@ static void addmm_out_sparse_csr_native_cpu( auto values = sparse.values(); scalar_t cast_alpha = alpha.to(); +<<<<<<< HEAD // If beta is zero NaN and Inf should not be propagated to the result if (beta.toComplexDouble() == 0.) { r.zero_(); } else { r.mul_(beta); } +======= + r.mul_(beta); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AT_DISPATCH_INDEX_TYPES( col_indices.scalar_type(), "csr_mm_crow_indices", [&]() { auto csr_accessor = csr.accessor(); @@ -657,6 +664,7 @@ Tensor& addmm_out_sparse_compressed_cpu( return result; } +<<<<<<< HEAD #if AT_USE_EIGEN_SPARSE() if ((result.layout() == kSparseCsr || result.layout() == kSparseCsc) && (mat1.layout() == kSparseCsr || mat1.layout() == kSparseCsc) && @@ -666,6 +674,8 @@ Tensor& addmm_out_sparse_compressed_cpu( } #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if !AT_USE_MKL_SPARSE() // The custom impl addmm_out_sparse_csr_native_cpu only supports CSR @ // strided -> strided diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 752365d545dee..197b2ebe022db 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -730,7 +730,11 @@ static std::tuple sparse_mask_like_prepare_sparse_inp // is that these primitives might project first argument onto second one or // the other way around depending on which arguments are coalesced and which are // larger. This function prepares inputs for `sparse_mask` such that `t` is +<<<<<<< HEAD // projected onto `mask` by sorting `t` if uncoalesced and artificially marking it +======= + // projected onto `mask` by sorting `t` if uncoalesced and artifically marking it +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // as coalesced all while `mask` is set to uncoalesced. // The result of this projectionk is going to be uncoalesced, so it is up to the // user to set the corresponding flag correctly with respect to the operations' diff --git a/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h index 267c19561a29d..09da4fa07fad3 100644 --- a/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h +++ b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h @@ -242,7 +242,11 @@ void _validate_compressed_sparse_indices_kernel( // Catch integer overflow from large dimensions. Otherwise, the // invariant checks may fail with bogus exceptions or succeed with // false-positive results when int64_t typed dimensions are cast to +<<<<<<< HEAD // index dtype that corresponds to smaller integer type such as +======= + // index dtype that corresponds to smaller interger type such as +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // int32_t. { AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [cdim, dim, nnz]() { diff --git a/aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h b/aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h index 530804099b6fd..5d644bb8b9987 100644 --- a/aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h +++ b/aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h @@ -112,7 +112,11 @@ struct LargestValuesGreedy { } }; +<<<<<<< HEAD // We consider each rows independently in order +======= +// We consider each rows independantly in order +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // This is to ensure that a row's sparsity pattern is only determined // by its values and the rows before (but never the rows after) // This enforces causality strictly diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh index c11588a32ba05..bd60f6c7e2127 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh @@ -196,6 +196,7 @@ C10_LAUNCH_BOUNDS_1(num_threads()) __global__ void coalesceValuesKernel( int64_t *segment_offsets, int64_t *value_indices, Dtype *values, Dtype *newValues, +<<<<<<< HEAD int64_t nnz, int64_t newNnz, #ifdef USE_ROCM int64_t nsegments, @@ -207,6 +208,11 @@ __global__ void coalesceValuesKernel( #else int64_t seg = blockIdx.x * 4 + threadIdx.y; #endif +======= + int64_t nnz, int64_t newNnz, int64_t stride) { + + int seg = blockIdx.x * 4 + threadIdx.y; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Number of values processed by each thread (grain size) const int SZ = 4; @@ -215,11 +221,15 @@ __global__ void coalesceValuesKernel( const int newValueRow = seg * stride; const int begin = segment_offsets[seg]; const int end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz; +<<<<<<< HEAD #ifdef USE_ROCM const int startFeature = threadIdx.x + blockIdx.z * nsegments * SZ; #else const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ; #endif +======= + const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Acctype tmp[SZ]; #pragma unroll for (int ii = 0; ii < SZ; ii++) { @@ -262,6 +272,7 @@ C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE*4) __global__ void coalesceValuesKernel( int64_t *segment_offsets, int64_t *value_indices, bool *values, bool *newValues, +<<<<<<< HEAD int64_t nnz, int64_t newNnz, #ifdef USE_ROCM int64_t nsegments, @@ -273,6 +284,11 @@ __global__ void coalesceValuesKernel( #else int64_t seg = blockIdx.x * 4 + threadIdx.y; #endif +======= + int64_t nnz, int64_t newNnz, int64_t stride) { + + int seg = blockIdx.x * 4 + threadIdx.y; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Number of values processed by each thread (grain size) const int SZ = 4; @@ -281,11 +297,15 @@ __global__ void coalesceValuesKernel( const int newValueRow = seg * stride; const int begin = segment_offsets[seg]; const int end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz; +<<<<<<< HEAD #ifdef USE_ROCM const int startFeature = threadIdx.x + blockIdx.z * nsegments * SZ; #else const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ; #endif +======= + const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool tmp[SZ]; #pragma unroll for (int ii = 0; ii < SZ; ii++) { diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp index c656dc71a660d..b092d07b8a916 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp @@ -65,7 +65,11 @@ void _csrmm2( csrvala, /* values of the sparse matrix, size = nnz */ CUSPARSE_INDEX_32I, /* data type of row offsets index */ CUSPARSE_INDEX_32I, /* data type of col indices */ +<<<<<<< HEAD CUSPARSE_INDEX_BASE_ZERO, /* base index of row offset and col index */ +======= + CUSPARSE_INDEX_BASE_ZERO, /* base index of row offset and col indes */ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cusparse_value_type /* data type of values */ )); diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu index b59221a3231a5..6ae2d7b6dc160 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu @@ -106,6 +106,7 @@ SparseTensor _coalesce_sparse_cuda(const SparseTensor& self) { values = values.contiguous(); int64_t stride = c10::multiply_integers(values.sizes().slice(1)); int warp_size = at::cuda::warp_size(); +<<<<<<< HEAD #ifdef USE_ROCM const int64_t BATCHING_SEGMENT = 4096; int64_t nsegments = ceil_div(newNnz, (int64_t) SZ); @@ -134,6 +135,10 @@ SparseTensor _coalesce_sparse_cuda(const SparseTensor& self) { C10_CUDA_KERNEL_LAUNCH_CHECK(); }); #else +======= + dim3 grid(ceil_div(newNnz, (int64_t) SZ), ceil_div(stride, (int64_t) warp_size*SZ)); + dim3 block(warp_size, SZ); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( at::ScalarType::ComplexHalf, at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, values.scalar_type(), "coalesce_sparse_cuda", [&] { @@ -149,7 +154,10 @@ SparseTensor _coalesce_sparse_cuda(const SparseTensor& self) { ); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); +<<<<<<< HEAD #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // this grid-strided version is slower but probably more flexible diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index 3730ceb913547..6bfb0ff0e3257 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -800,7 +800,11 @@ Tensor& bmm_out_sparse_cuda(const SparseTensor& self, const Tensor& mat2, Tensor Tensor indices_dim1 = indices[1].to(ScalarType::Int); Tensor indices_dim2 = indices[2].to(ScalarType::Int); +<<<<<<< HEAD std::vector mat_el_end_indices_host(num_matrices); +======= + std::unique_ptr mat_el_end_indices_host(new int64_t[num_matrices]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) { auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); @@ -809,14 +813,22 @@ Tensor& bmm_out_sparse_cuda(const SparseTensor& self, const Tensor& mat2, Tensor search_end_matrix_indices(mat_el_end_indices_device, num_matrices, indices_dim0); AT_CUDA_CHECK(cudaMemcpy( +<<<<<<< HEAD mat_el_end_indices_host.data(), +======= + mat_el_end_indices_host.get(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mat_el_end_indices_device, num_matrices*sizeof(int64_t), cudaMemcpyDeviceToHost )); } // Need a pointer to an array to access within a lambda +<<<<<<< HEAD int64_t* mat_el_end_indices = mat_el_end_indices_host.data(); +======= + int64_t* mat_el_end_indices = &mat_el_end_indices_host[0]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Scalar beta = 0; Scalar alpha = 1; diff --git a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu index c6e3197a22a8b..8661f5b8f24b7 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu @@ -93,7 +93,11 @@ void create_general_description_(cusparseMatDescr_t& description_) { } // csrMatrixRef is used to have a representation of a raw CSR matrix representation +<<<<<<< HEAD // coming from `sparse_sparse_matmul_cuda_kernel` function. +======= +// comming from `sparse_sparse_matmul_cuda_kernel` function. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Moreover this implements a RAII guard for a cusparse descriptor template struct csrMatrixRef { diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 7aad4309924d4..6b238d692e055 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -207,7 +207,11 @@ Tensor qkv_projection( } else { // encoder-decoder attention // TODO: is there a more efficient way to set this up? +<<<<<<< HEAD // TODO: can we stay nested instead of using cat? Probably just make a +======= + // TODO: can we stay nested insted of using cat? Probably just make a +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // NestedTensor out of the matmul results or something? auto q_kv_weight_s = at::native::split_with_sizes(qkv_weight, {embed_dim, embed_dim * 2}, 0); @@ -776,7 +780,11 @@ Tensor scaled_dot_product_attention( #ifdef USE_MPS const auto any_nested = query_.is_nested() || key.is_nested() || value.is_nested(); const bool any_inputs_require_grad = query_.requires_grad() || key.requires_grad() || value.requires_grad(); +<<<<<<< HEAD const auto all_contiguous = query_.is_contiguous_or_false() && key.is_contiguous_or_false() && value.is_contiguous_or_false(); +======= + const auto all_contiguous = query_.is_contiguous() && key.is_contiguous() && value.is_contiguous(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (query_device_type == DeviceType::MPS && dropout_p == 0.0 && !(GradMode::is_enabled() && any_inputs_require_grad) && (all_contiguous || mps::is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_15_0_PLUS)) diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index c2193f2378dd5..dacd26a628822 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -915,6 +915,19 @@ std::tuple>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // TODO(eqy): debug mask support // BHSD ... const int64_t batch_size = cumulative_sequence_length_q.value().size(0) - 1; @@ -1402,7 +1415,11 @@ std::tuple _efficient_ if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { +<<<<<<< HEAD #if defined(USE_ROCM_CK_SDPA) +======= +#if defined(USE_CK_FLASH_ATTENTION) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::optional out(res); std::optional seqused_k = std::nullopt; std::optional alibi_slopes = std::nullopt; diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 55fc1e261219e..46843688064ef 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -27,8 +27,11 @@ #include #include #include +<<<<<<< HEAD #include #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -100,14 +103,22 @@ std::tuple _flash_attention_backward( std::optional dk{std::nullopt}; std::optional dv{std::nullopt}; +<<<<<<< HEAD // The kernel computes regardless we will drop for this functions return +======= + // The kernel computes irregardless we will drop for this functions return +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor grad_softmax; // Currently unused args: std::optional alibi_slopes{std::nullopt}; const float softcap = 0.0; +<<<<<<< HEAD bool deterministic{false}; +======= + bool determinisitic{false}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto& ctx = at::globalContext(); if (ctx.deterministicAlgorithms()) { if (ctx.deterministicAlgorithmsWarnOnly()) { @@ -115,7 +126,11 @@ std::tuple _flash_attention_backward( "Flash Attention defaults to a non-deterministic algorithm. ", "To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False)."); } else { +<<<<<<< HEAD deterministic = true; +======= + determinisitic = true; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -150,7 +165,11 @@ std::tuple _flash_attention_backward( non_null_window_right, #endif softcap, +<<<<<<< HEAD deterministic, +======= + determinisitic, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) philox_seed, philox_offset); return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue)); @@ -178,7 +197,11 @@ std::tuple _flash_attention_backward( non_null_window_right, #endif softcap, +<<<<<<< HEAD deterministic, +======= + determinisitic, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) philox_seed, philox_offset); return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue)); @@ -188,7 +211,11 @@ std::tuple _flash_attention_backward( return std::make_tuple(Tensor(), Tensor(), Tensor()); } +<<<<<<< HEAD std::tuple _cudnn_attention_backward( +======= +std::tuple _scaled_dot_product_cudnn_attention_backward_cuda( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const Tensor& grad_out, const Tensor& query, const Tensor& key, @@ -215,6 +242,7 @@ std::tuple _cudnn_attention_backward( } } +<<<<<<< HEAD const bool is_nested = cum_seq_q.defined(); const int64_t max_seqlen_batch_q = query.size(2); const int64_t max_seqlen_batch_k = key.size(2); @@ -326,6 +354,59 @@ std::tuple _cudnn_attention_backward( philox_offset); return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); } +======= + const int64_t batch_size = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t head_dim_qk = query.size(3); + const int64_t head_dim_v = value.size(3); + const int64_t max_seqlen_batch_q = query.size(2); + const int64_t max_seqlen_batch_k = key.size(2); + + // This is needed because SaveVariable automatically converts + // std::optional to undefined tensor + std::optional attn_bias_; + if (attn_bias.defined()) { + attn_bias_ = attn_bias; + } + if (attn_bias_.has_value()) { + const auto bias_dim = attn_bias_.value().dim(); + if (bias_dim == 2) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else if (bias_dim == 3) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else { + TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); + attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); + } + } + + const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); + auto dq = at::empty_like(query); + auto dk = at::empty_like(key); + auto dv = at::empty_like(value); + run_cudnn_SDP_bprop(batch_size /*int64_t b*/, + num_heads /*int64_t h*/, + max_q/*int64_t s_q*/, + max_k/*int64_t s_kv*/, + head_dim_qk /*int64_t d_qk*/, + head_dim_v /*int64_t d_v*/, + softmax_scale /*float scaling_factor*/, + is_causal /*bool is_causal*/, + dropout_p /*float dropout_probability*/, + query /*const Tensor& q*/, + key /*const Tensor& k*/, + value /*const Tensor& v*/, + attn_bias_ /*const std::optional& attn_bias*/, + out /*const Tensor& o*/, + grad_out/*const Tensor& dO*/, + logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/, + dq/*Tensor& dQ*/, + dk/*Tensor& dK*/, + dv/*Tensor& dV*/, + philox_seed/*Tensor& dropoutseed*/, + philox_offset/*Tensor& dropoutoffset*/); + return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } std::tuple @@ -495,7 +576,11 @@ _efficient_attention_backward( // ROCM Implementation if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { +<<<<<<< HEAD #if defined(USE_ROCM_CK_SDPA) +======= +#if defined(USE_CK_FLASH_ATTENTION) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto my_softmax_scale = sdp::calculate_scale(query, scale).expect_float(); // Store grad_bias in optional std::optional opt_grad_bias = grad_bias; @@ -1185,6 +1270,7 @@ std::tuple _scaled_dot_product_e } } +<<<<<<< HEAD std::tuple _scaled_dot_product_cudnn_attention_backward_cuda( const Tensor& grad_out, const Tensor& query, @@ -1221,4 +1307,6 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ scale); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace at::native diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp index a4e37da1a4ae9..818bd8ca4fd57 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp @@ -32,9 +32,13 @@ #endif +<<<<<<< HEAD C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include C10_DIAGNOSTIC_POP() +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include @@ -391,14 +395,29 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head std::optional gen_) { auto dprops = at::cuda::getCurrentDeviceProperties(); +<<<<<<< HEAD bool is_sm80_or_newer = (dprops->major * 10) >= 80; TORCH_CHECK(is_sm80_or_newer, "FlashAttention only supports Ampere GPUs or newer."); +======= + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + bool is_sm10x = dprops->major == 10 && dprops->minor >= 0; + bool is_sm120_or_sm121 = dprops->major == 12 && dprops->minor <= 1; + TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { +<<<<<<< HEAD TORCH_CHECK(is_sm80_or_newer, "bfloat16 is only supported on Ampere GPUs or newer"); +======= + TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -573,14 +592,29 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q std::optional gen_) { auto dprops = at::cuda::getCurrentDeviceProperties(); +<<<<<<< HEAD bool is_sm80_or_newer = (dprops->major * 10) >= 80; TORCH_CHECK(is_sm80_or_newer, "FlashAttention only supports Ampere GPUs or newer."); +======= + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + bool is_sm10x = dprops->major == 10 && dprops->minor >= 0; + bool is_sm120_or_sm121 = dprops->major == 12 && dprops->minor <= 1; + TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { +<<<<<<< HEAD TORCH_CHECK(is_sm80_or_newer, "bfloat16 is only supported on Ampere GPUs or newer"); +======= + TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -828,8 +862,20 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si #endif if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); +<<<<<<< HEAD bool is_sm80_or_newer = (dprops->major * 10) >= 80; TORCH_CHECK(is_sm80_or_newer, "FlashAttention only supports Ampere GPUs or newer."); +======= + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm80 = dprops->major == 8 && dprops->minor == 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + bool is_sm10x = dprops->major == 10 && dprops->minor >= 0; + bool is_sm120_or_sm121 = dprops->major == 12 && dprops->minor <= 1; + TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool is_dropout = p_dropout > 0.0; auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -838,7 +884,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { +<<<<<<< HEAD TORCH_CHECK(is_sm80_or_newer, "bfloat16 is only supported on Ampere GPUs or newer"); +======= + TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -868,7 +918,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); if (head_size > 192 && (head_size <= 224 || is_dropout)) { +<<<<<<< HEAD TORCH_CHECK(is_sm80_or_newer, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); +======= + TORCH_CHECK(is_sm80 || is_sm90 || is_sm10x || is_sm120_or_sm121, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -1038,9 +1092,21 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); +<<<<<<< HEAD bool is_sm80_or_newer = (dprops->major * 10) >= 80; TORCH_CHECK(is_sm80_or_newer, "FlashAttention only supports Ampere GPUs or newer."); +======= + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm80 = dprops->major == 8 && dprops->minor == 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + bool is_sm10x = dprops->major == 10 && dprops->minor >= 0; + bool is_sm120_or_sm121 = dprops->major == 12 && dprops->minor <= 1; + TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool is_dropout = p_dropout > 0.0; auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -1048,7 +1114,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { +<<<<<<< HEAD TORCH_CHECK(is_sm80_or_newer, "bfloat16 is only supported on Ampere GPUs or newer"); +======= + TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -1083,7 +1153,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); if (head_size > 192 && (head_size <= 224 || is_dropout)) { +<<<<<<< HEAD TORCH_CHECK(is_sm80_or_newer, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); +======= + TORCH_CHECK(is_sm80 || is_sm90 || is_sm10x || is_sm120_or_sm121, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -1257,14 +1331,29 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ) { auto dprops = at::cuda::getCurrentDeviceProperties(); +<<<<<<< HEAD bool is_sm80_or_newer = (dprops->major * 10) >= 80; TORCH_CHECK(is_sm80_or_newer, "FlashAttention only supports Ampere GPUs or newer."); +======= + // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + bool is_sm10x = dprops->major == 10 && dprops->minor >= 0; + bool is_sm120_or_sm121 = dprops->major == 12 && dprops->minor <= 1; + TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + // We will support Turing in the near future + // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { +<<<<<<< HEAD TORCH_CHECK(is_sm80_or_newer, "bfloat16 is only supported on Ampere GPUs or newer"); +======= + TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype"); @@ -1299,7 +1388,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; const int num_heads_k = kcache.size(2); const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; +<<<<<<< HEAD TORCH_CHECK(batch_size > 0, "batch size must be positive"); +======= + TORCH_CHECK(batch_size > 0, "batch size must be postive"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h index 7115cb07a793e..156f0d2ad2c7e 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h @@ -125,7 +125,11 @@ class MemoryEfficientAttentionNormalize { FragmentSource const& source) const { assert(!isFirst); +<<<<<<< HEAD // Convert source to internal compute numeric type +======= + // Convert source to interal compute numeric type +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NumericArrayConverter source_converter; NumericArrayConverter @@ -164,7 +168,11 @@ class MemoryEfficientAttentionNormalize { const { assert(isFirst); +<<<<<<< HEAD // Convert source to internal compute numeric type +======= + // Convert source to interal compute numeric type +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NumericArrayConverter accumulator_converter; diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h index 3c3566512b45c..beaab057c376d 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h @@ -88,7 +88,11 @@ class CustomMmaBase { Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>; +<<<<<<< HEAD /// Number of warp-level GEMM operations +======= + /// Number of warp-level GEMM oeprations +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h index e75a1b9001e02..cca1157d6d426 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h @@ -68,7 +68,11 @@ namespace threadblock { /// ForwardTileIterator /// template < +<<<<<<< HEAD typename ThreadMap_, ///< Thread map (concept: OutputTileThreadMap) +======= + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) typename Element_, ///< Element data type bool ScatterD = false, ///< Scatter D operand or not bool UseCUDAStore = false> diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h index 20495a05474b0..1a770c970451b 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -245,7 +245,11 @@ struct AttentionBackwardKernel { static constexpr int64_t kWarpSize = 32; // If this is true, we store and accumulate dK/dV in RF +<<<<<<< HEAD // rather than going back to gmem every time +======= + // rather than going back to gmem everytime +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static constexpr bool kIsHalf = cutlass::sizeof_bits::value <= 16; static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI; static_assert( diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 8eec0de7773f3..f09479de9a886 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -61,6 +61,7 @@ namespace sdp { namespace { +<<<<<<< HEAD // tracks whether we've set the default priority order once, to avoid setting // it redundantly or overwriting a user-specified priority order // when the priority order context manager is used before the default priority @@ -84,6 +85,23 @@ bool check_prefer_cudnn_attention() { auto dprops = at::cuda::getCurrentDeviceProperties(); auto major = dprops->major; return (major == 9 || major == 10) && !dprops->minor; +======= +// TODO(eqy): more benchmarking to determine whether this should include sm86/89 +// Needs to be kept in-sync with test_fused_chocie in test_transformers.py +bool check_prefer_cudnn_attention() { + // TODO(eqy): Re-enable by default after upgrading to a release later than 9.5.0 + // see context: https://github.com/pytorch/pytorch/issues/138340 + // return false; +#if defined(CUDNN_VERSION) + +#if CUDNN_VERSION > 90000 + auto dprops = at::cuda::getCurrentDeviceProperties(); + return dprops->major >= 9; +#else + return false; +#endif + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #else return false; #endif @@ -91,6 +109,7 @@ bool check_prefer_cudnn_attention() { // flash_attention V2 is universally faster than efficient_attention and Math std::array priority_order(sdp_params const& params) { +<<<<<<< HEAD if (!priority_order_init_) { priority_order_init_ = true; if (check_prefer_cudnn_attention()) { @@ -101,6 +120,8 @@ std::array priority_order(sdp_params const& params) { at::globalContext().setSDPPriorityOrder(cudnn_order); } } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return at::globalContext().sDPPriorityOrder(); } @@ -437,7 +458,11 @@ bool check_flash_causal_non_square_seqlens(sdp_params const& params, bool debug) bool check_all_tensors_on_device(sdp_params const& params, bool debug) { // Check that all tensors are on the GPU device +<<<<<<< HEAD // This should be handled by the stub dispatch, but we call can_use_*_attention +======= + // This should be handled by the stub dispatch, but whe call can_use_*_attention +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // directly from python we need to ensure that the tensors are on cuda if (params.query.device().type() != at::DeviceType::CUDA) { if (debug) { @@ -473,9 +498,15 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) { return false; } auto head_dim_limit = 128; +<<<<<<< HEAD if (cudnn_version >= 91000) { auto dprops = at::cuda::getCurrentDeviceProperties(); if (dprops->major == 9 && !dprops->minor) { +======= + if (cudnn_version >= 90501) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (dprops->major == 9 && !dprops->minor) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) head_dim_limit = 256; } } @@ -512,6 +543,7 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) { return false; } } +<<<<<<< HEAD if (s_k == 1) { if (debug) { TORCH_WARN_ONCE("cudnn SDPA does not support key/value sequence length 1."); @@ -521,6 +553,11 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) { if (s_q == 1 && params.dropout != 0.0) { if (debug) { TORCH_WARN_ONCE("cudnn SDPA does not support query sequence length 1 with dropout."); +======= + if (s_q == 1 || s_k == 1) { + if (debug) { + TORCH_WARN_ONCE("cudnn SDPA does not support sequence length 1."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return false; } @@ -628,9 +665,15 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) { const auto dprop = at::cuda::getCurrentDeviceProperties(); // Check that the input is nested +<<<<<<< HEAD if (!(dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) { if (debug) { TORCH_WARN("cuDNN SDPA supports nested tensors on SM 9.0, SM 10.0."); +======= + if (dprop->major != 9 && has_for_nested_inputs(params)) { + if (debug) { + TORCH_WARN("CuDNN SDPA supports nested tensors on SM 9.0."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return false; } @@ -654,7 +697,11 @@ bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) { // sdp kernels if (!at::globalContext().userEnabledCuDNNSDP()) { if (debug) { +<<<<<<< HEAD TORCH_WARN("cuDNN attention has been runtime disabled."); +======= + TORCH_WARN("CuDNN attention has been runtime disabled."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return false; } @@ -685,6 +732,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { #endif #if defined(CUDNN_VERSION) && CUDNN_VERSION < 90000 if (debug) { +<<<<<<< HEAD TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use cuDNN Attention (< v9.0.0)"); } return false; @@ -698,14 +746,27 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { return false; } #endif +======= + TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use CuDNN Attention (< v9.0.0)"); + } + return false; +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Define gate functions that determine if a flash kernel can be ran // Replace with std::to_array when we migrate to c++20 constexpr auto general_constraints = c10::array_of( check_runtime_disabled_cudnn, check_for_nested_inputs, +<<<<<<< HEAD + check_all_tensors_on_device, + check_tensor_shapes, +======= + check_nonzero_sequence_lengths_dense, check_all_tensors_on_device, check_tensor_shapes, + check_cudnn_tensor_shapes, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) check_cudnn_deterministic, check_dtypes_low_precision, check_attn_mask_shape, @@ -718,10 +779,15 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { } constexpr auto dense_constraints = c10::array_of( +<<<<<<< HEAD check_nonzero_sequence_lengths_dense, check_last_dim_stride_equals_1_dense, check_batch_size_and_num_heads_dense, check_cudnn_tensor_shapes +======= + check_last_dim_stride_equals_1_dense, + check_batch_size_and_num_heads_dense +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ); if (has_only_dense_inputs(params)) { @@ -917,11 +983,14 @@ SDPBackend select_sdp_backend(sdp_params const& kernel_params) { return SDPBackend::math; } break; +<<<<<<< HEAD case SDPBackend::overrideable: if (ctx.userEnabledOverrideableSDP()) { TORCH_CHECK(false, "Invalid backend"); } break; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) default: TORCH_CHECK(false, "Invalid backend"); } @@ -938,7 +1007,11 @@ SDPBackend select_sdp_backend(sdp_params const& kernel_params) { sdp::can_use_mem_efficient_attention(kernel_params, print_debug); TORCH_WARN("Flash attention kernel not used because:"); sdp::can_use_flash_attention(kernel_params, print_debug); +<<<<<<< HEAD TORCH_WARN("cuDNN attention kernel not used because:"); +======= + TORCH_WARN("CuDNN attention kernel not used because:"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sdp::can_use_cudnn_attention(kernel_params, print_debug); TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.") return SDPBackend::error; diff --git a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h index d316808cf9bef..07d2deb9ba677 100644 --- a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h +++ b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h @@ -86,7 +86,11 @@ aotriton::TensorView mk_aotensor(const at::Tensor& q, std::string_view ten { const auto strides = q.strides(); int real_rank = strides.size(); +<<<<<<< HEAD if (real_rank != Rank) { // Lazy conversion of tensor_name +======= + if (real_rank != Rank) { // Lazy convertion of tensor_name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(false, std::string(tensor_name) + "'s rank should be " + std::to_string(Rank) + " but is " + std::to_string(real_rank)); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index 2467cb809fdbf..0b5839a67af04 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -236,6 +236,15 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x } else { softmax_fa_t = at::empty({ 0, 0, 0, 0 }, opts); } +<<<<<<< HEAD +======= + + at::Tensor atomic_counter; + if (is_causal) { + atomic_counter = at::zeros({1}, opts.dtype(at::kInt)); + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left, window_size_right, seqlen_q, @@ -249,6 +258,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x constexpr bool uses_swa = false; #endif +<<<<<<< HEAD // SWA in AOTriton Kernels is treated as "Generalized Causal masks" is_causal = is_causal || uses_swa; @@ -257,6 +267,8 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x atomic_counter = at::zeros({1}, opts.dtype(at::kInt)); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) hipError_t err; // TODO: Error handling using aotriton::v2::flash::attn_fwd; using sdp::aotriton_adapter::mk_aotensor; @@ -396,7 +408,11 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot CHECK_SHAPE(cu_seqlens_k, batch_size + 1); // AOTriton's varlen API needs input shapes be +<<<<<<< HEAD // (1, num_heads, total sequence length, head dimension) +======= + // (1, num_heads, total sequence lenght, head dimension) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::Tensor q_padded, k_padded, v_padded; at::Tensor out, out_padded; q_padded = q.unsqueeze(0).transpose(1, 2); @@ -450,9 +466,12 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot constexpr bool uses_swa = false; #endif +<<<<<<< HEAD // SWA in AOTriton Kernels is treated as "Generalized Causal masks" is_causal = is_causal || needs_swa; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto [seed_t, offset_t, philox_state, use_philox_state] = prepare_philox_arguments(p_dropout, batch_size * num_heads * 32); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt b/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt index 819880cf3bc5c..5e1980b33d581 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt @@ -1,6 +1,10 @@ # generate a list of kernels, but not actually emit files at config stage execute_process( +<<<<<<< HEAD COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py +======= + COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) --api fwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt RESULT_VARIABLE ret ) @@ -10,6 +14,7 @@ if(ret AND NOT ret EQUAL 0) endif() execute_process( +<<<<<<< HEAD COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_splitkv --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_splitkv_blob_list.txt RESULT_VARIABLE ret @@ -31,6 +36,9 @@ endif() execute_process( COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py +======= + COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) --api bwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt RESULT_VARIABLE ret ) @@ -39,14 +47,20 @@ if(ret AND NOT ret EQUAL 0) message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of BWD kernels via Python.") endif() +<<<<<<< HEAD # Generate the files for both fwd, fwd_splitkv, fwd_appendkv, and bwd execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} +======= +# Generate the files for both fwd and bwd +execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if(ret AND NOT ret EQUAL 0) message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.") endif() +<<<<<<< HEAD execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_splitkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} ) @@ -62,6 +76,9 @@ if(ret AND NOT ret EQUAL 0) endif() execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} +======= +execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) RESULT_VARIABLE ret ) @@ -78,6 +95,7 @@ if(ret AND NOT ret EQUAL 0) message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd pass") endif() +<<<<<<< HEAD execute_process( COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_splitkv_blob_list.txt" RESULT_VARIABLE ret) @@ -94,6 +112,8 @@ if(ret AND NOT ret EQUAL 0) message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd appendkv pass") endif() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Change make_kernel to make_kernel_pt for bwd execute_process( COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt" diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/add_make_kernel_pt.sh b/aten/src/ATen/native/transformers/hip/flash_attn/ck/add_make_kernel_pt.sh index 849613f795692..51bb1bb58d656 100755 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/add_make_kernel_pt.sh +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/add_make_kernel_pt.sh @@ -21,8 +21,11 @@ while IFS= read -r file; do if [ -f "$file" ]; then # Use sed to replace "make_kernel" with "make_kernel_pt" in place sed -i 's/make_kernel/make_kernel_pt/g' "$file" +<<<<<<< HEAD sed -i 's/\#include \"fmha_fwd.hpp\"/\#include \"fmha_fwd.hpp\"\n\#include \"launch_kernel_pt.hpp\"/g' "$file" sed -i 's/\#include \"fmha_bwd.hpp\"/\#include \"fmha_bwd.hpp\"\n\#include \"launch_kernel_pt.hpp\"/g' "$file" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) echo "Updated: $file" else echo "Skipping: $file (not found)" diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/bias.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/bias.hpp new file mode 100644 index 0000000000000..8115288fb8877 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/bias.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +// keep sync with BlockAttentionBiasEnum +enum class bias_enum +{ + no_bias = 0, + elementwise_bias = 1, + alibi = 2, +}; + +struct bias_info +{ + bias_enum type; + /* + * simple dispatch logic + * + * if type == elementwise_bias: + * if rank_info == 0: + * bias is 1*1*s*s + * elif rank_info == 1: + * bias is 1*h*s*s + * elif rank_info == 2: + * bias is b*h*s*s + * + * elif type == alibi: + * if rank_info == 0: + * alibi in 1*h + * elif rank_info == 1: + * alibi in b*h + */ + int rank_info; + + void serialize(std::ostream& os) const + { + if(type == bias_enum::no_bias) + os << "n"; + else if(type == bias_enum::elementwise_bias) + { + os << "e"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + else if(type == bias_enum::alibi) + { + os << "alibi"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + } + + static bias_info decode(std::string str) + { + bias_info info{bias_enum::no_bias, 0}; + if(str == "0" || str == "n") + { + info.type = bias_enum::no_bias; + } + else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 || + str.compare(0, 11, "elementwise") == 0) + { + info.type = bias_enum::elementwise_bias; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 || + str.compare(0, 5, "alibi") == 0) + { + info.type = bias_enum::alibi; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + return info; + } + + friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) + { + bi.serialize(os); + return os; + } +}; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp new file mode 100644 index 0000000000000..affa40619b598 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp @@ -0,0 +1,457 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +struct FmhaBwdFp16 +{ +}; + +struct FmhaBwdBf16 +{ +}; + +template +struct FmhaBwdTypeConfig; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using GemmDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::half_t; + using OGradDataType = ck_tile::half_t; + using QGradDataType = ck_tile::half_t; + using KGradDataType = ck_tile::half_t; + using VGradDataType = ck_tile::half_t; + using BiasGradDataType = ck_tile::half_t; +}; + +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using GemmDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = ck_tile::bf16_t; + using OGradDataType = ck_tile::bf16_t; + using QGradDataType = ck_tile::bf16_t; + using KGradDataType = ck_tile::bf16_t; + using VGradDataType = ck_tile::bf16_t; + using BiasGradDataType = ck_tile::bf16_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct fmha_bwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + const void* o_ptr; + const void* lse_ptr; + const void* do_ptr; + void* d_ptr; + void* rand_val_ptr; + void* dq_ptr; + void* dk_ptr; + void* dv_ptr; + void* dbias_ptr; + void* dq_acc_ptr; + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t max_seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + float scale; + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o; + ck_tile::index_t stride_randval; + ck_tile::index_t stride_do; + ck_tile::index_t stride_dq_acc; + ck_tile::index_t stride_dq; + ck_tile::index_t stride_dk; + ck_tile::index_t stride_dv; + ck_tile::index_t stride_dbias; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_do; + ck_tile::index_t nhead_stride_lsed; + ck_tile::index_t nhead_stride_dq_acc; + ck_tile::index_t nhead_stride_dq; + ck_tile::index_t nhead_stride_dk; + ck_tile::index_t nhead_stride_dv; + ck_tile::index_t nhead_stride_dbias; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_do; + ck_tile::index_t batch_stride_lsed; + ck_tile::index_t batch_stride_dq_acc; + ck_tile::index_t batch_stride_dq; + ck_tile::index_t batch_stride_dk; + ck_tile::index_t batch_stride_dv; + ck_tile::index_t batch_stride_dbias; + ck_tile::index_t split_stride_dq_acc; + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + float p_drop; + float p_undrop; + std::variant, std::pair> + drop_seed_offset; +}; + +template +auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) + { + return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.dq_acc_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dq_acc, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, + args.nhead_stride_dbias, + args.split_stride_dq_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.dq_acc_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dq_acc, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, + args.nhead_stride_dbias, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_do, + args.batch_stride_lsed, + args.batch_stride_dq_acc, + args.batch_stride_dk, + args.batch_stride_dv, + args.batch_stride_dbias, + args.split_stride_dq_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.drop_seed_offset); + } + }(); + + dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) +{ + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode) + { + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqstart_q_ptr, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed); + } + else + { // create batch mode kernel arguments + return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, + args.do_ptr, + args.d_ptr, + args.p_undrop, + args.seqlen_q, + args.hdim_v, + args.stride_do, + args.stride_o, + args.nhead_stride_do, + args.nhead_stride_o, + args.nhead_stride_lsed, + args.batch_stride_do, + args.batch_stride_o, + args.batch_stride_lsed); + } + }(); + + dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) +{ + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode) + { + return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, + args.dq_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.hdim_q, + args.stride_dq, + args.stride_dq_acc, + args.nhead_stride_dq, + args.nhead_stride_dq_acc, + args.split_stride_dq_acc); + } + else + { // create batch mode kernel arguments + return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, + args.dq_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.stride_dq, + args.stride_dq_acc, + args.nhead_stride_dq, + args.nhead_stride_dq_acc, + args.batch_stride_dq, + args.batch_stride_dq_acc, + args.split_stride_dq_acc); + } + }(); + + dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_bwd_dq_dk_dv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + using FmhaDropout = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kIsDeterministic = kIsDeterministic_; +}; + +template +float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dq_dk_dv_get_name_(); + +template +struct fmha_bwd_dot_do_o_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_dot_do_o_get_name_(); + +template +struct fmha_bwd_convert_dq_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kIsDeterministic = kIsDeterministic_; +}; + +template +float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args); + +template +void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); + +template +std::string fmha_bwd_convert_dq_get_name_(); + +// This is the public API, will be generated by script +struct fmha_bwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_dbias; + bool has_dropout; + bool is_store_randval; + bool is_deterministic; + // TODO: padding check is inside this api +}; +template +float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd.hpp new file mode 100644 index 0000000000000..2de70cd49bbb7 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd.hpp @@ -0,0 +1,824 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +struct FmhaFwdFp16 +{ +}; + +struct FmhaFwdBf16 +{ +}; + +struct FmhaFwdFp8 +{ +}; + +struct FmhaFwdBf8 +{ +}; + +struct FmhaFwdFp8Fp16 +{ +}; + +struct FmhaFwdFp8Bf16 +{ +}; + +template +struct FmhaFwdTypeConfig; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::half_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::fp8_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bf8_t; + using KDataType = ck_tile::bf8_t; + using VDataType = ck_tile::bf8_t; + using BiasDataType = ck_tile::bf8_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf8_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct fmha_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* rand_val_ptr; + void* lse_ptr; + void* o_ptr; + + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* + seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float scale_p; + float scale_o; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + + float p_drop; + bool s_randval; + + std::variant, std::pair> + drop_seed_offset; +}; + +struct fmha_fwd_splitkv_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* lse_acc_ptr; + void* o_acc_ptr; + void* lse_ptr; + void* o_ptr; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not + // nullptr. + + const void* cache_batch_idx; + + // the real seqlen_q & seqlen_k are decided by following: + // batch mode: seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k + // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] + // + // batch mode (kvcache): + // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k_ptr[b] + // group mode (kvcache): + // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // + // when is_gappy=true: + // seqlen_k = kargs.seqlen_k_ptr[b] + // seqstart_k_ptr[b] now store local offset of each batch + // + // when is_gappy=false: + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + ck_tile::index_t num_splits; + + float scale_s; + float scale_p; + float scale_o; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o_acc; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; + ck_tile::index_t batch_stride_o; + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; +}; + +struct fmha_fwd_appendkv_args +{ + void* q_ptr; + void* k_ptr; + const void* knew_ptr; + void* v_ptr; + const void* vnew_ptr; + + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_knew; + ck_tile::index_t batch; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0 + const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0 + ck_tile::index_t rotary_dim; + bool has_mask; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + + const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache) + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_knew; + ck_tile::index_t stride_v; + ck_tile::index_t stride_vnew; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_knew; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_vnew; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_knew; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_vnew; +}; + +template +auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + if constexpr(FmhaKernel::kIsGroupMode) + { + dim3 grids = FmhaKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); + return ck_tile::make_tuple(kargs, grids); + } + else + { + dim3 grids = + FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); + return ck_tile::make_tuple(kargs, grids); + } +} + +template +auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.is_gappy, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_k, // only used for paged-kvcache + args.batch_stride_v, // only used for paged-kvcache + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.seqlen_q, + args.seqlen_k, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + }(); + + dim3 grids = Kernel::GridSize( + args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits); + + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel argumentszs + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.seqstart_q_ptr, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.seqlen_q, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.batch_stride_lse, + args.batch_stride_o, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + }(); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.knew_ptr, + args.v_ptr, + args.vnew_ptr, + args.seqlen_q, + args.seqlen_k_ptr, + args.seqlen_knew, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.rotary_cos_ptr, + args.rotary_sin_ptr, + args.rotary_dim, + args.has_mask, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, + args.stride_q, + args.stride_k, + args.stride_knew, + args.stride_v, + args.stride_vnew, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_knew, + args.nhead_stride_v, + args.nhead_stride_vnew, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_knew, + args.batch_stride_v, + args.batch_stride_vnew); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew); + + return ck_tile::make_tuple(kargs, grids); +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_fwd_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kHasDropout = kHasDropout_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); + +template +struct fmha_fwd_splitkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kIsPagedKV = kIsPagedKV_; +}; + +template +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); + +template +std::string fmha_fwd_splitkv_get_name_(); + +template +struct fmha_fwd_splitkv_combine_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); + +template +std::string fmha_fwd_splitkv_combine_get_name_(); + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_fwd_appendkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_; + static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_; + static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_; + static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSk = kPadSk_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr auto RotaryEnum = RotaryEnum_; + static constexpr bool kIsPagedKV = kIsPagedKV_; +}; + +template +float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args); + +// This is the public API, will be generated by script +struct fmha_fwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_lse; + bool has_dropout; + bool do_fp8_static_quant; + // TODO: padding check is inside this api +}; +float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); + +struct fmha_fwd_splitkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_lse; + bool do_fp8_static_quant; + // TODO: padding check is inside this api +}; +float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, + fmha_fwd_splitkv_args, + const ck_tile::stream_config&); + +struct fmha_fwd_appendkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_v_rowmajor; + rope_enum rope_type; +}; +float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, + fmha_fwd_appendkv_args, + const ck_tile::stream_config&); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mask.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mask.hpp new file mode 100644 index 0000000000000..133049057d782 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mask.hpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include +#include + +// keep this in sync with ck_tile::GenericAttentionMaskEnum +enum class mask_enum +{ + no_mask = 0, + mask_top_left, + mask_bottom_right, + window_generic, +}; + +struct mask_info +{ + mask_enum type; + ck_tile::index_t y, x; + ck_tile::index_t left, right; // FA style SWA left/right + + void serialize(std::ostream& os) const + { + if(type == mask_enum::no_mask) + os << "n"; + else if(type == mask_enum::mask_top_left) + os << "t(" << left << ":" << right << ")"; + else if(type == mask_enum::mask_bottom_right) + os << "b(" << left << ":" << right << ")"; + else + { + os << "g(" << y << ":" << x << ")"; + } + } + static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) + { + ck_tile::index_t x_total = seqlen_k; + ck_tile::index_t y_total = seqlen_q; + mask_info tmp; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + if(t == "xt" || t == "xb") + { + // xformer style sliding window attn from top-left + ck_tile::index_t window_size = atoi(v.c_str()); + ck_tile::index_t left_size = -1; + ck_tile::index_t right_size = 0; + if(window_size > 0) + { + left_size = window_size / 2; + right_size = window_size - 1 - left_size; + } + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, y_total, x_total, t == "xt"); + + tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = left_size; + tmp.right = right_size; + } + else + { + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + printf("not supported value %s, %s\n", v.c_str(), str.c_str()); + assert(0); + } + tmp.type = mask_enum::window_generic; + ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); + ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation + if(t == "t") + { + tmp.type = mask_enum::mask_top_left; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, true); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "b") + { + tmp.type = mask_enum::mask_bottom_right; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, false); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "g") + { + tmp.y = v0; + tmp.x = v1; + tmp.left = v0; // TODO: don't use this? + tmp.right = v1; + } + else + { + printf("not supported type %s, %s\n", t.c_str(), str.c_str()); + assert(0); + } + } + } + else + { + auto set_causal_top_left = [&]() { + tmp.type = mask_enum::mask_top_left; + tmp.y = seqlen_q; + tmp.x = 1; + tmp.left = -1; + tmp.right = 0; + }; + auto set_causal_bottom_right = [&]() { + tmp.type = mask_enum::mask_bottom_right; + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + tmp.left = -1; + tmp.right = 0; + }; + if(str == "t") + set_causal_top_left(); + else if(str == "b") + set_causal_bottom_right(); + else + { + tmp.type = static_cast(atoi(str.c_str())); + if(tmp.type == mask_enum::mask_top_left) + { + set_causal_top_left(); + } + else if(tmp.type == mask_enum::mask_bottom_right) + { + set_causal_bottom_right(); + } + } + } + return tmp; + } + + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) + { + mi.serialize(os); + return os; + } +}; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_bwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_bwd_ck.hip index 59669afb93d2f..c430a611347db 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_bwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_bwd_ck.hip @@ -1,7 +1,11 @@ #include #include +<<<<<<< HEAD #if defined(USE_ROCM_CK_SDPA) +======= +#if defined(USE_CK_FLASH_ATTENTION) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace pytorch_flash { std::tuple< at::Tensor, // dQ @@ -117,4 +121,8 @@ mem_eff_backward_ck( } } // namespace pytorch_flash +<<<<<<< HEAD #endif // USE_ROCM_CK_SDPA +======= +#endif // USE_CK_FLASH_ATTENTION +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h index e92006ef6315c..7390b24cc9592 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h @@ -3,7 +3,11 @@ #include +<<<<<<< HEAD #if defined(USE_ROCM_CK_SDPA) +======= +#if defined(USE_CK_FLASH_ATTENTION) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace pytorch_flash { std::tuple< @@ -64,4 +68,8 @@ mem_eff_backward_ck( const at::Tensor philox_offset); } // namespace pytorch_flash +<<<<<<< HEAD #endif // USE_ROCM_CK_SDPA +======= +#endif // USE_CK_FLASH_ATTENTION +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_fwd_ck.hip index d15c5105d0b46..97f35a22ac788 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_fwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/me_fwd_ck.hip @@ -1,7 +1,11 @@ #include #include +<<<<<<< HEAD #if defined(USE_ROCM_CK_SDPA) +======= +#if defined(USE_CK_FLASH_ATTENTION) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace pytorch_flash { std::tuple< at::Tensor, // output @@ -93,4 +97,8 @@ mem_eff_forward_ck( } } // namespace pytorch_flash +<<<<<<< HEAD #endif // USE_ROCM_CK_SDPA +======= +#endif // USE_CK_FLASH_ATTENTION +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip index 854ac950a867d..46eb9bb9d3b7d 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip @@ -3,7 +3,10 @@ ******************************************************************************/ #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -29,6 +32,7 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask, deterministic}; } +<<<<<<< HEAD aiter::mha_bwd_traits get_mha_bwd_traits(fmha_bwd_traits t, mask_info mask) @@ -49,6 +53,8 @@ aiter::mha_bwd_traits get_mha_bwd_traits(fmha_bwd_traits t, mask_info mask) } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, // sizes const int b, @@ -122,11 +128,19 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, ck_tile::index_t stride_dv = dv.stride(1); ck_tile::index_t nhead_stride_dv = dv.stride(2); +<<<<<<< HEAD // dq_acc: (split, batch_size, nheads, seqlen_q, hdim) ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1); ck_tile::index_t stride_dq_acc = dq_acc.stride(3); ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2); +======= + // dq_acc: (split, batch_size, seqlen_q, nheads, hdim) + ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); + ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1); + ck_tile::index_t stride_dq_acc = dq_acc.stride(2); + ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // bias: (batch_size, nheads, seqlen_q, seqlen_k) void *attn_bias_ptr = nullptr; @@ -372,11 +386,19 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x at::Tensor dq_accum; if (!deterministic) { +<<<<<<< HEAD dq_accum = at::zeros({1, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat)); } else { const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64; const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0); dq_accum = at::zeros({nsplits, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat)); +======= + dq_accum = at::zeros({1, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + } else { + const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64; + const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0); + dq_accum = at::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } at::Tensor dk_expanded, dv_expanded; @@ -397,6 +419,17 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x if (seqlen_q > 0) { ck_tile::stream_config stream_config{stream}; dq.zero_(); // ck use atomic operation on dq +<<<<<<< HEAD +======= + auto traits = + get_ck_fmha_bwd_traits(mask, + q_dtype_str, + head_size_8x, + is_dropout, + attn_bias_.has_value(), + deterministic, + bias_requires_grad); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto args = get_ck_fmha_bwd_args( @@ -424,6 +457,7 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x softmax_scale, p_dropout, drop_seed_offset); +<<<<<<< HEAD float t = aiter::mha_bwd(args, stream_config, @@ -441,6 +475,9 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x +======= + float t = fmha_bwd(traits, args, stream_config); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip index 05f97414acdd8..1539bcff520e1 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip @@ -22,7 +22,10 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask, dtype, false, // is_group_mode true, // is_v_rowmajor +<<<<<<< HEAD false, // has_logits_soft_cap +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mask.type, enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias, has_lse, @@ -86,7 +89,10 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, ck_tile::index_t stride_attn_bias = 0; ck_tile::index_t batch_stride_bias = 0; ck_tile::index_t nhead_stride_bias = 0; +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (attn_bias_.has_value()) { auto a_b = attn_bias_.value(); CHECK_DEVICE(a_b); @@ -96,6 +102,10 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, nhead_stride_bias = a_b.stride(1); batch_stride_bias = a_b.stride(0); } +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return fmha_fwd_args{q.data_ptr(), k.data_ptr(), v.data_ptr(), @@ -117,7 +127,10 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, softmax_scale, // scale_s 1, // scale_p 1, // scale_o +<<<<<<< HEAD 0.0f, // logits_soft_cap +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stride_q, stride_k, stride_v, @@ -141,7 +154,10 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, mask.left, mask.right, static_cast(mask.type), +<<<<<<< HEAD -1, // min_seqlen_q +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) p_dropout, has_dropout_randval, drop_seed_offset}; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip index ee6261df8a91a..e9a724170e772 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip @@ -20,7 +20,10 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask, dtype, true, // is_group_mode true, // is_v_rowmajor +<<<<<<< HEAD false, // has_logits_soft_cap +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mask.type, enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias, has_lse, @@ -118,7 +121,10 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, softmax_scale, // scale_s 1, // scale_p 1, // scale_o +<<<<<<< HEAD 0.0f, // logits_soft_cap +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stride_q, stride_k, stride_v, @@ -142,7 +148,10 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, mask.left, mask.right, static_cast(mask.type), +<<<<<<< HEAD -1, // min_seqlen_q +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) p_dropout, has_dropout_randval, drop_seed_offset}; @@ -212,7 +221,11 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads const int total_q = q.size(0); const int total_k = k.size(0); +<<<<<<< HEAD TORCH_CHECK(batch_size > 0, "batch size must be positive"); +======= + TORCH_CHECK(batch_size > 0, "batch size must be postive"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/rotary.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/rotary.hpp new file mode 100644 index 0000000000000..85754c0378725 --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/rotary.hpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// keep sync with RotaryEmbeddingEnum +enum class rope_enum +{ + none = 0, + interleaved = 1, + half_rotated = 2, +}; + +template +std::tuple, ck_tile::HostTensor> +generate_rotary_cos_sin(ck_tile::index_t seqlen, + ck_tile::index_t rotary_dim, + std::optional seed = std::nullopt) +{ + // return dummy tensors if we won't apply RoPE at all + if(rotary_dim <= 0) + { + ck_tile::HostTensor dummy({1, 1}); + return std::make_tuple(dummy, dummy); + } + + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_real_distribution generator(0.0f, 1.0f); + + const ck_tile::index_t num_rows = seqlen * 2; + const ck_tile::index_t num_cols = rotary_dim / 2; + + using std::begin, std::end; + + ck_tile::HostTensor angle({num_rows, num_cols}); + std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; }); + + ck_tile::HostTensor cos({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) { + return ck_tile::type_convert(std::cos(origin_value)); + }); + + ck_tile::HostTensor sin({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) { + return ck_tile::type_convert(std::sin(origin_value)); + }); + + return std::make_tuple(cos, sin); +} + +template +std::tuple, ck_tile::HostTensor> +slice_rotary_cos_sin(const ck_tile::HostTensor& cos, + const ck_tile::HostTensor& sin, + ck_tile::index_t seqlen_offset, + ck_tile::index_t seqlen) +{ + assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2); + assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1)); + + assert(static_cast(seqlen_offset + seqlen) <= cos.get_length(0)); + + const ck_tile::index_t num_rows = seqlen; + const ck_tile::index_t num_cols = cos.get_length(1); + + ck_tile::HostTensor cos_pt({num_rows, num_cols}); + cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); }); + + ck_tile::HostTensor sin_pt({num_rows, num_cols}); + sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); }); + + return std::make_tuple(cos_pt, sin_pt); +} diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h index 71a1959065970..e7348297c054d 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -147,7 +147,11 @@ std::tuple mha_varlen_bwd_aot( const at::Tensor& philox_seed, const at::Tensor& philox_offset); +<<<<<<< HEAD #if defined(USE_ROCM_CK_SDPA) +======= +#if defined(USE_CK_FLASH_ATTENTION) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // CK implementation TORCH_API std::tuple< @@ -333,7 +337,11 @@ mha_varlen_fwd( const float softcap, const bool return_softmax, std::optional gen_) { +<<<<<<< HEAD #if defined(USE_ROCM_CK_SDPA) +======= +#if defined(USE_CK_FLASH_ATTENTION) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { std::optional dummy_attn_bias = std::nullopt; @@ -406,10 +414,16 @@ inline std::tuple mha_bwd( const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { +<<<<<<< HEAD #if defined(USE_ROCM_CK_SDPA) if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { +======= + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { +#if defined(USE_CK_FLASH_ATTENTION) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::optional non_null_dbias = std::nullopt; const int non_null_window_left = window_size_left.value_or(-1); const int non_null_window_right = window_size_right.value_or(-1); @@ -440,8 +454,15 @@ inline std::tuple mha_bwd( philox_offset); // for FA return [dQ, dV, dK, dSoftmax] return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax)); +<<<<<<< HEAD } #endif +======= +#else + TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend..."); +#endif + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return mha_bwd_aot( dout, q, @@ -494,7 +515,11 @@ inline std::tuple mha_varlen_bwd const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { +<<<<<<< HEAD #if defined(USE_ROCM_CK_SDPA) +======= +#if defined(USE_CK_FLASH_ATTENTION) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { std::optional non_null_dbias = std::nullopt; diff --git a/aten/src/ATen/native/transformers/sdp_utils.h b/aten/src/ATen/native/transformers/sdp_utils.h index 809abe50178ec..28aae7f26bc58 100644 --- a/aten/src/ATen/native/transformers/sdp_utils.h +++ b/aten/src/ATen/native/transformers/sdp_utils.h @@ -23,7 +23,11 @@ void alloc_with_matching_layout( const auto q_strides = q.strides(); std::stable_sort( fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) { +<<<<<<< HEAD return q_strides[idx1] ? q_strides[idx1] : 1 < q_strides[idx2] ? q_strides[idx2] : 1; +======= + return q_strides[idx1] < q_strides[idx2]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }); std::vector ordered_strides(shape.size()); int64_t current_stride = 1; diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.cpp b/aten/src/ATen/native/transformers/sdp_utils_cpp.cpp index 931b66cbef9ad..63dfd6a8e20bc 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.cpp +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.cpp @@ -43,7 +43,11 @@ bool use_flash_attention_cpp(sdp_params const& params, bool debug) { check_nested_tensor, check_for_dropout, check_tensor_shapes, +<<<<<<< HEAD check_batch_size_and_num_heads_dense, +======= + check_batch_size_and_num_heads_dense, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) check_attn_mask_shape, check_head_dim_size_cpp, check_nonzero_sequence_lengths_dense, diff --git a/aten/src/ATen/native/utils/ParamsHash.h b/aten/src/ATen/native/utils/ParamsHash.h index 4c9d97328ad61..dccc77222e1dd 100644 --- a/aten/src/ATen/native/utils/ParamsHash.h +++ b/aten/src/ATen/native/utils/ParamsHash.h @@ -41,7 +41,11 @@ struct ParamsEqual { }; // Provide explicit byte-for-byte constructors to avoid uwittingly leaving +<<<<<<< HEAD // padding bytes uninitialized (e.g., when passing Params by value) +======= +// padding bytes unitialized (e.g., when passing Params by value) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template struct ParamsWrapper { T pod; diff --git a/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h b/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h index 532caa62687a8..ad2dce583c88d 100644 --- a/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h +++ b/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h @@ -33,8 +33,12 @@ struct VulkanOpaqueTensorImpl : public OpaqueTensorImpl { return c10::fromIntArrayRefKnownNonNegative(strides_); } +<<<<<<< HEAD c10::SymBool sym_is_contiguous_custom( c10::MemoryFormat memory_format) const override { +======= + bool is_contiguous_custom(c10::MemoryFormat memory_format) const override { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (void)memory_format; return true; } diff --git a/aten/src/ATen/native/vulkan/api/Types.h b/aten/src/ATen/native/vulkan/api/Types.h index 1202a3bd73938..c89cd57be4ea0 100644 --- a/aten/src/ATen/native/vulkan/api/Types.h +++ b/aten/src/ATen/native/vulkan/api/Types.h @@ -71,7 +71,11 @@ inline VkFormat to_vkformat(const ScalarType t) { /* * Given a `VkFormat`, return the `ScalarType` that best represents the data +<<<<<<< HEAD * type of individual elements in an image texture of the `VkFormat`. Note that +======= + * type of invidivual elements in an image texture of the `VkFormat`. Note that +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * this mapping is different from the `to_vkformat()` function, since different * `ScalarType`s may use the same `VkFormat`. */ diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl index 47a2630aaafbe..ab3d1e147533e 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl @@ -75,7 +75,11 @@ void main() { // During prepacking, the weight tensor was rearranged in order to optimize // for data access linearity in this shader. Therefore we need to adjust the // canonical coordinates to the corresponding index in the rearranged weight +<<<<<<< HEAD // tensor. the x coordinate is multiplied by 4 since each group of 4 channels +======= + // tensor. the x coordinate is multipled by 4 since each group of 4 channels +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // is folded into the X axis. The y coordinate is offset based on the z // coordinate because the 2D planes were stacked atop each other vertically. kstart.x *= 4; diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl index d4188d6580599..20413e345e9c9 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl @@ -39,7 +39,11 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; /* * Computes a 2D pointwise convolution of a 2x2 output tile. Calculating an * output tile for pointwise convolution is more efficient because the kernel +<<<<<<< HEAD * size is only 1x1, making it much easier to reuse loaded texels from uKernel. +======= + * size is only 1x1, making it much easier to re-use loaded texels from uKernel. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) */ void main() { const ivec3 gpos = ivec3(gl_GlobalInvocationID); diff --git a/aten/src/ATen/native/vulkan/glsl/image_to_nchw_uint.glsl b/aten/src/ATen/native/vulkan/glsl/image_to_nchw_uint.glsl index 1f66a5fe19151..4707dd4b539e0 100644 --- a/aten/src/ATen/native/vulkan/glsl/image_to_nchw_uint.glsl +++ b/aten/src/ATen/native/vulkan/glsl/image_to_nchw_uint.glsl @@ -57,7 +57,11 @@ void main() { // out CxHxW plane. ivec4 c_index = pos_in_batch / uBlock.in_extents.w; +<<<<<<< HEAD // we divide pos_in_batch by HxW, to compute the channel index +======= + // we devide pos_in_batch by HxW, to compute the channel index +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ivec4 pos_in_hw = pos_in_batch % uBlock.in_extents.w; // we compute the reminder mod HxW, to find the positions in the flatten diff --git a/aten/src/ATen/native/vulkan/glsl/indexing.h b/aten/src/ATen/native/vulkan/glsl/indexing.h index c34ce25001ef5..5a3e530e2bc74 100644 --- a/aten/src/ATen/native/vulkan/glsl/indexing.h +++ b/aten/src/ATen/native/vulkan/glsl/indexing.h @@ -1,12 +1,20 @@ /* +<<<<<<< HEAD * Computes a 4D tensor coordinate from a linearized index +======= + * Computes a 4D tensor co-ordinate from a linearized index +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) */ uvec4 idx_to_coord(const uint idx, const uvec4 strides, const uvec4 sizes) { return ivec4(mod(idx / strides, sizes)); } /* +<<<<<<< HEAD * Computes a linearized index from a 4D tensor coordinate +======= + * Computes a linearized index from a 4D tensor co-ordinate +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) */ uint coord_to_idx(const uvec4 coord, const uvec4 strides) { return int(dot(coord * strides, ivec4(1))); diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl index bc13655d01e07..829b70aa7d7f5 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl @@ -96,7 +96,11 @@ void main() { // During prepacking, the weight tensor was rearranged in order to optimize // for data access linearity in this shader. Therefore we need to adjust the // canonical coordinates to the corresponding index in the rearranged weight +<<<<<<< HEAD // tensor. the x coordinate is multiplied by 4 since each group of 4 channels +======= + // tensor. the x coordinate is multipled by 4 since each group of 4 channels +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // is folded into the X axis. The y coordinate is offset based on the z // coordinate because the 2D planes were stacked atop each other vertically. kstart.x *= 4; diff --git a/aten/src/ATen/native/vulkan/ops/Tile.cpp b/aten/src/ATen/native/vulkan/ops/Tile.cpp index d39fd951106c6..94975bc7449c0 100644 --- a/aten/src/ATen/native/vulkan/ops/Tile.cpp +++ b/aten/src/ATen/native/vulkan/ops/Tile.cpp @@ -18,7 +18,11 @@ namespace { using namespace api::utils; Tensor tile(const Tensor& self, const IntArrayRef repeats) { +<<<<<<< HEAD // If self.size() > len(reps), reps is promoted to self.size() by prepending +======= + // If self.size() > len(reps), reps is promoted to self.size() by pre-pending +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // 1’s to it to keep the same behaviour as `numpy.tile`. // Thus for a tensor of shape (2, 3, 4, 5), a dims of (2, 2) is treated // as (1, 1, 2, 2). diff --git a/aten/src/ATen/nnapi/nnapi_bind.cpp b/aten/src/ATen/nnapi/nnapi_bind.cpp index 8f40ee4045681..ecdb727186cfe 100644 --- a/aten/src/ATen/nnapi/nnapi_bind.cpp +++ b/aten/src/ATen/nnapi/nnapi_bind.cpp @@ -26,7 +26,11 @@ static void load_platform_library() { (void)run_once; } +<<<<<<< HEAD // NnapiCompilation function definitions: +======= +// NnapiCompilation functon definitions: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Could possibly call load_platform_library in constructor, but error reporting // can be complicated if the constructor is called during model loading. diff --git a/aten/src/ATen/ops/from_blob.h b/aten/src/ATen/ops/from_blob.h index a209380abb64e..551addffcc0c7 100644 --- a/aten/src/ATen/ops/from_blob.h +++ b/aten/src/ATen/ops/from_blob.h @@ -5,7 +5,11 @@ namespace at { namespace detail { +<<<<<<< HEAD inline void noopDelete(void*) {} +======= +TORCH_API inline void noopDelete(void*) {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace detail diff --git a/aten/src/ATen/record_function.cpp b/aten/src/ATen/record_function.cpp index 94a2bf56f8d7b..31f8afc93ffb9 100644 --- a/aten/src/ATen/record_function.cpp +++ b/aten/src/ATen/record_function.cpp @@ -724,6 +724,7 @@ uint64_t RecordFunction::currentThreadId() { return current_thread_id_; } +<<<<<<< HEAD void RecordFunction::before(RecordFunction::FunctionDescriptor fn, int64_t sequence_nr) { std::visit([this](auto&& fn) { if constexpr (std::is_same_v, std::string_view>) { @@ -735,6 +736,38 @@ void RecordFunction::before(RecordFunction::FunctionDescriptor fn, int64_t seque } }, fn); sequence_nr_ = sequence_nr; +======= +void RecordFunction::before(const char* name, int64_t sequence_nr) { + fn_ = name; + sequence_nr_ = sequence_nr; + is_nccl_meta_ = (std::strcmp(name, kParamCommsCallName.c_str()) == 0); + +#ifndef NDEBUG + inputs_valid_ = true; +#endif + runStartCallbacks(); + invalidateInputs(); +} + +void RecordFunction::before(std::string name, int64_t sequence_nr) { + is_nccl_meta_ = (name == kParamCommsCallName); + fn_ = std::move(name); + sequence_nr_ = sequence_nr; + +#ifndef NDEBUG + inputs_valid_ = true; +#endif + runStartCallbacks(); + invalidateInputs(); +} + +void RecordFunction::before( + RecordFunction::schema_ref_t schema, + int64_t sequence_nr) { + sequence_nr_ = sequence_nr; + fn_ = schema; + is_nccl_meta_ = (schema.get().name() == kParamCommsCallName); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifndef NDEBUG inputs_valid_ = true; diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 8ec70a1682f37..69cf67b90f7b9 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -9,7 +9,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include namespace c10 { @@ -288,11 +291,17 @@ struct TORCH_API RecordFunction { explicit RecordFunction(RecordScope scope = RecordScope::FUNCTION); explicit RecordFunction(StepCallbacks&& step_callbacks); +<<<<<<< HEAD using schema_ref_t = std::reference_wrapper; using FunctionDescriptor = std::variant; void before( FunctionDescriptor fn, +======= + template + void before( + F fn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::ArrayRef args, int64_t current_sequence_nr = -1) { if (!isActive()) { @@ -302,8 +311,14 @@ struct TORCH_API RecordFunction { before(fn, current_sequence_nr); } +<<<<<<< HEAD void before( FunctionDescriptor fn, +======= + template + void before( + F fn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::ArrayRef args, const std::unordered_map* kwargs, int64_t current_sequence_nr = -1) { @@ -311,11 +326,20 @@ struct TORCH_API RecordFunction { return; } kwinputs_ = *kwargs; +<<<<<<< HEAD before(fn, args, current_sequence_nr); } void before( FunctionDescriptor fn, +======= + before(std::move(fn), args, current_sequence_nr); + } + + template + void before( + F fn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const std::unordered_map* kwargs, int64_t current_sequence_nr = -1) { if (!isActive()) { @@ -325,18 +349,34 @@ struct TORCH_API RecordFunction { before(fn, current_sequence_nr); } +<<<<<<< HEAD void before( FunctionDescriptor fn, const std::vector* args, int64_t current_sequence_nr = -1) { before( fn, +======= + template + void before( + F fn, + const std::vector* args, + int64_t current_sequence_nr = -1) { + before( + std::move(fn), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::ArrayRef(args->data(), args->size()), current_sequence_nr); } +<<<<<<< HEAD void before( FunctionDescriptor fn, +======= + template + void before( + F fn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const std::vector* args, const std::unordered_map* kwargs, int64_t current_sequence_nr = -1) { @@ -425,7 +465,14 @@ struct TORCH_API RecordFunction { // before functions initialize RecordFunction members and call // start callbacks +<<<<<<< HEAD void before(FunctionDescriptor schema, int64_t sequence_nr = -1); +======= + using schema_ref_t = std::reference_wrapper; + void before(const char* name, int64_t sequence_nr = -1); + void before(std::string name, int64_t sequence_nr = -1); + void before(schema_ref_t schema, int64_t sequence_nr = -1); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Sets node ID for distributed profiling static void setDefaultNodeId(int64_t defaultNodeId); @@ -549,10 +596,17 @@ TORCH_API std::optional getStepCallbacksUnlessEmpty( RecordScope scope); namespace detail { +<<<<<<< HEAD template void record_function_with_scope( RecordFunction& guard, RecordFunction::FunctionDescriptor fn, +======= +template +void record_function_with_scope( + RecordFunction& guard, + F fn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const Inputs& inputs, Args&&... args) { if (guard.needsInputs()) { @@ -565,10 +619,17 @@ void record_function_with_scope( } } +<<<<<<< HEAD template void record_function_with_scope_and_debug_handle( RecordFunction& guard, RecordFunction::FunctionDescriptor fn, +======= +template +void record_function_with_scope_and_debug_handle( + RecordFunction& guard, + F fn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t debug_handle, const Inputs& inputs, Args&&... args) { @@ -583,6 +644,7 @@ void record_function_with_scope_and_debug_handle( } } +<<<<<<< HEAD template void record_function_with_scope( RecordFunction& guard, @@ -597,12 +659,36 @@ template void record_function_with_scope_and_debug_handle( RecordFunction& guard, RecordFunction::FunctionDescriptor fn, +======= +template +void record_function_with_scope( + RecordFunction& guard, + F fn, + c10::ArrayRef inputs, + Args&&... args) { + return record_function_with_scope< + c10::ArrayRef, + F, + Args...>(guard, std::move(fn), inputs, std::forward(args)...); +} + +template +void record_function_with_scope_and_debug_handle( + RecordFunction& guard, + F fn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t debug_handle, c10::ArrayRef inputs, Args&&... args) { return record_function_with_scope_and_debug_handle< c10::ArrayRef, +<<<<<<< HEAD Args...>(guard, fn, debug_handle, inputs, std::forward(args)...); +======= + F, + Args...>( + guard, std::move(fn), debug_handle, inputs, std::forward(args)...); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace detail @@ -666,7 +752,11 @@ void record_function_with_scope_and_debug_handle( guard, fn, debug_handle, inputs, ##__VA_ARGS__); \ } +<<<<<<< HEAD // Helper macros to record LITE INTERPRETER scope events with debug handles +======= +// Helper macros to record LITE INTERPETER scope events with debug handles +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #define RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( \ fn, debug_handle, inputs) \ RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ diff --git a/aten/src/ATen/templates/FunctionalInverses.h b/aten/src/ATen/templates/FunctionalInverses.h index b15cd09a6c65d..72b34ed5794ea 100644 --- a/aten/src/ATen/templates/FunctionalInverses.h +++ b/aten/src/ATen/templates/FunctionalInverses.h @@ -2,12 +2,29 @@ // ${generated_comment} +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include namespace at { namespace functionalization { +<<<<<<< HEAD +======= +enum class InverseReturnMode { + /// Specifies that functional inverses should always return a view. + AlwaysView, + /// Specifies that functional inverses should always return a non-view / copy. + NeverView, + /// Specifies that functional inverses should return a view unless a (copying) scatter + /// inverse exists, in which case that will be used instead. + /// This avoids as_strided() calls that can be difficult for subclasses to handle. + ViewOrScatterInverse, +}; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct FunctionalInverses { ${view_inverse_declarations} diff --git a/aten/src/ATen/templates/Functions.cpp b/aten/src/ATen/templates/Functions.cpp index f210402e543aa..ba67c27fd387a 100644 --- a/aten/src/ATen/templates/Functions.cpp +++ b/aten/src/ATen/templates/Functions.cpp @@ -64,7 +64,11 @@ Tensor TensorMaker::make_tensor() { if (strides_) { auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize); if (storage_offset_) { +<<<<<<< HEAD storage_size += storage_offset_.value() * itemsize; +======= + storage_size += storage_offset_.value(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return storage_size; } @@ -75,7 +79,11 @@ Tensor TensorMaker::make_tensor() { } auto storage_size = size * itemsize; if (storage_offset_) { +<<<<<<< HEAD storage_size += storage_offset_.value() * itemsize; +======= + storage_size += storage_offset_.value(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return storage_size; } diff --git a/aten/src/ATen/templates/Functions.h b/aten/src/ATen/templates/Functions.h index b1feaf9d4daa9..f75421a64ccd0 100644 --- a/aten/src/ATen/templates/Functions.h +++ b/aten/src/ATen/templates/Functions.h @@ -83,6 +83,7 @@ namespace at { // Special C++ only overloads for std()-like functions (See gh-40287) // These are needed because int -> bool conversion takes precedence over int -> IntArrayRef // So, for example std(0) would select the std(unbiased=False) overload +<<<<<<< HEAD inline Tensor var(const Tensor& self, int dim) { return at::var(self, IntArrayRef{dim}); } @@ -93,6 +94,18 @@ inline Tensor std(const Tensor& self, int dim) { return at::std(self, IntArrayRef{dim}); } inline std::tuple std_mean(const Tensor& self, int dim) { +======= +TORCH_API inline Tensor var(const Tensor& self, int dim) { + return at::var(self, IntArrayRef{dim}); +} +TORCH_API inline std::tuple var_mean(const Tensor& self, int dim) { + return at::var_mean(self, IntArrayRef{dim}); +} +TORCH_API inline Tensor std(const Tensor& self, int dim) { + return at::std(self, IntArrayRef{dim}); +} +TORCH_API inline std::tuple std_mean(const Tensor& self, int dim) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return at::std_mean(self, IntArrayRef{dim}); } diff --git a/aten/src/ATen/templates/RegisterDispatchKey.cpp b/aten/src/ATen/templates/RegisterDispatchKey.cpp index 39c85b00d7a1b..7a4d488c1382d 100644 --- a/aten/src/ATen/templates/RegisterDispatchKey.cpp +++ b/aten/src/ATen/templates/RegisterDispatchKey.cpp @@ -5,11 +5,21 @@ // NOTE: This condition is true for all PyTorch internal libraries, it // just excludes external projects such as torch_xla which +<<<<<<< HEAD // reuse some of the PyTorch codegen machinery. #if defined(CAFFE2_BUILD_MAIN_LIB) || \ defined(TORCH_CUDA_BUILD_MAIN_LIB) || \ defined(TORCH_HIP_BUILD_MAIN_LIB) || \ defined(TORCH_XPU_BUILD_MAIN_LIB) +======= +// re-use some of the PyTorch codegen machinery. +#if defined(CAFFE2_BUILD_MAIN_LIB) || \ + defined(TORCH_CUDA_BUILD_MAIN_LIB) || \ + defined(TORCH_HIP_BUILD_MAIN_LIB) || \ + defined(TORCH_XPU_BUILD_MAIN_LIB) || \ + defined(TORCH_CUDA_CU_BUILD_MAIN_LIB) || \ + defined(TORCH_CUDA_CPP_BUILD_MAIN_LIB) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #endif diff --git a/aten/src/ATen/templates/RegisterFunctionalization.cpp b/aten/src/ATen/templates/RegisterFunctionalization.cpp index 408aff0cdab40..36a534b9fb1c7 100644 --- a/aten/src/ATen/templates/RegisterFunctionalization.cpp +++ b/aten/src/ATen/templates/RegisterFunctionalization.cpp @@ -4,7 +4,11 @@ #include #include #include +<<<<<<< HEAD #include +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 8ae2dee1ce50c..b5a6ed225258b 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -491,7 +491,11 @@ class TORCH_API Tensor: public TensorBase { "attribute won't be populated during autograd.backward(). If you indeed want the .grad " "field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. " "If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor " +<<<<<<< HEAD "instead. See github.com/pytorch/pytorch/pull/30531 for more information."); +======= + "instead. See github.com/pytorch/pytorch/pull/30531 for more informations."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return maybe_grad; } diff --git a/aten/src/ATen/test/basic.cpp b/aten/src/ATen/test/basic.cpp index 0937de4552821..7219b84c5dfd2 100644 --- a/aten/src/ATen/test/basic.cpp +++ b/aten/src/ATen/test/basic.cpp @@ -517,6 +517,7 @@ TEST(BasicTest, BasicStdTestCPU) { t3.join(); t4.join(); } +<<<<<<< HEAD TEST(BasicTest, TestForBlobResizeCPU) { // Checks that for_blob can correctly create tensors with non-empty offset and resize them @@ -535,3 +536,5 @@ TEST(BasicTest, TestForBlobStridesResizeCPU) { auto te = *at::expand_size(t, {3, 3}); ASSERT_EQ(te[1][1].item(), 5); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/test/cpu_caching_allocator_test.cpp b/aten/src/ATen/test/cpu_caching_allocator_test.cpp index fa650df221fa0..fe9bec05e871d 100644 --- a/aten/src/ATen/test/cpu_caching_allocator_test.cpp +++ b/aten/src/ATen/test/cpu_caching_allocator_test.cpp @@ -5,9 +5,12 @@ #include +<<<<<<< HEAD // At the moment caching allocator is only exposed to mobile cpu allocator. #ifdef C10_MOBILE +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TEST(CPUCachingAllocatorTest, check_alloc_free) { c10::CPUCachingAllocator caching_allocator; c10::WithCPUCachingAllocatorGuard cachine_allocator_guard( @@ -44,9 +47,19 @@ TEST(CPUCachingAllocatorTest, check_alloc_inside_free_outside) { } int main(int argc, char* argv[]) { +<<<<<<< HEAD ::testing::InitGoogleTest(&argc, argv); at::manual_seed(42); return RUN_ALL_TESTS(); } #endif /* C10_Mobile */ +======= +// At the moment caching allocator is only exposed to mobile cpu allocator. +#ifdef C10_MOBILE + ::testing::InitGoogleTest(&argc, argv); + at::manual_seed(42); + return RUN_ALL_TESTS(); +#endif /* C10_Mobile */ +} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/test/cpu_profiling_allocator_test.cpp b/aten/src/ATen/test/cpu_profiling_allocator_test.cpp index 15220e58e2485..d71ada29b0a78 100644 --- a/aten/src/ATen/test/cpu_profiling_allocator_test.cpp +++ b/aten/src/ATen/test/cpu_profiling_allocator_test.cpp @@ -199,7 +199,11 @@ int main(int argc, char* argv[]) { #ifdef C10_MOBILE // Need to disable mkldnn for this test since it allocated memory +<<<<<<< HEAD // via raw_allocate interface which requires context pointer and raw +======= + // via raw_allocate inteface which requires context pointer and raw +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // pointer to be the same. Tis is not true for mobile allocator. at::globalContext().setUserEnabledMkldnn(false); #endif diff --git a/aten/src/ATen/test/cuda_allocator_test.cpp b/aten/src/ATen/test/cuda_allocator_test.cpp index 27a352e7d5a26..c9d8ad6fa35f6 100644 --- a/aten/src/ATen/test/cuda_allocator_test.cpp +++ b/aten/src/ATen/test/cuda_allocator_test.cpp @@ -5,6 +5,57 @@ #include +<<<<<<< HEAD TEST(AllocatorTestCUDA, test_clone) { test_allocator_clone(c10::cuda::CUDACachingAllocator::get()); } +======= +#include + +TEST(AllocatorTestCUDA, test_clone) { + test_allocator_clone(c10::cuda::CUDACachingAllocator::get()); +} + +static int called_dummy_free_0 = 0; +static int called_dummy_free_1 = 0; + +void* dummy_alloc_0(size_t size, int device, void* stream) {return nullptr;} +void dummy_free_0(void* data, size_t size, int device, void* stream) { + called_dummy_free_0++; +} +void dummy_free_1(void* data, size_t size, int device, void* stream) { + called_dummy_free_1++; +} + +// Tests that data_ptrs have their respective deleters +// when mixing allocators +TEST(AllocatorTestCUDA, test_pluggable_allocator_deleters) { + // Create a tensor with dummy_allocator_0, where dummy_free_0 is the deleter + auto dummy_allocator_0 = torch::cuda::CUDAPluggableAllocator::createCustomAllocator(dummy_alloc_0, dummy_free_0); + c10::cuda::CUDACachingAllocator::allocator.store(dummy_allocator_0.get()); + at::Tensor a = at::empty({0}, at::TensorOptions().device(at::kCUDA)); + + // Create a tensor with dummy_allocator_1, where dummy_free_1 is the deleter + auto dummy_allocator_1 = torch::cuda::CUDAPluggableAllocator::createCustomAllocator(dummy_alloc_0, dummy_free_1); + c10::cuda::CUDACachingAllocator::allocator.store(dummy_allocator_1.get()); + at::Tensor b = at::empty({0}, at::TensorOptions().device(at::kCUDA)); + + // Manually use a's deleter + auto* ctx = a.storage().data_ptr().get_context(); + a.storage().data_ptr().get_deleter()(ctx); + a.storage().mutable_data_ptr().release_context(); + + // a's deleter is dummy_free_0 + // dummy_free_0 should be called above, so called_dummy_free_0 should be 1 + ASSERT_TRUE(called_dummy_free_0 == 1); + + // Manually use b's deleter + ctx = b.storage().data_ptr().get_context(); + b.storage().data_ptr().get_deleter()(ctx); + b.storage().mutable_data_ptr().release_context(); + + // b's deleter is dummy_free_1 + // dummy_free_1 should be called above, so called_dummy_free_1 should be 1 + ASSERT_TRUE(called_dummy_free_1 == 1); +} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/test/cuda_complex_test.cu b/aten/src/ATen/test/cuda_complex_test.cu index 5736f73330760..8558e74094b33 100644 --- a/aten/src/ATen/test/cuda_complex_test.cu +++ b/aten/src/ATen/test/cuda_complex_test.cu @@ -5,14 +5,24 @@ __global__ void test_thrust_kernel() { // thrust conversion { +<<<<<<< HEAD [[maybe_unused]] constexpr float num1 = float(1.23); [[maybe_unused]] constexpr float num2 = float(4.56); +======= + constexpr float num1 = float(1.23); + constexpr float num2 = float(4.56); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert(c10::complex(thrust::complex(num1, num2)).real() == num1); assert(c10::complex(thrust::complex(num1, num2)).imag() == num2); } { +<<<<<<< HEAD [[maybe_unused]] constexpr double num1 = double(1.23); [[maybe_unused]] constexpr double num2 = double(4.56); +======= + constexpr double num1 = double(1.23); + constexpr double num2 = double(4.56); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert(c10::complex(thrust::complex(num1, num2)).real() == num1); assert(c10::complex(thrust::complex(num1, num2)).imag() == num2); } @@ -46,11 +56,19 @@ __global__ void test_reinterpret_cast() { assert(zzzz.real() == double(1)); assert(zzzz.imag() == double(2)); +<<<<<<< HEAD [[maybe_unused]] cuComplex cuComplex_zz = *reinterpret_cast(&zz); assert(cuComplex_zz.x == float(1)); assert(cuComplex_zz.y == float(2)); [[maybe_unused]] cuDoubleComplex cuDoubleComplex_zzzz = *reinterpret_cast(&zzzz); +======= + cuComplex cuComplex_zz = *reinterpret_cast(&zz); + assert(cuComplex_zz.x == float(1)); + assert(cuComplex_zz.y == float(2)); + + cuDoubleComplex cuDoubleComplex_zzzz = *reinterpret_cast(&zzzz); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert(cuDoubleComplex_zzzz.x == double(1)); assert(cuDoubleComplex_zzzz.y == double(2)); } diff --git a/aten/src/ATen/test/cuda_cub_test.cu b/aten/src/ATen/test/cuda_cub_test.cu index 6865984102b4b..f22b3f59023f1 100644 --- a/aten/src/ATen/test/cuda_cub_test.cu +++ b/aten/src/ATen/test/cuda_cub_test.cu @@ -146,8 +146,13 @@ TEST(InclusiveScanSplit, CubTest) { cudaMallocManaged(&output1, sizeof(int) * 10); cudaDeviceSynchronize(); +<<<<<<< HEAD at::cuda::cub::inclusive_scan, /*max_cub_size=*/2>( input, output1, NO_ROCM(::cuda)::std::plus<>(), 10); +======= + at::cuda::cub::inclusive_scan( + input, output1, ::at_cuda_detail::cub::Sum(), 10); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cudaDeviceSynchronize(); ASSERT_EQ(output1[0], 1); @@ -172,8 +177,13 @@ TEST(ExclusiveScanSplit, CubTest) { cudaMallocManaged(&output2, sizeof(int) * 10); cudaDeviceSynchronize(); +<<<<<<< HEAD at::cuda::cub::exclusive_scan, int, /*max_cub_size=*/2>( input, output2, NO_ROCM(::cuda)::std::plus<>(), 0, 10); +======= + at::cuda::cub::exclusive_scan( + input, output2, ::at_cuda_detail::cub::Sum(), 0, 10); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cudaDeviceSynchronize(); ASSERT_EQ(output2[0], 0); diff --git a/aten/src/ATen/test/cuda_dlconvertor_test.cpp b/aten/src/ATen/test/cuda_dlconvertor_test.cpp index 34f8589391d5e..3ae94d53d6128 100644 --- a/aten/src/ATen/test/cuda_dlconvertor_test.cpp +++ b/aten/src/ATen/test/cuda_dlconvertor_test.cpp @@ -9,7 +9,10 @@ #include using namespace at; +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TEST(TestDlconvertor, TestDlconvertorCUDA) { manual_seed(123); @@ -51,6 +54,7 @@ TEST(TestDlconvertor, TestDlconvertorCUDAHIP) { ASSERT_TRUE(a.equal(b)); } +<<<<<<< HEAD TEST(TestDlconvertorVersioned, TestDlconvertorCUDA) { manual_seed(123); @@ -93,3 +97,5 @@ TEST(TestDlconvertorVersioned, TestDlconvertorCUDAHIP) { ASSERT_TRUE(a.equal(b)); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/test/cuda_half_test.cu b/aten/src/ATen/test/cuda_half_test.cu index 6f45acc30f9ea..4323dfe824846 100644 --- a/aten/src/ATen/test/cuda_half_test.cu +++ b/aten/src/ATen/test/cuda_half_test.cu @@ -33,7 +33,11 @@ __device__ void test(){ // use the std namespace, but just "::" so that the function // gets resolved from nvcc math_functions.hpp +<<<<<<< HEAD [[maybe_unused]] float threshold = 0.00001; +======= + float threshold = 0.00001; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert(::abs(::lgamma(Half(10.0)) - ::lgamma(10.0f)) <= threshold); assert(::abs(::exp(Half(1.0)) - ::exp(1.0f)) <= threshold); assert(::abs(::log(Half(1.0)) - ::log(1.0f)) <= threshold); diff --git a/aten/src/ATen/test/cuda_vectorized_test.cu b/aten/src/ATen/test/cuda_vectorized_test.cu index e4c18102526ac..a15bef3178a9f 100644 --- a/aten/src/ATen/test/cuda_vectorized_test.cu +++ b/aten/src/ATen/test/cuda_vectorized_test.cu @@ -10,6 +10,7 @@ using namespace at::native::memory; constexpr int buffer_size = 1024; +<<<<<<< HEAD #if defined(CUDA_VERSION) && CUDA_VERSION < 13000 __managed__ double4 buffer1[buffer_size]; __managed__ double4 buffer2[buffer_size]; @@ -17,6 +18,10 @@ __managed__ double4 buffer2[buffer_size]; __managed__ double4_16a buffer1[buffer_size]; __managed__ double4_16a buffer2[buffer_size]; #endif +======= +__managed__ double4 buffer1[buffer_size]; +__managed__ double4 buffer2[buffer_size]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void reset_buffers() { for (int i = 0; i < buffer_size; i++) { diff --git a/aten/src/ATen/test/dlconvertor_test.cpp b/aten/src/ATen/test/dlconvertor_test.cpp index dca9126c7cde3..de882a7d82723 100644 --- a/aten/src/ATen/test/dlconvertor_test.cpp +++ b/aten/src/ATen/test/dlconvertor_test.cpp @@ -3,8 +3,17 @@ #include #include +<<<<<<< HEAD using namespace at; +======= +#include +// NOLINTNEXTLINE(modernize-deprecated-headers) +#include +#include + +using namespace at; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TEST(TestDlconvertor, TestDlconvertor) { manual_seed(123); @@ -27,6 +36,7 @@ TEST(TestDlconvertor, TestDlconvertorNoStrides) { ASSERT_TRUE(a.equal(b)); } +<<<<<<< HEAD TEST(TestDlconvertorUnversioned, TestDlconvertor) { manual_seed(123); @@ -50,3 +60,5 @@ TEST(TestDlconvertorUnversioned, TestDlconvertorNoStrides) { ASSERT_TRUE(a.equal(b)); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/test/half_test.cpp b/aten/src/ATen/test/half_test.cpp index 9e594196c6925..e53eeddbd2e23 100644 --- a/aten/src/ATen/test/half_test.cpp +++ b/aten/src/ATen/test/half_test.cpp @@ -25,7 +25,11 @@ TEST(TestHalf, Arithmetic) { ASSERT_EQ(one + one, 2); } +<<<<<<< HEAD TEST(TestHalf, Comparisons) { +======= +TEST(TestHalf, Comparisions) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Half zero = 0; Half one = 1; ASSERT_LT(zero, one); diff --git a/aten/src/ATen/test/thread_init_test.cpp b/aten/src/ATen/test/thread_init_test.cpp index 60dd52d1dffcb..d57188d3b0e69 100644 --- a/aten/src/ATen/test/thread_init_test.cpp +++ b/aten/src/ATen/test/thread_init_test.cpp @@ -1,8 +1,15 @@ +<<<<<<< HEAD #include #include #include #include +======= +#include +#include +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include @@ -10,7 +17,11 @@ // numbers of threads set and also whether the scheduler // will throw an exception when multiple threads call // their first parallel construct. +<<<<<<< HEAD static void test(int given_num_threads) { +======= +void test(int given_num_threads) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat)); ASSERT_TRUE(given_num_threads >= 0); ASSERT_EQ(at::get_num_threads(), given_num_threads); @@ -20,7 +31,11 @@ static void test(int given_num_threads) { } } +<<<<<<< HEAD TEST(ThreadInitTest, ThreadInit) { +======= +int main() { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::init_num_threads(); at::set_num_threads(4); @@ -33,11 +48,20 @@ TEST(ThreadInitTest, ThreadInit) { #if !AT_PARALLEL_NATIVE at::set_num_threads(5); +<<<<<<< HEAD ASSERT_EQ(at::get_num_threads(), 5); +======= + ASSERT_TRUE(at::get_num_threads() == 5); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif // test inter-op settings at::set_num_interop_threads(5); ASSERT_EQ(at::get_num_interop_threads(), 5); ASSERT_ANY_THROW(at::set_num_interop_threads(6)); +<<<<<<< HEAD +======= + + return 0; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } diff --git a/aten/src/ATen/test/undefined_tensor_test.cpp b/aten/src/ATen/test/undefined_tensor_test.cpp index ec6997fae9b05..85d6cbeaa7c11 100644 --- a/aten/src/ATen/test/undefined_tensor_test.cpp +++ b/aten/src/ATen/test/undefined_tensor_test.cpp @@ -9,7 +9,11 @@ using namespace at; TEST(TestUndefined, UndefinedTest) { manual_seed(123); +<<<<<<< HEAD // mainly test ops on undefined tensors don't segfault and give a reasonable error message. +======= + // mainly test ops on undefined tensors don't segfault and give a reasonable errror message. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor und; Tensor ft = ones({1}, CPU(kFloat)); diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index b7b756f74ba1f..37e35f01844d8 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -5,7 +5,11 @@ namespace { template class Memory : public ::testing::Test {}; template +<<<<<<< HEAD class Arithmetic : public ::testing::Test {}; +======= + class Arithmetics : public ::testing::Test {}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template class Comparison : public ::testing::Test {}; template @@ -61,8 +65,11 @@ namespace { template class QuantizationTests : public ::testing::Test {}; template +<<<<<<< HEAD class Quantization8BitTests : public ::testing::Test {}; template +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Quantization8BitWithTailTests : public ::testing::Test {}; template class FunctionalTests : public ::testing::Test {}; @@ -81,7 +88,10 @@ namespace { using FloatTestedTypes = ::testing::Types; using ALLTestedTypes = ::testing::Types; using QuantTestedTypes = ::testing::Types; +<<<<<<< HEAD using Quantization8BitTestedTypes = ::testing::Types; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER) using Quantization8BitWithTailTestedTypes = ::testing::Types; @@ -92,7 +102,11 @@ namespace { using ComplexTypes = ::testing::Types; using ReducedFloatTestedTypes = ::testing::Types; TYPED_TEST_SUITE(Memory, ALLTestedTypes); +<<<<<<< HEAD TYPED_TEST_SUITE(Arithmetic, FloatIntTestedTypes); +======= + TYPED_TEST_SUITE(Arithmetics, FloatIntTestedTypes); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TYPED_TEST_SUITE(Comparison, RealFloatIntReducedFloatTestedTypes); TYPED_TEST_SUITE(Bitwise, FloatIntTestedTypes); TYPED_TEST_SUITE(MinMax, RealFloatIntTestedTypes); @@ -119,7 +133,10 @@ namespace { TYPED_TEST_SUITE(BitwiseFloatsAdditional, RealFloatReducedFloatTestedTypes); TYPED_TEST_SUITE(BitwiseFloatsAdditional2, FloatTestedTypes); TYPED_TEST_SUITE(QuantizationTests, QuantTestedTypes); +<<<<<<< HEAD TYPED_TEST_SUITE(Quantization8BitTests, Quantization8BitTestedTypes); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TYPED_TEST_SUITE(InfiniteTests, RealFloatTestedTypes); #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER) TYPED_TEST_SUITE( @@ -691,7 +708,11 @@ namespace { AssertVectorized(NAME_INFO(DeInterleave FirstHalf), std::get<0>(cc), vec::loadu(vals)).check(true); AssertVectorized(NAME_INFO(DeInterleave SecondHalf), std::get<1>(cc), vec::loadu(vals + vec::size())).check(true); } +<<<<<<< HEAD TYPED_TEST(Arithmetic, Plus) { +======= + TYPED_TEST(Arithmetics, Plus) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using vec = TypeParam; using VT = ValueType; test_binary( @@ -703,7 +724,11 @@ namespace { createDefaultBinaryTestCase(TestSeed()), RESOLVE_OVERLOAD(filter_add_overflow)); } +<<<<<<< HEAD TYPED_TEST(Arithmetic, Minus) { +======= + TYPED_TEST(Arithmetics, Minus) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using vec = TypeParam; using VT = ValueType; test_binary( @@ -715,7 +740,11 @@ namespace { createDefaultBinaryTestCase(TestSeed()), RESOLVE_OVERLOAD(filter_sub_overflow)); } +<<<<<<< HEAD TYPED_TEST(Arithmetic, Multiplication) { +======= + TYPED_TEST(Arithmetics, Multiplication) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using vec = TypeParam; test_binary( NAME_INFO(mult), @@ -724,7 +753,11 @@ namespace { createDefaultBinaryTestCase(TestSeed(), false, true), RESOLVE_OVERLOAD(filter_mult_overflow)); } +<<<<<<< HEAD TYPED_TEST(Arithmetic, Division) { +======= + TYPED_TEST(Arithmetics, Division) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using vec = TypeParam; TestSeed seed; test_binary( @@ -1500,6 +1533,7 @@ namespace { }, test_case); } +<<<<<<< HEAD #ifndef _WIN32 TYPED_TEST(Quantization8BitTests, Transpose) { using VT = ValueType; @@ -1562,6 +1596,8 @@ namespace { } } #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TYPED_TEST(FunctionalTests, Map) { using vec = TypeParam; using VT = ValueType; diff --git a/aten/src/ATen/test/vec_test_all_types.h b/aten/src/ATen/test/vec_test_all_types.h index f7206cc340973..29a3287a88b13 100644 --- a/aten/src/ATen/test/vec_test_all_types.h +++ b/aten/src/ATen/test/vec_test_all_types.h @@ -1,7 +1,12 @@ #pragma once +<<<<<<< HEAD #include #include #include +======= +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -22,9 +27,13 @@ #else #define CACHE_LINE 32 #endif +<<<<<<< HEAD #ifndef _WIN32 #include #endif +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if defined(__GNUC__) #define CACHE_ALIGN __attribute__((aligned(CACHE_LINE))) #define not_inline __attribute__((noinline)) @@ -531,7 +540,11 @@ template std::enable_if_t::value, void> filter_div_ub(T& val1, T& val2) { //missing +<<<<<<< HEAD //at least consider zero division +======= + //at least consdier zero division +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto ret = std::abs(val2); if (ret == 0) { val2 = T(1, 2); @@ -1291,7 +1304,11 @@ std::enable_if_t>::value, Complex> local_multiply(Compl T y_real = y.real(); T y_imag = y.imag(); #if defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR) +<<<<<<< HEAD //check multiplication considering swap and fma +======= + //check multiplication considerin swap and fma +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T rr = x_real * y_real; T ii = x_imag * y_real; T neg_imag = -y_imag; @@ -1362,7 +1379,11 @@ std::enable_if_t>::value, Complex> local_division(Compl return Complex(rr, ii); #else /* defined(CPU_CAPABILITY_ZVECTOR) */ #if defined(CPU_CAPABILITY_VSX) +<<<<<<< HEAD //check multiplication considering swap and fma +======= + //check multiplication considerin swap and fma +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T rr = x_real * y_real; T ii = x_imag * y_real; T neg_imag = -y_imag; diff --git a/aten/src/ATen/test/verify_api_visibility.cpp b/aten/src/ATen/test/verify_api_visibility.cpp index c6d2fcc6fb865..c5bf24fd6fd35 100644 --- a/aten/src/ATen/test/verify_api_visibility.cpp +++ b/aten/src/ATen/test/verify_api_visibility.cpp @@ -20,8 +20,12 @@ #error "CAFFE2_STATIC_LINK_CUDA should not be visible in public headers" #endif +<<<<<<< HEAD #include TEST(VerifyApiVisibility, Test) { ASSERT_EQ(1, 1); } +======= +auto main() -> int {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index 263918af2662c..2d5a514ec3477 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -1232,7 +1232,11 @@ void test_matmul( } TEST_F(VulkanAPITest, DISABLED_matmul_3d_weight_vulkan) { +<<<<<<< HEAD // This will call at::bmm. Will crash for unknown reason. +======= + // This will call at::bmm. Will crash for unknow reason. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto m1_cpu = at::rand({13, 23, 45}, at::device(at::kCPU).dtype(at::kFloat)); const auto m2_cpu = @@ -1241,7 +1245,11 @@ TEST_F(VulkanAPITest, DISABLED_matmul_3d_weight_vulkan) { } TEST_F(VulkanAPITest, DISABLED_matmul_3d_weight_cpu) { +<<<<<<< HEAD // This will call at::bmm. Will crash for unknown reason. +======= + // This will call at::bmm. Will crash for unknow reason. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto m1_cpu = at::rand({13, 23, 45}, at::device(at::kCPU).dtype(at::kFloat)); const auto m2_cpu = @@ -2004,7 +2012,11 @@ TEST_F(VulkanAPITest, conv2d_pw_prepack_bc_medium) { 1); // groups } +<<<<<<< HEAD // The following 2 tests failed on Meta's CI when all tests are executed. Output +======= +// The followin 2 tests failed on Meta's CI when all tests are executed. Output +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // has lots of nan. Cause unknown. // When this test is run alone (with gtest_filter), it passes. // The test also passes with smaller planes, see "conv2d_pw_prepack_medium". @@ -5664,7 +5676,11 @@ TEST_F(VulkanAPITest, var_2d_unbiased) { test_var({3, 5}, {1}, true, true); test_var({3, 5}, {1}, true, false); +<<<<<<< HEAD // input.dim() == dim_list.size(), only keepdim == true is supported +======= + // inpu.dim() == dim_list.size(), only keepdim == true is supported +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test_var({3, 5}, {0, 1}, true, true); } @@ -5672,7 +5688,11 @@ TEST_F(VulkanAPITest, var_2d_biased) { test_var({3, 5}, {1}, false, true); test_var({3, 5}, {1}, false, false); +<<<<<<< HEAD // input.dim() == dim_list.size(), only keepdim == true is supported +======= + // inpu.dim() == dim_list.size(), only keepdim == true is supported +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test_var({3, 5}, {0, 1}, false, true); } @@ -7142,12 +7162,20 @@ TEST_F(VulkanAPITest, clone_success) { } TEST_F(VulkanAPITest, clone_invalidinputs_exceptions) { +<<<<<<< HEAD // Act: Vulkan supports Preserve and Contiguous memory formats +======= + // Act: Vulkan supports Preserve and Contiguous memory foramts +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) EXPECT_THROW({ clone_test({2, 3, 5, 161}, c10::MemoryFormat::ChannelsLast); }, ::std::exception); +<<<<<<< HEAD // Act: Vulkan supports Preserve and Contiguous memory formats +======= + // Act: Vulkan supports Preserve and Contiguous memory foramts +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) EXPECT_THROW({ clone_test({2, 3, 5, 161}, c10::MemoryFormat::ChannelsLast3d); }, ::std::exception); diff --git a/aten/src/ATen/test/vulkan_quantized_api_test.cpp b/aten/src/ATen/test/vulkan_quantized_api_test.cpp index 2829aed94def9..7dfc6e76f4009 100644 --- a/aten/src/ATen/test/vulkan_quantized_api_test.cpp +++ b/aten/src/ATen/test/vulkan_quantized_api_test.cpp @@ -2116,7 +2116,11 @@ std::tuple produce_inputs_for_binary_op( input2_cpu = produce_random_tensor(input2_shape); if (compute_quantization_params) { +<<<<<<< HEAD // compute appropriate scale and zero point for inputs +======= + // compute appropiate scale and zero point for inputs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto in1_quant_params = compute_quant_params(input1_cpu); in1_scale = std::get<0>(in1_quant_params); in1_zero_point = std::get<1>(in1_quant_params); @@ -2287,7 +2291,11 @@ void test_quantized_binary_op( apply_cpu_quantized_binary_op(op_name, input1_cpu_deq, input2_cpu_deq); if (compute_quantization_params || random_quantization_params) { +<<<<<<< HEAD // compute appropriate scale and zero point for output +======= + // compute appropiate scale and zero point for output +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto out_quant_params = compute_quant_params(output_cpu); out_scale = std::get<0>(out_quant_params); out_zero_point = std::get<1>(out_quant_params); @@ -2540,7 +2548,11 @@ void test_quantized_conv2d( bias_cpu = produce_random_tensor(bias_shape, 1.26, 5.97, 0.59); if (compute_quantization_params) { +<<<<<<< HEAD // compute appropriate scale and zero point for input, weight and bias +======= + // compute appropiate scale and zero point for input, weight and bias +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto in_quant_params = compute_quant_params(input_cpu, in_dtype); in_scale = std::get<0>(in_quant_params); in_zero_point = std::get<1>(in_quant_params); @@ -2624,7 +2636,11 @@ void test_quantized_conv2d( groups); if (compute_quantization_params || random_quantization_params) { +<<<<<<< HEAD // compute appropriate scale and zero point for output +======= + // compute appropiate scale and zero point for output +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto out_quant_params = compute_quant_params(output_cpu, out_dtype); out_scale = std::get<0>(out_quant_params); out_zero_point = std::get<1>(out_quant_params); @@ -3524,7 +3540,11 @@ TEST_F(VulkanAPITest, linear_4d_large) { test_quantized_linear({9, 13, 11, 17}, {23, 17}, {23}); } +<<<<<<< HEAD // The following code is not directly related to quantization. We put it here +======= +// The following code is not directly releated to quantization. We put it here +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // since we are not able to run this test on GH's CI: for some unknown reason, // we are not able to reference symbols in the vulkan directory, hence the build // on GH fails. Moving the test here so we are still able to run it on @@ -3566,7 +3586,11 @@ TEST_F(VulkanAPITest, extract_texel_test) { // is the channel count. // We always start a new batch on a new z. Hence, when c cannot be divided by // 4, there are some undefined values in the padding area. We use -1 to +<<<<<<< HEAD // indicate that we are not performing comparison on those values. +======= + // indicate that we are not performing comparsion on those values. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::tuple test_cases[]{ {{0, 0, 0}, {0, hw, 2 * hw, 3 * hw}}, {{1, 0, 0}, {1, hw + 1, 2 * hw + 1, 3 * hw + 1}}, @@ -3672,7 +3696,11 @@ TEST_F(VulkanAPITest, channel_to_width_packing_test) { at::Tensor output = at::native::vulkan::ops::convert(v_output); // This tensor will be width-packed. Meaning that each texel represent +<<<<<<< HEAD // consecutive elements along the width dimension. The difference between +======= + // consecutive elements along the width dimension. The differece between +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // consecutive texels is 1. std::tuple test_cases[]{ {{0, 0, 0}, {0, 1, 2, 3}}, diff --git a/aten/src/ATen/xpu/CachingHostAllocator.cpp b/aten/src/ATen/xpu/CachingHostAllocator.cpp index d531b46c3c554..f8ff2d9eaee27 100644 --- a/aten/src/ATen/xpu/CachingHostAllocator.cpp +++ b/aten/src/ATen/xpu/CachingHostAllocator.cpp @@ -30,12 +30,15 @@ struct XPUCachingHostAllocatorImpl bool query_event(XPUEvent& event) override { return event.query(); } +<<<<<<< HEAD bool pinned_use_background_threads() override { // Using background threads for XPU causes a hang on Windows during program // exit. Will be enabled once the issue is resolved. return false; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; DECLARE_HOST_ALLOCATOR( diff --git a/aten/src/ATen/xpu/XPUEvent.h b/aten/src/ATen/xpu/XPUEvent.h index 19d42aae080f1..e76aef28e2ead 100644 --- a/aten/src/ATen/xpu/XPUEvent.h +++ b/aten/src/ATen/xpu/XPUEvent.h @@ -12,7 +12,11 @@ namespace at::xpu { * must match the same device. * * Currently, XPUEvent does NOT support to export an inter-process event from +<<<<<<< HEAD * another process via inter-process communication(IPC). So it means that +======= + * another process via inter-process comunication(IPC). So it means that +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * inter-process communication for event handles between different processes is * not available. This could impact some applications that rely on cross-process * synchronization and communication. diff --git a/aten/src/README.md b/aten/src/README.md index fa279c89d26ca..904d1bede039f 100644 --- a/aten/src/README.md +++ b/aten/src/README.md @@ -8,7 +8,11 @@ multiple variants of the library, summarized here: * THC = TorcH Cuda * THCS = TorcH Cuda Sparse (now defunct) * THNN = TorcH Neural Network (now defunct) +<<<<<<< HEAD * THS = TorcH Sparse (now defunct) +======= +* THS = TorcH Sparse (now defunct) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (You'll also see these abbreviations show up in symbol names.) diff --git a/benchmarks/README.md b/benchmarks/README.md index 4ea84bcafab46..d004333a6990b 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -8,12 +8,20 @@ It also provides mechanisms to compare PyTorch with other frameworks. Make sure you're on a machine with CUDA, torchvision, and pytorch installed. Install in the following order: ``` # Install torchvision. It comes with the pytorch stable release binary +<<<<<<< HEAD python -m pip install torch torchvision +======= +pip3 install torch torchvision +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Install the latest pytorch master from source. # It should supersede the installation from the release binary. cd $PYTORCH_HOME +<<<<<<< HEAD python -m pip install --no-build-isolation -v -e . +======= +python setup.py build develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Check the pytorch installation version python -c "import torch; print(torch.__version__)" @@ -31,4 +39,7 @@ Please refer to each subfolder to discover each benchmark suite. Links are provi * [Overrides](overrides_benchmark/README.md) * [Sparse](sparse/README.md) * [Tensor expression](tensorexpr/HowToRun.md) +<<<<<<< HEAD * [Data](data/README.md) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/Makefile b/benchmarks/dynamo/Makefile index c5a0a20aaa690..690d468e5dd84 100644 --- a/benchmarks/dynamo/Makefile +++ b/benchmarks/dynamo/Makefile @@ -27,7 +27,11 @@ pull-deps: clone-deps (cd ../../../torchbenchmark && git fetch && git checkout "$$(cat ../pytorch/.github/ci_commit_pins/torchbench.txt)" && git submodule update --init --recursive) build-deps: clone-deps +<<<<<<< HEAD uv pip install numpy scipy ninja pyyaml six mkl mkl-include setuptools wheel cmake \ +======= + uv pip install astunparse numpy scipy ninja pyyaml mkl mkl-include setuptools cmake \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) typing-extensions requests protobuf numba cython scikit-learn librosa (cd ../../../torchvision && uv pip install -e . --no-build-isolation) (cd ../../../torchdata && uv pip install -e .) diff --git a/benchmarks/dynamo/benchmarks.py b/benchmarks/dynamo/benchmarks.py index 25c1e8203a0a9..01281723a1733 100755 --- a/benchmarks/dynamo/benchmarks.py +++ b/benchmarks/dynamo/benchmarks.py @@ -5,12 +5,15 @@ import sys +<<<<<<< HEAD # Run only this selected group of models, leave this empty to run everything TORCHBENCH_ONLY_MODELS = [ m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip() ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Note - hf and timm have their own version of this, torchbench does not # TODO(voz): Someday, consolidate all the files into one runner instead of a shim like this... def model_names(filename: str) -> set[str]: @@ -23,8 +26,11 @@ def model_names(filename: str) -> set[str]: if len(line_parts) == 1: line_parts = line.split(",") model_name = line_parts[0] +<<<<<<< HEAD if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS: continue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) names.add(model_name) return names diff --git a/benchmarks/dynamo/check_accuracy.py b/benchmarks/dynamo/check_accuracy.py index 678cee5f752c3..1d6e7d496a578 100644 --- a/benchmarks/dynamo/check_accuracy.py +++ b/benchmarks/dynamo/check_accuracy.py @@ -14,9 +14,12 @@ "detectron2_maskrcnn_r_101_c4", "timm_efficientnet", # see https://github.com/pytorch/pytorch/issues/148699 "XGLMForCausalLM", # discovered in https://github.com/pytorch/pytorch/pull/128148 +<<<<<<< HEAD "moondream", # discovered in https://github.com/pytorch/pytorch/pull/159291 # discovered in https://github.com/pytorch/pytorch/issues/161419. Its not flaky but really hard to repro, so skipping it "mobilenetv3_large_100", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } diff --git a/benchmarks/dynamo/check_graph_breaks.py b/benchmarks/dynamo/check_graph_breaks.py index 57814dacd00b3..0d8f2843823a6 100644 --- a/benchmarks/dynamo/check_graph_breaks.py +++ b/benchmarks/dynamo/check_graph_breaks.py @@ -13,7 +13,10 @@ "gluon_inception_v3", "detectron2_maskrcnn_r_101_c4", "XGLMForCausalLM", # discovered in https://github.com/pytorch/pytorch/pull/128148 +<<<<<<< HEAD "detectron2_fcos_r_50_fpn", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv index 0f088e7892d8f..a112b859fdc5d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv @@ -46,6 +46,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +165,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv index f65909f3a24ea..bb1bb98df6647 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_training.csv @@ -46,6 +46,17 @@ CamemBert,pass,5 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,5 +<<<<<<< HEAD DistillGPT2,pass,7 +======= +DistillGPT2,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -122,7 +137,11 @@ MobileBertForQuestionAnswering,pass,3 +<<<<<<< HEAD OPTForCausalLM,pass,8 +======= +OPTForCausalLM,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +169,13 @@ RobertaForQuestionAnswering,pass,5 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,6 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index 1d199fe8ea664..a1e3d8aaf928f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -106,7 +106,11 @@ dlrm,pass,0 +<<<<<<< HEAD doctr_det_predictor,pass,3 +======= +doctr_det_predictor,pass,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -162,7 +166,11 @@ hf_GPT2_large,pass_due_to_skip,0 +<<<<<<< HEAD hf_Reformer,pass,8 +======= +hf_Reformer,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -178,7 +186,11 @@ hf_T5_base,eager_fail_to_run,0 +<<<<<<< HEAD hf_T5_generate,pass,11 +======= +hf_T5_generate,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv index 54b7d63f3a4bc..b0a38a0e4e3ce 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv @@ -102,7 +102,11 @@ hf_DistilBert,pass,6 +<<<<<<< HEAD hf_GPT2,pass,8 +======= +hf_GPT2,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -110,7 +114,11 @@ hf_GPT2_large,pass_due_to_skip,0 +<<<<<<< HEAD hf_Reformer,pass,25 +======= +hf_Reformer,pass,23 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv index 169a42ff7cd41..8d5e81c966f48 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv @@ -42,6 +42,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -146,6 +157,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv index 169a42ff7cd41..8d5e81c966f48 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv @@ -42,6 +42,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -146,6 +157,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv index 0f088e7892d8f..a112b859fdc5d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv @@ -46,6 +46,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +165,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_timm_inference.csv index c7d283b9aa52d..a303245efbb78 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_timm_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_timm_inference.csv @@ -46,7 +46,11 @@ deit_base_distilled_patch16_224,pass,0 +<<<<<<< HEAD dla102,timeout,0 +======= +dla102,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv index a4dbaeb7b546d..aca1aa460742c 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv @@ -98,11 +98,19 @@ dlrm,pass,0 +<<<<<<< HEAD doctr_det_predictor,pass,3 doctr_reco_predictor,pass,1 +======= +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -138,7 +146,11 @@ hf_Bert_large,pass,0 +<<<<<<< HEAD hf_BigBird,pass,25 +======= +hf_BigBird,pass,24 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -158,7 +170,11 @@ hf_Longformer,pass,4 +<<<<<<< HEAD hf_Reformer,pass,8 +======= +hf_Reformer,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -210,7 +226,11 @@ mobilenet_v2,pass,0 +<<<<<<< HEAD mobilenet_v2_quantized_qat,pass,3 +======= +mobilenet_v2_quantized_qat,pass,2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -274,7 +294,11 @@ resnet50,pass,0 +<<<<<<< HEAD resnet50_quantized_qat,pass,3 +======= +resnet50_quantized_qat,pass,2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -346,7 +370,11 @@ vgg16,pass,0 +<<<<<<< HEAD vision_maskrcnn,fail_accuracy,29 +======= +vision_maskrcnn,fail_accuracy,30 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv index 0f088e7892d8f..a112b859fdc5d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv @@ -46,6 +46,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +165,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_timm_inference.csv index c7d283b9aa52d..a303245efbb78 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_timm_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_timm_inference.csv @@ -46,7 +46,11 @@ deit_base_distilled_patch16_224,pass,0 +<<<<<<< HEAD dla102,timeout,0 +======= +dla102,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv index 885029ba8c56e..8dae78891bd80 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv @@ -98,11 +98,19 @@ dlrm,pass,0 +<<<<<<< HEAD doctr_det_predictor,pass,3 doctr_reco_predictor,pass,1 +======= +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -138,7 +146,11 @@ hf_Bert_large,pass,0 +<<<<<<< HEAD hf_BigBird,pass,25 +======= +hf_BigBird,pass,24 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -158,7 +170,11 @@ hf_Longformer,pass,4 +<<<<<<< HEAD hf_Reformer,pass,8 +======= +hf_Reformer,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -210,7 +226,11 @@ mobilenet_v2,pass,0 +<<<<<<< HEAD mobilenet_v2_quantized_qat,pass,3 +======= +mobilenet_v2_quantized_qat,pass,2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -274,7 +294,11 @@ resnet50,pass,0 +<<<<<<< HEAD resnet50_quantized_qat,pass,3 +======= +resnet50_quantized_qat,pass,2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv index 0f088e7892d8f..a112b859fdc5d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv @@ -46,6 +46,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +165,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv index aa7a3161afcc6..9e550ce75b3c1 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv @@ -98,11 +98,19 @@ dlrm,pass,0 +<<<<<<< HEAD doctr_det_predictor,pass,3 doctr_reco_predictor,pass,1 +======= +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -138,7 +146,11 @@ hf_Bert_large,pass,0 +<<<<<<< HEAD hf_BigBird,pass,25 +======= +hf_BigBird,pass,24 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -158,7 +170,11 @@ hf_Longformer,pass,4 +<<<<<<< HEAD hf_Reformer,pass,8 +======= +hf_Reformer,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -210,7 +226,11 @@ mobilenet_v2,pass,0 +<<<<<<< HEAD mobilenet_v2_quantized_qat,pass,3 +======= +mobilenet_v2_quantized_qat,pass,2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -274,7 +294,11 @@ resnet50,pass,0 +<<<<<<< HEAD resnet50_quantized_qat,pass,3 +======= +resnet50_quantized_qat,pass,2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv index 0f088e7892d8f..a112b859fdc5d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv @@ -46,6 +46,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +165,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv index f65909f3a24ea..bb1bb98df6647 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_training.csv @@ -46,6 +46,17 @@ CamemBert,pass,5 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,5 +<<<<<<< HEAD DistillGPT2,pass,7 +======= +DistillGPT2,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -122,7 +137,11 @@ MobileBertForQuestionAnswering,pass,3 +<<<<<<< HEAD OPTForCausalLM,pass,8 +======= +OPTForCausalLM,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +169,13 @@ RobertaForQuestionAnswering,pass,5 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,6 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index 20cad351b1275..c22044beb01b6 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -106,7 +106,11 @@ dlrm,pass,0 +<<<<<<< HEAD doctr_det_predictor,pass,3 +======= +doctr_det_predictor,pass,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -146,7 +150,11 @@ hf_Bert_large,pass,0 +<<<<<<< HEAD hf_BigBird,pass,0 +======= +hf_BigBird,fail_to_run,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -162,7 +170,11 @@ hf_GPT2_large,pass_due_to_skip,0 +<<<<<<< HEAD hf_Reformer,pass,8 +======= +hf_Reformer,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -178,7 +190,11 @@ hf_T5_base,eager_fail_to_run,0 +<<<<<<< HEAD hf_T5_generate,pass,11 +======= +hf_T5_generate,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv index 5050b3762ed96..9ebf64cedafaa 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv @@ -2,7 +2,11 @@ name,accuracy,graph_breaks +<<<<<<< HEAD torchrec_dlrm,pass,6 +======= +torchrec_dlrm,fail_to_run,3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -94,7 +98,11 @@ hf_Bert_large,pass,6 +<<<<<<< HEAD hf_BigBird,pass,6 +======= +hf_BigBird,fail_to_run,3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -102,7 +110,11 @@ hf_DistilBert,pass,6 +<<<<<<< HEAD hf_GPT2,pass,8 +======= +hf_GPT2,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -110,7 +122,11 @@ hf_GPT2_large,pass_due_to_skip,0 +<<<<<<< HEAD hf_Reformer,pass,25 +======= +hf_Reformer,fail_to_run,19 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv index b0e8f34b964ec..6a1a552036dec 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv @@ -34,7 +34,11 @@ basic_gnn_gin,pass,0 +<<<<<<< HEAD basic_gnn_sage,pass,0 +======= +basic_gnn_sage,fail_to_run,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv index 0f088e7892d8f..a112b859fdc5d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv @@ -46,6 +46,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +165,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv index f26dea6f692ef..c2a09eae9b022 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv @@ -82,11 +82,19 @@ dlrm,pass,0 +<<<<<<< HEAD doctr_det_predictor,pass,3 doctr_reco_predictor,pass,1 +======= +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -122,7 +130,11 @@ hf_Bert_large,pass,0 +<<<<<<< HEAD hf_BigBird,pass,25 +======= +hf_BigBird,pass,24 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -142,7 +154,11 @@ hf_Longformer,pass,4 +<<<<<<< HEAD hf_Reformer,pass,8 +======= +hf_Reformer,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -194,7 +210,11 @@ mobilenet_v2,pass,0 +<<<<<<< HEAD mobilenet_v2_quantized_qat,pass,3 +======= +mobilenet_v2_quantized_qat,pass,2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -258,7 +278,11 @@ resnet50,pass,0 +<<<<<<< HEAD resnet50_quantized_qat,pass,3 +======= +resnet50_quantized_qat,pass,2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_huggingface_inference.csv index 0f088e7892d8f..a112b859fdc5d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_huggingface_inference.csv @@ -46,6 +46,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +165,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv index 39149853947c3..a9ed51d405985 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv @@ -98,11 +98,19 @@ dlrm,pass,0 +<<<<<<< HEAD doctr_det_predictor,pass,3 doctr_reco_predictor,pass,1 +======= +doctr_det_predictor,pass,5 + + + +doctr_reco_predictor,pass,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -138,7 +146,11 @@ hf_Bert_large,pass,0 +<<<<<<< HEAD hf_BigBird,pass,25 +======= +hf_BigBird,pass,24 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -158,7 +170,11 @@ hf_Longformer,pass,4 +<<<<<<< HEAD hf_Reformer,pass,8 +======= +hf_Reformer,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -210,7 +226,11 @@ mobilenet_v2,pass,0 +<<<<<<< HEAD mobilenet_v2_quantized_qat,pass,3 +======= +mobilenet_v2_quantized_qat,pass,2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -274,7 +294,11 @@ resnet50,pass,0 +<<<<<<< HEAD resnet50_quantized_qat,pass,3 +======= +resnet50_quantized_qat,pass,2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -346,7 +370,11 @@ vgg16,pass,0 +<<<<<<< HEAD vision_maskrcnn,fail_accuracy,29 +======= +vision_maskrcnn,fail_accuracy,30 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv index 0f088e7892d8f..a112b859fdc5d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv @@ -46,6 +46,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +165,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv index f65909f3a24ea..bb1bb98df6647 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_training.csv @@ -46,6 +46,17 @@ CamemBert,pass,5 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,5 +<<<<<<< HEAD DistillGPT2,pass,7 +======= +DistillGPT2,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -122,7 +137,11 @@ MobileBertForQuestionAnswering,pass,3 +<<<<<<< HEAD OPTForCausalLM,pass,8 +======= +OPTForCausalLM,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +169,13 @@ RobertaForQuestionAnswering,pass,5 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,6 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index 2b2c1a504647f..b5ffbd507fe1b 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -106,7 +106,11 @@ dlrm,pass,0 +<<<<<<< HEAD doctr_det_predictor,pass,3 +======= +doctr_det_predictor,pass,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -146,7 +150,11 @@ hf_Bert_large,pass,0 +<<<<<<< HEAD hf_BigBird,fail_accuracy,0 +======= +hf_BigBird,fail_to_run,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -162,7 +170,11 @@ hf_GPT2_large,pass_due_to_skip,0 +<<<<<<< HEAD hf_Reformer,pass,8 +======= +hf_Reformer,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -178,7 +190,11 @@ hf_T5_base,eager_fail_to_run,0 +<<<<<<< HEAD hf_T5_generate,pass,11 +======= +hf_T5_generate,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv index 89871fd49a04b..d08e90053e9d0 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv @@ -2,7 +2,11 @@ name,accuracy,graph_breaks +<<<<<<< HEAD torchrec_dlrm,pass,6 +======= +torchrec_dlrm,fail_to_run,3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -46,7 +50,11 @@ dcgan,pass,6 +<<<<<<< HEAD demucs,pass,9 +======= +demucs,fail_to_run,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -94,7 +102,11 @@ hf_Bert_large,pass,6 +<<<<<<< HEAD hf_BigBird,pass,6 +======= +hf_BigBird,fail_to_run,3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -102,7 +114,11 @@ hf_DistilBert,pass,6 +<<<<<<< HEAD hf_GPT2,pass,8 +======= +hf_GPT2,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -110,7 +126,11 @@ hf_GPT2_large,pass_due_to_skip,0 +<<<<<<< HEAD hf_Reformer,pass,25 +======= +hf_Reformer,fail_to_run,19 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv index 0f088e7892d8f..a112b859fdc5d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv @@ -46,6 +46,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +165,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv index f65909f3a24ea..bb1bb98df6647 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_training.csv @@ -46,6 +46,17 @@ CamemBert,pass,5 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,5 +<<<<<<< HEAD DistillGPT2,pass,7 +======= +DistillGPT2,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -122,7 +137,11 @@ MobileBertForQuestionAnswering,pass,3 +<<<<<<< HEAD OPTForCausalLM,pass,8 +======= +OPTForCausalLM,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +169,13 @@ RobertaForQuestionAnswering,pass,5 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,6 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index 1d199fe8ea664..a1e3d8aaf928f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -106,7 +106,11 @@ dlrm,pass,0 +<<<<<<< HEAD doctr_det_predictor,pass,3 +======= +doctr_det_predictor,pass,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -162,7 +166,11 @@ hf_GPT2_large,pass_due_to_skip,0 +<<<<<<< HEAD hf_Reformer,pass,8 +======= +hf_Reformer,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -178,7 +186,11 @@ hf_T5_base,eager_fail_to_run,0 +<<<<<<< HEAD hf_T5_generate,pass,11 +======= +hf_T5_generate,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv index 0985e42fc5cb9..0880de6576175 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv @@ -102,7 +102,11 @@ hf_DistilBert,pass,6 +<<<<<<< HEAD hf_GPT2,pass,8 +======= +hf_GPT2,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -110,7 +114,11 @@ hf_GPT2_large,pass_due_to_skip,0 +<<<<<<< HEAD hf_Reformer,pass,25 +======= +hf_Reformer,pass,23 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv index 0f088e7892d8f..a112b859fdc5d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv @@ -46,6 +46,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +165,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv index f65909f3a24ea..bb1bb98df6647 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv @@ -46,6 +46,17 @@ CamemBert,pass,5 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,5 +<<<<<<< HEAD DistillGPT2,pass,7 +======= +DistillGPT2,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -122,7 +137,11 @@ MobileBertForQuestionAnswering,pass,3 +<<<<<<< HEAD OPTForCausalLM,pass,8 +======= +OPTForCausalLM,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +169,13 @@ RobertaForQuestionAnswering,pass,5 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,6 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index e41018657c0e2..22857c6b31c52 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -106,7 +106,11 @@ dlrm,pass,0 +<<<<<<< HEAD doctr_det_predictor,pass,3 +======= +doctr_det_predictor,pass,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -162,7 +166,11 @@ hf_GPT2_large,pass_due_to_skip,0 +<<<<<<< HEAD hf_Reformer,pass,8 +======= +hf_Reformer,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -178,7 +186,11 @@ hf_T5_base,eager_fail_to_run,0 +<<<<<<< HEAD hf_T5_generate,pass,11 +======= +hf_T5_generate,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index 54b7d63f3a4bc..b0a38a0e4e3ce 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -102,7 +102,11 @@ hf_DistilBert,pass,6 +<<<<<<< HEAD hf_GPT2,pass,8 +======= +hf_GPT2,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -110,7 +114,11 @@ hf_GPT2_large,pass_due_to_skip,0 +<<<<<<< HEAD hf_Reformer,pass,25 +======= +hf_Reformer,pass,23 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv index 0f088e7892d8f..a112b859fdc5d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv @@ -46,6 +46,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +165,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_training.csv index 08061de428d71..ef3b967e2092d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_training.csv @@ -46,6 +46,17 @@ CamemBert,pass,5 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,5 +<<<<<<< HEAD DistillGPT2,pass,7 +======= +DistillGPT2,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -122,7 +137,11 @@ MobileBertForQuestionAnswering,pass,3 +<<<<<<< HEAD OPTForCausalLM,pass,8 +======= +OPTForCausalLM,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +169,13 @@ RobertaForQuestionAnswering,pass,5 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,6 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_inference.csv index bf70642a855ef..5372e646ec66e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_inference.csv @@ -106,11 +106,19 @@ dlrm,pass,0 +<<<<<<< HEAD doctr_det_predictor,eager_fail_to_run,3 doctr_reco_predictor,eager_fail_to_run,1 +======= +doctr_det_predictor,eager_fail_to_run,5 + + + +doctr_reco_predictor,eager_fail_to_run,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -166,7 +174,11 @@ hf_Longformer,pass,4 +<<<<<<< HEAD hf_Reformer,pass,8 +======= +hf_Reformer,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -181,7 +193,11 @@ hf_T5_base,pass,0 +<<<<<<< HEAD hf_T5_generate,pass,11 +======= +hf_T5_generate,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv index 48d0b111788f7..57b60fdaa2da9 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_training.csv @@ -102,7 +102,11 @@ hf_DistilBert,pass,6 +<<<<<<< HEAD hf_GPT2,pass,8 +======= +hf_GPT2,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -114,7 +118,11 @@ hf_Longformer,pass,4 +<<<<<<< HEAD hf_Reformer,pass,25 +======= +hf_Reformer,pass,23 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_inductor_huggingface_inference.csv index ce334e22c698b..fe67f0ad22f66 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_inductor_huggingface_inference.csv @@ -42,6 +42,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -58,7 +69,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -146,6 +161,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv index 0f088e7892d8f..a112b859fdc5d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv @@ -46,6 +46,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +165,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_training.csv index 08061de428d71..ef3b967e2092d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_training.csv @@ -46,6 +46,17 @@ CamemBert,pass,5 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,5 +<<<<<<< HEAD DistillGPT2,pass,7 +======= +DistillGPT2,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -122,7 +137,11 @@ MobileBertForQuestionAnswering,pass,3 +<<<<<<< HEAD OPTForCausalLM,pass,8 +======= +OPTForCausalLM,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +169,13 @@ RobertaForQuestionAnswering,pass,5 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,6 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_inference.csv index e019365ccbfdb..df0e2eff672ef 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_inference.csv @@ -106,11 +106,19 @@ dlrm,pass,0 +<<<<<<< HEAD doctr_det_predictor,eager_fail_to_run,3 doctr_reco_predictor,eager_fail_to_run,1 +======= +doctr_det_predictor,eager_fail_to_run,5 + + + +doctr_reco_predictor,eager_fail_to_run,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -166,7 +174,11 @@ hf_Longformer,pass,4 +<<<<<<< HEAD hf_Reformer,pass,8 +======= +hf_Reformer,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -181,7 +193,11 @@ hf_T5_base,pass,0 +<<<<<<< HEAD hf_T5_generate,pass,11 +======= +hf_T5_generate,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv index 643a02fdca8fd..1b879006c2c7e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_training.csv @@ -102,7 +102,11 @@ hf_DistilBert,pass,6 +<<<<<<< HEAD hf_GPT2,pass,8 +======= +hf_GPT2,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -114,7 +118,11 @@ hf_Longformer,pass,4 +<<<<<<< HEAD hf_Reformer,pass,25 +======= +hf_Reformer,pass,23 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv index 0f088e7892d8f..a112b859fdc5d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv @@ -46,6 +46,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +165,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_training.csv index f65909f3a24ea..bb1bb98df6647 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_training.csv @@ -46,6 +46,17 @@ CamemBert,pass,5 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,5 +<<<<<<< HEAD DistillGPT2,pass,7 +======= +DistillGPT2,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -122,7 +137,11 @@ MobileBertForQuestionAnswering,pass,3 +<<<<<<< HEAD OPTForCausalLM,pass,8 +======= +OPTForCausalLM,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +169,13 @@ RobertaForQuestionAnswering,pass,5 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,6 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv index fed8ebded682c..95d87bc9a8043 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv @@ -106,11 +106,19 @@ dlrm,pass,0 +<<<<<<< HEAD doctr_det_predictor,eager_fail_to_run,3 doctr_reco_predictor,eager_fail_to_run,1 +======= +doctr_det_predictor,eager_fail_to_run,5 + + + +doctr_reco_predictor,eager_fail_to_run,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -162,7 +170,11 @@ hf_GPT2_large,pass_due_to_skip,0 +<<<<<<< HEAD hf_Reformer,pass,8 +======= +hf_Reformer,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -174,7 +186,11 @@ hf_T5_base,eager_fail_to_run,0 +<<<<<<< HEAD hf_T5_generate,pass,11 +======= +hf_T5_generate,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv index ced88884720b7..2b150d617701e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv @@ -102,7 +102,11 @@ hf_DistilBert,pass,6 +<<<<<<< HEAD hf_GPT2,pass,8 +======= +hf_GPT2,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -110,7 +114,11 @@ hf_GPT2_large,pass_due_to_skip,0 +<<<<<<< HEAD hf_Reformer,pass,25 +======= +hf_Reformer,pass,23 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv index 0f088e7892d8f..a112b859fdc5d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv @@ -46,6 +46,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +165,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_training.csv index 08061de428d71..ef3b967e2092d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_training.csv @@ -46,6 +46,17 @@ CamemBert,pass,5 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,5 +<<<<<<< HEAD DistillGPT2,pass,7 +======= +DistillGPT2,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -122,7 +137,11 @@ MobileBertForQuestionAnswering,pass,3 +<<<<<<< HEAD OPTForCausalLM,pass,8 +======= +OPTForCausalLM,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +169,13 @@ RobertaForQuestionAnswering,pass,5 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,6 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_inference.csv index bf70642a855ef..5372e646ec66e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_inference.csv @@ -106,11 +106,19 @@ dlrm,pass,0 +<<<<<<< HEAD doctr_det_predictor,eager_fail_to_run,3 doctr_reco_predictor,eager_fail_to_run,1 +======= +doctr_det_predictor,eager_fail_to_run,5 + + + +doctr_reco_predictor,eager_fail_to_run,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -166,7 +174,11 @@ hf_Longformer,pass,4 +<<<<<<< HEAD hf_Reformer,pass,8 +======= +hf_Reformer,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -181,7 +193,11 @@ hf_T5_base,pass,0 +<<<<<<< HEAD hf_T5_generate,pass,11 +======= +hf_T5_generate,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv index d1606b622639e..d2dfb42ab817e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_training.csv @@ -102,7 +102,11 @@ hf_DistilBert,pass,6 +<<<<<<< HEAD hf_GPT2,pass,8 +======= +hf_GPT2,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -114,7 +118,11 @@ hf_Longformer,pass,4 +<<<<<<< HEAD hf_Reformer,pass,25 +======= +hf_Reformer,pass,23 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv index 0f088e7892d8f..a112b859fdc5d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv @@ -46,6 +46,17 @@ CamemBert,pass,0 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,0 + + + +DebertaForQuestionAnswering,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,0 +<<<<<<< HEAD DistillGPT2,pass,2 +======= +DistillGPT2,pass,0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +165,13 @@ RobertaForQuestionAnswering,pass,0 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,0 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_training.csv index f65909f3a24ea..bb1bb98df6647 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_training.csv @@ -46,6 +46,17 @@ CamemBert,pass,5 +<<<<<<< HEAD +======= +DebertaForMaskedLM,pass,5 + + + +DebertaForQuestionAnswering,pass,5 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,pass_due_to_skip,0 @@ -62,7 +73,11 @@ DistilBertForQuestionAnswering,pass,5 +<<<<<<< HEAD DistillGPT2,pass,7 +======= +DistillGPT2,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -122,7 +137,11 @@ MobileBertForQuestionAnswering,pass,3 +<<<<<<< HEAD OPTForCausalLM,pass,8 +======= +OPTForCausalLM,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -150,6 +169,13 @@ RobertaForQuestionAnswering,pass,5 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,pass,6 + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,pass,5 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_inference.csv index 014e23e41cb31..ab3f6d79f60a4 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_inference.csv @@ -106,11 +106,19 @@ dlrm,pass,0 +<<<<<<< HEAD doctr_det_predictor,eager_fail_to_run,3 doctr_reco_predictor,eager_fail_to_run,1 +======= +doctr_det_predictor,eager_fail_to_run,5 + + + +doctr_reco_predictor,eager_fail_to_run,4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -162,7 +170,11 @@ hf_GPT2_large,pass_due_to_skip,0 +<<<<<<< HEAD hf_Reformer,pass,8 +======= +hf_Reformer,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -174,7 +186,11 @@ hf_T5_base,eager_fail_to_run,0 +<<<<<<< HEAD hf_T5_generate,pass,11 +======= +hf_T5_generate,pass,5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv index e842ac7cb8e1f..f69676c7dc6ac 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_torchbench_training.csv @@ -102,7 +102,11 @@ hf_DistilBert,pass,6 +<<<<<<< HEAD hf_GPT2,pass,8 +======= +hf_GPT2,pass,6 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @@ -110,7 +114,11 @@ hf_GPT2_large,pass_due_to_skip,0 +<<<<<<< HEAD hf_Reformer,pass,25 +======= +hf_Reformer,pass,23 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 2901009f7c4d1..77cbe1727b923 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -9,23 +9,36 @@ import csv import dataclasses import functools +<<<<<<< HEAD import gc +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import importlib import itertools import json import logging import os +<<<<<<< HEAD import platform +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import random import shutil import signal import subprocess import sys +<<<<<<< HEAD import tempfile import time import weakref from contextlib import contextmanager from typing import Any, NamedTuple, Optional, overload, TYPE_CHECKING, TypeVar +======= +import time +import weakref +from contextlib import contextmanager +from typing import Any, NamedTuple, TYPE_CHECKING +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from unittest.mock import MagicMock import numpy as np @@ -42,7 +55,10 @@ import torch.distributed import torch.multiprocessing as mp from torch._C import _has_cuda as HAS_CUDA, _has_xpu as HAS_XPU +<<<<<<< HEAD from torch._C._nativert import PyModelRunner +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo.profiler import fx_insert_profiling, Profiler from torch._dynamo.testing import ( dummy_fx_compile, @@ -58,7 +74,10 @@ from torch._inductor.utils import fresh_cache except ImportError: from _dynamo.utils import clone_inputs, graph_break_reasons +<<<<<<< HEAD from _inductor.utils import fresh_cache +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch._functorch.config from torch._functorch.aot_autograd import set_model_name @@ -80,10 +99,14 @@ if TYPE_CHECKING: +<<<<<<< HEAD from collections.abc import Sequence _D = TypeVar("_D", bound=dict[str, Any]) _T = TypeVar("_T") +======= + from collections.abc import Mapping +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) @@ -204,6 +227,10 @@ class CI(NamedTuple): "PLBartForCausalLM", "PLBartForConditionalGeneration", "PegasusForCausalLM", +<<<<<<< HEAD +======= + "Speech2Text2ForCausalLM", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "TrOCRForCausalLM", "XGLMForCausalLM", # TIMM @@ -773,6 +800,7 @@ def vary_batch(t: torch.Tensor, new_batch_size) -> torch.Tensor: return (time_total, result) if return_result else time_total +<<<<<<< HEAD @overload def _normalize_bench_inputs(example_inputs: _D) -> tuple[tuple[()], _D]: ... @@ -784,6 +812,9 @@ def _normalize_bench_inputs( def _normalize_bench_inputs(example_inputs): +======= +def _normalize_bench_inputs(example_inputs) -> tuple[tuple[Any], Mapping[str, Any]]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NOTE(bowbao): For huggingface benchmark, example_inputs are formatted as dictionary, # and consumed like `model(**example_inputs)`. # For other benchmarks, example_inputs are formatted as tuple and consumed @@ -1101,10 +1132,13 @@ def maybe_mark_profile(*args, **kwargs): frozen_model_iter_fn = export_aot_inductor( model, example_inputs, args.inductor_compile_mode ) +<<<<<<< HEAD elif args.export_nativert: frozen_model_iter_fn = export_nativert(model, example_inputs) elif args.torchscript_jit_trace: frozen_model_iter_fn = torchscript_jit_trace(model, example_inputs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: frozen_model_iter_fn = torch._dynamo.run(model_iter_fn) @@ -1451,6 +1485,7 @@ def get_excess_memory(cls, model) -> float: return cls.cache.get(weakref.ref(model), (None, 0.0))[1] +<<<<<<< HEAD class NativeRTCache: cache: dict[weakref.ref, Any] = {} @@ -1505,6 +1540,8 @@ def load(cls, model, example_inputs): return cls.cache[key] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def export(model, example_inputs): from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path @@ -1531,6 +1568,7 @@ def opt_export(_, example_inputs): return opt_export +<<<<<<< HEAD def export_nativert(model, example_inputs): optimized = NativeRTCache.load(model, example_inputs) @@ -1541,6 +1579,8 @@ def opt_nativert(_, example_inputs, collect_outputs=False): return opt_nativert +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def export_aot_inductor(model, example_inputs, mode): optimized = AOTInductorModelCache.load(model, example_inputs, mode) @@ -1551,6 +1591,7 @@ def opt_aot_inductor(_, example_inputs, collect_outputs=False): return opt_aot_inductor +<<<<<<< HEAD def torchscript_jit_trace(model, example_inputs): optimized = JitTracedCache.load(model, example_inputs) @@ -1561,6 +1602,8 @@ def opt_jit_trace(_, example_inputs, collect_outputs=False): return opt_jit_trace +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def download_retry_decorator(download_fn): """ Decorator function for applying retry logic to a download function. @@ -1766,7 +1809,11 @@ def __init__(self): self.grad_scaler = DummyGradScaler() self.autocast = contextlib.nullcontext self.autocast_arg = {} +<<<<<<< HEAD self.optimizer: Optional[torch.optim.Optimizer] = None +======= + self.optimizer = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._args = None def setup_amp(self, current_device=None): @@ -1845,10 +1892,13 @@ def skip_models_for_cpu(self): return set() @property +<<<<<<< HEAD def skip_models_for_cpu_aarch64(self): return set() @property +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def skip_models_for_freezing_cpu(self): return set() @@ -2307,12 +2357,16 @@ def record_status(accuracy_status, dynamo_start_stats): try: model_copy = self.deepcopy_and_maybe_parallelize(model) self.init_optimizer(name, current_device, model_copy.parameters()) +<<<<<<< HEAD if ( self.args.export or self.args.export_aot_inductor or self.args.export_nativert or self.args.torchscript_jit_trace ): +======= + if self.args.export or self.args.export_aot_inductor: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # apply export on module directly # no need for n iterations # the logic should be the same to self.model_iter_fn (forward_pass) @@ -2472,7 +2526,10 @@ def run_performance_test_non_alternate( ) def warmup(fn, model, example_inputs, mode, niters=10): +<<<<<<< HEAD gc.collect() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) peak_mem = 0 start_stats = get_dynamo_stats() try: @@ -2511,8 +2568,11 @@ def warmup(fn, model, example_inputs, mode, niters=10): # Use distributed wrapping as necessary model = self.deepcopy_and_maybe_parallelize(model) +<<<<<<< HEAD if not hasattr(model, name): model.name = name +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.init_optimizer(name, current_device, model.parameters()) # The self.autocast context is needed for the model we export with aot_compile, @@ -2616,6 +2676,11 @@ def warmup(fn, model, example_inputs, mode, niters=10): result_summary = latency_experiment_summary( self.suite_name, self.args, model, timings, **experiment_kwargs ) +<<<<<<< HEAD +======= + if not hasattr(model, name): + model.name = name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) results.append(result_summary) return " ".join(map(str, results)) @@ -2634,7 +2699,10 @@ def run_performance_test( return experiment(*self.maybe_cast(model, example_inputs)) def warmup(fn, model, example_inputs, mode, niters=5): +<<<<<<< HEAD gc.collect() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) peak_mem = 0 start_stats = get_dynamo_stats() try: @@ -2673,9 +2741,12 @@ def warmup(fn, model, example_inputs, mode, niters=5): # Use distributed wrapping as necessary model = self.deepcopy_and_maybe_parallelize(model) +<<<<<<< HEAD if not hasattr(model, name): model.name = name +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.init_optimizer(name, current_device, model.parameters()) # The self.autocast context is needed for the model we export with aot_compile, @@ -2708,11 +2779,15 @@ def warmup(fn, model, example_inputs, mode, niters=5): niters=1, ) +<<<<<<< HEAD if ( self.args.export_aot_inductor or self.args.export_nativert or self.args.torchscript_jit_trace ): +======= + if self.args.export_aot_inductor: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) optimized_model_iter_fn = optimize_ctx else: optimized_model_iter_fn = optimize_ctx(self.model_iter_fn) @@ -2793,6 +2868,11 @@ def warmup(fn, model, example_inputs, mode, niters=5): f"{ok:3}/{total:3} +{frames_third_pass} frames {compilation_time:3.0f}s" ) +<<<<<<< HEAD +======= + if not hasattr(model, name): + model.name = name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) results.append(experiment(model, example_inputs, **experiment_kwargs)) return " ".join(map(str, results)) @@ -3361,12 +3441,15 @@ def get_example_inputs(self): instead of deleting it and creating a new one.", ) +<<<<<<< HEAD parser.add_argument( "--caching-precompile", action="store_true", help="Enables caching precompile, serializing artifacts to DynamoCache between runs", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) group_latency = parser.add_mutually_exclusive_group() group_latency.add_argument( "--cold-start-latency", @@ -3466,6 +3549,7 @@ def get_example_inputs(self): help="Measure pass rate with Export+AOTInductor", ) group.add_argument( +<<<<<<< HEAD "--export-nativert", action="store_true", help="Measure pass rate with Export+NativeRT", @@ -3476,6 +3560,8 @@ def get_example_inputs(self): help="Measure pass rate with TorchScript jit.trace", ) group.add_argument( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "--xla", action="store_true", help="Compare TorchXLA to eager PyTorch" ) group.add_argument( @@ -3527,6 +3613,7 @@ def get_example_inputs(self): return parser.parse_args(args) +<<<<<<< HEAD def process_caching_precompile(): """ After every process_entry, save precompile artifacts to DynamoCache @@ -3550,6 +3637,8 @@ def process_caching_precompile(): PrecompileContext.populate_caches(results) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def process_entry(rank, runner, original_dir, args): args.rank = rank with maybe_init_distributed( @@ -3558,10 +3647,14 @@ def process_entry(rank, runner, original_dir, args): world_size=args.world_size, port=args.distributed_master_port, ): +<<<<<<< HEAD result = run(runner, args, original_dir) if args.caching_precompile: process_caching_precompile() return result +======= + return run(runner, args, original_dir) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def maybe_fresh_cache(args): @@ -3597,10 +3690,13 @@ def main(runner, original_dir=None, args=None): ) with maybe_fresh_cache(args): +<<<<<<< HEAD if args.caching_precompile: os.environ["TORCH_CACHING_PRECOMPILE"] = "1" torch._dynamo.config.caching_precompile = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args.init_distributed = args.only and args.multiprocess if args.init_distributed: # NB: Do NOT query device count before CUDA initialization; we're @@ -3858,10 +3954,14 @@ def run(runner, args, original_dir=None): runner.skip_models.update(runner.slow_models) if args.devices == ["cpu"]: +<<<<<<< HEAD arch = platform.machine() runner.skip_models.update(runner.skip_models_for_cpu) if arch == "aarch64": runner.skip_models.update(runner.skip_models_for_cpu_aarch64) +======= + runner.skip_models.update(runner.skip_models_for_cpu) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif args.devices == ["cuda"]: runner.skip_models.update(runner.skip_models_for_cuda) @@ -3916,6 +4016,7 @@ def run(runner, args, original_dir=None): optimize_ctx = export experiment = speedup_experiment output_filename = "export.csv" +<<<<<<< HEAD elif args.export_nativert: optimize_ctx = export_nativert experiment = speedup_experiment @@ -3924,6 +4025,8 @@ def run(runner, args, original_dir=None): optimize_ctx = torchscript_jit_trace experiment = speedup_experiment output_filename = "torchscript_jit_trace.csv" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif args.xla: (dev,) = args.devices os.environ["PJRT_DEVICE"] = {"cuda": "GPU", "cpu": "CPU"}[dev] @@ -4238,7 +4341,11 @@ def detect_and_mark_batch(t): nonlocal marked for i, s in enumerate(t.size()): if s == batch_size: +<<<<<<< HEAD torch._dynamo.maybe_mark_dynamic(t, i) +======= + torch._dynamo.mark_dynamic(t, i) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) marked = True break diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index 76026731fe890..6d4451f56e6f5 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -106,11 +106,14 @@ def process_hf_reformer_output(out): # on A100 GPUs - 40 GB. BATCH_SIZE_KNOWN_MODELS = {} +<<<<<<< HEAD # Run only this selected group of models, leave this empty to run everything TORCHBENCH_ONLY_MODELS = [ m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip() ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO(sdym): use batch-size-file parameter of common.main, like torchbench.py # Get the list of models and their batch sizes @@ -121,8 +124,11 @@ def process_hf_reformer_output(out): lines = [line.rstrip() for line in lines] for line in lines: model_name, batch_size = line.split(",") +<<<<<<< HEAD if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS: continue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) batch_size = int(batch_size) BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size assert len(BATCH_SIZE_KNOWN_MODELS) @@ -370,7 +376,10 @@ def use_larger_multiplier_for_smaller_tensor(self, name): return name in [ "ElectraForQuestionAnswering", "MegatronBertForQuestionAnswering", +<<<<<<< HEAD "GPT2ForSequenceClassification", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] def _get_model_cls_and_config(self, model_name): @@ -460,12 +469,15 @@ def load_model( else: model.eval() +<<<<<<< HEAD # Turning off kv cache for torchbench models. This is not the right # thing to do, but the pt2 dashboard is outdated. Real transformers # benchmarks will be added soon using a different infra. if hasattr(model, "config") and hasattr(model.config, "use_cache"): model.config.use_cache = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.validate_model(model, example_inputs) return device, model_name, model, example_inputs, batch_size diff --git a/benchmarks/dynamo/huggingface.yaml b/benchmarks/dynamo/huggingface.yaml index 5640776117096..6734203c54862 100644 --- a/benchmarks/dynamo/huggingface.yaml +++ b/benchmarks/dynamo/huggingface.yaml @@ -31,6 +31,11 @@ batch_size: BlenderbotSmallForCausalLM: 4 BlenderbotSmallForConditionalGeneration: 2 CamemBert: 2 +<<<<<<< HEAD +======= + DebertaForMaskedLM: 4 + DebertaForQuestionAnswering: 2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM: 4 DebertaV2ForQuestionAnswering: 8 DistilBertForMaskedLM: 2 @@ -61,6 +66,10 @@ batch_size: PegasusForConditionalGeneration: 2 RobertaForCausalLM: 2 RobertaForQuestionAnswering: 2 +<<<<<<< HEAD +======= + Speech2Text2ForCausalLM: 4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration: 2 T5Small: 2 TrOCRForCausalLM: 2 diff --git a/benchmarks/dynamo/huggingface_models_list.txt b/benchmarks/dynamo/huggingface_models_list.txt index 12ceedd5c4ccc..d2c52ac376787 100644 --- a/benchmarks/dynamo/huggingface_models_list.txt +++ b/benchmarks/dynamo/huggingface_models_list.txt @@ -10,6 +10,11 @@ BlenderbotForConditionalGeneration,16 BlenderbotSmallForCausalLM,256 BlenderbotSmallForConditionalGeneration,128 CamemBert,32 +<<<<<<< HEAD +======= +DebertaForMaskedLM,32 +DebertaForQuestionAnswering,32 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,8 DebertaV2ForQuestionAnswering,8 DistilBertForMaskedLM,256 @@ -40,6 +45,10 @@ PegasusForCausalLM,128 PegasusForConditionalGeneration,64 RobertaForCausalLM,32 RobertaForQuestionAnswering,32 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,1024 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,8 T5Small,8 TrOCRForCausalLM,64 diff --git a/benchmarks/dynamo/huggingface_models_list_cpu.txt b/benchmarks/dynamo/huggingface_models_list_cpu.txt index 4078368a69c44..f5b6157a43979 100644 --- a/benchmarks/dynamo/huggingface_models_list_cpu.txt +++ b/benchmarks/dynamo/huggingface_models_list_cpu.txt @@ -10,6 +10,11 @@ BlenderbotForCausalLM,32 BlenderbotSmallForCausalLM,64 BlenderbotSmallForConditionalGeneration,64 CamemBert,16 +<<<<<<< HEAD +======= +DebertaForMaskedLM,32 +DebertaForQuestionAnswering,8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DebertaV2ForMaskedLM,16 DebertaV2ForQuestionAnswering,2 DistilBertForMaskedLM,128 @@ -36,6 +41,10 @@ PLBartForCausalLM,8 PLBartForConditionalGeneration,4 RobertaForCausalLM,16 RobertaForQuestionAnswering,16 +<<<<<<< HEAD +======= +Speech2Text2ForCausalLM,32 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T5ForConditionalGeneration,4 T5Small,1 TrOCRForCausalLM,32 diff --git a/benchmarks/dynamo/pr_time_benchmarks/check_results.py b/benchmarks/dynamo/pr_time_benchmarks/check_results.py index 734d3a01c1e82..a072e00d97c6d 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/check_results.py +++ b/benchmarks/dynamo/pr_time_benchmarks/check_results.py @@ -132,10 +132,17 @@ def log(event_name): ) new_entry = copy.deepcopy(entry) +<<<<<<< HEAD # only change if abs(ratio) > entry.noise_margin /5. new_entry.expected_value = ( replace_with_zeros(result) if abs(ratio) > entry.noise_margin * 100 / 5 +======= + # only change if abs(ratio) > entry.noise_margin /3. + new_entry.expected_value = ( + replace_with_zeros(result) + if abs(ratio) > entry.noise_margin * 100 / 3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else entry.expected_value ) new_expected[key] = new_entry diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index fc11be9ba6528..fc7940e1c62fb 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,3 +1,4 @@ +<<<<<<< HEAD add_loop_eager,compile_time_instruction_count,3070000000,0.1 @@ -87,3 +88,78 @@ basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1 basic_InlineMod_eager,compile_time_instruction_count,7464000000,0.1 +======= +add_loop_eager,compile_time_instruction_count,2937000000,0.015 + + + +add_loop_eager_dynamic,compile_time_instruction_count,4300194436,0.025 + + + +add_loop_inductor,compile_time_instruction_count,29630000000,0.015 + + + +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39110000000,0.025 + + + +add_loop_inductor_gpu,compile_time_instruction_count,26180000000,0.015 + + + +basic_modules_ListOfLinears_eager,compile_time_instruction_count,942514329,0.015 + + + +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18660000000,0.015 + + + +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16750000000,0.015 + + + +basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000000,0.2 + + + +update_hint_regression,compile_time_instruction_count,1677000000,0.02 + + + +sum_floordiv_regression,compile_time_instruction_count,984411080,0.015 + + + +symint_sum,compile_time_instruction_count,3252000000,0.015 + + + +symint_sum_loop,compile_time_instruction_count,4216000000,0.015 + + + +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2113000000,0.015 + + + +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6022000000,0.015 + + + +aotdispatcher_partitioner_cpu,compile_time_instruction_count,8844000000,0.015 + + + +aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1963000000,0.015 + + + +aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3875000000,0.015 + + + +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10420000000,0.015 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index 7f80d107ff9e7..8c0420b079ec6 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -32,7 +32,10 @@ import itertools import logging import os +<<<<<<< HEAD import platform +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import re import shutil import subprocess @@ -375,7 +378,10 @@ def get_skip_tests(suite, device, is_training: bool): original_dir = abspath(os.getcwd()) module = importlib.import_module(suite) os.chdir(original_dir) +<<<<<<< HEAD arch = platform.machine() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if suite == "torchbench": skip_tests.update(module.TorchBenchmarkRunner().skip_models) @@ -385,10 +391,13 @@ def get_skip_tests(suite, device, is_training: bool): ) if device == "cpu": skip_tests.update(module.TorchBenchmarkRunner().skip_models_for_cpu) +<<<<<<< HEAD if arch == "aarch64": skip_tests.update( module.TorchBenchmarkRunner().skip_models_for_cpu_aarch64 ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif device == "cuda": skip_tests.update(module.TorchBenchmarkRunner().skip_models_for_cuda) diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py index b63c41947b9ad..ce65bbe3c0f42 100755 --- a/benchmarks/dynamo/timm_models.py +++ b/benchmarks/dynamo/timm_models.py @@ -39,6 +39,7 @@ def pip_install(package): from timm.models import create_model TIMM_MODELS = {} +<<<<<<< HEAD # Run only this selected group of models, leave this empty to run everything TORCHBENCH_ONLY_MODELS = [ @@ -46,13 +47,20 @@ def pip_install(package): ] filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt") +======= +filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with open(filename) as fh: lines = fh.readlines() lines = [line.rstrip() for line in lines] for line in lines: model_name, batch_size = line.split(" ") +<<<<<<< HEAD if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS: continue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TIMM_MODELS[model_name] = int(batch_size) @@ -238,6 +246,7 @@ def _skip(self): return self._config["skip"] @property +<<<<<<< HEAD def skip_models_for_cpu(self): return self._skip["device"]["cpu"] @@ -246,6 +255,8 @@ def skip_models_for_cpu_aarch64(self): return self._skip["device"]["cpu_aarch64"] @property +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def skip_models(self): return self._skip["all"] diff --git a/benchmarks/dynamo/timm_models.yaml b/benchmarks/dynamo/timm_models.yaml index 6a6fdde849abc..132c243c4ae5d 100644 --- a/benchmarks/dynamo/timm_models.yaml +++ b/benchmarks/dynamo/timm_models.yaml @@ -2,6 +2,7 @@ skip: all: - ~ +<<<<<<< HEAD device: cpu: - ~ @@ -13,3 +14,5 @@ skip: - resnest101e - swsl_resnext101_32x16d - visformer_small +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index 1f10ecc661d8e..39be7f36947ee 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -139,10 +139,13 @@ def skip_models_for_cpu(self): return self._skip["device"]["cpu"] @property +<<<<<<< HEAD def skip_models_for_cpu_aarch64(self): return self._skip["device"]["cpu_aarch64"] @property +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def skip_models_for_cuda(self): return self._skip["device"]["cuda"] @@ -382,6 +385,7 @@ def load_model( if self.args.trace_on_xla: # work around for: https://github.com/pytorch/xla/issues/4174 import torch_xla # noqa: F401 +<<<<<<< HEAD # Turning off kv cache for torchbench models. This is not the right # thing to do, but the torchbench models are way outdated, and since we @@ -398,6 +402,8 @@ def load_model( if model_name == "hf_T5_generate": model.model.config.use_cache = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.validate_model(model, example_inputs) return device, benchmark.name, model, example_inputs, batch_size @@ -458,8 +464,11 @@ def get_tolerance_and_cosine_flag(self, is_training, current_device, name): if self.args.bfloat16: if name in self._tolerance["higher_bf16"]: return 1e-2, cosine +<<<<<<< HEAD elif current_device == "xpu" and name in self._tolerance["higher_bf16_xpu"]: return 8 * 1e-2, cosine +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_training and (current_device == "cuda" or current_device == "xpu"): tolerance = 1e-3 diff --git a/benchmarks/dynamo/torchbench.yaml b/benchmarks/dynamo/torchbench.yaml index 6a15cf33222b2..c4a613c95efb2 100644 --- a/benchmarks/dynamo/torchbench.yaml +++ b/benchmarks/dynamo/torchbench.yaml @@ -55,10 +55,13 @@ tolerance: - drq - hf_Whisper +<<<<<<< HEAD # These models need higher tolerance for xpu devices with bf16 higher_bf16_xpu: - squeezenet1_1 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) freezing: # Similar logic to timm_models.py:get_tolerance_and_cosine_flag # the conv-batchnorm fusion used under freezing may cause relatively @@ -213,6 +216,7 @@ skip: - llava - moco +<<<<<<< HEAD # Skip these additional models when running on aarch64 cpu_aarch64: # timeout on aarch64 @@ -222,6 +226,9 @@ skip: cuda: # Temporary until https://github.com/pytorch/pytorch/issues/162282 is fixed - sam_fast +======= + cuda: [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test: training: diff --git a/benchmarks/functional_autograd_benchmark/README.md b/benchmarks/functional_autograd_benchmark/README.md index 457f01265fbff..b4e313796e16f 100644 --- a/benchmarks/functional_autograd_benchmark/README.md +++ b/benchmarks/functional_autograd_benchmark/README.md @@ -17,8 +17,13 @@ export DEBUG=0 export OMP_NUM_THREADS=10 # Compile pytorch with the base revision +<<<<<<< HEAD git checkout main python -m pip install --no-build-isolation -v -e . +======= +git checkout master +python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Install dependencies: # Scipy is required by detr @@ -32,7 +37,11 @@ python functional_autograd_benchmark.py --output before.txt # Compile pytorch with your change popd git checkout your_feature_branch +<<<<<<< HEAD python -m pip install --no-build-isolation -v -e . +======= +python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Run the benchmark for the new version pushd benchmarks/functional_autograd_benchmark diff --git a/benchmarks/instruction_counts/worker/main.py b/benchmarks/instruction_counts/worker/main.py index 33021ec650049..23a526534f10d 100644 --- a/benchmarks/instruction_counts/worker/main.py +++ b/benchmarks/instruction_counts/worker/main.py @@ -170,7 +170,11 @@ def main(communication_file: str) -> None: # Runner process sent SIGINT. sys.exit() +<<<<<<< HEAD except BaseException: # noqa: B036 +======= + except BaseException: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) trace_f = io.StringIO() traceback.print_exc(file=trace_f) result = WorkerFailure(failure_trace=trace_f.getvalue()) diff --git a/benchmarks/operator_benchmark/README.md b/benchmarks/operator_benchmark/README.md index 0a8ad5846bf24..d17212ffedb4a 100644 --- a/benchmarks/operator_benchmark/README.md +++ b/benchmarks/operator_benchmark/README.md @@ -20,7 +20,11 @@ Key Features: The instruction below installs a cpp\_extension for PyTorch and it is required to run the benchmark suite. ```bash cd pt_extension +<<<<<<< HEAD python -m pip install . +======= +python setup.py install +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` ## How to run the benchmarks: diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py index 3f79ed2318c4f..e032478c43ed3 100644 --- a/benchmarks/operator_benchmark/benchmark_core.py +++ b/benchmarks/operator_benchmark/benchmark_core.py @@ -4,7 +4,10 @@ import functools import json import os +<<<<<<< HEAD import platform +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import timeit from collections import namedtuple from dataclasses import asdict, dataclass @@ -18,7 +21,10 @@ # needs to be imported after torch import torch.utils.cpp_extension as cpp_extension # noqa: F401 +<<<<<<< HEAD from torch.utils.benchmark import Timer +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Performance microbenchmarks. @@ -193,11 +199,14 @@ def __init__(self, args): self.predefined_minimum_secs = 1 self.max_iters = 1e6 self.use_jit = args.use_jit +<<<<<<< HEAD self.use_compile = args.use_compile if self.use_jit and self.use_compile: raise ValueError( "use_jit and use_compile are mutually exclusive, please specify one." ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.num_runs = args.num_runs self.print_per_iter = False self.output_csv = args.output_csv @@ -229,7 +238,11 @@ def _print_header(self): if self.args.operators: print(f"# {self.args.operators}") +<<<<<<< HEAD def _print_perf_result(self, results, test_case): +======= + def _print_perf_result(self, reported_run_time_us, test_case): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.args.report_aibench: # Output for AIBench # Print out per iteration execution time instead of avg time @@ -243,14 +256,22 @@ def _print_perf_result(self, results, test_case): "type": test_name, "metric": "latency", "unit": "us", +<<<<<<< HEAD "value": str(results["reported_run_time_us"[run]]), +======= + "value": str(reported_run_time_us[run]), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ) ) else: +<<<<<<< HEAD print( f"# Mode: {'JIT' if self.use_jit else 'Compile' if self.use_compile else 'Eager'}" ) +======= + print(f"# Mode: {'JIT' if self.use_jit else 'Eager'}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) print( f"# Name: {test_case.test_config.test_name}\n# Input: {test_case.test_config.input_config}" ) @@ -259,6 +280,7 @@ def _print_perf_result(self, results, test_case): if self.num_runs > 1: for run in range(self.num_runs): print( +<<<<<<< HEAD f"Run: {run}, {mode} Execution Time (us) : {results['reported_run_time_us'][run]:.3f}" ) print() @@ -269,11 +291,21 @@ def _print_perf_result(self, results, test_case): print(f"Peak Memory (KB) : {results['peak_memory']}\n") def _perf_result_to_dict(self, results, test_case): +======= + f"Run: {run}, {mode} Execution Time (us) : {reported_run_time_us[run]:.3f}" + ) + print() + else: + print(f"{mode} Execution Time (us) : {reported_run_time_us[0]:.3f}\n") + + def _perf_result_to_dict(self, reported_run_time_us, test_case): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """This function is the parallel of _print_perf_result, which instead of writing information to terminal, returns a dictionary. """ if self.args.report_aibench: return {} +<<<<<<< HEAD out = { "test_name": test_case.test_config.test_name, @@ -286,6 +318,15 @@ def _perf_result_to_dict(self, results, test_case): "latency unit": "us", "peak memory": results["peak_memory"], "memory unit": "KB", +======= + out = { + "test_name": test_case.test_config.test_name, + "input_config": test_case.test_config.input_config, + "mode": "JIT" if self.use_jit else "Eager", + "run": "Backward" if test_case.test_config.run_backward else "Forward", + "latency": round(reported_run_time_us[0], 3), + "latency unit": "us", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } # parsing test_case.test_config.input_config, adding it as entries to the 'out' dictionary @@ -347,6 +388,7 @@ def _launch_forward(self, test_case, iters, print_per_iter): func = test_case.run_forward if self.use_jit: func = test_case.run_jit_forward +<<<<<<< HEAD if self.use_compile: func = test_case.run_compile_forward @@ -367,6 +409,12 @@ def _launch_forward(self, test_case, iters, print_per_iter): ) result = timer.adaptive_autorange(min_run_time=0.0001) return result.median * iters +======= + forward_time = timeit.timeit( + functools.partial(func, iters, print_per_iter, cuda_sync), number=1 + ) + return forward_time +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _launch_backward(self, test_case, iters, print_per_iter=False): """This function runs forward path of an op to get an output. Then the backward path is executed @@ -379,7 +427,11 @@ def _launch_backward(self, test_case, iters, print_per_iter=False): ) return backward_time +<<<<<<< HEAD def _measure_metrics(self, launch_test, test_case, iters, print_per_iter): +======= + def _measure_time(self, launch_test, test_case, iters, print_per_iter): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This function execute the operator for iterations then look at the time. If it's not significant, the number of iterations will be increased before rerun. @@ -387,6 +439,7 @@ def _measure_metrics(self, launch_test, test_case, iters, print_per_iter): """ curr_test_total_time = 0 time_trace = [] +<<<<<<< HEAD peak_memory = 0 input_values = test_case.op_bench.inputs.values() device, device_module = None, None @@ -406,6 +459,10 @@ def _measure_metrics(self, launch_test, test_case, iters, print_per_iter): # Memory measurement process if hasattr(device_module, "max_memory_allocated"): peak_memory = device_module.max_memory_allocated(device) +======= + while True: + run_time_sec = launch_test(test_case, iters, print_per_iter) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) curr_test_total_time += run_time_sec # Analyze time after each run to decide if the result is stable results_are_significant = self._iteration_result_is_significant( @@ -419,6 +476,7 @@ def _measure_metrics(self, launch_test, test_case, iters, print_per_iter): time_trace.append(report_run_time) # Print out the time spent in each epoch in ms if self.args.report_aibench: +<<<<<<< HEAD mode = ( "JIT" if self.use_jit @@ -426,6 +484,9 @@ def _measure_metrics(self, launch_test, test_case, iters, print_per_iter): if self.use_compile else "Eager" ) +======= + mode = "JIT" if self.use_jit else "Eager" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test_name = "_".join( [test_case.framework, test_case.test_config.test_name, mode] ) @@ -437,7 +498,11 @@ def _measure_metrics(self, launch_test, test_case, iters, print_per_iter): "metric": "latency", "unit": "ms", "value": str(report_run_time / 1e3), +<<<<<<< HEAD }, +======= + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) if results_are_significant: @@ -447,7 +512,11 @@ def _measure_metrics(self, launch_test, test_case, iters, print_per_iter): # iteration count, and run the benchmark again... iters = self._predict_num_iter_needed(iters) reported_run_time_us = np.percentile(np.array(time_trace), 50) +<<<<<<< HEAD return reported_run_time_us, peak_memory / 1024 +======= + return reported_run_time_us +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _check_keep(self, test_flag, cmd_flag): return cmd_flag is None or test_flag == cmd_flag @@ -534,7 +603,10 @@ def _output_json( self, perf_list, output_file, +<<<<<<< HEAD benchmark_name="PyTorch operator benchmark", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Write the result into JSON format, so that it can be uploaded to the benchmark database @@ -552,10 +624,15 @@ def _output_json( input_config = perf_item.get("input_config", "") run_type = perf_item.get("run") latency = perf_item.get("latency", 0) +<<<<<<< HEAD peak_memory = perf_item.get("peak memory", 0) device = perf_item.get("device", "unknown") dtype = perf_item.get("dtype", "torch.float").split(".")[1] runtime = perf_item.get("runtime", None) +======= + + dtype = "float32" # default +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Extract mode based on run_type mode = None @@ -564,6 +641,7 @@ def _output_json( elif run_type == "Backward": mode = "training" +<<<<<<< HEAD # Extract use_compile from it if runtime == "Compile": use_compile = True @@ -580,6 +658,8 @@ def _output_json( else "unknown" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Create the record @dataclass class BenchmarkInfo: @@ -607,6 +687,7 @@ class BenchmarkRecord: model: ModelInfo metric: MetricInfo +<<<<<<< HEAD # Add record for latency record_latency = BenchmarkRecord( benchmark=BenchmarkInfo( @@ -619,6 +700,14 @@ class BenchmarkRecord: "arch": device_arch, "use_compile": use_compile, }, +======= + record = BenchmarkRecord( + benchmark=BenchmarkInfo( + name="PyTorch operator benchmark", + mode=mode, + dtype=dtype, + extra_info={"input_config": input_config}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), model=ModelInfo( name=test_name, type="micro-benchmark", origins=["pytorch"] @@ -630,6 +719,7 @@ class BenchmarkRecord: target_value=None, ), ) +<<<<<<< HEAD records.append(asdict(record_latency)) # Add record for peak memory @@ -641,6 +731,10 @@ class BenchmarkRecord: target_value=None, ) records.append(asdict(record_memory)) +======= + + records.append(asdict(record)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Write all records to the output file with open(output_file, "w", encoding="utf-8") as f: @@ -656,7 +750,10 @@ def run(self): "tag", "run_backward", "Execution Time", +<<<<<<< HEAD "Peak Memory (KB)", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] if self.args.output_json or self.args.output_json_for_dashboard: @@ -694,16 +791,25 @@ def run(self): test_case, self.args.warmup_iterations, print_per_iter=False ) # Actual Execution +<<<<<<< HEAD results = [ self._measure_metrics( +======= + reported_time = [ + self._measure_time( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) launch_func, test_case, self.iters, self.print_per_iter ) for _ in range(self.num_runs) ] +<<<<<<< HEAD result_dict = dict() result_dict["reported_run_time_us"] = [r[0] for r in results] result_dict["peak_memory"] = results[0][1] self._print_perf_result(results=result_dict, test_case=test_case) +======= + self._print_perf_result(reported_time, test_case) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # output results to csv self._output_csv( @@ -719,6 +825,7 @@ def run(self): ), test_case.test_config.tag, test_case.test_config.run_backward, +<<<<<<< HEAD result_dict["reported_run_time_us"][0], result_dict["peak_memory"], ], @@ -730,6 +837,18 @@ def run(self): self._output_json( perf_list, self.args.output_json_for_dashboard, self.args.benchmark_name ) +======= + reported_time[0], + ], + ) + if self.args.output_json or self.args.output_json_for_dashboard: + perf_list.append( + self._perf_result_to_dict(reported_time, test_case) + ) + + if self.args.output_json_for_dashboard: + self._output_json(perf_list, self.args.output_json_for_dashboard) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.args.output_json: with open(self.args.output_json, "w") as f: diff --git a/benchmarks/operator_benchmark/benchmark_pytorch.py b/benchmarks/operator_benchmark/benchmark_pytorch.py index cfed9ebac04b1..55c6ca4f62e79 100644 --- a/benchmarks/operator_benchmark/benchmark_pytorch.py +++ b/benchmarks/operator_benchmark/benchmark_pytorch.py @@ -4,6 +4,7 @@ import torch +<<<<<<< HEAD # Import the C++ extension to register the _consume operator try: import benchmark_cpp_extension # noqa: F401 @@ -13,6 +14,8 @@ "Failed to import C++ extension, please build it using \ncd pt_extension \npython -m pip install ." ) from err +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """PyTorch performance microbenchmarks. This module contains PyTorch-specific functionalities for performance @@ -80,6 +83,7 @@ def forward_consume(self, iters: int): for _ in range(iters): torch.ops.operator_benchmark._consume(self.forward_impl()) +<<<<<<< HEAD def forward_impl_eager(self): # This is to supply the inputs to the forward function which # will be called in both the eager and compile mode of local runs @@ -90,6 +94,8 @@ def forward_consume_eager(self, iters: int): for _ in range(iters): torch.ops.operator_benchmark._consume(self.forward_impl_eager()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def module_name(self): """this is used to label the operator being benchmarked""" if self.user_given_name: @@ -136,13 +142,17 @@ def __init__(self, op_bench, test_config): self.framework = "PyTorch" self.time_series = [] self._jit_forward_graph = None +<<<<<<< HEAD self._compile_forward_graph = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _generate_jit_forward_graph(self): """generate a graph for the forward function via scripting""" scripted_op_bench = torch.jit.script(self.op_bench) return scripted_op_bench.forward_consume +<<<<<<< HEAD def _generate_compile_forward_graph(self): """generate a compiled graph for the forward function via torch.compile""" compiled_forward_consume = torch.compile( @@ -150,12 +160,15 @@ def _generate_compile_forward_graph(self): ) return compiled_forward_consume +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def run_jit_forward(self, num_runs, print_per_iter=False, cuda_sync=False): """Run the forward path of an op with JIT mode""" if self._jit_forward_graph is None: self._jit_forward_graph = self._generate_jit_forward_graph() self._jit_forward_graph(num_runs) +<<<<<<< HEAD def run_compile_forward(self, num_runs, print_per_iter=False, cuda_sync=False): """Run the forward path of an op with compile mode""" if self._compile_forward_graph is None: @@ -164,6 +177,8 @@ def run_compile_forward(self, num_runs, print_per_iter=False, cuda_sync=False): if cuda_sync: torch.cuda.synchronize(torch.cuda.current_device()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _print_per_iter(self): # print last 50 values length = min(len(self.time_series), 50) @@ -185,14 +200,22 @@ def run_forward(self, num_runs, print_per_iter, cuda_sync): if print_per_iter: for _ in range(num_runs): start_time = time.time() +<<<<<<< HEAD self.output = self.op_bench.forward_impl_eager() +======= + self.output = self.op_bench.forward_impl() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if cuda_sync: torch.cuda.synchronize(torch.cuda.current_device()) end_time = time.time() self.time_series.append((end_time - start_time) * 1e3) else: for _ in range(num_runs): +<<<<<<< HEAD self.output = self.op_bench.forward_impl_eager() +======= + self.output = self.op_bench.forward_impl() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if cuda_sync: torch.cuda.synchronize(torch.cuda.current_device()) diff --git a/benchmarks/operator_benchmark/benchmark_runner.py b/benchmarks/operator_benchmark/benchmark_runner.py index 6568cf9bf3ee6..3badfd05ac53f 100644 --- a/benchmarks/operator_benchmark/benchmark_runner.py +++ b/benchmarks/operator_benchmark/benchmark_runner.py @@ -63,6 +63,7 @@ def parse_args(): ) parser.add_argument( +<<<<<<< HEAD "--benchmark-name", "--benchmark_name", help="Name of the benchmark to store results to", @@ -70,6 +71,8 @@ def parse_args(): ) parser.add_argument( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "--list-tests", "--list_tests", help="List all test cases without running them", @@ -143,6 +146,7 @@ def parse_args(): ) parser.add_argument( +<<<<<<< HEAD "--use-compile", "--use_compile", type=benchmark_utils.str2bool, @@ -153,6 +157,8 @@ def parse_args(): ) parser.add_argument( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "--forward-only", "--forward_only", type=benchmark_utils.str2bool, @@ -179,7 +185,11 @@ def parse_args(): "--output-json-for-dashboard", "--output_json_for_dashboard", help="Save results in JSON format for display on the OSS dashboard", +<<<<<<< HEAD default="benchmark-results.json", +======= + default="False", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) args, _ = parser.parse_known_args() diff --git a/benchmarks/operator_benchmark/expected_ci_operator_benchmark_eager_float32_cpu.csv b/benchmarks/operator_benchmark/expected_ci_operator_benchmark_eager_float32_cpu.csv index 9a7b6797e982a..cd29834d56c1a 100644 --- a/benchmarks/operator_benchmark/expected_ci_operator_benchmark_eager_float32_cpu.csv +++ b/benchmarks/operator_benchmark/expected_ci_operator_benchmark_eager_float32_cpu.csv @@ -1,5 +1,9 @@ Benchmarking Framework,Benchmarking Module Name,Case Name,tag,run_backward,Execution Time +<<<<<<< HEAD PyTorch,add,add_M1_N1_K1_cpu,short,FALSE,2.459 +======= +PyTorch,add,add_M1_N1_K1_cpu,short,FALSE,3.9497 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PyTorch,add,add_M64_N64_K64_cpu,short,FALSE,14.3181 PyTorch,add,add_M64_N64_K128_cpu,short,FALSE,14.6826 PyTorch,add,add_M1_N1_K1_cpu_bwdall_BACKWARD,short,TRUE,58.1449 @@ -376,10 +380,17 @@ PyTorch,relu6,"relu6_dims(3,4,5)_contigFalse_inplaceFalse_dtypetorch.qint32",sho PyTorch,relu6,"relu6_dims(2,3,4,5)_contigFalse_inplaceFalse_dtypetorch.quint8",short,FALSE,9.6588 PyTorch,relu6,"relu6_dims(2,3,4,5)_contigFalse_inplaceFalse_dtypetorch.qint8",short,FALSE,9.5969 PyTorch,relu6,"relu6_dims(2,3,4,5)_contigFalse_inplaceFalse_dtypetorch.qint32",short,FALSE,9.547 +<<<<<<< HEAD PyTorch,relu6,"relu6_dims(512,512)_contigFalse_inplaceFalse_dtypetorch.quint8",short,FALSE,50.21375 PyTorch,relu6,"relu6_dims(512,512)_contigFalse_inplaceFalse_dtypetorch.qint8",short,FALSE,45.14133333 PyTorch,relu6,"relu6_dims(512,512)_contigFalse_inplaceFalse_dtypetorch.qint32",short,FALSE,52.6664 PyTorch,relu6,"relu6_dims(256,1024)_contigFalse_inplaceFalse_dtypetorch.quint8",short,FALSE,51.49525 +======= +PyTorch,relu6,"relu6_dims(512,512)_contigFalse_inplaceFalse_dtypetorch.quint8",short,FALSE,68.739 +PyTorch,relu6,"relu6_dims(512,512)_contigFalse_inplaceFalse_dtypetorch.qint8",short,FALSE,45.14133333 +PyTorch,relu6,"relu6_dims(512,512)_contigFalse_inplaceFalse_dtypetorch.qint32",short,FALSE,52.6664 +PyTorch,relu6,"relu6_dims(256,1024)_contigFalse_inplaceFalse_dtypetorch.quint8",short,FALSE,69.1875 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PyTorch,relu6,"relu6_dims(256,1024)_contigFalse_inplaceFalse_dtypetorch.qint8",short,FALSE,48.3458 PyTorch,relu6,"relu6_dims(256,1024)_contigFalse_inplaceFalse_dtypetorch.qint32",short,FALSE,62.0719 PyTorch,functional.hardtanh,"functional.hardtanh_dims(3,4,5)_contigFalse_inplaceFalse_dtypetorch.quint8",short,FALSE,7.5728 @@ -388,10 +399,17 @@ PyTorch,functional.hardtanh,"functional.hardtanh_dims(3,4,5)_contigFalse_inplace PyTorch,functional.hardtanh,"functional.hardtanh_dims(2,3,4,5)_contigFalse_inplaceFalse_dtypetorch.quint8",short,FALSE,8.1647 PyTorch,functional.hardtanh,"functional.hardtanh_dims(2,3,4,5)_contigFalse_inplaceFalse_dtypetorch.qint8",short,FALSE,8.1768 PyTorch,functional.hardtanh,"functional.hardtanh_dims(2,3,4,5)_contigFalse_inplaceFalse_dtypetorch.qint32",short,FALSE,8.0619 +<<<<<<< HEAD PyTorch,functional.hardtanh,"functional.hardtanh_dims(512,512)_contigFalse_inplaceFalse_dtypetorch.quint8",short,FALSE,48.88475 PyTorch,functional.hardtanh,"functional.hardtanh_dims(512,512)_contigFalse_inplaceFalse_dtypetorch.qint8",short,FALSE,43.702 PyTorch,functional.hardtanh,"functional.hardtanh_dims(512,512)_contigFalse_inplaceFalse_dtypetorch.qint32",short,FALSE,50.3613 PyTorch,functional.hardtanh,"functional.hardtanh_dims(256,1024)_contigFalse_inplaceFalse_dtypetorch.quint8",short,FALSE,50.3995 +======= +PyTorch,functional.hardtanh,"functional.hardtanh_dims(512,512)_contigFalse_inplaceFalse_dtypetorch.quint8",short,FALSE,67.118 +PyTorch,functional.hardtanh,"functional.hardtanh_dims(512,512)_contigFalse_inplaceFalse_dtypetorch.qint8",short,FALSE,43.702 +PyTorch,functional.hardtanh,"functional.hardtanh_dims(512,512)_contigFalse_inplaceFalse_dtypetorch.qint32",short,FALSE,50.3613 +PyTorch,functional.hardtanh,"functional.hardtanh_dims(256,1024)_contigFalse_inplaceFalse_dtypetorch.quint8",short,FALSE,67.436 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PyTorch,functional.hardtanh,"functional.hardtanh_dims(256,1024)_contigFalse_inplaceFalse_dtypetorch.qint8",short,FALSE,46.9813 PyTorch,functional.hardtanh,"functional.hardtanh_dims(256,1024)_contigFalse_inplaceFalse_dtypetorch.qint32",short,FALSE,59.2295 PyTorch,functional.hardsigmoid,"functional.hardsigmoid_dims(3,4,5)_contigFalse_inplaceFalse_dtypetorch.quint8",short,FALSE,6.5189 @@ -1316,4 +1334,8 @@ PyTorch,where,"where_cond_shape(8,16,1)_input_shape(1,)_other_shape(1,)_cpu_dtyp PyTorch,where,"where_cond_shape(8,16,1)_input_shape(16,1)_other_shape(8,16,1)_cpu_dtypetorch.float32",short,FALSE,5.763 PyTorch,where,"where_cond_shape(8,16,1)_input_shape(8,1,1)_other_shape(1,)_cpu_dtypetorch.float32",short,FALSE,5.744666667 PyTorch,clamp,clamp_M512_N512_cpu,short,FALSE,15.26233333 +<<<<<<< HEAD +PyTorch,gelu,gelu_M512_N512_cpu,short,FALSE,31.33166667 +======= PyTorch,gelu,gelu_M512_N512_cpu,short,FALSE,31.33166667 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/benchmarks/operator_benchmark/pt/add_test.py b/benchmarks/operator_benchmark/pt/add_test.py index 739b8ef14a54b..bec80e0d8cd13 100644 --- a/benchmarks/operator_benchmark/pt/add_test.py +++ b/benchmarks/operator_benchmark/pt/add_test.py @@ -52,6 +52,30 @@ def forward(self, input_one, input_two): op_bench.generate_pt_test(add_long_configs + add_short_configs, AddBenchmark) op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddBenchmark) +<<<<<<< HEAD +======= + +"""Mircobenchmark for addmm operator.""" + + +class AddmmBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K, device): + self.inputs = { + "input_one": torch.rand(M, K, device=device, requires_grad=self.auto_set()), + "mat1": torch.rand(M, N, device=device, requires_grad=self.auto_set()), + "mat2": torch.rand(N, K, device=device, requires_grad=self.auto_set()), + } + self.set_module_name("addmm") + + def forward(self, input_one, mat1, mat2): + return torch.addmm(input_one, mat1, mat2) + + +op_bench.generate_pt_test(add_long_configs + add_short_configs, AddmmBenchmark) +op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddmmBenchmark) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Mircobenchmark for addr operator.""" @@ -85,5 +109,49 @@ def forward(self, input_one, vec1, vec2): op_bench.generate_pt_test(addr_configs, AddrBenchmark) op_bench.generate_pt_gradient_test(addr_configs, AddrBenchmark) +<<<<<<< HEAD +======= + +"""Mircobenchmark for addbmm operator.""" + + +class AddbmmBenchmark(op_bench.TorchBenchmarkBase): + def init(self, B, M, N, K, device): + self.inputs = { + "input_one": torch.rand( + (M, N), device=device, requires_grad=self.auto_set() + ), + "batch1": torch.rand( + (B, M, K), device=device, requires_grad=self.auto_set() + ), + "batch2": torch.rand( + ( + B, + K, + N, + ), + device=device, + requires_grad=self.auto_set(), + ), + } + self.set_module_name("addbmm") + + def forward(self, input_one, batch1, batch2): + return torch.addbmm(input_one, batch1, batch2) + + +addbmm_configs = op_bench.cross_product_configs( + B=[2, 100], + M=[8, 256], + N=[256, 16], + K=[15, 16], + device=["cpu", "cuda"], + tags=["addbmm"], +) + +op_bench.generate_pt_test(addbmm_configs, AddbmmBenchmark) +op_bench.generate_pt_gradient_test(addbmm_configs, AddbmmBenchmark) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/bmm_test.py b/benchmarks/operator_benchmark/pt/bmm_test.py index f867f6ac09f8d..3fde51257e4ff 100644 --- a/benchmarks/operator_benchmark/pt/bmm_test.py +++ b/benchmarks/operator_benchmark/pt/bmm_test.py @@ -27,12 +27,21 @@ ) batched_binary_configs_long = op_bench.cross_product_configs( +<<<<<<< HEAD B=[8, 32], M=[256, 1024], N=[256, 1024], K=[64, 128], device=["cuda"], dtype=[torch.float32, torch.bfloat16, torch.float16], +======= + B=[1, 128], + M=[8, 128], + N=[32, 64], + K=[4, 256], + device=["cpu", "cuda"], + dtype=[torch.float, torch.bfloat16], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags=["long"], ) @@ -40,12 +49,17 @@ class BatchedBinaryOpBenchmark(op_bench.TorchBenchmarkBase): def init(self, B, M, N, K, device, dtype, op_func): self.inputs = { +<<<<<<< HEAD "batch1": torch.rand( (B, M, N), device=device, dtype=dtype, requires_grad=self.auto_set() ), "batch2": torch.rand( (B, N, K), device=device, dtype=dtype, requires_grad=self.auto_set() ), +======= + "batch1": torch.rand((B, M, N), device=device).to(dtype=dtype), + "batch2": torch.rand((B, N, K), device=device).to(dtype=dtype), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } self.op_func = op_func @@ -58,11 +72,14 @@ def forward(self, batch1, batch2): batched_binary_configs_short + batched_binary_configs_long, BatchedBinaryOpBenchmark, ) +<<<<<<< HEAD op_bench.generate_pt_gradient_tests_from_op_list( batched_binary_ops, batched_binary_configs_long, BatchedBinaryOpBenchmark, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # batched ternary ops @@ -75,6 +92,7 @@ def forward(self, batch1, batch2): class BatchedTernaryOpBenchmark(op_bench.TorchBenchmarkBase): def init(self, B, M, N, K, device, dtype, op_func): self.inputs = { +<<<<<<< HEAD "input_": torch.rand( (B, M, K), device=device, dtype=dtype, requires_grad=self.auto_set() ), @@ -84,6 +102,11 @@ def init(self, B, M, N, K, device, dtype, op_func): "batch2": torch.rand( (B, N, K), device=device, dtype=dtype, requires_grad=self.auto_set() ), +======= + "input_": torch.rand((B, M, K), device=device).to(dtype=dtype), + "batch1": torch.rand((B, M, N), device=device).to(dtype=dtype), + "batch2": torch.rand((B, N, K), device=device).to(dtype=dtype), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } self.op_func = op_func @@ -96,12 +119,15 @@ def forward(self, input_, batch1, batch2): batched_binary_configs_short + batched_binary_configs_long, BatchedTernaryOpBenchmark, ) +<<<<<<< HEAD op_bench.generate_pt_gradient_tests_from_op_list( batched_ternary_ops, batched_binary_configs_long, BatchedTernaryOpBenchmark, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: does it automatically register new scripts? diff --git a/benchmarks/operator_benchmark/pt/matmul_test.py b/benchmarks/operator_benchmark/pt/matmul_test.py index d0c58aa16e8f3..94dcc3542ec20 100644 --- a/benchmarks/operator_benchmark/pt/matmul_test.py +++ b/benchmarks/operator_benchmark/pt/matmul_test.py @@ -13,12 +13,19 @@ [128, 128, 128, True, False], [256, 256, 256, False, True], ], +<<<<<<< HEAD cross_product_configs={"device": ["cpu", "cuda"]}, +======= + cross_product_configs={ + "device": ["cpu", "cuda"], + }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags=["short"], ) mm_long_configs = op_bench.cross_product_configs( +<<<<<<< HEAD M=[256, 1024, 3000], N=[512, 4096], K=[512, 4096], @@ -26,11 +33,20 @@ trans_b=[True, False], device=["cuda"], dtype=[torch.float16, torch.bfloat16, torch.float32], +======= + M=[32], + N=[512, 128], + K=[64], + trans_a=[False, True], + trans_b=[True, False], + device=["cpu", "cuda"], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags=["long"], ) class MatMulBenchmark(op_bench.TorchBenchmarkBase): +<<<<<<< HEAD def init(self, M, N, K, trans_a, trans_b, device, dtype=torch.float): # Create tensors without requires_grad first, then set it separately # This avoids creating graph leaves that cannot be deep copied @@ -53,6 +69,16 @@ def init(self, M, N, K, trans_a, trans_b, device, dtype=torch.float): self.inputs = { "input_one": input_one, "input_two": input_two, +======= + def init(self, M, N, K, trans_a, trans_b, device): + self.inputs = { + "input_one": torch.rand(M, N, device=device) + if trans_a + else torch.rand(N, M, device=device).t(), + "input_two": torch.rand(N, K, device=device) + if trans_b + else torch.rand(K, N, device=device).t(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } self.set_module_name("matmul") @@ -61,7 +87,10 @@ def forward(self, input_one, input_two): op_bench.generate_pt_test(mm_long_configs + mm_short_configs, MatMulBenchmark) +<<<<<<< HEAD op_bench.generate_pt_gradient_test(mm_long_configs, MatMulBenchmark) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": diff --git a/benchmarks/operator_benchmark/pt/mm_test.py b/benchmarks/operator_benchmark/pt/mm_test.py index f9e0743ba7125..813456a6ef076 100644 --- a/benchmarks/operator_benchmark/pt/mm_test.py +++ b/benchmarks/operator_benchmark/pt/mm_test.py @@ -23,11 +23,19 @@ ) mm_long_configs = op_bench.cross_product_configs( +<<<<<<< HEAD M=[256, 1024, 3000], N=[512, 4096], K=[512, 4096], device=["cuda"], dtype=[torch.float16, torch.bfloat16, torch.float32], +======= + M=[8, 128], + N=[32, 64], + K=[256, 512], + device=["cpu", "cuda"], + dtype=[torch.float, torch.bfloat16], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tags=["long"], ) @@ -35,12 +43,17 @@ class MmOpBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, K, device, dtype, op_func): self.inputs = { +<<<<<<< HEAD "input_one": torch.randn( M, N, device=device, requires_grad=self.auto_set(), dtype=dtype ), "input_two": torch.randn( N, K, device=device, requires_grad=self.auto_set(), dtype=dtype ), +======= + "input_one": torch.randn(M, N, device=device).to(dtype=dtype), + "input_two": torch.randn(N, K, device=device).to(dtype=dtype), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } self.op_func = op_func @@ -51,9 +64,12 @@ def forward(self, input_one, input_two): op_bench.generate_pt_tests_from_op_list( ops_list, mm_short_configs + mm_long_configs, MmOpBenchmark ) +<<<<<<< HEAD op_bench.generate_pt_gradient_tests_from_op_list( ops_list, mm_long_configs, MmOpBenchmark ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": diff --git a/benchmarks/sparse/test_csr.sh b/benchmarks/sparse/test_csr.sh index f0b460b8a882b..e7833f0e6589b 100644 --- a/benchmarks/sparse/test_csr.sh +++ b/benchmarks/sparse/test_csr.sh @@ -11,7 +11,11 @@ export USE_MKL=1 CMAKE_ONLY=1 python setup.py build ccmake build # or cmake-gui build +<<<<<<< HEAD python -m pip install --no-build-isolation -v . +======= +python setup.py install +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cd benchmarks echo "!! SPARSE SPMM TIME BENCHMARK!! " >> $OUTFILE @@ -28,7 +32,11 @@ echo "----- USE_MKL=0 ------" >> $OUTFILE rm -rf build export USE_MKL=0 +<<<<<<< HEAD python -m pip install --no-build-isolation -v . +======= +python setup.py install +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cd benchmarks for dim0 in 1000 5000 10000; do diff --git a/buckbuild.bzl b/buckbuild.bzl index 193c16fbd4e5f..e375056ef0077 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -11,7 +11,11 @@ load("//tools/build_defs:glob_defs.bzl", "subdir_glob") load("//tools/build_defs:platform_defs.bzl", "APPLETVOS", "IOS", "MACOSX") load("//tools/build_defs:type_defs.bzl", "is_list", "is_string") load("//tools/build_defs/android:build_mode_defs.bzl", is_production_build_android = "is_production_build") +<<<<<<< HEAD load("//tools/build_defs/apple:build_mode_defs.bzl", is_production_build_ios = "is_production_build", is_profile_build_ios = "is_profile_build") +======= +load("//tools/build_defs/apple:build_mode_defs.bzl", is_production_build_ios = "is_production_build") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) load( ":build_variables.bzl", "aten_cpu_source_list", @@ -74,7 +78,11 @@ def _is_build_mode_dev(): if is_production_build_android(): # Android Prod builds return False +<<<<<<< HEAD if is_production_build_ios() or is_profile_build_ios(): +======= + if is_production_build_ios(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # iOS Prod builds return False @@ -391,8 +399,11 @@ def get_aten_generated_files(enabled_backends): "CompositeExplicitAutogradFunctions_inl.h", "CompositeExplicitAutogradNonFunctionalFunctions.h", "CompositeExplicitAutogradNonFunctionalFunctions_inl.h", +<<<<<<< HEAD "ViewMetaClasses.h", "ViewMetaClasses.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "VmapGeneratedPlumbing.h", "core/ATenOpList.cpp", "core/TensorBody.h", @@ -826,6 +837,7 @@ def get_pt_operator_registry_dict( apple_sdks = kwargs.get("apple_sdks"), ) +<<<<<<< HEAD # Extract existing linker_flags from kwargs and combine with default flags existing_linker_flags = kwargs.pop("linker_flags", []) combined_linker_flags = get_no_as_needed_linker_flag() + existing_linker_flags @@ -833,6 +845,11 @@ def get_pt_operator_registry_dict( return dict( srcs = code_gen_files["srcs"], linker_flags = combined_linker_flags, +======= + return dict( + srcs = code_gen_files["srcs"], + linker_flags = get_no_as_needed_linker_flag(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # @lint-ignore BUCKLINT link_whole link_whole = True, soname = "libtorch-code-gen.$(ext)", @@ -950,7 +967,10 @@ def define_buck_targets( [ ("torch/csrc/api/include", "torch/**/*.h"), ("", "torch/csrc/**/*.h"), +<<<<<<< HEAD ("", "torch/nativert/**/*.h"), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("", "torch/headeronly/**/*.h"), ("", "torch/script.h"), ("", "torch/library.h"), @@ -1150,9 +1170,12 @@ def define_buck_targets( "--replace", "@AT_KLEIDIAI_ENABLED@", "0", +<<<<<<< HEAD "--replace", "@AT_USE_EIGEN_SPARSE@", "0", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ]), outs = { "Config.h": ["Config.h"], @@ -1194,7 +1217,10 @@ def define_buck_targets( "NativeMetaFunctions.h": ":gen_aten[NativeMetaFunctions.h]", "Operators.h": ":gen_aten[Operators.h]", "RedispatchFunctions.h": ":gen_aten[RedispatchFunctions.h]", +<<<<<<< HEAD "ViewMetaClasses.h": ":gen_aten[ViewMetaClasses.h]", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "core/TensorBody.h": ":gen_aten[core/TensorBody.h]", "core/aten_interned_strings.h": ":gen_aten[core/aten_interned_strings.h]", "core/enum_tag.h": ":gen_aten[core/enum_tag.h]", @@ -1255,7 +1281,10 @@ def define_buck_targets( "torch/csrc/jit/mobile/parse_operators.cpp", "torch/csrc/jit/mobile/upgrader_mobile.cpp", "torch/csrc/jit/serialization/import_read.cpp", +<<<<<<< HEAD "torch/csrc/jit/serialization/pickler_helper.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/csrc/jit/serialization/unpickler.cpp", ], header_namespace = "", diff --git a/build.bzl b/build.bzl index 91529e75c9f09..a7f2b2033587e 100644 --- a/build.bzl +++ b/build.bzl @@ -118,9 +118,12 @@ def define_targets(rules): ":LazyNonNativeIr.h", ":RegisterDispatchDefinitions.ini", ":RegisterDispatchKey.cpp", +<<<<<<< HEAD ":ViewMetaClassesPythonBinding.cpp", ":ViewMetaClasses.cpp", ":ViewMetaClasses.h", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ":native_functions.yaml", ":shape_inference.h", ":tags.yaml", @@ -173,7 +176,10 @@ GENERATED_H = [ "FunctionalInverses.h", "RedispatchFunctions.h", "RegistrationDeclarations.h", +<<<<<<< HEAD "ViewMetaClasses.h", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "VmapGeneratedPlumbing.h", ] @@ -250,7 +256,10 @@ GENERATED_CPP = [ "RegisterFunctionalization_1.cpp", "RegisterFunctionalization_2.cpp", "RegisterFunctionalization_3.cpp", +<<<<<<< HEAD "ViewMetaClasses.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] GENERATED_CPP_CORE = [ @@ -312,7 +321,10 @@ _GENERATED_AUTOGRAD_PYTHON_CPP = [ "torch/csrc/autograd/generated/python_torch_functions_1.cpp", "torch/csrc/autograd/generated/python_torch_functions_2.cpp", "torch/csrc/autograd/generated/python_variable_methods.cpp", +<<<<<<< HEAD "torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] GENERATED_AUTOGRAD_PYTHON = _GENERATED_AUTOGRAD_PYTHON_HEADERS + _GENERATED_AUTOGRAD_PYTHON_CPP diff --git a/build_variables.bzl b/build_variables.bzl index 05f5fb1068c84..31da2c3854831 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -89,7 +89,10 @@ core_sources_common = [ torch_unpickler_common = [ "torch/csrc/jit/serialization/import_read.cpp", +<<<<<<< HEAD "torch/csrc/jit/serialization/pickler_helper.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/csrc/jit/serialization/unpickler.cpp", ] @@ -512,14 +515,20 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/TCPStore.cpp", "torch/csrc/distributed/c10d/TCPStoreBackend.cpp", "torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp", +<<<<<<< HEAD "torch/csrc/distributed/c10d/Types.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/csrc/distributed/c10d/Utils.cpp", "torch/csrc/distributed/c10d/Work.cpp", "torch/csrc/distributed/c10d/comm.cpp", "torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp", "torch/csrc/distributed/c10d/control_plane/Handlers.cpp", "torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp", +<<<<<<< HEAD "torch/csrc/distributed/c10d/cuda/StreamBlock.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/csrc/distributed/c10d/debug.cpp", "torch/csrc/distributed/c10d/default_comm_hooks.cpp", "torch/csrc/distributed/c10d/logger.cpp", @@ -594,20 +603,30 @@ libtorch_core_jit_sources = sorted(jit_sources_full) libtorch_nativert_sources = [ +<<<<<<< HEAD "torch/nativert/ModelRunner.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/nativert/graph/Graph.cpp", "torch/nativert/graph/GraphPasses.cpp", "torch/nativert/graph/GraphSignature.cpp", "torch/nativert/graph/Serialization.cpp", "torch/nativert/graph/TensorMeta.cpp", +<<<<<<< HEAD "torch/nativert/graph/GraphUtils.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/nativert/executor/DelegateExecutor.cpp", "torch/nativert/executor/Placement.cpp", "torch/nativert/executor/ExecutionPlanner.cpp", "torch/nativert/executor/ExecutionFrame.cpp", +<<<<<<< HEAD "torch/nativert/executor/Executor.cpp", "torch/nativert/executor/GraphExecutorBase.cpp", "torch/nativert/executor/ConstantFolder.cpp", +======= + "torch/nativert/executor/GraphExecutorBase.cpp", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/nativert/executor/OpKernel.cpp", "torch/nativert/executor/PlacementUtils.cpp", "torch/nativert/executor/SerialGraphExecutor.cpp", @@ -620,6 +639,7 @@ libtorch_nativert_sources = [ "torch/nativert/kernels/HigherOrderKernel.cpp", "torch/nativert/executor/memory/GreedyBySize.cpp", "torch/nativert/executor/memory/Bump.cpp", +<<<<<<< HEAD "torch/nativert/executor/ParallelGraphExecutor.cpp", "torch/nativert/kernels/CallTorchBindKernel.cpp", "torch/nativert/kernels/KernelFactory.cpp", @@ -642,6 +662,8 @@ libtorch_nativert_sources = [ libtorch_nativert_cuda_sources = [ "torch/nativert/executor/triton/CudaTritonKernelManager.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] torch_mobile_tracer_sources = [ @@ -664,7 +686,10 @@ libtorch_lite_eager_symbolication = [ # Later we can split serialization and deserialization logic # to have better separation within build and only build relevant parts. "torch/csrc/jit/serialization/pickle.cpp", +<<<<<<< HEAD "torch/csrc/jit/serialization/pickler_helper.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/csrc/jit/serialization/pickler.cpp", "torch/csrc/jit/serialization/unpickler.cpp", ] @@ -751,9 +776,13 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/UCCTracing.cpp", "torch/csrc/distributed/c10d/UCCUtils.cpp", "torch/csrc/distributed/c10d/cuda/AsyncMM.cu", +<<<<<<< HEAD "torch/csrc/distributed/c10d/cuda/CUDAEventCache.cpp", "torch/csrc/distributed/c10d/cuda/utils.cpp", "torch/csrc/distributed/c10d/cuda/StreamBlock.cu", +======= + "torch/csrc/distributed/c10d/cuda/utils.cpp", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu", @@ -762,6 +791,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu", "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp", "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu", +<<<<<<< HEAD "torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", ] @@ -773,11 +803,20 @@ libtorch_nvshmem_sources = [ "torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu", ] +======= + "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", +] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) libtorch_cuda_distributed_sources = libtorch_cuda_distributed_base_sources + libtorch_cuda_distributed_extra_sources libtorch_cuda_sources = libtorch_cuda_core_sources + libtorch_cuda_distributed_sources + [ "torch/csrc/cuda/nccl.cpp", +<<<<<<< HEAD ] + libtorch_nativert_cuda_sources +======= +] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch_cpp_srcs = [ "torch/csrc/api/src/cuda.cpp", # this just forwards stuff, no real CUDA @@ -885,7 +924,10 @@ libtorch_python_core_sources = [ "torch/csrc/QScheme.cpp", "torch/csrc/Module.cpp", "torch/csrc/PyInterpreter.cpp", +<<<<<<< HEAD "torch/csrc/PyInterpreterHooks.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/csrc/python_dimname.cpp", "torch/csrc/Size.cpp", "torch/csrc/Storage.cpp", @@ -925,8 +967,11 @@ libtorch_python_core_sources = [ "torch/csrc/mps/Module.cpp", "torch/csrc/mtia/Module.cpp", "torch/csrc/export/pybind.cpp", +<<<<<<< HEAD "torch/csrc/export/upgrader.cpp", "torch/csrc/export/example_upgraders.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/csrc/inductor/aoti_package/pybind.cpp", "torch/csrc/inductor/aoti_runner/pybind.cpp", "torch/csrc/inductor/aoti_eager/kernel_holder.cpp", @@ -1007,9 +1052,13 @@ libtorch_python_core_sources = [ "torch/csrc/utils/disable_torch_function.cpp", "torch/csrc/utils/verbose.cpp", "torch/csrc/cpu/Module.cpp", +<<<<<<< HEAD "torch/csrc/functionalization/Module.cpp", "torch/csrc/instruction_counter/Module.cpp", "torch/nativert/python/Bindings.cpp", +======= + "torch/csrc/instruction_counter/Module.cpp", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] + lazy_tensor_core_python_sources libtorch_python_distributed_core_sources = [ @@ -1050,7 +1099,10 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"): "torch/csrc/autograd/generated/python_torch_functions_1.cpp", "torch/csrc/autograd/generated/python_torch_functions_2.cpp", "torch/csrc/autograd/generated/python_variable_methods.cpp", +<<<<<<< HEAD "torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ]] _libtorch_python_sources.extend(libtorch_python_core_sources) @@ -1096,7 +1148,10 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/DeviceAccelerator.cpp", "aten/src/ATen/Context.cpp", "aten/src/ATen/DLConvertor.cpp", +<<<<<<< HEAD "aten/src/ATen/DTensorState.cpp", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "aten/src/ATen/EmptyTensor.cpp", "aten/src/ATen/ExpandUtils.cpp", "aten/src/ATen/CachedTensorUtils.cpp", diff --git a/c10/BUCK.oss b/c10/BUCK.oss index 4ec4ab5beabb4..415e5e9188feb 100644 --- a/c10/BUCK.oss +++ b/c10/BUCK.oss @@ -37,6 +37,11 @@ cxx_library( ), exported_linker_flags = [], exported_preprocessor_flags = [ +<<<<<<< HEAD +======= + '-DC10_USING_CUSTOM_GENERATED_MACROS', + '-DC10_USE_GLOG', +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) '-DC10_USE_MINIMAL_GLOG', '-DC10_MOBILE', '-fexceptions', diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt index f82e460cafc31..83eced0156995 100644 --- a/c10/CMakeLists.txt +++ b/c10/CMakeLists.txt @@ -18,12 +18,25 @@ else() set(C10_LIB c10) endif() +<<<<<<< HEAD set(C10_USE_GFLAGS ${USE_GFLAGS}) # also used in torch/headeronly set(C10_USE_GLOG ${USE_GLOG}) # also used in torch/headeronly set(C10_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) # also used in torch/headeronly set(C10_USE_NUMA ${USE_NUMA}) # also used in torch/headeronly set(C10_USE_MSVC_STATIC_RUNTIME ${CAFFE2_USE_MSVC_STATIC_RUNTIME}) # also used in torch/headeronly set(C10_USE_ROCM_KERNEL_ASSERT ${USE_ROCM_KERNEL_ASSERT}) # also used in torch/headeronly +======= + # ---[ Configure macro file. + set(C10_USE_GFLAGS ${USE_GFLAGS}) # used in cmake_macros.h.in + set(C10_USE_GLOG ${USE_GLOG}) # used in cmake_macros.h.in + set(C10_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) # used in cmake_macros.h.in + set(C10_USE_NUMA ${USE_NUMA}) + set(C10_USE_MSVC_STATIC_RUNTIME ${CAFFE2_USE_MSVC_STATIC_RUNTIME}) + set(C10_USE_ROCM_KERNEL_ASSERT ${USE_ROCM_KERNEL_ASSERT}) + configure_file( + ${CMAKE_CURRENT_LIST_DIR}/macros/cmake_macros.h.in + ${CMAKE_BINARY_DIR}/c10/macros/cmake_macros.h) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Note: if you want to add ANY dependency to the c10 library, make sure you # check with the core PyTorch developers as the dependency will be @@ -90,8 +103,11 @@ if(NOT BUILD_LIBTORCHLESS) if(C10_USE_GLOG) target_link_libraries(c10 PUBLIC glog::glog) endif() +<<<<<<< HEAD target_link_libraries(c10 PUBLIC headeronly) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) target_link_libraries(c10 PRIVATE fmt::fmt-header-only) target_link_libraries(c10 PRIVATE nlohmann) target_link_libraries(c10 PRIVATE moodycamel) @@ -168,6 +184,11 @@ endif() install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR} DESTINATION include FILES_MATCHING PATTERN "*.h") +<<<<<<< HEAD +======= +install(FILES ${CMAKE_BINARY_DIR}/c10/macros/cmake_macros.h + DESTINATION include/c10/macros) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if(MSVC AND C10_BUILD_SHARED_LIBS) install(FILES $ DESTINATION lib OPTIONAL) diff --git a/c10/core/Backend.h b/c10/core/Backend.h index 0497d72b95703..de8dd2563dfc4 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -38,8 +38,11 @@ enum class Backend { SparseCUDA, SparseCsrCPU, SparseCsrCUDA, +<<<<<<< HEAD SparseCsrMPS, SparseMPS, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SparseHIP, SparseVE, SparseXPU, @@ -96,10 +99,13 @@ inline Backend dispatchKeyToBackend(DispatchKey t) { return Backend::SparseCPU; } else if (t == DispatchKey::SparseCUDA) { return Backend::SparseCUDA; +<<<<<<< HEAD } else if (t == DispatchKey::SparseMPS) { return Backend::SparseMPS; } else if (t == DispatchKey::SparseCsrMPS) { return Backend::SparseCsrMPS; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else if (t == DispatchKey::SparseHIP) { return Backend::SparseHIP; } else if (t == DispatchKey::SparseVE) { @@ -178,10 +184,13 @@ inline DispatchKey backendToDispatchKey(Backend b) { return DispatchKey::SparseCPU; case Backend::SparseCUDA: return DispatchKey::SparseCUDA; +<<<<<<< HEAD case Backend::SparseMPS: return DispatchKey::SparseMPS; case Backend::SparseCsrMPS: return DispatchKey::SparseCsrMPS; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case Backend::SparseHIP: return DispatchKey::SparseHIP; case Backend::SparseVE: @@ -274,8 +283,11 @@ inline DeviceType backendToDeviceType(Backend b) { case Backend::Meta: return DeviceType::Meta; case Backend::MPS: +<<<<<<< HEAD case Backend::SparseMPS: case Backend::SparseCsrMPS: +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return DeviceType::MPS; case Backend::HPU: return DeviceType::HPU; @@ -321,10 +333,13 @@ inline const char* toString(Backend b) { return "SparseCPU"; case Backend::SparseCUDA: return "SparseCUDA"; +<<<<<<< HEAD case Backend::SparseMPS: return "SparseMPS"; case Backend::SparseCsrMPS: return "SparseCsrMPS"; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case Backend::SparseHIP: return "SparseHIP"; case Backend::SparseVE: @@ -377,7 +392,10 @@ inline bool isSparse(Backend b) { case Backend::SparseXPU: case Backend::SparseCPU: case Backend::SparseCUDA: +<<<<<<< HEAD case Backend::SparseMPS: +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case Backend::SparseHIP: case Backend::SparseVE: case Backend::SparsePrivateUse1: diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h index 0bec03ae417fa..15bcb2d62081c 100644 --- a/c10/core/CachingDeviceAllocator.h +++ b/c10/core/CachingDeviceAllocator.h @@ -1,7 +1,10 @@ #pragma once #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace c10::CachingDeviceAllocator { @@ -60,6 +63,7 @@ struct DeviceStats { }; } // namespace c10::CachingDeviceAllocator +<<<<<<< HEAD namespace c10 { @@ -112,3 +116,5 @@ C10_API inline DeviceAllocator* getDeviceAllocator(const DeviceType& t) { } } // namespace c10 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/core/Contiguity.h b/c10/core/Contiguity.h index eed3f24983424..532e2915baf66 100644 --- a/c10/core/Contiguity.h +++ b/c10/core/Contiguity.h @@ -12,7 +12,11 @@ namespace c10 { template bool _compute_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { +<<<<<<< HEAD if (numel == 0) { +======= + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return true; } @@ -20,11 +24,19 @@ bool _compute_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { // NB: make sure we do signed arithmetic for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { const auto& size_d = sizes[d]; +<<<<<<< HEAD if (size_d == 1) { continue; } if (strides[d] != expected_stride) { +======= + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(size_d, 1))) { + continue; + } + + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected_stride))) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return false; } expected_stride *= size_d; @@ -32,6 +44,7 @@ bool _compute_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { return true; } +<<<<<<< HEAD // Return a SymBool with underlying symbolic expression that represents // contiguity. Guaranteed not to throw DDE, may returns a symbolic expressions // or symbolic True. @@ -100,6 +113,33 @@ inline static c10::SymBool _compute_contiguous_sym( // When T is SymInt this function may throw a data dependent error. // _compute_channels_last_contiguous_2d_sym does not. Only use this function // when inputs are hinted. +======= +// This function will return True if the tensor is contiguous, and False if the +// its not or if we can't determine if it is contiguous due to unbacked symbols +// (it could be either in that case based on the actual runtime data). +template +bool definitely_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { + if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) { + return true; + } + + T expected_stride = 1; + // NB: make sure we do signed arithmetic + for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { + const auto& size_d = sizes[d]; + if (TORCH_GUARD_OR_FALSE(sym_eq(size_d, 1))) { + continue; + } + + if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride))) { + return false; + } + expected_stride *= size_d; + } + return true; +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template bool _compute_channels_last_contiguous_2d( ArrayRef sizes, @@ -111,8 +151,13 @@ bool _compute_channels_last_contiguous_2d( T expected = 1; for (auto& d : {1, 3, 2, 0}) { const auto& size_d = sizes[d]; +<<<<<<< HEAD if (size_d != 1) { if (strides[d] != expected) { +======= + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) { + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return false; } expected *= size_d; @@ -129,6 +174,7 @@ bool _compute_channels_last_contiguous_2d( } } +<<<<<<< HEAD // Return a SymBool with underlying symbolic expression that represents // contiguity. Guaranteed not to throw DDE, may returns a symbolic expressions // or symbolic True. @@ -188,6 +234,8 @@ inline static c10::SymBool _compute_channels_last_contiguous_2d_sym( // When T is SymInt this function may throw a data dependent error. // _compute_channels_last_contiguous_3d_sym does not. Only use this function // when inputs are hinted. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template bool _compute_channels_last_contiguous_3d( ArrayRef sizes, @@ -199,8 +247,13 @@ bool _compute_channels_last_contiguous_3d( T expected = 1; for (auto& d : {1, 4, 3, 2, 0}) { const auto& size_d = sizes[d]; +<<<<<<< HEAD if (size_d != 1) { if (strides[d] != expected) { +======= + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) { + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return false; } expected *= size_d; @@ -217,6 +270,7 @@ bool _compute_channels_last_contiguous_3d( } } +<<<<<<< HEAD inline static c10::SymBool _compute_channels_last_contiguous_3d_sym( ArrayRef sizes, ArrayRef strides) { @@ -270,6 +324,8 @@ inline static c10::SymBool _compute_channels_last_contiguous_3d_sym( } } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template bool _compute_non_overlapping_and_dense( ArrayRef sizes, diff --git a/c10/core/Device.cpp b/c10/core/Device.cpp index 68fa6f91979ab..ea6473bb999a1 100644 --- a/c10/core/Device.cpp +++ b/c10/core/Device.cpp @@ -41,9 +41,12 @@ DeviceType parse_type(const std::string& device_string) { "'mkldnn' is no longer used as device type. So torch.device('mkldnn') will be " "deprecated and removed in the future. Please use other valid device types instead."); } +<<<<<<< HEAD if (device_string == get_privateuse1_backend()) { return DeviceType::PrivateUse1; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto device = std::find_if( types.begin(), types.end(), @@ -53,6 +56,12 @@ DeviceType parse_type(const std::string& device_string) { if (device != types.end()) { return device->second; } +<<<<<<< HEAD +======= + if (device_string == get_privateuse1_backend()) { + return DeviceType::PrivateUse1; + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector device_names; for (const auto& it : types) { if (it.first) { diff --git a/c10/core/Device.h b/c10/core/Device.h index 52a116d4e3f6a..2e05283773edd 100644 --- a/c10/core/Device.h +++ b/c10/core/Device.h @@ -160,7 +160,11 @@ struct C10_API Device final { /// Return true if the device supports arbitrary strides. bool supports_as_strided() const noexcept { return type_ != DeviceType::IPU && type_ != DeviceType::XLA && +<<<<<<< HEAD type_ != DeviceType::Lazy; +======= + type_ != DeviceType::Lazy && type_ != DeviceType::MTIA; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } /// Same string as returned from operator<<. diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp index 907493981e117..f8b9c33f23ec6 100644 --- a/c10/core/DeviceType.cpp +++ b/c10/core/DeviceType.cpp @@ -158,7 +158,11 @@ void register_privateuse1_backend(const std::string& backend_name) { privateuse1_backend_name = backend_name; // Invariant: once this flag is set, privateuse1_backend_name is NEVER written // to. +<<<<<<< HEAD privateuse1_backend_name_set.store(true, std::memory_order_release); +======= + privateuse1_backend_name_set.store(true, std::memory_order_relaxed); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } bool is_privateuse1_backend_registered() { diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 7c239ecddede2..0b0ee5228c618 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -354,8 +354,11 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { {"SparseCPU", c10::DispatchKey::SparseCPU}, {"SparseCUDA", c10::DispatchKey::SparseCUDA}, +<<<<<<< HEAD {"SparseMPS", c10::DispatchKey::SparseMPS}, {"SparseCsrMPS", c10::DispatchKey::SparseCsrMPS}, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"SparseHIP", c10::DispatchKey::SparseHIP}, {"SparseXPU", c10::DispatchKey::SparseXPU}, {"SparseVE", c10::DispatchKey::SparseVE}, diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index dea4c5a55de79..1035a64ac7ed4 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -634,7 +634,11 @@ class DispatchKeySet final { C10_API std::string toString(DispatchKeySet); C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet); +<<<<<<< HEAD inline int getDispatchTableIndexForDispatchKey(DispatchKey k) { +======= +C10_API inline int getDispatchTableIndexForDispatchKey(DispatchKey k) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return DispatchKeySet(k).getDispatchTableIndexForDispatchKeySet(); } diff --git a/c10/core/Layout.h b/c10/core/Layout.h index 0d09e0ed46f4e..07b2b3a20ce27 100644 --- a/c10/core/Layout.h +++ b/c10/core/Layout.h @@ -32,7 +32,10 @@ inline Layout layout_from_backend(Backend backend) { switch (backend) { case Backend::SparseCPU: case Backend::SparseCUDA: +<<<<<<< HEAD case Backend::SparseMPS: +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case Backend::SparseHIP: case Backend::SparseVE: case Backend::SparseXPU: @@ -42,13 +45,20 @@ inline Layout layout_from_backend(Backend backend) { return Layout::Mkldnn; case Backend::SparseCsrCPU: case Backend::SparseCsrCUDA: +<<<<<<< HEAD case Backend::SparseCsrMPS: +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case Backend::SparseCsrHIP: case Backend::SparseCsrVE: case Backend::SparseCsrXPU: TORCH_CHECK( false, +<<<<<<< HEAD "Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU|MPS) to a unique layout."); +======= + "Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU) to a unique layout."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) default: return Layout::Strided; } diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 646a1dde39940..0ed9e176aa9c2 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -191,17 +191,23 @@ class C10_API Scalar { isIntegral() const { return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag; } +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool isIntegral(bool includeBool) const { return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag || (includeBool && isBoolean()); } +<<<<<<< HEAD // See Note [Meaning of HAS_u] bool isUnsigned() const { return Tag::HAS_u == tag || (Tag::HAS_i == tag && v.i >= 0); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool isComplex() const { return Tag::HAS_z == tag; } diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 4a15eb23ac63c..15828731191e0 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -19,16 +19,36 @@ #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include #include +<<<<<<< HEAD #include namespace c10 { // [dtype Macros note] For the macros below: +======= +namespace c10 { + +// dummy struct for uint1 to uint7, actual functionality +// of these dtypes will be implemented in python with Tensor subclass +template +struct dummy_uint1_7_t {}; + +// dummy struct for int1 to int7, actual functionality +// of these dtypes will be implemented in python with Tensor subclass +template +struct dummy_int1_7_t {}; + +// For the macros below: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // // For users: If you want to macro some code for all non-QInt scalar types // (i.e. types with complete information, you probably want one of the @@ -48,6 +68,59 @@ namespace c10 { // some old PRs where we added new dtypes (check history of this file) can // help give you an idea where to start. +<<<<<<< HEAD +======= +// NB: Order matters for this macro; it is relied upon in +// _promoteTypesLookup and the serialization format. +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ + _(uint8_t, Byte) /* 0 */ \ + _(int8_t, Char) /* 1 */ \ + _(int16_t, Short) /* 2 */ \ + _(int, Int) /* 3 */ \ + _(int64_t, Long) /* 4 */ \ + _(at::Half, Half) /* 5 */ \ + _(float, Float) /* 6 */ \ + _(double, Double) /* 7 */ \ + _(c10::complex, ComplexHalf) /* 8 */ \ + _(c10::complex, ComplexFloat) /* 9 */ \ + _(c10::complex, ComplexDouble) /* 10 */ \ + _(bool, Bool) /* 11 */ \ + _(c10::qint8, QInt8) /* 12 */ \ + _(c10::quint8, QUInt8) /* 13 */ \ + _(c10::qint32, QInt32) /* 14 */ \ + _(at::BFloat16, BFloat16) /* 15 */ \ + _(c10::quint4x2, QUInt4x2) /* 16 */ \ + _(c10::quint2x4, QUInt2x4) /* 17 */ \ + _(c10::bits1x8, Bits1x8) /* 18 */ \ + _(c10::bits2x4, Bits2x4) /* 19 */ \ + _(c10::bits4x2, Bits4x2) /* 20 */ \ + _(c10::bits8, Bits8) /* 21 */ \ + _(c10::bits16, Bits16) /* 22 */ \ + _(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \ + _(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \ + _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \ + _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \ + _(uint16_t, UInt16) /* 27 */ \ + _(uint32_t, UInt32) /* 28 */ \ + _(uint64_t, UInt64) /* 29 */ \ + _(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \ + _(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \ + _(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \ + _(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \ + _(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \ + _(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \ + _(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \ + _(c10::dummy_int1_7_t<1>, Int1) /* 37 */ \ + _(c10::dummy_int1_7_t<2>, Int2) /* 38 */ \ + _(c10::dummy_int1_7_t<3>, Int3) /* 39 */ \ + _(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \ + _(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \ + _(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \ + _(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \ + _(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \ + _(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // If you want to support ComplexHalf for real, add ComplexHalf // into this macro (and change the name). But beware: convert() // doesn't work for all the conversions you need... @@ -93,6 +166,20 @@ namespace c10 { _(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \ _(at::Float8_e8m0fnu, Float8_e8m0fnu) +<<<<<<< HEAD +======= +enum class ScalarType : int8_t { +#define DEFINE_ST_ENUM_VAL_(_1, n) n, + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_) +#undef DEFINE_ENUM_ST_ENUM_VAL_ + Undefined, + NumOptions +}; + +constexpr uint16_t NumScalarTypes = + static_cast(ScalarType::NumOptions); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace impl { // These are used to map ScalarTypes to C++ types. diff --git a/c10/core/SymInt.cpp b/c10/core/SymInt.cpp index b78ca94dc5145..676125d0bcf81 100644 --- a/c10/core/SymInt.cpp +++ b/c10/core/SymInt.cpp @@ -20,6 +20,7 @@ void SymInt::promote_to_negative() { s.data_ = 0; } +<<<<<<< HEAD std::optional SymInt::maybe_as_int_slow_path() const { auto* node = toSymNodeImplUnowned(); if (auto c = node->constant_int()) { @@ -28,6 +29,8 @@ std::optional SymInt::maybe_as_int_slow_path() const { return node->maybe_as_int(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SymNode SymInt::toSymNode() const { TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE( is_heap_allocated(), "SymInt::toSymNode is_heap_allocated"); @@ -53,11 +56,20 @@ bool SymInt::has_hint() const { #define DEFINE_BINARY(API, OP, METHOD, RET) \ RET SymInt::API(const SymInt& sci) const { \ if (auto ma = maybe_as_int()) { \ +<<<<<<< HEAD TORCH_INTERNAL_ASSERT_DEBUG_ONLY( \ !sci.maybe_as_int(), \ "should have hit fast path in the header in this case."); \ auto b = sci.toSymNode(); \ return RET(b->wrap_int(*ma)->METHOD(b)); \ +======= + if (auto mb = sci.maybe_as_int()) { \ + return RET(OP(*ma, *mb)); \ + } else { \ + auto b = sci.toSymNode(); \ + return RET(b->wrap_int(*ma)->METHOD(b)); \ + } \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { \ if (auto mb = sci.maybe_as_int()) { \ auto a = toSymNodeImplUnowned(); \ @@ -68,6 +80,7 @@ bool SymInt::has_hint() const { } \ } +<<<<<<< HEAD DEFINE_BINARY(operator_add_slow_path, std::plus<>(), add, SymInt) DEFINE_BINARY(operator_sub_slow_path, std::minus<>(), sub, SymInt) DEFINE_BINARY(operator_mul_slow_path, std::multiplies<>(), mul, SymInt) @@ -81,6 +94,21 @@ DEFINE_BINARY(sym_gt_slow_path, std::greater<>(), gt, SymBool) DEFINE_BINARY(sym_ge_slow_path, std::greater_equal<>(), ge, SymBool) DEFINE_BINARY(min_slow_path, std::min, sym_min, SymInt) DEFINE_BINARY(max_slow_path, std::max, sym_max, SymInt) +======= +DEFINE_BINARY(operator+, std::plus<>(), add, SymInt) +DEFINE_BINARY(operator-, std::minus<>(), sub, SymInt) +DEFINE_BINARY(operator*, std::multiplies<>(), mul, SymInt) +DEFINE_BINARY(operator/, std::divides<>(), floordiv, SymInt) +DEFINE_BINARY(operator%, std::modulus<>(), mod, SymInt) +DEFINE_BINARY(sym_eq, std::equal_to<>(), eq, SymBool) +DEFINE_BINARY(sym_ne, std::not_equal_to<>(), ne, SymBool) +DEFINE_BINARY(sym_lt, std::less<>(), lt, SymBool) +DEFINE_BINARY(sym_le, std::less_equal<>(), le, SymBool) +DEFINE_BINARY(sym_gt, std::greater<>(), gt, SymBool) +DEFINE_BINARY(sym_ge, std::greater_equal<>(), ge, SymBool) +DEFINE_BINARY(min, std::min, sym_min, SymInt) +DEFINE_BINARY(max, std::max, sym_max, SymInt) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SymInt::operator SymFloat() const { if (auto ma = maybe_as_int()) { @@ -160,6 +188,7 @@ SymInt operator-(const SymInt& s) { } } +<<<<<<< HEAD void SymInt::operator_imul_slow_path(const SymInt& sci) { *this = *this * sci; } @@ -169,6 +198,17 @@ void SymInt::operator_idiv_slow_path(const SymInt& sci) { } void SymInt::operator_iadd_slow_path(const SymInt& sci) { +======= +void SymInt::operator*=(const SymInt& sci) { + *this = *this * sci; +} + +void SymInt::operator/=(const SymInt& sci) { + *this = *this / sci; +} + +void SymInt::operator+=(const SymInt& sci) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *this = *this + sci; } diff --git a/c10/core/SymInt.h b/c10/core/SymInt.h index 9b1c776cbe2ab..96aa42a847032 100644 --- a/c10/core/SymInt.h +++ b/c10/core/SymInt.h @@ -7,7 +7,10 @@ #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -178,6 +181,7 @@ class C10_API SymInt { #endif } +<<<<<<< HEAD SymInt operator+(const SymInt& sci) const { if (auto ma = maybe_as_int()) { if (auto mb = sci.maybe_as_int()) { @@ -308,6 +312,25 @@ class C10_API SymInt { } return sym_ge_slow_path(sci); } +======= + SymInt operator+(const SymInt& sci) const; + SymInt operator-(const SymInt& sci) const; + SymInt operator*(const SymInt& sci) const; + SymInt operator/(const SymInt& sci) const; + SymInt operator%(const SymInt& sci) const; + void operator*=(const SymInt& sci); + void operator+=(const SymInt& sci); + void operator/=(const SymInt& sci); + + SymInt clone() const; + + SymBool sym_eq(const SymInt&) const; + SymBool sym_ne(const SymInt&) const; + SymBool sym_lt(const SymInt&) const; + SymBool sym_le(const SymInt&) const; + SymBool sym_gt(const SymInt&) const; + SymBool sym_ge(const SymInt&) const; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool operator==(const SymInt& o) const { return sym_eq(o).guard_bool(__FILE__, __LINE__); @@ -328,6 +351,7 @@ class C10_API SymInt { return sym_ge(o).guard_bool(__FILE__, __LINE__); } +<<<<<<< HEAD SymInt min(const SymInt& sci) const { if (auto ma = maybe_as_int()) { if (auto mb = sci.maybe_as_int()) { @@ -345,6 +369,10 @@ class C10_API SymInt { } return max_slow_path(sci); } +======= + SymInt min(const SymInt& sci) const; + SymInt max(const SymInt& sci) const; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // If both are symbolic, this checks if // they share the same node. @@ -368,7 +396,15 @@ class C10_API SymInt { if (!is_heap_allocated()) { return data_; } +<<<<<<< HEAD return maybe_as_int_slow_path(); +======= + auto* node = toSymNodeImplUnowned(); + if (auto c = node->constant_int()) { + return c; + } + return node->maybe_as_int(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // Return whether the integer is directly coercible to a SymInt @@ -389,6 +425,7 @@ class C10_API SymInt { private: void promote_to_negative(); +<<<<<<< HEAD SymInt operator_add_slow_path(const SymInt& sci) const; SymInt operator_sub_slow_path(const SymInt& sci) const; SymInt operator_mul_slow_path(const SymInt& sci) const; @@ -408,6 +445,8 @@ class C10_API SymInt { SymInt max_slow_path(const SymInt& sci) const; std::optional maybe_as_int_slow_path() const; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Constraints on the internal representation: // diff --git a/c10/core/SymbolicShapeMeta.cpp b/c10/core/SymbolicShapeMeta.cpp index 01276d416fbb8..28dfe3869c94e 100644 --- a/c10/core/SymbolicShapeMeta.cpp +++ b/c10/core/SymbolicShapeMeta.cpp @@ -71,6 +71,7 @@ normalize_sym_sizes_strides(SymIntArrayRef sizes, SymIntArrayRef strides) { return std::tuple, std::vector>( std::move(base), std::move(size_nodes), std::move(stride_nodes)); } +<<<<<<< HEAD namespace { bool all_hinted( const c10::SymIntArrayRef& sizes, @@ -92,6 +93,8 @@ bool all_hinted( return all_hinted; } } // namespace +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Special treatment because of numel SymBool SymbolicShapeMeta::compute_contiguous() const { @@ -100,6 +103,7 @@ SymBool SymbolicShapeMeta::compute_contiguous() const { } c10::SymIntArrayRef sizes(sizes_); c10::SymIntArrayRef strides(strides_); +<<<<<<< HEAD auto result = _compute_contiguous_sym(sizes, strides, numel()); @@ -178,6 +182,20 @@ SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d() const { c10::SymIntArrayRef sizes(sizes_); \ c10::SymIntArrayRef strides(strides_); \ return fallback(sizes, strides); \ +======= + return _compute_contiguous(sizes, strides, numel()); +} + +// The rest of them +#define DEFINE_EAGER_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \ + SymBool SymbolicShapeMeta::name() const { \ + if (!strides_valid_) { \ + return false; \ + } \ + c10::SymIntArrayRef sizes(sizes_); \ + c10::SymIntArrayRef strides(strides_); \ + return fallback(sizes, strides); \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } #define DEFINE_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \ @@ -197,11 +215,19 @@ SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d() const { } // clang-format off +<<<<<<< HEAD DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d) DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d) DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and_dense, _compute_non_overlapping_and_dense) +======= +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, is_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d) +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, is_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d) +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d, is_channels_last_strides_2d) +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d, is_channels_last_strides_3d) +DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and_dense, _compute_non_overlapping_and_dense) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // clang-format on #undef DEFINE_SYMBOOL_COMPUTE @@ -279,7 +305,10 @@ void SymbolicShapeMeta::set_numel(SymInt val) const { numel_ = std::move(val); available_.fetch_or(numel_avail); } +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void SymbolicShapeMeta::set_is_contiguous(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_contiguous()) { @@ -288,7 +317,10 @@ void SymbolicShapeMeta::set_is_contiguous(SymBool val) const { is_contiguous_ = std::move(val); available_.fetch_or(is_contiguous_avail); } +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last_contiguous()) { @@ -297,7 +329,10 @@ void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const { is_channels_last_contiguous_ = std::move(val); available_.fetch_or(is_channels_last_contiguous_avail); } +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last_3d_contiguous()) { @@ -306,7 +341,10 @@ void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const { is_channels_last_3d_contiguous_ = std::move(val); available_.fetch_or(is_channels_last_3d_contiguous_avail); } +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void SymbolicShapeMeta::set_is_channels_last(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last()) { @@ -315,7 +353,10 @@ void SymbolicShapeMeta::set_is_channels_last(SymBool val) const { is_channels_last_ = std::move(val); available_.fetch_or(is_channels_last_avail); } +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void SymbolicShapeMeta::set_is_channels_last_3d(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last_3d()) { diff --git a/c10/core/SymbolicShapeMeta.h b/c10/core/SymbolicShapeMeta.h index 0820038968a8e..630e98da5fb06 100644 --- a/c10/core/SymbolicShapeMeta.h +++ b/c10/core/SymbolicShapeMeta.h @@ -1,5 +1,8 @@ #pragma once +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -83,6 +86,7 @@ class C10_API SymbolicShapeMeta { return numel_; } +<<<<<<< HEAD const SymBool& is_contiguous(at::MemoryFormat memory_format) const { if (memory_format == at::MemoryFormat::ChannelsLast) { return this->is_channels_last_contiguous(); @@ -92,6 +96,8 @@ class C10_API SymbolicShapeMeta { return this->is_contiguous(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const SymBool& is_contiguous() const { if (C10_UNLIKELY(!has_is_contiguous())) { init_is_contiguous(); @@ -204,7 +210,10 @@ class C10_API SymbolicShapeMeta { // Lazily initialized variables, with the corresponding available_ flag // indicating whether the value has been initialized mutable std::atomic available_{0}; +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) enum avail { numel_avail = 1 << 0, is_contiguous_avail = 1 << 1, diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index cd0321d3bb6f5..d8c6d2709fe75 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -310,6 +310,7 @@ void TensorImpl::throw_data_ptr_access_error() const { false, "Cannot access data pointer of Tensor that doesn't have storage"); } +<<<<<<< HEAD c10::SymBool TensorImpl::sym_is_contiguous_custom( at::MemoryFormat memory_format) const { if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { @@ -325,6 +326,14 @@ c10::SymBool TensorImpl::sym_is_contiguous_custom( } return sym_is_contiguous_default(memory_format); +======= +bool TensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { + if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { + return pyobj_slot_.load_pyobj_interpreter()->is_contiguous( + this, memory_format); + } + return is_contiguous_default(memory_format); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } bool TensorImpl::is_strides_like_custom(at::MemoryFormat memory_format) const { @@ -335,12 +344,20 @@ bool TensorImpl::is_strides_like_custom(at::MemoryFormat memory_format) const { return is_strides_like_default(memory_format); } +<<<<<<< HEAD c10::SymBool TensorImpl::sym_is_non_overlapping_and_dense_custom() const { +======= +bool TensorImpl::is_non_overlapping_and_dense_custom() const { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { return pyobj_slot_.load_pyobj_interpreter()->is_non_overlapping_and_dense( this); } +<<<<<<< HEAD return sym_is_non_overlapping_and_dense_default(); +======= + return is_non_overlapping_and_dense_default(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } IntArrayRef TensorImpl::sizes_custom() const { diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 972181327b1f6..530c1f236dee6 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -643,6 +643,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } } +<<<<<<< HEAD template ArrayRef generic_sizes() { static_assert( @@ -654,10 +655,30 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } else { return sym_sizes(); } +======= + // From https://stackoverflow.com/a/3057522/23845 + // TODO: does C++14 have a stdlib template for this? + template + struct identity { + typedef T type; + }; + + template + ArrayRef generic_sizes() { + return _generic_sizes(identity()); + } + + ArrayRef _generic_sizes(identity) { + return sizes(); + } + ArrayRef _generic_sizes(identity) { + return sym_sizes(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } template ArrayRef generic_strides() { +<<<<<<< HEAD static_assert( std::is_same_v || std::is_same_v, "Only supports int64_t and c10::SymInt."); @@ -667,10 +688,21 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } else { return sym_strides(); } +======= + return _generic_strides(identity()); + } + + ArrayRef _generic_strides(identity) { + return strides(); + } + ArrayRef _generic_strides(identity) { + return sym_strides(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } template T generic_storage_offset() { +<<<<<<< HEAD static_assert( std::is_same_v || std::is_same_v, "Only supports int64_t and c10::SymInt."); @@ -680,6 +712,16 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } else { return sym_storage_offset(); } +======= + return _generic_storage_offset(identity()); + } + + int64_t _generic_storage_offset(identity) { + return storage_offset(); + } + c10::SymInt _generic_storage_offset(identity) { + return sym_storage_offset(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } /** @@ -808,6 +850,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } } +<<<<<<< HEAD c10::SymBool sym_is_contiguous( at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const { if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { @@ -845,6 +888,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return is_contiguous_default_impl(memory_format); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /** * Whether or not a tensor is laid out in contiguous memory. * @@ -860,6 +905,33 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return is_contiguous_default(memory_format); } +<<<<<<< HEAD +======= + // These are factored into separate functions in case subclasses + // want to use them + bool is_contiguous_default(at::MemoryFormat memory_format) const { + if (has_symbolic_sizes_strides_) { + if (memory_format == at::MemoryFormat::ChannelsLast) { + return symbolic_shape_meta().is_channels_last_contiguous().guard_bool( + __FILE__, __LINE__); + } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { + return symbolic_shape_meta() + .is_channels_last_3d_contiguous() + .guard_bool(__FILE__, __LINE__); + } + return symbolic_shape_meta().is_contiguous().guard_bool( + __FILE__, __LINE__); + } + + if (memory_format == at::MemoryFormat::ChannelsLast) { + return is_channels_last_contiguous_; + } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { + return is_channels_last_3d_contiguous_; + } + return is_contiguous_; + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool is_strides_like_default(at::MemoryFormat memory_format) const { if (has_symbolic_sizes_strides_) { if (memory_format == at::MemoryFormat::ChannelsLast) { @@ -882,6 +954,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } } +<<<<<<< HEAD SymBool sym_is_non_overlapping_and_dense_default() const { if (has_symbolic_sizes_strides_) { return symbolic_shape_meta().is_non_overlapping_and_dense(); @@ -893,6 +966,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { bool is_non_overlapping_and_dense_default() const { if (has_symbolic_sizes_strides_) { return sym_is_non_overlapping_and_dense_default().guard_bool( +======= + bool is_non_overlapping_and_dense_default() const { + if (has_symbolic_sizes_strides_) { + return symbolic_shape_meta().is_non_overlapping_and_dense().guard_bool( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __FILE__, __LINE__); } else { return is_non_overlapping_and_dense_; @@ -985,6 +1063,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * for a tensor to have rank, but not well defined sizes. */ // sizes_strides_policy_ >= CustomStrides +<<<<<<< HEAD virtual bool is_strides_like_custom(at::MemoryFormat memory_format) const; @@ -1003,6 +1082,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { .guard_bool(__FILE__, __LINE__); } +======= + virtual bool is_contiguous_custom(at::MemoryFormat memory_format) const; + virtual bool is_strides_like_custom(at::MemoryFormat memory_format) const; + virtual bool is_non_overlapping_and_dense_custom() const; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // sizes_strides_policy_ >= CustomSizes // Currently this method only exists to be overwritten by subclasses such as // NestedTensorImpl. @@ -1036,9 +1120,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { virtual c10::SymInt sym_storage_offset_custom() const; public: +<<<<<<< HEAD /** * True if this tensor has storage. See storage() for details. */ +======= + /** + * True if this tensor has storage. See storage() for details. + */ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef DEBUG // Allow subclasses to check that their storage_ is never getting set in debug // builds. @@ -1048,11 +1138,19 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { #endif bool has_storage() const +<<<<<<< HEAD // NOTE: we devirtualize this because it arguably shouldn't be an // error just to ask subclasses if they have storage. // This used to throw for most subclasses, but OpaqueTensorImpl // wanted it to successfully return false, so we went ahead and made // it a non-error. +======= + // NOTE: we devirtualize this because it arguably shouldn't be an + // error just to ask subclasses if they have storage. + // This used to throw for most subclasses, but OpaqueTensorImpl + // wanted it to successfully return false, so we went ahead and made + // it a non-error. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY { return storage_; @@ -2086,7 +2184,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { constexpr auto sparse_backends = DispatchKeySet( {BackendComponent::CPUBit, BackendComponent::CUDABit, +<<<<<<< HEAD BackendComponent::MPSBit, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) BackendComponent::HIPBit, BackendComponent::XPUBit}); constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse); @@ -2480,11 +2581,14 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return is_strides_like(at::MemoryFormat::ChannelsLast3d); } +<<<<<<< HEAD bool is_non_overlapping_and_dense_or_false() const { return sym_is_non_overlapping_and_dense().guard_or_false( __FILE__, __LINE__); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool is_non_overlapping_and_dense() const { if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { return is_non_overlapping_and_dense_custom(); @@ -2492,6 +2596,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return is_non_overlapping_and_dense_default(); } +<<<<<<< HEAD SymBool sym_is_non_overlapping_and_dense() const { if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { return sym_is_non_overlapping_and_dense_custom(); @@ -2499,6 +2604,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return sym_is_non_overlapping_and_dense_default(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // if this returns true, then it is guaranteed that this tensor has symbolic // sizes/strides bool has_symbolic_sizes_strides() const { diff --git a/c10/core/UndefinedTensorImpl.cpp b/c10/core/UndefinedTensorImpl.cpp index b42d3a92545f0..a09fafe50594c 100644 --- a/c10/core/UndefinedTensorImpl.cpp +++ b/c10/core/UndefinedTensorImpl.cpp @@ -12,8 +12,12 @@ UndefinedTensorImpl::UndefinedTensorImpl() set_custom_sizes_strides(SizesStridesPolicy::CustomStrides); } +<<<<<<< HEAD c10::SymBool UndefinedTensorImpl::sym_is_contiguous_custom( MemoryFormat format) const { +======= +bool UndefinedTensorImpl::is_contiguous_custom(MemoryFormat format) const { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return is_contiguous_default(format); } IntArrayRef UndefinedTensorImpl::strides_custom() const { diff --git a/c10/core/UndefinedTensorImpl.h b/c10/core/UndefinedTensorImpl.h index 6b7573a69388a..d1067e98559ec 100644 --- a/c10/core/UndefinedTensorImpl.h +++ b/c10/core/UndefinedTensorImpl.h @@ -32,7 +32,11 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl { void set_storage_offset(int64_t offset) override; protected: +<<<<<<< HEAD c10::SymBool sym_is_contiguous_custom(MemoryFormat format) const override; +======= + bool is_contiguous_custom(MemoryFormat format) const override; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) IntArrayRef strides_custom() const override; SymIntArrayRef sym_strides_custom() const override; diff --git a/c10/core/impl/PyInterpreter.cpp b/c10/core/impl/PyInterpreter.cpp index 913bc78726576..08a998b7b7da1 100644 --- a/c10/core/impl/PyInterpreter.cpp +++ b/c10/core/impl/PyInterpreter.cpp @@ -60,10 +60,13 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable { bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const override { PANIC(is_contiguous); } +<<<<<<< HEAD c10::SymBool sym_is_contiguous(const TensorImpl* self, at::MemoryFormat) const override { PANIC(sym_is_contiguous); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool is_strides_like(const TensorImpl* self, at::MemoryFormat) const override { PANIC(is_strides_like); diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h index def708c24b802..abde7c8420131 100644 --- a/c10/core/impl/PyInterpreter.h +++ b/c10/core/impl/PyInterpreter.h @@ -168,9 +168,12 @@ struct C10_API PyInterpreterVTable { virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const = 0; +<<<<<<< HEAD virtual c10::SymBool sym_is_contiguous( const TensorImpl* self, at::MemoryFormat) const = 0; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat) const = 0; virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0; @@ -243,4 +246,27 @@ struct C10_API PyInterpreter { void disarm() noexcept; }; +<<<<<<< HEAD +======= +// PyInterpreterStatus describes what the state of its interpreter tag +// is, relative to the thread currently holding the GIL. +enum class PyInterpreterStatus { + // We just allocated the Tensor, it hasn't escaped to other threads, + // we know that it definitely hasn't been tagged to be associated + // with an interpreter. + DEFINITELY_UNINITIALIZED, + // We queried the interpreter field and it looked uninitialized. But + // another thread may have raced with us to tag it with some other + // interpreter id. So we will have to do a CEX to make sure we can + // actually nab it. + MAYBE_UNINITIALIZED, + // We queried the interpreter field and it was tagged to belong to us. + // This means we have sole write access (as we hold the GIL for this + // interpreter) + TAGGED_BY_US, + // Someone else tagged this. We can't use this TensorImpl from Python. + TAGGED_BY_OTHER, +}; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace c10::impl diff --git a/c10/core/impl/PyObjectSlot.cpp b/c10/core/impl/PyObjectSlot.cpp index 0f1bfb2110747..7b4fd8f30ff1b 100644 --- a/c10/core/impl/PyObjectSlot.cpp +++ b/c10/core/impl/PyObjectSlot.cpp @@ -34,12 +34,36 @@ PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const { reinterpret_cast(pyobj_) & ~0x1ULL); } +<<<<<<< HEAD +======= +void PyObjectSlot::unchecked_clear_pyobj(PyInterpreter* interpreter) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(interpreter == pyobj_interpreter_.load()); + pyobj_ = nullptr; +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const { auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire); if (interpreter) { return *interpreter; } +<<<<<<< HEAD TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set"); +======= + TORCH_CHECK( + false, + "cannot access PyObject for Tensor on interpreter ", + (*pyobj_interpreter_.load())->name()); +} + +bool PyObjectSlot::check_interpreter(PyInterpreter* interpreter) { + return interpreter == pyobj_interpreter(); +} + +bool PyObjectSlot::has_pyobj_nonhermetic() { + return check_pyobj(pyobj_interpreter(), /*ignore_hermetic_tls=*/true) + .has_value(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } bool PyObjectSlot::owns_pyobj() { diff --git a/c10/core/impl/PyObjectSlot.h b/c10/core/impl/PyObjectSlot.h index 58b2490eba001..562eb3cf1fcfb 100644 --- a/c10/core/impl/PyObjectSlot.h +++ b/c10/core/impl/PyObjectSlot.h @@ -2,7 +2,10 @@ #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -25,9 +28,58 @@ struct C10_API PyObjectSlot { // // NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after // PyObject if necessary! +<<<<<<< HEAD void init_pyobj(PyObject* pyobj) { pyobj_interpreter_.store( getGlobalPyInterpreter(), std::memory_order_relaxed); +======= + void init_pyobj( + PyInterpreter* self_interpreter, + PyObject* pyobj, + PyInterpreterStatus status) { + impl::PyInterpreter* expected = nullptr; + switch (status) { + case impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED: + // caller guarantees there is no multithreaded access; if there is + // no data race OK to do a relaxed store + pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed); + break; + case impl::PyInterpreterStatus::TAGGED_BY_US: + // no tagging is necessary, the tag is already correct + break; + case impl::PyInterpreterStatus::MAYBE_UNINITIALIZED: + // attempt to claim this TensorImpl with the specified interpreter + // tag + if (pyobj_interpreter_.compare_exchange_strong( + expected, self_interpreter, std::memory_order_acq_rel)) { + break; + } + // test if, actually, it was already tagged by us! this situation can't + // be caused by a race, but it could be caused by a situation + // where someone conservatively tagged the tensor as MAYBE_UNINITIALIZED + // (because they didn't pre-check the tag) when actually it was + // owned by the interpreter + if (expected == self_interpreter) { + break; + } + // fallthrough, we lost the race. We are guaranteed not to lose the + // race with ourself, as calls to init_pyobj with the same interpreter + // ID must be sequentialized by the GIL + [[fallthrough]]; + case impl::PyInterpreterStatus::TAGGED_BY_OTHER: + TORCH_CHECK( + false, + "cannot allocate PyObject for Tensor on interpreter ", + self_interpreter, + " that has already been used by another torch deploy interpreter ", + pyobj_interpreter_.load()); + } + + // we are the ONLY thread that can have gotten to this point. It is not + // possible to conflict with another zero interpreter as access is protected + // by GIL + // NB: owns_pyobj tag is initially false +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pyobj_ = pyobj; } @@ -52,6 +104,7 @@ struct C10_API PyObjectSlot { // // NB: this lives in header so that we can avoid actually creating the // std::optional +<<<<<<< HEAD // @todo alban: I'm not too sure what's going on here, we can probably delete // it but it's worthwhile making sure @@ -71,6 +124,51 @@ struct C10_API PyObjectSlot { PyInterpreter& load_pyobj_interpreter() const; +======= + std::optional check_pyobj( + PyInterpreter* self_interpreter, + bool ignore_hermetic_tls = false) const { + // Note [Memory ordering on Python interpreter tag] + impl::PyInterpreter* interpreter = + pyobj_interpreter_.load(std::memory_order_acquire); + if (interpreter == nullptr) { + // NB: This never returns DEFINITELY_UNINITIALIZED because there is + // always the possibility that another thread races to initialize + // after we query here. The only time when we can conclude a tensor + // is definitely uninitialized is when we have just allocated it and + // it cannot have escaped to other threads yet + return std::nullopt; + } else if (interpreter == self_interpreter) { + // NB: pyobj_ could still be null! + if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) { + return std::nullopt; + } else { + return _unchecked_untagged_pyobj(); + } + } else { + TORCH_CHECK( + false, + "cannot access PyObject for Tensor on interpreter ", + (*self_interpreter)->name(), + " that has already been used by another torch deploy interpreter ", + (*pyobj_interpreter_.load())->name()); + } + } + + // Clear the PyObject field for an interpreter, in situations where we + // statically know the tensor is tagged with our interpreter. + void unchecked_clear_pyobj(PyInterpreter* interpreter); + + PyInterpreter& load_pyobj_interpreter() const; + + // Check if the PyObjectSlot's interpreter is the same as the specified + // interpreter + bool check_interpreter(PyInterpreter* interpreter); + + // Check if the PyObjectSlot is holding a PyObject, owned or non-owned + bool has_pyobj_nonhermetic(); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool owns_pyobj(); void set_owns_pyobj(bool b); diff --git a/c10/core/impl/SizesAndStrides.h b/c10/core/impl/SizesAndStrides.h index 6cc87e1d6be3e..aaa90b7d97aff 100644 --- a/c10/core/impl/SizesAndStrides.h +++ b/c10/core/impl/SizesAndStrides.h @@ -64,10 +64,13 @@ class C10_API SizesAndStrides { storageBytes(size_))); } +<<<<<<< HEAD bool operator!=(const SizesAndStrides& other) const { return !(*this == other); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SizesAndStrides& operator=(const SizesAndStrides& rhs) { if (this == &rhs) { return *this; diff --git a/c10/cuda/CUDAAllocatorConfig.cpp b/c10/cuda/CUDAAllocatorConfig.cpp index 8706f7362a3d2..1e9756c734f19 100644 --- a/c10/cuda/CUDAAllocatorConfig.cpp +++ b/c10/cuda/CUDAAllocatorConfig.cpp @@ -25,7 +25,10 @@ CUDAAllocatorConfig::CUDAAllocatorConfig() #endif m_release_lock_on_cudamalloc(false), m_pinned_use_cuda_host_register(false), +<<<<<<< HEAD m_graph_capture_record_stream_reuse(false), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m_pinned_use_background_threads(false) { m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); } @@ -374,9 +377,12 @@ void CUDAAllocatorConfig::parseArgs(const std::optional& env) { } else if (config_item_view == "pinned_use_background_threads") { i = parsePinnedUseBackgroundThreads(config, i); used_native_specific_option = true; +<<<<<<< HEAD } else if (config_item_view == "graph_capture_record_stream_reuse") { i = parseGraphCaptureRecordStreamReuse(config, i); used_native_specific_option = true; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { TORCH_CHECK( false, "Unrecognized CachingAllocator option: ", config_item_view); @@ -410,6 +416,7 @@ size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister( return i; } +<<<<<<< HEAD size_t CUDAAllocatorConfig::parseGraphCaptureRecordStreamReuse( const std::vector& config, size_t i) { @@ -427,6 +434,8 @@ size_t CUDAAllocatorConfig::parseGraphCaptureRecordStreamReuse( return i; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads( const std::vector& config, size_t i) { diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index 54c41ba70fb6f..6d77e638e5a72 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -53,10 +53,13 @@ class C10_CUDA_API CUDAAllocatorConfig { return instance().m_release_lock_on_cudamalloc; } +<<<<<<< HEAD static bool graph_capture_record_stream_reuse() { return instance().m_graph_capture_record_stream_reuse; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /** Pinned memory allocator settings */ static bool pinned_use_cuda_host_register() { return instance().m_pinned_use_cuda_host_register; @@ -146,9 +149,12 @@ class C10_CUDA_API CUDAAllocatorConfig { size_t parsePinnedUseBackgroundThreads( const std::vector& config, size_t i); +<<<<<<< HEAD size_t parseGraphCaptureRecordStreamReuse( const std::vector& config, size_t i); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::atomic m_max_split_size; std::atomic m_max_non_split_rounding_size; @@ -160,7 +166,10 @@ class C10_CUDA_API CUDAAllocatorConfig { m_expandable_segments_handle_type; std::atomic m_release_lock_on_cudamalloc; std::atomic m_pinned_use_cuda_host_register; +<<<<<<< HEAD std::atomic m_graph_capture_record_stream_reuse; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::atomic m_pinned_use_background_threads; std::string m_last_allocator_settings; std::mutex m_last_allocator_settings_mutex; diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 93ac4f7a4c649..46a1657a0e5de 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -373,12 +373,20 @@ struct ExpandableSegment { ExpandableSegment( c10::DeviceIndex device, std::optional stream, +<<<<<<< HEAD +======= + size_t address_space_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) size_t segment_size, std::vector peers) : device_(device), stream_(stream), // 2MB for small pool, 20MB for large pool segment_size_(segment_size), +<<<<<<< HEAD +======= + max_handles_(numSegments(address_space_size)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) peers_(std::move(peers)) { cudaDeviceProp prop{}; C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_)); @@ -547,7 +555,15 @@ struct ExpandableSegment { ShareHeader header{}; buf.read((char*)&header, sizeof(ShareHeader)); auto segment = std::make_unique( +<<<<<<< HEAD device, std::nullopt, header.segment_size, std::move(peers)); +======= + device, + std::nullopt, + header.num_handles * header.segment_size, + header.segment_size, + std::move(peers)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // older build setups (e.g. multiwheels) do not have this syscall, added 2020 // but the kernel on the system might still support it. #ifndef SYS_pidfd_open @@ -745,6 +761,10 @@ struct ExpandableSegment { ExpandableSegment( c10::DeviceIndex device, std::optional stream, +<<<<<<< HEAD +======= + size_t address_space_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) size_t segment_size, std::vector peers) { TORCH_INTERNAL_ASSERT(false, "expandable segment not supported"); @@ -836,7 +856,12 @@ struct AllocParams { size_t size, cudaStream_t stream, BlockPool* pool, +<<<<<<< HEAD size_t alloc_size) +======= + size_t alloc_size, + DeviceStats& stats) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : search_key(device, stream, size), pool(pool), alloc_size(alloc_size) {} c10::DeviceIndex device() const { @@ -1167,6 +1192,7 @@ class DeviceCachingAllocator { // tracks which pools we can use as a last resort before ooming ska::flat_hash_set use_on_oom_pools; +<<<<<<< HEAD // Map of blocks whose freeing is deferred until after CUDA graph capture. // - Key: Block* to be freed. // - Value: List of "empty nodes" inserted as free markers during capture. @@ -1174,6 +1200,10 @@ class DeviceCachingAllocator { // ends. ska::flat_hash_map> deferred_blocks; +======= + // See free() for this thing's purpose + std::vector needs_events_deferred_until_no_capture; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // outstanding cuda events ska::flat_hash_map< cuda::CUDAStream, @@ -1334,16 +1364,23 @@ class DeviceCachingAllocator { // capture. Cross-stream memory use is uncommon, so the deferral's // effect on memory use during capture should be small. process_events(context); +<<<<<<< HEAD } else { if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) { // We check if there is some block that is safe to reuse on this stream free_safe_blocks_in_capture(context, stream); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } size_t size = round_size(orig_size); auto& pool = get_pool(size, stream); const size_t alloc_size = get_allocation_size(size); +<<<<<<< HEAD AllocParams params(device, size, stream, &pool, alloc_size); +======= + AllocParams params(device, size, stream, &pool, alloc_size, stats); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) params.stat_types = get_stat_types_for_pool(pool); // First, try to get a block from the existing pool. @@ -1390,7 +1427,11 @@ class DeviceCachingAllocator { beginAllocateToPool(mempool_id, filter); auto& mempool = get_pool(size, stream); AllocParams mempool_params( +<<<<<<< HEAD device, size, stream, &mempool, alloc_size); +======= + device, size, stream, &mempool, alloc_size, stats); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mempool_params.stat_types = get_stat_types_for_pool(mempool); block_found = get_free_block(mempool_params); endAllocateToPool(mempool_id); @@ -1629,6 +1670,7 @@ class DeviceCachingAllocator { return block; } +<<<<<<< HEAD // Insert "free marker" (empty nodes) into the CUDA graph for all streams that // have used the block, including the allocation stream. These nodes mark the // last use of the block in the capture graph. Returns a vector of the @@ -1871,6 +1913,8 @@ class DeviceCachingAllocator { } } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void free(Block* block) { std::shared_ptr context = maybeGatherContext(RecordContext::ALL); @@ -1906,6 +1950,7 @@ class DeviceCachingAllocator { if (block->size >= CUDAAllocatorConfig::max_split_size()) stats.oversize_allocations.decrease(1); +<<<<<<< HEAD // If the block has been used on more than one stream, handle accordingly. if (!block->stream_uses.empty()) { if (C10_UNLIKELY(!captures_underway.empty())) { @@ -1922,6 +1967,16 @@ class DeviceCachingAllocator { } } else { // If not in a capture, insert events for the block. +======= + if (!block->stream_uses.empty()) { + if (C10_UNLIKELY(!captures_underway.empty())) { + // It's forbidden to cudaEventQuery an event recorded during CUDA graph + // capture. We conservatively defer recording end-of-life events until + // the next call to process_events() (which won't happen until no + // captures are underway) + needs_events_deferred_until_no_capture.push_back(block); + } else { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) insert_events(block); } } else { @@ -2181,7 +2236,12 @@ class DeviceCachingAllocator { block_state.size, block_state.stream, &pool, +<<<<<<< HEAD block_state.size); +======= + block_state.size, + stats); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pool.blocks.erase(curr_block); params.block = curr_block; params.stat_types = get_stat_types_for_pool(pool); @@ -2676,8 +2736,24 @@ class DeviceCachingAllocator { } } auto segment_size = pool->is_small ? kSmallBuffer : kLargeBuffer; +<<<<<<< HEAD expandable_segments_.emplace_back(new ExpandableSegment( device, stream, segment_size, devices_with_peer_access_)); +======= + cudaDeviceProp prop{}; + C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); + // we allocate enough address space for 1 1/8 the total memory on the GPU. + // This allows for some cases where we have to unmap pages earlier in the + // segment to put them at the end. + size_t address_space_size = prop.totalGlobalMem + prop.totalGlobalMem / 8; + + expandable_segments_.emplace_back(new ExpandableSegment( + device, + stream, + address_space_size, + segment_size, + devices_with_peer_access_)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ExpandableSegment* es = expandable_segments_.back(); Block* candidate = new Block(device, stream, es->size(), pool, es->ptr()); @@ -3237,8 +3313,13 @@ class DeviceCachingAllocator { --it; } if (!(*cur)->expandable_segment_) { +<<<<<<< HEAD totalReleased += (*cur)->size; release_block(*cur, context); +======= + release_block(*cur, context); + totalReleased += (*cur)->size; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if (is_first) { break; @@ -3547,8 +3628,13 @@ class DeviceCachingAllocator { void insert_events_deferred_until_no_capture( const std::shared_ptr& context) { +<<<<<<< HEAD if (C10_UNLIKELY(!deferred_blocks.empty())) { for (auto& [block, inserted_empty_nodes] : deferred_blocks) { +======= + if (C10_UNLIKELY(!needs_events_deferred_until_no_capture.empty())) { + for (auto* block : needs_events_deferred_until_no_capture) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_INTERNAL_ASSERT(!block->stream_uses.empty()); // only streams recorded before cudagraph will be used to insert events // since we know all streams recorded during cudagraph must have @@ -3560,7 +3646,11 @@ class DeviceCachingAllocator { free_block(block, context); } } +<<<<<<< HEAD deferred_blocks.clear(); +======= + needs_events_deferred_until_no_capture.clear(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -3991,8 +4081,11 @@ class NativeCachingAllocator : public CUDAAllocator { md.pinned_use_host_register = CUDAAllocatorConfig::pinned_use_cuda_host_register(); md.last_allocator_settings = CUDAAllocatorConfig::last_allocator_settings(); +<<<<<<< HEAD md.graph_capture_record_stream_reuse = CUDAAllocatorConfig::graph_capture_record_stream_reuse(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) md.roundup_power2_divisions = CUDAAllocatorConfig::roundup_power2_divisions(); @@ -4421,6 +4514,7 @@ struct BackendStaticInitializer { BackendStaticInitializer() { auto r = parseEnvForBackend(); +<<<<<<< HEAD // Register this HIP allocator as the CUDA allocator to allow it to work // with both c10::GetAllocator(kCUDA) and c10::getDeviceAllocator(kCUDA) // APIs. We don't perform this masquerading inside @@ -4433,6 +4527,9 @@ struct BackendStaticInitializer { at::SetAllocator(c10::Device(HIP_MASQUERADING_AS_CUDA).type(), r, 0); allocator.store(r); #undef HIP_MASQUERADING_AS_CUDA +======= + allocator.store(r); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } }; @@ -4459,8 +4556,16 @@ std::atomic MemPool::uuid_{1}; MemPool::MemPool( CUDACachingAllocator::CUDAAllocator* allocator, bool is_user_created, +<<<<<<< HEAD bool use_on_oom) : allocator_(allocator), is_user_created_(is_user_created) { +======= + bool use_on_oom, + bool symmetric) + : allocator_(allocator), + is_user_created_(is_user_created), + symmetric_(symmetric) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (is_user_created_) { id_ = {0, uid_++}; } else { @@ -4483,6 +4588,13 @@ MempoolId_t MemPool::id() { return id_; } +<<<<<<< HEAD +======= +bool MemPool::is_symmetric() { + return symmetric_; +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CUDACachingAllocator::CUDAAllocator* MemPool::allocator() { return allocator_; } diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index bfc486d69fcff..0a03c1332d0a1 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -163,7 +163,10 @@ struct AllocatorConfigInfo { bool expandable_segments; bool release_lock_on_malloc; bool pinned_use_host_register; +<<<<<<< HEAD bool graph_capture_record_stream_reuse; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::string last_allocator_settings; std::vector roundup_power2_divisions; }; @@ -203,24 +206,43 @@ struct ShareableHandle { std::string handle; }; +<<<<<<< HEAD class CUDAAllocator : public DeviceAllocator { +======= +class CUDAAllocator : public Allocator { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) public: virtual void* raw_alloc(size_t nbytes) = 0; virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) = 0; virtual void raw_delete(void* ptr) = 0; virtual void init(int device_count) = 0; +<<<<<<< HEAD virtual double getMemoryFraction(c10::DeviceIndex device) = 0; virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0; +======= + virtual bool initialized() = 0; + virtual double getMemoryFraction(c10::DeviceIndex device) = 0; + virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0; + virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) virtual void enable(bool value) = 0; virtual bool isEnabled() const = 0; virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0; virtual void* getBaseAllocation(void* ptr, size_t* size) = 0; +<<<<<<< HEAD // Keep for BC only virtual void recordStream(const DataPtr& ptr, CUDAStream stream) = 0; void recordStream(const DataPtr& ptr, c10::Stream stream) override { CUDAStream cuda_stream = CUDAStream(stream); recordStream(ptr, cuda_stream); } +======= + virtual void recordStream(const DataPtr&, CUDAStream stream) = 0; + virtual c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) = 0; + virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; + virtual void resetPeakStats(c10::DeviceIndex device) = 0; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0; virtual void beginAllocateToPool( c10::DeviceIndex device, @@ -525,10 +547,13 @@ inline void enablePeerAccess( namespace c10::cuda { +<<<<<<< HEAD // Keep BC only using c10::CaptureId_t; using c10::MempoolId_t; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // MemPool represents a pool of memory in a caching allocator. Currently, // it's just the ID of the pool object maintained in the CUDACachingAllocator. // @@ -539,7 +564,12 @@ struct C10_CUDA_API MemPool { MemPool( CUDACachingAllocator::CUDAAllocator* allocator = nullptr, bool is_user_created = true, +<<<<<<< HEAD bool use_on_oom = false); +======= + bool use_on_oom = false, + bool symmetric = false); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MemPool(const MemPool&) = delete; MemPool(MemPool&&) = default; MemPool& operator=(const MemPool&) = delete; @@ -547,6 +577,10 @@ struct C10_CUDA_API MemPool { ~MemPool(); MempoolId_t id(); +<<<<<<< HEAD +======= + bool is_symmetric(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CUDACachingAllocator::CUDAAllocator* allocator(); int use_count(); c10::DeviceIndex device(); @@ -558,6 +592,10 @@ struct C10_CUDA_API MemPool { CUDACachingAllocator::CUDAAllocator* allocator_; bool is_user_created_; MempoolId_t id_; +<<<<<<< HEAD +======= + bool symmetric_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::DeviceIndex device_; }; diff --git a/c10/cuda/CUDAException.cpp b/c10/cuda/CUDAException.cpp index 457d35f020bbe..6919e3c5c061d 100644 --- a/c10/cuda/CUDAException.cpp +++ b/c10/cuda/CUDAException.cpp @@ -28,9 +28,13 @@ void c10_cuda_check_implementation( std::string check_message; #ifndef STRIP_ERROR_MESSAGES check_message.append("CUDA error: "); +<<<<<<< HEAD const char* error_string = cudaGetErrorString(cuda_error); check_message.append(error_string); check_message.append(c10::cuda::get_cuda_error_help(cuda_error)); +======= + check_message.append(cudaGetErrorString(cuda_error)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) check_message.append(c10::cuda::get_cuda_check_suffix()); check_message.append("\n"); if (include_device_assertions) { diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp index 9839e4e72049e..5b8f7fa7c94da 100644 --- a/c10/cuda/CUDAFunctions.cpp +++ b/c10/cuda/CUDAFunctions.cpp @@ -53,12 +53,20 @@ int device_count_impl(bool fail_if_no_driver) { "https://pytorch.org to install a PyTorch version that has been " "compiled with your version of the CUDA driver."); } +<<<<<<< HEAD } +======= + } break; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case cudaErrorInitializationError: TORCH_CHECK( false, "CUDA driver initialization failed, you might not " "have a CUDA gpu."); +<<<<<<< HEAD +======= + break; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case cudaErrorUnknown: TORCH_CHECK( false, @@ -66,6 +74,10 @@ int device_count_impl(bool fail_if_no_driver) { "incorrectly set up environment, e.g. changing env " "variable CUDA_VISIBLE_DEVICES after program start. " "Setting the available devices to be zero."); +<<<<<<< HEAD +======= + break; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if C10_ASAN_ENABLED case cudaErrorMemoryAllocation: // In ASAN mode, we know that a cudaErrorMemoryAllocation error will @@ -78,6 +90,7 @@ int device_count_impl(bool fail_if_no_driver) { "would like to use GPUs, turn off ASAN."); break; #endif // C10_ASAN_ENABLED +<<<<<<< HEAD #if _WIN32 && CUDA_VERSION >= 13000 // Workaround for CUDA-13.0 error handling on Windows, see // https://github.com/pytorch/pytorch/issues/162333#issuecomment-3267929585 @@ -90,6 +103,8 @@ int device_count_impl(bool fail_if_no_driver) { break; } #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) default: TORCH_CHECK( false, diff --git a/c10/cuda/CUDAFunctions.h b/c10/cuda/CUDAFunctions.h index 543c866027464..61116becc971a 100644 --- a/c10/cuda/CUDAFunctions.h +++ b/c10/cuda/CUDAFunctions.h @@ -90,6 +90,7 @@ C10_CUDA_API void __inline__ memcpy_and_sync( (*interp)->trace_gpu_stream_synchronization( c10::kCUDA, reinterpret_cast(stream)); } +<<<<<<< HEAD #if defined(USE_ROCM) && USE_ROCM // As of ROCm 6.4.1, HIP runtime does not raise an error during capture of // hipMemcpyWithStream which is a synchronous call. Thus, we add a check @@ -101,6 +102,10 @@ C10_CUDA_API void __inline__ memcpy_and_sync( } else { C10_CUDA_CHECK(hipErrorStreamCaptureUnsupported); } +======= +#if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301) + C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #else C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream)); C10_CUDA_CHECK(cudaStreamSynchronize(stream)); diff --git a/c10/cuda/CUDAGraphsC10Utils.h b/c10/cuda/CUDAGraphsC10Utils.h index 936875fd71d5c..26639395cf620 100644 --- a/c10/cuda/CUDAGraphsC10Utils.h +++ b/c10/cuda/CUDAGraphsC10Utils.h @@ -9,6 +9,15 @@ namespace c10::cuda { +<<<<<<< HEAD +======= +using CaptureId_t = unsigned long long; + +// first is set if the instance is created by CUDAGraph::capture_begin. +// second is set if the instance is created by at::cuda::graph_pool_handle. +using MempoolId_t = std::pair; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // RAII guard for "cudaStreamCaptureMode", a thread-local value // that controls the error-checking strictness of a capture. struct C10_CUDA_API CUDAStreamCaptureModeGuard { diff --git a/c10/cuda/CUDAMiscFunctions.cpp b/c10/cuda/CUDAMiscFunctions.cpp index b1b6170f891e9..a4d75d098b8aa 100644 --- a/c10/cuda/CUDAMiscFunctions.cpp +++ b/c10/cuda/CUDAMiscFunctions.cpp @@ -1,5 +1,6 @@ #include #include +<<<<<<< HEAD #include #include #include @@ -25,6 +26,11 @@ std::string get_cuda_error_help(cudaError_t error) noexcept { return help_text; } +======= + +namespace c10::cuda { + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // NOLINTNEXTLINE(bugprone-exception-escape,-warnings-as-errors) const char* get_cuda_check_suffix() noexcept { static auto device_blocking_flag = diff --git a/c10/cuda/CUDAMiscFunctions.h b/c10/cuda/CUDAMiscFunctions.h index ec1114935457e..d216db81999d3 100644 --- a/c10/cuda/CUDAMiscFunctions.h +++ b/c10/cuda/CUDAMiscFunctions.h @@ -3,6 +3,7 @@ // CUDAExceptions.h #include +<<<<<<< HEAD #include #include @@ -10,6 +11,12 @@ namespace c10::cuda { C10_CUDA_API std::string get_cuda_error_help(cudaError_t) noexcept; +======= + +#include + +namespace c10::cuda { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) C10_CUDA_API const char* get_cuda_check_suffix() noexcept; C10_CUDA_API std::mutex* getFreeMutex(); } // namespace c10::cuda diff --git a/c10/cuda/CUDAStream.cpp b/c10/cuda/CUDAStream.cpp index 6d2b1e06fda9b..f464a182d0b58 100644 --- a/c10/cuda/CUDAStream.cpp +++ b/c10/cuda/CUDAStream.cpp @@ -147,7 +147,11 @@ static inline StreamIdType streamIdType(StreamId s) { // rightmost bit int mask_for_type = (1 << kStreamTypeBits) - 1; auto val = (s >> 1) & mask_for_type; +<<<<<<< HEAD TORCH_CHECK(val || !(s & 1), "invalid StreamId", s); +======= + TORCH_INTERNAL_ASSERT(val || !(s & 1), "invalid StreamId", s); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return StreamIdType(val); } @@ -216,6 +220,12 @@ static void initSingleStream(int p, DeviceIndex device_index, int i) { // Creates the low and high priority stream pools for the specified device // Warning: only call once per device! static void initDeviceStreamState(DeviceIndex device_index) { +<<<<<<< HEAD +======= + // Switches to the requested device so streams are properly associated + // with it. + CUDAGuard device_guard{device_index}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto i : c10::irange(kStreamsPerPool)) { for (const auto p : c10::irange(max_stream_priorities)) { initSingleStream(p, device_index, i); @@ -276,7 +286,11 @@ cudaStream_t CUDAStream::stream() const { StreamIdType st = streamIdType(stream_id); size_t si = streamIdIndex(stream_id); if (st.isDefault()) { +<<<<<<< HEAD TORCH_CHECK( +======= + TORCH_INTERNAL_ASSERT( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) si == 0, "Unrecognized stream ", stream_, @@ -291,7 +305,11 @@ cudaStream_t CUDAStream::stream() const { return reinterpret_cast(stream_id); } else { auto streamType = st.getStreamType(); +<<<<<<< HEAD TORCH_CHECK( +======= + TORCH_INTERNAL_ASSERT( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) streamType >= 1 && streamType <= max_stream_priorities, "Unrecognized stream ", stream_, diff --git a/c10/cuda/driver_api.cpp b/c10/cuda/driver_api.cpp index d545bf5477b64..3acc80a21d34d 100644 --- a/c10/cuda/driver_api.cpp +++ b/c10/cuda/driver_api.cpp @@ -38,6 +38,7 @@ DriverAPI create_driver_api() { C10_NVML_DRIVER_API(LOOKUP_NVML_ENTRY) #undef LOOKUP_NVML_ENTRY } +<<<<<<< HEAD if (handle_1) { #define LOOKUP_NVML_ENTRY_OPTIONAL(name) \ @@ -45,6 +46,8 @@ DriverAPI create_driver_api() { C10_NVML_DRIVER_API_OPTIONAL(LOOKUP_NVML_ENTRY_OPTIONAL) #undef LOOKUP_NVML_ENTRY_OPTIONAL } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return r; } @@ -61,14 +64,20 @@ void* get_symbol(const char* name, int version) { } #endif +<<<<<<< HEAD // As of CUDA 13, this API is deprecated. #if defined(CUDA_VERSION) && (CUDA_VERSION < 13000) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // This fallback to the old API to try getting the symbol again. if (auto st = cudaGetDriverEntryPoint(name, &out, cudaEnableDefault, &qres); st == cudaSuccess && qres == cudaDriverEntryPointSuccess && out) { return out; } +<<<<<<< HEAD #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // If the symbol cannot be resolved, report and return nullptr; // the caller is responsible for checking the pointer. diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 8910e581a1a4e..b95e4df0cf0f5 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -53,8 +53,12 @@ #define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \ _(cuMulticastAddDevice, 12030) \ _(cuMulticastBindMem, 12030) \ +<<<<<<< HEAD _(cuMulticastCreate, 12030) \ _(cuMulticastUnbind, 12030) +======= + _(cuMulticastCreate, 12030) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #else #define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) #endif @@ -67,12 +71,15 @@ _(nvmlDeviceGetComputeRunningProcesses) \ _(nvmlSystemGetCudaDriverVersion_v2) +<<<<<<< HEAD #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12040) #define C10_NVML_DRIVER_API_OPTIONAL(_) _(nvmlDeviceGetGpuFabricInfoV) #else #define C10_NVML_DRIVER_API_OPTIONAL(_) #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace c10::cuda { struct DriverAPI { @@ -81,7 +88,10 @@ struct DriverAPI { C10_LIBCUDA_DRIVER_API_REQUIRED(CREATE_MEMBER_VERSIONED) C10_LIBCUDA_DRIVER_API_OPTIONAL(CREATE_MEMBER_VERSIONED) C10_NVML_DRIVER_API(CREATE_MEMBER) +<<<<<<< HEAD C10_NVML_DRIVER_API_OPTIONAL(CREATE_MEMBER) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #undef CREATE_MEMBER_VERSIONED #undef CREATE_MEMBER diff --git a/c10/macros/Export.h b/c10/macros/Export.h index 1b8a6811c53f5..dd1885243ebd9 100644 --- a/c10/macros/Export.h +++ b/c10/macros/Export.h @@ -1 +1,82 @@ +<<<<<<< HEAD #include +======= +#ifndef C10_MACROS_EXPORT_H_ +#define C10_MACROS_EXPORT_H_ + +#ifndef C10_USING_CUSTOM_GENERATED_MACROS +#include +#endif // C10_USING_CUSTOM_GENERATED_MACROS + +#include + +// This one is being used by libtorch.so +#ifdef CAFFE2_BUILD_MAIN_LIB +#define TORCH_API C10_EXPORT +#else +#define TORCH_API C10_IMPORT +#endif + +// You may be wondering: Whose brilliant idea was it to split torch_cuda into +// two pieces with confusing names? +// Once upon a time, there _was_ only TORCH_CUDA_API. All was happy until we +// tried to compile PyTorch for CUDA 11.1, which ran into relocation marker +// issues when linking big binaries. +// (https://github.com/pytorch/pytorch/issues/39968) We had two choices: +// (1) Stop supporting so many GPU architectures +// (2) Do something else +// We chose #2 and decided to split the behemoth that was torch_cuda into two +// smaller libraries, one with most of the core kernel functions (torch_cuda_cu) +// and the other that had..well..everything else (torch_cuda_cpp). The idea was +// this: instead of linking our static libraries (like the hefty +// libcudnn_static.a) with another huge library, torch_cuda, and run into pesky +// relocation marker issues, we could link our static libraries to a smaller +// part of torch_cuda (torch_cuda_cpp) and avoid the issues. + +// libtorch_cuda_cu.so +#ifdef TORCH_CUDA_CU_BUILD_MAIN_LIB +#define TORCH_CUDA_CU_API C10_EXPORT +#elif defined(BUILD_SPLIT_CUDA) +#define TORCH_CUDA_CU_API C10_IMPORT +#endif + +// libtorch_cuda_cpp.so +#ifdef TORCH_CUDA_CPP_BUILD_MAIN_LIB +#define TORCH_CUDA_CPP_API C10_EXPORT +#elif defined(BUILD_SPLIT_CUDA) +#define TORCH_CUDA_CPP_API C10_IMPORT +#endif + +// libtorch_cuda.so (where torch_cuda_cu and torch_cuda_cpp are a part of the +// same api) +#ifdef TORCH_CUDA_BUILD_MAIN_LIB +#define TORCH_CUDA_CPP_API C10_EXPORT +#define TORCH_CUDA_CU_API C10_EXPORT +#elif !defined(BUILD_SPLIT_CUDA) +#define TORCH_CUDA_CPP_API C10_IMPORT +#define TORCH_CUDA_CU_API C10_IMPORT +#endif + +#if defined(TORCH_HIP_BUILD_MAIN_LIB) +#define TORCH_HIP_CPP_API C10_EXPORT +#define TORCH_HIP_API C10_EXPORT +#else +#define TORCH_HIP_CPP_API C10_IMPORT +#define TORCH_HIP_API C10_IMPORT +#endif + +#if defined(TORCH_XPU_BUILD_MAIN_LIB) +#define TORCH_XPU_API C10_EXPORT +#else +#define TORCH_XPU_API C10_IMPORT +#endif + +// Enums only need to be exported on windows for non-CUDA files +#if defined(_WIN32) && defined(__CUDACC__) +#define C10_API_ENUM C10_API +#else +#define C10_API_ENUM +#endif + +#endif // C10_MACROS_EXPORT_H_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index 87ebc4f422c4c..fbf0071ae91e8 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -1 +1,568 @@ +<<<<<<< HEAD #include +======= +#ifndef C10_MACROS_MACROS_H_ +#define C10_MACROS_MACROS_H_ +#include + +/* Main entry for c10/macros. + * + * In your code, include c10/macros/Macros.h directly, instead of individual + * files in this folder. + */ + +// For build systems that do not directly depend on CMake and directly build +// from the source directory (such as Buck), one may not have a cmake_macros.h +// file at all. In this case, the build system is responsible for providing +// correct macro definitions corresponding to the cmake_macros.h.in file. +// +// In such scenarios, one should define the macro +// C10_USING_CUSTOM_GENERATED_MACROS +// to inform this header that it does not need to include the cmake_macros.h +// file. + +#ifndef C10_USING_CUSTOM_GENERATED_MACROS +#include +#endif // C10_USING_CUSTOM_GENERATED_MACROS + +#include + +#if defined(__clang__) +#define __ubsan_ignore_float_divide_by_zero__ \ + __attribute__((no_sanitize("float-divide-by-zero"))) +#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined"))) +#define __ubsan_ignore_signed_int_overflow__ \ + __attribute__((no_sanitize("signed-integer-overflow"))) +#define __ubsan_ignore_pointer_overflow__ \ + __attribute__((no_sanitize("pointer-overflow"))) +#define __ubsan_ignore_function__ __attribute__((no_sanitize("function"))) +#define __ubsan_ignore_float_cast_overflow__ \ + __attribute__((no_sanitize("float-cast-overflow"))) +#else +#define __ubsan_ignore_float_divide_by_zero__ +#define __ubsan_ignore_undefined__ +#define __ubsan_ignore_signed_int_overflow__ +#define __ubsan_ignore_pointer_overflow__ +#define __ubsan_ignore_function__ +#define __ubsan_ignore_float_cast_overflow__ +#endif + +// Detect address sanitizer as some stuff doesn't work with it +#undef C10_ASAN_ENABLED + +// for clang +#if defined(__has_feature) +#if ((__has_feature(address_sanitizer))) +#define C10_ASAN_ENABLED 1 +#endif +#endif + +// for gcc +#if defined(__SANITIZE_ADDRESS__) +#if __SANITIZE_ADDRESS__ +#if !defined(C10_ASAN_ENABLED) +#define C10_ASAN_ENABLED 1 +#endif +#endif +#endif + +#if !defined(C10_ASAN_ENABLED) +#define C10_ASAN_ENABLED 0 +#endif + +// Detect undefined-behavior sanitizer (UBSAN) +#undef C10_UBSAN_ENABLED + +// for clang or gcc >= 14 +// NB: gcc 14 adds support for Clang's __has_feature +// https://gcc.gnu.org/gcc-14/changes.html +// gcc < 14 doesn't have a macro for UBSAN +// (e.g. __SANITIZE_UNDEFINED__ does not exist in gcc) +// https://github.com/google/sanitizers/issues/765 +#if defined(__has_feature) +#if ((__has_feature(undefined_behavior_sanitizer))) +#define C10_UBSAN_ENABLED 1 +#endif +#endif + +#if !defined(C10_UBSAN_ENABLED) +#define C10_UBSAN_ENABLED 0 +#endif + +// Disable the copy and assignment operator for a class. Note that this will +// disable the usage of the class in std containers. +#define C10_DISABLE_COPY_AND_ASSIGN(classname) \ + classname(const classname&) = delete; \ + classname& operator=(const classname&) = delete + +#define C10_CONCATENATE_IMPL(s1, s2) s1##s2 +#define C10_CONCATENATE(s1, s2) C10_CONCATENATE_IMPL(s1, s2) + +#define C10_MACRO_EXPAND(args) args + +#define C10_STRINGIZE_IMPL(x) #x +#define C10_STRINGIZE(x) C10_STRINGIZE_IMPL(x) + +/** + * C10_ANONYMOUS_VARIABLE(str) introduces a new identifier which starts with + * str and ends with a unique number. + */ +#ifdef __COUNTER__ +#define C10_UID __COUNTER__ +#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __COUNTER__) +#else +#define C10_UID __LINE__ +#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__) +#endif + +#ifdef __has_cpp_attribute +#define C10_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) +#else +#define C10_HAS_CPP_ATTRIBUTE(x) (0) +#endif + +#ifndef FBCODE_CAFFE2 +/// DEPRECATED: Warn if a type or return value is discarded. +#define C10_NODISCARD [[nodiscard]] + +/// DEPRECATED: Suppress an unused variable. +#define C10_UNUSED [[maybe_unused]] +#endif + +#if !defined(__has_attribute) +#define __has_attribute(x) 0 +#endif + +// Direct port of LLVM_ATTRIBUTE_USED. +#if __has_attribute(used) +#define C10_USED __attribute__((__used__)) +#else +#define C10_USED +#endif + +#define C10_RESTRICT __restrict + +// Simply define the namespace, in case a dependent library want to refer to +// the c10 namespace but not any nontrivial files. +namespace c10 {} +namespace c10::cuda {} +namespace c10::hip {} +namespace c10::xpu {} + +// Since C10 is the core library for caffe2 (and aten), we will simply reroute +// all abstractions defined in c10 to be available in caffe2 as well. +// This is only for backwards compatibility. Please use the symbols from the +// c10 namespace where possible. +namespace caffe2 { +using namespace c10; +} +namespace at { +using namespace c10; +} +namespace at::cuda { +using namespace c10::cuda; +} // namespace at::cuda + +// WARNING!!! THIS IS A GIANT HACK!!! +// This line means you cannot simultaneously include c10/hip +// and c10/cuda and then use them from the at::cuda namespace. +// This is true in practice, because HIPIFY works inplace on +// files in ATen/cuda, so it assumes that c10::hip is available +// from at::cuda. This namespace makes that happen. When +// HIPIFY is no longer out-of-place, we can switch the cuda +// here to hip and everyone is happy. +namespace at::cuda { +using namespace c10::hip; +} // namespace at::cuda + +namespace at::xpu { +using namespace c10::xpu; +} // namespace at::xpu + +// C10_LIKELY/C10_UNLIKELY +// +// These macros provide parentheses, so you can use these macros as: +// +// if C10_LIKELY(some_expr) { +// ... +// } +// +// NB: static_cast to boolean is mandatory in C++, because __builtin_expect +// takes a long argument, which means you may trigger the wrong conversion +// without it. +// +#if defined(__GNUC__) || defined(__ICL) || defined(__clang__) +#define C10_LIKELY(expr) (__builtin_expect(static_cast(expr), 1)) +#define C10_UNLIKELY(expr) (__builtin_expect(static_cast(expr), 0)) +#else +#define C10_LIKELY(expr) (expr) +#define C10_UNLIKELY(expr) (expr) +#endif + +/// C10_NOINLINE - Functions whose declaration is annotated with this will not +/// be inlined. +#ifdef __GNUC__ +#define C10_NOINLINE __attribute__((noinline)) +#elif _MSC_VER +#define C10_NOINLINE __declspec(noinline) +#else +#define C10_NOINLINE +#endif + +#if defined(_MSC_VER) +#define C10_ALWAYS_INLINE __forceinline +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define C10_ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define C10_ALWAYS_INLINE inline +#endif + +// Unlike C10_ALWAYS_INLINE, C10_ALWAYS_INLINE_ATTRIBUTE can be used +// on a lambda. +#if defined(_MSC_VER) +// MSVC 14.39 is reasonably recent and doesn't like +// [[msvc::forceinline]] on a lambda, so don't try to use it. +#define C10_ALWAYS_INLINE_ATTRIBUTE +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define C10_ALWAYS_INLINE_ATTRIBUTE __attribute__((__always_inline__)) +#else +#define C10_ALWAYS_INLINE_ATTRIBUTE +#endif + +#if defined(_MSC_VER) +#define C10_ATTR_VISIBILITY_HIDDEN +#elif defined(__GNUC__) +#define C10_ATTR_VISIBILITY_HIDDEN __attribute__((__visibility__("hidden"))) +#else +#define C10_ATTR_VISIBILITY_HIDDEN +#endif + +#define C10_ERASE C10_ALWAYS_INLINE C10_ATTR_VISIBILITY_HIDDEN + +#include + +#ifdef __HIPCC__ +// Unlike CUDA, HIP requires a HIP header to be included for __host__ to work. +// We do this #include here so that C10_HOST_DEVICE and friends will Just Work. +// See https://github.com/ROCm/hip/issues/441 +#include +#endif + +#if defined(__CUDACC__) || defined(__HIPCC__) +// Designates functions callable from the host (CPU) and the device (GPU) +#define C10_HOST_DEVICE __host__ __device__ +#define C10_DEVICE __device__ +#define C10_HOST __host__ +// constants from +// (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications) +// The maximum number of threads per multiprocessor is 1024 for Turing +// architecture (7.5), 1536 for Geforce Ampere (8.6)/Jetson Orin (8.7), and +// 2048 for all other architectures. You'll get warnings if you exceed these +// constants. Hence, the following macros adjust the input values from the user +// to resolve potential warnings. +#if __CUDA_ARCH__ == 750 +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024; +#elif __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 890 +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1536; +#else +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048; +#endif +// CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently +constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024; +// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block +// size. 256 is a good number for this fallback and should give good occupancy +// and versatility across all architectures. +constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; +// NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it +// turns out that although __launch_bounds__ can take constexpr, it +// can't take a constexpr that has anything to do with templates. +// Currently we use launch_bounds that depend on template arguments in +// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, C10_MAX_THREADS_PER_BLOCK +// and C10_MIN_BLOCKS_PER_SM are kept as macros. +// Suppose you were planning to write __launch_bounds__(a, b), based on your +// performance tuning on a modern GPU. Instead, you should write +// __launch_bounds__(C10_MAX_THREADS_PER_BLOCK(a), C10_MIN_BLOCKS_PER_SM(a, b)), +// which will also properly respect limits on old architectures. +#define C10_MAX_THREADS_PER_BLOCK(val) \ + (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) \ + : CUDA_THREADS_PER_BLOCK_FALLBACK) +#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) \ + ((((threads_per_block) * (blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) \ + ? (blocks_per_sm) \ + : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block) - 1) / \ + (threads_per_block)))) +// C10_LAUNCH_BOUNDS is analogous to __launch_bounds__ +#define C10_LAUNCH_BOUNDS_0 \ + __launch_bounds__( \ + 256, 4) // default launch bounds that should give good occupancy and + // versatility across all architectures. +#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) \ + __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block)))) +#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) \ + __launch_bounds__( \ + (C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), \ + (C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm)))) +#else +#define C10_HOST_DEVICE +#define C10_HOST +#define C10_DEVICE +#endif + +#if defined(USE_ROCM) +#define C10_HIP_HOST_DEVICE __host__ __device__ +#else +#define C10_HIP_HOST_DEVICE +#endif + +#if defined(USE_ROCM) +// C10_WARP_SIZE is only allowed for device code. +// Host code _must_ use at::cuda::warp_size() +// HIP header used to define warpSize as a constexpr that was either 32 or 64 +// depending on the target device, and then always set it to 64 for host code. +// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we +// set it to something unreasonable to trigger obvious host code errors. + +namespace at::cuda { +TORCH_CUDA_CPP_API int warp_size(); +} +#ifdef __HIPCC__ +static inline int __host__ C10_WARP_SIZE_INTERNAL() { + return at::cuda::warp_size(); +} + +static inline constexpr int __device__ C10_WARP_SIZE_INTERNAL() { +#if defined(__GFX9__) + return 64; +#else // __GFX9__ + return 32; +#endif // __GFX9__ +} +#else // __HIPCC__ +inline int C10_WARP_SIZE_INTERNAL() { + return at::cuda::warp_size(); +} +#endif // __HIPCC__ + +#define C10_WARP_SIZE (C10_WARP_SIZE_INTERNAL()) +#define C10_WARP_SIZE_STATIC 64 + +#else // defined(USE_ROCM) +#define C10_WARP_SIZE 32 +#endif + +#if defined(_MSC_VER) && _MSC_VER <= 1900 +#define __func__ __FUNCTION__ +#endif + +// CUDA_KERNEL_ASSERT checks the assertion +// even when NDEBUG is defined. This is useful for important assertions in CUDA +// code that would otherwise be suppressed when building Release. +#if defined(__ANDROID__) || defined(__APPLE__) || defined(__FreeBSD__) +// Those platforms do not support assert() +#define CUDA_KERNEL_ASSERT(cond) +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) +#define SYCL_KERNEL_ASSERT(cond) +#elif defined(_MSC_VER) +#if defined(NDEBUG) +extern "C" { +C10_IMPORT +#if defined(__SYCL_DEVICE_ONLY__) +extern SYCL_EXTERNAL void _wassert( + const wchar_t* wexpr, + const wchar_t* wfile, + unsigned line); +#else +#if defined(__CUDA_ARCH__) +__host__ __device__ +#endif // __CUDA_ARCH__ + void + _wassert(wchar_t const* _Message, wchar_t const* _File, unsigned _Line); +#endif // __SYCL_DEVICE_ONLY__ +} +#endif // NDEBUG +#define CUDA_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + (void)(_wassert( \ + _CRT_WIDE(#cond), \ + _CRT_WIDE(__FILE__), \ + static_cast(__LINE__)), \ + 0); \ + } +// TODO: This doesn't assert the message because I (chilli) couldn't figure out +// a nice way to convert a char* to a wchar_t* +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ + if (C10_UNLIKELY(!(cond))) { \ + (void)(_wassert( \ + _CRT_WIDE(#cond), \ + _CRT_WIDE(__FILE__), \ + static_cast(__LINE__)), \ + 0); \ + } +#define SYCL_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + (void)(_wassert( \ + _CRT_WIDE(#cond), \ + _CRT_WIDE(__FILE__), \ + static_cast(__LINE__)), \ + 0); \ + } +#else // __APPLE__, _MSC_VER +#if defined(NDEBUG) +extern "C" { +#if defined(__SYCL_DEVICE_ONLY__) +extern SYCL_EXTERNAL void __assert_fail( + const char* expr, + const char* file, + unsigned int line, + const char* func); +#else // __SYCL_DEVICE_ONLY__ +#if (defined(__CUDA_ARCH__) && !(defined(__clang__) && defined(__CUDA__))) +// CUDA supports __assert_fail function which are common for both device +// and host side code. +__host__ __device__ +#endif + + // This forward declaration matching the declaration of __assert_fail + // exactly how it is in glibc in case parts of the program are compiled with + // different NDEBUG settings. Otherwise we might get 'ambiguous declaration' + // error. Note: On ROCm - this declaration serves for host side compilation. + void + __assert_fail( + const char* assertion, + const char* file, + unsigned int line, + const char* function) noexcept __attribute__((__noreturn__)); + +#endif // __SYCL_DEVICE_ONLY__ +} +#endif // NDEBUG +// ROCm disables kernel assert by default for performance considerations. +// Though ROCm supports __assert_fail, it uses kernel printf which has +// a non-negligible performance impact even if the assert condition is +// never triggered. We choose to use abort() instead which will still +// terminate the application but without a more useful error message. +#if !defined(C10_USE_ROCM_KERNEL_ASSERT) and defined(USE_ROCM) +#define CUDA_KERNEL_ASSERT(cond) \ + if C10_UNLIKELY (!(cond)) { \ + abort(); \ + } +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ + if C10_UNLIKELY (!(cond)) { \ + abort(); \ + } +#define SYCL_KERNEL_ASSERT(cond) \ + if C10_UNLIKELY (!(cond)) { \ + abort(); \ + } +#else +#define CUDA_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + __assert_fail( \ + #cond, __FILE__, static_cast(__LINE__), __func__); \ + } +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ + if (C10_UNLIKELY(!(cond))) { \ + __assert_fail( \ + msg, __FILE__, static_cast(__LINE__), __func__); \ + } +#define SYCL_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + __assert_fail( \ + #cond, __FILE__, static_cast(__LINE__), __func__); \ + } +#endif // C10_USE_ROCM_KERNEL_ASSERT and USE_ROCM +#endif // __APPLE__ + +#ifdef __APPLE__ +#include +#endif + +#if defined(__ANDROID__) +#define C10_ANDROID 1 +#define C10_MOBILE 1 +#elif ( \ + defined(__APPLE__) && \ + (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE)) +#define C10_IOS 1 +#define C10_MOBILE 1 +#endif // ANDROID / IOS + +#if defined(C10_MOBILE) && C10_MOBILE +#define C10_ALWAYS_INLINE_UNLESS_MOBILE inline +#else +#define C10_ALWAYS_INLINE_UNLESS_MOBILE C10_ALWAYS_INLINE +#endif + +#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) +#define CONSTEXPR_EXCEPT_WIN_CUDA constexpr +#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA constexpr + +#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \ + static constexpr const char field[] = val; +#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val) +#endif // !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) + +#ifndef HAS_DEMANGLE +#if defined(__ANDROID__) || defined(_WIN32) || defined(__EMSCRIPTEN__) +#define HAS_DEMANGLE 0 +#elif defined(__APPLE__) && \ + (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE) +#define HAS_DEMANGLE 0 +#else +#define HAS_DEMANGLE 1 +#endif +#endif // HAS_DEMANGLE + +#define _C10_PRAGMA__(string) _Pragma(#string) +#define _C10_PRAGMA_(string) _C10_PRAGMA__(string) + +#ifdef __clang__ +#define C10_CLANG_DIAGNOSTIC_PUSH() _Pragma("clang diagnostic push") +#define C10_CLANG_DIAGNOSTIC_POP() _Pragma("clang diagnostic pop") +#define C10_CLANG_DIAGNOSTIC_IGNORE(flag) \ + _C10_PRAGMA_(clang diagnostic ignored flag) +#define C10_CLANG_HAS_WARNING(flag) __has_warning(flag) +#else +#define C10_CLANG_DIAGNOSTIC_PUSH() +#define C10_CLANG_DIAGNOSTIC_POP() +#define C10_CLANG_DIAGNOSTIC_IGNORE(flag) +#define C10_CLANG_HAS_WARNING(flag) 0 +#endif + +#ifdef __clang__ + +#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) \ + _C10_PRAGMA_(clang diagnostic push) \ + _C10_PRAGMA_(clang diagnostic ignored "-Wunknown-warning-option") \ + _C10_PRAGMA_(clang diagnostic ignored warning) + +#define C10_DIAGNOSTIC_POP() _C10_PRAGMA_(clang diagnostic pop) + +#elif __GNUC__ + +#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) \ + _C10_PRAGMA_(GCC diagnostic push) \ + _C10_PRAGMA_(GCC diagnostic ignored "-Wpragmas") \ + _C10_PRAGMA_(GCC diagnostic ignored warning) + +#define C10_DIAGNOSTIC_POP() _C10_PRAGMA_(GCC diagnostic pop) + +#else + +#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) +#define C10_DIAGNOSTIC_POP() + +#endif + +// This macro is used to find older C++ compilers +// that don't support move optimization for return values. + +#if (defined(__GNUC__) && __GNUC__ < 13) || \ + (defined(__clang_major__) && __clang_major__ < 13) +#define C10_RETURN_MOVE_IF_OLD_COMPILER 1 +#else +#define C10_RETURN_MOVE_IF_OLD_COMPILER 0 +#endif + +#endif // C10_MACROS_MACROS_H_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/macros/build.bzl b/c10/macros/build.bzl index d5809d36687d7..ef44cdf922adf 100644 --- a/c10/macros/build.bzl +++ b/c10/macros/build.bzl @@ -1,13 +1,20 @@ def define_targets(rules): rules.cc_library( name = "macros", +<<<<<<< HEAD +======= + srcs = [":cmake_macros_h"], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) hdrs = [ "Macros.h", # Despite the documentation in Macros.h, Export.h is included # directly by many downstream files. Thus, we declare it as a # public header in this file. "Export.h", +<<<<<<< HEAD "cmake_macros.h", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ], linkstatic = True, local_defines = ["C10_BUILD_MAIN_LIB"], @@ -17,6 +24,25 @@ def define_targets(rules): ], ) +<<<<<<< HEAD +======= + rules.cmake_configure_file( + name = "cmake_macros_h", + src = "cmake_macros.h.in", + out = "cmake_macros.h", + definitions = [ + "C10_BUILD_SHARED_LIBS", + "C10_USE_MSVC_STATIC_RUNTIME", + ] + rules.select({ + "//c10:using_gflags": ["C10_USE_GFLAGS"], + "//conditions:default": [], + }) + rules.select({ + "//c10:using_glog": ["C10_USE_GLOG"], + "//conditions:default": [], + }), + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) rules.filegroup( name = "headers", srcs = rules.glob( diff --git a/c10/macros/cmake_configure_file.bzl b/c10/macros/cmake_configure_file.bzl new file mode 100644 index 0000000000000..16d09cc9ee30c --- /dev/null +++ b/c10/macros/cmake_configure_file.bzl @@ -0,0 +1,65 @@ +# Forked from header_template_rule. header_template_rule is not +# compatible with our usage of select because its substitutions +# attribute is a dict, and dicts may not be appended with select. We +# get around this limitation by using a list as our substitutions. +def _cmake_configure_file_impl(ctx): + command = ["cat $1"] + for definition in ctx.attr.definitions: + command.append( + "| sed 's@#cmakedefine {}@#define {}@'".format( + definition, + definition, + ), + ) + + # Replace any that remain with /* #undef FOO */. + command.append("| sed -r 's@#cmakedefine ([A-Z0-9_]+)@/* #undef \\1 */@'") + command.append("> $2") + + ctx.actions.run_shell( + inputs = [ctx.file.src], + outputs = [ctx.outputs.out], + command = " ".join(command), + arguments = [ + ctx.file.src.path, + ctx.outputs.out.path, + ], + ) + return [ + # create a provider which says that this + # out file should be made available as a header + CcInfo(compilation_context = cc_common.create_compilation_context( + + # pass out the include path for finding this header + includes = depset([ctx.outputs.out.dirname, ctx.bin_dir.path]), + + # and the actual header here. + headers = depset([ctx.outputs.out]), + )), + ] + +cmake_configure_file = rule( + implementation = _cmake_configure_file_impl, + doc = """ +Mimics CMake's configure_file in Bazel. + +Args: + name: A unique name for this rule. + src: The input file template. + out: The generated output. + definitions: A mapping of identifier in template to its value. +""", + attrs = { + # We use attr.string_list for compatibility with select and + # config_setting. See the comment above _cmake_configure_file_impl + # for more information. + "definitions": attr.string_list(mandatory = True), + "out": attr.output(mandatory = True), + "src": attr.label( + mandatory = True, + allow_single_file = True, + ), + }, + # output_to_genfiles is required for header files. + output_to_genfiles = True, +) diff --git a/c10/macros/cmake_macros.h.in b/c10/macros/cmake_macros.h.in new file mode 100644 index 0000000000000..76c185b55236c --- /dev/null +++ b/c10/macros/cmake_macros.h.in @@ -0,0 +1,14 @@ +#ifndef C10_MACROS_CMAKE_MACROS_H_ +#define C10_MACROS_CMAKE_MACROS_H_ + +// Automatically generated header file for the C10 library. +// Do not include this file directly. Instead, include c10/macros/Macros.h. + +#cmakedefine C10_BUILD_SHARED_LIBS +#cmakedefine C10_USE_GLOG +#cmakedefine C10_USE_GFLAGS +#cmakedefine C10_USE_NUMA +#cmakedefine C10_USE_MSVC_STATIC_RUNTIME +#cmakedefine C10_USE_ROCM_KERNEL_ASSERT + +#endif // C10_MACROS_CMAKE_MACROS_H_ diff --git a/c10/metal/atomic.h b/c10/metal/atomic.h index d0cbc03916989..1bd646c9b7573 100644 --- a/c10/metal/atomic.h +++ b/c10/metal/atomic.h @@ -35,6 +35,7 @@ static inline void atomic_add_helper( device ::metal::atomic* data, long offset, T value) { +<<<<<<< HEAD constexpr auto elem_per_enum = sizeof(uint) / sizeof(T); auto ptr = data + (offset / elem_per_enum); auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed); @@ -45,6 +46,17 @@ static inline void atomic_add_helper( do { val.i = old; val.t[offset & (elem_per_enum - 1)] += value; +======= + auto ptr = data + (offset >> 1); + auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed); + union { + uint i; + T t[2]; + } val; + do { + val.i = old; + val.t[offset & 1] += value; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } while (!::metal::atomic_compare_exchange_weak_explicit( ptr, &old, @@ -57,6 +69,7 @@ template <> struct AtomicType { using type = ::metal::atomic; static inline void atomic_add(device type* data, long offset, half value) { +<<<<<<< HEAD atomic_add_helper(data, offset, value); } }; @@ -85,6 +98,13 @@ struct AtomicType { } }; +======= + atomic_add_helper(data, offset, value); + } +}; + +#if __METAL_VERSION__ >= 310 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template <> struct AtomicType { using type = ::metal::atomic; @@ -92,6 +112,10 @@ struct AtomicType { atomic_add_helper(data, offset, value); } }; +<<<<<<< HEAD +======= +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Metal supports atomic_store_explicit for bools, but // sizeof(::metal::atomic_bool) is 4 Therefore it could not be used to @@ -124,6 +148,7 @@ struct AtomicType { } }; +<<<<<<< HEAD // ComplexHalf atomic op template <> struct AtomicType { @@ -173,5 +198,7 @@ struct AtomicType { } }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace metal } // namespace c10 diff --git a/c10/metal/common.h b/c10/metal/common.h index e4b4d1a38ca4e..265c55375ef60 100644 --- a/c10/metal/common.h +++ b/c10/metal/common.h @@ -2,6 +2,7 @@ // Set of global constants that could be shareable between CPU and Metal code #ifdef __METAL__ +<<<<<<< HEAD #include #define C10_METAL_CONSTEXPR constant constexpr #else @@ -9,6 +10,14 @@ #define C10_METAL_CONSTEXPR constexpr #endif +======= +#define C10_METAL_CONSTEXPR constant constexpr +#else +#define C10_METAL_CONSTEXPR constexpr +#endif + +#if !defined(__METAL__) || __METAL_VERSION__ >= 310 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #define C10_METAL_ALL_TYPES_FUNCTOR(_) \ _(Byte, 0) \ _(Char, 1) \ @@ -21,10 +30,27 @@ _(ComplexFloat, 9) \ _(Bool, 11) \ _(BFloat16, 15) +<<<<<<< HEAD +======= +#else +#define C10_METAL_ALL_TYPES_FUNCTOR(_) \ + _(Byte, 0) \ + _(Char, 1) \ + _(Short, 2) \ + _(Int, 3) \ + _(Long, 4) \ + _(Half, 5) \ + _(Float, 6) \ + _(ComplexHalf, 8) \ + _(ComplexFloat, 9) \ + _(Bool, 11) +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace c10 { namespace metal { C10_METAL_CONSTEXPR unsigned max_ndim = 16; +<<<<<<< HEAD C10_METAL_CONSTEXPR unsigned simdgroup_size = 32; #ifdef __METAL__ @@ -34,6 +60,8 @@ using array = ::metal::array; template using array = std::array; #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) enum class ScalarType { #define _DEFINE_ENUM_VAL_(_v, _n) _v = _n, diff --git a/c10/metal/indexing.h b/c10/metal/indexing.h index 9cfe65f6a03a8..804f8c1c8e229 100644 --- a/c10/metal/indexing.h +++ b/c10/metal/indexing.h @@ -186,8 +186,15 @@ inline T val_at_offs(constant void* ptr, long offs, ScalarType type) { return cast_to(val_at_offs(ptr, offs)); case ScalarType::Half: return cast_to(val_at_offs(ptr, offs)); +<<<<<<< HEAD case ScalarType::BFloat16: return cast_to(val_at_offs(ptr, offs)); +======= +#if __METAL_VERSION__ >= 310 + case ScalarType::BFloat16: + return cast_to(val_at_offs(ptr, offs)); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Complex case ScalarType::ComplexHalf: return cast_to(val_at_offs(ptr, offs)); diff --git a/c10/metal/reduction_utils.h b/c10/metal/reduction_utils.h index 2d97820191663..0018bafaef04b 100644 --- a/c10/metal/reduction_utils.h +++ b/c10/metal/reduction_utils.h @@ -5,6 +5,7 @@ namespace c10 { namespace metal { +<<<<<<< HEAD namespace detail { template struct simd_type { @@ -24,10 +25,19 @@ struct simd_type { template inline ::metal::enable_if_t, T> simd_sum(T val) { return T(::metal::simd_sum(detail::simd_type_t(val))); +======= + +constant constexpr ushort simdgroup_size = 32; + +template +inline ::metal::enable_if_t, T> simd_sum(T val) { + return ::metal::simd_sum(val); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } template inline ::metal::enable_if_t, T> simd_prod(T val) { +<<<<<<< HEAD return T(::metal::simd_product(detail::simd_type_t(val))); } @@ -83,6 +93,9 @@ template < true> inline T simd_min(T val) { return ::metal::simd_min(val); +======= + return ::metal::simd_product(val); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // Metal does not support SIMD reductions over 64-bit types, but it could be @@ -97,7 +110,11 @@ inline ::metal::enable_if_t<::metal::is_same_v, T> simd_sum(T val) { val += as_type( ::metal::simd_shuffle_and_fill_down(as_type(val), int2(0), i)); } +<<<<<<< HEAD return simd_broadcast(val, 0); +======= + return as_type(::metal::simd_broadcast(as_type(val), 0)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } template @@ -106,6 +123,7 @@ inline ::metal::enable_if_t<::metal::is_same_v, T> simd_prod(T val) { val *= as_type( ::metal::simd_shuffle_and_fill_down(as_type(val), int2(0), i)); } +<<<<<<< HEAD return simd_broadcast(val, 0); } @@ -178,6 +196,9 @@ template inline c10::metal::pair simd_argmax(ARG_T val, IDX_T idx_val) { auto rc = simd_argmax(val); return {rc.first, simd_broadcast(idx_val, rc.second)}; +======= + return as_type(::metal::simd_broadcast(as_type(val), 0)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // Below algorithms are written with hardcoded assumption that simdgroup is 32 @@ -229,6 +250,7 @@ opmath_t threadgroup_prod( } template +<<<<<<< HEAD T threadgroup_max(threadgroup T* data, T val, unsigned idx, unsigned size) { auto rc = simd_max(val); if (idx % simdgroup_size == 0) { @@ -267,6 +289,8 @@ T threadgroup_min(threadgroup T* data, T val, unsigned idx, unsigned size) { } template +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) float3 threadgroup_welford_reduce(threadgroup T* data, unsigned size) { ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); float m = data[0]; @@ -301,6 +325,7 @@ float3 threadgroup_welford_combine(threadgroup T* data, unsigned size) { return rc; } +<<<<<<< HEAD template IDX_T threadgroup_argmax( threadgroup ARG_T* arg_data, @@ -353,6 +378,54 @@ IDX_T threadgroup_argmin( } ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); return idx_data[0]; +======= +template +T threadgroup_max(threadgroup T* data, unsigned size) { + // TODO: This should be moved to the callee + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + T rc = data[0]; + for (unsigned idx = 1; idx < size; ++idx) { + rc = ::c10::metal::max(rc, data[idx]); + } + return rc; +} + +template +T threadgroup_min(threadgroup T* data, unsigned size) { + // TODO: This should be moved to the callee + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + T rc = data[0]; + for (unsigned idx = 1; idx < size; ++idx) { + rc = ::c10::metal::min(rc, data[idx]); + } + return rc; +} + +template +int threadgroup_argmax(threadgroup T* data, unsigned size) { + // TODO: This should be moved to the callee + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + int rc = 0; + for (unsigned idx = 1; idx < size; ++idx) { + if (data[idx] > data[rc]) { + rc = idx; + } + } + return rc; +} + +template +int threadgroup_argmin(threadgroup T* data, unsigned size) { + // TODO: This should be moved to the callee + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + int rc = 0; + for (unsigned idx = 1; idx < size; ++idx) { + if (data[idx] < data[rc]) { + rc = idx; + } + } + return rc; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace metal diff --git a/c10/metal/special_math.h b/c10/metal/special_math.h index 29a45ff4c30b6..3737858785e75 100644 --- a/c10/metal/special_math.h +++ b/c10/metal/special_math.h @@ -1,7 +1,10 @@ // Implementation of specal math functions for Metal #pragma once #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -49,11 +52,14 @@ inline float erf(T x) { } template +<<<<<<< HEAD float erfc(T x) { return 1.0 - erf(x); } template +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inline float erfinv(T y) { /* coefficients in rational expansion */ constexpr float a[4] = {0.886226899, -1.645349621, 0.914624893, -0.140543331}; @@ -1565,7 +1571,11 @@ float chebyshev_polynomial_t_forward(T x, int64_t n) { float q = x; float r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = (x + x) * q - p; p = q; q = r; @@ -1609,7 +1619,11 @@ float chebyshev_polynomial_u_forward(T x, int64_t n) { auto p = 1.0; float r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = 2 * x * q - p; p = q; q = r; @@ -1662,7 +1676,11 @@ float chebyshev_polynomial_v_forward(T x, int64_t n) { auto p = 1.0; float r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = 2 * x * q - p; p = q; q = r; @@ -1719,7 +1737,11 @@ float chebyshev_polynomial_w_forward(T x, int64_t n) { auto p = 1.0; float r; +<<<<<<< HEAD for (int64_t k = 2; (k <= n) && !::metal::isnan(q); k++) { +======= + for (int64_t k = 2; k <= n; k++) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = 2.0 * x * q - p; p = q; q = r; @@ -1729,6 +1751,7 @@ float chebyshev_polynomial_w_forward(T x, int64_t n) { } // chebyshev_polynomial_w_forward(T x, int64_t n) template +<<<<<<< HEAD float shifted_chebyshev_polynomial_t_forward(T x, int64_t n) { if (n < 0) { return 0.0; @@ -1930,6 +1953,8 @@ float shifted_chebyshev_polynomial_w_forward(T x, int64_t n) { } // shifted_chebyshev_polynomial_w_forward(T x, int64_t n) template +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // TODO: Add 512 if/when double will be supported in Metal inline constexpr int getHermitianLimit() { return 128; diff --git a/c10/metal/utils.h b/c10/metal/utils.h index aaa0e1741240d..f2a1b2114502b 100644 --- a/c10/metal/utils.h +++ b/c10/metal/utils.h @@ -24,12 +24,20 @@ struct vectypes { using type2 = half2; }; +<<<<<<< HEAD +======= +#if __METAL_VERSION__ >= 310 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template <> struct vectypes { using type4 = bfloat4; using type3 = bfloat3; using type2 = bfloat2; }; +<<<<<<< HEAD +======= +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template <> struct vectypes { @@ -77,10 +85,18 @@ struct OpMathType { using type = int; }; +<<<<<<< HEAD +======= +#if __METAL_VERSION__ >= 310 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template <> struct OpMathType { using type = float; }; +<<<<<<< HEAD +======= +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Type promotion structure for higher precision accumulation template @@ -94,11 +110,19 @@ struct AccumulationType { using type = float; }; +<<<<<<< HEAD +======= +#if __METAL_VERSION__ >= 310 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Specialization for bfloat - promote to float for accumulation template <> struct AccumulationType { using type = float; }; +<<<<<<< HEAD +======= +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace detail @@ -124,6 +148,10 @@ min(T a, U b) { return ::metal::min(a, static_cast(b)); } +<<<<<<< HEAD +======= +#if __METAL_VERSION__ >= 310 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template <> inline bfloat min(bfloat a, bfloat b) { return bfloat( @@ -135,6 +163,10 @@ inline bfloat max(bfloat a, bfloat b) { return bfloat( ::metal::isunordered(a, b) ? NAN : ::metal::max(float(a), float(b))); } +<<<<<<< HEAD +======= +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template using vec2type_t = typename detail::vectypes::type2; @@ -322,11 +354,14 @@ inline float log1p(float x) { return rc; } +<<<<<<< HEAD template struct pair { T1 first; T2 second; }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace metal } // namespace c10 diff --git a/c10/ovrsource_defs.bzl b/c10/ovrsource_defs.bzl index aafe5a4de8c42..ab5e5f1a6f9d6 100644 --- a/c10/ovrsource_defs.bzl +++ b/c10/ovrsource_defs.bzl @@ -63,6 +63,10 @@ def define_c10_ovrsource(name, is_mobile): "core/impl/*.h", ]), reexport_all_header_dependencies = False, +<<<<<<< HEAD +======= + # tests = C10_CPU_TEST_TARGETS, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) visibility = [ "//xplat/caffe2/c10:c10_ovrsource", ], @@ -73,7 +77,12 @@ def define_c10_ovrsource(name, is_mobile): ], }), exported_deps = [ +<<<<<<< HEAD "//xplat/caffe2/torch/headeronly:torch_headeronly_ovrsource", +======= + "//xplat/caffe2/torch/headeronly:torch_headeronly", + ":ovrsource_c10_cmake_macros.h", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "//arvr/third-party/gflags:gflags", "//third-party/cpuinfo:cpuinfo", "//third-party/fmt:fmt", @@ -82,6 +91,77 @@ def define_c10_ovrsource(name, is_mobile): ) def define_ovrsource_targets(): +<<<<<<< HEAD +======= + # C10_CPU_TEST_FILES = native.glob([ + # "test/core/*.cpp", + # "test/util/*.cpp", + # ]) + + # C10_GPU_TEST_FILES = native.glob([ + # "cuda/test/**/*.cpp", + # ]) + + # C10_CPU_TEST_TARGETS = [ + # ":" + paths.basename(test)[:-len(".cpp")] + "_ovrsource" + # for test in C10_CPU_TEST_FILES + # ] + + # C10_GPU_TEST_TARGETS = [ + # ":" + paths.basename(test)[:-len(".cpp")] + "_ovrsource" + # for test in C10_GPU_TEST_FILES + # ] + + common_c10_cmake_defines = [ + ("#cmakedefine C10_BUILD_SHARED_LIBS", ""), + ("#cmakedefine C10_USE_NUMA", ""), + ("#cmakedefine C10_USE_MSVC_STATIC_RUNTIME", ""), + ("#cmakedefine C10_USE_ROCM_KERNEL_ASSERT", ""), + ] + + mobile_c10_cmake_defines = [ + ("#cmakedefine C10_USE_GLOG", ""), + ("#cmakedefine C10_USE_GFLAGS", ""), + ] + + non_mobile_c10_cmake_defines = [ + ("#cmakedefine C10_USE_GLOG", "#define C10_USE_GLOG 1"), + ("#cmakedefine C10_USE_GFLAGS", "#define C10_USE_GFLAGS 1"), + ] + + gen_cmake_header( + src = "macros/cmake_macros.h.in", + defines = common_c10_cmake_defines + mobile_c10_cmake_defines, + header = "c10/macros/cmake_macros.h", + prefix = "ovrsource_c10_mobile_", + ) + + gen_cmake_header( + src = "macros/cmake_macros.h.in", + defines = common_c10_cmake_defines + non_mobile_c10_cmake_defines, + header = "c10/macros/cmake_macros.h", + prefix = "ovrsource_c10_non_mobile_", + ) + + oxx_static_library( + name = "ovrsource_c10_cmake_macros.h", + compatible_with = [ + "ovr_config//os:android", + "ovr_config//os:iphoneos", + "ovr_config//os:linux", + "ovr_config//os:macos", + "ovr_config//os:windows", + ], + deps = select({ + "ovr_config//os:android": [":ovrsource_c10_mobile_cmake_macros.h"], + "ovr_config//os:iphoneos": [":ovrsource_c10_mobile_cmake_macros.h"], + "ovr_config//os:linux": [":ovrsource_c10_non_mobile_cmake_macros.h"], + "ovr_config//os:macos": [":ovrsource_c10_non_mobile_cmake_macros.h"], + "ovr_config//os:windows": [":ovrsource_c10_non_mobile_cmake_macros.h"], + }), + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10_cuda_macros = gen_cmake_header( src = "cuda/impl/cuda_cmake_macros.h.in", defines = [ @@ -137,6 +217,10 @@ def define_ovrsource_targets(): "cuda/impl/*.h", ]), reexport_all_header_dependencies = False, +<<<<<<< HEAD +======= + # tests = C10_GPU_TEST_TARGETS, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) visibility = ["PUBLIC"], deps = [ "//third-party/cuda:libcuda", @@ -146,3 +230,67 @@ def define_ovrsource_targets(): ":c10_ovrsource", ], ) +<<<<<<< HEAD +======= + + # [ + # oxx_test( + # name = paths.basename(test)[:-len(".cpp")] + "_ovrsource", + # srcs = [test], + # compatible_with = cpu_supported_platforms, + # compiler_flags = select({ + # "DEFAULT": [], + # "ovr_config//compiler:cl": [ + # "/w", + # ], + # "ovr_config//compiler:clang": [ + # "-Wno-error", + # "-Wno-self-assign-overloaded", + # "-Wno-self-move", + # "-Wno-shadow", + # "-Wno-undef", + # "-Wno-unused-function", + # "-Wno-unused-variable", + # ], + # }), + # framework = "gtest", + # oncall = "ovrsource_pytorch", + # raw_headers = native.glob([ + # "test/**/*.h", + # ]), + # deps = [ + # ":c10_ovrsource", + # ], + # ) + # for test in C10_CPU_TEST_FILES + # ] + + # [ + # oxx_test( + # name = paths.basename(test)[:-len(".cpp")] + "_ovrsource", + # srcs = [test], + # compatible_with = cuda_supported_platforms, + # compiler_flags = select({ + # "DEFAULT": [], + # "ovr_config//compiler:cl": [ + # "/w", + # ], + # "ovr_config//compiler:clang": [ + # "-Wno-error", + # ], + # }), + # framework = "gtest", + # oncall = "ovrsource_pytorch", + # raw_headers = native.glob([ + # "test/**/*.h", + # ]), + # runtime_shared_libraries = [ + # "//third-party/cuda:cudart", + # ], + # deps = [ + # ":c10_cuda_ovrsource", + # ], + # ) + # for test in C10_GPU_TEST_FILES + # ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/test/build.bzl b/c10/test/build.bzl index deb917dd8fcf3..baee2d5357159 100644 --- a/c10/test/build.bzl +++ b/c10/test/build.bzl @@ -46,7 +46,11 @@ def define_targets(rules): "util/typeid_test.cpp", ], ), +<<<<<<< HEAD copts = ["-Wno-deprecated-declarations", "-Wno-ctad-maybe-unsupported"], +======= + copts = ["-Wno-deprecated-declarations"], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) deps = [ ":Macros", ":complex_math_test_common", diff --git a/c10/test/core/SymInt_test.cpp b/c10/test/core/SymInt_test.cpp index e408543f5362c..de177a63e99cf 100644 --- a/c10/test/core/SymInt_test.cpp +++ b/c10/test/core/SymInt_test.cpp @@ -1,6 +1,9 @@ #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -36,6 +39,7 @@ TEST(SymIntTest, Overflows) { } #endif +<<<<<<< HEAD namespace { // We need a SymNodeImpl that 1) has working arithmetic with @@ -201,4 +205,6 @@ TEST(SymIntTest, MinMax) { test_operator(); test_operator(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif diff --git a/c10/test/util/generic_math_test.cpp b/c10/test/util/generic_math_test.cpp index 461d55819c65c..e1872b52b32fd 100644 --- a/c10/test/util/generic_math_test.cpp +++ b/c10/test/util/generic_math_test.cpp @@ -14,6 +14,9 @@ TEST(GenericMathTest, div_floor_test) { EXPECT_DOUBLE_EQ(c10::div_floor_floating(5., -2.), -3.); EXPECT_EQ(c10::div_floor_integer(5, 2), 2); EXPECT_EQ(c10::div_floor_integer(5, -2), -3); +<<<<<<< HEAD EXPECT_EQ(c10::div_mod(-9, -3), 0); EXPECT_EQ(c10::div_mod(-9., -3.), 0.); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } diff --git a/c10/util/BFloat16-inl.h b/c10/util/BFloat16-inl.h index 6d3510cd5be83..46f04ba907d09 100644 --- a/c10/util/BFloat16-inl.h +++ b/c10/util/BFloat16-inl.h @@ -1 +1,344 @@ +<<<<<<< HEAD #include +======= +#pragma once + +#include +#include + +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#elif defined(SYCL_LANGUAGE_VERSION) +#include // for SYCL 2020 +#endif + +namespace c10 { + +/// Constructors +inline C10_HOST_DEVICE BFloat16::BFloat16(float value) + : +#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 800 + x(__bfloat16_as_ushort(__float2bfloat16(value))) +#elif defined(__SYCL_DEVICE_ONLY__) && \ + defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) + x(c10::bit_cast(sycl::ext::oneapi::bfloat16(value))) +#else + // RNE by default + x(detail::round_to_nearest_even(value)) +#endif +{ +} + +/// Implicit conversions +inline C10_HOST_DEVICE BFloat16::operator float() const { +#if defined(__CUDACC__) && !defined(USE_ROCM) + return __bfloat162float(*reinterpret_cast(&x)); +#elif defined(__SYCL_DEVICE_ONLY__) && \ + defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) + return float(*reinterpret_cast(&x)); +#else + return detail::f32_from_bits(x); +#endif +} + +#if defined(__CUDACC__) && !defined(USE_ROCM) +inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) { + x = *reinterpret_cast(&value); +} +inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const { + return *reinterpret_cast(&x); +} +#endif + +#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) +inline C10_HOST_DEVICE BFloat16::BFloat16( + const sycl::ext::oneapi::bfloat16& value) { + x = *reinterpret_cast(&value); +} +inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const { + return *reinterpret_cast(&x); +} +#endif + +// CUDA intrinsics + +#if defined(__CUDACC__) || defined(__HIPCC__) +inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) { +#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __ldg(reinterpret_cast(ptr)); +#else + return *ptr; +#endif +} +#endif + +/// Arithmetic + +inline C10_HOST_DEVICE BFloat16 +operator+(const BFloat16& a, const BFloat16& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 +operator-(const BFloat16& a, const BFloat16& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 +operator*(const BFloat16& a, const BFloat16& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) { + a = a / b; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) { + a.x = a.x | b.x; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) { + a.x = a.x ^ b.x; + return a; +} + +inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) { + a.x = a.x & b.x; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) { + return static_cast(a) / b; +} + +// Overloading < and > operators, because std::max and std::min use them. + +inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) { + return float(lhs) > float(rhs); +} + +inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) { + return float(lhs) < float(rhs); +} + +} // namespace c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_specialized = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits::has_denorm_loss; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 8; + static constexpr int digits10 = 2; + static constexpr int max_digits10 = 4; + static constexpr int radix = 2; + static constexpr int min_exponent = -125; + static constexpr int min_exponent10 = -37; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr c10::BFloat16 min() { + return c10::BFloat16(0x0080, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 lowest() { + return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 max() { + return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 epsilon() { + return c10::BFloat16(0x3C00, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 round_error() { + return c10::BFloat16(0x3F00, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 infinity() { + return c10::BFloat16(0x7F80, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 quiet_NaN() { + return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 signaling_NaN() { + return c10::BFloat16(0x7F80, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 denorm_min() { + return c10::BFloat16(0x0001, c10::BFloat16::from_bits()); + } +}; + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/BFloat16.h b/c10/util/BFloat16.h index 6d3510cd5be83..f4bbb9fafe5e6 100644 --- a/c10/util/BFloat16.h +++ b/c10/util/BFloat16.h @@ -1 +1,127 @@ +<<<<<<< HEAD #include +======= +#pragma once + +// Defines the bloat16 type (brain floating-point). This representation uses +// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. + +#include +#include +#include +#include +#include +#include + +#if defined(__CUDACC__) && !defined(USE_ROCM) +#include +#endif + +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#elif defined(SYCL_LANGUAGE_VERSION) +#include // for SYCL 2020 +#endif + +namespace c10 { + +namespace detail { +inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { + float res = 0; + uint32_t tmp = src; + tmp <<= 16; + +#if defined(USE_ROCM) && defined(__HIPCC__) + float* tempRes; + + // We should be using memcpy in order to respect the strict aliasing rule + // but it fails in the HIP environment. + tempRes = reinterpret_cast(&tmp); + res = *tempRes; +#else + std::memcpy(&res, &tmp, sizeof(tmp)); +#endif + + return res; +} + +inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { + uint32_t res = 0; + +#if defined(USE_ROCM) && defined(__HIPCC__) + // We should be using memcpy in order to respect the strict aliasing rule + // but it fails in the HIP environment. + uint32_t* tempRes = reinterpret_cast(&src); + res = *tempRes; +#else + std::memcpy(&res, &src, sizeof(res)); +#endif + + return res >> 16; +} + +inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { +#if defined(USE_ROCM) && defined(__HIPCC__) + if (src != src) { +#elif defined(_MSC_VER) + if (isnan(src)) { +#else + if (std::isnan(src)) { +#endif + return UINT16_C(0x7FC0); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + union { + uint32_t U32; // NOLINT(facebook-hte-BadMemberName) + float F32; // NOLINT(facebook-hte-BadMemberName) + }; + + F32 = src; + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((U32 + rounding_bias) >> 16); + } +} +} // namespace detail + +struct alignas(2) BFloat16 { + uint16_t x; + + // HIP wants __host__ __device__ tag, CUDA does not +#if defined(USE_ROCM) && defined(__HIPCC__) + C10_HOST_DEVICE BFloat16() = default; +#else + BFloat16() = default; +#endif + + struct from_bits_t {}; + static constexpr C10_HOST_DEVICE from_bits_t from_bits() { + return from_bits_t(); + } + + constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) + : x(bits) {} + /* implicit */ inline C10_HOST_DEVICE BFloat16(float value); + inline C10_HOST_DEVICE operator float() const; + +#if defined(__CUDACC__) && !defined(USE_ROCM) + inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); + explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; +#endif + +#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) + inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value); + explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const; +#endif +}; + +C10_API inline std::ostream& operator<<( + std::ostream& out, + const BFloat16& value) { + out << (float)value; + return out; +} + +} // namespace c10 + +#include // IWYU pragma: keep +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/Exception.h b/c10/util/Exception.h index 545cef5351380..baa277d6c4652 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -267,6 +267,7 @@ class C10_API NotImplementedError : public Error { using Error::Error; }; +<<<<<<< HEAD // Used in ATen for buffer-related errors, e.g. trying to create a DLPack of // an unsupported device. These turn into BufferError when they cross to // Python. @@ -274,6 +275,8 @@ class C10_API BufferError : public Error { using Error::Error; }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Used in ATen for non finite indices. These turn into // ExitException when they cross to Python. class C10_API EnforceFiniteError : public Error { @@ -372,7 +375,30 @@ C10_API std::string GetExceptionString(const std::exception& e); // https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly #define C10_EXPAND_MSVC_WORKAROUND(x) x +<<<<<<< HEAD #include +======= +// On nvcc, C10_UNLIKELY thwarts missing return statement analysis. In cases +// where the unlikely expression may be a constant, use this macro to ensure +// return statement analysis keeps working (at the cost of not getting the +// likely/unlikely annotation on nvcc). +// https://github.com/pytorch/pytorch/issues/21418 +// +// Currently, this is only used in the error reporting macros below. If you +// want to use it more generally, move me to Macros.h +// +// TODO: Brian Vaughan observed that we might be able to get this to work on +// nvcc by writing some sort of C++ overload that distinguishes constexpr inputs +// from non-constexpr. Since there isn't any evidence that losing C10_UNLIKELY +// in nvcc is causing us perf problems, this is not yet implemented, but this +// might be an interesting piece of C++ code for an intrepid bootcamper to +// write. +#if defined(__CUDACC__) +#define C10_UNLIKELY_OR_CONST(e) e +#else +#define C10_UNLIKELY_OR_CONST(e) C10_UNLIKELY(e) +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // ---------------------------------------------------------------------------- // Error reporting macros @@ -642,10 +668,13 @@ namespace c10::detail { #define TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \ TORCH_CHECK_WITH_MSG(NotImplementedError, cond, "TYPE", __VA_ARGS__) +<<<<<<< HEAD // Like TORCH_CHECK, but raises BufferError instead of Errors. #define TORCH_CHECK_BUFFER(cond, ...) \ TORCH_CHECK_WITH_MSG(BufferError, cond, "TYPE", __VA_ARGS__) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #define TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(cond, ...) \ TORCH_CHECK_WITH_MSG( \ ErrorAlwaysShowCppStacktrace, cond, "TYPE", ##__VA_ARGS__) diff --git a/c10/util/Float4_e2m1fn_x2.h b/c10/util/Float4_e2m1fn_x2.h index 15f7ac70c4e83..6285e22828eb8 100644 --- a/c10/util/Float4_e2m1fn_x2.h +++ b/c10/util/Float4_e2m1fn_x2.h @@ -1 +1,32 @@ +<<<<<<< HEAD #include +======= +#pragma once +#include + +#include + +/// Defines the Float4_e2m1fn_x2 type (4-bit floating-point, two elements packed +/// into one byte). This is the FP4 dtype from the OCP MX format spec +/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, +/// Section 5.3.3) +/// +/// Given two high precision values val0 and val1, here is the +/// binary configuration of their packed representation, from MSB to LSB: +/// +/// original value | val1 : val0 +/// ======================================== +/// bit index (MSB==7, LSB==0) | 7654 : 3210 +/// sign/exponent/mantissa | seem : seem +/// + +namespace c10 { + +struct alignas(1) Float4_e2m1fn_x2 { + uint8_t val_; + Float4_e2m1fn_x2() = default; + C10_HOST_DEVICE explicit Float4_e2m1fn_x2(uint8_t val) : val_(val) {} +}; + +} // namespace c10 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/Float8_e4m3fn-inl.h b/c10/util/Float8_e4m3fn-inl.h index ef52e38f506da..1581c1bf0ede1 100644 --- a/c10/util/Float8_e4m3fn-inl.h +++ b/c10/util/Float8_e4m3fn-inl.h @@ -1 +1,278 @@ +<<<<<<< HEAD #include +======= +#pragma once + +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace c10 { + +/// Constructors + +inline C10_HOST_DEVICE Float8_e4m3fn::Float8_e4m3fn(float value) + : x(detail::fp8e4m3fn_from_fp32_value(value)) {} + +/// Implicit conversions + +inline C10_HOST_DEVICE Float8_e4m3fn::operator float() const { + return detail::fp8e4m3fn_to_fp32_value(x); +} + +/// Special values helper + +inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const { + return (x & 0b01111111) == 0b01111111; +} + +/// Arithmetic + +inline C10_HOST_DEVICE Float8_e4m3fn +operator+(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn +operator-(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn +operator*(const Float8_e4m3fn& a, const Float8_e4m3fn& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn operator/( + const Float8_e4m3fn& a, + const Float8_e4m3fn& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn operator-(const Float8_e4m3fn& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE Float8_e4m3fn& operator+=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fn& operator-=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fn& operator*=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fn& operator/=( + Float8_e4m3fn& a, + const Float8_e4m3fn& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Float8_e4m3fn a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Float8_e4m3fn a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Float8_e4m3fn a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Float8_e4m3fn a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fn b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fn b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fn b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fn b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fn& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fn& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fn& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fn& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Float8_e4m3fn a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Float8_e4m3fn a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Float8_e4m3fn a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Float8_e4m3fn a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fn b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fn b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fn b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fn b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn operator+(int a, Float8_e4m3fn b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator-(int a, Float8_e4m3fn b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator*(int a, Float8_e4m3fn b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator/(int a, Float8_e4m3fn b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int64_t b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int64_t b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int64_t b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int64_t b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fn operator+(int64_t a, Float8_e4m3fn b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator-(int64_t a, Float8_e4m3fn b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator*(int64_t a, Float8_e4m3fn b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e4m3fn operator/(int64_t a, Float8_e4m3fn b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e4m3fn to float. + +} // namespace c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -5; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = false; + + static constexpr c10::Float8_e4m3fn min() { + return c10::Float8_e4m3fn(0x08, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn lowest() { + return c10::Float8_e4m3fn(0xFE, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn max() { + return c10::Float8_e4m3fn(0x7E, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn epsilon() { + return c10::Float8_e4m3fn(0x20, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn round_error() { + return c10::Float8_e4m3fn(0x30, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn quiet_NaN() { + return c10::Float8_e4m3fn(0x7F, c10::Float8_e4m3fn::from_bits()); + } + static constexpr c10::Float8_e4m3fn denorm_min() { + return c10::Float8_e4m3fn(0x01, c10::Float8_e4m3fn::from_bits()); + } +}; + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/Float8_e4m3fn.h b/c10/util/Float8_e4m3fn.h index ef52e38f506da..5c642018c2815 100644 --- a/c10/util/Float8_e4m3fn.h +++ b/c10/util/Float8_e4m3fn.h @@ -1 +1,244 @@ +<<<<<<< HEAD #include +======= +#pragma once + +/// Defines the Float8_e4m3fn type (8-bit floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration: +/// s eeee mmm +/// 1 sign bit +/// 4 exponent bits +/// 3 mantissa bits +/// bias = 7 +/// +/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf +/// and inspired by Half implementation from pytorch/c10/util/Half.h + +#include +#include + +#if defined(__cplusplus) +#include +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#ifdef _MSC_VER +#include +#endif + +#include +#include + +namespace c10 { + +namespace detail { + +/* + * Convert a 8-bit floating-point number in fp8 E4M3FN format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) { + /* + * Extend the fp8 E4M3FN number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+----+---+-----------------------------+ + * | S |EEEE|MMM|0000 0000 0000 0000 0000 0000| + * +---+----+---+-----------------------------+ + * Bits 31 27-30 24-26 0-23 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)input << 24; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 + * of the 32-bit word: + * + * +---+----+---+-----------------------------+ + * | S |EEEE|MMM|0000 0000 0000 0000 0000 0000| + * +---+----+---+-----------------------------+ + * Bits 31 27-30 24-26 0-23 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the + * half-precision number normalized. If the initial number is normalized, some + * of its high 5 bits (sign == 0 and 4-bit exponent) equals one. In this case + * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note + * that if we shift denormalized nonsign by renorm_shift, the unit bit of + * mantissa will shift into exponent, turning the biased exponent into 1, and + * making mantissa normalized (i.e. without leading 1). + */ +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + uint32_t renorm_shift = __clz(nonsign); +#elif defined(__SYCL_DEVICE_ONLY__) + // Note: zero is not a supported input into `__builtin_clz` + uint32_t renorm_shift = + nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT; +#elif defined(_MSC_VER) && !defined(__clang__) + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + // Note: zero is not a supported input into `__builtin_clz` + uint32_t renorm_shift = + nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT; +#endif + renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; + /* + * Iff fp8e4m3fn number has all exponent and mantissa bits set to 1, + * the addition overflows it into bit 31, and the subsequent shift turns the + * high 9 bits into 1. Thus inf_nan_mask == 0x7F800000 if the fp8e4m3fn number + * is Nan, 0x00000000 otherwise + */ + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 + * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 + * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input + * was denormal) + * 2. Shift nonsign right by 4 so the exponent (4 bits originally) + * becomes an 8-bit field and 3-bit mantissa shifts into the 3 high + * bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x78 to the exponent (starting at bit 23) to compensate the + * different in exponent bias (0x7F for single-precision number less 0x07 + * for fp8e4m3fn number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to + * account for renormalization. As renorm_shift is less than 0x78, this + * can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the + * input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent + * into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + uint32_t result = sign | + ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); + return fp32_from_bits(result); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E4M3FN format, in bit representation. + */ +inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) { + /* + * Binary representation of 480.0f, which is the first value + * not representable in fp8e4m3fn range: + * 0 1111 111 - fp8e4m3fn + * 0 10000111 11100000000000000000000 - fp32 + */ + constexpr uint32_t fp8_max = UINT32_C(1087) << 20; + + /* + * A mask for converting fp32 numbers lower than fp8e4m3fn normal range + * into denorm representation + * magic number: ((127 - 7) + (23 - 3) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(141) << 23; + + uint32_t f_bits = fp32_to_bits(f); + + uint8_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fp8_max) { + // NaN - all exponent and mantissa bits set to 1 + result = 0x7f; + } else { + if (f_bits < (UINT32_C(121) << 23)) { + // Input number is smaller than 2^(-6), which is the smallest + // fp8e4m3fn normal number + f_bits = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 20) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 20); + } + } + + result |= static_cast(sign >> 24); + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e4m3fn { + uint8_t x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e4m3fn() = default; + + constexpr C10_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t) + : x(bits) {} + inline C10_HOST_DEVICE Float8_e4m3fn(float value); + inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE bool isnan() const; +}; + +C10_API inline std::ostream& operator<<( + std::ostream& out, + const Float8_e4m3fn& value) { + out << (float)value; + return out; +} + +} // namespace c10 + +#include // IWYU pragma: keep +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/Float8_e4m3fnuz-inl.h b/c10/util/Float8_e4m3fnuz-inl.h index f8fab7180e1e7..958731b624a3f 100644 --- a/c10/util/Float8_e4m3fnuz-inl.h +++ b/c10/util/Float8_e4m3fnuz-inl.h @@ -1 +1,283 @@ +<<<<<<< HEAD #include +======= +#pragma once + +#include +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace c10 { + +/// Constructors + +inline C10_HOST_DEVICE Float8_e4m3fnuz::Float8_e4m3fnuz(float value) + : x(detail::fp8e4m3fnuz_from_fp32_value(value)) {} + +/// Implicit conversions + +inline C10_HOST_DEVICE Float8_e4m3fnuz::operator float() const { + return detail::fp8_fnuz_to_fp32_value<4, 3>(x); +} + +/// Special values helper + +inline C10_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const { + return x == 0b10000000; +} + +/// Arithmetic + +inline C10_HOST_DEVICE Float8_e4m3fnuz +operator+(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz +operator-(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz +operator*(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/( + const Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(const Float8_e4m3fnuz& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator+=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator-=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator*=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator/=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Float8_e4m3fnuz a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Float8_e4m3fnuz a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Float8_e4m3fnuz a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Float8_e4m3fnuz a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fnuz& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fnuz& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fnuz& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fnuz& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Float8_e4m3fnuz a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Float8_e4m3fnuz a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Float8_e4m3fnuz a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Float8_e4m3fnuz a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int a, Float8_e4m3fnuz b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int a, Float8_e4m3fnuz b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int a, Float8_e4m3fnuz b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int a, Float8_e4m3fnuz b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int64_t b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int64_t b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int64_t b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int64_t b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e4m3fnuz to float. + +} // namespace c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -6; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = false; + + static constexpr c10::Float8_e4m3fnuz min() { + return c10::Float8_e4m3fnuz(0x08, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz lowest() { + return c10::Float8_e4m3fnuz(0xFF, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz max() { + return c10::Float8_e4m3fnuz(0x7F, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz epsilon() { + return c10::Float8_e4m3fnuz(0x28, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz round_error() { + return c10::Float8_e4m3fnuz(0x38, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz infinity() { + // NaN (no infinities) + return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz quiet_NaN() { + return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits()); + } + static constexpr c10::Float8_e4m3fnuz denorm_min() { + return c10::Float8_e4m3fnuz(0x01, c10::Float8_e4m3fnuz::from_bits()); + } +}; + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/Float8_e4m3fnuz.h b/c10/util/Float8_e4m3fnuz.h index f8fab7180e1e7..c25b872f97444 100644 --- a/c10/util/Float8_e4m3fnuz.h +++ b/c10/util/Float8_e4m3fnuz.h @@ -1 +1,143 @@ +<<<<<<< HEAD #include +======= +#pragma once + +/// Defines the Float8_e4m3fnuz type (8-bit floating-point) including +/// conversions to standard C types and basic arithmetic operations. Note that +/// arithmetic operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration remains the same as Float8_e4m3fn: +/// s eeee mmm +/// 1 sign bit +/// 4 exponent bits +/// 3 mantissa bits +/// The key differences versus Float8_e4m3fn are: +/// bias = 8 +/// no infinities or negative zero +/// NaN only when sign bit is 1, rest all 0s +/// +/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and +/// the existing Float8_e4m3fn implementation. + +#include +#include +#include +#include + +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#include +#include + +namespace c10 { + +namespace detail { + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation. + */ +inline C10_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) { + /* + * Binary representation of 256.0f, which is the first value not representable + * (i.e. the first value which would overflow in to the sign bit, resulting in + * a NaN) in fp8e4m3fnuz range: + * 1 0000 000 - fp8e4m3fnuz + * 0 10000111 00000000000000000000000 - fp32 + */ + constexpr uint32_t fnuz_max = UINT32_C(0x87) << 23; + + /* + * A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range + * into denorm representation + * magic number: ((127 - 8) + (23 - 3) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(0x8C) << 23; + + uint32_t f_bits = fp32_to_bits(f); + + uint32_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fnuz_max) { + // NaN -- sign bit set to 1, rest 0s. + return 0x80; + } + + if (f_bits < (UINT32_C(0x78) << 23) /* 2^-7 in float32 */) { + // Input exponent is less than -7, the smallest e4m3fnuz exponent, so the + // number will become subnormal. + f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + if (result == 0) { + // fnuz types don't have negative zero. + return 0; + } + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 20) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(8 - 127) << 23) + 0x7FFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 20); + } + + result |= sign >> 24; + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e4m3fnuz { + uint8_t x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e4m3fnuz() = default; + + constexpr C10_HOST_DEVICE Float8_e4m3fnuz(uint8_t bits, from_bits_t) + : x(bits) {} + inline C10_HOST_DEVICE Float8_e4m3fnuz(float value); + inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE bool isnan() const; +}; + +C10_API inline std::ostream& operator<<( + std::ostream& out, + const Float8_e4m3fnuz& value) { + out << (float)value; + return out; +} + +} // namespace c10 + +#include // IWYU pragma: keep +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/Float8_e5m2-inl.h b/c10/util/Float8_e5m2-inl.h index 2e21840fba376..13bfe9a9629ba 100644 --- a/c10/util/Float8_e5m2-inl.h +++ b/c10/util/Float8_e5m2-inl.h @@ -1 +1,290 @@ +<<<<<<< HEAD #include +======= +#pragma once + +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +#define EXP_WIDTH_FP8 5 +#define MAN_WIDTH_FP8 2 +#define EXP_BIAS_FP8 15 + +namespace c10 { + +/// Constructors + +inline C10_HOST_DEVICE Float8_e5m2::Float8_e5m2(float value) + : x(detail::fp8e5m2_from_fp32_value(value)) {} + +/// Implicit conversions + +inline C10_HOST_DEVICE Float8_e5m2::operator float() const { + return detail::fp8e5m2_to_fp32_value(x); +} + +/// Special values helpers + +inline C10_HOST_DEVICE bool Float8_e5m2::isnan() const { + return (x & 0b01111111) > 0b01111100; +} + +inline C10_HOST_DEVICE bool Float8_e5m2::isinf() const { + return (x & 0b01111111) == 0b01111100; +} + +/// Arithmetic + +inline C10_HOST_DEVICE Float8_e5m2 +operator+(const Float8_e5m2& a, const Float8_e5m2& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 +operator-(const Float8_e5m2& a, const Float8_e5m2& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 +operator*(const Float8_e5m2& a, const Float8_e5m2& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 operator/( + const Float8_e5m2& a, + const Float8_e5m2& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 operator-(const Float8_e5m2& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE Float8_e5m2& operator+=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2& operator-=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2& operator*=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2& operator/=( + Float8_e5m2& a, + const Float8_e5m2& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Float8_e5m2 a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Float8_e5m2 a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Float8_e5m2 a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Float8_e5m2 a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2 b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2 b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2 b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2 b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Float8_e5m2 a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Float8_e5m2 a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Float8_e5m2 a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Float8_e5m2 a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2 b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2 b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2 b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2 b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 operator+(int a, Float8_e5m2 b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator-(int a, Float8_e5m2 b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator*(int a, Float8_e5m2 b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator/(int a, Float8_e5m2 b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int64_t b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int64_t b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int64_t b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int64_t b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2 operator+(int64_t a, Float8_e5m2 b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator-(int64_t a, Float8_e5m2 b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator*(int64_t a, Float8_e5m2 b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e5m2 operator/(int64_t a, Float8_e5m2 b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e5m2 to float. + +} // namespace c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_specialized = true; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr c10::Float8_e5m2 min() { + return c10::Float8_e5m2(0x4, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 max() { + return c10::Float8_e5m2(0x7B, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 lowest() { + return c10::Float8_e5m2(0xFB, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 epsilon() { + return c10::Float8_e5m2(0x34, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 round_error() { + return c10::Float8_e5m2(0x38, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 infinity() { + return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 quiet_NaN() { + return c10::Float8_e5m2(0x7F, c10::Float8_e5m2::from_bits()); + } + static constexpr c10::Float8_e5m2 denorm_min() { + return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits()); + } +}; + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/Float8_e5m2.h b/c10/util/Float8_e5m2.h index 2e21840fba376..aea7cdb74a43d 100644 --- a/c10/util/Float8_e5m2.h +++ b/c10/util/Float8_e5m2.h @@ -1 +1,152 @@ +<<<<<<< HEAD #include +======= +#pragma once + +/// Defines the Float8_e5m2 type (8-bit floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration: +/// s eeeee mm +/// 1 sign bit +/// 5 exponent bits +/// 2 mantissa bits +/// bias = 15 +/// +/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf +/// and inspired by Half implementation from pytorch/c10/util/Half.h + +#include + +namespace c10 { + +namespace detail { + +/* + * Convert a 8-bit floating-point number in fp8 E5M2 format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline C10_HOST_DEVICE float fp8e5m2_to_fp32_value(uint8_t input) { + /* + * Extend the fp8 E5M2 number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+----+---+-----------------------------+ + * | S |EEEEE|MM|0000 0000 0000 0000 0000 0000| + * +---+----+---+-----------------------------+ + * Bits 31 26-30 24-25 0-23 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + uint16_t half_representation = input; + half_representation <<= 8; + return fp16_ieee_to_fp32_value(half_representation); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E5M2 format, in bit representation. + */ +inline C10_HOST_DEVICE uint8_t fp8e5m2_from_fp32_value(float f) { + /* + * Binary representation of fp32 infinity + * 0 11111111 00000000000000000000000 + */ + constexpr uint32_t fp32_inf = UINT32_C(255) << 23; + + /* + * Binary representation of 65536.0f, which is the first value + * not representable in fp8e5m2 range: + * 0 11111 00 - fp8e5m2 + * 0 10001111 00000000000000000000000 - fp32 + */ + constexpr uint32_t fp8_max = UINT32_C(143) << 23; + + /* + * A mask for converting fp32 numbers lower than fp8e5m2 normal range + * into denorm representation + * magic number: ((127 - 15) + (23 - 2) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(134) << 23; + + uint32_t f_bits = fp32_to_bits(f); + uint8_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fp8_max) { + // NaN - all exponent and mantissa bits set to 1 + result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C); + } else { + if (f_bits < (UINT32_C(113) << 23)) { + // Input number is smaller than 2^(-14), which is the smallest + // fp8e5m2 normal number + f_bits = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + } else { + // resulting mantissa is odd + uint32_t mant_odd = (f_bits >> 21) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 21); + } + } + + result |= static_cast(sign >> 24); + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e5m2 { + uint8_t x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e5m2() = default; + + constexpr C10_HOST_DEVICE Float8_e5m2(uint8_t bits, from_bits_t) : x(bits) {} + inline C10_HOST_DEVICE Float8_e5m2(float value); + inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE bool isnan() const; + inline C10_HOST_DEVICE bool isinf() const; +}; + +C10_API inline std::ostream& operator<<( + std::ostream& out, + const Float8_e5m2& value) { + out << (float)value; + return out; +} + +} // namespace c10 + +#include // IWYU pragma: keep +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/Float8_e5m2fnuz-inl.h b/c10/util/Float8_e5m2fnuz-inl.h index 1f2d3db723d02..0ac20143324c0 100644 --- a/c10/util/Float8_e5m2fnuz-inl.h +++ b/c10/util/Float8_e5m2fnuz-inl.h @@ -1 +1,289 @@ +<<<<<<< HEAD #include +======= +#pragma once + +#include +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace c10 { + +/// Constructors + +inline C10_HOST_DEVICE Float8_e5m2fnuz::Float8_e5m2fnuz(float value) + : x(detail::fp8e5m2fnuz_from_fp32_value(value)) {} + +/// Implicit conversions + +inline C10_HOST_DEVICE Float8_e5m2fnuz::operator float() const { + return detail::fp8_fnuz_to_fp32_value<5, 2>(x); +} + +/// Special values helpers + +inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isnan() const { + return x == 0b10000000; +} + +inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isinf() const { + return false; +} + +/// Arithmetic + +inline C10_HOST_DEVICE Float8_e5m2fnuz +operator+(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz +operator-(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz +operator*(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/( + const Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(const Float8_e5m2fnuz& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator+=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator-=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator*=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator/=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Float8_e5m2fnuz a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Float8_e5m2fnuz a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Float8_e5m2fnuz a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Float8_e5m2fnuz a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2fnuz& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2fnuz& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2fnuz& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2fnuz& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Float8_e5m2fnuz a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Float8_e5m2fnuz a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Float8_e5m2fnuz a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Float8_e5m2fnuz a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int a, Float8_e5m2fnuz b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int a, Float8_e5m2fnuz b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int a, Float8_e5m2fnuz b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int a, Float8_e5m2fnuz b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int64_t b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int64_t b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int64_t b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int64_t b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e5m2fnuz to float. + +} // namespace c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_specialized = true; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -14; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr c10::Float8_e5m2fnuz min() { + return c10::Float8_e5m2fnuz(0x04, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz max() { + return c10::Float8_e5m2fnuz(0x7F, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz lowest() { + return c10::Float8_e5m2fnuz(0xFF, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz epsilon() { + return c10::Float8_e5m2fnuz(0x34, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz round_error() { + return c10::Float8_e5m2fnuz(0x38, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz infinity() { + return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits()); + } + // TODO(future): we are mapping neg_zero to both inf and NaN, this is + // surprising and we should figure out what to do about it. + static constexpr c10::Float8_e5m2fnuz quiet_NaN() { + return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits()); + } + static constexpr c10::Float8_e5m2fnuz denorm_min() { + return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits()); + } +}; + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/Float8_e5m2fnuz.h b/c10/util/Float8_e5m2fnuz.h index 1f2d3db723d02..02eac597a0bee 100644 --- a/c10/util/Float8_e5m2fnuz.h +++ b/c10/util/Float8_e5m2fnuz.h @@ -1 +1,142 @@ +<<<<<<< HEAD #include +======= +#pragma once + +/// Defines the Float8_e5m2fnuz type (8-bit floating-point) including +/// conversions to standard C types and basic arithmetic operations. Note that +/// arithmetic operations are implemented by converting to floating point and +/// performing the operation in float32. +/// Binary configuration remains the same as e5m2: +/// s eeeee mm +/// 1 sign bit +/// 5 exponent bits +/// 2 mantissa bits +/// The key differences that e5m2fnuz brings are: +/// bias = 16 +/// no infinities or negative zero +/// NaN only when sign bit is 1, rest all 0s +/// +/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and +/// the existing Float8_e4m3fn implementation. + +#include +#include +#include + +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#include +#include + +namespace c10 { + +namespace detail { + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 E5M2 format, in bit representation. + */ +inline C10_HOST_DEVICE uint8_t fp8e5m2fnuz_from_fp32_value(float f) { + /* + * Binary representation of 65536.0f, which is the first value not + * representable (i.e. the first value which would overflow in to the sign + * bit, resulting in a NaN) in fp8e4m3fnuz range: + * 1 00000 00 - fp8e5m2fnuz + * 0 10001111 00000000000000000000000 - fp32 + */ + constexpr uint32_t fnuz_max = UINT32_C(0x8F) << 23; + + /* + * A mask for converting fp32 numbers lower than fp8e5m2fnuz normal range + * into denormalized representation. + * magic number: ((127 - 16) + (23 - 2) + 1) + */ + constexpr uint32_t denorm_mask = UINT32_C(0x85) << 23; + + uint32_t f_bits = fp32_to_bits(f); + uint32_t result = 0u; + + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = f_bits & UINT32_C(0x80000000); + + /* + * Set sign bit to 0 + */ + f_bits ^= sign; + + if (f_bits >= fnuz_max) { + // NaN -- sign bit set to 1, rest 0s + return 0x80; + } + + if (f_bits < (UINT32_C(0x70) << 23) /* 2^-15 in float32 */) { + // Input exponent is less than -15, the smallest e5m2fnuz exponent, so the + // number will become subnormal. + f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + result = static_cast(f_bits - denorm_mask); + if (result == 0) { + // fnuz types don't have negative zero. + return 0; + } + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 21) & 1; + + // update exponent, rounding bias part 1 + f_bits += ((uint32_t)(16 - 127) << 23) + 0xFFFFF; + + // rounding bias part 2 + f_bits += mant_odd; + + // take the bits! + result = static_cast(f_bits >> 21); + } + + result |= sign >> 24; + return result; +} + +} // namespace detail + +struct alignas(1) Float8_e5m2fnuz { + uint8_t x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e5m2fnuz() = default; + + constexpr C10_HOST_DEVICE Float8_e5m2fnuz(uint8_t bits, from_bits_t) + : x(bits) {} + inline C10_HOST_DEVICE Float8_e5m2fnuz(float value); + inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE bool isnan() const; + inline C10_HOST_DEVICE bool isinf() const; +}; + +C10_API inline std::ostream& operator<<( + std::ostream& out, + const Float8_e5m2fnuz& value) { + out << (float)value; + return out; +} + +} // namespace c10 + +#include // IWYU pragma: keep +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/Float8_e8m0fnu-inl.h b/c10/util/Float8_e8m0fnu-inl.h index 9982faa07976b..2392cb8b0bc27 100644 --- a/c10/util/Float8_e8m0fnu-inl.h +++ b/c10/util/Float8_e8m0fnu-inl.h @@ -1 +1,116 @@ +<<<<<<< HEAD #include +======= +#pragma once + +#include +#include +#include +#include + +// TODO(#146647): Can we remove the below warning? +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace c10 { + +/// Constructors + +inline C10_HOST_DEVICE Float8_e8m0fnu::Float8_e8m0fnu(float value) + : x(detail::fp8e8m0fnu_from_fp32_value(value)) {} + +/// Implicit conversions + +inline C10_HOST_DEVICE Float8_e8m0fnu::operator float() const { + // TODO(#146647): maybe rewrite without control flow + + // if exponent is zero, need to special case to return 2^-127 instead of zero + if (x == 0) { + return c10::detail::fp32_from_bits(0x00400000); + } + + // if exponent is NaN, need to special case to return properly encoded NaN + if (isnan()) { + return c10::detail::fp32_from_bits(0x7f800001); + } + + // leave sign at 0, set the exponent bits, leave stored mantissa at 0 + uint32_t res = x << 23; + + return c10::detail::fp32_from_bits(res); +} + +/// Special values helper + +inline C10_HOST_DEVICE bool Float8_e8m0fnu::isnan() const { + return x == 0b11111111; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e8m0fnu to float. + +} // namespace c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = false; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = false; + static constexpr auto has_denorm_loss = false; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 1; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 1; // just a 2! + static constexpr int radix = 2; + static constexpr int min_exponent = -126; + static constexpr int min_exponent10 = -38; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = false; + + static constexpr c10::Float8_e8m0fnu min() { + // 2^-127 + return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits()); + } + static constexpr c10::Float8_e8m0fnu lowest() { + // 2^-127 + return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits()); + } + static constexpr c10::Float8_e8m0fnu max() { + // 254 biased, which is 127 unbiased, so 2^127 + return c10::Float8_e8m0fnu(0b11111110, c10::Float8_e8m0fnu::from_bits()); + } + static constexpr c10::Float8_e8m0fnu epsilon() { + // according to https://en.cppreference.com/w/cpp/types/numeric_limits, this + // is "the difference between 1.0 and the next representable value of the + // given floating-point type". The next representable value is 2.0, so the + // difference is 1.0 which is 2^0. 0 unbiased is 127 biased. + return c10::Float8_e8m0fnu(0b01111111, c10::Float8_e8m0fnu::from_bits()); + } + static constexpr c10::Float8_e8m0fnu round_error() { + // 0.5 in float, which is 2^-1, and -1 + 127 = 126 + return c10::Float8_e8m0fnu(0b01111110, c10::Float8_e8m0fnu::from_bits()); + } + static constexpr c10::Float8_e8m0fnu quiet_NaN() { + return c10::Float8_e8m0fnu(0b11111111, c10::Float8_e8m0fnu::from_bits()); + } +}; + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/Float8_e8m0fnu.h b/c10/util/Float8_e8m0fnu.h index 9982faa07976b..338937f32e2e4 100644 --- a/c10/util/Float8_e8m0fnu.h +++ b/c10/util/Float8_e8m0fnu.h @@ -1 +1,124 @@ +<<<<<<< HEAD #include +======= +#pragma once + +/// Defines the Float8_e8m0fnu type (8-bit floating-point) including +/// conversions to standard C types +/// Binary configuration : +/// eeeeeeee +/// no sign bits +/// 8 exponent bits +/// no mantissa bits +/// +/// This is the E8M0 dtype from the OCP MX format spec +/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, +/// Section 5.4.1) + +#include +#include +#include +#include + +// TODO(#146647): do we need to special case OPENCL? +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#include +#include + +namespace c10 { + +namespace detail { + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 e8m0fnu format, in bit representation. + */ +inline C10_HOST_DEVICE uint8_t fp8e8m0fnu_from_fp32_value(float f) { + // TODO(#146647): maybe rewrite without control flow + + uint32_t f_bits = c10::detail::fp32_to_bits(f); + + // extract the exponent + uint32_t exponent = (f_bits >> 23) & 0b11111111; + + // special case float32 NaN and +-inf to map to e8m0 nan + if (exponent == 0b11111111) { + return exponent; + } + + // next, we use guard, round, sticky bits and the LSB to implement round to + // nearest, with ties to even + + // guard bit - bit 23, or 22 zero-indexed + uint8_t g = (f_bits & 0x400000) > 0; + // round bit - bit 22, or 21 zero-indexed + uint8_t r = (f_bits & 0x200000) > 0; + // sticky bit - bits 21 to 1, or 20 to 0 zero-indexed + uint8_t s = (f_bits & 0x1FFFFF) > 0; + // in casting to e8m0, LSB is the implied mantissa bit. It equals to 0 if the + // original float32 is denormal, and to 1 if the original float32 is normal. + uint8_t lsb = exponent > 0; + + // implement the RNE logic + bool round_up = false; + + // if g == 0, round down (no-op) + if (g == 1) { + if ((r == 1) || (s == 1)) { + // round up + round_up = true; + } else { + if (lsb == 1) { + // round up + round_up = true; + } + // if lsb == 0, round down (no-op) + } + } + + if (round_up) { + // adjust exponent + // note that if exponent was 255 we would have already returned earlier, so + // we know we can add one safely without running out of bounds + exponent++; + } + + return exponent; +} + +} // namespace detail + +struct alignas(1) Float8_e8m0fnu { + uint8_t x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e8m0fnu() = default; + + constexpr C10_HOST_DEVICE Float8_e8m0fnu(uint8_t bits, from_bits_t) + : x(bits) {} + inline C10_HOST_DEVICE Float8_e8m0fnu(float value); + inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE bool isnan() const; +}; + +C10_API inline std::ostream& operator<<( + std::ostream& out, + const Float8_e8m0fnu& value) { + out << (float)value; + return out; +} + +} // namespace c10 + +#include // IWYU pragma: keep +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/Float8_fnuz_cvt.h b/c10/util/Float8_fnuz_cvt.h new file mode 100644 index 0000000000000..327f90d11a719 --- /dev/null +++ b/c10/util/Float8_fnuz_cvt.h @@ -0,0 +1,64 @@ +#pragma once + +#include + +#include + +#if defined(SYCL_LANGUAGE_VERSION) +#include +#endif + +namespace c10::detail { + +/* + * Convert a 8-bit floating-point number in either f8 E4M3FNUZ or bf8 E5M2FNUZ + * format, in bit representation, to a 32-bit floating-point number. + */ +template +inline C10_HOST_DEVICE float fp8_fnuz_to_fp32_value(uint8_t x) { + static_assert((we == 4 && wm == 3) || (we == 5 && wm == 2)); + constexpr uint32_t weo = 8; + constexpr uint32_t wmo = 23; + + if (x == 0) { + return 0; + } + + if (x == 0x80) { + constexpr uint32_t ifNaN = 0x7F800001; + return fp32_from_bits(ifNaN); + } + + uint32_t mantissa = x & ((1 << wm) - 1); + uint32_t exponent = (x & 0x7F) >> wm; + + // subnormal input + if (exponent == 0) { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + uint32_t renorm_shift = __clz(mantissa); +#elif defined(__SYCL_DEVICE_ONLY__) + uint32_t renorm_shift = sycl::clz(mantissa); +#elif defined(_MSC_VER) + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)mantissa); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(mantissa); +#endif + uint32_t sh = 1 + renorm_shift - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + + const uint32_t exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)); + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + uint32_t sign = x >> 7; + uint32_t retval = (sign << 31) | (exponent << 23) | mantissa; + return fp32_from_bits(retval); +} + +} // namespace c10::detail diff --git a/c10/util/Half-inl.h b/c10/util/Half-inl.h index fe66779a0e51d..90138f04399dc 100644 --- a/c10/util/Half-inl.h +++ b/c10/util/Half-inl.h @@ -1 +1,354 @@ +<<<<<<< HEAD #include +======= +#pragma once + +#include +#include + +#include +#include + +#ifdef __CUDACC__ +#include +#endif + +#ifdef __HIPCC__ +#include +#endif + +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#elif defined(SYCL_LANGUAGE_VERSION) +#include // for SYCL 2020 +#endif + +#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) +#include +#endif + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace c10 { + +#if defined(__aarch64__) && !defined(__CUDACC__) +/// Constructors +inline Half::Half(float16_t value) : x(detail::fp16_to_bits(value)) {} +inline Half::operator float16_t() const { + return detail::fp16_from_bits(x); +} +#else + +inline C10_HOST_DEVICE Half::Half(float value) + : +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + x(__half_as_short(__float2half(value))) +#elif defined(__SYCL_DEVICE_ONLY__) + x(c10::bit_cast(sycl::half(value))) +#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) + x(at::vec::float2half_scalar(value)) +#else + x(detail::fp16_ieee_from_fp32_value(value)) +#endif +{ +} + +/// Implicit conversions + +inline C10_HOST_DEVICE Half::operator float() const { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + return __half2float(*reinterpret_cast(&x)); +#elif defined(__SYCL_DEVICE_ONLY__) + return float(c10::bit_cast(x)); +#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) + return at::vec::half2float_scalar(x); +#elif defined(__aarch64__) && !defined(__CUDACC__) + return detail::native_fp16_to_fp32_value(x); +#else + return detail::fp16_ieee_to_fp32_value(x); +#endif +} + +#endif /* !defined(__aarch64__) || defined(__CUDACC__) \ + */ + +#if defined(__CUDACC__) || defined(__HIPCC__) +inline C10_HOST_DEVICE Half::Half(const __half& value) { + x = *reinterpret_cast(&value); +} +inline C10_HOST_DEVICE Half::operator __half() const { + return *reinterpret_cast(&x); +} +#endif + +#ifdef SYCL_LANGUAGE_VERSION +inline C10_HOST_DEVICE Half::Half(const sycl::half& value) { + x = *reinterpret_cast(&value); +} +inline C10_HOST_DEVICE Half::operator sycl::half() const { + return *reinterpret_cast(&x); +} +#endif + +// CUDA intrinsics + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \ + (defined(__clang__) && defined(__CUDA__)) +inline __device__ Half __ldg(const Half* ptr) { + return __ldg(reinterpret_cast(ptr)); +} +#endif + +/// Arithmetic + +inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Half operator-(const Half& a) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ + defined(__HIP_DEVICE_COMPILE__) + return __hneg(a); +#elif defined(__SYCL_DEVICE_ONLY__) + return -c10::bit_cast(a); +#else + return -static_cast(a); +#endif +} + +inline C10_HOST_DEVICE Half& operator+=(Half& a, const Half& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Half& operator-=(Half& a, const Half& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Half& operator*=(Half& a, const Half& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Half& operator/=(Half& a, const Half& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Half a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Half a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Half a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Half a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Half b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Half b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Half b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Half b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Half& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Half& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Half& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Half& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Half a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Half a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Half a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Half a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Half b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Half b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Half b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Half b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Half operator+(Half a, int b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Half operator-(Half a, int b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Half operator*(Half a, int b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Half operator/(Half a, int b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Half operator+(int a, Half b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Half operator-(int a, Half b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Half operator*(int a, Half b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Half operator/(int a, Half b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Half operator+(Half a, int64_t b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Half operator-(Half a, int64_t b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Half operator*(Half a, int64_t b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Half operator/(Half a, int64_t b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Half operator+(int64_t a, Half b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Half operator-(int64_t a, Half b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Half operator*(int64_t a, Half b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Half operator/(int64_t a, Half b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Half to float. + +} // namespace c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits::has_denorm_loss; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = true; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 11; + static constexpr int digits10 = 3; + static constexpr int max_digits10 = 5; + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + static constexpr c10::Half min() { + return c10::Half(0x0400, c10::Half::from_bits()); + } + static constexpr c10::Half lowest() { + return c10::Half(0xFBFF, c10::Half::from_bits()); + } + static constexpr c10::Half max() { + return c10::Half(0x7BFF, c10::Half::from_bits()); + } + static constexpr c10::Half epsilon() { + return c10::Half(0x1400, c10::Half::from_bits()); + } + static constexpr c10::Half round_error() { + return c10::Half(0x3800, c10::Half::from_bits()); + } + static constexpr c10::Half infinity() { + return c10::Half(0x7C00, c10::Half::from_bits()); + } + static constexpr c10::Half quiet_NaN() { + return c10::Half(0x7E00, c10::Half::from_bits()); + } + static constexpr c10::Half signaling_NaN() { + return c10::Half(0x7D00, c10::Half::from_bits()); + } + static constexpr c10::Half denorm_min() { + return c10::Half(0x0001, c10::Half::from_bits()); + } +}; + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/Half.h b/c10/util/Half.h index 98480b22db334..5071731362e31 100644 --- a/c10/util/Half.h +++ b/c10/util/Half.h @@ -1,3 +1,4 @@ +<<<<<<< HEAD #include // need to keep the following for BC because the APIs in here were exposed @@ -6,3 +7,429 @@ !defined(__APPLE__) #include #endif +======= +#pragma once + +/// Defines the Half type (half-precision floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32, instead of using CUDA half intrinsics. +/// Most uses of this type within ATen are memory bound, including the +/// element-wise kernels, and the half intrinsics aren't efficient on all GPUs. +/// If you are writing a compute bound kernel, you can use the CUDA half +/// intrinsics directly on the Half type from device code. + +#include +#include +#include +#include +#include + +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#endif + +#ifdef _MSC_VER +#include +#endif + +#include +#include +#include +#include +#include + +#ifdef __CUDACC__ +#include +#endif + +#ifdef __HIPCC__ +#include +#endif + +#if defined(CL_SYCL_LANGUAGE_VERSION) +#include // for SYCL 1.2.1 +#elif defined(SYCL_LANGUAGE_VERSION) +#include // for SYCL 2020 +#endif + +#if defined(__aarch64__) && !defined(__CUDACC__) +#include +#endif + +#if defined(__GNUC__) || defined(__clang__) +#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || \ + defined(_M_IX86) +#if defined(__F16C__) && \ + !(defined(__CUDA_ARCH__) || defined(__CUDACC__) || \ + defined(__HIP_DEVICE_COMPILE__)) +#define C10_X86_F16 1 +#include // import conversion ops from f16cintrin.h +#endif // defined(__F16C__) && !(defined(__CUDA_ARCH__) || defined(__CUDACC__) + // || defined(__HIP_DEVICE_COMPILE__)) +#endif // __x86_64__ || _M_X64 || __i386 || _M_IX86 +#endif // __GNUC__ || __clang__ + +namespace c10 { + +namespace detail { + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 + * of the 32-bit word: + * + * +---+-----+------------+-------------------+ + * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 30 27-31 17-26 0-16 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the + * half-precision number normalized. If the initial number is normalized, some + * of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case + * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note + * that if we shift denormalized nonsign by renorm_shift, the unit bit of + * mantissa will shift into exponent, turning the biased exponent into 1, and + * making mantissa normalized (i.e. without leading 1). + */ +#ifdef _MSC_VER + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(nonsign); +#endif + renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; + /* + * Iff half-precision number has exponent of 15, the addition overflows + * it into bit 31, and the subsequent shift turns the high 9 bits + * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number + * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise + */ + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 + * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 + * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input + * was denormal) + * 2. Shift nonsign right by 3 so the exponent (5 bits originally) + * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high + * bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the + * different in exponent bias (0x7F for single-precision number less 0xF + * for half-precision number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to + * account for renormalization. As renorm_shift is less than 0x70, this + * can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the + * input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent + * into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + return sign | + ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); +} + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding + * mode and no operations on denormals) floating-point operations and bitcasts + * between integer and floating-point variables. + */ +C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) { +#ifdef C10_X86_F16 + return _cvtsh_ss(h); +#else + /* + * Extend the half-precision floating-point number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits + * of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become + * mantissa and exponent of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, there are some adjustments to the exponent: + * - The exponent needs to be corrected by the difference in exponent bias + * between single-precision and half-precision formats (0x7F - 0xF = 0x70) + * - Inf and NaN values in the inputs should become Inf and NaN values after + * conversion to the single-precision number. Therefore, if the biased + * exponent of the half-precision input was 0x1F (max possible value), the + * biased exponent of the single-precision output must be 0xFF (max possible + * value). We do this correction in two steps: + * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset + * below) rather than by 0x70 suggested by the difference in the exponent bias + * (see above). + * - Then we multiply the single-precision result of exponent adjustment by + * 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the + * necessary exponent adjustment by 0x70 due to difference in exponent bias. + * The floating-point multiplication hardware would ensure than Inf and + * NaN would retain their value on at least partially IEEE754-compliant + * implementations. + * + * Note that the above operations do not handle denormal inputs (where biased + * exponent == 0). However, they also do not operate on denormal inputs, and + * do not produce denormal results. + */ + constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23; + // const float exp_scale = 0x1.0p-112f; + constexpr uint32_t scale_bits = (uint32_t)15 << 23; + float exp_scale_val = 0; +#if defined(_MSC_VER) && defined(__clang__) + __builtin_memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); +#else + std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); +#endif + + const float exp_scale = exp_scale_val; + const float normalized_value = + fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + /* + * Convert denormalized half-precision inputs into single-precision results + * (always normalized). Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has + * on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the + * same mantissa and thehalf-precision input and with an exponent which would + * scale the corresponding mantissa bits to 2**(-24). A normalized + * single-precision floating-point number is represented as: FP32 = (1 + + * mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased + * exponent is 126, a unit change in the mantissa of the input denormalized + * half-precision number causes a change of the constructed single-precision + * number by 2**(-24), i.e. the same amount. + * + * The last step is to adjust the bias of the constructed single-precision + * number. When the input half-precision number is zero, the constructed + * single-precision number has the value of FP32 = 1 * 2**(126 - 127) = + * 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed + * single-precision number to get the numerical equivalent of the input + * half-precision number. + */ + constexpr uint32_t magic_mask = UINT32_C(126) << 23; + constexpr float magic_bias = 0.5f; + const float denormalized_value = + fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or + * as a denormalized number, depending on the input exponent. The variable + * two_w contains input exponent in bits 27-31, therefore if its smaller than + * 2**27, the input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign + * of the input number. + */ + constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) + : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +#endif // C10_X86_F16 +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 16-bit floating-point number in IEEE half-precision format, in bit + * representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding + * mode and no operations on denormals) floating-point operations and bitcasts + * between integer and floating-point variables. + */ +inline uint16_t fp16_ieee_from_fp32_value(float f) { +#ifdef C10_X86_F16 + return _cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT); +#else + // const float scale_to_inf = 0x1.0p+112f; + // const float scale_to_zero = 0x1.0p-110f; + constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23; + constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23; + float scale_to_inf_val = 0, scale_to_zero_val = 0; + std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); + std::memcpy( + &scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val)); + const float scale_to_inf = scale_to_inf_val; + const float scale_to_zero = scale_to_zero_val; + +#if defined(_MSC_VER) && _MSC_VER == 1916 + float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero; +#else + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; +#endif + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return static_cast( + (sign >> 16) | + (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign)); +#endif // C10_X86_F16 +} + +#ifdef C10_X86_F16 +#undef C10_X86_F16 +#endif // C10_X86_F16 + +#if defined(__aarch64__) && !defined(__CUDACC__) +inline float16_t fp16_from_bits(uint16_t h) { + return c10::bit_cast(h); +} + +inline uint16_t fp16_to_bits(float16_t f) { + return c10::bit_cast(f); +} + +// According to https://godbolt.org/z/frExdbsWG it would translate to single +// fcvt s0, h0 +inline float native_fp16_to_fp32_value(uint16_t h) { + return static_cast(fp16_from_bits(h)); +} + +inline uint16_t native_fp16_from_fp32_value(float f) { + return fp16_to_bits(static_cast(f)); +} +#endif + +} // namespace detail + +struct alignas(2) Half { + unsigned short x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + // HIP wants __host__ __device__ tag, CUDA does not +#if defined(USE_ROCM) + C10_HOST_DEVICE Half() = default; +#else + Half() = default; +#endif + + constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits) {} +#if defined(__aarch64__) && !defined(__CUDACC__) + inline Half(float16_t value); + inline operator float16_t() const; +#else + inline C10_HOST_DEVICE Half(float value); + inline C10_HOST_DEVICE operator float() const; +#endif + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_HOST_DEVICE Half(const __half& value); + inline C10_HOST_DEVICE operator __half() const; +#endif +#ifdef SYCL_LANGUAGE_VERSION + inline C10_HOST_DEVICE Half(const sycl::half& value); + inline C10_HOST_DEVICE operator sycl::half() const; +#endif +}; + +C10_API inline std::ostream& operator<<(std::ostream& out, const Half& value) { + out << (float)value; + return out; +} + +} // namespace c10 + +#include // IWYU pragma: keep +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/Logging.h b/c10/util/Logging.h index 2a08b1f1ce698..f78d39a10a001 100644 --- a/c10/util/Logging.h +++ b/c10/util/Logging.h @@ -79,7 +79,11 @@ C10_API void UpdateLoggingLevelsFromFlags(); const char* msg, const void* caller = nullptr); +<<<<<<< HEAD [[noreturn]] inline void ThrowEnforceNotMet( +======= +[[noreturn]] C10_API inline void ThrowEnforceNotMet( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const char* file, const int line, const char* condition, @@ -102,7 +106,11 @@ C10_API void UpdateLoggingLevelsFromFlags(); const char* msg, const void* caller = nullptr); +<<<<<<< HEAD [[noreturn]] inline void ThrowEnforceFiniteNotMet( +======= +[[noreturn]] C10_API inline void ThrowEnforceFiniteNotMet( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const char* file, const int line, const char* condition, diff --git a/c10/util/TypeSafeSignMath.h b/c10/util/TypeSafeSignMath.h index 28520225d4b26..7482b16dd27ea 100644 --- a/c10/util/TypeSafeSignMath.h +++ b/c10/util/TypeSafeSignMath.h @@ -1 +1,144 @@ +<<<<<<< HEAD #include +======= +#pragma once + +#include +#include +#include + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wstring-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wstring-conversion") +#endif +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace c10 { + +/// Returns false since we cannot have x < 0 if x is unsigned. +template +inline constexpr bool is_negative( + const T& /*x*/, + std::true_type /*is_unsigned*/) { + return false; +} + +/// Returns true if a signed variable x < 0 +template +inline constexpr bool is_negative(const T& x, std::false_type /*is_unsigned*/) { + return x < T(0); +} + +/// Returns true if x < 0 +/// NOTE: Will fail on an unsigned custom type +/// For the most part it's possible to fix this if +/// the custom type has a constexpr constructor. +/// However, notably, c10::Half does not :-( +template +inline constexpr bool is_negative(const T& x) { + return is_negative(x, std::is_unsigned()); +} + +/// Returns the sign of an unsigned variable x as 0, 1 +template +inline constexpr int signum(const T& x, std::true_type /*is_unsigned*/) { + return T(0) < x; +} + +/// Returns the sign of a signed variable x as -1, 0, 1 +template +inline constexpr int signum(const T& x, std::false_type /*is_unsigned*/) { + return (T(0) < x) - (x < T(0)); +} + +/// Returns the sign of x as -1, 0, 1 +/// NOTE: Will fail on an unsigned custom type +/// For the most part it's possible to fix this if +/// the custom type has a constexpr constructor. +/// However, notably, c10::Half does not :-( +template +inline constexpr int signum(const T& x) { + return signum(x, std::is_unsigned()); +} + +/// Returns true if a and b are not both negative +template +inline constexpr bool signs_differ(const T& a, const U& b) { + return is_negative(a) != is_negative(b); +} + +// Suppress sign compare warning when compiling with GCC +// as later does not account for short-circuit rule before +// raising the warning, see https://godbolt.org/z/Tr3Msnz99 +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wsign-compare" +#endif + +/// Returns true if x is greater than the greatest value of the type Limit +template +inline constexpr bool greater_than_max(const T& x) { + constexpr bool can_overflow = + std::numeric_limits::digits > std::numeric_limits::digits; + return can_overflow && x > std::numeric_limits::max(); +} + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +/// Returns true if x < lowest(Limit). Standard comparison +template +inline constexpr bool less_than_lowest( + const T& x, + std::false_type /*limit_is_unsigned*/, + std::false_type /*x_is_unsigned*/) { + return x < std::numeric_limits::lowest(); +} + +/// Returns false since all the limit is signed and therefore includes +/// negative values but x cannot be negative because it is unsigned +template +inline constexpr bool less_than_lowest( + const T& /*x*/, + std::false_type /*limit_is_unsigned*/, + std::true_type /*x_is_unsigned*/) { + return false; +} + +/// Returns true if x < 0, where 0 is constructed from T. +/// Limit is not signed, so its lower value is zero +template +inline constexpr bool less_than_lowest( + const T& x, + std::true_type /*limit_is_unsigned*/, + std::false_type /*x_is_unsigned*/) { + return x < T(0); +} + +/// Returns false sign both types are unsigned +template +inline constexpr bool less_than_lowest( + const T& /*x*/, + std::true_type /*limit_is_unsigned*/, + std::true_type /*x_is_unsigned*/) { + return false; +} + +/// Returns true if x is less than the lowest value of type T +/// NOTE: Will fail on an unsigned custom type +/// For the most part it's possible to fix this if +/// the custom type has a constexpr constructor. +/// However, notably, c10::Half does not : +template +inline constexpr bool less_than_lowest(const T& x) { + return less_than_lowest( + x, std::is_unsigned(), std::is_unsigned()); +} + +} // namespace c10 + +C10_CLANG_DIAGNOSTIC_POP() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/WaitCounter.h b/c10/util/WaitCounter.h index c87c2e3293e5d..c2ca90018c037 100644 --- a/c10/util/WaitCounter.h +++ b/c10/util/WaitCounter.h @@ -3,7 +3,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include diff --git a/c10/util/bit_cast.h b/c10/util/bit_cast.h index 49d0822d94f1b..0c7251cc64260 100644 --- a/c10/util/bit_cast.h +++ b/c10/util/bit_cast.h @@ -1 +1,48 @@ +<<<<<<< HEAD #include +======= +#pragma once + +#include +#include + +#if __has_include() && (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L) +#include +#define C10_HAVE_STD_BIT_CAST 1 +#else +#define C10_HAVE_STD_BIT_CAST 0 +#endif // __has_include() && (__cplusplus >= 202002L || + // (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L)) + +namespace c10 { + +#if C10_HAVE_STD_BIT_CAST +using std::bit_cast; +#else +// Implementations of std::bit_cast() from C++ 20. +// +// This is a less sketchy version of reinterpret_cast. +// +// See https://en.cppreference.com/w/cpp/numeric/bit_cast for more +// information as well as the source of our implementations. +template +std::enable_if_t< + sizeof(To) == sizeof(From) && std::is_trivially_copyable_v && + std::is_trivially_copyable_v, + To> +// constexpr support needs compiler magic +bit_cast(const From& src) noexcept { + static_assert( + std::is_trivially_constructible_v, + "This implementation additionally requires " + "destination type to be trivially constructible"); + + To dst; + std::memcpy(&dst, &src, sizeof(To)); + return dst; +} +#endif // C10_HAVE_STD_BIT_CAST +#undef C10_HAVE_STD_BIT_CAST + +} // namespace c10 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/bits.h b/c10/util/bits.h index 1e3c4e5151aed..ba753fb920280 100644 --- a/c10/util/bits.h +++ b/c10/util/bits.h @@ -1 +1,65 @@ +<<<<<<< HEAD #include +======= +#pragma once +#include + +#include + +namespace c10 { + +/** + * bits1x8 is an uninterpreted dtype of a tensor with 1 bit (packed to byte + * boundary), without any semantics defined. + */ +struct alignas(1) bits1x8 { + using underlying = uint8_t; + uint8_t val_; + bits1x8() = default; + C10_HOST_DEVICE explicit bits1x8(uint8_t val) : val_(val) {} +}; + +/** + * bits2x4 is an uninterpreted dtype of a tensor with 2 bits (packed to byte + * boundary), without any semantics defined. + */ +struct alignas(1) bits2x4 { + using underlying = uint8_t; + uint8_t val_; + bits2x4() = default; + C10_HOST_DEVICE explicit bits2x4(uint8_t val) : val_(val) {} +}; + +/** + * bits4x2 is an uninterpreted dtype of a tensor with 4 bits (packed to byte + * boundary), without any semantics defined. + */ +struct alignas(1) bits4x2 { + using underlying = uint8_t; + uint8_t val_; + bits4x2() = default; + C10_HOST_DEVICE explicit bits4x2(uint8_t val) : val_(val) {} +}; + +/** + * bits8 is an uninterpreted dtype of a tensor with 8 bits, without any + * semantics defined. + */ +struct alignas(1) bits8 { + uint8_t val_; + bits8() = default; + C10_HOST_DEVICE explicit bits8(uint8_t val) : val_(val) {} +}; + +/** + * bits16 is an uninterpreted dtype of a tensor with 16 bits, without any + * semantics defined. + */ +struct alignas(2) bits16 { + uint16_t val_; + bits16() = default; + C10_HOST_DEVICE explicit bits16(uint16_t val) : val_(val) {} +}; + +} // namespace c10 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/build.bzl b/c10/util/build.bzl index f061d28b4ad29..53ecf0b8e154c 100644 --- a/c10/util/build.bzl +++ b/c10/util/build.bzl @@ -58,9 +58,12 @@ def define_targets(rules): name = "bit_cast", hdrs = ["bit_cast.h"], visibility = ["//:__subpackages__"], +<<<<<<< HEAD deps = [ "//c10/macros", ], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) rules.cc_library( diff --git a/c10/util/complex.h b/c10/util/complex.h index 4e699684bc38f..3baf8d1a85bf4 100644 --- a/c10/util/complex.h +++ b/c10/util/complex.h @@ -4,7 +4,535 @@ #include #include +<<<<<<< HEAD #include +======= + +#if defined(__CUDACC__) || defined(__HIPCC__) +#include +#endif + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif +#if C10_CLANG_HAS_WARNING("-Wfloat-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion") +#endif + +namespace c10 { + +// c10::complex is an implementation of complex numbers that aims +// to work on all devices supported by PyTorch +// +// Most of the APIs duplicates std::complex +// Reference: https://en.cppreference.com/w/cpp/numeric/complex +// +// [NOTE: Complex Operator Unification] +// Operators currently use a mix of std::complex, thrust::complex, and +// c10::complex internally. The end state is that all operators will use +// c10::complex internally. Until then, there may be some hacks to support all +// variants. +// +// +// [Note on Constructors] +// +// The APIs of constructors are mostly copied from C++ standard: +// https://en.cppreference.com/w/cpp/numeric/complex/complex +// +// Since C++14, all constructors are constexpr in std::complex +// +// There are three types of constructors: +// - initializing from real and imag: +// `constexpr complex( const T& re = T(), const T& im = T() );` +// - implicitly-declared copy constructor +// - converting constructors +// +// Converting constructors: +// - std::complex defines converting constructor between float/double/long +// double, +// while we define converting constructor between float/double. +// - For these converting constructors, upcasting is implicit, downcasting is +// explicit. +// - We also define explicit casting from std::complex/thrust::complex +// - Note that the conversion from thrust is not constexpr, because +// thrust does not define them as constexpr ???? +// +// +// [Operator =] +// +// The APIs of operator = are mostly copied from C++ standard: +// https://en.cppreference.com/w/cpp/numeric/complex/operator%3D +// +// Since C++20, all operator= are constexpr. Although we are not building with +// C++20, we also obey this behavior. +// +// There are three types of assign operator: +// - Assign a real value from the same scalar type +// - In std, this is templated as complex& operator=(const T& x) +// with specialization `complex& operator=(T x)` for float/double/long +// double Since we only support float and double, on will use `complex& +// operator=(T x)` +// - Copy assignment operator and converting assignment operator +// - There is no specialization of converting assignment operators, which type +// is +// convertible is solely dependent on whether the scalar type is convertible +// +// In addition to the standard assignment, we also provide assignment operators +// with std and thrust +// +// +// [Casting operators] +// +// std::complex does not have casting operators. We define casting operators +// casting to std::complex and thrust::complex +// +// +// [Operator ""] +// +// std::complex has custom literals `i`, `if` and `il` defined in namespace +// `std::literals::complex_literals`. We define our own custom literals in the +// namespace `c10::complex_literals`. Our custom literals does not follow the +// same behavior as in std::complex, instead, we define _if, _id to construct +// float/double complex literals. +// +// +// [real() and imag()] +// +// In C++20, there are two overload of these functions, one it to return the +// real/imag, another is to set real/imag, they are both constexpr. We follow +// this design. +// +// +// [Operator +=,-=,*=,/=] +// +// Since C++20, these operators become constexpr. In our implementation, they +// are also constexpr. +// +// There are two types of such operators: operating with a real number, or +// operating with another complex number. For the operating with a real number, +// the generic template form has argument type `const T &`, while the overload +// for float/double/long double has `T`. We will follow the same type as +// float/double/long double in std. +// +// [Unary operator +-] +// +// Since C++20, they are constexpr. We also make them expr +// +// [Binary operators +-*/] +// +// Each operator has three versions (taking + as example): +// - complex + complex +// - complex + real +// - real + complex +// +// [Operator ==, !=] +// +// Each operator has three versions (taking == as example): +// - complex == complex +// - complex == real +// - real == complex +// +// Some of them are removed on C++20, but we decide to keep them +// +// [Operator <<, >>] +// +// These are implemented by casting to std::complex +// +// +// +// TODO(@zasdfgbnm): c10::complex is not currently supported, +// because: +// - lots of members and functions of c10::Half are not constexpr +// - thrust::complex only support float and double + +template +struct alignas(sizeof(T) * 2) complex { + using value_type = T; + + T real_ = T(0); + T imag_ = T(0); + + constexpr complex() = default; + C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T()) + : real_(re), imag_(im) {} + template + explicit constexpr complex(const std::complex& other) + : complex(other.real(), other.imag()) {} +#if defined(__CUDACC__) || defined(__HIPCC__) + template + explicit C10_HOST_DEVICE complex(const thrust::complex& other) + : real_(other.real()), imag_(other.imag()) {} +// NOTE can not be implemented as follow due to ROCm bug: +// explicit C10_HOST_DEVICE complex(const thrust::complex &other): +// complex(other.real(), other.imag()) {} +#endif + + // Use SFINAE to specialize casting constructor for c10::complex and + // c10::complex + template + C10_HOST_DEVICE explicit constexpr complex( + const std::enable_if_t, complex>& other) + : real_(other.real_), imag_(other.imag_) {} + template + C10_HOST_DEVICE constexpr complex( + const std::enable_if_t, complex>& other) + : real_(other.real_), imag_(other.imag_) {} + + constexpr complex& operator=(T re) { + real_ = re; + imag_ = 0; + return *this; + } + + constexpr complex& operator+=(T re) { + real_ += re; + return *this; + } + + constexpr complex& operator-=(T re) { + real_ -= re; + return *this; + } + + constexpr complex& operator*=(T re) { + real_ *= re; + imag_ *= re; + return *this; + } + + constexpr complex& operator/=(T re) { + real_ /= re; + imag_ /= re; + return *this; + } + + template + constexpr complex& operator=(const complex& rhs) { + real_ = rhs.real(); + imag_ = rhs.imag(); + return *this; + } + + template + constexpr complex& operator+=(const complex& rhs) { + real_ += rhs.real(); + imag_ += rhs.imag(); + return *this; + } + + template + constexpr complex& operator-=(const complex& rhs) { + real_ -= rhs.real(); + imag_ -= rhs.imag(); + return *this; + } + + template + constexpr complex& operator*=(const complex& rhs) { + // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i + T a = real_; + T b = imag_; + U c = rhs.real(); + U d = rhs.imag(); + real_ = a * c - b * d; + imag_ = a * d + b * c; + return *this; + } + +#ifdef __APPLE__ +#define FORCE_INLINE_APPLE __attribute__((always_inline)) +#else +#define FORCE_INLINE_APPLE +#endif + template + constexpr FORCE_INLINE_APPLE complex& operator/=(const complex& rhs) + __ubsan_ignore_float_divide_by_zero__ { + // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i + // the calculation below follows numpy's complex division + T a = real_; + T b = imag_; + U c = rhs.real(); + U d = rhs.imag(); + +#if defined(__GNUC__) && !defined(__clang__) + // std::abs is already constexpr by gcc + auto abs_c = std::abs(c); + auto abs_d = std::abs(d); +#else + auto abs_c = c < 0 ? -c : c; + auto abs_d = d < 0 ? -d : d; +#endif + + if (abs_c >= abs_d) { + if (abs_c == U(0) && abs_d == U(0)) { + /* divide by zeros should yield a complex inf or nan */ + real_ = a / abs_c; + imag_ = b / abs_d; + } else { + auto rat = d / c; + auto scl = U(1.0) / (c + d * rat); + real_ = (a + b * rat) * scl; + imag_ = (b - a * rat) * scl; + } + } else { + auto rat = c / d; + auto scl = U(1.0) / (d + c * rat); + real_ = (a * rat + b) * scl; + imag_ = (b * rat - a) * scl; + } + return *this; + } +#undef FORCE_INLINE_APPLE + + template + constexpr complex& operator=(const std::complex& rhs) { + real_ = rhs.real(); + imag_ = rhs.imag(); + return *this; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + template + C10_HOST_DEVICE complex& operator=(const thrust::complex& rhs) { + real_ = rhs.real(); + imag_ = rhs.imag(); + return *this; + } +#endif + + template + explicit constexpr operator std::complex() const { + return std::complex(std::complex(real(), imag())); + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + template + C10_HOST_DEVICE explicit operator thrust::complex() const { + return static_cast>(thrust::complex(real(), imag())); + } +#endif + + // consistent with NumPy behavior + explicit constexpr operator bool() const { + return real() || imag(); + } + + C10_HOST_DEVICE constexpr T real() const { + return real_; + } + constexpr void real(T value) { + real_ = value; + } + C10_HOST_DEVICE constexpr T imag() const { + return imag_; + } + constexpr void imag(T value) { + imag_ = value; + } +}; + +namespace complex_literals { + +constexpr complex operator""_if(long double imag) { + return complex(0.0f, static_cast(imag)); +} + +constexpr complex operator""_id(long double imag) { + return complex(0.0, static_cast(imag)); +} + +constexpr complex operator""_if(unsigned long long imag) { + return complex(0.0f, static_cast(imag)); +} + +constexpr complex operator""_id(unsigned long long imag) { + return complex(0.0, static_cast(imag)); +} + +} // namespace complex_literals + +template +constexpr complex operator+(const complex& val) { + return val; +} + +template +constexpr complex operator-(const complex& val) { + return complex(-val.real(), -val.imag()); +} + +template +constexpr complex operator+(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result += rhs; +} + +template +constexpr complex operator+(const complex& lhs, const T& rhs) { + complex result = lhs; + return result += rhs; +} + +template +constexpr complex operator+(const T& lhs, const complex& rhs) { + return complex(lhs + rhs.real(), rhs.imag()); +} + +template +constexpr complex operator-(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result -= rhs; +} + +template +constexpr complex operator-(const complex& lhs, const T& rhs) { + complex result = lhs; + return result -= rhs; +} + +template +constexpr complex operator-(const T& lhs, const complex& rhs) { + complex result = -rhs; + return result += lhs; +} + +template +constexpr complex operator*(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result *= rhs; +} + +template +constexpr complex operator*(const complex& lhs, const T& rhs) { + complex result = lhs; + return result *= rhs; +} + +template +constexpr complex operator*(const T& lhs, const complex& rhs) { + complex result = rhs; + return result *= lhs; +} + +template +constexpr complex operator/(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result /= rhs; +} + +template +constexpr complex operator/(const complex& lhs, const T& rhs) { + complex result = lhs; + return result /= rhs; +} + +template +constexpr complex operator/(const T& lhs, const complex& rhs) { + complex result(lhs, T()); + return result /= rhs; +} + +// Define operators between integral scalars and c10::complex. std::complex does +// not support this when T is a floating-point number. This is useful because it +// saves a lot of "static_cast" when operate a complex and an integer. This +// makes the code both less verbose and potentially more efficient. +#define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \ + typename std::enable_if_t< \ + std::is_floating_point_v && std::is_integral_v, \ + int> = 0 + +template +constexpr c10::complex operator+(const c10::complex& a, const iT& b) { + return a + static_cast(b); +} + +template +constexpr c10::complex operator+(const iT& a, const c10::complex& b) { + return static_cast(a) + b; +} + +template +constexpr c10::complex operator-(const c10::complex& a, const iT& b) { + return a - static_cast(b); +} + +template +constexpr c10::complex operator-(const iT& a, const c10::complex& b) { + return static_cast(a) - b; +} + +template +constexpr c10::complex operator*(const c10::complex& a, const iT& b) { + return a * static_cast(b); +} + +template +constexpr c10::complex operator*(const iT& a, const c10::complex& b) { + return static_cast(a) * b; +} + +template +constexpr c10::complex operator/(const c10::complex& a, const iT& b) { + return a / static_cast(b); +} + +template +constexpr c10::complex operator/(const iT& a, const c10::complex& b) { + return static_cast(a) / b; +} + +#undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION + +template +constexpr bool operator==(const complex& lhs, const complex& rhs) { + return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag()); +} + +template +constexpr bool operator==(const complex& lhs, const T& rhs) { + return (lhs.real() == rhs) && (lhs.imag() == T()); +} + +template +constexpr bool operator==(const T& lhs, const complex& rhs) { + return (lhs == rhs.real()) && (T() == rhs.imag()); +} + +template +constexpr bool operator!=(const complex& lhs, const complex& rhs) { + return !(lhs == rhs); +} + +template +constexpr bool operator!=(const complex& lhs, const T& rhs) { + return !(lhs == rhs); +} + +template +constexpr bool operator!=(const T& lhs, const complex& rhs) { + return !(lhs == rhs); +} + +template +std::basic_ostream& operator<<( + std::basic_ostream& os, + const complex& x) { + return (os << static_cast>(x)); +} + +template +std::basic_istream& operator>>( + std::basic_istream& is, + complex& x) { + std::complex tmp; + is >> tmp; + x = tmp; + return is; +} + +} // namespace c10 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // std functions // @@ -70,6 +598,75 @@ constexpr c10::complex conj(const c10::complex& z) { } // namespace std +<<<<<<< HEAD +======= +namespace c10 { + +template +C10_HOST_DEVICE complex polar(const T& r, const T& theta = T()) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>(thrust::polar(r, theta)); +#else + // std::polar() requires r >= 0, so spell out the explicit implementation to + // avoid a branch. + return complex(r * std::cos(theta), r * std::sin(theta)); +#endif +} + +template <> +struct alignas(4) complex { + Half real_; + Half imag_; + + // Constructors + complex() = default; + // Half constructor is not constexpr so the following constructor can't + // be constexpr + C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag) + : real_(real), imag_(imag) {} + C10_HOST_DEVICE inline complex(const c10::complex& value) + : real_(value.real()), imag_(value.imag()) {} + + // Conversion operator + inline C10_HOST_DEVICE operator c10::complex() const { + return {real_, imag_}; + } + + constexpr C10_HOST_DEVICE Half real() const { + return real_; + } + constexpr C10_HOST_DEVICE Half imag() const { + return imag_; + } + + C10_HOST_DEVICE complex& operator+=(const complex& other) { + real_ = static_cast(real_) + static_cast(other.real_); + imag_ = static_cast(imag_) + static_cast(other.imag_); + return *this; + } + + C10_HOST_DEVICE complex& operator-=(const complex& other) { + real_ = static_cast(real_) - static_cast(other.real_); + imag_ = static_cast(imag_) - static_cast(other.imag_); + return *this; + } + + C10_HOST_DEVICE complex& operator*=(const complex& other) { + auto a = static_cast(real_); + auto b = static_cast(imag_); + auto c = static_cast(other.real()); + auto d = static_cast(other.imag()); + real_ = a * c - b * d; + imag_ = a * d + b * c; + return *this; + } +}; + +} // namespace c10 + +C10_CLANG_DIAGNOSTIC_POP() + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H // math functions are included in a separate file #include // IWYU pragma: keep diff --git a/c10/util/floating_point_utils.h b/c10/util/floating_point_utils.h index 10aa67c7cb843..0d17d2bb4031a 100644 --- a/c10/util/floating_point_utils.h +++ b/c10/util/floating_point_utils.h @@ -1 +1,37 @@ +<<<<<<< HEAD #include +======= +#pragma once + +#include +#include +#include + +namespace c10::detail { + +C10_HOST_DEVICE inline float fp32_from_bits(uint32_t w) { +#if defined(__OPENCL_VERSION__) + return as_float(w); +#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + return __uint_as_float((unsigned int)w); +#elif defined(__INTEL_COMPILER) + return _castu32_f32(w); +#else + return c10::bit_cast(w); +#endif +} + +C10_HOST_DEVICE inline uint32_t fp32_to_bits(float f) { +#if defined(__OPENCL_VERSION__) + return as_uint(f); +#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + return (uint32_t)__float_as_uint(f); +#elif defined(__INTEL_COMPILER) + return _castf32_u32(f); +#else + return c10::bit_cast(f); +#endif +} + +} // namespace c10::detail +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/generic_math.h b/c10/util/generic_math.h index 493c03cb42e64..44ec87e45d221 100644 --- a/c10/util/generic_math.h +++ b/c10/util/generic_math.h @@ -93,7 +93,11 @@ template < std::enable_if_t, int> = 0> inline C10_HOST_DEVICE scalar_t div_mod(scalar_t a, scalar_t b) { auto mod = a % b; +<<<<<<< HEAD if (mod != 0 && (b < 0) != (mod < 0)) { +======= + if ((b < 0) != (mod < 0)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mod += b; } return mod; diff --git a/c10/util/int128.h b/c10/util/int128.h index 4bea5a5f1197c..81a04f22920fb 100644 --- a/c10/util/int128.h +++ b/c10/util/int128.h @@ -154,6 +154,7 @@ inline bool operator!=(const uint128& lhs, const uint128& rhs) { return !(lhs == rhs); } +<<<<<<< HEAD inline UINT128_CONSTEXPR uint128::uint128() : lo_(0), hi_(0) {} inline UINT128_CONSTEXPR uint128::uint128(uint64_t top, uint64_t bottom) : lo_(bottom), hi_(top) {} @@ -165,12 +166,29 @@ inline UINT128_CONSTEXPR uint128::uint128(uint64_t bottom) inline UINT128_CONSTEXPR uint128::uint128(uint32_t bottom) : lo_(bottom), hi_(0) {} inline UINT128_CONSTEXPR uint128::uint128(int bottom) +======= +C10_API inline UINT128_CONSTEXPR uint128::uint128() : lo_(0), hi_(0) {} +C10_API inline UINT128_CONSTEXPR uint128::uint128(uint64_t top, uint64_t bottom) + : lo_(bottom), hi_(top) {} +C10_API inline UINT128_CONSTEXPR uint128::uint128(const uint128_pod& v) + : lo_(v.lo), hi_(v.hi) {} +C10_API inline UINT128_CONSTEXPR uint128::uint128(uint64_t bottom) + : lo_(bottom), hi_(0) {} +#ifndef SWIG +C10_API inline UINT128_CONSTEXPR uint128::uint128(uint32_t bottom) + : lo_(bottom), hi_(0) {} +C10_API inline UINT128_CONSTEXPR uint128::uint128(int bottom) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : lo_(bottom), hi_(static_cast((bottom < 0) ? -1 : 0)) {} #endif #undef UINT128_CONSTEXPR +<<<<<<< HEAD inline void uint128::Initialize(uint64_t top, uint64_t bottom) { +======= +C10_API inline void uint128::Initialize(uint64_t top, uint64_t bottom) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) hi_ = top; lo_ = bottom; } @@ -226,11 +244,19 @@ LOGIC128(^) #undef LOGIC128 +<<<<<<< HEAD #define LOGICASSIGN128(op) \ inline uint128& uint128::operator op(const uint128 & other) { \ hi_ op other.hi_; \ lo_ op other.lo_; \ return *this; \ +======= +#define LOGICASSIGN128(op) \ + C10_API inline uint128& uint128::operator op(const uint128 & other) { \ + hi_ op other.hi_; \ + lo_ op other.lo_; \ + return *this; \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } LOGICASSIGN128(|=) @@ -295,7 +321,11 @@ inline uint128& operator<<=(uint128& self, int amount) { return self; } +<<<<<<< HEAD inline uint128& uint128::operator>>=(int amount) { +======= +C10_API inline uint128& uint128::operator>>=(int amount) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // uint64_t shifts of >= 64 are undefined, so we will need some // special-casing. if (amount < 64) { @@ -333,7 +363,11 @@ inline uint128 operator%(const uint128& lhs, const uint128& rhs) { return uint128(lhs) %= rhs; } +<<<<<<< HEAD inline uint128& uint128::operator+=(const uint128& b) { +======= +C10_API inline uint128& uint128::operator+=(const uint128& b) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) hi_ += b.hi_; uint64_t lolo = lo_ + b.lo_; if (lolo < lo_) @@ -342,7 +376,11 @@ inline uint128& uint128::operator+=(const uint128& b) { return *this; } +<<<<<<< HEAD inline uint128& uint128::operator-=(const uint128& b) { +======= +C10_API inline uint128& uint128::operator-=(const uint128& b) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) hi_ -= b.hi_; if (b.lo_ > lo_) --hi_; @@ -350,7 +388,11 @@ inline uint128& uint128::operator-=(const uint128& b) { return *this; } +<<<<<<< HEAD inline uint128& uint128::operator*=(const uint128& b) { +======= +C10_API inline uint128& uint128::operator*=(const uint128& b) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint64_t a96 = hi_ >> 32; uint64_t a64 = hi_ & 0xffffffffu; uint64_t a32 = lo_ >> 32; @@ -373,24 +415,40 @@ inline uint128& uint128::operator*=(const uint128& b) { return *this; } +<<<<<<< HEAD inline uint128 uint128::operator++(int) { +======= +C10_API inline uint128 uint128::operator++(int) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint128 tmp(*this); *this += 1; return tmp; } +<<<<<<< HEAD inline uint128 uint128::operator--(int) { +======= +C10_API inline uint128 uint128::operator--(int) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint128 tmp(*this); *this -= 1; return tmp; } +<<<<<<< HEAD inline uint128& uint128::operator++() { +======= +C10_API inline uint128& uint128::operator++() { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *this += 1; return *this; } +<<<<<<< HEAD inline uint128& uint128::operator--() { +======= +C10_API inline uint128& uint128::operator--() { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *this -= 1; return *this; } diff --git a/c10/util/irange.h b/c10/util/irange.h index cc52d443ee5f3..1dec16987b8f1 100644 --- a/c10/util/irange.h +++ b/c10/util/irange.h @@ -24,7 +24,11 @@ struct integer_iterator { using pointer = I*; using reference = I&; +<<<<<<< HEAD explicit constexpr integer_iterator(I val) : value(val) {} +======= + explicit constexpr integer_iterator(I value) : value(value) {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) constexpr I operator*() const { return value; diff --git a/c10/util/qint32.h b/c10/util/qint32.h index 2d1f877f98d48..68bf4dc545cc3 100644 --- a/c10/util/qint32.h +++ b/c10/util/qint32.h @@ -1 +1,22 @@ +<<<<<<< HEAD #include +======= +#pragma once +#include + +#include + +namespace c10 { + +/** + * qint32 is for signed 32 bit quantized Tensors + */ +struct alignas(4) qint32 { + using underlying = int32_t; + int32_t val_; + qint32() = default; + C10_HOST_DEVICE explicit qint32(int32_t val) : val_(val) {} +}; + +} // namespace c10 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/qint8.h b/c10/util/qint8.h index 6eb25f755c901..aa7dd08743b54 100644 --- a/c10/util/qint8.h +++ b/c10/util/qint8.h @@ -1 +1,24 @@ +<<<<<<< HEAD #include +======= +#pragma once +#include + +#include + +namespace c10 { + +/** + * This is the data type for quantized Tensors. Right now we only have + * qint8 which is for 8 bit Tensors, and qint32 for 32 bit int Tensors, + * we might have 4 bit, 2 bit or 1 bit data types in the future. + */ +struct alignas(1) qint8 { + using underlying = int8_t; + int8_t val_; + qint8() = default; + C10_HOST_DEVICE explicit qint8(int8_t val) : val_(val) {} +}; + +} // namespace c10 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/quint2x4.h b/c10/util/quint2x4.h index 67c846159dfd8..9be7cbdd0c84e 100644 --- a/c10/util/quint2x4.h +++ b/c10/util/quint2x4.h @@ -1 +1,23 @@ +<<<<<<< HEAD #include +======= +#pragma once +#include + +#include + +namespace c10 { + +/** + * quint2x4 is for un-signed 2 bit quantized Tensors that are packed to byte + * boundary. + */ +struct alignas(1) quint2x4 { + using underlying = uint8_t; + uint8_t val_; + quint2x4() = default; + C10_HOST_DEVICE explicit quint2x4(uint8_t val) : val_(val) {} +}; + +} // namespace c10 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/quint4x2.h b/c10/util/quint4x2.h index c9e06e6131777..ee14e5342102b 100644 --- a/c10/util/quint4x2.h +++ b/c10/util/quint4x2.h @@ -1 +1,23 @@ +<<<<<<< HEAD #include +======= +#pragma once +#include + +#include + +namespace c10 { + +/** + * quint4x2 is for un-signed 4 bit quantized Tensors that are packed to byte + * boundary. + */ +struct alignas(1) quint4x2 { + using underlying = uint8_t; + uint8_t val_; + quint4x2() = default; + C10_HOST_DEVICE explicit quint4x2(uint8_t val) : val_(val) {} +}; + +} // namespace c10 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/quint8.h b/c10/util/quint8.h index 4d5719750c627..f9a2ccebecd63 100644 --- a/c10/util/quint8.h +++ b/c10/util/quint8.h @@ -1 +1,22 @@ +<<<<<<< HEAD #include +======= +#pragma once +#include + +#include + +namespace c10 { + +/** + * quint8 is for unsigned 8 bit quantized Tensors + */ +struct alignas(1) quint8 { + using underlying = uint8_t; + uint8_t val_; + quint8() = default; + C10_HOST_DEVICE explicit quint8(uint8_t val) : val_(val) {} +}; + +} // namespace c10 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/c10/util/safe_numerics.h b/c10/util/safe_numerics.h index 32ffca52e4864..842d95e36a027 100644 --- a/c10/util/safe_numerics.h +++ b/c10/util/safe_numerics.h @@ -1,7 +1,10 @@ #pragma once #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include // GCC has __builtin_mul_overflow from before it supported __has_builtin @@ -32,6 +35,7 @@ C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) { #endif } +<<<<<<< HEAD template C10_ALWAYS_INLINE bool mul_overflows(T a, T b, T* out) { #if C10_HAS_BUILTIN_OVERFLOW() @@ -62,6 +66,30 @@ C10_ALWAYS_INLINE bool mul_overflows(T a, T b, T* out) { C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) { return mul_overflows(a, b, out); +======= +C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) { +#if C10_HAS_BUILTIN_OVERFLOW() + return __builtin_mul_overflow(a, b, out); +#else + *out = a * b; + // This test isn't exact, but avoids doing integer division + return ( + (c10::llvm::countLeadingZeros(a) + c10::llvm::countLeadingZeros(b)) < 64); +#endif +} + +C10_ALWAYS_INLINE bool mul_overflows(int64_t a, int64_t b, int64_t* out) { +#if C10_HAS_BUILTIN_OVERFLOW() + return __builtin_mul_overflow(a, b, out); +#else + volatile int64_t tmp = a * b; + *out = tmp; + if (a == 0 || b == 0) { + return false; + } + return !(a == tmp / b); +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } template diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index a5e088515ff55..c974e63aa7dd8 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -540,7 +540,11 @@ class DeviceCachingAllocator { static void local_raw_delete(void* ptr); +<<<<<<< HEAD class XPUAllocator : public DeviceAllocator { +======= +class XPUAllocator : public Allocator { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) private: std::mutex mutex; ska::flat_hash_map allocated_blocks; @@ -576,10 +580,13 @@ class XPUAllocator : public DeviceAllocator { } } +<<<<<<< HEAD bool initialized() override { return !device_allocators.empty(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void malloc( void** devPtr, DeviceIndex device, @@ -614,13 +621,21 @@ class XPUAllocator : public DeviceAllocator { } } +<<<<<<< HEAD void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override { +======= + void emptyCache() { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (auto& da : device_allocators) { da->emptyCache(); } } +<<<<<<< HEAD void recordStream(const DataPtr& ptr, c10::Stream stream) override { +======= + void recordStream(const DataPtr& ptr, XPUStream stream) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!ptr.get()) { return; } @@ -630,8 +645,12 @@ class XPUAllocator : public DeviceAllocator { Block* block = get_allocated_block(ptr.get()); TORCH_CHECK(block, "No allocated block can be found."); +<<<<<<< HEAD c10::xpu::XPUStream xpu_stream{stream}; device_allocators[block->device]->recordStream(block, xpu_stream); +======= + device_allocators[block->device]->recordStream(block, stream); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } DataPtr allocate(size_t size) override { @@ -684,17 +703,29 @@ class XPUAllocator : public DeviceAllocator { ": did you call init?"); } +<<<<<<< HEAD DeviceStats getDeviceStats(DeviceIndex device) override { +======= + DeviceStats getDeviceStats(DeviceIndex device) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assertValidDevice(device); return device_allocators[device]->getStats(); } +<<<<<<< HEAD void resetPeakStats(DeviceIndex device) override { +======= + void resetPeakStats(DeviceIndex device) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assertValidDevice(device); device_allocators[device]->resetPeakStats(); } +<<<<<<< HEAD void resetAccumulatedStats(DeviceIndex device) override { +======= + void resetAccumulatedStats(DeviceIndex device) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assertValidDevice(device); device_allocators[device]->resetAccumulatedStats(); } diff --git a/c10/xpu/XPUDeviceProp.h b/c10/xpu/XPUDeviceProp.h index 085c6367477f0..12b0c2330d1a7 100644 --- a/c10/xpu/XPUDeviceProp.h +++ b/c10/xpu/XPUDeviceProp.h @@ -113,6 +113,7 @@ namespace c10::xpu { _(native_vector_width_double) \ _(native_vector_width_half) +<<<<<<< HEAD #define AT_FORALL_XPU_EXT_DEVICE_PROPERTIES(_) \ /* the number of EUs associated with the Intel GPU. */ \ _(gpu_eu_count, gpu_eu_count, 512) \ @@ -131,6 +132,20 @@ namespace c10::xpu { \ /* the device descriptor for device Universal Unique ID, 16 bytes*/ \ _(uuid, device_info_uuid, (std::array{})) +======= +#define AT_FORALL_XPU_EXT_DEVICE_PROPERTIES(_) \ + /* the number of EUs associated with the Intel GPU. */ \ + _(gpu_eu_count, 512) \ + \ + /* the number of EUs in a subslice. */ \ + _(gpu_eu_count_per_subslice, 8) \ + \ + /* the simd width of EU of GPU. */ \ + _(gpu_eu_simd_width, 8) \ + \ + /* the number of hardware threads per EU of GPU. */ \ + _(gpu_hw_threads_per_eu, 8) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #define AT_FORALL_XPU_DEVICE_ASPECT(_) \ /* sycl::half is supported on device. */ \ diff --git a/c10/xpu/XPUFunctions.cpp b/c10/xpu/XPUFunctions.cpp index 6947c078483eb..927fc08d88b5d 100644 --- a/c10/xpu/XPUFunctions.cpp +++ b/c10/xpu/XPUFunctions.cpp @@ -157,9 +157,15 @@ void initDeviceProperties(DeviceProp* device_prop, DeviceIndex device) { #define ASSIGN_DEVICE_PROP(property) \ device_prop->property = raw_device.get_info(); +<<<<<<< HEAD #define ASSIGN_EXT_DEVICE_PROP(property, aspect_tag, default_value) \ device_prop->property = raw_device.has(sycl::aspect::ext_intel_##aspect_tag) \ ? raw_device.get_info() \ +======= +#define ASSIGN_EXT_DEVICE_PROP(property, default_value) \ + device_prop->property = raw_device.has(sycl::aspect::ext_intel_##property) \ + ? raw_device.get_info() \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : default_value; #define ASSIGN_DEVICE_ASPECT(member) \ diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 6ab41b6c84793..43b8c70716865 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -267,7 +267,10 @@ if(NOT INTERN_DISABLE_AUTOGRAD AND NOT BUILD_LITE_INTERPRETER) "${TORCH_SRC_DIR}/csrc/autograd/generated/ADInplaceOrViewType_0.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/ADInplaceOrViewType_1.cpp" "${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_cpu.cpp" +<<<<<<< HEAD "${TORCH_SRC_DIR}/csrc/inductor/aoti_torch/generated/c_shim_aten.cpp" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if(BUILD_LAZY_TS_BACKEND) list(APPEND GENERATED_CXX_TORCH @@ -316,7 +319,10 @@ set(GENERATED_CXX_PYTHON "${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp" +<<<<<<< HEAD "${TORCH_SRC_DIR}/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) set(GENERATED_H_PYTHON @@ -380,9 +386,12 @@ add_custom_command( "${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h" "${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h" "${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp" +<<<<<<< HEAD "${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClasses.h" "${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClasses.cpp" "${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ${autograd_python} ${autograd_yaml} ${autograd_templates} @@ -585,7 +594,10 @@ if(USE_CUDA) ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu +<<<<<<< HEAD ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" ) endif() @@ -830,6 +842,10 @@ if(USE_MPS) if(CAN_COMPILE_METAL) add_dependencies(torch_cpu metallibs) target_link_options(torch_cpu PRIVATE -Wl,-sectcreate,__TEXT,metal_basic,${CMAKE_CURRENT_BINARY_DIR}/aten/src/ATen/kernels_basic.metallib) +<<<<<<< HEAD +======= + target_link_options(torch_cpu PRIVATE -Wl,-sectcreate,__TEXT,metal_bfloat,${CMAKE_CURRENT_BINARY_DIR}/aten/src/ATen/kernels_bfloat.metallib) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else() target_compile_definitions(torch_cpu PRIVATE PYTORCH_JIT_COMPILE_SHADERS) endif() @@ -992,6 +1008,7 @@ elseif(USE_CUDA) target_compile_definitions(torch_cuda PRIVATE USE_NCCL) endif() +<<<<<<< HEAD # Compile with NVSHMEM # Default value of `USE_NVSHMEM` is set in CMakeLists.txt under root, to ON. if(USE_NVSHMEM) @@ -1028,17 +1045,44 @@ elseif(USE_CUDA) # If NVSHMEM_LIBRARY is found, we build torch_cuda with NVSHMEM support. if(NVSHMEM_HOST_LIB AND NVSHMEM_DEVICE_LIB AND NVSHMEM_INCLUDE_DIR) message(STATUS "NVSHMEM found, building with NVSHMEM support") +======= + # Use env var for these for now for prototyping purposes + set(USE_NVSHMEM $ENV{USE_NVSHMEM} CACHE BOOL "Whether to build with NVSHMEM support") + # If user has specified NVSHMEM_HOME, we use it; + # Otherwise, NVSHMEM_HOME is auto detected in tools/setup_helpers/cmake.py + if($ENV{NVSHMEM_HOME}) + set(NVSHMEM_HOME $ENV{NVSHMEM_HOME} CACHE PATH "Path to NVSHMEM build dir") + endif() + + if(USE_NVSHMEM AND NOT DEFINED NVSHMEM_HOME) + message(WARNING "USE_NVSHMEM set to 1 but NVSHMEM_HOME not found. Please run `pip install nvidia-nvshmem-`, or set NVSHMEM_HOME to the NVSHMEM build dir") + # Disable nvshmem if NVSHMEM_HOME is not found + set(USE_NVSHMEM FALSE CACHE BOOL "Whether to build with NVSHMEM support") + endif() + + if(USE_NVSHMEM) + message("Building with NVSHMEM support: '${NVSHMEM_HOME}'") + set(NVSHMEM_INCLUDE_DIR "${NVSHMEM_HOME}/include") + set(NVSHMEM_LIB_DIR "${NVSHMEM_HOME}/lib") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) include_directories(${NVSHMEM_INCLUDE_DIR}) # Linking with nvshmem requires the source binary to be built with -rdc # which is not viable for libtorch_cuda. So we isolate the linking of +<<<<<<< HEAD # nvshmem in torch_nvshmem. add_library(torch_nvshmem SHARED +======= + # nvshmem in nvshmem_extension. + add_library(nvshmem_extension SHARED +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp" "${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu" "${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu" "${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp" ) +<<<<<<< HEAD set_target_properties(torch_nvshmem PROPERTIES CUDA_SEPARABLE_COMPILATION ON) target_compile_options(torch_nvshmem PRIVATE $<$:-rdc=true>) target_compile_options(torch_nvshmem PRIVATE "-U__CUDA_NO_HALF_OPERATORS__") @@ -1052,6 +1096,22 @@ elseif(USE_CUDA) install(TARGETS torch_nvshmem EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}") else() message(STATUS "NVSHMEM not found, not building with NVSHMEM support.") +======= + set_target_properties(nvshmem_extension PROPERTIES CUDA_SEPARABLE_COMPILATION ON) + target_compile_options(nvshmem_extension PRIVATE $<$:-rdc=true>) + target_compile_options(nvshmem_extension PRIVATE "-U__CUDA_NO_HALF_OPERATORS__") + target_link_directories(nvshmem_extension PRIVATE ${NVSHMEM_LIB_DIR}) + target_link_libraries(nvshmem_extension PRIVATE + # Full path needed bc nvshmem wheel ships with .so.3 instead of .so; + # otherwise, we could just write `nvshmem_host` + ${NVSHMEM_LIB_DIR}/libnvshmem_host.so.3 + nvshmem_device + ) + target_compile_definitions(torch_cuda PUBLIC USE_NVSHMEM) + target_compile_definitions(nvshmem_extension PUBLIC USE_NVSHMEM) + target_link_libraries(torch_cuda PRIVATE nvshmem_extension) + install(TARGETS nvshmem_extension EXPORT Caffe2Targets DESTINATION lib) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() if(USE_UCC) @@ -1067,7 +1127,11 @@ elseif(USE_CUDA) UNFUSE_FMA # Addressing issue #121558 ) target_sources(torch_cuda PRIVATE $) +<<<<<<< HEAD target_include_directories(torch_cuda SYSTEM PUBLIC +======= + target_include_directories(torch_cuda PUBLIC +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) $ $ $ @@ -1097,10 +1161,24 @@ elseif(USE_CUDA) torch_cuda ) if($ENV{ATEN_STATIC_CUDA}) +<<<<<<< HEAD target_link_libraries(torch_cuda_linalg PRIVATE CUDA::cusolver_static ${CUDAToolkit_LIBRARY_DIR}/libcusolver_lapack_static.a # needed for libcusolver_static ) +======= + if(CUDA_VERSION_MAJOR LESS_EQUAL 11) + target_link_libraries(torch_cuda_linalg PRIVATE + CUDA::cusolver_static + ${CUDAToolkit_LIBRARY_DIR}/liblapack_static.a # needed for libcusolver_static + ) + elseif(CUDA_VERSION_MAJOR GREATER_EQUAL 12) + target_link_libraries(torch_cuda_linalg PRIVATE + CUDA::cusolver_static + ${CUDAToolkit_LIBRARY_DIR}/libcusolver_lapack_static.a # needed for libcusolver_static + ) + endif() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else() target_link_libraries(torch_cuda_linalg PRIVATE CUDA::cusolver @@ -1127,11 +1205,14 @@ elseif(USE_CUDA) set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations) set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations) endif() +<<<<<<< HEAD # Set driver api defined for PeerToPeerAccess if(NOT WIN32) set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/cuda/PeerToPeerAccess.cpp PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1") endif() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() if(USE_XPU) @@ -1188,11 +1269,27 @@ if(USE_XPU) if(NOT TARGET torch_xpu_ops) message(WARNING "Failed to include ATen XPU implementation target") else() +<<<<<<< HEAD +======= + target_link_libraries(torch_xpu PRIVATE torch_xpu_ops) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # USE_C10D_XCCL to decide if XCCL backend is enabled in torch-xpu-ops build. if(USE_C10D_XCCL) target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL) endif() +<<<<<<< HEAD target_link_libraries(torch_xpu PRIVATE $) +======= + if(MSVC) + # Windows + target_link_options(torch_xpu PRIVATE + "-WHOLEARCHIVE:$") + else() + # Linux + target_link_options(torch_xpu PRIVATE + "-Wl,--whole-archive,$,--no-whole-archive") + endif() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Set cached ${ATen_XPU_INCLUDE_DIRS} to torch include_directories(SYSTEM ${ATen_XPU_INCLUDE_DIRS}) @@ -1355,6 +1452,13 @@ if(BUILD_TEST) add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit) add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert) add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor) +<<<<<<< HEAD +======= + add_subdirectory( + ${TORCH_ROOT}/test/cpp/tensorexpr + ${CMAKE_BINARY_DIR}/test_tensorexpr + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if(USE_DISTRIBUTED) add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d) if(NOT WIN32) @@ -1452,8 +1556,13 @@ if(USE_ROCM) if(USE_MEM_EFF_ATTENTION) target_compile_definitions(torch_hip PRIVATE USE_MEM_EFF_ATTENTION) endif() +<<<<<<< HEAD if(USE_ROCM_CK_SDPA) target_compile_definitions(torch_hip PRIVATE USE_ROCM_CK_SDPA) +======= + if(USE_CK_FLASH_ATTENTION) + target_compile_definitions(torch_hip PRIVATE USE_CK_FLASH_ATTENTION) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() endif() @@ -1639,11 +1748,15 @@ if(USE_CUDA) endif() target_link_libraries(torch_cuda INTERFACE torch::cudart) target_link_libraries(torch_cuda PUBLIC c10_cuda) +<<<<<<< HEAD if(TARGET torch::nvtx3) target_link_libraries(torch_cuda PRIVATE torch::nvtx3) else() target_link_libraries(torch_cuda PUBLIC torch::nvtoolsext) endif() +======= + target_link_libraries(torch_cuda PRIVATE CUDA::nvtx3) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) target_include_directories( torch_cuda INTERFACE $) @@ -1656,10 +1769,13 @@ if(USE_CUDA) # order of the libraries in the linker call matters here when statically # linking; libculibos and cublas must be last. target_link_libraries(torch_cuda PUBLIC torch_cpu_library ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS}) +<<<<<<< HEAD if(USE_FBGEMM_GENAI) # Link fbgemm_genai to torch_cuda (only for (1) CUDA build for SM100). target_link_libraries(torch_cuda PRIVATE fbgemm_genai) endif() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() # ---[ XPU library. @@ -1739,9 +1855,12 @@ if(BUILD_SHARED_LIBS) if(USE_CUDA) target_link_libraries(torch_global_deps ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS}) target_link_libraries(torch_global_deps torch::cudart) +<<<<<<< HEAD if(TARGET torch::nvtoolsext) target_link_libraries(torch_global_deps torch::nvtoolsext) endif() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() install(TARGETS torch_global_deps DESTINATION "${TORCH_INSTALL_LIB_DIR}") endif() @@ -1780,11 +1899,14 @@ if(USE_ROCM) target_link_libraries(torch_hip PUBLIC torch_cpu_library ${Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS}) target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS}) +<<<<<<< HEAD if(USE_FBGEMM_GENAI) if(USE_ROCM) target_link_libraries(torch_hip PRIVATE fbgemm_genai) endif() endif() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Since PyTorch files contain HIP headers, this is also needed to capture the includes. # ROCM_INCLUDE_DIRS is defined in LoadHIP.cmake target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE} ${ROCM_INCLUDE_DIRS}) diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index e4973c849a18f..56749687d9f94 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -91,6 +91,7 @@ if(INTERN_BUILD_ATEN_OPS) torch_cuda_get_nvcc_gencode_flag(_existing_arch_flags) set(_file_compile_flags "") +<<<<<<< HEAD foreach(_arch ${archs}) if("${_arch}" STREQUAL "89") if(_existing_arch_flags MATCHES ".*compute_86.*") @@ -113,6 +114,32 @@ if(INTERN_BUILD_ATEN_OPS) endif() endif() endforeach() +======= + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0) + foreach(_arch ${archs}) + if("${_arch}" STREQUAL "89") + if(_existing_arch_flags MATCHES ".*compute_86.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_89,code=sm_89") + endif() + endif() + if("${_arch}" STREQUAL "90a") + if(_existing_arch_flags MATCHES ".*compute_90.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_90a,code=sm_90a") + endif() + endif() + if("${_arch}" STREQUAL "100a") + if(_existing_arch_flags MATCHES ".*compute_100.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_100a,code=sm_100a") + endif() + endif() + if("${_arch}" STREQUAL "120a") + if(_existing_arch_flags MATCHES ".*compute_120.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a") + endif() + endif() + endforeach() + endif() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) list(JOIN _file_compile_flags " " _file_compile_flags) set_source_files_properties(${file} PROPERTIES COMPILE_FLAGS "${_file_compile_flags}") @@ -126,7 +153,11 @@ if(INTERN_BUILD_ATEN_OPS) "90a") _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu" +<<<<<<< HEAD "90a;100a") +======= + "90a") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index daceebd8bc889..589ecb2695552 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -153,7 +153,10 @@ set(AT_MKLDNN_ACL_ENABLED 0) set(AT_MKLDNN_ENABLED 0) set(AT_MKL_ENABLED 0) set(AT_KLEIDIAI_ENABLED 0) +<<<<<<< HEAD set(AT_USE_EIGEN_SPARSE 0) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # setting default preferred BLAS options if not already present. if(NOT INTERN_BUILD_MOBILE) set(BLAS "MKL" CACHE STRING "Selected BLAS library") @@ -164,7 +167,10 @@ else() endif() set_property(CACHE BLAS PROPERTY STRINGS "ATLAS;BLIS;Eigen;FLAME;Generic;MKL;OpenBLAS;vecLib;APL") message(STATUS "Trying to find preferred BLAS backend of choice: " ${BLAS}) +<<<<<<< HEAD set(BLAS_CHECK_F2C 0) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if(BLAS STREQUAL "Eigen") # Eigen is header-only and we do not have any dependent libraries @@ -177,7 +183,10 @@ elseif(BLAS STREQUAL "ATLAS") set(BLAS_INFO "atlas") set(BLAS_FOUND 1) set(BLAS_LIBRARIES ${ATLAS_LIBRARIES} cblas) +<<<<<<< HEAD set(BLAS_CHECK_F2C 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elseif(BLAS STREQUAL "OpenBLAS") find_package(OpenBLAS REQUIRED) include_directories(SYSTEM ${OpenBLAS_INCLUDE_DIR}) @@ -185,12 +194,18 @@ elseif(BLAS STREQUAL "OpenBLAS") set(BLAS_INFO "open") set(BLAS_FOUND 1) set(BLAS_LIBRARIES ${OpenBLAS_LIB}) +<<<<<<< HEAD set(BLAS_CHECK_F2C 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elseif(BLAS STREQUAL "BLIS") find_package(BLIS REQUIRED) include_directories(SYSTEM ${BLIS_INCLUDE_DIR}) list(APPEND Caffe2_DEPENDENCY_LIBS ${BLIS_LIB}) +<<<<<<< HEAD set(BLAS_CHECK_F2C 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elseif(BLAS STREQUAL "MKL") if(BLAS_SET_BY_USER) find_package(MKL REQUIRED) @@ -220,7 +235,10 @@ elseif(BLAS STREQUAL "NVPL") set(BLAS_INFO "nvpl") set(BLAS_FOUND 1) set(BLAS_USE_CBLAS_DOT TRUE) +<<<<<<< HEAD set(BLAS_CHECK_F2C 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elseif(BLAS STREQUAL "vecLib") find_package(vecLib REQUIRED) include_directories(SYSTEM ${vecLib_INCLUDE_DIR}) @@ -232,14 +250,20 @@ elseif(BLAS STREQUAL "FlexiBLAS") find_package(FlexiBLAS REQUIRED) include_directories(SYSTEM ${FlexiBLAS_INCLUDE_DIR}) list(APPEND Caffe2_DEPENDENCY_LIBS ${FlexiBLAS_LIB}) +<<<<<<< HEAD set(BLAS_CHECK_F2C 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elseif(BLAS STREQUAL "APL") find_package(APL REQUIRED) include_directories(SYSTEM ${APL_INCLUDE_DIR}) set(BLAS_INFO "apl") set(BLAS_FOUND 1) set(BLAS_LIBRARIES ${APL_LIBRARIES}) +<<<<<<< HEAD set(BLAS_CHECK_F2C 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elseif(BLAS STREQUAL "Generic") # On Debian family, the CBLAS ABIs have been merged into libblas.so if(ENV{GENERIC_BLAS_LIBRARIES} STREQUAL "") @@ -253,11 +277,15 @@ elseif(BLAS STREQUAL "Generic") set(GENERIC_BLAS_FOUND TRUE) set(BLAS_INFO "generic") set(BLAS_FOUND 1) +<<<<<<< HEAD set(BLAS_CHECK_F2C 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else() message(FATAL_ERROR "Unrecognized BLAS option: " ${BLAS}) endif() +<<<<<<< HEAD # Determine if blas was compiled with the f2c conventions if(BLAS_LIBRARIES AND BLAS_CHECK_F2C) include(cmake/BLAS_ABI.cmake) @@ -272,6 +300,8 @@ if(USE_EIGEN_SPARSE) set(AT_USE_EIGEN_SPARSE 1) endif() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if(NOT INTERN_BUILD_MOBILE) set(AT_MKL_SEQUENTIAL 0) set(USE_BLAS 1) @@ -586,7 +616,11 @@ elseif(NOT TARGET XNNPACK AND USE_SYSTEM_XNNPACK) find_library(microkernels-prod_LIBRARY microkernels-prod) set_property(TARGET XNNPACK PROPERTY IMPORTED_LOCATION "${XNNPACK_LIBRARY}") set_property(TARGET microkernels-prod PROPERTY IMPORTED_LOCATION "${microkernels-prod_LIBRARY}") +<<<<<<< HEAD if(NOT XNNPACK_LIBRARY OR NOT microkernels-prod_LIBRARY) +======= + if(NOT XNNPACK_LIBRARY or NOT microkernels-prod_LIBRARY) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) message(FATAL_ERROR "Cannot find XNNPACK") endif() message("-- Found XNNPACK: ${XNNPACK_LIBRARY}") @@ -674,20 +708,71 @@ if(USE_FBGEMM) if(NOT DEFINED FBGEMM_SOURCE_DIR) set(FBGEMM_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/fbgemm" CACHE STRING "FBGEMM source directory") endif() +<<<<<<< HEAD if(USE_FBGEMM AND NOT TARGET fbgemm) set(FBGEMM_BUILD_TESTS OFF CACHE BOOL "") set(FBGEMM_BUILD_BENCHMARKS OFF CACHE BOOL "") set(FBGEMM_LIBRARY_TYPE "static" CACHE STRING "") add_subdirectory("${FBGEMM_SOURCE_DIR}") +======= + if(NOT CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) + message(WARNING + "A compiler with AVX512 support is required for FBGEMM. " + "Not compiling with FBGEMM. " + "Turn this warning off by USE_FBGEMM=OFF.") + set(USE_FBGEMM OFF) + endif() + if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) + message(WARNING + "x64 operating system is required for FBGEMM. " + "Not compiling with FBGEMM. " + "Turn this warning off by USE_FBGEMM=OFF.") + set(USE_FBGEMM OFF) + endif() + if(USE_FBGEMM AND NOT TARGET fbgemm) + set(FBGEMM_BUILD_TESTS OFF CACHE BOOL "") + set(FBGEMM_BUILD_BENCHMARKS OFF CACHE BOOL "") + if(MSVC AND BUILD_SHARED_LIBS) + set(FBGEMM_LIBRARY_TYPE "shared" CACHE STRING "") + else() + set(FBGEMM_LIBRARY_TYPE "static" CACHE STRING "") + endif() + if(USE_ASAN) + set(USE_SANITIZER "address,undefined" CACHE STRING "-fsanitize options for FBGEMM") + endif() + add_subdirectory("${FBGEMM_SOURCE_DIR}") + set_property(TARGET fbgemm_generic PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET fbgemm_avx2 PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET fbgemm_avx512 PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET fbgemm PROPERTY POSITION_INDEPENDENT_CODE ON) + + # Disabling autovec in fbgemm due to large library size causing symbol relocation issues, which is only allowed in static builds. + # Long-term solution involves modularizing fbgemm targets. + target_compile_definitions(fbgemm_generic PUBLIC DISABLE_FBGEMM_AUTOVEC) + target_compile_definitions(fbgemm_avx2 PUBLIC DISABLE_FBGEMM_AUTOVEC) + target_compile_definitions(fbgemm_avx512 PUBLIC DISABLE_FBGEMM_AUTOVEC) + + if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 13.0.0) + # See https://github.com/pytorch/pytorch/issues/74352 + target_compile_options_if_supported(asmjit -Wno-deprecated-copy) + target_compile_options_if_supported(asmjit -Wno-unused-but-set-variable) + endif() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") target_compile_options_if_supported(asmjit -Wno-extra-semi) target_compile_options_if_supported(fbgemm -Wno-extra-semi) endif() +<<<<<<< HEAD target_compile_options_if_supported(asmjit -Wno-unused-but-set-variable) target_compile_options_if_supported(asmjit -Wno-unused-variable) endif() if(USE_FBGEMM) +======= + endif() + if(USE_FBGEMM) + target_compile_definitions(fbgemm PUBLIC DISABLE_FBGEMM_AUTOVEC) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) list(APPEND Caffe2_DEPENDENCY_LIBS fbgemm) endif() endif() @@ -696,6 +781,12 @@ if(USE_FBGEMM) caffe2_update_option(USE_FBGEMM ON) else() caffe2_update_option(USE_FBGEMM OFF) +<<<<<<< HEAD +======= + message(WARNING + "Turning USE_FAKELOWP off as it depends on USE_FBGEMM.") + caffe2_update_option(USE_FAKELOWP OFF) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() if(USE_OPENCL) @@ -953,6 +1044,7 @@ endif() # ---[ nvtx if(USE_SYSTEM_NVTX) find_path(nvtx3_dir NAMES nvtx3 PATHS ${CUDA_INCLUDE_DIRS}) +<<<<<<< HEAD else() find_path(nvtx3_dir NAMES nvtx3 PATHS "${PROJECT_SOURCE_DIR}/third_party/NVTX/c/include" NO_DEFAULT_PATH) endif() @@ -965,6 +1057,19 @@ else() message(WARNING "Cannot find NVTX3, find old NVTX instead") add_library(torch::nvtoolsext INTERFACE IMPORTED) set_property(TARGET torch::nvtoolsext PROPERTY INTERFACE_LINK_LIBRARIES CUDA::nvToolsExt) +======= + find_package_handle_standard_args(nvtx3 DEFAULT_MSG nvtx3_dir) + if(NOT nvtx3_FOUND) + message(WARNING "Cannot find system NVTX3, find shipped NVTX3 instead") + endif() +endif() +if(NOT TARGET CUDA::nvtx3) + add_library(CUDA::nvtx3 INTERFACE IMPORTED) +endif() +if(NOT nvtx3_dir) + find_path(nvtx3_dir NAMES nvtx3 PATHS "${PROJECT_SOURCE_DIR}/third_party/NVTX/c/include" NO_DEFAULT_PATH) + target_include_directories(CUDA::nvtx3 INTERFACE "${nvtx3_dir}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() @@ -1017,9 +1122,12 @@ if(USE_ROCM) if(HIPBLASLT_VEC_EXT) list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT) endif() +<<<<<<< HEAD if(USE_ROCM_CK_GEMM) list(APPEND HIP_CXX_FLAGS -DUSE_ROCM_CK_GEMM) endif() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) list(APPEND HIP_HIPCC_FLAGS --offload-compress) if(WIN32) add_definitions(-DROCM_ON_WINDOWS) @@ -1134,7 +1242,11 @@ if(USE_UCC) endif() # ---[ CUB +<<<<<<< HEAD if(USE_CUDA AND CUDA_VERSION VERSION_LESS 13.0) +======= +if(USE_CUDA) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) find_package(CUB) if(NOT CUB_FOUND) message(FATAL_ERROR "Cannot find CUB.") @@ -1157,10 +1269,23 @@ if(USE_DISTRIBUTED AND USE_TENSORPIPE) # Tensorpipe uses cuda_add_library torch_update_find_cuda_flags() +<<<<<<< HEAD +======= + if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") + message(WARNING "Archived TensorPipe forces CMake compatibility mode") + set(CMAKE_POLICY_VERSION_MINIMUM 3.5) + endif() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/tensorpipe) # Suppress warning to unblock libnop compilation by clang-17 # See https://github.com/pytorch/pytorch/issues/151316 target_compile_options_if_supported(tensorpipe -Wno-missing-template-arg-list-after-template-kw) +<<<<<<< HEAD +======= + if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") + unset(CMAKE_POLICY_VERSION_MINIMUM) + endif() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) list(APPEND Caffe2_DEPENDENCY_LIBS tensorpipe) list(APPEND Caffe2_DEPENDENCY_LIBS nlohmann) @@ -1226,6 +1351,7 @@ if(USE_GLOO) if(NOT Gloo_FOUND) message(FATAL_ERROR "Cannot find gloo") endif() +<<<<<<< HEAD message("Found gloo: ${Gloo_NATIVE_LIBRARY}, cuda lib: ${Gloo_CUDA_LIBRARY}, hip lib: ${Gloo_HIP_LIBRARY}") message("Found gloo include directories: ${Gloo_INCLUDE_DIRS}") add_library(gloo SHARED IMPORTED) @@ -1237,6 +1363,12 @@ if(USE_GLOO) add_library(gloo_hip SHARED IMPORTED) set_target_properties(gloo_hip PROPERTIES IMPORTED_LOCATION ${Gloo_HIP_LIBRARY}) endif() +======= + message("Found gloo: ${Gloo_LIBRARY}") + message("Found gloo include directories: ${Gloo_INCLUDE_DIRS}") + add_library(gloo SHARED IMPORTED) + set_target_properties(gloo PROPERTIES IMPORTED_LOCATION ${Gloo_LIBRARY}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # need to use Gloo_INCLUDE_DIRS over third_party/gloo to find Gloo's auto-generated config.h include_directories(BEFORE SYSTEM ${Gloo_INCLUDE_DIRS}) endif() @@ -1682,9 +1814,15 @@ if(USE_KINETO) set(CMAKE_REQUIRED_LINK_OPTIONS "") if(NOT EXCEPTIONS_WORK) message(FATAL_ERROR +<<<<<<< HEAD "Detected that statically linking against CUPTI causes exceptions to stop working. " "See https://github.com/pytorch/pytorch/issues/57744 for more details. " "Perhaps try: USE_CUPTI_SO=1 CMAKE_FRESH=1 python -m pip install -e . -v --no-build-isolation") +======= + "Detected that statically linking against CUPTI causes exceptions to stop working. " + "See https://github.com/pytorch/pytorch/issues/57744 for more details. " + "Perhaps try: USE_CUPTI_SO=1 CMAKE_FRESH=1 python setup.py develop") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() endif() diff --git a/cmake/MiscCheck.cmake b/cmake/MiscCheck.cmake index 54126b1f130dc..a502a416c0997 100644 --- a/cmake/MiscCheck.cmake +++ b/cmake/MiscCheck.cmake @@ -2,6 +2,27 @@ include(CheckCXXSourceCompiles) include(CheckCXXCompilerFlag) include(CMakePushCheckState) +<<<<<<< HEAD +======= +# ---[ Check if we want to turn off deprecated warning due to glog. +if(USE_GLOG) + cmake_push_check_state(RESET) + set(CMAKE_REQUIRED_FLAGS "-std=c++17") + CHECK_CXX_SOURCE_COMPILES( + "#include + int main(int argc, char** argv) { + return 0; + }" CAFFE2_NEED_TO_TURN_OFF_DEPRECATION_WARNING + FAIL_REGEX ".*-Wno-deprecated.*") + + if(NOT CAFFE2_NEED_TO_TURN_OFF_DEPRECATION_WARNING AND NOT MSVC) + message(STATUS "Turning off deprecation warning due to glog.") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated") + endif() + cmake_pop_check_state() +endif() + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ---[ Check if the compiler has AVX/AVX2 support. We only check AVX2. if(NOT INTERN_BUILD_MOBILE) find_package(AVX) # checks AVX and AVX2 @@ -12,6 +33,49 @@ if(NOT INTERN_BUILD_MOBILE) set(CAFFE2_PERF_WITH_AVX2 1) endif() endif() +<<<<<<< HEAD +======= +# ---[ Check if the compiler has AVX512 support. +cmake_push_check_state(RESET) +if(MSVC AND NOT CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + # We could've used MSVC's hidden option /arch:AVX512 that defines __AVX512F__, + # __AVX512DQ__, and __AVX512VL__, and /arch:AVX512F that defines __AVX512F__. + # But, we chose not to do that not to rely on hidden options. + set(CMAKE_REQUIRED_FLAGS "/D__AVX512F__ /D__AVX512DQ__ /D__AVX512VL__") +else() + # We only consider the case where all of avx512f, avx512dq, and avx512vl are + # supported. + # Platforms where avx512f is supported by not avx512dq and avx512vl as of + # Jan 15 2019 : linux_manywheel_2.7mu_cpu_build and + # linux_conda_3.7_cu100_build + set(CMAKE_REQUIRED_FLAGS "-mavx512f -mavx512dq -mavx512vl") +endif() +CHECK_CXX_SOURCE_COMPILES( + "#if defined(_MSC_VER) + #include + #else + #include + #endif + // check avx512f + __m512 addConstant(__m512 arg) { + return _mm512_add_ps(arg, _mm512_set1_ps(1.f)); + } + // check avx512dq + __m512 andConstant(__m512 arg) { + return _mm512_and_ps(arg, _mm512_set1_ps(1.f)); + } + int main() { + __m512i a = _mm512_set1_epi32(1); + __m256i ymm = _mm512_extracti64x4_epi64(a, 0); + ymm = _mm256_abs_epi64(ymm); // check avx512vl + __mmask16 m = _mm512_cmp_epi32_mask(a, a, _MM_CMPINT_EQ); + __m512i r = _mm512_andnot_si512(a, a); + }" CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) +if(CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) + message(STATUS "Current compiler supports avx512f extension. Will build fbgemm.") +endif() +cmake_pop_check_state() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ---[ Checks if compiler supports -fvisibility=hidden check_cxx_compiler_flag("-fvisibility=hidden" COMPILER_SUPPORTS_HIDDEN_VISIBILITY) diff --git a/cmake/Modules/FindBLAS.cmake b/cmake/Modules/FindBLAS.cmake index b4b158fc4965c..9d8baf1eeac1a 100644 --- a/cmake/Modules/FindBLAS.cmake +++ b/cmake/Modules/FindBLAS.cmake @@ -311,8 +311,85 @@ endif() # Determine if blas was compiled with the f2c conventions IF (BLAS_LIBRARIES) +<<<<<<< HEAD include(cmake/BLAS_ABI.cmake) endif(BLAS_LIBRARIES) +======= + # Push host architecture when cross-compiling otherwise check would fail + # when cross-compiling for arm64 on x86_64 + cmake_push_check_state(RESET) + if(CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64)$") + list(APPEND CMAKE_REQUIRED_FLAGS "-arch ${CMAKE_HOST_SYSTEM_PROCESSOR}") + endif() + +# Set values through env variables if cross compiling + IF (CMAKE_CROSSCOMPILING) + IF("$ENV{PYTORCH_BLAS_F2C}" STREQUAL "ON") + SET(BLAS_F2C TRUE) + ELSE() + SET(BLAS_F2C FALSE) + ENDIF() + + IF("$ENV{PYTORCH_BLAS_USE_CBLAS_DOT}" STREQUAL "ON") + SET(BLAS_USE_CBLAS_DOT TRUE) + ELSE() + SET(BLAS_USE_CBLAS_DOT FALSE) + ENDIF() + ELSE () + SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) + CHECK_C_SOURCE_RUNS(" + #include + #include + float x[4] = { 1, 2, 3, 4 }; + float y[4] = { .1, .01, .001, .0001 }; + int four = 4; + int one = 1; + extern double sdot_(); + int main() { + int i; + double r = sdot_(&four, x, &one, y, &one); + exit((float)r != (float).1234); + }" BLAS_F2C_DOUBLE_WORKS ) + CHECK_C_SOURCE_RUNS(" + #include + #include + float x[4] = { 1, 2, 3, 4 }; + float y[4] = { .1, .01, .001, .0001 }; + int four = 4; + int one = 1; + extern float sdot_(); + int main() { + int i; + double r = sdot_(&four, x, &one, y, &one); + exit((float)r != (float).1234); + }" BLAS_F2C_FLOAT_WORKS ) + IF (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) + MESSAGE(STATUS "This BLAS uses the F2C return conventions") + SET(BLAS_F2C TRUE) + ELSE (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) + SET(BLAS_F2C FALSE) + ENDIF(BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS) + CHECK_C_SOURCE_RUNS(" + #include + #include + float x[4] = { 1, 2, 3, 4 }; + float y[4] = { .1, .01, .001, .0001 }; + extern float cblas_sdot(); + int main() { + int i; + double r = cblas_sdot(4, x, 1, y, 1); + exit((float)r != (float).1234); + }" BLAS_USE_CBLAS_DOT ) + IF (BLAS_USE_CBLAS_DOT) + SET(BLAS_USE_CBLAS_DOT TRUE) + ELSE (BLAS_USE_CBLAS_DOT) + SET(BLAS_USE_CBLAS_DOT FALSE) + ENDIF(BLAS_USE_CBLAS_DOT) + SET(CMAKE_REQUIRED_LIBRARIES) + ENDIF(CMAKE_CROSSCOMPILING) + cmake_pop_check_state() +ENDIF(BLAS_LIBRARIES) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # epilogue diff --git a/cmake/Modules/FindCUB.cmake b/cmake/Modules/FindCUB.cmake index 1b04d07674961..b13bd1d6b48c3 100644 --- a/cmake/Modules/FindCUB.cmake +++ b/cmake/Modules/FindCUB.cmake @@ -3,7 +3,11 @@ # CUB_INCLUDE_DIRS - the CUB include directory find_path(CUB_INCLUDE_DIR +<<<<<<< HEAD HINTS "${CUDAToolkit_INCLUDE_DIRS}" +======= + HINTS "${CUDA_TOOLKIT_INCLUDE}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NAMES cub/cub.cuh DOC "The directory where CUB includes reside" ) diff --git a/cmake/Modules/FindGloo.cmake b/cmake/Modules/FindGloo.cmake index 944cd4d8d2573..1f1d407cafe3e 100644 --- a/cmake/Modules/FindGloo.cmake +++ b/cmake/Modules/FindGloo.cmake @@ -1,8 +1,12 @@ # Try to find the Gloo library and headers. # Gloo_FOUND - system has Gloo lib # Gloo_INCLUDE_DIRS - the Gloo include directory +<<<<<<< HEAD # Gloo_NATIVE_LIBRARY - base gloo library, needs to be linked # Gloo_CUDA_LIBRARY/Gloo_HIP_LIBRARY - CUDA/HIP support library in Gloo +======= +# Gloo_LIBRARY/Gloo_NATIVE_LIBRARY - libraries needed to use Gloo +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) find_path(Gloo_INCLUDE_DIR NAMES gloo/common/common.h @@ -11,6 +15,7 @@ find_path(Gloo_INCLUDE_DIR find_library(Gloo_NATIVE_LIBRARY NAMES gloo +<<<<<<< HEAD DOC "The Gloo library" ) @@ -28,15 +33,49 @@ find_library(Gloo_CUDA_LIBRARY find_library(Gloo_HIP_LIBRARY NAMES gloo_hiop DOC "Gloo's HIP support/code" +======= + DOC "The Gloo library (without CUDA)" +) + +find_library(Gloo_CUDA_LIBRARY + NAMES gloo_cuda + DOC "The Gloo library (with CUDA)" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) set(Gloo_INCLUDE_DIRS ${Gloo_INCLUDE_DIR}) +<<<<<<< HEAD +======= +# use the CUDA library depending on the Gloo_USE_CUDA variable +if (DEFINED Gloo_USE_CUDA) + if (${Gloo_USE_CUDA}) + set(Gloo_LIBRARY ${Gloo_CUDA_LIBRARY}) + set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY}) + else() + set(Gloo_LIBRARY ${Gloo_NATIVE_LIBRARY}) + set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY}) + endif() +else() + # else try to use the CUDA library if found + if (${Gloo_CUDA_LIBRARY} STREQUAL "Gloo_CUDA_LIBRARY-NOTFOUND") + set(Gloo_LIBRARY ${Gloo_NATIVE_LIBRARY}) + set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY}) + else() + set(Gloo_LIBRARY ${Gloo_CUDA_LIBRARY}) + set(Gloo_NATIVE_LIBRARY ${Gloo_NATIVE_LIBRARY}) + endif() +endif() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) include(FindPackageHandleStandardArgs) find_package_handle_standard_args(Gloo FOUND_VAR Gloo_FOUND +<<<<<<< HEAD REQUIRED_VARS Gloo_INCLUDE_DIR Gloo_NATIVE_LIBRARY +======= + REQUIRED_VARS Gloo_INCLUDE_DIR Gloo_LIBRARY +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) mark_as_advanced(Gloo_FOUND) diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index 2018d5ec9370b..7b21873035f20 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -46,8 +46,13 @@ IF(NOT MKLDNN_FOUND) endif() endif() ExternalProject_Add(xpu_mkldnn_proj +<<<<<<< HEAD GIT_REPOSITORY https://github.com/uxlfoundation/oneDNN GIT_TAG v3.9.1 +======= + GIT_REPOSITORY https://github.com/oneapi-src/oneDNN + GIT_TAG v3.8.1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PREFIX ${XPU_MKLDNN_DIR_PREFIX} BUILD_IN_SOURCE 0 CMAKE_ARGS -DCMAKE_C_COMPILER=icx diff --git a/cmake/Modules_CUDA_fix/FindCUDA.cmake b/cmake/Modules_CUDA_fix/FindCUDA.cmake index 55c4e83012d82..7f9d73d248402 100644 --- a/cmake/Modules_CUDA_fix/FindCUDA.cmake +++ b/cmake/Modules_CUDA_fix/FindCUDA.cmake @@ -7,4 +7,8 @@ set(UPSTREAM_FIND_CUDA_DIR "${CMAKE_CURRENT_LIST_DIR}/upstream/") +<<<<<<< HEAD +======= +include("${UPSTREAM_FIND_CUDA_DIR}/CMakeInitializeConfigs.cmake") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) include("${UPSTREAM_FIND_CUDA_DIR}/FindCUDA.cmake") diff --git a/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake b/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake index 411a246656b3b..27b94b27586a6 100644 --- a/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake +++ b/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake @@ -414,7 +414,10 @@ # FindCUDA.cmake +<<<<<<< HEAD include(FindPackageHandleStandardArgs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This macro helps us find the location of helper files we will need the full path to macro(CUDA_FIND_HELPER_FILE _name _extension) set(_full_name "${_name}.${_extension}") @@ -1066,6 +1069,11 @@ set(CUDA_TOOLKIT_TARGET_DIR_INTERNAL "${CUDA_TOOLKIT_TARGET_DIR}" CACHE INTERNAL set(CUDA_SDK_ROOT_DIR_INTERNAL "${CUDA_SDK_ROOT_DIR}" CACHE INTERNAL "This is the value of the last time CUDA_SDK_ROOT_DIR was set successfully." FORCE) +<<<<<<< HEAD +======= +include(${CMAKE_CURRENT_LIST_DIR}/FindPackageHandleStandardArgs.cmake) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) find_package_handle_standard_args(CUDA REQUIRED_VARS CUDA_TOOLKIT_ROOT_DIR diff --git a/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake b/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake index bf7edd69ccd13..0c19b49979169 100644 --- a/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake +++ b/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake @@ -201,7 +201,11 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable) set(add_ptx TRUE) set(arch_name ${CMAKE_MATCH_1}) endif() +<<<<<<< HEAD if(arch_name MATCHES "^([0-9]+\\.[0-9][af]?(\\([0-9]+\\.[0-9]\\))?)$") +======= + if(arch_name MATCHES "^([0-9]+\\.[0-9]a?(\\([0-9]+\\.[0-9]\\))?)$") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) set(arch_bin ${CMAKE_MATCH_1}) set(arch_ptx ${arch_bin}) else() @@ -262,8 +266,13 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable) # remove dots and convert to lists string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}") string(REGEX REPLACE "\\." "" cuda_arch_ptx "${cuda_arch_ptx}") +<<<<<<< HEAD string(REGEX MATCHALL "[0-9()]+[af]?" cuda_arch_bin "${cuda_arch_bin}") string(REGEX MATCHALL "[0-9]+[af]?" cuda_arch_ptx "${cuda_arch_ptx}") +======= + string(REGEX MATCHALL "[0-9()]+a?" cuda_arch_bin "${cuda_arch_bin}") + string(REGEX MATCHALL "[0-9]+a?" cuda_arch_ptx "${cuda_arch_ptx}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if(cuda_arch_bin) list(REMOVE_DUPLICATES cuda_arch_bin) diff --git a/cmake/Modules_CUDA_fix/upstream/FindPackageHandleStandardArgs.cmake b/cmake/Modules_CUDA_fix/upstream/FindPackageHandleStandardArgs.cmake new file mode 100644 index 0000000000000..67f6bd6f2bcd1 --- /dev/null +++ b/cmake/Modules_CUDA_fix/upstream/FindPackageHandleStandardArgs.cmake @@ -0,0 +1,386 @@ +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. + +#[=======================================================================[.rst: +FindPackageHandleStandardArgs +----------------------------- + +This module provides a function intended to be used in :ref:`Find Modules` +implementing :command:`find_package()` calls. It handles the +``REQUIRED``, ``QUIET`` and version-related arguments of ``find_package``. +It also sets the ``_FOUND`` variable. The package is +considered found if all variables listed contain valid results, e.g. +valid filepaths. + +.. command:: find_package_handle_standard_args + + There are two signatures:: + + find_package_handle_standard_args( + (DEFAULT_MSG|) + ... + ) + + find_package_handle_standard_args( + [FOUND_VAR ] + [REQUIRED_VARS ...] + [VERSION_VAR ] + [HANDLE_COMPONENTS] + [CONFIG_MODE] + [FAIL_MESSAGE ] + ) + + The ``_FOUND`` variable will be set to ``TRUE`` if all + the variables ``...`` are valid and any optional + constraints are satisfied, and ``FALSE`` otherwise. A success or + failure message may be displayed based on the results and on + whether the ``REQUIRED`` and/or ``QUIET`` option was given to + the :command:`find_package` call. + + The options are: + + ``(DEFAULT_MSG|)`` + In the simple signature this specifies the failure message. + Use ``DEFAULT_MSG`` to ask for a default message to be computed + (recommended). Not valid in the full signature. + + ``FOUND_VAR `` + Obsolete. Specifies either ``_FOUND`` or + ``_FOUND`` as the result variable. This exists only + for compatibility with older versions of CMake and is now ignored. + Result variables of both names are always set for compatibility. + + ``REQUIRED_VARS ...`` + Specify the variables which are required for this package. + These may be named in the generated failure message asking the + user to set the missing variable values. Therefore these should + typically be cache entries such as ``FOO_LIBRARY`` and not output + variables like ``FOO_LIBRARIES``. + + ``VERSION_VAR `` + Specify the name of a variable that holds the version of the package + that has been found. This version will be checked against the + (potentially) specified required version given to the + :command:`find_package` call, including its ``EXACT`` option. + The default messages include information about the required + version and the version which has been actually found, both + if the version is ok or not. + + ``HANDLE_COMPONENTS`` + Enable handling of package components. In this case, the command + will report which components have been found and which are missing, + and the ``_FOUND`` variable will be set to ``FALSE`` + if any of the required components (i.e. not the ones listed after + the ``OPTIONAL_COMPONENTS`` option of :command:`find_package`) are + missing. + + ``CONFIG_MODE`` + Specify that the calling find module is a wrapper around a + call to ``find_package( NO_MODULE)``. This implies + a ``VERSION_VAR`` value of ``_VERSION``. The command + will automatically check whether the package configuration file + was found. + + ``FAIL_MESSAGE `` + Specify a custom failure message instead of using the default + generated message. Not recommended. + +Example for the simple signature: + +.. code-block:: cmake + + find_package_handle_standard_args(LibXml2 DEFAULT_MSG + LIBXML2_LIBRARY LIBXML2_INCLUDE_DIR) + +The ``LibXml2`` package is considered to be found if both +``LIBXML2_LIBRARY`` and ``LIBXML2_INCLUDE_DIR`` are valid. +Then also ``LibXml2_FOUND`` is set to ``TRUE``. If it is not found +and ``REQUIRED`` was used, it fails with a +:command:`message(FATAL_ERROR)`, independent whether ``QUIET`` was +used or not. If it is found, success will be reported, including +the content of the first ````. On repeated CMake runs, +the same message will not be printed again. + +Example for the full signature: + +.. code-block:: cmake + + find_package_handle_standard_args(LibArchive + REQUIRED_VARS LibArchive_LIBRARY LibArchive_INCLUDE_DIR + VERSION_VAR LibArchive_VERSION) + +In this case, the ``LibArchive`` package is considered to be found if +both ``LibArchive_LIBRARY`` and ``LibArchive_INCLUDE_DIR`` are valid. +Also the version of ``LibArchive`` will be checked by using the version +contained in ``LibArchive_VERSION``. Since no ``FAIL_MESSAGE`` is given, +the default messages will be printed. + +Another example for the full signature: + +.. code-block:: cmake + + find_package(Automoc4 QUIET NO_MODULE HINTS /opt/automoc4) + find_package_handle_standard_args(Automoc4 CONFIG_MODE) + +In this case, a ``FindAutmoc4.cmake`` module wraps a call to +``find_package(Automoc4 NO_MODULE)`` and adds an additional search +directory for ``automoc4``. Then the call to +``find_package_handle_standard_args`` produces a proper success/failure +message. +#]=======================================================================] + +include(${CMAKE_CURRENT_LIST_DIR}/FindPackageMessage.cmake) + +# internal helper macro +macro(_FPHSA_FAILURE_MESSAGE _msg) + if (${_NAME}_FIND_REQUIRED) + message(FATAL_ERROR "${_msg}") + else () + if (NOT ${_NAME}_FIND_QUIETLY) + message(STATUS "${_msg}") + endif () + endif () +endmacro() + + +# internal helper macro to generate the failure message when used in CONFIG_MODE: +macro(_FPHSA_HANDLE_FAILURE_CONFIG_MODE) + # _CONFIG is set, but FOUND is false, this means that some other of the REQUIRED_VARS was not found: + if(${_NAME}_CONFIG) + _FPHSA_FAILURE_MESSAGE("${FPHSA_FAIL_MESSAGE}: missing:${MISSING_VARS} (found ${${_NAME}_CONFIG} ${VERSION_MSG})") + else() + # If _CONSIDERED_CONFIGS is set, the config-file has been found, but no suitable version. + # List them all in the error message: + if(${_NAME}_CONSIDERED_CONFIGS) + set(configsText "") + list(LENGTH ${_NAME}_CONSIDERED_CONFIGS configsCount) + math(EXPR configsCount "${configsCount} - 1") + foreach(currentConfigIndex RANGE ${configsCount}) + list(GET ${_NAME}_CONSIDERED_CONFIGS ${currentConfigIndex} filename) + list(GET ${_NAME}_CONSIDERED_VERSIONS ${currentConfigIndex} version) + string(APPEND configsText " ${filename} (version ${version})\n") + endforeach() + if (${_NAME}_NOT_FOUND_MESSAGE) + string(APPEND configsText " Reason given by package: ${${_NAME}_NOT_FOUND_MESSAGE}\n") + endif() + _FPHSA_FAILURE_MESSAGE("${FPHSA_FAIL_MESSAGE} ${VERSION_MSG}, checked the following files:\n${configsText}") + + else() + # Simple case: No Config-file was found at all: + _FPHSA_FAILURE_MESSAGE("${FPHSA_FAIL_MESSAGE}: found neither ${_NAME}Config.cmake nor ${_NAME_LOWER}-config.cmake ${VERSION_MSG}") + endif() + endif() +endmacro() + + +function(FIND_PACKAGE_HANDLE_STANDARD_ARGS _NAME _FIRST_ARG) + +# Set up the arguments for `cmake_parse_arguments`. + set(options CONFIG_MODE HANDLE_COMPONENTS) + set(oneValueArgs FAIL_MESSAGE VERSION_VAR FOUND_VAR) + set(multiValueArgs REQUIRED_VARS) + +# Check whether we are in 'simple' or 'extended' mode: + set(_KEYWORDS_FOR_EXTENDED_MODE ${options} ${oneValueArgs} ${multiValueArgs} ) + list(FIND _KEYWORDS_FOR_EXTENDED_MODE "${_FIRST_ARG}" INDEX) + + if(${INDEX} EQUAL -1) + set(FPHSA_FAIL_MESSAGE ${_FIRST_ARG}) + set(FPHSA_REQUIRED_VARS ${ARGN}) + set(FPHSA_VERSION_VAR) + else() + cmake_parse_arguments(FPHSA "${options}" "${oneValueArgs}" "${multiValueArgs}" ${_FIRST_ARG} ${ARGN}) + + if(FPHSA_UNPARSED_ARGUMENTS) + message(FATAL_ERROR "Unknown keywords given to FIND_PACKAGE_HANDLE_STANDARD_ARGS(): \"${FPHSA_UNPARSED_ARGUMENTS}\"") + endif() + + if(NOT FPHSA_FAIL_MESSAGE) + set(FPHSA_FAIL_MESSAGE "DEFAULT_MSG") + endif() + + # In config-mode, we rely on the variable _CONFIG, which is set by find_package() + # when it successfully found the config-file, including version checking: + if(FPHSA_CONFIG_MODE) + list(INSERT FPHSA_REQUIRED_VARS 0 ${_NAME}_CONFIG) + list(REMOVE_DUPLICATES FPHSA_REQUIRED_VARS) + set(FPHSA_VERSION_VAR ${_NAME}_VERSION) + endif() + + if(NOT FPHSA_REQUIRED_VARS) + message(FATAL_ERROR "No REQUIRED_VARS specified for FIND_PACKAGE_HANDLE_STANDARD_ARGS()") + endif() + endif() + +# now that we collected all arguments, process them + + if("x${FPHSA_FAIL_MESSAGE}" STREQUAL "xDEFAULT_MSG") + set(FPHSA_FAIL_MESSAGE "Could NOT find ${_NAME}") + endif() + + list(GET FPHSA_REQUIRED_VARS 0 _FIRST_REQUIRED_VAR) + + string(TOUPPER ${_NAME} _NAME_UPPER) + string(TOLOWER ${_NAME} _NAME_LOWER) + + if(FPHSA_FOUND_VAR) + if(FPHSA_FOUND_VAR MATCHES "^${_NAME}_FOUND$" OR FPHSA_FOUND_VAR MATCHES "^${_NAME_UPPER}_FOUND$") + set(_FOUND_VAR ${FPHSA_FOUND_VAR}) + else() + message(FATAL_ERROR "The argument for FOUND_VAR is \"${FPHSA_FOUND_VAR}\", but only \"${_NAME}_FOUND\" and \"${_NAME_UPPER}_FOUND\" are valid names.") + endif() + else() + set(_FOUND_VAR ${_NAME_UPPER}_FOUND) + endif() + + # collect all variables which were not found, so they can be printed, so the + # user knows better what went wrong (#6375) + set(MISSING_VARS "") + set(DETAILS "") + # check if all passed variables are valid + set(FPHSA_FOUND_${_NAME} TRUE) + foreach(_CURRENT_VAR ${FPHSA_REQUIRED_VARS}) + if(NOT ${_CURRENT_VAR}) + set(FPHSA_FOUND_${_NAME} FALSE) + string(APPEND MISSING_VARS " ${_CURRENT_VAR}") + else() + string(APPEND DETAILS "[${${_CURRENT_VAR}}]") + endif() + endforeach() + if(FPHSA_FOUND_${_NAME}) + set(${_NAME}_FOUND TRUE) + set(${_NAME_UPPER}_FOUND TRUE) + else() + set(${_NAME}_FOUND FALSE) + set(${_NAME_UPPER}_FOUND FALSE) + endif() + + # component handling + unset(FOUND_COMPONENTS_MSG) + unset(MISSING_COMPONENTS_MSG) + + if(FPHSA_HANDLE_COMPONENTS) + foreach(comp ${${_NAME}_FIND_COMPONENTS}) + if(${_NAME}_${comp}_FOUND) + + if(NOT DEFINED FOUND_COMPONENTS_MSG) + set(FOUND_COMPONENTS_MSG "found components: ") + endif() + string(APPEND FOUND_COMPONENTS_MSG " ${comp}") + + else() + + if(NOT DEFINED MISSING_COMPONENTS_MSG) + set(MISSING_COMPONENTS_MSG "missing components: ") + endif() + string(APPEND MISSING_COMPONENTS_MSG " ${comp}") + + if(${_NAME}_FIND_REQUIRED_${comp}) + set(${_NAME}_FOUND FALSE) + string(APPEND MISSING_VARS " ${comp}") + endif() + + endif() + endforeach() + set(COMPONENT_MSG "${FOUND_COMPONENTS_MSG} ${MISSING_COMPONENTS_MSG}") + string(APPEND DETAILS "[c${COMPONENT_MSG}]") + endif() + + # version handling: + set(VERSION_MSG "") + set(VERSION_OK TRUE) + + # check with DEFINED here as the requested or found version may be "0" + if (DEFINED ${_NAME}_FIND_VERSION) + if(DEFINED ${FPHSA_VERSION_VAR}) + set(_FOUND_VERSION ${${FPHSA_VERSION_VAR}}) + + if(${_NAME}_FIND_VERSION_EXACT) # exact version required + # count the dots in the version string + string(REGEX REPLACE "[^.]" "" _VERSION_DOTS "${_FOUND_VERSION}") + # add one dot because there is one dot more than there are components + string(LENGTH "${_VERSION_DOTS}." _VERSION_DOTS) + if (_VERSION_DOTS GREATER ${_NAME}_FIND_VERSION_COUNT) + # Because of the C++ implementation of find_package() ${_NAME}_FIND_VERSION_COUNT + # is at most 4 here. Therefore a simple lookup table is used. + if (${_NAME}_FIND_VERSION_COUNT EQUAL 1) + set(_VERSION_REGEX "[^.]*") + elseif (${_NAME}_FIND_VERSION_COUNT EQUAL 2) + set(_VERSION_REGEX "[^.]*\\.[^.]*") + elseif (${_NAME}_FIND_VERSION_COUNT EQUAL 3) + set(_VERSION_REGEX "[^.]*\\.[^.]*\\.[^.]*") + else () + set(_VERSION_REGEX "[^.]*\\.[^.]*\\.[^.]*\\.[^.]*") + endif () + string(REGEX REPLACE "^(${_VERSION_REGEX})\\..*" "\\1" _VERSION_HEAD "${_FOUND_VERSION}") + unset(_VERSION_REGEX) + if (NOT ${_NAME}_FIND_VERSION VERSION_EQUAL _VERSION_HEAD) + set(VERSION_MSG "Found unsuitable version \"${_FOUND_VERSION}\", but required is exact version \"${${_NAME}_FIND_VERSION}\"") + set(VERSION_OK FALSE) + else () + set(VERSION_MSG "(found suitable exact version \"${_FOUND_VERSION}\")") + endif () + unset(_VERSION_HEAD) + else () + if (NOT ${_NAME}_FIND_VERSION VERSION_EQUAL _FOUND_VERSION) + set(VERSION_MSG "Found unsuitable version \"${_FOUND_VERSION}\", but required is exact version \"${${_NAME}_FIND_VERSION}\"") + set(VERSION_OK FALSE) + else () + set(VERSION_MSG "(found suitable exact version \"${_FOUND_VERSION}\")") + endif () + endif () + unset(_VERSION_DOTS) + + else() # minimum version specified: + if (${_NAME}_FIND_VERSION VERSION_GREATER _FOUND_VERSION) + set(VERSION_MSG "Found unsuitable version \"${_FOUND_VERSION}\", but required is at least \"${${_NAME}_FIND_VERSION}\"") + set(VERSION_OK FALSE) + else () + set(VERSION_MSG "(found suitable version \"${_FOUND_VERSION}\", minimum required is \"${${_NAME}_FIND_VERSION}\")") + endif () + endif() + + else() + + # if the package was not found, but a version was given, add that to the output: + if(${_NAME}_FIND_VERSION_EXACT) + set(VERSION_MSG "(Required is exact version \"${${_NAME}_FIND_VERSION}\")") + else() + set(VERSION_MSG "(Required is at least version \"${${_NAME}_FIND_VERSION}\")") + endif() + + endif() + else () + # Check with DEFINED as the found version may be 0. + if(DEFINED ${FPHSA_VERSION_VAR}) + set(VERSION_MSG "(found version \"${${FPHSA_VERSION_VAR}}\")") + endif() + endif () + + if(VERSION_OK) + string(APPEND DETAILS "[v${${FPHSA_VERSION_VAR}}(${${_NAME}_FIND_VERSION})]") + else() + set(${_NAME}_FOUND FALSE) + endif() + + + # print the result: + if (${_NAME}_FOUND) + FIND_PACKAGE_MESSAGE(${_NAME} "Found ${_NAME}: ${${_FIRST_REQUIRED_VAR}} ${VERSION_MSG} ${COMPONENT_MSG}" "${DETAILS}") + else () + + if(FPHSA_CONFIG_MODE) + _FPHSA_HANDLE_FAILURE_CONFIG_MODE() + else() + if(NOT VERSION_OK) + _FPHSA_FAILURE_MESSAGE("${FPHSA_FAIL_MESSAGE}: ${VERSION_MSG} (found ${${_FIRST_REQUIRED_VAR}})") + else() + _FPHSA_FAILURE_MESSAGE("${FPHSA_FAIL_MESSAGE} (missing:${MISSING_VARS}) ${VERSION_MSG}") + endif() + endif() + + endif () + + set(${_NAME}_FOUND ${${_NAME}_FOUND} PARENT_SCOPE) + set(${_NAME_UPPER}_FOUND ${${_NAME}_FOUND} PARENT_SCOPE) +endfunction() diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 745d9ea058687..2721f34d1ed14 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -70,6 +70,10 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_CPP_CODE_COVERAGE : ${USE_CPP_CODE_COVERAGE}") message(STATUS " USE_CUDA : ${USE_CUDA}") if(${USE_CUDA}) +<<<<<<< HEAD +======= + message(STATUS " Split CUDA : ${BUILD_SPLIT_CUDA}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}") message(STATUS " USE_CUDNN : ${USE_CUDNN}") message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}") @@ -127,6 +131,7 @@ function(caffe2_print_configuration_summary) endif() message(STATUS " USE_ROCM : ${USE_ROCM}") if(${USE_ROCM}) +<<<<<<< HEAD message(STATUS " ROCM_VERSION : ${ROCM_VERSION}") message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}") @@ -137,6 +142,17 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}") message(STATUS " USE_EIGEN_FOR_SPARSE : ${USE_EIGEN_SPARSE}") message(STATUS " USE_FBGEMM : ${USE_FBGEMM}") +======= + message(STATUS " ROCM_VERSION : ${ROCM_VERSION}") + message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") + message(STATUS " USE_CK_FLASH_ATTENTION : ${USE_CK_FLASH_ATTENTION}") + message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}") + endif() + message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}") + message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}") + message(STATUS " USE_FBGEMM : ${USE_FBGEMM}") + message(STATUS " USE_FAKELOWP : ${USE_FAKELOWP}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) message(STATUS " USE_KINETO : ${USE_KINETO}") message(STATUS " USE_GFLAGS : ${USE_GFLAGS}") message(STATUS " USE_GLOG : ${USE_GLOG}") @@ -172,7 +188,10 @@ function(caffe2_print_configuration_summary) if(${USE_NCCL}) message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}") endif() +<<<<<<< HEAD message(STATUS " Found NVSHMEM : ${NVSHMEM_INCLUDE_DIR}") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) message(STATUS " USE_NNPACK : ${USE_NNPACK}") message(STATUS " USE_NUMPY : ${USE_NUMPY}") message(STATUS " USE_OBSERVERS : ${USE_OBSERVERS}") diff --git a/cmake/TorchConfig.cmake.in b/cmake/TorchConfig.cmake.in index 0b32ffa99ceb5..db80ec7796133 100644 --- a/cmake/TorchConfig.cmake.in +++ b/cmake/TorchConfig.cmake.in @@ -132,9 +132,12 @@ if(@USE_CUDA@) else() set(TORCH_CUDA_LIBRARIES ${CUDA_NVRTC_LIB}) endif() +<<<<<<< HEAD if(TARGET torch::nvtoolsext) list(APPEND TORCH_CUDA_LIBRARIES torch::nvtoolsext) endif() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if(@BUILD_SHARED_LIBS@) find_library(C10_CUDA_LIBRARY c10_cuda PATHS "${TORCH_INSTALL_PREFIX}/lib") diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 018bca837a5a8..8af6b682a1c4f 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -6,7 +6,11 @@ set(PYTORCH_FOUND_HIP FALSE) # In the latter case, if /opt/rocm does not exist emit status # message and return. if(DEFINED ENV{ROCM_PATH}) +<<<<<<< HEAD file(TO_CMAKE_PATH "$ENV{ROCM_PATH}" ROCM_PATH) +======= + set(ROCM_PATH $ENV{ROCM_PATH}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if(NOT EXISTS ${ROCM_PATH}) message(FATAL_ERROR "ROCM_PATH environment variable is set to ${ROCM_PATH} but does not exist.\n" @@ -31,7 +35,11 @@ if(NOT DEFINED ENV{MAGMA_HOME}) set(MAGMA_HOME ${ROCM_PATH}/magma) set(ENV{MAGMA_HOME} ${ROCM_PATH}/magma) else() +<<<<<<< HEAD file(TO_CMAKE_PATH "$ENV{MAGMA_HOME}" MAGMA_HOME) +======= + set(MAGMA_HOME $ENV{MAGMA_HOME}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() # MIOpen isn't a part of HIP-SDK for Windows and hence, may have a different diff --git a/cmake/public/cuda.cmake b/cmake/public/cuda.cmake index 218c50a69c6fb..852a7e43e6278 100644 --- a/cmake/public/cuda.cmake +++ b/cmake/public/cuda.cmake @@ -69,8 +69,13 @@ endif() message(STATUS "PyTorch: CUDA detected: " ${CUDA_VERSION}) message(STATUS "PyTorch: CUDA nvcc is: " ${CUDA_NVCC_EXECUTABLE}) message(STATUS "PyTorch: CUDA toolkit directory: " ${CUDA_TOOLKIT_ROOT_DIR}) +<<<<<<< HEAD if(CUDA_VERSION VERSION_LESS 12.0) message(FATAL_ERROR "PyTorch requires CUDA 12.0 or above.") +======= +if(CUDA_VERSION VERSION_LESS 11.0) + message(FATAL_ERROR "PyTorch requires CUDA 11.0 or above.") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() if(CUDA_FOUND) @@ -110,7 +115,11 @@ if(CUDA_FOUND) # Force CUDA to be processed for again next time # TODO: I'm not sure if this counts as an implementation detail of # FindCUDA +<<<<<<< HEAD set(cuda_version_from_findcuda ${CUDA_VERSION_STRING}) +======= + set(${cuda_version_from_findcuda} ${CUDA_VERSION_STRING}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unset(CUDA_TOOLKIT_ROOT_DIR_INTERNAL CACHE) # Not strictly necessary, but for good luck. unset(CUDA_VERSION CACHE) @@ -282,6 +291,7 @@ endif() # cufft add_library(caffe2::cufft INTERFACE IMPORTED) if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32) +<<<<<<< HEAD if(CUDA_VERSION VERSION_LESS_EQUAL 12.9) set_property( TARGET caffe2::cufft PROPERTY INTERFACE_LINK_LIBRARIES @@ -291,6 +301,11 @@ if(CAFFE2_STATIC_LINK_CUDA AND NOT WIN32) TARGET caffe2::cufft PROPERTY INTERFACE_LINK_LIBRARIES CUDA::cufft_static) endif() +======= + set_property( + TARGET caffe2::cufft PROPERTY INTERFACE_LINK_LIBRARIES + CUDA::cufft_static_nocallback) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else() set_property( TARGET caffe2::cufft PROPERTY INTERFACE_LINK_LIBRARIES @@ -319,6 +334,7 @@ endif() # setting nvcc arch flags torch_cuda_get_nvcc_gencode_flag(NVCC_FLAGS_EXTRA) # CMake 3.18 adds integrated support for architecture selection, but we can't rely on it +<<<<<<< HEAD if(DEFINED CMAKE_CUDA_ARCHITECTURES) message(WARNING "pytorch is not compatible with `CMAKE_CUDA_ARCHITECTURES` and will ignore its value. " @@ -326,6 +342,9 @@ if(DEFINED CMAKE_CUDA_ARCHITECTURES) set(CMAKE_CUDA_ARCHITECTURES OFF) endif() +======= +set(CMAKE_CUDA_ARCHITECTURES OFF) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) list(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA}) message(STATUS "Added CUDA NVCC flags for: ${NVCC_FLAGS_EXTRA}") diff --git a/cmake/public/utils.cmake b/cmake/public/utils.cmake index 68e66bb3fc386..7de0752b63fc1 100644 --- a/cmake/public/utils.cmake +++ b/cmake/public/utils.cmake @@ -163,7 +163,24 @@ macro(caffe2_interface_library SRC DST) # link command for the specific SRC library. if(${__src_target_type} STREQUAL "STATIC_LIBRARY") # In the case of static library, we will need to add whole-static flags. +<<<<<<< HEAD target_link_libraries(${DST} INTERFACE $) +======= + if(APPLE) + target_link_libraries( + ${DST} INTERFACE -Wl,-force_load,\"$\") + elseif(MSVC) + # In MSVC, we will add whole archive in default. + target_link_libraries( + ${DST} INTERFACE "$") + target_link_options( + ${DST} INTERFACE "-WHOLEARCHIVE:$") + else() + # Assume everything else is like gcc + target_link_libraries(${DST} INTERFACE + "-Wl,--whole-archive,\"$\" -Wl,--no-whole-archive") + endif() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Link all interface link libraries of the src target as well. # For static library, we need to explicitly depend on all the libraries # that are the dependent library of the source library. Note that we cannot @@ -362,6 +379,17 @@ function(torch_compile_options libname) # For MS official doc: https://learn.microsoft.com/en-us/cpp/build/reference/zc-preprocessor set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:preprocessor" PARENT_SCOPE) +<<<<<<< HEAD +======= + if(${MSVC_TOOLSET_VERSION} GREATER_EQUAL 143) + # Add /d2implyavx512upperregs- to disable compiler over-aggressive optimization, which caused involeved AVX512 register on AVX2 machine. + # Reference: https://github.com/pytorch/pytorch/issues/145702#issuecomment-2874029459 + target_compile_options(${libname} PUBLIC $<$:/d2implyavx512upperregs->) + endif() + + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) target_compile_options(${libname} PUBLIC $<$: ${MSVC_RUNTIME_LIBRARY_OPTION} @@ -386,7 +414,11 @@ function(torch_compile_options libname) list(APPEND private_compile_options -Wredundant-move) endif() if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") +<<<<<<< HEAD list(APPEND private_compile_options -Wextra-semi -Wmove) +======= + list(APPEND private_compile_options -Wextra-semi -Wno-error=extra-semi -Wmove) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else() list(APPEND private_compile_options # Considered to be flaky. See the discussion at @@ -397,7 +429,10 @@ function(torch_compile_options libname) if(WERROR) list(APPEND private_compile_options -Werror +<<<<<<< HEAD -Werror=ignored-attributes +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -Werror=inconsistent-missing-override -Werror=inconsistent-missing-destructor-override -Werror=pedantic diff --git a/cmake/public/xpu.cmake b/cmake/public/xpu.cmake index b39e31d0ade8a..d78e4a1e3eb83 100644 --- a/cmake/public/xpu.cmake +++ b/cmake/public/xpu.cmake @@ -11,7 +11,10 @@ set(XPU_HOST_CXX_FLAGS) find_package(SYCLToolkit REQUIRED) if(NOT SYCL_FOUND) set(PYTORCH_FOUND_XPU FALSE) +<<<<<<< HEAD # Exit early to avoid populating XPU_HOST_CXX_FLAGS. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return() endif() set(PYTORCH_FOUND_XPU TRUE) @@ -37,8 +40,11 @@ torch_xpu_get_arch_list(XPU_ARCH_FLAGS) # propagate to torch-xpu-ops set(TORCH_XPU_ARCH_LIST ${XPU_ARCH_FLAGS}) +<<<<<<< HEAD # Ensure USE_XPU is enabled. string(APPEND XPU_HOST_CXX_FLAGS " -DUSE_XPU") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) string(APPEND XPU_HOST_CXX_FLAGS " -DSYCL_COMPILER_VERSION=${SYCL_COMPILER_VERSION}") if(DEFINED ENV{XPU_ENABLE_KINETO}) diff --git a/docs/Makefile b/docs/Makefile index 1337a1fc5dc03..998403b3b7df7 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -15,8 +15,17 @@ help: figures: @$(PYCMD) source/scripts/build_activation_images.py +<<<<<<< HEAD @$(PYCMD) source/scripts/build_lr_scheduler_images.py +======= + @$(PYCMD) source/scripts/build_quantization_configs.py + @$(PYCMD) source/scripts/build_lr_scheduler_images.py + +onnx: + @$(PYCMD) source/scripts/onnx/build_onnx_torchscript_supported_aten_op_csv_table.py + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) opset: @$(PYCMD) source/scripts/build_opsets.py diff --git a/docs/cpp/source/conf.py b/docs/cpp/source/conf.py index 10d854c21db4f..34c13466742a3 100644 --- a/docs/cpp/source/conf.py +++ b/docs/cpp/source/conf.py @@ -40,6 +40,7 @@ "sphinx.ext.intersphinx", ] + (["breathe", "exhale"] if run_doxygen else []) +<<<<<<< HEAD intersphinx_mapping = {"pytorch": ("https://docs.pytorch.org/docs/main", None)} # Configure Sphinx warnings and error handling @@ -68,6 +69,9 @@ collections.MutableMapping = MutableMapping except ImportError: pass +======= +intersphinx_mapping = {"pytorch": ("https://pytorch.org/docs/main", None)} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Setup absolute paths for communicating with breathe / exhale where # items are expected / should be trimmed by. @@ -128,6 +132,7 @@ Welcome to the developer reference for the PyTorch C++ API. """ ), +<<<<<<< HEAD ############################################################################ # Duplicate handling and error management. # ############################################################################ @@ -143,6 +148,8 @@ "variable", }, "fullToctreeMaxDepth": 2, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } # Tell sphinx what the primary language being documented is. @@ -216,7 +223,10 @@ # html_theme_options = { "canonical_url": "https://pytorch.org/docs/stable/", +<<<<<<< HEAD "analytics_id": "GTM-T8XT4PS", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "collapse_navigation": False, "logo": {"text": "Home"}, "icon_links": [ diff --git a/docs/source/_static/img/onnx/torch_script_exporter_memory_usage.png b/docs/source/_static/img/onnx/torch_script_exporter_memory_usage.png new file mode 100644 index 0000000000000..b9c81a71ef3c0 Binary files /dev/null and b/docs/source/_static/img/onnx/torch_script_exporter_memory_usage.png differ diff --git a/docs/source/accelerator.md b/docs/source/accelerator.md index ce593a9acf518..c36d094ba17a3 100644 --- a/docs/source/accelerator.md +++ b/docs/source/accelerator.md @@ -25,6 +25,7 @@ synchronize device_index ``` +<<<<<<< HEAD ```{eval-rst} .. automodule:: torch.accelerator.memory @@ -48,3 +49,5 @@ reset_accumulated_memory_stats reset_peak_memory_stats ``` +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/docs/source/backends.md b/docs/source/backends.md index 3e6cdc9697bf0..7a46919de07da 100644 --- a/docs/source/backends.md +++ b/docs/source/backends.md @@ -54,7 +54,11 @@ These backends include: .. attribute:: allow_tf32 A :class:`bool` that controls whether TensorFloat-32 tensor cores may be used in matrix +<<<<<<< HEAD multiplications on Ampere or newer GPUs. allow_tf32 is going to be deprecated. See :ref:`tf32_on_ampere`. +======= + multiplications on Ampere or newer GPUs. See :ref:`tf32_on_ampere`. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` ```{eval-rst} @@ -193,7 +197,11 @@ These backends include: .. attribute:: allow_tf32 A :class:`bool` that controls where TensorFloat-32 tensor cores may be used in cuDNN +<<<<<<< HEAD convolutions on Ampere or newer GPUs. allow_tf32 is going to be deprecated. See :ref:`tf32_on_ampere`. +======= + convolutions on Ampere or newer GPUs. See :ref:`tf32_on_ampere`. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` ```{eval-rst} @@ -253,6 +261,7 @@ These backends include: ``` +<<<<<<< HEAD ## torch.backends.miopen ```{eval-rst} @@ -266,6 +275,8 @@ These backends include: (https://rocm.docs.amd.com/projects/MIOpen/en/latest/how-to/find-and-immediate.html). ``` +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ## torch.backends.mps ```{eval-rst} diff --git a/docs/source/community/persons_of_interest.rst b/docs/source/community/persons_of_interest.rst index d66cf86b4444d..18e7f07689fca 100644 --- a/docs/source/community/persons_of_interest.rst +++ b/docs/source/community/persons_of_interest.rst @@ -131,12 +131,19 @@ Distributed - Ke Wen (`kwen2501 `__) - Chien-Chin Huang (`fegin `__) - Tristan Rice (`d4l3k `__) +<<<<<<< HEAD - Junjie Wang (`fduwjj `__) - Wei Feng (`weifengpy `__) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - (emeritus) Shen Li (`mrshenli `__) - (emeritus) Pritam Damania (`pritamdamania87 `__) - (emeritus) Yanli Zhao (`zhaojuanmao `__) - (emeritus) Rohan Varma (`rohan-varma `__) +<<<<<<< HEAD +======= +- (emeritus) Junjie Wang (`fduwjj `__) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - (emeritus) Alisson Azzolini (`aazzolini `__) - (emeritus) James Reed (`jamesr66a `__) - (emeritus) Kiuk Chung (`kiukchung `__) @@ -350,9 +357,15 @@ XLA TorchServe ~~~~~~~~~~ +<<<<<<< HEAD - (emeritus) Li Ning (`lxning `__) - (emeritus) Ankith Gunapal (`agunapal `__) - (emeritus) Hamid Shojanazeri (`HamidShojanazeri `__) +======= +- Li Ning (`lxning `__) +- Ankith Gunapal (`agunapal `__) +- Hamid Shojanazeri (`HamidShojanazeri `__) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - (emeritus) Mark Saroufim (`msaroufIm `__) - (emeritus) Manoj Rao (`mycpuorg `__) - (emeritus) Vamshi Dantu (`vdantu `__) diff --git a/docs/source/conf.py b/docs/source/conf.py index 17bdb33721be8..861c23599ebde 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -62,7 +62,11 @@ "sphinxcontrib.katex", "sphinx_copybutton", "sphinx_design", +<<<<<<< HEAD "myst_nb", +======= + "myst_parser", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "sphinx.ext.linkcode", "sphinxcontrib.mermaid", "sphinx_sitemap", @@ -133,7 +137,11 @@ html_theme_options = { "logo": {"text": "Home"}, "analytics_id": "GTM-T8XT4PS", +<<<<<<< HEAD "canonical_url": "https://docs.pytorch.org/docs/stable/", +======= + "canonical_url": "https://pytorch.org/docs/stable/", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "switcher": { "json_url": "https://docs.pytorch.org/docs/pytorch-versions.json", "version_match": switcher_version, @@ -143,7 +151,11 @@ "external_links": [ { "name": "Tutorials", +<<<<<<< HEAD "url": "https://docs.pytorch.org/tutorials/", +======= + "url": "https://pytorch.org/tutorials/", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, ], "show_version_warning_banner": True, @@ -181,6 +193,10 @@ theme_variables = pytorch_sphinx_theme2.get_theme_variables() html_context = { +<<<<<<< HEAD +======= + "theme_variables": theme_variables, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "github_url": "https://github.com", "github_user": "pytorch", "github_repo": "pytorch", @@ -188,7 +204,11 @@ "github_version": "main", "pytorch_project": "docs", "doc_path": "docs/source", +<<<<<<< HEAD "theme_variables": theme_variables, +======= + "theme_variables": theme_variables, # noqa: F601 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # library links are defined in # pytorch_sphinx_theme2/pytorch_sphinx_theme2/links.json "library_links": theme_variables.get("library_links", []), @@ -263,6 +283,11 @@ "flags_frozen", # torch.distributed.algorithms.ddp_comm_hooks "register_ddp_comm_hook", +<<<<<<< HEAD +======= + # torch.nn + "factory_kwargs", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # torch.nn.parallel "DistributedDataParallelCPU", # torch.utils @@ -736,6 +761,30 @@ "probs_to_logits", "tril_matrix_to_vec", "vec_to_tril_matrix", +<<<<<<< HEAD +======= + # torch.functional + "align_tensors", + "atleast_1d", + "atleast_2d", + "atleast_3d", + "block_diag", + "broadcast_shapes", + "broadcast_tensors", + "cartesian_prod", + "cdist", + "chain_matmul", + "einsum", + "lu", + "meshgrid", + "norm", + "split", + "stft", + "tensordot", + "unique", + "unique_consecutive", + "unravel_index", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # torch.fx.annotate "annotate", # torch.fx.experimental.accelerator_partitioner @@ -1062,6 +1111,10 @@ "z3op", "z3str", # torch.fx.graph_module +<<<<<<< HEAD +======= + "reduce_deploy_graph_module", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "reduce_graph_module", "reduce_package_graph_module", # torch.fx.node @@ -1223,6 +1276,7 @@ # torch.multiprocessing.spawn "start_processes", # torch.nn.functional +<<<<<<< HEAD "adaptive_max_pool1d_with_indices", # documented as adaptive_max_pool1d "adaptive_max_pool2d_with_indices", # documented as adaptive_max_pool2d "adaptive_max_pool3d_with_indices", # documented as adaptive_max_pool3d @@ -1254,6 +1308,39 @@ "xavier_uniform", # deprecated # torch.nn.modules.rnn "apply_permutation", # deprecated +======= + "adaptive_max_pool1d_with_indices", + "adaptive_max_pool2d_with_indices", + "adaptive_max_pool3d_with_indices", + "assert_int_or_pair", + "fractional_max_pool2d_with_indices", + "fractional_max_pool3d_with_indices", + "max_pool1d_with_indices", + "max_pool2d_with_indices", + "max_pool3d_with_indices", + "multi_head_attention_forward", + # torch.nn.grad + "conv1d_input", + "conv1d_weight", + "conv2d_input", + "conv2d_weight", + "conv3d_input", + "conv3d_weight", + # torch.nn.init + "constant", + "dirac", + "eye", + "kaiming_normal", + "kaiming_uniform", + "normal", + "orthogonal", + "sparse", + "uniform", + "xavier_normal", + "xavier_uniform", + # torch.nn.modules.rnn + "apply_permutation", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # torch.nn.modules.utils "consume_prefix_in_state_dict_if_present", # torch.nn.parallel.comm @@ -1275,8 +1362,39 @@ "is_namedtuple", "scatter", "scatter_kwargs", +<<<<<<< HEAD # torch.nn.utils.rnn "bind", # looks unintentionally public +======= + # torch.nn.parameter + "is_lazy", + # torch.nn.utils.convert_parameters + "parameters_to_vector", + "vector_to_parameters", + # torch.nn.utils.fusion + "fuse_conv_bn_eval", + "fuse_conv_bn_weights", + "fuse_linear_bn_eval", + "fuse_linear_bn_weights", + # torch.nn.utils.init + "skip_init", + # torch.nn.utils.memory_format + "convert_conv2d_weight_memory_format", + # torch.nn.utils.parametrizations + "weight_norm", + # torch.nn.utils.parametrize + "transfer_parametrizations_and_params", + "type_before_parametrizations", + # torch.nn.utils.rnn + "bind", + "invert_permutation", + # torch.nn.utils.spectral_norm + "remove_spectral_norm", + "spectral_norm", + # torch.nn.utils.weight_norm + "remove_weight_norm", + "weight_norm", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # torch.onnx.operators "reshape_from_tensor_shape", "shape_as_tensor", @@ -1790,9 +1908,43 @@ "check_export_model_diff", "verify", "verify_aten_graph", +<<<<<<< HEAD # torch.optim.optimizer "register_optimizer_step_post_hook", "register_optimizer_step_pre_hook", +======= + # torch.optim.adadelta + "adadelta", + # torch.optim.adagrad + "adagrad", + # torch.optim.adam + "adam", + # torch.optim.adamax + "adamax", + # torch.optim.adamw + "adamw", + # torch.optim.asgd + "asgd", + # torch.optim.nadam + "nadam", + # torch.optim.optimizer + "register_optimizer_step_post_hook", + "register_optimizer_step_pre_hook", + # torch.optim.radam + "radam", + # torch.optim.rmsprop + "rmsprop", + # torch.optim.rprop + "rprop", + # torch.optim.sgd + "sgd", + # torch.optim.swa_utils + "get_ema_avg_fn", + "get_ema_multi_avg_fn", + "get_swa_avg_fn", + "get_swa_multi_avg_fn", + "update_bn", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # torch.overrides "enable_reentrant_dispatch", # torch.package.analyze.find_first_use_of_broken_modules @@ -2514,8 +2666,11 @@ # torch.distributed.checkpoint.hf_storage "HuggingFaceStorageReader", "HuggingFaceStorageWriter", +<<<<<<< HEAD # torch.distributed.checkpoint.quantized_hf_storage "QuantizedHuggingFaceStorageReader", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # torch.distributed.checkpoint.metadata "BytesStorageMetadata", "ChunkStorageMetadata", @@ -2652,8 +2807,11 @@ "ExpRelaxedCategorical", # torch.distributions.utils "lazy_property", +<<<<<<< HEAD # torch.export.unflatten "UnflattenedModule", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # torch.export.exported_program "ConstantArgument", "ExportedProgram", @@ -2863,6 +3021,7 @@ # torch.nn.cpp "ModuleWrapper", "OrderedDictWrapper", +<<<<<<< HEAD # torch.nn.modules.container "Container", # deprecated # torch.nn.modules.linear @@ -2875,6 +3034,153 @@ "NLLLoss2d", # deprecated # torch.nn.modules.normalization "CrossMapLRN2d", +======= + # torch.nn.modules.activation + "CELU", + "ELU", + "GELU", + "GLU", + "Hardshrink", + "Hardsigmoid", + "Hardswish", + "Hardtanh", + "LeakyReLU", + "LogSigmoid", + "LogSoftmax", + "Mish", + "MultiheadAttention", + "PReLU", + "RReLU", + "ReLU", + "ReLU6", + "SELU", + "SiLU", + "Sigmoid", + "Softmax", + "Softmax2d", + "Softmin", + "Softplus", + "Softshrink", + "Softsign", + "Tanh", + "Tanhshrink", + "Threshold", + # torch.nn.modules.adaptive + "AdaptiveLogSoftmaxWithLoss", + # torch.nn.modules.batchnorm + "SyncBatchNorm", + # torch.nn.modules.channelshuffle + "ChannelShuffle", + # torch.nn.modules.container + "Container", + "ModuleList", + "ParameterList", + "Sequential", + # torch.nn.modules.conv + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + # torch.nn.modules.distance + "CosineSimilarity", + "PairwiseDistance", + # torch.nn.modules.dropout + "AlphaDropout", + "Dropout", + "Dropout1d", + "Dropout2d", + "Dropout3d", + "FeatureAlphaDropout", + # torch.nn.modules.flatten + "Flatten", + "Unflatten", + # torch.nn.modules.fold + "Fold", + "Unfold", + # torch.nn.modules.linear + "Bilinear", + "Identity", + "LazyLinear", + "Linear", + "NonDynamicallyQuantizableLinear", + # torch.nn.modules.loss + "BCELoss", + "BCEWithLogitsLoss", + "CTCLoss", + "CosineEmbeddingLoss", + "CrossEntropyLoss", + "GaussianNLLLoss", + "HingeEmbeddingLoss", + "HuberLoss", + "KLDivLoss", + "L1Loss", + "MSELoss", + "MarginRankingLoss", + "MultiLabelMarginLoss", + "MultiLabelSoftMarginLoss", + "MultiMarginLoss", + "NLLLoss", + "NLLLoss2d", + "PoissonNLLLoss", + "SmoothL1Loss", + "SoftMarginLoss", + "TripletMarginLoss", + "TripletMarginWithDistanceLoss", + # torch.nn.modules.module + "Module", + # torch.nn.modules.normalization + "CrossMapLRN2d", + "GroupNorm", + "LayerNorm", + "LocalResponseNorm", + # torch.nn.modules.padding + "CircularPad1d", + "CircularPad2d", + "CircularPad3d", + "ZeroPad1d", + "ZeroPad2d", + "ZeroPad3d", + # torch.nn.modules.pixelshuffle + "PixelShuffle", + "PixelUnshuffle", + # torch.nn.modules.pooling + "AdaptiveAvgPool1d", + "AdaptiveAvgPool2d", + "AdaptiveAvgPool3d", + "AdaptiveMaxPool1d", + "AdaptiveMaxPool2d", + "AdaptiveMaxPool3d", + "AvgPool1d", + "AvgPool2d", + "AvgPool3d", + "FractionalMaxPool2d", + "FractionalMaxPool3d", + "LPPool1d", + "LPPool2d", + "LPPool3d", + "MaxPool1d", + "MaxPool2d", + "MaxPool3d", + "MaxUnpool1d", + "MaxUnpool2d", + "MaxUnpool3d", + # torch.nn.modules.rnn + "GRU", + "GRUCell", + "LSTM", + "LSTMCell", + "RNN", + "RNNBase", + "RNNCell", + "RNNCellBase", + # torch.nn.modules.sparse + "Embedding", + "EmbeddingBag", + # torch.nn.modules.upsampling + "Upsample", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # torch.nn.parallel.data_parallel "DataParallel", # torch.nn.parallel.distributed @@ -2905,8 +3211,59 @@ # torch.onnx.verification "OnnxBackend", "OnnxTestCaseRepro", +<<<<<<< HEAD + # torch.optim.optimizer + "Optimizer", +======= + # torch.optim.adadelta + "Adadelta", + # torch.optim.adagrad + "Adagrad", + # torch.optim.adam + "Adam", + # torch.optim.adamax + "Adamax", + # torch.optim.adamw + "AdamW", + # torch.optim.asgd + "ASGD", + # torch.optim.lbfgs + "LBFGS", + # torch.optim.lr_scheduler + "ChainedScheduler", + "ConstantLR", + "CosineAnnealingLR", + "CosineAnnealingWarmRestarts", + "CyclicLR", + "ExponentialLR", + "LRScheduler", + "LambdaLR", + "LinearLR", + "MultiStepLR", + "MultiplicativeLR", + "OneCycleLR", + "PolynomialLR", + "ReduceLROnPlateau", + "SequentialLR", + "StepLR", + # torch.optim.nadam + "NAdam", # torch.optim.optimizer "Optimizer", + # torch.optim.radam + "RAdam", + # torch.optim.rmsprop + "RMSprop", + # torch.optim.rprop + "Rprop", + # torch.optim.sgd + "SGD", + # torch.optim.sparse_adam + "SparseAdam", + # torch.optim.swa_utils + "AveragedModel", + "SWALR", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # torch.overrides "BaseTorchFunctionMode", "TorchFunctionMode", @@ -3310,8 +3667,11 @@ def linkcode_resolve(domain, info): "https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css", ] +<<<<<<< HEAD html_js_files = ["js/runllm-widget.js"] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from sphinx.ext.coverage import CoverageBuilder diff --git a/docs/source/cpp_index.rst b/docs/source/cpp_index.rst index 37571b9c60bc2..feb55d46037c1 100644 --- a/docs/source/cpp_index.rst +++ b/docs/source/cpp_index.rst @@ -7,6 +7,23 @@ C++ PyTorch provides several features for working with C++, and it’s best to choose from them based on your needs. At a high level, the following support is available: +<<<<<<< HEAD +======= +TorchScript C++ API +-------------------- +`TorchScript `__ allows PyTorch models defined in Python to be serialized and then loaded and run in C++ capturing the model code via compilation or tracing its execution. You can learn more in the `Loading a TorchScript Model in C++ tutorial `__. This means you can define your models in Python as much as possible, but subsequently export them via TorchScript for doing no-Python execution in production or embedded environments. The TorchScript C++ API is used to interact with these models and the TorchScript execution engine, including: + +* Loading serialized TorchScript models saved from Python +* Doing simple model modifications if needed (e.g. pulling out submodules) +* Constructing the input and doing preprocessing using C++ Tensor API + +Extending PyTorch and TorchScript with C++ Extensions +------------------------------------------------------ +TorchScript can be augmented with user-supplied code through custom operators and custom classes. +Once registered with TorchScript, these operators and classes can be invoked in TorchScript code run from +Python or from C++ as part of a serialized TorchScript model. The `Extending TorchScript with Custom C++ Operators `__ tutorial walks through interfacing TorchScript with OpenCV. In addition to wrapping a function call with a custom operator, C++ classes and structs can be bound into TorchScript through a pybind11-like interface which is explained in the `Extending TorchScript with Custom C++ Classes `__ tutorial. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor and Autograd in C++ --------------------------- Most of the tensor and autograd operations in PyTorch Python API are also available in the C++ API. These include: @@ -17,7 +34,13 @@ Most of the tensor and autograd operations in PyTorch Python API are also availa Authoring Models in C++ ------------------------ +<<<<<<< HEAD We provide the full capability of authoring and training a neural net model purely in C++, with familiar components such as ``torch::nn`` / ``torch::nn::functional`` / ``torch::optim`` that closely resemble the Python API. +======= +The "author in TorchScript, infer in C++" workflow requires model authoring to be done in TorchScript. +However, there might be cases where the model has to be authored in C++ (e.g. in workflows where a Python +component is undesirable). To serve such use cases, we provide the full capability of authoring and training a neural net model purely in C++, with familiar components such as ``torch::nn`` / ``torch::nn::functional`` / ``torch::optim`` that closely resemble the Python API. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * For an overview of the PyTorch C++ model authoring and training API, please see: https://pytorch.org/cppdocs/frontend.html * For a detailed tutorial on how to use the API, please see: https://pytorch.org/tutorials/advanced/cpp_frontend.html diff --git a/docs/source/cuda.md b/docs/source/cuda.md index 24830cacdd4f6..13c9c7d3c957d 100644 --- a/docs/source/cuda.md +++ b/docs/source/cuda.md @@ -269,6 +269,13 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t ``` ```{eval-rst} +<<<<<<< HEAD +======= +.. py:module:: torch.cuda.error +``` + +```{eval-rst} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. py:module:: torch.cuda.gds ``` diff --git a/docs/source/deploy.md b/docs/source/deploy.md new file mode 100644 index 0000000000000..ef5131717bf7b --- /dev/null +++ b/docs/source/deploy.md @@ -0,0 +1,8 @@ +--- +orphan: true +--- + +# torch::deploy has been moved to pytorch/multipy + + +``torch::deploy`` has been moved to its new home at [https://github.com/pytorch/multipy](https://github.com/pytorch/multipy). diff --git a/docs/source/distributed.checkpoint.md b/docs/source/distributed.checkpoint.md index c733ffef18d97..f13bf2decbd13 100644 --- a/docs/source/distributed.checkpoint.md +++ b/docs/source/distributed.checkpoint.md @@ -36,11 +36,14 @@ The entrypoints to load and save a checkpoint are the following: ``` ```{eval-rst} +<<<<<<< HEAD .. autoclass:: torch.distributed.checkpoint.state_dict_saver.AsyncSaveResponse :members: ``` ```{eval-rst} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. autofunction:: save ``` @@ -76,6 +79,7 @@ The following module is also useful for additional customization of the staging ``` ```{eval-rst} +<<<<<<< HEAD .. autoclass:: torch.distributed.checkpoint.staging.DefaultStager :members: ``` @@ -86,6 +90,8 @@ The following module is also useful for additional customization of the staging ``` ```{eval-rst} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. autoclass:: torch.distributed.checkpoint.staging.BlockingAsyncStager :members: ``` @@ -173,9 +179,12 @@ We also provide other storage layers, including ones to interact with HuggingFac .. autoclass:: torch.distributed.checkpoint.HuggingFaceStorageWriter :members: +<<<<<<< HEAD .. autoclass:: torch.distributed.checkpoint.QuantizedHuggingFaceStorageReader :members: +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) We provide default implementations of `LoadPlanner` and `SavePlanner` that can handle all of torch.distributed constructs such as FSDP, DDP, ShardedTensor and DistributedTensor. diff --git a/docs/source/distributed.elastic.md b/docs/source/distributed.elastic.md index 1c7177dd4a9a0..5b5db578e207a 100644 --- a/docs/source/distributed.elastic.md +++ b/docs/source/distributed.elastic.md @@ -29,7 +29,10 @@ elastic/metrics elastic/events elastic/subprocess_handler elastic/control_plane +<<<<<<< HEAD elastic/numa +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` ```{toctree} diff --git a/docs/source/distributed.md b/docs/source/distributed.md index 1a5f8d2b6f3fd..a324c33cbc65c 100644 --- a/docs/source/distributed.md +++ b/docs/source/distributed.md @@ -20,6 +20,7 @@ for a brief introduction to all features related to distributed training. ## Backends +<<<<<<< HEAD `torch.distributed` supports four built-in backends, each with different capabilities. The table below shows which functions are available for use with a CPU or GPU for each backend. For NCCL, GPU refers to CUDA GPU @@ -55,6 +56,41 @@ MPI supports CUDA only if the implementation used to build PyTorch supports it. +----------------+-----+-----+-----+-----+-----+-----+-----+-----+ | barrier | āœ“ | ✘ | āœ“ | ? | ✘ | āœ“ | ✘ | āœ“ | +----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +======= +`torch.distributed` supports three built-in backends, each with +different capabilities. The table below shows which functions are available +for use with CPU / CUDA tensors. +MPI supports CUDA only if the implementation used to build PyTorch supports it. + +```{eval-rst} ++----------------+-----------+-----------+-----------+ +| Backend | ``gloo`` | ``mpi`` | ``nccl`` | ++----------------+-----+-----+-----+-----+-----+-----+ +| Device | CPU | GPU | CPU | GPU | CPU | GPU | ++================+=====+=====+=====+=====+=====+=====+ +| send | āœ“ | ✘ | āœ“ | ? | ✘ | āœ“ | ++----------------+-----+-----+-----+-----+-----+-----+ +| recv | āœ“ | ✘ | āœ“ | ? | ✘ | āœ“ | ++----------------+-----+-----+-----+-----+-----+-----+ +| broadcast | āœ“ | āœ“ | āœ“ | ? | ✘ | āœ“ | ++----------------+-----+-----+-----+-----+-----+-----+ +| all_reduce | āœ“ | āœ“ | āœ“ | ? | ✘ | āœ“ | ++----------------+-----+-----+-----+-----+-----+-----+ +| reduce | āœ“ | āœ“ | āœ“ | ? | ✘ | āœ“ | ++----------------+-----+-----+-----+-----+-----+-----+ +| all_gather | āœ“ | āœ“ | āœ“ | ? | ✘ | āœ“ | ++----------------+-----+-----+-----+-----+-----+-----+ +| gather | āœ“ | āœ“ | āœ“ | ? | ✘ | āœ“ | ++----------------+-----+-----+-----+-----+-----+-----+ +| scatter | āœ“ | āœ“ | āœ“ | ? | ✘ | āœ“ | ++----------------+-----+-----+-----+-----+-----+-----+ +| reduce_scatter | āœ“ | āœ“ | ✘ | ✘ | ✘ | āœ“ | ++----------------+-----+-----+-----+-----+-----+-----+ +| all_to_all | āœ“ | āœ“ | āœ“ | ? | ✘ | āœ“ | ++----------------+-----+-----+-----+-----+-----+-----+ +| barrier | āœ“ | ✘ | āœ“ | ? | ✘ | āœ“ | ++----------------+-----+-----+-----+-----+-----+-----+ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` ### Backends that come with PyTorch @@ -83,9 +119,14 @@ In the past, we were often asked: "which backend should I use?". - Rule of thumb +<<<<<<< HEAD - Use the NCCL backend for distributed training with CUDA **GPU**. - Use the XCCL backend for distributed training with XPU **GPU**. - Use the Gloo backend for distributed training with **CPU**. +======= + - Use the NCCL backend for distributed **GPU** training + - Use the Gloo backend for distributed **CPU** training. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - GPU hosts with InfiniBand interconnect @@ -1140,10 +1181,13 @@ If you are running single node training, it may be convenient to interactively b ``` ```{eval-rst} +<<<<<<< HEAD .. py:module:: torch.distributed.checkpoint.quantized_hf_storage ``` ```{eval-rst} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. py:module:: torch.distributed.checkpoint.metadata ``` @@ -1482,9 +1526,12 @@ If you are running single node training, it may be convenient to interactively b ```{eval-rst} .. py:module:: torch.distributed.checkpoint.state_dict ``` +<<<<<<< HEAD ```{toctree} :hidden: distributed._dist2 ``` +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/docs/source/distributed.pipelining.md b/docs/source/distributed.pipelining.md index 9d8b6998aae43..9db6d553646fd 100644 --- a/docs/source/distributed.pipelining.md +++ b/docs/source/distributed.pipelining.md @@ -505,10 +505,13 @@ The following set of APIs transform your model into a pipeline representation. ``` ```{eval-rst} +<<<<<<< HEAD .. autoclass:: ScheduleDualPipeV ``` ```{eval-rst} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. autoclass:: PipelineScheduleSingle :members: ``` diff --git a/docs/source/distributed.tensor.md b/docs/source/distributed.tensor.md index cb12eb195c02c..776910420ac47 100644 --- a/docs/source/distributed.tensor.md +++ b/docs/source/distributed.tensor.md @@ -179,6 +179,7 @@ specifying the {class}`DeviceMesh` and {class}`Placement` for the {class}`DTenso ``` +<<<<<<< HEAD ### Random Operations DTensor provides distributed RNG functionality to ensure that random operations on sharded tensors get unique values, and random operations on replicated tensors get the same values. This system requires that all participating @@ -191,6 +192,8 @@ When using DTensor together with Pipeline Parallelism, ranks for each pipeline s DTensor's RNG infra is based on the philox based RNG algorithm, and supports any philox based backend (cuda, and other cuda-like devices), but unfortunately does not yet support the CPU backend. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ## Debugging ```{eval-rst} diff --git a/docs/source/draft_export.md b/docs/source/draft_export.md new file mode 100644 index 0000000000000..cc7247d3b526d --- /dev/null +++ b/docs/source/draft_export.md @@ -0,0 +1,262 @@ +(draft-export)= + +# Draft Export + +:::{warning} +This feature is not meant to be used in production and is designed to be +used as a tool for debugging torch.export tracing errors. +::: + +Draft-export is a new version of export, which is designed to consistently +produce a graph, even if there are potential soundness issues, and to generate a +report listing out all of the issues export encountered during +tracing and providing additional debugging information. For custom operators that +don't have fake kernels, it will also generate a profile which you can register +to automatically generate a fake kernel. + +Have you ever tried to export a model using {func}`torch.export.export`, only to +encounter a data-dependent issue? You fix it, but then run into a missing fake +kernel problem. And after resolving that, you get hit with another +data-dependent issue. You wonder to yourself, I wish there was a way I could +just get a graph to play around with, and be able to view all the issues in one +place so that I can fix them later… + +`draft_export` to the rescue! + +`draft_export` is a version of export which will always successfully export a +graph, even if there are potential soundness issues. These issues will then be +compiled into a report for clearer visualization, which can be fixed later on. + +## What sort of errors does it catch? + +Draft-export helps to catch and debug the following errors: + +- Guard on data-dependent errors +- Constraint violation errors +- Missing fake kernels +- Incorrectly written fake kernels + +## How does it work? + +In normal export, we will convert the sample inputs into FakeTensors and use +them to record operations and trace the program into a graph. Input tensor +shapes that can change (which are marked through `dynamic_shapes`), or values +within tensors (typically from an `.item()` call) will be represented as a symbolic +shape (`SymInt`) instead of a concrete integer. However some issues may occur +while tracing - we may run into guards that we cannot evaluate, like if we want +to check if some item in a tensor is greater than 0 (`u0 >= 0`). Since the tracer +doesn't know anything about the value of `u0`, it will throw a data-dependent +error. If the model uses a custom operator but a fake kernel hasn't been +defined for it, then we will error with `fake_tensor.UnsupportedOperatorException` +because export doesn't know how to apply this on `FakeTensors`. If a custom +operator has a fake kernel implemented incorrectly, export will silently produce +an incorrect graph that doesn't match the eager behavior. + +To fix the above errors, draft-export uses *real tensor tracing* to guide us on +how to proceed when tracing. As we trace the model with fake tensors, for every +operation that happens on a fake tensor, draft-export will also run the operator +on stored real tensors which come from the example inputs passed to export. This +allows us to address the above errors: When we reach a guard that we cannot +evaluate, like `u0 >= 0`, we will use the stored real tensor values to +evaluate this guard. Runtime asserts will be added into the graph to ensure that +the graph asserts the same guard that we assumed while tracing. If we run into +a custom operator without a fake kernel, we will run the operator's normal +kernel with the stored real tensors, and return a fake tensor with the same rank +but unbacked shapes. Since we have the real tensor output for every operation, +we will compare this with the fake tensor output from the fake kernel. If the +fake kernel is implemented incorrectly, we will then catch this behavior and +generate a more correct fake kernel. + +## How can I use draft export? + +Let's say you're trying to export this piece of code: + +```python +class M(torch.nn.Module): + def forward(self, x, y, z): + res = torch.ops.mylib.foo2(x, y) + + a = res.item() + a = -a + a = a // 3 + a = a + 5 + + z = torch.cat([z, z]) + + torch._check_is_size(a) + torch._check(a < z.shape[0]) + + return z[:a] + +inp = (torch.tensor(3), torch.tensor(4), torch.ones(3, 3)) + +ep = torch.export.export(M(), inp) +``` + +This runs into a ā€œmissing fake kernelā€ error for `mylib.foo2` and then a +`GuardOnDataDependentExpression` because of the slicing of `z` with `a`, +an unbacked symint. + +To call `draft-export`, we can replace the `torch.export` line with the following: + +```python +ep = torch.export.draft_export(M(), inp) +``` + +`ep` is a valid ExportedProgram which can now be passed through further environments! + +## Debugging with draft-export + +In the terminal output from draft-export, you should see the following message: + +``` +######################################################################################### +WARNING: 2 issue(s) found during export, and it was not able to soundly produce a graph. +To view the report of failures in an html page, please run the command: + `tlparse /tmp/export_angelayi/dedicated_log_torch_trace_axpofwe2.log --export` +Or, you can view the errors in python by inspecting `print(ep._report)`. +######################################################################################## +``` + +Draft-export automatically dumps logs for `tlparse`. You can view the tracing +errors by using `print(ep._report)`, or you can pass the logs into `tlparse` +to generate an html report. + +Running the `tlparse` command in the terminal will generate a +[tlparse](https://github.com/pytorch/tlparse) +HTML report. Here is an example of the `tlparse` report: + +```{image} _static/img/export/draft_export_report.png +``` + +Clicking into the Data Dependent Error, we will see the following page which +contains information to help debug this error. Specifically, it contains: + +- The stacktrace at which this error occurs +- A list of local variables and their shapes +- Information for how this guard was created + +```{image} _static/img/export/draft_export_report_dde.png +``` + +## The returned Exported Program + +Because draft-export specializes on code paths based on the example inputs, the +exported program resulting from draft-export is guaranteed to be runnable and +return correct results for **at least** the given example inputs. Other inputs can +work, as long as they match the same guards that were taken when we were +draft-exporting. + +For example, if we have a graph branching on if a value is greater than 5, if in +draft-export our example inputs were greater than 5, then the returned +`ExportedProgram` will specialize on that branch, and will assert that the value +is greater than 5. This means that the program will succeed if you pass in +another value greater than 5, but will fail if you pass in a value less than 5. +This is more sound than `torch.jit.trace`, which will silently specialize on the +branch. The proper way for `torch.export` to support both branches would be to +rewrite the code using `torch.cond`, which will then capture both branches. + +Because of the runtime assertions in the graph, the returned exported-program is +also retraceable with `torch.export` or `torch.compile`, with a minor addition in +the case where a custom operator is missing a fake kernel. + +## Generating Fake Kernels + +If a custom operator does not contain a fake implementation, currently +draft-export will use the real-tensor propagation to get an output for the +operator and continue tracing. However, if we run the exported program with fake +tensors or retrace the exported model, we will still fail because there is still +no fake kernel implementation. + +To address this, after draft-export, we will generate an operator profile for +each custom operator call that we encounter, and store this on the report +attached to the exported program: `ep._report.op_profiles`. Users can then use the +context manager `torch._library.fake_profile.unsafe_generate_fake_kernels` to +generate and register a fake implementation based on these operator profiles. +This way future fake tensor retracing will work. + +The workflow would look something like: + +```python +class M(torch.nn.Module): + def forward(self, a, b): + res = torch.ops.mylib.foo(a, b) # no fake impl + return res + +ep = draft_export(M(), (torch.ones(3, 4), torch.ones(3, 4))) + +with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles): + decomp = ep.run_decompositions() + +new_inp = ( + torch.ones(2, 3, 4), + torch.ones(2, 3, 4), +) + +# Save the profile to a yaml and check it into a codebase +save_op_profiles(ep._report.op_profiles, "op_profile.yaml") +# Load the yaml +loaded_op_profile = load_op_profiles("op_profile.yaml") +``` + +The operator profile is a dictionary mapping operator name to a set of profiles +which describe the input and outputs of the operator, and could be manually +written, saved into a yaml file, and checked into a codebase. Here's an example +of a profile for `mylib.foo.default`: + +```python +"mylib.foo.default": { + OpProfile( + args_profile=( + TensorMetadata( + rank=2, + dtype=torch.float32, + device=torch.device("cpu"), + layout=torch.strided, + ), + TensorMetadata( + rank=2, + dtype=torch.float32, + device=torch.device("cpu"), + layout=torch.strided, + ), + ), + out_profile=TensorMetadata( + rank=2, + dtype=torch.float32, + device=torch.device("cpu"), + layout=torch.strided, + ), + ) +} +``` + +`mylib.foo.default`'s profile contains only one profile, which says that for 2 +input tensors of rank 2, dtype `torch.float32`, device `cpu`, we will return +one tensor of rank 2, dtype `torch.float32`, and device `cpu`. Using the +context manager, will then generate a fake kernel where given 2 input tensors of +rank 2 (and the other tensor metadata), we will output one tensor of rank 2 (and +the other tensor metadata). + +If the operator also supports other input ranks, then we can add the profile to +this list of profiles, either by manually adding it into the existing profile or +rerunning draft-export with new inputs to get new profiles, so that the +generated fake kernel will support more input types. Otherwise it will error. + +## Where to go from here? + +Now that we have successfully created an `ExportedProgram` using draft-export, +we can use further compilers such as `AOTInductor` to optimize its performance +and produce a runnable artifact. This optimized version can then be used for +deployment. In parallel, we can utilize the report generated by draft-export to +identify and fix `torch.export` errors that were encountered so that the +original model can be directly traceable with `torch.export`. + +```{toctree} +:caption: Additional Links +:maxdepth: 1 + +torch.compiler_fake_tensor +torch.compiler_dynamic_shapes +torch.compiler_aot_inductor +``` diff --git a/docs/source/export.ir_spec.md b/docs/source/export.ir_spec.md new file mode 100644 index 0000000000000..355539ecfcc94 --- /dev/null +++ b/docs/source/export.ir_spec.md @@ -0,0 +1,487 @@ +(export.ir_spec)= + +# torch.export IR Specification + +Export IR is an intermediate representation (IR) for compilers, which bears +similarities to MLIR and TorchScript. It is specifically designed to express the +semantics of PyTorch programs. Export IR primarily represents computation in a +streamlined list of operations, with limited support for dynamism such as +control flows. + +To create an Export IR graph, a frontend can be used that soundly captures a +PyTorch program via a trace-specializing mechanism. The resulting Export IR can +then be optimized and executed by a backend. This can be done today through +{func}`torch.export.export`. + +The key concepts that will be covered in this document include: + +- ExportedProgram: the data structure containing the Export IR program +- Graph: which consists of a list of nodes. +- Nodes: which represents operations, control flow, and metadata stored on this node. +- Values are produced and consumed by nodes. +- Types are associated with values and nodes. +- The size and memory layout of values are also defined. + +## Assumptions + +This doc assumes that the audience is sufficiently familiar with PyTorch, +specifically with {class}`torch.fx` and its related toolings. Thus it will stop +describing contents present in {class}`torch.fx` documentation and paper. + +## What is Export IR + +Export IR is a graph-based intermediate representation IR of PyTorch programs. +Export IR is realized on top of {class}`torch.fx.Graph`. In other words, **all +Export IR graphs are also valid FX graphs**, and if interpreted using standard +FX semantics, Export IR can be interpreted soundly. One implication is that an +exported graph can be converted to a valid Python program via standard FX +codegen. + +This documentation will primarily focus on highlighting areas where Export IR +differs from FX in terms of its strictness, while skipping parts where it shares +similarities with FX. + +## ExportedProgram + +The top-level Export IR construct is an {class}`torch.export.ExportedProgram` +class. It bundles the computational graph of a PyTorch model (which is usually a +{class}`torch.nn.Module`) with the parameters or weights that this model +consumes. + +Some notable attributes of the {class}`torch.export.ExportedProgram` class are: + +- `graph_module` ({class}`torch.fx.GraphModule`): Data structure containing + the flattened computational graph of the PyTorch model. The graph can be + directly accessed through `ExportedProgram.graph`. +- `graph_signature` ({class}`torch.export.ExportGraphSignature`): The graph + signature, which specifies the parameters and buffer names used and mutated + within the graph. Instead of storing parameters and buffers as attributes of + the graph, they are lifted as inputs to the graph. The graph_signature is + utilized to keep track of additional information on these parameters and + buffers. +- `state_dict` (`Dict[str, Union[torch.Tensor, torch.nn.Parameter]]`): Data + structure containing the parameters and buffers. +- `range_constraints` (`Dict[sympy.Symbol, RangeConstraint]`): For programs + that are exported with data dependent behavior, the metadata on each node will + contain symbolic shapes (which look like `s0`, `i0`). This attribute maps + the symbolic shapes to their lower/upper ranges. + +## Graph + +An Export IR Graph is a PyTorch program represented in the form of a DAG +(directed acyclic graph). Each node in this graph represents a particular +computation or operation, and edges of this graph consist of references between +nodes. + +We can view Graph having this schema: + +```python +class Graph: + nodes: List[Node] +``` + +In practice, Export IR's graph is realized as {class}`torch.fx.Graph` Python class. + +An Export IR graph contains the following nodes (Nodes will be described in more +details in the next section): + +- 0 or more nodes of op type `placeholder` +- 0 or more nodes of op type `call_function` +- exactly 1 node of op type `output` + +**Collorary:** The smallest valid Graph will be of one node. i.e. nodes is never empty. + +**Definition:** +The set of `placeholder` nodes of a Graph represents the **inputs** of the +Graph of GraphModule. The `output` node of a Graph represents the **outputs** +of the Graph of GraphModule. + +Example: + +```python +import torch +from torch import nn + +class MyModule(nn.Module): + + def forward(self, x, y): + return x + y + +example_args = (torch.randn(1), torch.randn(1)) +mod = torch.export.export(MyModule(), example_args) +print(mod.graph) +``` + +```python +graph(): + %x : [num_users=1] = placeholder[target=x] + %y : [num_users=1] = placeholder[target=y] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %y), kwargs = {}) + return (add,) +``` + +The above is the textual representation of a Graph, with each line being a node. + +## Node + +A Node represents a particular computation or operation and is represented in +Python using the {class}`torch.fx.Node` class. Edges between nodes are +represented as direct references to other nodes via the `args` property of the +Node class. Using the same FX machinery, we can represent the following +operations that a computational graph typically needs, such as operator calls, +placeholders (aka inputs), conditionals, and loops. + +The Node has the following schema: + +```python +class Node: + name: str # name of node + op_name: str # type of operation + + # interpretation of the fields below depends on op_name + target: [str|Callable] + args: List[object] + kwargs: Dict[str, object] + meta: Dict[str, object] +``` + +**FX Text Format** + +As in the example above, notice that each line has this format: + +``` +%:[...] = [target=](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5}) +``` + +This format captures everything present in the Node class, with the exception of +`meta`, in a compact format. + +Concretely: + +- **** is the name of the node as it would appear in `node.name`. +- **** is the `node.op` field, which must be one of these: + ``, ``, + ``, or ``. +- **** is the target of the node as `node.target`. The meaning of this + field depends on `op_name`. +- **args1, … args 4…** are what is listed in the `node.args` tuple. If a + value in the list is an {class}`torch.fx.Node`, then it will be especially + indicated with a leading **%.** + +For example, a call to the add operator would appear as: + +``` +%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {}) +``` + +Where `%x`, `%y` are two other Nodes that have names x and y. Worth noting +that the string `torch.op.aten.add.Tensor` represents the callable object that +is actually stored in the target field, not merely its string name. + +The final line of this text format is: + +``` +return [add] +``` + +which is a Node with `op_name = output`, indicating that we are returning this +one element. + +### call_function + +A `call_function` node represents a call to an operator. + +**Definitions** + +- **Functional:** We say a callable is ā€œfunctionalā€ if it satisfies all the + following requirements: + + - Non-mutating: The operator does not mutate the value of its input (for + tensors, this includes both metadata and data). + - No side effects: The operator does not mutate states that are visible + from outside, like changing values of module parameters. + +- **Operator:** is a functional callable with a predefined schema. Examples of + such operators include functional ATen operators. + +**Representation in FX** + +``` +%name = call_function[target = operator](args = (%x, %y, …), kwargs = {}) +``` + +**Differences from vanilla FX call_function** + +1. In FX graph, a call_function can refer to any callable, in Export IR, we + restrict it to only a select subset of ATen operators, custom operators, and + control flow operators. +2. In Export IR, constant arguments will be embedded within the graph. +3. In FX graph, a get_attr node can represent reading any attribute stored in + the graph module. However, in Export IR this is restricted to reading only + submodules as all parameters/buffers will be passed in as inputs to the graph + module. + +#### Metadata + +`Node.meta` is a dict attached to every FX node. However, the FX spec does not +specify what metadata can or will be there. Export IR provides a stronger +contract, specifically all `call_function` nodes will guarantee having and +only having the following metadata fields: + +- `node.meta["stack_trace"]` is a string containing the Python stack trace + referencing the original Python source code. An example stack trace looks + like: + + ``` + File "my_module.py", line 19, in forward + return x + dummy_helper(y) + File "helper_utility.py", line 89, in dummy_helper + return y + 1 + ``` + +- `node.meta["val"]` describes the output of running the operation. It can be + of type ``, ``, a + `List[Union[FakeTensor, SymInt]]`, or `None`. + +- `node.meta["nn_module_stack"]` describes the "stacktrace" of the + {class}`torch.nn.Module` from which the node came, if it was from a + {class}`torch.nn.Module` call. For example, if a node containing the `addmm` + op called from a {class}`torch.nn.Linear` module inside of a + {class}`torch.nn.Sequential` module, the `nn_module_stack` would look + something like: + + ``` + {'self_linear': ('self.linear', ), 'self_sequential': ('self.sequential', )} + ``` + +- `node.meta["source_fn_stack"]` contains the torch function or the leaf + {class}`torch.nn.Module` class this node was called from before decomposition. + For example, a node containing the `addmm` op from a + {class}`torch.nn.Linear` module call would contain {class}`torch.nn.Linear` in + their `source_fn`, and a node containing the `addmm` op from a + {class}`torch.nn.functional.Linear` module call would contain + {class}`torch.nn.functional.Linear` in their `source_fn`. + +### placeholder + +Placeholder represents an input to a graph. Its semantics are exactly the same as in FX. +Placeholder nodes must be the first N nodes in the nodes list of a graph. N can be zero. + +**Representation in FX** + +```python +%name = placeholder[target = name](args = ()) +``` + +The target field is a string which is the name of input. + +`args`, if non-empty, should be of size 1 representing the default value of this input. + +**Metadata** + +Placeholder nodes also have `meta[ā€˜val’]`, like `call_function` nodes. The +`val` field in this case represents the input shape/dtype that the graph is +expected to receive for this input parameter. + +### output + +An output call represents a return statement in a function; it thus terminates the +current graph. There is one and only one output node, and it will always be the +last node of the graph. + +**Representation in FX** + +``` +output[](args = (%something, …)) +``` + +This has the exact semantics as in {class}`torch.fx`. `args` represents the node +to be returned. + +**Metadata** + +Output node has the same metadata as `call_function` nodes. + +### get_attr + +`get_attr` nodes represent reading a submodule from the encapsulating +{class}`torch.fx.GraphModule`. Unlike a vanilla FX graph from +{func}`torch.fx.symbolic_trace` in which `get_attr` nodes are used to read +attributes such as parameters and buffers from the top-level +{class}`torch.fx.GraphModule`, parameters and buffers are passed in as +inputs to the graph module, and stored in the top-level +{class}`torch.export.ExportedProgram`. + +**Representation in FX** + +```python +%name = get_attr[target = name](args = ()) +``` + +**Example** + +Consider the following model: + +```python +from functorch.experimental.control_flow import cond + +def true_fn(x): + return x.sin() + +def false_fn(x): + return x.cos() + +def f(x, y): + return cond(y, true_fn, false_fn, [x]) +``` + +Graph: + +``` +graph(): + %x_1 : [num_users=1] = placeholder[target=x_1] + %y_1 : [num_users=1] = placeholder[target=y_1] + %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] + %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0] + %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {}) + return conditional +``` + +The line, `%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]`, +reads the submodule `true_graph_0` which contains the `sin` operator. + +## References + +### SymInt + +A SymInt is an object that can either be a literal integer or a symbol that represents +an Integer (represented in Python by `sympy.Symbol` class). When SymInt is a +symbol, it describes a variable of type integer that is unknown to the graph at +compile time, that is, its value is only known at runtime. + +### FakeTensor + +A FakeTensor is an object that contains the metadata of a tensor. It can be +viewed as having the following metadata. + +```python +class FakeTensor: + size: List[SymInt] + dtype: torch.dtype + device: torch.device + dim_order: List[int] # This doesn't exist yet +``` + +The size field of FakeTensor is a list of integers or SymInts. If SymInts are +present, this means this tensor has a dynamic shape. If integers are present, it +is assumed that the tensor will have that exact static shape. The rank of the +TensorMeta is never dynamic. The dtype field represents the dtype of the +output of that node. There are no implicit type promotions in Edge IR. There +are no strides in FakeTensor. + +In other words: + +- If the operator in node.target returns a Tensor, then `node.meta['val']` is a + FakeTensor describing that tensor. +- If the operator in node.target returns an n-tuple of Tensors, then + `node.meta['val']` is an n-tuple of FakeTensors describing each tensor. +- If the operator in node.target returns an int/float/scalar that is known at + compile time, then `node.meta['val']` is None. +- If the operator in node.target returns an int/float/scalar that is not known + at compile time, then `node.meta['val']` is of type SymInt. + +For example: + +- `aten::add` returns a Tensor; so its spec will be a FakeTensor with dtype + and size of the tensor returned by this operator. +- `aten::sym_size` returns an integer; so its val will be a SymInt because its + value is only available at runtime. +- `max_pool2d_with_indexes` returns a tuple of (Tensor, Tensor); so the spec + will also be a 2-tuple of FakeTensor objects, the first TensorMeta describes + the first element of the return value etc. + +Python code: + +```python +def add_one(x): + return torch.ops.aten(x, 1) +``` + +Graph: + +``` +graph(): + %ph_0 : [#users=1] = placeholder[target=ph_0] + %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {}) + return [add_tensor] +``` + +FakeTensor: + +```python +FakeTensor(dtype=torch.int, size=[2,], device=CPU) +``` + +### Pytree-able Types + +We define a type ā€œPytree-ableā€, if it is either a leaf type or a container type +that contains other Pytree-able types. + +Note: + +> The concept of pytree is the same as the one documented +> [here](https://jax.readthedocs.io/en/latest/pytrees.html) for JAX: + +The following types are defined as **leaf type**: + +```{eval-rst} +.. list-table:: + :widths: 50 50 + :header-rows: 1 + + * - Type + - Definition + * - Tensor + - :class:`torch.Tensor` + * - Scalar + - Any numerical types from Python, including integral types, floating point types, and zero dimensional tensors. + * - int + - Python int (bound as int64_t in C++) + * - float + - Python float (bound as double in C++) + * - bool + - Python bool + * - str + - Python string + * - ScalarType + - :class:`torch.dtype` + * - Layout + - :class:`torch.layout` + * - MemoryFormat + - :class:`torch.memory_format` + * - Device + - :class:`torch.device` +``` + +The following types are defined as **container type**: + +```{eval-rst} +.. list-table:: + :widths: 50 50 + :header-rows: 1 + + * - Type + - Definition + * - Tuple + - Python tuple + * - List + - Python list + * - Dict + - Python dict with Scalar keys + * - NamedTuple + - Python namedtuple + * - Dataclass + - Must be registered through `register_dataclass `__ + * - Custom class + - Any custom class defined with `_register_pytree_node `__ +``` diff --git a/docs/source/export.md b/docs/source/export.md index b550e0270b325..e69024f80ee3f 100644 --- a/docs/source/export.md +++ b/docs/source/export.md @@ -1,3 +1,4 @@ +<<<<<<< HEAD --- file_format: mystnb kernelspec: @@ -8,10 +9,20 @@ mystnb: merge_streams: True --- +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (torch.export)= # torch.export +<<<<<<< HEAD +======= +:::{warning} +This feature is a prototype under active development and there WILL BE +BREAKING CHANGES in the future. +::: + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ## Overview {func}`torch.export.export` takes a {class}`torch.nn.Module` and produces a traced graph @@ -19,9 +30,15 @@ representing only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, which can subsequently be executed with different outputs or serialized. +<<<<<<< HEAD ```{code-cell} import torch from torch.export import export, ExportedProgram +======= +```python +import torch +from torch.export import export +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Mod(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -31,10 +48,60 @@ class Mod(torch.nn.Module): example_args = (torch.randn(10, 10), torch.randn(10, 10)) +<<<<<<< HEAD exported_program: ExportedProgram = export(Mod(), args=example_args) print(exported_program) ``` +======= +exported_program: torch.export.ExportedProgram = export( + Mod(), args=example_args +) +print(exported_program) +``` + +```python +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"): + # code: a = torch.sin(x) + sin: "f32[10, 10]" = torch.ops.aten.sin.default(x) + + # code: b = torch.cos(y) + cos: "f32[10, 10]" = torch.ops.aten.cos.default(y) + + # code: return a + b + add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos) + return (add,) + + Graph signature: + ExportGraphSignature( + input_specs=[ + InputSpec( + kind=, + arg=TensorArgument(name='x'), + target=None, + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='y'), + target=None, + persistent=None + ) + ], + output_specs=[ + OutputSpec( + kind=, + arg=TensorArgument(name='add'), + target=None + ) + ] + ) + Range constraints: {} +``` + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) `torch.export` produces a clean intermediate representation (IR) with the following invariants. More specifications about the IR can be found {ref}`here `. @@ -92,6 +159,7 @@ level). Note that users can still use {func}`torch.fx.symbolic_trace` as a preprocessing step before `torch.export`. Compared to {func}`torch.jit.script`, `torch.export` does not capture Python +<<<<<<< HEAD control flow or data structures, unless using explicit {ref}`control flow operators `, but it supports more Python language features due to its comprehensive coverage over Python bytecodes. The resulting graphs are simpler and only have straight @@ -112,6 +180,30 @@ example: ```{code-cell} import torch from torch.export import export, ExportedProgram +======= +control flow or data structures, but it supports more Python language features +than TorchScript (as it is easier to have comprehensive coverage over Python +bytecodes). The resulting graphs are simpler and only have straight line control +flow (except for explicit control flow operators). + +Compared to {func}`torch.jit.trace`, `torch.export` is sound: it is able to +trace code that performs integer computation on sizes and records all of the +side-conditions necessary to show that a particular trace is valid for other +inputs. + +## Exporting a PyTorch Model + +### An Example + +The main entrypoint is through {func}`torch.export.export`, which takes a +callable ({class}`torch.nn.Module`, function, or method) and sample inputs, and +captures the computation graph into an {class}`torch.export.ExportedProgram`. An +example: + +```python +import torch +from torch.export import export +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Simple module for demonstration class M(torch.nn.Module): @@ -131,6 +223,7 @@ class M(torch.nn.Module): example_args = (torch.randn(1, 3, 256, 256),) example_kwargs = {"constant": torch.ones(1, 16, 256, 256)} +<<<<<<< HEAD exported_program: ExportedProgram = export( M(), args=example_args, kwargs=example_kwargs ) @@ -138,6 +231,66 @@ print(exported_program) # To run the exported program, we can use the `module()` method print(exported_program.module()(torch.randn(1, 3, 256, 256), constant=torch.ones(1, 16, 256, 256))) +======= +exported_program: torch.export.ExportedProgram = export( + M(), args=example_args, kwargs=example_kwargs +) +print(exported_program) +``` + +```python +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"): + # code: a = self.conv(x) + conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1]) + + # code: a.add_(constant) + add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant) + + # code: return self.maxpool(self.relu(a)) + relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_) + max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3]) + return (max_pool2d,) + +Graph signature: + ExportGraphSignature( + input_specs=[ + InputSpec( + kind=, + arg=TensorArgument(name='p_conv_weight'), + target='conv.weight', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='p_conv_bias'), + target='conv.bias', + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='x'), + target=None, + persistent=None + ), + InputSpec( + kind=, + arg=TensorArgument(name='constant'), + target=None, + persistent=None + ) + ], + output_specs=[ + OutputSpec( + kind=, + arg=TensorArgument(name='max_pool2d'), + target=None + ) + ] + ) +Range constraints: {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` Inspecting the `ExportedProgram`, we can note the following: @@ -145,13 +298,19 @@ Inspecting the `ExportedProgram`, we can note the following: - The {class}`torch.fx.Graph` contains the computation graph of the original program, along with records of the original code for easy debugging. - The graph contains only `torch.ops.aten` operators found [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml) +<<<<<<< HEAD and custom operators. +======= + and custom operators, and is fully functional, without any inplace operators + such as `torch.add_`. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - The parameters (weight and bias to conv) are lifted as inputs to the graph, resulting in no `get_attr` nodes in the graph, which previously existed in the result of {func}`torch.fx.symbolic_trace`. - The {class}`torch.export.ExportGraphSignature` models the input and output signature, along with specifying which inputs are parameters. - The resulting shape and dtype of tensors produced by each node in the graph is +<<<<<<< HEAD noted. For example, the `conv2d` node will result in a tensor of dtype `torch.float32` and shape (1, 16, 256, 256). @@ -205,6 +364,185 @@ from run to run. Such dimensions must be specified by using the ```{code-cell} import torch +======= + noted. For example, the `convolution` node will result in a tensor of dtype + `torch.float32` and shape (1, 16, 256, 256). + +(non-strict-export)= + +### Non-Strict Export + +In PyTorch 2.3, we introduced a new mode of tracing called **non-strict mode**. +It's still going through hardening, so if you run into any issues, please file +them to Github with the "oncall: export" tag. + +In *non-strict mode*, we trace through the program using the Python interpreter. +Your code will execute exactly as it would in eager mode; the only difference is +that all Tensor objects will be replaced by ProxyTensors, which will record all +their operations into a graph. + +In *strict* mode, which is currently the default, we first trace through the +program using TorchDynamo, a bytecode analysis engine. TorchDynamo does not +actually execute your Python code. Instead, it symbolically analyzes it and +builds a graph based on the results. This analysis allows torch.export to +provide stronger guarantees about safety, but not all Python code is supported. + +An example of a case where one might want to use non-strict mode is if you run +into a unsupported TorchDynamo feature that might not be easily solved, and you +know the python code is not exactly needed for computation. For example: + +```python +import contextlib +import torch + +class ContextManager(): + def __init__(self): + self.count = 0 + def __enter__(self): + self.count += 1 + def __exit__(self, exc_type, exc_value, traceback): + self.count -= 1 + +class M(torch.nn.Module): + def forward(self, x): + with ContextManager(): + return x.sin() + x.cos() + +export(M(), (torch.ones(3, 3),), strict=False) # Non-strict traces successfully +export(M(), (torch.ones(3, 3),)) # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager +``` + +In this example, the first call using non-strict mode (through the +`strict=False` flag) traces successfully whereas the second call using strict +mode (default) results with a failure, where TorchDynamo is unable to support +context managers. One option is to rewrite the code (see {ref}`Limitations of torch.export `), +but seeing as the context manager does not affect the tensor +computations in the model, we can go with the non-strict mode's result. + +(training-export)= + +### Export for Training and Inference + +In PyTorch 2.5, we introduced a new API called {func}`export_for_training`. +It's still going through hardening, so if you run into any issues, please file +them to Github with the "oncall: export" tag. + +In this API, we produce the most generic IR that contains all ATen operators +(including both functional and non-functional) which can be used to train in +eager PyTorch Autograd. This API is intended for eager training use cases such as PT2 Quantization +and will soon be the default IR of torch.export.export. To read further about +the motivation behind this change, please refer to + + +When this API is combined with {func}`run_decompositions()`, you should be able to get inference IR with +any desired decomposition behavior. + +To show some examples: + +```python +class ConvBatchnorm(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(1, 3, 1, 1) + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return (x,) + +mod = ConvBatchnorm() +inp = torch.randn(1, 1, 3, 3) + +ep_for_training = torch.export.export_for_training(mod, (inp,)) +print(ep_for_training) +``` + +```python +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): + conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias) + add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1) + batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True) + return (batch_norm,) +``` + +From the above output, you can see that {func}`export_for_training` produces pretty much the same ExportedProgram +as {func}`export` except for the operators in the graph. You can see that we captured batch_norm in the most general +form. This op is non-functional and will be lowered to different ops when running inference. + +You can also go from this IR to an inference IR via {func}`run_decompositions` with arbitrary customizations. + +```python +# Lower to core aten inference IR, but keep conv2d +decomp_table = torch.export.default_decompositions() +del decomp_table[torch.ops.aten.conv2d.default] +ep_for_inference = ep_for_training.run_decompositions(decomp_table) + +print(ep_for_inference) +``` + +```python +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): + conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias) + add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1) + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05) + getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] + getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] + getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4] + return (getitem_3, getitem_4, add, getitem) +``` + +Here you can see that we kept `conv2d` op in the IR while decomposing the rest. Now the IR is a functional IR +containing core aten operators except for `conv2d`. + +You can do even more customization by directly registering your chosen decomposition behaviors. + +You can do even more customizations by directly registering custom decomp behaviour + +```python +# Lower to core aten inference IR, but customize conv2d +decomp_table = torch.export.default_decompositions() + +def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1): + return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups) + +decomp_table[torch.ops.aten.conv2d.default] = my_awesome_conv2d_function +ep_for_inference = ep_for_training.run_decompositions(decomp_table) + +print(ep_for_inference) +``` + +```python +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): + convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1) + mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2) + add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1) + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05) + getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] + getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] + getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; + return (getitem_3, getitem_4, add, getitem) +``` + +### Expressing Dynamism + +By default `torch.export` will trace the program assuming all input shapes are +**static**, and specializing the exported program to those dimensions. However, +some dimensions, such as a batch dimension, can be dynamic and vary from run to +run. Such dimensions must be specified by using the +{func}`torch.export.Dim` API to create them and by passing them into +{func}`torch.export.export` through the `dynamic_shapes` argument. An example: + +```python +import torch +from torch.export import Dim, export +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class M(torch.nn.Module): def __init__(self): @@ -226,6 +564,7 @@ class M(torch.nn.Module): example_args = (torch.randn(32, 64), torch.randn(32, 128)) # Create a dynamic batch size +<<<<<<< HEAD batch = torch.export.Dim("batch") # Specify that the first dimension of each input is that batch size dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} @@ -237,14 +576,49 @@ print(ep) example_args2 = (torch.randn(64, 64), torch.randn(64, 128)) ep.module()(*example_args2) # success +======= +batch = Dim("batch") +# Specify that the first dimension of each input is that batch size +dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} + +exported_program: torch.export.ExportedProgram = export( + M(), args=example_args, dynamic_shapes=dynamic_shapes +) +print(exported_program) +``` + +```python +ExportedProgram: +class GraphModule(torch.nn.Module): + def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"): + + # code: out1 = self.branch1(x1) + linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias) + relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear) + + # code: out2 = self.branch2(x2) + linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias) + relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1) + + # code: return (out1 + self.buffer, out2) + add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer) + return (add, relu_1) + +Range constraints: {s0: VR[0, int_oo]} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` Some additional things to note: - Through the {func}`torch.export.Dim` API and the `dynamic_shapes` argument, we specified the first dimension of each input to be dynamic. Looking at the inputs `x1` and +<<<<<<< HEAD `x2`, they have a symbolic shape of `(s0, 64)` and `(s0, 128)`, instead of the `(32, 64)` and `(32, 128)` shaped tensors that we passed in as example inputs. +======= + `x2`, they have a symbolic shape of (s0, 64) and (s0, 128), instead of + the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) `s0` is a symbol representing that this dimension can be a range of values. - `exported_program.range_constraints` describes the ranges of each symbol @@ -255,6 +629,7 @@ Some additional things to note: [The 0/1 Specialization Problem](https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk) for an in-depth discussion of this topic. +<<<<<<< HEAD In the example, we used `Dim("batch")` to create a dynamic dimension. This is the most explicit way to specify dynamism. We can also use `Dim.DYNAMIC` and @@ -426,22 +801,82 @@ To save the `ExportedProgram`, users can use the {func}`torch.export.save` and {func}`torch.export.load` APIs. The resulting file is a zipfile with a specific structure. The details of the structure are defined in the {ref}`PT2 Archive Spec `. +======= +We can also specify more expressive relationships between input shapes, such as +where a pair of shapes might differ by one, a shape might be double of +another, or a shape is even. An example: + +```python +class M(torch.nn.Module): + def forward(self, x, y): + return x + y[1:] + +x, y = torch.randn(5), torch.randn(6) +dimx = torch.export.Dim("dimx", min=3, max=6) +dimy = dimx + 1 + +exported_program = torch.export.export( + M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}), +) +print(exported_program) +``` + +```python +ExportedProgram: +class GraphModule(torch.nn.Module): + def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"): + # code: return x + y[1:] + slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807) + add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1) + return (add,) + +Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]} +``` + +Some things to note: + +- By specifying `{0: dimx}` for the first input, we see that the resulting + shape of the first input is now dynamic, being `[s0]`. And now by specifying + `{0: dimy}` for the second input, we see that the resulting shape of the + second input is also dynamic. However, because we expressed `dimy = dimx + 1`, + instead of `y`'s shape containing a new symbol, we see that it is + now being represented with the same symbol used in `x`, `s0`. We can + see that relationship of `dimy = dimx + 1` is being shown through `s0 + 1`. +- Looking at the range constraints, we see that `s0` has the range [3, 6], + which is specified initially, and we can see that `s0 + 1` has the solved + range of [4, 7]. + +### Serialization + +To save the `ExportedProgram`, users can use the {func}`torch.export.save` and +{func}`torch.export.load` APIs. A convention is to save the `ExportedProgram` +using a `.pt2` file extension. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) An example: ```python import torch +<<<<<<< HEAD +======= +import io +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class MyModule(torch.nn.Module): def forward(self, x): return x + 10 +<<<<<<< HEAD exported_program = torch.export.export(MyModule(), (torch.randn(5),)) +======= +exported_program = torch.export.export(MyModule(), torch.randn(5)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.export.save(exported_program, 'exported_program.pt2') saved_exported_program = torch.export.load('exported_program.pt2') ``` +<<<<<<< HEAD (training-export)= ## Export IR, Decompositions @@ -575,17 +1010,140 @@ Notice that instead of `torch.ops.aten.conv2d.default` being decomposed into `torch.ops.aten.convolution.default`, it is now decomposed into `torch.ops.aten.convolution.default` and `torch.ops.aten.mul.Tensor`, which matches our custom decomposition rule. +======= +### Specializations + +A key concept in understanding the behavior of `torch.export` is the +difference between *static* and *dynamic* values. + +A *dynamic* value is one that can change from run to run. These behave like +normal arguments to a Python function—you can pass different values for an +argument and expect your function to do the right thing. Tensor *data* is +treated as dynamic. + +A *static* value is a value that is fixed at export time and cannot change +between executions of the exported program. When the value is encountered during +tracing, the exporter will treat it as a constant and hard-code it into the +graph. + +When an operation is performed (e.g. `x + y`) and all inputs are static, then +the output of the operation will be directly hard-coded into the graph, and the +operation won’t show up (i.e. it will get constant-folded). + +When a value has been hard-coded into the graph, we say that the graph has been +*specialized* to that value. + +The following values are static: + +#### Input Tensor Shapes + +By default, `torch.export` will trace the program specializing on the input +tensors' shapes, unless a dimension is specified as dynamic via the +`dynamic_shapes` argument to `torch.export`. This means that if there exists +shape-dependent control flow, `torch.export` will specialize on the branch +that is being taken with the given sample inputs. For example: + +```python +import torch +from torch.export import export + +class Mod(torch.nn.Module): + def forward(self, x): + if x.shape[0] > 5: + return x + 1 + else: + return x - 1 + +example_inputs = (torch.rand(10, 2),) +exported_program = export(Mod(), example_inputs) +print(exported_program) +``` + +```python +ExportedProgram: +class GraphModule(torch.nn.Module): + def forward(self, x: "f32[10, 2]"): + # code: return x + 1 + add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1) + return (add,) +``` + +The conditional of (`x.shape[0] > 5`) does not appear in the +`ExportedProgram` because the example inputs have the static +shape of (10, 2). Since `torch.export` specializes on the inputs' static +shapes, the else branch (`x - 1`) will never be reached. To preserve the dynamic +branching behavior based on the shape of a tensor in the traced graph, +{func}`torch.export.Dim` will need to be used to specify the dimension +of the input tensor (`x.shape[0]`) to be dynamic, and the source code will +need to be {ref}`rewritten `. + +Note that tensors that are part of the module state (e.g. parameters and +buffers) always have static shapes. + +#### Python Primitives + +`torch.export` also specializes on Python primitives, +such as `int`, `float`, `bool`, and `str`. However they do have dynamic +variants such as `SymInt`, `SymFloat`, and `SymBool`. + +For example: + +```python +import torch +from torch.export import export + +class Mod(torch.nn.Module): + def forward(self, x: torch.Tensor, const: int, times: int): + for i in range(times): + x = x + const + return x + +example_inputs = (torch.rand(2, 2), 1, 3) +exported_program = export(Mod(), example_inputs) +print(exported_program) +``` + +```python +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, x: "f32[2, 2]", const, times): + # code: x = x + const + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1) + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1) + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1) + return (add_2,) +``` + +Because integers are specialized, the `torch.ops.aten.add.Tensor` operations +are all computed with the hard-coded constant `1`, rather than `const`. If +a user passes a different value for `const` at runtime, like 2, than the one used +during export time, 1, this will result in an error. +Additionally, the `times` iterator used in the `for` loop is also "inlined" +in the graph through the 3 repeated `torch.ops.aten.add.Tensor` calls, and the +input `times` is never used. + +#### Python Containers + +Python containers (`List`, `Dict`, `NamedTuple`, etc.) are considered to +have static structure. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (limitations-of-torch-export)= ## Limitations of torch.export +<<<<<<< HEAD +======= +### Graph Breaks + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) As `torch.export` is a one-shot process for capturing a computation graph from a PyTorch program, it might ultimately run into untraceable parts of programs as it is nearly impossible to support tracing all PyTorch and Python features. In the case of `torch.compile`, an unsupported operation will cause a "graph break" and the unsupported operation will be run with default Python evaluation. In contrast, `torch.export` will require users to provide additional +<<<<<<< HEAD information or rewrite parts of their code to make it traceable. {ref}`Draft-export ` is a great resource for listing out @@ -606,6 +1164,19 @@ some Python features that are unsupported. An option to get past dealing with this graph breaks is by using {ref}`non-strict export ` through changing the `strict` flag to `strict=False`. +======= +information or rewrite parts of their code to make it traceable. As the +tracing is based on TorchDynamo, which evaluates at the Python +bytecode level, there will be significantly fewer rewrites required compared to +previous tracing frameworks. + +When a graph break is encountered, {ref}`ExportDB ` is a great +resource for learning about the kinds of programs that are supported and +unsupported, along with ways to rewrite programs to make them traceable. + +An option to get past dealing with this graph breaks is by using +{ref}`non-strict export ` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (data-shape-dependent-control-flow)= @@ -618,6 +1189,7 @@ number of paths. In such cases, users will need to rewrite their code using special control flow operators. Currently, we support {ref}`torch.cond ` to express if-else like control flow (more coming soon!). +<<<<<<< HEAD You can also refer to this [tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html#data-dependent-errors) for more ways of addressing data-dependent errors. @@ -630,6 +1202,15 @@ operator. Please see this [tutorial](https://docs.pytorch.org/tutorials/advanced/custom_ops_landing_page.html) for more details. +======= +### Missing Fake/Meta/Abstract Kernels for Operators + +When tracing, a FakeTensor kernel (aka meta kernel, abstract impl) is +required for all operators. This is used to reason about the input/output shapes +for this operator. + +Please see {func}`torch.library.register_fake` for more details. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) In the unfortunate case where your model uses an ATen operator that is does not have a FakeTensor kernel implementation yet, please file an issue. @@ -640,6 +1221,7 @@ have a FakeTensor kernel implementation yet, please file an issue. :caption: Additional Links for Export Users :maxdepth: 1 +<<<<<<< HEAD export/api_reference export/programming_model export/ir_spec @@ -650,13 +1232,210 @@ cond generated/exportdb/index torch.compiler_aot_inductor torch.compiler_ir +======= +export.programming_model +export.ir_spec +draft_export +torch.compiler_transformations +torch.compiler_ir +generated/exportdb/index +cond +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` ```{toctree} :caption: Deep Dive for PyTorch Developers :maxdepth: 1 +<<<<<<< HEAD torch.compiler_dynamic_shapes torch.compiler_fake_tensor torch.compiler_transformations +======= +torch.compiler_dynamo_overview +torch.compiler_dynamo_deepdive +torch.compiler_dynamic_shapes +torch.compiler_fake_tensor +``` + +## API Reference + +```{eval-rst} +.. automodule:: torch.export +``` + +```{eval-rst} +.. autofunction:: export +``` + +```{eval-rst} +.. autofunction:: save +``` + +```{eval-rst} +.. autofunction:: load +``` + +```{eval-rst} +.. autofunction:: draft_export +``` + +```{eval-rst} +.. autofunction:: register_dataclass +``` + +```{eval-rst} +.. autoclass:: torch.export.dynamic_shapes.Dim +``` + +```{eval-rst} +.. autoclass:: torch.export.dynamic_shapes.ShapesCollection + + .. automethod:: dynamic_shapes +``` + +```{eval-rst} +.. autoclass:: torch.export.dynamic_shapes.AdditionalInputs + + .. automethod:: add + .. automethod:: dynamic_shapes + .. automethod:: verify +``` + +```{eval-rst} +.. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes +``` + +```{eval-rst} +.. autoclass:: ExportedProgram + + .. attribute:: graph + .. attribute:: graph_signature + .. attribute:: state_dict + .. attribute:: constants + .. attribute:: range_constraints + .. attribute:: module_call_graph + .. attribute:: example_inputs + .. automethod:: module + .. automethod:: run_decompositions +``` + +```{eval-rst} +.. autoclass:: ExportGraphSignature +``` + +```{eval-rst} +.. autoclass:: ModuleCallSignature +``` + +```{eval-rst} +.. autoclass:: ModuleCallEntry +``` + +```{eval-rst} +.. automodule:: torch.export.decomp_utils +``` + +```{eval-rst} +.. autoclass:: CustomDecompTable + + .. automethod:: copy + .. automethod:: items + .. automethod:: keys + .. automethod:: materialize + .. automethod:: pop + .. automethod:: update +``` + +```{eval-rst} +.. autofunction:: torch.export.exported_program.default_decompositions +``` + +```{eval-rst} +.. automodule:: torch.export.exported_program +``` + +```{eval-rst} +.. automodule:: torch.export.graph_signature +``` + +```{eval-rst} +.. autoclass:: ExportGraphSignature + + .. automethod:: replace_all_uses + .. automethod:: get_replace_hook +``` + +```{eval-rst} +.. autoclass:: ExportBackwardSignature +``` + +```{eval-rst} +.. autoclass:: InputKind +``` + +```{eval-rst} +.. autoclass:: InputSpec +``` + +```{eval-rst} +.. autoclass:: OutputKind +``` + +```{eval-rst} +.. autoclass:: OutputSpec +``` + +```{eval-rst} +.. autoclass:: SymIntArgument +``` + +```{eval-rst} +.. autoclass:: SymBoolArgument +``` + +```{eval-rst} +.. autoclass:: SymFloatArgument +``` + +```{eval-rst} +.. autoclass:: CustomObjArgument +``` + +```{eval-rst} +.. py:module:: torch.export.dynamic_shapes +``` + +```{eval-rst} +.. py:module:: torch.export.custom_ops +``` + +```{eval-rst} +.. automodule:: torch.export.unflatten + :members: +``` + +```{eval-rst} +.. automodule:: torch.export.custom_obj +``` + +```{eval-rst} +.. automodule:: torch.export.experimental +``` + +```{eval-rst} +.. automodule:: torch.export.passes +``` + +```{eval-rst} +.. autofunction:: torch.export.passes.move_to_device_pass +``` + +```{eval-rst} +.. automodule:: torch.export.pt2_archive +``` + +```{eval-rst} +.. automodule:: torch.export.pt2_archive.constants +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` diff --git a/docs/source/export.programming_model.md b/docs/source/export.programming_model.md new file mode 100644 index 0000000000000..9a21db78464aa --- /dev/null +++ b/docs/source/export.programming_model.md @@ -0,0 +1,523 @@ +(export-programming-model)= + +# torch.export Programming Model + +This document aims to explain the behaviors and capabilities of +{func}`torch.export.export`. It is intended to help build your intuition +for how {func}`torch.export.export` handles code. + +## Basics of Tracing + +{func}`torch.export.export` captures a graph representing your model by +tracing its execution on "example" inputs and recording the PyTorch operations +and conditions observed along the traced path. This graph can then be run +on different inputs as long as they satisfy the same conditions. + +The basic output of {func}`torch.export.export` is a single graph of PyTorch +operations, with associated metadata. The exact format of this output is +covered in the {ref}`export.ir_spec`. + +### Strict vs. Non-Strict Tracing + +{func}`torch.export.export` provides two modes of tracing. + +In *non-strict mode*, we trace through the program using the normal Python +interpreter. Your code executes exactly as it would in eager mode; the only +difference is that all Tensors are replaced by +[fake Tensors](https://pytorch.org/docs/main/torch.compiler_fake_tensor.html), +**which have shapes and other forms of metadata but no data**, wrapped in +[Proxy objects](https://pytorch.org/docs/main/fx.html) that record all +operations on them into a graph. We also capture +[conditions on Tensor shapes](https://pytorch.org/docs/main/torch.compiler_dynamic_shapes.html#the-guard-model) +**that guard the correctness of the generated code**. + +In *strict mode*, we first trace through the program using +{ref}`TorchDynamo `, a Python bytecode +analysis engine. TorchDynamo does not actually execute your Python code. +Instead, it symbolically analyzes it and builds a graph based on the results. +On the one hand, this analysis allows {func}`torch.export.export` to provide +additional guarantees on Python-level safety (beyond capturing conditions on +Tensor shapes, as in non-strict mode). On the other hand, not all Python +features are supported by this analysis. + +Although currently the default mode of tracing is strict, **we strongly +recommend using non-strict**, which will soon become the default. +For most models, conditions on Tensor shapes are enough for soundness, and +the additional guarantees on Python-level safety have no impact; at the same +time, the possibility of hitting unsupported Python features in TorchDynamo +presents an unnecessary risk. + +In the rest of this document we assume we are tracing in +[non-strict mode](https://pytorch.org/docs/main/export.html#non-strict-export); +in particular, we assume that **all Python features are supported**. + +## Values: Static vs. Dynamic + +A key concept in understanding the behavior of {func}`torch.export.export` is +the difference between *static* and *dynamic* values. + +### Static Values + +A *static* value is a value that is **fixed at export time and cannot change +between executions of the exported program**. When the value is encountered +during tracing, we treat it as a constant and hard-code it into the graph. + +When an operation is performed (e.g. `x + y`) and all inputs are static, +the output of the operation is directly hard-coded into the graph and the +operation does not show up (i.e. it gets "constant-folded"). + +When a value has been hard-coded into the graph, we say that the graph has +been *specialized* to that value. For example: + +```python +import torch + +class MyMod(torch.nn.Module): + def forward(self, x, y): + z = y + 7 + return x + z + +m = torch.export.export(MyMod(), (torch.randn(1), 3)) +print(m.graph_module.code) + +""" +def forward(self, arg0_1, arg1_1): + add = torch.ops.aten.add.Tensor(arg0_1, 10); arg0_1 = None + return (add,) + +""" +``` + +Here, we provide `3` as the traced value for `y`; it is treated as a static +value and added to `7`, burning in the static value `10` in the graph. + +### Dynamic Values + +A *dynamic* value is one that **can change from run to run**. It behaves just +like a "normal" function argument: you can pass different inputs and expect +your function to do the right thing. + +### Which values are static vs. dynamic? + +Whether a value is static or dynamic depends on its type: + +- For Tensor: + + - Tensor *data* is treated as dynamic. + + - Tensor *shapes* can be treated by the system as static or dynamic. + + - By default, shapes of all input Tensors are considered static. + The user can override this behavior for any input Tensor by specifying + a [dynamic shape](https://pytorch.org/docs/main/export.html#expressing-dynamism) + for it. + - Tensors that are part of module state, i.e., parameters and buffers, + always have static shapes. + + - Other forms of Tensor *metadata* (e.g. `device`, `dtype`) are static. + +- Python *primitives* (`int`, `float`, `bool`, `str`, `None`) are static. + + - There are dynamic variants for some primitive types (`SymInt`, + `SymFloat`, `SymBool`). Typically users do not have to deal with them. + +- For Python *standard containers* (`list`, `tuple`, `dict`, `namedtuple`): + + - The structure (i.e., length for `list` and `tuple` values, and key + sequence for `dict` and `namedtuple` values) is static. + - The contained elements have these rules applied to them recursively + (basically the + [PyTree](https://jax.readthedocs.io/en/latest/pytrees.html) scheme) + with leaves that are either Tensor or primitive types. + +- Other *classes* (including data classes) can be registered with PyTree + (see below), and follow the same rules as the standard containers. + +## Input types + +Inputs will be treated as either static or dynamic, based on their type +(as explained above). + +- A static input will get hard-coded into the graph, and passing a different + value at run time will result in an error. Recall that these are mostly + values of primitive types. +- A dynamic input behaves like a "normal" function input. Recall that these + are mostly values of Tensor types. + +By default, the types of inputs you can use for your program are: + +- Tensor +- Python primitives (`int`, `float`, `bool`, `str`, `None`) +- Python standard containers (`list`, `tuple`, `dict`, `namedtuple`) + +### Custom Input Types + +In addition, you can also define your own (custom) class and use it as an +input type, but you will need to register such a class as a PyTree. + +Here's an example of using an utility to register a dataclass that is used as +an input type. + +```python +@dataclass +class Input: + f: torch.Tensor + p: torch.Tensor + +torch.export.register_dataclass(Input) + +class M(torch.nn.Module): + def forward(self, x: Input): + return x.f + 1 + +torch.export.export(M(), (Input(f=torch.ones(10, 4), p=torch.zeros(10, 4)),)) +``` + +### Optional input types + +For optional inputs to the program that are not passed in, +{func}`torch.export.export` will specialize to their default values. As a +result, the exported program will require users to explicitly pass in all +arguments, and will lose the defaulting behavior. For example: + +```python +class M(torch.nn.Module): + def forward(self, x, y=None): + if y is not None: + return y * x + return x + x + +# Optional input is passed in +ep = torch.export.export(M(), (torch.randn(3, 3), torch.randn(3, 3))) +print(ep) +""" +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, x: "f32[3, 3]", y: "f32[3, 3]"): + # File: /data/users/angelayi/pytorch/moo.py:15 in forward, code: return y * x + mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(y, x); y = x = None + return (mul,) +""" + +# Optional input is not passed in +ep = torch.export.export(M(), (torch.randn(3, 3),)) +print(ep) +""" +ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, x: "f32[3, 3]", y): + # File: /data/users/angelayi/pytorch/moo.py:16 in forward, code: return x + x + add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, x); x = None + return (add,) +""" +``` + +## Control Flow: Static vs. Dynamic + +Control flow is supported by {func}`torch.export.export`. The behavior of +control flow depends on whether the value you are branching on is static or +dynamic. + +### Static Control Flow + +**Python control flow over static values is supported transparently**. (Recall +that static values include static shapes, so control flow over static shapes +is also covered by this case.) + +As mentioned above, we "burn in" static values, so the exported graph will +never see any control flow over static values. + +In the case of an `if` statement, we will continue tracing the branch taken +at export time. In the case of a `for` or `while` statement, we will continue +tracing by unrolling the loop. + +### Dynamic Control Flow: Shape-Dependent vs. Data-Dependent + +When the value involved in a control flow is dynamic, it could depend on +dynamic shapes or dynamic data. Given that the compiler traces with +information on shapes rather than data, the implications on the programming +model are different in these cases. + +#### Dynamic Shape-Dependent Control Flow + +When the value involved in a control flow is a +[dynamic shape](https://pytorch.org/docs/main/torch.compiler_dynamic_shapes.html), +in most cases **we will also know the concrete value of the dynamic shape +during tracing**: see the following section for more details on how the +compiler tracks this information. + +In these cases we say that the control flow is shape-dependent. **We use the +concrete value of the dynamic shape to evaluate the condition** to either +`True` or `False` and continue tracing (as discussed above), additionally +emitting a guard corresponding to the condition just evaluated. + +Otherwise the control flow is considered data-dependent. We cannot evaluate +the condition to either `True` or `False`, so cannot continue tracing and have to +raise an error at export time. See next section. + +#### Dynamic Data-Dependent Control Flow + +**Data-dependent control flow over dynamic values is supported, but you must +use one of PyTorch's explicit operators** to continue tracing. Using Python +control flow statements over dynamic values is not permitted, because the +compiler cannot evaluate the conditions necessary to continue tracing and +thus an error must be raised at export time. + +We provide **operators to express general conditionals and loops over dynamic +values**, e.g., `torch.cond`, `torch.map`. Note that you only need to use these +if you truly want *data-dependent control flow*. + +Here's an example of an `if` statement on a data-dependent condition, +`x.sum() > 0`, where `x` is an input Tensor, rewritten using `torch.cond`. +Instead of having to decide which branch to trace, now both branches are +traced. + +```python +class M_old(torch.nn.Module): + def forward(self, x): + if x.sum() > 0: + return x.sin() + else: + return x.cos() + +class M_new(torch.nn.Module): + def forward(self, x): + return torch.cond( + pred=x.sum() > 0, + true_fn=lambda x: x.sin(), + false_fn=lambda x: x.cos(), + operands=(x,), + ) +``` + +A special case of data-dependent control flow is where it involves a +[data-dependent dynamic shape](https://pytorch.org/docs/main/torch.compiler_dynamic_shapes.html#unbacked-symints): +typically, the shape of some intermediate Tensor that depends on input data +rather than on input shapes (thus not shape-dependent). Instead of using a +control flow operator, in this case you can provide an assertion that decides +whether the condition is `True` or `False`. Given such an assertion, we can +continue tracing, emitting a guard as above. + +We provide **operators to express assertions on dynamic shapes**, e.g., +`torch._check`. Note that you only need to use this when there is control +flow on data-dependent dynamic shapes. + +Here's an example of an `if` statement on a condition involving a +data-dependent dynamic shape, `nz.shape[0] > 0`, where `nz` is the result of +calling {func}`torch.nonzero`, an operator whose output shape depends on input +data. Instead of rewriting it, you can add an assertion using `torch._check` +to effectively decide which branch to trace. + +```python +class M_old(torch.nn.Module): + def forward(self, x): + nz = x.nonzero() + if nz.shape[0] > 0: + return x.sin() + else: + return x.cos() + +class M_new(torch.nn.Module): + def forward(self, x): + nz = x.nonzero() + torch._check(nz.shape[0] > 0) + if nz.shape[0] > 0: + return x.sin() + else: + return x.cos() +``` + +## Basics of Symbolic Shapes + +During tracing, dynamic Tensor shapes and conditions over them are encoded as +"symbolic expressions." (In contrast, static Tensor shapes and conditions +over them are simply `int` and `bool` values.) + +A *symbol* is like a variable; it describes a dynamic Tensor shape. + +As tracing proceeds, shapes of intermediate Tensors may be described by more +general expressions, typically involving integer arithmetic operators. This +is because **for most PyTorch operators, shapes of output Tensors can be +described as functions of shapes of input Tensors**. For example, the shape of +the output of {func}`torch.cat` is the sum of the shapes of its inputs. + +Moreover, as we encounter control flow in the program, we create boolean +expressions, typically involving relational operators, describing conditions +along the traced path. These **expressions are evaluated to decide which path +to trace through the program**, and recorded in a +[shape environment](https://pytorch.org/docs/main/torch.compiler_dynamic_shapes.html#overall-architecture) +to guard the correctness of the traced path and to evaluate subsequently +created expressions. + +We briefly introduce these subsystems next. + +### Fake Implementations of PyTorch Operators + +Recall that during tracing, we are executing the program with +[fake Tensors](https://pytorch.org/docs/main/torch.compiler_fake_tensor.html), +which have no data. In general we cannot call the actual implementations of +PyTorch operators with fake Tensors. Thus each operator needs to have an +additional fake (a.k.a. "meta") implementation, which inputs and outputs fake +Tensors, that matches the behavior of the actual implementation in terms of +shapes and other forms of metadata carried by fake Tensors. + +For example, note how the fake implementation of {func}`torch.index_select` +computes the shape of the output using the shape of the input (while ignoring +input data and returning empty output data). + +```python +def meta_index_select(self, dim, index): + result_size = list(self.size()) + if self.dim() > 0: + result_size[dim] = index.numel() + return self.new_empty(result_size) +``` + +#### Shape Propagation: Backed vs. Unbacked Dynamic Shapes + +Shapes are propagated using fake implementations of PyTorch operators. + +A key concept to understand the propagation of dynamic shapes in particular +is the difference between *backed* and *unbacked* dynamic shapes: we know the +concrete values of the former but not the latter. + +Propagation of shapes, including tracking backed and unbacked dynamic shapes, +proceeds as follows: + +- The shapes of Tensors representing inputs can be static or dynamic. When + dynamic, they are described by symbols; moreover, **such symbols are backed + since we also know their concrete values given the "real" example inputs + provided by the user at export time**. + +- The output shape of an operator is computed by its fake implementation, and + is either static or dynamic. When dynamic, in general it is described by a + symbolic expression. Moreover: + + - If the output shape depends only on input shapes, it is either static or + backed dynamic whenever the input shapes are all static or backed dynamic. + - On the other hand, **if the output shape depends on input data**, it is + necessarily dynamic, and moreover, **because we cannot know its concrete + value it is unbacked**. + +### Control Flow: Guards and Assertions + +When a condition on shapes is encountered, it either involves only static +shapes, in which case it is a `bool`, or it involves dynamic shapes, in which +case it is a symbolic boolean expression. For the latter: + +- When the condition involves only backed dynamic shapes, we can use the + concrete values of those dynamic shapes to evaluate the condition to `True` + or `False`. We can then add a guard to the shape environment that states + that the corresponding symbolic boolean expression is `True` or `False`, + and continue tracing. +- Otherwise the condition involves unbacked dynamic shapes. In general we + cannot evaluate such a condition without additional information; thus we + cannot continue tracing, and we must raise an error at export time. The + user is expected to use an explicit PyTorch operator for tracing to + continue. This information is added as a guard in the shape environment, + and can also possibly help evaluate other subsequently encountered + conditions to `True` or `False`. + +Once the model is exported, **any guards on backed dynamic shapes can be +understood as conditions on input dynamic shapes**. These are verified against +a dynamic shape specification that must have been provided to export, +describing conditions on dynamic shapes that not only example inputs but also +all future inputs are expected to satisfy for the generated code to be +correct. More precisely, the dynamic shape specification must logically imply +the generated guards, otherwise an error is raised at export time (along with +suggested fixes to the dynamic shape specification). On the other hand, when +there are no generated guards on backed dynamic shapes (in particular, when +all shapes are static) no dynamic shape specification needs to be provided to +export. In general, the dynamic shape specification is converted to runtime +assertions on the inputs of the generated code. + +Finally, **any guards on unbacked dynamic shapes are converted to "inline" +runtime assertions**. These are added in the generated code at the locations +where those unbacked dynamic shapes were created: typically, right after +data-dependent operator calls. + +## Allowed PyTorch operators + +All PyTorch operators are permitted. + +### Custom operators + +In addition, you can define and use +[custom operators](https://pytorch.org/tutorials/advanced/python_custom_ops#python-custom-ops-tutorial). +Defining a custom operator includes defining a fake implementation for it, +just like any other PyTorch operator (see previous section). + +Here's an example of a custom `sin` operator that wraps NumPy, and its +registered (trivial) fake implementation. + +```python +@torch.library.custom_op("mylib::sin", mutates_args=()) +def sin(x: Tensor) -> Tensor: + x_np = x.numpy() + y_np = np.sin(x_np) + return torch.from_numpy(y_np) + +@torch.library.register_fake("mylib::sin") +def _(x: Tensor) -> Tensor: + return torch.empty_like(x) +``` + +**Sometimes your custom operator's fake implementation will involve +data-dependent shapes**. Here's how a fake implementation for a custom +`nonzero` might look like. + +```python +... + +@torch.library.register_fake("mylib::custom_nonzero") +def _(x): + nnz = torch.library.get_ctx().new_dynamic_size() + shape = [nnz, x.dim()] + return x.new_empty(shape, dtype=torch.int64) +``` + +## Module State: Reads vs. Updates + +Module states include parameters, buffers, and regular attributes. + +- A regular attribute can be of any type. +- On the other hand, parameters and buffers are always Tensors. + +Module states can be dynamic or static, based on their types as outlined +above. For example, `self.training` is a `bool`, which means it is static; on +the other hand, any parameter or buffer is dynamic. + +The *shapes* of any Tensors contained in module states cannot be dynamic, i.e., +those shapes are fixed at export time, and cannot change between executions +of the exported program. + +### Access rules + +**All module states must be initialized**. Accessing a module state that is +not already initialized causes an error to be raised at export time. + +**Reading module states is always permitted**. + +Updating module states is possible, but must follow the rules below: + +- **A static regular attribute** (e.g., of primitive type) **can be updated**. + Reads and updates can be freely interleaved, and as expected, any reads + will always see the values of the latest updates. Because these attributes + are static, we will also burn the values in, so the generated code will not + have any instructions to actually "get" or "set" such attributes. +- **A dynamic regular attribute** (e.g., of Tensor type) **cannot be updated**. + To do so, it must be registered as a buffer during module initialization. +- **A buffer can be updated**, where the updating can be in-place (e.g., + `self.buffer[:] = ...`) or not (e.g., `self.buffer = ...`). +- **A parameter cannot be updated**. Typically parameters are updated only + during training, not during inference. We recommend exporting with + {func}`torch.no_grad` to avoid parameter updates at export time. + +### Effects of functionalization + +Any dynamic module state that is read and/or updated is "lifted" +(respectively) as an input and/or output of the generated code. + +The exported program stores, along with the generated code, the initial +values of parameters and buffers and the constant values of other Tensor +attributes. diff --git a/docs/source/fx.md b/docs/source/fx.md index 831534606abe0..b9149fbb4b694 100644 --- a/docs/source/fx.md +++ b/docs/source/fx.md @@ -44,7 +44,12 @@ Your transform will take in a {class}`torch.nn.Module`, acquire a {class}`Graph` from it, do some modifications, and return a new {class}`torch.nn.Module`. You should think of the {class}`torch.nn.Module` that your FX transform returns as identical to a regular {class}`torch.nn.Module` -- you can pass it to another +<<<<<<< HEAD FX transform, or you can run it. Ensuring that the inputs and outputs of your FX transform are a +======= +FX transform, you can pass it to TorchScript, or you can +run it. Ensuring that the inputs and outputs of your FX transform are a +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {class}`torch.nn.Module` will allow for composability. ```{note} diff --git a/docs/source/index.md b/docs/source/index.md index df012d1d6e177..1d1d06cfcac3a 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -23,11 +23,25 @@ The APIs and performance characteristics of these features may change. :glob: :maxdepth: 2 +<<<<<<< HEAD Install PyTorch user_guide/index pytorch-api notes community/index +======= +pytorch-api +notes +``` + +```{toctree} +:glob: +:hidden: +:maxdepth: 2 + +community/index +C++ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` ## Indices and tables diff --git a/docs/source/jit.rst b/docs/source/jit.rst index 5295f82f9ac19..377be5d7c8e7b 100644 --- a/docs/source/jit.rst +++ b/docs/source/jit.rst @@ -2,6 +2,7 @@ TorchScript =========== .. toctree:: +<<<<<<< HEAD :maxdepth: 1 :hidden: @@ -16,6 +17,49 @@ TorchScript `torch.export `__ instead. .. automodule:: torch.jit +======= + :maxdepth: 1 + :caption: Builtin Functions + :hidden: + + torch.jit.supported_ops + + +.. toctree:: + :maxdepth: 1 + :caption: Language Reference + :hidden: + + jit_language_reference + + +.. toctree:: + :maxdepth: 1 + + jit_language_reference_v2 + + +.. contents:: :local: + :depth: 2 + +.. automodule:: torch.jit +.. currentmodule:: torch.jit + +TorchScript is a way to create serializable and optimizable models from PyTorch code. +Any TorchScript program can be saved from a Python +process and loaded in a process where there is no Python dependency. + +We provide tools to incrementally transition a model from a pure Python program +to a TorchScript program that can be run independently from Python, such as in a standalone C++ program. +This makes it possible to train models in PyTorch using familiar tools in Python and then export +the model via TorchScript to a production environment where Python programs may be disadvantageous +for performance and multi-threading reasons. + +For a gentle introduction to TorchScript, see the `Introduction to TorchScript `_ tutorial. + +For an end-to-end example of converting a PyTorch model to TorchScript and running it in C++, see the +`Loading a PyTorch Model in C++ `_ tutorial. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Creating TorchScript Code -------------------------- @@ -47,11 +91,825 @@ Creating TorchScript Code Attribute annotate +<<<<<<< HEAD .. This package is missing doc. Adding it here for coverage .. This does not add anything to the rendered page. .. py:module:: torch.jit.supported_ops .. py:module:: torch.jit.unsupported_tensor_ops +======= +Mixing Tracing and Scripting +---------------------------- + +In many cases either tracing or scripting is an easier approach for converting a model to TorchScript. +Tracing and scripting can be composed to suit the particular requirements +of a part of a model. + +Scripted functions can call traced functions. This is particularly useful when you need +to use control-flow around a simple feed-forward model. For instance the beam search +of a sequence to sequence model will typically be written in script but can call an +encoder module generated using tracing. + + +.. testsetup:: + + # These are hidden from the docs, but these are necessary for `doctest` + # since the `inspect` module doesn't play nicely with the execution + # environment for `doctest` + import torch + + original_script = torch.jit.script + def script_wrapper(obj, *args, **kwargs): + obj.__module__ = 'FakeMod' + return original_script(obj, *args, **kwargs) + + torch.jit.script = script_wrapper + + original_trace = torch.jit.trace + def trace_wrapper(obj, *args, **kwargs): + obj.__module__ = 'FakeMod' + return original_trace(obj, *args, **kwargs) + + torch.jit.trace = trace_wrapper + + +Example (calling a traced function in script): + +.. testcode:: + + import torch + + def foo(x, y): + return 2 * x + y + + traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) + + @torch.jit.script + def bar(x): + return traced_foo(x, x) + +Traced functions can call script functions. This is useful when a small part of +a model requires some control-flow even though most of the model is just a feed-forward +network. Control-flow inside of a script function called by a traced function is +preserved correctly. + +Example (calling a script function in a traced function): + +.. testcode:: + + import torch + + @torch.jit.script + def foo(x, y): + if x.max() > y.max(): + r = x + else: + r = y + return r + + + def bar(x, y, z): + return foo(x, y) + z + + traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3))) + +This composition also works for ``nn.Module``\s as well, where it can be used to generate +a submodule using tracing that can be called from the methods of a script module. + +Example (using a traced module): + +.. testcode:: + :skipif: torchvision is None + + import torch + import torchvision + + class MyScriptModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) + .resize_(1, 3, 1, 1)) + self.resnet = torch.jit.trace(torchvision.models.resnet18(), + torch.rand(1, 3, 224, 224)) + + def forward(self, input): + return self.resnet(input - self.means) + + my_script_module = torch.jit.script(MyScriptModule()) + + +TorchScript Language +-------------------- + +TorchScript is a statically typed subset of Python, so many Python features apply +directly to TorchScript. See the full :ref:`language-reference` for details. + + +.. _builtin functions: + +Built-in Functions and Modules +------------------------------ + +TorchScript supports the use of most PyTorch functions and many Python built-ins. +See :ref:`builtin-functions` for a full reference of supported functions. + +PyTorch Functions and Modules +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +TorchScript supports a subset of the tensor and neural network +functions that PyTorch provides. Most methods on Tensor as well as functions in +the ``torch`` namespace, all functions in ``torch.nn.functional`` and +most modules from ``torch.nn`` are supported in TorchScript. + +See :ref:`jit_unsupported` for a list of unsupported PyTorch functions and modules. + + +Python Functions and Modules +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Many of Python's `built-in functions `_ are supported in TorchScript. +The :any:`math` module is also supported (see :ref:`math-module` for details), but no other Python modules +(built-in or third party) are supported. + + +Python Language Reference Comparison +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +For a full listing of supported Python features, see :ref:`python-language-reference`. + +Debugging +--------- + +.. _`disable TorchScript`: + +Disable JIT for Debugging +~~~~~~~~~~~~~~~~~~~~~~~~~ +.. envvar:: PYTORCH_JIT + +Setting the environment variable ``PYTORCH_JIT=0`` will disable all script +and tracing annotations. If there is hard-to-debug error in one of your +TorchScript models, you can use this flag to force everything to run using native +Python. Since TorchScript (scripting and tracing) is disabled with this flag, +you can use tools like ``pdb`` to debug the model code. For example:: + + @torch.jit.script + def scripted_fn(x : torch.Tensor): + for i in range(12): + x = x + x + return x + + def fn(x): + x = torch.neg(x) + import pdb; pdb.set_trace() + return scripted_fn(x) + + traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),)) + traced_fn(torch.rand(3, 4)) + +Debugging this script with ``pdb`` works except for when we invoke the +:func:`@torch.jit.script ` function. We can globally disable +JIT, so that we can call the :func:`@torch.jit.script ` +function as a normal Python function and not compile it. If the above script +is called ``disable_jit_example.py``, we can invoke it like so:: + + $ PYTORCH_JIT=0 python disable_jit_example.py + +and we will be able to step into the :func:`@torch.jit.script +` function as a normal Python function. To disable the +TorchScript compiler for a specific function, see +:func:`@torch.jit.ignore `. + +.. _inspecting-code: + +Inspecting Code +~~~~~~~~~~~~~~~ + +TorchScript provides a code pretty-printer for all :class:`ScriptModule` instances. This +pretty-printer gives an interpretation of the script method's code as valid +Python syntax. For example: + +.. testcode:: + + @torch.jit.script + def foo(len): + # type: (int) -> torch.Tensor + rv = torch.zeros(3, 4) + for i in range(len): + if i < 10: + rv = rv - 1.0 + else: + rv = rv + 1.0 + return rv + + print(foo.code) + +.. testoutput:: + :hide: + + ... + +A :class:`ScriptModule` with a single ``forward`` method will have an attribute +``code``, which you can use to inspect the :class:`ScriptModule`'s code. +If the :class:`ScriptModule` has more than one method, you will need to access +``.code`` on the method itself and not the module. We can inspect the +code of a method named ``foo`` on a :class:`ScriptModule` by accessing ``.foo.code``. +The example above produces this output: :: + + def foo(len: int) -> Tensor: + rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None) + rv0 = rv + for i in range(len): + if torch.lt(i, 10): + rv1 = torch.sub(rv0, 1., 1) + else: + rv1 = torch.add(rv0, 1., 1) + rv0 = rv1 + return rv0 + +This is TorchScript's compilation of the code for the ``forward`` method. +You can use this to ensure TorchScript (tracing or scripting) has captured +your model code correctly. + + +.. _interpreting-graphs: + +Interpreting Graphs +~~~~~~~~~~~~~~~~~~~ +TorchScript also has a representation at a lower level than the code pretty-\ +printer, in the form of IR graphs. + +TorchScript uses a static single assignment (SSA) intermediate representation +(IR) to represent computation. The instructions in this format consist of +ATen (the C++ backend of PyTorch) operators and other primitive operators, +including control flow operators for loops and conditionals. As an example: + +.. testcode:: + + @torch.jit.script + def foo(len): + # type: (int) -> torch.Tensor + rv = torch.zeros(3, 4) + for i in range(len): + if i < 10: + rv = rv - 1.0 + else: + rv = rv + 1.0 + return rv + + print(foo.graph) + +.. testoutput:: + :hide: + + ... + +``graph`` follows the same rules described in the :ref:`inspecting-code` section +with regard to ``forward`` method lookup. + +The example script above produces the graph:: + + graph(%len.1 : int): + %24 : int = prim::Constant[value=1]() + %17 : bool = prim::Constant[value=1]() # test.py:10:5 + %12 : bool? = prim::Constant() + %10 : Device? = prim::Constant() + %6 : int? = prim::Constant() + %1 : int = prim::Constant[value=3]() # test.py:9:22 + %2 : int = prim::Constant[value=4]() # test.py:9:25 + %20 : int = prim::Constant[value=10]() # test.py:11:16 + %23 : float = prim::Constant[value=1]() # test.py:12:23 + %4 : int[] = prim::ListConstruct(%1, %2) + %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10 + %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5 + block0(%i.1 : int, %rv.14 : Tensor): + %21 : bool = aten::lt(%i.1, %20) # test.py:11:12 + %rv.13 : Tensor = prim::If(%21) # test.py:11:9 + block0(): + %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18 + -> (%rv.3) + block1(): + %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18 + -> (%rv.6) + -> (%17, %rv.13) + return (%rv) + + +Take the instruction ``%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10`` for +example. + +* ``%rv.1 : Tensor`` means we assign the output to a (unique) value named ``rv.1``, that value is of ``Tensor`` type and that we do not know its concrete shape. +* ``aten::zeros`` is the operator (equivalent to ``torch.zeros``) and the input list ``(%4, %6, %6, %10, %12)`` specifies which values in scope should be passed as inputs. The schema for built-in functions like ``aten::zeros`` can be found at `Builtin Functions`_. +* ``# test.py:9:10`` is the location in the original source file that generated this instruction. In this case, it is a file named `test.py`, on line 9, and at character 10. + +Notice that operators can also have associated ``blocks``, namely the +``prim::Loop`` and ``prim::If`` operators. In the graph print-out, these +operators are formatted to reflect their equivalent source code forms +to facilitate easy debugging. + +Graphs can be inspected as shown to confirm that the computation described +by a :class:`ScriptModule` is correct, in both automated and manual fashion, as +described below. + +Tracer +~~~~~~ + + +Tracing Edge Cases +^^^^^^^^^^^^^^^^^^ +There are some edge cases that exist where the trace of a given Python +function/module will not be representative of the underlying code. These +cases can include: + +* Tracing of control flow that is dependent on inputs (e.g. tensor shapes) +* Tracing of in-place operations of tensor views (e.g. indexing on the left-hand side of an assignment) + +Note that these cases may in fact be traceable in the future. + + +Automatic Trace Checking +^^^^^^^^^^^^^^^^^^^^^^^^ +One way to automatically catch many errors in traces is by using ``check_inputs`` +on the ``torch.jit.trace()`` API. ``check_inputs`` takes a list of tuples +of inputs that will be used to re-trace the computation and verify the +results. For example:: + + def loop_in_traced_fn(x): + result = x[0] + for i in range(x.size(0)): + result = result * x[i] + return result + + inputs = (torch.rand(3, 4, 5),) + check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)] + + traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs) + +Gives us the following diagnostic information:: + + ERROR: Graphs differed across invocations! + Graph diff: + + graph(%x : Tensor) { + %1 : int = prim::Constant[value=0]() + %2 : int = prim::Constant[value=0]() + %result.1 : Tensor = aten::select(%x, %1, %2) + %4 : int = prim::Constant[value=0]() + %5 : int = prim::Constant[value=0]() + %6 : Tensor = aten::select(%x, %4, %5) + %result.2 : Tensor = aten::mul(%result.1, %6) + %8 : int = prim::Constant[value=0]() + %9 : int = prim::Constant[value=1]() + %10 : Tensor = aten::select(%x, %8, %9) + - %result : Tensor = aten::mul(%result.2, %10) + + %result.3 : Tensor = aten::mul(%result.2, %10) + ? ++ + %12 : int = prim::Constant[value=0]() + %13 : int = prim::Constant[value=2]() + %14 : Tensor = aten::select(%x, %12, %13) + + %result : Tensor = aten::mul(%result.3, %14) + + %16 : int = prim::Constant[value=0]() + + %17 : int = prim::Constant[value=3]() + + %18 : Tensor = aten::select(%x, %16, %17) + - %15 : Tensor = aten::mul(%result, %14) + ? ^ ^ + + %19 : Tensor = aten::mul(%result, %18) + ? ^ ^ + - return (%15); + ? ^ + + return (%19); + ? ^ + } + + +This message indicates to us that the computation differed between when +we first traced it and when we traced it with the ``check_inputs``. Indeed, +the loop within the body of ``loop_in_traced_fn`` depends on the shape +of the input ``x``, and thus when we try another ``x`` with a different +shape, the trace differs. + +In this case, data-dependent control flow like this can be captured using +:func:`torch.jit.script` instead: + +.. testcode:: + + def fn(x): + result = x[0] + for i in range(x.size(0)): + result = result * x[i] + return result + + inputs = (torch.rand(3, 4, 5),) + check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)] + + scripted_fn = torch.jit.script(fn) + print(scripted_fn.graph) + #print(str(scripted_fn.graph).strip()) + + for input_tuple in [inputs] + check_inputs: + torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple)) + +.. testoutput:: + :hide: + + ... + + +Which produces:: + + graph(%x : Tensor) { + %5 : bool = prim::Constant[value=1]() + %1 : int = prim::Constant[value=0]() + %result.1 : Tensor = aten::select(%x, %1, %1) + %4 : int = aten::size(%x, %1) + %result : Tensor = prim::Loop(%4, %5, %result.1) + block0(%i : int, %7 : Tensor) { + %10 : Tensor = aten::select(%x, %1, %i) + %result.2 : Tensor = aten::mul(%7, %10) + -> (%5, %result.2) + } + return (%result); + } + +Tracer Warnings +^^^^^^^^^^^^^^^ +The tracer produces warnings for several problematic patterns in traced +computation. As an example, take a trace of a function that contains an +in-place assignment on a slice (a view) of a Tensor: + +.. testcode:: + + def fill_row_zero(x): + x[0] = torch.rand(*x.shape[1:2]) + return x + + traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),)) + print(traced.graph) + +.. testoutput:: + :hide: + + ... + +Produces several warnings and a graph which simply returns the input:: + + fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe. + x[0] = torch.rand(*x.shape[1:2]) + fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error: + Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%) + traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),)) + graph(%0 : Float(3, 4)) { + return (%0); + } + +We can fix this by modifying the code to not use the in-place update, but +rather build up the result tensor out-of-place with ``torch.cat``: + +.. testcode:: + + def fill_row_zero(x): + x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0) + return x + + traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),)) + print(traced.graph) + +.. testoutput:: + :hide: + + ... + +Frequently Asked Questions +-------------------------- + +Q: I would like to train a model on GPU and do inference on CPU. What are the +best practices? + + First convert your model from GPU to CPU and then save it, like so: :: + + cpu_model = gpu_model.cpu() + sample_input_cpu = sample_input_gpu.cpu() + traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu) + torch.jit.save(traced_cpu, "cpu.pt") + + traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu) + torch.jit.save(traced_gpu, "gpu.pt") + + # ... later, when using the model: + + if use_gpu: + model = torch.jit.load("gpu.pt") + else: + model = torch.jit.load("cpu.pt") + + model(input) + + This is recommended because the tracer may witness tensor creation on a + specific device, so casting an already-loaded model may have unexpected + effects. Casting the model *before* saving it ensures that the tracer has + the correct device information. + + +Q: How do I store attributes on a :class:`ScriptModule`? + + Say we have a model like: + + .. testcode:: + + import torch + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.x = 2 + + def forward(self): + return self.x + + m = torch.jit.script(Model()) + + + + If ``Model`` is instantiated it will result in a compilation error + since the compiler doesn't know about ``x``. There are 4 ways to inform the + compiler of attributes on :class:`ScriptModule`: + + 1. ``nn.Parameter`` - Values wrapped in ``nn.Parameter`` will work as they + do on ``nn.Module``\s + + 2. ``register_buffer`` - Values wrapped in ``register_buffer`` will work as + they do on ``nn.Module``\s. This is equivalent to an attribute (see 4) of type + ``Tensor``. + + 3. Constants - Annotating a class member as ``Final`` (or adding it to a list called + ``__constants__`` at the class definition level) will mark the contained names + as constants. Constants are saved directly in the code of the model. See + `builtin-constants` for details. + + 4. Attributes - Values that are a `supported type` can be added as mutable + attributes. Most types can be inferred but some may need to be specified, see + `module attributes` for details. + +Q: I would like to trace module's method but I keep getting this error: + +``RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient`` + + This error usually means that the method you are tracing uses a module's parameters and + you are passing the module's method instead of the module instance (e.g. ``my_module_instance.forward`` vs ``my_module_instance``). + + - Invoking ``trace`` with a module's method captures module parameters (which may require gradients) as **constants**. + - On the other hand, invoking ``trace`` with module's instance (e.g. ``my_module``) creates a new module and correctly copies parameters into the new module, so they can accumulate gradients if required. + + To trace a specific method on a module, see :func:`torch.jit.trace_module ` + +Known Issues +--------------- + +If you're using ``Sequential`` with TorchScript, the inputs of some +of the ``Sequential`` submodules may be falsely inferred to be +``Tensor``, even if they're annotated otherwise. The canonical +solution is to subclass ``nn.Sequential`` and redeclare ``forward`` +with the input typed correctly. + +Appendix +-------- + +Migrating to PyTorch 1.2 Recursive Scripting API +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +This section details the changes to TorchScript in PyTorch 1.2. If you are new to TorchScript you can +skip this section. There are two main changes to the TorchScript API with PyTorch 1.2. + +1. :func:`torch.jit.script ` will now attempt to recursively compile functions, +methods, and classes that it encounters. Once you call ``torch.jit.script``, +compilation is "opt-out", rather than "opt-in". + +2. ``torch.jit.script(nn_module_instance)`` is now the preferred way to create +:class:`ScriptModule`\s, instead of inheriting from ``torch.jit.ScriptModule``. +These changes combine to provide a simpler, easier-to-use API for converting +your ``nn.Module``\s into :class:`ScriptModule`\s, ready to be optimized and executed in a +non-Python environment. + +The new usage looks like this: + +.. testcode:: + + import torch + import torch.nn as nn + import torch.nn.functional as F + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5) + self.conv2 = nn.Conv2d(20, 20, 5) + + def forward(self, x): + x = F.relu(self.conv1(x)) + return F.relu(self.conv2(x)) + + my_model = Model() + my_scripted_model = torch.jit.script(my_model) + + +* The module's ``forward`` is compiled by default. Methods called from ``forward`` are lazily compiled in the order they are used in ``forward``. +* To compile a method other than ``forward`` that is not called from ``forward``, add ``@torch.jit.export``. +* To stop the compiler from compiling a method, add :func:`@torch.jit.ignore ` or :func:`@torch.jit.unused `. ``@ignore`` leaves the +* method as a call to python, and ``@unused`` replaces it with an exception. ``@ignored`` cannot be exported; ``@unused`` can. +* Most attribute types can be inferred, so ``torch.jit.Attribute`` is not necessary. For empty container types, annotate their types using `PEP 526-style `_ class annotations. +* Constants can be marked with a ``Final`` class annotation instead of adding the name of the member to ``__constants__``. +* Python 3 type hints can be used in place of ``torch.jit.annotate`` + +As a result of these changes, the following items are considered deprecated and should not appear in new code: + * The ``@torch.jit.script_method`` decorator + * Classes that inherit from ``torch.jit.ScriptModule`` + * The ``torch.jit.Attribute`` wrapper class + * The ``__constants__`` array + * The ``torch.jit.annotate`` function + +Modules +^^^^^^^ +.. warning:: + + The :func:`@torch.jit.ignore ` annotation's behavior changes in + PyTorch 1.2. Before PyTorch 1.2 the @ignore decorator was used to make a function + or method callable from code that is exported. To get this functionality back, + use ``@torch.jit.unused()``. ``@torch.jit.ignore`` is now equivalent + to ``@torch.jit.ignore(drop=False)``. See :func:`@torch.jit.ignore ` + and :func:`@torch.jit.unused` for details. + +When passed to the :func:`torch.jit.script ` function, a ``torch.nn.Module``\'s data is +copied to a :class:`ScriptModule` and the TorchScript compiler compiles the module. +The module's ``forward`` is compiled by default. Methods called from ``forward`` are +lazily compiled in the order they are used in ``forward``, as well as any +``@torch.jit.export`` methods. + +.. autofunction:: export + +Functions +^^^^^^^^^ +Functions don't change much, they can be decorated with :func:`@torch.jit.ignore ` or :func:`torch.jit.unused ` if needed. + +.. testcode:: + + # Same behavior as pre-PyTorch 1.2 + @torch.jit.script + def some_fn(): + return 2 + + # Marks a function as ignored, if nothing + # ever calls it then this has no effect + @torch.jit.ignore + def some_fn2(): + return 2 + + # As with ignore, if nothing calls it then it has no effect. + # If it is called in script it is replaced with an exception. + @torch.jit.unused + def some_fn3(): + import pdb; pdb.set_trace() + return 4 + + # Doesn't do anything, this function is already + # the main entry point + @torch.jit.export + def some_fn4(): + return 2 + +TorchScript Classes +^^^^^^^^^^^^^^^^^^^ + +.. warning:: + + TorchScript class support is experimental. Currently it is best suited + for simple record-like types (think a ``NamedTuple`` with methods + attached). + +Everything in a user defined `TorchScript Class `_ is +exported by default, functions can be decorated with :func:`@torch.jit.ignore +` if needed. + +Attributes +^^^^^^^^^^ +The TorchScript compiler needs to know the types of `module attributes`. Most types +can be inferred from the value of the member. Empty lists and dicts cannot have their +types inferred and must have their types annotated with `PEP 526-style `_ class annotations. +If a type cannot be inferred and is not explicitly annotated, it will not be added as an attribute +to the resulting :class:`ScriptModule` + + +Old API: + +.. testcode:: + + from typing import Dict + import torch + + class MyModule(torch.jit.ScriptModule): + def __init__(self): + super().__init__() + self.my_dict = torch.jit.Attribute({}, Dict[str, int]) + self.my_int = torch.jit.Attribute(20, int) + + m = MyModule() + +New API: + +.. testcode:: + + from typing import Dict + + class MyModule(torch.nn.Module): + my_dict: Dict[str, int] + + def __init__(self): + super().__init__() + # This type cannot be inferred and must be specified + self.my_dict = {} + + # The attribute type here is inferred to be `int` + self.my_int = 20 + + def forward(self): + pass + + m = torch.jit.script(MyModule()) + + +Constants +^^^^^^^^^ +The ``Final`` type constructor can be used to mark members as `constant`. If members are not marked constant, they will be copied to the resulting :class:`ScriptModule` as an attribute. Using ``Final`` opens opportunities for optimization if the value is known to be fixed and gives additional type safety. + +Old API: + +.. testcode:: + + class MyModule(torch.jit.ScriptModule): + __constants__ = ['my_constant'] + + def __init__(self): + super().__init__() + self.my_constant = 2 + + def forward(self): + pass + m = MyModule() + +New API: + +:: + + from typing import Final + + class MyModule(torch.nn.Module): + + my_constant: Final[int] + + def __init__(self): + super().__init__() + self.my_constant = 2 + + def forward(self): + pass + + m = torch.jit.script(MyModule()) + +.. _Python 3 type hints: + +Variables +^^^^^^^^^ +Containers are assumed to have type ``Tensor`` and be non-optional (see +`Default Types` for more information). Previously, ``torch.jit.annotate`` was used to +tell the TorchScript compiler what the type should be. Python 3 style type hints are +now supported. + +.. testcode:: + + import torch + from typing import Dict, Optional + + @torch.jit.script + def make_dict(flag: bool): + x: Dict[str, int] = {} + x['hi'] = 2 + b: Optional[int] = None + if flag: + b = 2 + return x, b + +Fusion Backends +~~~~~~~~~~~~~~~ +There are a couple of fusion backends available to optimize TorchScript execution. The default fuser on CPUs is NNC, which can perform fusions for both CPUs and GPUs. The default fuser on GPUs is NVFuser, which supports a wider range of operators and has demonstrated generated kernels with improved throughput. See the `NVFuser documentation `_ for more details on usage and debugging. + + +References +~~~~~~~~~~ +.. toctree:: + :maxdepth: 1 + + jit_python_reference + jit_unsupported + +.. This package is missing doc. Adding it here for coverage +.. This does not add anything to the rendered page. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. py:module:: torch.jit.mobile .. py:module:: torch.jit.annotations .. py:module:: torch.jit.frontend diff --git a/docs/source/jit_builtin_functions.rst b/docs/source/jit_builtin_functions.rst index 6fd514f6e6fca..4e07ab9baceb4 100644 --- a/docs/source/jit_builtin_functions.rst +++ b/docs/source/jit_builtin_functions.rst @@ -3,6 +3,14 @@ TorchScript Builtins ==================== +<<<<<<< HEAD .. warning:: TorchScript is deprecated, please use `torch.export `__ instead. +======= +This is a full reference of functions and Tensor methods accessible in TorchScript + +.. contents:: :local: + +.. automodule:: torch.jit.supported_ops +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/docs/source/jit_language_reference.md b/docs/source/jit_language_reference.md index f2b31768e2d58..fbbc92f26855b 100644 --- a/docs/source/jit_language_reference.md +++ b/docs/source/jit_language_reference.md @@ -30,7 +30,930 @@ # TorchScript Language Reference +<<<<<<< HEAD :::{warning} TorchScript is deprecated, please use [torch.export](https://docs.pytorch.org/docs/stable/export.html) instead. -::: \ No newline at end of file +::: +======= +TorchScript is a statically typed subset of Python that can either be written directly (using +the {func}`@torch.jit.script ` decorator) or generated automatically from Python code via +tracing. When using tracing, code is automatically converted into this subset of +Python by recording only the actual operators on tensors and simply executing and +discarding the other surrounding Python code. + +When writing TorchScript directly using `@torch.jit.script` decorator, the programmer must +only use the subset of Python supported in TorchScript. This section documents +what is supported in TorchScript as if it were a language reference for a stand +alone language. Any features of Python not mentioned in this reference are not +part of TorchScript. See `Builtin Functions` for a complete reference of available +PyTorch tensor methods, modules, and functions. + +As a subset of Python, any valid TorchScript function is also a valid Python +function. This makes it possible to `disable TorchScript` and debug the +function using standard Python tools like `pdb`. The reverse is not true: there +are many valid Python programs that are not valid TorchScript programs. +Instead, TorchScript focuses specifically on the features of Python that are +needed to represent neural network models in PyTorch. + +(types)= + +(supported-type)= + +## Types + +The largest difference between TorchScript and the full Python language is that +TorchScript only supports a small set of types that are needed to express neural +net models. In particular, TorchScript supports: + +```{eval-rst} +.. csv-table:: + :header: "Type", "Description" + + "``Tensor``", "A PyTorch tensor of any dtype, dimension, or backend" + "``Tuple[T0, T1, ..., TN]``", "A tuple containing subtypes ``T0``, ``T1``, etc. (e.g. ``Tuple[Tensor, Tensor]``)" + "``bool``", "A boolean value" + "``int``", "A scalar integer" + "``float``", "A scalar floating point number" + "``str``", "A string" + "``List[T]``", "A list of which all members are type ``T``" + "``Optional[T]``", "A value which is either None or type ``T``" + "``Dict[K, V]``", "A dict with key type ``K`` and value type ``V``. Only ``str``, ``int``, and ``float`` are allowed as key types." + "``T``", "A {ref}`TorchScript Class`" + "``E``", "A {ref}`TorchScript Enum`" + "``NamedTuple[T0, T1, ...]``", "A :func:`collections.namedtuple ` tuple type" + "``Union[T0, T1, ...]``", "One of the subtypes ``T0``, ``T1``, etc." +``` + +Unlike Python, each variable in TorchScript function must have a single static type. +This makes it easier to optimize TorchScript functions. + +Example (a type mismatch) + +```{eval-rst} +.. testcode:: + + import torch + + @torch.jit.script + def an_error(x): + if x: + r = torch.rand(1) + else: + r = 4 + return r + +``` + +```{eval-rst} +.. testoutput:: + + Traceback (most recent call last): + ... + RuntimeError: ... + + Type mismatch: r is set to type Tensor in the true branch and type int in the false branch: + @torch.jit.script + def an_error(x): + if x: + ~~~~~ + r = torch.rand(1) + ~~~~~~~~~~~~~~~~~ + else: + ~~~~~ + r = 4 + ~~~~~ <--- HERE + return r + and was used here: + else: + r = 4 + return r + ~ <--- HERE... +``` + +### Unsupported Typing Constructs + +TorchScript does not support all features and types of the {mod}`typing` module. Some of these +are more fundamental things that are unlikely to be added in the future while others +may be added if there is enough user demand to make it a priority. + +These types and features from the {mod}`typing` module are unavailable in TorchScript. + +```{eval-rst} +.. csv-table:: + :header: "Item", "Description" + + ":any:`typing.Any`", ":any:`typing.Any` is currently in development but not yet released" + ":any:`typing.NoReturn`", "Not implemented" + ":any:`typing.Sequence`", "Not implemented" + ":any:`typing.Callable`", "Not implemented" + ":any:`typing.Literal`", "Not implemented" + ":any:`typing.ClassVar`", "Not implemented" + ":any:`typing.Final`", "This is supported for :any:`module attributes ` class attribute annotations but not for functions" + ":any:`typing.AnyStr`", "TorchScript does not support :any:`bytes` so this type is not used" + ":any:`typing.overload`", ":any:`typing.overload` is currently in development but not yet released" + "Type aliases", "Not implemented" + "Nominal vs structural subtyping", "Nominal typing is in development, but structural typing is not" + "NewType", "Unlikely to be implemented" + "Generics", "Unlikely to be implemented" +``` + +Any other functionality from the {any}`typing` module not explicitly listed in this documentation is unsupported. + +### Default Types + +By default, all parameters to a TorchScript function are assumed to be Tensor. +To specify that an argument to a TorchScript function is another type, it is possible to use +MyPy-style type annotations using the types listed above. + +```{eval-rst} +.. testcode:: + + import torch + + @torch.jit.script + def foo(x, tup): + # type: (int, Tuple[Tensor, Tensor]) -> Tensor + t0, t1 = tup + return t0 + t1 + x + + print(foo(3, (torch.rand(3), torch.rand(3)))) +``` + +```{eval-rst} +.. testoutput:: + :hide: + + ... +``` + +:::{note} +It is also possible to annotate types with Python 3 type hints from the +`typing` module. + +```{eval-rst} +.. testcode:: + + import torch + from typing import Tuple + + @torch.jit.script + def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + t0, t1 = tup + return t0 + t1 + x + + print(foo(3, (torch.rand(3), torch.rand(3)))) +``` + +```{eval-rst} +.. testoutput:: + :hide: + + ... +``` +::: + +An empty list is assumed to be `List[Tensor]` and empty dicts +`Dict[str, Tensor]`. To instantiate an empty list or dict of other types, +use `Python 3 type hints`. + +Example (type annotations for Python 3): + +```{eval-rst} +.. testcode:: + + import torch + import torch.nn as nn + from typing import Dict, List, Tuple + + class EmptyDataStructures(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]: + # This annotates the list to be a `List[Tuple[int, float]]` + my_list: List[Tuple[int, float]] = [] + for i in range(10): + my_list.append((i, x.item())) + + my_dict: Dict[str, int] = {} + return my_list, my_dict + + x = torch.jit.script(EmptyDataStructures()) + + + +``` + +### Optional Type Refinement + +TorchScript will refine the type of a variable of type `Optional[T]` when +a comparison to `None` is made inside the conditional of an if-statement or checked in an `assert`. +The compiler can reason about multiple `None` checks that are combined with +`and`, `or`, and `not`. Refinement will also occur for else blocks of if-statements +that are not explicitly written. + +The `None` check must be within the if-statement's condition; assigning +a `None` check to a variable and using it in the if-statement's condition will +not refine the types of variables in the check. +Only local variables will be refined, an attribute like `self.x` will not and must assigned to +a local variable to be refined. + +Example (refining types on parameters and locals): + +```{eval-rst} +.. testcode:: + + import torch + import torch.nn as nn + from typing import Optional + + class M(nn.Module): + z: Optional[int] + + def __init__(self, z): + super().__init__() + # If `z` is None, its type cannot be inferred, so it must + # be specified (above) + self.z = z + + def forward(self, x, y, z): + # type: (Optional[int], Optional[int], Optional[int]) -> int + if x is None: + x = 1 + x = x + 1 + + # Refinement for an attribute by assigning it to a local + z = self.z + if y is not None and z is not None: + x = y + z + + # Refinement via an `assert` + assert z is not None + x += z + return x + + module = torch.jit.script(M(2)) + module = torch.jit.script(M(None)) + +``` + +(TorchScript Class)= + +(TorchScript Classes)= + +(torchscript-classes)= + +### TorchScript Classes + +:::{warning} +TorchScript class support is experimental. Currently it is best suited +for simple record-like types (think a `NamedTuple` with methods +attached). +::: + +Python classes can be used in TorchScript if they are annotated with {func}`@torch.jit.script `, +similar to how you would declare a TorchScript function: + +```{eval-rst} +.. testcode:: + :skipif: True # TODO: fix the source file resolving so this can be tested + + @torch.jit.script + class Foo: + def __init__(self, x, y): + self.x = x + + def aug_add_x(self, inc): + self.x += inc + +``` + +This subset is restricted: + +- All functions must be valid TorchScript functions (including `__init__()`). + +- Classes must be new-style classes, as we use `__new__()` to construct them with pybind11. + +- TorchScript classes are statically typed. Members can only be declared by assigning to + self in the `__init__()` method. + + > For example, assigning to `self` outside of the `__init__()` method: + > + > ``` + > @torch.jit.script + > class Foo: + > def assign_x(self): + > self.x = torch.rand(2, 3) + > ``` + > + > Will result in: + > + > ``` + > RuntimeError: + > Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: + > def assign_x(self): + > self.x = torch.rand(2, 3) + > ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + > ``` + +- No expressions except method definitions are allowed in the body of the class. + +- No support for inheritance or any other polymorphism strategy, except for inheriting + from `object` to specify a new-style class. + +After a class is defined, it can be used in both TorchScript and Python interchangeably +like any other TorchScript type: + +``` +# Declare a TorchScript class +@torch.jit.script +class Pair: + def __init__(self, first, second): + self.first = first + self.second = second + +@torch.jit.script +def sum_pair(p): + # type: (Pair) -> Tensor + return p.first + p.second + +p = Pair(torch.rand(2, 3), torch.rand(2, 3)) +print(sum_pair(p)) +``` + +(TorchScript Enum)= + +(TorchScript Enums)= + +(torchscript-enums)= + +### TorchScript Enums + +Python enums can be used in TorchScript without any extra annotation or code: + +``` +from enum import Enum + + +class Color(Enum): + RED = 1 + GREEN = 2 + +@torch.jit.script +def enum_fn(x: Color, y: Color) -> bool: + if x == Color.RED: + return True + + return x == y +``` + +After an enum is defined, it can be used in both TorchScript and Python interchangeably +like any other TorchScript type. The type of the values of an enum must be `int`, +`float`, or `str`. All values must be of the same type; heterogeneous types for enum +values are not supported. + +### Named Tuples + +Types produced by {func}`collections.namedtuple ` can be used in TorchScript. + +```{eval-rst} +.. testcode:: + + import torch + import collections + + Point = collections.namedtuple('Point', ['x', 'y']) + + @torch.jit.script + def total(point): + # type: (Point) -> Tensor + return point.x + point.y + + p = Point(x=torch.rand(3), y=torch.rand(3)) + print(total(p)) +``` + +```{eval-rst} +.. testoutput:: + :hide: + + ... + +``` + +(jit_iterables)= + +### Iterables + +Some functions (for example, {any}`zip` and {any}`enumerate`) can only operate on iterable types. +Iterable types in TorchScript include `Tensor`s, lists, tuples, dictionaries, strings, +{any}`torch.nn.ModuleList` and {any}`torch.nn.ModuleDict`. + +## Expressions + +The following Python Expressions are supported. + +### Literals + +``` +True +False +None +'string literals' +"string literals" +3 # interpreted as int +3.4 # interpreted as a float +``` + +#### List Construction + +An empty list is assumed have type `List[Tensor]`. +The types of other list literals are derived from the type of the members. +See [Default Types] for more details. + +``` +[3, 4] +[] +[torch.rand(3), torch.rand(4)] +``` + +#### Tuple Construction + +``` +(3, 4) +(3,) +``` + +#### Dict Construction + +An empty dict is assumed have type `Dict[str, Tensor]`. +The types of other dict literals are derived from the type of the members. +See [Default Types] for more details. + +``` +{'hello': 3} +{} +{'a': torch.rand(3), 'b': torch.rand(4)} +``` + +### Variables + +See [Variable Resolution] for how variables are resolved. + +``` +my_variable_name +``` + +### Arithmetic Operators + +``` +a + b +a - b +a * b +a / b +a ^ b +a @ b +``` + +### Comparison Operators + +``` +a == b +a != b +a < b +a > b +a <= b +a >= b +``` + +### Logical Operators + +``` +a and b +a or b +not b +``` + +### Subscripts and Slicing + +``` +t[0] +t[-1] +t[0:2] +t[1:] +t[:1] +t[:] +t[0, 1] +t[0, 1:2] +t[0, :1] +t[-1, 1:, 0] +t[1:, -1, 0] +t[i:j, i] +``` + +### Function Calls + +Calls to `builtin functions` + +``` +torch.rand(3, dtype=torch.int) +``` + +Calls to other script functions: + +```{eval-rst} +.. testcode:: + + import torch + + @torch.jit.script + def foo(x): + return x + 1 + + @torch.jit.script + def bar(x): + return foo(x) +``` + +### Method Calls + +Calls to methods of builtin types like tensor: `x.mm(y)` + +On modules, methods must be compiled before they can be called. The TorchScript +compiler recursively compiles methods it sees when compiling other methods. By default, +compilation starts on the `forward` method. Any methods called by `forward` will +be compiled, and any methods called by those methods, and so on. To start compilation at +a method other than `forward`, use the {func}`@torch.jit.export ` decorator +(`forward` implicitly is marked `@torch.jit.export`). + +Calling a submodule directly (e.g. `self.resnet(input)`) is equivalent to +calling its `forward` method (e.g. `self.resnet.forward(input)`). + +```{eval-rst} +.. testcode:: + :skipif: torchvision is None + + import torch + import torch.nn as nn + import torchvision + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + means = torch.tensor([103.939, 116.779, 123.68]) + self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1)) + resnet = torchvision.models.resnet18() + self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224)) + + def helper(self, input): + return self.resnet(input - self.means) + + def forward(self, input): + return self.helper(input) + + # Since nothing in the model calls `top_level_method`, the compiler + # must be explicitly told to compile this method + @torch.jit.export + def top_level_method(self, input): + return self.other_helper(input) + + def other_helper(self, input): + return input + 10 + + # `my_script_module` will have the compiled methods `forward`, `helper`, + # `top_level_method`, and `other_helper` + my_script_module = torch.jit.script(MyModule()) + +``` + +### Ternary Expressions + +``` +x if x > y else y +``` + +### Casts + +``` +float(ten) +int(3.5) +bool(ten) +str(2)`` +``` + +### Accessing Module Parameters + +``` +self.my_parameter +self.my_submodule.my_parameter +``` + +## Statements + +TorchScript supports the following types of statements: + +### Simple Assignments + +``` +a = b +a += b # short-hand for a = a + b, does not operate in-place on a +a -= b +``` + +### Pattern Matching Assignments + +``` +a, b = tuple_or_list +a, b, *c = a_tuple +``` + +Multiple Assignments + +``` +a = b, c = tup +``` + +### Print Statements + +``` +print("the result of an add:", a + b) +``` + +### If Statements + +``` +if a < 4: + r = -a +elif a < 3: + r = a + a +else: + r = 3 * a +``` + +In addition to bools, floats, ints, and Tensors can be used in a conditional +and will be implicitly casted to a boolean. + +### While Loops + +``` +a = 0 +while a < 4: + print(a) + a += 1 +``` + +### For loops with range + +``` +x = 0 +for i in range(10): + x *= i +``` + +### For loops over tuples + +These unroll the loop, generating a body for +each member of the tuple. The body must type-check correctly for each member. + +``` +tup = (3, torch.rand(4)) +for x in tup: + print(x) +``` + +### For loops over constant nn.ModuleList + +To use a `nn.ModuleList` inside a compiled method, it must be marked +constant by adding the name of the attribute to the `__constants__` +list for the type. For loops over a `nn.ModuleList` will unroll the body of the +loop at compile time, with each member of the constant module list. + +```{eval-rst} +.. testcode:: + + class SubModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(2)) + + def forward(self, input): + return self.weight + input + + class MyModule(torch.nn.Module): + __constants__ = ['mods'] + + def __init__(self): + super().__init__() + self.mods = torch.nn.ModuleList([SubModule() for i in range(10)]) + + def forward(self, v): + for module in self.mods: + v = module(v) + return v + + + m = torch.jit.script(MyModule()) + + +``` + +### Break and Continue + +``` +for i in range(5): + if i == 1: + continue + if i == 3: + break + print(i) +``` + +### Return + +``` +return a, b +``` + +## Variable Resolution + +TorchScript supports a subset of Python's variable resolution (i.e. scoping) +rules. Local variables behave the same as in Python, except for the restriction +that a variable must have the same type along all paths through a function. +If a variable has a different type on different branches of an if statement, it +is an error to use it after the end of the if statement. + +Similarly, a variable is not allowed to be used if it is only *defined* along some +paths through the function. + +Example: + +```{eval-rst} +.. testcode:: + + @torch.jit.script + def foo(x): + if x < 0: + y = 4 + print(y) +``` + +```{eval-rst} +.. testoutput:: + + Traceback (most recent call last): + ... + RuntimeError: ... + + y is not defined in the false branch... + @torch.jit.script... + def foo(x): + if x < 0: + ~~~~~~~~~ + y = 4 + ~~~~~ <--- HERE + print(y) + and was used here: + if x < 0: + y = 4 + print(y) + ~ <--- HERE... +``` + +Non-local variables are resolved to Python values at compile time when the +function is defined. These values are then converted into TorchScript values using +the rules described in [Use of Python Values]. + +## Use of Python Values + +To make writing TorchScript more convenient, we allow script code to refer +to Python values in the surrounding scope. For instance, any time there is a +reference to `torch`, the TorchScript compiler is actually resolving it to the +`torch` Python module when the function is declared. These Python values are +not a first class part of TorchScript. Instead they are de-sugared at compile-time +into the primitive types that TorchScript supports. This depends +on the dynamic type of the Python valued referenced when compilation occurs. +This section describes the rules that are used when accessing Python values in TorchScript. + +### Functions + +TorchScript can call Python functions. This functionality is very useful when +incrementally converting a model to TorchScript. The model can be moved function-by-function +to TorchScript, leaving calls to Python functions in place. This way you can incrementally +check the correctness of the model as you go. + +```{eval-rst} +.. autofunction:: torch.jit.is_scripting +``` + +```{eval-rst} +.. autofunction:: torch.jit.is_tracing + +``` + +### Attribute Lookup On Python Modules + +TorchScript can lookup attributes on modules. `Builtin functions` like `torch.add` +are accessed this way. This allows TorchScript to call functions defined in +other modules. + +(constant)= + +### Python-defined Constants + +TorchScript also provides a way to use constants that are defined in Python. +These can be used to hard-code hyper-parameters into the function, or to +define universal constants. There are two ways of specifying that a Python +value should be treated as a constant. + +1. Values looked up as attributes of a module are assumed to be constant: + +```{eval-rst} +.. testcode:: + + import math + import torch + + @torch.jit.script + def fn(): + return math.pi +``` + +2. Attributes of a ScriptModule can be marked constant by annotating them with `Final[T]` + +``` +import torch +import torch.nn as nn + +class Foo(nn.Module): + # `Final` from the `typing_extensions` module can also be used + a : torch.jit.Final[int] + + def __init__(self): + super().__init__() + self.a = 1 + 4 + + def forward(self, input): + return self.a + input + +f = torch.jit.script(Foo()) +``` + +Supported constant Python types are + +- `int` +- `float` +- `bool` +- `torch.device` +- `torch.layout` +- `torch.dtype` +- tuples containing supported types +- `torch.nn.ModuleList` which can be used in a TorchScript for loop + +(module-attributes)= +(Module Attributes)= + +### Module Attributes + +The `torch.nn.Parameter` wrapper and `register_buffer` can be used to assign +tensors to a module. Other values assigned to a module that is compiled +will be added to the compiled module if their types can be inferred. All [types] +available in TorchScript can be used as module attributes. Tensor attributes are +semantically the same as buffers. The type of empty lists and dictionaries and `None` +values cannot be inferred and must be specified via +[PEP 526-style](https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations) class annotations. +If a type cannot be inferred and is not explicitly annotated, it will not be added as an attribute +to the resulting {class}`ScriptModule`. + +Example: + +```{eval-rst} +.. testcode:: + + from typing import List, Dict + + class Foo(nn.Module): + # `words` is initialized as an empty list, so its type must be specified + words: List[str] + + # The type could potentially be inferred if `a_dict` (below) was not + # empty, but this annotation ensures `some_dict` will be made into the + # proper type + some_dict: Dict[str, int] + + def __init__(self, a_dict): + super().__init__() + self.words = [] + self.some_dict = a_dict + + # `int`s can be inferred + self.my_int = 10 + + def forward(self, input): + # type: (str) -> int + self.words.append(input) + return self.some_dict[input] + self.my_int + + f = torch.jit.script(Foo({'hi': 2})) +``` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/docs/source/jit_language_reference_v2.md b/docs/source/jit_language_reference_v2.md index 40da0740963ba..cf3bf88bc9f8b 100644 --- a/docs/source/jit_language_reference_v2.md +++ b/docs/source/jit_language_reference_v2.md @@ -25,7 +25,1837 @@ # TorchScript Language Reference +<<<<<<< HEAD :::{warning} TorchScript is deprecated, please use [torch.export](https://docs.pytorch.org/docs/stable/export.html) instead. -::: \ No newline at end of file +::: +======= +This reference manual describes the syntax and core semantics of the TorchScript language. +TorchScript is a statically typed subset of the Python language. This document explains the supported features of +Python in TorchScript and also how the language diverges from regular Python. Any features of Python that are not mentioned in +this reference manual are not part of TorchScript. TorchScript focuses specifically on the features of Python that are needed to +represent neural network models in PyTorch. + +```{contents} +:depth: 1 +:local: true +``` + +(type-system)= + +## Terminology + +This document uses the following terminologies: + +```{eval-rst} +.. list-table:: + :widths: 25 25 + :header-rows: 1 + + * - Pattern + - Notes + * - ``::=`` + - Indicates that the given symbol is defined as. + * - ``" "`` + - Represents real keywords and delimiters that are part of the syntax. + * - ``A | B`` + - Indicates either A or B. + * - ``( )`` + - Indicates grouping. + * - ``[]`` + - Indicates optional. + * - ``A+`` + - Indicates a regular expression where term A is repeated at least once. + * - ``A*`` + - Indicates a regular expression where term A is repeated zero or more times. +``` + +## Type System + +TorchScript is a statically typed subset of Python. The largest difference between TorchScript and the full Python language is that TorchScript only supports a small set of types that are needed to express +neural net models. + +### TorchScript Types + +The TorchScript type system consists of `TSType` and `TSModuleType` as defined below. + +``` +TSAllType ::= TSType | TSModuleType +TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType +``` + +`TSType` represents the majority of TorchScript types that are composable and that can be used in TorchScript type annotations. +`TSType` refers to any of the following: + +- Meta Types, e.g., `Any` +- Primitive Types, e.g., `int`, `float`, and `str` +- Structural Types, e.g., `Optional[int]` or `List[MyClass]` +- Nominal Types (Python classes), e.g., `MyClass` (user-defined), `torch.tensor` (built-in) + +`TSModuleType` represents `torch.nn.Module` and its subclasses. It is treated differently from `TSType` because its type schema is inferred partly from the object instance and partly from the class definition. +As such, instances of a `TSModuleType` may not follow the same static type schema. `TSModuleType` cannot be used as a TorchScript type annotation or be composed with `TSType` for type safety considerations. + +### Meta Types + +Meta types are so abstract that they are more like type constraints than concrete types. +Currently TorchScript defines one meta-type, `Any`, that represents any TorchScript type. + +#### `Any` Type + +The `Any` type represents any TorchScript type. `Any` specifies no type constraints, thus there is no type-checking on `Any`. +As such it can be bound to any Python or TorchScript data types (e.g., `int`, TorchScript `tuple`, or an arbitrary Python class that is not scripted). + +``` +TSMetaType ::= "Any" +``` + +Where: + +- `Any` is the Python class name from the typing module. Therefore, to use the `Any` type, you must import it from `typing` (e.g., `from typing import Any`). +- Since `Any` can represent any TorchScript type, the set of operators that are allowed to operate on values of this type on `Any` is limited. + +#### Operators Supported for `Any` Type + +- Assignment to data of `Any` type. +- Binding to parameter or return of `Any` type. +- `x is`, `x is not` where `x` is of `Any` type. +- `isinstance(x, Type)` where `x` is of `Any` type. +- Data of `Any` type is printable. +- Data of `List[Any]` type may be sortable if the data is a list of values of the same type `T` and that `T` supports comparison operators. + +**Compared to Python** + +`Any` is the least constrained type in the TorchScript type system. In that sense, it is quite similar to the +`Object` class in Python. However, `Any` only supports a subset of the operators and methods that are supported by `Object`. + +#### Design Notes + +When we script a PyTorch module, we may encounter data that is not involved in the execution of the script. Nevertheless, it has to be described +by a type schema. It is not only cumbersome to describe static types for unused data (in the context of the script), but also may lead to unnecessary +scripting failures. `Any` is introduced to describe the type of the data where precise static types are not necessary for compilation. + +**Example 1** + +This example illustrates how `Any` can be used to allow the second element of the tuple parameter to be of any type. This is possible +because `x[1]` is not involved in any computation that requires knowing its precise type. + +```{eval-rst} +.. testcode:: + + import torch + + from typing import Tuple + from typing import Any + + @torch.jit.export + def inc_first_element(x: Tuple[int, Any]): + return (x[0]+1, x[1]) + + m = torch.jit.script(inc_first_element) + print(m((1,2.0))) + print(m((1,(100,200)))) +``` + +The example above produces the following output: + +```{eval-rst} +.. testoutput:: + + (2, 2.0) + (2, (100, 200)) +``` + +The second element of the tuple is of `Any` type, thus can bind to multiple types. +For example, `(1, 2.0)` binds a float type to `Any` as in `Tuple[int, Any]`, +whereas `(1, (100, 200))` binds a tuple to `Any` in the second invocation. + +**Example 2** + +This example illustrates how we can use `isinstance` to dynamically check the type of the data that is annotated as `Any` type: + +```{eval-rst} +.. testcode:: + + import torch + from typing import Any + + def f(a:Any): + print(a) + return (isinstance(a, torch.Tensor)) + + ones = torch.ones([2]) + m = torch.jit.script(f) + print(m(ones)) +``` + +The example above produces the following output: + +```{eval-rst} +.. testoutput:: + + 1 + 1 + [ CPUFloatType{2} ] + True +``` + +### Primitive Types + +Primitive TorchScript types are types that represent a single type of value and go with a single pre-defined +type name. + +``` +TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None" +``` + +### Structural Types + +Structural types are types that are structurally defined without a user-defined name (unlike nominal types), +such as `Future[int]`. Structural types are composable with any `TSType`. + +``` +TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | + TSOptional | TSUnion | TSFuture | TSRRef | TSAwait + +TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" +TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" +TSList ::= "List" "[" TSType "]" +TSOptional ::= "Optional" "[" TSType "]" +TSUnion ::= "Union" "[" (TSType ",")* TSType "]" +TSFuture ::= "Future" "[" TSType "]" +TSRRef ::= "RRef" "[" TSType "]" +TSAwait ::= "Await" "[" TSType "]" +TSDict ::= "Dict" "[" KeyType "," TSType "]" +KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any" +``` + +Where: + +- `Tuple`, `List`, `Optional`, `Union`, `Future`, `Dict` represent Python type class names that are defined in the module `typing`. To use these type names, you must import them from `typing` (e.g., `from typing import Tuple`). +- `namedtuple` represents the Python class `collections.namedtuple` or `typing.NamedTuple`. +- `Future` and `RRef` represent the Python classes `torch.futures` and `torch.distributed.rpc`. +- `Await` represent the Python class `torch._awaits._Await` + +**Compared to Python** + +Apart from being composable with TorchScript types, these TorchScript structural types often support a common subset of the operators and methods of their Python counterparts. + +**Example 1** + +This example uses `typing.NamedTuple` syntax to define a tuple: + +```{eval-rst} +.. testcode:: + + import torch + from typing import NamedTuple + from typing import Tuple + + class MyTuple(NamedTuple): + first: int + second: int + + def inc(x: MyTuple) -> Tuple[int, int]: + return (x.first+1, x.second+1) + + t = MyTuple(first=1, second=2) + scripted_inc = torch.jit.script(inc) + print("TorchScript:", scripted_inc(t)) +``` + +The example above produces the following output: + +```{eval-rst} +.. testoutput:: + + TorchScript: (2, 3) +``` + +**Example 2** + +This example uses `collections.namedtuple` syntax to define a tuple: + +```{eval-rst} +.. testcode:: + + import torch + from typing import NamedTuple + from typing import Tuple + from collections import namedtuple + + _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('first', int), ('second', int)]) + _UnannotatedNamedTuple = namedtuple('_NamedTupleAnnotated', ['first', 'second']) + + def inc(x: _AnnotatedNamedTuple) -> Tuple[int, int]: + return (x.first+1, x.second+1) + + m = torch.jit.script(inc) + print(inc(_UnannotatedNamedTuple(1,2))) +``` + +The example above produces the following output: + +```{eval-rst} +.. testoutput:: + + (2, 3) +``` + +**Example 3** + +This example illustrates a common mistake of annotating structural types, i.e., not importing the composite type +classes from the `typing` module: + +```python +import torch + +# ERROR: Tuple not recognized because not imported from typing +@torch.jit.export +def inc(x: Tuple[int, int]): + return (x[0]+1, x[1]+1) + +m = torch.jit.script(inc) +print(m((1,2))) +``` + +Running the above code yields the following scripting error: + +```python +File "test-tuple.py", line 5, in + def inc(x: Tuple[int, int]): +NameError: name 'Tuple' is not defined +``` + +The remedy is to add the line `from typing import Tuple` to the beginning of the code. + +### Nominal Types + +Nominal TorchScript types are Python classes. These types are called nominal because they are declared with a custom +name and are compared using class names. Nominal classes are further classified into the following categories: + +``` +TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum +``` + +Among them, `TSCustomClass` and `TSEnum` must be compilable to TorchScript Intermediate Representation (IR). This is enforced by the type-checker. + +### Built-in Class + +Built-in nominal types are Python classes whose semantics are built into the TorchScript system (e.g., tensor types). +TorchScript defines the semantics of these built-in nominal types, and often supports only a subset of the methods or +attributes of its Python class definition. + +``` +TSBuiltinClass ::= TSTensor | "torch.device" | "torch.Stream" | "torch.dtype" | + "torch.nn.ModuleList" | "torch.nn.ModuleDict" | ... +TSTensor ::= "torch.Tensor" | "common.SubTensor" | "common.SubWithTorchFunction" | + "torch.nn.parameter.Parameter" | and subclasses of torch.Tensor +``` + +#### Special Note on torch.nn.ModuleList and torch.nn.ModuleDict + +Although `torch.nn.ModuleList` and `torch.nn.ModuleDict` are defined as a list and dictionary in Python, +they behave more like tuples in TorchScript: + +- In TorchScript, instances of `torch.nn.ModuleList` or `torch.nn.ModuleDict` are immutable. +- Code that iterates over `torch.nn.ModuleList` or `torch.nn.ModuleDict` is completely unrolled so that elements of `torch.nn.ModuleList` or keys of `torch.nn.ModuleDict` can be of different subclasses of `torch.nn.Module`. + +**Example** + +The following example highlights the use of a few built-in Torchscript classes (`torch.*`): + +```python +import torch + +@torch.jit.script +class A: + def __init__(self): + self.x = torch.rand(3) + + def f(self, y: torch.device): + return self.x.to(device=y) + +def g(): + a = A() + return a.f(torch.device("cpu")) + +script_g = torch.jit.script(g) +print(script_g.graph) +``` + +### Custom Class + +Unlike built-in classes, semantics of custom classes are user-defined and the entire class definition must be compilable to TorchScript IR and subject to TorchScript type-checking rules. + +``` +TSClassDef ::= [ "@torch.jit.script" ] + "class" ClassName [ "(object)" ] ":" + MethodDefinition | + [ "@torch.jit.ignore" ] | [ "@torch.jit.unused" ] + MethodDefinition +``` + +Where: + +- Classes must be new-style classes. Python 3 supports only new-style classes. In Python 2.x, a new-style class is specified by subclassing from the object. +- Instance data attributes are statically typed, and instance attributes must be declared by assignments inside the `__init__()` method. +- Method overloading is not supported (i.e., you cannot have multiple methods with the same method name). +- `MethodDefinition` must be compilable to TorchScript IR and adhere to TorchScript’s type-checking rules, (i.e., all methods must be valid TorchScript functions and class attribute definitions must be valid TorchScript statements). +- `torch.jit.ignore` and `torch.jit.unused` can be used to ignore the method or function that is not fully torchscriptable or should be ignored by the compiler. + +**Compared to Python** + +TorchScript custom classes are quite limited compared to their Python counterpart. Torchscript custom classes: + +- Do not support class attributes. +- Do not support subclassing except for subclassing an interface type or object. +- Do not support method overloading. +- Must initialize all its instance attributes in `__init__()`; this is because TorchScript constructs a static schema of the class by inferring attribute types in `__init__()`. +- Must contain only methods that satisfy TorchScript type-checking rules and are compilable to TorchScript IRs. + +**Example 1** + +Python classes can be used in TorchScript if they are annotated with `@torch.jit.script`, similar to how a TorchScript function would be declared: + +```python +@torch.jit.script +class MyClass: + def __init__(self, x: int): + self.x = x + + def inc(self, val: int): + self.x += val +``` + +**Example 2** + +A TorchScript custom class type must "declare" all its instance attributes by assignments in `__init__()`. If an instance attribute is not defined in `__init__()` but accessed in other methods of the class, the class cannot be compiled as a TorchScript class, as shown in the following example: + +```python +import torch + +@torch.jit.script +class foo: + def __init__(self): + self.y = 1 + +# ERROR: self.x is not defined in __init__ +def assign_x(self): + self.x = torch.rand(2, 3) +``` + +The class will fail to compile and issue the following error: + +``` +RuntimeError: +Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: +def assign_x(self): + self.x = torch.rand(2, 3) + ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE +``` + +**Example 3** + +In this example, a TorchScript custom class defines a class variable name, which is not allowed: + +```python +import torch + +@torch.jit.script +class MyClass(object): + name = "MyClass" + def __init__(self, x: int): + self.x = x + +def fn(a: MyClass): + return a.name +``` + +It leads to the following compile-time error: + +``` +RuntimeError: +'__torch__.MyClass' object has no attribute or method 'name'. Did you forget to initialize an attribute in __init__()?: + File "test-class2.py", line 10 +def fn(a: MyClass): + return a.name + ~~~~~~ <--- HERE +``` + +### Enum Type + +Like custom classes, semantics of the enum type are user-defined and the entire class definition must be compilable to TorchScript IR and adhere to TorchScript type-checking rules. + +``` +TSEnumDef ::= "class" Identifier "(enum.Enum | TSEnumType)" ":" + ( MemberIdentifier "=" Value )+ + ( MethodDefinition )* +``` + +Where: + +- Value must be a TorchScript literal of type `int`, `float`, or `str`, and must be of the same TorchScript type. +- `TSEnumType` is the name of a TorchScript enumerated type. Similar to Python enum, TorchScript allows restricted `Enum` subclassing, that is, subclassing an enumerated is allowed only if it does not define any members. + +**Compared to Python** + +- TorchScript supports only `enum.Enum`. It does not support other variations such as `enum.IntEnum`, `enum.Flag`, `enum.IntFlag`, and `enum.auto`. +- Values of TorchScript enum members must be of the same type and can only be `int`, `float`, or `str` types, whereas Python enum members can be of any type. +- Enums containing methods are ignored in TorchScript. + +**Example 1** + +The following example defines the class `Color` as an `Enum` type: + +```python +import torch +from enum import Enum + +class Color(Enum): + RED = 1 + GREEN = 2 + +def enum_fn(x: Color, y: Color) -> bool: + if x == Color.RED: + return True + return x == y + +m = torch.jit.script(enum_fn) + +print("Eager: ", enum_fn(Color.RED, Color.GREEN)) +print("TorchScript: ", m(Color.RED, Color.GREEN)) +``` + +**Example 2** + +The following example shows the case of restricted enum subclassing, where `BaseColor` does not define any member, thus can be subclassed by `Color`: + +```python +import torch +from enum import Enum + +class BaseColor(Enum): + def foo(self): + pass + +class Color(BaseColor): + RED = 1 + GREEN = 2 + +def enum_fn(x: Color, y: Color) -> bool: + if x == Color.RED: + return True + return x == y + +m = torch.jit.script(enum_fn) + +print("TorchScript: ", m(Color.RED, Color.GREEN)) +print("Eager: ", enum_fn(Color.RED, Color.GREEN)) +``` + +### TorchScript Module Class + +`TSModuleType` is a special class type that is inferred from object instances that are created outside TorchScript. `TSModuleType` is named by the Python class of the object instance. The `__init__()` method of the Python class is not considered a TorchScript method, so it does not have to comply with TorchScript’s type-checking rules. + +The type schema of a module instance class is constructed directly from an instance object (created outside the scope of TorchScript) rather than inferred from `__init__()` like custom classes. It is possible that two objects of the same instance class type follow two different type schemas. + +In this sense, `TSModuleType` is not really a static type. Therefore, for type safety considerations, `TSModuleType` cannot be used in a TorchScript type annotation or be composed with `TSType`. + +### Module Instance Class + +TorchScript module type represents the type schema of a user-defined PyTorch module instance. When scripting a PyTorch module, the module object is always created outside TorchScript (i.e., passed in as parameter to `forward`). The Python module class is treated as a module instance class, so the `__init__()` method of the Python module class is not subject to the type-checking rules of TorchScript. + +``` +TSModuleType ::= "class" Identifier "(torch.nn.Module)" ":" + ClassBodyDefinition +``` + +Where: + +- `forward()` and other methods decorated with `@torch.jit.export` must be compilable to TorchScript IR and subject to TorchScript’s type-checking rules. + +Unlike custom classes, only the forward method and other methods decorated with `@torch.jit.export` of the module type need to be compilable. Most notably, `__init__()` is not considered a TorchScript method. Consequently, module type constructors cannot be invoked within the scope of TorchScript. Instead, TorchScript module objects are always constructed outside and passed into `torch.jit.script(ModuleObj)`. + +**Example 1** + +This example illustrates a few features of module types: + +- The `TestModule` instance is created outside the scope of TorchScript (i.e., before invoking `torch.jit.script`). +- `__init__()` is not considered a TorchScript method, therefore, it does not have to be annotated and can contain arbitrary Python code. In addition, the `__init__()` method of an instance class cannot be invoked in TorchScript code. Because `TestModule` instances are instantiated in Python, in this example, `TestModule(2.0)` and `TestModule(2)` create two instances with different types for its data attributes. `self.x` is of type `float` for `TestModule(2.0)`, whereas `self.y` is of type `int` for `TestModule(2.0)`. +- TorchScript automatically compiles other methods (e.g., `mul()`) invoked by methods annotated via `@torch.jit.export` or `forward()` methods. +- Entry-points to a TorchScript program are either `forward()` of a module type, functions annotated as `torch.jit.script`, or methods annotated as `torch.jit.export`. + +```{eval-rst} +.. testcode:: + + import torch + + class TestModule(torch.nn.Module): + def __init__(self, v): + super().__init__() + self.x = v + + def forward(self, inc: int): + return self.x + inc + + m = torch.jit.script(TestModule(1)) + print(f"First instance: {m(3)}") + + m = torch.jit.script(TestModule(torch.ones([5]))) + print(f"Second instance: {m(3)}") +``` + +The example above produces the following output: + +```{eval-rst} +.. testoutput:: + + First instance: 4 + Second instance: tensor([4., 4., 4., 4., 4.]) +``` + +**Example 2** + +The following example shows an incorrect usage of module type. Specifically, this example invokes the constructor of `TestModule` inside the scope of TorchScript: + +```{eval-rst} +.. testcode:: + + import torch + + class TestModule(torch.nn.Module): + def __init__(self, v): + super().__init__() + self.x = v + + def forward(self, x: int): + return self.x + x + + class MyModel: + def __init__(self, v: int): + self.val = v + + @torch.jit.export + def doSomething(self, val: int) -> int: + # error: should not invoke the constructor of module type + myModel = TestModule(self.val) + return myModel(val) + + # m = torch.jit.script(MyModel(2)) # Results in below RuntimeError + # RuntimeError: Could not get name of python class object +``` + +(type-annotation)= + +## Type Annotation + +Since TorchScript is statically typed, programmers need to annotate types at *strategic points* of TorchScript code so that every local variable or +instance data attribute has a static type, and every function and method has a statically typed signature. + +### When to Annotate Types + +In general, type annotations are only needed in places where static types cannot be automatically inferred (e.g., parameters or sometimes return types to +methods or functions). Types of local variables and data attributes are often automatically inferred from their assignment statements. Sometimes an inferred type +may be too restrictive, e.g., `x` being inferred as `NoneType` through assignment `x = None`, whereas `x` is actually used as an `Optional`. In such +cases, type annotations may be needed to overwrite auto inference, e.g., `x: Optional[int] = None`. Note that it is always safe to type annotate a local variable +or data attribute even if its type can be automatically inferred. The annotated type must be congruent with TorchScript’s type-checking. + +When a parameter, local variable, or data attribute is not type annotated and its type cannot be automatically inferred, TorchScript assumes it to be a +default type of `TensorType`, `List[TensorType]`, or `Dict[str, TensorType]`. + +### Annotate Function Signature + +Since a parameter may not be automatically inferred from the body of the function (including both functions and methods), they need to be type annotated. Otherwise, they assume the default type `TensorType`. + +TorchScript supports two styles for method and function signature type annotation: + +- **Python3-style** annotates types directly on the signature. As such, it allows individual parameters to be left unannotated (whose type will be the default type of `TensorType`), or allows the return type to be left unannotated (whose type will be automatically inferred). + +``` +Python3Annotation ::= "def" Identifier [ "(" ParamAnnot* ")" ] [ReturnAnnot] ":" + FuncOrMethodBody +ParamAnnot ::= Identifier [ ":" TSType ] "," +ReturnAnnot ::= "->" TSType +``` + +Note that when using Python3 style, the type `self` is automatically inferred and should not be annotated. + +- **Mypy style** annotates types as a comment right below the function/method declaration. In the Mypy style, since parameter names do not appear in the annotation, all parameters have to be annotated. + +``` +MyPyAnnotation ::= "# type:" "(" ParamAnnot* ")" [ ReturnAnnot ] +ParamAnnot ::= TSType "," +ReturnAnnot ::= "->" TSType +``` + +**Example 1** + +In this example: + +- `a` is not annotated and assumes the default type of `TensorType`. +- `b` is annotated as type `int`. +- The return type is not annotated and is automatically inferred as type `TensorType` (based on the type of the value being returned). + +```python +import torch + +def f(a, b: int): + return a+b + +m = torch.jit.script(f) +print("TorchScript:", m(torch.ones([6]), 100)) +``` + +**Example 2** + +The following example uses Mypy style annotation. Note that parameters or return values must be annotated even if some of +them assume the default type. + +```python +import torch + +def f(a, b): + # type: (torch.Tensor, int) → torch.Tensor + return a+b + +m = torch.jit.script(f) +print("TorchScript:", m(torch.ones([6]), 100)) +``` + +### Annotate Variables and Data Attributes + +In general, types of data attributes (including class and instance data attributes) and local variables can be automatically inferred from assignment statements. +Sometimes, however, if a variable or attribute is associated with values of different types (e.g., as `None` or `TensorType`), then they may need to be explicitly +type annotated as a *wider* type such as `Optional[int]` or `Any`. + +#### Local Variables + +Local variables can be annotated according to Python3 typing module annotation rules, i.e., + +``` +LocalVarAnnotation ::= Identifier [":" TSType] "=" Expr +``` + +In general, types of local variables can be automatically inferred. In some cases, however, you may need to annotate a multi-type for local variables +that may be associated with different concrete types. Typical multi-types include `Optional[T]` and `Any`. + +**Example** + +```python +import torch + +def f(a, setVal: bool): + value: Optional[torch.Tensor] = None + if setVal: + value = a + return value + +ones = torch.ones([6]) +m = torch.jit.script(f) +print("TorchScript:", m(ones, True), m(ones, False)) +``` + +#### Instance Data Attributes + +For `ModuleType` classes, instance data attributes can be annotated according to Python3 typing module annotation rules. Instance data attributes can be annotated (optionally) as final +via `Final`. + +``` +"class" ClassIdentifier "(torch.nn.Module):" +InstanceAttrIdentifier ":" ["Final("] TSType [")"] +... +``` + +Where: + +- `InstanceAttrIdentifier` is the name of an instance attribute. +- `Final` indicates that the attribute cannot be re-assigned outside of `__init__` or overridden in subclasses. + +**Example** + +```python +import torch + +class MyModule(torch.nn.Module): + offset_: int + +def __init__(self, offset): + self.offset_ = offset + +... +``` + +### Type Annotation APIs + +#### `torch.jit.annotate(T, expr)` + +This API annotates type `T` to an expression `expr`. This is often used when the default type of an expression is not the type intended by the programmer. +For instance, an empty list (dictionary) has the default type of `List[TensorType]` (`Dict[TensorType, TensorType]`), but sometimes it may be used to initialize +a list of some other types. Another common use case is for annotating the return type of `tensor.tolist()`. Note, however, that it cannot be used to annotate +the type of a module attribute in `__init__`; `torch.jit.Attribute` should be used for this instead. + +**Example** + +In this example, `[]` is declared as a list of integers via `torch.jit.annotate` (instead of assuming `[]` to be the default type of `List[TensorType]`): + +```python +import torch +from typing import List + +def g(l: List[int], val: int): + l.append(val) + return l + +def f(val: int): + l = g(torch.jit.annotate(List[int], []), val) + return l + +m = torch.jit.script(f) +print("Eager:", f(3)) +print("TorchScript:", m(3)) +``` + +See {meth}`torch.jit.annotate` for more information. + +### Type Annotation Appendix + +#### TorchScript Type System Definition + +``` +TSAllType ::= TSType | TSModuleType +TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType + +TSMetaType ::= "Any" +TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None" + +TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional | + TSUnion | TSFuture | TSRRef | TSAwait +TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" +TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" +TSList ::= "List" "[" TSType "]" +TSOptional ::= "Optional" "[" TSType "]" +TSUnion ::= "Union" "[" (TSType ",")* TSType "]" +TSFuture ::= "Future" "[" TSType "]" +TSRRef ::= "RRef" "[" TSType "]" +TSAwait ::= "Await" "[" TSType "]" +TSDict ::= "Dict" "[" KeyType "," TSType "]" +KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any" + +TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum +TSBuiltinClass ::= TSTensor | "torch.device" | "torch.stream"| + "torch.dtype" | "torch.nn.ModuleList" | + "torch.nn.ModuleDict" | ... +TSTensor ::= "torch.tensor" and subclasses +``` + +#### Unsupported Typing Constructs + +TorchScript does not support all features and types of the Python3 [typing](https://docs.python.org/3/library/typing.html#module-typing) module. +Any functionality from the [typing](https://docs.python.org/3/library/typing.html#module-typing) module that is not explicitly specified in this +documentation is unsupported. The following table summarizes `typing` constructs that are either unsupported or supported with restrictions in TorchScript. + +```{eval-rst} +============================= ================ + Item Description +----------------------------- ---------------- +``typing.Any`` In development +``typing.NoReturn`` Not supported +``typing.Callable`` Not supported +``typing.Literal`` Not supported +``typing.ClassVar`` Not supported +``typing.Final`` Supported for module attributes, class attribute, and annotations, but not for functions. +``typing.AnyStr`` Not supported +``typing.overload`` In development +Type aliases Not supported +Nominal typing In development +Structural typing Not supported +NewType Not supported +Generics Not supported +============================= ================ +``` + +(expressions)= + +## Expressions + +The following section describes the grammar of expressions that are supported in TorchScript. +It is modeled after [the expressions chapter of the Python language reference](https://docs.python.org/3/reference/expressions.html). + +### Arithmetic Conversions + +There are a number of implicit type conversions that are performed in TorchScript: + +- A `Tensor` with a `float` or `int` data type can be implicitly converted to an instance of `FloatType` or `IntType` provided that it has a size of 0, does not have `require_grad` set to `True`, and will not require narrowing. +- Instances of `StringType` can be implicitly converted to `DeviceType`. +- The implicit conversion rules from the two bullet points above can be applied to instances of `TupleType` to produce instances of `ListType` with the appropriate contained type. + +Explicit conversions can be invoked using the `float`, `int`, `bool`, and `str` built-in functions +that accept primitive data types as arguments and can accept user-defined types if they implement +`__bool__`, `__str__`, etc. + +### Atoms + +Atoms are the most basic elements of expressions. + +``` +atom ::= identifier | literal | enclosure +enclosure ::= parenth_form | list_display | dict_display +``` + +#### Identifiers + +The rules that dictate what is a legal identifier in TorchScript are the same as +their [Python counterparts](https://docs.python.org/3/reference/lexical_analysis.html#identifiers). + +#### Literals + +``` +literal ::= stringliteral | integer | floatnumber +``` + +Evaluation of a literal yields an object of the appropriate type with the specific value +(with approximations applied as necessary for floats). Literals are immutable, and multiple evaluations +of identical literals may obtain the same object or distinct objects with the same value. +[stringliteral](https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals), +[integer](https://docs.python.org/3/reference/lexical_analysis.html#integer-literals), and +[floatnumber](https://docs.python.org/3/reference/lexical_analysis.html#floating-point-literals) +are defined in the same way as their Python counterparts. + +#### Parenthesized Forms + +``` +parenth_form ::= '(' [expression_list] ')' +``` + +A parenthesized expression list yields whatever the expression list yields. If the list contains at least one +comma, it yields a `Tuple`; otherwise, it yields the single expression inside the expression list. An empty +pair of parentheses yields an empty `Tuple` object (`Tuple[]`). + +#### List and Dictionary Displays + +``` +list_comprehension ::= expression comp_for +comp_for ::= 'for' target_list 'in' or_expr +list_display ::= '[' [expression_list | list_comprehension] ']' +dict_display ::= '{' [key_datum_list | dict_comprehension] '}' +key_datum_list ::= key_datum (',' key_datum)* +key_datum ::= expression ':' expression +dict_comprehension ::= key_datum comp_for +``` + +Lists and dicts can be constructed by either listing the container contents explicitly or by providing +instructions on how to compute them via a set of looping instructions (i.e. a *comprehension*). A comprehension +is semantically equivalent to using a for loop and appending to an ongoing list. +Comprehensions implicitly create their own scope to make sure that the items of the target list do not leak into the +enclosing scope. In the case that container items are explicitly listed, the expressions in the expression list +are evaluated left-to-right. If a key is repeated in a `dict_display` that has a `key_datum_list`, the +resultant dictionary uses the value from the rightmost datum in the list that uses the repeated key. + +### Primaries + +``` +primary ::= atom | attributeref | subscription | slicing | call +``` + +#### Attribute References + +``` +attributeref ::= primary '.' identifier +``` + +The `primary` must evaluate to an object of a type that supports attribute references that have an attribute named +`identifier`. + +#### Subscriptions + +``` +subscription ::= primary '[' expression_list ']' +``` + +The `primary` must evaluate to an object that supports subscription. + +- If the primary is a `List`, `Tuple`, or `str`, the expression list must evaluate to an integer or slice. +- If the primary is a `Dict`, the expression list must evaluate to an object of the same type as the key type of the `Dict`. +- If the primary is a `ModuleList`, the expression list must be an `integer` literal. +- If the primary is a `ModuleDict`, the expression must be a `stringliteral`. + +#### Slicings + +A slicing selects a range of items in a `str`, `Tuple`, `List`, or `Tensor`. Slicings may be used as +expressions or targets in assignment or `del` statements. + +``` +slicing ::= primary '[' slice_list ']' +slice_list ::= slice_item (',' slice_item)* [','] +slice_item ::= expression | proper_slice +proper_slice ::= [expression] ':' [expression] [':' [expression] ] +``` + +Slicings with more than one slice item in their slice lists can only be used with primaries that evaluate to an +object of type `Tensor`. + +#### Calls + +``` +call ::= primary '(' argument_list ')' +argument_list ::= args [',' kwargs] | kwargs +args ::= [arg (',' arg)*] +kwargs ::= [kwarg (',' kwarg)*] +kwarg ::= arg '=' expression +arg ::= identifier +``` + +The `primary` must desugar or evaluate to a callable object. All argument expressions are evaluated +before the call is attempted. + +### Power Operator + +``` +power ::= primary ['**' u_expr] +``` + +The power operator has the same semantics as the built-in pow function (not supported); it computes its +left argument raised to the power of its right argument. It binds more tightly than unary operators on the +left, but less tightly than unary operators on the right; i.e. `-2 ** -3 == -(2 ** (-3))`. The left and right +operands can be `int`, `float` or `Tensor`. Scalars are broadcast in the case of scalar-tensor/tensor-scalar +exponentiation operations, and tensor-tensor exponentiation is done elementwise without any broadcasting. + +### Unary and Arithmetic Bitwise Operations + +``` +u_expr ::= power | '-' power | '~' power +``` + +The unary `-` operator yields the negation of its argument. The unary `~` operator yields the bitwise inversion +of its argument. `-` can be used with `int`, `float`, and `Tensor` of `int` and `float`. +`~` can only be used with `int` and `Tensor` of `int`. + +### Binary Arithmetic Operations + +``` +m_expr ::= u_expr | m_expr '*' u_expr | m_expr '@' m_expr | m_expr '//' u_expr | m_expr '/' u_expr | m_expr '%' u_expr +a_expr ::= m_expr | a_expr '+' m_expr | a_expr '-' m_expr +``` + +The binary arithmetic operators can operate on `Tensor`, `int`, and `float`. For tensor-tensor ops, both arguments must +have the same shape. For scalar-tensor or tensor-scalar ops, the scalar is usually broadcast to the size of the +tensor. Division ops can only accept scalars as their right-hand side argument, and do not support broadcasting. +The `@` operator is for matrix multiplication and only operates on `Tensor` arguments. The multiplication operator +(`*`) can be used with a list and integer in order to get a result that is the original list repeated a certain +number of times. + +### Shifting Operations + +``` +shift_expr ::= a_expr | shift_expr ( '<<' | '>>' ) a_expr +``` + +These operators accept two `int` arguments, two `Tensor` arguments, or a `Tensor` argument and an `int` or +`float` argument. In all cases, a right shift by `n` is defined as floor division by `pow(2, n)`, and a left shift +by `n` is defined as multiplication by `pow(2, n)`. When both arguments are `Tensors`, they must have the same +shape. When one is a scalar and the other is a `Tensor`, the scalar is logically broadcast to match the size of +the `Tensor`. + +### Binary Bitwise Operations + +``` +and_expr ::= shift_expr | and_expr '&' shift_expr +xor_expr ::= and_expr | xor_expr '^' and_expr +or_expr ::= xor_expr | or_expr '|' xor_expr +``` + +The `&` operator computes the bitwise AND of its arguments, the `^` the bitwise XOR, and the `|` the bitwise OR. +Both operands must be `int` or `Tensor`, or the left operand must be `Tensor` and the right operand must be +`int`. When both operands are `Tensor`, they must have the same shape. When the right operand is `int`, and +the left operand is `Tensor`, the right operand is logically broadcast to match the shape of the `Tensor`. + +### Comparisons + +``` +comparison ::= or_expr (comp_operator or_expr)* +comp_operator ::= '<' | '>' | '==' | '>=' | '<=' | '!=' | 'is' ['not'] | ['not'] 'in' +``` + +A comparison yields a boolean value (`True` or `False`), or if one of the operands is a `Tensor`, a boolean +`Tensor`. Comparisons can be chained arbitrarily as long as they do not yield boolean `Tensors` that have more +than one element. `a op1 b op2 c ...` is equivalent to `a op1 b and b op2 c and ...`. + +#### Value Comparisons + +The operators `<`, `>`, `==`, `>=`, `<=`, and `!=` compare the values of two objects. The two objects generally need to be of +the same type, unless there is an implicit type conversion available between the objects. User-defined types can +be compared if rich comparison methods (e.g., `__lt__`) are defined on them. Built-in type comparison works like +Python: + +- Numbers are compared mathematically. +- Strings are compared lexicographically. +- `lists`, `tuples`, and `dicts` can be compared only to other `lists`, `tuples`, and `dicts` of the same type and are compared using the comparison operator of corresponding elements. + +#### Membership Test Operations + +The operators `in` and `not in` test for membership. `x in s` evaluates to `True` if `x` is a member of `s` and `False` otherwise. +`x not in s` is equivalent to `not x in s`. This operator is supported for `lists`, `dicts`, and `tuples`, and can be used with +user-defined types if they implement the `__contains__` method. + +#### Identity Comparisons + +For all types except `int`, `double`, `bool`, and `torch.device`, operators `is` and `is not` test for the object’s identity; +`x is y` is `True` if and only if `x` and `y` are the same object. For all other types, `is` is equivalent to +comparing them using `==`. `x is not y` yields the inverse of `x is y`. + +### Boolean Operations + +``` +or_test ::= and_test | or_test 'or' and_test +and_test ::= not_test | and_test 'and' not_test +not_test ::= 'bool' '(' or_expr ')' | comparison | 'not' not_test +``` + +User-defined objects can customize their conversion to `bool` by implementing a `__bool__` method. The operator `not` +yields `True` if its operand is false, `False` otherwise. The expression `x` and `y` first evaluates `x`; if it is `False`, its +value (`False`) is returned; otherwise, `y` is evaluated and its value is returned (`False` or `True`). The expression `x` or `y` +first evaluates `x`; if it is `True`, its value (`True`) is returned; otherwise, `y` is evaluated and its value is returned +(`False` or `True`). + +### Conditional Expressions + +``` +conditional_expression ::= or_expr ['if' or_test 'else' conditional_expression] +expression ::= conditional_expression +``` + +The expression `x if c else y` first evaluates the condition `c` rather than x. If `c` is `True`, `x` is +evaluated and its value is returned; otherwise, `y` is evaluated and its value is returned. As with if-statements, +`x` and `y` must evaluate to a value of the same type. + +### Expression Lists + +``` +expression_list ::= expression (',' expression)* [','] +starred_item ::= '*' primary +``` + +A starred item can only appear on the left-hand side of an assignment statement, e.g., `a, *b, c = ...`. + +% statements: + +## Simple Statements + +The following section describes the syntax of simple statements that are supported in TorchScript. +It is modeled after [the simple statements chapter of the Python language reference](https://docs.python.org/3/reference/simple_stmts.html). + +### Expression Statements + +``` +expression_stmt ::= starred_expression +starred_expression ::= expression | (starred_item ",")* [starred_item] +starred_item ::= assignment_expression | "*" or_expr +``` + +### Assignment Statements + +``` +assignment_stmt ::= (target_list "=")+ (starred_expression) +target_list ::= target ("," target)* [","] +target ::= identifier + | "(" [target_list] ")" + | "[" [target_list] "]" + | attributeref + | subscription + | slicing + | "*" target +``` + +### Augmented Assignment Statements + +``` +augmented_assignment_stmt ::= augtarget augop (expression_list) +augtarget ::= identifier | attributeref | subscription +augop ::= "+=" | "-=" | "*=" | "/=" | "//=" | "%=" | + "**="| ">>=" | "<<=" | "&=" | "^=" | "|=" +``` + +### Annotated Assignment Statements + +``` +annotated_assignment_stmt ::= augtarget ":" expression + ["=" (starred_expression)] +``` + +### The `raise` Statement + +``` +raise_stmt ::= "raise" [expression ["from" expression]] +``` + +Raise statements in TorchScript do not support `try\except\finally`. + +### The `assert` Statement + +``` +assert_stmt ::= "assert" expression ["," expression] +``` + +Assert statements in TorchScript do not support `try\except\finally`. + +### The `return` Statement + +``` +return_stmt ::= "return" [expression_list] +``` + +Return statements in TorchScript do not support `try\except\finally`. + +### The `del` Statement + +``` +del_stmt ::= "del" target_list +``` + +### The `pass` Statement + +``` +pass_stmt ::= "pass" +``` + +### The `print` Statement + +``` +print_stmt ::= "print" "(" expression [, expression] [.format{expression_list}] ")" +``` + +### The `break` Statement + +``` +break_stmt ::= "break" +``` + +### The `continue` Statement: + +``` +continue_stmt ::= "continue" +``` + +## Compound Statements + +The following section describes the syntax of compound statements that are supported in TorchScript. +The section also highlights how Torchscript differs from regular Python statements. +It is modeled after [the compound statements chapter of the Python language reference](https://docs.python.org/3/reference/compound_stmts.html). + +### The `if` Statement + +Torchscript supports both basic `if/else` and ternary `if/else`. + +#### Basic `if/else` Statement + +``` +if_stmt ::= "if" assignment_expression ":" suite + ("elif" assignment_expression ":" suite) + ["else" ":" suite] +``` + +`elif` statements can repeat for an arbitrary number of times, but it needs to be before `else` statement. + +#### Ternary `if/else` Statement + +``` +if_stmt ::= return [expression_list] "if" assignment_expression "else" [expression_list] +``` + +**Example 1** + +A `tensor` with 1 dimension is promoted to `bool`: + +```{eval-rst} +.. testcode:: + + import torch + + @torch.jit.script + def fn(x: torch.Tensor): + if x: # The tensor gets promoted to bool + return True + return False + print(fn(torch.rand(1))) +``` + +The example above produces the following output: + +```{eval-rst} +.. testoutput:: + + True +``` + +**Example 2** + +A `tensor` with multi dimensions are not promoted to `bool`: + +```python +import torch + +# Multi dimensional Tensors error out. + +@torch.jit.script +def fn(): + if torch.rand(2): + print("Tensor is available") + + if torch.rand(4,5,6): + print("Tensor is available") + +print(fn()) +``` + +Running the above code yields the following `RuntimeError`. + +``` +RuntimeError: The following operation failed in the TorchScript interpreter. +Traceback of TorchScript (most recent call last): +@torch.jit.script +def fn(): + if torch.rand(2): + ~~~~~~~~~~~~ <--- HERE + print("Tensor is available") +RuntimeError: Boolean value of Tensor with more than one value is ambiguous +``` + +If a conditional variable is annotated as `final`, either the true or false branch is evaluated depending on the evaluation of the conditional variable. + +**Example 3** + +In this example, only the True branch is evaluated, since `a` is annotated as `final` and set to `True`: + +```python +import torch + +a : torch.jit.final[Bool] = True + +if a: + return torch.empty(2,3) +else: + return [] +``` + +### The `while` Statement + +``` +while_stmt ::= "while" assignment_expression ":" suite +``` + +`while...else` statements are not supported in Torchscript. It results in a `RuntimeError`. + +### The `for-in` Statement + +``` +for_stmt ::= "for" target_list "in" expression_list ":" suite + ["else" ":" suite] +``` + +`for...else` statements are not supported in Torchscript. It results in a `RuntimeError`. + +**Example 1** + +For loops on tuples: these unroll the loop, generating a body for each member of the tuple. The body must type-check correctly for each member. + +```{eval-rst} +.. testcode:: + + import torch + from typing import Tuple + + @torch.jit.script + def fn(): + tup = (3, torch.ones(4)) + for x in tup: + print(x) + + fn() +``` + +The example above produces the following output: + +```{eval-rst} +.. testoutput:: + + 3 + 1 + 1 + 1 + 1 + [ CPUFloatType{4} ] + +``` + +**Example 2** + +For loops on lists: for loops over a `nn.ModuleList` will unroll the body of the loop at compile time, with each member of the module list. + +```python +class SubModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(2)) + + def forward(self, input): + return self.weight + input + +class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mods = torch.nn.ModuleList([SubModule() for i in range(10)]) + + def forward(self, v): + for module in self.mods: + v = module(v) + return v + +model = torch.jit.script(MyModule()) +``` + +### The `with` Statement + +The `with` statement is used to wrap the execution of a block with methods defined by a context manager. + +``` +with_stmt ::= "with" with_item ("," with_item) ":" suite +with_item ::= expression ["as" target] +``` + +- If a target was included in the `with` statement, the return value from the context manager’s `__enter__()` is assigned to it. Unlike python, if an exception caused the suite to be exited, its type, value, and traceback are not passed as arguments to `__exit__()`. Three `None` arguments are supplied. +- `try`, `except`, and `finally` statements are not supported inside `with` blocks. +- Exceptions raised within `with` block cannot be suppressed. + +### The `tuple` Statement + +``` +tuple_stmt ::= tuple([iterables]) +``` + +- Iterable types in TorchScript include `Tensors`, `lists`, `tuples`, `dictionaries`, `strings`, `torch.nn.ModuleList`, and `torch.nn.ModuleDict`. +- You cannot convert a List to Tuple by using this built-in function. + +Unpacking all outputs into a tuple is covered by: + +``` +abc = func() # Function that returns a tuple +a,b = func() +``` + +### The `getattr` Statement + +``` +getattr_stmt ::= getattr(object, name[, default]) +``` + +- Attribute name must be a literal string. +- Module type object is not supported (e.g., torch.\_C). +- Custom class object is not supported (e.g., torch.classes.\*). + +### The `hasattr` Statement + +``` +hasattr_stmt ::= hasattr(object, name) +``` + +- Attribute name must be a literal string. +- Module type object is not supported (e.g., torch.\_C). +- Custom class object is not supported (e.g., torch.classes.\*). + +### The `zip` Statement + +``` +zip_stmt ::= zip(iterable1, iterable2) +``` + +- Arguments must be iterables. +- Two iterables of same outer container type but different length are supported. + +**Example 1** + +Both the iterables must be of the same container type: + +```{eval-rst} +.. testcode:: + + a = [1, 2] # List + b = [2, 3, 4] # List + zip(a, b) # works +``` + +**Example 2** + +This example fails because the iterables are of different container types: + +``` +a = (1, 2) # Tuple +b = [2, 3, 4] # List +zip(a, b) # Runtime error +``` + +Running the above code yields the following `RuntimeError`. + +``` +RuntimeError: Can not iterate over a module list or + tuple with a value that does not have a statically determinable length. +``` + +**Example 3** + +Two iterables of the same container Type but different data type is supported: + +```{eval-rst} +.. testcode:: + + a = [1.3, 2.4] + b = [2, 3, 4] + zip(a, b) # Works +``` + +Iterable types in TorchScript include `Tensors`, `lists`, `tuples`, `dictionaries`, `strings`, `torch.nn.ModuleList`, and `torch.nn.ModuleDict`. + +### The `enumerate` Statement + +``` +enumerate_stmt ::= enumerate([iterable]) +``` + +- Arguments must be iterables. +- Iterable types in TorchScript include `Tensors`, `lists`, `tuples`, `dictionaries`, `strings`, `torch.nn.ModuleList` and `torch.nn.ModuleDict`. + +(python-values-torch-script)= + +## Python Values + +(python-builtin-functions-values-resolution)= + +### Resolution Rules + +When given a Python value, TorchScript attempts to resolve it in the following five different ways: + +- Compilable Python Implementation: + : - When a Python value is backed by a Python implementation that can be compiled by TorchScript, TorchScript compiles and uses the underlying Python implementation. + - Example: `torch.jit.Attribute` +- Op Python Wrapper: + : - When a Python value is a wrapper of a native PyTorch op, TorchScript emits the corresponding operator. + - Example: `torch.jit._logging.add_stat_value` +- Python Object Identity Match: + : - For a limited set of `torch.*` API calls (in the form of Python values) that TorchScript supports, TorchScript attempts to match a Python value against each item in the set. + - When matched, TorchScript generates a corresponding `SugaredValue` instance that contains lowering logic for these values. + - Example: `torch.jit.isinstance()` +- Name Match: + : - For Python built-in functions and constants, TorchScript identifies them by name, and creates a corresponding `SugaredValue` instance that implements their functionality. + - Example: `all()` +- Value Snapshot: + : - For Python values from unrecognized modules, TorchScript attempts to take a snapshot of the value and converts it to a constant in the graph of the function(s) or method(s) that are being compiled. + - Example: `math.pi` + +(python-builtin-functions-support)= + +### Python Built-in Functions Support + +```{eval-rst} +.. list-table:: TorchScript Support for Python Built-in Functions + :widths: 25 25 50 + :header-rows: 1 + + * - Built-in Function + - Support Level + - Notes + * - ``abs()`` + - Partial + - Only supports ``Tensor``/``Int``/``Float`` type inputs. | Doesn't honor ``__abs__`` override. + * - ``all()`` + - Full + - + * - ``any()`` + - Full + - + * - ``ascii()`` + - None + - + * - ``bin()`` + - Partial + - Only supports ``Int`` type input. + * - ``bool()`` + - Partial + - Only supports ``Tensor``/``Int``/``Float`` type inputs. + * - ``breakpoint()`` + - None + - + * - ``bytearray()`` + - None + - + * - ``bytes()`` + - None + - + * - ``callable()`` + - None + - + * - ``chr()`` + - Partial + - Only ASCII character set is supported. + * - ``classmethod()`` + - Full + - + * - ``compile()`` + - None + - + * - ``complex()`` + - None + - + * - ``delattr()`` + - None + - + * - ``dict()`` + - Full + - + * - ``dir()`` + - None + - + * - ``divmod()`` + - Full + - + * - ``enumerate()`` + - Full + - + * - ``eval()`` + - None + - + * - ``exec()`` + - None + - + * - ``filter()`` + - None + - + * - ``float()`` + - Partial + - Doesn't honor ``__index__`` override. + * - ``format()`` + - Partial + - Manual index specification not supported. | Format type modifier not supported. + * - ``frozenset()`` + - None + - + * - ``getattr()`` + - Partial + - Attribute name must be string literal. + * - ``globals()`` + - None + - + * - ``hasattr()`` + - Partial + - Attribute name must be string literal. + * - ``hash()`` + - Full + - ``Tensor``'s hash is based on identity not numeric value. + * - ``hex()`` + - Partial + - Only supports ``Int`` type input. + * - ``id()`` + - Full + - Only supports ``Int`` type input. + * - ``input()`` + - None + - + * - ``int()`` + - Partial + - ``base`` argument not supported. | Doesn't honor ``__index__`` override. + * - ``isinstance()`` + - Full + - ``torch.jit.isintance`` provides better support when checking against container types like ``Dict[str, int]``. + * - ``issubclass()`` + - None + - + * - ``iter()`` + - None + - + * - ``len()`` + - Full + - + * - ``list()`` + - Full + - + * - ``ord()`` + - Partial + - Only ASCII character set is supported. + * - ``pow()`` + - Full + - + * - ``print()`` + - Partial + - ``separate``, ``end`` and ``file`` arguments are not supported. + * - ``property()`` + - None + - + * - ``range()`` + - Full + - + * - ``repr()`` + - None + - + * - ``reversed()`` + - None + - + * - ``round()`` + - Partial + - ``ndigits`` argument is not supported. + * - ``set()`` + - None + - + * - ``setattr()`` + - None + - + * - ``slice()`` + - Full + - + * - ``sorted()`` + - Partial + - ``key`` argument is not supported. + * - ``staticmethod()`` + - Full + - + * - ``str()`` + - Partial + - ``encoding`` and ``errors`` arguments are not supported. + * - ``sum()`` + - Full + - + * - ``super()`` + - Partial + - It can only be used in ``nn.Module``'s ``__init__`` method. + * - ``type()`` + - None + - + * - ``vars()`` + - None + - + * - ``zip()`` + - Full + - + * - ``__import__()`` + - None + - +``` + +(python-builtin-values-support)= + +### Python Built-in Values Support + +```{eval-rst} +.. list-table:: TorchScript Support for Python Built-in Values + :widths: 25 25 50 + :header-rows: 1 + + * - Built-in Value + - Support Level + - Notes + * - ``False`` + - Full + - + * - ``True`` + - Full + - + * - ``None`` + - Full + - + * - ``NotImplemented`` + - None + - + * - ``Ellipsis`` + - Full + - + +``` + +(torch-apis-in-torchscript)= + +## torch.\* APIs + +(torch-apis-in-torchscript-rpc)= + +### Remote Procedure Calls + +TorchScript supports a subset of RPC APIs that supports running a function on +a specified remote worker instead of locally. + +Specifically, following APIs are fully supported: + +- `torch.distributed.rpc.rpc_sync()` + : - `rpc_sync()` makes a blocking RPC call to run a function on a remote worker. RPC messages are sent and received in parallel to execution of Python code. + - More details about its usage and examples can be found in {meth}`~torch.distributed.rpc.rpc_sync`. +- `torch.distributed.rpc.rpc_async()` + : - `rpc_async()` makes a non-blocking RPC call to run a function on a remote worker. RPC messages are sent and received in parallel to execution of Python code. + - More details about its usage and examples can be found in {meth}`~torch.distributed.rpc.rpc_async`. +- `torch.distributed.rpc.remote()` + : - `remote.()` executes a remote call on a worker and gets a Remote Reference `RRef` as the return value. + - More details about its usage and examples can be found in {meth}`~torch.distributed.rpc.remote`. + +(torch-apis-in-torchscript-async)= + +### Asynchronous Execution + +TorchScript enables you to create asynchronous computation tasks to make better use +of computation resources. This is done via supporting a list of APIs that are +only usable within TorchScript: + +- `torch.jit.fork()` + : - Creates an asynchronous task executing func and a reference to the value of the result of this execution. Fork will return immediately. + - Synonymous to `torch.jit._fork()`, which is only kept for backward compatibility reasons. + - More details about its usage and examples can be found in {meth}`~torch.jit.fork`. +- `torch.jit.wait()` + : - Forces completion of a `torch.jit.Future[T]` asynchronous task, returning the result of the task. + - Synonymous to `torch.jit._wait()`, which is only kept for backward compatibility reasons. + - More details about its usage and examples can be found in {meth}`~torch.jit.wait`. + +(torch-apis-in-torchscript-annotation)= + +### Type Annotations + +TorchScript is statically-typed. It provides and supports a set of utilities to help annotate variables and attributes: + +- `torch.jit.annotate()` + : - Provides a type hint to TorchScript where Python 3 style type hints do not work well. + - One common example is to annotate type for expressions like `[]`. `[]` is treated as `List[torch.Tensor]` by default. When a different type is needed, you can use this code to hint TorchScript: `torch.jit.annotate(List[int], [])`. + - More details can be found in {meth}`~torch.jit.annotate` +- `torch.jit.Attribute` + : - Common use cases include providing type hint for `torch.nn.Module` attributes. Because their `__init__` methods are not parsed by TorchScript, `torch.jit.Attribute` should be used instead of `torch.jit.annotate` in the module's `__init__` methods. + - More details can be found in {meth}`~torch.jit.Attribute` +- `torch.jit.Final` + : - An alias for Python's `typing.Final`. `torch.jit.Final` is kept only for backward compatibility reasons. + +(torch-apis-in-torchscript-meta-programming)= + +### Meta Programming + +TorchScript provides a set of utilities to facilitate meta programming: + +- `torch.jit.is_scripting()` + : - Returns a boolean value indicating whether the current program is compiled by `torch.jit.script` or not. + - When used in an `assert` or an `if` statement, the scope or branch where `torch.jit.is_scripting()` evaluates to `False` is not compiled. + - Its value can be evaluated statically at compile time, thus commonly used in `if` statements to stop TorchScript from compiling one of the branches. + - More details and examples can be found in {meth}`~torch.jit.is_scripting` +- `torch.jit.is_tracing()` + : - Returns a boolean value indicating whether the current program is traced by `torch.jit.trace` / `torch.jit.trace_module` or not. + - More details can be found in {meth}`~torch.jit.is_tracing` +- `@torch.jit.ignore` + : - This decorator indicates to the compiler that a function or method should be ignored and left as a Python function. + - This allows you to leave code in your model that is not yet TorchScript compatible. + - If a function decorated by `@torch.jit.ignore` is called from TorchScript, ignored functions will dispatch the call to the Python interpreter. + - Models with ignored functions cannot be exported. + - More details and examples can be found in {meth}`~torch.jit.ignore` +- `@torch.jit.unused` + : - This decorator indicates to the compiler that a function or method should be ignored and replaced with the raising of an exception. + - This allows you to leave code in your model that is not yet TorchScript compatible and still export your model. + - If a function decorated by `@torch.jit.unused` is called from TorchScript, a runtime error will be raised. + - More details and examples can be found in {meth}`~torch.jit.unused` + +(torch-apis-in-torchscript-type-refinement)= + +### Type Refinement + +- `torch.jit.isinstance()` + : - Returns a boolean indicating whether a variable is of the specified type. + - More details about its usage and examples can be found in {meth}`~torch.jit.isinstance`. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/docs/source/jit_python_reference.rst b/docs/source/jit_python_reference.rst new file mode 100644 index 0000000000000..96e0fe13037c1 --- /dev/null +++ b/docs/source/jit_python_reference.rst @@ -0,0 +1,432 @@ +.. _python-language-reference: + +Python Language Reference Coverage +================================== + +This is a 1:1 mapping of the features listed in https://docs.python.org/3/reference/ and their +support in TorchScript. The categorizations are as follows: + + +.. list-table:: + :header-rows: 1 + + * - Section + - Status + - Note + * - `1. Introduction `_ + - Not Relevant + - + * - `1.1. Alternate Implementations `_ + - Not Relevant + - + * - `1.2. Notation `_ + - Not Relevant + - + * - `2. Lexical analysis `_ + - Not Relevant + - + * - `2.1. Line structure `_ + - Not Relevant + - + * - `2.1.1. Logical lines `_ + - Not Relevant + - + * - `2.1.2. Physical lines `_ + - Supported + - + * - `2.1.3. Comments `_ + - Supported + - + * - `2.1.4. Encoding declarations `_ + - Not Supported + - TorchScript explicitly don't support unicode + * - `2.1.5. Explicit line joining `_ + - Supported + - + * - `2.1.6. Implicit line joining `_ + - Supported + - + * - `2.1.7. Blank lines `_ + - Supported + - + * - `2.1.8. Indentation `_ + - Supported + - + * - `2.1.9. Whitespace between tokens `_ + - Not Relevant + - + * - `2.2. Other tokens `_ + - Not Relevant + - + * - `2.3. Identifiers and keywords `_ + - Supported + - + * - `2.3.1. Keywords `_ + - Supported + - + * - `2.3.2. Reserved classes of identifiers `_ + - Supported + - + * - `2.4. Literals `_ + - Not Relevant + - + * - `2.4.1. String and Bytes literals `_ + - Supported + - + * - `2.4.2. String literal concatenation `_ + - Supported + - + * - `2.4.3. Formatted string literals `_ + - Partially Supported + - + * - `2.4.4. Numeric literals `_ + - Supported + - + * - `2.4.5. Integer literals `_ + - Supported + - + * - `2.4.6. Floating point literals `_ + - Supported + - + * - `2.4.7. Imaginary literals `_ + - Not Supported + - + * - `2.5. Operators `_ + - Partially Supported + - Not supported: ``<<``, ``>>``, ``:=`` + * - `2.6. Delimiters `_ + - Partially Supported + - Not supported: ``**=``, ``<<=``, ``>>=``, ``%=``, ``^=``, ``@=``, ``&=``, ``//=``, ``%`` operator for some types (e.g. ``str``\ ) + * - `3. Data model `_ + - Not Relevant + - + * - `3.1. Objects, values and types `_ + - Not Relevant + - + * - `3.2. The standard type hierarchy `_ + - Partially Supported + - Not supported: NotImplemented, Ellipsis, numbers.Complex, bytes, byte arrays, sets, frozen sets, generators, coroutines, async generators, modules, I/O objects, internal objects, slice objects ( though slicing is supported), classmethod + * - `3.3. Special method names `_ + - Supported + - + * - `3.3.1. Basic customization `_ + - Partially Supported + - Not supported: ``__new__`` , ``__del__`` , ``__bytes__`` , ``__format__`` , ``__hash__`` , + * - `3.3.2. Customizing attribute access `_ + - Not Supported + - + * - `3.3.2.1. Customizing module attribute access `_ + - Not Supported + - + * - `3.3.2.2. Implementing Descriptors `_ + - Not Supported + - + * - `3.3.2.3. Invoking Descriptors `_ + - Not Supported + - + * - `3.3.2.4. __slots__ `_ + - Not Supported + - + * - `3.3.2.4.1. Notes on using __slots__ `_ + - Not Supported + - + * - `3.3.3. Customizing class creation `_ + - Not Supported + - + * - `3.3.3.1. Metaclasses `_ + - Not Supported + - + * - `3.3.3.2. Resolving MRO entries `_ + - Not Supported + - ``super()`` is not supported + * - `3.3.3.3. Determining the appropriate metaclass `_ + - Not relevant + - + * - `3.3.3.4. Preparing the class namespace `_ + - Not relevant + - + * - `3.3.3.5. Executing the class body `_ + - Not relevant + - + * - `3.3.3.6. Creating the class object `_ + - Not relevant + - + * - `3.3.3.7. Uses for metaclasses `_ + - Not relevant + - + * - `3.3.4. Customizing instance and subclass checks `_ + - Not Supported + - + * - `3.3.5. Emulating generic types `_ + - Not Supported + - + * - `3.3.6. Emulating callable objects `_ + - Supported + - + * - `3.3.7. Emulating container types `_ + - Partially Supported + - Some magic methods not supported (e.g. ``__iter__`` ) + * - `3.3.8. Emulating numeric types `_ + - Partially Supported + - Magic methods with swapped operands not supported (``__r*__``) + * - `3.3.9. With Statement Context Managers `_ + - Not Supported + - + * - `3.3.10. Special method lookup `_ + - Not relevant + - + * - `3.4. Coroutines `_ + - Not Supported + - + * - `3.4.1. Awaitable Objects `_ + - Not Supported + - + * - `3.4.2. Coroutine Objects `_ + - Not Supported + - + * - `3.4.3. Asynchronous Iterators `_ + - Not Supported + - + * - `3.4.4. Asynchronous Context Managers `_ + - Not Supported + - + * - `4. Execution model `_ + - Not Relevant + - + * - `4.1. Structure of a program `_ + - Not Relevant + - + * - `4.2. Naming and binding `_ + - Not Relevant + - Names are bound at compile time in TorchScript + * - `4.2.1. Binding of names `_ + - Not Relevant + - See ``global`` and ``nonlocal`` statements section + * - `4.2.2. Resolution of names `_ + - Not Relevant + - See ``global`` and ``nonlocal`` statements section + * - `4.2.3. Builtins and restricted execution `_ + - Not Relevant + - + * - `4.2.4. Interaction with dynamic features `_ + - Not Supported + - Python values cannot be captured + * - `4.3. Exceptions `_ + - Partially Supported + - See ``try`` and ``raise`` statement section + * - `5. The import system `_ + - Not Relevant + - + * - `6. Expressions `_ + - Not Relevant + - See expressions section + * - `6.1. Arithmetic conversions `_ + - Supported + - + * - `6.2. Atoms `_ + - Not Relevant + - + * - `6.2.1. Identifiers (Names) `_ + - Supported + - + * - `6.2.2. Literals `_ + - Partially Supported + - ``bytesliteral``\ , ``imagnumber`` not supported + * - `6.2.3. Parenthesized forms `_ + - Supported + - + * - `6.2.4. Displays for lists, sets and dictionaries `_ + - Partially Supported + - Not supported: comprehension ifs, async iterators + * - `6.2.5. List displays `_ + - Supported + - + * - `6.2.6. Set displays `_ + - Not Supported + - + * - `6.2.7. Dictionary displays `_ + - Supported + - dict() constructor with kwargs doesn't work, dict comprehensions, dictionary unpacking + * - `6.2.8. Generator expressions `_ + - Not Supported + - + * - `6.2.9. Yield expressions `_ + - Not Supported + - + * - `6.2.9.1. Generator-iterator methods `_ + - Not Supported + - + * - `6.2.9.2. Examples `_ + - Not Supported + - + * - `6.2.9.3. Asynchronous generator functions `_ + - Not Supported + - + * - `6.2.9.4. Asynchronous generator-iterator methods `_ + - Not Supported + - + * - `6.3. Primaries `_ + - Supported + - + * - `6.3.1. Attribute references `_ + - Supported + - + * - `6.3.2. Subscriptions `_ + - Supported + - + * - `6.3.3. Slicings `_ + - Partially Supported + - Tuple slicing with stride is not supported + * - `6.3.4. Calls `_ + - Partially Supported + - Args unpack / kwargs unpack is not supported + * - `6.4. Await expression `_ + - Not Supported + - + * - `6.5. The power operator `_ + - Supported + - + * - `6.6. Unary arithmetic and bitwise operations `_ + - Partially Supported + - Some bitwise operators are not implemented for primitive types (e.g. ``~x`` where ``x`` is an ``int`` is not currently supported) + * - `6.7. Binary arithmetic operations `_ + - Partially Supported + - See delimiters section + * - `6.8. Shifting operations `_ + - Not Supported + - + * - `6.9. Binary bitwise operations `_ + - Supported + - + * - `6.10. Comparisons `_ + - Supported + - + * - `6.10.1. Value comparisons `_ + - Partially Supported + - Dictionary equality checks are not currently supported + * - `6.10.2. Membership test operations `_ + - Partially Supported + - Not supported for TorchScript classes + * - `6.10.3. Identity comparisons `_ + - Supported + - + * - `6.11. Boolean operations `_ + - Supported + - + * - `6.12. Conditional expressions `_ + - Supported + - + * - `6.13. Lambdas `_ + - Not Supported + - + * - `6.14. Expression lists `_ + - Partially Supported + - Iterable unpacking not supported + * - `6.15. Evaluation order `_ + - Supported + - + * - `6.16. Operator precedence `_ + - Supported + - + * - `7. Simple statements `_ + - Supported + - + * - `7.1. Expression statements `_ + - Supported + - + * - `7.2. Assignment statements `_ + - Supported + - + * - `7.2.1. Augmented assignment statements `_ + - Partially Supported + - See delimiters section + * - `7.2.2. Annotated assignment statements `_ + - Supported + - + * - `7.3. The assert statement `_ + - Partially Supported + - Exception message is not customizable + * - `7.4. The pass statement `_ + - Supported + - + * - `7.5. The del statement `_ + - Not Supported + - + * - `7.6. The return statement `_ + - Supported + - Some other features of returning (e.g. behavior with try..finally) are unsupported + * - `7.7. The yield statement `_ + - Not Supported + - + * - `7.8. The raise statement `_ + - Partially Supported + - Exception message is not customizable + * - `7.9. The break statement `_ + - Supported + - Some other features of returning (e.g. behavior with try..finally) are unsupported + * - `7.10. The continue statement `_ + - Supported + - Some other features of returning (e.g. behavior with try..finally) are unsupported + * - `7.11. The import statement `_ + - Not Supported + - + * - `7.11.1. Future statements `_ + - Not Supported + - + * - `7.12. The global statement `_ + - Not Supported + - + * - `7.13. The nonlocal statement `_ + - Not Supported + - + * - `8. Compound statements `_ + - Irrelevant + - + * - `8.1. The if statement `_ + - Supported + - + * - `8.2. The while statement `_ + - Partially Supported + - while..else is not supported + * - `8.3. The for statement `_ + - Partially Supported + - for..else is not supported + * - `8.4. The try statement `_ + - Not Supported + - + * - `8.5. The with statement `_ + - Partially Supported + - ``__exit__`` is always called with ``exc_type``, ``exc_value``, and ``traceback`` set to None, even if an exception was raised, and ``__exit__``'s return value is ignored. + * - `8.6. Function definitions `_ + - Not Supported + - + * - `8.7. Class definitions `_ + - Not Supported + - + * - `8.8. Coroutines `_ + - Not Supported + - + * - `8.8.1. Coroutine function definition `_ + - Not Supported + - + * - `8.8.2. The async for statement `_ + - Not Supported + - + * - `8.8.3. The async with statement `_ + - Not Supported + - + * - `9. Top-level components `_ + - Not Relevant + - + * - `9.1. Complete Python programs `_ + - Not Relevant + - + * - `9.2. File input `_ + - Not Relevant + - + * - `9.3. Interactive input `_ + - Not Relevant + - + * - `9.4. Expression input `_ + - Not Relevant + - diff --git a/docs/source/jit_unsupported.rst b/docs/source/jit_unsupported.rst new file mode 100644 index 0000000000000..60bca7d6d92c6 --- /dev/null +++ b/docs/source/jit_unsupported.rst @@ -0,0 +1,90 @@ +.. _jit_unsupported: + +TorchScript Unsupported PyTorch Constructs +============================================ + +Torch and Tensor Unsupported Attributes +------------------------------------------ + + +TorchScript supports most methods defined on ``torch`` and ``torch.Tensor``, but we do not have full coverage. +Here are specific known ops and categories of ops which have diverging behavior between +Python and TorchScript. If you encounter something else that is not supported please +file a GitHub issue. Deprecated ops are not listed below. + + + +.. automodule:: torch.jit.unsupported_tensor_ops + + +Functions Not Correctly Bound on Torch +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The following functions will fail if used in TorchScript, either because they +are not bound on `torch` or because Python expects a different schema than +TorchScript. + + * :func:`torch.tensordot` + * :func:`torch.nn.init.calculate_gain` + * :func:`torch.nn.init.eye_` + * :func:`torch.nn.init.dirac_` + * :func:`torch.nn.init.kaiming_normal_` + * :func:`torch.nn.init.orthogonal_` + * :func:`torch.nn.init.sparse` + + +Ops With Divergent Schemas Between Torch & Python +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The following categories of ops have divergent schemas: + +Functions which construct tensors from non-tensor inputs do not support the `requires_grad` +argument, except for `torch.tensor`. This covers the following ops: + + * :func:`torch.norm` + * :func:`torch.bartlett_window` + * :func:`torch.blackman_window` + * :func:`torch.empty` + * :func:`torch.empty_like` + * :func:`torch.empty_strided` + * :func:`torch.eye` + * :func:`torch.full` + * :func:`torch.full_like` + * :func:`torch.hamming_window` + * :func:`torch.hann_window` + * :func:`torch.linspace` + * :func:`torch.logspace` + * :func:`torch.normal` + * :func:`torch.ones` + * :func:`torch.rand` + * :func:`torch.rand_like` + * :func:`torch.randint_like` + * :func:`torch.randn` + * :func:`torch.randn_like` + * :func:`torch.randperm` + * :func:`torch.tril_indices` + * :func:`torch.triu_indices` + * :func:`torch.vander` + * :func:`torch.zeros` + * :func:`torch.zeros_like` + +The following functions require `dtype`, `layout`, `device` as parameters in TorchScript, +but these parameters are optional in Python. + + * :func:`torch.randint` + * :func:`torch.sparse_coo_tensor` + * :meth:`~torch.Tensor.to` + + +PyTorch Unsupported Modules and Classes +------------------------------------------ + +TorchScript cannot currently compile a number of other commonly used PyTorch +constructs. Below are listed the modules that TorchScript does not support, and +an incomplete list of PyTorch classes that are not supported. For unsupported modules +we suggest using :meth:`torch.jit.trace`. + + * :class:`torch.nn.RNN` + * :class:`torch.nn.AdaptiveLogSoftmaxWithLoss` + * :class:`torch.autograd.Function` + * :class:`torch.autograd.enable_grad` diff --git a/docs/source/jit_utils.rst b/docs/source/jit_utils.rst new file mode 100644 index 0000000000000..abc4235912321 --- /dev/null +++ b/docs/source/jit_utils.rst @@ -0,0 +1,4 @@ +JIT Utils - torch.utils.jit +================================================== + +.. automodule:: torch.utils.jit diff --git a/docs/source/library.rst b/docs/source/library.rst new file mode 100644 index 0000000000000..6cefdf1eb10c7 --- /dev/null +++ b/docs/source/library.rst @@ -0,0 +1,80 @@ +.. _torch-library-docs: + +torch.library +=================================== +.. py:module:: torch.library +.. currentmodule:: torch.library + +torch.library is a collection of APIs for extending PyTorch's core library +of operators. It contains utilities for testing custom operators, creating new +custom operators, and extending operators defined with PyTorch's C++ operator +registration APIs (e.g. aten operators). + +For a detailed guide on effectively using these APIs, please see +`PyTorch Custom Operators Landing Page `_ +for more details on how to effectively use these APIs. + +Testing custom ops +------------------ + +Use :func:`torch.library.opcheck` to test custom ops for incorrect usage of the +Python torch.library and/or C++ TORCH_LIBRARY APIs. Also, if your operator supports +training, use :func:`torch.autograd.gradcheck` to test that the gradients are +mathematically correct. + +.. autofunction:: opcheck + +Creating new custom ops in Python +--------------------------------- + +Use :func:`torch.library.custom_op` to create new custom ops. + +.. autofunction:: custom_op +.. autofunction:: triton_op +.. autofunction:: wrap_triton + +Extending custom ops (created from Python or C++) +------------------------------------------------- + +Use the register.* methods, such as :func:`torch.library.register_kernel` and +:func:`torch.library.register_fake`, to add implementations +for any operators (they may have been created using :func:`torch.library.custom_op` or +via PyTorch's C++ operator registration APIs). + +.. autofunction:: register_kernel +.. autofunction:: register_autocast +.. autofunction:: register_autograd +.. autofunction:: register_fake +.. autofunction:: register_vmap +.. autofunction:: impl_abstract +.. autofunction:: get_ctx +.. autofunction:: register_torch_dispatch +.. autofunction:: infer_schema +.. autoclass:: torch._library.custom_ops.CustomOpDef + + .. automethod:: set_kernel_enabled + + +Low-level APIs +-------------- + +The following APIs are direct bindings to PyTorch's C++ low-level +operator registration APIs. + +.. warning:: + The low-level operator registration APIs and the PyTorch Dispatcher are a + complicated PyTorch concept. We recommend you use the higher level APIs above + (that do not require a torch.library.Library object) when possible. + `This blog post `_ + is a good starting point to learn about the PyTorch Dispatcher. + +A tutorial that walks you through some examples on how to use this API is available on `Google Colab `_. + +.. autoclass:: torch.library.Library + :members: + +.. autofunction:: fallthrough_kernel + +.. autofunction:: define + +.. autofunction:: impl diff --git a/docs/source/mtia.memory.md b/docs/source/mtia.memory.md index 43187c6f2e033..47826619d7e9e 100644 --- a/docs/source/mtia.memory.md +++ b/docs/source/mtia.memory.md @@ -16,5 +16,8 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined :nosignatures: memory_stats +<<<<<<< HEAD memory_allocated +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` diff --git a/docs/source/nn.attention.flex_attention.md b/docs/source/nn.attention.flex_attention.md index 4cfb51c5945c0..a9d4a5e4fac2e 100644 --- a/docs/source/nn.attention.flex_attention.md +++ b/docs/source/nn.attention.flex_attention.md @@ -14,12 +14,15 @@ ```{eval-rst} .. autofunction:: flex_attention ``` +<<<<<<< HEAD ```{eval-rst} .. autoclass:: AuxOutput ``` ```{eval-rst} .. autoclass:: AuxRequest ``` +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ## BlockMask Utilities @@ -42,6 +45,7 @@ .. autofunction:: noop_mask ``` +<<<<<<< HEAD ## FlexKernelOptions ```{eval-rst} @@ -50,6 +54,8 @@ :undoc-members: ``` +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ## BlockMask ```{eval-rst} diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 5d15e90a55499..fff24350e52f0 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -1,6 +1,7 @@ .. role:: hidden :class: hidden-section +<<<<<<< HEAD .. toctree:: :maxdepth: 2 :hidden: @@ -10,6 +11,12 @@ torch.nn =================================== .. automodule:: torch.nn +======= +torch.nn +=================================== +.. automodule:: torch.nn +.. automodule:: torch.nn.modules +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) These are the basic building blocks for graphs: @@ -480,8 +487,11 @@ for more information on how to implement your own parametrizations. parametrize.remove_parametrizations parametrize.cached parametrize.is_parametrized +<<<<<<< HEAD parametrize.transfer_parametrizations_and_params parametrize.type_before_parametrizations +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. autosummary:: :toctree: generated @@ -512,17 +522,25 @@ Utility functions in other modules nn.utils.rnn.pack_sequence nn.utils.rnn.unpack_sequence nn.utils.rnn.unpad_sequence +<<<<<<< HEAD nn.utils.rnn.invert_permutation nn.parameter.is_lazy nn.factory_kwargs +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. autosummary:: :toctree: generated :nosignatures: :template: classtemplate.rst +<<<<<<< HEAD nn.modules.flatten.Flatten nn.modules.flatten.Unflatten +======= + nn.Flatten + nn.Unflatten +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Quantized Functions -------------------- @@ -541,6 +559,21 @@ Lazy Modules Initialization nn.modules.lazy.LazyModuleMixin +<<<<<<< HEAD +======= +Aliases +_______ + +The following are aliases to their counterparts in ``torch.nn``: + +.. currentmodule:: torch +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + nn.modules.normalization.RMSNorm +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. This module needs to be documented. Adding here in the meantime .. for tracking purposes diff --git a/docs/source/notes/cpu_threading_runtimes.svg b/docs/source/notes/cpu_threading_runtimes.svg new file mode 100644 index 0000000000000..e36ec598f063c --- /dev/null +++ b/docs/source/notes/cpu_threading_runtimes.svg @@ -0,0 +1,208 @@ + +image/svg+xml0102030400.51.01.52.02.5# ThreadsTime, s diff --git a/docs/source/notes/cpu_threading_torchscript_inference.rst b/docs/source/notes/cpu_threading_torchscript_inference.rst index 8cac34c8c36fd..634ea6b309c5b 100644 --- a/docs/source/notes/cpu_threading_torchscript_inference.rst +++ b/docs/source/notes/cpu_threading_torchscript_inference.rst @@ -3,6 +3,166 @@ CPU threading and TorchScript inference ================================================= +<<<<<<< HEAD .. warning:: TorchScript is deprecated, please use `torch.export `__ instead. +======= +PyTorch allows using multiple CPU threads during TorchScript model inference. +The following figure shows different levels of parallelism one would find in a +typical application: + +.. image:: cpu_threading_torchscript_inference.svg + :width: 75% + +One or more inference threads execute a model's forward pass on the given inputs. +Each inference thread invokes a JIT interpreter that executes the ops +of a model inline, one by one. A model can utilize a ``fork`` TorchScript +primitive to launch an asynchronous task. Forking several operations at once +results in a task that is executed in parallel. The ``fork`` operator returns a +``Future`` object which can be used to synchronize on later, for example: + +.. code-block:: python + + @torch.jit.script + def compute_z(x): + return torch.mm(x, self.w_z) + + @torch.jit.script + def forward(x): + # launch compute_z asynchronously: + fut = torch.jit._fork(compute_z, x) + # execute the next operation in parallel to compute_z: + y = torch.mm(x, self.w_y) + # wait for the result of compute_z: + z = torch.jit._wait(fut) + return y + z + + +PyTorch uses a single thread pool for the inter-op parallelism, this thread pool +is shared by all inference tasks that are forked within the application process. + +In addition to the inter-op parallelism, PyTorch can also utilize multiple threads +within the ops (`intra-op parallelism`). This can be useful in many cases, +including element-wise ops on large tensors, convolutions, GEMMs, embedding +lookups and others. + + +Build options +------------- + +PyTorch uses an internal ATen library to implement ops. In addition to that, +PyTorch can also be built with support of external libraries, such as MKL_ and MKL-DNN_, +to speed up computations on CPU. + +ATen, MKL and MKL-DNN support intra-op parallelism and depend on the +following parallelization libraries to implement it: + +* OpenMP_ - a standard (and a library, usually shipped with a compiler), widely used in external libraries; +* TBB_ - a newer parallelization library optimized for task-based parallelism and concurrent environments. + +OpenMP historically has been used by a large number of libraries. It is known +for a relative ease of use and support for loop-based parallelism and other primitives. + +TBB is used to a lesser extent in external libraries, but, at the same time, +is optimized for the concurrent environments. PyTorch's TBB backend guarantees that +there's a separate, single, per-process intra-op thread pool used by all of the +ops running in the application. + +Depending of the use case, one might find one or another parallelization +library a better choice in their application. + +PyTorch allows selecting of the parallelization backend used by ATen and other +libraries at the build time with the following build options: + ++------------+------------------------+-----------------------------+----------------------------------------+ +| Library | Build Option | Values | Notes | ++============+========================+=============================+========================================+ +| ATen | ``ATEN_THREADING`` | ``OMP`` (default), ``TBB`` | | ++------------+------------------------+-----------------------------+----------------------------------------+ +| MKL | ``MKL_THREADING`` | (same) | To enable MKL use ``BLAS=MKL`` | ++------------+------------------------+-----------------------------+----------------------------------------+ +| MKL-DNN | ``MKLDNN_CPU_RUNTIME`` | (same) | To enable MKL-DNN use ``USE_MKLDNN=1`` | ++------------+------------------------+-----------------------------+----------------------------------------+ + +It is recommended not to mix OpenMP and TBB within one build. + +Any of the ``TBB`` values above require ``USE_TBB=1`` build setting (default: OFF). +A separate setting ``USE_OPENMP=1`` (default: ON) is required for OpenMP parallelism. + +Runtime API +----------- + +The following API is used to control thread settings: + ++------------------------+-----------------------------------------------------------+---------------------------------------------------------+ +| Type of parallelism | Settings | Notes | ++========================+===========================================================+=========================================================+ +| Inter-op parallelism | ``at::set_num_interop_threads``, | Default number of threads: number of CPU cores. | +| | ``at::get_num_interop_threads`` (C++) | | +| | | | +| | ``set_num_interop_threads``, | | +| | ``get_num_interop_threads`` (Python, :mod:`torch` module) | | ++------------------------+-----------------------------------------------------------+ | +| Intra-op parallelism | ``at::set_num_threads``, | | +| | ``at::get_num_threads`` (C++) | | +| | ``set_num_threads``, | | +| | ``get_num_threads`` (Python, :mod:`torch` module) | | +| | | | +| | Environment variables: | | +| | ``OMP_NUM_THREADS`` and ``MKL_NUM_THREADS`` | | ++------------------------+-----------------------------------------------------------+---------------------------------------------------------+ + +For the intra-op parallelism settings, ``at::set_num_threads``, ``torch.set_num_threads`` always take precedence +over environment variables, ``MKL_NUM_THREADS`` variable takes precedence over ``OMP_NUM_THREADS``. + +Tuning the number of threads +---------------------------- + +The following simple script shows how a runtime of matrix multiplication changes with the number of threads: + +.. code-block:: python + + import timeit + runtimes = [] + threads = [1] + [t for t in range(2, 49, 2)] + for t in threads: + torch.set_num_threads(t) + r = timeit.timeit(setup = "import torch; x = torch.randn(1024, 1024); y = torch.randn(1024, 1024)", stmt="torch.mm(x, y)", number=100) + runtimes.append(r) + # ... plotting (threads, runtimes) ... + +Running the script on a system with 24 physical CPU cores (Xeon E5-2680, MKL and OpenMP based build) results in the following runtimes: + +.. image:: cpu_threading_runtimes.svg + :width: 75% + +The following considerations should be taken into account when tuning the number of intra- and inter-op threads: + +* When choosing the number of threads one needs to avoid `oversubscription` (using too many threads, leads to performance degradation). For example, in an application that uses a large application thread pool or heavily relies on + inter-op parallelism, one might find disabling intra-op parallelism as a possible option (i.e. by calling ``set_num_threads(1)``); + +* In a typical application one might encounter a trade off between `latency` (time spent on processing an inference request) and `throughput` (amount of work done per unit of time). Tuning the number of threads can be a useful + tool to adjust this trade off in one way or another. For example, in latency critical applications one might want to increase the number of intra-op threads to process each request as fast as possible. At the same time, parallel implementations + of ops may add an extra overhead that increases amount work done per single request and thus reduces the overall throughput. + +.. warning:: + OpenMP does not guarantee that a single per-process intra-op thread + pool is going to be used in the application. On the contrary, two different application or inter-op + threads may use different OpenMP thread pools for intra-op work. + This might result in a large number of threads used by the application. + Extra care in tuning the number of threads is needed to avoid + oversubscription in multi-threaded applications in OpenMP case. + +.. note:: + Pre-built PyTorch releases are compiled with OpenMP support. + +.. note:: + ``parallel_info`` utility prints information about thread settings and can be used for debugging. + Similar output can be also obtained in Python with ``torch.__config__.parallel_info()`` call. + +.. _OpenMP: https://www.openmp.org/ +.. _TBB: https://github.com/intel/tbb +.. _MKL: https://software.intel.com/en-us/mkl +.. _MKL-DNN: https://github.com/intel/mkl-dnn +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/docs/source/notes/cpu_threading_torchscript_inference.svg b/docs/source/notes/cpu_threading_torchscript_inference.svg new file mode 100644 index 0000000000000..f09884cc5f274 --- /dev/null +++ b/docs/source/notes/cpu_threading_torchscript_inference.svg @@ -0,0 +1,681 @@ + +image/svg+xml… +Inputs +Application Thread Pool +… +Op +Op +Op +Inference thread +Fork +Op +Join +… +… +Inter +- +op parallelism +Intra +- +op parallelism +• +ATen/Parallel +(e.g. at::parallel_for) +• +MKL +• +MKL +- +DNN +• +... +OpenMP +TBB +… + diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index 8981ac1bf6ed4..705e876f776f9 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -64,6 +64,7 @@ Below you can find a small example showcasing this:: TensorFloat-32 (TF32) on Ampere (and later) devices --------------------------------------------------- +<<<<<<< HEAD After Pytorch 2.9, we provide a new sets of APIs to control the TF32 behavior in a more fine-grained way, and suggest to use the new APIs for better control. We can set float32 precision per backend and per operators. We can also override the global setting for a specific operator. @@ -107,6 +108,8 @@ We suggest to use the new settings for better control. And we do not support to Old settings with `allow_tf32` as follows is going to be deprecated. We suggest to use the above new settings for better control. And we do not support to use mix of old and new settings. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Starting in PyTorch 1.7, there is a new flag called `allow_tf32`. This flag defaults to True in PyTorch 1.7 to PyTorch 1.11, and False in PyTorch 1.12 and later. This flag controls whether PyTorch is allowed to use the TensorFloat32 (TF32) tensor cores, @@ -128,7 +131,11 @@ matmuls and convolutions are controlled separately, and their corresponding flag # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = True +<<<<<<< HEAD The precision of matmuls can also be set more broadly (limited not just to CUDA) via :meth:`~torch.set_float32_matmul_precision`. +======= +The precision of matmuls can also be set more broadly (limited not just to CUDA) via :meth:`~torch.set_float_32_matmul_precision`. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Note that besides matmuls and convolutions themselves, functions and nn modules that internally uses matmuls or convolutions are also affected. These include `nn.Linear`, `nn.Conv*`, cdist, tensordot, affine grid and grid sample, adaptive log softmax, GRU and LSTM. @@ -608,6 +615,7 @@ Available options: for processing events. This avoids any slow path associated with querying/processing of events in the fast allocation path. This feature is disabled by default. +<<<<<<< HEAD * ``graph_capture_record_stream_reuse`` (experimental, default: `False`) If set to `True`, the CUDA caching allocator will attempt to reclaim device memory during CUDA Graph capture by using the graph topology (instead of CUDA events) to determine @@ -616,6 +624,8 @@ Available options: reaches joined frontiers. Note: Enabling this option can significantly increase the time spent capturing the graph. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. note:: Some stats reported by the @@ -904,6 +914,7 @@ APIs can be used for debugging purposes: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html#memory-allocator +<<<<<<< HEAD Tuning NVLink Performance with Custom Memory Allocator on H100/H200 GPUs ------------------------------------------------------------------------ In rare cases, performance of NVLink on H100/H200 GPUs can be influenced by the physical memory @@ -1028,6 +1039,8 @@ functions are: } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cuBLAS workspaces ----------------- diff --git a/docs/source/notes/fsdp.rst b/docs/source/notes/fsdp.rst new file mode 100644 index 0000000000000..ce713fc1697f7 --- /dev/null +++ b/docs/source/notes/fsdp.rst @@ -0,0 +1,148 @@ +.. _fsdp_notes: + +FSDP Notes +========== + +.. _fsdp_prefetch: + +FSDP Prefetch Nuances +--------------------- + +For overlapping ``forward`` all-gathers with ``forward`` compute, there are two possible mechanisms: + +1. Implicit forward prefetching (always enabled) +2. Explicit forward prefetching (``forward_prefetch=True``) + +Implicit ``forward`` prefetching refers to relying on issuing the all-gathers from a separate CUDA +stream to allow for overlapping an all-gather with ``forward`` compute issued before it (from the CPU +perspective). For example, if we have layer 0 all-gather -> layer 0 ``forward`` compute -> layer 1 +all-gather -> …, then layer 1 all-gather can overlap with layer 0 ``forward`` compute even though the +CPU thread issued it afterwards. (The 1st all-gather will not be able to overlap with anything.) + +Explicit ``forward`` prefetching refers to changing the CPU thread’s issue order: e.g. layer 0 +all-gather -> layer 1 all-gather -> layer 0 ``forward`` compute -> …. In eager mode, there is no way to +know in general which layer is the next layer (e.g. layer 1 in the example) when still executing on +layer 0. Therefore, explicit ``forward`` prefetching should only be used for models whose execution +order is fixed from iteration to iteration (which we sometimes call ā€œstatic graphā€). An example of a +model that does not satisfy this constraint is `FLAVA +`_). + +Explicit ``forward`` prefetching only saves the time taken to issue a layer’s ``forward`` compute kernels at +the cost that the next all-gather’s output tensor must be allocated while the current one is still +in use. By issuing the next all- gather before the current ``forward`` compute kernels, the next +all-gather can start sooner on GPU. For most LLM workloads, this is not the case, so there is no +motivation for enabling ``forward_prefetch=True``. + +In contrast, for ``backward``, we must use explicit ``backward`` prefetching or else there will be 0 overlap +of communication and computation. The reason is because we use a single NCCL process group for both +all-gather and reduce-scatter (partially because in earlier NCCL versions, it was not safe to use +multiple concurrently on the same device over the same ranks). A single NCCL process group means a +single internal NCCL stream on which reduce-scatters and all-gathers run serially. As such, unless +we explicitly reorder the CPU issue order to be next all-gather -> current reduce-scatter, then the +current reduce-scatter would block the next all-gather and hence the next ``backward`` computation, +preventing the current reduce-scatter from overlapping. + +.. _fsdp_comms_payload_size: + +Communication payload size +-------------------------- + +In FSDP the communications are: + +1. all-gather on parameters in ``forward`` +2. all-gather on parameters in ``backward`` +3. reduce-scatter on gradients in ``backward`` + +If activation checkpointing (:func:`~torch.utils.checkpoint.checkpoint`) is used there is no +additional communication since the parameters are prefetched anyway during ``backward``. + +In the FSDP design, the communication payload per rank is determined as follows: Each call to +:class:`FullyShardedDataParallel` creates one communication group consisting of the parameters in +``module.parameters()`` except any already assigned to a nested :class:`FullyShardedDataParallel` +instance. For example, for Llama, if you apply :class:`FullyShardedDataParallel` to every +transformer block and also to the root module, then there is one communication group for each +transformer block and finally one communication group with the initial embedding and final linear. +Each communication group corresponds to a single all-gather call and single reduce-scatter call. In +that way, how you apply :class:`FullyShardedDataParallel` determines the communication size. In +general, applying FSDP to each transformer block is a good heuristic for LLMs, and it is hard to do +better than that given the current design. + +Let's consider an example where we have a Transformer-based model sharded over 8 GPUs, where the +sharding happens at the transformer block-level only, and each transformer block contains 1.6B +parameters and the parameters are in fp32 (4 bytes each). Which means that once sharded, each +transformer block will contain 0.2B parameters on each rank. + +* The ``forward`` pass will communicate in chunks of ``0.2*4 = 0.8GB`` in all-gather +* The ``backward`` pass will communicate 2 times ``0.8GB`` each (1x all-gather and 1x reduce-scatter) + +In other words there will be 3 communications with a payload of ``0.8GB`` each. If the model was +comprised of 10 transformer blocks there would be a total of 30 communications for a total of +``30*0.8=24GB``. + +To formalize the payload size per communication per rank is +``total_transformer_block_params_in_B*dtype_bytes/num_gpus`` (GBs). + +Please note that in this example we didn't include the additional communications required for the +embedding, which should be accounted for as well. And the math would depend on whether the input and +output embeddings are tied or not. If they aren't tied there will be 2x more communications. + +.. _fsdp_buffers_sizes: + +FSDP buffers sizes +------------------ + +First, let's cover the buffers allocated for communications: + +``forward`` currently requires 2x all-gather buffer size. Here is why: + +As explained in :ref:`fsdp_prefetch` in the case of explicit ``forward`` prefetching +(``forward_prefetch=True``) case of layer 0 all-gather -> layer 0 forward compute -> layer 1 +all-gather there is a need for 2 all-gather-sized buffers, because one buffer is used in the current ``forward`` while the other is used to do the prefetching. + +While the implicit ``forward`` prefetching (``forward_prefetch=False``, default) case of the same sequence in theory should need only 1 buffer, in reality it's still 2x all-gather-sized buffers. The reason is that in the flat-parameter FSDP design, we do not copy-out of the all-gather buffer. The parameters used for compute are directly viewed into the all-gather buffer (in fact, the main benefit of the "flat parameter" is exactly this reason). In that case, while 'layer 1 all-gather' is overlapping with 'layer 0 forward compute', the 'layer 0 forward compute' is using the parameters viewed into the 'layer 0 all-gather' buffer. + +A natural question then is, when would you want ``forward_prefetch=False``? For static-graph models (like most LLMs), there is a major technical reason. It is more that, practically, we added this option quickly for some CPU-bound internal models and have not tested every code path with it in unit testing, so we are less confident in it. ``forward_prefetching=False`` can be slightly easier to reason about since we do not have to check the recorded forward order as a possible 'failure mode'; a module's all-gather can always be found under its own ``record_function`` label in its profiler trace. + +``backward`` currently requires at least 2x all-gather buffer size and potentially a bit more. Here is why: + +The current FSDP design uses ``recordStream`` to manage allocations produced in one stream consumed in another, which can lead to more memory usage than expected. How much more can be "non-deterministic" in that it depends on GPU kernel timing relative to the CPU. The ``limit_all_gathers=True`` argument is a mitigation to that - for more details refer to this discussion is `FSDP & CUDACachingAllocator `_. + +The way existing FSDP works with autograd: + +* Existing FSDP all-gathers the ``flat_param``, which is the autograd leaf. +* It calls ``torch.split`` to get 1D views into the ``flat_param`` corresponding to its constituent original parameters. +* It calls ``torch.view`` on each 1D split to view back to ND. +* This means that in ``backward``, we end up with ``ViewBackward`` (ND -> 1D) and ``SplitWithSizesBackward`` (which is a concat). In particular, each individual gradient is computed as a separate allocation, and an explicit concat happens to construct the reduce-scatter input buffer. This implies actually a 2x buffer size for reduce-scatter at that peak memory point. + +In summary, for ``backward``, it is about 2x buffer size for reduce-scatter plus any ``recordStream`` effects. + +Second, let's discuss the additional buffers: + +Once the sharded parameters are gathered from all ranks, they require an additional buffer of `total_transformer_block_params_in_B*dtype_bytes` for the full parameters - so continuing the example from earlier if each transformer block is 1.6B parameters and the parameters are in fp32, then it'd be `1.6*4=6.4GB` buffer. + +And there is a need for 2 of those buffers, since there is one currently being used and another being prefetched. + +To summarize, we have: + +1. 2 times communication buffers of ``total_transformer_block_params_in_B*dtype_bytes/num_gpus`` +2. 2 times unsharded transformer block parameters buffer ````total_transformer_block_params_in_B*dtype_bytes`` + +or if you have been following the example: + +1. ``2*1.6*4/8=1.6GB`` +2. ``2**1.6*4=12.8GB`` + +and the total of ``14.4GB``. + +Now let's briefly discuss what happens to the embeddings as we have left those out from the calculations: + +Given the rule we discussed that you included in the note starting with "the communication buffer +size is determined as follows", we can analyze as follows: + +* Suppose we apply FSDP to the root module (e.g. the ``Transformer`` class). Suppose we further apply FSDP to each transformer block (e.g. the ``TransformerBlock`` class). +* Most commonly, the embedding and final linear projection are direct children of the root ``Transformer`` class. +* Following our rule, that means that the embedding and final linear projection are assigned to the root ``Transformer``'s flat parameter. +* We have _another_ special rule, which is that the root does not free its parameters after forward because they will be anyways immediately all-gathered in backward. +* Putting this together, this means that the root's flat parameter including the embedding and final projection are all-gathered to begin forward and kept in GPU memory until the end of backward. +* If the embedding and final linear are not weight-tied, then we _could_ further apply FSDP to the embedding and to the final linear. For weight-tied parameters, we require them to be part of the same flat parameter (or else it would get double-counted). That would allow the embedding to be freed after its usage in forward and only all-gathered toward the end of backward. +* Hopefully, this gives a better sense -- each FSDP module gets assigned parameters in its ``module.parameters`` except those already assigned to another nested FSDP module, and the FSDP module's ``forward`` defines the 'live' interval for its parameters. Hence, the nested ``nn.Module`` structure can affect the all-gather/free schedule and hence the memory/throughput performance. diff --git a/docs/source/notes/get_start_xpu.rst b/docs/source/notes/get_start_xpu.rst index 57cb47bd840d4..ea68ddb0264a8 100644 --- a/docs/source/notes/get_start_xpu.rst +++ b/docs/source/notes/get_start_xpu.rst @@ -24,12 +24,24 @@ For Intel Client GPU +-------------------------------------+----------------------------------------------------------------------------------------------------+ | Supported OS | Validated Hardware | +=====================================+====================================================================================================+ +<<<<<<< HEAD || Windows 11 & Ubuntu 24.04/25.04 || IntelĀ® Arc A-Series Graphics (CodeName: Alchemist) | +======= +|| Windows 11 & Ubuntu 24.10 || IntelĀ® Arc A-Series Graphics (CodeName: Alchemist) | +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) || || IntelĀ® Arc B-Series Graphics (CodeName: Battlemage) | || || IntelĀ® Coreā„¢ Ultra Processors with IntelĀ® Arcā„¢ Graphics (CodeName: Meteor Lake-H) | || || IntelĀ® Coreā„¢ Ultra Desktop Processors (Series 2) with IntelĀ® Arcā„¢ Graphics (CodeName: Lunar Lake) | || || IntelĀ® Coreā„¢ Ultra Mobile Processors (Series 2) with IntelĀ® Arcā„¢ Graphics (CodeName: Arrow Lake-H)| +-------------------------------------+----------------------------------------------------------------------------------------------------+ +<<<<<<< HEAD +======= +|| Ubuntu 24.04 & WSL2 (Ubuntu 24.04) || IntelĀ® Arc A-Series Graphics (CodeName: Alchemist) | +|| || IntelĀ® Coreā„¢ Ultra Processors with IntelĀ® Arcā„¢ Graphics (CodeName: Meteor Lake-H) | +|| || IntelĀ® Coreā„¢ Ultra Desktop Processors (Series 2) with IntelĀ® Arcā„¢ Graphics (CodeName: Lunar Lake) | +|| || IntelĀ® Coreā„¢ Ultra Mobile Processors (Series 2) with IntelĀ® Arcā„¢ Graphics (CodeName: Arrow Lake-H)| ++-------------------------------------+----------------------------------------------------------------------------------------------------+ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Intel GPUs support (Prototype) is ready from PyTorch* 2.5 for IntelĀ® Client GPUs and IntelĀ® Data Center GPU Max Series on both Linux and Windows, which brings Intel GPUs and the SYCL* software stack into the official PyTorch stack with consistent user experience to embrace more AI application scenarios. @@ -102,7 +114,11 @@ If you are migrating code from ``cuda``, you would change references from ``cuda The following points outline the support and limitations for PyTorch with Intel GPU: #. Both training and inference workflows are supported. +<<<<<<< HEAD #. Both eager mode and ``torch.compile`` is supported. The feature ``torch.compile`` is also supported on Windows from PyTorch* 2.7 with Intel GPU, refer to `How to use torch.compile on Windows CPU/XPU `_. +======= +#. Both eager mode and ``torch.compile`` is supported. The feature ``torch.compile`` is also supported on Windows from PyTorch* 2.7 with Intel GPU, refer to `How to Use Inductor on Windows with CPU/XPU `_. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #. Data types such as FP32, BF16, FP16, and Automatic Mixed Precision (AMP) are all supported. Examples diff --git a/docs/source/notes/hip.rst b/docs/source/notes/hip.rst index 7ee596b53f9cc..423c3676ccfea 100644 --- a/docs/source/notes/hip.rst +++ b/docs/source/notes/hip.rst @@ -179,6 +179,7 @@ by recompiling the PyTorch from source. Please add below line as an argument to cmake command parameters:: -DROCM_FORCE_ENABLE_GPU_ASSERTS:BOOL=ON +<<<<<<< HEAD Enabling/Disabling ROCm Composable Kernel ----------------------------------------- @@ -206,3 +207,5 @@ To enable CK in either scenario, simply pass 'ck' to those functions. In order to set the backend to CK, the user MUST have built with the correct environment variable. If not, PyTorch will print a warning and use the "default" backend. For GEMMs, this will route to hipblas and for SDPA it routes to aotriton. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/docs/source/notes/large_scale_deployments.rst b/docs/source/notes/large_scale_deployments.rst index 27380a68cf338..6f0a495768410 100644 --- a/docs/source/notes/large_scale_deployments.rst +++ b/docs/source/notes/large_scale_deployments.rst @@ -7,6 +7,12 @@ This note talks about several extension points and tricks that might be useful when running PyTorch within a larger system or operating multiple systems using PyTorch in a larger organization. +<<<<<<< HEAD +======= +It doesn't cover topics of deploying models to production. Check +:mod:`torch.jit` or one of the corresponding tutorials. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) The note assumes that you either build PyTorch from source in your organization or have an ability to statically link additional code to be loaded when PyTorch is used. Therefore, many of the hooks are exposed as C++ APIs that @@ -83,7 +89,12 @@ scripts, the callback fires only once for a given process for each of the APIs. ``c10::SetAPIUsageHandler`` can be used to register API usage instrumentation handler. Passed argument is going to be an "api key" identifying used point, for +<<<<<<< HEAD example ``python.import`` for PyTorch extension import. +======= +example ``python.import`` for PyTorch extension import or +``torch.script.compile`` if TorchScript compilation was triggered. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. code-block:: cpp @@ -95,6 +106,45 @@ Note for developers: new API trigger points can be added in code with ``C10_LOG_API_USAGE_ONCE("my_api")`` in C++ or ``torch._C._log_api_usage_once("my.api")`` in Python. +<<<<<<< HEAD +======= +Attaching metadata to saved TorchScript models +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +TorchScript modules can be saved as an archive file that bundles serialized +parameters and module code as TorchScript (see :meth:`torch.jit.save`). It's +often convenient to bundle additional information together with the model, for +example, description of model producer or auxiliary artifacts. + +It can be achieved by passing the ``_extra_files`` argument to +:meth:`torch.jit.save` and ``torch::jit::load`` to store and retrieve +arbitrary binary blobs during saving process. Since TorchScript files are +regular ZIP archives, extra information gets stored as regular files inside +archive's ``extra/`` directory. + +There's also a global hook allowing to attach extra files to any TorchScript +archive produced in the current process. It might be useful to tag models with +producer metadata, akin to JPEG metadata produced by digital cameras. Example +usage might look like: + +.. code-block:: cpp + + SetExportModuleExtraFilesHook([](const Module&) { + ExtraFilesMap files; + files["producer_info.json"] = "{\"user\": \"" + getenv("USER") + "\"}"; + return files; + }); + + +Build environment considerations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +TorchScript's compilation needs to have access to the original python files as +it uses python's ``inspect.getsource`` call. In certain production environments +it might require explicitly deploying ``.py`` files along with precompiled +``.pyc``. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Common extension points ^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/notes/libtorch_stable_abi.md b/docs/source/notes/libtorch_stable_abi.md index 1180a85d0eaa9..cb21737089c2a 100644 --- a/docs/source/notes/libtorch_stable_abi.md +++ b/docs/source/notes/libtorch_stable_abi.md @@ -9,9 +9,14 @@ This note will eventually contain more details on how to use the APIs in torch/c | type in custom extension | StableIValue representation | type in libtorch | Schema Type | | -------- | ------- | ------- | ------- | | std::optional\ | if there is a value, raw bitwise copy into leading bytes of uint64_t of pointer to a new StableIValue representing S. if there is no value, nullptr. | std::optional\ | Type? | +<<<<<<< HEAD | torch::stable::Tensor | raw bitwise copy of underlying AtenTensorHandle into leading bytes of uint64_t | at::Tensor | Tensor | | RAIIATH (outdated) | raw bitwise copy of underlying AtenTensorHandle into leading bytes of uint64_t | at::Tensor | Tensor | | torch::headeronly::ScalarType | raw bitwise copy of the translated underlying enum into leading bytes of uint64_t | torch::headeronly::ScalarType | ScalarType | +======= +| RAIIATH | raw bitwise copy of underlying AtenTensorHandle into leading bytes of uint64_t | at::Tensor | Tensor | +| int32_t | raw bitwise copy into leading bytes of uint64_t | at::ScalarType | ScalarType | +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) | int32_t | raw bitwise copy into leading bytes of uint64_t | at::Layout | Layout | | int32_t | raw bitwise copy into leading bytes of uint64_t | at::MemoryFormat | MemoryFormat | | bool | raw bitwise copy into leading bytes of uint64_t | bool | bool | @@ -31,16 +36,26 @@ This note will eventually contain more details on how to use the APIs in torch/c | ? | ? | c10::SymBool | SymBool | | ? | ? | at::QScheme | QScheme | +<<<<<<< HEAD Our confidently supported types are the ones in the table that have completed rows. You can rely on this subset for proper ABI stability. For a limited set of use cases, we also implicitly support any literal type that is representable within 64 bits as StableIValues, as the default reinterpret_cast will succeed. (For example: c10::Device.) These types are currently ABI-stable on best effort but might break in the future and thus should be used for short term testing only. +======= +Our confidently supported types are the ones in the table that have completed rows. You can rely on this subset proper ABI stability. + +For a limited set of use cases, we also implicitly support any literal type that is representable within 64 bits as StableIValues, as the default reinterpret_cast will succeed. These types are currently ABI-stable on best effort but might break in the future and thus should be used for short term testing only. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) You can always work with StableIValue abstractions in your custom kernel for types such as c10::Device even if there is no standard defined representation of device in custom extensions by not introspecting into the StableIValue. For example, a custom operator can take as argument a StableIValue device and directly pass it through to an aten operator with `aoti_torch_call_dispatcher`. ## How to use stack-based APIs +<<<<<<< HEAD `aoti_torch_call_dispatcher` is what we consider a stack-based API because it takes as input a stack of StableIValues, which correlates with a `torch::jit::stack` of IValues. Working with the dispatcher will likely bring you into proximity with stack-based APIs, so we are documenting some invariants: +======= +`aoti_torch_call_dispatcher` is what we consider a stack-based API because it takes as input a stack of StableIValues. Working with the dispatcher will likely bring you into proximity with stack-based APIs, so we are documenting some invariants: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 1. The stack is populated left to right. a. For example, a stack representing arguments `arg0`, `arg1`, and `arg2` will have `arg0` at index 0, `arg1` at index 1, and `arg2` at index 2. diff --git a/docs/source/notes/numerical_accuracy.rst b/docs/source/notes/numerical_accuracy.rst index 8944ecc05f277..13061b84e5733 100644 --- a/docs/source/notes/numerical_accuracy.rst +++ b/docs/source/notes/numerical_accuracy.rst @@ -93,8 +93,13 @@ On Ampere (and later) Nvidia GPUs, PyTorch can use TensorFloat32 (TF32) to speed When an operation is performed using TF32 tensor cores, only the first 10 bits of the input mantissa are read. This may reduce accuracy and produce surprising results (e.g., multiplying a matrix by the identity matrix may produce results that are different from the input). By default, TF32 tensor cores are disabled for matrix multiplications and enabled for convolutions, although most neural network workloads have the same convergence behavior when using TF32 as they have with fp32. +<<<<<<< HEAD We recommend enabling TF32 tensor cores for matrix multiplications with ``torch.backends.cuda.matmul.fp32_precision = "tf32"`` (```torch.backends.cuda.matmul.allow_tf32 = True`` is going to be deprecated) if your network does not need full float32 precision. If your network needs full float32 precision for both matrix multiplications and convolutions, then TF32 tensor cores can also be disabled for convolutions with ``torch.backends.cudnn.conv.fp32_precision = "ieee"`` (``torch.backends.cudnn.allow_tf32 = False`` is going to be deprecated). +======= +We recommend enabling TF32 tensor cores for matrix multiplications with ``torch.backends.cuda.matmul.allow_tf32 = True`` if your network does not need full float32 precision. +If your network needs full float32 precision for both matrix multiplications and convolutions, then TF32 tensor cores can also be disabled for convolutions with ``torch.backends.cudnn.allow_tf32 = False``. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) For more information see :ref:`TensorFloat32`. diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index 42997694f762b..2fcf52f624290 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -339,6 +339,175 @@ if one does not have access to the ``torch.load`` callsites. if ``weights_only`` was not passed as an argument. +<<<<<<< HEAD +======= +.. _serializing-python-modules: + +Serializing torch.nn.Modules and loading them in C++ +---------------------------------------------------- + +See also: `Tutorial: Loading a TorchScript Model in C++ `_ + +ScriptModules can be serialized as a TorchScript program and loaded +using :func:`torch.jit.load`. +This serialization encodes all the modules’ methods, submodules, parameters, +and attributes, and it allows the serialized program to be loaded in C++ +(i.e. without Python). + +The distinction between :func:`torch.jit.save` and :func:`torch.save` may not +be immediately clear. :func:`torch.save` saves Python objects with pickle. +This is especially useful for prototyping, researching, and training. +:func:`torch.jit.save`, on the other hand, serializes ScriptModules to a format +that can be loaded in Python or C++. This is useful when saving and loading C++ +modules or for running modules trained in Python with C++, a common practice +when deploying PyTorch models. + +To script, serialize and load a module in Python: + +:: + + >>> scripted_module = torch.jit.script(MyModule()) + >>> torch.jit.save(scripted_module, 'mymodule.pt') + >>> torch.jit.load('mymodule.pt') + RecursiveScriptModule( original_name=MyModule + (l0): RecursiveScriptModule(original_name=Linear) + (l1): RecursiveScriptModule(original_name=Linear) ) + + +Traced modules can also be saved with :func:`torch.jit.save`, with the caveat +that only the traced code path is serialized. The following example demonstrates +this: + +:: + + # A module with control flow + >>> class ControlFlowModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.l0 = torch.nn.Linear(4, 2) + self.l1 = torch.nn.Linear(2, 1) + + def forward(self, input): + if input.dim() > 1: + return torch.tensor(0) + + out0 = self.l0(input) + out0_relu = torch.nn.functional.relu(out0) + return self.l1(out0_relu) + + >>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4)) + >>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt') + >>> loaded = torch.jit.load('controlflowmodule_traced.pt') + >>> loaded(torch.randn(2, 4))) + tensor([[-0.1571], [-0.3793]], grad_fn=) + + >>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4)) + >>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt') + >>> loaded = torch.jit.load('controlflowmodule_scripted.pt') + >> loaded(torch.randn(2, 4)) + tensor(0) + +The above module has an if statement that is not triggered by the traced inputs, +and so is not part of the traced module and not serialized with it. +The scripted module, however, contains the if statement and is serialized with it. +See the `TorchScript documentation `_ +for more on scripting and tracing. + +Finally, to load the module in C++: + +:: + + >>> torch::jit::script::Module module; + >>> module = torch::jit::load('controlflowmodule_scripted.pt'); + +See the `PyTorch C++ API documentation `_ +for details about how to use PyTorch modules in C++. + +.. _saving-loading-across-versions: + +Saving and loading ScriptModules across PyTorch versions +----------------------------------------------------------- + +The PyTorch Team recommends saving and loading modules with the same version of +PyTorch. Older versions of PyTorch may not support newer modules, and newer +versions may have removed or modified older behavior. These changes are +explicitly described in +PyTorch’s `release notes `_, +and modules relying on functionality that has changed may need to be updated +to continue working properly. In limited cases, detailed below, PyTorch will +preserve the historic behavior of serialized ScriptModules so they do not require +an update. + +torch.div performing integer division +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In PyTorch 1.5 and earlier :func:`torch.div` would perform floor division when +given two integer inputs: + +:: + + # PyTorch 1.5 (and earlier) + >>> a = torch.tensor(5) + >>> b = torch.tensor(3) + >>> a / b + tensor(1) + +In PyTorch 1.7, however, :func:`torch.div` will always perform a true division +of its inputs, just like division in Python 3: + +:: + + # PyTorch 1.7 + >>> a = torch.tensor(5) + >>> b = torch.tensor(3) + >>> a / b + tensor(1.6667) + +The behavior of :func:`torch.div` is preserved in serialized ScriptModules. +That is, ScriptModules serialized with versions of PyTorch before 1.6 will continue +to see :func:`torch.div` perform floor division when given two integer inputs +even when loaded with newer versions of PyTorch. ScriptModules using :func:`torch.div` +and serialized on PyTorch 1.6 and later cannot be loaded in earlier versions of +PyTorch, however, since those earlier versions do not understand the new behavior. + +torch.full always inferring a float dtype +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In PyTorch 1.5 and earlier :func:`torch.full` always returned a float tensor, +regardless of the fill value it’s given: + +:: + + # PyTorch 1.5 and earlier + >>> torch.full((3,), 1) # Note the integer fill value... + tensor([1., 1., 1.]) # ...but float tensor! + +In PyTorch 1.7, however, :func:`torch.full` will infer the returned tensor’s +dtype from the fill value: + +:: + + # PyTorch 1.7 + >>> torch.full((3,), 1) + tensor([1, 1, 1]) + + >>> torch.full((3,), True) + tensor([True, True, True]) + + >>> torch.full((3,), 1.) + tensor([1., 1., 1.]) + + >>> torch.full((3,), 1 + 1j) + tensor([1.+1.j, 1.+1.j, 1.+1.j]) + +The behavior of :func:`torch.full` is preserved in serialized ScriptModules. That is, +ScriptModules serialized with versions of PyTorch before 1.6 will continue to see +torch.full return float tensors by default, even when given bool or +integer fill values. ScriptModules using :func:`torch.full` and serialized on PyTorch 1.6 +and later cannot be loaded in earlier versions of PyTorch, however, since those +earlier versions do not understand the new behavior. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. _utility functions: Utility functions diff --git a/docs/source/onnx.md b/docs/source/onnx.md index 73a24b671553c..095131306daed 100644 --- a/docs/source/onnx.md +++ b/docs/source/onnx.md @@ -12,6 +12,11 @@ The exported model can be consumed by any of the many [runtimes that support ONNX](https://onnx.ai/supported-tools.html#deployModel), including Microsoft's [ONNX Runtime](https://www.onnxruntime.ai). +<<<<<<< HEAD +======= +**There are two flavors of ONNX exporter API that you can use, as listed below.** +Both can be called through function {func}`torch.onnx.export`. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Next example shows how to export a simple model. ```python @@ -38,6 +43,7 @@ torch.onnx.export( ) ``` +<<<<<<< HEAD ## torch.export-based ONNX Exporter @@ -63,6 +69,41 @@ Q: How to export models containing loops? See {ref}`torch.cond `. +======= +Next sections introduce the two versions of the exporter. + +## TorchDynamo-based ONNX Exporter + +*The TorchDynamo-based ONNX exporter is the newest (and Beta) exporter for PyTorch 2.1 and newer* + +TorchDynamo engine is leveraged to hook into Python's frame evaluation API and dynamically rewrite its +bytecode into an FX Graph. The resulting FX Graph is then polished before it is finally translated into an +ONNX graph. + +The main advantage of this approach is that the [FX graph](https://pytorch.org/docs/stable/fx.html) is captured using +bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques. + +{doc}`Learn more about the TorchDynamo-based ONNX Exporter ` + +## TorchScript-based ONNX Exporter + +*The TorchScript-based ONNX exporter is available since PyTorch 1.2.0* + +[TorchScript](https://pytorch.org/docs/stable/jit.html) is leveraged to trace (through {func}`torch.jit.trace`) +the model and capture a static computation graph. + +As a consequence, the resulting graph has a couple limitations: + +* It does not record any control-flow, like if-statements or loops; +* Does not handle nuances between `training` and `eval` mode; +* Does not truly handle dynamic inputs + +As an attempt to support the static tracing limitations, the exporter also supports TorchScript scripting +(through {func}`torch.jit.script`), which adds support for data-dependent control-flow, for example. However, TorchScript +itself is a subset of the Python language, so not all features in Python are supported, such as in-place operations. + +{doc}`Learn more about the TorchScript-based ONNX Exporter ` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ## Contributing / Developing @@ -70,6 +111,7 @@ The ONNX exporter is a community project and we welcome contributions. We follow [PyTorch guidelines for contributions](https://github.com/pytorch/pytorch/blob/main/CONTRIBUTING.md), but you might also be interested in reading our [development wiki](https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter). +<<<<<<< HEAD ## torch.onnx APIs @@ -95,10 +137,13 @@ also be interested in reading our [development wiki](https://github.com/pytorch/ :noindex: ``` +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ```{eval-rst} .. toctree:: :hidden: +<<<<<<< HEAD onnx_export onnx_ops onnx_verification @@ -118,6 +163,21 @@ also be interested in reading our [development wiki](https://github.com/pytorch/ ```{eval-rst} .. py:module:: torch.onnx.errors .. py:module:: torch.onnx.operators +======= + onnx_dynamo + onnx_ops + onnx_verification + onnx_dynamo_onnxruntime_backend + onnx_torchscript +``` + + +```{eval-rst} +.. py:module:: torch.onnx.errors +.. py:module:: torch.onnx.operators +.. py:module:: torch.onnx.symbolic_caffe2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. py:module:: torch.onnx.symbolic_helper .. py:module:: torch.onnx.symbolic_opset10 .. py:module:: torch.onnx.symbolic_opset11 diff --git a/docs/source/onnx_dynamo.md b/docs/source/onnx_dynamo.md new file mode 100644 index 0000000000000..c5077ef360a5e --- /dev/null +++ b/docs/source/onnx_dynamo.md @@ -0,0 +1,274 @@ +# TorchDynamo-based ONNX Exporter + +```{eval-rst} +.. automodule:: torch.onnx + :noindex: +``` + +```{contents} +:local: +:depth: 1 +``` + +## Overview + +The ONNX exporter leverages TorchDynamo engine to hook into Python's frame evaluation API +and dynamically rewrite its bytecode into an FX Graph. +The resulting FX Graph is then polished before it is finally translated into an ONNX graph. + +The main advantage of this approach is that the [FX graph](https://pytorch.org/docs/stable/fx.html) is captured using +bytecode analysis that preserves the dynamic nature of the model instead of using traditional static tracing techniques. + +In addition, during the export process, memory usage is significantly reduced compared to the TorchScript-enabled exporter. +See the {doc}`memory usage documentation ` for more information. + + +## Dependencies + +The ONNX exporter depends on extra Python packages: + + - [ONNX](https://onnx.ai) + - [ONNX Script](https://microsoft.github.io/onnxscript) + +They can be installed through [pip](https://pypi.org/project/pip/): + +```{code-block} bash + + pip install --upgrade onnx onnxscript +``` + +[onnxruntime](https://onnxruntime.ai) can then be used to execute the model +on a large variety of processors. + +## A simple example + +See below a demonstration of exporter API in action with a simple Multilayer Perceptron (MLP) as example: + +```{code-block} python +import torch +import torch.nn as nn + +class MLPModel(nn.Module): + def __init__(self): + super().__init__() + self.fc0 = nn.Linear(8, 8, bias=True) + self.fc1 = nn.Linear(8, 4, bias=True) + self.fc2 = nn.Linear(4, 2, bias=True) + self.fc3 = nn.Linear(2, 2, bias=True) + self.fc_combined = nn.Linear(8 + 8 + 8, 8, bias=True) # Combine all inputs + + def forward(self, tensor_x: torch.Tensor, input_dict: dict, input_list: list): + """ + Forward method that requires all inputs: + - tensor_x: A direct tensor input. + - input_dict: A dictionary containing the tensor under the key 'tensor_x'. + - input_list: A list where the first element is the tensor. + """ + # Extract tensors from inputs + dict_tensor = input_dict['tensor_x'] + list_tensor = input_list[0] + + # Combine all inputs into a single tensor + combined_tensor = torch.cat([tensor_x, dict_tensor, list_tensor], dim=1) + + # Process the combined tensor through the layers + combined_tensor = self.fc_combined(combined_tensor) + combined_tensor = torch.sigmoid(combined_tensor) + combined_tensor = self.fc0(combined_tensor) + combined_tensor = torch.sigmoid(combined_tensor) + combined_tensor = self.fc1(combined_tensor) + combined_tensor = torch.sigmoid(combined_tensor) + combined_tensor = self.fc2(combined_tensor) + combined_tensor = torch.sigmoid(combined_tensor) + output = self.fc3(combined_tensor) + return output + +model = MLPModel() + +# Example inputs +tensor_input = torch.rand((97, 8), dtype=torch.float32) +dict_input = {'tensor_x': torch.rand((97, 8), dtype=torch.float32)} +list_input = [torch.rand((97, 8), dtype=torch.float32)] + +# The input_names and output_names are used to identify the inputs and outputs of the ONNX model +input_names = ['tensor_input', 'tensor_x', 'list_input_index_0'] +output_names = ['output'] + +# Exporting the model with all required inputs +onnx_program = torch.onnx.export(model,(tensor_input, dict_input, list_input), dynamic_shapes=({0: "batch_size"},{"tensor_x": {0: "batch_size"}},[{0: "batch_size"}]), input_names=input_names, output_names=output_names, dynamo=True,) + +# Check the exported ONNX model is dynamic +assert onnx_program.model.graph.inputs[0].shape == ("batch_size", 8) +assert onnx_program.model.graph.inputs[1].shape == ("batch_size", 8) +assert onnx_program.model.graph.inputs[2].shape == ("batch_size", 8) +``` + +As the code above shows, all you need is to provide {func}`torch.onnx.export` with an instance of the model and its input. +The exporter will then return an instance of {class}`torch.onnx.ONNXProgram` that contains the exported ONNX graph along with extra information. + +The in-memory model available through ``onnx_program.model_proto`` is an ``onnx.ModelProto`` object in compliance with the [ONNX IR spec](https://github.com/onnx/onnx/blob/main/docs/IR.md). +The ONNX model may then be serialized into a [Protobuf file](https://protobuf.dev/) using the {meth}`torch.onnx.ONNXProgram.save` API. + +```{code-block} python + onnx_program.save("mlp.onnx") +``` + +## Use the same model to compare with the TorchScript-enabled exporter + +The biggest difference between the TorchScript-enabled exporter and the TorchDynamo-based exporter is that the latter +requires dynamic_shapes to be the same tree structure as the input, while the former +requires the dynamic_shapes to be a single and flatten dictionary. + +```{code-block} python + torch.onnx.export(model,(tensor_input, dict_input, list_input), "mlp.onnx", dynamic_axes={"tensor_input":{0: "batch_size"}, "tensor_x": {0: "batch_size"}, "list_input_index_0": {0: "batch_size"}}, input_names=input_names, output_names=output_names) +``` + +## Inspecting the ONNX model using GUI + +You can view the exported model using [Netron](https://netron.app/). + +```{image} _static/img/onnx/onnx_dynamo_mlp_model.png +:alt: MLP model as viewed using Netron +:width: 30% +:align: center +``` + +## When the conversion fails + +Function {func}`torch.onnx.export` should be called a second time with +parameter ``report=True``. A markdown report is generated to help the user +to resolve the issue. + +```{toctree} +:hidden: +onnx_dynamo_memory_usage +``` +## Metadata + +During ONNX export, each ONNX node is annotated with metadata that helps trace its origin and context from the original PyTorch model. This metadata is useful for debugging, model inspection, and understanding the mapping between PyTorch and ONNX graphs. + +The following metadata fields are added to each ONNX node: + +- **namespace** + + A string representing the hierarchical namespace of the node, consisting of a stack trace of modules/methods. + + *Example:* + `__main__.SimpleAddModel/add: aten.add.Tensor` + +- **pkg.torch.onnx.class_hierarchy** + + A list of class names representing the hierarchy of modules leading to this node. + + *Example:* + `['__main__.SimpleAddModel', 'aten.add.Tensor']` + +- **pkg.torch.onnx.fx_node** + + The string representation of the original FX node, including its name, number of consumers, the targeted torch op, arguments, and keyword arguments. + + *Example:* + `%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%tensor_x, %input_dict_tensor_x, %input_list_0], 1), kwargs = {})` + +- **pkg.torch.onnx.name_scopes** + + A list of name scopes (methods) representing the path to this node in the PyTorch model. + + *Example:* + `['', 'add']` + +- **pkg.torch.onnx.stack_trace** + + The stack trace from the original code where this node was created, if available. + + *Example:* + ``` + File "simpleadd.py", line 7, in forward + return torch.add(x, y) + ``` + +These metadata fields are stored in the metadata_props attribute of each ONNX node and can be inspected using Netron or programmatically. + +The overall ONNX graph has the following `metadata_props`: + +- **pkg.torch.export.ExportedProgram.graph_signature** + + This property contains a string representation of the graph_signature from the original PyTorch ExportedProgram. The graph signature describes the structure of the model's inputs and outputs and how they map to the ONNX graph. The inputs are defined as `InputSpec` objects, which include the kind of input (e.g., `InputKind.PARAMETER` for parameters, `InputKind.USER_INPUT` for user-defined inputs), the argument name, the target (which can be a specific node in the model), and whether the input is persistent. The outputs are defined as `OutputSpec` objects, which specify the kind of output (e.g., `OutputKind.USER_OUTPUT`) and the argument name. + + To read more about the graph signature, please see the {doc}`torch.export ` for more information. + +- **pkg.torch.export.ExportedProgram.range_constraints** + + This property contains a string representation of any range constraints that were present in the original PyTorch ExportedProgram. Range constraints specify valid ranges for symbolic shapes or values in the model, which can be important for models that use dynamic shapes or symbolic dimensions. + + *Example:* + `s0: VR[2, int_oo]`, which indicates that the size of the input tensor must be at least 2. + + To read more about range constraints, please see the {doc}`torch.export ` for more information. + +Each input value in the ONNX graph may have the following metadata property: + +- **pkg.torch.export.graph_signature.InputSpec.kind** + + The kind of input, as defined by PyTorch's InputKind enum. + + *Example values:* + - "USER_INPUT": A user-provided input to the model. + - "PARAMETER": A model parameter (e.g., weight). + - "BUFFER": A model buffer (e.g., running mean in BatchNorm). + - "CONSTANT_TENSOR": A constant tensor argument. + - "CUSTOM_OBJ": A custom object input. + - "TOKEN": A token input. + +- **pkg.torch.export.graph_signature.InputSpec.persistent** + + Indicates whether the input is persistent (i.e., should be saved as part of the model's state). + + *Example values:* + - "True" + - "False" + +Each output value in the ONNX graph may have the following metadata property: + +- **pkg.torch.export.graph_signature.OutputSpec.kind** + + The kind of input, as defined by PyTorch's OutputKind enum. + + *Example values:* + - "USER_OUTPUT": A user-visible output. + - "LOSS_OUTPUT": A loss value output. + - "BUFFER_MUTATION": Indicates a buffer was mutated. + - "GRADIENT_TO_PARAMETER": Gradient output for a parameter. + - "GRADIENT_TO_USER_INPUT": Gradient output for a user input. + - "USER_INPUT_MUTATION": Indicates a user input was mutated. + - "TOKEN": A token output. + +Each initialized value, input, output has the following metadata: + +- **pkg.torch.onnx.original_node_name** + + The original name of the node in the PyTorch FX graph that produced this value in the case where the value was renamed. This helps trace initializers back to their source in the original model. + + *Example:* + `fc1.weight` + +## API Reference + +```{eval-rst} +.. autofunction:: torch.onnx.export +.. autoclass:: torch.onnx.ONNXProgram + :members: +.. autofunction:: is_in_onnx_export +.. autoclass:: torch.onnx.OnnxExporterError + :members: +.. autofunction:: torch.onnx.enable_fake_mode +``` + +## Deprecated + +The following classes and functions are deprecated and will be removed. + +```{eval-rst} +.. autofunction:: torch.onnx.dynamo_export +.. autoclass:: torch.onnx.ExportOptions +``` diff --git a/docs/source/onnx_dynamo_memory_usage.rst b/docs/source/onnx_dynamo_memory_usage.rst new file mode 100644 index 0000000000000..ba1213c6ee085 --- /dev/null +++ b/docs/source/onnx_dynamo_memory_usage.rst @@ -0,0 +1,112 @@ +Understanding TorchDynamo-based ONNX Exporter Memory Usage +========================================================== +The previous TorchScript-based ONNX exporter would execute the model once to trace its execution, which could cause it to run out of +memory on your GPU if the model's memory requirements exceeded the available GPU memory. This issue has been addressed with the new +TorchDynamo-based ONNX exporter. + +The TorchDynamo-based ONNX exporter utilizes torch.export.export() function to leverage +`FakeTensorMode `_ to avoid performing actual tensor computations +during the export process. This approach results in significantly lower memory usage compared to the TorchScript-based ONNX exporter. + +Below is an example demonstrating the memory usage difference between TorchScript-based and TorchDynamo-based ONNX exporters. +In this example, we use the HighResNet model from MONAI. Before proceeding, please install it from PyPI: + +.. code-block:: bash + + pip install monai + + +PyTorch offers a tool for capturing and visualizing memory usage traces. We will use this tool to record the memory usage of the two +exporters during the export process and compare the results. You can find more details about this tool on +`Understanding CUDA Memory Usage `_. + + +TorchScript-based exporter +========================== +The code below could be run to generate a snapshot file which records the state of allocated CUDA memory during the export process. + +.. code-block:: python + + import torch + + from monai.networks.nets import ( + HighResNet, + ) + + torch.cuda.memory._record_memory_history() + + model = HighResNet( + spatial_dims=3, in_channels=1, out_channels=3, norm_type="batch" + ).eval() + + model = model.to("cuda") + data = torch.randn(30, 1, 48, 48, 48, dtype=torch.float32).to("cuda") + + with torch.no_grad(): + onnx_program = torch.onnx.export( + model, + data, + "torchscript_exporter_highresnet.onnx", + dynamo=False, + ) + + snapshot_name = "torchscript_exporter_example.pickle" + print(f"generate {snapshot_name}") + + torch.cuda.memory._dump_snapshot(snapshot_name) + print("Export is done.") + + +Open `pytorch.org/memory_viz `_ and drag/drop the generated pickled snapshot file into the visualizer. +The memory usage is described as below: + +.. image:: _static/img/onnx/torch_script_exporter_memory_usage.png + + +By this figure, we can see the memory usage peak is above 2.8GB. + + +TorchDynamo-based exporter +========================== + +The code below could be run to generate a snapshot file which records the state of allocated CUDA memory during the export process. + +.. code-block:: python + + import torch + + from monai.networks.nets import ( + HighResNet, + ) + + torch.cuda.memory._record_memory_history() + + model = HighResNet( + spatial_dims=3, in_channels=1, out_channels=3, norm_type="batch" + ).eval() + + model = model.to("cuda") + data = torch.randn(30, 1, 48, 48, 48, dtype=torch.float32).to("cuda") + + with torch.no_grad(): + onnx_program = torch.onnx.export( + model, + data, + "test_faketensor.onnx", + dynamo=True, + ) + + snapshot_name = f"torchdynamo_exporter_example.pickle" + print(f"generate {snapshot_name}") + + torch.cuda.memory._dump_snapshot(snapshot_name) + print(f"Export is done.") + +Open `pytorch.org/memory_viz `_ and drag/drop the generated pickled snapshot file into the visualizer. +The memory usage is described as below: + +.. image:: _static/img/onnx/torch_dynamo_exporter_memory_usage.png + + +By this figure, we can see the memory usage peak is only around 45MB. Comparing to the memory usage peak of TorchScript-based exporter, +it reduces 98% memory usage. diff --git a/docs/source/onnx_dynamo_onnxruntime_backend.md b/docs/source/onnx_dynamo_onnxruntime_backend.md new file mode 100644 index 0000000000000..a59cd4ab919cd --- /dev/null +++ b/docs/source/onnx_dynamo_onnxruntime_backend.md @@ -0,0 +1,11 @@ +# ONNX Backend for TorchDynamo + +For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`. + +```{warning} + The ONNX backend for torch.compile is a rapidly evolving beta technology. +``` + +```{eval-rst} +.. autofunction:: torch.onnx.is_onnxrt_backend_supported +``` \ No newline at end of file diff --git a/docs/source/onnx_torchscript.rst b/docs/source/onnx_torchscript.rst new file mode 100644 index 0000000000000..23a7adb06f7b1 --- /dev/null +++ b/docs/source/onnx_torchscript.rst @@ -0,0 +1,715 @@ +TorchScript-based ONNX Exporter +=============================== + +.. note:: + To export an ONNX model using TorchDynamo instead of TorchScript, please see :doc:`Learn more about the TorchDynamo-based ONNX Exporter ` + +.. contents:: :local: + +Example: AlexNet from PyTorch to ONNX +------------------------------------- + +Here is a simple script which exports a pretrained AlexNet to an ONNX file named ``alexnet.onnx``. +The call to ``torch.onnx.export`` runs the model once to trace its execution and then exports the +traced model to the specified file:: + + import torch + import torchvision + + dummy_input = torch.randn(10, 3, 224, 224, device="cuda") + model = torchvision.models.alexnet(pretrained=True).cuda() + + # Providing input and output names sets the display names for values + # within the model's graph. Setting these does not change the semantics + # of the graph; it is only for readability. + # + # The inputs to the network consist of the flat list of inputs (i.e. + # the values you would pass to the forward() method) followed by the + # flat list of parameters. You can partially specify names, i.e. provide + # a list here shorter than the number of inputs to the model, and we will + # only set that subset of names, starting from the beginning. + input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ] + output_names = [ "output1" ] + + torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names) + +The resulting ``alexnet.onnx`` file contains a binary `protocol buffer `_ +which contains both the network structure and parameters of the model you exported +(in this case, AlexNet). The argument ``verbose=True`` causes the +exporter to print out a human-readable representation of the model:: + + # These are the inputs and parameters to the network, which have taken on + # the names we specified earlier. + graph(%actual_input_1 : Float(10, 3, 224, 224) + %learned_0 : Float(64, 3, 11, 11) + %learned_1 : Float(64) + %learned_2 : Float(192, 64, 5, 5) + %learned_3 : Float(192) + # ---- omitted for brevity ---- + %learned_14 : Float(1000, 4096) + %learned_15 : Float(1000)) { + # Every statement consists of some output tensors (and their types), + # the operator to be run (with its attributes, e.g., kernels, strides, + # etc.), its input tensors (%actual_input_1, %learned_0, %learned_1) + %17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0] + %18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1] + %19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2] + # ---- omitted for brevity ---- + %29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12] + # Dynamic means that the shape is not known. This may be because of a + # limitation of our implementation (which we would like to fix in a + # future release) or shapes which are truly dynamic. + %30 : Dynamic = onnx::Shape(%29), scope: AlexNet + %31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet + %32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet + %33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet + # ---- omitted for brevity ---- + %output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6] + return (%output1); + } + +You can also verify the output using the `ONNX `_ library, +which you can install using ``pip``:: + + pip install onnx + +Then, you can run:: + + import onnx + + # Load the ONNX model + model = onnx.load("alexnet.onnx") + + # Check that the model is well formed + onnx.checker.check_model(model) + + # Print a human readable representation of the graph + print(onnx.helper.printable_graph(model.graph)) + +You can also run the exported model with one of the many +`runtimes that support ONNX `_. +For example after installing `ONNX Runtime `_, you can +load and run the model:: + + import onnxruntime as ort + import numpy as np + + ort_session = ort.InferenceSession("alexnet.onnx") + + outputs = ort_session.run( + None, + {"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)}, + ) + print(outputs[0]) + +Here is a more involved `tutorial on exporting a model and running it with ONNX Runtime `_. + +.. _tracing-vs-scripting: + +Tracing vs Scripting +-------------------- + +Internally, :func:`torch.onnx.export()` requires a :class:`torch.jit.ScriptModule` rather than +a :class:`torch.nn.Module`. If the passed-in model is not already a ``ScriptModule``, +``export()`` will use *tracing* to convert it to one: + +.. TODO(justinchuby): Add a word on recommending tracing over scripting for most use cases. + +* **Tracing**: If ``torch.onnx.export()`` is called with a Module that is not already a + ``ScriptModule``, it first does the equivalent of :func:`torch.jit.trace`, which executes the model + once with the given ``args`` and records all operations that happen during that execution. This + means that if your model is dynamic, e.g., changes behavior depending on input data, the exported + model will *not* capture this dynamic behavior. + We recommend examining the exported model and making sure the operators look + reasonable. Tracing will unroll loops and if statements, exporting a static graph that is exactly + the same as the traced run. If you want to export your model with dynamic control flow, you will + need to use *scripting*. + +* **Scripting**: Compiling a model via scripting preserves dynamic control flow and is valid for inputs + of different sizes. To use scripting: + + * Use :func:`torch.jit.script` to produce a ``ScriptModule``. + * Call ``torch.onnx.export()`` with the ``ScriptModule`` as the model. The ``args`` are still required, + but they will be used internally only to produce example outputs, so that the types and shapes of the + outputs can be captured. No tracing will be performed. + +See `Introduction to TorchScript `_ +and `TorchScript `_ for more details, including how to compose tracing and scripting to suit the +particular requirements of different models. + + +Avoiding Pitfalls +----------------- + +Avoid NumPy and built-in Python types +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +PyTorch models can be written using NumPy or Python types and functions, but +during :ref:`tracing`, any variables of NumPy or Python +types (rather than torch.Tensor) are converted to constants, which will produce +the wrong result if those values should change depending on the inputs. + +For example, rather than using numpy functions on numpy.ndarrays: :: + + # Bad! Will be replaced with constants during tracing. + x, y = np.random.rand(1, 2), np.random.rand(1, 2) + np.concatenate((x, y), axis=1) + +Use torch operators on torch.Tensors: :: + + # Good! Tensor operations will be captured during tracing. + x, y = torch.randn(1, 2), torch.randn(1, 2) + torch.cat((x, y), dim=1) + + +And rather than use :func:`torch.Tensor.item` (which converts a Tensor to a Python +built-in number): :: + + # Bad! y.item() will be replaced with a constant during tracing. + def forward(self, x, y): + return x.reshape(y.item(), -1) + +Use torch's support for implicit casting of single-element tensors: :: + + # Good! y will be preserved as a variable during tracing. + def forward(self, x, y): + return x.reshape(y, -1) + +Avoid Tensor.data +^^^^^^^^^^^^^^^^^ + +Using the Tensor.data field can produce an incorrect trace and therefore an incorrect ONNX graph. +Use :func:`torch.Tensor.detach` instead. (Work is ongoing to +`remove Tensor.data entirely `_). + +Avoid in-place operations when using tensor.shape in tracing mode +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In tracing mode, shapes obtained from ``tensor.shape`` are traced as tensors, +and share the same memory. This might cause a mismatch the final output values. +As a workaround, avoid the use of inplace operations in these scenarios. +For example, in the model:: + + class Model(torch.nn.Module): + def forward(self, states): + batch_size, seq_length = states.shape[:2] + real_seq_length = seq_length + real_seq_length += 2 + return real_seq_length + seq_length + +``real_seq_length`` and ``seq_length`` share the same memory in tracing mode. +This could be avoided by rewriting the inplace operation:: + + real_seq_length = real_seq_length + 2 + +Limitations +----------- + +Types +^^^^^ + +* Only :class:`torch.Tensors`, numeric types that can be trivially converted to torch.Tensors (e.g. float, int), + and tuples and lists of those types are supported as model inputs or outputs. Dict and str inputs and + outputs are accepted in :ref:`tracing` mode, but: + + * Any computation that depends on the value of a dict or a str input **will be replaced with the + constant value** seen during the one traced execution. + * Any output that is a dict will be silently replaced with a **flattened sequence of its values + (keys will be removed)**. E.g. ``{"foo": 1, "bar": 2}`` becomes ``(1, 2)``. + * Any output that is a str will be silently removed. + +* Certain operations involving tuples and lists are not supported in + :ref:`scripting` mode due to limited support in ONNX for nested sequences. + In particular appending a tuple to a list is not supported. In tracing mode, the nested sequences + will be flattened automatically during the tracing. + +Differences in Operator Implementations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Due to differences in implementations of operators, running the exported model on different runtimes +may produce different results from each other or from PyTorch. Normally these differences are +numerically small, so this should only be a concern if your application is sensitive to these +small differences. + +.. _tensor-indexing: + +Unsupported Tensor Indexing Patterns +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Tensor indexing patterns that cannot be exported are listed below. +If you are experiencing issues exporting a model that does not include any of +the unsupported patterns below, please double check that you are exporting with +the latest ``opset_version``. + +Reads / Gets +~~~~~~~~~~~~ + +When indexing into a tensor for reading, the following patterns are not supported: :: + + # Tensor indices that includes negative values. + data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])] + # Workarounds: use positive index values. + +Writes / Sets +~~~~~~~~~~~~~ + +When indexing into a Tensor for writing, the following patterns are not supported: :: + + # Multiple tensor indices if any has rank >= 2 + data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data + # Workarounds: use single tensor index with rank >= 2, + # or multiple consecutive tensor indices with rank == 1. + + # Multiple tensor indices that are not consecutive + data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data + # Workarounds: transpose `data` such that tensor indices are consecutive. + + # Tensor indices that includes negative values. + data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data + # Workarounds: use positive index values. + + # Implicit broadcasting required for new_data. + data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data + # Workarounds: expand new_data explicitly. + # Example: + # data shape: [3, 4, 5] + # new_data shape: [5] + # expected new_data shape after broadcasting: [2, 2, 2, 5] + +Adding support for operators +---------------------------- + +When exporting a model that includes unsupported operators, you'll see an error message like: + +.. code-block:: text + + RuntimeError: ONNX export failed: Couldn't export operator foo + +When that happens, there are a few things you can do: + +#. Change the model to not use that operator. +#. Create a symbolic function to convert the operator and register it as a custom symbolic function. +#. Contribute to PyTorch to add the same symbolic function to :mod:`torch.onnx` itself. + +If you decided to implement a symbolic function (we hope you will contribute it back to PyTorch!), here is how you can get started: + +ONNX exporter internals +^^^^^^^^^^^^^^^^^^^^^^^ + +A "symbolic function" is a function that decomposes a PyTorch operator into a +composition of a series of ONNX operators. + +During export, each node (which contains a PyTorch operator) in the TorchScript +graph is visited by the exporter in topological order. +Upon visiting a node, the exporter looks for a registered symbolic functions for +that operator. Symbolic functions are implemented in Python. A symbolic function for +an op named ``foo`` would look something like:: + + + def foo( + g, + input_0: torch._C.Value, + input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]: + """ + Adds the ONNX operations representing this PyTorch function by updating the + graph g with `g.op()` calls. + + Args: + g (Graph): graph to write the ONNX representation into. + input_0 (Value): value representing the variables which contain + the first input for this operator. + input_1 (Value): value representing the variables which contain + the second input for this operator. + + Returns: + A Value or List of Values specifying the ONNX nodes that compute something + equivalent to the original PyTorch operator with the given inputs. + + None if it cannot be converted to ONNX. + """ + ... + +The ``torch._C`` types are Python wrappers around the types defined in C++ in +`ir.h `_. + +The process for adding a symbolic function depends on the type of operator. + +.. _adding-support-aten: + +ATen operators +^^^^^^^^^^^^^^ + +`ATen `_ is PyTorch's built-in tensor library. +If the operator is an ATen operator (shows up in the TorchScript graph with the prefix +``aten::``), make sure it is not supported already. + +List of supported operators +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Visit the auto generated :doc:`list of supported TorchScript operators <../onnx_torchscript_supported_aten_ops>` +for details on which operator are supported in each ``opset_version``. + +Adding support for an aten or quantized operator +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If the operator is not in the list above: + +* Define the symbolic function in ``torch/onnx/symbolic_opset.py``, for example + `torch/onnx/symbolic_opset9.py `_. + Make sure the function has the same name as the ATen function, which may be declared in + ``torch/_C/_VariableFunctions.pyi`` or ``torch/nn/functional.pyi`` (these files are generated at + build time, so will not appear in your checkout until you build PyTorch). +* By default, the first arg is the ONNX graph. + Other arg names must EXACTLY match the names in the ``.pyi`` file, + because dispatch is done with keyword arguments. +* In the symbolic function, if the operator is in the + `ONNX standard operator set `_, + we only need to create a node to represent the ONNX operator in the graph. + If not, we can compose several standard operators that have the + equivalent semantics to the ATen operator. + +Here is an example of handling missing symbolic function for the ``ELU`` operator. + +If we run the following code:: + + print( + torch.jit.trace( + torch.nn.ELU(), # module + torch.ones(1) # example input + ).graph + ) + +We see something like:: + + graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU, + %input : Float(1, strides=[1], requires_grad=0, device=cpu)): + %4 : float = prim::Constant[value=1.]() + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=1]() + %7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6) + return (%7) + +Since we see ``aten::elu`` in the graph, we know this is an ATen operator. + +We check the `ONNX operator list `_, +and confirm that ``Elu`` is standardized in ONNX. + +We find a signature for ``elu`` in ``torch/nn/functional.pyi``:: + + def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ... + +We add the following lines to ``symbolic_opset9.py``:: + + def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False): + return g.op("Elu", input, alpha_f=alpha) + +Now PyTorch is able to export models containing the ``aten::elu`` operator! + +See the ``torch/onnx/symbolic_opset*.py`` files for more examples. + + +torch.autograd.Functions +^^^^^^^^^^^^^^^^^^^^^^^^ + +If the operator is a sub-class of :class:`torch.autograd.Function`, there are three ways +to export it. + +Static Symbolic Method +~~~~~~~~~~~~~~~~~~~~~~ + +You can add a static method named ``symbolic`` to your function class. It should return +ONNX operators that represent the function's behavior in ONNX. For example:: + + class MyRelu(torch.autograd.Function): + @staticmethod + def forward(ctx, input: torch.Tensor) -> torch.Tensor: + ctx.save_for_backward(input) + return input.clamp(min=0) + + @staticmethod + def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value: + return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float))) + +.. FIXME(justinchuby): PythonOps are too complicated and the example below +.. uses private methods we do not expose. We are looking to +.. improve the experience. Since SymbolicContext is deprecated, we think +.. defining a symbolic staticmethod is a better way to go for now. + +.. PythonOp Symbolic +.. ~~~~~~~~~~~~~~~~~ + +.. Alternatively, you can register a custom symbolic function. +.. This gives the symbolic function access to more info through the +.. ``torch.onnx.SymbolicContext`` object, which gets passed in as the first +.. argument (before the ``Graph`` object). + +.. All autograd ``Function``\ s appear in the TorchScript graph as ``prim::PythonOp`` nodes. +.. In order to differentiate between different ``Function`` subclasses, the +.. symbolic function should use the ``name`` kwarg which gets set to the name of the class. + +.. Custom symbolic functions should add type and shape information by calling ``setType(...)`` +.. on Value objects before returning them (implemented in C++ by +.. . ``torch::jit::Value::setType``). This is not required, but it can help the exporter's +.. shape and type inference for down-stream nodes. For a non-trivial example of ``setType``, see +.. ``test_aten_embedding_2`` in +.. `test_operators.py `_. + +.. The example below shows how you can access ``requires_grad`` via the ``Node`` object: + +.. class MyClip(torch.autograd.Function): +.. @staticmethod +.. def forward(ctx, input, min): +.. ctx.save_for_backward(input) +.. return input.clamp(min=min) + +.. class MyRelu(torch.autograd.Function): +.. @staticmethod +.. def forward(ctx, input): +.. ctx.save_for_backward(input) +.. return input.clamp(min=0) + +.. def symbolic_python_op(g: "GraphContext", *args, **kwargs): +.. n = ctx.cur_node +.. print("original node: ", n) +.. for i, out in enumerate(n.outputs()): +.. print("original output {}: {}, requires grad: {}".format(i, out, out.requiresGrad())) +.. import torch.onnx.symbolic_helper as sym_helper +.. for i, arg in enumerate(args): +.. requires_grad = arg.requiresGrad() if sym_helper._is_value(arg) else False +.. print("arg {}: {}, requires grad: {}".format(i, arg, requires_grad)) + +.. name = kwargs["name"] +.. ret = None +.. if name == "MyClip": +.. ret = g.op("Clip", args[0], args[1]) +.. elif name == "MyRelu": +.. ret = g.op("Relu", args[0]) +.. else: +.. # Logs a warning and returns None +.. return _unimplemented("prim::PythonOp", "unknown node kind: " + name) +.. # Copy type and shape from original node. +.. ret.setType(n.type()) +.. return ret + +.. from torch.onnx import register_custom_op_symbolic +.. . register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1) + +Inline Autograd Function +~~~~~~~~~~~~~~~~~~~~~~~~ + +In cases where a static symbolic method is not provided for its subsequent :class:`torch.autograd.Function` or +where a function to register ``prim::PythonOp`` as custom symbolic functions is not provided, +:func:`torch.onnx.export` tries to inline the graph that corresponds to that :class:`torch.autograd.Function` such that +this function is broken down into individual operators that were used within the function. +The export should be successful as long as these individual operators are supported. For example:: + + class MyLogExp(torch.autograd.Function): + @staticmethod + def forward(ctx, input: torch.Tensor) -> torch.Tensor: + ctx.save_for_backward(input) + h = input.exp() + return h.log().log() + +There is no static symbolic method present for this model, yet it is exported as follows:: + + graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)): + %1 : float = onnx::Exp[](%input) + %2 : float = onnx::Log[](%1) + %3 : float = onnx::Log[](%2) + return (%3) + +If you need to avoid inlining of :class:`torch.autograd.Function`, you should export models with +``operator_export_type`` set to ``ONNX_FALLTHROUGH`` or ``ONNX_ATEN_FALLBACK``. + +Custom operators +^^^^^^^^^^^^^^^^ + +You can export your model with custom operators that includes a combination of many standard ONNX ops, +or are driven by self-defined C++ backend. + +ONNX-script functions +~~~~~~~~~~~~~~~~~~~~~ + +If an operator is not a standard ONNX op, but can be composed of multiple existing ONNX ops, you can utilize +`ONNX-script `_ to create an external ONNX function to support the operator. +You can export it by following this example:: + + import onnxscript + # There are three opset version needed to be aligned + # This is (1) the opset version in ONNX function + from onnxscript.onnx_opset import opset15 as op + opset_version = 15 + + x = torch.randn(1, 2, 3, 4, requires_grad=True) + model = torch.nn.SELU() + + custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1) + + @onnxscript.script(custom_opset) + def Selu(X): + alpha = 1.67326 # auto wrapped as Constants + gamma = 1.0507 + alphaX = op.CastLike(alpha, X) + gammaX = op.CastLike(gamma, X) + neg = gammaX * (alphaX * op.Exp(X) - alphaX) + pos = gammaX * X + zero = op.CastLike(0, X) + return op.Where(X <= zero, neg, pos) + + # setType API provides shape/type to ONNX shape/type inference + def custom_selu(g: jit_utils.GraphContext, X): + return g.onnxscript_op(Selu, X).setType(X.type()) + + # Register custom symbolic function + # There are three opset version needed to be aligned + # This is (2) the opset version in registry + torch.onnx.register_custom_op_symbolic( + symbolic_name="aten::selu", + symbolic_fn=custom_selu, + opset_version=opset_version, + ) + + # There are three opset version needed to be aligned + # This is (2) the opset version in exporter + torch.onnx.export( + model, + x, + "model.onnx", + opset_version=opset_version, + # only needed if you want to specify an opset version > 1. + custom_opsets={"onnx-script": 2} + ) + +The example above exports it as a custom operator in the "onnx-script" opset. +When exporting a custom operator, you can specify the custom domain version using the +``custom_opsets`` dictionary at export. If not specified, the custom opset version defaults to 1. + +NOTE: Be careful to align the opset version mentioned in the above example, and make sure they are consumed in exporter step. +The example usage of how to write a onnx-script function is a beta version in terms of the active development on onnx-script. +Please follow the latest `ONNX-script `_ + +C++ Operators +~~~~~~~~~~~~~ + +If a model uses a custom operator implemented in C++ as described in +`Extending TorchScript with Custom C++ Operators `_, +you can export it by following this example:: + + from torch.onnx import symbolic_helper + + + # Define custom symbolic function + @symbolic_helper.parse_args("v", "v", "f", "i") + def symbolic_foo_forward(g, input1, input2, attr1, attr2): + return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2) + + + # Register custom symbolic function + torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9) + + + class FooModel(torch.nn.Module): + def __init__(self, attr1, attr2): + super().__init__() + self.attr1 = attr1 + self.attr2 = attr2 + + def forward(self, input1, input2): + # Calling custom op + return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2) + + + model = FooModel(attr1, attr2) + torch.onnx.export( + model, + (example_input1, example_input1), + "model.onnx", + # only needed if you want to specify an opset version > 1. + custom_opsets={"custom_domain": 2} + ) + +The example above exports it as a custom operator in the "custom_domain" opset. +When exporting a custom operator, you can specify the custom domain version using the +``custom_opsets`` dictionary at export. If not specified, the custom opset version defaults to 1. + +The runtime that consumes the model needs to support the custom op. See +`Caffe2 custom ops `_, +`ONNX Runtime custom ops `_, +or your runtime of choice's documentation. + + +Discovering all unconvertible ATen ops at once +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When export fails due to an unconvertible ATen op, there may in fact be more +than one such op but the error message only mentions the first. To discover +all of the unconvertible ops in one go you can:: + + # prepare model, args, opset_version + ... + + torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops( + model, args, opset_version=opset_version + ) + + print(set(unconvertible_ops)) + +The set is approximated because some ops may be removed during the conversion +process and don't need to be converted. Some other ops may have partial support +that will fail conversion with particular inputs, but this should give you a +general idea of what ops are not supported. Please feel free to open GitHub Issues +for op support requests. + +Frequently Asked Questions +-------------------------- +Q: I have exported my LSTM model, but its input size seems to be fixed? + + The tracer records the shapes of the example inputs. If the model should accept + inputs of dynamic shapes, set ``dynamic_axes`` when calling :func:`torch.onnx.export`. + +Q: How to export models containing loops? + + See `Tracing vs Scripting`_. + +Q: How to export models with primitive type inputs (e.g. int, float)? + + Support for primitive numeric type inputs was added in PyTorch 1.9. + However, the exporter does not support models with str inputs. + +Q: Does ONNX support implicit scalar datatype casting? + + The ONNX standard does not, but the exporter will try to handle that part. + Scalars are exported as constant tensors. + The exporter will figure out the right data type for scalars. In rare cases when it is unable + to do so, you will need to manually specify the datatype with e.g. `dtype=torch.float32`. + If you see any errors, please [create a GitHub issue](https://github.com/pytorch/pytorch/issues). + +Q: Are lists of Tensors exportable to ONNX? + + Yes, for ``opset_version`` >= 11, since ONNX introduced the Sequence type in opset 11. + +Python API +---------- + +.. automodule:: torch.onnx + +Functions +^^^^^^^^^ + +.. autofunction:: export + :noindex: +.. autofunction:: register_custom_op_symbolic +.. autofunction:: unregister_custom_op_symbolic +.. autofunction:: select_model_mode_for_export +.. autofunction:: is_in_onnx_export + :noindex: + +Classes +^^^^^^^ + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + JitScalarType \ No newline at end of file diff --git a/docs/source/onnx_torchscript_supported_aten_ops.rst b/docs/source/onnx_torchscript_supported_aten_ops.rst new file mode 100644 index 0000000000000..e5a795a903860 --- /dev/null +++ b/docs/source/onnx_torchscript_supported_aten_ops.rst @@ -0,0 +1,30 @@ +:orphan: + +ONNX supported TorchScript operators +==================================== + +.. This file is automatically generated during the documentation build +.. by cross referencing ONNX operator symbolics with TorchScript operators via +.. ``docs/source/scripts/build_onnx_torchscript_supported_aten_op_csv_table.py``. +.. Do not modify directly and instead `rebuild the docs `_. + +This page lists the TorchScript operators that are supported/unsupported by ONNX export. + +Supported operators +------------------- + +.. csv-table:: ONNX support for TorchScript operators + :file: ../build/onnx/auto_gen_supported_op_list.csv + :widths: 70, 30 + :header-rows: 1 + + +Unsupported operators +--------------------- + +Operators that are not yet supported + +.. csv-table:: Unsupported operators + :file: ../build/onnx/auto_gen_unsupported_op_list.csv + :widths: 70, 30 + :header-rows: 1 diff --git a/docs/source/onnx_verification.md b/docs/source/onnx_verification.md index 4036aea8f81a7..f35f33f42d019 100644 --- a/docs/source/onnx_verification.md +++ b/docs/source/onnx_verification.md @@ -1,5 +1,8 @@ # torch.onnx.verification +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ```{eval-rst} .. automodule:: torch.onnx.verification ``` @@ -12,3 +15,26 @@ .. autoclass:: VerificationInfo :members: ``` +<<<<<<< HEAD +======= + +```{eval-rst} +.. autofunction:: verify +``` + +## Deprecated + +The following classes and functions are deprecated. + + +```{eval-rst} +.. py:class:: check_export_model_diff +.. py:class:: GraphInfo +.. py:class:: GraphInfoPrettyPrinter +.. py:class:: OnnxBackend +.. py:class:: OnnxTestCaseRepro +.. py:class:: VerificationOptions +.. py:function:: find_mismatch +.. py:function:: verify_aten_graph +``` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/docs/source/optim.md b/docs/source/optim.md index 8c3174c76fb29..bfb4241b66d48 100644 --- a/docs/source/optim.md +++ b/docs/source/optim.md @@ -165,7 +165,10 @@ for input, target in dataset: Adamax ASGD LBFGS +<<<<<<< HEAD Muon +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NAdam RAdam RMSprop @@ -211,7 +214,10 @@ Below is a table showing the available and default implementations of each algor :class:`Adamax`;foreach;yes;no :class:`ASGD`;foreach;yes;no :class:`LBFGS`;for-loop;no;no +<<<<<<< HEAD :class:`Muon`;for-loop;no;no +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) :class:`NAdam`;foreach;yes;no :class:`RAdam`;foreach;yes;no :class:`RMSprop`;foreach;yes;no @@ -235,7 +241,10 @@ Below table is showing the stability status for fused implementations: :class:`Adamax`;unsupported;unsupported;unsupported :class:`ASGD`;unsupported;unsupported;unsupported :class:`LBFGS`;unsupported;unsupported;unsupported +<<<<<<< HEAD :class:`Muon`;unsupported;unsupported;unsupported +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) :class:`NAdam`;unsupported;unsupported;unsupported :class:`RAdam`;unsupported;unsupported;unsupported :class:`RMSprop`;unsupported;unsupported;unsupported @@ -691,6 +700,7 @@ We train the model for a total of 300 epochs and start to collect EMA averages i ```{eval-rst} +<<<<<<< HEAD .. py:module:: torch.optim.lr_scheduler .. py:module:: torch.optim.optimizer .. py:module:: torch.optim.swa_utils @@ -702,3 +712,22 @@ for tracking purposes --> optim.aliases.md ``` +======= +.. py:module:: torch.optim.adadelta +.. py:module:: torch.optim.adagrad +.. py:module:: torch.optim.adam +.. py:module:: torch.optim.adamax +.. py:module:: torch.optim.adamw +.. py:module:: torch.optim.asgd +.. py:module:: torch.optim.lbfgs +.. py:module:: torch.optim.lr_scheduler +.. py:module:: torch.optim.nadam +.. py:module:: torch.optim.optimizer +.. py:module:: torch.optim.radam +.. py:module:: torch.optim.rmsprop +.. py:module:: torch.optim.rprop +.. py:module:: torch.optim.sgd +.. py:module:: torch.optim.sparse_adam +.. py:module:: torch.optim.swa_utils +``` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/docs/source/package.md b/docs/source/package.md index 1b50f743d5793..ebbdf0adb7bcf 100644 --- a/docs/source/package.md +++ b/docs/source/package.md @@ -416,6 +416,24 @@ with PackageExporter(f2, importer=(importer, sys_importer)) as exporter: exporter.save_pickle("model", "model.pkl", obj) ``` +<<<<<<< HEAD +======= +### Package a TorchScript module? +To package a TorchScript model, use the same `save_pickle` and `load_pickle` APIs as you would with any other object. +Saving TorchScript objects that are attributes or submodules is supported as well with no extra work. + +```python +# save TorchScript just like any other object +with PackageExporter(file_name) as e: + e.save_pickle("res", "script_model.pkl", scripted_model) + e.save_pickle("res", "mixed_model.pkl", python_model_with_scripted_submodule) +# load as normal +importer = PackageImporter(file_name) +loaded_script = importer.load_pickle("res", "script_model.pkl") +loaded_mixed = importer.load_pickle("res", "mixed_model.pkl" +``` + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ## Explanation ### `torch.package` Format Overview diff --git a/docs/source/pytorch-api.md b/docs/source/pytorch-api.md index 6ebf94c47a357..9afb2a3f69652 100644 --- a/docs/source/pytorch-api.md +++ b/docs/source/pytorch-api.md @@ -1,4 +1,5 @@ (pytorch_api)= +<<<<<<< HEAD # Reference API ```{toctree} @@ -6,11 +7,17 @@ C++ ``` +======= +# Python API +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ```{toctree} :glob: :maxdepth: 1 +<<<<<<< HEAD :caption: Python API +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch nn diff --git a/docs/source/quantization-accuracy-debugging.md b/docs/source/quantization-accuracy-debugging.md new file mode 100644 index 0000000000000..d13d83129570a --- /dev/null +++ b/docs/source/quantization-accuracy-debugging.md @@ -0,0 +1,96 @@ +# Quantization Accuracy Debugging + +This document provides high level strategies for improving quantization +accuracy. If a quantized model has error compared to the original model, +we can categorize the error into: + +1. **data insensitive error** - caused by intrinsic model quantization error, + large portion of input data has large error +2. **data sensitive error** - caused by outlier input data, small + portion of input data has large error +3. **implementation error** - quantized kernel is not matching reference implementation + +## Data insensitive error + +### General tips + +1. For PTQ, ensure that the data you are calibrating with is representative + of your dataset. For example, for a classification problem a general + guideline is to have multiple samples in every category, and the overall + number of samples should be at least 100. There is no penalty for + calibrating with more data other than calibration time. +2. If your model has Conv-BN or Linear-BN patterns, consider fusing them. + If you are using FX graph mode quantization, this is done automatically + by the workflow. If you are using Eager mode quantization, you can do + this manually with the ``torch.ao.quantization.fuse_modules`` API. +3. Increase the precision of dtype of the problematic ops. Usually, fp32 + will have the highest accuracy, followed by fp16, followed by dynamically + quantized int8, followed by statically quantized int8. + + 1. Note: this is trading off performance for accuracy. + 2. Note: availability of kernels per dtype per op can vary by backend. + 3. Note: dtype conversions add an additional performance cost. For example, + ``fp32_op -> quant -> int8_op -> dequant -> fp32_op -> quant -> int8_op -> dequant`` + will have a performance penalty compared to + ``fp32_op -> fp32_op -> quant -> int8_op -> int8_op -> dequant`` + because of a higher number of required dtype conversions. + +4. If you are using PTQ, consider using QAT to recover some of the accuracy loss + from quantization. + +### Int8 quantization tips + +1. If you are using per-tensor weight quantization, consider using per-channel + weight quantization. +2. If you are doing inference on `fbgemm`, ensure that you set the `reduce_range` + argument to `False` if your CPU is Cooperlake or newer, and to `True` otherwise. +3. Audit the input activation distribution variation across different samples. + If this variation is high, the layer may be suitable for dynamic quantization + but not static quantization. + +## Data sensitive error + +If you are using static quantization and a small portion of your input data is +resulting in high quantization error, you can try: + +1. Adjust your calibration dataset to make it more representative of your + inference dataset. +2. Manually inspect (using Numeric Suite) which layers have high quantization + error. For these layers, consider leaving them in floating point or adjusting + the observer settings to choose a better scale and zero_point. + + +## Implementation error + +If you are using PyTorch quantization with your own backend +you may see differences between the reference implementation of an +operation (such as ``dequant -> op_fp32 -> quant``) and the quantized implementation +(such as `op_int8`) of the op on the target hardware. This could mean one of two things: + +1. the differences (usually small) are expected due to specific behavior of + the target kernel on the target hardware compared to fp32/cpu. An example of this + is accumulating in an integer dtype. Unless the kernel guarantees bitwise + equivalency with the reference implementation, this is expected. +2. the kernel on the target hardware has an accuracy issue. In this case, reach + out to the kernel developer. + +## Numerical Debugging Tooling (prototype) + +```{eval-rst} +.. toctree:: + :hidden: + + torch.ao.ns._numeric_suite + torch.ao.ns._numeric_suite_fx +``` + +```{warning} +Numerical debugging tooling is early prototype and subject to change. +``` + +```{eval-rst} +* :ref:`torch_ao_ns_numeric_suite` + Eager mode numeric suite +* :ref:`torch_ao_ns_numeric_suite_fx` + FX numeric suite +``` diff --git a/docs/source/quantization-backend-configuration.md b/docs/source/quantization-backend-configuration.md new file mode 100644 index 0000000000000..fb28fbef54387 --- /dev/null +++ b/docs/source/quantization-backend-configuration.md @@ -0,0 +1,19 @@ +# Quantization Backend Configuration + +FX Graph Mode Quantization allows the user to configure various +quantization behaviors of an op in order to match the expectation +of their backend. + +In the future, this document will contain a detailed spec of +these configurations. + +## Default values for native configurations + +Below is the output of the configuration for quantization of ops +in x86 and qnnpack (PyTorch's default quantized backends). + +Results: + +```{eval-rst} +.. literalinclude:: scripts/quantization_backend_configs/default_backend_config.txt +``` diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index d8f7c162b5e04..3ca7087976ef6 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -6,6 +6,7 @@ Quantization .. automodule:: torch.ao.quantization .. automodule:: torch.ao.quantization.fx +<<<<<<< HEAD We are cetralizing all quantization related development to `torchao `__, please checkout our new doc page: https://docs.pytorch.org/ao/stable/index.html Plan for the existing quantization flows: @@ -25,6 +26,895 @@ We plan to delete `torch.ao.quantization` in 2.10 if there are no blockers, or i Quantization API Reference (Kept since APIs are still public) ----------------------------------------------------------------- +======= +.. warning :: + Quantization is in beta and subject to change. + +Introduction to Quantization +---------------------------- + +Quantization refers to techniques for performing computations and storing +tensors at lower bitwidths than floating point precision. A quantized model +executes some or all of the operations on tensors with reduced precision rather than +full precision (floating point) values. This allows for a more compact model representation and +the use of high performance vectorized operations on many hardware platforms. +PyTorch supports INT8 quantization compared to typical FP32 models allowing for +a 4x reduction in the model size and a 4x reduction in memory bandwidth +requirements. Hardware support for INT8 computations is typically 2 to 4 +times faster compared to FP32 compute. Quantization is primarily a technique to +speed up inference and only the forward pass is supported for quantized +operators. + +PyTorch supports multiple approaches to quantizing a deep learning model. In +most cases the model is trained in FP32 and then the model is converted to +INT8. In addition, PyTorch also supports quantization aware training, which +models quantization errors in both the forward and backward passes using +fake-quantization modules. Note that the entire computation is carried out in +floating point. At the end of quantization aware training, PyTorch provides +conversion functions to convert the trained model into lower precision. + +At lower level, PyTorch provides a way to represent quantized tensors and +perform operations with them. They can be used to directly construct models +that perform all or part of the computation in lower precision. Higher-level +APIs are provided that incorporate typical workflows of converting FP32 model +to lower precision with minimal accuracy loss. + +Quantization API Summary +----------------------------- + +PyTorch provides three different modes of quantization: Eager Mode Quantization, FX Graph Mode Quantization (maintenance) and PyTorch 2 Export Quantization. + +Eager Mode Quantization is a beta feature. User needs to do fusion and specify where quantization and dequantization happens manually, also it only supports modules and not functionals. + +FX Graph Mode Quantization is an automated quantization workflow in PyTorch, and currently it's a prototype feature, it is in maintenance mode since we have PyTorch 2 Export Quantization. It improves upon Eager Mode Quantization by adding support for functionals and automating the quantization process, although people might need to refactor the model to make the model compatible with FX Graph Mode Quantization (symbolically traceable with ``torch.fx``). Note that FX Graph Mode Quantization is not expected to work on arbitrary models since the model might not be symbolically traceable, we will integrate it into domain libraries like torchvision and users will be able to quantize models similar to the ones in supported domain libraries with FX Graph Mode Quantization. For arbitrary models we'll provide general guidelines, but to actually make it work, users might need to be familiar with ``torch.fx``, especially on how to make a model symbolically traceable. + +PyTorch 2 Export Quantization is the new full graph mode quantization workflow, released as prototype feature in PyTorch 2.1. With PyTorch 2, we are moving to a better solution for full program capture (torch.export) since it can capture a higher percentage (88.8% on 14K models) of models compared to torch.fx.symbolic_trace (72.7% on 14K models), the program capture solution used by FX Graph Mode Quantization. torch.export still has limitations around some python constructs and requires user involvement to support dynamism in the exported model, but overall it is an improvement over the previous program capture solution. PyTorch 2 Export Quantization is built for models captured by torch.export, with flexibility and productivity of both modeling users and backend developers in mind. The main features are +(1). Programmable API for configuring how a model is quantized that can scale to many more use cases +(2). Simplified UX for modeling users and backend developers since they only need to interact with a single object (Quantizer) for expressing user’s intention about how to quantize a model and what the backend support. +(3). Optional reference quantized model representation that can represent quantized computation with integer operations that maps closer to actual quantized computations that happens in hardware. + +New users of quantization are encouraged to try out PyTorch 2 Export Quantization first, if it does not work well, user can try eager mode quantization. + +The following table compares the differences between Eager Mode Quantization, FX Graph Mode Quantization and PyTorch 2 Export Quantization: + ++-----------------+-------------------+-------------------+-------------------------+ +| |Eager Mode |FX Graph |PyTorch 2 Export | +| |Quantization |Mode |Quantization | +| | |Quantization | | ++-----------------+-------------------+-------------------+-------------------------+ +|Release |beta |prototype |prototype | +|Status | |(maintenance) | | ++-----------------+-------------------+-------------------+-------------------------+ +|Operator |Manual |Automatic |Automatic | +|Fusion | | | | ++-----------------+-------------------+-------------------+-------------------------+ +|Quant/DeQuant |Manual |Automatic |Automatic | +|Placement | | | | ++-----------------+-------------------+-------------------+-------------------------+ +|Quantizing |Supported |Supported |Supported | +|Modules | | | | ++-----------------+-------------------+-------------------+-------------------------+ +|Quantizing |Manual |Automatic |Supported | +|Functionals/Torch| | | | +|Ops | | | | ++-----------------+-------------------+-------------------+-------------------------+ +|Support for |Limited Support |Fully |Fully Supported | +|Customization | |Supported | | ++-----------------+-------------------+-------------------+-------------------------+ +|Quantization Mode|Post Training |Post Training |Defined by | +|Support |Quantization: |Quantization: |Backend Specific | +| |Static, Dynamic, |Static, Dynamic, |Quantizer | +| |Weight Only |Weight Only | | +| | | | | +| |Quantization Aware |Quantization Aware | | +| |Training: |Training: | | +| |Static |Static | | ++-----------------+-------------------+-------------------+-------------------------+ +|Input/Output |``torch.nn.Module``|``torch.nn.Module``|``torch.fx.GraphModule`` | +|Model Type | |(May need some |(captured by | +| | |refactors to make |``torch.export`` | +| | |the model | | +| | |compatible with FX | | +| | |Graph Mode | | +| | |Quantization) | | ++-----------------+-------------------+-------------------+-------------------------+ + + + +There are three types of quantization supported: + +1. dynamic quantization (weights quantized with activations read/stored in + floating point and quantized for compute) +2. static quantization (weights quantized, activations quantized, calibration + required post training) +3. static quantization aware training (weights quantized, activations quantized, + quantization numerics modeled during training) + +Please see our `Introduction to Quantization on PyTorch +`_ blog post +for a more comprehensive overview of the tradeoffs between these quantization +types. + +Operator coverage varies between dynamic and static quantization and is captured in the table below. + ++---------------------------+-------------------+--------------------+ +| |Static | Dynamic | +| |Quantization | Quantization | ++---------------------------+-------------------+--------------------+ +| | nn.Linear | | Y | | Y | +| | nn.Conv1d/2d/3d | | Y | | N | ++---------------------------+-------------------+--------------------+ +| | nn.LSTM | | Y (through | | Y | +| | | | custom modules) | | | +| | nn.GRU | | N | | Y | ++---------------------------+-------------------+--------------------+ +| | nn.RNNCell | | N | | Y | +| | nn.GRUCell | | N | | Y | +| | nn.LSTMCell | | N | | Y | ++---------------------------+-------------------+--------------------+ +|nn.EmbeddingBag | Y (activations | | +| | are in fp32) | Y | ++---------------------------+-------------------+--------------------+ +|nn.Embedding | Y | Y | ++---------------------------+-------------------+--------------------+ +| nn.MultiheadAttention | Y (through | Not supported | +| | custom modules) | | ++---------------------------+-------------------+--------------------+ +| Activations | Broadly supported | Un-changed, | +| | | computations | +| | | stay in fp32 | ++---------------------------+-------------------+--------------------+ + + +Eager Mode Quantization +^^^^^^^^^^^^^^^^^^^^^^^ +For a general introduction to the quantization flow, including different types of quantization, please take a look at `General Quantization Flow`_. + +Post Training Dynamic Quantization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This is the simplest to apply form of quantization where the weights are +quantized ahead of time but the activations are dynamically quantized +during inference. This is used for situations where the model execution time +is dominated by loading weights from memory rather than computing the matrix +multiplications. This is true for LSTM and Transformer type models with +small batch size. + +Diagram:: + + # original model + # all tensors and computations are in floating point + previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32 + / + linear_weight_fp32 + + # dynamically quantized model + # linear and LSTM weights are in int8 + previous_layer_fp32 -- linear_int8_w_fp32_inp -- activation_fp32 -- next_layer_fp32 + / + linear_weight_int8 + +PTDQ API Example:: + + import torch + + # define a floating point model + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 4) + + def forward(self, x): + x = self.fc(x) + return x + + # create a model instance + model_fp32 = M() + # create a quantized model instance + model_int8 = torch.ao.quantization.quantize_dynamic( + model_fp32, # the original model + {torch.nn.Linear}, # a set of layers to dynamically quantize + dtype=torch.qint8) # the target dtype for quantized weights + + # run the model + input_fp32 = torch.randn(4, 4, 4, 4) + res = model_int8(input_fp32) + +To learn more about dynamic quantization please see our `dynamic quantization tutorial +`_. + +Post Training Static Quantization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Post Training Static Quantization (PTQ static) quantizes the weights and activations of the model. It +fuses activations into preceding layers where possible. It requires +calibration with a representative dataset to determine optimal quantization +parameters for activations. Post Training Static Quantization is typically used when +both memory bandwidth and compute savings are important with CNNs being a +typical use case. + +We may need to modify the model before applying post training static quantization. Please see `Model Preparation for Eager Mode Static Quantization`_. + +Diagram:: + + # original model + # all tensors and computations are in floating point + previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32 + / + linear_weight_fp32 + + # statically quantized model + # weights and activations are in int8 + previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8 + / + linear_weight_int8 + +PTSQ API Example:: + + import torch + + # define a floating point model where some layers could be statically quantized + class M(torch.nn.Module): + def __init__(self): + super().__init__() + # QuantStub converts tensors from floating point to quantized + self.quant = torch.ao.quantization.QuantStub() + self.conv = torch.nn.Conv2d(1, 1, 1) + self.relu = torch.nn.ReLU() + # DeQuantStub converts tensors from quantized to floating point + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x): + # manually specify where tensors will be converted from floating + # point to quantized in the quantized model + x = self.quant(x) + x = self.conv(x) + x = self.relu(x) + # manually specify where tensors will be converted from quantized + # to floating point in the quantized model + x = self.dequant(x) + return x + + # create a model instance + model_fp32 = M() + + # model must be set to eval mode for static quantization logic to work + model_fp32.eval() + + # attach a global qconfig, which contains information about what kind + # of observers to attach. Use 'x86' for server inference and 'qnnpack' + # for mobile inference. Other quantization configurations such as selecting + # symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques + # can be specified here. + # Note: the old 'fbgemm' is still available but 'x86' is the recommended default + # for server inference. + # model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm') + model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86') + + # Fuse the activations to preceding layers, where applicable. + # This needs to be done manually depending on the model architecture. + # Common fusions include `conv + relu` and `conv + batchnorm + relu` + model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']]) + + # Prepare the model for static quantization. This inserts observers in + # the model that will observe activation tensors during calibration. + model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused) + + # calibrate the prepared model to determine quantization parameters for activations + # in a real world setting, the calibration would be done with a representative dataset + input_fp32 = torch.randn(4, 1, 4, 4) + model_fp32_prepared(input_fp32) + + # Convert the observed model to a quantized model. This does several things: + # quantizes the weights, computes and stores the scale and bias value to be + # used with each activation tensor, and replaces key operators with quantized + # implementations. + model_int8 = torch.ao.quantization.convert(model_fp32_prepared) + + # run the model, relevant calculations will happen in int8 + res = model_int8(input_fp32) + +To learn more about static quantization, please see the `static quantization tutorial +`_. + +Quantization Aware Training for Static Quantization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Quantization Aware Training (QAT) models the effects of quantization during training +allowing for higher accuracy compared to other quantization methods. We can do QAT for static, dynamic or weight only quantization. During +training, all calculations are done in floating point, with fake_quant modules +modeling the effects of quantization by clamping and rounding to simulate the +effects of INT8. After model conversion, weights and +activations are quantized, and activations are fused into the preceding layer +where possible. It is commonly used with CNNs and yields a higher accuracy +compared to static quantization. + +We may need to modify the model before applying post training static quantization. Please see `Model Preparation for Eager Mode Static Quantization`_. + +Diagram:: + + # original model + # all tensors and computations are in floating point + previous_layer_fp32 -- linear_fp32 -- activation_fp32 -- next_layer_fp32 + / + linear_weight_fp32 + + # model with fake_quants for modeling quantization numerics during training + previous_layer_fp32 -- fq -- linear_fp32 -- activation_fp32 -- fq -- next_layer_fp32 + / + linear_weight_fp32 -- fq + + # quantized model + # weights and activations are in int8 + previous_layer_int8 -- linear_with_activation_int8 -- next_layer_int8 + / + linear_weight_int8 + +QAT API Example:: + + import torch + + # define a floating point model where some layers could benefit from QAT + class M(torch.nn.Module): + def __init__(self): + super().__init__() + # QuantStub converts tensors from floating point to quantized + self.quant = torch.ao.quantization.QuantStub() + self.conv = torch.nn.Conv2d(1, 1, 1) + self.bn = torch.nn.BatchNorm2d(1) + self.relu = torch.nn.ReLU() + # DeQuantStub converts tensors from quantized to floating point + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + x = self.dequant(x) + return x + + # create a model instance + model_fp32 = M() + + # model must be set to eval for fusion to work + model_fp32.eval() + + # attach a global qconfig, which contains information about what kind + # of observers to attach. Use 'x86' for server inference and 'qnnpack' + # for mobile inference. Other quantization configurations such as selecting + # symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques + # can be specified here. + # Note: the old 'fbgemm' is still available but 'x86' is the recommended default + # for server inference. + # model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm') + model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86') + + # fuse the activations to preceding layers, where applicable + # this needs to be done manually depending on the model architecture + model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, + [['conv', 'bn', 'relu']]) + + # Prepare the model for QAT. This inserts observers and fake_quants in + # the model needs to be set to train for QAT logic to work + # the model that will observe weight and activation tensors during calibration. + model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused.train()) + + # run the training loop (not shown) + training_loop(model_fp32_prepared) + + # Convert the observed model to a quantized model. This does several things: + # quantizes the weights, computes and stores the scale and bias value to be + # used with each activation tensor, fuses modules where appropriate, + # and replaces key operators with quantized implementations. + model_fp32_prepared.eval() + model_int8 = torch.ao.quantization.convert(model_fp32_prepared) + + # run the model, relevant calculations will happen in int8 + res = model_int8(input_fp32) + +To learn more about quantization aware training, please see the `QAT +tutorial +`_. + +Model Preparation for Eager Mode Static Quantization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +It is necessary to currently make some modifications to the model definition +prior to Eager mode quantization. This is because currently quantization works on a module +by module basis. Specifically, for all quantization techniques, the user needs to: + +1. Convert any operations that require output requantization (and thus have + additional parameters) from functionals to module form (for example, + using ``torch.nn.ReLU`` instead of ``torch.nn.functional.relu``). +2. Specify which parts of the model need to be quantized either by assigning + ``.qconfig`` attributes on submodules or by specifying ``qconfig_mapping``. + For example, setting ``model.conv1.qconfig = None`` means that the + ``model.conv`` layer will not be quantized, and setting + ``model.linear1.qconfig = custom_qconfig`` means that the quantization + settings for ``model.linear1`` will be using ``custom_qconfig`` instead + of the global qconfig. + +For static quantization techniques which quantize activations, the user needs +to do the following in addition: + +1. Specify where activations are quantized and de-quantized. This is done using + :class:`~torch.ao.quantization.QuantStub` and + :class:`~torch.ao.quantization.DeQuantStub` modules. +2. Use :class:`~torch.ao.nn.quantized.FloatFunctional` to wrap tensor operations + that require special handling for quantization into modules. Examples + are operations like ``add`` and ``cat`` which require special handling to + determine output quantization parameters. +3. Fuse modules: combine operations/modules into a single module to obtain + higher accuracy and performance. This is done using the + :func:`~torch.ao.quantization.fuse_modules.fuse_modules` API, which takes in lists of modules + to be fused. We currently support the following fusions: + [Conv, Relu], [Conv, BatchNorm], [Conv, BatchNorm, Relu], [Linear, Relu] + +(Prototype - maintenance mode) FX Graph Mode Quantization +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +There are multiple quantization types in post training quantization (weight only, dynamic and static) and the configuration is done through `qconfig_mapping` (an argument of the `prepare_fx` function). + +FXPTQ API Example:: + + import torch + from torch.ao.quantization import ( + get_default_qconfig_mapping, + get_default_qat_qconfig_mapping, + QConfigMapping, + ) + import torch.ao.quantization.quantize_fx as quantize_fx + import copy + + model_fp = UserModel() + + # + # post training dynamic/weight_only quantization + # + + # we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model + model_to_quantize = copy.deepcopy(model_fp) + model_to_quantize.eval() + qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig) + # a tuple of one or more example inputs are needed to trace the model + example_inputs = (input_fp32) + # prepare + model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs) + # no calibration needed when we only have dynamic/weight_only quantization + # quantize + model_quantized = quantize_fx.convert_fx(model_prepared) + + # + # post training static quantization + # + + model_to_quantize = copy.deepcopy(model_fp) + qconfig_mapping = get_default_qconfig_mapping("qnnpack") + model_to_quantize.eval() + # prepare + model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs) + # calibrate (not shown) + # quantize + model_quantized = quantize_fx.convert_fx(model_prepared) + + # + # quantization aware training for static quantization + # + + model_to_quantize = copy.deepcopy(model_fp) + qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack") + model_to_quantize.train() + # prepare + model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs) + # training loop (not shown) + # quantize + model_quantized = quantize_fx.convert_fx(model_prepared) + + # + # fusion + # + model_to_quantize = copy.deepcopy(model_fp) + model_fused = quantize_fx.fuse_fx(model_to_quantize) + +Please follow the tutorials below to learn more about FX Graph Mode Quantization: + +- `User Guide on Using FX Graph Mode Quantization `_ +- `FX Graph Mode Post Training Static Quantization `_ +- `FX Graph Mode Post Training Dynamic Quantization `_ + +(Prototype) PyTorch 2 Export Quantization +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +API Example:: + + import torch + from torch.ao.quantization.quantize_pt2e import prepare_pt2e + from torch.export import export_for_training + from torch.ao.quantization.quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + # initialize a floating point model + float_model = M().eval() + + # define calibration function + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + # Step 1. program capture + # NOTE: this API will be updated to torch.export API in the future, but the captured + # result should mostly stay the same + m = export_for_training(m, *example_inputs).module() + # we get a model with aten ops + + # Step 2. quantization + # backend developer will write their own Quantizer and expose methods to allow + # users to express how they + # want the model to be quantized + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) + # or prepare_qat_pt2e for Quantization Aware Training + m = prepare_pt2e(m, quantizer) + + # run calibration + # calibrate(m, sample_inference_data) + m = convert_pt2e(m) + + # Step 3. lowering + # lower to target backend + + +Please follow these tutorials to get started on PyTorch 2 Export Quantization: + +Modeling Users: + +- `PyTorch 2 Export Post Training Quantization `_ +- `PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor `_ +- `PyTorch 2 Export Quantization Aware Training `_ + +Backend Developers (please check out all Modeling Users docs as well): + +- `How to Write a Quantizer for PyTorch 2 Export Quantization `_ + + +Quantization Stack +------------------------ +Quantization is the process to convert a floating point model to a quantized model. So at high level the quantization stack can be split into two parts: 1). The building blocks or abstractions for a quantized model 2). The building blocks or abstractions for the quantization flow that converts a floating point model to a quantized model + +Quantized Model +^^^^^^^^^^^^^^^^^^^^^^^ +Quantized Tensor +~~~~~~~~~~~~~~~~~ +In order to do quantization in PyTorch, we need to be able to represent +quantized data in Tensors. A Quantized Tensor allows for storing +quantized data (represented as int8/uint8/int32) along with quantization +parameters like scale and zero\_point. Quantized Tensors allow for many +useful operations making quantized arithmetic easy, in addition to +allowing for serialization of data in a quantized format. + +PyTorch supports both per tensor and per channel symmetric and asymmetric quantization. Per tensor means that all the values within the tensor are quantized the same way with the same quantization parameters. Per channel means that for each dimension, typically the channel dimension of a tensor, the values in the tensor are quantized with different quantization parameters. This allows for less error in converting tensors to quantized values since outlier values would only impact the channel it was in, instead of the entire Tensor. + +The mapping is performed by converting the floating point tensors using + +.. image:: math-quantizer-equation.png + :width: 40% + +Note that, we ensure that zero in floating point is represented with no error +after quantization, thereby ensuring that operations like padding do not cause +additional quantization error. + +Here are a few key attributes for quantized Tensor: + +* QScheme (torch.qscheme): a enum that specifies the way we quantize the Tensor + + * torch.per_tensor_affine + * torch.per_tensor_symmetric + * torch.per_channel_affine + * torch.per_channel_symmetric + +* dtype (torch.dtype): data type of the quantized Tensor + + * torch.quint8 + * torch.qint8 + * torch.qint32 + * torch.float16 + +* quantization parameters (varies based on QScheme): parameters for the chosen way of quantization + + * torch.per_tensor_affine would have quantization parameters of + + * scale (float) + * zero_point (int) + * torch.per_channel_affine would have quantization parameters of + + * per_channel_scales (list of float) + * per_channel_zero_points (list of int) + * axis (int) + +Quantize and Dequantize +~~~~~~~~~~~~~~~~~~~~~~~ +The input and output of a model are floating point Tensors, but activations in the quantized model are quantized, so we need operators to convert between floating point and quantized Tensors. + +* Quantize (float -> quantized) + + * torch.quantize_per_tensor(x, scale, zero_point, dtype) + * torch.quantize_per_channel(x, scales, zero_points, axis, dtype) + * torch.quantize_per_tensor_dynamic(x, dtype, reduce_range) + * to(torch.float16) + +* Dequantize (quantized -> float) + + * quantized_tensor.dequantize() - calling dequantize on a torch.float16 Tensor will convert the Tensor back to torch.float + * torch.dequantize(x) + +Quantized Operators/Modules +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +* Quantized Operator are the operators that takes quantized Tensor as inputs, and outputs a quantized Tensor. +* Quantized Modules are PyTorch Modules that performs quantized operations. They are typically defined for weighted operations like linear and conv. + +Quantized Engine +~~~~~~~~~~~~~~~~~~~~ +When a quantized model is executed, the qengine (torch.backends.quantized.engine) specifies which backend is to be used for execution. It is important to ensure that the qengine is compatible with the quantized model in terms of value range of quantized activation and weights. + +Quantization Flow +^^^^^^^^^^^^^^^^^^^^^^^ + +Observer and FakeQuantize +~~~~~~~~~~~~~~~~~~~~~~~~~~ +* Observer are PyTorch Modules used to: + + * collect tensor statistics like min value and max value of the Tensor passing through the observer + * and calculate quantization parameters based on the collected tensor statistics +* FakeQuantize are PyTorch Modules used to: + + * simulate quantization (performing quantize/dequantize) for a Tensor in the network + * it can calculate quantization parameters based on the collected statistics from observer, or it can learn the quantization parameters as well + +QConfig +~~~~~~~~~~~ +* QConfig is a namedtuple of Observer or FakeQuantize Module class that can are configurable with qscheme, dtype etc. it is used to configure how an operator should be observed + + * Quantization configuration for an operator/module + + * different types of Observer/FakeQuantize + * dtype + * qscheme + * quant_min/quant_max: can be used to simulate lower precision Tensors + * Currently supports configuration for activation and weight + * We insert input/weight/output observer based on the qconfig that is configured for a given operator or module + +General Quantization Flow +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +In general, the flow is the following + +* prepare + + * insert Observer/FakeQuantize modules based on user specified qconfig + +* calibrate/train (depending on post training quantization or quantization aware training) + + * allow Observers to collect statistics or FakeQuantize modules to learn the quantization parameters + +* convert + + * convert a calibrated/trained model to a quantized model + +There are different modes of quantization, they can be classified in two ways: + +In terms of where we apply the quantization flow, we have: + +1. Post Training Quantization (apply quantization after training, quantization parameters are calculated based on sample calibration data) +2. Quantization Aware Training (simulate quantization during training so that the quantization parameters can be learned together with the model using training data) + +And in terms of how we quantize the operators, we can have: + +- Weight Only Quantization (only weight is statically quantized) +- Dynamic Quantization (weight is statically quantized, activation is dynamically quantized) +- Static Quantization (both weight and activations are statically quantized) + +We can mix different ways of quantizing operators in the same quantization flow. For example, we can have post training quantization that has both statically and dynamically quantized operators. + +Quantization Support Matrix +-------------------------------------- +Quantization Mode Support +^^^^^^^^^^^^^^^^^^^^^^^^^^^ ++-----------------------------+------------------------------------------------------+----------------+----------------+------------+-----------------+ +| |Quantization |Dataset | Works Best For | Accuracy | Notes | +| |Mode |Requirement | | | | ++-----------------------------+---------------------------------+--------------------+----------------+----------------+------------+-----------------+ +|Post Training Quantization |Dynamic/Weight Only Quantization |activation |None |LSTM, MLP, |good |Easy to use, | +| | |dynamically | |Embedding, | |close to static | +| | |quantized (fp16, | |Transformer | |quantization when| +| | |int8) or not | | | |performance is | +| | |quantized, weight | | | |compute or memory| +| | |statically quantized| | | |bound due to | +| | |(fp16, int8, in4) | | | |weights | +| +---------------------------------+--------------------+----------------+----------------+------------+-----------------+ +| |Static Quantization |activation and |calibration |CNN |good |Provides best | +| | |weights statically |dataset | | |perf, may have | +| | |quantized (int8) | | | |big impact on | +| | | | | | |accuracy, good | +| | | | | | |for hardwares | +| | | | | | |that only support| +| | | | | | |int8 computation | ++-----------------------------+---------------------------------+--------------------+----------------+----------------+------------+-----------------+ +| |Dynamic Quantization |activation and |fine-tuning |MLP, Embedding |best |Limited support | +| | |weight are fake |dataset | | |for now | +| | |quantized | | | | | +| +---------------------------------+--------------------+----------------+----------------+------------+-----------------+ +| |Static Quantization |activation and |fine-tuning |CNN, MLP, |best |Typically used | +| | |weight are fake |dataset |Embedding | |when static | +| | |quantized | | | |quantization | +| | | | | | |leads to bad | +| | | | | | |accuracy, and | +| | | | | | |used to close the| +| | | | | | |accuracy gap | +|Quantization Aware Training | | | | | | | ++-----------------------------+---------------------------------+--------------------+----------------+----------------+------------+-----------------+ + +Please see our `Introduction to Quantization on Pytorch +`_ blog post +for a more comprehensive overview of the tradeoffs between these quantization +types. + +Quantization Flow Support +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +PyTorch provides two modes of quantization: Eager Mode Quantization and FX Graph Mode Quantization. + +Eager Mode Quantization is a beta feature. User needs to do fusion and specify where quantization and dequantization happens manually, also it only supports modules and not functionals. + +FX Graph Mode Quantization is an automated quantization framework in PyTorch, and currently it's a prototype feature. It improves upon Eager Mode Quantization by adding support for functionals and automating the quantization process, although people might need to refactor the model to make the model compatible with FX Graph Mode Quantization (symbolically traceable with ``torch.fx``). Note that FX Graph Mode Quantization is not expected to work on arbitrary models since the model might not be symbolically traceable, we will integrate it into domain libraries like torchvision and users will be able to quantize models similar to the ones in supported domain libraries with FX Graph Mode Quantization. For arbitrary models we'll provide general guidelines, but to actually make it work, users might need to be familiar with ``torch.fx``, especially on how to make a model symbolically traceable. + +New users of quantization are encouraged to try out FX Graph Mode Quantization first, if it does not work, user may try to follow the guideline of `using FX Graph Mode Quantization `_ or fall back to eager mode quantization. + +The following table compares the differences between Eager Mode Quantization and FX Graph Mode Quantization: + ++-----------------+-------------------+-------------------+ +| |Eager Mode |FX Graph | +| |Quantization |Mode | +| | |Quantization | ++-----------------+-------------------+-------------------+ +|Release |beta |prototype | +|Status | | | ++-----------------+-------------------+-------------------+ +|Operator |Manual |Automatic | +|Fusion | | | ++-----------------+-------------------+-------------------+ +|Quant/DeQuant |Manual |Automatic | +|Placement | | | ++-----------------+-------------------+-------------------+ +|Quantizing |Supported |Supported | +|Modules | | | ++-----------------+-------------------+-------------------+ +|Quantizing |Manual |Automatic | +|Functionals/Torch| | | +|Ops | | | ++-----------------+-------------------+-------------------+ +|Support for |Limited Support |Fully | +|Customization | |Supported | ++-----------------+-------------------+-------------------+ +|Quantization Mode|Post Training |Post Training | +|Support |Quantization: |Quantization: | +| |Static, Dynamic, |Static, Dynamic, | +| |Weight Only |Weight Only | +| | | | +| |Quantization Aware |Quantization Aware | +| |Training: |Training: | +| |Static |Static | ++-----------------+-------------------+-------------------+ +|Input/Output |``torch.nn.Module``|``torch.nn.Module``| +|Model Type | |(May need some | +| | |refactors to make | +| | |the model | +| | |compatible with FX | +| | |Graph Mode | +| | |Quantization) | ++-----------------+-------------------+-------------------+ + +Backend/Hardware Support +^^^^^^^^^^^^^^^^^^^^^^^^^^^ ++-----------------+---------------+------------+------------+------------+ +|Hardware |Kernel Library |Eager Mode |FX Graph |Quantization| +| | |Quantization|Mode |Mode Support| +| | | |Quantization| | ++-----------------+---------------+------------+------------+------------+ +|server CPU |fbgemm/onednn |Supported |All | +| | | |Supported | ++-----------------+---------------+ | + +|mobile CPU |qnnpack/xnnpack| | | +| | | | | ++-----------------+---------------+------------+------------+------------+ +|server GPU |TensorRT (early|Not support |Supported |Static | +| |prototype) |this it | |Quantization| +| | |requires a | | | +| | |graph | | | ++-----------------+---------------+------------+------------+------------+ + +Today, PyTorch supports the following backends for running quantized operators efficiently: + +* x86 CPUs with AVX2 support or higher (without AVX2 some operations have inefficient implementations), via `x86` optimized by `fbgemm `_ and `onednn `_ (see the details at `RFC `_) +* ARM CPUs (typically found in mobile/embedded devices), via `qnnpack `_ +* (early prototype) support for NVidia GPU via `TensorRT `_ through `fx2trt` (to be open sourced) + + +Note for native CPU backends +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +We expose both `x86` and `qnnpack` with the same native pytorch quantized operators, so we need additional flag to distinguish between them. The corresponding implementation of `x86` and `qnnpack` is chosen automatically based on the PyTorch build mode, though users have the option to override this by setting `torch.backends.quantization.engine` to `x86` or `qnnpack`. + +When preparing a quantized model, it is necessary to ensure that qconfig +and the engine used for quantized computations match the backend on which +the model will be executed. The qconfig controls the type of observers used +during the quantization passes. The qengine controls whether `x86` or `qnnpack` +specific packing function is used when packing weights for +linear and convolution functions and modules. For example: + +Default settings for x86:: + + # set the qconfig for PTQ + # Note: the old 'fbgemm' is still available but 'x86' is the recommended default on x86 CPUs + qconfig = torch.ao.quantization.get_default_qconfig('x86') + # or, set the qconfig for QAT + qconfig = torch.ao.quantization.get_default_qat_qconfig('x86') + # set the qengine to control weight packing + torch.backends.quantized.engine = 'x86' + +Default settings for qnnpack:: + + # set the qconfig for PTQ + qconfig = torch.ao.quantization.get_default_qconfig('qnnpack') + # or, set the qconfig for QAT + qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack') + # set the qengine to control weight packing + torch.backends.quantized.engine = 'qnnpack' + +Operator Support +^^^^^^^^^^^^^^^^^^^^ + +Operator coverage varies between dynamic and static quantization and is captured in the table below. +Note that for FX Graph Mode Quantization, the corresponding functionals are also supported. + ++---------------------------+-------------------+--------------------+ +| |Static | Dynamic | +| |Quantization | Quantization | ++---------------------------+-------------------+--------------------+ +| | nn.Linear | | Y | | Y | +| | nn.Conv1d/2d/3d | | Y | | N | ++---------------------------+-------------------+--------------------+ +| | nn.LSTM | | N | | Y | +| | nn.GRU | | N | | Y | ++---------------------------+-------------------+--------------------+ +| | nn.RNNCell | | N | | Y | +| | nn.GRUCell | | N | | Y | +| | nn.LSTMCell | | N | | Y | ++---------------------------+-------------------+--------------------+ +|nn.EmbeddingBag | Y (activations | | +| | are in fp32) | Y | ++---------------------------+-------------------+--------------------+ +|nn.Embedding | Y | Y | ++---------------------------+-------------------+--------------------+ +|nn.MultiheadAttention |Not Supported | Not supported | ++---------------------------+-------------------+--------------------+ +|Activations |Broadly supported | Un-changed, | +| | | computations | +| | | stay in fp32 | ++---------------------------+-------------------+--------------------+ + +Note: this will be updated with some information generated from native backend_config_dict soon. + +Quantization API Reference +--------------------------- +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) The :doc:`Quantization API Reference ` contains documentation of quantization APIs, such as quantization passes, quantized tensor operations, @@ -35,6 +925,339 @@ and supported quantized modules and functions. quantization-support +<<<<<<< HEAD +======= +Quantization Backend Configuration +---------------------------------- + +The :doc:`Quantization Backend Configuration ` contains documentation +on how to configure the quantization workflows for various backends. + +.. toctree:: + :hidden: + + quantization-backend-configuration + +Quantization Accuracy Debugging +------------------------------- + +The :doc:`Quantization Accuracy Debugging ` contains documentation +on how to debug quantization accuracy. + +.. toctree:: + :hidden: + + quantization-accuracy-debugging + +Quantization Customizations +--------------------------- + +While default implementations of observers to select the scale factor and bias +based on observed tensor data are provided, developers can provide their own +quantization functions. Quantization can be applied selectively to different +parts of the model or configured differently for different parts of the model. + +We also provide support for per channel quantization for **conv1d()**, **conv2d()**, +**conv3d()** and **linear()**. + +Quantization workflows work by adding (e.g. adding observers as +``.observer`` submodule) or replacing (e.g. converting ``nn.Conv2d`` to +``nn.quantized.Conv2d``) submodules in the model's module hierarchy. It +means that the model stays a regular ``nn.Module``-based instance throughout the +process and thus can work with the rest of PyTorch APIs. + +Quantization Custom Module API +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Both Eager mode and FX graph mode quantization APIs provide a hook for the user +to specify module quantized in a custom way, with user defined logic for +observation and quantization. The user needs to specify: + +1. The Python type of the source fp32 module (existing in the model) +2. The Python type of the observed module (provided by user). This module needs + to define a `from_float` function which defines how the observed module is + created from the original fp32 module. +3. The Python type of the quantized module (provided by user). This module needs + to define a `from_observed` function which defines how the quantized module is + created from the observed module. +4. A configuration describing (1), (2), (3) above, passed to the quantization APIs. + + +The framework will then do the following: + +1. during the `prepare` module swaps, it will convert every module of type + specified in (1) to the type specified in (2), using the `from_float` function of + the class in (2). +2. during the `convert` module swaps, it will convert every module of type + specified in (2) to the type specified in (3), using the `from_observed` function + of the class in (3). + +Currently, there is a requirement that `ObservedCustomModule` will have a single +Tensor output, and an observer will be added by the framework (not by the user) +on that output. The observer will be stored under the `activation_post_process` key +as an attribute of the custom module instance. Relaxing these restrictions may +be done at a future time. + +Custom API Example:: + + import torch + import torch.ao.nn.quantized as nnq + from torch.ao.quantization import QConfigMapping + import torch.ao.quantization.quantize_fx + + # original fp32 module to replace + class CustomModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return self.linear(x) + + # custom observed module, provided by user + class ObservedCustomModule(torch.nn.Module): + def __init__(self, linear): + super().__init__() + self.linear = linear + + def forward(self, x): + return self.linear(x) + + @classmethod + def from_float(cls, float_module): + assert hasattr(float_module, 'qconfig') + observed = cls(float_module.linear) + observed.qconfig = float_module.qconfig + return observed + + # custom quantized module, provided by user + class StaticQuantCustomModule(torch.nn.Module): + def __init__(self, linear): + super().__init__() + self.linear = linear + + def forward(self, x): + return self.linear(x) + + @classmethod + def from_observed(cls, observed_module): + assert hasattr(observed_module, 'qconfig') + assert hasattr(observed_module, 'activation_post_process') + observed_module.linear.activation_post_process = \ + observed_module.activation_post_process + quantized = cls(nnq.Linear.from_float(observed_module.linear)) + return quantized + + # + # example API call (Eager mode quantization) + # + + m = torch.nn.Sequential(CustomModule()).eval() + prepare_custom_config_dict = { + "float_to_observed_custom_module_class": { + CustomModule: ObservedCustomModule + } + } + convert_custom_config_dict = { + "observed_to_quantized_custom_module_class": { + ObservedCustomModule: StaticQuantCustomModule + } + } + m.qconfig = torch.ao.quantization.default_qconfig + mp = torch.ao.quantization.prepare( + m, prepare_custom_config_dict=prepare_custom_config_dict) + # calibration (not shown) + mq = torch.ao.quantization.convert( + mp, convert_custom_config_dict=convert_custom_config_dict) + # + # example API call (FX graph mode quantization) + # + m = torch.nn.Sequential(CustomModule()).eval() + qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_qconfig) + prepare_custom_config_dict = { + "float_to_observed_custom_module_class": { + "static": { + CustomModule: ObservedCustomModule, + } + } + } + convert_custom_config_dict = { + "observed_to_quantized_custom_module_class": { + "static": { + ObservedCustomModule: StaticQuantCustomModule, + } + } + } + mp = torch.ao.quantization.quantize_fx.prepare_fx( + m, qconfig_mapping, torch.randn(3,3), prepare_custom_config=prepare_custom_config_dict) + # calibration (not shown) + mq = torch.ao.quantization.quantize_fx.convert_fx( + mp, convert_custom_config=convert_custom_config_dict) + +Best Practices +-------------- + +1. If you are using the ``x86`` backend, we need to use 7 bits instead of 8 bits. Make sure you reduce the range for the ``quant\_min``, ``quant\_max``, e.g. +if ``dtype`` is ``torch.quint8``, make sure to set a custom ``quant_min`` to be ``0`` and ``quant_max`` to be ``127`` (``255`` / ``2``) +if ``dtype`` is ``torch.qint8``, make sure to set a custom ``quant_min`` to be ``-64`` (``-128`` / ``2``) and ``quant_max`` to be ``63`` (``127`` / ``2``), we already set this correctly if +you call the `torch.ao.quantization.get_default_qconfig(backend)` or `torch.ao.quantization.get_default_qat_qconfig(backend)` function to get the default ``qconfig`` for +``x86`` or ``qnnpack`` backend + +2. If ``onednn`` backend is selected, 8 bits for activation will be used in the default qconfig mapping ``torch.ao.quantization.get_default_qconfig_mapping('onednn')`` +and default qconfig ``torch.ao.quantization.get_default_qconfig('onednn')``. It is recommended to be used on CPUs with Vector Neural Network Instruction (VNNI) +support. Otherwise, setting ``reduce_range`` to True of the activation's observer to get better accuracy on CPUs without VNNI support. + +Frequently Asked Questions +-------------------------- + +1. How can I do quantized inference on GPU?: + + We don't have official GPU support yet, but this is an area of active development, you can find more information + `here `_ + +2. Where can I get ONNX support for my quantized model? + + If you get errors exporting the model (using APIs under ``torch.onnx``), you may open an issue in the PyTorch repository. Prefix the issue title with ``[ONNX]`` and tag the issue as ``module: onnx``. + + If you encounter issues with ONNX Runtime, open an issue at `GitHub - microsoft/onnxruntime `_. + +3. How can I use quantization with LSTM's?: + + LSTM is supported through our custom module api in both eager mode and fx graph mode quantization. Examples can be found at + Eager Mode: `pytorch/test_quantized_op.py TestQuantizedOps.test_custom_module_lstm `_ + FX Graph Mode: `pytorch/test_quantize_fx.py TestQuantizeFx.test_static_lstm `_ + +Common Errors +--------------------------------------- + +Passing a non-quantized Tensor into a quantized kernel +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you see an error similar to:: + + RuntimeError: Could not run 'quantized::some_operator' with arguments from the 'CPU' backend... + +This means that you are trying to pass a non-quantized Tensor to a quantized +kernel. A common workaround is to use ``torch.ao.quantization.QuantStub`` to +quantize the tensor. This needs to be done manually in Eager mode quantization. +An e2e example:: + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = torch.ao.quantization.QuantStub() + self.conv = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + # during the convert step, this will be replaced with a + # `quantize_per_tensor` call + x = self.quant(x) + x = self.conv(x) + return x + +Passing a quantized Tensor into a non-quantized kernel +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you see an error similar to:: + + RuntimeError: Could not run 'aten::thnn_conv2d_forward' with arguments from the 'QuantizedCPU' backend. + +This means that you are trying to pass a quantized Tensor to a non-quantized +kernel. A common workaround is to use ``torch.ao.quantization.DeQuantStub`` to +dequantize the tensor. This needs to be done manually in Eager mode quantization. +An e2e example:: + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = torch.ao.quantization.QuantStub() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + # this module will not be quantized (see `qconfig = None` logic below) + self.conv2 = torch.nn.Conv2d(1, 1, 1) + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x): + # during the convert step, this will be replaced with a + # `quantize_per_tensor` call + x = self.quant(x) + x = self.conv1(x) + # during the convert step, this will be replaced with a + # `dequantize` call + x = self.dequant(x) + x = self.conv2(x) + return x + + m = M() + m.qconfig = some_qconfig + # turn off quantization for conv2 + m.conv2.qconfig = None + +Saving and Loading Quantized models +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When calling ``torch.load`` on a quantized model, if you see an error like:: + + AttributeError: 'LinearPackedParams' object has no attribute '_modules' + +This is because directly saving and loading a quantized model using ``torch.save`` and ``torch.load`` +is not supported. To save/load quantized models, the following ways can be used: + +1. Saving/Loading the quantized model state_dict + +An example:: + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 5) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.linear(x) + x = self.relu(x) + return x + + m = M().eval() + prepare_orig = prepare_fx(m, {'' : default_qconfig}) + prepare_orig(torch.rand(5, 5)) + quantized_orig = convert_fx(prepare_orig) + + # Save/load using state_dict + b = io.BytesIO() + torch.save(quantized_orig.state_dict(), b) + + m2 = M().eval() + prepared = prepare_fx(m2, {'' : default_qconfig}) + quantized = convert_fx(prepared) + b.seek(0) + quantized.load_state_dict(torch.load(b)) + +2. Saving/Loading scripted quantized models using ``torch.jit.save`` and ``torch.jit.load`` + +An example:: + + # Note: using the same model M from previous example + m = M().eval() + prepare_orig = prepare_fx(m, {'' : default_qconfig}) + prepare_orig(torch.rand(5, 5)) + quantized_orig = convert_fx(prepare_orig) + + # save/load using scripted model + scripted = torch.jit.script(quantized_orig) + b = io.BytesIO() + torch.jit.save(scripted, b) + b.seek(0) + scripted_quantized = torch.jit.load(b) + +Symbolic Trace Error when using FX Graph Mode Quantization +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Symbolic traceability is a requirement for `(Prototype - maintenance mode) FX Graph Mode Quantization`_, so if you pass a PyTorch Model that is not symbolically traceable to `torch.ao.quantization.prepare_fx` or `torch.ao.quantization.prepare_qat_fx`, we might see an error like the following:: + + torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow + +Please take a look at `Limitations of Symbolic Tracing `_ and use - `User Guide on Using FX Graph Mode Quantization `_ to workaround the problem. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. torch.ao is missing documentation. Since part of it is mentioned here, adding them here for now. .. They are here for tracking purposes until they are more permanently fixed. @@ -97,8 +1320,13 @@ and supported quantized modules and functions. .. py:module:: torch.ao.ns.fx.ns_types .. py:module:: torch.ao.ns.fx.pattern_utils .. py:module:: torch.ao.ns.fx.qconfig_multi_mapping +<<<<<<< HEAD .. py:module:: torch.ao.ns.fx.weight_utils .. py:module:: torch.ao.ns.fx.utils +======= +.. py:module:: torch.ao.ns.fx.utils +.. py:module:: torch.ao.ns.fx.weight_utils +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. py:module:: torch.ao.pruning.scheduler.base_scheduler .. py:module:: torch.ao.pruning.scheduler.cubic_scheduler .. py:module:: torch.ao.pruning.scheduler.lambda_scheduler @@ -110,6 +1338,10 @@ and supported quantized modules and functions. .. py:module:: torch.ao.quantization.backend_config.executorch .. py:module:: torch.ao.quantization.backend_config.fbgemm .. py:module:: torch.ao.quantization.backend_config.native +<<<<<<< HEAD +======= +.. py:module:: torch.ao.quantization.backend_config.observation_type +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. py:module:: torch.ao.quantization.backend_config.onednn .. py:module:: torch.ao.quantization.backend_config.qnnpack .. py:module:: torch.ao.quantization.backend_config.tensorrt @@ -210,9 +1442,12 @@ and supported quantized modules and functions. .. py:module:: torch.quantization.quantize_jit .. py:module:: torch.quantization.stubs .. py:module:: torch.quantization.utils +<<<<<<< HEAD .. currentmodule:: torch.ao.ns.fx.utils .. autofunction:: torch.ao.ns.fx.utils.compute_sqnr(x, y) .. autofunction:: torch.ao.ns.fx.utils.compute_normalized_l2_error(x, y) .. autofunction:: torch.ao.ns.fx.utils.compute_cosine_similarity(x, y) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/docs/source/rpc.md b/docs/source/rpc.md index 38e9354f70d9f..24fa7cb157c59 100644 --- a/docs/source/rpc.md +++ b/docs/source/rpc.md @@ -8,6 +8,7 @@ higher-level API to automatically differentiate models split across several machines. ```{warning} +<<<<<<< HEAD APIs in the RPC package are stable and in maintenance mode. ``` @@ -16,6 +17,18 @@ CUDA support is a **beta** feature. Not all features of the RPC package are yet compatible with CUDA support and thus their use is discouraged. These unsupported features include: RRefs, JIT compatibility, dist autograd and dist optimizer, and profiling. +======= +APIs in the RPC package are stable. There are multiple ongoing work items +to improve performance and error handling, which will ship in future releases. +``` + +```{warning} +CUDA support was introduced in PyTorch 1.9 and is still a **beta** feature. +Not all features of the RPC package are yet compatible with CUDA support and +thus their use is discouraged. These unsupported features include: RRefs, +JIT compatibility, dist autograd and dist optimizer, and profiling. These +shortcomings will be addressed in future releases. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` ```{note} @@ -100,6 +113,16 @@ device lists on source and destination workers do not match. In such cases, applications can always explicitly move the input tensors to CPU on the caller and move it to the desired devices on the callee if necessary. +<<<<<<< HEAD +======= +```{warning} + TorchScript support in RPC is a prototype feature and subject to change. Since + v1.5.0, ``torch.distributed.rpc`` supports calling TorchScript functions as + RPC target functions, and this will help improve parallelism on the callee + side as executing TorchScript functions does not require GIL. +``` + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ```{eval-rst} .. autofunction:: rpc_sync .. autofunction:: rpc_async @@ -150,7 +173,13 @@ multiple different transports (TCP, of course, but also shared memory, NVLink, InfiniBand, ...) and can automatically detect their availability and negotiate the best transport to use for each pipe. +<<<<<<< HEAD The TensorPipe backend comes with a TCP-based transport, just like Gloo. It is also able to +======= +The TensorPipe backend has been introduced in PyTorch v1.6 and is being actively +developed. At the moment, it only supports CPU tensors, with GPU support coming +soon. It comes with a TCP-based transport, just like Gloo. It is also able to +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) automatically chunk and multiplex large tensors over multiple sockets and threads in order to achieve very high bandwidths. The agent will be able to pick the best transport on its own, with no intervention required. @@ -290,4 +319,10 @@ to use [the profiler](https://pytorch.org/docs/stable/autograd.html#profiler) to - [Getting started with Distributed RPC Framework](https://pytorch.org/tutorials/intermediate/rpc_tutorial.html) - [Implementing a Parameter Server using Distributed RPC Framework](https://pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html) - [Combining Distributed DataParallel with Distributed RPC Framework](https://pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html) (covers **RemoteModule** as well) +<<<<<<< HEAD +- [Implementing batch RPC processing](https://pytorch.org/tutorials/intermediate/rpc_async_execution.html) +======= +- [Profiling RPC-based Workloads](https://pytorch.org/tutorials/recipes/distributed_rpc_profiling.html) - [Implementing batch RPC processing](https://pytorch.org/tutorials/intermediate/rpc_async_execution.html) +- [Distributed Pipeline Parallel](https://pytorch.org/tutorials/intermediate/dist_pipeline_parallel_tutorial.html) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/docs/source/scripts/build_quantization_configs.py b/docs/source/scripts/build_quantization_configs.py new file mode 100644 index 0000000000000..5d1f445ade9a1 --- /dev/null +++ b/docs/source/scripts/build_quantization_configs.py @@ -0,0 +1,64 @@ +""" +This script will generate default values of quantization configs. +These are for use in the documentation. +""" + +import os.path + +import torch +from torch.ao.quantization.backend_config import get_native_backend_config_dict +from torch.ao.quantization.backend_config.utils import ( + entry_to_pretty_str, + remove_boolean_dispatch_from_name, +) + + +# Create a directory for the images, if it doesn't exist +QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH = os.path.join( + os.path.realpath(os.path.dirname(__file__)), "quantization_backend_configs" +) + +if not os.path.exists(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH): + os.mkdir(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH) + +output_path = os.path.join( + QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH, "default_backend_config.txt" +) + +with open(output_path, "w") as f: + native_backend_config_dict = get_native_backend_config_dict() + + configs = native_backend_config_dict["configs"] + + def _sort_key_func(entry): + pattern = entry["pattern"] + while isinstance(pattern, tuple): + pattern = pattern[-1] + + pattern = remove_boolean_dispatch_from_name(pattern) + if not isinstance(pattern, str): + # methods are already strings + pattern = torch.typename(pattern) + + # we want + # + # torch.nn.modules.pooling.AdaptiveAvgPool1d + # + # and + # + # torch._VariableFunctionsClass.adaptive_avg_pool1d + # + # to be next to each other, so convert to all lower case + # and remove the underscores, and compare the last part + # of the string + pattern_str_normalized = pattern.lower().replace("_", "") + key = pattern_str_normalized.split(".")[-1] + return key + + configs.sort(key=_sort_key_func) + + entries = [] + for entry in configs: + entries.append(entry_to_pretty_str(entry)) + entries = ",\n".join(entries) + f.write(entries) diff --git a/docs/source/scripts/onnx/build_onnx_torchscript_supported_aten_op_csv_table.py b/docs/source/scripts/onnx/build_onnx_torchscript_supported_aten_op_csv_table.py new file mode 100644 index 0000000000000..6e512d59507c3 --- /dev/null +++ b/docs/source/scripts/onnx/build_onnx_torchscript_supported_aten_op_csv_table.py @@ -0,0 +1,59 @@ +""" +This script generates a CSV table with all ATen operators +supported by `torch.onnx.export`. The generated table is included by +docs/source/onnx_supported_aten_list.rst. +""" + +import os + +from torch.onnx import _onnx_supported_ops + + +# Constants +BUILD_DIR = "build/onnx" +SUPPORTED_OPS_CSV_FILE = "auto_gen_supported_op_list.csv" +UNSUPPORTED_OPS_CSV_FILE = "auto_gen_unsupported_op_list.csv" + + +def _sort_key(namespaced_opname): + return tuple(reversed(namespaced_opname.split("::"))) + + +def _get_op_lists(): + all_schemas = _onnx_supported_ops.all_forward_schemas() + symbolic_schemas = _onnx_supported_ops.all_symbolics_schemas() + supported_result = set() + not_supported_result = set() + for opname in all_schemas: + opname = opname.removesuffix("_") + if opname in symbolic_schemas: + # Supported op + opsets = symbolic_schemas[opname].opsets + supported_result.add((opname, f"Since opset {opsets[0]}")) + else: + # Unsupported op + not_supported_result.add((opname, "Not yet supported")) + return ( + sorted(supported_result, key=lambda x: _sort_key(x[0])), + sorted(not_supported_result), + ) + + +def main(): + os.makedirs(BUILD_DIR, exist_ok=True) + + supported, unsupported = _get_op_lists() + + with open(os.path.join(BUILD_DIR, SUPPORTED_OPS_CSV_FILE), "w") as f: + f.write("Operator,opset_version(s)\n") + for name, opset_version in supported: + f.write(f'"``{name}``","{opset_version}"\n') + + with open(os.path.join(BUILD_DIR, UNSUPPORTED_OPS_CSV_FILE), "w") as f: + f.write("Operator,opset_version(s)\n") + for name, opset_version in unsupported: + f.write(f'"``{name}``","{opset_version}"\n') + + +if __name__ == "__main__": + main() diff --git a/docs/source/tensor_attributes.rst b/docs/source/tensor_attributes.rst index eda8dbce234ce..856175a3c12af 100644 --- a/docs/source/tensor_attributes.rst +++ b/docs/source/tensor_attributes.rst @@ -17,6 +17,7 @@ torch.dtype A :class:`torch.dtype` is an object that represents the data type of a :class:`torch.Tensor`. PyTorch has several different data types: +<<<<<<< HEAD **Floating point dtypes** ========================================= =============================== @@ -76,6 +77,32 @@ dtype description **Note**: legacy constructors such as ``torch.*.FloatTensor``, ``torch.*.DoubleTensor``, ``torch.*.HalfTensor``, ``torch.*.BFloat16Tensor``, ``torch.*.ByteTensor``, ``torch.*.CharTensor``, ``torch.*.ShortTensor``, ``torch.*.IntTensor``, ``torch.*.LongTensor``, ``torch.*.BoolTensor`` only remain for backwards compatibility and should no longer be used. +======= +========================== =========================================== =========================== +Data type dtype Legacy Constructors +========================== =========================================== =========================== +32-bit floating point ``torch.float32`` or ``torch.float`` ``torch.*.FloatTensor`` +64-bit floating point ``torch.float64`` or ``torch.double`` ``torch.*.DoubleTensor`` +32-bit complex ``torch.complex32`` or ``torch.chalf`` +64-bit complex ``torch.complex64`` or ``torch.cfloat`` +128-bit complex ``torch.complex128`` or ``torch.cdouble`` +16-bit floating point [1]_ ``torch.float16`` or ``torch.half`` ``torch.*.HalfTensor`` +16-bit floating point [2]_ ``torch.bfloat16`` ``torch.*.BFloat16Tensor`` +8-bit integer (unsigned) ``torch.uint8`` ``torch.*.ByteTensor`` +8-bit integer (signed) ``torch.int8`` ``torch.*.CharTensor`` +16-bit integer (signed) ``torch.int16`` or ``torch.short`` ``torch.*.ShortTensor`` +32-bit integer (signed) ``torch.int32`` or ``torch.int`` ``torch.*.IntTensor`` +64-bit integer (signed) ``torch.int64`` or ``torch.long`` ``torch.*.LongTensor`` +Boolean ``torch.bool`` ``torch.*.BoolTensor`` +========================== =========================================== =========================== + +.. [1] Sometimes referred to as binary16: uses 1 sign, 5 exponent, and 10 + significand bits. Useful when precision is important. + +.. [2] Sometimes referred to as Brain Floating Point: use 1 sign, 8 exponent and 7 + significand bits. Useful when range is important, since it has the same + number of exponent bits as ``float32`` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) To find out if a :class:`torch.dtype` is a floating point data type, the property :attr:`is_floating_point` can be used, which returns ``True`` if the data type is a floating point data type. @@ -99,8 +126,13 @@ by finding the minimum dtype that satisfies the following rules: A floating point scalar operand has dtype `torch.get_default_dtype()` and an integral non-boolean scalar operand has dtype `torch.int64`. Unlike numpy, we do not inspect +<<<<<<< HEAD values when determining the minimum `dtypes` of an operand. Complex types are not yet supported. Promotion for shell dtypes is not defined. +======= +values when determining the minimum `dtypes` of an operand. Quantized and complex types +are not yet supported. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Promotion Examples:: @@ -184,6 +216,7 @@ the result of :func:`torch.cuda.current_device()`. A :class:`torch.Tensor`'s device can be accessed via the :attr:`Tensor.device` property. +<<<<<<< HEAD A :class:`torch.device` can be constructed using: * A device string, which is a string representation of the device type and optionally the device ordinal. @@ -191,6 +224,11 @@ A :class:`torch.device` can be constructed using: * A device ordinal, where the current :ref:`accelerator` type will be used. Via a device string: +======= +A :class:`torch.device` can be constructed via a string or via a string and device ordinal + +Via a string: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) :: >>> torch.device('cuda:0') @@ -202,10 +240,17 @@ Via a device string: >>> torch.device('mps') device(type='mps') +<<<<<<< HEAD >>> torch.device('cuda') # implicit index is the "current device index" device(type='cuda') Via a device type and a device ordinal: +======= + >>> torch.device('cuda') # current cuda device + device(type='cuda') + +Via a string and device ordinal: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) :: @@ -218,6 +263,7 @@ Via a device type and a device ordinal: >>> torch.device('cpu', 0) device(type='cpu', index=0) +<<<<<<< HEAD Via a device ordinal: .. note:: @@ -236,6 +282,8 @@ Via a device ordinal: File "", line 1, in RuntimeError: Cannot access accelerator device when none is available. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) The device object can also be used as a context manager to change the default device tensors are allocated on: @@ -269,12 +317,30 @@ non-None device argument. To globally change the default device, see also >>> torch.randn((2,3), device='cuda:1') .. note:: +<<<<<<< HEAD Methods which take a device will generally accept a (properly formatted) string or an integer device ordinal, i.e. the following are all equivalent: >>> torch.randn((2,3), device=torch.device('cuda:1')) >>> torch.randn((2,3), device='cuda:1') >>> torch.randn((2,3), device=1) # equivalent to 'cuda:1' if the current accelerator is cuda +======= + For legacy reasons, a device can be constructed via a single device ordinal, which is treated + as the current :ref:`accelerator` type. + This matches :meth:`Tensor.get_device`, which returns an ordinal for device + tensors and is not supported for cpu tensors. + + >>> torch.device(1) + device(type='cuda', index=1) + +.. note:: + Methods which take a device will generally accept a (properly formatted) string + or (legacy) integer device ordinal, i.e. the following are all equivalent: + + >>> torch.randn((2,3), device=torch.device('cuda:1')) + >>> torch.randn((2,3), device='cuda:1') + >>> torch.randn((2,3), device=1) # legacy +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. note:: Tensors are never moved automatically between devices and require an explicit call from the user. Scalar Tensors (with tensor.dim()==0) are the only exception to this rule and they are automatically transferred from CPU to GPU when needed as this operation can be done "for free". diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index c2336dfd81ec0..07e51e7b1bb25 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -6,7 +6,88 @@ torch.Tensor =================================== A :class:`torch.Tensor` is a multi-dimensional matrix containing elements of +<<<<<<< HEAD a single data type. Please see :ref:`dtype-doc` for more details about dtype support. +======= +a single data type. + + +Data types +---------- + +Torch defines tensor types with the following data types: + +======================================= =========================================== +Data type dtype +======================================= =========================================== +32-bit floating point ``torch.float32`` or ``torch.float`` +64-bit floating point ``torch.float64`` or ``torch.double`` +16-bit floating point [1]_ ``torch.float16`` or ``torch.half`` +16-bit floating point [2]_ ``torch.bfloat16`` +32-bit complex ``torch.complex32`` or ``torch.chalf`` +64-bit complex ``torch.complex64`` or ``torch.cfloat`` +128-bit complex ``torch.complex128`` or ``torch.cdouble`` +8-bit integer (unsigned) ``torch.uint8`` +16-bit integer (unsigned) ``torch.uint16`` (limited support) [4]_ +32-bit integer (unsigned) ``torch.uint32`` (limited support) [4]_ +64-bit integer (unsigned) ``torch.uint64`` (limited support) [4]_ +8-bit integer (signed) ``torch.int8`` +16-bit integer (signed) ``torch.int16`` or ``torch.short`` +32-bit integer (signed) ``torch.int32`` or ``torch.int`` +64-bit integer (signed) ``torch.int64`` or ``torch.long`` +Boolean ``torch.bool`` +quantized 8-bit integer (unsigned) ``torch.quint8`` +quantized 8-bit integer (signed) ``torch.qint8`` +quantized 32-bit integer (signed) ``torch.qint32`` +quantized 4-bit integer (unsigned) [3]_ ``torch.quint4x2`` +8-bit floating point, e4m3 [5]_ ``torch.float8_e4m3fn`` (limited support) +8-bit floating point, e5m2 [5]_ ``torch.float8_e5m2`` (limited support) +======================================= =========================================== + +.. [1] + Sometimes referred to as binary16: uses 1 sign, 5 exponent, and 10 + significand bits. Useful when precision is important at the expense of range. +.. [2] + Sometimes referred to as Brain Floating Point: uses 1 sign, 8 exponent, and 7 + significand bits. Useful when range is important, since it has the same + number of exponent bits as ``float32`` +.. [3] + quantized 4-bit integer is stored as a 8-bit signed integer. Currently it's only supported in EmbeddingBag operator. +.. [4] + Unsigned types asides from ``uint8`` are currently planned to only have + limited support in eager mode (they primarily exist to assist usage with + torch.compile); if you need eager support and the extra range is not needed, + we recommend using their signed variants instead. See + https://github.com/pytorch/pytorch/issues/58734 for more details. +.. [5] + ``torch.float8_e4m3fn`` and ``torch.float8_e5m2`` implement the spec for 8-bit + floating point types from https://arxiv.org/abs/2209.05433. The op support + is very limited. + + +For backwards compatibility, we support the following alternate class names +for these data types: + +======================================= ============================= ================================ +Data type CPU tensor GPU tensor +======================================= ============================= ================================ +32-bit floating point :class:`torch.FloatTensor` :class:`torch.cuda.FloatTensor` +64-bit floating point :class:`torch.DoubleTensor` :class:`torch.cuda.DoubleTensor` +16-bit floating point :class:`torch.HalfTensor` :class:`torch.cuda.HalfTensor` +16-bit floating point :class:`torch.BFloat16Tensor` :class:`torch.cuda.BFloat16Tensor` +8-bit integer (unsigned) :class:`torch.ByteTensor` :class:`torch.cuda.ByteTensor` +8-bit integer (signed) :class:`torch.CharTensor` :class:`torch.cuda.CharTensor` +16-bit integer (signed) :class:`torch.ShortTensor` :class:`torch.cuda.ShortTensor` +32-bit integer (signed) :class:`torch.IntTensor` :class:`torch.cuda.IntTensor` +64-bit integer (signed) :class:`torch.LongTensor` :class:`torch.cuda.LongTensor` +Boolean :class:`torch.BoolTensor` :class:`torch.cuda.BoolTensor` +======================================= ============================= ================================ + +However, to construct tensors, we recommend using factory functions such as +:func:`torch.empty` with the ``dtype`` argument instead. The +:class:`torch.Tensor` constructor is an alias for the default tensor type +(:class:`torch.FloatTensor`). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Initializing and basic operations --------------------------------- diff --git a/docs/source/torch.ao.ns._numeric_suite.md b/docs/source/torch.ao.ns._numeric_suite.md new file mode 100644 index 0000000000000..b1466470fe26c --- /dev/null +++ b/docs/source/torch.ao.ns._numeric_suite.md @@ -0,0 +1,16 @@ +(torch_ao_ns_numeric_suite)= + +# torch.ao.ns._numeric_suite + +```{warning} +This module is an early prototype and is subject to change. +``` + +```{eval-rst} +.. currentmodule:: torch.ao.ns._numeric_suite +``` +```{eval-rst} +.. automodule:: torch.ao.ns._numeric_suite + :members: + :member-order: bysource +``` diff --git a/docs/source/torch.ao.ns._numeric_suite_fx.md b/docs/source/torch.ao.ns._numeric_suite_fx.md new file mode 100644 index 0000000000000..46a46d598f4f5 --- /dev/null +++ b/docs/source/torch.ao.ns._numeric_suite_fx.md @@ -0,0 +1,39 @@ +(torch_ao_ns_numeric_suite_fx)= + +# torch.ao.ns._numeric_suite_fx + + +```{warning} + This module is an early prototype and is subject to change. +``` + +```{eval-rst} +.. automodule:: torch.ao.ns._numeric_suite_fx + :members: + :member-order: bysource + +``` +--- + +# torch.ao.ns.fx.utils + + +```{warning} + This module is an early prototype and is subject to change. +``` + +```{eval-rst} +.. currentmodule:: torch.ao.ns.fx.utils +``` + +```{eval-rst} +.. function:: compute_sqnr(x, y) +``` + +```{eval-rst} +.. function:: compute_normalized_l2_error(x, y) +``` + +```{eval-rst} +.. function:: compute_cosine_similarity(x, y) +``` \ No newline at end of file diff --git a/docs/source/torch.compiler.md b/docs/source/torch.compiler.md index fe37c3b42aea5..e9a1a5cae2ebd 100644 --- a/docs/source/torch.compiler.md +++ b/docs/source/torch.compiler.md @@ -26,9 +26,12 @@ written in Python and it marks the transition of PyTorch from C++ to Python. which results in capturing the backwards pass "ahead-of-time". This enables acceleration of both forwards and backwards pass using TorchInductor. +<<<<<<< HEAD To better understand how `torch.compile` tracing behavior on your code, or to learn more about the internals of `torch.compile`, please refer to the [`torch.compile` programming model](compile/programming_model.md). +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) :::{note} In some cases, the terms `torch.compile`, TorchDynamo, `torch.compiler` might be used interchangeably in this documentation. @@ -39,7 +42,11 @@ TorchDynamo requires a backend that converts the captured graphs into a fast machine code. Different backends can result in various optimization gains. The default backend is called TorchInductor, also known as *inductor*, TorchDynamo has a list of supported backends developed by our partners, +<<<<<<< HEAD which can be seen by running `torch.compiler.list_backends()` each of which +======= +which can be see by running `torch.compiler.list_backends()` each of which +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with its optional dependencies. Some of the most commonly used backends include: @@ -59,6 +66,11 @@ Some of the most commonly used backends include: - CUDA graphs with AOT Autograd. `Read more `__ * - ``torch.compile(m, backend="ipex")`` - Uses IPEX on CPU. `Read more `__ +<<<<<<< HEAD +======= + * - ``torch.compile(m, backend="onnxrt")`` + - Uses ONNX Runtime for training on CPU/GPU. :doc:`Read more ` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` **Inference-only backends** @@ -91,7 +103,10 @@ Some of the most commonly used backends include: torch.compiler_api torch.compiler.config torch.compiler_fine_grain_apis +<<<<<<< HEAD torch.compiler_backward +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.compiler_aot_inductor torch.compiler_inductor_profiling torch.compiler_profiling_torch_compile @@ -101,6 +116,7 @@ Some of the most commonly used backends include: torch.compiler_inductor_provenance ``` +<<<<<<< HEAD ```{eval-rst} .. toctree:: :caption: `torch.compile` Programming Model @@ -108,6 +124,8 @@ Some of the most commonly used backends include: compile/programming_model ``` +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) % _If you want to contribute a developer-level topic % that provides in-depth overview of a torch._dynamo feature, % add in the below toc. diff --git a/docs/source/torch.compiler_aot_inductor.md b/docs/source/torch.compiler_aot_inductor.md index 0584cac0aa917..1abc7a8a6d50a 100644 --- a/docs/source/torch.compiler_aot_inductor.md +++ b/docs/source/torch.compiler_aot_inductor.md @@ -1,5 +1,8 @@ +<<<<<<< HEAD (torch.compiler_aot_inductor)= +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models ```{warning} @@ -27,7 +30,11 @@ relies on. We will then use {func}`torch._inductor.aoti_compile_and_package` to compile the exported program using TorchInductor, and save the compiled artifacts into one +<<<<<<< HEAD package. The package is in the format of a {ref}`PT2 Archive Spec `. +======= +package. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ```{note} If you have a CUDA-enabled device on your machine and you installed PyTorch with CUDA support, @@ -202,7 +209,10 @@ Below are some useful tools for debugging AOT Inductor. logging torch.compiler_aot_inductor_minifier +<<<<<<< HEAD torch.compiler_aot_inductor_debugging_guide +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` To enable runtime checks on inputs, set the environment variable `AOTI_RUNTIME_CHECK_INPUTS` to 1. This will raise a `RuntimeError` if the inputs to the compiled model differ in size, data type, or strides from those used during export. diff --git a/docs/source/torch.compiler_cudagraph_trees.md b/docs/source/torch.compiler_cudagraph_trees.md index eb137625ea746..713a91304ed7e 100644 --- a/docs/source/torch.compiler_cudagraph_trees.md +++ b/docs/source/torch.compiler_cudagraph_trees.md @@ -219,7 +219,10 @@ may skip CUDAGraph when necessary. Here, we list common reasons for skipping CUD [dynamic shapes](https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html). CUDAGraph Trees currently record a CUDAGraph for every unique input tensor shapes. Please see *Dynamic Shape Support* for more details. +<<<<<<< HEAD - **CUDAGraph-unsafe custom ops**: Some custom ops may include cudagraph unsafe ops, which causes cudagraph to be skipped. Please see *CUDAGraph Unsafe Custom Ops* for more details. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - **Incompatible operators**: CUDAGraph Trees skip a function if it contain incompatible operators. Please replace these operators in a function with supported operators. We show an exhaustive list of incompatible operators: @@ -250,6 +253,7 @@ aten._local_scalar_dense aten._assert_scalar ``` +<<<<<<< HEAD ### CUDAGraph Unsafe Custom Ops Custom ops are assumed to be safe for CUDAGraph by default. However, some custom ops may include unsupported ops such as cpu ops. Since custom op are treated as black boxes by the compiler, users must explicitly mark these ops as unsafe for CUDAGraph by setting the `torch._C.Tag.cudagraph_unsafe` tag, as demonstrated in the example below. When a function contains cudagraph-unsafe custom ops, it will be skipped by CUDAGraph unless *CUDAGraph partition* is enabled. @@ -293,6 +297,8 @@ Currently, CUDAGraph partition supports splitting off the following types of ops - **Unbacked Symints**: Please refer to *Dynamic Shape Support* section for more information. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ### Limitations Because CUDA Graph fixes memory addresses, CUDA Graphs do not have a great way of handling live tensors from a previous invocation. @@ -328,4 +334,8 @@ tensors of a prior iteration (outside of torch.compile) before you begin the nex |---------------|------------------------------------------------------------|------------------------------------------------------------------------| | Memory Can Increase | On each graph compilation (new sizes, etc.) | If you are also running non-cudagraph memory | | Recordings | On any new invocation of a graph | Will re-record on any new, unique path you take through your program | +<<<<<<< HEAD | Footguns | Invocation of one graph will overwrite prior invocation | Cannot persist memory between separate runs through your model - one training loop training, or one run of inference | +======= +| Footguns | Invocation of one graph will overwrite prior invocation | Cannot persist memory between separate runs through your model - one training loop training, or one run of inference | +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/docs/source/torch.compiler_dynamo_deepdive.md b/docs/source/torch.compiler_dynamo_deepdive.md index 9fa7654023ca5..835bca609c4f1 100644 --- a/docs/source/torch.compiler_dynamo_deepdive.md +++ b/docs/source/torch.compiler_dynamo_deepdive.md @@ -285,7 +285,11 @@ appear in the errors, and the `VariableTracker` method that throws the exception when you encounter a Dynamo error. In particular, sometimes we find that an object is tracked as a `UserDefinedObjectVariable` (this is Dynamo’s catch-all class), when it should have been tracked as +<<<<<<< HEAD something more specific. In these cases, the `VariableBuilder` +======= +something more specific. In these cases, the `SourceBuilder.__call__` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) logic is often to blame. **Debugging tip**. When running a program with `TORCH_LOGS=dynamo`, diff --git a/docs/source/torch.compiler_ir.md b/docs/source/torch.compiler_ir.md index ff66b8cc7efce..0201d2befeb86 100644 --- a/docs/source/torch.compiler_ir.md +++ b/docs/source/torch.compiler_ir.md @@ -1,5 +1,8 @@ +<<<<<<< HEAD (torch.compiler_ir)= +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # IRs PyTorch 2.0 offers two set of IRs for backends to interface with: Core Aten IR and Prims IR. diff --git a/docs/source/torch.compiler_profiling_torch_compile.md b/docs/source/torch.compiler_profiling_torch_compile.md index 9c1a215920abf..1385607ad02da 100644 --- a/docs/source/torch.compiler_profiling_torch_compile.md +++ b/docs/source/torch.compiler_profiling_torch_compile.md @@ -134,7 +134,11 @@ Note a few things: Although there are logging tools for identifying graph breaks, the profiler provides a quick visual method of identifying :ref:`graph breaks `. There are two profiler events to look for: **Torch-Compiled Region** and **CompiledFunction**. +<<<<<<< HEAD **Torch-Compiled Region** - which was introduced in PyTorch 2.2 - is a profiler event that covers the entire compiled region. Graph breaks almost always look the same: nested ā€œTorch-Compiled Regionā€ events. Starting in PyTorch 2.5, the profiler event will also contain the frame ID and the frame compile ID. The frame ID is a unique identifier for the frame, and the frame compile ID denotes how many times the frame has been compiled. +======= +**Torch-Compiled Region** - which was introduced in PyTorch 2.2 - is a profiler event that covers the entire compiled region. Graph breaks almost always look the same: nested ā€œTorch-Compiled Regionā€ events. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) If you run two separate functions with torch.compile() applied independently on each of them, you should generally expect to see two adjacent (i.e NOT stacked/nested) Torch-Compiled regions. Meanwhile, if you encounter graph breaks (or disable()'ed/skipped regions), expect nested ā€œTorch-Compiled Regionā€ events. @@ -249,4 +253,8 @@ One common issue is bad GPU utilization. A quick way to identify this is if ther This is often the result of CPU overhead, e.g. if the amount of time spent on the CPU between kernel launches is larger than the amount of time spent by the GPU to process the kernels. The issue is more common for small batch sizes. +<<<<<<< HEAD When using inductor, enabling CUDA graphs can often help improve performance when launch overhead is a concern. +======= +When using inductor, enabling CUDA graphs can often help improve performance when launch overhead is a concern. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/docs/source/torch.compiler_troubleshooting.md b/docs/source/torch.compiler_troubleshooting.md index a4f7af3b9b8e9..8a00da1e2596a 100644 --- a/docs/source/torch.compiler_troubleshooting.md +++ b/docs/source/torch.compiler_troubleshooting.md @@ -192,8 +192,11 @@ For more information on dynamic shapes, see [The dynamic shapes manual](https:// ## Logging Tools +<<<<<<< HEAD (tlparse-torch-trace)= +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ### tlparse / TORCH_TRACE `tlparse` / `TORCH_TRACE` are a pair of tools that produce compilation reports that look like this: @@ -254,8 +257,11 @@ Here are some insights you can gain from a `tlparse`: For example, you can look at the high-level generated FX graph or the generated Triton code. - Is there relevant information for a particular frame? You can find these in `compilation_metrics`. +<<<<<<< HEAD (torch-logs)= +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ### TORCH_LOGS You can use the `TORCH_LOGS` environment variable to selectively enable parts of the `torch.compile` stack to log. diff --git a/docs/source/torch.compiler_troubleshooting_old.md b/docs/source/torch.compiler_troubleshooting_old.md index ef13fc1772374..980f21095cf30 100644 --- a/docs/source/torch.compiler_troubleshooting_old.md +++ b/docs/source/torch.compiler_troubleshooting_old.md @@ -717,5 +717,9 @@ backtrace is slow and very spammy so it is not included by default with extended In order to measure the cold start compilation time or debug a cache corruption, it is possible pass `TORCHINDUCTOR_FORCE_DISABLE_CACHES=1` or set +<<<<<<< HEAD `torch.compiler.config.force_disable_caches = True` which will override any +======= +`torch._inductor.config.force_disable_caches = True` which will override any +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) other caching config option and disable all compile time caching. diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 645fdd52135f9..8a699afa1fcdc 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -122,7 +122,10 @@ Indexing, Slicing, Joining, Mutating Ops slice_scatter scatter_add scatter_reduce +<<<<<<< HEAD segment_reduce +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) split squeeze stack @@ -475,7 +478,10 @@ Reduction Ops var var_mean count_nonzero +<<<<<<< HEAD hash_tensor +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Comparison Ops ~~~~~~~~~~~~~~~~~~~~~~ @@ -806,6 +812,10 @@ Operator Tags .. for tracking purposes .. py:module:: torch.utils.model_dump .. py:module:: torch.utils.viz +<<<<<<< HEAD +======= +.. py:module:: torch.functional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. py:module:: torch.quasirandom .. py:module:: torch.return_types .. py:module:: torch.serialization @@ -815,6 +825,7 @@ Operator Tags .. py:module:: torch.torch_version .. py:module:: torch.types .. py:module:: torch.version +<<<<<<< HEAD .. Hidden aliases (e.g. torch.functional.broadcast_tensors()). We want `torch.broadcast_tensors()` to be visible only. @@ -822,3 +833,5 @@ Operator Tags :hidden: torch.aliases.md +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/docs/source/torch_cuda_memory.md b/docs/source/torch_cuda_memory.md index f7f1fe706dad3..0f2de798dfb9e 100644 --- a/docs/source/torch_cuda_memory.md +++ b/docs/source/torch_cuda_memory.md @@ -32,7 +32,11 @@ torch.cuda.memory._dump_snapshot("my_snapshot.pickle") ## Using the visualizer +<<<<<<< HEAD Open and drag/drop the pickled snapshot file into the visualizer. +======= +Open [pytorch.org/memory_viz](https://pytorch.org/memory_viz>) and drag/drop the pickled snapshot file into the visualizer. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) The visualizer is a javascript application that runs locally on your computer. It does not upload any snapshot data. diff --git a/docs/source/type_info.md b/docs/source/type_info.md index 9933d551506d9..400a23ffc85e2 100644 --- a/docs/source/type_info.md +++ b/docs/source/type_info.md @@ -20,6 +20,7 @@ This is similar to [numpy.finfo](https://numpy.org/doc/stable/reference/generate A {class}`torch.finfo` provides the following attributes: +<<<<<<< HEAD | Name | Type | Description | | :-------------- | :---- | :------------------------------------------------------------------------------------------ | | bits | int | The number of bits occupied by the type. | @@ -29,6 +30,17 @@ A {class}`torch.finfo` provides the following attributes: | tiny | float | The smallest positive normal number. Equivalent to ``smallest_normal``. | | smallest_normal | float | The smallest positive normal number. See notes. | | resolution | float | The approximate decimal resolution of this type, i.e., ``10**-precision``. | +======= +| Name | Type | Description | +| :-------------- | :---- | :------------------------------------------------------------------------- | +| bits | int | The number of bits occupied by the type. | +| eps | float | The smallest representable number such that ``1.0 + eps != 1.0``. | +| max | float | The largest representable number. | +| min | float | The smallest representable number (typically ``-max``). | +| tiny | float | The smallest positive normal number. Equivalent to ``smallest_normal``. | +| smallest_normal | float | The smallest positive normal number. See notes. | +| resolution | float | The approximate decimal resolution of this type, i.e., ``10**-precision``. | +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ```{note} The constructor of {class}`torch.finfo` can be called without argument, diff --git a/functorch/README.md b/functorch/README.md index 5e16966b1daa9..0eed90d46a9f3 100644 --- a/functorch/README.md +++ b/functorch/README.md @@ -7,7 +7,11 @@ | [**Future Plans**](#future-plans) **This library is currently under heavy development - if you have suggestions +<<<<<<< HEAD on the API or use-cases you'd like to be covered, please open a GitHub issue +======= +on the API or use-cases you'd like to be covered, please open an github issue +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) or reach out. We'd love to hear about how you're using the library.** `functorch` is [JAX-like](https://github.com/google/jax) composable function @@ -161,7 +165,11 @@ result = vmap(model)(examples) ### grad +<<<<<<< HEAD `grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It computes +======= +`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It compute +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) the gradients of the output of func w.r.t. to `inputs[0]`. ```py @@ -192,7 +200,11 @@ def compute_loss(weights, example, target): weights = torch.randn(feature_size, requires_grad=True) examples = torch.randn(batch_size, feature_size) targets = torch.randn(batch_size) +<<<<<<< HEAD inputs = (weights, examples, targets) +======= +inputs = (weights,examples, targets) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) ``` diff --git a/functorch/benchmarks/chrome_trace_parser.py b/functorch/benchmarks/chrome_trace_parser.py index cc641c1cf81c9..762551d16408c 100755 --- a/functorch/benchmarks/chrome_trace_parser.py +++ b/functorch/benchmarks/chrome_trace_parser.py @@ -66,7 +66,11 @@ def main(): filenames, total_length ) print(f"{modelname}, {utilization}, {mm_conv_utilization}") +<<<<<<< HEAD except BaseException: # noqa: B036 +======= + except BaseException: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log.exception("%s, ERROR", filename) print(f"{filename}, ERROR") diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp index 8f1e561e2051b..f2c97433aa269 100644 --- a/functorch/csrc/dim/dim.cpp +++ b/functorch/csrc/dim/dim.cpp @@ -6,6 +6,10 @@ #include +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Many APIs have changed/don't exist anymore #if IS_PYTHON_3_12_PLUS @@ -13,13 +17,19 @@ // Re-enable this some day PyObject* Dim_init() { +<<<<<<< HEAD PyErr_SetString( PyExc_RuntimeError, "First class dim doesn't work with python 3.12"); return nullptr; +======= + PyErr_SetString(PyExc_RuntimeError, "First class dim doesn't work with python 3.12"); + return nullptr; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } #else +<<<<<<< HEAD #include #include #include @@ -32,6 +42,20 @@ PyObject* Dim_init() { #include #include #include +======= +#include "minpybind.h" +#include +#include +#include +#include +#include +#include +//#include +#include +#include +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include "arena.h" #include "dim.h" @@ -71,6 +95,7 @@ PyTypeObject* DimType = nullptr; PyObject* Tensor_getitem(PyObject* self, PyObject* index); int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value); +<<<<<<< HEAD namespace { void maybeInitializeGlobals() { // globals that depend on the python dim library, @@ -119,20 +144,79 @@ void initializeGlobals(Arena& A) { THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript; NamedTuple = mpy::import("typing").attr("NamedTuple"); no_slice = PySlice_New(NULL, NULL, NULL); +======= +namespace{ +void maybeInitializeGlobals() { + // globals that depend on the python dim library, + // which we can't lookup until we finish initializing the _C module + if (_Tensor.ptr()) { + return; + } + auto dim = mpy::import("functorch.dim"); + _Tensor = dim.attr("_Tensor"); + pointwise = dim.attr("pointwise"); + _Tensor_sum = _Tensor.attr("sum"); + DimType = (PyTypeObject*) mpy::import("functorch.dim").attr("Dim").ptr(); +} + +void replaceMappingIfMatches(mpy::handle tp) { + auto T = (PyTypeObject*) tp.ptr(); + bool recurse = false; + if (T->tp_as_mapping->mp_subscript == THPVariable_getitem) { + T->tp_as_mapping->mp_subscript = Tensor_getitem; + recurse = true; + } + if (T->tp_as_mapping->mp_ass_subscript == THPVariable_setitem) { + T->tp_as_mapping->mp_ass_subscript = Tensor_setitem; + recurse = true; + } + if (recurse) { + auto result = tp.attr("__subclasses__").call(); + mpy::list_view lv(result); + for (auto i : lv.enumerate()) { + replaceMappingIfMatches(lv[i]); + } + } +} + +void initializeGlobals(Arena & A) { + auto torch = mpy::import("torch"); + torch_Tensor = (PyTypeObject*) torch.attr("Tensor").ptr(); + torch_Tensor___mul__ = torch.attr("Tensor").attr("__mul__"); + + torch_Tensor_expand = torch.attr("_C").attr("TensorBase").attr("expand"); + torch_Tensor_split = torch.attr("_C").attr("TensorBase").attr("split"); + torch_Tensor_copy_ = torch.attr("Tensor").attr("copy_"); + auto py_TensorBase = torch.attr("_C").attr("TensorBase"); + auto TensorBase = (PyTypeObject*) py_TensorBase.ptr(); + THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript; + THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript; + NamedTuple = mpy::import("typing").attr("NamedTuple"); + no_slice = PySlice_New(NULL, NULL, NULL); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } mpy::handle DimensionBindError_; mpy::handle DimensionBindError() { +<<<<<<< HEAD if (!DimensionBindError_.ptr()) { DimensionBindError_ = mpy::import("functorch.dim").attr("DimensionBindError"); } return DimensionBindError_; +======= + if(!DimensionBindError_.ptr()) { + DimensionBindError_ = mpy::import("functorch.dim").attr("DimensionBindError"); + } + return DimensionBindError_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } static int64_t n_dims_created = 65; struct Dim : public mpy::base { +<<<<<<< HEAD int64_t level_; // for stable comparisons in prototype mpy::object name_; Dim() : level_(n_dims_created++) {} @@ -229,10 +313,105 @@ struct DimEntry { private: int64_t data_; +======= + int64_t level_; // for stable comparisons in prototype + mpy::object name_; + Dim() + : level_(n_dims_created++) {} + void init(mpy::object name, int64_t s = -1) { + name_ = std::move(name); + size_ = s; + } + + static bool check_exact(mpy::handle v) { + return Py_TYPE(v.ptr()) == DimType; + } + + int64_t size() const { + if (size_ == -1) { + mpy::raise_error(PyExc_ValueError, "dimension %S is unbound", name_.ptr()); + } + return size_; + } + void set_size(int64_t v) { + if (size_ == -1) { + size_ = v; + } else if(size_ != v) { + mpy::raise_error(DimensionBindError(), "Dim '%R' previously bound to a dimension of size %lld cannot bind to a dimension of size %lld", this, this->size_, v); + } + } + bool is_bound() const { + return size_ != -1; + } + static mpy::obj create(mpy::object name, int64_t s = -1) { + if (!DimType) { + maybeInitializeGlobals(); + } + auto r = Dim::alloc(DimType); + r->init(std::move(name), s); + return r; + } + static PyTypeObject Type; + const at::Tensor& range() { + if (!range_.defined()) { + range_ = at::arange(size()); + } + return range_; + } + const at::Tensor& batchtensor() { + if (!batchtensor_.defined()) { + batchtensor_ = at::functorch::addBatchDim(range(), 0, level_); + } + return batchtensor_; + } +private: + int64_t size_{-1}; + at::Tensor range_; + at::Tensor batchtensor_; +}; + + +struct DimEntry { + // union of either a negative number indicating which dimension this is from the rhs, + // or a pointer to a first-class dimension. + // pointers do not have their highest bit set, so checking the number is negative tells us + // that it is not a dim. + bool is_positional() const { + return data_ < 0; + } + bool is_none() const { + return data_ == 0; + } + int64_t position() const { + return data_; + } + mpy::hdl dim() const { + Dim* result; + std::memcpy(&result, &data_, sizeof(Dim*)); + return mpy::hdl(result); + } + + DimEntry() + : data_(0) {} + + DimEntry(int64_t pos) + : data_(pos) { + AT_ASSERT(pos < 0); + } + DimEntry(mpy::hdl d) { + std::memcpy(&data_, &d, sizeof(int64_t)); + } + bool operator==(const DimEntry& rhs) const { + return data_ == rhs.data_; + } +private: + int64_t data_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; // Dim wrapper methods DimEntry _wrap_dim(mpy::handle d, size_t N, bool keepdim) { +<<<<<<< HEAD if (Dim::check(d)) { if (keepdim) { mpy::raise_error( @@ -375,11 +554,144 @@ PyTypeObject Dim::Type = { Dim_init), /* tp_init */ 0, /* tp_alloc */ Dim::new_stub, /* tp_new */ +======= + if (Dim::check(d)) { + if (keepdim) { + mpy::raise_error(PyExc_ValueError, "cannot preserve first-class dimensions with keepdim=True"); + } + return Dim::unchecked_wrap(d); + } else if (mpy::is_int(d)) { + auto i = mpy::to_int(d); + while (i >= 0) { + i -= N; + } + return i; + } else { + return DimEntry(); + } +} + + +int Dim_init(mpy::hdl self, PyObject *args, PyObject *kwds) { + PY_BEGIN + static constexpr const char* kwlist[] = {"name", "size", nullptr}; + mpy::handle name; + mpy::handle size = nullptr; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", const_cast(kwlist), &name, &size)) { + return -1; + } + self->init(mpy::object::borrow(name), (size.ptr() && !mpy::is_none(size)) ? mpy::to_int(size) : -1); + return 0; + PY_END(-1) +} + +PyObject* Dim_repr(Dim* self) { + PY_BEGIN + mpy::object name = (self->name_.ptr()) ? self->name_ : mpy::unicode_from_string(""); + return name.release(); + PY_END(nullptr) +} + + +PyObject* Dim_getsize(Dim* self, void*) { + PY_BEGIN + return mpy::from_int(self->size()).release(); + PY_END(nullptr) +} + +int Dim_setsize(Dim* self, PyObject* size, void*) { + PY_BEGIN + self->set_size(mpy::to_int(size)); + return 0; + PY_END(-1) +} + +PyObject* Dim_getis_bound(Dim* self, void*) { + return PyBool_FromLong(self->is_bound()); +} + +PyObject* Dim_getlevel(Dim* self, void*) { + return PyLong_FromLong(self->level_); +} + +PyObject* Dim_get_levels(Dim* self, void*) { + mpy::tuple t(1); + t.set(0, mpy::object::borrow(self->ptr())); + return t.release(); +} + +PyObject* Dim_get_has_device(Dim* self, void*) { + Py_RETURN_FALSE; +} + +PyObject* Dim_get_tensor(Dim* self, void*) { + return THPVariable_Wrap(self->range()); +} + +PyObject* Dim_get_batchtensor(Dim* self, void*) { + return THPVariable_Wrap(self->batchtensor()); +} + + +PyGetSetDef Dim_getsetters[] = { + {"size", (getter) Dim_getsize, (setter) Dim_setsize, + "Dimension size", NULL}, + {"is_bound", (getter) Dim_getis_bound, NULL, "is_bound", NULL}, + {"_level", (getter) Dim_getlevel, NULL, "_level", NULL}, + {"_levels", (getter) Dim_get_levels, NULL, "_levels", NULL}, + {"_has_device", (getter) Dim_get_has_device, NULL, "_has_device", NULL}, + {"_tensor", (getter) Dim_get_tensor, NULL, "_tensor", NULL}, + {"_batchtensor", (getter) Dim_get_batchtensor, NULL, "_batchtensor", NULL}, + {"ndim", (getter) [](PyObject* self, void*) -> PyObject* { return mpy::from_int(1).release(); }, NULL, "ndim", NULL}, + {NULL} /* Sentinel */ +}; +} +PyTypeObject Dim::Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "_C.Dim", /* tp_name */ + sizeof(Dim), /* tp_basicsize */ + 0, /* tp_itemsize */ + Dim::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + (reprfunc)Dim_repr, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "Dim Object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + Dim_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)(void*)static_cast,PyObject*,PyObject*)>(Dim_init), /* tp_init */ + 0, /* tp_alloc */ + Dim::new_stub, /* tp_new */ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; // class DimList ------------ struct DimList : public mpy::base { +<<<<<<< HEAD mpy::object name_; std::vector> dims_; static PyTypeObject Type; @@ -643,11 +955,268 @@ static int DimList_init(DimList* self, PyObject* args, PyObject* kwds) { } return 0; PY_END(-1); +======= + mpy::object name_; + std::vector> dims_; + static PyTypeObject Type; + void init(mpy::object name) { + name_ = std::move(name); + } + void set_dims(std::vector> dims) { + bound_ = true; + dims_ = std::move(dims); + } + bool is_bound() { + return bound_; + } + void bind_len(int64_t size) { + if (bound_) { + int64_t b_size = dims_.size(); + if (b_size != size) { + mpy::raise_error(DimensionBindError(), "Dimlist has size %lld but it is being bound to size %d", b_size, size); + } + } else { + bound_ = true; + dims_.resize(size); + for (Py_ssize_t i = 0; i < size; ++i) { + dims_[i] = Dim::create(mpy::unicode_from_format("%S%i", name_.ptr(), (int)i)); + } + } + } + int64_t size() const { + if (!bound_) { + mpy::raise_error(DimensionBindError(), "DimList not bound"); + } + return dims_.size(); + } + void set_bound(bool b) { + bound_ = b; + } +private: + bool bound_ = false; +}; + + +static int DimList_init(DimList *self, PyObject *args, PyObject *kwds); + +static PyObject* DimList_repr(DimList* self) { + PY_BEGIN + if (self->is_bound()) { + size_t size = self->dims_.size(); + mpy::tuple t(size); + for(size_t i = 0; i < size; ++i) { + t.set(i, self->dims_[i]); + } + return mpy::repr(t).release(); + } else if(!mpy::is_none(self->name_)) { + return mpy::unicode_from_format("*%S", self->name_.ptr()).release(); + } else { + return mpy::unicode_from_string("").release(); + } + PY_END(nullptr) +} + +static PyObject* DimList_bind(DimList *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + mpy::handle sizes; + static const char * const _keywords[] = {"sizes", nullptr}; + static _PyArg_Parser parser = {"O", _keywords, 0}; + if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &sizes)) { + return nullptr; + } + if (!mpy::is_sequence(sizes)) { + mpy::raise_error(PyExc_ValueError, "expected a sequence"); + } + mpy::sequence_view seq = sizes; + auto size = seq.size(); + self->bind_len(size); + for (Py_ssize_t i = 0; i < size; ++i) { + self->dims_[i]->set_size(mpy::to_int(seq[i])); + } + Py_RETURN_NONE; + PY_END(nullptr) +} + +static PyObject* DimList_bind_len(DimList *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + int size; + static const char * const _keywords[] = {"N", nullptr}; + static _PyArg_Parser parser = {"i", _keywords, 0}; + if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &size)) { + return nullptr; + } + self->bind_len(size); + Py_RETURN_NONE; + PY_END(nullptr) +} + +static PyMethodDef DimList_methods[] = { + {"bind", (PyCFunction)(void*) DimList_bind, METH_FASTCALL | METH_KEYWORDS}, + {"bind_len", (PyCFunction)(void*) DimList_bind_len, METH_FASTCALL | METH_KEYWORDS}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + + +static Py_ssize_t DimList_len(DimList* self) { + PY_BEGIN + return self->size(); + PY_END(-1) +} + +static PyObject * DimList_item(DimList* self, Py_ssize_t idx) { + PY_BEGIN + if (!self->is_bound()) { + mpy::raise_error(DimensionBindError(), "DimList not bound"); + } + if (idx < 0 || (size_t) idx >= self->dims_.size()) { + mpy::raise_error(PyExc_IndexError, "index out of bounds"); + } + mpy::object r = self->dims_[idx]; + return r.release(); + PY_END(nullptr) +} + +PySequenceMethods DimList_seq { + (lenfunc) DimList_len, //lenfunc sq_length; + 0, //binaryfunc sq_concat; + 0, //ssizeargfunc sq_repeat; + (ssizeargfunc) DimList_item, //ssizeargfunc sq_item; + 0, //void *was_sq_slice; + 0, //ssizeobjargproc sq_ass_item; + 0, //void *was_sq_ass_slice; + 0, //objobjproc sq_contains; + + 0, //binaryfunc sq_inplace_concat; + 0, //ssizeargfunc sq_inplace_repeat; +}; + +static PyObject* DimList_getis_bound(DimList* self, void*) { + return PyBool_FromLong(self->is_bound()); +} + +static PyGetSetDef DimList_getsetters[] = { + {"is_bound", (getter) DimList_getis_bound, NULL, "is_bound", NULL}, + {NULL} /* Sentinel */ +}; + + +static PyObject* DimList_subscript(DimList* self, mpy::handle idx) { + PY_BEGIN + if (mpy::is_int(idx)) { + return DimList_item(self, mpy::to_int(idx)); + } else if (mpy::is_slice(idx)) { + if (!self->is_bound()) { + mpy::raise_error(DimensionBindError(), "DimList not bound"); + } + mpy::slice_view s(idx, self->dims_.size()); + mpy::tuple r(s.slicelength); + for (Py_ssize_t i = s.start, j = 0; i < s.stop; i += s.step) { + r.set(j++, self->dims_[i]); + } + return r.release(); + } else { + mpy::raise_error(PyExc_ValueError, "expected an int or a slice"); + return nullptr; + } + PY_END(nullptr) +} + +PyMappingMethods DimList_mapping = { + 0, //lenfunc mp_length; + (binaryfunc)(void*) DimList_subscript, //binaryfunc mp_subscript; + 0, //objobjargproc mp_ass_subscript; +}; + + + +PyTypeObject DimList::Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "_C.DimList", /* tp_name */ + sizeof(DimList), /* tp_basicsize */ + 0, /* tp_itemsize */ + DimList::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + (reprfunc)DimList_repr, /* tp_repr */ + 0, /* tp_as_number */ + &DimList_seq, /* tp_as_sequence */ + &DimList_mapping, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + 0, /* tp_flags */ + "DimList Object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + DimList_methods, /* tp_methods */ + 0, /* tp_members */ + DimList_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc) DimList_init, /* tp_init */ + 0, /* tp_alloc */ + DimList::new_stub, /* tp_new */ +}; + +static int DimList_init(DimList *self, PyObject *args, PyObject *kwds) { + PY_BEGIN + static constexpr const char* kwlist[] = {"len_or_dims", "name", nullptr}; + mpy::handle len_or_dims = nullptr; + PyObject* name = nullptr; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", const_cast(kwlist), &len_or_dims, &name)) { + return -1; + } + self->init(mpy::object::borrow(name ? name : Py_None)); + if (len_or_dims.ptr()) { + if(mpy::is_int(len_or_dims)) { + self->bind_len(mpy::to_int(len_or_dims)); + } else if (mpy::is_sequence(len_or_dims)) { + mpy::sequence_view s(len_or_dims); + std::vector> dims; + size_t size = s.size(); + dims.reserve(size); + for (size_t i = 0; i < size; ++i) { + auto r = s[i]; + if (mpy::is_int(r)) { + dims.emplace_back(Dim::create(mpy::unicode_from_format("%S%i", self->name_.ptr(), (int)i), mpy::to_int(r))); + } else { + dims.emplace_back(Dim::wrap(r)); + } + } + self->set_dims(std::move(dims)); + } else { + PyErr_Format(PyExc_ValueError, "expected a length or a sequence of dimensions"); + return -1; + } + return 0; + } + return 0; + PY_END(-1); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // Tensor ----------------------------- PyTypeObject* TensorType = nullptr; // the python wrapper type. +<<<<<<< HEAD mpy::object run_torch_function( Arena& A, mpy::handle orig, @@ -1167,82 +1736,556 @@ mpy::object tree_map( elements[i] = fn(elements[i]); } return unflatten(elements); -} +======= +mpy::object run_torch_function(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise); -// prereq: isinstance(h, _Tensor) -int64_t _Tensor_ndim(mpy::handle h) { - if (Tensor::check(h)) { - int64_t r = 0; - for (auto l : Tensor::unchecked_wrap(h)->levels()) { - if (l.is_positional()) { - ++r; - } +namespace{ + +at::Tensor _add_batch_dims(Arena& A, at::Tensor t, Slice levels_) { + auto levels = Slice(); + levels.extend(A, levels_); + while (true) { + int64_t min_real_index = -1; + int64_t min_index = -1; + int64_t min_value = INT_MAX; + int64_t i = 0; + int64_t r = 0; + for (auto l : levels) { + if (!l.is_none()) { + if (!l.is_positional() && l.dim()->level_ < min_value) { + min_value = l.dim()->level_; + min_index = i; + min_real_index = r; + } + ++i; + } + ++r; + } + if (min_index == -1) { + return t; + } + auto t2 = at::functorch::addBatchDim(std::move(t), min_index, min_value); + t = std::move(t2); + levels[min_real_index] = DimEntry(); } - return r; - } - // Dim or DelayedMulTensor - return 0; } -mpy::handle handle_from_tensor(Arena& A, TensorRef t) { - // fast case: tensor is live in python - std::optional mb_obj = - t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - if (mb_obj.has_value() && - !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { - return *mb_obj; - } - return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t))); -} -} // namespace -struct EnableAllLayers { - EnableAllLayers(Arena& A, Slice levels) { - std::vector> layers; - layers.reserve(levels.size()); - for (auto l : levels) { - if (!l.is_positional()) { - auto d = l.dim(); - levels_to_dim_.append(A, d); - } - } - std::sort( - levels_to_dim_.begin(), - levels_to_dim_.end(), - [](mpy::hdl lhs, mpy::hdl rhs) { - return lhs->level_ < rhs->level_; - }); - for (auto i : levels_to_dim_.enumerate()) { - auto batch_size = levels_to_dim_[i]->size(); - auto level = at::functorch::initAndPushDynamicLayer( - at::functorch::TransformType::Vmap, - batch_size, - at::functorch::RandomnessType::Different); - if (i == 0) { - levels_start_ = level; - } + +struct DelayedOperator { + DelayedOperator(mpy::object o, mpy::vector_args a) + : orig(std::move(o)), args(a) { + auto all = a.size(); + // this will outlive the call so + // take ownership of temporaries + // in vector args + auto buf = new mpy::handle[all]; + memcpy(buf, args.args, sizeof(mpy::handle)*all); + args.args = buf; + for (auto i : args.enumerate_all()) { + Py_INCREF(args.args[i].ptr()); + } + Py_XINCREF(args.kwnames.ptr()); } - } + ~DelayedOperator() { + for (auto i : args.enumerate_all()) { + Py_DECREF(args[i].ptr()); + } + if (args.has_keywords()) { + Py_XDECREF(args.kwnames.ptr()); + } + delete [] args.args; + } + mpy::object orig; + mpy::vector_args args; +}; - ~EnableAllLayers() { - auto to_remove = levels_start_ + levels_to_dim_.size() - 1; - for (auto i : levels_to_dim_.enumerate()) { - AT_ASSERT( - at::functorch::popDynamicLayerAndDeleteMetadata().layerId() == - to_remove - i); +void free_levels_dims(Slice levels) { + for(auto e : levels) { + if (!e.is_positional()) { + mpy::object::steal(e.dim()); + } } - } +} +} - mpy::obj from_batched( - Arena& A, - at::Tensor batchedtensor, - bool has_device) { - Slice levels; - for (auto i : irange(-batchedtensor.dim(), 0)) { - levels.append(A, i); +struct Tensor : public mpy::base { +private: + at::Tensor tensor_; + at::Tensor batchtensor_; + OwnedSlice levels_; + bool has_device_; + std::unique_ptr delayed_; +public: + + at::Tensor& tensor(Arena& A) { + if (C10_UNLIKELY(!tensor_.defined())) { + AT_ASSERT(delayed_); + auto t = Tensor::wrap(run_torch_function(A, delayed_->orig, delayed_->args, true)); + tensor_ = t->tensor(A); + delayed_.reset(); + // don't force creation of batch tensor if it wasn't already provided. + batchtensor_ = t->batchtensor_; + AT_ASSERT(levels() == t->levels()); + } + return tensor_; } - TensorRef tensor; + at::Tensor& batchtensor(Arena& A) { + if (C10_UNLIKELY(!batchtensor_.defined())) { + batchtensor_ = _add_batch_dims(A, tensor(A), levels_.slice()); + } + return batchtensor_; + } + Slice levels() { + return levels_.slice(); + } + bool has_device() { + return has_device_; + } + DelayedOperator* delayed() { + return delayed_.get(); + } + static PyTypeObject Type; + + static bool check_exact(mpy::handle v) { + return Py_TYPE(v.ptr()) == TensorType; + } + + + static mpy::obj create() { + if (!TensorType) { + TensorType = (PyTypeObject*) mpy::import("functorch.dim").attr("Tensor").release(); + } + return Tensor::alloc(TensorType); + } + void capture_levels(Slice levels) { + // grab ownership of the dims inside levels + for (auto l : levels) { + if (!l.is_positional()) { + mpy::object::borrow(l.dim()).release(); + } + } + levels_.set(levels, free_levels_dims); + } + static mpy::object from_positional(Arena & A, at::Tensor tensor, Slice levels, bool has_device); + static mpy::obj create_delayed(mpy::object op, mpy::vector_args args, Slice levels, bool has_device); + friend struct EnableAllLayers; +}; + +namespace{ +// version in header does a unnecessary refcount +/- +at::functorch::BatchedTensorImpl* maybeGetBatchedImpl(const at::Tensor& tensor) { + if (at::functorch::isBatchedTensor(tensor)) { + return static_cast(tensor.unsafeGetTensorImpl()); + } + return nullptr; +} + +TensorRef unchecked_tensor_from(mpy::handle p) { + auto v = (THPVariable*) p.ptr(); + return TensorRef(*v->cdata); +} + +static int64_t ndim_of_levels(Slice levels) { + int64_t r = 0; + for (auto l : levels) { + if (l.is_positional()) { + ++r; + } + } + return r; +} + +struct TensorInfo { + TensorRef tensor; + Slice levels; + bool has_device; + TensorRef batchedtensor; + int64_t ndim() const { + return ndim_of_levels(levels); + } + operator bool() const { + return tensor; + } + + static TensorInfo create(Arena& A, mpy::handle h, bool ensure_batched=true, bool ensure_present=true) { + if (Tensor::check_exact(h)) { + auto t = Tensor::unchecked_wrap(h); + return TensorInfo {t->tensor(A), t->levels(), t->has_device(), ensure_batched ? t->batchtensor(A) : TensorRef()}; + } else if (Dim::check_exact(h)) { + auto d = Dim::unchecked_wrap(h); + return TensorInfo {d->range(), Slice(A, DimEntry(d)), false, ensure_batched ? d->batchtensor() : TensorRef()}; + } else if (THPVariable_Check(h.ptr())) { + TensorRef t = unchecked_tensor_from(h); + Slice levels; + for (auto i : irange(-t->dim(), 0)) { + levels.append(A, i); + } + return TensorInfo {t, levels, true, t}; + } else { + if (ensure_present) { + mpy::raise_error(PyExc_ValueError, "expected a tensor object"); + } + return TensorInfo {}; + } + } + + +}; + +static PyObject* py_Tensor_from_positional(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + #define ARGS(_) _(mpy::handle, tensor) _(mpy::handle, py_levels) _(int, has_device) + MPY_PARSE_ARGS_KWNAMES("OOp", ARGS) + #undef ARGS + + if (!THPVariable_Check(tensor.ptr())) { + mpy::raise_error(PyExc_ValueError, "_tensor is not a Tensor?"); + } + + Slice levels; + mpy::sequence_view sq(py_levels); + for (auto i : sq.enumerate()) { + mpy::object v = sq[i]; + if (mpy::is_int(v)) { + auto vi = mpy::to_int(v); + levels.append(A, vi); + } else { + auto dim = Dim::wrap(std::move(v)); + mpy::hdl hdim = dim; + levels.append(A, hdim); + } + } + return Tensor::from_positional(A, THPVariable_Unpack(tensor.ptr()), levels, has_device != 0).release(); + PY_END(nullptr) +} +} + +mpy::object Tensor::from_positional(Arena & A, at::Tensor tensor, Slice levels, bool has_device) { + size_t seen_dims = 0; + int last = 0; + //auto sz = tensor.sizes(); + for (auto i : levels.enumerate()) { + auto l = levels[i]; + if (l.is_positional()) { + AT_ASSERT(last == 0 || last + 1 == l.position()); + last = l.position(); + } else { + mpy::object::borrow(l.dim()).release(); + //AT_ASSERT(sz[i] == l.dim()->size()); + ++seen_dims; + } + } + AT_ASSERT(last == 0 || last == -1); + if (!seen_dims) { + return mpy::object::steal(THPVariable_Wrap(tensor)); + } + + mpy::obj self = Tensor::create(); + self->tensor_ = std::move(tensor); + AT_ASSERT(self->tensor_.dim() == levels.size()); + self->levels_.set(levels, free_levels_dims); + self->has_device_ = has_device; + mpy::object r = std::move(self); + return r; +} + + +mpy::obj Tensor::create_delayed(mpy::object op, mpy::vector_args args, Slice levels, bool has_device) { + mpy::obj self = Tensor::create(); + self->capture_levels(levels); + self->has_device_ = has_device; + self->delayed_ = std::make_unique(std::move(op), args); + return self; +} + +namespace{ +mpy::list slice_to_list(Slice h) { + mpy::list lst(h.size()); + for (auto i : h.enumerate()) { + lst.set(i, mpy::object::borrow(h[i])); + } + return lst; +} + +mpy::tuple slice_to_tuple(Slice h) { + mpy::tuple lst(h.size()); + for (auto i : h.enumerate()) { + lst.set(i, mpy::object::borrow(h[i])); + } + return lst; +} + +enum UType { + U_ELEM, + U_TUPLE_LIKE, + U_DICT, +}; + +struct Unflatten { + mpy::object operator()(Slice& elements) { + mpy::object r; + switch (type) { + case U_ELEM: { + r = mpy::object::borrow(elements[0]); + elements = elements.slice(1); + } break; + case U_TUPLE_LIKE: { + mpy::tuple tup(children.size()); + for (auto i : children.enumerate()) { + tup.set(i, children[i](elements)); + } + r = obj.call(tup); + } break; + case U_DICT: { + r = mpy::object::checked_steal(PyDict_New()); + mpy::dict_view rv(r); + mpy::dict_view d(obj); + Py_ssize_t pos = 0; + mpy::handle k, v; + for (int i = 0; d.next(&pos, &k, &v); ++i) { + rv.set(k, children[i](elements)); + } + } break; + } + return r; + } + UType type; + mpy::handle obj; + Slice children; +}; + +Unflatten tree_flatten(Arena& A, mpy::handle agg, Slice& flat_elements) { + Slice c; + UType utype; + mpy::handle obj; + if (mpy::list_view::check(agg)) { + obj = agg.type(); + utype = U_TUPLE_LIKE; + mpy::list_view l(agg); + for (auto i : l.enumerate()) { + c.append(A, tree_flatten(A, l[i], flat_elements)); + } + } else if (mpy::tuple_view::check(agg)) { + obj = agg.type(); + utype = U_TUPLE_LIKE; + // includes named tuples + mpy::tuple_view l(agg); + for (auto i : l.enumerate()) { + c.append(A, tree_flatten(A, l[i], flat_elements)); + } + } else if (mpy::dict_view::check(agg)) { + utype = U_DICT; + mpy::dict_view d(agg); + obj = agg; + Py_ssize_t pos = 0; + mpy::handle k, v; + while (d.next(&pos, &k, &v)) { + c.append(A, tree_flatten(A, v, flat_elements)); + } + } else { + utype = U_ELEM; + flat_elements.append(A, agg); + } + return Unflatten {utype, obj, c}; +} + +struct UnflattenVectorArgs { + mpy::vector_args operator()(Arena& A, Slice& elements) { + if (!had_nested) { + auto args = elements.begin(); + elements = Slice(); + return mpy::vector_args(args, nargs, kwnames); + } + Slice args; + for (auto u : children) { + args.append(A, A.autorelease(u(elements))); + } + return mpy::vector_args(args.begin(), nargs, kwnames); + } + Slice children; + Py_ssize_t nargs; + mpy::handle kwnames; + bool had_nested; +}; + +UnflattenVectorArgs tree_flatten(Arena& A, mpy::vector_args args, Slice& flat_elements) { + UnflattenVectorArgs r; + r.kwnames = args.kwnames; + r.nargs = args.nargs; + r.had_nested = false; + auto N = args.size(); + for(auto i : irange(N)) { + auto typ = Py_TYPE(args[i].ptr()); + // fast checks that this thing isn't something that is nested. + bool is_element = !typ->tp_as_sequence || typ == torch_Tensor || typ == TensorType || typ == DimType; + if (!is_element) { + flat_elements.extend(A, args.args, args.args + i); + for (auto j : irange(i)) { + (void)j; + r.children.append(A, Unflatten {U_ELEM}); + } + for (auto j : irange(i, N)) { + r.children.append(A, tree_flatten(A, args[j], flat_elements)); + if (r.children.back().type != U_ELEM) { + r.had_nested = true; + } + } + return r; + } + } + flat_elements.extend(A, args.args, args.args + N); + return r; +} + + +struct UnflattenArena { + Arena A; + Unflatten unflatten; +}; + +PyObject* py_unflatten(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + #define ARGS(_) _(mpy::handle, ns) + MPY_PARSE_ARGS_KWNAMES("O", ARGS) + #undef ARGS + mpy::sequence_view sv(ns); + // because we do not have a autorelase pool yet... + Arena A; + Slice slice; + mpy::handle Tuple = (PyObject*) &PyTuple_Type; + auto inputs = Tuple.call(ns); + mpy::tuple_view tv(inputs); + for (auto i : tv.enumerate()) { + slice.append(A, tv[i]); + } + auto AA = (UnflattenArena*) PyCapsule_GetPointer(self, "arena"); + auto r = AA->unflatten(slice).release(); + AT_ASSERT(r != nullptr); + return r; + PY_END(nullptr) +} + +PyMethodDef py_unflatten_def = {"unflatten", (PyCFunction)(void*) py_unflatten, METH_FASTCALL | METH_KEYWORDS}; + +void free_unflatten_arena(PyObject * pc) { + delete (UnflattenArena*) PyCapsule_GetPointer(pc, "arena"); +} + +PyObject* py_tree_flatten(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + #define ARGS(_) _(mpy::handle, tree) + MPY_PARSE_ARGS_KWNAMES("O", ARGS) + #undef ARGS + auto A = new UnflattenArena; + Slice elements; + A->unflatten = tree_flatten(A->A, tree, elements); + auto cap = mpy::object::checked_steal(PyCapsule_New(A, "arena", free_unflatten_arena)); + auto unflatten = mpy::object::checked_steal(PyCFunction_New(&py_unflatten_def, cap.release())); + mpy::tuple r(2); + r.set(0, slice_to_list(elements)); + r.set(1, std::move(unflatten)); + return r.release(); + PY_END(nullptr) +} + + + +mpy::object tree_map(Arena& A, const std::function& fn, mpy::handle agg) { + Slice elements; + auto unflatten = tree_flatten(A, agg, elements); + for (auto i : elements.enumerate()) { + elements[i] = fn(elements[i]); + } + return unflatten(elements); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +} + +// prereq: isinstance(h, _Tensor) +int64_t _Tensor_ndim(mpy::handle h) { +<<<<<<< HEAD + if (Tensor::check(h)) { + int64_t r = 0; + for (auto l : Tensor::unchecked_wrap(h)->levels()) { + if (l.is_positional()) { + ++r; + } + } + return r; + } + // Dim or DelayedMulTensor + return 0; +} + +mpy::handle handle_from_tensor(Arena& A, TensorRef t) { + // fast case: tensor is live in python + std::optional mb_obj = + t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( + /*ignore_hermetic_tls=*/false); + if (mb_obj.has_value() && + !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { + return *mb_obj; + } + return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t))); +} +} // namespace +struct EnableAllLayers { + EnableAllLayers(Arena& A, Slice levels) { + std::vector> layers; + layers.reserve(levels.size()); + for (auto l : levels) { + if (!l.is_positional()) { + auto d = l.dim(); + levels_to_dim_.append(A, d); + } + } + std::sort( + levels_to_dim_.begin(), + levels_to_dim_.end(), + [](mpy::hdl lhs, mpy::hdl rhs) { + return lhs->level_ < rhs->level_; + }); + + for (auto i : levels_to_dim_.enumerate()) { + auto batch_size = levels_to_dim_[i]->size(); + auto level = at::functorch::initAndPushDynamicLayer( + at::functorch::TransformType::Vmap, + batch_size, + at::functorch::RandomnessType::Different); + if (i == 0) { + levels_start_ = level; + } + } + } + + ~EnableAllLayers() { + auto to_remove = levels_start_ + levels_to_dim_.size() - 1; + for (auto i : levels_to_dim_.enumerate()) { + AT_ASSERT( + at::functorch::popDynamicLayerAndDeleteMetadata().layerId() == + to_remove - i); + } + } + + mpy::obj from_batched( + Arena& A, + at::Tensor batchedtensor, + bool has_device) { + Slice levels; + for (auto i : irange(-batchedtensor.dim(), 0)) { + levels.append(A, i); + } + TensorRef tensor; at::functorch::BatchedTensorImpl* impl = maybeGetBatchedImpl(batchedtensor); while (true) { auto level = impl->level(); @@ -1639,10 +2682,412 @@ static mpy::object create_dimlist(mpy::object name, mpy::handle size) { // Python wrappers that make new reflection primitives available for older // runtimes +======= + if (Tensor::check(h)) { + int64_t r = 0; + for (auto l : Tensor::unchecked_wrap(h)->levels()) { + if (l.is_positional()) { + ++r; + } + } + return r; + } + // Dim or DelayedMulTensor + return 0; +} + +mpy::handle handle_from_tensor(Arena& A, TensorRef t) { + // fast case: tensor is live in python + std::optional mb_obj = + t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter(), /*ignore_hermetic_tls=*/false); + if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { + return *mb_obj; + } + return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t))); +} +} +struct EnableAllLayers { + EnableAllLayers(Arena& A, Slice levels) { + std::vector> layers; + layers.reserve(levels.size()); + for (auto l : levels) { + if (!l.is_positional()) { + auto d = l.dim(); + levels_to_dim_.append(A, d); + } + } + std::sort(levels_to_dim_.begin(), levels_to_dim_.end(), [](mpy::hdl lhs, mpy::hdl rhs) { return lhs->level_ < rhs->level_;}); + + for (auto i : levels_to_dim_.enumerate()) { + auto batch_size = levels_to_dim_[i]->size(); + auto level = at::functorch::initAndPushDynamicLayer(at::functorch::TransformType::Vmap, batch_size, at::functorch::RandomnessType::Different); + if (i == 0) { + levels_start_ = level; + } + } + } + + ~EnableAllLayers() { + auto to_remove = levels_start_ + levels_to_dim_.size() - 1; + for (auto i : levels_to_dim_.enumerate()) { + AT_ASSERT(at::functorch::popDynamicLayerAndDeleteMetadata().layerId() == to_remove - i); + } + } + + mpy::obj from_batched(Arena& A, at::Tensor batchedtensor, bool has_device) { + Slice levels; + for (auto i : irange(-batchedtensor.dim(), 0)) { + levels.append(A, i); + } + TensorRef tensor; + at::functorch::BatchedTensorImpl * impl = maybeGetBatchedImpl(batchedtensor); + while(true) { + auto level = impl->level(); + AT_ASSERT(level >= levels_start_ && level < levels_start_ + levels_to_dim_.size()); + mpy::hdl dim = levels_to_dim_[level - levels_start_].ptr(); + levels.insert(A, impl->bdim(), dim); + at::functorch::BatchedTensorImpl * nimpl = maybeGetBatchedImpl(impl->value()); + if (!nimpl) { + tensor = impl->value(); + break; + } + impl = nimpl; + } + + mpy::obj self = Tensor::create(); + // grab ownership of the tensors + self->tensor_ = *tensor; + self->batchtensor_ = std::move(batchedtensor); + self->has_device_ = has_device; + self->capture_levels(levels); + return self; + } + void inplace_update_layers(TensorRef batchtensor, Slice levels) { + // XXX - requires a patch to functorch to att set_level + auto impl = maybeGetBatchedImpl(*batchtensor); + for (auto i : levels_to_dim_.reversed_enumerate()) { + if (!impl) { + break; + } + if (levels.contains(levels_to_dim_[i])) { + impl->_unsafe_set_level(levels_start_ + i); + impl = maybeGetBatchedImpl(impl->value()); + + } + } + } +private: + int64_t levels_start_{}; + Slice> levels_to_dim_; +}; + +namespace{ +TensorRef _match_levels(Arena& A, TensorRef v, Slice from_levels, Slice to_levels, bool drop_levels=false) { + if (from_levels == to_levels) { + return v; + } + // drop_levels -> if a dim appears in from_levels but not to_levels, it is assumed it has stride 0. + at::IntArrayRef sz = v->sizes(); + at::IntArrayRef sd = v->strides(); + AT_ASSERT(drop_levels || from_levels.size() <= to_levels.size()); + Slice nsz; + Slice nsd; + for (auto l : to_levels) { + auto oidx = from_levels.index(l); + if (!oidx) { + nsz.append(A, l.is_positional() ? 1 : l.dim()->size()); + nsd.append(A, 0); + } else { + auto idx = *oidx; + nsz.append(A, sz[idx]); + nsd.append(A, sd[idx]); + } + } + return A.autorelease(v->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()), at::IntArrayRef(nsd.begin(), nsd.end()), v->storage_offset())); +} +} +mpy::object run_torch_function(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise) { + if (!pointwise_optimize) { + is_pointwise = false; + } + // std::cout << "__torch_function__ " << ((is_pointwise) ? "pointwise" : "functorch") << " " << orig << "\n"; + + Slice> all_dims; + Slice flat_args; + auto unflatten_args = tree_flatten(A, args, flat_args); + TensorRef device_holding_tensor; + + Slice infos; + Slice result_levels; + for (auto f : flat_args) { + infos.append(A, TensorInfo::create(A, f, !is_pointwise, false)); + if (infos.back()) { + TensorInfo& info = infos.back(); + AT_ASSERT(is_pointwise || info.batchedtensor); + if (!device_holding_tensor && info.has_device) { + device_holding_tensor = infos.back().tensor; + } + for (auto l : info.levels) { + if (!result_levels.contains(l)) { + result_levels.append(A, l); + } + } + } + } + + if (is_pointwise) { + for (auto i : flat_args.enumerate()) { + if (infos[i]) { + TensorRef tensor = infos[i].tensor; + if (device_holding_tensor && !infos[i].has_device) { + tensor = A.autorelease(tensor->to(device_holding_tensor->device())); + } + auto ml = _match_levels(A, tensor, infos[i].levels, result_levels); + flat_args[i] = handle_from_tensor(A, std::move(ml)); + } + } + + Slice flat_it = flat_args; + mpy::vector_args uargs = unflatten_args(A, flat_it); + + mpy::object result = orig.call_vector(uargs); + + // fast wrap for normal case where operator just returns a tensor. + if (THPVariable_Check(result.ptr())) { + return Tensor::from_positional(A, THPVariable_Unpack(result.ptr()), result_levels, device_holding_tensor); + } + auto wrap = [&](mpy::handle h) { + if (THPVariable_Check(h.ptr())){ + return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), result_levels, device_holding_tensor)); + } + return h; + }; + return tree_map(A, wrap, result); + } else { + // std::cout << orig << " calling functorch...\n"; + // std::cout << "rl: " << result_levels << "\n"; + EnableAllLayers guard(A, result_levels); + for (auto i : flat_args.enumerate()) { + if (infos[i]) { + TensorRef batched = infos[i].batchedtensor; + if (device_holding_tensor && !infos[i].has_device) { + batched = A.autorelease(batched->to(device_holding_tensor->device())); + } + guard.inplace_update_layers(batched, infos[i].levels); + flat_args[i] = handle_from_tensor(A, batched); + } + } + Slice flat_it = flat_args; + mpy::vector_args uargs = unflatten_args(A, flat_it); + AT_ASSERT(flat_it.size() == 0); + mpy::object result = orig.call_vector(uargs); + auto wrap = [&](mpy::handle h) { + if (THPVariable_Check(h.ptr())) { + return A.autorelease(guard.from_batched(A, THPVariable_Unpack(h.ptr()), device_holding_tensor)); + } + return h; + }; + if (THPVariable_Check(result.ptr())) { + return guard.from_batched(A, THPVariable_Unpack(result.ptr()), device_holding_tensor); + } + return tree_map(A, wrap, result); + } +} + +namespace{ + +mpy::object __torch_function__(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise) { + if (orig == torch_Tensor___mul__) { + AT_ASSERT(args.nargs == 2 && !args.has_keywords()); + auto lhs = args[0]; + auto rhs = args[1]; + if (mpy::isinstance(lhs, _Tensor) && mpy::isinstance(rhs, _Tensor) && _Tensor_ndim(lhs) == 0 && _Tensor_ndim(rhs) == 0) { + bool has_device = false; + Slice levels; + for (auto i : args.enumerate_positional()) { + auto t = TensorInfo::create(A, args[i], false); + // something like a mask * rhs, which matrix multiplies don't correctly promote + if (!t.tensor->is_floating_point()) { + return run_torch_function(A, orig, args, is_pointwise); + } + has_device = has_device || t.has_device; + for (auto l : t.levels) { + if (!levels.contains(l)) { + levels.append(A, l); + } + } + } + // std::cout << "__torch_function__ " << "delay" << " " << orig << "\n"; + return Tensor::create_delayed(mpy::object::borrow(orig), args, levels, has_device); + } + } + return run_torch_function(A, orig, args, is_pointwise); +} + +mpy::vector_args as_vector_args(Arena& A, mpy::handle args, mpy::handle kwargs) { + auto pos_args = (mpy::handle*) &PyTuple_GET_ITEM(args.ptr(), 0); + auto pos_n = PyTuple_GET_SIZE(args.ptr()); + if (!kwargs.ptr()) { + return mpy::vector_args(pos_args, pos_n, nullptr); + } + Slice all_args; + Slice kwnames; + all_args.extend(A, pos_args, pos_args + pos_n); + mpy::dict_view dv(kwargs); + Py_ssize_t pos = 0; + mpy::handle key, value; + while (dv.next(&pos, &key, &value)) { + all_args.append(A, value); + kwnames.append(A, key); + } + return mpy::vector_args(all_args.begin(), pos_n, A.autorelease(slice_to_tuple(kwnames))); +} + +PyObject* py___torch_function__(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + maybeInitializeGlobals(); + AT_ASSERT(nargs == 4 || nargs == 5); + auto va = as_vector_args(A, args[3], nargs == 5 ? args[4] : nullptr); + bool is_pointwise = pointwise.contains(args[1]); + return __torch_function__(A, args[1], std::move(va), is_pointwise).release(); + PY_END(nullptr) +} + +mpy::object levels_to_tuple(Slice slice) { + mpy::tuple t(slice.size()); + for (auto i : slice.enumerate()) { + t.set(i, slice[i].is_positional() ? mpy::from_int(slice[i].position()) : mpy::object::borrow(slice[i].dim())); + } + mpy::object r = std::move(t); + return r; +} + +PyObject* Tensor_ndim(Tensor* self, void*) { + Py_ssize_t i = 0; + for (auto l : self->levels()) { + if (l.is_positional()) { + ++i; + } + } + return mpy::from_int(i).release(); +} + +PyGetSetDef Tensor_getsetters[] = { + {"_has_device", (getter) [](PyObject* self, void*) -> PyObject* { return mpy::from_bool(((Tensor*)self)->has_device()).release(); }, NULL}, + {"_tensor", (getter) [](PyObject* self, void*) -> PyObject* { + Arena A; + return THPVariable_Wrap(((Tensor*)self)->tensor(A)); }, NULL}, + {"_batchtensor", (getter) [](PyObject* self, void*) -> PyObject* { + Arena A; + return THPVariable_Wrap(((Tensor*)self)->batchtensor(A)); }, NULL}, + {"_levels", (getter) [](PyObject* self, void*) -> PyObject* { + PY_BEGIN + return levels_to_tuple(((Tensor*)self)->levels()).release(); + PY_END(nullptr) + }}, + {"ndim", (getter) Tensor_ndim, NULL, "ndim", NULL}, + {NULL} /* Sentinel */ +}; + +PyMethodDef Tensor_methods[] = { + {NULL, NULL, 0, NULL} /* Sentinel */ +}; +} + + +PyTypeObject Tensor::Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "_C.Tensor", /* tp_name */ + sizeof(Tensor), /* tp_basicsize */ + 0, /* tp_itemsize */ + Tensor::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE , /* tp_flags */ + "Tensor Object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + Tensor_methods, /* tp_methods */ + 0, /* tp_members */ + Tensor_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + Tensor::new_stub, /* tp_new */ +}; + + +// dim() -------------------- + +static bool relevant_op(_Py_CODEUNIT c) { + switch(c) { + case STORE_NAME: + case STORE_GLOBAL: + case STORE_FAST: + case STORE_DEREF: + return true; + default: + return false; + } +} + +static mpy::object create_dim(mpy::object name, mpy::handle size) { + auto d = Dim::create(std::move(name)); + if (!mpy::is_none(size)) { + d->set_size(mpy::to_int(size)); + } + return std::move(d); +} + +static mpy::object create_dimlist(mpy::object name, mpy::handle size) { + auto d = DimList::create(std::move(name)); + if (!mpy::is_none(size)) { + if (mpy::is_int(size)) { + d->bind_len(mpy::to_int(size)); + } else { + mpy::sequence_view s(size); + d->bind_len(s.size()); + for (auto i : irange(d->size())) { + d->dims_[i]->set_size(mpy::to_int(s[i])); + } + } + } + return std::move(d); +} + + + +// Python wrappers that make new reflection primitives available for older runtimes +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if !(IS_PYTHON_3_11_PLUS) #define _PyCode_CODE(CO) ((_Py_CODEUNIT*)PyBytes_AS_STRING((CO)->co_code)) #endif +<<<<<<< HEAD namespace { struct PyInstDecoder { PyInstDecoder(PyCodeObject* code_object, int lasti) @@ -2474,11 +3919,730 @@ mpy::object index( indices_list, has_dimpacks); return invoke_getitem(A, info); +======= +namespace{ +struct PyInstDecoder { + PyInstDecoder(PyCodeObject* code_object, int lasti) + : code_object_(code_object), code_(_PyCode_CODE(code_object)), offset_(lasti / sizeof(_Py_CODEUNIT)) {} + // On Windows, _PyOpcode_Caches and _PyOpcode_Deopt are private symbols + // See https://github.com/pytorch/pytorch/issues/93854 + void next() { + #if IS_PYTHON_3_11_PLUS + offset_ += _PyOpcode_Caches[opcode()]; + #endif + offset_ += 1; + } + int opcode() { + auto r = _Py_OPCODE(code_[offset_]); + #if IS_PYTHON_3_11_PLUS + r = _PyOpcode_Deopt[r]; + #endif + return r; + } + int oparg() { + return _Py_OPARG(code_[offset_]); + } + + mpy::object name() { + mpy::object names; + switch(opcode()) { + case STORE_NAME: + case STORE_GLOBAL: + names = mpy::object::borrow(code_object_->co_names); + break; + case STORE_FAST: + names = mpy::object::steal(PyCode_GetVarnames(code_object_)); + break; + case STORE_DEREF: + names = mpy::object::steal(PyCode_GetCellvars(code_object_)); + break; + default: + return mpy::object(); + } + return mpy::object::steal(PySequence_GetItem(names.ptr(), oparg())); + } +private: + PyCodeObject* code_object_; + _Py_CODEUNIT* code_; + int offset_; +}; + +template +static PyObject* _dims(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + Py_ssize_t specified_ndims = -1; + Py_ssize_t found_ndims = 0; + Py_ssize_t sizes = -1; + mpy::handle n = Py_None; + mpy::handle py_sizes = Py_None; + + if (nargs || kwnames) { + mpy::vector_args va(args, nargs, kwnames); + va.parse("dims", {"n", "sizes"}, {&n, &py_sizes}, 0); + if (!mpy::is_none(py_sizes)) { + sizes = mpy::sequence_view(py_sizes).size(); + specified_ndims = sizes; + } + if (!mpy::is_none(n)) { + specified_ndims = mpy::to_int(n); + } + } + + PyThreadState* state = PyThreadState_GET(); + auto f = mpy::obj::steal(PyThreadState_GetFrame(state)); + auto c = mpy::obj::steal(PyFrame_GetCode(f.ptr())); + auto lasti = PyFrame_GetLasti(f.ptr()); + auto decoder = PyInstDecoder(c.ptr(), lasti); + #if IS_PYTHON_3_11_PLUS + // When py3.11 adapts bytecode lasti points to the precall + // rather than the call instruction after it + if (decoder.opcode() == PRECALL) { + decoder.next(); + } + #endif + decoder.next(); + + if (relevant_op(decoder.opcode())) { + found_ndims = 1; + } else if (decoder.opcode() == UNPACK_SEQUENCE) { + found_ndims = decoder.oparg(); + decoder.next(); + } + + if (specified_ndims == -1) { + if (found_ndims == 0) { + mpy::raise_error(PyExc_SyntaxError, "dims() must be assigned to a sequence of variable names or have argument n specified"); + } + specified_ndims = found_ndims; + } + if (found_ndims != specified_ndims) { + found_ndims = 0; // avoid taking the wrong names for dimensions + } + + auto genobject = [&](int i) -> mpy::object { + mpy::object name; + if (i < found_ndims) { + name = decoder.name(); + } + if (!name.ptr()) { + name = mpy::unicode_from_format("d%d", i); + found_ndims = 0; // once we fail at finding a name, we can find any more + } else { + decoder.next(); + } + return create_object(std::move(name), sizes != -1 ? mpy::sequence_view(py_sizes)[i] : mpy::handle(Py_None)); + }; + if (sizes != -1 && sizes != specified_ndims) { + mpy::raise_error(PyExc_ValueError, "expected %d sizes but found %d", int(specified_ndims), int(sizes)); + } + if (specified_ndims == 1) { + return genobject(0).release(); + } + mpy::tuple result(specified_ndims); + for (int i = 0; i < specified_ndims; ++i) { + result.set(i, genobject(i)); + } + return result.release(); + PY_END(nullptr) +} + +struct DotPart { + Slice dims; + size_t total_size = 1; + void append(Arena& A, mpy::hdl d) { + total_size *= d->size(); + dims.append(A, d); + } +}; + +template +static at::ArrayRef as_array_ref(Slice t) { + return at::ArrayRef(t.begin(), t.end()); +} + +static TensorRef dot_prepare(Arena& A, std::initializer_list parts, const TensorInfo& t) { + Slice new_levels; + bool needs_reshape = false; + for (auto p : parts) { + if (p.dims.size() != 1) { + needs_reshape = true; + } + new_levels.extend(A, p.dims); + } + auto r = _match_levels(A, t.tensor, t.levels, new_levels, true); + if (!needs_reshape) { + return r; + } + Slice view; + for (auto p : parts) { + view.append(A, p.total_size); + } + return A.autorelease(r->reshape(at::IntArrayRef(view.begin(), view.end()))); +} + +static mpy::object dot_finish(Arena& A, std::initializer_list parts, at::Tensor r) { + Slice result_levels; + bool needs_reshape = false; + for (auto p : parts) { + if (p.dims.size() != 1) { + needs_reshape = true; + } + result_levels.extend(A, p.dims); + } + if (needs_reshape) { + Slice new_size; + for (auto l : result_levels) { + new_size.append(A, l.dim()->size()); + } + r = r.reshape(at::IntArrayRef(new_size.begin(), new_size.end())); + } + return Tensor::from_positional(A, std::move(r), result_levels, true); +} + + + +static mpy::object dot(Arena& A, TensorInfo lhs, TensorInfo rhs, Slice sum) { + auto lhs_strides = lhs.tensor->strides(); + auto rhs_strides = rhs.tensor->strides(); + + DotPart lro_dims; + DotPart lo_dims; + DotPart ro_dims; + DotPart lr_dims; + + auto insert_dim = [&] (mpy::hdl d, std::optional lhs_idx, std::optional rhs_idx) { + bool reduced = sum.contains(d); + int64_t lhs_stride = lhs_idx ? lhs_strides[*lhs_idx] : 0; + int64_t rhs_stride = rhs_idx ? rhs_strides[*rhs_idx] : 0; + if (reduced) { + // lr + lr_dims.append(A, d); + } else { + if ((lhs_stride == 0) == (rhs_stride == 0)) { + // lro + lro_dims.append(A, d); + } else if (lhs_stride != 0) { + // lo + lo_dims.append(A, d); + } else { + AT_ASSERT(rhs_stride != 0); + ro_dims.append(A, d); + } + } + }; + + + auto rhs_seen = A.allocate(rhs.levels.size()); + std::fill(rhs_seen, rhs_seen + rhs.levels.size(), false); + + for (auto i : lhs.levels.enumerate()) { + auto d = lhs.levels[i]; + auto rhs_idx = rhs.levels.index(d); + if (rhs_idx) { + rhs_seen[*rhs_idx] = true; + } + insert_dim(d.dim(), i, rhs_idx); + } + + for (auto i : rhs.levels.enumerate()) { + if (rhs_seen[i]) { + continue; + } + auto d = rhs.levels[i]; + insert_dim(d.dim(), std::nullopt, i); + } + + if (lr_dims.dims.size() != sum.size()) { + for (auto & d : sum) { + if (!lhs.levels.contains(d) && !rhs.levels.contains(d)) { + mpy::raise_error(DimensionBindError(), "summing over non-existent dimension %S", d.dim().ptr()); + } + } + } + + // std::cout << lhs.levels << " " << rhs.levels << " " << sum << "\n"; + // std::cout << lro_dims.dims << " " << lo_dims.dims << " " << ro_dims.dims << " " << lr_dims.dims << "\n"; + + // no batch, just call mm + if (lro_dims.dims.size() != 0) { + auto lhs_ = dot_prepare(A, {lro_dims, lo_dims, lr_dims}, lhs); + auto rhs_ = dot_prepare(A, {lro_dims, lr_dims, ro_dims}, rhs); + return dot_finish(A, {lro_dims, lo_dims, ro_dims}, at::bmm(*lhs_, *rhs_)); + } else { + auto lhs_ = dot_prepare(A, {lo_dims, lr_dims}, lhs); + auto rhs_ = dot_prepare(A, {lr_dims, ro_dims}, rhs); + return dot_finish(A, {lo_dims, ro_dims}, at::mm(*lhs_, *rhs_)); + } + +} + +static PyObject* test_c(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + + Arena A; + Slice s(A, 3, 4, 5); + AT_ASSERT(s.size() == 3 && s.capacity() == 8); + AT_ASSERT(s[0] == 3 && s[1] == 4 && s[2] == 5); + s.append(A, 6); + AT_ASSERT(s[3] == 6); + for(int i : irange(10)) { + s.append(A, i); + } + AT_ASSERT(s[0] == 3 && s.back() == 9 && s.size() == 14 && s.capacity() == 16); + + Slice s2(A, -1, -2, -3); + AT_ASSERT(s2[1] == -2 && s[0] == 3); + + auto ss = s.slice(1,2); + AT_ASSERT(ss.size() == 1); + AT_ASSERT(ss[0] == 4); + AT_ASSERT(ss.capacity() == 1); + ss.append(A, -4); + AT_ASSERT(ss.size() == 2 && ss[1] == -4); + ss[0] = 3; + AT_ASSERT(s[1] == 4); + + s.insert(A, s.slice(1, 4), ss); + AT_ASSERT(s[1] == 3 && s[2] == -4 && s[3] == 0); + + auto sz = s.size(); + s.insert(A, s.slice(1, 1), 4); + AT_ASSERT(s[1] == 4 && sz + 1 == s.size()); + + + Slice d(A, 0, 1, 2, 3, 4); + + Slice b(A, 0, 1, 2, 3, 4); + b.insert(A, b.slice(1,1), d); + AT_ASSERT(b.size() == 10); + AT_ASSERT(b[1] == 0); + AT_ASSERT(b[5] == 4); + AT_ASSERT(b.back() == 4); + + Py_RETURN_NONE; + + PY_END(nullptr); +} + + +static PyObject* order(PyObject *_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + if (kwnames) { + mpy::raise_error(PyExc_TypeError, "unexpected keyword arguments %S", kwnames); + } + AT_ASSERT(nargs-- > 0); + Slice orig_levels; + Slice levels; + TensorRef data; + mpy::handle self = args++[0]; + bool has_device; + if (Tensor::check_exact(self)) { + auto t = Tensor::unchecked_wrap(self); + orig_levels = t->levels(); + data = t->tensor(A); + has_device = t->has_device(); + } else { + auto d = Dim::unchecked_wrap(self); + orig_levels.append(A, d); + data = d->range(); + has_device = false; + } + + Slice flat_positional_dims; + Slice> to_flatten; + levels.extend(A, orig_levels); + + int orig_ndim = ndim_of_levels(levels); + auto append = [&](DimEntry d) { + auto midx = levels.index(d); + if (!midx) { + if (d.is_positional()) { + mpy::raise_error(PyExc_ValueError, "tensor has %d positional dimensions, but %d specified, or it was specified twice", int(orig_ndim), int(d.position() + orig_ndim)); + } else { + mpy::raise_error(PyExc_ValueError, "tensor of dimensions %R does not contain dim %R or it was specified twice", levels_to_tuple(orig_levels).ptr(), d.dim().ptr()); + } + } + levels[*midx] = DimEntry(); + flat_positional_dims.append(A, d); + }; + + int n_new_positional = 0; + for (auto i :irange(nargs)) { + mpy::handle arg = args[i]; + DimEntry entry = _wrap_dim(arg, orig_ndim, false); + if (!entry.is_none()) { + append(entry); + ++n_new_positional; + } else if (DimList::check(arg)) { + auto dl = DimList::unchecked_wrap(arg); + for (mpy::obj & d : dl->dims_) { + append(mpy::hdl(d)); + ++n_new_positional; + } + } else { + ++n_new_positional; + if (!mpy::is_sequence(arg)) { + mpy::raise_error(PyExc_ValueError, "expected a Dim, List[Dim], or Sequence[Dim]"); + } + mpy::sequence_view sq(arg); + auto N = sq.size(); + to_flatten.append(A, std::make_pair(flat_positional_dims.size(), N)); + for (auto j : irange(N)) { + DimEntry e = _wrap_dim(A.autorelease(sq[j]), orig_ndim, false); + if (e.is_none()) { + mpy::raise_error(PyExc_ValueError, "expected a Dim, or int"); + } + append(e); + } + } + } + + int insert_point = -1; + Slice new_levels; + for (auto l : levels) { + if (l.is_none()) { + continue; + } + if (l.is_positional()) { + if (insert_point == -1) { + insert_point = new_levels.size(); + new_levels.extend(A, flat_positional_dims); + } + } + new_levels.append(A, l); + } + if (insert_point == -1) { + insert_point = new_levels.size(); + new_levels.extend(A, flat_positional_dims); + } + + at::Tensor ndata = *_match_levels(A, data, orig_levels, new_levels); + + if (to_flatten.size()) { + Slice view; + auto sz = ndata.sizes(); + // before the new positional dims + for (auto i : irange(0, insert_point)) { + view.append(A, sz[i]); + } + int i = 0; + for (auto to_flat : to_flatten) { + for (;i < to_flat.first; ++i) { + view.append(A, sz[insert_point + i]); + } + int64_t new_size = 1; + int last = i + to_flat.second; + for (; i < last; ++i) { + new_size *= sz[insert_point + i]; + } + view.append(A, new_size); + } + for (; i < flat_positional_dims.size(); ++i) { + view.append(A, sz[insert_point + i]); + } + // after the new positional dims + for (auto i : irange(insert_point + flat_positional_dims.size(), levels.size())) { + view.append(A, sz[i]); + } + // we shorted the number of dimension, so remove them from new levels + // we will renumber them later + auto n_to_remove = flat_positional_dims.size() - n_new_positional; + new_levels.insert(A, new_levels.slice(insert_point, insert_point + n_to_remove), Slice()); + ndata = std::move(ndata).reshape(at::IntArrayRef(view.begin(), view.end())); + } + + // renumber the positional dimension + int seen = 0; + for (auto i : new_levels.reversed_enumerate()) { + if (new_levels[i].is_positional() || (i >= insert_point && i < insert_point + n_new_positional)) { + new_levels[i] = --seen; + } + } + return Tensor::from_positional(A, std::move(ndata), new_levels, has_device).release(); + + PY_END(nullptr) +} + +static PyObject* expand(PyObject *_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + AT_ASSERT(nargs-- > 0); + auto info = TensorInfo::create(A, args++[0], false); + for (auto i : irange(nargs)) { + if (!Dim::check(args[i])) { + maybeInitializeGlobals(); + mpy::vector_args vargs(args - 1, nargs + 1, kwnames); + if (THPVariable_Check(args[-1])) { + return torch_Tensor_expand.call_vector(vargs).release(); + } else { + return __torch_function__(A, torch_Tensor_expand, vargs, false).release(); + } + } + } + const at::Tensor& data = *info.tensor; + auto levels = info.levels; + Slice new_levels; + Slice sz; + Slice sd; + for (auto i : irange(nargs)) { + auto d = Dim::unchecked_wrap(args[i]); + if (levels.contains(d) || new_levels.contains(d)) { + mpy::raise_error(DimensionBindError(), "expanding dimension %R already exists in tensor with dims", d.ptr()); + } + new_levels.append(A, d); + sz.append(A, d->size()); + sd.append(A, 0); + } + new_levels.extend(A, levels); + at::IntArrayRef osz = data.sizes(); + at::IntArrayRef osd = data.strides(); + sz.extend(A, osz.begin(), osz.end()); + sd.extend(A, osd.begin(), osd.end()); + at::Tensor ndata = data.as_strided(at::IntArrayRef(sz.begin(), sz.end()), at::IntArrayRef(sd.begin(), sd.end()), data.storage_offset()); + return Tensor::from_positional(A, std::move(ndata), new_levels, info.has_device).release(); + PY_END(nullptr) +} + + +static void _bind_dims_to_size(Arena & A, int64_t sz, int64_t sd, + Slice> dims, Slice& nsz, Slice& nsd) { + int64_t rhs_prod = 1; + for (auto i : dims.enumerate()) { + if (!dims[i]->is_bound()) { + for (auto j : irange(i + 1, dims.size())) { + if (!dims[j]->is_bound()) { + mpy::raise_error(DimensionBindError(), "cannot infer the sizes of two dimensions at once %R and %R", dims[i].ptr(), dims[j].ptr()); + } + rhs_prod *= dims[j]->size(); + } + if (sz % rhs_prod != 0) { + mpy::tuple tup(dims.size()); + for (auto j : dims.enumerate()) { + tup.set(j, dims[j]->is_bound() ? mpy::from_int(dims[j]->size()) : mpy::unicode_from_string("?")); + } + mpy::raise_error(DimensionBindError(), "inferred dimension does not evenly fit into larger dimension: %d vs %R", (int) sz, tup.ptr()); + } + int64_t inferred_size = sz / rhs_prod; + dims[i]->set_size(inferred_size); + rhs_prod = sz; + break; + } + rhs_prod *= dims[i]->size(); + } + if (rhs_prod != sz) { + mpy::tuple tup(dims.size()); + for (auto j : dims.enumerate()) { + tup.set(j, mpy::object::borrow(dims[j])); + } + mpy::raise_error(DimensionBindError(), "Dimension sizes to do not match (%d != %d) when matching dimension pack %R", (int) sz, (int) rhs_prod, tup.ptr()); + } + auto new_strides = A.allocate(dims.size()); + auto prev_stride = sd; + for (auto i : dims.reversed_enumerate()) { + new_strides[i] = prev_stride; + prev_stride = dims[i]->size()*prev_stride; + } + for (auto i : dims.enumerate()) { + nsd.append(A, new_strides[i]); + nsz.append(A, dims[i]->size()); + } +} + +static bool has_dims(mpy::handle d) { + return Dim::check_exact(d) || Tensor::check_exact(d); +} + +struct IndexingInfo { + bool can_call_original; // if true, then it is safe to just call getitem or setitem, these objects do not need special handling + bool advanced_indexing; // requires actual lookup + TensorRef self; + Slice flat_inputs; + Slice result_levels; + bool has_device; +}; +} + +IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice input, Slice keys, Slice values, bool has_dimpacks_or_none); +namespace{ +Slice as_slice(mpy::tuple_view tv) { + PyObject** begin = &PyTuple_GET_ITEM(tv.ptr(),0); + return Slice((mpy::handle*)begin, (mpy::handle*) (begin + tv.size())); +} + +Slice as_slice(mpy::list_view tv) { + PyObject** begin = &PyList_GET_ITEM(tv.ptr(),0); + return Slice((mpy::handle*)begin, (mpy::handle*) (begin + tv.size())); +} + + +bool maybe_dimpack(Slice& elements, mpy::handle s, bool check_first=true) { + // can we avoid rechecking? + if (mpy::list_view::check(s)) { + mpy::list_view tv(s); + if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { + elements = as_slice(tv); + return true; + } + } + // can we avoid rechecking? + if (mpy::tuple_view::check(s)) { + mpy::tuple_view tv(s); + if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { + elements = as_slice(tv); + return true; + } + } + return false; +}; + +bool is_dimpack(mpy::handle s) { + Slice e; + return maybe_dimpack(e, s); +} + +mpy::object invoke_getitem(Arena& A, const IndexingInfo& iinfo) { + at::Tensor rtensor; + if (iinfo.advanced_indexing) { + auto self_hdl = handle_from_tensor(A, iinfo.self); + auto tup = slice_to_tuple(iinfo.flat_inputs); + // std::cout << "calling original getindex " << self_hdl << " " << tup << "\n"; + auto pytensor = mpy::object::checked_steal(THPVariable_getitem(self_hdl.ptr(), tup.ptr())); + rtensor = THPVariable_Unpack(pytensor.ptr()); + } else { + // std::cout << "skipping original getindex\n"; + rtensor = *iinfo.self; + } + // std::cout << "returning (from_positional)\n"; + return Tensor::from_positional(A, std::move(rtensor), iinfo.result_levels, iinfo.has_device); +} + +mpy::object index(Arena& A, mpy::handle self, mpy::handle dims, mpy::handle indices) { + maybeInitializeGlobals(); + Slice dims_list; + Slice indices_list; + // we allow for matching single dims to multiple dims, + // so we first have to normalize everything into the case where there is a list on lhs and the rhs + bool lhs_list = mpy::tuple_view::check(dims) || mpy::list_view::check(dims); + bool rhs_list = mpy::tuple_view::check(indices) || mpy::list_view::check(indices); + if (lhs_list && rhs_list) { + mpy::sequence_view dv(dims); + mpy::sequence_view ind(indices); + Py_ssize_t N = dv.size(); + if (N != ind.size()) { + mpy::raise_error(PyExc_TypeError, "dims (%d) and indices (%d) must have the same length", int(N), int(ind.size())); + } + for (auto i : irange(N)) { + dims_list.append(A, A.autorelease(dv[i])); + indices_list.append(A, A.autorelease(ind[i])); + } + } else { + dims_list.append(A, dims); + indices_list.append(A, indices); + } + + // dims being indexed can be grouped together into a single index space, and we have to + // flatten them int a single dimension before we can index them... + auto self_info = TensorInfo::create(A, self, false); + auto ndim = self_info.ndim(); + Slice new_levels; + Slice to_flatten; + Slice dims_list_flat; + auto parse_dim_entry = [&](mpy::handle s) -> DimEntry { + auto d = _wrap_dim(s, ndim, false); + if (d.is_none()) { + mpy::raise_error(PyExc_TypeError, "expected a dimension specifyer but found %R", s.ptr()); + } + return d; + }; + auto dim_not_present = [&](DimEntry d) { + if (d.is_positional()) { + mpy::raise_error(PyExc_TypeError, "dimension %d not in tensor of %d dimensions", d.position() + ndim , ndim); + } else { + mpy::raise_error(PyExc_TypeError, "dimension %R not in tensor", d.dim()->ptr()); + } + }; + + for (auto i : dims_list.enumerate()) { + Slice m; + if (maybe_dimpack(m, dims_list[i], /*check_first=*/false)) { + if (m.size() == 0) { + // plausible semantics work for this to have 0 elements (e.g. the index will always be 0) + dims_list_flat.append(A, DimEntry()); // value is just dropped + } + auto first = parse_dim_entry(m[0]); + dims_list_flat.append(A, first); + if (m.size() == 1) { + continue; + } + if (to_flatten.size() == 0) { + new_levels.extend(A, self_info.levels); + } + Slice rest; + for (auto i : irange(1, m.size())) { + auto d = parse_dim_entry(m[i]); + if (!new_levels.remove(A, d)) { + dim_not_present(d); + } + rest.append(A, d); + } + + auto first_idx = new_levels.index(first); + if (!first_idx) { + dim_not_present(first); + } + new_levels.insert(A, new_levels.slice(*first_idx + 1, *first_idx + 1), rest); + to_flatten.extend(A, rest); + } else { + dims_list_flat.append(A, parse_dim_entry(dims_list[i])); + } + } + if (to_flatten.size() > 0) { + TensorRef rearranged = _match_levels(A, self_info.tensor, self_info.levels, new_levels); + at::IntArrayRef sizes = rearranged->sizes(); + Slice new_sizes; + Slice reshape_levels; + for (auto i : new_levels.enumerate()) { + if (to_flatten.contains(new_levels[i])) { + new_sizes.back() *= sizes[i]; + } else { + new_sizes.append(A, sizes[i]); + reshape_levels.append(A, new_levels[i]); + } + } + self_info.tensor = A.autorelease(rearranged->reshape(at::IntArrayRef(new_sizes.begin(), new_sizes.end()))); + + self_info.levels = reshape_levels; // note: we are using the first level in a flattened group to represent the group for the rest of the op + // we need to be careful not to rely the dimensions size because it doesn't match the size of the whole group + } + bool has_dimpacks = false; + for (auto idx : indices_list) { + if (mpy::tuple_view::check(idx) || mpy::list_view::check(idx)) { + has_dimpacks = true; + break; + } + } + IndexingInfo info = getsetitem_flat(A, self_info, Slice(), dims_list_flat, indices_list, has_dimpacks); + return invoke_getitem(A, info); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // true -- the indices were flattened out of a tuple, list or sequence... Slice slice_from_sequence(Arena& A, mpy::handle value) { +<<<<<<< HEAD if (mpy::tuple_view::check(value)) { return as_slice(mpy::tuple_view(value)); } else if (mpy::list_view::check(value)) { @@ -3562,6 +5726,965 @@ PyObject* _patch_tensor_class( Py_RETURN_NONE; PY_END(nullptr) } +======= + if (mpy::tuple_view::check(value)) { + return as_slice(mpy::tuple_view(value)); + } else if (mpy::list_view::check(value)) { + return as_slice(mpy::list_view(value)); + } else { + mpy::sequence_view sv(value); + Slice r; + for (auto i : sv.enumerate()) { + r.append(A, A.autorelease(sv[i])); + } + return r; + } +} + +bool extractIndices(Arena& A, mpy::handle index, Slice& indices) { + if (mpy::tuple_view::check(index)) { + indices.extend(A, as_slice(mpy::tuple_view(index))); + return true; + } else if (THPVariable_Check(index.ptr())) { + indices.append(A, index); + return false; + } else if (!mpy::is_sequence(index)) { + indices.append(A, index); + return false; + } + // a copy of treatSequenceAsTuple modified to add Dim and our wrapped tensors.. + mpy::sequence_view sv(index); + if (sv.size() >= 32) { + indices.extend(A, slice_from_sequence(A, index)); + return true; + } + for (auto i : sv.enumerate()) { + mpy::handle item; + try { + item = sv[i]; + } catch (mpy::exception_set & e) { + PyErr_Clear(); + indices.append(A, index); + return false; + } + if (THPVariable_Check(item.ptr()) || mpy::is_sequence(item) || PySlice_Check(item.ptr()) || item.ptr() == Py_Ellipsis || mpy::is_none(item) || has_dims(item)) { + indices.extend(A, slice_from_sequence(A, index)); + return true; + } + } + indices.append(A, index); + return false; +} + +IndexingInfo getsetitem(Arena & A, mpy::handle self, mpy::handle index, bool tensors_have_dims) { + bool can_call_original_getitem = !tensors_have_dims; + + Slice input; + if (has_dims(index)) { + input.append(A, index); + } else { + bool is_sequence = extractIndices(A, index, input); + // nothing about first class dims here, fallback to getitem + if (can_call_original_getitem && !is_sequence) { + return { true }; + } + } + + int64_t dims_indexed = 0; + int64_t expanding_object = -1; + DimList* unbound_dim_list = nullptr; + auto check_expanding = [&](int64_t i) { + if (expanding_object != -1) { + mpy::raise_error(DimensionBindError(), "at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets %d and %d", (int) expanding_object, (int) i); + } + expanding_object = i; + }; + Slice dimlists; + + // calculate how many dimensioned have been indexed in order to compute the size of ... + // or expand a potentially unbound dimension list. + + bool has_dimpacks_or_none = false; + for (auto i : input.enumerate()) { + mpy::handle s = input[i]; + if (Dim::check_exact(s) || Tensor::check_exact(s)) { + can_call_original_getitem = false; + ++dims_indexed; + } else if (s.ptr() == Py_Ellipsis) { + check_expanding(i); + } else if (DimList::check(s)) { + can_call_original_getitem = false; + auto dl = DimList::unchecked_wrap(s); + if (!dl->is_bound()) { + check_expanding(i); + unbound_dim_list = dl.ptr(); + } else { + dims_indexed += dl->dims_.size(); + } + dimlists.append(A, i); + } else if (mpy::is_none(s)) { + has_dimpacks_or_none = true; + } else if (is_dimpack(s)) { + can_call_original_getitem = false; + has_dimpacks_or_none = true; + ++dims_indexed; + } else { + ++dims_indexed; + } + } + + // at this point if we haven't seen any Dim objects, we also can fallback to the original getitem. + if (can_call_original_getitem) { + return {true}; + } + + // std::cout << "__getitem__ " << self << " " << index << "\n"; + + TensorInfo self_info = TensorInfo::create(A, self, false, true); + auto ndim = self_info.ndim(); + if (dims_indexed > ndim) { + mpy::raise_error(PyExc_ValueError, "at least %d indices were supplied but the tensor only has %d dimensions", (int) dims_indexed, (int) ndim); + } + // expand any unbound dimension list, or expand ... into individual : slices. + auto expanding_dims = ndim - dims_indexed; + if (expanding_object != -1) { + if (unbound_dim_list) { + unbound_dim_list->bind_len(expanding_dims); + } else { + // ... + Slice no_slices; + for (auto i : irange(expanding_dims)) { + (void) i; + no_slices.append(A, no_slice); + } + input.insert(A, input.slice(expanding_object, expanding_object + 1), no_slices); + } + } + + // flatten out any dimensions stored in dimlist elements directly into the inputs + // std::cout << dimlists << " <- dim lists!\n"; + for (int64_t i = dimlists.size() - 1; i >=0; --i) { + auto idx = dimlists[i]; + // we added more elements to input because of ... + // so we need to also adjust the index to get back to where the + // dimlist existed + if (!unbound_dim_list && expanding_object != -1 && idx > expanding_object) { + idx += expanding_dims; + } + auto dl = DimList::unchecked_wrap(input[idx]); + // XXX would be better if we used an OwnedSlice in DimList + Slice more_dims((mpy::handle*) &*dl->dims_.begin(), (mpy::handle*) &*dl->dims_.end()); + input.insert(A, input.slice(idx, idx + 1), more_dims); + } + + return getsetitem_flat(A, self_info, input, Slice(), Slice(), has_dimpacks_or_none); +} +} +IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice input, Slice keys, Slice values, bool has_dimpacks_or_none) { + // At this point: + // ..., DimList have been eliminated + // Dim, Tensor, Tuple[Dim,...], int, slice still remain + + + // we have to count how many times we see a dimension. + // A[i,j] is a simple binding operation, but A[i, i+j] or A[i, i] requires advanced indexing. + Slice> seen_dims; + Slice seen_dims_nuses; + auto add_dim = [&](mpy::hdl entry) { + auto midx = seen_dims.index(entry); + if (!midx) { + seen_dims.append(A, entry); + seen_dims_nuses.append(A, 1); + } else { + ++seen_dims_nuses[*midx]; + } + }; + + Slice input_it = input; + + Slice flat_inputs; + // flat inputs will start with an empty mpy::handle if the + // actual value is in the tensor-like object in the tensor info + Slice tensor_inputs; + + auto append_flat_handle = [&](mpy::handle h) { + flat_inputs.append(A, h); + tensor_inputs.append(A, TensorInfo()); + }; + TensorRef device_holding_tensor; + auto append_tensor_input = [&](TensorInfo ti) { + flat_inputs.append(A, mpy::handle()); + tensor_inputs.append(A, ti); + if (ti.has_device && !device_holding_tensor) { + device_holding_tensor = ti.tensor; + } + }; + + Slice nsz; + Slice nsd; + at::IntArrayRef sz = self_info.tensor->sizes(); + at::IntArrayRef sd = self_info.tensor->strides(); + + auto append_size = [&](int i) { + if (has_dimpacks_or_none) { + nsz.append(A, sz[i]); + nsd.append(A, sd[i]); + } + }; + // std::cout << "self levels: " << self_info.levels << "\n"; + + auto parse_nones = [&]() { + while (input_it.size() && mpy::is_none(input_it[0])) { + append_flat_handle(no_slice); + nsz.append(A, 1); + nsd.append(A, 0); + input_it = input_it.slice(1); + } + }; + + + auto append_item = [&](int i, mpy::handle arg) { + if (Dim::check_exact(arg)) { + auto d = Dim::unchecked_wrap(arg); + d->set_size(sz[i]); + add_dim(d); + append_size(i); + append_flat_handle(arg); + return; + } + auto info = TensorInfo::create(A, arg, false, false); + if (info) { + append_size(i); + append_tensor_input(info); + for (auto il : info.levels) { + if (!il.is_positional()) { + add_dim(il.dim()); + } + } + return; + } + + if (has_dimpacks_or_none) { + Slice mp; + if (maybe_dimpack(mp, arg)) { + // dim pack + Slice> dim_pack; + for (auto d : mp) { + dim_pack.append(A, Dim::wrap(d)); + add_dim(dim_pack.back()); + append_flat_handle(dim_pack.back()); + } + _bind_dims_to_size(A, sz[i], sd[i], dim_pack, nsz, nsd); + return; + } + } + + append_size(i); + append_flat_handle(arg); + }; + + // pair up the indexing expressions with dimension of self it indexes + // self may have first-class dims, which do not participate the indexing. + for (auto i : self_info.levels.enumerate()) { + auto l = self_info.levels[i]; + auto idx = keys.index(l); + if (idx) { + append_item(i, values[*idx]); + } else if (l.is_positional()) { + // grab and index from the positional list + parse_nones(); + if (!input_it.size()) { + // we might have fewer indices than tensor dimensions, + // which implicitly indexes the remaining dimensions with : + append_flat_handle(no_slice); + append_size(i); + } else { + mpy::handle arg = input_it[0]; + input_it = input_it.slice(1); + append_item(i, arg); + } + } else { + add_dim(l.dim()); + append_flat_handle(l.dim()); + append_size(i); + } + } + // any training Nones may have no existing dimension associated with them in self. + parse_nones(); + + // we have to restride the tensor to collapse dimension packs and introduce our none dimensions. + if (has_dimpacks_or_none) { + self_info.tensor = A.autorelease(self_info.tensor->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()),at::IntArrayRef(nsd.begin(), nsd.end()), self_info.tensor->storage_offset())); + } + + + // figure out what the shape of the indexing tensors will be + // and what the shape of the resulting tensor will be + Slice result_levels; + Slice index_levels; + int64_t tensor_insert_point = -1; + bool requires_getindex = false; + auto mark_tensor_index = [&] { + if (tensor_insert_point == -1) { + tensor_insert_point = result_levels.size(); + } else if (tensor_insert_point != result_levels.size()) { + tensor_insert_point = 0; + } + }; + for (auto i : flat_inputs.enumerate()) { + auto inp = flat_inputs[i]; + if(tensor_inputs[i]) { + requires_getindex = true; + mark_tensor_index(); + for (auto l : tensor_inputs[i].levels) { + // std::cout << "Consider to add " << l << "\n"; + if (!index_levels.contains(l)) { + index_levels.append(A, l); + } + } + } else if (Dim::check_exact(inp)) { + auto d = Dim::unchecked_wrap(inp); + // dimensions used once are just binding operations + if (1 == seen_dims_nuses[*seen_dims.index(d)]) { + flat_inputs[i] = no_slice; + result_levels.append(A, d); + } else { + requires_getindex = true; + flat_inputs[i] = mpy::handle(); + tensor_inputs[i] = TensorInfo {d->range(), Slice(A, DimEntry(d)), false, TensorRef()}; + if (!index_levels.contains(d)) { + index_levels.append(A, d); + } + mark_tensor_index(); + } + } else { + if (inp.ptr() != no_slice.ptr()) { + requires_getindex = true; + } + if (!mpy::is_int(inp)) { + // note: actual positional indexes are accurately computed later + result_levels.append(A, -1); + } + } + } + + // indexing dimensions appear in the tensor at the _first use of a tensor_ in the indexing. So insert + // the indexing leveles into the result klevels at this spot + if (tensor_insert_point != -1) { + result_levels.insert(A, result_levels.slice(tensor_insert_point, tensor_insert_point), index_levels); + } + + // std::cout << "flat inputs: " << flat_inputs << "\n"; + // std::cout << "result_levels: " << result_levels << "\n"; + // std::cout << "index_levels: " << index_levels << "\n"; + + // get all the tensors to be the right shape for indexing + if (requires_getindex) { + for (auto i : flat_inputs.enumerate()) { + if (tensor_inputs[i]) { + AT_ASSERT(!flat_inputs[i].ptr()); + // std::cout << "tensor " << i << " " << tensor_inputs[i].levels << "\n"; + TensorRef t = tensor_inputs[i].tensor; + if (!tensor_inputs[i].has_device && device_holding_tensor) { + t = A.autorelease(t->to(device_holding_tensor->device())); + } + flat_inputs[i] = handle_from_tensor(A, _match_levels(A, t, tensor_inputs[i].levels, index_levels)); + } + } + } + + // previously we didn't know how many positional dimensions there would be so we couldn't number them right + // so fill it in now. + auto seen_positionals = 0; + for (auto i : result_levels.reversed_enumerate()) { + if (result_levels[i].is_positional()) { + result_levels[i] = -(++seen_positionals); + } + } + + return IndexingInfo {false, requires_getindex, self_info.tensor, flat_inputs, result_levels, self_info.has_device}; +} +namespace{ +mpy::object __getitem__(Arena & A, mpy::handle self, mpy::handle index) { + maybeInitializeGlobals(); + auto iinfo = getsetitem(A, self, index, has_dims(self)); + if (iinfo.can_call_original) { + return mpy::object::checked_steal(THPVariable_getitem(self.ptr(), index.ptr())); + } + + return invoke_getitem(A, iinfo); +} + + + +void __setitem__(Arena & A, mpy::handle self, mpy::handle index, mpy::handle rhs) { + maybeInitializeGlobals(); + auto iinfo = getsetitem(A, self, index, has_dims(self) || has_dims(rhs)); + if (iinfo.can_call_original) { + if (-1 == THPVariable_setitem(self.ptr(), index.ptr(), rhs.ptr())) { + throw mpy::exception_set(); + } + return; + } + + auto rhs_info = TensorInfo::create(A, rhs, false, false); + if (rhs_info) { // otherwise rhs can be a scalar... + for (auto l : rhs_info.levels) { + if (!iinfo.result_levels.contains(l)) { + if (l.is_positional()) { + mpy::raise_error(DimensionBindError(), "rhs contains too many dimensions (%d) compared to indexed value (%d)", ndim_of_levels(iinfo.result_levels), rhs_info.ndim()); + } else { + auto tup = levels_to_tuple(iinfo.result_levels); + mpy::raise_error(DimensionBindError(), "rhs of setitem contains dimension %R which is not in the dimension on the left (%R)", l.dim().ptr(), tup.ptr()); + } + } + } + auto rhs_matched = _match_levels(A, rhs_info.tensor, rhs_info.levels, iinfo.result_levels); + rhs = handle_from_tensor(A, rhs_matched); + } + self = handle_from_tensor(A, iinfo.self); + + if (iinfo.advanced_indexing) { + auto tup = slice_to_tuple(iinfo.flat_inputs); + if (-1 == THPVariable_setitem(self.ptr(), tup.ptr(), rhs.ptr())) { + throw mpy::exception_set(); + } + } else { + torch_Tensor_copy_.call(self, rhs); + } +} +} + +PyObject* Tensor_getitem(PyObject* self, PyObject* index) { + Arena A; + PY_BEGIN + return __getitem__(A, self, index).release(); + PY_END(nullptr); +} + +int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value) { + Arena A; + PY_BEGIN + __setitem__(A, self, index, value); + return 0; + PY_END(-1); +} + +namespace{ +PyObject* py___getitem__(PyObject *_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + AT_ASSERT(nargs == 2); + return __getitem__(A, args[0], args[1]).release(); + PY_END(nullptr) +} + +PyObject* py___setitem__(PyObject *_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + AT_ASSERT(nargs == 3); + __setitem__(A, args[0], args[1], args[2]); + Py_RETURN_NONE; + PY_END(nullptr) +} + + +PyObject* py_index(PyObject *_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + mpy::vector_args va(args, nargs, kwnames); + mpy::handle self, dims, indices; + va.parse("index", {"self", "dims", "indices"}, {&self, &dims, &indices}, 3); + return index(A, self, dims, indices).release(); + PY_END(nullptr) +} + + +PyObject* py_stack(PyObject *_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + mpy::vector_args va(args, nargs, kwnames); + mpy::handle tensors, new_dim, dim; + va.parse("stack", {"tensors", "new_dim", "dim"}, {&tensors, &new_dim, &dim}, 2); + + Slice result_levels; + Slice infos; + mpy::sequence_view sv(tensors); + auto new_dim_d = Dim::wrap(new_dim); + for (auto i : sv.enumerate()) { + infos.append(A, TensorInfo::create(A, A.autorelease(sv[i]), false)); + for (auto l : infos.back().levels) { + if (!result_levels.contains(l)) { + result_levels.append(A, l); + } + } + } + new_dim_d->set_size(infos.size()); + std::vector inputs; + inputs.reserve(infos.size()); + for (auto in : infos) { + inputs.emplace_back(*_match_levels(A, in.tensor, in.levels, result_levels)); + } + auto ndim = ndim_of_levels(result_levels); + int64_t rawdim = 0; + if (dim.ptr()) { + auto d = _wrap_dim(dim, ndim, false); + auto idx = result_levels.index(d); + if (!idx) { + mpy::raise_error(PyExc_TypeError, "Dimension %R does not exist in inputs", dim.ptr()); + } + rawdim = *idx; + } + auto result = at::stack(inputs, rawdim); + result_levels.insert(A, rawdim, new_dim_d); + return Tensor::from_positional(A, std::move(result), result_levels, true).release(); + PY_END(nullptr) +} + +PyObject* py_split(PyObject *_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + maybeInitializeGlobals(); + mpy::vector_args va(args, nargs, kwnames); + mpy::handle self, split_size_or_sections, dim; + va.parse("split", {"self", "split_size_or_sections", "dim"}, {&self, &split_size_or_sections, &dim}, 2); + bool dim_is_object = dim.ptr() && Dim::check_exact(dim); + Slice sizes; + + bool all_dims = true; + bool all_ints = true; + + if (!mpy::is_int(split_size_or_sections)) { + mpy::sequence_view sv(split_size_or_sections); + for (auto i : sv.enumerate()) { + sizes.append(A, A.autorelease(sv[i])); + if (Dim::check_exact(sizes.back())) { + all_ints = false; + } else { + all_dims = false; + } + } + } + if (all_ints) { + if (dim_is_object) { + mpy::raise_error(PyExc_TypeError, "when dim is specified as a Dim object, split sizes must also be dimensions."); + } + // call original split (if self has dimensions this will use torch function to do the split) + return torch_Tensor_split.call_vector(mpy::vector_args(args, nargs, kwnames)).release(); + } + if (!all_dims) { + mpy::raise_error(PyExc_TypeError, "split list must be ints or dims but got a mix"); + } + + auto self_info = TensorInfo::create(A, self, false); + auto ndim = self_info.ndim(); + if (!dim_is_object&& ndim == 0) { + mpy::raise_error(PyExc_TypeError, "split expects at least a 1-dimension tensor"); + } + DimEntry dim_l = dim.ptr() ? _wrap_dim(dim, ndim, false) : -ndim; + + auto idx = self_info.levels.index(dim_l); + if (!idx) { + if (!dim.ptr()) { + dim = A.autorelease(mpy::from_int(0)); + } + mpy::raise_error(PyExc_TypeError, "tensor does not contain dimension %R", dim.ptr()); + } + Slice indices; + + int64_t total_size = 0; + Slice unbound; + for (auto i : sizes.enumerate()) { + auto d = Dim::unchecked_wrap(sizes[i]); + if (d->is_bound()) { + indices.append(A, d->size()); + total_size += indices.back(); + } else { + indices.append(A, 0); + unbound.append(A, i); + } + } + auto tensor_size = self_info.tensor->sizes()[*idx]; + + if (unbound.size()) { + if (total_size > tensor_size) { + mpy::raise_error(PyExc_TypeError, "sizes of target dimensions add up to more (%d) than source dim (%d)", int(total_size), int(tensor_size)); + } + auto remaining_size = tensor_size - total_size; + auto chunk_size = (remaining_size + unbound.size() - 1) / unbound.size(); + for (auto u : unbound) { + auto sz = std::min(chunk_size, remaining_size); + Dim::unchecked_wrap(sizes[u])->set_size(sz); + indices[u] = sz; + remaining_size -= sz; + } + } else if (tensor_size != total_size) { + mpy::raise_error(PyExc_TypeError, "sum of sizes of target dimensions (%d) do not match the than source dim (%d)", int(total_size), int(tensor_size)); + } + + auto result_tensors = self_info.tensor->split_with_sizes(at::IntArrayRef(indices.begin(), indices.end()), *idx); + mpy::tuple result(result_tensors.size()); + Slice new_levels; + new_levels.extend(A, self_info.levels); + for (auto i : sizes.enumerate()) { + new_levels[*idx] = Dim::unchecked_wrap(sizes[i]); + result.set(i, Tensor::from_positional(A, std::move(result_tensors[i]), new_levels, true)); + } + + return result.release(); + + PY_END(nullptr) +} + +Slice _wrap_dims(Arena& A, mpy::handle d, size_t N, bool keepdim) { + auto de = _wrap_dim(d, N, keepdim); + Slice r; + if (!de.is_none()) { + r.append(A, de); + } else { + mpy::sequence_view sq(d); + for (auto i : sq.enumerate()) { + r.append(A, _wrap_dim(A.autorelease(sq[i]), N, keepdim)); + } + } + return r; +} + +struct WrappedOperator : public mpy::base { + mpy::object orig; + PyMethodDef method_def; + mpy::object name, doc; + + bool is_pointwise = false; + int64_t dim_offset = 0; + int64_t keepdim_offset = 1; + std::string dim_name; + bool single_dim = false; + bool reduce = true; + + static PyTypeObject Type; + + void init(mpy::object orig_, PyCFunction wrapper_implementation, std::string dim_name_="") { + orig = std::move(orig_); + method_def.ml_meth = wrapper_implementation; + name = orig.attr("__name__"); + doc = orig.attr("__doc__"); + dim_name = std::move(dim_name_); + if (!mpy::is_none(doc) && !dim_name.empty()) { + doc = mpy::unicode_from_format("%S\nArgument '%s' can be either an integer or a torchdim.Dim object.\n", doc.ptr(), dim_name.c_str()); + } + method_def.ml_name = mpy::is_none(name) ? "" : PyUnicode_AsUTF8(name.ptr()); + method_def.ml_doc = mpy::is_none(doc) ? "" : PyUnicode_AsUTF8(doc.ptr()); + method_def.ml_flags = METH_FASTCALL | METH_KEYWORDS; + } + + mpy::object function() { + return mpy::object::checked_steal(PyCFunction_New(&method_def, ptr())); + } + +}; +} + +PyTypeObject WrappedOperator::Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "_C.WrappedOperator", /* tp_name */ + sizeof(WrappedOperator), /* tp_basicsize */ + 0, /* tp_itemsize */ + WrappedOperator::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "Wrapped Object Holder", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + WrappedOperator::new_stub, /* tp_new */ +}; + +namespace{ +PyObject* patched_dim_method(PyObject * self_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + auto self = WrappedOperator::unchecked_wrap(self_); + PY_BEGIN + + mpy::vector_args va(args, nargs, kwnames); + + auto _getarg = [&](const char* name, int64_t offset_) -> mpy::handle { + auto offset = offset_ + 1; // do not include self + auto idx = va.index(name, offset); + return idx == -1 ? mpy::handle() : va[idx]; + }; + Slice patched_args; + patched_args.extend(A, va.begin(), va.end()); + auto _patcharg = [&](const char* name, int64_t offset_, mpy::handle value) { + auto offset = offset_ + 1; // do not include self + auto idx = va.index(name, offset); + if (idx == -1) { + mpy::raise_error(PyExc_ValueError, "Missing argument %s", name); + } + patched_args[idx] = value; + }; + + auto dim = _getarg(self->dim_name.c_str(), self->dim_offset); + if (!dim.ptr()) { + auto info = TensorInfo::create(A, args[0], true); + EnableAllLayers l(A, info.levels); + l.inplace_update_layers(info.batchedtensor, info.levels); + patched_args[0] = handle_from_tensor(A, info.batchedtensor); + auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); + return l.from_batched(A, THPVariable_Unpack(r.ptr()), info.has_device).release(); + } + + auto info = TensorInfo::create(A, args[0]); + auto keepdim = false; + if (self->reduce) { + auto py_keepdim = _getarg("keepdim", self->keepdim_offset); + if (py_keepdim.ptr()) { + keepdim = mpy::to_bool(py_keepdim); + } + } + + auto ndim = info.ndim(); + auto dims = _wrap_dims(A, dim, ndim, keepdim); + Slice dim_indices; + auto seen = A.allocate(info.levels.size()); + std::fill(seen, seen + info.levels.size(), false); + + for (auto d : dims) { + auto midx = info.levels.index(d); + if (!midx) { + auto tup = levels_to_tuple(info.levels); + mpy::raise_error(PyExc_ValueError, "Tensor with dimensions %R does not contain one of %R\n", tup.ptr(), dim.ptr()); + } + seen[*midx] = true; + dim_indices.append(A, *midx); + } + Slice new_levels; + if (self->reduce && !keepdim) { + for (auto i : info.levels.enumerate()) { + if (!seen[i]) { + new_levels.append(A, info.levels[i]); + } + } + } else { + new_levels = info.levels; + } + mpy::object py_indices; + if (dim_indices.size() == 1) { + py_indices = mpy::from_int(dim_indices[0]); + } else { + mpy::tuple tup(dim_indices.size()); + for (auto i : dim_indices.enumerate()) { + tup.set(i, mpy::from_int(dim_indices[i])); + } + py_indices = std::move(tup); + } + _patcharg(self->dim_name.c_str(), self->dim_offset, py_indices); + patched_args[0] = handle_from_tensor(A, info.tensor); + auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); + auto wrap = [&](mpy::handle h) { + if (THPVariable_Check(h.ptr())) { + return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), new_levels, info.has_device)); + } + return h; + }; + return tree_map(A, wrap, r).release(); + PY_END(nullptr) +} + +PyObject* _wrap(PyObject * self_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + + #define ARGS(_) _(mpy::handle, orig) _(mpy::handle, dim_offset) _(mpy::handle, keepdim_offset) \ + _(mpy::handle, dim_name) _(mpy::handle, single_dim) _(mpy::handle, reduce) + MPY_PARSE_ARGS_KWNAMES("O|OOOOO", ARGS) + + std::string dim_name_str; + if (dim_name.ptr()) { + dim_name_str = PyUnicode_AsUTF8(dim_name.ptr()); + } else { + dim_name_str = "dim"; + } + auto info = WrappedOperator::create(mpy::object::borrow(orig), (PyCFunction)(void*) patched_dim_method, std::move(dim_name_str)); + if (dim_offset.ptr()) { + info->dim_offset = mpy::to_int(dim_offset); + } + if (keepdim_offset.ptr()) { + info->keepdim_offset = mpy::to_int(keepdim_offset); + } + + if (single_dim.ptr()) { + info->single_dim = mpy::to_bool(single_dim); + } + if (reduce.ptr()) { + info->reduce = mpy::to_bool(reduce); + } + return info->function().release(); + #undef ARGS + + PY_END(nullptr) +} + +PyObject* call_torch_function(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + Arena A; + maybeInitializeGlobals(); + auto info = WrappedOperator::unchecked_wrap(self); + return __torch_function__(A, info->orig, mpy::vector_args(args, nargs, kwnames), info->is_pointwise).release(); + PY_END(nullptr) +} + +PyObject* _wrap_method(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + AT_ASSERT(nargs == 2); + // XXX - ignore python function wrapped, we will call torch function directly + mpy::handle orig = args[0]; + if (!pointwise.ptr()) { + auto dim = mpy::import("functorch.dim"); + pointwise = dim.attr("pointwise"); + } + auto info = WrappedOperator::create(mpy::object::borrow(orig), (PyCFunction)(void*) call_torch_function); + info->is_pointwise = pointwise.contains(orig); + return PyInstanceMethod_New(info->function().release()); + PY_END(nullptr); +} + + +PyObject* Tensor_sum(PyObject * self_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + maybeInitializeGlobals(); + mpy::vector_args va(args, nargs, kwnames); + auto self_ = Tensor::unchecked_wrap(args[0]); + auto d = self_->delayed(); + if (!d) { + return _Tensor_sum.call_vector(va).release(); + } + mpy::handle self, dim, keepdim, dtype; + va.parse("sum", {"self", "dim", "keepdim", "dtype"}, {&self, &dim, &keepdim, &dtype}, 1, 1); + + if (dtype.ptr() || (keepdim.ptr() && mpy::to_bool(keepdim))) { + // std::cout << "SKIPPING fusion because dtype or keepdim=True specified\n"; + return _Tensor_sum.call_vector(va).release(); + } + auto levels = self_->levels(); + + auto N = ndim_of_levels(levels); + auto reduced_dims = _wrap_dims(A, dim, N, false); + + return dot(A, TensorInfo::create(A, d->args[0], false), TensorInfo::create(A, d->args[1], false), reduced_dims).release(); + PY_END(nullptr) +} + +PyObject* _parse_test(PyObject * self_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + maybeInitializeGlobals(); + + int required = mpy::to_int(args[0]); + int kwonly = mpy::to_int(args[1]); + + mpy::vector_args va(args + 2, nargs - 2, kwnames); + + + mpy::handle a, b, c, d; + va.parse("_parse_test", {"a", "b", "c", "d"}, {&a, &b, &c, &d}, required, kwonly); + mpy::tuple r(4); + r.set(0, mpy::object::borrow(a.ptr() ? a : Py_None)); + r.set(1, mpy::object::borrow(b.ptr() ? b : Py_None)); + r.set(2, mpy::object::borrow(c.ptr() ? c : Py_None)); + r.set(3, mpy::object::borrow(d.ptr() ? d : Py_None)); + return r.release(); + + PY_END(nullptr) +} + +PyObject* _set_pointwise_optimize(PyObject * self_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + mpy::handle value; + mpy::vector_args va(args, nargs, kwnames); + va.parse("_set_pointwise_optimization", {"value"}, {&value}, 1); + pointwise_optimize = mpy::to_bool(value); + Py_RETURN_NONE; + PY_END(nullptr) +} + +PyObject* _patch_tensor_class(PyObject * self_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + + auto torch = mpy::import("torch"); + auto py_TensorBase = torch.attr("_C").attr("TensorBase"); + replaceMappingIfMatches(py_TensorBase); + + Py_RETURN_NONE; + PY_END(nullptr) +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const char* dims_doc = R"""( dims(n=None, sizes=None) -> torchdim.Dim or Tuple[torchdim.Dim, ...] @@ -3579,6 +6702,7 @@ Example:: )"""; PyMethodDef methods[] = { +<<<<<<< HEAD {"dims", (PyCFunction)(void*)_dims, METH_FASTCALL | METH_KEYWORDS, @@ -3624,10 +6748,33 @@ PyMethodDef methods[] = { (PyCFunction)(void*)_patch_tensor_class, METH_FASTCALL | METH_KEYWORDS}, {NULL, NULL, 0, NULL} /* Sentinel */ +======= + {"dims", (PyCFunction)(void*) _dims, METH_FASTCALL | METH_KEYWORDS, dims_doc}, + {"dimlists", (PyCFunction)(void*) _dims, METH_FASTCALL | METH_KEYWORDS}, + {"_test_c", (PyCFunction)(void*) test_c, METH_FASTCALL | METH_KEYWORDS}, + {"_wrap_method", (PyCFunction)(void*) _wrap_method, METH_FASTCALL | METH_KEYWORDS}, + {"Tensor_from_positional", (PyCFunction)(void*) py_Tensor_from_positional, METH_FASTCALL | METH_KEYWORDS}, + {"__torch_function__", (PyCFunction)(void*) py___torch_function__, METH_FASTCALL | METH_KEYWORDS}, + {"tree_flatten", (PyCFunction)(void*) py_tree_flatten, METH_FASTCALL | METH_KEYWORDS}, + {"order", (PyCFunction)(void*) order, METH_FASTCALL | METH_KEYWORDS}, + {"index", (PyCFunction)(void*) py_index, METH_FASTCALL | METH_KEYWORDS}, + {"stack", (PyCFunction)(void*) py_stack, METH_FASTCALL | METH_KEYWORDS}, + {"split", (PyCFunction)(void*) py_split, METH_FASTCALL | METH_KEYWORDS}, + {"expand", (PyCFunction)(void*) expand, METH_FASTCALL | METH_KEYWORDS}, + {"__getitem__", (PyCFunction)(void*) py___getitem__, METH_FASTCALL | METH_KEYWORDS}, + {"__setitem__", (PyCFunction)(void*) py___setitem__, METH_FASTCALL | METH_KEYWORDS}, + {"_wrap", (PyCFunction)(void*) _wrap, METH_FASTCALL | METH_KEYWORDS}, + {"Tensor_sum", (PyCFunction)(void*) Tensor_sum, METH_FASTCALL | METH_KEYWORDS}, + {"_parse_test", (PyCFunction)(void*) _parse_test, METH_FASTCALL | METH_KEYWORDS}, + {"_set_pointwise_optimize", (PyCFunction)(void*) _set_pointwise_optimize, METH_FASTCALL | METH_KEYWORDS}, + {"_patch_tensor_class", (PyCFunction)(void*) _patch_tensor_class, METH_FASTCALL | METH_KEYWORDS}, + {NULL, NULL, 0, NULL} /* Sentinel */ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; struct PyModuleDef module_def = { PyModuleDef_HEAD_INIT, +<<<<<<< HEAD "_C", /* name of module */ NULL, /* module documentation, may be NULL */ -1, /* size of per-interpreter state of the module, @@ -3652,6 +6799,32 @@ PyObject* Dim_init() { } catch (mpy::exception_set& err) { return nullptr; } +======= + "_C", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + methods +}; +} + +PyObject* Dim_init() { + Arena A; + try { + mpy::object mod = mpy::object::checked_steal(PyModule_Create(&module_def)); + Dim::ready(mod, "Dim"); + DimList::ready(mod, "DimList"); + Tensor::ready(mod, "Tensor"); + WrappedOperator::ready(mod, "_WrappedOperator"); + Py_INCREF(&PyInstanceMethod_Type); + PyModule_AddObject(mod.ptr(), "_instancemethod", (PyObject *)&PyInstanceMethod_Type); + + initializeGlobals(A); + return mod.release(); + } catch(mpy::exception_set& err) { + return nullptr; + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } #endif diff --git a/functorch/csrc/dim/minpybind.h b/functorch/csrc/dim/minpybind.h index ceced399b40d2..432a5b33e3634 100644 --- a/functorch/csrc/dim/minpybind.h +++ b/functorch/csrc/dim/minpybind.h @@ -602,7 +602,11 @@ struct vector_args { _PyArg_ParseStackAndKeywords((PyObject*const*)args, nargs, kwnames.ptr(), _parser, &dummy, &dummy, &dummy, &dummy, &dummy); #else _PyArg_Parser* _parser = new _PyArg_Parser{NULL, &names_buf[0], fname_cstr, 0}; +<<<<<<< HEAD auto buf = std::make_unique(names.size()); +======= + std::unique_ptr buf(new PyObject*[names.size()]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _PyArg_UnpackKeywords((PyObject*const*)args, nargs, NULL, kwnames.ptr(), _parser, required, (Py_ssize_t)values.size() - kwonly, 0, &buf[0]); #endif throw exception_set(); diff --git a/functorch/dim/__init__.py b/functorch/dim/__init__.py index 95747181e848e..f6aa42bf9176e 100644 --- a/functorch/dim/__init__.py +++ b/functorch/dim/__init__.py @@ -24,6 +24,13 @@ class DimensionBindError(Exception): # use dict to avoid writing C++ bindings for set pointwise = dict.fromkeys(op_properties.pointwise, True) +<<<<<<< HEAD +======= +use_c = True +if not use_c: + from . import reference + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _Tensor: # fast path around slow wrapping/unwrapping logic for simply queries used @@ -36,8 +43,17 @@ def dims(self): def dim(self): return self.ndim +<<<<<<< HEAD __torch_function__ = classmethod(_C.__torch_function__) expand = _C._instancemethod(_C.expand) +======= + if use_c: + __torch_function__ = classmethod(_C.__torch_function__) + expand = _C._instancemethod(_C.expand) + else: + __torch_function__ = reference.__torch_function__ + expand = reference.expand +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) index = _C._instancemethod(_C.index) @@ -56,6 +72,11 @@ class Dim(_C.Dim, _Tensor): class Tensor(_Tensor, _C.Tensor): +<<<<<<< HEAD +======= + if not use_c: + from_batched = staticmethod(_C.Tensor_from_batched) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from_positional = staticmethod(_C.Tensor_from_positional) sum = _C._instancemethod(_C.Tensor_sum) @@ -65,6 +86,7 @@ def cat(tensors, dim, new_dim): return stack(tensors, n, dim).index([n, dim], new_dim) +<<<<<<< HEAD _wrap = _C._wrap @@ -76,6 +98,23 @@ def _def(name, *args, **kwargs): t__getitem__ = _C._instancemethod(_C.__getitem__) stack = _C.stack split = _C._instancemethod(_C.split) +======= +if use_c: + _wrap = _C._wrap + + def _def(name, *args, **kwargs): + orig = getattr(torch.Tensor, name) + setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs))) + + t__getitem__ = _C._instancemethod(_C.__getitem__) + stack = _C.stack + split = _C._instancemethod(_C.split) +else: + _wrap, _def = reference._wrap, reference._def + t__getitem__ = reference.t__getitem__ + stack = reference.stack + split = reference.split +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # note: there is no python reference t__setitem__ = _C._instancemethod(_C.__setitem__) @@ -91,10 +130,20 @@ def _def(name, *args, **kwargs): _Tensor.split = split torch.Tensor.expand = _C._instancemethod(_C.expand) torch.Tensor.index = _C._instancemethod(_C.index) +<<<<<<< HEAD wrap_type(_Tensor, torch.Tensor, _Tensor.__torch_function__) del _Tensor.ndim _Tensor.order = _C._instancemethod(_C.order) +======= +wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__) +del _Tensor.ndim + +if use_c: + _Tensor.order = _C._instancemethod(_C.order) +else: + _Tensor.order = reference.positional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _def("mean") _def("sum") diff --git a/functorch/dim/batch_tensor.py b/functorch/dim/batch_tensor.py new file mode 100644 index 0000000000000..dae9b270896e9 --- /dev/null +++ b/functorch/dim/batch_tensor.py @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from contextlib import contextmanager + +from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers + + +_enabled = False + + +@contextmanager +def _enable_layers(dims): + global _enabled + assert not _enabled + input = sorted((d._level, d.size) for d in dims if not isinstance(d, int)) + n = len(input) + try: + _vmap_add_layers(input) + _enabled = True + yield + finally: + _enabled = False + _vmap_remove_layers(n) diff --git a/functorch/dim/delayed_mul_tensor.py b/functorch/dim/delayed_mul_tensor.py new file mode 100644 index 0000000000000..3c136cfe1247d --- /dev/null +++ b/functorch/dim/delayed_mul_tensor.py @@ -0,0 +1,76 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from . import _Tensor, Tensor +from .reference import _dims, _enable_layers, llist, ltuple + + +class DelayedMulTensor(_Tensor): + def __init__(self, lhs, rhs): + self._lhs, self._rhs = lhs, rhs + self._data = None + self._levels_data = None + self._has_device = lhs._has_device or rhs._has_device + self._batchtensor_data = None + self._tensor_data = None + + @property + def _levels(self): + if self._levels_data is None: + levels = llist(self._lhs._levels) + for l in self._rhs._levels: + if l not in levels: + levels.append(l) + self._levels_data = ltuple(levels) + return self._levels_data + + @property + def _batchtensor(self): + if self._batchtensor_data is None: + with _enable_layers(self._levels): + print("bt multiply fallback") + self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor + return self._batchtensor_data + + @property + def _tensor(self): + if self._tensor_data is None: + self._tensor_data = Tensor.from_batched( + self._batchtensor, self._has_device + )._tensor + return self._tensor_data + + @property + def ndim(self): + return self._batchtensor.ndim + + @property + def dims(self): + return ltuple(super().dims) + + def sum(self, dim): + dims = _dims(dim, 0, False, False) + n = ord("a") + all_levels = self._levels + + def to_char(d): + return chr(n + all_levels.index(d)) + + plhs, levelslhs = self._lhs._tensor, self._lhs._levels + prhs, levelsrhs = self._rhs._tensor, self._rhs._levels + new_levels = [l for l in self._levels if l not in dims] + fmt = "".join( + [ + *(to_char(d) for d in levelslhs), + ",", + *(to_char(d) for d in levelsrhs), + "->", + *(to_char(d) for d in new_levels), + ] + ) + result_data = torch.einsum(fmt, (plhs, prhs)) + return Tensor.from_positional(result_data, new_levels, True) diff --git a/functorch/dim/dim.py b/functorch/dim/dim.py new file mode 100644 index 0000000000000..9a4b568664849 --- /dev/null +++ b/functorch/dim/dim.py @@ -0,0 +1,120 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import dis +import inspect +from dataclasses import dataclass +from typing import Union + +from . import DimList + + +_vmap_levels = [] + + +@dataclass +class LevelInfo: + level: int + alive: bool = True + + +class Dim: + def __init__(self, name: str, size: Union[None, int] = None): + self.name = name + self._size = None + self._vmap_level = None + if size is not None: + self.size = size + + def __del__(self): + if self._vmap_level is not None: + _vmap_active_levels[self._vmap_stack].alive = False # noqa: F821 + while ( + not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level # noqa: F821 + ): + _vmap_decrement_nesting() # noqa: F821 + _vmap_levels.pop() + + @property + def size(self): + assert self.is_bound + return self._size + + @size.setter + def size(self, size: int): + from . import DimensionBindError + + if self._size is None: + self._size = size + self._vmap_level = _vmap_increment_nesting(size, "same") # noqa: F821 + self._vmap_stack = len(_vmap_levels) + _vmap_levels.append(LevelInfo(self._vmap_level)) + + elif self._size != size: + raise DimensionBindError( + f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}" + ) + + @property + def is_bound(self): + return self._size is not None + + def __repr__(self): + return self.name + + +def extract_name(inst): + assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME" + return inst.argval + + +_cache = {} + + +def dims(lists=0): + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + code, lasti = calling_frame.f_code, calling_frame.f_lasti + key = (code, lasti) + if key not in _cache: + first = lasti // 2 + 1 + instructions = list(dis.get_instructions(calling_frame.f_code)) + unpack = instructions[first] + + if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME": + # just a single dim, not a list + name = unpack.argval + ctor = Dim if lists == 0 else DimList + _cache[key] = lambda: ctor(name=name) + else: + assert unpack.opname == "UNPACK_SEQUENCE" + ndims = unpack.argval + names = tuple( + extract_name(instructions[first + 1 + i]) for i in range(ndims) + ) + first_list = len(names) - lists + _cache[key] = lambda: tuple( + Dim(n) if i < first_list else DimList(name=n) + for i, n in enumerate(names) + ) + return _cache[key]() + + +def _dim_set(positional, arg): + def convert(a): + if isinstance(a, Dim): + return a + else: + assert isinstance(a, int) + return positional[a] + + if arg is None: + return positional + elif not isinstance(arg, (Dim, int)): + return tuple(convert(a) for a in arg) + else: + return (convert(arg),) diff --git a/functorch/dim/reference.py b/functorch/dim/reference.py new file mode 100644 index 0000000000000..fd934011d8238 --- /dev/null +++ b/functorch/dim/reference.py @@ -0,0 +1,645 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# reference python implementations for C ops +import torch +from functorch._C import dim as _C + +from . import op_properties +from .batch_tensor import _enable_layers +from .tree_map import tree_flatten, tree_map + + +DimList = _C.DimList +import operator +from functools import reduce + + +# use dict to avoid writing C++ bindings for set +pointwise = set(op_properties.pointwise) + + +def prod(x): + return reduce(operator.mul, x, 1) + + +def _wrap_dim(d, N, keepdim): + from . import Dim + + if isinstance(d, Dim): + assert not keepdim, "cannot preserve first-class dimensions with keepdim=True" + return d + elif d >= 0: + return d - N + else: + return d + + +def _dims(d, N, keepdim, single_dim): + from . import Dim + + if isinstance(d, (Dim, int)): + return ltuple((_wrap_dim(d, N, keepdim),)) + assert not single_dim, f"expected a single dimension or int but found: {d}" + return ltuple(_wrap_dim(x, N, keepdim) for x in d) + + +def _bind_dims_to_size(lhs_size, rhs, lhs_debug): + from . import DimensionMismatchError + + not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound) + if len(not_bound) == 1: + idx, d = not_bound[0] + rhs_so_far = prod(r.size for r in rhs if r.is_bound) + if lhs_size % rhs_so_far != 0: + rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) + raise DimensionMismatchError( + f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}" + ) + new_size = lhs_size // rhs_so_far + d.size = new_size + elif len(not_bound) > 1: + rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) + raise DimensionMismatchError( + f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}" + ) + else: + rhs_size = prod(r.size for r in rhs) + if lhs_size != rhs_size: + raise DimensionMismatchError( + f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}" + ) + + +def _tensor_levels(inp): + from . import _Tensor + + if isinstance(inp, _Tensor): + return inp._tensor, llist(inp._levels), inp._has_device + else: + return inp, llist(range(-inp.ndim, 0)), True + + +def _match_levels(v, from_levels, to_levels): + view = [] + permute = [] + requires_view = False + size = v.size() + for t in to_levels: + try: + idx = from_levels.index(t) + permute.append(idx) + view.append(size[idx]) + except ValueError: + view.append(1) + requires_view = True + if permute != list(range(len(permute))): + v = v.permute(*permute) + if requires_view: + v = v.view(*view) + return v + + +# make a single dimension positional but do not permute it, +# used to do multi-tensor operators where the dim being acted on +# should not physically move if possible +def _positional_no_permute(self, dim, expand_dim=False): + from . import Tensor + + ptensor, levels = self._tensor, llist(self._levels) + try: + idx = levels.index(dim) + except ValueError: + if not expand_dim: + raise + idx = 0 + ptensor = ptensor.expand(dim.size, *ptensor.size()) + levels.insert(0, 0) + idx_batched = 0 + for i in range(idx): + if isinstance(levels[i], int): + levels[i] -= 1 + idx_batched += 1 + levels[idx] = -idx_batched - 1 + return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched + + +def seq(a, b): + from . import Dim + + if isinstance(a, Dim) != isinstance(b, Dim): + return False + if isinstance(a, Dim): + return a is b + else: + return a == b + + +class isin: + __slots__ = () + + def __contains__(self, item): + for x in self: + if seq(item, x): + return True + return False + + def index(self, item): + for i, x in enumerate(self): + if seq(item, x): + return i + raise ValueError + + +class llist(isin, list): + __slots__ = () + + +class ltuple(isin, tuple): + __slots__ = () + + +empty_dict = {} + + +@classmethod +def __torch_function__(self, orig, cls, args, kwargs=empty_dict): + from . import _Tensor, Tensor, TensorLike + from .delayed_mul_tensor import DelayedMulTensor + + if orig is torch.Tensor.__mul__: + lhs, rhs = args + if ( + isinstance(lhs, _Tensor) + and isinstance(rhs, _Tensor) + and lhs.ndim == 0 + and rhs.ndim == 0 + ): + return DelayedMulTensor(lhs, rhs) + all_dims = llist() + flat_args, unflatten = tree_flatten((args, kwargs)) + device_holding_tensor = None + for f in flat_args: + if isinstance(f, _Tensor): + if f._has_device: + device_holding_tensor = f._batchtensor + for d in f.dims: + if d not in all_dims: + all_dims.append(d) + + def unwrap(t): + if isinstance(t, _Tensor): + r = t._batchtensor + if device_holding_tensor is not None and not t._has_device: + r = r.to(device=device_holding_tensor.device) + return r + return t + + if orig in pointwise: + result_levels = llist() + to_expand = [] + for i, f in enumerate(flat_args): + if isinstance(f, TensorLike): + ptensor, levels, _ = _tensor_levels(f) + if ( + isinstance(f, _Tensor) + and not f._has_device + and device_holding_tensor is not None + ): + ptensor = ptensor.to(device=device_holding_tensor.device) + flat_args[i] = ptensor + for l in levels: + if l not in result_levels: + result_levels.append(l) + to_expand.append((i, levels)) + + for i, levels in to_expand: + flat_args[i] = _match_levels(flat_args[i], levels, result_levels) + args, kwargs = unflatten(flat_args) + result = orig(*args, **kwargs) + + def wrap(t): + if isinstance(t, TensorLike): + return Tensor.from_positional( + t, result_levels, device_holding_tensor is not None + ) + return t + + return tree_map(wrap, result) + else: + + def wrap(t): + if isinstance(t, TensorLike): + return Tensor.from_batched(t, device_holding_tensor is not None) + return t + + with _enable_layers(all_dims): + print(f"batch_tensor for {orig}") + args, kwargs = unflatten(unwrap(f) for f in flat_args) + result = orig(*args, **kwargs) + # print("END", orig) + return tree_map(wrap, result) + + +def positional(self, *dims): + from . import Dim, DimensionBindError, Tensor + + ptensor, levels = self._tensor, llist(self._levels) + flat_dims = llist() + view = [] + needs_view = False + ndim = self.ndim + for d in dims: + if isinstance(d, DimList): + flat_dims.extend(d) + view.extend(e.size for e in d) + elif isinstance(d, Dim): + flat_dims.append(d) + view.append(d.size) + elif isinstance(d, int): + d = _wrap_dim(d, ndim, False) + flat_dims.append(d) + view.append(ptensor.size(d)) + else: + flat_dims.extend(d) + view.append(prod(e.size for e in d)) + needs_view = True + + permute = list(range(len(levels))) + for i, d in enumerate(flat_dims): + try: + idx = levels.index(d) + except ValueError as e: + raise DimensionBindError( + f"tensor of dimensions {self.dims} does not contain dim {d}" + ) from e + p = permute[idx] + del levels[idx] + del permute[idx] + levels.insert(i, 0) + permute.insert(i, p) + ptensor = ptensor.permute(*permute) + seen = 0 + for i in range(len(levels) - 1, -1, -1): + if isinstance(levels[i], int): + seen += 1 + levels[i] = -seen + result = Tensor.from_positional(ptensor, levels, self._has_device) + if needs_view: + result = result.reshape(*view, *result.size()[len(flat_dims) :]) + return result + + +def _contains_dim(input): + from . import Dim + + for i in input: + if isinstance(i, Dim): + return True + + +def expand(self, *sizes): + if not _contains_dim(sizes): + return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes)) + dims = sizes + sizes = [d.size for d in dims] + [-1] * self.ndim + self = self.expand(*sizes) + return self[dims] + + +_not_present = object() + + +def _getarg(name, offset, args, kwargs, default): + if len(args) > offset: + return args[offset] + return kwargs.get(name, default) + + +def _patcharg(name, offset, args, kwargs, value): + if len(args) > offset: + args[offset] = value + else: + kwargs[name] = value + + +def _wrap( + orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True +): + from . import Dim, Tensor, TensorLike + + def fn(self, *args, **kwargs): + dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present) + if dim is _not_present or (single_dim and not isinstance(dim, Dim)): + with _enable_layers(self.dims): + print(f"dim fallback batch_tensor for {orig}") + return Tensor.from_batched( + orig(self._batchtensor, *args, **kwargs), self._has_device + ) + keepdim = ( + _getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False + ) + t, levels = self._tensor, llist(self._levels) + dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim) + dim_indices = tuple(levels.index(d) for d in dims) + if reduce and not keepdim: + new_levels = [l for i, l in enumerate(levels) if i not in dim_indices] + else: + new_levels = levels + + if len(dim_indices) == 1: + dim_indices = dim_indices[ + 0 + ] # so that dims that really only take a single argument work... + args = list(args) + _patcharg(dim_name, dim_offset, args, kwargs, dim_indices) + + def wrap(t): + if isinstance(t, TensorLike): + return Tensor.from_positional(t, new_levels, self._has_device) + return t + + with _enable_layers(new_levels): + print(f"dim used batch_tensor for {orig}") + r = orig(t, *args, **kwargs) + return tree_map(wrap, r) + + return fn + + +def _def(name, *args, **kwargs): + from . import _Tensor + + orig = getattr(torch.Tensor, name) + setattr(_Tensor, name, _wrap(orig, *args, **kwargs)) + + +no_slice = slice(None) + +_orig_getitem = torch.Tensor.__getitem__ + + +class dim_tracker: + def __init__(self) -> None: + self.dims = llist() + self.count = [] + + def record(self, d): + if d not in self.dims: + self.dims.append(d) + self.count.append(1) + + def __getitem__(self, d): + return self.count[self.dims.index(d)] + + +def t__getitem__(self, input): + from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike + + # * bail to original example if we have a single non-Dim tensor, or a non-tensor + # * locate ... or an unbound tensor list, and determine its size, bind dim list + # (remember that None does not count to the total dim count) + # * bind simple dims and dim-packs to their sizes, count the number of uses of each dim, + # produce the re-view if needed + # * for each single-use dim index, replace with no_slice and mark that it will be added + # (keep track of whether we have to call super) + # * call super if needed + # * if we have dims to bind, bind them (it will help if we eliminated ... and None before) + # this handles bool indexing handling, as well as some other simple cases. + + is_simple = ( + not isinstance(input, Dim) + and not isinstance(input, (tuple, list)) + and + # WAR for functorch bug where zero time tensors in getitem are not handled correctly. + not (isinstance(input, TensorLike) and input.ndim == 0) + ) + + if is_simple: + if isinstance(self, _Tensor): + return _Tensor.__torch_function__(_orig_getitem, None, (self, input)) + else: + return _orig_getitem(self, input) + + # can further optimize this case + if not isinstance(input, tuple): + input = [input] + else: + input = list(input) + + dims_indexed = 0 + expanding_object = None + dimlists = [] + for i, s in enumerate(input): + if s is ... or isinstance(s, DimList) and not s.is_bound: + if expanding_object is not None: + msg = ( + "at most one ... or unbound dimension list can exist in indexing list but" + f" found 2 at offsets {i} and {expanding_object}" + ) + raise DimensionBindError(msg) + expanding_object = i + + if isinstance(s, DimList): + dims_indexed += len(s) if s.is_bound else 0 + dimlists.append(i) + elif s is not None and s is not ...: + dims_indexed += 1 + + ndim = self.ndim + if dims_indexed > ndim: + raise IndexError( + f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions." + ) + if expanding_object is not None: + expanding_ndims = ndim - dims_indexed + obj = input[expanding_object] + if obj is ...: + input[expanding_object : expanding_object + 1] = [ + no_slice + ] * expanding_ndims + else: + obj.bind_len(expanding_ndims) + # flatten the dimslists into the indexing + for i in reversed(dimlists): + input[i : i + 1] = input[i] + dims_indexed = 0 + requires_view = False + size = self.size() + view_sizes = [] + dims_seen = dim_tracker() + + def add_dims(t): + if not isinstance(t, _Tensor): + return + for d in t.dims: + dims_seen.record(d) + + add_dims(self) + dim_packs = [] + for i, idx in enumerate(input): + if idx is None: + input[i] = no_slice + view_sizes.append(1) + requires_view = True + else: + sz = size[dims_indexed] + if isinstance(idx, Dim): + idx.size = sz + dims_seen.record(idx) + view_sizes.append(sz) + elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim): + for d in idx: + dims_seen.record(idx) + _bind_dims_to_size(sz, idx, f"offset {i}") + view_sizes.extend(d.size for d in idx) + requires_view = True + dim_packs.append(i) + else: + add_dims(idx) + view_sizes.append(sz) + dims_indexed += 1 + if requires_view: + self = self.view(*view_sizes) + for i in reversed(dim_packs): + input[i : i + 1] = input[i] + + # currently: + # input is flat, containing either Dim, or Tensor, or something valid for standard indexing + # self may have first-class dims as well. + + # to index: + # drop the first class dims from self, they just become direct indices of their positions + + # figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index. + # these dimensions will appear and need to be bound at the first place tensor occurs + + if isinstance(self, _Tensor): + ptensor_self, levels = self._tensor, list(self._levels) + # indices to ptensor rather than self which has first-class dimensions + input_it = iter(input) + flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels] + has_device = self._has_device + to_pad = 0 + else: + ptensor_self, flat_inputs = self, input + to_pad = ptensor_self.ndim - len(flat_inputs) + has_device = True + + result_levels = [] + index_levels = [] + tensor_insert_point = None + to_expand = {} + requires_getindex = False + for i, inp in enumerate(flat_inputs): + if isinstance(inp, Dim) and dims_seen[inp] == 1: + flat_inputs[i] = no_slice + result_levels.append(inp) + elif isinstance(inp, TensorLike): + requires_getindex = True + if tensor_insert_point is None: + tensor_insert_point = len(result_levels) + ptensor, levels, _ = _tensor_levels(inp) + to_expand[i] = levels + flat_inputs[i] = ptensor + for l in levels: + if l not in index_levels: + index_levels.append(l) + else: + requires_getindex = True + result_levels.append(0) + + if tensor_insert_point is not None: + result_levels[tensor_insert_point:tensor_insert_point] = index_levels + + for i, levels in to_expand.items(): + flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels) + + if requires_getindex: + result = _orig_getitem(ptensor_self, flat_inputs) + else: + result = ptensor_self + + next_positional = -1 + if to_pad > 0: + result_levels.extend([0] * to_pad) + for i, r in enumerate(reversed(result_levels)): + if isinstance(r, int): + result_levels[-1 - i] = next_positional + next_positional -= 1 + + return Tensor.from_positional(result, result_levels, has_device) + + +# XXX - dim is optional and can be the outer-most dimension... +def stack(tensors, new_dim, dim=0, out=None): + if isinstance(dim, int): + return torch.stack(tensors, dim, out).index(dim, new_dim) + index = None + if out is not None: + out, index = _positional_no_permute(out, dim, expand_dim=True) + ptensors = [] + for t in tensors: + pt, pi = _positional_no_permute(t, dim, expand_dim=True) + if index is not None and pi != index: + pt = pt.move_dim(pi, index) + else: + index = pi + ptensors.append(pt) + pr = torch.stack(ptensors, index, out=out) + return pr.index((index, index + 1), (new_dim, dim)) + + +_orig_split = torch.Tensor.split + + +def split(self, split_size_or_sections, dim=0): + from . import _Tensor, Dim + + if isinstance(split_size_or_sections, int) or any( + isinstance(t, int) for t in split_size_or_sections + ): + if isinstance(dim, Dim): + raise ValueError( + "when dim is specified as a Dim object, split sizes must also be dimensions." + ) + return _orig_split(self, split_size_or_sections, dim=dim) + + if isinstance(dim, Dim): + assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}" + self, dim = _positional_no_permute(self, dim) + + size = self.size(dim) + total_bound_size = 0 + unbound = [] + sizes = [] + for i, d in enumerate(split_size_or_sections): + if d.is_bound: + sizes.append(d.size) + total_bound_size += d.size + else: + sizes.append(0) + unbound.append(i) + + if unbound: + assert total_bound_size <= size, ( + f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})" + ) + remaining_size = size - total_bound_size + chunk_size = -(-remaining_size // len(unbound)) + for u in unbound: + sz = min(chunk_size, remaining_size) + split_size_or_sections[u].size = sz + sizes[u] = sz + remaining_size -= sz + else: + assert total_bound_size == size, ( + f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})" + ) + return tuple( + t.index(dim, d) + for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim)) + ) diff --git a/functorch/dim/wrap_type.py b/functorch/dim/wrap_type.py index b9ebda47c4cfe..2e3fab10787d2 100644 --- a/functorch/dim/wrap_type.py +++ b/functorch/dim/wrap_type.py @@ -26,8 +26,23 @@ PROPERTY_TYPES = (GetSetDescriptorType, property) +<<<<<<< HEAD def wrap_type(to_patch, pattern, __torch_function__): wrap_method = _wrap_method +======= +def _py_wrap_method(orig, __torch_function__): + def impl(*args, **kwargs): + return __torch_function__(orig, None, args, kwargs) + + return impl + + +def wrap_type(use_c, to_patch, pattern, __torch_function__): + if use_c: + wrap_method = _wrap_method + else: + wrap_method = _py_wrap_method +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) all = {} for t in reversed(pattern.mro()[:-1]): # skip object diff --git a/functorch/experimental/__init__.py b/functorch/experimental/__init__.py index 0500fc2c29d35..c4f55416592c5 100644 --- a/functorch/experimental/__init__.py +++ b/functorch/experimental/__init__.py @@ -1,5 +1,12 @@ # PyTorch forward-mode is not mature yet +<<<<<<< HEAD from torch._functorch.apis import chunk_vmap from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_ from torch._functorch.eager_transforms import hessian, jacfwd, jvp from torch.func import functionalize +======= +from functorch import functionalize +from torch._functorch.apis import chunk_vmap +from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_ +from torch._functorch.eager_transforms import hessian, jacfwd, jvp +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/functorch/writing_batching_rules.md b/functorch/writing_batching_rules.md index 61872c8d52327..14e56b86f35f3 100644 --- a/functorch/writing_batching_rules.md +++ b/functorch/writing_batching_rules.md @@ -5,7 +5,11 @@ First off, what are batching rules and why do we need so many of them? Well, to ### How does vmap work? Vmap is a function transform (pioneered by Jax) that allows one to batch functions. That is, given a function `f(x: [N]) -> [N]`, `vmap(f)` now transforms the signature to be `f(x: [B, N]) -> [B, N]`. That is - it adds a batch dimension to both the input and the output of the function. +<<<<<<< HEAD This guide will gloss over all the cool things you can do with this (there are many!), so let's focus on how we actually implement this. +======= +This guide will gloss over all the cool things you can do this (there are many!), so let's focus on how we actually implement this. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) One misconception is that this is some magic compiler voodoo, or that it is inherently some function transform. It is not - and there's another framing of it that might make it more clear. diff --git a/mypy.ini b/mypy.ini index e6a8af4c88c20..6cd883a0576c1 100644 --- a/mypy.ini +++ b/mypy.ini @@ -55,6 +55,12 @@ python_version = 3.11 # Extension modules without stubs. # +<<<<<<< HEAD +======= +[mypy-torch._C._jit_tree_views] +ignore_missing_imports = True + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [mypy-torch.for_onnx.onnx] ignore_missing_imports = True diff --git a/pt_template_srcs.bzl b/pt_template_srcs.bzl index 84f5f8bd3e627..441bdeb70bfc4 100644 --- a/pt_template_srcs.bzl +++ b/pt_template_srcs.bzl @@ -156,7 +156,10 @@ def get_generate_code_bin_outs(): "autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"], "autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"], "autograd/generated/python_variable_methods.cpp": ["autograd/generated/python_variable_methods.cpp"], +<<<<<<< HEAD "functionalization/generated/ViewMetaClassesPythonBinding.cpp": ["functionalization/generated/ViewMetaClassesPythonBinding.cpp"], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) return outs diff --git a/pyproject.toml b/pyproject.toml index 925742b4c3344..fe1b57c429824 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,4 @@ +<<<<<<< HEAD # Package ###################################################################### [build-system] @@ -83,11 +84,27 @@ classifiers = [ dynamic = [ "entry-points", "dependencies", +======= +[project] +name = "torch" +requires-python = ">=3.9" +license = {text = "BSD-3-Clause"} +dynamic = [ + "authors", + "classifiers", + "entry-points", + "dependencies", + "description", + "keywords", + "optional-dependencies", + "readme", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "scripts", "version", ] [project.urls] +<<<<<<< HEAD Homepage = "https://pytorch.org" Repository = "https://github.com/pytorch/pytorch" Documentation = "https://pytorch.org/docs" @@ -100,6 +117,36 @@ opt-einsum = ["opt-einsum>=3.3"] pyyaml = ["pyyaml"] # Linter tools ################################################################# +======= +Homepage = "https://pytorch.org/" +Documentation = "https://pytorch.org/docs/" +Source = "https://github.com/pytorch/pytorch" +Forum = "https://discuss.pytorch.org/" + + +[build-system] +requires = [ + # After 75.8.2 dropped dep disttools API. Please fix + # API temporarily restored and shim used. Please fix + # Setuptools will drop support for setup.py past 80 + # min version for recursive glob package data support + "setuptools>=62.3.0,<80.0", + "wheel", + "astunparse", + "numpy", + "ninja", + "pyyaml", + "cmake", + "typing-extensions>=4.10.0", + "requests", +] +# Use legacy backend to import local packages in setup.py +build-backend = "setuptools.build_meta:__legacy__" + + +[tool.black] +line-length = 88 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [tool.isort] src_paths = ["caffe2", "torch", "torchgen", "functorch", "test"] @@ -115,10 +162,18 @@ multi_line_output = 3 include_trailing_comma = true combine_as_imports = true +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [tool.usort.known] first_party = ["caffe2", "torch", "torchgen", "functorch", "test"] standard_library = ["typing_extensions"] +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [tool.ruff] line-length = 88 src = ["caffe2", "torch", "torchgen", "functorch", "test"] @@ -159,7 +214,10 @@ ignore = [ "E741", "EXE001", "F405", +<<<<<<< HEAD "FURB122", # writelines +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # these ignores are from flake8-logging-format; please fix! "G101", # these ignores are from ruff NPY; please fix! @@ -182,6 +240,7 @@ ignore = [ "SIM117", "SIM118", "UP007", # keep-runtime-typing +<<<<<<< HEAD "UP045", # keep-runtime-typing "TC006", # TODO: Remove Python-3.10 specific suppressions @@ -191,6 +250,9 @@ ignore = [ "UP038", "UP041", "FURB161", +======= + "TC006", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] select = [ "B", @@ -271,10 +333,13 @@ select = [ "YTT", ] +<<<<<<< HEAD [tool.ruff.lint.pyupgrade] # Preserve types, even if a file imports `from __future__ import annotations`. keep-runtime-typing = true +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [tool.ruff.lint.per-file-ignores] "__init__.py" = [ "F401", diff --git a/pyrefly.toml b/pyrefly.toml index 6b94aeb5c1ca5..f8085b0900e69 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -1,4 +1,8 @@ +<<<<<<< HEAD project-includes = [ +======= +project_includes = [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch", "caffe2", "test/test_bundled_images.py", @@ -7,11 +11,19 @@ project-includes = [ "test/test_datapipe.py", "test/test_futures.py", "test/test_numpy_interop.py", +<<<<<<< HEAD +======= + "test/test_torch.py", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "test/test_type_hints.py", "test/test_type_info.py", "test/test_utils.py", ] +<<<<<<< HEAD project-excludes = [ +======= +project_excludes = [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/include/**", "torch/csrc/**", "torch/distributed/elastic/agent/server/api.py", @@ -26,7 +38,11 @@ project-excludes = [ "*/__pycache__/**", "*/.*", ] +<<<<<<< HEAD ignore-missing-imports = [ +======= +replace_imports_with_any = [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch._C._jit_tree_views.*", "torch.for_onnx.onnx.*", "torch.ao.quantization.experimental.apot_utils.*", @@ -84,6 +100,7 @@ ignore-missing-imports = [ "redis.*" ] +<<<<<<< HEAD untyped_def_behavior = "check-and-infer-return-any" # Shut off noisy errors @@ -97,3 +114,6 @@ errors.implicit-import = false # [[tool.pyrefly.sub-config]] # matches = "test/test_torch.py" # untyped-def-behavior = "skip-and-infer-return-any" +======= +untyped_def_behavior = "check-and-infer-return-any" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/related_commits b/related_commits index b96cf18c181ab..01dd645c3dbf6 100644 --- a/related_commits +++ b/related_commits @@ -1,3 +1,4 @@ +<<<<<<< HEAD ubuntu|pytorch|apex|release/1.9.0|07c3ee5347294b7a07a65c2c3596f1b14c7d3daa|https://github.com/ROCm/apex centos|pytorch|apex|release/1.9.0|07c3ee5347294b7a07a65c2c3596f1b14c7d3daa|https://github.com/ROCm/apex ubuntu|pytorch|torchvision|release/0.24|b919bd0c56abbb3c5ca056a3a458af9fd1cabf52|https://github.com/pytorch/vision @@ -8,3 +9,15 @@ ubuntu|pytorch|torchaudio|release/2.9|e3c6ee2b6588b7cd27a84182de74bf12fe043831|h centos|pytorch|torchaudio|release/2.9|e3c6ee2b6588b7cd27a84182de74bf12fe043831|https://github.com/pytorch/audio ubuntu|pytorch|ao|main|a52a64aeb84fa6ff683ec2c7c42b97e27651a619|https://github.com/pytorch/ao centos|pytorch|ao|main|a52a64aeb84fa6ff683ec2c7c42b97e27651a619|https://github.com/pytorch/ao +======= +ubuntu|pytorch|apex|release/1.8.0|3f26640cff501d67d35acf424ed2566d50949f5b|https://github.com/ROCm/apex +centos|pytorch|apex|release/1.8.0|3f26640cff501d67d35acf424ed2566d50949f5b|https://github.com/ROCm/apex +ubuntu|pytorch|torchvision|release/0.23|824e8c8726b65fd9d5abdc9702f81c2b0c4c0dc8|https://github.com/pytorch/vision +centos|pytorch|torchvision|release/0.23|824e8c8726b65fd9d5abdc9702f81c2b0c4c0dc8|https://github.com/pytorch/vision +ubuntu|pytorch|torchdata|release/0.11|377e64c1be69a9be6649d14c9e3664070323e464|https://github.com/pytorch/data +centos|pytorch|torchdata|release/0.11|377e64c1be69a9be6649d14c9e3664070323e464|https://github.com/pytorch/data +ubuntu|pytorch|torchaudio|release/2.8|6e1c7fe9ff6d82b8665d0a46d859d3357d2ebaaa|https://github.com/pytorch/audio +centos|pytorch|torchaudio|release/2.8|6e1c7fe9ff6d82b8665d0a46d859d3357d2ebaaa|https://github.com/pytorch/audio +ubuntu|pytorch|ao|main|a96eeb1c7d7ba24cf0ccfc105141729acfed22bf|https://github.com/pytorch/ao +centos|pytorch|ao|main|a96eeb1c7d7ba24cf0ccfc105141729acfed22bf|https://github.com/pytorch/ao +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/requirements.txt b/requirements.txt index 824ca112602a0..b79d4fd0d24ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ # Python dependencies required for development +<<<<<<< HEAD # Build System requirements --requirement requirements-build.txt @@ -11,12 +12,35 @@ fsspec==2025.9.0 hypothesis==5.35.1 Jinja2==3.1.6 lintrunner==0.12.7 ; platform_machine != "s390x" and platform_machine != "riscv64" +======= +astunparse==1.6.3 +cmake>=3.31.4 +expecttest==0.3.0 +filelock==3.18.0 +fsspec==2025.7.0 +hypothesis==5.35.1 +jinja2==3.1.6 +lintrunner==0.12.7 ; platform_machine != "s390x" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) networkx==2.8.8 ninja==1.11.1.3 numpy==2.0.2 ; python_version == "3.9" numpy==2.1.2 ; python_version > "3.9" optree==0.13.0 +<<<<<<< HEAD psutil==7.1.0 sympy==1.13.3 typing_extensions==4.15.0 wheel +======= +packaging==25.0 +psutil==7.0.0 +pyyaml==6.0.2 +requests==2.32.4 +# setuptools develop deprecated on 80.0 +# issue on Windows after >= 75.8.2 - https://github.com/pytorch/pytorch/issues/148877 +setuptools==75.8.2 +sympy==1.13.3 +types-dataclasses==0.6.6 +typing-extensions==4.14.1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/scripts/README.md b/scripts/README.md index 367e7261f6a60..a63516bf8cb7c 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -1 +1,43 @@ This directory contains the useful tools. +<<<<<<< HEAD +======= + + +## build_android.sh +This script is to build PyTorch/Caffe2 library for Android. Take the following steps to start the build: + +- set ANDROID_NDK to the location of ndk + +```bash +export ANDROID_NDK=YOUR_NDK_PATH +``` + +- run build_android.sh +```bash +#in your PyTorch root directory +bash scripts/build_android.sh +``` +If succeeded, the libraries and headers would be generated to build_android/install directory. You can then copy these files from build_android/install to your Android project for further usage. + +You can also override the cmake flags via command line, e.g., following command will also compile the executable binary files: +```bash +bash scripts/build_android.sh -DBUILD_BINARY=ON +``` + +## build_ios.sh +This script is to build PyTorch/Caffe2 library for iOS, and can only be performed on macOS. Take the following steps to start the build: + +- Install Xcode from App Store, and configure "Command Line Tools" properly on Xcode. +- Install the dependencies: + +```bash +brew install cmake automake libtool +``` + +- run build_ios.sh +```bash +#in your PyTorch root directory +bash scripts/build_ios.sh +``` +If succeeded, the libraries and headers would be generated to build_ios/install directory. You can then copy these files to your Xcode project for further usage. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/scripts/add_apache_header.sh b/scripts/add_apache_header.sh new file mode 100755 index 0000000000000..a29a059d2d033 --- /dev/null +++ b/scripts/add_apache_header.sh @@ -0,0 +1 @@ +cat apache_header.txt $1 > _add_apache_header.txt && mv _add_apache_header.txt $1 diff --git a/scripts/apache_header.txt b/scripts/apache_header.txt new file mode 100644 index 0000000000000..b4eff258eb04d --- /dev/null +++ b/scripts/apache_header.txt @@ -0,0 +1,15 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ diff --git a/scripts/apache_python.txt b/scripts/apache_python.txt new file mode 100644 index 0000000000000..bc104d8845154 --- /dev/null +++ b/scripts/apache_python.txt @@ -0,0 +1,14 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## diff --git a/scripts/build_android.sh b/scripts/build_android.sh new file mode 100755 index 0000000000000..43f11b86828d4 --- /dev/null +++ b/scripts/build_android.sh @@ -0,0 +1,189 @@ +#!/bin/bash +############################################################################## +# Example command to build the android target. +############################################################################## +# +# This script shows how one can build a Caffe2 binary for the Android platform +# using android-cmake. A few notes: +# +# (1) This build also does a host build for protobuf. You will need autoconf +# to carry out this. If autoconf is not possible, you will need to provide +# a pre-built protoc binary that is the same version as the protobuf +# version under third_party. +# If you are building on Mac, you might need to install autotool and +# libtool. The easiest way is via homebrew: +# brew install automake +# brew install libtool +# (2) You will need to have android ndk installed. The current script assumes +# that you set ANDROID_NDK to the location of ndk. +# (3) The toolchain and the build target platform can be specified with the +# cmake arguments below. For more details, check out android-cmake's doc. + +set -e + +# Android specific flags +if [ -z "$ANDROID_ABI" ]; then + ANDROID_ABI="armeabi-v7a with NEON" +fi +ANDROID_NATIVE_API_LEVEL="21" +echo "Build with ANDROID_ABI[$ANDROID_ABI], ANDROID_NATIVE_API_LEVEL[$ANDROID_NATIVE_API_LEVEL]" + +CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" +if [ -z "$ANDROID_NDK" ]; then + echo "ANDROID_NDK not set; please set it to the Android NDK directory" + exit 1 +fi + +if [ ! -d "$ANDROID_NDK" ]; then + echo "ANDROID_NDK not a directory; did you install it under $ANDROID_NDK?" + exit 1 +fi + +if [ -z "$PYTHON" ]; then + PYTHON=python + PYTHON_VERSION_MAJOR=$($PYTHON -c 'import sys; print(sys.version_info[0])') + if [ "${PYTHON_VERSION_MAJOR}" -le 2 ]; then + echo "Default python executable is Python-2, trying to use python3 alias" + PYTHON=python3 + fi +fi + +ANDROID_NDK_PROPERTIES="$ANDROID_NDK/source.properties" +[ -f "$ANDROID_NDK_PROPERTIES" ] && ANDROID_NDK_VERSION=$(sed -n 's/^Pkg.Revision[^=]*= *\([0-9]*\)\..*$/\1/p' "$ANDROID_NDK_PROPERTIES") + +echo "Bash: $(/bin/bash --version | head -1)" +echo "Python: $($PYTHON -c 'import sys; print(sys.version)')" +echo "Caffe2 path: $CAFFE2_ROOT" +echo "Using Android NDK at $ANDROID_NDK" +echo "Android NDK version: $ANDROID_NDK_VERSION" + +CMAKE_ARGS=() + +# Build PyTorch mobile +CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$($PYTHON -c 'import sysconfig; print(sysconfig.get_path("purelib"))')") +CMAKE_ARGS+=("-DPython_EXECUTABLE=$($PYTHON -c 'import sys; print(sys.executable)')") +CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") + +# custom build with selected ops +if [ -n "${SELECTED_OP_LIST}" ]; then + SELECTED_OP_LIST="$(cd $(dirname $SELECTED_OP_LIST); pwd -P)/$(basename $SELECTED_OP_LIST)" + echo "Choose SELECTED_OP_LIST file: $SELECTED_OP_LIST" + if [ ! -r ${SELECTED_OP_LIST} ]; then + echo "Error: SELECTED_OP_LIST file ${SELECTED_OP_LIST} not found." + exit 1 + fi + CMAKE_ARGS+=("-DSELECTED_OP_LIST=${SELECTED_OP_LIST}") +fi + +# If Ninja is installed, prefer it to Make +if [ -x "$(command -v ninja)" ]; then + CMAKE_ARGS+=("-GNinja") +fi + +# Use android-cmake to build Android project from CMake. +CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake") + +if [ -z "$BUILD_MOBILE_BENCHMARK" ]; then + BUILD_MOBILE_BENCHMARK=0 +fi + +if [ -z "$BUILD_MOBILE_TEST" ]; then + BUILD_MOBILE_TEST=0 +fi +# Don't build artifacts we don't need +CMAKE_ARGS+=("-DBUILD_TEST=OFF") +CMAKE_ARGS+=("-DBUILD_BINARY=OFF") + +# If there exists env variable and it equals to 0, build full jit interpreter. +# Default behavior is to build lite interpreter +# cmd: BUILD_LITE_INTERPRETER=0 ./scripts/build_android.sh +if [ "${BUILD_LITE_INTERPRETER}" == 0 ]; then + CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF") +else + CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON") +fi +if [ "${TRACING_BASED}" == 1 ]; then + CMAKE_ARGS+=("-DTRACING_BASED=ON") +else + CMAKE_ARGS+=("-DTRACING_BASED=OFF") +fi +if [ "${USE_LIGHTWEIGHT_DISPATCH}" == 1 ]; then + CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=ON") + CMAKE_ARGS+=("-DSTATIC_DISPATCH_BACKEND=CPU") +else + CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=OFF") +fi + +CMAKE_ARGS+=("-DBUILD_MOBILE_BENCHMARK=$BUILD_MOBILE_BENCHMARK") +CMAKE_ARGS+=("-DBUILD_MOBILE_TEST=$BUILD_MOBILE_TEST") +CMAKE_ARGS+=("-DBUILD_PYTHON=OFF") +CMAKE_ARGS+=("-DBUILD_SHARED_LIBS=OFF") +if (( "${ANDROID_NDK_VERSION:-0}" < 18 )); then + CMAKE_ARGS+=("-DANDROID_TOOLCHAIN=gcc") +else + CMAKE_ARGS+=("-DANDROID_TOOLCHAIN=clang") +fi +# Disable unused dependencies +CMAKE_ARGS+=("-DUSE_CUDA=OFF") +CMAKE_ARGS+=("-DUSE_ITT=OFF") +CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") +CMAKE_ARGS+=("-DUSE_OPENCV=OFF") +CMAKE_ARGS+=("-DUSE_MPI=OFF") +CMAKE_ARGS+=("-DUSE_OPENMP=OFF") +# Only toggle if VERBOSE=1 +if [ "${VERBOSE:-}" == '1' ]; then + CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") +fi + +# Android specific flags +CMAKE_ARGS+=("-DANDROID_NDK=$ANDROID_NDK") +CMAKE_ARGS+=("-DANDROID_ABI=$ANDROID_ABI") +CMAKE_ARGS+=("-DANDROID_NATIVE_API_LEVEL=$ANDROID_NATIVE_API_LEVEL") +CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=rtti exceptions") +if [ "${ANDROID_STL_SHARED:-}" == '1' ]; then + CMAKE_ARGS+=("-DANDROID_STL=c++_shared") +fi +if [ "${ANDROID_DEBUG_SYMBOLS:-}" == '1' ]; then + CMAKE_ARGS+=("-DANDROID_DEBUG_SYMBOLS=1") +fi + +if [ -n "${USE_VULKAN}" ]; then + CMAKE_ARGS+=("-DUSE_VULKAN=ON") + if [ -n "${USE_VULKAN_FP16_INFERENCE}" ]; then + CMAKE_ARGS+=("-DUSE_VULKAN_FP16_INFERENCE=ON") + fi + if [ -n "${USE_VULKAN_RELAXED_PRECISION}" ]; then + CMAKE_ARGS+=("-DUSE_VULKAN_RELAXED_PRECISION=ON") + fi +fi + +# Use-specified CMake arguments go last to allow overriding defaults +CMAKE_ARGS+=($@) + +# Patch pocketfft (as Android does not have aligned_alloc even if compiled with c++17 +if [ -f third_party/pocketfft/pocketfft_hdronly.h ]; then + sed -i -e "s/__cplusplus >= 201703L/0/" third_party/pocketfft/pocketfft_hdronly.h +fi + +# Now, actually build the Android target. +BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_android"} +INSTALL_PREFIX=${BUILD_ROOT}/install +mkdir -p $BUILD_ROOT +cd $BUILD_ROOT +cmake "$CAFFE2_ROOT" \ + -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ + -DCMAKE_BUILD_TYPE=Release \ + "${CMAKE_ARGS[@]}" + +# Cross-platform parallel build +if [ -z "$MAX_JOBS" ]; then + if [ "$(uname)" == 'Darwin' ]; then + MAX_JOBS=$(sysctl -n hw.ncpu) + else + MAX_JOBS=$(nproc) + fi +fi + +echo "Will install headers and libs to $INSTALL_PREFIX for further Android project usage." +cmake --build . --target install -- "-j${MAX_JOBS}" +echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your Android project directory." diff --git a/scripts/build_android_gradle.sh b/scripts/build_android_gradle.sh new file mode 100755 index 0000000000000..fc27c5dd2516b --- /dev/null +++ b/scripts/build_android_gradle.sh @@ -0,0 +1,102 @@ +#!/usr/bin/env bash +set -eux -o pipefail + +env +echo "BUILD_ENVIRONMENT:$BUILD_ENVIRONMENT" + +export ANDROID_NDK_HOME=/opt/ndk +export ANDROID_NDK=/opt/ndk +export ANDROID_HOME=/opt/android/sdk + +# Must be in sync with GRADLE_VERSION in docker image for android +# https://github.com/pietern/pytorch-dockerfiles/blob/master/build.sh#L155 +export GRADLE_VERSION=6.8.3 +export GRADLE_HOME=/opt/gradle/gradle-$GRADLE_VERSION +export GRADLE_PATH=$GRADLE_HOME/bin/gradle + +# touch gradle cache files to prevent expiration +while IFS= read -r -d '' file +do + touch "$file" || true +done < <(find /var/lib/jenkins/.gradle -type f -print0) + +# Patch pocketfft (as Android does not have aligned_alloc even if compiled with c++17 +if [ -f ~/workspace/third_party/pocketfft/pocketfft_hdronly.h ]; then + sed -i -e "s/__cplusplus >= 201703L/0/" ~/workspace/third_party/pocketfft/pocketfft_hdronly.h +fi + +export GRADLE_LOCAL_PROPERTIES=~/workspace/android/local.properties +rm -f $GRADLE_LOCAL_PROPERTIES +echo "sdk.dir=/opt/android/sdk" >> $GRADLE_LOCAL_PROPERTIES +echo "ndk.dir=/opt/ndk" >> $GRADLE_LOCAL_PROPERTIES +echo "cmake.dir=/usr/local" >> $GRADLE_LOCAL_PROPERTIES + +retry () { + $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) +} + +# Run custom build script +if [[ "${BUILD_ENVIRONMENT}" == *-gradle-custom-build* ]]; then + # Install torch & torchvision - used to download & dump used ops from test model. + retry pip install torch torchvision --progress-bar off + + exec "$(dirname "${BASH_SOURCE[0]}")/../android/build_test_app_custom.sh" armeabi-v7a +fi + +# Run default build +BUILD_ANDROID_INCLUDE_DIR_x86=~/workspace/build_android/install/include +BUILD_ANDROID_LIB_DIR_x86=~/workspace/build_android/install/lib + +BUILD_ANDROID_INCLUDE_DIR_x86_64=~/workspace/build_android_install_x86_64/install/include +BUILD_ANDROID_LIB_DIR_x86_64=~/workspace/build_android_install_x86_64/install/lib + +BUILD_ANDROID_INCLUDE_DIR_arm_v7a=~/workspace/build_android_install_arm_v7a/install/include +BUILD_ANDROID_LIB_DIR_arm_v7a=~/workspace/build_android_install_arm_v7a/install/lib + +BUILD_ANDROID_INCLUDE_DIR_arm_v8a=~/workspace/build_android_install_arm_v8a/install/include +BUILD_ANDROID_LIB_DIR_arm_v8a=~/workspace/build_android_install_arm_v8a/install/lib + +PYTORCH_ANDROID_SRC_MAIN_DIR=~/workspace/android/pytorch_android/src/main + +JNI_INCLUDE_DIR=${PYTORCH_ANDROID_SRC_MAIN_DIR}/cpp/libtorch_include +mkdir -p $JNI_INCLUDE_DIR + +JNI_LIBS_DIR=${PYTORCH_ANDROID_SRC_MAIN_DIR}/jniLibs +mkdir -p $JNI_LIBS_DIR + +ln -s ${BUILD_ANDROID_INCLUDE_DIR_x86} ${JNI_INCLUDE_DIR}/x86 +ln -s ${BUILD_ANDROID_LIB_DIR_x86} ${JNI_LIBS_DIR}/x86 + +if [[ "${BUILD_ENVIRONMENT}" != *-gradle-build-only-x86_32* ]]; then +ln -s ${BUILD_ANDROID_INCLUDE_DIR_x86_64} ${JNI_INCLUDE_DIR}/x86_64 +ln -s ${BUILD_ANDROID_LIB_DIR_x86_64} ${JNI_LIBS_DIR}/x86_64 + +ln -s ${BUILD_ANDROID_INCLUDE_DIR_arm_v7a} ${JNI_INCLUDE_DIR}/armeabi-v7a +ln -s ${BUILD_ANDROID_LIB_DIR_arm_v7a} ${JNI_LIBS_DIR}/armeabi-v7a + +ln -s ${BUILD_ANDROID_INCLUDE_DIR_arm_v8a} ${JNI_INCLUDE_DIR}/arm64-v8a +ln -s ${BUILD_ANDROID_LIB_DIR_arm_v8a} ${JNI_LIBS_DIR}/arm64-v8a +fi + +GRADLE_PARAMS="-p android assembleRelease --debug --stacktrace" +if [[ "${BUILD_ENVIRONMENT}" == *-gradle-build-only-x86_32* ]]; then + GRADLE_PARAMS+=" -PABI_FILTERS=x86" +fi + +if [ -n "${GRADLE_OFFLINE:-}" ]; then + GRADLE_PARAMS+=" --offline" +fi + +$GRADLE_PATH $GRADLE_PARAMS + +find . -type f -name "*.a" -exec ls -lh {} \; + +while IFS= read -r -d '' file +do + echo + echo "$file" + ls -lah "$file" + zipinfo -l "$file" +done < <(find . -type f -name '*.aar' -print0) + +find . -type f -name *aar -print | xargs tar cfvz ~/workspace/android/artifacts.tgz diff --git a/scripts/build_ios.sh b/scripts/build_ios.sh new file mode 100755 index 0000000000000..ad16cb940dcb8 --- /dev/null +++ b/scripts/build_ios.sh @@ -0,0 +1,155 @@ +#!/bin/bash -xe +############################################################################## +# Example command to build the iOS target. +############################################################################## +# +# This script shows how one can build a Caffe2 binary for the iOS platform +# using ios-cmake. This is very similar to the android-cmake - see +# build_android.sh for more details. + +CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" + +if [ -z "$PYTHON" ]; then + PYTHON=python + PYTHON_VERSION_MAJOR=$($PYTHON -c 'import sys; print(sys.version_info[0])') + if [ "${PYTHON_VERSION_MAJOR}" -le 2 ]; then + echo "Default python executable is Python-2, trying to use python3 alias" + PYTHON=python3 + fi +fi + +echo "Bash: $(/bin/bash --version | head -1)" +echo "Python: $($PYTHON -c 'import sys; print(sys.version)')" +echo "Caffe2 path: $CAFFE2_ROOT" + +CMAKE_ARGS=() + +# Build PyTorch mobile +CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$($PYTHON -c 'import sysconfig; print(sysconfig.get_path("purelib"))')") +CMAKE_ARGS+=("-DPython_EXECUTABLE=$($PYTHON -c 'import sys; print(sys.executable)')") +CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") + +# custom build with selected ops +if [ -n "${SELECTED_OP_LIST}" ]; then + SELECTED_OP_LIST="$(cd $(dirname $SELECTED_OP_LIST); pwd -P)/$(basename $SELECTED_OP_LIST)" + echo "Choose SELECTED_OP_LIST file: $SELECTED_OP_LIST" + if [ ! -r ${SELECTED_OP_LIST} ]; then + echo "Error: SELECTED_OP_LIST file ${SELECTED_OP_LIST} not found." + exit 1 + fi + CMAKE_ARGS+=("-DSELECTED_OP_LIST=${SELECTED_OP_LIST}") +fi + +# bitcode +if [ "${ENABLE_BITCODE:-}" == '1' ]; then + CMAKE_ARGS+=("-DCMAKE_C_FLAGS=-fembed-bitcode") + CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fembed-bitcode") +fi + +# Use ios-cmake to build iOS project from CMake. +# This projects sets CMAKE_C_COMPILER to /usr/bin/gcc and +# CMAKE_CXX_COMPILER to /usr/bin/g++. In order to use ccache (if it is available) we +# must override these variables via CMake arguments. +CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$CAFFE2_ROOT/cmake/iOS.cmake") +if [ -n "${CCACHE_WRAPPER_PATH:-}"]; then + CCACHE_WRAPPER_PATH=/usr/local/opt/ccache/libexec +fi +if [ -d "$CCACHE_WRAPPER_PATH" ]; then + CMAKE_ARGS+=("-DCMAKE_C_COMPILER=$CCACHE_WRAPPER_PATH/gcc") + CMAKE_ARGS+=("-DCMAKE_CXX_COMPILER=$CCACHE_WRAPPER_PATH/g++") +fi + +# IOS_PLATFORM controls type of iOS platform (see ios-cmake) +if [ -n "${IOS_PLATFORM:-}" ]; then + CMAKE_ARGS+=("-DIOS_PLATFORM=${IOS_PLATFORM}") + if [ "${IOS_PLATFORM}" == "WATCHOS" ]; then + # enable bitcode by default for watchos + CMAKE_ARGS+=("-DCMAKE_C_FLAGS=-fembed-bitcode") + CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fembed-bitcode") + # disable the QNNPACK + CMAKE_ARGS+=("-DUSE_PYTORCH_QNNPACK=OFF") + fi +else + # IOS_PLATFORM is not set, default to OS, which builds iOS. + CMAKE_ARGS+=("-DIOS_PLATFORM=OS") +fi + +if [ -n "${IOS_ARCH:-}" ]; then + CMAKE_ARGS+=("-DIOS_ARCH=${IOS_ARCH}") +fi + +if [ "${BUILD_LITE_INTERPRETER}" == 0 ]; then + CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF") +else + CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON") +fi +if [ "${TRACING_BASED}" == 1 ]; then + CMAKE_ARGS+=("-DTRACING_BASED=ON") +else + CMAKE_ARGS+=("-DTRACING_BASED=OFF") +fi +if [ "${USE_LIGHTWEIGHT_DISPATCH}" == 1 ]; then + CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=ON") + CMAKE_ARGS+=("-DSTATIC_DISPATCH_BACKEND=CPU") +else + CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=OFF") +fi + +CMAKE_ARGS+=("-DUSE_LITE_INTERPRETER_PROFILER=OFF") + +# Don't build binaries or tests (only the library) +CMAKE_ARGS+=("-DBUILD_TEST=OFF") +CMAKE_ARGS+=("-DBUILD_BINARY=OFF") +CMAKE_ARGS+=("-DBUILD_PYTHON=OFF") + +# Disable unused dependencies +CMAKE_ARGS+=("-DUSE_CUDA=OFF") +CMAKE_ARGS+=("-DUSE_ITT=OFF") +CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") +CMAKE_ARGS+=("-DUSE_OPENCV=OFF") +CMAKE_ARGS+=("-DUSE_MPI=OFF") +CMAKE_ARGS+=("-DUSE_NUMPY=OFF") +CMAKE_ARGS+=("-DUSE_NNPACK=OFF") +CMAKE_ARGS+=("-DUSE_MKLDNN=OFF") + +# Metal +if [ "${USE_PYTORCH_METAL:-}" == "1" ]; then + CMAKE_ARGS+=("-DUSE_PYTORCH_METAL=ON") +fi + +# Core ML +if [ "${USE_COREML_DELEGATE}" == "1" ]; then + CMAKE_ARGS+=("-DUSE_COREML_DELEGATE=ON") +fi + +# pthreads +CMAKE_ARGS+=("-DCMAKE_THREAD_LIBS_INIT=-lpthread") +CMAKE_ARGS+=("-DCMAKE_HAVE_THREADS_LIBRARY=1") +CMAKE_ARGS+=("-DCMAKE_USE_PTHREADS_INIT=1") + +# Only toggle if VERBOSE=1 +if [ "${VERBOSE:-}" == '1' ]; then + CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") +fi + +# enable ARC +CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fobjc-arc") + +# Now, actually build the iOS target. +BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_ios"} +INSTALL_PREFIX=${BUILD_ROOT}/install +mkdir -p $BUILD_ROOT +cd $BUILD_ROOT +cmake "$CAFFE2_ROOT" \ + -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ + -DCMAKE_BUILD_TYPE=MinSizeRel \ + -DBUILD_SHARED_LIBS=OFF \ + ${CMAKE_ARGS[@]} \ + $@ + +cmake --build . -- "-j$(sysctl -n hw.ncpu)" + +# copy headers and libs to install directory +echo "Will install headers and libs to $INSTALL_PREFIX for further Xcode project usage." +make install +echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your Xcode project directory." diff --git a/scripts/build_local.sh b/scripts/build_local.sh new file mode 100755 index 0000000000000..b843671501256 --- /dev/null +++ b/scripts/build_local.sh @@ -0,0 +1,82 @@ +#!/bin/bash +# +############################################################################## +# Example command to build Caffe2 +############################################################################## +# + +set -ex + +CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" + +CMAKE_ARGS=() + +# If Ninja is installed, prefer it to Make +if [ -x "$(command -v ninja)" ]; then + CMAKE_ARGS+=("-GNinja") +fi + +# Use ccache if available (this path is where Homebrew installs ccache symlinks) +if [ "$(uname)" == 'Darwin' ]; then + if [ -n "${CCACHE_WRAPPER_PATH:-}"]; then + CCACHE_WRAPPER_PATH=/usr/local/opt/ccache/libexec + fi + if [ -d "$CCACHE_WRAPPER_PATH" ]; then + CMAKE_ARGS+=("-DCMAKE_C_COMPILER=$CCACHE_WRAPPER_PATH/gcc") + CMAKE_ARGS+=("-DCMAKE_CXX_COMPILER=$CCACHE_WRAPPER_PATH/g++") + fi +fi + +# Use special install script with Anaconda +if [ -n "${USE_ANACONDA}" ]; then + export SKIP_CONDA_TESTS=1 + export CONDA_INSTALL_LOCALLY=1 + "${ROOT_DIR}/scripts/build_anaconda.sh" "$@" +else + # Make sure that pyyaml is installed for the codegen of building Aten to work + if [[ -n "$(python -c 'import yaml' 2>&1)" ]]; then + echo "Installing pyyaml with pip at $(which pip)" + pip install --user pyyaml + fi + + # Make sure that typing is installed for the codegen of building Aten to work + if [[ -n "$(python -c 'import typing' 2>&1)" ]]; then + echo "Installing typing with pip at $(which pip)" + pip install --user typing + fi + + # Build protobuf compiler from third_party if configured to do so + if [ -n "${USE_HOST_PROTOC:-}" ]; then + echo "USE_HOST_PROTOC is set; building protoc before building Caffe2..." + "$CAFFE2_ROOT/scripts/build_host_protoc.sh" + CUSTOM_PROTOC_EXECUTABLE="$CAFFE2_ROOT/build_host_protoc/bin/protoc" + echo "Built protoc $("$CUSTOM_PROTOC_EXECUTABLE" --version)" + CMAKE_ARGS+=("-DCAFFE2_CUSTOM_PROTOC_EXECUTABLE=$CUSTOM_PROTOC_EXECUTABLE") + fi + + # We are going to build the target into build. + BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} + mkdir -p "$BUILD_ROOT" + cd "$BUILD_ROOT" + echo "Building Caffe2 in: $BUILD_ROOT" + + cmake "$CAFFE2_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + "${CMAKE_ARGS[@]}" \ + "$@" + + # Determine the number of CPUs to build with. + # If the `CAFFE_MAKE_NCPUS` variable is not specified, use them all. + if [ -n "${MAX_JOBS}" ]; then + CAFFE_MAKE_NCPUS="$MAX_JOBS" + elif [ -n "${CAFFE_MAKE_NCPUS}" ]; then + CAFFE_MAKE_NCPUS="$CAFFE_MAKE_NCPUS" + elif [ "$(uname)" == 'Darwin' ]; then + CAFFE_MAKE_NCPUS="$(sysctl -n hw.ncpu)" + else + CAFFE_MAKE_NCPUS="$(nproc)" + fi + + # Now, actually build the target. + cmake --build . -- "-j$CAFFE_MAKE_NCPUS" +fi diff --git a/scripts/build_mobile.sh b/scripts/build_mobile.sh new file mode 100755 index 0000000000000..7b1995a61ebc7 --- /dev/null +++ b/scripts/build_mobile.sh @@ -0,0 +1,107 @@ +#!/bin/bash +############################################################################## +# Example command to build the mobile target. +############################################################################## +# +# This script shows how one can build a libtorch library optimized for mobile +# devices using host toolchain. + +set -e + +export BUILD_PYTORCH_MOBILE_WITH_HOST_TOOLCHAIN=1 +CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" + +echo "Bash: $(/bin/bash --version | head -1)" +echo "Caffe2 path: $CAFFE2_ROOT" + +CMAKE_ARGS=() +CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$(python -c 'import sysconfig; print(sysconfig.get_path("purelib"))')") +CMAKE_ARGS+=("-DPython_EXECUTABLE=$(python -c 'import sys; print(sys.executable)')") +CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") +CMAKE_ARGS+=("-DBUILD_SHARED_LIBS=OFF") + +# custom build with selected ops +if [ -n "${SELECTED_OP_LIST}" ]; then + SELECTED_OP_LIST="$(cd $(dirname $SELECTED_OP_LIST); pwd -P)/$(basename $SELECTED_OP_LIST)" + echo "Choose SELECTED_OP_LIST file: $SELECTED_OP_LIST" + if [ ! -r ${SELECTED_OP_LIST} ]; then + echo "Error: SELECTED_OP_LIST file ${SELECTED_OP_LIST} not found." + exit 1 + fi + CMAKE_ARGS+=("-DSELECTED_OP_LIST=${SELECTED_OP_LIST}") +fi + +# If Ninja is installed, prefer it to Make +if [ -x "$(command -v ninja)" ]; then + CMAKE_ARGS+=("-GNinja") +fi + +# Don't build artifacts we don't need +CMAKE_ARGS+=("-DBUILD_TEST=OFF") +CMAKE_ARGS+=("-DBUILD_BINARY=OFF") + +# If there exists env variable and it equals to 1, build lite interpreter. +# Default behavior is to build full jit interpreter. +# cmd: BUILD_LITE_INTERPRETER=1 ./scripts/build_mobile.sh +if [ "x${BUILD_LITE_INTERPRETER}" == "x1" ]; then + CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON") +else + CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF") +fi +if [ "x${TRACING_BASED}" == "x1" ]; then + CMAKE_ARGS+=("-DTRACING_BASED=ON") +else + CMAKE_ARGS+=("-DTRACING_BASED=OFF") +fi + +# Lightweight dispatch bypasses the PyTorch Dispatcher. +if [ "${USE_LIGHTWEIGHT_DISPATCH}" == 1 ]; then + CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=ON") + CMAKE_ARGS+=("-DSTATIC_DISPATCH_BACKEND=CPU") +else + CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=OFF") +fi + +# Disable unused dependencies +CMAKE_ARGS+=("-DUSE_ROCM=OFF") +CMAKE_ARGS+=("-DUSE_CUDA=OFF") +CMAKE_ARGS+=("-DUSE_ITT=OFF") +CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") +CMAKE_ARGS+=("-DUSE_OPENCV=OFF") +CMAKE_ARGS+=("-DUSE_MPI=OFF") +CMAKE_ARGS+=("-DUSE_OPENMP=OFF") +CMAKE_ARGS+=("-DUSE_MKLDNN=OFF") +CMAKE_ARGS+=("-DUSE_NNPACK=OFF") +CMAKE_ARGS+=("-DUSE_NUMPY=OFF") +CMAKE_ARGS+=("-DUSE_BLAS=OFF") + +# Only toggle if VERBOSE=1 +if [ "${VERBOSE:-}" == '1' ]; then + CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") +fi + +# Use-specified CMake arguments go last to allow overriding defaults +CMAKE_ARGS+=("$@") + +# Now, actually build the Android target. +BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_mobile"} +INSTALL_PREFIX=${BUILD_ROOT}/install +mkdir -p $BUILD_ROOT +cd $BUILD_ROOT +cmake "$CAFFE2_ROOT" \ + -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ + -DCMAKE_BUILD_TYPE=Release \ + "${CMAKE_ARGS[@]}" + +# Cross-platform parallel build +if [ -z "$MAX_JOBS" ]; then + if [ "$(uname)" == 'Darwin' ]; then + MAX_JOBS=$(sysctl -n hw.ncpu) + else + MAX_JOBS=$(nproc) + fi +fi + +echo "Will install headers and libs to $INSTALL_PREFIX for further project usage." +cmake --build . --target install -- "-j${MAX_JOBS}" +echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your project directory." diff --git a/scripts/build_pytorch_android.sh b/scripts/build_pytorch_android.sh new file mode 100755 index 0000000000000..7b80965e34b5c --- /dev/null +++ b/scripts/build_pytorch_android.sh @@ -0,0 +1,51 @@ +#!/bin/bash +set -eux + +############################################################################## +# Master script to build PyTorch Android library with Java bindings. +############################################################################## +# Example usage: +# - Build default AARs: +# scripts/build_pytorch_android.sh +# +# - Build for specific ABI(s): +# scripts/build_pytorch_android.sh armeabi-v7a +# scripts/build_pytorch_android.sh arm64-v8a,x86,x86_64 +# +# Script's workflow: +# 1. Builds libtorch for android for specified android abisi (by default for all 4). +# Custom list of android abis can be specified as a bash argument as comma separated list. +# For example just for testing on android x86 emulator we need only x86 build. +# ./scripts/build_pytorch_android.sh x86 +# 2. Creates symbolic links to android/pytorch_android/src/main/jniLibs/${abi} for libtorch build output, +# android/pytorch_android/src/main/cpp/libtorch_include/${abi} for headers. +# 3. Runs pyotrch_android gradle build: +# gradle assembleRelease + +PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)" +PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android + +echo "PYTORCH_DIR:$PYTORCH_DIR" + +source "$PYTORCH_ANDROID_DIR/common.sh" + +check_android_sdk +check_gradle +parse_abis_list "$@" +build_android + +# To set proxy for gradle add following lines to ./gradle/gradle.properties: +# systemProp.http.proxyHost=... +# systemProp.http.proxyPort=8080 +# systemProp.https.proxyHost=... +# systemProp.https.proxyPort=8080 + +if [ "$CUSTOM_ABIS_LIST" = true ]; then + # Skipping clean task here as android gradle plugin 3.3.2 exteralNativeBuild has problems + # with it when abiFilters are specified. + $GRADLE_PATH -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR assembleRelease +else + $GRADLE_PATH -p $PYTORCH_ANDROID_DIR clean assembleRelease +fi + +find $PYTORCH_ANDROID_DIR -type f -name *aar | xargs ls -lah diff --git a/scripts/build_raspbian.sh b/scripts/build_raspbian.sh new file mode 100755 index 0000000000000..b1fe85926219e --- /dev/null +++ b/scripts/build_raspbian.sh @@ -0,0 +1,44 @@ +#!/bin/bash +############################################################################## +# Example command to build the Raspbian target. +############################################################################## +# +# This script shows how one can build a Caffe2 binary for raspbian. The build +# is essentially much similar to a host build, with one additional change +# which is to specify -mfpu=neon for optimized speed. + +CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" +echo "Caffe2 codebase root is: $CAFFE2_ROOT" +BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} +mkdir -p $BUILD_ROOT +echo "Build Caffe2 raspbian into: $BUILD_ROOT" + +# obtain dependencies. +echo "Installing dependencies." +sudo apt-get install \ + cmake \ + libgflags-dev \ + libgoogle-glog-dev \ + libprotobuf-dev \ + libpython-dev \ + python-pip \ + python-numpy \ + protobuf-compiler \ + python-protobuf +# python dependencies +sudo pip install hypothesis + +# Now, actually build the raspbian target. +echo "Building caffe2" +cd $BUILD_ROOT + +# Note: you can add more dependencies above if you need libraries such as +# leveldb, lmdb, etc. +cmake "$CAFFE2_ROOT" \ + -DCMAKE_VERBOSE_MAKEFILE=1 \ + -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=hard" \ + || exit 1 + +# Note: while Raspberry pi has 4 cores, running too many builds in parallel may +# cause out of memory errors so we will simply run -j 2 only. +make -j 2 || exit 1 diff --git a/scripts/build_tegra_x1.sh b/scripts/build_tegra_x1.sh new file mode 100755 index 0000000000000..063e17dfe3514 --- /dev/null +++ b/scripts/build_tegra_x1.sh @@ -0,0 +1,51 @@ +#!/bin/bash +############################################################################## +# Example command to build Caffe2 on Tegra X1. +############################################################################## +# +# This script shows how one can build a Caffe2 binary for NVidia's TX1. +# The build script assumes that you have the most recent libraries installed +# via the JetPack toolkit available at +# https://developer.nvidia.com/embedded/jetpack +# and it assumes that we are starting from a fresh system after the jetpack +# installation. If you have already installed some of the dependencies, you +# may be able to skip quite a few of the apt-get installs. + +CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" +echo "Caffe2 codebase root is: $CAFFE2_ROOT" +BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} +mkdir -p $BUILD_ROOT +echo "Build Caffe2 raspbian into: $BUILD_ROOT" + +# obtain necessary dependencies +echo "Installing dependencies." +sudo apt-get install \ + cmake \ + libgflags-dev \ + libgoogle-glog-dev \ + libprotobuf-dev \ + protobuf-compiler + +# obtain optional dependencies that are usually useful to have. +echo "Installing optional dependencies." +sudo apt-get install \ + libpython-dev \ + python-numpy \ + python-pip \ + python-protobuf + +# Obtain python hypothesis, which Caffe2 uses for unit testing. Note that +# the one provided by apt-get is quite old so we install it via pip +sudo pip install hypothesis + +# Now, actually build the android target. +echo "Building caffe2" +cd $BUILD_ROOT + +# CUDA_USE_STATIC_CUDA_RUNTIME needs to be set to off so that opencv can be +# properly used. Otherwise, opencv will complain that opencv_dep_cudart cannot +# be found. +cmake "$CAFFE2_ROOT" -DCUDA_USE_STATIC_CUDA_RUNTIME=OFF \ + || exit 1 + +make -j 4 || exit 1 diff --git a/scripts/build_tizen.sh b/scripts/build_tizen.sh new file mode 100755 index 0000000000000..2262a2503c1d0 --- /dev/null +++ b/scripts/build_tizen.sh @@ -0,0 +1,118 @@ +#!/usr/bin/env bash +############################################################################## +# Example command to build the Tizen target (RPi3). +############################################################################## +# +# This script shows how one can build a Caffe2 binary for a Tizen device (RPi3). +# The build is essentially much similar to a host build, with one additional change +# which is to specify -mfpu=neon for optimized speed. + +setup_environment(){ +# The rootfs image for a Tizen target (RPi3)is located at the below webpage: +# https://cdn.download.tizen.org/archive/releases/milestone/tizen/4.0.m1/tizen-unified_20170529.1/images/ +# If you do not have a Tizen device, Please, run qemu-arm-static and chroot command. +# $ sudo chroot ~/tizen-rootfs qemu-arm-static /usr/bin/bash + +CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" +echo "Caffe2 codebase root is: $CAFFE2_ROOT" +BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} +mkdir -p $BUILD_ROOT +echo "Build Caffe2 Tizen into: $BUILD_ROOT" +} + +caffe2_lite_dep_packages(){ +# Obtain necessary dependencies +# You can set-up a rpm repository with zypper, yum, and dnf because Tizen +# software platform officially support rpm format such as Fedora, OpenSUSE. +# The official Tizen repository is as following: +# https://cdn.download.tizen.org/archive/releases/milestone/tizen/4.0.m1/ +echo "Installing dependencies." +sudo zypper install \ + make \ + strace \ + cmake \ + gcc* \ + binutils \ + glibc* \ + cpp \ + protobuf-devel \ + libstdc++* +} + +caffe2_lite_build(){ +# Now, actually build the android target. +echo "Building caffe2" +cd $BUILD_ROOT + +# Note: add more dependencies above if you need libraries such as leveldb, lmdb, etc. +# If you have to disable a specific package due to a package absence +# from https://git.tizen.org/cgit/, append -Dxxx_xxx=OFF option before executing cmake. +cmake .. \ + -DCMAKE_VERBOSE_MAKEFILE=1 \ + -DUSE_GFLAGS=OFF \ + -DUSE_GLOG=OFF -DUSE_NNPACK=OFF \ + -DRUN_HAVE_STD_REGEX=0 \ + -DRUN_HAVE_POSIX_REGEX=0 \ + -DHAVE_GNU_POSIX_REGEX=0 \ + -DUSE_MPI=OFF -DUSE_OPENMP=OFF \ + -DBUILD_PYTHON=OFF \ + -DUSE_GLOO=OFF \ + -DUSE_OPENCV=OFF \ + -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=soft" \ + || exit 1 + +make -j`nproc` || exit 1 +} + +caffe2_full_dep_packages(){ +# Obtain necessary dependencies +# You can set-up a rpm repository with zypper, yum, and dnf because Tizen +# software platform officially support rpm format such as Fedora, OpenSUSE. +# The official Tizen repository is as following: +# https://cdn.download.tizen.org/archive/releases/milestone/tizen/4.0.m1/ +echo "Installing dependencies." +sudo zypper install \ + cmake \ + libgflags-dev \ + libgoogle-glog-dev \ + libprotobuf-dev \ + protobuf-compiler + +# Obtain optional dependencies that are usually useful to have. +echo "Installing optional dependencies." +sudo zypper install \ + libpython-dev \ + python-numpy \ + python-pip \ + python-protobuf + +# Obtain python hypothesis, which Caffe2 uses for unit testing. Note that +# the one provided by zypper is quite old so we install it via pip +sudo pip install hypothesis +} + +caffe2_full_build(){ +# Now, actually build the android target. +echo "Building caffe2" +cd $BUILD_ROOT + +# Note: add more dependencies above if you need libraries such as leveldb, lmdb, etc. +# If you have to disable a specific package due to a package absence +# from https://git.tizen.org/cgit/, append -Dxxx_xxx=OFF option before executing cmake. +cmake "$CAFFE2_ROOT" \ + -DCMAKE_VERBOSE_MAKEFILE=1 \ + -DUSE_CUDA=OFF \ + -DUSE_ITT=OFF \ + -DUSE_OPENCV=OFF \ + -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=soft" \ + || exit 1 + +make -j`nproc` || exit 1 +} + +#### Main +# Setup a build environment to compile Caffe2 deeplearning framework in Tizen platform. +setup_environment +# There are two build options to support 'full' version and 'lite' version (by default). +caffe2_lite_dep_packages +caffe2_lite_build diff --git a/scripts/build_windows.bat b/scripts/build_windows.bat new file mode 100644 index 0000000000000..60bfebad08c01 --- /dev/null +++ b/scripts/build_windows.bat @@ -0,0 +1,80 @@ +:: ############################################################################# +:: Example command to build on Windows. +:: ############################################################################# + +:: This script shows how one can build a Caffe2 binary for windows. + +@echo off +setlocal + +SET ORIGINAL_DIR=%cd% +SET CAFFE2_ROOT=%~dp0%.. + +if NOT DEFINED BUILD_BINARY ( + set BUILD_BINARY=OFF +) + +if NOT DEFINED BUILD_SHARED_LIBS ( + :: On CI, we test with BUILD_SHARED_LIBS=OFF. + :: By default, it will be BUILD_SHARED_LIBS=ON. + if NOT DEFINED BUILD_ENVIRONMENT ( + set BUILD_SHARED_LIBS=OFF + ) +) + +if NOT DEFINED CAFFE2_STATIC_LINK_CUDA ( + set CAFFE2_STATIC_LINK_CUDA=OFF +) + +if NOT DEFINED CMAKE_BUILD_TYPE ( + set CMAKE_BUILD_TYPE=Release +) + +if NOT DEFINED ONNX_NAMESPACE ( + set ONNX_NAMESPACE=onnx_c2 +) + +if NOT DEFINED TORCH_CUDA_ARCH_LIST ( + set TORCH_CUDA_ARCH_LIST=5.0 +) + +if NOT DEFINED USE_CUDA ( + set USE_CUDA=OFF +) + +if NOT DEFINED USE_OBSERVERS ( + set USE_OBSERVERS=OFF +) + +if NOT DEFINED MSVC_Z7_OVERRIDE ( + set MSVC_Z7_OVERRIDE=OFF +) + +if NOT DEFINED CMAKE_GENERATOR ( + set CMAKE_GENERATOR=Ninja +) + +set CMAKE_VERBOSE_MAKEFILE=1 + +:: Install pyyaml for Aten codegen +pip install pyyaml ninja + +echo CAFFE2_ROOT=%CAFFE2_ROOT% +echo CMAKE_GENERATOR=%CMAKE_GENERATOR% +echo CMAKE_BUILD_TYPE=%CMAKE_BUILD_TYPE% + +:: Set up cmake. We will skip building the test files right now. +pushd %CAFFE2_ROOT% +python tools\build_libtorch.py || goto :label_error +popd + +echo "Caffe2 built successfully" +cd %ORIGINAL_DIR% +endlocal +exit /b 0 + +:label_error +echo "Caffe2 building failed" +cd %ORIGINAL_DIR% +endlocal +exit /b 1 diff --git a/scripts/diagnose_protobuf.py b/scripts/diagnose_protobuf.py new file mode 100644 index 0000000000000..65af4618228db --- /dev/null +++ b/scripts/diagnose_protobuf.py @@ -0,0 +1,92 @@ +## @package diagnose_protobuf +# Module scripts.diagnose_protobuf +"""Diagnoses the current protobuf situation. + +Protocol buffer needs to be properly installed for Caffe2 to work, and +sometimes it is rather tricky. Specifically, we will need to have a +consistent version between C++ and python simultaneously. This is a +convenience script for one to quickly check if this is so on one's local +machine. + +Usage: + [set your environmental variables like PATH and PYTHONPATH] + python scripts/diagnose_protobuf.py +""" + +import os +import re +from subprocess import PIPE, Popen + + +# Get python protobuf version. +try: + import google.protobuf + + python_version = google.protobuf.__version__ + python_protobuf_installed = True +except ImportError: + print("DEBUG: cannot find python protobuf install.") + python_protobuf_installed = False + +if os.name == "nt": + protoc_name = "protoc.exe" +else: + protoc_name = "protoc" + +try: + p = Popen([protoc_name, "--version"], stdout=PIPE, stderr=PIPE) + out, err = p.communicate() +except: + print("DEBUG: did not find protoc binary.") + print("DEBUG: out: " + out) + print("DEBUG: err: " + err) + native_protobuf_installed = False +else: + if p.returncode: + print("DEBUG: protoc returned a non-zero return code.") + print("DEBUG: out: " + out) + print("DEBUG: err: " + err) + native_protobuf_installed = False + else: + tmp = re.search(r"\d\.\d\.\d", out) + if tmp: + native_version = tmp.group(0) + native_protobuf_installed = True + else: + print("DEBUG: cannot parse protoc version string.") + print("DEBUG: out: " + out) + native_protobuf_installed = False + +PYTHON_PROTOBUF_NOT_INSTALLED = """ +You have not installed python protobuf. Protobuf is needed to run caffe2. You +can install protobuf via pip or conda (if you are using anaconda python). +""" + +NATIVE_PROTOBUF_NOT_INSTALLED = """ +You have not installed the protoc binary. Protoc is needed to compile Caffe2 +protobuf source files. Depending on the platform you are on, you can install +protobuf via: + (1) Mac: using homebrew and do brew install protobuf. + (2) Linux: use apt and do apt-get install libprotobuf-dev + (3) Windows: install from source, or from the releases here: + https://github.com/google/protobuf/releases/ +""" + +VERSION_MISMATCH = f""" +Your python protobuf is of version {python_version} but your native protoc version is of +version {native_version}. This will cause the installation to produce incompatible +protobuf files. This is bad in general - consider installing the same version. +""" + +# Now, give actual recommendations +if not python_protobuf_installed: + print(PYTHON_PROTOBUF_NOT_INSTALLED) + +if not native_protobuf_installed: + print(NATIVE_PROTOBUF_NOT_INSTALLED) + +if python_protobuf_installed and native_protobuf_installed: + if python_version != native_version: + print(VERSION_MISMATCH) + else: + print("All looks good.") diff --git a/scripts/fbcode-dev-setup/ccache_setup.sh b/scripts/fbcode-dev-setup/ccache_setup.sh new file mode 100755 index 0000000000000..cb461bee2dd27 --- /dev/null +++ b/scripts/fbcode-dev-setup/ccache_setup.sh @@ -0,0 +1,92 @@ +#!/bin/bash + +# This script installs CCache with CUDA support. +# Example usage: +# ./ccache_setup.sh --path /installed/folder + +set -e +shopt -s expand_aliases + +# Setup the proxy +alias with_proxy="HTTPS_PROXY=http://fwdproxy:8080 HTTP_PROXY=http://fwdproxy:8080 FTP_PROXY=http://fwdproxy:8080 https_proxy=http://fwdproxy:8080 http_proxy=http://fwdproxy:8080 ftp_proxy=http://fwdproxy:8080 http_no_proxy='*.facebook.com|*.tfbnw.net|*.fb.com'" + +# Parse options +path="$HOME/ccache" +force=false + +while [[ $# -gt 0 ]]; do + case "$1" in + --path) + shift + path="$1" + path=$(realpath "$path") + ;; + --force) # Force install + force=true + ;; + --help) + echo 'usage: ./ccache_setup.py --path /installed/folder [--force]' + exit 0 + ;; + *) + echo "Invalid option: $1" + exit 1 + ;; + esac + shift +done + +# Check whether you put nvcc in PATH +set +e +nvcc_path=$(which nvcc) +if [[ -z "$nvcc_path" ]]; then + nvcc_path="/usr/local/cuda/bin/nvcc" + export PATH="/usr/local/cuda/bin:$PATH" +fi +set -e +if [ ! -f "$nvcc_path" ] && ! $force; then + # shellcheck disable=SC2016 + echo 'nvcc is not detected in $PATH' + exit 1 +fi +echo "nvcc is detected at $nvcc_path" + +if [ -f "$CUDA_NVCC_EXECUTABLE" ] && [[ "$CUDA_NVCC_EXECUTABLE" == *"ccache"* ]]; then # Heuristic rule + if $CUDA_NVCC_EXECUTABLE --version; then + if ! $force; then + echo "CCache with nvcc support is already installed at $CUDA_NVCC_EXECUTABLE, please add --force" + exit 0 + fi + fi +fi + +# Installing CCache +echo "CCache will be installed at $path" +if [ -e "$path" ]; then + mv --backup=t -T "$path" "${path}.old" +fi + +with_proxy git clone https://github.com/colesbury/ccache.git "$path" -b ccbin +cd "$path" +./autogen.sh +./configure +make install prefix="$path" + +mkdir -p "$path/lib" +mkdir -p "$path/cuda" +ln -sf "$path/bin/ccache" "$path/lib/cc" +ln -sf "$path/bin/ccache" "$path/lib/c++" +ln -sf "$path/bin/ccache" "$path/lib/gcc" +ln -sf "$path/bin/ccache" "$path/lib/g++" +ln -sf "$path/bin/ccache" "$path/cuda/nvcc" +"$path/bin/ccache" -M 25Gi + +# Make sure the nvcc wrapped in CCache is runnable +"$path/cuda/nvcc" --version +echo 'Congrats! The CCache with nvcc support is installed!' +echo -e "Please add the following lines to your bash init script:\\n" +echo "################ Env Var for CCache with CUDA support ################" +# shellcheck disable=SC2016 +echo 'export PATH="'"$path"'/lib:$PATH"' +echo 'export CUDA_NVCC_EXECUTABLE="'"$path"'/cuda/nvcc"' +echo '######################################################################' diff --git a/scripts/get_python_cmake_flags.py b/scripts/get_python_cmake_flags.py new file mode 100644 index 0000000000000..a49debcc884ad --- /dev/null +++ b/scripts/get_python_cmake_flags.py @@ -0,0 +1,24 @@ +## @package get_python_cmake_flags +# Module scripts.get_python_cmake_flags +############################################################################## +# Use this script to find your preferred python installation. +############################################################################## +# +# You can use the following to build with your preferred version of python +# if your installation is not being properly detected by CMake. +# +# mkdir -p build && cd build +# cmake $(python ../scripts/get_python_cmake_flags.py) .. +# make +# + + +import sys +import sysconfig + + +flags = [ + f"-DPython_EXECUTABLE:FILEPATH={sys.executable}", +] + +print(" ".join(flags), end="") diff --git a/scripts/onnx/install-develop.sh b/scripts/onnx/install-develop.sh index 9875f88fff18a..9b7f2de742256 100755 --- a/scripts/onnx/install-develop.sh +++ b/scripts/onnx/install-develop.sh @@ -15,4 +15,8 @@ pip install --no-use-pep517 -e "$tp2_dir/onnx" # Install caffe2 and pytorch pip install -r "$top_dir/caffe2/requirements.txt" pip install -r "$top_dir/requirements.txt" +<<<<<<< HEAD python -m pip install --no-build-isolation -v -e . +======= +python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/scripts/onnx/install.sh b/scripts/onnx/install.sh index 3204b3212b3a9..0b4a1ffd72b0c 100755 --- a/scripts/onnx/install.sh +++ b/scripts/onnx/install.sh @@ -35,4 +35,8 @@ _pip_install -b "$BUILD_DIR/onnx" "file://$tp2_dir/onnx#egg=onnx" # Install caffe2 and pytorch pip install -r "$top_dir/caffe2/requirements.txt" pip install -r "$top_dir/requirements.txt" +<<<<<<< HEAD python -m pip install --no-build-isolation -v . +======= +python setup.py install +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/scripts/release/README.md b/scripts/release/README.md index bc32bd0cb656c..ddde425de3009 100644 --- a/scripts/release/README.md +++ b/scripts/release/README.md @@ -10,7 +10,12 @@ These are a collection of scripts that are to be used for release activities. ### Order of Execution 1. Run cut-release-branch.sh to cut the release branch +<<<<<<< HEAD 2. Run apply-release-changes.sh to apply release only changes to create a PR with release only changes similar to this [PR](https://github.com/pytorch/pytorch/pull/149056) +======= +2. Run tag-docker-images.sh to tag current docker images with release tag and push them to docker.io. These images will be used to build the release. +3. Run apply-release-changes.sh to apply release only changes to create a PR with release only changes similar to this [PR](https://github.com/pytorch/pytorch/pull/149056) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #### Promoting packages diff --git a/scripts/release/tag-docker-images.sh b/scripts/release/tag-docker-images.sh new file mode 100644 index 0000000000000..f2299d6c463ee --- /dev/null +++ b/scripts/release/tag-docker-images.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# +# Step 1 after branch cut is complete. +# +# Tags latest docker images for release branch. +# In case of failure. The script can be rerun. +# +# Before executing this script do: +# 1. Create and Check out to Release Branch +# git checkout -b "${RELEASE_BRANCH}" +# 2. Update submodules +# git submodule update --init --recursive +# +# Usage (run from root of project): +# DRY_RUN=disabled ./scripts/release/tag-docker-images.sh +# + +set -eou pipefail + +GIT_TOP_DIR=$(git rev-parse --show-toplevel) +RELEASE_VERSION=${RELEASE_VERSION:-$(cut -d'.' -f1-2 "${GIT_TOP_DIR}/version.txt")} +DRY_RUN=${DRY_RUN:-enabled} + +python3 .github/scripts/tag_docker_images_for_release.py --version ${RELEASE_VERSION} --dry-run ${DRY_RUN} diff --git a/scripts/release_notes/README.md b/scripts/release_notes/README.md index c88533f937e7d..dd71c6e24fa74 100644 --- a/scripts/release_notes/README.md +++ b/scripts/release_notes/README.md @@ -130,7 +130,11 @@ This part is a little tedious but it seems to work. May want to explore using pa 5. Install the google doc extension [docs to markdown](https://github.com/evbacher/gd2md-html) 6. Start to compile back down these markdown files into a single markdown file. +<<<<<<< HEAD `TODO`: This is by far the most manual process and is ripe for automation. If the next person up would like to investigate Google Doc APIS there is some room for improvement here. +======= +`TODO`: This is by far the most manual process and is ripe for automation. If the next person up would like to investigate Google Doc APIS there is some room hor improvement here. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ### Part 4: Cherry Picks @@ -187,7 +191,11 @@ You will then create a release at [Pytorch Release](https://github.com/pytorch/p #### Tidbits You will probably have a release note that doesn't fit into the character limit of github. I used the following regex: +<<<<<<< HEAD `\[#(\d+)\]\(https://github.com/pytorch/pytorch/pull/\d+\)` to replace the full links to (#). +======= +`\[#(\d+)\]\(https://github.com/pytorch/pytorch/pull/\d+\)` to replace the full lunks to (#). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) This will get formatted correctly in the github UI and can be checked when creating a draft release. diff --git a/scripts/remove_apache_header.sh b/scripts/remove_apache_header.sh new file mode 100755 index 0000000000000..97980bfbb0ef6 --- /dev/null +++ b/scripts/remove_apache_header.sh @@ -0,0 +1,13 @@ +if [[ "$1" == *.py ]]; then + apache_header="apache_python.txt" +else + apache_header="apache_header.txt" +fi +apache_lines=$(wc -l < "${apache_header}") +apache_md5=$(cat "${apache_header}" | md5) +header_md5=$(head -n ${apache_lines} $1 | md5) +if [ "${header_md5}" == "${apache_md5}" ]; then + keep_lines=$(($(wc -l < $1) - ${apache_lines})) + tail -n ${keep_lines} $1 > _remove_apache_header.txt + mv _remove_apache_header.txt $1 +fi diff --git a/scripts/temp.sh b/scripts/temp.sh new file mode 100755 index 0000000000000..18eb2b4733816 --- /dev/null +++ b/scripts/temp.sh @@ -0,0 +1,7 @@ +find ../caffe2 -name "*.py" -exec ./remove_apache_header.sh {} \; +find ../caffe2 -name "*.h" -exec ./remove_apache_header.sh {} \; +find ../caffe2 -name "*.cc" -exec ./remove_apache_header.sh {} \; +find ../caffe2 -name "*.cpp" -exec ./remove_apache_header.sh {} \; +find ../caffe2 -name "*.cu" -exec ./remove_apache_header.sh {} \; +find ../caffe2 -name "*.mm" -exec ./remove_apache_header.sh {} \; +find ../caffe2 -name "*.m" -exec ./remove_apache_header.sh {} \; diff --git a/scripts/xcode_build.rb b/scripts/xcode_build.rb new file mode 100644 index 0000000000000..0734167bdda11 --- /dev/null +++ b/scripts/xcode_build.rb @@ -0,0 +1,76 @@ +require 'optparse' +require 'xcodeproj' + +options = {} +option_parser = OptionParser.new do |opts| + opts.banner = 'Tools for building PyTorch iOS framework on MacOS' + opts.on('-i', '--install_path ', 'path to the cmake install folder') { |value| + options[:install] = value + } + opts.on('-x', '--xcodeproj_path ', 'path to the XCode project file') { |value| + options[:xcodeproj] = value + } + opts.on('-p', '--platform ', 'platform for the current build, OS or SIMULATOR') { |value| + options[:platform] = value + } +end.parse! +puts options.inspect + +install_path = File.expand_path(options[:install]) +if not Dir.exist? (install_path) + raise "path don't exist:#{install_path}!" +end +xcodeproj_path = File.expand_path(options[:xcodeproj]) +if not File.exist? (xcodeproj_path) + raise "path don't exist:#{xcodeproj_path}!" +end + +project = Xcodeproj::Project.open(xcodeproj_path) +target = project.targets.first #TestApp +header_search_path = ['$(inherited)', "#{install_path}/include"] +libraries_search_path = ['$(inherited)', "#{install_path}/lib"] +other_linker_flags = ['$(inherited)', "-all_load"] + +target.build_configurations.each do |config| + config.build_settings['HEADER_SEARCH_PATHS'] = header_search_path + config.build_settings['LIBRARY_SEARCH_PATHS'] = libraries_search_path + config.build_settings['OTHER_LDFLAGS'] = other_linker_flags + config.build_settings['ENABLE_BITCODE'] = 'No' +end + +# link static libraries +target.frameworks_build_phases.clear +libs = ['libc10.a', 'libclog.a', 'libpthreadpool.a', 'libXNNPACK.a', 'libmicrokernels-prod.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a', 'libkineto.a'] +for lib in libs do + path = "#{install_path}/lib/#{lib}" + if File.exist?(path) + libref = project.frameworks_group.new_file(path) + target.frameworks_build_phases.add_file_reference(libref) + end +end +# link system frameworks +frameworks = ['CoreML', 'Metal', 'MetalPerformanceShaders', 'Accelerate', 'UIKit'] +if frameworks + frameworks.each do |framework| + path = "System/Library/Frameworks/#{framework}.framework" + framework_ref = project.frameworks_group.new_reference(path) + framework_ref.name = "#{framework}.framework" + framework_ref.source_tree = 'SDKROOT' + target.frameworks_build_phases.add_file_reference(framework_ref) + end +end +project.save + +sdk = nil +arch = nil +if options[:platform] == 'SIMULATOR' + sdk = 'iphonesimulator' + arch = 'arm64' +elsif options[:platform] == 'OS' + sdk = 'iphoneos' + arch = 'arm64' +else + raise "unsupported platform #{options[:platform]}" +end + +exec "xcodebuild clean build -project #{xcodeproj_path} -alltargets -sdk #{sdk} -configuration Release -arch #{arch}" diff --git a/setup.py b/setup.py index ae0097465da66..06418e4b78394 100644 --- a/setup.py +++ b/setup.py @@ -58,9 +58,12 @@ # USE_FBGEMM=0 # disables the FBGEMM build # +<<<<<<< HEAD # USE_FBGEMM_GENAI=0 # disables the FBGEMM GenAI build # +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # USE_KINETO=0 # disables usage of libkineto library for profiling # @@ -156,12 +159,15 @@ # USE_ROCM_KERNEL_ASSERT=1 # Enable kernel assert in ROCm platform # +<<<<<<< HEAD # USE_ROCM_CK_GEMM=1 # Enable building CK GEMM backend in ROCm platform # # USE_ROCM_CK_SDPA=1 # Enable building CK SDPA backend in ROCm platform # +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # PYTORCH_LAYERNORM_FAST_RECIPROCAL # If set, enables the use of builtin functions for fast reciprocals (1/x) w.r.t. # layer normalization. Default: enabled. @@ -239,11 +245,14 @@ # # BUILD_PYTHON_ONLY # Builds pytorch as a wheel using libtorch.so from a separate wheel +<<<<<<< HEAD # # USE_NIGHTLY=VERSION # Skip cmake build and instead download and extract nightly PyTorch wheel # matching the specified version (e.g., USE_NIGHTLY="2.8.0.dev20250608+cpu") # into the local directory for development use +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from __future__ import annotations @@ -253,15 +262,20 @@ if sys.platform == "win32" and sys.maxsize.bit_length() == 31: print( +<<<<<<< HEAD "32-bit Windows Python runtime is not supported. " "Please switch to 64-bit Python.", file=sys.stderr, +======= + "32-bit Windows Python runtime is not supported. Please switch to 64-bit Python." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) sys.exit(-1) import platform +<<<<<<< HEAD # Also update `project.requires-python` in pyproject.toml when changing this python_min_version = (3, 10, 0) python_min_version_str = ".".join(map(str, python_min_version)) @@ -270,17 +284,29 @@ f"You are using Python {platform.python_version()}. " f"Python >={python_min_version_str} is required.", file=sys.stderr, +======= +python_min_version = (3, 9, 0) +python_min_version_str = ".".join(map(str, python_min_version)) +if sys.version_info < python_min_version: + print( + f"You are using Python {platform.python_version()}. Python >={python_min_version_str} is required." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) sys.exit(-1) import filecmp import glob import importlib +<<<<<<< HEAD import itertools +======= +import importlib.util +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import json import shutil import subprocess import sysconfig +<<<<<<< HEAD import tempfile import textwrap import time @@ -327,6 +353,20 @@ IS_LINUX, IS_WINDOWS, ) +======= +import time +from collections import defaultdict + +import setuptools.command.build_ext +import setuptools.command.install +import setuptools.command.sdist +from setuptools import Extension, find_packages, setup +from setuptools.dist import Distribution +from tools.build_pytorch_libs import build_pytorch +from tools.generate_torch_version import get_torch_version +from tools.setup_helpers.cmake import CMake +from tools.setup_helpers.env import build_type, IS_DARWIN, IS_LINUX, IS_WINDOWS +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from tools.setup_helpers.generate_linker_script import gen_linker_script @@ -371,20 +411,32 @@ def str2bool(value: str | None) -> bool: raise ValueError(f"Invalid string value for boolean conversion: {value}") +<<<<<<< HEAD def _get_package_path(package_name: str) -> Path: from importlib.util import find_spec spec = find_spec(package_name) +======= +def _get_package_path(package_name): + spec = importlib.util.find_spec(package_name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if spec: # The package might be a namespace package, so get_data may fail try: loader = spec.loader if loader is not None: file_path = loader.get_filename() # type: ignore[attr-defined] +<<<<<<< HEAD return Path(file_path).parent except AttributeError: pass return CWD / package_name +======= + return os.path.dirname(file_path) + except AttributeError: + pass + return None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) BUILD_LIBTORCH_WHL = str2bool(os.getenv("BUILD_LIBTORCH_WHL")) @@ -398,7 +450,11 @@ def _get_package_path(package_name: str) -> Path: if BUILD_PYTHON_ONLY: os.environ["BUILD_LIBTORCHLESS"] = "ON" +<<<<<<< HEAD os.environ["LIBTORCH_LIB_PATH"] = (_get_package_path("torch") / "lib").as_posix() +======= + os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path('torch')}/lib" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ################################################################################ # Parameters parsed from environment @@ -409,8 +465,13 @@ def _get_package_path(package_name: str) -> Path: # see if the user passed a quiet flag to setup.py arguments and respect # that in our parts of the build EMIT_BUILD_WARNING = False +<<<<<<< HEAD RERUN_CMAKE = str2bool(os.environ.pop("CMAKE_FRESH", None)) CMAKE_ONLY = str2bool(os.environ.pop("CMAKE_ONLY", None)) +======= +RERUN_CMAKE = str2bool(os.getenv("CMAKE_FRESH")) +CMAKE_ONLY = str2bool(os.getenv("CMAKE_ONLY")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) filtered_args = [] for i, arg in enumerate(sys.argv): if arg == "--cmake": @@ -424,6 +485,7 @@ def _get_package_path(package_name: str) -> Path: if arg == "rebuild" or arg == "build": arg = "build" # rebuild is gone, make it build EMIT_BUILD_WARNING = True +<<<<<<< HEAD if arg == "develop": print( ( @@ -459,18 +521,25 @@ def _get_package_path(package_name: str) -> Path: env={**os.environ}, ) sys.exit(result.returncode) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if arg == "--": filtered_args += sys.argv[i:] break if arg == "-q" or arg == "--quiet": VERBOSE_SCRIPT = False +<<<<<<< HEAD if arg in ["clean", "dist_info", "egg_info", "sdist"]: +======= + if arg in ["clean", "egg_info", "sdist"]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) RUN_BUILD_DEPS = False filtered_args.append(arg) sys.argv = filtered_args if VERBOSE_SCRIPT: +<<<<<<< HEAD def report( *args: Any, file: IO[str] = sys.stderr, flush: bool = True, **kwargs: Any ) -> None: @@ -509,12 +578,46 @@ def report( CMAKE_PYTHON_LIBRARY = Path( sysconfig.get_config_var("LIBDIR") ) / sysconfig.get_config_var("INSTSONAME") +======= + def report(*args): + print(*args) + +else: + + def report(*args): + pass + + # Make distutils respect --quiet too + setuptools.distutils.log.warn = report + +# Constant known variables used throughout this file +cwd = os.path.dirname(os.path.abspath(__file__)) +lib_path = os.path.join(cwd, "torch", "lib") +third_party_path = os.path.join(cwd, "third_party") + +# CMAKE: full path to python library +if IS_WINDOWS: + cmake_python_library = "{}/libs/python{}.lib".format( + sysconfig.get_config_var("prefix"), sysconfig.get_config_var("VERSION") + ) + # Fix virtualenv builds + if not os.path.exists(cmake_python_library): + cmake_python_library = "{}/libs/python{}.lib".format( + sys.base_prefix, sysconfig.get_config_var("VERSION") + ) +else: + cmake_python_library = "{}/{}".format( + sysconfig.get_config_var("LIBDIR"), sysconfig.get_config_var("INSTSONAME") + ) +cmake_python_include_dir = sysconfig.get_path("include") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ################################################################################ # Version, create_version_file, and package_name ################################################################################ +<<<<<<< HEAD TORCH_PACKAGE_NAME = os.getenv("TORCH_PACKAGE_NAME", "torch") LIBTORCH_PKG_NAME = os.getenv("LIBTORCH_PACKAGE_NAME", "torch_no_python") if BUILD_LIBTORCH_WHL: @@ -522,14 +625,32 @@ def report( TORCH_VERSION = get_torch_version() report(f"Building wheel {TORCH_PACKAGE_NAME}-{TORCH_VERSION}") +======= +package_name = os.getenv("TORCH_PACKAGE_NAME", "torch") +LIBTORCH_PKG_NAME = os.getenv("LIBTORCH_PACKAGE_NAME", "torch_no_python") +if BUILD_LIBTORCH_WHL: + package_name = LIBTORCH_PKG_NAME + + +package_type = os.getenv("PACKAGE_TYPE", "wheel") +version = get_torch_version() +report(f"Building wheel {package_name}-{version}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cmake = CMake() +<<<<<<< HEAD def get_submodule_folders() -> list[Path]: git_modules_file = CWD / ".gitmodules" default_modules_path = [ THIRD_PARTY_DIR / name +======= +def get_submodule_folders(): + git_modules_path = os.path.join(cwd, ".gitmodules") + default_modules_path = [ + os.path.join(third_party_path, name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for name in [ "gloo", "cpuinfo", @@ -538,26 +659,46 @@ def get_submodule_folders() -> list[Path]: "cutlass", ] ] +<<<<<<< HEAD if not git_modules_file.exists(): return default_modules_path with git_modules_file.open(encoding="utf-8") as f: return [ CWD / line.partition("=")[-1].strip() +======= + if not os.path.exists(git_modules_path): + return default_modules_path + with open(git_modules_path) as f: + return [ + os.path.join(cwd, line.split("=", 1)[1].strip()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for line in f if line.strip().startswith("path") ] +<<<<<<< HEAD def check_submodules() -> None: def check_for_files(folder: Path, files: list[str]) -> None: if not any((folder / f).exists() for f in files): +======= +def check_submodules(): + def check_for_files(folder, files): + if not any(os.path.exists(os.path.join(folder, f)) for f in files): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) report("Could not find any of {} in {}".format(", ".join(files), folder)) report("Did you run 'git submodule update --init --recursive'?") sys.exit(1) +<<<<<<< HEAD def not_exists_or_empty(folder: Path) -> bool: return not folder.exists() or ( folder.is_dir() and next(folder.iterdir(), None) is None +======= + def not_exists_or_empty(folder): + return not os.path.exists(folder) or ( + os.path.isdir(folder) and len(os.listdir(folder)) == 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if str2bool(os.getenv("USE_SYSTEM_LIBS")): @@ -569,7 +710,11 @@ def not_exists_or_empty(folder: Path) -> bool: report(" --- Trying to initialize submodules") start = time.time() subprocess.check_call( +<<<<<<< HEAD ["git", "submodule", "update", "--init", "--recursive"], cwd=CWD +======= + ["git", "submodule", "update", "--init", "--recursive"], cwd=cwd +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) end = time.time() report(f" --- Submodule initialization took {end - start:.2f} sec") @@ -590,18 +735,27 @@ def not_exists_or_empty(folder: Path) -> bool: ], ) check_for_files( +<<<<<<< HEAD THIRD_PARTY_DIR / "fbgemm" / "external" / "asmjit", +======= + os.path.join(third_party_path, "fbgemm", "external", "asmjit"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ["CMakeLists.txt"], ) # Windows has very bad support for symbolic links. # Instead of using symlinks, we're going to copy files over +<<<<<<< HEAD def mirror_files_into_torchgen() -> None: +======= +def mirror_files_into_torchgen(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # (new_path, orig_path) # Directories are OK and are recursively mirrored. paths = [ ( +<<<<<<< HEAD CWD / "torchgen/packaged/ATen/native/native_functions.yaml", CWD / "aten/src/ATen/native/native_functions.yaml", ), @@ -633,6 +787,27 @@ def mirror_files_into_torchgen() -> None: continue if orig_path.is_dir(): if new_path.exists(): +======= + "torchgen/packaged/ATen/native/native_functions.yaml", + "aten/src/ATen/native/native_functions.yaml", + ), + ("torchgen/packaged/ATen/native/tags.yaml", "aten/src/ATen/native/tags.yaml"), + ("torchgen/packaged/ATen/templates", "aten/src/ATen/templates"), + ("torchgen/packaged/autograd", "tools/autograd"), + ("torchgen/packaged/autograd/templates", "tools/autograd/templates"), + ] + for new_path, orig_path in paths: + # Create the dirs involved in new_path if they don't exist + if not os.path.exists(new_path): + os.makedirs(os.path.dirname(new_path), exist_ok=True) + + # Copy the files from the orig location to the new location + if os.path.isfile(orig_path): + shutil.copyfile(orig_path, new_path) + continue + if os.path.isdir(orig_path): + if os.path.exists(new_path): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # copytree fails if the tree exists already, so remove it. shutil.rmtree(new_path) shutil.copytree(orig_path, new_path) @@ -640,6 +815,7 @@ def mirror_files_into_torchgen() -> None: raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`") +<<<<<<< HEAD # ATTENTION: THIS IS AI SLOP def extract_variant_from_version(version: str) -> str: """Extract variant from version string, defaulting to 'cpu'.""" @@ -1012,6 +1188,18 @@ def build_deps() -> None: version=TORCH_VERSION, cmake_python_library=CMAKE_PYTHON_LIBRARY.as_posix(), build_python=not BUILD_LIBTORCH_WHL, +======= +# all the work we need to do _before_ setup runs +def build_deps(): + report("-- Building version " + version) + check_submodules() + check_pydep("yaml", "pyyaml") + build_python = not BUILD_LIBTORCH_WHL + build_pytorch( + version=version, + cmake_python_library=cmake_python_library, + build_python=build_python, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) rerun_cmake=RERUN_CMAKE, cmake_only=CMAKE_ONLY, cmake=cmake, @@ -1021,13 +1209,18 @@ def build_deps() -> None: report( 'Finished running cmake. Run "ccmake build" or ' '"cmake-gui build" to adjust build options and ' +<<<<<<< HEAD '"python -m pip install --no-build-isolation -v ." to build.' +======= + '"python setup.py install" to build.' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) sys.exit() # Use copies instead of symbolic files. # Windows has very poor support for them. sym_files = [ +<<<<<<< HEAD CWD / "tools/shared/_utils_internal.py", CWD / "torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h", CWD / "torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h", @@ -1044,6 +1237,24 @@ def build_deps() -> None: same = True else: sym_file.unlink() +======= + "tools/shared/_utils_internal.py", + "torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h", + "torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h", + ] + orig_files = [ + "torch/_utils_internal.py", + "third_party/valgrind-headers/callgrind.h", + "third_party/valgrind-headers/valgrind.h", + ] + for sym_file, orig_file in zip(sym_files, orig_files): + same = False + if os.path.exists(sym_file): + if filecmp.cmp(sym_file, orig_file): + same = True + else: + os.remove(sym_file) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not same: shutil.copyfile(orig_file, sym_file) @@ -1058,7 +1269,11 @@ def build_deps() -> None: """.strip() +<<<<<<< HEAD def check_pydep(importname: str, module: str) -> None: +======= +def check_pydep(importname, module): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: importlib.import_module(importname) except ImportError as e: @@ -1068,6 +1283,7 @@ def check_pydep(importname: str, module: str) -> None: class build_ext(setuptools.command.build_ext.build_ext): +<<<<<<< HEAD def _embed_libomp(self) -> None: # Copy libiomp5.dylib/libomp.dylib inside the wheel package on MacOS build_lib = Path(self.build_lib) @@ -1084,6 +1300,21 @@ def _embed_libomp(self) -> None: ) rpaths: list[str] = [] libs: list[str] = [] +======= + def _embed_libomp(self): + # Copy libiomp5.dylib/libomp.dylib inside the wheel package on MacOS + lib_dir = os.path.join(self.build_lib, "torch", "lib") + libtorch_cpu_path = os.path.join(lib_dir, "libtorch_cpu.dylib") + if not os.path.exists(libtorch_cpu_path): + return + # Parse libtorch_cpu load commands + otool_cmds = ( + subprocess.check_output(["otool", "-l", libtorch_cpu_path]) + .decode("utf-8") + .split("\n") + ) + rpaths, libs = [], [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for idx, line in enumerate(otool_cmds): if line.strip() == "cmd LC_LOAD_DYLIB": lib_name = otool_cmds[idx + 2].strip() @@ -1095,9 +1326,14 @@ def _embed_libomp(self) -> None: assert rpath.startswith("path ") rpaths.append(rpath.split(" ", 1)[1].rsplit("(", 1)[0][:-1]) +<<<<<<< HEAD omplib_path: str = get_cmake_cache_vars()["OpenMP_libomp_LIBRARY"] # type: ignore[assignment] omplib_name: str = get_cmake_cache_vars()["OpenMP_C_LIB_NAMES"] # type: ignore[assignment] omplib_name += ".dylib" +======= + omplib_path = get_cmake_cache_vars()["OpenMP_libomp_LIBRARY"] + omplib_name = get_cmake_cache_vars()["OpenMP_C_LIB_NAMES"] + ".dylib" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) omplib_rpath_path = os.path.join("@rpath", omplib_name) # This logic is fragile and checks only two cases: @@ -1107,9 +1343,14 @@ def _embed_libomp(self) -> None: return # Copy libomp/libiomp5 from rpath locations +<<<<<<< HEAD target_lib = build_torch_lib_dir / omplib_name libomp_relocated = False install_name_tool_args: list[str] = [] +======= + target_lib = os.path.join(self.build_lib, "torch", "lib", omplib_name) + libomp_relocated = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for rpath in rpaths: source_lib = os.path.join(rpath, omplib_name) if not os.path.exists(source_lib): @@ -1140,6 +1381,7 @@ def _embed_libomp(self) -> None: ] libomp_relocated = True if libomp_relocated: +<<<<<<< HEAD install_name_tool_args = [ "install_name_tool", *install_name_tool_args, @@ -1166,6 +1408,27 @@ def run(self) -> None: # Report build options. This is run after the build completes so # `CMakeCache.txt` exists # and we can get an accurate report on what is used and what is not. cmake_cache_vars = get_cmake_cache_vars() +======= + install_name_tool_args.insert(0, "install_name_tool") + install_name_tool_args.append(libtorch_cpu_path) + subprocess.check_call(install_name_tool_args) + # Copy omp.h from OpenMP_C_FLAGS and copy it into include folder + omp_cflags = get_cmake_cache_vars()["OpenMP_C_FLAGS"] + if not omp_cflags: + return + for include_dir in [f[2:] for f in omp_cflags.split(" ") if f.startswith("-I")]: + omp_h = os.path.join(include_dir, "omp.h") + if not os.path.exists(omp_h): + continue + target_omp_h = os.path.join(self.build_lib, "torch", "include", "omp.h") + self.copy_file(omp_h, target_omp_h) + break + + def run(self): + # Report build options. This is run after the build completes so # `CMakeCache.txt` exists and we can get an + # accurate report on what is used and what is not. + cmake_cache_vars = defaultdict(lambda: False, cmake.get_cmake_cache_variables()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if cmake_cache_vars["USE_NUMPY"]: report("-- Building with NumPy bindings") else: @@ -1173,17 +1436,31 @@ def run(self) -> None: if cmake_cache_vars["USE_CUDNN"]: report( "-- Detected cuDNN at " +<<<<<<< HEAD f"{cmake_cache_vars['CUDNN_LIBRARY']}, " f"{cmake_cache_vars['CUDNN_INCLUDE_DIR']}" +======= + + cmake_cache_vars["CUDNN_LIBRARY"] + + ", " + + cmake_cache_vars["CUDNN_INCLUDE_DIR"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: report("-- Not using cuDNN") if cmake_cache_vars["USE_CUDA"]: +<<<<<<< HEAD report(f"-- Detected CUDA at {cmake_cache_vars['CUDA_TOOLKIT_ROOT_DIR']}") else: report("-- Not using CUDA") if cmake_cache_vars["USE_XPU"]: report(f"-- Detected XPU runtime at {cmake_cache_vars['SYCL_LIBRARY_DIR']}") +======= + report("-- Detected CUDA at " + cmake_cache_vars["CUDA_TOOLKIT_ROOT_DIR"]) + else: + report("-- Not using CUDA") + if cmake_cache_vars["USE_XPU"]: + report("-- Detected XPU runtime at " + cmake_cache_vars["SYCL_LIBRARY_DIR"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: report("-- Not using XPU") if cmake_cache_vars["USE_MKLDNN"]: @@ -1202,9 +1479,16 @@ def run(self) -> None: report("-- Not using MKLDNN") if cmake_cache_vars["USE_NCCL"] and cmake_cache_vars["USE_SYSTEM_NCCL"]: report( +<<<<<<< HEAD "-- Using system provided NCCL library at " f"{cmake_cache_vars['NCCL_LIBRARIES']}, " f"{cmake_cache_vars['NCCL_INCLUDE_DIRS']}" +======= + "-- Using system provided NCCL library at {}, {}".format( + cmake_cache_vars["NCCL_LIBRARIES"], + cmake_cache_vars["NCCL_INCLUDE_DIRS"], + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) elif cmake_cache_vars["USE_NCCL"]: report("-- Building NCCL library") @@ -1215,15 +1499,29 @@ def run(self) -> None: report("-- Building without distributed package") else: report("-- Building with distributed package: ") +<<<<<<< HEAD report(f" -- USE_TENSORPIPE={cmake_cache_vars['USE_TENSORPIPE']}") report(f" -- USE_GLOO={cmake_cache_vars['USE_GLOO']}") report(f" -- USE_MPI={cmake_cache_vars['USE_OPENMPI']}") +======= + report( + " -- USE_TENSORPIPE={}".format(cmake_cache_vars["USE_TENSORPIPE"]) + ) + report(" -- USE_GLOO={}".format(cmake_cache_vars["USE_GLOO"])) + report(" -- USE_MPI={}".format(cmake_cache_vars["USE_OPENMPI"])) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: report("-- Building without distributed package") if cmake_cache_vars["STATIC_DISPATCH_BACKEND"]: report( +<<<<<<< HEAD "-- Using static dispatch with " f"backend {cmake_cache_vars['STATIC_DISPATCH_BACKEND']}" +======= + "-- Using static dispatch with backend {}".format( + cmake_cache_vars["STATIC_DISPATCH_BACKEND"] + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if cmake_cache_vars["USE_LIGHTWEIGHT_DISPATCH"]: report("-- Using lightweight dispatch") @@ -1233,19 +1531,39 @@ def run(self) -> None: else: report("-- Not using ITT") +<<<<<<< HEAD super().run() +======= + # Do not use clang to compile extensions if `-fstack-clash-protection` is defined + # in system CFLAGS + c_flags = str(os.getenv("CFLAGS", "")) + if ( + IS_LINUX + and "-fstack-clash-protection" in c_flags + and "clang" in os.environ.get("CC", "") + ): + os.environ["CC"] = str(os.environ["CC"]) + + # It's an old-style class in Python 2.7... + setuptools.command.build_ext.build_ext.run(self) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if IS_DARWIN: self._embed_libomp() # Copy the essential export library to compile C++ extensions. if IS_WINDOWS: +<<<<<<< HEAD build_temp = Path(self.build_temp) build_lib = Path(self.build_lib) +======= + build_temp = self.build_temp +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ext_filename = self.get_ext_filename("_C") lib_filename = ".".join(ext_filename.split(".")[:-1]) + ".lib" +<<<<<<< HEAD export_lib = build_temp / "torch" / "csrc" / lib_filename target_lib = build_lib / "torch" / "lib" / "_C.lib" @@ -1277,10 +1595,72 @@ def build_extensions(self) -> None: def get_outputs(self) -> list[str]: outputs = super().get_outputs() +======= + export_lib = os.path.join( + build_temp, "torch", "csrc", lib_filename + ).replace("\\", "/") + + build_lib = self.build_lib + + target_lib = os.path.join(build_lib, "torch", "lib", "_C.lib").replace( + "\\", "/" + ) + + # Create "torch/lib" directory if not exists. + # (It is not created yet in "develop" mode.) + target_dir = os.path.dirname(target_lib) + if not os.path.exists(target_dir): + os.makedirs(target_dir) + + self.copy_file(export_lib, target_lib) + + # In ROCm on Windows case copy rocblas and hipblaslt files into + # torch/lib/rocblas/library and torch/lib/hipblaslt/library + if str2bool(os.getenv("USE_ROCM")): + rocm_dir_path = os.environ.get("ROCM_DIR") + rocm_bin_path = os.path.join(rocm_dir_path, "bin") + + rocblas_dir = os.path.join(rocm_bin_path, "rocblas") + target_rocblas_dir = os.path.join(target_dir, "rocblas") + os.makedirs(target_rocblas_dir, exist_ok=True) + self.copy_tree(rocblas_dir, target_rocblas_dir) + + hipblaslt_dir = os.path.join(rocm_bin_path, "hipblaslt") + target_hipblaslt_dir = os.path.join(target_dir, "hipblaslt") + os.makedirs(target_hipblaslt_dir, exist_ok=True) + self.copy_tree(hipblaslt_dir, target_hipblaslt_dir) + else: + report("The specified environment variable does not exist.") + + def build_extensions(self): + self.create_compile_commands() + + # Copy functorch extension + for i, ext in enumerate(self.extensions): + if ext.name != "functorch._C": + continue + fullname = self.get_ext_fullname(ext.name) + filename = self.get_ext_filename(fullname) + fileext = os.path.splitext(filename)[1] + src = os.path.join(os.path.dirname(filename), "functorch" + fileext) + dst = os.path.join(os.path.realpath(self.build_lib), filename) + if os.path.exists(src): + report(f"Copying {ext.name} from {src} to {dst}") + dst_dir = os.path.dirname(dst) + if not os.path.exists(dst_dir): + os.makedirs(dst_dir) + self.copy_file(src, dst) + + setuptools.command.build_ext.build_ext.build_extensions(self) + + def get_outputs(self): + outputs = setuptools.command.build_ext.build_ext.get_outputs(self) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) outputs.append(os.path.join(self.build_lib, "caffe2")) report(f"setup.py::get_outputs returning {outputs}") return outputs +<<<<<<< HEAD def create_compile_commands(self) -> None: def load(file: Path) -> list[dict[str, Any]]: return json.loads(file.read_text(encoding="utf-8")) @@ -1292,6 +1672,16 @@ def load(file: Path) -> list[dict[str, Any]]: for f in itertools.chain(ninja_files, cmake_files) for entry in load(f) ] +======= + def create_compile_commands(self): + def load(filename): + with open(filename) as f: + return json.load(f) + + ninja_files = glob.glob("build/*compile_commands.json") + cmake_files = glob.glob("torch/lib/build/*/compile_commands.json") + all_commands = [entry for f in ninja_files + cmake_files for entry in load(f)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # cquery does not like c++ compiles that start with gcc. # It forgets to include the c++ header directories. @@ -1303,11 +1693,20 @@ def load(file: Path) -> list[dict[str, Any]]: new_contents = json.dumps(all_commands, indent=2) contents = "" +<<<<<<< HEAD compile_commands_json = CWD / "compile_commands.json" if compile_commands_json.exists(): contents = compile_commands_json.read_text(encoding="utf-8") if contents != new_contents: compile_commands_json.write_text(new_contents, encoding="utf-8") +======= + if os.path.exists("compile_commands.json"): + with open("compile_commands.json") as f: + contents = f.read() + if contents != new_contents: + with open("compile_commands.json", "w") as f: + f.write(new_contents) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class concat_license_files: @@ -1319,6 +1718,7 @@ class concat_license_files: licensing info. """ +<<<<<<< HEAD def __init__(self, include_files: bool = False) -> None: self.f1 = CWD / "LICENSE" self.f2 = THIRD_PARTY_DIR / "LICENSES_BUNDLED.txt" @@ -1405,10 +1805,117 @@ def run(self) -> None: # Need to dump submodule hashes and create the proper LICENSE.txt for the sdist class sdist(setuptools.command.sdist.sdist): def run(self) -> None: +======= + def __init__(self, include_files=False): + self.f1 = "LICENSE" + self.f2 = "third_party/LICENSES_BUNDLED.txt" + self.include_files = include_files + + def __enter__(self): + """Concatenate files""" + + old_path = sys.path + sys.path.append(third_party_path) + try: + from build_bundled import create_bundled + finally: + sys.path = old_path + + with open(self.f1) as f1: + self.bsd_text = f1.read() + + with open(self.f1, "a") as f1: + f1.write("\n\n") + create_bundled( + os.path.relpath(third_party_path), f1, include_files=self.include_files + ) + + def __exit__(self, exception_type, exception_value, traceback): + """Restore content of f1""" + with open(self.f1, "w") as f: + f.write(self.bsd_text) + + +try: + from wheel.bdist_wheel import bdist_wheel +except ImportError: + # This is useful when wheel is not installed and bdist_wheel is not + # specified on the command line. If it _is_ specified, parsing the command + # line will fail before wheel_concatenate is needed + wheel_concatenate = None +else: + # Need to create the proper LICENSE.txt for the wheel + class wheel_concatenate(bdist_wheel): + """check submodules on sdist to prevent incomplete tarballs""" + + def run(self): + with concat_license_files(include_files=True): + super().run() + + def write_wheelfile(self, *args, **kwargs): + super().write_wheelfile(*args, **kwargs) + + if BUILD_LIBTORCH_WHL: + # Remove extraneneous files in the libtorch wheel + for root, dirs, files in os.walk(self.bdist_dir): + for file in files: + if file.endswith((".a", ".so")) and os.path.isfile( + os.path.join(self.bdist_dir, file) + ): + os.remove(os.path.join(root, file)) + elif file.endswith(".py"): + os.remove(os.path.join(root, file)) + # need an __init__.py file otherwise we wouldn't have a package + open(os.path.join(self.bdist_dir, "torch", "__init__.py"), "w").close() + + +class install(setuptools.command.install.install): + def run(self): + super().run() + + +class clean(setuptools.Command): + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + import glob + import re + + with open(".gitignore") as f: + ignores = f.read() + pat = re.compile(r"^#( BEGIN NOT-CLEAN-FILES )?") + for wildcard in filter(None, ignores.split("\n")): + match = pat.match(wildcard) + if match: + if match.group(1): + # Marker is found and stop reading .gitignore. + break + # Ignore lines which begin with '#'. + else: + # Don't remove absolute paths from the system + wildcard = wildcard.lstrip("./") + + for filename in glob.glob(wildcard): + try: + os.remove(filename) + except OSError: + shutil.rmtree(filename, ignore_errors=True) + + +class sdist(setuptools.command.sdist.sdist): + def run(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with concat_license_files(): super().run() +<<<<<<< HEAD def get_cmake_cache_vars() -> defaultdict[str, CMakeValue]: try: return defaultdict(lambda: False, cmake.get_cmake_cache_variables()) @@ -1425,6 +1932,17 @@ def configure_extension_build() -> tuple[ dict[str, list[str]], # entry_points list[str], # extra_install_requires ]: +======= +def get_cmake_cache_vars(): + try: + return defaultdict(lambda: False, cmake.get_cmake_cache_variables()) + except FileNotFoundError: + # CMakeCache.txt does not exist. Probably running "python setup.py clean" over a clean directory. + return defaultdict(lambda: False) + + +def configure_extension_build(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Configures extension build options according to system environment and user's choice. Returns: @@ -1437,17 +1955,30 @@ def configure_extension_build() -> tuple[ # Configure compile flags ################################################################################ +<<<<<<< HEAD library_dirs: list[str] = [str(TORCH_LIB_DIR)] extra_install_requires: list[str] = [] +======= + library_dirs = [] + extra_install_requires = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if IS_WINDOWS: # /NODEFAULTLIB makes sure we only link to DLL runtime # and matches the flags set for protobuf and ONNX +<<<<<<< HEAD extra_link_args: list[str] = ["/NODEFAULTLIB:LIBCMT.LIB"] # /MD links against DLL runtime # and matches the flags set for protobuf and ONNX # /EHsc is about standard C++ exception handling extra_compile_args: list[str] = ["/MD", "/FS", "/EHsc"] +======= + extra_link_args = ["/NODEFAULTLIB:LIBCMT.LIB"] + # /MD links against DLL runtime + # and matches the flags set for protobuf and ONNX + # /EHsc is about standard C++ exception handling + extra_compile_args = ["/MD", "/FS", "/EHsc"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: extra_link_args = [] extra_compile_args = [ @@ -1463,11 +1994,21 @@ def configure_extension_build() -> tuple[ "-fno-strict-aliasing", ] +<<<<<<< HEAD main_compile_args: list[str] = [] main_libraries: list[str] = ["torch_python"] main_link_args: list[str] = [] main_sources: list[str] = ["torch/csrc/stub.c"] +======= + library_dirs.append(lib_path) + + main_compile_args = [] + main_libraries = ["torch_python"] + + main_link_args = [] + main_sources = ["torch/csrc/stub.c"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if BUILD_LIBTORCH_WHL: main_libraries = ["torch"] @@ -1475,28 +2016,49 @@ def configure_extension_build() -> tuple[ if build_type.is_debug(): if IS_WINDOWS: +<<<<<<< HEAD extra_compile_args += ["/Z7"] extra_link_args += ["/DEBUG:FULL"] +======= + extra_compile_args.append("/Z7") + extra_link_args.append("/DEBUG:FULL") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: extra_compile_args += ["-O0", "-g"] extra_link_args += ["-O0", "-g"] if build_type.is_rel_with_deb_info(): if IS_WINDOWS: +<<<<<<< HEAD extra_compile_args += ["/Z7"] extra_link_args += ["/DEBUG:FULL"] +======= + extra_compile_args.append("/Z7") + extra_link_args.append("/DEBUG:FULL") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: extra_compile_args += ["-g"] extra_link_args += ["-g"] # pypi cuda package that requires installation of cuda runtime, cudnn and cublas # should be included in all wheels uploaded to pypi +<<<<<<< HEAD pytorch_extra_install_requires = os.getenv("PYTORCH_EXTRA_INSTALL_REQUIREMENTS") if pytorch_extra_install_requires: report(f"pytorch_extra_install_requirements: {pytorch_extra_install_requires}") extra_install_requires.extend( map(str.strip, pytorch_extra_install_requires.split("|")) ) +======= + pytorch_extra_install_requirements = os.getenv( + "PYTORCH_EXTRA_INSTALL_REQUIREMENTS", "" + ) + if pytorch_extra_install_requirements: + report( + f"pytorch_extra_install_requirements: {pytorch_extra_install_requirements}" + ) + extra_install_requires += pytorch_extra_install_requirements.split("|") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Cross-compile for M1 if IS_DARWIN: @@ -1519,7 +2081,11 @@ def configure_extension_build() -> tuple[ ] extra_link_args += ["-arch", macos_target_arch] +<<<<<<< HEAD def make_relative_rpath_args(path: str) -> list[str]: +======= + def make_relative_rpath_args(path): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if IS_DARWIN: return ["-Wl,-rpath,@loader_path/" + path] elif IS_WINDOWS: @@ -1531,6 +2097,7 @@ def make_relative_rpath_args(path: str) -> list[str]: # Declare extensions and package ################################################################################ +<<<<<<< HEAD ext_modules: list[Extension] = [] # packages that we want to install into site-packages and include them in wheels includes = ["torch", "torch.*", "torchgen", "torchgen.*"] @@ -1541,11 +2108,19 @@ def make_relative_rpath_args(path: str) -> list[str]: else: excludes.extend(["functorch", "functorch.*"]) packages = find_packages(include=includes, exclude=excludes) +======= + extensions = [] + excludes = ["tools", "tools.*", "caffe2", "caffe2.*"] + if not cmake_cache_vars["BUILD_FUNCTORCH"]: + excludes.extend(["functorch", "functorch.*"]) + packages = find_packages(exclude=excludes) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) C = Extension( "torch._C", libraries=main_libraries, sources=main_sources, language="c", +<<<<<<< HEAD extra_compile_args=[ *main_compile_args, *extra_compile_args, @@ -1559,16 +2134,38 @@ def make_relative_rpath_args(path: str) -> list[str]: ], ) ext_modules.append(C) +======= + extra_compile_args=main_compile_args + extra_compile_args, + include_dirs=[], + library_dirs=library_dirs, + extra_link_args=extra_link_args + + main_link_args + + make_relative_rpath_args("lib"), + ) + extensions.append(C) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # These extensions are built by cmake and copied manually in build_extensions() # inside the build_ext implementation if cmake_cache_vars["BUILD_FUNCTORCH"]: +<<<<<<< HEAD ext_modules.append(Extension(name="functorch._C", sources=[])) cmdclass = { "bdist_wheel": bdist_wheel, "build_ext": build_ext, "clean": clean, +======= + extensions.append( + Extension(name="functorch._C", sources=[]), + ) + + cmdclass = { + "bdist_wheel": wheel_concatenate, + "build_ext": build_ext, + "clean": clean, + "install": install, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "sdist": sdist, } @@ -1586,11 +2183,16 @@ def make_relative_rpath_args(path: str) -> list[str]: entry_points["console_scripts"].append( "torchfrtrace = tools.flight_recorder.fr_trace:main", ) +<<<<<<< HEAD return ext_modules, cmdclass, packages, entry_points, extra_install_requires +======= + return extensions, cmdclass, packages, entry_points, extra_install_requires +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # post run, warnings, printed at the end to make them more visible build_update_message = """ +<<<<<<< HEAD It is no longer necessary to use the 'build' or 'rebuild' targets To install: @@ -1619,17 +2221,54 @@ def main() -> None: "Set one to 0 and rerun." ) +======= + It is no longer necessary to use the 'build' or 'rebuild' targets + + To install: + $ python setup.py install + To develop locally: + $ python setup.py develop + To force cmake to re-generate native build files (off by default): + $ CMAKE_FRESH=1 python setup.py develop +""" + + +def print_box(msg): + lines = msg.split("\n") + size = max(len(l) + 1 for l in lines) + print("-" * (size + 2)) + for l in lines: + print("|{}{}|".format(l, " " * (size - len(l)))) + print("-" * (size + 2)) + + +def main(): + if BUILD_LIBTORCH_WHL and BUILD_PYTHON_ONLY: + raise RuntimeError( + "Conflict: 'BUILD_LIBTORCH_WHL' and 'BUILD_PYTHON_ONLY' can't both be 1. Set one to 0 and rerun." + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) install_requires = [ "filelock", "typing-extensions>=4.10.0", 'setuptools ; python_version >= "3.12"', "sympy>=1.13.3", +<<<<<<< HEAD "networkx>=2.5.1", "jinja2", "fsspec>=0.8.5", ] if BUILD_PYTHON_ONLY: install_requires += [f"{LIBTORCH_PKG_NAME}=={TORCH_VERSION}"] +======= + "networkx", + "jinja2", + "fsspec", + ] + + if BUILD_PYTHON_ONLY: + install_requires.append(f"{LIBTORCH_PKG_NAME}=={get_torch_version()}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if str2bool(os.getenv("USE_PRIORITIZED_TEXT_FOR_LD")): gen_linker_script( @@ -1658,8 +2297,13 @@ def main() -> None: dist.script_args = sys.argv[1:] try: dist.parse_command_line() +<<<<<<< HEAD except setuptools.errors.BaseError as e: print(e, file=sys.stderr) +======= + except setuptools.distutils.errors.DistutilsArgError as e: + print(e) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sys.exit(1) mirror_files_into_torchgen() @@ -1667,7 +2311,11 @@ def main() -> None: build_deps() ( +<<<<<<< HEAD ext_modules, +======= + extensions, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cmdclass, packages, entry_points, @@ -1675,6 +2323,20 @@ def main() -> None: ) = configure_extension_build() install_requires += extra_install_requires +<<<<<<< HEAD +======= + extras_require = { + "optree": ["optree>=0.13.0"], + "opt-einsum": ["opt-einsum>=3.3"], + "pyyaml": ["pyyaml"], + } + + # Read in README.md for our long_description + with open(os.path.join(cwd, "README.md"), encoding="utf-8") as f: + long_description = f.read() + + version_range_max = max(sys.version_info[1], 13) + 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch_package_data = [ "py.typed", "bin/*", @@ -1693,12 +2355,18 @@ def main() -> None: "include/**/*.hpp", "include/*.cuh", "include/**/*.cuh", +<<<<<<< HEAD "csrc/inductor/aoti_runtime/model.h", "_inductor/codegen/*.h", "_inductor/codegen/aoti_runtime/*.h", "_inductor/codegen/aoti_runtime/*.cpp", "_inductor/script.ld", "_inductor/kernel/flex/templates/*.jinja", +======= + "_inductor/codegen/*.h", + "_inductor/codegen/aoti_runtime/*.cpp", + "_inductor/script.ld", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "_export/serde/*.yaml", "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", @@ -1717,6 +2385,7 @@ def main() -> None: "utils/model_dump/code.js", "utils/model_dump/*.mjs", "_dynamo/graph_break_registry.json", +<<<<<<< HEAD "tools/dynamo/gb_id_mapping.py", ] @@ -1751,6 +2420,48 @@ def main() -> None: "include/kineto/*.h", "include/kineto/**/*.h", ] +======= + ] + + if not BUILD_LIBTORCH_WHL: + torch_package_data.extend( + [ + "lib/libtorch_python.so", + "lib/libtorch_python.dylib", + "lib/libtorch_python.dll", + ] + ) + if not BUILD_PYTHON_ONLY: + torch_package_data.extend( + [ + "lib/*.so*", + "lib/*.dylib*", + "lib/*.dll", + "lib/*.lib", + ] + ) + aotriton_image_path = os.path.join(lib_path, "aotriton.images") + aks2_files = [] + for root, dirs, files in os.walk(aotriton_image_path): + subpath = os.path.relpath(root, start=aotriton_image_path) + for fn in files: + aks2_files.append(os.path.join("lib/aotriton.images", subpath, fn)) + torch_package_data += aks2_files + if get_cmake_cache_vars()["USE_TENSORPIPE"]: + torch_package_data.extend( + [ + "include/tensorpipe/*.h", + "include/tensorpipe/**/*.h", + ] + ) + if get_cmake_cache_vars()["USE_KINETO"]: + torch_package_data.extend( + [ + "include/kineto/*.h", + "include/kineto/**/*.h", + ] + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torchgen_package_data = [ "packaged/*", "packaged/**/*", @@ -1758,6 +2469,7 @@ def main() -> None: package_data = { "torch": torch_package_data, } +<<<<<<< HEAD # some win libraries are excluded # these are statically linked exclude_windows_libs = [ @@ -1782,15 +2494,67 @@ def main() -> None: name=TORCH_PACKAGE_NAME, version=TORCH_VERSION, ext_modules=ext_modules, +======= + + if not BUILD_LIBTORCH_WHL: + package_data["torchgen"] = torchgen_package_data + else: + # no extensions in BUILD_LIBTORCH_WHL mode + extensions = [] + + setup( + name=package_name, + version=version, + description=( + "Tensors and Dynamic neural networks in Python with strong GPU acceleration" + ), + long_description=long_description, + long_description_content_type="text/markdown", + ext_modules=extensions, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cmdclass=cmdclass, packages=packages, entry_points=entry_points, install_requires=install_requires, +<<<<<<< HEAD package_data=package_data, exclude_package_data=exclude_package_data, # Disable automatic inclusion of data files because we want to # explicitly control with `package_data` above. include_package_data=False, +======= + extras_require=extras_require, + package_data=package_data, + # TODO fix later Manifest.IN file was previously ignored + include_package_data=False, # defaults to True with pyproject.toml file + url="https://pytorch.org/", + download_url="https://github.com/pytorch/pytorch/tags", + author="PyTorch Team", + author_email="packages@pytorch.org", + python_requires=f">={python_min_version_str}", + # PyPI package information. + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: BSD License", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Programming Language :: C++", + "Programming Language :: Python :: 3", + ] + + [ + f"Programming Language :: Python :: 3.{i}" + for i in range(python_min_version[1], version_range_max) + ], + license="BSD-3-Clause", + keywords="pytorch, machine learning", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if EMIT_BUILD_WARNING: print_box(build_update_message) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 21335a3617b43..2154c0d00078e 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2007,6 +2007,11 @@ "cast_symbool_to_symint_guardless", "constrain_range", "constrain_unify", +<<<<<<< HEAD +======= + "guard_or_true", + "guard_or_false", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "error", "eval_guards", "eval_is_non_overlapping_and_dense", @@ -2648,6 +2653,14 @@ "torch.export.graph_signature": [ "TokenArgument" ], +<<<<<<< HEAD +======= + "torch.export.pt2_archive": [ + "PT2ArchiveWriter", + "PT2ArchiveReader", + "is_pt2_package" + ], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch.fx.experimental.shape_inference.infer_shape": [ "DimDynamic", "FakeTensorMode", diff --git a/test/ao/sparsity/test_activation_sparsifier.py b/test/ao/sparsity/test_activation_sparsifier.py index 8e1525b858795..a5d93c9a47e9c 100644 --- a/test/ao/sparsity/test_activation_sparsifier.py +++ b/test/ao/sparsity/test_activation_sparsifier.py @@ -50,7 +50,11 @@ def _check_constructor(self, activation_sparsifier, model, defaults, sparse_conf sparsifier_defaults = activation_sparsifier.defaults combined_defaults = {**defaults, "sparse_config": sparse_config} +<<<<<<< HEAD # more keys are populated in activation sparsifier (even though they may be None) +======= + # more keys are populated in activation sparsifier (eventhough they may be None) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(combined_defaults) <= len(activation_sparsifier.defaults) for key, config in sparsifier_defaults.items(): diff --git a/test/ao/sparsity/test_composability.py b/test/ao/sparsity/test_composability.py index 528fe9b83c65b..108128e80f687 100644 --- a/test/ao/sparsity/test_composability.py +++ b/test/ao/sparsity/test_composability.py @@ -411,6 +411,10 @@ def test_q_prep_fx_before_s_prep(self): ) self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) +<<<<<<< HEAD +======= + @xfailIfS390X +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_q_prep_fx_s_prep_ref_conv(self): r""" This checks that the ordering: prepare_fx -> sparse prepare -> convert_to_reference_fx @@ -585,6 +589,10 @@ def test_s_prep_before_qat_prep_fx(self): ) self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) +<<<<<<< HEAD +======= + @xfailIfS390X +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_s_prep_q_prep_fx_ref(self): r""" This checks that the ordering: sparse prepare -> prepare_fx -> convert_to_reference_fx diff --git a/test/ao/sparsity/test_data_sparsifier.py b/test/ao/sparsity/test_data_sparsifier.py index 5217049aafdfd..bcd7858b82395 100644 --- a/test/ao/sparsity/test_data_sparsifier.py +++ b/test/ao/sparsity/test_data_sparsifier.py @@ -265,7 +265,11 @@ def check_memory_reference(self, data_list, data_with_config, defaults, **kwargs class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase): r"""This helper test class takes in any supported type of and runs some tests. This inherits the TestBaseDataSparsifierRuner wherein some functions are +<<<<<<< HEAD over-ridden to take accommodate the specific sparsifier. +======= + over-ridden to take accomodate the specific sparsifier. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TODO: Change the structure by creating a separate test case class for each member function """ @@ -770,7 +774,11 @@ def test_ptq_quantize_first(self): # higher threshold as quantization occurs before sparsity threshold = ( +<<<<<<< HEAD 1 # zero points seem to have higher magnitude with sparsity occurring after +======= + 1 # zero points seem to have higher magnitude with sparsity occuring after +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) sl_emb1 = (torch.abs(dequant_emb1) < threshold).float().mean() diff --git a/test/ao/sparsity/test_scheduler.py b/test/ao/sparsity/test_scheduler.py index b563efac73bd7..88fa8fb3543e5 100644 --- a/test/ao/sparsity/test_scheduler.py +++ b/test/ao/sparsity/test_scheduler.py @@ -188,7 +188,11 @@ def test_step(self): self.assertEqual( self._get_sparsity_levels(sparsifier), self.sorted_sparse_levels, +<<<<<<< HEAD msg="Sparsity level is not reaching the target level after delta_t * n steps ", +======= + msg="Sparsity level is not reaching the target level afer delta_t * n steps ", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/test/bench_mps_ops.py b/test/bench_mps_ops.py index e81fb555c848a..f786622b6509e 100644 --- a/test/bench_mps_ops.py +++ b/test/bench_mps_ops.py @@ -71,6 +71,7 @@ def bench_binary( return rc +<<<<<<< HEAD def check_eager_vs_compile(rc_c, rc_e, func, dtype): if not torch.allclose(rc_c, rc_e): mdiff = (rc_c - rc_e).abs().max() @@ -80,6 +81,8 @@ def check_eager_vs_compile(rc_c, rc_e, func, dtype): ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def bench_reduction( reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32 ) -> list[Measurement]: @@ -90,23 +93,40 @@ def f(t): return reduction_func(t, dim=0) f.__name__ = reduction_func.__name__ +<<<<<<< HEAD f_c = torch.compile(f, dynamic=False, fullgraph=True) +======= + f_c = torch.compile(f, dynamic=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for size in (512, 1024, 2048, 4096): x = torch.testing.make_tensor(size, size, device=device, dtype=dtype) rc_c, rc_e = f(x), f_c(x) rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e) +<<<<<<< HEAD check_eager_vs_compile(rc_c, rc_e, reduction_func, dtype) +======= + if not torch.allclose(rc_c, rc_e): + mdiff = (rc_c - rc_e).abs().max() + warnings.warn( + f"Eager and compile reduction do not match for {reduction_func.__name__} and {dtype} max_diff={mdiff}", + stacklevel=2, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) rc.append(bench_unary_op(f, x, f"eager-{size}x{size}")) rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}")) return rc def bench_scan( +<<<<<<< HEAD scan_func, device: str = "mps", dtype: torch.dtype = torch.float32, with_indices: bool = False, +======= + scan_func, device: str = "mps", dtype: torch.dtype = torch.float32 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> list[Measurement]: rc = [] @@ -116,18 +136,31 @@ def bench_scan( def f(t): return scan_func(t, dim=dim) +<<<<<<< HEAD f_c = torch.compile(f, dynamic=False, fullgraph=True) +======= + f_c = torch.compile(f, dynamic=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for size in (32, 128, 512, 1024): f.__name__ = f"{scan_func.__name__}-dim{dim}-{size}x{size}" f_c.__name__ = f.__name__ x = torch.testing.make_tensor(size, size, device=device, dtype=dtype) rc_c, rc_e = f(x), f_c(x) +<<<<<<< HEAD if with_indices: check_eager_vs_compile(rc_c[0], rc_e[0], scan_func, dtype) check_eager_vs_compile(rc_c[1], rc_e[1], scan_func, dtype) else: check_eager_vs_compile(rc_c, rc_e, scan_func, dtype) +======= + if not torch.allclose(rc_c, rc_e): + mdiff = (rc_c - rc_e).abs().max() + warnings.warn( + f"Eager and compile scan do not match for {scan_func.__name__} dim={dim} and {dtype} max_diff={mdiff}", + stacklevel=2, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) rc.append(bench_unary_op(f, x, "eager")) rc.append(bench_unary_op(f_c, x, "compile")) @@ -135,18 +168,31 @@ def f(t): def f_1d(t): return scan_func(t, dim=0) +<<<<<<< HEAD f_1d_c = torch.compile(f_1d, dynamic=False, fullgraph=True) +======= + f_1d_c = torch.compile(f_1d, dynamic=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for size in (100, 10000, 1000000): f_1d.__name__ = f"{scan_func.__name__}-1d-{size}" f_1d_c.__name__ = f_1d.__name__ x = torch.testing.make_tensor(size, device=device, dtype=dtype) rc_c, rc_e = f_1d(x), f_1d_c(x) +<<<<<<< HEAD if with_indices: check_eager_vs_compile(rc_c[0], rc_e[0], scan_func, dtype) check_eager_vs_compile(rc_c[1], rc_e[1], scan_func, dtype) else: check_eager_vs_compile(rc_c, rc_e, scan_func, dtype) +======= + if not torch.allclose(rc_c, rc_e): + mdiff = (rc_c - rc_e).abs().max() + warnings.warn( + f"Eager and compile 1D scan do not match for {scan_func.__name__} and {dtype} max_diff={mdiff}", + stacklevel=2, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) rc.append(bench_unary_op(f_1d, x, "eager")) rc.append(bench_unary_op(f_1d_c, x, "compile")) @@ -154,7 +200,13 @@ def f_1d(t): def main() -> None: +<<<<<<< HEAD dtypes = [torch.float16, torch.float32, torch.bfloat16] +======= + dtypes = [torch.float16, torch.float32] + if torch.backends.mps.is_macos_or_newer(14, 0): + dtypes.append(torch.bfloat16) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Profile index ops B = 11 @@ -185,12 +237,15 @@ def main() -> None: rc.extend(bench_scan(torch.cumsum, dtype=dtype)) Compare(rc).print() +<<<<<<< HEAD # Profile scan with indices ops (cummin) rc = [] for dtype in dtypes: rc.extend(bench_scan(torch.cummin, dtype=dtype, with_indices=True)) Compare(rc).print() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Profile binary ops rc = [] ops = [torch.fmax, torch.add] @@ -202,5 +257,8 @@ def main() -> None: if __name__ == "__main__": +<<<<<<< HEAD torch._dynamo.config.cache_size_limit = 2**16 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) main() diff --git a/test/benchmark_utils/test_benchmark_utils.py b/test/benchmark_utils/test_benchmark_utils.py index f9120c26a132f..e272796212da5 100644 --- a/test/benchmark_utils/test_benchmark_utils.py +++ b/test/benchmark_utils/test_benchmark_utils.py @@ -699,16 +699,24 @@ def custom_transforms(fn: str): 8959166 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] ... 92821 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] +<<<<<<< HEAD 91000 build/../torch/csrc/tensor/pytho ... ch/torch/lib/libtorch_python.so] # codespell:ignore +======= + 91000 build/../torch/csrc/tensor/pytho ... ch/torch/lib/libtorch_python.so] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 91000 /data/users/test_user/repos/pyto ... nsors::get_default_scalar_type() 90090 ???:pthread_mutex_lock [/usr/lib64/libpthread-2.28.so] 90000 build/../c10/core/TensorImpl.h:c ... ch/torch/lib/libtorch_python.so] 90000 build/../aten/src/ATen/record_fu ... torch/torch/lib/libtorch_cpu.so] 90000 /data/users/test_user/repos/pyto ... uard(std::optional) 90000 /data/users/test_user/repos/pyto ... ersionCounter::~VersionCounter() +<<<<<<< HEAD 88000 /data/users/test_user/repos/pyto ... ratorKernel*, at::Tensor const&)""".replace( " # codespell:ignore", "" ), +======= + 88000 /data/users/test_user/repos/pyto ... ratorKernel*, at::Tensor const&)""", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.regularizeAndAssertExpectedInline( diff --git a/test/cpp/aoti_abi_check/CMakeLists.txt b/test/cpp/aoti_abi_check/CMakeLists.txt index 6898e406fb3bd..e5b9252f4574a 100644 --- a/test/cpp/aoti_abi_check/CMakeLists.txt +++ b/test/cpp/aoti_abi_check/CMakeLists.txt @@ -5,11 +5,15 @@ set(AOTI_ABI_CHECK_TEST_SRCS ${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp +<<<<<<< HEAD ${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_vec.cpp +<<<<<<< HEAD ${AOTI_ABI_CHECK_TEST_ROOT}/test_vec_half.cpp ) @@ -17,6 +21,8 @@ set(AOTI_ABI_CHECK_TEST_SRCS # You may think test_vec.cpp needs to be in there, but it does not. set(AOTI_ABI_CHECK_VEC_TEST_SRCS ${AOTI_ABI_CHECK_TEST_ROOT}/test_vec_half.cpp +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) add_executable(test_aoti_abi_check @@ -31,6 +37,7 @@ target_compile_definitions(test_aoti_abi_check PRIVATE USE_GTEST) target_link_libraries(test_aoti_abi_check PRIVATE gtest_main) target_include_directories(test_aoti_abi_check PRIVATE ${ATen_CPU_INCLUDE}) +<<<<<<< HEAD foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS}) foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES}) get_filename_component(test_name ${test_src} NAME_WE) @@ -48,6 +55,8 @@ foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS}) endforeach() endforeach() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if(INSTALL_TEST) install(TARGETS test_aoti_abi_check DESTINATION bin) # Install PDB files for MSVC builds diff --git a/test/cpp/aoti_abi_check/test_dtype.cpp b/test/cpp/aoti_abi_check/test_dtype.cpp index e6e7e75867c8d..d993fcc3c086d 100644 --- a/test/cpp/aoti_abi_check/test_dtype.cpp +++ b/test/cpp/aoti_abi_check/test_dtype.cpp @@ -1,5 +1,6 @@ #include +<<<<<<< HEAD #include #include #include @@ -24,6 +25,27 @@ TEST(TestDtype, TestBFloat16) { torch::headeronly::BFloat16 sub = -1.0f; torch::headeronly::BFloat16 mul = 2.0f; torch::headeronly::BFloat16 div = 0.5f; +======= +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace aot_inductor { + +TEST(TestDtype, TestBFloat16) { + c10::BFloat16 a = 1.0f; + c10::BFloat16 b = 2.0f; + c10::BFloat16 add = 3.0f; + c10::BFloat16 sub = -1.0f; + c10::BFloat16 mul = 2.0f; + c10::BFloat16 div = 0.5f; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) EXPECT_EQ(a + b, add); EXPECT_EQ(a - b, sub); @@ -32,12 +54,21 @@ TEST(TestDtype, TestBFloat16) { } TEST(TestDtype, TestFloat8_e4m3fn) { +<<<<<<< HEAD torch::headeronly::Float8_e4m3fn a = 1.0f; torch::headeronly::Float8_e4m3fn b = 2.0f; torch::headeronly::Float8_e4m3fn add = 3.0f; torch::headeronly::Float8_e4m3fn sub = -1.0f; torch::headeronly::Float8_e4m3fn mul = 2.0f; torch::headeronly::Float8_e4m3fn div = 0.5f; +======= + c10::Float8_e4m3fn a = 1.0f; + c10::Float8_e4m3fn b = 2.0f; + c10::Float8_e4m3fn add = 3.0f; + c10::Float8_e4m3fn sub = -1.0f; + c10::Float8_e4m3fn mul = 2.0f; + c10::Float8_e4m3fn div = 0.5f; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) EXPECT_EQ(a + b, add); EXPECT_EQ(a - b, sub); @@ -46,12 +77,21 @@ TEST(TestDtype, TestFloat8_e4m3fn) { } TEST(TestDtype, TestFloat8_e4m3fuz) { +<<<<<<< HEAD torch::headeronly::Float8_e4m3fnuz a = 1.0f; torch::headeronly::Float8_e4m3fnuz b = 2.0f; torch::headeronly::Float8_e4m3fnuz add = 3.0f; torch::headeronly::Float8_e4m3fnuz sub = -1.0f; torch::headeronly::Float8_e4m3fnuz mul = 2.0f; torch::headeronly::Float8_e4m3fnuz div = 0.5f; +======= + c10::Float8_e4m3fnuz a = 1.0f; + c10::Float8_e4m3fnuz b = 2.0f; + c10::Float8_e4m3fnuz add = 3.0f; + c10::Float8_e4m3fnuz sub = -1.0f; + c10::Float8_e4m3fnuz mul = 2.0f; + c10::Float8_e4m3fnuz div = 0.5f; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) EXPECT_EQ(a + b, add); EXPECT_EQ(a - b, sub); @@ -60,12 +100,21 @@ TEST(TestDtype, TestFloat8_e4m3fuz) { } TEST(TestDtype, TestFloat8_e5m2) { +<<<<<<< HEAD torch::headeronly::Float8_e5m2 a = 1.0f; torch::headeronly::Float8_e5m2 b = 2.0f; torch::headeronly::Float8_e5m2 add = 3.0f; torch::headeronly::Float8_e5m2 sub = -1.0f; torch::headeronly::Float8_e5m2 mul = 2.0f; torch::headeronly::Float8_e5m2 div = 0.5f; +======= + c10::Float8_e5m2 a = 1.0f; + c10::Float8_e5m2 b = 2.0f; + c10::Float8_e5m2 add = 3.0f; + c10::Float8_e5m2 sub = -1.0f; + c10::Float8_e5m2 mul = 2.0f; + c10::Float8_e5m2 div = 0.5f; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) EXPECT_EQ(a + b, add); EXPECT_EQ(a - b, sub); @@ -74,12 +123,21 @@ TEST(TestDtype, TestFloat8_e5m2) { } TEST(TestDtype, TestFloat8_e5m2fnuz) { +<<<<<<< HEAD torch::headeronly::Float8_e5m2fnuz a = 1.0f; torch::headeronly::Float8_e5m2fnuz b = 2.0f; torch::headeronly::Float8_e5m2fnuz add = 3.0f; torch::headeronly::Float8_e5m2fnuz sub = -1.0f; torch::headeronly::Float8_e5m2fnuz mul = 2.0f; torch::headeronly::Float8_e5m2fnuz div = 0.5f; +======= + c10::Float8_e5m2fnuz a = 1.0f; + c10::Float8_e5m2fnuz b = 2.0f; + c10::Float8_e5m2fnuz add = 3.0f; + c10::Float8_e5m2fnuz sub = -1.0f; + c10::Float8_e5m2fnuz mul = 2.0f; + c10::Float8_e5m2fnuz div = 0.5f; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) EXPECT_EQ(a + b, add); EXPECT_EQ(a - b, sub); @@ -87,6 +145,7 @@ TEST(TestDtype, TestFloat8_e5m2fnuz) { EXPECT_EQ(a / b, div); } +<<<<<<< HEAD TEST(TestDtype, TestFloat8_e8m0fnu) { torch::headeronly::Float8_e8m0fnu a = 1.0f; ASSERT_FALSE(a.isnan()); @@ -104,11 +163,21 @@ TEST(TestDtype, TestHalf) { torch::headeronly::Half sub = -1.0f; torch::headeronly::Half mul = 2.0f; torch::headeronly::Half div = 0.5f; +======= +TEST(TestDtype, TestHalf) { + c10::Half a = 1.0f; + c10::Half b = 2.0f; + c10::Half add = 3.0f; + c10::Half sub = -1.0f; + c10::Half mul = 2.0f; + c10::Half div = 0.5f; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) EXPECT_EQ(a + b, add); EXPECT_EQ(a - b, sub); EXPECT_EQ(a * b, mul); EXPECT_EQ(a / b, div); +<<<<<<< HEAD EXPECT_EQ(a += b, add); EXPECT_EQ(a -= b, add - b); EXPECT_EQ(a *= b, b); @@ -129,6 +198,17 @@ TEST(TestDtype, TestComplexFloat) { torch::headeronly::complex sub(std::complex(-2.0f, -2.0f)); torch::headeronly::complex mul(std::complex(-5.0f, 10.0f)); torch::headeronly::complex div(std::complex(0.44f, 0.08f)); +======= +} + +TEST(TestDtype, TestComplexFloat) { + c10::complex a(std::complex(1.0f, 2.0f)); + c10::complex b(std::complex(3.0f, 4.0f)); + c10::complex add(std::complex(4.0f, 6.0f)); + c10::complex sub(std::complex(-2.0f, -2.0f)); + c10::complex mul(std::complex(-5.0f, 10.0f)); + c10::complex div(std::complex(0.44f, 0.08f)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) EXPECT_EQ(a + b, add); EXPECT_EQ(a - b, sub); @@ -136,6 +216,7 @@ TEST(TestDtype, TestComplexFloat) { EXPECT_EQ(a / b, div); } +<<<<<<< HEAD TEST(TestDtype, TestQuintsQintsAndBits) { // There's not much you can do with these dtypes... // so we'll just check that it compiles @@ -207,3 +288,7 @@ TEST(TestDtype, TestScalarType) { EXPECT_EQ(static_cast(i), expected_scalar_types[i]); } } +======= +} // namespace aot_inductor +} // namespace torch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/test/cpp/aoti_inference/CMakeLists.txt b/test/cpp/aoti_inference/CMakeLists.txt index cd87ba6c5053d..041f39aa34764 100644 --- a/test/cpp/aoti_inference/CMakeLists.txt +++ b/test/cpp/aoti_inference/CMakeLists.txt @@ -18,9 +18,14 @@ add_custom_command( OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/script_data.pt ${CMAKE_CURRENT_BINARY_DIR}/script_model_cpu.pt ${CMAKE_CURRENT_BINARY_DIR}/script_model_cuda.pt +<<<<<<< HEAD # This script requires the torch package to be installed. COMMAND python ${AOT_INDUCTOR_TEST_ROOT}/compile_model.py DEPENDS torch torch_python aoti_custom_class ${AOT_INDUCTOR_TEST_ROOT}/compile_model.py +======= + COMMAND python ${AOT_INDUCTOR_TEST_ROOT}/compile_model.py + DEPENDS compile_model.py +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) add_custom_target(aoti_script_model ALL DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/script_data.pt diff --git a/test/cpp/aoti_inference/test.cpp b/test/cpp/aoti_inference/test.cpp index cf606d242d9fe..65e5e2b81acb0 100644 --- a/test/cpp/aoti_inference/test.cpp +++ b/test/cpp/aoti_inference/test.cpp @@ -144,8 +144,11 @@ void test_aoti_package_loader_multi_gpu( const std::string& device, bool use_runtime_constant_folding) { torch::NoGradGuard no_grad; +<<<<<<< HEAD // Ensure that this test will reset the default CUDA device on exit. torch::DeviceGuard device_guard(c10::Device("cuda")); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::string data_path = (std::filesystem::path(STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / "data.pt") @@ -855,6 +858,7 @@ void test_aoti_free_buffer(bool use_runtime_constant_folding) { } } +<<<<<<< HEAD #if defined(USE_CUDA) || defined(USE_ROCM) void test_cuda_alloc_test() { torch::NoGradGuard no_grad; @@ -894,6 +898,8 @@ void test_cuda_alloc_test() { } #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ThreadPool { private: struct Task { @@ -1112,11 +1118,14 @@ TEST(AotInductorTest, FreeInactiveConstantBufferRuntimeConstantFoldingCuda) { TEST(AotInductorTest, MultiStreamTestCuda) { test_multi_cuda_streams("cuda"); } +<<<<<<< HEAD // TODO: ENABLE CUDACachingAllocator Test TEST(DISABLED_AotInductorTest, CudaAllocTestCuda) { test_cuda_alloc_test(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif } // namespace torch::aot_inductor diff --git a/test/cpp/aoti_inference/test.py b/test/cpp/aoti_inference/test.py index 756fd4a172b87..b0cddb3f6379e 100644 --- a/test/cpp/aoti_inference/test.py +++ b/test/cpp/aoti_inference/test.py @@ -1,5 +1,8 @@ import torch +<<<<<<< HEAD import torch._inductor.config +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._export import aot_compile from torch.export import Dim @@ -32,7 +35,10 @@ def forward(self, x, y): data = {} large_data = {} +<<<<<<< HEAD cuda_alloc_data = {} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) data_with_tensor_constants = {} @@ -87,6 +93,7 @@ def generate_basic_tests(): ) +<<<<<<< HEAD def generate_basic_tests_consts_cpp(): backup_consts_asm_cfg: bool = ( torch._inductor.config.aot_inductor.use_consts_asm_build @@ -99,6 +106,8 @@ def generate_basic_tests_consts_cpp(): torch._inductor.config.aot_inductor.use_consts_asm_build = backup_consts_asm_cfg +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def generate_large_tests(): device = "cuda" model = Net(device, size=4096).to(device=device) @@ -140,6 +149,7 @@ def generate_large_tests(): ) +<<<<<<< HEAD def generate_cuda_alloc_test(): device = "cuda" model = Net(device, size=4096).to(device=device) @@ -166,6 +176,8 @@ def generate_cuda_alloc_test(): ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # AOTI model which will create additional tensors during autograd. def generate_test_with_additional_tensors(): if not torch.cuda.is_available(): @@ -197,10 +209,15 @@ def generate_test_with_additional_tensors(): generate_basic_tests() +<<<<<<< HEAD generate_basic_tests_consts_cpp() generate_large_tests() generate_test_with_additional_tensors() generate_cuda_alloc_test() +======= +generate_large_tests() +generate_test_with_additional_tensors() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Use this to communicate tensors to the cpp code @@ -216,4 +233,7 @@ def __init__(self, data): torch.jit.script(Serializer(data_with_tensor_constants)).save( "data_with_tensor_constants.pt" ) +<<<<<<< HEAD torch.jit.script(Serializer(cuda_alloc_data)).save("cuda_alloc_data.pt") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/test/cpp/api/tensor_cuda.cpp b/test/cpp/api/tensor_cuda.cpp index 1c48a33fb7c0f..9a39a2dd8bf5d 100644 --- a/test/cpp/api/tensor_cuda.cpp +++ b/test/cpp/api/tensor_cuda.cpp @@ -1,8 +1,11 @@ #include #include +<<<<<<< HEAD #include #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include @@ -126,6 +129,7 @@ TEST(TensorTest, MagmaInitializesCorrectly_CUDA) { at::inverse(tensor); } } +<<<<<<< HEAD #ifdef USE_CUDA #include @@ -186,3 +190,5 @@ TEST(CuDNNBatchNormTest, OutVariantMatchesFunctional) { } #endif // AT_CUDNN_ENABLED() #endif // USE_CUDA +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/test/cpp/api/transformer.cpp b/test/cpp/api/transformer.cpp index fc4832d30157a..c4fdb59e10168 100644 --- a/test/cpp/api/transformer.cpp +++ b/test/cpp/api/transformer.cpp @@ -73,7 +73,11 @@ void transformer_encoder_layer_test_helper( ASSERT_TRUE( torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); +<<<<<<< HEAD // all 0 values are NOT masked. This shouldn't mask anything +======= + // all 0 values are NOT masked. This should't mask anything +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch::Tensor mask = torch::tensor({{0}}, tensor_options) == 1; result = model( encoder_input, diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index 0831958da761d..fc1365029cb7e 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -386,7 +386,11 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { ASSERT_TRUE( setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0); auto tempFilename = c10::str( +<<<<<<< HEAD std::filesystem::temp_directory_path().string(), "/comm_lib_trace_rank_"); +======= + std::filesystem::temp_directory_path().string(), "/nccl_trace_rank_"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ASSERT_TRUE( setenv("TORCH_NCCL_DEBUG_INFO_TEMP_FILE", tempFilename.c_str(), 1) == 0); // Enable nccl flight recorder. @@ -401,7 +405,11 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { // The only difference is that we are storing traces also in memory for // validation. std::string fileNamePrefix = c10d::getCvarString( +<<<<<<< HEAD {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/comm_lib_trace_rank_"); +======= + {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::unique_ptr wrterForTestPtr = std::make_unique(fileNamePrefix); std::vector& traces = wrterForTestPtr->getTraces(); diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp index ac4ba4da01577..92642f24750b2 100644 --- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp @@ -28,7 +28,11 @@ class NCCLTestBase { NCCLTestBase(NCCLTestBase&& other) noexcept = default; +<<<<<<< HEAD ::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() { +======= + std::shared_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return pg_; } @@ -39,7 +43,11 @@ class NCCLTestBase { void initialize( int rank, size_t size, +<<<<<<< HEAD std::optional<::c10::intrusive_ptr<::c10d::ProcessGroupNCCL>> split_from = +======= + std::optional<::std::shared_ptr<::c10d::ProcessGroupNCCL>> split_from = +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::nullopt) { store_ = c10::make_intrusive<::c10d::FileStore>(path_, size); @@ -52,13 +60,21 @@ class NCCLTestBase { opts->split_color = ++color_; } #endif +<<<<<<< HEAD pg_ = c10::make_intrusive<::c10d::ProcessGroupNCCL>( +======= + pg_ = std::make_unique<::c10d::ProcessGroupNCCL>( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) store_, rank, size, std::move(opts)); } protected: std::string path_; +<<<<<<< HEAD ::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> pg_; +======= + std::shared_ptr<::c10d::ProcessGroupNCCL> pg_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::chrono::milliseconds pgTimeout_; ::c10::intrusive_ptr<::c10d::Store> store_; int color_{1}; @@ -767,8 +783,13 @@ TEST_F(ProcessGroupNCCLTest, CUDAEventCache) { } // Test that the CUDAEventCache can be used to create CUDA events and reuse. +<<<<<<< HEAD auto event1 = c10d::CUDAEventCache::get(1)->create(true); auto event2 = c10d::CUDAEventCache::get(1)->create(false); +======= + auto event1 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true); + auto event2 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto event1_ptr = event1.get(); auto event2_ptr = event2.get(); @@ -777,6 +798,7 @@ TEST_F(ProcessGroupNCCLTest, CUDAEventCache) { event2 = nullptr; // Test that the CUDAEventCache is indeed reused. +<<<<<<< HEAD auto event3 = c10d::CUDAEventCache::get(2)->create(true); auto event4 = c10d::CUDAEventCache::get(2)->create(false); // The cache has been used up, new events should be created. @@ -785,6 +807,16 @@ TEST_F(ProcessGroupNCCLTest, CUDAEventCache) { // The cache has been used up, new events should be created. auto event7 = c10d::CUDAEventCache::get(1)->create(true); auto event8 = c10d::CUDAEventCache::get(1)->create(false); +======= + auto event3 = c10d::ProcessGroupNCCL::CUDAEventCache::get(2)->create(true); + auto event4 = c10d::ProcessGroupNCCL::CUDAEventCache::get(2)->create(false); + // The cache has been used up, new events should be created. + auto event5 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true); + auto event6 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false); + // The cache has been used up, new events should be created. + auto event7 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true); + auto event8 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) EXPECT_NE(event1_ptr, event3.get()); EXPECT_NE(event2_ptr, event4.get()); EXPECT_EQ(event1_ptr, event5.get()); diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index 0b2a06b53c9a2..7f463c25546e1 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -17,7 +17,11 @@ set(BACKEND_WITH_COMPILER_SRCS ) if(USE_KINETO) # Testing edge profiler for backend use +<<<<<<< HEAD # profiler_edge should only be added when USE_KINETO flag is on +======= + # profiler_edge should only be aded when USE_KINETO flag is on +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) list(APPEND BACKEND_WITH_COMPILER_SRCS ${TORCH_SRC_DIR}/csrc/jit/mobile/profiler_edge.cpp) endif() @@ -88,7 +92,10 @@ set(JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_subgraph_matcher.cpp ${JIT_TEST_ROOT}/test_subgraph_rewriter.cpp ${JIT_TEST_ROOT}/test_subgraph_utils.cpp +<<<<<<< HEAD ${JIT_TEST_ROOT}/test_te.cpp +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ${JIT_TEST_ROOT}/test_union.cpp ${JIT_TEST_ROOT}/test_utils.cpp ${JIT_TEST_ROOT}/test_script_profile.cpp diff --git a/test/cpp/jit/README.md b/test/cpp/jit/README.md index 06704be5d9706..512065b4463ad 100644 --- a/test/cpp/jit/README.md +++ b/test/cpp/jit/README.md @@ -36,7 +36,11 @@ The following commands assume you are in PyTorch root. ```bash # ... Build PyTorch from source, e.g. +<<<<<<< HEAD python -m pip install --no-build-isolation -v -e . +======= +python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # (re)build just the binary ninja -C build bin/test_jit # run tests diff --git a/test/cpp/jit/test_backend.cpp b/test/cpp/jit/test_backend.cpp index 4a060e436f2b0..1ee50d7bedc02 100644 --- a/test/cpp/jit/test_backend.cpp +++ b/test/cpp/jit/test_backend.cpp @@ -789,7 +789,11 @@ TEST( c._save_for_mobile(ss, ExtraFilesMap(), true); auto c_loaded = _load_for_mobile(ss); /* +<<<<<<< HEAD * Error stack trace will look like this: +======= + * Erro stack trace will look like this: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA) * Traceback of TorchScript (most recent call last): * File "", line 3, in FunctionName_UNKNOWN diff --git a/test/cpp/jit/test_backend_compiler_lib.cpp b/test/cpp/jit/test_backend_compiler_lib.cpp index 55511c3e684a6..7708b53f909b3 100644 --- a/test/cpp/jit/test_backend_compiler_lib.cpp +++ b/test/cpp/jit/test_backend_compiler_lib.cpp @@ -79,7 +79,11 @@ class BackendWithCompiler : public PyTorchBackendInterface { // forwards everything along. In a non toy setup this could grab information // from that runtime that might be relevant to execute, such as build flags // the resolution of the devices camera, or basically any runtime specific +<<<<<<< HEAD // information that wouldn't be available server side where preprocess is +======= + // information that wouldnt be available server side where preprocess is +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // called. c10::impl::GenericDict compile( c10::IValue processed, diff --git a/test/cpp/jit/test_custom_class_registrations.cpp b/test/cpp/jit/test_custom_class_registrations.cpp index 3aa981a7883ba..f2d5ae6a4bc8e 100644 --- a/test/cpp/jit/test_custom_class_registrations.cpp +++ b/test/cpp/jit/test_custom_class_registrations.cpp @@ -376,7 +376,11 @@ struct ElementwiseInterpreter : torch::CustomClassHolder { // for more info. // This is the type we will use to marshall information on disk during +<<<<<<< HEAD // Ser/De. It is a simple tuple composed of primitive types and simple +======= + // ser/de. It is a simple tuple composed of primitive types and simple +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // collection types like vector, optional, and dict. using SerializationType = std::tuple< std::vector /*input_names_*/, @@ -421,9 +425,13 @@ struct FlattenWithTensorOp : public torch::CustomClassHolder { explicit FlattenWithTensorOp(at::Tensor t) : t_(t) {} at::Tensor get() { +<<<<<<< HEAD // Need to return a copy of the tensor, otherwise the tensor will be // aliased with a tensor that may be modified by the user or backend. return t_.clone(); +======= + return t_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } std::tuple> __obj_flatten__() { @@ -439,9 +447,13 @@ struct ContainsTensor : public torch::CustomClassHolder { explicit ContainsTensor(at::Tensor t) : t_(t) {} at::Tensor get() { +<<<<<<< HEAD // Need to return a copy of the tensor, otherwise the tensor will be // aliased with a tensor that may be modified by the user or backend. return t_.clone(); +======= + return t_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } std::tuple> __obj_flatten__() { @@ -507,6 +519,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) { m.class_("_FlattenWithTensorOp") .def(torch::init()) .def("get", &FlattenWithTensorOp::get) +<<<<<<< HEAD .def("__obj_flatten__", &FlattenWithTensorOp::__obj_flatten__) .def_pickle( // __getstate__ @@ -516,6 +529,9 @@ TORCH_LIBRARY(_TorchScriptTesting, m) { [](at::Tensor data) -> c10::intrusive_ptr { return c10::make_intrusive(std::move(data)); }); +======= + .def("__obj_flatten__", &FlattenWithTensorOp::__obj_flatten__); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m.class_("_ConstantTensorContainer") .def(torch::init()) @@ -720,8 +736,12 @@ at::Tensor takes_foo_tensor_return(c10::intrusive_ptr foo, at::Tensor x) { } void queue_push(c10::intrusive_ptr tq, at::Tensor x) { +<<<<<<< HEAD // clone the tensor to avoid aliasing tq->push(x.clone()); +======= + tq->push(x); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } at::Tensor queue_pop(c10::intrusive_ptr tq) { @@ -754,11 +774,14 @@ TORCH_LIBRARY_IMPL(_TorchScriptTesting, CPU, m) { m.impl("takes_foo_tensor_return", takes_foo_tensor_return); } +<<<<<<< HEAD TORCH_LIBRARY_IMPL(_TorchScriptTesting, CUDA, m) { m.impl("queue_push", queue_push); m.impl("queue_pop", queue_pop); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_LIBRARY_IMPL(_TorchScriptTesting, Meta, m) { m.impl("takes_foo", &takes_foo); m.impl("takes_foo_list_return", takes_foo_list_return); diff --git a/test/cpp/jit/test_lite_trainer.cpp b/test/cpp/jit/test_lite_trainer.cpp index 950d0c524ad3a..dd4cfaeaebd2c 100644 --- a/test/cpp/jit/test_lite_trainer.cpp +++ b/test/cpp/jit/test_lite_trainer.cpp @@ -78,7 +78,11 @@ TEST(LiteTrainerTest, Params) { AT_ASSERT(parameters[0].item() == bc_parameters[0].item()); } +<<<<<<< HEAD // TODO Re-enable these tests after parameters are correctly loaded on mobile +======= +// TODO Renable these tests after parameters are correctly loaded on mobile +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /* TEST(MobileTest, NamedParameters) { Module m("m"); diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index ebeeb953d95b6..e4e8b685c5562 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -2709,7 +2709,10 @@ TEST(ProfilerDisableInCallbackTest, Basic) { } TEST(RecordDebugHandles, Basic) { +<<<<<<< HEAD GTEST_SKIP() << "Test is flaky and sometimes hangs on CI. "; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Enable the profiler in this thread const std::set activities( {torch::autograd::profiler::ActivityType::CPU}); diff --git a/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp b/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp index b6467b7c5b490..99faaa6d17c48 100644 --- a/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp +++ b/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp @@ -106,7 +106,11 @@ TEST(RunTimeTest, DelegateException) { * inputs.emplace_back(torch::rand({2, 4})); * inputs.emplace_back(torch::rand({13, 9})); * Run with inputs and expect exception +<<<<<<< HEAD * Error stack trace will look like this: +======= + * Erro stack trace will look like this: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA) * Traceback of TorchScript (most recent call last): * File "", line 3, in FunctionName_UNKNOWN diff --git a/test/cpp/nativert/CMakeLists.txt b/test/cpp/nativert/CMakeLists.txt index 1b4752ed9089f..72d9176b2bb25 100644 --- a/test/cpp/nativert/CMakeLists.txt +++ b/test/cpp/nativert/CMakeLists.txt @@ -5,12 +5,18 @@ file(GLOB_RECURSE NATIVERT_ALL_TEST_FILES "${NATIVERT_TEST_ROOT}/test_*.cpp") # Build the cpp gtest binary containing the cpp-only tests. set(NATIVERT_TEST_SRCS ${NATIVERT_ALL_TEST_FILES} +<<<<<<< HEAD ${TORCH_ROOT}/torch/nativert/ModelRunner.cpp ${TORCH_ROOT}/torch/nativert/graph/TensorMeta.cpp ${TORCH_ROOT}/torch/nativert/graph/Graph.cpp ${TORCH_ROOT}/torch/nativert/graph/GraphPasses.cpp ${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp ${TORCH_ROOT}/torch/nativert/graph/GraphUtils.cpp +======= + ${TORCH_ROOT}/torch/nativert/graph/TensorMeta.cpp + ${TORCH_ROOT}/torch/nativert/graph/Graph.cpp + ${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ${TORCH_ROOT}/torch/nativert/graph/Serialization.cpp ${TORCH_ROOT}/torch/nativert/executor/OpKernel.cpp ${TORCH_ROOT}/torch/nativert/executor/PlacementUtils.cpp @@ -23,6 +29,7 @@ set(NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/kernels/C10Kernel.cpp ${TORCH_ROOT}/torch/nativert/executor/memory/GreedyBySize.cpp ${TORCH_ROOT}/torch/nativert/executor/memory/Bump.cpp +<<<<<<< HEAD ${TORCH_ROOT}/torch/nativert/executor/memory/DisjointStorageGroups.cpp ${TORCH_ROOT}/torch/nativert/executor/memory/LayoutPlanner.cpp ${TORCH_ROOT}/torch/nativert/executor/memory/LayoutManager.cpp @@ -50,6 +57,10 @@ if(USE_CUDA) endif(MSVC) +======= +) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add_executable(test_nativert ${TORCH_ROOT}/test/cpp/common/main.cpp ${NATIVERT_TEST_SRCS} diff --git a/test/cpp/nativert/test_c10_kernel.cpp b/test/cpp/nativert/test_c10_kernel.cpp index 84c04c39d408d..bc08c615e25f2 100644 --- a/test/cpp/nativert/test_c10_kernel.cpp +++ b/test/cpp/nativert/test_c10_kernel.cpp @@ -27,6 +27,11 @@ return (%x) std::advance(it, 1); const Node& node = *it; +<<<<<<< HEAD +======= + c10::Device device = torch::Device(torch::kCPU, 0); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto a = at::randn({6, 6, 6}); auto b = at::randn({6, 6, 6}); @@ -34,7 +39,11 @@ return (%x) frame.setIValue(graph->getValue("a")->id(), a); frame.setIValue(graph->getValue("b")->id(), b); +<<<<<<< HEAD auto kernel = C10Kernel(&node); +======= + auto kernel = C10Kernel(&node, device); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kernel.computeInternal(frame); diff --git a/test/cpp/nativert/test_execution_frame.cpp b/test/cpp/nativert/test_execution_frame.cpp index d2a8b69cef20f..06ae17b480875 100644 --- a/test/cpp/nativert/test_execution_frame.cpp +++ b/test/cpp/nativert/test_execution_frame.cpp @@ -1,6 +1,9 @@ #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include namespace torch::nativert { @@ -92,9 +95,13 @@ TEST(ExecutionFrameTest, TestPersistentValue) { auto wid = graph->getValue("my_weight")->id(); EXPECT_NO_THROW(frame.getTensor(wid)); +<<<<<<< HEAD // can't release persistent value frame.releaseValueIfNeeded(wid); EXPECT_FALSE(frame.getIValue(wid).isNone()); +======= + EXPECT_DEATH(frame.releaseValue(wid), "Cannot release persistent value"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace torch::nativert diff --git a/test/cpp/nativert/test_itree.cpp b/test/cpp/nativert/test_itree.cpp index 4748c11c3e17a..bf1844649a11b 100644 --- a/test/cpp/nativert/test_itree.cpp +++ b/test/cpp/nativert/test_itree.cpp @@ -259,7 +259,11 @@ TEST(ITreeTest, NoContext) { c10::IValue(8), c10::IValue(9), }; +<<<<<<< HEAD EXPECT_THROW({ itreeUnflatten(flats, spec); }, c10::Error); +======= + ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TEST(ITreeTest, TooManyContext) { @@ -304,7 +308,11 @@ TEST(ITreeTest, TooManyContext) { c10::IValue(8), c10::IValue(9), }; +<<<<<<< HEAD EXPECT_THROW({ itreeUnflatten(flats, spec); }, c10::Error); +======= + ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TEST(ITreeTest, DoubleRegister) { @@ -375,7 +383,11 @@ TEST(ITreeTest, NotEnoughUnflatten) { c10::IValue(2), c10::IValue(7), }; +<<<<<<< HEAD EXPECT_THROW({ itreeUnflatten(flats, spec); }, c10::Error); +======= + ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TEST(ITreeTest, TooManyUnflatten) { @@ -449,7 +461,11 @@ TEST(ITreeTest, TooManyUnflatten) { c10::IValue(2), c10::IValue(7), }; +<<<<<<< HEAD EXPECT_THROW({ itreeUnflatten(flats, spec); }, c10::Error); +======= + ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TEST(ITreeTest, Flatten) { @@ -908,8 +924,13 @@ TEST(ITreeTest, UnmatchedDictFlatten) { list.push_back(std::move(tup)); list.push_back(c10::IValue(2)); list.push_back(std::move(dict)); +<<<<<<< HEAD EXPECT_THROW( { itreeFlatten(c10::IValue{std::move(list)}, spec); }, c10::Error); +======= + ASSERT_DEATH( + { itreeFlatten(c10::IValue{std::move(list)}, spec); }, "Check failed"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TEST(ITreeTest, DictFlattenTest) { @@ -1025,8 +1046,13 @@ TEST(ITreeTest, UnmatchedTupleFlatten) { list.push_back(std::move(tup)); list.push_back(c10::IValue(2)); list.push_back(std::move(dict)); +<<<<<<< HEAD EXPECT_THROW( { itreeFlatten(c10::IValue{std::move(list)}, spec); }, c10::Error); +======= + ASSERT_DEATH( + { itreeFlatten(c10::IValue{std::move(list)}, spec); }, "Check failed"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TEST(ITreeTest, ToAtenType) { diff --git a/test/cpp/nativert/test_layout_planner_algorithm.cpp b/test/cpp/nativert/test_layout_planner_algorithm.cpp index 0d4f8fb0d2737..8c568bf706b5c 100644 --- a/test/cpp/nativert/test_layout_planner_algorithm.cpp +++ b/test/cpp/nativert/test_layout_planner_algorithm.cpp @@ -2,7 +2,10 @@ #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include using namespace ::testing; @@ -62,6 +65,7 @@ TEST(LayoutPlannerAlgorithmTests, TestBump) { EXPECT_EQ(result.total_size, offset); } +<<<<<<< HEAD TEST(LayoutPlannerAlgorithmTests, TestStorageGroup) { auto specs = create_test_allocation_specs(); @@ -84,3 +88,5 @@ TEST(LayoutPlannerAlgorithmTests, TestStorageGroup) { EXPECT_EQ(result.total_size, 150); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/test/cpp/nativert/test_op_kernel.cpp b/test/cpp/nativert/test_op_kernel.cpp index 4854d39f5877f..f7423b591f5b0 100644 --- a/test/cpp/nativert/test_op_kernel.cpp +++ b/test/cpp/nativert/test_op_kernel.cpp @@ -1,6 +1,9 @@ #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include diff --git a/test/cpp/nativert/test_placement.cpp b/test/cpp/nativert/test_placement.cpp index ab65bfc07b917..e563b59f893e2 100644 --- a/test/cpp/nativert/test_placement.cpp +++ b/test/cpp/nativert/test_placement.cpp @@ -8,6 +8,26 @@ using namespace ::testing; namespace torch::nativert { +<<<<<<< HEAD +======= +TEST(PlacementTest, NormalizeDevice) { + c10::Device cpuDevice = c10::Device(c10::DeviceType::CPU); + c10::Device cpuDevice1 = c10::Device(c10::DeviceType::CPU); + cpuDevice1.set_index(1); + + EXPECT_EQ(normalizeDevice(cpuDevice), cpuDevice); + EXPECT_NE(normalizeDevice(cpuDevice1), cpuDevice1); + + c10::Device cudaDevice = c10::Device(c10::DeviceType::CUDA); + c10::Device cudaDevice1 = c10::Device(c10::DeviceType::CUDA, 1); + EXPECT_EQ(normalizeDevice(cudaDevice), c10::Device(c10::DeviceType::CUDA, 0)); + EXPECT_EQ( + normalizeDevice(cudaDevice1), c10::Device(c10::DeviceType::CUDA, 1)); + + EXPECT_NE( + normalizeDevice(cudaDevice1), c10::Device(c10::DeviceType::CUDA, 0)); +} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TEST(PlacementTest, IsSameDevice) { c10::Device cpuDevice = c10::Device(c10::DeviceType::CPU); @@ -73,11 +93,19 @@ TEST(PlacementTest, Placement) { {c10::Device("cuda:0"), c10::Device("cuda:1")}}; Placement p1(deviceMap1); EXPECT_EQ(p1.getMappedDevice(c10::Device("cpu")), c10::Device("cpu")); +<<<<<<< HEAD EXPECT_EQ(p1.getMappedDevice(c10::Device("cuda")), c10::Device("cuda")); EXPECT_EQ(p1.getMappedDevice(c10::Device("cuda:0")), c10::Device("cuda:1")); std::unordered_map deviceMap2 = { {c10::Device("cpu"), c10::Device("cuda:0")}}; +======= + EXPECT_EQ(p1.getMappedDevice(c10::Device("cuda")), c10::Device("cuda:1")); + EXPECT_EQ(p1.getMappedDevice(c10::Device("cuda:0")), c10::Device("cuda:1")); + + std::unordered_map deviceMap2 = { + {c10::Device("cpu"), c10::Device("cuda")}}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Placement p2(deviceMap2); EXPECT_EQ(p2.getMappedDevice(c10::Device("cpu")), c10::Device("cuda:0")); EXPECT_EQ(p2.getMappedDevice(c10::Device("cuda:0")), c10::Device("cuda:0")); diff --git a/test/cpp/nativert/test_weights.cpp b/test/cpp/nativert/test_weights.cpp index 566bc04698712..a31e6a1fa05d4 100644 --- a/test/cpp/nativert/test_weights.cpp +++ b/test/cpp/nativert/test_weights.cpp @@ -25,7 +25,11 @@ return(%o2, %baz) }; TEST_F(WeightsTest, ConstructEmptyStateDict) { std::unordered_map stateDict; +<<<<<<< HEAD Weights weights(graph.get(), stateDict); +======= + Weights weights(graph.get(), stateDict, *placement); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Check that weights are initialized correctly EXPECT_TRUE(weights.parameters().empty()); EXPECT_TRUE(weights.buffers().empty()); @@ -33,7 +37,11 @@ TEST_F(WeightsTest, ConstructEmptyStateDict) { } TEST_F(WeightsTest, SetAndGetValue) { std::unordered_map stateDict; +<<<<<<< HEAD Weights weights(graph.get(), stateDict); +======= + Weights weights(graph.get(), stateDict, *placement); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::Tensor tensor = at::ones({2, 2}); weights.setValue("added_weight", tensor); EXPECT_TRUE(weights.contains("added_weight")); diff --git a/test/cpp/tensorexpr/CMakeLists.txt b/test/cpp/tensorexpr/CMakeLists.txt new file mode 100644 index 0000000000000..8fe6ffd525e98 --- /dev/null +++ b/test/cpp/tensorexpr/CMakeLists.txt @@ -0,0 +1,83 @@ +set(TENSOREXPR_TEST_ROOT ${TORCH_ROOT}/test/cpp/tensorexpr) + +set(TENSOREXPR_TEST_SRCS + ${TENSOREXPR_TEST_ROOT}/test_approx.cpp + ${TENSOREXPR_TEST_ROOT}/test_aten.cpp + ${TENSOREXPR_TEST_ROOT}/test_boundsinference.cpp + ${TENSOREXPR_TEST_ROOT}/test_conv.cpp + ${TENSOREXPR_TEST_ROOT}/test_cpp_codegen.cpp + ${TENSOREXPR_TEST_ROOT}/test_dynamic_shapes.cpp + ${TENSOREXPR_TEST_ROOT}/test_expr.cpp + ${TENSOREXPR_TEST_ROOT}/test_external_calls.cpp + ${TENSOREXPR_TEST_ROOT}/test_graph_opt.cpp + ${TENSOREXPR_TEST_ROOT}/test_ir_printer.cpp + ${TENSOREXPR_TEST_ROOT}/test_ir_verifier.cpp + ${TENSOREXPR_TEST_ROOT}/test_kernel.cpp + ${TENSOREXPR_TEST_ROOT}/test_loopnest.cpp + ${TENSOREXPR_TEST_ROOT}/test_memdependency.cpp + ${TENSOREXPR_TEST_ROOT}/test_ops.cpp + ${TENSOREXPR_TEST_ROOT}/test_quantization.cpp + ${TENSOREXPR_TEST_ROOT}/test_memplanning.cpp + ${TENSOREXPR_TEST_ROOT}/test_reductions.cpp + ${TENSOREXPR_TEST_ROOT}/test_registerizer.cpp + ${TENSOREXPR_TEST_ROOT}/test_simplify.cpp + ${TENSOREXPR_TEST_ROOT}/test_te_fuser_pass.cpp + ${TENSOREXPR_TEST_ROOT}/test_type.cpp + ${TENSOREXPR_TEST_ROOT}/test_type_specializations.cpp +) + +if(USE_CUDA) + list(APPEND TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_cuda.cpp) +endif() + +if(USE_LLVM AND LLVM_FOUND) + list(APPEND TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_llvm.cpp) +endif() + +add_executable(test_tensorexpr + ${TORCH_ROOT}/test/cpp/common/main.cpp + ${TENSOREXPR_TEST_ROOT}/padded_buffer.cpp + ${TENSOREXPR_TEST_SRCS}) + +target_link_libraries(test_tensorexpr PRIVATE torch gtest_main) +target_include_directories(test_tensorexpr PRIVATE ${ATen_CPU_INCLUDE}) +target_compile_definitions(test_tensorexpr PRIVATE USE_GTEST) + +add_executable(tutorial_tensorexpr ${TENSOREXPR_TEST_ROOT}/tutorial.cpp) +target_link_libraries(tutorial_tensorexpr PRIVATE torch) +target_include_directories(tutorial_tensorexpr PRIVATE ${ATen_CPU_INCLUDE}) + +# The test case depends on the xnnpack header which in turn depends on the +# pthreadpool header. For some build environment we need add the dependency +# explicitly. +if(USE_PTHREADPOOL) + target_link_libraries(test_tensorexpr PRIVATE pthreadpool_interface) +endif() +if(USE_CUDA) + target_compile_definitions(test_tensorexpr PRIVATE USE_CUDA) + target_compile_definitions(tutorial_tensorexpr PRIVATE USE_CUDA) +elseif(USE_ROCM) + target_link_libraries(test_tensorexpr PRIVATE + hiprtc::hiprtc + hip::amdhip64 + ${TORCH_CUDA_LIBRARIES}) + target_compile_definitions(test_tensorexpr PRIVATE USE_ROCM) + + target_link_libraries(tutorial_tensorexpr PRIVATE + hiprtc::hiprtc + hip::amdhip64 + ${TORCH_CUDA_LIBRARIES}) + target_compile_definitions(tutorial_tensorexpr PRIVATE USE_ROCM) +endif() + +if(INSTALL_TEST) + set_target_properties(test_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") + install(TARGETS test_tensorexpr DESTINATION bin) + set_target_properties(tutorial_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") + install(TARGETS tutorial_tensorexpr DESTINATION bin) + # Install PDB files for MSVC builds + if(MSVC AND BUILD_SHARED_LIBS) + install(FILES $ DESTINATION bin OPTIONAL) + install(FILES $ DESTINATION bin OPTIONAL) + endif() +endif() diff --git a/test/cpp/tensorexpr/README.md b/test/cpp/tensorexpr/README.md new file mode 100644 index 0000000000000..f86a50a65e804 --- /dev/null +++ b/test/cpp/tensorexpr/README.md @@ -0,0 +1,55 @@ +# TensorExpr C++ Tests + +## How to add a new test +First, create a new test file. Test files should have be placed in this +directory, with a name that starts with `test_`, like `test_foo.cpp`. + +Here is an example test file you can copy-paste. +```cpp +#include + +// Tests go in torch::jit +namespace torch { +namespace jit { + +// 1. Test cases are void() functions. +// 2. They start with the prefix `test` +void testCaseOne() { + // ... +} + +void testCaseTwo() { + // ... +} +} +} +``` + +Then, register your test in `tests.h`: +```cpp +// Add to TH_FORALL_TESTS_CUDA instead for CUDA-requiring tests +#define TH_FORALL_TESTS(_) \ + _(ADFormulas) \ + _(Attributes) \ + ... + _(CaseOne) // note that the `test` prefix is omitted. + _(CaseTwo) +``` + +We glob all the test files together in `CMakeLists.txt` so that you don't +have to edit it every time you add a test. Unfortunately, this means that in +order to get the build to pick up your new test file, you need to re-run +cmake: +```bash +CMAKE_FRESH=1 python setup.py build +``` + +## How do I run the tests? +The following commands assume you are in PyTorch root. + + ```bash + # (re)build the test binary + ninja build/bin/test_tensorexpr + # run + build/bin/test_tensorexpr --gtest_filter='glob_style_filter*' + ``` diff --git a/test/cpp/tensorexpr/gtest_assert_float_eq.h b/test/cpp/tensorexpr/gtest_assert_float_eq.h new file mode 100644 index 0000000000000..f85264a8f5d3c --- /dev/null +++ b/test/cpp/tensorexpr/gtest_assert_float_eq.h @@ -0,0 +1,119 @@ +#pragma once + +#include +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This header file declares functions and macros used internally by +// Google Test. They are subject to change without notice. + +using Bits = uint32_t; + +// this avoids the "dereferencing type-punned pointer +// will break strict-aliasing rules" error +union Float { + float float_; + Bits bits_; +}; + +// # of bits in a number. +static const size_t kBitCount = 8 * sizeof(Bits); +// The mask for the sign bit. +static const Bits kSignBitMask = static_cast(1) << (kBitCount - 1); + +// GOOGLETEST_CM0001 DO NOT DELETE + +// Converts an integer from the sign-and-magnitude representation to +// the biased representation. More precisely, let N be 2 to the +// power of (kBitCount - 1), an integer x is represented by the +// unsigned number x + N. +// +// For instance, +// +// -N + 1 (the most negative number representable using +// sign-and-magnitude) is represented by 1; +// 0 is represented by N; and +// N - 1 (the biggest number representable using +// sign-and-magnitude) is represented by 2N - 1. +// +// Read http://en.wikipedia.org/wiki/Signed_number_representations +// for more details on signed number representations. +static Bits SignAndMagnitudeToBiased(const Bits& sam) { + if (kSignBitMask & sam) { + // sam represents a negative number. + return ~sam + 1; + } else { + // sam represents a positive number. + return kSignBitMask | sam; + } +} + +// Given two numbers in the sign-and-magnitude representation, +// returns the distance between them as an unsigned number. +static Bits DistanceBetweenSignAndMagnitudeNumbers( + const Bits& sam1, + const Bits& sam2) { + const Bits biased1 = SignAndMagnitudeToBiased(sam1); + const Bits biased2 = SignAndMagnitudeToBiased(sam2); + return (biased1 >= biased2) ? (biased1 - biased2) : (biased2 - biased1); +} + +// How many ULP's (Units in the Last Place) we want to tolerate when +// comparing two numbers. The larger the value, the more error we +// allow. A 0 value means that two numbers must be exactly the same +// to be considered equal. +// +// The maximum error of a single floating-point operation is 0.5 +// units in the last place. On Intel CPU's, all floating-point +// calculations are done with 80-bit precision, while double has 64 +// bits. Therefore, 4 should be enough for ordinary use. +// +// See the following article for more details on ULP: +// http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ +static const size_t kMaxUlps = 4; + +// Returns true if and only if this number is at most kMaxUlps ULP's away +// from rhs. In particular, this function: +// +// - returns false if either number is (or both are) NAN. +// - treats really large numbers as almost equal to infinity. +// - thinks +0.0 and -0.0 are 0 DLP's apart. +inline bool AlmostEquals(float lhs, float rhs) { + // The IEEE standard says that any comparison operation involving + // a NAN must return false. + if (std::isnan(lhs) || std::isnan(rhs)) + return false; + + Float l = {lhs}; + Float r = {rhs}; + + return DistanceBetweenSignAndMagnitudeNumbers(l.bits_, r.bits_) <= kMaxUlps; +} diff --git a/test/cpp/tensorexpr/padded_buffer.cpp b/test/cpp/tensorexpr/padded_buffer.cpp new file mode 100644 index 0000000000000..424d82c77453c --- /dev/null +++ b/test/cpp/tensorexpr/padded_buffer.cpp @@ -0,0 +1,37 @@ +#include "test/cpp/tensorexpr/padded_buffer.h" + +#include +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +int PaddedBufferBase::Index(const std::vector& indices) const { + TORCH_DCHECK_EQ(dims_.size(), indices.size()); + int total_index = 0; + for (const auto i : c10::irange(dims_.size())) { + total_index += indices[i] * strides_[i]; + } + return total_index; +} + +PaddedBufferBase::PaddedBufferBase( + const std::vector& dims, + // NOLINTNEXTLINE(modernize-pass-by-value) + const std::string& name) + : dims_(dims), name_(name), strides_(dims.size()) { + for (int i = (int)dims.size() - 1; i >= 0; --i) { + if (i == (int)dims.size() - 1) { + strides_[i] = 1; + } else { + strides_[i] = strides_[i + 1] * dims[i + 1]; + } + } + total_size_ = strides_[0] * dims[0]; +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/padded_buffer.h b/test/cpp/tensorexpr/padded_buffer.h new file mode 100644 index 0000000000000..b3e5227ae7e62 --- /dev/null +++ b/test/cpp/tensorexpr/padded_buffer.h @@ -0,0 +1,242 @@ +#pragma once + +#include +#include + +#include +#include "torch/csrc/jit/tensorexpr/eval.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +template +struct DefaultPaddedValue; + +template <> +struct DefaultPaddedValue { + static const int kValue = static_cast(0xDEADBEEF); +}; + +template <> +struct DefaultPaddedValue { + static const int8_t kValue = static_cast(0xBE); +}; + +template <> +struct DefaultPaddedValue { + static const uint8_t kValue = static_cast(0xBE); +}; + +template <> +struct DefaultPaddedValue { + static const int16_t kValue = static_cast(0xBEEF); +}; + +template <> +struct DefaultPaddedValue { + static const int64_t kValue = static_cast(0xDEADBEEF); +}; + +template <> +struct DefaultPaddedValue { + static constexpr float kValue = 0.1357; +}; + +template <> +struct DefaultPaddedValue { + // at::Half ctor isn't constexpr, so just fill it with bits. + static constexpr uint16_t kValue = 1357; +}; + +template <> +struct DefaultPaddedValue { + static constexpr double kValue = 0.1357; +}; + +// A concrete base to be used in PaddedBase. +class PaddedBufferBase { + public: + const std::string& name() const { + return name_; + } + + int size() const { + return total_size_; + } + + int raw_size() const { + return total_size_ + 2 * kPaddingSize; + } + + virtual ~PaddedBufferBase() {} + + protected: + explicit PaddedBufferBase( + const std::vector& dims, + const std::string& name); + int Index(const std::vector& indices) const; + + std::vector dims_; + std::string name_; + std::vector strides_; + int total_size_; // total number of useful element, does not include the + // paddings + static constexpr int kPaddingSize = 64; +}; + +// A padded buffer with wartermarks for testing. +// The buffer carries padded watermarks on both sides to catch potential +// out-of-bounds writes. For read-only data that are not supposed to change, it +// can also make a backup and be compared later. +template +class PaddedBuffer : public PaddedBufferBase { + public: + PaddedBuffer(int d0, const std::string& name = "") + : PaddedBuffer(std::vector({d0}), name) {} + PaddedBuffer(int d0, int d1, const std::string& name = "") + : PaddedBuffer(std::vector({d0, d1}), name) {} + PaddedBuffer(int d0, int d1, int d2, const std::string& name = "") + : PaddedBuffer(std::vector({d0, d1, d2}), name) {} + PaddedBuffer(int d0, int d1, int d2, int d3, const std::string& name = "") + : PaddedBuffer(std::vector({d0, d1, d2, d3}), name) {} + PaddedBuffer(const std::vector& dims, const std::string& name = "") + : PaddedBufferBase(dims, name) { + data_.resize(total_size_ + 2 * kPaddingSize, kPaddingValue); + } + PaddedBuffer(const PaddedBuffer& other, const std::string& name) + : PaddedBuffer(other) { + this->name_ = name; + } + + T* data() { + return data_.data() + kPaddingSize; + } + const T* data() const { + return const_cast(this)->data(); + } + T* raw_data() { + return data_.data(); + } + const T* raw_data() const { + return const_cast(this)->raw_data(); + } + T& operator()(int i0) { + // There is a bit performance impact with forming a vector here. But this + // data structure is for testing only, and not performance critical. + return this->operator()(std::vector({i0})); + } + const T& operator()(int i0) const { + return const_cast(this)->operator()(i0); + } + T& operator()(int i0, int i1) { + return this->operator()(std::vector({i0, i1})); + } + const T& operator()(int i0, int i1) const { + return const_cast(this)->operator()(i0, i1); + } + T& operator()(int i0, int i1, int i2) { + return this->operator()(std::vector({i0, i1, i2})); + } + const T& operator()(int i0, int i1, int i2) const { + return const_cast(this)->operator()(i0, i1, i2); + } + T& operator()(int i0, int i1, int i2, int i3) { + return this->operator()(std::vector({i0, i1, i2, i3})); + } + const T& operator()(int i0, int i1, int i2, int i3) const { + return const_cast(this)->operator()(i0, i1, i2, i3); + } + T& operator()(const std::vector& indices) { + return data_[kPaddingSize + Index(indices)]; + } + const T& operator()(const std::vector& indices) const { + return const_cast(this)->operator()(indices); + } + + template + friend void ExpectAllNear( + const PaddedBuffer& v1, + const PaddedBuffer& v2, + float abs_error); + template + friend void ExpectAllEqual( + const PaddedBuffer& v1, + const PaddedBuffer& v2); + void Backup() { + backup_data_ = data_; + } + + // Verify the watermarks in the paddings are intact. + void ValidateWatermark() const { + for (const auto i : c10::irange(kPaddingSize)) { + ASSERT_EQ(data_[i], kPaddingValue); + ASSERT_EQ(data_[i + total_size_ + kPaddingSize], kPaddingValue); + } + } + + void CheckBackup() const { + ValidateWatermark(); + DCHECK(backup_data_.size() == data_.size()) + << "Please make sure you have call Backup() before calling CheckBackup()"; + for (const auto i : c10::irange(total_size_)) { + ASSERT_EQ(data_[i + kPaddingSize], backup_data_[i + kPaddingSize]); + } + } + + private: + std::vector data_; + std::vector backup_data_; + T kPaddingValue = DefaultPaddedValue::kValue; +}; + +template +inline CodeGen::CallArg::CallArg(const PaddedBuffer& buffer) + : data_(const_cast(buffer.data())) {} + +template +std::string CompareErrorMsg( + const PaddedBuffer& v1, + const PaddedBuffer& v2, + int index) { + std::ostringstream oss; + oss << "index: " << index << ", v1: (" << v1.name() << ", " << v1(index) + << ")" + << ", v2: (" << v2.name() << ", " << v2(index) << ")"; + return oss.str(); +} + +template +void ExpectAllEqual(const PaddedBuffer& f1, const PaddedBuffer& f2) { + const std::vector& v1 = f1.data_; + const std::vector& v2 = f2.data_; + const int kPaddingSize = f1.kPaddingSize; + const int total_size = f1.total_size_; + ASSERT_EQ(v1.size(), v2.size()); + f1.ValidateWatermark(); + f2.ValidateWatermark(); + for (const auto i : c10::irange(total_size)) { + ASSERT_EQ(v1[kPaddingSize + i], v2[kPaddingSize + i]); + } +} + +template +void ExpectAllNear( + const PaddedBuffer& f1, + const PaddedBuffer& f2, + float abs_error) { + const std::vector& v1 = f1.data_; + const std::vector& v2 = f2.data_; + const int kPaddingSize = f1.kPaddingSize; + const int total_size = f1.total_size_; + ASSERT_EQ(v1.size(), v2.size()); + f1.ValidateWatermark(); + f2.ValidateWatermark(); + for (const auto i : c10::irange(total_size)) { + ASSERT_NEAR(v1[kPaddingSize + i], v2[kPaddingSize + i], abs_error); + } +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_approx.cpp b/test/cpp/tensorexpr/test_approx.cpp new file mode 100644 index 0000000000000..e1a576aecf526 --- /dev/null +++ b/test/cpp/tensorexpr/test_approx.cpp @@ -0,0 +1,96 @@ +#ifdef TORCH_ENABLE_LLVM + +#include +#include +#include +#include +#include +#include +#include + +using namespace torch::indexing; +namespace te = torch::jit::tensorexpr; + +static void vectorize(te::LoopNest* ln, te::Tensor target, int width) { + auto loops = ln->getLoopStmtsFor(target); + te::ForPtr inner, tail; + ln->splitWithTail(loops[0], width, &inner, &tail); + ASSERT_TRUE(te::LoopNest::vectorize(inner)); +} + +std::string diffs(const at::Tensor& a, const at::Tensor& b) { + auto diff = torch::abs(a.flatten() - b.flatten()); + auto count_diffs = torch::sum(diff > 0.f); + auto greatest_diff_index = torch::argmax(diff); + std::stringstream ss; + ss << "Found " << count_diffs << " unequal element(s). " + << "The greatest difference was " << diff.index({greatest_diff_index}) + << " at index " << greatest_diff_index; + return ss.str(); +} + +TEST(Approx, log_vml) { + te::VarHandle N("N", te::kInt); + te::BufHandle A("A", {N}, te::kFloat); + te::Tensor B = te::Compute( + "B", {N}, [&](const te::VarHandle& i) { return log_vml(A.load(i)); }); + + te::LoopNest ln({B}); + ln.prepareForCodegen(); + vectorize(&ln, B, 8); + te::StmtPtr s = ln.root_stmt(); + s = te::IRSimplifier::simplify(s); + te::LLVMCodeGen cg(s, {A, B, N}); + + auto eps = std::numeric_limits::epsilon(); + auto test = [&](const at::Tensor& A_t) { + at::Tensor B_ref = at::log(A_t); + at::Tensor B_t = at::empty_like(A_t); + auto ap = A_t.data_ptr(); + auto bp = B_t.data_ptr(); + cg.call({ap, bp, A_t.numel()}); + // Results should be bit-identical. + ASSERT_TRUE(torch::allclose( + B_t, B_ref, /*rtol=*/eps, /*atol=*/0.0f, /*equal_nan=*/true)) + << "Input[:8]\n" + << A_t.index({Slice(0, 8)}) << "\n" + << "Test[:8]\n" + << B_t.index({Slice(0, 8)}) << "\n" + << "Ref[:8]\n" + << B_ref.index({Slice(0, 8)}) << diffs(B_t, B_ref); + }; + + // Generate every single-precision FP value in [1.0, 2.0). + at::Tensor A_t = torch::arange(1.0f, 2.0f, eps); + ASSERT_EQ(A_t.numel(), 1 << 23); + + test(A_t); + + test(A_t * 2.0f); + test(A_t * 0.5f); + + test(A_t * 4.0f); + test(A_t * 0.25f); + + test(A_t * powf(2.0f, 16)); + test(A_t * powf(2.0f, -16)); + + test(A_t * powf(2.0f, 126)); + test(A_t * powf(2.0f, -126)); + + test(torch::full({32}, INFINITY)); + test(torch::full({32}, NAN)); + + auto min = std::numeric_limits::min(); + auto denorm_min = std::numeric_limits::denorm_min(); + + // Denormals aren't bit precise, because sleef isn't bit-precise either. + A_t = torch::arange(0.0f, min, denorm_min); + ASSERT_EQ(A_t.numel(), 1 << 23); + auto B_ref = at::log(A_t); + auto B_t = at::empty_like(B_ref); + cg.call({A_t.data_ptr(), B_t.data_ptr(), A_t.numel()}); + ASSERT_TRUE(torch::allclose(B_t, B_ref)); +} + +#endif // TORCH_ENABLE_LLVM diff --git a/test/cpp/tensorexpr/test_aten.cpp b/test/cpp/tensorexpr/test_aten.cpp new file mode 100644 index 0000000000000..34ce2bd069d55 --- /dev/null +++ b/test/cpp/tensorexpr/test_aten.cpp @@ -0,0 +1,1068 @@ +#include +#include +#include + +#include + +#include +#include +#include "test/cpp/tensorexpr/padded_buffer.h" +#include "test/cpp/tensorexpr/test_base.h" +#include "torch/csrc/jit/tensorexpr/ir_printer.h" + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +TEST(ATen, _cast_Float) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle to_float = Cast::make(kFloat, load_a); + StmtPtr store_b = b_buf.store({index}, to_float); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), static_cast(i)); + } +} + +TEST(ATen, negInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle to_float = Sub::make(0, load_a); + StmtPtr store_b = b_buf.store({index}, to_float); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), -static_cast(i)); + } +} + +TEST(ATen, negFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle to_float = Sub::make(0, load_a); + StmtPtr store_b = b_buf.store({index}, to_float); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), -i); + } +} + +TEST(ATen, addInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); + BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); + ir_eval(a_v, b_v, c_v, d_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), 3 * i + 2); + ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)); + } +} + +TEST(ATen, addFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + StmtPtr store_d = d_buf.store({index}, load_a + load_b * load_c); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); + ir_eval(a_v, b_v, c_v, d_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), 3 * i + 2); + ASSERT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)); + } +} + +TEST(ATen, subInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); + BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); + ir_eval(a_v, b_v, c_v, d_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), 3 * i + 2); + ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)); + } +} + +TEST(ATen, subFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + StmtPtr store_d = d_buf.store({index}, load_a - load_b * load_c); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); + ir_eval(a_v, b_v, c_v, d_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), 3 * i + 2); + ASSERT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)); + } +} + +TEST(ATen, lerp) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + StmtPtr store_d = d_buf.store({index}, load_a + load_c * (load_b - load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_d); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf}); + ir_eval(a_v, b_v, c_v, d_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), 3 * i + 2); + ASSERT_EQ(d_v(i), a_v(i) + c_v(i) * (b_v(i) - a_v(i))); + } +} + +TEST(ATen, addcmulInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); + BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kInt); + BufHandle e_buf("E", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + ExprHandle load_d = d_buf.load(index); + StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_e); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + PaddedBuffer e_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + d_v(i) = 5 * i + 3; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf}); + ir_eval(a_v, b_v, c_v, d_v, e_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), 3 * i + 2); + ASSERT_EQ(d_v(i), 5 * i + 3); + ASSERT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)); + } +} + +TEST(ATen, addcmulFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + BufHandle d_buf("D", {ExprHandle(kTotalSize)}, kFloat); + BufHandle e_buf("E", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + ExprHandle load_c = c_buf.load(index); + ExprHandle load_d = d_buf.load(index); + StmtPtr store_e = e_buf.store({index}, load_a + load_b * load_c * load_d); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_e); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + PaddedBuffer e_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + d_v(i) = 5 * i + 3; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf, d_buf, e_buf}); + ir_eval(a_v, b_v, c_v, d_v, e_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), 3 * i + 2); + ASSERT_EQ(d_v(i), 5 * i + 3); + ASSERT_FLOAT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)); + } +} + +TEST(ATen, mulInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, load_a * load_b); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), a_v(i) * b_v(i)); + } +} + +TEST(ATen, mulFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, load_a * load_b); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), a_v(i) * b_v(i)); + } +} + +TEST(ATen, divInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, load_a / load_b); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = 2 * i + 1; + b_v(i) = i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), 2 * i + 1); + ASSERT_EQ(b_v(i), i + 1); + ASSERT_EQ(c_v(i), a_v(i) / b_v(i)); + } +} + +TEST(ATen, divFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, load_a / load_b); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = 2 * i + 1; + b_v(i) = i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), 2 * i + 1); + ASSERT_EQ(b_v(i), i + 1); + ASSERT_EQ(c_v(i), a_v(i) / b_v(i)); + } +} + +TEST(ATen, maxInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), std::max(a_v(i), b_v(i))); + } +} + +TEST(ATen, maxFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, Max::make(load_a, load_b, true)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), std::fmax(a_v(i), b_v(i))); + } +} + +TEST(ATen, minInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), std::min(a_v(i), b_v(i))); + } +} + +TEST(ATen, minFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c_buf("C", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + ExprHandle load_b = b_buf.load(index); + StmtPtr store_c = c_buf.store({index}, Min::make(load_a, load_b, true)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 2 * i + 1); + ASSERT_EQ(c_v(i), std::fmin(a_v(i), b_v(i))); + } +} + +void __ubsan_ignore_float_divide_by_zero__ testATenreciprocal() { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, FloatImm::make(1.0f) / load_a); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i); + ASSERT_EQ(b_v(i), 1.0f / i); + } +} + +TEST(ATen, reluInt) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, Max::make(load_a, 0, false)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i - 64; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i - 64); + ASSERT_EQ(b_v(i), std::max(a_v(i), 0)); + } +} + +TEST(ATen, reluFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store( + {index}, Max::make(load_a, 0, false) // relu does not propagate nans + ); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i - 64; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i - 64); + ASSERT_EQ(b_v(i), std::fmax(a_v(i), 0)); + } +} + +TEST(ATen, logFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, log(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i + 10; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i + 10); + ASSERT_EQ(b_v(i), std::log(a_v(i))); + } +} + +TEST(ATen, fastLogFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, fast_log(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = at::randn({1}).item().to(); + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + auto test = b_v(i); + auto ref = std::log(a_v(i)); + if (std::isnan(ref)) { + ASSERT_EQ(std::isnan(test), true); + } else { + ASSERT_FLOAT_EQ(test, ref); + } + } +} + +TEST(ATen, fastTanhFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, fast_tanh(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = at::randn({1}).item().to(); + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + auto test = b_v(i); + auto ref = std::tanh(a_v(i)); + if (std::isnan(ref)) { + ASSERT_EQ(std::isnan(test), true); + } else { + ASSERT_NEAR(test, ref, 1e-6); + } + } +} + +TEST(ATen, fastSigmoidFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, fast_sigmoid(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = at::randn({1}).item().to(); + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + auto test = b_v(i); + at::Tensor t = at::ones({1}) * a_v(i); + float ref = at::sigmoid(t).item().to(); + if (std::isnan(ref)) { + ASSERT_EQ(std::isnan(test), true); + } else { + ASSERT_NEAR(test, ref, 1e-6); + } + } +} + +TEST(ATen, log10Float) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, log10(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i + 10; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i + 10); + ASSERT_EQ(b_v(i), std::log10(a_v(i))); + } +} + +TEST(ATen, log2Float) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, log2(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i + 10; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i + 10); + ASSERT_EQ(b_v(i), std::log2(a_v(i))); + } +} + +TEST(ATen, expFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, exp(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + a_v(i) = i / 10.0f; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i / 10.0f); + ASSERT_EQ(b_v(i), std::exp(a_v(i))); + } +} + +TEST(ATen, erfFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, erf(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + a_v(i) = i / 10.0f; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i / 10.0f); + ASSERT_EQ(b_v(i), std::erf(a_v(i))); + } +} + +TEST(ATen, cosFloat) { + const int kTotalSize = 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, cos(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + a_v(i) = i / 10.0f; + } + + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf}); + ir_eval(a_v, b_v); + + for (const auto i : c10::irange(kTotalSize)) { + ASSERT_EQ(a_v(i), i / 10.0f); + ASSERT_EQ(b_v(i), std::cos(a_v(i))); + } +} + +TEST(ATen, eqInt) { + constexpr int N = 128; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 1); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 0); + + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kEQ))); + + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); + ir_eval(a_buffer, b_buffer, c_buffer); + + assertAllEqual(c_buffer, 1); +} + +TEST(ATen, geInt) { + constexpr int N = 128; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 5); + std::vector b_buffer(N, 5); + std::vector c_buffer(N, 0); + + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kGE))); + + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); + ir_eval(a_buffer, b_buffer, c_buffer); + + assertAllEqual(c_buffer, 1); +} + +TEST(ATen, gtInt) { + constexpr int N = 128; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 6); + std::vector b_buffer(N, 3); + std::vector c_buffer(N, 0); + + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kGT))); + + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); + ir_eval(a_buffer, b_buffer, c_buffer); + + assertAllEqual(c_buffer, 1); +} + +TEST(ATen, leInt) { + constexpr int N = 128; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 5); + std::vector b_buffer(N, 5); + std::vector c_buffer(N, 0); + + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kLE))); + + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); + ir_eval(a_buffer, b_buffer, c_buffer); + + assertAllEqual(c_buffer, 1); +} + +TEST(ATen, ltInt) { + constexpr int N = 128; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 5); + std::vector b_buffer(N, 5); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kLT))); + + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); + ir_eval(a_buffer, b_buffer, c_buffer); + + assertAllEqual(c_buffer, 0); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_base.h b/test/cpp/tensorexpr/test_base.h new file mode 100644 index 0000000000000..68b96fe6c90f7 --- /dev/null +++ b/test/cpp/tensorexpr/test_base.h @@ -0,0 +1,89 @@ +#pragma once + +#if defined(USE_GTEST) +#include +#include +#else +#include +#include "c10/util/Exception.h" +#include "test/cpp/tensorexpr/gtest_assert_float_eq.h" +#define ASSERT_EQ(x, y, ...) TORCH_INTERNAL_ASSERT((x) == (y), __VA_ARGS__) +#define ASSERT_FLOAT_EQ(x, y, ...) \ + TORCH_INTERNAL_ASSERT(AlmostEquals((x), (y)), __VA_ARGS__) +#define ASSERT_NE(x, y, ...) TORCH_INTERNAL_ASSERT((x) != (y), __VA_ARGS__) +#define ASSERT_GT(x, y, ...) TORCH_INTERNAL_ASSERT((x) > (y), __VA_ARGS__) +#define ASSERT_GE(x, y, ...) TORCH_INTERNAL_ASSERT((x) >= (y), __VA_ARGS__) +#define ASSERT_LT(x, y, ...) TORCH_INTERNAL_ASSERT((x) < (y), __VA_ARGS__) +#define ASSERT_LE(x, y, ...) TORCH_INTERNAL_ASSERT((x) <= (y), __VA_ARGS__) + +#define ASSERT_NEAR(x, y, a, ...) \ + TORCH_INTERNAL_ASSERT(std::fabs((x) - (y)) < (a), __VA_ARGS__) + +#define ASSERT_TRUE TORCH_INTERNAL_ASSERT +#define ASSERT_FALSE(x) ASSERT_TRUE(!(x)) +#define ASSERT_THROWS_WITH(statement, substring) \ + try { \ + (void)statement; \ + ASSERT_TRUE(false); \ + } catch (const std::exception& e) { \ + ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \ + } +#define ASSERT_ANY_THROW(statement) \ + { \ + bool threw = false; \ + try { \ + (void)statement; \ + } catch (const std::exception& e) { \ + threw = true; \ + } \ + ASSERT_TRUE(threw); \ + } + +#endif // defined(USE_GTEST) +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +template +void ExpectAllNear( + const std::vector& v1, + const std::vector& v2, + V threshold, + const std::string& name = "") { + ASSERT_EQ(v1.size(), v2.size()); + for (size_t i = 0; i < v1.size(); i++) { + ASSERT_NEAR(v1[i], v2[i], threshold); + } +} + +template +void ExpectAllNear( + const std::vector& vec, + const U& val, + V threshold, + const std::string& name = "") { + for (size_t i = 0; i < vec.size(); i++) { + ASSERT_NEAR(vec[i], val, threshold); + } +} + +template +static void assertAllEqual(const std::vector& vec, const T& val) { + for (auto const& elt : vec) { + ASSERT_EQ(elt, val); + } +} + +template +static void assertAllEqual(const std::vector& v1, const std::vector& v2) { + ASSERT_EQ(v1.size(), v2.size()); + for (size_t i = 0; i < v1.size(); ++i) { + ASSERT_EQ(v1[i], v2[i]); + } +} +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_boundsinference.cpp b/test/cpp/tensorexpr/test_boundsinference.cpp new file mode 100644 index 0000000000000..2605842d6e74d --- /dev/null +++ b/test/cpp/tensorexpr/test_boundsinference.cpp @@ -0,0 +1,1019 @@ +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +static void verifyConstBounds( + const TensorAccessBoundsInfo& access_info, + const std::vector>& ref) { + size_t ndim = ref.size(); + ASSERT_EQ(access_info.start.size(), ndim); + ASSERT_EQ(access_info.stop.size(), ndim); + for (const auto i : c10::irange(ndim)) { + if (ref[i].first >= 0) { // Negative values are used to skip the check + ASSERT_TRUE(access_info.start[i]->isConstant()); + int start_i = immediateAs(access_info.start[i]); + ASSERT_EQ(start_i, ref[i].first); + } + if (ref[i].second >= 0) { + ASSERT_TRUE(access_info.stop[i]->isConstant()); + int stop_i = immediateAs(access_info.stop[i]); + ASSERT_EQ(stop_i, ref[i].second); + } + } +} + +TEST(BoundsInference, _1) { + // Verify that bounds inference works for the following example: + // for i in 0..100: + // b[i] = a[i] + // For this loop bounds inference should yield the following: + // {{b, kStore, 0, 99}, {a, kLoad, 0, 99}} + ExprHandle n(100); + BufHandle a("a", {n}, kFloat); + Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); + LoopNest l({b}); + auto bounds_info = inferBounds(l.root_stmt()); + + // We should have two entries: one for 'b' and one for 'a'. + ASSERT_EQ(bounds_info.size(), 2); + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{0, 99}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}}); +} + +TEST(BoundsInference, _2) { + // Verify that bounds inference works for the following example: + // for i in 0..n: + // b[i] = a[i] + // For this loop bounds inference should yield the following: + // {{b, kStore, 0, n-1}, {a, kLoad, 0, n-1}} + VarHandle n("n", kInt); + BufHandle a("a", {n}, kFloat); + Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); + LoopNest l({b}); + auto bounds_info = inferBounds(l.root_stmt()); + + // We should have two entries: one for 'b' and one for 'a'. + ASSERT_EQ(bounds_info.size(), 2); + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{0, -1}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, -1}}); +} + +TEST(BoundsInference, _3) { + // Verify that bounds inference works for the following example: + // for i in 0..100: + // b[i] = a[i] * a[i+10] + // For this loop bounds inference should yield the following: + // {{b, kStore, 0, 99}, {a, kLoad, 0, 109}} + ExprHandle n(100); + BufHandle a("a", {n + 10}, kFloat); + Tensor b = Compute( + "b", {n}, [&](const VarHandle& i) { return a.load(i) * a.load(i + 10); }); + LoopNest l({b}); + auto bounds_info = inferBounds(l.root_stmt()); + + // We should have two entries: one for 'b' and one for 'a'. + ASSERT_EQ(bounds_info.size(), 2); + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{0, 109}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 99}}); +} + +TEST(BoundsInference, _4) { + // Verify that bounds inference works for the following example: + // + // for y in 0..200: + // for x in 0..320: + // b[y,x] = x*y + // for y in 0..200: + // for x in 0..320: + // c[y,x] = a[y,x] * b[y,x] + ExprHandle W(320); + ExprHandle H(200); + BufHandle a("a", {H, W}, kFloat); + Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) { + return x * y; + }); + Tensor c = Compute("c", {H, W}, [&](const VarHandle& y, const VarHandle& x) { + return a.load(y, x) * b.load(y, x); + }); + LoopNest l({c}); + std::vector loops = l.getLoopStmtsFor(c); + StmtPtr body = l.getLoopBodyFor(c); + { + // Infer bounds on the top-level loop scope + auto bounds_info = inferBounds(loops[0]); + ASSERT_EQ(bounds_info.size(), 3); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{0, 199}, {0, 319}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 199}, {0, 319}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 199}, {0, 319}}); + } + { + // Infer bounds on the inner loop scope + auto bounds_info = inferBounds(loops[1]); + ASSERT_EQ(bounds_info.size(), 3); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {0, 319}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 319}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 319}}); + } + { + // Infer bounds on the inner loop body's scope + auto bounds_info = inferBounds(body); + ASSERT_EQ(bounds_info.size(), 3); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}}); + } +} + +TEST(BoundsInference, _5) { + // Verify that bounds inference works for the following example: + // for i in 0..100: + // b[i] = a[i] + // + // ==> split ==> + // + // for i_outer in 0..100/16: + // for i_inner in 0..16: + // b[i_outer * 16 + i_inner] = a[i_outer * 16 + i_inner] + // for i_tail in 0..100%16: + // b[i_tail + (100/16)*16] = a[i_tail + (100/16)*16]; + ExprHandle n(100); + BufHandle a("a", {n}, kFloat); + Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); + LoopNest l({b}); + + ForPtr inner; + ForPtr tail; + std::vector loops = l.getLoopStmtsFor(b); + LoopNest::splitWithTail(loops[0], 16, &inner, &tail); + ForPtr outer = loops[0]; + + { + // Verify inferred bounds for the outer loop + auto bounds_info = inferBounds(outer); + ASSERT_EQ(bounds_info.size(), 2); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{0, 95}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 95}}); + } + { + // Verify inferred bounds for the tail loop + auto bounds_info = inferBounds(tail); + ASSERT_EQ(bounds_info.size(), 2); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{96, 99}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{96, 99}}); + } +} + +TEST(BoundsInference, _6) { + // Verify that bounds inference works for the following example: + // + // for y in 0..200: + // for x in 0..320: + // b[y,x] = x*y + // for y in 0..20: + // for x in 0..32: + // c[y,x] = a[y+100,x+100] * b[y*2,x*5] + ExprHandle W(320); + ExprHandle H(200); + ExprHandle CW(32); + ExprHandle CH(20); + BufHandle a("a", {H, W}, kFloat); + Tensor b = Compute("b", {H, W}, [&](const VarHandle& y, const VarHandle& x) { + return x * y; + }); + Tensor c = + Compute("c", {CH, CW}, [&](const VarHandle& y, const VarHandle& x) { + return a.load(y + 100, x + 100) * b.load(y * 2, x * 5); + }); + LoopNest l({c}); + std::vector loops = l.getLoopStmtsFor(c); + StmtPtr body = l.getLoopBodyFor(c); + { + // Infer bounds on the top-level loop scope + auto bounds_info = inferBounds(loops[0]); + ASSERT_EQ(bounds_info.size(), 3); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{100, 119}, {100, 131}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 38}, {0, 155}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 19}, {0, 31}}); + } + { + // Infer bounds on the inner loop scope + auto bounds_info = inferBounds(loops[1]); + ASSERT_EQ(bounds_info.size(), 3); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {100, 131}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {0, 155}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {0, 31}}); + } + { + // Infer bounds on the inner loop body's scope + auto bounds_info = inferBounds(body); + ASSERT_EQ(bounds_info.size(), 3); + + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{-1, -1}, {-1, -1}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(b.buf())[0], {{-1, -1}, {-1, -1}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{-1, -1}, {-1, -1}}); + } +} + +TEST(BoundsInference, Adjacent) { + ExprHandle H(6); + BufHandle a("a", {20}, kFloat); + Tensor b = Compute("b", {H}, [&](const VarHandle& x) { return a.load(x); }); + Tensor c = + Compute("c", {H}, [&](const VarHandle& x) { return a.load(x + H); }); + LoopNest l({b, c}); + std::vector loops = NodeFinder::find(l.root_stmt()); + + { + // Infer bounds on the top-level loop scope + auto bounds_info = inferBounds(loops[0]); + ASSERT_EQ(bounds_info.size(), 2); + + // reads from a[0:5], writes to b[0:5] + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{0, 5}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}}); + } + { + // Infer bounds on the inner loop scope + auto bounds_info = inferBounds(loops[1]); + ASSERT_EQ(bounds_info.size(), 2); + + // reads from a[0+6:5+6], writes to c[0:5] + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{6, 11}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}}); + } + { + // Infer bounds on the high level program. + auto bounds_info = inferBounds(l.root_stmt()); + ASSERT_EQ(bounds_info.size(), 3); + + // Should be union of above 2 bounds, but this time the bounds of A can be + // merged. + ASSERT_EQ(bounds_info.at(a.node()).size(), 1); + ASSERT_EQ(bounds_info.at(a.node())[0].kind, kLoad); + verifyConstBounds(bounds_info.at(a.node())[0], {{0, 11}}); + + ASSERT_EQ(bounds_info.at(b.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(b.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(b.buf())[0], {{0, 5}}); + + ASSERT_EQ(bounds_info.at(c.buf()).size(), 1); + ASSERT_EQ(bounds_info.at(c.buf())[0].kind, kStore); + verifyConstBounds(bounds_info.at(c.buf())[0], {{0, 5}}); + } +} + +TEST(BoundsInference, MultipleTopLoopLoad) { + BufHandle a("a", {100}, kFloat); + Tensor b = Compute("b", {64}, [&](const VarHandle& x) { return a.load(x); }); + Tensor c = + Compute("c", {32}, [&](const VarHandle& x) { return a.load(x + 10); }); + Tensor d = + Compute("d", {96}, [&](const VarHandle& x) { return a.load(x + 2); }); + LoopNest l({b, c, d}); + + auto bounds_info = inferBounds(l.root_stmt()); + + ASSERT_EQ(bounds_info.size(), 4); + + // a only read. + { + auto bounds = bounds_info[a.node()]; + ASSERT_EQ(bounds.size(), 1); + // One dimension. + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kLoad); + // Bounds: + // start: Min of the 3 load bounds = Min of loop starts + offset = 0+0 (b). + // stop: Max of the 3 load bounds = Max of loop stops + offset - 1 = + // 96 + 2 - 1 (d). + verifyConstBounds(bound, {{0, 97}}); + } + + // b, c, d only written. + { + auto bounds = bounds_info[b.buf()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // Just the loop extents for b. + verifyConstBounds(bound, {{0, 63}}); + } + { + auto bounds = bounds_info[c.buf()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // Just the loop extents for c. + verifyConstBounds(bound, {{0, 31}}); + } + { + auto bounds = bounds_info[d.buf()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // Just the loop extents for d. + verifyConstBounds(bound, {{0, 95}}); + } +} + +TEST(BoundsInference, MultipleTopLoopStore) { + BufHandle a("a", {100}, kFloat); + BufHandle b("b", {100}, kFloat); + BufHandle c("c", {100}, kFloat); + BufHandle d("d", {100}, kFloat); + VarHandle x("x", kInt); + + // Same as above but the offsets are on the Store now. + // Can't do this through ComputeAPI without transforms we don't have yet. + StmtPtr stmt = Block::make( + {For::make(x, 0, 64, Store::make(b, {x}, Load::make(a, {x}))), + For::make(x, 0, 32, Store::make(c, {x + 10}, Load::make(a, {x}))), + For::make(x, 0, 96, Store::make(d, {x + 2}, Load::make(a, {x})))}); + + auto bounds_info = inferBounds(stmt); + + ASSERT_EQ(bounds_info.size(), 4); + + // a only read. + { + auto bounds = bounds_info[a.node()]; + ASSERT_EQ(bounds.size(), 1); + // One dimension. + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kLoad); + // Bounds: there are no offsets, so this is just the max loop bounds. + verifyConstBounds(bound, {{0, 95}}); + } + + // b, c, d only written. + { + auto bounds = bounds_info[b.node()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // This should be equivalent to {offset, extent + offset} for the b loop. + // b loop has no offset, so just the loop extents. + verifyConstBounds(bound, {{0, 63}}); + } + { + auto bounds = bounds_info[c.node()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // This should be equivalent to {offset, extent + offset} for the c loop. + // Offset is 10, extent is 32-1. + verifyConstBounds(bound, {{10, 41}}); + } + { + auto bounds = bounds_info[d.node()]; + ASSERT_EQ(bounds.size(), 1); + auto bound = bounds[0]; + ASSERT_EQ(bound.kind, TensorAccessKind::kStore); + // This should be equivalent to {offset, extent + offset} for the d loop. + // Offset is 2, extent is 96-1. + verifyConstBounds(bound, {{2, 97}}); + } +} + +TEST(BoundsInference, CacheReads) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor B = + Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 30, j + 3); + }); + Tensor C = + Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); + }); + + LoopNest l({B, C}); + auto bounds_info_before = inferBounds(l.root_stmt()); + + StmtPtr j_loop = l.getLoopStmtsFor(B)[1]; + LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); + + auto bounds_info_after = inferBounds(l.root_stmt()); + + // CacheAccesses should not change existing bounds, but add a new one for the + // cache. + for (auto& pair : bounds_info_after) { + auto beforeIt = bounds_info_before.find(pair.first); + if (beforeIt != bounds_info_before.end()) { + // Same number of TensorAccessBoundInfos. + ASSERT_EQ(pair.second.size(), beforeIt->second.size()); + + for (const auto i : c10::irange(pair.second.size())) { + TensorAccessBoundsInfo& after = pair.second[i]; + TensorAccessBoundsInfo& before = beforeIt->second[i]; + // Same number of dimensions. + ASSERT_EQ(before.start.size(), after.start.size()); + + // Bounds are equal. + for (const auto j : c10::irange(before.start.size())) { + ASSERT_TRUE(exprEquals(before.start[j], after.start[j])); + ASSERT_TRUE(exprEquals(before.stop[j], after.stop[j])); + } + } + } else { + // This should be the cache. + ASSERT_EQ(pair.first->name_hint(), "A_local"); + // Should have both a load and a store. + ASSERT_EQ(pair.second.size(), 2); + TensorAccessBoundsInfo& first = pair.second[0]; + TensorAccessBoundsInfo& second = pair.second[1]; + + ASSERT_NE(first.kind, second.kind); + // 2 dimensions. + ASSERT_EQ(first.start.size(), second.start.size()); + ASSERT_EQ(first.start.size(), 2); + + // bounds for load and store are equal. + for (const auto j : c10::irange(first.start.size())) { + ASSERT_TRUE(exprEquals(first.start[j], second.start[j])); + ASSERT_TRUE(exprEquals(first.stop[j], second.stop[j])); + } + } + } +} + +TEST(BoundsInference, Flattened) { + Tensor b = Compute( + "b", + {3, 4, 5}, + [&](const VarHandle& z, const VarHandle& y, const VarHandle& x) { + return x * y + z; + }); + + LoopNest l({b}); + // Flatten indices. + l.prepareForCodegen(); + auto bounds_info = inferBounds(l.root_stmt()); + + // There's only one buffer. + ASSERT_EQ(bounds_info.size(), 1); + auto& TABI = bounds_info[b.buf()][0]; + ASSERT_EQ(TABI.kind, TensorAccessKind::kStore); + // Flattened bounds should have a single dimension. + ASSERT_EQ(TABI.start.size(), 1); + ASSERT_EQ(TABI.stop.size(), 1); + + // Bounds should be 0 -> (3*4*5)-1 + ASSERT_TRUE(exprEquals(TABI.start[0], alloc(0))); + ASSERT_TRUE(exprEquals(TABI.stop[0], alloc(3 * 4 * 5 - 1))); +} + +TEST(BoundsInference, GetPotentialHazards) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + + { + /* + * A[0] = B[0]; + * B[0] = 3; WAR on B + * A[0] = B[0]; WAW on A, RAW on B + * C[0] = 5; + */ + + StorePtr store1 = Store::make(a, {0}, Load::make(b, {0})); + StorePtr store2 = Store::make(b, {0}, 3); + StorePtr store3 = Store::make(a, {0}, Load::make(b, {0})); + StorePtr store4 = Store::make(c, {0}, 5); + StmtPtr stmt = Block::make({store1, store2, store3, store4}); + + MemDependencyChecker analyzer; + stmt->accept(&analyzer); + + ASSERT_EQ( + HazardKind::WriteAfterRead, + getPotentialHazards(analyzer, store1, store2)); + + ASSERT_EQ( + HazardKind::ReadAfterWrite, + getPotentialHazards(analyzer, store2, store3)); + + ASSERT_EQ( + HazardKind::WriteAfterWrite, + getPotentialHazards(analyzer, store1, store3)); + + // Fourth store has no dependencies + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, store1, store4)); + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, store2, store4)); + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, store3, store4)); + } +} + +TEST(BoundsInference, GetPotentialHazardsLoopNoHazard) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor B = Compute("B", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return (i + 1) * (j + 1); + }); + + LoopNest l({A, B}); + + using namespace analysis; + + MemDependencyChecker analyzer; + l.root_stmt()->accept(&analyzer); + + ForPtr loopRootA = l.getLoopStmtsFor(A)[0]; + ForPtr loopRootB = l.getLoopStmtsFor(B)[0]; + + // No dependencies between loops. + ASSERT_EQ( + HazardKind::NoDependency, + getPotentialHazards(analyzer, loopRootA, loopRootB)); +} + +TEST(BoundsInference, GetPotentialHazardsLoopCall) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor B = + Compute("B", {64, 64}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i, j) + 5; + }); + + LoopNest l({A, B}); + + using namespace analysis; + + MemDependencyChecker analyzer; + l.root_stmt()->accept(&analyzer); + + ForPtr loopRootA = l.getLoopStmtsFor(A)[0]; + ForPtr loopRootB = l.getLoopStmtsFor(B)[0]; + + ASSERT_EQ( + HazardKind::ReadAfterWrite, + getPotentialHazards(analyzer, loopRootA, loopRootB)); +} + +TEST(BoundsInference, GetPotentialHazardsLoopSplit) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + + LoopNest l({A}); + ForPtr inner, tail; + + // Splitting with tail by something offset creates a tail which also writes to + // A. + ForPtr outer = l.getLoopStmtsFor(A)[0]; + // `outer` loop get transformed to the outer loop after splitting. + LoopNest::splitWithTail(outer, 5, &inner, &tail); + + using namespace analysis; + + MemDependencyChecker analyzer; + l.root_stmt()->accept(&analyzer); + + ASSERT_EQ( + HazardKind::WriteAfterWrite, getPotentialHazards(analyzer, outer, tail)); +} + +TEST(BoundsInference, HasConflictingOverlapSameBufferWithPartialOverlap) { + // Input IR: + // for (const auto j : c10::irange(10, 100)) { + // A[j] = 10 * j; + // } + // for (const auto k : c10::irange(10, 100)) { + // A[k-1] = 20 * k; + // } + BufHandle a_buf("A", {200}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = + For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k))); + auto par = Block::make({forJ, forK}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); +} + +TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlap) { + // Input IR: + // for (const auto j : c10::irange(10, 100)) { + // A[j] = 10 * j; + // } + // for (const auto k : c10::irange(10, 100)) { + // A[k] = 20 * k; + // } + BufHandle a_buf("A", {200}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make(k, 10, 100, Store::make(a_buf, {k}, Mul::make(20, k))); + auto par = Block::make({forJ, forK}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); +} + +TEST(BoundsInference, HasConflictingOverlapSameBufferWithFullOverlapRAW) { + // Input IR: + // for (const auto j : c10::irange(10, 100)) { + // A[j] = 10 * j; + // } + // for (const auto k : c10::irange(10, 100)) { + // B[k] = A[k]; + // } + BufHandle a_buf("A", {200}, kInt); + BufHandle b_buf("B", {200}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = + For::make(k, 10, 100, Store::make(b_buf, {k}, Load::make(a_buf, {k}))); + auto par = Block::make({forJ, forK}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); +} + +TEST(BoundsInference, HasConflictingOverlapSameBufferNotOverlapping) { + // Input IR: + // for (const auto j : c10::irange(10, 100)) { + // A[j] = 10 * j; + // } + // for (const auto k : c10::irange(10, 100)) { + // A[k+100] = 20 * k; + // } + BufHandle a_buf("A", {200}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = + For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(20, k))); + auto par = Block::make({forJ, forK}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ)); +} + +TEST(BoundsInference, HasConflictingOverlap2DBufferWithOverlap) { + // Input IR: + // for (const auto i : c10::irange(20)) { + // for (const auto j : c10::irange(100)) { + // A[i,j] = i * j * 500; + // } + // } + // for (const auto m : c10::irange(20)) { + // for (const auto n : c10::irange(50)) { + // A[m+1,n] = m + n * 100; + // } + // } + BufHandle a_buf("A", {20, 100}, kInt); + BufHandle b_buf("B", {20, 50}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); + auto forJ = For::make(j, 0, 100, storeA1); + auto forI = For::make(i, 0, 20, forJ); + auto storeA2 = + Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100))); + auto forN = For::make(n, 0, 50, storeA2); + auto forM = For::make(m, 0, 20, forN); + auto par = Block::make({forI, forM}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forI, forM)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forM, forI)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forN)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forN, forJ)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, storeA2)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA2, storeA1)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, storeA2)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, storeA1, forM)); +} + +TEST(BoundsInference, HasConflictingOverlap2DBufferWithNoOverlap) { + // Input IR: + // for (const auto i : c10::irange(20)) { + // for (const auto j : c10::irange(100)) { + // A[i,j] = i * j * 500; + // } + // } + // for (const auto m : c10::irange(20)) { + // for (const auto n : c10::irange(50)) { + // A[m+20,n+100] = m + n * 100; + // } + // } + BufHandle a_buf("A", {20, 100}, kInt); + BufHandle b_buf("B", {20, 50}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); + auto forJ = For::make(j, 0, 100, storeA1); + auto forI = For::make(i, 0, 20, forJ); + auto storeA2 = + Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100))); + auto forN = For::make(n, 0, 50, storeA2); + auto forM = For::make(m, 0, 20, forN); + auto par = Block::make({forI, forM}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM)); +} + +TEST(BoundsInference, HasConflictingOverlapDifferentBuffers) { + // Input IR: + // for (const auto i : c10::irange(20)) { + // for (const auto j : c10::irange(100)) { + // A[i,j] = i * j * 500; + // } + // } + // for (const auto m : c10::irange(20)) { + // for (const auto n : c10::irange(50)) { + // B[m,n] = m + n * 100; + // } + // } + BufHandle a_buf("A", {20, 100}, kInt); + BufHandle b_buf("B", {20, 50}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); + auto forJ = For::make(j, 0, 100, storeA1); + auto forI = For::make(i, 0, 20, forJ); + auto storeA2 = Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100))); + auto forN = For::make(n, 0, 50, storeA2); + auto forM = For::make(m, 0, 20, forN); + auto par = Block::make({forI, forM}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forI, forM)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forM, forI)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forN)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forN, forJ)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, storeA2)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA2, storeA1)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, storeA2)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, storeA1, forM)); +} + +TEST(BoundsInference, HasConflictingOverlapDueToRAWDependence) { + // Input IR: + // for (const auto j : c10::irange(100)) { + // A[j] = 10 * j; + // } + // for (const auto k : c10::irange(100)) { + // B[k] = 20 * A[99-k]; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make( + k, + 0, + 100, + Store::make( + b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); + auto par = Block::make({forJ, forK}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); +} + +TEST(BoundsInference, HasConflictingOverlapDueToWARDependence) { + // Input IR: + // for (const auto k : c10::irange(100)) { + // B[k] = 20 * A[99-k]; + // } + // for (const auto j : c10::irange(100)) { + // A[j] = 10 * j; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forK = For::make( + k, + 0, + 100, + Store::make( + b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto par = Block::make({forK, forJ}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forJ, forK)); + ASSERT_TRUE(hasConflictingOverlap(analyzer, forK, forJ)); +} + +TEST(BoundsInference, HasConflictingOverlapWithLoads) { + // Input IR: + // for (const auto k : c10::irange(10, 100)) { + // B[k] = 20 * A[99-k]; + // } + // for (const auto j : c10::irange(10, 100)) { + // C[j] = 10 * A[j]; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + BufHandle c_buf("C", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forK = For::make( + k, + 10, + 100, + Store::make( + b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); + auto forJ = For::make( + j, + 10, + 100, + Store::make(c_buf, {j}, Mul::make(10, Load::make(a_buf, {j})))); + auto par = Block::make({forK, forJ}); + + tensorexpr::analysis::MemDependencyChecker analyzer; + par->accept(&analyzer); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forJ, forK)); + ASSERT_FALSE(hasConflictingOverlap(analyzer, forK, forJ)); +} + +TEST(BoundsInference, IsOverlapping) { + // Input IR: + // for (const auto i : c10::irange(100)) { + // A[i] = i * 10; // storeA1 + // B[i] = A[99-i] * 20; // loadA1 + // C[i] = A[i + 100] * 10; // loadA2 + // A[i + 50] = i * 50; // storeA2 + // A[i + 150] = i * 150; // storeA3 + // } + BufHandle a_buf("A", {300}, kInt); + BufHandle b_buf("B", {100}, kInt); + BufHandle c_buf("C", {100}, kInt); + VarHandle i("i", kInt); + auto storeA1 = Store::make(a_buf, {i}, i * 10); + auto loadA1 = Load::make(a_buf, {ExprHandle(99) - i}); + auto storeB = Store::make(b_buf, {i}, Mul::make(loadA1, 20)); + auto loadA2 = Load::make(a_buf, {i + 100}); + auto storeC = Store::make(c_buf, {i}, Mul::make(loadA2, 10)); + auto storeA2 = Store::make(a_buf, {i + 50}, i * 50); + auto storeA3 = Store::make(a_buf, {i + 150}, i * 150); + auto forI = For::make( + i, 0, 100, Block::make({storeA1, storeB, storeC, storeA2, storeA3})); + tensorexpr::analysis::MemDependencyChecker analyzer; + forI->accept(&analyzer); + ASSERT_TRUE(isOverlapping(analyzer, storeA1, to(loadA1.node()))); + ASSERT_FALSE(isOverlapping(analyzer, storeA1, to(loadA2.node()))); + ASSERT_TRUE(isOverlapping(analyzer, storeA1, storeA2)); + ASSERT_FALSE(isOverlapping(analyzer, storeA1, storeA3)); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_conv.cpp b/test/cpp/tensorexpr/test_conv.cpp new file mode 100644 index 0000000000000..e72303873a6cf --- /dev/null +++ b/test/cpp/tensorexpr/test_conv.cpp @@ -0,0 +1,234 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +namespace te = torch::jit::tensorexpr; +namespace F = torch::nn::functional; + +#ifdef TORCH_ENABLE_LLVM + +// Generate test data with few bits of precision, to minimize error +// accumulation from floating-point reordering. +static at::Tensor genTestData(c10::IntArrayRef args) { + return at::trunc(at::randn(args) * 256.0f) / 256.0f; +} + +TEST(Conv, DepthwiseConv2D) { + constexpr int N = 1, C = 72, H = 56, W = 56; + constexpr int K = 72, R = 3, S = 3; + constexpr int kPad = 1, kStride = 2, kGroups = C; + constexpr int CperG = C / kGroups; + + te::BufHandle input("input", {N, C, H, W}, te::kFloat); + te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat); + te::BufHandle bias("bias", {K}, te::kFloat); + te::Tensor output = + te::conv2d_depthwise(input, weight, bias, kStride, kPad, kGroups); + + te::LoopNest loop({output}); + loop.simplify(); + loop.prepareForCodegen(); + te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, bias, output}); + + auto it = genTestData({N, C, H, W}); + auto wt = genTestData({K, CperG, R, S}); + auto bt = genTestData({K}); + auto ref = at::conv2d(it, wt, bt, kStride, kPad, /*dilation=*/1, kGroups); + auto ot = at::zeros_like(ref); + cg.call( + {it.data_ptr(), + wt.data_ptr(), + bt.data_ptr(), + ot.data_ptr()}); + + ASSERT_TRUE(at::allclose(ref, ot)); +} + +TEST(Conv, DepthwiseConv2DNoBias) { + constexpr int N = 1, C = 72, H = 56, W = 56; + constexpr int K = 72, R = 3, S = 3; + constexpr int kPad = 1, kStride = 2, kGroups = C; + constexpr int CperG = C / kGroups; + + te::BufHandle input("input", {N, C, H, W}, te::kFloat); + te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat); + te::Tensor output = + te::conv2d_depthwise(input, weight, kStride, kPad, kGroups); + + te::LoopNest loop({output}); + loop.simplify(); + loop.prepareForCodegen(); + te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, output}); + + auto it = genTestData({N, C, H, W}); + auto wt = genTestData({K, CperG, R, S}); + auto ref = + at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups); + auto ot = at::zeros_like(ref); + cg.call({it.data_ptr(), wt.data_ptr(), ot.data_ptr()}); + + ASSERT_TRUE(at::allclose(ref, ot)); +} + +TEST(Conv, DepthwiseConv2DDynamicShapes) { + te::VarHandle N_var("N", te::kInt); + te::VarHandle C_var("C", te::kInt); + te::VarHandle H_var("H", te::kInt); + te::VarHandle W_var("W", te::kInt); + te::VarHandle K_var("K", te::kInt); + te::VarHandle CperG_var("CperG", te::kInt); + te::VarHandle R_var("R", te::kInt); + te::VarHandle S_var("S", te::kInt); + te::VarHandle kPad_var("kPad", te::kInt); + te::VarHandle kStride_var("kStride", te::kInt); + te::VarHandle kGroups_var("kGroups", te::kInt); + + te::BufHandle input("input", {N_var, C_var, H_var, W_var}, te::kFloat); + te::BufHandle weight("weight", {K_var, CperG_var, R_var, S_var}, te::kFloat); + te::Tensor output = te::conv2d_depthwise( + input, + weight, + N_var, + C_var, + H_var, + W_var, + K_var, + CperG_var, + R_var, + S_var, + kStride_var, + kPad_var, + kGroups_var); + + te::LoopNest loop({output}); + loop.simplify(); + loop.prepareForCodegen(); + std::vector buffer_args = { + input, + weight, + N_var, + C_var, + H_var, + W_var, + K_var, + CperG_var, + R_var, + S_var, + kPad_var, + kStride_var, + kGroups_var, + output}; + te::LLVMCodeGen cg(loop.root_stmt(), buffer_args); + + constexpr int N = 1, C = 72, H = 56, W = 56; + constexpr int K = 72, R = 3, S = 3; + constexpr int kPad = 1, kStride = 2, kGroups = C; + constexpr int CperG = C / kGroups; + + auto it = genTestData({N, C, H, W}); + auto wt = genTestData({K, CperG, R, S}); + auto ref = + at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups); + auto ot = at::zeros_like(ref); + std::vector call_args = { + it.data_ptr(), + wt.data_ptr(), + N, + C, + H, + W, + K, + CperG, + R, + S, + kPad, + kStride, + kGroups, + ot.data_ptr()}; + cg.call(call_args); + + ASSERT_TRUE(at::allclose(ref, ot)); +} + +#endif + +TEST(Conv, Conv2D) { + // Input dimensions. + constexpr int N = 1; + constexpr int C = 3; + constexpr int H = 11; + constexpr int W = 11; + + // Filter dimensions. + constexpr int K = 8; + constexpr int R = 3; + constexpr int S = 3; + + // Output dims. + constexpr int OH = H - R + 1; + constexpr int OW = W - S + 1; + + // Compute reference result. + at::Tensor input = torch::randn({N, C, H, W}); + at::Tensor filter = torch::randn({K, C, R, S}); + at::Tensor ref = F::conv2d(input, filter); + + // Double check the output size is as expected. + ASSERT_EQ(ref.size(0), N); + ASSERT_EQ(ref.size(1), K); + ASSERT_EQ(ref.size(2), OH); + ASSERT_EQ(ref.size(3), OW); + + te::BufHandle inputB("input", {N, C, H, W}, te::kFloat); + te::BufHandle filterB("filter", {K, C, R, S}, te::kFloat); + + te::Tensor conv = te::Reduce( + "conv", + {N, K, OH, OW}, + te::Sum(), + // FIXME: We have to use a `std::vector` parameter here and then unpack + // it, because we don't have an overload allowing for an arbitrary number + // of ExprHandle/VarHandle parameters. + [&](const std::vector& v) { + auto const& n = v[0]; + auto const& k = v[1]; + auto const& oh = v[2]; + auto const& ow = v[3]; + auto const& c = v[4]; + auto const& r = v[5]; + auto const& s = v[6]; + // FIXME: We have to use `call` and construct a `std::vector` here + // because the `operator()` overload is only specialized for a small + // number of arguments. + return inputB.load(n, c, oh + r, ow + s) * filterB.load(k, c, r, s); + }, + // FIXME: If you forget one of the reduction dims, you get a segfault. + // Could that be caught by a verifier? + {C, R, S}); + + // FIXME: It'd be nice to have a single header that pulls in things like + // LoopNest, IRSimplifier, etc. + te::LoopNest loop({conv}); + loop.prepareForCodegen(); + te::StmtPtr s = loop.root_stmt(); + s = te::IRSimplifier::simplify(s); + + at::Tensor result = at::empty_like(ref); + te::SimpleIREvaluator cg(s, {inputB, filterB, conv}); + cg.call( + {input.data_ptr(), + filter.data_ptr(), + result.data_ptr()}); + + ASSERT_TRUE(at::allclose(ref, result, 1e-3, 1e-3)); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_cpp_codegen.cpp b/test/cpp/tensorexpr/test_cpp_codegen.cpp new file mode 100644 index 0000000000000..ed7679053637c --- /dev/null +++ b/test/cpp/tensorexpr/test_cpp_codegen.cpp @@ -0,0 +1,259 @@ +#include + +#include "test/cpp/tensorexpr/test_base.h" + +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +#define STR_CHECK(node, expected) \ + std::stringstream ss; \ + CppPrinter printer(&ss); \ + printer.visit(node); \ + ASSERT_EQ(ss.str(), expected) + +#define FILE_CHECK(node, pattern) \ + std::stringstream ss; \ + CppPrinter printer(&ss); \ + printer.visit(node); \ + torch::jit::testing::FileCheck().run(pattern, ss.str()) + +TEST(CppPrinter, IntImm) { + auto i = alloc(10); + STR_CHECK(i, "10"); +} + +TEST(CppPrinter, FloatImm) { + auto f = alloc(10); + STR_CHECK(f, "10.f"); +} + +TEST(CppPrinter, FloatImm1) { + auto f = alloc(10); + STR_CHECK(f, "10.f"); +} + +TEST(CppPrinter, DoubleImm) { + auto d = alloc(10); + STR_CHECK(d, "10.0"); +} + +TEST(CppPrinter, DoubleImm1) { + auto d = alloc(10.1); + STR_CHECK(d, "10.1"); +} + +TEST(CppPrinter, HalfImm) { + auto h = alloc(10); + STR_CHECK(h, "10"); +} + +TEST(CppPrinter, Add) { + auto add = alloc(alloc(1), alloc(2)); + STR_CHECK(add, "1 + 2"); +} + +TEST(CppPrinter, AddExpr1) { + auto add = alloc( + alloc(alloc(0), alloc(1)), + alloc(alloc(2), alloc(3))); + STR_CHECK(add, "(0 + 1) + (2 - 3)"); +} + +TEST(CppPrinter, AddExpr2) { + auto add = alloc( + alloc(alloc(0), alloc(1)), + alloc(alloc(2), alloc(3))); + STR_CHECK(add, "0 * 1 + (2 - 3)"); +} + +TEST(CppPrinter, AddExpr3) { + auto add = alloc( + alloc(alloc(0), alloc(1)), + alloc
(alloc(2), alloc(3))); + STR_CHECK(add, "(0 + 1) + 2 / 3"); +} + +TEST(CppPrinter, Mod) { + auto mod = alloc(alloc(1), alloc(2)); + STR_CHECK(mod, "1 % 2"); +} + +TEST(CppPrinter, ModFloat) { + auto mod = alloc(alloc(1), alloc(2)); + STR_CHECK(mod, "std::fmod(1.f, 2.f)"); +} + +TEST(CppPrinter, Max) { + auto max = alloc(alloc(1), alloc(2), false); + STR_CHECK(max, "std::max(1, 2)"); +} + +TEST(CppPrinter, MaxFloat) { + auto max = alloc(alloc(1), alloc(2), false); + STR_CHECK(max, "std::max(1.f, 2.f)"); +} + +TEST(CppPrinter, MaxHalf) { + auto max = alloc(alloc(1), alloc(2), false); + STR_CHECK(max, "(1 < 2) ? 2 : 1"); +} + +TEST(CppPrinter, And) { + auto v = alloc(alloc(1), alloc(2)); + STR_CHECK(v, "1 & 2"); +} + +TEST(CppPrinter, CompareSelect) { + auto cs = alloc( + alloc(1), + alloc(2), + alloc(1), + alloc(2), + CompareSelectOperation::kLE); + STR_CHECK(cs, "((1 <= 2) ? 1.f : 2.f)"); +} + +TEST(CppPrinter, IfThenElse) { + auto cond = alloc(alloc(1), alloc(2)); + auto true_value = alloc(alloc(0), alloc(1)); + auto false_value = alloc(alloc(2), alloc(3)); + auto v = alloc(cond, true_value, false_value); + STR_CHECK(v, "((1 + 2) ? 0 - 1 : 2 * 3)"); +} + +TEST(CppPrinter, AllocateFree) { + BufHandle buf("x", {2, 3}, kInt); + AllocatePtr alloc = Allocate::make(buf); + FreePtr free = Free::make(buf); + BlockPtr block = Block::make({alloc, free}); + + const std::string pattern = R"( + # CHECK: { + # CHECK: int* x = static_cast(malloc(24)); + # CHECK: free(x); + # CHECK: } + )"; + FILE_CHECK(block, pattern); +} + +TEST(CppPrinter, LoadStore) { + BufHandle a("A", {2, 3}, kInt); + BufHandle b("B", {3, 4}, kInt); + auto store = b.store({2, 2}, a.load(1, 1)); + STR_CHECK( + store, "B[(0 + 2 * (1 * 4)) + 2 * 1] = A[(0 + 1 * (1 * 3)) + 1 * 1];\n"); +} + +TEST(CppPrinter, Var) { + auto var = alloc("x", kInt); + STR_CHECK(var, "x"); +} + +TEST(CppPrinter, Cast) { + auto cast = alloc(kFloat, alloc(1)); + STR_CHECK(cast, "static_cast(1)"); +} + +TEST(CppPrinter, BitCast) { + auto cast = alloc(kInt, alloc(20)); + STR_CHECK(cast, "std::bitcast(20.f)"); +} + +TEST(CppPrinter, Let) { + auto var = alloc("x", kFloat); + auto val = alloc(2); + auto let = alloc(var, val); + STR_CHECK(let, "float x = 2.f;\n"); +} + +TEST(CppPrinter, For) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + VarHandle i("i", kInt); + auto f = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i)))); + const std::string pattern = R"( + # CHECK: for (int i = 0; i < 1024; i++) { + # CHECK: C[i] = (A[i]) + (B[i]); + # CHECK: } + )"; + FILE_CHECK(f, pattern); +} + +TEST(CppPrinter, Cond) { + BufHandle x("X", {1}, kInt); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = + Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); + const std::string pattern = R"( + # CHECK: if (((X[0] < 10) ? 1 : 0)) { + # CHECK: X[0] = (X[0]) + 1; + # CHECK: } else { + # CHECK: X[0] = (X[0]) - 1; + # CHECK: } + )"; + FILE_CHECK(cond, pattern); +} + +TEST(CppPrinter, Intrinsics) { + const std::unordered_set> unsupported_ops{ + kRand, kSigmoid}; + for (const auto i : c10::irange(static_cast(kMaxIntrinsicsOp))) { + IntrinsicsOp op = static_cast(i); + if (unsupported_ops.count(op)) { + continue; + } + + if (Intrinsics::OpArgCount(op) == 1) { + auto v = alloc(op, alloc(2.0f)); + STR_CHECK(v, "std::" + v->func_name() + "(2.f)"); + } else { + auto v = + alloc(op, alloc(1.0f), alloc(2.0f)); + STR_CHECK(v, "std::" + v->func_name() + "(1.f, 2.f)"); + } + } +} + +TEST(CppPrinter, ExternalCall) { + std::vector dims{alloc(2), alloc(2)}; + auto output = alloc("out", dims, kFloat); + auto buf_arg1 = alloc("a", dims, kFloat); + auto buf_arg2 = alloc("b", dims, kFloat); + auto scalar_arg = alloc(alloc(1), alloc(2)); + std::vector buf_args{buf_arg1, buf_arg2}; + std::vector scalar_args{scalar_arg}; + auto call = + alloc(output, "nnc_aten_matmul", buf_args, scalar_args); + const std::string pattern = R"( + # CHECK: { + # CHECK: void* buf_ptrs[]{out, a, b}; + # CHECK: int64_t buf_ranks[]{2, 2, 2}; + # CHECK: int64_t buf_dims[]{2, 2, 2, 2, 2, 2}; + # CHECK: int8_t buf_dtypes[]{6, 6, 6}; + # CHECK: int64_t extra_args[]{1 + 2}; + # CHECK: nnc_aten_matmul( + # CHECK: 3, + # CHECK: buf_ptrs, + # CHECK: buf_ranks, + # CHECK: buf_dims, + # CHECK: buf_dtypes, + # CHECK: 1, + # CHECK: extra_args); + # CHECK: } + )"; + FILE_CHECK(call, pattern); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp new file mode 100644 index 0000000000000..8a96c68dc75e4 --- /dev/null +++ b/test/cpp/tensorexpr/test_cuda.cpp @@ -0,0 +1,2344 @@ +#ifdef USE_CUDA + +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; +using namespace torch::jit::tensorexpr; + +template +static void testCudaTestVectorAdd01_impl() { + const int num_iter = 3; + const int block_count = 16; + const int block_size = 128; + Dtype dtype = ToDtype(); + BufHandle a_buf("a", {num_iter, block_count, block_size}, dtype); + BufHandle b_buf("b", {num_iter, block_count, block_size}, dtype); + Tensor c = Compute( + "c", + { + num_iter, + block_count, + block_size, + }, + [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { + return a_buf.load(n, b_id, t_id) + b_buf.load(n, b_id, t_id); + }); + LoopNest l({c}); + std::vector loops = l.getLoopStmtsFor(c); + loops[1]->set_gpu_block_index(0); + loops[2]->set_gpu_thread_index(0); + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); + const int N = block_count * block_size * num_iter; + PaddedBuffer a_v(N); + PaddedBuffer b_v(N); + PaddedBuffer c_v(N); + PaddedBuffer c_ref(N); + + for (const auto i : c10::irange(N)) { + a_v(i) = ctype(i); + b_v(i) = ctype(i * 3 + 7); + c_ref(i) = a_v(i) + b_v(i); + } + + // TODO: move gpu support into PaddedBuffer + ctype* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(ctype))); + ctype* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(ctype))); + ctype* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(ctype))); + C10_CUDA_CHECK( + cudaMemcpy(a_dev, a_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK( + cudaMemcpy(b_dev, b_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK( + cudaMemcpy(c_dev, c_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(c_v.data(), c_dev, N * sizeof(ctype), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); +} + +float sigmoid(float x) { + return 1.0f / (1.0f + expf(-0.0f - x)); +} + +TEST(Cuda, Sigmoid_CUDA) { + const int num_iter = 3; + const int block_count = 16; + const int block_size = 128; + Dtype dtype = ToDtype(); + BufHandle a_buf("a", {num_iter, block_count, block_size}, dtype); + Tensor c = Compute( + "c", + { + num_iter, + block_count, + block_size, + }, + [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { + return sigmoid(sigmoid(a_buf.load(n, b_id, t_id))); + }); + LoopNest l({c}); + std::vector loops = l.getLoopStmtsFor(c); + loops[1]->set_gpu_block_index(0); + loops[2]->set_gpu_thread_index(0); + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, a_buf); + const int N = block_count * block_size * num_iter; + PaddedBuffer a_v(N); + PaddedBuffer c_v(N); + PaddedBuffer c_ref(N); + + for (const auto i : c10::irange(N)) { + a_v(i) = float(i); + c_ref(i) = sigmoid(sigmoid(a_v(i))); + } + + // TODO: move gpu support into PaddedBuffer + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); + C10_CUDA_CHECK( + cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK( + cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, a_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); +} + +TEST(Cuda, TestVectorAdd01_CUDA) { + // floating types. + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + + // integer types. + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); +} + +static void testCudaTestVectorAdd02_impl(int64_t N, int64_t block_size) { + BufHandle a_buf("a", {N}, kFloat); + BufHandle b_buf("b", {N}, kFloat); + Tensor c = Compute("c", {N}, [&](const VarHandle& n) { + return a_buf.load(n) + b_buf.load(n); + }); + LoopNest l({c}); + ForPtr n_inner; + std::vector loops = l.getLoopStmtsFor(c); + l.splitWithMask(loops[0], block_size, &n_inner); + loops[0]->set_gpu_block_index(0); + n_inner->set_gpu_thread_index(0); + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); + PaddedBuffer a_v(N); + PaddedBuffer b_v(N); + PaddedBuffer c_v(N); + PaddedBuffer c_ref(N); + + for (const auto i : c10::irange(N)) { + a_v(i) = i; + b_v(i) = i * 3 + 7; + c_ref(i) = a_v(i) + b_v(i); + } + + // TODO: move gpu support into PaddedBuffer + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); + C10_CUDA_CHECK( + cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK( + cudaMemcpy(b_dev, b_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK( + cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); +} + +TEST(Cuda, TestVectorAdd02_CUDA) { + testCudaTestVectorAdd02_impl(1024, 128); + testCudaTestVectorAdd02_impl(1030, 128); +} + +TEST(Cuda, HalfCast_CUDA) { + auto half = ToDtype(); + BufHandle a("a", {4}, half); + Tensor b = Compute("b", {4}, [&](const VarHandle& i) { + return Cast::make(kFloat, a.load(i)); + }); + + LoopNest l({b}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + CudaCodeGen cg(s, {a, b}); + + std::vector aData(4, 2.0f); + std::vector bData(4, 0.0f); + at::Half* aDev = nullptr; + float* bDev = nullptr; + auto aSize = aData.size() * sizeof(aData[0]); + auto bSize = bData.size() * sizeof(bData[0]); + + C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); + C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); + C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy(bDev, bData.data(), bSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cg.call({aDev, bDev}); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + C10_CUDA_CHECK(cudaMemcpy(aData.data(), aDev, aSize, cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy(bData.data(), bDev, bSize, cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + assertAllEqual(bData, 2.0f); + + C10_CUDA_CHECK(cudaFree(aDev)); + C10_CUDA_CHECK(cudaFree(bDev)); +} + +TEST(Cuda, DynamicShape2D_CUDA) { + auto testWithSize = [](int32_t M, int32_t N) { + VarHandle m("m", kInt); + VarHandle n("n", kInt); + BufHandle a("a", {m, n}, kFloat); + BufHandle b("b", {m, n}, kFloat); + Tensor c = + Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { + return a.load(i, j) + b.load(i, j); + }); + LoopNest l({c}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + CudaCodeGen cg(s, {a, b, c, m, n}); + + std::vector aData(M * N, 1.0f); + std::vector bData(M * N, 2.0f); + std::vector cData(M * N, 0.0f); + float* aDev = nullptr; + float* bDev = nullptr; + float* cDev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&aDev, aData.size() * sizeof(aData[0]))); + C10_CUDA_CHECK(cudaMalloc(&bDev, bData.size() * sizeof(bData[0]))); + C10_CUDA_CHECK(cudaMalloc(&cDev, cData.size() * sizeof(cData[0]))); + C10_CUDA_CHECK(cudaMemcpy( + aDev, + aData.data(), + aData.size() * sizeof(aData[0]), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + bDev, + bData.data(), + bData.size() * sizeof(bData[0]), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + cDev, + cData.data(), + cData.size() * sizeof(cData[0]), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cg.call({aDev, bDev, cDev, M, N}); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + C10_CUDA_CHECK(cudaMemcpy( + cData.data(), + cDev, + cData.size() * sizeof(cData[0]), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); + + C10_CUDA_CHECK(cudaFree(aDev)); + C10_CUDA_CHECK(cudaFree(bDev)); + C10_CUDA_CHECK(cudaFree(cDev)); + }; + testWithSize(32, 32); + testWithSize(1, 16); + testWithSize(27, 13); +} + +TEST(Cuda, TestRand01_CUDA) { + const int num_iter = 3; + const int block_count = 16; + const int block_size = 128; + Tensor c = Compute( + "c", + { + num_iter, + block_count, + block_size, + }, + [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { + return Intrinsics::make(IntrinsicsOp::kRand, kFloat); + }); + LoopNest l({c}); + std::vector loops = l.getLoopStmtsFor(c); + loops[1]->set_gpu_block_index(0); + loops[2]->set_gpu_thread_index(0); + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c); + const int N = block_count * block_size * num_iter; + PaddedBuffer c_v(N); + + // TODO: move gpu support into PaddedBuffer + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, N * sizeof(float))); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + float sum1 = 0; + float sum2 = 0; + float sum3 = 0; + for (const auto i : c10::irange(N)) { + float v = c_v.data()[i]; + sum1 += v; + sum2 += v * v; + sum3 += v * v * v; + ASSERT_TRUE(v >= 0 && v < 1); + } + sum1 /= N; + sum2 /= N; + sum3 /= N; + float sum1_mean = 1.f / 2; + float sum2_mean = 1.f / 3; + float sum3_mean = 1.f / 4; + + ASSERT_NEAR(sum1, sum1_mean, 2e-2); + ASSERT_NEAR(sum2, sum2_mean, 2e-2); + ASSERT_NEAR(sum3, sum3_mean, 2e-2); + C10_CUDA_CHECK(cudaFree(c_dev)); +} + +TEST(Cuda, DynamicShapeSplit_CUDA) { + constexpr int64_t N = 4096; + VarHandle n("n", kLong); + BufHandle a("a", {n}, kFloat); + Tensor b = + Compute("b", {n}, [&](const VarHandle& i) { return a.load(i) * 2.0f; }); + LoopNest l({b}); + ForPtr inner; + std::vector loops = l.getLoopStmtsFor(b); + l.splitWithMask(loops[0], 1024, &inner); + loops[0]->set_gpu_block_index(0); + inner->set_gpu_thread_index(0); + StmtPtr s = l.root_stmt(); + CudaCodeGen cg(s, {a, b, n}); + + std::vector aData(N, 1.0f); + std::vector bData(N, 1.0f); + float* aDev = nullptr; + float* bDev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&aDev, aData.size() * sizeof(aData[0]))); + C10_CUDA_CHECK(cudaMalloc(&bDev, bData.size() * sizeof(bData[0]))); + C10_CUDA_CHECK(cudaMemcpy( + aDev, + aData.data(), + aData.size() * sizeof(aData[0]), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + bDev, + bData.data(), + bData.size() * sizeof(aData[0]), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cg.call({aDev, bDev, N}); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + C10_CUDA_CHECK(cudaMemcpy( + bData.data(), + bDev, + bData.size() * sizeof(aData[0]), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(bData, std::vector(N, 2.0f), 1e-7); + + C10_CUDA_CHECK(cudaFree(aDev)); + C10_CUDA_CHECK(cudaFree(bDev)); +} + +TEST(Cuda, OneBlockOneThreadGlobalReduce1_CUDA) { + const static int N = 1024; + BufHandle data_buf("data", {N}, kFloat); + BufHandle output_buf("output", {1}, kFloat); + + // The test adds the following code for trivial reduction: + // for (const auto bidx : c10::irange(1)) { // blockIdx.x + // for (const auto tidx : c10::irange(1)) { // threadIdx.x + // output[0] = 0.f; + // for (const auto i1 : c10::irange(1024)) { + // output[0] = output[0] + data[i1]; + // } + // } + // } + + StorePtr init_store = output_buf.store({0}, 0.f); + VarHandle i1("i1", kInt); + ExprHandle load_data = Load::make(data_buf, {i1}); + ExprHandle load_output = Load::make(output_buf, {0}); + ExprHandle add_value = load_output + load_data; + StorePtr store_output = output_buf.store({0}, add_value); + ForPtr for_output = For::make(i1, 0, N, store_output); + StmtPtr reduce_block = Block::make({init_store, for_output}); + VarHandle thread_idx("tidx", kInt); + LoopOptions thread_idx_options; + thread_idx_options.set_gpu_thread_index(0); + ForPtr thread_idx_loop = + For::make(thread_idx, 0, 1, reduce_block, thread_idx_options); + VarHandle block_idx("bidx", kInt); + LoopOptions block_idx_options; + block_idx_options.set_gpu_block_index(0); + ForPtr block_idx_loop = + For::make(block_idx, 0, 1, thread_idx_loop, block_idx_options); + + CudaCodeGen cuda_cg(block_idx_loop, data_buf, output_buf); + PaddedBuffer data_v(N); + PaddedBuffer output_v(1, "output_v"); + PaddedBuffer output_ref(1, "output_ref"); + + output_ref(0) = 0; + for (const auto i : c10::irange(N)) { + data_v(i) = i; + output_ref(0) += data_v(i); + } + + float* data_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&data_dev, N * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + data_dev, data_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); + float* output_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&output_dev, 1 * sizeof(float))); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(data_dev, output_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + output_v.data(), output_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(output_v, output_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(data_dev)); + C10_CUDA_CHECK(cudaFree(output_dev)); +} + +TEST(Cuda, OneBlockMultiThreadGlobalReduce1_CUDA) { + const static int N = 1024; + + // This test does the following reduction: + // clang-format off + // for b in 0..1 // block-idx + // for t in 0..1024: // thread-idx + // if t < 1: + // b[0] = 0 + // // implied sync_threads + // for t in 0..1024: // thread-idx + // b[0] = b[0] + a[t] // implied atomic + // clang-format on + + BufHandle a_buf("a", {N}, kFloat); + BufHandle b_buf("b", {1}, kFloat); + + StorePtr init_store = b_buf.store({0}, 0.f); + VarHandle t("t", kInt); + VarHandle b("b", kInt); + + // for t in 0..1024: // thread-idx + // if t < 1: + // b[0] = 0 + ExprHandle cond_t_lt_1 = + CompareSelect::make(t, 1, CompareSelectOperation::kLT); + CondPtr masked_init_b = Cond::make(cond_t_lt_1, init_store, nullptr); + LoopOptions thread_idx_options; + thread_idx_options.set_gpu_thread_index(0); + ForPtr for_init = For::make(t, 0, N, masked_init_b, thread_idx_options); + + // for t in 0..1024: // thread-idx + // b[0] = b[0] + a[t] // implied atomic + ExprHandle load_a = Load::make(a_buf, {t}); + ExprHandle load_b = Load::make(b_buf, {0}); + ExprHandle add_value = load_b + load_a; + StorePtr store_b = b_buf.store({0}, add_value); + ForPtr for_b = For::make(t, 0, N, store_b, thread_idx_options); + + StmtPtr reduce_block = Block::make({for_init, for_b}); + + VarHandle block_idx("bidx", kInt); + LoopOptions block_idx_options; + block_idx_options.set_gpu_block_index(0); + ForPtr block_idx_loop = + For::make(block_idx, 0, 1, reduce_block, block_idx_options); + + CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf); + PaddedBuffer a_v(N); + PaddedBuffer b_v(1, "b_v"); + PaddedBuffer b_ref(1, "b_ref"); + + b_ref(0) = 0; + for (const auto i : c10::irange(N)) { + a_v(i) = i; + b_ref(0) += a_v(i); + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, N * sizeof(float))); + C10_CUDA_CHECK( + cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice)); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(b_v, b_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); +} + +TEST(Cuda, NoThreadIdxWrite_1_CUDA) { + // This test does the following reduction: + // + // for k in 0..1: // block-idx + // a[0] = 0 + // for n in 0..2: + // a[0] = a[0] + n + // for m in 0..1024: // thread-idx + // b[m] = m + // a[1] = 1 + // for l in 0..2: + // a[1] = a[1] + n + // + // note that the statements not covered by thread-idx are supposed to be + // covered by its own thread-idx + + const static int N = 1024; + BufHandle a_buf("a", {2}, kFloat); + BufHandle b_buf("b", {N}, kFloat); + + VarHandle k("k", kInt); + VarHandle l("l", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + + // a[0] = 0 + // for n in 0..2: + // a[0] = a[0] + n + StorePtr store_a0_0 = a_buf.store({0}, 0.f); + ExprHandle load_a0 = Load::make(a_buf, {0}); + ExprHandle v1 = load_a0 + n; + StorePtr store_a0_v1 = a_buf.store({0}, v1); + ForPtr loop_a_0 = For::make(n, 0, 2, store_a0_v1); + + // for m in 0..1024: // thread-idx + // b[m] = m + StorePtr store_bm_m = b_buf.store({m}, m + 0.f); + LoopOptions thread_idx_options; + thread_idx_options.set_gpu_thread_index(0); + ForPtr loop_b_1 = For::make(m, 0, N, store_bm_m, thread_idx_options); + + // a[1] = 1 + // for l in 0..2: + // a[1] = a[1] + l + StorePtr store_a1_1 = a_buf.store({1}, 1.f); + ExprHandle load_a1 = a_buf.load(1); + ExprHandle v2 = load_a1 + l; + StorePtr store_a1_v2 = a_buf.store({1}, v2); + ForPtr loop_a_1 = For::make(l, 0, 2, store_a1_v2); + + StmtPtr reduce_block = + Block::make({store_a0_0, loop_a_0, loop_b_1, store_a1_1, loop_a_1}); + + VarHandle block_idx("bidx", kInt); + LoopOptions block_idx_options; + block_idx_options.set_gpu_block_index(0); + ForPtr block_idx_loop = + For::make(block_idx, 0, 1, reduce_block, block_idx_options); + + CudaCodeGen cuda_cg(block_idx_loop, a_buf, b_buf); + PaddedBuffer a_v(2); + PaddedBuffer b_v(N, "b_v"); + PaddedBuffer a_ref(2, "a_ref"); + PaddedBuffer b_ref(N, "b_ref"); + + a_ref(0) = 0; + for (const auto i : c10::irange(2)) { + a_ref(0) += i; + } + a_ref(1) = a_ref(0) + 1; + for (const auto i : c10::irange(N)) { + b_ref(i) = i; + } + + // TODO: add check of the generated code. + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, 2 * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, N * sizeof(float))); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(a_v.data(), a_dev, 2 * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK( + cudaMemcpy(b_v.data(), b_dev, N * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(a_v, a_ref, 1e-5); + ExpectAllNear(b_v, b_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); +} + +TEST(Cuda, SharedMemReduce_1_CUDA) { + // FIXME: this test is flaky in CI. + // This test does the following: + // for k in 0..1: // block-idx + // alloc(c, 64) + // for n in 0..64: // thread-idx + // c(n) = 0 + // for m in 0..128: + // for n in 0..64: // thread_idx + // c(n) = c(n) + a(k, m, n) + // b(k) = 0 + // for n in 0..64: // thread_idx + // b(k) = b(k) + c(n) + // free(c) + + const int M = 128; + const int N = 64; + const int kTotalSize = M * N; + LoopOptions thread_idx_opt; + thread_idx_opt.set_gpu_thread_index(0); + LoopOptions block_idx_opt; + block_idx_opt.set_gpu_block_index(0); + + BufHandle a("a", {1, M, N}, kFloat); + BufHandle b("b", {1}, kFloat); + VarHandle k("k", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + + std::vector block; + std::vector dims; + dims.push_back(ExprHandle(N).node()); + BufHandle c{alloc("c", dims, kFloat)}; + { + // alloc(c, 64); + AllocatePtr alloc = Allocate::make(c); + block.push_back(alloc); + } + + { + // for n in 0..64: // thread-idx + // c(n) = 0 + StorePtr store_cn_0 = Store::make(c, {n}, 0.f); + ForPtr loop_n1 = For::make(n, 0, N, store_cn_0, thread_idx_opt); + block.push_back(loop_n1); + } + + { + // for m in 0..128: + // for n in 0..64: // thread_idx + // c(n) = c(n) + a(k, m, n) + ExprHandle load_cn = Load::make(kFloat, c, {n}); + ExprHandle a_kmn = Load::make(a, {k * (M * N) + m * N + n}); + ExprHandle v_add = load_cn + a_kmn; + StorePtr store_cn_v = Store::make(c, {n}, v_add); + ForPtr loop_n2 = For::make(n, 0, N, store_cn_v, thread_idx_opt); + ForPtr loop_m1 = For::make(m, 0, M, loop_n2); + block.push_back(loop_m1); + } + + { + // b(k) = 0 + // for n in 0..64: // thread_idx + // b(k) = b(k) + c(n) + StorePtr store_bk_0 = b.store({k}, 0.f); + block.push_back(store_bk_0); + ExprHandle load_bk = b.load(k); + ExprHandle load_cn = Load::make(kFloat, c, {n}); + ExprHandle v_add = load_bk + load_cn; + StorePtr store_bk = b.store({k}, v_add); + ForPtr loop_n3 = For::make(n, 0, N, store_bk, thread_idx_opt); + block.push_back(loop_n3); + } + + { + // free(c) + FreePtr free_stmt = Free::make(c); + block.push_back(free_stmt); + } + + BlockPtr reduce_body = Block::make(block); + ForPtr loop_k1 = For::make(k, 0, 1, reduce_body, block_idx_opt); + + // TODO: check the generated code for correctness. + CudaCodeGen cuda_cg(loop_k1, a, b); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // Check the c write is not masked, but the d write is. + const std::string& verification_pattern = + R"IR( +# CHECK: c_1 = 0 +# CHECK: for (int m = 0; m < 128 +# CHECK: c_1 = c_1 + +# CHECK: __syncthreads(); +# CHECK: if (threadIdx.x<1 +# CHECK: b[blockIdx.x] = +# CHECK: __syncthreads(); +# CHECK: atomicAdd(&b[blockIdx.x], c_1) +)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + PaddedBuffer a_v(1, M, N, "a_v"); + PaddedBuffer b_v(1, "b_v"); + PaddedBuffer b_ref(1, "b_ref"); + + b_ref(0) = 0; + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + int v = i + j; + a_v(0, i, j) = v; + b_ref(0) += v; + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, kTotalSize * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, a_v.data(), kTotalSize * sizeof(float), cudaMemcpyHostToDevice)); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(b_v, b_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); +} + +TEST(Cuda, LocalMemReduce_1_CUDA) { + // This test does the following: + // for k in 0..1: // block-idx + // b(k) = 0 + // for n in 0..64: // thread-idx + // alloc(c, 1) + // c(0) = 0 + // for m in 0..128: + // c(0) = c(0) + a(k, m, n) + // b(k) = b(k) + c(0) + // free(c) + + const int M = 128; + const int N = 64; + const int kTotalSize = M * N; + LoopOptions thread_idx_opt; + thread_idx_opt.set_gpu_thread_index(0); + LoopOptions block_idx_opt; + block_idx_opt.set_gpu_block_index(0); + + BufHandle a("a", {1, M, N}, kFloat); + BufHandle b("b", {1}, kFloat); + VarHandle k("k", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + + BufHandle c{ + alloc("c", std::vector({alloc(1)}), kFloat)}; + std::vector block_k; + { + // b(k) = 0 + StorePtr store_bk_0 = b.store({k}, 0.f); + block_k.push_back(store_bk_0); + } + std::vector block_n; + { + // alloc(c, 1); + AllocatePtr alloc = Allocate::make(c); + block_n.push_back(alloc); + } + { + // c(0) = 0 + StorePtr store_c0_0 = Store::make(c, {0}, 0.f); + block_n.push_back(store_c0_0); + } + { + // for m in 0..128: + // c(0) = c(0) + a(k, m, n) + ExprHandle load_c0 = Load::make(kFloat, c, {0}); + ExprHandle a_kmn = a.load(k * (M * N) + m * N + n); + ExprHandle v_add = load_c0 + a_kmn; + StorePtr store_c0_v = Store::make(c, {0}, v_add); + ForPtr loop_m = For::make(m, 0, M, store_c0_v); + block_n.push_back(loop_m); + } + { + // b(k) = b(k) + c(0) + ExprHandle load_bk = b.load(k); + ExprHandle load_c0 = Load::make(kFloat, c, {0}); + ExprHandle v_add = load_bk + load_c0; + StorePtr store_bk = b.store({k}, v_add); + block_n.push_back(store_bk); + } + { + // free(c) + FreePtr free_stmt = Free::make(c); + block_n.push_back(free_stmt); + } + { + BlockPtr block_n_stmt = Block::make(block_n); + ForPtr for_n = For::make(n, 0, N, block_n_stmt, thread_idx_opt); + block_k.push_back(for_n); + } + BlockPtr block_k_stmt = Block::make(block_k); + ForPtr loop_k = For::make(k, 0, 1, block_k_stmt, block_idx_opt); + + CudaCodeGen cuda_cg(loop_k, a, b); + PaddedBuffer a_v(1, M, N, "a_v"); + PaddedBuffer b_v(1, "b_v"); + PaddedBuffer b_ref(1, "b_ref"); + + b_ref(0) = 0; + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + int v = i + j; + a_v(0, i, j) = v; + b_ref(0) += v; + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, kTotalSize * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, a_v.data(), kTotalSize * sizeof(float), cudaMemcpyHostToDevice)); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, 1 * sizeof(float))); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK( + cudaMemcpy(b_v.data(), b_dev, 1 * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(b_v, b_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); +} + +TEST(Cuda, HalfSupport_CUDA) { + auto half = ToDtype(); + BufHandle a("a", {4}, half); + Tensor b = Compute("b", {4}, [&](const VarHandle& i) { + return Cast::make(half, ExprHandle(2.0f) * a.load(i)); + }); + + Tensor c = Compute("c", {4}, [&](const VarHandle& i) { + return Cast::make(kFloat, Cast::make(half, ExprHandle(42)) + b.load(i)); + }); + + Tensor d = Compute("d", {4}, [&](const VarHandle& i) { + return Cast::make(half, c.load(i)); + }); + + LoopNest l({b, c, d}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + CudaCodeGen cg(s, {a, b, c, d}); + + std::vector aData(4, 2.0f); + std::vector cData(4, 0.0f); + std::vector dData(4, 0.0f); + at::Half* aDev = nullptr; + at::Half* bDev = nullptr; + at::Half* cDev = nullptr; + at::Half* dDev = nullptr; + auto aSize = aData.size() * sizeof(aData[0]); + auto bSize = aData.size() * sizeof(aData[0]); + auto cSize = cData.size() * sizeof(float); + auto dSize = dData.size() * sizeof(dData[0]); + + C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); + C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); + C10_CUDA_CHECK(cudaMalloc(&cDev, cSize)); + C10_CUDA_CHECK(cudaMalloc(&dDev, dSize)); + C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy(cDev, cData.data(), cSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy(dDev, dData.data(), dSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cg.call({aDev, bDev, cDev, dDev}); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + C10_CUDA_CHECK(cudaMemcpy(aData.data(), aDev, aSize, cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy(cData.data(), cDev, cSize, cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy(dData.data(), dDev, dSize, cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + assertAllEqual(cData, 46.0f); + + C10_CUDA_CHECK(cudaFree(aDev)); + C10_CUDA_CHECK(cudaFree(bDev)); + C10_CUDA_CHECK(cudaFree(cDev)); + C10_CUDA_CHECK(cudaFree(dDev)); +} + +TEST(Cuda, HalfPropagation_CUDA) { + auto half = ToDtype(); + BufHandle a("a", {4}, half); + Tensor relu = Compute("relu", {4}, [&](const VarHandle& i) { + return Max::make(a.load(i), ExprHandle(alloc(0)), true); + }); + + LoopNest l({relu}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + CudaCodeGen cg(s, {a, relu}); + + std::ostringstream oss; + oss << *cg.stmt(); + + // Check the types used by the Max are Float. + const std::string& verification_pattern = + R"IR( +# CHECK: for ( +# CHECK: float v = float(a[i]); +# CHECK: relu[i] = half(Max(v, 0.f +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector aData(4, 2.0f); + std::vector reluData(4, 0.0f); + at::Half* aDev = nullptr; + at::Half* reluDev = nullptr; + auto aSize = aData.size() * sizeof(aData[0]); + auto reluSize = reluData.size() * sizeof(reluData[0]); + + C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); + C10_CUDA_CHECK(cudaMalloc(&reluDev, reluSize)); + C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK( + cudaMemcpy(reluDev, reluData.data(), reluSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cg.call({aDev, reluDev}); + C10_CUDA_CHECK( + cudaMemcpy(reluData.data(), reluDev, reluSize, cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + assertAllEqual(aData, reluData); + + C10_CUDA_CHECK(cudaFree(aDev)); + C10_CUDA_CHECK(cudaFree(reluDev)); +} + +TEST(Cuda, UnusedHalfArgument_CUDA) { + BufHandle a("a", {4}, kFloat); + auto half = ToDtype(); + BufHandle b("b", {4}, half); + Tensor relu = Compute("relu", {4}, [&](const VarHandle& i) { + return Max::make(a.load(i), ExprHandle(alloc(0)), true); + }); + + LoopNest l({relu}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + CudaCodeGen cg(s, {a, b, relu}); + + std::ostringstream oss; + oss << *cg.stmt(); + + // Check the types used by the Max are Float. + const std::string& verification_pattern = + R"IR( +# CHECK: for ( +# CHECK: float v = a[i]; +# CHECK: relu[i] = Max(v, 0.f +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // Sanity Cbeck; + std::vector aData(4, 2.0f); + std::vector bData(4, 2.0f); + std::vector reluData(4, 0.0f); + at::Half* aDev = nullptr; + at::Half* bDev = nullptr; + at::Half* reluDev = nullptr; + auto aSize = aData.size() * sizeof(aData[0]); + auto bSize = bData.size() * sizeof(bData[0]); + auto reluSize = reluData.size() * sizeof(reluData[0]); + + C10_CUDA_CHECK(cudaMalloc(&aDev, aSize)); + C10_CUDA_CHECK(cudaMalloc(&bDev, bSize)); + C10_CUDA_CHECK(cudaMalloc(&reluDev, reluSize)); + C10_CUDA_CHECK(cudaMemcpy(aDev, aData.data(), aSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy(bDev, bData.data(), bSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK( + cudaMemcpy(reluDev, reluData.data(), reluSize, cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cg.call({aDev, bDev, reluDev}); + C10_CUDA_CHECK( + cudaMemcpy(reluData.data(), reluDev, reluSize, cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + assertAllEqual(aData, reluData); + + C10_CUDA_CHECK(cudaFree(aDev)); + C10_CUDA_CHECK(cudaFree(bDev)); + C10_CUDA_CHECK(cudaFree(reluDev)); +} + +TEST(Cuda, PrioritizeDependents_CUDA) { + BufHandle a("a", {10}, kFloat); + BufHandle b("b", {12}, kFloat); + BufHandle c("c", {12}, kFloat); + + LoopOptions block_idx_opt; + block_idx_opt.set_gpu_block_index(0); + + VarHandle i("i", kInt); + VarHandle j("j", kInt); + + /* + * for (const auto i : c10::irange(12)) { + * c[i] = (i < 10 ? a[i] + b[i] : b[i]); + * } + */ + ExprHandle load_a = a.load({i}); + ExprHandle load_b = b.load({i}); + ExprHandle cmp = CompareSelect::make(i, 10, CompareSelectOperation::kLT); + ExprHandle ite = IfThenElse::make(cmp, Add::make(load_a, load_b), load_b); + + ForPtr loop = + For::make(i, 0, 12, Block::make({c.store({i}, ite)}), block_idx_opt); + + CudaCodeGen cuda_cg(loop, a, b, c); + + PaddedBuffer a_v(10, "a_v"); + PaddedBuffer b_v(12, "b_v"); + PaddedBuffer c_v(12, "c_v"); + PaddedBuffer c_ref(12, "c_ref"); + + for (const auto i : c10::irange(10)) { + a_v(i) = i * 100; + b_v(i) = i; + c_v(i) = 0; + } + + for (const auto i : c10::irange(10, 12)) { + b_v(i) = i; + c_v(i) = 0; + } + + float* a_dev = nullptr; + float* b_dev = nullptr; + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, 10 * sizeof(float))); + C10_CUDA_CHECK(cudaMalloc(&b_dev, 12 * sizeof(float))); + C10_CUDA_CHECK(cudaMalloc(&c_dev, 12 * sizeof(float))); + + C10_CUDA_CHECK(cudaMemcpy( + a_dev, a_v.data(), 10 * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, b_v.data(), 12 * sizeof(float), cudaMemcpyHostToDevice)); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(a_dev, b_dev, c_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), c_dev, 12 * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + for (const auto i : c10::irange(12)) { + if (i < 10) { + c_ref(i) = i + i * 100; + } else { + c_ref(i) = i; + } + } + + ExpectAllNear(c_v, c_ref, 1e-5); +} + +/// Tests the case where there are two loops which have different extents bound +/// to the same block dimension. We must mask the smaller extent loop body. +TEST(Cuda, MaskBlockDim_CUDA) { + int A_SIZE = 100; + int B_SIZE = 50; + BufHandle a_buf("a", {A_SIZE}, kFloat); + BufHandle b_buf("b", {B_SIZE}, kFloat); + Tensor c = Compute( + "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); + Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { + return a_buf.load(i) + b_buf.load(i); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_block_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_block_index(0); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // Check the c write is not masked, but the d write is. + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: if (blockIdx +# CHECK: c[blockIdx.x] = +# CHECK: if (blockIdx.x<50 +# CHECK: d[blockIdx.x] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(1))); + + // Sanity check that the kernel works. + PaddedBuffer a_v(A_SIZE); + PaddedBuffer b_v(B_SIZE); + PaddedBuffer c_v(A_SIZE); + PaddedBuffer d_v(B_SIZE); + + PaddedBuffer c_ref(A_SIZE); + PaddedBuffer d_ref(B_SIZE); + + for (const auto i : c10::irange(A_SIZE)) { + a_v(i) = (float)i; + c_ref(i) = (float)(i + 10); + } + + for (const auto i : c10::irange(B_SIZE)) { + b_v(i) = (float)(B_SIZE - i); + d_ref(i) = a_v(i) + b_v(i); + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +/// Tests the case with two loops, which have different extents that are bound +/// to the same thread dimension. This is the same as the above - the smaller +/// rank write should be masked. But this time we also need to syncthreads. +TEST(Cuda, MaskThreadDim_CUDA) { + int A_SIZE = 50; + int B_SIZE = 100; + BufHandle a_buf("a", {A_SIZE}, kFloat); + BufHandle b_buf("b", {B_SIZE}, kFloat); + Tensor c = Compute( + "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); + Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { + return a_buf.load(i / 2) + b_buf.load(i); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_thread_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_thread_index(0); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // Check the c write is masked, but the d write is not. + const std::string& verification_pattern = + R"IR( +# CHECK: if (threadIdx.x<50 +# CHECK: c[threadIdx.x] = +# CHECK: __syncthreads(); +# CHECK-NOT: if (threadIdx.x +# CHECK: d[threadIdx.x] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(1))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(B_SIZE))); + + PaddedBuffer a_v(A_SIZE); + PaddedBuffer b_v(B_SIZE); + PaddedBuffer c_v(A_SIZE); + PaddedBuffer d_v(B_SIZE); + + PaddedBuffer c_ref(A_SIZE); + PaddedBuffer d_ref(B_SIZE); + + for (const auto i : c10::irange(A_SIZE)) { + a_v(i) = (float)i; + c_ref(i) = (float)(i + 10); + } + + for (const auto i : c10::irange(B_SIZE)) { + b_v(i) = (float)(B_SIZE - i); + d_ref(i) = a_v(i / 2) + b_v(i); + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +/// Tests the case where there are two loops, and each is bound to a different +/// block dimension. In this case all writes should be masked since they occur +/// in distinct dimensions. +// Note: this is an extremely dumb pattern which we should never see, but is a +// useful edge case to make sure we've got things covered. +TEST(Cuda, MaskMultiBlockDim_CUDA) { + int A_SIZE = 100; + int B_SIZE = 50; + BufHandle a_buf("a", {A_SIZE}, kFloat); + BufHandle b_buf("b", {B_SIZE}, kFloat); + Tensor c = Compute( + "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); + Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { + return a_buf.load(i) + b_buf.load(i); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_block_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_block_index(1); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // Write to c should be masked against y, write to d against x. + const std::string& verification_pattern = + R"IR( +# CHECK: if (blockIdx.y<1 +# CHECK: c[blockIdx.x] = +# CHECK: if (blockIdx.x<1 +# CHECK: d[blockIdx.y] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); + ASSERT_TRUE(exprEquals(blockExtents[1], alloc(B_SIZE))); + + PaddedBuffer a_v(A_SIZE); + PaddedBuffer b_v(B_SIZE); + PaddedBuffer c_v(A_SIZE); + PaddedBuffer d_v(B_SIZE); + + PaddedBuffer c_ref(A_SIZE); + PaddedBuffer d_ref(B_SIZE); + + for (const auto i : c10::irange(A_SIZE)) { + a_v(i) = (float)i; + c_ref(i) = (float)(i + 10); + } + + for (const auto i : c10::irange(B_SIZE)) { + b_v(i) = (float)(B_SIZE - i); + d_ref(i) = a_v(i) + b_v(i); + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +/// Tests the case where both the blockDim and threadDim are bound to different +/// loops. In this instance both stores should be masked since they are +/// distinct. +// Note: this is an extremely dumb pattern which we should never see, but is a +// useful edge case to make sure we've got things covered. +TEST(Cuda, MaskBlockAndThreadDim_CUDA) { + int A_SIZE = 100; + int B_SIZE = 50; + BufHandle a_buf("a", {A_SIZE}, kFloat); + BufHandle b_buf("b", {B_SIZE}, kFloat); + Tensor c = Compute( + "c", {A_SIZE}, [&](const VarHandle& i) { return a_buf.load(i) + 10; }); + Tensor d = Compute("d", {B_SIZE}, [&](const VarHandle& i) { + return a_buf.load(i) + b_buf.load(i); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_block_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_thread_index(0); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + const std::string& verification_pattern = + R"IR( +# CHECK: if (threadIdx.x<1 +# CHECK: c[blockIdx.x] = +# CHECK: } +# CHECK: if (blockIdx.x<1 +# CHECK: d[threadIdx.x] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(A_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(B_SIZE))); + + PaddedBuffer a_v(A_SIZE); + PaddedBuffer b_v(B_SIZE); + PaddedBuffer c_v(A_SIZE); + PaddedBuffer d_v(B_SIZE); + + PaddedBuffer c_ref(A_SIZE); + PaddedBuffer d_ref(B_SIZE); + + for (const auto i : c10::irange(A_SIZE)) { + a_v(i) = (float)i; + c_ref(i) = (float)(i + 10); + } + + for (const auto i : c10::irange(B_SIZE)) { + b_v(i) = (float)(B_SIZE - i); + d_ref(i) = a_v(i) + b_v(i); + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, a_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, b_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, c_v.data(), A_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, d_v.data(), B_SIZE * sizeof(float), cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), c_dev, A_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), d_dev, B_SIZE * sizeof(float), cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +/// Tests the case where the loopnest has two loops of depth two: each with the +/// outer loop bound to blockDim.x and the inner loop bound to threadDim.x. In +/// this case all writes with a rank smaller than the max should be masked. +TEST(Cuda, MaskMultiDim_CUDA) { + int OUTER_SIZE = 10; + int A_SIZE = 100; + int B_SIZE = 50; + BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); + BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); + Tensor c = Compute( + "C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return ExprHandle(2) * a_buf.load(i, j); + }); + Tensor d = Compute( + "D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return c.load(i, j * 2) + b_buf.load(i, j); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // The write to D should be masked, but not the write to C. + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: if ( +# CHECK: C[threadIdx.x + 100 * blockIdx.x] = +# CHECK: __syncthreads(); +# CHECK: if (threadIdx.x<50 +# CHECK: D[threadIdx.x + 50 * blockIdx.x] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); + + PaddedBuffer a_v(OUTER_SIZE, A_SIZE); + PaddedBuffer b_v(OUTER_SIZE, B_SIZE); + PaddedBuffer c_v(OUTER_SIZE, A_SIZE); + PaddedBuffer d_v(OUTER_SIZE, B_SIZE); + + PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); + PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); + + for (const auto o : c10::irange(OUTER_SIZE)) { + for (const auto i : c10::irange(A_SIZE)) { + a_v(o, i) = (float)i; + c_ref(o, i) = (float)(i * 2); + } + } + + for (const auto o : c10::irange(OUTER_SIZE)) { + for (const auto i : c10::irange(B_SIZE)) { + b_v(o, i) = (float)(B_SIZE - i); + d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, + a_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, + b_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, + c_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, + d_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), + c_dev, + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), + d_dev, + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +// Tests the case where loop extents are symbolic and not known at compile time. +// In this case both stores must be masked against the extent of the other loop, +// incase it is larger. +TEST(Cuda, MaskMultiDimSymbolic_CUDA) { + VarHandle OUTER_SIZE("OUTER_SIZE", kLong); + VarHandle A_SIZE("A_SIZE", kLong); + VarHandle B_SIZE("B_SIZE", kLong); + BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); + BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); + Tensor c = Compute( + "C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return ExprHandle(2) * a_buf.load(i, j); + }); + Tensor d = Compute( + "D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return c.load(i, j * 2) + b_buf.load(i, j); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, OUTER_SIZE, A_SIZE, B_SIZE, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // Since we don't know which is bigger (A_SIZE or B_SIZE) we must mask both. + const std::string& verification_pattern = + R"IR( +# CHECK: if (threadIdx.x(A_SIZE.node(), B_SIZE.node(), true))); + + int64_t OUTER_EXTENT = 10; + int64_t A_EXTENT = 100; + int64_t B_EXTENT = 50; + + PaddedBuffer a_v(OUTER_EXTENT, A_EXTENT); + PaddedBuffer b_v(OUTER_EXTENT, B_EXTENT); + PaddedBuffer c_v(OUTER_EXTENT, A_EXTENT); + PaddedBuffer d_v(OUTER_EXTENT, B_EXTENT); + + PaddedBuffer c_ref(OUTER_EXTENT, A_EXTENT); + PaddedBuffer d_ref(OUTER_EXTENT, B_EXTENT); + + for (const auto o : c10::irange(OUTER_EXTENT)) { + for (const auto i : c10::irange(A_EXTENT)) { + a_v(o, i) = (float)i; + c_ref(o, i) = (float)(i * 2); + } + } + + for (const auto o : c10::irange(OUTER_EXTENT)) { + for (const auto i : c10::irange(B_EXTENT)) { + b_v(o, i) = (float)(B_EXTENT - i); + d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_EXTENT * A_EXTENT * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_EXTENT * B_EXTENT * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_EXTENT * A_EXTENT * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_EXTENT * B_EXTENT * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, + a_v.data(), + OUTER_EXTENT * A_EXTENT * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, + b_v.data(), + OUTER_EXTENT * B_EXTENT * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, + c_v.data(), + OUTER_EXTENT * A_EXTENT * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, + d_v.data(), + OUTER_EXTENT * B_EXTENT * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, OUTER_EXTENT, A_EXTENT, B_EXTENT, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), + c_dev, + OUTER_EXTENT * A_EXTENT * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), + d_dev, + OUTER_EXTENT * B_EXTENT * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +// Tests the case where two loops are fused at a common parent loop, which is +// bound to the block dimension. Internally the inner loops have different +// extents but are bound to the same thread dimension. The smaller loop should +// be masked. +TEST(Cuda, MaskCompoundInnerLoop_CUDA) { + int OUTER_SIZE = 10; + int A_SIZE = 100; + int B_SIZE = 50; + BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); + BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); + BufHandle c_buf("c", {OUTER_SIZE, A_SIZE}, kFloat); + BufHandle d_buf("d", {OUTER_SIZE, B_SIZE}, kFloat); + + // Can't build this using Compute and transforms yet. + LoopOptions blockBound; + blockBound.set_gpu_block_index(0); + LoopOptions threadBound; + threadBound.set_gpu_thread_index(0); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + + StmtPtr stmt = For::make( + i, + 0, + OUTER_SIZE, + Block::make( + {For::make( + j, + 0, + A_SIZE, + c_buf.store({i, j}, ExprHandle(2) * a_buf.load(i, j)), + threadBound), + For::make( + k, + 0, + B_SIZE, + d_buf.store({i, k}, c_buf.load(i, k * 2) + b_buf.load(i, k)), + threadBound)}), + blockBound); + + stmt = FlattenIndexes(stmt); + stmt = IRSimplifier::simplify(stmt); + + CudaCodeGen cuda_cg(stmt, a_buf, b_buf, c_buf, d_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // The write to D should be masked, but not the write to C. + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: if ( +# CHECK: c[threadIdx.x + 100 * blockIdx.x] = +# CHECK: __syncthreads(); +# CHECK: if (threadIdx.x<50 +# CHECK: d[threadIdx.x + 50 * blockIdx.x] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); + + PaddedBuffer a_v(OUTER_SIZE, A_SIZE); + PaddedBuffer b_v(OUTER_SIZE, B_SIZE); + PaddedBuffer c_v(OUTER_SIZE, A_SIZE); + PaddedBuffer d_v(OUTER_SIZE, B_SIZE); + + PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); + PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); + + for (const auto o : c10::irange(OUTER_SIZE)) { + for (const auto i : c10::irange(A_SIZE)) { + a_v(o, i) = (float)i; + c_ref(o, i) = (float)(i * 2); + } + for (const auto i : c10::irange(B_SIZE)) { + b_v(o, i) = (float)(B_SIZE - i); + d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, + a_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, + b_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, + c_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, + d_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(a_dev, b_dev, c_dev, d_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), + c_dev, + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), + d_dev, + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +// Tests the case with two loops fused into a common parent, which is not bound +// to any block or thread dimension - however it's two inner loops are bound to +// the first thread dimensions. This should work just like the MaskThreadDim +// test where the bigger loop is unmasked but the smaller is masked. +TEST(Cuda, MaskInnerLoopOneBlock_CUDA) { + int OUTER_SIZE = 10; + int A_SIZE = 100; + int B_SIZE = 50; + BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); + BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); + BufHandle c_buf("c", {OUTER_SIZE, A_SIZE}, kFloat); + BufHandle d_buf("d", {OUTER_SIZE, B_SIZE}, kFloat); + + // Can't build this using Compute and transforms yet. + LoopOptions blockBound; + blockBound.set_gpu_block_index(0); + LoopOptions threadBound; + threadBound.set_gpu_thread_index(0); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + + StmtPtr stmt = For::make( + i, + 0, + OUTER_SIZE, + Block::make( + {For::make( + j, + 0, + A_SIZE, + c_buf.store({i, j}, ExprHandle(2) * a_buf.load(i, j)), + threadBound), + For::make( + k, + 0, + B_SIZE, + d_buf.store({i, k}, c_buf.load(i, k * 2) + b_buf.load(i, k)), + threadBound)})); + + stmt = FlattenIndexes(stmt); + stmt = IRSimplifier::simplify(stmt); + + CudaCodeGen cuda_cg(stmt, a_buf, b_buf, c_buf, d_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // The other loop remains the D write is masked. + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i = 0; i < 10 +# CHECK-NOT: if ( +# CHECK: c[threadIdx.x + 100 * i] = +# CHECK: __syncthreads(); +# CHECK: if (threadIdx.x<50 +# CHECK: d[threadIdx.x + 50 * i] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(1))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); + + PaddedBuffer a_v(OUTER_SIZE, A_SIZE); + PaddedBuffer b_v(OUTER_SIZE, B_SIZE); + PaddedBuffer c_v(OUTER_SIZE, A_SIZE); + PaddedBuffer d_v(OUTER_SIZE, B_SIZE); + + PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); + PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); + + for (const auto o : c10::irange(OUTER_SIZE)) { + for (const auto i : c10::irange(A_SIZE)) { + a_v(o, i) = (float)i; + c_ref(o, i) = (float)(i * 2); + } + for (const auto i : c10::irange(B_SIZE)) { + b_v(o, i) = (float)(B_SIZE - i); + d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, + a_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, + b_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, + c_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, + d_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(a_dev, b_dev, c_dev, d_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), + c_dev, + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), + d_dev, + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +// Tests the case with two loop nests, each of which bound to the same block +// size, but with internal loops bound to different thread rank (ie x and y). In +// this case both bodies must be masked against the other dimension being > 0. +// Note: this is a bit degenerate no one would actually write this for perf. +TEST(Cuda, MaskMultiDimMultiAxis_CUDA) { + int OUTER_SIZE = 10; + int A_SIZE = 30; + int B_SIZE = 15; + BufHandle a_buf("a", {OUTER_SIZE, A_SIZE}, kFloat); + BufHandle b_buf("b", {OUTER_SIZE, B_SIZE}, kFloat); + Tensor c = Compute( + "C", {OUTER_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return ExprHandle(2) * a_buf.load(i, j); + }); + Tensor d = Compute( + "D", {OUTER_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return c.load(i, j * 2) + b_buf.load(i, j); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(1); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // Both stores masked against the other thread dim < 1. + const std::string& verification_pattern = + R"IR( +# CHECK: if (threadIdx.y<1 +# CHECK: C[threadIdx.x + 30 * blockIdx.x] = +# CHECK: __syncthreads(); +# CHECK: if (threadIdx.x<1 +# CHECK: D[threadIdx.y + 15 * blockIdx.x] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); + + PaddedBuffer a_v(OUTER_SIZE, A_SIZE); + PaddedBuffer b_v(OUTER_SIZE, B_SIZE); + PaddedBuffer c_v(OUTER_SIZE, A_SIZE); + PaddedBuffer d_v(OUTER_SIZE, B_SIZE); + + PaddedBuffer c_ref(OUTER_SIZE, A_SIZE); + PaddedBuffer d_ref(OUTER_SIZE, B_SIZE); + + for (const auto o : c10::irange(OUTER_SIZE)) { + for (const auto i : c10::irange(A_SIZE)) { + a_v(o, i) = (float)i; + c_ref(o, i) = (float)(i * 2); + } + } + + for (const auto o : c10::irange(OUTER_SIZE)) { + for (const auto i : c10::irange(B_SIZE)) { + b_v(o, i) = (float)(B_SIZE - i); + d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_SIZE * A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_SIZE * B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, + a_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, + b_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, + c_v.data(), + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, + d_v.data(), + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), + c_dev, + OUTER_SIZE * A_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), + d_dev, + OUTER_SIZE * B_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +// Tests the case with two loop nests, each bound to both Block and Thread but +// the second loop is smaller in both cases - the second store must be masked +// for both the block and thread dimension. +TEST(Cuda, MaskMultiDimMultiLevel_CUDA) { + int OUTER_A_SIZE = 10; + int OUTER_B_SIZE = 5; + int A_SIZE = 30; + int B_SIZE = 15; + BufHandle a_buf("a", {OUTER_A_SIZE, A_SIZE}, kFloat); + BufHandle b_buf("b", {OUTER_B_SIZE, B_SIZE}, kFloat); + Tensor c = Compute( + "C", {OUTER_A_SIZE, A_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return ExprHandle(2) * a_buf.load(i, j); + }); + Tensor d = Compute( + "D", {OUTER_B_SIZE, B_SIZE}, [&](const VarHandle& i, const VarHandle& j) { + return c.load(i, j * 2) + b_buf.load(i, j); + }); + + LoopNest l({c, d}); + std::vector loops = l.getLoopStmtsFor(c); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + loops = l.getLoopStmtsFor(d); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + CudaCodeGen cuda_cg(stmt, c, d, a_buf, b_buf); + + std::ostringstream oss; + oss << *cuda_cg.stmt(); + + // The write to D should be masked twice, but not the write to C. + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: if ( +# CHECK: C[threadIdx.x + 30 * blockIdx.x] = +# CHECK: __syncthreads(); +# CHECK: if (blockIdx.x<5 +# CHECK: if (threadIdx.x<15 +# CHECK: D[threadIdx.x + 15 * blockIdx.x] =)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto blockExtents = cuda_cg.gpu_block_extents(); + auto threadExtents = cuda_cg.gpu_thread_extents(); + ASSERT_TRUE(exprEquals(blockExtents[0], alloc(OUTER_A_SIZE))); + ASSERT_TRUE(exprEquals(threadExtents[0], alloc(A_SIZE))); + + PaddedBuffer a_v(OUTER_A_SIZE, A_SIZE); + PaddedBuffer b_v(OUTER_B_SIZE, B_SIZE); + PaddedBuffer c_v(OUTER_A_SIZE, A_SIZE); + PaddedBuffer d_v(OUTER_B_SIZE, B_SIZE); + + PaddedBuffer c_ref(OUTER_A_SIZE, A_SIZE); + PaddedBuffer d_ref(OUTER_B_SIZE, B_SIZE); + + for (const auto o : c10::irange(OUTER_A_SIZE)) { + for (const auto i : c10::irange(A_SIZE)) { + a_v(o, i) = (float)i; + c_ref(o, i) = (float)(i * 2); + } + } + + for (const auto o : c10::irange(OUTER_B_SIZE)) { + for (const auto i : c10::irange(B_SIZE)) { + b_v(o, i) = (float)(B_SIZE - i); + d_ref(o, i) = c_ref(o, i * 2) + b_v(o, i); + } + } + + float* a_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&a_dev, OUTER_A_SIZE * A_SIZE * sizeof(float))); + float* b_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&b_dev, OUTER_B_SIZE * B_SIZE * sizeof(float))); + float* c_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&c_dev, OUTER_A_SIZE * A_SIZE * sizeof(float))); + float* d_dev = nullptr; + C10_CUDA_CHECK(cudaMalloc(&d_dev, OUTER_B_SIZE * B_SIZE * sizeof(float))); + C10_CUDA_CHECK(cudaMemcpy( + a_dev, + a_v.data(), + OUTER_A_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + b_dev, + b_v.data(), + OUTER_B_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + c_dev, + c_v.data(), + OUTER_A_SIZE * A_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaMemcpy( + d_dev, + d_v.data(), + OUTER_B_SIZE * B_SIZE * sizeof(float), + cudaMemcpyHostToDevice)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + cuda_cg(c_dev, d_dev, a_dev, b_dev); + + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(cudaMemcpy( + c_v.data(), + c_dev, + OUTER_A_SIZE * A_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaMemcpy( + d_v.data(), + d_dev, + OUTER_B_SIZE * B_SIZE * sizeof(float), + cudaMemcpyDeviceToHost)); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + ExpectAllNear(c_v, c_ref, 1e-5); + ExpectAllNear(d_v, d_ref, 1e-5); + + C10_CUDA_CHECK(cudaFree(a_dev)); + C10_CUDA_CHECK(cudaFree(b_dev)); + C10_CUDA_CHECK(cudaFree(c_dev)); + C10_CUDA_CHECK(cudaFree(d_dev)); +} + +} // namespace jit +} // namespace torch + +#endif diff --git a/test/cpp/tensorexpr/test_dynamic_shapes.cpp b/test/cpp/tensorexpr/test_dynamic_shapes.cpp new file mode 100644 index 0000000000000..07b9872fb8325 --- /dev/null +++ b/test/cpp/tensorexpr/test_dynamic_shapes.cpp @@ -0,0 +1,701 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::indexing; +using namespace torch::jit::tensorexpr; + +TEST(DynamicShapes, SimpleGraph) { +#ifdef TORCH_ENABLE_LLVM + std::shared_ptr graph = std::make_shared(); + const auto graph_string = R"IR( + graph(%x : Tensor, + %SS_2 : int, + %SS_3 : int): + %3 : Tensor = aten::tanh(%x) + %4 : Tensor = aten::erf(%3) + return (%4))IR"; + torch::jit::parseIR(graph_string, graph.get()); + + auto x_inp = graph->inputs()[0]; + auto x_type = TensorType::create(at::rand({10, 5})); + std::vector x_sym_dims( + {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()}); + auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims); + graph->inputs().at(0)->setType(x_sym_type); + for (const auto n : graph->nodes()) { + n->output()->setType(x_sym_type); + } + + // Graph with symbolic shapes: + // + // graph(%x : Float(SS(-2), SS(-3)), + // %SS_2 : int, + // %SS_3 : int): + // %3 : Float(SS(-2), SS(-3)) = aten::tanh(%x) + // %4 : Float(SS(-2), SS(-3)) = aten::erf(%3) + // return (%4) + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[x_inp] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + std::vector symbolic_shape_inputs = c10::fmap( + x_sym_dims, + [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); }); + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + // Run with the same static dims as the one we initialized the graph with. + { + auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::erf(at::tanh(a)); + + std::vector stack = fmap(std::vector({a})); + stack.push_back(10); + stack.push_back(5); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } + + // Run with inputs having different dims. + { + auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::erf(at::tanh(a)); + + std::vector stack = fmap(std::vector({a})); + stack.push_back(50); + stack.push_back(100); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } +#endif +} + +TEST(DynamicShapes, GraphWith2InputsSameDims) { +#ifdef TORCH_ENABLE_LLVM + // The two inputs in this graph must have the same dims. + std::shared_ptr graph = std::make_shared(); + const auto graph_string = R"IR( + graph(%x : Tensor, + %y : Tensor, + %SS_2 : int, + %SS_3 : int): + %3 : Tensor = aten::tanh(%x) + %4 : Tensor = aten::erf(%3) + %5 : Tensor = aten::mul(%4, %y) + return (%5))IR"; + torch::jit::parseIR(graph_string, graph.get()); + + auto x_inp = graph->inputs()[0]; + auto y_inp = graph->inputs()[1]; + auto x_type = TensorType::create(at::rand({10, 5})); + std::vector x_sym_dims( + {c10::ShapeSymbol::newSymbol(), c10::ShapeSymbol::newSymbol()}); + auto x_sym_type = x_type->withSymbolicShapes(x_sym_dims); + graph->inputs().at(0)->setType(x_sym_type); + graph->inputs().at(1)->setType(x_sym_type); + for (const auto n : graph->nodes()) { + n->output()->setType(x_sym_type); + } + + // Graph with symbolic shapes: + // + // graph(%x : Float(SS(-4), SS(-5)), + // %y : Float(SS(-4), SS(-5)), + // %SS_2 : int, + // %SS_3 : int): + // %4 : Float(SS(-4), SS(-5)) = aten::tanh(%x) + // %5 : Float(SS(-4), SS(-5)) = aten::erf(%4) + // %6 : Float(SS(-4), SS(-5)) = aten::mul(%5, %y) + // return (%6) + + std::vector symbolic_shape_inputs = c10::fmap( + x_sym_dims, + [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); }); + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[x_inp] = input_desc; + symbolic_strides[y_inp] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + // Run with the same static dims as the one we initialized the graph with. + { + auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto b = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::erf(at::tanh(a)), b); + + std::vector stack = fmap(std::vector({a, b})); + stack.push_back(10); + stack.push_back(5); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } + + // Run with inputs having different dims. + { + auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto b = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::erf(at::tanh(a)), b); + + std::vector stack = fmap(std::vector({a, b})); + stack.push_back(50); + stack.push_back(100); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } +#endif +} + +TEST(DynamicShapes, GraphWith2InputsAndBroadcast) { +#ifdef TORCH_ENABLE_LLVM + // The second input to the graph has a dim of size 1 which should be + // broadcasted in the at::mul op. + std::shared_ptr graph = std::make_shared(); + const auto graph_string = R"IR( + graph(%x : Float(10, 5, requires_grad=0, device=cpu), + %y : Float(1, 5, requires_grad=0, device=cpu), + %SS_2 : int, + %SS_3 : int): + %3 : Tensor = aten::tanh(%x) + %4 : Tensor = aten::erf(%3) + %5 : Tensor = aten::mul(%4, %y) + return (%5))IR"; + torch::jit::parseIR(graph_string, graph.get()); + + auto x_inp = graph->inputs()[0]; + auto y_inp = graph->inputs()[1]; + auto x_type = TensorType::create(at::rand({10, 5})); + auto y_type = TensorType::create(at::rand({1, 5})); + auto x_dim0_sym = c10::ShapeSymbol::newSymbol(); + auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); + auto x_sym_type = x_type->withSymbolicShapes( + std::vector({x_dim0_sym, x_dim1_sym})); + auto y_sym_type = y_type->withSymbolicShapes(std::vector( + {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym})); + graph->inputs().at(0)->setType(x_sym_type); + graph->inputs().at(1)->setType(y_sym_type); + for (const auto n : graph->nodes()) { + n->output()->setType(x_sym_type); + } + + // Graph with symbolic shapes: + // + // graph(%x : Float(SS(-6), SS(-7)), + // %y : Float(1, SS(-7)), + // %SS_2 : int, + // %SS_3 : int): + // %4 : Float(SS(-6), SS(-7)) = aten::tanh(%x) + // %5 : Float(SS(-6), SS(-7)) = aten::erf(%4) + // %6 : Float(SS(-6), SS(-7)) = aten::mul(%5, %y) + // return (%6) + + std::vector symbolic_shape_inputs( + {x_dim0_sym.value(), x_dim1_sym.value()}); + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[x_inp] = input_desc; + symbolic_strides[y_inp] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + // Run with the same static dims as the one we initialized the graph with. + { + auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::erf(at::tanh(a)), b); + + std::vector stack = fmap(std::vector({a, b})); + stack.push_back(10); + stack.push_back(5); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } + + // Run with inputs having different dims. + { + auto a = at::rand({50, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::erf(at::tanh(a)), b); + + std::vector stack = fmap(std::vector({a, b})); + stack.push_back(50); + stack.push_back(100); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } +#endif +} + +TEST(DynamicShapes, GraphWithPartiallySymbolicOutput) { +#ifdef TORCH_ENABLE_LLVM + // The second input to the graph has a dim of size 1 which should be + // broadcasted in the at::mul op. + std::shared_ptr graph = std::make_shared(); + const auto graph_string = R"IR( + graph(%x : Float(1, 5, requires_grad=0, device=cpu), + %y : Float(1, 5, requires_grad=0, device=cpu), + %SS_2 : int): + %4 : Tensor = aten::tanh(%x) + %5 : Tensor = aten::mul(%4, %y) + return (%5))IR"; + torch::jit::parseIR(graph_string, graph.get()); + + auto x_inp = graph->inputs()[0]; + auto y_inp = graph->inputs()[1]; + auto x_type = TensorType::create(at::rand({1, 5})); + auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); + auto x_sym_type = x_type->withSymbolicShapes(std::vector( + {c10::ShapeSymbol::fromStaticSize(1), x_dim1_sym})); + graph->inputs().at(0)->setType(x_sym_type); + graph->inputs().at(1)->setType(x_sym_type); + for (const auto n : graph->nodes()) { + n->output()->setType(x_sym_type); + } + + // Graph with symbolic shapes: + // + // graph(%x : Float(1, SS(-2)), + // %y : Float(1, SS(-2)), + // %SS_2 : int): + // %3 : Float(1, SS(-2)) = aten::tanh(%x) + // %4 : Float(1, SS(-2)) = aten::mul(%3, %y) + // return (%4) + + std::vector symbolic_shape_inputs({x_dim1_sym.value()}); + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[x_inp] = input_desc; + symbolic_strides[y_inp] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + // Run with the same static dims as the one we initialized the graph with. + { + auto a = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto b = at::rand({1, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::tanh(a), b); + + std::vector stack = fmap(std::vector({a, b})); + stack.push_back(5); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } + + // Run with inputs having different dims. + { + auto a = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto b = at::rand({1, 100}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::tanh(a), b); + + std::vector stack = fmap(std::vector({a, b})); + stack.push_back(100); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } +#endif +} + +TEST(DynamicShapes, GraphWithSymbolicStrides) { +#ifdef TORCH_ENABLE_LLVM + std::shared_ptr graph = std::make_shared(); + const auto graph_string = R"IR( + graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), + %1 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), + %SS_3 : int, + %SS_2 : int): + %15 : int = prim::Constant[value=1]() + %21 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::add(%0, %1, %15) + %22 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu) = aten::mul(%21, %0) + return (%22))IR"; + parseIR(graph_string, &*graph); + + std::vector input_desc = { + torch::jit::StrideInput::S_AS_ARG, torch::jit::StrideInput::S_ONE}; + std::vector output_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[graph->inputs().at(0)] = input_desc; + symbolic_strides[graph->inputs().at(1)] = input_desc; + symbolic_strides[graph->outputs().at(0)] = output_desc; + std::vector symbolic_shape_inputs = {-3, -2}; + TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + { + auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::add(x0, x1, 1), x0); + + std::vector inputs = {x0, x1}; + std::vector stack = at::fmap(inputs); + stack.push_back(32); + stack.push_back(10); + k.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } + + { + auto x0 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto x1 = at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto out = + at::rand({10, 32}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::add(x0, x1, 1), x0); + + std::vector inputs = {out, x0, x1}; + std::vector stack = at::fmap(inputs); + stack.push_back(32); + stack.push_back(10); + k.runWithAllocatedOutputs(stack); + + ASSERT_TRUE(at::allclose(out, ref)); + } +#endif +} + +TEST(DynamicShapes, GraphWithCatAndBroadcast) { +#ifdef TORCH_ENABLE_LLVM + std::shared_ptr graph = std::make_shared(); + const auto graph_string = R"IR( + graph(%x : Float(10, 5, requires_grad=0, device=cpu), + %y : Float(4, 5, requires_grad=0, device=cpu), + %z : Float(1, 1, requires_grad=0, device=cpu), + %SS_2 : int, + %SS_3 : int, + %SS_4 : int, + %SS_5 : int): + %11 : int = prim::Constant[value=0]() + %3 : Tensor = aten::tanh(%x) + %out1 : Tensor = aten::erf(%3) + %out2 : Tensor = aten::relu(%y) + %10 : Tensor[] = prim::ListConstruct(%out1, %out2) + %25 : Tensor = aten::cat(%10, %11) + %28 : Tensor = aten::hardswish(%25) + %29 : Tensor = aten::mul(%28, %z) + return (%29))IR"; + torch::jit::parseIR(graph_string, graph.get()); + + auto x_inp = graph->inputs()[0]; + auto y_inp = graph->inputs()[1]; + auto z_inp = graph->inputs()[2]; + auto x_type = TensorType::create(at::rand({10, 5})); + auto y_type = TensorType::create(at::rand({4, 5})); + auto z_type = TensorType::create(at::rand({1, 1})); + auto x_dim0_sym = c10::ShapeSymbol::newSymbol(); + auto x_dim1_sym = c10::ShapeSymbol::newSymbol(); + auto x_sym_type = x_type->withSymbolicShapes( + std::vector({x_dim0_sym, x_dim1_sym})); + auto y_dim0_sym = c10::ShapeSymbol::newSymbol(); + auto y_sym_type = y_type->withSymbolicShapes( + std::vector({y_dim0_sym, x_dim1_sym})); + graph->inputs().at(0)->setType(x_sym_type); + graph->inputs().at(1)->setType(y_sym_type); + auto cat_dim0_sym = c10::ShapeSymbol::newSymbol(); + auto cat_out_type = x_type->withSymbolicShapes( + std::vector({cat_dim0_sym, x_dim1_sym})); + auto nodeIt = graph->nodes().begin(); + ++nodeIt; + nodeIt->output()->setType(x_sym_type); // aten::tanh + ++nodeIt; + nodeIt->output()->setType(x_sym_type); // aten::erf + ++nodeIt; + nodeIt->output()->setType(y_sym_type); // aten::relu + ++nodeIt; + ++nodeIt; + nodeIt->output()->setType(cat_out_type); // aten::cat + ++nodeIt; + nodeIt->output()->setType(cat_out_type); // aten::hardswish + ++nodeIt; + nodeIt->output()->setType(cat_out_type); // aten::mul + + // Graph with symbolic shapes: + // + // graph(%x : Float(SS(-2), SS(-3)), + // %y : Float(SS(-4), SS(-3)), + // %z : Float(1, 1), + // %SS_2 : int, + // %SS_3 : int, + // %SS_4 : int, + // %SS_5 : int): + // %7 : int = prim::Constant[value=0]() + // %8 : Float(SS(-2), SS(-3)) = aten::tanh(%x) + // %9 : Float(SS(-2), SS(-3)) = aten::erf(%8) + // %10 : Float(SS(-4), SS(-3)) = aten::relu(%y) + // %11 : Tensor[] = prim::ListConstruct(%9, %10) + // %12 : Float(SS(-5), SS(-3)) = aten::cat(%11, %7) + // %13 : Float(SS(-5), SS(-3)) = aten::hardswish(%12) + // %14 : Float(SS(-5), SS(-3)) = aten::mul(%13, %z) + // return (%14) + + std::vector symbolic_shape_inputs( + {x_dim0_sym.value(), + x_dim1_sym.value(), + y_dim0_sym.value(), + cat_dim0_sym.value()}); + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[x_inp] = input_desc; + symbolic_strides[y_inp] = input_desc; + symbolic_strides[z_inp] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto b = at::rand({4, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto c = at::rand({1, 1}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul( + at::hardswish(at::cat({at::erf(at::tanh(a)), at::relu(b)}, 0)), c); + + std::vector stack = fmap(std::vector({a, b, c})); + stack.push_back(10); + stack.push_back(5); + stack.push_back(4); + stack.push_back(14); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); +#endif +} + +TEST(DynamicShapes, GraphFromModel) { +#ifdef TORCH_ENABLE_LLVM + std::shared_ptr graph = std::make_shared(); + const auto graph_string = R"IR( + graph(%0 : Float(SS(-2), SS(-3), requires_grad=0, device=cpu), + %1 : Float(SS(-2), SS(-4), requires_grad=0, device=cpu), + %2 : Float(SS(-2), SS(-5), requires_grad=0, device=cpu), + %input.4 : Long(SS(-2), SS(-6), requires_grad=0, device=cpu), + %4 : Float(SS(-7), requires_grad=0, device=cpu), + %5 : Float(SS(-7), requires_grad=0, device=cpu), + %SS_10 : int, + %SS_9 : int, + %SS_8 : int, + %SS_7 : int, + %SS_6 : int, + %SS_5 : int, + %SS_4 : int, + %SS_3 : int, + %SS_2 : int): + %15 : int = prim::Constant[value=1]() + %16 : bool = prim::Constant[value=0]() + %17 : int = prim::Constant[value=6]() + %18 : Float(SS(-2), SS(-6), strides=[139, 1], requires_grad=0, device=cpu) = aten::to(%input.4, %17, %16, %16) + %19 : Tensor[] = prim::ListConstruct(%0, %1, %18, %2) + %20 : Float(SS(-2), SS(-8), strides=[261, 1], requires_grad=0, device=cpu) = aten::cat(%19, %15) + %21 : Float(SS(-2), SS(-9), strides=[261, 1], requires_grad=0, device=cpu) = aten::add(%20, %5, %15) + %22 : Float(SS(-2), SS(-10), requires_grad=0, device=cpu) = aten::mul(%21, %4) + return (%22))IR"; + parseIR(graph_string, &*graph); + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[graph->inputs().at(0)] = input_desc; + symbolic_strides[graph->inputs().at(1)] = input_desc; + symbolic_strides[graph->inputs().at(2)] = input_desc; + symbolic_strides[graph->inputs().at(3)] = input_desc; + symbolic_strides[graph->inputs().at(4)] = input_desc; + symbolic_strides[graph->inputs().at(5)] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + std::vector symbolic_shape_inputs = { + -10, -9, -8, -7, -6, -5, -4, -3, -2}; + TensorExprKernel k(graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + int64_t i2 = 10; + int64_t i3 = 32; + int64_t i4 = 19; + int64_t i5 = 71; + int64_t i6 = 139; + int64_t i7 = 261; + int64_t i8 = 261; + int64_t i9 = 261; + int64_t i10 = 261; + auto x0 = at::rand({i2, i3}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto x1 = at::rand({i2, i4}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto x2 = at::rand({i2, i5}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto x3 = at::ones({i2, i6}, at::TensorOptions(at::kCPU).dtype(at::kLong)); + auto x4 = at::rand({i7}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto x5 = at::rand({i8}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + auto ref = at::mul(at::add(at::cat({x0, x1, x3, x2}, 1), x5), x4); + + { + std::vector inputs = {x0, x1, x2, x3, x4, x5}; + std::vector stack = at::fmap(inputs); + stack.emplace_back(i10); + stack.emplace_back(i9); + stack.emplace_back(i8); + stack.emplace_back(i7); + stack.emplace_back(i6); + stack.emplace_back(i5); + stack.emplace_back(i4); + stack.emplace_back(i3); + stack.emplace_back(i2); + k.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + } + + { + auto out = + at::rand({i2, i10}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); + std::vector inputs = {out, x0, x1, x2, x3, x4, x5}; + std::vector stack = at::fmap(inputs); + stack.emplace_back(i10); + stack.emplace_back(i9); + stack.emplace_back(i8); + stack.emplace_back(i7); + stack.emplace_back(i6); + stack.emplace_back(i5); + stack.emplace_back(i4); + stack.emplace_back(i3); + stack.emplace_back(i2); + k.runWithAllocatedOutputs(stack); + + ASSERT_TRUE(at::allclose(out, ref)); + } +#endif +} + +TEST(DynamicShapes, MultiThreadedExecution) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_template = R"IR( + graph(%x : Float(SS(-2), SS(-3), requires_grad=0, device=${device}), + %y : Float(SS(-2), SS(-3), requires_grad=0, device=${device}), + %SS_2 : int, + %SS_3 : int): + %3 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::tanh(%x) + %4 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::erf(%3) + %5 : Float(SS(-2), SS(-3), requires_grad=0, device=${device}) = aten::mul(%4, %y) + return (%5))IR"; + for (bool use_cuda : {false, true}) { + if (!torch::cuda::is_available() && use_cuda) { + continue; + } + auto device = use_cuda ? at::kCUDA : at::kCPU; + at::jit::TemplateEnv env; + env.s("device", use_cuda ? "cuda:0" : "cpu"); + const auto graph_string = format(graph_template, env); + std::shared_ptr graph = std::make_shared(); + torch::jit::parseIR(graph_string, graph.get()); + + std::vector symbolic_shape_inputs = {-2, -3}; + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[graph->inputs().at(0)] = input_desc; + symbolic_strides[graph->inputs().at(1)] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + auto run_kernel = [&](int dim1, int dim2) { + auto a = + at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat)); + auto b = + at::rand({dim1, dim2}, at::TensorOptions(device).dtype(at::kFloat)); + + auto ref = at::mul(at::erf(at::tanh(a)), b); + + std::vector stack = fmap(std::vector({a, b})); + stack.emplace_back(dim1); + stack.emplace_back(dim2); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + }; + + // Run the kernel in parallel to ensure that the run() method calls in + // TensorExprKernel are not changing any state. + constexpr size_t kNumThreads = 4; + std::vector threads; + for (size_t id = 0; id < kNumThreads; ++id) { + threads.emplace_back(run_kernel, id + 5, id + 20); + } + for (auto& t : threads) { + t.join(); + } + } +#endif +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_expr.cpp b/test/cpp/tensorexpr/test_expr.cpp new file mode 100644 index 0000000000000..eb2d6296b2299 --- /dev/null +++ b/test/cpp/tensorexpr/test_expr.cpp @@ -0,0 +1,836 @@ +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; + +using SimpleIRExprEval = ExprEval; + +TEST(Expr, BasicValueTest) { + ExprHandle a = IntImm::make(2), b = IntImm::make(3); + ExprHandle c = Add::make(a, b); + SimpleIRExprEval eval(c); + ASSERT_EQ(eval.value(), 5); +} + +TEST(Expr, BasicValueTest02) { + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle c(4.0f); + ExprHandle d(5.0f); + ExprHandle f = (a + b) - (c + d); + SimpleIRExprEval eval(f); + ASSERT_EQ(eval.value(), -4.0f); +} + +TEST(Expr, IsChannelsLastContiguous) { + std::vector vars = { + VarHandle("var1", kLong), + VarHandle("var2", kLong), + VarHandle("var3", kLong), + VarHandle("var4", kLong), + VarHandle("var5", kLong)}; + + // { + // key: ndims, + // value: [ + // ... + // [dim_2, dim_1, ..., dim_n] + // ] + // } + using shapGenInfo = std::unordered_map>>; + + // { + // size: [ExprHandle_1, ExprHandle_2, ..., ExprHandle_n], + // strides: [ + // ... + // [ExprHandle_x, ExprHandle_y, ..., ExprHandle_z] + // ] + // } + using shapeInfo = + std::pair, std::vector>>; + + std::vector dims = {3, 4, 5}; + + std::unordered_map> dims_expr_vec_conf = { + {3, std::vector(vars.begin(), vars.begin() + 2)}, + {4, std::vector(vars.begin(), vars.begin() + 3)}, + {5, std::vector(vars.begin(), vars.begin() + 4)}, + }; + + shapGenInfo channels_last_cont_shape_conf = { + {3, {{1, 2, 0}}}, {4, {{1, 3, 2, 0}}}, {5, {{1, 4, 3, 2, 0}}}}; + shapGenInfo channels_last_non_cont_shape_conf = { + {3, {{2, 1, 0}, {1, 0, 2}}}, + {4, {{3, 1, 2, 0}, {1, 2, 3, 0}, {1, 0, 2, 3}}}, + {5, {{4, 3, 2, 1, 0}, {1, 3, 2, 4, 0}, {1, 4, 3, 2, 0}}}}; + + shapGenInfo cont_shape_conf = { + {3, {{0, 1, 2}}}, {4, {{0, 1, 2, 3}}}, {5, {{0, 1, 2, 3, 4}}}}; + + auto shape_gen_fn = [dims_expr_vec_conf]( + int ndims, shapGenInfo shape_gen_info) -> shapeInfo { + auto dims_expr_vec = dims_expr_vec_conf.at(ndims); + std::vector> strides_expr_vec; + for (size_t i = 0; i < strides_expr_vec.size(); i++) { + strides_expr_vec[i].resize(ndims); + } + + auto stride_gen_fn = [](int indicator, ExprHandle a, ExprHandle b) { + if (indicator % 2 == 0) { + return a * b; + } else { + return b * a; + } + }; + + auto stride_order_vec = shape_gen_info.at(ndims); + for (size_t i = 0; i < strides_expr_vec.size(); i++) { + auto stride_order = stride_order_vec[i]; + + strides_expr_vec[i][stride_order[0]] = 1; + for (size_t j = 1; j < stride_order.size(); j++) { + auto cur_dim_idx = stride_order[j]; + auto adjacent_dim_idx = stride_order[j - 1]; + + strides_expr_vec[i][cur_dim_idx] = stride_gen_fn( + i, + dims_expr_vec[adjacent_dim_idx], + strides_expr_vec[i][adjacent_dim_idx]); + } + } + + return {dims_expr_vec, strides_expr_vec}; + }; + + auto check_channels_last_fn = [](int ndims, BufHandle buf_handle) -> bool { + if (ndims == 3) { + return buf_handle.is_channels_last_1d_contiguous(); + } else if (ndims == 4) { + return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast); + } else { + return buf_handle.is_contiguous(at::MemoryFormat::ChannelsLast3d); + } + }; + + // channels-last contiguous + for (size_t i = 0; i < dims.size(); i++) { + auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf); + for (size_t j = 0; j < shape_info.second.size(); j++) { + BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); + ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), true); + } + } + + // channels-last non-contiguous + for (size_t i = 0; i < dims.size(); i++) { + auto shape_info = shape_gen_fn(dims[i], channels_last_non_cont_shape_conf); + for (size_t j = 0; j < shape_info.second.size(); j++) { + BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); + ASSERT_EQ(check_channels_last_fn(dims[i], buf_handle), false); + } + } + + // contiguous + for (size_t i = 0; i < dims.size(); i++) { + auto shape_info = shape_gen_fn(dims[i], cont_shape_conf); + for (size_t j = 0; j < shape_info.second.size(); j++) { + BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); + ASSERT_EQ(buf_handle.is_contiguous(), true); + } + } + + // non-contiguous + for (size_t i = 0; i < dims.size(); i++) { + auto shape_info = shape_gen_fn(dims[i], channels_last_cont_shape_conf); + for (size_t j = 0; j < shape_info.second.size(); j++) { + BufHandle buf_handle("a", shape_info.first, shape_info.second[j], kFloat); + ASSERT_EQ(buf_handle.is_contiguous(), false); + } + } +} + +TEST(Expr, LetTest01) { + VarHandle x("x", kFloat); + ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle(3.f)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, LetTest02) { + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle body = + ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle(3.f)); + eval.bindVar(y, ExprHandle(6.f)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4 * 6)); +} + +TEST(Expr, LetStmtTest01) { + BufHandle a_buf("a", {1}, kFloat); + BufHandle b_buf("b", {1}, kFloat); + + ExprHandle load_a = a_buf.load(0); + VarHandle var = VarHandle("v", kFloat); + StmtPtr let_store = Let::make(var, load_a); + StmtPtr store_b = b_buf.store({0}, var); + BlockPtr block = Block::make({let_store, store_b}); + + SimpleIREvaluator eval(block, {a_buf, b_buf}); + + PaddedBuffer a_v(1); + PaddedBuffer b_v(1); + PaddedBuffer b_ref(1); + + a_v(0) = 23; + b_ref(0) = a_v(0); + eval(a_v, b_v); + + ExpectAllNear(b_v, b_ref, 1e-5); +} + +TEST(Expr, IntTest) { + VarHandle x("x", kInt); + ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle(3)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, FloatTest) { + VarHandle x("x", kFloat); + ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle(3.f)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, ByteTest) { + VarHandle x("x", kByte); + ExprHandle body = ExprHandle((uint8_t)2) + + (x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle((uint8_t)3)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, CharTest) { + VarHandle x("x", kChar); + ExprHandle body = ExprHandle((int8_t)2) + + (x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle((int8_t)3)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, ShortTest) { + VarHandle x("x", kShort); + ExprHandle body = ExprHandle((int16_t)2) + + (x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle((int16_t)3)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, LongTest) { + VarHandle x("x", kLong); + ExprHandle body = ExprHandle((int64_t)2) + + (x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle((int64_t)3)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, HalfTest) { + VarHandle x("x", kHalf); + ExprHandle body = ExprHandle((at::Half)2) + + (x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle((at::Half)3)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, DoubleTest) { + VarHandle x("x", kDouble); + ExprHandle body = ExprHandle((double)2) + + (x * ExprHandle((double)3) + ExprHandle((double)4)); + SimpleIRExprEval eval(body); + eval.bindVar(x, ExprHandle((double)3)); + ASSERT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +TEST(Expr, VectorAdd01) { + const int kVectorSize = 8; + const int kVectorCount = 128; + const int kTotalSize = kVectorSize * kVectorCount; + + BufHandle a_buf("A", {kTotalSize}, kFloat); + BufHandle b_buf("B", {kTotalSize}, kFloat); + BufHandle c_buf("C", {kTotalSize}, kFloat); + + /* + Build the following: + for (const auto index : c10::irange(kVectorCount)) { + store(c_buf, ramp(index * 8, 1, 8), + load(a_buf, ramp(index * 8, 1, 8) + + load(b_buf, ramp(index * 8, 1, 8)))) + } + */ + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = + a_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)}); + ExprHandle load_b = + b_buf.load({Ramp::make(index * kVectorSize, 1, kVectorSize)}); + ExprHandle value = load_a + load_b; + StmtPtr store_c = + c_buf.store({Ramp::make(index * kVectorSize, 1, kVectorSize)}, value); + StmtPtr stmt = For::make(index, 0, kVectorCount, store_c); + + ASSERT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize)); + ASSERT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize)); + ASSERT_EQ(value.dtype(), Dtype(kFloat, kVectorSize)); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer c_ref(kTotalSize); + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = i * i; + b_v(i) = i * i * 4; + c_ref(i) = a_v(i) + b_v(i); + } + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c_buf}); + ir_eval(a_v, b_v, c_v); + ExpectAllNear(c_v, c_ref, 1e-5); +} + +TEST(Expr, CompareSelectEQ) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 1); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 0); + std::vector c_ref(N, 0); + + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kEQ))); + + SimpleIREvaluator ir_eval(memcpy_expr, {a, b, c}); + ir_eval(a_buffer, b_buffer, c_buffer); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(a_buffer, 1); + assertAllEqual(b_buffer, 1); + assertAllEqual(c_buffer, 1); +} + +TEST(Expr, CompareSelectDtypes) { + // LHS and RHS expressions should have the same dtype, but this dtype could + // differ from the dtype of the return values (but dtypes of true and false + // return values should be the same). + // This test constructs a CompareSelect expression where the input dtype is + // different from the output dtype and verifies that it works correctly: + // result = ((int)lhs == (int)rhs) ? (float)retval1 : (float)retval2 + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kFloat); + std::vector a_buffer(N, 1); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 0.0f); + std::vector c_ref(N, 3.14f); + + VarHandle i("i", kInt); + // C[i] = (A[i] == B[i]) ? 3.14f : 2.78f + // A and B are int, C is float. + auto select_expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), + b.load(i), + FloatImm::make(3.14f), + FloatImm::make(2.78f), + CompareSelectOperation::kEQ))); + + SimpleIREvaluator ir_eval(select_expr, {a, b, c}); + ir_eval(a_buffer, b_buffer, c_buffer); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(a_buffer, 1); + assertAllEqual(b_buffer, 1); + ExpectAllNear(c_buffer, c_ref, 1e-7); +} + +TEST(Expr, IntrinsicsDtypes) { + constexpr int N = 256; + BufHandle a("A", {N}, kDouble); + BufHandle b("B", {N}, kDouble); + std::vector a_buffer(N, -10.0); + std::vector b_buffer(N, 0.0); + std::vector b_ref(N, 10.0); + + VarHandle i("i", kInt); + auto abs_expr = For::make(i, 0, N, b.store({i}, tensorexpr::abs(a.load(i)))); + + SimpleIREvaluator ir_eval(abs_expr, {a, b}); + ir_eval(a_buffer, b_buffer); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + + assertAllEqual(a_buffer, -10.0); + ExpectAllNear(b_buffer, b_ref, 1e-7); +} + +TEST(Expr, Substitute01) { + VarPtr x = alloc("x", kFloat); + VarPtr y = alloc("y", kFloat); + ExprPtr e = + alloc(alloc(x, alloc(1.0f)), alloc(x, y)); + + VarPtr z = alloc("z", kFloat); + ExprPtr e2 = Substitute(e, {{x, alloc(z, alloc(5.0f))}}); + ExprPtr e2_ref = alloc( + alloc(alloc(z, alloc(5.0f)), alloc(1.0f)), + alloc(alloc(z, alloc(5.0f)), y)); + std::ostringstream oss; + oss << *e2; + std::string e2_str = oss.str(); + + oss.str(""); + oss << *e2_ref; + std::string e2_ref_str = oss.str(); + ASSERT_EQ(e2_str, e2_ref_str); +} + +TEST(Expr, Math01) { + ExprHandle v = sin(ExprHandle(1.0f)); + + std::ostringstream oss; + oss << v; + ASSERT_EQ(oss.str(), "sin(1.f)"); + + SimpleIRExprEval eval(v); + float v_ref = std::sin(1.0f); + float res = eval.value(); + ASSERT_NEAR(res, v_ref, 1e-6); +} + +TEST(Expr, UnaryMath01) { + struct TestConfig { + std::function func; + std::function ref_func; + }; + + std::vector test_configs = { + {[](const ExprHandle& v) { return sin(v); }, + [](float v) { return std::sin(v); }}, + {[](const ExprHandle& v) { return sin(v); }, + [](float v) { return std::sin(v); }}, + {[](const ExprHandle& v) { return tan(v); }, + [](float v) { return std::tan(v); }}, + {[](const ExprHandle& v) { return asin(v); }, + [](float v) { return std::asin(v); }}, + {[](const ExprHandle& v) { return acos(v); }, + [](float v) { return std::acos(v); }}, + {[](const ExprHandle& v) { return atan(v); }, + [](float v) { return std::atan(v); }}, + {[](const ExprHandle& v) { return sinh(v); }, + [](float v) { return std::sinh(v); }}, + {[](const ExprHandle& v) { return cosh(v); }, + [](float v) { return std::cosh(v); }}, + {[](const ExprHandle& v) { return tanh(v); }, + [](float v) { return std::tanh(v); }}, + {[](const ExprHandle& v) { return exp(v); }, + [](float v) { return std::exp(v); }}, + {[](const ExprHandle& v) { return tensorexpr::abs(v); }, + [](float v) { return std::fabs(v); }}, + {[](const ExprHandle& v) { return log(v); }, + [](float v) { return std::log(v); }}, + {[](const ExprHandle& v) { return log2(v); }, + [](float v) { return std::log2(v); }}, + {[](const ExprHandle& v) { return log10(v); }, + [](float v) { return std::log10(v); }}, + {[](const ExprHandle& v) { return erf(v); }, + [](float v) { return std::erf(v); }}, + {[](const ExprHandle& v) { return sqrt(v); }, + [](float v) { return std::sqrt(v); }}, + {[](const ExprHandle& v) { return rsqrt(v); }, + [](float v) { return 1.0f / std::sqrt(v); }}, + {[](const ExprHandle& v) { return ceil(v); }, + [](float v) { return std::ceil(v); }}, + {[](const ExprHandle& v) { return floor(v); }, + [](float v) { return std::floor(v); }}, + {[](const ExprHandle& v) { return round(v); }, + [](float v) { return std::round(v); }}, + {[](const ExprHandle& v) { return trunc(v); }, + [](float v) { return std::trunc(v); }}, + }; + + for (const TestConfig& test_config : test_configs) { + const float input_v = 0.8765f; + ExprHandle v = test_config.func(ExprHandle(input_v)); + float v_ref = test_config.ref_func(input_v); + SimpleIRExprEval eval(v); + ASSERT_NEAR(eval.value(), v_ref, 1e-6); + } + + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + for (float input_v : {std::nan("1"), 0., .5}) { + ExprHandle v = FloatImm::make(input_v); + SimpleIRExprEval eval(Intrinsics::make(kIsNan, v)); + ASSERT_NEAR(eval.value(), std::isnan(input_v), 0); + } +} + +TEST(Expr, BinaryMath01) { + struct TestConfig { + std::function func; + std::function ref_func; + }; + + std::vector test_configs = { + {[](const ExprHandle& v1, const ExprHandle& v2) { return pow(v1, v2); }, + [](float v1, float v2) { return std::pow(v1, v2); }}, + {[](const ExprHandle& v1, const ExprHandle& v2) { return fmod(v1, v2); }, + [](float v1, float v2) { return std::fmod(v1, v2); }}, + }; + + for (const TestConfig& test_config : test_configs) { + const float v1 = 0.8765f; + float v2 = 1.2345f; + ExprHandle v_expr = test_config.func(ExprHandle(v1), ExprHandle(v2)); + float v_ref = test_config.ref_func(v1, v2); + SimpleIRExprEval eval(v_expr); + ASSERT_NEAR(eval.value(), v_ref, 1e-6); + } +} + +TEST(Expr, LogicalOps01) { + ExprHandle a(23); + ExprHandle b(11); + ExprHandle c(0.72f); + ExprHandle d(0.69f); + ExprHandle f1 = (a > b) && (c > d); + ExprHandle f2 = (a > b) && (c < d); + ExprHandle f3 = (a < b) && (c > d); + ExprHandle f4 = (a < b) && (c < d); + ExprHandle f5 = (a < b) || (c > d); + ExprHandle f6 = (a < b) || (c < d); + ExprHandle f7 = (a > b) || (c < d); + ExprHandle f8 = (a > b) || (c > d); + + SimpleIRExprEval eval1(f1); + SimpleIRExprEval eval2(f2); + SimpleIRExprEval eval3(f3); + SimpleIRExprEval eval4(f4); + SimpleIRExprEval eval5(f5); + SimpleIRExprEval eval6(f6); + SimpleIRExprEval eval7(f7); + SimpleIRExprEval eval8(f8); + ASSERT_EQ(eval1.value(), 1); + ASSERT_EQ(eval2.value(), 0); + ASSERT_EQ(eval3.value(), 0); + ASSERT_EQ(eval4.value(), 0); + ASSERT_EQ(eval5.value(), 1); + ASSERT_EQ(eval6.value(), 0); + ASSERT_EQ(eval7.value(), 1); + ASSERT_EQ(eval8.value(), 1); +} + +TEST(Expr, LogicalOps02) { + ExprHandle a(23); + ExprHandle b(11); + ExprHandle c(0.72f); + ExprHandle d(0.72f); + + ExprHandle f1 = (a > b) || (c > d); + ExprHandle f2 = (a > b) && (c <= d); + ExprHandle f3 = (a > b) && (c > d); + ExprHandle ff1 = f1 && f2; + ExprHandle ff2 = f2 || f3; + + SimpleIRExprEval eval1(ff1); + SimpleIRExprEval eval2(ff2); + ASSERT_EQ(eval1.value(), 1); + ASSERT_EQ(eval2.value(), 1); +} + +TEST(Expr, LogicalOps03) { + ExprHandle a(23); + ExprHandle b(11); + ExprHandle c(0.72f); + ExprHandle d(0.69f); + + // Bool types + ExprHandle bool_f1 = (a > b) && BoolImm::make(true); + ExprHandle bool_f2 = (c <= d) || BoolImm::make(true); + + // Int types + ExprHandle int_f1 = (a > b) && IntImm::make(1); + ExprHandle int_f2 = (c <= d) || IntImm::make(1); + + // Short types + ExprHandle short_f1 = (a > b) && ShortImm::make(1); + ExprHandle short_f2 = (c <= d) || ShortImm::make(1); + + // Long types + ExprHandle long_f1 = (a > b) && LongImm::make(1); + ExprHandle long_f2 = (c <= d) || LongImm::make(1); + + // Char types + ExprHandle char_f1 = (a > b) && CharImm::make(1); + ExprHandle char_f2 = (c <= d) || CharImm::make(1); + + // Byte types + ExprHandle byte_f1 = (a > b) && ByteImm::make(1); + ExprHandle byte_f2 = (c <= d) || ByteImm::make(1); + + SimpleIRExprEval eval1(bool_f1); + SimpleIRExprEval eval2(bool_f2); + SimpleIRExprEval eval3(int_f1); + SimpleIRExprEval eval4(int_f2); + SimpleIRExprEval eval5(short_f1); + SimpleIRExprEval eval6(short_f2); + SimpleIRExprEval eval7(long_f1); + SimpleIRExprEval eval8(long_f2); + SimpleIRExprEval eval9(char_f1); + SimpleIRExprEval eval10(char_f2); + SimpleIRExprEval eval11(byte_f1); + SimpleIRExprEval eval12(byte_f2); + + ASSERT_EQ(eval1.value(), true); + ASSERT_EQ(eval2.value(), true); + ASSERT_EQ(eval3.value(), 1); + ASSERT_EQ(eval4.value(), 1); + ASSERT_EQ(eval5.value(), 1); + ASSERT_EQ(eval6.value(), 1); + ASSERT_EQ(eval7.value(), 1); + ASSERT_EQ(eval8.value(), 1); + ASSERT_EQ(eval9.value(), 1); + ASSERT_EQ(eval10.value(), 1); + ASSERT_EQ(eval11.value(), 1); + ASSERT_EQ(eval12.value(), 1); +} + +TEST(Expr, BitwiseOps) { + ExprHandle a(59); + ExprHandle b(11); + ExprHandle c(101); + ExprHandle d(2); + ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d; + + SimpleIRExprEval eval(f); + ASSERT_EQ(eval.value(), 11); +} + +TEST(Expr, DynamicShapeAdd) { + auto testWithSize = [](int32_t size) { + VarHandle n("n", kInt); + BufHandle a("a", {n}, kFloat); + BufHandle b("b", {n}, kFloat); + BufHandle c("c", {n}, kFloat); + VarHandle i("i", kInt); + StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); + std::vector aData(size, 1.0f); + std::vector bData(size, 2.0f); + std::vector cData(size, 0.0f); + SimpleIREvaluator(s, {a, b, c, n})(aData, bData, cData, size); + ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); + }; + testWithSize(1); + testWithSize(16); + testWithSize(37); +} + +TEST(Expr, OutOfBounds) { + ExprHandle N(10); + ExprHandle start(0); + ExprHandle stop(15); + VarHandle i("i", kInt); + + BufHandle X("X", {N}, kInt); + + auto body = Store::make(X, {i}, i); + auto stmt = For::make(i, start, stop, body); + + PaddedBuffer data(20); + + EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); +} + +TEST(Expr, OutOfBounds2d) { + std::vector> size_options = {{10, 15}, {15, 10}}; + for (auto sizes : size_options) { + ExprHandle N(sizes.first); + ExprHandle M(sizes.second); + ExprHandle start(0); + ExprHandle stopInner(15); + ExprHandle stopOuter(15); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + + BufHandle X("X", {N, M}, kInt); + + auto body = Store::make(X, {i, j}, i); + auto inner = For::make(j, start, stopInner, body); + auto stmt = For::make(i, start, stopOuter, inner); + + PaddedBuffer data(400); + + EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); + } +} + +TEST(Expr, OutOfBounds2dFlattenedIndex) { + ExprHandle buf_size(149); + ExprHandle start(0); + ExprHandle stopInner(15); + ExprHandle stopOuter(10); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + + BufHandle X("X", {buf_size}, kInt); + + auto idx = Add::make(Mul::make(i, stopInner), j); + auto body = Store::make(X, {idx}, i); + auto inner = For::make(j, start, stopInner, body); + auto stmt = For::make(i, start, stopOuter, inner); + + PaddedBuffer data(400); + + EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data)); +} + +void testCond01() { + const int N = 16; + PaddedBuffer a_v(N); + BufHandle a_buf("a", {N}, kFloat); + VarHandle index = VarHandle("index", kInt); + StmtPtr assign_x2 = a_buf.store({index}, cast(index) * 2); + StmtPtr assign_x3 = a_buf.store({index}, cast(index) * 3); + ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ); + StmtPtr assign = Cond::make(even_cond, assign_x2, assign_x3); + StmtPtr for_stmt = For::make(index, 0, N, assign); + SimpleIREvaluator(for_stmt, {a_buf})(a_v); + + PaddedBuffer a_ref(N); + for (const auto i : c10::irange(N)) { + if (i % 2 == 0) { + a_ref(i) = i * 2; + } else { + a_ref(i) = i * 3; + } + } + ExpectAllNear(a_v, a_ref, 1e-5); +} + +void testIfThenElse01() { + ExprHandle v = ifThenElse(ExprHandle(1), ExprHandle(1.0f), ExprHandle(2.0f)); + + std::ostringstream oss; + oss << v; + ASSERT_EQ(oss.str(), "IfThenElse(1, 1.f, 2.f)"); + + SimpleIRExprEval eval(v); + ASSERT_EQ(eval.value(), 1.0f); +} + +void testIfThenElse02() { + ExprHandle v = ifThenElse(ExprHandle(0), ExprHandle(1.0f), ExprHandle(2.0f)); + + std::ostringstream oss; + oss << v; + ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)"); + + SimpleIRExprEval eval(v); + ASSERT_EQ(eval.value(), 2.0f); +} + +void testIfThenElse03() { + ExprHandle v = + ifThenElse(BoolImm::make(false), ExprHandle(1.0f), ExprHandle(2.0f)); + + std::ostringstream oss; + oss << v; + ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)"); + + SimpleIRExprEval eval(v); + ASSERT_EQ(eval.value(), 2.0f); +} + +void testStmtClone() { + const int N = 16; + + BufHandle a_buf("a", {N}, kInt); + VarHandle index = VarHandle("index", kInt); + StmtPtr body = a_buf.store({index}, 5); + StmtPtr loop = For::make(index, 0, N, body); + + StmtPtr cloned_loop = Stmt::clone(loop); + std::vector orig_loop_results(N); + std::vector cloned_loop_results(N); + SimpleIREvaluator(loop, {a_buf})(orig_loop_results); + SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results); + + assertAllEqual(orig_loop_results, 5); + assertAllEqual(cloned_loop_results, 5); + + // Let's add another assign to the body in the cloned loop and verify that the + // original statement hasn't changed while the cloned one has. + StmtPtr body_addition = a_buf.store({index}, 33); + BlockPtr cloned_body = static_to(static_to(cloned_loop)->body()); + cloned_body->append_stmt(body_addition); + + std::vector orig_loop_results_after_mutation(N); + std::vector cloned_loop_results_after_mutation(N); + SimpleIREvaluator(loop, {a_buf})(orig_loop_results_after_mutation); + SimpleIREvaluator(cloned_loop, {a_buf})(cloned_loop_results_after_mutation); + + assertAllEqual(orig_loop_results_after_mutation, 5); + assertAllEqual(cloned_loop_results_after_mutation, 33); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_external_calls.cpp b/test/cpp/tensorexpr/test_external_calls.cpp new file mode 100644 index 0000000000000..49f43d16b499d --- /dev/null +++ b/test/cpp/tensorexpr/test_external_calls.cpp @@ -0,0 +1,1061 @@ +#include + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; + +TEST(ExternalCall, Conv1d_float) { + BufHandle Input("Input", {1, 100, 115}, kFloat); + BufHandle Weight("Weight", {100, 1, 7}, kFloat); + BufHandle Bias("Bias", {100}, kFloat); + BufHandle ResultBuf("Result", {1, 100, 115}, kFloat); + int64_t stride = 1; + int64_t pad = 3; + int64_t dilation = 1; + int64_t groups = 100; + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, + "nnc_aten_conv1d", + {Input, Weight, Bias}, + {stride, pad, dilation, groups})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({1, 100, 115}, options) * 5.f; + at::Tensor weight = at::ones({100, 1, 7}, options) * 6.f; + at::Tensor bias = at::ones({100}, options) * 11.f; + at::Tensor ref = + at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups); + + at::Tensor nnc_result; + std::vector input_buf(1 * 100 * 115, 5.f); + std::vector weight_buf(100 * 1 * 7, 6.f); + std::vector bias_buf(100, 11.f); + std::vector result_buf(1 * 100 * 115, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); + + llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); + + ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Conv1d_int) { + // A similar test, but now using kInt tensors + BufHandle Input("Input", {1, 100, 115}, kInt); + BufHandle Weight("Weight", {100, 1, 7}, kInt); + BufHandle Bias("Bias", {100}, kInt); + BufHandle ResultBuf("Result", {1, 100, 115}, kInt); + int64_t stride = 1; + int64_t pad = 3; + int64_t dilation = 1; + int64_t groups = 100; + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, + "nnc_aten_conv1d", + {Input, Weight, Bias}, + {stride, pad, dilation, groups})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kInt) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({1, 100, 115}, options) * 5; + at::Tensor weight = at::ones({100, 1, 7}, options) * 6; + at::Tensor bias = at::ones({100}, options) * 11; + at::Tensor ref = + at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups); + + at::Tensor nnc_result; + std::vector input_buf(1 * 100 * 115, 5); + std::vector weight_buf(100 * 1 * 7, 6); + std::vector bias_buf(100, 11); + std::vector result_buf(1 * 100 * 115, -1); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); + + llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); + + ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Conv1d_nobias_noargs) { + BufHandle Input("Input", {1, 1, 115}, kFloat); + BufHandle Weight("Weight", {10, 1, 7}, kFloat); + BufHandle ResultBuf("Result", {1, 10, 109}, kFloat); + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make(ResultBuf, "nnc_aten_conv1d", {Input, Weight}, {})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({1, 1, 115}, options) * 5.f; + at::Tensor weight = at::ones({10, 1, 7}, options) * 6.f; + at::Tensor ref = at::conv1d(input, weight); + + at::Tensor nnc_result; + std::vector input_buf(1 * 1 * 115, 5.f); + std::vector weight_buf(10 * 1 * 7, 6.f); + std::vector result_buf(1 * 10 * 109, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result}); + + llvm_codegen.call({input_buf, weight_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result}); + + ir_eval.call({input_buf, weight_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Conv2d_float) { + BufHandle Input("Input", {1, 3, 224, 224}, kFloat); + BufHandle Weight("Weight", {16, 3, 3, 3}, kFloat); + BufHandle Bias("Bias", {16}, kFloat); + BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); + int64_t stride = 2; + int64_t pad = 1; + int64_t dilation = 1; + int64_t groups = 1; + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, + "nnc_aten_conv2d", + {Input, Weight, Bias}, + {stride, stride, pad, pad, dilation, dilation, groups})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5.f; + at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6.f; + at::Tensor bias = at::ones({16}, options) * 11.f; + at::Tensor ref = at::conv2d( + input, + weight, + bias, + {stride, stride}, + {pad, pad}, + {dilation, dilation}, + groups); + + at::Tensor nnc_result; + std::vector input_buf(1 * 3 * 224 * 224, 5.f); + std::vector weight_buf(16 * 3 * 3 * 3, 6.f); + std::vector bias_buf(16, 11.f); + std::vector result_buf(1 * 16 * 112 * 112, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); + + llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); + + ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Conv2d_int) { + // A similar test, but now using kInt tensors + + BufHandle Input("Input", {1, 3, 224, 224}, kInt); + BufHandle Weight("Weight", {16, 3, 3, 3}, kInt); + BufHandle Bias("Bias", {16}, kInt); + BufHandle ResultBuf("Result", {1, 16, 112, 112}, kInt); + int64_t stride = 2; + int64_t pad = 1; + int64_t dilation = 1; + int64_t groups = 1; + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, + "nnc_aten_conv2d", + {Input, Weight, Bias}, + {stride, stride, pad, pad, dilation, dilation, groups})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kInt) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5; + at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6; + at::Tensor bias = at::ones({16}, options) * 11; + at::Tensor ref = at::conv2d( + input, + weight, + bias, + {stride, stride}, + {pad, pad}, + {dilation, dilation}, + groups); + + at::Tensor nnc_result; + std::vector input_buf(1 * 3 * 224 * 224, 5); + std::vector weight_buf(16 * 3 * 3 * 3, 6); + std::vector bias_buf(16, 11); + std::vector result_buf(1 * 16 * 112 * 112, -1); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result}); + + llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result}); + + ir_eval.call({input_buf, weight_buf, bias_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Conv2d_nobias_noargs) { + BufHandle Input("Input", {1, 16, 112, 112}, kFloat); + BufHandle Weight("Weight", {16, 16, 1, 1}, kFloat); + BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make(ResultBuf, "nnc_aten_conv2d", {Input, Weight}, {})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({1, 16, 112, 112}, options) * 5.f; + at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f; + at::Tensor ref = at::conv2d(input, weight); + + at::Tensor nnc_result; + std::vector input_buf(1 * 16 * 112 * 112, 5.f); + std::vector weight_buf(16 * 16 * 1 * 1, 6.f); + std::vector result_buf(1 * 16 * 112 * 112, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result}); + + llvm_codegen.call({input_buf, weight_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result}); + + ir_eval.call({input_buf, weight_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Addmm_float) { + BufHandle Input("Input", {100, 300}, kFloat); + BufHandle Mat1("Mat1", {100, 200}, kFloat); + BufHandle Mat2("Mat2", {200, 300}, kFloat); + BufHandle ResultBuf("Result", {100, 300}, kFloat); + int64_t beta = 2; + int64_t alpha = 2; + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, "nnc_aten_addmm", {Input, Mat1, Mat2}, {beta, alpha})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({100, 300}, options) * 5.f; + at::Tensor mat1 = at::ones({100, 200}, options) * 6.f; + at::Tensor mat2 = at::ones({200, 300}, options) * 11.f; + at::Tensor ref = at::addmm(input, mat1, mat2, beta, alpha); + + at::Tensor nnc_result; + std::vector input_buf(100 * 300, 5.f); + std::vector mat1_buf(100 * 200, 6.f); + std::vector mat2_buf(200 * 300, 11.f); + std::vector result_buf(100 * 300, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Mat1, Mat2, Result}); + + llvm_codegen.call({input_buf, mat1_buf, mat2_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Mat1, Mat2, Result}); + + ir_eval.call({input_buf, mat1_buf, mat2_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Embedding) { + BufHandle Weight("Weight", {256, 100}, kFloat); + BufHandle Indices("Indices", {1, 115}, kLong); + BufHandle ResultBuf("Result", {1, 115, 100}, kFloat); + int64_t padding_idx = -1; + bool scale_grad_by_freq = false; + bool sparse = false; + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, + "nnc_aten_embedding", + {Weight, Indices}, + {padding_idx, (int64_t)scale_grad_by_freq, (int64_t)sparse})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + + at::Tensor weight = at::ones({256, 100}, options.dtype(at::kFloat)) * 5.f; + at::Tensor indices = at::ones({1, 115}, options.dtype(at::kLong)) * 6; + at::Tensor ref = + at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse); + + at::Tensor nnc_result; + std::vector weight_buf(256 * 100, 5.f); + std::vector indices_buf(1 * 115, 6); + std::vector result_buf(1 * 115 * 100, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Weight, Indices, Result}); + + llvm_codegen.call({weight_buf, indices_buf, result_buf}); + nnc_result = at::from_blob( + result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat)); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Weight, Indices, Result}); + + ir_eval.call({weight_buf, indices_buf, result_buf}); + nnc_result = at::from_blob( + result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat)); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, MaxReduction) { + BufHandle Input("Input", {1, 115, 152}, kFloat); + BufHandle ResultBuf("Result", {1, 152}, kFloat); + int64_t dim = 1; + bool keep_dim = false; + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, "nnc_aten_max_red", {Input}, {dim, (int64_t)keep_dim})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + + at::Tensor input = at::ones({1, 115, 152}, options) * 5.f; + at::Tensor ref = std::get<0>(at::max(input, dim, keep_dim)); + + at::Tensor nnc_result; + std::vector input_buf(1 * 115 * 152, 5.f); + std::vector result_buf(1 * 152, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Result}); + + llvm_codegen.call({input_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 152}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Result}); + + ir_eval.call({input_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 152}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +#ifdef USE_XNNPACK + +TEST(ExternalCall, Prepacked_Linear_float) { + using namespace at::native::xnnpack; + + BufHandle Input("Input", {100, 200}, kFloat); + BufHandle ResultBuf("Result", {100, 300}, kFloat); + + // Calculate reference result using at::linear. + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = + at::linspace(-10.0, 10.0, 100 * 200, options).resize_({100, 200}); + at::Tensor weight = + at::linspace(-10.0, 10.0, 300 * 200, options).resize_({300, 200}); + at::Tensor bias = at::linspace(-10.0, 10.0, 300, options); + at::Tensor ref = at::linear(input, weight, bias); + + // Create prepacked xnnpack context object. + auto linear_clamp_prepack_op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("prepacked::linear_clamp_prepack", "") + .typed( + at::Tensor, + std::optional, + const std::optional&, + const std::optional&)>(); + auto prepacked = linear_clamp_prepack_op.call( + weight, bias, std::optional(), std::optional()); + + BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat); + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, + "nnc_prepacked_linear_clamp_run", + {Input, DummyPrepacked}, + {})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + at::Tensor nnc_result; + std::vector input_buf( + input.data_ptr(), input.data_ptr() + 100 * 200); + std::vector result_buf(100 * 300, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result}); + + llvm_codegen.call({input_buf, prepacked.get(), result_buf}); + nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result}); + + ir_eval.call({input_buf, prepacked.get(), result_buf}); + nnc_result = at::from_blob(result_buf.data(), {100, 300}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Prepacked_Conv2d_float) { + using namespace at::native::xnnpack; + + BufHandle Input("Input", {1, 3, 224, 224}, kFloat); + BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat); + int64_t stride = 2; + int64_t pad = 1; + int64_t dilation = 1; + int64_t groups = 1; + + // Calculate reference result using at::conv2d. + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::linspace(-10.0, 10.0, 1 * 3 * 224 * 224, options) + .resize_({1, 3, 224, 224}); + at::Tensor weight = + at::linspace(-10.0, 10.0, 16 * 3 * 3 * 3, options).resize_({16, 3, 3, 3}); + at::Tensor bias = at::linspace(-10.0, 10.0, 16, options); + at::Tensor ref = at::conv2d( + input, + weight, + bias, + {stride, stride}, + {pad, pad}, + {dilation, dilation}, + groups); + + // Create prepacked xnnpack context object. + auto conv2d_clamp_prepack_op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("prepacked::conv2d_clamp_prepack", "") + .typed( + at::Tensor, + std::optional, + std::vector, + std::vector, + std::vector, + int64_t, + const std::optional&, + const std::optional&)>(); + auto prepacked = conv2d_clamp_prepack_op.call( + weight, + bias, + {stride, stride}, + {pad, pad}, + {dilation, dilation}, + groups, + std::optional(), + std::optional()); + + BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat); + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make( + ResultBuf, + "nnc_prepacked_conv2d_clamp_run", + {Input, DummyPrepacked}, + {})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + at::Tensor nnc_result; + std::vector input_buf( + input.data_ptr(), input.data_ptr() + 1 * 3 * 224 * 224); + std::vector result_buf(1 * 16 * 112 * 112, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result}); + + llvm_codegen.call({input_buf, prepacked.get(), result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result}); + + ir_eval.call({input_buf, prepacked.get(), result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03)); +} + +#endif // USE_XNNPACK + +TEST(ExternalCall, BinaryFloat) { + using TensorFunc = std::function; + using Test = std::tuple< + std::vector, + std::vector, + std::vector, + TensorFunc, + std::string>; + std::vector tests = {}; + tests.push_back( + Test{{100, 200}, {200, 300}, {100, 300}, at::matmul, "nnc_aten_matmul"}); + tests.push_back(Test{{100, 300}, {300}, {100}, at::mv, "nnc_aten_mv"}); + tests.push_back(Test{ + {100, 200}, + {200, 300}, + {100, 300}, + [&](const at::Tensor& a, const at::Tensor& b) { return at::mm(a, b); }, + "nnc_aten_mm"}); + for (auto curTest : tests) { + auto [aShape, bShape, resShape, torchFunc, externCallName] = curTest; + auto toExprHandleVec = [](std::vector v) { + auto intV = std::vector(v.begin(), v.end()); + return std::vector(intV.begin(), intV.end()); + }; + BufHandle A("A", toExprHandleVec(aShape), kFloat); + BufHandle B("B", toExprHandleVec(bShape), kFloat); + BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat); + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make(ResultBuf, externCallName, {A, B}, {})); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f; + at::Tensor b = at::ones(c10::IntArrayRef(bShape), options) * 6.f; + at::Tensor ref = torchFunc(a, b); + + auto prod = [](std::vector v) { + // NOLINTNEXTLINE(modernize-use-transparent-functors) + return std::accumulate(v.begin(), v.end(), 1, std::multiplies()); + }; + + at::Tensor nnc_result; + std::vector a_buf(prod(aShape), 5.f); + std::vector b_buf(prod(bShape), 6.f); + std::vector result_buf(prod(resShape), -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {A, B, Result}); + + llvm_codegen.call({a_buf, b_buf, result_buf}); + nnc_result = + at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {A, B, Result}); + ir_eval.call({a_buf, b_buf, result_buf}); + nnc_result = + at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); + } +} + +TEST(ExternalCall, UnaryFloat) { + using TensorFunc = std::function; + auto toExprHandleVec = [](std::vector v) { + auto intV = std::vector(v.begin(), v.end()); + return std::vector(intV.begin(), intV.end()); + }; + using Test = std::tuple< + std::vector, + std::vector, + TensorFunc, + std::string, + std::vector>; + std::vector tests = {}; + tests.push_back(Test{ + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + {1, 64, 8, 9}, + {1, 64, 5, 7}, + [](at::Tensor x) { return at::adaptive_avg_pool2d(x, {5, 7}); }, + "nnc_aten_adaptive_avg_pool2d", + toExprHandleVec({5, 7})}); + tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + {100, 200}, + {100}, + [](at::Tensor x) { return at::mean(x, {1}); }, + "nnc_aten_mean", + toExprHandleVec({1, /*keepdim=*/0})}); + for (auto curTest : tests) { + auto [aShape, resShape, torchFunc, externCallName, externCallArgs] = + curTest; + BufHandle A("A", toExprHandleVec(aShape), kFloat); + BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat); + + Tensor Result = Tensor( + ResultBuf.node(), + ExternalCall::make(ResultBuf, externCallName, {A}, externCallArgs)); + LoopNest l({Result}); + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f; + at::Tensor ref = torchFunc(a); + + auto prod = [](std::vector v) { + // NOLINTNEXTLINE(modernize-use-transparent-functors) + return std::accumulate(v.begin(), v.end(), 1, std::multiplies()); + }; + + at::Tensor nnc_result; + std::vector a_buf(prod(aShape), 5.f); + std::vector result_buf(prod(resShape), -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {A, Result}); + + llvm_codegen.call({a_buf, result_buf}); + nnc_result = + at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {A, Result}); + ir_eval.call({a_buf, result_buf}); + nnc_result = + at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); + } +} + +TEST(ExternalCall, ComputeInterop) { + // This test verifies that Tensors using external calls can be used by and can + // use Tensors built with Compute API. + + BufHandle ConvResultBuf("ConvResult", {1, 16, 32, 32}, kFloat); + BufHandle MatmulResultBuf("MatmulResult", {1, 16, 32, 32}, kFloat); + + Tensor Input = Compute( + "Input", + {1, 16, 32, 32}, + [&](const VarHandle& n, + const VarHandle& c, + const VarHandle& h, + const VarHandle& w) { return FloatImm::make(5.0f); }); + Tensor Weight = Compute( + "Weight", + {16, 16, 1, 1}, + [&](const VarHandle& n, + const VarHandle& c, + const VarHandle& h, + const VarHandle& w) { return FloatImm::make(6.0f); }); + + Tensor ConvResult = Tensor( + ConvResultBuf.node(), + ExternalCall::make( + ConvResultBuf, + "nnc_aten_conv2d", + {BufHandle(Input.buf()), BufHandle(Weight.buf())}, + {})); + Tensor MatmulResult = Tensor( + MatmulResultBuf.node(), + ExternalCall::make( + MatmulResultBuf, + "nnc_aten_matmul", + {BufHandle(ConvResult.buf()), BufHandle(ConvResult.buf())}, + {})); + Tensor Result = Compute( + "Result", + {1, 16, 32, 32}, + [&](const VarHandle& n, + const VarHandle& c, + const VarHandle& h, + const VarHandle& w) { + return ConvResult.load(n, c, h, w) + MatmulResult.load(n, c, h, w); + }); + + LoopNest l({Input, Weight, ConvResult, MatmulResult, Result}); + + // Inlining should not inline anything here since all Bufs are either defined + // or used in ExternalCalls - we run it just for testing + l.inlineIntermediateBufs(true); + + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor input = at::ones({1, 16, 32, 32}, options) * 5.f; + at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f; + at::Tensor t = at::conv2d(input, weight); + at::Tensor t2 = at::matmul(t, t); + at::Tensor ref = t + t2; + + at::Tensor nnc_result; + std::vector input_buf(1 * 16 * 32 * 32, 5.f); + std::vector weight_buf(16 * 16 * 1 * 1, 6.f); + std::vector conv_result_buf(1 * 16 * 32 * 32, -1.f); + std::vector matmul_result_buf(1 * 16 * 32 * 32, -1.f); + std::vector result_buf(1 * 16 * 32 * 32, -1.f); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen( + l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result}); + + llvm_codegen.call( + {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval( + l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result}); + + ir_eval.call( + {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf}); + nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, Inlining) { + // This test verifies that Tensors using external calls can be used by and + // can use Tensors built with Compute API. + + BufHandle MatmulResultBuf("MatmulResult", {8, 8}, kFloat); + + Tensor A = Compute("A", {8, 8}, [&](const VarHandle& i, const VarHandle& j) { + return FloatImm::make(5.0f); + }); + Tensor B = Compute("B", {8, 8}, [&](const VarHandle& i, const VarHandle& j) { + return FloatImm::make(4.0f); + }); + Tensor MatmulResult = Tensor( + MatmulResultBuf.node(), + ExternalCall::make( + MatmulResultBuf, + "nnc_aten_matmul", + {BufHandle(A.buf()), BufHandle(B.buf())}, + {})); + Tensor Result = + Compute("Result", {8, 8}, [&](const VarHandle& i, const VarHandle& j) { + return MatmulResult.load(i, j) + FloatImm::make(3.0f); + }); + + StmtPtr root_stmt = alloc(std::vector( + {A.stmt(), B.stmt(), MatmulResult.stmt(), Result.stmt()})); + LoopNest l(root_stmt, {Result.buf()}); + + // Inlining should not inline anything here since all Bufs are either + // defined or used in ExternalCalls + l.inlineIntermediateBufs(false); + + l.prepareForCodegen(); + l.simplify(); + + auto options = at::TensorOptions() + .dtype(at::kFloat) + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + at::Tensor a = at::ones({8, 8}, options) * 5.f; + at::Tensor b = at::ones({8, 8}, options) * 4.f; + at::Tensor t = at::matmul(a, b); + at::Tensor ref = t + 3.f; + + at::Tensor nnc_result; + std::vector result_buf(8 * 8); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen llvm_codegen(l.root_stmt(), {Result}); + + llvm_codegen.call({result_buf}); + nnc_result = at::from_blob(result_buf.data(), {8, 8}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +#endif + + SimpleIREvaluator ir_eval(l.root_stmt(), {Result}); + + ir_eval.call({result_buf}); + nnc_result = at::from_blob(result_buf.data(), {8, 8}, options); + ASSERT_TRUE(at::allclose(nnc_result, ref)); +} + +TEST(ExternalCall, JitCustomFusionOp) { + const char* custom_op_schema_literal = + "nnc_custom::add_mul(Tensor a, Tensor b, Tensor c) -> Tensor"; + const char* external_func_name = "nnc_add_mul"; + + auto add_mul_lowering_func = + [external_func_name]( + const std::vector& inputs, + const std::vector& output_shape, + const std::vector& output_strides, + const std::optional& output_type, + at::Device device) { + auto output_dtype = Dtype(*output_type); + torch::jit::tensorexpr::BufHandle result_buf( + "nnc_add_mul_res_buf", output_shape, output_dtype); + const torch::jit::tensorexpr::BufHandle& a = + std::get(inputs[0]); + const torch::jit::tensorexpr::BufHandle& b = + std::get(inputs[1]); + const torch::jit::tensorexpr::BufHandle& c = + std::get(inputs[1]); + torch::jit::tensorexpr::StmtPtr s = + torch::jit::tensorexpr::ExternalCall::make( + result_buf, external_func_name, {a, b, c}, {}); + return Tensor(result_buf.node(), s); + }; + + auto add_mul_external_func = [](int64_t bufs_num, + void** buf_data, + int64_t* buf_ranks, + int64_t* buf_dims, + int64_t* buf_strides, + int8_t* buf_dtypes, + int64_t args_num, + int64_t* extra_args) {}; + + torch::jit::RegisterOperators reg({Operator( + custom_op_schema_literal, + [](const Node* node) -> Operation { + return [](Stack& _stack) { + auto a = std::move(peek(_stack, 0, 3)).toTensor(); + auto b = std::move(peek(_stack, 1, 3)).toTensor(); + auto c = std::move(peek(_stack, 2, 3)).toTensor(); + drop(_stack, 3); + auto result = (a + b) * c; + pack(_stack, std::move(result)); + return 0; + }; + }, + c10::AliasAnalysisKind::FROM_SCHEMA)}); + + auto& custom_operator_set = torch::jit::tensorexpr::getCustomOperatorSet(); + custom_operator_set.insert({custom_op_schema_literal}); + + auto& te_lowering_registry = torch::jit::tensorexpr::getNNCLoweringRegistry(); + te_lowering_registry.insert( + parseSchema(custom_op_schema_literal), add_mul_lowering_func); + + auto& te_nnc_func_registry = torch::jit::tensorexpr::getNNCFunctionRegistry(); + te_nnc_func_registry[external_func_name] = add_mul_external_func; + + std::string graph_string = R"IR( + graph(%a : Float(10, 20, strides=[20, 1], device=cpu), + %b : Float(10, 20, strides=[20, 1], device=cpu), + %c : Float(10, 20, strides=[20, 1], device=cpu)): + %res : Float(10, 20, strides=[20, 1], device=cpu) = nnc_custom::add_mul(%a, %b, %c) + return (%res))IR"; + + auto graph = std::make_shared(); + torch::jit::parseIR(graph_string, graph.get()); + + std::string shape_compute_python_string = R"PY( + def computOutput(a: List[int], b: List[int], c: List[int]): + expandedSizes: List[int] = [] + dimsA = len(a) + dimsB = len(b) + dimsC = len(c) + ndim = max(dimsA, dimsB, dimsC) + for i in range(ndim): + offset = ndim - 1 - i + dimA = dimsA - 1 - offset + dimB = dimsB - 1 - offset + dimC = dimsC - 1 - offset + sizeA = a[dimA] if (dimA >= 0) else 1 + sizeB = b[dimB] if (dimB >= 0) else 1 + sizeC = a[dimC] if (dimC >= 0) else 1 + + if sizeA != sizeB and sizeB != sizeC and sizeA != 1 and sizeB != 1 and sizeC != 1: + # TODO: only assertion error is bound in C++ compilation right now + raise AssertionError( + "The size of tensor a {} must match the size of tensor b (" + "{} and c {}) at non-singleton dimension {}".format(sizeA, sizeB, sizeC, i) + ) + + expandedSizes.append(max(sizeA, sizeB, sizeC)) + + return expandedSizes + )PY"; + auto cu_ptr = torch::jit::compile(shape_compute_python_string); + torch::jit::GraphFunction* gf = + (torch::jit::GraphFunction*)&cu_ptr->get_function("computOutput"); + ASSERT_TRUE(gf); + +#ifdef TORCH_ENABLE_LLVM + auto static_graph_case = graph->copy(); + FuseTensorExprs(static_graph_case, 1); + torch::jit::testing::FileCheck() + .check("prim::TensorExprGroup_") + ->check("nnc_custom::add_mul") + ->run(*static_graph_case); + + auto dynamic_graph_case = graph->copy(); + auto custom_op = torch::jit::getOperatorForLiteral(custom_op_schema_literal); + ASSERT_TRUE(custom_op); + torch::jit::RegisterShapeComputeGraphForSchema( + custom_op->schema(), gf->graph()); + FuseTensorExprs(dynamic_graph_case, 1, false, true); + torch::jit::testing::FileCheck() + .check("prim::TensorExprGroup_") + ->check("nnc_custom::add_mul") + ->run(*dynamic_graph_case); +#else + torch::jit::testing::FileCheck().check("nnc_custom::add_mul")->run(*graph); +#endif +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_graph_opt.cpp b/test/cpp/tensorexpr/test_graph_opt.cpp new file mode 100644 index 0000000000000..aed73d09d14d5 --- /dev/null +++ b/test/cpp/tensorexpr/test_graph_opt.cpp @@ -0,0 +1,319 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +class GraphOpt : public ::testing::Test { + public: + void SetUp() override { + old_cat_wo_conditionals_ = getCatWoConditionals(); + getCatWoConditionals() = true; + } + + void TearDown() override { + getCatWoConditionals() = old_cat_wo_conditionals_; + } + + private: + bool old_cat_wo_conditionals_; +}; + +TEST_F(GraphOpt, OptimizeCat) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat) + return (%5))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + + TensorExprKernel kernel(g); + + // The `aten::log` op must be moved to the inputs of `aten::cat`. + testing::FileCheck() + .check("aten::log") + ->check("aten::log") + ->check("aten::log") + ->check("aten::cat") + ->check_not("aten::log") + ->run(*kernel.graph()); + + auto x = at::rand({10}, at::kFloat); + auto y = at::rand({20}, at::kFloat); + auto z = at::rand({30}, at::kFloat); + auto ref = at::log(at::cat({x, y, z}, 0)); + + std::vector inputs = {x, y, z}; + std::vector stack = fmap(inputs); + kernel.run(stack); + auto out = stack[0].toTensor(); + ASSERT_EQ(out.sizes(), ref.sizes()); + ASSERT_EQ(out.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(out, ref)); +#endif +} + +TEST_F(GraphOpt, OptimizeCat2) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat) + %6 : Float(60, strides=[1], device=cpu) = aten::tanh(%5) + return (%6))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + + TensorExprKernel kernel(g); + + // The `aten::log` and `aten::tanh` ops must be moved to the inputs of + // `aten::cat`. + testing::FileCheck() + .check("aten::log") + ->check("aten::log") + ->check("aten::log") + ->check("aten::tanh") + ->check("aten::tanh") + ->check("aten::tanh") + ->check("aten::cat") + ->check_not("aten::log") + ->check_not("aten::tanh") + ->run(*kernel.graph()); + + auto x = at::rand({10}, at::kFloat); + auto y = at::rand({20}, at::kFloat); + auto z = at::rand({30}, at::kFloat); + auto ref = at::tanh(at::log(at::cat({x, y, z}, 0))); + + std::vector inputs = {x, y, z}; + std::vector stack = fmap(inputs); + kernel.run(stack); + auto out = stack[0].toTensor(); + ASSERT_EQ(out.sizes(), ref.sizes()); + ASSERT_EQ(out.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(out, ref)); +#endif +} + +TEST_F(GraphOpt, OptimizeCat3) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%a : Float(60, strides=[1], device=cpu), + %x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat) + %6 : Float(60, strides=[1], device=cpu) = aten::mul(%a, %5) + return (%6))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + + TensorExprKernel kernel(g); + + // The `aten::tanh` op must be moved to the inputs of `aten::cat`. + // But the `aten::mul` op must not be moved since it is not a single-tensor + // op (it has 2 tensor inputs). + testing::FileCheck() + .check("aten::tanh") + ->check("aten::tanh") + ->check("aten::tanh") + ->check("aten::cat") + ->check("aten::mul") + ->check_not("aten::tanh") + ->run(*kernel.graph()); + + auto a = at::rand({60}, at::kFloat); + auto x = at::rand({10}, at::kFloat); + auto y = at::rand({20}, at::kFloat); + auto z = at::rand({30}, at::kFloat); + auto ref = at::tanh(at::cat({x, y, z}, 0)) * a; + + std::vector inputs = {a, x, y, z}; + std::vector stack = fmap(inputs); + kernel.run(stack); + auto out = stack[0].toTensor(); + ASSERT_EQ(out.sizes(), ref.sizes()); + ASSERT_EQ(out.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(out, ref)); +#endif +} + +TEST_F(GraphOpt, OptimizeCatWithTypePromotionInUser) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%x : Int(10, strides=[1], device=cpu), + %y : Int(20, strides=[1], device=cpu), + %z : Int(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Int(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + %5 : Float(60, strides=[1], device=cpu) = aten::tanh(%cat) + return (%5))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + + TensorExprKernel kernel(g); + + // The `aten::tanh` op must be moved to the inputs of `aten::cat`. + // The scalar type of the inputs to `cat` should now be `Float` since they + // are the result of `tanh` which does the type promotion. + testing::FileCheck() + .check("aten::tanh") + ->check("aten::tanh") + ->check("aten::tanh") + ->check("aten::cat") + ->check_not("aten::tanh") + ->run(*kernel.graph()); + + auto x = at::randint(std::numeric_limits::max(), {10}, at::kInt); + auto y = at::randint(std::numeric_limits::max(), {20}, at::kInt); + auto z = at::randint(std::numeric_limits::max(), {30}, at::kInt); + auto ref = at::tanh(at::cat({x, y, z}, 0)); + + std::vector inputs = {x, y, z}; + std::vector stack = fmap(inputs); + kernel.run(stack); + auto out = stack[0].toTensor(); + ASSERT_EQ(out.sizes(), ref.sizes()); + ASSERT_EQ(out.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(out, ref)); +#endif +} + +TEST_F(GraphOpt, OptimizeCatWithTypePromotionInCat) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Double(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Double(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + %5 : Double(60, strides=[1], device=cpu) = aten::log(%cat) + return (%5))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + + TensorExprKernel kernel(g); + + // No transformation should have happened because the `aten::cat` op performs + // type promotion. This case is currently not handled. + testing::FileCheck() + .check("aten::cat") + ->check("aten::log") + ->check_not("aten::cat") + ->check_not("aten::log") + ->run(*kernel.graph()); +#endif +} + +TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%0 : Float(60, strides=[1], device=cpu), + %x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat) + return (%5))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + + TensorExprKernel kernel(g); + + // No transformation is expected since the consumers of cat are not + // single-tensor element-wise ops. + testing::FileCheck() + .check("aten::cat") + ->check("aten::mul") + ->check_not("aten::cat") + ->check_not("aten::mul") + ->run(*kernel.graph()); +#endif +} + +TEST_F(GraphOpt, OptimizeCatNoSingleTensorElementwiseOp2) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%0 : Float(60, strides=[1], device=cpu), + %1 : Float(60, strides=[1], device=cpu), + %x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(30, strides=[1], device=cpu)): + %one : int = prim::Constant[value=1]() + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + %5 : Float(60, strides=[1], device=cpu) = aten::mul(%0, %cat) + %6 : Float(60, strides=[1], device=cpu) = aten::add(%5, %1, %one) + return (%6))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + + TensorExprKernel kernel(g); + + // No transformation is expected since the consumers of cat are not + // single-tensor element-wise ops. + testing::FileCheck() + .check("aten::cat") + ->check("aten::mul") + ->check("aten::add") + ->check_not("aten::cat") + ->check_not("aten::mul") + ->check_not("aten::add") + ->run(*kernel.graph()); +#endif +} + +TEST_F(GraphOpt, AOTGraphPrepPasses) { + const auto graph_string = R"IR( + graph(%x, %y, %z, %i : int): + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + return (%xyz_list, %i))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + removeGraphOutput(g, 1); + replaceListOutputWithTuple(g); + LowerAllTuples(g); + + testing::FileCheck().check("return (%x, %y, %z)")->run(*g); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_ir_printer.cpp b/test/cpp/tensorexpr/test_ir_printer.cpp new file mode 100644 index 0000000000000..4d2f8c6e906ee --- /dev/null +++ b/test/cpp/tensorexpr/test_ir_printer.cpp @@ -0,0 +1,98 @@ +#include + +#include +#include "test/cpp/tensorexpr/test_base.h" + +#include +#include +#include +#include +#include +#include + +#include +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +TEST(IRPrinter, BasicValueTest) { + ExprHandle a = IntImm::make(2), b = IntImm::make(3); + ExprHandle c = Add::make(a, b); + + std::stringstream ss; + ss << c; + ASSERT_EQ(ss.str(), "2 + 3"); +} + +TEST(IRPrinter, BasicValueTest02) { + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle c(4.0f); + ExprHandle d(5.0f); + ExprHandle f = (a + b) - (c + d); + + std::stringstream ss; + ss << f; + ASSERT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)"); +} + +TEST(IRPrinter, BasicValueTest03) { + ExprHandle a(3.402823466385289e+38f); + ExprHandle b(-3.402823466385289e+38f); + std::stringstream ss; + ss << a << ", " << b; + ASSERT_EQ(ss.str(), "3.402823466385289e+38f, -3.402823466385289e+38f"); +} + +TEST(IRPrinter, CastTest) { + VarHandle x("x", kHalf); + VarHandle y("y", kFloat); + ExprHandle body = ExprHandle(2.f) + + (Cast::make(kFloat, x) * ExprHandle(3.f) + ExprHandle(4.f) * y); + + std::stringstream ss; + ss << body; + ASSERT_EQ(ss.str(), "2.f + (float(x) * 3.f + 4.f * y)"); +} + +TEST(IRPrinter, FunctionName) { + int M = 4; + int N = 20; + + Tensor producer = Compute( + "producer", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return m * n; + }); + + Tensor chunk_0 = Compute( + "chunk_0", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) { + return producer.load(m, n); + }); + + Tensor chunk_1 = Compute( + "chunk_1", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) { + return producer.load(m, n + ExprHandle(N / 2)); + }); + + Tensor consumer = Compute( + "consumer", {M, N / 2}, [&](const ExprHandle& i, const ExprHandle& j) { + return i * chunk_1.load(i, j); + }); + + LoopNest l({chunk_0, chunk_1, consumer}); + auto body = LoopNest::sanitizeNames(l.root_stmt()); + + std::stringstream ss; + ss << *body; + + const std::string& verification_pattern = + R"IR( + # CHECK: for (int i_2 + # CHECK: for (int j_2 + # CHECK: consumer[i_2, j_2] = i_2 * (chunk_1[i_2, j_2])IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, ss.str()); +} +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_ir_verifier.cpp b/test/cpp/tensorexpr/test_ir_verifier.cpp new file mode 100644 index 0000000000000..886213ea9c760 --- /dev/null +++ b/test/cpp/tensorexpr/test_ir_verifier.cpp @@ -0,0 +1,191 @@ +#include + +#include +#include "test/cpp/tensorexpr/test_base.h" + +#include +#include +#include +#include +#include +#include + +#include +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +TEST(IRVerifier, BitwiseOps) { + VarPtr X = alloc("x", kInt); + VarPtr Y = alloc("y", kFloat); + { + auto a = alloc(X, Y); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + auto a = alloc(X, Y); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + auto a = alloc(X, Y); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + auto a = alloc(X, Y); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + auto a = alloc(X, Y); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } +} + +TEST(IRVerifier, CompareSelect) { + ExprPtr X = alloc(1); + ExprPtr Y = alloc(3.14f); + { + auto a = alloc(X, X, X, Y, kEQ); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + auto a = alloc(X, Y, X, X, kEQ); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } +} + +TEST(IRVerifier, Ramp) { + VarPtr I = alloc("i", kInt); + VarPtr J = alloc("j", kFloat); + { + auto a = alloc(I, J, 4); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } +} + +TEST(IRVerifier, Load) { + VarPtr I = alloc("i", kInt); + VarPtr J = alloc("j", kLong); + VarPtr K = alloc("k", kFloat); + BufPtr B = alloc( + "b", + std::vector({alloc(10), alloc(20)}), + kFloat); + { + // Indices with different int dtypes (kInt, kLong) are ok + auto a = alloc(B, std::vector({I, J})); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_NO_THROW(verify(a)); + } + { + // Float index + auto a = alloc(B, std::vector({K, K})); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + // Multilanes are only allowed in flattened indices + auto multilane_index = alloc(I, alloc(1), 4); + auto a = alloc(B, std::vector({I, multilane_index})); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } +} + +TEST(IRVerifier, IfThenElse) { + VarPtr I = alloc("i", kInt); + VarPtr J = alloc("j", kLong); + VarPtr K = alloc("k", kFloat); + { + // Condition must be integral + auto a = alloc(K, I, I); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + // Dtypes of true and false exprs must match + auto a = alloc(I, I, J); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + // Can't have multiple lanes in condition expr + auto a = alloc(alloc(I, 4), I, I); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } +} + +TEST(IRVerifier, For) { + VarPtr I = alloc("i", kInt); + VarPtr J = alloc("j", kInt); + StmtPtr body = alloc(std::vector({})); + { + // Can't have nullptr as a Var + auto a = alloc(nullptr, I, J, body); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_ANY_THROW(verify(a)); + } +} + +TEST(IRVerifier, Block) { + VarPtr I = alloc("i", kInt); + BufPtr B = alloc("B", std::vector({alloc(10)}), kInt); + { + StmtPtr store = alloc(B, std::vector({I}), I); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + StmtPtr block1 = alloc(std::vector({store})); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + StmtPtr block2 = alloc(std::vector({store})); + // Stmt can't have multiple parents, thus inserting it into several blocks + // is illegal + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(block2)); + } +} + +TEST(IRVerifier, Store) { + VarPtr I = alloc("i", kInt); + VarPtr J = alloc("j", kLong); + VarPtr K = alloc("k", kFloat); + BufPtr B = alloc( + "b", + std::vector({alloc(10), alloc(20)}), + kFloat); + { + // Indices with different int dtypes (kInt, kLong) are ok + auto a = alloc(B, std::vector({I, J}), K); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_NO_THROW(verify(a)); + } + { + // Float index + auto a = alloc(B, std::vector({K, K}), K); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + // Multilanes are only allowed in flattened indices + auto multilane_index = alloc(I, alloc(1), 4); + auto a = alloc(B, std::vector({I, multilane_index}), K); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } + { + // Value and buf dtypes mismatch + auto a = alloc(B, std::vector({I}), I); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks) + EXPECT_ANY_THROW(verify(a)); + } +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp new file mode 100644 index 0000000000000..22f6b64efe1a8 --- /dev/null +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -0,0 +1,2133 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::indexing; +using namespace torch::jit::tensorexpr; + +class Kernel : public ::testing::Test { + public: + void SetUp() override { + getTEMustUseLLVMOnCPU() = false; + } +}; + +TEST_F(Kernel, ParallelExternalCallBuf) { + const auto graph_string = R"IR( + graph(%0 : Float(1000, 5000, strides=[5000, 1], device=cpu), + %1 : Float(1000, 5000, strides=[5000, 1], device=cpu), + %2 : Float(5000, 1000, strides=[5000, 1], device=cpu)): + %3 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::mul(%0, %1) + %4 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::matmul(%3, %2) + return (%4))IR"; + auto graph = std::make_shared(); + torch::jit::parseIR(graph_string, &*graph); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int64_t i = 0ll; i < 5000ll; i++) /* parallel */{)IR"; + +#ifdef TORCH_ENABLE_LLVM + TensorExprKernel k(graph); + StmtPtr s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +#endif +} + +TEST_F(Kernel, InliningIntermediates) { + // here, each mul has only one use, so it should be completely inlined + { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[3, 1], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %one : int = prim::Constant[value=1]() + %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + %5: Float(5, 3, strides=[3, 1]) = aten::add(%4, %1, %one) + return (%5))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + auto stmt = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *stmt; + torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str()); + } + { + const auto graph_template = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=${device}), + %1 : Float(5, 3, strides=[3, 1], device=${device})): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %one : int = prim::Constant[value=1]() + %3 : Float(5, 3, strides=[3, 1]) = aten::sub(%0, %2, %one) + %4 : Float(5, 3, strides=[3, 1]) = aten::add(%3, %0, %one) + %5 : Float(5, 3, strides=[3, 1]) = aten::div(%3, %0) + return (%4, %5))IR"; + for (bool use_cuda : {false, true}) { + if (!torch::cuda::is_available() && use_cuda) { + continue; + } + + at::jit::TemplateEnv env; + env.s("device", use_cuda ? "cuda:0" : "cpu"); + const auto graph_string = format(graph_template, env); + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + auto stmt = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *stmt; + // aten_mul only has one use, inlined completely + torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str()); + + // aten_sub should be removed by the CUDA backend by metavar rewriting + // and by the CPU backend by horizontal fusion. + torch::jit::testing::FileCheck().check_not("aten_sub")->run(oss.str()); + } + } +} + +TEST_F(Kernel, PreAllocIntermediateBufs) { + const auto graph_string = R"IR( +graph(%a.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu), + %b.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu)): + %2 : int = prim::Constant[value=1]() + %c.2 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::matmul(%a.1, %b.1) # test_matmul.py:12:12 + %3 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %c.2, %2) # test_matmul.py:13:15 + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::matmul(a, b) + a; + TensorExprKernel k(graph, {}, {}, true); + + std::vector inputs = {a, b}; + auto stmt = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *stmt; + + // Check whether the intermediate buffer has been added to constants + auto constants = k.getConstantDescriptors(); + ASSERT_EQ(constants.size(), 1); + + // Check the IR we produced + torch::jit::testing::FileCheck().check_not("Alloc")->run(oss.str()); + torch::jit::testing::FileCheck().check_not("Free")->run(oss.str()); + + // Check correctness + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, _1) { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[3, 1], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 5 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +} + +TEST_F(Kernel, _2) { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[1, 5], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = + at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); + auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 5 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +} + +TEST_F(Kernel, _3) { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[12, 2], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat)) + .index({Slice(None, None, 2), Slice(None, None, 2)}); + auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 5 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +} + +TEST_F(Kernel, Huge) { + const auto graph_string = R"IR( + graph(%x.1 : Float(4000000000, strides=[1], requires_grad=0, device=cpu)): + %1 : int = prim::Constant[value=0]() + %2 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::unsqueeze(%x.1, %1) + %3 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::relu(%2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + std::ostringstream oss; + oss << *k.getCodeGenStmt(); + // The 4000000000 iterations loop will be split into 500000000 x 8 and the + // outer loop will be parallel. If LLVM is not present, it will not be split, + // and to cover both of these cases we're looking for 00000000ll; in the + // output. + const std::string& verification_pattern = R"IR(# CHECK: 00000000ll;)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST_F(Kernel, ParallelStrided) { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu), + %1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)): + %2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat)) + .index( + {Slice(None, None, 2), + Slice(None, None, 2), + Slice(None, None, 2)}); + auto ref = a * (a * b); + auto o = at::zeros_like(ref); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 5 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +} + +TEST_F(Kernel, DISABLED_Shape_Inference) { + // disabled: doesn't do stride propagation, and isn't being used currently + + // Test TensorExpr shape inference capabilities: it should only require shapes + // for the inputs + { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[12, 2], device=cpu)): + %2 : Tensor = aten::mul(%0, %1) + %3 : Tensor = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat)) + .index({Slice(None, None, 2), Slice(None, None, 2)}); + auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 5 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } + } + { + const auto graph_string = R"IR( + graph(%0 : Float(8, 8, strides=[8, 1], device=cpu), + %1 : Float(8, 8, strides=[8, 1], device=cpu)): + %2 : Tensor = aten::mul(%0, %1) + %3 : Tensor, %4 : Tensor = prim::ConstantChunk[dim=1,chunks=2](%2) + %r : Tensor = aten::mul(%3, %4) + return (%r))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({8, 4}, TensorOptions(kCPU).dtype(at::kFloat)); + auto t = torch::chunk(a * b, 2, 1); + auto ref = t[0] * t[1]; + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + TORCH_CHECK_EQ(o.sizes()[0], 8); + TORCH_CHECK_EQ(o.sizes()[1], 4); + for (size_t i = 0; i < 8 * 4; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } + } + { + // Test that shape inference handles aten::unsqueeze + + const auto graph_string = R"IR( + graph(%a : Float(4, 2, strides=[2, 1], device=cpu), + %b : Float(4, 3, 2, strides=[6, 2, 1], device=cpu), + %c : Float(3, 2, 2, strides=[4, 2, 1], device=cpu)): + %one : int = prim::Constant[value=1]() + %minus_one : int = prim::Constant[value=-1]() + %three : int = prim::Constant[value=3]() + %minus_four : int = prim::Constant[value=-4]() + %a1 : Tensor = aten::unsqueeze(%a, %one) # new size: [4,1,2] + %a2 : Tensor = aten::unsqueeze(%a1, %minus_one) # new size: [4,1,2,1] + %b1 : Tensor = aten::unsqueeze(%b, %three) # new size: [4,3,2,1] + %c1 : Tensor = aten::unsqueeze(%c, %minus_four) # new size: [1,3,2,2] + %ab : Tensor = aten::mul(%a2, %b1) # expected size: [4,3,2,1] + %abc : Tensor = aten::mul(%ab, %c1) # expected size: [4,3,2,2] + return (%abc))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({4, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({4, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto c = at::rand({3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({4, 3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::unsqueeze(at::unsqueeze(a, 1), -1) * at::unsqueeze(b, 3) * + at::unsqueeze(c, -4); + + TensorExprKernel k(graph); + std::vector inputs = {a, b, c}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NEXT: for +# CHECK-NEXT: for +# CHECK-NEXT: aten_mul)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + + // Check sizes + TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); + size_t num_el = 1; + for (const auto idx : c10::irange(ref.sizes().size())) { + TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + num_el *= ref.sizes()[idx]; + } + + // Check the contents + for (const auto i : c10::irange(num_el)) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } + } + { + // Test that shape inference handles aten::cat + + const auto graph_string = R"IR( + graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), + %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), + %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) + %r : Tensor = aten::cat(%inputs, %dim) # new size: [5,19,2] + return (%r))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({5, 19, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::cat({a, b, c}, 1); + + TensorExprKernel k(graph); + std::vector inputs = {a, b, c}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NEXT: for +# CHECK-NEXT: aten_cat)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + + // Check sizes + TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); + size_t num_el = 1; + for (const auto idx : c10::irange(ref.sizes().size())) { + TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + num_el *= ref.sizes()[idx]; + } + + // Check the contents + for (const auto i : c10::irange(num_el)) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } + } + { + // Test that we throw an error when input list for aten::cat is empty + + const auto graph_string = R"IR( + graph(): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct() + %r : Tensor = aten::cat(%inputs, %dim) + return (%r))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + auto compile = [&]() { + TensorExprKernel k(graph); + k.getCodeGenStmt(); + }; + ASSERT_THROWS_WITH(compile(), "Empty input list is passed to aten::cat"); + } + { + // Test that we throw an error when 'dim' passed to aten::cat is invalid + + const auto ir_dim_99 = R"IR( + graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), + %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)): + %dim : int = prim::Constant[value=99]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b) + %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim) + return (%r))IR"; + const auto ir_dim_minus_6 = R"IR( + graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), + %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)): + %dim : int = prim::Constant[value=-6]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b) + %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim) + return (%r))IR"; + + auto compile = [](const std::string& graph_string) { + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + k.getCodeGenStmt(); + }; + ASSERT_THROWS_WITH(compile(ir_dim_99), "Invalid index"); + ASSERT_THROWS_WITH(compile(ir_dim_minus_6), "Invalid index"); + } +} + +TEST_F(Kernel, CatInputTypesPromotion) { + { + // Test that we properly promote input types for aten::cat + + const auto graph_string = R"IR( + graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), + %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), + %c : Double(5, 9, 2, strides=[18, 2, 1], device=cpu)): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) + %r : Double(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim) + return (%r))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kDouble)); + auto ref = at::cat({a, b, c}, 1); + + TensorExprKernel k(graph); + std::vector inputs = {a, b, c}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NEXT: for +# CHECK-NEXT: aten_cat)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + + // Check sizes + TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); + TORCH_CHECK_EQ(o.dtype(), ref.dtype()); + size_t num_el = 1; + for (const auto idx : c10::irange(ref.sizes().size())) { + TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + num_el *= ref.sizes()[idx]; + } + + // Check the contents + for (const auto i : c10::irange(num_el)) { + TORCH_CHECK_EQ(((double*)o.data_ptr())[i], ((double*)ref.data_ptr())[i]); + } + } +} + +TEST_F(Kernel, ToDType) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%x.1 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): + %1 : NoneType = prim::Constant() + %2 : bool = prim::Constant[value=0]() + %3 : int = prim::Constant[value=6]() + %4 : int = prim::Constant[value=15]() + %5 : int = prim::Constant[value=5]() + %6 : bool = prim::Constant[value=1]() + %y.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::sigmoid(%x.1) + %z.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_reduced_precision(%y.3, %6, %6, %5, %4) + %h.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_full_precision(%z.3, %6, %6) + %i.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%h.3, %3, %2, %2, %1) + %j.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%i.3, %4, %2, %2, %1) + %k.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%j.3, %3, %2, %2, %1) + return (%k.3))IR"; + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + StmtPtr s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NEXT: aten_to +# CHECK-NEXT: } +# CHECK-NEXT: })IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto a = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kBFloat16)); + auto ref = + at::_to_copy(at::sigmoid(a), TensorOptions(kCPU).dtype(at::kFloat)); + + std::vector inputs = {a}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + ASSERT_EQ(o.sizes(), ref.sizes()); + ASSERT_EQ(o.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3)); +#endif +} + +TEST_F(Kernel, CatAndInlineWithAConstantDim) { + const auto graph_string = R"IR( + graph(%0 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu), + %1 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu)): + %2 : bool = prim::Constant[value=0]() + %3 : int = prim::Constant[value=1]() + %4 : Tensor[] = prim::ListConstruct(%0, %1) + %5 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%4, %3) + %6 : Tensor[] = prim::ListConstruct(%5) + %7 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%6, %3) + %8 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::_cast_Float(%7, %2) + return (%8, %7))IR"; + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + + auto a = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::_cast_Float(at::cat({a, b}, 1), 0); + + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + ASSERT_EQ(o.sizes(), ref.sizes()); + ASSERT_EQ(o.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, CatWithEmptyInputs) { + bool curr_cat_wo_conditionals = getCatWoConditionals(); + for (auto cat_wo_conditionals : {true, false}) { + getCatWoConditionals() = cat_wo_conditionals; + const auto graph_string = R"IR( + graph(%0 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu), + %1 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu)): + %3 : int = prim::Constant[value=0]() + %6 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%0) + %7 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%1) + %10 : Tensor[] = prim::ListConstruct(%6, %7) + %11 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::cat(%10, %3) + return (%11))IR"; + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + TensorExprKernel k(graph); + + auto a = at::rand({0, 64}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({10, 64}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::cat({at::tanh(a), at::tanh(b)}, 0); + + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + ASSERT_EQ(o.sizes(), ref.sizes()); + ASSERT_EQ(o.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(o, ref)); + } + getCatWoConditionals() = curr_cat_wo_conditionals; +} + +TEST_F(Kernel, CatWoConditionals) { + bool old_cat_wo_conditionals = getCatWoConditionals(); + getCatWoConditionals() = true; + const auto graph_string = R"IR( + graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu), + %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu), + %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) + %r : Float(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim) + return (%r))IR"; + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + StmtPtr s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK: for +# CHECK: for +# CHECK: aten_cat +# CHECK: for +# CHECK: for +# CHECK: aten_cat +# CHECK: for +# CHECK: for +# CHECK: aten_cat)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::cat({a, b, c}, 1); + + std::vector inputs = {a, b, c}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + + // Check sizes + TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); + TORCH_CHECK_EQ(o.dtype(), ref.dtype()); + size_t num_el = 1; + for (const auto idx : c10::irange(ref.sizes().size())) { + TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + num_el *= ref.sizes()[idx]; + } + + // Check the contents + for (const auto i : c10::irange(num_el)) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } + getCatWoConditionals() = old_cat_wo_conditionals; +} + +TEST_F(Kernel, OptimizeConditionals) { + bool old_cat_wo_conditionals = getCatWoConditionals(); + bool old_opt_conditionals = getOptConditionals(); + getCatWoConditionals() = false; + getOptConditionals() = true; + const auto graph_string = R"IR( + graph(%a : Float(5, 3, strides=[3, 1], device=cpu), + %b : Float(5, 7, strides=[7, 1], device=cpu), + %c : Float(5, 9, strides=[9, 1], device=cpu)): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) + %r : Float(5, 19, strides=[19, 1]) = aten::cat(%inputs, %dim) + %t : Float(5, 19, strides=[19, 1]) = aten::relu(%r) + return (%t))IR"; + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + StmtPtr s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for +# CHECK-NEXT: aten_relu +# CHECK: for +# CHECK-NEXT: aten_relu +# CHECK: for +# CHECK-NEXT: aten_relu +# CHECK-NOT: Allocate +# CHECK-NOT: Free)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto b = at::rand({5, 7}, TensorOptions(kCPU).dtype(at::kFloat)); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto c = at::rand({5, 9}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = at::relu(at::cat({a, b, c}, 1)); + + std::vector inputs = {a, b, c}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + + // Check sizes + TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size()); + TORCH_CHECK_EQ(o.dtype(), ref.dtype()); + size_t num_el = 1; + for (const auto idx : c10::irange(ref.sizes().size())) { + TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]); + num_el *= ref.sizes()[idx]; + } + + // Check the contents + for (const auto i : c10::irange(num_el)) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } + getOptConditionals() = old_opt_conditionals; + getCatWoConditionals() = old_cat_wo_conditionals; +} + +namespace { + +std::string dtypeConstant(ScalarType scalar_type) { + if (scalar_type == ScalarType::Undefined) { + return "None = prim::Constant()"; + } else { + at::jit::TemplateEnv env_dtype; + env_dtype.d("scalar_type", static_cast(scalar_type)); + return format("int = prim::Constant[value=${scalar_type}]()", env_dtype); + } +} + +at::Tensor iotaTensor(IntArrayRef sizes, const at::TensorOptions& options) { + int64_t numel = std::accumulate( + sizes.begin(), + sizes.end(), + 1, + // NOLINTNEXTLINE(modernize-use-transparent-functors) + std::multiplies()); + std::vector values(numel); + std::iota(values.begin(), values.end(), 0); + auto a = at::tensor(values, options); + return a.reshape(sizes); +} + +} // namespace + +TEST_F(Kernel, SumAllAxes) { + // Test lowering of sum on all axes. + const auto graph_template = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): + %1 : ${dtype} + %2 : ${out_dtype}(requires_grad=0, device=cpu) = aten::sum(%0, %1) + return (%2))IR"; + auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + + for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) { + at::jit::TemplateEnv env; + env.s("dtype", dtypeConstant(scalar_type)); + if (scalar_type == ScalarType::Undefined) { + env.s("out_dtype", "Float"); + } else { + env.s("out_dtype", "Double"); + } + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto o = at::empty({}, TensorOptions(kCPU)); + std::optional dtype; + if (scalar_type != ScalarType::Undefined) { + dtype = static_cast(scalar_type); + } + auto ref = a.sum(/*dtype=*/dtype); + TensorExprKernel k(graph); + std::vector inputs = {a}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for +# CHECK-NEXT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + ASSERT_EQ(o.sizes(), ref.sizes()); + ASSERT_EQ(o.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(o, ref)); + } +} + +std::string li_to_str(at::ArrayRef li) { + std::stringstream out; + bool first = true; + for (auto elem : li) { + if (!first) { + out << ", "; + } + out << elem; + first = false; + } + return out.str(); +} + +TEST_F(Kernel, SumOneAxis) { + // Test lowering of sum on one axis. + const auto graph_template = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): + %1 : int[] = prim::Constant[value=[${dim}]]() + %2 : bool = prim::Constant[value=${keepdim}]() + %3 : ${dtype} + %4 : ${out_dtype}(${size}, strides=[${strides}], device=cpu) = aten::sum(%0, %1, %2, %3) + return (%4))IR"; + auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + + for (int dim = -a.dim(); dim < a.dim(); ++dim) { + for (bool keepdim : {false, true}) { + for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) { + at::jit::TemplateEnv env; + env.d("dim", dim); + env.d("keepdim", keepdim); + env.s("dtype", dtypeConstant(scalar_type)); + std::optional dtype; + if (scalar_type != ScalarType::Undefined) { + dtype = static_cast(scalar_type); + } + auto ref = a.sum({dim}, /*keepdim=*/keepdim, /*dtype=*/dtype); + if (scalar_type == ScalarType::Undefined) { + env.s("out_dtype", "Float"); + } else { + env.s("out_dtype", "Double"); + } + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + const auto graph_string = format(graph_template, env); + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto o = at::empty({}, TensorOptions(kCPU)); + TensorExprKernel k(graph); + std::vector inputs = {a}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for (int64_t +# CHECK-NEXT: sum +# CHECK-NEXT: for (int64_t +# CHECK-NEXT: sum)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + ASSERT_EQ(o.sizes(), ref.sizes()); + ASSERT_EQ(o.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3)); + } + } + } +} + +TEST_F(Kernel, SumMultipleAxes) { + // Test lowering of sum on multiple axes. + const auto graph_template = R"IR( + graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], requires_grad=0, device=cpu)): + %1 : int = prim::Constant[value=${dim1}]() + %2 : int = prim::Constant[value=${dim2}]() + %3 : int[] = prim::ListConstruct(%1, %2) + %4 : bool = prim::Constant[value=${keepdim}]() + %5 : ${dtype} + %6 : Float(${size}, strides=[${strides}], requires_grad=0, device=cpu) = aten::sum(%0, %3, %4, %5) + return (%6))IR"; + auto a = iotaTensor({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + + // Only iterate over positive values of axes to keep the running time + // reasonable, since the number of pairs is quadratic. + for (const auto dim1 : c10::irange(a.dim())) { + for (int dim2 = dim1 + 1; dim2 < a.dim(); ++dim2) { + for (bool keepdim : {false, true}) { + at::jit::TemplateEnv env; + env.d("dim1", dim1); + env.d("dim2", dim2); + env.d("keepdim", keepdim); + env.s("dtype", dtypeConstant(ScalarType::Undefined)); + auto o = at::empty({}, TensorOptions(kCPU)); + auto ref = a.sum(IntArrayRef{dim1, dim2}, /*keepdim=*/keepdim); + + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + std::vector inputs = {a}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = + R"IR( +# CHECK: for (int64_t +# CHECK: for (int64_t +# CHECK: for (int64_t +# CHECK: for (int64_t +# CHECK: sum)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + ASSERT_EQ(o.sizes(), ref.sizes()); + ASSERT_EQ(o.dtype(), ref.dtype()); + ASSERT_TRUE(at::allclose(o, ref)); + } + } + } +} + +// This test and the following ones testing Softmax only tests with dim set +// to one of the valid input dimensions. It does not test with dim=None +// because that is supposed to be deprecated. +TEST_F(Kernel, Softmax2D) { + const auto graph_template = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)): + %1 : int = prim::Constant[value=${dim}]() + %dt_float : int = prim::Constant[value=7]() + %dt_none : NoneType = prim::Constant() + %4 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %${dt}) + return (%4))IR"; + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + + const std::string& verification_template = + R"IR( + # CHECK: for (int i${other_dim} = 0; i${other_dim} < ${other_dim_size} + # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_max + # CHECK: for (int i${other_dim}_1 = 0; i${other_dim}_1 < ${other_dim_size} + # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_sum + # CHECK: for (int i0_2 = 0; i0_2 < 5 + # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 + # CHECK-NEXT: aten_softmax)IR"; + + for (bool empty_dtype : {false, true}) { + for (auto log_softmax : {false, true}) { + for (const auto softmax_dim : c10::irange(a.dim())) { + auto softmax_dim_size = a.sizes()[softmax_dim]; + auto other_dim = (softmax_dim + 1) % a.dim(); + auto ref = + log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); + at::jit::TemplateEnv env; + env.d("dim", softmax_dim); + env.s("op", log_softmax ? "log_softmax" : "softmax"); + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + env.s("dt", empty_dtype ? "dt_none" : "dt_float"); + + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + std::vector inputs = {a}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + at::jit::TemplateEnv ver_env; + ver_env.d("other_dim", other_dim); + ver_env.d("other_dim_size", a.sizes()[other_dim]); + ver_env.d("softmax_dim", softmax_dim); + ver_env.d("softmax_dim_size", softmax_dim_size); + const auto verification_pattern = + format(verification_template, ver_env); + + // verification sting temporarily disabled until + // inlining of exp() is benchmarked and determined + // torch::jit::testing::FileCheck().run(verification_pattern, + // oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + auto output = stack[0].toTensor(); + ASSERT_EQ(output.sizes(), ref.sizes()); + ASSERT_TRUE(at::allclose(output, ref)); + } + } + } +} + +TEST_F(Kernel, Softmax3D) { + const auto graph_template = R"IR( + graph(%0 : Float(3, 4, 5, strides=[20, 5, 1], device=cpu)): + %1 : int = prim::Constant[value=${dim}]() + %2 : int = prim::Constant[value=7]() + %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) + return (%3))IR"; + + auto a = at::rand({3, 4, 5}, TensorOptions(kCPU).dtype(at::kFloat)); + + const std::string& verification_template = + R"IR( + # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} + # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} + # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_max + # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} + # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} + # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_sum + # CHECK: for (int i0_2 = 0; i0_2 < 3 + # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 4 + # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 5 + # CHECK-NEXT: aten_softmax)IR"; + + for (auto log_softmax : {false, true}) { + for (const auto softmax_dim : c10::irange(a.dim())) { + auto softmax_dim_size = a.sizes()[softmax_dim]; + std::vector other_dims; + for (const auto i : c10::irange(a.dim())) { + if (i != softmax_dim) { + other_dims.push_back(i); + } + } + auto ref = + log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); + + at::jit::TemplateEnv env; + env.d("dim", softmax_dim); + env.s("op", log_softmax ? "log_softmax" : "softmax"); + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + std::vector inputs = {a}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + at::jit::TemplateEnv ver_env; + ver_env.d("dim1", other_dims[0]); + ver_env.d("dim1_size", a.sizes()[other_dims[0]]); + ver_env.d("dim2", other_dims[1]); + ver_env.d("dim2_size", a.sizes()[other_dims[1]]); + ver_env.d("softmax_dim", softmax_dim); + ver_env.d("softmax_dim_size", softmax_dim_size); + const auto verification_pattern = format(verification_template, ver_env); + + // verification sting temporarily disabled until + // inlining of exp() is benchmarked and determined + // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + auto output = stack[0].toTensor(); + + ASSERT_EQ(output.sizes(), ref.sizes()); + ASSERT_TRUE(at::allclose(output, ref)); + } + } +} + +TEST_F(Kernel, Softmax4D) { + const auto graph_template = R"IR( + graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)): + %1 : int = prim::Constant[value=${dim}]() + %2 : int = prim::Constant[value=7]() + %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2) + return (%3))IR"; + + auto a = at::rand({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + + const std::string& verification_template = + R"IR( + # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size} + # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size} + # CHECK-NEXT: for (int i${dim3} = 0; i${dim3} < ${dim3_size} + # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_max + # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size} + # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size} + # CHECK-NEXT: for (int i${dim3}_1 = 0; i${dim3}_1 < ${dim3_size} + # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size} + # CHECK-NEXT: aten_softmax_sum + # CHECK: for (int i0_2 = 0; i0_2 < 2 + # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3 + # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 2 + # CHECK-NEXT: for (int i3_2 = 0; i3_2 < 3 + # CHECK-NEXT: aten_softmax)IR"; + + for (auto log_softmax : {false, true}) { + for (const auto softmax_dim : c10::irange(a.dim())) { + auto softmax_dim_size = a.sizes()[softmax_dim]; + std::vector other_dims; + for (const auto i : c10::irange(a.dim())) { + if (i != softmax_dim) { + other_dims.push_back(i); + } + } + auto ref = + log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim); + + at::jit::TemplateEnv env; + env.d("dim", softmax_dim); + env.s("op", log_softmax ? "log_softmax" : "softmax"); + env.s("size", li_to_str(ref.sizes())); + env.s("strides", li_to_str(ref.strides())); + + const auto graph_string = format(graph_template, env); + + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + std::vector inputs = {a}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + at::jit::TemplateEnv ver_env; + ver_env.d("dim1", other_dims[0]); + ver_env.d("dim1_size", a.sizes()[other_dims[0]]); + ver_env.d("dim2", other_dims[1]); + ver_env.d("dim2_size", a.sizes()[other_dims[1]]); + ver_env.d("dim3", other_dims[2]); + ver_env.d("dim3_size", a.sizes()[other_dims[2]]); + ver_env.d("softmax_dim", softmax_dim); + ver_env.d("softmax_dim_size", softmax_dim_size); + const auto verification_pattern = format(verification_template, ver_env); + + // verification sting temporarily disabled until + // inlining of exp() is benchmarked and determined + // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + auto output = stack[0].toTensor(); + ASSERT_EQ(output.sizes(), ref.sizes()); + ASSERT_TRUE(at::allclose(output, ref)); + } + } +} + +TEST_F(Kernel, SignTest) { + const auto graph_template = R"IR( + graph(%0 : ${dtype}(${size}, strides=[1], device=cpu)): + %2 : ${dtype}(${size}, strides=[1]) = aten::sign(%0) + return (%2))IR"; + + auto run_test = [](const std::string& graph_string, const at::Tensor& input) { + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + StmtPtr s = k.getCodeGenStmt(); + + std::vector inputs = {input}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto ref = at::sign(input); + ASSERT_TRUE(at::allclose(o, ref)); + }; + auto common_options = at::TensorOptions() + .layout(at::kStrided) + .device(at::kCPU) + .requires_grad(false); + int default_input_size = 100; + for (auto scalar_type : {ScalarType::Float, ScalarType::Double}) { + at::Tensor corner_case_inputs; + at::jit::TemplateEnv env; + auto options = common_options; + switch (scalar_type) { + case ScalarType::Float: { + env.s("dtype", "Float"); + options = options.dtype(at::kFloat); + std::vector input_float = { + 0.0f, + -0.0f, + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + std::nanf("1"), + -std::nanf("1")}; + corner_case_inputs = at::from_blob( + input_float.data(), + {static_cast(input_float.size())}, + options); + auto rand_input = at::rand({default_input_size}, options); + auto input = at::cat({rand_input, corner_case_inputs}); + env.d("size", at::numel(input)); + const auto graph_string = format(graph_template, env); + run_test(graph_string, input); + break; + } + case ScalarType::Double: { + env.s("dtype", "Double"); + options = options.dtype(at::kDouble); + std::vector input_double = { + 0.0, + -0.0, + std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + std::nan("1"), + -std::nan("1")}; + corner_case_inputs = at::from_blob( + input_double.data(), + {static_cast(input_double.size())}, + options); + auto rand_input = at::rand({default_input_size}, options); + auto input = at::cat({rand_input, corner_case_inputs}); + env.d("size", at::numel(input)); + const auto graph_string = format(graph_template, env); + run_test(graph_string, input); + break; + } + default: + throw unsupported_dtype(); + } + } +} + +TEST_F(Kernel, InlineProducerIntoReduction) { + // Inline producer (mul) into reduction (sum). + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[3, 1], device=cpu)): + %2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1) + %3 : int = prim::Constant[value=7]() + %4 : Double(device=cpu) = aten::sum(%2, %3) + return (%4))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + StmtPtr s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + + // Check the IR we produced. + // We should have only one loop in the end. + const std::string& verification_pattern = + R"IR( + # CHECK: for (int64_t i_1 = 0ll; i_1 < 5 + # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3 + # CHECK-NEXT: sum + # CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto ref = (a * b).sum(at::kDouble); + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, InlineReductionIntoConsumer) { + // Inline producer (mul %2) into reduction (sum %4) but DO NOT + // inline the reduction into consumer (mul %4). + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[3, 1], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : int = prim::Constant[value=6]() + %4 : Float(device=cpu) = aten::sum(%2, %3) + %5 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%2, %4) + return (%5))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + TensorExprKernel k(graph); + StmtPtr s = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *s; + + // Check the IR we produced. + // We should have two loops in the end. + const std::string& verification_pattern = + R"IR( + # CHECK: for (int64_t i_1 = 0ll; i_1 < 5 + # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3 + # CHECK-NEXT: sum + # CHECK: for (int64_t i_2 = 0ll; i_2 < 5 + # CHECK-NEXT: for (int64_t j_2 = 0ll; j_2 < 3 + # CHECK-NEXT: aten_mul + # CHECK-NOT: for)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto ref = (a * b).sum(at::kFloat) * (a * b); + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, SanitizeNames_CUDA) { + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cuda:0), + %1 : Float(5, 3, strides=[3, 1], device=cuda:0)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%4))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + graph->inputs().at(0)->setDebugName("aten::add:"); + graph->inputs().at(1)->setDebugName("aten::add_"); + TensorExprKernel k(graph); + auto a = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat)); + auto b = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat)); + auto ref = a * (a * b); + std::vector inputs = {a, b}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, SanitizeConstants_CUDA) { + const auto graph_string = R"IR( + graph(%x : Float(16, 16, strides=[16, 1], device=cuda:0)): + %none : NoneType = prim::Constant() + %size : int = prim::Constant[value=16]() + %sizes : int[] = prim::ListConstruct(%size, %size) + %30 : Device = prim::Constant[value="cuda"]() + %y : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::ones(%sizes, %none, %none, %30, %none) + %z : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::mul(%x, %y) + return (%z))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + // IRParser doesn't support tensor constants, so we insert a call to + // aten::ones and then const-prop it + ConstantPropagation(graph); + + // We set the name of the constant to include special characters that are + // not allowed. This should be fixed by the sanitizer in TensorExprKernel. + graph->nodes().front()->output()->setDebugName("illegal.name"); + + // Check if we have a constant node with illegal name in the graph. + auto const_node = graph->nodes().front(); + ASSERT_EQ(const_node->kind(), prim::Constant); + ASSERT_NE(const_node->output()->debugName().find('.'), std::string::npos); + + TensorExprKernel k(graph); + + auto x = at::rand({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat)); + std::vector inputs = {x}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto y = at::ones({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat)); + auto ref = x * y; + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, ConstantTensors) { + const auto graph_string = R"IR( + graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): + %none : NoneType = prim::Constant() + %size : int = prim::Constant[value=16]() + %sizes : int[] = prim::ListConstruct(%size, %size) + %y : Float(16, 16, strides=[16, 1], device=cpu) = aten::ones(%sizes, %none, %none, %none, %none) + %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) + return (%z))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + // IRParser doesn't support tensor constants, so we insert a call to + // aten::ones and then const-prop it + ConstantPropagation(graph); + + TensorExprKernel k(graph); + + auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); + std::vector inputs = {x}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto y = at::ones({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = x * y; + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, ConstantTensorsNonContiguous) { + const auto graph_string = R"IR( + graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): + %none : NoneType = prim::Constant() + %dtype : int = prim::Constant[value=6]() + %c0 : int = prim::Constant[value=0]() + %c256 : int = prim::Constant[value=256]() + %c16 : int = prim::Constant[value=16]() + %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none) + %sizes : int[] = prim::ListConstruct(%c16, %c16) + %y_t : Tensor = aten::view(%y_flat, %sizes) + %y : Tensor = aten::t(%y_t) + %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) + return (%z))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + // IRParser doesn't support tensor constants, so we generate several aten + // calls to produce non-contiguous constant tensor and then const-prop it + ConstantPropagation(graph); + + TensorExprKernel k(graph); + + auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat)); + std::vector inputs = {x}; + std::vector stack = fmap(inputs); + k.run(stack); + auto o = stack[0].toTensor(); + auto y = at::arange(0, 256, TensorOptions(kCPU).dtype(at::kFloat)) + .view({16, 16}) + .t(); + auto ref = x * y; + ASSERT_TRUE(at::allclose(o, ref)); +} + +TEST_F(Kernel, RunFast) { +#ifdef TORCH_ENABLE_LLVM + // TODO: Implement call_raw in IREval and remove the ifdef + + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[1, 5], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = + at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); + auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + + k.runFast({a.data_ptr(), b.data_ptr()}, {o.data_ptr()}); + for (size_t i = 0; i < 5 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +#endif +} + +TEST_F(Kernel, RunWithAllocatedOutputs) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%0 : Float(5, 3, strides=[3, 1], device=cpu), + %1 : Float(5, 3, strides=[1, 5], device=cpu)): + %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = + at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); + auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + + std::vector args = {o, a, b}; + std::vector stack = fmap(args); + k.runWithAllocatedOutputs(stack); + for (size_t i = 0; i < 5 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +#endif +} + +TEST_F(Kernel, CodegenInspection) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%x : Float(16, 16, strides=[16, 1], device=cpu)): + %none : NoneType = prim::Constant() + %dtype : int = prim::Constant[value=6]() + %c0 : int = prim::Constant[value=0]() + %c256 : int = prim::Constant[value=256]() + %c16 : int = prim::Constant[value=16]() + %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none) + %sizes : int[] = prim::ListConstruct(%c16, %c16) + %y_t : Tensor = aten::view(%y_flat, %sizes) + %y : Tensor = aten::t(%y_t) + %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y) + return (%z))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + // IRParser doesn't support tensor constants, so we generate several aten + // calls to produce non-contiguous constant tensor and then const-prop it + ConstantPropagation(graph); + + TensorExprKernel k(graph); + + // Check that we could retrieve generated assembly + auto asm_str = k.getCodeText("asm"); + const std::string& asm_verification_pattern = + R"ASM( + # CHECK: .text + # CHECK: retq)ASM"; + torch::jit::testing::FileCheck().run(asm_verification_pattern, asm_str); + + // Check that we could retrieve info about codegen parameters + auto constants = k.getConstantDescriptors(); + auto buf_args = k.getBufferArgs(); + // Expected buf args: [input0, output0, constant0] + ASSERT_EQ(buf_args.size(), 3); + ASSERT_EQ(constants.size(), 1); + ASSERT_TRUE( + !buf_args[0].isVar() && !buf_args[1].isVar() && !buf_args[2].isVar()); +#endif +} + +Tensor lowerNanToNum( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const std::optional& outputType, + at::Device device) { + auto input_buf = std::get(inputs[0]); + auto e = Compute( + "custom_nan_to_num", + outputShape, + outputStrides, + [&](const std::vector& axes) { + std::vector indices(axes.begin(), axes.end()); + auto load = input_buf.load(indices); + return IfThenElse::make(Cast::make(kBool, isnan(load)), 0.0f, load); + }); + return e; +} + +TEST_F(Kernel, CustomLowering) { + const auto graph_string = R"IR( + graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)): + %none : NoneType = prim::Constant() + %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none) + return (%y) +)IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + std::unordered_map lowerings = { + {aten::nan_to_num, lowerNanToNum}}; + TensorExprKernel k(graph, lowerings); + + auto stmt = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *stmt; + + // Check that our custom lowering is actually used + torch::jit::testing::FileCheck().check("custom_nan_to_num")->run(oss.str()); + torch::jit::testing::FileCheck().check("isnan")->run(oss.str()); +} + +TEST_F(Kernel, Vectorize) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%0 : Float(100, 16, strides=[16, 1], device=cpu), + %1 : Float(100, 16, strides=[16, 1], device=cpu)): + %2 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %1) + %3 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 100 * 16; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +#endif +} + +// TODO: To vectorize loopnest for 100x3 case, we need to flatten loops first. +TEST_F(Kernel, DISABLED_FlattenVectorize) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%0 : Float(100, 3, strides=[3, 1], device=cpu), + %1 : Float(100, 3, strides=[3, 1], device=cpu)): + %2 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 100 * 3; i++) { + TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +#endif +} + +TEST_F(Kernel, Strided1dWithinBounds) { + auto ir = R"IR( + graph(%0 : Float(3, strides=[1], device=cpu), + %1 : Float(3, strides=[2], device=cpu)): + %2 : int = prim::Constant[value=1]() + %3 : Float(3, strides=[1]) = aten::add(%0, %1, %2) + return (%3))IR"; + auto graph = std::make_shared(); + std::unordered_map vmap; + parseIR(ir, graph.get(), vmap); + TensorExprKernel k(graph); + + auto a = at::rand({3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({6}, TensorOptions(kCPU).dtype(at::kFloat)) + .index({Slice(None, None, 2)}); + auto expect = a + b; + + std::vector inputs = {a, b}; + + std::vector stack = fmap(inputs); + k.run(stack); + + auto output = stack[0].toTensor(); + + for (size_t i = 0; i < 3; ++i) { + TORCH_CHECK_EQ( + ((float*)output.data_ptr())[i], ((float*)expect.data_ptr())[i]); + } +} + +TEST_F(Kernel, InputAsOutput) { + const auto graph_string = R"IR( + graph(%x : Float(5, 3, strides=[3, 1], device=cpu), + %y : Float(5, 3, strides=[1, 5], device=cpu)): + return (%x, %y))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto y = + at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1); + TensorExprKernel k(graph); + std::vector inputs = {x, y}; + + std::vector stack = fmap(inputs); + k.run(stack); + CHECK(at::allclose(x, stack[0].toTensor())); + CHECK(at::allclose(y, stack[1].toTensor())); +} + +TEST_F(Kernel, ScalarOut) { + auto ir = R"IR( +graph(%x : int, %y : int): + %z : int = aten::mul(%x, %y) + %r : int = aten::mul(%z, %x) + return (%r, %z))IR"; + auto graph = std::make_shared(); + std::unordered_map vmap; + parseIR(ir, graph.get(), vmap); + TensorExprKernel k(graph); + + auto stmt = k.getCodeGenStmt(); + std::ostringstream oss; + oss << *stmt; + + // Verify the generated IR. We expect to see a scalar variable (Let) followed + // by a store to a 0-dim buffer. + const std::string& verification_pattern = R"IR( +# CHECK: int64_t +# CHECK-NEXT: [0ll] = +# CHECK-NEXT: int64_t +# CHECK-NEXT: [0ll] = +)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + int64_t x = 2, y = 3, r = 0, z = 0; + + // Verify that TEK::runFast works correctly with scalar outputs + std::vector inputs = {&x, &y}; + std::vector outputs = {&r, &z}; + k.runFast(inputs, outputs); + TORCH_CHECK_EQ(z, x * y); + TORCH_CHECK_EQ(r, z * x); + + // Verify that TEK::run works correctly with scalar outputs + std::vector stack = {x, y}; + k.run(stack); + TORCH_CHECK_EQ(stack[0], x * y * x); + TORCH_CHECK_EQ(stack[1], x * y); +} + +TEST_F(Kernel, ScalarTensorOut) { + auto ir = R"IR( +graph(%x : int, + %xt : Long(3, strides=[1], device=cpu), + %y : int, + %yt : Long(3, strides=[1], device=cpu)): + %z : int = aten::mul(%x, %y) + %r : int = aten::mul(%z, %x) + %zt : Long(3, strides=[1], device=cpu) = aten::mul(%xt, %y) + %rt : Long(3, strides=[1], device=cpu) = aten::mul(%zt, %xt) + return (%r, %rt, %z, %zt))IR"; + auto graph = std::make_shared(); + std::unordered_map vmap; + parseIR(ir, graph.get(), vmap); + TensorExprKernel k(graph); + int64_t x = 2, y = 3, r = 0, z = 0; + auto xt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 2; + auto yt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 3; + auto zt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong)); + auto rt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong)); + + // Verify that TEK::runFast works correctly with mixed scalar and tensor + // inputs/utputs + std::vector inputs = {&x, xt.data_ptr(), &y, yt.data_ptr()}; + std::vector outputs = {&r, rt.data_ptr(), &z, zt.data_ptr()}; + k.runFast(inputs, outputs); + TORCH_CHECK_EQ(z, x * y); + TORCH_CHECK_EQ(r, z * x); + ASSERT_TRUE(at::equal(zt, xt * yt)); + ASSERT_TRUE(at::equal(rt, zt * xt)); + + // Verify that TEK::run works correctly with mixed scalar and tensor + // inputs/utputs + std::vector stack = {x, xt, y, yt}; + k.run(stack); + TORCH_CHECK_EQ(stack[0], x * y * x); + ASSERT_TRUE(at::equal(stack[1].toTensor(), xt * yt * xt)); + TORCH_CHECK_EQ(stack[2], x * y); + ASSERT_TRUE(at::equal(stack[3].toTensor(), xt * yt)); +} + +TEST_F(Kernel, FuseLoopsWithVariableBounds) { +#ifdef TORCH_ENABLE_LLVM + bool old_cat_wo_conditionals = getCatWoConditionals(); + getCatWoConditionals() = true; + const auto graph_string = R"IR( + graph(%a : Float(SS(-2), 3, SS(-3), requires_grad=0, device=cpu), + %b : Float(SS(-2), 7, SS(-3), requires_grad=0, device=cpu), + %c : Float(SS(-2), 9, SS(-3), requires_grad=0, device=cpu), + %SS_2 : int, + %SS_3 : int): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) + %r : Float(SS(-2), 19, SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] + return (%r))IR"; + std::shared_ptr graph = std::make_shared(); + torch::jit::parseIR(graph_string, graph.get()); + + std::vector symbolic_shape_inputs = {-2, -3}; + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[graph->inputs().at(0)] = input_desc; + symbolic_strides[graph->inputs().at(1)] = input_desc; + symbolic_strides[graph->inputs().at(2)] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + std::ostringstream oss; + oss << *kernel.getCodeGenStmt(); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int64_t i +# CHECK-NEXT: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK-NOT: for (int64_t i + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto run_kernel = [&](int dim1, int dim2) { + auto a = + at::rand({dim1, 3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); + auto b = + at::rand({dim1, 7, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); + auto c = + at::rand({dim1, 9, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); + + auto ref = at::cat({a, b, c}, 1); + + std::vector stack = + fmap(std::vector({a, b, c})); + stack.emplace_back(dim1); + stack.emplace_back(dim2); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + }; + + run_kernel(10, 20); + getCatWoConditionals() = old_cat_wo_conditionals; +#endif +} + +TEST_F(Kernel, FuseLoopsWithVariableConcatDim) { +#ifdef TORCH_ENABLE_LLVM + bool old_cat_wo_conditionals = getCatWoConditionals(); + getCatWoConditionals() = true; + const auto graph_string = R"IR( + graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), + %b : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), + %c : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), + %SS_2 : int, + %SS_3 : int, + %SS_4 : int, + %SS_5 : int): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c) + %r : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] + return (%r))IR"; + std::shared_ptr graph = std::make_shared(); + torch::jit::parseIR(graph_string, graph.get()); + + std::vector symbolic_shape_inputs = {-2, -3, -4, -5}; + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[graph->inputs().at(0)] = input_desc; + symbolic_strides[graph->inputs().at(1)] = input_desc; + symbolic_strides[graph->inputs().at(2)] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + std::ostringstream oss; + oss << *kernel.getCodeGenStmt(); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int64_t i +# CHECK-NEXT: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK-NOT: for (int64_t i + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto run_kernel = [&](int dim1, int dim2, int dim3) { + auto a = + at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); + auto b = + at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); + auto c = + at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat)); + + auto ref = at::cat({a, b, c}, 1); + + std::vector stack = + fmap(std::vector({a, b, c})); + stack.emplace_back(dim1); + stack.emplace_back(dim2); + stack.emplace_back(dim3); + stack.emplace_back(3 * dim3); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + }; + + run_kernel(10, 20, 15); + getCatWoConditionals() = old_cat_wo_conditionals; +#endif +} + +TEST_F(Kernel, DoNotFuseLoopsWithMismatchingVariableDims) { +#ifdef TORCH_ENABLE_LLVM + bool old_cat_wo_conditionals = getCatWoConditionals(); + getCatWoConditionals() = true; + const auto graph_string = R"IR( + graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu), + %b : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu), + %SS_2 : int, + %SS_3 : int, + %SS_4 : int, + %SS_5 : int, + %SS_6 : int): + %dim : int = prim::Constant[value=1]() + %inputs : Tensor[] = prim::ListConstruct(%a, %b) + %r : Float(SS(-2), SS(-6), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim) # new size: [5,19,2] + return (%r))IR"; + std::shared_ptr graph = std::make_shared(); + torch::jit::parseIR(graph_string, graph.get()); + + std::vector symbolic_shape_inputs = {-2, -3, -4, -5, -6}; + + std::vector input_desc = { + torch::jit::StrideInput::TENSOR_CONT}; + std::unordered_map< + const torch::jit::Value*, + std::vector> + symbolic_strides; + symbolic_strides[graph->inputs().at(0)] = input_desc; + symbolic_strides[graph->inputs().at(1)] = input_desc; + symbolic_strides[graph->outputs().at(0)] = input_desc; + + TensorExprKernel kernel( + graph, {}, symbolic_shape_inputs, false, symbolic_strides); + + std::ostringstream oss; + oss << *kernel.getCodeGenStmt(); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int64_t i +# CHECK-NEXT: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK: for (int64_t j +# CHECK-NEXT: for (int64_t k +# CHECK-NOT: for (int64_t j +# CHECK-NOT: for (int64_t i + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + auto run_kernel = [&](int dim2, int dim3, int dim4, int dim5) { + auto a = + at::rand({dim2, dim4, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat)); + auto b = + at::rand({dim2, dim5, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat)); + + auto ref = at::cat({a, b}, 1); + + std::vector stack = fmap(std::vector({a, b})); + stack.emplace_back(dim2); + stack.emplace_back(dim3); + stack.emplace_back(dim4); + stack.emplace_back(dim5); + stack.emplace_back(dim4 + dim5); + kernel.run(stack); + + auto o = stack[0].toTensor(); + ASSERT_TRUE(at::allclose(o, ref)); + }; + + run_kernel(10, 20, 15, 8); + getCatWoConditionals() = old_cat_wo_conditionals; +#endif +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp new file mode 100644 index 0000000000000..f6ffc84f62c09 --- /dev/null +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -0,0 +1,1799 @@ +#ifdef TORCH_ENABLE_LLVM +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; + +using LLVMExprEval = ExprEval; + +// Typed tests, can't use gtest params here due to the way we instantiate tests. +#define TEST_LLVM_SCALAR_TYPES(_) \ + _(uint8_t, Byte, 24) \ + _(int8_t, Char, -20) \ + _(int16_t, Short, 3332) \ + _(int, Int, 123456) \ + _(int64_t, Long, 2631563121321) \ + _(float, Float, 0.122) \ + _(double, Double, 0.21312) \ + _(at::Half, Half, 0.128f) + +#define IMM_TEST(Type, Name, Val) \ + TEST(LLVM, Name##ImmTest) { \ + auto a = Name##Imm::make(Val); \ + LLVMExprEval cg(a); \ + if (std::is_floating_point()) { \ + ASSERT_NEAR(cg.value(), Val, 0.1); \ + } else { \ + ASSERT_EQ(cg.value(), Val); \ + } \ + } +TEST_LLVM_SCALAR_TYPES(IMM_TEST) +#undef IMM_TEST + +#define ADD_TEST(Type, Name, Val) \ + TEST(LLVM, Name##AddTest) { \ + auto a = Name##Imm::make(Val); \ + auto b = Name##Imm::make(Val * 2); \ + auto c = Add::make(a, b); \ + LLVMExprEval cg(c); \ + if (std::is_floating_point()) { \ + ASSERT_NEAR(cg.value(), Val * 3, 0.1); \ + } else { \ + ASSERT_EQ(cg.value(), Val * 3); \ + } \ + } +TEST_LLVM_SCALAR_TYPES(ADD_TEST) +#undef ADD_TEST + +#define SUB_TEST(Type, Name, Val) \ + TEST(LLVM, Name##SubTest) { \ + auto a = Name##Imm::make(Val * 2); \ + auto b = Name##Imm::make(Val); \ + auto c = Sub::make(a, b); \ + LLVMExprEval cg(c); \ + if (std::is_floating_point()) { \ + ASSERT_NEAR(cg.value(), Val, 0.1); \ + } else { \ + ASSERT_EQ(cg.value(), Val); \ + } \ + } +TEST_LLVM_SCALAR_TYPES(SUB_TEST) +#undef SUB_TEST + +#define MUL_TEST(Type, Name, Val) \ + TEST(LLVM, Name##MulTest) { \ + auto a = Name##Imm::make(Val); \ + auto b = Name##Imm::make((Type)4); \ + auto c = Mul::make(a, b); \ + LLVMExprEval cg(c); \ + if (std::is_floating_point()) { \ + ASSERT_NEAR(cg.value(), Val * 4, 0.1); \ + } else { \ + ASSERT_EQ(cg.value(), Val * 4); \ + } \ + } +TEST_LLVM_SCALAR_TYPES(MUL_TEST) +#undef MUL_TEST + +#define DIV_TEST(Type, Name, Val) \ + TEST(LLVM, Name##DivTest) { \ + auto a = Name##Imm::make((Type)6); \ + auto b = Name##Imm::make((Type)3); \ + auto c = Div::make(a, b); \ + LLVMExprEval cg(c); \ + if (std::is_floating_point()) { \ + ASSERT_NEAR(cg.value(), 2, 0.1); \ + } else { \ + ASSERT_EQ(cg.value(), 2); \ + } \ + } +TEST_LLVM_SCALAR_TYPES(DIV_TEST) +#undef DIV_TEST + +TEST(LLVM, IntToFloatCastTest) { + auto a = IntImm::make(2); + auto b = Cast::make(kFloat, a); + LLVMExprEval cg(b, {}); + ASSERT_EQ(cg.value(), 2.0); +} + +TEST(LLVM, FloatToIntCastTest) { + auto a = FloatImm::make(2.0); + auto b = Cast::make(kInt, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), 2); +} + +TEST(LLVM, IntToLongCastTest) { + auto a = IntImm::make(12345); + auto b = Cast::make(kLong, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), 12345); +} + +TEST(LLVM, ByteToCharCastTest) { + auto a = ByteImm::make(250); + auto b = Cast::make(kChar, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), (int8_t)250); +} + +TEST(LLVM, HalfToLongCastTest) { + auto a = HalfImm::make(2.0); + auto b = Cast::make(kLong, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), 2); +} + +TEST(LLVM, ByteToDoubleCastTest) { + auto a = ByteImm::make(2); + auto b = Cast::make(kDouble, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), 2); +} + +TEST(LLVM, FloatToByteCastTest) { + auto a = FloatImm::make(254.0); + auto b = Cast::make(kByte, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), 254); +} + +TEST(LLVM, FloatToCharCastTest) { + auto a = FloatImm::make(-2.0); + auto b = Cast::make(kChar, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), -2); +} + +TEST(LLVM, ByteToFloatCastTest) { + auto a = ByteImm::make(254); + auto b = Cast::make(kFloat, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), 254.0); +} + +TEST(LLVM, CharToFloatCastTest) { + auto a = CharImm::make(-2); + auto b = Cast::make(kFloat, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), -2.0); +} + +TEST(LLVM, BitCast) { + /* constexpr int16_t ref16 = 1337; */ + constexpr int32_t ref32 = 1337; + constexpr int64_t ref64 = 1337; + constexpr float reff32 = 1337.0f; + constexpr double reff64 = 1337.0f; + + // this is broken + /*{ + at::Half k_; + at::Half* k = &k_; + *reinterpret_cast(k) = ref16; + auto a = HalfImm::make(k); + auto b = BitCast::make(kShort, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), ref16); + }*/ + + { + float k = raw_bitcast(ref32); + auto a = FloatImm::make(k); + auto b = BitCast::make(kInt, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), ref32); + } + + { + double k = raw_bitcast(ref64); + auto a = DoubleImm::make(k); + auto b = BitCast::make(kLong, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), ref64); + } + + { + int64_t k = raw_bitcast(reff64); + auto a = LongImm::make(k); + auto b = BitCast::make(kDouble, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), reff64); + } + + { + int32_t k = raw_bitcast(reff32); + auto a = IntImm::make(k); + auto b = BitCast::make(kFloat, a); + LLVMExprEval cg(b); + ASSERT_EQ(cg.value(), reff32); + } +} + +TEST(LLVM, fastLogFloat) { + const int kTotalSize = 128 * 128; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = a_buf.load(index); + StmtPtr store_b = b_buf.store({index}, fast_log(load_a)); + StmtPtr stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (const auto i : c10::irange(kTotalSize)) { + a_v(i) = at::randn({1}).item().to(); + } + + LLVMCodeGen ir_eval(stmt, {a_buf, b_buf}); + ir_eval.call({a_v, b_v}); + + for (const auto i : c10::irange(kTotalSize)) { + auto test = b_v(i); + auto ref = std::log(a_v(i)); + if (std::isnan(ref)) { + ASSERT_EQ(std::isnan(test), true); + } else { + ASSERT_FLOAT_EQ(test, ref); + } + } +} + +TEST(LLVM, LetTest01) { + BufHandle a("A", {1}, kFloat); + std::vector v = {1, 0}; + std::vector args({v.data()}); + VarHandle x("x", kFloat); + auto block = Block::make({ + Let::make(x, 3.f), + a.store({0}, ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f))), + }); + + LLVMCodeGen cg(block, {a}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 4.f); +} + +TEST(LLVM, LetTest02) { + BufHandle a("A", {1}, kFloat); + std::vector v = {1, 0}; + std::vector args({v.data()}); + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + auto block = Block::make( + {Let::make(x, 3.f), + Let::make(y, 6.f), + a.store( + {IntImm::make(0)}, + ExprHandle(2.f) + (x * ExprHandle(3.f) + y * ExprHandle(4.f)))}); + + LLVMCodeGen cg(block, {a}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 6.f * 4.f); +} + +TEST(LLVM, LetTestMultitype) { + BufHandle a("A", {1}, kDouble); + std::vector v = {1, 0}; + std::vector args({v.data()}); + VarHandle x("x", kByte); + VarHandle y("y", kHalf); + auto block = Block::make( + {Let::make(x, 3), + Let::make(y, 6.f), + a.store( + {0}, + Cast::make( + kDouble, + ExprHandle(2.f) + + (x * ExprHandle(3.f) + y * ExprHandle(4.f))))}); + + LLVMCodeGen cg(block, {a}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(v[0], 2.f + 3 * 3.f + 6.f * 4.f); +} + +TEST(LLVM, BufferTest) { + BufHandle a("A", {32}, kFloat); + std::vector v(5); + std::vector args({v.data()}); + auto rv = IntImm::make(0); + LLVMExprEval cg(rv, {a}); + ASSERT_EQ(cg.value(args), 0); +} + +TEST(LLVM, BlockTest) { + BufHandle a("A", {32}, kInt); + std::vector v = {1, 2}; + std::vector args({v.data()}); + + auto block = Block::make({ + a.store({0}, 3), + a.store({1}, 4), + a.store({0}, 4), + }); + + LLVMCodeGen cg(block, {a}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(v[0], 4); + ASSERT_EQ(v[1], 4); +} + +TEST(LLVM, LoadStoreTest) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + std::vector a_buffer = {42}; + std::vector b_buffer = {-11}; + + auto store = b.store({0}, a.load(0)); + LLVMCodeGen cg(store, {a, b}); + std::vector args({a_buffer.data(), b_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(a_buffer[0], 42); + ASSERT_EQ(b_buffer[0], 42); +} + +TEST(LLVM, IfThenElseTest) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + BufHandle c("C", {1}, kInt); + std::vector a_buffer = {42}; + std::vector b_buffer = {-11}; + std::vector c_buffer = {1}; + + auto store = b.store({0}, IfThenElse::make(c.load(0), a.load(0), 0)); + LLVMCodeGen cg(store, {a, b, c}); + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(a_buffer[0], 42); + ASSERT_EQ(b_buffer[0], 42); +} + +// if (x < 10) x = x + 1 +TEST(LLVM, CondNoFalseBlockTest) { + BufHandle x("X", {1}, kInt); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = Cond::make(cmp, x.store({0}, x.load(0) + 1), nullptr); + + for (int32_t x_value : {0, 10, 20}) { + std::vector x_buffer = {x_value}; + std::vector args({x_buffer.data()}); + LLVMCodeGen cg(cond, {x}); + ASSERT_EQ(cg.value(args), 0); + if (x_value < 10) { + ASSERT_EQ(x_buffer[0], x_value + 1); + } else { + ASSERT_EQ(x_buffer[0], x_value); + } + } +} + +// if (x < 10) { +// x = x + 1; +// } else { +// x = x - 1; +// } +TEST(LLVM, CondTest) { + BufHandle x("X", {1}, kInt); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = + Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); + auto block = Block::make({ + cond, + x.store({0}, x.load(0) * 2), + }); + + for (int32_t x_value : {0, 10, 20}) { + std::vector x_buffer = {x_value}; + std::vector args({x_buffer.data()}); + LLVMCodeGen cg(block, {x}); + ASSERT_EQ(cg.value(args), 0); + if (x_value < 10) { + ASSERT_EQ(x_buffer[0], (x_value + 1) * 2); + } else { + ASSERT_EQ(x_buffer[0], (x_value - 1) * 2); + } + } +} + +// if (x < 10) { +// if (x > 5) { +// x = x + 1; +// } else { +// x = x - 1; +// } +// } else { +// if (x <= 15) { +// x = x + 2; +// } else { +// x = x - 2; +// } +// } +TEST(LLVM, CondNestedTest) { + BufHandle x("X", {1}, kInt); + auto true_cmp = + CompareSelect::make(x.load(0), 5, CompareSelectOperation::kGT); + auto true_cond = Cond::make( + true_cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1)); + auto false_cmp = + CompareSelect::make(x.load(0), 15, CompareSelectOperation::kLE); + auto false_cond = Cond::make( + false_cmp, x.store({0}, x.load(0) + 2), x.store({0}, x.load(0) - 2)); + auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT); + auto cond = Cond::make(cmp, true_cond, false_cond); + + for (int32_t x_value : {0, 8, 15, 20}) { + std::vector x_buffer = {x_value}; + std::vector args({x_buffer.data()}); + LLVMCodeGen cg(cond, {x}); + ASSERT_EQ(cg.value(args), 0); + if (x_value < 10) { + if (x_value > 5) { + ASSERT_EQ(x_buffer[0], x_value + 1); + } else { + ASSERT_EQ(x_buffer[0], x_value - 1); + } + } else { + if (x_value <= 15) { + ASSERT_EQ(x_buffer[0], x_value + 2); + } else { + ASSERT_EQ(x_buffer[0], x_value - 2); + } + } + } +} + +TEST(LLVM, DirectVectorization) { + constexpr int M = 3; + constexpr int N = 64; + BufHandle a("a", {M, N}, kFloat); + BufHandle b("b", {M, N}, kFloat); + BufHandle c("c", {M, N}, kFloat); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + StmtPtr s = For::make( + m, + 0, + M, + Store::make( + c, + {Ramp::make(m * 64, 1, 64)}, + Load::make({kFloat, 64}, a, {Ramp::make(m * 64, 1, 64)}) * + Load::make({kFloat, 64}, b, {Ramp::make(m * 64, 1, 64)}))); + LLVMCodeGen cg(s, {a, b, c}); +} + +TEST(LLVM, VecLoadStoreTest) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + std::vector a_buffer = {1, 1, 1, 1}; + std::vector b_buffer = {2, 2, 2, 2}; + + auto store = b.store({Ramp::make(0, 1, 4)}, a.load({Ramp::make(0, 1, 4)})); + LLVMCodeGen cg(store, {a, b}); + std::vector args({a_buffer.data(), b_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(a_buffer[0], 1); + ASSERT_EQ(a_buffer[1], 1); + ASSERT_EQ(a_buffer[2], 1); + ASSERT_EQ(a_buffer[3], 1); + ASSERT_EQ(b_buffer[0], 1); + ASSERT_EQ(b_buffer[1], 1); + ASSERT_EQ(b_buffer[2], 1); + ASSERT_EQ(b_buffer[3], 1); +} + +#define FLOAT_INTRINSICS_TEST(Name, Lanes) \ + TEST(LLVM, VecFloat_##Name##Lane##Lanes##Test) { \ + BufHandle a("A", {1}, kFloat); \ + BufHandle b("B", {1}, kFloat); \ + float val = 0.5f; \ + std::vector a_buffer(Lanes, val); \ + std::vector b_buffer(Lanes, val); \ + auto store = b.store( \ + {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \ + LLVMCodeGen cg(store, {a, b}); \ + std::vector args({a_buffer.data(), b_buffer.data()}); \ + ASSERT_EQ(cg.value(args), 0); \ + for (const auto i : c10::irange(Lanes)) { \ + ASSERT_FLOAT_EQ(a_buffer[i], val); \ + } \ + } // namespace jit +FLOAT_INTRINSICS_TEST(erf, 4) +FLOAT_INTRINSICS_TEST(erfc, 4) +FLOAT_INTRINSICS_TEST(acos, 4) +FLOAT_INTRINSICS_TEST(asin, 4) +FLOAT_INTRINSICS_TEST(atan, 4) +FLOAT_INTRINSICS_TEST(cosh, 4) +FLOAT_INTRINSICS_TEST(sinh, 4) +FLOAT_INTRINSICS_TEST(tanh, 4) +FLOAT_INTRINSICS_TEST(expm1, 4) +FLOAT_INTRINSICS_TEST(lgamma, 4) +FLOAT_INTRINSICS_TEST(erf, 8) +FLOAT_INTRINSICS_TEST(erfc, 8) +FLOAT_INTRINSICS_TEST(acos, 8) +FLOAT_INTRINSICS_TEST(asin, 8) +FLOAT_INTRINSICS_TEST(atan, 8) +FLOAT_INTRINSICS_TEST(cosh, 8) +FLOAT_INTRINSICS_TEST(sinh, 8) +FLOAT_INTRINSICS_TEST(tanh, 8) +FLOAT_INTRINSICS_TEST(expm1, 8) +FLOAT_INTRINSICS_TEST(lgamma, 8) +#undef FLOAT_INTRINSICS_TEST + +#define DOUBLE_INTRINSICS_TEST(Name, Lanes) \ + TEST(LLVM, VecDouble_##Name##Lane##Lanes##Test) { \ + BufHandle a("A", {1}, kDouble); \ + BufHandle b("B", {1}, kDouble); \ + float val = 0.5f; \ + std::vector a_buffer(Lanes, val); \ + std::vector b_buffer(Lanes, val); \ + auto store = b.store( \ + {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \ + LLVMCodeGen cg(store, {a, b}); \ + std::vector args({a_buffer.data(), b_buffer.data()}); \ + ASSERT_EQ(cg.value(args), 0); \ + for (const auto i : c10::irange(Lanes)) { \ + ASSERT_FLOAT_EQ(a_buffer[i], val); \ + } \ + } // namespace jit +DOUBLE_INTRINSICS_TEST(erf, 2) +DOUBLE_INTRINSICS_TEST(erfc, 2) +DOUBLE_INTRINSICS_TEST(acos, 2) +DOUBLE_INTRINSICS_TEST(asin, 2) +DOUBLE_INTRINSICS_TEST(atan, 2) +DOUBLE_INTRINSICS_TEST(cosh, 2) +DOUBLE_INTRINSICS_TEST(sinh, 2) +DOUBLE_INTRINSICS_TEST(tanh, 2) +DOUBLE_INTRINSICS_TEST(expm1, 2) +DOUBLE_INTRINSICS_TEST(lgamma, 2) +DOUBLE_INTRINSICS_TEST(erf, 4) +DOUBLE_INTRINSICS_TEST(erfc, 4) +DOUBLE_INTRINSICS_TEST(acos, 4) +DOUBLE_INTRINSICS_TEST(asin, 4) +DOUBLE_INTRINSICS_TEST(atan, 4) +DOUBLE_INTRINSICS_TEST(cosh, 4) +DOUBLE_INTRINSICS_TEST(sinh, 4) +DOUBLE_INTRINSICS_TEST(tanh, 4) +DOUBLE_INTRINSICS_TEST(expm1, 4) +DOUBLE_INTRINSICS_TEST(lgamma, 4) +#undef DOUBLE_INTRINSICS_TEST + +TEST(LLVM, VectorizerLoadStoreTest) { + BufHandle a("A", {1}, kInt); + + Tensor c = Compute("c", {4}, [&](const VarHandle& i) { return a.load(i); }); + + BufHandle c_buf(c.buf()); + LoopNest l({c}); + StmtPtr s = l.root_stmt(); + ASSERT_TRUE(LoopNest::vectorize(to(to(s)->front()))); + + ASSERT_TRUE(to(to(s)->front()) == nullptr); + + LLVMCodeGen cg(s, {a, c_buf}); + + std::vector a_vec(4, 21); + std::vector c_vec(4, 0); + std::vector args({a_vec.data(), c_vec.data()}); + ASSERT_EQ(cg.value(args), 0); + assertAllEqual(c_vec, 21); +} + +TEST(LLVM, VectorizeBitCast) { + BufHandle a("A", {128}, kInt); + + Tensor c = Compute("c", {128}, [&](const VarHandle& i) { + return bitcast(a.load(i)); + }); + + BufHandle c_buf(c.buf()); + LoopNest l({c}); + StmtPtr s = l.root_stmt(); + ASSERT_TRUE(LoopNest::vectorize(to(to(s)->front()))); + ASSERT_TRUE(to(to(s)->front()) == nullptr); + + LLVMCodeGen cg(s, {a, c_buf}); + + std::vector a_vec(128); + std::vector c_vec(128); + for (const auto i : c10::irange(128)) { + a_vec[i] = raw_bitcast(1337.f); + } + std::vector args({a_vec.data(), c_vec.data()}); + ASSERT_EQ(cg.value(args), 0); + assertAllEqual(c_vec, 1337.f); +} + +TEST(LLVM, MemcpyTest) { + constexpr int N = 32; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + std::vector a_buffer(N, 42); + std::vector b_buffer(N, 0); + + VarHandle i("i", kInt); + auto expr = For::make(i, 0, N, b.store({i}, a.load(i))); + + LLVMCodeGen cg(expr, {a, b}); + + std::vector args({a_buffer.data(), b_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + assertAllEqual(a_buffer, 42); + assertAllEqual(b_buffer, 42); +} + +TEST(LLVM, BzeroTest) { + constexpr int N = 32; + BufHandle b("B", {N}, kInt); + std::vector b_buffer(N, 11); + + VarHandle i("i", kInt); + auto expr = For::make(i, 0, N, b.store({i}, 0)); + + LLVMCodeGen cg(expr, {b}); + + std::vector args({b_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(b_buffer.size(), N); + assertAllEqual(b_buffer, 0); +} + +TEST(LLVM, ElemwiseAdd) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i)))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41); + assertAllEqual(b_buffer, 1); + assertAllEqual(c_buffer, 42); +} + +TEST(LLVM, ElemwiseAddFloat) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + BufHandle c("C", {N}, kFloat); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = For::make(i, 0, N, c.store({i}, a.load(i) + b.load(i))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41.0f); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 42.0f); +} + +TEST(LLVM, ElemwiseLog10Float) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + std::vector a_buffer(N, 10.0f); + std::vector b_buffer(N, 2.0f); + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N / 4, + b.store( + {Ramp::make(i * 4, 1, 4)}, log10(a.load({Ramp::make(i * 4, 1, 4)})))); + + LLVMCodeGen cg(expr, {a, b}); + + std::vector args({a_buffer.data(), b_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + assertAllEqual(a_buffer, 10.0f); + assertAllEqual(b_buffer, 1.0f); +} + +TEST(LLVM, ElemwiseLog1pFloat) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + std::vector a_buffer(N, expf(3.0f) - 1); + std::vector b_buffer(N, 42.0f); + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N / 4, + b.store( + {Ramp::make(i * 4, 1, 4)}, log1p(a.load({Ramp::make(i * 4, 1, 4)})))); + + LLVMCodeGen cg(expr, {a, b}); + + std::vector args({a_buffer.data(), b_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + assertAllEqual(a_buffer, expf(3.0f) - 1); + ExpectAllNear(b_buffer, 3.0f, 1e-5f); +} + +TEST(LLVM, ElemwiseMaxInt) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = + For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41); + assertAllEqual(b_buffer, 1); + assertAllEqual(c_buffer, 41); +} + +TEST(LLVM, ElemwiseMinInt) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = + For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41); + assertAllEqual(b_buffer, 1); + assertAllEqual(c_buffer, 1); +} + +TEST(LLVM, ElemwiseMaxFloat) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + BufHandle c("C", {N}, kFloat); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = + For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41.0f); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 41.0f); +} + +TEST(LLVM, ElemwiseMaxNaNFloat) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + BufHandle c("C", {N}, kFloat); + std::vector a_buffer(N, NAN); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = + For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(b_buffer, 1.0f); + for (auto const& elt : c_buffer) { + ASSERT_TRUE(std::isnan(elt)); + } +} + +TEST(LLVM, ElemwiseMinFloat) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + BufHandle c("C", {N}, kFloat); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = + For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41.0f); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 1.0f); +} + +TEST(LLVM, ElemwiseMinNaNFloat) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + BufHandle c("C", {N}, kFloat); + std::vector a_buffer(N, NAN); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + VarHandle i("i", kInt); + auto expr = + For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(b_buffer, 1.0f); + for (auto const& elt : c_buffer) { + ASSERT_TRUE(std::isnan(elt)); + } +} + +TEST(LLVM, ElemwiseMod) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 23); + std::vector c_buffer(N, 18); + + VarHandle i("i", kInt); + auto expr = For::make(i, 0, N, c.store({i}, Mod::make(a.load(i), b.load(i)))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41); + assertAllEqual(b_buffer, 23); + assertAllEqual(c_buffer, 18); +} + +TEST(LLVM, CompareSelectIntEQ) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 1); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 0); + std::vector c_ref(N, 1); + + for (int i = 0; i < N / 2; i++) { + b_buffer[i] = 0; + c_ref[i] = 0; + } + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kEQ))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(a_buffer, 1); + for (const auto i : c10::irange(N)) { + ASSERT_EQ(c_ref[i], c_buffer[i]); + } +} + +TEST(LLVM, CompareSelectFloatEQ) { + constexpr int N = 1024; + BufHandle a("A", {N}, kFloat); + BufHandle b("B", {N}, kFloat); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 1.0f); + std::vector b_buffer(N, 1.0f); + std::vector c_buffer(N, 0); + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kEQ))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(a_buffer, 1.0f); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 1); +} + +TEST(LLVM, CompareSelectByteGT) { + constexpr int N = 1024; + BufHandle a("A", {N}, kByte); + BufHandle b("B", {N}, kByte); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 0); + std::vector b_buffer(N, 0); + std::vector c_buffer(N, 0); + std::vector c_ref(N, 0); + + for (int i = 0; i < N / 2; i++) { + a_buffer[i] = 128; + c_ref[i] = 1; + } + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kGT))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(b_buffer, uint8_t(0)); + for (const auto i : c10::irange(N)) { + ASSERT_EQ(c_ref[i], c_buffer[i]); + } +} + +TEST(LLVM, CompareSelectByteGE) { + constexpr int N = 1024; + BufHandle a("A", {N}, kByte); + BufHandle b("B", {N}, kByte); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 0); + std::vector b_buffer(N, 0); + std::vector c_buffer(N, 0); + std::vector c_ref(N, 1); + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kGE))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(b_buffer, uint8_t(0)); + for (const auto i : c10::irange(N)) { + ASSERT_EQ(c_ref[i], c_buffer[i]); + } +} + +TEST(LLVM, CompareSelectByteLT) { + constexpr int N = 1024; + BufHandle a("A", {N}, kByte); + BufHandle b("B", {N}, kByte); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 0); + std::vector b_buffer(N, 128); + std::vector c_buffer(N, 0); + std::vector c_ref(N, 1); + + for (int i = 0; i < N / 2; i++) { + a_buffer[i] = 128; + c_ref[i] = 0; + } + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kLT))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(b_buffer, uint8_t(128)); + for (const auto i : c10::irange(N)) { + ASSERT_EQ(c_ref[i], c_buffer[i]); + } +} + +TEST(LLVM, CompareSelectByteLE) { + constexpr int N = 1024; + BufHandle a("A", {N}, kByte); + BufHandle b("B", {N}, kByte); + BufHandle c("C", {N}, kInt); + std::vector a_buffer(N, 0); + std::vector b_buffer(N, 128); + std::vector c_buffer(N, 0); + std::vector c_ref(N, 1); + + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + c.store( + {i}, + CompareSelect::make( + a.load(i), b.load(i), CompareSelectOperation::kLE))); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(b_buffer, uint8_t(128)); + for (const auto i : c10::irange(N)) { + ASSERT_EQ(c_ref[i], c_buffer[i]); + } +} + +TEST(LLVM, StoreFloat) { + BufHandle result("result", {1}, kFloat); + std::vector result_buffer = {0.0f}; + auto expr = result.store({0}, FloatImm::make(3.14f)); + LLVMCodeGen cg(expr, {result}); + std::vector args({result_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + ASSERT_EQ(result_buffer[0], 3.14f); +} + +TEST(LLVM, SimpleMath01) { + const int N = 1024; + Tensor tensor = Compute( + "f", {N}, [](const VarHandle& i) { return cast(i * i + 1); }); + LoopNest l({tensor}); + StmtPtr stmt = l.root_stmt(); + BufHandle f_buf(tensor.buf()); + LLVMCodeGen cg(stmt, {f_buf}); + + PaddedBuffer f_v(N, "f_v"); + std::vector args({f_v.data()}); + int value = cg.value(args); + ASSERT_EQ(value, 0); + PaddedBuffer f_ref(N, "f_ref"); + for (const auto i : c10::irange(N)) { + f_ref(i) = i * i + 1; + } + ExpectAllNear(f_v, f_ref, 1e-5); +} + +TEST(LLVM, ComputeMul) { + const int N = 1024; + BufHandle a("a", {N}, kFloat); + BufHandle b("b", {N}, kFloat); + Tensor c = Compute( + "c", {N}, [&](const VarHandle& i) { return a.load(i) * b.load(i); }); + + BufHandle c_buf(c.buf()); + LoopNest l({c}); + StmtPtr s = l.root_stmt(); + + LLVMCodeGen cg(s, {a, b, c_buf}); + + std::vector a_vec(N, 21.0f); + std::vector b_vec(N, 2.0f); + std::vector c_vec(N, 0.0f); + std::vector args({a_vec.data(), b_vec.data(), c_vec.data()}); + ASSERT_EQ(cg.value(args), 0); + assertAllEqual(c_vec, 42.0f); +} + +TEST(LLVM, BroadcastAdd) { + const int M = 32; + const int N = 1024; + BufHandle a("a", {M, N}, kFloat); + BufHandle b("b", {N}, kFloat); + Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) { + return a.load(i, j) + b.load(j); + }); + + BufHandle c_buf(c.buf()); + LoopNest l({c}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + + LLVMCodeGen cg(s, {a, b, c_buf}); + + std::vector av(M * N); + std::iota(av.begin(), av.end(), 0); + std::vector bv(N); + std::iota(bv.begin(), bv.end(), 0); + std::vector cv(M * N, 0); + std::vector args({av.data(), bv.data(), cv.data()}); + ASSERT_EQ(cg.value(args), 0); + + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + ASSERT_EQ(cv[i * N + j], av[i * N + j] + bv[j]); + } + } +} + +TEST(LLVM, BitwiseOps) { + auto a = IntImm::make(59); + auto b = IntImm::make(11); + auto c = IntImm::make(101); + auto d = IntImm::make(2); + + ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d; + LLVMExprEval cg(f); + + ASSERT_EQ(cg.value(), 11); +} + +TEST(LLVM, ArithmeticRightShift) { + auto a = CharImm::make(-4); + auto b = CharImm::make(1); + ExprHandle f = a >> b; + LLVMExprEval cg(f); + ASSERT_EQ(cg.value(), -2); +} + +TEST(LLVM, LogicalRightShift) { + auto a = ByteImm::make(0xfc); + auto b = ByteImm::make(1); + ExprHandle f = a >> b; + LLVMExprEval cg(f); + ASSERT_EQ(cg.value(), 0x7e); +} + +TEST(LLVM, DynamicShapeAdd) { + auto testWithSize = [](int32_t size) { + VarHandle n("n", kInt); + BufHandle a("a", {n}, kFloat); + BufHandle b("b", {n}, kFloat); + BufHandle c("c", {n}, kFloat); + VarHandle i("i", kInt); + StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); + std::vector aData(size, 1.0f); + std::vector bData(size, 2.0f); + std::vector cData(size, 0.0f); + LLVMCodeGen cg(s, {a, b, c, n}); + std::vector args({aData.data(), bData.data(), cData.data(), &size}); + cg.value(args); + ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); + }; + testWithSize(1); + testWithSize(16); + testWithSize(37); +} + +TEST(LLVM, BindDynamicShapeAdd) { + auto testWithSize = [](int32_t size) { + VarHandle n("n", kInt); + BufHandle a("a", {n}, kFloat); + BufHandle b("b", {n}, kFloat); + BufHandle c("c", {n}, kFloat); + VarHandle i("i", kInt); + StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i))); + std::vector aData(size, 1.0f); + std::vector bData(size, 2.0f); + std::vector cData(size, 0.0f); + LLVMCodeGen cg(s, {a, b, c, n}); + cg.call({aData, bData, cData, size}); + ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); + }; + testWithSize(1); + testWithSize(16); + testWithSize(37); +} + +TEST(LLVM, TensorDynamicShapeAdd) { + auto testWithSize = [](int32_t size) { + VarHandle n("n", kInt); + BufHandle a("a", {n}, kFloat); + BufHandle b("b", {n}, kFloat); + Tensor c = Compute( + "c", {n}, [&](const VarHandle& i) { return a.load(i) + b.load(i); }); + LoopNest l({c}); + StmtPtr s = l.root_stmt(); + LLVMCodeGen cg(s, {a, b, c, n}); + std::vector aData(size, 1.0f); + std::vector bData(size, 2.0f); + std::vector cData(size, 0.0f); + cg.call({aData, bData, cData, size}); + ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); + }; + testWithSize(1); + testWithSize(16); + testWithSize(37); +} + +TEST(LLVM, DynamicShape2D) { + auto testWithSize = [](int32_t M, int32_t N) { + VarHandle m("m", kInt); + VarHandle n("n", kInt); + BufHandle a("a", {m, n}, kFloat); + BufHandle b("b", {m, n}, kFloat); + Tensor c = + Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { + return a.load(i, j) + b.load(i, j); + }); + LoopNest l({c}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + LLVMCodeGen cg(s, {a, b, c, m, n}); + std::vector aData(M * N, 1.0f); + std::vector bData(M * N, 2.0f); + std::vector cData(M * N, 0.0f); + cg.call({aData, bData, cData, M, N}); + ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); + }; + testWithSize(1, 8); + testWithSize(16, 32); + testWithSize(37, 11); +} + +TEST(LLVM, EmptyStmt) { + StmtPtr s = alloc(std::vector({})); + + LLVMCodeGen cg(s, {}); + cg.call({}); + // Just don't crash. +} + +TEST(LLVM, EliminatedStmt) { + BufHandle a("a", {1}, kFloat); + + Tensor c = Compute("c", {0}, [&](const VarHandle& m) { return m; }); + + LoopNest l({c}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + s = IRSimplifier::simplify(s); + LLVMCodeGen cg(s, {a, c}); + std::vector aData(1, 1.0f); + std::vector cData(0, 0.0f); + cg.call({aData, cData}); +} + +TEST(LLVM, SimpleReduction) { + int M = 128; + int N = 64; + + BufHandle a("a", {1, M, N}, kFloat); + + Tensor b = Reduce("sum", {1}, Sum(), a, {M, N}); + LoopNest loop({b}); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + LLVMCodeGen cg(s, {a, b}); + + PaddedBuffer a_v(1, M, N, "a_v"); + PaddedBuffer b_v(1, "b_v"); + PaddedBuffer b_ref(1, "b_ref"); + + b_ref(0) = 0; + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + int v = i + j; + a_v(0, i, j) = v; + b_ref(0) += v; + } + } + + cg.call({a_v, b_v}); + + ExpectAllNear(b_v, b_ref, 1e-5); +} + +TEST(LLVM, RFactorReduction) { + int M = 128; + int N = 64; + + BufHandle a("a", {1, M, N}, kFloat); + + Tensor b = Reduce("sum", {1}, Sum(), a, {M, N}); + LoopNest loop({b}); + + std::vector loops = loop.getLoopStmtsFor(b); + ForPtr loop_m = loops.at(1); + ForPtr loop_n = loops.at(2); + loop.reorderAxis(loop_m, loop_n); + + loops = loop.getLoopStmtsFor(b); + loop_m = loops.at(2); + loop_n = loops.at(1); + auto b_body = loop.getAllWritesToBuf(b.buf())[1]; + ASSERT_TRUE(loop.rfactor(b_body, loop_n)); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + LLVMCodeGen cg(s, {a, b}); + + PaddedBuffer a_v(1, M, N, "a_v"); + PaddedBuffer b_v(1, "b_v"); + PaddedBuffer b_ref(1, "b_ref"); + + b_ref(0) = 0; + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + int v = i + j; + a_v(0, i, j) = v; + b_ref(0) += v; + } + } + + cg.call({a_v, b_v}); + + ExpectAllNear(b_v, b_ref, 1e-5); +} + +TEST(LLVM, RFactorVectorizedReduction) { + int M = 128; + int N = 64; + + BufHandle a("a", {1, M, N}, kFloat); + + Tensor b = Reduce("sum", {1}, Sum(), a, {M, N}); + LoopNest loopnest({b}); + std::vector loops = loopnest.getLoopStmtsFor(b); + // Reorder n and m loops + loopnest.reorderAxis(loops.at(1), loops.at(2)); + auto b_body = loopnest.getAllWritesToBuf(b.buf()).at(1); + auto all_loops = loopnest.getAllLoopNestsWritingToBuf(b.buf()); + ASSERT_TRUE(all_loops.size() == 2 && all_loops[1].size() == 3); + ASSERT_TRUE(loopnest.rfactor(b_body, all_loops[1][1])); + auto distributed_loops = loopnest.distributeLoop(all_loops[1][1]); + + // Vectorize initializer of rfac_buf + ASSERT_TRUE(LoopNest::vectorize(distributed_loops[0])); + // Vectorize producer of rfac_buf + ASSERT_TRUE(LoopNest::vectorize(distributed_loops[1])); + loopnest.simplify(); + + loopnest.prepareForCodegen(); + + StmtPtr s = IRSimplifier::simplify(loopnest.root_stmt()); + LLVMCodeGen cg(s, {a, b}); + + PaddedBuffer a_v(1, M, N, "a_v"); + PaddedBuffer b_v(1, "b_v"); + PaddedBuffer b_ref(1, "b_ref"); + + b_ref(0) = 0; + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + int v = i + j; + a_v(0, i, j) = v; + b_ref(0) += v; + } + } + + cg.call({a_v, b_v}); + + ExpectAllNear(b_v, b_ref, 1e-5); +} + +template +static void testSimpleParallel() { + // Compute a simple operation, and try all loop-axis combination to be + // parallel or sequential. + const int M = 4; + const int N = 6; + Tensor f = Compute("f", {M, N}, [](const VarHandle& m, const VarHandle& n) { + return cast(m + n); + }); + LoopNest loop_nest({f}); + auto const& loops = loop_nest.getLoopStmtsFor(f); + ForPtr m = loops[0]; + ForPtr n = loops[1]; + if (outer) { + m->set_parallel(); + } + if (inner) { + n->set_parallel(); + } + loop_nest.prepareForCodegen(); + StmtPtr stmt = loop_nest.root_stmt(); + LLVMCodeGen cg(stmt, {f}); + + PaddedBuffer f_v(M, N, "f_v"); + std::vector args({f_v.data()}); + int value = cg.value(args); + ASSERT_EQ(value, 0); + PaddedBuffer f_ref(M, N, "f_ref"); + for (const auto m : c10::irange(M)) { + for (const auto n : c10::irange(N)) { + f_ref(m, n) = m + n; + } + } + ExpectAllNear(f_v, f_ref, 1e-5); +} + +TEST(LLVM, SimpleParallelSS) { + testSimpleParallel(); +} +TEST(LLVM, SimpleParallelSP) { + testSimpleParallel(); +} +TEST(LLVM, SimpleParallelPS) { + testSimpleParallel(); +} +TEST(LLVM, SimpleParallelPP) { + testSimpleParallel(); +} + +TEST(LLVM, CompositeParallel) { + int loop_count = 6; + int test_count = 1 << loop_count; + // Compute a composite operation, and try all loop-axis combination to be + // parallel or sequential. + for (const auto test_cfg : c10::irange(test_count)) { + int M = 5; + int N = 7; + Tensor t1 = Compute("t1", {M}, [](const VarHandle& m) { return m + 1.f; }); + Tensor t2 = Compute("t2", {N}, [](const VarHandle& n) { return n + 2.f; }); + Tensor t3 = + Compute("t3", {M, N}, [=](const VarHandle& m, const VarHandle& n) { + return t1.load(m) * t2.load(n); + }); + Tensor t4 = + Compute("t4", {M, N}, [=](const VarHandle& m, const VarHandle& n) { + return t3.load(m, n) + m + n; + }); + LoopNest loop_nest({t4}, {t1, t2, t3, t4}); + std::vector loop_list; + { + auto const& loops = loop_nest.getLoopStmtsFor(t1); + loop_list.push_back(loops[0]); + } + { + auto const& loops = loop_nest.getLoopStmtsFor(t2); + loop_list.push_back(loops[0]); + } + { + auto const& loops = loop_nest.getLoopStmtsFor(t3); + loop_list.push_back(loops[0]); + loop_list.push_back(loops[1]); + } + { + auto const& loops = loop_nest.getLoopStmtsFor(t4); + loop_list.push_back(loops[0]); + loop_list.push_back(loops[1]); + } + ASSERT_EQ(loop_list.size(), loop_count); + for (const auto i : c10::irange(loop_count)) { + if (test_cfg & (1 << i)) { + loop_list[i]->set_parallel(); + } + } + loop_nest.prepareForCodegen(); + StmtPtr stmt = loop_nest.root_stmt(); + LLVMCodeGen cg(stmt, {t4}); + + PaddedBuffer t4_v(M, N, "t4_v"); + std::vector args({t4_v.data()}); + int value = cg.value(args); + ASSERT_EQ(value, 0); + PaddedBuffer t4_ref(M, N, "t4_ref"); + for (const auto m : c10::irange(M)) { + for (const auto n : c10::irange(N)) { + t4_ref(m, n) = (m + 1) * (n + 2) + m + n; + } + } + ExpectAllNear(t4_v, t4_ref, 1e-5); + } +} + +TEST(LLVM, VectorizedGEMM) { + int M = 32; + int N = 32; + int K = 48; + + BufHandle AP("A", {M, K}, kFloat); + BufHandle BP("B", {K, N}, kFloat); + Tensor CT = Reduce( + "gemm", + {M, N}, + Sum(), + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { + return AP.load(m, k) * BP.load(k, n); + }, + {K}); + LoopNest loop({CT}); + + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr m = loops[0]; + loop.splitWithMask(m, 16); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr n = loops[2]; + loop.splitWithMask(n, 16); + } + // mo, mi, no, ni, k -> + // mo, no, mi, ni, k + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr mi = loops[1]; + ForPtr no = loops[2]; + loop.reorderAxis(mi, no); + } + // mo, no, mi, ni, k -> + // mo, no, mi, k, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr ni = loops[3]; + ForPtr k = loops[4]; + loop.reorderAxis(ni, k); + } + // mo, no, mi, k, ni -> + // mo, no, k, mi, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr mi = loops[2]; + ForPtr k = loops[3]; + loop.reorderAxis(mi, k); + } + { + auto loops = NodeFinder::find(loop.root_stmt()); + ASSERT_TRUE(LoopNest::vectorize(loops[3])); + ASSERT_TRUE(LoopNest::vectorize(loops.back())); + } + + loop.prepareForCodegen(); + + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + LLVMCodeGen cg(s, {AP, BP, CT}); + + PaddedBuffer a_v(M, K, "a_v"); + PaddedBuffer b_v(K, N, "b_v"); + PaddedBuffer c_v(M, N, "c_v"); + PaddedBuffer c_ref(M, N, "c_ref"); + + for (const auto m : c10::irange(M)) { + for (const auto n : c10::irange(N)) { + c_ref(m, n) = 0.f; + for (const auto k : c10::irange(K)) { + c_ref(m, n) += a_v(m, k) * b_v(k, n); + } + } + } + + cg.call({a_v, b_v, c_v}); + + ExpectAllNear(c_v, c_ref, 1e-5); +} + +TEST(LLVM, CallRaw) { + const int M = 32; + VarHandle N("N", kInt); + BufHandle a("a", {M, N}, kFloat); + BufHandle b("b", {N}, kFloat); + Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) { + return a.load(i, j) + b.load(j); + }); + + LoopNest l({c}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + + int32_t N_value = 1024; + std::vector av(M * N_value); + std::iota(av.begin(), av.end(), 0); + std::vector bv(N_value); + std::iota(bv.begin(), bv.end(), 0); + std::vector cv(M * N_value, 0); + std::vector args({av.data(), bv.data(), cv.data(), &N_value}); + + LLVMCodeGen cg(s, {a, b, BufHandle(c.buf()), N}); + cg.call_raw(args); + + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N_value)) { + ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]); + } + } + + SimpleIREvaluator eval(s, {a, b, BufHandle(c.buf()), N}); + eval.call_raw(args); + + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N_value)) { + ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]); + } + } +} + +TEST(LLVM, CustomTarget) { + constexpr int M = 16; + BufHandle a("a", {M}, kFloat); + BufHandle b("b", {M}, kFloat); + BufHandle c("c", {M}, kFloat); + Tensor d = Compute("d", {M}, [&](const VarHandle& m) { + return a.load(m) * b.load(m) + c.load(m); + }); + LoopNest nest({d}); + nest.prepareForCodegen(); + auto cg = LLVMCodeGenBuilder(nest.root_stmt(), {a, b, c, d}) + .triple("i686-elf") + .cpu("i386") + .build(); + std::ostringstream ss; + ss << cg->getCodeText("asm"); + torch::jit::testing::FileCheck() + .check("fadds") + ->check("fmuls") + ->check_not("vfmadd") + ->run(ss.str()); +} + +TEST(LLVM, CodeGenKernelFuncName) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + std::vector a_buffer = {42}; + std::vector b_buffer = {-11}; + auto store = b.store({0}, a.load(0)); + + { + LLVMCodeGen cg(store, {a, b}); + // Check that the kernel function name used by LLVMCodeGen + // is not empty. + ASSERT_NE(cg.kernel_func_name(), ""); + } + + { + LLVMCodeGen cg(store, {a, b}, at::kCPU, "new_func"); + // Check that the kernel function name used by LLVMCodeGen + // is the one that was given above. + ASSERT_EQ(cg.kernel_func_name(), "new_func"); + } +} + +} // namespace jit +} // namespace torch + +#endif // TORCH_ENABLE_LLVM diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp new file mode 100644 index 0000000000000..a8bda8814dbae --- /dev/null +++ b/test/cpp/tensorexpr/test_loopnest.cpp @@ -0,0 +1,6894 @@ +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +void checkIR(StmtPtr s, const std::string& pattern) { + std::ostringstream oss; + oss << *s; + torch::jit::testing::FileCheck().run(pattern, oss.str()); +} + +void checkExprIR(ExprPtr e, const std::string& pattern) { + std::string prefixed_pattern = "# CHECK: " + pattern + "\n"; + std::ostringstream oss; + oss << *e << "\n"; + torch::jit::testing::FileCheck().run(prefixed_pattern, oss.str()); +} + +void checkExprIR(const ExprHandle& e, const std::string& pattern) { + checkExprIR(e.node(), pattern); +} + +TEST(LoopNest, ExprSimple01) { + Tensor tensor = + Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }); + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + + LoopNest::splitWithTail(loops[0], 2); + LoopNest::splitWithTail(loops[0], 2); +} + +TEST(LoopNest, ExprLower01) { + Tensor tensor = + Compute("f", {16, 5}, [](const VarHandle& x, const VarHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }); + LoopNest l({tensor}); + StmtPtr stmt = l.root_stmt(); + std::ostringstream oss; + oss << *stmt; + ASSERT_GT(oss.str().size(), 20); + ASSERT_LT(oss.str().size(), 200); +} + +TEST(LoopNest, ExprSimple02) { + auto func = [](const ExprHandle& x, const ExprHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }; + Tensor tensor = Compute("f", {26, 5}, func); + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + + LoopNest::splitWithTail(loops[0], 4); + + StmtPtr stmt = l.root_stmt(); + std::ostringstream oss; + oss << *stmt; + ASSERT_GT(oss.str().size(), 200); + ASSERT_LT(oss.str().size(), 600); + + { + // Compare to a reference loop structure structure. + VarHandle x_outer("i_outer", kInt); + VarHandle x_inner("i_inner", kInt); + VarHandle y("i", kInt); + VarHandle x_tail("i_tail", kInt); + BufHandle f("f", {26, 5}, kFloat); + ExprHandle x_1 = x_outer * 4 + x_inner; + ExprHandle x_outer_end = (ExprHandle(26) - 0) / 4; + ForPtr stmt1 = For::make( + x_outer, + 0, + x_outer_end, + For::make( + x_inner, + 0, + 4, + For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y))))); + ExprHandle x_2 = x_tail + x_outer_end * 4; + ForPtr stmt2 = For::make( + x_tail, + 0, + (ExprHandle(26) - 0) % 4, + For::make(y, 0, 5, Store::make(f, {x_2, y}, func(x_2, y)))); + StmtPtr stmt = Block::make({stmt1, stmt2}); + + std::ostringstream oss_ref; + oss_ref << *stmt; + ASSERT_EQ(oss.str(), oss_ref.str()); + } + + { + PaddedBuffer f_v(26, 5, "f_v"); + PaddedBuffer f_ref(26, 5, "f_res"); + + stmt = FlattenIndexes(stmt); + SimpleIREvaluator ir_eval(stmt, {tensor}); + ir_eval(f_v); + + for (int x = 0; x < 26; x++) { + for (int y = 0; y < 5; y++) { + f_ref(x, y) = 1 + x * x + y * y; + } + } + + ExpectAllNear(f_v, f_ref, 1e-5); + } +} + +BlockPtr getSimplifiedBody(const LoopNest& l) { + StmtPtr stmt = l.root_stmt(); + StmtPtr simplified = IRSimplifier::simplify(stmt); + return to(simplified); +} + +void assertForRange(ForPtr f, int expected_start, int expected_stop) { + ASSERT_NE(f, nullptr); + IntImmPtr start = to(f->start()); + ASSERT_NE(start, nullptr); + ASSERT_EQ(start->value(), expected_start); + IntImmPtr stop = to(f->stop()); + ASSERT_NE(stop, nullptr); + ASSERT_EQ(stop->value(), expected_stop); +} + +void assertForRanges( + BlockPtr body, + const std::vector>& start_stops) { + ASSERT_EQ(body->nstmts(), start_stops.size()); + + auto it = body->begin(); + for (size_t i = 0; i < start_stops.size(); i++, it++) { + ForPtr loop = to(*it); + assertForRange(loop, start_stops[i].first, start_stops[i].second); + } +} + +TEST(LoopNest, ExprSliceHeadWithLoopOptions) { + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); + LoopNest::sliceHead(loops[0], 2, &head, &tail); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 2}, {0, 8}}); + + ASSERT_TRUE(tail->loop_options().is_gpu_block_index()); + ASSERT_EQ(tail->loop_options().gpu_block_index(), LoopOptions::IDX_Y); + + ASSERT_TRUE(head->loop_options().isDefault()); +} + +TEST(LoopNest, ExprSliceTailWithLoopOptions) { + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::sliceTail(loops[0], 4, &head, &tail); + + ForPtr tail_head; + ForPtr tail_tail; + tail->set_gpu_block_index(LoopOptions::IDX_Y); + LoopNest::sliceTail(tail, 2, &tail_head, &tail_tail); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 6}, {0, 2}, {8, 10}}); + + ASSERT_TRUE(tail_head->loop_options().is_gpu_block_index()); + ASSERT_EQ(tail_head->loop_options().gpu_block_index(), LoopOptions::IDX_Y); + + ASSERT_TRUE(head->loop_options().isDefault()); + ASSERT_TRUE(tail_tail->loop_options().isDefault()); +} + +TEST(LoopNest, ExprSliceHeadWhenFactorEqualsSize) { + // When factor equals the For loop's original size, keep using the original + // For loop. + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::sliceHead(loops[0], 10, &head, &tail); + + ASSERT_EQ(head, loops[0]); + ASSERT_EQ(tail, nullptr); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 10}}); +} + +TEST(LoopNest, ExprSliceHeadWhenFactorLargerThanSize) { + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::sliceHead(loops[0], 100, &head, &tail); + + ASSERT_EQ(head, loops[0]); + ASSERT_EQ(tail, nullptr); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 10}}); +} + +TEST(LoopNest, ExprSliceHead) { + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::sliceHead(loops[0], 4, &head, &tail); + + ASSERT_NE(head, nullptr); + ASSERT_NE(head, loops[0]); + ASSERT_NE(tail, nullptr); + ASSERT_EQ(tail, loops[0]); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 4}, {4, 10}}); +} + +TEST(LoopNest, ExprSliceHeadWithNonZeroStart) { + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + + ForPtr head; + ForPtr tail; + LoopNest::sliceTail(loops[0], 4, &head, &tail); + // head: [0, 6) + // tail: [6, 10) + + LoopNest::sliceHead(tail, 2); + // tail_head: [6, 8) + // tail_tail: [8, 10) + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 6}, {6, 8}, {8, 10}}); +} + +TEST(LoopNest, ExprSliceTailWhenFactorEqualsSize) { + // When factor equals the For loop's original size, keep using the original + // For loop. + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::sliceTail(loops[0], 10, &head, &tail); + + ASSERT_EQ(head, nullptr); + ASSERT_EQ(tail, loops[0]); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 10}}); +} + +TEST(LoopNest, ExprSliceTailWhenFactorLargerThanSize) { + // When factor equals the For loop's original size, keep using the original + // For loop. + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::sliceTail(loops[0], 100, &head, &tail); + + ASSERT_EQ(head, nullptr); + ASSERT_EQ(tail, loops[0]); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 10}}); +} + +TEST(LoopNest, ExprSliceTail) { + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + ForPtr head; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::sliceTail(loops[0], 4, &head, &tail); + + ASSERT_NE(head, nullptr); + ASSERT_EQ(head, loops[0]); + ASSERT_NE(tail, nullptr); + ASSERT_NE(tail, loops[0]); + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 6}, {6, 10}}); +} + +TEST(LoopNest, ExprSplitAndSlice) { + // 0: splitWithTail + // 1: sliceTail on inner loop + // 2: sliceHead on outer loop + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {100}, func); + LoopNest l({tensor}); + + ForPtr inner; + ForPtr tail; + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + // outer: [0, 4) + // inner: [0, 21) + // tail: [84, 100) + LoopNest::splitWithTail(loops[0], 21, &inner, &tail); + LoopNest::sliceTail(inner, 2); + LoopNest::sliceHead(loops[0], 2); + + // for (int x_outer = 0; x_outer < 2; x_outer++) { + // for (int x_inner = 0; x_inner < 19; x_inner++) { + // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); + // } + // for (int x_inner = 19; x_inner < 21; x_inner++) { + // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); + // } + // } + // for (int x_outer = 2; x_outer < 4; x_outer++) { + // for (int x_inner = 0; x_inner < 19; x_inner++) { + // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); + // } + // for (int x_inner = 19; x_inner < 21; x_inner++) { + // f[21 * x_outer + x_inner] = 1.f + float(21 * x_outer + x_inner); + // } + // } + // for (int x_tail = 0; x_tail < 16; x_tail++) { + // f[x_tail + 84] = 1.f + float(x_tail + 84); + // } + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 2}, {2, 4}, {0, 16}}); + + auto biter = body->begin(); + + ForPtr loop = to(*biter++); + assertForRanges(loop->body(), {{0, 19}, {19, 21}}); + + loop = to(*biter); + assertForRanges(loop->body(), {{0, 19}, {19, 21}}); +} + +TEST(LoopNest, ExprSliceAndNormalize) { + // 0: sliceHead + // 1: normalize tail + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {10}, func); + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + + ForPtr head; + ForPtr tail; + LoopNest::sliceHead(loops[0], 2, &head, &tail); + // head: [0, 2) + // tail: [2, 10) + + LoopNest::normalize(tail); + // normalized_tail: [0, 8) + + BlockPtr body = getSimplifiedBody(l); + assertForRanges(body, {{0, 2}, {0, 8}}); +} + +template +T evalExpr(const ExprHandle& expr, const VarHandle& var, T value) { + ExprEval eval(expr, {var}); + return eval.value(value); +} + +TEST(LoopNest, ExprSliceWithVariableDimension) { + auto testWithDimension = + [](int dimension, + const std::vector>& expected_for_ranges) { + VarHandle dim("dim", kInt); + Tensor tensor = + Compute("f", {dim}, [](const ExprHandle& x) { return x; }); + LoopNest l({tensor}); + std::vector loops = + l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + + ForPtr head; + ForPtr tail; + LoopNest::sliceHead(loops[0], 2, &head, &tail); + + LoopNest::sliceTail(tail, 2); + + BlockPtr body = getSimplifiedBody(l); + ASSERT_EQ(expected_for_ranges.size(), 3); + auto it = body->begin(); + for (auto& start_stop : expected_for_ranges) { + ForPtr loop = to(*it++); + int start = evalExpr(ExprHandle(loop->start()), dim, dimension); + int stop = evalExpr(ExprHandle(loop->stop()), dim, dimension); + ASSERT_EQ(start, start_stop.first); + ASSERT_EQ(stop, start_stop.second); + } + }; + + testWithDimension(1, {{0, 1}, {1, 1}, {1, 1}}); + testWithDimension(2, {{0, 2}, {2, 2}, {2, 2}}); + testWithDimension(3, {{0, 2}, {2, 2}, {2, 3}}); + testWithDimension(4, {{0, 2}, {2, 2}, {2, 4}}); + testWithDimension(5, {{0, 2}, {2, 3}, {3, 5}}); + testWithDimension(10, {{0, 2}, {2, 8}, {8, 10}}); +} + +TEST(LoopNest, ExprSplitWithTail) { + auto func = [](const ExprHandle& x) { + return ExprHandle(1.0f) + cast(x); + }; + Tensor tensor = Compute("f", {199}, func); + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + LoopNest::splitWithTail(loops[0], 17); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + LoopNest::splitWithTail(loops[0], 7); + + StmtPtr stmt = l.root_stmt(); + StmtPtr simplified = IRSimplifier::simplify(stmt); + BlockPtr body = to(simplified); + ASSERT_EQ(body->nstmts(), 3); + auto biter = body->begin(); + + // Verify that the split loops are ordered correctly. + ForPtr loop = to(*biter++); + assertForRange(loop, 0, 7); + + loop = to(*biter++); + assertForRange(loop, 0, 4); + + loop = to(*biter); + assertForRange(loop, 0, 12); +} + +TEST(LoopNest, ExprSplitWithTailNone) { + auto func = [](const ExprHandle& x, const ExprHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }; + Tensor tensor = Compute("f", {24, 5}, func); + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::splitWithTail(loops[0], 4); + + StmtPtr stmt = l.root_stmt(); + std::ostringstream oss; + oss << *stmt; + ASSERT_GT(oss.str().size(), 200); + ASSERT_LT(oss.str().size(), 600); + + { + // Compare to a reference loop structure structure. + VarHandle x_outer("i_outer", kInt); + VarHandle x_inner("i_inner", kInt); + VarHandle y("i", kInt); + VarHandle x_tail("i_tail", kInt); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) + BufHandle f("f", {24, 5}, kFloat); + ExprHandle x_1 = x_outer * 4 + x_inner; + ExprHandle x_outer_end = (ExprHandle(24) - 0) / 4; + StmtPtr stmt = alloc(std::vector({For::make( + x_outer, + 0, + x_outer_end, + For::make( + x_inner, + 0, + 4, + For::make(y, 0, 5, Store::make(f, {x_1, y}, func(x_1, y)))))})); + + std::ostringstream oss_ref; + oss_ref << *stmt; + ASSERT_EQ(oss.str(), oss_ref.str()); + } + + { + PaddedBuffer f_v(24, 5, "f_v"); + PaddedBuffer f_ref(24, 5, "f_res"); + + SimpleIREvaluator ir_eval(stmt, {tensor}); + ir_eval(f_v); + + for (int x = 0; x < 24; x++) { + for (int y = 0; y < 5; y++) { + f_ref(x, y) = 1 + x * x + y * y; + } + } + + ExpectAllNear(f_v, f_ref, 1e-5); + } +} + +TEST(LoopNest, ExprSplitWithMask01) { + const int M = 26; + const int N = 5; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {M, N}, kFloat); + Tensor tensor = + Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return a_buf.load(m, n) + b_buf.load(m, n) + 1.0f; + }); + + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::splitWithMask(loops[1], 4); + + StmtPtr stmt = l.root_stmt(); + + PaddedBuffer a_v(M, N, "a"); + PaddedBuffer b_v(M, N, "b"); + PaddedBuffer c_v(M, N, "c"); + PaddedBuffer c_ref(M, N, "c_ref"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + a_v(m, n) = 2 * m; + b_v(m, n) = 3 * n; + c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; + } + } + + SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); + + ExpectAllNear(c_v, c_ref, 1e-5); +} + +// Tests the case where we split a loop cleanly multiple times, we should not +// insert any masks. +TEST(LoopNest, ExprSplitWithMaskRepeatedNoMask) { + const int M = 64; + BufHandle a_buf("a", {M}, kFloat); + BufHandle b_buf("b", {M}, kFloat); + Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { + return a_buf.load(m) + b_buf.load(m) + 1.0f; + }); + + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::splitWithMask(loops[0], 4); + LoopNest::splitWithMask(loops[0], 4); + + StmtPtr stmt1 = IRSimplifier::simplify(l.root_stmt()); + + // Two splits mean 3 loops, but should need no masks in this case. + checkIR(stmt1, R"IR( +# CHECK: for ( +# CHECK-NOT: if ( +# CHECK: for ( +# CHECK-NOT: if ( +# CHECK: for ( +# CHECK-NOT: if ( +# CHECK: f[)IR"); +} + +TEST(LoopNest, getLoopAt) { + // Input IR: + // for (int i = 0; i < 100; i++) { + // for (int j = 0; j < 100; j++) { + // A[i, j] = sin(i * j); + // for (int k1 = 0; k1 < 200; k1++) { + // B[i, j, k1] = (A[i, j]) / (k1 + 1); + // } + // for (int k2 = 0; k2 < 300; k2++) { + // C[i, j, k2] = (A[i, j]) * (k2 + 1); + // } + // } + // } + BufPtr A = alloc( + "A", + std::vector({alloc(100), alloc(100)}), + kInt); + BufPtr B = alloc( + "B", + std::vector( + {alloc(100), alloc(100), alloc(200)}), + kInt); + BufPtr C = alloc( + "C", + std::vector( + {alloc(100), alloc(100), alloc(300)}), + kInt); + BufHandle a_buf(A); + BufHandle b_buf(B); + BufHandle c_buf(C); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k1("k1", kInt); + VarHandle k2("k2", kInt); + auto store1 = Store::make(a_buf, {i, j}, sin(i * j)); + auto store2 = Store::make( + b_buf, {i, j, k1}, Div::make(Load::make(a_buf, {i, j}), (k1 + 1))); + auto store3 = Store::make( + c_buf, {i, j, k2}, Mul::make(Load::make(a_buf, {i, j}), (k2 + 1))); + auto for_k2 = For::make(k2, 0, 300, Block::make({store3})); + auto for_k1 = For::make(k1, 0, 200, Block::make({store2})); + auto for_j = For::make(j, 0, 100, Block::make({store1, for_k1, for_k2})); + auto for_i = For::make(i, 0, 100, for_j); + LoopNest l(Block::make({for_i}), {B, C}); + auto ret_k2 = l.getLoopAt(for_i, {0, 2}); + TORCH_CHECK(ret_k2 == for_k2); + + std::ostringstream oss; + oss << *ret_k2; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int k2 +# CHECK-NEXT: C[i, j, k2] = + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(LoopNest, TileSimple) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + const int M = 64, N = 64; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {M, N}, kFloat); + Tensor tensor = + Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; + }); + + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + l.tile(loops[0], loops[1], 4, 8); + + // IR check + StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); + checkIR(stmt, R"IR( +# CHECK: for (int i_outer +# CHECK: for (int i_outer_1 +# CHECK: for (int i_inner +# CHECK: for (int i_inner_1 +# CHECK: f[ +# CHECK-NOT: for (int i_tail +# CHECK-NOT: for (int i_tail)IR"); + + // Correctness check + PaddedBuffer a_v(M, N, "a"); + PaddedBuffer b_v(M, N, "b"); + PaddedBuffer c_v(M, N, "c"); + PaddedBuffer c_ref(M, N, "c_ref"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + a_v(m, n) = 2 * m; + b_v(m, n) = 3 * n; + c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; + } + } + + SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(c_v, c_ref, 1e-5); +} + +TEST(LoopNest, TileWithTails) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + const int M = 64, N = 64; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {M, N}, kFloat); + Tensor tensor = + Compute("f", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return a_buf.load({m, n}) + b_buf.load({m, n}) + 1.0f; + }); + + LoopNest l({tensor}); + std::vector loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + l.tile(loops[0], loops[1], 5, 9); + + // IR check + StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); + checkIR(stmt, R"IR( +# CHECK: for (int i_outer +# CHECK: for (int i_outer_1 +# CHECK: for (int i_inner +# CHECK: for (int i_inner_1 +# CHECK: f[ +# CHECK: for (int i_inner +# CHECK: f[ +# CHECK: for (int i_tail)IR"); + + // Correctness check + PaddedBuffer a_v(M, N, "a"); + PaddedBuffer b_v(M, N, "b"); + PaddedBuffer c_v(M, N, "c"); + PaddedBuffer c_ref(M, N, "c_ref"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + a_v(m, n) = 2 * m; + b_v(m, n) = 3 * n; + c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; + } + } + + SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(c_v, c_ref, 1e-5); +} + +TEST(LoopNest, TileInMiddle) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + const int M = 8, N = 8, L = 8, K = 8; + BufHandle a_buf("a", {M, N, L, K}, kFloat); + BufHandle b_buf("b", {M, N, L, K}, kFloat); + Tensor tensor = Compute( + "f", + {M, N, L, K}, + [&](const ExprHandle& m, + const ExprHandle& n, + const ExprHandle& l, + const ExprHandle& k) { + return a_buf.load({m, n, l, k}) + b_buf.load({m, n, l, k}) + 1.0f; + }); + + LoopNest nest({tensor}); + std::vector loops = + nest.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + nest.tile(loops[1], loops[2], 3, 3); + + // IR check + StmtPtr stmt = IRSimplifier::simplify(nest.root_stmt()); + checkIR(stmt, R"IR( +# CHECK: for (int i +# CHECK: for (int i_outer +# CHECK: for (int i_outer_1 +# CHECK: for (int i_inner +# CHECK: for (int i_inner_1 +# CHECK: for (int i_1 +# CHECK: f[ +# CHECK: for (int i_tail_1 +# CHECK: for (int i_inner_1 +# CHECK: for (int i_1 +# CHECK: f[ +# CHECK: for (int i_tail)IR"); + + // Correctness check + PaddedBuffer a_v(M, N, L, K, "a"); + PaddedBuffer b_v(M, N, L, K, "b"); + PaddedBuffer c_v(M, N, L, K, "c"); + PaddedBuffer c_ref(M, N, L, K, "c_ref"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + for (int l = 0; l < L; l++) { + for (int k = 0; k < K; k++) { + a_v(m, n, l, k) = 2 * (m + l); + b_v(m, n, l, k) = 3 * (n + k); + c_ref(m, n, l, k) = a_v(m, n, l, k) + b_v(m, n, l, k) + 1.0f; + } + } + } + } + + SimpleIREvaluator(stmt, {a_buf, b_buf, tensor})(a_v, b_v, c_v); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(c_v, c_ref, 1e-5); +} + +TEST(LoopNest, SplitWithTailWithLoopOptions) { + const int M = 21; + BufHandle a_buf("a", {M}, kFloat); + BufHandle b_buf("b", {M}, kFloat); + Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { + return a_buf.load(m) + b_buf.load(m) + 1.0f; + }); + ForPtr inner, tail; + + LoopNest l({tensor}); + auto loops = NodeFinder::find(l.root_stmt()); + ASSERT_GT(loops.size(), 0); + loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); + LoopNest::splitWithTail(loops[0], 4, &inner, &tail); + ASSERT_NE(inner, nullptr); + ASSERT_NE(tail, nullptr); + ForPtr outer = loops[0]; + + // Outer loop carries loop axis bindings. + ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); + ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y); + + // Inner loop has none. + ASSERT_TRUE(inner->loop_options().isDefault()); + + // Tail loop has none. + ASSERT_TRUE(tail->loop_options().isDefault()); +} + +TEST(LoopNest, SplitWithMaskWithLoopOptions) { + const int M = 21; + BufHandle a_buf("a", {M}, kFloat); + BufHandle b_buf("b", {M}, kFloat); + Tensor tensor = Compute("f", {M}, [&](const ExprHandle& m) { + return a_buf.load(m) + b_buf.load(m) + 1.0f; + }); + ForPtr inner; + + LoopNest l({tensor}); + auto loops = NodeFinder::find(l.root_stmt()); + loops[0]->set_gpu_block_index(LoopOptions::IDX_Y); + LoopNest::splitWithMask(loops[0], 4, &inner); + ForPtr outer = loops[0]; + + // Outer loop carries loop axis bindings. + ASSERT_TRUE(outer->loop_options().is_gpu_block_index()); + ASSERT_EQ(outer->loop_options().gpu_block_index(), LoopOptions::IDX_Y); + + // Inner loop has none. + ASSERT_TRUE(inner->loop_options().isDefault()); +} + +TEST(LoopNest, ScheduleBroadcastAddBuffer) { + const int M = 4; + const int N = 5; + const int K = 6; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {N, K}, kFloat); + Tensor c = Compute( + "broadcast_add", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + LoopNest l({c}); + StmtPtr stmt = l.root_stmt(); + + PaddedBuffer a_v(M, N, "a_v"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + a_v(m, n) = 7 * m * n; + } + } + a_v.Backup(); + + PaddedBuffer b_v(N, K, "b_v"); + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + b_v(n, k) = 11 * n * k; + } + } + b_v.Backup(); + + PaddedBuffer c_v(M, N, K, "c_buf"); + SimpleIREvaluator ir_eval(stmt, {a_buf, b_buf, c}); + ir_eval(a_v, b_v, c_v); + + a_v.CheckBackup(); + b_v.CheckBackup(); + PaddedBuffer c_ref(M, N, K, "c_ref"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + c_ref(m, n, k) = 7 * m * n + 11 * n * k; + } + } + } + ExpectAllNear(c_v, c_ref, 1e-5); +} + +TEST(LoopNest, ScheduleFunctionCall01) { + const int M = 4; + const int N = 5; + const int K = 6; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {N, K}, kFloat); + Tensor c = Compute( + "broadcast_add", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + Tensor d = Compute( + "d", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c.load(m, n, k) + 1; + }); + + LoopNest l({d}, {c, d}); + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + std::ostringstream oss; + oss << *stmt; + ASSERT_GT(oss.str().size(), 100); + + PaddedBuffer a_v(M, N); + PaddedBuffer b_v(N, K); + PaddedBuffer c_v(M, N, K); + PaddedBuffer d_v(M, N, K); + PaddedBuffer d_ref(M, N, K); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + a_v(i, j) = i * i; + } + } + for (int i = 0; i < N; i++) { + for (int j = 0; j < K; j++) { + b_v(i, j) = j * j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + for (int k = 0; k < K; k++) { + d_ref(i, j, k) = a_v(i, j) + b_v(j, k) + 1; + } + } + } + + SimpleIREvaluator eval(stmt, {a_buf, b_buf, d}); + eval(a_v, b_v, d_v); + + ExpectAllNear(d_v, d_ref, 1e-5); +} + +TEST(LoopNest, ScheduleInlineSimple) { + const int M = 4; + const int N = 5; + const int K = 6; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {N, K}, kFloat); + BufHandle c_buf("c", {M, N}, kFloat); + BufHandle d_buf("d", {M, K}, kFloat); + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) * b_buf.load(n, k); + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); + }); + + LoopNest l1({y}, {x, y}); + LoopNest l2(l1); + l2.computeInline(x.buf()); + + l1.prepareForCodegen(); + l2.prepareForCodegen(); + + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); + + SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, c_buf, d_buf, y}); + SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, c_buf, d_buf, y}); + + PaddedBuffer a_v(M, N); + PaddedBuffer b_v(N, K); + PaddedBuffer c_v(M, N); + PaddedBuffer d_v(M, K); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + a_v(i, j) = i * i; + } + } + for (int i = 0; i < N; i++) { + for (int j = 0; j < K; j++) { + b_v(i, j) = j * j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + c_v(i, j) = i + j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < K; j++) { + d_v(i, j) = i * j; + } + } + + PaddedBuffer y_1(M, N, K); + PaddedBuffer y_2(M, N, K); + + eval1(a_v, b_v, c_v, d_v, y_1); + eval2(a_v, b_v, c_v, d_v, y_2); + ExpectAllNear(y_1, y_2, 1e-5); + std::ostringstream oss1, oss2; + oss1 << *stmt1; + oss2 << *stmt2; + ASSERT_GT(oss1.str().size(), oss2.str().size()); +} + +static std::string remove_space(const std::string& str) { + std::string str_new = str; + str_new.erase( + remove_if(str_new.begin(), str_new.end(), isspace), str_new.end()); + return str_new; +} + +void InlineFunc01Helper(const std::vector& inline_order) { + const int M = 4; + const int N = 5; + const int K = 6; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {N, K}, kFloat); + BufHandle c_buf("c", {M, N}, kFloat); + BufHandle d_buf("d", {M, K}, kFloat); + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) * b_buf.load(n, k); + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); + }); + Tensor z = Compute( + "z", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x.load(m, n, k) + y.load(m, n, k); + }); + + LoopNest l({z}, {x, y, z}); + for (const std::string& order : inline_order) { + if (order == "x") { + l.computeInline(x.buf()); + } else if (order == "y") { + l.computeInline(y.buf()); + } else { + throw std::runtime_error("Invalid order: " + order); + } + } + l.prepareForCodegen(); + StmtPtr stmt = l.root_stmt(); + + std::ostringstream oss; + oss << *stmt; + std::string str1 = remove_space(oss.str()); + + { + PaddedBuffer a_v(M, N); + PaddedBuffer b_v(N, K); + PaddedBuffer c_v(M, N); + PaddedBuffer d_v(M, K); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + a_v(i, j) = i * i; + } + } + for (int i = 0; i < N; i++) { + for (int j = 0; j < K; j++) { + b_v(i, j) = j * j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + c_v(i, j) = i + j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < K; j++) { + d_v(i, j) = i * j; + } + } + + PaddedBuffer z_v(M, N, K); + PaddedBuffer z_ref(M, N, K); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); + } + } + } + + SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); + eval(a_v, b_v, c_v, d_v, z_v); + ExpectAllNear(z_v, z_ref, 1e-5); + } + + if (inline_order.size() == 2) { + Tensor z2 = Compute( + "z", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) * b_buf.load(n, k) + + (c_buf.load(m, n) * d_buf.load(m, k) + + a_buf.load(m, n) * b_buf.load(n, k)); + }); + LoopNest l2({z2}); + l2.prepareForCodegen(); + StmtPtr stmt2 = l2.root_stmt(); + + std::ostringstream oss2; + oss2 << *stmt2; + std::string str2 = remove_space(oss2.str()); + + ASSERT_EQ(str1, str2); + ASSERT_GT(str1.size(), 100); + } +} + +TEST(LoopNest, ScheduleInlineFunc01) { + InlineFunc01Helper({"x", "y"}); + InlineFunc01Helper({"y", "x"}); + InlineFunc01Helper({"x"}); + InlineFunc01Helper({"y"}); + InlineFunc01Helper({}); +} + +// Make sure we cache random vars if we should. +TEST(LoopNest, ScheduleInlineRandom) { + const int M = 4; + const int N = 5; + const int K = 6; + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return Mod::make(Intrinsics::make(kRand, kInt), 5); + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x.load(m, n, k) + x.load(m, n, k); + }); + + LoopNest l1({y}, {x, y}); + l1.computeInline(x.buf()); + + // would normally compare results but Rand isn't implemented in the + // SimpleIREvaluator, even if we could seed it. + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + + // Check the IR we produced + checkIR(stmt1, R"IR( +# CHECK: for (int i = 0; i < 4; i++) +# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) +# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) +# CHECK: int x = rand(); +# CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR"); +} + +// Make sure we don't cache random vars that are not being inlined. +TEST(LoopNest, ScheduleInlineRandomUnrelated) { + const int M = 4; + const int N = 5; + const int K = 6; + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return m * n * k; + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x.load(m, n, k) + Intrinsics::make(kRand, kInt) + + Intrinsics::make(kRand, kInt); + }); + + LoopNest l1({y}, {x, y}); + l1.computeInline(x.buf()); + + // would normally compare results but Rand isn't implemented in the + // SimpleIREvaluator, even if we could seed it. + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + + // Check the IR we produced + checkIR(stmt1, R"IR( +# CHECK: for (int i = 0; i < 4; i++) +# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) +# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) +# CHECK: y[i, i_1, i_2] = ((i * i_1) * i_2 + (rand())) + (rand());)IR"); +} + +// Make sure we generate the right number of random values == the dimensionality +// of the production tensor. +TEST(LoopNest, ScheduleInlineRandomLowerDimensions) { + const int M = 4; + const int N = 5; + const int K = 6; + + Tensor x = Compute("x", {M}, [&](const VarHandle& m) { + return Mod::make(Intrinsics::make(kRand, kInt), 5); + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x.load(m) + x.load(m); + }); + + LoopNest l1({y}, {x, y}); + l1.computeInline(x.buf()); + + // would normally compare results but Rand isn't implemented in the + // SimpleIREvaluator, even if we could seed it. + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + + // Check the IR we produced + checkIR(stmt1, R"IR( +# CHECK: for (int i = 0; i < 4; i++) +# CHECK: int x = rand(); +# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) +# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) +# CHECK: y[i, i_1, i_2] = 2 * (x % 5);)IR"); +} + +// Make sure we don't screw up intrinsics thinking they're rand. +TEST(LoopNest, ScheduleInlineIntrinsics) { + const int M = 4; + const int N = 5; + const int K = 6; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {N, K}, kFloat); + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) * b_buf.load(n, k); + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return Intrinsics::make(kSqrt, x.load(m, n, k)); + }); + + PaddedBuffer a_v(M, N); + PaddedBuffer b_v(N, K); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + a_v(i, j) = i * i; + } + } + for (int i = 0; i < N; i++) { + for (int j = 0; j < K; j++) { + b_v(i, j) = j * j; + } + } + + LoopNest l1({y}, {x, y}); + LoopNest l2(l1); + l2.computeInline(x.buf()); + + l1.prepareForCodegen(); + l2.prepareForCodegen(); + + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); + + SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); + SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); + + PaddedBuffer y_1(M, N, K); + PaddedBuffer y_2(M, N, K); + + eval1(a_v, b_v, y_1); + eval2(a_v, b_v, y_2); + ExpectAllNear(y_1, y_2, 1e-5); + std::ostringstream oss1, oss2; + oss1 << *stmt1; + oss2 << *stmt2; + ASSERT_GT(oss1.str().size(), oss2.str().size()); +} + +// Make sure we can handle rand and non-rand intrinsics. +TEST(LoopNest, ScheduleInlineRandWithIntrinsics) { + const int M = 4; + const int N = 5; + const int K = 6; + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return Intrinsics::make(kRand, kFloat); + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return Intrinsics::make(kSqrt, x.load(m, n, k)); + }); + + LoopNest l1({y}, {x, y}); + l1.computeInline(x.buf()); + + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + + // Check the IR we produced + checkIR(stmt1, R"IR( +# CHECK: for (int i = 0; i < 4; i++) +# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) +# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) +# CHECK: float x = rand(); +# CHECK: y[i, i_1, i_2] = sqrt(x);)IR"); +} + +// Split a Compute then inline it into another compute. +TEST(LoopNest, ScheduleSplitAThenInline) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + + LoopNest l({b}, {a, b}); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + LoopNest::splitWithMask(loops[0], 4); + ASSERT_FALSE(l.computeInline(a.buf())); +} + +// Split a Compute then inline another Compute into it. +TEST(LoopNest, ScheduleSplitBThenInline) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + + LoopNest l({b}, {a, b}); + std::vector loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); + LoopNest::splitWithMask(loops[0], 3); + l.computeInline(a.buf()); + l.prepareForCodegen(); + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + + std::vector output(6, 0); + SimpleIREvaluator eval(s, {b}); + eval(output); + + for (int i = 0; i < 6; ++i) { + ASSERT_EQ(output[i], (i + 8) * (i + 8)); + } +} + +// Split a Compute twice then inline it. +TEST(LoopNest, ScheduleSplitTwiceThenInline) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {2}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + ForPtr i_inner; + + LoopNest l({b}, {a, b}); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + LoopNest::splitWithMask(loops[0], 4, &i_inner); + LoopNest::splitWithMask(i_inner, 2); + ASSERT_FALSE(l.computeInline(a.buf())); +} + +// Inline a Compute, then split. +TEST(LoopNest, ScheduleInlineThenSplit) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + + LoopNest l({b}, {a, b}); + l.computeInline(a.buf()); + + std::vector loops = NodeFinder::find(l.root_stmt()); + LoopNest::splitWithMask(loops.back(), 3); + l.prepareForCodegen(); + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + std::vector output(6, 0); + SimpleIREvaluator eval(s, {b}); + eval(output); + + for (int i = 0; i < 6; ++i) { + ASSERT_EQ(output[i], (i + 8) * (i + 8)); + } +} + +// Split a Compute, inline it, then split the result. +TEST(LoopNest, ScheduleSplitInlineThenSplit) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {16}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + + LoopNest l({b}, {a, b}); + auto loops = NodeFinder::find(l.root_stmt()); + LoopNest::splitWithMask(loops.back(), 2); + l.computeInline(a.buf()); + + loops = NodeFinder::find(l.root_stmt()); + LoopNest::splitWithMask(loops.front(), 2); + l.prepareForCodegen(); + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + std::vector output(16, 0); + SimpleIREvaluator eval(s, {b}); + eval(output); + + for (int i = 0; i < 16; ++i) { + ASSERT_EQ(output[i], (i + 8) * (i + 8)); + } +} + +// Oversplit a loop that is simplified out after inlining. +TEST(LoopNest, ScheduleSplitInlineSimplify) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { + return ExprHandle(4) * i - ExprHandle(2) * i; + }); + Tensor b = Compute( + "b", {2}, [&](const VarHandle& j) { return a.load(j) - ExprHandle(1); }); + + LoopNest l({b}, {a, b}); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + LoopNest::splitWithMask(loops[0], 4); + ASSERT_FALSE(l.computeInline(a.buf())); +} + +// Inline a Compute with two consumers. +TEST(LoopNest, ScheduleInlineThreeMixedOnce) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { + return a.load(k) * b.load(l); + }); + + LoopNest l({c}, {a, b, c}); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + l.computeInline(a.buf()); + l.prepareForCodegen(); + + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + std::vector output(4 * 3, 0); + SimpleIREvaluator eval(s, {c}); + eval(output); + + for (int k = 0; k < 4; ++k) { + for (int l = 0; l < 3; ++l) { + ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); + } + } +} + +// Inline Compute A into B, then inline B into C. +TEST(LoopNest, ScheduleInlineThreeMixedTwice) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { + return a.load(k) * b.load(l); + }); + + LoopNest l({c}, {a, b, c}); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + l.computeInline(a.buf()); + l.computeInline(b.buf()); + l.prepareForCodegen(); + + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + std::vector output(4 * 3, 0); + SimpleIREvaluator eval(s, {c}); + eval(output); + + for (int k = 0; k < 4; ++k) { + for (int l = 0; l < 3; ++l) { + ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); + } + } +} + +// Inline a Compute that is both a producer and consumer. +TEST(LoopNest, ScheduleInlineThreeMixedInner) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { + return a.load(k) * b.load(l); + }); + + LoopNest l({c}, {a, b, c}); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + l.computeInline(b.buf()); + l.prepareForCodegen(); + + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + std::vector output(4 * 3, 0); + SimpleIREvaluator eval(s, {c}); + eval(output); + + for (int k = 0; k < 4; ++k) { + for (int l = 0; l < 3; ++l) { + ASSERT_EQ(output[k * 3 + l], (k) * (k) * (l + 8) * (l + 8)); + } + } +} + +// Split 3 Computes, then inline the first two into the last. +TEST(LoopNest, ScheduleInlineThreeMixedSplit) { + Tensor a = Compute("a", {18}, [&](const VarHandle& i) { return i * i; }); + Tensor b = Compute( + "b", {6}, [&](const VarHandle& j) { return a.load(j + ExprHandle(8)); }); + Tensor c = Compute("c", {4, 3}, [&](const VarHandle& k, const VarHandle& l) { + return a.load(k) * b.load(l); + }); + + LoopNest l({c}, {a, b, c}); + std::vector loops = l.getAllLoopNestsWritingToBuf(a.buf()).at(0); + LoopNest::splitWithMask(loops[0], 4); + loops = l.getAllLoopNestsWritingToBuf(b.buf()).at(0); + LoopNest::splitWithMask(loops[0], 3); + loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); + LoopNest::splitWithMask(loops[0], 2); + + ASSERT_FALSE(l.computeInline(a.buf())); +} + +// Check that inlining works for output tensors too +TEST(LoopNest, ScheduleInlineOutputTensors) { + const int M = 4; + const int N = 5; + const int K = 6; + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return m * n * k; + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x.load(m, n, k) + m; + }); + + LoopNest l1({x, y}); + l1.computeInline(x.buf()); + + // would normally compare results but Rand isn't implemented in the + // SimpleIREvaluator, even if we could seed it. + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + + // Check the IR we produced + checkIR(stmt1, R"IR( +# CHECK: for (int i = 0; i < 4; i++) +# CHECK: for (int i_1 = 0; i_1 < 5; i_1++) +# CHECK: for (int i_2 = 0; i_2 < 6; i_2++) +# CHECK: x[i, i_1, i_2] = (i * i_1) * i_2; +# CHECK: for (int i_3 = 0; i_3 < 4; i_3++) +# CHECK: for (int i_4 = 0; i_4 < 5; i_4++) +# CHECK: for (int i_5 = 0; i_5 < 6; i_5++) +# CHECK: y[i_3, i_4, i_5] = i_3 + (i_3 * i_4) * i_5;)IR"); +} + +TEST(LoopNest, ScheduleInlineWithCompoundIndices) { + // Input IR: + // for (int64_t i = 0; i < 100; i++) { + // A[i*2,i] = i * 500ll; + // } + // for (int64_t j = 0; j < 100; j++) { + // B[0ll,j] = A[0, j] + j * 100ll; + // } + BufHandle a_buf("A", {20, 100}, kLong); + BufHandle b_buf("B", {20, 100}, kLong); + VarHandle i("i", kLong); + VarHandle j("j", kLong); + auto forI = For::make( + i, + 0, + 100, + Store::make(a_buf, {i * 2, i}, Mul::make(i, static_cast(500)))); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + b_buf, + {static_cast(0), j}, + Add::make( + Load::make(a_buf, {static_cast(0), j}), + Mul::make(j, static_cast(100))))); + auto par = Block::make({forI, forJ}); + + LoopNest l(par, {b_buf.node()}); + // Inlining should fail since the producer has compound expr as index. + ASSERT_FALSE(l.computeInline(a_buf.node())); + + // The input statement must remain as is. + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int64_t i = 0; + # CHECK-NEXT: A[ + # CHECK: for (int64_t j = 0; + # CHECK-NEXT: B[)IR"); +} + +TEST(LoopNest, ScheduleInlineConsumerIndicesWithCast) { + // Input IR: + // for (int64_t i = 0; i < 100; i++) { + // A[0ll,i] = i * 500ll; + // } + // for (int64_t j = 0; j < 100; j++) { + // B[0ll,j] = A[(int64_t)0, j] + j * 100ll; + // } + BufHandle a_buf("A", {20, 100}, kLong); + BufHandle b_buf("B", {20, 100}, kLong); + VarHandle i("i", kLong); + VarHandle j("j", kLong); + auto forI = For::make( + i, + 0, + 100, + Store::make( + a_buf, + {static_cast(0), i}, + Mul::make(i, static_cast(500)))); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + b_buf, + {static_cast(0), j}, + Add::make( + Load::make(a_buf, {0, j}), + Mul::make(j, static_cast(100))))); + auto par = Block::make({forI, forJ}); + + LoopNest l(par, {b_buf.node()}); + ASSERT_TRUE(l.computeInline(a_buf.node())); + + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int64_t j = 0; j < 100; j++) { + # CHECK: B[0ll, j] = j * 500ll + j * 100ll; + # CHECK: })IR"); +} + +TEST(LoopNest, ScheduleInlineProducerIndicesWithCast) { + // Input IR: + // for (int64_t i = 0; i < 100; i++) { + // A[(int64_t)0,i] = i * 500ll; + // } + // for (int64_t j = 0; j < 100; j++) { + // B[0ll,j] = A[0ll, j] + j * 100ll; + // } + BufHandle a_buf("A", {20, 100}, kLong); + BufHandle b_buf("B", {20, 100}, kLong); + VarHandle i("i", kLong); + VarHandle j("j", kLong); + auto forI = For::make( + i, + 0, + 100, + Store::make(a_buf, {0, i}, Mul::make(i, static_cast(500)))); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + b_buf, + {static_cast(0), j}, + Add::make( + Load::make(a_buf, {static_cast(0), j}), + Mul::make(j, static_cast(100))))); + auto par = Block::make({forI, forJ}); + + LoopNest l(par, {b_buf.node()}); + ASSERT_TRUE(l.computeInline(a_buf.node())); + + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int64_t j = 0; j < 100; j++) { + # CHECK: B[0ll, j] = j * 500ll + j * 100ll; + # CHECK: })IR"); +} + +TEST(LoopNest, ScheduleFuserStyle) { + const int kVectorSize = 8; + const int kVectorCount = 128; + const int kTotalSize = kVectorSize * kVectorCount; + + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + + Tensor b = + Compute("f", {kTotalSize}, [&](const std::vector& axes) { + return a_buf.load(axes[0]) + 11.0f; + }); + + Tensor c = + Compute("g", {kTotalSize}, [&](const std::vector& axes) { + return b.load(axes[0]) + 1.0f; + }); + + LoopNest l({b, c}); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + + std::vector a_data(kTotalSize, 7.0f); + std::vector b_data(kTotalSize, 0.0f); + std::vector c_data(kTotalSize, 0.0f); + SimpleIREvaluator(s, {a_buf, b, c})(a_data, b_data, c_data); + + for (int i = 0; i < kTotalSize; i++) { + ASSERT_EQ(b_data[i], 18.0f); + ASSERT_EQ(c_data[i], 19.0f); + } +} + +TEST(LoopNest, ScheduleFuserThreeArg) { + const int kVectorSize = 8; + const int kVectorCount = 128; + const int kTotalSize = kVectorSize * kVectorCount; + + BufHandle a("A", {ExprHandle(kTotalSize)}, kFloat); + BufHandle b("B", {ExprHandle(kTotalSize)}, kFloat); + BufHandle c("C", {ExprHandle(kTotalSize)}, kFloat); + BufHandle d("D", {ExprHandle(kTotalSize)}, kFloat); + + Tensor e = Compute("e", {kTotalSize}, [&](const VarHandle& i) { + return a.load(i) + b.load(i); + }); + Tensor f = Compute("f", {kTotalSize}, [&](const VarHandle& i) { + return e.load(i) + c.load(i); + }); + Tensor g = Compute("g", {kTotalSize}, [&](const VarHandle& i) { + return f.load(i) + d.load(i); + }); + + LoopNest l({g}, {e, f, g}); + l.computeInline(l.getLoopBodyFor(e)); + l.computeInline(l.getLoopBodyFor(f)); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + + std::vector a_data(kTotalSize, 1.0f); + std::vector b_data(kTotalSize, 2.0f); + std::vector c_data(kTotalSize, 3.0f); + std::vector d_data(kTotalSize, 4.0f); + std::vector g_data(kTotalSize, 0.0f); + SimpleIREvaluator(s, {a, b, c, d, g})(a_data, b_data, c_data, d_data, g_data); + + for (int i = 0; i < kTotalSize; i++) { + ASSERT_EQ(g_data[i], 10.0f); + } +} + +TEST(LoopNest, ScheduleDynamicShape2D) { + auto testWithSize = [](int32_t M, int32_t N) { + VarHandle m("m", kInt); + VarHandle n("n", kInt); + BufHandle a("a", {m, n}, kFloat); + BufHandle b("b", {m, n}, kFloat); + Tensor c = + Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) { + return a.load(i, j) + b.load(i, j); + }); + LoopNest l({c}); + StmtPtr s = l.root_stmt(); + SimpleIREvaluator cg(s, {a, b, c, m, n}); + std::vector aData(M * N, 1.0f); + std::vector bData(M * N, 2.0f); + std::vector cData(M * N, 0.0f); + cg.call({aData, bData, cData, M, N}); + ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); + }; + testWithSize(1, 8); + testWithSize(16, 32); + testWithSize(37, 11); +} + +TEST(LoopNest, LoopNestComputeAt_1) { + // Verify that compute_at works on the following example: + // + // for (int i_a = 0; i_a < N; i_a++) { + // A[i_a] = i_a * i_a + // } + // for (int i_b = 0; i_b < N; i_b++) { + // B[i_b] = A[i_b] + // } + // + // After the transformation the i_b loop should have an allocation for a temp + // buffer and that buffer should be used in computation of B. No use of A + // should be in that loop after the transformation. Also, computation of A + // should not be inlined into B. Instead, it should be computed into the temp, + // and the temp should be used in B. + VarHandle N("N", kInt); + Tensor A = Compute("A", {N}, [&](const VarHandle& i_a) { return i_a * i_a; }); + Tensor B = + Compute("B", {N}, [&](const VarHandle& i_b) { return A.load(i_b); }); + LoopNest l({B}, {A, B}); + std::vector loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); + l.prepareForCodegen(); + SimpleIREvaluator cg(l.root_stmt(), {B, N}); + StmtPtr s = cg.stmt(); + + checkIR(s, R"IR( +# CHECK: Allocate(temp); // dtype=int, dims=[1] +# CHECK: for (int i = 0; i < N; i++) +# CHECK: temp[ +# CHECK-NOT: A[ +# CHECK: B[i_1] = temp[0] +# CHECK: Free(temp))IR"); + + // Now check that the loop still produces the correct result. + std::vector b_data(100, 0); + cg.call({b_data, 100}); + + std::vector b_ref(100, 0); + for (int i = 0; i < 100; i++) { + b_ref[i] = i * i; + } + assertAllEqual(b_data, b_ref); +} + +TEST(LoopNest, LoopNestComputeAt_2) { + // Verify that compute_at works on the following example: + // + // for (int py = 0; py < H+1; py++) { + // for (int px = 0; px < W+1; px++) { + // p[py, px] = py*px + // } + // } + // for (int cy = 0; cy < H; cy++) { + // for (int cx = 0; cx < W; cx++) { + // c[py, px] = p[cy,cx] + p[cy+1,cx] + + // p[cy,cx+1] + p[cy+1,cx+1] + // } + // } + + const int kW = 16, kH = 16; + VarHandle W("W", kInt); + VarHandle H("H", kInt); + Tensor p = Compute( + "prod", {H + 1, W + 1}, [&](const VarHandle& py, const VarHandle& px) { + return px * py; + }); + Tensor c = + Compute("cons", {H, W}, [&](const VarHandle& y, const VarHandle& x) { + return p.load(y, x) + p.load(y + 1, x) + p.load(y, x + 1) + + p.load(y + 1, x + 1); + }); + + std::vector c_ref(kW * kH, 0); + for (int y = 0; y < kH; y++) { + for (int x = 0; x < kW; x++) { + c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); + } + } + LoopNest orig_loopnest({c}, {p, c}); + + { + // First let's try to compute P at axis cy (the outer loop) + LoopNest l(orig_loopnest); + std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); + l.prepareForCodegen(); + SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); + StmtPtr s = cg.stmt(); + + // Check the IR we produced + checkIR(s, R"IR( +# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1] +# CHECK: for (int i_2 = 0; i_2 < H; i_2++) +# CHECK: for +# CHECK: for +# CHECK: for (int i_3 = 0; i_3 < W; i_3++) +# CHECK-NOT: prod[ +# CHECK: cons[ +# CHECK: Free(temp))IR"); + + // Now check that the loop still produces the correct result. + std::vector c_data(kW * kH, 0); + cg.call({c_data, kW, kH}); + + assertAllEqual(c_data, c_ref); + } + { + // Now let's try to compute P at axis cx (the inner loop) + LoopNest l(orig_loopnest); + std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); + l.prepareForCodegen(); + SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); + StmtPtr s = cg.stmt(); + + // Check the IR we produced + checkIR(s, R"IR( +# CHECK: Allocate(temp); // dtype=int, dims=[2, 2] +# CHECK: for (int i_2 = 0; i_2 < H; i_2++) +# CHECK: for (int i_3 = 0; i_3 < W; i_3++) +# CHECK: for +# CHECK: for +# CHECK-NOT: prod[ +# CHECK: cons[ +# CHECK: Free(temp))IR"); + + // Now check that the loop still produces the correct result. + std::vector c_data(kW * kH, 0); + cg.call({c_data, kW, kH}); + + assertAllEqual(c_data, c_ref); + } +} + +TEST(LoopNest, LoopNestComputeAt_3) { + // Verify that compute_at works on the following example: + // + // A(x,y) = x*y + // B(x,y) = A(x, y) + // C(x,y) = B(x+1, y) + // D(x,y) = A(x, y+1) + C(x, y) + // + // i.e. when 'A' comes to 'D' directly and indirectly through 'C'. + + const int kW = 16, kH = 16; + VarHandle W("W", kInt); + VarHandle H("H", kInt); + Tensor A = Compute( + "A", {H + 1, W + 1}, [&](const VarHandle& ay, const VarHandle& ax) { + return ax * ay; + }); + Tensor B = Compute( + "B", {H + 1, W + 1}, [&](const VarHandle& by, const VarHandle& bx) { + return A.load(by, bx); + }); + Tensor C = + Compute("C", {H, W}, [&](const VarHandle& cy, const VarHandle& cx) { + return B.load(cy, cx + 1); + }); + Tensor D = + Compute("D", {H, W}, [&](const VarHandle& dy, const VarHandle& dx) { + return A.load(dy + 1, dx) + C.load(dy, dx); + }); + + std::vector c_ref(kW * kH, 0); + for (int y = 0; y < kH; y++) { + for (int x = 0; x < kW; x++) { + c_ref[y * kW + x] = (y + 1) * x + y * (x + 1); + } + } + + LoopNest orig_loopnest({D}, {A, B, C, D}); + { + // First let's try to compute A at axis dy (the outer loop) + LoopNest l(orig_loopnest); + std::vector loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); + l.prepareForCodegen(); + SimpleIREvaluator cg(l.root_stmt(), {D, W, H}); + StmtPtr s = cg.stmt(); + + // Check the IR we produced + checkIR(s, R"IR( +# CHECK: Allocate(temp); // dtype=int, dims=[1, W] +# CHECK: for (int i = 0; i < H + 1; i++) +# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) +# CHECK: A[ +# CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++) +# CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++) +# CHECK: B[ +# CHECK: for (int i_4 = 0; i_4 < H; i_4++) +# CHECK: for (int i_5 = 0; i_5 < W; i_5++) +# CHECK: C[ +# CHECK: for (int i_6 = 0; i_6 < H; i_6++) +# CHECK: for (int i_7 = 0; i_7 < W; i_7++) +# CHECK-NOT: A[)IR"); + + // Now check that the loop still produces the correct result. + std::vector c_data(kW * kH, 0); + cg.call({c_data, kW, kH}); + + assertAllEqual(c_data, c_ref); + } + { + // Now let's try to compute A at axis dx (the inner loop) + LoopNest l(orig_loopnest); + std::vector loops = l.getAllLoopNestsWritingToBuf(D.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(A), loops[1]); + l.prepareForCodegen(); + SimpleIREvaluator cg(l.root_stmt(), {D, W, H}); + StmtPtr s = cg.stmt(); + + // Check the IR we produced + checkIR(s, R"IR( +# CHECK: Allocate(temp); // dtype=int, dims=[1, 1] +# CHECK: for (int i = 0; i < H + 1; i++) +# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) +# CHECK: A[ +# CHECK: for (int i_2 = 0; i_2 < H + 1; i_2++) +# CHECK: for (int i_3 = 0; i_3 < W + 1; i_3++) +# CHECK: B[ +# CHECK: for (int i_4 = 0; i_4 < H; i_4++) +# CHECK: for (int i_5 = 0; i_5 < W; i_5++) +# CHECK: C[ +# CHECK: for (int i_6 = 0; i_6 < H; i_6++) +# CHECK: for (int i_7 = 0; i_7 < W; i_7++) +# CHECK-NOT: A[)IR"); + + // Now check that the loop still produces the correct result. + std::vector c_data(kW * kH, 0); + cg.call({c_data, kW, kH}); + + assertAllEqual(c_data, c_ref); + } +} + +using Axis = const VarHandle&; + +TEST(LoopNest, Reduce2dComputeAt) { + const int kW = 16, kH = 16; + VarHandle W("W", kInt); + VarHandle H("H", kInt); + + Tensor p = Compute( + "prod", {H + 1, W + 1}, [&](Axis py, Axis px) { return px * py; }); + Tensor c = Reduce( + "cons", + {H, W}, + Sum(), + [&](Axis y, Axis x, Axis r, Axis s) { return p.load(y + r, x + s); }, + {2, 2}); + + std::vector c_ref(kW * kH, 0); + for (int y = 0; y < kH; y++) { + for (int x = 0; x < kW; x++) { + c_ref[y * kW + x] = y * x + (y + 1) * x + y * (x + 1) + (y + 1) * (x + 1); + } + } + LoopNest orig_loopnest({c}, {p, c}); + checkIR(orig_loopnest.root_stmt(), R"IR( +# CHECK: for (int i = 0; i < H + 1; i++) { +# CHECK: for (int i_1 = 0; i_1 < W + 1; i_1++) { +# CHECK: prod[i, i_1] = i_1 * i; +# CHECK: } +# CHECK: } +# CHECK: for (int i_2 = 0; i_2 < H; i_2++) { +# CHECK: for (int i_3 = 0; i_3 < W; i_3++) { +# CHECK: cons[i_2, i_3] = int(0); +# CHECK: for (int i_4 = 0; i_4 < 2; i_4++) { +# CHECK: for (int i_5 = 0; i_5 < 2; i_5++) { +# CHECK: cons[i_2, i_3] = ReduceOp((cons[i_2, i_3]) + (prod[i_2 + i_4, i_3 + i_5]), reduce_args={i_4, i_5}); +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: } +)IR"); + + { + // First let's try to compute P at axis cy (the outer loop) + LoopNest l(orig_loopnest); + auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(p), loops[0]); + // FIXME: Calling simplify here breaks the IR: + // MALFORMED INPUT: could not find base node in Load - temp[...] + // l.simplify(); + l.eliminateDeadStores(); + l.prepareForCodegen(); + SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); + checkIR(cg.stmt(), R"IR( +# CHECK: Allocate(temp); // dtype=int, dims=[2, W + 1] +# CHECK: for (int i = 0; i < H; i++) { +# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) { +# CHECK: for (int idx1 = 0; idx1 < W + 1; idx1++) { +# CHECK: temp[(0 + idx0 * (1 * (W + 1))) + idx1 * 1] = (idx0 + i) * (idx1 + 0); +# CHECK: } +# CHECK: } +# CHECK: for (int i_1 = 0; i_1 < W; i_1++) { +# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = int(0); +# CHECK: for (int i_2 = 0; i_2 < 2; i_2++) { +# CHECK: for (int i_3 = 0; i_3 < 2; i_3++) { +# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * (W + 1))) + (i_1 + i_3) * 1]); +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: Free(temp); +)IR"); + + // Now check that the loop still produces the correct result. + std::vector c_data(kW * kH, 0); + cg.call({c_data, kW, kH}); + assertAllEqual(c_data, c_ref); + } + { + // Now let's try to compute P at axis cx (the inner loop) + LoopNest l(orig_loopnest); + std::vector loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(p), loops[1]); + l.simplify(); + l.eliminateDeadStores(); + l.prepareForCodegen(); + SimpleIREvaluator cg(l.root_stmt(), {c, W, H}); + checkIR(cg.stmt(), R"IR( +# CHECK: Allocate(temp); // dtype=int, dims=[2, 2] +# CHECK: for (int i = 0; i < H; i++) { +# CHECK: for (int i_1 = 0; i_1 < W; i_1++) { +# CHECK: for (int idx0 = 0; idx0 < 2; idx0++) { +# CHECK: for (int idx1 = 0; idx1 < 2; idx1++) { +# CHECK: temp[(0 + idx0 * (1 * 2)) + idx1 * 1] = (i + idx0) * (i_1 + idx1); +# CHECK: } +# CHECK: } +# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = 0; +# CHECK: for (int i_2 = 0; i_2 < 2; i_2++) { +# CHECK: for (int i_3 = 0; i_3 < 2; i_3++) { +# CHECK: cons[(0 + i * (1 * W)) + i_1 * 1] = (cons[(0 + i * (1 * W)) + i_1 * 1]) + (temp[(0 + i_2 * (1 * 2)) + i_3 * 1]); +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: Free(temp); +)IR"); + + // Now check that the loop still produces the correct result. + std::vector c_data(kW * kH, 0); + cg.call({c_data, kW, kH}); + assertAllEqual(c_data, c_ref); + } +} + +TEST(LoopNest, DISABLED_Conv1d_NH) { + // Lots of stuff is broken here. The computeAt swaps the axes for some odd + // reason. Even without that, the index flattener fails due to "dimensions + // mismatch in flatten index". + + int N = 4; + int H = 256; + int R = 3; + int Pad = 1; + BufHandle IP("input", {H}, kFloat); + + Tensor A = Compute("A", {N, H + 2 * Pad}, [&](Axis n, Axis h) { + auto cond = CompareSelect::make(h, Pad, 1, 0, kLT); + cond = CompareSelect::make(h, H + Pad, 1, cond, kGE); + return ifThenElse(cond, 0.f, IP.load(n, h - Pad)); + }); + Tensor B = Reduce( + "B", + {N, H}, + Sum(), + [&](Axis n, Axis h, Axis r) { return A.load(n, h + r); }, + {R}); + LoopNest l({B}); + checkIR(l.root_stmt(), R"IR( +# CHECK: for (int np = 0; np < 4; np++) { +# CHECK: for (int hp = 0; hp < 258; hp++) { +# CHECK: A[np, hp] = IfThenElse(hp>=257 ? 1 : (hp<1 ? 1 : 0), 0.f, input[np, hp - 1]); +# CHECK: } +# CHECK: } +# CHECK: for (int n = 0; n < 4; n++) { +# CHECK: for (int h = 0; h < 256; h++) { +# CHECK: B[n, h] = float(0); +# CHECK: for (int r = 0; r < 3; r++) { +# CHECK: B[n, h] = ReduceOp((B[n, h]) + (A(n, h + r)), reduce_args={r}); +# CHECK: } +# CHECK: } +# CHECK: } +)IR"); + std::vector loops = l.getAllLoopNestsWritingToBuf(B.buf()).at(0); + LoopNest::computeAt(l.getLoopBodyFor(A), loops[0]); + // FIXME: The current IR is totally broken. The body of the inlined loop is: + + // temp[idx0, idx1] = IfThenElse(idx0 + n>=257 ? 1 : (idx0 + n<1 ? 1 : 0), + // 0.f, input[idx1 + 0, (idx0 + n) - 1]); + + // Which seems to mix up the axes. The CHECK below is my best guess at what + // the input "should" look like + + checkIR(l.root_stmt(), R"IR( +# CHECK: for (int n = 0; n < 4; n++) { +# CHECK: for (int idx0 = 0; idx0 < 1; idx0++) { +# CHECK: for (int idx1 = 0; idx1 < 258; idx1++) { + temp[idx0, idx1] = IfThenElse(idx1>=257 ? 1 : (idx1<1 ? 1 : 0), 0.f, input[n, idx1 - 1]); +# CHECK: } +# CHECK: } +# CHECK: for (int h = 0; h < 256; h++) { +# CHECK: B[n, h] = float(0); +# CHECK: for (int r = 0; r < 3; r++) { +# CHECK: B[n, h] = ReduceOp((B[n, h]) + (temp[0, r + h]), reduce_args={r}); +# CHECK: } +# CHECK: } +# CHECK: } +)IR"); + + l.simplify(); + l.prepareForCodegen(); + StmtPtr s = l.root_stmt(); + + SimpleIREvaluator cg(s, {IP, B}); + // auto At = at::ones({N, H}, at::kFloat); + auto At = at::arange(N * H, at::kFloat).reshape({N, H}); + auto Rt = at::conv1d( + At, at::ones({1, 1, 3}), at::Tensor(), /*stride=*/1, /*padding=*/3); + auto Bt = at::empty_like(Rt); + cg.call({At.data_ptr(), Bt.data_ptr()}); + ASSERT_TRUE(at::allclose(Rt, Bt)); +} + +class LoopOrderHelper : public IRVisitor { + std::stringstream ordering; + + public: + std::string getOrder(StmtPtr s) { + ordering.str(""); + s->accept(this); + return ordering.str(); + } + + void visit(const ForPtr& v) final { + ordering << v->var()->name_hint() << ","; + IRVisitor::visit(v); + } +}; + +TEST(LoopNest, LoopNestReorderAxis1) { + Tensor tensor = + Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }); + LoopNest l({tensor}); + StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + + std::vector stmt1_output(6, 0); + SimpleIREvaluator cg(stmt1, {tensor}); + cg.call({stmt1_output}); + + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[0], loops[1]); + StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + + ASSERT_NE(stmt1, stmt2); + LoopOrderHelper loopOrderHelper; + std::string order1 = loopOrderHelper.getOrder(stmt1); + std::string order2 = loopOrderHelper.getOrder(stmt2); + + ASSERT_EQ(order1, "j,i,"); + ASSERT_EQ(order2, "i,j,"); + + std::vector stmt2_output(6, 0); + SimpleIREvaluator cg2(stmt2, {tensor}); + cg.call({stmt2_output}); + + for (int i = 0; i < 6; ++i) { + ASSERT_EQ(stmt1_output[i], stmt2_output[i]); + } + + // Reorder them back. + loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[0], loops[1]); + StmtPtr stmt3 = l.root_stmt(); + + std::string order3 = loopOrderHelper.getOrder(stmt3); + ASSERT_EQ(order3, order1); + + std::ostringstream oss1, oss2; + oss1 << *stmt1; + oss2 << *stmt3; + + // Should be identical to the unreordered statement. + ASSERT_EQ(oss1.str(), oss2.str()); +} + +TEST(LoopNest, LoopNestReorderPartialAxes) { + Tensor tensor = Compute( + "f", + {2, 3, 4}, + [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y + + cast(z) * z; + }); + LoopNest l({tensor}); + + LoopOrderHelper loopOrderHelper; + StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,"); + + std::vector stmt1_output(24, 0); + SimpleIREvaluator cg(stmt1, {tensor}); + cg.call({stmt1_output}); + + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[0], loops[1]); + ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,i,k,"); + + StmtPtr stmt2 = Stmt::clone(l.root_stmt()); + + std::vector stmt2_output(24, 0); + SimpleIREvaluator cg2(stmt2, {tensor}); + cg2.call({stmt2_output}); + + for (int i = 0; i < 24; ++i) { + ASSERT_EQ(stmt1_output[i], stmt2_output[i]); + } + + loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[1], loops[2]); + ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "j,k,i,"); + + StmtPtr stmt3 = Stmt::clone(l.root_stmt()); + + std::vector stmt3_output(24, 0); + SimpleIREvaluator cg3(stmt3, {tensor}); + cg3.call({stmt3_output}); + + for (int i = 0; i < 24; ++i) { + ASSERT_EQ(stmt1_output[i], stmt3_output[i]); + } +} + +TEST(LoopNest, LoopNestReorderInternalAxis) { + Tensor tensor = Compute( + "f", + {1, 2, 3, 4}, + [](const VarHandle& w, + const VarHandle& x, + const VarHandle& y, + const VarHandle& z) { + return ExprHandle(1.0f) + w + cast(x) * x + cast(y) * y + + cast(z) * z; + }); + LoopNest l({tensor}); + + LoopOrderHelper loopOrderHelper; + StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + ASSERT_EQ(loopOrderHelper.getOrder(stmt1), "i,j,k,l,"); + + std::vector stmt1_output(24, 0); + SimpleIREvaluator cg(stmt1, {tensor}); + cg.call({stmt1_output}); + + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[2], loops[1]); + ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "i,k,j,l,"); + + StmtPtr stmt2 = l.root_stmt(); + + std::vector stmt2_output(24, 0); + SimpleIREvaluator cg2(stmt2, {tensor}); + cg2.call({stmt2_output}); + + for (int i = 0; i < 24; ++i) { + ASSERT_EQ(stmt1_output[i], stmt2_output[i]); + } +} + +TEST(LoopNest, LoopNestReorderEnclosingAxis) { + Tensor tensor = Compute( + "f", + {1, 2, 3, 4}, + [](const VarHandle& w, + const VarHandle& x, + const VarHandle& y, + const VarHandle& z) { + return ExprHandle(1.0f) + w + cast(x) * x + cast(y) * y + + cast(z) * z; + }); + LoopNest l({tensor}); + + LoopOrderHelper loopOrderHelper; + StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + + std::vector stmt1_output(24, 0); + SimpleIREvaluator cg(stmt1, {tensor}); + cg.call({stmt1_output}); + + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[0], loops[3]); + ASSERT_EQ(loopOrderHelper.getOrder(l.root_stmt()), "l,j,k,i,"); + + StmtPtr stmt2 = l.root_stmt(); + + std::vector stmt2_output(24, 0); + SimpleIREvaluator cg2(stmt2, {tensor}); + cg2.call({stmt2_output}); + + for (int i = 0; i < 24; ++i) { + ASSERT_EQ(stmt1_output[i], stmt2_output[i]); + } +} + +TEST(LoopNest, LoopNestReorderSameAxis) { + Tensor tensor = + Compute("f", {2, 3}, [](const VarHandle& x, const VarHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }); + LoopNest l({tensor}); + StmtPtr stmt1 = Stmt::clone(l.root_stmt()); + + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[1], loops[1]); + StmtPtr stmt2 = Stmt::clone(l.root_stmt()); + + std::ostringstream oss, oss2; + oss << *stmt1; + oss2 << *stmt2; + ASSERT_EQ(oss.str(), oss2.str()); +} + +TEST(LoopNest, LoopNestReorderExtraStatements) { + /* We're going for a structure like this: + * for i in ... + * Stmt 1 + * for j in ... + * Stmt 2 + * for k in ... + * Stmt 3 + * Stmt 4 + */ + + Tensor tensor = Compute( + "f", + {2, 3, 4}, + [](const VarHandle& x, const VarHandle& y, const VarHandle& z) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y + + cast(z) * z; + }); + LoopNest l({tensor}); + + BufHandle extra("res", {6, 3}, kFloat); + + auto loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + + VarHandle i = VarHandle(loops[0]->var()); + + StmtPtr store_1 = Store::make(extra, {i, 0}, 1.f); + StmtPtr store_2 = Store::make(extra, {i, 1}, 2.f); + // stmt 3 is the Function body. + StmtPtr store_3 = Store::make(extra, {i, 2}, 4.f); + + loops[0]->body()->prepend_stmt(store_1); + loops[1]->body()->prepend_stmt(store_2); + loops[1]->body()->append_stmt(store_3); + StmtPtr stmt1 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + + std::vector extra1(6, 0); + std::vector res1(24, 0); + SimpleIREvaluator cg(stmt1, {tensor, extra}); + cg.call({res1, extra1}); + + /* Then we reorder loop y and z, we want it to look like: + * + * for i in ... + * Stmt 1 + * for j in ... + * Stmt 2 + * for j_1 in ... + * for k in ... + * Stmt 3 + * for j_2 in ... + * Stmt 4 + * + * We need extra loops because we don't have dependency info about stmt 3 + * and 4. + * + */ + + LoopNest::reorderAxis(loops[1], loops[2]); + StmtPtr stmt2 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + + // Check the IR we produced + checkIR(stmt2, R"IR( +# CHECK: for +# CHECK: res[i, 0] = 1 +# CHECK: for +# CHECK: res[i, 1] = 2 +# CHECK: for +# CHECK: for +# CHECK: f[ +# CHECK: for +# CHECK: res[i, 2] = 4 +)IR"); + + std::vector extra2(6, 0); + std::vector res2(24, 0); + SimpleIREvaluator cg2(stmt2, {tensor, extra}); + cg2.call({res2, extra2}); + + for (int i = 0; i < 24; ++i) { + ASSERT_EQ(res1[i], res2[i]); + } + for (int i = 0; i < 6; ++i) { + ASSERT_EQ(extra1[i], extra2[i]); + } + + /* Now reorder x and the y above stmt 3: + * + * + * for x in ... + * Stmt 1 + * for y in ... + * Stmt 2 + * + * for y in ... + * for z in ... + * for x in ... + * Stmt 3 + * + * for x in ... + * for y in ... + * Stmt 4 + * + * + */ + loops = l.getAllLoopNestsWritingToBuf(tensor.buf()).at(0); + LoopNest::reorderAxis(loops[0], loops[2]); + StmtPtr stmt3 = LoopNest::sanitizeNames(Stmt::clone(l.root_stmt())); + + // Check the IR we produced + checkIR(stmt3, R"IR( +# CHECK: for +# CHECK: res[i, 0] = 1 +# CHECK: for +# CHECK: res[i, 1] = 2 +# CHECK: for +# CHECK: for +# CHECK: for +# CHECK: f[ +# CHECK: for +# CHECK: for +# CHECK: res[i_2, 2] = 4 +)IR"); + + std::vector extra3(6, 0); + std::vector res3(24, 0); + SimpleIREvaluator cg3(stmt3, {tensor, extra}); + cg3.call({res3, extra3}); + + for (int i = 0; i < 24; ++i) { + ASSERT_EQ(res1[i], res3[i]); + } + for (int i = 0; i < 6; ++i) { + ASSERT_EQ(extra1[i], extra3[i]); + } +} + +void LoopNestReorderTestHelper( + bool prepend, + bool append, + int index1, + int index2) { + Tensor c = Compute( + "5d", {2, 3, 2, 3, 2}, [](const std::vector&) { return -1; }); + LoopNest l({c}); + + BufHandle extra("extra", {5}, kInt); + + auto loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); + int j = 0; + for (auto l : loops) { + // Add an increment at each layer of the loop which counts the number of + // times the loop executes. + LoadPtr load = + alloc(extra.node(), std::vector({alloc(j)})); + AddPtr add = alloc(load, alloc(1)); + StmtPtr store = alloc( + extra.node(), std::vector({alloc(j)}), add); + if (prepend) { + l->body()->prepend_stmt(store); + } + if (append) { + l->body()->append_stmt(Stmt::clone(store)); + } + + j++; + } + + StmtPtr stmt1 = Stmt::clone(l.root_stmt()); + + std::vector extra1(5, 0); + std::vector res1(2 * 3 * 2 * 3 * 2, 0); + SimpleIREvaluator cg(stmt1, {c, extra}); + cg.call({res1, extra1}); + + std::vector loopExtents = {2, 3, 2, 3, 2}; + + int expected_loops = 0; + if (prepend) { + expected_loops++; + } + if (append) { + expected_loops++; + } + for (int i = 0; i < 5; ++i) { + expected_loops *= loopExtents[i]; + ASSERT_EQ(extra1[i], expected_loops); + } + + loops = l.getAllLoopNestsWritingToBuf(c.buf()).at(0); + LoopNest::reorderAxis(loops[index1], loops[index2]); + StmtPtr stmt2 = Stmt::clone(l.root_stmt()); + + std::ostringstream oss, oss2; + oss << *stmt1; + oss2 << *stmt2; + ASSERT_NE(oss.str(), oss2.str()); + + std::vector extra2(5, 0); + std::vector res2(2 * 3 * 2 * 3 * 2, 0); + SimpleIREvaluator cg2(stmt2, {c, extra}); + cg2.call({res2, extra2}); + + expected_loops = 0; + if (prepend) { + expected_loops++; + } + if (append) { + expected_loops++; + } + + for (int i = 0; i < 5; ++i) { + expected_loops *= loopExtents[i]; + ASSERT_EQ(extra2[i], expected_loops); + } + + for (int i = 0; i < 2 * 3 * 2 * 3 * 2; ++i) { + ASSERT_EQ(res2[i], res1[i]); + } +} + +TEST(LoopNest, LoopNestReorderLongStringOfPreOrphans) { + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 5; ++j) { + // skip noops, since we check the loop isn't the same after reordering. + if (i != j) { + LoopNestReorderTestHelper(true, false, i, j); + } + } + } +} + +TEST(LoopNest, LoopNestReorderLongStringOfPostOrphans) { + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 5; ++j) { + // skip noops, since we check the loop isn't the same after reordering. + if (i != j) { + LoopNestReorderTestHelper(false, true, i, j); + } + } + } +} + +TEST(LoopNest, LoopNestReorderLongStringFull) { + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 5; ++j) { + // skip noops, since we check the loop isn't the same after reordering. + if (i != j) { + LoopNestReorderTestHelper(true, true, i, j); + } + } + } +} + +TEST(LoopNest, LoopNestReorderInternalLoopNest) { + const int M = 4; + const int N = 5; + const int K = 6; + BufHandle a_buf("a", {M, N}, kFloat); + BufHandle b_buf("b", {N, K}, kFloat); + BufHandle c_buf("c", {M, N}, kFloat); + BufHandle d_buf("d", {M, K}, kFloat); + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) * b_buf.load(n, k); + }); + Tensor y = Compute( + "y", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c_buf.load(m, n) * d_buf.load(m, k) + x.load(m, n, k); + }); + Tensor z = Compute( + "z", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x.load(m, n, k) + y.load(m, n, k); + }); + + LoopNest l({z}, {x, y, z}); + ForPtr a = l.getAllLoopNestsWritingToBuf(y.buf())[0][2]; + ForPtr b = l.getAllLoopNestsWritingToBuf(y.buf())[0][0]; + LoopNest::reorderAxis(a, b); + + l.prepareForCodegen(); + StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); + + // Check the IR we produced has the 3 nests in the right order, but k and m + // swapped in the middle. + checkIR(stmt, R"IR( +# CHECK: < 4 +# CHECK: < 5 +# CHECK: < 6 +# CHECK: < 6 +# CHECK: < 5 +# CHECK: < 4 +# CHECK: < 4 +# CHECK: < 5 +# CHECK: < 6)IR"); + + { + PaddedBuffer a_v(M, N); + PaddedBuffer b_v(N, K); + PaddedBuffer c_v(M, N); + PaddedBuffer d_v(M, K); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + a_v(i, j) = i * i; + } + } + for (int i = 0; i < N; i++) { + for (int j = 0; j < K; j++) { + b_v(i, j) = j * j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + c_v(i, j) = i + j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < K; j++) { + d_v(i, j) = i * j; + } + } + + PaddedBuffer z_v(M, N, K); + PaddedBuffer z_ref(M, N, K); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); + } + } + } + + SimpleIREvaluator eval(stmt, {a_buf, b_buf, c_buf, d_buf, z}); + eval(a_v, b_v, c_v, d_v, z_v); + ExpectAllNear(z_v, z_ref, 1e-5); + } +} + +TEST(LoopNest, OuterLoopVectorization) { + Tensor tensor = + Compute("f", {8, 8}, [](const VarHandle& x, const VarHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }); + LoopNest l({tensor}); + + ASSERT_TRUE( + LoopNest::vectorize(l.getAllLoopNestsWritingToBuf(tensor.buf())[0][0])); + + StmtPtr root_stmt = l.root_stmt(); + BlockPtr outer_block = to(root_stmt); + ASSERT_NE(outer_block, nullptr); + while (BlockPtr inner_block = to(outer_block->front())) { + outer_block = inner_block; + } + + // Verify that we have only a single loop level remaining after + // vectorization. + ASSERT_EQ(outer_block->nstmts(), 1); + ForPtr for_loop = to(outer_block->front()); + ASSERT_NE(for_loop, nullptr); + BlockPtr for_body = for_loop->body(); + ASSERT_EQ(for_body->nstmts(), 1); + ASSERT_EQ(to(for_body->front()), nullptr); +} + +TEST(LoopNest, VectorizeLoopNotNormalized) { + // Input IR: + // for (int i = 0; i < 10; i++) { + // for (int j = 1; j < 5; j++) { + // A[i,j] = i * j; + // } + // } + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = For::make(j, 1, 5, for_body); + auto outer_for = For::make(i, 0, 10, inner_for); + auto block = Block::make({outer_for}); + LoopNest l(block, {a_buf.node()}); + + ASSERT_TRUE(LoopNest::vectorize(inner_for)); + ASSERT_EQ(outer_for->body()->nstmts(), 1); + ASSERT_EQ(to(outer_for->body()->front()), nullptr); +} + +namespace { + +std::string constantUpperBoundLoopIR(int upper_bound_val) { + ExprHandle upper_bound(upper_bound_val); + Tensor A = + Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; }); + LoopNest l({A}); + std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; + StmtPtr unrolled = nullptr; + LoopNest::fullUnroll(loops[0], &unrolled); + std::ostringstream oss; + oss << *unrolled; + return oss.str(); +} + +} // namespace + +TEST(LoopNest, Unroll) { + const std::string actual = constantUpperBoundLoopIR(3); + const std::string& verification_pattern = + R"IR( +# CHECK: A[0] = 0; +# CHECK: A[1] = 2; +# CHECK: A[2] = 4)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, actual); +} + +TEST(LoopNest, UnrollOuter) { + ExprHandle outer_bound(3); + ExprHandle inner_bound(4); + Tensor A = Compute( + "A", + {outer_bound, inner_bound}, + [&](const VarHandle& x, const VarHandle& y) { return x + y; }); + LoopNest l({A}); + std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; + StmtPtr unrolled = nullptr; + LoopNest::fullUnroll(loops[0], &unrolled); + checkIR(unrolled, R"IR( +# CHECK: for (int i = 0; i < 4; i++) { +# CHECK: A[0, i] = i; +# CHECK: } +# CHECK: for (int i = 0; i < 4; i++) { +# CHECK: A[1, i] = i + 1; +# CHECK: } +# CHECK: for (int i = 0; i < 4; i++) { +# CHECK: A[2, i] = i + 2; +# CHECK: })IR"); +} + +TEST(LoopNest, UnrollInner) { + ExprHandle outer_bound(3); + ExprHandle inner_bound(4); + Tensor A = Compute( + "A", + {outer_bound, inner_bound}, + [&](const VarHandle& x, const VarHandle& y) { return x + y; }); + LoopNest l({A}); + std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; + StmtPtr unrolled = nullptr; + LoopNest::fullUnroll( + static_to(loops[0]->body()->stmts().front()), &unrolled); + checkIR(loops[0], R"IR( +# CHECK: for (int i = 0; i < 3; i++) { +# CHECK: A[i, 0] = i; +# CHECK: A[i, 1] = i + 1; +# CHECK: A[i, 2] = i + 2; +# CHECK: A[i, 3] = i + 3; +# CHECK: })IR"); +} + +TEST(LoopNest, UnrollMultipleStatements) { + const int kTotalSize = 3; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + + VarHandle x("x", kInt); + auto f = For::make( + x, + 0, + kTotalSize, + Block::make( + {Store::make(a_buf, {x}, x * 2), + Store::make(b_buf, {x}, Load::make(a_buf, {x}))})); + auto parent_block = Block::make({f}); + StmtPtr unrolled = nullptr; + LoopNest::fullUnroll(f, &unrolled); + checkIR(unrolled, R"IR( +# CHECK: A[0] = 0; +# CHECK: B[0] = A[0]; +# CHECK: A[1] = 2; +# CHECK: B[1] = A[1]; +# CHECK: A[2] = 4 +# CHECK: B[2] = A[2];)IR"); +} + +TEST(LoopNest, UnrollNonLiteralConstantBounds) { + // Input IR: + // for (int i = 2 - 1; i < 12 / 3; i++) { + // for (int j = 0; j < 4; j++) { + // A[i,j] = i * j; + // } + // } + BufHandle a_buf("A", {3, 4}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = For::make(j, 0, 4, for_body); + auto outer_for = For::make( + i, + IntImm::make(2) - IntImm::make(1), + IntImm::make(12) / IntImm::make(3), + inner_for); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto b = Block::make({outer_for}); + + std::vector loops = {outer_for, inner_for}; + StmtPtr unrolled = nullptr; + LoopNest::fullUnroll(loops[0], &unrolled); + checkIR(unrolled, R"IR( +# CHECK: for (int j = 0; j < 4; j++) { +# CHECK: A[1, j] = j; +# CHECK: } +# CHECK: for (int j = 0; j < 4; j++) { +# CHECK: A[2, j] = 2 * j; +# CHECK: } +# CHECK: for (int j = 0; j < 4; j++) { +# CHECK: A[3, j] = 3 * j; +# CHECK: })IR"); +} + +TEST(LoopNest, UnrollNonConstantBounds) { + // Input IR: + // for (int i = 0; i < M; i++) { + // for (int j = 0; j < N; j++) { + // A[i, j] = i * j; + // } + // } + VarHandle M("M", kInt); + VarHandle N("N", kInt); + BufHandle a_buf("A", {M, N}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = For::make(j, 0, N, for_body); + auto outer_for = For::make(i, 0, M, inner_for); + auto block = Block::make({outer_for}); + LoopNest l(block, {a_buf.node()}); + + LoopNest::unroll(inner_for, 8); + l.simplify(); + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int i = 0; i < M; i++) { + # CHECK: for (int j_outer = 0; j_outer < N / 8; j_outer++) { + # CHECK: A[i, 8 * j_outer] = + # CHECK: A[i, 8 * j_outer + 1] = + # CHECK: A[i, 2 * (4 * j_outer + 1)] = + # CHECK: A[i, 8 * j_outer + 3] = + # CHECK: A[i, 4 * (2 * j_outer + 1)] = + # CHECK: A[i, 8 * j_outer + 5] = + # CHECK: A[i, 8 * j_outer + 6] = + # CHECK: A[i, 8 * j_outer + 7] = + # CHECK: } + # CHECK: for (int j_tail = 0; j_tail < N % 8; j_tail++) { + # CHECK: A[i, 8 * (N / 8) + j_tail] = + # CHECK: } + # CHECK: } + )IR"); +} + +TEST(LoopNest, UnrollByFactorsLessThan2) { + // Input IR: + // for (int i = 0; i < M; i++) { + // for (int j = 0; j < N; j++) { + // A[i, j] = i * j; + // } + // } + VarHandle M("M", kInt); + VarHandle N("N", kInt); + BufHandle a_buf("A", {M, N}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = For::make(j, 0, N, for_body); + auto outer_for = For::make(i, 0, M, inner_for); + auto block = Block::make({outer_for}); + LoopNest l(block, {a_buf.node()}); + + // Unrolling by factor = 1 should do nothing. + LoopNest::unroll(inner_for, 1); + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int i = 0; i < M; i++) { + # CHECK: for (int j = 0; j < N; j++) { + # CHECK: A[i, j] = + # CHECK: } + # CHECK: } + )IR"); + + // Unrolling by factor = 0 should do nothing. + LoopNest::unroll(inner_for, 0); + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int i = 0; i < M; i++) { + # CHECK: for (int j = 0; j < N; j++) { + # CHECK: A[i, j] = + # CHECK: } + # CHECK: } + )IR"); + + // Unrolling by negative factor should do nothing. + LoopNest::unroll(inner_for, -2); + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int i = 0; i < M; i++) { + # CHECK: for (int j = 0; j < N; j++) { + # CHECK: A[i, j] = + # CHECK: } + # CHECK: } + )IR"); +} + +TEST(LoopNest, UnrollByFactorEqualToIters) { + // Input IR: + // for (int i = 0; i < 5; i++) { + // A[i] = i * i; + // } + BufHandle a_buf("A", {5}, kInt); + VarHandle i("i", kInt); + auto for_body = Block::make({Store::make(a_buf, {i}, i * i)}); + auto for_loop = For::make(i, 0, 5, for_body); + auto block = Block::make({for_loop}); + LoopNest l(block, {a_buf.node()}); + + LoopNest::unroll(for_loop, 5); + checkIR(l.root_stmt(), R"IR( + # CHECK: for (int i_outer = 0; i_outer < (5 - 0) / 5; i_outer++) + # CHECK: A[5 * i_outer] + # CHECK: A[5 * i_outer + 1] + # CHECK: A[5 * i_outer + 2] + # CHECK: A[5 * i_outer + 3] + # CHECK: A[5 * i_outer + 4] + )IR"); +} + +TEST(LoopNest, UnrollEmpty) { + const std::string actual = constantUpperBoundLoopIR(0); + const std::string& verification_pattern = R"IR( +# CHECK-NOT: A[ + )IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, actual); +} + +TEST(LoopNest, NoUnroll) { + VarHandle upper_bound("N", kInt); + Tensor A = + Compute("A", {upper_bound}, [&](const VarHandle& x) { return x * 2; }); + LoopNest l({A}); + std::vector loops = l.getAllLoopNestsWritingToBuf(A.buf())[0]; + StmtPtr unrolled = nullptr; + ASSERT_THROWS_WITH( + LoopNest::fullUnroll(loops[0], &unrolled), "non-constant loop"); +} + +TEST(LoopNest, UnrollWithLet) { + const int kTotalSize = 3; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + + VarHandle e("e", kInt); + VarHandle x("x", kInt); + auto f = For::make( + x, + 0, + kTotalSize, + Block::make( + {Let::make(e, 7), + Store::make(a_buf, {x}, e), + Store::make(b_buf, {x}, e + 1)})); + auto parent_block = Block::make({f}); + StmtPtr unrolled = nullptr; + LoopNest::fullUnroll(f, &unrolled); + std::ostringstream oss; + oss << *unrolled; + const std::string& verification_pattern = + R"IR( +# CHECK: int e = 7; +# CHECK: A[0] = e; +# CHECK: B[0] = e + 1; +# CHECK: A[1] = e; +# CHECK: B[1] = e + 1; +# CHECK: A[2] = e; +# CHECK: B[2] = e + 1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector a_v(kTotalSize, 0); + std::vector b_v(kTotalSize, 0); + SimpleIREvaluator eval(unrolled, {a_buf, b_buf}); + eval(a_v, b_v); + for (int i = 0; i < kTotalSize; ++i) { + ASSERT_EQ(a_v[i], 7); + ASSERT_EQ(b_v[i], 8); + } +} + +TEST(LoopNest, IsNormalized) { + // Input IR: + // for (int i = 50; i < 100; i++) { + // A[i] = B[i]; + // } + BufHandle a_buf("A", {ExprHandle(100)}, kInt); + BufHandle b_buf("B", {ExprHandle(100)}, kInt); + VarHandle i("i", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto for_stmt = + For::make(i, 50, 100, Store::make(a_buf, {i}, Load::make(b_buf, {i}))); + Block::make({for_stmt}); + ASSERT_FALSE(LoopNest::isNormalized(for_stmt)); + + for_stmt->set_start(alloc(0)); + ASSERT_TRUE(LoopNest::isNormalized(for_stmt)); + + VarHandle N("N", kInt); + for_stmt->set_start(N.node()); + ASSERT_FALSE(LoopNest::isNormalized(for_stmt)); +} + +TEST(LoopNest, NormalizeStartPositive) { + // Input IR: + // for (int x = 50; x < 100; x++) { + // A[x] = B[x]; + // B[x] = x * 2; + // } + const int kTotalSize = 50; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + VarHandle x("x", kInt); + auto for_body = Block::make( + {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), + Store::make(b_buf, {x}, x * 2)}); + auto for_stmt = For::make(x, 50, 100, for_body); + Block::make({for_stmt}); + + LoopNest::normalize(for_stmt); + + auto result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int x = 0; x < 50; x++) { + # CHECK: A[x + 50] = B[x + 50]; + # CHECK: B[x + 50] = 2 * (x + 50); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, NormalizeStartNegative) { + // Input IR: + // for (int x = -50; x < 100; x++) { + // A[x + 50] = B[x + 50]; + // B[x + 50] = x * 2; + // } + const int kTotalSize = 150; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + VarHandle x("x", kInt); + auto for_body = Block::make( + {Store::make(a_buf, {x + 50}, Load::make(kInt, b_buf, {x + 50})), + Store::make(b_buf, {x + 50}, x * 2)}); + auto for_stmt = For::make(x, -50, 100, for_body); + Block::make({for_stmt}); + + LoopNest::normalize(for_stmt); + + auto result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int x = 0; x < 150; x++) { + # CHECK: A[x] = B[x]; + # CHECK: B[x] = 2 * (x - 50); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, NormalizeStartZero) { + // Input IR: + // for (int x = 0; x < 100; x++) { + // A[x] = B[x]; + // B[x] = x * 2; + // } + // Should not be modified. + + const int kTotalSize = 100; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + VarHandle x("x", kInt); + auto for_body = Block::make( + {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), + Store::make(b_buf, {x}, x * 2)}); + auto for_stmt = For::make(x, 0, 100, for_body); + Block::make({for_stmt}); + + LoopNest::normalize(for_stmt); + + auto result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int x = 0; x < 100; x++) { + # CHECK: A[x] = B[x]; + # CHECK: B[x] = 2 * x; + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, NormalizeStartVariable) { + // Input IR: + // for (int x = y; x < 100; x++) { + // A[x] = B[x]; + // B[x] = x * 2; + // } + + const int kTotalSize = 100; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto for_body = Block::make( + {Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})), + Store::make(b_buf, {x}, x * 2)}); + auto for_stmt = For::make(x, y, 100, for_body); + auto parent_block = Block::make({for_stmt}); + + LoopNest::normalize(for_stmt); + + auto result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int x = 0; x < 100 - y; x++) { + # CHECK: A[x + y] = B[x + y]; + # CHECK: B[x + y] = 2 * (x + y); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, NormalizeOnNestedOuterLoop) { + // Input IR: + // for (int x = 50; x < 100; x++) { + // for (int y = 10; y < 100; y++) { + // A[x] = A[x] + B[y] + y * 2; + // } + // } + + BufHandle a_buf("A", {ExprHandle(50)}, kInt); + BufHandle b_buf("B", {ExprHandle(100)}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto inner_for_body = Store::make( + a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2); + auto inner_for = For::make(y, 10, 100, inner_for_body); + auto for_stmt = For::make(x, 50, 100, inner_for); + Block::make({for_stmt}); + + LoopNest::normalize(for_stmt); + + auto result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int x = 0; x < 50; x++) { + # CHECK: for (int y = 10; y < 100; y++) { + # CHECK: A[x + 50] = ((A[x + 50]) + (B[y])) + 2 * y; + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, NormalizeOnNestedInnerLoop) { + // Input IR: + // for (int x = 50; x < 100; x++) { + // for (int y = 10; y < 100; y++) { + // A[x] = A[x] + B[y] + y * 2; + // } + // } + + BufHandle a_buf("A", {ExprHandle(50)}, kInt); + BufHandle b_buf("B", {ExprHandle(100)}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto inner_for_body = Store::make( + a_buf, {x}, Load::make(a_buf, {x}) + Load::make(b_buf, {y}) + y * 2); + auto inner_for = For::make(y, 10, 100, inner_for_body); + auto for_stmt = For::make(x, 50, 100, inner_for); + Block::make({for_stmt}); + + LoopNest::normalize(inner_for); + + auto result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int x = 50; x < 100; x++) { + # CHECK: for (int y = 0; y < 90; y++) { + # CHECK: A[x] = (((A[x]) + (B[y + 10])) + 2 * y) + 20; + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(LoopNest, NormalizeAndSplitWithTail) { + // Create a dummy tensor to construct LoopNest. + ExprHandle n(100); + BufHandle a("a", {n}, kFloat); + Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); + LoopNest l({b}); + + // Input IR: + // for (int x = 5; x < 10; x++) { + // A[x] = x * 2; + // } + const int kTotalSize = 5; + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt); + VarHandle x("x", kInt); + auto for_stmt = For::make(x, 5, 10, Store::make(a_buf, {x}, x * 2)); + auto parent_block = Block::make({for_stmt}); + + LoopNest::normalize(for_stmt); + + ForPtr x_inner; + ForPtr x_tail; + LoopNest::splitWithTail(for_stmt, 10, &x_inner, &x_tail); + + auto x_outer_result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss_outer; + oss_outer << *x_outer_result; + const std::string& expected_outer_ir = + R"IR( + # CHECK: { + # CHECK: } + )IR"; + torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str()); + + auto x_tail_result = IRSimplifier::simplify(x_tail); + std::ostringstream oss_tail; + oss_tail << *x_tail_result; + const std::string& expected_tail_ir = + R"IR( + # CHECK: for (int x_tail = 0; x_tail < 5; x_tail++) { + # CHECK: A[x_tail + 5] = 2 * (x_tail + 5); + )IR"; + torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str()); +} + +TEST(LoopNest, NotNormalizeAndSplitWithTail) { + // Create a dummy tensor to construct LoopNest. + ExprHandle n(100); + BufHandle a("a", {n}, kFloat); + Tensor b = Compute("b", {n}, [&](const VarHandle& i) { return a.load(i); }); + LoopNest l({b}); + + // Input IR: + // for (int x = 5; x < 15; x++) { + // A[x] = x * 2; + // } + const int kTotalSize = 10; + BufHandle a_buf("A", {kTotalSize}, kInt); + VarHandle x("x", kInt); + auto for_stmt = For::make(x, 5, 15, Store::make(a_buf, {x}, x * 2)); + auto parent_block = Block::make({for_stmt}); + + ForPtr x_inner; + ForPtr x_tail; + LoopNest::splitWithTail(for_stmt, 8, &x_inner, &x_tail); + + auto x_outer_result = IRSimplifier::simplify(for_stmt); + std::ostringstream oss_outer; + oss_outer << *x_outer_result; + const std::string& expected_outer_ir = + R"IR( + # CHECK: { + # CHECK: } + )IR"; + torch::jit::testing::FileCheck().run(expected_outer_ir, oss_outer.str()); + + auto x_tail_result = IRSimplifier::simplify(x_tail); + std::ostringstream oss_tail; + oss_tail << *x_tail_result; + const std::string& expected_tail_ir = + R"IR( + # CHECK: for (int x_tail = 0; x_tail < 2; x_tail++) { + # CHECK: A[x_tail + 13] = 2 * (x_tail + 13); + )IR"; + torch::jit::testing::FileCheck().run(expected_tail_ir, oss_tail.str()); +} + +TEST(LoopNest, FlattenSimpleLoopNest2D) { + // Input IR: + // for (int i = 0; i < 10; i++) { + // for (int j = 0; j < 5; j++) { + // A[i,j] = i * j; + // } + // } + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = For::make(j, 0, 5, for_body); + auto outer_for = For::make(i, 0, 10, inner_for); + auto parent_block = Block::make({outer_for}); + + std::vector loops = {outer_for, inner_for}; + ForPtr flattened = nullptr; + ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, loops.front()); + + auto result = IRSimplifier::simplify(flattened); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) { + # CHECK: A[i_flat / 5, i_flat % 5] = + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + { + SimpleIREvaluator eval1(loops[0], {a_buf}); + PaddedBuffer inp1(10, 5); + eval1(inp1); + SimpleIREvaluator eval2(flattened, {a_buf}); + PaddedBuffer inp2(10, 5); + eval2(inp2); + ExpectAllNear(inp1, inp2, 1e-5); + } +} + +TEST(LoopNest, FlattenSimpleLoopNest3D) { + // Input IR: + // for (int i = 0; i < 10; i++) { + // for (int j = 0; j < 5; j++) { + // for (int k = 0; k < 7; k++) { + // A[i,j,k] = i + j * k; + // } + // } + // } + BufHandle a_buf("A", {10, 5, 7}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j, k}, i + j * k)}); + auto for1 = For::make(k, 0, 7, for_body); + auto for2 = For::make(j, 0, 5, for1); + auto for3 = For::make(i, 0, 10, for2); + auto parent_block = Block::make({for3}); + + std::vector loops = {for3, for2, for1}; + ForPtr flattened = nullptr; + ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, loops.front()); + + auto result = IRSimplifier::simplify(flattened); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int i_flat = 0; i_flat < 350; i_flat++) { + # CHECK: A[i_flat / 35, (i_flat / 7) % 5, i_flat % 7] = + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + { + SimpleIREvaluator eval1(loops[0], {a_buf}); + PaddedBuffer inp1(10, 5, 7); + eval1(inp1); + SimpleIREvaluator eval2(flattened, {a_buf}); + PaddedBuffer inp2(10, 5, 7); + eval2(inp2); + ExpectAllNear(inp1, inp2, 1e-5); + } +} + +TEST(LoopNest, FlattenLoopNestAfterNormalize) { + // Input IR: + // for (int i = 2; i < 10; i++) { + // for (int j = 3; j < 15; j++) { + // A[i - 2,j - 3] = i * j; + // } + // } + BufHandle a_buf("A", {8, 12}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i - 2, j - 3}, i * j)}); + auto inner_for = For::make(j, 3, 15, for_body); + auto outer_for = For::make(i, 2, 10, inner_for); + auto parent_block = Block::make({outer_for}); + + std::vector loops = {outer_for, inner_for}; + ForPtr flattened = nullptr; + ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, loops.front()); + + auto result = IRSimplifier::simplify(flattened); + std::ostringstream oss; + oss << *result; + const std::string& expected_ir = + R"IR( + # CHECK: for (int i_flat = 0; i_flat < 96; i_flat++) { + # CHECK: A[i_flat / 12, i_flat % 12] = + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + { + SimpleIREvaluator eval1(loops[0], {a_buf}); + PaddedBuffer inp1(8, 12); + eval1(inp1); + SimpleIREvaluator eval2(flattened, {a_buf}); + PaddedBuffer inp2(8, 12); + eval2(inp2); + ExpectAllNear(inp1, inp2, 1e-5); + } +} + +TEST(LoopNest, FlattenLoopNestWithNonLiteralConstantBounds) { + // Input IR: + // for (int i = 0; i < 15-5; i++) { + // for (int j = 0; j < 20/4; j++) { + // A[i,j] = i * j; + // } + // } + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = + For::make(j, 0, IntImm::make(20) / IntImm::make(4), for_body); + auto outer_for = + For::make(i, 0, IntImm::make(15) - IntImm::make(5), inner_for); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto b = Block::make({outer_for}); + + std::vector loops = {outer_for, inner_for}; + ForPtr flattened = nullptr; + ASSERT_TRUE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, loops.front()); + + auto result = IRSimplifier::simplify(flattened); + checkIR(result, R"IR( + # CHECK: for (int i_flat = 0; i_flat < 50; i_flat++) { + # CHECK: A[i_flat / 5, i_flat % 5] = + )IR"); + + { + SimpleIREvaluator eval1(loops[0], {a_buf}); + PaddedBuffer inp1(10, 5); + eval1(inp1); + SimpleIREvaluator eval2(flattened, {a_buf}); + PaddedBuffer inp2(10, 5); + eval2(inp2); + ExpectAllNear(inp1, inp2, 1e-5); + } +} + +TEST(LoopNest, FlattenImperfectLoopNest) { + // Input IR: + // for (int i = 0; i < 10; i++) { + // A[i, i] = 0; + // for (int j = 0; j < 15; j++) { + // A[i,j] = i * j; + // } + // } + // Do not flatten. + + BufHandle a_buf("A", {10, 15}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for = For::make(j, 0, 15, for_body); + auto outer_for = For::make( + i, 0, 10, Block::make({Store::make(a_buf, {i, i}, 0), inner_for})); + auto par = Block::make({outer_for}); + HashProvider hasher; + auto hash_before = hasher.hash(par); + + std::vector loops = {outer_for, inner_for}; + ForPtr flattened = nullptr; + ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, nullptr); + auto hash_after = hasher.hash(par); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, FlattenReductionLoopNest) { + // Input IR: + // for (int i = 0; i < 10; i++) { + // S[i] = 0; + // for (int j = 0; j < 15; j++) { + // S[i] = S[i] + A[i,j]; + // } + // } + // Do not flatten. + + BufHandle a_buf("A", {10, 15}, kInt); + BufHandle s_buf("S", {10}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto for_body = Block::make({Store::make( + s_buf, {i}, Load::make(s_buf, {i}) + Load::make(a_buf, {i, j}))}); + auto inner_for = For::make(j, 0, 15, for_body); + auto outer_for = + For::make(i, 0, 10, Block::make({Store::make(s_buf, {i}, 0), inner_for})); + auto par = Block::make({outer_for}); + HashProvider hasher; + auto hash_before = hasher.hash(par); + + std::vector loops = {outer_for, inner_for}; + ForPtr flattened = nullptr; + ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, nullptr); + auto hash_after = hasher.hash(par); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, FlattenReductionLoopNestFromTensor) { + const int M = 3; + const int N = 7; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + BufHandle b("b", {m, n}, kFloat); + Tensor c = Reduce("sum", {M}, Sum(), b, {N}); + LoopNest loop({c}); + HashProvider hasher; + auto hash_before = hasher.hash(loop.root_stmt()); + + auto loops = loop.getAllLoopNestsWritingToBuf(c.buf())[1]; + ForPtr flattened = nullptr; + ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, nullptr); + auto hash_after = hasher.hash(loop.root_stmt()); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, FlattenIncorrectLoopsAsInput) { + // Input IR: + // for (int i = 0; i < 10; i++) { + // for (int j = 0; j < 5; j++) { + // A[i,j] = i * j; + // } + // } + // for (int x = 0; x < 10; x++) { + // for (int y = 0; y < 5; y++) { + // A[x,y] = A[x,y] + x + y; + // } + // } + // Flatten({For_i, For_y}) => should not succeed + + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for1 = For::make(j, 0, 5, for_body1); + auto outer_for1 = For::make(i, 0, 10, inner_for1); + auto for_body2 = Block::make( + {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); + auto inner_for2 = For::make(y, 0, 5, for_body2); + auto outer_for2 = For::make(x, 0, 10, inner_for2); + auto par = Block::make({outer_for1, outer_for2}); + HashProvider hasher; + auto hash_before = hasher.hash(par); + + std::vector loops = {outer_for1, inner_for2}; + ForPtr flattened = nullptr; + ASSERT_FALSE(LoopNest::flatten(loops, &flattened)); + ASSERT_EQ(flattened, nullptr); + auto hash_after = hasher.hash(par); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, DetectInlineRankMismatch) { + const int kTotalSize = 8; + + BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat); + Tensor a = Compute( + "a", {kTotalSize}, [&](const VarHandle& i) { return a_buf.load(i); }); + Tensor reshape = Compute( + "reshape", + {kTotalSize / 2, 2}, + [&](const VarHandle& i, const VarHandle& j) { return a.load(i, j); }); + LoopNest l({reshape}, {a, reshape}); + ASSERT_FALSE(l.computeInline(l.getLoopBodyFor(a))); +} + +TEST(LoopNest, CacheReadsSimple) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor B = + Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 30, j + 3); + }); + Tensor C = + Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); + }); + + LoopNest l({B, C}, {A, B, C}); + StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; + LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); + + l.prepareForCodegen(); + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {B, C}); + result = cg.stmt(); + + // just this once: verify the whole thing. + checkIR(result, R"IR( +#CHECK: Allocate(A); // dtype=int, dims=[64, 64] +#CHECK: Allocate(A_local); // dtype=int, dims=[1, 10] +#CHECK: for (int i +#CHECK: for (int j +#CHECK: A[ +#CHECK: } +#CHECK: } +#CHECK: for (int i_1 +#CHECK: for (int j_1 +#CHECK: A_local[j_1] = A[ +#CHECK: } +#CHECK: for (int j_2 +#CHECK: B[j_2 + 10 * i_1] = A_local[j_2]; +#CHECK: } +#CHECK: } +#CHECK: for (int i_2 +#CHECK: for (int j_3 +#CHECK: C[ +#CHECK: } +#CHECK: } +#CHECK: Free(A_local); +#CHECK: Free(A); + )IR"); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 3); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +TEST(LoopNest, CacheReadsOuter) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor B = + Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); + }); + Tensor C = + Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); + }); + + LoopNest l({B, C}, {A, B, C}); + StmtPtr i_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][0]; + LoopNest::cacheAccesses(A.buf(), "A_local", i_loop); + + l.prepareForCodegen(); + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {B, C}); + result = cg.stmt(); + + checkIR(result, R"IR( +#CHECK: Allocate(A_local); // dtype=int, dims=[21, 11] +#CHECK: A_local[j_1 + 11 * i_1] = +#CHECK: B[j_2 + 10 * i_2] = (A_local[j_2 + 11 * i_2]) + (A_local[(j_2 + 11 * i_2) + 12]); + )IR"); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +TEST(LoopNest, CacheReadsInternal) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor B = + Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); + }); + Tensor C = + Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); + }); + + LoopNest l({B, C}, {A, B, C}); + StmtPtr j_loop = l.getAllLoopNestsWritingToBuf(B.buf())[0][1]; + LoopNest::cacheAccesses(A.buf(), "A_local", j_loop); + l.prepareForCodegen(); + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {B, C}); + result = cg.stmt(); + + checkIR(result, R"IR( +#CHECK: Allocate(A_local); // dtype=int, dims=[2, 11] +#CHECK: A_local[k + 11 * j_1] = +#CHECK: B[j_2 + 10 * i_1] = (A_local[j_2 + 12]) + (A_local[j_2]); + )IR"); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +TEST(LoopNest, CacheReadsInner) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + // note im changing the offset of the first arg of the first call to A. + Tensor B = + Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 34, j + 40) + A.load(i + 30, j + 41); + }); + Tensor C = + Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); + }); + + LoopNest l({B, C}, {A, B, C}); + StmtPtr body = l.getLoopBodyFor(B); + LoopNest::cacheAccesses(A.buf(), "A_local", body); + l.prepareForCodegen(); + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {B, C}); + result = cg.stmt(); + + checkIR(result, R"IR( +#CHECK: Allocate(A_local); // dtype=int, dims=[5, 2] +#CHECK: A_local[l + 2 * k] = +#CHECK: B[j_1 + 10 * i_1] = (A_local[1]) + (A_local[8]); + )IR"); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 34) * (j + 40) + (i + 30) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +TEST(LoopNest, CacheWritesSimple) { + Tensor A = Compute("A", {64, 64}, [](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + Tensor B = + Compute("B", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 30, j + 40) + A.load(i + 31, j + 41); + }); + Tensor C = + Compute("C", {20, 10}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i + 10, j + 20) + A.load(i + 30, j + 40); + }); + + LoopNest l({B, C}, {A, B, C}); + StmtPtr a_loop = l.getAllLoopNestsWritingToBuf(A.buf())[0][1]; + LoopNest::cacheAccesses(A.buf(), "A_local", a_loop); + + l.prepareForCodegen(); + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {B, C}); + result = cg.stmt(); + + checkIR(result, R"IR( +#CHECK: Allocate(A_local); // dtype=int, dims=[1, 64] +#CHECK: for (int j = 0; j < 64 +#CHECK: A_local[j] = i * j; +#CHECK: for (int j_1 = 0; j_1 < 64 +#CHECK: A[j_1 + 64 * i] = A_local[ +#CHECK: Free(A_local); +#CHECK-NOT: A_local + )IR"); + + std::vector b_data(200, 0); + std::vector c_data(200, 0); + cg.call({b_data, c_data}); + + std::vector b_ref(200, 0); + std::vector c_ref(200, 0); + + for (int i = 0; i < 20; ++i) { + for (int j = 0; j < 10; ++j) { + b_ref[i * 10 + j] = (i + 30) * (j + 40) + (i + 31) * (j + 41); + c_ref[i * 10 + j] = (i + 10) * (j + 20) + (i + 30) * (j + 40); + } + } + + assertAllEqual(b_data, b_ref); + assertAllEqual(c_data, c_ref); +} + +TEST(LoopNest, DeadStoreElimination) { + VarHandle y("y", kInt); + VarHandle x("x_tail", kInt); + BufHandle f("f", {26, 5}, kInt); + BufHandle g("g", {26, 5}, kInt); + ExprHandle x_outer_end = 5; + ExprHandle x_2 = x + x_outer_end * 4; + ForPtr stmt1 = For::make( + x, + 0, + 5, + For::make( + y, + 0, + 5, + Block::make({ + Store::make(f, {x_2, y}, (x_2 + y)), + Store::make(g, {x_2, y}, (x_2 * y)), + }))); + StmtPtr stmt = Block::make({stmt1}); + + // Will eliminate if not used by an output. + LoopNest loop(Stmt::clone(stmt), {f.node()}); + loop.eliminateDeadStores(); + + checkIR(loop.root_stmt(), R"IR( +#CHECK: f[x_tail + 5 * 4, y] +#CHECK-NOT: g[x_tail + 5 * 4, y] + )IR"); + + // But won't eliminate if used by different outputs. + LoopNest loop2(stmt, {f.node(), g.node()}); + loop2.eliminateDeadStores(); + + checkIR(loop2.root_stmt(), R"IR( +#CHECK: f[x_tail + 5 * 4, y] +#CHECK: g[x_tail + 5 * 4, y] + )IR"); +} + +TEST(LoopNest, DeadStoreEliminationWithIntermediates) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + BufHandle f("f", {26 * 5}, kInt); + BufHandle g("g", {26 * 5}, kInt); + BufHandle h("h", {26, 5}, kInt); + ExprHandle x_outer_end = 5; + ExprHandle x_2 = x + x_outer_end * 4; + ForPtr stmt1 = For::make(x, 0, 26 * 5, Store::make(f, {x}, x)); + ForPtr stmt2 = For::make(z, 0, 26 * 5, Store::make(g, {z}, z + 1)); + ForPtr stmt3 = For::make( + x, + 0, + 5, + For::make( + y, + 0, + 5, + Block::make({ + Store::make(h, {x, y}, Load::make(f, {x * y})), + }))); + StmtPtr stmt = Block::make({stmt1, stmt2, stmt3}); + + // Will eliminate the write to g, but not f since it used by the producer of + // h. + LoopNest loop(Stmt::clone(stmt), {h.node()}); + loop.eliminateDeadStores(); + + checkIR(loop.root_stmt(), R"IR( + #CHECK: f[x] = x; + #CHECK-NOT: g[z] = + #CHECK: h[x, y] = f[x * y]; + )IR"); + + // Sanity check won't eliminate if g is an output. + LoopNest loop2(stmt, {h.node(), g.node()}); + loop2.eliminateDeadStores(); + + checkIR(loop2.root_stmt(), R"IR( + #CHECK: f[x] = x; + #CHECK: g[z] = z + 1; + #CHECK: h[x, y] = f[x * y]; + )IR"); +} + +TEST(LoopNest, CompoundTensorSimple) { + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for1 = For::make(j, 0, 5, for_body1); + auto outer_for1 = For::make(i, 0, 10, inner_for1); + auto for_body2 = Block::make( + {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); + auto inner_for2 = For::make(y, 0, 5, for_body2); + auto outer_for2 = For::make(x, 0, 10, inner_for2); + BlockPtr body = Block::make({outer_for1, outer_for2}); + + Tensor A = Tensor(a_buf.node(), body); + + LoopNest l({A}); + l.prepareForCodegen(); + + std::vector a_data(50, 0); + + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg(s, {A}); + + std::vector a_ref(50, 0); + + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 5; ++j) { + a_ref[i * 5 + j] = (i * j) + i + j; + } + } + cg.call({a_data}); + + assertAllEqual(a_data, a_ref); +} + +TEST(LoopNest, InlineConstantIndex) { + const int N = 10; + BufHandle x_buf("a", {1, N, 1}, kFloat); + Tensor y = Compute( + "f", + {1, N, 1}, + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { + return x_buf.load(m, n, o); + }); + Tensor z = Compute( + "f", + {1, N, 1}, + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& o) { + return y.load(m, n, o); + }); + + LoopNest l({z}, {y, z}); + l.simplify(); + ASSERT_TRUE(l.computeInline(y.buf())); +} + +TEST(LoopNest, CompoundTensorUsed) { + BufHandle a_buf("A", {10, 5}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + auto for_body1 = Block::make({Store::make(a_buf, {i, j}, i * j)}); + auto inner_for1 = For::make(j, 0, 5, for_body1); + auto outer_for1 = For::make(i, 0, 10, inner_for1); + auto for_body2 = Block::make( + {Store::make(a_buf, {x, y}, Load::make(a_buf, {x, y}) + x + y)}); + auto inner_for2 = For::make(y, 0, 5, for_body2); + auto outer_for2 = For::make(x, 0, 10, inner_for2); + BlockPtr body = Block::make({outer_for1, outer_for2}); + + Tensor A = Tensor(a_buf.node(), body); + Tensor B = Compute("B", {10, 3}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i, j + 1) + A.load(i, j + 2); + }); + + LoopNest l({B}, {A, B}); + ASSERT_FALSE(l.computeInline(A.buf())); + l.prepareForCodegen(); + + std::vector a_data(50, 0); + std::vector b_data(50, 0); + + StmtPtr s = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg(s, {B}); + + std::vector b_ref(50, 0); + + auto AT = [](int i, int j) { return i * j + i + j; }; + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 3; ++j) { + b_ref[i * 3 + j] = AT(i, j + 1) + AT(i, j + 2); + } + } + cg.call({b_data}); + + assertAllEqual(b_data, b_ref); +} + +TEST(LoopNest, InlineFromLoad) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto store_a = For::make(i, 0, N, Store::make(a, {i}, i)); + auto store_b = For::make(j, 0, N, Store::make(b, {j}, Load::make(a, {j}))); + LoopNest l(Block::make({store_a, store_b}), {b.node()}); + + l.computeInline(a.node()); + + // Check that A[j] is replaced with j after inlining + std::ostringstream oss; + oss << *l.root_stmt(); + torch::jit::testing::FileCheck().run( + R"IR( +# CHECK: for (int j +# CHECK-NOT: B[j] = A[j] +# CHECK-NEXT: B[j] = j +)IR", + oss.str()); +} + +TEST(LoopNest, OptimizeConditionalsSimple) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) + // } + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {20}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {15}, kInt); + VarHandle i("i", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto store = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 5, kLT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5}))); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, store); + auto par = Block::make({forI}); + + LoopNest nest(par, {a_buf.node()}); + nest.optimizeConditionals(); + + std::ostringstream oss; + oss << *nest.root_stmt(); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i = 0; i < 5 +# CHECK-NEXT: A[i] = B[i] +# CHECK: for (int i = 0; i < 15 +# CHECK-NEXT: A[i + 5] = C[i] + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(LoopNest, OptimizeConditionalsNestedConditions) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) + // } + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {20}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle d_buf("D", {10}, kInt); + VarHandle i("i", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto store = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 10, kLT), + IfThenElse::make( + CompareSelect::make(i, 5, kLT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5})), + Load::make(d_buf, {i - 10}))); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, store); + auto par = Block::make({forI}); + + LoopNest nest(par, {a_buf.node()}); + nest.optimizeConditionals(); + + std::ostringstream oss; + oss << *nest.root_stmt(); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i = 0; i < 5 +# CHECK-NEXT: A[i] = B[i] +# CHECK: for (int i = 0; i < 5 +# CHECK-NEXT: A[i + 5] = C[i] +# CHECK: for (int i = 0; i < 10 +# CHECK-NEXT: A[i + 10] = D[i] + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(LoopNest, OptimizeConditionalsMultipleStores) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) + // } + // for (int j = 0; j < 100; j++) { + // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j]) + // } + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {20}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {100}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle d_buf("D", {100}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto storeA = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 5, kLT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5}))); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, storeA); + auto storeB = Store::make( + b_buf, + {j}, + IfThenElse::make( + CompareSelect::make(j, 30, kLT), + Load::make(c_buf, {j}), + Load::make(d_buf, {j}))); + auto forJ = For::make(j, 0, 100, storeB); + auto par = Block::make({forI, forJ}); + + LoopNest nest(par, {a_buf.node()}); + nest.optimizeConditionals(); + + std::ostringstream oss; + oss << *nest.root_stmt(); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i = 0; i < 5 +# CHECK-NEXT: A[i] = B[i] +# CHECK: for (int i = 0; i < 15 +# CHECK-NEXT: A[i + 5] = C[i] +# CHECK: for (int j = 0; j < 30 +# CHECK-NEXT: B[j] = C[j] +# CHECK: for (int j = 0; j < 70 +# CHECK-NEXT: B[j + 30] = D[j + 30] + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(LoopNest, OptimizeConditionalsMultipleStoresInOneLoop) { + // Input IR: + // for (int i = 0; i < 50; i++) { + // A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5]) + // B[j] = IfThenElse(j<30 ? 1 : 0, C[j], D[j]) + // } + // Only the first conditional, in the write to A, will be optimized. + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {100}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {100}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {100}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle d_buf("D", {100}, kInt); + VarHandle i("i", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto storeA = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 5, kLT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5}))); + auto storeB = Store::make( + b_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 30, kLT), + Load::make(c_buf, {i}), + Load::make(d_buf, {i}))); + auto forI = For::make(i, 0, 50, Block::make({storeA, storeB})); + auto par = Block::make({forI}); + + LoopNest nest(par, {a_buf.node()}); + nest.optimizeConditionals(); + + std::ostringstream oss; + oss << *nest.root_stmt(); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i = 0; i < 5 +# CHECK-NEXT: A[i] = B[i] +# CHECK-NEXT: B[i] = C[i] +# CHECK: for (int i = 0; i < 45 +# CHECK-NEXT: A[i + 5] = C[i] +# CHECK-NEXT: B[i + 5] = IfThenElse(i + 5<30 ? 1 : 0, C[i + 5], D[i + 5]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(LoopNest, OptimizeConditionalsOuterLoopVar) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 100; j++) { + // A[i] = IfThenElse(i<10, IfThenElse(i<5, B[i], C[i-5]), D[i-10]) + // } + // } + // Currently, this case where the condition variable `i` is not the + // inner-most loop variable, is not optimized. + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {20}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle d_buf("D", {10}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto store = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 10, kLT), + IfThenElse::make( + CompareSelect::make(i, 5, kLT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5})), + Load::make(d_buf, {i - 10}))); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, For::make(j, 0, 100, store)); + auto par = Block::make({forI}); + LoopNest nest(par, {a_buf.node()}); + + HashProvider hasher; + auto hash_before = hasher.hash(nest.root_stmt()); + nest.optimizeConditionals(); + auto hash_after = hasher.hash(nest.root_stmt()); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, OptimizeConditionalsCompValuesNotOrdered) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = IfThenElse(i<5, IfThenElse(i<10, B[i], C[i-5]), D[i-10]) + // } + // No optimization should be done here because one of the conditions use '>'. + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {20}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle d_buf("D", {10}, kInt); + VarHandle i("i", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto store = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 5, kLT), + IfThenElse::make( + CompareSelect::make(i, 10, kLT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5})), + Load::make(d_buf, {i - 10}))); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, store); + auto par = Block::make({forI}); + LoopNest nest(par, {a_buf.node()}); + + HashProvider hasher; + auto hash_before = hasher.hash(nest.root_stmt()); + nest.optimizeConditionals(); + auto hash_after = hasher.hash(nest.root_stmt()); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, OptimizeConditionalsCompValuesNotConstants) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = IfThenElse(i'. + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {20}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle d_buf("D", {10}, kInt); + VarHandle i("i", kInt); + VarHandle N("N", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto store = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, N, kLT), + IfThenElse::make( + CompareSelect::make(i, 5, kLT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5})), + Load::make(d_buf, {i - 10}))); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, store); + auto par = Block::make({forI}); + LoopNest nest(par, {a_buf.node()}); + + HashProvider hasher; + auto hash_before = hasher.hash(nest.root_stmt()); + nest.optimizeConditionals(); + auto hash_after = hasher.hash(nest.root_stmt()); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, OptimizeConditionalsInvalidCondition) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = IfThenElse(i<10, IfThenElse(i>5, B[i], C[i-5]), D[i-10]) + // } + // No optimization should be done here because one of the conditions use '>'. + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle a_buf("A", {20}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle b_buf("B", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle c_buf("C", {5}, kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + BufHandle d_buf("D", {10}, kInt); + VarHandle i("i", kInt); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto store = Store::make( + a_buf, + {i}, + IfThenElse::make( + CompareSelect::make(i, 10, kLT), + IfThenElse::make( + CompareSelect::make(i, 5, kGT), + Load::make(b_buf, {i}), + Load::make(c_buf, {i - 5})), + Load::make(d_buf, {i - 10}))); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, store); + auto par = Block::make({forI}); + LoopNest nest(par, {a_buf.node()}); + + HashProvider hasher; + auto hash_before = hasher.hash(nest.root_stmt()); + nest.optimizeConditionals(); + auto hash_after = hasher.hash(nest.root_stmt()); + ASSERT_EQ(hash_before, hash_after); +} + +TEST(LoopNest, OptimizeConditionalsInvalidCondition2) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = IfThenElse(10 colReduce(int M, int N) { + BufHandle a("a", {M, N}, kFloat); + Tensor t = Reduce( + "b", + {N}, + Sum(), + [&](const VarHandle& n, const VarHandle& m) { return a.load(m, n); }, + {M}); + return {a, Tensor(t.buf(), LoopNest::sanitizeNames(t.stmt()))}; +} + +static StmtPtr splitTailReorder(Tensor b) { + constexpr int kVectorWidth = 8; + LoopNest nest({b}); + auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; + nest.splitWithTail(loops[0], kVectorWidth); + // Now the loopnests will look like: + // + // for (int i_outer = 0; ... + // for (int i_inner = 0; ... + // b[i_outer * 8 + i_inner] = float(0); + // for (int j = 0; ... + // b[i_outer * 8 + i_inner] = ReduceOp(...); + // + // for (int i_tail = 0; ... + // b[i_tail + ((100 - 0) / 8) * 8] = float(0); + // for (int j = 0; ... + // b[i_tail + ((100 - 0) / 8) * 8] = ReduceOp(...); + // + // Since there are 4 writes to b, we will get 4 loopnests from the + // call to `getAllLoopNestsWritingToBuf` below. + // + // Write #2: "b[i_outer * 8 + i_inner] = ReduceOp(...)" + // Loopnest #2: {i_outer, i_inner, j}; + // We will have to reorder i_inner and j. + auto loopnests = nest.getAllLoopNestsWritingToBuf(b.buf()); + LoopNest::reorderAxis(loopnests[1][1], loopnests[1][2]); + nest.prepareForCodegen(); + return nest.root_stmt(); +} + +static StmtPtr splitMaskReorder(Tensor b) { + constexpr int kVectorWidth = 8; + LoopNest nest({b}); + auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; + nest.splitWithMask(loops[0], kVectorWidth); + loops = nest.getAllLoopNestsWritingToBuf(b.buf())[1]; + LoopNest::reorderAxis(loops[1], loops[2]); + nest.prepareForCodegen(); + return nest.root_stmt(); +} + +static void checkColReduce(StmtPtr s, BufHandle p, Tensor t) { + int M = immediateAs(p.dim(0)); + int N = immediateAs(p.dim(1)); + PaddedBuffer a(M, N); + PaddedBuffer b(N); + PaddedBuffer ref(N); + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + a(i, j) = 1.0f; + } + } + for (int i = 0; i < N; i++) { + b(i) = 0.0f; + } + for (int i = 0; i < N; i++) { + ref(i) = 76.0f; + } + SimpleIREvaluator(s, {p, t}).call({a, b}); + ExpectAllNear(b, ref, 1e-5); +} + +TEST(LoopNest, ColReduceSplitTailEvenReorder) { + constexpr int M = 76, N = 128; + auto p = colReduce(M, N); + StmtPtr s = splitTailReorder(p.second); + + std::ostringstream oss; + oss << *s; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i_outer +# CHECK-NEXT: for (int i_inner +# CHECK-NEXT: b[ +# CHECK: for (int j +# CHECK-NEXT: for (int i_inner +# CHECK-NEXT: b[ +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + checkColReduce(s, p.first, p.second); +} + +TEST(LoopNest, ColReduceSplitTailUnevenReorder) { + constexpr int M = 76, N = 100; + auto p = colReduce(M, N); + StmtPtr s = splitTailReorder(p.second); + + std::ostringstream oss; + oss << *s; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i_outer +# CHECK-NEXT: for (int i_inner +# CHECK-NEXT: b[ +# CHECK: for (int j +# CHECK-NEXT: for (int i_inner +# CHECK-NEXT: b[ +# CHECK: for (int i_tail +# CHECK-NEXT: b[ +# CHECK-NEXT: for (int j +# CHECK-NEXT: b[ + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + checkColReduce(s, p.first, p.second); +} + +TEST(LoopNest, ColReduceSplitMaskEvenReorder) { + constexpr int M = 76, N = 128; + auto p = colReduce(M, N); + StmtPtr s = splitMaskReorder(p.second); + checkColReduce(s, p.first, p.second); +} + +TEST(LoopNest, ColReduceSplitMaskUnevenReorder) { + constexpr int M = 76, N = 100; + auto p = colReduce(M, N); + StmtPtr s = splitMaskReorder(p.second); + checkColReduce(s, p.first, p.second); +} + +TEST(LoopNest, ReorderAxisWithMultipleConds) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // if i > 5 { + // if i < 10 { + // for (int j = 0; j < 100; j++) { + // A[i] = i * j; + // } + // } + // } + // } + BufHandle a_buf("A", {20}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i}, Mul::make(i, j))); + auto inner_cond = Cond::make(CompareSelect::make(i, 10, kLT), forJ, nullptr); + auto outer_cond = + Cond::make(CompareSelect::make(i, 5, kGT), inner_cond, nullptr); + auto forI = For::make(i, 0, 20, outer_cond); + StmtPtr par = Block::make({forI}); + LoopNest l(par, {a_buf.node()}); + LoopNest::reorderAxis(forI, forJ); + ASSERT_EQ(par, l.root_stmt()); + par = IRSimplifier::simplify(par); + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int j +# CHECK-NEXT: for (int i +# CHECK-NEXT: if (i>5 +# CHECK-NEXT: if (i<10 +# CHECK-NEXT: A[i] = i * j +# CHECK-NOT: for ( + )IR"; + std::ostringstream oss; + oss << *par; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(LoopNest, VectorizeUse) { + constexpr int N = 8; + BufHandle a("a", {N}, kFloat); + Tensor b = + Compute("b", {N}, [&](const VarHandle& n) { return a.load(n) + 1.0f; }); + Tensor c = + Compute("c", {N}, [&](const VarHandle& n) { return b.load(n) + 2.0f; }); + LoopNest nest({c}, {b, c}); + auto loops = nest.getAllLoopNestsWritingToBuf(b.buf())[0]; + ASSERT_TRUE(LoopNest::vectorize(loops[0])); + loops = nest.getAllLoopNestsWritingToBuf(c.buf())[0]; + ASSERT_TRUE(LoopNest::vectorize(loops[0])); + nest.prepareForCodegen(); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + StmtPtr s = nest.root_stmt(); + std::ostringstream oss; + oss << *nest.root_stmt(); + torch::jit::testing::FileCheck().run( + R"IR( +# CHECK: c[Ramp +)IR", + oss.str()); +} + +const char* int64Loop = R"IR( +# CHECK: for (int64_t i = 0ll; i < 12ll; i++) { +# CHECK: b[i] = (a[i]) + 1ll; +# CHECK: } +)IR"; + +TEST(LoopNest, Int64Direct) { + constexpr int64_t N = 12; + BufHandle a("a", {N}, kLong); + BufHandle b("b", {N}, kLong); + VarHandle n("i", kLong); + StmtPtr s = For::make( + n, LongImm::make(0l), N, b.store({n}, a.load({n}) + LongImm::make(1l))); + s = IRSimplifier::simplify(s); + std::ostringstream oss; + oss << *s; + torch::jit::testing::FileCheck().run(int64Loop, oss.str()); +} + +TEST(LoopNest, Int64Compute) { + constexpr int64_t N = 12; + BufHandle a("a", {N}, kLong); + Tensor b = Compute("b", {N}, [&](const VarHandle& n) { + return a.load(n) + LongImm::make(1l); + }); + LoopNest nest({b}); + nest.prepareForCodegen(); + nest.simplify(); + std::ostringstream oss; + oss << *nest.root_stmt(); + torch::jit::testing::FileCheck().run(int64Loop, oss.str()); +} + +TEST(LoopNest, DistributeLoopWithAllStmtsAsPivots) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = 0; + // for (int j = 0; j < 100; j++) { + // A[i] = A[i] + i * j; + // } + // B[i] = A[i]; + // for (int k = 0; k < 50; k++) { + // B[i] = B[i] + i * k; + // } + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto initA = Store::make(a_buf, {i}, 0); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); + auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); + auto forK = For::make( + k, + 0, + 50, + Store::make( + b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); + auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); + auto par = Block::make({forI}); + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = 0 +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i] = +# CHECK: for (int i +# CHECK-NEXT: B[i] = A[i] +# CHECK: for (int i +# CHECK-NEXT: for (int k +# CHECK-NEXT: B[i] = +# CHECK-NOT: for ( + )IR"; + + LoopNest nest(par, {a_buf.node(), b_buf.node()}); + auto new_loops = LoopNest::distributeLoop(forI, {initA, forJ, initB}); + + std::ostringstream oss; + oss << *par; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The first loop after distribution must be same as the original For. + ASSERT_EQ(new_loops.front(), forI); +} + +TEST(LoopNest, DistributeLoopWithOneStmtAsPivot) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = 0; + // for (int j = 0; j < 100; j++) { + // A[i] = A[i] + i * j; + // } + // B[i] = A[i]; + // for (int k = 0; k < 50; k++) { + // B[i] = B[i] + i * k; + // } + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto initA = Store::make(a_buf, {i}, 0); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); + auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); + auto forK = For::make( + k, + 0, + 50, + Store::make( + b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); + auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); + auto par = Block::make({forI}); + + LoopNest nest(par, {a_buf.node(), b_buf.node()}); + auto new_loops = LoopNest::distributeLoop(forI, {forJ}); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = 0 +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i] = +# CHECK: for (int i +# CHECK-NEXT: B[i] = A[i] +# CHECK-NEXT: for (int k +# CHECK-NEXT: B[i] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The first loop after distribution must be same as the original For. + ASSERT_EQ(new_loops.front(), forI); +} + +TEST(LoopNest, DistributeLoopWithoutAnyPivot) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = 0; + // for (int j = 0; j < 100; j++) { + // A[i] = A[i] + i * j; + // } + // B[i] = A[i]; + // for (int k = 0; k < 50; k++) { + // B[i] = B[i] + i * k; + // } + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto initA = Store::make(a_buf, {i}, 0); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); + auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); + auto forK = For::make( + k, + 0, + 50, + Store::make( + b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); + auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); + auto par = Block::make({forI}); + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = 0 +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i] = +# CHECK: for (int i +# CHECK-NEXT: B[i] = A[i] +# CHECK: for (int i +# CHECK-NEXT: for (int k +# CHECK-NEXT: B[i] = +# CHECK-NOT: for ( + )IR"; + + LoopNest nest(par, {a_buf.node(), b_buf.node()}); + auto new_loops = LoopNest::distributeLoop(forI); + + std::ostringstream oss; + oss << *par; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The first loop after distribution must be same as the original For. + ASSERT_EQ(new_loops.front(), forI); +} + +TEST(LoopNest, DistributeLoopOverInnerLoops) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = 0; + // for (int j = 0; j < 100; j++) { + // A[i] = A[i] + i * j; + // } + // B[i] = A[i]; + // for (int k = 0; k < 50; k++) { + // B[i] = B[i] + i * k; + // } + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto initA = Store::make(a_buf, {i}, 0); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + a_buf, {i}, Add::make(Load::make(a_buf, {i}), Mul::make(i, j)))); + auto initB = Store::make(b_buf, {i}, Load::make(a_buf, {i})); + auto forK = For::make( + k, + 0, + 50, + Store::make( + b_buf, {i}, Add::make(Load::make(b_buf, {i}), Mul::make(i, k)))); + auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); + auto par = Block::make({forI}); + + LoopNest nest(par, {a_buf.node(), b_buf.node()}); + auto new_loops = LoopNest::distributeLoopOverInnerLoops(forI); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = 0 +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i] = +# CHECK: for (int i +# CHECK-NEXT: B[i] = A[i] +# CHECK-NEXT: for (int k +# CHECK-NEXT: B[i] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The first loop after distribution must be same as the original For. + ASSERT_EQ(new_loops.front(), forI); +} + +TEST(LoopNest, DistributeLoopAndParentsWithoutAnyPivot) { + // Input IR: + // for (int m = 0; m < 50; m++) { + // for (int i = 0; i < 20; i++) { + // A[m,i] = 0; + // for (int j = 0; j < 100; j++) { + // A[m,i] = A[m,i] + i * j; + // } + // B[m,i] = A[m,i]; + // for (int k = 0; k < 50; k++) { + // B[m,i] = B[m,i] + i * k; + // } + // } + // } + BufHandle a_buf("A", {100, 100}, kInt); + BufHandle b_buf("B", {100, 100}, kInt); + VarHandle m("m", kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto initA = Store::make(a_buf, {m, i}, 0); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + a_buf, + {m, i}, + Add::make(Load::make(a_buf, {m, i}), Mul::make(i, j)))); + auto initB = Store::make(b_buf, {m, i}, Load::make(a_buf, {m, i})); + auto forK = For::make( + k, + 0, + 50, + Store::make( + b_buf, + {m, i}, + Add::make(Load::make(b_buf, {m, i}), Mul::make(i, k)))); + auto forI = For::make(i, 0, 20, Block::make({initA, forJ, initB, forK})); + + { + // Check the case of distributing loop and its parents over all the + // statements in the loop. + const std::string& verification_pattern = + R"IR( +# CHECK: for (int m +# CHECK-NEXT: for (int i +# CHECK-NEXT: A[m, i] = 0 +# CHECK: for (int m +# CHECK-NEXT: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[m, i] = +# CHECK: for (int m +# CHECK-NEXT: for (int i +# CHECK-NEXT: B[m, i] = A[m, i] +# CHECK: for (int m +# CHECK-NEXT: for (int i +# CHECK-NEXT: for (int k +# CHECK-NEXT: B[m, i] = +# CHECK-NOT: for ( + )IR"; + + auto newForI = to(Stmt::clone(forI)); + auto forM = For::make(m, 0, 50, newForI); + auto par = Block::make({forM}); + LoopNest nest(par, {a_buf.node(), b_buf.node()}); + auto newLoops = LoopNest::distributeLoopAndParents(newForI); + + std::ostringstream oss; + oss << *par; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The first loop after distribution must be same as the original For. + ASSERT_EQ(newLoops.front(), forM); + } + + { + // Check the case of distributing loop and its parents over all the inner + // loops. + const std::string& verification_pattern = + R"IR( +# CHECK: for (int m +# CHECK-NEXT: for (int i +# CHECK-NEXT: A[m, i] = 0 +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[m, i] = +# CHECK: for (int m +# CHECK-NEXT: for (int i +# CHECK-NEXT: B[m, i] = A[m, i] +# CHECK-NEXT: for (int k +# CHECK-NEXT: B[m, i] = +# CHECK-NOT: for ( + )IR"; + + auto newForI = to(Stmt::clone(forI)); + auto forM = For::make(m, 0, 50, newForI); + auto par = Block::make({forM}); + LoopNest nest(par, {a_buf.node(), b_buf.node()}); + auto newLoops = LoopNest::distributeLoopAndParentsOverInnerLoops(newForI); + + std::ostringstream oss; + oss << *par; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The first loop after distribution must be same as the original For. + ASSERT_EQ(newLoops.front(), forM); + } +} + +TEST(LoopNest, fuseLoopsSimple) { + // Input IR: + // for (int j = 0; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 0; k < 100; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int j +# CHECK-NEXT: A[j] = +# CHECK-NEXT: B[j] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forJ); +} + +TEST(LoopNest, fuseLoopsMultiple) { + // Input IR: + // for (int i = 0; i < 100; i++) { + // A[i+100] = 20 + i; + // } + // for (int j = 0; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 0; k < 100; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {200}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forI = + For::make(i, 0, 100, Store::make(a_buf, {i + 100}, Add::make(20, i))); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make(k, 0, 100, Store::make(b_buf, {k}, Mul::make(20, k))); + auto par = Block::make({forI, forJ, forK}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forI, forJ, forK}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i + 100] = +# CHECK-NEXT: A[i] = +# CHECK-NEXT: B[i] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forI); +} + +TEST(LoopNest, fuseLoopsNested) { + // Input IR: + // for (int m = 0; m < 20; m++) { + // A[m] = 0; + // for (int j = 0; j < 100; j++) { + // A[m] = A[m] + m * j; + // } + // } + // for (int n = 0; n < 20; n++) { + // B[n] = A[n]; + // for (int k = 0; k < 50; k++) { + // B[n] = B[n] + n * k; + // } + // } + BufHandle a_buf("A", {20, 100}, kInt); + BufHandle b_buf("B", {20, 100}, kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto initA = Store::make(a_buf, {m}, 0); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j)))); + auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n})); + auto forK = For::make( + k, + 0, + 50, + Store::make( + b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k)))); + auto forM = For::make(m, 0, 20, Block::make({initA, forJ})); + auto forN = For::make(n, 0, 20, Block::make({initB, forK})); + auto par = Block::make({forM, forN}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int m +# CHECK-NEXT: A[m] = 0 +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[m] = +# CHECK: B[m] = A[m] +# CHECK-NEXT: for (int k +# CHECK-NEXT: B[m] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forM); +} + +TEST(LoopNest, fuseLoopsNested2D) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 100; j++) { + // A[i,j] = i * j * 500; + // } + // } + // for (int m = 0; m < 20; m++) { + // for (int n = 0; n < 50; n++) { + // B[m,n] = m + n * 100; + // } + // } + BufHandle a_buf("A", {20, 100}, kInt); + BufHandle b_buf("B", {20, 100}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto forI = For::make( + i, + 0, + 20, + For::make( + j, + 0, + 100, + Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)))); + auto forM = For::make( + m, + 0, + 20, + For::make( + n, + 0, + 50, + Store::make(b_buf, {m, n}, Add::make(m, Mul::make(n, 100))))); + auto par = Block::make({forI, forM}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i, j] = +# CHECK: for (int n +# CHECK-NEXT: B[i, n] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forI); +} + +TEST(LoopNest, fuseLoopsNested2DInner) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 100; j++) { + // A[i,j] = i * j * 500; + // } + // for (int n = 0; n < 100; n++) { + // B[i,n] = m + n * 100; + // } + // } + BufHandle a_buf("A", {20, 100}, kInt); + BufHandle b_buf("B", {20, 100}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle n("n", kInt); + auto forJ = For::make( + j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))); + auto forN = For::make( + n, 0, 100, Store::make(b_buf, {i, n}, Add::make(i, Mul::make(n, 100)))); + auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); + + std::ostringstream oss; + oss << *forI; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i, j] = +# CHECK-NEXT: B[i, j] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forJ); +} + +TEST(LoopNest, fuseLoopsDifferentStopBounds) { + // Input IR: + // for (int j = 0; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 0; k < 50; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make(k, 0, 50, Store::make(b_buf, {j}, Mul::make(20, k))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsDifferentStartBounds) { + // Input IR: + // for (int j = 0; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 50; k < 100; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsNotContiguous) { + // Input IR: + // for (int j = 0; j < 100; j++) { + // A[j] = 10 * j; + // } + // B[0] = 0; + // for (int k = 0; k < 100; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto initB = Store::make(b_buf, {0}, 0); + auto forK = For::make(k, 0, 100, Store::make(b_buf, {j}, Mul::make(20, k))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forJ, initB, forK}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsWithDifferentParents) { + // Input IR: + // for (int i = 0; i < 50; i++) { + // for (int j = 0; j < 100; j++) { + // A[i,j] = i * j; + // } + // } + // B[0] = 0; + // for (int k = 50; k < 100; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {50, 100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(i, j))); + auto forI = For::make(i, 0, 50, forJ); + auto initB = Store::make(b_buf, {0}, 0); + auto forK = For::make(k, 50, 100, Store::make(b_buf, {j}, Mul::make(20, k))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forI, initB, forK}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsWithVariableBounds) { + // Input IR: + // for (int j = 0; j < N; j++) { + // A[j] = 10 * j; + // } + // for (int k = 0; k < N; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + VarHandle N("N", kInt); + auto forJ = For::make(j, 0, N, Store::make(a_buf, {j}, Mul::make(10, j))); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) + auto forK = For::make(k, 0, N, Store::make(b_buf, {j}, Mul::make(20, k))); + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int j +# CHECK-NEXT: A[j] = +# CHECK-NEXT: B[j] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forJ); +} + +TEST(LoopNest, fuseLoopsWithExprBounds) { + // Input IR: + // for (int j = 0; j < M + N; j++) { + // A[j] = 10 * j; + // } + // for (int k = 0; k < M + N; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + VarHandle M("M", kInt); + VarHandle N("N", kInt); + auto forJ = For::make(j, 0, M + N, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make(k, 0, M + N, Store::make(b_buf, {j}, Mul::make(20, k))); + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int j +# CHECK-NEXT: A[j] = +# CHECK-NEXT: B[j] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forJ); +} + +TEST(LoopNest, fuseLoopsWithDifferentExprBounds) { + // Input IR: + // for (int j = M; j < N * 2; j++) { + // A[j] = 10 * j; + // } + // for (int k = M; k < N + N; k++) { + // B[k] = 20 * k; + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + VarHandle M("M", kInt); + VarHandle N("N", kInt); + auto forJ = For::make(j, M, N * 2, Store::make(a_buf, {j}, Mul::make(10, j))); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) + auto forK = For::make(k, M, N + N, Store::make(b_buf, {j}, Mul::make(20, k))); + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int j +# CHECK-NEXT: A[j] = +# CHECK-NEXT: B[j] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forJ); +} + +TEST(LoopNest, fuseLoopsWithNonOverlappingBufferAccesses) { + // Input IR: + // for (int j = 10; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 10; k < 100; k++) { + // A[k+100] = 30 * k + // } + BufHandle a_buf("A", {200}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = + For::make(k, 10, 100, Store::make(a_buf, {k + 100}, Mul::make(30, k))); + auto par = Block::make({forJ, forK}); + + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int j +# CHECK-NEXT: A[j] = +# CHECK-NEXT: A[j + 100] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forJ); +} + +TEST(LoopNest, fuseLoopsWithNonOverlapping2DBufferAccesses) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 100; j++) { + // A[i,j] = i * j * 500; + // } + // } + // for (int m = 0; m < 20; m++) { + // for (int n = 0; n < 50; n++) { + // A[m+20,n+100] = m + n * 100; + // } + // } + BufHandle a_buf("A", {20, 100}, kInt); + BufHandle b_buf("B", {20, 50}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto storeA1 = Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)); + auto forJ = For::make(j, 0, 100, storeA1); + auto forI = For::make(i, 0, 20, forJ); + auto storeA2 = + Store::make(a_buf, {m + 20, n + 100}, Add::make(m, Mul::make(n, 100))); + auto forN = For::make(n, 0, 50, storeA2); + auto forM = For::make(m, 0, 20, forN); + auto par = Block::make({forI, forM}); + + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i, j] = +# CHECK: for (int n +# CHECK-NEXT: A[i + 20, n + 100] = +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forI); +} + +TEST(LoopNest, fuseLoopsWithReductions) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // A[i] = 0 + // for (int j = 0; j < 100; j++) { + // A[i] = A[i] + B[i,j]; + // } + // } + // for (int m = 0; m < 20; m++) { + // C[m] = A[m]; + // } + BufHandle a_buf("A", {20}, kInt); + BufHandle b_buf("B", {20, 100}, kInt); + BufHandle c_buf("C", {20}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + auto initA = Store::make(a_buf, {i}, 0); + auto sumA = Store::make( + a_buf, {i}, Add::make(Load::make(a_buf, {i}), Load::make(b_buf, {i, j}))); + auto forJ = For::make(j, 0, 100, sumA); + auto forI = For::make(i, 0, 20, Block::make({initA, forJ})); + auto forM = + For::make(m, 0, 20, Store::make(c_buf, {m}, Load::make(a_buf, {m}))); + auto par = Block::make({forI, forM}); + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i] = (A[i]) + +# CHECK-NOT: for ( +# CHECK: C[i] = A[i] + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forI); +} + +TEST(LoopNest, fuseLoopsWith2DReductions) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 50; j++) { + // A[i,j] = 0 + // for (int k = 0; k < 100; k++) { + // A[i,j] = A[i,j] + B[i,j,k]; + // } + // } + // } + // for (int m = 0; m < 20; m++) { + // for (int n = 0; n < 40; n++) { + // C[m,n] = A[m,n]; + // } + // } + BufHandle a_buf("A", {20, 50}, kInt); + BufHandle b_buf("B", {20, 50, 100}, kInt); + BufHandle c_buf("C", {20, 40}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto initA = Store::make(a_buf, {i, j}, 0); + auto sumA = Store::make( + a_buf, + {i, j}, + Add::make(Load::make(a_buf, {i, j}), Load::make(b_buf, {i, j, k}))); + auto forK = For::make(k, 0, 100, sumA); + auto forJ = For::make(j, 0, 50, Block::make({initA, forK})); + auto forI = For::make(i, 0, 20, forJ); + auto storeC = Store::make(c_buf, {m, n}, Load::make(a_buf, {m, n})); + auto forM = For::make(m, 0, 20, For::make(n, 0, 40, storeC)); + auto par = Block::make({forI, forM}); + + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i, j] = +# CHECK-NEXT: for (int k +# CHECK-NEXT: A[i, j] = (A[i, j]) + +# CHECK: for (int n +# CHECK-NEXT: C[i, n] = A[i, n] +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forI); +} + +TEST(LoopNest, fuseLoopsWithComplexIndices) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 20; j++) { + // A[i,j*20+j+2] = i + j; + // } + // } + // for (int m = 0; m < 20; m++) { + // for (int n = 0; n < 20; n++) { + // B[m,n] = A[m,n*20+n+2]; + // } + // } + BufHandle a_buf("A", {20, 400}, kInt); + BufHandle b_buf("B", {20, 400}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto writeA = Store::make(a_buf, {i, j * 20 + j + 2}, i + j); + auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); + auto storeB = + Store::make(b_buf, {m, n}, Load::make(a_buf, {m, n * 20 + n + 2})); + auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); + auto par = Block::make({forI, forM}); + + ForPtr fused_loop; + ASSERT_TRUE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i, (j * 20 + j) + 2] = i + j +# CHECK: for (int n +# CHECK-NEXT: B[i, n] = A[i, (n * 20 + n) + 2] +# CHECK-NOT: for ( + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + // The fused loop must be the same as the first loop. + ASSERT_EQ(fused_loop, forI); +} + +TEST(LoopNest, fuseLoopsWithMixedLoopVarsAsIndices) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 20; j++) { + // A[i,i*20+j] = i + j; + // } + // } + // for (int m = 0; m < 20; m++) { + // for (int n = 0; n < 20; n++) { + // B[m,n] = A[m,m*20+n]; // Both indices of A use m + // } + // } + BufHandle a_buf("A", {20, 500}, kInt); + BufHandle b_buf("B", {20, 500}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto writeA = Store::make(a_buf, {i, i * 20 + j}, i + j); + auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); + auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {m, m * 20 + n})); + auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); + auto par = Block::make({forI, forM}); + + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsWithTranspose) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 20; j++) { + // A[i,j] = i + j; + // } + // } + // for (int m = 0; m < 20; m++) { + // for (int n = 0; n < 20; n++) { + // B[m,n] = A[n,m]; // Transpose + // } + // } + BufHandle a_buf("A", {20, 20}, kInt); + BufHandle b_buf("B", {20, 20}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto writeA = Store::make(a_buf, {i, j}, i + j); + auto forI = For::make(i, 0, 20, For::make(j, 0, 20, writeA)); + auto storeB = Store::make(b_buf, {m, n}, Load::make(a_buf, {n, m})); + auto forM = For::make(m, 0, 20, For::make(n, 0, 20, storeB)); + auto par = Block::make({forI, forM}); + + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsThatViolateDependencies1) { + // Input IR: + // for (int j = 10; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 10; k < 100; k++) { + // A[k-1] = 20 * k; + // } + BufHandle a_buf("A", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = + For::make(k, 10, 100, Store::make(a_buf, {k - 1}, Mul::make(20, k))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsThatViolateDependencies2) { + // Input IR: + // for (int j = 10; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 10; k < 100; k++) { + // A[k+50] = 20 * k; + // } + BufHandle a_buf("A", {150}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 10, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = + For::make(k, 10, 100, Store::make(a_buf, {k + 50}, Mul::make(20, k))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsThatViolateDependencies3) { + // Input IR: + // for (int m = 0; m < 20; m++) { + // A[m] = 0; + // for (int j = 0; j < 100; j++) { + // A[m] = A[m] + m * j; + // } + // } + // for (int n = 0; n < 20; n++) { + // B[n] = A[n+1]; + // for (int k = 0; k < 50; k++) { + // B[n] = B[n] + n * k; + // } + // } + BufHandle a_buf("A", {25, 100}, kInt); + BufHandle b_buf("B", {20, 50}, kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto initA = Store::make(a_buf, {m}, 0); + auto forJ = For::make( + j, + 0, + 100, + Store::make( + a_buf, {m}, Add::make(Load::make(a_buf, {m}), Mul::make(m, j)))); + auto initB = Store::make(b_buf, {n}, Load::make(a_buf, {n + 1})); + auto forK = For::make( + k, + 0, + 50, + Store::make( + b_buf, {n}, Add::make(Load::make(b_buf, {n}), Mul::make(n, k)))); + auto forM = For::make(m, 0, 20, Block::make({initA, forJ})); + auto forN = For::make(n, 0, 20, Block::make({initB, forK})); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forM, forN}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forM, forN}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsThatViolateDependencies4) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 100; j++) { + // A[i,j] = i * j * 500; + // } + // } + // for (int m = 0; m < 20; m++) { + // for (int n = 0; n < 50; n++) { + // A[m+1,n] = m + n * 100; + // } + // } + BufHandle a_buf("A", {30, 100}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + auto forI = For::make( + i, + 0, + 20, + For::make( + j, + 0, + 100, + Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500)))); + auto forM = For::make( + m, + 0, + 20, + For::make( + n, + 0, + 50, + Store::make(a_buf, {m + 1, n}, Add::make(m, Mul::make(n, 100))))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forI, forM}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forI, forM}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsThatViolateDependencies5) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 100; j++) { + // A[i,j] = i * j * 500; + // } + // for (int n = 0; n < 100; n++) { + // A[i,n+1] = m + n * 100; + // } + // } + BufHandle a_buf("A", {20, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle n("n", kInt); + auto forJ = For::make( + j, 0, 100, Store::make(a_buf, {i, j}, Mul::make(Mul::make(i, j), 500))); + auto forN = For::make( + n, + 0, + 100, + Store::make(a_buf, {i, n + 1}, Add::make(i, Mul::make(n, 100)))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,cppcoreguidelines-avoid-magic-numbers) + auto forI = For::make(i, 0, 20, Block::make({forJ, forN})); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forN}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsThatViolateDependencies6) { + // Input IR: + // for (int j = 0; j < 100; j++) { + // A[j] = 10 * j; + // } + // for (int k = 0; k < 100; k++) { + // B[k] = 20 * A[99-k]; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + auto forK = For::make( + k, + 0, + 100, + Store::make( + b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forJ, forK}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forJ, forK}, &fused_loop)); +} + +TEST(LoopNest, fuseLoopsThatViolateDependencies7) { + // Input IR: + // for (int k = 0; k < 100; k++) { + // B[k] = 20 * A[99-k]; + // } + // for (int j = 0; j < 100; j++) { + // A[j] = 10 * j; + // } + BufHandle a_buf("A", {100}, kInt); + BufHandle b_buf("B", {100}, kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto forK = For::make( + k, + 0, + 100, + Store::make( + b_buf, {k}, Mul::make(20, Load::make(a_buf, {ExprHandle(99) - k})))); + auto forJ = For::make(j, 0, 100, Store::make(a_buf, {j}, Mul::make(10, j))); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forK, forJ}); + ForPtr fused_loop; + ASSERT_FALSE(LoopNest::fuseLoops({forK, forJ}, &fused_loop)); +} + +TEST(LoopNest, areLoopsPerfectlyNested) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 30; j++) { + // for (int k = 0; k < 40; k++) { + // A[i,j,k] = i * j * k; + // } + // } + // } + BufHandle a_buf("A", {20, 30, 40}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); + auto forK = For::make(k, 0, 40, store); + auto forJ = For::make(j, 0, 30, forK); + auto forI = For::make(i, 0, 20, forJ); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forI}); + ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); + + // Specifying the loops in any other order fails. + ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forJ, forI, forK})); + ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forK, forJ})); + ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forK, forJ, forI})); + + // Adding a statement to forK body should be OK. + auto init = Store::make(a_buf, {i, j}, 0); + forK->body()->insert_stmt_before(init, store); + ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); + + // Adding a statement in forJ body should fail this test. + forK->body()->remove_stmt(init); + forJ->body()->insert_stmt_before(init, forK); + ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); + + // Similarly, adding a statement in forI body should fail this test. + forJ->body()->remove_stmt(init); + forI->body()->insert_stmt_before(init, forJ); + ASSERT_FALSE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); +} + +TEST(LoopNest, reorderNestedLoops2D) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 30; j++) { + // A[i,j] = i * j; + // } + // } + BufHandle a_buf("A", {20, 30, 40}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto store = Store::make(a_buf, {i, j}, Mul::make(i, j)); + auto forJ = For::make(j, 0, 30, store); + auto forI = For::make(i, 0, 20, forJ); + auto par = Block::make({forI}); + + auto reordered = LoopNest::reorder({forI, forJ}, {1, 0}); + + ASSERT_EQ(reordered[0], forJ); + ASSERT_EQ(reordered[1], forI); + ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forJ, forI})); + ASSERT_EQ(forJ->get_parent(), par); + ASSERT_EQ(store->get_parent(), forI->body()); +} + +TEST(LoopNest, reorderNestedLoops3D) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 30; j++) { + // for (int k = 0; k < 40; k++) { + // A[i,j,k] = i * j * k; + // } + // } + // } + BufHandle a_buf("A", {20, 30, 40}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); + auto forK = For::make(k, 0, 40, store); + auto forJ = For::make(j, 0, 30, forK); + auto forI = For::make(i, 0, 20, forJ); + auto par = Block::make({forI}); + + auto reordered = LoopNest::reorder({forI, forJ, forK}, {2, 0, 1}); + + ASSERT_EQ(reordered[0], forK); + ASSERT_EQ(reordered[1], forI); + ASSERT_EQ(reordered[2], forJ); + ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forJ})); + ASSERT_EQ(forK->get_parent(), par); + ASSERT_EQ(store->get_parent(), forJ->body()); +} + +TEST(LoopNest, reorderNestedLoops4D) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 30; j++) { + // for (int k = 0; k < 40; k++) { + // for (int l = 0; l < 50; l++) { + // A[i,j,k,l] = i * j * k * l * 500; + // } + // } + // } + // } + BufHandle a_buf("A", {20, 30, 40, 50}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + VarHandle l("l", kInt); + auto store = Store::make( + a_buf, + {i, j, k, l}, + Mul::make(Mul::make(Mul::make(Mul::make(i, j), k), l), 500)); + auto forL = For::make(l, 0, 50, store); + auto forK = For::make(k, 0, 40, forL); + auto forJ = For::make(j, 0, 30, forK); + auto forI = For::make(i, 0, 20, forJ); + auto par = Block::make({forI}); + + auto reordered = LoopNest::reorder({forI, forJ, forK, forL}, {2, 0, 3, 1}); + + ASSERT_EQ(reordered[0], forK); + ASSERT_EQ(reordered[1], forI); + ASSERT_EQ(reordered[2], forL); + ASSERT_EQ(reordered[3], forJ); + ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forK, forI, forL, forJ})); + ASSERT_EQ(forK->get_parent(), par); + ASSERT_EQ(store->get_parent(), forJ->body()); +} + +TEST(LoopNest, reorderTrivialPermutation) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 30; j++) { + // for (int k = 0; k < 40; k++) { + // A[i,j,k] = i * j * k; + // } + // } + // } + BufHandle a_buf("A", {20, 30, 40}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); + auto forK = For::make(k, 0, 40, store); + auto forJ = For::make(j, 0, 30, forK); + auto forI = For::make(i, 0, 20, forJ); + auto par = Block::make({forI}); + + auto reordered = LoopNest::reorder({forI, forJ, forK}, {0, 1, 2}); + + ASSERT_EQ(reordered[0], forI); + ASSERT_EQ(reordered[1], forJ); + ASSERT_EQ(reordered[2], forK); + ASSERT_TRUE(LoopNest::areLoopsPerfectlyNested({forI, forJ, forK})); + ASSERT_EQ(forI->get_parent(), par); + ASSERT_EQ(store->get_parent(), forK->body()); +} + +TEST(LoopNest, reorderInvalidPermutations) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 30; j++) { + // for (int k = 0; k < 40; k++) { + // A[i,j,k] = i * j * k; + // } + // } + // } + BufHandle a_buf("A", {20, 30, 40}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); + auto forK = For::make(k, 0, 40, store); + auto forJ = For::make(j, 0, 30, forK); + auto forI = For::make(i, 0, 20, forJ); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forI}); + + ASSERT_THROWS_WITH( + LoopNest::reorder({forI, forJ, forK}, {0, 1, 2, 3}), + "invalid permutation size"); + ASSERT_THROWS_WITH( + LoopNest::reorder({forI, forJ, forK}, {1, 2}), + "invalid permutation size"); + ASSERT_THROWS_WITH( + LoopNest::reorder({forI, forJ, forK}, {2, 1, 3}), + "invalid permutation for reorder"); + ASSERT_THROWS_WITH( + LoopNest::reorder({forI, forJ, forK}, {1, 1, 0}), + "invalid permutation for reorder"); + ASSERT_THROWS_WITH( + LoopNest::reorder({forI, forJ, forK}, {0, 0, 0}), + "invalid permutation for reorder"); +} + +TEST(LoopNest, reorderInvalidLoopNest) { + // Input IR: + // for (int i = 0; i < 20; i++) { + // for (int j = 0; j < 30; j++) { + // A[i,j] = 0 + // for (int k = 0; k < 40; k++) { + // A[i,j,k] = i * j * k; + // } + // } + // } + BufHandle a_buf("A", {20, 30, 40}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto store = Store::make(a_buf, {i, j, k}, Mul::make(Mul::make(i, j), k)); + auto forK = For::make(k, 0, 40, store); + auto forJ = For::make(j, 0, 30, forK); + auto forI = For::make(i, 0, 20, forJ); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + auto par = Block::make({forI}); + + // Specifying the loops in incorrect order fails. + ASSERT_THROWS_WITH( + LoopNest::reorder({forK, forI, forJ}, {1, 0, 2}), + "reorder is only allowed on perfectly nested loops"); + + // Adding a statement to forJ loop fails. + auto init = Store::make(a_buf, {i}, 0); + forJ->body()->insert_stmt_before(init, forK); + ASSERT_THROWS_WITH( + LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}), + "reorder is only allowed on perfectly nested loops"); + + // Moving that statement to forI loop also fails. + forJ->body()->remove_stmt(init); + forI->body()->insert_stmt_before(init, forJ); + ASSERT_THROWS_WITH( + LoopNest::reorder({forI, forJ, forK}, {1, 0, 2}), + "reorder is only allowed on perfectly nested loops"); +} + +TEST(LoopNest, compressBufferSimple) { + // Input IR: + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // A[i,j] = sin(i*j) + // } + // for (int j = 0; j < 199; ++j) { + // B[i,j] = A[i,j] + A[i, j+1] + // } + // } + BufHandle aBuf("A", {100, 200}, kInt); + BufHandle bBuf("B", {100, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); + auto forJ2 = For::make( + j, + 0, + 199, + Store::make( + bBuf, + {i, j}, + Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); + auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); + auto par = Block::make({forI}); + LoopNest::compressBuffer(aBuf.node(), par); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[0, j] = +# CHECK: for (int j +# CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); +} + +TEST(LoopNest, compressBufferMultipleDims) { + // Input IR: + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // A[i,j] = sin(i*j) + // B[i,j] = A[i,j] + A[i,j] + // } + // } + BufHandle aBuf("A", {100, 200}, kInt); + BufHandle bBuf("B", {100, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto store1 = Store::make(aBuf, {i, j}, sin(i * j)); + auto store2 = Store::make( + bBuf, + {i, j}, + Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j}))); + auto forJ = For::make(j, 0, 200, Block::make({store1, store2})); + auto forI = For::make(i, 0, 100, forJ); + auto par = Block::make({forI}); + LoopNest::compressBuffer(aBuf.node(), par); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[0, 0] = +# CHECK-NEXT: B[i, j] = (A[0, 0]) + (A[0, 0]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); +} + +TEST(LoopNest, compressBufferMultipleDims2) { + // Input IR: + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // for (int k = 0; k < 300; ++k) { + // A[i,j,k] = sin(i*j*k) + // } + // for (int k = 0; k < 299; ++j) { + // B[i,j,k] = A[i,j,k] + A[i,j,k+1] + // } + // } + // } + BufHandle aBuf("A", {100, 200, 300}, kInt); + BufHandle bBuf("B", {100, 200, 300}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + auto store1 = Store::make(aBuf, {i, j, k}, sin(i * j * k)); + auto forK1 = For::make(k, 0, 300, store1); + auto store2 = Store::make( + bBuf, + {i, j, k}, + Add::make(Load::make(aBuf, {i, j, k}), Load::make(aBuf, {i, j, k + 1}))); + auto forK2 = For::make(k, 0, 299, store2); + auto forJ = For::make(j, 0, 200, Block::make({forK1, forK2})); + auto forI = For::make(i, 0, 100, forJ); + auto par = Block::make({forI}); + LoopNest::compressBuffer(aBuf.node(), par); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: for (int k +# CHECK-NEXT: A[0, 0, k] = +# CHECK: for (int k +# CHECK-NEXT: B[i, j, k] = (A[0, 0, k]) + (A[0, 0, k + 1]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 3); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(2), 300); +} + +TEST(LoopNest, compressBufferDifferentOrderIndices) { + // Input IR: + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // A[j, i] = sin(i*j) + // } + // for (int j = 0; j < 99; ++j) { + // B[i, j] = A[j, i] + A[j+1, 0] + // } + // } + BufHandle aBuf("A", {100, 200}, kInt); + BufHandle bBuf("B", {100, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {j, i}, sin(i * j))); + auto forJ2 = For::make( + j, + 0, + 99, + Store::make( + bBuf, + {i, j}, + Add::make(Load::make(aBuf, {j, i}), Load::make(aBuf, {j + 1, i})))); + auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); + auto par = Block::make({forI}); + LoopNest::compressBuffer(aBuf.node(), par); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[j, 0] = +# CHECK: for (int j +# CHECK-NEXT: B[i, j] = (A[j, 0]) + (A[j + 1, 0]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 1); +} + +TEST(LoopNest, compressBufferVariableBounds) { + // Input IR: + // for (int i = 0; i < M; ++i) { + // for (int j = 0; j < N; ++j) { + // A[i,j] = sin(i*j) + // } + // for (int j = 0; j < N-1; ++j) { + // B[i,j] = A[i,j] + A[i, j+1] + // } + // } + BufHandle aBuf("A", {100, 200}, kInt); + BufHandle bBuf("B", {100, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle M("M", kInt); + VarHandle N("N", kInt); + auto forJ1 = For::make(j, 0, N, Store::make(aBuf, {i, j}, sin(i * j))); + auto forJ2 = For::make( + j, + 0, + N - 1, + Store::make( + bBuf, + {i, j}, + Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + auto forI = For::make(i, 0, M, Block::make({forJ1, forJ2})); + auto par = Block::make({forI}); + LoopNest::compressBuffer(aBuf.node(), par); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[0, j] = +# CHECK: for (int j +# CHECK-NEXT: B[i, j] = (A[0, j]) + (A[0, j + 1]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); +} + +TEST(LoopNest, compressBufferNoCommonParentLoops) { + // Input IR: + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // A[i,j] = sin(i*j) + // } + // } + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 199; ++j) { + // B[i,j] = A[i,j] + A[i, j+1] + // } + // } + BufHandle aBuf("A", {100, 200}, kInt); + BufHandle bBuf("B", {100, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); + auto forJ2 = For::make( + j, + 0, + 199, + Store::make( + bBuf, + {i, j}, + Add::make(Load::make(aBuf, {i, j}), Load::make(aBuf, {i, j + 1})))); + auto forI1 = For::make(i, 0, 100, forJ1); + auto forI2 = For::make(i, 0, 100, forJ2); + auto par = Block::make({forI1, forI2}); + LoopNest::compressBuffer(aBuf.node(), par); + + // There should be no change in the buffer or code. + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i, j] = +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: B[i, j] = (A[i, j]) + (A[i, j + 1]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 100); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); +} + +TEST(LoopNest, compressBufferIndicesMixed) { + // Input IR: + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // A[i + j, j] = sin(i*j) + // } + // for (int j = 0; j < 199; ++j) { + // B[i,j] = A[i + j, j] + A[i + j, j+1] + // } + // } + BufHandle aBuf("A", {300, 200}, kInt); + BufHandle bBuf("B", {100, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto forJ1 = For::make(j, 0, 200, Store::make(aBuf, {i + j, j}, sin(i * j))); + auto forJ2 = For::make( + j, + 0, + 199, + Store::make( + bBuf, + {i, j}, + Add::make( + Load::make(aBuf, {i + j, j}), Load::make(aBuf, {i + j, j + 1})))); + auto forI = For::make(i, 0, 100, Block::make({forJ1, forJ2})); + auto par = Block::make({forI}); + LoopNest::compressBuffer(aBuf.node(), par); + + // There should be no change in the buffer or code. + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[i + j, j] = +# CHECK: for (int j +# CHECK-NEXT: B[i, j] = (A[i + j, j]) + (A[i + j, j + 1]) + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 300); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); +} + +TEST(LoopNest, compressMultipleBuffers) { + // Input IR: + // for (int i = 0; i < 100; ++i) { + // for (int j = 0; j < 200; ++j) { + // A[i,j] = sin(i*j) + // } + // for (int k = 0; k < 199; ++k) { + // B[i,k] = A[i,k] + A[i, k+1] + // } + // for (int m = 0; m < 50; ++m) { + // C[i,m] = B[i,m] + // } + // } + BufHandle aBuf("A", {100, 200}, kInt); + BufHandle bBuf("B", {100, 200}, kInt); + BufHandle cBuf("C", {100, 200}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + VarHandle k("k", kInt); + VarHandle m("m", kInt); + auto forJ = For::make(j, 0, 200, Store::make(aBuf, {i, j}, sin(i * j))); + auto forK = For::make( + k, + 0, + 199, + Store::make( + bBuf, + {i, k}, + Add::make(Load::make(aBuf, {i, k}), Load::make(aBuf, {i, k + 1})))); + auto forM = + For::make(m, 0, 50, Store::make(cBuf, {i, m}, Load::make(bBuf, {i, m}))); + auto forI = For::make(i, 0, 100, Block::make({forJ, forK, forM})); + auto par = Block::make({forI}); + + // This should compress all buffers A, B, and C as follows: + // A[100, 200] -> A[1, 200] + // B[100, 200] -> B[1, 200] + // C[100, 200] -> C[1, 1] + LoopNest::compressAllBuffers(par); + + std::ostringstream oss; + oss << *par; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: for (int j +# CHECK-NEXT: A[0, j] = +# CHECK: for (int k +# CHECK-NEXT: B[0, k] = (A[0, k]) + (A[0, k + 1]) +# CHECK: for (int m +# CHECK-NEXT: C[0, 0] = B[0, m] + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + ASSERT_EQ(aBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(0), 1); + IS_IMM_WITH_VAL(Int, aBuf.node()->dim(1), 200); + ASSERT_EQ(bBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, bBuf.node()->dim(0), 1); + IS_IMM_WITH_VAL(Int, bBuf.node()->dim(1), 200); + ASSERT_EQ(cBuf.node()->ndim(), 2); + IS_IMM_WITH_VAL(Int, cBuf.node()->dim(0), 1); + IS_IMM_WITH_VAL(Int, cBuf.node()->dim(1), 1); +} + +TEST(LoopNest, sanitizeNames) { + std::vector dim_args; + // Let's pick names that would overlap with default index names if not + // sanitized properly: + dim_args.emplace_back(ExprHandle(alloc("i", kInt))); + dim_args.emplace_back(ExprHandle(alloc("N:2", kInt))); + // Now let's create a many dimensions so that we had to use the same letter + // for different loops + for (int i = 0; i < 10; i++) { + dim_args.emplace_back(ExprHandle(alloc("N", kInt))); + } + + // Now create two Computes with conflicting after sanitization names: + Tensor X = Compute("$X:!", dim_args, [&](const std::vector& v) { + return v[0] + v[1] + v[9] + 1; + }); + Tensor Y = Reduce( + "%X\"+", + {}, + Sum(), + [&](const std::vector& v) { return X.load(v); }, + dim_args); + + // Finally, let's verify what we got after sanitization: + LoopNest l({X, Y}); + StmtPtr s = l.root_stmt(); + LoopNest::sanitizeNames(s); + + std::ostringstream oss; + oss << *s; + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i = 0; i < i_1; i++) { +# CHECK-NEXT: for (int j = 0; j < N_2_1; j++) { +# CHECK-NEXT: for (int k = 0; k < N_9; k++) { +# CHECK-NEXT: for (int l = 0; l < N_8; l++) { +# CHECK-NEXT: for (int m = 0; m < N_7; m++) { +# CHECK-NEXT: for (int n = 0; n < N_6; n++) { +# CHECK-NEXT: for (int o = 0; o < N_5; o++) { +# CHECK-NEXT: for (int p = 0; p < N_4; p++) { +# CHECK-NEXT: for (int i1 = 0; i1 < N_3; i1++) { +# CHECK-NEXT: for (int j1 = 0; j1 < N_2; j1++) { +# CHECK-NEXT: for (int k1 = 0; k1 < N_1; k1++) { +# CHECK-NEXT: for (int l1 = 0; l1 < N; l1++) { +# CHECK-NEXT: v_X__[i, j, k, l, m, n, o, p, i1, j1, k1, l1] = ((i + j) + j1) + 1; +# CHECK: v_X___1 = int(0); +# CHECK-NEXT: for (int i_2 = 0; i_2 < i_1; i_2++) { +# CHECK-NEXT: for (int j_1 = 0; j_1 < N_2_1; j_1++) { +# CHECK-NEXT: for (int k_1 = 0; k_1 < N_9; k_1++) { +# CHECK-NEXT: for (int l_1 = 0; l_1 < N_8; l_1++) { +# CHECK-NEXT: for (int m_1 = 0; m_1 < N_7; m_1++) { +# CHECK-NEXT: for (int n_1 = 0; n_1 < N_6; n_1++) { +# CHECK-NEXT: for (int o_1 = 0; o_1 < N_5; o_1++) { +# CHECK-NEXT: for (int p_1 = 0; p_1 < N_4; p_1++) { +# CHECK-NEXT: for (int i1_1 = 0; i1_1 < N_3; i1_1++) { +# CHECK-NEXT: for (int j1_1 = 0; j1_1 < N_2; j1_1++) { +# CHECK-NEXT: for (int k1_1 = 0; k1_1 < N_1; k1_1++) { +# CHECK-NEXT: for (int l1_1 = 0; l1_1 < N; l1_1++) { +# CHECK-NEXT: v_X___1 = ReduceOp((v_X___1) + (v_X__[i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1]), reduce_args={i_2, j_1, k_1, l_1, m_1, n_1, o_1, p_1, i1_1, j1_1, k1_1, l1_1}); + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp new file mode 100644 index 0000000000000..cac7283f2bebe --- /dev/null +++ b/test/cpp/tensorexpr/test_memdependency.cpp @@ -0,0 +1,3252 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +// Test helper function used to determine if two regions of a buffer have an +// overlap. No Overlap & partial overlap is obvious. Contains means A is +// larger and fully encloses B, while ContainedOrEqual is the reverse. Equal +// ranges are ContainedOrEqual. +TEST(MemDependency, BoundOverlap) { + using namespace analysis; + + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; + + // Sanity check 3 overlap cases. + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(0, 0), CB(0, 0))); + ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 3), CB(2, 5))); + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 0), CB(1, 1))); + + // Partial overlap works in either order. + ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 10), CB(7, 14))); + ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(7, 14), CB(0, 10))); + + // Total Overlap works when one bound encloses the other, and returns which. + ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(7, 9))); + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(0, 16))); + + // Total overlap works when the bounds are an identical range, returns + // ContainedOrEqual. + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 15), CB(2, 15))); + + // Total overlap when only one end of the bound matches. + ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 10))); + ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(3, 15))); + ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(0, 10), CB(0, 9))); + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 10), CB(2, 15))); + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(3, 15), CB(2, 15))); + + // No overlap when a < b. + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 2), CB(5, 10))); + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 2), CB(3, 3))); + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(100, 120), CB(130, 130))); + + // No overlap when a > b. + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(5, 10), CB(0, 2))); + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(3, 3), CB(2, 2))); + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(130, 130), CB(100, 120))); + + // No overlap when adjacent. + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(0, 100), CB(101, 120))); + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(2, 3), CB(0, 1))); + + // Partial overlap when middle bounds match. + ASSERT_EQ( + OverlapKind::PartialOverlap, boundOverlap(CB(0, 100), CB(100, 120))); + ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(0, 2), CB(2, 4))); + ASSERT_EQ( + OverlapKind::PartialOverlap, boundOverlap(CB(100, 120), CB(0, 100))); + ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(2, 3), CB(1, 2))); + + // Total overlap when one bound is single length over one end of the other. + ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(15, 15))); + ASSERT_EQ(OverlapKind::Contains, boundOverlap(CB(2, 15), CB(2, 2))); + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(2, 2), CB(2, 15))); + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(15, 15), CB(2, 15))); +} + +TEST(MemDependency, BoundComparison) { + using namespace analysis; + + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; + + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kEQ)); + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kEQ)); + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kEQ)); + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kEQ)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kEQ)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kEQ)); + + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kNE)); + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(10, 10), CB(10, 10), CompareSelectOperation::kNE)); + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kNE)); + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kNE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kNE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(20, 30), CompareSelectOperation::kEQ)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kNE)); + + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLT)); + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLT)); + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLT)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLT)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLT)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLT)); + + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGE)); + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGE)); + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGE)); + + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kGT)); + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kGT)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kGT)); + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kGT)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kGT)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kGT)); + + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(10, 20), CB(30, 40), CompareSelectOperation::kLE)); + ASSERT_EQ( + CmpEvalResult::True, + compareBound(CB(30, 40), CB(40, 50), CompareSelectOperation::kLE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(10, 100), CB(10, 100), CompareSelectOperation::kLE)); + ASSERT_EQ( + CmpEvalResult::False, + compareBound(CB(30, 40), CB(10, 20), CompareSelectOperation::kLE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 40), CB(10, 30), CompareSelectOperation::kLE)); + ASSERT_EQ( + CmpEvalResult::NotDetermined, + compareBound(CB(30, 45), CB(40, 50), CompareSelectOperation::kLE)); +} + +TEST(MemDependency, BoundOverlapSymbolic) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + VarHandle w("w", kInt); + + using namespace analysis; + + auto CB = [](ExprHandle s, ExprHandle e) { + return Bound(s.node(), e.node()); + }; + + // Sanity check cases where the start and end is symbolic but the diff is + // constant. + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_EQ(OverlapKind::ContainedOrEqual, boundOverlap(CB(x, x), CB(x, x))); + ASSERT_EQ( + OverlapKind::PartialOverlap, + boundOverlap(CB(x, x + 3), CB(x + 2, x + 5))); + ASSERT_EQ(OverlapKind::NoOverlap, boundOverlap(CB(x, x), CB(x + 1, x + 1))); + + // We can't infer the sign of y, so cannot tell whether adding y is larger or + // smaller than y/2. + ASSERT_EQ( + OverlapKind::PartialOverlap, + boundOverlap(CB(x, x + y), CB(x, x + y / 2))); + + // No information about this bound, have to take the most conservative option: + // there may be an overlap. + ASSERT_EQ(OverlapKind::PartialOverlap, boundOverlap(CB(x, y), CB(z, w))); + + // Math on opaque terms works. + ASSERT_EQ( + OverlapKind::ContainedOrEqual, + boundOverlap(CB(x + w, y - z), CB(x + w, y - z))); + // Even requiring simplification. + ASSERT_EQ( + OverlapKind::ContainedOrEqual, + boundOverlap(CB(x - w - w, y), CB(x - w * 2, y))); +} + +// Tests the helper function for overlap of multi dimensional indices bounds. +// This uses boundOverlap on each dimension and return the "lowest" kind of +// overlap. +TEST(MemDependency, BoundOverlapMultiDim) { + using namespace analysis; + + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; + + // Sanity check one dimensional cases. + ASSERT_EQ(OverlapKind::ContainedOrEqual, overlaps({CB(0, 0)}, {CB(0, 0)})); + ASSERT_EQ(OverlapKind::NoOverlap, overlaps({CB(0, 2)}, {CB(5, 10)})); + ASSERT_EQ( + OverlapKind::PartialOverlap, overlaps({CB(0, 100)}, {CB(100, 120)})); + + // Total overlap in 3 dims. + ASSERT_EQ( + OverlapKind::ContainedOrEqual, + overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 4)})); + ASSERT_EQ( + OverlapKind::ContainedOrEqual, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(0, 10)})); + + // Total overlap in 2 dims, no overlap in another. + ASSERT_EQ( + OverlapKind::NoOverlap, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 5), CB(5, 10)})); + + // Total overlap in 2 dims, partial overlap in another. + ASSERT_EQ( + OverlapKind::PartialOverlap, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(0, 5), CB(5, 10)})); + // This case is most important, so verify the overlap in any dim. (dim 2) + ASSERT_EQ( + OverlapKind::PartialOverlap, + overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 2), CB(2, 6), CB(0, 5)})); + // Dim 1. + ASSERT_EQ( + OverlapKind::PartialOverlap, + overlaps({CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(1, 3), CB(0, 5), CB(0, 5)})); + // Total overlap in 1 dim, partial in 2. + ASSERT_EQ( + OverlapKind::PartialOverlap, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(0, 5), CB(5, 10)})); + // Total overlap, partial overlap, no overlap. + ASSERT_EQ( + OverlapKind::NoOverlap, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(2, 6), CB(11, 15), CB(0, 5)})); + + // Total overlap (B) in 2 dims, total overlap (A) in another. + ASSERT_EQ( + OverlapKind::Contains, + overlaps({CB(0, 2), CB(0, 5), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 4)})); + + // Total overlap (A) in 2 dims, total overlap (B) in another. + ASSERT_EQ( + OverlapKind::Contains, + overlaps( + {CB(0, 12), CB(0, 15), CB(0, 4)}, {CB(0, 2), CB(0, 3), CB(0, 14)})); + + // Total (B), No Overlap, Total (A). + ASSERT_EQ( + OverlapKind::NoOverlap, + overlaps( + {CB(0, 2), CB(0, 5), CB(0, 5)}, {CB(0, 6), CB(11, 15), CB(1, 2)})); +} + +// Test the helper we use to subtract bounds: returns the regions(s) of A which +// remain after removing the region of B. +TEST(MemDependency, BoundSubtract) { + using namespace analysis; + + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + // One element subtract. + ASSERT_EQ(subtractBound(CB(0, 0), CB(0, 0)).size(), 0); + ASSERT_EQ(subtractBound(CB(5, 5), CB(5, 5)).size(), 0); + + // No Overlap. + ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(2, 2)), {CB(5, 5)})); + ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(0, 4)), {CB(5, 5)})); + + // one side overlap. + ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(4, 7)), {CB(1, 3)})); + ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(5, 7)), {CB(0, 4)})); + ASSERT_TRUE(EQ(subtractBound(CB(4, 5), CB(1, 4)), {CB(5, 5)})); + ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 4)), {CB(5, 5)})); + + // both sides overlap. + ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(0, 7)), {})); + ASSERT_TRUE(EQ(subtractBound(CB(5, 5), CB(5, 7)), {})); + + // internal overlap. + ASSERT_TRUE(EQ(subtractBound(CB(1, 5), CB(2, 3)), {CB(1, 1), CB(4, 5)})); + ASSERT_TRUE(EQ(subtractBound(CB(0, 5), CB(2, 4)), {CB(0, 1), CB(5, 5)})); +} + +TEST(MemDependency, BoundSubtractSymbolic) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + VarHandle w("w", kInt); + + using namespace analysis; + + auto CB = [](ExprHandle s, ExprHandle e) { + return Bound(s.node(), e.node()); + }; + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + // One element subtract. + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(x, x)), {})); + ASSERT_TRUE(EQ(subtractBound(CB(x + 1, x + 1), CB(x + 1, x + 1)), {})); + ASSERT_TRUE(EQ(subtractBound(CB(x * 2, x * 2), CB(x * 2, x * 2)), {})); + + // Subtract constant range low. + ASSERT_TRUE( + EQ(subtractBound(CB(x, x + 10), CB(x, x + 4)), {CB(x + 5, x + 10)})); + // Subtract constant range high. + ASSERT_TRUE( + EQ(subtractBound(CB(x, x + 10), CB(x + 6, x + 12)), {CB(x, x + 5)})); + // Subtract constant range total overlap. + ASSERT_TRUE(EQ(subtractBound(CB(x, x + 10), CB(x, x + 10)), {})); + ASSERT_TRUE(EQ(subtractBound(CB(x + 2, x + 10), CB(x, x + 12)), {})); + // Subtract constant range internal. + ASSERT_TRUE( + EQ(subtractBound(CB(x, x + 10), CB(x + 3, x + 7)), + {CB(x, x + 2), CB(x + 8, x + 10)})); + + // Size is inferable but not constant, only works with a single var. + ASSERT_TRUE(EQ(subtractBound(CB(0, x), CB(0, x * 2)), {})); + ASSERT_TRUE(EQ(subtractBound(CB(0, x * 2), CB(0, x - 1)), {CB(x, x * 2)})); + + // Size is not inferable. + ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(z, w)), {CB(x, y)})); + ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(x, z)), {CB(x, y)})); + ASSERT_TRUE(EQ(subtractBound(CB(x, y), CB(0, x)), {CB(x, y)})); + ASSERT_TRUE(EQ(subtractBound(CB(x, x), CB(0, 0)), {CB(x, x)})); +} + +// Tests the helper function that does subtraction, but for multi dimensional +// indices bounds. +TEST(MemDependency, BoundSubtractMultiDim) { + using namespace analysis; + + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; + auto EQ = [](std::vector x, std::vector y) { + if (x.size() != y.size()) { + return false; + } + for (auto i = 0U; i < x.size(); ++i) { + if (!indexBoundsEquals(x[i], y[i])) { + return false; + } + } + return true; + }; + + // sanity check one dimension. + ASSERT_TRUE(EQ(subtractIndicesBounds({CB(0, 9)}, {CB(0, 9)}), {})); + ASSERT_TRUE(EQ(subtractIndicesBounds({CB(3, 9)}, {CB(0, 12)}), {})); + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, 12)}, {CB(0, 9)}), {{CB(10, 12)}})); + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, 12)}, {CB(3, 12)}), {{CB(0, 2)}})); + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(0, 9)}, {CB(1, 8)}), {{CB(0, 0)}, {CB(9, 9)}})); + + // Multi dim total overlap. + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 9), CB(0, 2)}), {})); + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 10), CB(0, 20)}), {})); + + // Mutli dim one way partial in dim 1. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 3), CB(0, 2)}), + {{CB(4, 9), CB(0, 2)}})); + + // Mutli dim one way partial in dim 2. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, 9), CB(0, 20)}, {CB(0, 9), CB(0, 10)}), + {{CB(0, 9), CB(11, 20)}})); + + // Partial overlap in 2 dims. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8)}), + {{CB(0, 1), CB(0, 5)}, {CB(2, 5), CB(0, 1)}})); + + // Partial overlap in 3 dims. + ASSERT_TRUE( + EQ(subtractIndicesBounds( + {CB(0, 5), CB(0, 5), CB(0, 5)}, {CB(2, 8), CB(2, 8), CB(2, 8)}), + {{CB(0, 1), CB(0, 5), CB(0, 5)}, + {CB(2, 5), CB(0, 1), CB(0, 5)}, + {CB(2, 5), CB(2, 5), CB(0, 1)}})); +} + +// Tests the multi dimensional subtraction code for bounds that cannot be fully +// materialized. +TEST(MemDependency, BoundSubtractMultiDimSymbolic) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + + auto CB = [](ExprHandle s, ExprHandle e) { + return Bound(s.node(), e.node()); + }; + + auto EQ = [](std::vector x, std::vector y) { + if (x.size() != y.size()) { + return false; + } + for (auto i = 0U; i < x.size(); ++i) { + if (!indexBoundsEquals(x[i], y[i])) { + return false; + } + } + return true; + }; + + // Cannot determine overlaps. + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_TRUE(EQ(subtractIndicesBounds({CB(x, x)}, {CB(0, 0)}), {{CB(x, x)}})); + + // Various total Overlaps. + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(x, x), CB(x, x)}, {CB(x, x), CB(x, x)}), {})); + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(x, y), CB(x, y)}, {CB(x, y), CB(x, y)}), {})); + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(x, x), CB(y, y)}, {CB(x, x), CB(y, y)}), {})); + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(0, y)}), {})); + + // one-way overlap in first dim. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x - 5), CB(0, y)}), + {{CB(x - 4, x), CB(0, y)}})); + // second dim. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(5, y)}), + {{CB(0, x), CB(0, 4)}})); + + // Internal overlap in first dim. + ASSERT_TRUE( + EQ(subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(2, x - 5), CB(0, y)}), + {{CB(0, 1), CB(0, y)}, {CB(x - 4, x), CB(0, y)}})); + // second dim. + ASSERT_TRUE(EQ( + subtractIndicesBounds({CB(0, x), CB(0, y)}, {CB(0, x), CB(10, y - 10)}), + {{CB(0, x), CB(0, 9)}, {CB(0, x), CB(y - 9, y)}})); + + // Overlap in both dimensions. + ASSERT_TRUE( + EQ(subtractIndicesBounds( + {CB(0, x), CB(0, y)}, {CB(5, x - 5), CB(10, y - 10)}), + { + {CB(0, 4), CB(0, y)}, + {CB(x - 4, x), CB(0, y)}, + {CB(0, x), CB(0, 9)}, + {CB(0, x), CB(y - 9, y)}, + })); +} + +// Simple check that the analyzer does anything at all... +TEST(MemDependency, MemDependencyCheckerSimple) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + + analysis::MemDependencyChecker analyzer; + + /* + * A[0] = 3; + * B[0] = A[0] + 1; + */ + + StorePtr aStore = Store::make(a, {0}, 3); + StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); + + StmtPtr stmt = Block::make({aStore, bStore}); + + stmt->accept(&analyzer); + + ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); + ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore)); + // sanity check, but anything that depends directly must depend indirectly. + ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aStore)); +} + +// Check that there is a difference between direct and indirect dependence. +TEST(MemDependency, MemDependencyCheckerMultiStmt) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + BufHandle c("C", {1}, kInt); + + analysis::MemDependencyChecker analyzer; + + /* + * A[0] = 3; + * B[0] = A[0]; + * C[0] = B[0] + 1; + */ + + StorePtr aStore = Store::make(a, {0}, 3); + StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); + StorePtr cStore = Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)); + + StmtPtr stmt = Block::make({aStore, bStore, cStore}); + + stmt->accept(&analyzer); + + // C depends on A indirectly. + ASSERT_FALSE(analyzer.dependsDirectly(cStore, aStore)); + ASSERT_TRUE(analyzer.dependsIndirectly(cStore, aStore)); + + // C depends on B directly, which depends on A directly. + ASSERT_TRUE(analyzer.dependsDirectly(cStore, bStore)); + ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); + + // Dependency goes top to bottom only. + ASSERT_FALSE(analyzer.dependsIndirectly(bStore, cStore)); + ASSERT_FALSE(analyzer.dependsIndirectly(aStore, bStore)); + ASSERT_FALSE(analyzer.dependsIndirectly(aStore, cStore)); +} + +// Verify that we do filter writes that are totally overlapped by later writes. +TEST(MemDependency, MemDependencyCheckerOverlap) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + + analysis::MemDependencyChecker analyzer; + + /* + * A[0] = 3; + * A[0] = 6; + * B[0] = A[0] + 1; + */ + + StorePtr aStore = Store::make(a, {0}, 3); + StorePtr a2Store = Store::make(a, {0}, 6); + StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {0}), 1)); + + StmtPtr stmt = Block::make({aStore, a2Store, bStore}); + + stmt->accept(&analyzer); + + // B store depends on second A store but not first since it is completely + // overlapped. + ASSERT_TRUE(analyzer.dependsIndirectly(bStore, a2Store)); + ASSERT_FALSE(analyzer.dependsIndirectly(bStore, aStore)); + + // No dependency between either A store. + ASSERT_FALSE(analyzer.dependsIndirectly(aStore, a2Store)); + ASSERT_FALSE(analyzer.dependsIndirectly(a2Store, aStore)); +} + +// Verify that bounds match loop iterations, and that dependencies progress +// across loop scopes. +TEST(MemDependency, MemDependencyCheckerLoop) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + MemDependencyChecker analyzer; + + /* + * for (int x = 0; x < 10; ++x) { + * A[x] = x; + * } + * B[0] = A[0] + 1; + */ + + StorePtr aStore = Store::make(a, {x}, x); + StmtPtr loop = For::make(x, 0, 10, aStore); + StorePtr bStore = Store::make(b, {0}, Add::make(Load::make(a, {4}), 1)); + + StmtPtr stmt = Block::make({loop, bStore}); + + stmt->accept(&analyzer); + + // Same A->B dependency. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, aStore)); + + // B depends on the loop. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); + // A is in the loop but does not depend on any loop iteration. + ASSERT_FALSE(analyzer.dependsIndirectly(aStore, loop)); + + auto aStoreAccess = analyzer.accessFor(aStore); + ASSERT_NE(aStoreAccess, nullptr); + + // It should have bounds covering the range of x: 0 <= x < 10. + ASSERT_TRUE(indexBoundsEquals( + aStoreAccess->bounds(), {Bound(alloc(0), alloc(9))})); +} + +// Reductions should promote dependencies as well. +TEST(MemDependency, MemDependencyCheckerLoopReduce) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + MemDependencyChecker analyzer; + + /* + * A[0] = 0; + * for (int x = 0; x < 10; ++x) { + * A[0] = A[x] + 1; + * } + * B[0] = A[0]; + */ + + StorePtr aInit = Store::make(a, {0}, 0); + ExprHandle reduce = Sum()(a, 1, {x}, {x}); + StorePtr aReduce = Store::make(a, {0}, reduce); + StmtPtr loop = For::make(x, 0, 10, aReduce); + StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); + + StmtPtr stmt = Block::make({aInit, loop, bStore}); + + stmt->accept(&analyzer); + + // B -> A. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce)); + + // B depends indirectly on the initializer of A, since the reduction depends + // on it. + ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit)); + ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit)); + + ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit)); + + // B depends on the loop. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); + // A is in the loop and depends on other iterations. + ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop)); + + // The loop contents depend on the initializer too. + ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit)); + + // Find loads within the reduction: + auto reduceLoads = NodeFinder::find(reduce.node()); + // Pull out the access for the load inside the loop. + for (auto load : reduceLoads) { + auto loopLoad = analyzer.accessFor(load); + // It should have 10 element long bounds. + ASSERT_TRUE(indexBoundsEquals( + loopLoad->bounds(), {Bound(alloc(0), alloc(9))})); + } +} + +// Lowering a reduction doesn't affect dependency analysis. +TEST(MemDependency, MemDependencyCheckerLoopReduceExpanded) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + MemDependencyChecker analyzer; + + /* + * A[0] = 0; + * for (int x = 0; x < 10; ++x) { + * A[0] = A[x] + 1; + * } + * B[0] = A[0]; + */ + + StorePtr aInit = Store::make(a, {0}, 0); + ExprHandle aLoad = Load::make(a, {x}); + StorePtr aReduce = Store::make(a, {0}, Add::make(aLoad, 1)); + StmtPtr loop = For::make(x, 0, 10, aReduce); + StorePtr bStore = Store::make(b, {0}, Load::make(a, {0})); + + StmtPtr stmt = Block::make({aInit, loop, bStore}); + + stmt->accept(&analyzer); + + // B -> A. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, aReduce)); + + // B depends indirectly on the initializer of A, since the reduction depends + // on it. + ASSERT_FALSE(analyzer.dependsDirectly(bStore, aInit)); + ASSERT_TRUE(analyzer.dependsIndirectly(bStore, aInit)); + + ASSERT_TRUE(analyzer.dependsDirectly(aReduce, aInit)); + + // B depends on the loop. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, loop)); + // A is in the loop and depends on other iterations. + ASSERT_TRUE(analyzer.dependsDirectly(aReduce, loop)); + + // The loop contents depend on the initializer too. + ASSERT_TRUE(analyzer.dependsDirectly(loop, aInit)); + + // Pull out the access for the store inside the loop. + auto loopLoad = analyzer.accessFor(aLoad.node()); + // It should have 10 element long bounds. + ASSERT_TRUE(indexBoundsEquals( + loopLoad->bounds(), {Bound(alloc(0), alloc(9))})); +} + +// Can determine dependencies of outputs, through to inputs. +TEST(MemDependency, MemDependencyCheckerInputsOutputs) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + // initialize analyzer with inputs and outputs. + analysis::MemDependencyChecker analyzer({a}, {b}); + + // Here's a Relu. + /* + * for (int x = 0; x < 10; ++x) { + * B[x] = Max(A[x], 0); + * } + */ + + ExprHandle aLoad = Load::make(a, {x}); + StorePtr bStore = Store::make(b, {x}, Max::make(aLoad, 0, true)); + StmtPtr loop = For::make(x, 0, 10, bStore); + + StmtPtr stmt = Block::make({loop}); + + stmt->accept(&analyzer); + + // Output depends indirectly on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + // aLoad depends directly on the input A. + ASSERT_TRUE(analyzer.dependsDirectly(aLoad.node(), a.node())); + // bStore therefore depends directly on the input A. + ASSERT_TRUE(analyzer.dependsDirectly(bStore, a.node())); + // The output depends directly on the store. + ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore)); + + // Check AccessInfo based overloads. + auto input = analyzer.input(a.node()); + auto output = analyzer.output(b.node()); + + // Output depends indirectly on input. + ASSERT_TRUE(analyzer.dependsIndirectly(output, input)); + // Not directly. + ASSERT_FALSE(analyzer.dependsDirectly(output, input)); + // Not in reverse order. + ASSERT_FALSE(analyzer.dependsIndirectly(input, output)); + + // output -> bStore -> bLoad -> input. + auto storeAccess = analyzer.accessFor(bStore); + auto loadAccess = analyzer.accessFor(aLoad.node()); + + ASSERT_TRUE(analyzer.dependsDirectly(output, storeAccess)); + ASSERT_TRUE(analyzer.dependsDirectly(loadAccess, input)); +} + +// Can tell if an output does not depend on an input. +TEST(MemDependency, MemDependencyCheckerOutputDoesntDepend) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + // initialize analyzer with inputs and outputs. + analysis::MemDependencyChecker analyzer({a}, {b}); + + // Here's a dumb Relu. + /* + * for (int x = 0; x < 10; ++x) { + * B[x] = Max(x, 0); + * } + */ + + StorePtr bStore = Store::make(b, {x}, Max::make(x, 0, true)); + StmtPtr loop = For::make(x, 0, 10, bStore); + + StmtPtr stmt = Block::make({loop}); + + stmt->accept(&analyzer); + + // Output does not depend indirectly on input. + ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), a.node())); + + // The output still depends directly on the store. + ASSERT_TRUE(analyzer.dependsDirectly(b.node(), bStore)); + + // Check AccessInfo based overloads. + auto input = analyzer.input(a.node()); + auto output = analyzer.output(b.node()); + + // Output does not depend indirectly on input. + ASSERT_FALSE(analyzer.dependsIndirectly(output, input)); +} + +// Verify different loop extents produce accesses with different bounds, and +// that later accesses find dependencies that overlap their entire bound range. +TEST(MemDependency, MemDependencyCheckerLoopBounds) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + using namespace analysis; + + MemDependencyChecker analyzer({a}, {c}); + + // This enables using the execution order of the loops to determine if some + // loops are self dependent or not. + analyzer.allowLoopExecutionOrderAnalysis(); + + /* + * for (int x = 1; x < 10; ++x) { + * B[x] = A[x]; + * } + * for (int x = 1; x < 9; ++x) { + * B[x] = B[x] * 2; + * } + * for (int x = 3; x < 4; ++x) { + * C[x] = A[x]; + * } + * for (int x = 0; x < 10; ++x) { + * C[x] = B[x]; + * } + */ + + std::vector stmts( + {For::make(x, 1, 10, Store::make(b, {x}, Load::make(a, {x}))), + For::make( + x, 1, 9, Store::make(b, {x}, Mul::make(Load::make(b, {x}), 2))), + For::make(x, 3, 4, Store::make(c, {x}, Load::make(a, {x}))), + For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x})))}); + + StmtPtr stmt = Block::make(stmts); + + stmt->accept(&analyzer); + + auto input = analyzer.input(a.node()); + auto output = analyzer.output(c.node()); + + // sanity check Output -> Input. + ASSERT_TRUE(analyzer.dependsIndirectly(output, input)); + + // Check the For loop dependencies: + + // Last write to C depends on both writes to B since they contain the last + // write to at least one element. + ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[1])); + ASSERT_TRUE(analyzer.dependsIndirectly(stmts[3], stmts[0])); + + // The last write to C does not depend on the other write to C. + ASSERT_FALSE(analyzer.dependsIndirectly(stmts[3], stmts[2])); + + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + /* 0. Input: A[(0, 9)] - dependents: 1 5 + * 1. Load: A[(1, 9)] - depends on: 0 - dependents: 2 + * 2. Store: B[(1, 9)] - depends on: 1 - dependents: 3 7 + * 3. Load: B[(1, 8)] - depends on: 2 - dependents: 4 + * 4. Store: B[(1, 8)] - depends on: 3 - dependents: 7 + * 5. Load: A[(3, 3)] - depends on: 0 - dependents: 6 + * 6. Store: C[(3, 3)] - depends on: 5 + * 7. Load: B[(0, 9)] - depends on: 2 4 - dependents: 8 + * 8. Store: C[(0, 9)] - depends on: 7 - dependents: 9 + * 9. Output: C[(0, 9)] - depends on: 8 + */ + + // Now let's look at the bounds of each access. + // There are 9 accesses in this Stmt, so this is exhaustive, we wont do this + // much. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 10); + VarPtr aVar = a.node()->base_handle(); + VarPtr bVar = b.node()->base_handle(); + VarPtr cVar = c.node()->base_handle(); + + // The first access is the input A. + ASSERT_EQ(history[0]->type(), AccessType::Input); + ASSERT_EQ(history[0]->var(), aVar); + // It has the bounds of the producing Input. + ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)})); + // sanity check the input we retrieved earlier matches. + ASSERT_EQ(history[0], input); + + // The second access is the load of A in the first loop. + ASSERT_EQ(history[1]->type(), AccessType::Load); + ASSERT_EQ(history[1]->var(), aVar); + // It has the bounds of the loop, i.e. start == 1. + ASSERT_TRUE(EQ(history[1]->bounds(), {CB(1, 9)})); + // It reads from A, so it should have a dependency on the last write to this + // range - with is the input. + ASSERT_EQ(history[1]->dependencies().size(), 1); + ASSERT_TRUE(history[1]->hasDependency(history[0])); + + // The third access is the store into B in the first loop. + ASSERT_EQ(history[2]->type(), AccessType::Store); + ASSERT_EQ(history[2]->var(), bVar); + // It also has the bounds of the loop, i.e. start == 1. + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)})); + // The previous load is in its RHS, so it depends on it. + ASSERT_EQ(history[2]->dependencies().size(), 1); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + + // The third access is the load from B in the second loop. + ASSERT_EQ(history[3]->type(), AccessType::Load); + ASSERT_EQ(history[3]->var(), bVar); + // It has the bounds of the second loop, i.e. >= 1 < 9. + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 8)})); + // It reads from B in a smaller range, so should depend on the previous + // store. + ASSERT_EQ(history[3]->dependencies().size(), 1); + ASSERT_TRUE(history[3]->hasDependency(history[2])); + + // The fourth: the store to B in the second loop. + ASSERT_EQ(history[4]->type(), AccessType::Store); + ASSERT_EQ(history[4]->var(), bVar); + // It also has the bounds of the second loop. + ASSERT_TRUE(EQ(history[4]->bounds(), {CB(1, 8)})); + // The previous load is in its RHS, so it depends on it as before. + ASSERT_EQ(history[4]->dependencies().size(), 1); + ASSERT_TRUE(history[4]->hasDependency(history[3])); + + // The fifth access is the load is from the 3rd loop, and skips previous B + // accesses. + ASSERT_EQ(history[5]->type(), AccessType::Load); + ASSERT_EQ(history[5]->var(), aVar); + // It has the bounds of the third loop: >= 3 < 4. + ASSERT_TRUE(EQ(history[5]->bounds(), {CB(3, 3)})); + // It depends on the last thing to write to A, which is the A input. + ASSERT_EQ(history[5]->dependencies().size(), 1); + ASSERT_TRUE(history[5]->hasDependency(history[0])); + + // Sixth: the store into the output C. + ASSERT_EQ(history[6]->type(), AccessType::Store); + ASSERT_EQ(history[6]->var(), cVar); + // It also has the bounds of the third loop. + ASSERT_TRUE(EQ(history[6]->bounds(), {CB(3, 3)})); + // The previous load is in its RHS, so it depends on it as always. + ASSERT_EQ(history[6]->dependencies().size(), 1); + ASSERT_TRUE(history[6]->hasDependency(history[5])); + + // The seventh access is the load of B in the fourth loop. + ASSERT_EQ(history[7]->type(), AccessType::Load); + ASSERT_EQ(history[7]->var(), bVar); + // It has the bounds of the final loop, >= 0 < 10 + ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)})); + // The bounds of this read are larger than the bounds of the previous write, + // so it depends on both previous Stores to B. + ASSERT_EQ(history[7]->dependencies().size(), 2); + ASSERT_TRUE(history[7]->hasDependency(history[2])); + ASSERT_TRUE(history[7]->hasDependency(history[4])); + + // Eight: the final store into the output C. + ASSERT_EQ(history[8]->type(), AccessType::Store); + ASSERT_EQ(history[8]->var(), cVar); + // It also has the bounds of the final loop. + ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)})); + // The previous load is in its RHS, so it depends on it as always. + ASSERT_EQ(history[8]->dependencies().size(), 1); + ASSERT_TRUE(history[8]->hasDependency(history[7])); + + // The last access represents the output Buf. + ASSERT_EQ(history[9]->type(), AccessType::Output); + ASSERT_EQ(history[9]->var(), cVar); + // It has the bounds of the output Buf. + ASSERT_TRUE(EQ(history[9]->bounds(), {CB(0, 9)})); + // sanity check the input we retrieved earlier matches. + ASSERT_EQ(history[9], output); + // It depends on the last write to C only. + ASSERT_EQ(history[9]->dependencies().size(), 1); + ASSERT_TRUE(history[9]->hasDependency(history[8])); +} + +// Verify that we can still infer bounds when the loop var is offset. +TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + MemDependencyChecker analyzer({a}, {b}); + + // This enables using the execution order of the loops to determine if some + // loops are self dependent or not. + analyzer.allowLoopExecutionOrderAnalysis(); + + /* + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + * for (int x = 0; x < 9; x++) { + * A[x] = A[x + 1]; + * } + * for (int x = 0; x < 9; x++) { + * A[9 - x] = A[8 - x]; + * } + * for (int x = 0; x < 10; x++) { + * A[x] = A[9 - x]; + * } + * for (int x = 0; x < 10; x++) { + * B[x] = A[x]; + * } + */ + + StmtPtr stmt = Block::make( + {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))), + For::make(x, 0, 9, Store::make(a, {x}, Load::make(a, {x + 1}))), + For::make( + x, + 0, + 9, + Store::make( + a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))), + For::make( + x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))), + For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x})))}); + + stmt->accept(&analyzer); + + // Sanity check output depends on Input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + auto CB = [](int s, int e) { + return Bound(alloc(s), alloc(e)); + }; + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + /* 0. Input: A[(0, 9)] - dependents: 1 + * 1. Load: A[(0, 8)] - depends on: 0 2 - dependents: 2 + * 2. Store: A[(1, 9)] - depends on: 1 - dependents: 1 3 + * 3. Load: A[(1, 9)] - depends on: 2 - dependents: 4 + * 4. Store: A[(0, 8)] - depends on: 3 - dependents: 5 7 + * 5. Load: A[(0, 8)] - depends on: 4 - dependents: 6 + * 6. Store: A[(1, 9)] - depends on: 5 - dependents: 7 + * 7. Load: A[(0, 9)] - depends on: 4 6 8 - dependents: 8 + * 8. Store: A[(0, 9)] - depends on: 7 - dependents: 7 9 + * 9. Load: A[(0, 9)] - depends on: 8 - dependents: 10 + * 10. Store: B[(0, 9)] - depends on: 9 - dependents: 11 + * 11. Output: B[(0, 9)] - depends on: 10 + */ + + // Now let's look at the bounds of each access. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 12); + VarPtr aVar = a.node()->base_handle(); + VarPtr bVar = b.node()->base_handle(); + + // The first access is the input A. + ASSERT_EQ(history[0]->type(), AccessType::Input); + ASSERT_EQ(history[0]->var(), aVar); + // It has the bounds of the producing Input. + ASSERT_TRUE(EQ(history[0]->bounds(), {CB(0, 9)})); + + // The second access is the load A[x-1]. + ASSERT_EQ(history[1]->type(), AccessType::Load); + ASSERT_EQ(history[1]->var(), aVar); + // It has the bounds of the loop modified by the offset of each index, in + // this case -1. + ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 8)})); + // It depends on the input, but also the store in the same loop, since + // different interations of the loop depend on each other. + ASSERT_EQ(history[1]->dependencies().size(), 2); + ASSERT_TRUE(history[1]->hasDependency(history[0])); + ASSERT_TRUE(history[1]->hasDependency(history[2])); + + // The third access is the Store to A[x] in the first loop. + ASSERT_EQ(history[2]->type(), AccessType::Store); + ASSERT_EQ(history[2]->var(), aVar); + // It has no offset on x, so should have the same bounds as the loop. + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(1, 9)})); + + // The fourth access is the load A[x+1] in the second loop. + ASSERT_EQ(history[3]->type(), AccessType::Load); + ASSERT_EQ(history[3]->var(), aVar); + // It has the bounds of the loop (0 <= x < 9) modified by the offset of each + // index, in this case 1. + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(1, 9)})); + // This load totally overlaps the previous write to A, so it depends only on + // it and not the input. + ASSERT_EQ(history[3]->dependencies().size(), 1); + ASSERT_TRUE(history[3]->hasDependency(history[2])); + + // The fifth access is the store to A[x] in the second loop. + ASSERT_EQ(history[4]->type(), AccessType::Store); + ASSERT_EQ(history[4]->var(), aVar); + // It has no offset on x, so should have the same bounds as the loop. + ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, 8)})); + + // The sixth access is the load to A[8 - x] in the third loop. + ASSERT_EQ(history[5]->type(), AccessType::Load); + ASSERT_EQ(history[5]->var(), aVar); + // It has the bounds of the loop (0 <= x < 9) modified by the offset of each + // index, in this case 8 - x. + // This access has a negative stride, which will be normalized. + ASSERT_TRUE(EQ(history[5]->bounds(), {CB(0, 8)})); + // This load totally overlaps the most recent write to A, so it depends only + // on it and not the input or the first write to A. + ASSERT_EQ(history[5]->dependencies().size(), 1); + ASSERT_TRUE(history[5]->hasDependency(history[4])); + + // The seventh access is the store to A[9 - x] in the third loop. + ASSERT_EQ(history[6]->type(), AccessType::Store); + ASSERT_EQ(history[6]->var(), aVar); + // This store has a negative stride on it's indices, but is normalized + // internally. + ASSERT_TRUE(EQ(history[6]->bounds(), {CB(1, 9)})); + + // The eighth access is the load A[9-x] in the second loop. + ASSERT_EQ(history[7]->type(), AccessType::Load); + ASSERT_EQ(history[7]->var(), aVar); + // It has the bounds of the loop (0 <= x < 9), modified by the offset 9 - x, + // which essentially traverses the loop backwards. + ASSERT_TRUE(EQ(history[7]->bounds(), {CB(0, 9)})); + // This Load has three write dependencies: + ASSERT_EQ(history[7]->dependencies().size(), 3); + // * The previous store (#6) for elements 1-9 + ASSERT_TRUE(history[7]->hasDependency(history[6])); + // * An earlier store (#4) covering element 0 + ASSERT_TRUE(history[7]->hasDependency(history[4])); + // * A future store inside this loop, since this loop modifies the buffer + // in a non distinct way (due to the load and store having different access + // strides). + ASSERT_TRUE(history[7]->hasDependency(history[8])); + + // The ninth access is the store to A[x] in the fourth loop. + ASSERT_EQ(history[8]->type(), AccessType::Store); + ASSERT_EQ(history[8]->var(), aVar); + // This store has a negative stride on it's indices, but is normalized + // internally. + ASSERT_TRUE(EQ(history[8]->bounds(), {CB(0, 9)})); + + // The tenth and 11th accesses are the copy from A[x] to B[x]. + ASSERT_EQ(history[9]->type(), AccessType::Load); + ASSERT_EQ(history[9]->var(), aVar); + ASSERT_EQ(history[10]->type(), AccessType::Store); + ASSERT_EQ(history[10]->var(), bVar); + + // The last access represents the output Buf. + ASSERT_EQ(history[11]->type(), AccessType::Output); + ASSERT_EQ(history[11]->var(), bVar); + // It has the bounds of the output Buf. + ASSERT_TRUE(EQ(history[11]->bounds(), {CB(0, 9)})); + // It depends on the last write to B only. + ASSERT_EQ(history[11]->dependencies().size(), 1); + ASSERT_TRUE(history[11]->hasDependency(history[10])); + + // ok that's enough of that. +} + +// Check many different cases of loop self dependency - when a load within a +// loop is dependent on a Store later in the same loop but in different +// iteration. This is affected by whether or not we can trust the execution +// order of the loop. +TEST(MemDependency, MemDependencyCheckerLoopSelfDependency) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + + using namespace analysis; + + // This check assumes that the Stmt has a single Store with a single Load on + // the RHS. + auto isSelfDependent = + [](const std::vector>& history) -> bool { + return history.front()->hasDependency(history.back()); + }; + + { + /* for (int y = 0; y < 10; y++) { + * A[y] = (A[y]) + 1; + * } */ + + // Not self dependent since all loop iterations use a different y. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + y, + 0, + 10, + Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), 1))})); + + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int y = 0; y < 10; y++) { + * A[y + 1] = (A[y + 1]) + 1; + * } + */ + + // Not self dependent due to different y (with offset). + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + y, + 0, + 10, + Block::make( + {Store::make(a, {y + 1}, Add::make(Load::make(a, {y + 1}), 1))})); + + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + // Is self dependent since all loops use a common constant element of A. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, + 0, + 10, + Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[0] = (B[0]) + x; + * } + */ + + // Is not self dependent because there is no store to the buffer that is + // read. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, + 0, + 10, + Block::make({Store::make(a, {0}, Add::make(Load::make(b, {0}), x))})); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[y] = (A[y]) + x; + * } + */ + + // Is self dependent since all loops use a common symbolic element of A. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, + 0, + 10, + Block::make({Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x] = A[x + 1]; + * } + */ + + // In this case it depends if we are considering execution order. + + MemDependencyChecker analyzer; + + StmtPtr stmt = + For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}))); + stmt->accept(&analyzer); + + // With analysis of order disabled, this is self dependent since the read + // from X+1 and the write to X+1 could be in reverse order. + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x] = A[x + 1]; + * } + */ + + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + + StmtPtr stmt = + For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 1}))); + stmt->accept(&analyzer); + + // If order analysis is enabled, this is not dependent since the read for + // each element occurs before the write to that element. + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + */ + + MemDependencyChecker analyzer; + + StmtPtr stmt = + For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + */ + + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + + StmtPtr stmt = + For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))); + stmt->accept(&analyzer); + + // In this case, even with order analysis the Load is dependent on the + // Store, since the write to X occurs before the read from X. + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 9; x++) { + * A[9 - x] = A[8 - x]; + * } + */ + + // Still works if the execution order is reversed, so long as the read + // comes before the write. + + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + + StmtPtr stmt = For::make( + x, + 3, + 10, + Store::make( + a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))); + stmt->accept(&analyzer); + + // However here was can determine the A store is earlier in the order than + // the load. + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 9; x++) { + * A[8 - x] = A[9 - x]; + * } + */ + + // But not if it doesn't. + + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + + StmtPtr stmt = For::make( + x, + 3, + 10, + Store::make( + a, {ExprHandle(8) - x}, Load::make(a, {ExprHandle(9) - x}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 9; x++) { + * A[9 - x] = A[8 - x]; + * } + */ + + // And not if we're not relying on execution order. + + MemDependencyChecker analyzer; + + StmtPtr stmt = For::make( + x, + 3, + 10, + Store::make( + a, {ExprHandle(9) - x}, Load::make(a, {ExprHandle(8) - x}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 3; x < 10; x++) { + * A[x - 2] = A[x - 1]; + * } + */ + + // Forward order but negative indices. + + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + + StmtPtr stmt = + For::make(x, 3, 10, Store::make(a, {x - 2}, Load::make(a, {x - 1}))); + stmt->accept(&analyzer); + + // However here was can determine the A store is earlier in the order than + // the load. + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2]; + * } + */ + + // With an access stride. + + MemDependencyChecker analyzer; + // Execution order doesn't matter since the read and the write are totally + // distinct. + + StmtPtr stmt = + For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 + 1]; + * } + */ + + // Here we can use the common stride of the accesses to determine they are + // distinct. + // Note, this is the only place (loop self dependency) we use this stride + // to avoid unnecessary dependence. + + MemDependencyChecker analyzer; + // Execution order doesn't matter since the read and the write are totally + // distinct. + + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 1}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 - 1]; + * } + */ + + // same if the read is behind the write so long as they are distinct. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 1}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 + 2]; + * } + */ + + // But not if the offset is in the stride. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 2}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 - 2]; + * } + */ + + // Works with negative offsets too. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 1, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 - 2}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 + 7]; + * } + */ + + // Detects accesses are distinct when offset is large but not a multiple + // of stride. + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 7}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 2 + 4]; + * } + */ + + // Works with offsets which are multiples of the stride. + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 2 + 4}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 6] = A[x * 6 + 5]; + * } + */ + + // detects accesses are distinct with large strides when the offset is + // within. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 6}, Load::make(a, {x * 6 + 5}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 6]; + * } + */ + + // detects accesses are overlapping when stride is different but a + // multiple. + + MemDependencyChecker analyzer; + StmtPtr stmt = + For::make(x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 4] = A[x * 2]; + * } + */ + + // still works when the read axis is the smaller stride. + + MemDependencyChecker analyzer; + StmtPtr stmt = + For::make(x, 0, 10, Store::make(a, {x * 4}, Load::make(a, {x * 2}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 6 + 1]; + * } + */ + + // detects accesses are distinct when stride is different but a multiple + // and there is an offset. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 1}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 6 + 4]; + * } + */ + + // The smaller stride determines whether there is overlap. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 6 + 4}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2 + 3] = A[x * 6]; + * } + */ + + // The smaller stride determines whether there is overlap, not the larger. + + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2 + 3}, Load::make(a, {x * 6}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[x * 3 + 1]; + * } + */ + + // If they have strides with no common multiple > 1, they overlap. + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x * 2}, Load::make(a, {x * 3 + 1}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x] = A[x + 10]; + * } + */ + + // If the offset is greater than the size of the loop, they can't overlap. + + MemDependencyChecker analyzer; + StmtPtr stmt = + For::make(x, 0, 10, Store::make(a, {x}, Load::make(a, {x + 10}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x] = A[9 - x]; + * } + */ + + // If they have different execution orders they may overlap. + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x}, Load::make(a, {ExprHandle(9) - x}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x * 2] = A[19 - x * 2]; + * } + */ + + // Or they may not, depending on their start offset and strides. + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, + 0, + 10, + Store::make(a, {x * 2}, Load::make(a, {ExprHandle(19) - x * 2}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x / 2] = A[x / 2]; + * } + */ + + // If the stride is not monotonic, they overlap. + + MemDependencyChecker analyzer; + StmtPtr stmt = + For::make(x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x / 2] = A[x / 2] + 1; + * } + */ + + // If the stride is not monotonic, they overlap - even with an offset. + MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, 0, 10, Store::make(a, {x / 2}, Load::make(a, {x / 2 + 1}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = 0; x < 10; x++) { + * A[x % 2] = A[x % 2]; + * } + */ + + // Mod too... + + analysis::MemDependencyChecker analyzer; + StmtPtr stmt = For::make( + x, + 0, + 10, + Store::make(a, {Mod::make(x, 2)}, Load::make(a, {Mod::make(x, 2)}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + /* for (int x = y; x < z; x++) { + * A[x] = A[x + 1]; + * } + */ + + // Still works with symbolic loop extents. + + { + MemDependencyChecker analyzer; + StmtPtr stmt = + For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}))); + stmt->accept(&analyzer); + + ASSERT_TRUE(isSelfDependent(analyzer.getHistory())); + } + + { + MemDependencyChecker analyzer; + analyzer.allowLoopExecutionOrderAnalysis(); + StmtPtr stmt = + For::make(x, y, z, Store::make(a, {x}, Load::make(a, {x + 1}))); + stmt->accept(&analyzer); + + ASSERT_FALSE(isSelfDependent(analyzer.getHistory())); + } + } +} + +// Verify that a strided access still works. +// TODO: actually this only works because of the size of the ranges, revisit +// this test after strided overlap is implemented. +TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { + BufHandle a("A", {20}, kInt); + BufHandle b("B", {20}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + MemDependencyChecker analyzer({a.node()}, {b.node()}); + StmtPtr stmt = Block::make( + {For::make( + x, 0, 10, Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))), + For::make(x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))) + + }); + stmt->accept(&analyzer); + + // Sanity check output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // Output has 2 dependencies... the store in each loop. + auto outputAccess = analyzer.output(b.node()); + ASSERT_EQ(outputAccess->dependencies().size(), 2); +} + +/* TODO(nickg) - this test will fail due to the lack of stride math in Bound +TEST(MemDependency, MemDependencyCheckerLoopDistinctStrides) { + BufHandle a("A", {20}, kInt); + BufHandle b("B", {20}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + analysis::MemDependencyChecker analyzer({a.node()}, {c.node()}); + StmtPtr stmt = Block::make( + {For::make( + x, + 0, + 10, + Store::make(b, {x * 2 + 1}, Load::make(a, {x * 2 + 1}))), + For::make( + x, 0, 10, Store::make(b, {x * 2}, Load::make(a, {x * 2}))), + For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))) + + }); + stmt->accept(&analyzer); + + std::cout << *stmt << "\n"; + for (auto& wi : analyzer.getHistory()) { + wi->print(); + } + } +}*/ + +// analysis on Stmts using Cond. +TEST(MemDependency, MemDependencyCheckerLoopBoundsCond) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * if (y<5 ? 1 : 0) { + * C[0] = (B[0]) + 1; + * } else { + * C[0] = (B[1]) + 1; + * } + */ + + // Future usages may depend on accesses in both branches of a condition. + + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), + Cond::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + Store::make(c, {0}, Add::make(Load::make(b, {0}), 1)), + Store::make(c, {0}, Add::make(Load::make(b, {1}), 1)))}); + + stmt->accept(&analyzer); + + // Output C should have 3 dependencies, each of the three stores. + auto outputAccess = analyzer.output(c.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 3); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * if (y<5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * C[x] = B[x]; + * } + * } else { + * for (int x = 0; x < 10; x++) { + * C[x] = (B[x]) + 1; + * } + * } + */ + + // Future usages may depend on accesses in both branches of a condition. + + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), + Cond::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + For::make(x, 0, 10, Store::make(c, {x}, Load::make(b, {x}))), + For::make( + x, + 0, + 10, + Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))}); + + stmt->accept(&analyzer); + + // Output C should have 3 dependencies, each of the three stores. + auto outputAccess = analyzer.output(c.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 3); + + // TODO(nickg): actually since the true and false branch cover the total + // range of the first store this should have 2 dependencies, but we don't + // do that yet. + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * if (y<5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * C[x] = (B[x]) + 1; + * } + * } + */ + + // Only has true branch. + + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), + Cond::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + For::make( + x, + 0, + 10, + Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))), + nullptr)}); + + stmt->accept(&analyzer); + + // Output C should have 3 dependencies, each of the three stores. + auto outputAccess = analyzer.output(c.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 2); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * if (y<5 ? 1 : 0) { + * } else { + * for (int x = 0; x < 10; x++) { + * C[x] = (B[x]) + 1; + * } + * } + */ + + // Only has false branch. + + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), + Cond::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + nullptr, + For::make( + x, + 0, + 10, + Store::make(c, {x}, Add::make(Load::make(b, {x}), 1))))}); + + stmt->accept(&analyzer); + + // Output C should have 3 dependencies, each of the three stores. + auto outputAccess = analyzer.output(c.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 2); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * if (C[0]<5 ? 1 : 0) { + * C[0] = 5; + * } + */ + + // Cond's Condition depends on a previous access. + + MemDependencyChecker analyzer({a}, {c}); + StorePtr initStore = Store::make(c, {x}, Load::make(a, {x})); + ExprHandle conditionalLoad = Load::make(c, {0}); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, initStore), + Cond::make( + CompareSelect::make( + conditionalLoad, 5, CompareSelectOperation::kLT), + Store::make(c, {0}, 5), + nullptr)}); + + stmt->accept(&analyzer); + + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + + ASSERT_TRUE(analyzer.dependsDirectly(conditionalLoad.node(), initStore)); + ASSERT_FALSE(analyzer.dependsDirectly(conditionalLoad.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(conditionalLoad.node(), a.node())); + } +} + +// Stmts using IfThenElse. +TEST(MemDependency, MemDependencyCheckerIfThenElse) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + using namespace analysis; + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * C[0] = (y < 5 ? (B[0]) + 1 : (B[1]) + 1; + */ + + // Future usages may depend on accesses in both branches of a condition. + + MemDependencyChecker analyzer({a, b}, {c}); + StorePtr ifStore = Store::make( + c, + {0}, + IfThenElse::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + Add::make(Load::make(b, {0}), 1), + Add::make(Load::make(b, {1}), 1))); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), + ifStore}); + + stmt->accept(&analyzer); + + // Output C should have 2 dependencies, each of the two stores. + auto outputAccess = analyzer.output(c.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 2); + + // Now we need to check the Store containing the IfThenElse. + auto ifStoreAccess = analyzer.accessFor(ifStore); + + // It should have 2 dependencies. + ASSERT_EQ(ifStoreAccess->dependencies().size(), 2); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[x]; + * } + * C[0] = (y < 5 ? (B[0]) + 1 : 42; + */ + + // If the load appears in only one side of an IfThenElse the output may be + // dependent on it. + + MemDependencyChecker analyzer({a, b}, {c}); + StorePtr ifStore = Store::make( + c, + {0}, + IfThenElse::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + Add::make(Load::make(b, {0}), 1), + 42)); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, Store::make(c, {x}, Load::make(a, {x}))), + ifStore}); + + stmt->accept(&analyzer); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = (x < 5 ? B[x] : A[x]; + * } + */ + + // In this case C is dependent on both A and B. + + // TODO: in cases like this it would be possible to split the range of B + // into two bounds, one dependent on A and one dependent on B. We'd need to + // examine conditions relative to previously encountered loop variables. I'm + // uncertain if this would be helpful. + + MemDependencyChecker analyzer({a, b}, {c}); + StorePtr ifStore = Store::make( + c, + {0}, + IfThenElse::make( + CompareSelect::make(y, 5, CompareSelectOperation::kLT), + Load::make(b, {x}), + Load::make(a, {x}))); + StmtPtr stmt = Block::make({For::make(x, 0, 10, ifStore)}); + + stmt->accept(&analyzer); + + // C depends indirectly on A and B. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + } +} + +// Cutting a loop with single elem writes +TEST(MemDependency, MemDependencyCheckerCutLoop) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + { + /* for (int x = 0; x < 10; x++) { + * B[x] = A[x]; + * } + * B[5] = 100; + */ + + // Cutting a loop with single element writes. + + MemDependencyChecker analyzer({a}, {b}); + StmtPtr stmt = Block::make( + {For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))), + Store::make(b, {5}, 100)}); + + stmt->accept(&analyzer); + + // Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // Output has 2 dependencies. + auto outputAccess = analyzer.output(b.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 2); + } + + { + /* for (int x = 0; x < 10; x++) { + * B[x] = A[x]; + * } + * for (int x = 4; x < 7; x++) { + * B[x] = B[x] + 3; + * } + * B[5] = 100; + * B[6] = 101; + * B[7] = 102; + */ + + // Cutting a loop with a smaller loop but then totally overlap that second + // loop with one element writes. + + MemDependencyChecker analyzer({a}, {b}); + ForPtr firstLoop = + For::make(x, 0, 10, Store::make(b, {x}, Load::make(a, {x}))); + StorePtr secondStore = + Store::make(b, {x}, Add::make(Load::make(b, {x}), 1)); + ForPtr secondLoop = For::make(x, 4, 7, secondStore); + + StmtPtr stmt = Block::make( + {firstLoop, + secondLoop, + Store::make(b, {4}, 100), + Store::make(b, {5}, 101), + Store::make(b, {6}, 102)}); + + stmt->accept(&analyzer); + + // Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // Output has 4 dependencies. + auto outputAccess = analyzer.output(b.node()); + ASSERT_NE(outputAccess, nullptr); + ASSERT_EQ(outputAccess->dependencies().size(), 4); + + // Second loop depends on first loop. + ASSERT_TRUE(analyzer.dependsDirectly(secondLoop, firstLoop)); + + // Output does not depend on second loop or store. + ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondLoop)); + ASSERT_FALSE(analyzer.dependsIndirectly(b.node(), secondStore)); + } +} + +// Dynamic shapes (load in indices). +TEST(MemDependency, MemDependencyCheckerDynamicShapes) { + BufHandle a("A", {100}, kInt); + BufHandle b("B", {100}, kInt); + BufHandle c("C", {100}, kInt); + VarHandle x("x", kInt); + + using namespace analysis; + + auto CB = [](ExprHandle s, ExprHandle e) { + return Bound(s.node(), e.node()); + }; + + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + { + /* for (int x = 0; x < B[0]; x++) { + * C[x] = A[x]; + * } + */ + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make({For::make( + x, 0, Load::make(b, {0}), Store::make(c, {x}, Load::make(a, {x})))}); + + stmt->accept(&analyzer); + + /* 0. Input: B[(0, 99)] - dependents: 2 + * 1. Input: A[(0, 99)] - dependents: 3 + * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 3 4 + * 3. Load: A[(0, (B[0]) - 1)] - depends on: 1 2 - dependents: 4 + * 4. Store: C[(0, (B[0]) - 1)] - depends on: 2 3 - dependents: 5 + * 5. Output: C[(0, 99)] - depends on: 4 + */ + + // Output dependent on A input. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + // Also dependent on B input to determine the size of the region written. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 6); + + // The accesses in the loop depend on the load in the stop condition. + ASSERT_TRUE(history[4]->hasDependency(history[2])); + ASSERT_TRUE(history[3]->hasDependency(history[2])); + + // Make a load from B to compare against. + ExprHandle loadFromB = Load::make(b, {0}); + + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, loadFromB - 1)})); + ASSERT_TRUE(EQ(history[4]->bounds(), {CB(0, loadFromB - 1)})); + } + + { + /* for (int x = B[0]; x < B[1]; x++) { + * C[x] = A[x]; + * } + */ + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make({For::make( + x, + Load::make(b, {0}), + Load::make(b, {1}), + Store::make(c, {x}, Load::make(a, {x})))}); + + stmt->accept(&analyzer); + + /* 0. Input: B[(0, 99)] - dependents: 2 3 + * 1. Input: A[(0, 99)] - dependents: 4 + * 2. Load: B[(0, 0)] - depends on: 0 - dependents: 4 5 + * 3. Load: B[(1, 1)] - depends on: 0 - dependents: 4 5 + * 4. Load: A[(B[0], (B[1]) - 1)] - depends on: 1 2 3 - dependents: 5 + * 5. Store: C[(B[0], (B[1]) - 1)] - depends on: 2 3 4 - dependents: 6 + * 6. Output: C[(0, 99)] - depends on: 5 + */ + + // Sanity check output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 7); + + // The accesses in the loop depend on the load in the start condition. + ASSERT_TRUE(history[5]->hasDependency(history[2])); + ASSERT_TRUE(history[4]->hasDependency(history[2])); + + // also the stop condition. + ASSERT_TRUE(history[5]->hasDependency(history[3])); + ASSERT_TRUE(history[4]->hasDependency(history[3])); + + // Make loads from B to compare against. + ExprHandle loadFromB0 = Load::make(b, {0}); + ExprHandle loadFromB1 = Load::make(b, {1}); + ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromB0, loadFromB1 - 1)})); + ASSERT_TRUE(EQ(history[5]->bounds(), {CB(loadFromB0, loadFromB1 - 1)})); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[x] = A[B[x]]; + * } + */ + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make({For::make( + x, 0, 10, Store::make(c, {x}, Load::make(a, {Load::make(b, {x})})))}); + + stmt->accept(&analyzer); + + /* 0. Input: B[(0, 99)] - dependents: 2 + * 1. Input: A[(0, 99)] - dependents: 3 + * 2. Load: B[(0, 9)] - depends on: 0 - dependents: 3 4 + * 3. Load: A[(B[0], B[9])] - depends on: 1 2 - dependents: 4 + * 4. Store: C[(0, 9)] - depends on: 2 3 - dependents: 5 + * 5. Output: C[(0, 99)] - depends on: 4 + */ + + // Sanity check output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 6); + + // The store depends on both loads, the load of A depends on the load of B. + ASSERT_TRUE(history[4]->hasDependency(history[2])); + ASSERT_TRUE(history[4]->hasDependency(history[3])); + + ASSERT_TRUE(history[3]->hasDependency(history[2])); + + // The loads in the indices depend on the relevant input buffer. + ASSERT_TRUE(history[3]->hasDependency(history[1])); + ASSERT_TRUE(history[2]->hasDependency(history[0])); + + // The load from B has the loop bounds. + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); + + // The load from A has bounds B[0] to B[9]. + ExprHandle loadFromB0 = Load::make(b, {0}); + ExprHandle loadFromB9 = Load::make(b, {9}); + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromB0, loadFromB9)})); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[B[x]] = A[x]; + * } + */ + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make({For::make( + x, 0, 10, Store::make(c, {Load::make(b, {x})}, Load::make(a, {x})))}); + + stmt->accept(&analyzer); + + /* 0. Input: B[(0, 99)] - dependents: 3 + * 1. Input: A[(0, 99)] - dependents: 2 + * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 4 + * 3. Load: B[(0, 9)] - depends on: 0 - dependents: 4 + * 4. Store: C[(B[0], B[9])] - depends on: 2 3 - dependents: 5 + * 5. Output: C[(0, 99)] - depends on: 4 + */ + // Sanity check output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 6); + + // The store depends on both loads, neither load is dependent. + ASSERT_TRUE(history[4]->hasDependency(history[2])); + ASSERT_TRUE(history[4]->hasDependency(history[3])); + + ASSERT_FALSE(history[3]->hasDependency(history[2])); + ASSERT_FALSE(history[2]->hasDependency(history[3])); + + // The loads each depend on their relevant input. (but accesses are in a + // different order than the last case). + ASSERT_TRUE(history[3]->hasDependency(history[0])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + + // The load from B has the loop bounds. + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, 9)})); + + // And so does the load from A. + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); + } + + { + /* for (int x = 0; x < 10; x++) { + * C[B[A[x]]] = x; + * } + */ + MemDependencyChecker analyzer({a, b}, {c}); + StmtPtr stmt = Block::make({For::make( + x, 0, 10, Store::make(c, {Load::make(b, {Load::make(a, {x})})}, x))}); + + stmt->accept(&analyzer); + + /* 0. Input: B[(0, 99)] - dependents: 3 + * 1. Input: A[(0, 99)] - dependents: 2 + * 2. Load: A[(0, 9)] - depends on: 1 - dependents: 3 4 + * 3. Load: B[(A[0], A[9])] - depends on: 0 2 - dependents: 4 + * 4. Store: C[(B[A[0]], B[A[9]])] - depends on: 2 3 - dependents: 5 + * 5. Output: C[(0, 99)] - depends on: 4 + */ + + // Sanity check output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(c.node(), b.node())); + + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 6); + + // The store depends on both loads. + ASSERT_TRUE(history[4]->hasDependency(history[2])); + ASSERT_TRUE(history[4]->hasDependency(history[3])); + + // The outer load depends on the inner. + ASSERT_TRUE(history[3]->hasDependency(history[2])); + + // The loads each depend on their relevant input. (but accesses are in a + // different order than the last case). + ASSERT_TRUE(history[3]->hasDependency(history[0])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + + // The load from A has the loop bounds. + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 9)})); + // The load from B as bounds A[0] to A[9]. + ExprHandle loadFromA0 = Load::make(a, {0}); + ExprHandle loadFromA9 = Load::make(a, {9}); + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(loadFromA0, loadFromA9)})); + + // The store has bounds of B[A[0]] to B[A[9]]. + ExprHandle loadFromBA0 = Load::make(b, {loadFromA0}); + ExprHandle loadFromBA9 = Load::make(b, {loadFromA9}); + ASSERT_TRUE(EQ(history[4]->bounds(), {CB(loadFromBA0, loadFromBA9)})); + } +} + +// Verify multi dimensional bounds work. +TEST(MemDependency, MemDependencyCheckerMultiDim) { + int M = 10, N = 9, K = 12; + BufHandle a("A", {M, N, K}, kInt); + BufHandle b("B", {M, N, K}, kInt); + BufHandle c("C", {M, K}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + + using namespace analysis; + + auto CB = [](ExprHandle s, ExprHandle e) { + return Bound(s.node(), e.node()); + }; + + auto EQ = [](const IndexBounds& x, const IndexBounds& y) { + return indexBoundsEquals(x, y); + }; + + { + /* for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 9; y++) { + * for (int z = 0; z < 12; z++) { + * B[x, y, z] = A[x, y, z]; + * } + * } + * } + */ + // Full range. + + MemDependencyChecker analyzer({a}, {b}); + StmtPtr stmt = Block::make({For::make( + x, + 0, + M, + For::make( + y, + 0, + N, + For::make( + z, + 0, + K, + Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))}); + + stmt->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // 4 accesses: input, load, store, output. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 4); + + // Simple chain from input to output. + ASSERT_TRUE(history[3]->hasDependency(history[2])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + ASSERT_TRUE(history[1]->hasDependency(history[0])); + + ASSERT_TRUE( + EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); + ASSERT_TRUE( + EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); + } + + { + /* for (int x = 0; x < 5; x++) { + * for (int y = 0; y < 5; y++) { + * for (int z = 0; z < 5; z++) { + * B[x, y, z] = A[x, y, z]; + * } + * } + * } + */ + // Partial range. + + MemDependencyChecker analyzer({a}, {b}); + StmtPtr stmt = Block::make({For::make( + x, + 0, + 5, + For::make( + y, + 0, + 5, + For::make( + z, + 0, + 5, + Store::make(b, {x, y, z}, Load::make(a, {x, y, z})))))}); + + stmt->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // 4 accesses: input, load, store, output. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 4); + + // Simple chain from input to output. + ASSERT_TRUE(history[3]->hasDependency(history[2])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + ASSERT_TRUE(history[1]->hasDependency(history[0])); + + ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)})); + ASSERT_TRUE(EQ(history[2]->bounds(), {CB(0, 4), CB(0, 4), CB(0, 4)})); + } + + { + /* for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 12; y++) { + * B[x, 0, y] = A[x, 0, y]; + * } + * } + */ + + // Partial loops. + + MemDependencyChecker analyzer({a}, {b}); + StmtPtr stmt = Block::make({For::make( + x, + 0, + N, + For::make( + y, 0, K, Store::make(b, {x, 0, y}, Load::make(a, {x, 0, y}))))}); + + stmt->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // 4 accesses: input, load, store, output. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 4); + + // Simple chain from input to output. + ASSERT_TRUE(history[3]->hasDependency(history[2])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + ASSERT_TRUE(history[1]->hasDependency(history[0])); + + ASSERT_TRUE( + EQ(history[1]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)})); + ASSERT_TRUE( + EQ(history[2]->bounds(), {CB(0, N - 1), CB(0, 0), CB(0, K - 1)})); + } + + { + /* for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 100; y++) { + * for (int z = 0; z < 12; z++) { + * B[x, 0, z] = (A[x, 0, z]) + (C[x, z]); + * } + * } + * } + */ + + // Loops that don't correspond to an index, bufs with different + // dimensionality. + + MemDependencyChecker analyzer({a, c}, {b}); + StmtPtr stmt = Block::make({For::make( + x, + 0, + M, + For::make( + y, + 0, + 100, + For::make( + z, + 0, + K, + Store::make( + b, + {x, 0, z}, + Add::make( + Load::make(a, {x, 0, z}), Load::make(c, {x, z}))))))}); + + stmt->accept(&analyzer); + + // Sanity test: Output depends on both inputs. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), c.node())); + + // 6 accesses: 2 inputs, 2 loads, store, output. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 6); + + // Simple chain from input to output over the A buf. + // history[0] is the C input, history[3] is the load from C. + ASSERT_TRUE(history[5]->hasDependency(history[4])); + ASSERT_TRUE(history[4]->hasDependency(history[2])); + ASSERT_TRUE(history[2]->hasDependency(history[1])); + // The store also depends on the load from the C input. + ASSERT_TRUE(history[4]->hasDependency(history[3])); + ASSERT_TRUE(history[3]->hasDependency(history[0])); + + // A Buf accesses. + ASSERT_TRUE( + EQ(history[4]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)})); + ASSERT_TRUE( + EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, K - 1)})); + + // C buf access. + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, K - 1)})); + } + + { + /* for (int x = 0; x < 9; x++) { + * for (int y = 0; y < 10; y++) { + * for (int z = 0; z < 12; z++) { + * B[x, 0, 0] = (B[x, y, z]) + (A[x, y, z]); + * } + * } + * } + */ + // Multi-dim reductions. + + MemDependencyChecker analyzer({a}, {b}); + StmtPtr stmt = Block::make({For::make( + x, + 0, + M, + For::make( + y, + 0, + N, + For::make( + z, + 0, + K, + Store::make( + b, + {x, 0, 0}, + Add::make( + Load::make(b, {x, y, z}), + Load::make(a, {x, y, z}))))))}); + + stmt->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(b.node(), a.node())); + + // 4 accesses: input, 2 loads, store, output. + auto history = analyzer.getHistory(); + ASSERT_EQ(history.size(), 5); + + // Simple chain from input to output. + ASSERT_TRUE(history[4]->hasDependency(history[3])); + ASSERT_TRUE(history[3]->hasDependency(history[2])); + ASSERT_TRUE(history[3]->hasDependency(history[1])); + ASSERT_TRUE(history[2]->hasDependency(history[0])); + + // The load from B depends on the store to B. + ASSERT_TRUE(history[1]->hasDependency(history[3])); + + ASSERT_TRUE( + EQ(history[1]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); + ASSERT_TRUE( + EQ(history[2]->bounds(), {CB(0, M - 1), CB(0, N - 1), CB(0, K - 1)})); + ASSERT_TRUE(EQ(history[3]->bounds(), {CB(0, M - 1), CB(0, 0), CB(0, 0)})); + } +} + +// Various tests using the external Compute/Reduce API. +TEST(MemDependency, MemDependencyCheckerComputeAPI) { + using namespace analysis; + + /* for (int m = 0; m < 4; m++) { + * for (int n = 0; n < 5; n++) { + * for (int k = 0; k < 6; k++) { + * broadcast_add[m, n, k] = (a[m, n]) + (b[n, k]); + * } + * } + * } + * for (int m_1 = 0; m_1 < 4; m_1++) { + * for (int n_1 = 0; n_1 < 5; n_1++) { + * for (int k_1 = 0; k_1 < 6; k_1++) { + * d[m_1, n_1, k_1] = (broadcast_add(m_1, n_1, k_1)) + float(1); + * } + * } + * } + */ + + // Can determine if 2 loops created by Compute are dependent. + BufHandle a_buf("a", {4, 5}, kFloat); + BufHandle b_buf("b", {5, 6}, kFloat); + Tensor c = Compute( + "broadcast_add", + {4, 5, 6}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + Tensor d = Compute( + "d", + {4, 5, 6}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c.load(m, n, k) + 1; + }); + + LoopNest l({d}, {c, d}); + + MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()}); + + l.root_stmt()->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node())); + + // Second loop depends on first loop. + auto c_loop = l.getLoopStmtsFor(c)[0]; + auto d_loop = l.getLoopStmtsFor(d)[0]; + ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); +} + +TEST(MemDependency, MemDependencyCheckerComputeInline) { + using namespace analysis; + + /* for (int m = 0; m < 4; m++) { + * for (int n = 0; n < 5; n++) { + * for (int k = 0; k < 6; k++) { + * d[m, n, k] = ((a[m, n]) + (b[n, k])) + float(1); + * } + * } + * } + */ + + // Check inlining affects the number of accesses returned. + + BufHandle a_buf("a", {4, 5}, kFloat); + BufHandle b_buf("b", {5, 6}, kFloat); + Tensor c = Compute( + "broadcast_add", + {4, 5, 6}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + Tensor d = Compute( + "d", + {4, 5, 6}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c.load(m, n, k) + 1; + }); + + LoopNest l({d}, {c, d}); + l.computeInline(c.buf()); + + MemDependencyChecker analyzer({a_buf.node(), b_buf.node()}, {d.buf()}); + l.root_stmt()->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a_buf.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b_buf.node())); + + // broadcast_add tensor should not appear in trace at all. + for (auto& wi : analyzer.getHistory()) { + ASSERT_NE(wi->var(), c.buf()->base_handle()); + } +} + +TEST(MemDependency, MemDependencyCheckerComputeSplit) { + using namespace analysis; + // Split an axis, so the number of loops != the number of dimensions. + + BufHandle a_buf("a", {4, 5}, kFloat); + BufHandle b_buf("b", {5, 6}, kFloat); + Tensor c = Compute( + "broadcast_add", + {4, 5, 6}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + + LoopNest l({c}); + + MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()}); + l.root_stmt()->accept(&analyzer_before); + + l.splitWithTail(l.getLoopStmtsFor(c)[0], 2); + + MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()}); + StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); + stmt->accept(&analyzer_after); + + // Splitting should not change accesses at all. + auto history_before = analyzer_before.getHistory(); + auto history_after = analyzer_after.getHistory(); + + ASSERT_EQ(history_before.size(), history_after.size()); + + for (size_t i = 0; i < history_before.size(); ++i) { + ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); + ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); + ASSERT_EQ( + history_before[i]->bounds().size(), history_after[i]->bounds().size()); + ASSERT_TRUE(indexBoundsEquals( + history_before[i]->bounds(), history_after[i]->bounds())); + ASSERT_EQ( + history_before[i]->dependencies().size(), + history_after[i]->dependencies().size()); + ASSERT_EQ( + history_before[i]->dependents().size(), + history_after[i]->dependents().size()); + } +} + +TEST(MemDependency, MemDependencyCheckerComputeReorder) { + using namespace analysis; + // Reorder an axis, so the loop order doesn't match the indexing order. + + BufHandle a_buf("a", {4, 5}, kFloat); + BufHandle b_buf("b", {5, 6}, kFloat); + Tensor c = Compute( + "broadcast_add", + {4, 5, 6}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n) + b_buf.load(n, k); + }); + + LoopNest l({c}); + + MemDependencyChecker analyzer_before({a_buf.node(), b_buf.node()}, {c.buf()}); + l.root_stmt()->accept(&analyzer_before); + + auto loops = l.getLoopStmtsFor(c); + l.reorderAxis(loops[0], loops[1]); + + MemDependencyChecker analyzer_after({a_buf.node(), b_buf.node()}, {c.buf()}); + StmtPtr stmt = IRSimplifier::simplify(l.root_stmt()); + stmt->accept(&analyzer_after); + + // Reordering should not change accesses at all. + auto history_before = analyzer_before.getHistory(); + auto history_after = analyzer_after.getHistory(); + + ASSERT_EQ(history_before.size(), history_after.size()); + + for (size_t i = 0; i < history_before.size(); ++i) { + ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); + ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); + ASSERT_EQ( + history_before[i]->bounds().size(), history_after[i]->bounds().size()); + ASSERT_TRUE(indexBoundsEquals( + history_before[i]->bounds(), history_after[i]->bounds())); + ASSERT_EQ( + history_before[i]->dependencies().size(), + history_after[i]->dependencies().size()); + ASSERT_EQ( + history_before[i]->dependents().size(), + history_after[i]->dependents().size()); + } +} + +TEST(MemDependency, MemDependencyCheckerComputeReduce) { + using namespace analysis; + /* for (int l2 = 0; l2 < 2; l2++) { + * for (int n1 = 0; n1 < 3; n1++) { + * for (int m1 = 0; m1 < 6; m1++) { + * scale[l2, n1, m1] = (b[l2, n1, m1]) * (a[l2, n1, m1]); + * } + * } + * } + * for (int l1 = 0; l1 < 2; l1++) { + * sum[l1] = float(0); + * for (int n1_1 = 0; n1_1 < 3; n1_1++) { + * for (int m1_1 = 0; m1_1 < 6; m1_1++) { + * sum[l1] = ReduceOp(sum, (sum[l1]) + (scale(l1, n1_1, m1_1)), + * out_args={l1}, reduce_args={n1, m1}); + * } + * } + * } + */ + + // Can determine dependencies of a Reduction. + + BufHandle a("a", {2, 3, 6}, kFloat); + BufHandle b("b", {2, 3, 6}, kFloat); + + Tensor c = Compute( + "scale", + {2, 3, 6}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {2}, Sum(), c, {3, 6}); + LoopNest l({d}, {c, d}); + + MemDependencyChecker analyzer({a.node(), b.node()}, {d.buf()}); + + l.root_stmt()->accept(&analyzer); + + // Sanity test: Output depends on input. + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(d.buf(), b.node())); + + // Second loop depends on first loop. + auto c_loop = l.getLoopStmtsFor(c)[0]; + auto d_loop = l.getLoopStmtsFor(d)[0]; + ASSERT_TRUE(analyzer.dependsDirectly(d_loop, c_loop)); + + // Reduction depends on both inputs. + auto reduces = NodeFinder::find(l.root_stmt()); + ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], a.node())); + ASSERT_TRUE(analyzer.dependsIndirectly(reduces[0], b.node())); +} + +TEST(MemDependency, MemDependencyCheckerComputeGEMM) { + int M = 1024; + int N = 1024; + int K = 2048; + using namespace analysis; + + BufHandle AP("A", {M, K}, kFloat); + BufHandle BP("B", {K, N}, kFloat); + Tensor CT = Reduce( + "gemm", + {M, N}, + Sum(), + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { + return AP.load(m, k) * BP.load(k, n); + }, + {K}); + LoopNest loop({CT}); + + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr m = loops[0]; + loop.splitWithMask(m, 4); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr n = loops[2]; + loop.splitWithMask(n, 16); + } + // mo, mi, no, ni, k -> + // mo, no, mi, ni, k + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr mi = loops[1]; + ForPtr no = loops[2]; + loop.reorderAxis(mi, no); + } + // mo, no, mi, ni, k -> + // mo, no, mi, k, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr ni = loops[3]; + ForPtr k = loops[4]; + loop.reorderAxis(ni, k); + } + // mo, no, mi, k, ni -> + // mo, no, k, mi, ni + { + auto const& loops = loop.getLoopStmtsFor(CT); + ForPtr mi = loops[2]; + ForPtr k = loops[3]; + loop.reorderAxis(mi, k); + } + { + auto const& loops = loop.getLoopStmtsFor(CT); + loop.cacheAccesses(CT.buf(), "C_regs", loops[2]); + } + + MemDependencyChecker analyzer_unlowered( + loop.getInputBufs(), loop.getOutputBufs()); + + MemDependencyChecker analyzer_lowered( + loop.getInputBufs(), loop.getOutputBufs()); + + // Test both unlowered and lowered form. + { + StmtPtr stmt = IRSimplifier::simplify(loop.root_stmt()); + stmt->accept(&analyzer_unlowered); + + // Outputs depend on inputs. + ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), AP.node())); + ASSERT_TRUE(analyzer_unlowered.dependsIndirectly(CT.buf(), BP.node())); + + // The last write to gemm should cover the total bound of the output. + std::shared_ptr outputAccess = + analyzer_unlowered.output(CT.buf()); + // A single dependency. + ASSERT_EQ(outputAccess->dependencies().size(), 1); + + // dependencies is a set with 1 element, so can just deref begin(). + std::shared_ptr gemmStore = + outputAccess->dependencies().begin()->second; + // Check its a store. + ASSERT_EQ(gemmStore->type(), AccessType::Store); + + ASSERT_TRUE(indexBoundsEquals(outputAccess->bounds(), gemmStore->bounds())); + + // Likewise the first read from each input cover the entire range of the + // input. + auto aInput = analyzer_unlowered.input(AP.node()); + auto bInput = analyzer_unlowered.input(BP.node()); + + // A single dependent each. + ASSERT_EQ(aInput->dependents().size(), 1); + ASSERT_EQ(bInput->dependents().size(), 1); + + // They're both loads. + std::shared_ptr aLoad = aInput->dependents().begin()->second; + std::shared_ptr bLoad = bInput->dependents().begin()->second; + ASSERT_EQ(aLoad->type(), AccessType::Load); + ASSERT_EQ(bLoad->type(), AccessType::Load); + + ASSERT_TRUE(indexBoundsEquals(aInput->bounds(), aLoad->bounds())); + ASSERT_TRUE(indexBoundsEquals(bInput->bounds(), bLoad->bounds())); + } + + loop.prepareForCodegen(); + SimpleIREvaluator cg(loop.root_stmt(), {AP, BP, CT}); + + // now check lowered dependency graph. + { + StmtPtr stmt = IRSimplifier::simplify(cg.stmt()); + stmt->accept(&analyzer_lowered); + + // Lowering will change the dimensionality of all bounds due to index + // flattening and will insert Allocates and Frees. + + auto history_before = analyzer_unlowered.getHistory(); + auto history_after = analyzer_lowered.getHistory(); + + ASSERT_EQ(history_before.size() + 2, history_after.size()); + + // Filter out the alloc/free; + auto isAllocFree = [](const auto& info) { + return info->type() == AccessType::Alloc || + info->type() == AccessType::Free; + }; + history_after.erase( + std::remove_if(history_after.begin(), history_after.end(), isAllocFree), + history_after.end()); + + ASSERT_EQ(history_before.size(), history_after.size()); + + for (size_t i = 0; i < history_before.size(); ++i) { + ASSERT_EQ(history_before[i]->type(), history_after[i]->type()); + ASSERT_EQ(history_before[i]->var(), history_after[i]->var()); + + if (history_before[i]->dependencies().size() != + history_after[i]->dependencies().size()) { + // Must depend on an Alloc. + ASSERT_TRUE(std::any_of( + history_after[i]->dependencies().begin(), + history_after[i]->dependencies().end(), + [](const auto& pair) { + return pair.second->type() == AccessType::Alloc; + })); + + ASSERT_EQ( + history_before[i]->dependencies().size() + 1, + history_after[i]->dependencies().size()); + } + + if (history_before[i]->dependents().size() != + history_after[i]->dependents().size()) { + // Must depend on an Free. + ASSERT_TRUE(std::any_of( + history_after[i]->dependents().begin(), + history_after[i]->dependents().end(), + [](const auto& pair) { + return pair.second->type() == AccessType::Free; + })); + + ASSERT_EQ( + history_before[i]->dependents().size() + 1, + history_after[i]->dependents().size()); + } + + // Inputs and outputs are not flattened, only accesses. + if (history_before[i]->type() == AccessType::Input || + history_before[i]->type() == AccessType::Output) { + ASSERT_EQ( + history_before[i]->bounds().size(), + history_after[i]->bounds().size()); + ASSERT_TRUE(indexBoundsEquals( + history_before[i]->bounds(), history_after[i]->bounds())); + } else { + ASSERT_EQ(history_after[i]->bounds().size(), 1); + ExprPtr flat_bounds = alloc(1); + + for (auto& b : history_before[i]->bounds()) { + flat_bounds = + alloc(flat_bounds, alloc(b.end, alloc(1))); + + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_TRUE(exprEquals(b.start, history_after[i]->bounds()[0].start)); + } + + flat_bounds = IRSimplifier::simplify(flat_bounds); + ExprPtr after_bounds = IRSimplifier::simplify( + alloc(history_after[i]->bounds()[0].end, alloc(1))); + ASSERT_TRUE(exprEquals(flat_bounds, after_bounds)); + } + } + } +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_memplanning.cpp b/test/cpp/tensorexpr/test_memplanning.cpp new file mode 100644 index 0000000000000..f5ee8747650fc --- /dev/null +++ b/test/cpp/tensorexpr/test_memplanning.cpp @@ -0,0 +1,708 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +extern void checkIR(StmtPtr s, const std::string& pattern); + +TEST(BufLiveRange, SingleRangeLine) { + VarHandle i("i", kInt), j("j", kInt); + BufHandle a("a", {32}, kFloat); + BufHandle b("b", {32, 32}, kFloat); + + // Construct Stmt: + // { + // for (int i = 0; i < 32; i++) { + // a[i] = 0; + // for (int j = 0; j < 32; j++) { + // a[i] = (a[i]) + (b[i, j]); + // } + // } + // } + + StorePtr aInit = Store::make(a, {i}, 0); + ExprHandle reduce = a.load({i}) + b.load({i, j}); + StorePtr aReduce = Store::make(a, {i}, reduce); + StmtPtr loop = + For::make(i, 0, 32, Block::make({aInit, For::make(j, 0, 32, aReduce)})); + + StmtPtr stmt = Block::make({loop}); + + auto range = BufLiveRange::liveRange(stmt, a.node()); + ASSERT_TRUE(std::get<0>(range) == 0); + ASSERT_TRUE(std::get<1>(range) == 0); +} + +TEST(BufLiveRange, MulRangeLine) { + VarHandle i("i", kInt); + BufHandle a("a", {32}, kFloat); + BufHandle b("b", {32}, kFloat); + + // Construct Stmt: + // { + // for (int i = 0; i < 32; i++) { + // if (i<10 ? 1 : 0) { + // a[i] = i + i; + // b[i] = i * i; + // } + // } + // for (int i = 0; i < 32; i++) { + // if (i>10 ? 1 : 0) { + // a[i] = i * i; + // b[i] = i + i; + // } + // } + // } + + StorePtr aStore_1 = Store::make(a, {i}, i + i); + StorePtr bStore_1 = Store::make(b, {i}, i * i); + StmtPtr loop_1 = For::make( + i, 0, 32, Cond::make(i < 10, Block::make({aStore_1, bStore_1}), NULL)); + + StorePtr aStore_2 = Store::make(a, {i}, i * i); + StorePtr bStore_2 = Store::make(b, {i}, i + i); + StmtPtr loop_2 = For::make( + i, 0, 32, Cond::make(i > 10, Block::make({aStore_2, bStore_2}), NULL)); + + StmtPtr stmt = Block::make({loop_1, loop_2}); + + auto range_a = BufLiveRange::liveRange(stmt, a.node()); + ASSERT_TRUE(std::get<0>(range_a) == 0); + ASSERT_TRUE(std::get<1>(range_a) == 1); + + auto range_b = BufLiveRange::liveRange(stmt, b.node()); + ASSERT_TRUE(std::get<0>(range_b) == 0); + ASSERT_TRUE(std::get<1>(range_b) == 1); +} + +TEST(MemPlanning, MemReuseWithTypeCast) { + int M = 4; + int N = 4; + int K = 4; + + BufHandle AP("A", {M, K}, kFloat); + BufHandle BP("B", {K, N}, kFloat); + + Tensor CT = Reduce( + "gemm", + {M, N}, + Sum(), + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { + return AP.load(m, k) * BP.load(k, n); + }, + {K}); + Tensor DT = + Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return CompareSelect::make( + CT.load(m, n), 0.0f, 0.0f, CT.load(m, n), kLT); + }); + Tensor ET = + Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return Cast::make(kQUInt8, DT.load(m, n) + DT.load(m, n)); + }); + Tensor FT = + Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return ET.load(m, n); + }); + StmtPtr stmt = + tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); + + // Constructed stmt: + // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], + // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are + // different: 'E' type quint8 < 'gemm' type float. We'll reuse 'gemm' for 'E' + // with typecasting. + //{ + // for (int i = 0; i < 4; i++) { + // for (int i_1 = 0; i_1 < 4; i_1++) { + // gemm[i, i_1] = float(0); + // for (int i_2 = 0; i_2 < 4; i_2++) { + // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, + // i_1]), reduce_args={i_2}); + // } + // } + // } + // for (int i_3 = 0; i_3 < 4; i_3++) { + // for (int i_4 = 0; i_4 < 4; i_4++) { + // relu[i_3, i_4] = (gemm[i_3, i_4])<0.f ? 0.f : (gemm[i_3, i_4]); + // } + // } + // for (int i_5 = 0; i_5 < 4; i_5++) { + // for (int i_6 = 0; i_6 < 4; i_6++) { + // E[i_5, i_6] = quint8((relu[i_5, i_6]) + (relu[i_5, i_6])); + // } + // } + // for (int i_7 = 0; i_7 < 4; i_7++) { + // for (int i_8 = 0; i_8 < 4; i_8++) { + // F[i_7, i_8] = E[i_7, i_8]; + // } + // } + //} + + LoopNest l(stmt, {FT.buf()}); + l.prepareForCodegen(); + SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT}); + + checkIR(cg.stmt(), R"IR( +# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] +# CHECK: Allocate(relu); // dtype=float, dims=[4, 4] +# CHECK: Alias(E,gemm); +# CHECK: Free(relu); +# CHECK: Free(gemm))IR"); + + PaddedBuffer a_v(M, K, "a"); + PaddedBuffer b_v(K, N, "b"); + PaddedBuffer o1(M, N, "e_before"); + PaddedBuffer o2(M, N, "e_after"); + + for (const auto m : c10::irange(M)) { + for (const auto k : c10::irange(K)) { + a_v(m, k) = at::randn({1}).item().to(); + } + } + + for (const auto k : c10::irange(K)) { + for (const auto n : c10::irange(N)) { + b_v(k, n) = at::randn({1}).item().to(); + } + } + + cg.call({a_v, b_v, o1}); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); + + checkIR(cg_llvm.stmt(), R"IR( +# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] +# CHECK: Allocate(relu); // dtype=float, dims=[4, 4] +# CHECK: Alias(E,gemm); +# CHECK: Free(relu); +# CHECK: Free(gemm))IR"); + + cg_llvm.call({a_v, b_v, o2}); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(o1, o2, 1e-5); +#endif +} + +TEST(MemPlanning, NoMemReuseForLargerType) { + int M = 4; + int N = 4; + int K = 4; + + BufHandle AP("A", {M, K}, kShort); + BufHandle BP("B", {K, N}, kShort); + + Tensor CT = Reduce( + "gemm", + {M, N}, + Sum(), + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { + return AP.load(m, k) * BP.load(k, n); + }, + {K}); + auto zero = Cast::make(CT.buf()->dtype(), 0); + Tensor DT = + Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return CompareSelect::make( + CT.load(m, n), zero, zero, CT.load(m, n), kLT); + }); + Tensor ET = + Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return Cast::make(kFloat, DT.load(m, n) + DT.load(m, n)); + }); + Tensor FT = + Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return ET.load(m, n); + }); + StmtPtr stmt = + tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); + + // Constructed stmt: + // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], + // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are + // different: 'E' type float > 'gemm' type int16. We won't reuse 'gemm' for + // 'E'. + //{ + // for (int i = 0; i < 4; i++) { + // for (int i_1 = 0; i_1 < 4; i_1++) { + // gemm[i, i_1] = int16_t(0); + // for (int i_2 = 0; i_2 < 4; i_2++) { + // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, + // i_1]), reduce_args={i_2}); + // } + // } + // } + // for (int i_3 = 0; i_3 < 4; i_3++) { + // for (int i_4 = 0; i_4 < 4; i_4++) { + // relu[i_3, i_4] = (gemm[i_3, i_4]) a_v(M, K, "a"); + PaddedBuffer b_v(K, N, "b"); + PaddedBuffer o1(M, N, "e_before"); + PaddedBuffer o2(M, N, "e_after"); + + for (const auto m : c10::irange(M)) { + for (const auto k : c10::irange(K)) { + a_v(m, k) = at::randn({1}).item().to(); + } + } + + for (const auto k : c10::irange(K)) { + for (const auto n : c10::irange(N)) { + b_v(k, n) = at::randn({1}).item().to(); + } + } + + cg.call({a_v, b_v, o1}); + +#ifdef TORCH_ENABLE_LLVM + LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); + + checkIR(cg_llvm.stmt(), R"IR( +# CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4] +# CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4] +# CHECK: Allocate(E); // dtype=float, dims=[4, 4] +# CHECK: Free(E); +# CHECK: Free(relu); +# CHECK: Free(gemm))IR"); + + cg_llvm.call({a_v, b_v, o2}); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(o1, o2, 1e-5); +#endif +} + +TEST(MemPlanning, SameBufSizeMemReuse) { + int M = 1024; + int N = 1024; + int K = 2048; + + BufHandle AP("A", {M, K}, kFloat); + BufHandle BP("B", {K, N}, kFloat); + + Tensor CT = Reduce( + "gemm", + {M, N}, + Sum(), + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { + return AP.load(m, k) * BP.load(k, n); + }, + {K}); + Tensor DT = + Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + auto zero = Cast::make(CT.buf()->dtype(), 0); + return CompareSelect::make( + CT.load(m, n), zero, zero, CT.load(m, n), kLT); + }); + Tensor ET = + Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return DT.load(m, n) + DT.load(m, n); + }); + Tensor FT = + Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return ET.load(m, n) * ET.load(m, n); + }); + auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); + + // Constructed stmt: + // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], + // add [2, 3] Buffer 'gemm' and 'add' are the same size; we'll reuse 'gemm' + // for 'add'. + //{ + // for (int M = 0; M < 1024; M++) { + // for (int N = 0; N < 1024; N++) { + // gemm[M, N] = float(0); + // for (int K = 0; K < 2048; K++) { + // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), + // reduce_args={K}); + // } + // } + // } + // for (int M_1 = 0; M_1 < 1024; M_1++) { + // for (int N_1 = 0; N_1 < 1024; N_1++) { + // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); + return CompareSelect::make( + CT.load(m, n), zero, zero, CT.load(m, n), kLT); + }); + Tensor ET = + Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return DT.load(m, n) + DT.load(m, n); + }); + Tensor FT = + Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return ET.load(m, n) * ET.load(m, n); + }); + Tensor GT = + Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return FT.load(m, n) - ET.load(m, n); + }); + + auto stmt = + Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt()}); + + // Constructed stmt: + // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], + // add [2, 3], mul [3, 4] Buffer 'gemm', 'relu, ''add' and 'mul' are the same + // size; we'll reuse 'gemm' for 'add', and reuse 'relu' for 'mul' + //{ + // for (int M = 0; M < 1024; M++) { + // for (int N = 0; N < 1024; N++) { + // gemm[M, N] = float(0); + // for (int K = 0; K < 2048; K++) { + // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), + // reduce_args={K}); + // } + // } + // } + // for (int M_1 = 0; M_1 < 1024; M_1++) { + // for (int N_1 = 0; N_1 < 1024; N_1++) { + // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); + return CompareSelect::make( + CT.load(m, n), zero, zero, CT.load(m, n), kLT); + }); + Tensor ET = + Compute("add", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return DT.load(m, n) + DT.load(m, n); + }); + Tensor FT = + Compute("mul", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return ET.load(m, n) * ET.load(m, n); + }); + Tensor GT = + Compute("sub", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return FT.load(m, n) - 1; + }); + Tensor HT = + Compute("div", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { + return GT.load(m, n) / 2; + }); + + auto stmt = Block::make( + {CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt(), GT.stmt(), HT.stmt()}); + + // Constructed stmt: + // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], + // add [2, 3], mul [3, 4], sub [4, 5] Buffer 'gemm', 'relu, ''add', 'mul' and + // 'sub' are the same size; we'll reuse 'gemm' for 'add', reuse 'relu' for + // 'mul', and reuse 'gemm' for 'sub'. + //{ + // for (int M = 0; M < 1024; M++) { + // for (int N = 0; N < 1024; N++) { + // gemm[M, N] = float(0); + // for (int K = 0; K < 2048; K++) { + // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), + // reduce_args={K}); + // } + // } + // } + // for (int M_1 = 0; M_1 < 1024; M_1++) { + // for (int N_1 = 0; N_1 < 1024; N_1++) { + // relu[M_1, N_1] = (gemm[M_1, N_1])dtype(), 0); + return CompareSelect::make( + CT.load(m, n), zero, zero, CT.load(m, n), kLT); + }); + Tensor ET = Compute( + "add", {M * 2, N * 2}, [&](const ExprHandle& em, const ExprHandle& en) { + return DT.load(em / 2, en / 2) + DT.load(em / 2, en / 2); + }); + Tensor FT = Compute( + "mul", {M * 2, N * 2}, [&](const ExprHandle& fm, const ExprHandle& fn) { + return ET.load(fm, fn) * ET.load(fm, fn); + }); + auto stmt = Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); + + // Constructed stmt: + // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], + // add [2, 3] We do not reuse buffer 'gemm' for 'add' because the size of + // buffer 'gemm' is smaller. + //{ + // for (int M = 0; M < 1024; M++) { + // for (int N = 0; N < 1024; N++) { + // gemm[M, N] = float(0); + // for (int K = 0; K < 2048; K++) { + // gemm[M, N] = ReduceOp((gemm[M, N]) + (A[M, K]) * (B[K, N]), + // reduce_args={K}); + // } + // } + // } + // for (int M_1 = 0; M_1 < 1024; M_1++) { + // for (int N_1 = 0; N_1 < 1024; N_1++) { + // relu[M_1, N_1] = (gemm[M_1, N_1]) +#include +#include +#include +#include +#include + +using namespace torch::jit::tensorexpr; + +using Tensors = std::vector; +using Args = std::vector; +std::unique_ptr compile( + const Args& inputs, + const Tensors& outputs) { + LoopNest nest({outputs}); + nest.prepareForCodegen(); + nest.simplify(); + auto join = inputs; + join.insert(join.end(), outputs.begin(), outputs.end()); + return std::make_unique(nest.root_stmt(), join); +} + +TEST(Ops, Sum) { + constexpr int M = 8; + constexpr int N = 16; + std::vector testDims = {{0}, {1}, {0, 1}}; + std::vector> outputShapes = {{N}, {M}, {}}; + for (unsigned idx = 0; idx < testDims.size(); idx++) { + const auto& dims = testDims[idx]; + const auto& outShape = outputShapes[idx]; + + BufHandle a("a", {M, N}, kFloat); + std::vector outStrides = + c10::fmap(make_contiguous_strides(outShape)); + Tensor b = computeSum( + {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU); + auto cg = compile({a}, {b}); + + auto at = at::arange(M * N, at::kFloat).view({M, N}); + auto ref = at::sum(at, dims); + auto bt = at::empty_like(ref); + + cg->call({at.data_ptr(), bt.data_ptr()}); + + ASSERT_TRUE(at::allclose(bt, ref)); + } +} + +TEST(Ops, ChannelsLastSum) { + constexpr int A = 2; + constexpr int B = 3; + constexpr int C = 4; + constexpr int D = 5; + constexpr int E = 6; + std::vector testDims = {{0}, {1}, {0, 1}}; + + std::vector> outputShapes = { + {B, C, D, E}, {A, C, D, E}, {C, D, E}}; + for (unsigned idx = 0; idx < testDims.size(); idx++) { + const auto& dims = testDims[idx]; + const auto& outShape = outputShapes[idx]; + + BufHandle a("a", {A, B, C, D, E}, kFloat); + std::vector outStrides = + c10::fmap(make_channels_last_strides(outShape)); + Tensor b = computeSum( + {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU); + auto cg = compile({a}, {b}); + + auto at = at::arange(A * B * C * D * E, at::kFloat).view({A, B, C, D, E}); + auto ref = at::sum(at, dims); + auto bt = at::empty_like(ref); + + cg->call({at.data_ptr(), bt.data_ptr()}); + + ASSERT_TRUE(at::allclose(bt, ref)); + } +} diff --git a/test/cpp/tensorexpr/test_quantization.cpp b/test/cpp/tensorexpr/test_quantization.cpp new file mode 100644 index 0000000000000..af6b539ff33e9 --- /dev/null +++ b/test/cpp/tensorexpr/test_quantization.cpp @@ -0,0 +1,452 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "torch/csrc/jit/tensorexpr/eval.h" +#include "torch/csrc/jit/tensorexpr/ir.h" + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; +using SimpleIRExprEval = ExprEval; +using namespace torch::indexing; +using namespace torch::jit::tensorexpr; + +class Quantization : public ::testing::Test { + public: + void SetUp() override { + getTEMustUseLLVMOnCPU() = false; + } +}; + +TEST_F(Quantization, QuantDequantInt8) { + const auto graph_string = R"IR( + graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)): + %2 : int = prim::Constant[value=12]() + %3 : int = prim::Constant[value=13]() + %4 : float = prim::Constant[value=0.1]() + %q.1 : QInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2) + %6 : Float(2, 2) = aten::dequantize(%q.1) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQInt8); + auto y_expected = at::dequantize(q); + TensorExprKernel k(graph); + std::vector inputs = {x}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +TEST_F(Quantization, QuantDequantUInt8) { + const auto graph_string = R"IR( + graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)): + %2 : int = prim::Constant[value=13]() + %3 : int = prim::Constant[value=122]() + %4 : float = prim::Constant[value=0.1]() + %q.1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2) + %6 : Float(2, 2) = aten::dequantize(%q.1) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = 2 * at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8); + auto y_expected = at::dequantize(q); + TensorExprKernel k(graph); + std::vector inputs = {x}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +TEST_F(Quantization, QuantDequantUInt8_NLC) { + const auto graph_string = R"IR( + graph(%x.1 : Float(1, 2, 2, strides=[4, 1, 2], device=cpu)): + %2 : int = prim::Constant[value=13]() + %3 : int = prim::Constant[value=122]() + %4 : float = prim::Constant[value=0.1]() + %q.1 : QUInt8(1, 2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2) + %6 : Float(1, 2, 2) = aten::dequantize(%q.1) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = 2 * at::rand({1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + x.unsafeGetTensorImpl()->set_sizes_and_strides( + std::initializer_list{1, 2, 2}, {4, 1, 2}); + auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8); + auto y_expected = at::dequantize(q); + TensorExprKernel k(graph); + std::vector inputs = {x}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x:\n" << x << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +at::Tensor quantized_add( + at::Tensor x1, + at::Tensor x2, + double scale, + int64_t zero) { + const auto qadd_op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("quantized::add", "") + .typed(); + return qadd_op.call(x1, x2, scale, zero); +} + +TEST_F(Quantization, QuantAddDequantInt8) { + const auto graph_string = R"IR( + graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): + %2 : int = prim::Constant[value=12]() + %qz1 : int = prim::Constant[value=13]() + %qs1 : float = prim::Constant[value=0.1]() + %qz2 : int = prim::Constant[value=13]() + %qs2 : float = prim::Constant[value=0.1]() + %qza : int = prim::Constant[value=13]() + %qsa : float = prim::Constant[value=0.1]() + %q1 : QInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) + %q2 : QInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) + %qa : QInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza) + %6 : Float(2, 2) = aten::dequantize(%qa) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQInt8); + auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQInt8); + auto qa = quantized_add(q1, q2, 0.1f, 13); + auto y_expected = at::dequantize(qa); + TensorExprKernel k(graph); + std::vector inputs = {x1, x2}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x1:\n" << x1 << std::endl; + std::cout << "q1:\n" << q1 << std::endl; + std::cout << "x2:\n" << x2 << std::endl; + std::cout << "q2:\n" << q2 << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +TEST_F(Quantization, QuantAddDequantUInt8) { + const auto graph_string = R"IR( + graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): + %2 : int = prim::Constant[value=13]() + %qz1 : int = prim::Constant[value=13]() + %qs1 : float = prim::Constant[value=0.1]() + %qz2 : int = prim::Constant[value=13]() + %qs2 : float = prim::Constant[value=0.1]() + %qza : int = prim::Constant[value=13]() + %qsa : float = prim::Constant[value=0.1]() + %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) + %q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) + %qa : QUInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza) + %6 : Float(2, 2) = aten::dequantize(%qa) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); + auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8); + auto qa = quantized_add(q1, q2, 0.1f, 13); + auto y_expected = at::dequantize(qa); + + TensorExprKernel k(graph); + std::vector inputs = {x1, x2}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x1:\n" << x1 << std::endl; + std::cout << "q1:\n" << q1 << std::endl; + std::cout << "x2:\n" << x2 << std::endl; + std::cout << "q2:\n" << q2 << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +TEST_F(Quantization, QuantSigmoidDequantUInt8) { + const auto graph_string = R"IR( + graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu)): + %2 : int = prim::Constant[value=13]() + %qz1 : int = prim::Constant[value=13]() + %qs1 : float = prim::Constant[value=0.1]() + %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) + %qa : QUInt8(2, 2) = aten::sigmoid(%q1) + %6 : Float(2, 2) = aten::dequantize(%qa) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); + auto qs = at::sigmoid(q1); + auto y_expected = at::dequantize(qs); + + TensorExprKernel k(graph); + std::vector inputs = {x1}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x1:\n" << x1 << std::endl; + std::cout << "q1:\n" << q1 << std::endl; + std::cout << "qs:\n" << qs << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +at::Tensor quantized_mul( + at::Tensor x1, + at::Tensor x2, + double scale, + int64_t zero) { + const auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("quantized::mul", "") + .typed(); + return op.call(x1, x2, scale, zero); +} + +TEST_F(Quantization, QuantMulDequantUInt8) { + const auto graph_string = R"IR( + graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)): + %2 : int = prim::Constant[value=13]() + %qz1 : int = prim::Constant[value=13]() + %qs1 : float = prim::Constant[value=0.1]() + %qz2 : int = prim::Constant[value=13]() + %qs2 : float = prim::Constant[value=0.1]() + %qza : int = prim::Constant[value=13]() + %qsa : float = prim::Constant[value=0.1]() + %q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2) + %q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2) + %qa : QUInt8(2, 2) = quantized::mul(%q1, %q2, %qsa, %qza) + %6 : Float(2, 2) = aten::dequantize(%qa) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8); + auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8); + auto qa = quantized_mul(q1, q2, 0.1f, 13); + auto y_expected = at::dequantize(qa); + + TensorExprKernel k(graph); + std::vector inputs = {x1, x2}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x1:\n" << x1 << std::endl; + std::cout << "q1:\n" << q1 << std::endl; + std::cout << "x2:\n" << x2 << std::endl; + std::cout << "q2:\n" << q2 << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +TEST_F(Quantization, QuantUpsampleNearst2dDequantUInt8) { + const auto graph_string = R"IR( + graph(%x : Float(1, 1, 4, 4, strides=[16, 16, 4, 1], device=cpu)): + %2 : int = prim::Constant[value=13]() + %4 : NoneType = prim::Constant() + %3 : int[] = prim::Constant[value=[6, 6]]() + %qz : int = prim::Constant[value=13]() + %qs : float = prim::Constant[value=0.1]() + %q : QUInt8(1, 1, 4, 4) = aten::quantize_per_tensor(%x, %qs, %qz, %2) + %qu : QUInt8(1, 1, 6, 6) = aten::upsample_nearest2d(%q, %3, %4) + %6 : Float(1, 1, 6, 6) = aten::dequantize(%qu) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = at::rand({1, 1, 4, 4}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8); + auto qu = at::upsample_nearest2d(q, {6, 6}); + auto y_expected = at::dequantize(qu); + + TensorExprKernel k(graph); + std::vector inputs = {x}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x:\n" << x << std::endl; + std::cout << "q:\n" << q << std::endl; + std::cout << "qu:\n" << qu << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +TEST_F(Quantization, UpsampleNearst2d) { + const auto graph_string = R"IR( + graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)): + %4 : NoneType = prim::Constant() + %3 : int[] = prim::Constant[value=[4, 4]]() + %u : Float(1, 1, 4, 4) = aten::upsample_nearest2d(%x, %3, %4) + return (%u))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto y_expected = at::upsample_nearest2d(x, {4, 4}); + + TensorExprKernel k(graph); + std::vector inputs = {x}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x:\n" << x << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +at::Tensor quantized_cat( + c10::List const& xs, + int64_t dim, + double scale, + int64_t zero) { + const auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("quantized::cat", "") + .typed const&, + int64_t, + std::optional, + std::optional)>(); + return op.redispatch( + DispatchKeySet({DispatchKey::QuantizedCPU}), xs, dim, scale, zero); +} + +TEST_F(Quantization, QuantCatDequantUInt8) { + const auto graph_string = R"IR( + graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %y : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %z : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)): + %qdt : int = prim::Constant[value=13]() + %qxz : int = prim::Constant[value=13]() + %qxs : float = prim::Constant[value=0.1]() + %qyz : int = prim::Constant[value=16]() + %qys : float = prim::Constant[value=0.15]() + %qzz : int = prim::Constant[value=19]() + %qzs : float = prim::Constant[value=0.2]() + %qx : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%x, %qxs, %qxz, %qdt) + %qy : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%y, %qys, %qyz, %qdt) + %qz : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%z, %qzs, %qzz, %qdt) + %catx : Tensor[] = prim::ListConstruct(%qx, %qy, %qz) + %catd : int = prim::Constant[value=0]() + %qcat : QUInt8(3, 1, 2, 2) = quantized::cat(%catx, %catd, %qxs, %qxz) + %cat : Float(3, 1, 2, 2) = aten::dequantize(%qcat) + return (%cat))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto y = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto z = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat)); + auto qx = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8); + auto qy = at::quantize_per_tensor(y, 0.15f, 16, at::kQUInt8); + auto qz = at::quantize_per_tensor(z, 0.2f, 19, at::kQUInt8); + auto qcat = quantized_cat({qx, qy, qz}, 0, 0.1f, 13); + auto expected = at::dequantize(qcat); + + TensorExprKernel k(graph); + std::vector inputs = {x, y, z}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto result = stack[0].toTensor(); + bool check = at::allclose(expected, result); + if (!check) { + std::cout << "x:\n" << x << std::endl; + std::cout << "y:\n" << y << std::endl; + std::cout << "z:\n" << z << std::endl; + std::cout << "qx:\n" << qx << std::endl; + std::cout << "qy:\n" << qy << std::endl; + std::cout << "qz:\n" << qz << std::endl; + std::cout << "qcat:\n" << qcat << std::endl; + std::cout << "expected:\n" << expected << std::endl; + std::cout << "result:\n" << result << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp new file mode 100644 index 0000000000000..bdc744ae4e033 --- /dev/null +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -0,0 +1,1928 @@ +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +TEST(Reductions, ReduceSum0D_1) { + const int M = 10; + + BufHandle b("b", {M}, kFloat); + std::vector in(M); + for (const auto j : c10::irange(M)) { + in[j] = j; + } + + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {}); + LoopNest loop({c}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + for (const auto i : c10::irange(M)) { + ASSERT_EQ(out[i], in[i]); + } +} + +TEST(Reductions, ReduceSum0D_2) { + BufHandle b("b", {}, kFloat); + std::vector in(1); + in[0] = 77.7; + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {}); + LoopNest loop({c}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], in[0]); +} + +// Sum an array to a single value. +TEST(Reductions, ReduceSum1D) { + BufHandle b("b", {10}, kFloat); + std::vector in(10); + for (const auto j : c10::irange(10)) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {10}); + LoopNest loop({c}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 45); +} +// Sum a 2D tensor to a 1D tensor with dynamic shapes. +TEST(Reductions, ReduceSum2D) { + const int M = 3; + const int N = 7; + + VarHandle m("m", kInt); + VarHandle n("n", kInt); + + BufHandle b("b", {m, n}, kFloat); + std::vector in(M * N); + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + in[i * N + j] = j; + } + } + + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N}); + LoopNest loop({c}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c, n, m}); + + cg.call({in, out, 5, 7}); + + float expected = 0; + for (const auto i : c10::irange(N)) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + expected += i; + } + + for (const auto i : c10::irange(M)) { + ASSERT_EQ(out[i], expected); + } +} + +// Sum a 3D tensor to both a 2D and 1D tensor, then reduce the 2D tensor flat to +// check our work. +TEST(Reductions, ReduceSum3D) { + const int M = 10; + VarHandle m("m", kInt); + + BufHandle b("b", {2, 3, m}, kFloat); + + Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m}); + LoopNest loop({c}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c, m}); + + std::vector bData(2 * 3 * M, 0); + std::vector cData(2 * 3, 6.0f); + std::vector dData(2, 1.0f); + std::vector eData(2, 1.0f); + + for (int i = 0; i < 2 * 3; ++i) { + for (const auto j : c10::irange(M)) { + bData[i * M + j] = j; + } + } + + cg.call({bData, cData, M}); + float expected = 0; + for (const auto i : c10::irange(M)) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + expected += i; + } + + for (int i = 0; i < 2 * 3; ++i) { + ASSERT_EQ(cData[i], expected); + } + + Tensor d = Reduce("sum2", {2}, Sum(), b, {3, m}); + LoopNest loop2({d}); + loop2.prepareForCodegen(); + StmtPtr s2 = loop2.root_stmt(); + s2 = IRSimplifier::simplify(s2); + + SimpleIREvaluator cg2(s2, {b, d, m}); + cg2.call({bData, dData, M}); + + // We're combining an additional dimension of 3, so the sum is 3x. + expected = expected * 3; + + for (const auto i : c10::irange(2)) { + ASSERT_EQ(dData[i], expected); + } + + // This is the same as just reducing the original result across that axis. + BufHandle c_buf(c.buf()); + Tensor e = Reduce("sum3", {2}, Sum(), c_buf, {3}); + LoopNest loop3({e}); + loop3.prepareForCodegen(); + StmtPtr s3 = loop3.root_stmt(); + s3 = IRSimplifier::simplify(s3); + + SimpleIREvaluator cg3(s3, {c, e}); + cg3.call({cData, eData}); + + for (const auto i : c10::irange(2)) { + ASSERT_EQ(eData[i], expected); + } +} + +// Sum a large (10 D) Tensor 5 dimensions in. +TEST(Reductions, ReduceSum10D) { + BufHandle in_("in_", {2, 3, 2, 3, 2, 3, 2, 3, 2, 3}, kFloat); + const int InputSize = 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3; + BufHandle out_("out_", {2, 3, 2, 3, 2}, kFloat); + const int OutputSize = 2 * 3 * 2 * 3 * 2; + + std::vector in(InputSize, 1.f); + std::vector out(OutputSize, -1.f); + + Tensor c = Reduce("sum", {2, 3, 2, 3, 2}, Sum(), in_, {3, 2, 3, 2, 3}); + LoopNest loop({c}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {in_, c}); + + cg.call({in, out}); + + // NOLINTNEXTLINE(bugprone-integer-division) + float expected = InputSize / OutputSize; + for (const auto i : c10::irange(OutputSize)) { + ASSERT_EQ(out[i], expected); + } +} + +// Reduce via Mul rather than Add using a custom Reducer. +TEST(Reductions, ReduceProduct) { + const int M = 4; + const int N = 4; + + BufHandle b("b", {M, N}, kFloat); + std::vector in(M * N); + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + in[i * N + j] = 2 + j; + } + } + + std::vector out(M, -1.f); + + Reducer product( + ExprHandle(1.f), [](ExprHandle a, ExprHandle b) { return a * b; }); + + Tensor c = Reduce("product", {M}, product, b, {N}); + LoopNest loop({c}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + + float expected = 1; + for (const auto i : c10::irange(N)) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + expected *= 2 + i; + } + + for (const auto i : c10::irange(M)) { + ASSERT_EQ(out[i], expected); + } +} + +// Maximum reductions. +TEST(Reductions, ReduceMax) { + BufHandle in_("b", {10}, kFloat); + + std::vector in(10); + std::vector out(1, -1.f); + for (const auto j : c10::irange(10)) { + in[j] = j; + } + + Tensor dm1 = Reduce("max", {}, Maximum(kFloat), in_, {10}); + + LoopNest loop({dm1}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + SimpleIREvaluator cg(s, {in_, dm1}); + + cg.call({in, out}); + + ASSERT_EQ(out[0], 9); + + BufHandle in2_("b", {2, 5}, kFloat); + std::vector out2(2, -1.f); + + Tensor m2d = Reduce("max", {2}, Maximum(kFloat), in2_, {5}); + + LoopNest loop2({m2d}); + loop2.prepareForCodegen(); + s = loop2.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg2(s, {in2_, m2d}); + cg2.call({in, out2}); + + ASSERT_EQ(out2[0], 4); + ASSERT_EQ(out2[1], 9); +} + +// Minimum reduction, with custom initialization. +TEST(Reductions, ReduceMinCustomInitializer) { + VarHandle minInit("minInit", kFloat); + BufHandle in_("b", {10}, kFloat); + + std::vector in(10); + std::vector out(1, -1.f); + for (const auto j : c10::irange(10)) { + in[j] = 10 + j; + } + + Tensor min = Reduce( + "min", + {}, + Minimum(ExprHandle(minInit)), + [&](ParameterList& v) { return in_.load(v); }, + {10}); + + LoopNest loop({min}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {in_, min, minInit}); + + // Works normally (note that out data starts lower than the correct + // minimum). + cg.call({in, out, std::numeric_limits::max()}); + ASSERT_EQ(out[0], 10); + + // With an initalizer lower than the min, that's the min. + cg.call({in, out, 5.f}); + ASSERT_EQ(out[0], 5); +} + +// Example implementation of Any/All. +// TODO: this is very awkward without logical And/Or operators. +TEST(Reductions, ReduceAnyAll) { + VarHandle searchValue("searchValue", kInt); + BufHandle b("b", {4, 10}, kInt); + + Reducer anyEqSV(ExprHandle(0), [](ExprHandle a, ExprHandle b) { + return CompareSelect::make(a, 1, 1, b, kEQ); + }); + + Tensor any = Reduce( + "anyEqual", + {4}, + anyEqSV, + [&](const auto& i, const auto& j) { + return CompareSelect::make(b.load(i, j), searchValue, kEQ); + }, + {10}); + + LoopNest loop({any}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, any, searchValue}); + + std::vector in(40, 0); + std::vector out(4, 0); + + // input has 0-39 in 4 rows. + for (const auto i : c10::irange(40)) { + in[i] = i; + } + cg.call({in, out, 1}); + + // only the first row has 1 + ASSERT_EQ(out[0], 1); + ASSERT_EQ(out[1], 0); + ASSERT_EQ(out[2], 0); + ASSERT_EQ(out[3], 0); + + cg.call({in, out, 15}); + + // 15 in the 3rd row + ASSERT_EQ(out[0], 0); + ASSERT_EQ(out[1], 1); + ASSERT_EQ(out[2], 0); + ASSERT_EQ(out[3], 0); + + Reducer allGTSV(ExprHandle(1), [](ExprHandle a, ExprHandle b) { + return CompareSelect::make(a, 0, 0, b, kEQ); + }); + + Tensor allGreaterThan = Reduce( + "allGreaterThan", + {4}, + allGTSV, + [&](const auto& i, const auto& j) { + return CompareSelect::make(b.load(i, j), searchValue, kGT); + }, + {10}); + + LoopNest loop2({allGreaterThan}); + loop2.prepareForCodegen(); + s = loop2.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg2(s, {b, allGreaterThan, searchValue}); + + cg2.call({in, out, 11}); + + // 11 is in row 2. + ASSERT_EQ(out[0], 0); + ASSERT_EQ(out[1], 0); + ASSERT_EQ(out[2], 1); + ASSERT_EQ(out[3], 1); + + cg2.call({in, out, -3}); + + // All are positive. + ASSERT_EQ(out[0], 1); + ASSERT_EQ(out[1], 1); + ASSERT_EQ(out[2], 1); + ASSERT_EQ(out[3], 1); +} + +TEST(Reductions, ReduceMatmul2D) { + BufHandle tA("tA", {3, 2}, kFloat); + BufHandle tB("tB", {2, 3}, kFloat); + + std::vector tA_(6); + std::vector tB_(6); + + std::vector out(9, -1.f); + for (const auto i : c10::irange(3)) { + for (const auto j : c10::irange(2)) { + tA_[i * 2 + j] = i * 2 + j; + tB_[j * 3 + i] = i * 2 + j; + } + } + + Tensor mm = Reduce( + "mm", + {3, 3}, + Sum(), + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { + return tA.load(m, k) * tB.load(k, n); + }, + {2}); + + LoopNest loop({mm}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {tA, tB, mm}); + cg.call({tA_, tB_, out}); + + std::vector expected( + {1.f, 3.f, 5.f, 3.f, 13.f, 23.f, 5.f, 23.f, 41.f}); + + for (const auto i : c10::irange(9)) { + ASSERT_EQ(out[i], expected[i]); + } +} + +TEST(Reductions, ReduceRfactorLike) { + BufHandle in("in", {10, 10}, kFloat); + std::vector in_(100); + for (const auto i : c10::irange(100)) { + in_[i] = i; + } + std::vector in_rf_(10, -2.f); + std::vector out(1, -1.f); + + Tensor l1 = Reduce("l1", {10}, Sum(), in, {10}); + BufHandle in_rf(l1.buf()); + + Tensor l2 = Reduce("l2", {}, Sum(), in_rf, {10}); + + LoopNest loop({l1, l2}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {in, l1, l2}); + cg.call({in_, in_rf_, out}); + + ASSERT_EQ(out[0], 99 * 50); +} + +TEST(Reductions, ReduceAsProducer) { + const int M = 10; + VarHandle m("m", kInt); + + BufHandle a("a", {2, 3}, kFloat); + BufHandle b("b", {2, 3, m}, kFloat); + + Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m}); + Tensor d = + Compute("scale", {2, 3}, [&](const VarHandle& l, const VarHandle& n) { + return c.load(l, n) * a.load(l, n); + }); + LoopNest loop({d}, {c, d}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {a, b, d, m}); + + std::vector aData(2 * 3, 0); + std::vector bData(2 * 3 * M, 0); + std::vector dData(2 * 3, 6.0f); + + for (int i = 0; i < 2 * 3; ++i) { + aData[i] = 6 - i; + for (const auto j : c10::irange(M)) { + bData[i * M + j] = j; + } + } + + cg.call({aData, bData, dData, M}); + float expected = 0; + for (const auto i : c10::irange(M)) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + expected += i; + } + for (int i = 0; i < 2 * 3; ++i) { + ASSERT_EQ(dData[i], expected * (6 - i)); + } +} + +TEST(Reductions, ReduceAsConsumer) { + const int M = 10; + VarHandle m("m", kInt); + + BufHandle a("a", {2, 3, m}, kFloat); + BufHandle b("b", {2, 3, m}, kFloat); + + Tensor c = Compute( + "scale", + {2, 3, m}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {2}, Sum(), c, {3, m}); + LoopNest loop({d}, {c, d}); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {a, b, d, m}); + + std::vector aData(2 * 3 * M, 0); + std::vector bData(2 * 3 * M, 0); + std::vector dData(2, 6.0f); + + for (int i = 0; i < 2 * 3; ++i) { + for (const auto j : c10::irange(M)) { + bData[i * M + j] = j + 1; + aData[i * M + j] = 6 - i; + } + } + + cg.call({aData, bData, dData, M}); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + float expected[2] = {0, 0}; + for (const auto i : c10::irange(2)) { + for (const auto j : c10::irange(3)) { + for (const auto k : c10::irange(M)) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + expected[i] += (k + 1) * (6 - (i * 3 + j)); + } + } + } + + for (const auto i : c10::irange(2)) { + ASSERT_EQ(dData[i], expected[i]); + } +} + +TEST(Reductions, SplitReduceAxis) { + BufHandle in("in", {16, 8}, kFloat); + + std::vector in_(16 * 8); + for (const auto i : c10::irange(16)) { + for (const auto j : c10::irange(8)) { + in_[i * 8 + j] = i; + } + } + std::vector out(16, -1.f); + + Tensor tensor = Reduce("sum", {16}, Sum(), in, {8}); + LoopNest l({tensor}); + std::vector loops = l.getLoopStmtsFor(tensor); + LoopNest::splitWithTail(loops[1], 2); + + l.prepareForCodegen(); + + StmtPtr s = l.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {in, tensor}); + cg.call({in_, out}); + + for (const auto i : c10::irange(16)) { + ASSERT_EQ(out[i], i * 8); + } +} + +TEST(Reductions, SplitNonReduceAxis) { + BufHandle in("in", {16, 8}, kFloat); + + std::vector in_(16 * 8); + for (const auto i : c10::irange(16)) { + for (const auto j : c10::irange(8)) { + in_[i * 8 + j] = i; + } + } + std::vector out(16, -1.f); + Tensor tensor = Reduce("sum", {16}, Sum(), in, {8}); + LoopNest l({tensor}); + std::vector loops = l.getLoopStmtsFor(tensor); + LoopNest::splitWithTail(loops[0], 2); + LoopNest::splitWithTail(loops[0], 2); + + l.prepareForCodegen(); + + StmtPtr s = l.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {in, tensor}); + cg.call({in_, out}); + + for (const auto i : c10::irange(16)) { + ASSERT_EQ(out[i], i * 8); + } +} + +TEST(Reductions, ReorderedReductionInitializer) { + /* From the quip: + for k in 0..1: // blockIdx + for m in 0..128: + for n in 0..64: // threadIdx + SumOp(c(k, n), 0, a(k, m, n), {m}) + */ + + BufHandle in("in", {1, 12, 6}, kFloat); + std::vector in_(12 * 6, 1.f); + + Tensor tensor_ = Reduce("sum", {1, 12}, Sum(), in, {6}); + LoopNest l_({tensor_}); + + l_.prepareForCodegen(); + StmtPtr s_ = Stmt::clone(l_.root_stmt()); + s_ = IRSimplifier::simplify(s_); + + Tensor tensor = Reduce("sum", {1, 12}, Sum(), in, {6}); + LoopNest l({tensor}); + + auto loops = l.getLoopStmtsFor(tensor); + loops[0]->set_gpu_block_index(0); + loops[1]->set_gpu_thread_index(0); + + LoopNest::reorderAxis(loops[1], loops[2]); + + StmtPtr s = l.root_stmt(); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + s = IRSimplifier::simplify(s); + + l.prepareForCodegen(); + + s = l.root_stmt(); + s = IRSimplifier::simplify(s); + + std::vector out1(16, -1.f); + SimpleIREvaluator cg(s_, {in, tensor_}); + cg.call({in_, out1}); + + std::vector out2(16, -1.f); + SimpleIREvaluator cg2(s, {in, tensor}); + cg2.call({in_, out2}); + + for (const auto i : c10::irange(16)) { + ASSERT_EQ(out1[i], out2[i]); + } +} + +TEST(Reductions, ReduceRfactor) { + const int M = 10; + const int N = 10; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + + BufHandle b("b", {m, n}, kFloat); + std::vector in(M * N); + for (int j = 0; j < M * N; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {m, n}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; + ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); + auto rc = NodeFinder::find(loop.root_stmt()); + ASSERT_EQ(rc.size(), 2); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c, m, n}); + + cg.call({in, out, M, N}); + ASSERT_EQ(out[0], 4950); +} + +TEST(Reductions, Reduce3DRfactorInner) { + const int M = 10; + const int N = 10; + const int K = 10; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle k("k", kInt); + + BufHandle b("b", {m, n, k}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; + ASSERT_FALSE(loop.rfactor(c_body, loops.at(2))); + auto rc = NodeFinder::find(loop.root_stmt()); + ASSERT_EQ(rc.size(), 1); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c, m, n, k}); + + cg.call({in, out, M, N, K}); + ASSERT_EQ(out[0], 499500); +} + +TEST(Reductions, Reduce3DRfactorOuter) { + const int M = 10; + const int N = 10; + const int K = 10; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle k("k", kInt); + + BufHandle b("b", {m, n, k}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; + ASSERT_TRUE(loop.rfactor(c_body, loops.at(0))); + auto rc = NodeFinder::find(loop.root_stmt()); + ASSERT_EQ(rc.size(), 2); + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c, m, n, k}); + cg.call({in, out, M, N, K}); + ASSERT_EQ(out[0], 499500); +} + +TEST(Reductions, ReduceRepeatedInternalRfactor) { + BufHandle in_("in_", {2, 3, 4, 5, 6}, kFloat); + const int InputSize = 2 * 3 * 4 * 5 * 6; + + std::vector in(InputSize, 1.f); + std::vector out(1, -1.f); + std::vector ref(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), in_, {2, 3, 4, 5, 6}); + LoopNest orig_loop({c}); + + // Try rfactoring N outer loops + for (const auto rfac_number : c10::irange(1, 5)) { + LoopNest refloop(orig_loop); + LoopNest loop(orig_loop); + refloop.prepareForCodegen(); + SimpleIREvaluator ref_cg( + IRSimplifier::simplify(refloop.root_stmt()), {in_, c}); + ref_cg.call({in, ref}); + + BufPtr tmp_buf = c.buf(); + + for (const auto idx : c10::irange(rfac_number)) { + auto reduce = loop.getAllWritesToBuf(tmp_buf)[1]; + ASSERT_TRUE(loop.rfactor( + reduce, loop.getLoopStmtsFor(tmp_buf).at(idx), &tmp_buf)); + } + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {in_, c}); + cg.call({in, out}); + + ASSERT_EQ(ref[0], out[0]); + } +} + +// Split a reduction axis with a tail loop. +TEST(Reductions, ReduceSplitTail) { + const int M = 10; + const int N = 10; + const int K = 10; + + BufHandle b("b", {M, N, K}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + for (const auto i : c10::irange(3)) { + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::splitWithTail(loops[i], 8); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 4950); + } +} + +// Split a reduction axis cleanly so there is no tail loop. +TEST(Reductions, ReduceSplitNoTail) { + const int M = 10; + const int N = 10; + const int K = 10; + BufHandle b("b", {M, N, K}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + for (const auto i : c10::irange(3)) { + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::splitWithTail(loops[i], 5); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 4950); + } +} + +// Split a reduction axis with only a tail loop (the split loop will be size 0 +// and eliminated out). +TEST(Reductions, ReduceOverSplitTail) { + const int M = 10; + const int N = 10; + const int K = 10; + + BufHandle b("b", {M, N, K}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + for (const auto i : c10::irange(3)) { + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::splitWithTail(loops[i], 16); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 4950); + } +} + +// Split a reduction axis with a mask. +TEST(Reductions, ReduceSplitMask) { + const int M = 10; + const int N = 10; + const int K = 10; + + BufHandle b("b", {M, N, K}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + for (const auto i : c10::irange(3)) { + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::splitWithMask(loops[i], 8); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 4950); + } +} + +// Split a reduction axis cleanly not requiring a mask. +TEST(Reductions, ReduceSplitNoMask) { + const int M = 10; + const int N = 10; + const int K = 10; + BufHandle b("b", {M, N, K}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + for (const auto i : c10::irange(3)) { + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::splitWithMask(loops[i], 5); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 4950); + } +} + +// Split a reduction axis with all logic in the mask. +TEST(Reductions, ReduceOverSplitMask) { + const int M = 10; + const int N = 10; + const int K = 10; + + BufHandle b("b", {M, N, K}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + for (const auto i : c10::irange(3)) { + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::splitWithMask(loops[i], 16); + + loop.prepareForCodegen(); + StmtPtr s = loop.root_stmt(); + s = IRSimplifier::simplify(s); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 4950); + } +} + +// Test an rfactor when there are two ReduceOps in the graph due to a +// splitWithTail. +TEST(Reductions, ReduceSplitRfactor) { + const int M = 2; + const int N = 10; + const int K = 10; + const int SPLIT_FACTOR = 4; + + BufHandle b("b", {M, N, K}, kFloat); + std::vector in(M * N * K); + for (const auto m : c10::irange(M)) { + for (int j = 0; j < N * K; ++j) { + in[m * N * K + j] = j; + } + } + + std::vector out(M, -1.f); + + Tensor c = Reduce("sum", {M}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::splitWithTail(loops[2], SPLIT_FACTOR); + + auto c_body = loop.getAllWritesToBuf(c.buf())[2]; + auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); + ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3); + LoopNest::reorderAxis(all_loops[2][1], all_loops[2][2]); + all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); + ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3); + ASSERT_TRUE(loop.rfactor(c_body, all_loops[2][1])); + loop.prepareForCodegen(); + loop.simplify(); + StmtPtr s = loop.root_stmt(); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + for ([[maybe_unused]] const auto i : c10::irange(M)) { + ASSERT_EQ(out[0], 4950); + } +} + +// Test an rfactor which ends up being eliminated since the total loop size is +// smaller than the split factor. +TEST(Reductions, ReduceOverSplitRfactor) { + const int N = 10; + const int K = 10; + const int SPLIT_FACTOR = 16; + + BufHandle b("b", {N, K}, kFloat); + std::vector in(N * K); + for (int j = 0; j < N * K; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {N, K}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + ForPtr i, t; + LoopNest::splitWithTail(loops[1], SPLIT_FACTOR, &i, &t); + LoopNest::reorderAxis(loops[0], i); + + auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf()); + ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(1).size() == 3); + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; + ASSERT_TRUE(loop.rfactor(c_body, all_loops[1][0])); + LoopNest::reorderAxis(all_loops[1][0], all_loops[1][2]); + + loop.prepareForCodegen(); + loop.simplify(); + StmtPtr s = loop.root_stmt(); + + SimpleIREvaluator cg(s, {b, c}); + + cg.call({in, out}); + ASSERT_EQ(out[0], 4950); + + std::ostringstream oss; + oss << *cg.stmt(); + + // Check the IR to verify the rfactored reduce is eliminated. + // TODO: The alloc free should be eliminated here since it is size 0. + /* + const std::string& verification_pattern = + R"IR( +# CHECK: Allocate(tmp_buf); // dtype=float, dims=[0] +# CHECK: sum[0] = 0.f; +# CHECK: for (int n = 0; n < 10; n++) { +# CHECK: for (int k_tail = 0; k_tail < 10; k_tail++) { +# CHECK: sum[0] = (sum[0]) + (b[k_tail + 10 * n]); +# CHECK: } +# CHECK: } +# CHECK: Free(tmp_buf);)IR"; + */ + // TODO: rfactor output is not consistent yet, will fix (@nickg). + // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Reductions, ReduceInlineReduction) { + const int M = 4; + const int N = 5; + const int K = 6; + + BufHandle a_buf("a", {M}, kFloat); + BufHandle b_buf("b", {M, N, K}, kFloat); + + Tensor x = Reduce("x", {M}, Sum(), b_buf, {N, K}); + Tensor y = Compute( + "y", {M}, [&](const VarHandle& m) { return a_buf.load(m) + x.load(m); }); + + PaddedBuffer a_v(M); + PaddedBuffer b_v(M, N, K); + + for (const auto i : c10::irange(M)) { + a_v(i) = i * i; + } + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + for (const auto k : c10::irange(K)) { + b_v(i, j, k) = j * j * k; + } + } + } + + LoopNest l1({y}, {x, y}); + // Cannot inline a reduction computation + ASSERT_FALSE(l1.computeInline(x.buf())); +} + +TEST(Reductions, ReduceInlineConsumer) { + const int M = 4; + const int N = 5; + const int K = 6; + + BufHandle a_buf("a", {M, N, K}, kFloat); + BufHandle b_buf("b", {M, N, K}, kFloat); + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n, k) + b_buf.load(m, n, k); + }); + Tensor y = Reduce("y", {M}, Sum(), x, {N, K}); + + PaddedBuffer a_v(M, N, K); + PaddedBuffer b_v(M, N, K); + + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + for (const auto k : c10::irange(K)) { + a_v(i, j, k) = i * i + k; + b_v(i, j, k) = j * j + k; + } + } + } + + LoopNest l1({y}, {x, y}); + LoopNest l2(l1); + l2.computeInline(x.buf()); + + l1.prepareForCodegen(); + l2.prepareForCodegen(); + + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); + + SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); + SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); + + PaddedBuffer y_1(M); + PaddedBuffer y_2(M); + + eval1(a_v, b_v, y_1); + eval2(a_v, b_v, y_2); + ExpectAllNear(y_1, y_2, 1e-5); + std::ostringstream oss1, oss2; + oss1 << *stmt1; + oss2 << *stmt2; + ASSERT_GT(oss1.str().size(), oss2.str().size()); +} + +TEST(Reductions, ReduceInlineReducerInternal) { + const int M = 4; + const int N = 5; + const int K = 6; + + BufHandle a_buf("a", {M, N, K}, kFloat); + BufHandle b_buf("b", {M, N, K}, kFloat); + + Tensor x = Compute( + "x", + {M, N, K}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf.load(m, n, k) + b_buf.load(m, n, k); + }); + + Reducer minimum(ExprHandle(0.f), [&](ExprHandle a, ExprHandle b) { + return Add::make(ExprHandle(1.f), Min::make(a, b, false)); + }); + Tensor y = Reduce("y", {M}, minimum, x, {N, K}); + + PaddedBuffer a_v(M, N, K); + PaddedBuffer b_v(M, N, K); + + for (const auto i : c10::irange(M)) { + for (const auto j : c10::irange(N)) { + for (const auto k : c10::irange(K)) { + a_v(i, j, k) = i * i + k; + b_v(i, j, k) = j * j + k; + } + } + } + + LoopNest l1({y}, {x, y}); + LoopNest l2(l1); + l2.computeInline(x.buf()); + + l1.prepareForCodegen(); + l2.prepareForCodegen(); + + StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt()); + StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt()); + + SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y}); + SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y}); + + PaddedBuffer y_1(M); + PaddedBuffer y_2(M); + + eval1(a_v, b_v, y_1); + eval2(a_v, b_v, y_2); + ExpectAllNear(y_1, y_2, 1e-5); + std::ostringstream oss1, oss2; + oss1 << *stmt1; + oss2 << *stmt2; + ASSERT_GT(oss1.str().size(), oss2.str().size()); +} + +TEST(Reductions, ReductionCacheAccessesOperatorAxis) { + int L = 4; + int N = 3; + int M = 2; + + BufHandle a("a", {L, N, M}, kFloat); + BufHandle b("b", {L, N, M}, kFloat); + + Tensor c = Compute( + "scale", + {L, N, M}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {L}, Sum(), c, {N, M}); + + Tensor e = Compute("scale", {L}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); + }); + + LoopNest l({e}, {c, d, e}); + LoopNest l_before(l); + l_before.prepareForCodegen(); + SimpleIREvaluator cg_before( + LoopNest::sanitizeNames(l_before.root_stmt()), {a, b, e}); + + StmtPtr d_loop = l.getLoopStmtsFor(d)[0]; + l.cacheAccesses(d.buf(), "d_local", d_loop); + l.prepareForCodegen(); + + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg_after(result, {a, b, e}); + + std::ostringstream oss; + oss << *cg_after.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(d_local); // dtype=float, dims=[4] +#CHECK: for (int i_2 +#CHECK: d_local[i_2] = 0.f +#CHECK: for (int +#CHECK: for (int +#CHECK: d_local[i_2] = (d_local[i_2]) + (scale[ +#CHECK: } +#CHECK: } +#CHECK: } +#CHECK: for (int i_3 +#CHECK: sum[i_3] = d_local[i_3] +#CHECK: Free(d_local); +#CHECK-NOT: d_local + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + PaddedBuffer a_v(L, M, N, "a"); + PaddedBuffer b_v(L, M, N, "b"); + PaddedBuffer c_v(L, M, N, "c"); + PaddedBuffer d_v(L, "d"); + PaddedBuffer e_before(L, "e_before"); + PaddedBuffer e_after(L, "e_after"); + + for (const auto l : c10::irange(L)) { + for (const auto m : c10::irange(M)) { + for (const auto n : c10::irange(N)) { + a_v(l, m, n) = at::randn({1}).item().to(); + b_v(l, m, n) = at::randn({1}).item().to(); + } + } + } + + cg_before.call({a_v, b_v, e_before}); + cg_after.call({a_v, b_v, e_after}); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(e_before, e_after, 1e-5); +} + +TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) { + int L = 4; + int N = 3; + int M = 2; + + BufHandle a("a", {L, N, M}, kFloat); + BufHandle b("b", {L, N, M}, kFloat); + + Tensor c = Compute( + "scale", + {L, N, M}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {L}, Sum(), c, {N, M}); + + Tensor e = Compute("scale", {L}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); + }); + + LoopNest l({e}, {c, d, e}); + LoopNest l_before(l); + l_before.prepareForCodegen(); + SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); + + StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; + l.cacheAccesses(d.buf(), "d_local", d_loop); + l.prepareForCodegen(); + + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg_after(result, {a, b, e}); + + std::ostringstream oss; + oss << *cg_after.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(d_local); // dtype=float, dims=[1] +#CHECK: sum[i_1] = 0 +#CHECK: d_local[0] = sum[i_1] +#CHECK: for (int j_1 +#CHECK: for (int k_1 +#CHECK: d_local[0] = (d_local[0]) + (scale[ +#CHECK: } +#CHECK: } +#CHECK: sum[i_1] = d_local[0] +#CHECK: Free(d_local); +#CHECK-NOT: d_local + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + PaddedBuffer a_v(L, M, N, "a"); + PaddedBuffer b_v(L, M, N, "b"); + PaddedBuffer c_v(L, M, N, "c"); + PaddedBuffer d_v(L, "d"); + PaddedBuffer e_before(L, "e_before"); + PaddedBuffer e_after(L, "e_after"); + + for (const auto l : c10::irange(L)) { + for (const auto m : c10::irange(M)) { + for (const auto n : c10::irange(N)) { + a_v(l, m, n) = at::randn({1}).item().to(); + b_v(l, m, n) = at::randn({1}).item().to(); + } + } + } + + cg_before.call({a_v, b_v, e_before}); + cg_after.call({a_v, b_v, e_after}); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(e_before, e_after, 1e-5); +} + +TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) { + int L = 4; + int N = 3; + int M = 2; + + BufHandle a("a", {L, N, M}, kFloat); + BufHandle b("b", {L, N, M}, kFloat); + + Tensor c = Compute( + "scale", + {L, N, M}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {L}, Sum(), c, {N, M}); + + Tensor e = Compute("scale", {L}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); + }); + + LoopNest l({e}, {c, d, e}); + LoopNest l_before(l); + l_before.prepareForCodegen(); + SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e}); + + StmtPtr d_loop = l.getLoopStmtsFor(d)[2]; + l.cacheAccesses(d.buf(), "d_local", d_loop); + l.prepareForCodegen(); + + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg_after(result, {a, b, e}); + + std::ostringstream oss; + oss << *cg_after.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(d_local); // dtype=float, dims=[1] +#CHECK: sum[i_1] = 0 +#CHECK: for (int +#CHECK: d_local[0] = 0 +#CHECK: for (int +#CHECK: d_local[0] = (d_local[0]) + (scale[ +#CHECK: } +#CHECK: sum[i_1] = (sum[i_1]) + (d_local[0]) +#CHECK: } +#CHECK: Free(d_local); +#CHECK-NOT: d_local + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + PaddedBuffer a_v(L, M, N, "a"); + PaddedBuffer b_v(L, M, N, "b"); + PaddedBuffer c_v(L, M, N, "c"); + PaddedBuffer d_v(L, "d"); + PaddedBuffer e_before(L, "e_before"); + PaddedBuffer e_after(L, "e_after"); + + for (const auto l : c10::irange(L)) { + for (const auto m : c10::irange(M)) { + for (const auto n : c10::irange(N)) { + a_v(l, m, n) = at::randn({1}).item().to(); + b_v(l, m, n) = at::randn({1}).item().to(); + } + } + } + + cg_before.call({a_v, b_v, e_before}); + cg_after.call({a_v, b_v, e_after}); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + ExpectAllNear(e_before, e_after, 1e-5); +} + +TEST(Reductions, ReductionCacheBodyAccess) { + BufHandle a("a", {24, 32, 12}, kFloat); + BufHandle b("b", {24, 32, 12}, kFloat); + + Tensor c = Compute( + "scale", + {24, 32, 12}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); + + Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); + }); + + LoopNest l({e}, {c, d, e}); + + StmtPtr d_loop = l.getLoopStmtsFor(d)[1]; + l.cacheAccesses(c.buf(), "scale_local", d_loop); + + l.prepareForCodegen(); + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {a, b, e}); + + std::ostringstream oss; + oss << *cg.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(scale_local); // dtype=float, dims=[1, 32, 12] +#CHECK: for (int j_1 = 0; j_1 < 32; j_1++) { +#CHECK: for (int k_1 = 0; k_1 < 12; k_1++) { +#CHECK: scale_local[k_1 + 12 * j_1] = scale[(k_1 + 12 * j_1) + 384 * i_1]; +#CHECK: sum[i_1] = (sum[i_1]) + (scale_local[k_2 + 12 * j_2]); +#CHECK: scale_1[i_2] = (b[i_2]) * (sum[i_2]); +#CHECK: Free(scale_local); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(Reductions, ReductionCacheConsumerAccess) { + BufHandle a("a", {24, 32, 12}, kFloat); + BufHandle b("b", {24, 32, 12}, kFloat); + + Tensor c = Compute( + "scale", + {24, 32, 12}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); + + Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); + }); + + LoopNest l({e}, {c, d, e}); + + LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4); + + StmtPtr e_loop = l.getLoopStmtsFor(e)[1]; + l.cacheAccesses(d.buf(), "sum_local", e_loop); + l.prepareForCodegen(); + + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {a, b, e}); + + std::ostringstream oss; + oss << *cg.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Alias(sum_local,scale); +#CHECK: sum[i_1] = (sum[i_1]) + (scale[ +#CHECK: for (int j_2 = 0; j_2 < 4 +#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2]; +#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(Reductions, ReductionSplitCacheConsumerAccess) { + BufHandle a("a", {24, 32, 12}, kFloat); + BufHandle b("b", {24, 32, 12}, kFloat); + + Tensor c = Compute( + "scale", + {24, 32, 12}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); + + Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); + }); + + LoopNest l({e}, {c, d, e}); + + ForPtr inner; + + // Split outer reduction axis. + LoopNest::splitWithMask(l.getLoopStmtsFor(d)[0], 4, &inner); + + // Split reduction consumer. + LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); + + l.cacheAccesses(d.buf(), "sum_local", inner); + l.prepareForCodegen(); + + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {a, b, e}); + + // reduction changes but cache does not. + std::ostringstream oss; + oss << *cg.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Alias(sum_local,scale); +#CHECK: sum[j_1 + 4 * i_1] = (sum[j_1 + 4 * i_1]) + (scale[((l + 12 * k_1) + 1536 * i_1) + 384 * j_1]); +#CHECK: for (int i_2 = 0; i_2 < 6 +#CHECK: for (int j_2 = 0; j_2 < 4 +#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2]; +#CHECK: for (int j_3 = 0; j_3 < 4 +#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(Reductions, ReductionReorderCacheConsumerAccess) { + BufHandle a("a", {24, 32, 12}, kFloat); + BufHandle b("b", {24, 32, 12}, kFloat); + + Tensor c = Compute( + "scale", + {24, 32, 12}, + [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) { + return b.load(l, n, m) * a.load(l, n, m); + }); + Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12}); + + Tensor e = Compute("scale", {24}, [&](const VarHandle& l) { + return b.load(0, 0, l) * d.load(l); + }); + + LoopNest l({e}, {c, d, e}); + + ForPtr inner; + + // reorder outer reduction axes. + auto loops = l.getLoopStmtsFor(d); + LoopNest::reorderAxis(loops[0], loops[1]); + + // Split reduction consumer. + LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner); + + l.cacheAccesses(d.buf(), "sum_local", inner); + l.prepareForCodegen(); + + StmtPtr result = + LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt())); + SimpleIREvaluator cg(result, {a, b, e}); + + // neither reduction body not cache changes. + std::ostringstream oss; + oss << *cg.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: sum[j_1] = (sum[j_1]) + (scale[(k_1 + 12 * i_2) + 384 * j_1]); +#CHECK: for (int i_3 = 0; i_3 < 6; +#CHECK: for (int j_2 = 0; j_2 < 4; +#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_3]; +#CHECK: for (int j_3 = 0; j_3 < 4; +#CHECK: scale_1[j_3 + 4 * i_3] = (b[j_3 + 4 * i_3]) * (sum_local[j_3]); + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} + +TEST(Reductions, ReductionRfactorCacheTempOuter) { + const int M = 10; + const int N = 10; + const int K = 10; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle k("k", kInt); + + BufHandle b("B", {m, n, k}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); + LoopNest loop({c}); + + std::vector loops = loop.getLoopStmtsFor(c); + LoopNest::reorderAxis(loops.at(0), loops.at(1)); + loops = loop.getLoopStmtsFor(c); + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; + BufPtr rfac_buf; + ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); + loop.distributeLoop(loops.at(0)); + + auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); + ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); + LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]); + + all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); + LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][1]); + loop.simplify(); + loop.prepareForCodegen(); + StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt()); + SimpleIREvaluator cg(s, {b, c, m, n, k}); + + std::ostringstream oss; + oss << *cg.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n] +#CHECK: Allocate(tmp); // dtype=float, dims=[n] +#CHECK: for (int i_1 = 0; i_1 < m +#CHECK: for (int j = 0; j < n +#CHECK: tmp[j] = 0 +#CHECK: } +#CHECK: for (int j_1 = 0; j_1 < n +#CHECK: for (int k +#CHECK: tmp[j_1] = (tmp[j_1]) + (B[ +#CHECK: } +#CHECK: } +#CHECK: for (int j_2 = 0; j_2 < n +#CHECK: sum_rfac[j_2] = (sum_rfac[j_2]) + (tmp[j_2]); +#CHECK: } +#CHECK: Free(tmp); +#CHECK-NOT: tmp + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + cg.call({in, out, M, N, K}); + ASSERT_EQ(out[0], 499500); +} + +TEST(Reductions, ReductionRfactorCacheTempInner) { + const int M = 10; + const int N = 10; + const int K = 10; + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle k("k", kInt); + + BufHandle b("B", {m, n, k}, kFloat); + std::vector in(M * N * K); + for (int j = 0; j < M * N * K; ++j) { + in[j] = j; + } + + std::vector out(1, -1.f); + + Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k}); + LoopNest loop({c}); + std::vector loops = loop.getLoopStmtsFor(c); + auto c_body = loop.getAllWritesToBuf(c.buf())[1]; + + LoopNest::reorderAxis(loops.at(0), loops.at(1)); + loops = loop.getLoopStmtsFor(c); + BufPtr rfac_buf; + ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf)); + loop.distributeLoop(loops.at(0)); + auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); + ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); + LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]); + + all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf); + ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3); + LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][2]); + loop.prepareForCodegen(); + loop.simplify(); + StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt()); + SimpleIREvaluator cg(s, {b, c, m, n, k}); + + std::ostringstream oss; + oss << *cg.stmt(); + const std::string& expected_ir = + R"IR( +#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n] +#CHECK: Allocate(tmp); // dtype=float, dims=[1] +#CHECK: for (int i_1 = 0; i_1 < m +#CHECK: for (int j = 0; j < n +#CHECK: tmp[0] = 0 +#CHECK: for (int k +#CHECK: tmp[0] = (tmp[0]) + (B[ +#CHECK: } +#CHECK: sum_rfac[j] = (sum_rfac[j]) + (tmp[0]); +#CHECK: Free(tmp); +#CHECK-NOT: tmp + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + cg.call({in, out, M, N, K}); + ASSERT_EQ(out[0], 499500); +} + +TEST(Reductions, ReductionVectorize) { + std::vector in_(8 * 8); + for (const auto i : c10::irange(8)) { + for (const auto j : c10::irange(8)) { + in_[i * 8 + j] = i; + } + } + std::vector out_before(8, -1.f); + std::vector out_after(8, -1.f); + + BufHandle in("in", {8, 8}, kFloat); + + Tensor tensor = Reduce("sum", {8}, Sum(), in, {8}); + LoopNest l_before({tensor}); + LoopNest l(l_before); + l_before.prepareForCodegen(); + SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); + cg_before.call({in_, out_before}); + + ASSERT_TRUE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[0])); + + StmtPtr s = l.root_stmt(); + s = LoopNest::sanitizeNames(IRSimplifier::simplify(s)); + + std::ostringstream oss; + oss << *s; + const std::string& expected_ir = + R"IR( +#CHECK: sum[Ramp(0, 1, 8)] = Broadcast(0.f, 8); +#CHECK: for (int i = 0; i < 8; i++) { +#CHECK: sum[Ramp(0, 1, 8)] = ReduceOp((sum[Ramp(0, 1, 8)]) + (in[Ramp(i, 8, 8)]), reduce_args={i}); +#CHECK: } + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + // Vectorizing should not change result. + l.prepareForCodegen(); + s = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg_after(s, {in, tensor}); + cg_after.call({in_, out_after}); + for (const auto i : c10::irange(8)) { + ASSERT_EQ(out_before[i], out_after[i]); + } +} + +TEST(Reductions, ReductionVectorizeInner) { + BufHandle in("in", {8, 8}, kFloat); + + Tensor tensor = Reduce("sum", {8}, Sum(), in, {8}); + LoopNest l({tensor}); + + ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1])); +} + +TEST(Reductions, ReductionVectorizeRfactor) { + std::vector in_(8 * 8); + for (const auto i : c10::irange(8)) { + for (const auto j : c10::irange(8)) { + in_[i * 8 + j] = i; + } + } + std::vector out_before(1, -1.f); + std::vector out_after(1, -1.f); + + BufHandle in("in", {8, 8}, kFloat); + + Tensor tensor = Reduce("sum", {}, Sum(), in, {8, 8}); + + LoopNest l_before({tensor}); + LoopNest l(l_before); + l_before.prepareForCodegen(); + SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor}); + cg_before.call({in_, out_before}); + + ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1])); + + // But if we rfactor this so it's not a reduce axis we can vectorize that + // loop. + std::vector loops = l.getLoopStmtsFor(tensor); + LoopNest::reorderAxis(loops[0], loops[1]); + loops = l.getLoopStmtsFor(tensor); + auto tensor_body = l.getAllWritesToBuf(tensor.buf())[1]; + BufPtr rfac_buf = nullptr; + ASSERT_TRUE(LoopNest::rfactor(tensor_body, loops.at(0), &rfac_buf)); + + LoopNest::distributeLoop(loops.at(0)); + auto rfac_loops = l.getAllLoopNestsWritingToBuf(rfac_buf); + + ASSERT_TRUE(LoopNest::vectorize(rfac_loops[1][0])); + l.simplify(); + + StmtPtr s = LoopNest::sanitizeNames(l.root_stmt()); + + std::ostringstream oss; + oss << *s; + const std::string& expected_ir = + R"IR( +#CHECK: sum = 0.f; +#CHECK: for (int i = 0; i < 8; i++) { +#CHECK: sum_rfac[i] = 0.f; +#CHECK: } +#CHECK: for (int i_1 = 0; i_1 < 8; i_1++) { +#CHECK: sum_rfac[Ramp(0, 1, 8)] = ReduceOp((sum_rfac[Ramp(0, 1, 8)]) + (in[Ramp(8 * i_1, 1, 8)]), reduce_args={i_1}); +#CHECK: } +#CHECK: for (int i_2 = 0; i_2 < 8; i_2++) { +#CHECK: sum = ReduceOp((sum) + (sum_rfac[i_2]), reduce_args={i_2}); +#CHECK: } + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); + + // Vectorizing should not change result. + l.prepareForCodegen(); + s = IRSimplifier::simplify(l.root_stmt()); + SimpleIREvaluator cg_after(s, {in, tensor}); + cg_after.call({in_, out_after}); + + ASSERT_EQ(out_before[0], out_after[0]); +} + +TEST(Reductions, InitFunction) { + constexpr int M = 32; + constexpr int N = 16; + BufHandle A("A", {M, N}, kFloat); + BufHandle B("B", {N}, kFloat); + Tensor C = Reduce( + "C", + {N}, + Sum(), + [&](const std::vector& v) { return B.load(v[0]); }, + [&](const std::vector& v) { return A.load(v[1], v[0]); }, + {M}); + LoopNest nest({C}); + nest.prepareForCodegen(); + StmtPtr s = LoopNest::sanitizeNames(IRSimplifier::simplify(nest.root_stmt())); + std::ostringstream oss; + oss << *s << "\n"; + const std::string& expected_ir = + R"IR( +#CHECK: for (int i = 0; i < 16; i++) { +#CHECK: C[i] = B[i]; +#CHECK: for (int j = 0; j < 32; j++) { +#CHECK: C[i] = (C[i]) + (A[i + 16 * j]); +#CHECK: } +#CHECK: } + )IR"; + torch::jit::testing::FileCheck().run(expected_ir, oss.str()); +} +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_registerizer.cpp b/test/cpp/tensorexpr/test_registerizer.cpp new file mode 100644 index 0000000000000..d6f5977789a9e --- /dev/null +++ b/test/cpp/tensorexpr/test_registerizer.cpp @@ -0,0 +1,3702 @@ +#include +#include "test/cpp/tensorexpr/test_base.h" + +#include "test/cpp/tensorexpr/test_utils.h" +#include "torch/csrc/jit/tensorexpr/ir_simplifier.h" +#include "torch/csrc/jit/tensorexpr/registerizer.h" + +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; + +// Can replace a simple scalar access with a local variable. +TEST(Registerizer, RegisterizerSimple) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = x + A_1; + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Won't do replacement of a loop access. +TEST(Registerizer, RegisterizerLoop) { + BufHandle a("A", {10}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[x] = (A[x]) + x; + * } + */ + + // No change. + stmt = registerize(stmt); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[x] = (A[x]) + x; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: int +# CHECK: A[0] = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A_ +# CHECK: A[x] = +# CHECK-NOT: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Won't replace even if the load is a fixed scalar, since the store could +// invalidate it. +TEST(Registerizer, RegisterizerLoopFixedLoad) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {x}, Add::make(Load::make(a, {0}), x))}))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[x] = (A[0]) + x; + * } + */ + + // No change. + stmt = registerize(stmt); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[x] = (A[0]) + x; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: int +# CHECK: A[0] = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A_ +# CHECK: A[x] = +# CHECK-NOT: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// We can registerize accesses that occur entirely within inner scopes, even if +// they depend on the loop var. +TEST(Registerizer, RegisterizerLoopInternal) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {x}, Add::make(Load::make(a, {x}), x)), + Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); + + /* + * for (int x = 0; x < 10; x++) { + * A[x] = (A[x]) + x; + * A[x] = (A[x]) + x; + * } + */ + + stmt = registerize(stmt); + + // TODO: the order of terms in addition changes and in general depends on + // some hash value. This results in unpredictable swaps of the operands from + // random changes, which is not great. Ideally, we should ensure some + // specific order (ideally, the original one). + /* + * for (int x = 0; x < 10; x++) { + * int A_1 = A[x]; + * A_1 = x + A_1; + * A_1 = x + A_1; + * A[x] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: int A_1 = A[x]; +# CHECK: A_1 = A_1 + x; +# CHECK: A_1 = A_1 + x; +# CHECK: A[x] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An access can be overlapped by another read in the same Expr. In this case +// B[z] and B[y] overlap and prevent registerization of both accesses. +TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + StmtPtr stmt = Block::make({For::make( + x, + 0, + 10, + Store::make(a, {x}, Add::make(Load::make(b, {y}), Load::make(b, {z}))))}); + stmt = IRSimplifier::simplify(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * A[x] = (B[y]) + (B[z]); + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +TEST(Registerizer, RegisterizerLoopInternalRepeated) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)), + Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)), + Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})) + + }); + + /* + * for (int x = 0; x < 10; x++) { + * A[0] = x + (A[1]); + * A[0] = x + (A[1]); + * } + * for (int x = 0; x < 10; x++) { + * A[0] = x + (A[1]); + * A[0] = x + (A[1]); + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[1]; + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_1 + x; + * A_2 = A_1 + x; + * } + * for (int x = 0; x < 10; x++) { + * A_2 = A_1 + x; + * A_2 = A_1 + x; + * } + * A[0] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[1]; +# CHECK: int A_2 = A[0]; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: A_2 = A_1 + x; +# CHECK: A_2 = A_1 + x; +# CHECK: } +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: A_2 = A_1 + x; +# CHECK: A_2 = A_1 + x; +# CHECK: } +# CHECK-NOT: A[1] +# CHECK: A[0] = A_2; +# CHECK-NOT: A[1] +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)), + Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)), + Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})) + + }); + stmt = IRSimplifier::simplify(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = IRSimplifier::simplify(Block::make( + {For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))), + Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))), + Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})) + + })); + + /* + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + * for (int x = 0; x < 10; x++) { + * A[0] = (A[x]) + x; + * A[0] = (A[x]) + x; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Will registerize multiple accesses of different items of the same buffer. +TEST(Registerizer, RegisterizerMultiVar) { + BufHandle a("A", {2}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({ + Store::make(a, {0}, 0), + Store::make(a, {1}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)), + Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})), + }); + + /* + * A[0] = 0; + * A[1] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * A[1] = (A[1]) - x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * int A_2 = 0; + * for (int x = 0; x < 10; x++) { + * A_2 = x + A_2; + * A_1 = A_1 - x; + * } + * A[1] = A_2; + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: int A_2 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A_2 = +# CHECK: A[1] = A_2 +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Will registerize the valid accesses while skipping invalid replacements. +TEST(Registerizer, RegisterizerVariableLoad) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle x2("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make(x, 0, 10, Store::make(b, {x}, x)), + For::make( + x2, + 0, + 10, + Block::make({Store::make( + a, {0}, Add::make(Load::make(a, {0}), Load::make(b, {x2})))}))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = x; + * } + * for (int x_1 = 0; x_1 < 10; x_1++) { + * A[0] = (A[0]) + (B[x_1]); + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = x; + * } + * for (int x_1 = 0; x_1 < 10; x_1++) { + * A_1 = A_1 + (B[x_1]); + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: B[x] = x +# CHECK: for (int x_1 = 0; x_1 < 10; x_1++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize variable accesses so long as the variable does not change. +TEST(Registerizer, RegisterizerSymbolicIndices) { + VarHandle i("i", kInt); + VarHandle N("N", kInt); + BufHandle a("A", {N}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {i}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {i}, Add::make(Load::make(a, {i}), x))}))}); + + /* + * A[i] = 0; + * for (int x = 0; x < 10; x++) { + * A[i] = (A[i]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = x + A_1; + * } + * A[i] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[i] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize accesses dependent on multiple loop vars. +TEST(Registerizer, RegisterizerMultiLoop) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + For::make( + y, + 0, + 10, + Block::make({Store::make( + a, + {0}, + Mul::make(Add::make(Load::make(a, {0}), x), y))})))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * A[0] = x * y + (A[0]) * y; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * A_1 = x * y + y * A_1; + * } + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: for (int y = 0; y < 10; y++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize correctly if scalars already exist in the program. +TEST(Registerizer, RegisterizerRepeated) { + BufHandle a("A", {2}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({ + Store::make(a, {0}, 0), + Store::make(a, {1}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)), + Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})), + }); + + // Registerize manually to make sure we only replace a single target. + { + registerizer::RegisterizerAnalysis analysis; + stmt->accept(&analysis); + auto candidates = analysis.getCandidates(); + ASSERT_EQ(candidates.size(), 2); + + candidates.pop_back(); + registerizer::RegisterizerReplacer replacer(candidates); + stmt = stmt->accept_mutator(&replacer); + } + + // Re-analyze and replace the second target. + { + registerizer::RegisterizerAnalysis analysis; + stmt->accept(&analysis); + auto candidates = analysis.getCandidates(); + ASSERT_EQ(candidates.size(), 1); + + registerizer::RegisterizerReplacer replacer(candidates); + stmt = stmt->accept_mutator(&replacer); + } + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: int A_1_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A_1_1 = +# CHECK: A[1] = A_1_1; +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize the load of A. +TEST(Registerizer, RegisterizerNoLoads) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1))}))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = x + 1; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = x + 1; + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize the load of A but not the store of B. +TEST(Registerizer, RegisterizerNoRepeatedStores) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {x}, Add::make(Load::make(a, {0}), x))}))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = (A[0]) + x; + * } + */ + + stmt = registerize(stmt); + + // TODO: its unnecessary to reorder the initializer of A[0], but it's not + // actually worse so lets not worry for now. + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = x + A_1; + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A_ +# CHECK: B[x] = +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Won't registerize if there are multiple accesses which may overlap. +TEST(Registerizer, RegisterizerMultiVarOverlap) { + BufHandle a("A", {2}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({ + Store::make(a, {0}, 0), + Store::make(a, {1}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {x}, Add::make(Load::make(a, {0}), x)), + Store::make(a, {x + 1}, Sub::make(Load::make(a, {1}), x))})), + }); + stmt = IRSimplifier::simplify(stmt); + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +TEST(Registerizer, RegisterizerAllocs) { + BufHandle a("A", {2}, kInt); + BufHandle c("C", {1}, kInt); + VarHandle x("x", kInt); + + BufHandle b("B", {Load::make(c, {0})}, kInt); + + StmtPtr stmt = Block::make( + {Allocate::make(b), + Store::make(a, {0}, Load::make(c, {0})), + Store::make(b, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {0}, Add::make(Load::make(b, {0}), x)), + Store::make(a, {0}, Load::make(c, {0}))})), + Free::make(b)}); + + /* + * Allocate(B, int, {C[0]}); + * A[0] = C[0]; + * B[0] = 0; + * for (int x = 0; x < 10; x++) { + * B[0] = (B[0]) + x; + * A[0] = C[0]; + * } + * Free(B); + */ + + stmt = registerize(stmt); + + /* + * int C_1 = C[0]; + * Allocate(B, int, {C_}); + * int A_1 = C_1; + * int B_1 = 0; + * for (int x = 0; x < 10; x++) { + * B_1 = B_1 + x; + * A_1 = C_1; + * } + * B[0] = B_1; + * A[0] = A_1; + * Free(B); + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int C_1 = C[0]; +# CHECK: Allocate(B +# CHECK: int A_1 = C_1; +# CHECK: int B_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK: B_1 = +# CHECK: A_1 = C_ +# CHECK: B[0] = B_1; +# CHECK: A[0] = A_1; +# CHECK: Free(B)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Registerizer, RegisterizerNoInitializer) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({For::make( + x, + 0, + 10, + Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); + + /* + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_1 = x + A_1; + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Registerizer, RegisterizerNoInitializerLoopVar) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({For::make( + x, + 0, + 10, + Block::make({Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))}); + stmt = IRSimplifier::simplify(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * A[x] = (A[x]) + x; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +TEST(Registerizer, RegisterizerLoadThenStore) { + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {0}, Add::make(Load::make(a, {0}), x)), + Store::make(a, {0}, Load::make(b, {0}))}))}); + + /* + * for (int x = 0; x < 10; x++) { + * B[0] = (A[0]) + x; + * A[0] = B[0]; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * int B_1 = B[0]; + * for (int x = 0; x < 10; x++) { + * B_1 = x + A_1; + * A_1 = B_1; + * } + * B[0] = B_1; + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: int B_1 = B[0]; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: B[ +# CHECK: B_1 = +# CHECK-NOT: A[ +# CHECK: A_1 = B_ +# CHECK: B[0] = B_ +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Registerizer, RegisterizerParallelized) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + LoopOptions loopOpts; + loopOpts.set_gpu_block_index(0); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}), + loopOpts)}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + ASSERT_THROWS_WITH( + registerize(stmt), + "Registerization must occur after parallelism flattening"); +} + +// Should be able to registerize this since the scalar would exist before the +// branch. +TEST(Registerizer, RegisterizerConditionAfter) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x})), + Store::make(c, {x}, Load::make(a, {x})), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + nullptr)}); + + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; + * C[x] = A_1; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: C[x] = A_1; +# CHECK: if ( +# CHECK: A_1 = A_1 + 1; +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Should be able to registerize this since the scalar exists in the same form +// after the branch and there is no overlap. +TEST(Registerizer, RegisterizerConditionBefore) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + nullptr), + Store::make(a, {x}, Load::make(b, {x})), + Store::make(c, {x}, Load::make(a, {x}))}); + + /* + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * A[x] = B[x]; + * C[x] = A[x]; + */ + + stmt = registerize(stmt); + + /* + * int A_ 1 = A[x]; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A_1 = B[x]; + * C[x] = A_1; + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if ( +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A_1 = B[x]; +# CHECK: C[x] = A_1; +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Should be able to registerize this as the combination of the two above rules. +TEST(Registerizer, RegisterizerConditionInside) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x})), + Store::make(c, {x}, Load::make(a, {x})), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + nullptr), + Store::make(b, {x}, Load::make(a, {x})), + Store::make(a, {x}, Load::make(c, {x}))}); + + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * B[x] = A[x]; + * A[x] = C[x]; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; + * C[x] = A_1; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * B[x] = A_1; + * A_1 = C[x]; + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: C[x] = A_1; +# CHECK: if ( +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: B[x] = A_1; +# CHECK: A_1 = C[x]; +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An example where an access is cut by an overlapping access inside a +// condition, and both sides are large enough to be registerized but cannot be +// because there is no safe place to put the initializer or finalizer. +TEST(Registerizer, RegisterizerConditionInsideOverlap1) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + StmtPtr stmt = Block::make( + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + {Store::make(a, {x}, Load::make(b, {x})), + Store::make(c, {x}, Load::make(a, {x})), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({ + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + Store::make(a, {0}, 3), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + }), + nullptr), + Store::make(b, {x}, Load::make(a, {x})), + Store::make(a, {x}, Load::make(c, {x}))}); + + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * A[0] = 3; + * A[x] = (A[x]) + 1; + * } + * B[x] = A[x]; + * A[x] = C[x]; + */ + + // The A[0] store overlaps, A[x] cutting the region that can be registerized + // into two groups. + // Each group has 2 loads and 2 stores however, so we could registerize it, + // but the first group would need to be finalized inside the condition block, + // the second would need to be initialized inside the condition block. There's + // no safe place to put these that's visible to the other uses in the group + // and so neither registerization is possible. + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Same as the above, but the access group before the condition (and after the +// condition) are large enough to be registerized without needing the access +// from the loop. Registerization occurs but does not include any accesses in +// the condition, and the first group must be finalized before the Cond, the +// second initialized after it. +TEST(Registerizer, RegisterizerConditionInsideOverlap2) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + StmtPtr stmt = Block::make( + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + {Store::make(a, {x}, Load::make(b, {x})), + Store::make(a, {x}, Load::make(b, {x + 1})), + Store::make(c, {x}, Load::make(a, {x})), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({ + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + Store::make(a, {0}, 3), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + }), + nullptr), + Store::make(b, {x}, Load::make(a, {x})), + Store::make(b, {x + 1}, Load::make(a, {x})), + Store::make(a, {x}, Load::make(c, {x}))}); + + /* + * A[x] = B[x]; + * A[x] = B[x + 1]; + * C[x] = A[x]; + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * A[0] = 3; + * A[x] = (A[x]) + 1; + * } + * B[x] = A[x]; + * B[x + 1] = A[x]; + * A[x] = C[x]; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; // A_1 initializer + * A_1 = B[x + 1]; // + * C[x] = A_1; // + * A[x] = A_1; // A_1 finalizer + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * A[0] = 3; + * A[x] = (A[x]) + 1; + * } + * int A_2 = A[x]; // A_2 initialier + * B[x] = A_2; // + * B[x + 1] = A_2; // + * A_2 = C[x]; // + * A[x] = A_2; // A_2 finalizer + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: A_1 = B[x + 1]; +# CHECK: C[x] = A_1; +# CHECK: A[x] = A_1; +# CHECK: if ( +# CHECK-NOT: A_1 = A_1 + 1; +# CHECK: A[x] = (A[x] +# CHECK: A[0] = +# CHECK: A[x] = (A[x] +# CHECK: } +# CHECK: int A_2 = A[x]; +# CHECK: B[x] = A_2; +# CHECK: B[x + 1] = A_2; +# CHECK: A_2 = C[x]; +# CHECK: A[x] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// When accesses are within conditional blocks they are not visible to the wider +// program, because we don't know if the branch would be taken and if it isn't +// the accesses in it don't need to be valid (think size checks on the index). +// In this case the accesses cannot be registerized. +TEST(Registerizer, RegisterizerConditionHidden) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + nullptr), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * if (x>5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// But... if the same access is found in a non conditional scope, that means +// that that access is valid in the higher scope (or at least if its not it's +// the user's fault). It "unhides" the conditional accesses, allowing +// registerization to occur. +TEST(Registerizer, RegisterizerConditionUnhidden) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + nullptr), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + * A[x] = (A[x]) + 1; <-- this is doing the unhiding. + * if (x>5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * if (x<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A_1 = A_1 + 1; + * if (x>5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if (x<5 +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A_1 = A_1 + 1; +# CHECK: if (x>5 +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize a load that occurs in the condition of a Cond. +TEST(Registerizer, RegisterizerCondCondition) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make( + {Store::make(a, {x}, Load::make(b, {x})), + Store::make(c, {x}, Load::make(a, {x})), + Cond::make( + CompareSelect::make( + Load::make(a, {x}), 5, CompareSelectOperation::kLT), + Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)), + nullptr)}); + + /* + * A[x] = B[x]; + * C[x] = A[x]; + * if ((A[x])<5 ? 1 : 0) { + * C[x] = (C[x]) + 1; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = B[x]; + * int C_1 = A_1; + * if (A_1<5 ? 1 : 0) { + * C_1 = C_1 + 1; + * } + * C[x] = C_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = B[x]; +# CHECK: int C_1 = A_1; +# CHECK: if (A_1<5 +# CHECK: C_1 = C_1 + 1; +# CHECK: C[x] = C_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Appearing in the condition of a Cond makes it visible to the enclosing scope, +// and so we can registerize internal usages. +TEST(Registerizer, RegisterizerCondConditionUnhidden) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make({Cond::make( + CompareSelect::make(Load::make(a, {x}), 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)), + Store::make(a, {x}, Add::make(Load::make(a, {x}), 10)))}); + + /* + * if ((A[x])<5 ? 1 : 0) { + * A[x] = (A[x]) + 1; + * } else { + * A[x] = (A[x]) + 10; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * if (A_1<5 ? 1 : 0) { + * A_1 = A_1 + 1; + * } else { + * A_1 = A_1 + 10; + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if (A_1<5 +# CHECK: A_1 = A_1 + 1; +# CHECK: } else { +# CHECK: A_1 = A_1 + 10; +# CHECK: } +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Conditional hiding also works for IfThenElse exprs. +TEST(Registerizer, RegisterizerIfThenElseHidden) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + StmtPtr stmt = Block::make( + {Store::make( + b, + {y}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}), 1), + Add::make(Load::make(a, {x + 1}), 2))), + Store::make( + b, + {y + 1}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}), 1), + Add::make(Load::make(a, {x + 1}), 2)))}); + + /* + * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Conditional unhiding also works for IfThenElse exprs. +TEST(Registerizer, RegisterizerIfThenElseUnhidden) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + StmtPtr stmt = Block::make({ + Store::make(a, {x}, 0), + Store::make( + b, + {y}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}), 1), + Add::make(Load::make(a, {x + 1}), 2))), + Store::make( + b, + {y + 1}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}), 1), + Add::make(Load::make(a, {x + 1}), 2))), + }); + + /* + * A[x] = 0; + * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2); + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); + * B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); +# CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2); +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Nested IfThenElse exprs can't promote to higher level scopes. +TEST(Registerizer, RegisterizerIfThenElseNested) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + BufHandle d("D", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make({Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + IfThenElse::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Load::make(d, {x}), + Load::make(b, {x})), + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kEQ), + Load::make(c, {x}), + Load::make(d, {x}))))}); + + /* + * A[x] = IfThenElse(x<3 ? 1 : 0, + * IfThenElse(x==2 ? 1 : 0, D[x], B[x]), + * IfThenElse(x==5 ? 1 : 0, C[x], D[x])); + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Cannot registerize an access completely contained within an IfThenElse +// branch, since it is not a Stmt and cannot hold variable definitions. We need +// to check that we don't promote the initializer/finalizer to the enclosing +// Block. +TEST(Registerizer, RegisterizerIfThenElseInternal) { + // Making these floats so they don't get simplified to a single access. + BufHandle a("A", {5}, kFloat); + BufHandle b("B", {5}, kFloat); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make({Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Add::make(Load::make(b, {x}), Load::make(b, {x})), + Load::make(b, {x})))}); + + /* + * A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]); + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); + + // If this was a Cond instead of an IfThenElse then we could registerize the + // two accesses to B[x] in the True branch. + + // Actually lets verify that. + + stmt = Block::make({Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Store::make(a, {x}, Add::make(Load::make(b, {x}), Load::make(b, {x}))), + Store::make(a, {x}, Load::make(b, {x})))}); + + /* + * if (x<3 ? 1 : 0) { + * A[x] = (B[x]) + (B[x]); + * } else { + * A[x] = B[x]; + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<3 ? 1 : 0) { + * float B_1 = B[x]; + * A[x] = B_1 + B_1; + * } else { + * A[x] = B[x]; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: int +# CHECK-NOT: float +# CHECK: if (x<3 +# CHECK: float B_1 = +# CHECK: A[x] = B_1 + B_1 +# CHECK: } else { +# CHECK: A[x] = B[x] +# CHECK: } +# CHECK-NOT: A[x] +# CHECK-NOT: B[x])IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize a load that occurs in the condition of an IfThenElse; +TEST(Registerizer, RegisterizerIfThenElseCondition) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make( + {Store::make(a, {x}, Load::make(a, {x})), + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make( + Load::make(a, {x}), 5, CompareSelectOperation::kLT), + Load::make(b, {0}), + Load::make(c, {0})))}); + + /* + * A[x] = A[x]; <---- just here so there are enough accesses to combine. + * A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]); + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * A_1 = A_1; + * A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]); +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Appearing in the condition of a Cond makes it visible to the enclosing scope, +// and so we can registerize internal usages. +TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make({Store::make( + b, + {x}, + IfThenElse::make( + CompareSelect::make( + Load::make(a, {x}), 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}), 1), + Add::make(Load::make(a, {x}), 10)))}); + + /* + * B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10); + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10); + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Cannot promote accesses internal to IfThenElse branches even if the enclosing +// scope if conditional. +TEST(Registerizer, RegisterizerConditionBranchOnly) { + BufHandle a("A", {5}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({For::make( + x, + 0, + 10, + Block::make({ + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}), x), + Add::make(Load::make(a, {x - 5}), x))), + Store::make( + a, + {x - 5}, + IfThenElse::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Add::make(Load::make(a, {x}), x), + Add::make(Load::make(a, {x - 5}), x)))), + }))}); + stmt = IRSimplifier::simplify(stmt); + + std::ostringstream before; + before << *stmt; + + /* for (int x = 0; x < 10; x++) { + * if (x<5 ? 1 : 0) { + * A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); + * } else { + * A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x); + * } + * } + */ + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// We can registerize an IfThenElse that appears in the condition branch of a +// Cond. This is a weird but valid thing to do. +TEST(Registerizer, RegisterizerCondIfThenElse) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + BufHandle c("C", {5}, kInt); + VarHandle x("x", kInt); + + StmtPtr stmt = Block::make({Cond::make( + CompareSelect::make( + IfThenElse::make( + CompareSelect::make( + Load::make(a, {x}), 5, CompareSelectOperation::kLT), + Load::make(a, {x}), + Load::make(b, {x})), + x, + CompareSelectOperation::kEQ), + Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)), + nullptr)}); + + /* + * if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) { + * C[x] = (C[x]) + 1; + * } + */ + + stmt = registerize(stmt); + + // access to A can be registerized, but not B or C + + /* + * int A_1 = A[x]; + * if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) { + * C[x] = (C[x]) + 1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x] +# CHECK: C[x] = (C[x]) + 1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can registerize a conditional access in the RHS of a store unhidden by it's +// LHS, and hoist it out of a loop. +TEST(Registerizer, RegisterizerIfThenElseLoop) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + StmtPtr stmt = For::make( + y, + 0, + 10, + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Load::make(a, {x}), + Load::make(b, {y})))); + + /* + * for (int y = 0; y < 10; y++) { + * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]); + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[x]; + * for (int y = 0; y < 10; y++) { + * A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); + * } + * A[x] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[x]; +# CHECK: for ( +# CHECK: A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]); +# CHECK: } +# CHECK: A[x] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Cannot registerize if the RHS overlaps the access creating visibility. +TEST(Registerizer, RegisterizerIfThenElseLoopCut) { + BufHandle a("A", {5}, kInt); + BufHandle b("B", {5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + StmtPtr stmt = Block::make({For::make( + y, + 0, + 10, + Store::make( + a, + {x}, + IfThenElse::make( + CompareSelect::make(x, 3, CompareSelectOperation::kLT), + Load::make(a, {x}), + Load::make(a, {y}))))}); + + /* + * for (int y = 0; y < 10; y++) { + * A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]); + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Simple case where an access is cut by an overlapping access later in the +// program, we can registerize up until the overlap. +TEST(Registerizer, RegisterizerPartialAfter) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})), + For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})))}); + + /* + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[0] = A_1; + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for ( +# CHECK: A_1 = A_1 + x; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: for ( +# CHECK: A[x] = A[x - 1]; +# CHECK: } +# CHECK-NOT: A)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// We can registerize an access which overlaps a previous access, the +// initializer must be inserted after the previous access. +TEST(Registerizer, RegisterizerPartialBefore) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))), + Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))}); + + /* + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * for (int x = 1; x < 10; x++) { + * A[x] = A[x - 1]; + * } + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK-NOT: int +# CHECK: for ( +# CHECK: A[x] = A[x - 1]; +# CHECK: } +# CHECK: int A_1 = 0; +# CHECK: for ( +# CHECK: A_1 = A_1 + x; +# CHECK: } +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// The combination of the previous two tests, an access is cut by an overlapping +// access in both directions. +TEST(Registerizer, RegisterizerPartialInside) { + BufHandle a("A", {1}, kInt); + VarHandle x1("x1", kInt); + VarHandle x2("x2", kInt); + VarHandle x3("x3", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 2), + For::make( + x1, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x1))), + For::make(x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}))), + For::make( + x3, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x3)))}); + + /* + * A[0] = 2; + * for (int x1 = 0; x1 < 10; x1++) { + * A[0] = (A[0]) + x1; + * } + * for (int x2 = 1; x2 < 10; x2++) { + * A[x2] = A[x2 - 1]; + * } + * for (int x3 = 0; x3 < 10; x3++) { + * A[0] = (A[0]) + x3; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 2; + * for (int x1 = 0; x1 < 10; x1++) { + * A_1 = A_1 + x1; + * } + * A[0] = A_1; + * for (int x2 = 1; x2 < 10; x2++) { + * A[x2] = A[x2 - 1]; + * } + * int A_2 = A[0]; + * for (int x3 = 0; x3 < 10; x3++) { + * A_2 = A_2 + x3; + * } + * A[0] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 2; +# CHECK: for ( +# CHECK: A_1 = A_1 + x1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: for ( +# CHECK: A[x2] = +# CHECK: } +# CHECK: int A_2 = A[0]; +# CHECK: for ( +# CHECK: A_2 = A_2 + x3; +# CHECK: } +# CHECK: A[0] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An element could be registerized program wide but is cut by a conditional +// access, we should break this into two scalars and write back to the buffer +// before the condition. +TEST(Registerizer, RegisterizerPartialCondition) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 2), + For::make( + x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x))), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Store::make(a, {x}, Load::make(a, {x - 1})), + nullptr), + For::make( + x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x)))}); + + /* + * A[0] = 2; + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + * if (x<5 ? 1 : 0) { + * A[x] = A[x - 1]; + * } + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 2; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[0] = A_1; + * if (x<5 ? 1 : 0) { + * A[x] = A[x - 1]; + * } + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_2 + x; + * } + * A[0] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 2; +# CHECK: for ( +# CHECK: A_1 = A_1 + x; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: if ( +# CHECK: A[x] = +# CHECK: } +# CHECK: int A_2 = A[0]; +# CHECK: for ( +# CHECK: A_2 = A_2 + x; +# CHECK: } +# CHECK: A[0] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Tests case where an access is cut by an internal conditional access which +// itself is registerized. +TEST(Registerizer, RegisterizerPartialConditionInternalCut) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 1), + Store::make(a, {0}, 3), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}), + nullptr), + Store::make(a, {0}, 4), + Store::make(a, {0}, 6)}); + + /* + * A[0] = 1; + * A[0] = 3; + * if (x<5 ? 1 : 0) { + * A[x] = 1; + * A[x] = 3; + * } + * A[0] = 4; + * A[0] = 6; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 1; + * A_1 = 3; + * A[0] = A_1; + * if (x<5 ? 1 : 0) { + * int A_2 = 1; + * A_2 = 3; + * A[x] = A_2; + * } + * int A_3 = 4; + * A_3 = 6; + * A[0] = A_3; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 1; +# CHECK: A_1 = 3 +# CHECK: A[0] = A_1; +# CHECK: if ( +# CHECK: int A_2 = 1; +# CHECK: A_2 = 3; +# CHECK: A[x] = A_2; +# CHECK: } +# CHECK: int A_3 = 4; +# CHECK: A_3 = 6; +# CHECK: A[0] = A_3;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// First statement in condition closes outer access, but can be registerized +// with later statements. +TEST(Registerizer, RegisterizerPartialConditionInternalStart) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, 1), + Store::make(a, {0}, 3), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}), + nullptr), + Store::make(a, {x}, 4), + Store::make(a, {x}, 6)}); + + /* + * A[0] = 1; + * A[0] = 3; + * if (x<5 ? 1 : 0) { + * A[x] = 1; + * A[x] = 3; + * } + * A[x] = 4; + * A[x] = 6; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 1; + * A_1 = 3; + * A[0] = A_1; + * int A_2 = A[x]; <--- must read from the input here. + * if (x<5 ? 1 : 0) { + * A_2 = 1; + * A_2 = 3; + * } + * A_2 = 4; + * A_2 = 6; + * A[x] = A_2; + */ + + // TODO: I suppose we could refactor with a conditional initializer? + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 1; +# CHECK: A_1 = 3 +# CHECK: A[0] = A_1; +# CHECK: int A_2 = A[x]; +# CHECK: if ( +# CHECK: A_2 = 1; +# CHECK: A_2 = 3; +# CHECK: } +# CHECK: A_2 = 4; +# CHECK: A_2 = 6; +# CHECK: A[x] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// An access cuts two open overlaps and creates four scalar variables. +TEST(Registerizer, RegisterizerPartialOverlapsTwo) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {1}, Load::make(a, {0})), + Store::make(a, {0}, Load::make(a, {1})), + Store::make(a, {0}, Load::make(a, {1})), + For::make(x, 1, 10, Store::make(a, {x}, x)), + Store::make(a, {1}, Load::make(a, {0})), + Store::make(a, {0}, Load::make(a, {1})), + Store::make(a, {0}, Load::make(a, {1}))}); + + /* + * A[1] = A[0]; + * A[0] = A[1]; + * A[0] = A[1]; + * for (int x = 1; x < 10; x++) { + * A[x] = x; + * } + * A[1] = A[0]; + * A[0] = A[1]; + * A[0] = A[1]; + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * int A_2 = A_1; + * A_1 = A_2; + * A_1 = A_2; + * A[1] = A_2; + * A[0] = A_1; + * for (int x = 1; x < 10; x++) { + * A[x] = x; + * } + * int A_3 = A[0]; + * int A_4 = A_3; + * A_3 = A_4; + * A_3 = A_4; + * A[1] = A_4; + * A[0] = A_3; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: int A_2 = A_1; +# CHECK: A_1 = A_2; +# CHECK: A_1 = A_2; +# CHECK: A[1] = A_2; +# CHECK: A[0] = A_1; +# CHECK: for ( +# CHECK: A[x] = x; +# CHECK: } +# CHECK: int A_3 = A[0]; +# CHECK: int A_4 = A_3; +# CHECK: A_3 = A_4; +# CHECK: A_3 = A_4; +# CHECK: A[1] = A_4; +# CHECK: A[0] = A_3;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Nested blocks will automatically be flattened and do not provent +// registerization of enclosed accesses. +TEST(Registerizer, RegisterizerNestedBlocks) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), 2))}), + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), 3)), + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), 4))})})}); + + /* + * A[0] = (A[0]) + 1; + * { + * A[0] = (A[0]) + 2; + * } + * { + * A[0] = (A[0]) + 3; + * { + * A[0] = (A[0]) + 4; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * A_1 = A_1 + 1; + * A_1 = A_1 + 2; + * A_1 = A_1 + 3; + * A_1 = A_1 + 4; + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: A_1 = A_1 + 1; +# CHECK: A_1 = A_1 + 2; +# CHECK: A_1 = A_1 + 3; +# CHECK: A_1 = A_1 + 4; +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// The access can be registerized internally to a condition, but must ensure +// that both initializer and finalizer are within the same condition. +TEST(Registerizer, RegisterizerNestedConditions) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make({Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr)}), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * if (x==2 ? 1 : 0) { + * + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<5 ? 1 : 0) { + * int A_1 = A[0]; + * A_1 = A_1 + 1; + * if (x==2 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x<5 +# CHECK: int A_1 = A[0]; +# CHECK: A_1 = A_1 + 1; +# CHECK: if (x==2 +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// If an access exists outside the scope of the condition then we can lift +// nested conditional usages into the same scalar. +TEST(Registerizer, RegisterizerNestedConditionsUnhidden) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make( + {Store::make(a, {1}, 1), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr)}), + nullptr)}); + + /* + * A[0] = (A[0]) + 1; + * if (x<5 ? 1 : 0) { + * A[1] = 1; + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = A[0]; + * A_1 = A_1 + 1; + * if (x<5 ? 1 : 0) { + * A[1] = 1; + * if (x==2 ? 1 : 0) { + * A_1 = A_1 + 1; + * } + * } + * A[0] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = A[0]; +# CHECK: A_1 = A_1 + 1; +# CHECK: if (x<5 +# CHECK: A[1] = 1; +# CHECK: if (x==2 +# CHECK: A_1 = A_1 + 1; +# CHECK: A[0] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr)}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * if (x<5 ? 1 : 0) { + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); + + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + stmt = registerize(stmt); +} + +TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr)}), + nullptr), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); + + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + stmt = registerize(stmt); +} + +// If an access is cut by another access internal to a condition block, it still +// cuts the access. +TEST(Registerizer, RegisterizerNestedConditionsCut) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + Block::make( + {Store::make(a, {x}, 1), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr)}), + nullptr)}); + + /* + * A[0] = (A[0]) + 1; + * if (x<5 ? 1 : 0) { + * A[x] = 1; + * if (x==2 ? 1 : 0) { + * + * A[0] = (A[0]) + 1; + * } + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +TEST(Registerizer, RegisterizerNestedConditionLoopHidden) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {x}, 0), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)), + nullptr)}))}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * for (int x = 0; x < 10; x++) { + * B[x] = 0; <-- this is only here to prevent Loop/Cond reordering. + * if (x==2 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// Three loops and four element regions, three of which should be registerized +// at different levels of the IR. +TEST(Registerizer, RegisterizerNestedConditionThreeDeep) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {4}, 0), + Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kGT), + Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kGT), + Block::make({ + Cond::make( + CompareSelect::make(x, 4, CompareSelectOperation::kGT), + Block::make({ + Store::make( + a, {1}, Add::make(Load::make(a, {1}), 1)), + Store::make( + a, {2}, Add::make(Load::make(a, {2}), 1)), + Store::make( + a, {3}, Add::make(Load::make(a, {3}), 1)), + Store::make( + a, {4}, Add::make(Load::make(a, {4}), 1)), + Store::make( + a, {1}, Add::make(Load::make(a, {1}), 1)), + }), + nullptr), + Store::make(a, {2}, Add::make(Load::make(a, {2}), 1)), + }), + nullptr), + nullptr)}); + + /* + * A[4] = 0; + * if (x>2 ? 1 : 0) { + * if (x>3 ? 1 : 0) { + * if (x>4 ? 1 : 0) { + * A[1] = (A[1]) + 1; + * A[2] = (A[2]) + 1; + * A[3] = (A[3]) + 1; + * A[4] = (A[4]) + 1; + * A[1] = (A[1]) + 1; + * } + * A[2] = (A[2]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * if (x>2 ? 1 : 0) { + * if (x>3 ? 1 : 0) { + * int A_3 = A[2]; + * if (x>4 ? 1 : 0) { + * int A_2 = A[1]; + * A_2 = A_2 + 1; + * A_3 = A_3 + 1; + * A[3] = (A[3]) + 1; + * A_1 = A_1 + 1; + * A_2 = A_2 + 1; + * A[1] = A_2; + * } + * A_3 = A_3 + 1; + * A[2] = A_3; + * } + * } + * A[4] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: if (x>2 ? 1 : 0) { +# CHECK: if (x>3 ? 1 : 0) { +# CHECK: int A_3 = A[2]; +# CHECK: if (x>4 ? 1 : 0) { +# CHECK: int A_2 = A[1]; +# CHECK: A_2 = A_2 + 1; +# CHECK: A_3 = A_3 + 1; +# CHECK: A[3] = (A[3]) + 1; +# CHECK: A_1 = A_1 + 1; +# CHECK: A_2 = A_2 + 1; +# CHECK: A[1] = A_2; +# CHECK: } +# CHECK: A_3 = A_3 + 1; +# CHECK: A[2] = A_3; +# CHECK: } +# CHECK: } +# CHECK: A[4] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Can replace a simple scalar access with a local variable even when that +// variable is an outer loop var. +TEST(Registerizer, RegisterizerNestedLoopSimple) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make({For::make( + y, + 0, + 10, + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})))}); + + /* + * for (int y = 0; y < 10; y++) { + * for (int x = 0; x < 10; x++) { + * A[y] = (A[y]) + x; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * for (int y = 0; y < 10; y++) { + * int A_1 = A[y]; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[y] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int y +# CHECK: int A_1 = A[y]; +# CHECK: for (int x +# CHECK: A_1 = A_1 + x; +# CHECK: } +# CHECK: A[y] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Test the positive case of the hiddenAccess split, where an internal +// conditional access can be hoisted up through a loop to match an existing +// access in a higher scope and the two can be registerized. +TEST(Registerizer, RegisterizerHiddenAccessYes) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {x}, 0), + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kEQ), + For::make( + y, + 0, + 10, + Store::make( + a, {0}, Add::make(Load::make(a, {0}), 1))), + nullptr)}))}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * for (int y = 0; y < 10; y++) { + * A[0] = (A[0]) + 1; + * } + * } + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x==2 ? 1 : 0) { + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * for (int y = 0; y < 10; y++) { + * A_1 = A_1 + 1; + * } + * } + * } + * A[0] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x==2 +# CHECK: int A_1 = 0; +# CHECK: for (int x +# CHECK: B[x] = 0; +# CHECK: if (x==3 +# CHECK: for (int y +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Test the negative case of the hiddenAccess split, where the hoisted access is +// never unhidden at a higher scope and registerization occurs at the lower +// scope. +TEST(Registerizer, RegisterizerHiddenAccessNo) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Block::make({For::make( + x, + 0, + 10, + Block::make( + {Store::make(b, {x}, 0), + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + Cond::make( + CompareSelect::make(x, 3, CompareSelectOperation::kEQ), + For::make( + y, + 0, + 10, + Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), + nullptr)}))}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * for (int y = 0; y < 10; y++) { + * A[0] = (A[0]) + 1; + * } + * } + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x==2 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * B[x] = 0; + * if (x==3 ? 1 : 0) { + * int A_1 = A[0]; + * for (int y = 0; y < 10; y++) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + * } + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x==2 +# CHECK: for (int x +# CHECK: B[x] = 0; +# CHECK: if (x==3 +# CHECK: int A_1 = A[0]; +# CHECK: for (int y +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: } +# CHECK: } +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// In this case the conditional access must be hoisted by two loops, there are +// two accesses here one is unhidden and the other isnt. A[0] can be +// registerized but B[0] cannot. +TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make({Cond::make( + CompareSelect::make(x, 2, CompareSelectOperation::kEQ), + Block::make( + {Store::make(a, {0}, 0), + For::make( + x, + 0, + 10, + For::make( + y, + 0, + 10, + Block::make({Cond::make( + CompareSelect::make(y, 3, CompareSelectOperation::kEQ), + Block::make( + {Store::make( + a, {0}, Add::make(Load::make(a, {0}), 1)), + Store::make( + b, {0}, Add::make(Load::make(b, {0}), 1))}), + nullptr)})))}), + nullptr)}); + + /* + * if (x==2 ? 1 : 0) { + * A[0] = 0; + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * if (y==3 ? 1 : 0) { + * A[0] = (A[0]) + 1; + * B[0] = (B[0]) + 1; + * } + * } + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x==2 ? 1 : 0) { + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * if (y==3 ? 1 : 0) { + * A_1 = A_1 + 1; + * B[0] = (B[0]) + 1; + * } + * } + * } + * A[0] = A_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x==2 +# CHECK: int A_1 = 0; +# CHECK: for (int x +# CHECK: for (int y +# CHECK: if (y==3 +# CHECK: A_1 = A_1 + 1; +# CHECK: B[0] = (B[0]) + 1; +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Accesses are registerized inside two conditions, but the immediate parent is +// not a condition. +TEST(Registerizer, RegisterizerTwoConditionalLoops) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + For::make( + x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), + nullptr), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + For::make( + x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + * if (x>5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<5 ? 1 : 0) { + * int A_1 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + * if (x>5 ? 1 : 0) { + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_2 + 1; + * } + * A[0] = A_2; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x<5 +# CHECK: int A_1 = A[0]; +# CHECK: for (int x +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: } +# CHECK: if (x>5 +# CHECK: int A_2 = A[0]; +# CHECK: for (int x +# CHECK: A_2 = A_2 + 1; +# CHECK: } +# CHECK: A[0] = A_2; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Accesses are registerized inside two conditions, cut in the middle. +TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) { + BufHandle a("A", {1}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kLT), + For::make( + x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), + nullptr), + For::make(x, 0, 10, Store::make(a, {x}, 1)), + Cond::make( + CompareSelect::make(x, 5, CompareSelectOperation::kGT), + For::make( + x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))), + nullptr)}); + + /* + * if (x<5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + * for (int x = 0; x < 10; x++) { + * A[x] = 1; + * } + * if (x>5 ? 1 : 0) { + * for (int x = 0; x < 10; x++) { + * A[0] = (A[0]) + 1; + * } + * } + */ + + stmt = registerize(stmt); + + /* + * if (x<5 ? 1 : 0) { + * int A_1 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + 1; + * } + * A[0] = A_1; + * } + * for (int x = 0; x < 10; x++) { + * A[x] = 1; + * } + * if (x>5 ? 1 : 0) { + * int A_2 = A[0]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_2 + 1; + * } + * A[0] = A_2; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: if (x<5 +# CHECK: int A_1 = A[0]; +# CHECK: for (int x +# CHECK: A_1 = A_1 + 1; +# CHECK: } +# CHECK: A[0] = A_1; +# CHECK: } +# CHECK: for (int x +# CHECK: A[x] = 1; +# CHECK: if (x>5 +# CHECK: int A_2 = A[0]; +# CHECK: for (int x +# CHECK: A_2 = A_2 + 1; +# CHECK: } +# CHECK: A[0] = A_2; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// references a Let var in a local scope which cannot be hoisted out of the +// loop. +TEST(Registerizer, RegisterizerLoopLetVar) { + BufHandle a("A", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = IRSimplifier::simplify(Block::make({For::make( + x, + 0, + 10, + Block::make( + {Let::make(y, 30), + Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))})); + + /* + * for (int x = 0; x < 10; x++) { + * int y = 30; + * A[y] = x + (A[y]); + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// references a Let var in an outer scope that does not prevent hoisting the +// initializer. +TEST(Registerizer, RegisterizerLoopLetVarOuter) { + BufHandle a("A", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make( + {Let::make(y, 30), + For::make( + x, + 0, + 10, + Block::make( + {Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))}); + + /* + * int y = 30; + * for (int x = 0; x < 10; x++) { + * A[y] = x + (A[y]); + * } + */ + + stmt = registerize(stmt); + + /* + * int y = 30; + * int A_1 = A[y]; + * for (int x = 0; x < 10; x++) { + * A_1 = A_1 + x; + * } + * A[y] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int y = 30; +# CHECK: int A_1 = A[y]; +# CHECK: for (int x +# CHECK: A_1 = A_1 + x; +# CHECK: A[y] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Okay so the registerizer generally goes after index flattening, but just in +// case. Test multi index registerization. +TEST(Registerizer, RegisterizerMultiDim) { + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}), x))}))}); + + /* + * A[0, 1, 2] = 0; + * for (int x = 0; x < 10; x++) { + * A[0, 1, 2] = (A[0, 1, 2]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * int A_1 = 0; + * for (int x = 0; x < 10; x++) { + * A_1 = x + A_1; + * } + * A[0, 1, 2] = A_1; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: int A_1 = 0; +# CHECK: for (int x = 0; x < 10; x++) +# CHECK-NOT: A[ +# CHECK: A_1 = +# CHECK: A[0, 1, 2] = A_1;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// Wont registerize if only some dims match, but will still registerize distinct +// elements. +TEST(Registerizer, RegisterizerMultiDimPartial) { + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}), x))}))}); + + /* + * A[0, 1, 2] = 0; + * for (int x = 0; x < 10; x++) { + * A[0, 2, 2] = (A[0, 1, 4]) + x; + * } + */ + + stmt = registerize(stmt); + + /* + * A[0, 1, 2] = 0; + * int A_1 = A[0, 1, 4]; + * int A_2 = A[0, 2, 2]; + * for (int x = 0; x < 10; x++) { + * A_2 = A_1 + x; + * } + * A[0, 2, 2] = A_2; + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: A[0, 1, 2] = 0; +# CHECK: int A_1 = A[0, 1, 4]; +# CHECK: int A_2 = A[0, 2, 2]; +# CHECK: for ( +# CHECK: A_2 = A_1 + x; +# CHECK: A[0, 2, 2] = A_2;)IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// If they could overlap across all dimensions we cannot registerize. +TEST(Registerizer, RegisterizerMultiDimOverlap) { + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}), x))}))}); + stmt = IRSimplifier::simplify(stmt); + + /* + * A[0, 1, 2] = 0; + * for (int x = 0; x < 10; x++) { + * A[0, x, 2] = (A[y, 2, 2]) + x; + * } + */ + + std::ostringstream before; + before << *stmt; + + // No change. + stmt = registerize(stmt); + + std::ostringstream after; + after << *stmt; + + ASSERT_EQ(before.str(), after.str()); +} + +// But, if one dimension is known to be distinct they do not overlap. +TEST(Registerizer, RegisterizerMultiDimPartialOverlap) { + BufHandle a("A", {3, 4, 5}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + StmtPtr stmt = Block::make( + {Store::make(a, {0, 1, 2}, 0), + For::make( + x, + 0, + 10, + Block::make({Store::make( + a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}), x))}))}); + + /* + * A[0, 1, 2] = 0; <---- 2nd dim overlaps with store. + * for (int x = 0; x < 10; x++) { + * A[0, x, 2] = (A[y, 2, 4]) + x; <---- 3rd dim has constant diff. + * } + */ + + stmt = registerize(stmt); + + /* + * A[0, 1, 2] = 0; + * int A_1 = A[y, 2, 4]; + * for (int x = 0; x < 10; x++) { + * A[0, x, 2] = A_1 + x; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: A[0, 1, 2] = 0; +# CHECK: int A_1 = A[y, 2, 4]; +# CHECK: for ( +# CHECK: A[0, x, 2] = A_1 + x; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// A 3D reduction with different input dimensionality. +TEST(Registerizer, RegisterizerMultiDim3DReduction1) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10, 10}, kInt); + BufHandle c("C", {10, 10, 10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + StmtPtr stmt = For::make( + x, + 0, + 10, + For::make( + y, + 0, + 10, + For::make( + z, + 0, + 10, + Store::make( + c, + {x, y, z}, + Add::make( + Load::make(c, {x, y, z}), + Mul::make(Load::make(b, {x, y}), Load::make(a, {x}))))))); + + /* + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * for (int z = 0; z < 10; z++) { + * C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]); + * } + * } + * } + */ + + // We can registerize the A and B access since they can be hoisted before + // hitting a dependent loop var. + + stmt = registerize(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * int A_1 = A[x]; + * for (int y = 0; y < 10; y++) { + * int B_1 = B[x, y]; + * for (int z = 0; z < 10; z++) { + * C[x, y, z] = A_1 * B_1 + (C[x, y, z]); + * } + * } + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int x +# CHECK: int A_1 = A[x]; +# CHECK: for (int y +# CHECK: int B_1 = B[x, y]; +# CHECK: for (int z +# CHECK: C[x, y, z] = A_1 * B_1 + (C[x, y, z]); +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +// A 3D reduction with the same smaller dimensionality using different loop +// vars. +TEST(Registerizer, RegisterizerMultiDim3DReduction2) { + BufHandle a("A", {10}, kInt); + BufHandle b("B", {10}, kInt); + BufHandle c("C", {10}, kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + StmtPtr stmt = For::make( + x, + 0, + 10, + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + For::make( + y, + 0, + 10, + For::make( + z, + 0, + 10, + Store::make( + c, + {x}, + Add::make( + Load::make(c, {x}), + Mul::make(Load::make(b, {y}), Load::make(a, {x}))))))); + + /* + * for (int x = 0; x < 10; x++) { + * for (int y = 0; y < 10; y++) { + * for (int z = 0; z < 10; z++) { + * C[x] = (C[x]) + (B[y]) * (A[x]); + * } + * } + * } + */ + + // We can registerize all accesses, the A and C access can be hoisted to the + // outer loop since they depend only on it's loop var while the B can only be + // raised to the loop of y. + + stmt = registerize(stmt); + + /* + * for (int x = 0; x < 10; x++) { + * int A_1 = A[x]; + * int C_1 = C[x]; + * for (int y = 0; y < 10; y++) { + * int B_1 = B[y]; + * for (int z = 0; z < 10; z++) { + * C_1 = A_1 * B_1 + C_1; + * } + * } + * C[x] = C_1; + * } + */ + + std::ostringstream oss; + oss << *stmt; + + const std::string& verification_pattern = + R"IR( +# CHECK: for (int x +# CHECK: int A_1 = A[x]; +# CHECK: int C_1 = C[x]; +# CHECK: for (int y +# CHECK: int B_1 = B[y]; +# CHECK: for (int z +# CHECK: C_1 = A_1 * B_1 + C_1; +# CHECK: } +# CHECK: } +# CHECK: C[x] = C_1; +# CHECK: })IR"; + + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp new file mode 100644 index 0000000000000..99a00d0d62c11 --- /dev/null +++ b/test/cpp/tensorexpr/test_simplify.cpp @@ -0,0 +1,5680 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; +using SimpleIRExprEval = ExprEval; + +TEST(Simplify, ConstantFoldSimple) { + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle f = (a + b); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 5); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 5.f); +} + +TEST(Simplify, ConstantFoldTwoLayer) { + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle c(4.0f); + ExprHandle d(5.0f); + ExprHandle f = (a + b) - (c + d); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), -4); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), -4.f); +} + +TEST(Simplify, ConstantFoldShifts) { + ExprHandle a(7); + ExprHandle b(2); + ExprHandle c(3); + ExprHandle f = ((a << b) << b) >> c; + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 14); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 7 << (4 - 3)); +} + +TEST(Simplify, ConstantFoldBitwise) { + ExprHandle a(59); + ExprHandle b(22); + ExprHandle c(101); + ExprHandle f = (a ^ b) & c; + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 37); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), (59 ^ 22) & 101); +} + +TEST(Simplify, ConstantFoldMultiOp) { + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle c(4.0f); + ExprHandle d(5.0f); + ExprHandle e(6.0f); + ExprHandle f(7.0f); + ExprHandle fn = ((a / e) - (c + d)) * (f / b); + + ExprHandle newF = IRSimplifier::simplify(fn); + ASSERT_NE(newF.AsNode(), nullptr); + + SimpleIRExprEval eval(newF); + SimpleIRExprEval ref(fn); + + ASSERT_EQ(eval.value(), ref.value()); +} + +TEST(Simplify, ConstantFoldMinMax) { + ExprHandle a(12.0f); + ExprHandle b(15.0f); + ExprHandle c(17.0f); + + // x = max(12, min(15, 17)). + ExprHandle minHandle = Min::make(b, c, true); + ExprHandle fn = Max::make(a, minHandle, false); + + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_EQ(fn.dtype().scalar_type(), ScalarType::Float); + + ExprHandle newF = IRSimplifier::simplify(fn); + ASSERT_NE(newF.AsNode(), nullptr); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 15.f); +} + +TEST(Simplify, ConstantFoldIntrinsics) { + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle c(4.0f); + ExprHandle powHandle = Intrinsics::make(kPow, a, b); + ExprHandle sinHandle = Intrinsics::make(kSin, powHandle); + ExprHandle modHandle = Intrinsics::make(kFmod, c, sinHandle); + ExprHandle logHandle = Intrinsics::make(kLog10, modHandle); + ExprHandle rndHandle = Intrinsics::make(kRound, logHandle); + ExprHandle fn = Intrinsics::make(kAbs, rndHandle); + + ExprHandle newF = IRSimplifier::simplify(fn); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 1); + + SimpleIRExprEval eval(newF); + SimpleIRExprEval ref(fn); + + ASSERT_EQ(eval.value(), ref.value()); +} + +TEST(Simplify, ConstantFoldCastToBool) { + ExprHandle f = Cast::make(kBool, IntImm::make(0)); + ExprHandle newF = IRSimplifier::simplify(f); + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), false); +} + +TEST(Simplify, ConstantFoldWithVar) { + { + VarHandle x("x", kInt); + ExprHandle body = x * (ExprHandle(2) + ExprHandle(4)); + + ExprHandle newF = IRSimplifier::simplify(body); + MulPtr root = newF.AsNode(); + ASSERT_NE(root, nullptr); + ASSERT_NE(to(root->lhs()), nullptr); + + SimpleIRExprEval eval(newF); + eval.bindVar(x, ExprHandle(3)); + ASSERT_EQ(eval.value(), 3 * (2 + 4)); + } + + { + VarHandle x("x", kFloat); + ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f)); + + ExprHandle newF = IRSimplifier::simplify(body); + MulPtr root = newF.AsNode(); + ASSERT_NE(root, nullptr); + ASSERT_NE(to(root->rhs()), nullptr); + + SimpleIRExprEval eval(newF); + eval.bindVar(x, ExprHandle(3.f)); + ASSERT_EQ(eval.value(), 3 * (2 + 4)); + } +} + +TEST(Simplify, ConditionalSelectFoldSimple) { + ExprHandle a(3.0f); + ExprHandle b(4.0f); + ExprHandle c(3.0f); + { + ExprHandle f = (a > b); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 0); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 0); + } + { + ExprHandle f = (a < b); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 1); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 1); + } + { + ExprHandle f = (a == c); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 1); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 1); + } + { + ExprHandle f = (a != c); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 0); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 0); + } +} + +TEST(Simplify, ConditionalSelectFoldTwoLayer) { + ExprHandle a(3.0f); + ExprHandle b(2.0f); + ExprHandle c(2.0f); + ExprHandle d(1.0f); + { + ExprHandle f = (a + b < c + d); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 0); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 0); + } + { + ExprHandle f = (a + b > c + d); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 1); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 1); + } + { + ExprHandle f = (a + d == b + c); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 1); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 1); + } + { + ExprHandle f = (a + d != b + c); + + ExprHandle newF = IRSimplifier::simplify(f); + ASSERT_NE(newF.AsNode(), nullptr); + ASSERT_EQ(newF.AsNode()->value(), 0); + + SimpleIRExprEval eval(newF); + ASSERT_EQ(eval.value(), 0); + } +} + +TEST(Simplify, ConditionalSelectFoldWithVar) { + VarHandle x("x", kFloat); + ExprHandle f = x < 4.f; + + ExprHandle newF = IRSimplifier::simplify(f); + IntImmPtr folded = newF.AsNode(); + ASSERT_EQ(folded, nullptr); + + { + SimpleIRExprEval eval(newF); + eval.bindVar(x, ExprHandle(3.f)); + ASSERT_EQ(eval.value(), 1); + } + { + SimpleIRExprEval eval(newF); + eval.bindVar(x, ExprHandle(5.f)); + ASSERT_EQ(eval.value(), 0); + } +} + +TEST(Simplify, UnFoldableExpr) { + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle body = (ExprHandle(3) * x) + (ExprHandle(5) * y); + + ExprHandle newF = IRSimplifier::simplify(body); + AddPtr root = newF.AsNode(); + ASSERT_NE(root, nullptr); + ASSERT_EQ(to(root->lhs()), nullptr); + ASSERT_EQ(to(root->rhs()), nullptr); + + SimpleIRExprEval eval(newF); + eval.bindVar(x, ExprHandle(3.f)); + eval.bindVar(y, ExprHandle(2.f)); + ASSERT_EQ(eval.value(), 9 + 10); +} + +TEST(Simplify, HashSimple) { + VarHandle x("x", kFloat); + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle f = a + b * x; + + HashProvider hasher; + + auto hash_x = hasher.hash(x.node()); + auto hash_a = hasher.hash(a.node()); + auto hash_f = hasher.hash(f.node()); + + ASSERT_NE(hash_x, (size_t)0); + ASSERT_NE(hash_a, (size_t)0); + ASSERT_NE(hash_f, (size_t)0); + ASSERT_NE(hash_x, hash_a); + ASSERT_NE(hash_x, hash_f); + ASSERT_NE(hash_a, hash_f); +} + +TEST(Simplify, HashEquivalence) { + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle f = (x * y) + (x * y); + + AddPtr root = f.AsNode(); + ASSERT_NE(root, nullptr); + + HashProvider hasher; + auto hash_f = hasher.hash(f.node()); + auto hash_l = hasher.hash(root->lhs()); + auto hash_r = hasher.hash(root->rhs()); + + // Root not equal to either branch. + ASSERT_NE(hash_f, hash_l); + ASSERT_NE(hash_f, hash_r); + // but branches are equal. + ASSERT_EQ(hash_l, hash_r); + + // Still equivalent if separate. + ExprHandle a(2); + ExprHandle f2 = x + a / y; + ExprHandle b(2); + ExprHandle f3 = x + b / y; + ASSERT_EQ(hasher.hash(f2.node()), hasher.hash(f3.node())); + + // Not equivalent if different vars (even with same name). + VarHandle z("x", kFloat); + ExprHandle f4 = z + b / y; + ASSERT_NE(hasher.hash(f2.node()), hasher.hash(f4.node())); + + // Intrinsics sanity check. + ExprHandle f5 = Intrinsics::make(kSin, x) * Intrinsics::make(kCos, x); + ASSERT_NE(hasher.hash(f5.node()), (size_t)0); +} + +TEST(Simplify, HashEquivalenceRand) { + ExprHandle f = + Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt); + + AddPtr root = f.AsNode(); + ASSERT_NE(root, nullptr); + + HashProvider hasher; + auto hash_f = hasher.hash(f.node()); + auto hash_l = hasher.hash(root->lhs()); + auto hash_r = hasher.hash(root->rhs()); + + // Root not equal to either branch. + ASSERT_NE(hash_f, hash_l); + ASSERT_NE(hash_f, hash_r); + // and branches are NOT equal. + ASSERT_NE(hash_l, hash_r); +} + +TEST(Simplify, HashEquivalenceAfterFolding) { + VarHandle x("x", kFloat); + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle c(5.0f); + + ExprHandle f1 = ((a + b) * x); + ExprHandle f2 = (c * x); + + HashProvider hasher; + auto hash_l = hasher.hash(f1.node()); + auto hash_r = hasher.hash(f2.node()); + + // Root not equal to either branch, and branches not equal. + ASSERT_NE(hash_l, hash_r); + + ExprHandle ff1 = IRSimplifier::simplify(f1); + ExprHandle ff2 = IRSimplifier::simplify(f2); + + auto hash_l_n = hasher.hash(ff1.node()); + auto hash_r_n = hasher.hash(ff2.node()); + // but branches are now equal. + ASSERT_EQ(hash_l_n, hash_r_n); +} + +TEST(Simplify, HashDifferenceTypes) { + HashProvider hasher; + std::vector immediates; + + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + // NOLINTNEXTLINE(modernize-use-bool-literals) + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + immediates.push_back(alloc(1)); + + // Immediates of different types are not equal. + for (unsigned int i = 0; i < immediates.size(); ++i) { + for (unsigned int j = i + 1; j < immediates.size(); ++j) { + ASSERT_NE(hasher.hash(immediates[i]), hasher.hash(immediates[j])); + } + } + + // But coerced immediates are if they are the same type: + ExprHandle f1 = ExprHandle(2.f) + CharImm::make(1); + ExprHandle f2 = Cast::make(kFloat, IntImm::make(3)); + + ExprHandle ff1 = IRSimplifier::simplify(f1); + ExprHandle ff2 = IRSimplifier::simplify(f2); + + ASSERT_EQ(hasher.hash(ff1.node()), hasher.hash(ff2.node())); +} + +TEST(Simplify, HashLargeExpression) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + VarHandle i("i", kInt); + auto memcpy_stmt = For::make( + i, + 0, + N, + Store::make( + c, + {i}, + CompareSelect::make( + Load::make(a, {i}), + Load::make(b, {i}), + CompareSelectOperation::kEQ))); + + BufHandle d("D", {1}, kInt); + BufHandle e("E", {1}, kInt); + auto store_ramp_stmt = Store::make( + e, {Ramp::make(0, 1, 4)}, Load::make(d, {Ramp::make(0, 1, 4)})); + + auto if_stmt = Cond::make( + CompareSelect::make( + Load::make(a, {i}), Load::make(b, {i}), CompareSelectOperation::kGE), + memcpy_stmt, + store_ramp_stmt); + + HashProvider hasher; + auto hash_r = hasher.hash(if_stmt); + // We should not have to do any more work. + ASSERT_TRUE(hasher.cachedHash(memcpy_stmt)); + auto hash_t = hasher.hash(memcpy_stmt); + ASSERT_TRUE(hasher.cachedHash(store_ramp_stmt)); + auto hash_f = hasher.hash(store_ramp_stmt); + + // Root not equal to either branch, and branches not equal. + ASSERT_NE(hash_r, hash_t); + ASSERT_NE(hash_r, hash_f); + ASSERT_NE(hash_t, hash_f); +} + +TEST(Simplify, HashForLoopOptions) { + constexpr int N = 1024; + BufHandle a("A", {N}, kInt); + BufHandle b("B", {N}, kInt); + BufHandle c("C", {N}, kInt); + VarHandle i("i", kInt); + auto for_stmt = For::make( + i, + 0, + N, + Store::make( + c, + {i}, + CompareSelect::make( + Load::make(a, {i}), + Load::make(b, {i}), + CompareSelectOperation::kEQ))); + + HashProvider hasher; + auto hash_before = hasher.hash(for_stmt); + hasher.clearCache(); + + for_stmt->set_gpu_block_index(LoopOptions::IDX_X); + auto hash_block_idx = hasher.hash(for_stmt); + hasher.clearCache(); + + ASSERT_NE(hash_before, hash_block_idx); + + for_stmt->set_gpu_block_index(LoopOptions::IDX_UNSET); + auto hash_reset = hasher.hash(for_stmt); + hasher.clearCache(); + + ASSERT_EQ(hash_before, hash_reset); + for_stmt->set_gpu_thread_index(LoopOptions::IDX_X); + auto hash_thread_idx = hasher.hash(for_stmt); + + ASSERT_NE(hash_before, hash_thread_idx); + ASSERT_NE(hash_block_idx, hash_thread_idx); +} + +/// (2 + x) + 4 => x + 6 +TEST(Simplify, SimplifyAdd) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + VarHandle m("m", kInt); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + VarHandle n("n", kInt); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + VarHandle n_1("n_1", kInt); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4); + + ExprHandle simplified = IRSimplifier::simplify(body); + AddPtr root = simplified.AsNode(); + ASSERT_NE(root, nullptr); + VarPtr lhs = to(root->lhs()); + ASSERT_NE(lhs, nullptr); + ASSERT_EQ(lhs->name_hint(), "x"); + IntImmPtr rhs = to(root->rhs()); + ASSERT_NE(rhs, nullptr); + ASSERT_EQ(rhs->value(), 6.f); +} + +/// (2 - x) - 4 => -2 - x +TEST(Simplify, SimplifySub) { + VarHandle x("x", kInt); + ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4); + + ExprHandle simplified = IRSimplifier::simplify(body); + SubPtr root = simplified.AsNode(); + ASSERT_NE(root, nullptr); + IntImmPtr lhs = to(root->lhs()); + ASSERT_NE(lhs, nullptr); + ASSERT_EQ(lhs->value(), -2.f); + VarPtr rhs = to(root->rhs()); + ASSERT_NE(rhs, nullptr); + ASSERT_EQ(rhs->name_hint(), "x"); +} + +/// 2 * (1 - x) - 4 => 2 * (-3 - x) +TEST(Simplify, SimplifyMultiLayer) { + VarHandle x("x", kInt); + ExprHandle body = ExprHandle(2) * ((ExprHandle(1) - x) - ExprHandle(4)); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); + IS_IMM_WITH_VAL(Int, sub->lhs(), -3); + IS_VAR_WITH_NAME(sub->rhs(), "x"); +} + +/// 2 * (3 * x) - (x * 4) => 2 * x +TEST(Simplify, SimplifyMultiTerm) { + VarHandle x("x", kInt); + ExprHandle body = + (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); + + ExprHandle simplified = IRSimplifier::simplify(body); + MulPtr root = simplified.AsNode(); + ASSERT_NE(root, nullptr); + IntImmPtr lhs = to(root->lhs()); + ASSERT_NE(lhs, nullptr); + ASSERT_EQ(lhs->value(), 2); + VarPtr rhs = to(root->rhs()); + ASSERT_NE(rhs, nullptr); + ASSERT_EQ(rhs->name_hint(), "x"); +} + +/// 2 * (3 * (long)x) - (x * 4) => 2 * x +TEST(Simplify, SimplifyCasts) { + VarHandle x("x", kLong); + ExprHandle body = + (ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4))); + + ExprHandle simplified = IRSimplifier::simplify(body); + MulPtr root = simplified.AsNode(); + ASSERT_NE(root, nullptr); + LongImmPtr lhs = to(root->lhs()); + ASSERT_NE(lhs, nullptr); + ASSERT_EQ(lhs->value(), 2); + VarPtr rhs = to(root->rhs()); + ASSERT_NE(rhs, nullptr); + ASSERT_EQ(rhs->name_hint(), "x"); +} + +/// (x + 0) * 1 => x +TEST(Simplify, SimplifyEliminatesNoOps) { + VarHandle x("x", kInt); + ExprHandle body = (x + ExprHandle(0)) * 1; + + ExprHandle simplified = IRSimplifier::simplify(body); + VarPtr root = simplified.AsNode(); + ASSERT_NE(root, nullptr); + ASSERT_EQ(root->name_hint(), "x"); +} + +/// Cannot simplify this. +TEST(Simplify, SimplifyMultiVar) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = x * 24 + y * 34; + + ExprHandle simplified = IRSimplifier::simplify(body); + + AddPtr root = simplified.AsNode(); + ASSERT_NE(root, nullptr); + MulPtr lhs = to(root->lhs()); + ASSERT_NE(lhs, nullptr); + VarPtr varX = to(lhs->rhs()); + ASSERT_NE(varX, nullptr); + ASSERT_EQ(varX->name_hint(), "x"); + MulPtr rhs = to(root->rhs()); + ASSERT_NE(rhs, nullptr); + VarPtr varY = to(rhs->rhs()); + ASSERT_NE(varY, nullptr); + ASSERT_EQ(varY->name_hint(), "y"); +} + +// x + 2 + y => x + y + 2 +TEST(Simplify, DISABLED_SimplifyReorderings) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = x + 2 + y; + ExprHandle simplified = IRSimplifier::simplify(body); + + AddPtr root = simplified.AsNode(); + ASSERT_NE(root, nullptr); + + IS_NODE_WITH_NAME(Add, root->lhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + IS_IMM_WITH_VAL(Int, root->rhs(), 2); +} + +/// y + x * 0 => y +TEST(Simplify, SimplifyEliminatesVar) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = y + x * ExprHandle(0); + + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "y"); +} + +TEST(Simplify, SimplifyAdds) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // (x + y) + (x + y) => 2 * (x + y) + ExprHandle body = (x + y) + (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), root); + IS_IMM_WITH_VAL(Int, root->lhs(), 2); + IS_NODE_WITH_NAME(Add, root->rhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); + } + + { + // (x * y) + (x * y) => 2 * (x * y) + ExprHandle body = (x * y) + (x * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), root); + IS_IMM_WITH_VAL(Int, root->lhs(), 2); + IS_NODE_WITH_NAME(Mul, root->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // (x - y) + (x - y) => 2 * (x - y) + ExprHandle body = (x - y) + (x - y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + + IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // (x + x + x + x) => 4 * x + ExprHandle body = (x + x + x + x); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), root); + IS_IMM_WITH_VAL(Int, root->lhs(), 4); + IS_VAR_WITH_NAME(root->rhs(), "x"); + } + + { + // (x + 0) => x. + ExprHandle body = x + 0; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // (x + 0.f) => float(x). + ExprHandle body = x + 0.f; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Cast, simplified.node(), cast); + ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); + IS_VAR_WITH_NAME(cast->src_value(), "x"); + } +} + +TEST(Simplify, SimplifyMuls) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // (x + y) * (x + y) => (x + y) * (x + y) + // We don't attempt to simplify multiplication of polynomials since the + // result is only very rarely more efficient. + ExprHandle body = (x + y) * (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); + IS_VAR_WITH_NAME(lhs->lhs(), "x"); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); + IS_NODE_WITH_NAME(Add, mul->rhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // x * y * x * y => x * x * y * y + // These get reordered only. + ExprHandle body = x * y * x * y; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul1); + IS_NODE_WITH_NAME(Mul, mul1->lhs(), mul2); + IS_NODE_WITH_NAME(Mul, mul2->lhs(), mul3); + IS_VAR_WITH_NAME(mul1->rhs(), "y"); + IS_VAR_WITH_NAME(mul2->rhs(), "y"); + IS_VAR_WITH_NAME(mul3->lhs(), "x"); + IS_VAR_WITH_NAME(mul3->rhs(), "x"); + } + + { + // 1 * (x * 1) => x + // Ones cancel cleanly. + ExprHandle body = ExprHandle(1) * (x * ExprHandle(1)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // 1.f * (x * 1.f) => x + // Even float ones cancel cleanly, but carry their type. + ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(1.f)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Cast, simplified.node(), cast); + ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); + IS_VAR_WITH_NAME(cast->src_value(), "x"); + } + + { + // 1 * (x * 1.f) => x + // One float is enough to cast the expr. + ExprHandle body = ExprHandle(1) * (x * ExprHandle(1.f)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Cast, simplified.node(), cast); + ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); + IS_VAR_WITH_NAME(cast->src_value(), "x"); + } + + { + // 1 * (x * 0) => 0 + // Zeroes are eliminated. + ExprHandle body = ExprHandle(1) * (x * ExprHandle(0)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // 1 * (x * 0) => 0 + // But not for Float since nan * 0 = nan. + ExprHandle body = ExprHandle(1.f) * (x * ExprHandle(0.f)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_NODE_WITH_NAME(Cast, mul->lhs(), cast); + ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); + IS_VAR_WITH_NAME(cast->src_value(), "x"); + IS_IMM_WITH_VAL(Float, mul->rhs(), 0.0); + } + + { + // (x - y) * (x - y) => (x - y) * (x - y) + // As with Add we don't attempt simplification of this. + ExprHandle body = (x - y) * (x - y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_NODE_WITH_NAME(Sub, mul->lhs(), lhs); + IS_VAR_WITH_NAME(lhs->lhs(), "x"); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); + IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // (x + y) * (x - y) => (x + y) * (x - y) + // Don't simplify with different ops on each side. + ExprHandle body = (x + y) * (x - y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_NODE_WITH_NAME(Add, mul->lhs(), lhs); + IS_VAR_WITH_NAME(lhs->lhs(), "x"); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); + IS_NODE_WITH_NAME(Sub, mul->rhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "x"); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // Multiply a polynomial by a term. + // - term with no scalar, poly with non-identity scalar. + // x * (y + 1) => x + x * y + ExprHandle body = x * (y + ExprHandle(1)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, add->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // Multiply a polynomial by a term. + // - term with identity scalar, poly with non-identity scalar. + // (x * 1) * (y + 1) => x + x * y + ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(1)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, add->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // Multiply a polynomial by a term. + // - term with non-identity scalar, poly with non-identity scalar. + // (x * 2) * (y + 1) => 2 * (x + x * y) + ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(1)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_NODE_WITH_NAME(Add, mul->rhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, add->rhs(), mul2); + IS_VAR_WITH_NAME(mul2->lhs(), "x"); + IS_VAR_WITH_NAME(mul2->rhs(), "y"); + } + + { + // Multiply a polynomial by a term. + // - term with non-identity scalar, poly with identity scalar. + // (x * 2) * (y + 0) => 2 * (x * y) + ExprHandle body = (x * ExprHandle(2)) * (y + ExprHandle(0)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_NODE_WITH_NAME(Mul, mul->rhs(), mul2); + IS_VAR_WITH_NAME(mul2->lhs(), "x"); + IS_VAR_WITH_NAME(mul2->rhs(), "y"); + } + + { + // Multiply a polynomial by a term. + // - term with identity scalar, poly with identity scalar. + // (x * 1) * (y + 0) => x * y + ExprHandle body = (x * ExprHandle(1)) * (y + ExprHandle(0)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // Multiply a polynomial by a term. + // - term with no scalar, poly with identity scalar. + // x * (y + 0) => x * y + ExprHandle body = x * (y + ExprHandle(0)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } +} + +// Sub an expr from itself will result in zero. +TEST(Simplify, SimplifySubs) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // (x + y) - (x + y) => 0 + ExprHandle body = (x + y) - (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // (x * y) - (x * y) => 0 + ExprHandle body = (x * y) - (x * y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // (x - y) - (x - y) => 0 + ExprHandle body = (x - y) - (x - y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // (x + y) - 2 * (x + y) => -1 * x - y + ExprHandle body = (x + y) - ExprHandle(2) * (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mul, sub->lhs(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), -1); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + IS_VAR_WITH_NAME(sub->rhs(), "y"); + } + + { + // (x + y) - y => x + ExprHandle body = (x + y) - y; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // (x - 0) => x. + ExprHandle body = x - 0; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // (x - 0.f) => x. + // Simple enough to cancel in float. + ExprHandle body = x - ExprHandle(0.f); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cast, simplified.node(), cast); + ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); + IS_VAR_WITH_NAME(cast->src_value(), "x"); + } + + { + // (x - (float)(y - y)) => x. + ExprHandle body = x - Cast::make(kFloat, y - y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cast, simplified.node(), cast); + ASSERT_EQ(cast->dtype().scalar_type(), ScalarType::Float); + IS_VAR_WITH_NAME(cast->src_value(), "x"); + } + + { + // (x - y) - y => x - 2 * y + ExprHandle body = (x - y) - y; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_VAR_WITH_NAME(sub->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // 2 * x - x => x + ExprHandle body = (ExprHandle(2) * x) - x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // x - 2 * x = -1 * x + // We don't have a unary negate, but this could be 0 -x I guess? + ExprHandle body = x - (ExprHandle(2) * x); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + + IS_IMM_WITH_VAL(Int, mul->lhs(), -1); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + // (x + y + 5) * (x - x) => 0 + // Cancelling out one side of Mul cancels both. + ExprHandle body = (x + y + 5) * (x - x); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // Cancel out opaque modulus. + ExprHandle body = (x % y + 2) - (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 2); + } + + { + // Cancel out opaque modulus with a bit more going on. + ExprHandle body = (x % y + (x * 2 - x - y * 0) - x + 2) - (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 2); + } + + { + // Sub where result is negative. + ExprHandle body = x - (x + 1); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), -1); + } + + { + // Sub where result is positive due to negative scalar on RHS. + ExprHandle body = x - (x - 1); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 1); + } + + { + // Term - Polynomial sub where RHS must be negated. + ExprHandle body = (x * 2) - (x * 2 + 1); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), -1); + } + + { + // Term - Polynomial sub where the result is a Term. + ExprHandle body = (y * x * 2) - (x * y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // Term - Polynomial sub where the result is a Polynomial. + ExprHandle body = (x * 2) - (x + 1); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + + IS_VAR_WITH_NAME(sub->lhs(), "x"); + IS_IMM_WITH_VAL(Int, sub->rhs(), 1); + } +} + +TEST(Simplify, SimplifyDiv) { + VarHandle x("x", kInt); + + { + ExprHandle body = ExprHandle(0) / x; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + ExprHandle body = x / 1; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_VAR_WITH_NAME(simplified.node(), "x"); + } +} + +TEST(Simplify, SimplifyDivWithLoopContext0) { + // Stmt to simplify: + // for (int i = 0; i < 100; i++) { + // A[i] = i / 100; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {100}, kInt); + auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i / 100))); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = 0; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyDivWithLoopContext1) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // A[i] = (i + 24) / 6; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {6}, kInt); + auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / 6)); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = 4; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyDivWithLoopContext2) { + // Stmt to simplify: + // for (const auto i : c10::irange(5)) { + // A[i] = (i + 25) / 6; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {5}, kInt); + auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) / 6)); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = 4; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyDivWithLoopContext3) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // A[i] = (i + 24) / (-6); + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {6}, kInt); + auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) / (-6))); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NOT: A[i] = -4; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyDivWithLoopContext4) { + // Stmt to simplify: + // for (const auto i : c10::irange(5)) { + // A[i] = (i - 5) / 6; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {5}, kInt); + auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) / 6)); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NOT: A[i] = 0; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyDivWithLoopContext5) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // for (const auto j : c10::irange(10)) { + // A[i, j] = (i + 6*j) / 6; + // } + //} + VarHandle i("i", kInt); + VarHandle j("j", kInt); + BufHandle a_buf("A", {6, 10}, kInt); + auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / 6)); + auto for_i = For::make(i, 0, 6, for_j); + + const StmtPtr simplified = IRSimplifier::simplify(for_i); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK: for (int j +# CHECK-NEXT: A[i, j] = j; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyDivWithLoopContext6) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // for (int j = -1; j < 9; j++) { + // A[i, j+1] = (i + 6*j) / 6; + // } + //} + VarHandle i("i", kInt); + VarHandle j("j", kInt); + BufHandle a_buf("A", {6, 10}, kInt); + auto for_j = + For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) / 6)); + auto for_i = For::make(i, 0, 6, for_j); + + const StmtPtr simplified = IRSimplifier::simplify(for_i); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK: for (int j +# CHECK-NOT: A[i, j] = j; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyDivWithLoopContext7) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // for (const auto j : c10::irange(10)) { + // A[i, j] = (i + 6*j) / (-6); + // } + //} + VarHandle i("i", kInt); + VarHandle j("j", kInt); + BufHandle a_buf("A", {6, 10}, kInt); + auto for_j = + For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) / (-6))); + auto for_i = For::make(i, 0, 6, for_j); + + const StmtPtr simplified = IRSimplifier::simplify(for_i); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK: for (int j +# CHECK-NOT: A[i, j] = -j; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext0) { + // Stmt to simplify: + // for (const auto i : c10::irange(100)) { + // A[i] = i % 100; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {100}, kInt); + auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i % 100))); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = i; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext1) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // A[i] = (i + 24) % 6; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {6}, kInt); + auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % 6)); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = i; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext2) { + // Stmt to simplify: + // for (const auto i : c10::irange(5)) { + // A[i] = (i + 25) % 6; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {5}, kInt); + auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + 25) % 6)); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = i + 1; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext3) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // A[i] = (i + 24) % (-6); + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {6}, kInt); + auto for_stmt = For::make(i, 0, 6, Store::make(a_buf, {i}, (i + 24) % (-6))); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NOT: A[i] = i; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext4) { + // Stmt to simplify: + // for (const auto i : c10::irange(5)) { + // A[i] = (i - 5) % 6; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {5}, kInt); + auto for_stmt = For::make(i, 0, 5, Store::make(a_buf, {i}, (i + (-5)) % 6)); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NOT: A[i] = i - 5; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext5) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // for (const auto j : c10::irange(10)) { + // A[i, j] = (i + 6*j) % 6; + // } + //} + VarHandle i("i", kInt); + VarHandle j("j", kInt); + BufHandle a_buf("A", {6, 10}, kInt); + auto for_j = For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % 6)); + auto for_i = For::make(i, 0, 6, for_j); + + const StmtPtr simplified = IRSimplifier::simplify(for_i); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK: for (int j +# CHECK-NEXT: A[i, j] = i; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext6) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // for (int j = -1; j < 9; j++) { + // A[i, j+1] = (i + 6*j) % 6; + // } + //} + VarHandle i("i", kInt); + VarHandle j("j", kInt); + BufHandle a_buf("A", {6, 10}, kInt); + auto for_j = + For::make(j, -1, 9, Store::make(a_buf, {i, j + 1}, (i + j * 6) % 6)); + auto for_i = For::make(i, 0, 6, for_j); + + const StmtPtr simplified = IRSimplifier::simplify(for_i); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK: for (int j +# CHECK-NOT: A[i, j] = i; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyModWithLoopContext7) { + // Stmt to simplify: + // for (const auto i : c10::irange(6)) { + // for (const auto j : c10::irange(10)) { + // A[i, j] = (i + 6*j) % (-6); + // } + //} + VarHandle i("i", kInt); + VarHandle j("j", kInt); + BufHandle a_buf("A", {6, 10}, kInt); + auto for_j = + For::make(j, 0, 10, Store::make(a_buf, {i, j}, (i + j * 6) % (-6))); + auto for_i = For::make(i, 0, 6, for_j); + + const StmtPtr simplified = IRSimplifier::simplify(for_i); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK: for (int j +# CHECK-NOT: A[i, j] = i; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + +TEST(Simplify, SimplifyMod) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + + { + // Constant folding works. + ExprHandle body = ExprHandle(10) % 8; + ExprHandle simplified = IRSimplifier::simplify(body); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + IS_IMM_WITH_VAL(Int, simplified.node(), 2); + } + + { + // x % x => 0 + ExprHandle body = x % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // 0 % x => 0 + ExprHandle body = ExprHandle(0) % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // x % 1 => 0 + ExprHandle body = x % 1; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // Doesn't change unknown mods. + // x % y => x % y + ExprHandle body = x % y; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "x"); + IS_VAR_WITH_NAME(mod->rhs(), "y"); + } + + { + // don't touch if RHS is unknown. + // 4 % x => 4 % x + ExprHandle body = ExprHandle(4) % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_IMM_WITH_VAL(Int, mod->lhs(), 4); + IS_VAR_WITH_NAME(mod->rhs(), "x"); + } + + { + // don't touch if LHS is unknown. + // x % 4 => x % 4 + ExprHandle body = x % 4; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "x"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 4); + } + + { + // if LHS is a multiple of RHS, mod is zero. + // 2 * x % x => 0 + ExprHandle body = (x * 2) % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // true even if the multiple is not constant. + // x * y % x => 0 + ExprHandle body = (x * y) % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // true with multiple unknown values in LHS. + // x * y * z % x => 0 + ExprHandle body = (x * y * z) % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // true if the denom is compound. + // x * y * z % y * z => 0 + ExprHandle body = (x * y * z) % (y * z); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // Sanity check true with scalars that are multiples. + // 12 * x % 4 => 0 + ExprHandle body = (x * 12) % 4; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } + + { + // Sanity check not true if the smaller scalar is on LHS. + // 4 * x % 12 => 4 * x % 12 + ExprHandle body = (x * 4) % 12; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_NODE_WITH_NAME(Mul, mod->lhs(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 4); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 12); + } + + { + // Both scalar and symbolic in multiple. + // (6 * x * y) % (3 * x * y) => 0 + ExprHandle body = (ExprHandle(6) * x * y) % (x * y * 3); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 0); + } +} + +// Test that mixing ops together simplifies as expected. +TEST(Simplify, SimplifyMultiOp) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // (x * y) + (x - y) => (x + x * y) - y + ExprHandle body = (x * y) + (x - y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Add, sub->lhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, add->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + IS_VAR_WITH_NAME(sub->rhs(), "y"); + } + + { + // (x + y) - x * y => (x + y) - x * y + ExprHandle body = (x + y) - x * y; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Add, sub->lhs(), add); + IS_NODE_WITH_NAME(Mul, sub->rhs(), mul); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // (x - y) - (x + y) => -2 * y + ExprHandle body = (x - y) - (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), -2); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // (x - 0) + (x * 1) - (x + 0) => x + ExprHandle body = (x - 0) + (x * 1) - (x + 0); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // (x - 0.f) + (x * 1.f) - (x + 0.f) => float(x) + float(x) - float(x) + // Even in Float simple terms cancel out, but the variable ones cannot. + ExprHandle body = + (x - ExprHandle(0.f)) + (x * ExprHandle(1.f)) - (x + ExprHandle(0.f)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Add, sub->lhs(), add); + IS_NODE_WITH_NAME(Cast, add->lhs(), cast1); + IS_VAR_WITH_NAME(cast1->src_value(), "x"); + IS_NODE_WITH_NAME(Cast, add->rhs(), cast2); + IS_VAR_WITH_NAME(cast2->src_value(), "x"); + IS_NODE_WITH_NAME(Cast, sub->rhs(), cast3); + IS_VAR_WITH_NAME(cast3->src_value(), "x"); + } +} + +// Test that chaining many ops together works as expected. +TEST(Simplify, SimplifyManyOps) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // x + y + x + x + y + y + x + y + x = 4 * y + 5 * x + ExprHandle body = x + y + x + x + y + y + x + y + x; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + + IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 4); + IS_VAR_WITH_NAME(lhs->rhs(), "y"); + + IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); + IS_VAR_WITH_NAME(rhs->rhs(), "x"); + } + + { + // x - y + x + x - y - y + x - y + x = 5 * x - 4 * y + ExprHandle body = x - y + x + x - y - y + x - y + x; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), add); + + IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 5); + IS_VAR_WITH_NAME(lhs->rhs(), "x"); + + IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // x + y + x - x - y - y + x + y + x = 3 * x + ExprHandle body = x + y + x - x - y - y + x + y + x; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 3); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } +} + +TEST(Simplify, SimplifyFactorization) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // (2 * x) + (2 * y) => 2 * (x + y) + ExprHandle body = (ExprHandle(2) * x + ExprHandle(2) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + + IS_NODE_WITH_NAME(Add, mul->rhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); + } + + { + // Factorization when scalars have common divider. + // (2 * x) + (4 * y) => 2 * (2 * y + x) + ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + + IS_NODE_WITH_NAME(Add, mul->rhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, add->rhs(), mul2); + IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); + IS_VAR_WITH_NAME(mul2->rhs(), "y"); + } + + { + // Factorization attempt without a common divider. + // (2 * x) + (5 * y) => (5 * y) + (2 * x) + ExprHandle body = (ExprHandle(2) * x + ExprHandle(5) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + + IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); + IS_VAR_WITH_NAME(lhs->rhs(), "x"); + + IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 5); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // Factorization after merging. + // (2 * x) + (4 * y) + (8 * x + 6 * y) => 10 * (x + y) + ExprHandle body = (ExprHandle(2) * x + ExprHandle(4) * y) + + (ExprHandle(8) * x + ExprHandle(6) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 10); + + IS_NODE_WITH_NAME(Add, mul->rhs(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); + } + + { + // Factorization with common divider but different signs. + // (2 * x) + (-4 * y) => 2 * (x - 2 * y) + ExprHandle body = (ExprHandle(2) * x + ExprHandle(-4) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + + IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); + IS_VAR_WITH_NAME(sub->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, sub->rhs(), mul2); + IS_IMM_WITH_VAL(Int, mul2->lhs(), 2); + IS_VAR_WITH_NAME(mul2->rhs(), "y"); + } + + { + // Factorization with all negative numbers. + // (-2 * x) + (-4 * y) => 2 * (-1 * x - 2 * y) + ExprHandle body = ExprHandle(-2) * x + ExprHandle(-4) * y; + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + + IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); + IS_NODE_WITH_NAME(Mul, sub->lhs(), mul2); + IS_IMM_WITH_VAL(Int, mul2->lhs(), -1); + IS_VAR_WITH_NAME(mul2->rhs(), "x"); + IS_NODE_WITH_NAME(Mul, sub->rhs(), mul3); + IS_IMM_WITH_VAL(Int, mul3->lhs(), 2); + IS_VAR_WITH_NAME(mul3->rhs(), "y"); + } + + { + // The following test ensures that there in no infinite recursion during + // factorization when negative numbers are involved. + VarHandle a("a", kInt); + VarHandle b("b", kInt); + VarHandle c("c", kInt); + VarHandle d("d", kInt); + VarHandle e("e", kInt); + VarHandle f("f", kInt); + VarHandle g("g", kInt); + VarHandle h("h", kInt); + + ExprHandle body = a * 1024 + 0 + b * (-1) + c * (-1) + d * 1 + e * 1 + + f * 32 + g * (-1024) + h * (-32); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR( + simplified, + "((((((d + e) + 1024 * a) + 32 * f) - b) - c) - 1024 * g) - 32 * h"); + } +} + +// (4 * x + y + z * 2) + (4 * x + y + z * 4) => 2 * (y + 3 * z + 4 * x) +TEST(Simplify, SimplifyFactorizeUneven) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + ExprHandle body = + (ExprHandle(4) * x + y + z * 2) + (ExprHandle(4) * x + y + z * 4); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), root); + IS_IMM_WITH_VAL(Int, root->lhs(), 2); + IS_NODE_WITH_NAME(Add, root->rhs(), add1); + IS_NODE_WITH_NAME(Add, add1->lhs(), add2); + + IS_VAR_WITH_NAME(add2->lhs(), "y"); + IS_NODE_WITH_NAME(Mul, add2->rhs(), zmul); + IS_NODE_WITH_NAME(Mul, add1->rhs(), xmul); + + IS_IMM_WITH_VAL(Int, xmul->lhs(), 4); + IS_VAR_WITH_NAME(xmul->rhs(), "x"); + + IS_IMM_WITH_VAL(Int, zmul->lhs(), 3); + IS_VAR_WITH_NAME(zmul->rhs(), "z"); +} + +// (x * y) + (2 * x) * (x + y) => 2 * (x * x) + 3 * (x * y) +// This is kind of a placeholder test for variable factorization. +TEST(Simplify, SimplifyDeeperTerms) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x * y) + (ExprHandle(2) * x) * (x + y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + + IS_NODE_WITH_NAME(Mul, add->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); + IS_NODE_WITH_NAME(Mul, lhs->rhs(), xxTerm); + IS_VAR_WITH_NAME(xxTerm->lhs(), "x"); + IS_VAR_WITH_NAME(xxTerm->rhs(), "x"); + + IS_NODE_WITH_NAME(Mul, add->rhs(), rhs); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 3); + IS_NODE_WITH_NAME(Mul, rhs->rhs(), xyTerm); + IS_VAR_WITH_NAME(xyTerm->lhs(), "x"); + IS_VAR_WITH_NAME(xyTerm->rhs(), "y"); +} + +// Tests the difference between two less trivial expressions. +// (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1 +TEST(Simplify, SimplifyDeeperDifference) { + VarHandle n("n", kInt); + VarHandle n_1("n_1", kInt); + VarHandle m("m", kInt); + ExprHandle body = + (m * (ExprHandle(1) * n_1) + (n + 1)) - (m * (ExprHandle(1) * n_1) + n); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_IMM_WITH_VAL(Int, simplified.node(), 1); +} + +// Test constant folding into the difference between expressions. +// 2 + char((m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n)) => 3 +TEST(Simplify, SimplifyFoldComplexDifference) { + VarHandle n("n", kInt); + VarHandle n_1("n_1", kInt); + VarHandle m("m", kInt); + ExprHandle body = + (IntImm::make(2) + + (Cast::make( + kChar, + (m * (ExprHandle(1) * n_1) + (n + 1)) - + (m * (ExprHandle(1) * n_1) + n)))); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 3); +} + +TEST(Simplify, SimplifyIfComponents) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = IfThenElse::make( + ((ExprHandle(5) - ExprHandle(4)) * x) > y, + ExprHandle(2) * x - x, + ExprHandle(2) * y - y); + + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(IfThenElse, simplified.node(), ifexpr); + + IS_NODE_WITH_NAME(CompareSelect, ifexpr->condition(), cmp); + ASSERT_EQ(cmp->compare_select_op(), kGT); + IS_VAR_WITH_NAME(cmp->lhs(), "x"); + IS_VAR_WITH_NAME(cmp->rhs(), "y"); + + IS_VAR_WITH_NAME(ifexpr->true_value(), "x"); + IS_VAR_WITH_NAME(ifexpr->false_value(), "y"); +} + +TEST(Simplify, SimplifyOpaqueTerms) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + { + // 2 * x/y * y - x/y * y => x/y * y + ExprHandle body = ((ExprHandle(2)) * (x / y) * y) - ((x / y) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_NODE_WITH_NAME(Div, mul->lhs(), div); + IS_VAR_WITH_NAME(div->lhs(), "x"); + IS_VAR_WITH_NAME(div->rhs(), "y"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // x%y - (x%y - 1) => 1 + ExprHandle body = (x % y) - ((x % y) - 1); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_IMM_WITH_VAL(Int, simplified.node(), 1); + } +} + +TEST(Simplify, SimplifySymbolicMinMax) { + { + // Minimum with constant difference between terms. + VarHandle x("x", kInt); + ExprHandle body = Min::make(x + 3, x + 7, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_IMM_WITH_VAL(Int, add->rhs(), 3); + } + + { + // Maximum with constant difference between terms. + VarHandle x("x", kInt); + ExprHandle body = Max::make(x + 3, x + 7, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_IMM_WITH_VAL(Int, add->rhs(), 7); + } + + { + // Can't simplify multiples because of signedness of variable component. + // TODO: maybe we could for unsigned types? + VarHandle x("x", kInt); + ExprHandle body = Max::make(x * 3, x * 7, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE(Max, simplified.node()); + } +} + +TEST(Simplify, SimplifyNestedMax) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + + { + // Max(x + y, x + y) => x + y + ExprHandle body = Max::make(x + y, x + y, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y"); + } + + { + // Max(x + y, Max(x + y, z)) => Max(x + y, z) + ExprHandle body = Max::make(x + y, Max::make(x + y, z, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max); + IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(max->rhs(), "z"); + } + + { + // Max(x + y, Max(z, x + y)) => Max(x + y, z) + ExprHandle body = Max::make(x + y, Max::make(z, x + y, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max); + IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(max->rhs(), "z"); + } + + { + // Max(Max(x + y, z), x + y) => Max(x + y, z) + ExprHandle body = Max::make(Max::make(x + y, z, true), x + y, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max); + IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(max->rhs(), "z"); + } + + { + // Max(Max(z, x + y), x + y) => Max(x + y, z) + ExprHandle body = Max::make(Max::make(z, x + y, true), x + y, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max); + IS_BINOP_W_VARS(Add, max->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(max->rhs(), "z"); + } + + { + // Max(Max(x, y), x) => Max(Max(x, y), x) + // Nested Max ops with different propagate_nans should not be simplified. + ExprHandle body = Max::make(Max::make(x, y, true), x, false); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max); + IS_BINOP_W_VARS(Max, max->lhs(), max1, "x", "y"); + ASSERT_TRUE(max1->propagate_nans()); + IS_VAR_WITH_NAME(max->rhs(), "x"); + ASSERT_FALSE(max->propagate_nans()); + } + + { + // Max(Min(x, y), Min(x, z)) => Min(Max(y, z), x) + ExprHandle body = + Max::make(Min::make(x, y, true), Min::make(x, z, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); + } + + { + // Max(Min(x, y), Min(z, x)) => Min(Max(y, z), x) + ExprHandle body = + Max::make(Min::make(x, y, true), Min::make(z, x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); + } + + { + // Max(Min(y, x), Min(x, z)) => Min(Max(y, z), x) + ExprHandle body = + Max::make(Min::make(y, x, true), Min::make(x, z, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); + } + + { + // Max(Min(y, x), Min(z, x)) => Min(Max(y, z), x) + ExprHandle body = + Max::make(Min::make(y, x, true), Min::make(z, x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Min(Max(y, z, 1), x, 1)"); + } + + { + // Max(Min(y, x), Min(z, x)) => Max(Min(x, y), Min(x, z)) + // When all the ops in the pattern do not have the same propagate_nans, + // it should not be simplified. + ExprHandle body = + Max::make(Min::make(y, x, true), Min::make(z, x, false), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max); + IS_BINOP_W_VARS(Min, max->lhs(), min1, "x", "y"); + ASSERT_TRUE(min1->propagate_nans()); + IS_BINOP_W_VARS(Min, max->rhs(), min2, "x", "z"); + ASSERT_FALSE(min2->propagate_nans()); + ASSERT_TRUE(max->propagate_nans()); + } + + { + // Max(5, Max(x, 8)) => Max(x, 8) + ExprHandle body = Max::make(5, Max::make(x, 8, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); + ASSERT_TRUE(max->propagate_nans()); + } + + { + // Max(8, Max(x, 5)) => Max(x, 8) + ExprHandle body = Max::make(8, Max::make(x, 5, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); + ASSERT_TRUE(max->propagate_nans()); + } + + { + // Max(Max(x, 8), 5) => Max(x, 8) + ExprHandle body = Max::make(Max::make(x, 8, true), 5, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); + ASSERT_TRUE(max->propagate_nans()); + } + + { + // Max(Max(x, 5), 8) => Max(x, 8) + ExprHandle body = Max::make(Max::make(x, 5, true), 8, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Max, simplified.node(), max, "x", 8); + ASSERT_TRUE(max->propagate_nans()); + } + + { + // Max(5, Max(x, Max(y, Max(z, 8)))) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + 5, Max::make(x, Max::make(y, Max::make(z, 8, true), true), true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } + + { + // Max(8, Max(Max(y, Max(z, 5)), x)) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + 8, Max::make(Max::make(y, Max::make(z, 5, true), true), x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } + + { + // Max(5, Max(Max(Max(z, 8), y), x)) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + 5, Max::make(Max::make(Max::make(z, 8, true), y, true), x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } + + { + // Max(Max(x, Max(y, Max(5, z))), 8) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + Max::make(x, Max::make(y, Max::make(5, z, true), true), true), 8, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } + + { + // Max(Max(Max(y, Max(8, z)), x), 5) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + Max::make(Max::make(y, Max::make(z, 8, true), true), x, true), 5, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } + + { + // Max(Max(Max(Max(5, z), y), x), 8) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + Max::make(Max::make(Max::make(z, 5, true), y, true), x, true), 8, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } + + { + // Max(Max(Max(Max(z, 5), y), x), 8) => Max(Max(x, Max(Max(z, 5), y)), 8) + // Do not simplify when all the Max ops do not have the same + // propagate_nans. + ExprHandle body = Max::make( + Max::make(Max::make(Max::make(z, 5, true), y, false), x, true), + 8, + false); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Max(Max(Max(Max(z, 5, 1), y, 0), x, 1), 8, 0)"); + } + + { + // Max(8, Max(Max(x, 5), Max(y, z))) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + 8, Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } + + { + // Max(Max(Max(x, 5), Max(y, z)), 8) => Max(Max(Max(x, 8), y), z) + ExprHandle body = Max::make( + Max::make(Max::make(x, 5, true), Max::make(y, z, true), true), 8, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Max, simplified.node(), max1); + IS_NODE_WITH_NAME(Max, max1->lhs(), max2); + IS_BINOP_W_CONST(Max, max2->lhs(), max3, "x", 8); + ASSERT_TRUE(max3->propagate_nans()); + IS_VAR_WITH_NAME(max2->rhs(), "y"); + IS_VAR_WITH_NAME(max1->rhs(), "z"); + } +} + +TEST(Simplify, SimplifyNestedMin) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + + { + // Min(x + y, x + y) => x + y + ExprHandle body = Min::make(x + y, x + y, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + IS_BINOP_W_VARS(Add, simplified.node(), add, "x", "y"); + } + + { + // Min(x + y, Min(x + y, z)) => Min(x + y, z) + ExprHandle body = Min::make(x + y, Min::make(x + y, z, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min); + IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(min->rhs(), "z"); + } + + { + // Min(x + y, Min(z, x + y)) => Min(x + y, z) + ExprHandle body = Min::make(x + y, Min::make(z, x + y, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min); + IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(min->rhs(), "z"); + } + + { + // Min(Min(x + y, z), x + y) => Min(x + y, z) + ExprHandle body = Min::make(Min::make(x + y, z, true), x + y, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min); + IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(min->rhs(), "z"); + } + + { + // Min(Min(z, x + y), x + y) => Min(x + y, z) + ExprHandle body = Min::make(Min::make(z, x + y, true), x + y, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min); + IS_BINOP_W_VARS(Add, min->lhs(), add, "x", "y"); + IS_VAR_WITH_NAME(min->rhs(), "z"); + } + + { + // Min(Min(x, y), x) => Min(Min(x, y), x) + // Nested Min ops with different propagate_nans should not be simplified. + ExprHandle body = Min::make(Min::make(x, y, true), x, false); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_BINOP_W_VARS(Min, min1->lhs(), min2, "x", "y"); + ASSERT_TRUE(min2->propagate_nans()); + IS_VAR_WITH_NAME(min1->rhs(), "x"); + ASSERT_FALSE(min1->propagate_nans()); + } + + { + // Min(Max(x, y), Max(x, z)) => Max(Min(y, z), x) + ExprHandle body = + Min::make(Max::make(x, y, true), Max::make(x, z, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); + } + + { + // Min(Max(x, y), Max(z, x)) => Max(Min(y, z), x) + ExprHandle body = + Min::make(Max::make(x, y, true), Max::make(z, x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); + } + + { + // Min(Max(y, x), Max(x, z)) => Max(Min(y, z), x) + ExprHandle body = + Min::make(Max::make(y, x, true), Max::make(x, z, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); + } + + { + // Min(Max(y, x), Max(z, x)) => Max(Min(y, z), x) + ExprHandle body = + Min::make(Max::make(y, x, true), Max::make(z, x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Max(Min(y, z, 1), x, 1)"); + } + + { + // Min(Max(y, x), Max(z, x)) => Min(Max(x, y), Max(x, z)) + // When all the ops in the pattern do not have the same propagate_nans, + // it should not be simplified. + ExprHandle body = + Min::make(Max::make(y, x, true), Max::make(z, x, false), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min); + IS_BINOP_W_VARS(Max, min->lhs(), max1, "x", "y"); + ASSERT_TRUE(max1->propagate_nans()); + IS_BINOP_W_VARS(Max, min->rhs(), max2, "x", "z"); + ASSERT_FALSE(max2->propagate_nans()); + ASSERT_TRUE(min->propagate_nans()); + } + + { + // Min(5, Min(x, 8)) => Min(x, 8) + ExprHandle body = Min::make(5, Min::make(x, 8, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); + ASSERT_TRUE(min->propagate_nans()); + } + + { + // Min(8, Min(x, 5)) => Min(x, 8) + ExprHandle body = Min::make(8, Min::make(x, 5, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); + ASSERT_TRUE(min->propagate_nans()); + } + + { + // Min(Min(x, 8), 5) => Min(x, 8) + ExprHandle body = Min::make(Min::make(x, 8, true), 5, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); + ASSERT_TRUE(min->propagate_nans()); + } + + { + // Min(Min(x, 5), 8) => Min(x, 8) + ExprHandle body = Min::make(Min::make(x, 5, true), 8, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_BINOP_W_CONST(Min, simplified.node(), min, "x", 5); + ASSERT_TRUE(min->propagate_nans()); + } + + { + // Min(5, Min(x, Min(y, Min(z, 8)))) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + 5, Min::make(x, Min::make(y, Min::make(z, 8, true), true), true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } + + { + // Min(5, Min(Min(y, Min(z, 8)), x)) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + 5, Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } + + { + // Min(5, Min(Min(Min(z, 8), y), x)) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + 5, Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } + + { + // Min(Min(x, Min(y, Min(8, z))), 5) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + Min::make(x, Min::make(y, Min::make(8, z, true), true), true), 5, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } + + { + // Min(Min(Min(y, Min(8, z)), x), 5) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + Min::make(Min::make(y, Min::make(z, 8, true), true), x, true), 5, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } + + { + // Min(Min(Min(Min(8, z), y), x), 5) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + Min::make(Min::make(Min::make(z, 8, true), y, true), x, true), 5, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } + + { + // Min(Min(Min(Min(z, 5), y), x), 8) => Min(Min(Min(Min(z, 5), y), x), 8) + // Do not simplify when all the Min ops do not have the same + // propagate_nans. + ExprHandle body = Min::make( + Min::make(Min::make(Min::make(z, 5, true), y, false), x, true), + 8, + false); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "Min(Min(Min(Min(z, 5, 1), y, 0), x, 1), 8, 0)"); + } + + { + // Min(8, Min(Min(x, 5), Min(y, z))) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + 8, Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } + + { + // Min(Min(Min(x, 5), Min(y, z)), 8) => Min(Min(Min(x, 5), y), z) + ExprHandle body = Min::make( + Min::make(Min::make(x, 5, true), Min::make(y, z, true), true), 8, true); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Min, simplified.node(), min1); + IS_NODE_WITH_NAME(Min, min1->lhs(), min2); + IS_BINOP_W_CONST(Min, min2->lhs(), min3, "x", 5); + ASSERT_TRUE(min3->propagate_nans()); + IS_VAR_WITH_NAME(min2->rhs(), "y"); + IS_VAR_WITH_NAME(min1->rhs(), "z"); + } +} + +TEST(Simplify, SimplifyWontReorderFloat) { + { + // 3 * (3 * x) - 3 * (3 * y) => 9 * (x - y) + // This is an expression we can simplify. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - + ExprHandle(3) * (ExprHandle(3) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 9); + IS_NODE_WITH_NAME(Sub, mul->rhs(), sub); + IS_VAR_WITH_NAME(sub->lhs(), "x"); + IS_VAR_WITH_NAME(sub->rhs(), "y"); + } + + { + // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - 3 * (3 * y). + // If the vars are floating point, ops are not associative and we can't + // reorder. + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + + ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - + ExprHandle(3) * (ExprHandle(3) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); + IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); + IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); + IS_IMM_WITH_VAL(Float, lhsVarMul->lhs(), 3); + IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); + + IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); + IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); + IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); + IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); + IS_VAR_WITH_NAME(rhsVarMul->rhs(), "y"); + } + + { + // 3 * (3 * x) - 3 * (3 * y) => 3 * (3 * x) - (9 * y). + // We will simplify subexprs if they dont reorder floating point ops. + VarHandle x("x", kDouble); + VarHandle y("y", kInt); + + ExprHandle body = ExprHandle(3) * (ExprHandle(3) * x) - + ExprHandle(3) * (ExprHandle(3) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); + IS_IMM_WITH_VAL(Double, lhsMul->lhs(), 3); + IS_NODE_WITH_NAME(Mul, lhsMul->rhs(), lhsVarMul); + IS_IMM_WITH_VAL(Double, lhsVarMul->lhs(), 3); + IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); + + IS_NODE_WITH_NAME_AND_CAST(Mul, sub->rhs(), rhsMul, Double); + IS_IMM_WITH_VAL(Int, rhsMul->lhs(), 9); + IS_VAR_WITH_NAME(rhsMul->rhs(), "y"); + } + + { + // Prevent reordering if FP propagated from dtypes. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + ExprHandle body = ExprHandle(3.f) * (ExprHandle(3) * x) - + ExprHandle(3) * (ExprHandle(3.f) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mul, sub->lhs(), lhsMul); + IS_IMM_WITH_VAL(Float, lhsMul->lhs(), 3); + IS_NODE_WITH_NAME_AND_CAST(Mul, lhsMul->rhs(), lhsVarMul, Float); + IS_IMM_WITH_VAL(Int, lhsVarMul->lhs(), 3); + IS_VAR_WITH_NAME(lhsVarMul->rhs(), "x"); + + IS_NODE_WITH_NAME(Mul, sub->rhs(), rhsMul); + IS_IMM_WITH_VAL(Float, rhsMul->lhs(), 3); + IS_NODE_WITH_NAME(Mul, rhsMul->rhs(), rhsVarMul); + IS_IMM_WITH_VAL(Float, rhsVarMul->lhs(), 3); + IS_NODE_WITH_NAME(Cast, rhsVarMul->rhs(), yCast); + IS_VAR_WITH_NAME(yCast->src_value(), "y"); + } + + { + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + // x%y - (x%y - 1) => x%y - (x%y - 1). + // We wont reorder opaque ops if they are FP. + ExprHandle body = (x % y) - ((x % y) - 1); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_NODE_WITH_NAME(Mod, sub->lhs(), lhsMod); + IS_VAR_WITH_NAME(lhsMod->lhs(), "x"); + IS_VAR_WITH_NAME(lhsMod->rhs(), "y"); + + IS_NODE_WITH_NAME(Sub, sub->rhs(), rhsSub); + IS_NODE_WITH_NAME(Mod, rhsSub->lhs(), rhsMod); + IS_VAR_WITH_NAME(rhsMod->lhs(), "x"); + IS_VAR_WITH_NAME(rhsMod->rhs(), "y"); + IS_IMM_WITH_VAL(Float, rhsSub->rhs(), 1); + } +} + +TEST(Simplify, SimplifyRoundModPattern) { + { + // (x/y)*y + x%y => x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = ((x / y) * y) + (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // Reverse order. + // x%y + (x/y)*y => x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x % y) + ((x / y) * y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // Non opaque denominator. + // (x / (4+y)) * (4+y)) + (x % (y + 4)) => x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)) + + (x % (y + ExprHandle(4))); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // Reverse order. + // (x % (y + 4)) + (x / (4+y)) * (4+y)) => x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x % (y + ExprHandle(4))) + + ((x / (ExprHandle(4) + y)) * (ExprHandle(4) + y)); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // Opaque denominator. + // (x / (2/y)) * (2/y)) + (x % (2/y)) => x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = ((x / (ExprHandle(2) / y)) * (ExprHandle(2) / y)) + + (x % (ExprHandle(2) / y)); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // Non opaque numerator + // ((2*x)/y * y) + ((2*x) % y) => 2 * x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = + (((ExprHandle(2) * x) / y) * y) + ((ExprHandle(2) * x) % y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + // Opaque numerator. + // ((x/2) / y * y) + (x/2 % y) => x / 2. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = + (((x / ExprHandle(2)) / y) * y) + ((x / ExprHandle(2)) % y); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Div, simplified.node(), div); + IS_VAR_WITH_NAME(div->lhs(), "x"); + IS_IMM_WITH_VAL(Int, div->rhs(), 2); + } + + { + // Numerator and denominator. + // ((2*x)/(2*y) * (2*y)) + ((2*x) % (2*y)) => 2 * x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = + (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)) + + ((ExprHandle(2) * x) % (ExprHandle(2) * y)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + // Reverse order. + // ((2*x) % (2*y)) + ((2*x)/(2*y) * (2*y)) => 2 * x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = ((ExprHandle(2) * x) % (ExprHandle(2) * y)) + + (((ExprHandle(2) * x) / (ExprHandle(2) * y)) * (ExprHandle(2) * y)); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + // Negated Subtraction of Round Mod. + // (x/y) * y - (0 - x%y) => x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = ((x / y) * y) - (ExprHandle(0) - (x % y)); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // Other terms are preserved. + // (x/y)*y + x%y + (y * x) => x + (y * x). + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = ((x / y) * y) + (x % y) + (y * x); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_NODE_WITH_NAME(Mul, add->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // Sanity checking we wont do the optimization on floats. + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle body = ((x / y) * y) + (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_NODE_WITH_NAME(Mul, add->lhs(), roundMul); + IS_NODE_WITH_NAME(Div, roundMul->lhs(), roundDiv); + IS_VAR_WITH_NAME(roundDiv->lhs(), "x"); + IS_VAR_WITH_NAME(roundDiv->rhs(), "y"); + IS_VAR_WITH_NAME(roundMul->rhs(), "y"); + IS_NODE_WITH_NAME(Mod, add->rhs(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "x"); + IS_VAR_WITH_NAME(mod->rhs(), "y"); + } + + { + // Sanity check we wont do it if the mod term doesn't match. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + ExprHandle body = ((x / y) * y) + (x % z); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "(x / y) * y + x % z"); + } + + { + // Sanity check we wont do it if the div term doesn't match. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + ExprHandle body = (y * (x / z)) + (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "x % y + (x / z) * y"); + } + + { + // Sanity check we wont do it if the mul term doesn't match. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + ExprHandle body = ((x / y) * z) + (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "x % y + (x / y) * z"); + } +} + +TEST(Simplify, SimplifyRoundModPatternFactorization) { + { + // Full factorization. + // 2 * (x/y * y) + 2 * (x%y) => 2 * x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = ExprHandle(2) * ((x / y) * y) + ExprHandle(2) * (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + // Partial Factorization. + // 32 * (x/8) + 4 * (x % 8) => 4 * x. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks,cppcoreguidelines-avoid-magic-numbers) + ExprHandle body = ExprHandle(32) * (x / 8) + ExprHandle(4) * (x % 8); + ExprHandle simplified = IRSimplifier::simplify(body); + + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 4); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + // Factorization requiring constant folding. + // 20 * (x / (16 / 2)) * 2 + (11 % 6) * (x % (7+1)) => 5 * x. + VarHandle x("x", kInt); + ExprHandle body = ExprHandle(40) * (x / (ExprHandle(16) / 2)) + + (ExprHandle(11) % 6) * (x % (ExprHandle(7) + 1)); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 5); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + VarHandle x("x", kInt); + ExprHandle body = (x / 5) * 10 + ExprHandle(2) * (x % 5); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } + + { + VarHandle x("x", kInt); + ExprHandle body = (x / 10) * 0 + x % 5; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "x"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 5); + } +} + +TEST(Simplify, SimplifyRoundModPatternMultivar) { + { + // Multivar. + // (x/8) * 8 + (y/5)*5 + x%8 + y%5 => x + y. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x / ExprHandle(8) * ExprHandle(8)) + + (y / ExprHandle(5) * ExprHandle(5)) + (x % 8) + (y % 5); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_VAR_WITH_NAME(add->lhs(), "x"); + IS_VAR_WITH_NAME(add->rhs(), "y"); + } + + { + // Find the right var. + // (y/8) * 8 x%8 + y%8 + z%8 => x%8 + y + z%8 + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + ExprHandle body = + (y / ExprHandle(8) * ExprHandle(8)) + (x % 8) + (y % 8) + (z % 8); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_NODE_WITH_NAME(Add, add->lhs(), add2); + IS_NODE_WITH_NAME(Mod, add2->lhs(), xMod); + IS_VAR_WITH_NAME(xMod->lhs(), "x"); + IS_IMM_WITH_VAL(Int, xMod->rhs(), 8); + IS_VAR_WITH_NAME(add2->rhs(), "y"); + IS_NODE_WITH_NAME(Mod, add->rhs(), zMod); + IS_VAR_WITH_NAME(zMod->lhs(), "z"); + IS_IMM_WITH_VAL(Int, zMod->rhs(), 8); + } + + { + // Compound. + // (x + (z + 512 * y) % 16) + 16 * ((z + 512 * y) / 16) + // => (z + 512 * y) + x + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle z("z", kInt); + + ExprHandle body = x + (z + y * 512) % 16 + ((z + y * 512) / 16 * 16); + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "x + (z + 512 * y)"); + } +} + +TEST(Simplify, SimplifyModRoundModPattern) { + { + // t/7 % 9 * 7 + t % 7 => t%63 + VarHandle t("t", kInt); + ExprHandle body = (t / 7 % 9) * 7 + t % 7; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "t"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 63); + } + + { + // 2*t/7 % 9 * 7 + 2*t % 7 => 2*t % 63 + VarHandle t("t", kInt); + ExprHandle body = (ExprHandle(2) * t / 7 % 9) * 7 + ExprHandle(2) * t % 7; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_NODE_WITH_NAME(Mul, mod->lhs(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "t"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 63); + } + + { + // t/x % y * x + t % x => t%(x*y) + VarHandle t("t", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (t / x % y) * x + t % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "t"); + IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // k*t/x % y * x + k*t % x => k*t%(x*y) + VarHandle t("t", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle k("k", kInt); + ExprHandle body = (k * t / x % y) * x + k * t % x; + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "(k * t) % (x * y)"); + } + + { + // t/k/x % y * x + t/k % x => t/k%(x*y) + VarHandle t("t", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle k("k", kInt); + ExprHandle body = (t / k / x % y) * x + t / k % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_NODE_WITH_NAME(Div, mod->lhs(), div); + IS_VAR_WITH_NAME(div->lhs(), "t"); + IS_VAR_WITH_NAME(div->rhs(), "k"); + IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // Sanity checking we wont do the optimization on floats. + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + VarHandle z("z", kFloat); + ExprHandle body = ((x / y % z) * y) + (x % y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_NODE_WITH_NAME(Mul, add->lhs(), mul); + IS_NODE_WITH_NAME(Mod, mul->lhs(), mod); + IS_NODE_WITH_NAME(Div, mod->lhs(), div); + IS_VAR_WITH_NAME(div->lhs(), "x"); + IS_VAR_WITH_NAME(div->rhs(), "y"); + IS_VAR_WITH_NAME(mod->rhs(), "z"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + IS_NODE_WITH_NAME(Mod, add->rhs(), mod2); + IS_VAR_WITH_NAME(mod2->lhs(), "x"); + IS_VAR_WITH_NAME(mod2->rhs(), "y"); + } +} + +TEST(Simplify, SimplifyModRoundModPatternFactorization) { + { + // 2 * (t /7 % 9 * 7) + 2 * (t % 7) => 2 * (t % 63) + VarHandle t("t", kInt); + ExprHandle body = + ExprHandle(2) * ((t / 7 % 9) * 7) + ExprHandle(2) * (t % 7); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "t"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 63); + } + + { + // t /7 % 9 * 14 + 2* (t % 7) => 2* (t % 63) + VarHandle t("t", kInt); + ExprHandle body = (t / 7 % 9) * 14 + ExprHandle(2) * (t % 7); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "t"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 63); + } + + { + // t/14 % 9 * 7 + t/2 % 7 => t/2 % 63 + VarHandle t("t", kInt); + ExprHandle body = (t / 14 % 9) * 7 + t / 2 % 7; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_NODE_WITH_NAME(Div, mod->lhs(), div); + IS_VAR_WITH_NAME(div->lhs(), "t"); + IS_IMM_WITH_VAL(Int, div->rhs(), 2); + IS_IMM_WITH_VAL(Int, mod->rhs(), 63); + } + + { + // t/(7*3) % 9 * 7*3 + t % (7*3) => t % 189 + VarHandle t("t", kInt); + ExprHandle body = (t / (ExprHandle(7) * ExprHandle(3)) % 9) * 7 * 3 + + t % (ExprHandle(7) * ExprHandle(3)); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mod, simplified.node(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "t"); + IS_IMM_WITH_VAL(Int, mod->rhs(), 189); + } + + { + // 2*(t/x % y * x) + 2*(t % x) => 2*(t%(x*y)) + VarHandle t("t", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = + ExprHandle(2) * ((t / x % y) * x) + ExprHandle(2) * (t % x); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_NODE_WITH_NAME(Mod, mul->rhs(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "t"); + IS_NODE_WITH_NAME(Mul, mod->rhs(), mul2); + IS_VAR_WITH_NAME(mul2->lhs(), "x"); + IS_VAR_WITH_NAME(mul2->rhs(), "y"); + } +} + +TEST(Simplify, SimplifyModRoundModPatternMultivar) { + { + // t/7 % 9 * 7 + t % 7 + t => t % 63 + t + VarHandle t("t", kInt); + ExprHandle body = (t / 7 % 9) * 7 + t % 7 + t; + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "t % 63 + t"); + } + + { + // t/7 % 9 * 7 + t/8 % 9 * 8 + t % 7 + t % 8 => t % 63 + t % 72 + VarHandle t("t", kInt); + ExprHandle body = (t / 7 % 9) * 7 + (t / 8 % 9) * 8 + t % 7 + t % 8; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_NODE_WITH_NAME(Mod, add->lhs(), mod1); + IS_VAR_WITH_NAME(mod1->lhs(), "t"); + IS_IMM_WITH_VAL(Int, mod1->rhs(), 63); + IS_NODE_WITH_NAME(Mod, add->rhs(), mod2); + IS_VAR_WITH_NAME(mod2->lhs(), "t"); + IS_IMM_WITH_VAL(Int, mod2->rhs(), 72); + } + + { + // k + t/x % y * x + t % x => k + t%(x*y) + VarHandle t("t", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle k("k", kInt); + ExprHandle body = k + (t / x % y) * x + t % x; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_VAR_WITH_NAME(add->lhs(), "k"); + IS_NODE_WITH_NAME(Mod, add->rhs(), mod); + IS_VAR_WITH_NAME(mod->lhs(), "t"); + IS_NODE_WITH_NAME(Mul, mod->rhs(), mul); + IS_VAR_WITH_NAME(mul->lhs(), "x"); + IS_VAR_WITH_NAME(mul->rhs(), "y"); + } + + { + // t/x % y * x + t % x + (t/k / x % y) * x + t/k % x + // => t%(x*y) + t/k % (x*y) + VarHandle t("t", kInt); + VarHandle x("x", kInt); + VarHandle y("y", kInt); + VarHandle k("k", kInt); + ExprHandle body = (t / x % y) * x + t % x + (t / k / x % y) * x + t / k % x; + ExprHandle simplified = IRSimplifier::simplify(body); + checkExprIR(simplified, "(t / k) % (x * y) + t % (x * y)"); + } + + { + // 3D: (7 * ((i0_flat / 7) % 9) + i0_flat % 7) + 63 * (i0_flat / 63) + // => io_flat + VarHandle t("io_flat", kInt); + ExprHandle body = + ExprHandle(7) * (t / 7 % 9) + t % 7 + ExprHandle(63) * (t / 63); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "io_flat"); + } + + { // 5D: i0_flat / (11 * 10 * 9 * 7) * (7 * 9 * 10 * 11) + + // (i0_flat / (10 * 9 * 7) % 11) * 7 * 9 * 10 + + // (i0_flat / (9 * 7) % 10) * 7 * 9 + + // (i0_flat / 7 % 9) * 7 + + // i0_flat % 7 => io_flat + VarHandle t("io_flat", kInt); + ExprHandle body = (t / (ExprHandle(11) * 10 * 9 * 7)) * (7 * 9 * 10 * 11) + + (t / (ExprHandle(10) * 9 * 7) % 11) * 7 * 9 * 10 + + (t / (ExprHandle(9) * 7) % 10) * 7 * 9 + (t / 7 % 9) * 7 + t % 7; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "io_flat"); + } + + { + // 3D: (m * ((i0_flat / m) % n) + i0_flat % m) + (m * n) * + // (i0_flat / (m * n)) => io_flat + VarHandle t("io_flat", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + ExprHandle body = m * (t / m % n) + t % m + (m * n) * (t / (m * n)); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "io_flat"); + } + + { // 5D: i0_flat / (k * l * n * m) * (m * n * l * k) + + // (i0_flat / (l * n * m) % k) * m * n * l + + // (i0_flat / (n * m) % l) * m * n + + // (i0_flat / m % n) * m + + // i0_flat % m => io_flat + VarHandle t("io_flat", kInt); + VarHandle m("m", kInt); + VarHandle n("n", kInt); + VarHandle l("l", kInt); + VarHandle k("k", kInt); + ExprHandle body = (t / (k * l * n * m)) * (m * n * l * k) + + (t / (l * n * m) % k) * m * n * l + (t / (n * m) % l) * m * n + + (t / m % n) * m + t % m; + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "io_flat"); + } +} + +TEST(Simplify, SimplifyDivisionScalarFactorization) { + { + // Simple factorization of numerator and denominator. + // 8x / 4y => 2x / y. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x * 8) / (y * 4); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Div, simplified.node(), div); + IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); + IS_VAR_WITH_NAME(lhs->rhs(), "x"); + IS_VAR_WITH_NAME(div->rhs(), "y"); + } + + { + // Don't change anything if we can't factorize. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x * 7) / (y * 4); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Div, simplified.node(), div); + IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 7); + IS_VAR_WITH_NAME(lhs->rhs(), "x"); + IS_NODE_WITH_NAME(Mul, div->rhs(), rhs); + IS_IMM_WITH_VAL(Int, rhs->lhs(), 4); + IS_VAR_WITH_NAME(rhs->rhs(), "y"); + } + + { + // Don't reorder floats. + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle body = (x * 8) / (y * 4); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Div, simplified.node(), div); + IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); + IS_VAR_WITH_NAME(lhs->lhs(), "x"); + IS_IMM_WITH_VAL(Float, lhs->rhs(), 8.f); + IS_NODE_WITH_NAME(Mul, div->rhs(), rhs); + IS_VAR_WITH_NAME(rhs->lhs(), "y"); + IS_IMM_WITH_VAL(Float, rhs->rhs(), 4.f); + } + + { + // Sanity check we do nothing if there are only scalar parts. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x * 1) / (y * 1); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Div, simplified.node(), div); + IS_VAR_WITH_NAME(div->lhs(), "x"); + IS_VAR_WITH_NAME(div->rhs(), "y"); + } + + { + // Can factorize amounts of variables. + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = (x + x + x + x) / (y + y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Div, simplified.node(), div); + IS_NODE_WITH_NAME(Mul, div->lhs(), lhs); + IS_IMM_WITH_VAL(Int, lhs->lhs(), 2); + IS_VAR_WITH_NAME(lhs->rhs(), "x"); + IS_VAR_WITH_NAME(div->rhs(), "y"); + } +} + +TEST(Simplify, SimplifyConstantBranches) { + { + // If the condition is constant true then take the true_value. + // 1 ? x : y => x + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle t(1); + ExprHandle body = IfThenElse::make(t, x, y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // If the condition is constant false then take the false_value. + // 0 ? x : y => y + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle t(0); + ExprHandle body = IfThenElse::make(t, x, y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "y"); + } + + { + // condition is simplified before checking. + // (x-x) ? x : y => y + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = IfThenElse::make(x - x, x, y); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "y"); + } + + { + // If both branches are the same then don't do the condition. + // y ? x : x => x + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = IfThenElse::make(y, x, x); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_VAR_WITH_NAME(simplified.node(), "x"); + } + + { + // If both branches simplify to the same thing it still works. + // y ? (x + x) : (2 * x) => x + VarHandle x("x", kInt); + VarHandle y("y", kInt); + ExprHandle body = IfThenElse::make(y, x + x, ExprHandle(2) * x); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_IMM_WITH_VAL(Int, mul->lhs(), 2); + IS_VAR_WITH_NAME(mul->rhs(), "x"); + } +} + +TEST(Simplify, SimplifyConstantCond) { + { + // If the condition is constant true then take the true_value. + // 1 ? A[0] = 1 : B[0] = 1 => A[0] = 1 + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + ExprHandle condition(1); + StmtPtr true_val = Store::make(a, {0}, 1); + StmtPtr false_val = Store::make(b, {0}, 1); + + CondPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "A"); + } + + { + // If the condition is constant false then take the false_value. + // 0 ? A[0] = 1 : B[0] = 1 => B[0] = 1 + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + ExprHandle condition(0); + StmtPtr true_val = Store::make(a, {0}, 1); + StmtPtr false_val = Store::make(b, {0}, 1); + + StmtPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "B"); + } + + { + // condition is simplified before checking. + // (x-x) ? A[0] = 1 : B[0] = 1 => B[0] = 1 + VarHandle x("x", kInt); + BufHandle a("A", {1}, kInt); + BufHandle b("B", {1}, kInt); + ExprHandle condition(x - x); + StmtPtr true_val = Store::make(a, {0}, 1); + StmtPtr false_val = Store::make(b, {0}, 1); + + StmtPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "B"); + } + + { + // If both branches are the same then don't do the condition. + // x ? A[0] = x : A[0] = x => A[0] = x + VarHandle x("x", kInt); + BufHandle a("A", {1}, kInt); + ExprHandle condition(x - x); + StmtPtr true_val = Store::make(a, {0}, x); + StmtPtr false_val = Store::make(a, {0}, x); + + StmtPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "A"); + } + + { + // If both branches simplify to the same thing it still works. + // x ? (x + x) : (2 * x) => x + VarHandle x("x", kInt); + BufHandle a("A", {1}, kInt); + ExprHandle condition(x - x); + StmtPtr true_val = Store::make(a, {0}, ExprHandle(2) * x); + StmtPtr false_val = Store::make(a, {0}, x + x); + + StmtPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "A"); + } + + { + // But not if they dont + // x ? x : (2 * x) => x ? x : (2 * x) + VarHandle x("x", kInt); + BufHandle a("A", {1}, kInt); + ExprHandle condition(x); + StmtPtr true_val = Store::make(a, {0}, x); + StmtPtr false_val = Store::make(a, {0}, ExprHandle(2) * x); + + StmtPtr body = alloc(condition.node(), true_val, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + ASSERT_EQ(block, nullptr); + } + + { + StmtPtr cond = alloc( + ExprHandle(false).node(), + alloc(std::vector({})), + nullptr); + StmtPtr simplified = IRSimplifier::simplify(cond); + ASSERT_EQ(simplified, nullptr); + } + + { + StmtPtr cond = alloc( + ExprHandle(true).node(), + nullptr, + alloc(std::vector({}))); + StmtPtr simplified = IRSimplifier::simplify(cond); + ASSERT_EQ(simplified, nullptr); + } +} + +TEST(Simplify, SimplifyEliminateEmptyCond) { + // If the branches are empty in different ways, eliminate. + { + VarHandle x("x", kInt); + ExprHandle condition(x); + StmtPtr true_val = alloc(std::vector({})); + + StmtPtr body = alloc(condition.node(), true_val, nullptr); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + ASSERT_NE(block, nullptr); + ASSERT_EQ(block->nstmts(), 0); + } + + { + VarHandle x("x", kInt); + ExprHandle condition(x); + StmtPtr false_val = alloc(std::vector({})); + + StmtPtr body = alloc(condition.node(), nullptr, false_val); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + ASSERT_NE(block, nullptr); + ASSERT_EQ(block->nstmts(), 0); + } +} + +TEST(Simplify, SimplifyConstantComparisons) { + auto ComparisonTest = + [](ExprHandle a, ExprHandle b, CompareSelectOperation op, int result) { + ExprHandle body = CompareSelect::make(a, b, op); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), result); + }; + + // Equals. + ComparisonTest(2, 2, kEQ, 1); + ComparisonTest(1, 2, kEQ, 0); + ComparisonTest(2, 1, kEQ, 0); + + // Greater than. + ComparisonTest(2, 2, kGT, 0); + ComparisonTest(1, 2, kGT, 0); + ComparisonTest(2, 1, kGT, 1); + + // Greater or Equal. + ComparisonTest(2, 2, kGE, 1); + ComparisonTest(1, 2, kGE, 0); + ComparisonTest(2, 1, kGE, 1); + + // Less Than. + ComparisonTest(2, 2, kLT, 0); + ComparisonTest(1, 2, kLT, 1); + ComparisonTest(2, 1, kLT, 0); + + // Less or Equal. + ComparisonTest(2, 2, kLE, 1); + ComparisonTest(1, 2, kLE, 1); + ComparisonTest(2, 1, kLE, 0); + + // Not equal. + ComparisonTest(2, 2, kNE, 0); + ComparisonTest(1, 2, kNE, 1); + ComparisonTest(2, 1, kNE, 1); + + // With specified results: + ExprHandle body = CompareSelect::make(2, 2, 5, 42, kNE); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_IMM_WITH_VAL(Int, simplified.node(), 42); +} + +TEST(Simplify, SimplifySymbolicComparisons) { + VarHandle x("x", kInt); + VarHandle y("y", kInt); + + auto TookTrueBranch = [](ExprHandle a) { IS_IMM_WITH_VAL(Int, a.node(), 1); }; + auto TookFalseBranch = [](ExprHandle a) { + IS_IMM_WITH_VAL(Int, a.node(), 0); + }; + + // EQ + + // x == x => 1 + ExprHandle body = CompareSelect::make(x, x, kEQ); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x == x+1 => 0 + body = CompareSelect::make(x, x + 1, kEQ); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x == x * 2 cannot simplify since we don't know x is nonzero. + body = CompareSelect::make(x, x * 2, kEQ); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); + + // x == x * 1 => 1 + body = CompareSelect::make(x, x * 1, kEQ); + TookTrueBranch(IRSimplifier::simplify(body)); + + { + // x == y => x == y + body = CompareSelect::make(x, y, kEQ); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); + ASSERT_EQ(cmp->compare_select_op(), kEQ); + IS_VAR_WITH_NAME(cmp->lhs(), "x"); + IS_VAR_WITH_NAME(cmp->rhs(), "y"); + } + + { + // x == 5 => x == 5 + body = CompareSelect::make(x, 5, kEQ); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp); + ASSERT_EQ(cmp->compare_select_op(), kEQ); + IS_VAR_WITH_NAME(cmp->lhs(), "x"); + IS_IMM_WITH_VAL(Int, cmp->rhs(), 5); + } + + // GT + + // x+1 > x => 1 + body = CompareSelect::make(x + 1, x, kGT); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x > x + 1 => 0 + body = CompareSelect::make(x, x + 1, kGT); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x > x - 1 => 1 + body = CompareSelect::make(x, x - 1, kGT); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x - 1 > x => 0 + body = CompareSelect::make(x - 1, x, kGT); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x > x => 0 + body = CompareSelect::make(x, x, kGT); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x * 2 > x => x * 2 > x + // since we don't know the sign of x. + body = CompareSelect::make(x * 2, x, kGT); + IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); + + // GE + + // x+1 >= x => 1 + body = CompareSelect::make(x + 1, x, kGE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x >= x + 1 => 0 + body = CompareSelect::make(x, x + 1, kGE); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x >= x => 1 + body = CompareSelect::make(x, x, kGE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x * 2 >= x => x * 2 >= x + // since we don't know the sign of x. + body = CompareSelect::make(x * 2, x, kGE); + IS_NODE(CompareSelect, IRSimplifier::simplify(body).node()); + + // LT + + // x+1 < x => 0 + body = CompareSelect::make(x + 1, x, kLT); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x < x + 1 => 1 + body = CompareSelect::make(x, x + 1, kLT); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x < x => 0 + body = CompareSelect::make(x, x, kLT); + TookFalseBranch(IRSimplifier::simplify(body)); + + // LE + + // x+1 <= x => 0 + body = CompareSelect::make(x + 1, x, kLE); + TookFalseBranch(IRSimplifier::simplify(body)); + + // x <= x + 1 => 1 + body = CompareSelect::make(x, x + 1, kLE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x <= x => 1 + body = CompareSelect::make(x, x, kLE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // NE + + // x+1 != x => 1 + body = CompareSelect::make(x + 1, x, kNE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x != x + 1 => 1 + body = CompareSelect::make(x, x + 1, kNE); + TookTrueBranch(IRSimplifier::simplify(body)); + + // x != x => 0 + body = CompareSelect::make(x, x, kNE); + TookFalseBranch(IRSimplifier::simplify(body)); +} + +TEST(Simplify, SimplifyEliminateZeroLengthFor) { + { + // Will eliminate zero loop For. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, 0, 0, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + ASSERT_EQ(block->nstmts(), 0); + } + + { + // still works if start is not zero. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, 2, 2, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + ASSERT_EQ(block->nstmts(), 0); + } + + { + // works if both terms are variable. + VarHandle x("x", kInt); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, x, x, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + ASSERT_EQ(block->nstmts(), 0); + } + + { + // works if one term simplifies down. + VarHandle x("x", kInt); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, 0, x - x, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + ASSERT_EQ(block->nstmts(), 0); + } + + { + // Sanity check does nothing if the condition is not met. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE(For, simplified); + } +} + +TEST(Simplify, SimplifyOneLoopFor) { + { + // Will remove the loop if the body is run once. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "C"); + IS_IMM_WITH_VAL(Int, store->flat_index(), 0); + } + + { + // still works if start is not zero. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, 2, 3, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "C"); + IS_IMM_WITH_VAL(Int, store->flat_index(), 2); + } + + { + // works if both terms are variable. + VarHandle x("x", kInt); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, x, x + 1, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "C"); + IS_VAR_WITH_NAME(store->flat_index(), "x"); + } + + { + // works if one term simplifies down. + VarHandle x("x", kInt); + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = + For::make(i, 0, x - x + 1, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "C"); + IS_IMM_WITH_VAL(Int, store->flat_index(), 0); + } + + { + // Sanity check does nothing if the condition is not met. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + auto body = For::make(i, 0, 3, Store::make(c, {i}, Load::make(a, {i}))); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE(For, simplified); + } +} + +TEST(Simplify, SimplifyForWontLoseLoopOptions) { + { + // Sanity check does nothing if the condition is not met. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + LoopOptions options; + options.set_gpu_block_index(LoopOptions::IDX_W); + auto body = + For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i})), options); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(For, simplified, for_); + LoopOptions options2 = for_->loop_options(); + ASSERT_EQ(options.gpu_block_index(), options2.gpu_block_index()); + } +} + +TEST(Simplify, SimplifyMultilevelFor) { + { + // Multiple layers of For will be simplified out. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); + auto outer = For::make(j, 0, 1, body); + StmtPtr simplified = IRSimplifier::simplify(outer); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "C"); + IS_IMM_WITH_VAL(Int, store->flat_index(), 0); + } + + { + // Will maintain an outer loop if the inner loop is eliminated. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto body = For::make(i, 0, 1, Store::make(c, {i}, Load::make(a, {i}))); + auto outer = For::make(j, 0, 2, body); + StmtPtr simplified = IRSimplifier::simplify(outer); + ForPtr for__ = static_to(simplified); + IS_NODE_WITH_NAME(For, for__, for_); + IS_VAR_WITH_NAME(for_->var(), "j"); + IS_IMM_WITH_VAL(Int, for_->start(), 0); + IS_IMM_WITH_VAL(Int, for_->stop(), 2); + BlockPtr block = to(for_->body()); + ASSERT_NE(block, nullptr); + IS_NODE_WITH_NAME(Store, block->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "C"); + IS_IMM_WITH_VAL(Int, store->flat_index(), 0); + } + + { + // Will maintain inner loop if outer loops is eliminated. + BufHandle a("A", {4}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto body = For::make(i, 0, 2, Store::make(c, {i}, Load::make(a, {i}))); + auto outer = For::make(j, 0, 1, body); + StmtPtr simplified = IRSimplifier::simplify(outer); + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(For, block->front(), for_); + IS_VAR_WITH_NAME(for_->var(), "i"); + IS_IMM_WITH_VAL(Int, for_->start(), 0); + IS_IMM_WITH_VAL(Int, for_->stop(), 2); + IS_NODE_WITH_NAME(Store, for_->body()->front(), store); + IS_VAR_WITH_NAME(store->base_handle(), "C"); + IS_VAR_WITH_NAME(store->flat_index(), "i"); + } +} + +TEST(Simplify, SimplifyForCleansUp) { + { + BufHandle a("a", {1, 12, 1}, kFloat); + VarHandle x("x", kInt); + Tensor b = Compute( + "x", + {1, 12, 1}, + [](const VarHandle& i, const VarHandle& m, const VarHandle& n) { + return i + m + n; + }); + LoopNest l({b}); + l.prepareForCodegen(); + + StmtPtr body = LoopNest::sanitizeNames(l.root_stmt()); + StmtPtr simplified = IRSimplifier::simplify(body); + + BlockPtr block = to(simplified); + IS_NODE_WITH_NAME(For, block->front(), for_); + // for is over "m". + IS_VAR_WITH_NAME(for_->var(), "j"); + // x[m] = m; + IS_NODE_WITH_NAME(Store, for_->body()->front(), store); + IS_VAR_WITH_NAME(store->flat_index(), "j"); + IS_VAR_WITH_NAME(store->value(), "j"); + } +} + +TEST(Simplify, SimplifyEliminateEmptyFor) { + { + // Flatten many layers around an empty block to an empty block. + StmtPtr last = alloc(std::vector({})); + for ([[maybe_unused]] const auto i : c10::irange(11)) { + VarHandle loopVar("loopVar", kInt); + last = For::make(loopVar, 0, 10, last); + } + + StmtPtr simplified = IRSimplifier::simplify(last); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 0); + } +} + +TEST(Simplify, SimplifyFlattenBlock) { + { + // Flatten multiple blocks down to one. + // { { { stmt1, stmt2 } } } => { stmt1, stmt2 } + BufHandle a("A", {1}, kInt); + StorePtr store1 = Store::make(a, {0}, 1); + StorePtr store2 = Store::make(a, {0}, 0); + + BlockPtr block1 = alloc(std::vector({store1, store2})); + BlockPtr block2 = alloc(std::vector({block1})); + + BlockPtr enclosing = alloc(std::vector({block2})); + StmtPtr simplified = IRSimplifier::simplify(enclosing); + + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + + IS_NODE_WITH_NAME(Store, block->front(), store1_); + IS_NODE_WITH_NAME(Store, block->back(), store2_); + + ASSERT_EQ(store1->value(), store1_->value()); + ASSERT_EQ(store2->value(), store2_->value()); + } + + { + // Flatten multiple sub blocks containing statements. + // { { stmt1 }, { stmt2 } } => { stmt1, stmt2 } + BufHandle a("A", {1}, kInt); + StorePtr store1 = Store::make(a, {0}, 1); + StorePtr store2 = Store::make(a, {0}, 0); + + BlockPtr block1 = alloc(std::vector({store1})); + BlockPtr block2 = alloc(std::vector({store2})); + + BlockPtr enclosing = alloc(std::vector({block1, block2})); + StmtPtr simplified = IRSimplifier::simplify(enclosing); + + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + + IS_NODE_WITH_NAME(Store, block->front(), store1_); + IS_NODE_WITH_NAME(Store, block->back(), store2_); + + ASSERT_EQ(store1->value(), store1_->value()); + ASSERT_EQ(store2->value(), store2_->value()); + } + + { + // Flatten sub blocks with different depths. + // { stmt1 , { { stmt2 } } } => { stmt1, stmt2 } + BufHandle a("A", {1}, kInt); + StorePtr store1 = Store::make(a, {0}, 1); + StorePtr store2 = Store::make(a, {0}, 0); + + BlockPtr block1 = alloc(std::vector({store2})); + BlockPtr block2 = alloc(std::vector({block1})); + + BlockPtr enclosing = alloc(std::vector({store1, block2})); + StmtPtr simplified = IRSimplifier::simplify(enclosing); + + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + + IS_NODE_WITH_NAME(Store, block->front(), store1_); + IS_NODE_WITH_NAME(Store, block->back(), store2_); + + ASSERT_EQ(store1->value(), store1_->value()); + ASSERT_EQ(store2->value(), store2_->value()); + } + + { + // Flatten many layers around an empty block to an empty block. + StmtPtr last = alloc(std::vector({})); + for ([[maybe_unused]] const auto i : c10::irange(11)) { + last = alloc(std::vector({last})); + } + + StmtPtr simplified = IRSimplifier::simplify(last); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 0); + } +} + +TEST(Simplify, SimplifyEliminateZeroLengthAlloc) { + { + // Simple positive case. + BufHandle b("x", {0}, kInt); + + AllocatePtr alloc_ = Allocate::make(b); + FreePtr free_ = Free::make(b); + + BlockPtr block1 = alloc(std::vector({alloc_, free_})); + ASSERT_EQ(block1->nstmts(), 2); + + StmtPtr simplified = IRSimplifier::simplify(block1); + IS_NODE_WITH_NAME(Block, simplified, block2); + ASSERT_EQ(block2->nstmts(), 0); + } + + { + // Simple negative case. + BufHandle b("x", {2}, kInt); + + AllocatePtr alloc_ = Allocate::make(b); + FreePtr free_ = Free::make(b); + + BlockPtr block1 = alloc(std::vector({alloc_, free_})); + ASSERT_EQ(block1->nstmts(), 2); + + StmtPtr simplified = IRSimplifier::simplify(block1); + IS_NODE_WITH_NAME(Block, simplified, block2); + ASSERT_EQ(block2->nstmts(), 2); + } + + { + // Finds right Alloc/Free. + BufHandle b1("x", {0}, kInt); + BufHandle b2("y", {2}, kInt); + + AllocatePtr alloc1 = Allocate::make(b1); + AllocatePtr alloc2 = Allocate::make(b2); + FreePtr free2_ = Free::make(b2); + FreePtr free1_ = Free::make(b1); + + BlockPtr block1 = + alloc(std::vector({alloc1, alloc2, free2_, free1_})); + ASSERT_EQ(block1->nstmts(), 4); + + StmtPtr simplified = IRSimplifier::simplify(block1); + IS_NODE_WITH_NAME(Block, simplified, block2); + ASSERT_EQ(block2->nstmts(), 2); + IS_NODE_WITH_NAME(Allocate, block2->stmts().front(), simplified_alloc); + IS_VAR_WITH_NAME(simplified_alloc->buffer_var(), "y"); + IS_NODE_WITH_NAME(Free, block2->stmts().back(), simplified_free); + ASSERT_EQ(simplified_alloc->buffer_var(), simplified_free->buffer_var()); + } + + { + // Dynamic shape. + VarHandle z("z", kInt); + BufHandle b1("x", {0}, kInt); + BufHandle b2("y", {z}, kInt); + + AllocatePtr alloc1 = Allocate::make(b1); + AllocatePtr alloc2 = Allocate::make(b2); + FreePtr free2_ = Free::make(b2); + FreePtr free1_ = Free::make(b1); + + BlockPtr block1 = + alloc(std::vector({alloc1, alloc2, free2_, free1_})); + ASSERT_EQ(block1->nstmts(), 4); + StmtPtr simplified = IRSimplifier::simplify(block1); + IS_NODE_WITH_NAME(Block, simplified, block2); + ASSERT_EQ(block2->nstmts(), 2); + } +} + +TEST(Simplify, DontSimplifyRand) { + { + // rand() + rand() = rand() + rand() NOT 2 * rand(). + ExprHandle body = + Intrinsics::make(kRand, kInt) + Intrinsics::make(kRand, kInt); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Add, simplified.node(), add); + IS_RAND(add->lhs()); + IS_RAND(add->rhs()); + } + + { + // rand() - rand() = rand() - rand() NOT 0. + ExprHandle body = + Intrinsics::make(kRand, kFloat) - Intrinsics::make(kRand, kFloat); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Sub, simplified.node(), sub); + IS_RAND(sub->lhs()); + IS_RAND(sub->rhs()); + } + + { + // rand() * rand() = rand() * rand(). + ExprHandle body = + Intrinsics::make(kRand, kInt) * Intrinsics::make(kRand, kInt); + ExprHandle simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Mul, simplified.node(), mul); + IS_RAND(mul->lhs()); + IS_RAND(mul->rhs()); + } +} + +TEST(Simplify, SimplifyReorderForCond) { + BufHandle a("A", {4}, kInt); + BufHandle b("B", {1}, kInt); + BufHandle c("C", {4}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + + { + // for ( if ( ... ) ) => if ( for ( ... ) ). + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kLT), + Store::make(c, {i}, Load::make(a, {i})), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cond, simplified, cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); + IS_NODE_WITH_NAME(For, true_block->front(), loop); + } + + { + // Can't reorder if condition is dependent on the loop var. + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make(i, 2, CompareSelectOperation::kEQ), + Store::make(c, {i}, Load::make(a, {i})), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(For, simplified, loop); + IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); + } + + { + // Can't reorder if condition is dependent on a var that is modified inside + // the loop. + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make( + Load::make(c, {0}), 10, CompareSelectOperation::kLT), + Store::make(c, {0}, Load::make(a, {i})), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(For, simplified, loop); + IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); + } + + { + // Condition based on buffer not referenced in body. Can reorder here. + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make( + Load::make(b, {0}), 10, CompareSelectOperation::kLT), + Store::make(c, {0}, Load::make(a, {i})), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cond, simplified, cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); + IS_NODE_WITH_NAME(For, true_block->front(), loop); + } + + { + // Condition based on buffer read only in body. Can reorder here. + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make( + Load::make(a, {0}), 10, CompareSelectOperation::kLT), + Store::make(c, {0}, Load::make(a, {i})), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cond, simplified, cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); + IS_NODE_WITH_NAME(For, true_block->front(), loop); + } + + { + // Condition depends on Let in the loop. Cannot reorder. + auto body = For::make( + i, + 0, + 4, + Block::make( + {Let::make(j, 3), + Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kLT), + Store::make(c, {0}, Load::make(a, {i})), + nullptr)})); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(For, simplified, loop); + IS_NODE_WITH_NAME(Let, loop->body()->front(), let); + IS_NODE_WITH_NAME(Cond, loop->body()->back(), cond); + } + + { + // Multi level Ifs where all conditions are distinct. Move BOTH Cond + // statements outside the loop. + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make( + Load::make(a, {0}), 10, CompareSelectOperation::kLT), + Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kEQ), + Store::make(c, {0}, Load::make(a, {i})), + nullptr), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cond, simplified, cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); + IS_NODE_WITH_NAME(Cond, true_block->front(), cond2); + IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_block2); + IS_NODE_WITH_NAME(For, true_block2->front(), loop); + } + + { + // Multi level Ifs where the inner condition does depend on a loop var, + // reorder only the first Cond. + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make( + Load::make(a, {0}), 10, CompareSelectOperation::kLT), + Cond::make( + CompareSelect::make(i, 3, CompareSelectOperation::kEQ), + Store::make(c, {0}, Load::make(a, {i})), + nullptr), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cond, simplified, cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); + IS_NODE_WITH_NAME(For, true_block->front(), loop); + IS_NODE_WITH_NAME(Block, loop->body(), loop_body); + IS_NODE_WITH_NAME(Cond, loop_body->front(), cond2); + } + + { + // Don't reorder if there's an else block of the Cond. + // We could, but is it much better? + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kLT), + Store::make(c, {0}, Load::make(a, {i})), + Store::make(c, {0}, 0))); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(For, simplified, loop); + IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); + } + + { + // Condition uses distinct region of Tensor. + // We could reorder here wih better analysis, but we don't. Included for + // completeness. + auto body = For::make( + i, + 0, + 4, + Cond::make( + CompareSelect::make( + Load::make(c, {0}), 10, CompareSelectOperation::kLT), + Store::make(c, {1}, Load::make(a, {i})), + nullptr)); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(For, simplified, loop); + IS_NODE_WITH_NAME(Cond, loop->body()->front(), cond); + } +} + +TEST(Simplify, SimplifyFuseConditions) { + BufHandle a("A", {2}, kInt); + BufHandle b("B", {2}, kInt); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + + { + // Can fuse since the conditions are identical. + // if (A) { X }; if (A) { Y }; => if (A) { X; Y } + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 2); + ASSERT_EQ(cond->false_stmt(), nullptr); + } + + { + // Can't fuse, conditions are not identical in lhs (i != j). + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + IS_NODE_WITH_NAME(Cond, block->front(), cond1); + IS_NODE_WITH_NAME(Cond, block->back(), cond2); + + IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); + IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); + ASSERT_EQ(true_stmt1->nstmts(), 1); + ASSERT_EQ(true_stmt2->nstmts(), 1); + + ASSERT_EQ(cond1->false_stmt(), nullptr); + ASSERT_EQ(cond2->false_stmt(), nullptr); + } + { + // Can't fuse, conditions are not identical in rhs (10 != 11). + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 11, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + IS_NODE_WITH_NAME(Cond, block->front(), cond1); + IS_NODE_WITH_NAME(Cond, block->back(), cond2); + + IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); + IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); + ASSERT_EQ(true_stmt1->nstmts(), 1); + ASSERT_EQ(true_stmt2->nstmts(), 1); + + ASSERT_EQ(cond1->false_stmt(), nullptr); + ASSERT_EQ(cond2->false_stmt(), nullptr); + } + + { + // Can't fuse, conditions are not identical in operation (LT vs GT). + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kGT), + Store::make(a, {1}, i), + nullptr)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + IS_NODE_WITH_NAME(Cond, block->front(), cond1); + IS_NODE_WITH_NAME(Cond, block->back(), cond2); + + IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); + IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); + ASSERT_EQ(true_stmt1->nstmts(), 1); + ASSERT_EQ(true_stmt2->nstmts(), 1); + + ASSERT_EQ(cond1->false_stmt(), nullptr); + ASSERT_EQ(cond2->false_stmt(), nullptr); + } + + { + // Can't fuse, CompareSelect results are different. + // Actually we totally could if we normalized CompareSelect results, but + // TODO for later. + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, 1, 0, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(j, 10, 2, 0, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + IS_NODE_WITH_NAME(Cond, block->front(), cond1); + IS_NODE_WITH_NAME(Cond, block->back(), cond2); + + IS_NODE_WITH_NAME(Block, cond1->true_stmt(), true_stmt1); + IS_NODE_WITH_NAME(Block, cond2->true_stmt(), true_stmt2); + ASSERT_EQ(true_stmt1->nstmts(), 1); + ASSERT_EQ(true_stmt2->nstmts(), 1); + + ASSERT_EQ(cond1->false_stmt(), nullptr); + ASSERT_EQ(cond2->false_stmt(), nullptr); + } + + { + // Can fuse with false stmt only. + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + nullptr, + Store::make(a, {0}, i)), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + nullptr, + Store::make(a, {1}, i))}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->false_stmt(), false_stmt); + ASSERT_EQ(false_stmt->nstmts(), 2); + ASSERT_EQ(cond->true_stmt(), nullptr); + } + + { + // Can fuse with both true and false stmt. + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + Store::make(b, {0}, i)), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + Store::make(b, {1}, i))}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 2); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt); + ASSERT_EQ(false_stmt->nstmts(), 2); + } + + { + // Can fuse with mismatched true / false stmt existing + auto body = Block::make( + {Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + nullptr, + Store::make(b, {1}, i))}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 1); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), false_stmt); + ASSERT_EQ(false_stmt->nstmts(), 1); + } + + { + // Can fuse partial block contents, ie when there are non fused stmts before + // and after. + // before: + // if (j < 10) { A[0] = j; } + // if (i < 10) { A[0] = i; } + // if (i < 10) { A[1] = i; } + // if (i < 11) { A[1] = j; } + // + // after: + // + // if (j < 10) { A[0] = j; } + // if (i < 10) { + // A[0] = i; + // A[1] = i; + // } + // if (i < 11) { A[1] = j; } + + auto body = Block::make({ + Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, j), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 11, CompareSelectOperation::kLT), + Store::make(a, {1}, j), + nullptr), + }); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 3); + auto it = block->begin(); + it++; + IS_NODE_WITH_NAME(Cond, *it, cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 2); + ASSERT_EQ(cond->false_stmt(), nullptr); + } + + { + // Can fuse longer sequences of identical conditions. + auto body = Block::make({ + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, j), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, j), + nullptr), + }); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 4); + ASSERT_EQ(cond->false_stmt(), nullptr); + } + + { + // Can't fuse through a non condition. + auto body = Block::make({ + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, j), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Store::make(b, {1}, i + j), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr), + Cond::make( + CompareSelect::make(i, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, j), + nullptr), + }); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 3); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 2); + ASSERT_EQ(cond->false_stmt(), nullptr); + + IS_NODE_WITH_NAME(Cond, block->back(), cond2); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt2); + ASSERT_EQ(true_stmt2->nstmts(), 2); + ASSERT_EQ(cond2->false_stmt(), nullptr); + + auto it = block->begin(); + it++; + IS_NODE_WITH_NAME(Store, *it, middle); + } + + { + // Can fuse if the conditions simplify to the same thing. + auto body = Block::make( + {Cond::make( + CompareSelect::make( + i * 2, + ExprHandle(87) % ExprHandle(11), + CompareSelectOperation::kLT), + Store::make(a, {0}, i), + nullptr), + Cond::make( + CompareSelect::make( + i * 2, + ExprHandle(300) / ExprHandle(30), + CompareSelectOperation::kLT), + Store::make(a, {1}, i), + nullptr)}); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 2); + ASSERT_EQ(cond->false_stmt(), nullptr); + } + + { + // Can fuse non-CompareSelects. + // if (i) { X } if (i) { Y } => if (i) { X; Y } + auto body = Block::make( + {Cond::make(i, Store::make(a, {0}, i), nullptr), + Cond::make(i, Store::make(a, {1}, i), nullptr)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + IS_NODE_WITH_NAME(Cond, block->front(), cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_stmt); + ASSERT_EQ(true_stmt->nstmts(), 2); + ASSERT_EQ(cond->false_stmt(), nullptr); + } + + { + // Sanity check wont fuse different non-CompareSelects. + auto body = Block::make( + {Cond::make(i, Store::make(a, {0}, i), nullptr), + Cond::make(j, Store::make(a, {1}, i), nullptr)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + IS_NODE_WITH_NAME(Cond, block->front(), cond1); + IS_NODE_WITH_NAME(Cond, block->back(), cond2); + } + + { + // Sanity check constant condition elimination still occurs when merging is + // possible. + auto body = Block::make( + {Cond::make(1, Store::make(a, {0}, i), nullptr), + Cond::make(1, Store::make(a, {1}, i), nullptr)}); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 2); + IS_NODE_WITH_NAME(Store, block->front(), store1); + IS_NODE_WITH_NAME(Store, block->back(), store2); + } + + { + // Sanity check for-cond reordering occurs after fusing. + auto body = For::make( + i, + 0, + 4, + Block::make( + {Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kLT), + Store::make(a, {1}, Load::make(b, {0})), + nullptr), + Cond::make( + CompareSelect::make(j, 10, CompareSelectOperation::kLT), + Store::make(a, {2}, Load::make(b, {0})), + nullptr)})); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Cond, simplified, cond); + IS_NODE_WITH_NAME(Block, cond->true_stmt(), true_block); + IS_NODE_WITH_NAME(For, true_block->front(), loop); + } +} + +TEST(Simplify, SimplifySyncThreads) { + BufHandle a("A", {4}, kInt); + VarHandle i("i", kInt); + + { + // Merge two inner SyncThreads. + auto body = Block::make( + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + {Store::make(a, {0}, 1), + alloc(), + alloc(), + Store::make(a, {1}, 0)}); + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 3); + auto it = block->begin(); + IS_NODE(Store, *it++); + IS_NODE(SyncThreads, *it++); + IS_NODE(Store, *it++); + } + + { + // Eliminate outer SyncThreads. + auto body = Block::make( + {alloc(), Store::make(a, {1}, 0), alloc()}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + auto it = block->begin(); + IS_NODE(Store, *it); + } + + { + // Merge many inner SyncThreads. + auto body = Block::make( + {Store::make(a, {0}, 1), + alloc(), + alloc(), + alloc(), + alloc(), + alloc(), + Store::make(a, {1}, 0)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 3); + auto it = block->begin(); + IS_NODE(Store, *it++); + IS_NODE(SyncThreads, *it++); + IS_NODE(Store, *it++); + } + + { + // Merge multiple outer SyncThreads. + auto body = Block::make( + {alloc(), + alloc(), + Store::make(a, {1}, 0), + alloc(), + alloc(), + alloc(), + alloc()}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 1); + auto it = block->begin(); + IS_NODE(Store, *it); + } + + { + // Merge multiple sections; + auto body = Block::make( + {Store::make(a, {0}, 1), + alloc(), + alloc(), + Store::make(a, {1}, 0), + Store::make(a, {2}, 0), + alloc(), + alloc(), + alloc(), + Store::make(a, {3}, 0)}); + + StmtPtr simplified = IRSimplifier::simplify(body); + IS_NODE_WITH_NAME(Block, simplified, block); + ASSERT_EQ(block->nstmts(), 6); + auto it = block->begin(); + IS_NODE(Store, *it++); + IS_NODE(SyncThreads, *it++); + IS_NODE(Store, *it++); + IS_NODE(Store, *it++); + IS_NODE(SyncThreads, *it++); + IS_NODE(Store, *it++); + } +} + +TEST(Simplify, SimplifyRampSubBroadcast) { + int num_lanes = 4; + ExprHandle ramp = Ramp::make(ExprHandle(0), ExprHandle(6), num_lanes); + ExprHandle broadcast = Broadcast::make(ExprHandle(-5), num_lanes); + ExprHandle simplified = IRSimplifier::simplify(ramp - broadcast); + RampPtr newRamp = simplified.AsNode(); + IS_NODE_WITH_NAME(IntImm, newRamp->base(), base); + ASSERT_EQ(base->value(), 5); + IS_NODE_WITH_NAME(IntImm, newRamp->stride(), stride); + ASSERT_EQ(stride->value(), 6); + ASSERT_EQ(newRamp->lanes(), num_lanes); +} + +TEST(Simplify, SimplifyBroadcastTermExpander) { + int num_lanes = 8; + ExprHandle bc0 = Broadcast::make(ExprHandle(0), num_lanes); + ExprHandle bc1 = Broadcast::make(ExprHandle(1), num_lanes); + ExprHandle bc2 = Broadcast::make(ExprHandle(2), num_lanes); + // NB: We need a term in the middle which isn't simplified to trigger the + // relevant path in TermExpander::mutate. The two bc1 terms are brought + // together and simplified to 2 * bc1, which then needs to make 2 multi-lane. + ExprHandle simplified = IRSimplifier::simplify(bc1 + (bc0 / bc2) + bc1); + BufHandle buf("buf", {num_lanes}, kInt); + // The result isn't fully simplified currently and thus would be brittle to + // match. Observe its value instead. + auto store = Store::make(buf, {Ramp::make(0, 1, num_lanes)}, simplified); + SimpleIREvaluator eval(store, {buf}); + std::vector output(num_lanes); + eval(output); + for (const auto i : c10::irange(num_lanes)) { + ASSERT_EQ(output[i], 2); + } +} + +TEST(Simplify, CompareSelectLoopBounds) { + constexpr int N = 8; + BufHandle b("b", {N}, kFloat); + VarHandle n("n", kInt); + VarHandle m("m", kInt); + VarHandle var_N("var_N", kInt); + VarHandle var_M("var_M", kInt); + + auto test_case_fn = [](const VarHandle& n, + const BufHandle& b, + const ExprHandle& start, + const ExprHandle& stop, + const int& cmp_val, + const CompareSelectOperation& cmp_op, + const std::string& check_string) { + StmtPtr s = For::make( + n, + start, + stop, + b.store({n}, CompareSelect::make(n, cmp_val, 0.f, 1.0f, cmp_op))); + s = IRSimplifier::simplify(s); + std::ostringstream oss; + oss << *s; + std::string target_string = "# CHECK: "; + target_string += check_string; + torch::jit::testing::FileCheck().run(target_string, oss.str()); + }; + + auto test_case_nest_loops_fn = [](const VarHandle& n, + const VarHandle& m, + const BufHandle& b, + const ExprHandle& n_start, + const ExprHandle& n_stop, + const ExprHandle& m_start, + const ExprHandle& m_stop, + const CompareSelectOperation& cmp_op, + const std::string& check_string) { + StmtPtr s = For::make( + m, + m_start, + m_stop, + b.store({n, m}, CompareSelect::make(n, m, 0.f, 1.0f, cmp_op))); + StmtPtr root_s = For::make(n, n_start, n_stop, s); + root_s = IRSimplifier::simplify(root_s); + std::ostringstream oss; + oss << *root_s; + std::string target_string = "# CHECK: "; + target_string += check_string; + torch::jit::testing::FileCheck().run(target_string, oss.str()); + }; + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n < 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, 1, kLT, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n <= 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n <= 1 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, 1, kLE, "b[n] = n<=1 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n <= 0 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, 0, kLE, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n < 0 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, 0, kLT, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n < 8 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 0.f; + // } + test_case_fn(n, b, 1, N, N, kLT, "b[n] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n <= 7 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 0.f; + // } + test_case_fn(n, b, 1, N, N - 1, kLE, "b[n] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n <= 8 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 0.f; + // } + test_case_fn(n, b, 1, N, N, kLE, "b[n] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n < 7 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n < 7 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, N - 1, kLT, "b[n] = n<7 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n > 0 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 0.f; + // } + test_case_fn(n, b, 1, N, 0, kGT, "b[n] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n > 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n > 1 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, 1, kGT, "b[n] = n>1 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n >= 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 0.f; + // } + test_case_fn(n, b, 1, N, 1, kGE, "b[n] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n > 7 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, N - 1, kGT, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n >= 7 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n >= 7 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, N - 1, kGE, "b[n] = n>=7 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n > 5 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n > 5 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, 5, kGT, "b[n] = n>5 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n >= 5 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n >= 5 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, 5, kGE, "b[n] = n>=5 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n > 8 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, N, kGT, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n >= 8 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, N, kGE, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, 2)) { + // b[n] = n == 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, 2)) { + // b[1] = 0.f; + // } + test_case_fn(n, b, 1, 2, 1, kEQ, "b[1] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n == 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n == 1 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, 1, kEQ, "b[n] = n==1 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n == 0 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, 0, kEQ, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n == 7 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n == 7 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, N - 1, kEQ, "b[n] = n==7 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n == 8 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + test_case_fn(n, b, 1, N, N, kEQ, "b[n] = 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 1 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, 1, kNE, "b[n] = n!=1 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 7 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 7 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, N - 1, kNE, "b[n] = n!=7 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 5 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 5 ? 0.f : 1.f; + // } + test_case_fn(n, b, 1, N, 5, kNE, "b[n] = n!=5 ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 0 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 0.f; + // } + test_case_fn(n, b, 1, N, 0, kNE, "b[n] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n != 8 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 0.f; + // } + test_case_fn(n, b, 1, N, N, kNE, "b[n] = 0.f;"); + + // Before: + // for (const auto n : c10::irange(10, 20)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = (n != m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(10, 20)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = 0.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kNE, "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 20, + var_N + 30, + var_N + 40, + kNE, + "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 20, + var_M + 30, + var_M + 40, + kNE, + "b[n, m] = n!=m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 20)) { + // b[n, m] = (n != m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 20)) { + // b[n, m] = 0.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kNE, "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_N + 10, + var_N + 20, + kNE, + "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_M + 10, + var_M + 20, + kNE, + "b[n, m] = n!=m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 31)) { + // b[n, m] = (n != m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 31)) { + // b[n, m] = (n != m) ? 0.f : 1.f; + // } + // } + test_case_nest_loops_fn( + n, m, b, 30, 40, 10, 31, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_N + 10, + var_N + 31, + kNE, + "b[n, m] = n!=m ? 0.f : 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_M + 10, + var_M + 31, + kNE, + "b[n, m] = n!=m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(10, 31)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = (n != m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(10, 31)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = (n != m) ? 0.f : 1.f; + // } + // } + test_case_nest_loops_fn( + n, m, b, 10, 31, 30, 40, kNE, "b[n, m] = n!=m ? 0.f : 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 31, + var_N + 30, + var_N + 40, + kNE, + "b[n, m] = n!=m ? 0.f : 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 31, + var_M + 30, + var_M + 40, + kNE, + "b[n, m] = n!=m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(10, 20)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = (n < m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(10, 20)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = 0.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kLT, "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 20, + var_N + 30, + var_N + 40, + kLT, + "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 20, + var_M + 30, + var_M + 40, + kLT, + "b[n, m] = n m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 20)) { + // b[n, m] = 0.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kGT, "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_N + 10, + var_N + 20, + kGT, + "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_M + 10, + var_M + 20, + kGT, + "b[n, m] = n>m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(10, 31)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = (n > m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(10, 31)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = 1.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kGT, "b[n, m] = 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 31, + var_N + 30, + var_N + 40, + kGT, + "b[n, m] = 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 31, + var_M + 30, + var_M + 40, + kGT, + "b[n, m] = n>m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 31)) { + // b[n, m] = (n >= m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 31)) { + // b[n, m] = 0.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 30, 40, 10, 31, kGE, "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_N + 10, + var_N + 31, + kGE, + "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_M + 10, + var_M + 31, + kGE, + "b[n, m] = n>=m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(10, 20)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = (n >= m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(10, 20)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = 1.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 10, 20, 30, 40, kGE, "b[n, m] = 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 20, + var_N + 30, + var_N + 40, + kGE, + "b[n, m] = 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 20, + var_M + 30, + var_M + 40, + kGE, + "b[n, m] = n>=m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(10, 31)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = (n <= m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(10, 31)) { + // for(const auto m : c10::irange(30, 40)) { + // b[n, m] = 0.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 10, 31, 30, 40, kLE, "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 31, + var_N + 30, + var_N + 40, + kLE, + "b[n, m] = 0.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 10, + var_N + 31, + var_M + 30, + var_M + 40, + kLE, + "b[n, m] = n<=m ? 0.f : 1.f;"); + + // Before: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 20)) { + // b[n, m] = (n <= m) ? 0.f : 1.f; + // } + // } + // After: + // for (const auto n : c10::irange(30, 40)) { + // for(const auto m : c10::irange(10, 20)) { + // b[n, m] = 0.f; + // } + // } + test_case_nest_loops_fn(n, m, b, 30, 40, 10, 20, kLE, "b[n, m] = 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_N + 10, + var_N + 20, + kLE, + "b[n, m] = 1.f;"); + test_case_nest_loops_fn( + n, + m, + b, + var_N + 30, + var_N + 40, + var_M + 10, + var_M + 20, + kLE, + "b[n, m] = n<=m ? 0.f : 1.f;"); +} + +TEST(Simplify, CompareSelectCondAlwaysInLoopBounds) { + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = n < 1 ? 0.f : 1.f; + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + constexpr int N = 8; + BufHandle b("b", {N}, kFloat); + VarHandle n("n", kInt); + StmtPtr s = For::make( + n, 1, N, b.store({n}, CompareSelect::make(n, 1, 0.f, 1.0f, kLT))); + s = IRSimplifier::simplify(s); + std::ostringstream oss; + oss << *s; + torch::jit::testing::FileCheck().run( + R"IR( +# CHECK: b[n] = 1.f; +)IR", + oss.str()); +} + +TEST(Simplify, IfThenCondAlwaysInLoopBounds) { + // Before: + // for (const auto n : c10::irange(1, N)) { + // b[n] = IfThenElse(n < 1 ? 1 : 0, 0.f, 1.f); + // } + // After: + // for (const auto n : c10::irange(1, N)) { + // b[n] = 1.f; + // } + constexpr int N = 8; + BufHandle b("b", {N}, kFloat); + VarHandle n("n", kInt); + StmtPtr s = + For::make(n, 1, N, b.store({n}, IfThenElse::make(n < 1, 0.f, 1.0f))); + s = IRSimplifier::simplify(s); + std::ostringstream oss; + oss << *s; + torch::jit::testing::FileCheck().run( + R"IR( +# CHECK: b[n] = 1.f; +)IR", + oss.str()); +} + +TEST(Simplify, MultiClauseCondAlwaysInLoopBounds) { + // This test mimics the unpadded region of a conv2d. We want to remove any + // conditional that is provably satisfied (or unsatisfied) by the entire loop + // range. + // Before: + // for (const auto i : c10::irange(1, 7)) { + // for (const auto j : c10::irange(1, 7)) { + // b[i, j] = IfThenElse( + // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, 1.f); + // After: + // for (const auto i : c10::irange(1, 7)) { + // for (const auto j : c10::irange(1, 7)) { + // b[i, j] = 1.f; + constexpr int N = 8; + BufHandle b("b", {N, N}, kFloat); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto csel = CompareSelect::make(i, 1, kLT); + csel = CompareSelect::make(j, 1, 1, csel, kLT); + csel = CompareSelect::make(i, N - 1, 1, csel, kGE); + csel = CompareSelect::make(j, N - 1, 1, csel, kGE); + StmtPtr s = b.store({i, j}, IfThenElse::make(csel, 0.f, 1.0f)); + s = For::make(j, 1, N - 1, s); + s = For::make(i, 1, N - 1, s); + s = IRSimplifier::simplify(s); + std::ostringstream oss; + oss << *s; + torch::jit::testing::FileCheck().run( + R"IR( +# CHECK: b[i, j] = 1.f; +)IR", + oss.str()); +} + +TEST(Simplify, DISABLED_SimplifyLoopBounds) { + // This test mimics the padded region of a conv2d. We want to adjust the + // loop bounds such that the condition will be always met. Note that this + // could be solved by peeling, and applying the range-based conditional + // simplification in the previous tests. + // Before: + // for (const auto i : c10::irange(3)) { + // for (const auto j : c10::irange(3)) { + // b[i, j] = (b[i, j]) + (IfThenElse( + // j>=7 ? 1 : (i>=7 ? 1 : (j<1 ? 1 : (i<1 ? 1 : 0))), 0.f, a[i, j])); + // After: + // for (const auto i : c10::irange(1, 3)) { + // for (const auto j : c10::irange(1, 3)) { + // b[i, j] = (b[i, j]) + 1.f; + constexpr int N = 8; + constexpr int K = 3; + BufHandle a("a", {N, N}, kFloat); + BufHandle b("b", {N, N}, kFloat); + VarHandle i("i", kInt); + VarHandle j("j", kInt); + auto csel = CompareSelect::make(i, 1, kLT); + csel = CompareSelect::make(j, 1, 1, csel, kLT); + csel = CompareSelect::make(i, N - 1, 1, csel, kGE); + csel = CompareSelect::make(j, N - 1, 1, csel, kGE); + StmtPtr s = b.store( + {i, j}, b.load({i, j}) + IfThenElse::make(csel, 0.f, a.load({i, j}))); + s = For::make(j, 0, K, s); + s = For::make(i, 0, K, s); + s = IRSimplifier::simplify(s); + std::ostringstream oss; + oss << *s; + torch::jit::testing::FileCheck().run( + R"IR( +# CHECK: for (const auto i : c10::irange(1, 3)) { +# CHECK: for (const auto j : c10::irange(1, 3)) { +# CHECK-NOT: IfThenElse +)IR", + oss.str()); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_te_fuser_pass.cpp b/test/cpp/tensorexpr/test_te_fuser_pass.cpp new file mode 100644 index 0000000000000..56535de914e43 --- /dev/null +++ b/test/cpp/tensorexpr/test_te_fuser_pass.cpp @@ -0,0 +1,402 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +struct WithCPUFuser { + WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) { + overrideCanFuseOnCPU(val); + } + + ~WithCPUFuser() { + overrideCanFuseOnCPU(cpuFuserEnabled); + } + + bool cpuFuserEnabled; +}; + +TEST(TEFuserPass, FuserPass_1) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%0 : Float(128, strides=[1], device=cpu), + %1 : Float(128, strides=[1], device=cpu)): + %12 : int = prim::Constant[value=1]() + %2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) + %2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1) + %3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12) + %4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1) + %5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12) + return (%5))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g); + + // We should not be able to fuse across the in-place operation here. + testing::FileCheck() + .check("prim::TensorExprGroup_") + ->check("aten::add_") + ->check("prim::TensorExprGroup_") + ->run(*g); +} + +TEST(TEFuserPass, FuserPass_2) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%0 : Float(128, strides=[1], device=cpu), + %1 : Float(128, strides=[1], device=cpu)): + %12 : int = prim::Constant[value=1]() + %a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1) + %b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12) + %c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12) + %d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a) + return (%d))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g); + + // We should not be able to fuse across the in-place operation here. + testing::FileCheck() + .check("aten::add_") + ->check("prim::TensorExprGroup_0") + ->run(*g); +} + +TEST(TEFuserPass, FuserPass_3) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(128, strides=[1], device=cpu), + %y : Float(128, strides=[1], device=cpu)): + %r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y) + return (%r))IR"; + { + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + + // We should not create a fusion group since its size would be too small + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); + } + { + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 1); + + // We should create a fusion group since its size is above the threshold + testing::FileCheck().check("prim::TensorExprGroup")->run(*g); + } +} + +TEST(TEFuserPass, FuserPass_0DimInput) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(device=cpu), + %y : Float(device=cpu)): + %one : int = prim::Constant[value=1]() + %a : Float(device=cpu) = aten::mul(%x, %y) + %b : Float(device=cpu) = aten::add(%x, %a, %one) + return (%b))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g); + + // We should fuse 0-dim tensors too + testing::FileCheck().check("prim::TensorExprGroup")->run(*g); +} + +TEST(TEFuserPass, FuserPass_UnfusibleDevice) { + WithCPUFuser cf(false); + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(10, strides=[1], device=cpu)): + %a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) + return (%a))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 1); + + // Test that we're not starting fusion groups from nodes with unfusible device + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); +} + +TEST(TEFuserPass, FuserPass_UnknownShapes) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Tensor, + %y : Tensor): + %a : Tensor = aten::mul(%x, %y) + %b : Tensor = aten::mul(%x, %a) + return (%b))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g); + + // Test that we're not generating fusion groups when shapes are not known + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); +} + +TEST(TEFuserPass, FuserPass_Multidevice) { + { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + return (%cat))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 1); + + // We should be able to fuse this + testing::FileCheck().check("prim::TensorExprGroup")->run(*g); + } + { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cuda:0), + %z : Float(30, strides=[1], device=cpu)): + %dim : int = prim::Constant[value=0]() + %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim) + return (%cat))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 1); + + // We should not fuse this aten::cat since its inputs are from different + // devices + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); + } + { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(10, strides=[1], device=cuda:0)): + %dim : int = prim::Constant[value=0]() + %xy_list : Tensor[] = prim::ListConstruct(%x, %y) + %xy_cat : Float(30, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) + %r : Float(30, strides=[1], device=cpu) = aten::mul(%xy_cat, %z) + return (%r))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + + // Test that we check device before merging one node (cat) into another + // (mul) + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); + } + { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cpu), + %z : Float(10, strides=[1], device=cuda:0)): + %z2 : Tensor = aten::mul(%z, %z) + %dim : int = prim::Constant[value=0]() + %xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2) + %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) + return (%cat))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + + // Test that we check device before merging one node (mul) into another + // (cat) + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); + } + { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cpu), + %y : Float(20, strides=[1], device=cuda:0)): + %r : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y) + return (%r))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 1); + + // We should not fuse this graph since its inputs are from different devices + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); + } + { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(10, strides=[1], device=cuda:0), + %y : Float(20, strides=[1], device=cuda:1), + %z : Float(20, strides=[1], device=cpu)): + %x2 : Float(10, strides=[1], device=cpu) = aten::mul(%x, %x) + %y2 : Float(10, strides=[1], device=cpu) = aten::mul(%y, %y) + %z2 : Float(10, strides=[1], device=cpu) = aten::mul(%z, %z) + return (%x2, %y2, %z2))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + + // We should not fuse these two computations since they use different + // devices + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); + } +} + +TEST(TEFuserPass, FuserPass_MergeGroups) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%a : Float(128, strides=[1], device=cpu), + %b : Float(128, strides=[1], device=cpu)): + %x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a) + %y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b) + return (%x, %y))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 1); + + // The %x and %y computations are completely independent and yet we should put + // them into a single fusion group rather than having two separate ones. + testing::FileCheck() + .check("= prim::TensorExprGroup_") + ->check_not("= prim::TensorExprGroup_") + ->run(*g); +} + +TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Bool(8, strides=[1], device=cpu), + %y : Bool(8, strides=[1], device=cpu)): + %a : Bool(8, strides=[1], device=cpu) = aten::__and__(%x, %y) + %b : Tensor = aten::__or__(%a, %y) + return (%b) + )IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); +} + +TEST(TEFuserPass, FuserPass_Where) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(8, strides=[1], device=cpu), + %y : Float(8, strides=[1], device=cpu), + %z : Float(8, strides=[1], device=cpu)): + %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) + %b : Float(8, strides=[1], device=cpu) = aten::where(%cond, %y, %z) + return (%b) + )IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + testing::FileCheck().check("prim::TensorExprGroup")->run(*g); +} + +TEST(TEFuserPass, FuserPass_WhereList) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%x : Float(8, strides=[1], device=cpu), + %y : Float(8, strides=[1], device=cpu), + %z : Float(8, strides=[1], device=cpu)): + %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y) + %b : Tensor[] = aten::where(%cond) + return (%b) + )IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + FuseTensorExprs(g, /* min_group_size= */ 2); + testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g); +} + +TEST(TEFuserPass, DynamicShapeFusion) { + WithCPUFuser cf; + const auto graph_string = R"IR( + graph(%0 : Float(10, 5, strides=[5, 1], device=cpu), + %1 : Float(10, 5, strides=[5, 1], device=cpu)): + %2 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%0, %1) + %3 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%2, %1) + return (%3))IR"; + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + + g->lint(); + FuseTensorExprs( + g, + /* min_group_size = */ 2, + /* add_composed_op = */ true, + /* fuse_to_dynamic_shapes = */ true); + Code code(g, ""); + + testing::FileCheck() + .check("prim::TensorExprDynamicGroup_") + ->check("prim::TensorExprDynamicGuard") + ->check("prim::TensorExprGroup_") + ->run(*g); + + auto run_and_compare = [&](const std::vector& inputs) { + TORCH_INTERNAL_ASSERT(inputs.size() == 2); + + auto ref = at::mul(at::mul(inputs[0], inputs[1]), inputs[1]); + + InterpreterState interp(code); + Stack stack(inputs.begin(), inputs.end()); + interp.run(stack); + at::Tensor out = pop(stack).toTensor(); + ASSERT_TRUE(at::allclose(out, ref)); + }; + + std::vector inputs = {at::rand({10, 5}), at::rand({10, 5})}; + run_and_compare(inputs); + + std::vector inputs2 = {at::rand({20, 5}), at::rand({20, 5})}; + run_and_compare(inputs2); + + std::vector inputs3 = {at::rand({25, 60}), at::rand({25, 60})}; + run_and_compare(inputs3); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_type.cpp b/test/cpp/tensorexpr/test_type.cpp new file mode 100644 index 0000000000000..6758503f4de79 --- /dev/null +++ b/test/cpp/tensorexpr/test_type.cpp @@ -0,0 +1,202 @@ +#include + +#include "torch/csrc/jit/tensorexpr/eval.h" +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/tensor.h" + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; + +TEST(Type, Test01) { + { + Dtype dt1 = kInt; + ASSERT_EQ(dt1, kInt); + } + { + Dtype dt2_a(kInt, 8); + Dtype dt2_b(kInt, 4); + Dtype dt2_c(ScalarType::Int, 8); + ASSERT_EQ(dt2_a, dt2_c); + ASSERT_NE(dt2_a, dt2_b); + } + { + ASSERT_EQ(kInt, ToDtype()); + ASSERT_EQ(kFloat, ToDtype()); + ASSERT_EQ(kByte, ToDtype()); + ASSERT_EQ(kChar, ToDtype()); + ASSERT_EQ(kShort, ToDtype()); + ASSERT_EQ(kLong, ToDtype()); + ASSERT_EQ(kHalf, ToDtype()); + ASSERT_EQ(kDouble, ToDtype()); + ASSERT_EQ(kBool, ToDtype()); + } + { + Dtype int32x8(kInt, 8); + Dtype float32x8(kFloat, 8); + ASSERT_NE(int32x8, float32x8); + ASSERT_EQ(float32x8, BinaryOpDtype(int32x8, float32x8)); + ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, int32x8)); + ASSERT_EQ(int32x8, BinaryOpDtype(int32x8, int32x8)); + ASSERT_EQ(float32x8, BinaryOpDtype(float32x8, float32x8)); + } +} + +TEST(Type, BitCasting) { + { + VarHandle x("x", kFloat); + ExprHandle y = bitcast(x); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_EQ(y.dtype(), kInt); + } + { + VarHandle x("x", kInt); + ExprHandle y = bitcast(x); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_EQ(y.dtype(), kFloat); + } + { + VarHandle x("x", kShort); + ExprHandle y = bitcast(x); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_EQ(y.dtype(), kHalf); + } + { + VarHandle x("x", kHalf); + ExprHandle y = bitcast(x); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + ASSERT_EQ(y.dtype(), kShort); + } + + constexpr int32_t ref32 = 1337; + constexpr int64_t ref64 = 1337; + constexpr float reff32 = 1337.0f; + constexpr double reff64 = 1337.0f; + using SimpleIRExprEval = ExprEval; + // this is broken + /*{ + constexpr int16_t ref16 = 1337; + at::Half k_; + at::Half* k = &k_; + *reinterpret_cast(k) = ref16; + auto a = HalfImm::make(*k); + auto b = BitCast::make(kShort, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), ref16); + }*/ + + { + float k = raw_bitcast(ref32); + auto a = FloatImm::make(k); + auto b = BitCast::make(kInt, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), ref32); + } + + { + double k = raw_bitcast(ref64); + auto a = DoubleImm::make(k); + auto b = BitCast::make(kLong, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), ref64); + } + + { + int64_t k = raw_bitcast(reff64); + auto a = LongImm::make(k); + auto b = BitCast::make(kDouble, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), reff64); + } + + { + int32_t k = raw_bitcast(reff32); + auto a = IntImm::make(k); + auto b = BitCast::make(kFloat, a); + SimpleIRExprEval cg(b); + ASSERT_EQ(cg.value(), reff32); + } + + // This segfaults :( + /*{ + VarHandle x("x", kDouble); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + VarHandle x("x", kFloat); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + VarHandle x("x", kLong); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + VarHandle x("x", kShort); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + } + { + VarHandle x("x", kInt); + ASSERT_ANY_THROW(ExprHandle y = bitcast(x)); + }*/ +} + +TEST(Type, Propagation) { + // Same types: + { + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle body = FloatImm::make(2.f) + + (x * FloatImm::make(3.f) + FloatImm::make(4.f) * y); + ASSERT_EQ(body.dtype(), kFloat); + } + // Int to bigger int: + { + VarHandle x("x", kShort); + VarHandle y("y", kLong); + ExprHandle body = + ShortImm::make(2.f) + (x * ShortImm::make(3) + ShortImm::make(4) * y); + ASSERT_EQ(body.dtype(), kLong); + } + // Float to bigger float: + { + VarHandle x("x", kHalf); + VarHandle y("y", kDouble); + ExprHandle body = + HalfImm::make(2.f) + (x * HalfImm::make(3) + HalfImm::make(4) * y); + ASSERT_EQ(body.dtype(), kDouble); + } + // Int to Float: + { + VarHandle x("x", kFloat); + VarHandle y("y", kInt); + ExprHandle body = + IntImm::make(2) + (x * IntImm::make(3) + IntImm::make(4) * y); + ASSERT_EQ(body.dtype(), kFloat); + } + // Smaller float, bigger Int: + { + VarHandle x("x", kHalf); + VarHandle y("y", kLong); + ExprHandle body = + HalfImm::make(2) + (x * HalfImm::make(3) + HalfImm::make(4) * y); + ASSERT_EQ(body.dtype(), kHalf); + } + // Bigger float, smaller Int: + { + VarHandle x("x", kChar); + VarHandle y("y", kDouble); + ExprHandle body = + CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y); + ASSERT_EQ(body.dtype(), kDouble); + } + // Sign change char/byte upgrades to short: + { + VarHandle x("x", kChar); + VarHandle y("y", kByte); + ExprHandle body = + CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y); + ASSERT_EQ(body.dtype(), kShort); + } +} +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_type_specializations.cpp b/test/cpp/tensorexpr/test_type_specializations.cpp new file mode 100644 index 0000000000000..d9756627fa74d --- /dev/null +++ b/test/cpp/tensorexpr/test_type_specializations.cpp @@ -0,0 +1,75 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +// Test that tensor type specializations are available in +// the custom passes + +namespace torch { +namespace jit { + +namespace { + +bool hasTensorTypeSpecializations(torch::jit::Block* block) { + for (Value* v : block->inputs()) { + if (hasTensorTypeSpecialization(v)) + return true; + } + for (Node* n : block->nodes()) { + for (torch::jit::Block* b : n->blocks()) { + if (hasTensorTypeSpecializations(b)) + return true; + } + for (Value* v : n->outputs()) { + if (hasTensorTypeSpecialization(v)) + return true; + } + } + return false; +} + +static bool hasSpecializations = false; +void detectTTSpecializationPass(std::shared_ptr& graph) { + GRAPH_DUMP("In detectTTSpecialization Custom Post Pass: ", graph); + hasSpecializations = hasTensorTypeSpecializations(graph->block()); +} + +} // namespace + +TEST(SpecializationsInCustomPasses, Basic) { + RegisterPass p(detectTTSpecializationPass); + hasSpecializations = false; + std::shared_ptr graph = std::make_shared(); + parseIR( + R"IR( +graph(%a.1 : Tensor, + %b.1 : Tensor): + %c.1 : Tensor = aten::mul(%a.1, %b.1) # misc/test_specializations.py:5:8 + %d.1 : Tensor = aten::mul(%c.1, %b.1) # misc/test_specializations.py:6:8 + return (%d.1) + )IR", + &*graph); + + IValue ival = IValue(torch::randn({22}, at::kCPU)); + std::vector stack = {ival, ival}; + auto run = [&](std::shared_ptr& graph, std::vector stack) { + GraphExecutor executor(graph, ""); + executor.run(stack); + return stack; + }; + run(graph, stack); + + // Profiling mode will not be run with simple executor + if (!getExecutorMode()) { + EXPECT_TRUE(hasSpecializations); + } +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_utils.h b/test/cpp/tensorexpr/test_utils.h new file mode 100644 index 0000000000000..065e513c1a645 --- /dev/null +++ b/test/cpp/tensorexpr/test_utils.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; + +#define IS_NODE(T, node) \ + { \ + auto node_ = to(node); \ + ASSERT_NE(nullptr, node_); \ + } + +#define IS_NODE_WITH_NAME(T, node, name) \ + auto name = to(node); \ + ASSERT_NE(nullptr, name); + +#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \ + NodePtr name = nullptr; \ + { \ + auto node_ = to(node); \ + ASSERT_NE(nullptr, node_); \ + ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \ + name = to(node_->src_value()); \ + } \ + ASSERT_NE(nullptr, name); + +#define IS_IMM_WITH_VAL(T, node, val) \ + { \ + auto node_ = to(node); \ + ASSERT_NE(nullptr, node_); \ + ASSERT_EQ(node_->value(), val); \ + } + +#define IS_VAR_WITH_NAME(node, name) \ + { \ + auto node_ = to(node); \ + ASSERT_NE(nullptr, node_); \ + ASSERT_EQ(node_->name_hint(), name); \ + } + +#define IS_BINOP_W_VARS(T, node, name, v1, v2) \ + NodePtr name = nullptr; \ + { \ + name = to(node); \ + ASSERT_NE(nullptr, name); \ + IS_VAR_WITH_NAME(name->lhs(), v1); \ + IS_VAR_WITH_NAME(name->rhs(), v2); \ + } + +#define IS_BINOP_W_CONST(T, node, name, v, c) \ + NodePtr name = nullptr; \ + { \ + name = to(node); \ + ASSERT_NE(nullptr, name); \ + IS_VAR_WITH_NAME(name->lhs(), v); \ + IS_IMM_WITH_VAL(Int, name->rhs(), c); \ + } + +#define IS_RAND(node) \ + { \ + auto node_ = to(node); \ + ASSERT_NE(nullptr, node_); \ + ASSERT_EQ(node_->op_type(), kRand); \ + } + +void checkIR(StmtPtr s, const std::string& pattern); +void checkExprIR(ExprPtr e, const std::string& pattern); +void checkExprIR(const ExprHandle& e, const std::string& pattern); + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/tutorial.cpp b/test/cpp/tensorexpr/tutorial.cpp new file mode 100644 index 0000000000000..3f4c32af463b6 --- /dev/null +++ b/test/cpp/tensorexpr/tutorial.cpp @@ -0,0 +1,542 @@ +// *** Tensor Expressions *** +// +// This tutorial covers basics of NNC's tensor expressions, shows basic APIs to +// work with them, and outlines how they are used in the overall TorchScript +// compilation pipeline. This doc is permanently a "work in progress" since NNC +// is under active development and things change fast. +// +// This Tutorial's code is compiled in the standard pytorch build, and the +// executable can be found in `build/bin/tutorial_tensorexpr`. +// +// *** What is NNC *** +// +// NNC stands for Neural Net Compiler. It is a component of TorchScript JIT +// and it performs on-the-fly code generation for kernels, which are often a +// combination of multiple aten (torch) operators. +// +// When the JIT interpreter executes a torchscript model, it automatically +// extracts subgraphs from the torchscript IR graph for which specialized code +// can be JIT generated. This usually improves performance as the 'combined' +// kernel created from the subgraph could avoid unnecessary memory traffic that +// is unavoidable when the subgraph is interpreted as-is, operator by operator. +// This optimization is often referred to as 'fusion'. Relatedly, the process of +// finding and extracting subgraphs suitable for NNC code generation is done by +// a JIT pass called 'fuser'. +// +// *** What is TE *** +// +// TE stands for Tensor Expressions. TE is a commonly used approach for +// compiling kernels performing tensor (~matrix) computation. The idea behind it +// is that operators are represented as a mathematical formula describing what +// computation they do (as TEs) and then the TE engine can perform mathematical +// simplification and other optimizations using those formulas and eventually +// generate executable code that would produce the same results as the original +// sequence of operators, but more efficiently. +// +// NNC's design and implementation of TE was heavily inspired by Halide and TVM +// projects. +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace torch::jit::tensorexpr; + +#ifdef TORCH_ENABLE_LLVM + +// Helper function to print a snippet from a big multi-line string +static void printLinesToFrom(const std::string& input_str, int from, int to); + +#endif + +int main(int argc, char* argv[]) { + std::cout << "*** Structure of tensor expressions and statements ***" + << std::endl; + { + // A tensor expression is a tree of expressions. Each expression has a type, + // and that type defines what sub-expressions the current expression has. + // For instance, an expression of type 'Mul' would have a type 'kMul' and + // two subexpressions: LHS and RHS. Each of these two sub-expressions could + // also be a 'Mul' or some other expression. + // + // Let's construct a simple TE: + ExprPtr lhs = alloc(5); + ExprPtr rhs = alloc("x", kInt); + ExprPtr mul = alloc(lhs, rhs); + std::cout << "Tensor expression: " << *mul << std::endl; + // Prints: Tensor expression: 5 * x + + // Here we created an expression representing a 5*x computation, where x is + // an int variable. + + // Another, probably a more convenient, way to construct tensor expressions + // is to use so called expression handles (as opposed to raw expressions + // like we did in the previous example). Expression handles overload common + // operations and allow us to express the same semantics in a more natural + // way: + ExprHandle l = 5; + ExprHandle r = Var::make("x", kInt); + ExprHandle m = l * r; + std::cout << "Tensor expression: " << *m.node() << std::endl; + // Prints: Tensor expression: 5 * x + + // Converting from handles to raw expressions and back is easy: + ExprHandle handle = Var::make("x", kInt); + ExprPtr raw_expr_from_handle = handle.node(); + ExprPtr raw_expr = alloc("x", kInt); + ExprHandle handle_from_raw_expr = ExprHandle(raw_expr); + + // We could construct arbitrarily complex expressions using mathematical + // and logical operations, casts between various data types, and a bunch of + // intrinsics. + ExprHandle a = Var::make("a", kInt); + ExprHandle b = Var::make("b", kFloat); + ExprHandle c = Var::make("c", kFloat); + ExprHandle x = ExprHandle(5) * a + b / (sigmoid(c) - 3.0f); + std::cout << "Tensor expression: " << *x.node() << std::endl; + // Prints: Tensor expression: float(5 * a) + b / ((sigmoid(c)) - 3.f) + + // An ultimate purpose of tensor expressions is to optimize tensor + // computations, and in order to represent accesses to tensors data, there + // is a special kind of expression - a load. + // To construct a load we need two pieces: the base and the indices. The + // base of a load is a Buf expression, which could be thought of as a + // placeholder similar to Var, but with dimensions info. + // + // Let's construct a simple load: + BufHandle A("A", {64, 32}, kInt); + VarPtr i_var = alloc("i", kInt), j_var = alloc("j", kInt); + ExprHandle i(i_var), j(j_var); + ExprHandle load = Load::make(A.dtype(), A, {i, j}); + std::cout << "Tensor expression: " << *load.node() << std::endl; + // Prints: Tensor expression: A[i, j] + + // Tensor Expressions constitute Tensor Statements, which are used to + // represent computation of a given operator or a group of operators from a + // fusion group. + // + // There are three main kinds of tensor statements: + // - block + // - store + // - loop + // + // A Store represents a store to a single element of a tensor (or to a + // group of elements if it's a vectorized store). Store statements, + // similarly to Load expressions, have a base and indices, but on top of + // that they also include a value - an expression representing what needs + // to be stored at the given memory location. Let's create a Store stmt: + StmtPtr store_a = Store::make(A, {i, j}, i + j); + std::cout << "Store statement: " << *store_a << std::endl; + // Prints: Store statement: A[i, j] = i + j; + + // An operator fills the entire tensor, not just a single element, and to + // represent this we need to use For stmt: let's wrap our store stmt with + // two nested loops to represent that variables i and j need to iterate + // over some ranges. + ForPtr loop_j_a = For::make(VarHandle(j_var), 0, 32, store_a); + ForPtr loop_i_a = For::make(VarHandle(i_var), 0, 64, loop_j_a); + + std::cout << "Nested for loops: " << std::endl << *loop_i_a << std::endl; + // Prints: + // Nested for loops: + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // A[i, j] = i + j; + // } + // } + + // A Block statement is used when we need a sequence of other statements. + // E.g. if a fusion group contains several operators, we initially define + // separate loopnest for each of them and put them all into a common block: + BufHandle B("B", {64, 32}, kInt); + StmtPtr store_b = Store::make(B, {i, j}, A.load(i, j)); + ForPtr loop_j_b = For::make(VarHandle(j_var), 0, 32, store_b); + ForPtr loop_i_b = For::make(VarHandle(i_var), 0, 64, loop_j_b); + + BlockPtr block = Block::make({loop_i_a, loop_i_b}); + std::cout << "Compound Block statement: " << std::endl + << *block << std::endl; + // Prints: + // Compound Block statement: + // { + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // A[i, j] = i + j; + // } + // } + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // B[i, j] = A[i, j]; + // } + // } + // } + + // Manually constructing nested loops and blocks to represent a computation + // might be laborious, and instead we can use a 'Compute' API. This API + // requires us to specify dimensions and a lambda to compute a single + // element of the resulting tensor and returns a `Tensor` structure. This + // structure is simply a pair of a buffer that was created to represent the + // result of the computation (BufPtr) and a statement representing the + // computation itself (StmtPtr). + Tensor C = + Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { + return i * j; + }); + std::cout << "Stmt produced by 'Compute' API: " << std::endl + << *C.stmt() << std::endl; + // Prints: + // Stmt produced by 'Compute' API: + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // C[i, j] = i * j; + // } + // } + + // To construct statements to represent computations with reductions, we + // can use a 'Reduce' API - it is similar to 'Compute' but takes a couple + // of extra arguments defining how to perform the reduction. Let's define a + // simple 2D sum of C using that: + Tensor D = Reduce( + "D", + {}, + Sum(), + [&](const VarHandle& i, const VarHandle& j) { return C.load(i, j); }, + {64, 32}); + std::cout << "Stmt produced by 'Reduce' API: " << std::endl + << *D.stmt() << std::endl; + } + + std::cout << "*** Loopnests transformations ***" << std::endl; + { + // When a statement for the computation is generated, we might want to + // apply some optimizations to it. These transformations allow us to end up + // with a statement producing the same results, but more efficiently. + // + // Let's look at a couple of transformations that are used in NNC. We will + // begin with constructing a Block statement like we did before. + + Tensor C = + Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { + return i * (j + 1); + }); + BufHandle c_buf(C.buf()); + Tensor D = + Compute("D", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { + return c_buf.load(i, j) - i; + }); + StmtPtr block = Block::make({C.stmt(), D.stmt()}); + std::cout << "Stmt produced by 'Compute' API: " << std::endl + << *block << std::endl; + // Prints: + // Stmt produced by 'Compute' API: + // { + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // C[i, j] = i * (j + 1); + // } + // } + // for (const auto i_1 : c10::irange(64)) { + // for (const auto j_1 : c10::irange(32)) { + // D[i_1, j_1] = (C[i_1, j_1]) - i_1; + // } + // } + // } + + // One transformation we can apply to this computation is inlining: i.e. + // taking the expression that defines values of C and substituting a load + // from C with it. + // To do that, we first need to create a special object called LoopNest - + // all transformations are methods of this class. To create a loopnest we + // need to provide a list of output buffers and the root statement: + LoopNest nest(block, {D.buf()}); + + // We can always retrieve the Stmt back from LoopNest: + std::cout << "LoopNest root stmt: " << std::endl + << *nest.root_stmt() << std::endl; + // Prints: + // LoopNest root stmt: + // { + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // C[i, j] = i * (j + 1); + // } + // } + // for (const auto i_1 : c10::irange(64)) { + // for (const auto j_1 : c10::irange(32)) { + // D[i_1, j_1] = (C[i_1, j_1]) - i_1; + // } + // } + // } + + // Now we can apply the inlining transformation: + nest.computeInline(C.buf()); + std::cout << "Stmt after inlining:" << std::endl + << *nest.root_stmt() << std::endl; + // Prints: + // Stmt after inlining: + // { + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // D[i, j] = i * (j + 1) - i; + // } + // } + // } + + // We can also apply algebraic simplification to a statement: + StmtPtr simplified = IRSimplifier::simplify(nest.root_stmt()); + std::cout << "Stmt after simplification:" << std::endl + << *simplified << std::endl; + // Prints: + // Stmt after simplification: + // { + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // D[i, j] = i * j; + // } + // } + // } + + // Many loopnest transformations are stateless and can be applied without + // creating a LoopNest object. In fact, we plan to make all transformations + // stateless. + // splitWithTail is one such transformation: it splits an iteration space + // of a given loop into two with a given factor. + ForPtr outer_loop = to(to(simplified)->stmts().front()); + LoopNest::splitWithTail(outer_loop, 13); + // Call simplifier once more to fold some arithmetic. + simplified = IRSimplifier::simplify(simplified); + std::cout << "Stmt after splitWithTail:" << std::endl + << *simplified << std::endl; + // Prints: + // Stmt after splitWithTail: + // { + // for (const auto i_outer : c10::irange(4)) { + // for (const auto i_inner : c10::irange(13)) { + // for (const auto j : c10::irange(32)) { + // D[i_inner + 13 * i_outer, j] = i_inner * j + 13 * (i_outer * j); + // } + // } + // } + // for (const auto i_tail : c10::irange(12)) { + // for (const auto j : c10::irange(32)) { + // D[i_tail + 52, j] = i_tail * j + 52 * j; + // } + // } + // } + + // NNC supports a wide range of loop nest transformations, which we are not + // listing here. Please refer to documentation in + // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/loopnest.h + // for more details. + } + + std::cout << "*** Codegen ***" << std::endl; + { + // An ultimate goal of tensor expressions is to be provide a mechanism to + // execute a given computation in the fastest possible way. So far we've + // looked at how we could describe what computation we're interested in, but + // we haven't looked at how to actually execute it. + // + // All we've been dealing with was just symbols with no actual data + // associated, in this section we would look at how we can bridge that gap. + + // Let's start by constructing a simple computation for us to work with: + BufHandle A("A", {64, 32}, kInt); + BufHandle B("B", {64, 32}, kInt); + Tensor X = + Compute("X", {64, 32}, [&](const VarHandle& i, const VarHandle& j) { + return A.load(i, j) + B.load(i, j); + }); + + // And let's lower it to a loop nest, as we did in the previous section. We + // can pass Tensor object directly: + LoopNest loopnest({X}); + std::cout << *loopnest.root_stmt() << std::endl; + // Prints: + // { + // for (const auto i : c10::irange(64)) { + // for (const auto j : c10::irange(32)) { + // X[i, j] = (A[i, j]) + (B[i, j]); + // } + // } + + // Now imagine that we have two actual tensors 64x32 that we want sum + // together, how do we pass those tensors to the computation and how do we + // carry it out? + // + // Codegen object is aimed at providing exactly that functionality. Codegen + // is an abstract class and concrete codegens are derived from it. + // Currently, we have three codegens: + // 1) Simple Evaluator, + // 2) LLVM Codegen for CPU, + // 3) CUDA Codegen. + // In this example we will be using Simple Evaluator, since it's available + // everywhere. + + // To create a codegen, we need to provide the statement - it specifies the + // computation we want to perform - and a list of placeholders and tensors + // used in the computation. The latter part is crucial since that's the only + // way the codegen could use to correlate symbols in the statement to actual + // data arrays that we will be passing when we will actually be performing + // the computation. + // + // Let's create a Simple IR Evaluator codegen for our computation: + SimpleIREvaluator ir_eval(loopnest.root_stmt(), {A, B, X}); + + // We are using the simplest codegen and in it almost no work is done at the + // construction step. Real codegens such as CUDA and LLVM perform + // compilation during that stage so that when we're about to run the + // computation everything is ready. + + // Let's now create some inputs and run our computation with them: + std::vector data_A(64 * 32, 3); // This will be the input A + std::vector data_B(64 * 32, 5); // This will be the input B + std::vector data_X(64 * 32, 0); // This will be used for the result + + // Now let's invoke our codegen to perform the computation on our data. We + // need to provide as many arguments as how many placeholders and tensors we + // passed at the codegen construction time. A position in these lists would + // define how real data arrays from the latter call (these arguments are + // referred to as 'CallArg's in our codebase) correspond to symbols + // (placeholders and tensors) used in the tensor expressions we constructed + // (these are referred to as 'BufferArg'). + // Thus, we will provide three arguments: data_A, data_B, and data_X. data_A + // contains data for the placeholder A, data_B - for the placeholder B, and + // data_X would be used for contents of tensor X. + ir_eval(data_A, data_B, data_X); + + // Let's print one of the elements from each array to verify that the + // computation did happen: + std::cout << "A[10] = " << data_A[10] << std::endl + << "B[10] = " << data_B[10] << std::endl + << "X[10] = A[10] + B[10] = " << data_X[10] << std::endl; + // Prints: + // A[10] = 3 + // B[10] = 5 + // X[10] = A[10] + B[10] = 8 + } + + std::cout << "*** Lowering TorchScript IR to TensorExpr IR ***" << std::endl; + { + // This section requires a LLVM-enabled PyTorch build, so we have to use a + // guard: +#ifdef TORCH_ENABLE_LLVM + + // Often we would like to convert a TorchScript IR to TE rather than + // construct TE IR from scratch. NNC provides an API to perform such + // lowering: it takes a TorchScript graph and returns an object that can be + // used to invoke the generated kernel. + // This API is currently used by the TorchScript JIT fuser and can also be + // used ahead of time to pre-compile parts of a model. + // + // To get familiar with this API let's first start with defining a simple + // TorchScript graph: + const auto graph_string = R"IR( + graph(%A : Float(5, 3, strides=[3, 1], device=cpu), + %B : Float(5, 3, strides=[3, 1], device=cpu)): + %AB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %B) + %one : int = prim::Constant[value=1]() + %AAB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %AB) + %AAB_plus_B: Float(5, 3, strides=[3, 1]) = aten::add(%AAB, %B, %one) + return (%AAB_plus_B))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + // This graph defines a simple computation of A*A*B + B where A and B are + // input 5x3 tensors. + + // To lower this TorchScript graph to TE, we just need to create a + // TensorExprKernel object. In its constructor it constructs the + // corresponding TE IR and compiles it for the given backend (in this + // example for CPU using LLVM compiler). + TensorExprKernel kernel(graph); + + // We can retrieve the generated TE stmt from the kernel object: + StmtPtr kernel_stmt = kernel.getCodeGenStmt(); + std::cout << "TE Stmt constructed from TorchScript: " << std::endl + << *kernel_stmt << std::endl; + // Prints: + // TE Stmt constructed from TorchScript: + // { + // for (const auto v : c10::irange(5)) { + // for (const auto _tail_tail : c10::irange(3)) { + // aten_add[_tail_tail + 3 * v] = (tA[_tail_tail + 3 * v]) * + // ((tA[_tail_tail + 3 * v]) * (tB[_tail_tail + 3 * v])) + + // (tB[_tail_tail + 3 * v]); + // } + // } + // } + + // We can also examine generated LLVM IR and assembly code: + std::cout << "Generated LLVM IR: " << std::endl; + auto ir_str = kernel.getCodeText("ir"); + printLinesToFrom(ir_str, 15, 20); + // Prints: + // Generated LLVM IR: + // %9 = bitcast float* %2 to <8 x float>* + // %10 = load <8 x float>, <8 x float>* %9 ... + // %11 = bitcast float* %5 to <8 x float>* + // %12 = load <8 x float>, <8 x float>* %11 ... + // %13 = fmul <8 x float> %10, %12 + // %14 = fmul <8 x float> %10, %13 + + std::cout << "Generated assembly: " << std::endl; + auto asm_str = kernel.getCodeText("asm"); + printLinesToFrom(asm_str, 10, 15); + // Prints: + // Generated assembly: + // vmulps %ymm1, %ymm0, %ymm2 + // vfmadd213ps %ymm1, %ymm0, %ymm2 + // vmovups %ymm2, (%rax) + // vmovss 32(%rcx), %xmm0 + // vmovss 32(%rdx), %xmm1 + // vmulss %xmm1, %xmm0, %xmm2 + + // We can also execute the generated kernel: + auto A = + at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) * + 2.0; + auto B = + at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) * + 3.0; + std::vector inputs = {A, B}; + std::vector stack = torch::fmap(inputs); + kernel.run(stack); + auto R = stack[0].toTensor(); + + // Let's print one of the elements from the result tensor to verify that the + // computation did happen and was correct: + std::cout << "R[2][2] = " << R[2][2] << std::endl; + // Prints: + // R[2][2] = 15 + // [ CPUFloatType{} ] +#endif + } + return 0; +} + +void printLinesToFrom(const std::string& input_str, int from, int to) { + std::istringstream f(input_str); + std::string s; + int idx = 0; + while (getline(f, s)) { + if (idx > from) { + std::cout << s << "\n"; + } + if (idx++ > to) { + break; + } + } +} diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 306a882627d4b..c7c5a6140b7a2 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -1,4 +1,5 @@ #include +<<<<<<< HEAD #include #include #include @@ -9,6 +10,10 @@ #ifdef LAE_USE_CUDA #include #endif +======= +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include @@ -40,6 +45,7 @@ Tensor sgd_out_of_place( const float weight_decay, const double lr, const bool maximize) { +<<<<<<< HEAD STD_TORCH_CHECK(param.dim() == 1, "param must be 1D"); // these test the get_device() and get_device_index() methods @@ -47,6 +53,8 @@ Tensor sgd_out_of_place( STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1"); STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1"); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t *param_sizes; int64_t *param_strides; aoti_torch_get_sizes(param.get(), ¶m_sizes); @@ -140,10 +148,19 @@ Tensor my_ones_like(Tensor t, StableIValue device) { const auto num_args = 6; StableIValue stack[num_args]; +<<<<<<< HEAD auto mf = aoti_torch_memory_format_contiguous_format(); stack[0] = from(t); stack[1] = from(std::optional(t.scalar_type())); // dtype +======= + int32_t t_dtype; + aoti_torch_get_dtype(t.get(), &t_dtype); + auto mf = aoti_torch_memory_format_contiguous_format(); + + stack[0] = from(t); + stack[1] = from(std::optional(t_dtype)); // dtype +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stack[2] = from(std::nullopt); // layout stack[3] = from(std::optional(device)); // device stack[4] = from(std::optional(false)); // pin_memory @@ -267,6 +284,7 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("is_contiguous", &boxed_is_contiguous); } +<<<<<<< HEAD Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) { return transpose(t, dim0, dim1); @@ -552,3 +570,5 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("test_get_current_device_index", &boxed_test_get_current_device_index); } #endif // LAE_USE_CUDA +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 074461d352740..a7e7e648e84d8 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -51,6 +51,7 @@ def my_abs(t) -> Tensor: return torch.ops.libtorch_agnostic.my_abs.default(t) +<<<<<<< HEAD def my_is_cpu(t) -> bool: """ Returns is_cpu on the input tensor. @@ -64,6 +65,8 @@ def my_is_cpu(t) -> bool: return torch.ops.libtorch_agnostic.my_is_cpu.default(t) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def my_ones_like(tensor, device) -> Tensor: """ Returns a new Tensor like the input tensor, but with all ones @@ -129,6 +132,7 @@ def is_contiguous(t) -> bool: Returns: is_contiguous(t) """ return torch.ops.libtorch_agnostic.is_contiguous.default(t) +<<<<<<< HEAD def my_transpose(t, dim0, dim1) -> Tensor: @@ -307,3 +311,5 @@ def my_new_zeros_dtype_variant(t) -> Tensor: Returns: New zeros tensor """ return torch.ops.libtorch_agnostic.my_new_zeros_dtype_variant.default(t) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/setup.py b/test/cpp_extensions/libtorch_agnostic_extension/setup.py index b7141a3e6fcd6..f0f7a3efc028f 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/setup.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/setup.py @@ -4,8 +4,12 @@ from setuptools import find_packages, setup +<<<<<<< HEAD import torch from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension +======= +from torch.utils.cpp_extension import BuildExtension, CppExtension +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ROOT_DIR = Path(__file__).parent @@ -36,6 +40,7 @@ def get_extension(): "cxx": ["-fdiagnostics-color=always"], } +<<<<<<< HEAD extension = CppExtension # allow including if torch.cuda.is_available(): @@ -46,6 +51,12 @@ def get_extension(): return [ extension( +======= + sources = list(CSRC_DIR.glob("**/*.cpp")) + + return [ + CppExtension( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "libtorch_agnostic._C", sources=sorted(str(s) for s in sources), py_limited_api=True, diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index 0f471e8132a60..26e5f6729a855 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -1,11 +1,17 @@ # Owner(s): ["module: cpp"] +<<<<<<< HEAD import math +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from pathlib import Path import torch from torch.testing._internal.common_device_type import ( +<<<<<<< HEAD deviceCountAtLeast, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_device_type_tests, onlyCPU, onlyCUDA, @@ -175,6 +181,7 @@ def _make_cuda_tensors(prior_mem): curr_mem = torch.cuda.memory_allocated(device) self.assertEqual(curr_mem, init_mem) +<<<<<<< HEAD def test_my_transpose(self, device): import libtorch_agnostic @@ -345,6 +352,8 @@ def test_my_new_zeros_dtype_variant(self, device): ref_out = t.new_zeros((2, 5), dtype=torch.float) self.assertEqual(out, ref_out, exact_device=True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/test/cpp_extensions/mps_extension.mm b/test/cpp_extensions/mps_extension.mm index 30b70a76563d6..4dacc47dc9ae2 100644 --- a/test/cpp_extensions/mps_extension.mm +++ b/test/cpp_extensions/mps_extension.mm @@ -13,11 +13,14 @@ kernel void add_arrays(device const float* inA, { result[index] = inA[index] + inB[index]; } +<<<<<<< HEAD kernel void add_one(device float* data, uint index [[thread_position_in_grid]]) { data[index] += 1.0; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) )MPS_ADD_ARRAYS"); at::Tensor get_cpu_add_output(at::Tensor & cpu_input1, at::Tensor & cpu_input2) { @@ -55,6 +58,7 @@ kernel void add_one(device float* data, return mps_output; } +<<<<<<< HEAD void mps_add_one_new_encoder(const at::Tensor& input) { using namespace at::native::mps; TORCH_CHECK(input.is_mps()); @@ -82,4 +86,9 @@ void mps_add_one_new_encoder(const at::Tensor& input) { m.def("get_cpu_add_output", &get_cpu_add_output); m.def("get_mps_add_output", &get_mps_add_output); m.def("mps_add_one_new_context", &mps_add_one_new_encoder); +======= +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("get_cpu_add_output", &get_cpu_add_output); + m.def("get_mps_add_output", &get_mps_add_output); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } diff --git a/test/cpp_extensions/open_registration_extension.cpp b/test/cpp_extensions/open_registration_extension.cpp index fbd53b96234b2..ba69c7d1c3bdd 100644 --- a/test/cpp_extensions/open_registration_extension.cpp +++ b/test/cpp_extensions/open_registration_extension.cpp @@ -139,6 +139,39 @@ void fallback_with_undefined_tensor() { grad_scale, found_inf); } +<<<<<<< HEAD +======= +struct CustomAutogradFnReturnsSelf : public torch::autograd::Function { + + static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) { + return self; + } + + static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) { + return {grad_output[0] * 0.5}; + } +}; + +struct CustomAutogradFnAliasing : public torch::autograd::Function { + + static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) { + return self.view_symint(self.sym_sizes()); + } + + static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) { + return {grad_output[0] * 0.5}; + } +}; + +at::Tensor custom_autograd_fn_returns_self(at::Tensor x) { + return CustomAutogradFnReturnsSelf::apply(x); +} + +at::Tensor custom_autograd_fn_aliasing(at::Tensor x) { + return CustomAutogradFnAliasing::apply(x); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Here, we're exposing a custom device object that corresponds to our custom backend. // We do this using pybind: exposing an "extension_name.custom_device()" function in python, // that's implemented in C++. @@ -149,4 +182,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method"); m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called"); m.def("fallback_with_undefined_tensor", &fallback_with_undefined_tensor, "fallback_with_undefined_tensor for privateuse1"); +<<<<<<< HEAD +======= + + // Co-opting this file to more easily test torch.compile'ing of custom autograd functions in C++ + m.def("custom_autograd_fn_returns_self", &custom_autograd_fn_returns_self); +} + +TORCH_LIBRARY(_test_funcs, m) { + m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)"); +} +TORCH_LIBRARY_IMPL(_test_funcs, AutogradCPU, m) { + m.impl("custom_autograd_fn_aliasing", &custom_autograd_fn_aliasing); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } diff --git a/test/cpp_extensions/open_registration_extension/README.md b/test/cpp_extensions/open_registration_extension/README.md new file mode 100644 index 0000000000000..24fec68c31835 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/README.md @@ -0,0 +1,37 @@ +# PyTorch OpenReg + +This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend from core. + +## How to use + +Install as standalone with `python setup.py develop` (or install) from this folder. +You can run test via `python {PYTORCH_ROOT_PATH}/test/test_openreg.py`. + +## Design principles + +For simplicity anything that can be implemented from python is done so. +A real implementation will most likely want to call these different APIs from c++ directly. + +The current version sends everything back to python and contains enough implementation to run basic model, transfer host/device and printing. + +The codebase is split as follows: + +- `pytorch_openreg/__init__.py` + - imports torch to get core state initialized. + - imports `._aten_impl` to register our aten op implementations to torch. + - imports `.C` to load our c++ extension that registers more ops, allocator and hooks. + - renames the PrivateUse1 backend and register our python-side module. +- `pytorch_openreg/_aten_impl.py` + - Define a new `torch.Library` that registers a fallback that will be called whenever a backend kernel for PrivateUse1 is called. It contains the logic to handle all kind of native functions, computing the output metadata, allocating it and only calling into the device daemon to perform computation. +- `pytorch_openreg/_device_daemon.py` + - contains the Allocator (responsible for allocating memory on the device side and host side, as int8 buffers). + - contains `Driver`, which as user-process driver to deal with some information needed to be done in driver. + - contains `Executor`, which as device-process exector to do something related device logic. +- `pytorch_openreg/_meta_parser.py` mainly contain utilities to send objects over the wire from the user process to the device process. + - The main class there is `OpenRegTensorMeta` that contains all the metadata sent to the device which should be enough for it to populate the output Tensor. + +## Next steps + +The main next step would be to: + +- Replace the current `open_registration_extension.cpp` test in PyTorch CI with this. diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py new file mode 100644 index 0000000000000..05b8955b6557b --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py @@ -0,0 +1,122 @@ +import types + +import torch + +# Create our python implementation dict so that the C++ module +# can access it during its initialization and also register aten impls. +from ._aten_impl import impl_factory as impl_factory # noqa: F401 +from ._device_daemon import driver + + +# Load the C++ Module +import pytorch_openreg._C # isort:skip # type: ignore[import] # noqa: F401 + + +def _create_module(): + module = types.ModuleType("_OpenRegMod") + + class device: + r"""Context-manager that changes the selected device. + + Args: + device (torch.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + + def __init__(self, device): + self.idx = torch.accelerator._get_device_index(device, optional=True) + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = driver.exec("exchangeDevice", self.idx) + + def __exit__(self, type, value, traceback): + self.idx = driver.exec("uncheckedSetDevice", self.prev_idx) + return False + + def device_count() -> int: + return driver.exec("deviceCount") + + def is_available(): + return True + + def current_device(): + return torch.accelerator.current_device_index() + + def get_rng_state(device="openreg"): + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("openreg", device) + idx = device.index + if idx is None: + idx = current_device() + default_generator = pytorch_openreg._C._get_default_generator(idx) + return default_generator.get_state() + + def set_rng_state(new_state, device="openreg"): + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("openreg", device) + idx = device.index + if idx is None: + idx = current_device() + default_generator = pytorch_openreg._C._get_default_generator(idx) + default_generator.set_state(new_state) + + def initial_seed() -> int: + _lazy_init() + idx = current_device() + default_generator = pytorch_openreg._C._get_default_generator(idx) + return default_generator.initial_seed() + + def manual_seed(seed: int) -> None: + seed = int(seed) + + idx = current_device() + default_generator = pytorch_openreg._C._get_default_generator(idx) + default_generator.manual_seed(seed) + + def manual_seed_all(seed: int) -> None: + seed = int(seed) + + for idx in range(device_count()): + default_generator = pytorch_openreg._C._get_default_generator(idx) + default_generator.manual_seed(seed) + + def is_initialized(): + return module._initialized + + def _is_in_bad_fork(): + return False + + def _lazy_init(): + if is_initialized(): + return + pytorch_openreg._C._init() + module._initialized = True + + module.is_available = is_available # type: ignore[assignment] + + module._initialized = False # type: ignore[assignment] + module._lazy_init = _lazy_init # type: ignore[assignment] + module.is_initialized = is_initialized # type: ignore[assignment] + + module.device = device # type: ignore[assignment] + module.device_count = device_count # type: ignore[assignment] + module.current_device = current_device # type: ignore[assignment] + module.get_rng_state = get_rng_state # type: ignore[assignment] + module.set_rng_state = set_rng_state # type: ignore[assignment] + module._is_in_bad_fork = _is_in_bad_fork # type: ignore[assignment] + module.initial_seed = initial_seed # type: ignore[assignment] + module.manual_seed = manual_seed # type: ignore[assignment] + module.manual_seed_all = manual_seed_all # type: ignore[assignment] + + return module + + +# Set all the appropriate state on PyTorch +torch.utils.rename_privateuse1_backend("openreg") +torch._register_device_module("openreg", _create_module()) +torch.utils.generate_methods_for_privateuse1_backend(for_storage=True) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py new file mode 100644 index 0000000000000..d4c49bd28d458 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py @@ -0,0 +1,186 @@ +import logging + +import torch +from torch.utils._pytree import tree_any + + +log = logging.getLogger(__name__) + +from ._device_daemon import driver +from ._meta_parser import prepare_for_sending, to_device_no_copy + + +_IMPL_REGISTRY = {} + + +def impl_factory(name): + if name in _IMPL_REGISTRY: + return _IMPL_REGISTRY[name] + + def _(*args, **kwargs): + log.info("Calling hook %s", name) + return driver.exec(name, *args, **kwargs) + + _IMPL_REGISTRY[name] = _ + return _ + + +def _openreg_kernel_fallback(op, *args, **kwargs): + def get_tensor_device(*args): + for arg in args: + if isinstance(arg, torch.Tensor) and arg.device.type == "openreg": + return arg.device + + device = get_tensor_device(*args) + if device is None: + return _kernel_fallback(op, *args, **kwargs) + + # Mimicks the DeviceGuard system we have in aten + with torch.openreg.device(device): # type: ignore[misc] + return _kernel_fallback(op, *args, **kwargs) + + +def _kernel_fallback(op, *args, **kwargs): + log.info("Calling kernel %s", op) + + op_name = None + post_process = None + if "out" in op._overloadname: + # Note that all structured native op will call here + if isinstance(kwargs["out"], tuple): + raise RuntimeError(f"out= variant {op} with tuple out= not supported") + if kwargs["out"].nelement() == 0: + # Out variant that needs a resize, convert to an out of place + # and handle generically below + orig_out = kwargs["out"] + del kwargs["out"] + if op._overloadname != "out": + raise RuntimeError( + "Cannot retranslate non-default out= variant form 0 size" + ) + op = op.overloadpacket.default + + def _post_process(): + nonlocal real_res + orig_out.set_(real_res) + real_res = orig_out + + post_process = _post_process + + else: + # No metadata update to do, just run the op on the device + op_name = op.overloadpacket._qualified_op_name + real_res = kwargs["out"] + elif not tree_any(lambda obj: isinstance(obj, torch.Tensor), (args, kwargs)): + # No Tensor argument means factory function + # They should decompose and be handled in our c++ side directly + raise RuntimeError(f"{op} not handled yet.") + elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default: + # Only handle inplace ops returning their first arg + assert len(args) >= 1, f"Inplace {op} needs at least one arg" + assert len(op._schema.returns) == 1, ( + f"NYI Inplace {op} with more than one return" + ) + op_name = op.overloadpacket._qualified_op_name + real_res = args[0] + elif any(r.alias_info is not None for r in op._schema.returns): + # View ops + if op is torch.ops.aten.view.default: + return torch.ops.aten._unsafe_view(*args, **kwargs) + raise RuntimeError(f"{op} view op is not handled yet") + + if op_name is None: + # 1. Compute updated metadata + if torch.Tag.dynamic_output_shape not in op.tags: + # Usual case: run the meta op to see the output metadata + meta_args, meta_kwargs = to_device_no_copy("meta", args, kwargs) + meta_res = op(*meta_args, **meta_kwargs) + + # 2. Allocate the output + real_res, _ = to_device_no_copy("openreg", meta_res, {}) + else: + # Slow version for data-dependent functions: + # Run the op on the device just to get the output shape + args_, kwargs_ = prepare_for_sending(args, kwargs) + shape = driver.exec( + "get_op_output_shape", + op.overloadpacket._qualified_op_name, + args_, + kwargs_, + ) + + # 2. Allocate the output + real_res = args[0].new(shape) + + # 3. Move to out variant + kwargs["out"] = real_res + # Let overload resolution find the out= overload + op_name = op.overloadpacket._qualified_op_name + + # 4. Run the compute and populate the output on the device + args, kwargs = prepare_for_sending(args, kwargs) + driver.exec("run_op", op_name, args, kwargs) + + if post_process is not None: + post_process() + + return real_res + + +def copy_from_device(from_): + with torch.openreg.device(from_.device): # type: ignore[misc] + args, _ = prepare_for_sending((from_,), {}) + return driver.exec("send_data", *args) + + +def copy_from_host_to_device(from_, to_): + with torch.openreg.device(to_.device): # type: ignore[misc] + args, _ = prepare_for_sending((to_,), {}) + driver.exec("recv_data", from_, *args) + return to_ + + +def _copy_from(from_, to_): + if from_.device.type == to_.device.type: + assert from_.device.type == "openreg" + if from_.device.index == to_.device.index: + op = torch.ops.aten.copy_.default + return _openreg_kernel_fallback(op, to_, from_) + else: + host_mem = copy_from_device(from_) + return copy_from_host_to_device(host_mem, to_) + elif from_.device.type == "openreg": + host_mem = copy_from_device(from_) + return to_.copy_(host_mem) + elif to_.device.type == "openreg": + return copy_from_host_to_device(from_, to_) + else: + raise RuntimeError("Should not happen") + + +def _set_source_tensor(ten1, ten2): + return torch.ops.aten.set_.source_Storage_storage_offset( + ten1, + ten2.untyped_storage(), + ten2.storage_offset(), + ten2.size(), + ten2.stride(), + ) + + +def _local_scalar_dense(ten): + host_mem = copy_from_device(ten) + return host_mem.item() + + +_openreg_lib = torch.library.Library("_", "IMPL") +_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1") + +_openreg_lib_aten = torch.library.Library("aten", "IMPL") +_openreg_lib_aten.impl("_copy_from", _copy_from, dispatch_key="PrivateUse1") +_openreg_lib_aten.impl( + "set_.source_Tensor", _set_source_tensor, dispatch_key="PrivateUse1" +) +_openreg_lib_aten.impl( + "_local_scalar_dense", _local_scalar_dense, dispatch_key="PrivateUse1" +) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py new file mode 100644 index 0000000000000..d339869635001 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py @@ -0,0 +1,391 @@ +import ctypes +import logging +import threading +import time + +import torch + +from ._meta_parser import ( + OpenRegTensorData, + receive_after_sending, + safe_str, + validate_send_queue_args, +) + + +log = logging.getLogger(__name__) +mp_context = torch.multiprocessing.get_context("spawn") + +# Constant properties of our device +NUM_DEVICES = 2 + + +# Our allocator +class Allocator: + def __init__(self): + self.allocated = {} + + def malloc(self, size): + mem = ctypes.create_string_buffer(size) + ptr = ctypes.addressof(mem) + self.allocated[ptr] = (size, mem) + return ptr + + def free(self, ptr): + if ptr not in self.allocated: + return False + else: + del self.allocated[ptr] + return True + + +class HostAllocator(Allocator): + def is_pinned_ptr(self, ptr): + return ptr in self.allocated or any( + ptr_ <= ptr and ptr < ptr_ + size + for ptr_, (size, _) in self.allocated.items() + ) + + +class DeviceAllocator(Allocator): + def tensor_from_meta(self, meta): + def create_tensor_from_data_ptr(ptr, size): + storage = torch._C._construct_storage_from_data_pointer( + ptr, torch.device("cpu"), size + ) + return torch.Tensor(storage) + + found_base = None + # Usual case, we're receiving a known Tensor + if meta.data_ptr in self.allocated: + found_base = create_tensor_from_data_ptr( + meta.data_ptr, self.allocated[meta.data_ptr][0] + ) + + # Might be a rewrap of another storage at a different offset + # Slow path to try and find the corresponding storage + if found_base is None: + for tag, (size, _) in self.allocated.items(): + # t is always a 1D uint8 storage! + if meta.data_ptr > tag and meta.data_ptr < tag + size: + # Blame @ngimel for this + slice_size = size - (meta.data_ptr - tag) + found_base = create_tensor_from_data_ptr(meta.data_ptr, slice_size) + + # Might be an empty tensor + if found_base is None and meta.nelem_in_bytes == 0: + found_base = torch.tensor((), dtype=torch.uint8) + + # This pointer is not allocated here, segfault ! + if found_base is None: + log.info("Currently allocated blocks:\n %s", safe_str(self.allocated)) + log.info("Trying to access %s", meta) + raise RuntimeError("SEGFAULT!") + + # Raw 1d uint8 data + raw = found_base + # Reinterpret cast in the right dtype + as_dtype = raw.view(dtype=meta.dtype) + # View to the right shape/stride/offset + view = as_dtype.as_strided(meta.size, meta.stride, meta.storage_offset) + return view + + +def register(registry): + def func(fn): + registry[fn.__name__] = fn + return fn + + return func + + +class Driver: + def __init__(self, num_devices): + super().__init__() + self.num_devices = num_devices + self.is_initialized = False + + # State of our driver + self.curr_device_idx = 0 + self.curr_streams = {} + + # Allocated memory belongs to which device + self.memory_belong = {} + self.host_allocator = HostAllocator() + self.event_belong = {} + + self.rlock = threading.RLock() + + def _lazy_init(self): + if self.is_initialized: + return + self.devices = [] + + for i in range(self.num_devices): + req_queue = mp_context.Queue() + ans_queue = mp_context.Queue() + runner = mp_context.Process( + target=_Executor(i).run_forever, + args=(req_queue, ans_queue), + daemon=True, + ) + runner.start() + self.devices.append((req_queue, ans_queue, runner)) + + self.is_initialized = True + + def exec(self, cmd, *args): + with self.rlock: + log.info("Main process launched: %s(*%s)", cmd, safe_str(args)) + + if cmd in Driver.registry: + res = Driver.registry[cmd](self, *args) + else: + res = self.run_on_executor(self.curr_device_idx, cmd, *args) + + log.info("Main process result for %s received: %s", cmd, safe_str(res)) + if res == "ERROR": + raise RuntimeError(f"Error in daemon while executing {cmd}, see logs") + else: + return res + + def run_on_executor(self, device_idx, cmd, *args): + self._lazy_init() + req_queue, ans_queue, _ = self.devices[device_idx] + stream = self.getStream(device_idx) + validate_send_queue_args(cmd, args) + req_queue.put((stream, cmd) + args) + return ans_queue.get() + + registry = {} + + @register(registry) + def hasPrimaryContext(self, device_idx): + return device_idx >= 0 and device_idx < self.num_devices + + @register(registry) + def deviceCount(self, *args): + assert len(args) == 0 + return self.num_devices + + @register(registry) + def getDevice(self): + return self.curr_device_idx + + @register(registry) + def setDevice(self, device_idx): + assert device_idx >= 0 and device_idx < self.num_devices + self.curr_device_idx = device_idx + + @register(registry) + def uncheckedSetDevice(self, *args): + assert len(args) == 1 + self.curr_device_idx = int(args[0]) + + @register(registry) + def exchangeDevice(self, *args): + assert len(args) == 1 + res = self.curr_device_idx + self.curr_device_idx = int(args[0]) + return res + + @register(registry) + def malloc(self, size): + ptr = self.run_on_executor(self.curr_device_idx, "malloc", size) + self.memory_belong[ptr] = self.curr_device_idx + return ptr + + @register(registry) + def free(self, ptr): + device_idx = self.memory_belong.pop(ptr, None) + if device_idx is None: + return False + return self.run_on_executor(device_idx, "free", ptr) + + @register(registry) + def isPinnedPtr(self, ptr): + return self.host_allocator.is_pinned_ptr(ptr) + + @register(registry) + def hostMalloc(self, size): + return self.host_allocator.malloc(size) + + @register(registry) + def hostFree(self, ptr): + return self.host_allocator.free(ptr) + + @register(registry) + def getNewStream(self, device_idx, priority): + return self.run_on_executor(device_idx, "getNewStream", priority) + + @register(registry) + def queryStream(self, stream): + return self.run_on_executor( + stream.device_index, "queryStream", stream.stream_id + ) + + @register(registry) + def getStream(self, device_idx): + return self.curr_streams.get(device_idx, 0) + + @register(registry) + def exchangeStream(self, stream): + stream_id = self.curr_streams.get(stream.device_index, 0) + self.curr_streams[stream.device_index] = stream.stream_id + return stream_id + + @register(registry) + def synchronizeStream(self, stream): + self.run_on_executor(stream.device_index, "synchronizeStream", stream.stream_id) + + @register(registry) + def record(self, event, stream, device_index, flags): + event_ptr = ctypes.cast(event, ctypes.POINTER(ctypes.c_int64)) + # Create event if needed + if event_ptr.contents.value == 0: + event_ptr.contents.value = self.run_on_executor( + stream.device_index, "eventCreateWithFlags", flags + ) + self.event_belong[event_ptr.contents.value] = stream.device_index + + # Record event + self.run_on_executor( + stream.device_index, + "eventRecord", + event_ptr.contents.value, + stream.stream_id, + ) + + @register(registry) + def destroyEvent(self, event, device_index): + self.run_on_executor(device_index, "eventDestroy", event) + self.event_belong.pop(event) + + @register(registry) + def synchronizeEvent(self, event): + self.run_on_executor(self.event_belong[event], "eventSynchronize", event) + + @register(registry) + def queryEvent(self, event): + return self.run_on_executor(self.event_belong[event], "eventQuery", event) + + @register(registry) + def elapsedTime(self, e1, e2, device_index): + return self.run_on_executor(device_index, "eventElapsedTime", e1, e2) + + @register(registry) + def block(self, event, stream): + self.run_on_executor(stream.device_index, "block", event, stream.stream_id) + + +class _Executor: + def __init__(self, id): + self.id = id + self.allocator = DeviceAllocator() + self.stream = 0 + self.event_incr_id = 0 + self.events = {} + + def run_forever(self, req_queue, ans_queue): + # Serve all requests + while True: + # Ignore stream since cpu backend doesn't support asynchronous execution + _, cmd, *args = req_queue.get() + log.info("Worker executing: %s", cmd) + if cmd in _Executor.registry: + res = _Executor.registry[cmd](self, *args) + else: + log.warning("Bad command in worker") + res = "ERROR" + + log.info("Worker answering to: %s", cmd) + ans_queue.put(res) + + registry = {} + + @register(registry) + def malloc(self, size): + return self.allocator.malloc(size) + + @register(registry) + def free(self, ptr): + return self.allocator.free(ptr) + + def _run_op(self, op_name, args, kwargs): + op, _ = torch._C._jit_get_operation(op_name) + args, kwargs = receive_after_sending(self.allocator, args, kwargs) + return op(*args, **kwargs) + + @register(registry) + def run_op(self, op_name, args, kwargs): + self._run_op(op_name, args, kwargs) + + @register(registry) + def get_op_output_shape(self, op_name, args, kwargs): + return self._run_op(op_name, args, kwargs).size() + + @register(registry) + def send_data(self, *args): + assert len(args) == 1 + return OpenRegTensorData.from_meta(self.allocator, args[0]) + + @register(registry) + def recv_data(self, host_tensor, dev_mem): + dev_tensor = OpenRegTensorData.from_meta(self.allocator, dev_mem) + dev_tensor.copy_(host_tensor) + + @register(registry) + def getNewStream(self, priority): + self.stream += 1 + return self.stream + + @register(registry) + def queryStream(self, stream): + return True + + @register(registry) + def synchronizeStream(self, stream): + # no-op + pass + + @register(registry) + def eventCreateWithFlags(self, flags): + self.event_incr_id += 1 + self.events[self.event_incr_id] = [flags, None] + return self.event_incr_id + + @register(registry) + def eventRecord(self, event, stream): + # Only flags == 1 enables timing + if self.events[event][0] == 1: + self.events[event][1] = time.time() * 1000 + return 0 + + @register(registry) + def eventDestroy(self, event): + self.events.pop(event) + + @register(registry) + def eventSynchronize(self, event): + assert self.events.get(event) is not None + return 0 + + @register(registry) + def eventQuery(self, event): + assert self.events.get(event) is not None + return True + + @register(registry) + def eventElapsedTime(self, e1, e2): + time_1 = self.events[e1][1] + time_2 = self.events[e2][1] + assert time_1 is not None and time_2 is not None + return time_2 - time_1 + + @register(registry) + def block(self, event, stream): + # no-op + pass + + +driver = Driver(NUM_DEVICES) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py new file mode 100644 index 0000000000000..0f54f2ec4df00 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py @@ -0,0 +1,103 @@ +import pprint + +import torch +from torch.utils._pytree import tree_map, tree_map_only + + +class OpenRegTensorMeta: + def __init__(self, tensor, checked=True): + if checked and not tensor.device.type == "openreg": + raise RuntimeError( + "Creating OpenRegTensorMeta is only for Tensors on openreg device" + ) + self.data_ptr = tensor.untyped_storage().data_ptr() + self.size = tensor.size() + self.stride = tensor.stride() + self.storage_offset = tensor.storage_offset() + self.dtype = tensor.dtype + self.nelem_in_bytes = tensor.nelement() * tensor.element_size() + + def __repr__(self): + return ( + f"OpenRegTensorMeta({self.data_ptr=}, {self.size=}, {self.stride=}, " + f"{self.storage_offset=}, {self.dtype=}, {self.nelem_in_bytes=})" + ) + + +class OpenRegTensorData(torch.Tensor): + @staticmethod + def from_meta(allocator, tensor_meta): + return OpenRegTensorData(allocator.tensor_from_meta(tensor_meta)) + + +VALID_QUEUE_TYPES_IN = {torch.Tensor, int, float} + +VALID_QUEUE_TYPES_OUT = {OpenRegTensorMeta, int, float, str} + + +def safe_str(args): + def convert(obj): + if isinstance(obj, torch.Tensor): + return str(OpenRegTensorMeta(obj, checked=False)) + else: + return obj + + new_args = tree_map(convert, args) + return pprint.pformat(new_args) + + +def validate_send_queue_args(cmd, args): + def check(obj): + if type(obj) not in VALID_QUEUE_TYPES_OUT: + if ( + cmd == "recv_data" + and type(obj) in [torch.Tensor, OpenRegTensorData] + and obj.device.type == "cpu" + ): + # Only HtoD copy command can send cpu Tensors over + return + raise RuntimeError( + f"Trying to send invalid object through queue: {type(obj)}" + ) + + tree_map(check, args) + + +def prepare_for_sending(args, kwargs): + def convert(obj): + if type(obj) not in VALID_QUEUE_TYPES_IN: + raise RuntimeError( + f"Cannot send object of type {type(obj)} over openreg device pipe." + ) + + if isinstance(obj, torch.Tensor): + return OpenRegTensorMeta(obj) + else: + return obj + + return tree_map(convert, (args, kwargs)) + + +def receive_after_sending(allocator, args, kwargs): + def convert(obj): + if type(obj) not in VALID_QUEUE_TYPES_OUT: + raise RuntimeError( + f"Received invalid object of type {type(obj)} over openreg device pipe." + ) + + if isinstance(obj, OpenRegTensorMeta): + return allocator.tensor_from_meta(obj) + else: + return obj + + return tree_map(convert, (args, kwargs)) + + +def to_device_no_copy(device, args, kwargs): + def safe_to(t): + if device == "meta": + return t.to(device=device) + else: + return torch.empty_like(t, device=device) + + return tree_map_only(torch.Tensor, safe_to, (args, kwargs)) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/Module.cpp b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/Module.cpp new file mode 100644 index 0000000000000..4580629454b76 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/Module.cpp @@ -0,0 +1,51 @@ +#include "OpenReg.h" + +#include + +#include +#include +#include +#include + +static PyObject* _initExtension(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + + at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), + "_get_default_generator expects an int, but got ", + THPUtils_typename(arg)); + auto idx = static_cast(THPUtils_unpackLong(arg)); + + return THPGenerator_initDefaultGenerator( + at::globalContext().defaultGenerator( + c10::Device(c10::DeviceType::PrivateUse1, idx))); + + END_HANDLE_TH_ERRORS +} + +static PyMethodDef methods[] = { + {"_init", _initExtension, METH_NOARGS, nullptr}, + {"_get_default_generator", _getDefaultGenerator, METH_O, nullptr}, + {nullptr, nullptr, 0, nullptr} +}; + +static struct PyModuleDef openreg_C_module = + {PyModuleDef_HEAD_INIT, "pytorch_openreg._C", nullptr, -1, methods}; + +PyMODINIT_FUNC PyInit__C(void) { + PyObject* mod = PyModule_Create(&openreg_C_module); + + py::object openreg_mod = py::module_::import("pytorch_openreg"); + // Only borrowed from the python side! + openreg::set_impl_factory(openreg_mod.attr("impl_factory").ptr()); + + return mod; +} diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h new file mode 100644 index 0000000000000..a04248f2e5029 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h @@ -0,0 +1,50 @@ +#pragma once + +#include + +namespace openreg { + +using openreg_ptr_t = uint64_t; + +void set_impl_factory(PyObject* factory); +py::function get_method(const char* name); + +static constexpr char kFreeMethod[] = "free"; +static constexpr char kHostFreeMethod[] = "hostFree"; + +template +static void ReportAndDelete(void* ptr) { + if (!ptr || !Py_IsInitialized()) { + return; + } + + py::gil_scoped_acquire acquire; + + PyObject *type = nullptr, *value = nullptr, *traceback = nullptr; + // Always stash, this will be a no-op if there is no error + PyErr_Fetch(&type, &value, &traceback); + + TORCH_CHECK( + get_method(name)(reinterpret_cast(ptr)).cast(), + "Failed to free memory pointer at ", + ptr); + + // If that user code raised an error, just print it without raising it + if (PyErr_Occurred()) { + PyErr_Print(); + } + + // Restore the original error + PyErr_Restore(type, value, traceback); +} + +#define REGISTER_PRIVATEUSE1_SERIALIZATION( \ + FOR_SERIALIZATION, FOR_DESERIALIZATION) \ + static int register_serialization() { \ + torch::jit::TensorBackendMetaRegistry( \ + c10::DeviceType::PrivateUse1, FOR_SERIALIZATION, FOR_DESERIALIZATION); \ + return 0; \ + } \ + static const int _temp = register_serialization(); + +} // namespace openreg diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp new file mode 100644 index 0000000000000..a87b378fb95c8 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp @@ -0,0 +1,350 @@ +#include "OpenReg.h" + +#include +#include +#include + +#include +#include +#include + +namespace openreg { +namespace { + +// Python factory function where real implementations can be found +PyObject* py_factory; + +struct HostAllocator final : at::Allocator { + HostAllocator() = default; + + at::DataPtr allocate(size_t nbytes) override { + py::gil_scoped_acquire acquire; + void* data = nullptr; + if (nbytes > 0) { + data = reinterpret_cast( + get_method("hostMalloc")(nbytes).cast()); + TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host."); + } + return {data, data, &ReportAndDelete, at::Device(at::kCPU)}; + } + + at::DeleterFnPtr raw_deleter() const override { + return &ReportAndDelete; + } + + void copy_data(void* dest, const void* src, std::size_t count) const final { + py::gil_scoped_acquire acquire; + get_method("hostCopyData")( + reinterpret_cast(dest), + reinterpret_cast(src), + count); + } +}; + +static HostAllocator global_host_alloc; + +static c10::DeviceIndex device_count() { + py::gil_scoped_acquire acquire; + return get_method("deviceCount")().cast(); +} + +static c10::DeviceIndex current_device_idx() { + py::gil_scoped_acquire acquire; + return get_method("getDevice")().cast(); +} + +class OpenRegGeneratorImpl : public at::CPUGeneratorImpl { + public: + OpenRegGeneratorImpl(c10::DeviceIndex device_index) { + device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index); + key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1); + } + ~OpenRegGeneratorImpl() override = default; +}; + +static at::Generator make_openreg_generator(c10::DeviceIndex device_index) { + return at::make_generator(device_index); +} + +// Default, global generators, one per device. +static std::vector default_generators; + +struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface { + OpenRegHooksInterface() {}; + ~OpenRegHooksInterface() override = default; + + bool hasPrimaryContext(c10::DeviceIndex device_index) const override { + py::gil_scoped_acquire acquire; + return get_method("hasPrimaryContext")(device_index).cast(); + } + + at::Allocator* getPinnedMemoryAllocator() const override { + return &global_host_alloc; + } + + bool isPinnedPtr(const void* data) const override { + py::gil_scoped_acquire acquire; + return get_method("isPinnedPtr")(reinterpret_cast(data)) + .cast(); + } + + const at::Generator& getDefaultGenerator( + c10::DeviceIndex device_index) const override { + static bool flag [[maybe_unused]] = []() { + auto deivce_nums = device_count(); + default_generators.resize(deivce_nums); + for (auto i = 0; i < deivce_nums; i++) { + default_generators[i] = make_openreg_generator(i); + default_generators[i].seed(); + } + return true; + }(); + + c10::DeviceIndex idx = device_index; + if (idx == -1) { + idx = current_device_idx(); + } else { + TORCH_CHECK(idx >= 0 && idx < device_count()); + } + return default_generators[idx]; + } + + at::Generator getNewGenerator(c10::DeviceIndex device_index) const override { + return make_openreg_generator(device_index); + } +}; + +static bool register_hook_flag [[maybe_unused]] = []() { + at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface()); + + return true; +}(); + +// Device guard registration +struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface { + static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1; + + OpenRegGuardImpl() = default; + explicit OpenRegGuardImpl(c10::DeviceType t) { + TORCH_INTERNAL_ASSERT(t == static_type); + } + + /** + * Return the type of device managed by this guard implementation. + */ + c10::DeviceType type() const override { + return static_type; + } + + /** + * Set the current device to Device, and return the previous c10::Device. + */ + c10::Device exchangeDevice(c10::Device d) const override { + TORCH_INTERNAL_ASSERT(d.is_privateuseone()); + py::gil_scoped_acquire acquire; + auto old_device_index = + get_method("exchangeDevice")(d.index()).cast(); + return c10::Device(static_type, old_device_index); + } + + /** + * Get the current device. + */ + c10::Device getDevice() const override { + return c10::Device(static_type, current_device_idx()); + } + + /** + * Set the current device to c10::Device. + */ + void setDevice(c10::Device d) const override { + TORCH_INTERNAL_ASSERT(d.is_privateuseone()); + py::gil_scoped_acquire acquire; + auto device = get_method("setDevice")(d.index()); + } + + /** + * Set the current device to c10::Device, without checking for errors + * (so, e.g., this can be called from a destructor). + */ + void uncheckedSetDevice(c10::Device d) const noexcept override { + py::gil_scoped_acquire acquire; + auto device = get_method("uncheckedSetDevice")(d.index()); + } + + /** + * Get the current stream for a given device. + */ + c10::Stream getStream(c10::Device d) const noexcept override { + py::gil_scoped_acquire acquire; + auto stream_id = get_method("getStream")(d.index()).cast(); + return c10::Stream(c10::Stream::UNSAFE, d, stream_id); + } + + /** + * Get the default stream for a given device. + */ + c10::Stream getDefaultStream(c10::Device d) const override { + py::gil_scoped_acquire acquire; + return get_method("getDefaultStream")(d.index()).cast(); + } + + /** + * Get a stream from the global pool for a given device. + */ + c10::Stream getStreamFromGlobalPool( + c10::Device d, + bool isHighPriority = false) const override { + py::gil_scoped_acquire acquire; + return get_method("getStreamFromGlobalPool")(d.index(), isHighPriority) + .cast(); + } + + /** + * Return a new stream for a given device and priority. The stream will be + * copied and shared around, device backend should be able to correctly handle + * the lifetime of the stream. + */ + c10::Stream getNewStream(c10::Device d, int priority = 0) const override { + py::gil_scoped_acquire acquire; + auto stream_id = + get_method("getNewStream")(d.index(), priority).cast(); + return c10::Stream(c10::Stream::UNSAFE, d, stream_id); + } + + /** + * Set a stream to be the thread local current stream for its device. + * Return the previous stream for that device. You are NOT required + * to set the current device to match the device of this stream. + */ + c10::Stream exchangeStream(c10::Stream s) const noexcept override { + py::gil_scoped_acquire acquire; + auto stream_id = get_method("exchangeStream")(s).cast(); + return c10::Stream(c10::Stream::UNSAFE, s.device(), stream_id); + } + + /** + * Destroys the given event. + */ + void destroyEvent(void* event, const c10::DeviceIndex device_index) + const noexcept override { + py::gil_scoped_acquire acquire; + get_method("destroyEvent")((int64_t)event, device_index); + } + + /** + * Increments the event's version and enqueues a job with this version + * in the stream's work queue. When the stream process that job + * it notifies all streams waiting on / blocked by that version of the + * event to continue and marks that version as recorded. + * */ + void record( + void** event, + const c10::Stream& stream, + const c10::DeviceIndex device_index, + const c10::EventFlag flag) const override { + py::gil_scoped_acquire acquire; + get_method("record")((int64_t)event, stream, device_index, (int64_t)flag); + } + + /** + * Does nothing if the event has not been scheduled to be recorded. + * If the event was previously enqueued to be recorded, a command + * to wait for the version of the event that exists at the time of this call + * is inserted in the stream's work queue. + * When the stream reaches this command it will stop processing + * additional commands until that version of the event is marked as recorded. + */ + void block(void* event, const c10::Stream& stream) const override { + py::gil_scoped_acquire acquire; + get_method("block")((int64_t)event, stream); + } + + /** + * Returns true if (and only if) + * (1) the event has never been scheduled to be recorded + * (2) the current version is marked as recorded. + * Returns false otherwise. + */ + bool queryEvent(void* event) const override { + py::gil_scoped_acquire acquire; + return get_method("queryEvent")((int64_t)event).cast(); + } + + /** + * Get the number of devices. WARNING: This is REQUIRED to not raise + * an exception. If there is some sort of problem, e.g., driver error, + * you should report that there are zero available devices. + */ + c10::DeviceIndex deviceCount() const noexcept override { + return device_count(); + } + /** + * Return true if all the work previously enqueued on the stream for + * asynchronous execution has completed running on the device. + */ + bool queryStream(const c10::Stream& stream) const override { + py::gil_scoped_acquire acquire; + return get_method("queryStream")(stream).cast(); + } + + /** + * Wait (by blocking the calling thread) until all the work previously + * enqueued on the stream has completed running on the device. + */ + virtual void synchronizeStream(const c10::Stream& stream) const override { + py::gil_scoped_acquire acquire; + get_method("synchronizeStream")(stream); + } + + /** + * Wait (by blocking the calling thread) until all the work previously + * recorded on the event has completed running on the device. + */ + void synchronizeEvent(void* event) const override { + py::gil_scoped_acquire acquire; + get_method("synchronizeEvent")((int64_t)event); + } + + /** + * Ensure the caching allocator (if any) is aware that the given DataPtr is + * being used on the given stream, and that it should thus avoid recycling the + * DataPtr until all work on that stream is done. + */ + void recordDataPtrOnStream( + const c10::DataPtr& data_ptr, + const c10::Stream& stream) const override { + py::gil_scoped_acquire acquire; + get_method("recordDataPtrOnStream")(data_ptr, stream); + } + + /** + * Fetch the elapsed time between two recorded events. + */ + double elapsedTime( + void* event1, + void* event2, + const c10::DeviceIndex device_index) const override { + py::gil_scoped_acquire acquire; + return get_method("elapsedTime")( + (int64_t)event1, (int64_t)event2, device_index) + .cast(); + } +}; + +// Register our device guard +C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl); + +} // namespace + +// Setter for the python dictionary with implementations +void set_impl_factory(PyObject* factory) { + py_factory = factory; +} + +py::function get_method(const char* name) { + auto factory = py::cast(py_factory); + return factory(name); +} + +} // namespace openreg diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp new file mode 100644 index 0000000000000..9289ec7b62db2 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp @@ -0,0 +1,364 @@ +#include "OpenReg.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace openreg { +namespace { + +struct OpenRegAllocator final : at::Allocator { + OpenRegAllocator() = default; + + at::DataPtr allocate(size_t nbytes) override { + py::gil_scoped_acquire acquire; + auto curr_device_idx = get_method("getDevice")().cast(); + auto curr_device = + c10::Device(c10::DeviceType::PrivateUse1, curr_device_idx); + void* data = nullptr; + if (nbytes > 0) { + data = reinterpret_cast( + get_method("malloc")(nbytes).cast()); + TORCH_CHECK( + data, "Failed to allocator ", nbytes, " bytes on openreg device."); + } + return {data, data, &ReportAndDelete, curr_device}; + } + + at::DeleterFnPtr raw_deleter() const override { + return &ReportAndDelete; + } + + void copy_data(void* dest, const void* src, std::size_t count) const final { + py::gil_scoped_acquire acquire; + get_method("copy_data")( + reinterpret_cast(dest), + reinterpret_cast(src), + count); + } +}; + +static OpenRegAllocator global_openreg_alloc; +REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc); + +// Empty op needs C++ code and cannot be handled by python side fallback +at::Tensor empty_openreg( + c10::IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt) { + const auto device = c10::device_or_default(device_opt); + const auto dtype = c10::dtype_or_default(dtype_opt); + TORCH_CHECK(device.is_privateuseone()); + TORCH_CHECK( + c10::layout_or_default(layout_opt) == c10::Layout::Strided, + "Non strided layout not supported"); + TORCH_CHECK( + !c10::pinned_memory_or_default(pin_memory_opt), + "Pin memory can only be on CPU"); + const c10::DeviceGuard device_guard(device); + constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); + return at::detail::empty_generic( + size, &global_openreg_alloc, pu1_dks, dtype, memory_format_opt); +} + +at::Tensor empty_strided_openreg( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + const auto device = c10::device_or_default(device_opt); + const auto dtype = c10::dtype_or_default(dtype_opt); + TORCH_CHECK(device.is_privateuseone()); + TORCH_CHECK( + c10::layout_or_default(layout_opt) == c10::Layout::Strided, + "Non strided layout not supported"); + TORCH_CHECK( + !c10::pinned_memory_or_default(pin_memory_opt), + "Pin memory can only be on CPU"); + const c10::DeviceGuard device_guard(device); + constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); + return at::detail::empty_strided_generic( + size, stride, &global_openreg_alloc, pu1_dks, dtype); +} + +at::Tensor as_strided_openreg( + const at::Tensor& self, + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional storage_offset_) { + // Metadata-only change so we re-use the cpu impl + return at::cpu::as_strided(self, size, stride, storage_offset_); +} + +const at::Tensor& resize__openreg( + const at::Tensor& self, + c10::SymIntArrayRef size, + ::std::optional memory_format) { + return at::native::resize_( + self, C10_AS_INTARRAYREF_SLOW(size), memory_format); +} + +at::Tensor& set_source_Storage_storage_offsetset_openreg( + at::Tensor& result, + at::Storage storage, + int64_t storage_offset, + c10::IntArrayRef size, + c10::IntArrayRef stride) { + return at::cpu::set_(result, storage, storage_offset, size, stride); +} + +std::tuple +custom_scaled_dot_product_fused_attention_overrideable( + const at::Tensor & query, + const at::Tensor & key, + const at::Tensor & value, + const std::optional & attn_bias, + double dropout_p, + bool is_causal, + bool return_debug_mask, + std::optional scale) { + const int64_t batch_size = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t head_dim_v = value.size(3); + const int64_t max_seqlen_q = query.size(2); + const int64_t max_seqlen_kv = key.size(2); + + auto opts = query.options(); + auto output = at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts); + auto logsumexp = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + auto debug_attn_mask = at::empty({batch_size, num_heads, max_seqlen_q, max_seqlen_kv}, + opts.dtype(at::kFloat)); + auto philox_seed = at::empty({}, at::dtype(at::kLong)); + auto philox_offset = at::empty({}, at::dtype(at::kLong)); + + return std::make_tuple(output, logsumexp, at::Tensor(), at::Tensor(), max_seqlen_q, max_seqlen_kv, philox_seed, philox_offset, debug_attn_mask); +} + +std::tuple +custom_scaled_dot_product_fused_attention_overrideable_backward( + const at::Tensor & grad_out, + const at::Tensor & query, + const at::Tensor & key, + const at::Tensor & value, + const at::Tensor & attn_bias, + std::array grad_input_mask, + const at::Tensor & out, + const at::Tensor & logsumexp, + const at::Tensor & cum_seq_q, + const at::Tensor & cum_seq_k, + int64_t max_q, + int64_t max_k, + double dropout_p, + bool is_causal, + const at::Tensor & philox_seed, + const at::Tensor & philox_offset, + std::optional scale) { + return std::tuple( + at::empty_like(query), + at::empty_like(key), + at::empty_like(value), + at::empty_like(attn_bias)); +} +} + +// Using the simplest way to obtain continuous Tensor data and process it. +// This is a demo for using operand API, and you can add more complex logic +// for input and output tensor based on your custom device kernel. +void abs_kernel(at::TensorIteratorBase& iter) { + // Abs only have a input tensor and a output tensor. + auto& output_operand = iter.operand(0); + auto& input_operand = iter.operand(1); + auto& output_tensor_base = output_operand.tensor_base(); + auto& input_tensor_base = input_operand.tensor_base(); + TORCH_CHECK(!input_operand.original_tensor_base().defined(), + "input original tensor is defined."); + TORCH_CHECK(!output_operand.original_tensor_base().defined(), + "output original tensor is defined."); + // For easy test, only accept contiguous input tensor for calculate. + auto memory_format = input_tensor_base.suggest_memory_format(); + TORCH_CHECK(input_tensor_base.is_contiguous(memory_format), + "Input tensor need be contiguous."); + // Add necessary restrictions to ensure the security of the demo. + TORCH_CHECK(input_tensor_base.sizes() == output_tensor_base.sizes(), + "Intput and output tensor size are not equal."); + // Common dtype is calculate in TensorIteratorBase. + TORCH_CHECK(iter.common_dtype() == at::ScalarType::Float, + "Only support float type.") + // Using for loop for abs calculate. + auto abs_function = [](float* output_ptr, const float* input_ptr, + const int64_t NUM) { + for (int64_t i = 0; i < NUM; ++i) { + *(output_ptr + i) = std::abs(*(input_ptr + i)); + } + }; + // To simplify the logic of the test demo code, + // we only use contiguous tensor to calculate on device side. + // And using input tensor memory format. + if (iter.is_contiguous()) { + // Add for will_resize flag check. You can convert to differernt + // tensor memory format when will_resize is True. + // If TensorIteratorConfig resize_outputs_ flag is true, and there are two + // situations: + // 1) Out tensor is undefined, and TensorIterator set will_resize to true; + // 2) Out tensor is defined and tensor size is not equal to input tensor size; + // TensorIterator set will_resize to true, and call set_output_raw_strided + // to resize output tensor. + // When output operand will_resize flag is ture, dummy + // device can convert tensor to dummy device preferred memory format. + // Here we don't convert tensor memory format, because it will become complex + // when dummy device want keep same memory format for training network. + TORCH_CHECK(output_operand.will_resize, + "output operand will_resize flag need be True."); + abs_function((float*)iter.data_ptr(0), (float*)iter.data_ptr(1), iter.numel()); + } else { + // Stride copy is not support for foo device, using cpu device instead. + // For abs op, the last situation is: output tensor is not contiguous with + // operand will_resize is False. + TORCH_CHECK(!output_operand.will_resize, "output operand will_resize is True."); + // Get a contiguous tensor with input memory format. + at::Tensor output = at::empty(output_tensor_base.sizes(), + input_tensor_base.options() + .memory_format(memory_format)); + // For structured op which inheried from TensorIteratorBase, maybe you need to + // call set_output_raw_strided function to update output stored in op sturctured. + // abs op is no need to do this. + output_operand.exchange_tensor(c10::MaybeOwned::owned(std::in_place, output)); + abs_function((float*)output_operand.tensor_base().mutable_data_ptr(), + (float*)iter.data_ptr(1), iter.numel()); + // Copy tensor base to original tensor base, and keep same scalar type and + // stride with cpu and gpu. + if (output_operand.original_tensor_base().defined() && + !output_operand.original_tensor_base().is_same(output_operand.tensor_base())) { + output_operand.original_tensor().copy_(output_operand.tensor()); + output_operand.restore_original_tensor(); + } + } +} + +int64_t _fused_sdp_choice_privateuse1( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa) { + auto backend = sdp::SDPBackend::overrideable; + return static_cast(backend); +} + +void quantize_tensor_per_tensor_affine_privateuse1( + const at::Tensor& rtensor, + at::Tensor& qtensor, + double scale, + int64_t zero_point) { + // Just test the process, so do nothing +} + +/* Notes: + * + * OpenReg is currently designed to simulate device memory through multiple + * subprocesses on purpose to ensure we don't mistakenly poke at the "device's + * memory" from the main process. And be able to simulate the same thing that + * happens with other accelerators: any metadata-only change is cpu-only + * (main process), any data change must go through to the device (other process) + * and any data transfer between the two is expensive (serializing the whole + * Tensor). + * + * Currently, for the efficiency of IPC, most operations are to pass the Tensor + * metadata, and only a small number of operations involving copy will serialize + * and pass the Tensor body by custom pickler provided by torch.multiprocess. + * + * Therefore, in principle, only operations related to Metadata modification can + * be directly implemented at the C++ level and registered in PrivateUse1; but + * if memory access is involved, the relevant operations must be implemented at + * the Python level, otherwise invalid memory access will result. + */ + +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + m.impl("empty.memory_format", empty_openreg); + m.impl("empty_strided", empty_strided_openreg); + m.impl("as_strided", as_strided_openreg); + m.impl("resize_", resize__openreg); + m.impl("set_.source_Storage", at::native::set_); + m.impl("set_.source_Storage_storage_offset", set_source_Storage_storage_offsetset_openreg); + m.impl("quantize_per_tensor", at::native::quantize_per_tensor); + m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1); + m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable); + m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward); +} + +struct OpenRegBackendMeta : public c10::BackendMeta { + OpenRegBackendMeta(int version_number, int format_number) + : version_number_(version_number), format_number_(format_number) {} + + int version_number_{-1}; + int format_number_{-1}; +}; + +void for_serialization( + const at::Tensor& t, + std::unordered_map& m) { + auto meta_ptr = t.unsafeGetTensorImpl()->get_backend_meta(); + + if (meta_ptr != nullptr) { + auto o_meta_ptr = dynamic_cast(meta_ptr); + if (o_meta_ptr->version_number_ == 1) { + m["version_number"] = true; + } + if (o_meta_ptr->format_number_ == 29) { + m["format_number"] = true; + } + } +} + +void for_deserialization( + const at::Tensor& t, + std::unordered_map& m) { + int version_number{-1}; + int format_number{-1}; + + if (m.find("version_number") != m.end()) { + version_number = 1; + } + if (m.find("format_number") != m.end()) { + format_number = 29; + } + + c10::intrusive_ptr meta{std::unique_ptr( + new OpenRegBackendMeta(version_number, format_number))}; + t.unsafeGetTensorImpl()->set_backend_meta(meta); +} + +REGISTER_PRIVATEUSE1_SERIALIZATION(&for_serialization, &for_deserialization) +} // namespace openreg + +namespace at::native { +REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &openreg::abs_kernel); +REGISTER_PRIVATEUSE1_DISPATCH( + quantize_tensor_per_tensor_affine_stub, + &openreg::quantize_tensor_per_tensor_affine_privateuse1); +REGISTER_PRIVATEUSE1_DISPATCH( + _fused_sdp_choice_stub, + &openreg::_fused_sdp_choice_privateuse1); +} // namespace at::native diff --git a/test/cpp_extensions/open_registration_extension/setup.py b/test/cpp_extensions/open_registration_extension/setup.py new file mode 100644 index 0000000000000..fa8c1308c6c52 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/setup.py @@ -0,0 +1,78 @@ +import distutils.command.clean +import os +import platform +import shutil +import sys +from pathlib import Path + +from setuptools import find_packages, setup + +from torch.utils.cpp_extension import BuildExtension, CppExtension + + +PACKAGE_NAME = "pytorch_openreg" +version = 1.0 + +ROOT_DIR = Path(__file__).absolute().parent +CSRS_DIR = ROOT_DIR / "pytorch_openreg/csrc" + + +class clean(distutils.command.clean.clean): + def run(self): + # Run default behavior first + distutils.command.clean.clean.run(self) + + # Remove pytorch_openreg extension + for path in (ROOT_DIR / "pytorch_openreg").glob("**/*.so"): + path.unlink() + # Remove build directory + build_dirs = [ + ROOT_DIR / "build", + ] + for path in build_dirs: + if path.exists(): + shutil.rmtree(str(path), ignore_errors=True) + + +if __name__ == "__main__": + if sys.platform == "win32": + vc_version = os.getenv("VCToolsVersion", "") + if vc_version.startswith("14.16."): + CXX_FLAGS = ["/sdl"] + else: + CXX_FLAGS = ["/sdl", "/permissive-"] + elif platform.machine() == "s390x": + # no -Werror on s390x due to newer compiler + CXX_FLAGS = {"cxx": ["-g", "-Wall"]} + else: + CXX_FLAGS = {"cxx": ["-g", "-Wall", "-Werror"]} + + sources = list(CSRS_DIR.glob("*.cpp")) + + # Note that we always compile with debug info + ext_modules = [ + CppExtension( + name="pytorch_openreg._C", + sources=sorted(str(s) for s in sources), + include_dirs=[CSRS_DIR], + extra_compile_args=CXX_FLAGS, + ) + ] + + setup( + name=PACKAGE_NAME, + version=version, + author="PyTorch Core Team", + description="Example for PyTorch out of tree registration", + packages=find_packages(exclude=("test",)), + package_data={PACKAGE_NAME: ["*.dll", "*.dylib", "*.so"]}, + install_requires=[ + "torch", + ], + ext_modules=ext_modules, + python_requires=">=3.8", + cmdclass={ + "build_ext": BuildExtension.with_options(no_python_abi_suffix=True), + "clean": clean, + }, + ) diff --git a/test/custom_operator/my_custom_ops.py b/test/custom_operator/my_custom_ops.py index 0eedcb49c2c5b..4159a1d2e64e7 100644 --- a/test/custom_operator/my_custom_ops.py +++ b/test/custom_operator/my_custom_ops.py @@ -6,7 +6,11 @@ torch.ops.load_library(get_custom_op_library_path()) +<<<<<<< HEAD @torch.library.register_fake("custom::nonzero") +======= +@torch.library.impl_abstract("custom::nonzero") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def nonzero_abstract(x): n = x.dim() ctx = torch.library.get_ctx() diff --git a/test/custom_operator/my_custom_ops2.py b/test/custom_operator/my_custom_ops2.py index 2a7f4b825f478..54be4a9404ad2 100644 --- a/test/custom_operator/my_custom_ops2.py +++ b/test/custom_operator/my_custom_ops2.py @@ -6,6 +6,10 @@ torch.ops.load_library(get_custom_op_library_path()) +<<<<<<< HEAD @torch.library.register_fake("custom::sin") +======= +@torch.library.impl_abstract("custom::sin") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def sin_abstract(x): return torch.empty_like(x) diff --git a/test/custom_operator/pointwise.py b/test/custom_operator/pointwise.py index 53335fdb02677..df8fd277625a0 100644 --- a/test/custom_operator/pointwise.py +++ b/test/custom_operator/pointwise.py @@ -8,12 +8,20 @@ # NB: The impl_abstract_pystub for cos actually # specifies it should live in the my_custom_ops2 module. +<<<<<<< HEAD @torch.library.register_fake("custom::cos") +======= +@torch.library.impl_abstract("custom::cos") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def cos_abstract(x): return torch.empty_like(x) # NB: There is no impl_abstract_pystub for tan +<<<<<<< HEAD @torch.library.register_fake("custom::tan") +======= +@torch.library.impl_abstract("custom::tan") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def tan_abstract(x): return torch.empty_like(x) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index c52c1e539ff6d..115f09aa19949 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -6,7 +6,10 @@ import os import tempfile from typing import Callable, Optional, Union +<<<<<<< HEAD from unittest.mock import MagicMock +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.distributed as dist @@ -20,12 +23,18 @@ MixedPrecisionPolicy, OffloadPolicy, ) +<<<<<<< HEAD from torch.distributed.fsdp._fully_shard._fsdp_api import AllGather from torch.distributed.fsdp._fully_shard._fsdp_collectives import ( _div_if_needed, _get_gradient_divide_factors, DefaultAllGather, DefaultReduceScatter, +======= +from torch.distributed.fsdp._fully_shard._fsdp_collectives import ( + _div_if_needed, + _get_gradient_divide_factors, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) foreach_all_gather, foreach_all_gather_copy_out, foreach_reduce, @@ -166,7 +175,10 @@ def _test_all_gather( all_gather_stream, ): def all_gather(fsdp_param_group: FSDPParamGroup, group: dist.ProcessGroup): +<<<<<<< HEAD all_gather_comm = DefaultAllGather() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) all_gather_result = foreach_all_gather( fsdp_param_group.fsdp_params, group, @@ -174,7 +186,10 @@ def all_gather(fsdp_param_group: FSDPParamGroup, group: dist.ProcessGroup): all_gather_copy_in_stream=all_gather_copy_in_stream, all_gather_stream=all_gather_stream, device=self.device, +<<<<<<< HEAD all_gather_comm=all_gather_comm, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) foreach_all_gather_copy_out(all_gather_result, fsdp_params, group) # Transition to unsharded state to register unsharded parameters @@ -267,7 +282,10 @@ def _test_reduce_scatter( group = fsdp_param_group.mesh_info.shard_process_group self.assertEqual(group.size(), self.world_size) all_reduce_stream = device_module.Stream() +<<<<<<< HEAD comm = DefaultReduceScatter() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ( _, _, @@ -280,7 +298,10 @@ def _test_reduce_scatter( unsharded_grads, group, reduce_scatter_stream, +<<<<<<< HEAD comm, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig_dtype=orig_params[0].dtype, reduce_dtype=reduce_scatter_dtype, device=self.device, @@ -419,10 +440,13 @@ def test_set_reduce_scatter_divide_factor(self): {"divide_factor": [self.world_size * 2, self.world_size]}, self._test_set_reduce_scatter_divide_factor, ) +<<<<<<< HEAD self.run_subtests( {"divide_factor": [self.world_size]}, self._test_set_reduce_scatter_divide_factor_mixed_prevision, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _test_set_reduce_scatter_divide_factor(self, divide_factor: float): torch.manual_seed(42) @@ -455,6 +479,7 @@ def _test_set_reduce_scatter_divide_factor(self, divide_factor: float): self.assertEqual(ref_loss, loss) check_sharded_parity(self, ref_model, model) +<<<<<<< HEAD def _test_set_reduce_scatter_divide_factor_mixed_prevision( self, divide_factor: float ): @@ -504,6 +529,8 @@ def _test_set_reduce_scatter_divide_factor_mixed_prevision( self.assertEqual(ref_loss, loss) check_sharded_parity(self, ref_model, model) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skip_if_lt_x_gpu(2) def test_set_reshard_after_forward(self): """ @@ -1415,6 +1442,7 @@ def test_fully_shard_alloc_from_pg(self): with open(self.nccl_log_dir.name + "/nccl_log") as f: self.assertRegex(f.read(), self.MEMORY_REGISTER_RE) +<<<<<<< HEAD @skip_if_lt_x_gpu(2) def test_exception_when_used_together_with_comm_hooks(self): model = nn.Linear(16, 16) @@ -1431,6 +1459,8 @@ def test_exception_when_used_together_with_comm_hooks(self): with self.assertRaises(AssertionError): model.set_allocate_memory_from_process_group_for_comm(True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestFullyShardForceSumReduction(FSDPTest): # The messages might change when we move to a different NCCL version. @@ -1570,6 +1600,7 @@ def test_fully_shard_force_sum_both_reductions(self): self.assertRegex(logs, all_reduce_sum_re) +<<<<<<< HEAD class TestFullyShardReduceOpWorldSize1(FSDPTest): @property def world_size(self) -> int: @@ -1619,5 +1650,7 @@ def test_size1_reduceop(self): self.assertEqual(all_reduce_op, ReduceOp.SUM) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index b64d4107ee0ca..933f9c0715bb4 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -299,6 +299,7 @@ def _check_count(copy_count, resize_count): def _reinplace_all_gather_with_optional_checks(self, fwd_fullgraph): def _run_with_checks(graph, orig_fn): +<<<<<<< HEAD if self.world_size > 1: self.assertGreater( _count_op_in_graph( @@ -313,6 +314,14 @@ def _run_with_checks(graph, orig_fn): ), 0, ) +======= + self.assertGreater( + _count_op_in_graph( + graph, torch.ops._c10d_functional.all_gather_into_tensor.default + ), + 0, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig_fn(graph) @@ -323,6 +332,7 @@ def _run_with_checks(graph, orig_fn): 0, ) +<<<<<<< HEAD if self.world_size > 1: self.assertGreater( _count_op_in_graph( @@ -339,6 +349,14 @@ def _run_with_checks(graph, orig_fn): ), 0, ) +======= + self.assertGreater( + _count_op_in_graph( + graph, torch.ops._c10d_functional.all_gather_into_tensor_out.default + ), + 0, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if fwd_fullgraph: return mock.patch.object( @@ -566,8 +584,12 @@ def test_compiled(): Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround. Developer debug context: call_method TensorVariable() backward () {} +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0123.html""", # noqa: B950 +======= +""", # noqa: B950 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: self.assertGreater(len(counters["graph_break"]), 1) @@ -1110,7 +1132,10 @@ def _test_transformer_backend_inductor_fullgraph_True(self): pass file_check.run(bwd_code) +<<<<<<< HEAD @unittest.skip('"Traceable FSDP2" is not being maintained anymore.') +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfRocm @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") # TODO: native_dropout causes CUDA IMA error, need to figure out why @@ -1118,7 +1143,10 @@ def _test_transformer_backend_inductor_fullgraph_True(self): def test_transformer_backend_inductor_fullgraph_True(self): self._test_transformer_backend_inductor_fullgraph_True() +<<<<<<< HEAD @unittest.skip('"Traceable FSDP2" is not being maintained anymore.') +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfRocm @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") # TODO: native_dropout causes CUDA IMA error, need to figure out why diff --git a/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py b/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py index 0ce32057ffbe0..dbc105b60689f 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py @@ -73,7 +73,11 @@ def _test_gradient_scaler(self, has_inf: bool, test_2d: bool): opt.param_groups[0]["params"][0].grad._local_tensor[0, 0].fill_( float("inf") ) +<<<<<<< HEAD initial_grad = opt.param_groups[0]["params"][0].grad.to_local().clone() +======= + inital_grad = opt.param_groups[0]["params"][0].grad.to_local().clone() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) scaler.unscale_(opt) for found_inf in scaler._per_optimizer_states[id(opt)][ @@ -85,7 +89,11 @@ def _test_gradient_scaler(self, has_inf: bool, test_2d: bool): OptState.UNSCALED.value, ) unscaled_grad = opt.param_groups[0]["params"][0].grad.to_local().clone() +<<<<<<< HEAD self.assertEqual(unscaled_grad, initial_grad * inv_scale) +======= + self.assertEqual(unscaled_grad, inital_grad * inv_scale) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) initial_scale = scaler.get_scale() initial_state = copy.copy(opt.state) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_logging.py b/test/distributed/_composable/fsdp/test_fully_shard_logging.py index c9450a2b8f475..f5520562e08d9 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_logging.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_logging.py @@ -6,9 +6,17 @@ import torch.distributed as dist from torch._dynamo.test_case import run_tests from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +<<<<<<< HEAD from torch.testing._internal.logging_utils import LoggingTestCase +======= +from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.logging_utils import LoggingTestCase + + +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 3991fda639108..ef6837cd318d2 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -335,7 +335,11 @@ def test_train_parity_multi_group(self): self.run_subtests( { "reshard_after_forward": [True, False, 2], +<<<<<<< HEAD "test_device_type": [device_type.type], +======= + "device_type": [device_type.type], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "offload_policy": [OffloadPolicy()], "delay_after_forward": [False, True], "delay_before_all_gather": [False, True], @@ -360,7 +364,11 @@ def test_train_parity_multi_group_cpu_offload_eager(self): CPUOffloadPolicy(pin_memory=True), CPUOffloadPolicy(pin_memory=False), ], +<<<<<<< HEAD "test_device_type": [device_type.type], +======= + "device_type": [device_type.type], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "delay_after_forward": [False, True], "delay_before_all_gather": [False, True], "delay_before_reduce_scatter": [False, True], @@ -381,7 +389,11 @@ def test_train_parity_multi_group_unshard_async_op(self): self.run_subtests( { "reshard_after_forward": [True], +<<<<<<< HEAD "test_device_type": [device_type.type], +======= + "device_type": [device_type.type], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "offload_policy": [OffloadPolicy()], "delay_after_forward": [False, True], "delay_before_all_gather": [False, True], @@ -396,7 +408,11 @@ def _test_train_parity_multi_group( self, reshard_after_forward: Union[bool, int], offload_policy: OffloadPolicy, +<<<<<<< HEAD test_device_type: str, +======= + device_type: str, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) delay_after_forward: bool, delay_before_all_gather: bool, delay_before_reduce_scatter: bool, @@ -412,7 +428,11 @@ def _test_train_parity_multi_group( in (2, 3) ): return +<<<<<<< HEAD assert test_device_type in ("cuda", "hpu", "xpu", "cpu"), f"{test_device_type}" +======= + assert device_type in ("cuda", "hpu", "xpu", "cpu"), f"{device_type}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.manual_seed(42) vocab_size = 1024 model_args = ModelArgs( @@ -424,7 +444,11 @@ def _test_train_parity_multi_group( ) model = Transformer(model_args) ref_model = copy.deepcopy(model) +<<<<<<< HEAD if test_device_type == device_type.type: +======= + if device_type == device_type: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) replicate( ref_model.to(device_type), device_ids=[self.rank], @@ -433,7 +457,11 @@ def _test_train_parity_multi_group( gloo_pg = dist.new_group(backend="gloo") replicate(ref_model, process_group=gloo_pg) ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) +<<<<<<< HEAD mesh = init_device_mesh(test_device_type, (self.world_size,)) +======= + mesh = init_device_mesh(device_type, (self.world_size,)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fully_shard_fn = functools.partial( fully_shard, mesh=mesh, @@ -483,12 +511,20 @@ def delayed_reduce_scatter(*args, **kwargs): _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) losses.append(_model(inp).sum()) if _model is model and delay_after_forward: +<<<<<<< HEAD torch.get_device_module(test_device_type)._sleep( +======= + torch.get_device_module(device_type)._sleep( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int(delay_in_ms * get_cycles_per_ms()) ) losses[-1].backward() if _model is model and delay_before_optim: +<<<<<<< HEAD torch.get_device_module(test_device_type)._sleep( +======= + torch.get_device_module(device_type)._sleep( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int(delay_in_ms * get_cycles_per_ms()) ) _optim.step() @@ -1311,7 +1347,11 @@ def _test_3d_mlp_with_nd_mesh( use_activation_checkpointing, reshard_after_forward=reshard_after_forward, ) +<<<<<<< HEAD # Checking parameters match orig model is critical to validate .full_tensor correctly replicates the +======= + # Checking paramters match orig model is critical to validate .full_tensor correctly replicates the +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # strided-sharded layers. for ref_p, p in zip(ref_model.parameters(), model.parameters()): self.assertIsInstance(p, DTensor) @@ -1360,10 +1400,13 @@ def test_train_parity_hsdp(self): "use_activation_checkpointing": [False, True], "mlp_dim": [3, 16, 17], "sync_gradients_at_last_batch": [True, False], +<<<<<<< HEAD "offload_policy": [ CPUOffloadPolicy(pin_memory=True), CPUOffloadPolicy(pin_memory=False), ], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, functools.partial(self._test_train_parity_hsdp, global_mesh), ) @@ -1375,7 +1418,10 @@ def _test_train_parity_hsdp( use_activation_checkpointing: bool, mlp_dim: int, sync_gradients_at_last_batch: bool, +<<<<<<< HEAD offload_policy: CPUOffloadPolicy, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.manual_seed(42) model = nn.Sequential( @@ -1394,6 +1440,7 @@ def _test_train_parity_hsdp( if use_activation_checkpointing: checkpoint(mlp) fully_shard( +<<<<<<< HEAD mlp, mesh=global_mesh, reshard_after_forward=reshard_after_forward, @@ -1404,6 +1451,12 @@ def _test_train_parity_hsdp( mesh=global_mesh, reshard_after_forward=reshard_after_forward, offload_policy=offload_policy, +======= + mlp, mesh=global_mesh, reshard_after_forward=reshard_after_forward + ) + fully_shard( + model, mesh=global_mesh, reshard_after_forward=reshard_after_forward +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) optim = torch.optim.Adam(model.parameters(), lr=1e-2) check_sharded_parity(self, ref_model, model) @@ -1478,6 +1531,7 @@ def forward(self, imgs: torch.Tensor) -> torch.Tensor: check_sharded_parity(self, ref_model, model) +<<<<<<< HEAD class TestFullyShardWorldSize1(FSDPTest): @property def world_size(self) -> int: @@ -1543,5 +1597,7 @@ def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: self.assertEqual(losses[0], losses[1]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index bcaf06ea947a0..5df4863e052f1 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -277,19 +277,32 @@ def test_tp_with_fsdp_offloading(self): loss = model(inp).sum() fwd_comm_counts = fwd_comm_mode.get_comm_counts() +<<<<<<< HEAD self.assertEqual(len(fwd_comm_counts), 1) self.assertEqual(fwd_comm_counts[funcol.all_reduce], num_mlps) self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], 0) +======= + self.assertEqual(len(fwd_comm_counts), 2) + self.assertEqual(fwd_comm_counts[funcol.all_reduce], num_mlps) + self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_mlps) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ref_loss = ref_model(inp).sum() self.assertEqual(loss, ref_loss) with CommDebugMode() as bwd_comm_mode: loss.backward() bwd_comm_counts = bwd_comm_mode.get_comm_counts() +<<<<<<< HEAD self.assertEqual(len(bwd_comm_counts), 2) # First MLP's input gradient does not need to be all-reduced self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1) self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], 0) +======= + self.assertEqual(len(bwd_comm_counts), 3) + # First MLP's input gradient does not need to be all-reduced + self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1) + self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_mlps) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps) ref_loss.backward() @@ -556,6 +569,24 @@ def _compare_params(self, m1, m2): @with_comms @skip_if_lt_x_gpu(4) +<<<<<<< HEAD +======= + def test_raise_invalid_tp_composition(self): + with self.assertRaisesRegex( + RuntimeError, r"Found TP device_mesh on the \d dimension of its parent mesh" + ): + mesh_2d = init_device_mesh( + self.device_type, (2, self.world_size // 2), mesh_dim_names=("tp", "dp") + ) + parallelize_plan = { + "net1": ColwiseParallel(), + "net2": RowwiseParallel(), + } + parallelize_module(SimpleModel().cuda(), mesh_2d["tp"], parallelize_plan) + + @with_comms + @skip_if_lt_x_gpu(4) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_2d_fsdp_state_enable_extension(self): mesh_2d = init_device_mesh( self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") diff --git a/test/distributed/_composable/test_replicate.py b/test/distributed/_composable/test_replicate.py index a793fe2fed4cc..941f4ff7bda1c 100644 --- a/test/distributed/_composable/test_replicate.py +++ b/test/distributed/_composable/test_replicate.py @@ -69,7 +69,11 @@ def test_replicate_single_module_save_load(self): def test_replicate_non_root_multiple_save_load(self): """ +<<<<<<< HEAD Tests the replicate() on multiple submodules matches +======= + Tests tha replicate() on multiple submodules matches +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) local module state_dict. """ self._init_pg() diff --git a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py index f62e4d29617d0..6a6903eb1f845 100644 --- a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py +++ b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py @@ -1733,7 +1733,11 @@ def test_sharded_tensor_to_cpu(self): self.assertEqual(remote_device_before.rank(), remote_device_after.rank()) self.assertEqual(str(remote_device_after.device()), "cpu") +<<<<<<< HEAD # ensure metadata also get changed to CPU +======= + # ensure metdata also get changed to CPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) metas = new_st.metadata().shards_metadata for meta in metas: self.assertEqual(str(meta.placement.device()), "cpu") @@ -1764,7 +1768,11 @@ def test_sharded_tensor_to_cpu(self): self.assertEqual(remote_device_before.rank(), remote_device_after.rank()) self.assertEqual(str(remote_device_after.device()), "cpu") +<<<<<<< HEAD # ensure metadata also get changed to CPU +======= + # ensure metdata also get changed to CPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) metas = new_st.metadata().shards_metadata for meta in metas: self.assertEqual(str(meta.placement.device()), "cpu") @@ -1820,7 +1828,11 @@ def test_sharded_tensor_to_cuda(self): self.assertEqual(str(remote_device_before.device().type), "cpu") self.assertEqual(str(remote_device_after.device().type), "cuda") +<<<<<<< HEAD # ensure metadata also get changed to GPU +======= + # ensure metdata also get changed to GPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) metas = new_st_gpu.metadata().shards_metadata for meta in metas: self.assertEqual(str(meta.placement.device().type), "cuda") diff --git a/test/distributed/_test_template.py b/test/distributed/_test_template.py index 517a4cf97f6e8..6b59c880356b3 100644 --- a/test/distributed/_test_template.py +++ b/test/distributed/_test_template.py @@ -1,10 +1,18 @@ # Owner(s): ["oncall: distributed"] +<<<<<<< HEAD from torch.testing._internal.common_distributed import MultiProcContinuousTest from torch.testing._internal.common_utils import run_tests class TestTemplate(MultiProcContinuousTest): +======= +from torch.testing._internal.common_distributed import MultiProcContinousTest +from torch.testing._internal.common_utils import run_tests + + +class TestTemplate(MultiProcContinousTest): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def testABC(self): print(f"rank {self.rank} of {self.world_size} testing ABC") diff --git a/test/distributed/_tools/test_fsdp2_mem_tracker.py b/test/distributed/_tools/test_fsdp2_mem_tracker.py index 05e7a9640da33..70488b7e49936 100644 --- a/test/distributed/_tools/test_fsdp2_mem_tracker.py +++ b/test/distributed/_tools/test_fsdp2_mem_tracker.py @@ -37,16 +37,26 @@ def _init_cublas_workspace(dev: torch.device): def _reset_mem_stats(dev: torch.device): +<<<<<<< HEAD mod = torch.get_device_module(dev) mod.empty_cache() mod.reset_accumulated_memory_stats(dev) mod.reset_peak_memory_stats(dev) +======= + torch.cuda.empty_cache() + torch.cuda.reset_accumulated_memory_stats(dev) + torch.cuda.reset_peak_memory_stats(dev) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestTrackerFullyShard1DTrainingCore(FSDPTest): @property def world_size(self) -> int: +<<<<<<< HEAD return min(4, torch.accelerator.device_count()) +======= + return min(4, torch.cuda.device_count()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skip_if_lt_x_gpu(2) def test_tracker_multi_group_eager(self): @@ -78,6 +88,7 @@ def _test_tracker_multi_group( mp_policy: MixedPrecisionPolicy, ): debug = False +<<<<<<< HEAD dev = torch.device(torch.accelerator.current_device_index()) _init_cublas_workspace(dev) gc.collect() @@ -85,11 +96,23 @@ def _test_tracker_multi_group( mod = torch.get_device_module(dev) mem_stats = mod.memory_stats(dev) pre_acc_active = mem_stats["active_bytes.all.current"] +======= + dev = torch.device(torch.cuda.current_device()) + _init_cublas_workspace(dev) + gc.collect() + _reset_mem_stats(dev) + mem_stats = torch.cuda.memory_stats(dev) + pre_cuda_active = mem_stats["active_bytes.all.current"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.manual_seed(42) lin_dim, bsz = 2048, 8192 with torch.device(dev): model = nn.Sequential(*[MLP(dim=lin_dim, device=dev) for _ in range(4)]) +<<<<<<< HEAD mesh = init_device_mesh(dev.type, (self.world_size,)) +======= + mesh = init_device_mesh("cuda", (self.world_size,)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fully_shard_fn = functools.partial( fully_shard, mesh=mesh, @@ -112,6 +135,7 @@ def _test_tracker_multi_group( optim.zero_grad() if iter_idx == 0: fmt.reset_mod_stats() +<<<<<<< HEAD mem_stats = mod.memory_stats() tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"] acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active @@ -120,11 +144,23 @@ def _test_tracker_multi_group( print( f"Accuracy: {accuracy} Tracker Max:{tracker_max} Accelerator Max:{acc_max}" ) +======= + mem_stats = torch.cuda.memory_stats() + tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"] + cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active + accuracy = tracker_max / cuda_max + if self.rank == 0 and debug: + print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertAlmostEqual( accuracy, 1.0, delta=0.1, +<<<<<<< HEAD msg=f"Tracker Max:{tracker_max} Accelerator Max:{acc_max}", +======= + msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) del model del inp @@ -133,6 +169,7 @@ def _test_tracker_multi_group( @skip_if_lt_x_gpu(2) def test_tracker_non_root_forward_backward(self): """ +<<<<<<< HEAD Tests tracker accuracy when running forward/backward through a non-root. """ debug = False @@ -143,6 +180,17 @@ def test_tracker_non_root_forward_backward(self): mod = torch.get_device_module(dev) mem_stats = mod.memory_stats(dev) pre_acc_active = mem_stats["active_bytes.all.current"] +======= + Tests tracker accracy when running forward/backward through a non-root. + """ + debug = False + dev = torch.device(torch.cuda.current_device()) + _init_cublas_workspace(dev) + gc.collect() + _reset_mem_stats(dev) + mem_stats = torch.cuda.memory_stats(dev) + pre_cuda_active = mem_stats["active_bytes.all.current"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.manual_seed(42) lin_dim, bsz = 2048, 8 model = nn.Sequential(*[MLP(lin_dim, dev) for _ in range(3)]) @@ -162,6 +210,7 @@ def test_tracker_non_root_forward_backward(self): optim.zero_grad() if iter_idx == 0: fmt.reset_mod_stats() +<<<<<<< HEAD mem_stats = mod.memory_stats() tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"] acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active @@ -170,11 +219,23 @@ def test_tracker_non_root_forward_backward(self): print( f"Accuracy: {accuracy} Tracker Max:{tracker_max} Accelerator Max:{acc_max}" ) +======= + mem_stats = torch.cuda.memory_stats() + tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"] + cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active + accuracy = tracker_max / cuda_max + if self.rank == 0 and debug: + print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertAlmostEqual( accuracy, 1.0, delta=0.1, +<<<<<<< HEAD msg=f"Tracker Max:{tracker_max} Accelerator Max:{acc_max}", +======= + msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) del inp del model @@ -184,7 +245,11 @@ def test_tracker_non_root_forward_backward(self): class TestTrackerFullyShard1DTrainingCompose(FSDPTest): @property def world_size(self) -> int: +<<<<<<< HEAD return min(torch.accelerator.device_count(), 4) +======= + return min(torch.cuda.device_count(), 4) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skip_if_lt_x_gpu(2) def test_tracker_with_activation_checkpointing(self): @@ -204,6 +269,7 @@ def _test_tracker_with_activation_checkpointing( ): assert checkpoint_impl in ("composable", "wrapper") debug = False +<<<<<<< HEAD dev = torch.device(torch.accelerator.current_device_index()) _init_cublas_workspace(dev) gc.collect() @@ -211,6 +277,14 @@ def _test_tracker_with_activation_checkpointing( mod = torch.get_device_module(dev) mem_stats = mod.memory_stats(dev) pre_acc_active = mem_stats["active_bytes.all.current"] +======= + dev = torch.device(torch.cuda.current_device()) + _init_cublas_workspace(dev) + gc.collect() + _reset_mem_stats(dev) + mem_stats = torch.cuda.memory_stats(dev) + pre_cuda_active = mem_stats["active_bytes.all.current"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.manual_seed(42) vocab_size = 8192 bsz, seq_len = 16, 512 @@ -257,6 +331,7 @@ def _test_tracker_with_activation_checkpointing( optim.zero_grad() if iter_idx == 0: fmt.reset_mod_stats() +<<<<<<< HEAD mem_stats = mod.memory_stats() tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"] acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active @@ -265,11 +340,23 @@ def _test_tracker_with_activation_checkpointing( print( f"Accuracy: {accuracy} Tracker Max:{tracker_max} Accelerator Max:{acc_max}" ) +======= + mem_stats = torch.cuda.memory_stats() + tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"] + cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active + accuracy = tracker_max / cuda_max + if self.rank == 0 and debug: + print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertAlmostEqual( accuracy, 1.0, delta=0.1, +<<<<<<< HEAD msg=f"Tracker Max:{tracker_max} Accelerator Max:{acc_max}", +======= + msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) del inp del model diff --git a/test/distributed/_tools/test_mem_tracker.py b/test/distributed/_tools/test_mem_tracker.py index 4b4068227d553..6d796fa574c42 100644 --- a/test/distributed/_tools/test_mem_tracker.py +++ b/test/distributed/_tools/test_mem_tracker.py @@ -5,12 +5,19 @@ import torch import torch.nn as nn from torch.distributed._tools.mem_tracker import MemTracker +<<<<<<< HEAD +======= +from torch.testing._internal.common_cuda import TEST_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import ( run_tests, skipIfRocm, skipIfTorchDynamo, +<<<<<<< HEAD TEST_CUDA, TEST_XPU, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TestCase, ) from torch.utils.checkpoint import checkpoint @@ -25,6 +32,7 @@ def _init_cublas_workspace(self, dev: torch.device): del inp def _reset_mem_stats(self, dev: torch.device): +<<<<<<< HEAD mod = torch.get_device_module(dev) mod.empty_cache() mod.reset_accumulated_memory_stats(dev) @@ -36,11 +44,22 @@ def _reset_mem_stats(self, dev: torch.device): ) @skipIfRocm() def test_accelerator_tracker_equivalence( +======= + torch.cuda.empty_cache() + torch.cuda.reset_accumulated_memory_stats(dev) + torch.cuda.reset_peak_memory_stats(dev) + + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + @skipIfRocm() + def test_cuda_tracker_equivalence( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, ): """ Tests that the tracker correctly calculates the peak memory. """ +<<<<<<< HEAD dev = torch.device(torch.accelerator.current_device_index()) self._init_cublas_workspace(dev) gc.collect(1) @@ -48,6 +67,14 @@ def test_accelerator_tracker_equivalence( mod = torch.get_device_module(dev) mem_stats = mod.memory_stats(dev) pre_acc_active = mem_stats["active_bytes.all.current"] +======= + dev = torch.device(torch.cuda.current_device()) + self._init_cublas_workspace(dev) + gc.collect(1) + self._reset_mem_stats(dev) + mem_stats = torch.cuda.memory_stats(dev) + pre_cuda_active = mem_stats["active_bytes.all.current"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bsz, n_layers, dim, dtype = 16, 4, 512, torch.bfloat16 class DummyModel(nn.Module): @@ -79,6 +106,7 @@ def forward(self, x): # Check for accuracy of peak memory tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"] +<<<<<<< HEAD mem_stats = mod.memory_stats(dev) acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active accuracy = tracker_max / acc_max @@ -88,12 +116,22 @@ def forward(self, x): @unittest.skipIf( not TEST_CUDA and not TEST_XPU, "Neither CUDA or XPU is not available" ) +======= + mem_stats = torch.cuda.memory_stats(dev) + cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active + accuracy = tracker_max / cuda_max + self.assertAlmostEqual(accuracy, 1.0, delta=0.1) + + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") + @unittest.skipIf(not TEST_CUDA, "CUDA not available") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_tracker_with_activation_checkpointing( self, ): """ Tests that the tracker correctly computes the peak memory during activation checkpointing. """ +<<<<<<< HEAD dev = torch.device(torch.accelerator.current_device_index()) self._init_cublas_workspace(dev) gc.collect(1) @@ -101,6 +139,14 @@ def test_tracker_with_activation_checkpointing( mod = torch.get_device_module(dev) mem_stats = mod.memory_stats(dev) pre_acc_active = mem_stats["active_bytes.all.current"] +======= + dev = torch.device(torch.cuda.current_device()) + self._init_cublas_workspace(dev) + gc.collect(1) + self._reset_mem_stats(dev) + mem_stats = torch.cuda.memory_stats(dev) + pre_cuda_active = mem_stats["active_bytes.all.current"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bsz, n_layers, dim, dtype = 128, 4, 1024, torch.float16 @@ -152,9 +198,15 @@ def forward(self, x): # Check for accuracy of peak memory tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"] +<<<<<<< HEAD mem_stats = mod.memory_stats(dev) acc_max = mem_stats["active_bytes.all.peak"] - pre_acc_active accuracy = tracker_max / acc_max +======= + mem_stats = torch.cuda.memory_stats(dev) + cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active + accuracy = tracker_max / cuda_max +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertAlmostEqual(accuracy, 1.0, delta=0.1) @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") diff --git a/test/distributed/_tools/test_memory_tracker.py b/test/distributed/_tools/test_memory_tracker.py index 63366033629ff..e27b6aa1cc85d 100644 --- a/test/distributed/_tools/test_memory_tracker.py +++ b/test/distributed/_tools/test_memory_tracker.py @@ -5,18 +5,31 @@ import torch import torch.nn as nn from torch.distributed._tools import MemoryTracker +<<<<<<< HEAD from torch.testing._internal.common_utils import run_tests, TestCase class TestMemoryTracker(TestCase): @unittest.skipIf(not torch.accelerator.is_available(), "no accelerator") +======= +from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase + + +class TestMemoryTracker(TestCase): + @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "no cuda/xpu") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_local_model(self): """ Minimal test case to check the memory tracker can collect the expected memory stats at operator level, as well as can print the summary result without crash. """ +<<<<<<< HEAD device = torch.accelerator.current_accelerator() +======= + device = "cuda" if TEST_CUDA else "xpu" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Create a model with a hierarchy of modules torch.manual_seed(0) model = nn.Sequential( @@ -34,9 +47,15 @@ def test_local_model(self): tracker = MemoryTracker() tracker.start_monitor(model) +<<<<<<< HEAD x = torch.randn(size=(2, 3, 224, 224), device=device) # torch.LongTensor expects cpu device type, not gpu device type in # constructor, so calling .to() outside constructor here. +======= + x = torch.randn(size=(2, 3, 224, 224), device=torch.device(device)) + # torch.LongTensor expects cpu device type, not device type in + # constructor, so calling .to(device) outside constructor here. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) target = torch.LongTensor([0, 1]).to(device) criterion = nn.CrossEntropyLoss() criterion(model(x), target).backward() diff --git a/test/distributed/_tools/test_sac_ilp.py b/test/distributed/_tools/test_sac_ilp.py index bd9c8d3a8136a..3eab972410c65 100644 --- a/test/distributed/_tools/test_sac_ilp.py +++ b/test/distributed/_tools/test_sac_ilp.py @@ -211,7 +211,11 @@ def test_sac_ilp_case3(self): class TestOptimalCheckpointingPolicy(TestCase): +<<<<<<< HEAD # tests are adapted from tests in xformers +======= + # tests are adpated from tests in xformers +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # https://github.com/facebookresearch/xformers/blob/c6c0ac31f1b08542a0bc27278c6ed10f825f6963/tests/test_checkpoint.py#L222 def setUp(self): super().setUp() diff --git a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py index e1b1041875afb..f38875e8c7671 100644 --- a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py +++ b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py @@ -1,7 +1,10 @@ # Owner(s): ["oncall: distributed"] import time +<<<<<<< HEAD from concurrent.futures import Future +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from dataclasses import dataclass, field from enum import auto, Enum from functools import partial @@ -14,7 +17,10 @@ import torch.distributed.checkpoint.state_dict_saver as saver import torch.nn as nn import torch.nn.functional as F +<<<<<<< HEAD from torch.distributed.checkpoint.staging import DefaultStager, StagingOptions +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.checkpoint.state_dict import ( _patch_model_state_dict, _patch_optimizer_state_dict, @@ -24,10 +30,14 @@ set_state_dict, ) from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys +<<<<<<< HEAD from torch.distributed.checkpoint.state_dict_saver import ( AsyncCheckpointerType, AsyncSaveResponse, ) +======= +from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.utils import CheckpointException from torch.distributed.device_mesh import init_device_mesh @@ -221,6 +231,7 @@ def test_e2e(self, compile, model_type): @skip_if_lt_x_gpu(4) @with_temp_dir @parametrize( +<<<<<<< HEAD "cache_staged_state_dict, async_checkpointer_type, zoc", [ (False, AsyncCheckpointerType.THREAD, False), @@ -234,13 +245,27 @@ def test_e2e(self, compile, model_type): def test_e2e_async_cached( self, cache_staged_state_dict, async_checkpointer_type, zoc ): +======= + "cache_staged_state_dict, async_checkpointer_type", + [ + (False, AsyncCheckpointerType.THREAD), + (True, AsyncCheckpointerType.THREAD), + (False, AsyncCheckpointerType.PROCESS), + (True, AsyncCheckpointerType.PROCESS), + ], + ) + def test_e2e_async_cached(self, cache_staged_state_dict, async_checkpointer_type): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._run_e2e_test( compile=False, model_type=ModelType.FSDP, async_op=True, cache_staged_state_dict=cache_staged_state_dict, async_checkpointer_type=async_checkpointer_type, +<<<<<<< HEAD zoc=zoc, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def _run_e2e_test( @@ -250,7 +275,10 @@ def _run_e2e_test( async_op=False, cache_staged_state_dict=False, async_checkpointer_type=None, +<<<<<<< HEAD zoc=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): model, optim = self._create_model(compile, ModelType.NONE) _train(model, optim, train_steps=2) @@ -270,6 +298,7 @@ def _run_e2e_test( writer = DCP.FileSystemWriter( self.temp_dir, cache_staged_state_dict=cache_staged_state_dict ) +<<<<<<< HEAD stager = None if not cache_staged_state_dict: use_shared_memory = ( @@ -283,6 +312,9 @@ def _run_e2e_test( ) stager = DefaultStager(staging_options) async_save_response_or_future = saver.async_save( +======= + f = saver.async_save( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sd, storage_writer=writer, async_checkpointer_type=( @@ -290,6 +322,7 @@ def _run_e2e_test( if async_checkpointer_type else AsyncCheckpointerType.THREAD ), +<<<<<<< HEAD async_stager=stager, ) if isinstance(async_save_response_or_future, Future): @@ -304,6 +337,15 @@ def _run_e2e_test( print(f"still waiting... {time.monotonic() - t}") save_future.result() +======= + ) + t = time.monotonic() + while not f.done(): + time.sleep(1) + print(f"still waiting... {time.monotonic() - t}") + + f.result() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: DCP.save(sd, checkpoint_id=self.temp_dir) diff --git a/test/distributed/checkpoint/e2e/test_fsdp_ep.py b/test/distributed/checkpoint/e2e/test_fsdp_ep.py index 51d4b3e995372..319fb9a67270e 100644 --- a/test/distributed/checkpoint/e2e/test_fsdp_ep.py +++ b/test/distributed/checkpoint/e2e/test_fsdp_ep.py @@ -72,9 +72,15 @@ def test_e2e(self): mesh_fsdp_tp = init_device_mesh( self.device_type, (2, 4), mesh_dim_names=("dp", "tp") ) +<<<<<<< HEAD # TODO: we are using an internal API atm. Change to a public API once it is ready. mesh_fsdp_ep = _mesh_resources.create_sub_mesh(mesh_fsdp_tp, ("dp",), [(0,)]) del _mesh_resources.child_to_root_mapping[mesh_fsdp_ep] +======= + # TODO: we are using an internal API atm. Change to a publich API once it is ready. + mesh_fsdp_ep = _mesh_resources.create_child_mesh(mesh_fsdp_tp, ("dp",)) + del _mesh_resources.child_to_parent_mapping[mesh_fsdp_ep] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mesh_fsdp = init_device_mesh(self.device_type, (8,)) for i, l in enumerate(model.second.ep_layers): diff --git a/test/distributed/checkpoint/test_checkpoint.py b/test/distributed/checkpoint/test_checkpoint.py index 66911327327d3..cc3c306771a23 100644 --- a/test/distributed/checkpoint/test_checkpoint.py +++ b/test/distributed/checkpoint/test_checkpoint.py @@ -2,7 +2,11 @@ import os import sys +<<<<<<< HEAD from typing import Any, cast, Optional, Union +======= +from typing import cast, Optional, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.distributed as dist @@ -170,9 +174,13 @@ def __init__(self, fail_conf): def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: return +<<<<<<< HEAD def set_up_storage_writer( self, is_coordinator: bool, *args: Any, **kwargs: Any ) -> None: +======= + def set_up_storage_writer(self, is_coordinator: bool) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._fail_rank("fail_set_up_storage_writer") def prepare_local_plan(self, plan: SavePlan) -> SavePlan: @@ -358,18 +366,26 @@ def test_load_error_handling(self) -> None: self._test_load(state_dict) self._test_load(state_dict, fail_set_up_storage_reader=[0]) self._test_load(state_dict, fail_prepare_global_plan=[0]) +<<<<<<< HEAD self._test_load(state_dict, fail_read_metadata=[0], ignore_exception_type=True) +======= + self._test_load(state_dict, fail_read_metadata=[0]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._test_load(state_dict, fail_prepare_local_plan=[1]) self._test_load(state_dict, fail_read_data=[3]) self._test_load(state_dict, fail_read_data_async=[1]) self._test_load(state_dict, coordinator=3, fail_set_up_storage_reader=[0]) +<<<<<<< HEAD self._test_load( state_dict, coordinator=1, fail_read_metadata=[3], ignore_exception_type=True, ) +======= + self._test_load(state_dict, coordinator=1, fail_read_metadata=[3]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._test_load(state_dict, coordinator=2, fail_read_data=[0]) self._test_load(state_dict, coordinator=3, fail_read_data_async=[2]) self._test_load(state_dict, coordinator=1, fail_prepare_global_plan=[1]) @@ -378,7 +394,11 @@ def test_load_error_handling_no_dist(self) -> None: state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]} self._test_load(state_dict) self._test_load(state_dict, fail_set_up_storage_reader=[0]) +<<<<<<< HEAD self._test_load(state_dict, fail_read_metadata=[0], ignore_exception_type=True) +======= + self._test_load(state_dict, fail_read_metadata=[0]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._test_load(state_dict, fail_prepare_local_plan=[0]) self._test_load(state_dict, fail_prepare_global_plan=[0]) self._test_load(state_dict, fail_read_data=[0]) diff --git a/test/distributed/checkpoint/test_fsspec.py b/test/distributed/checkpoint/test_fsspec.py index 9d69d6d386a7e..aaceb3a07b5cd 100644 --- a/test/distributed/checkpoint/test_fsspec.py +++ b/test/distributed/checkpoint/test_fsspec.py @@ -18,10 +18,14 @@ from torch.distributed.checkpoint.utils import CheckpointException from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType +<<<<<<< HEAD from torch.testing._internal.common_distributed import ( requires_accelerator_dist_backend, skip_if_lt_x_gpu, ) +======= +from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, @@ -29,10 +33,13 @@ ) +<<<<<<< HEAD device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" BACKEND = torch.distributed.get_default_backend_for_device(device_type) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def with_temp_dir( func: Optional[Callable] = None, ) -> Optional[Callable]: @@ -82,14 +89,24 @@ class TestFSSpec(ShardedTensorTestBase): def world_size(self) -> int: return 2 +<<<<<<< HEAD @with_comms(backend=BACKEND, init_rpc=False) @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(2) +======= + @with_comms(init_rpc=False) + @skip_if_lt_x_gpu(2) + @requires_nccl() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @with_temp_dir def test_fsspec(self): CHECKPOINT_DIR = self.temp_dir +<<<<<<< HEAD model = FSDP(MyTestModule().to(device_type)) +======= + model = FSDP(MyTestModule().cuda()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) optim = torch.optim.Adam(model.parameters(), lr=0.1) model(torch.rand(8, 8, device=dist.get_rank())).sum().backward() optim.step() @@ -106,7 +123,11 @@ def test_fsspec(self): planner=dcp.DefaultSavePlanner(), ) +<<<<<<< HEAD model_2 = FSDP(MyTestModule().to(device_type)) +======= + model_2 = FSDP(MyTestModule().cuda()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.1) with FSDP.summon_full_params(model): @@ -156,9 +177,15 @@ def opt_at(opt, idx): opt_at(optim, 0)["exp_avg_sq"], opt_at(optim_2, 0)["exp_avg_sq"] ) +<<<<<<< HEAD @with_comms(backend=BACKEND, init_rpc=False) @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(2) +======= + @with_comms(init_rpc=False) + @skip_if_lt_x_gpu(2) + @requires_nccl() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @with_temp_dir def test_overwrite(self): t1, t2 = torch.randn(10), torch.randn(10) diff --git a/test/distributed/checkpoint/test_hf_safetensor_e2e.py b/test/distributed/checkpoint/test_hf_safetensor_e2e.py index 9fbe2c47db039..089a19faa681b 100644 --- a/test/distributed/checkpoint/test_hf_safetensor_e2e.py +++ b/test/distributed/checkpoint/test_hf_safetensor_e2e.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: distributed checkpointing"] import importlib +<<<<<<< HEAD import json import os @@ -13,6 +14,14 @@ from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard, zeros +======= + +import torch +import torch.distributed.checkpoint as dist_cp +from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor import distribute_tensor, Replicate, Shard, zeros +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, run_tests, @@ -121,6 +130,7 @@ def test_load_into_empty_dict(self) -> None: torch.equal(state_dict_to_save[key], state_dict_loaded[key]) ) +<<<<<<< HEAD @with_temp_dir def test_load_with_multiple_threads(self) -> None: if importlib.util.find_spec("safetensors") is None: @@ -327,6 +337,8 @@ def test_consolidate_to_one_file(self) -> None: dist.barrier() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ONE_D_PLACEMENTS = [ [Shard(0)], @@ -381,7 +393,11 @@ def test_1d_to_1d_reshard_placement_change(self) -> None: state_dict=state_dict_to_save, storage_writer=dist_cp.HuggingFaceStorageWriter( path=CHECKPOINT_DIR, +<<<<<<< HEAD save_distributed=True, +======= + save_sharded=True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ) @@ -439,7 +455,11 @@ def test_2d_to_2d_reshard_placement_change(self) -> None: dist_cp.save( state_dict=state_dict_to_save, storage_writer=dist_cp.HuggingFaceStorageWriter( +<<<<<<< HEAD path=CHECKPOINT_DIR, save_distributed=True +======= + path=CHECKPOINT_DIR, save_sharded=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), planner=dist_cp.DefaultSavePlanner(), ) @@ -494,7 +514,11 @@ def test_1d_to_2d_reshard_mesh_change(self) -> None: dist_cp.save( state_dict=state_dict_to_save, storage_writer=dist_cp.HuggingFaceStorageWriter( +<<<<<<< HEAD path=CHECKPOINT_DIR, save_distributed=True +======= + path=CHECKPOINT_DIR, save_sharded=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ) @@ -545,7 +569,11 @@ def test_2d_to_1d_reshard_mesh_change(self) -> None: dist_cp.save( state_dict=state_dict_to_save, storage_writer=dist_cp.HuggingFaceStorageWriter( +<<<<<<< HEAD path=CHECKPOINT_DIR, save_distributed=True +======= + path=CHECKPOINT_DIR, save_sharded=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), planner=dist_cp.DefaultSavePlanner(), ) @@ -595,7 +623,11 @@ def test_dtensor_checkpoint_resharding_with_empty_shard(self): dist_cp.save( state_dict=ref_state_dict, storage_writer=dist_cp.HuggingFaceStorageWriter( +<<<<<<< HEAD path=self.temp_dir, save_distributed=True +======= + path=self.temp_dir, save_sharded=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ) diff --git a/test/distributed/checkpoint/test_hf_storage.py b/test/distributed/checkpoint/test_hf_storage.py index 81558db13a69f..91138d68af8bc 100644 --- a/test/distributed/checkpoint/test_hf_storage.py +++ b/test/distributed/checkpoint/test_hf_storage.py @@ -2,16 +2,24 @@ import json import os +<<<<<<< HEAD +======= +import pathlib +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import tempfile from unittest.mock import MagicMock import torch from torch.distributed.checkpoint import DefaultLoadPlanner +<<<<<<< HEAD from torch.distributed.checkpoint._hf_utils import ( _HFStorageInfo, NUM_BYTES_FOR_HEADER_LEN, ) +======= +from torch.distributed.checkpoint._hf_utils import _HFStorageInfo +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.checkpoint.default_planner import DefaultSavePlanner from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem from torch.distributed.checkpoint.hf_storage import ( @@ -108,7 +116,11 @@ def test_write_data_with_sharding(self) -> None: with tempfile.TemporaryDirectory() as path: writer = HuggingFaceStorageWriter( path=path, +<<<<<<< HEAD save_distributed=True, +======= + save_sharded=True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) writer.fs = FileSystem() @@ -162,6 +174,7 @@ def test_write_data_with_sharding(self) -> None: ) def test_read_data_hf(self) -> None: +<<<<<<< HEAD tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0]) mock_safe_open = MagicMock() @@ -171,14 +184,35 @@ def test_read_data_hf(self) -> None: sys.modules["safetensors"] = MagicMock() sys.modules["safetensors"].safe_open = mock_safe_open +======= + mock_safetensors = MagicMock() + sys.modules["safetensors"] = mock_safetensors + + # Create test tensors + tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0]) + + # Mock the deserialize function to return our test tensors + # The format matches what's expected in the read_data method + mock_safetensors.deserialize.return_value = [ + ( + "tensor_0", + {"data": tensor_0.numpy().tobytes(), "dtype": "F32", "shape": [4]}, + ), + ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with tempfile.TemporaryDirectory() as path: # Create the reader reader = HuggingFaceStorageReader(path=path) +<<<<<<< HEAD +======= + reader.fs = FileSystem() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Create test file file_name = "model-00001-of-00001.safetensors" file_path = os.path.join(path, file_name) +<<<<<<< HEAD with open(file_path, "wb") as f: # write metadata the same way it would be in safetensors file @@ -201,6 +235,9 @@ def test_read_data_hf(self) -> None: f.write(metadata_bytes) f.write(tensor_0.numpy().tobytes()) +======= + pathlib.Path(file_path).touch() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Set up storage data with _StorageInfo objects storage_data = { @@ -208,6 +245,11 @@ def test_read_data_hf(self) -> None: fqn="tensor_0", offset=torch.Size([0]), index=None ): _HFStorageInfo( file_path, +<<<<<<< HEAD +======= + 0, + tensor_0.numel() * tensor_0.element_size(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor_0.shape, tensor_0.dtype, ), @@ -260,15 +302,22 @@ def test_read_data_hf(self) -> None: ), ) +<<<<<<< HEAD +======= + # Call read_data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) future = reader.read_data(load_plan, load_planner) future.wait() # Verify results - the target tensors should now contain the values from our test tensor self.assertTrue(torch.equal(state_dict["tensor_0"], tensor_0)) +<<<<<<< HEAD mock_safe_open.assert_called_once_with(filename=file_path, framework="pt") mock_context.__enter__.return_value.get_slice.assert_called_with("tensor_0") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_write_metadata_hf(self) -> None: mock_module = MagicMock() sys.modules["huggingface_hub"] = mock_module @@ -322,6 +371,7 @@ def test_write_metadata_hf(self) -> None: self.assertEqual(metadata, expected_metadata) def test_read_metadata_hf(self): +<<<<<<< HEAD mock_safe_open = MagicMock() mock_context = MagicMock() @@ -343,11 +393,14 @@ def test_read_metadata_hf(self): sys.modules["safetensors"] = mock_safetensors sys.modules["safetensors.torch"] = mock_safetensors.torch +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with tempfile.TemporaryDirectory() as path: reader = HuggingFaceStorageReader(path=path) key = "tensor_0" file_name = "test.safetensors" +<<<<<<< HEAD file_path = os.path.join(path, file_name) # Create an empty file so fs.ls can find it @@ -366,6 +419,25 @@ def test_read_metadata_hf(self): # Verify that safe_open was called with our file path mock_safe_open.assert_called_once_with(file_path, framework="pt") +======= + with open(os.path.join(path, file_name), "wb") as f: + # write metadata the same way it would be in safetensors file + metadata_contents = json.dumps( + { + "tensor_0": { + "dtype": "F32", + "shape": [5, 10], + "data_offsets": [0, 200], + } + } + ) + metadata_bytes = metadata_contents.encode("utf-8") + + f.write(len(metadata_bytes).to_bytes(8, byteorder="little")) + f.write(metadata_bytes) + + metadata = reader.read_metadata() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( metadata.state_dict_metadata, @@ -381,7 +453,10 @@ def test_read_metadata_hf(self): ), }, ) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( metadata.storage_data, { @@ -389,6 +464,11 @@ def test_read_metadata_hf(self): fqn=key, offset=torch.Size([0, 0]), index=None ): _HFStorageInfo( os.path.join(path, file_name), +<<<<<<< HEAD +======= + 0, + 200, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.Size([5, 10]), torch.float32, ) diff --git a/test/distributed/checkpoint/test_planner.py b/test/distributed/checkpoint/test_planner.py index edf043301ed28..40451c452a124 100644 --- a/test/distributed/checkpoint/test_planner.py +++ b/test/distributed/checkpoint/test_planner.py @@ -24,7 +24,10 @@ DefaultLoadPlanner, DefaultSavePlanner, ) +<<<<<<< HEAD from torch.distributed.checkpoint.filesystem import CURRENT_DCP_VERSION +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.checkpoint.metadata import ( BytesStorageMetadata, ChunkStorageMetadata, @@ -594,6 +597,7 @@ def test_load_different_sizes_throws(self): planner=DefaultLoadPlanner(), ) +<<<<<<< HEAD @with_temp_dir def test_version_key_in_planner_data(self): original_module = nn.Linear(2, 2) @@ -610,6 +614,8 @@ def test_version_key_in_planner_data(self): self.assertEqual(planner.metadata.version, CURRENT_DCP_VERSION) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index a42215e0ea0d6..8d824a08c025d 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -62,9 +62,12 @@ from torch.utils._pytree import tree_all, tree_all_only +<<<<<<< HEAD device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) @@ -82,7 +85,11 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin): @property def world_size(self) -> int: +<<<<<<< HEAD return min(4, torch.accelerator.device_count()) +======= + return min(4, torch.cuda.device_count()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _test_save_load( self, @@ -104,7 +111,11 @@ def _test_save_load( for d_optim in _dist_optim: d_optim.zero_grad() +<<<<<<< HEAD batch = torch.rand(8, 100, device=device_type) +======= + batch = torch.rand(8, 100, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model(batch).sum().backward() dist_model(batch).sum().backward() @@ -112,7 +123,11 @@ def _test_save_load( for d_optim in _dist_optim: d_optim.step() +<<<<<<< HEAD # We need to ensure gradients don't exist, this the invariant of using DSD. +======= + # We need to ensure gradients don't exist, this the invarient of using DSD. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) optim.zero_grad() # Get the state_dict, and compare the result @@ -138,7 +153,11 @@ def _test_save_load( # We won't be able to load the partial state_dict back. return # Since we already have the state_dict saved before, no need to call DCP. +<<<<<<< HEAD # We can directly load them back. This assert is to ensure that optimizer +======= + # We can directly load them back. This asser is to ensure that optimizer +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # state storage are initialized. # self.assertEqual(len(curr_dist_osd[STATE]), len(dist_osd[STATE])) set_model_state_dict( @@ -191,9 +210,15 @@ def _test_fsdp( def init_model_optim(): if use_dtensor: +<<<<<<< HEAD device_mesh = init_device_mesh(device_type, (self.world_size,)) orig_model = CompositeParamModel(device=torch.device(device_type)) +======= + device_mesh = init_device_mesh("cuda", (self.world_size,)) + + orig_model = CompositeParamModel(device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True) copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True) if wrapping: @@ -201,7 +226,11 @@ def init_model_optim(): else: strategy = {UnitModule} if use_dtensor: +<<<<<<< HEAD device_mesh = init_device_mesh(device_type, (self.world_size,)) +======= + device_mesh = init_device_mesh("cuda", (self.world_size,)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dist_model = FSDP( copy.deepcopy(orig_model), auto_wrap_policy=ModuleWrapPolicy(strategy), @@ -261,7 +290,11 @@ def _test_fsdp2( foreach: bool = True, ): def init_model_optim(): +<<<<<<< HEAD orig_model = CompositeParamModel(device=torch.device(device_type)) +======= + orig_model = CompositeParamModel(device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig_optim = optimizer_class( orig_model.parameters(), lr=1e-4, foreach=foreach ) @@ -298,7 +331,11 @@ def test_fsdp2(self) -> None: def _test_ddp(self, use_composable: bool, optimizer_class: type[Optimizer]) -> None: def init_model_optim(): +<<<<<<< HEAD orig_model = CompositeParamModel(device=torch.device(device_type)) +======= + orig_model = CompositeParamModel(device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4) copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4) if use_composable: @@ -332,7 +369,11 @@ def _test_fsdp_ddp( test_frozen: bool = False, ) -> None: def init_model_optim(): +<<<<<<< HEAD orig_model = CompositeParamModel(device=torch.device(device_type)) +======= + orig_model = CompositeParamModel(device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if test_frozen: for param in chain( orig_model.u1.parameters(), orig_model.u2.parameters() @@ -373,7 +414,11 @@ def test_fsdp_ddp(self) -> None: def _test_single_gpu(self, optimizer_class: type[Optimizer]) -> None: def init_model_optim(): +<<<<<<< HEAD orig_model = CompositeParamModel(device=torch.device(device_type)) +======= + orig_model = CompositeParamModel(device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4) copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4) model_copy = copy.deepcopy(orig_model) @@ -388,7 +433,11 @@ def test_single_gpu(self) -> None: self._test_single_gpu(torch.optim.AdamW) def _test_strict(self, parallelism: str) -> None: +<<<<<<< HEAD model = CompositeParamModel(device=torch.device(device_type)) +======= + model = CompositeParamModel(device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if parallelism == "DDP": model = DDP(model) else: @@ -425,8 +474,13 @@ def test_strict(self) -> None: def _test_cpu_offload_full_state_dict( self, optimizer_class: type[Optimizer] ) -> None: +<<<<<<< HEAD orig_model = CompositeParamModel(device=torch.device(device_type)) device_mesh = init_device_mesh(device_type, (self.world_size,)) +======= + orig_model = CompositeParamModel(device=torch.device("cuda")) + device_mesh = init_device_mesh("cuda", (self.world_size,)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dist_model = FSDP( copy.deepcopy(orig_model), auto_wrap_policy=ModuleWrapPolicy({UnitModule}), @@ -502,7 +556,11 @@ def test_cpu_offload_full_state_dict(self) -> None: @skip_if_lt_x_gpu(1) def test_activation_ckpt_fqns_ddp(self) -> None: """Tests that activation checkpointing prefixes are removed from module names""" +<<<<<<< HEAD model = CompositeParamModel(device=torch.device(device_type)) +======= + model = CompositeParamModel(device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) original_keys = get_model_state_dict(model).keys() apply_activation_checkpointing(model) @@ -521,7 +579,11 @@ def test_activation_ckpt_fqns_fsdp1(self) -> None: def _test_activation_ckpt_fqns_fsdp1(self, use_orig_params: bool) -> None: """Tests that activation checkpointing prefixes are removed from module names""" +<<<<<<< HEAD model = CompositeParamModel(device=torch.device(device_type)) +======= + model = CompositeParamModel(device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) original_keys = get_model_state_dict(model).keys() apply_activation_checkpointing(model) @@ -532,7 +594,11 @@ def _test_activation_ckpt_fqns_fsdp1(self, use_orig_params: bool) -> None: @skip_if_lt_x_gpu(1) def test_extra_state(self) -> None: +<<<<<<< HEAD model = CompositeParamModel(device=torch.device(device_type)) +======= + model = CompositeParamModel(device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_extra_state(self): return "MyState" @@ -550,21 +616,35 @@ def set_extra_state(self, state): @skip_if_lt_x_gpu(1) def test_non_persistent_buffers(self) -> None: +<<<<<<< HEAD model = CompositeParamModel(device=torch.device(device_type)) model.register_buffer( "dont_save_me", torch.rand(100, device=device_type), persistent=False +======= + model = CompositeParamModel(device=torch.device("cuda")) + model.register_buffer( + "dont_save_me", torch.rand(100, device="cuda"), persistent=False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) target_model = copy.deepcopy(model) set_model_state_dict(target_model, get_model_state_dict(target_model)) self.assertEqual(model.state_dict(), get_model_state_dict(target_model)) def _test_broadcast_from_rank0(self, wrapper) -> None: +<<<<<<< HEAD model = CompositeParamModel(device=torch.device(device_type)) +======= + model = CompositeParamModel(device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) optim = torch.optim.Adam(model.parameters()) fsdp_model = wrapper(copy.deepcopy(model)) fsdp_optim = torch.optim.Adam(fsdp_model.parameters()) +<<<<<<< HEAD batch = torch.rand(8, 100, device=device_type) +======= + batch = torch.rand(8, 100, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model(batch).sum().backward() optim.step() states, optim_states = get_state_dict(model, optim) @@ -634,8 +714,13 @@ def check(equal): @with_comms @skip_if_lt_x_gpu(4) def test_broadcast_from_rank0(self) -> None: +<<<<<<< HEAD device_mesh = init_device_mesh(device_type, (self.world_size,)) hsdp_device_mesh = init_device_mesh(device_type, (2, self.world_size // 2)) +======= + device_mesh = init_device_mesh("cuda", (self.world_size,)) + hsdp_device_mesh = init_device_mesh("cuda", (2, self.world_size // 2)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.run_subtests( { "wrapper": [ @@ -657,8 +742,13 @@ def test_fsdp_root_not_initialized(self) -> None: # This test verifies that FSDP root is not initialized but we should # still be able to get the state_dict without errors because # fsdp_model.state_dict() will trigger the FSDP initialization. +<<<<<<< HEAD device_mesh = init_device_mesh(device_type, (self.world_size,)) model = CompositeParamModel(device=torch.device(device_type)) +======= + device_mesh = init_device_mesh("cuda", (self.world_size,)) + model = CompositeParamModel(device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh) fsdp_optim = torch.optim.Adam(fsdp_model.parameters()) get_model_state_dict(fsdp_model) @@ -671,9 +761,16 @@ def test_optim_state_dict_param_matching(self) -> None: # "initial_lr" is added to optim_state_dict, but not to the new optim # We test whether "initial_lr" appear in optim after # set_optimizer_state_dict. +<<<<<<< HEAD torch.manual_seed(0) model = nn.Sequential( *[nn.Linear(4, 4, device=device_type, bias=False) for _ in range(2)] +======= + device = "cuda" + torch.manual_seed(0) + model = nn.Sequential( + *[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for layer in model: fully_shard(layer) @@ -707,11 +804,19 @@ def test_optim_state_dict_param_matching(self) -> None: @with_comms @skip_if_lt_x_gpu(2) def test_flattened_osd(self) -> None: +<<<<<<< HEAD device_mesh = init_device_mesh(device_type, (self.world_size,)) model = CompositeParamModel(device=torch.device(device_type)) fsdp_model = fully_shard(copy.deepcopy(model), mesh=device_mesh) fsdp_optim = torch.optim.AdamW(fsdp_model.parameters()) batch = torch.rand(8, 100, device=device_type) +======= + device_mesh = init_device_mesh("cuda", (self.world_size,)) + model = CompositeParamModel(device=torch.device("cuda")) + fsdp_model = fully_shard(copy.deepcopy(model), mesh=device_mesh) + fsdp_optim = torch.optim.AdamW(fsdp_model.parameters()) + batch = torch.rand(8, 100, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fsdp_model(batch).sum().backward() fsdp_optim.step() fsdp_optim.zero_grad() @@ -732,7 +837,11 @@ def test_flattened_osd(self) -> None: self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict()) def _test_deprecate_partial(self) -> None: +<<<<<<< HEAD model = CompositeParamModel(device=torch.device(device_type)) +======= + model = CompositeParamModel(device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model_state_dict1 = get_model_state_dict(model) model_state_dict1 = copy.deepcopy(model_state_dict1) @@ -785,8 +894,13 @@ def _test_deprecate_partial(self) -> None: self.assertEqual(model.l.bias, model_state_dict1["l.bias"]) def _test_deprecate_fsdp_api(self) -> None: +<<<<<<< HEAD device_mesh = init_device_mesh(device_type, (self.world_size,)) model = CompositeParamModel(device=torch.device(device_type)) +======= + device_mesh = init_device_mesh("cuda", (self.world_size,)) + model = CompositeParamModel(device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh) with self.assertWarnsRegex( FutureWarning, @@ -825,8 +939,13 @@ def forward(self, input): return output def init_model_optim(): +<<<<<<< HEAD device_mesh = init_device_mesh(device_type, (self.world_size,)) orig_model = TiedEmbeddingModel(10000, 300).to(torch.device(device_type)) +======= + device_mesh = init_device_mesh("cuda", (self.world_size,)) + orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4) copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4) dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh) @@ -907,12 +1026,17 @@ def test_setting_meta_device_model_broadcasting_and_memory(self) -> None: self.assertEqual(cpu_model_value, meta_model_value) # Memory allocated and reserved are lower due to the change at _distribute_tensors # from view to clone. This test would fail if with view due to higher memory cost. +<<<<<<< HEAD memory_allocated = ( torch.get_device_module(device_type).memory_allocated(0) / 1024 / 1024 ) memory_reserved = ( torch.get_device_module(device_type).memory_reserved(0) / 1024 / 1024 ) +======= + memory_allocated = torch.cuda.memory_allocated(0) / 1024 / 1024 + memory_reserved = torch.cuda.memory_reserved(0) / 1024 / 1024 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(memory_allocated <= 384) self.assertTrue(memory_reserved <= 768) @@ -948,11 +1072,19 @@ def test_multi_device_load_model_state_dict(self) -> None: meta_submodel = nn.Linear(4, 4, bias=False) with torch.device("cpu"): cpu_submodel = nn.Linear(4, 4, bias=False) +<<<<<<< HEAD with torch.device(device_type): acc_submodel = nn.Linear(4, 4, bias=False) two_device_model_with_meta = nn.Sequential(meta_submodel, acc_submodel) two_device_model_without_meta = nn.Sequential(cpu_submodel, acc_submodel) +======= + with torch.device("cuda"): + cuda_submodel = nn.Linear(4, 4, bias=False) + + two_device_model_with_meta = nn.Sequential(meta_submodel, cuda_submodel) + two_device_model_without_meta = nn.Sequential(cpu_submodel, cuda_submodel) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch.device("cpu"): model_to_set = nn.Sequential( @@ -980,7 +1112,11 @@ def test_multi_device_load_model_state_dict(self) -> None: def test_state_dict_with_hook_on_keys(self) -> None: with torch.device("meta"): metamodel = FusionEmbedding(4, 4, 4) +<<<<<<< HEAD with torch.device(device_type): +======= + with torch.device("cuda"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gpumodel = FusionEmbeddingWithHook(4, 4, 4) gpumodel_state_dict = get_model_state_dict(gpumodel) with self.assertRaisesRegex(RuntimeError, "Missing key"): @@ -1001,8 +1137,13 @@ def __init__(self): def forward(self, x): return self.fc1(self.fc(x)) +<<<<<<< HEAD device_mesh = init_device_mesh(device_type, (self.world_size,)) model = TestModel().to(device_type) +======= + device_mesh = init_device_mesh("cuda", (self.world_size,)) + model = TestModel().cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) parallelize_module( model, device_mesh, @@ -1020,7 +1161,11 @@ def _test_multi( optim = torch.optim.AdamW(**optim_kwargs) optim.zero_grad() +<<<<<<< HEAD model(torch.randn(64, 64, device=device_type)).sum().backward() +======= + model(torch.randn(64, 64).cuda()).sum().backward() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) optim.step() optim.zero_grad() @@ -1073,7 +1218,11 @@ def setUp(self) -> None: @skip_if_lt_x_gpu(1) def test_no_dist(self) -> None: +<<<<<<< HEAD model = CompositeParamModel(device=torch.device(device_type)) +======= + model = CompositeParamModel(device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) optim = torch.optim.AdamW(model.parameters(), lr=1e-4) self.assertFalse(dist.is_initialized()) diff --git a/test/distributed/checkpoint/test_state_dict_stager.py b/test/distributed/checkpoint/test_state_dict_stager.py index 8134472f52d5c..36b73838c71a8 100644 --- a/test/distributed/checkpoint/test_state_dict_stager.py +++ b/test/distributed/checkpoint/test_state_dict_stager.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: distributed"] import dataclasses +<<<<<<< HEAD import os import tempfile from datetime import timedelta @@ -18,6 +19,14 @@ from torch.distributed.checkpoint._state_dict_stager import StateDictStager from torch.distributed.checkpoint.staging import _ReplicationStager from torch.distributed.tensor import DeviceMesh, distribute_tensor +======= + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor +from torch.distributed._tensor.placement_types import Shard +from torch.distributed.checkpoint._state_dict_stager import StateDictStager +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -826,6 +835,7 @@ def test_dtensor(self): ) ) self.assertEqual(cpu_state_dict["dtensor"]._spec, dtensor._spec) +<<<<<<< HEAD self.assertEqual(cpu_state_dict["dtensor"].size(), dtensor.size()) @@ -1345,6 +1355,8 @@ def test_replication_persistence(self): # Clean up stager.close() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": diff --git a/test/distributed/checkpoint/test_state_dict_utils.py b/test/distributed/checkpoint/test_state_dict_utils.py index 76e9aeb9e3302..db8963ca7fe18 100644 --- a/test/distributed/checkpoint/test_state_dict_utils.py +++ b/test/distributed/checkpoint/test_state_dict_utils.py @@ -220,6 +220,7 @@ def _verify(cpu_state_dict): self.assertEqual(cpu_state_dict["step"], 7) self.assertEqual(cpu_state_dict["nested"], {"list": [1, 2, 3, 4]}) +<<<<<<< HEAD def _verify_weakref_finalize(cpu_state_dict): import gc @@ -227,6 +228,8 @@ def _verify_weakref_finalize(cpu_state_dict): del cpu_state_dict gc.collect() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cpu_state_dict = _create_cpu_state_dict(state_dict) _verify(cpu_state_dict) cpu_state_dict = _create_cpu_state_dict(state_dict, pin_memory=True) @@ -237,7 +240,10 @@ def _verify_weakref_finalize(cpu_state_dict): state_dict, share_memory=True, pin_memory=True ) _verify(cpu_state_dict) +<<<<<<< HEAD _verify_weakref_finalize(cpu_state_dict) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @with_comms @skip_if_lt_x_gpu(2) diff --git a/test/distributed/elastic/rendezvous/api_test.py b/test/distributed/elastic/rendezvous/api_test.py index 938353a9ffa09..f225c922aad94 100644 --- a/test/distributed/elastic/rendezvous/api_test.py +++ b/test/distributed/elastic/rendezvous/api_test.py @@ -140,6 +140,7 @@ def test_get_as_bool_returns_false_if_value_represents_false(self) -> None: self.assertFalse(params.get_as_bool("dummy_param")) def test_get_as_bool_raises_error_if_value_is_invalid(self) -> None: +<<<<<<< HEAD for value in [ "01", "Flse", # codespell:ignore @@ -151,6 +152,9 @@ def test_get_as_bool_raises_error_if_value_is_invalid(self) -> None: 2, -1, ]: +======= + for value in ["01", "Flse", "Ture", "g", "4", "_", "truefalse", 2, -1]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.subTest(value=value): self._kwargs["dummy_param"] = value diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py index 9b31cf3b1755b..fcc074ca20142 100644 --- a/test/distributed/elastic/test_control_plane.py +++ b/test/distributed/elastic/test_control_plane.py @@ -71,9 +71,15 @@ def test_worker_server(self) -> None: self.assertEqual(resp.status, 200) self.assertIn("ping", json.loads(resp.data)) +<<<<<<< HEAD resp = pool.request("POST", "/handler/nonexistent") self.assertEqual(resp.status, 404) self.assertIn(b"Handler nonexistent not found:", resp.data) +======= + resp = pool.request("POST", "/handler/nonexistant") + self.assertEqual(resp.status, 404) + self.assertIn(b"Handler nonexistant not found:", resp.data) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @requires_cuda def test_dump_nccl_trace_pickle(self) -> None: @@ -207,8 +213,13 @@ def set_status(self, status: int) -> None: def test_get_handler_nonexistant(self) -> None: from torch._C._distributed_c10d import _get_handler +<<<<<<< HEAD with self.assertRaisesRegex(ValueError, "Failed to find handler nonexistent"): _get_handler("nonexistent") +======= + with self.assertRaisesRegex(ValueError, "Failed to find handler nonexistant"): + _get_handler("nonexistant") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_get_handler_names(self) -> None: from torch._C._distributed_c10d import _get_handler_names diff --git a/test/distributed/fsdp/test_distributed_checkpoint.py b/test/distributed/fsdp/test_distributed_checkpoint.py index c80602c5d50f3..418a3a0604347 100644 --- a/test/distributed/fsdp/test_distributed_checkpoint.py +++ b/test/distributed/fsdp/test_distributed_checkpoint.py @@ -31,10 +31,17 @@ sys.exit(0) +<<<<<<< HEAD _DISTRIBUTED_STATE_DICT_IMPLS = ( StateDictType.LOCAL_STATE_DICT, StateDictType.SHARDED_STATE_DICT, ) +======= +_DISTRIBUTED_STATE_DICT_IMPLS = { + StateDictType.LOCAL_STATE_DICT, + StateDictType.SHARDED_STATE_DICT, +} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestDistributedCheckpoint(FSDPTest): diff --git a/test/distributed/fsdp/test_fsdp_comm_hooks.py b/test/distributed/fsdp/test_fsdp_comm_hooks.py index 624e74d373686..46fd1bbb64afa 100644 --- a/test/distributed/fsdp/test_fsdp_comm_hooks.py +++ b/test/distributed/fsdp/test_fsdp_comm_hooks.py @@ -13,7 +13,11 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.testing._internal.common_distributed import ( +<<<<<<< HEAD requires_accelerator_dist_backend, +======= + requires_nccl, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_nccl_version, skip_but_pass_in_sandcastle_if, skip_if_lt_x_gpu, @@ -30,6 +34,7 @@ print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) +<<<<<<< HEAD device_type = ( acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" ) @@ -40,12 +45,23 @@ and (torch.version.cuda is not None or torch.version.hip is not None) ) or torch.xpu.is_available() +======= +# bfloat16 is only supported by CUDA 11+ +BFLOAT16_AVAILABLE = torch.cuda.is_available() and ( + torch.version.cuda is not None or torch.version.hip is not None +) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Net(nn.Module): def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None): # to ensure determinism torch.manual_seed(0) +<<<<<<< HEAD torch.get_device_module(device_type).manual_seed(0) +======= + torch.cuda.manual_seed(0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__() if has_wrapping: @@ -55,12 +71,20 @@ def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None): nn.ReLU(), FSDP( nn.Linear(16, 8), +<<<<<<< HEAD device_id=torch.accelerator.current_device_index(), +======= + device_id=torch.cuda.current_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, ), ), +<<<<<<< HEAD device_id=torch.accelerator.current_device_index(), +======= + device_id=torch.cuda.current_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, ) @@ -139,11 +163,19 @@ def test_default_communication_hook_behavior( """ out_dim = self.world_size net = torch.nn.Linear(1, out_dim, bias=False) +<<<<<<< HEAD inpt = torch.tensor([self.rank]).float().to(self.rank) net_default_hook = FSDP( net, device_id=torch.accelerator.current_device_index(), +======= + inpt = torch.tensor([self.rank]).float().cuda(self.rank) + + net_default_hook = FSDP( + net, + device_id=torch.cuda.current_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sharding_strategy=sharding_strategy, ).to(self.rank) @@ -177,10 +209,17 @@ def _get_submodules(self, fsdp_net): ] def _init_model(self, core, sharding_strategy, mixed_precision=None): +<<<<<<< HEAD device = torch.device(device_type) return FSDP( core, device_id=torch.accelerator.current_device_index(), +======= + device = torch.device("cuda") + return FSDP( + core, + device_id=torch.cuda.current_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, ).to(device) @@ -282,7 +321,11 @@ def test_registering_hook_hybrid_strategy(self): ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2, ): +<<<<<<< HEAD model = Net(False, None, None).to(device=device_type) +======= + model = Net(False, None, None).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fsdp_model = FSDP( model, auto_wrap_policy=ModuleWrapPolicy({nn.Linear}), @@ -342,7 +385,11 @@ def _check_low_precision_hook( ): # keep everything deterministic for input data torch.manual_seed(0) +<<<<<<< HEAD torch.get_device_module(device_type).manual_seed(0) +======= + torch.cuda.manual_seed(0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fsdp_with_hook = self._init_model( Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy), @@ -364,7 +411,11 @@ def _check_low_precision_hook( optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1) optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1) +<<<<<<< HEAD in_data = torch.rand(16, 8).to(device=device_type) +======= + in_data = torch.rand(16, 8).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fsdp_with_hook.train() fsdp_with_mp.train() loss_hook = fsdp_with_hook(in_data).sum() @@ -383,7 +434,11 @@ def _check_low_precision_hook( ): self.assertEqual(hook_param.grad, mp_param.grad) +<<<<<<< HEAD @requires_accelerator_dist_backend(["nccl", "xccl"]) +======= + @requires_nccl() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skip_if_lt_x_gpu(2) @parametrize("has_wrapping", [True, False]) @parametrize( @@ -404,11 +459,19 @@ def test_fp16_hook( state, hook, sharding_strategy, torch.float16, has_wrapping ) +<<<<<<< HEAD @requires_accelerator_dist_backend(["nccl", "xccl"]) @requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS") @skip_but_pass_in_sandcastle_if( not BFLOAT16_AVAILABLE, "BFloat16 is only supported by CUDA 11+ or XPU", +======= + @requires_nccl() + @requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS") + @skip_but_pass_in_sandcastle_if( + not BFLOAT16_AVAILABLE, + "BFloat16 is only supported by CUDA 11+", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @skip_if_lt_x_gpu(2) @parametrize("has_wrapping", [True, False]) diff --git a/test/distributed/fsdp/test_fsdp_flatten_params.py b/test/distributed/fsdp/test_fsdp_flatten_params.py index 12e432f214f30..89a3fd06f1fb3 100644 --- a/test/distributed/fsdp/test_fsdp_flatten_params.py +++ b/test/distributed/fsdp/test_fsdp_flatten_params.py @@ -44,11 +44,16 @@ def world_size(self) -> int: return 1 def _get_default_config(self): +<<<<<<< HEAD device_type = ( acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" ) return { "device": torch.device(device_type), +======= + return { + "device": torch.device("cuda"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "sharding_strategy": HandleShardingStrategy.FULL_SHARD, "offload_params": False, "mp_param_dtype": None, @@ -650,6 +655,7 @@ def test_flat_param_shard_metadata_with_memory_format(self, memory_format): ), ) +<<<<<<< HEAD @skip_if_lt_x_gpu(1) def test_writeback_orig_params_no_shard(self): class EmbeddingModel(nn.Module): @@ -680,6 +686,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out = fsdp_model(x) self.assertEqual(out.shape, torch.Size([])) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_parametrized_tests(TestFlattenParams) diff --git a/test/distributed/fsdp/test_fsdp_freezing_weights.py b/test/distributed/fsdp/test_fsdp_freezing_weights.py index ad318a6bf7520..03d27a42fad19 100644 --- a/test/distributed/fsdp/test_fsdp_freezing_weights.py +++ b/test/distributed/fsdp/test_fsdp_freezing_weights.py @@ -31,8 +31,11 @@ ) sys.exit(0) +<<<<<<< HEAD device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Model(nn.Module): def __init__( @@ -49,6 +52,10 @@ def __init__( nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), ) +<<<<<<< HEAD +======= + self.device = torch.cuda.current_device() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.head = nn.Linear(64, 10) if with_fsdp and freeze_after_wrap_fsdp: self.fsdp_wrap(fsdp_kwargs) @@ -146,7 +153,11 @@ def _dist_train( forward_prefetch, ): torch.manual_seed(0) +<<<<<<< HEAD batch = torch.randn(size=(2, 3, 224, 224)).to(device_type) +======= + batch = torch.randn(size=(2, 3, 224, 224)).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fsdp_kwargs = { "device_id": self.rank, @@ -165,7 +176,11 @@ def _dist_train( disable_autograd, fsdp_kwargs, ) +<<<<<<< HEAD model = model.to(device_type) +======= + model = model.cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # freezing the trunk using requires_grad. if freezing_method == FreezingMethod.RequiresGrad: @@ -179,7 +194,11 @@ def _dist_train( else: model = DistributedDataParallel(model, **ddp_kwargs) +<<<<<<< HEAD target = torch.tensor([0, 1], dtype=torch.long).to(device_type) +======= + target = torch.tensor([0, 1], dtype=torch.long).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) diff --git a/test/distributed/fsdp/test_fsdp_hybrid_shard.py b/test/distributed/fsdp/test_fsdp_hybrid_shard.py index 26a05bbc41714..d336102a367c0 100644 --- a/test/distributed/fsdp/test_fsdp_hybrid_shard.py +++ b/test/distributed/fsdp/test_fsdp_hybrid_shard.py @@ -49,8 +49,11 @@ ) sys.exit(0) +<<<<<<< HEAD device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @contextlib.contextmanager def patch_allreduce(new_allreduce): @@ -99,7 +102,11 @@ class ShardingStrategyMode(Enum): class TestFSDPHybridShard(FSDPTest): @property def world_size(self): +<<<<<<< HEAD return max(torch.accelerator.device_count(), 2) +======= + return max(torch.cuda.device_count(), 2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def process_group(self): @@ -107,7 +114,11 @@ def process_group(self): @skip_if_lt_x_gpu(2) def test_raises_manual_wrap_hybrid_shard_when_none_policy(self): +<<<<<<< HEAD model = MyModel().to(device_type) +======= + model = MyModel().cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) err_ctx = self.assertRaisesRegex( ValueError, "requires explicit specification of process group or device_mesh.", @@ -121,8 +132,13 @@ def test_raises_manual_wrap_hybrid_shard_when_none_policy(self): @skip_if_lt_x_gpu(4) def test_hsdp_save_load_state_dict(self): +<<<<<<< HEAD model = MyModel().to(device_type) num_node_devices = torch.accelerator.device_count() +======= + model = MyModel().cuda() + num_node_devices = torch.cuda.device_count() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_rank_lists = ( list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices)), @@ -163,7 +179,11 @@ def test_hsdp_save_load_state_dict(self): msd = model.state_dict() osd = FSDP.optim_state_dict(model, optim) +<<<<<<< HEAD load_model = fsdp_ctor(MyModel().to(device_type)) +======= + load_model = fsdp_ctor(MyModel().cuda()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) load_optim = torch.optim.AdamW(load_model.parameters()) with FSDP.state_dict_type(load_model, StateDictType.SHARDED_STATE_DICT): load_model.load_state_dict(msd) @@ -172,8 +192,13 @@ def test_hsdp_save_load_state_dict(self): @skip_if_lt_x_gpu(4) def test_hsdp_sync_module_state(self): +<<<<<<< HEAD model = MyModel().to(device_type) num_node_devices = torch.accelerator.device_count() +======= + model = MyModel().cuda() + num_node_devices = torch.cuda.device_count() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_rank_lists = ( list(range(0, num_node_devices // 2)), list(range(num_node_devices // 2, num_node_devices)), @@ -216,7 +241,11 @@ def test_hsdp_sync_module_state(self): @skip_if_lt_x_gpu(2) def test_invalid_pg_specification_raises(self): pol = ModuleWrapPolicy({nn.Linear}) +<<<<<<< HEAD model = MyModel().to(device_type) +======= + model = MyModel().cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaisesRegex( ValueError, "Expected process_group to be passed in" ): @@ -262,7 +291,11 @@ def _test_fsdp_hybrid_shard_basic_setup( use_device_mesh: bool, ): if use_device_mesh: +<<<<<<< HEAD device_mesh = init_device_mesh(device_type, (1, self.world_size)) +======= + device_mesh = init_device_mesh("cuda", (1, self.world_size)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: device_mesh = None hsdp_model = self._init_hsdp_model( @@ -318,7 +351,11 @@ def patched_collective(orig_collective, counter, *args, **kwargs): patch_allreduce(patched_allreduce), patch_reduce_scatter(patched_reduce_scatter), ): +<<<<<<< HEAD inp = hsdp_model.get_input(device=torch.accelerator.current_device_index()) +======= + inp = hsdp_model.get_input(device=torch.cuda.current_device()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = hsdp_model(inp[0], inp[1]) loss = hsdp_model.get_loss(inp, out) loss.backward() @@ -367,7 +404,11 @@ def _test_fsdp_hybrid_shard_parity( hsdp_optim = torch.optim.Adam(hsdp_model.parameters(), lr=1e-2) torch.manual_seed(global_pg.rank() + 1) for _ in range(5): +<<<<<<< HEAD inp = fsdp_model.module.get_input(torch.device(device_type)) +======= + inp = fsdp_model.module.get_input(torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) losses: list[torch.Tensor] = [] for model, optim in ((fsdp_model, fsdp_optim), (hsdp_model, hsdp_optim)): optim.zero_grad() @@ -383,7 +424,11 @@ def _init_fsdp_model(self, use_orig_params: bool) -> nn.Module: ) hsdp_kwargs = { "auto_wrap_policy": auto_wrap_policy, +<<<<<<< HEAD "device_id": torch.accelerator.current_device_index(), +======= + "device_id": torch.cuda.current_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "use_orig_params": use_orig_params, } fsdp_model = TransformerWithSharedParams.init( @@ -410,7 +455,11 @@ def _init_hsdp_model( {TransformerEncoderLayer, TransformerDecoderLayer}, ) hsdp_kwargs = { +<<<<<<< HEAD "device_id": torch.accelerator.current_device_index(), +======= + "device_id": torch.cuda.current_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "auto_wrap_policy": auto_wrap_policy, "sharding_strategy": hsdp_sharding_strategy, "use_orig_params": use_orig_params, @@ -437,7 +486,11 @@ def _init_hsdp_model( # Use `FULL_SHARD` for the embedding and output projection hsdp_model = FSDP( model, +<<<<<<< HEAD device_id=torch.accelerator.current_device_index(), +======= + device_id=torch.cuda.current_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sharding_strategy=ShardingStrategy.FULL_SHARD, use_orig_params=use_orig_params, ) diff --git a/test/distributed/fsdp/test_fsdp_ignored_modules.py b/test/distributed/fsdp/test_fsdp_ignored_modules.py index d8974327ea5dd..6b6f8c0dd41a3 100644 --- a/test/distributed/fsdp/test_fsdp_ignored_modules.py +++ b/test/distributed/fsdp/test_fsdp_ignored_modules.py @@ -36,8 +36,11 @@ ) sys.exit(0) +<<<<<<< HEAD device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Model(torch.nn.Module): def __init__(self) -> None: @@ -96,9 +99,15 @@ def __init__(self, num_ignored: int) -> None: class TestFSDPIgnoredModules(FSDPTest): @property def world_size(self): +<<<<<<< HEAD return min(torch.accelerator.device_count(), 2) def _train_model(self, model, optim, num_iters, device=torch.device(device_type)): +======= + return min(torch.cuda.device_count(), 2) + + def _train_model(self, model, optim, num_iters, device=torch.device("cuda")): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for _ in range(num_iters): module = model.module if isinstance(model, FSDP) else model inp = module.get_input(device) @@ -200,7 +209,11 @@ def _test_ignored_modules_nested(self, use_orig_params: bool, ignore_modules: bo # Initialize an FSDP-wrapped nested model that first wraps the nested # sequential's second linear layer (`layer1[1]`) and then wraps the # overall model while ignoring the nested sequential (`layer1`) +<<<<<<< HEAD model = Model().to(device_type) +======= + model = Model().cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fsdp_fn = functools.partial(FSDP, use_orig_params=use_orig_params) model.layer1[1] = fsdp_fn(model.layer1[1]) if ignore_modules: @@ -248,7 +261,11 @@ def test_ignored_states_auto_wrap(self): ) def _test_ignored_states_auto_wrap(self, policy, ignore_bias: bool): +<<<<<<< HEAD model = Model().to(device_type) +======= + model = Model().cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ignored_states = [model.layer1[1].weight] if ignore_bias: ignored_states.append(model.layer1[1].bias) @@ -287,7 +304,11 @@ def _test_ignored_states_auto_wrap(self, policy, ignore_bias: bool): def test_ignored_modules_invalid(self): """Tests that passing an FSDP module as an ignored module or the top-level module itself errors.""" +<<<<<<< HEAD model = Model().to(device_type) +======= + model = Model().cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) wrap_cls = FSDP model.layer1 = wrap_cls(model.layer1) # Passing an FSDP module as an ignored module should error @@ -304,7 +325,11 @@ def test_ignored_modules_invalid(self): ): # FSDP does not allow to wrap the same model twice, so create # a new local model here. +<<<<<<< HEAD new_model = Model().to(device_type) +======= + new_model = Model().cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) wrap_cls(new_model, ignored_modules=[new_model]) @skip_if_lt_x_gpu(2) @@ -336,7 +361,11 @@ def _test_diff_ignored_modules_across_ranks( # we wrap `layer3` with FSDP, where `layer3` is registered as a module # after `layer1`, which has the variable number of ignored modules wrap_cls = FSDP +<<<<<<< HEAD model = ModelWithIgnoredModules(num_ignored=self.rank + 1).to(device_type) +======= + model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) layer1_ignored_modules = [ m for m in model.layer1.modules() if isinstance(m, IgnoredModule) ] @@ -372,7 +401,11 @@ def _test_diff_ignored_modules_across_ranks( @skip_if_lt_x_gpu(2) @parametrize("ignore_modules", [True, False]) def test_ignored_modules_not_under_wrapped_root(self, ignore_modules: bool): +<<<<<<< HEAD model = Model().to(device_type) +======= + model = Model().cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ignored_modules = list(model.layer1.children())[1:] ignore_kwargs = ( @@ -411,7 +444,11 @@ def test_ignored_states_check(self): ) def _test_ignored_states_check(self, ignore_modules: bool): +<<<<<<< HEAD model = Model().to(device_type) +======= + model = Model().cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ignored_modules = list(model.layer1.children())[1:] ignored_params = {p for m in ignored_modules for p in m.parameters()} ignored_states = ignored_params.union(set(ignored_modules)) diff --git a/test/distributed/fsdp/test_fsdp_memory.py b/test/distributed/fsdp/test_fsdp_memory.py index 93391f01b376d..dcc0cf0343f5d 100644 --- a/test/distributed/fsdp/test_fsdp_memory.py +++ b/test/distributed/fsdp/test_fsdp_memory.py @@ -14,7 +14,10 @@ instantiate_parametrized_tests, parametrize, run_tests, +<<<<<<< HEAD TEST_CUDA, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TEST_HPU, TEST_WITH_DEV_DBG_ASAN, ) @@ -32,6 +35,7 @@ ) sys.exit(0) +<<<<<<< HEAD device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" @@ -40,6 +44,13 @@ def get_cur_mem(rank, result, prefix): if TEST_CUDA: torch._C._cuda_clearCublasWorkspaces() result[prefix] = round(torch.accelerator.memory_allocated() / 1024 / 1024) +======= + +def get_cur_mem(rank, result, prefix): + """Collect memory allocated values in a result dict in MB""" + torch._C._cuda_clearCublasWorkspaces() + result[prefix] = round(torch.cuda.memory_allocated() / 1024 / 1024) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Model(nn.Module): @@ -114,14 +125,22 @@ def world_size(self): def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations): gpu_id = self.rank +<<<<<<< HEAD batch = torch.randn(size=(2, 3, 224, 224)).to(device_type) +======= + batch = torch.randn(size=(2, 3, 224, 224)).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model = create_model( with_fsdp=True, with_checkpoint=with_checkpoint, model_hidden_dim=model_hidden_dim, ) +<<<<<<< HEAD model = model.to(device_type) +======= + model = model.cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model = FSDP(model) # We enable momentum so that after the first iteration, the optimizer state is added @@ -137,7 +156,11 @@ def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations): get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd") out = sum(o.sum() for o in out[0]) +<<<<<<< HEAD fake_loss = criterion(out, torch.tensor(0.0).to(device_type)) +======= + fake_loss = criterion(out, torch.tensor(0.0).cuda()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) get_cur_mem(gpu_id, results, f"iter {iteration}: after loss") fake_loss.backward() @@ -162,7 +185,11 @@ def cmp(results, expected): output = cmp(results, expected) self.assertEqual(output, "") +<<<<<<< HEAD @unittest.skipIf(TEST_HPU, "Memory will be different for CUDA and HPU, skipping") +======= + @unittest.skipIf(TEST_HPU, "Memory will be differnt for CUDA and HPU, skipping") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skip_if_lt_x_gpu(2) @parametrize("ckpt", ["no_ckpt", "ckpt"]) def test_fsdp_memory(self, ckpt): @@ -171,8 +198,13 @@ def test_fsdp_memory(self, ckpt): model = create_model( with_fsdp=False, with_checkpoint=False, model_hidden_dim=model_hidden_dim +<<<<<<< HEAD ).to(device_type) model_size_mb = round(torch.accelerator.memory_allocated() / 1024 / 1024) +======= + ).cuda() + model_size_mb = round(torch.cuda.memory_allocated() / 1024 / 1024) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) del model sharded_model_size_mb = int(model_size_mb / self.world_size) diff --git a/test/distributed/fsdp/test_fsdp_meta.py b/test/distributed/fsdp/test_fsdp_meta.py index d3b0079a24adc..1586983f00e06 100644 --- a/test/distributed/fsdp/test_fsdp_meta.py +++ b/test/distributed/fsdp/test_fsdp_meta.py @@ -43,8 +43,11 @@ ) sys.exit(0) +<<<<<<< HEAD device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _reset_params_if_meta(is_meta: bool, model: nn.Module): # For torchdistX init, we don't need to call reset_params, as @@ -119,7 +122,11 @@ def _init_with_reset_params(module: nn.Module): ) ) if has_meta_states: +<<<<<<< HEAD device = torch.device(device_type, torch.accelerator.current_device_index()) +======= + device = torch.device("cuda", torch.cuda.current_device()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) module.to_empty(device=device, recurse=False) module.reset_parameters() @@ -166,13 +173,21 @@ def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None): # Test to make sure it is the same model parameters as regular FSDP # approach. +<<<<<<< HEAD regular = MyModel(device=device_type) +======= + regular = MyModel(device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _reset_params_if_meta(is_meta, regular) fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap) regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) self._compare_fsdp(fsdp_meta, fsdp_regular) +<<<<<<< HEAD inp = torch.randn(10, 2, device=device_type) +======= + inp = torch.randn(10, 2, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fsdp_meta(inp).sum().backward() fsdp_regular(inp).sum().backward() meta_opt.step() @@ -184,7 +199,11 @@ def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None): model = meta_module_fn() fsdp_meta = FSDP(model, param_init_fn=init_fn) meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3) +<<<<<<< HEAD regular = MyModel(device=device_type) +======= + regular = MyModel(device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _reset_params_if_meta(is_meta, regular) fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap) regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) @@ -219,7 +238,11 @@ def meta_module_fn(): ) def test_simple_model_with_torchdistX_default_init(self): def meta_module_fn(): +<<<<<<< HEAD return deferred_init.deferred_init(MyModel, device=device_type) +======= + return deferred_init.deferred_init(MyModel, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._test_simple_model_with_meta_device(meta_module_fn) @@ -230,7 +253,11 @@ def meta_module_fn(): ) def test_simple_model_with_torchdistX_init_fn(self): def meta_module_fn(): +<<<<<<< HEAD return deferred_init.deferred_init(MyModel, device=device_type) +======= + return deferred_init.deferred_init(MyModel, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._test_simple_model_with_meta_device( meta_module_fn, init_fn=_init_with_torchdistX @@ -250,7 +277,11 @@ def _test_nested_model_with_meta_device( param_init_fn=init_fn, ) meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3) +<<<<<<< HEAD module_regular = NestedModel(device=device_type) +======= + module_regular = NestedModel(device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _reset_params_if_meta(is_meta, module_regular) fsdp_regular = FSDP( module_regular, @@ -271,7 +302,11 @@ def _test_nested_model_with_meta_device( # Init and reset parameters before wrapping so that reset_params # matches up with meta device's initialization. +<<<<<<< HEAD module_regular = NestedModel(device=device_type) +======= + module_regular = NestedModel(device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _reset_params_if_meta(is_meta, module_regular) with enable_wrap(wrapper_cls=FSDP): module_regular.lin1 = wrap(module_regular.lin1) @@ -281,7 +316,11 @@ def _test_nested_model_with_meta_device( # Compare it before training self._compare_fsdp(fsdp_meta, fsdp_regular) +<<<<<<< HEAD inp = torch.randn(10, 2, device=device_type) +======= + inp = torch.randn(10, 2, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fsdp_meta(inp).sum().backward() fsdp_regular(inp).sum().backward() meta_opt.step() @@ -319,7 +358,11 @@ def meta_module_fn(): @parametrize("auto_wrap", [True, False]) def test_nested_model_with_torchdistX_default_init(self, auto_wrap): def meta_module_fn(): +<<<<<<< HEAD return deferred_init.deferred_init(NestedModel, device=device_type) +======= + return deferred_init.deferred_init(NestedModel, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._test_nested_model_with_meta_device( auto_wrap=auto_wrap, meta_module_fn=meta_module_fn @@ -333,7 +376,11 @@ def meta_module_fn(): @parametrize("auto_wrap", [True, False]) def test_nested_model_with_torchdistX_init_fn(self, auto_wrap): def meta_module_fn(): +<<<<<<< HEAD return deferred_init.deferred_init(NestedModel, device=device_type) +======= + return deferred_init.deferred_init(NestedModel, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._test_nested_model_with_meta_device( auto_wrap=auto_wrap, @@ -353,7 +400,11 @@ def _test_bad_arg(self, meta_module_fn): ) def test_bad_arg_torchdistx(self): def meta_module_fn(): +<<<<<<< HEAD return deferred_init.deferred_init(NestedModel, device_type) +======= + return deferred_init.deferred_init(NestedModel, "cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._test_bad_arg(meta_module_fn) @@ -403,7 +454,11 @@ def _param_init_fn(module: nn.Module) -> None: # TODO: `module.to_empty()` is not generally correct for meta # device initialization. # https://github.com/pytorch/pytorch/issues/90465 +<<<<<<< HEAD module.to_empty(device=torch.device(device_type)) +======= + module.to_empty(device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) module.apply(model._module_init_fn) model = Model() @@ -416,7 +471,11 @@ def _param_init_fn(module: nn.Module) -> None: param_dtype=torch.float32, reduce_dtype=torch.float16 ), param_init_fn=_param_init_fn, +<<<<<<< HEAD device_id=torch.accelerator.current_device_index(), +======= + device_id=torch.cuda.current_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py index 45c1668dfb2e2..8aa0ca78235e9 100644 --- a/test/distributed/fsdp/test_fsdp_misc.py +++ b/test/distributed/fsdp/test_fsdp_misc.py @@ -60,10 +60,13 @@ ) sys.exit(0) +<<<<<<< HEAD device_type = ( acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class MyModel(nn.Module): def __init__(self) -> None: @@ -97,9 +100,15 @@ def test_fsdp_device_id(self, use_index): without specifying a device ID (i.e. ``torch.device("cuda")``) warns """ dev_id = ( +<<<<<<< HEAD torch.accelerator.current_device_index() if use_index else torch.device(device_type, torch.accelerator.current_device_index()) +======= + torch.cuda.current_device() + if use_index + else torch.device("cuda", torch.cuda.current_device()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def _check_device_matches(module, device_id): @@ -112,7 +121,11 @@ def _check_device_matches(module, device_id): self.assertEqual(1, len(devices)) found_device = devices.pop() if use_index and not isinstance(device_id, torch.device): +<<<<<<< HEAD device = torch.device(device_type, device_id) +======= + device = torch.device("cuda", device_id) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: device = device_id self.assertEqual(found_device, device) @@ -144,11 +157,18 @@ def _check_device_matches(module, device_id): self.process_group, FSDPInitMode.RECURSIVE, DEVICEInitMode.DEVICE_BEFORE, +<<<<<<< HEAD fsdp_kwargs={"device_id": torch.device(device_type)}, ) _check_device_matches( nested_wrapped_module, torch.device(device_type, torch.accelerator.current_device_index()), +======= + fsdp_kwargs={"device_id": torch.device("cuda")}, + ) + _check_device_matches( + nested_wrapped_module, torch.device("cuda", torch.cuda.current_device()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @skip_if_lt_x_gpu(2) @@ -183,8 +203,13 @@ def forward(self, x, y): loss = torch.nn.functional.cross_entropy(output, y) return loss +<<<<<<< HEAD model = Mnist().to(device=device_type) model1 = Mnist().to(device=device_type) +======= + model = Mnist().cuda() + model1 = Mnist().cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model1.load_state_dict(model.state_dict()) fsdp_model = FSDP( model, @@ -202,17 +227,30 @@ def forward(self, x, y): seed = self.rank + 20231010 torch.manual_seed(seed) +<<<<<<< HEAD torch.get_device_module(device_type).manual_seed(seed) +======= + torch.cuda.manual_seed(seed) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) losses = [] grads = [] for i in range(5): +<<<<<<< HEAD x = torch.randn(8, 1, 28, 28, device=device_type).requires_grad_() y = torch.randint(low=0, high=9, size=(8,), device=device_type) for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)): seed = self.rank + i torch.manual_seed(seed) torch.get_device_module(device_type).manual_seed(seed) +======= + x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_() + y = torch.randint(low=0, high=9, size=(8,), device="cuda") + for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)): + seed = self.rank + i + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) loss = model(x, y).sum() losses.append(loss) loss.backward() @@ -228,8 +266,13 @@ def forward(self, x, y): fsdp_model.eval() ddp_model.eval() for _ in range(5): +<<<<<<< HEAD x = torch.randn(8, 1, 28, 28, device=device_type).requires_grad_() y = torch.randint(low=0, high=9, size=(8,), device=device_type) +======= + x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_() + y = torch.randint(low=0, high=9, size=(8,), device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fsdp_loss = fsdp_model(x, y) ddp_loss = ddp_model(x, y) assert torch.allclose(fsdp_loss, ddp_loss) @@ -237,12 +280,21 @@ def forward(self, x, y): fsdp_model.train() ddp_model.train() for i in range(5): +<<<<<<< HEAD x = torch.randn(8, 1, 28, 28, device=device_type).requires_grad_() y = torch.randint(low=0, high=9, size=(8,), device=device_type) for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)): seed = self.rank + i torch.manual_seed(seed) torch.get_device_module(device_type).manual_seed(seed) +======= + x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_() + y = torch.randint(low=0, high=9, size=(8,), device="cuda") + for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)): + seed = self.rank + i + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) loss = model(x, y).sum() losses.append(loss) loss.backward() @@ -277,12 +329,21 @@ def forward(self, x, y): return out1 fsdp = FSDP( +<<<<<<< HEAD MyModel().to(device=device_type), sharding_strategy=sharding_strategy, auto_wrap_policy=always_wrap_policy, ) x = torch.randn(10, 10, device=device_type) y = torch.randn(10, 10, device=device_type) +======= + MyModel().cuda(), + sharding_strategy=sharding_strategy, + auto_wrap_policy=always_wrap_policy, + ) + x = torch.randn(10, 10, device="cuda") + y = torch.randn(10, 10, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for _ in range(4): if use_second_layer: a, _ = fsdp(x, y) @@ -341,7 +402,11 @@ def _check_equal(local, fsdp): torch.testing.assert_close(p1, p2) fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy) +<<<<<<< HEAD m = MyModule().to(device=device_type) +======= + m = MyModule().cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m_local = deepcopy(m) local_m = m_local prev_params = [p.clone() for p in m_local.parameters()] @@ -354,7 +419,11 @@ def _check_equal(local, fsdp): opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3) for i in range(6): +<<<<<<< HEAD t = torch.ones(4, device=device_type) +======= + t = torch.ones(4, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a, b = m(t) local_a, local_b = local_m(t) if i < 2: @@ -390,7 +459,11 @@ def _check_equal(local, fsdp): @skip_if_lt_x_gpu(2) def test_fsdp_optim_overlap_no_use_orig_params_error(self): fsdp_overlap = FSDP( +<<<<<<< HEAD MyModel().to(device=device_type), +======= + MyModel().cuda(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto_wrap_policy=always_wrap_policy, use_orig_params=False, ) @@ -403,7 +476,11 @@ def test_fsdp_optim_overlap_no_use_orig_params_error(self): register_hook=False, ) +<<<<<<< HEAD inp = torch.randn(10, 10, device=device_type) +======= + inp = torch.randn(10, 10, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaisesRegex( RuntimeError, "only supported with use_orig_params=True" ): @@ -414,16 +491,27 @@ def test_fsdp_optimizer_overlap(self): torch.manual_seed(0) for cpu_offload in [True, False]: offload = CPUOffload(offload_params=cpu_offload) +<<<<<<< HEAD model = MyModel().to(device=device_type) model_overlap = deepcopy(model) fsdp = FSDP( model.to(device=device_type), +======= + model = MyModel().cuda() + model_overlap = deepcopy(model) + fsdp = FSDP( + model.cuda(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto_wrap_policy=always_wrap_policy, use_orig_params=True, cpu_offload=offload, ) fsdp_overlap = FSDP( +<<<<<<< HEAD model_overlap.to(device=device_type), +======= + model_overlap.cuda(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto_wrap_policy=always_wrap_policy, use_orig_params=True, cpu_offload=offload, @@ -450,7 +538,11 @@ def test_fsdp_optimizer_overlap(self): ] for i in range(6): +<<<<<<< HEAD inp = torch.randn(2, 2, device=device_type) +======= + inp = torch.randn(2, 2, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch.no_grad(): inp_clone = inp.clone() fsdp(inp, inp).sum().backward() @@ -551,7 +643,11 @@ def test_fsdp_cpu_init_stays_on_cpu(self): """Tests that passing a CPU module to FSDP preserves that the wrapped module is on CPU after FSDP initialization, albeit after logging a warning, and that FSDP moves CPU input to GPU before the forward.""" +<<<<<<< HEAD torch.accelerator.set_device_index(self.rank) +======= + torch.cuda.set_device(self.rank) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) regex = "passed-in `module` is on CPU" context = self.assertWarnsRegex( expected_warning=UserWarning, expected_regex=regex @@ -566,7 +662,11 @@ def test_fsdp_cpu_init_stays_on_cpu(self): devices = {p.device for p in fsdp_model.parameters()} self.assertEqual(1, len(devices)) self.assertEqual(torch.device("cpu"), devices.pop()) +<<<<<<< HEAD fsdp_model = fsdp_model.to(device=device_type) +======= + fsdp_model = fsdp_model.cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Ensure fwd + backward can be performed after moving to CUDA. # CPU input also tests that input is correctly moved to appropriate # CUDA device. @@ -611,19 +711,31 @@ def init_nested_wrapped_module(): nested_wrapped_module, self.process_group, auto_wrap_policy=ModuleWrapPolicy({nn.Linear}), +<<<<<<< HEAD device_id=torch.accelerator.current_device_index(), +======= + device_id=torch.cuda.current_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sync_module_states=True, ) # Each rank's buffers should be 0s since rank 0 is the source, and they # should be on GPU since we specified `device_id` self.assertEqual( nested_wrapped_module.buf.device, +<<<<<<< HEAD torch.device(device_type, torch.accelerator.current_device_index()), +======= + torch.device("cuda", torch.cuda.current_device()), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.assertEqual(nested_wrapped_module.buf, torch.zeros((2, 2))) self.assertEqual( nested_wrapped_module.module.module[0].buf.device, +<<<<<<< HEAD torch.device(device_type, torch.accelerator.current_device_index()), +======= + torch.device("cuda", torch.cuda.current_device()), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.assertEqual( nested_wrapped_module.module.module[0].buf, torch.zeros((3, 2)) @@ -649,9 +761,15 @@ def __init__(self) -> None: def forward(self, x): return x +<<<<<<< HEAD m = MyModule().to(device=device_type) m = FSDP(m) t = torch.ones(1, device=device_type, requires_grad=True) +======= + m = MyModule().cuda() + m = FSDP(m) + t = torch.ones(1, device="cuda", requires_grad=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MyOutputType = namedtuple( "MyOutputType", ["a", "b", "c", "d"], defaults=(t, t, t, t) @@ -688,7 +806,11 @@ def _test_device_id_auto_wrap(self, use_callable: bool): auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_kwargs = { "auto_wrap_policy": auto_wrap_policy, +<<<<<<< HEAD "device_id": torch.accelerator.current_device_index(), +======= + "device_id": torch.cuda.current_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } fsdp_model = TransformerWithSharedParams.init( self.process_group, @@ -699,7 +821,11 @@ def _test_device_id_auto_wrap(self, use_callable: bool): for fsdp_module in FSDP.fsdp_modules(fsdp_model): self.assertEqual( fsdp_module.compute_device, +<<<<<<< HEAD torch.device(device_type, torch.accelerator.current_device_index()), +======= + torch.device("cuda", torch.cuda.current_device()), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @skip_if_lt_x_gpu(2) @@ -734,7 +860,11 @@ def forward(self, x): model, auto_wrap_policy=auto_wrap_policy, cpu_offload=CPUOffload(offload_params=True), +<<<<<<< HEAD device_id=torch.accelerator.current_device_index(), +======= + device_id=torch.cuda.current_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) use_orig_params=use_orig_params, ) cpu_device = torch.device("cpu") @@ -747,6 +877,7 @@ def test_module_device_mismatches_device_id(self): module that does not match the GPU device ID raises an error.""" # TODO: override FSDP MT Thread _run to set this instead of here for # every test. +<<<<<<< HEAD torch.accelerator.set_device_index(self.rank) context = ( @@ -757,6 +888,14 @@ def test_module_device_mismatches_device_id(self): else nullcontext() ) +======= + torch.cuda.set_device(self.rank) + context = ( + self.assertRaisesRegex(ValueError, f"cuda:{self.rank} vs cuda:0") + if self.rank != 0 + else nullcontext() + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with context: NestedWrappedModule.init( self.process_group, @@ -773,11 +912,16 @@ def test_cpu_gpu_module(self): """Tests a CPU + GPU module supported if device_id is passed in, errors if device_id is not. """ +<<<<<<< HEAD torch.accelerator.set_device_index(self.rank) +======= + torch.cuda.set_device(self.rank) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class CPUGPUModule(nn.Module): def __init__(self) -> None: super().__init__() +<<<<<<< HEAD self.a = nn.Linear(1, 1).to(device=device_type) self.b = nn.Linear(1, 1) @@ -787,6 +931,15 @@ def __init__(self) -> None: self.assertEqual( param.device, torch.device(torch.accelerator.current_device_index()) ) +======= + self.a = nn.Linear(1, 1).cuda() + self.b = nn.Linear(1, 1) + + cpu_gpu = CPUGPUModule() + fsdp = FSDP(cpu_gpu, device_id=torch.cuda.current_device()) + for param in fsdp.parameters(): + self.assertEqual(param.device, torch.device(torch.cuda.current_device())) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # without device_id, we hit an error with self.assertRaisesRegex(RuntimeError, "please pass in device_id"): @@ -794,7 +947,11 @@ def __init__(self) -> None: @skip_if_lt_x_gpu(2) def test_fsdp_ignored_module_meta(self): +<<<<<<< HEAD torch.accelerator.set_device_index(self.rank) +======= + torch.cuda.set_device(self.rank) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class CPUGPUModule(nn.Module): def __init__(self) -> None: @@ -813,11 +970,19 @@ def __init__(self) -> None: m = CPUGPUModule() m = FSDP( m, +<<<<<<< HEAD device_id=torch.accelerator.current_device_index(), ignored_modules=[m.a], use_orig_params=True, param_init_fn=lambda m: m.to_empty( device=torch.accelerator.current_device_index(), recurse=False +======= + device_id=torch.cuda.current_device(), + ignored_modules=[m.a], + use_orig_params=True, + param_init_fn=lambda m: m.to_empty( + device=torch.cuda.current_device(), recurse=False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ) self.assertEqual(meta_device, next(m.a.parameters()).device) @@ -865,11 +1030,16 @@ def test_no_params(self): """ # TODO: override FSDP MT Thread _run to set this instead of here for # every test. +<<<<<<< HEAD torch.accelerator.set_device_index(self.rank) +======= + torch.cuda.set_device(self.rank) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Test CPU no_params = nn.ReLU() FSDP(no_params) # Test CUDA +<<<<<<< HEAD no_params = nn.ReLU().to(device=device_type) FSDP(no_params) # Test CPU + device_id @@ -879,6 +1049,17 @@ def test_no_params(self): # inconsistency between compute_device and device_id, since compute_device # is computed as torch.cuda.current_device when there are no params. no_params = nn.ReLU().to(device=device_type) +======= + no_params = nn.ReLU().cuda() + FSDP(no_params) + # Test CPU + device_id + no_params = nn.ReLU() + FSDP(no_params, device_id=torch.cuda.current_device()) + # For modules with no params, wrong device_id will raise error about + # inconsistency between compute_device and device_id, since compute_device + # is computed as torch.cuda.current_device when there are no params. + no_params = nn.ReLU().cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) context = ( ( self.assertRaisesRegex( @@ -903,11 +1084,19 @@ def __init__(self, rank): super().__init__() # Seed via rank to make model different across ranks torch.manual_seed(rank) +<<<<<<< HEAD torch.get_device_module(device_type).manual_seed(rank) self.lin = nn.Linear(10, 10, bias=False) self.buffer = nn.Buffer(torch.ones(1) * rank) m = MyModel(self.rank).to(device=device_type) +======= + torch.cuda.manual_seed(rank) + self.lin = nn.Linear(10, 10, bias=False) + self.buffer = nn.Buffer(torch.ones(1) * rank) + + m = MyModel(self.rank).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _assert_module_states( m, process_group=self.process_group, assert_fn=self.assertNotEqual ) @@ -924,11 +1113,15 @@ def __init__(self, rank): m, process_group=self.process_group, assert_fn=self.assertNotEqual ) # Passing sync_module_states into FSDP makes model the same during init. +<<<<<<< HEAD fsdp = FSDP( m, device_id=torch.accelerator.current_device_index(), sync_module_states=True, ) +======= + fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with fsdp.summon_full_params(fsdp): _assert_module_states( fsdp, process_group=self.process_group, assert_fn=self.assertEqual @@ -983,7 +1176,11 @@ def _test_homogeneous_attributes(self, attr_name_and_values: tuple[str, Any, Any with self.assertRaisesRegex( ValueError, f"Expects one homogeneous value for {attr_name}" ): +<<<<<<< HEAD inp = fsdp_model.module.get_input(torch.device(device_type)) +======= + inp = fsdp_model.module.get_input(torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fsdp_model(*inp) @skip_if_lt_x_gpu(2) @@ -991,7 +1188,11 @@ def test_fsdp_unsupported_module_cls(self): regex = r"FSDP will not all-gather parameters for containers that do not implement forward" model = nn.ModuleList([MLP(8, torch.device("cpu")) for _ in range(3)]) with self.assertWarnsRegex(UserWarning, regex): +<<<<<<< HEAD FSDP(model, device_id=device_type) +======= + FSDP(model, device_id="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model = nn.ModuleDict( {"1": MLP(8, torch.device("cpu")), "2": MLP(8, torch.device("cpu"))} ) @@ -1015,10 +1216,14 @@ def test_world_size_1_sharding_strategy_warning(self): # warning with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") # trigger all warnings +<<<<<<< HEAD FSDP( nn.Linear(3, 3).to(device=device_type), sharding_strategy=ShardingStrategy.NO_SHARD, ) +======= + FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.NO_SHARD) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for warning in w: self.assertTrue( warning.category != UserWarning @@ -1032,20 +1237,30 @@ def test_world_size_1_sharding_strategy_warning(self): warning_prefix + " " + str(ShardingStrategy.FULL_SHARD) + warning_suffix ) with self.assertWarnsRegex(UserWarning, expected_regex_full_shard): +<<<<<<< HEAD FSDP( nn.Linear(3, 3).to(device=device_type), sharding_strategy=ShardingStrategy.FULL_SHARD, ) with self.assertWarnsRegex(UserWarning, expected_regex_full_shard): FSDP(nn.Linear(3, 3).to(device=device_type)) +======= + FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.FULL_SHARD) + with self.assertWarnsRegex(UserWarning, expected_regex_full_shard): + FSDP(nn.Linear(3, 3).cuda()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # - Pass `SHARD_GRAD_OP` expected_regex_shard_grad_op = ( warning_prefix + " " + str(ShardingStrategy.SHARD_GRAD_OP) + warning_suffix ) with self.assertWarnsRegex(UserWarning, expected_regex_shard_grad_op): FSDP( +<<<<<<< HEAD nn.Linear(3, 3).to(device=device_type), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, +======= + nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @skip_if_lt_x_gpu(1) @@ -1069,7 +1284,11 @@ def test_training_device_mismatch_errors(self): # Incorrectly moving from CPU -> GPU model = torch.nn.Linear(10, 10) fsdp_model = FSDP(model, cpu_offload=CPUOffload(offload_params=True)) +<<<<<<< HEAD fsdp_model.to(torch.device(device_type)) +======= + fsdp_model.to(torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inp = torch.randn((2, 10)) with self.assertRaisesRegex( RuntimeError, @@ -1110,16 +1329,26 @@ def __setattr__(self, name: str, value: Any) -> None: # Construct FSDP module without changing any environment variables and # run forward, which triggers both unsharded and sharded view setting +<<<<<<< HEAD module = SetattrLinear(5, 5, torch.device(device_type)) fsdp_module = FSDP(module, use_orig_params=use_orig_params) inp = torch.randn((8, 5), device=torch.device(device_type)) +======= + module = SetattrLinear(5, 5, torch.device("cuda")) + fsdp_module = FSDP(module, use_orig_params=use_orig_params) + inp = torch.randn((8, 5), device=torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) called_setattr_override = False fsdp_module(inp) self.assertTrue(called_setattr_override) # Repeat with unsafe setattr explicitly enabled os.environ[_FSDP_USE_UNSAFE_SETATTR] = "1" +<<<<<<< HEAD module = SetattrLinear(5, 5, torch.device(device_type)) +======= + module = SetattrLinear(5, 5, torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fsdp_module = FSDP(module, use_orig_params=use_orig_params) called_setattr_override = False fsdp_module(inp) @@ -1127,7 +1356,11 @@ def __setattr__(self, name: str, value: Any) -> None: # Repeat with unsafe setattr explicitly disabled os.environ[_FSDP_USE_UNSAFE_SETATTR] = "0" +<<<<<<< HEAD module = SetattrLinear(5, 5, torch.device(device_type)) +======= + module = SetattrLinear(5, 5, torch.device("cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fsdp_module = FSDP(module, use_orig_params=use_orig_params) called_setattr_override = False fsdp_module(inp) diff --git a/test/distributed/fsdp/test_fsdp_uneven.py b/test/distributed/fsdp/test_fsdp_uneven.py index d0094ce1de71f..393618af78099 100644 --- a/test/distributed/fsdp/test_fsdp_uneven.py +++ b/test/distributed/fsdp/test_fsdp_uneven.py @@ -45,7 +45,11 @@ def _get_ref_results(self, device, model, input, my_lr): def test_one_iteration(self, device): """Test FSDP with uneven divide of parameter shards.""" model = Linear(3, 3, bias=False) +<<<<<<< HEAD input = torch.rand(self.world_size, 3) +======= + input = torch.rand(8, 3) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) my_lr = 0.1 ref_forward_output_my_rank, ref_weight_out = self._get_ref_results( diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index 5c90ad8be144e..0a28b69bfe732 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -1165,23 +1165,45 @@ def closure_ddp(): # Increased tolerances are needed to pass when using TF32 # See: https://github.com/pytorch/pytorch/issues/67764 +<<<<<<< HEAD torch.testing.assert_close( local_loss.cpu(), ddp_loss.cpu(), rtol=1e-03, atol=1e-08, msg="Losses differ between local optimizer and ZeRO", +======= + ( + torch.testing.assert_close( + local_loss.cpu(), + ddp_loss.cpu(), + rtol=1e-03, + atol=1e-08, + ), + "Losses differ between local optimizer and ZeRO", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for local_p, ddp_p in zip( local_model.parameters(), ddp_model.parameters() ): +<<<<<<< HEAD torch.testing.assert_close( local_p.cpu(), ddp_p.cpu(), rtol=1e-03, atol=1e-04, msg="Models differ after a step", +======= + ( + torch.testing.assert_close( + local_p.cpu(), + ddp_p.cpu(), + rtol=1e-03, + atol=1e-04, + ), + "Models differ after a step", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @skipIfHpu diff --git a/test/distributed/pipelining/model_registry.py b/test/distributed/pipelining/model_registry.py index 347dad6fb766c..b2d031fbf3ffa 100644 --- a/test/distributed/pipelining/model_registry.py +++ b/test/distributed/pipelining/model_registry.py @@ -211,10 +211,17 @@ def __init__(self, d_hid: int): self.fc2_weight = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.fc2_bias = torch.nn.Parameter(torch.randn(d_hid)) +<<<<<<< HEAD torch.nn.init.uniform_(self.fc1_weight, -0.001, 0.001) torch.nn.init.uniform_(self.fc2_weight, -0.001, 0.001) torch.nn.init.uniform_(self.fc1_bias, -0.001, 0.001) torch.nn.init.uniform_(self.fc2_bias, -0.001, 0.001) +======= + torch.nn.init.uniform_(self.fc1_weight, -0.01, 0.01) + torch.nn.init.uniform_(self.fc2_weight, -0.01, 0.01) + torch.nn.init.uniform_(self.fc1_bias, -0.01, 0.01) + torch.nn.init.uniform_(self.fc2_bias, -0.01, 0.01) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.cached_context = {} self.cached_context["fc1"] = [] diff --git a/test/distributed/pipelining/schedule_registry.py b/test/distributed/pipelining/schedule_registry.py index 9b401193a1720..289a98b0099af 100644 --- a/test/distributed/pipelining/schedule_registry.py +++ b/test/distributed/pipelining/schedule_registry.py @@ -45,7 +45,11 @@ def __init__( ) # Go through one microbatch +<<<<<<< HEAD # Note(whc) - it might be easier to work with this schedules by writing them as a list of +======= + # Note(whc) - it might be easier to work with thes schedules by writing them as a list of +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ["0F0", ...] and then parsing them in the test infra to turn them into actions. self.pipeline_order = { 0: [ diff --git a/test/distributed/pipelining/test_backward.py b/test/distributed/pipelining/test_backward.py index b46a97d02c29e..1afb6e1faa2b3 100644 --- a/test/distributed/pipelining/test_backward.py +++ b/test/distributed/pipelining/test_backward.py @@ -10,10 +10,14 @@ stage_backward_input, stage_backward_weight, ) +<<<<<<< HEAD from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, skipXPUIf, ) +======= +from torch.testing._internal.common_device_type import instantiate_device_type_tests +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import run_tests, TestCase @@ -22,7 +26,10 @@ class StageBackwardTests(TestCase): +<<<<<<< HEAD @skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1682") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_stage_backward(self, device): # MLP as a stage module mod = MLPModule(d_hid).to(device) @@ -97,7 +104,10 @@ def test_stage_backward_input(self, device): # Check that the weight gradients were not updated self.assertEqual(p.grad, None) +<<<<<<< HEAD @skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1682") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_stage_backward_weight(self, device): # MLP as a stage module mod = MLPModule(d_hid).to(device) @@ -138,7 +148,10 @@ def test_stage_backward_weight(self, device): print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") raise +<<<<<<< HEAD @skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1682") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_stage_backward_weight_multiple_iters(self, device): # MLP as a stage module mod = MLPModule(d_hid).to(device) @@ -189,6 +202,7 @@ def test_stage_backward_weight_multiple_iters(self, device): print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") raise +<<<<<<< HEAD def test_stage_backward_weight_grad_validation(self, device): test_cases = [ ( @@ -232,6 +246,11 @@ def test_stage_backward_weight_grad_validation(self, device): instantiate_device_type_tests( StageBackwardTests, globals(), only_for=devices, allow_xpu=True ) +======= + +devices = ["cpu", "cuda", "hpu", "xpu"] +instantiate_device_type_tests(StageBackwardTests, globals(), only_for=devices) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/pipelining/test_microbatch.py b/test/distributed/pipelining/test_microbatch.py index 99bb0fddaa21c..f91a49032cd4c 100644 --- a/test/distributed/pipelining/test_microbatch.py +++ b/test/distributed/pipelining/test_microbatch.py @@ -9,10 +9,14 @@ split_args_kwargs_into_chunks, TensorChunkSpec, ) +<<<<<<< HEAD from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, skipXPUIf, ) +======= +from torch.testing._internal.common_device_type import instantiate_device_type_tests +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import run_tests, TestCase @@ -59,7 +63,10 @@ def test_split_and_merge(self): torch.testing.assert_close(merged_kwargs, kwargs) print("Microbatch test passed") +<<<<<<< HEAD @skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1682") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_chunk_spec(self, device): mod = ModelWithKwargs().to(device) batch_size = ModelWithKwargs.DEFAULT_BATCH_SIZE @@ -88,15 +95,22 @@ def test_chunk_spec(self, device): ref = mod(x, y) out = pipe(x, y)[0] +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.testing.assert_close(out, ref) print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") devices = ["cpu", "cuda", "hpu", "xpu"] +<<<<<<< HEAD instantiate_device_type_tests( MicrobatchTests, globals(), only_for=devices, allow_xpu=True ) +======= +instantiate_device_type_tests(MicrobatchTests, globals(), only_for=devices) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index dabf3d78a6f13..e4410ce825c07 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -10,12 +10,18 @@ import torch from torch.distributed.pipelining import ( Schedule1F1B, +<<<<<<< HEAD ScheduleDualPipeV, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ScheduleGPipe, ScheduleInterleaved1F1B, ScheduleInterleavedZeroBubble, ScheduleLoopedBFS, +<<<<<<< HEAD ScheduleZBVZeroBubble, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from torch.distributed.pipelining._utils import generate_stage_to_rank_mapping from torch.distributed.pipelining.schedules import ( @@ -40,7 +46,11 @@ W, ) from torch.distributed.pipelining.stage import _PipelineStageBase, PipelineStage +<<<<<<< HEAD from torch.testing._internal.common_distributed import requires_accelerator_dist_backend +======= +from torch.testing._internal.common_distributed import requires_nccl +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import ( check_leaked_tensors, instantiate_parametrized_tests, @@ -53,7 +63,10 @@ ARTIFACTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "artifacts") +<<<<<<< HEAD device = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) logger = logging.getLogger(__name__) torch.manual_seed(0) @@ -351,6 +364,7 @@ def stage_to_rank(stage): num_stages=num_stages, ) +<<<<<<< HEAD @parametrize( "ScheduleClass", [ScheduleDualPipeV, ScheduleZBVZeroBubble], @@ -392,10 +406,13 @@ def test_pipeline_order_for_v_schedules(self, ScheduleClass): schedule.pipeline_order, group_size, num_stages, num_microbatches ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_parametrized_tests(TestSchedulePlan) +<<<<<<< HEAD class TestScheduleCsv(TestCase): @parametrize( "ScheduleClass,csv_name", @@ -436,6 +453,8 @@ def test_csv_compare(self, ScheduleClass, csv_name): instantiate_parametrized_tests(TestScheduleCsv) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestScheduleLowering(TestCase): """Tests lowering passes that convert simple compute-only (FBW) schedules into compute+comms schedules""" @@ -741,7 +760,11 @@ def _dump_csv(pipeline_order_with_comms, filename: str): # print(_format_pipeline_order(simulated_schedule)) self.assertEqual(num_steps, 113) +<<<<<<< HEAD @requires_accelerator_dist_backend(["nccl", "xccl"]) +======= + @requires_nccl() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_grad_with_v_schedule(self): """ We have a special case for V schedules where 2 adjacent stages are on the same rank. @@ -761,6 +784,10 @@ def test_grad_with_v_schedule(self): d_hid = 512 batch_size = 256 n_stages = 2 +<<<<<<< HEAD +======= + device = "cuda" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) full_mod = MultiMLP(d_hid, n_layers=n_stages) full_mod.to(device) @@ -804,7 +831,11 @@ def test_grad_with_v_schedule(self): loss_fn=loss_fn, scale_grads=False, ) +<<<<<<< HEAD schedule._prepare_schedule_with_comms( +======= + schedule._load_actions( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) { 0: self._parse_actions( [ @@ -859,7 +890,11 @@ def test_grad_with_v_schedule(self): torch.distributed.destroy_process_group() +<<<<<<< HEAD @requires_accelerator_dist_backend(["nccl", "xccl"]) +======= + @requires_nccl() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_grad_with_split_b_w(self): """ Ensure that separate dInput and dWeight computations are correctly executed. @@ -872,6 +907,10 @@ def test_grad_with_split_b_w(self): d_hid = 512 batch_size = 256 n_stages = 1 +<<<<<<< HEAD +======= + device = "cuda" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) full_mod = MultiMLP(d_hid, n_layers=n_stages) full_mod.to(device) @@ -914,7 +953,11 @@ def test_grad_with_split_b_w(self): num_microbatches, loss_fn=loss_fn, ) +<<<<<<< HEAD schedule._prepare_schedule_with_comms( +======= + schedule._load_actions( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) { 0: self._parse_actions( [ diff --git a/test/distributed/pipelining/test_schedule_multiproc.py b/test/distributed/pipelining/test_schedule_multiproc.py index 9ba12c3d69965..29897f6c4db02 100644 --- a/test/distributed/pipelining/test_schedule_multiproc.py +++ b/test/distributed/pipelining/test_schedule_multiproc.py @@ -3,7 +3,10 @@ import copy import logging import tempfile +<<<<<<< HEAD from dataclasses import dataclass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from model_registry import ModelWithKwargs, MultiMLP, MultiMLPKwargs, MultiMLPWithDw from schedule_registry import ( @@ -20,7 +23,10 @@ pipeline, PipelineStage, Schedule1F1B, +<<<<<<< HEAD ScheduleDualPipeV, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ScheduleGPipe, ScheduleInterleaved1F1B, ScheduleInterleavedZeroBubble, @@ -28,10 +34,17 @@ ScheduleZBVZeroBubble, ) from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime +<<<<<<< HEAD from torch.nn.modules.loss import MSELoss from torch.testing._internal.common_distributed import ( MultiProcContinuousTest, requires_accelerator_dist_backend, +======= +from torch.testing._internal.common_cuda import TEST_MULTIGPU +from torch.testing._internal.common_distributed import ( + MultiProcContinousTest, + requires_nccl, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from torch.testing._internal.common_utils import ( check_leaked_tensors, @@ -45,6 +58,7 @@ logger = logging.getLogger(__name__) d_hid = 512 +<<<<<<< HEAD batch_size = 64 torch.manual_seed(0) device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" @@ -202,16 +216,30 @@ def zero_gradients(stage_modules): class ScheduleTest(MultiProcContinuousTest): world_size = 4 +======= +batch_size = 256 +torch.manual_seed(0) +device_type = "cuda" + + +class ScheduleTest(MultiProcContinousTest): + world_size = 2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def backend_str(cls) -> str: # Testing with NCCL backend +<<<<<<< HEAD return backend +======= + return "nccl" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def device(self) -> torch.device: return torch.device(device_type, self.rank) +<<<<<<< HEAD @property def config(self) -> PipelineTestConfig: """Lazily create and return the pipeline test configuration.""" @@ -236,6 +264,40 @@ def test_forward_only(self, ScheduleClass): # Run forward-only schedule out = None +======= + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("ScheduleClass", [_ScheduleForwardOnly]) + def test_forward_only(self, ScheduleClass): + mod = MultiMLP(d_hid, n_layers=self.world_size) + mod.to(self.device) + + mod_ref = copy.deepcopy(mod) + + x = torch.randn(batch_size, d_hid, device=self.device) + x_clone = x.clone() + + num_microbatches = 2 * self.world_size + x_mb = x.chunk(num_microbatches)[0] + + # Create a pipeline + split_spec = mod.split_spec if hasattr(mod, "split_spec") else None + pipe = pipeline( + mod, + mb_args=(x_mb,), + split_spec=split_spec, + ) + + stage = pipe.build_stage( + self.rank, + self.device, + ) + + # Attach to a schedule + schedule = ScheduleClass(stage, num_microbatches, scale_grads=False) + + # Run +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) num_iters = 20 for _ in range(num_iters): if self.rank == 0: @@ -247,6 +309,7 @@ def test_forward_only(self, ScheduleClass): else: schedule.step() +<<<<<<< HEAD # Validate pipelined output matches reference model if self.rank == self.world_size - 1: for _ in range(num_iters): @@ -339,6 +402,43 @@ def test_multi_iter(self, ScheduleClass): mod, _, x, target, loss_fn = setup_models_and_data(self.config) chunks = 4 stage, _, _ = create_single_stage_pipeline(self.config, mod, x, chunks) +======= + # Validate pipelined output is the same as reference model + if self.rank == self.world_size - 1: + for _ in range(num_iters): + x_clone = mod_ref(x_clone) + + torch.testing.assert_close(x_clone, out) + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) + def test_multi_iter(self, ScheduleClass): + mod = MultiMLP(d_hid, n_layers=self.world_size) + mod.to(self.device) + + x = torch.randn(batch_size, d_hid, device=self.device) + target = torch.randn(batch_size, d_hid, device=self.device) + loss_fn = torch.nn.MSELoss(reduction="sum") + + chunks = 4 + x_mb = x.chunk(chunks)[0] + + # Create a pipeline + split_spec = mod.split_spec if hasattr(mod, "split_spec") else None + pipe = pipeline( + mod, + mb_args=(x_mb,), + split_spec=split_spec, + ) + + stage = pipe.build_stage( + self.rank, + self.device, + ) + + # Attach to a schedule +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn, scale_grads=False) # Run @@ -351,6 +451,7 @@ def test_multi_iter(self, ScheduleClass): else: schedule.step() +<<<<<<< HEAD dist.barrier(device_ids=[self.rank]) @requires_accelerator_dist_backend(["nccl", "xccl"]) @@ -360,6 +461,19 @@ def test_multi_iter(self, ScheduleClass): @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) def test_kwargs_with_tracer(self, ScheduleClass): mod = ModelWithKwargs(d_hid, splits=self.world_size) +======= + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) + def test_kwargs_with_tracer(self, ScheduleClass): + # Model has two stages only, thus limiting group size to 2 + group_size = 2 + group = dist.new_group(list(range(group_size))) + if self.rank >= group_size: + return + + mod = ModelWithKwargs(d_hid) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mod.to(self.device) x = torch.randn(batch_size, d_hid, device=self.device) @@ -380,31 +494,50 @@ def test_kwargs_with_tracer(self, ScheduleClass): stage = pipe.build_stage( self.rank, self.device, +<<<<<<< HEAD +======= + group=group, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Attach to a schedule schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn, scale_grads=False) # Run +<<<<<<< HEAD out = None losses = [] if self.rank == 0: schedule.step(x, y=y) elif self.rank == self.world_size - 1: +======= + if self.rank == 0: + schedule.step(x, y=y) + elif self.rank == group_size - 1: + losses = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = schedule.step(target=target, losses=losses) else: schedule.step() +<<<<<<< HEAD dist.barrier(device_ids=[self.rank]) # Last rank checks result if self.rank == self.world_size - 1: +======= + # dist.barrier() + + # Last rank checks result + if self.rank == group_size - 1: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ref_out = mod(x, y=y) ref_loss = loss_fn(ref_out, target) pipe_loss = sum(losses) torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=5e-3) torch.testing.assert_close(pipe_loss, ref_loss) +<<<<<<< HEAD @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_but_pass_in_sandcastle_if( not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" @@ -431,10 +564,63 @@ def test_grad_with_tracer(self, ScheduleClass): if self.rank == 0: schedule.step(x) elif self.rank == self.world_size - 1: +======= + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) + def test_grad_with_tracer(self, ScheduleClass): + mod = MultiMLP(d_hid, n_layers=self.world_size) + mod.to(self.device) + + ref_mod = copy.deepcopy(mod) + x = torch.randn(batch_size, d_hid, device=self.device) + with torch.no_grad(): + y = ref_mod(x) + # Add a small perturbation + target = y + torch.randn(batch_size, d_hid, device=self.device) + + loss_fn = torch.nn.MSELoss(reduction="sum") + + # Run reference + for _ in range(2): + ref_mod.zero_grad() + ref_out = ref_mod(x) + ref_loss = loss_fn(ref_out, target) + ref_loss.backward() + + # Create a pipeline + chunks = 2 * self.world_size + x_mb = x.chunk(chunks)[0] + split_spec = mod.split_spec if hasattr(mod, "split_spec") else None + pipe = pipeline( + mod, + mb_args=(x_mb,), + split_spec=split_spec, + ) + + stage = pipe.build_stage( + self.rank, + self.device, + ) + + # Attach to a schedule + schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn, scale_grads=False) + + # Run + stage_module = pipe.get_stage_module(self.rank) + for _ in range(2): + # Zero gradients + stage_module.zero_grad() + if self.rank == 0: + schedule.step(x) + elif self.rank == self.world_size - 1: + losses = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = schedule.step(target=target, losses=losses) else: schedule.step() +<<<<<<< HEAD dist.barrier(device_ids=[self.rank]) # Last rank checks result @@ -488,10 +674,92 @@ def test_grad_with_manual(self, ScheduleClass, shape_inference): if self.rank == 0: schedule.step(x) elif self.rank == self.world_size - 1: +======= + dist.barrier() + + # Last rank checks result + if self.rank == self.world_size - 1: + # Check output + torch.testing.assert_close(out, ref_out) + # Check loss + # Since the reduction used in the loss function above is "sum", we use + # "sum" here to reduce microbatch losses into a single value too. + pipe_loss = sum(losses) + torch.testing.assert_close(pipe_loss, ref_loss) + + # Every rank checks gradients + for name, p in stage_module.named_parameters(): + ref_p = ref_mod.get_parameter(name) + try: + torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) + except AssertionError: + print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") + raise + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) + @parametrize("shape_inference", [True, False]) + def test_grad_with_manual(self, ScheduleClass, shape_inference): + full_mod = MultiMLP(d_hid, n_layers=self.world_size) + full_mod.to(self.device) + + ref_mod = copy.deepcopy(full_mod) + x = torch.randn(batch_size, d_hid, device=self.device) + with torch.no_grad(): + y = ref_mod(x) + # Add a small perturbation + target = y + torch.randn(batch_size, d_hid, device=self.device) + + loss_fn = torch.nn.MSELoss(reduction="sum") + + # Run reference + for _ in range(2): + ref_mod.zero_grad() + ref_out = ref_mod(x) + ref_loss = loss_fn(ref_out, target) + ref_loss.backward() + + # Get a submodule, e.g. `layers.0` or `layers.1` + submod_name = f"layers.{self.rank}" + stage_module = full_mod.get_submodule(submod_name) + chunks = 2 * self.world_size + + if shape_inference: + input_args = None + output_args = None + else: + input_args = (x.chunk(chunks)[0],) + with torch.no_grad(): + output_args = stage_module(*input_args) + + # Create a pipeline stage to wrap that submodule + stage = PipelineStage( + stage_module, + self.rank, + self.world_size, + self.device, + input_args=input_args, + output_args=output_args, + ) + + # Attach to a schedule + schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn, scale_grads=False) + + # Run + for _ in range(2): + # Zero gradients + stage_module.zero_grad() + if self.rank == 0: + schedule.step(x) + elif self.rank == self.world_size - 1: + losses = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = schedule.step(target=target, losses=losses) else: schedule.step() +<<<<<<< HEAD dist.barrier(device_ids=[self.rank]) # Last rank checks result @@ -507,6 +775,32 @@ def test_grad_with_manual(self, ScheduleClass, shape_inference): @skip_but_pass_in_sandcastle_if( not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) +======= + dist.barrier() + + # Last rank checks result + if self.rank == self.world_size - 1: + # Check output + torch.testing.assert_close(out, ref_out) + # Check loss + # Since the reduction used in the loss function above is "sum", we use + # "sum" here to reduce microbatch losses into a single value too. + pipe_loss = sum(losses) + torch.testing.assert_close(pipe_loss, ref_loss) + + # Every rank checks gradients + ref_submod = ref_mod.get_submodule(submod_name) + for name, p in stage_module.named_parameters(): + ref_p = ref_submod.get_parameter(name) + try: + torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) + except AssertionError: + print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") + raise + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize( "ScheduleClass", [ @@ -519,6 +813,7 @@ def test_grad_with_manual(self, ScheduleClass, shape_inference): def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): stages_per_rank = 2 n_stages = stages_per_rank * self.world_size +<<<<<<< HEAD mod, ref_mod, x, target, loss_fn = setup_models_and_data( self.config, n_layers=n_stages ) @@ -532,11 +827,43 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): ) print(f"Rank {self.rank} stages: {[stage.stage_index for stage in stages]}") +======= + full_mod = MultiMLP(d_hid, n_layers=n_stages) + full_mod.to(self.device) + + ref_mod = copy.deepcopy(full_mod) + x = torch.randn(batch_size, d_hid, device=self.device) + with torch.no_grad(): + y = ref_mod(x) + # Add a small perturbation + target = y + torch.randn(batch_size, d_hid, device=self.device) + + loss_fn = torch.nn.MSELoss(reduction="sum") + + # Run reference + for _ in range(2): + ref_mod.zero_grad() + ref_out = ref_mod(x) + ref_loss = loss_fn(ref_out, target) + ref_loss.backward() + + # Get a submodule, e.g. `layers.0` or `layers.1` + stage_indices = [ + self.rank + i * self.world_size for i in range(stages_per_rank) + ] + print(f"Rank {self.rank} stages: {stage_indices}") + submod_names = [f"layers.{i}" for i in stage_indices] + stage_modules = [ + full_mod.get_submodule(submod_name) for submod_name in submod_names + ] + # Create a pipeline stage to wrap that submodule +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) num_microbatches = ( ScheduleClass.num_microbatches if hasattr(ScheduleClass, "num_microbatches") else 2 * self.world_size ) +<<<<<<< HEAD # Create schedule schedule = ScheduleClass( @@ -554,11 +881,43 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): # Test CSV round-trip for compute_comms schedule schedule = _PipelineScheduleRuntime( stages, num_microbatches, loss_fn=loss_fn, scale_grads=False +======= + stages = [ + PipelineStage( + stage_module, + stage_idx, + n_stages, + self.device, + ) + for stage_module, stage_idx in zip(stage_modules, stage_indices) + ] + + # Attach to a schedule + schedule = ScheduleClass( + stages, num_microbatches, loss_fn=loss_fn, scale_grads=False + ) + if use_new_runtime: + old_schedule = schedule + tmp_schedule = _PipelineScheduleRuntime( + stages, + num_microbatches, + loss_fn=loss_fn, + scale_grads=False, + ) + tmp_schedule._load_actions(old_schedule.pipeline_order) + # test that csv round-trip works for compute_comms schedule + schedule = _PipelineScheduleRuntime( + stages, + num_microbatches, + loss_fn=loss_fn, + scale_grads=False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) with tempfile.NamedTemporaryFile() as f: tmp_schedule._dump_csv(f.name) f.seek(0) schedule._load_csv(f.name, format="compute_comms") +<<<<<<< HEAD one_more_schedule = _PipelineScheduleRuntime( stages, num_microbatches, loss_fn=loss_fn, scale_grads=False @@ -571,11 +930,33 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): self.assertEqual( len(schedule.pipeline_order_with_comms), len(one_more_schedule.pipeline_order_with_comms), +======= + one_more_schedule = _PipelineScheduleRuntime( + stages, + num_microbatches, + loss_fn=loss_fn, + scale_grads=False, + ) + one_more_schedule._load_actions( + schedule.pipeline_order_with_comms, format="compute_comms" + ) + self.assertEqual( + len(schedule.pipeline_order_with_comms), + len( + one_more_schedule.pipeline_order_with_comms, + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for rank in schedule.pipeline_order_with_comms: self.assertEqual( len(schedule.pipeline_order_with_comms[rank]), +<<<<<<< HEAD len(one_more_schedule.pipeline_order_with_comms[rank]), +======= + len( + one_more_schedule.pipeline_order_with_comms[rank], + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for a, b in zip( schedule.pipeline_order_with_comms[rank], @@ -583,6 +964,7 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): ): self.assertEqual(a, b) +<<<<<<< HEAD # Run pipeline with tensor leak checking out = None losses = [] @@ -596,6 +978,21 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): else: schedule.step() +======= + # Run + with check_leaked_tensors() as garbage_tensors: + for _ in range(2): + # Zero gradients + for stage_module in stage_modules: + stage_module.zero_grad() + if self.rank == 0: + schedule.step(x) + elif self.rank == self.world_size - 1: + losses = [] + out = schedule.step(target=target, losses=losses) + else: + schedule.step() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( len(garbage_tensors), 0, @@ -603,6 +1000,7 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): ) dist.barrier() +<<<<<<< HEAD # Verify results if self.rank == self.world_size - 1: torch.testing.assert_close(out, ref_out) @@ -619,10 +1017,342 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): @skip_but_pass_in_sandcastle_if( not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) +======= + # Last rank checks result + if self.rank == self.world_size - 1: + # Check output + torch.testing.assert_close(out, ref_out) + # Check loss + # Since the reduction used in the loss function above is "sum", we use + # "sum" here to reduce microbatch losses into a single value too. + pipe_loss = sum(losses) + torch.testing.assert_close(pipe_loss, ref_loss) + + # Every rank checks gradients + for stage_module, submod_name in zip(stage_modules, submod_names): + # Get corresponding submodule from reference model + ref_submod = ref_mod.get_submodule(submod_name) + # Check gradients per parameter + for name, p in stage_module.named_parameters(): + ref_p = ref_submod.get_parameter(name) + try: + torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=1e-3) + except AssertionError: + print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") + raise + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("ScheduleClass", [ScheduleWithW, ScheduleInterleavedZeroBubble]) + def test_schedule_with_native_zero_bubble(self, ScheduleClass): + print(ScheduleClass) + if ScheduleClass is ScheduleInterleavedZeroBubble: + n_stages = 4 + num_microbatches = 2 * n_stages + rank_stages = { + 0: [0, 2], + 1: [1, 3], + } + else: + n_stages = ScheduleClass.n_stages + num_microbatches = ScheduleClass.num_microbatches + rank_stages = ScheduleClass.rank_stages + + num_steps = 4 + full_mod = MultiMLP(d_hid, n_layers=n_stages) + full_mod.to(self.device) + + ref_mod = copy.deepcopy(full_mod) + x = torch.randn(batch_size, d_hid, device=self.device) + # x = torch.randn(batch_size, d_hid, device=self.device, requires_grad=True) + with torch.no_grad(): + y = ref_mod(x) + # Add a small perturbation + target = y + torch.randn(batch_size, d_hid, device=self.device) + + loss_fn = torch.nn.MSELoss(reduction="sum") + + # Create a pipeline stage to wrap that submodule + stage_indices = rank_stages[self.rank] + print(f"Rank {self.rank} stages: {stage_indices}") + submod_names = [f"layers.{i}" for i in stage_indices] + stage_modules = [ + full_mod.get_submodule(submod_name) for submod_name in submod_names + ] + stages = [ + PipelineStage( + stage_module, + stage_idx, + n_stages, + self.device, + ) + for stage_module, stage_idx in zip(stage_modules, rank_stages[self.rank]) + ] + + # We set scale_grads=False since we use a loss function that sums instead of mean-reduces + # (note: normally we recommend using mean-reduce loss functions, but we preserve at least one test case + # using sum scaling for completeness) + schedule = ScheduleClass( + stages, num_microbatches, loss_fn=loss_fn, scale_grads=False + ) + + # Run reference + ref_x = x.detach().clone().requires_grad_(x.requires_grad) + torch.testing.assert_close(x, ref_x) + for _ in range(num_steps): + ref_out = ref_mod(ref_x) + ref_loss = loss_fn(ref_out, target) + ref_loss.backward() + + with check_leaked_tensors() as garbage_tensors: + # Run pipelined stages + for _ in range(num_steps): + if self.rank == 0: + schedule.step(x) + elif self.rank == self.world_size - 1: + losses = [] + schedule.step(target=target, losses=losses) + else: + schedule.step() + self.assertEqual( + len(garbage_tensors), + 0, + "Found leaked tensors, check logs above for debug info", + ) + + # Every rank checks parameters compared with the reference model + for stage_module, submod_name in zip(stage_modules, submod_names): + # Get corresponding submodule from reference model + ref_submod = ref_mod.get_submodule(submod_name) + # Check gradients per parameter + for name, p in stage_module.named_parameters(): + ref_p = ref_submod.get_parameter(name) + try: + torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) + except AssertionError: + print( + f"Parameter test failed for {submod_name}.{name}: {p.grad} vs {ref_p.grad}" + ) + raise + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize( + "ScheduleClass", + [ + ScheduleWithReorderedB, + ], + ) + def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass): + n_stages = 2 + num_microbatches = 2 + stages_per_rank = 1 + full_mod = MultiMLP(d_hid, n_layers=n_stages) + full_mod.to(self.device) + + ref_mod = copy.deepcopy(full_mod) + x = torch.randn(batch_size, d_hid, device=self.device) + with torch.no_grad(): + y = ref_mod(x) + # Add a small perturbation + target = y + torch.randn(batch_size, d_hid, device=self.device) + + loss_fn = torch.nn.MSELoss(reduction="sum") + + # Run reference + for _ in range(2): + ref_mod.zero_grad() + ref_out = ref_mod(x) + ref_loss = loss_fn(ref_out, target) + ref_loss.backward() + + # Get a submodule, e.g. `layers.0` or `layers.1` + stage_indices = [ + self.rank + i * self.world_size for i in range(stages_per_rank) + ] + print(f"Rank {self.rank} stages: {stage_indices}") + submod_names = [f"layers.{i}" for i in stage_indices] + stage_modules = [ + full_mod.get_submodule(submod_name) for submod_name in submod_names + ] + # Create a pipeline stage to wrap that submodule + num_microbatches = ( + ScheduleClass.num_microbatches + if hasattr(ScheduleClass, "num_microbatches") + else 8 + ) + stages = [ + PipelineStage( + stage_module, + stage_idx, + n_stages, + self.device, + ) + for stage_module, stage_idx in zip(stage_modules, stage_indices) + ] + + # Attach to a schedule + schedule = ScheduleClass( + stages, num_microbatches, loss_fn=loss_fn, scale_grads=False + ) + assert isinstance(schedule, _PipelineScheduleRuntime) + + # Run + with check_leaked_tensors() as garbage_tensors: + for _ in range(2): + # Zero gradients + for stage_module in stage_modules: + stage_module.zero_grad() + if self.rank == 0: + schedule.step(x) + elif self.rank == self.world_size - 1: + losses = [] + out = schedule.step(target=target, losses=losses) + else: + schedule.step() + self.assertEqual( + len(garbage_tensors), + 0, + "Found leaked tensors, check logs above for debug info", + ) + dist.barrier() + + # Last rank checks result + if self.rank == self.world_size - 1: + # Check output + torch.testing.assert_close(out, ref_out) + # Check loss + # Since the reduction used in the loss function above is "sum", we use + # "sum" here to reduce microbatch losses into a single value too. + pipe_loss = sum(losses) + torch.testing.assert_close(pipe_loss, ref_loss) + + # Every rank checks gradients + for stage_module, submod_name in zip(stage_modules, submod_names): + # Get corresponding submodule from reference model + ref_submod = ref_mod.get_submodule(submod_name) + # Check gradients per parameter + for name, p in stage_module.named_parameters(): + ref_p = ref_submod.get_parameter(name) + try: + torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) + except AssertionError: + print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") + raise + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize( + "schedule_class", [ScheduleVShaped, ScheduleUnbalanced, ScheduleZBVZeroBubble] + ) + @parametrize("use_new_runtime", [False, True]) + def test_non_symmetric_stage_ids(self, schedule_class, use_new_runtime): + if schedule_class is ScheduleZBVZeroBubble: + n_stages = 4 + rank_stages = { + 0: [0, 3], + 1: [1, 2], + } + else: + n_stages = schedule_class.n_stages + rank_stages = schedule_class.rank_stages + full_mod = MultiMLP(d_hid, n_layers=n_stages) + full_mod.to(self.device) + + ref_mod = copy.deepcopy(full_mod) + x = torch.randn(batch_size, d_hid, device=self.device) + with torch.no_grad(): + y = ref_mod(x) + # Add a small perturbation + target = y + torch.randn(batch_size, d_hid, device=self.device) + + loss_fn = torch.nn.MSELoss(reduction="sum") + + # Run reference + for _ in range(2): + ref_mod.zero_grad() + ref_out = ref_mod(x) + ref_loss = loss_fn(ref_out, target) + ref_loss.backward() + + # Create a pipeline stage to wrap that submodule + num_microbatches = 1 + stage_indices = rank_stages[self.rank] + print(f"Rank {self.rank} stages: {stage_indices}") + submod_names = [f"layers.{i}" for i in stage_indices] + stage_modules = [ + full_mod.get_submodule(submod_name) for submod_name in submod_names + ] + stages = [ + PipelineStage( + stage_module, + stage_idx, + n_stages, + self.device, + ) + for stage_module, stage_idx in zip(stage_modules, rank_stages[self.rank]) + ] + + schedule = schedule_class( + stages, + num_microbatches, + loss_fn=loss_fn, + scale_grads=False, + ) + if use_new_runtime: + old_schedule = schedule + schedule = _PipelineScheduleRuntime( + stages, + num_microbatches, + loss_fn=loss_fn, + ) + schedule._load_actions(old_schedule.pipeline_order) + + # Run + # TODO how to better specify .step() when first and last stage are on rank 0... + for _ in range(2): + # Zero gradients + for stage_module in stage_modules: + stage_module.zero_grad() + if self.rank == 0: + losses = [] + out = schedule.step(x, target=target, losses=losses) + else: + schedule.step() + + dist.barrier() + + # Last rank checks result + if self.rank == 0: + # Check output + torch.testing.assert_close(out, ref_out) + # Check loss + # Since the reduction used in the loss function above is "sum", we use + # "sum" here to reduce microbatch losses into a single value too. + pipe_loss = sum(losses) + torch.testing.assert_close(pipe_loss, ref_loss) + + # Every rank checks gradients + for stage_module, submod_name in zip(stage_modules, submod_names): + # Get corresponding submodule from reference model + ref_submod = ref_mod.get_submodule(submod_name) + # Check gradients per parameter + for name, p in stage_module.named_parameters(): + ref_p = ref_submod.get_parameter(name) + try: + torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) + except AssertionError: + print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") + raise + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble]) def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass): stages_per_rank = 2 n_stages = stages_per_rank * self.world_size +<<<<<<< HEAD full_mod, ref_mod, x, target, _ = setup_models_and_data( self.config, n_layers=n_stages, model_class=MultiMLPWithDw ) @@ -636,6 +1366,44 @@ def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass): stages, stage_modules, submod_names = create_multi_stage_pipeline( self.config, full_mod, stages_per_rank, n_stages ) +======= + full_mod = MultiMLPWithDw(d_hid, n_layers=n_stages) + full_mod.to(self.device) + + ref_mod = copy.deepcopy(full_mod) + x = torch.randn(batch_size, d_hid, device=self.device) + with torch.no_grad(): + y = ref_mod(x) + # Add a small perturbation + target = y + torch.randn(batch_size, d_hid, device=self.device) + + ref_loss_fn = torch.nn.MSELoss(reduction="sum") + full_loss_fn = torch.nn.MSELoss(reduction="sum") + + full_mod.toggle() + + # Get a submodule, e.g. `layers.0` or `layers.1` + stage_indices = [ + self.rank + i * self.world_size for i in range(stages_per_rank) + ] + submod_names = [f"layers.{i}" for i in stage_indices] + stage_modules = [ + full_mod.get_submodule(submod_name) for submod_name in submod_names + ] + + # Run reference + for _ in range(2): + ref_stage_modules = [ + ref_mod.get_submodule(submod_name) for submod_name in submod_names + ] + for stage_module in ref_stage_modules: + stage_module.zero_grad() + + ref_mod.zero_grad() + ref_out = ref_mod(x) + ref_loss = ref_loss_fn(ref_out, target) + ref_loss.backward() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class CustomState: def __init__(self, stage_module, stage_idx, rank): @@ -646,6 +1414,10 @@ def __init__(self, stage_module, stage_idx, rank): def dw_builder(self): def dw_runner(): +<<<<<<< HEAD +======= + # This inner function would be called by PipelineStage during `backward_weight_one_chunk` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.i += 1 print( f"[Rank {self.rank}] dw_count={self.i} stage={self.stage_idx}" @@ -654,6 +1426,7 @@ def dw_runner(): return dw_runner +<<<<<<< HEAD # Create custom states and rebuild stages with dw_builder cs = {} stage_indices = [ @@ -662,6 +1435,14 @@ def dw_runner(): for stage_module, stage_idx in zip(stage_modules, stage_indices): cs[stage_idx] = CustomState(stage_module, stage_idx, self.rank) +======= + cs = {} + for stage_module, stage_idx in zip(stage_modules, stage_indices): + cs[stage_idx] = CustomState(stage_module, stage_idx, self.rank) + + # Create a pipeline stage to wrap that submodule + chunks = 2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stages = [ PipelineStage( stage_module, @@ -673,6 +1454,7 @@ def dw_runner(): for stage_module, stage_idx in zip(stage_modules, stage_indices) ] +<<<<<<< HEAD schedule = ScheduleClass(stages, 2, loss_fn=loss_fn) # Run pipeline @@ -683,10 +1465,26 @@ def dw_runner(): if self.rank == 0: schedule.step(x) elif self.rank == self.world_size - 1: +======= + # Attach to a schedule + schedule = ScheduleClass( + stages, chunks, loss_fn=full_loss_fn, scale_grads=False + ) + + for _ in range(2): + # Zero gradients + for stage_module in stage_modules: + stage_module.zero_grad() + if self.rank == 0: + schedule.step(x) + elif self.rank == self.world_size - 1: + losses = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = schedule.step(target=target, losses=losses) else: schedule.step() +<<<<<<< HEAD dist.barrier(device_ids=[self.rank]) # Verify results @@ -758,6 +1556,31 @@ def test_v_shape_schedules(self, schedule_class, use_new_runtime): @skip_but_pass_in_sandcastle_if( not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) +======= + dist.barrier() + # Last rank checks result + if self.rank == self.world_size - 1: + # Check output + torch.testing.assert_close(out, ref_out) + + # Check loss + # Since the reduction used in the loss function above is "sum", we use + # "sum" here to reduce microbatch losses into a single value too. + pipe_loss = sum(losses) + torch.testing.assert_close(pipe_loss, ref_loss) + + # Every rank checks gradients + for stage_module, submod_name in zip(stage_modules, submod_names): + # Get corresponding submodule from reference model + ref_submod = ref_mod.get_submodule(submod_name) + # Check gradients per parameter + for name, p in stage_module.named_parameters(): + ref_p = ref_submod.get_parameter(name) + torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize( "ScheduleClass", [ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B], @@ -765,6 +1588,7 @@ def test_v_shape_schedules(self, schedule_class, use_new_runtime): def test_zero_bubble_with_model_kwargs(self, ScheduleClass): stages_per_rank = 2 n_stages = stages_per_rank * self.world_size +<<<<<<< HEAD mod, ref_mod, x, target, loss_fn = setup_models_and_data( self.config, n_layers=n_stages, model_class=MultiMLPKwargs ) @@ -780,6 +1604,55 @@ def test_zero_bubble_with_model_kwargs(self, ScheduleClass): self.config, mod, stages_per_rank, n_stages ) +======= + full_mod = MultiMLPKwargs(d_hid, n_layers=n_stages) + full_mod.to(self.device) + + ref_mod = copy.deepcopy(full_mod) + x = torch.randn(batch_size, d_hid, device=self.device) + unused_kwarg = torch.tensor([1.0], device=self.device) + + with torch.no_grad(): + y = ref_mod(x) + # Add a small perturbation + target = y + torch.randn(batch_size, d_hid, device=self.device) + + loss_fn = torch.nn.MSELoss(reduction="sum") + + # Get a submodule, e.g. `layers.0` or `layers.1` + stage_indices = [ + self.rank + i * self.world_size for i in range(stages_per_rank) + ] + submod_names = [f"layers.{i}" for i in stage_indices] + stage_modules = [ + full_mod.get_submodule(submod_name) for submod_name in submod_names + ] + # Run reference + for _ in range(2): + ref_stage_modules = [ + ref_mod.get_submodule(submod_name) for submod_name in submod_names + ] + for stage_module in ref_stage_modules: + stage_module.zero_grad() + + ref_mod.zero_grad() + ref_out = ref_mod(x, unused_kwarg=unused_kwarg) + ref_loss = loss_fn(ref_out, target) + ref_loss.backward() + + # Create a pipeline stage to wrap that submodule + stages = [ + PipelineStage( + stage_module, + stage_idx, + n_stages, + self.device, + ) + for stage_module, stage_idx in zip(stage_modules, stage_indices) + ] + + # Attach to a schedule +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) num_microbatches = ( ScheduleClass.num_microbatches if hasattr(ScheduleClass, "num_microbatches") @@ -789,11 +1662,18 @@ def test_zero_bubble_with_model_kwargs(self, ScheduleClass): stages, num_microbatches, loss_fn=loss_fn, scale_grads=False ) +<<<<<<< HEAD # Run pipeline with kwargs out = None losses = [] for _ in range(2): zero_gradients(stage_modules) +======= + for _ in range(2): + # Zero gradients + for stage_module in stage_modules: + stage_module.zero_grad() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.rank == 0: schedule.step( x, @@ -802,11 +1682,16 @@ def test_zero_bubble_with_model_kwargs(self, ScheduleClass): .expand(num_microbatches, -1), ) elif self.rank == self.world_size - 1: +<<<<<<< HEAD +======= + losses = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = schedule.step(target=target, losses=losses) else: schedule.step() dist.barrier() +<<<<<<< HEAD # Verify results if self.rank == self.world_size - 1: @@ -818,11 +1703,37 @@ def test_zero_bubble_with_model_kwargs(self, ScheduleClass): check_gradients( self.config, stage_modules, ref_mod, submod_names, rtol=3e-5, atol=5e-3 ) +======= + # Last rank checks result + if self.rank == self.world_size - 1: + # Check output + torch.testing.assert_close(out, ref_out) + + # Check loss + pipe_loss = sum(losses) + torch.testing.assert_close(pipe_loss, ref_loss) + + # Every rank checks gradients + for stage_module, submod_name in zip(stage_modules, submod_names): + # Get corresponding submodule from reference model + ref_submod = ref_mod.get_submodule(submod_name) + # Check gradients per parameter + for name, p in stage_module.named_parameters(): + ref_p = ref_submod.get_parameter(name) + try: + torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=5e-3) + except AssertionError: + print( + f"Gradient test failed for {name}: {p.grad=} vs {ref_p.grad=}" + ) + raise +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_parametrized_tests(ScheduleTest) +<<<<<<< HEAD class CustomSchedulesTest(MultiProcContinuousTest): """ These schedules are from the ScheduleRegistry and require world_size == 2 @@ -1025,5 +1936,7 @@ def test_schedule_with_native_zero_bubble(self, ScheduleClass): instantiate_parametrized_tests(CustomSchedulesTest) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py index 12c8d62037357..9ffd7cd92fbd0 100644 --- a/test/distributed/pipelining/test_stage.py +++ b/test/distributed/pipelining/test_stage.py @@ -14,15 +14,27 @@ ScheduleGPipe, ) from torch.distributed.pipelining._utils import PipeliningShapeError +<<<<<<< HEAD from torch.testing._internal.common_distributed import ( MultiProcContinuousTest, MultiProcessTestCase, requires_accelerator_dist_backend, +======= +from torch.testing._internal.common_cuda import TEST_MULTIGPU +from torch.testing._internal.common_distributed import ( + MultiProcContinousTest, + MultiProcessTestCase, + requires_nccl, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, +<<<<<<< HEAD +======= + skip_but_pass_in_sandcastle, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) skip_but_pass_in_sandcastle_if, ) from torch.utils._pytree import tree_map_only @@ -32,9 +44,13 @@ batch_size = 256 chunks = 4 +<<<<<<< HEAD device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" backend = dist.get_default_backend_for_device(device_type) TEST_MULTIACCELERATOR = torch.accelerator.device_count() >= 2 +======= +device_type = "cuda" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.manual_seed(0) @@ -63,11 +79,19 @@ def f(x): return flatten_hook +<<<<<<< HEAD class StageTest(MultiProcContinuousTest): @classmethod def backend_str(cls) -> str: # Testing with NCCL backend return backend +======= +class StageTest(MultiProcContinousTest): + @classmethod + def backend_str(cls) -> str: + # Testing with NCCL backend + return "nccl" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def device_type(cls) -> str: @@ -77,10 +101,15 @@ def device_type(cls) -> str: def device(self) -> torch.device: return torch.device(device_type, self.rank) +<<<<<<< HEAD @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_but_pass_in_sandcastle_if( not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) +======= + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("ModelClass", [ExampleCode, MultiMLP]) def test_tracer(self, ModelClass): mod = ModelClass(d_hid, self.world_size) @@ -123,10 +152,15 @@ def _run_step(x): old_keys = mod.state_dict().keys() assert all(k in old_keys for k in submod_keys) +<<<<<<< HEAD @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_but_pass_in_sandcastle_if( not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) +======= + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("ModelClass", [ModelWithKwargs]) def test_tracer_kwargs(self, ModelClass): mod = ModelClass(d_hid, self.world_size) @@ -174,10 +208,15 @@ def test_tracer_kwargs(self, ModelClass): old_keys = mod.state_dict().keys() assert all(k in old_keys for k in submod_keys) +<<<<<<< HEAD @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_but_pass_in_sandcastle_if( not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) +======= + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_manual(self): full_mod = MultiMLP(d_hid, n_layers=self.world_size) full_mod.to(self.device) @@ -208,10 +247,15 @@ def _run_step(x): ref_out = full_mod(x) torch.testing.assert_close(out, ref_out) +<<<<<<< HEAD @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_but_pass_in_sandcastle_if( not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) +======= + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_custom_dw_with_fb_schedule(self): """Tests that separate weight grad function 'dw_runner' gets run under a schedule that's only aware of F/B.""" full_mod = MultiMLP(d_hid, n_layers=self.world_size) @@ -270,10 +314,15 @@ def _run_step(x): ref_out = full_mod(x) torch.testing.assert_close(out, ref_out) +<<<<<<< HEAD @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_but_pass_in_sandcastle_if( not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) +======= + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_output_chunks_memory_usage(self): """Test that output_chunks doesn't store memory for non-first stages.""" full_mod = MultiMLP(d_hid, n_layers=self.world_size) @@ -357,17 +406,26 @@ def tearDown(self): def init_pg(self): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( +<<<<<<< HEAD backend=backend, +======= + backend="nccl", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) store=store, rank=self.rank, world_size=self.world_size, device_id=self.device, ) +<<<<<<< HEAD @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_but_pass_in_sandcastle_if( not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) +======= + @requires_nccl() + @skip_but_pass_in_sandcastle("Flaky in CI") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_shape_prop_mismatch(self): """Tests shape prop errors are raised""" self.init_pg() @@ -414,10 +472,15 @@ def _run_step(x): with self.assertRaisesRegex(PipeliningShapeError, "dtype mismatch"): _run_step(x) +<<<<<<< HEAD @requires_accelerator_dist_backend(["nccl", "xccl"]) @skip_but_pass_in_sandcastle_if( not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) +======= + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_custom_dw_errors(self): """Tests expected errors are raised""" self.init_pg() @@ -433,7 +496,10 @@ def test_custom_dw_errors(self): self.device, dw_builder=lambda: None, ) +<<<<<<< HEAD stage_with_dw_builder._has_backward = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaisesRegex(AssertionError, "backward_one_chunk"): stage_with_dw_builder.backward_weight_one_chunk(bwd_chunk_id=0) diff --git a/test/distributed/pipelining/test_transformer.py b/test/distributed/pipelining/test_transformer.py index 20e830547de7b..2f740ff476b23 100644 --- a/test/distributed/pipelining/test_transformer.py +++ b/test/distributed/pipelining/test_transformer.py @@ -73,9 +73,13 @@ def get_layers(module): devices = ["cpu", "cuda", "hpu", "xpu"] +<<<<<<< HEAD instantiate_device_type_tests( TransformerTests, globals(), only_for=devices, allow_xpu=True ) +======= +instantiate_device_type_tests(TransformerTests, globals(), only_for=devices) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/pipelining/test_unflatten.py b/test/distributed/pipelining/test_unflatten.py index 0493f39b16cb8..742f22e5dd2e0 100644 --- a/test/distributed/pipelining/test_unflatten.py +++ b/test/distributed/pipelining/test_unflatten.py @@ -73,9 +73,13 @@ def test_unflatten(self, device): devices = ["cpu", "cuda", "hpu", "xpu"] +<<<<<<< HEAD instantiate_device_type_tests( UnflattenTests, globals(), only_for=devices, allow_xpu=True ) +======= +instantiate_device_type_tests(UnflattenTests, globals(), only_for=devices) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/rpc/test_faulty_agent.py b/test/distributed/rpc/test_faulty_agent.py index f9e9db18cce50..d0f4e615ba23f 100644 --- a/test/distributed/rpc/test_faulty_agent.py +++ b/test/distributed/rpc/test_faulty_agent.py @@ -22,7 +22,11 @@ # On CircleCI these tests are already run on CPU jobs, thus to save resources do +<<<<<<< HEAD # not run them on GPU jobs, since they wouldn't provide additional test signal. +======= +# not run them on GPU jobs, since thet wouldn't provide additional test signal. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not (IS_CI and torch.cuda.is_available()): globals().update( generate_tests( diff --git a/test/distributed/rpc/test_tensorpipe_agent.py b/test/distributed/rpc/test_tensorpipe_agent.py index e21460ba04c82..24ae29be9d074 100644 --- a/test/distributed/rpc/test_tensorpipe_agent.py +++ b/test/distributed/rpc/test_tensorpipe_agent.py @@ -23,7 +23,11 @@ # On CircleCI these tests are already run on CPU jobs, thus to save resources do +<<<<<<< HEAD # not run them on GPU jobs, since they wouldn't provide additional test signal. +======= +# not run them on GPU jobs, since thet wouldn't provide additional test signal. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not (IS_CI and torch.cuda.is_available()): globals().update( generate_tests( diff --git a/test/distributed/tensor/debug/test_comm_mode_features.py b/test/distributed/tensor/debug/test_comm_mode_features.py index 86b3849fda69a..702ac7f596912 100644 --- a/test/distributed/tensor/debug/test_comm_mode_features.py +++ b/test/distributed/tensor/debug/test_comm_mode_features.py @@ -11,7 +11,11 @@ parallelize_module, RowwiseParallel, ) +<<<<<<< HEAD from torch.testing._internal.common_utils import run_tests, skipIfHpu, TEST_XPU, xfailIf +======= +from torch.testing._internal.common_utils import run_tests, skipIfHpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, MLPModule, @@ -144,10 +148,17 @@ def test_MLPStacked_distributed_sharding_display(self): model2 = MLPStacked(self.device_type) parallelize_plan = { +<<<<<<< HEAD "layers.0.net1": ColwiseParallel(), "layers.0.net2": RowwiseParallel(), "layers.1.net1": ColwiseParallel(), "layers.1.net2": RowwiseParallel(), +======= + "MLPStacked.layers.0.net1": ColwiseParallel(), + "MLPStacked.layers.0.net2": RowwiseParallel(), + "MLPStacked.layers.1.net1": ColwiseParallel(), + "MLPStacked.layers.1.net2": RowwiseParallel(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } model2 = parallelize_module(model2, device_mesh, parallelize_plan) @@ -221,7 +232,10 @@ def test_MLP_module_tracing(self): @skipIfHpu @skip_unless_torch_gpu +<<<<<<< HEAD @xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1555 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @with_comms def test_transformer_module_tracing(self, is_seq_parallel=False): """ diff --git a/test/distributed/tensor/experimental/test_local_map.py b/test/distributed/tensor/experimental/test_local_map.py index dad23226363ed..c14584ac5af29 100644 --- a/test/distributed/tensor/experimental/test_local_map.py +++ b/test/distributed/tensor/experimental/test_local_map.py @@ -1,5 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] +<<<<<<< HEAD +======= +from functools import partial +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.distributed._functional_collectives as funcol @@ -13,7 +17,10 @@ ) from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.experimental import local_map +<<<<<<< HEAD from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -49,7 +56,12 @@ def mm_allreduce_forward(device_mesh, A, B): return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() +<<<<<<< HEAD @local_map( +======= +@partial( + local_map, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_placements=replicate, in_placements=(None, col_wise, row_wise), ) @@ -88,7 +100,11 @@ def test_local_map_correctness(self): ) # row-wisely sharded W tensor # Test 1: use the function returned from calling local_map +<<<<<<< HEAD # get the function wrapped with DTensor/Tensor conversion +======= + # get the function wrapped with DTensor/Tensor convertion +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # mm_allreduce_forward is a function that applies to Tensors with manual collective # local_mm_allreduce_forward is the function that does the same but applies to # DTensors' `_local_tensor`. @@ -384,6 +400,7 @@ def test_local_map_with_grad_placement(self): ) self.assertEqual(W_dt.grad.full_tensor(), W.grad) +<<<<<<< HEAD @skip_if_lt_x_gpu(4) @with_comms def test_multi_mesh_inputs(self): @@ -426,6 +443,8 @@ def test_multi_mesh_inputs(self): # output lives in mesh_2d self.assertEqual(Y_dt.device_mesh, mesh_2d) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/experimental/test_tp_transform.py b/test/distributed/tensor/experimental/test_tp_transform.py index 2f52d9c18b2bc..96771047ef132 100644 --- a/test/distributed/tensor/experimental/test_tp_transform.py +++ b/test/distributed/tensor/experimental/test_tp_transform.py @@ -85,7 +85,11 @@ def test_tp_transform_with_uncovered_op(self): with torch.no_grad(): tp_res = tp_model(*inputs) self.assertEqual(res, tp_res) +<<<<<<< HEAD # Expect all_gather to be inserted to distributed sharded fc results +======= + # Expect all_gather to be inserted to distributed sharded fc resutls +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assert_has_c10d_ops( tp_exported_program.graph_module, { diff --git a/test/distributed/tensor/parallel/test_parallelize_api.py b/test/distributed/tensor/parallel/test_parallelize_api.py index 2ef70f1a447e3..c5a42abdb0cbd 100644 --- a/test/distributed/tensor/parallel/test_parallelize_api.py +++ b/test/distributed/tensor/parallel/test_parallelize_api.py @@ -33,7 +33,11 @@ def forward(self, x): class TensorParallelAPITests(DTensorTestBase): @property def world_size(self): +<<<<<<< HEAD gpu_num = torch.accelerator.device_count() +======= + gpu_num = torch.cuda.device_count() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return gpu_num if gpu_num % 2 == 0 and gpu_num > 4 else 4 def _compare_params( @@ -333,6 +337,7 @@ def test_parallelize_module_multi_wildcard(self): self._compare_module(model, model_tp, inp_size, rank0_only=False) @with_comms +<<<<<<< HEAD def test_parallelize_module_with_root_module(self): inp_size = [16, 10] model = MLPModule(self.device_type) @@ -376,6 +381,8 @@ def test_parallelize_module_with_no_match(self): self._compare_module(model, model_tp, inp_size, rank0_only=False) @with_comms +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_under_devicemesh_context(self): # test ColwiseParallel inp_size = [8, 10] @@ -400,8 +407,12 @@ def test_empty_plan(self): # Call parallelize_module with empty plan. # Goal is not to crash. device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +<<<<<<< HEAD with self.assertWarns(UserWarning): parallelize_module(model, device_mesh) +======= + parallelize_module(model, device_mesh) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": diff --git a/test/distributed/tensor/parallel/test_tp_examples.py b/test/distributed/tensor/parallel/test_tp_examples.py index 49d3d6a0c52d6..63436781861e2 100644 --- a/test/distributed/tensor/parallel/test_tp_examples.py +++ b/test/distributed/tensor/parallel/test_tp_examples.py @@ -27,7 +27,12 @@ RowwiseParallel, ) from torch.distributed.tensor.parallel.input_reshard import input_reshard +<<<<<<< HEAD from torch.testing._internal.common_device_type import skipXPUIf +======= +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION +from torch.testing._internal.common_device_type import skipIf +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -282,7 +287,10 @@ def _thaw_params(thaw_params, model, model_tp): @skip_unless_torch_gpu @parametrize("is_seq_parallel", [True, False]) @parametrize("dtype", [torch.float64, torch.float32]) +<<<<<<< HEAD @skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1555") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_transformer_training(self, is_seq_parallel, dtype: torch.dtype): EXP_BASE_CC = ExpCommCounts( fwd={all_reduce: 6, all_gather: 1}, bwd={all_reduce: 9} @@ -414,7 +422,11 @@ def test_transformer_training(self, is_seq_parallel, dtype: torch.dtype): + f"{str(dtype).split('.')[-1]}_" + f"thaw_{'__'.join(sorted({n.rpartition('.')[0].replace('.', '_') for n in thaw})) if thaw else 'all'}", ) +<<<<<<< HEAD @skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1555") +======= + @skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_transformer_req_grad(self, thaw_params, is_seq_parallel, dtype, exp_cnts): # Sample a subset of `requires_grad` patterns diff --git a/test/distributed/tensor/parallel/test_tp_random_state.py b/test/distributed/tensor/parallel/test_tp_random_state.py index 490210517f517..7db0fc966447c 100644 --- a/test/distributed/tensor/parallel/test_tp_random_state.py +++ b/test/distributed/tensor/parallel/test_tp_random_state.py @@ -66,7 +66,11 @@ def test_model_init(self): # in the following way: # - within a tensor parallel group, the RNG is set with the same seed # - across data parallel groups, the RNG is set with different seeds +<<<<<<< HEAD torch.get_device_module(self.device_type).manual_seed(0) +======= + torch.cuda.manual_seed(dp_rank) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # disable/enable parallel RNG feature if random._rng_tracker: @@ -118,10 +122,21 @@ def tp_weights_assert(tensor1, tensor2): # compare local shards across TP groups def dp_weights_assert(tensor1, tensor2): +<<<<<<< HEAD # local weights shall be initialized the same across TP groups, # and it doesn't matter whether DTensor's RNG infra is activated since all spmd ranks # started with the same seed. self.assertEqual(tensor1, tensor2) +======= + if enable_distribute_flag: + # local weights shall be initialized the same across TP groups + self.assertEqual(tensor1, tensor2) + else: + # without the parallel RNG, weight initialization violates the TP setup: + # local weights are initialized differently across TP groups due to different + # random seeds set in data loading. + self.assertNotEqual(tensor1, tensor2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.check_gathered_tensors( dp_rank, dp_size, tensor_gather, dp_weights_assert diff --git a/test/distributed/tensor/test_api.py b/test/distributed/tensor/test_api.py index a4efd6d5b6bed..ad1b0dc4032b3 100644 --- a/test/distributed/tensor/test_api.py +++ b/test/distributed/tensor/test_api.py @@ -48,7 +48,11 @@ def world_size(self) -> int: def test_distribute_tensor_rank(self): comm_mode = CommDebugMode() +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Shard(0)] for requires_grad in [True, False]: @@ -134,7 +138,11 @@ def test_distribute_tensor_errors(self): @with_comms def test_distribute_tensor_uneven_sharding(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_sizes_and_shard_dims = [ ((self.world_size * 3 + 1, 3, 3), 0), ((self.world_size * 3 + 2, 3, 3), 0), @@ -156,7 +164,11 @@ def test_distribute_tensor_uneven_sharding(self): @with_comms def test_distribute_module(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # fully shard all linear modules on dim 0 module_to_shard = MyModel(5 * self.world_size, 20, device=self.device_type) shard_spec = [Shard(0)] @@ -219,7 +231,11 @@ def shard_fn(name, module, device_mesh): @with_comms def test_distribute_module_input_fn_output_fn(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # fully replicate all linear modules module_to_replicate = MyModel(20, 1, device=self.device_type) @@ -264,7 +280,11 @@ def replicate_input_fn(mod, inputs, device_mesh): @with_comms def test_distribute_module_input_fn_output_fn_warning(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # fully replicate all linear modules module_to_replicate = MyModel(20, 1, device=self.device_type) @@ -292,7 +312,11 @@ def output_fn(outputs, device_mesh): @with_comms def test_distribute_module_casting(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # check DTensor casting dt = DTensor.from_local(torch.rand(10), device_mesh, [Replicate()]) @@ -335,7 +359,11 @@ def test_distribute_module_casting(self): def test_distribute_module_meta(self): # If the model is too big, the user may first the create entire model on the meta device and then initialize # it on the device in the partition function. +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # fully shard all parameters on dim 0 module_to_shard = MyModel(5 * self.world_size, 20, device="meta") diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index a2543d443e4fe..5b8f398314390 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -1,16 +1,24 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] +<<<<<<< HEAD import functools import itertools import random import unittest from typing import Union +======= +import unittest +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.distributed as dist import torch.nn.functional as F +<<<<<<< HEAD from torch import nn, Tensor from torch.distributed.device_mesh import init_device_mesh +======= +from torch import nn +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.tensor import DeviceMesh from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.experimental._attention import ( @@ -26,11 +34,14 @@ ) from torch.distributed.tensor.parallel import parallelize_module from torch.nn.attention import sdpa_kernel, SDPBackend +<<<<<<< HEAD from torch.nn.attention.flex_attention import ( _mask_mod_signature, create_block_mask, flex_attention, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_CUDNN_ATTENTION, PLATFORM_SUPPORTS_FLASH_ATTENTION, @@ -447,6 +458,7 @@ def _test_ring_attention_custom_transformer(self, rotater: _RotateMethod) -> Non ) +<<<<<<< HEAD # Compile the flex_attention function compiled_flex_attention = torch.compile(flex_attention, dynamic=False, fullgraph=True) compiled_create_block_mask = torch.compile( @@ -745,5 +757,7 @@ def test_ring_flex_attention_document_mask(self) -> None: test_func() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_common_rules.py b/test/distributed/tensor/test_common_rules.py index 3450f8faa2b5c..abcd910b1c686 100644 --- a/test/distributed/tensor/test_common_rules.py +++ b/test/distributed/tensor/test_common_rules.py @@ -8,17 +8,31 @@ from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( +<<<<<<< HEAD DTensorContinuousTestBase, +======= + DTensorTestBase, + with_comms, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) aten = torch.ops.aten +<<<<<<< HEAD class CommonRulesTest(DTensorContinuousTestBase): # hard code world size to 4 as we need to test # at least with 2d mesh world_size = 4 +======= +class CommonRulesTest(DTensorTestBase): + @property + def world_size(self) -> int: + # hard code world size to 4 as we need to test + # at least with 2d mesh + return 4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _gen_tensor_meta(self, shape): empty_tensor = torch.empty(shape) @@ -28,9 +42,16 @@ def _gen_tensor_meta(self, shape): empty_tensor.dtype, ) +<<<<<<< HEAD def test_einop_basic_propagation(self): # plain einsum, mm mesh = DeviceMesh(self.device_type(), torch.arange(self.world_size)) +======= + @with_comms + def test_einop_basic_propagation(self): + # plain einsum, mm + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mm_call = aten.mm.default # propagate col-wise sharding @@ -81,8 +102,14 @@ def test_einop_basic_propagation(self): self.assertIsNotNone(output_spec) self.assertTrue(output_spec.placements[0].is_partial()) +<<<<<<< HEAD def test_einop_pointwise_propagation(self): mesh = DeviceMesh(self.device_type(), torch.arange(self.world_size)) +======= + @with_comms + def test_einop_pointwise_propagation(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add_call = aten.add.Tensor # addition @@ -132,12 +159,20 @@ def test_einop_pointwise_propagation(self): self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [0, -1, -1]) +<<<<<<< HEAD +======= + @with_comms +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_einop_merge_sharding(self): # 2d mesh einop merge sharding mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) +<<<<<<< HEAD mesh = DeviceMesh(self.device_type(), mesh_shape) +======= + mesh = DeviceMesh(self.device_type, mesh_shape) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mm_call = aten.mm.default @@ -157,11 +192,19 @@ def test_einop_merge_sharding(self): self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [0, 1]) +<<<<<<< HEAD +======= + @with_comms +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_einop_linearity(self): mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) +<<<<<<< HEAD mesh = DeviceMesh(self.device_type(), mesh_shape) +======= + mesh = DeviceMesh(self.device_type, mesh_shape) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mm_call = aten.mm.default @@ -224,10 +267,18 @@ def test_einop_linearity(self): # mat2 mesh dim 1 should become partial now! self.assertTrue(mat2_spec.placements[1].is_partial()) +<<<<<<< HEAD def test_einop_multi_sharding_on_mesh_dim(self): # einop prop with multi sharding on same mesh dim mesh_shape = torch.arange(self.world_size) mesh = DeviceMesh(self.device_type(), mesh_shape) +======= + @with_comms + def test_einop_multi_sharding_on_mesh_dim(self): + # einop prop with multi sharding on same mesh dim + mesh_shape = torch.arange(self.world_size) + mesh = DeviceMesh(self.device_type, mesh_shape) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mm_call = aten.mm.default mat1, mat2 = [0, -1], [0, -1] @@ -252,11 +303,19 @@ def test_einop_multi_sharding_on_mesh_dim(self): self.assertEqual(schema_suggestion.args_schema[0].dim_map, [0, -1]) self.assertEqual(schema_suggestion.args_schema[1].dim_map, [-1, -1]) +<<<<<<< HEAD +======= + @with_comms +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_einop_errors(self): mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) +<<<<<<< HEAD mesh = DeviceMesh(self.device_type(), mesh_shape) +======= + mesh = DeviceMesh(self.device_type, mesh_shape) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add_call = aten.add.Tensor mat1, mat2 = [0, -1], [1, -1] @@ -272,8 +331,14 @@ def test_einop_errors(self): with self.assertRaisesRegex(RuntimeError, "sharded two different ways:"): einop_rule("ij,ij->ij", OpSchema(add_call, (mat1_spec, mat2_spec), {})) +<<<<<<< HEAD def test_pointwise_rules_broadcasting(self): mesh = DeviceMesh(self.device_type(), torch.arange(self.world_size)) +======= + @with_comms + def test_pointwise_rules_broadcasting(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) where_call = aten.where.self inp1, inp2, inp3 = [0], [], [-1, -1] @@ -297,8 +362,14 @@ def test_pointwise_rules_broadcasting(self): self.assertIsNotNone(output_spec) self.assertEqual(output_spec.dim_map, [-1, 0]) +<<<<<<< HEAD def test_pointwise_rules_suggestion(self): mesh = DeviceMesh(self.device_type(), torch.arange(self.world_size)) +======= + @with_comms + def test_pointwise_rules_suggestion(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lerp_call = aten.lerp.Scalar # propagate point-wise sharding @@ -324,12 +395,20 @@ def test_pointwise_rules_suggestion(self): self.assertEqual(len(schema_suggestion.args_schema), 3) self.assertEqual(schema_suggestion.args_schema[2], -1) +<<<<<<< HEAD +======= + @with_comms +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_pointwise_multi_sharding_on_mesh_dim(self): # 2d mesh pointwise sharding mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) +<<<<<<< HEAD mesh = DeviceMesh(self.device_type(), mesh_shape) +======= + mesh = DeviceMesh(self.device_type, mesh_shape) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add_call = aten.add.Tensor @@ -369,12 +448,20 @@ def test_pointwise_multi_sharding_on_mesh_dim(self): self.assertEqual(schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1]) self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat2) +<<<<<<< HEAD +======= + @with_comms +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_pointwise_enforce_sharding_multi_sharding_on_mesh_dim(self): # 2d mesh pointwise sharding mesh_shape = torch.arange(self.world_size).reshape( self.world_size // 2, self.world_size // 2 ) +<<<<<<< HEAD mesh = DeviceMesh(self.device_type(), mesh_shape) +======= + mesh = DeviceMesh(self.device_type, mesh_shape) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add_call = aten.add_.Tensor diff --git a/test/distributed/tensor/test_convolution_ops.py b/test/distributed/tensor/test_convolution_ops.py index d249a6d2ff772..9d325a7f3db75 100644 --- a/test/distributed/tensor/test_convolution_ops.py +++ b/test/distributed/tensor/test_convolution_ops.py @@ -5,7 +5,11 @@ import torch import torch.nn as nn +<<<<<<< HEAD from torch.distributed import DeviceMesh +======= +from torch.distributed import DeviceMesh, init_device_mesh +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.tensor import ( distribute_module, distribute_tensor, @@ -48,7 +52,11 @@ def world_size(self) -> int: @with_comms def test_downsampling_convolution(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Shard(3)] input_list = torch.rand(ITER_TIME, 7, 3, 512, 1024) @@ -118,7 +126,11 @@ def test_downsampling_convolution(self): @with_comms @skip_if_lt_x_gpu(2) def test_depthwise_convolution(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Shard(3)] input_list = torch.rand(ITER_TIME, 7, 256, 128, 256) @@ -186,7 +198,13 @@ def test_depthwise_convolution(self): @with_comms @skip_if_lt_x_gpu(2) def test_conv_backward_none_grad_inp(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = init_device_mesh( + device_type=self.device_type, mesh_shape=(self.world_size,) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) conv = nn.Conv2d(64, 64, 3, padding=1).train() x = torch.randn(1, 64, 32, 32) x_dt = DTensor.from_local(x, device_mesh, [Replicate()]) diff --git a/test/distributed/tensor/test_dtensor.py b/test/distributed/tensor/test_dtensor.py index f5ddb1a4222c6..1ec4d78f98526 100644 --- a/test/distributed/tensor/test_dtensor.py +++ b/test/distributed/tensor/test_dtensor.py @@ -11,6 +11,10 @@ import torch import torch.nn.functional as F from torch.distributed._functional_collectives import AsyncCollectiveTensor +<<<<<<< HEAD +======= +from torch.distributed.device_mesh import init_device_mesh +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.tensor import ( DeviceMesh, distribute_tensor, @@ -60,7 +64,11 @@ def reset_parameters(self, *args, **kwargs): class DTensorTest(DTensorTestBase): @with_comms def test_dtensor_constructor(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = [Shard(0)] local_tensor = torch.randn(3, 3, requires_grad=True) @@ -148,7 +156,11 @@ def test_modules_w_meta_dtensor(self): @with_comms def test_dtensor_stride(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard0_spec = [Shard(0)] local_tensor = torch.randn(4, 8) dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard0_spec) @@ -171,7 +183,11 @@ def test_dtensor_stride(self): @with_comms def test_from_local(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) @@ -208,7 +224,12 @@ def test_from_local(self): @with_comms def test_from_local_uneven_sharding(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + mesh_shape = (self.world_size,) + device_mesh = init_device_mesh(self.device_type, mesh_shape) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uneven_dim0_size = self.world_size + 1 global_tensor = torch.randn(uneven_dim0_size, 2) @@ -233,7 +254,12 @@ def test_from_local_uneven_sharding(self): @with_comms def test_from_local_uneven_sharding_raise_error(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + mesh_shape = (self.world_size,) + device_mesh = init_device_mesh(self.device_type, mesh_shape) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uneven_dim0_size = self.world_size + 1 global_tensor = torch.randn(uneven_dim0_size, 2) @@ -267,7 +293,11 @@ def test_from_local_uneven_sharding_raise_error(self): @with_comms def test_from_local_negative_dim(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = [Shard(-1)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) @@ -275,7 +305,11 @@ def test_from_local_negative_dim(self): @with_comms def test_to_local(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = (Shard(0),) local_tensor_with_grad = torch.randn( 3, 3, device=self.device_type, requires_grad=True @@ -335,7 +369,11 @@ def test_to_local(self): @with_comms def test_to_local_grad_hint(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = (Shard(0),) global_tensor = torch.ones(8, 3, requires_grad=True) @@ -360,7 +398,11 @@ def test_to_local_grad_hint(self): @with_comms def test_full_tensor_sync(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = (Shard(0),) global_tensor = torch.ones(8, 3, requires_grad=True) @@ -371,7 +413,11 @@ def test_full_tensor_sync(self): @with_comms def test_full_tensor_grad_hint(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = (Shard(0),) global_tensor = torch.ones(8, 3, requires_grad=True) @@ -384,7 +430,11 @@ def test_full_tensor_grad_hint(self): @with_comms def test_dtensor_new_empty_strided(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) local_tensor = torch.randn(8, 8, requires_grad=True, device=self.device_type) my_dtensor = distribute_tensor(local_tensor, device_mesh, [Shard(0)]) new_strided_dtensor = my_dtensor.new_empty_strided( @@ -410,7 +460,11 @@ def test_dtensor_async_output(self): # Tests that if the output of some dtensor operations isn't used in any compute, # the output should be an AsyncCollectiveTensor (representing the fact that # we haven't synced the collective yet). +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def fn(dt): dt_out_redistribute = dt.redistribute(mesh, [Replicate()], async_op=True) @@ -435,7 +489,11 @@ def fn(dt): self.assertEqual(type(out_view), AsyncCollectiveTensor) self.assertFalse(out.completed) +<<<<<<< HEAD # Use the data, requiring a sync +======= + # Use the daa, requiring a sync +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ref = torch.ones((4, 2), device=self.device_type) + 1 ref = ref.view(-1) out_data = out_view + 1 @@ -450,7 +508,11 @@ def fn(dt): @with_comms def test_from_local_then_to_local(self): # this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = [Shard(0)] # step 1. construct from construct local tensor @@ -482,7 +544,11 @@ def test_from_local_then_to_local(self): @with_comms def test_dtensor_spec_read_only_after_set(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) @@ -494,7 +560,11 @@ def test_dtensor_spec_read_only_after_set(self): @with_comms def test_dtensor_spec_hash(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = [Shard(0)] local_tensor = torch.randn(3, 3) local_tensor2 = torch.randn(3, 3) @@ -514,7 +584,11 @@ def test_dtensor_spec_hash(self): @with_comms def test_dtensor_properties(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) @@ -568,7 +642,11 @@ def test_dtensor_save_load_import(self): @with_comms def test_shard_tensor(self): ws = self.world_size +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(ws))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) full_tensor = torch.arange(ws * ws).reshape(ws, ws) # Shard by row @@ -619,7 +697,11 @@ def sub_mesh_assert_equal(self, mesh, exp_in_mesh, exp_out_of_mesh, tensor): @with_comms def test_dtensor_device_mesh_device_conversion(self): # construct a cuda device mesh +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # construct from a cpu local tensor with cuda device mesh # should automatically convert the dist tensor to cuda @@ -631,14 +713,22 @@ def test_dtensor_device_mesh_device_conversion(self): @with_comms def test_dtensor_api_device_mesh_context_manager(self): +<<<<<<< HEAD with self.build_device_mesh() as mesh: +======= + with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local( local_tensor, device_mesh=mesh, placements=placements ) +<<<<<<< HEAD with self.build_device_mesh(): +======= + with DeviceMesh(self.device_type, list(range(self.world_size))): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, placements=placements) @@ -648,7 +738,11 @@ def test_dtensor_api_device_mesh_context_manager(self): replica_tensor.size(), torch.Size([3 * self.world_size, 3]) ) +<<<<<<< HEAD with self.build_device_mesh(): +======= + with DeviceMesh(self.device_type, torch.arange(self.world_size)): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = [Shard(0)] global_shape = torch.Size([3 * self.world_size, 3]) global_tensor = torch.randn(global_shape) @@ -834,7 +928,11 @@ def test_redistribute_sub_mesh(self): @with_comms def test_implicit_replication(self): +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = init_device_mesh(self.device_type, (self.world_size,)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) local_tensor1 = torch.ones(4, 3) sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)]) @@ -849,6 +947,7 @@ def test_implicit_replication(self): self.assertEqual(local_shard, torch.ones(4, 3) + torch.ones(3)) @with_comms +<<<<<<< HEAD def test_vmap_embedding(self): mesh = self.build_device_mesh() batch_size, seq_len = 2, 6 @@ -875,6 +974,10 @@ def test_vmap_embedding(self): @with_comms def test_auto_implicit_replication(self): mesh = self.build_device_mesh() +======= + def test_auto_implicit_replication(self): + mesh = init_device_mesh(self.device_type, (self.world_size,)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) local_tensor = torch.ones(self.world_size, 3, device=self.device_type) sharded_dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]) @@ -900,7 +1003,11 @@ def add_scalar_tensor_with_dtensor(): @with_comms def test_metadata_consistency_check(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = [Shard(0)] # Create a local tensor with specific metadata and check dtype change @@ -962,7 +1069,11 @@ def _create_tensor(self, size): @with_comms def test_split_tensor_1D(self) -> None: +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_placement = Shard(0) for size in range(8): diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index 15e3daf6b9413..3e43b164a0509 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -20,6 +20,7 @@ ) from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +<<<<<<< HEAD from torch.distributed.tensor import ( DeviceMesh, distribute_module, @@ -33,6 +34,12 @@ from torch.distributed.tensor.parallel import ( ColwiseParallel, loss_parallel, +======= +from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor.parallel import ( + ColwiseParallel, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) parallelize_module, PrepareModuleInput, PrepareModuleOutput, @@ -96,6 +103,7 @@ def extract_graph(fx_g, _, graph_cell): ) +<<<<<<< HEAD def _apply_sharding(mod: nn.Module, shard_dim: int, device_mesh: DeviceMesh): """ Shards on the given dimension if possible, else replicate @@ -123,6 +131,8 @@ def shard_module_params(name, module, device_mesh): return sharded_mod +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestDTensorCompile(torch._dynamo.test_case.TestCase): def setUp(self): super( @@ -202,8 +212,11 @@ def forward(self, b_buffer, x): return (view_as_1,)""", # noqa: B950 ) +<<<<<<< HEAD # During tracing, sharding propagation cache is skipped, so an extra dry run for # add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertExpectedInline( str(ep.run_decompositions({}).graph_module.code).strip(), """\ @@ -211,8 +224,13 @@ def forward(self, b_parametrizations_buffer_original0, x): _assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None _to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None +<<<<<<< HEAD add_1 = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None view_1 = torch.ops.aten.view.default(add_1, [4, 4]); add_1 = None +======= + add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None + view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (view_1,)""", # noqa: B950 ) @@ -258,7 +276,11 @@ def fn(x: DeviceMesh): group1 = x.get_group(mesh_dim=1) return size, coord, group0, group1 +<<<<<<< HEAD # Can't be fullgraph=True because ProcessGroup is not reconstructible in dynamo +======= + # Cant be fullgraph=True because ProcessGroup is not reconstructible in dynamo +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) compiled_fn = torch.compile(backend="aot_eager")(fn) mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).unsqueeze(1)) @@ -306,9 +328,13 @@ def fn(x): .to_local()[0] ) +<<<<<<< HEAD x = DTensor.from_local( torch.rand(4, 4, requires_grad=True), mesh, [Shard(0)], run_check=False ) +======= + x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._dynamo.mark_dynamic(x, 0) ref = fn(x) @@ -316,6 +342,7 @@ def fn(x): res = opt_fn(x) self.assertEqual(res, ref) +<<<<<<< HEAD @skipIfHpu def test_dtensor_dynamic_slice(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -382,6 +409,8 @@ def fn(x, y): res = opt_fn(x, y) self.assertEqual(res, ref) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dtensor_attribute_access_on_intermediate(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -1215,6 +1244,7 @@ def fn(x, y): self.assertEqual(x_ref.grad, x.grad) self.assertEqual(y_ref.grad, y.grad) +<<<<<<< HEAD @with_comms def test_compile_embedding_redistribute(self): mesh = self.build_device_mesh() @@ -1238,6 +1268,8 @@ def forward(self, x): output = sharded_net(replicated_inp) self.assertEqual(output.full_tensor(), ref_out) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_dtensor_ops.py b/test/distributed/tensor/test_dtensor_ops.py index 8c650f6b0ce02..01aae3124acc1 100644 --- a/test/distributed/tensor/test_dtensor_ops.py +++ b/test/distributed/tensor/test_dtensor_ops.py @@ -103,6 +103,10 @@ def wrapped(fn): xfail("arange"), xfail("argmax"), xfail("argmin"), +<<<<<<< HEAD +======= + xfail("argsort"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("as_strided"), xfail("as_strided", "partial_views"), xfail("as_strided_copy"), @@ -118,8 +122,17 @@ def wrapped(fn): xfail("cholesky_inverse"), xfail("cholesky_solve"), xfail("chunk"), +<<<<<<< HEAD xfail("combinations"), xfail("complex"), +======= + xfail("clamp"), + xfail("clamp_max"), + xfail("clamp_min"), + xfail("combinations"), + xfail("complex"), + xfail("constant_pad_nd"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("count_nonzero"), xfail("cross"), xfail("cummax"), @@ -160,11 +173,19 @@ def wrapped(fn): xfail("frexp"), xfail("full"), xfail("full_like"), +<<<<<<< HEAD +======= + xfail("gather"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("geometric"), xfail("geqrf"), xfail("grid_sampler_2d"), xfail("gradient"), xfail("heaviside"), +<<<<<<< HEAD +======= + xfail("histc"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("histogram"), xfail("histogramdd"), xfail("index_add"), @@ -233,6 +254,10 @@ def wrapped(fn): xfail("median"), xfail("min", "reduction_with_dim"), xfail("mode"), +<<<<<<< HEAD +======= + xfail("msort"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("multinomial"), xfail("mv"), xfail("max_pool2d_with_indices_backward", ""), @@ -241,6 +266,10 @@ def wrapped(fn): xfail("nanquantile"), xfail("nansum"), xfail("native_batch_norm"), +<<<<<<< HEAD +======= + xfail("native_dropout_backward"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("narrow_copy"), xfail("ne"), xfail("new_empty"), @@ -289,6 +318,10 @@ def wrapped(fn): xfail("nn.functional.interpolate", "nearest"), xfail("nn.functional.interpolate", "nearest-exact"), xfail("nn.functional.leaky_relu"), +<<<<<<< HEAD +======= + xfail("nn.functional.linear"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("nn.functional.local_response_norm"), xfail("nn.functional.logsigmoid"), xfail("nn.functional.margin_ranking_loss"), @@ -304,8 +337,16 @@ def wrapped(fn): xfail("nn.functional.mish"), xfail("nn.functional.mse_loss"), xfail("nn.functional.multi_margin_loss"), +<<<<<<< HEAD + xfail("nn.functional.multilabel_margin_loss"), + xfail("nn.functional.multilabel_soft_margin_loss"), +======= + xfail("nn.functional.multi_head_attention_forward"), xfail("nn.functional.multilabel_margin_loss"), xfail("nn.functional.multilabel_soft_margin_loss"), + xfail("nn.functional.normalize"), + xfail("nn.functional.pad", "constant"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("nn.functional.pad", "reflect"), xfail("nn.functional.pad", "replicate"), xfail("nn.functional.pad", "replicate_negative"), @@ -354,6 +395,10 @@ def wrapped(fn): xfail("rot90"), xfail("rsub"), xfail("scalar_tensor"), +<<<<<<< HEAD +======= + xfail("scatter_add"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("scatter_reduce", "amax"), xfail("scatter_reduce", "amin"), xfail("scatter_reduce", "mean"), @@ -361,6 +406,10 @@ def wrapped(fn): xfail("scatter_reduce", "sum"), xfail("searchsorted"), xfail("select_scatter"), +<<<<<<< HEAD +======= + xfail("sort"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("sparse.sampled_addmm"), xfail("sparse.mm", "reduce"), xfail("special.airy_ai"), @@ -370,8 +419,11 @@ def wrapped(fn): xfail("special.bessel_y1"), xfail("special.chebyshev_polynomial_t"), xfail("special.chebyshev_polynomial_u"), +<<<<<<< HEAD xfail("special.chebyshev_polynomial_v"), xfail("special.chebyshev_polynomial_w"), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("special.entr"), xfail("special.erfcx"), xfail("special.hermite_polynomial_h"), @@ -380,7 +432,10 @@ def wrapped(fn): xfail("special.i1"), xfail("special.i1e"), xfail("special.laguerre_polynomial_l"), +<<<<<<< HEAD xfail("special.legendre_polynomial_p"), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("special.log_ndtr"), xfail("special.modified_bessel_i0"), xfail("special.modified_bessel_i1"), @@ -389,10 +444,13 @@ def wrapped(fn): xfail("special.ndtri"), xfail("special.scaled_modified_bessel_k0"), xfail("special.scaled_modified_bessel_k1"), +<<<<<<< HEAD xfail("special.shifted_chebyshev_polynomial_t"), xfail("special.shifted_chebyshev_polynomial_u"), xfail("special.shifted_chebyshev_polynomial_v"), xfail("special.shifted_chebyshev_polynomial_w"), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("special.spherical_bessel_j0"), xfail("special.xlog1py"), xfail("special.zeta"), @@ -507,17 +565,31 @@ class TestDTensorOps(DTensorOpTestBase): def world_size(self) -> int: return OP_DB_WORLD_SIZE +<<<<<<< HEAD def run_opinfo_test( self, dtype, op, requires_grad=True, sample_inputs_filter=lambda s: True ): +======= + # only allow float dytpe for now, we can relax this constraint + # when feel necessary later (i.e when adding quantization support). + @suppress_warnings + @ops(op_db, allowed_dtypes=(torch.float,)) + @skipOps("TestDTensorOps", "test_dtensor_op_db", dtensor_fails) + def test_dtensor_op_db(self, dtype, op): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.mesh = DeviceMesh(DEVICE_TYPE, torch.arange(self.world_size)) # test each op with dist tensor inputs and normal inputs def test(): +<<<<<<< HEAD samples = op.sample_inputs(DEVICE_TYPE, dtype, requires_grad=requires_grad) for sample_input in samples: if not sample_inputs_filter(sample_input): continue +======= + samples = op.sample_inputs(DEVICE_TYPE, dtype, requires_grad=True) + for sample_input in samples: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = [sample_input.input] + list(sample_input.args) kwargs = sample_input.kwargs @@ -530,6 +602,7 @@ def test(): self.check_dtensor_func(test, op) +<<<<<<< HEAD # only allow float dytpe for now, we can relax this constraint # when feel necessary later (i.e when adding quantization support). @suppress_warnings @@ -538,6 +611,8 @@ def test(): def test_dtensor_op_db(self, dtype, op): self.run_opinfo_test(dtype, op) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def assert_ref_dtensor_equal(self, dtensor_rs, rs): flat_dtensor_rs = pytree.tree_leaves(dtensor_rs) flat_rs = pytree.tree_leaves(rs) @@ -651,6 +726,7 @@ def check_dtensor_func(self, test_func, opinfo, dry_run=False): else: print(f"xfail('{opinfo.name}'),") +<<<<<<< HEAD def test_one_hot(self): ops = [op for op in op_db if op.name == "nn.functional.one_hot"] assert len(ops) == 1 @@ -663,6 +739,8 @@ def test_one_hot(self): sample_inputs_filter=lambda s: s.kwargs["num_classes"] != -1, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU) instantiate_device_type_tests(TestDTensorOps, globals(), only_for=(DEVICE_TYPE,)) diff --git a/test/distributed/tensor/test_embedding_ops.py b/test/distributed/tensor/test_embedding_ops.py index eabd4a55470e1..fa38371c6da71 100644 --- a/test/distributed/tensor/test_embedding_ops.py +++ b/test/distributed/tensor/test_embedding_ops.py @@ -193,7 +193,11 @@ def test_multiple_embeddings_rowwise(self): from torch.distributed.tensor._ops._embedding_ops import _MaskPartial +<<<<<<< HEAD # case 1: two embeddings with the same shape, thus sharing the underlying _MaskPartial +======= + # case 1: two embeddings with the same shape, thus sharing the underying _MaskPartial +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # and MaskBuffer, because of cache hit from sharding propagation emb1 = torch.nn.Embedding(10, 23, device=self.device_type) diff --git a/test/distributed/tensor/test_experimental_ops.py b/test/distributed/tensor/test_experimental_ops.py index ec4229a47b19c..0c120f56f7bbc 100644 --- a/test/distributed/tensor/test_experimental_ops.py +++ b/test/distributed/tensor/test_experimental_ops.py @@ -4,7 +4,11 @@ import torch import torch.distributed as dist +<<<<<<< HEAD from torch.distributed.tensor import distribute_tensor, Replicate +======= +from torch.distributed.tensor import DeviceMesh, distribute_tensor, Replicate +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -24,7 +28,11 @@ def world_size(self) -> int: @with_comms def test_slice(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Replicate()] input_list = torch.rand(ITER_TIME, 1024, 10) @@ -76,7 +84,11 @@ def test_slice(self): @with_comms def test_bernoulli(self): rank = dist.get_rank() +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Replicate()] input_list = torch.rand(ITER_TIME, 1024, 10) @@ -138,7 +150,11 @@ def test_bernoulli(self): @with_comms def test_nll(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Replicate()] pred_list = torch.rand(ITER_TIME, 1024, 10) diff --git a/test/distributed/tensor/test_init.py b/test/distributed/tensor/test_init.py index d08b7e0fda4a1..3ed889fbcfa84 100644 --- a/test/distributed/tensor/test_init.py +++ b/test/distributed/tensor/test_init.py @@ -37,7 +37,11 @@ def world_size(self): def _run_init_op(self, init_op, dist_init_op, eq_op, *args, **kwargs): # 1d mesh test +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements_list = [[Shard(0)], [Shard(1)], [Shard(2)], [Replicate()]] # even sharding @@ -131,8 +135,13 @@ def test_zeros(self): @with_comms def test_zeros_full_mesh(self): +<<<<<<< HEAD # construct a gpu device 1d mesh mesh = self.build_device_mesh() +======= + # construct a cuda device 1d mesh + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = [Shard(0)] size = [32, 3] dist_tensor = zeros(size, device_mesh=mesh, placements=placements) @@ -157,7 +166,11 @@ def test_zeros_full_mesh(self): self.assertEqual(local_tensor.size(), torch.Size([7, 3])) self.assertEqual(torch.zeros(7, 3), local_tensor) +<<<<<<< HEAD # construct a gpu device mesh with 2d: shard, replicate +======= + # construct a cuda device mesh with 2d: shard, replicate +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2)) placements = [Shard(0), Replicate()] size = [32, 4] @@ -168,7 +181,11 @@ def test_zeros_full_mesh(self): self.assertEqual(local_tensor.size(), torch.Size([16, 4])) self.assertEqual(local_tensor, torch.zeros([16, 4])) +<<<<<<< HEAD # construct a gpu device mesh with 2d: shard, shard +======= + # construct a cuda device mesh with 2d: shard, shard +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placements = [Shard(0), Shard(1)] size = [32, 4] dist_tensor = zeros(size, device_mesh=mesh, placements=placements) @@ -197,7 +214,11 @@ def test_zeros_full_mesh(self): @with_comms def test_zeros_submesh(self): # default world_size is 4 +<<<<<<< HEAD # construct a gpu device 1d mesh, with no sub pg initialized +======= + # construct a cuda device 1d mesh, with no sub pg initialized +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sub_mesh_list = [0, 3] mesh = DeviceMesh(self.device_type, sub_mesh_list) placements = [Shard(0)] @@ -213,7 +234,11 @@ def test_zeros_submesh(self): self.assertEqual(local_tensor.size(), torch.Size([0])) self.assertEqual(local_tensor, torch.zeros(0)) +<<<<<<< HEAD # construct a gpu device 1d mesh: unevenly, with subpg initialized +======= + # construct a cuda device 1d mesh: unevenly, with subpg initialized +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sub_mesh_list = [0, 1, 3] mesh = DeviceMesh(self.device_type, sub_mesh_list) placements = [Shard(0)] @@ -233,7 +258,11 @@ def test_zeros_submesh(self): self.assertEqual(local_tensor.size(), torch.Size([0])) self.assertEqual(local_tensor, torch.tensor([])) +<<<<<<< HEAD # construct a gpu device 2d mesh, with no subpg initialized +======= + # construct a cuda device 2d mesh, with no subpg initialized +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sub_mesh_list = [[0], [3]] mesh = DeviceMesh(self.device_type, sub_mesh_list) placements = [Shard(0), Shard(1)] diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index 0dc2f15fe69a7..00c305fb26423 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -271,6 +271,7 @@ def test_layer_norm_fwd(self): norm_shape_idx_list = list(range(x.ndim)) shard_dims = [-1, 0, 1, 2] elementwise_affine_list = [False, True] +<<<<<<< HEAD # Test RMSNorm as well if CUDA norm_types = [torch.nn.LayerNorm] @@ -287,6 +288,16 @@ def test_layer_norm_fwd(self): for norm_type, shard_dim, norm_idx, elementwise_affine in test_config_list: normalized_shape = x.shape[norm_idx:] layer_norm = norm_type( +======= + test_config_list = list( + itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list) + ) + + # normalized shape is a torch.Size object + for shard_dim, norm_idx, elementwise_affine in test_config_list: + normalized_shape = x.shape[norm_idx:] + layer_norm = torch.nn.LayerNorm( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) normalized_shape, elementwise_affine=elementwise_affine, device=self.device_type, @@ -295,7 +306,10 @@ def test_layer_norm_fwd(self): def _replicate_fn(name, module, device_mesh): for name, param in module.named_parameters(): +<<<<<<< HEAD # RMSNorm only has weight, LayerNorm has both weight and bias +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if name in ["weight", "bias"]: param_dist = torch.nn.Parameter( distribute_tensor(param, device_mesh, [Replicate()]) @@ -316,7 +330,11 @@ def _replicate_fn(name, module, device_mesh): self.assertLessEqual( comm_mode.get_total_counts(), 1, # TODO: This should be 0! +<<<<<<< HEAD f"comm count={comm_mode.get_total_counts()}, norm_type={norm_type.__name__}, " +======= + f"comm count={comm_mode.get_total_counts()}, " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", ) @@ -338,6 +356,7 @@ def test_layer_norm_bwd(self): norm_shape_idx_list = list(range(3)) shard_dims = [0, 1, 2] elementwise_affine_list = [False, True] +<<<<<<< HEAD # Test both LayerNorm and RMSNorm (if CUDA) norm_types = [torch.nn.LayerNorm] @@ -352,6 +371,14 @@ def test_layer_norm_bwd(self): # normalized shape is a torch.Size object for norm_type, shard_dim, norm_idx, elementwise_affine in test_config_list: +======= + test_config_list = list( + itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list) + ) + + # normalized shape is a torch.Size object + for shard_dim, norm_idx, elementwise_affine in test_config_list: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.rand( batch, sentence_length, @@ -360,7 +387,11 @@ def test_layer_norm_bwd(self): requires_grad=True, ) normalized_shape = x.shape[norm_idx:] +<<<<<<< HEAD layer_norm = norm_type( +======= + layer_norm = torch.nn.LayerNorm( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) normalized_shape, elementwise_affine=elementwise_affine, device=self.device_type, @@ -381,11 +412,17 @@ def _replicate_fn(name, module, device_mesh): self.assertEqual( layer_norm_local.weight, layer_norm_dist.weight.full_tensor() ) +<<<<<<< HEAD # RMSNorm doesn't have bias if hasattr(layer_norm_local, "bias"): self.assertEqual( layer_norm_local.bias, layer_norm_dist.bias.full_tensor() ) +======= + self.assertEqual( + layer_norm_local.bias, layer_norm_dist.bias.full_tensor() + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x_local = x.detach().clone().requires_grad_(True) x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) @@ -403,7 +440,11 @@ def _replicate_fn(name, module, device_mesh): self.assertEqual( sum(comm_mode.comm_module_counts["Global"]["forward"].values()), expected_fwd_comm, +<<<<<<< HEAD f"comm count={comm_mode.get_total_counts()}, norm_type={norm_type.__name__}, " +======= + f"comm count={comm_mode.get_total_counts()}, " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", ) @@ -417,7 +458,11 @@ def _replicate_fn(name, module, device_mesh): self.assertEqual( sum(comm_mode.comm_module_counts["Global"]["backward"].values()), expected_bwd_comm, +<<<<<<< HEAD f"comm count={comm_mode.get_total_counts()}, norm_type={norm_type.__name__}, " +======= + f"comm count={comm_mode.get_total_counts()}, " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", ) @@ -431,22 +476,36 @@ def _replicate_fn(name, module, device_mesh): is_tensor_partial(layer_norm_dist.weight.grad._spec), needs_reduction, ) +<<<<<<< HEAD # RMSNorm doesn't have bias if hasattr(layer_norm_dist, "bias"): self.assertEqual( is_tensor_partial(layer_norm_dist.bias.grad._spec), needs_reduction, ) +======= + self.assertEqual( + is_tensor_partial(layer_norm_dist.bias.grad._spec), + needs_reduction, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( layer_norm_local.weight.grad, layer_norm_dist.weight.grad.full_tensor(), ) +<<<<<<< HEAD # RMSNorm doesn't have bias if hasattr(layer_norm_local, "bias"): self.assertEqual( layer_norm_local.bias.grad, layer_norm_dist.bias.grad.full_tensor(), ) +======= + self.assertEqual( + layer_norm_local.bias.grad, + layer_norm_dist.bias.grad.full_tensor(), + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(x_local.grad, x_dist.grad.full_tensor()) @@ -455,6 +514,7 @@ def test_layer_norm_bwd_req_grad(self): device_mesh = self.build_device_mesh() batch, seq_len, embedding_dim, vocab_size = 8, 8, 10, 32 +<<<<<<< HEAD # Test both LayerNorm and RMSNorm (if CUDA) norm_types = [torch.nn.LayerNorm] if self.device_type == "cuda" and hasattr(torch.nn, "RMSNorm"): @@ -463,6 +523,10 @@ def test_layer_norm_bwd_req_grad(self): # build our subtest configurations and filter out invalid ones class SubTest(NamedTuple): norm_type: type +======= + # build our subtest configurations and filter out invalid ones + class SubTest(NamedTuple): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) multidim_norm: bool elementwise_affine: bool emb_req_grad: bool @@ -472,24 +536,35 @@ class SubTest(NamedTuple): subtest_fails = {} valid_filter = ( # noqa: E731 lambda cfg: ( +<<<<<<< HEAD not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[3:]) +======= + not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[2:]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) subtest_cfgs = list( filter( valid_filter, +<<<<<<< HEAD [ SubTest(norm_type, *cfg) for norm_type in norm_types for cfg in itertools.product(*(((False, True),) * 5)) ], +======= + [SubTest(*cfg) for cfg in itertools.product(*(((False, True),) * 5))], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) for subtest_cfg in subtest_cfgs: try: ( +<<<<<<< HEAD norm_type, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) multidim_norm, elementwise_affine, emb_req_grad, @@ -507,7 +582,11 @@ def __init__(self): self.preln_embeddings = torch.nn.Embedding( vocab_size, embedding_dim ) +<<<<<<< HEAD self.layer_norm = norm_type( +======= + self.layer_norm = torch.nn.LayerNorm( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) normalized_shape, elementwise_affine=elementwise_affine ) self.postln_linear = torch.nn.Linear( @@ -724,7 +803,11 @@ def test_foreach_add_different_mesh(self): self.assertEqual(out0.device_mesh, mesh_x) self.assertEqual(out1.device_mesh, mesh_y) +<<<<<<< HEAD with self.assertRaisesRegex(RuntimeError, "Sharding propagation failed"): +======= + with self.assertRaisesRegex(ValueError, "computation across different mesh"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.ops.aten._foreach_add( [replica_inp00, replica_inp01], [replica_inp10, replica_inp11] ) @@ -793,6 +876,7 @@ def test_cumsum(self): self.assertTrue(output_dtensor.placements[0].is_shard(shard_dim)) self.assertEqual(output_dtensor.full_tensor(), output) +<<<<<<< HEAD @with_comms def test_conj_complex_dtensor(self): mesh = self.build_device_mesh() @@ -881,6 +965,8 @@ def test_histc(self): out_full = out_dt.full_tensor() self.assertEqual(global_bins, out_full) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_matrix_ops.py b/test/distributed/tensor/test_matrix_ops.py index f467d1175db1b..8b62c867d3dbe 100644 --- a/test/distributed/tensor/test_matrix_ops.py +++ b/test/distributed/tensor/test_matrix_ops.py @@ -7,7 +7,11 @@ import torch import torch.nn.functional as F +<<<<<<< HEAD from torch.distributed import init_device_mesh +======= +from torch.distributed import DeviceMesh, init_device_mesh +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.tensor import ( distribute_tensor, DTensor, @@ -19,12 +23,16 @@ from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM90OrLater from torch.testing._internal.common_device_type import E4M3_MAX_POS, e4m3_type +<<<<<<< HEAD from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, TEST_WITH_ROCM, ) +======= +from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, skip_unless_torch_gpu, @@ -52,7 +60,11 @@ def scale_for_fp8( class DistMatrixOpsTest(DTensorTestBase): @with_comms def test_addmm(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Shard(0)] replica_spec = [Replicate()] @@ -69,7 +81,11 @@ def test_addmm(self): @with_comms def test_addmm_empty_operand(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Shard(0)] replica_spec = [Replicate()] @@ -86,7 +102,11 @@ def test_addmm_empty_operand(self): @with_comms def test_addmm_auto_redistribute(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard0_spec = [Shard(0)] shard1_spec = [Shard(1)] replica_spec = [Replicate()] @@ -117,7 +137,11 @@ def test_addmm_auto_redistribute(self): @with_comms def test_mm(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard0_spec = Shard(0) shard1_spec = Shard(1) replica_spec = Replicate() @@ -152,7 +176,11 @@ def test_placement_comb( "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", ) def test_scaled_mm(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shrd0 = Shard(0) shrd1 = Shard(1) repl = Replicate() @@ -222,7 +250,11 @@ def test_scaled_mm(self): @with_comms def test_matmul(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dim = 128 x = torch.randn(8, dim) A = torch.randn(dim, dim) @@ -241,7 +273,11 @@ def test_matmul(self): @with_comms def test_t(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Shard(0)] tensor_to_transpose = torch.randn(12, 8, requires_grad=True) @@ -255,7 +291,11 @@ def test_t(self): @with_comms def test_t_partial(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = torch.randn(12, 8) b = torch.randn(8, 4) @@ -280,7 +320,11 @@ def test_t_partial(self): @with_comms @skip_unless_torch_gpu def test_baddbmm(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) batch_1 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) batch_2 = torch.rand(4, 8, 8, device=self.device_type, requires_grad=True) @@ -344,7 +388,11 @@ def test_placement_comb( @with_comms def test_bmm(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mat1 = torch.rand(4, 8, 4, device=self.device_type, requires_grad=True) mat2 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) local_result = torch.bmm(mat1, mat2) @@ -389,7 +437,11 @@ def test_placement_comb( @with_comms @skip_unless_torch_gpu def test_scaled_dot_product_attention(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comm_mode = CommDebugMode() # bsz, n_heads, slen, head_dim query = torch.rand( @@ -411,6 +463,13 @@ def test_scaled_dot_product_attention(self): requires_grad=True, ) +<<<<<<< HEAD +======= + dist_query = distribute_tensor(query, device_mesh, [Shard(1)]) + dist_key = distribute_tensor(key, device_mesh, [Shard(1)]) + dist_value = distribute_tensor(value, device_mesh, [Shard(1)]) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.nn.attention import sdpa_kernel, SDPBackend available_backends = [] @@ -427,6 +486,7 @@ def test_scaled_dot_product_attention(self): if torch.backends.cuda.can_use_efficient_attention(params, debug=False): available_backends.append(SDPBackend.EFFICIENT_ATTENTION) +<<<<<<< HEAD placement_specs = [(Replicate(),), (Shard(0),), (Shard(1),)] for backend, input_placements in itertools.product( available_backends, placement_specs @@ -434,6 +494,9 @@ def test_scaled_dot_product_attention(self): dist_query = distribute_tensor(query, device_mesh, input_placements) dist_key = distribute_tensor(key, device_mesh, input_placements) dist_value = distribute_tensor(value, device_mesh, input_placements) +======= + for backend in available_backends: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with sdpa_kernel(backends=[backend]): out = F.scaled_dot_product_attention( query, key, value, dropout_p=dropout_p, is_causal=is_causal @@ -447,13 +510,18 @@ def test_scaled_dot_product_attention(self): is_causal=is_causal, ) self.assertEqual(comm_mode.get_total_counts(), 0) +<<<<<<< HEAD self.assertEqual(dist_out.placements, input_placements) +======= + self.assertTrue(dist_out.placements[0].is_shard(dim=1)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(dist_out.full_tensor(), out) out.sum().backward() with comm_mode: dist_out.sum().backward() self.assertEqual(comm_mode.get_total_counts(), 0) +<<<<<<< HEAD self.assertEqual(dist_query.grad.placements, input_placements) self.assertEqual(dist_query.grad.full_tensor(), query.grad) self.assertEqual(dist_key.grad.placements, input_placements) @@ -463,6 +531,14 @@ def test_scaled_dot_product_attention(self): query.grad.zero_() key.grad.zero_() value.grad.zero_() +======= + self.assertTrue(dist_query.grad.placements[0].is_shard(dim=1)) + self.assertEqual(dist_query.grad.full_tensor(), query.grad) + self.assertTrue(dist_key.grad.placements[0].is_shard(dim=1)) + self.assertEqual(dist_key.grad.full_tensor(), key.grad) + self.assertTrue(dist_value.grad.placements[0].is_shard(dim=1)) + self.assertEqual(dist_value.grad.full_tensor(), value.grad) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skip_unless_torch_gpu @with_comms() @@ -497,7 +573,11 @@ def test_tensordot_shampoo(self): """ Create a simple test for Shampoo's use case. """ +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = init_device_mesh(self.device_type, (self.world_size,)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) local_a = torch.randn(4, 4) local_b = torch.randn(4, 15) @@ -518,6 +598,7 @@ def test_tensordot_shampoo(self): @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @with_comms @skip_unless_torch_gpu +<<<<<<< HEAD @parametrize( "kwargs", [ @@ -570,26 +651,57 @@ def test_grouped_mm(self, kwargs): device=self.device_type, dtype=dtype, requires_grad=True, +======= + def test_grouped_mm(self): + # TODO: torch._grouped_mm can take inputs of dimension (2D, 3D) x (2D, 3D) + # Here we only test the 2D x 3D Tensor Parallel use case in an MoE layer. + # More tests need to be added. + device_mesh = init_device_mesh(self.device_type, (self.world_size,)) + comm_mode = CommDebugMode() + dtype = torch.bfloat16 + + inp = torch.rand( + 64, 16, device=self.device_type, dtype=dtype, requires_grad=True + ) + w1 = torch.rand( + 2, 16, 32, device=self.device_type, dtype=dtype, requires_grad=True + ) + w2 = torch.rand( + 2, 32, 16, device=self.device_type, dtype=dtype, requires_grad=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) offs = torch.tensor([16, 64], device=self.device_type, dtype=torch.int32) h = torch._grouped_mm(inp, w1, offs=offs) out = torch._grouped_mm(h, w2, offs=offs) +<<<<<<< HEAD dist_inp = distribute_tensor(inp, device_mesh, kwargs["inp_placements"]) # colwise sharded dist_w1 = distribute_tensor(w1, device_mesh, kwargs["w1_placements"]) # rowwise sharded dist_w2 = distribute_tensor(w2, device_mesh, kwargs["w2_placements"]) +======= + dist_inp = distribute_tensor(inp, device_mesh, [Replicate()]) + # colwise sharded + dist_w1 = distribute_tensor(w1, device_mesh, [Shard(2)]) + # rowwise sharded + dist_w2 = distribute_tensor(w2, device_mesh, [Shard(1)]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dist_offs = distribute_tensor(offs, device_mesh, [Replicate()]) with comm_mode: dist_h = torch._grouped_mm(dist_inp, dist_w1, offs=dist_offs) dist_out = torch._grouped_mm(dist_h, dist_w2, offs=dist_offs) +<<<<<<< HEAD self.assertEqual( comm_mode.get_total_counts(), kwargs["expected_comm_counts_fwd"] ) self.assertEqual(dist_out.placements, kwargs["expected_out_placements"]) +======= + self.assertEqual(comm_mode.get_total_counts(), 0) + self.assertTrue(dist_out.placements[0].is_partial()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(dist_out.full_tensor(), out) out_grad = torch.ones_like(out) @@ -600,19 +712,29 @@ def test_grouped_mm(self, kwargs): with comm_mode: dist_out.backward(dist_out_grad) +<<<<<<< HEAD self.assertEqual( comm_mode.get_total_counts(), kwargs["expected_comm_counts_bwd"] ) self.assertEqual( comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], kwargs["expected_comm_counts_bwd"], +======= + self.assertEqual(comm_mode.get_total_counts(), 1) + self.assertEqual( + comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], + 1, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.assertEqual(dist_inp.grad.full_tensor(), inp.grad) self.assertEqual(dist_w1.grad.full_tensor(), w1.grad) self.assertEqual(dist_w2.grad.full_tensor(), w2.grad) +<<<<<<< HEAD instantiate_parametrized_tests(DistMatrixOpsTest) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_op_strategy.py b/test/distributed/tensor/test_op_strategy.py index 8e97d80e95430..e6e803ec8d483 100644 --- a/test/distributed/tensor/test_op_strategy.py +++ b/test/distributed/tensor/test_op_strategy.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: distributed"] +<<<<<<< HEAD import itertools import random from contextlib import contextmanager @@ -26,10 +27,20 @@ OpStrategy, RuntimeSchemaInfo, ) +======= +from itertools import chain + +import torch +from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard +from torch.distributed.tensor._collective_utils import redistribute_cost +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import OpSchema, OpSpec, OpStrategy +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.tensor._ops._einsum_strategy import ( EinsumDims, gen_einsum_strategies, ) +<<<<<<< HEAD from torch.distributed.tensor._ops.utils import ( register_op_strategy, replicate_op_strategy, @@ -51,6 +62,10 @@ def extract_tensor_meta(t) -> TensorMeta: return TensorMeta(t.shape, t.stride(), t.dtype) +======= +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestEinsumDims(TestCase): @@ -103,7 +118,11 @@ def test_free_dims(self): self.assertEqual(edims.lhs_out_only_dims, ["c"]) self.assertEqual(edims.rhs_out_only_dims, []) +<<<<<<< HEAD equation = "abd,bf->abfd" # codespell:ignore +======= + equation = "abd,bf->abfd" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_dims, output_dim = EinsumDims.parse_equation(equation) edims = EinsumDims.parse_dims(input_dims, output_dim) @@ -136,6 +155,7 @@ def test_bmm_1d_mesh(self): all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh) self.assertEqual(len(all_strats.strategies), 5) +<<<<<<< HEAD def test_bmm_diffinndim_2d_mesh(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2)) all_strats = gen_einsum_strategies("bmk,kn->bmn", mesh) @@ -146,6 +166,8 @@ def test_bmm_diffoutndim_2d_mesh(self): all_strats = gen_einsum_strategies("bmk,k->bm", mesh) self.assertEqual(len(all_strats.strategies), 16) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_bmm_2d_mesh(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2)) @@ -169,6 +191,12 @@ def test_linearity_1d_mesh(self): class TestCostModel(DTensorOpTestBase): +<<<<<<< HEAD +======= + def _extract_tensor_meta(self, t) -> TensorMeta: + return TensorMeta(t.shape, t.stride(), t.dtype) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def world_size(self) -> int: return 4 @@ -180,7 +208,11 @@ def test_redistribute_cost_mesh_1d(self): partial_placement = (Partial(),) global_tensor = torch.randn(10, 10) +<<<<<<< HEAD global_tensor_meta = extract_tensor_meta(global_tensor) +======= + global_tensor_meta = self._extract_tensor_meta(global_tensor) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # shard spec shard_spec = DTensorSpec(mesh_1d, shard_placement, global_tensor_meta) @@ -215,9 +247,15 @@ def test_redistribute_cost_latency(self): partial_placement = (Partial(),) shard1_placement = (Shard(1),) +<<<<<<< HEAD shard0_tensor_meta = extract_tensor_meta(torch.randn(8)) partial_tensor_meta = extract_tensor_meta(torch.randn(50, 6)) shard1_tensor_meta = extract_tensor_meta(torch.randn(6, 8)) +======= + shard0_tensor_meta = self._extract_tensor_meta(torch.randn(8)) + partial_tensor_meta = self._extract_tensor_meta(torch.randn(50, 6)) + shard1_tensor_meta = self._extract_tensor_meta(torch.randn(6, 8)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # shard spec shard0_spec = DTensorSpec(mesh, shard0_placement, shard0_tensor_meta) @@ -261,7 +299,11 @@ def test_redistribute_cost_mesh_2d(self): partial_placement = (Partial(), Partial()) global_tensor = torch.randn(8, 8) +<<<<<<< HEAD global_tensor_meta = extract_tensor_meta(global_tensor) +======= + global_tensor_meta = self._extract_tensor_meta(global_tensor) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # shard spec shard_spec = DTensorSpec(mesh_2d, shard_placement, global_tensor_meta) @@ -290,8 +332,13 @@ def test_mm_strategies(self): mesh = self.build_device_mesh() lhs_tensor = torch.randn(6, 8) rhs_tensor = torch.randn(8, 12) +<<<<<<< HEAD lhs_tensor_meta = extract_tensor_meta(lhs_tensor) rhs_tensor_meta = extract_tensor_meta(rhs_tensor) +======= + lhs_tensor_meta = self._extract_tensor_meta(lhs_tensor) + rhs_tensor_meta = self._extract_tensor_meta(rhs_tensor) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mm_combs = ( (Shard(0), Replicate()), @@ -336,8 +383,13 @@ def test_bmm_strategies(self): mesh = self.build_device_mesh() lhs_tensor = torch.randn(8, 6, 8) rhs_tensor = torch.randn(8, 8, 12) +<<<<<<< HEAD lhs_tensor_meta = extract_tensor_meta(lhs_tensor) rhs_tensor_meta = extract_tensor_meta(rhs_tensor) +======= + lhs_tensor_meta = self._extract_tensor_meta(lhs_tensor) + rhs_tensor_meta = self._extract_tensor_meta(rhs_tensor) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bmm_combs = ( (Shard(0), Shard(0)), @@ -378,6 +430,7 @@ def test_bmm_strategies(self): self.assertFalse(output_sharding.needs_redistribute) +<<<<<<< HEAD # -------------Test op strategy registration------------- # custom op without List[Tensor] as input # reference: https://docs.pytorch.org/docs/stable/library.html#torch.library.register_autograd @@ -644,5 +697,7 @@ def test_call_with_different_nontensor_args(self): self.assertEqual(out1.full_tensor(), out2.full_tensor()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_optimizers.py b/test/distributed/tensor/test_optimizers.py index c876f28e165b3..62587facdbcb6 100644 --- a/test/distributed/tensor/test_optimizers.py +++ b/test/distributed/tensor/test_optimizers.py @@ -5,10 +5,17 @@ import torch import torch.nn as nn from torch.distributed.tensor import ( +<<<<<<< HEAD distribute_module, distribute_tensor, DTensor, init_device_mesh, +======= + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Replicate, Shard, ) @@ -88,7 +95,11 @@ def test_optimizer_foreach_supported_types_include_DTensor(self): @with_comms def test_adam_1d_sharding(self): +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # lr as a Tensor is not supported for capturable=False and foreach=True adam_float_lr_configs = [ @@ -147,7 +158,11 @@ def test_adam_1d_sharding(self): @with_comms def test_adamw_1d_sharding(self): +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # lr as a Tensor is not supported for capturable=False and foreach=True adamw_float_lr_configs = [ @@ -223,7 +238,11 @@ def test_adamw_1d_sharding(self): @with_comms def test_sgd_1d_sharding(self): +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sgd_configs = [ {"lr": 0.1, "foreach": False}, @@ -263,7 +282,11 @@ def test_sgd_1d_sharding(self): @with_comms def test_adagrad_1d_sharding(self): +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) adagrad_configs = [ {"lr": 0.1, "foreach": False}, @@ -319,7 +342,11 @@ def test_adagrad_1d_sharding(self): @with_comms def test_RMSprop_1d_sharding(self): +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) RMSprop_configs = [ {"lr": 0.1, "foreach": False}, @@ -386,7 +413,11 @@ def test_RMSprop_1d_sharding(self): @with_comms def test_adadelta_1d_sharding(self): +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) adadelta_configs = [ {"lr": 0.1, "foreach": False}, @@ -430,7 +461,11 @@ def test_adadelta_1d_sharding(self): @with_comms def test_nadam_1d_sharding(self): +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nadam_configs = [ {"lr": 0.1, "foreach": False}, @@ -467,7 +502,11 @@ def test_nadam_1d_sharding(self): @with_comms def test_radam_1d_sharding(self): +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) radam_configs = [ {"lr": 0.1, "foreach": False}, @@ -507,7 +546,11 @@ def test_radam_1d_sharding(self): @with_comms def test_adamax_1d_sharding(self): +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) adamax_configs = [ {"lr": 0.1, "foreach": False}, @@ -551,7 +594,11 @@ def test_adamax_1d_sharding(self): @with_comms def test_asgd_1d_sharding(self): +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) asgd_configs = [ {"lr": 0.1, "foreach": False}, @@ -606,6 +653,7 @@ def test_asgd_1d_sharding(self): mesh, mod, opt, dist_mod, dist_opt, inp, atol=1.3e-5, rtol=1e-4 ) +<<<<<<< HEAD @with_comms def test_admaw_fused_across_meshes(self): mesh_shape = (2, self.world_size // 2) @@ -715,6 +763,8 @@ def _input_fn_2d(mod, inputs, device_mesh): inp = torch.ones(8, 10, device=self.device_type) self._assert_optimizer(None, mod, opt, mod_copy, dist_opt, inp) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_pointwise_ops.py b/test/distributed/tensor/test_pointwise_ops.py index 28dd1ac9def51..1671be8d1aaa8 100644 --- a/test/distributed/tensor/test_pointwise_ops.py +++ b/test/distributed/tensor/test_pointwise_ops.py @@ -17,7 +17,10 @@ Replicate, Shard, ) +<<<<<<< HEAD from torch.distributed.tensor.debug import CommDebugMode +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorOpTestBase, @@ -148,6 +151,17 @@ def test_partial_add(self): d_3 = d_1 + d_2 self.assertTrue(d_3._spec.placements[0].is_partial()) +<<<<<<< HEAD +======= + def test_partial_mul(self): + device_mesh = self.build_device_mesh() + d_1 = DTensor.from_local(torch.ones(2, 2), device_mesh, [Partial()]) + d_2 = DTensor.from_local(torch.ones(2, 2), device_mesh, [Partial()]) + d_3 = d_1 * d_2 + self.assertTrue(d_3._spec.placements[0].is_replicate()) + self.assertEqual(d_3.to_local(), torch.ones(2, 2) * (self.world_size**2)) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_activations(self): device_mesh = self.build_device_mesh() self._run_sharded_elementwise_ops( @@ -275,6 +289,7 @@ def test_mul_out(self): self.assertEqual(input_tensor, dtensor.to_local()) self.assertEqual(expected, dt.to_local()) +<<<<<<< HEAD def test_mul_partial(self): # we only test the partial behavior for mul op as other placement # behaviors should be well tested in test_dtensor_ops.py @@ -331,6 +346,8 @@ def test_mul_partial(self): self.assertEqual(z.placements, (Replicate(),)) self.assertEqual(z.to_local(), input) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_random_ops.py b/test/distributed/tensor/test_random_ops.py index 2cf9916c7d67a..b3b398fb7e413 100644 --- a/test/distributed/tensor/test_random_ops.py +++ b/test/distributed/tensor/test_random_ops.py @@ -33,18 +33,25 @@ ) +<<<<<<< HEAD def get_generator_seed_for_device_type(device_type: str) -> int: device_module = torch.get_device_module(device_type) return device_module.get_rng_state()[:8].view(torch.int64).item() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class DistTensorRandomInitTest(DTensorTestBase): def _run_init_op(self, init_op, *args, **kwargs): device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_size = (8, 4) +<<<<<<< HEAD # NOTE: currently random initialization on gpu device has different +======= + # NOTE: currently random initialization on cuda device has different +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # behavior from other devices. Unify the test once the behavior is unified. if not is_rng_supported_mesh(device_mesh): input_tensor = torch.randn(*input_size, device=self.device_type) @@ -94,6 +101,7 @@ def test_init_ops(self): @with_comms @skip_if_lt_x_gpu(4) +<<<<<<< HEAD def test_init_with_user_generator(self): device_mesh = self.build_device_mesh() torch.manual_seed(42) @@ -127,6 +135,16 @@ def test_meta_tensor_init(self): # Note: this behavior changed, and now the guideline is to set the same RNG seed on all SPMD ranks. torch.get_device_module(self.device_type).manual_seed(0) +======= + def test_meta_tensor_init(self): + # test suite sets each rank's seed to the same value but in actual + # execution the default random seed will be different (a random value). + # The DTensor random ops will use the same random seed even though the + # torch random generator keeps different seeds on ranks. This ensures + # that Replicate DTensor will have the same initialized results + # across ranks. + torch.cuda.manual_seed(self.rank) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) size = [1024, 2048] meta_dtensor = distribute_tensor( @@ -145,7 +163,11 @@ def test_meta_tensor_init(self): self.assertTrue(random._rng_tracker.distribute_region_enabled) # allgather the local tensors +<<<<<<< HEAD gathered_local_tensors = funcol.all_gather_tensor( +======= + local_tensor = funcol.all_gather_tensor( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) @@ -156,8 +178,12 @@ def test_meta_tensor_init(self): # other rank should have an identical local tensor other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) self.assertEqual( +<<<<<<< HEAD gathered_local_tensors[self_slice, :], gathered_local_tensors[other_slice, :], +======= + local_tensor[self_slice, :], local_tensor[other_slice, :] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Test 2: disable the distribute region for RNG @@ -176,11 +202,19 @@ def test_meta_tensor_init(self): # compare with local tensors from other ranks for other_rank in range(self.world_size): +<<<<<<< HEAD # the RNG result on each rank are the same even without the help of DTensor's RNG infra, # since the default RNG is the same across ranks. if self.rank != other_rank: other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) self.assertEqual( +======= + # the RNG result on each rank differs even they're supposed + # to be replicated + if self.rank != other_rank: + other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) + self.assertNotEqual( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) local_tensor[self_slice, :], local_tensor[other_slice, :] ) @@ -306,12 +340,16 @@ def test_rng_tracker_init(self): # seed synchronization only happens after `manual_seed` or the first DTensor # random op call dt.uniform_(0, 1) +<<<<<<< HEAD # We do not maintain the copy of the seed in dtensor, but we do mutate the global rng state # since we now always pull it fresh from the local device generator self.assertEqual( seed_from_rank_0, get_generator_seed_for_device_type(self.device_type) ) +======= + self.assertEqual(seed_from_rank_0, random._rng_tracker.get_seed("parallel-rng")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @with_comms @skip_unless_torch_gpu @@ -330,6 +368,7 @@ def test_manual_seed(self): manual_seed(self.rank, device_mesh) # RNG tracker should already be initialized self.assertTrue(random._rng_tracker is not None) +<<<<<<< HEAD self.assertEqual( self.rank, get_generator_seed_for_device_type(self.device_type) ) @@ -337,6 +376,13 @@ def test_manual_seed(self): # Test 2: set same seed on different ranks manual_seed(1234, device_mesh) self.assertEqual(1234, get_generator_seed_for_device_type(self.device_type)) +======= + self.assertEqual(self.rank, random._rng_tracker.get_seed("parallel-rng")) + + # Test 2: set same seed on different ranks + manual_seed(1234, device_mesh) + self.assertEqual(1234, random._rng_tracker.get_seed("parallel-rng")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(comm_mode.get_total_counts(), 0) @@ -369,10 +415,14 @@ def test_pipeline_parallel_manual_seed(self): # set the seed for each pipeline stage to 123 + pp_rank manual_seed(123 + pp_rank, spmd_mesh) +<<<<<<< HEAD # dtensor no longer stores a copy of the seed, but it mutates the device's generator so we can check that self.assertEqual( 123 + pp_rank, get_generator_seed_for_device_type(self.device_type) ) +======= + self.assertEqual(123 + pp_rank, random._rng_tracker.get_seed("parallel-rng")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # mimic initializing a model weight sharded on the SPMD mesh spmd_dtensor = torch.distributed.tensor.ones( @@ -457,15 +507,23 @@ def test_deterministic_rand_1d(self): self_slice = slice(4 * self.rank, 4 * self.rank + 4) for other_rank in range(self.world_size): if self.rank != other_rank: +<<<<<<< HEAD # other rank should have a different local tensor for shard placement +======= + # other rank should have an identical local tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) other_slice = slice(4 * other_rank, 4 * other_rank + 4) self.assertNotEqual( local_tensor[self_slice, :], local_tensor[other_slice, :], ) +<<<<<<< HEAD # we should set manual seed to the same value on all SPMD ranks torch.manual_seed(0) +======= + torch.manual_seed(self.rank) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtensor = fn(size, device_mesh=device_mesh, placements=[Replicate()]) local_tensor = funcol.all_gather_tensor( dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) @@ -475,7 +533,11 @@ def test_deterministic_rand_1d(self): self_slice = slice(4 * self.rank, 4 * self.rank + 4) for other_rank in range(self.world_size): if self.rank != other_rank: +<<<<<<< HEAD # other rank should have an identical local tensor for replicate placement +======= + # other rank should have an identical local tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) other_slice = slice(4 * other_rank, 4 * other_rank + 4) self.assertEqual( local_tensor[self_slice, :], @@ -592,8 +654,13 @@ class DistTensorRandomOpsTest3D(DTensorTestBase): def world_size(self): return 8 +<<<<<<< HEAD @skip_if_lt_x_gpu(8) @with_comms +======= + @with_comms + @skip_if_lt_x_gpu(8) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_hsdp_tp_model_meta_init(self): # initialize the 3-d device mesh global_mesh = init_device_mesh( diff --git a/test/distributed/tensor/test_redistribute.py b/test/distributed/tensor/test_redistribute.py index fe07b0dd6a241..ca9489e94df02 100644 --- a/test/distributed/tensor/test_redistribute.py +++ b/test/distributed/tensor/test_redistribute.py @@ -15,6 +15,7 @@ ) from torch.distributed.tensor._collective_utils import shard_dim_alltoall from torch.distributed.tensor.debug import CommDebugMode +<<<<<<< HEAD from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -22,6 +23,9 @@ TEST_CUDA, TEST_HPU, ) +======= +from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TEST_HPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, @@ -37,10 +41,16 @@ def world_size(self): return 4 @with_comms +<<<<<<< HEAD @parametrize("dtype", [torch.float32, torch.cfloat]) def test_shard_to_replicate_forward_backward(self, dtype): # 1) test shard -> replicate forward device_mesh = self.build_device_mesh() +======= + def test_shard_to_replicate_forward_backward(self): + # 1) test shard -> replicate forward + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) replica_spec = [Replicate()] input_sizes_and_shard_dim = [ @@ -56,7 +66,11 @@ def test_shard_to_replicate_forward_backward(self, dtype): for input_size, shard_dim in input_sizes_and_shard_dim: shard_spec = [Shard(shard_dim)] expected_tensor = torch.randn( +<<<<<<< HEAD input_size, device=self.device_type, requires_grad=True, dtype=dtype +======= + input_size, device=self.device_type, requires_grad=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) dtensor = distribute_tensor(expected_tensor, device_mesh, shard_spec) with comm_mode: @@ -75,14 +89,22 @@ def test_shard_to_replicate_forward_backward(self, dtype): grad_input = dtensor.grad self.assertEqual(grad_input.placements, shard_spec) self.assertEqual( +<<<<<<< HEAD grad_input.to_local(), torch.ones(dtensor.to_local().size(), dtype=dtype), +======= + grad_input.to_local(), torch.ones(dtensor.to_local().size()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.assertEqual(comm_mode.get_total_counts(), 0) @with_comms def test_replicate_to_replicate_forward_backward(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) replica_spec = [Replicate()] local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) @@ -109,6 +131,7 @@ def test_replicate_to_replicate_forward_backward(self): self.assertEqual(comm_mode.get_total_counts(), 0) @with_comms +<<<<<<< HEAD @parametrize("dtype", [torch.float32, torch.cfloat]) def test_replicate_to_local_partial_grad(self, dtype): device_mesh = self.build_device_mesh() @@ -116,6 +139,12 @@ def test_replicate_to_local_partial_grad(self, dtype): local_tensor = torch.randn( 12, 3, device=self.device_type, requires_grad=True, dtype=dtype ) +======= + def test_replicate_to_local_partial_grad(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + replica_spec = [Replicate()] + local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) replica_tensor = distribute_tensor(local_tensor, device_mesh, replica_spec) @@ -132,7 +161,11 @@ def test_replicate_to_local_partial_grad(self, dtype): @with_comms def test_replicate_to_shard_forward_backward(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) replica_spec = [Replicate()] input_sizes_and_shard_dim = [ @@ -179,16 +212,25 @@ def test_replicate_to_shard_forward_backward(self): ) @with_comms +<<<<<<< HEAD @parametrize("dtype", [torch.float32, torch.cfloat]) def test_partial_to_replicate_forward_backward(self, dtype): +======= + def test_partial_to_replicate_forward_backward(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Although we don't allow user to reshard to produce a partial # placement (i.e. user can't reshard to partial), we do allow # replicate to partial internally, and also partial to replicate # backward should work as expected +<<<<<<< HEAD device_mesh = self.build_device_mesh() partial_local = torch.ones( 12, 3, device=self.device_type, requires_grad=True, dtype=dtype ) +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + partial_local = torch.ones(12, 3, device=self.device_type, requires_grad=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) partial_spec = [Partial()] replica_spec = [Replicate()] @@ -213,14 +255,22 @@ def test_partial_to_replicate_forward_backward(self, dtype): global_partial_tensor.backward(torch.ones_like(global_partial_tensor)) self.assertIsNotNone(partial_local.grad) self.assertEqual(partial_local.grad.size(), partial_local.size()) +<<<<<<< HEAD self.assertEqual( partial_local.grad, torch.ones_like(partial_local, dtype=dtype) ) +======= + self.assertEqual(partial_local.grad, torch.ones_like(partial_local)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(comm_mode.get_total_counts(), 0) @with_comms def test_replicate_to_replicate_forward_backward_datatype_conversion(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) replica_spec = [Replicate()] forward_datatypes = [ @@ -277,7 +327,11 @@ def test_replicate_to_replicate_forward_backward_datatype_conversion(self): @with_comms def test_shard_to_replicate_forward_backward_datatype_conversion(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) replica_spec = [Replicate()] shard_dim_and_input_sizes = [ @@ -349,7 +403,11 @@ def test_shard_to_replicate_forward_backward_datatype_conversion(self): @with_comms def test_replicate_to_partial(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) partial_spec = Partial() replica_spec = Replicate() @@ -396,9 +454,14 @@ def test_replicate_to_partial(self): self.assertEqual(comm_mode.get_total_counts(), 0) @with_comms +<<<<<<< HEAD @parametrize("dtype", [torch.float32, torch.cfloat]) def test_partial_to_shard(self, dtype): device_mesh = self.build_device_mesh() +======= + def test_partial_to_shard(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) partial_spec = [Partial()] my_rank = device_mesh.get_rank() @@ -416,7 +479,11 @@ def test_partial_to_shard(self, dtype): for input_size, shard_dim in input_sizes_and_shard_dim: shard_spec = [Shard(shard_dim)] +<<<<<<< HEAD partial_local = torch.ones(input_size, device=self.device_type, dtype=dtype) +======= + partial_local = torch.ones(input_size, device=self.device_type) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) partial_tensor = DTensor.from_local( partial_local, device_mesh, partial_spec, run_check=False ) @@ -445,7 +512,11 @@ def test_partial_to_shard(self, dtype): self.assertEqual(scatter_shard_tensor.placements, shard_spec) self.assertEqual( scatter_shard_tensor.to_local(), +<<<<<<< HEAD torch.ones(local_shape, dtype=dtype) * self.world_size, +======= + torch.ones(local_shape) * self.world_size, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.assertEqual( comm_mode.get_comm_counts()[funcol.reduce_scatter_tensor], 1 @@ -453,7 +524,11 @@ def test_partial_to_shard(self, dtype): @with_comms def test_redistribute_negative_shard_dim(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) shard_spec = [Shard(1)] shard_minus_spec = [Shard(-1)] @@ -488,6 +563,7 @@ def test_redistribute_uneven_sharding(self): self.assertEqual(dt_full_tensor, input_tensor) @with_comms +<<<<<<< HEAD @parametrize("dtype", [torch.float32, torch.cfloat]) def test_redistribute_shard_dim_change(self, dtype): # test 1d device mesh @@ -503,6 +579,22 @@ def test_redistribute_shard_dim_change(self, dtype): torch.randn((5, 8), device=self.device_type, dtype=dtype), # uneven case 3 torch.randn((5, 5), device=self.device_type, dtype=dtype), +======= + def test_redistribute_shard_dim_change(self): + # test 1d device mesh + mesh_1d = DeviceMesh(self.device_type, torch.arange(self.world_size)) + data_to_test = [ + # evenly sharded case + torch.randn((8, 8), device=self.device_type), + # 3d or more dims + torch.randn((8, 8, 8), device=self.device_type), + # uneven case 1 + torch.randn((8, 5), device=self.device_type), + # uneven case 2 + torch.randn((5, 8), device=self.device_type), + # uneven case 3 + torch.randn((5, 5), device=self.device_type), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] sharding_src_dst_pairs = [([Shard(0)], [Shard(1)]), ([Shard(1)], [Shard(0)])] @@ -538,6 +630,7 @@ def test_redistribute_shard_dim_change(self, dtype): ) data_to_test_2d = [ # evenly sharded case +<<<<<<< HEAD torch.randn((8, 8), device=self.device_type, dtype=dtype), # 3d or more dims torch.randn((8, 8, 8), device=self.device_type, dtype=dtype), @@ -547,6 +640,17 @@ def test_redistribute_shard_dim_change(self, dtype): torch.randn((5, 8), device=self.device_type, dtype=dtype), # uneven case 3 torch.randn((5, 5), device=self.device_type, dtype=dtype), +======= + torch.randn((8, 8), device=self.device_type), + # 3d or more dims + torch.randn((8, 8, 8), device=self.device_type), + # uneven case 1 + torch.randn((8, 5), device=self.device_type), + # uneven case 2 + torch.randn((5, 8), device=self.device_type), + # uneven case 3 + torch.randn((5, 5), device=self.device_type), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] sharding_src_dst_pairs_2d = [ ([Shard(0), Shard(1)], [Shard(0), Shard(0)]), @@ -586,11 +690,18 @@ def test_redistribute_shard_dim_change(self, dtype): self.assertEqual(local_out_dt, local_expected_dt) @with_comms +<<<<<<< HEAD @parametrize("dtype", [torch.float32, torch.cfloat]) def test_shard_dim_alltoall(self, dtype): # init 2d mesh here so we can test when group_rank != global_rank mesh = init_device_mesh(self.device_type, (2, 2)) tensor = torch.randn(12, self.world_size, device=self.device_type, dtype=dtype) +======= + def test_shard_dim_alltoall(self): + # init 2d mesh here so we can test when group_rank != global_rank + mesh = init_device_mesh(self.device_type, (2, 2)) + tensor = torch.randn(12, self.world_size, device=self.device_type) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_tensor = shard_dim_alltoall(tensor, 0, 1, mesh, 0) meta_tensor = torch.randn(12, self.world_size, device="meta") @@ -600,9 +711,12 @@ def test_shard_dim_alltoall(self, dtype): self.assertEqual(new_tensor.stride(), new_meta_tensor.stride()) +<<<<<<< HEAD instantiate_parametrized_tests(RedistributeTest) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class MultiDimRedistributeTest(DTensorTestBase): @property def world_size(self) -> int: @@ -635,7 +749,11 @@ def test_multi_dim_mesh(self): dt = distribute_tensor(full_tensor, device_mesh, repl_inputs) if repl_inputs != inputs: +<<<<<<< HEAD # create a new DTensor reinterpreting some of the replicated entries as "Partial" +======= + # create a new DTensor reinterpreting some of the replicated entires as "Partial" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dt = DTensor.from_local( dt.to_local(), device_mesh, inputs, run_check=False ) diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 0e75748be8a31..9cf24480a291c 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -6,7 +6,10 @@ DeviceMesh, distribute_tensor, DTensor, +<<<<<<< HEAD init_device_mesh, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Partial, Replicate, Shard, @@ -25,7 +28,11 @@ class DistTensorOpsTest(DTensorTestBase): @with_comms def test_aten_contiguous(self): # this op not covered by dtensor_ops +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._test_op( mesh, lambda x: torch.ops.aten.contiguous(x), @@ -34,7 +41,11 @@ def test_aten_contiguous(self): @with_comms def test_detach(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Shard(0)] tensor_to_detach = torch.randn(12, 8, requires_grad=True) @@ -44,7 +55,11 @@ def test_detach(self): @with_comms def test_clone(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) specs = [[Replicate()], [Shard(0)]] tensor_to_clone = torch.randn(12, 8, requires_grad=True) for spec in specs: @@ -54,6 +69,7 @@ def test_clone(self): self.assertEqual(cloned_mat.to_local(), mat.to_local()) @with_comms +<<<<<<< HEAD def test_copy_(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -96,6 +112,10 @@ def test_copy_(self): @with_comms def test_contiguous(self): device_mesh = self.build_device_mesh() +======= + def test_contiguous(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor = torch.rand(3, 5, 6, requires_grad=True) sharding = [Shard(0)] dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) @@ -121,7 +141,11 @@ def test_contiguous(self): @with_comms def test_inplace_op(self): +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_tensor = torch.randn((12, 3), device=self.device_type) dt_to_add = distribute_tensor(input_tensor, mesh, [Shard(0)]) dt_to_mul = dt_to_add.clone() @@ -148,7 +172,11 @@ def test_inplace_op(self): @with_comms def test_op_out_variant(self): +<<<<<<< HEAD mesh = self.build_device_mesh() +======= + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_tensor = torch.randn((12, 3), device=self.device_type) sharded_dt_input = distribute_tensor(input_tensor, mesh, [Shard(0)]) expected_dt = sharded_dt_input.clone() + 3 @@ -169,7 +197,11 @@ def test_op_out_variant(self): @with_comms def test_empty_like(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -180,7 +212,11 @@ def test_empty_like(self): @with_comms def test_fill_inplace(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -192,7 +228,11 @@ def test_fill_inplace(self): @with_comms def test_full_like(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -203,7 +243,11 @@ def test_full_like(self): @with_comms def test_ones_like(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -214,7 +258,11 @@ def test_ones_like(self): @with_comms def test_ones_like_partial_sum(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -227,7 +275,11 @@ def test_ones_like_partial_sum(self): @with_comms def test_fill_inplace_partial_sum(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -243,7 +295,11 @@ def test_fill_inplace_partial_sum(self): @with_comms def test_zeros_like_partial_sum(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -256,7 +312,11 @@ def test_zeros_like_partial_sum(self): @with_comms def test_zero_inplace(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -268,7 +328,11 @@ def test_zero_inplace(self): @with_comms def test_zeros_like(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -320,7 +384,11 @@ def test_stack(self): @with_comms def test_equal(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Shard(0)] input_tensor_1 = torch.ones(4, 4) @@ -370,7 +438,11 @@ def _test_op(self, mesh, op_call, *args, **kwargs): @with_comms def test_new_full(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comm_mode = CommDebugMode() global_tensor = torch.randn(12, 8) @@ -397,7 +469,11 @@ def test_new_full(self): @with_comms def test_new_empty_strided(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comm_mode = CommDebugMode() shard_dim = 1 @@ -442,7 +518,11 @@ def test_new_empty_strided(self): @with_comms def test_scatter(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comm_mode = CommDebugMode() # case 1 all replicate: input replicated, index/src replicated, output replicated @@ -476,7 +556,11 @@ def test_scatter(self): @with_comms def test_gather(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comm_mode = CommDebugMode() # case 1 all replicate: input replicated, index replicated, output replicated @@ -527,7 +611,11 @@ def test_gather(self): @with_comms def test_index(self): meshes = [ +<<<<<<< HEAD self.build_device_mesh(), # 1D mesh +======= + DeviceMesh(self.device_type, list(range(self.world_size))), # 1D mesh +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO(@azzolini): un-comment when DTensorConverter supports N-D mesh # DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, -1)), # 2D mesh ] @@ -634,6 +722,7 @@ def test_index(self): ) @with_comms +<<<<<<< HEAD def test_index_put_scalar(self): device_mesh = init_device_mesh(self.device_type, (2, self.world_size // 2)) global_input = torch.randn(2, 4, 8, device=self.device_type) @@ -678,6 +767,10 @@ def test_index_put_tensor(self): @with_comms def test_where_type_promotion(self): mesh = self.build_device_mesh() # 1D mesh +======= + def test_where_type_promotion(self): + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) # 1D mesh +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) specs = [[Shard(0)], [Replicate()]] for spec in specs: @@ -689,7 +782,11 @@ def test_where_type_promotion(self): @with_comms def test_dtensor_dtype_conversion(self): +<<<<<<< HEAD device_mesh = self.build_device_mesh() +======= + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard_spec = [Shard(0)] # by default we start from bf16 dtype local_tenor = torch.randn(2, 8, dtype=torch.bfloat16) @@ -723,7 +820,11 @@ def test_dtensor_dtype_conversion(self): @with_comms def test_slice(self): +<<<<<<< HEAD mesh = self.build_device_mesh() # 1D mesh +======= + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) # 1D mesh +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comm_mode = CommDebugMode() shard_spec = [Shard(1)] @@ -748,6 +849,7 @@ def test_slice(self): self.assertEqual(sharded_out.full_tensor(), global_out) self.assertEqual(sharded_dtensor.grad.full_tensor(), global_tensor.grad) +<<<<<<< HEAD @with_comms def test_split_on_partial(self): self.run_subtests( @@ -776,6 +878,8 @@ def _test_split_on_partial(self, reduce_op: str, split_size: int, split_dim: int dim=split_dim, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_utils.py b/test/distributed/tensor/test_utils.py index dbfbac12223bb..1581f0a963ec9 100644 --- a/test/distributed/tensor/test_utils.py +++ b/test/distributed/tensor/test_utils.py @@ -179,7 +179,11 @@ def test_compute_global_tensor_shape_1D_invalid_shape(self): ) with self.assertRaisesRegex( RuntimeError, +<<<<<<< HEAD "Non-sharded dimensions should have identical size across ranks.", +======= + "Non-sharded dimentions should have identical size across ranks.", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): _ = compute_global_tensor_shape( local_shape, diff --git a/test/distributed/tensor/test_view_ops.py b/test/distributed/tensor/test_view_ops.py index 815b588a7ded7..1dd80070abe73 100644 --- a/test/distributed/tensor/test_view_ops.py +++ b/test/distributed/tensor/test_view_ops.py @@ -10,9 +10,13 @@ from torch.distributed.tensor import ( DeviceMesh, distribute_tensor, +<<<<<<< HEAD DTensor, init_device_mesh, Partial, +======= + init_device_mesh, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Replicate, Shard, ) @@ -27,7 +31,11 @@ view_groups, ) from torch.distributed.tensor.debug import CommDebugMode +<<<<<<< HEAD from torch.distributed.tensor.placement_types import _StridedShard, Placement +======= +from torch.distributed.tensor.placement_types import Placement +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -170,6 +178,7 @@ def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): *(device_mesh.ndim * [sharding_choices]) ) +<<<<<<< HEAD outer_mesh = device_mesh["outer"] inner_mesh = device_mesh["inner"] inner_mesh_size = inner_mesh.size() @@ -198,6 +207,10 @@ def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): ) else: in_dt = distribute_tensor(args[0], device_mesh, in_shard) +======= + for in_shard in all_sharding_choices: + in_dt = distribute_tensor(args[0], device_mesh, in_shard) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comm_mode = CommDebugMode() with comm_mode: @@ -228,13 +241,20 @@ def test_illegal_views(self): shard.view(-1) shard = dtensor.redistribute(device_mesh=device_mesh, placements=[Shard(dim=1)]) +<<<<<<< HEAD with self.assertRaisesRegex(RuntimeError, "Sharding propagation failed"): +======= + with self.assertRaisesRegex( + RuntimeError, "Attempted to flatten sharded dimension" + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shard.view(-1) # 8 is the uneven case since mesh dim is 6 tensor = torch.randn((8, 256)) dtensor = distribute_tensor(tensor, device_mesh, [Replicate()]) shard = dtensor.redistribute(device_mesh=device_mesh, placements=[Shard(dim=0)]) +<<<<<<< HEAD with self.assertRaisesRegex(RuntimeError, "Sharding propagation failed"): shard.view(-1) @@ -251,6 +271,17 @@ def test_view_ops(self): mesh_shape = (dist.get_world_size() // 2, 2) self.device_mesh = init_device_mesh( self.device_type, mesh_shape=mesh_shape, mesh_dim_names=("outer", "inner") +======= + with self.assertRaisesRegex( + RuntimeError, "Attempted to flatten unevenly sharded dimension" + ): + shard.view(-1) + + @with_comms + def test_view_ops(self): + self.device_mesh = DeviceMesh( + self.device_type, torch.arange(dist.get_world_size()).view(-1, 2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.dimmap_test(torch.atleast_1d, (randn(()),), (Singleton(),)) self.dimmap_test(torch.atleast_1d, (randn(24),), (InputDim(0),)) @@ -475,6 +506,10 @@ def test_view_ops(self): (randn(42, 24, 36), 1), (InputDim(0), Singleton(), InputDim(1), InputDim(2)), ) +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.dimmap_test( Tensor.view, (randn(6, 12, 24), 72, 24), @@ -641,6 +676,7 @@ def test_view_redistribution(self): mesh = init_device_mesh(self.device_type, (self.world_size,)) dtensor_x = distribute_tensor(x, mesh, (Shard(0),)) +<<<<<<< HEAD with self.assertRaisesRegex(RuntimeError, "Sharding propagation failed"): dtensor_x.view(-1, 8) @@ -663,6 +699,13 @@ def test_squeeze_(self): ) self.assertEqual(dist_x.placements, [Partial(), Shard(0)]) +======= + with self.assertRaisesRegex( + RuntimeError, "Attempted to flatten unevenly sharded dimension" + ): + dtensor_x.view(-1, 8) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_xla_integration.py b/test/distributed/tensor/test_xla_integration.py index e39931e1f1830..fa1eed7d96c7e 100644 --- a/test/distributed/tensor/test_xla_integration.py +++ b/test/distributed/tensor/test_xla_integration.py @@ -150,7 +150,11 @@ def text_xla_distribute_module(self): def shard_params(mod_name, mod, mesh): shard_spec = [Shard(0)] +<<<<<<< HEAD # annotate fc1 and fc2 +======= + # annoate fc1 and fc2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(mod, nn.Linear): for _, param in mod.named_parameters(): # annotate the parameter tensors directly diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 1857feffd9394..27c9c8b9edd24 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -293,6 +293,7 @@ def forward(self, x): return self.conv3(x) +<<<<<<< HEAD # A model involving FFTs, used to test DDP with complex tensors class FFTModel(nn.Module): def __init__(self, hin, win, n_features): @@ -310,6 +311,8 @@ def forward(self, x): return x +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Task(nn.Module): def __init__(self) -> None: super().__init__() diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index bafc781b591c6..4e2661004f1c9 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -1,6 +1,9 @@ # Owner(s): ["module: c10d"] import gc +<<<<<<< HEAD import re +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import threading import unittest from datetime import timedelta @@ -10,7 +13,11 @@ import torch.distributed as dist import torch.distributed._functional_collectives as funcol from torch._C import FileCheck +<<<<<<< HEAD from torch._inductor.utils import fresh_cache, run_and_get_code, run_and_get_triton_code +======= +from torch._inductor.utils import fresh_cache, run_and_get_triton_code +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed._functional_collectives import ( all_gather_into_tensor_coalesced, all_gather_tensor, @@ -78,6 +85,13 @@ def device(self) -> torch.device: return torch.device(f"cuda:{self.rank}") def _init_process_group(self) -> None: +<<<<<<< HEAD +======= + # Allow testing aoti after torch.compile + torch._inductor.config.triton.store_cubin = True + torch._inductor.config.debug = True + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.cuda.set_device(self.device) store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( @@ -489,7 +503,11 @@ def run(self): try: func(arg) compiled(arg) +<<<<<<< HEAD except BaseException as exc: # noqa: B036 +======= + except BaseException as exc: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.exc = exc def join(self): @@ -710,6 +728,7 @@ def test_collectives(self) -> None: self.assertEqual(pg.dels, 4) +<<<<<<< HEAD def find_buffer_assignments(code): pattern = r"buf(\d+) = empty_strided_" matches = re.finditer(pattern, code) @@ -774,6 +793,14 @@ def test_inductor_all_reduce_cpu(self): class CompileTest(TestCase): def setUp(self): super().setUp() +======= +class CompileTest(TestCase): + def setUp(self): + super().setUp() + # Allow testing aoti after torch.compile + torch._inductor.config.triton.store_cubin = True + torch._inductor.config.debug = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.rank = 0 self.world_size = 2 @@ -807,6 +834,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: compiled = torch.compile(func) code = run_and_get_triton_code(compiled, arg) +<<<<<<< HEAD buf0, buf1 = find_buffer_assignments(code) ( FileCheck() @@ -834,6 +862,24 @@ def func(arg: torch.Tensor) -> torch.Tensor: "aoti_torch_cpu__c10d_functional_all_reduce(buf" ).check_count("aoti_torch_delete_tensor_object(buf", 4).run(code) +======= + ( + FileCheck() + .check("buf0 = empty") + .check("buf7 = empty") + # Expect in-place with inductor allocated buf + .check("torch.ops._c10d_functional.all_reduce_.default(buf0") + .check("torch.ops._c10d_functional.wait_tensor.default(buf0") + # Expect no in-place with graph input (buf5 is a clone) + .check("torch.ops._c10d_functional.all_reduce_.default(buf7") + .check("torch.ops._c10d_functional.wait_tensor.default(buf7") + # Expect no extra copy on return + .check("return (buf0, buf7, )") + .run(code) + ) + assert "= torch.ops._c10d_functional.wait_tensor.default" not in code + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Test aoti AOTIRunnerUtil.run(func, (arg,)) torch.cuda.synchronize() @@ -854,6 +900,7 @@ def func(args: list[torch.Tensor]) -> torch.Tensor: args = [torch.rand(4, 4, device="cuda") for _ in range(2)] compiled = torch.compile(func) code = run_and_get_triton_code(compiled, args) +<<<<<<< HEAD buf0, buf1, buf2, buf3 = find_buffer_assignments(code) ( FileCheck() @@ -875,6 +922,28 @@ def func(args: list[torch.Tensor]) -> torch.Tensor: .check(f"torch.ops._c10d_functional.wait_tensor.default({buf3}") # Expect no extra copy on return .check(f"return ({buf0}, {buf2}, {buf1}, {buf3}, )") +======= + ( + FileCheck() + .check("buf0 = empty") + .check("buf5 = empty") + .check("buf1 = empty") + .check("buf6 = empty") + # Expect in-place with inductor allocated buf + .check( + "torch.ops._c10d_functional.all_reduce_coalesced_.default([buf0, buf1]" + ) + # Expect no in-place with graph input (buf5, buf6 are clones) + .check( + "torch.ops._c10d_functional.all_reduce_coalesced_.default([buf5, buf6]" + ) + .check("torch.ops._c10d_functional.wait_tensor.default(buf0") + .check("torch.ops._c10d_functional.wait_tensor.default(buf1") + .check("torch.ops._c10d_functional.wait_tensor.default(buf5") + .check("torch.ops._c10d_functional.wait_tensor.default(buf6") + # Expect no extra copy on return + .check("return (buf0, buf1, buf5, buf6, )") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .run(code) ) assert "= torch.ops._c10d_functional.wait_tensor.default" not in code @@ -896,6 +965,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: compiled = torch.compile(func) code = run_and_get_triton_code(compiled, arg) +<<<<<<< HEAD (buf0,) = find_buffer_assignments(code) ( FileCheck() @@ -905,6 +975,16 @@ def func(arg: torch.Tensor) -> torch.Tensor: .check(f"torch.ops._c10d_functional.all_reduce_.default({buf0}") .check(f"torch.ops._c10d_functional.wait_tensor.default({buf0}") .check(f"return ({buf0}") +======= + ( + FileCheck() + .check("buf0 = empty") + # We always call .contiguous() on the input to all_reduce_, + # so input will not be a view anymore. + .check("torch.ops._c10d_functional.all_reduce_.default(buf0") + .check("torch.ops._c10d_functional.wait_tensor.default(buf0") + .check("return (buf0") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .run(code) ) @@ -951,6 +1031,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: arg = torch.rand(4, 4, device="cuda") compiled = torch.compile(func) code = run_and_get_triton_code(compiled, arg) +<<<<<<< HEAD buf0, buf1 = find_buffer_assignments(code) ( FileCheck() @@ -966,6 +1047,22 @@ def func(arg: torch.Tensor) -> torch.Tensor: .check(f"extern_kernels.mm(arg0_1, {buf1}, out=buf8") # Expect no extra copy on return .check(f"return ({buf1}, buf8, )") +======= + ( + FileCheck() + # Expect allocation + .check("buf0 = empty") + .check("torch.ops._c10d_functional.all_reduce_.default(buf0") + .check("torch.ops._c10d_functional.wait_tensor.default(buf0") + # Expect allocation + .check("buf7 = empty") + .check("extern_kernels.mm(arg0_1, buf0, out=buf7") + # Expect buf0 to be reused + .check("buf8 = buf0; del buf0 # reuse") + .check("extern_kernels.mm(arg0_1, buf7, out=buf8") + # Expect no extra copy on return + .check("return (buf7, buf8, )") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .run(code) ) assert "= torch.ops._c10d_functional.wait_tensor.default" not in code @@ -1180,20 +1277,35 @@ def func(arg: torch.Tensor) -> torch.Tensor: compiled = torch.compile(func) code = run_and_get_triton_code(compiled, arg) +<<<<<<< HEAD buf0, buf1 = find_buffer_assignments(code) ( FileCheck() .check(f"{buf0} = empty") .check(f"buf1 = {buf0}") .check(f"{buf1} = empty") +======= + ( + FileCheck() + .check("buf0 = empty") + .check("buf1 = buf0") + .check("buf8 = empty") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Expect in-place with inductor allocated buf .check("torch.ops._c10d_functional.broadcast_.default(buf1") .check("torch.ops._c10d_functional.wait_tensor.default(buf1") # Expect no in-place with graph input (buf5 is a clone) +<<<<<<< HEAD .check(f"torch.ops._c10d_functional.broadcast_.default({buf1}") .check(f"torch.ops._c10d_functional.wait_tensor.default({buf1}") # Expect no extra copy on return .check(f"return (buf1, {buf1}, )") +======= + .check("torch.ops._c10d_functional.broadcast_.default(buf8") + .check("torch.ops._c10d_functional.wait_tensor.default(buf8") + # Expect no extra copy on return + .check("return (buf1, buf8, )") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .run(code) ) diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index 0b265e65cf57c..9ec39860c0070 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -25,7 +25,10 @@ import test_c10d_common from test_c10d_common import ( +<<<<<<< HEAD FFTModel, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gpus_for_rank, LOOPBACK, ModuleForDdpCommHook, @@ -135,6 +138,7 @@ def simple_reduce_tests(rank, world_size): ), ) +<<<<<<< HEAD # Extend tests for cfloat dtype tests.extend( ( @@ -161,6 +165,8 @@ def simple_reduce_tests(rank, world_size): ), ) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return tests @@ -400,6 +406,7 @@ def broadcast(xs, rootRank, rootTensor): torch.tensor([i * num + j], dtype=torch.float32), output[1] ) +<<<<<<< HEAD # Run with 1 input tensor of cfloat dtype x = fn(torch.tensor([complex(self.rank, self.rank)], dtype=torch.cfloat)) output = broadcast([x], i, 0) @@ -407,6 +414,8 @@ def broadcast(xs, rootRank, rootTensor): torch.tensor([complex(i, i)], dtype=torch.cfloat), output[0] ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Test overloaded convenience function x = torch.tensor([self.rank + 1.0]) fut = pg.broadcast(x, root=0).get_future() @@ -477,6 +486,7 @@ def test_allreduce_checks(self): opts = c10d.AllreduceOptions() pg.allreduce([t1, t3], opts) +<<<<<<< HEAD @requires_gloo() def test_allreduce_op_timeout(self): store = c10d.FileStore(self.file_name, self.world_size) @@ -505,6 +515,8 @@ def test_allreduce_overall_timeout(self): with self.assertRaisesRegex(RuntimeError, "Timed out waiting 1ms"): pg.allreduce([t1]).wait() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _test_allreduce_basics(self, fn): store = c10d.FileStore(self.file_name, self.world_size) pg = self._create_process_group_gloo( @@ -1189,7 +1201,11 @@ def test_gather_basics_cuda(self): @requires_gloo() def test_gather_noncontiguous_input(self): # Take a column of 2D tensor, such that memory is not dense +<<<<<<< HEAD self._test_gather_basics(lambda t: t.expand(2, 2).tril().contiguous()[:, 0]) +======= + self._test_gather_basics(lambda t: t.expand(2, 2).contiguous()[:, 0]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _test_gather_stress(self, inputs, fn): store = c10d.FileStore(self.file_name, self.world_size) @@ -1323,7 +1339,11 @@ def test_allgather_basics_cuda(self): @requires_gloo() def test_allgather_noncontiguous_input(self): # Take a column of 2D tensor, such that memory is not dense +<<<<<<< HEAD self._test_allgather_basics(lambda t: t.expand(2, 2).tril().contiguous()[:, 0]) +======= + self._test_allgather_basics(lambda t: t.expand(2, 2).contiguous()[:, 0]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @requires_gloo() def test_allgather_inference_mode(self): @@ -1624,6 +1644,7 @@ def test_barrier_implies_wait(self): for i, tensor in enumerate(tensors): self.assertEqual(torch.full(size, float(i * self.world_size)), tensor) +<<<<<<< HEAD @skip_if_lt_x_gpu(2) @requires_gloo() @skipIfRocm @@ -1655,6 +1676,8 @@ def test_send_recv_complex(self): pg.recv([recv_tensor], 0, 0).wait() self.assertEqual(send_tensor, recv_tensor) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class DistributedDataParallelTest( test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase @@ -2320,6 +2343,7 @@ def div_by_world_size(fut): self._run_and_verify_sparse_gradients(vanilla_model, ddp_model) +<<<<<<< HEAD @requires_gloo() def test_ddp_complex_params(self): process_group = self._get_process_group() @@ -2338,6 +2362,8 @@ def test_ddp_complex_params(self): loss.backward() optimizer.step() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ReducerModule(nn.Module): def __init__(self) -> None: @@ -2517,7 +2543,11 @@ def tearDown(self) -> None: def _verify_trace(self, t, is_json): ver = t["version"] +<<<<<<< HEAD self.assertEqual(ver, "2.10") +======= + self.assertEqual(ver, "2.9") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pg_config = t["pg_config"] self.assertEqual(len(pg_config), 1) default_pg_info = pg_config["0"] diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 2ac332f65fd06..96dc9f4dc11bb 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -29,6 +29,7 @@ import test_c10d_common +<<<<<<< HEAD from test_c10d_common import ( ConvNet, DoubleGpuNet, @@ -36,6 +37,9 @@ gpus_for_rank, ModuleForDdpCommHook, ) +======= +from test_c10d_common import ConvNet, DoubleGpuNet, gpus_for_rank, ModuleForDdpCommHook +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch.distributed as dist import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default @@ -607,7 +611,11 @@ def test_nan_check(self): def _helper_test_extra_cuda_context_by_nvml(self): """ +<<<<<<< HEAD A helper for `test_extra_cuda_context`, if pynvml is available. +======= + A helper for `test_extra_cuda_context`, if pynvml is avaiable. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pynvml provides python bindings for NVIDIA NVML functionalities. Here we are interested in: nvmlDeviceGetComputeRunningProcesses """ @@ -640,11 +648,26 @@ def _helper_test_extra_cuda_context_by_nvml(self): def _helper_test_extra_cuda_context_by_memory(self): """ +<<<<<<< HEAD A helper for `test_extra_cuda_context`, if pynvml is NOT available. +======= + A helper for `test_extra_cuda_context`, if pynvml is NOT avaiable. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) If extra context is created, it would manifest into device 0's memory usage. """ device = torch.device(f"cuda:{self.rank:d}") x = torch.empty((1,), device=device) +<<<<<<< HEAD +======= + + # We need this barrier to ensure that all nodes have completed init_process_group + # If rank=0 gets a mem snapshot before other nodes have finished init_process_group, + # then we artificially see a bump in memory usage. As per the following comment, + # we are going to be moving away from this function: + # https://github.com/pytorch/pytorch/pull/154174#discussion_r2105065931 + c10d.barrier() + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Rank 0 takes a snapshot before collective -- this snapshot should have # included rank 0's own context. if self.rank == 0: @@ -1118,7 +1141,11 @@ def test_non_blocking_with_eager_init(self): os.environ["TORCH_NCCL_NONBLOCKING_TIMEOUT"] = "100" store = c10d.FileStore(self.file_name, self.world_size) device = torch.device(f"cuda:{self.rank}") +<<<<<<< HEAD # bound device to trigger eager init mode +======= + # bound device to triger eager init mode +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pg = self._create_process_group_nccl(store, self.opts(), device_id=device) backend = pg._get_backend(torch.device(device)) self.assertEqual(backend.comm_split_count(), 0) @@ -1223,6 +1250,7 @@ def test_init_with_idx(self): ) dist.all_reduce(torch.empty(1, device=torch.device("cuda", device_idx))) +<<<<<<< HEAD @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_block_current_stream(self): @@ -1238,6 +1266,8 @@ def test_block_current_stream(self): work.wait() torch.cuda.synchronize() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class DistributedDataParallelTest( test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase @@ -2558,6 +2588,28 @@ def test_channels_last_contig(self): @requires_nccl() @skip_if_lt_x_gpu(2) def test_ddp_complex_params(self): +<<<<<<< HEAD +======= + class FFTModel(nn.Module): + def __init__(self, hin, win, n_features): + super().__init__() + self.hin = hin + self.win = win + self.weight = nn.Parameter( + torch.ones( + (n_features, n_features, hin, win // 2 + 1), dtype=torch.cfloat + ) + ) + + def forward(self, x): + xc = torch.fft.rfft2( + x, s=(self.hin, self.win), dim=(-2, -1), norm="ortho" + ) + xcw = torch.einsum("nchw,cohw->nohw", xc, self.weight) + x = torch.fft.irfft2(xcw, dim=(-2, -1), norm="ortho") + return x + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) process_group = self._get_process_group() device_id = gpus_for_rank(self.world_size)[self.rank][0] N, C, H, W = 1, 16, 64, 64 @@ -2839,6 +2891,7 @@ def _reduce_timeout(self): os.environ["TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"] = "1000" @requires_nccl() +<<<<<<< HEAD @skip_if_lt_x_gpu(3) @skip_if_rocm_multiprocess def test_send_recv_non_dense_tensor(self): @@ -2858,6 +2911,8 @@ def test_send_recv_non_dense_tensor(self): dist.recv(block, src=0) @requires_nccl() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) @skip_if_rocm_multiprocess @@ -3001,7 +3056,11 @@ def assert_fut_success(fut): time.sleep(4) self.assertEqual(process_group.get_error(), ErrorType.REMOTE_ERROR) +<<<<<<< HEAD # Mimicking all ranks sensing the timeout, abort +======= + # Mimicing all ranks sensing the timeout, abort +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) process_group.abort() if prev_nccl_async_error_handling is not None: @@ -3105,7 +3164,11 @@ def test_invalid_nccl_blocking_wait_env(self): self._run_invalid_nccl_blocking_wait_env("4294967295") +<<<<<<< HEAD class NcclUserBufferRegistrationTest(MultiProcessTestCase): +======= +class NcclRegistrationTest(MultiProcessTestCase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def setUp(self): super().setUp() # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests @@ -3133,6 +3196,7 @@ def tearDown(self): @requires_multicast_support() def test_nccl_user_buffer_registration(self): store = c10d.FileStore(self.file_name, self.world_size) +<<<<<<< HEAD device = torch.device(f"cuda:{self.rank}") c10d.init_process_group( backend="nccl", @@ -3141,6 +3205,12 @@ def test_nccl_user_buffer_registration(self): store=store, device_id=device, ) +======= + c10d.init_process_group( + backend="nccl", rank=self.rank, world_size=self.world_size, store=store + ) + device = torch.device(f"cuda:{self.rank}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.cuda.set_device(self.rank) pg = c10d.distributed_c10d._get_default_group() backend = pg._get_backend(torch.device(device)) @@ -3182,6 +3252,7 @@ def test_nccl_user_buffer_registration(self): @requires_multicast_support() def test_nccl_window_registration(self): store = c10d.FileStore(self.file_name, self.world_size) +<<<<<<< HEAD device = torch.device(f"cuda:{self.rank}") with torch.cuda.device(device): # Eager init the nccl comm so that we don't implicitly create one during register_mem_pool @@ -3224,12 +3295,47 @@ def test_nccl_window_registration(self): # clean up memory del tensor, pool +======= + c10d.init_process_group( + backend="nccl", rank=self.rank, world_size=self.world_size, store=store + ) + device = torch.device(f"cuda:{self.rank}") + torch.cuda.set_device(self.rank) + pg = c10d.distributed_c10d._get_default_group() + backend = pg._get_backend(torch.device(device)) + + # Use NCCL memory allocator + # enable symmetric memory usage in NCCL + pool = torch.cuda.MemPool(backend.mem_allocator, symm_mem=True) + + # allocate memory with ncclMemAlloc + # note: symmetric kernels are not available for dtypes like torch.int64 + with torch.cuda.use_mem_pool(pool): + tensor = torch.arange(1024 * 1024 * 2, device=device, dtype=torch.float32) + + # register buffers to NCCL + backend.register_mem_pool(pool) + + # allreduce now should use NVIDIA Switches + pg.allreduce(tensor).wait() + torch.cuda.synchronize(device=device) + + # de-register buffers from NCCL + backend.deregister_mem_pool(pool) + + # clean up memory + del tensor, pool +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with open(os.environ["NCCL_DEBUG_FILE"]) as f: nccl_debug_file_content = f.read() # if buffers were registered and symmetric kernels ran, NCCL_DEBUG # should show successful registration in debug output +<<<<<<< HEAD self.assertRegex(nccl_debug_file_content, "Symmetric") +======= + self.assertRegex(nccl_debug_file_content, "[Symmetric]") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): @@ -3762,6 +3868,7 @@ def test_allgather_base(self): self.assertEqual(output_tensor, tensor) @requires_nccl() +<<<<<<< HEAD @skip_if_lt_x_gpu(2) def test_allgather_noncontig(self): store = dist.FileStore(self.file_name, self.world_size) @@ -3783,6 +3890,8 @@ def test_allgather_noncontig(self): self.assertEqual(o, tensor) @requires_nccl() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skip_if_lt_x_gpu(1) @parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) def test_allgather_float8(self, float8_dtype): @@ -4335,7 +4444,11 @@ def local_device(self): def _join_processes(self, fn): # We need to patch sys.exit() as skip_if will use sys.exit() and +<<<<<<< HEAD # the exit code from the this process will not be caught. +======= + # the exit code from the this process will not be catched. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with mock.patch("sys.exit"): fn() super()._join_processes(fn) @@ -4394,12 +4507,19 @@ def started_or_scheduled(self, timing_enabled): class NCCLTraceTest(NCCLTraceTestBase): def _verify_trace(self, t, include_collectives, timing_enabled, is_json): ver = t["version"] +<<<<<<< HEAD self.assertEqual(ver, "2.10") comm_lib_version = t["comm_lib_version"] torch_comm_lib_version = torch.cuda.nccl.version() self.assertEqual( comm_lib_version, ".".join(str(v) for v in torch_comm_lib_version) ) +======= + self.assertEqual(ver, "2.9") + nccl_version = t["nccl_version"] + torch_nccl_version = torch.cuda.nccl.version() + self.assertEqual(nccl_version, ".".join(str(v) for v in torch_nccl_version)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pg_config = t["pg_config"] self.assertEqual(len(pg_config), 1) default_pg_info = pg_config["0"] diff --git a/test/distributed/test_c10d_ops_nccl.py b/test/distributed/test_c10d_ops_nccl.py index 9c22cf116d589..e0560399cf85d 100644 --- a/test/distributed/test_c10d_ops_nccl.py +++ b/test/distributed/test_c10d_ops_nccl.py @@ -25,7 +25,11 @@ from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( init_multigpu_helper, +<<<<<<< HEAD MultiProcContinuousTest, +======= + MultiProcContinousTest, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_nccl, requires_nccl_version, sm_is_or_higher_than, @@ -45,7 +49,11 @@ sys.exit(0) +<<<<<<< HEAD class ProcessGroupNCCLOpTest(MultiProcContinuousTest): +======= +class ProcessGroupNCCLOpTest(MultiProcContinousTest): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def backend_str(cls) -> str: return "nccl" @@ -936,6 +944,7 @@ def test_reduce_scatter_float8(self): ) torch.testing.assert_close(output_tensor, expected) +<<<<<<< HEAD @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_reduce_scatter_bfloat16(self): device = torch.device("cuda", self.rank_to_GPU[self.rank][0]) @@ -955,6 +964,8 @@ def test_reduce_scatter_bfloat16(self): ) torch.testing.assert_close(output_tensor, expected) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_barrier(self): diff --git a/test/distributed/test_c10d_pypg.py b/test/distributed/test_c10d_pypg.py index 65faf2075daa6..735bbe66b8dd4 100644 --- a/test/distributed/test_c10d_pypg.py +++ b/test/distributed/test_c10d_pypg.py @@ -1,7 +1,10 @@ # Owner(s): ["oncall: distributed"] +<<<<<<< HEAD import time import unittest +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import weakref import test_c10d_common @@ -12,7 +15,10 @@ from torch._C._distributed_c10d import _create_work_from_future from torch.futures import Future from torch.nn.parallel import DistributedDataParallel as DDP +<<<<<<< HEAD from torch.testing._internal.common_cuda import TEST_CUDA +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_distributed import MultiThreadedTestCase from torch.testing._internal.common_utils import run_tests, TestCase @@ -181,6 +187,7 @@ def use_wrapper(self): return True +<<<<<<< HEAD class BlockWork(dist._Work): """ Dummy work that is used to test blocking the current stream. @@ -194,6 +201,8 @@ def get_future(self): return self.future_ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestPyProcessGroup(TestCase): def test_attr_overrides(self): pg = DummyAttrProcessGroup(0, 1) @@ -213,6 +222,7 @@ def test_abort_shutdown(self) -> None: pg.abort() pg.shutdown() +<<<<<<< HEAD @unittest.skipIf(not TEST_CUDA, "no cuda/xpu") def test_block_current_stream(self) -> None: torch.cuda.synchronize() @@ -271,6 +281,8 @@ def test_block_current_stream_use_after_free(self) -> None: stream.synchronize() self.assertTrue(event.query()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/test_collective_utils.py b/test/distributed/test_collective_utils.py index 791aafa5a3a6b..f02c1924559ca 100644 --- a/test/distributed/test_collective_utils.py +++ b/test/distributed/test_collective_utils.py @@ -2,6 +2,7 @@ from unittest import mock +<<<<<<< HEAD import torch import torch.distributed as c10d from torch.distributed.collective_utils import ( @@ -21,6 +22,12 @@ TestCase, ) from torch.testing._internal.distributed.fake_pg import FakeStore +======= +import torch.distributed as c10d +from torch.distributed.collective_utils import all_gather, broadcast +from torch.testing._internal.common_distributed import MultiProcessTestCase +from torch.testing._internal.common_utils import run_tests +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestCollectiveUtils(MultiProcessTestCase): @@ -131,6 +138,7 @@ def test_all_gather_result_raises_exceptions_from_func( with self.assertRaisesRegex(Exception, expected_exception): all_gather(data_or_fn=func) +<<<<<<< HEAD @parametrize("device", ["cpu", "cuda"]) def test_check_rng_sync( self, @@ -211,6 +219,8 @@ def test_summarize_ranks(self): instantiate_parametrized_tests(TestCollectiveUtils) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/distributed/test_composability.py b/test/distributed/test_composability.py index b87e85a9a458a..61c41d905853a 100644 --- a/test/distributed/test_composability.py +++ b/test/distributed/test_composability.py @@ -19,7 +19,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( +<<<<<<< HEAD MultiProcContinuousTest, +======= + MultiProcContinousTest, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_nccl, skip_if_lt_x_gpu, ) @@ -91,7 +95,11 @@ def loss_fn(y, target, scale=1e-4): return torch.nn.functional.cross_entropy(y, target) * scale +<<<<<<< HEAD class ComposabilityTest(MultiProcContinuousTest): +======= +class ComposabilityTest(MultiProcContinousTest): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def backend_str(cls) -> str: # Testing with NCCL backend diff --git a/test/distributed/test_compute_comm_reordering.py b/test/distributed/test_compute_comm_reordering.py index 986fc2a0247d5..a101a7fe00aee 100644 --- a/test/distributed/test_compute_comm_reordering.py +++ b/test/distributed/test_compute_comm_reordering.py @@ -179,10 +179,14 @@ def func(a): .check("extern_kernels.mm") .check("triton_poi_fused_relu") .check("torch.ops._c10d_functional.all_reduce_.default") +<<<<<<< HEAD .check_same("buf0") # mm not use buf prior to wait_tensor .check("extern_kernels.mm") .check_not("buf0") +======= + .check("extern_kernels.mm") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .check("torch.ops._c10d_functional.wait_tensor.default") .check("extern_kernels.mm") .run(code) @@ -259,11 +263,14 @@ def func(a, *, tag, ranks, group_size): "reorder_compute_for_overlap", ], ) +<<<<<<< HEAD @patch.object( torch._inductor.config, "runtime_estimations_mms_benchmark", False, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_reorder_compute_for_overlap(self): def func(a, *, tag, ranks, group_size): ar = _functional_collectives.all_reduce(a, "sum", ranks, tag) diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 5672171d0be4d..a4e55b47bad78 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -5,7 +5,10 @@ import torch import torch.distributed as dist import torch.distributed._functional_collectives as funcol +<<<<<<< HEAD from torch._C._distributed_c10d import Backend as C10dBackend +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh from torch.distributed.distributed_c10d import ( @@ -31,7 +34,11 @@ DTensorTestBase, with_comms, ) +<<<<<<< HEAD from torch.testing._internal.distributed.fake_pg import FakeProcessGroup, FakeStore +======= +from torch.testing._internal.distributed.fake_pg import FakeStore +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._typing_utils import not_none @@ -579,6 +586,7 @@ def test_raises_mesh_shape_mesh_dim_names_mismatch(self): mesh_dim_names=["dp", "tp"], ) +<<<<<<< HEAD def _test_backend_override_argument_dict_with_idx_and_backend(self): opts = FakeProcessGroup.Options() opts.fake_option = 42 @@ -688,6 +696,8 @@ def test_backend_override_argument_errors(self): backend_override={42: "bar"}, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestDeviceMeshGetItem(DTensorTestBase): @property @@ -776,10 +786,13 @@ def test_get_item_3d(self): self.assertEqual(hsdp_mesh_2.mesh.tolist(), hsdp_group[hsdp_group_idx]) self.assertEqual(hsdp_mesh_1, hsdp_mesh_2) +<<<<<<< HEAD # Test slicing out 1D mesh from a sub-2D mesh. shard_mesh = hsdp_mesh_2["Shard"] self.assertEqual(shard_mesh.mesh.tolist(), shard_group[shard_group_idx]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @with_comms def test_cache_and_reuse_submesh_slice_result(self): mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("dp", "tp")) diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index d3436bbe47548..f60ad171dc894 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -1814,7 +1814,11 @@ def test_fsdp_skip_guards(self): Note: comptime prints the guards before the time they get installed or not installed, so in both cases (skip or no skip) the same guards get printed. The difference is that in the skip case, they show up +<<<<<<< HEAD with a special 'guard source' which will cause them to not be installed. So all we check for is the expected +======= + with a special 'guard source' which will cuase them to not be installed. So all we check for is the expected +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) guard source 'local_fsdp_module'. """ global GUARDS_FILE @@ -1871,7 +1875,11 @@ def _(ctx): def test_fsdp_skip_register_attr_or_module(self): """ +<<<<<<< HEAD ensure FSDP module is not registered as attributes +======= + ensure FSDP module is not registered as attrbutes +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) in the fx graph see `not source.guard_source().is_fsdp_module()` before calling `register_attr_or_module` diff --git a/test/distributed/test_fake_pg.py b/test/distributed/test_fake_pg.py index 0214680ba5e0b..187966497a20f 100644 --- a/test/distributed/test_fake_pg.py +++ b/test/distributed/test_fake_pg.py @@ -40,14 +40,24 @@ def tearDown(self): pass def test_all_reduce(self): +<<<<<<< HEAD dist.init_process_group(backend="fake", rank=1, world_size=2) +======= + store = FakeStore() + dist.init_process_group(backend="fake", rank=1, world_size=2, store=store) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output = torch.ones(3, 3) * dist.get_rank() dist.all_reduce(output) self.assertEqual(tuple(output.shape), (3, 3)) def test_allgather(self): +<<<<<<< HEAD dist.init_process_group(backend="fake", rank=1, world_size=2) +======= + store = FakeStore() + dist.init_process_group(backend="fake", rank=1, world_size=2, store=store) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_tensor = torch.ones(3, 3) * dist.get_rank() output_tensors = [torch.empty_like(input_tensor) for _ in range(2)] @@ -104,7 +114,12 @@ def allgather_fn(tensor): FileCheck().check("all_gather").check("wait_tensor").run(str(gm.graph)) def test_broadcast(self): +<<<<<<< HEAD dist.init_process_group(backend="fake", rank=0, world_size=2) +======= + store = FakeStore() + dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # src == rank output = torch.ones(3, 3) diff --git a/test/distributed/test_functional_api.py b/test/distributed/test_functional_api.py index b5522fe2bef06..765d81e668620 100644 --- a/test/distributed/test_functional_api.py +++ b/test/distributed/test_functional_api.py @@ -13,6 +13,10 @@ from torch._inductor.utils import run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_device_type import instantiate_device_type_tests +<<<<<<< HEAD +======= +from torch.testing._internal.distributed.fake_pg import FakeStore +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.inductor_utils import HAS_GPU @@ -23,7 +27,11 @@ from torch.testing._internal.common_distributed import ( DistributedTestBase, MultiThreadedTestCase, +<<<<<<< HEAD requires_accelerator_dist_backend, +======= + requires_nccl, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TEST_SKIPS, ) from torch.testing._internal.common_utils import ( @@ -33,7 +41,10 @@ skipIfHpu, TEST_CUDA, TEST_HPU, +<<<<<<< HEAD TEST_XPU, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TestCase, ) @@ -64,9 +75,12 @@ if TEST_HPU: devices.append("hpu") DEVICE = "hpu" +<<<<<<< HEAD elif TEST_XPU: devices.append("xpu") DEVICE = "xpu" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif TEST_CUDA: devices.append("cuda") @@ -272,10 +286,17 @@ def setUp(self): @parametrize("device", devices) def test_broadcast(self, device): +<<<<<<< HEAD if device != "cpu": if torch.accelerator.device_count() < self.world_size: self.skipTest("Not enough accelerator devices") torch.accelerator.set_device_index(dist.get_rank()) +======= + if device == "cuda": + if torch.cuda.device_count() < self.world_size: + self.skipTest("Not enough CUDA devices") + torch.cuda.set_device(dist.get_rank()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if dist.get_rank() == 0: tensor = torch.ones([4], device=device) @@ -288,10 +309,17 @@ def test_broadcast(self, device): @parametrize("device", devices) def test_all_reduce_eager(self, device): +<<<<<<< HEAD if device != "cpu": if torch.accelerator.device_count() < self.world_size: self.skipTest("Not enough accelerator devices") torch.accelerator.set_device_index(dist.get_rank()) +======= + if device == "cuda": + if torch.cuda.device_count() < self.world_size: + self.skipTest("Not enough CUDA devices") + torch.cuda.set_device(dist.get_rank()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor = torch.ones([4], device=device) mesh = dt.DeviceMesh(device, torch.arange(4)) @@ -305,10 +333,17 @@ def test_all_reduce_eager(self, device): @parametrize("device", devices) def test_all_reduce_coalesced_eager(self, device): +<<<<<<< HEAD if device != "cpu": if torch.accelerator.device_count() < self.world_size: self.skipTest("Not enough accelerator devices") torch.accelerator.set_device_index(dist.get_rank()) +======= + if device == "cuda": + if torch.cuda.device_count() < self.world_size: + self.skipTest("Not enough CUDA devices") + torch.cuda.set_device(dist.get_rank()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) t0 = torch.ones([4], device=device) t1 = torch.ones([6], device=device) + 2 @@ -320,10 +355,17 @@ def test_all_reduce_coalesced_eager(self, device): @parametrize("device", devices) def test_all_gather_tensor(self, device): +<<<<<<< HEAD if device != "cpu": if torch.accelerator.device_count() < self.world_size: self.skipTest("Not enough accelerator devices") torch.accelerator.set_device_index(dist.get_rank()) +======= + if device == "cuda": + if torch.cuda.device_count() < self.world_size: + self.skipTest("Not enough CUDA devices") + torch.cuda.set_device(dist.get_rank()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # testing 1d/2d mesh mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size)) @@ -342,10 +384,17 @@ def test_all_gather_tensor(self, device): @parametrize("device", devices) def test_all_gather_into_tensor_coalesced(self, device): +<<<<<<< HEAD if device != "cpu": if torch.accelerator.device_count() < self.world_size: self.skipTest("Not enough accelerator devices") torch.accelerator.set_device_index(dist.get_rank()) +======= + if device == "cuda": + if torch.cuda.device_count() < self.world_size: + self.skipTest("Not enough CUDA devices") + torch.cuda.set_device(dist.get_rank()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1] mesh = dt.DeviceMesh(device, torch.arange(4)) @@ -359,10 +408,17 @@ def test_all_gather_into_tensor_coalesced(self, device): @parametrize("device", devices) def test_reduce_scatter_tensor(self, device): +<<<<<<< HEAD if device != "cpu": if torch.accelerator.device_count() < self.world_size: self.skipTest("Not enough accelerator devices") torch.accelerator.set_device_index(dist.get_rank()) +======= + if device == "cuda": + if torch.cuda.device_count() < self.world_size: + self.skipTest("Not enough CUDA devices") + torch.cuda.set_device(dist.get_rank()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # testing 1d/2d mesh mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size)) @@ -383,10 +439,17 @@ def test_reduce_scatter_tensor(self, device): @parametrize("device", devices) def test_reduce_scatter_into_tensor_coalesced(self, device): +<<<<<<< HEAD if device != "cpu": if torch.accelerator.device_count() < self.world_size: self.skipTest("Not enough accelerator devices") torch.accelerator.set_device_index(dist.get_rank()) +======= + if device == "cuda": + if torch.cuda.device_count() < self.world_size: + self.skipTest("Not enough CUDA devices") + torch.cuda.set_device(dist.get_rank()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensors = [ torch.ones([4], dtype=torch.int64, device=device), torch.ones([4], dtype=torch.int64, device=device) + 1, @@ -430,10 +493,18 @@ def setUp(self): # so create a fake_pg. self.rank = 0 self.world_size = 2 +<<<<<<< HEAD +======= + store = FakeStore() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dist.init_process_group( backend="fake", world_size=self.world_size, rank=self.rank, +<<<<<<< HEAD +======= + store=store, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def tearDown(self): @@ -475,17 +546,29 @@ def allred_mesh_dim(input): # And then set the BACKEND variable appropriately. if TEST_HPU: BACKEND = dist.Backend.HCCL +<<<<<<< HEAD elif TEST_XPU: BACKEND = dist.Backend.XCCL +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # allows you to check for multiple accelerator irrespective of device type # to add new device types to this check simply follow the same format # and append an elif with the conditional and appropriate device count function for your new device def exit_if_lt_x_accelerators(x): +<<<<<<< HEAD if torch.accelerator.is_available(): if torch.accelerator.device_count() < x: sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code) +======= + if TEST_CUDA: + if torch.cuda.device_count() < x: + sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) + elif TEST_HPU: + if torch.hpu.device_count() < x: + sys.exit(TEST_SKIPS[f"multi-hpu-{x}"].exit_code) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def with_comms(func=None): @@ -494,9 +577,13 @@ def with_comms(func=None): @wraps(func) def wrapper(self, *args, **kwargs): +<<<<<<< HEAD if ( BACKEND == dist.Backend.NCCL or BACKEND == dist.Backend.XCCL ) and torch.accelerator.device_count() < self.world_size: +======= + if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) kwargs["device"] = DEVICE @@ -574,7 +661,11 @@ def test_all_to_all_single_split_sizes_none(self, device): self.assertEqual(y, expected) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") +<<<<<<< HEAD @requires_accelerator_dist_backend(["nccl", "xccl"]) +======= + @requires_nccl() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @with_comms() def test_tracing(self, device): def allreduce(t, pg): @@ -595,12 +686,20 @@ def allreduce(t, pg): backend="fake", rank=0, world_size=8, +<<<<<<< HEAD +======= + store=FakeStore(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) allreduce(torch.randn(8, device=device), pg=dist.group.WORLD) dist.destroy_process_group() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") +<<<<<<< HEAD @requires_accelerator_dist_backend(["nccl", "xccl"]) +======= + @requires_nccl() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @with_comms() def test_tracing_with_dce_code(self, device): if self.world_size > 2: @@ -819,6 +918,7 @@ def test_all_to_all_single(self, device) -> None: # Update the supported devices in DEVICE instantiate_device_type_tests( +<<<<<<< HEAD TestCollectivesWithDistributedBackend, globals(), only_for=DEVICE, allow_xpu=True ) instantiate_device_type_tests( @@ -832,6 +932,15 @@ def test_all_to_all_single(self, device) -> None: globals(), only_for=DEVICE, allow_xpu=True, +======= + TestCollectivesWithDistributedBackend, globals(), only_for=DEVICE +) +instantiate_device_type_tests( + TestDistributedBackendCollectivesWithWorldSize4, globals(), only_for=DEVICE +) +instantiate_device_type_tests( + TestFunctionalAutogradWithDistributedBackend, globals(), only_for=DEVICE +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if __name__ == "__main__": diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index ca729fd50b0af..dfcc320109b44 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -2,7 +2,11 @@ import datetime import functools import unittest +<<<<<<< HEAD from collections import Counter +======= +from collections import defaultdict +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Optional from unittest.mock import patch @@ -19,6 +23,7 @@ from torch._inductor.comms import ( _reorder_communication_preserving_peak_memory_internal, ReorderInfo, +<<<<<<< HEAD sink_waits_iterative, ) from torch._inductor.compile_fx import compile_fx as inductor_compile_fx @@ -32,11 +37,22 @@ from torch.distributed.distributed_c10d import GroupMember from torch.fx.experimental.proxy_tensor import make_fx from torch.testing._internal.common_cuda import SM80OrLater +======= +) +from torch._inductor.compile_fx import compile_fx as inductor_compile_fx +from torch._inductor.scheduler import BaseSchedulerNode +from torch._inductor.utils import run_and_get_triton_code +from torch.distributed.distributed_c10d import GroupMember +from torch.fx.experimental.proxy_tensor import make_fx +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_distributed import ( _dynamo_dist_per_rank_init, DynamoDistributedMultiProcTestCase, DynamoDistributedSingleProcTestCase, +<<<<<<< HEAD MultiProcessTestCase, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_nccl, skip_if_lt_x_gpu, ) @@ -418,6 +434,7 @@ def forward(self, x, world_size, tag, ranks, group_size): @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) +<<<<<<< HEAD def test_allgather_scalar_tensor_input(self): def func(tensor, world_size): tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] @@ -433,6 +450,8 @@ def func(tensor, world_size): @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_allgather_contiguous_input(self): class Model(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: @@ -671,7 +690,11 @@ def alltoall_autograd( class TrackingMode(TorchDispatchMode): def __init__(self): super().__init__() +<<<<<<< HEAD self.ops_counter = Counter() +======= + self.ops_counter = defaultdict(int) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: @@ -845,7 +868,11 @@ def func(inp, *, tag, ranks, group_size): compiled = torch.compile(func) out = compiled(inputs, **self.get_world_trs()) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) +<<<<<<< HEAD # NOTE: Make sure we are not unnecessarily copying the outputs of +======= + # NOTE: Make sure we are not unneccessarily copying the outputs of +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # wait_tensors before they are returned from the graph. ( FileCheck() @@ -912,7 +939,11 @@ def func(inp, *, tag, ranks, group_size): compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) +<<<<<<< HEAD # NOTE: Make sure we are not unnecessarily copying the outputs of +======= + # NOTE: Make sure we are not unneccessarily copying the outputs of +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # wait_tensors before they are returned from the graph. ( FileCheck() @@ -1377,7 +1408,11 @@ def func(inp, *, tag, ranks, group_size): compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) +<<<<<<< HEAD # NOTE: Make sure we are not unnecessarily copying the outputs of +======= + # NOTE: Make sure we are not unneccessarily copying the outputs of +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # wait_tensors before they are returned from the graph. ( FileCheck() @@ -1424,7 +1459,11 @@ def func(inp, *, tag, ranks, group_size): compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) # NOTE: The first return value should be the output of the first wait_tensor. +<<<<<<< HEAD # We want to make sure no unnecessary copy is made. +======= + # We want to make sure no unneccessary copy is made. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ( FileCheck() .check("buf0 = empty_strided") @@ -1495,7 +1534,11 @@ def _reorder_communication_preserving_peak_memory( compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) # NOTE: The first return value should be the output of the first wait_tensor. +<<<<<<< HEAD # We want to make sure no unnecessary copy is made. +======= + # We want to make sure no unneccessary copy is made. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ( FileCheck() .check("buf0 = empty_strided") @@ -1523,6 +1566,7 @@ def _reorder_communication_preserving_peak_memory( self.assertEqual(len(node_stats), 1) for stats in node_stats.values(): self.assertEqual(stats.initial_exposed, 0) +<<<<<<< HEAD self.assertEqual(stats.limiting_factor, "None") self.assertEqual(stats.moves, 0) @@ -2054,6 +2098,11 @@ def test_sync_decision_cross_ranks(self): saved_values = _sync_decision_cross_ranks(test_graph, saved_values) self.assertEqual(saved_values, [wt1]) +======= + self.assertEqual(stats.limiting_factor, "data dependency") + self.assertEqual(stats.moves, 0) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/distributed/test_multi_threaded_pg.py b/test/distributed/test_multi_threaded_pg.py index 7ca6d25ad1c97..754f3c5e9bb38 100644 --- a/test/distributed/test_multi_threaded_pg.py +++ b/test/distributed/test_multi_threaded_pg.py @@ -25,8 +25,11 @@ from torch.testing._internal.common_utils import IS_SANDCASTLE, run_tests, TestCase +<<<<<<< HEAD device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DEFAULT_WORLD_SIZE = 4 @@ -332,7 +335,11 @@ def backward(ctx, grad_output): return grad_output * result x = torch.tensor( +<<<<<<< HEAD [dist.get_rank()], dtype=torch.float, device=device_type, requires_grad=True +======= + [dist.get_rank()], dtype=torch.float, device="cuda", requires_grad=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) x = MyFunc.apply(x) x.sum().backward() diff --git a/test/distributed/test_nccl.py b/test/distributed/test_nccl.py index 49d72b8b4edd8..aa139e1524b6d 100644 --- a/test/distributed/test_nccl.py +++ b/test/distributed/test_nccl.py @@ -13,15 +13,22 @@ dtypes, instantiate_device_type_tests, ) +<<<<<<< HEAD from torch.testing._internal.common_distributed import ( MultiProcContinuousTest, skip_if_lt_x_gpu, ) +======= +from torch.testing._internal.common_distributed import MultiProcContinousTest +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import ( IS_WINDOWS, load_tests, NoTest, +<<<<<<< HEAD requires_cuda_p2p_access, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests, skip_but_pass_in_sandcastle_if, TEST_WITH_ROCM, @@ -245,6 +252,7 @@ def test_reduce_scatter(self, device, dtype): self.assertEqual(outputs[i], expected[i]) +<<<<<<< HEAD @requires_cuda_p2p_access() class NCCLSymmetricMemoryTest(MultiProcContinuousTest): @property @@ -260,6 +268,26 @@ def test_nccl_symmem_alloc(self): # Need this all_reduce to initialize NCCL communicator. Otherwise, the # test will hang. TODO: investigate how NCCLSymmetricMemory can # initialize NCCL communicator. +======= +device_type = "cuda" +device_module = torch.get_device_module(device_type) + + +class NCCLSymmetricMemoryTest(MultiProcContinousTest): + def _init_device(self) -> None: + # TODO: relieve this (seems to hang if without) + device_module.set_device(self.device) + + @property + def device(self) -> torch.device: + return torch.device(device_type, self.rank) + + # To run this test, one needs to TORCH_SYMMMEM=NCCL when running the test. + @skip_but_pass_in_sandcastle_if(TEST_WITH_ROCM, "Skip NCCL tests for ROCm") + @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows") + def test_nccl_symmem_alloc(self): + self._init_device() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10d.all_reduce(torch.ones(1, device=self.device)) group_name = c10d.group.WORLD.group_name symm_mem.enable_symm_mem_for_group(group_name) diff --git a/test/distributed/test_nvshmem.py b/test/distributed/test_nvshmem.py index a51f7e35eb33e..6a35992a816f3 100644 --- a/test/distributed/test_nvshmem.py +++ b/test/distributed/test_nvshmem.py @@ -1,12 +1,17 @@ # Owner(s): ["oncall: distributed"] # To run: +<<<<<<< HEAD # python test/distributed/test_nvshmem.py +======= +# TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.distributed as dist import torch.distributed._symmetric_memory as symm_mem +<<<<<<< HEAD from torch.distributed.device_mesh import init_device_mesh from torch.testing._internal.common_distributed import ( MultiProcContinuousTest, @@ -16,10 +21,22 @@ instantiate_parametrized_tests, parametrize, requires_cuda_p2p_access, +======= +import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem +from torch._inductor.runtime.triton_compat import tl, triton +from torch.testing._internal.common_distributed import MultiProcContinousTest +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests, skip_but_pass_in_sandcastle_if, skipIfRocm, ) +<<<<<<< HEAD +======= +from torch.testing._internal.inductor_utils import requires_triton +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Decorator @@ -35,6 +52,7 @@ def requires_nvshmem(): device_module = torch.get_device_module(device_type) +<<<<<<< HEAD @requires_nvshmem() @requires_cuda_p2p_access() class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest): @@ -43,6 +61,16 @@ def _init_device(self) -> None: device_module.set_device(self.device) # Set NVSHMEM as SymmMem backend symm_mem.set_backend("NVSHMEM") +======= +@instantiate_parametrized_tests +@requires_nvshmem() +class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest): + def _init_device(self) -> None: + # TODO: relieve this (seems to hang if without) + device_module.set_device(self.device) + # NOTE: required for nvshmem allocation + torch.empty(1, device=self.device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def device(self) -> torch.device: @@ -68,6 +96,7 @@ def foo(): symm_mem.rendezvous(out, group=group_name) @skipIfRocm +<<<<<<< HEAD def test_alloc_without_device_context(self) -> None: # Set NVSHMEM as SymmMem backend symm_mem.set_backend("NVSHMEM") @@ -185,6 +214,8 @@ def test_get_remote_tensor(self) -> None: self.assertEqual(y, expected) @skipIfRocm +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nvshmem_put(self) -> None: self._init_device() group_name = dist.group.WORLD.group_name @@ -210,6 +241,7 @@ def test_nvshmem_put(self) -> None: dist.barrier() @skipIfRocm +<<<<<<< HEAD def test_nvshmem_get(self) -> None: self._init_device() group_name = dist.group.WORLD.group_name @@ -248,6 +280,8 @@ def device(self) -> torch.device: return torch.device(device_type, self.rank) @skipIfRocm +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nvshmem_all_to_all(self) -> None: self._init_device() @@ -295,6 +329,7 @@ def test_all_to_all_vdev(self) -> None: overflow_factor = self.world_size # worst case: one rank receives all data max_out_numel = max_inp_numel * overflow_factor +<<<<<<< HEAD inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).copy_( torch.randn(max_inp_numel, dtype=dtype, device=self.device) ) @@ -320,12 +355,36 @@ def test_all_to_all_vdev(self) -> None: # Check output splits (row 1) torch.testing.assert_close(out_splits_offsets[0], out_splits) +======= + inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).fill_( + self.rank + ) + out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1) + in_out_splits = symm_mem.empty( + (3, self.world_size), dtype=torch.int64, device=self.device + ) + # Row 0 is input splits + in_out_splits[0].copy_(inp_splits) + + torch.ops.symm_mem.all_to_all_vdev(inp, out, in_out_splits, group_name) + + # Check input splits (row 0) -- should not change + torch.testing.assert_close(in_out_splits[0], inp_splits) + + # Check output splits (row 1) + torch.testing.assert_close(in_out_splits[1], out_splits) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Check output offsets (row 2) out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan # output offsets from `all_to_all_vdev` is exclusive scan +<<<<<<< HEAD self.assertEqual(out_splits_offsets[1][0], 0) torch.testing.assert_close(out_splits_offsets[1][1:], out_offsets[:-1]) +======= + self.assertEqual(in_out_splits[2][0], 0) + torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Check data expected = torch.empty(out_numel, dtype=dtype, device=self.device) @@ -368,6 +427,7 @@ def test_all_to_all_vdev_2d(self, align: int) -> None: overflow_factor = self.world_size # worst case: one rank receives all data max_out_numel = max_inp_numel * overflow_factor +<<<<<<< HEAD inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).copy_( torch.randn(max_inp_numel, dtype=dtype, device=self.device) ) @@ -392,6 +452,28 @@ def test_all_to_all_vdev_2d(self, align: int) -> None: # Check input splits (row 0) -- should not change torch.testing.assert_close(in_splits, inp_splits) +======= + inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).fill_( + self.rank + ) + out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1) + # 3 rows: input splits, output splits, output offsets + # Initiallizing all values to -1 to check if they are updated + in_out_splits = symm_mem.empty( + (3, nsplits), dtype=torch.int64, device=self.device + ).fill_(-1) + # Row 0 is input splits + in_out_splits[0].copy_(inp_splits) + + torch.ops.symm_mem.all_to_all_vdev_2d( + inp, out, in_out_splits, group_name, major_align=align + ) + received_out_splits = in_out_splits[1] + received_out_offsets = in_out_splits[2] + + # Check input splits (row 0) -- should not change + torch.testing.assert_close(in_out_splits[0], inp_splits) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Check output splits (row 1) torch.testing.assert_close(received_out_splits, out_splits_t.reshape(-1)) @@ -445,6 +527,7 @@ def test_all_to_all_vdev_2d(self, align: int) -> None: torch.testing.assert_close(received_chunk, chunk) @skipIfRocm +<<<<<<< HEAD def test_all_to_all_vdev_2d_offset(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() @@ -702,6 +785,709 @@ def test_dispatch_combine_subgroup(self) -> None: ) subgroup = dm.get_group("ep") dispatch_then_combine(self.device, align=8, group=subgroup) +======= + @requires_triton() + def test_triton_put(self) -> None: + # A Triton kernel that calls nvshmem device side API + @triton.jit + def put_kernel( + dst_ptr, + src_ptr, + numel: tl.constexpr, + peer: tl.constexpr, + ): + nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer) + + torch.manual_seed(42 + self.rank) + self._init_device() + + # Enable NVSHMEM for Triton + nvshmem_lib = nvshmem.enable_triton() + + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + + val = 5 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + + peer = 1 - rank + if rank == 0: + dst_ptr = out_hdl.buffer_ptrs[rank] + src_ptr = inp_hdl.buffer_ptrs[rank] + put_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel=numel, + peer=peer, + extern_libs=nvshmem_lib, + ) + + dist.barrier() + if rank == 1: + torch.testing.assert_close( + out, val * torch.ones(numel, dtype=dtype, device=self.device) + ) + + @skipIfRocm + @requires_triton() + def test_triton_get(self) -> None: + # A Triton kernel that calls nvshmem device side API for GET + @triton.jit + def get_kernel( + dst_ptr, + src_ptr, + numel: tl.constexpr, + peer: tl.constexpr, + ): + nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer) + + torch.manual_seed(42 + self.rank) + self._init_device() + + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + val = 7 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_( + val if rank == 0 else -1 + ) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + dist.barrier() + peer = 1 - rank + if rank == 1: + # Rank 1 gets data from rank 0 + dst_ptr = out_hdl.buffer_ptrs[rank] + src_ptr = inp_hdl.buffer_ptrs[rank] + get_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel=numel, + peer=peer, + extern_libs=nvshmem_lib, + ) + if rank == 1: + torch.testing.assert_close( + out, val * torch.ones(numel, dtype=dtype, device=self.device) + ) + + @skipIfRocm + @requires_triton() + def test_triton_get_ring(self) -> None: + # A Triton kernel that calls nvshmem device side API for GET + # with ring topology + @triton.jit + def get_kernel( + dst_ptr, + src_ptr, + numel: tl.constexpr, + peer: tl.constexpr, + ): + nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer) + + torch.manual_seed(42 + self.rank) + self._init_device() + + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + world_size = dist.get_world_size() + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + + # Each rank fills its input buffer with its own rank value + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(rank) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + dist.barrier() + + # Ring topology: each rank gets data from the rank to its left + # rank 0 gets from rank (world_size-1), rank 1 gets from rank 0, etc. + peer = (rank - 1) % world_size + + # All ranks execute the get operation + dst_ptr = out_hdl.buffer_ptrs[rank] + src_ptr = inp_hdl.buffer_ptrs[rank] + get_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel=numel, + peer=peer, + extern_libs=nvshmem_lib, + ) + + expected_value = peer + torch.testing.assert_close( + out, expected_value * torch.ones(numel, dtype=dtype, device=self.device) + ) + + @skipIfRocm + @requires_triton() + def test_triton_put_signal_set(self) -> None: + # A Triton kernel that calls nvshmem device side API for PUT with SIGNAL + @triton.jit + def put_signal_kernel( + dst_ptr, + src_ptr, + numel: tl.constexpr, + sig_ptr, + signal_val: tl.constexpr, + sig_op: tl.constexpr, + peer: tl.constexpr, + ): + nvshmem.putmem_signal_block( + dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer + ) + + torch.manual_seed(42 + self.rank) + self._init_device() + + nvshmem_lib = nvshmem.enable_triton() + + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + + # Data buffers + val = 11 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + + # Use the signal pad attached to the output symmetric memory handle + # as the flag buffer for signaling completion. + flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) + + peer = 1 - rank + NVSHMEM_SIGNAL_SET = 0 # value defined by NVSHMEM for atomic set + SIGNAL_VAL = 1 # Signal completion value + NVSHMEM_CMP_EQ = 0 # compare equal for signal wait until + + # Kernel for waiting on the signal locally (Rank 1). + @triton.jit + def signal_wait_until_kernel( + sig_ptr, cmp_op: tl.constexpr, cmp_val: tl.constexpr + ): + nvshmem.signal_wait_until(sig_ptr, cmp_op, cmp_val) + + if rank == 0: + # Rank 0 puts into Rank 1 + dst_ptr = out_hdl.buffer_ptrs[peer] + src_ptr = inp_hdl.buffer_ptrs[rank] + sig_ptr = out_hdl.signal_pad_ptrs[peer] + put_signal_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel=numel, + sig_ptr=sig_ptr, + signal_val=SIGNAL_VAL, + sig_op=NVSHMEM_SIGNAL_SET, + peer=peer, + extern_libs=nvshmem_lib, + ) + + if rank == 1: + # Wait until signal flag is set by Rank 0 + sig_ptr_local = out_hdl.signal_pad_ptrs[rank] + signal_wait_until_kernel[(1,)]( + sig_ptr_local, + cmp_op=NVSHMEM_CMP_EQ, + cmp_val=SIGNAL_VAL, + extern_libs=nvshmem_lib, + ) + # After wait completes, verify data and flag contents + torch.testing.assert_close( + out, val * torch.ones(numel, dtype=dtype, device=self.device) + ) + torch.testing.assert_close( + flag, torch.tensor([SIGNAL_VAL], dtype=torch.int64, device=self.device) + ) + + @skipIfRocm + @requires_triton() + def test_triton_put_signal_add(self) -> None: + # A Triton kernel that calls nvshmem device side API for PUT with SIGNAL + @triton.jit + def put_signal_kernel( + dst_ptr, + src_ptr, + numel: tl.constexpr, + sig_ptr, + signal_val: tl.constexpr, + sig_op: tl.constexpr, + peer: tl.constexpr, + ): + nvshmem.putmem_signal_block( + dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer + ) + + torch.manual_seed(42 + self.rank) + self._init_device() + + nvshmem_lib = nvshmem.enable_triton() + + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + + # Data buffers + val = 11 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + + # Use the signal pad attached to the output symmetric memory handle + # as the flag buffer for signaling completion. + flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) + + peer = 1 - rank + NVSHMEM_SIGNAL_ADD = 5 # atomic add operation + SIGNAL_VAL = 16 # val + NVSHMEM_SIGNAL_ADD + NVSHMEM_CMP_EQ = 0 + + @triton.jit + def signal_wait_until_kernel( + sig_ptr, cmp_op: tl.constexpr, cmp_val: tl.constexpr + ): + nvshmem.signal_wait_until(sig_ptr, cmp_op, cmp_val) + + if rank == 0: + # Rank 0 puts into Rank 1 + dst_ptr = out_hdl.buffer_ptrs[peer] + src_ptr = inp_hdl.buffer_ptrs[rank] + sig_ptr = out_hdl.signal_pad_ptrs[peer] + put_signal_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel=numel, + sig_ptr=sig_ptr, + signal_val=SIGNAL_VAL, + sig_op=NVSHMEM_SIGNAL_ADD, + peer=peer, + extern_libs=nvshmem_lib, + ) + + if rank == 1: + sig_ptr_local = out_hdl.signal_pad_ptrs[rank] + signal_wait_until_kernel[(1, 1, 1)]( + sig_ptr_local, + cmp_op=NVSHMEM_CMP_EQ, + cmp_val=SIGNAL_VAL, + extern_libs=nvshmem_lib, + ) + torch.testing.assert_close( + out, val * torch.ones(numel, dtype=dtype, device=self.device) + ) + torch.testing.assert_close( + flag, torch.tensor([SIGNAL_VAL], dtype=torch.int64, device=self.device) + ) + + @skipIfRocm + @requires_triton() + def test_triton_wait_until(self) -> None: + # A Triton kernel that calls nvshmem device side API for PUT + @triton.jit + def put_kernel( + dst_ptr, + src_ptr, + numel: tl.constexpr, + peer: tl.constexpr, + ): + nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer) + + # A Triton kernel that calls nvshmem device side API for WAIT_UNTIL + @triton.jit + def wait_until_kernel( + ivar_ptr, + cmp_op: tl.constexpr, + cmp_val: tl.constexpr, + ): + nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val) + + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + + # Data buffers + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + val = 13 + flag_val = 21 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + dist.barrier() + + peer = 1 - rank + NVSHMEM_CMP_EQ = 0 # from nvshmem.h + + if rank == 0: + # Rank 0 waits for the flag to be set by Rank 1, then checks the data + ivar_ptr = out_hdl.signal_pad_ptrs[rank] + wait_until_kernel[(1, 1, 1)]( + ivar_ptr, + cmp_op=NVSHMEM_CMP_EQ, + cmp_val=flag_val, + extern_libs=nvshmem_lib, + ) + torch.testing.assert_close( + out, val * torch.ones(numel, dtype=dtype, device=self.device) + ) + + if rank == 1: + # Rank 1 puts data into Rank 0's output buffer + dst_ptr = out_hdl.buffer_ptrs[rank] + src_ptr = inp_hdl.buffer_ptrs[rank] + put_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel=numel, + peer=peer, + extern_libs=nvshmem_lib, + ) + # Rank 1 sets the flag on Rank 0 + # We use a temporary tensor for the value to put. + flag_update_val = torch.tensor( + [flag_val], dtype=torch.int64, device=self.device + ) + dst_ptr = out_hdl.signal_pad_ptrs[rank] + src_ptr = flag_update_val.data_ptr() + put_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel=1, + peer=peer, + extern_libs=nvshmem_lib, + ) + + @skipIfRocm + @requires_triton() + def test_triton_signal_wait_until(self) -> None: + # A Triton kernel that waits on a signal variable until it meets the compare condition. + @triton.jit + def signal_wait_until_kernel( + sig_ptr, + cmp_op: tl.constexpr, + cmp_val: tl.constexpr, + ): + nvshmem.signal_wait_until(sig_ptr, cmp_op, cmp_val) + + # A Triton kernel for the producer that puts data and then signals completion. + @triton.jit + def put_and_signal_kernel( + dst_ptr, + src_ptr, + numel: tl.constexpr, + sig_ptr, + signal_val: tl.constexpr, + sig_op: tl.constexpr, + peer: tl.constexpr, + ): + nvshmem.putmem_signal_block( + dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer + ) + + self._init_device() + # Enable NVSHMEM for Triton + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + peer = 1 - rank + + # NVSHMEM constants from documentation + NVSHMEM_CMP_EQ = 0 # equal comparison + NVSHMEM_SIGNAL_SET = 0 # atomic set operation + + # Message configuration + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + val_to_put = 123 # arbitrary test value + COMPLETION_FLAG_VAL = 1 + + # Producer (rank 0) prepares the data to send + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val_to_put) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + # Consumer (rank 1) prepares the destination buffer + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + out_hdl = symm_mem.rendezvous(out, group=group_name) + # Use the signal pad for synchronization, as in previous tests + flag_dtype = torch.int64 + flag = out_hdl.get_signal_pad(rank, (1,), dtype=flag_dtype).fill_(0) + # Ensure setup is complete on all ranks before proceeding + dist.barrier() + + if rank == 0: + # Producer (rank 0): Puts data into rank 1's `out` buffer and then sets the flag + dst_ptr = out_hdl.buffer_ptrs[peer] + src_ptr = inp_hdl.buffer_ptrs[rank] + sig_ptr = out_hdl.signal_pad_ptrs[peer] + put_and_signal_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + numel, + sig_ptr, + signal_val=COMPLETION_FLAG_VAL, + sig_op=NVSHMEM_SIGNAL_SET, + peer=peer, + extern_libs=nvshmem_lib, + ) + elif rank == 1: + # Consumer (rank 1): Waits on the signal variable using `signal_wait_until`. + sig_ptr = out_hdl.signal_pad_ptrs[rank] + signal_wait_until_kernel[(1, 1, 1)]( + sig_ptr, + cmp_op=NVSHMEM_CMP_EQ, + cmp_val=COMPLETION_FLAG_VAL, + extern_libs=nvshmem_lib, + ) + # After the wait returns, verify data and flag + torch.testing.assert_close( + out, val_to_put * torch.ones(numel, dtype=dtype, device=self.device) + ) + torch.testing.assert_close( + flag, + torch.tensor( + [COMPLETION_FLAG_VAL], dtype=flag_dtype, device=self.device + ), + ) + # Final barrier to ensure the test does not exit before assertions complete + dist.barrier() + + @skipIfRocm + @requires_triton() + def test_triton_fence(self) -> None: + """ + Rank 0 performs two put operations into Rank 1's buffers with a fence + between them, followed by another fence and a flag update. Rank 1 waits + for the flag, then verifies that both destination buffers contain the + expected values. The flag is transferred after the final fence, so + its arrival implies that both preceding puts have been delivered in + order. + """ + + # Triton kernel that issues two ordered puts separated by fences and + # finally writes the completion flag. + @triton.jit + def put_with_fence_kernel( + dst_ptr1, + dst_ptr2, + src_ptr1, + src_ptr2, + flag_ptr, + flag_src_ptr, + numel: tl.constexpr, + peer: tl.constexpr, + ): + # First put + nvshmem.putmem_block(dst_ptr1, src_ptr1, numel, peer) + # Ensure the first put is ordered before the next. + nvshmem.fence() + # Second put + nvshmem.putmem_block(dst_ptr2, src_ptr2, numel, peer) + # Order the second put before flag update. + nvshmem.fence() + # Write the flag (single int64) to signal completion. + nvshmem.putmem_block(flag_ptr, flag_src_ptr, 1, peer) + + # Kernel for Rank 1 to wait until the flag becomes the expected value. + @triton.jit + def wait_until_kernel( + ivar_ptr, + cmp_op: tl.constexpr, + cmp_val: tl.constexpr, + ): + nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val) + + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + peer = 1 - rank + # Message configuration + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + val1 = 10 + val2 = 20 + flag_val = 1 + # Symmetric buffers + inp1 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val1) + inp2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val2) + out1 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + out2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp1_hdl = symm_mem.rendezvous(inp1, group=group_name) + inp2_hdl = symm_mem.rendezvous(inp2, group=group_name) + out1_hdl = symm_mem.rendezvous(out1, group=group_name) + out2_hdl = symm_mem.rendezvous(out2, group=group_name) + + # Flag buffer resides in the signal pad of out2. + flag = out2_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) + flag_update_val = torch.tensor( + [flag_val], dtype=torch.int64, device=self.device + ) + NVSHMEM_CMP_EQ = 0 # compare equal + dist.barrier() + + if rank == 0: + dst_ptr1 = out1_hdl.buffer_ptrs[rank] + dst_ptr2 = out2_hdl.buffer_ptrs[rank] + src_ptr1 = inp1_hdl.buffer_ptrs[rank] + src_ptr2 = inp2_hdl.buffer_ptrs[rank] + flag_ptr = out2_hdl.signal_pad_ptrs[rank] + flag_src_ptr = flag_update_val.data_ptr() + + put_with_fence_kernel[(1, 1, 1)]( + dst_ptr1, + dst_ptr2, + src_ptr1, + src_ptr2, + flag_ptr, + flag_src_ptr, + numel, + peer=peer, + extern_libs=nvshmem_lib, + ) + elif rank == 1: + # Wait until flag is set by Rank 0. + ivar_ptr = out2_hdl.signal_pad_ptrs[rank] + wait_until_kernel[(1, 1, 1)]( + ivar_ptr, + cmp_op=NVSHMEM_CMP_EQ, + cmp_val=flag_val, + extern_libs=nvshmem_lib, + ) + + # Verify ordered data arrival. + torch.testing.assert_close( + out1, val1 * torch.ones(numel, dtype=dtype, device=self.device) + ) + torch.testing.assert_close( + out2, val2 * torch.ones(numel, dtype=dtype, device=self.device) + ) + torch.testing.assert_close( + flag, torch.tensor([flag_val], dtype=torch.int64, device=self.device) + ) + dist.barrier() + + @skipIfRocm + @requires_triton() + def test_triton_quiet(self) -> None: + # A Triton kernel that uses nvshmem_quiet to ensure completion + @triton.jit + def put_with_quiet_kernel( + dst_ptr, + src_ptr, + flag_dst_ptr, + flag_src_ptr, + numel: tl.constexpr, + peer: tl.constexpr, + ): + # Put data + nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer) + # Call quiet to ensure put is complete + nvshmem.quiet() + # Only after quiet, set the completion flag + # This ensures the data put is complete before flag is set + nvshmem.putmem_block(flag_dst_ptr, flag_src_ptr, 1, peer) + + torch.manual_seed(42 + self.rank) + self._init_device() + # Enable NVSHMEM for Triton + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + msg_size_bytes = 8 + dtype = torch.int8 + numel = msg_size_bytes // dtype.itemsize + # Data buffers + val = 15 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) + out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) + out_hdl = symm_mem.rendezvous(out, group=group_name) + # Use signal pad as completion flag + flag_val = 42 + peer = 1 - rank + NVSHMEM_CMP_EQ = 0 + + @triton.jit + def wait_until_kernel( + ivar_ptr, + cmp_op: tl.constexpr, + cmp_val: tl.constexpr, + ): + nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val) + + dist.barrier() + if rank == 0: + # Rank 0 waits for flag from Rank 1 + ivar_ptr = out_hdl.signal_pad_ptrs[rank] + wait_until_kernel[(1, 1, 1)]( + ivar_ptr, + cmp_op=NVSHMEM_CMP_EQ, + cmp_val=flag_val, + extern_libs=nvshmem_lib, + ) + # After flag is set, data should be complete due to quiet + torch.testing.assert_close( + out, val * torch.ones(numel, dtype=dtype, device=self.device) + ) + if rank == 1: + # Rank 1 puts data and flag with quiet in between + dst_ptr = out_hdl.buffer_ptrs[rank] + src_ptr = inp_hdl.buffer_ptrs[rank] + flag_dst_ptr = out_hdl.signal_pad_ptrs[rank] + # Create a tensor for the flag value + flag_update_val = torch.tensor( + [flag_val], dtype=torch.int64, device=self.device + ) + flag_src_ptr = flag_update_val.data_ptr() + put_with_quiet_kernel[(1, 1, 1)]( + dst_ptr, + src_ptr, + flag_dst_ptr, + flag_src_ptr, + numel=numel, + peer=peer, + extern_libs=nvshmem_lib, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": diff --git a/test/distributed/test_pg_wrapper.py b/test/distributed/test_pg_wrapper.py index c1fbf05e60a1c..1bc070c2addd4 100644 --- a/test/distributed/test_pg_wrapper.py +++ b/test/distributed/test_pg_wrapper.py @@ -376,7 +376,11 @@ def patched_isinstance(obj, clazz): ): self._create_wrapper_pg(with_new_group=True) # nothing to assert, isinstance(pg, _ProcessGroupWrapper) +<<<<<<< HEAD # should never be invoked since it is proceeded by +======= + # should never be invoked since it is preceeded by +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # _GLOO_AVAILABLE check, this test will fail on # an unexpected NameError if not. diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index 870805eec75e8..8b8af49f39275 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -837,9 +837,15 @@ def test_tcp_store_timeout_set(self): # not respected, it will take much longer to timeout. start = time.time() with self.assertRaisesRegex( +<<<<<<< HEAD DistStoreError, "wait timeout after 100ms, keys: /nonexistent key" ): store0.get("nonexistent key") +======= + DistStoreError, "wait timeout after 100ms, keys: /nonexistant key" + ): + store0.get("nonexistant key") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) end = time.time() time_diff = end - start @@ -1066,7 +1072,11 @@ def run(rank, my_store): wait_for_workers=False, ) +<<<<<<< HEAD threads = [] +======= + ths = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i in range(2): t = threading.Thread( target=run, @@ -1076,16 +1086,26 @@ def run(rank, my_store): ), ) t.start() +<<<<<<< HEAD threads.append(t) +======= + ths.append(t) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def handler(a, b): pass signal.signal(signal.SIGUSR1, handler) time.sleep(1) +<<<<<<< HEAD signal.pthread_kill(threads[1].ident, signal.SIGUSR1) for t in threads: +======= + signal.pthread_kill(ths[1].ident, signal.SIGUSR1) + + for t in ths: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) t.join() self.assertTrue(rank_res[0], "rank0") self.assertTrue(rank_res[1], "rank1") diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index eeeb24bec307b..1ada946147628 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -2,7 +2,10 @@ import itertools import os +<<<<<<< HEAD import random +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from contextlib import nullcontext from unittest import skip, skipIf @@ -25,7 +28,11 @@ from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM90OrLater from torch.testing._internal.common_device_type import e4m3_type from torch.testing._internal.common_distributed import ( +<<<<<<< HEAD MultiProcContinuousTest, +======= + MultiProcContinousTest, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MultiProcessTestCase, requires_multicast_support, skip_if_lt_x_gpu, @@ -35,9 +42,15 @@ MI300_ARCH, parametrize, requires_cuda, +<<<<<<< HEAD requires_cuda_p2p_access, run_tests, runOnRocmArch, +======= + run_tests, + runOnRocmArch, + skip_but_pass_in_sandcastle_if, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) skipIfRocm, TEST_WITH_ROCM, TestCase, @@ -51,15 +64,48 @@ device_module = torch.get_device_module(device_type) +<<<<<<< HEAD @instantiate_parametrized_tests @requires_cuda_p2p_access() class SymmetricMemoryTest(MultiProcContinuousTest): +======= +def requires_cuda_p2p_access(): + cuda_p2p_access_available = ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (8, 0) + and torch.cuda.device_count() >= 2 + ) + num_devices = torch.cuda.device_count() + for i in range(num_devices - 1): + for j in range(i + 1, num_devices): + if not torch.cuda.can_device_access_peer(i, j): + cuda_p2p_access_available = False + break + if not cuda_p2p_access_available: + break + + return skip_but_pass_in_sandcastle_if( + not cuda_p2p_access_available, + "cuda p2p access is not available", + ) + + +@instantiate_parametrized_tests +@requires_cuda_p2p_access() +class SymmetricMemoryTest(MultiProcContinousTest): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def device(self) -> torch.device: return torch.device(device_type, self.rank) +<<<<<<< HEAD def _init_process(self): torch.cuda.set_device(self.device) +======= + def _init_process(self, set_device: bool = True): + if set_device: + torch.cuda.set_device(self.device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.manual_seed(42 + self.rank) def test_has_multicast_support(self) -> None: @@ -69,6 +115,7 @@ def test_has_multicast_support(self) -> None: @skipIfRocm @skip_if_lt_x_gpu(2) +<<<<<<< HEAD def test_get_backend(self) -> None: backend = symm_mem.get_backend(torch.device("cuda")) self.assertIsNotNone(backend) @@ -77,6 +124,8 @@ def test_get_backend(self) -> None: @skipIfRocm @skip_if_lt_x_gpu(2) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cuda_nvlink_connectivity_detection(self) -> None: from torch._C._distributed_c10d import _detect_dma_connectivity @@ -92,6 +141,89 @@ def test_large_alloc(self) -> None: t = symm_mem.empty(2 * 1024**3, dtype=torch.uint8, device="cuda") self.assertEqual(t.numel() * t.element_size(), 2 * 1024**3) +<<<<<<< HEAD +======= + def _get_test_alloc_args(self): + shape = (64, 64) + stride = (64, 1) + dtype = torch.float32 + device = self.device + group_name = "0" + return (shape, stride, dtype, device, group_name) + + def _verify_symmetric_memory(self, symm_mem_hdl): + self.assertEqual(symm_mem_hdl.world_size, self.world_size) + + buf = symm_mem_hdl.get_buffer( + 0, (symm_mem_hdl.buffer_size // 4,), torch.float32 + ) + self.assertEqual(buf.storage_offset(), 0) + self.assertEqual(buf.untyped_storage().size(), symm_mem_hdl.buffer_size) + + if symm_mem_hdl.rank == 0: + symm_mem_hdl.wait_signal(src_rank=1) + self.assertTrue(buf.eq(42).all()) + else: + buf.fill_(42) + symm_mem_hdl.put_signal(dst_rank=0) + + symm_mem_hdl.barrier() + + if symm_mem_hdl.rank == 0: + symm_mem_hdl.barrier() + self.assertTrue(buf.eq(43).all()) + else: + buf.fill_(43) + symm_mem_hdl.barrier() + + symm_mem_hdl.barrier() + + @runOnRocmArch(MI300_ARCH) + @skip_if_lt_x_gpu(2) + @parametrize("set_device", [True, False]) + def test_empty_strided_p2p(self, set_device: bool) -> None: + self._init_process(set_device) + enable_symm_mem_for_group(dist.group.WORLD.group_name) + + alloc_args = self._get_test_alloc_args() + + t = torch.empty((64, 64), device=self.device) + self.assertIsNone(_SymmetricMemory.rendezvous(t)) + + t = _SymmetricMemory.empty_strided_p2p(*alloc_args) + symm_mem_hdl = _SymmetricMemory.rendezvous(t) + + del t + self._verify_symmetric_memory(symm_mem_hdl) + + @skipIfRocm # started failing during ROCm 6.4 CI upgrade + @skip_if_lt_x_gpu(2) + @parametrize("set_device", [True, False]) + def test_empty_strided_p2p_persistent(self, set_device: bool) -> None: + self._init_process(set_device) + enable_symm_mem_for_group(dist.group.WORLD.group_name) + + alloc_args = self._get_test_alloc_args() + + t = _SymmetricMemory.empty_strided_p2p(*alloc_args, alloc_id=42) + data_ptr = t.data_ptr() + + # Verify that persistent allocation would fail if there's an active + # allocation with the same alloc_id. + with self.assertRaises(RuntimeError): + _SymmetricMemory.empty_strided_p2p(*alloc_args, alloc_id=42) + + # Verify that persistent allocation would succeed in lieu of activate + # allocations with the same alloc_id, and the returned tensor would + # have the same data pointer. + del t + t = _SymmetricMemory.empty_strided_p2p(*alloc_args, alloc_id=42) + self.assertEqual(t.data_ptr(), data_ptr) + + symm_mem_hdl = _SymmetricMemory.rendezvous(t) + self._verify_symmetric_memory(symm_mem_hdl) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @runOnRocmArch(MI300_ARCH) @skip_if_lt_x_gpu(2) def test_get_signal_pad(self) -> None: @@ -154,6 +286,7 @@ def test_allow_overlapping_devices(self) -> None: @runOnRocmArch(MI300_ARCH) @skip_if_lt_x_gpu(2) +<<<<<<< HEAD @parametrize("symm_mem_input", [True, False]) def test_low_contention_all_gather(self, symm_mem_input: bool) -> None: self._init_process() @@ -272,6 +405,8 @@ def _init_process(self): @runOnRocmArch(MI300_ARCH) @skip_if_lt_x_gpu(2) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("gather_dim", [0, 1]) def test_fused_all_gather_matmul(self, gather_dim: int) -> None: self._init_process() @@ -558,7 +693,11 @@ def test_fused_scaled_matmul_reduce_scatter( ) assert outputs[0].stride() == outputs[1].stride() +<<<<<<< HEAD self.assertEqual(outputs[0], outputs[1]) +======= + assert torch.allclose(outputs[0], outputs[1]), (outputs[0], outputs[1]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @runOnRocmArch(MI300_ARCH) @parametrize("dim", [0, 1, 2]) @@ -573,6 +712,7 @@ def test_optimal_layout(self, dim: int) -> None: self.assertTrue(x.movedim(dim, 0).is_contiguous()) self.assertTrue(torch.allclose(x, t)) +<<<<<<< HEAD # [READ ME FIRST] # The `SymmMemEmptySetDeviceTest` suite parameterizes whether user sets the @@ -697,11 +837,114 @@ def test_empty_strided_p2p_persistent(self, set_device: bool) -> None: symm_mem_hdl = _SymmetricMemory.rendezvous(t) self._verify_symmetric_memory(symm_mem_hdl) +======= + @runOnRocmArch(MI300_ARCH) + @skip_if_lt_x_gpu(2) + @parametrize("symm_mem_input", [True, False]) + def test_low_contention_all_gather(self, symm_mem_input: bool) -> None: + self._init_process() + + if symm_mem_input: + t = _SymmetricMemory.empty_strided_p2p( + size=(64, 64), + stride=(64, 1), + dtype=torch.float32, + device=self.device, + group_name="0", + ).fill_(self.rank) + else: + t = torch.full((64, 64), self.rank, dtype=torch.float32, device=self.device) + + res = torch.ops.symm_mem._low_contention_all_gather(t, "0") + res = torch.ops._c10d_functional.wait_tensor(res) + self.assertEqual(res.shape, (64 * self.world_size, 64)) + + chunks = res.chunk(self.world_size) + for r in range(self.world_size): + self.assertTrue(chunks[r].eq(r).all()) + + @runOnRocmArch(MI300_ARCH) + @skip_if_lt_x_gpu(2) + @parametrize("reduce_op", ["sum", "avg"]) + @parametrize("symm_mem_input", [True, False]) + def test_low_contention_reduce_scatter( + self, reduce_op: str, symm_mem_input: bool + ) -> None: + self._init_process() + + if symm_mem_input: + t = _SymmetricMemory.empty_strided_p2p( + size=(64, 64), + stride=(64, 1), + dtype=torch.float32, + device=self.device, + group_name="0", + ) + else: + t = torch.empty((64, 64), dtype=torch.float32, device=self.device) + + chunks = t.chunk(self.world_size) + for r in range(self.world_size): + chunks[r].fill_(r) + + res = torch.ops.symm_mem._low_contention_reduce_scatter(t, reduce_op, "0") + res = torch.ops._c10d_functional.wait_tensor(res) + self.assertEqual(res.shape, (64 // self.world_size, 64)) + + if reduce_op == "sum": + expect = self.rank * self.world_size + elif reduce_op == "avg": + expect = self.rank + else: + raise AssertionError(f"Unexpected reduce_op: {reduce_op}") + self.assertTrue(res.eq(expect).all()) + + @runOnRocmArch(MI300_ARCH) + @skip_if_lt_x_gpu(4) + def test_subgroup(self) -> None: + self._init_process() + + ranks = list(range(self.world_size)) + subgroup_0 = dist.new_group(ranks[: len(ranks) // 2]) + subgroup_1 = dist.new_group(ranks[len(ranks) // 2 :]) + + world = dist.group.WORLD + subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1 + + t = symm_mem.empty(64, device="cuda") + symm_mem_world = symm_mem.rendezvous(t, group=world) + symm_mem_subgroup = symm_mem.rendezvous(t, group=subgroup) + + self.assertEqual(symm_mem_world.world_size, world.size()) + self.assertEqual(symm_mem_world.rank, world.rank()) + self.assertEqual(symm_mem_subgroup.world_size, world.size() // 2) + self.assertEqual(symm_mem_subgroup.rank, world.rank() % subgroup.size()) + + t.fill_(world.rank()) + symm_mem_world.barrier() + + # Observe a peer buffer via the world group + peer_rank = (world.rank() + 1) % world.size() + buf = symm_mem_world.get_buffer(peer_rank, (64,), torch.float32) + self.assertTrue(buf.eq(peer_rank).all()) + + # Observe a peer buffer via the subgroup + peer_rank = (subgroup.rank() + 1) % subgroup.size() + buf = symm_mem_subgroup.get_buffer(peer_rank, (64,), torch.float32) + if world.rank() < world.size() // 2: + self.assertTrue(buf.eq(peer_rank).all()) + else: + self.assertTrue(buf.eq(peer_rank + world.size() // 2).all()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This Test class is used to test the error handling of SymmetricMemory APIs. # Since a process restart is often needed after each test, we use the +<<<<<<< HEAD # MultiProcessTestCase instead of MultiProcContinuousTest. +======= +# MultiProcessTestCase instead of MultiProcContinousTest. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @requires_cuda_p2p_access() class SymmMemNegativeTest(MultiProcessTestCase): def setUp(self) -> None: @@ -729,7 +972,11 @@ def _init_process(self): # These timeout tests are skipped on ROCm because timeout calls trap(), which # is handled differently inside hip runtime. It collects gpu coredump and causes +<<<<<<< HEAD # the linux kernel to create a core dump of the host application. The functionality +======= + # the linux kernel to create a core dump of the host application. The funcitonality +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # is there, meaning timeout is happening correctly. However, there isn't a nice way # to test it as the current executing thread will coredump and exit. @skipIfRocm @@ -755,7 +1002,11 @@ def test_barrier_timeout(self) -> None: # These timeout tests are skipped on ROCm because timeout calls trap(), which # is handled differently inside hip runtime. It collects gpu coredump and causes +<<<<<<< HEAD # the linux kernel to create a core dump of the host application. The functionality +======= + # the linux kernel to create a core dump of the host application. The funcitonality +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # is there, meaning timeout is happening correctly. However, there isn't a nice way # to test it as the current executing thread will coredump and exit. @skipIfRocm @@ -784,7 +1035,11 @@ def test_put_signal_timeout(self) -> None: # These timeout tests are skipped on ROCm because timeout calls trap(), which # is handled differently inside hip runtime. It collects gpu coredump and causes +<<<<<<< HEAD # the linux kernel to create a core dump of the host application. The functionality +======= + # the linux kernel to create a core dump of the host application. The funcitonality +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # is there, meaning timeout is happening correctly. However, there isn't a nice way # to test it as the current executing thread will coredump and exit. @skipIfRocm @@ -811,7 +1066,11 @@ def test_wait_signal_timeout(self) -> None: @instantiate_parametrized_tests @requires_cuda_p2p_access() +<<<<<<< HEAD class SymmMemCollectiveTest(MultiProcContinuousTest): +======= +class SymmMemCollectiveTest(MultiProcContinousTest): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def device(self) -> torch.device: return torch.device(device_type, self.rank) @@ -1058,7 +1317,11 @@ def test_multimem_all_gather(self, align_bytes: int) -> None: @instantiate_parametrized_tests @requires_cuda_p2p_access() +<<<<<<< HEAD class LoweringTest(MultiProcContinuousTest): +======= +class LoweringTest(MultiProcContinousTest): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _init_process(self) -> None: torch.cuda.set_device(self.device) enable_symm_mem_for_group(dist.group.WORLD.group_name) diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index 7cb8cc678136f..571721a5783e7 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -1805,6 +1805,7 @@ def test_zero_excluded_binomial(self): assert (vals == 0.0).sum() > 4000 assert (vals == 1.0).sum() > 4000 +<<<<<<< HEAD def test_torch_binomial_dtype_errors(self): dtypes = [torch.int, torch.long, torch.short] @@ -1828,6 +1829,8 @@ def test_torch_binomial_dtype_errors(self): ): torch.binomial(total_count, total_prob) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @set_default_dtype(torch.double) def test_multinomial_1d(self): total_count = 10 @@ -6369,7 +6372,11 @@ def test_lazy_logits_initialization(self): except NotImplementedError: pass self.assertNotIn("probs", dist.__dict__, msg=message) +<<<<<<< HEAD _ = (dist.batch_shape, dist.event_shape) +======= + dist.batch_shape, dist.event_shape +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertNotIn("probs", dist.__dict__, msg=message) def test_lazy_probs_initialization(self): @@ -6386,7 +6393,11 @@ def test_lazy_probs_initialization(self): except NotImplementedError: pass self.assertNotIn("logits", dist.__dict__, msg=message) +<<<<<<< HEAD _ = (dist.batch_shape, dist.event_shape) +======= + dist.batch_shape, dist.event_shape +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertNotIn("logits", dist.__dict__, msg=message) diff --git a/test/dynamo/cpython/3_13/list_tests.diff b/test/dynamo/cpython/3_13/list_tests.diff index 1a5c63a9142dc..6f333a8c90514 100644 --- a/test/dynamo/cpython/3_13/list_tests.diff +++ b/test/dynamo/cpython/3_13/list_tests.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/list_tests.py b/test/dynamo/cpython/3_13/list_tests.py +<<<<<<< HEAD index dbc5ef4f9f2..af717703053 100644 --- a/test/dynamo/cpython/3_13/list_tests.py +++ b/test/dynamo/cpython/3_13/list_tests.py @@ -1,3 +1,56 @@ +======= +index dbc5ef4f9f2..2b9f3b9311f 100644 +--- a/test/dynamo/cpython/3_13/list_tests.py ++++ b/test/dynamo/cpython/3_13/list_tests.py +@@ -1,3 +1,53 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/list_tests.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -59,6 +69,7 @@ index dbc5ef4f9f2..af717703053 100644 """ Tests common to list and UserList.UserList """ +<<<<<<< HEAD @@ -5,7 +58,7 @@ Tests common to list and UserList.UserList import sys from functools import cmp_to_key @@ -157,3 +168,14 @@ index dbc5ef4f9f2..af717703053 100644 a = self.type2test() a[:] = [EvilCmp(a) for _ in range(100)] # This used to seg fault before patch #1005778 +======= +@@ -5,7 +55,7 @@ Tests common to list and UserList.UserList + import sys + from functools import cmp_to_key + +-from test import seq_tests ++import seq_tests + from test.support import ALWAYS_EQ, NEVER_EQ, get_c_recursion_limit + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/test/dynamo/cpython/3_13/list_tests.py b/test/dynamo/cpython/3_13/list_tests.py index 21e85eef179fd..bbbf67db4d956 100644 --- a/test/dynamo/cpython/3_13/list_tests.py +++ b/test/dynamo/cpython/3_13/list_tests.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/list_tests.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case @@ -172,6 +175,13 @@ def test_setitem(self): a[-1] = 9 self.assertEqual(a, self.type2test([5,6,7,8,9])) +<<<<<<< HEAD +======= + msg = "list indices must be integers or slices" + with self.assertRaisesRegex(TypeError, msg): + a['a'] = "python" + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_delitem(self): a = self.type2test([0, 1]) del a[1] @@ -319,6 +329,7 @@ def test_extend(self): self.assertRaises(TypeError, a.extend) # overflow test. issue1621 +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class CustomIter: def __iter__(self): @@ -327,6 +338,15 @@ def __next__(self): raise StopIteration def __length_hint__(self): return sys.maxsize +======= + class CustomIter: + def __iter__(self): + return self + def __next__(self): + raise StopIteration + def __length_hint__(self): + return sys.maxsize +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = self.type2test([1,2,3,4]) a.extend(CustomIter()) self.assertEqual(a, [1,2,3,4]) @@ -387,6 +407,7 @@ def test_remove(self): a = self.type2test([NEVER_EQ]) self.assertRaises(ValueError, a.remove, ALWAYS_EQ) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class BadExc(Exception): pass @@ -396,14 +417,30 @@ def __eq__(self, other): if other == 2: raise BadExc() return False +======= + class BadExc(Exception): + pass + + class BadCmp: + def __eq__(self, other): + if other == 2: + raise BadExc() + return False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = self.type2test([0, 1, 2, 3]) self.assertRaises(BadExc, a.remove, BadCmp()) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class BadCmp2: def __eq__(self, other): raise BadExc() +======= + class BadCmp2: + def __eq__(self, other): + raise BadExc() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = self.type2test('abcdefghcij') d.remove('c') @@ -428,6 +465,7 @@ def test_index(self): self.assertRaises(ValueError, a.index, 2, 0, 4) self.assertEqual(a, self.type2test([-2, -1, 0, 1, 2])) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): # Test modifying the list during index's iteration class EvilCmp: @@ -436,6 +474,15 @@ def __init__(self, victim): def __eq__(self, other): del self.victim[:] return False +======= + # Test modifying the list during index's iteration + class EvilCmp: + def __init__(self, victim): + self.victim = victim + def __eq__(self, other): + del self.victim[:] + return False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = self.type2test() a[:] = [EvilCmp(a) for _ in range(100)] # This used to seg fault before patch #1005778 diff --git a/test/dynamo/cpython/3_13/mapping_tests.diff b/test/dynamo/cpython/3_13/mapping_tests.diff index c376ddf725ae5..c54ac6cda7de8 100644 --- a/test/dynamo/cpython/3_13/mapping_tests.diff +++ b/test/dynamo/cpython/3_13/mapping_tests.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/mapping_tests.py b/test/dynamo/cpython/3_13/mapping_tests.py +<<<<<<< HEAD index ed89a81a6ea..b19cec7cb23 100644 --- a/test/dynamo/cpython/3_13/mapping_tests.py +++ b/test/dynamo/cpython/3_13/mapping_tests.py @@ -1,10 +1,64 @@ +======= +index ed89a81a6ea..eed59a68e94 100644 +--- a/test/dynamo/cpython/3_13/mapping_tests.py ++++ b/test/dynamo/cpython/3_13/mapping_tests.py +@@ -1,10 +1,61 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/mapping_tests.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -61,12 +71,18 @@ index ed89a81a6ea..b19cec7cb23 100644 import unittest import collections from test.support import get_c_recursion_limit +<<<<<<< HEAD +======= + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class BasicTestMappingProtocol(unittest.TestCase): +class BasicTestMappingProtocol(__TestCase): # This base class can be used to check that an object conforms to the # mapping protocol +<<<<<<< HEAD @@ -196,70 +250,76 @@ class BasicTestMappingProtocol(unittest.TestCase): self.assertRaises((TypeError, AttributeError), d.update, 42) @@ -418,3 +434,6 @@ index ed89a81a6ea..b19cec7cb23 100644 d = self._empty_mapping() x = BadHash() +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/test/dynamo/cpython/3_13/mapping_tests.py b/test/dynamo/cpython/3_13/mapping_tests.py index 88c97899ae3eb..fece8aca6b134 100644 --- a/test/dynamo/cpython/3_13/mapping_tests.py +++ b/test/dynamo/cpython/3_13/mapping_tests.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/mapping_tests.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case @@ -250,6 +253,7 @@ def test_update(self): self.assertRaises((TypeError, AttributeError), d.update, 42) outerself = self +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class SimpleUserDict: def __init__(self): @@ -258,12 +262,22 @@ def keys(self): return self.d.keys() def __getitem__(self, i): return self.d[i] +======= + class SimpleUserDict: + def __init__(self): + self.d = outerself.reference + def keys(self): + return self.d.keys() + def __getitem__(self, i): + return self.d[i] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d.clear() d.update(SimpleUserDict()) i1 = sorted(d.items()) i2 = sorted(self.reference.items()) self.assertEqual(i1, i2) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Exc(Exception): pass @@ -272,10 +286,19 @@ class Exc(Exception): pass class FailingUserDict: def keys(self): raise Exc +======= + class Exc(Exception): pass + + d = self._empty_mapping() + class FailingUserDict: + def keys(self): + raise Exc +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertRaises(Exc, d.update, FailingUserDict()) d.clear() +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class FailingUserDict: def keys(self): @@ -320,6 +343,49 @@ def __iter__(self): return self def __next__(self): raise Exc() +======= + class FailingUserDict: + def keys(self): + class BogonIter: + def __init__(self): + self.i = 1 + def __iter__(self): + return self + def __next__(self): + if self.i: + self.i = 0 + return 'a' + raise Exc + return BogonIter() + def __getitem__(self, key): + return key + self.assertRaises(Exc, d.update, FailingUserDict()) + + class FailingUserDict: + def keys(self): + class BogonIter: + def __init__(self): + self.i = ord('a') + def __iter__(self): + return self + def __next__(self): + if self.i <= ord('z'): + rtn = chr(self.i) + self.i += 1 + return rtn + raise StopIteration + return BogonIter() + def __getitem__(self, key): + raise Exc + self.assertRaises(Exc, d.update, FailingUserDict()) + + d = self._empty_mapping() + class badseq(object): + def __iter__(self): + return self + def __next__(self): + raise Exc() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertRaises(Exc, d.update, badseq()) @@ -469,6 +535,7 @@ def test_update(self): d.update(self._full_mapping({1:2, 3:4, 5:6}).items()) self.assertEqual(d, {1:2, 2:4, 3:4, 5:6}) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class SimpleUserDict: def __init__(self): @@ -477,6 +544,15 @@ def keys(self): return self.d.keys() def __getitem__(self, i): return self.d[i] +======= + class SimpleUserDict: + def __init__(self): + self.d = {1:1, 2:2, 3:3} + def keys(self): + return self.d.keys() + def __getitem__(self, i): + return self.d[i] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d.clear() d.update(SimpleUserDict()) self.assertEqual(d, {1:1, 2:2, 3:3}) @@ -492,22 +568,33 @@ def g(): yield 1 self.assertEqual(d.fromkeys(g()), {1:None}) self.assertRaises(TypeError, {}.fromkeys, 3) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class dictlike(self.type2test): pass +======= + class dictlike(self.type2test): pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(dictlike.fromkeys('a'), {'a':None}) self.assertEqual(dictlike().fromkeys('a'), {'a':None}) self.assertTrue(dictlike.fromkeys('a').__class__ is dictlike) self.assertTrue(dictlike().fromkeys('a').__class__ is dictlike) self.assertTrue(type(dictlike.fromkeys('a')) is dictlike) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class mydict(self.type2test): def __new__(cls): return collections.UserDict() +======= + class mydict(self.type2test): + def __new__(cls): + return collections.UserDict() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ud = mydict.fromkeys('ab') self.assertEqual(ud, {'a':None, 'b':None}) self.assertIsInstance(ud, collections.UserDict) self.assertRaises(TypeError, dict.fromkeys) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Exc(Exception): pass @@ -530,6 +617,27 @@ def __next__(self): class baddict2(self.type2test): def __setitem__(self, key, value): raise Exc() +======= + class Exc(Exception): pass + + class baddict1(self.type2test): + def __init__(self, *args, **kwargs): + raise Exc() + + self.assertRaises(Exc, baddict1.fromkeys, [1]) + + class BadSeq(object): + def __iter__(self): + return self + def __next__(self): + raise Exc() + + self.assertRaises(Exc, self.type2test.fromkeys, BadSeq()) + + class baddict2(self.type2test): + def __setitem__(self, key, value): + raise Exc() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertRaises(Exc, baddict2.fromkeys, [1]) @@ -603,6 +711,7 @@ class TestHashMappingProtocol(TestMappingProtocol): def test_getitem(self): TestMappingProtocol.test_getitem(self) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Exc(Exception): pass @@ -611,11 +720,21 @@ def __eq__(self, other): raise Exc() def __hash__(self): return 24 +======= + class Exc(Exception): pass + + class BadEq(object): + def __eq__(self, other): + raise Exc() + def __hash__(self): + return 24 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = self._empty_mapping() d[BadEq()] = 42 self.assertRaises(KeyError, d.__getitem__, 23) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class BadHash(object): fail = False @@ -624,6 +743,15 @@ def __hash__(self): raise Exc() else: return 42 +======= + class BadHash(object): + fail = False + def __hash__(self): + if self.fail: + raise Exc() + else: + return 42 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = self._empty_mapping() x = BadHash() @@ -633,10 +761,16 @@ def __hash__(self): def test_fromkeys(self): TestMappingProtocol.test_fromkeys(self) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class mydict(self.type2test): def __new__(cls): return collections.UserDict() +======= + class mydict(self.type2test): + def __new__(cls): + return collections.UserDict() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ud = mydict.fromkeys('ab') self.assertEqual(ud, {'a':None, 'b':None}) self.assertIsInstance(ud, collections.UserDict) @@ -644,6 +778,7 @@ def __new__(cls): def test_pop(self): TestMappingProtocol.test_pop(self) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Exc(Exception): pass @@ -654,6 +789,17 @@ def __hash__(self): raise Exc() else: return 42 +======= + class Exc(Exception): pass + + class BadHash(object): + fail = False + def __hash__(self): + if self.fail: + raise Exc() + else: + return 42 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = self._empty_mapping() x = BadHash() @@ -683,12 +829,20 @@ def test_repr(self): d[1] = d self.assertEqual(repr(d), '{1: {...}}') +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Exc(Exception): pass class BadRepr(object): def __repr__(self): raise Exc() +======= + class Exc(Exception): pass + + class BadRepr(object): + def __repr__(self): + raise Exc() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = self._full_mapping({1: BadRepr()}) self.assertRaises(Exc, repr, d) @@ -706,6 +860,7 @@ def test_eq(self): self.assertEqual(self._full_mapping({1: 2}), self._full_mapping({1: 2})) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Exc(Exception): pass @@ -714,6 +869,15 @@ def __eq__(self, other): raise Exc() def __hash__(self): return 1 +======= + class Exc(Exception): pass + + class BadCmp(object): + def __eq__(self, other): + raise Exc() + def __hash__(self): + return 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d1 = self._full_mapping({BadCmp(): 1}) d2 = self._full_mapping({1: 1}) @@ -723,6 +887,7 @@ def __hash__(self): def test_setdefault(self): TestMappingProtocol.test_setdefault(self) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Exc(Exception): pass @@ -733,6 +898,17 @@ def __hash__(self): raise Exc() else: return 42 +======= + class Exc(Exception): pass + + class BadHash(object): + fail = False + def __hash__(self): + if self.fail: + raise Exc() + else: + return 42 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = self._empty_mapping() x = BadHash() diff --git a/test/dynamo/cpython/3_13/seq_tests.diff b/test/dynamo/cpython/3_13/seq_tests.diff index d5e6f92a07689..388c037c6aed4 100644 --- a/test/dynamo/cpython/3_13/seq_tests.diff +++ b/test/dynamo/cpython/3_13/seq_tests.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/seq_tests.py b/test/dynamo/cpython/3_13/seq_tests.py +<<<<<<< HEAD index 719c9434a16..290e57c04a0 100644 --- a/test/dynamo/cpython/3_13/seq_tests.py +++ b/test/dynamo/cpython/3_13/seq_tests.py @@ -1,3 +1,57 @@ +======= +index 719c9434a16..4325892276d 100644 +--- a/test/dynamo/cpython/3_13/seq_tests.py ++++ b/test/dynamo/cpython/3_13/seq_tests.py +@@ -1,3 +1,54 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/seq_tests.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -60,14 +70,22 @@ index 719c9434a16..290e57c04a0 100644 """ Tests common to tuple, list and UserList.UserList """ +<<<<<<< HEAD @@ -95,7 +149,7 @@ class LyingList(list): def __iter__(self): yield 1 +======= +@@ -95,7 +146,7 @@ class LyingList(list): + def __iter__(self): + yield 1 + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class CommonTest(unittest.TestCase): +class CommonTest(__TestCase): # The type to be tested type2test = None +<<<<<<< HEAD @@ -115,13 +169,14 @@ class CommonTest(unittest.TestCase): uu2 = self.type2test(u2) @@ -181,3 +199,6 @@ index 719c9434a16..290e57c04a0 100644 a = self.type2test([0, 1, 2, 3]) self.assertRaises(BadExc, a.index, BadCmp()) +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/test/dynamo/cpython/3_13/seq_tests.py b/test/dynamo/cpython/3_13/seq_tests.py index 11d59c847326c..7152b2d051de0 100644 --- a/test/dynamo/cpython/3_13/seq_tests.py +++ b/test/dynamo/cpython/3_13/seq_tests.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/seq_tests.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case @@ -169,6 +172,7 @@ def test_constructors(self): uu2 = self.type2test(u2) v = self.type2test(tuple(u)) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class OtherSeq: def __init__(self, initseq): @@ -177,6 +181,15 @@ def __len__(self): return len(self.__data) def __getitem__(self, i): return self.__data[i] +======= + class OtherSeq: + def __init__(self, initseq): + self.__data = initseq + def __len__(self): + return len(self.__data) + def __getitem__(self, i): + return self.__data[i] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) s = OtherSeq(u0) v0 = self.type2test(s) self.assertEqual(len(v0), len(s)) @@ -294,12 +307,20 @@ def test_contains_order(self): # Sequences must test in-order. If a rich comparison has side # effects, these will be visible to tests against later members. # In this test, the "side effect" is a short-circuiting raise. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class DoNotTestEq(Exception): pass class StopCompares: def __eq__(self, other): raise DoNotTestEq +======= + class DoNotTestEq(Exception): + pass + class StopCompares: + def __eq__(self, other): + raise DoNotTestEq +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) checkfirst = self.type2test([1, StopCompares()]) self.assertIn(1, checkfirst) @@ -339,9 +360,14 @@ def test_addmul(self): self.assertEqual(u2+u2+u2, u2*3) self.assertEqual(u2+u2+u2, 3*u2) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass(self.type2test): pass +======= + class subclass(self.type2test): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u3 = subclass([0, 1]) self.assertEqual(u3, u3*1) self.assertIsNot(u3, u3*1) @@ -368,10 +394,16 @@ def test_imul(self): def test_getitemoverwriteiter(self): # Verify that __getitem__ overrides are not recognized by __iter__ +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class T(self.type2test): def __getitem__(self, key): return str(key) + '!!!' +======= + class T(self.type2test): + def __getitem__(self, key): + return str(key) + '!!!' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(next(iter(T((1,2)))), 1) def test_repeat(self): @@ -419,6 +451,7 @@ def test_count(self): self.assertRaises(TypeError, a.count) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class BadExc(Exception): pass @@ -428,6 +461,16 @@ def __eq__(self, other): if other == 2: raise BadExc() return False +======= + class BadExc(Exception): + pass + + class BadCmp: + def __eq__(self, other): + if other == 2: + raise BadExc() + return False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertRaises(BadExc, a.count, BadCmp()) @@ -453,6 +496,7 @@ def test_index(self): self.assertRaises(TypeError, u.index) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class BadExc(Exception): pass @@ -462,6 +506,16 @@ def __eq__(self, other): if other == 2: raise BadExc() return False +======= + class BadExc(Exception): + pass + + class BadCmp: + def __eq__(self, other): + if other == 2: + raise BadExc() + return False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = self.type2test([0, 1, 2, 3]) self.assertRaises(BadExc, a.index, BadCmp()) diff --git a/test/dynamo/cpython/3_13/test_cmath.diff b/test/dynamo/cpython/3_13/test_cmath.diff index deb03570db1cd..217151a68326d 100644 --- a/test/dynamo/cpython/3_13/test_cmath.diff +++ b/test/dynamo/cpython/3_13/test_cmath.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/test_cmath.py b/test/dynamo/cpython/3_13/test_cmath.py +<<<<<<< HEAD index a96a5780b31..d00dfca8a17 100644 --- a/test/dynamo/cpython/3_13/test_cmath.py +++ b/test/dynamo/cpython/3_13/test_cmath.py @@ -1,5 +1,58 @@ +======= +index a96a5780b31..883e87a0733 100644 +--- a/test/dynamo/cpython/3_13/test_cmath.py ++++ b/test/dynamo/cpython/3_13/test_cmath.py +@@ -1,5 +1,55 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_cmath.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -62,19 +72,33 @@ index a96a5780b31..d00dfca8a17 100644 from test.test_math import parse_testfile, test_file import test.test_math as test_math import unittest +<<<<<<< HEAD @@ -50,7 +103,7 @@ complex_nans = [complex(x, y) for x, y in [ (INF, NAN) ]] +======= +@@ -50,7 +100,7 @@ complex_nans = [complex(x, y) for x, y in [ + (INF, NAN) + ]] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase): +class CMathTests(__TestCase): # list of all functions in cmath test_functions = [getattr(cmath, fname) for fname in [ 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', +<<<<<<< HEAD @@ -66,6 +119,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase): def tearDown(self): self.test_values.close() +======= +@@ -66,6 +116,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase): + def tearDown(self): + self.test_values.close() + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + def assertFloatIdentical(self, x, y): + """Fail unless floats x and y are identical, in the sense that: + (1) both x and y are nans, or @@ -111,6 +135,7 @@ index a96a5780b31..d00dfca8a17 100644 def rAssertAlmostEqual(self, a, b, rel_err = 2e-15, abs_err = 5e-323, msg=None): """Fail if the two floating-point numbers are not almost equal. +<<<<<<< HEAD @@ -165,38 +251,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase): # end up being passed to the cmath functions @@ -185,6 +210,11 @@ index a96a5780b31..d00dfca8a17 100644 @@ -590,4 +677,4 @@ class IsCloseTests(test_math.IsCloseTests): +======= +@@ -590,4 +673,4 @@ class IsCloseTests(test_math.IsCloseTests): + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/test/dynamo/cpython/3_13/test_cmath.py b/test/dynamo/cpython/3_13/test_cmath.py index 95cb84121f9c3..506f127a0de65 100644 --- a/test/dynamo/cpython/3_13/test_cmath.py +++ b/test/dynamo/cpython/3_13/test_cmath.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_cmath.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case @@ -251,6 +254,7 @@ def test_user_object(self): # end up being passed to the cmath functions # usual case: new-style class implementing __complex__ +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class MyComplex: def __init__(self, value): @@ -284,6 +288,40 @@ def __complex__(self): class JustFloat: def __float__(self): return flt_arg +======= + class MyComplex: + def __init__(self, value): + self.value = value + def __complex__(self): + return self.value + + # classes for which __complex__ raises an exception + class SomeException(Exception): + pass + class MyComplexException: + def __complex__(self): + raise SomeException + + # some classes not providing __float__ or __complex__ + class NeitherComplexNorFloat(object): + pass + class Index: + def __int__(self): return 2 + def __index__(self): return 2 + class MyInt: + def __int__(self): return 2 + + # other possible combinations of __float__ and __complex__ + # that should work + class FloatAndComplex: + def __float__(self): + return flt_arg + def __complex__(self): + return cx_arg + class JustFloat: + def __float__(self): + return flt_arg +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for f in self.test_functions: # usual usage diff --git a/test/dynamo/cpython/3_13/test_complex.diff b/test/dynamo/cpython/3_13/test_complex.diff index 2a7042b9c0a6f..446466ebb8919 100644 --- a/test/dynamo/cpython/3_13/test_complex.diff +++ b/test/dynamo/cpython/3_13/test_complex.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py +<<<<<<< HEAD index 6ff1a8ab29d..1572433c5ae 100644 --- a/test/dynamo/cpython/3_13/test_complex.py +++ b/test/dynamo/cpython/3_13/test_complex.py @@ -1,16 +1,147 @@ +======= +index 6ff1a8ab29d..ab5bd3dab62 100644 +--- a/test/dynamo/cpython/3_13/test_complex.py ++++ b/test/dynamo/cpython/3_13/test_complex.py +@@ -1,16 +1,143 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_complex.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -19,7 +29,10 @@ index 6ff1a8ab29d..1572433c5ae 100644 +from torch._dynamo.test_case import CPythonTestCase +from torch.testing._internal.common_utils import ( + run_tests, +<<<<<<< HEAD + slowTest, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + xfailIfTorchDynamo, +) + @@ -43,7 +56,11 @@ index 6ff1a8ab29d..1572433c5ae 100644 + "test.test_iter", + "test.typinganndata.ann_module", ) +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +class RedirectImportFinder(importlib.abc.MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # Check if the import is the problematic one @@ -74,7 +91,11 @@ index 6ff1a8ab29d..1572433c5ae 100644 from math import isnan, copysign +import math import operator +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +VALID_UNDERSCORE_LITERALS = [ + '0_0_0', + '4_2', @@ -155,10 +176,17 @@ index 6ff1a8ab29d..1572433c5ae 100644 INF = float("inf") NAN = float("nan") DBL_MAX = sys.float_info.max +<<<<<<< HEAD @@ -45,7 +176,40 @@ class WithComplex: def __complex__(self): return self.value +======= +@@ -45,7 +172,40 @@ class WithComplex: + def __complex__(self): + return self.value + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): +class ComplexTest(__TestCase): + @@ -194,6 +222,7 @@ index 6ff1a8ab29d..1572433c5ae 100644 + """ + self.assertFloatIdentical(x.real, y.real) + self.assertFloatIdentical(x.imag, y.imag) +<<<<<<< HEAD def assertAlmostEqual(self, a, b): if isinstance(a, complex): @@ -201,6 +230,15 @@ index 6ff1a8ab29d..1572433c5ae 100644 # check that relative difference < eps self.assertTrue(abs((x-y)/y) < eps) +======= + + def assertAlmostEqual(self, a, b): + if isinstance(a, complex): +@@ -74,6 +234,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): + # check that relative difference < eps + self.assertTrue(abs((x-y)/y) < eps) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + def assertFloatsAreIdentical(self, x, y): + """assert that floats x and y are identical, in the sense that: + (1) both x and y are nans, or @@ -227,6 +265,7 @@ index 6ff1a8ab29d..1572433c5ae 100644 def assertClose(self, x, y, eps=1e-9): """Return true iff complexes x and y "are close".""" self.assertCloseAbs(x.real, y.real, eps) +<<<<<<< HEAD @@ -93,6 +280,7 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): q = z.__truediv__(y) self.assertClose(q, x) @@ -323,6 +362,11 @@ index 6ff1a8ab29d..1572433c5ae 100644 @@ -855,4 +1049,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): +======= +@@ -855,4 +1038,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py index 6921c1da6ec4c..140233ddc15b9 100644 --- a/test/dynamo/cpython/3_13/test_complex.py +++ b/test/dynamo/cpython/3_13/test_complex.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_complex.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case @@ -14,7 +17,10 @@ from torch._dynamo.test_case import CPythonTestCase from torch.testing._internal.common_utils import ( run_tests, +<<<<<<< HEAD slowTest, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfailIfTorchDynamo, ) @@ -280,7 +286,10 @@ def check_div(self, x, y): q = z.__truediv__(y) self.assertClose(q, x) +<<<<<<< HEAD @slowTest +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_truediv(self): simple_real = [float(i) for i in range(-5, 6)] simple_complex = [complex(x, y) for x in simple_real for y in simple_real] @@ -526,10 +535,14 @@ def test_pow_with_small_integer_exponents(self): def test_boolcontext(self): for i in range(100): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): r1 = random() r2 = random() self.assertTrue(complex(r1 + 1e-6, r2 + 1e-6)) +======= + self.assertTrue(complex(random() + 1e-6, random() + 1e-6)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(not complex(0.0, 0.0)) self.assertTrue(1j) @@ -622,6 +635,7 @@ def check(z, x, y): self.assertRaises(TypeError, complex, WithComplex(1), object()) self.assertRaises(TypeError, complex, WithComplex(None), object()) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class EvilExc(Exception): pass @@ -629,6 +643,14 @@ class EvilExc(Exception): class evilcomplex: def __complex__(self): raise EvilExc +======= + class EvilExc(Exception): + pass + + class evilcomplex: + def __complex__(self): + raise EvilExc +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertRaises(EvilExc, complex, evilcomplex()) @@ -652,15 +674,22 @@ def __complex__(self): self.assertRaises(TypeError, complex, WithIndex(None), 1.5) self.assertRaises(TypeError, complex, 1.5, WithIndex(None)) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class MyInt: def __int__(self): return 42 +======= + class MyInt: + def __int__(self): + return 42 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertRaises(TypeError, complex, MyInt()) self.assertRaises(TypeError, complex, MyInt(), 1.5) self.assertRaises(TypeError, complex, 1.5, MyInt()) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class complex0(complex): """Test usage of __complex__() when inheriting from 'complex'""" @@ -679,6 +708,25 @@ class complex2(complex): complex is returned""" def __complex__(self): return None +======= + class complex0(complex): + """Test usage of __complex__() when inheriting from 'complex'""" + def __complex__(self): + return 42j + + class complex1(complex): + """Test usage of __complex__() with a __new__() method""" + def __new__(self, value=0j): + return complex.__new__(self, 2*value) + def __complex__(self): + return self + + class complex2(complex): + """Make sure that __complex__() calls fail if anything other than a + complex is returned""" + def __complex__(self): + return None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) check(complex(complex0(1j)), 0.0, 42.0) with self.assertWarns(DeprecationWarning): diff --git a/test/dynamo/cpython/3_13/test_dict.diff b/test/dynamo/cpython/3_13/test_dict.diff index d8e24851409a9..66398efe92f24 100644 --- a/test/dynamo/cpython/3_13/test_dict.diff +++ b/test/dynamo/cpython/3_13/test_dict.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/test_dict.py b/test/dynamo/cpython/3_13/test_dict.py +<<<<<<< HEAD index 4729132c5a5..6ecf111c1e3 100644 --- a/test/dynamo/cpython/3_13/test_dict.py +++ b/test/dynamo/cpython/3_13/test_dict.py @@ -1,3 +1,60 @@ +======= +index 4729132c5a5..14f829c1715 100644 +--- a/test/dynamo/cpython/3_13/test_dict.py ++++ b/test/dynamo/cpython/3_13/test_dict.py +@@ -1,3 +1,57 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_dict.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -63,6 +73,7 @@ index 4729132c5a5..6ecf111c1e3 100644 import collections import collections.abc import gc +<<<<<<< HEAD @@ -11,11 +68,12 @@ from test import support from test.support import import_helper, get_c_recursion_limit @@ -255,10 +266,26 @@ index 4729132c5a5..6ecf111c1e3 100644 self.assertRaises(ValueError, {}.update, [(1, 2, 3)]) +======= +@@ -11,7 +65,7 @@ from test import support + from test.support import import_helper, get_c_recursion_limit + + +-class DictTest(unittest.TestCase): ++class DictTest(__TestCase): + + def test_invalid_keyword_arguments(self): + class Custom(dict): +@@ -265,6 +319,7 @@ class DictTest(unittest.TestCase): + + self.assertRaises(ValueError, {}.update, [(1, 2, 3)]) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + @unittest.skip("test hangs") def test_fromkeys(self): self.assertEqual(dict.fromkeys('abc'), {'a':None, 'b':None, 'c':None}) d = {} +<<<<<<< HEAD @@ -276,38 +346,43 @@ class DictTest(unittest.TestCase): yield 1 self.assertEqual(d.fromkeys(g()), {1:None}) @@ -441,6 +468,9 @@ index 4729132c5a5..6ecf111c1e3 100644 # 5 items y = {hashed1: 5, 0: 0, 1: 1, 2: 2, 3: 3} @@ -477,7 +559,7 @@ class DictTest(unittest.TestCase): +======= +@@ -477,7 +532,7 @@ class DictTest(unittest.TestCase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for copymode in -1, +1: # -1: b has same structure as a # +1: b is a.copy() @@ -449,6 +479,7 @@ index 4729132c5a5..6ecf111c1e3 100644 size = 2**log2size a = {} b = {} +<<<<<<< HEAD @@ -517,15 +599,16 @@ class DictTest(unittest.TestCase): self.assertRaises(TypeError, d.pop) @@ -694,6 +725,12 @@ index 4729132c5a5..6ecf111c1e3 100644 pass self._tracked(MyDict()) +======= +@@ -1006,18 +1061,6 @@ class DictTest(unittest.TestCase): + pass + self._tracked(MyDict()) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - @support.cpython_only - def test_track_lazy_instance_dicts(self): - class C: @@ -707,6 +744,7 @@ index 4729132c5a5..6ecf111c1e3 100644 - self._tracked(d) - def make_shared_key_dict(self, n): +<<<<<<< HEAD - class C: - pass + with torch._dynamo.error_on_graph_break(False): @@ -1103,6 +1141,22 @@ index 4729132c5a5..6ecf111c1e3 100644 @@ -1666,4 +1773,4 @@ class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol): +======= + class C: + pass +@@ -1622,7 +1665,7 @@ class DictTest(unittest.TestCase): + self.assertGreaterEqual(eq_count, 1) + + +-class CAPITest(unittest.TestCase): ++class CAPITest(__TestCase): + + # Test _PyDict_GetItem_KnownHash() + @support.cpython_only +@@ -1666,4 +1709,4 @@ class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol): + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/test/dynamo/cpython/3_13/test_dict.py b/test/dynamo/cpython/3_13/test_dict.py index 4a4f170ad9727..d475cbc712c39 100644 --- a/test/dynamo/cpython/3_13/test_dict.py +++ b/test/dynamo/cpython/3_13/test_dict.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_dict.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case @@ -71,9 +74,14 @@ def find_spec(self, fullname, path, target=None): class DictTest(__TestCase): def test_invalid_keyword_arguments(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Custom(dict): pass +======= + class Custom(dict): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for invalid in {1 : 2}, Custom({1 : 2}): with self.assertRaises(TypeError): dict(**invalid) @@ -166,9 +174,14 @@ def test_items(self): def test_views_mapping(self): mappingproxy = type(type.__dict__) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Dict(dict): pass +======= + class Dict(dict): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for cls in [dict, Dict]: d = cls() m1 = d.keys().mapping @@ -216,17 +229,26 @@ def test_getitem(self): self.assertRaises(TypeError, d.__getitem__) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class BadEq(object): def __eq__(self, other): raise Exc() def __hash__(self): return 24 +======= + class BadEq(object): + def __eq__(self, other): + raise Exc() + def __hash__(self): + return 24 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = {} d[BadEq()] = 42 self.assertRaises(KeyError, d.__getitem__, 23) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Exc(Exception): pass @@ -237,6 +259,17 @@ def __hash__(self): raise Exc() else: return 42 +======= + class Exc(Exception): pass + + class BadHash(object): + fail = False + def __hash__(self): + if self.fail: + raise Exc() + else: + return 42 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = BadHash() d[x] = 42 @@ -262,6 +295,7 @@ def test_update(self): self.assertRaises((TypeError, AttributeError), d.update, None) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class SimpleUserDict: def __init__(self): @@ -270,10 +304,20 @@ def keys(self): return self.d.keys() def __getitem__(self, i): return self.d[i] +======= + class SimpleUserDict: + def __init__(self): + self.d = {1:1, 2:2, 3:3} + def keys(self): + return self.d.keys() + def __getitem__(self, i): + return self.d[i] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d.clear() d.update(SimpleUserDict()) self.assertEqual(d, {1:1, 2:2, 3:3}) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Exc(Exception): pass @@ -329,6 +373,56 @@ def __iter__(self): return self def __next__(self): raise Exc() +======= + class Exc(Exception): pass + + d.clear() + class FailingUserDict: + def keys(self): + raise Exc + self.assertRaises(Exc, d.update, FailingUserDict()) + + class FailingUserDict: + def keys(self): + class BogonIter: + def __init__(self): + self.i = 1 + def __iter__(self): + return self + def __next__(self): + if self.i: + self.i = 0 + return 'a' + raise Exc + return BogonIter() + def __getitem__(self, key): + return key + self.assertRaises(Exc, d.update, FailingUserDict()) + + class FailingUserDict: + def keys(self): + class BogonIter: + def __init__(self): + self.i = ord('a') + def __iter__(self): + return self + def __next__(self): + if self.i <= ord('z'): + rtn = chr(self.i) + self.i += 1 + return rtn + raise StopIteration + return BogonIter() + def __getitem__(self, key): + raise Exc + self.assertRaises(Exc, d.update, FailingUserDict()) + + class badseq(object): + def __iter__(self): + return self + def __next__(self): + raise Exc() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertRaises(Exc, {}.update, badseq()) @@ -346,21 +440,32 @@ def g(): yield 1 self.assertEqual(d.fromkeys(g()), {1:None}) self.assertRaises(TypeError, {}.fromkeys, 3) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class dictlike(dict): pass +======= + class dictlike(dict): pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(dictlike.fromkeys('a'), {'a':None}) self.assertEqual(dictlike().fromkeys('a'), {'a':None}) self.assertIsInstance(dictlike.fromkeys('a'), dictlike) self.assertIsInstance(dictlike().fromkeys('a'), dictlike) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class mydict(dict): def __new__(cls): return collections.UserDict() +======= + class mydict(dict): + def __new__(cls): + return collections.UserDict() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ud = mydict.fromkeys('ab') self.assertEqual(ud, {'a':None, 'b':None}) self.assertIsInstance(ud, collections.UserDict) self.assertRaises(TypeError, dict.fromkeys) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Exc(Exception): pass @@ -383,6 +488,27 @@ def __next__(self): class baddict2(dict): def __setitem__(self, key, value): raise Exc() +======= + class Exc(Exception): pass + + class baddict1(dict): + def __init__(self): + raise Exc() + + self.assertRaises(Exc, baddict1.fromkeys, [1]) + + class BadSeq(object): + def __iter__(self): + return self + def __next__(self): + raise Exc() + + self.assertRaises(Exc, dict.fromkeys, BadSeq()) + + class baddict2(dict): + def __setitem__(self, key, value): + raise Exc() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertRaises(Exc, baddict2.fromkeys, [1]) @@ -398,20 +524,32 @@ def __setitem__(self, key, value): self.assertEqual(dict.fromkeys(d, 0), res) # test fast path when object's constructor returns large non-empty dict +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class baddict3(dict): def __new__(cls): return d +======= + class baddict3(dict): + def __new__(cls): + return d +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = {i : i for i in range(1000)} res = d.copy() res.update(a=None, b=None, c=None) self.assertEqual(baddict3.fromkeys({"a", "b", "c"}), res) # test slow path when object is a proper subclass of dict +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class baddict4(dict): def __init__(self): dict.__init__(self, d) +======= + class baddict4(dict): + def __init__(self): + dict.__init__(self, d) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = {i : i for i in range(1000)} res = d.copy() res.update(a=None, b=None, c=None) @@ -447,9 +585,14 @@ def test_copy_fuzz(self): self.assertEqual(len(d2), len(d) + 1) def test_copy_maintains_tracking(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class A: pass +======= + class A: + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) key = A() @@ -494,6 +637,7 @@ def test_setdefault(self): self.assertEqual(len(d['key']), 2) self.assertRaises(TypeError, d.setdefault) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Exc(Exception): pass @@ -505,6 +649,17 @@ def __hash__(self): raise Exc() else: return 42 +======= + class Exc(Exception): pass + + class BadHash(object): + fail = False + def __hash__(self): + if self.fail: + raise Exc() + else: + return 42 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = BadHash() d[x] = 42 @@ -513,6 +668,7 @@ def __hash__(self): def test_setdefault_atomic(self): # Issue #13521: setdefault() calls __hash__ and __eq__ only once. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Hashed(object): def __init__(self): @@ -524,6 +680,18 @@ def __hash__(self): def __eq__(self, other): self.eq_count += 1 return id(self) == id(other) +======= + class Hashed(object): + def __init__(self): + self.hash_count = 0 + self.eq_count = 0 + def __hash__(self): + self.hash_count += 1 + return 42 + def __eq__(self, other): + self.eq_count += 1 + return id(self) == id(other) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) hashed1 = Hashed() y = {hashed1: 5} hashed2 = Hashed() @@ -533,6 +701,7 @@ def __eq__(self, other): self.assertEqual(hashed1.eq_count + hashed2.eq_count, 1) def test_setitem_atomic_at_resize(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Hashed(object): def __init__(self): @@ -544,6 +713,18 @@ def __hash__(self): def __eq__(self, other): self.eq_count += 1 return id(self) == id(other) +======= + class Hashed(object): + def __init__(self): + self.hash_count = 0 + self.eq_count = 0 + def __hash__(self): + self.hash_count += 1 + return 42 + def __eq__(self, other): + self.eq_count += 1 + return id(self) == id(other) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) hashed1 = Hashed() # 5 items y = {hashed1: 5, 0: 0, 1: 1, 2: 2, 3: 3} @@ -599,6 +780,7 @@ def test_pop(self): self.assertRaises(TypeError, d.pop) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Exc(Exception): pass @@ -609,6 +791,17 @@ def __hash__(self): raise Exc() else: return 42 +======= + class Exc(Exception): pass + + class BadHash(object): + fail = False + def __hash__(self): + if self.fail: + raise Exc() + else: + return 42 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = BadHash() d[x] = 42 @@ -652,6 +845,7 @@ def test_mutating_iteration_delete_over_items(self): def test_mutating_lookup(self): # changing dict during a lookup (issue #14417) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class NastyKey: mutate_dict = None @@ -669,6 +863,24 @@ def __eq__(self, other): NastyKey.mutate_dict = None del mydict[key] return self.value == other.value +======= + class NastyKey: + mutate_dict = None + + def __init__(self, value): + self.value = value + + def __hash__(self): + # hash collision! + return 1 + + def __eq__(self, other): + if NastyKey.mutate_dict: + mydict, key = NastyKey.mutate_dict + NastyKey.mutate_dict = None + del mydict[key] + return self.value == other.value +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) key1 = NastyKey(1) key2 = NastyKey(2) @@ -686,12 +898,20 @@ def test_repr(self): d[1] = d self.assertEqual(repr(d), '{1: {...}}') +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Exc(Exception): pass class BadRepr(object): def __repr__(self): raise Exc() +======= + class Exc(Exception): pass + + class BadRepr(object): + def __repr__(self): + raise Exc() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = {1: BadRepr()} self.assertRaises(Exc, repr, d) @@ -706,6 +926,7 @@ def test_eq(self): self.assertEqual({}, {}) self.assertEqual({1: 2}, {1: 2}) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Exc(Exception): pass @@ -714,6 +935,15 @@ def __eq__(self, other): raise Exc() def __hash__(self): return 1 +======= + class Exc(Exception): pass + + class BadCmp(object): + def __eq__(self, other): + raise Exc() + def __hash__(self): + return 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d1 = {BadCmp(): 1} d2 = {1: 1} @@ -770,10 +1000,16 @@ def helper_keys_contained(self, fn): self.assertFalse(larger == larger3) def test_errors_in_view_containment_check(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class C: def __eq__(self, other): raise RuntimeError +======= + class C: + def __eq__(self, other): + raise RuntimeError +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d1 = {1: C()} d2 = {1: C()} @@ -853,10 +1089,16 @@ def test_missing(self): # (E) subclass defines __missing__ method raising RuntimeError # (F) subclass sets __missing__ instance variable (no effect) # (G) subclass doesn't define __missing__ at all +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class D(dict): def __missing__(self, key): return 42 +======= + class D(dict): + def __missing__(self, key): + return 42 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = D({1: 2, 3: 4}) self.assertEqual(d[1], 2) self.assertEqual(d[3], 4) @@ -864,28 +1106,46 @@ def __missing__(self, key): self.assertNotIn(2, d.keys()) self.assertEqual(d[2], 42) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class E(dict): def __missing__(self, key): raise RuntimeError(key) +======= + class E(dict): + def __missing__(self, key): + raise RuntimeError(key) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) e = E() with self.assertRaises(RuntimeError) as c: e[42] self.assertEqual(c.exception.args, (42,)) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class F(dict): def __init__(self): # An instance variable __missing__ should have no effect self.__missing__ = lambda key: None +======= + class F(dict): + def __init__(self): + # An instance variable __missing__ should have no effect + self.__missing__ = lambda key: None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f = F() with self.assertRaises(KeyError) as c: f[42] self.assertEqual(c.exception.args, (42,)) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class G(dict): pass +======= + class G(dict): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) g = G() with self.assertRaises(KeyError) as c: g[42] @@ -900,6 +1160,7 @@ def test_tuple_keyerror(self): def test_bad_key(self): # Dictionary lookups should fail if __eq__() raises an exception. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class CustomException(Exception): pass @@ -912,6 +1173,19 @@ def __eq__(self, other): if isinstance(other, self.__class__): raise CustomException return other +======= + class CustomException(Exception): + pass + + class BadDictKey: + def __hash__(self): + return hash(self.__class__) + + def __eq__(self, other): + if isinstance(other, self.__class__): + raise CustomException + return other +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = {} x1 = BadDictKey() @@ -947,6 +1221,7 @@ def test_resize2(self): # Another dict resizing bug (SF bug #1456209). # This caused Segmentation faults or Illegal instructions. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class X(object): def __hash__(self): @@ -955,6 +1230,15 @@ def __eq__(self, other): if resizing: d.clear() return False +======= + class X(object): + def __hash__(self): + return 5 + def __eq__(self, other): + if resizing: + d.clear() + return False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = {} resizing = False d[X()] = 1 @@ -977,9 +1261,14 @@ def test_empty_presized_dict_in_freelist(self): def test_container_iterator(self): # Bug #3680: tp_traverse was not implemented for dictiter and # dictview objects. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class C(object): pass +======= + class C(object): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) views = (dict.items, dict.values, dict.keys) for v in views: obj = C() @@ -1032,10 +1321,15 @@ def test_track_literals(self): @support.cpython_only def test_track_dynamic(self): # Test GC-optimization of dynamically-created dicts +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class MyObject(object): pass +======= + class MyObject(object): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x, y, z, w, o = 1.5, "a", (1, object()), [], MyObject() d = dict() @@ -1103,9 +1397,14 @@ class MyDict(dict): self._tracked(MyDict()) def make_shared_key_dict(self, n): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class C: pass +======= + class C: + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dicts = [] for i in range(n): @@ -1194,6 +1493,7 @@ def test_splittable_popitem(self): @support.cpython_only def test_splittable_update(self): """dict.update(other) must preserve order in other.""" +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class C: def __init__(self, order): @@ -1201,6 +1501,14 @@ def __init__(self, order): self.a, self.b, self.c = 1, 2, 3 else: self.c, self.b, self.a = 1, 2, 3 +======= + class C: + def __init__(self, order): + if order: + self.a, self.b, self.c = 1, 2, 3 + else: + self.c, self.b, self.a = 1, 2, 3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) o = C(True) o = C(False) # o.__dict__ has reversed order. self.assertEqual(list(o.__dict__), ["c", "b", "a"]) @@ -1212,9 +1520,14 @@ def __init__(self, order): @support.cpython_only def test_splittable_to_generic_combinedtable(self): """split table must be correctly resized and converted to generic combined table""" +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class C: pass +======= + class C: + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = C() a.x = 1 @@ -1336,6 +1649,7 @@ def test_reversevaluesiterator_pickling(self): self.assertEqual(sorted(values), sorted(data.values())) def test_instance_dict_getattr_str_subclass(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Foo: def __init__(self, msg): @@ -1344,12 +1658,24 @@ def __init__(self, msg): with torch._dynamo.error_on_graph_break(False): class _str(str): pass +======= + class Foo: + def __init__(self, msg): + self.msg = msg + f = Foo('123') + class _str(str): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(f.msg, getattr(f, _str('msg'))) self.assertEqual(f.msg, f.__dict__[_str('msg')]) def test_object_set_item_single_instance_non_str_key(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Foo: pass +======= + class Foo: pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f = Foo() f.__dict__[1] = 1 f.a = 'a' @@ -1359,10 +1685,16 @@ def check_reentrant_insertion(self, mutate): # This object will trigger mutation of the dict when replaced # by another value. Note this relies on refcounting: the test # won't achieve its purpose on fully-GCed Python implementations. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Mutating: def __del__(self): mutate(d) +======= + class Mutating: + def __del__(self): + mutate(d) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = {k: Mutating() for k in 'abcdefghijklmnopqr'} for k in list(d): @@ -1385,6 +1717,7 @@ def mutate(d): self.check_reentrant_insertion(mutate) def test_merge_and_mutate(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class X: def __hash__(self): @@ -1393,6 +1726,15 @@ def __hash__(self): def __eq__(self, o): other.clear() return False +======= + class X: + def __hash__(self): + return 0 + + def __eq__(self, o): + other.clear() + return False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) l = [(i,0) for i in range(1, 1337)] other = dict(l) @@ -1408,6 +1750,7 @@ def test_free_after_iterating(self): def test_equal_operator_modifying_operand(self): # test fix for seg fault reported in bpo-27945 part 3. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class X(): def __del__(self): @@ -1419,17 +1762,36 @@ def __eq__(self, other): def __hash__(self): return 13 +======= + class X(): + def __del__(self): + dict_b.clear() + + def __eq__(self, other): + dict_a.clear() + return True + + def __hash__(self): + return 13 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dict_a = {X(): 0} dict_b = {X(): X()} self.assertTrue(dict_a == dict_b) # test fix for seg fault reported in bpo-38588 part 1. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Y: def __eq__(self, other): dict_d.clear() return True +======= + class Y: + def __eq__(self, other): + dict_d.clear() + return True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dict_c = {0: Y()} dict_d = {0: set()} @@ -1437,6 +1799,7 @@ def __eq__(self, other): def test_fromkeys_operator_modifying_dict_operand(self): # test fix for seg fault reported in issue 27945 part 4a. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class X(int): def __hash__(self): @@ -1446,6 +1809,16 @@ def __eq__(self, other): if len(d) > 1: d.clear() return False +======= + class X(int): + def __hash__(self): + return 13 + + def __eq__(self, other): + if len(d) > 1: + d.clear() + return False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = {} # this is required to exist so that d can be constructed! d = {X(1): 1, X(2): 2} @@ -1456,6 +1829,7 @@ def __eq__(self, other): def test_fromkeys_operator_modifying_set_operand(self): # test fix for seg fault reported in issue 27945 part 4b. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class X(int): def __hash__(self): @@ -1465,6 +1839,16 @@ def __eq__(self, other): if len(d) > 1: d.clear() return False +======= + class X(int): + def __hash__(self): + return 13 + + def __eq__(self, other): + if len(d) > 1: + d.clear() + return False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = {} # this is required to exist so that d can be constructed! d = {X(1), X(2)} @@ -1474,17 +1858,25 @@ def __eq__(self, other): pass def test_dictitems_contains_use_after_free(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class X: def __eq__(self, other): d.clear() return NotImplemented +======= + class X: + def __eq__(self, other): + d.clear() + return NotImplemented +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = {0: set()} (0, X()) in d.items() def test_dict_contain_use_after_free(self): # bpo-40489 +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class S(str): def __eq__(self, other): @@ -1493,25 +1885,47 @@ def __eq__(self, other): def __hash__(self): return hash('test') +======= + class S(str): + def __eq__(self, other): + d.clear() + return NotImplemented + + def __hash__(self): + return hash('test') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = {S(): 'value'} self.assertFalse('test' in d) def test_init_use_after_free(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class X: def __hash__(self): pair[:] = [] return 13 +======= + class X: + def __hash__(self): + pair[:] = [] + return 13 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pair = [X(), 123] dict([pair]) def test_oob_indexing_dictiter_iternextitem(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class X(int): def __del__(self): d.clear() +======= + class X(int): + def __del__(self): + d.clear() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = {i: X(i) for i in range(8)} @@ -1545,11 +1959,18 @@ def test_reverse_iterator_for_empty_dict(self): self.assertEqual(list(reversed(dict().keys())), []) def test_reverse_iterator_for_shared_shared_dicts(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class A: def __init__(self, x, y): if x: self.x = x if y: self.y = y +======= + class A: + def __init__(self, x, y): + if x: self.x = x + if y: self.y = y +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(list(reversed(A(1, 2).__dict__)), ['y', 'x']) self.assertEqual(list(reversed(A(1, 0).__dict__)), ['x']) @@ -1565,15 +1986,21 @@ def test_dict_copy_order(self): self.assertEqual(list(copy.items()), expected) # dict subclass doesn't override __iter__ +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class CustomDict(dict): pass +======= + class CustomDict(dict): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pairs = [('a', 1), ('b', 2), ('c', 3)] d = CustomDict(pairs) self.assertEqual(pairs, list(dict(d).items())) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class CustomReversedDict(dict): def keys(self): @@ -1583,6 +2010,16 @@ def keys(self): def items(self): return reversed(dict.items(self)) +======= + class CustomReversedDict(dict): + def keys(self): + return reversed(list(dict.keys(self))) + + __iter__ = keys + + def items(self): + return reversed(dict.items(self)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) d = CustomReversedDict(pairs) self.assertEqual(pairs[::-1], list(dict(d).items())) @@ -1607,6 +2044,7 @@ def test_dict_items_result_gc_reversed(self): self.assertTrue(gc.is_tracked(next(it))) def test_store_evilattr(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class EvilAttr: def __init__(self, d): @@ -1619,6 +2057,19 @@ def __del__(self): class Obj: pass +======= + class EvilAttr: + def __init__(self, d): + self.d = d + + def __del__(self): + if 'attr' in self.d: + del self.d['attr'] + gc.collect() + + class Obj: + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = Obj() obj.__dict__ = {} @@ -1630,6 +2081,7 @@ def test_str_nonstr(self): # `str` keys. Make sure the unoptimized path is used when a non-`str` # key appears. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class StrSub(str): pass @@ -1647,6 +2099,23 @@ def __eq__(self, other): eq_count += 1 return True return False +======= + class StrSub(str): + pass + + eq_count = 0 + # This class compares equal to the string 'key3' + class Key3: + def __hash__(self): + return hash('key3') + + def __eq__(self, other): + nonlocal eq_count + if isinstance(other, Key3) or isinstance(other, str) and other == 'key3': + eq_count += 1 + return True + return False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) key3_1 = StrSub('key3') key3_2 = Key3() @@ -1746,6 +2215,7 @@ def test_getitem_knownhash(self): # key does not exist self.assertRaises(KeyError, dict_getitem_knownhash, {}, 1, hash(1)) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Exc(Exception): pass class BadEq: @@ -1753,6 +2223,14 @@ def __eq__(self, other): raise Exc def __hash__(self): return 7 +======= + class Exc(Exception): pass + class BadEq: + def __eq__(self, other): + raise Exc + def __hash__(self): + return 7 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) k1, k2 = BadEq(), BadEq() d = {k1: 1} diff --git a/test/dynamo/cpython/3_13/test_float.diff b/test/dynamo/cpython/3_13/test_float.diff index 3e1d08e8fe60a..56d2213353c37 100644 --- a/test/dynamo/cpython/3_13/test_float.diff +++ b/test/dynamo/cpython/3_13/test_float.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/test_float.py b/test/dynamo/cpython/3_13/test_float.py +<<<<<<< HEAD index 97f951f1299..da82bd190c3 100644 --- a/test/dynamo/cpython/3_13/test_float.py +++ b/test/dynamo/cpython/3_13/test_float.py @@ -1,3 +1,57 @@ +======= +index 97f951f1299..ce2c46777e0 100644 +--- a/test/dynamo/cpython/3_13/test_float.py ++++ b/test/dynamo/cpython/3_13/test_float.py +@@ -1,3 +1,54 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_float.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -60,9 +70,15 @@ index 97f951f1299..da82bd190c3 100644 import fractions import operator import os +<<<<<<< HEAD @@ -8,11 +62,84 @@ import time import unittest +======= +@@ -8,11 +59,84 @@ import time + import unittest + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from test import support -from test.support.testcase import FloatsAreIdenticalMixin -from test.support.numbers import ( @@ -149,6 +165,7 @@ index 97f951f1299..da82bd190c3 100644 + from math import isinf, isnan, copysign, ldexp import math +<<<<<<< HEAD @@ -35,7 +162,7 @@ class FloatSubclass(float): class OtherFloatSubclass(float): @@ -338,12 +355,28 @@ index 97f951f1299..da82bd190c3 100644 self.assertEqual(hash(value), object.__hash__(value)) +======= + +@@ -35,7 +159,7 @@ class FloatSubclass(float): + class OtherFloatSubclass(float): + pass + +-class GeneralFloatCases(unittest.TestCase): ++class GeneralFloatCases(__TestCase): + + def test_float(self): + self.assertEqual(float(3.14), 3.14) +@@ -620,7 +744,7 @@ class GeneralFloatCases(unittest.TestCase): + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipUnless(hasattr(float, "__getformat__"), "requires __getformat__") -class FormatFunctionsTestCase(unittest.TestCase): +class FormatFunctionsTestCase(__TestCase): def test_getformat(self): self.assertIn(float.__getformat__('double'), ['unknown', 'IEEE, big-endian', 'IEEE, little-endian']) +<<<<<<< HEAD @@ -645,7 +782,7 @@ LE_FLOAT_NAN = bytes(reversed(BE_FLOAT_NAN)) # is accident (today). # let's also try to guarantee that -0.0 and 0.0 don't get confused. @@ -366,14 +399,44 @@ index 97f951f1299..da82bd190c3 100644 self.assertEqual(format(-123.34, '00.10e'), '-1.2334000000e+02') self.assertEqual(format(-123.34, '00.10g'), '-123.34') +======= +@@ -645,7 +769,7 @@ LE_FLOAT_NAN = bytes(reversed(BE_FLOAT_NAN)) + # is accident (today). + # let's also try to guarantee that -0.0 and 0.0 don't get confused. + +-class IEEEFormatTestCase(unittest.TestCase): ++class IEEEFormatTestCase(__TestCase): + + @support.requires_IEEE_754 + def test_double_specials_do_unpack(self): +@@ -670,7 +794,7 @@ class IEEEFormatTestCase(unittest.TestCase): + self.assertEqual(struct.pack(">>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class ReprTestCase(unittest.TestCase): +class ReprTestCase(__TestCase): def test_repr(self): with open(os.path.join(os.path.split(__file__)[0], 'mathdata', +<<<<<<< HEAD @@ -832,7 +969,29 @@ class ReprTestCase(unittest.TestCase): self.assertEqual(repr(float(negs)), str(float(negs))) +======= +@@ -832,7 +956,29 @@ class ReprTestCase(unittest.TestCase): + self.assertEqual(repr(float(negs)), str(float(negs))) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @support.requires_IEEE_754 -class RoundTestCase(unittest.TestCase, FloatsAreIdenticalMixin): +class RoundTestCase(__TestCase): @@ -399,11 +462,19 @@ index 97f951f1299..da82bd190c3 100644 + else: + msg += ': zeros have different signs' + self.fail(msg.format(x, y)) +<<<<<<< HEAD def test_inf_nan(self): self.assertRaises(OverflowError, round, INF) @@ -955,7 +1114,7 @@ class RoundTestCase(unittest.TestCase, FloatsAreIdenticalMixin): +======= + + def test_inf_nan(self): + self.assertRaises(OverflowError, round, INF) +@@ -955,7 +1101,7 @@ class RoundTestCase(unittest.TestCase, FloatsAreIdenticalMixin): + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Beginning with Python 2.6 float has cross platform compatible # ways to create and represent inf and nan -class InfNanTest(unittest.TestCase): @@ -411,8 +482,13 @@ index 97f951f1299..da82bd190c3 100644 def test_inf_from_str(self): self.assertTrue(isinf(float("inf"))) self.assertTrue(isinf(float("+inf"))) +<<<<<<< HEAD @@ -1056,12 +1215,35 @@ class InfNanTest(unittest.TestCase): +======= +@@ -1056,12 +1202,35 @@ class InfNanTest(unittest.TestCase): + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fromHex = float.fromhex toHex = float.hex -class HexFloatTestCase(FloatsAreIdenticalMixin, unittest.TestCase): @@ -421,7 +497,11 @@ index 97f951f1299..da82bd190c3 100644 MIN = fromHex('0x1p-1022') # min normal TINY = fromHex('0x0.0000000000001p-1022') # min subnormal EPS = fromHex('0x0.0000000000001p0') # diff between 1.0 and next float up +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + def assertFloatsAreIdentical(self, x, y): + """assert that floats x and y are identical, in the sense that: + (1) both x and y are nans, or @@ -447,6 +527,7 @@ index 97f951f1299..da82bd190c3 100644 + def identical(self, x, y): self.assertFloatsAreIdentical(x, y) +<<<<<<< HEAD @@ -1482,17 +1664,19 @@ class HexFloatTestCase(FloatsAreIdenticalMixin, unittest.TestCase): self.identical(x, fromHex(toHex(x))) @@ -478,6 +559,13 @@ index 97f951f1299..da82bd190c3 100644 self.assertEqual(getattr(f, 'foo', 'none'), 'bar') +======= + +@@ -1500,5 +1669,5 @@ class HexFloatTestCase(FloatsAreIdenticalMixin, unittest.TestCase): + self.assertEqual(getattr(f, 'foo', 'none'), 'bar') + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_float.py b/test/dynamo/cpython/3_13/test_float.py index efc387023a4ae..da0da62739f34 100644 --- a/test/dynamo/cpython/3_13/test_float.py +++ b/test/dynamo/cpython/3_13/test_float.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_float.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case @@ -222,10 +225,16 @@ def test_underscores(self): def test_non_numeric_input_types(self): # Test possible non-numeric types for the argument x, including # subclasses of the explicitly documented accepted types. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class CustomStr(str): pass class CustomBytes(bytes): pass class CustomByteArray(bytearray): pass +======= + class CustomStr(str): pass + class CustomBytes(bytes): pass + class CustomByteArray(bytearray): pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) factories = [ bytes, @@ -312,6 +321,7 @@ def test_float_with_comma(self): def test_floatconversion(self): # Make sure that calls to __float__() work properly +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Foo1(object): def __float__(self): @@ -337,6 +347,32 @@ def __float__(self): class FooStr(str): def __float__(self): return float(str(self)) + 1 +======= + class Foo1(object): + def __float__(self): + return 42. + + class Foo2(float): + def __float__(self): + return 42. + + class Foo3(float): + def __new__(cls, value=0.): + return float.__new__(cls, 2*value) + + def __float__(self): + return self + + class Foo4(float): + def __float__(self): + return 42 + + # Issue 5759: __float__ not called on str subclasses (though it is on + # unicode subclasses). + class FooStr(str): + def __float__(self): + return float(str(self)) + 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(float(Foo1()), 42.) self.assertEqual(float(Foo2()), 42.) @@ -345,6 +381,7 @@ def __float__(self): self.assertRaises(TypeError, float, Foo4(42)) self.assertEqual(float(FooStr('8')), 9.) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Foo5: def __float__(self): @@ -356,6 +393,17 @@ def __float__(self): class F: def __float__(self): return OtherFloatSubclass(42.) +======= + class Foo5: + def __float__(self): + return "" + self.assertRaises(TypeError, time.sleep, Foo5()) + + # Issue #24731 + class F: + def __float__(self): + return OtherFloatSubclass(42.) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertWarns(DeprecationWarning): self.assertEqual(float(F()), 42.) with self.assertWarns(DeprecationWarning): @@ -365,20 +413,34 @@ def __float__(self): with self.assertWarns(DeprecationWarning): self.assertIs(type(FloatSubclass(F())), FloatSubclass) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class MyIndex: def __init__(self, value): self.value = value def __index__(self): return self.value +======= + class MyIndex: + def __init__(self, value): + self.value = value + def __index__(self): + return self.value +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(float(MyIndex(42)), 42.0) self.assertRaises(OverflowError, float, MyIndex(2**2000)) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class MyInt: def __int__(self): return 42 +======= + class MyInt: + def __int__(self): + return 42 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertRaises(TypeError, float, MyInt()) @@ -387,30 +449,49 @@ def test_keyword_args(self): float(x='3.14') def test_keywords_in_subclass(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass(float): pass +======= + class subclass(float): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u = subclass(2.5) self.assertIs(type(u), subclass) self.assertEqual(float(u), 2.5) with self.assertRaises(TypeError): subclass(x=0) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass_with_init(float): def __init__(self, arg, newarg=None): self.newarg = newarg +======= + class subclass_with_init(float): + def __init__(self, arg, newarg=None): + self.newarg = newarg +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u = subclass_with_init(2.5, newarg=3) self.assertIs(type(u), subclass_with_init) self.assertEqual(float(u), 2.5) self.assertEqual(u.newarg, 3) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass_with_new(float): def __new__(cls, arg, newarg=None): self = super().__new__(cls, arg) self.newarg = newarg return self +======= + class subclass_with_new(float): + def __new__(cls, arg, newarg=None): + self = super().__new__(cls, arg) + self.newarg = newarg + return self +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u = subclass_with_new(2.5, newarg=3) self.assertIs(type(u), subclass_with_new) self.assertEqual(float(u), 2.5) @@ -746,12 +827,20 @@ def test_hash(self): def test_hash_nan(self): value = float('nan') self.assertEqual(hash(value), object.__hash__(value)) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class H: def __hash__(self): return 42 class F(float, H): pass +======= + class H: + def __hash__(self): + return 42 + class F(float, H): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) value = F('nan') self.assertEqual(hash(value), object.__hash__(value)) @@ -1664,19 +1753,31 @@ def roundtrip(x): self.identical(x, fromHex(toHex(x))) def test_subclass(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class F(float): def __new__(cls, value): return float.__new__(cls, value + 1) +======= + class F(float): + def __new__(cls, value): + return float.__new__(cls, value + 1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f = F.fromhex((1.5).hex()) self.assertIs(type(f), F) self.assertEqual(f, 2.5) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class F2(float): def __init__(self, value): self.foo = 'bar' +======= + class F2(float): + def __init__(self, value): + self.foo = 'bar' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f = F2.fromhex((1.5).hex()) self.assertIs(type(f), F2) diff --git a/test/dynamo/cpython/3_13/test_int.diff b/test/dynamo/cpython/3_13/test_int.diff index 20ab3ed2f58bf..829db2c981223 100644 --- a/test/dynamo/cpython/3_13/test_int.diff +++ b/test/dynamo/cpython/3_13/test_int.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/test_int.py b/test/dynamo/cpython/3_13/test_int.py +<<<<<<< HEAD index 48825f46911..731680d82a0 100644 --- a/test/dynamo/cpython/3_13/test_int.py +++ b/test/dynamo/cpython/3_13/test_int.py @@ -1,13 +1,140 @@ +======= +index 48825f46911..ac7aeacbc01 100644 +--- a/test/dynamo/cpython/3_13/test_int.py ++++ b/test/dynamo/cpython/3_13/test_int.py +@@ -1,13 +1,137 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_int.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -59,7 +69,11 @@ index 48825f46911..731680d82a0 100644 + import sys import time +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import unittest from unittest import mock from test import support @@ -144,6 +158,7 @@ index 48825f46911..731680d82a0 100644 + '(1+1.5_j_)', + '(1+1.5_j)', +] +<<<<<<< HEAD try: import _pylong @@ -454,14 +469,48 @@ index 48825f46911..731680d82a0 100644 int_class = IntSubclass +======= + + try: + import _pylong +@@ -38,7 +162,7 @@ L = [ + class IntSubclass(int): + pass + +-class IntTestCases(unittest.TestCase): ++class IntTestCases(__TestCase): + + def test_basic(self): + self.assertEqual(int(314), 314) +@@ -607,7 +731,7 @@ class IntTestCases(unittest.TestCase): + self.assertEqual(int('1_2_3_4_5_6_7', 32), 1144132807) + + +-class IntStrDigitLimitsTests(unittest.TestCase): ++class IntStrDigitLimitsTests(__TestCase): + + int_class = int # Override this in subclasses to reuse the suite. + +@@ -818,7 +942,7 @@ class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests): + int_class = IntSubclass + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class PyLongModuleTests(unittest.TestCase): +class PyLongModuleTests(__TestCase): # Tests of the functions in _pylong.py. Those get used when the # number of digits in the input values are large enough. +<<<<<<< HEAD @@ -922,4 +1068,4 @@ class PyLongModuleTests(unittest.TestCase): bits <<= 1 +======= + +@@ -922,4 +1046,4 @@ class PyLongModuleTests(unittest.TestCase): + bits <<= 1 + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/test/dynamo/cpython/3_13/test_int.py b/test/dynamo/cpython/3_13/test_int.py index b0f8fe49d1b94..bacff85622251 100644 --- a/test/dynamo/cpython/3_13/test_int.py +++ b/test/dynamo/cpython/3_13/test_int.py @@ -4,15 +4,22 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_int.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case import unittest from torch._dynamo.test_case import CPythonTestCase +<<<<<<< HEAD from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo +======= +from torch.testing._internal.common_utils import run_tests +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __TestCase = CPythonTestCase @@ -436,6 +443,7 @@ def test_int_base_bad_types(self): int('0', 5.0) def test_int_base_indexable(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): with torch._dynamo.error_on_graph_break(False): class MyIndexable(object): @@ -443,6 +451,13 @@ def __init__(self, value): self.value = value def __index__(self): return self.value +======= + class MyIndexable(object): + def __init__(self, value): + self.value = value + def __index__(self): + return self.value +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Check out of range bases. for base in 2**100, -2**100, 1, 37: @@ -457,11 +472,17 @@ def __index__(self): def test_non_numeric_input_types(self): # Test possible non-numeric types for the argument x, including # subclasses of the explicitly documented accepted types. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class CustomStr(str): pass class CustomBytes(bytes): pass class CustomByteArray(bytearray): pass +======= + class CustomStr(str): pass + class CustomBytes(bytes): pass + class CustomByteArray(bytearray): pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) factories = [ bytes, @@ -503,6 +524,7 @@ def test_string_float(self): def test_intconversion(self): # Test __int__() +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class ClassicMissingMethods: pass @@ -543,11 +565,47 @@ def __trunc__(self): class ExceptionalTrunc(base): def __trunc__(self): 1 / 0 +======= + class ClassicMissingMethods: + pass + self.assertRaises(TypeError, int, ClassicMissingMethods()) + + class MissingMethods(object): + pass + self.assertRaises(TypeError, int, MissingMethods()) + + class Foo0: + def __int__(self): + return 42 + + self.assertEqual(int(Foo0()), 42) + + class Classic: + pass + for base in (object, Classic): + class IntOverridesTrunc(base): + def __int__(self): + return 42 + def __trunc__(self): + return -12 + self.assertEqual(int(IntOverridesTrunc()), 42) + + class JustTrunc(base): + def __trunc__(self): + return 42 + with self.assertWarns(DeprecationWarning): + self.assertEqual(int(JustTrunc()), 42) + + class ExceptionalTrunc(base): + def __trunc__(self): + 1 / 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaises(ZeroDivisionError), \ self.assertWarns(DeprecationWarning): int(ExceptionalTrunc()) for trunc_result_base in (object, Classic): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Index(trunc_result_base): def __index__(self): @@ -579,6 +637,36 @@ def __trunc__(self): class TruncReturnsNonIntegral(base): def __trunc__(self): return NonIntegral() +======= + class Index(trunc_result_base): + def __index__(self): + return 42 + + class TruncReturnsNonInt(base): + def __trunc__(self): + return Index() + with self.assertWarns(DeprecationWarning): + self.assertEqual(int(TruncReturnsNonInt()), 42) + + class Intable(trunc_result_base): + def __int__(self): + return 42 + + class TruncReturnsNonIndex(base): + def __trunc__(self): + return Intable() + with self.assertWarns(DeprecationWarning): + self.assertEqual(int(TruncReturnsNonInt()), 42) + + class NonIntegral(trunc_result_base): + def __trunc__(self): + # Check that we avoid infinite recursion. + return NonIntegral() + + class TruncReturnsNonIntegral(base): + def __trunc__(self): + return NonIntegral() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: with self.assertWarns(DeprecationWarning): int(TruncReturnsNonIntegral()) @@ -590,6 +678,7 @@ def __trunc__(self): self.fail("Failed to raise TypeError with %s" % ((base, trunc_result_base),)) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): # Regression test for bugs.python.org/issue16060. class BadInt(trunc_result_base): @@ -599,12 +688,23 @@ def __int__(self): class TruncReturnsBadInt(base): def __trunc__(self): return BadInt() +======= + # Regression test for bugs.python.org/issue16060. + class BadInt(trunc_result_base): + def __int__(self): + return 42.0 + + class TruncReturnsBadInt(base): + def __trunc__(self): + return BadInt() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaises(TypeError), \ self.assertWarns(DeprecationWarning): int(TruncReturnsBadInt()) def test_int_subclass_with_index(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class MyIndex(int): def __index__(self): @@ -613,6 +713,15 @@ def __index__(self): class BadIndex(int): def __index__(self): return 42.0 +======= + class MyIndex(int): + def __index__(self): + return 42 + + class BadIndex(int): + def __index__(self): + return 42.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) my_int = MyIndex(7) self.assertEqual(my_int, 7) @@ -621,6 +730,7 @@ def __index__(self): self.assertEqual(int(BadIndex()), 0) def test_int_subclass_with_int(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class MyInt(int): def __int__(self): @@ -629,6 +739,15 @@ def __int__(self): class BadInt(int): def __int__(self): return 42.0 +======= + class MyInt(int): + def __int__(self): + return 42 + + class BadInt(int): + def __int__(self): + return 42.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) my_int = MyInt(7) self.assertEqual(my_int, 7) @@ -639,6 +758,7 @@ def __int__(self): self.assertRaises(TypeError, int, my_int) def test_int_returns_int_subclass(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class BadIndex: def __index__(self): @@ -667,6 +787,35 @@ def __trunc__(self): class TruncReturnsIntSubclass: def __trunc__(self): return True +======= + class BadIndex: + def __index__(self): + return True + + class BadIndex2(int): + def __index__(self): + return True + + class BadInt: + def __int__(self): + return True + + class BadInt2(int): + def __int__(self): + return True + + class TruncReturnsBadIndex: + def __trunc__(self): + return BadIndex() + + class TruncReturnsBadInt: + def __trunc__(self): + return BadInt() + + class TruncReturnsIntSubclass: + def __trunc__(self): + return True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bad_int = BadIndex() with self.assertWarns(DeprecationWarning): @@ -711,7 +860,10 @@ def __trunc__(self): self.assertEqual(n, 1) self.assertIs(type(n), IntSubclass) +<<<<<<< HEAD @skipIfTorchDynamo("flaky under dynamo") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_error_message(self): def check(s, base=None): with self.assertRaises(ValueError, diff --git a/test/dynamo/cpython/3_13/test_int_literal.diff b/test/dynamo/cpython/3_13/test_int_literal.diff index 65d7645590431..e569ad0831b30 100644 --- a/test/dynamo/cpython/3_13/test_int_literal.diff +++ b/test/dynamo/cpython/3_13/test_int_literal.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/test_int_literal.py b/test/dynamo/cpython/3_13/test_int_literal.py +<<<<<<< HEAD index bf725710d55..311b8713a36 100644 --- a/test/dynamo/cpython/3_13/test_int_literal.py +++ b/test/dynamo/cpython/3_13/test_int_literal.py @@ -1,3 +1,57 @@ +======= +index bf725710d55..831d03666fb 100644 +--- a/test/dynamo/cpython/3_13/test_int_literal.py ++++ b/test/dynamo/cpython/3_13/test_int_literal.py +@@ -1,3 +1,54 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_int_literal.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -60,7 +70,11 @@ index bf725710d55..311b8713a36 100644 """Test correct treatment of hex/oct constants. This is complex because of changes due to PEP 237. +<<<<<<< HEAD @@ -5,7 +59,7 @@ This is complex because of changes due to PEP 237. +======= +@@ -5,7 +56,7 @@ This is complex because of changes due to PEP 237. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import unittest @@ -69,7 +83,11 @@ index bf725710d55..311b8713a36 100644 def test_hex_baseline(self): # A few upper/lowercase tests +<<<<<<< HEAD @@ -140,4 +194,4 @@ class TestHexOctBin(unittest.TestCase): +======= +@@ -140,4 +191,4 @@ class TestHexOctBin(unittest.TestCase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(-0b1111111111111111111111111111111111111111111111111111111111111111, -18446744073709551615) if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_int_literal.py b/test/dynamo/cpython/3_13/test_int_literal.py index 311b8713a36cc..ac0bdece1b96d 100644 --- a/test/dynamo/cpython/3_13/test_int_literal.py +++ b/test/dynamo/cpython/3_13/test_int_literal.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_int_literal.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_iter.diff b/test/dynamo/cpython/3_13/test_iter.diff index 18bdcdfb3df82..20b257fcd6823 100644 --- a/test/dynamo/cpython/3_13/test_iter.diff +++ b/test/dynamo/cpython/3_13/test_iter.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/test_iter.py b/test/dynamo/cpython/3_13/test_iter.py +<<<<<<< HEAD index 1b9f3cf7624..6560c7423a6 100644 --- a/test/dynamo/cpython/3_13/test_iter.py +++ b/test/dynamo/cpython/3_13/test_iter.py @@ -1,3 +1,60 @@ +======= +index 1b9f3cf7624..d0c68f4314c 100644 +--- a/test/dynamo/cpython/3_13/test_iter.py ++++ b/test/dynamo/cpython/3_13/test_iter.py +@@ -1,3 +1,57 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_iter.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -61,6 +71,7 @@ index 1b9f3cf7624..6560c7423a6 100644 +# ======= END DYNAMO PATCH ======= + # Test iterators. +<<<<<<< HEAD import sys @@ -104,12 +161,10 @@ class EmptyIterClass: @@ -231,11 +242,28 @@ index 1b9f3cf7624..6560c7423a6 100644 @@ -635,6 +694,7 @@ class TestCase(unittest.TestCase): pass +======= + + import sys +@@ -104,7 +158,7 @@ class EmptyIterClass: + + # Main test suite + +-class TestCase(unittest.TestCase): ++class TestCase(__TestCase): + + # Helper to check that an iterator returns a given sequence + def check_iterator(self, it, seq, pickle=True): +@@ -635,6 +689,7 @@ class TestCase(unittest.TestCase): + pass + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Test zip()'s use of iterators. + @skipIfTorchDynamo("infinite loop") def test_builtin_zip(self): self.assertEqual(list(zip()), []) self.assertEqual(list(zip(*[])), []) +<<<<<<< HEAD @@ -653,17 +713,18 @@ class TestCase(unittest.TestCase): self.assertEqual(list(d.items()), list(zip(d, d.values()))) @@ -429,6 +457,11 @@ index 1b9f3cf7624..6560c7423a6 100644 @@ -1187,4 +1253,4 @@ class TestCase(unittest.TestCase): +======= +@@ -1187,4 +1242,4 @@ class TestCase(unittest.TestCase): + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/test/dynamo/cpython/3_13/test_iter.py b/test/dynamo/cpython/3_13/test_iter.py index 8e6240d99ce6d..a66134a489f9e 100644 --- a/test/dynamo/cpython/3_13/test_iter.py +++ b/test/dynamo/cpython/3_13/test_iter.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_iter.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case @@ -165,6 +168,11 @@ class TestCase(__TestCase): # Helper to check that an iterator returns a given sequence def check_iterator(self, it, seq, pickle=True): +<<<<<<< HEAD +======= + if pickle: + self.check_pickle(it, seq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) res = [] while 1: try: @@ -176,6 +184,11 @@ def check_iterator(self, it, seq, pickle=True): # Helper to check that a for loop generates a given sequence def check_for_loop(self, expr, seq, pickle=True): +<<<<<<< HEAD +======= + if pickle: + self.check_pickle(iter(expr), seq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) res = [] for val in expr: res.append(val) @@ -314,6 +327,7 @@ def test_reduce_mutating_builtins_iter(self): def run(builtin_name, item, sentinel=None): it = iter(item) if sentinel is None else iter(item, sentinel) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class CustomStr: def __init__(self, name, iterator): @@ -328,6 +342,21 @@ def __eq__(self, other): # the pointers after this call list(self.iterator) return other == self.name +======= + class CustomStr: + def __init__(self, name, iterator): + self.name = name + self.iterator = iterator + def __hash__(self): + return hash(self.name) + def __eq__(self, other): + # Here we exhaust our iterator, possibly changing + # its `it_seq` pointer to NULL + # The `__reduce__` call should correctly get + # the pointers after this call + list(self.iterator) + return other == self.name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # del is required here # to not prematurely call __eq__ from @@ -377,10 +406,16 @@ def __eq__(self, other): # Test a new_style class with __iter__ but no next() method def test_new_style_iter_class(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class IterClass(object): def __iter__(self): return self +======= + class IterClass(object): + def __iter__(self): + return self +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertRaises(TypeError, iter, IterClass()) # Test two-argument iter() with callable instance @@ -449,12 +484,20 @@ def spam(state=[0]): # Test exception propagation through sequence iterator def test_exception_sequence(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class MySequenceClass(SequenceClass): def __getitem__(self, i): if i == 10: raise RuntimeError return SequenceClass.__getitem__(self, i) +======= + class MySequenceClass(SequenceClass): + def __getitem__(self, i): + if i == 10: + raise RuntimeError + return SequenceClass.__getitem__(self, i) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) res = [] try: for x in MySequenceClass(20): @@ -466,12 +509,20 @@ def __getitem__(self, i): # Test for StopIteration from __getitem__ def test_stop_sequence(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class MySequenceClass(SequenceClass): def __getitem__(self, i): if i == 10: raise StopIteration return SequenceClass.__getitem__(self, i) +======= + class MySequenceClass(SequenceClass): + def __getitem__(self, i): + if i == 10: + raise StopIteration + return SequenceClass.__getitem__(self, i) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False) # Test a big range @@ -598,6 +649,7 @@ def test_builtin_filter(self): self.assertRaises(TypeError, filter, None, list) self.assertRaises(TypeError, filter, None, 42) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Boolean: def __init__(self, truth): @@ -626,6 +678,34 @@ def __next__(self): else: raise StopIteration return SeqIter(self.vals) +======= + class Boolean: + def __init__(self, truth): + self.truth = truth + def __bool__(self): + return self.truth + bTrue = Boolean(True) + bFalse = Boolean(False) + + class Seq: + def __init__(self, *args): + self.vals = args + def __iter__(self): + class SeqIter: + def __init__(self, vals): + self.vals = vals + self.i = 0 + def __iter__(self): + return self + def __next__(self): + i = self.i + self.i = i + 1 + if i < len(self.vals): + return self.vals[i] + else: + raise StopIteration + return SeqIter(self.vals) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) seq = Seq(*([bTrue, bFalse] * 25)) self.assertEqual(list(filter(lambda x: not x, seq)), [bFalse]*25) @@ -713,6 +793,7 @@ def test_builtin_zip(self): self.assertEqual(list(d.items()), list(zip(d, d.values()))) # Generate all ints starting at constructor arg. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class IntsFrom: def __init__(self, start): @@ -725,6 +806,19 @@ def __next__(self): i = self.i self.i = i+1 return i +======= + class IntsFrom: + def __init__(self, start): + self.i = start + + def __iter__(self): + return self + + def __next__(self): + i = self.i + self.i = i+1 + return i +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f = open(TESTFN, "w", encoding="utf-8") try: @@ -747,6 +841,7 @@ def __next__(self): self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)]) # Classes that lie about their lengths. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class NoGuessLen5: def __getitem__(self, i): @@ -761,6 +856,21 @@ def __len__(self): class Guess30Len5(NoGuessLen5): def __len__(self): return 30 +======= + class NoGuessLen5: + def __getitem__(self, i): + if i >= 5: + raise IndexError + return i + + class Guess3Len5(NoGuessLen5): + def __len__(self): + return 3 + + class Guess30Len5(NoGuessLen5): + def __len__(self): + return 30 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def lzip(*args): return list(zip(*args)) @@ -780,6 +890,7 @@ def test_unicode_join_endcase(self): # This class inserts a Unicode object into its argument's natural # iteration, in the 3rd position. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class OhPhooey: def __init__(self, seq): @@ -795,6 +906,22 @@ def __next__(self): if i == 2: return "fooled you!" return next(self.it) +======= + class OhPhooey: + def __init__(self, seq): + self.it = iter(seq) + self.i = 0 + + def __iter__(self): + return self + + def __next__(self): + i = self.i + self.i = i+1 + if i == 2: + return "fooled you!" + return next(self.it) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f = open(TESTFN, "w", encoding="utf-8") try: @@ -958,6 +1085,7 @@ def test_writelines(self): f.writelines({}) # Try a big chunk too. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Iterator: def __init__(self, start, finish): @@ -982,6 +1110,31 @@ def __init__(self, start, finish): def __iter__(self): return Iterator(self.start, self.finish) +======= + class Iterator: + def __init__(self, start, finish): + self.start = start + self.finish = finish + self.i = self.start + + def __next__(self): + if self.i >= self.finish: + raise StopIteration + result = str(self.i) + '\n' + self.i += 1 + return result + + def __iter__(self): + return self + + class Whatever: + def __init__(self, start, finish): + self.start = start + self.finish = finish + + def __iter__(self): + return Iterator(self.start, self.finish) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f.writelines(Whatever(6, 6+2000)) f.close() @@ -1054,6 +1207,7 @@ def test_unpack_iter(self): @cpython_only def test_ref_counting_behavior(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class C(object): count = 0 @@ -1064,6 +1218,17 @@ def __del__(self): cls = self.__class__ assert cls.count > 0 cls.count -= 1 +======= + class C(object): + count = 0 + def __new__(cls): + cls.count += 1 + return object.__new__(cls) + def __del__(self): + cls = self.__class__ + assert cls.count > 0 + cls.count -= 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = C() self.assertEqual(C.count, 1) del x @@ -1154,6 +1319,7 @@ def test_sinkstate_enumerate(self): def test_3720(self): # Avoid a crash, when an iterator deletes its next() method. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class BadIterator(object): def __iter__(self): @@ -1161,6 +1327,14 @@ def __iter__(self): def __next__(self): del BadIterator.__next__ return 1 +======= + class BadIterator(object): + def __iter__(self): + return self + def __next__(self): + del BadIterator.__next__ + return 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: for i in BadIterator() : diff --git a/test/dynamo/cpython/3_13/test_list.diff b/test/dynamo/cpython/3_13/test_list.diff index 7b0a90735d87c..d0e489c36bc53 100644 --- a/test/dynamo/cpython/3_13/test_list.diff +++ b/test/dynamo/cpython/3_13/test_list.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/test_list.py b/test/dynamo/cpython/3_13/test_list.py +<<<<<<< HEAD index 23ef902aa0b..b9afb1ef26e 100644 --- a/test/dynamo/cpython/3_13/test_list.py +++ b/test/dynamo/cpython/3_13/test_list.py @@ -1,6 +1,60 @@ +======= +index 23ef902aa0b..30e69ff75bd 100644 +--- a/test/dynamo/cpython/3_13/test_list.py ++++ b/test/dynamo/cpython/3_13/test_list.py +@@ -1,6 +1,57 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_list.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -64,6 +74,7 @@ index 23ef902aa0b..b9afb1ef26e 100644 from test.support import cpython_only from test.support.script_helper import assert_python_ok import pickle +<<<<<<< HEAD @@ -36,7 +90,7 @@ class ListTest(list_tests.CommonTest): # earlier due to a newlib bug. See the following mailing list # thread for the details: @@ -255,13 +266,25 @@ index 23ef902aa0b..b9afb1ef26e 100644 a.append(4) self.assertEqual(list(it), []) +======= +@@ -324,6 +375,7 @@ class ListTest(list_tests.CommonTest): + a.append(4) + self.assertEqual(list(it), []) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + @unittest.skip("Fails on python <=3.13.2 and passes on >=3.13.3") def test_deopt_from_append_list(self): # gh-132011: it used to crash, because # of `CALL_LIST_APPEND` specialization failure. +<<<<<<< HEAD @@ -345,4 +410,4 @@ class ListTest(list_tests.CommonTest): self.assertEqual(rc, 0) +======= +@@ -345,4 +397,4 @@ class ListTest(list_tests.CommonTest): + self.assertEqual(rc, 0) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/test/dynamo/cpython/3_13/test_list.py b/test/dynamo/cpython/3_13/test_list.py index 7f91b7b840804..ab76e949e9395 100644 --- a/test/dynamo/cpython/3_13/test_list.py +++ b/test/dynamo/cpython/3_13/test_list.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_list.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case @@ -101,31 +104,51 @@ def test_keyword_args(self): list(sequence=[]) def test_keywords_in_subclass(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass(list): pass +======= + class subclass(list): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u = subclass([1, 2]) self.assertIs(type(u), subclass) self.assertEqual(list(u), [1, 2]) with self.assertRaises(TypeError): subclass(sequence=()) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass_with_init(list): def __init__(self, seq, newarg=None): super().__init__(seq) self.newarg = newarg +======= + class subclass_with_init(list): + def __init__(self, seq, newarg=None): + super().__init__(seq) + self.newarg = newarg +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u = subclass_with_init([1, 2], newarg=3) self.assertIs(type(u), subclass_with_init) self.assertEqual(list(u), [1, 2]) self.assertEqual(u.newarg, 3) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass_with_new(list): def __new__(cls, seq, newarg=None): self = super().__new__(cls, seq) self.newarg = newarg return self +======= + class subclass_with_new(list): + def __new__(cls, seq, newarg=None): + self = super().__new__(cls, seq) + self.newarg = newarg + return self +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u = subclass_with_new([1, 2], newarg=3) self.assertIs(type(u), subclass_with_new) self.assertEqual(list(u), [1, 2]) @@ -172,6 +195,7 @@ def test_list_resize_overflow(self): lst *= size def test_repr_mutate(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Obj: @staticmethod @@ -181,6 +205,16 @@ def __repr__(): except IndexError: pass return 'obj' +======= + class Obj: + @staticmethod + def __repr__(): + try: + mylist.pop() + except IndexError: + pass + return 'obj' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mylist = [Obj() for _ in range(5)] self.assertEqual(repr(mylist), '[obj, obj, obj]') @@ -276,13 +310,18 @@ def test_no_comdat_folding(self): # Issue 8847: In the PGO build, the MSVC linker's COMDAT folding # optimization causes failures in code that relies on distinct # function addresses. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class L(list): pass +======= + class L(list): pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaises(TypeError): (3,) + L([1,2]) def test_equal_operator_modifying_operand(self): # test fix for seg fault reported in bpo-38588 part 2. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class X: def __eq__(self,other) : @@ -298,6 +337,22 @@ class Z: def __eq__(self, other): list3.clear() return NotImplemented +======= + class X: + def __eq__(self,other) : + list2.clear() + return NotImplemented + + class Y: + def __eq__(self, other): + list1.clear() + return NotImplemented + + class Z: + def __eq__(self, other): + list3.clear() + return NotImplemented +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) list1 = [X()] list2 = [Y()] @@ -308,18 +363,27 @@ def __eq__(self, other): self.assertFalse(list3 == list4) def test_lt_operator_modifying_operand(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): # See gh-120298 class evil: def __lt__(self, other): other.clear() return NotImplemented +======= + # See gh-120298 + class evil: + def __lt__(self, other): + other.clear() + return NotImplemented +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = [[evil()]] with self.assertRaises(TypeError): a[0] < a def test_list_index_modifing_operand(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): # See gh-120384 class evil: @@ -328,6 +392,15 @@ def __init__(self, lst): def __iter__(self): yield from self.lst self.lst.clear() +======= + # See gh-120384 + class evil: + def __init__(self, lst): + self.lst = lst + def __iter__(self): + yield from self.lst + self.lst.clear() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lst = list(range(5)) operand = evil(lst) @@ -346,21 +419,35 @@ def test_count_index_remove_crashes(self): # bpo-38610: The count(), index(), and remove() methods were not # holding strong references to list elements while calling # PyObject_RichCompareBool(). +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class X: def __eq__(self, other): lst.clear() return NotImplemented +======= + class X: + def __eq__(self, other): + lst.clear() + return NotImplemented +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lst = [X()] with self.assertRaises(ValueError): lst.index(lst) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class L(list): def __eq__(self, other): str(other) return NotImplemented +======= + class L(list): + def __eq__(self, other): + str(other) + return NotImplemented +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lst = L([X()]) lst.count(lst) diff --git a/test/dynamo/cpython/3_13/test_math.diff b/test/dynamo/cpython/3_13/test_math.diff index 058477820c63d..8f38e4df9a645 100644 --- a/test/dynamo/cpython/3_13/test_math.diff +++ b/test/dynamo/cpython/3_13/test_math.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/test_math.py b/test/dynamo/cpython/3_13/test_math.py +<<<<<<< HEAD index 5ee3055c871..5402cdc4a6c 100644 --- a/test/dynamo/cpython/3_13/test_math.py +++ b/test/dynamo/cpython/3_13/test_math.py @@ -1,3 +1,61 @@ +======= +index 5ee3055c871..51773d5f478 100644 +--- a/test/dynamo/cpython/3_13/test_math.py ++++ b/test/dynamo/cpython/3_13/test_math.py +@@ -1,3 +1,58 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_math.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -63,6 +73,7 @@ index 5ee3055c871..5402cdc4a6c 100644 + # Python test set -- math module # XXXX Should not do tests around zero only +<<<<<<< HEAD @@ -242,7 +300,7 @@ class BadDescr: def __get__(self, obj, objtype=None): @@ -105,10 +116,27 @@ index 5ee3055c871..5402cdc4a6c 100644 self.ftest('fabs(0)', math.fabs(0), 0) self.ftest('fabs(1)', math.fabs(1), 1) +======= + +@@ -242,7 +297,7 @@ class BadDescr: + def __get__(self, obj, objtype=None): + raise ValueError + +-class MathTests(unittest.TestCase): ++class MathTests(__TestCase): + + def ftest(self, name, got, expected, ulp_tol=5, abs_tol=0.0): + """Compare arguments expected and got, as floats, if either +@@ -533,6 +588,7 @@ class MathTests(unittest.TestCase): + self.ftest('fabs(0)', math.fabs(0), 0) + self.ftest('fabs(1)', math.fabs(1), 1) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + @skipIfTorchDynamo("infinite loop") def testFactorial(self): self.assertEqual(math.factorial(0), 1) total = 1 +<<<<<<< HEAD @@ -573,16 +633,17 @@ class MathTests(unittest.TestCase): #self.assertEqual(math.ceil(NINF), NINF) #self.assertTrue(math.isnan(math.floor(NAN))) @@ -165,10 +193,17 @@ index 5ee3055c871..5402cdc4a6c 100644 with self.assertRaises(ValueError): math.dist([1, 2], [3, 4, 5]) +======= +@@ -1072,6 +1128,7 @@ class MathTests(unittest.TestCase): + with self.assertRaises(ValueError): + math.dist([1, 2], [3, 4, 5]) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + @slowTest def testIsqrt(self): # Test a variety of inputs, large and small. test_values = ( +<<<<<<< HEAD @@ -1101,12 +1165,13 @@ class MathTests(unittest.TestCase): self.assertIs(type(s), int) self.assertEqual(s, 0) @@ -192,6 +227,12 @@ index 5ee3055c871..5402cdc4a6c 100644 self.assertEqual(math.ldexp(NINF, n), NINF) self.assertTrue(math.isnan(math.ldexp(NAN, n))) +======= +@@ -1202,12 +1259,6 @@ class MathTests(unittest.TestCase): + self.assertEqual(math.ldexp(NINF, n), NINF) + self.assertTrue(math.isnan(math.ldexp(NAN, n))) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - @requires_IEEE_754 - def testLdexp_denormal(self): - # Denormal output incorrectly rounded (truncated) @@ -201,22 +242,37 @@ index 5ee3055c871..5402cdc4a6c 100644 def testLog(self): self.assertRaises(TypeError, math.log) self.assertRaises(TypeError, math.log, 1, 2, 3) +<<<<<<< HEAD @@ -1233,6 +1292,7 @@ class MathTests(unittest.TestCase): self.assertRaises(ValueError, math.log1p, -1) self.assertEqual(math.log1p(INF), INF) +======= +@@ -1233,6 +1284,7 @@ class MathTests(unittest.TestCase): + self.assertRaises(ValueError, math.log1p, -1) + self.assertEqual(math.log1p(INF), INF) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + @skipIfTorchDynamo("Infinite loop") @requires_IEEE_754 def testLog2(self): self.assertRaises(TypeError, math.log2) +<<<<<<< HEAD @@ -1251,6 +1311,7 @@ class MathTests(unittest.TestCase): self.assertRaises(ValueError, math.log2, NINF) self.assertTrue(math.isnan(math.log2(NAN))) +======= +@@ -1251,6 +1303,7 @@ class MathTests(unittest.TestCase): + self.assertRaises(ValueError, math.log2, NINF) + self.assertTrue(math.isnan(math.log2(NAN))) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + @skipIfTorchDynamo("Infinite loop") @requires_IEEE_754 # log2() is not accurate enough on Mac OS X Tiger (10.4) @support.requires_mac_ver(10, 5) +<<<<<<< HEAD @@ -1332,17 +1393,18 @@ class MathTests(unittest.TestCase): with self.assertRaises(RuntimeError): sumprod(raise_after(5), range(10)) @@ -334,26 +390,56 @@ index 5ee3055c871..5402cdc4a6c 100644 self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])), decimal.Decimal) +======= +@@ -1332,7 +1385,7 @@ class MathTests(unittest.TestCase): + with self.assertRaises(RuntimeError): + sumprod(raise_after(5), range(10)) + +- from test.test_iter import BasicIterClass ++ from test_iter import BasicIterClass + + self.assertEqual(sumprod(BasicIterClass(1), [1]), 0) + self.assertEqual(sumprod([1], BasicIterClass(1)), 0) +@@ -2252,6 +2305,7 @@ class MathTests(unittest.TestCase): + self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])), + decimal.Decimal) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + @skipIfTorchDynamo("Infinite loop") def testPerm(self): perm = math.perm factorial = math.factorial +<<<<<<< HEAD @@ -2316,6 +2382,7 @@ class MathTests(unittest.TestCase): self.assertIs(type(perm(IntSubclass(5), IntSubclass(k))), int) self.assertIs(type(perm(MyIndexable(5), MyIndexable(k))), int) +======= +@@ -2316,6 +2370,7 @@ class MathTests(unittest.TestCase): + self.assertIs(type(perm(IntSubclass(5), IntSubclass(k))), int) + self.assertIs(type(perm(MyIndexable(5), MyIndexable(k))), int) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + @skipIfTorchDynamo("infinite loop") def testComb(self): comb = math.comb factorial = math.factorial +<<<<<<< HEAD @@ -2446,6 +2513,7 @@ class MathTests(unittest.TestCase): math.nextafter(1.0, INF, steps=-1) +======= +@@ -2446,6 +2501,7 @@ class MathTests(unittest.TestCase): + math.nextafter(1.0, INF, steps=-1) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + @unittest.skip("flaky test under torch dynamo") # works on pytest and crashes on unittest @requires_IEEE_754 def test_ulp(self): self.assertEqual(math.ulp(1.0), sys.float_info.epsilon) +<<<<<<< HEAD @@ -2472,10 +2540,11 @@ class MathTests(unittest.TestCase): def test_issue39871(self): # A SystemError should not be raised if the first arg to atan2(), @@ -389,6 +475,27 @@ index 5ee3055c871..5402cdc4a6c 100644 def test_fma_nan_results(self): @@ -2719,8 +2788,7 @@ class FMATests(unittest.TestCase): +======= +@@ -2508,7 +2564,7 @@ class MathTests(unittest.TestCase): + self.assertEqual(math.copysign(1.0, x), math.copysign(1.0, y)) + + +-class IsCloseTests(unittest.TestCase): ++class IsCloseTests(__TestCase): + isclose = math.isclose # subclasses should override this + + def assertIsClose(self, a, b, *args, **kwargs): +@@ -2631,7 +2687,7 @@ class IsCloseTests(unittest.TestCase): + self.assertAllNotClose(fraction_examples, rel_tol=1e-9) + + +-class FMATests(unittest.TestCase): ++class FMATests(__TestCase): + """ Tests for math.fma. """ + + def test_fma_nan_results(self): +@@ -2719,8 +2775,7 @@ class FMATests(unittest.TestCase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # properly: it doesn't use the right sign when the result is zero. @unittest.skipIf( sys.platform.startswith(("freebsd", "wasi", "netbsd", "emscripten")) @@ -398,10 +505,17 @@ index 5ee3055c871..5402cdc4a6c 100644 f"this platform doesn't implement IEE 754-2008 properly") def test_fma_zero_result(self): nonnegative_finites = [0.0, 1e-300, 2.3, 1e300] +<<<<<<< HEAD @@ -2879,10 +2947,5 @@ class FMATests(unittest.TestCase): ) +======= +@@ -2879,10 +2934,5 @@ class FMATests(unittest.TestCase): + ) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -def load_tests(loader, tests, pattern): - from doctest import DocFileSuite - tests.addTest(DocFileSuite(os.path.join("mathdata", "ieee754.txt"))) diff --git a/test/dynamo/cpython/3_13/test_math.py b/test/dynamo/cpython/3_13/test_math.py index d9f6b5fd1d94c..d3096cb338bc8 100644 --- a/test/dynamo/cpython/3_13/test_math.py +++ b/test/dynamo/cpython/3_13/test_math.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_math.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case @@ -475,6 +478,7 @@ def testCeil(self): #self.assertEqual(math.ceil(NINF), NINF) #self.assertTrue(math.isnan(math.ceil(NAN))) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class TestCeil: def __ceil__(self): @@ -486,6 +490,18 @@ class TestNoCeil: pass class TestBadCeil: __ceil__ = BadDescr() +======= + class TestCeil: + def __ceil__(self): + return 42 + class FloatCeil(float): + def __ceil__(self): + return 42 + class TestNoCeil: + pass + class TestBadCeil: + __ceil__ = BadDescr() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(math.ceil(TestCeil()), 42) self.assertEqual(math.ceil(FloatCeil()), 42) self.assertEqual(math.ceil(FloatLike(42.5)), 43) @@ -633,6 +649,7 @@ def testFloor(self): #self.assertEqual(math.ceil(NINF), NINF) #self.assertTrue(math.isnan(math.floor(NAN))) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class TestFloor: def __floor__(self): @@ -644,6 +661,18 @@ class TestNoFloor: pass class TestBadFloor: __floor__ = BadDescr() +======= + class TestFloor: + def __floor__(self): + return 42 + class FloatFloor(float): + def __floor__(self): + return 42 + class TestNoFloor: + pass + class TestBadFloor: + __floor__ = BadDescr() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(math.floor(TestFloor()), 42) self.assertEqual(math.floor(FloatFloor()), 42) self.assertEqual(math.floor(FloatLike(41.9)), 41) @@ -1056,9 +1085,14 @@ def testDist(self): ) # Verify tuple subclasses are allowed +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class T(tuple): pass +======= + class T(tuple): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(dist(T((1, 2, 3)), ((4, 2, -1))), 5.0) # Test handling of bad arguments @@ -1090,9 +1124,14 @@ class T(tuple): with self.assertRaises(TypeError): dist([1], 2) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class BadFloat: __float__ = BadDescr() +======= + class BadFloat: + __float__ = BadDescr() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaises(ValueError): dist([1], [BadFloat()]) @@ -1165,6 +1204,7 @@ def testIsqrt(self): self.assertIs(type(s), int) self.assertEqual(s, 0) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class IntegerLike(object): def __init__(self, value): @@ -1172,6 +1212,14 @@ def __init__(self, value): def __index__(self): return self.value +======= + class IntegerLike(object): + def __init__(self, value): + self.value = value + + def __index__(self): + return self.value +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) s = math.isqrt(IntegerLike(1729)) self.assertIs(type(s), int) @@ -1399,12 +1447,20 @@ def raise_after(n): self.assertEqual(sumprod([1], BasicIterClass(1)), 0) # Error in multiplication +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class BadMultiply: def __mul__(self, other): raise RuntimeError def __rmul__(self, other): raise RuntimeError +======= + class BadMultiply: + def __mul__(self, other): + raise RuntimeError + def __rmul__(self, other): + raise RuntimeError +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaises(RuntimeError): sumprod([10, BadMultiply(), 30], [1, 2, 3]) with self.assertRaises(RuntimeError): @@ -1449,6 +1505,7 @@ def test_sumprod_stress(self): Decimal = decimal.Decimal Fraction = fractions.Fraction +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Int(int): def __add__(self, other): @@ -1469,6 +1526,27 @@ def __mul__(self, other): __rmul__ = __mul__ def __repr__(self): return f'Flt({int(self)})' +======= + class Int(int): + def __add__(self, other): + return Int(int(self) + int(other)) + def __mul__(self, other): + return Int(int(self) * int(other)) + __radd__ = __add__ + __rmul__ = __mul__ + def __repr__(self): + return f'Int({int(self)})' + + class Flt(float): + def __add__(self, other): + return Int(int(self) + int(other)) + def __mul__(self, other): + return Int(int(self) * int(other)) + __radd__ = __add__ + __rmul__ = __mul__ + def __repr__(self): + return f'Flt({int(self)})' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def baseline_sumprod(p, q): """This defines the target behavior including exceptions and special values. @@ -1988,6 +2066,7 @@ def test_trunc(self): self.assertEqual(math.trunc(-0.999999), -0) self.assertEqual(math.trunc(-100.999), -100) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class TestTrunc: def __trunc__(self): @@ -1999,6 +2078,18 @@ class TestNoTrunc: pass class TestBadTrunc: __trunc__ = BadDescr() +======= + class TestTrunc: + def __trunc__(self): + return 23 + class FloatTrunc(float): + def __trunc__(self): + return 23 + class TestNoTrunc: + pass + class TestBadTrunc: + __trunc__ = BadDescr() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(math.trunc(TestTrunc()), 23) self.assertEqual(math.trunc(FloatTrunc()), 23) @@ -2231,10 +2322,16 @@ def test_prod(self): self.assertEqual(prod([1., F(3, 2)]), 1.5) # Error in multiplication +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class BadMultiply: def __rmul__(self, other): raise RuntimeError +======= + class BadMultiply: + def __rmul__(self, other): + raise RuntimeError +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaises(RuntimeError): prod([10., BadMultiply()]) @@ -2540,11 +2637,18 @@ def test_ulp(self): def test_issue39871(self): # A SystemError should not be raised if the first arg to atan2(), # copysign(), or remainder() cannot be converted to a float. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class F: def __float__(self): self.converted = True 1/0 +======= + class F: + def __float__(self): + self.converted = True + 1/0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for func in math.atan2, math.copysign, math.remainder: y = F() with self.assertRaises(TypeError): diff --git a/test/dynamo/cpython/3_13/test_ordered_dict.diff b/test/dynamo/cpython/3_13/test_ordered_dict.diff index 1df02fabdfd27..83ceda25e45e6 100644 --- a/test/dynamo/cpython/3_13/test_ordered_dict.diff +++ b/test/dynamo/cpython/3_13/test_ordered_dict.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/test_ordered_dict.py b/test/dynamo/cpython/3_13/test_ordered_dict.py +<<<<<<< HEAD index a9b6a84996e..efc4288d1a4 100644 --- a/test/dynamo/cpython/3_13/test_ordered_dict.py +++ b/test/dynamo/cpython/3_13/test_ordered_dict.py @@ -1,3 +1,60 @@ +======= +index a9b6a84996e..b77eff70414 100644 +--- a/test/dynamo/cpython/3_13/test_ordered_dict.py ++++ b/test/dynamo/cpython/3_13/test_ordered_dict.py +@@ -1,3 +1,57 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_ordered_dict.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -63,6 +73,7 @@ index a9b6a84996e..efc4288d1a4 100644 import builtins import contextlib import copy +<<<<<<< HEAD @@ -113,13 +170,14 @@ class OrderedDictTests: def test_init_calls(self): @@ -332,11 +343,41 @@ index a9b6a84996e..efc4288d1a4 100644 TODEL = Key() dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2))) @@ -878,7 +951,7 @@ class CPythonOrderedDictSideEffects: +======= +@@ -760,7 +814,7 @@ class _TriggerSideEffectOnEqual: + def side_effect(self): + raise NotImplementedError + +-class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase): ++class PurePythonOrderedDictTests(OrderedDictTests, __TestCase): + + module = py_coll + OrderedDict = py_coll.OrderedDict +@@ -781,7 +835,7 @@ class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase): + self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2))) + + +-class CPythonBuiltinDictTests(unittest.TestCase): ++class CPythonBuiltinDictTests(__TestCase): + """Builtin dict preserves insertion order. + + Reuse some of tests in OrderedDict selectively. +@@ -800,6 +854,7 @@ for method in ( + del method + + ++ + class CPythonOrderedDictSideEffects: + + def check_runtime_error_issue119004(self, dict1, dict2): +@@ -878,7 +933,7 @@ class CPythonOrderedDictSideEffects: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipUnless(c_coll, 'requires the C version of the collections module') class CPythonOrderedDictTests(OrderedDictTests, CPythonOrderedDictSideEffects, - unittest.TestCase): + __TestCase): +<<<<<<< HEAD module = c_coll OrderedDict = c_coll.OrderedDict @@ -359,34 +400,80 @@ index a9b6a84996e..efc4288d1a4 100644 module = c_coll class OrderedDict(c_coll.OrderedDict): @@ -1008,6 +1081,7 @@ class PurePythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol): +======= + + module = c_coll + OrderedDict = c_coll.OrderedDict +@@ -986,7 +1041,7 @@ class CPythonOrderedDictSubclassTests(CPythonOrderedDictTests): + pass + + +-class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase): ++class PurePythonOrderedDictWithSlotsCopyingTests(__TestCase): + + module = py_coll + class OrderedDict(py_coll.OrderedDict): +@@ -995,7 +1050,7 @@ class PurePythonOrderedDictWithSlotsCopyingTests(unittest.TestCase): + + + @unittest.skipUnless(c_coll, 'requires the C version of the collections module') +-class CPythonOrderedDictWithSlotsCopyingTests(unittest.TestCase): ++class CPythonOrderedDictWithSlotsCopyingTests(__TestCase): + + module = c_coll + class OrderedDict(c_coll.OrderedDict): +@@ -1008,6 +1063,7 @@ class PurePythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def setUpClass(cls): cls.type2test = py_coll.OrderedDict + super().setUpClass() +<<<<<<< HEAD def test_popitem(self): d = self._empty_mapping() @@ -1020,6 +1094,7 @@ class CPythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol): +======= + + def test_popitem(self): + d = self._empty_mapping() +@@ -1020,6 +1076,7 @@ class CPythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def setUpClass(cls): cls.type2test = c_coll.OrderedDict + super().setUpClass() +<<<<<<< HEAD def test_popitem(self): d = self._empty_mapping() @@ -1033,6 +1108,7 @@ class PurePythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol): +======= + + def test_popitem(self): + d = self._empty_mapping() +@@ -1033,6 +1090,7 @@ class PurePythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class MyOrderedDict(py_coll.OrderedDict): pass cls.type2test = MyOrderedDict + super().setUpClass() +<<<<<<< HEAD def test_popitem(self): d = self._empty_mapping() @@ -1047,6 +1123,7 @@ class CPythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol): +======= + + def test_popitem(self): + d = self._empty_mapping() +@@ -1047,6 +1105,7 @@ class CPythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class MyOrderedDict(c_coll.OrderedDict): pass cls.type2test = MyOrderedDict + super().setUpClass() +<<<<<<< HEAD def test_popitem(self): d = self._empty_mapping() @@ -405,14 +492,39 @@ index a9b6a84996e..efc4288d1a4 100644 -class CSimpleLRUCacheTests(SimpleLRUCacheTests, unittest.TestCase): +class CSimpleLRUCacheTests(SimpleLRUCacheTests, __TestCase): +======= + + def test_popitem(self): + d = self._empty_mapping() +@@ -1120,21 +1179,22 @@ class SimpleLRUCacheTests: + self.assertEqual(list(c), [1, 3, 2]) + + +-class PySimpleLRUCacheTests(SimpleLRUCacheTests, unittest.TestCase): ++class PySimpleLRUCacheTests(SimpleLRUCacheTests, __TestCase): + + class type2test(SimpleLRUCache, py_coll.OrderedDict): + pass + + + @unittest.skipUnless(c_coll, 'requires the C version of the collections module') +-class CSimpleLRUCacheTests(SimpleLRUCacheTests, unittest.TestCase): ++class CSimpleLRUCacheTests(SimpleLRUCacheTests, __TestCase): + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def setUpClass(cls): class type2test(SimpleLRUCache, c_coll.OrderedDict): pass cls.type2test = type2test + super().setUpClass() +<<<<<<< HEAD +======= + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/test/dynamo/cpython/3_13/test_ordered_dict.py b/test/dynamo/cpython/3_13/test_ordered_dict.py index 56a8662de1335..fec0b68ef668f 100644 --- a/test/dynamo/cpython/3_13/test_ordered_dict.py +++ b/test/dynamo/cpython/3_13/test_ordered_dict.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_ordered_dict.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case @@ -170,6 +173,7 @@ def test_update(self): def test_init_calls(self): calls = [] +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Spam: def keys(self): @@ -178,6 +182,15 @@ def keys(self): def items(self): calls.append('items') return () +======= + class Spam: + def keys(self): + calls.append('keys') + return () + def items(self): + calls.append('items') + return () +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.OrderedDict(Spam()) self.assertEqual(calls, ['keys']) @@ -187,10 +200,16 @@ def test_overridden_init(self): # a consistent internal state is created in __new__ # rather than __init__. OrderedDict = self.OrderedDict +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class ODNI(OrderedDict): def __init__(*args, **kwargs): pass +======= + class ODNI(OrderedDict): + def __init__(*args, **kwargs): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) od = ODNI() od['a'] = 1 # This used to fail because __init__ was bypassed @@ -326,10 +345,16 @@ def test_pop(self): self.assertEqual(od.pop(k, 12345), 12345) # make sure pop still works when __missing__ is defined +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Missing(OrderedDict): def __missing__(self, key): return 0 +======= + class Missing(OrderedDict): + def __missing__(self, key): + return 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m = Missing(a=1) self.assertEqual(m.pop('b', 5), 5) self.assertEqual(m.pop('a', 6), 1) @@ -476,10 +501,16 @@ def test_setdefault(self): self.assertEqual(od.setdefault('g', default=9), 9) # make sure setdefault still works when __missing__ is defined +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Missing(OrderedDict): def __missing__(self, key): return 0 +======= + class Missing(OrderedDict): + def __missing__(self, key): + return 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(Missing().setdefault(5, 9), 9) def test_reinsert(self): @@ -545,10 +576,16 @@ def test_views(self): def test_override_update(self): OrderedDict = self.OrderedDict # Verify that subclasses can override update() without breaking __init__() +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class MyOD(OrderedDict): def update(self, *args, **kwds): raise Exception() +======= + class MyOD(OrderedDict): + def update(self, *args, **kwds): + raise Exception() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) items = [('a', 1), ('c', 3), ('b', 2)] self.assertEqual(list(MyOD(items).items()), items) @@ -569,10 +606,16 @@ def test_highly_nested_subclass(self): # should not crash Python. OrderedDict = self.OrderedDict deleted = [] +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class MyOD(OrderedDict): def __del__(self): deleted.append(self.i) +======= + class MyOD(OrderedDict): + def __del__(self): + deleted.append(self.i) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = None for i in range(100): obj = MyOD([(None, obj)]) @@ -584,6 +627,7 @@ def __del__(self): def test_delitem_hash_collision(self): OrderedDict = self.OrderedDict +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Key: def __init__(self, hash): @@ -598,6 +642,21 @@ def __eq__(self, other): return False def __repr__(self): return self.value +======= + class Key: + def __init__(self, hash): + self._hash = hash + self.value = str(id(self)) + def __hash__(self): + return self._hash + def __eq__(self, other): + try: + return self.value == other.value + except AttributeError: + return False + def __repr__(self): + return self.value +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def blocking_hash(hash): # See the collision-handling in lookdict (in Objects/dictobject.c). @@ -624,10 +683,16 @@ def blocking_hash(hash): def test_issue24347(self): OrderedDict = self.OrderedDict +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Key: def __hash__(self): return randrange(100000) +======= + class Key: + def __hash__(self): + return randrange(100000) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) od = OrderedDict() for i in range(100): @@ -647,10 +712,16 @@ def __hash__(self): def test_issue24348(self): OrderedDict = self.OrderedDict +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Key: def __hash__(self): return 1 +======= + class Key: + def __hash__(self): + return 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) od = OrderedDict() od[Key()] = 0 @@ -832,10 +903,16 @@ class PurePythonOrderedDictTests(OrderedDictTests, __TestCase): OrderedDict = py_coll.OrderedDict def test_issue119004_attribute_error(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Key(_TriggerSideEffectOnEqual): def side_effect(self): del dict1[TODEL] +======= + class Key(_TriggerSideEffectOnEqual): + def side_effect(self): + del dict1[TODEL] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TODEL = Key() dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2))) @@ -875,10 +952,16 @@ def check_runtime_error_issue119004(self, dict1, dict2): self.assertRaisesRegex(RuntimeError, msg, operator.eq, dict1, dict2) def test_issue119004_change_size_by_clear(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Key(_TriggerSideEffectOnEqual): def side_effect(self): dict1.clear() +======= + class Key(_TriggerSideEffectOnEqual): + def side_effect(self): + dict1.clear() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dict1 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2))) dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2))) @@ -888,10 +971,16 @@ def side_effect(self): self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2))) def test_issue119004_change_size_by_delete_key(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Key(_TriggerSideEffectOnEqual): def side_effect(self): del dict1[TODEL] +======= + class Key(_TriggerSideEffectOnEqual): + def side_effect(self): + del dict1[TODEL] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TODEL = Key() dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2))) @@ -902,11 +991,18 @@ def side_effect(self): self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2))) def test_issue119004_change_linked_list_by_clear(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Key(_TriggerSideEffectOnEqual): def side_effect(self): dict1.clear() dict1['a'] = dict1['b'] = 'c' +======= + class Key(_TriggerSideEffectOnEqual): + def side_effect(self): + dict1.clear() + dict1['a'] = dict1['b'] = 'c' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dict1 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2))) dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2))) @@ -916,11 +1012,18 @@ def side_effect(self): self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2))) def test_issue119004_change_linked_list_by_delete_key(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Key(_TriggerSideEffectOnEqual): def side_effect(self): del dict1[TODEL] dict1['a'] = 'c' +======= + class Key(_TriggerSideEffectOnEqual): + def side_effect(self): + del dict1[TODEL] + dict1['a'] = 'c' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TODEL = Key() dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2))) @@ -931,11 +1034,18 @@ def side_effect(self): self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2))) def test_issue119004_change_size_by_delete_key_in_dict_eq(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Key(_TriggerSideEffectOnEqual): trigger = 0 def side_effect(self): del dict1[TODEL] +======= + class Key(_TriggerSideEffectOnEqual): + trigger = 0 + def side_effect(self): + del dict1[TODEL] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TODEL = Key() dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2))) diff --git a/test/dynamo/cpython/3_13/test_set.diff b/test/dynamo/cpython/3_13/test_set.diff index 77dce156a1e12..66ae3841f35a0 100644 --- a/test/dynamo/cpython/3_13/test_set.diff +++ b/test/dynamo/cpython/3_13/test_set.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/test_set.py b/test/dynamo/cpython/3_13/test_set.py +<<<<<<< HEAD index d9102eb98a5..c8ee5ca451f 100644 --- a/test/dynamo/cpython/3_13/test_set.py +++ b/test/dynamo/cpython/3_13/test_set.py @@ -1,3 +1,56 @@ +======= +index d9102eb98a5..0b8e99a04c4 100644 +--- a/test/dynamo/cpython/3_13/test_set.py ++++ b/test/dynamo/cpython/3_13/test_set.py +@@ -1,3 +1,53 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_set.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -59,6 +69,7 @@ index d9102eb98a5..c8ee5ca451f 100644 import unittest from test import support from test.support import warnings_helper +<<<<<<< HEAD @@ -38,7 +91,7 @@ class HashCountingInt(int): self.hash_count += 1 return int.__hash__(self) @@ -69,10 +80,23 @@ index d9102eb98a5..c8ee5ca451f 100644 def setUp(self): @@ -47,6 +100,7 @@ class TestJointOps: +======= +@@ -38,7 +88,7 @@ class HashCountingInt(int): + self.hash_count += 1 + return int.__hash__(self) + +-class TestJointOps: ++class _TestJointOps: + # Tests common to both set and frozenset + + def setUp(self): +@@ -47,6 +97,7 @@ class TestJointOps: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' self.s = self.thetype(word) self.d = dict.fromkeys(word) + super().setUp() +<<<<<<< HEAD def test_new_or_init(self): self.assertRaises(TypeError, self.thetype, [], 2) @@ -140,10 +164,20 @@ index d9102eb98a5..c8ee5ca451f 100644 def test_free_after_iterating(self): support.check_free_after_iterating(self, iter, self.thetype) +======= + + def test_new_or_init(self): + self.assertRaises(TypeError, self.thetype, [], 2) +@@ -355,7 +406,7 @@ class TestJointOps: + def test_free_after_iterating(self): + support.check_free_after_iterating(self, iter, self.thetype) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestSet(TestJointOps, unittest.TestCase): +class TestSet(_TestJointOps, __TestCase): thetype = set basetype = set +<<<<<<< HEAD @@ -600,19 +658,20 @@ class TestSet(TestJointOps, unittest.TestCase): self.assertRaises(ReferenceError, str, p) @@ -226,10 +260,18 @@ index d9102eb98a5..c8ee5ca451f 100644 subclass_with_new([1, 2], newarg=3) +======= + +@@ -675,7 +726,7 @@ class TestSetSubclass(TestSet): + subclass_with_new([1, 2], newarg=3) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestFrozenSet(TestJointOps, unittest.TestCase): +class TestFrozenSet(_TestJointOps, __TestCase): thetype = frozenset basetype = frozenset +<<<<<<< HEAD @@ -756,27 +818,30 @@ class TestFrozenSetSubclass(TestFrozenSet): basetype = frozenset @@ -276,6 +318,13 @@ index d9102eb98a5..c8ee5ca451f 100644 class SetSubclassWithSlots(set): __slots__ = ('x', 'y', '__dict__') +======= + +@@ -811,10 +862,17 @@ class TestFrozenSetSubclass(TestFrozenSet): + class SetSubclassWithSlots(set): + __slots__ = ('x', 'y', '__dict__') + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestSetSubclassWithSlots(unittest.TestCase): +class TestSetSubclassWithSlots(__TestCase): thetype = SetSubclassWithSlots @@ -290,6 +339,7 @@ index d9102eb98a5..c8ee5ca451f 100644 + self.s = self.thetype(word) + self.d = dict.fromkeys(word) + super().setUp() +<<<<<<< HEAD class FrozenSetSubclassWithSlots(frozenset): __slots__ = ('x', 'y', '__dict__') @@ -306,29 +356,62 @@ index d9102eb98a5..c8ee5ca451f 100644 #------------------------------------------------------------------------------ +======= + + class FrozenSetSubclassWithSlots(frozenset): + __slots__ = ('x', 'y', '__dict__') +@@ -828,7 +886,7 @@ empty_set = set() + + #============================================================================== + +-class TestBasicOps: ++class _TestBasicOps: + + def test_repr(self): + if self.repr is not None: +@@ -934,7 +992,7 @@ class TestBasicOps: + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase): +class TestBasicOpsEmpty(_TestBasicOps, __TestCase): def setUp(self): self.case = "empty set" self.values = [] +<<<<<<< HEAD @@ -942,10 +1014,11 @@ class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase): +======= +@@ -942,10 +1000,11 @@ class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.dup = set(self.values) self.length = 0 self.repr = "set()" + super().setUp() +<<<<<<< HEAD #------------------------------------------------------------------------------ +======= + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase): +class TestBasicOpsSingleton(_TestBasicOps, __TestCase): def setUp(self): self.case = "unit set (number)" self.values = [3] +<<<<<<< HEAD @@ -953,6 +1026,7 @@ class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase): +======= +@@ -953,6 +1012,7 @@ class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.dup = set(self.values) self.length = 1 self.repr = "{3}" + super().setUp() +<<<<<<< HEAD def test_in(self): self.assertIn(3, self.set) @@ -336,16 +419,30 @@ index d9102eb98a5..c8ee5ca451f 100644 #------------------------------------------------------------------------------ +======= + + def test_in(self): + self.assertIn(3, self.set) +@@ -962,7 +1022,7 @@ class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase): + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestBasicOpsTuple(TestBasicOps, unittest.TestCase): +class TestBasicOpsTuple(_TestBasicOps, __TestCase): def setUp(self): self.case = "unit set (tuple)" self.values = [(0, "zero")] +<<<<<<< HEAD @@ -970,6 +1044,7 @@ class TestBasicOpsTuple(TestBasicOps, unittest.TestCase): +======= +@@ -970,6 +1030,7 @@ class TestBasicOpsTuple(TestBasicOps, unittest.TestCase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.dup = set(self.values) self.length = 1 self.repr = "{(0, 'zero')}" + super().setUp() +<<<<<<< HEAD def test_in(self): self.assertIn((0, "zero"), self.set) @@ -353,19 +450,38 @@ index d9102eb98a5..c8ee5ca451f 100644 #------------------------------------------------------------------------------ +======= + + def test_in(self): + self.assertIn((0, "zero"), self.set) +@@ -979,7 +1040,7 @@ class TestBasicOpsTuple(TestBasicOps, unittest.TestCase): + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestBasicOpsTriple(TestBasicOps, unittest.TestCase): +class TestBasicOpsTriple(_TestBasicOps, __TestCase): def setUp(self): self.case = "triple set" self.values = [0, "zero", operator.add] +<<<<<<< HEAD @@ -987,36 +1062,39 @@ class TestBasicOpsTriple(TestBasicOps, unittest.TestCase): +======= +@@ -987,36 +1048,39 @@ class TestBasicOpsTriple(TestBasicOps, unittest.TestCase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.dup = set(self.values) self.length = 3 self.repr = None + super().setUp() +<<<<<<< HEAD #------------------------------------------------------------------------------ +======= + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestBasicOpsString(TestBasicOps, unittest.TestCase): +class TestBasicOpsString(_TestBasicOps, __TestCase): def setUp(self): @@ -375,12 +491,21 @@ index d9102eb98a5..c8ee5ca451f 100644 self.dup = set(self.values) self.length = 3 + super().setUp() +<<<<<<< HEAD def test_repr(self): self.check_repr_against_values() #------------------------------------------------------------------------------ +======= + + def test_repr(self): + self.check_repr_against_values() + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestBasicOpsBytes(TestBasicOps, unittest.TestCase): +class TestBasicOpsBytes(_TestBasicOps, __TestCase): def setUp(self): @@ -390,22 +515,36 @@ index d9102eb98a5..c8ee5ca451f 100644 self.dup = set(self.values) self.length = 3 + super().setUp() +<<<<<<< HEAD def test_repr(self): self.check_repr_against_values() #------------------------------------------------------------------------------ +======= + + def test_repr(self): + self.check_repr_against_values() + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase): +class TestBasicOpsMixedStringBytes(_TestBasicOps, __TestCase): def setUp(self): self.enterContext(warnings_helper.check_warnings()) warnings.simplefilter('ignore', BytesWarning) +<<<<<<< HEAD @@ -1025,6 +1103,7 @@ class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase): +======= +@@ -1025,6 +1089,7 @@ class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.set = set(self.values) self.dup = set(self.values) self.length = 4 + super().setUp() +<<<<<<< HEAD def test_repr(self): self.check_repr_against_values() @@ -422,20 +561,46 @@ index d9102eb98a5..c8ee5ca451f 100644 #============================================================================== +======= + + def test_repr(self): + self.check_repr_against_values() +@@ -1038,7 +1103,7 @@ def baditer(): + def gooditer(): + yield True + +-class TestExceptionPropagation(unittest.TestCase): ++class TestExceptionPropagation(__TestCase): + """SF 628246: Set constructor should not trap iterator TypeErrors""" + + def test_instanceWithException(self): +@@ -1065,7 +1130,7 @@ class TestExceptionPropagation(unittest.TestCase): + + #============================================================================== + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestSetOfSets(unittest.TestCase): +class TestSetOfSets(__TestCase): def test_constructor(self): inner = frozenset([1]) outer = set([inner]) +<<<<<<< HEAD @@ -1078,9 +1157,10 @@ class TestSetOfSets(unittest.TestCase): #============================================================================== +======= +@@ -1078,9 +1143,10 @@ class TestSetOfSets(unittest.TestCase): + + #============================================================================== + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestBinaryOps(unittest.TestCase): +class TestBinaryOps(__TestCase): def setUp(self): self.set = set((2, 4, 6)) + super().setUp() +<<<<<<< HEAD def test_eq(self): # SF bug 643115 self.assertEqual(self.set, set({2:1,4:3,6:5})) @@ -443,11 +608,21 @@ index d9102eb98a5..c8ee5ca451f 100644 #============================================================================== +======= + + def test_eq(self): # SF bug 643115 + self.assertEqual(self.set, set({2:1,4:3,6:5})) +@@ -1151,9 +1217,10 @@ class TestBinaryOps(unittest.TestCase): + + #============================================================================== + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestUpdateOps(unittest.TestCase): +class TestUpdateOps(__TestCase): def setUp(self): self.set = set((2, 4, 6)) + super().setUp() +<<<<<<< HEAD def test_union_subset(self): self.set |= set([2]) @@ -455,12 +630,22 @@ index d9102eb98a5..c8ee5ca451f 100644 #============================================================================== +======= + + def test_union_subset(self): + self.set |= set([2]) +@@ -1237,10 +1304,11 @@ class TestUpdateOps(unittest.TestCase): + + #============================================================================== + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestMutate(unittest.TestCase): +class TestMutate(__TestCase): def setUp(self): self.values = ["a", "b", "c"] self.set = set(self.values) + super().setUp() +<<<<<<< HEAD def test_add_present(self): self.set.add("c") @@ -474,6 +659,21 @@ index d9102eb98a5..c8ee5ca451f 100644 case2method = {"<=": "issubset", ">=": "issuperset", @@ -1334,22 +1416,22 @@ class TestSubsets: +======= + + def test_add_present(self): + self.set.add("c") +@@ -1311,7 +1379,7 @@ class TestMutate(unittest.TestCase): + + #============================================================================== + +-class TestSubsets: ++class _TestSubsets: + + case2method = {"<=": "issubset", + ">=": "issuperset", +@@ -1334,22 +1402,22 @@ class TestSubsets: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result = eval("x" + case + "y", locals()) self.assertEqual(result, expected) # Test the "friendly" method-name spelling, if one exists. @@ -483,7 +683,11 @@ index d9102eb98a5..c8ee5ca451f 100644 + method = getattr(x, _TestSubsets.case2method[case]) result = method(y) self.assertEqual(result, expected) +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Now do the same for the operands reversed. - rcase = TestSubsets.reverse[case] + rcase = _TestSubsets.reverse[case] @@ -496,48 +700,81 @@ index d9102eb98a5..c8ee5ca451f 100644 result = method(x) self.assertEqual(result, expected) #------------------------------------------------------------------------------ +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase): +class TestSubsetEqualEmpty(_TestSubsets, __TestCase): left = set() right = set() name = "both empty" +<<<<<<< HEAD @@ -1357,7 +1439,7 @@ class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase): #------------------------------------------------------------------------------ +======= +@@ -1357,7 +1425,7 @@ class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase): + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase): +class TestSubsetEqualNonEmpty(_TestSubsets, __TestCase): left = set([1, 2]) right = set([1, 2]) name = "equal pair" +<<<<<<< HEAD @@ -1365,7 +1447,7 @@ class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase): #------------------------------------------------------------------------------ +======= +@@ -1365,7 +1433,7 @@ class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase): + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase): +class TestSubsetEmptyNonEmpty(_TestSubsets, __TestCase): left = set() right = set([1, 2]) name = "one empty, one non-empty" +<<<<<<< HEAD @@ -1373,7 +1455,7 @@ class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase): #------------------------------------------------------------------------------ +======= +@@ -1373,7 +1441,7 @@ class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase): + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestSubsetPartial(TestSubsets, unittest.TestCase): +class TestSubsetPartial(_TestSubsets, __TestCase): left = set([1]) right = set([1, 2]) name = "one a non-empty proper subset of other" +<<<<<<< HEAD @@ -1381,7 +1463,7 @@ class TestSubsetPartial(TestSubsets, unittest.TestCase): #------------------------------------------------------------------------------ +======= +@@ -1381,7 +1449,7 @@ class TestSubsetPartial(TestSubsets, unittest.TestCase): + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestSubsetNonOverlap(TestSubsets, unittest.TestCase): +class TestSubsetNonOverlap(_TestSubsets, __TestCase): left = set([1]) right = set([2]) name = "neither empty, neither contains" +<<<<<<< HEAD @@ -1389,7 +1471,7 @@ class TestSubsetNonOverlap(TestSubsets, unittest.TestCase): #============================================================================== @@ -551,6 +788,21 @@ index d9102eb98a5..c8ee5ca451f 100644 #------------------------------------------------------------------------------ +======= +@@ -1389,7 +1457,7 @@ class TestSubsetNonOverlap(TestSubsets, unittest.TestCase): + + #============================================================================== + +-class TestOnlySetsInBinaryOps: ++class _TestOnlySetsInBinaryOps: + + def test_eq_ne(self): + # Unlike the others, this is testing that == and != *are* allowed. +@@ -1505,47 +1573,52 @@ class TestOnlySetsInBinaryOps: + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestOnlySetsNumeric(TestOnlySetsInBinaryOps, unittest.TestCase): +class TestOnlySetsNumeric(_TestOnlySetsInBinaryOps, __TestCase): def setUp(self): @@ -558,9 +810,15 @@ index d9102eb98a5..c8ee5ca451f 100644 self.other = 19 self.otherIsIterable = False + super().setUp() +<<<<<<< HEAD #------------------------------------------------------------------------------ +======= + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestOnlySetsDict(TestOnlySetsInBinaryOps, unittest.TestCase): +class TestOnlySetsDict(_TestOnlySetsInBinaryOps, __TestCase): def setUp(self): @@ -568,9 +826,15 @@ index d9102eb98a5..c8ee5ca451f 100644 self.other = {1:2, 3:4} self.otherIsIterable = True + super().setUp() +<<<<<<< HEAD #------------------------------------------------------------------------------ +======= + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestOnlySetsOperator(TestOnlySetsInBinaryOps, unittest.TestCase): +class TestOnlySetsOperator(_TestOnlySetsInBinaryOps, __TestCase): def setUp(self): @@ -578,9 +842,15 @@ index d9102eb98a5..c8ee5ca451f 100644 self.other = operator.add self.otherIsIterable = False + super().setUp() +<<<<<<< HEAD #------------------------------------------------------------------------------ +======= + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestOnlySetsTuple(TestOnlySetsInBinaryOps, unittest.TestCase): +class TestOnlySetsTuple(_TestOnlySetsInBinaryOps, __TestCase): def setUp(self): @@ -588,9 +858,15 @@ index d9102eb98a5..c8ee5ca451f 100644 self.other = (2, 4, 6) self.otherIsIterable = True + super().setUp() +<<<<<<< HEAD #------------------------------------------------------------------------------ +======= + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestOnlySetsString(TestOnlySetsInBinaryOps, unittest.TestCase): +class TestOnlySetsString(_TestOnlySetsInBinaryOps, __TestCase): def setUp(self): @@ -598,19 +874,30 @@ index d9102eb98a5..c8ee5ca451f 100644 self.other = 'abc' self.otherIsIterable = True + super().setUp() +<<<<<<< HEAD #------------------------------------------------------------------------------ +======= + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase): +class TestOnlySetsGenerator(_TestOnlySetsInBinaryOps, __TestCase): def setUp(self): def gen(): for i in range(0, 10, 2): +<<<<<<< HEAD @@ -1553,10 +1640,11 @@ class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase): +======= +@@ -1553,10 +1626,11 @@ class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.set = set((1, 2, 3)) self.other = gen() self.otherIsIterable = True + super().setUp() +<<<<<<< HEAD #============================================================================== @@ -623,52 +910,97 @@ index d9102eb98a5..c8ee5ca451f 100644 #------------------------------------------------------------------------------ +======= + + #============================================================================== + +-class TestCopying: ++class _TestCopying: + + def test_copy(self): + dup = self.set.copy() +@@ -1577,40 +1651,46 @@ class TestCopying: + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestCopyingEmpty(TestCopying, unittest.TestCase): +class TestCopyingEmpty(_TestCopying, __TestCase): def setUp(self): self.set = set() + super().setUp() +<<<<<<< HEAD #------------------------------------------------------------------------------ +======= + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestCopyingSingleton(TestCopying, unittest.TestCase): +class TestCopyingSingleton(_TestCopying, __TestCase): def setUp(self): self.set = set(["hello"]) + super().setUp() +<<<<<<< HEAD #------------------------------------------------------------------------------ +======= + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestCopyingTriple(TestCopying, unittest.TestCase): +class TestCopyingTriple(_TestCopying, __TestCase): def setUp(self): self.set = set(["zero", 0, None]) + super().setUp() +<<<<<<< HEAD #------------------------------------------------------------------------------ +======= + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestCopyingTuple(TestCopying, unittest.TestCase): +class TestCopyingTuple(_TestCopying, __TestCase): def setUp(self): self.set = set([(1, 2)]) + super().setUp() +<<<<<<< HEAD #------------------------------------------------------------------------------ +======= + + #------------------------------------------------------------------------------ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestCopyingNested(TestCopying, unittest.TestCase): +class TestCopyingNested(_TestCopying, __TestCase): def setUp(self): self.set = set([((1, 2), (3, 4))]) + super().setUp() +<<<<<<< HEAD #============================================================================== +======= + + #============================================================================== + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestIdentities(unittest.TestCase): +class TestIdentities(__TestCase): def setUp(self): self.a = set('abracadabra') self.b = set('alacazam') + super().setUp() +<<<<<<< HEAD def test_binopsVsSubsets(self): a, b = self.a, self.b @@ -685,11 +1017,30 @@ index d9102eb98a5..c8ee5ca451f 100644 def __hash__(self): return 0 +======= + + def test_binopsVsSubsets(self): + a, b = self.a, self.b +@@ -1727,7 +1807,7 @@ def L(seqn): + 'Test multiple tiers of iterators' + return chain(map(lambda x:x, R(Ig(G(seqn))))) + +-class TestVariousIteratorArgs(unittest.TestCase): ++class TestVariousIteratorArgs(__TestCase): + + def test_constructor(self): + for cons in (set, frozenset): +@@ -1785,7 +1865,7 @@ class bad_dict_clear: + def __hash__(self): + return 0 + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestWeirdBugs(unittest.TestCase): +class TestWeirdBugs(__TestCase): def test_8420_set_merge(self): # This used to segfault global be_bad, set2, dict2 +<<<<<<< HEAD @@ -1813,12 +1907,13 @@ class TestWeirdBugs(unittest.TestCase): list(si) @@ -761,25 +1112,62 @@ index d9102eb98a5..c8ee5ca451f 100644 self.check_set_op_does_not_crash(f3) +======= +@@ -1826,7 +1906,7 @@ class TestWeirdBugs(unittest.TestCase): + s.update(other) + + +-class TestOperationsMutating: ++class _TestOperationsMutating: + """Regression test for bpo-46615""" + + constructor1 = None +@@ -1862,7 +1942,7 @@ class TestOperationsMutating: + self.assertIn("changed size during iteration", str(e)) + + +-class TestBinaryOpsMutating(TestOperationsMutating): ++class _TestBinaryOpsMutating(_TestOperationsMutating): + + def test_eq_with_mutation(self): + self.check_set_op_does_not_crash(lambda a, b: a == b) +@@ -1933,24 +2013,24 @@ class TestBinaryOpsMutating(TestOperationsMutating): + self.check_set_op_does_not_crash(f3) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestBinaryOpsMutating_Set_Set(TestBinaryOpsMutating, unittest.TestCase): +class TestBinaryOpsMutating_Set_Set(_TestBinaryOpsMutating, __TestCase): constructor1 = set constructor2 = set +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestBinaryOpsMutating_Subclass_Subclass(TestBinaryOpsMutating, unittest.TestCase): +class TestBinaryOpsMutating_Subclass_Subclass(_TestBinaryOpsMutating, __TestCase): constructor1 = SetSubclass constructor2 = SetSubclass +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestBinaryOpsMutating_Set_Subclass(TestBinaryOpsMutating, unittest.TestCase): +class TestBinaryOpsMutating_Set_Subclass(_TestBinaryOpsMutating, __TestCase): constructor1 = set constructor2 = SetSubclass +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestBinaryOpsMutating_Subclass_Set(TestBinaryOpsMutating, unittest.TestCase): +class TestBinaryOpsMutating_Subclass_Set(_TestBinaryOpsMutating, __TestCase): constructor1 = SetSubclass constructor2 = set +<<<<<<< HEAD -class TestMethodsMutating(TestOperationsMutating): @@ -791,35 +1179,69 @@ index d9102eb98a5..c8ee5ca451f 100644 self.check_set_op_does_not_crash(set.update) +======= + + +-class TestMethodsMutating(TestOperationsMutating): ++class _TestMethodsMutating(_TestOperationsMutating): + + def test_issubset_with_mutation(self): + self.check_set_op_does_not_crash(set.issubset) +@@ -1986,27 +2066,27 @@ class TestMethodsMutating(TestOperationsMutating): + self.check_set_op_does_not_crash(set.update) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestMethodsMutating_Set_Set(TestMethodsMutating, unittest.TestCase): +class TestMethodsMutating_Set_Set(_TestMethodsMutating, __TestCase): constructor1 = set constructor2 = set +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestMethodsMutating_Subclass_Subclass(TestMethodsMutating, unittest.TestCase): +class TestMethodsMutating_Subclass_Subclass(_TestMethodsMutating, __TestCase): constructor1 = SetSubclass constructor2 = SetSubclass +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestMethodsMutating_Set_Subclass(TestMethodsMutating, unittest.TestCase): +class TestMethodsMutating_Set_Subclass(_TestMethodsMutating, __TestCase): constructor1 = set constructor2 = SetSubclass +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestMethodsMutating_Subclass_Set(TestMethodsMutating, unittest.TestCase): +class TestMethodsMutating_Subclass_Set(_TestMethodsMutating, __TestCase): constructor1 = SetSubclass constructor2 = set +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestMethodsMutating_Set_Dict(TestMethodsMutating, unittest.TestCase): +class TestMethodsMutating_Set_Dict(_TestMethodsMutating, __TestCase): constructor1 = set constructor2 = dict.fromkeys +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestMethodsMutating_Set_List(TestMethodsMutating, unittest.TestCase): +class TestMethodsMutating_Set_List(_TestMethodsMutating, __TestCase): constructor1 = set constructor2 = list +<<<<<<< HEAD @@ -2068,7 +2164,7 @@ def faces(G): return f @@ -833,6 +1255,21 @@ index d9102eb98a5..c8ee5ca451f 100644 @@ -2118,4 +2214,4 @@ class TestGraphs(unittest.TestCase): #============================================================================== +======= + +@@ -2068,7 +2148,7 @@ def faces(G): + return f + + +-class TestGraphs(unittest.TestCase): ++class TestGraphs(__TestCase): + + def test_cube(self): + +@@ -2118,4 +2198,4 @@ class TestGraphs(unittest.TestCase): + #============================================================================== + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/test/dynamo/cpython/3_13/test_set.py b/test/dynamo/cpython/3_13/test_set.py index 1d80fccca5b13..cdc7b61e2e9a1 100644 --- a/test/dynamo/cpython/3_13/test_set.py +++ b/test/dynamo/cpython/3_13/test_set.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_set.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case @@ -315,6 +318,7 @@ def test_iterator_pickling(self): self.assertEqual(self.thetype(it), data - self.thetype((drop,))) def test_deepcopy(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Tracer: def __init__(self, value): @@ -323,6 +327,15 @@ def __hash__(self): return self.value def __deepcopy__(self, memo=None): return Tracer(self.value + 1) +======= + class Tracer: + def __init__(self, value): + self.value = value + def __hash__(self): + return self.value + def __deepcopy__(self, memo=None): + return Tracer(self.value + 1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) t = Tracer(10) s = self.thetype([t]) dup = copy.deepcopy(s) @@ -334,9 +347,14 @@ def __deepcopy__(self, memo=None): def test_gc(self): # Create a nest of cycles to exercise overall ref count check +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class A: pass +======= + class A: + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) s = set(A() for i in range(1000)) for elem in s: elem.cycle = s @@ -345,10 +363,16 @@ class A: def test_subclass_with_custom_hash(self): # Bug #1257731 +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class H(self.thetype): def __hash__(self): return int(id(self) & 0x7fffffff) +======= + class H(self.thetype): + def __hash__(self): + return int(id(self) & 0x7fffffff) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) s=H() f=set() f.add(s) @@ -399,9 +423,14 @@ def test_do_not_rehash_dict_keys(self): def test_container_iterator(self): # Bug #3680: tp_traverse was not implemented for set iterator object +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class C(object): pass +======= + class C(object): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = C() ref = weakref.ref(obj) container = set([obj, 1]) @@ -658,6 +687,7 @@ def test_weakref(self): self.assertRaises(ReferenceError, str, p) def test_rich_compare(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class TestRichSetCompare: def __gt__(self, some_set): @@ -672,6 +702,21 @@ def __ge__(self, some_set): def __le__(self, some_set): self.le_called = True return False +======= + class TestRichSetCompare: + def __gt__(self, some_set): + self.gt_called = True + return False + def __lt__(self, some_set): + self.lt_called = True + return False + def __ge__(self, some_set): + self.ge_called = True + return False + def __le__(self, some_set): + self.le_called = True + return False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This first tries the builtin rich set comparison, which doesn't know # how to handle the custom object. Upon returning NotImplemented, the @@ -703,31 +748,51 @@ class TestSetSubclass(TestSet): basetype = set def test_keywords_in_subclass(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass(set): pass +======= + class subclass(set): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u = subclass([1, 2]) self.assertIs(type(u), subclass) self.assertEqual(set(u), {1, 2}) with self.assertRaises(TypeError): subclass(sequence=()) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass_with_init(set): def __init__(self, arg, newarg=None): super().__init__(arg) self.newarg = newarg +======= + class subclass_with_init(set): + def __init__(self, arg, newarg=None): + super().__init__(arg) + self.newarg = newarg +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u = subclass_with_init([1, 2], newarg=3) self.assertIs(type(u), subclass_with_init) self.assertEqual(set(u), {1, 2}) self.assertEqual(u.newarg, 3) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass_with_new(set): def __new__(cls, arg, newarg=None): self = super().__new__(cls, arg) self.newarg = newarg return self +======= + class subclass_with_new(set): + def __new__(cls, arg, newarg=None): + self = super().__new__(cls, arg) + self.newarg = newarg + return self +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u = subclass_with_new([1, 2]) self.assertIs(type(u), subclass_with_new) self.assertEqual(set(u), {1, 2}) @@ -818,30 +883,49 @@ class TestFrozenSetSubclass(TestFrozenSet): basetype = frozenset def test_keywords_in_subclass(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass(frozenset): pass +======= + class subclass(frozenset): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u = subclass([1, 2]) self.assertIs(type(u), subclass) self.assertEqual(set(u), {1, 2}) with self.assertRaises(TypeError): subclass(sequence=()) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass_with_init(frozenset): def __init__(self, arg, newarg=None): self.newarg = newarg +======= + class subclass_with_init(frozenset): + def __init__(self, arg, newarg=None): + self.newarg = newarg +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u = subclass_with_init([1, 2], newarg=3) self.assertIs(type(u), subclass_with_init) self.assertEqual(set(u), {1, 2}) self.assertEqual(u.newarg, 3) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass_with_new(frozenset): def __new__(cls, arg, newarg=None): self = super().__new__(cls, arg) self.newarg = newarg return self +======= + class subclass_with_new(frozenset): + def __new__(cls, arg, newarg=None): + self = super().__new__(cls, arg) + self.newarg = newarg + return self +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u = subclass_with_new([1, 2], newarg=3) self.assertIs(type(u), subclass_with_new) self.assertEqual(set(u), {1, 2}) @@ -1907,6 +1991,7 @@ def test_iter_and_mutate(self): list(si) def test_merge_and_mutate(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class X: def __hash__(self): @@ -1914,6 +1999,14 @@ def __hash__(self): def __eq__(self, o): other.clear() return False +======= + class X: + def __hash__(self): + return hash(0) + def __eq__(self, o): + other.clear() + return False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) other = set() other = {X() for i in range(10)} @@ -1928,6 +2021,7 @@ class _TestOperationsMutating: constructor2 = None def make_sets_of_bad_objects(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Bad: def __eq__(self, other): @@ -1940,6 +2034,19 @@ def __eq__(self, other): return bool(randrange(2)) def __hash__(self): return randrange(2) +======= + class Bad: + def __eq__(self, other): + if not enabled: + return False + if randrange(20) == 0: + set1.clear() + if randrange(20) == 0: + set2.clear() + return bool(randrange(2)) + def __hash__(self): + return randrange(2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Don't behave poorly during construction. enabled = False set1 = self.constructor1(Bad() for _ in range(randrange(50))) diff --git a/test/dynamo/cpython/3_13/test_sort.diff b/test/dynamo/cpython/3_13/test_sort.diff index 2e719655d9dfa..28fdf1c454cf4 100644 --- a/test/dynamo/cpython/3_13/test_sort.diff +++ b/test/dynamo/cpython/3_13/test_sort.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/test_sort.py b/test/dynamo/cpython/3_13/test_sort.py +<<<<<<< HEAD index 2a7cfb7affa..4805f1fcceb 100644 --- a/test/dynamo/cpython/3_13/test_sort.py +++ b/test/dynamo/cpython/3_13/test_sort.py @@ -1,3 +1,57 @@ +======= +index 2a7cfb7affa..d661ae544b9 100644 +--- a/test/dynamo/cpython/3_13/test_sort.py ++++ b/test/dynamo/cpython/3_13/test_sort.py +@@ -1,3 +1,54 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_sort.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -60,15 +70,23 @@ index 2a7cfb7affa..4805f1fcceb 100644 from test import support import random import unittest +<<<<<<< HEAD @@ -39,7 +93,7 @@ def check(tag, expected, raw, compare=None): nerrors += 1 return +======= +@@ -39,7 +90,7 @@ def check(tag, expected, raw, compare=None): + nerrors += 1 + return + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestBase(unittest.TestCase): +class TestBase(__TestCase): def testStressfully(self): # Try a variety of sizes at and around powers of 2, and at powers of 10. sizes = [0] +<<<<<<< HEAD @@ -48,32 +102,33 @@ class TestBase(unittest.TestCase): sizes.extend(range(n-1, n+2)) sizes.extend([10, 100, 1000]) @@ -210,11 +228,36 @@ index 2a7cfb7affa..4805f1fcceb 100644 self.assertIs(opt, ref) #note: not assertEqual! We want to ensure *identical* behavior. +======= +@@ -151,7 +202,7 @@ class TestBase(unittest.TestCase): + self.assertEqual(forced, native) + #============================================================================== + +-class TestBugs(unittest.TestCase): ++class TestBugs(__TestCase): + + def test_bug453523(self): + # bug 453523 -- list.sort() crasher. +@@ -188,7 +239,7 @@ class TestBugs(unittest.TestCase): + + #============================================================================== + +-class TestDecorateSortUndecorate(unittest.TestCase): ++class TestDecorateSortUndecorate(__TestCase): + + def test_decorated(self): + data = 'The quick Brown fox Jumped over The lazy Dog'.split() +@@ -309,7 +360,7 @@ def check_against_PyObject_RichCompareBool(self, L): + self.assertIs(opt, ref) + #note: not assertEqual! We want to ensure *identical* behavior. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -class TestOptimizedCompares(unittest.TestCase): +class TestOptimizedCompares(__TestCase): def test_safe_object_compare(self): heterogeneous_lists = [[0, 'foo'], [0.0, 'foo'], +<<<<<<< HEAD @@ -331,17 +389,18 @@ class TestOptimizedCompares(unittest.TestCase): # This test is by ppperry. It ensures that unsafe_object_compare is # verifying ms->key_richcompare == tp->richcompare before comparing. @@ -260,6 +303,11 @@ index 2a7cfb7affa..4805f1fcceb 100644 @@ -408,4 +468,4 @@ class TestOptimizedCompares(unittest.TestCase): #============================================================================== +======= +@@ -408,4 +459,4 @@ class TestOptimizedCompares(unittest.TestCase): + #============================================================================== + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/test/dynamo/cpython/3_13/test_sort.py b/test/dynamo/cpython/3_13/test_sort.py index ab9f094cab1b3..3fdf07d067e76 100644 --- a/test/dynamo/cpython/3_13/test_sort.py +++ b/test/dynamo/cpython/3_13/test_sort.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_sort.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case @@ -102,6 +105,7 @@ def testStressfully(self): sizes.extend(range(n-1, n+2)) sizes.extend([10, 100, 1000]) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class Complains(object): maybe_complain = True @@ -129,6 +133,34 @@ def __lt__(self, other): def __repr__(self): return "Stable(%d, %d)" % (self.key, self.index) +======= + class Complains(object): + maybe_complain = True + + def __init__(self, i): + self.i = i + + def __lt__(self, other): + if Complains.maybe_complain and random.random() < 0.001: + if verbose: + print(" complaining at", self, other) + raise RuntimeError + return self.i < other.i + + def __repr__(self): + return "Complains(%d)" % self.i + + class Stable(object): + def __init__(self, key, i): + self.key = key + self.index = i + + def __lt__(self, other): + return self.key < other.key + + def __repr__(self): + return "Stable(%d, %d)" % (self.key, self.index) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for n in sizes: x = list(range(n)) @@ -213,6 +245,7 @@ def test_bug453523(self): # If this fails, the most likely outcome is a core dump. # Mutations during a list sort should raise a ValueError. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class C: def __lt__(self, other): @@ -221,6 +254,15 @@ def __lt__(self, other): else: L.append(3) return random.random() < 0.5 +======= + class C: + def __lt__(self, other): + if L and random.random() < 0.75: + L.pop() + else: + L.append(3) + return random.random() < 0.5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) L = [C() for i in range(50)] self.assertRaises(ValueError, L.sort) @@ -284,6 +326,7 @@ def k(x): def test_key_with_mutating_del(self): data = list(range(10)) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class SortKiller(object): def __init__(self, x): @@ -293,11 +336,22 @@ def __del__(self): data[:] = range(20) def __lt__(self, other): return id(self) < id(other) +======= + class SortKiller(object): + def __init__(self, x): + pass + def __del__(self): + del data[:] + data[:] = range(20) + def __lt__(self, other): + return id(self) < id(other) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertRaises(ValueError, data.sort, key=SortKiller) def test_key_with_mutating_del_and_exception(self): data = list(range(10)) ## dup = data[:] +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class SortKiller(object): def __init__(self, x): @@ -306,6 +360,15 @@ def __init__(self, x): def __del__(self): del data[:] data[:] = list(range(20)) +======= + class SortKiller(object): + def __init__(self, x): + if x > 2: + raise RuntimeError + def __del__(self): + del data[:] + data[:] = list(range(20)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertRaises(RuntimeError, data.sort, key=SortKiller) ## major honking subtlety: we *can't* do: ## @@ -389,6 +452,7 @@ def test_unsafe_object_compare(self): # This test is by ppperry. It ensures that unsafe_object_compare is # verifying ms->key_richcompare == tp->richcompare before comparing. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class WackyComparator(int): def __lt__(self, other): @@ -401,6 +465,19 @@ class WackyList1(list): class WackyList2(list): def __lt__(self, other): raise ValueError +======= + class WackyComparator(int): + def __lt__(self, other): + elem.__class__ = WackyList2 + return int.__lt__(self, other) + + class WackyList1(list): + pass + + class WackyList2(list): + def __lt__(self, other): + raise ValueError +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) L = [WackyList1([WackyComparator(i), i]) for i in range(10)] elem = L[-1] @@ -414,10 +491,16 @@ def __lt__(self, other): # The following test is also by ppperry. It ensures that # unsafe_object_compare handles Py_NotImplemented appropriately. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class PointlessComparator: def __lt__(self, other): return NotImplemented +======= + class PointlessComparator: + def __lt__(self, other): + return NotImplemented +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) L = [PointlessComparator(), PointlessComparator()] self.assertRaises(TypeError, L.sort) self.assertRaises(TypeError, [(x,) for x in L].sort) diff --git a/test/dynamo/cpython/3_13/test_tuple.diff b/test/dynamo/cpython/3_13/test_tuple.diff index d7ae3af2a2c82..86830a95bd054 100644 --- a/test/dynamo/cpython/3_13/test_tuple.diff +++ b/test/dynamo/cpython/3_13/test_tuple.diff @@ -1,8 +1,15 @@ diff --git a/test/dynamo/cpython/3_13/test_tuple.py b/test/dynamo/cpython/3_13/test_tuple.py +<<<<<<< HEAD index 9ce80c5e8ea..1080e85e31a 100644 --- a/test/dynamo/cpython/3_13/test_tuple.py +++ b/test/dynamo/cpython/3_13/test_tuple.py @@ -1,4 +1,58 @@ +======= +index 9ce80c5e8ea..e52c0cbc140 100644 +--- a/test/dynamo/cpython/3_13/test_tuple.py ++++ b/test/dynamo/cpython/3_13/test_tuple.py +@@ -1,4 +1,55 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -from test import support, seq_tests +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] @@ -10,9 +17,12 @@ index 9ce80c5e8ea..1080e85e31a 100644 +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_tuple.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -60,6 +70,7 @@ index 9ce80c5e8ea..1080e85e31a 100644 +from test import support +import seq_tests import unittest +<<<<<<< HEAD import gc @@ -43,27 +97,30 @@ class TupleTest(seq_tests.CommonTest): @@ -128,6 +139,13 @@ index 9ce80c5e8ea..1080e85e31a 100644 @@ -510,4 +569,4 @@ class TupleTest(seq_tests.CommonTest): # pileup 262,143 mean 8.0 coll 262,143 z +92683.6 +======= + + import gc +@@ -510,4 +561,4 @@ class TupleTest(seq_tests.CommonTest): + # pileup 262,143 mean 8.0 coll 262,143 z +92683.6 + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/test/dynamo/cpython/3_13/test_tuple.py b/test/dynamo/cpython/3_13/test_tuple.py index 914e3443f2874..cce32f814c5b1 100644 --- a/test/dynamo/cpython/3_13/test_tuple.py +++ b/test/dynamo/cpython/3_13/test_tuple.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_tuple.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case @@ -97,30 +100,49 @@ def test_keyword_args(self): tuple(sequence=()) def test_keywords_in_subclass(self): +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass(tuple): pass +======= + class subclass(tuple): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u = subclass([1, 2]) self.assertIs(type(u), subclass) self.assertEqual(list(u), [1, 2]) with self.assertRaises(TypeError): subclass(sequence=()) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass_with_init(tuple): def __init__(self, arg, newarg=None): self.newarg = newarg +======= + class subclass_with_init(tuple): + def __init__(self, arg, newarg=None): + self.newarg = newarg +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u = subclass_with_init([1, 2], newarg=3) self.assertIs(type(u), subclass_with_init) self.assertEqual(list(u), [1, 2]) self.assertEqual(u.newarg, 3) +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class subclass_with_new(tuple): def __new__(cls, arg, newarg=None): self = super().__new__(cls, arg) self.newarg = newarg return self +======= + class subclass_with_new(tuple): + def __new__(cls, arg, newarg=None): + self = super().__new__(cls, arg) + self.newarg = newarg + return self +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) u = subclass_with_new([1, 2], newarg=3) self.assertIs(type(u), subclass_with_new) self.assertEqual(list(u), [1, 2]) @@ -408,9 +430,14 @@ def test_track_dynamic(self): @support.cpython_only def test_track_subtypes(self): # Tuple subtypes must always be tracked +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class MyTuple(tuple): pass +======= + class MyTuple(tuple): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.check_track_dynamic(MyTuple, True) @support.cpython_only @@ -462,8 +489,12 @@ def test_no_comdat_folding(self): # Issue 8847: In the PGO build, the MSVC linker's COMDAT folding # optimization causes failures in code that relies on distinct # function addresses. +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class T(tuple): pass +======= + class T(tuple): pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaises(TypeError): [3,] + T((1,2)) diff --git a/test/dynamo/cpython/3_13/test_userdict.diff b/test/dynamo/cpython/3_13/test_userdict.diff index 8b8101ae9091d..191ef756deb40 100644 --- a/test/dynamo/cpython/3_13/test_userdict.diff +++ b/test/dynamo/cpython/3_13/test_userdict.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/test_userdict.py b/test/dynamo/cpython/3_13/test_userdict.py +<<<<<<< HEAD index 61e79f553e8..75b789633ed 100644 --- a/test/dynamo/cpython/3_13/test_userdict.py +++ b/test/dynamo/cpython/3_13/test_userdict.py @@ -1,3 +1,57 @@ +======= +index 61e79f553e8..c953390355e 100644 +--- a/test/dynamo/cpython/3_13/test_userdict.py ++++ b/test/dynamo/cpython/3_13/test_userdict.py +@@ -1,3 +1,54 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userdict.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -60,7 +70,11 @@ index 61e79f553e8..75b789633ed 100644 # Check every path through every method of UserDict from test import mapping_tests, support +<<<<<<< HEAD @@ -215,10 +269,10 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol): +======= +@@ -215,10 +266,10 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Decorate existing test with recursion limit, because # the test is for C structure, but `UserDict` is a Python structure. diff --git a/test/dynamo/cpython/3_13/test_userdict.py b/test/dynamo/cpython/3_13/test_userdict.py index 75b789633edf0..5ea65afc57be8 100644 --- a/test/dynamo/cpython/3_13/test_userdict.py +++ b/test/dynamo/cpython/3_13/test_userdict.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userdict.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_userlist.diff b/test/dynamo/cpython/3_13/test_userlist.diff index 77e951de5fad6..cdaea0820f4fd 100644 --- a/test/dynamo/cpython/3_13/test_userlist.diff +++ b/test/dynamo/cpython/3_13/test_userlist.diff @@ -1,17 +1,27 @@ diff --git a/test/dynamo/cpython/3_13/test_userlist.py b/test/dynamo/cpython/3_13/test_userlist.py +<<<<<<< HEAD index 312702c8e39..d3d8dbf394a 100644 --- a/test/dynamo/cpython/3_13/test_userlist.py +++ b/test/dynamo/cpython/3_13/test_userlist.py @@ -1,7 +1,61 @@ +======= +index 312702c8e39..a4532922f5d 100644 +--- a/test/dynamo/cpython/3_13/test_userlist.py ++++ b/test/dynamo/cpython/3_13/test_userlist.py +@@ -1,7 +1,58 @@ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +<<<<<<< HEAD +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userlist.py + +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) +import sys +import torch +import torch._dynamo.test_case @@ -58,12 +68,17 @@ index 312702c8e39..d3d8dbf394a 100644 +# ======= END DYNAMO PATCH ======= + # Check every path through every method of UserList +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from collections import UserList -from test import list_tests +import list_tests import unittest from test import support +<<<<<<< HEAD @@ -56,9 +110,10 @@ class UserListTest(list_tests.CommonTest): @@ -81,6 +96,11 @@ index 312702c8e39..d3d8dbf394a 100644 def test_userlist_copy(self): @@ -69,9 +124,9 @@ class UserListTest(list_tests.CommonTest): +======= + +@@ -69,9 +120,9 @@ class UserListTest(list_tests.CommonTest): + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Decorate existing test with recursion limit, because # the test is for C structure, but `UserList` is a Python structure. - test_repr_deep = support.infinite_recursion(25)( @@ -89,7 +109,11 @@ index 312702c8e39..d3d8dbf394a 100644 + # test_repr_deep = support.infinite_recursion(25)( + # list_tests.CommonTest.test_repr_deep, + # ) +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/test/dynamo/cpython/3_13/test_userlist.py b/test/dynamo/cpython/3_13/test_userlist.py index 9bd988c458836..31fa162deb262 100644 --- a/test/dynamo/cpython/3_13/test_userlist.py +++ b/test/dynamo/cpython/3_13/test_userlist.py @@ -4,9 +4,12 @@ # ruff: noqa # flake8: noqa +<<<<<<< HEAD # Test copied from # https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userlist.py +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import torch import torch._dynamo.test_case @@ -110,10 +113,16 @@ def test_mixedadd(self): def test_getitemoverwriteiter(self): # Verify that __getitem__ overrides *are* recognized by __iter__ +<<<<<<< HEAD with torch._dynamo.error_on_graph_break(False): class T(self.type2test): def __getitem__(self, key): return str(key) + '!!!' +======= + class T(self.type2test): + def __getitem__(self, key): + return str(key) + '!!!' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(next(iter(T((1,2)))), "0!!!") def test_userlist_copy(self): diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 8fe89e84546b5..f15e94c8332d2 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -17,10 +17,20 @@ from torch._dynamo.backends.common import aot_autograd from torch._dynamo.testing import CompileCounterWithBackend from torch._higher_order_ops.wrap import tag_activation_checkpoint +<<<<<<< HEAD from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu, skipIfRocm from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_CUDNN_ATTENTION, + SM90OrLater, +) +from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu, skipIfRocm +from torch.testing._internal.inductor_utils import HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.two_tensor import TwoTensor from torch.utils.checkpoint import ( checkpoint, @@ -29,6 +39,7 @@ ) +<<<<<<< HEAD if HAS_CUDA_AND_TRITON: import triton from triton import language as tl @@ -49,6 +60,9 @@ def add_one_kernel( tl.store(out_ptr + offsets, output, mask=mask) +======= +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) @@ -203,6 +217,7 @@ def _compare_orig_and_checkpointed_fns( # The original version and the checkpointed version of the same function # should produce the same outputs and the same gradients under torch.compile. +<<<<<<< HEAD def clone_args(args): cloned_args = [] for arg in args: @@ -272,6 +287,48 @@ def runtime_wrapper(*runtime_args): return runtime_wrapper run(export_compiler) +======= + # Run original version + cloned_args_orig_fn = [] + for arg in args: + cloned_args_orig_fn.append( + arg.detach().clone().requires_grad_(arg.requires_grad) + ) + torch.manual_seed(0) + compiled_orig_fn = torch.compile( + orig_fn, fullgraph=fullgraph, backend="inductor" + ) + result_orig_fn = compiled_orig_fn(*cloned_args_orig_fn) + result_orig_fn.sum().backward() + + # Run checkpointed version + cloned_args_checkpointed_fn = [] + for arg in args: + cloned_args_checkpointed_fn.append( + arg.detach().clone().requires_grad_(arg.requires_grad) + ) + torch.manual_seed(0) + compiled_checkpointed_fn = torch.compile( + checkpointed_fn, fullgraph=fullgraph, backend="inductor" + ) + result_checkpointed_fn = compiled_checkpointed_fn(*cloned_args_checkpointed_fn) + result_checkpointed_fn.sum().backward() + + # Check that outputs and gradients are equal + self.assertEqual( + result_orig_fn, + result_checkpointed_fn, + msg="Output mismatch between the original version and the checkpointed version of the same function", + ) + for cloned_arg_orig_fn, cloned_arg_checkpointed_fn in zip( + cloned_args_orig_fn, cloned_args_checkpointed_fn + ): + self.assertEqual( + cloned_arg_orig_fn.grad, + cloned_arg_checkpointed_fn.grad, + msg="Gradient mismatch between the original version and the checkpointed version of the same function", + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_tags_function(self, device): def gn(x, y): @@ -292,7 +349,11 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_tags_function_via_global_checkpoint(self, device): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -311,7 +372,11 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_tags_function_with_kwargs(self, device): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -331,7 +396,11 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_tags_sequential_layers(self, device): def gn(x): x = x.cos() @@ -356,7 +425,11 @@ def fn(x): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_tags_multiple_checkpoints(self, device): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -378,7 +451,11 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_tags_module(self, device): class MockModule(torch.nn.Module): def __init__(self) -> None: @@ -406,7 +483,11 @@ def fn(x): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_tags_decomps(self, device): # Ensures that tags are passed on through decompositions as well class MockModule(torch.nn.Module): @@ -441,7 +522,11 @@ def fn(x): ) self._validate(fn, backend, x) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._inductor.config.patch(fallback_random=True) def test_tags_recomputed_rand(self, device): def gn(x, y): @@ -465,7 +550,11 @@ def fn(x, y): backend = "inductor" self._validate(fn, backend, x, y) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._inductor.config.patch(fallback_random=True) def test_tags_rand(self, device): def gn(x, y): @@ -492,7 +581,11 @@ def fn(x, y): backend = "inductor" self._validate(fn, backend, x, y) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._inductor.config.patch(fallback_random=True) def test_tags_dropout(self, device): # Figure out a way to test the number of inductor_random calls @@ -600,7 +693,11 @@ def _factory_fn(): Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no_primal}.""", ) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_fallback(self, device): def gn(x, y): torch._dynamo.graph_break() @@ -628,7 +725,11 @@ def fn(x, y): self.assertEqual(cnt.op_count, 2) self.assertEqual(len(cnt.graphs), 2) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_kwargs(self, device): def gn(x, y, z=None): a = torch.matmul(x, y) @@ -662,7 +763,11 @@ def fn(x, y, z): body_function = getattr(cnt.graphs[0], wrap_node.args[0].name) self.assertEqual(op_count(body_function), 2) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_symints_location(self, device): def gn(x, y): return torch.matmul(x, torch.nn.functional.dropout(y, 0.5)) @@ -692,7 +797,11 @@ def fn(x, y): wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint) self.assertEqual(len(wrap_node.args), 3) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_must_recompute(self, device): def context_fn_must_recompute_mm(): @@ -759,7 +868,11 @@ def fn(x): ), ) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device): def selective_checkpointing_context_fn(): @@ -806,6 +919,7 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) +<<<<<<< HEAD @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_triton_kernel(self, device): @@ -874,6 +988,9 @@ def fn(x, y): self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_tensor_subclass(self, device): def selective_checkpointing_context_fn(): @@ -923,7 +1040,11 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_custom_rule(self, device): def _get_custom_policy(meta): @@ -988,7 +1109,11 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_partial_ctx_fn(self, device): def selective_checkpointing_context_fn(no_recompute_list): @@ -1034,7 +1159,11 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_outplace_op(self, device): def selective_checkpointing_context_fn(): @@ -1079,7 +1208,11 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_list_ops(self, device): def selective_checkpointing_context_fn(): @@ -1127,7 +1260,11 @@ def fn(x, y): "In-place op support in selective checkpointing + torch.compile " "requires TorchDispatchMode + torch.compile work to complete" ) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_compile_selective_checkpoint_inplace_op(self, device): def selective_checkpointing_context_fn(): no_recompute_list = [ @@ -1173,7 +1310,11 @@ def fn(x, y): self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @torch._inductor.config.patch(fallback_random=True) def test_compile_selective_checkpoint_random_op(self, device): @@ -1233,7 +1374,11 @@ def fn(x): self._validate(fn, backend, x, skip_check=not preserve_rng_state) self._compare_orig_and_checkpointed_fns(gn, fn, x) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") def test_compile_selective_checkpoint_invalid_context(self): def gn(x, y): @@ -1271,7 +1416,11 @@ def fn(x, y): ): self._validate(fn, backend, x, y) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) def test_compile_selective_checkpoint_parametrization(self): def sac_policy(): @@ -1365,7 +1514,11 @@ def reset_parameters(self): self.assertEqual(input.grad, input_compiled.grad) @skipIfRocm +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_autocast_flash_attention(self, device): def fn(primals_1, primals_2, primals_3): return torch.ops.aten._scaled_dot_product_efficient_attention.default( @@ -1389,7 +1542,11 @@ def gn(*args): res = opt_gn(*args) self.assertEqual(ref, res) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_error_msg(self, device): class MockModule(torch.nn.Module): def __init__(self) -> None: @@ -1413,7 +1570,11 @@ def fn(x): ): opt_fn(x) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_list_inputs(self, device): class MockModule(torch.nn.Module): def __init__(self) -> None: @@ -1438,7 +1599,11 @@ def fn(x, ys): res = opt_fn(x, [y, z]) self.assertEqual(ref, res) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_pattern_matcher(self, device): # Check that the sdpa op is recomputed in the backward graph # tests percolate_tags @@ -1480,15 +1645,25 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): opt_fn = torch.compile(fn, backend=backend, fullgraph=True) opt_fn(*args1).sum().backward() +<<<<<<< HEAD fwd_graph = aot_graphs[0] op1 = torch.ops.aten._scaled_dot_product_flash_attention.default op2 = torch.ops.aten._scaled_dot_product_cudnn_attention.default +======= + if PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater: + op = torch.ops.aten._scaled_dot_product_cudnn_attention.default + else: + op = torch.ops.aten._scaled_dot_product_flash_attention.default + + fwd_graph = aot_graphs[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue( count_ops( fwd_graph, [], freq=1, +<<<<<<< HEAD op=op1, ) or count_ops( @@ -1498,6 +1673,12 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): op=op2, ) ) +======= + op=op, + ) + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bwd_graph = aot_graphs[1] # Check that sin is not recomputed in the backward graph - checks percolate tags self.assertTrue(count_ops(bwd_graph, [], freq=0, op=torch.ops.aten.sin.default)) @@ -1507,6 +1688,7 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): bwd_graph, [], freq=1, +<<<<<<< HEAD op=op1, ) or count_ops( @@ -1514,11 +1696,18 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): [], freq=1, op=op2, +======= + op=op, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) @requires_distributed() +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_distributed_utils_checkpoint_wrapper(self): from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as dist_checkpoint_wrapper, @@ -1544,7 +1733,11 @@ def forward(self, x): self.assertEqual(ref, res) @requires_distributed() +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) def test_dynamo_does_not_trace_getattr_as_top_frame(self): # inline_inbuilt_nn_modules is a proxy to emulate what FSDP tests do. diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index 8b2740596a72d..aac581bc7321d 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -1213,7 +1213,11 @@ def fn(x): @torch._functorch.config.patch(donated_buffer=True) def test_donated_buffer1(self): +<<<<<<< HEAD logger_name = "torch._functorch._aot_autograd.graph_compile" +======= + logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch.compile() def relu(x): @@ -1233,9 +1237,15 @@ def relu(x): @torch._functorch.config.patch("donated_buffer", True) def test_donated_buffer2(self): +<<<<<<< HEAD logger_name = "torch._functorch._aot_autograd.graph_compile" # we will reuse the graph for g across f1 and f2 +======= + logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + + # we will re-use the graph for g across f1 and f2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch.compile() def g(activation, param2): return torch.matmul(activation, param2) @@ -1255,9 +1265,15 @@ def f(inp, param1, param2): @torch._functorch.config.patch("donated_buffer", True) def test_donated_buffer3(self): +<<<<<<< HEAD logger_name = "torch._functorch._aot_autograd.graph_compile" # we will reuse the graph for g across f1 and f2 +======= + logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + + # we will re-use the graph for g across f1 and f2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch.compile() def g(activation, param2): return torch.matmul(activation, param2) @@ -1278,7 +1294,11 @@ def f(inp, param1, param2): @torch._functorch.config.patch("donated_buffer", True) def test_donated_buffer4(self): +<<<<<<< HEAD logger_name = "torch._functorch._aot_autograd.graph_compile" +======= + logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Mod(torch.nn.Module): def __init__(self) -> None: @@ -1309,7 +1329,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @torch._functorch.config.patch("donated_buffer", True) def test_donated_buffer5(self): +<<<<<<< HEAD logger_name = "torch._functorch._aot_autograd.graph_compile" +======= + logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch.compile() def f(x, z): @@ -1339,7 +1363,10 @@ def f(x, z): FileCheck().check("bw_donated_idxs=[1]").run("\n".join(captured.output)) @torch._functorch.config.patch("donated_buffer", True) +<<<<<<< HEAD @torch._dynamo.config.patch("graph_break_on_nn_param_ctor", False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_donated_buffer6(self): if is_dynamic_shape_test(self._testMethodName): # parameters should not be dynamic shape @@ -1347,7 +1374,11 @@ def test_donated_buffer6(self): # SymNodeVariable() is not a constant return +<<<<<<< HEAD logger_name = "torch._functorch._aot_autograd.graph_compile" +======= + logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def fn(x): p = torch.nn.Parameter(x + 123) diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index 04af76c90c529..65548877554c6 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -37,7 +37,11 @@ skipIfWindows, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_triton +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +from torch.testing._internal.triton_utils import requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.two_tensor import TwoTensor @@ -175,6 +179,7 @@ def fn(x, y): if hasattr(a, "_dynamo_weak_dynamic_indices"): del a._dynamo_weak_dynamic_indices self.assertEqual(eager_result, compiled_result) +<<<<<<< HEAD if functorch_config.bundled_autograd_cache: self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) @@ -183,6 +188,11 @@ def fn(x, y): self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) +======= + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) @@ -195,10 +205,14 @@ def fn(x, y): autotune_expect = 2 if device == GPU_TYPE else 0 +<<<<<<< HEAD if functorch_config.bundled_autograd_cache: self.assertEqual(len(cache_info.inductor_artifacts), 0) else: self.assertEqual(len(cache_info.inductor_artifacts), 2) +======= + self.assertEqual(len(cache_info.inductor_artifacts), 2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect) self.assertEqual(len(cache_info.aot_autograd_artifacts), 1) self.assertEqual(len(cache_info.pgo_artifacts), 0) @@ -216,6 +230,7 @@ def fn(x, y): compiled_result.sum().backward() if hasattr(a, "_dynamo_weak_dynamic_indices"): del a._dynamo_weak_dynamic_indices +<<<<<<< HEAD if functorch_config.bundled_autograd_cache: self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) @@ -224,6 +239,11 @@ def fn(x, y): self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 4) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) +======= + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 4) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2) @@ -236,10 +256,15 @@ def fn(x, y): # Hot load and hit with fresh_cache(): cache_info = torch.compiler.load_cache_artifacts(artifact_bytes) +<<<<<<< HEAD if functorch_config.bundled_autograd_cache: self.assertEqual(len(cache_info.inductor_artifacts), 0) else: self.assertEqual(len(cache_info.inductor_artifacts), 2) +======= + + self.assertEqual(len(cache_info.inductor_artifacts), 2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect) self.assertEqual(len(cache_info.aot_autograd_artifacts), 1) self.assertEqual(len(cache_info.pgo_artifacts), 0) @@ -250,12 +275,17 @@ def fn(x, y): if hasattr(a, "_dynamo_weak_dynamic_indices"): del a._dynamo_weak_dynamic_indices self.assertEqual(eager_result, compiled_result) +<<<<<<< HEAD if functorch_config.bundled_autograd_cache: self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) else: self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 4) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2) +======= + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 4) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 2) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) @@ -295,6 +325,7 @@ def fn(x, y): @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch({"enable_autograd_cache": True}) +<<<<<<< HEAD def test_vmap(self): """ make @@ -345,6 +376,8 @@ def fn(x, y): @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch({"enable_autograd_cache": True}) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_multi_graph_specialization(self): """ Verify multi graph specializations all cache hit @@ -486,6 +519,7 @@ def fn(x, y): @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) +<<<<<<< HEAD @functorch_config.patch({"enable_autograd_cache": True}) @functorch_config.patch({"strict_autograd_cache": True}) @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") @@ -520,6 +554,16 @@ def fn(x, y): {"enable_autograd_cache": True, "view_replay_for_aliased_outputs": True} ) def test_view_replay(self): +======= + @functorch_config.patch( + {"enable_autograd_cache": True, "view_replay_for_aliased_outputs": True} + ) + def test_view_replay_bypass(self): + """ + Shoud bypass when view replay is turned on + """ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def fn(a): tmp = a.detach() a.mul_(2) @@ -527,6 +571,7 @@ def fn(a): with torch.autograd._force_original_view_tracking(True): compiled_fn = torch.compile(fn) +<<<<<<< HEAD def run_and_check(miss, hit, bypass): self._clear_dynamo_and_codecache() @@ -546,6 +591,12 @@ def run_and_check(miss, hit, bypass): run_and_check(miss=1, hit=0, bypass=0) run_and_check(miss=1, hit=1, bypass=0) run_and_check(miss=1, hit=2, bypass=0) +======= + compiled_fn(torch.rand(2, 3)) + + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @@ -751,7 +802,11 @@ def fn(a, b): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch({"enable_autograd_cache": True}) @@ -807,7 +862,11 @@ def backward(ctx, grad_output): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch({"enable_autograd_cache": True}) @@ -849,7 +908,12 @@ def fn(a): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda + @requires_triton() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @functorch_config.patch({"enable_autograd_cache": True}) @@ -882,7 +946,11 @@ def backward(ctx, grad_output): def fn(a): return MyAutogradFunction.apply(a) +<<<<<<< HEAD a = torch.randn(5, device=GPU_TYPE, requires_grad=True) +======= + a = torch.randn(5, device="cuda", requires_grad=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a2 = a.clone().detach_().requires_grad_(True) compiled_fn = torch.compile(fn, backend="inductor") result = compiled_fn(a) @@ -902,6 +970,7 @@ def fn(a): self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0) +<<<<<<< HEAD @requires_cuda_and_triton @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch("fx_graph_cache", True) @@ -1110,6 +1179,8 @@ def my_triton_op2(x: torch.Tensor) -> torch.Tensor: # noqa: F811 self.assertEqual(fn(a2), result) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @inductor_config.patch("fx_graph_remote_cache", False) @inductor_config.patch({"fx_graph_cache": True}) @functorch_config.patch({"enable_autograd_cache": True}) @@ -1528,7 +1599,11 @@ def f(): result = f() self.assertEqual(result[0].device, torch.device("cuda:1")) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @inductor_config.patch("fx_graph_cache", True) @inductor_config.patch("fx_graph_remote_cache", False) @functorch_config.patch({"enable_autograd_cache": True}) @@ -1814,6 +1889,7 @@ def fn(x, y): self.assertEqual(eager_result, compiled_result) self.assertEqual(expected_grads[0], actual_grads[0]) self.assertEqual(expected_grads[1], actual_grads[1]) +<<<<<<< HEAD if functorch_config.bundled_autograd_cache: self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) @@ -1822,6 +1898,11 @@ def fn(x, y): self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 3) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) +======= + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 3) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) @@ -1835,10 +1916,14 @@ def fn(x, y): autotune_expect = 2 if device == GPU_TYPE else 0 +<<<<<<< HEAD if functorch_config.bundled_autograd_cache: self.assertEqual(len(cache_info.inductor_artifacts), 0) else: self.assertEqual(len(cache_info.inductor_artifacts), 3) +======= + self.assertEqual(len(cache_info.inductor_artifacts), 3) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect) self.assertEqual(len(cache_info.aot_autograd_artifacts), 1) self.assertEqual(len(cache_info.pgo_artifacts), 0) @@ -1852,10 +1937,14 @@ def fn(x, y): with fresh_cache(): cache_info = torch.compiler.load_cache_artifacts(artifact_bytes) +<<<<<<< HEAD if functorch_config.bundled_autograd_cache: self.assertEqual(len(cache_info.inductor_artifacts), 0) else: self.assertEqual(len(cache_info.inductor_artifacts), 3) +======= + self.assertEqual(len(cache_info.inductor_artifacts), 3) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect) self.assertEqual(len(cache_info.aot_autograd_artifacts), 1) self.assertEqual(len(cache_info.pgo_artifacts), 0) @@ -1879,6 +1968,7 @@ def fn(x, y): if i == 0: # initial compile +<<<<<<< HEAD if functorch_config.bundled_autograd_cache: self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) @@ -1888,6 +1978,13 @@ def fn(x, y): self.assertEqual( counters["inductor"]["fxgraph_lookup_write_file"], 3 ) +======= + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 3) + self.assertEqual( + counters["inductor"]["fxgraph_lookup_write_file"], 3 + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) self.assertEqual( diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 326a1e627b3f4..68c15ea480941 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -8,6 +8,7 @@ import torch._dynamo.test_case import torch._dynamo.testing import torch._dynamo.utils +<<<<<<< HEAD from torch.testing._internal.triton_utils import HAS_GPU, requires_gpu @@ -16,6 +17,12 @@ ) if HAS_GPU: +======= +from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda + + +if HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import triton from torch.testing._internal.triton_utils import add_kernel @@ -508,13 +515,21 @@ def test_amp_custom_fwd_bwd(self): class MyMM(torch.autograd.Function): @staticmethod +<<<<<<< HEAD @torch.amp.custom_fwd(device_type=device_type) +======= + @torch.amp.custom_fwd(device_type="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def forward(ctx, a, b): ctx.save_for_backward(a, b) return a.mm(b) @staticmethod +<<<<<<< HEAD @torch.amp.custom_bwd(device_type=device_type) +======= + @torch.amp.custom_bwd(device_type="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def backward(ctx, grad): a, b = ctx.saved_tensors return grad.mm(b.t()), a.t().mm(grad) @@ -1433,7 +1448,11 @@ def backward(ctx, grad_output, grad_dx): result = grad_output * dx + grad_dx * 6 * x # Intentionally return a wrong value to test if the backward is triggered twice. # Since if the first MyCube.apply returns values w/o requires_grad=True, +<<<<<<< HEAD # this backward would be only triggered once (the first MyCube.apply call), +======= + # this backward would be only triggered once (the first MyCube.appy call), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # as the second MyCube.apply is inlined by Dynamo and the corresponding backward # would be generated by autograd engine. return result * 0.5 @@ -1477,7 +1496,11 @@ def fn(): self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 1) +<<<<<<< HEAD @requires_gpu +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_triton_kernel_basic(self): class Add(torch.autograd.Function): @staticmethod @@ -1501,14 +1524,23 @@ def f(x, y): z = Add.apply(x, y) return z +<<<<<<< HEAD x = torch.randn(10, device=device_type, requires_grad=True) y = torch.randn(10, device=device_type, requires_grad=True) +======= + x = torch.randn(10, device="cuda", requires_grad=True) + y = torch.randn(10, device="cuda", requires_grad=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) z = f(x, y) loss = z.sum() loss.backward() self.assertEqual(x + y, z) +<<<<<<< HEAD @requires_gpu +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_triton_kernel_multiple_out(self): class Add(torch.autograd.Function): @staticmethod @@ -1536,8 +1568,13 @@ def f(x, y): z = Add.apply(x, y) return z +<<<<<<< HEAD x = torch.randn(10, device=device_type, requires_grad=True) y = torch.randn(10, device=device_type, requires_grad=True) +======= + x = torch.randn(10, device="cuda", requires_grad=True) + y = torch.randn(10, device="cuda", requires_grad=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) z, _ = f(x, y) loss = z.sum() loss.backward() diff --git a/test/dynamo/test_backends.py b/test/dynamo/test_backends.py index be1470c08e794..865eb7b32b5db 100644 --- a/test/dynamo/test_backends.py +++ b/test/dynamo/test_backends.py @@ -8,6 +8,10 @@ import torch._dynamo.backends import torch._dynamo.test_case from torch._dynamo.backends.debugging import ExplainWithBackend +<<<<<<< HEAD +======= +from torch._dynamo.backends.onnxrt import has_onnxruntime +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo.backends.tvm import has_tvm from torch._dynamo.testing import same from torch.fx._lazy_graph_module import _force_skip_lazy_graph_module @@ -16,7 +20,14 @@ onlyHPU, ) from torch.testing._internal.common_utils import skipIfHpu +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +from torch.testing._internal.inductor_utils import HAS_CUDA + + +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Seq(torch.nn.Module): @@ -130,10 +141,21 @@ def test_aot_eager_decomp_partition(self, device): def test_aot_ts(self, device): self._check_backend_works("aot_ts", device) +<<<<<<< HEAD @requires_cuda_and_triton def test_aot_cudagraphs(self, device): self._check_backend_works("cudagraphs", device) +======= + @requires_cuda + def test_aot_cudagraphs(self, device): + self._check_backend_works("cudagraphs", device) + + @unittest.skipIf(not has_onnxruntime(), "requires onnxruntime") + def test_onnxrt(self, device): + self._check_backend_works("onnxrt", device) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not has_tvm(), "requires tvm") def test_tvm(self, device): self._check_backend_works("tvm", device) diff --git a/test/dynamo/test_base_hop.py b/test/dynamo/test_base_hop.py index 607b502351aaf..a8178dfdfcc90 100644 --- a/test/dynamo/test_base_hop.py +++ b/test/dynamo/test_base_hop.py @@ -1,4 +1,8 @@ # Owner(s): ["module: dynamo"] +<<<<<<< HEAD +======= +import unittest +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import unittest.mock as mock import torch @@ -12,6 +16,13 @@ ) from torch._higher_order_ops.schema import find_hop_schema from torch.testing._internal.common_utils import instantiate_parametrized_tests +<<<<<<< HEAD +======= +from torch.testing._internal.inductor_utils import HAS_CUDA + + +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def normalize_graph(gm): diff --git a/test/dynamo/test_bytecode_utils.py b/test/dynamo/test_bytecode_utils.py index ea5ec7b55a4fd..330a6878ce50e 100644 --- a/test/dynamo/test_bytecode_utils.py +++ b/test/dynamo/test_bytecode_utils.py @@ -284,7 +284,11 @@ def fn(): def nothing(*args): pass +<<<<<<< HEAD code, _ = bytecode_transformation.transform_code_object(fn.__code__, nothing) +======= + code = bytecode_transformation.transform_code_object(fn.__code__, nothing) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(code.co_exceptiontable, fn.__code__.co_exceptiontable) @skipIfNotPy311 @@ -300,7 +304,11 @@ def fn(): def nothing(*args): pass +<<<<<<< HEAD code, _ = bytecode_transformation.transform_code_object(fn.__code__, nothing) +======= + code = bytecode_transformation.transform_code_object(fn.__code__, nothing) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(code.co_exceptiontable, fn.__code__.co_exceptiontable) @skipIfNotPy311 @@ -544,6 +552,7 @@ def fn(x): self.assertEqual(fn(torch.ones(3)), torch.ones(3) + 1) +<<<<<<< HEAD # https://github.com/pytorch/pytorch/issues/160471 def test_extended_args_starts_line(self): # NOTE: need to LOAD_CONST i before LOAD_FAST x @@ -572,6 +581,8 @@ def transformations(instructions, _): bytecode_transformation.transform_code_object(fn.__code__, transformations) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class BytecodeHookTests(torch._dynamo.test_case.TestCase): def test_bytecode_hook(self): diff --git a/test/dynamo/test_callback.py b/test/dynamo/test_callback.py index e516364626314..3998a27ac68ad 100644 --- a/test/dynamo/test_callback.py +++ b/test/dynamo/test_callback.py @@ -8,7 +8,11 @@ from torch._dynamo.test_case import run_tests, TestCase from torch._guards import CompileId from torch.testing._internal.common_utils import TEST_WITH_ROCM +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +from torch.testing._internal.inductor_utils import HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class CallbackTests(TestCase): @@ -61,7 +65,11 @@ def test_counter_assertion(self) -> None: @unittest.skipIf( TEST_WITH_ROCM, "ROCm outputs a different number of autotuning logs" ) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @unittest.skipIf(not HAS_CUDA, "requires triton") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._inductor.config.patch(force_disable_caches=True) def test_triggers(self) -> None: torch._dynamo.reset() @@ -106,8 +114,17 @@ def forward(self, x): end=CallbackArgs(callback_trigger=, compile_id='1/0') start=CallbackArgs(callback_trigger=, compile_id='1/0') end=CallbackArgs(callback_trigger=, compile_id='1/0') +<<<<<<< HEAD start=CallbackArgs(callback_trigger=, compile_id='0/0') end=CallbackArgs(callback_trigger=, compile_id='0/0')""", # noqa: B950 +======= +start=CallbackArgs(callback_trigger=, compile_id='1/0') +end=CallbackArgs(callback_trigger=, compile_id='1/0') +start=CallbackArgs(callback_trigger=, compile_id='0/0') +end=CallbackArgs(callback_trigger=, compile_id='0/0') +start=CallbackArgs(callback_trigger=, compile_id='0/0') +end=CallbackArgs(callback_trigger=, compile_id='0/0')""", # noqa: B950 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) order.clear() diff --git a/test/dynamo/test_compiler_bisector.py b/test/dynamo/test_compiler_bisector.py index 161f9674cd4a1..010df6c291724 100644 --- a/test/dynamo/test_compiler_bisector.py +++ b/test/dynamo/test_compiler_bisector.py @@ -1,5 +1,9 @@ # Owner(s): ["module: dynamo"] +<<<<<<< HEAD +======= +import unittest +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from contextlib import contextmanager from importlib import import_module @@ -10,18 +14,30 @@ from torch._inductor.compiler_bisector import CompilerBisector from torch._inductor.test_case import TestCase from torch.library import _scoped_library, Library +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +from torch.testing._internal.inductor_utils import HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten = torch.ops.aten +<<<<<<< HEAD +======= +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f32 = torch.float32 i64 = torch.int64 i32 = torch.int32 +<<<<<<< HEAD @requires_cuda_and_triton +======= +@requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestCompilerBisector(TestCase): test_ns = "_test_bisector" diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 5e188e76dc56e..0baf124f49f3f 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -388,7 +388,11 @@ def fn(x, s0, s1): ref1 = fn(x, s1, s1) res1 = opt_fn(x, s1, s1) +<<<<<<< HEAD # We have a re-compilation because of changing inputs +======= + # We have a re-compilation because of chaning inputs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(cnts.frame_count, 2) self.assertEqual(ref1, res1) @@ -403,7 +407,11 @@ def fn(x, s0, s1): ref0 = fn(x, s0, s1) res0 = opt_fn(x, s0, s1) +<<<<<<< HEAD # We have a re-compilation because of changing inputs +======= + # We have a re-compilation because of chaning inputs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(cnts.frame_count, 2) self.assertEqual(ref0, res0) @@ -1742,6 +1750,7 @@ def f(x): opt_f = torch.compile(f, backend="eager") opt_f(torch.randn(2, 2)) +<<<<<<< HEAD # Regression test to make sure dynamo won't crash on these kwargs. def test_sdpa_kernel_ctx_manager_kwargs(self): backends = [torch.nn.attention.SDPBackend.MATH] @@ -1819,6 +1828,8 @@ def f(x): opt_f = torch.compile(f, backend="eager") opt_f(torch.randn(2, 2)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_torch_profiler_use_after_with_block(self): counters.clear() diff --git a/test/dynamo/test_debug_utils.py b/test/dynamo/test_debug_utils.py index eae4d06d98904..828bfa0f7a531 100644 --- a/test/dynamo/test_debug_utils.py +++ b/test/dynamo/test_debug_utils.py @@ -1,6 +1,10 @@ # Owner(s): ["module: dynamo"] import os +<<<<<<< HEAD +======= +import unittest +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from unittest.mock import patch import torch @@ -9,8 +13,16 @@ from torch._dynamo.debug_utils import aot_graph_input_parser, generate_env_vars_string from torch._dynamo.test_case import TestCase from torch.testing._internal.common_device_type import instantiate_device_type_tests +<<<<<<< HEAD +======= +from torch.testing._internal.inductor_utils import HAS_CUDA + + +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f32 = torch.float32 i64 = torch.int64 i32 = torch.int32 diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 6af25a385c2f6..601dc4934cb45 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -8,9 +8,14 @@ import torch import torch._dynamo.test_case import torch._dynamo.testing +<<<<<<< HEAD from torch._dynamo.exc import IncorrectUsage, Unsupported from torch._dynamo.utils import counters from torch.testing._internal.common_utils import skipIfWindows +======= +from torch._dynamo.exc import IncorrectUsage +from torch._dynamo.utils import counters +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def my_custom_function(x): @@ -515,6 +520,7 @@ def fn(x, s): fn(x, State(41)) self.assertEqual(cnts.frame_count, 2) +<<<<<<< HEAD def test_nonstrict_trace_int_and_float_output(self): @torch._dynamo.nonstrict_trace def trace_me(x): @@ -532,6 +538,8 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nonstrict_trace_tuple_and_sym_int_output(self): @torch._dynamo.nonstrict_trace def trace_me(x): @@ -689,7 +697,17 @@ def fn(p): fn(p) self.assertFalse(True) # must raise error before this except torch._dynamo.exc.Unsupported as e: +<<<<<<< HEAD self.assertIn("Invalid input type for nonstrict_trace-ed function", str(e)) +======= + msg = """ +For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type .Point>, please use one of the following to register the type with pytree: + * `torch.utils._pytree.register_constant` + * `torch.utils._pytree.register_dataclass` + * `torch.utils._pytree.register_pytree_node` +""" # NOQA: B950 + self.assertIn(msg, str(e)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nonstrict_trace_nested_custom_class_error(self): class Point: @@ -735,6 +753,7 @@ def fn(x, y): fn(torch.ones(10), torch.ones(1)) self.assertFalse(True) # must raise error before this except torch._dynamo.exc.Unsupported as e: +<<<<<<< HEAD self.assertIn("Invalid input type for nonstrict_trace-ed function", str(e)) def test_nonstrict_trace_custom_class_output_error(self): @@ -764,6 +783,15 @@ def fn(x): self.assertIn( "Unsupported output type for nonstrict_trace-ed function", str(e) ) +======= + msg = """ +For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type .Point>, please use one of the following to register the type with pytree: + * `torch.utils._pytree.register_constant` + * `torch.utils._pytree.register_dataclass` + * `torch.utils._pytree.register_pytree_node` +""" # NOQA: B950 + self.assertIn(msg, str(e)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nonstrict_newly_constructed_trace_register_constant_type_error(self): class State: @@ -800,10 +828,19 @@ def fn(x): fn(x) self.assertFalse(True) # must raise error before this except torch._dynamo.exc.Unsupported as e: +<<<<<<< HEAD self.assertIn( "Input marked with `pytree.register_constant` constructed in the `torch.compile` region", str(e), ) +======= + msg = """ +You are calling a `nonstrict_trace`-ed function with an input that contains an object of type .State>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region. + +Please construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub. +""" # NOQA: B950 + self.assertIn(msg, str(e)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nonstrict_trace_object_in_context_error(self): class Point: @@ -846,9 +883,23 @@ def fn(x, y): fn(x, y) self.assertFalse(True) # must raise error before this except torch._dynamo.exc.Unsupported as e: +<<<<<<< HEAD self.assertIn( "Invalid use of pytree_flatten with nonstrict_trace-ed function", str(e) ) +======= + msg = """ +You are calling a `nonstrict_trace`-ed function where one one of the inputs has been registered with a `pytree_flatten` that puts an object of type .Point> into the context. + +Please consider modifying that `pytree_flatten` to avoid putting the object into context, and apply one of the following to .Point> + * `torch.utils._pytree.register_constant` + * `torch.utils._pytree.register_dataclass` + * `torch.utils._pytree.register_pytree_node` + +If the above doesn't work, please subtmit an issue to GitHub. +""" # NOQA: B950 + self.assertIn(msg, str(e)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_graph_break(self): cnts = torch._dynamo.testing.CompileCounter() @@ -893,9 +944,12 @@ def gn(x): self.assertEqual(gn(inp), inp + 3) self.assertEqual(cnts.frame_count, 1) +<<<<<<< HEAD @skipIfWindows( msg="TODO: (xuhancn), confirm if torch.compiler.disable work on Windows." ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_disable_recursive_false(self): def fn2(x): return x + 1 @@ -1067,10 +1121,18 @@ def fn3(x): self.assertEqual(cnts.frame_count, 2) self.assertEqual(cnts.op_count, 4) +<<<<<<< HEAD with self.assertRaisesRegex( Unsupported, r"Skip calling `torch.compiler.disable\(\)`d function" ): fn3(torch.randn(4, 5)) +======= + try: + fn3(torch.randn(4, 5)) + self.assertFalse(True) + except torch._dynamo.exc.Unsupported as e: + self.assertIn("Skip calling `torch.compiler.disable()`d function", str(e)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_disable_optimize(self): cnt = torch._dynamo.testing.CompileCounter() @@ -1720,6 +1782,7 @@ def f4(x): ): f4(torch.randn(3)) +<<<<<<< HEAD def test_error_on_graph_break(self): cnts = torch._dynamo.testing.CompileCounter() @@ -2044,6 +2107,8 @@ def outer_f2(x): with self.assertRaises(Unsupported): outer_f2(inp) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index 3b1c9315336e1..d2c971ee48154 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -3,11 +3,18 @@ # ruff: noqa: TRY002 import itertools +<<<<<<< HEAD import operator import types import unittest import weakref from collections import defaultdict, namedtuple, OrderedDict, UserDict +======= +import types +import unittest +import weakref +from collections import defaultdict, namedtuple, OrderedDict +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Any import torch @@ -18,6 +25,7 @@ import torch.utils.checkpoint from torch._dynamo.testing import same from torch._dynamo.utils import dict_items +<<<<<<< HEAD from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, make_dynamo_test, @@ -25,16 +33,21 @@ parametrize, ) from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SimpleDict(dict): pass +<<<<<<< HEAD class DummyUserDict(UserDict): pass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class DictTests(torch._dynamo.test_case.TestCase): def test_dict_subclass_instantiation(self): def fn(x): @@ -792,6 +805,7 @@ def fn(x): x = torch.randn(4) self.assertEqual(fn(x), opt_fn(x)) +<<<<<<< HEAD def test_construct_user_dict_and_return(self): def fn(x): return DummyUserDict({"a": x + 1}) @@ -803,6 +817,8 @@ def fn(x): opt_fn = torch.compile(fn, backend="eager", fullgraph=True) self.assertEqual(res["a"], opt_fn(x)["a"]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_fn_id(self): def fn(x, f): d = {id(f): 3} @@ -922,7 +938,11 @@ def test_mapping_proxy_existing_local_mutation(self): def fn(x): # Dynamo should not cause a graph break here because it knows that +<<<<<<< HEAD # the existing proxy can't point to this new dict +======= + # the existing proxy cant point to this new dict +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) other_dict = {} other_dict["d"] = 4 y = torch.sin(x * mp["c"]) @@ -946,6 +966,7 @@ def fn(x): self.assertEqual(["b", "c", "a"], list(opt_fn(x).keys())) self.assertEqual(fn(x), opt_fn(x)) +<<<<<<< HEAD def test_mapping_proxy_ban_muation_on_dict_realization(self): def fn(x): class Foo: @@ -965,6 +986,8 @@ class Foo: self.assertEqual(ref, res) self.assertEqual(foo1.bar, foo2.bar) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_overridden_get_item(self): class MyDict(dict): def __init__(self, *args, **kwargs): @@ -1014,10 +1037,19 @@ def fn(b: Any): a = {"one": torch.ones(1)} return a | b +<<<<<<< HEAD from torch._dynamo.exc import Unsupported for arg in args: with self.assertRaisesRegex(Unsupported, "Observed exception"): +======= + from torch._dynamo.exc import InternalTorchDynamoError + + for arg in args: + with self.assertRaisesRegex( + InternalTorchDynamoError, "unsupported operand type" + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _ = fn(arg) def test_builtin_or_with_diff_keys(self): @@ -1060,6 +1092,7 @@ def f(x): self.assertEqual(ref, res) +<<<<<<< HEAD @parametrize("op", ["or_", "and_", "xor", "sub"]) def test_dict_keys_binop(self, op): op = getattr(operator, op) @@ -1688,6 +1721,8 @@ def test_move_to_end(self): p.move_to_end("a") self.assertEqual(list(p.keys()), list("bc")) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 847f3a6fd2166..b65f2c9e6787d 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -4,7 +4,10 @@ import re import traceback import unittest +<<<<<<< HEAD import unittest.mock +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import warnings from functools import lru_cache @@ -13,7 +16,11 @@ import torch._dynamo.config import torch._dynamo.test_case import torch.utils._pytree as python_pytree +<<<<<<< HEAD from torch._dynamo.exc import ResumePrologueTracingError, Unsupported +======= +from torch._dynamo.exc import Unsupported +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo.testing import skipIfNotPy312 from torch._dynamo.utils import counters from torch.testing._internal.common_utils import ( @@ -47,7 +54,11 @@ def __exit__(self, exc_type, exc_value, traceback): pass +<<<<<<< HEAD class ErrorMessagesTest(LoggingTestCase): +======= +class GraphBreakMessagesTest(LoggingTestCase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dynamic_shape_operator(self): def fn(): return torch.nonzero(torch.rand([10, 10])) @@ -62,7 +73,10 @@ def fn(): Developer debug context: aten.nonzero.default +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0036.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -84,7 +98,10 @@ def fn(): Developer debug context: aten.linalg_lstsq.default +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0037.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -107,7 +124,10 @@ def fn(x): Developer debug context: call_method TensorVariable() item () {} +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0124.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -131,7 +151,10 @@ def fn(x): Developer debug context: aten.equal.default +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0033.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -159,7 +182,10 @@ def fn(lst): Developer debug context: TensorVariable() +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0207.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -181,11 +207,17 @@ def fn(it): Hint: Avoid calling `zip.__iter__` in your code. Hint: Please report an issue to PyTorch. Hint: Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). This can happen unintentionally if a previous graph break happens with a builtin iterator in the local scope. +<<<<<<< HEAD Hint: List/dict comprehensions in Python <= 3.11 result in implicit function calls, which Dynamo cannot trace as a top level frame. Possible workarounds are (1) use a loop instead of a comprehension, (2) fix any graph breaks in the function above the comprehension, (3) wrap the comprehension in a function, or (4) use Python 3.12+. Developer debug context: call_method UserDefinedObjectVariable(zip) __iter__ [] {} For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0156.html +======= + + Developer debug context: call_method UserDefinedObjectVariable(zip) __iter__ () {} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -210,11 +242,17 @@ def fn(x, items): Hint: Please report an issue to PyTorch. Hint: Consider moving the creation of dict view object (e.g. `dict.keys()`, `dict.items()`,) to the compiled region, instead of passing it as an input to the compiled region. Hint: Dynamo does not fully support tracing builtin iterators (e.g. `map`, `zip`, `enumerate`) passed in from uncompiled to compiled regions (e.g. `torch.compile(fn)(enumerate(...))`). This can happen unintentionally if a previous graph break happens with a builtin iterator in the local scope. +<<<<<<< HEAD Hint: List/dict comprehensions in Python <= 3.11 result in implicit function calls, which Dynamo cannot trace as a top level frame. Possible workarounds are (1) use a loop instead of a comprehension, (2) fix any graph breaks in the function above the comprehension, (3) wrap the comprehension in a function, or (4) use Python 3.12+. Developer debug context: call_method UserDefinedObjectVariable(dict_items) __iter__ [] {} For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0156.html +======= + + Developer debug context: call_method UserDefinedObjectVariable(dict_items) __iter__ () {} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -238,7 +276,10 @@ def fn(it): Developer debug context: call_function UserDefinedObjectVariable(zip) [] {} +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0147.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -262,7 +303,10 @@ def fn(obj): Developer debug context: Attempted SETUP_WITH/BEFORE_WITH on ConstantVariable(int: 3) +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0142.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -290,10 +334,14 @@ def fn(x): Exception:test Traceback: File "test_error_messages.py", line N, in fn +<<<<<<< HEAD return x + 1 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0219.html""", +======= + return x + 1""", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_unsupported_builtin(self): @@ -312,7 +360,10 @@ def fn(): Developer debug context: builtin print [] False +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0059.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -338,7 +389,10 @@ def post_munge(s): Developer debug context: module: unittest.case, qualname: skip, skip reason: +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -360,7 +414,10 @@ def fn(): Developer debug context: module: torch._dynamo.decorators, qualname: disable, skip reason: +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -389,7 +446,10 @@ def post_munge(s): Developer debug context: qualname: skip, name: skip, filename: `case.py`, skip reason: skipped according trace_rules.lookup unittest +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0008.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -411,7 +471,10 @@ def fn(): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -432,7 +495,10 @@ def fn(): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{'msg': ConstantVariable(str: 'test graph break')}` +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -454,7 +520,10 @@ def fn(): Developer debug context: module: _warnings, qualname: warn, skip reason: +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -482,8 +551,12 @@ def fn(x): Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py Developer debug context: module: optree._C, qualname: PyCapsule.flatten, skip reason: +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""", +======= +""", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @scoped_load_inline @@ -519,6 +592,7 @@ def f(x): first_graph_break = next(iter(counters["graph_break"].keys())) first_graph_break = re.sub(r"mylib(_v\d+)?", "mylib", first_graph_break) +<<<<<<< HEAD # HACK: this patches around the fact that PyBind11 improperly sets the # __qualname__ attribute on functions and methods; see # https://github.com/pybind/pybind11/issues/5774. This should be removed if @@ -526,6 +600,8 @@ def f(x): first_graph_break = re.sub( r"pybind11_detail_function_record_v[^ .]+", "PyCapsule", first_graph_break ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertExpectedInline( first_graph_break, @@ -536,8 +612,12 @@ def f(x): Hint: If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`. Developer debug context: module: mylib, qualname: PyCapsule.foobar, skip reason: +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""", +======= +""", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) cpp_source = """ @@ -589,7 +669,10 @@ def fn(x, y): Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: TensorVariable(), step: ConstantVariable(NoneType: None) +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -609,9 +692,14 @@ def fn(): Hint: Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled. Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. +<<<<<<< HEAD Developer debug context: raised exception RuntimeError([ConstantVariable(str: 'test')]) For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0088.html +======= + Developer debug context: raised exception ExceptionVariable() + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -637,7 +725,10 @@ def fn(mod): Developer debug context: Foo +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0119.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -666,7 +757,10 @@ def fn(mod, x): Developer debug context: nn.Module subclass: Foo, name: attr, attribute type: module +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0161.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -696,7 +790,10 @@ def fn(): Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr), GenericContextWrappingVariable(GenericCtxMgr)] +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -711,8 +808,12 @@ def fn(): Hint: Remove the `torch._dynamo.graph_break()` call. Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html""", +======= +""", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_load_build_class(self): @@ -733,7 +834,10 @@ class Foo: Developer debug context: +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0075.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -766,7 +870,10 @@ def post_munge(s): Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. Developer debug context: GET_AITER with args (, Instruction(GET_AITER) +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0082.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -790,14 +897,23 @@ def post_munge(s): lambda: torch.compile(fn, backend="eager", fullgraph=True)(), """\ Reconstruction failure +<<<<<<< HEAD Explanation: Dynamo has no bytecode reconstruction implemented for sourceless variable UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)). +======= + Explanation: Dynamo has no bytecode reconstruction implemented for sourceless variable UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Hint: If Dynamo is attempting to trace a return statement and your code is attempting to return a variable that Dynamo cannot reconstruct, then remove it from the return statement. Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one. Hint: Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have reconstruction rules may be fundamentally unreconstructable. +<<<<<<< HEAD Developer debug context: UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)) For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0092.html +======= + Developer debug context: UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -833,7 +949,10 @@ def post_munge(s): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) User code traceback: File "test_error_messages.py", line N, in test_reconstruction_failure_gb torch.compile(fn, backend="eager")() @@ -846,14 +965,23 @@ def post_munge(s): post_munge(munge_exc(records[1].exc_info[1], suppress_suffix=True, skip=0)), """\ Reconstruction failure +<<<<<<< HEAD Explanation: Dynamo has no bytecode reconstruction implemented for sourceless variable UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)). +======= + Explanation: Dynamo has no bytecode reconstruction implemented for sourceless variable UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Hint: If Dynamo is attempting to trace a return statement and your code is attempting to return a variable that Dynamo cannot reconstruct, then remove it from the return statement. Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one. Hint: Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have reconstruction rules may be fundamentally unreconstructable. +<<<<<<< HEAD Developer debug context: UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)) For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0092.html +======= + Developer debug context: UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -882,7 +1010,10 @@ def fn(x): Developer debug context: +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0087.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -906,7 +1037,10 @@ def fn(x): Developer debug context: attempted to jump with TensorVariable() +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -973,7 +1107,10 @@ def fn(x): Developer debug context: value: ConstantVariable(bool: False) +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0034.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -1017,7 +1154,10 @@ def gn(): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -1070,7 +1210,10 @@ def gn(): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in fn @@ -1106,7 +1249,10 @@ def hn(x): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) User code traceback: File "test_error_messages.py", line N, in test_nested_compile_user_frames torch.compile(fn, backend="eager")(torch.randn(3)) @@ -1220,7 +1366,10 @@ def f3(x): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) User code traceback: File "test_error_messages.py", line N, in test_graph_break_traceback_collapsed_resume_frames f1(torch.randn(3)) @@ -1305,12 +1454,20 @@ def post_munge(s): lambda: outer(f, torch.randn(3)), """\ Skip calling `torch.compiler.disable()`d function +<<<<<<< HEAD Explanation: Skip calling function `.f at 0xmem_addr>` since it was wrapped with `torch.compiler.disable` (reason: None) Hint: Remove the `torch.compiler.disable` call Developer debug context: .f at 0xmem_addr> For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0098.html +======= + Explanation: Skip calling function `.f at 0xmem_addr>` since it was wrapped with `torch.compiler.disable` (reason: None) + Hint: Remove the `torch.compiler.disable` call + + Developer debug context: .f at 0xmem_addr> + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in outer @@ -1327,12 +1484,20 @@ def g(x): lambda: outer(g, torch.randn(3)), """\ Skip calling `torch.compiler.disable()`d function +<<<<<<< HEAD Explanation: Skip calling function `.g at 0xmem_addr>` since it was wrapped with `torch.compiler.disable` (reason: test message) Hint: Remove the `torch.compiler.disable` call Developer debug context: .g at 0xmem_addr> For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0098.html +======= + Explanation: Skip calling function `.g at 0xmem_addr>` since it was wrapped with `torch.compiler.disable` (reason: test message) + Hint: Remove the `torch.compiler.disable` call + + Developer debug context: .g at 0xmem_addr> + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in outer @@ -1358,7 +1523,10 @@ def forward(self, x): Developer debug context: source: LocalSource(local_name='fn', is_input=True, dynamism=None, is_derefed_cell_contents=False) +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0148.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_error_messages.py", line N, in outer @@ -1366,6 +1534,7 @@ def forward(self, x): post_munge=post_munge, ) +<<<<<<< HEAD # Test that errors while tracing resume function prologues do not get suppressed def test_graph_break_in_buggy_resume_prologue(self): import torch._dynamo.bytecode_transformation as bt @@ -1408,6 +1577,8 @@ def bad_clean_and_assemble_instructions(instructions, *args): ): fn(torch.randn(3)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py index ad56417ed568d..26fcad604f323 100644 --- a/test/dynamo/test_exc.py +++ b/test/dynamo/test_exc.py @@ -43,7 +43,10 @@ def fn001(x): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from user code: File "test_exc.py", line N, in fn001 @@ -183,7 +186,10 @@ def fn001(x): Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) User code traceback: File "test_exc.py", line N, in test_graph_break_log torch.compile(fn001, backend="eager")(torch.randn(1)) @@ -253,6 +259,7 @@ def fn(x, shape): Model: ==> L['shape'][0]: 0 +<<<<<<< HEAD ==> L['shape'][1]: 0 ==> L['shape'][2]: 0 ==> L['x'].size()[0]: 3 @@ -260,6 +267,15 @@ def fn(x, shape): ==> L['x'].stride()[0]: 1 ==> s3: 0 ==> s52: 0 +======= + ==> L['shape'][1]: 1 + ==> L['shape'][2]: 1 + ==> L['x'].size()[0]: 3 + ==> L['x'].storage_offset(): 0 + ==> L['x'].stride()[0]: 1 + ==> s3: 1 + ==> s52: 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ==> s77: 3 ==> s86: 0 @@ -317,16 +333,27 @@ def fn(x, shape): %split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {}) Model: +<<<<<<< HEAD ==> L['shape'][0]: 0 ==> L['shape'][1]: 0 +======= + ==> L['shape'][0]: 1 + ==> L['shape'][1]: 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ==> L['shape'][2]: 0 ==> L['x'].size()[0]: 3 ==> L['x'].storage_offset(): 0 ==> L['x'].stride()[0]: 1 ==> s3: 0 +<<<<<<< HEAD ==> s52: 0 ==> s77: 3 ==> s86: 0 +======= + ==> s52: 1 + ==> s77: 3 + ==> s86: 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Assertions: ==> (== 0 L['x'].storage_offset()) diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index 43fdc335b8c20..064e7bd45fbbd 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -136,6 +136,7 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) +<<<<<<< HEAD def test_exception_with_vars(self): def fn(x): try: @@ -150,6 +151,8 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_autocast_with_exception(self): class Optimizer(torch.autograd.Function): @staticmethod @@ -186,7 +189,11 @@ def test_propagate_exception_inside_ctx_manager(self): def cm(): try: yield +<<<<<<< HEAD except BaseException: # noqa: B036 +======= + except BaseException: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise ValueError # noqa: B904 @contextlib.contextmanager @@ -264,7 +271,11 @@ def ctx(): for x, y in args: try: fn(x, y) +<<<<<<< HEAD except BaseException: # noqa: B036 +======= + except BaseException: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_exc = sys.exc_info() fix_exc_context(frame_exc[1], new_exc[1], prev_exc[1]) prev_exc = new_exc @@ -272,7 +283,11 @@ def ctx(): try: fixed_ctx = prev_exc[1].__context__ raise prev_exc[1] +<<<<<<< HEAD except BaseException: # noqa: B036 +======= + except BaseException: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prev_exc[1].__context__ = fixed_ctx raise @@ -306,7 +321,11 @@ def fn(x): x = torch.randn(4) fn(x) +<<<<<<< HEAD # Can't use fullgraph=True because RERAISE is not supported +======= + # Cant use fullgraph=True because RERAISE is not supported +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) opt_fn = torch.compile(fn, backend="eager") opt_fn(x) @@ -763,7 +782,11 @@ def fn(t): raise GeneratorExit except Exception: return t.sin() +<<<<<<< HEAD except BaseException: # noqa: B036 +======= + except BaseException: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return t.cos() t = torch.randn(2) diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 53c9e2b79f381..9bf139d7ae839 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -368,6 +368,7 @@ def func(x): self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) +<<<<<<< HEAD def test_immutable_list_dict(self): class M(torch.nn.Module): def forward(self, x1, x2): @@ -387,6 +388,8 @@ def forward(self, x1, x2): res = torch.compile(ep.module(), dynamic=True, fullgraph=True)(x1, x2) self.assertTrue(torch._dynamo.utils.same(res, M()(x1, x2))) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dupes(self): inp = torch.tensor([0.1, 0.1]) @@ -3536,7 +3539,11 @@ def forward(self, pred, x): [3, 3, 4, 5], [true_graph, true_graph, false_graph, false_graph], [true_guard_code, true_guard_code, false_guard_code, false_guard_code], +<<<<<<< HEAD # Outer shape env should have no guards in it because we never specialize on the outer symbool. +======= + # Outter shape env should have no guards in it because we never specialize on the outter symbool. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [[], [], [], []], ) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 5b8aa5c61e405..90e946320fa38 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -11,7 +11,10 @@ import operator import random import sys +<<<<<<< HEAD import types +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import typing import unittest from dataclasses import dataclass, field @@ -31,7 +34,11 @@ EagerAndRecordGraphs, normalize_gm, ) +<<<<<<< HEAD from torch._dynamo.utils import ifdynstaticdefault, range_iterator, same +======= +from torch._dynamo.utils import ifdynstaticdefault, same +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo.variables import ConstantVariable, SkipFunctionVariable from torch._dynamo.variables.lists import RangeVariable from torch.nn import functional as F @@ -268,6 +275,7 @@ def test_itertools_product(a, b): v = v + x * i return v +<<<<<<< HEAD def test_itertools_product_args(self): @torch.compile(backend="eager", fullgraph=True) def fn(*args, **kwargs): @@ -316,6 +324,8 @@ def test_itertools_filterfalse_basic(a, b): a += x return a +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @make_test def test_itertools_chain(a, b): v = a @@ -568,11 +578,14 @@ def test_tuple2(a, b): args = [a, b] return sub(*args) +<<<<<<< HEAD @make_test def test_tuple_map(a, b): t = tuple(map(torch.sin, [a, b])) return t[0] + t[1] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_size_tuple_add(self): def fn(): size = torch.Size([]) @@ -1294,7 +1307,11 @@ def test_module_constant(x, y): @make_test def test_inline_softmax(x, y): +<<<<<<< HEAD # This is common in some huggingface models +======= + # This is common in sme huggingface models +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch.nn.Softmax(dim=-1)(x + y * 2) @make_test @@ -1757,6 +1774,7 @@ def test_tuple_contains(a, b): return a - b @make_test +<<<<<<< HEAD def test_set_in_frozenset(x): var = set("abc") other = set([frozenset("abc")]) @@ -1764,6 +1782,48 @@ def test_set_in_frozenset(x): return x + 1 else: return x - 1 +======= + def test_set_invalid_ConstantVariable_op(a, b): + s = set({"banana", "apple", "orange"}) + try: + s - 1 + except TypeError: + return a + b + except Exception: + return a - b + else: + return a * b + + @make_test + def test_set_pop_raise_KeyError(a, b): + s = set() + try: + s.pop() + except KeyError: + return a + b + except Exception: + return a - b + else: + return a * b + + @make_test + def test_set_issubset(a, b): + vals1 = {"a", "b", "c"} + vals2 = {"b", "c"} + vals3 = {"b", "e", "f"} + if vals2.issubset(vals1) and not vals2.issubset(vals3): + return a + b + return a - b + + @make_test + def test_set_issuperset(a, b): + vals1 = {"a", "b", "c"} + vals2 = {"b", "c"} + vals3 = {"b", "e", "f"} + if vals1.issuperset(vals2) and not vals1.issuperset(vals3): + return a + b + return a - b +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @make_test def test_set_update_bytecode(x): @@ -1783,6 +1843,184 @@ def test_set_update_list_with_duplicated_items(x): else: return x - 1 +<<<<<<< HEAD +======= + @make_test + def test_set_contains(a, b): + vals = set(["a", "b", "c"]) + if "a" in vals: + x = a + b + else: + x = a - b + if "d" in vals: + y = a + b + else: + y = a - b + return x, y + + def test_set_isdisjoint(self): + x = {"apple", "banana", "cherry"} + y = {"google", "microsoft", "apple"} + + def fn(a): + if x.isdisjoint(y): + return a + 1 + else: + return a - 1 + + test = make_test(fn) + test(self) + + @make_test + def test_set_intersection(a, b): + set1 = {"apple", "banana", "cherry"} + set2 = {"google", "microsoft", "apple"} + set3 = {"shoes", "flipflops", "apple"} + intersection_set = set1.intersection(set2, set3) + if "apple" in intersection_set: + x = a + b + else: + x = a - b + if "banana" in intersection_set: + y = a + b + else: + y = a - b + if "shoes" in intersection_set: + z = a + b + else: + z = a - b + return x, y, z + + @make_test + def test_set_intersection_update(a, b): + set1 = {"apple", "banana", "cherry"} + set2 = {"google", "microsoft", "apple"} + set3 = {"shoes", "flipflops", "apple"} + set1.intersection_update(set2, set3) + if "apple" in set1: + x = a + b + else: + x = a - b + if "banana" in set1: + y = a + b + else: + y = a - b + if "shoes" in set1: + z = a + b + else: + z = a - b + return x, y, z + + @parametrize("_type", [set]) + def test_set_union(self, _type): + @make_test + def fn(a, b): + set1 = _type({"apple", "banana", "cherry"}) + set2 = _type({"google", "microsoft", "apple"}) + set3 = _type({"shoes", "flipflops", "sneakers"}) + union_set = set1.union(set2, set3) + if "apple" in union_set: + x = a + b + else: + x = a - b + if "banana" in union_set: + y = a + b + else: + y = a - b + if "shoes" in union_set: + z = a + b + else: + z = a - b + return x, y, z + + fn(self) + + @parametrize( + "fn_name", ["add", "symmetric_difference", "symmetric_difference_update"] + ) + def test_set_raise_TypeError(self, fn_name): + @make_test + def fn(a, b): + set1 = {"apple", "banana", "cherry"} + try: + getattr(set1, fn_name)() + except TypeError: + return a + b + return a - b + + fn(self) + + @make_test + def test_set_difference(a, b): + set1 = {"apple", "banana", "cherry"} + set2 = {"google", "microsoft", "apple"} + set3 = {"shoes", "flipflops", "sneakers"} + difference_set = set1.difference(set2, set3) + if "apple" in difference_set: + x = a + b + else: + x = a - b + if "banana" in difference_set: + y = a + b + else: + y = a - b + if "shoes" in difference_set: + z = a + b + else: + z = a - b + return x, y, z + + @make_test + def test_set_difference_update(a, b): + set1 = {"apple", "banana", "cherry"} + set2 = {"google", "microsoft", "apple"} + set3 = {"shoes", "flipflops", "sneakers"} + set1.difference_update(set2, set3) + if "apple" in set1: + x = a + b + else: + x = a - b + if "banana" in set1: + y = a + b + else: + y = a - b + if "shoes" in set1: + z = a + b + else: + z = a - b + return x, y, z + + @make_test + def test_set_symmetric_difference(a, b): + set1 = {"apple", "banana", "cherry"} + set2 = {"google", "microsoft", "apple"} + symmetric_diff_set = set1.difference(set2) + if "apple" in symmetric_diff_set: + x = a + b + else: + x = a - b + if "banana" in symmetric_diff_set: + y = a + b + else: + y = a - b + return x, y + + @make_test + def test_set_symmetric_difference_update(a, b): + set1 = {"apple", "banana", "cherry"} + set2 = {"google", "microsoft", "apple"} + set1.difference(set2) + if "apple" in set1: + x = a + b + else: + x = a - b + if "banana" in set1: + y = a + b + else: + y = a - b + return x, y + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_set_keys_view(self): from collections.abc import KeysView @@ -1815,6 +2053,26 @@ def fn(x): x = torch.rand(4) self.assertEqual(fn(x), opt_fn(x)) +<<<<<<< HEAD +======= + @parametrize("method", ["add", "__contains__"]) + def test_set_raise_TypeError_on_unshashable_obj(self, method): + @make_test + def fn(a, b): + s = set({1, 2, 3, 4}) + try: + m = getattr(s, method) + m([[]]) + except TypeError: + return a + b + except Exception: + return a - b + else: + return a * b + + fn(self) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_constant_set(self): s = set([1, 2]) @@ -2026,6 +2284,7 @@ def test_namedtuple_defaults(a, b): tmp = mytuple(a, xy=b) return mytuple(tmp.x, tmp[1], tmp.xy + b) +<<<<<<< HEAD @make_test def test_namedtuple_replace(a, b): mytuple = collections.namedtuple("mytuple", ["x", "y"]) @@ -2041,6 +2300,8 @@ def test_namedtuple_fields(a, b): else: return a - b +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class MyNamedTuple(NamedTuple): first: torch.Tensor second: torch.Tensor @@ -2822,6 +3083,7 @@ def fn(x, a, b): opt_fn = torch.compile(fullgraph=True, backend="eager")(fn) self.assertEqual(opt_fn(x, a, b), fn(x, a, b)) +<<<<<<< HEAD def test_list_setitem(self): def fn(a: int): some_array = [1, 2, 3] @@ -2842,6 +3104,8 @@ def fn(a: int): self.assertEqual(opt_fn(0), fn(0)) self.assertEqual(opt_fn(1), fn(1)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_pow_int(self): def fn(a, b): return torch.pow(a, b) @@ -3491,6 +3755,7 @@ def gen_random_range_args(self): args[2] = 1 return args +<<<<<<< HEAD def test_range_iterator_graph_break(self): @torch.compile(backend="eager") def fn(x): @@ -3536,6 +3801,8 @@ def test_range_iterator_2(a, b): return a + b return a - b +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_range_length(self): def test(*args, expected=None): r = range(*args) @@ -4036,6 +4303,7 @@ def fn(x): opt_fn = torch.compile(fn, backend="eager", fullgraph=True) self.assertEqual(fn(x), opt_fn(x)) +<<<<<<< HEAD def test_torch_get_device_module(self): def f1(): mod1 = torch.get_device_module() @@ -4114,6 +4382,8 @@ def f(): finally: torch = old_torch +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def udf_mul(x, y): return x * y @@ -4207,7 +4477,10 @@ def func(): self.assertEqual(cnts.frame_count, 3) self.assertEqual(cnts.op_count, 6) +<<<<<<< HEAD @torch._dynamo.config.patch(assume_dunder_attributes_remain_unchanged=False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_meth_default_tensor_args(self): """ Tests that we indeed reference (and mutate) "the one" default tensor arg @@ -4408,6 +4681,7 @@ def fn(a, b): fn(self) +<<<<<<< HEAD @parametrize( "method_name", [ @@ -4435,6 +4709,8 @@ def fn(a, b): fn(self) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_frozenset_construction(self): def fn(x): s = frozenset({x}) @@ -5144,6 +5420,7 @@ def __getattribute__(self, name): with self.assertRaises(Unsupported): a.call_function(None, [], {}) +<<<<<<< HEAD def test_inspect_method_source(self): class Mod(torch.nn.Module): def __init__(self): @@ -5167,6 +5444,8 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_parametrized_tests(FunctionTests) instantiate_parametrized_tests(DefaultsTests) diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index cfb3241d712d1..c87ad7b0fe3a3 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -22,13 +22,19 @@ def setUp(self): super().setUp() self._old = torch._dynamo.config.enable_faithful_generator_behavior torch._dynamo.config.enable_faithful_generator_behavior = True +<<<<<<< HEAD self._unittest_old = torch._dynamo.config.enable_trace_unittest torch._dynamo.config.enable_trace_unittest = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def tearDown(self): super().tearDown() torch._dynamo.config.enable_faithful_generator_behavior = self._old +<<<<<<< HEAD torch._dynamo.config.enable_trace_unittest = self._unittest_old +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _compile_check(self, fn, args=None, fullgraph=True): eager = EagerAndRecordGraphs() @@ -355,7 +361,11 @@ def fn(t, ctx): ctx = whoo() next(ctx) with self.assertRaisesRegex( +<<<<<<< HEAD Unsupported, "Detected a method call to a user-defined generator object." +======= + Unsupported, "Generator as graph argument is not supported" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): fn(t, ctx) @@ -374,7 +384,11 @@ def fn(t, ctx): ctx = whoo(t) next(ctx) with self.assertRaisesRegex( +<<<<<<< HEAD Unsupported, "Detected a method call to a user-defined generator object." +======= + Unsupported, "Generator as graph argument is not supported" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): fn(t, ctx) @@ -395,7 +409,11 @@ def fn(t, ctx): t = torch.randn(2) ctx = whoo() with self.assertRaisesRegex( +<<<<<<< HEAD Unsupported, "Detected a method call to a user-defined generator object." +======= + Unsupported, "Generator as graph argument is not supported" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): fn(t, ctx) @@ -413,8 +431,12 @@ def fn(t, ctx): t = torch.randn(2) ctx = whoo(t) with self.assertRaisesRegex( +<<<<<<< HEAD Unsupported, "Detected a method call to a user-defined generator object.", +======= + Unsupported, "Generator as graph argument is not supported" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): fn(t, ctx) @@ -890,6 +912,7 @@ def f(x): torch.compile(f, backend="eager", fullgraph=True)(torch.ones(3)), ) +<<<<<<< HEAD @make_dynamo_test def test_generator___contains__(self): def whoo(): @@ -921,6 +944,8 @@ def whoo(): self.assertRaises(StopIteration, next, g) self.assertFalse(3 in whoo()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestGeneratorSend(GeneratorTestsBase): def test_send(self): @@ -1515,6 +1540,7 @@ def fn(t): self._compile_check(fn) +<<<<<<< HEAD def test_return_const_value_in_except_and_finally(self): def whoo(): try: @@ -1585,6 +1611,8 @@ def fn(t): self._compile_check(fn) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_parametrized_tests(GeneratorTests) instantiate_parametrized_tests(TestGeneratorSend) diff --git a/test/dynamo/test_graph_deduplication.py b/test/dynamo/test_graph_deduplication.py index 004aee88a8633..6c7ce1acc19f4 100644 --- a/test/dynamo/test_graph_deduplication.py +++ b/test/dynamo/test_graph_deduplication.py @@ -4,16 +4,23 @@ import torch import torch.fx +<<<<<<< HEAD from torch._dynamo.graph_deduplication import apply_graph_deduplication from torch._dynamo.graph_utils import _detect_cycles from torch._dynamo.output_graph import FakeRootModule +======= +from torch._dynamo.graph_utils import _detect_cycles +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo.test_case import TestCase from torch._dynamo.testing import ( AotEagerAndRecordGraphs, extract_graph_and_tracker, normalize_gm, ) +<<<<<<< HEAD from torch.compiler import allow_in_graph +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._ordered_set import OrderedSet @@ -1109,6 +1116,7 @@ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): """, ) +<<<<<<< HEAD def test_tuple_return(self): @allow_in_graph def tuple_return(x, y): @@ -1224,6 +1232,8 @@ def fn(x0, x1, x2, y0, y1, y2): fn_opt(*args) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_graph_region_tracker.py b/test/dynamo/test_graph_region_tracker.py index ce456596fd55e..be851f2dfaecc 100644 --- a/test/dynamo/test_graph_region_tracker.py +++ b/test/dynamo/test_graph_region_tracker.py @@ -1,5 +1,9 @@ # Owner(s): ["module: dynamo"] import contextlib +<<<<<<< HEAD +======= +import os +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.fx @@ -8,6 +12,31 @@ from torch.utils._pytree import tree_map +<<<<<<< HEAD +======= +def get_nodes_by_name(graph, names): + nodes = [] + for node in graph.nodes: + if node.name in names: + nodes.append(node) + + return nodes + + +unique_ind = 0 + + +def track_same_nodes(names, graph, region_tracker): + global unique_ind + unique_ind += 1 + # find nodes in graph with names and track them + # as if they were at the same code location + nodes = get_nodes_by_name(graph, names) + for node in nodes: + region_tracker.track_node("x", unique_ind, node) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class GraphRegionTrackerTests(TestCase): def setUp(self): self.exit_stack = contextlib.ExitStack() @@ -195,6 +224,24 @@ def fn(x, y, z): ) def test_mismatched_global_state(self): +<<<<<<< HEAD +======= + @contextlib.contextmanager + def _hip_allow_tf32(): + # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new + # and only for MI300+ + hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) + os.environ["HIPBLASLT_ALLOW_TF32"] = "1" + + try: + yield + finally: + if hip_allow_tf32 is not None: + os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 + else: + del os.environ["HIPBLASLT_ALLOW_TF32"] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def inner_fn(x, y): x1 = x * 1 y1 = y + 1 @@ -235,6 +282,7 @@ def set_default_dtype_bfloat16(): def reset_default_dtype(): torch.set_default_dtype(old_dtype) +<<<<<<< HEAD for ctx in [ lambda: torch.set_grad_enabled(False), torch.autograd.grad_mode.inference_mode, @@ -258,6 +306,33 @@ def reset_default_dtype(): """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \ [['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""", ) +======= + tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext + with tf32_ctx(): + for ctx in [ + lambda: torch.set_grad_enabled(False), + torch.autograd.grad_mode.inference_mode, + lambda: torch.autograd.graph.disable_saved_tensors_hooks( + "This is not supported" + ), + # lambda: torch.set_num_threads(2), : Unsupported + (set_default_dtype_bfloat16, reset_default_dtype), + ( + lambda: torch.use_deterministic_algorithms(True), + lambda: torch.use_deterministic_algorithms(False), + ), + # (lambda: torch.use_deterministic_algorithms(True, warn_only=True), + # lambda: torch.use_deterministic_algorithms(False)), : Unsupported + create_toggle_fns("allow_bf16_reduced_precision_reduction"), + create_toggle_fns("allow_fp16_reduced_precision_reduction"), + create_toggle_fns("allow_tf32"), + ]: + self.assertExpectedInline( + self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx), + """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \ +[['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""", + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_mutation_tracking_simple(self): def fn(x, y, z): @@ -330,6 +405,7 @@ def fn(x, y): """[[['y', 'o1'], ['y_1', 'o2'], ['y_2', 'o3']]]""", ) +<<<<<<< HEAD def test_region_sorting(self): from torch._dynamo.graph_region_tracker import _sort_with_ref_region @@ -367,6 +443,8 @@ def fn(x, y): key = next(iter(tracker.node_to_duplicates.keys())) tracker.track_node(None, key) # this will fail if the node is added again +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py index c4ad29f69b438..b04bb989a50ba 100644 --- a/test/dynamo/test_guard_manager.py +++ b/test/dynamo/test_guard_manager.py @@ -1,7 +1,11 @@ # Owner(s): ["module: dynamo"] +<<<<<<< HEAD import abc import functools import inspect +======= +import functools +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import unittest import weakref @@ -70,8 +74,12 @@ def less_match_verbose_code_parts(expected): class GuardManagerTests(torch._dynamo.test_case.TestCase): def test_global_state_guard(self): +<<<<<<< HEAD root = RootGuardManager() guard = guards.GLOBAL_STATE(root, ["global_state_check"]) +======= + guard = guards.GLOBAL_STATE(["global_state_check"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(None)) with set_default_dtype(torch.double): self.assertFalse(guard(None)) @@ -112,9 +120,13 @@ def test_global_state_reason(self): self.assertEqual(guards.reason(), "grad_mode ") def test_python_lambda_leaf_guard(self): +<<<<<<< HEAD root = RootGuardManager() const_guard = guards.LAMBDA_GUARD( root, +======= + const_guard = guards.LAMBDA_GUARD( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) functools.partial(equals_match, expected=5), equals_match_verbose_code_parts(5), ) @@ -123,16 +135,25 @@ def test_python_lambda_leaf_guard(self): self.assertFalse(const_guard("foo")) def test_type_guard(self): +<<<<<<< HEAD root = RootGuardManager() foo = 4 guard = guards.TYPE_MATCH(root, id_type(foo), ["type(x) == int"]) +======= + foo = 4 + guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == int"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(5)) self.assertTrue(guard(4)) self.assertFalse(guard("foo")) foo = {"a": 1} +<<<<<<< HEAD guard = guards.TYPE_MATCH(root, id_type(foo), ["type(x) == dict"]) +======= + guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == dict"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(foo)) self.assertTrue(guard({})) self.assertFalse(guard(5)) @@ -145,32 +166,50 @@ def __init__(self, x, y): foo = Foo(1, 2) +<<<<<<< HEAD guard = guards.TYPE_MATCH(root, id_type(foo), ["type(x) == Foo"]) +======= + guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == Foo"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(foo)) self.assertFalse(guard({})) self.assertFalse(guard(5)) self.assertFalse(guard("foo")) def test_id_guard(self): +<<<<<<< HEAD root = RootGuardManager() foo = 4 guard = guards.ID_MATCH(root, id(foo), ["id(x) == id(foo)"]) +======= + foo = 4 + guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(foo)) self.assertFalse(guard(5)) self.assertFalse(guard("foo")) foo = {"a": 1} +<<<<<<< HEAD guard = guards.ID_MATCH(root, id(foo), ["id(x) == id(foo)"]) +======= + guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(foo)) self.assertFalse(guard({"a": 1})) self.assertFalse(guard({})) self.assertFalse(guard(5)) def test_equals_guard(self): +<<<<<<< HEAD root = RootGuardManager() foo = 4 guard = guards.EQUALS_MATCH(root, foo, ["x == 4"]) +======= + foo = 4 + guard = guards.EQUALS_MATCH(foo, ["x == 4"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(4)) self.assertFalse(guard(5)) @@ -178,7 +217,11 @@ def test_equals_guard(self): # tuple foo = (1, 2, 3) +<<<<<<< HEAD guard = guards.EQUALS_MATCH(root, foo, ["x == foo"]) +======= + guard = guards.EQUALS_MATCH(foo, ["x == foo"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(foo)) self.assertTrue(guard((1, 2, 3))) self.assertFalse(guard((1, 2, 3, 4))) @@ -186,22 +229,35 @@ def test_equals_guard(self): # list foo = [1, 2, 3] +<<<<<<< HEAD guard = guards.EQUALS_MATCH(root, foo, ["x == foo"]) +======= + guard = guards.EQUALS_MATCH(foo, ["x == foo"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(foo)) self.assertTrue(guard([1, 2, 3])) self.assertFalse(guard([1, 2, 3, 4])) # type foo = int +<<<<<<< HEAD guard = guards.EQUALS_MATCH(root, foo, ["x == foo"]) +======= + guard = guards.EQUALS_MATCH(foo, ["x == foo"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(foo)) self.assertTrue(guard(int)) self.assertFalse(guard(float)) def test_default_device_guard(self): +<<<<<<< HEAD root = RootGuardManager() foo = 1 guard = guards.DEFAULT_DEVICE(root, ["cpu device"]) +======= + foo = 1 + guard = guards.DEFAULT_DEVICE(["cpu device"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(foo)) try: @@ -211,15 +267,23 @@ def test_default_device_guard(self): torch.set_default_device(None) def test_length_check_guard(self): +<<<<<<< HEAD root = RootGuardManager() foo = [1, 2, 3] guard = guards.LENGTH_CHECK(root, len(foo), ["len(x) == len(foo)"]) +======= + foo = [1, 2, 3] + guard = guards.LENGTH_CHECK(len(foo), ["len(x) == len(foo)"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(foo)) self.assertFalse(guard([])) def test_no_hasattr_guard(self): +<<<<<<< HEAD root = RootGuardManager() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Bar: def __init__(self) -> None: self.bar = 2 @@ -232,7 +296,11 @@ def __init__(self) -> None: foo = Foo() +<<<<<<< HEAD guard = guards.NO_HASATTR(root, "foo", ["hasattr(x, 'foo') == False"]) +======= + guard = guards.NO_HASATTR("foo", ["hasattr(x, 'foo') == False"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(bar)) self.assertFalse(guard(foo)) @@ -270,9 +338,14 @@ def __init__(self, x, y): self.assertFalse(guard_manager.check(f_locals_unaliased)) def test_dict_version_guard(self): +<<<<<<< HEAD root = RootGuardManager() foo = {"a": 1, "b": 2} guard = guards.DICT_VERSION(root, foo, ["x.version == foo.version"]) +======= + foo = {"a": 1, "b": 2} + guard = guards.DICT_VERSION(foo, ["x.version == foo.version"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(foo)) self.assertFalse(guard(dict(foo))) @@ -282,9 +355,14 @@ def test_dict_version_guard(self): self.assertFalse(guard({})) def test_dynamic_indices_guard(self): +<<<<<<< HEAD root = RootGuardManager() guard1 = guards.DYNAMIC_INDICES(root, set(), ["x.size(0) == y.size(0)"]) guard2 = guards.DYNAMIC_INDICES(root, set({0, 1}), ["x.size(0) == y.size(0)"]) +======= + guard1 = guards.DYNAMIC_INDICES(set(), ["x.size(0) == y.size(0)"]) + guard2 = guards.DYNAMIC_INDICES(set({0, 1}), ["x.size(0) == y.size(0)"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.randn(4) self.assertTrue(guard1(x)) @@ -382,20 +460,32 @@ def __init__(self, x, y, z): self.assertFalse(guard_manager.check_verbose(f_locals_unaliased).result) def test_weakref_alive_guard(self): +<<<<<<< HEAD root = RootGuardManager() x = torch.rand(3, 4) weakref_x = weakref.ref(x) guard = guards.NOT_NONE(root, ["weakref_x is not None"]) +======= + x = torch.rand(3, 4) + weakref_x = weakref.ref(x) + + guard = guards.NOT_NONE(["weakref_x is not None"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(weakref_x())) del x self.assertFalse(guard(weakref_x())) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") def test_call_function_no_args_guard(self): +<<<<<<< HEAD root = RootGuardManager() x = torch.cuda.current_device() guard = guards.EQUALS_MATCH(root, x, [0]) +======= + x = torch.cuda.current_device() + guard = guards.EQUALS_MATCH(x, [0]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(0)) self.assertFalse(guard(1)) self.assertFalse(guard(2)) @@ -713,16 +803,25 @@ def fn(x): self.assertTrue("Test" in debug_info.verbose_code_parts[0]) def test_dict_contains_guard(self): +<<<<<<< HEAD root = RootGuardManager() foo = {"a": 1, "b": 2} guard = guards.DICT_CONTAINS(root, True, "a", ["has a"]) +======= + foo = {"a": 1, "b": 2} + guard = guards.DICT_CONTAINS(True, "a", ["has a"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(foo)) self.assertTrue(guard({"a": 1, "b": 2})) self.assertFalse(guard({"b": 2, "c": 3})) self.assertFalse(guard({})) +<<<<<<< HEAD guard = guards.DICT_CONTAINS(root, False, "c", ["not has c"]) +======= + guard = guards.DICT_CONTAINS(False, "c", ["not has c"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guard(foo)) self.assertTrue(guard({"a": 1, "b": 2})) self.assertFalse(guard({"b": 2, "c": 3})) @@ -813,7 +912,11 @@ def test_clone(self): except ImportError: from utils import install_guard_manager_testing_hook +<<<<<<< HEAD def hook(guard_wrapper, f_locals, builder): +======= + def hook(guard_wrapper, f_locals): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) root = guard_wrapper.root # Check full cloning works as expected @@ -853,7 +956,11 @@ def test_diff_guard_manager(self): from utils import install_guard_manager_testing_hook counter = 0 +<<<<<<< HEAD def hook(guard_wrapper, f_locals, builder): +======= + def hook(guard_wrapper, f_locals): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal counter root = guard_wrapper.root diff_guard_root = guard_wrapper.diff_guard_root @@ -882,9 +989,14 @@ def hook(guard_wrapper, f_locals, builder): counter += 1 class Bar: +<<<<<<< HEAD def __init__(self): self.x = 4 self.y = torch.randn(4) +======= + x = 4 + y = torch.randn(4) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bar = Bar() @@ -901,6 +1013,7 @@ def fn(x, foo, bar): opt_fn(x, foo, bar) +<<<<<<< HEAD class TypePropagationTests(torch._dynamo.test_case.TestCase): @torch._dynamo.config.patch(skip_tensor_guards_with_matching_dict_tags=True) def test_basic_types(self): @@ -1360,6 +1473,8 @@ def max_size_test(guard_wrapper, f_locals, builder): opt_fn(x) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index e826492089f63..e2fdfeed11761 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -235,7 +235,10 @@ def __hash__(self): pytree.register_constant(CustomConstantType) +<<<<<<< HEAD @torch._dynamo.config.patch({"strict_precompile": True}) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestGuardSerialization(torch._inductor.test_case.TestCase): def test_function_locals(self): def foo(x): @@ -255,14 +258,21 @@ def _tracefunc(self, frame, event, arg): self._frame_state = _FrameState( f_locals=dict(frame.f_locals), +<<<<<<< HEAD f_globals=frame.f_globals, +======= + f_globals=dict(frame.f_globals), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f_code=frame.f_code, f_builtins=frame.f_builtins, ) def _test_serialization(self, guard_type, fn, *args, **kwargs): # kwargs might contain a callable that generates kwargs +<<<<<<< HEAD torch._dynamo.reset() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwarg_gen_fn = kwargs.get("_gen_fn", None) if kwarg_gen_fn is not None: kwargs = kwarg_gen_fn() @@ -307,9 +317,12 @@ def transform(instructions: list, code_options: dict[str, object]): nonlocal ref_gm nonlocal loaded_gm +<<<<<<< HEAD torch._dynamo.convert_frame.initial_global_state = ( torch._C._dynamo.guards.GlobalStateGuard() ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tracer = InstructionTranslator( instructions, self._frame_state.f_code, @@ -338,36 +351,54 @@ def transform(instructions: list, code_options: dict[str, object]): ): tracer.run() +<<<<<<< HEAD ref_gm = CheckFunctionManager( self._frame_state.f_code, tracer.output, guard_filter_fn=guard_filter_fn, ).guard_manager +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) check_fn_manager = CheckFunctionManager( self._frame_state.f_code, tracer.output, guard_filter_fn=guard_filter_fn, +<<<<<<< HEAD save_guards=True, ) guards_state = check_fn_manager.guards_state self._cached_guards_state = guards_state self._cached_f_code = self._frame_state.f_code +======= + guards_serialization_mode="save", + ) + ref_gm = check_fn_manager.guard_manager + guards_state = check_fn_manager.guards_state +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertIsNotNone(guards_state) guards_state = pickle.loads(guards_state) check_fn_manager = CheckFunctionManager( self._frame_state.f_code, guards_state.output_graph, +<<<<<<< HEAD shape_code_parts=guards_state.shape_code_parts, runtime_global_scope=self._frame_state.f_globals, +======= + guards_serialization_mode="load", + shape_code_parts=guards_state.shape_code_parts, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) loaded_gm = check_fn_manager.guard_manager try: transform_code_object(self._frame_state.f_code, transform) finally: +<<<<<<< HEAD torch._dynamo.convert_frame.initial_global_state = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._frame_state = None self.assertIsNotNone(ref_gm) @@ -879,6 +910,7 @@ def fn(x): ): self._test_serialization("ID_MATCH", fn, torch.randn(3)) +<<<<<<< HEAD @torch._dynamo.config.patch(caching_precompile=True) def test_id_match_with_config(self): def fn(x): @@ -896,6 +928,8 @@ def fn(x): ref, loaded = self._test_serialization("FUNCTION_MATCH", fn, torch.randn(3)) self._test_check_fn(ref, loaded, {"x": torch.randn(3)}, True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dispatch_key_set_match(self): def fn(x, dks): if dks.has("CPU"): @@ -1067,10 +1101,17 @@ def fn(x, x_): return x + x_ x = torch.randn(3, 2) +<<<<<<< HEAD ref, loaded = self._test_serialization("DUPLICATE_INPUT", fn, x, x) self._test_check_fn(ref, loaded, {"x": x, "x_": x}, True) self._test_check_fn(ref, loaded, {"x": x, "x_": torch.randn(3, 2)}, False) +======= + with self.assertRaisesRegex( + PackageError, "DUPLICATE_INPUT guard cannot be serialized" + ): + self._test_serialization("DUPLICATE_INPUT", fn, x, x) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_weakref_alive(self): mod = torch.nn.Linear(10, 10, bias=False) @@ -1168,6 +1209,7 @@ def fn(x): with torch.enable_grad(): self._test_check_fn(ref, loaded, {"x": x}, True) +<<<<<<< HEAD def test_grad_mode_loading(self): def fn(x): return x + 1 @@ -1186,6 +1228,8 @@ def fn(x): loaded = check_fn_manager.guard_manager self._test_check_fn(ref, loaded, {"x": x}, False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_deterministic_algorithms(self): def fn(x): return x + 1 @@ -1301,6 +1345,7 @@ def fn(x): self._test_check_fn(ref, loaded, {"x": torch.randn(3, 11, 2)}, False) self._test_check_fn(ref, loaded, {"x": torch.randn(3, 2, 2)}, False) +<<<<<<< HEAD def test_builtin_match(self): def fn(x): # usage of getattr() here installs a BUILTIN_MATCH guard @@ -1346,6 +1391,8 @@ def forward(self, x): ref, loaded = self._test_serialization("TENSOR_MATCH", m, torch.randn(3, 2)) self._test_check_fn(ref, loaded, {"self": m, "x": torch.randn(3, 2)}, True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 9f093d4dc0cea..256cbbb1061be 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -38,8 +38,16 @@ xfailIfTorchDynamo, ) from torch.testing._internal.hop_db import hop_db +<<<<<<< HEAD from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test + + +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def count_ops(gm, args, freq, op): @@ -1183,7 +1191,11 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): pred = a.sum() > 0 with self.assertRaisesRegex( NotImplementedError, +<<<<<<< HEAD "no rule registered for HigherOrderOperator cond and mode .*MyMode", +======= + "no rule registered for HOP cond and mode .*MyMode", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): with MyMode(): res = cond_op(pred, torch.sin, torch.cos, (a,)) @@ -2130,7 +2142,11 @@ def false_fn(x): and node.target == torch.ops.higher_order.cond ): _, _, _, operands = node.args +<<<<<<< HEAD # Since we compile with dynamic, each branch takes 4 inputs (buffer, x, z, s1) +======= + # Since we compile wit dynamic, each branch takes 4 inputs (buffer, x, z, s1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(len(operands), 4) if node.op == "get_attr": if str(node.target) in ("cond_true_0, cond_false_0"): @@ -2608,17 +2624,36 @@ def f(x): f, default_args_generator((x,)), arg_count, expected_opcount=3 ) +<<<<<<< HEAD def test_support_float_in_output(self): counters.clear() cnt = CompileCounter() @torch.compile(backend=cnt, fullgraph=True) +======= + def test_fallback_on_python_primitives_output(self): + counters.clear() + cnt = CompileCounter() + + @torch.compile(backend=cnt) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def f(x): return wrap(lambda x: [1, torch.sin(x), 2.0], x) x = torch.randn(3) result = f(x) self.assertEqual(result, [1, torch.sin(x), 2.0]) +<<<<<<< HEAD +======= + self.assertEqual(cnt.frame_count, 0) + assert_dict_matches_regex( + self, + dict(counters["graph_break"]), + { + ".*HigherOrderOperator body's output must consist of tensors or ints only but got": 1 + }, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nested_tuple_output(self): def f(x): @@ -3076,6 +3111,7 @@ def forward(self, L_a_ : torch.SymInt, L_b_ : torch.SymInt, L_c_ : torch.SymInt, b = torch.arange(l_b_) c = torch.arange(l_c_) d = torch.arange(l_d_) +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(l_d_, 'error'); _vmap_increment_nesting = None child = torch._functorch.predispatch._add_batch_dim(d, 0, 1); d = None @@ -3099,6 +3135,31 @@ def forward(self, L_a_ : torch.SymInt, L_b_ : torch.SymInt, L_c_ : torch.SymInt, _vmap_decrement_nesting_2 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_2 = None _remove_batch_dim_3 = torch._functorch.predispatch._remove_batch_dim(batched_outputs_3, 1, l_d_, 0); batched_outputs_3 = l_d_ = None _vmap_decrement_nesting_3 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_3 = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(l_d_, 'error'); _vmap_increment_nesting = None + child = torch._C._functorch._add_batch_dim(d, 0, 1); d = None + lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None + _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(l_c_, 'error'); _vmap_increment_nesting_1 = None + child_1 = torch._C._functorch._add_batch_dim(c, 0, 2); c = None + lazy_load_decompositions_2 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_2 = None + _vmap_increment_nesting_2 = torch._C._functorch._vmap_increment_nesting(l_b_, 'error'); _vmap_increment_nesting_2 = None + child_2 = torch._C._functorch._add_batch_dim(b, 0, 3); b = None + lazy_load_decompositions_3 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_3 = None + _vmap_increment_nesting_3 = torch._C._functorch._vmap_increment_nesting(l_a_, 'error'); _vmap_increment_nesting_3 = None + _add_batch_dim_3 = torch._C._functorch._add_batch_dim(a, 0, 4); a = None + add = _add_batch_dim_3 + child_2; _add_batch_dim_3 = child_2 = None + add_1 = add + child_1; add = child_1 = None + batched_outputs = add_1 + child; add_1 = child = None + batched_outputs_1 = torch._C._functorch._remove_batch_dim(batched_outputs, 4, l_a_, 0); batched_outputs = l_a_ = None + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None + batched_outputs_2 = torch._C._functorch._remove_batch_dim(batched_outputs_1, 3, l_b_, 0); batched_outputs_1 = l_b_ = None + _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None + batched_outputs_3 = torch._C._functorch._remove_batch_dim(batched_outputs_2, 2, l_c_, 0); batched_outputs_2 = l_c_ = None + _vmap_decrement_nesting_2 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_2 = None + _remove_batch_dim_3 = torch._C._functorch._remove_batch_dim(batched_outputs_3, 1, l_d_, 0); batched_outputs_3 = l_d_ = None + _vmap_decrement_nesting_3 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_3 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (_remove_batch_dim_3,)""", # noqa: B950 ) @@ -3731,11 +3792,19 @@ def forward(self, L_x_: "f32[4, 3]"): child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None child_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None + + child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None @@ -3778,18 +3847,32 @@ def forward(self, L_x_: "f32[4, 3]"): basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3); chunk_1 = None +<<<<<<< HEAD lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None _vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None _add_batch_dim_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(basis, 0, 3); basis = None +======= + lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None + + _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None + + _add_batch_dim_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 3); basis = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _autograd_grad = torch._functorch.eager_transforms._autograd_grad([primals_out], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True); primals_out = diff_primals = _add_batch_dim_1 = None batched_outputs: "f32[4, 3]" = _autograd_grad[0]; _autograd_grad = None +<<<<<<< HEAD chunked_result: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 3, 12, 0); batched_outputs = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0); batched_outputs = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) split = chunked_result.split((12,), dim = 0); chunked_result = None split_1: "f32[12, 4, 3]" = split[0]; split = None @@ -3808,9 +3891,15 @@ def forward(self, L_x_: "f32[4, 3]"): _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None +<<<<<<< HEAD results_1: "f32[12, 4, 3, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None _vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None +======= + results_1: "f32[12, 4, 3, 4, 3]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None + + _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) movedim: "f32[4, 3, 4, 3, 12]" = results_1.movedim(0, -1); results_1 = None split_2 = movedim.split((12,), dim = -1); movedim = None @@ -3859,11 +3948,19 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None child_1: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None + + child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None @@ -3908,18 +4005,32 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3); chunk_1 = None +<<<<<<< HEAD lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None _vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None _add_batch_dim_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(basis, 0, 3); basis = None +======= + lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None + + _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None + + _add_batch_dim_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 3); basis = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _autograd_grad = torch._functorch.eager_transforms._autograd_grad([primals_out], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True); primals_out = child_4 = _add_batch_dim_1 = None child_5: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None +<<<<<<< HEAD child_6: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(child_5, 3, 12, 0); child_5 = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + child_6: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0); child_5 = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) split = child_6.split((12,), dim = 0); child_6 = None split_1: "f32[12, 3, 4]" = split[0]; split = None @@ -3939,9 +4050,15 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None +<<<<<<< HEAD child_10: "f32[12, 4, 3, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(child_9, 1, 12, 0); child_9 = None _vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None +======= + child_10: "f32[12, 4, 3, 3, 4]" = torch._C._functorch._remove_batch_dim(child_9, 1, 12, 0); child_9 = None + + _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) movedim: "f32[4, 3, 3, 4, 12]" = child_10.movedim(0, -1); child_10 = None split_2 = movedim.split((12,), dim = -1); movedim = None @@ -4006,18 +4123,32 @@ def forward(self, L_x_: "f32[4, 3]"): basis: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None _add_batch_dim: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(basis, 0, 1); basis = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None + + _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _autograd_grad = torch._functorch.eager_transforms._autograd_grad([primals_out], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); primals_out = diff_primals = _add_batch_dim = None batched_outputs: "f32[4, 3]" = _autograd_grad[0]; _autograd_grad = None +<<<<<<< HEAD chunked_result: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) split = chunked_result.split((12,), dim = 0); chunked_result = None split_1: "f32[12, 4, 3]" = split[0]; split = None @@ -4084,18 +4215,32 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): basis: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None _add_batch_dim: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(basis, 0, 1); basis = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None + + _add_batch_dim: "f32[3, 4]" = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _autograd_grad = torch._functorch.eager_transforms._autograd_grad([primals_out], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); primals_out = diff_primals = _add_batch_dim = None batched_outputs: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None +<<<<<<< HEAD chunked_result: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) split = chunked_result.split((12,), dim = 0); chunked_result = None split_1: "f32[12, 3, 4]" = split[0]; split = None @@ -4164,18 +4309,32 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): basis: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None _add_batch_dim: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(basis, 0, 1); basis = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None + + _add_batch_dim: "f32[3, 4]" = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _autograd_grad = torch._functorch.eager_transforms._autograd_grad([primals_out], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); primals_out = diff_primals = _add_batch_dim = None batched_outputs: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None +<<<<<<< HEAD chunked_result: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) split = chunked_result.split((12,), dim = 0); chunked_result = None split_1: "f32[12, 3, 4]" = split[0]; split = None @@ -4472,6 +4631,7 @@ def wrapper_fn(model, params, buffers, inputs): if torch._dynamo.config.inline_inbuilt_nn_modules: expected = """\ class GraphModule(torch.nn.Module): +<<<<<<< HEAD def forward(self, L_inputs_: "f32[1, 1]", L_model_modules_l1_parameters_weight_: "f32[1, 1]", L_model_modules_l1_parameters_bias_: "f32[1]", L_model_buffers_buffer_: "f32[1]"): l_inputs_ = L_inputs_ l_model_modules_l1_parameters_weight_ = L_model_modules_l1_parameters_weight_ @@ -4479,6 +4639,17 @@ def forward(self, L_inputs_: "f32[1, 1]", L_model_modules_l1_parameters_weight_: l_model_buffers_buffer_ = L_model_buffers_buffer_ linear: "f32[1, 1]" = torch._C._nn.linear(l_inputs_, l_model_modules_l1_parameters_weight_, l_model_modules_l1_parameters_bias_); l_inputs_ = l_model_modules_l1_parameters_weight_ = l_model_modules_l1_parameters_bias_ = None add: "f32[1, 1]" = linear + l_model_buffers_buffer_; linear = l_model_buffers_buffer_ = None +======= + def forward(self, L_params_l1_weight_: "f32[1, 1]", L_params_l1_bias_: "f32[1]", L_buffers_buffer_: "f32[1]", L_inputs_: "f32[1, 1]"): + l_params_l1_weight_ = L_params_l1_weight_ + l_params_l1_bias_ = L_params_l1_bias_ + l_buffers_buffer_ = L_buffers_buffer_ + l_inputs_ = L_inputs_ + + linear: "f32[1, 1]" = torch._C._nn.linear(l_inputs_, l_params_l1_weight_, l_params_l1_bias_); l_inputs_ = l_params_l1_weight_ = l_params_l1_bias_ = None + + add: "f32[1, 1]" = linear + l_buffers_buffer_; linear = l_buffers_buffer_ = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (add,) """ # We found Windows/Linux have some empty line difference, empty_line_normalizer will help fix it. @@ -5221,11 +5392,19 @@ def forward(self, L_x_: "f32[4, 3]"): child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None child_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None + + child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None @@ -5251,9 +5430,15 @@ def forward(self, L_x_: "f32[4, 3]"): _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None +<<<<<<< HEAD results: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + results: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) movedim: "f32[4, 3, 12]" = results.movedim(0, -1); results = None split = movedim.split((12,), dim = -1); movedim = None @@ -5302,11 +5487,19 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None child_1: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None + + child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None @@ -5333,9 +5526,15 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None +<<<<<<< HEAD results: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + results: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) movedim: "f32[3, 4, 12]" = results.movedim(0, -1); results = None split = movedim.split((12,), dim = -1); movedim = None @@ -5384,11 +5583,19 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): child: "f32[12, 3, 4]" = chunk.view(12, 3, 4); chunk = None +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None child_1: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None + + child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None @@ -5417,10 +5624,17 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None +<<<<<<< HEAD results: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None aux_2: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(aux_1, 1, 12, 0); aux_1 = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + results: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None + aux_2: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(aux_1, 1, 12, 0); aux_1 = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aux_3: "f32[4, 3]" = aux_2[0]; aux_2 = None @@ -5471,11 +5685,19 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): child: "f32[12, 4, 3]" = chunk.view(12, 4, 3); chunk = None +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(12, 'same'); _vmap_increment_nesting = None child_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'same'); _vmap_increment_nesting = None + + child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None @@ -5509,10 +5731,17 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None +<<<<<<< HEAD child_8: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(child_6, 1, 12, 0); child_6 = None child_9: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(child_7, 1, 12, 0); child_7 = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + child_8: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(child_6, 1, 12, 0); child_6 = None + child_9: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(child_7, 1, 12, 0); child_7 = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) movedim: "f32[3, 4, 12]" = child_8.movedim(0, -1); child_8 = None split = movedim.split((12,), dim = -1); movedim = None @@ -6252,19 +6481,33 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 3, 3]"): l_x_ = L_x_ +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None _add_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None + + _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sum_1: "f32[3]" = _add_batch_dim.sum(0) sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None batched_outputs: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None +<<<<<<< HEAD _remove_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (_remove_batch_dim,) """, ) @@ -6290,20 +6533,34 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 3, 3]"): l_x_ = L_x_ +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None _add_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None + + _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sum_1: "f32[3]" = _add_batch_dim.sum(0) sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None batched_outputs: "f32[3]" = add + 3; add = None +<<<<<<< HEAD _remove_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (_remove_batch_dim,) """, ) @@ -6330,20 +6587,34 @@ def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3]"): l_x_ = L_x_ l_y_ = L_y_ +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None _add_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None + + _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sum_1: "f32[3]" = _add_batch_dim.sum(0) sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None batched_outputs: "f32[3, 3]" = add + l_y_; add = l_y_ = None +<<<<<<< HEAD _remove_batch_dim: "f32[3, 3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + _remove_batch_dim: "f32[3, 3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (_remove_batch_dim,) """, ) @@ -6371,21 +6642,36 @@ def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3]"): l_x_ = L_x_ l_y_ = L_y_ +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None _add_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None _add_batch_dim_1: "f32[3]" = torch._functorch.predispatch._add_batch_dim(l_y_, 1, 1); l_y_ = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None + + _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None + _add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 1, 1); l_y_ = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sum_1: "f32[3]" = _add_batch_dim.sum(0) sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None batched_outputs: "f32[3]" = add + _add_batch_dim_1; add = _add_batch_dim_1 = None +<<<<<<< HEAD _remove_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (_remove_batch_dim,) """, ) @@ -6415,21 +6701,36 @@ def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3]"): l_x_ = L_x_ l_y_ = L_y_ +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None _add_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None _add_batch_dim_1: "f32[3]" = torch._functorch.predispatch._add_batch_dim(l_y_, 1, 1); l_y_ = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None + + _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None + _add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 1, 1); l_y_ = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sum_1: "f32[3]" = _add_batch_dim.sum(0) sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None batched_outputs: "f32[3]" = add + _add_batch_dim_1; add = _add_batch_dim_1 = None +<<<<<<< HEAD _remove_batch_dim: "f32[3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (_remove_batch_dim,) """, ) @@ -6455,6 +6756,7 @@ def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"): l_x_ = L_x_ l_y_ = L_y_ +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None @@ -6478,6 +6780,31 @@ def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"): _remove_batch_dim_1: "f32[3, 3, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs_1, 1, 3, 0); batched_outputs_1 = None _vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None + + child: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None + child_1: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_y_, 0, 1); l_y_ = None + + lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None + + _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None + + _add_batch_dim_2: "f32[3]" = torch._C._functorch._add_batch_dim(child, 1, 2); child = None + _add_batch_dim_3: "f32[3]" = torch._C._functorch._add_batch_dim(child_1, 1, 2); child_1 = None + + batched_outputs: "f32[3]" = _add_batch_dim_2 + _add_batch_dim_3; _add_batch_dim_2 = _add_batch_dim_3 = None + + batched_outputs_1: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None + + _remove_batch_dim_1: "f32[3, 3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 3, 0); batched_outputs_1 = None + + _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (_remove_batch_dim_1,) """, ) @@ -6504,6 +6831,7 @@ def forward(self, L_y_: "f32[5, 3]", L_x_: "f32[2, 3]"): l_y_ = L_y_ l_x_ = L_x_ +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(5, 'error'); _vmap_increment_nesting = None @@ -6525,6 +6853,29 @@ def forward(self, L_y_: "f32[5, 3]", L_x_: "f32[2, 3]"): _remove_batch_dim_1: "f32[5, 3, 2, 3]" = torch._functorch.predispatch._remove_batch_dim(batched_outputs_1, 1, 5, 0); batched_outputs_1 = None _vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(5, 'error'); _vmap_increment_nesting = None + + child: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 0, 1); l_y_ = None + + lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None + + _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None + + _add_batch_dim_1: "f32[]" = torch._C._functorch._add_batch_dim(child, 0, 2); child = None + + batched_outputs: "f32[2, 3]" = l_x_ * _add_batch_dim_1; l_x_ = _add_batch_dim_1 = None + + batched_outputs_1: "f32[3, 2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None + + _remove_batch_dim_1: "f32[5, 3, 2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 5, 0); batched_outputs_1 = None + + _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (_remove_batch_dim_1,) """, ) @@ -6549,19 +6900,34 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[2, 4, 3]"): l_x_ = L_x_ +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None _add_batch_dim: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None + + _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) child: "f32[3]" = _add_batch_dim.sum(0) child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None +<<<<<<< HEAD _remove_batch_dim: "f32[2, 3]" = torch._functorch.predispatch._remove_batch_dim(child, 1, 2, 0); child = None _remove_batch_dim_1: "f32[2, 4]" = torch._functorch.predispatch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + _remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 0); child = None + _remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (_remove_batch_dim, _remove_batch_dim_1) """, ) @@ -6586,19 +6952,34 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[2, 4, 3]"): l_x_ = L_x_ +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None _add_batch_dim: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None + + _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) child: "f32[3]" = _add_batch_dim.sum(0) child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None +<<<<<<< HEAD _remove_batch_dim: "f32[3, 2]" = torch._functorch.predispatch._remove_batch_dim(child, 1, 2, 1); child = None _remove_batch_dim_1: "f32[2, 4]" = torch._functorch.predispatch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + _remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1); child = None + _remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (_remove_batch_dim, _remove_batch_dim_1) """, ) @@ -6624,19 +7005,34 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[2, 4, 3]"): l_x_ = L_x_ +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None _add_batch_dim: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(l_x_, 0, 1); l_x_ = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None + + _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) child: "f32[3]" = _add_batch_dim.sum(0) child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None +<<<<<<< HEAD _remove_batch_dim: "f32[3, 2]" = torch._functorch.predispatch._remove_batch_dim(child, 1, 2, 1); child = None _remove_batch_dim_1: "f32[2, 4]" = torch._functorch.predispatch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + _remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1); child = None + _remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (_remove_batch_dim, _remove_batch_dim_1) """, ) @@ -6834,7 +7230,11 @@ def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True): for arg, cloned_arg in zip(args, cloned_args): self.assertEqual(arg.grad, cloned_arg.grad) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._functorch.config.patch(functionalize_rng_ops=True) def test_function(self): def gn(x, y): @@ -6853,7 +7253,11 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._functorch.config.patch(functionalize_rng_ops=True) def test_function_with_kwargs(self): def gn(x, y): @@ -6876,7 +7280,11 @@ def fn(x, y): backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._functorch.config.patch(functionalize_rng_ops=True) def test_dropout(self): def gn(x, y): @@ -6902,7 +7310,11 @@ def fn(x, y): fn, backend, x, y, skip_check=True ) # dropout decomp is known to diverge with eager +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._functorch.config.patch(functionalize_rng_ops=True) def test_dropout_inductor(self): def gn(x, y): @@ -6921,7 +7333,11 @@ def fn(x, y): fn, backend, x, y, skip_check=True ) # dropout decomp is known to diverge with eager +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._functorch.config.patch(functionalize_rng_ops=True) def test_fallback(self): def gn(x, y): @@ -6952,7 +7368,11 @@ def fn(x, y): self.assertEqual(cnt.op_count, 2) self.assertEqual(len(backend.graphs), 2) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._functorch.config.patch(functionalize_rng_ops=True) def test_module(self): class MockModule(torch.nn.Module): @@ -7095,6 +7515,7 @@ def test_non_aliasing_util(self): ): _assert_tensors_nonaliasing(a, a) +<<<<<<< HEAD def test_flop_counter_for_cond(self): from torch.utils.flop_counter import FlopCounterMode @@ -7192,6 +7613,8 @@ def false_branch(x): }, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail_hops_compile = { # aot_eager @@ -7205,7 +7628,11 @@ def false_branch(x): class TestHigherOrderOpsOpInfo(torch._dynamo.test_case.TestCase): +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("backend", ("aot_eager", "inductor")) @ops( list(filter(lambda op: op.name not in xfail_hops_compile, hop_db)), diff --git a/test/dynamo/test_hooks.py b/test/dynamo/test_hooks.py index 3f3a3bd7f6537..1f8e5bdedc83c 100644 --- a/test/dynamo/test_hooks.py +++ b/test/dynamo/test_hooks.py @@ -746,7 +746,11 @@ def test_fn(fn): if cnts: self.assertEqual(cnts.frame_count, 1) # These same exact assertions run on both eager and compiled +<<<<<<< HEAD # X goes to x*2 because of mul_ +======= + # X goes to x*2 becaue of mul_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(x, torch.tensor([0.5, 0.5, 0.5]) * 2) # This test proves grad aliasing works - self.assertEqual(x.grad, b * 5) diff --git a/test/dynamo/test_inline_and_install.py b/test/dynamo/test_inline_and_install.py index b38b96ccc3e9e..f3733423d7359 100644 --- a/test/dynamo/test_inline_and_install.py +++ b/test/dynamo/test_inline_and_install.py @@ -57,7 +57,11 @@ def make_dynamic_cls(cls): ) +<<<<<<< HEAD # These tests do string comparison on the graphs, and since buffers are now inlined, they +======= +# These tests do string comparisson on the graphs, and since buffers are now inlined, they +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # are named different, resulting in failure unittest.expectedFailure( InlineAndInstallExportTests.test_param_buffer_safe_from_mutation_simple_inline_and_install # noqa: F821 diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 2a83b28b50a9c..c9ae174f58771 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -21,6 +21,7 @@ from torch.testing._internal.common_cuda import SM90OrLater from torch.testing._internal.common_utils import ( find_free_port, +<<<<<<< HEAD IS_WINDOWS, munge_exc, skipIfTorchDynamo, @@ -32,11 +33,19 @@ HAS_CUDA_AND_TRITON, HAS_XPU_AND_TRITON, ) +======= + munge_exc, + skipIfTorchDynamo, + xfailIfS390X, +) +from torch.testing._internal.inductor_utils import HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.logging_utils import ( LoggingTestCase, make_logging_test, make_settings_test, ) +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton @@ -44,14 +53,22 @@ HAS_CUDA_AND_TRITON or HAS_XPU_AND_TRITON, "requires cuda or xpu with triton" ) +======= + + +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) +<<<<<<< HEAD device_type = ( acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def munge_shape_guards(s: str) -> str: SHAPE_GUARD_REGEX = ( @@ -86,7 +103,11 @@ def inductor_error_fn(a): def inductor_schedule_fn(a): +<<<<<<< HEAD output = a.add(torch.ones(1000, 1000, device=device_type)) +======= + output = a.add(torch.ones(1000, 1000, device="cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return output @@ -123,6 +144,7 @@ class LoggingTests(LoggingTestCase): test_output_code = multi_record_test(3, output_code=True) test_aot_graphs = multi_record_test(3, aot_graphs=True) +<<<<<<< HEAD @requires_gpu @make_logging_test(schedule=True) def test_schedule(self, records): @@ -144,6 +166,29 @@ def test_fusion(self, records): def test_cudagraphs(self, records): fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn) fn_opt(torch.ones(1000, 1000, device=device_type)) +======= + @requires_cuda + @make_logging_test(schedule=True) + def test_schedule(self, records): + fn_opt = torch.compile(inductor_schedule_fn, backend="inductor") + fn_opt(torch.ones(1000, 1000, device="cuda")) + self.assertGreater(len(records), 0) + self.assertLess(len(records), 5) + + @requires_cuda + @make_logging_test(fusion=True) + def test_fusion(self, records): + fn_opt = torch.compile(inductor_schedule_fn, backend="inductor") + fn_opt(torch.ones(1000, 1000, device="cuda")) + self.assertGreater(len(records), 0) + self.assertLess(len(records), 8) + + @requires_cuda + @make_logging_test(cudagraphs=True) + def test_cudagraphs(self, records): + fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn) + fn_opt(torch.ones(1000, 1000, device="cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertGreater(len(records), 0) self.assertLess(len(records), 8) @@ -252,7 +297,11 @@ def throw(x): exitstack.close() @requires_distributed() +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @make_logging_test(ddp_graphs=True) def test_ddp_graphs(self, records): class ToyModel(torch.nn.Module): @@ -530,7 +579,11 @@ def test_invalid_artifact_flag_error_msg(self): "import torch", env=env, ) +<<<<<<< HEAD lines = stderr.decode().split("\r\n" if IS_WINDOWS else "\n") +======= + lines = stderr.decode().split("\n") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This is a sanity assert that our error is not spammy. # As of this test creation this was 18. # See this issue for the purpose o this test: @@ -546,7 +599,10 @@ def test_invalid_artifact_flag_error_msg(self): self.assertEqual(lines[-4], "Valid settings:") @requires_distributed() +<<<<<<< HEAD @skipIfWindows(msg="TODO: (xuhancn), Can't reproduce locally") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_distributed_rank_logging(self): env = dict(os.environ) env["TORCH_LOGS"] = "dynamo" @@ -726,10 +782,17 @@ def f(x, y, z): self.assertExpectedInline( munge_shape_guards(record.getMessage()), """\ +<<<<<<< HEAD +- __SHAPE_GUARD__: L['x'].size()[0] == 2*L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in # +- __SHAPE_GUARD__: L['z'].size()[0] == L['y'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False) +- __SHAPE_GUARD__: ((2*L['y'].size()[0]) % 3) == 0 # if x.size(0) % 3 == 0: # #:# in # #:# in # +- __SHAPE_GUARD__: 2 <= L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950 +======= ++- __SHAPE_GUARD__: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in # ++- __SHAPE_GUARD__: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False) ++- __SHAPE_GUARD__: ((2*L['z'].size()[0]) % 3) == 0 # if x.size(0) % 3 == 0: # #:# in # #:# in # ++- __SHAPE_GUARD__: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @make_logging_test(guards=True) @@ -779,11 +842,18 @@ def fn(x): self.assertGreater(len(records), 0) self.assertLess(len(records), 4) +<<<<<<< HEAD @xfailIf(TEST_XPU) # https://github.com/pytorch/pytorch/issues/157778 @make_logging_test(perf_hints=True) @requires_gpu def test_optimizer_non_static_param(self, records): params = [torch.randn(10, 10, device=device_type) for _ in range(2)] +======= + @make_logging_test(perf_hints=True) + @requires_cuda + def test_optimizer_non_static_param(self, records): + params = [torch.randn(10, 10, device="cuda") for _ in range(2)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for param in params: param.grad = torch.zeros_like(param) opt = torch.optim.Adam(params) @@ -793,7 +863,11 @@ def test_optimizer_non_static_param(self, records): self.assertLess(len(records), 3) @make_logging_test(autotuning=True) +<<<<<<< HEAD @requires_gpu +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not SM90OrLater, "requires H100+ GPU") def test_autotuning(self, records): with torch._inductor.utils.fresh_cache(): @@ -802,10 +876,14 @@ def f(a, b): return torch.mm(a, b) f = torch.compile(f, mode="max-autotune-no-cudagraphs") +<<<<<<< HEAD f( torch.randn(10, 10, device=device_type), torch.randn(10, 10, device=device_type), ) +======= + f(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertGreater(len(records), 0) self.assertLess(len(records), 40) @@ -855,6 +933,11 @@ def fn(a): len([r for r in records if "return a + 1" in r.getMessage()]), 0 ) +<<<<<<< HEAD +======= + # there are some additional deprecation warnings in stderr, probably due to newer dependencies used on s390x + @xfailIfS390X +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_logs_out(self): import tempfile @@ -942,7 +1025,10 @@ def bar(): "aot_graphs", "aot_graphs_effects", "pre_grad_graphs", +<<<<<<< HEAD "joint_graph_passes", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "post_grad_graphs", "inductor_metrics", "ir_pre_fusion", diff --git a/test/dynamo/test_metrics_context.py b/test/dynamo/test_metrics_context.py index 3a8657003cd19..52fb9c9c3b00b 100644 --- a/test/dynamo/test_metrics_context.py +++ b/test/dynamo/test_metrics_context.py @@ -64,7 +64,11 @@ def test_set_disallow_overwrite(self): def test_update_disallow_overwrite(self): """ +<<<<<<< HEAD Validate update won't overwrite. +======= + Validate update won't overwite. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ with MetricsContext(self._on_exit) as context: context.update({"m1": 1, "m2": 2}) @@ -73,7 +77,11 @@ def test_update_disallow_overwrite(self): def test_update_allow_overwrite(self): """ +<<<<<<< HEAD Validate update will overwrite when given param. +======= + Validate update will overwite when given param. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ with MetricsContext(self._on_exit) as context: context.update({"m1": 1, "m2": 2}) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 0a3891e2dc146..a38493ec5d79b 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1,7 +1,10 @@ # Owner(s): ["module: dynamo"] # ruff: noqa: F841 import abc +<<<<<<< HEAD import builtins +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import collections import collections.abc import copy @@ -17,13 +20,19 @@ import math import operator import os +<<<<<<< HEAD import pickle +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import random import sys import tempfile import threading import traceback +<<<<<<< HEAD import types +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import typing import unittest import unittest.mock as mock @@ -57,7 +66,10 @@ ) from torch._dynamo.utils import call_size, counters, ifdynstaticdefault from torch._dynamo.variables import builder +<<<<<<< HEAD from torch._inductor.codecache import WritableTempFile +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.utils import fresh_cache, run_and_get_code from torch.ao.quantization import MinMaxObserver from torch.ao.quantization.fake_quantize import FakeQuantize @@ -87,23 +99,32 @@ ) from torch.testing._internal.common_utils import ( freeze_rng_state, +<<<<<<< HEAD instantiate_parametrized_tests, IS_FBCODE, parametrize, +======= + IS_FBCODE, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) scoped_load_inline, set_default_dtype, skipIfHpu, skipIfNNModuleInlined, skipIfWindows, +<<<<<<< HEAD subtest, TEST_HPU, TEST_XPU, +======= + TEST_HPU, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) wrapDeterministicFlagAPITest, ) from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.logging_utils import logs_to_string +<<<<<<< HEAD pytree_modules = { "python": python_pytree, } @@ -119,6 +140,13 @@ [subtest(module, name=name) for name, module in pytree_modules.items()], ) +======= +if python_pytree._cxx_pytree_dynamo_traceable: + import torch.utils._cxx_pytree as cxx_pytree +else: + cxx_pytree = None + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"]) T = typing.TypeVar("T") @@ -1719,17 +1747,27 @@ def fn(packed): if hasattr(packed, "b"): b = packed.b + 1 c = packed[2] +<<<<<<< HEAD d = len(packed._fields) return a + b + c + d +======= + return a + b + c +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) v1 = torch.Tensor([1]) v2 = torch.Tensor([2]) v3 = torch.Tensor([3]) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) +<<<<<<< HEAD self.assertEqual(opt_fn(MyTuple(v1, v2, v3))[0], 10) self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 4) +======= + self.assertEqual(opt_fn(MyTuple(v1, v2, v3))[0], 7) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, 3) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_namedtuple3(self): def fn(x, packed): @@ -1976,6 +2014,7 @@ def fn(a, b): self.assertEqual(exp, act) +<<<<<<< HEAD def test_class_binop(self): class Foo: def __init__(self, x): @@ -2001,6 +2040,8 @@ def fn(a, b): opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) self.assertRaises(torch._dynamo.exc.Unsupported, opt_fn, a, b) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_user_getattr1(self): class MyConfig(dict): def __getattr__(self, name): @@ -3318,7 +3359,11 @@ def fn(m, x): def test_global_state_guard_serialization(self): GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard guards = GlobalStateGuard() +<<<<<<< HEAD serialized_guards = guards.__getstate__() +======= + serialized_guards = guards.dump() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) json_guards = json.loads(serialized_guards) samples = [] @@ -3340,17 +3385,28 @@ def test_global_state_guard_serialization(self): samples.append(new_dict) for sample in samples: +<<<<<<< HEAD guards.__setstate__(json.dumps(sample)) self.assertFalse(guards.check()) guards.__setstate__(json.dumps(json_guards)) +======= + guards.load(json.dumps(sample)) + self.assertFalse(guards.check()) + + guards.load(json.dumps(json_guards)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(guards.check()) # Test on autocast states. def _test_autocast(dtype): with torch.autocast("cpu", dtype): guards = GlobalStateGuard() +<<<<<<< HEAD serialized_guards = guards.__getstate__() +======= + serialized_guards = guards.dump() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) json_guards = json.loads(serialized_guards) for i, enabled in enumerate(json_guards["autocast_state"]["enabled"]): @@ -3359,7 +3415,11 @@ def _test_autocast(dtype): type(json_guards["autocast_state"]["dtype"][i]), int ) json_guards["autocast_state"]["dtype"][i] += 1 +<<<<<<< HEAD guards.__setstate__(json.dumps(json_guards)) +======= + guards.load(json.dumps(json_guards)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertFalse(guards.check()) _test_autocast(torch.float16) @@ -4095,7 +4155,11 @@ def test_write_to_cells_with_name_shadowing(self): y = x def make_x_get_set(): +<<<<<<< HEAD # NOTE: this `x` is a different cell object than the outer `x`. +======= + # NOTE: this `x` is a different cell object than the outter `x`. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = y def set_x(v): @@ -4887,7 +4951,11 @@ def fn(x, y): self.assertEqual(cnts.frame_count, 2) def test_id_guarded_object(self): +<<<<<<< HEAD class UserDefinedObject: +======= + class UDO: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch.compile(backend="eager") def call(self, x, ref_id): self_id = id(self) @@ -4900,11 +4968,19 @@ def call(self, x, ref_id): # Make sure we do recompile when id(self) is executed on # different self objects. x = torch.ones(2) +<<<<<<< HEAD obj1 = UserDefinedObject() obj1_id = id(obj1) self.assertEqual(obj1.call(x, obj1_id), torch.ones(2)) obj2 = UserDefinedObject() +======= + obj1 = UDO() + obj1_id = id(obj1) + self.assertEqual(obj1.call(x, obj1_id), torch.ones(2)) + + obj2 = UDO() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails. self.assertEqual(obj2.call(x, obj1_id), torch.zeros(2)) @@ -5183,9 +5259,12 @@ def fn(sample): self.assertTrue(same(ref, res)) +<<<<<<< HEAD @skipIfWindows( msg="TODO(xuhancn): confirm, AssertionError: tensor([0.0290, 0.4019, 0.2598, 0.3666]) is not None" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_release_input_memory(self): x = torch.rand([4]) x_ref = weakref.ref(x) @@ -5201,9 +5280,12 @@ def foo(x): del x self.assertIs(x_ref(), None) +<<<<<<< HEAD @skipIfWindows( msg="TODO: (xuhancn) conform, AssertionError: Linear(in_features=10, out_features=10, bias=True) is not None" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_release_module_memory(self): mod = torch.nn.Linear(10, 10) x = torch.rand([10, 10]) @@ -5235,7 +5317,10 @@ def foo(mod, x): self.assertIsNone(mod_ref(), None) self.assertIsNone(mod_weight_ref(), None) +<<<<<<< HEAD @skipIfWindows(msg="TODO: (xuhancn) conform, AssertionError: False is not true") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_release_scope_memory(self): def inner(y): y @@ -6405,6 +6490,7 @@ def func(x, y): self.assertTrue(same(ref, res)) self.assertTrue(same(x, x1)) +<<<<<<< HEAD def test_inference_mode_param(self): def fn(x): p = torch.nn.Parameter(x, requires_grad=False) @@ -6418,6 +6504,8 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_if_cond_nn_mod1(self): class MockModule(torch.nn.Module): def __init__(self, output_relu=True): @@ -6886,7 +6974,11 @@ def fn(x): # assign fstring to a variable causes the fstring to be used, # which realizes the variable tracker. f_str = f"{x.shape[0]}" +<<<<<<< HEAD return x.sin(), f_str +======= + return x.sin() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) guard_failure = None @@ -6968,7 +7060,11 @@ def guard_failures(failure): self.assertTrue(guard_failure is not None) self.assertIn("""tensor 'rank' size mismatch at index 0""", guard_failure[0]) +<<<<<<< HEAD @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "Test requires CUDA or XPU.") +======= + @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_symint_as_device_kwarg_non_strict_export(self): class Mod(torch.nn.Module): def forward(self, x): @@ -7272,6 +7368,7 @@ def injected(x): with torch.compiler.set_stance("fail_on_recompile"): self.assertEqual(compiled_fn(*args), injected(*args)) +<<<<<<< HEAD def test_fail_on_recompile_error_message(self): from torch._C._dynamo.eval_frame import ( _load_precompile_entry, @@ -7307,6 +7404,8 @@ def injected_bool(x: bool): finally: _reset_precompile_entries(fn.__code__) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_shape_and_tuple_equality(self): def fn(x, y, t): z = x * y @@ -8421,6 +8520,7 @@ def write_state(state): def fn(x): return x + 1 +<<<<<<< HEAD initial_state = read_state() y = torch.randn(10) try: @@ -8439,6 +8539,45 @@ def fn(x): assert cnt == len(initial_state) finally: write_state(initial_state) +======= + import contextlib + + @contextlib.contextmanager + def _hip_allow_tf32(): + # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new + # and only for MI300+ + hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) + os.environ["HIPBLASLT_ALLOW_TF32"] = "1" + + try: + yield + finally: + if hip_allow_tf32 is not None: + os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 + else: + del os.environ["HIPBLASLT_ALLOW_TF32"] + + tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext + with tf32_ctx(): + initial_state = read_state() + y = torch.randn(10) + try: + for round in range(3): + for i in range(len(initial_state)): + new_state = [False] * len(initial_state) + new_state[i] = True + write_state(new_state) + assert read_state() == new_state + last_state.clear() + fn(y) + assert last_state == new_state + if round == 0: + assert cnt == i + 1 + else: + assert cnt == len(initial_state) + finally: + write_state(initial_state) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_grad_state_mutated(self): prior = torch.is_grad_enabled() @@ -8556,6 +8695,7 @@ def global_context_capture_fn(frame_summary): self.assertEqual(seen_frames[0].name, "fn") self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)") +<<<<<<< HEAD def test_fullgraph_capture(self): from torch._dynamo.convert_frame import ( FrameInfo, @@ -8604,6 +8744,8 @@ def foo(x): )(x), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_torch_guards_stack_frame_register_inlining_deep(self): x = torch.tensor([0.5, 0.5]) y = torch.tensor([0.75, 0.75, 0.75, 0.75]) @@ -8641,6 +8783,7 @@ def global_context_capture_fn(frame_summary): self.assertEqual(seen_frames[1].name, "uwu_inline_me") self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)") +<<<<<<< HEAD def test_recompile_on_disable_1(self): # fix https://github.com/pytorch/pytorch/issues/157399 @torch.compile(backend="eager") @@ -8677,6 +8820,8 @@ def fn1(y): # there will be a resume function here return f(x) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_error_on_recompile(self): @torch.compile(backend="eager") def fn(a, b): @@ -8826,7 +8971,11 @@ def test_guards_cse_pass_single(self): ), testcase(expr="f(m.n[0], '1').x.y.z", expected="f(_var3, '1').x.y.z"), testcase(expr="f(m.n[0], '2').x.y.z", expected="f(_var3, '2').x.y.z"), +<<<<<<< HEAD # The whole expression gets CSE-d, as well as all of its sub-expressions. +======= + # The whole expressiong gets CSE-d, as well as all of its sub-expressions. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) testcase( expr="self.g(a, b).k", preface=["_var4 = self.g", "_var5 = _var4(a, b)", "_var6 = _var5.k"], @@ -9101,6 +9250,74 @@ def fn(): opt = torch.compile(fn, backend="eager") opt() +<<<<<<< HEAD +======= + def test_tracing_py_tree(self): + def fn(xs): + flat_xs, spec = python_pytree.tree_flatten(xs) + res = [x.clone() for x in flat_xs] + return python_pytree.tree_unflatten(res, spec) + + xs = [torch.tensor(i) for i in range(3)] + + counter = CompileCounter() + torch.compile(fn, backend=counter, fullgraph=True)(xs) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 3) + + def test_tracing_nested_py_tree(self): + def fn(xs): + flat_xs, spec = python_pytree.tree_flatten(xs) + res = [x.clone() for x in flat_xs] + return python_pytree.tree_unflatten(res, spec) + + xs = [torch.tensor(i) for i in range(3)] + xsl = [xs, xs, xs, xs] + + counter = CompileCounter() + comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) + real_out = fn(xsl) + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 12) + + def test_tracing_nested_py_tree_tuples(self): + def fn(xs): + flat_xs, spec = python_pytree.tree_flatten(xs) + res = [x.clone() for x in flat_xs] + return python_pytree.tree_unflatten(res, spec) + + xs = [torch.tensor(i) for i in range(3)] + xsl = (xs, xs, xs, xs) + + counter = CompileCounter() + comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) + real_out = fn(xsl) + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 12) + + def test_tracing_nested_py_tree_dicts(self): + def fn(xs): + flat_xs, spec = python_pytree.tree_flatten(xs) + res = [x.clone() for x in flat_xs] + return python_pytree.tree_unflatten(res, spec) + + xs = [torch.tensor(i) for i in range(3)] + xsl = { + "a": xs, + "b": xs, + "c": xs, + } + + counter = CompileCounter() + comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) + real_out = fn(xsl) + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 9) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dynamic_one_hot(self): def fn(x): x = x + 1 @@ -9117,6 +9334,31 @@ def fn(x): self.assertEqual(counter.frame_count, 2) self.assertEqual(counter.op_count, 2) +<<<<<<< HEAD +======= + def test_tracing_nested_py_tree_mixed_all(self): + def fn(xs): + flat_xs, spec = python_pytree.tree_flatten(xs) + res = [x.clone() for x in flat_xs] + return python_pytree.tree_unflatten(res, spec) + + xs = [torch.tensor(i) for i in range(3)] + xsa = (xs, xs) + xsb = {"aa": xsa, "ab": xs} + xsl = { + "a": xs, + "b": xsa, + "c": xsb, + } + + counter = CompileCounter() + comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl) + real_out = fn(xsl) + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 18) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_any_all_symnode(self): cnt = CompileCounter() @@ -9143,6 +9385,49 @@ def fn(x): self.assertEqual(fn(y3), y3 - 3) self.assertEqual(cnt.frame_count, 2) +<<<<<<< HEAD +======= + def test_tracing_py_tree_tensor_subclass(self): + from torch.testing._internal.two_tensor import TwoTensor + from torch.utils.checkpoint import checkpoint + + def fn(xs): + nested_xs = [[xs]] + flat_xs, spec = python_pytree.tree_flatten(xs) + return flat_xs[0].clone() + + # use checkpoint to trigger a "sourceless" tensor subclass + def checkpoint_fn(xs): + return checkpoint(fn, xs, use_reentrant=True) + + xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2)) + + counter = CompileCounter() + torch.compile(checkpoint_fn, backend=counter, fullgraph=True)(xs) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 2) + + def test_tracing_tree_map_only(self): + def fn(xs): + def mapper(x): + return x.clone() + + y = python_pytree.tree_map_only(torch.Tensor, mapper, xs) + return y + + xs = [torch.tensor(i) for i in range(3)] + ["hi"] + xsa = (xs, xs) + xsb = {"aa": xsa, "ab": xs} + + counter = CompileCounter() + comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsb) + real_out = fn(xsb) + + self.assertEqual(comp_out, real_out) + self.assertEqual(counter.frame_count, 1) + self.assertEqual(counter.op_count, 9) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._dynamo.config.patch( capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True ) @@ -9322,6 +9607,34 @@ def foo(x, y): self.assertEqual(counter.frame_count, 1) self.assertEqual(result, eager_result) +<<<<<<< HEAD +======= + def test_input_set_graph_break(self): + def foo(x): + return x.pop() * x.pop() + + x = torch.randn(10, 10) + y = torch.randn(10, 10) + + counter = CompileCounter() + + inp = {x, x, x, x, y, y} + foo = torch.compile(foo, backend=counter, fullgraph=True) + + # There's a lot of stuff about sets that cannot work without a good deal of exertion on our part. + # Specifically, getting a set as input won't ever work with how GetItemSource works (Can't arbitrary access set contents) + # and so the guard story for the objects passed into input just isn't there atm. + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, + "Unsupported method call", + ): + foo(inp) + + foo = torch.compile(foo, backend=counter, fullgraph=False) + foo(inp) + self.assertEqual(counter.frame_count, 1) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_reconstruct_set_across_graph_break(self): def foo(x, y): setty = set() @@ -10513,6 +10826,7 @@ def fn(x, y): self.assertEqual(actual, expected) +<<<<<<< HEAD def test_frozen_dataclass_attr_access(self): @dataclasses.dataclass(frozen=True) class TestDataClass: @@ -10584,6 +10898,140 @@ def fn(x, y): actual = fn_opt(*inps) expected = fn(*inps) self.assertEqual(actual, expected) +======= + def test_pytree_tree_leaves(self): + implemtations = [("python", python_pytree)] + if cxx_pytree is not None: + implemtations.append(("cxx", cxx_pytree)) + + for name, module in implemtations: + with self.subTest(f"pytree implement: {name}"): + + def fn(x): + tree = { + "a": [x, x - 1], + "b": x + 2, + "c": ( + x, + 3.0, + collections.deque([0.0, -x, 1, 2], maxlen=3), + ), + "d": collections.OrderedDict( + { + "e": torch.return_types.qr((2 * x, None)), + "f": MyTuple(x, x + 1, torch.zeros(4, 3)), + }, + ), + } + leaves = module.tree_leaves(tree) + return leaves + + x = torch.randn(3, 2) + expected = fn(x) + fn_opt = torch.compile(fullgraph=True)(fn) + actual = fn_opt(x) + + self.assertEqual(actual, expected) + + def test_pytree_tree_flatten_unflatten(self): + implemtations = [("python", python_pytree)] + if cxx_pytree is not None: + implemtations.append(("cxx", cxx_pytree)) + + for name, module in implemtations: + with self.subTest(f"pytree implement: {name}"): + + def fn(x, y): + tree = { + "a": [x, x - 1], + "b": x + 2, + "c": ( + x, + 3.0, + collections.deque([0.0, -x, 1, 2], maxlen=3), + ), + "d": collections.OrderedDict( + { + "e": torch.return_types.qr((2 * x, None)), + "f": MyTuple(x, x + 1, torch.zeros(4, 3)), + }, + ), + } + leaves, treespec = module.tree_flatten(tree) + new_leaves = [ + x - 1, + y, + x * y, + 3.0, + y - 2, + 1, + torch.zeros(2, 2), + 2 * y, + -y, + x + y, + x - y, + torch.ones(3, 2), + 1, + ] + new_tree = module.tree_unflatten(new_leaves, treespec) + return leaves, new_tree + + x = torch.randn(3, 2) + y = torch.randn(3, 2) + expected = fn(x, y) + fn_opt = torch.compile(fullgraph=True)(fn) + actual = fn_opt(x, y) + + self.assertEqual(actual, expected) + + def test_pytree_tree_map(self): + implemtations = [("python", python_pytree)] + if cxx_pytree is not None: + implemtations.append(("cxx", cxx_pytree)) + + for name, module in implemtations: + with self.subTest(f"pytree implement: {name}"): + + def fn(x, y): + tree1 = { + "a": [x, x - 1], + "b": x + 2, + "c": ( + x, + 3.0, + collections.deque([0.0, -x, 1, 2], maxlen=3), + ), + "d": collections.OrderedDict( + { + "e": torch.return_types.qr((2 * x, None)), + "f": MyTuple(x, x + 1, torch.zeros(4, 3)), + }, + ), + } + tree2 = collections.OrderedDict( + [ + ("c", (y, 3.0, collections.deque([1, -y, 10.0]))), + ("a", [y, y + 1]), + ("b", y + 2), + ( + "d", + { + "f": MyTuple(torch.ones(4, 3), -y, y + 1), + "e": torch.return_types.qr((2 * y, None)), + }, + ), + ], + ) + return module.tree_map(lambda u, v: (u, v), tree1, tree2) + + x = torch.randn(3, 2) + y = torch.randn(3, 2) + expected = fn(x, y) + fn_opt = torch.compile(fullgraph=True)(fn) + actual = fn_opt(x, y) + + self.assertEqual(actual, expected) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_shape_env_no_recording(self): main = ShapeEnv(should_record_events=False) @@ -10637,8 +11085,13 @@ def test_shape_env_equal_constructor(self): ShapeEnv not equal: field values don't match: ==> settings: values don't match. +<<<<<<< HEAD > Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, trace_asserts=False) > Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, trace_asserts=False) +======= + > Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, allow_complex_guards_as_runtime_asserts=False, trace_asserts=False) + > Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, allow_complex_guards_as_runtime_asserts=False, trace_asserts=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, ) self._replay_and_check(main) @@ -11104,7 +11557,11 @@ def EEE(): def fn(): return 3 """ +<<<<<<< HEAD with WritableTempFile(mode="w") as f: +======= + with tempfile.NamedTemporaryFile(mode="w") as f: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f.write(src) f.flush() from torch._dynamo.funcname_cache import get_funcname @@ -11497,7 +11954,10 @@ def fn(x, const): self.assertIs(c1[1], c2[0]) @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) +<<<<<<< HEAD @skipIfWindows(msg="TODO: (xuhancn) conform, AssertionError: False is not true") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dynamo_cache_invalidate(self): DeletedGuardManagerWrapper = torch._dynamo.guards.DeletedGuardManagerWrapper @@ -11692,7 +12152,11 @@ def fn(x, y): # Ensure that the generated graph returns only one output. We want the # add_ on the grad to be part of the graph itself, so that inductor can +<<<<<<< HEAD # theoretically move the add_ and resulting copy_ nodes at the right +======= + # theoretically move the add_ and resutling copy_ nodes at the right +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # place to free memory. self.assertEqual(len(list(cnt.graphs[0].graph.nodes)[-1].all_input_nodes), 1) self.assertEqual(z, ref_y) @@ -11805,6 +12269,7 @@ def fn(x, d): with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): fn(torch.randn(4), d) +<<<<<<< HEAD def test_hash_hop(self): associative_scan = importlib.import_module( "torch._higher_order_ops.associative_scan" @@ -11818,6 +12283,8 @@ def fn(y, s): fn(torch.ones(2, 2, device="cpu"), associative_scan.AssociativeScanOp()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_iter_type(self): @torch.compile(fullgraph=True) def fn(y): @@ -12267,7 +12734,11 @@ def __init__(self, x): self.ne_called = False def __ne__(self, other): +<<<<<<< HEAD # ne_called attr is later checked to ensure that overridden +======= + # ne_called attr is later checked to ensure that overrideen +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # `__ne__` is traced self.ne_called = True return not self.__eq__(other) @@ -12588,6 +13059,7 @@ def f(x): res = opt_f(x) self.assertEqual(ref, res) +<<<<<<< HEAD def test_builtin_complex(self): def f(x): c = ( @@ -12870,6 +13342,8 @@ def mapper(x): self.assertEqual(counter.frame_count, 1) self.assertEqual(counter.op_count, 9) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestTracer(JitTestCase): def test_jit_save(self): @@ -13038,7 +13512,11 @@ def forward(self, query, key, value): def test_torch_device_is_available(self, device): def fn(x): +<<<<<<< HEAD if torch.accelerator.is_available(): +======= + if TEST_HPU or TEST_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return x + 1 else: return x - 1 @@ -13141,6 +13619,7 @@ def f(rank): def test_cuda_set_device(self, device): def fn(): a = torch.ones(2, device=device) +<<<<<<< HEAD torch.get_device_module(device).set_device(1) return a + 1 @@ -13158,6 +13637,29 @@ def test_torch_device_python_type(self, device): ("cpu", "cpu", None), (device, device_type, 0), ]: +======= + torch.cuda.set_device(1) + return a + 1 + + with torch.cuda.device(0): + counter = CompileCounter() + opt_fn = torch.compile(fn, backend=counter) + res = opt_fn() + self.assertEqual(res.device.type, "cuda") + self.assertEqual(res.device.index, 0) + self.assertEqual(counter.frame_count, 2) + + def test_torch_device_python_type(self): + for device, device_type, index in [ + ("cpu", "cpu", None), + ("cuda:0", "cuda", 0), + ("hpu:0", "hpu", 0), + ]: + if (device == "cuda:0" and not TEST_CUDA) or ( + device == "hpu:0" and not TEST_HPU + ): + continue +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def fn(target): target_device = target.device @@ -13251,6 +13753,7 @@ def forward(self, input): # RuntimeError: value cannot be converted to type at::Half without overflow +<<<<<<< HEAD instantiate_parametrized_tests(MiscTestsPyTree) devices = ("cuda", "hpu", "xpu") @@ -13259,6 +13762,10 @@ def forward(self, input): ) +======= +devices = ("cuda", "hpu") +instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py index d2833e1a7195a..db4dd10566482 100644 --- a/test/dynamo/test_model_output.py +++ b/test/dynamo/test_model_output.py @@ -7,7 +7,11 @@ import torch._dynamo.testing from torch._dynamo.testing import same from torch.testing._internal.common_device_type import instantiate_device_type_tests +<<<<<<< HEAD from torch.testing._internal.common_utils import TestCase +======= +from torch.testing._internal.common_utils import TEST_HPU, TestCase +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: @@ -359,11 +363,19 @@ def forward( ) +<<<<<<< HEAD devices = ["cpu", "cuda", "xpu", "hpu"] instantiate_device_type_tests( TestModelOutputBert, globals(), only_for=devices, allow_xpu=True ) +======= +devices = ["cpu", "cuda"] +if TEST_HPU: + devices.append("hpu") + +instantiate_device_type_tests(TestModelOutputBert, globals(), only_for=devices) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 818e5a85aa26d..f371de92d468a 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -11,20 +11,28 @@ _pop_torch_function_stack, _push_on_torch_function_stack, ) +<<<<<<< HEAD from torch._dynamo.utils import counters from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode from torch.testing._internal.common_utils import skipIfXpu from torch.testing._internal.inductor_utils import GPU_TYPE from torch.testing._internal.triton_utils import requires_gpu +======= +from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode +from torch.testing._internal.triton_utils import requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._device import DeviceContext from torch.utils._python_dispatch import TorchDispatchMode +<<<<<<< HEAD device_type = ( acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestMode(BaseTorchFunctionMode): def __torch_function__(self, func, types, args, kwargs=None): if not kwargs: @@ -36,6 +44,7 @@ def __torch_function__(self, func, types, args, kwargs=None): return super().__torch_function__(func, types, args, kwargs) +<<<<<<< HEAD class HopDetectionError(Exception): pass @@ -53,6 +62,8 @@ def __torch_function__(self, func, types, args, kwargs=None): return super().__torch_function__(func, types, args, kwargs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TorchDispatchModeTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): @@ -62,6 +73,7 @@ def setUpClass(cls): def tearDownClass(cls): super().tearDownClass() +<<<<<<< HEAD def test_torch_dispatch_ignore_compile_internals(self): counters.clear() from torch.utils._python_dispatch import TorchDispatchMode @@ -110,6 +122,8 @@ def g(x): self.assertEqual(counters["frames"]["total"], 1) self.assertEqual(counters["frames"]["ok"], 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_skip_torch_dispatch_modes(self): class RewriteAddToMul(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): @@ -300,7 +314,11 @@ def fn(x): self.assertRaisesRegex( torch._dynamo.exc.Unsupported, +<<<<<<< HEAD "Attempted to pop from empty torch function mode stack", +======= + "Popping from an empty torch function mode stack", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lambda: fn(torch.ones(2, 2)), ) @@ -572,6 +590,7 @@ def fn(x, y): # Needs larger cache size since we recompile for each op @patch.object(torch._dynamo.config, "recompile_limit", 48) def test_builtin_equivalent_funcs(self): +<<<<<<< HEAD from torch._dynamo.variables.builtin import ( BUILTIN_TO_TENSOR_FN_MAP, BUILTIN_TO_TENSOR_RFN_MAP, @@ -579,6 +598,13 @@ def test_builtin_equivalent_funcs(self): from torch._dynamo.variables.torch_function import ( bin_int_ops, bin_ops, +======= + from torch._dynamo.variables.torch_function import ( + bin_int_ops, + bin_ops, + BUILTIN_TO_TENSOR_FN_MAP, + BUILTIN_TO_TENSOR_RFN_MAP, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor_and_int_ops, un_int_ops, un_ops, @@ -688,12 +714,20 @@ def func(a): func(torch.randn(3)) +<<<<<<< HEAD @requires_gpu +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_flex_attention(self): import torch from torch.nn.attention.flex_attention import create_block_mask, flex_attention +<<<<<<< HEAD torch.set_default_device(device_type) +======= + torch.set_default_device("cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) flex_attention = torch.compile(flex_attention, dynamic=False) @@ -703,9 +737,13 @@ def prefix_lm(b, h, q, kv): return prefix_lengths[b] >= kv # This runs in fullgraph already +<<<<<<< HEAD create_block_mask( prefix_lm, 8, None, 512, 512, _compile=True, device=device_type ) +======= + create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_register_hook(self): import functools @@ -728,6 +766,7 @@ def forward(self, x): with torch.device("cpu"): torch.compile(mod, fullgraph=True)(x) +<<<<<<< HEAD @requires_gpu @skipIfXpu(msg="XPU does not support flex attention") def test_hop(self): @@ -773,6 +812,8 @@ def test_hop_eager(self): torch.ones(2, 2, 2, 2), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 7cac7eca72394..46da14eb679c5 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -699,6 +699,7 @@ def forward(self, x, y): return self.layer(x, y=y) +<<<<<<< HEAD class LazyModuleBadInferParams(LazyModuleMixin, torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -710,6 +711,8 @@ def forward(self, x, y): return self.layer(x, y=y) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class LazyParentModule(LazyModuleMixin, torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1665,6 +1668,7 @@ def test_lazy_module_kwargs(self): exp_res = m(x, y) self.assertTrue(torch.allclose(exp_res, opt_m(x, y))) +<<<<<<< HEAD def test_lazy_module_bad_params(self): m = LazyModuleBadInferParams() x = [torch.rand([5, 5])] * 3 @@ -1691,6 +1695,8 @@ def m(x, y): with self.assertRaises(AttributeError): exp_res = opt_m(x, y) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # RuntimeError: SymIntArrayRef expected to contain only concrete integers @expectedFailureDynamic def test_lazy_module_speculation_log_divergence(self): @@ -2024,7 +2030,11 @@ def forward(self, x): # Check order of _modules def fn(x): for idx, p in enumerate(mod.modules()): +<<<<<<< HEAD # Something silly to force dependency on the order +======= + # Something silly to force depedency on the order +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x += coeffs_for_mod[p] * coeffs[idx] for idx, p in enumerate(mod.named_modules()): x += coeffs_for_mod[p[1]] * coeffs[idx] @@ -3422,6 +3432,7 @@ def forward(self, x): compiled_mod = torch.compile(mod, backend="eager") compiled_mod(x) +<<<<<<< HEAD def test_trace_delattr(self): TMP_PREFIX = "_tmp_" @@ -3479,6 +3490,11 @@ def forward(self, x): instantiate_device_type_tests( NNModuleTestsDevice, globals(), only_for=devices, allow_xpu=True ) +======= + +devices = ["cuda", "hpu"] +instantiate_device_type_tests(NNModuleTestsDevice, globals(), only_for=devices) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py index 96a726ad66808..51f397c458189 100644 --- a/test/dynamo/test_package.py +++ b/test/dynamo/test_package.py @@ -1,9 +1,13 @@ # Owner(s): ["module: dynamo"] +<<<<<<< HEAD import importlib import os import sys import tempfile +======= +import os +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import unittest import torch @@ -12,15 +16,21 @@ import torch._inductor.test_case import torch.onnx.operators import torch.utils.cpp_extension +<<<<<<< HEAD from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache from torch._dynamo.precompile_context import PrecompileContext from torch._dynamo.testing import reduce_to_scalar_loss from torch._functorch import config as functorch_config from torch._inductor.mock_cache import global_stats, PatchCaches, Stats +======= +from torch._dynamo.package import CompilePackage, DynamoStore +from torch._functorch import config as functorch_config +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.runtime.runtime_utils import cache_dir from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, +<<<<<<< HEAD skipIfRocm, skipIfXpu, ) @@ -36,6 +46,13 @@ def compute_loss_helper(x): @functorch_config.patch("bundled_autograd_cache", True) @torch._dynamo.config.patch({"strict_precompile": True}) +======= +) +from torch.testing._internal.inductor_utils import HAS_CUDA + + +@functorch_config.patch("bundled_autograd_cache", True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @instantiate_parametrized_tests class TestPackage(torch._inductor.test_case.TestCase): def path(self): @@ -43,6 +60,7 @@ def path(self): os.makedirs(path, exist_ok=True) return path +<<<<<<< HEAD def setUp(self): super().setUp() torch._dynamo.reset() @@ -104,6 +122,14 @@ def test_basic_fn(self, backend, device): raise unittest.SkipTest("Requires XPU/Triton") ctx = DiskDynamoStore() +======= + @parametrize("backend", ("eager", "inductor")) + @parametrize("device", ("cpu", "cuda")) + def test_basic_fn(self, backend, device): + if device == "cuda" and not HAS_CUDA: + raise unittest.SkipTest("Requires CUDA/Triton") + ctx = DynamoStore() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def fn(x): return x + 1 @@ -140,6 +166,7 @@ def fn(x): self.assertEqual(expected, compiled_fn(*args)) @parametrize("backend", ("eager", "inductor")) +<<<<<<< HEAD @parametrize("device", ("cpu", "cuda", "xpu")) def test_lazy_backward(self, backend, device): if device == "cuda" and not HAS_CUDA_AND_TRITON: @@ -195,6 +222,14 @@ def test_graph_break_bomb(self, backend, device): raise unittest.SkipTest("Requires XPU/Triton") ctx = DiskDynamoStore() +======= + @parametrize("device", ("cpu", "cuda")) + def test_graph_break_bomb(self, backend, device): + if device == "cuda" and not HAS_CUDA: + raise unittest.SkipTest("Requires CUDA/Triton") + + ctx = DynamoStore() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def fn(x, l, r): if l > r: @@ -251,6 +286,7 @@ def guard_filter_fn(guards): compiled_fn(torch.tensor(N), 0, N - 1) @parametrize("backend", ("eager", "inductor")) +<<<<<<< HEAD @parametrize("device", ("cpu", "cuda", "xpu")) def test_dynamic_shape(self, backend, device): if device == "cuda" and not HAS_CUDA_AND_TRITON: @@ -259,6 +295,13 @@ def test_dynamic_shape(self, backend, device): raise unittest.SkipTest("Requires XPU/Triton") ctx = DiskDynamoStore() +======= + @parametrize("device", ("cpu", "cuda")) + def test_dynamic_shape(self, backend, device): + if device == "cuda" and not HAS_CUDA: + raise unittest.SkipTest("Requires CUDA/Triton") + ctx = DynamoStore() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def fn(x): return x + x.shape[0] @@ -300,6 +343,7 @@ def fn(x): ): compiled_fn(*args2) +<<<<<<< HEAD def test_file_change(self): ctx = DiskDynamoStore() @@ -632,6 +676,8 @@ def foo(set_of_x): compiled_fn(*args) self._save_and_reload(expected_backends=1, expected_dynamo=1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_pgo.py b/test/dynamo/test_pgo.py index ce2fda1387291..e59d5d95eb1c7 100644 --- a/test/dynamo/test_pgo.py +++ b/test/dynamo/test_pgo.py @@ -12,9 +12,13 @@ import torch.compiler.config import torch.nested from torch._dynamo.testing import CompileCounter +<<<<<<< HEAD from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.utils import clear_caches, fresh_cache from torch.testing._internal.common_utils import IS_WINDOWS +======= +from torch._inductor.utils import clear_caches, fresh_cache +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class PgoTest(torch._dynamo.test_case.TestCase): @@ -57,10 +61,13 @@ def f(x): f(torch.randn(2, 6)) self.assertEqual(cnts.frame_count, 1) +<<<<<<< HEAD @torch._dynamo.config.patch( force_parameter_static_shapes=False, force_nn_module_property_static_shapes=False, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_whitelist_suggestion(self): cnts = CompileCounter() @@ -122,6 +129,7 @@ def check_whitelist(sources_): f(torch.randn(8, 8), torch.randn(8)) self.assertEqual(cnts.frame_count, 1) +<<<<<<< HEAD def test_no_empty_graph_allowlist(self): @torch._dynamo.disable def g(x): @@ -145,6 +153,8 @@ def f1(x): f1(torch.randn(8)) self.assertEqual(torch._dynamo.pgo._LOGGED_DYNAMIC_ALLOWLIST, True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_pgo_dynamic_false(self): @torch.compile(backend="eager", dynamic=False) class Foo(torch.nn.Module): @@ -169,6 +179,7 @@ def forward(self, x, y): def test_whitelist_ints_floats(self): @torch.compile(backend="eager", fullgraph=True) class Bar(torch.nn.Module): +<<<<<<< HEAD def __init__(self, c, d): super().__init__() self.c = c @@ -183,6 +194,18 @@ def forward(self, x, y, z): f = Bar(1.0, 2) f(2, 1.0, 2.0) f.d = 3 +======= + def __init__(self, c): + super().__init__() + self.c = c + + def forward(self, x, y, z): + if self.c == 1.0: + return x + y + torch.tensor([z]) + + f = Bar(1.0) + f(2, 1.0, 2.0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f(3, 1.2, 2.0) state = torch._dynamo.pgo.render_code_state(torch._dynamo.pgo.get_code_state()) whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(1) @@ -193,7 +216,10 @@ def forward(self, x, y, z): ) # ephemeral FloatTensor source self.assertTrue("L['z']" not in whitelist) # static float self.assertTrue("L['self'].c" not in whitelist) # static float property +<<<<<<< HEAD self.assertTrue("L['self'].d" in whitelist) # dynamic int property +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_pgo_dynamic_params(self): cnts = CompileCounter() @@ -223,6 +249,7 @@ def run(): self.assertEqual(cnts.frame_count, 3) # parameter static shapes are forced static, so we recompile once +<<<<<<< HEAD with torch._dynamo.config.patch( force_parameter_static_shapes=False, force_nn_module_property_static_shapes=False, @@ -233,6 +260,16 @@ def run(): # because flags were flipped, params were included in PGO run() self.assertEqual(cnts.frame_count, 1) +======= + run() + self.assertEqual(cnts.frame_count, 2) + + # flags are flipped, PGO records dynamism, so params are dynamically compiled to start + torch._dynamo.config.force_parameter_static_shapes = False + torch._dynamo.config.force_nn_module_property_static_shapes = False + run() + self.assertEqual(cnts.frame_count, 1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_njt(self): cnts = CompileCounter() @@ -353,9 +390,14 @@ def func(x): temp_dir1 = tempfile.TemporaryDirectory() temp_dir2 = tempfile.TemporaryDirectory() +<<<<<<< HEAD # We need normalize_path_separator for Windows file path. path1 = normalize_path_separator(os.path.join(temp_dir1.name, "example.py")) path2 = normalize_path_separator(os.path.join(temp_dir2.name, "example.py")) +======= + path1 = os.path.join(temp_dir1.name, "example.py") + path2 = os.path.join(temp_dir2.name, "example.py") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cnts = CompileCounter() assert path1 != path2 @@ -373,11 +415,15 @@ def write_load_and_run(path): write_load_and_run(path1) self.assertEqual(cnts.frame_count, 2) state = torch._dynamo.pgo.render_code_state(torch._dynamo.pgo.get_code_state()) +<<<<<<< HEAD # Windows can't create unification temp path: # hash(a18a3259)C:/Users/Xuhan/AppData/Local/Temp/tmpx3hfkuqa/example.py # Skip hash check self.assertTrue("hash" if IS_WINDOWS else "hash(390fe689)" in state) +======= + self.assertTrue("hash(390fe689)" in state) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue("/example.py:4:func:" in state) self.assertTrue(" L['x']: tensor size=[?] stride=[1]" in state) # We should compile this only once due to PGO. @@ -385,6 +431,7 @@ def write_load_and_run(path): write_load_and_run(path2) self.assertEqual(cnts.frame_count, 1) +<<<<<<< HEAD @torch._dynamo.config.patch( automatic_dynamic_remote_pgo=True, automatic_dynamic_local_pgo=False ) @@ -497,6 +544,8 @@ def f(ints, t_scalar, tensors): merge_pgo_entry(t1, t2) self.assertEqual(t2.size, (auto_dynamic, auto_dynamic)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_precompile_context.py b/test/dynamo/test_precompile_context.py index b509adf281129..870b449a11946 100644 --- a/test/dynamo/test_precompile_context.py +++ b/test/dynamo/test_precompile_context.py @@ -1,16 +1,23 @@ # Owner(s): ["module: dynamo"] +<<<<<<< HEAD import pickle +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._dynamo import torch._dynamo.test_case import torch._functorch +<<<<<<< HEAD from torch._dynamo.precompile_context import ( EditablePrecompileCacheArtifact, PrecompileCacheArtifact, PrecompileContext, ) +======= +from torch._dynamo.precompile_context import PrecompileContext +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._functorch import config as functorch_config from torch._functorch._aot_autograd.autograd_cache import ( BundledAOTAutogradCacheArtifact, @@ -20,8 +27,13 @@ @functorch_config.patch({"enable_autograd_cache": True}) +<<<<<<< HEAD @torch._dynamo.config.patch( {"caching_precompile": True} +======= +@functorch_config.patch( + {"bundled_autograd_cache": True} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Requires bundledaotautograd cache for now class PrecompileContextTests(InductorTestCase): def setUp(self): @@ -47,9 +59,16 @@ def simple_function(x): x = torch.randn(10, device=GPU_TYPE, requires_grad=True) result = compiled_fn(x) result.sum().backward() +<<<<<<< HEAD self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 2) self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0) +======= + # Check that PrecompileContext._new_cache_artifacts_by_key has length 1 + self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1) + + self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result = PrecompileContext.serialize() assert result is not None serialized, cache_info = result @@ -82,6 +101,7 @@ def simple_function(x): x = torch.randn(10, device=GPU_TYPE, requires_grad=True) result = compiled_fn(x) result.sum().backward() +<<<<<<< HEAD self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 2) for key in PrecompileContext._new_cache_artifacts_by_key.keys(): result = PrecompileContext.serialize_artifact_by_key(key) @@ -123,6 +143,13 @@ def edit_fn(x): PrecompileContext.edit_artifact(key, edit_fn) +======= + # Check that PrecompileContext._new_cache_artifacts_by_key has length 1 + # TODO: the key right now is the AOTAutogradCacheKey, but will be backend_id once + # we have torch._dynamo.package implemented + self.assertEqual(len(PrecompileContext._new_cache_artifacts_by_key), 1) + key = next(iter(PrecompileContext._new_cache_artifacts_by_key.keys())) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result = PrecompileContext.serialize_artifact_by_key(key) assert isinstance(result, BundledAOTAutogradCacheArtifact) self.assertEqual(result.key, key) @@ -130,6 +157,7 @@ def edit_fn(x): self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0) result = PrecompileContext.serialize() assert result is not None +<<<<<<< HEAD artifacts, cache_info = result self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1) @@ -146,6 +174,11 @@ def edit_fn(x): len(PrecompileContext._new_cache_artifacts["precompile_aot_autograd"]), 0 ) +======= + _, cache_info = result + self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_profiler.py b/test/dynamo/test_profiler.py index 61dc63ed2d5c6..61cd7e8b3a6a8 100644 --- a/test/dynamo/test_profiler.py +++ b/test/dynamo/test_profiler.py @@ -181,7 +181,11 @@ def fn(x, y): torch.randn(10, 15), ) +<<<<<<< HEAD annotations = [e.name for e in prof.events() if "Torch-Compiled" in e.name] +======= + annotations = [e.name for e in prof.events() if "Compiled" in e.name] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( annotations, [ diff --git a/test/dynamo/test_python_dispatcher.py b/test/dynamo/test_python_dispatcher.py index d74077a5be4ce..c984a5e4e2c31 100644 --- a/test/dynamo/test_python_dispatcher.py +++ b/test/dynamo/test_python_dispatcher.py @@ -5,12 +5,15 @@ import torch._dynamo.test_case from torch._dynamo.testing import CompileCounter, EagerAndRecordGraphs, normalize_gm from torch.testing._internal.common_cuda import TEST_CUDA +<<<<<<< HEAD from torch.testing._internal.common_utils import TEST_XPU device_type = ( acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class PythonDispatcherTests(torch._dynamo.test_case.TestCase): @@ -80,7 +83,11 @@ def forward(self, L_x_: "f32[2, 3]"): """, # NOQA: B950 ) +<<<<<<< HEAD @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "requires cuda or xpu") +======= + @unittest.skipIf(not TEST_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dispatch_key_set_guard(self): counter = CompileCounter() @@ -102,7 +109,11 @@ def fn(x, dks): # No recompile since the dispatch key set is the same though the tensor is different. self.assertEqual(counter.frame_count, 1) +<<<<<<< HEAD x3 = torch.randn(2, 3, device=device_type) +======= + x3 = torch.randn(2, 3, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dks3 = torch._C._dispatch_keys(x3) self.assertEqual(fn(x3, dks3), torch.sin(x3 - 1)) # Re-compile since the dispatch key set is different. diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py index f945039b55d1c..408c2c32606d5 100644 --- a/test/dynamo/test_recompile_ux.py +++ b/test/dynamo/test_recompile_ux.py @@ -12,11 +12,14 @@ from torch.testing._internal.logging_utils import kwargs_to_settings, log_settings +<<<<<<< HEAD device_type = ( acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class RecompileUxTests(torch._dynamo.test_case.TestCase): # TODO(whc) dynamo actually recompiles one more time than the cache limit cache_limit = 1 @@ -106,10 +109,14 @@ def model(input): .startswith("torch._dynamo hit config.recompile_limit") ) +<<<<<<< HEAD @unittest.skipIf( not torch.cuda.is_available() and not torch.xpu.is_available(), "requires cuda or xpu", ) +======= + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nvfuser_guards(self): # we may want to model dynamo's guards sufficiently after nvfuser's ProfilingExecutor guards # such that we ensure dynamo is in charge of all the recompilations at the top level, @@ -117,11 +124,19 @@ def test_nvfuser_guards(self): def func(a, b, c): return a + b * c +<<<<<<< HEAD a = torch.rand(3, 4, 5, device=device_type) b = torch.rand(3, 4, 5, device=device_type) b_v = torch.rand(3, 5, 4, device=device_type).view(3, 4, 5) b_p = torch.rand(3, 5, 4, device=device_type).permute(0, 2, 1) c = torch.rand(3, 4, 5, device=device_type) +======= + a = torch.rand(3, 4, 5, device="cuda") + b = torch.rand(3, 4, 5, device="cuda") + b_v = torch.rand(3, 5, 4, device="cuda").view(3, 4, 5) + b_p = torch.rand(3, 5, 4, device="cuda").permute(0, 2, 1) + c = torch.rand(3, 4, 5, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) compile_counter = torch._dynamo.testing.CompileCounter() with torch._dynamo.config.patch("recompile_limit", 2): diff --git a/test/dynamo/test_recompiles.py b/test/dynamo/test_recompiles.py index 825d2e5d674a9..80a0990c00081 100644 --- a/test/dynamo/test_recompiles.py +++ b/test/dynamo/test_recompiles.py @@ -4,6 +4,7 @@ import torch import torch._dynamo.test_case import torch._dynamo.testing +<<<<<<< HEAD from torch._dynamo import config as dc @@ -56,6 +57,11 @@ def forward(self, x): mod = Mod() mod(torch.randn(2, 2)) +======= + + +class RecompileTests(torch._dynamo.test_case.TestCase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_automatic_dynamic_reduce_recompiles(self): # Test the counterfactual, lots of recompiles without this config def foo(x, y): @@ -548,6 +554,7 @@ def f(x, foo): f(x, foo1) self.assertEqual(counter.frame_count, 2) +<<<<<<< HEAD def test_no_recompile_over_unused_objects(self): # This is a regression test case that imitates # https://github.com/city96/ComfyUI-GGUF/blob/47bec6147569a138dd30ad3e14f190a36a3be456/ops.py#L169-L182 @@ -571,6 +578,8 @@ def apply_patches(f, x, keys): apply_patches(f, x, [("c", 3), ("d", 4)]) self.assertEqual(counter.frame_count, 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_reconstruct.py b/test/dynamo/test_reconstruct.py index 9f3d41964195d..aa83d649773be 100644 --- a/test/dynamo/test_reconstruct.py +++ b/test/dynamo/test_reconstruct.py @@ -7,7 +7,11 @@ import torch import torch._dynamo.test_case from torch.testing._internal.common_utils import IS_FBCODE +<<<<<<< HEAD from torch.testing._internal.inductor_utils import GPU_TYPE, requires_triton +======= +from torch.testing._internal.inductor_utils import requires_triton +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._triton import ( has_triton_experimental_host_tma, has_triton_tensor_descriptor_host_tma, @@ -82,6 +86,10 @@ def f(d, t): opt_f(d_opt, t) self.assertEqual(d, d_opt) +<<<<<<< HEAD +======= + @unittest.expectedFailure +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_ConstDict_popitem_reconstruct(self): """ If something is pop'ed from the dict, we reconstruct everything @@ -420,7 +428,11 @@ def create_tma(tensor): ) return tensor + 1, tma +<<<<<<< HEAD x = torch.randn(128, 128, device=GPU_TYPE) +======= + x = torch.randn(128, 128, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ref = create_tma(x) res = torch.compile(create_tma, backend="eager")(x) @@ -441,7 +453,11 @@ def create_tma(tensor): ) return tensor + 1, tma +<<<<<<< HEAD x = torch.randn(128, 128, device=GPU_TYPE) +======= + x = torch.randn(128, 128, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ref = create_tma(x) res = torch.compile(create_tma, backend="eager")(x) diff --git a/test/dynamo/test_reorder_logs.py b/test/dynamo/test_reorder_logs.py index be6bf8085af27..70da2bbc6dc33 100644 --- a/test/dynamo/test_reorder_logs.py +++ b/test/dynamo/test_reorder_logs.py @@ -210,8 +210,12 @@ def f(x): Hint: Set `torch._dynamo.config.capture_scalar_outputs = True` or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` to include these operations in the captured graph. Developer debug context: call_method TensorVariable() item () {} +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0124.html""", # noqa: B950 +======= +""", # noqa: B950 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 5e23e818f8eb0..61ec1a390d4e4 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -42,6 +42,7 @@ import torch.library import torch.utils._pytree as pytree from torch import nn +<<<<<<< HEAD from torch._dynamo.backends.debugging import ExplainWithBackend from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import ( @@ -51,13 +52,20 @@ skipIfNotPy312, skipIfPy312, ) +======= +from torch._dynamo.debug_utils import same_two_models +from torch._dynamo.testing import CompileCounter, rand_strided, same, skipIfPy312 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.utils import fresh_cache from torch.nn import functional as F from torch.profiler import profile, ProfilerActivity from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_FP8, +<<<<<<< HEAD SM70OrLater, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TEST_CUDA, ) from torch.testing._internal.common_device_type import instantiate_device_type_tests @@ -66,7 +74,10 @@ parametrize, serialTest, skipIfHpu, +<<<<<<< HEAD skipIfRocm, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) skipIfWindows, TEST_WITH_ROCM, ) @@ -994,7 +1005,11 @@ def tearDown(self) -> None: self.exit_stack.close() super().tearDown() +<<<<<<< HEAD def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals, builder): +======= + def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) root = guard_manager_wrapper.root cloned_root = root.clone_manager(lambda x: True) cloned_wrapper = torch._dynamo.guards.GuardManagerWrapper(cloned_root) @@ -2044,6 +2059,7 @@ def fn(x): ref0 = fn(x) ref1 = fn(x) +<<<<<<< HEAD opt_fn = torch.compile(fn, backend="eager") # Especially for internal usage, there are many calls to random functions # on first compile, e.g., from various library initializations. Run once @@ -2051,6 +2067,10 @@ def fn(x): opt_fn(x) random.seed(0) +======= + random.seed(0) + opt_fn = torch.compile(fn, backend="eager") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) res0 = opt_fn(x) res1 = opt_fn(x) @@ -3839,9 +3859,12 @@ def f(x): self.assertEqual(f(torch.ones(8, 4)), gm(torch.ones(8, 4))) +<<<<<<< HEAD @skipIfWindows( msg="TODO: (xuhancn) fix, AssertionError: tensor([[0.1000, 0.1000, 0.1000, ..., 0.1000, 0.1000, 0.1000]," ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_optim_state_references_cleared(self): model = torch.nn.Linear(2048, 2048, bias=False) x = torch.ones(2048) @@ -3957,7 +3980,11 @@ def randint_fn(high, size, out): opt_model(17, (12,), out2) @requires_cuda +<<<<<<< HEAD @serialTest() +======= + @serialTest +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_mem_leak_guards(self): def gn(x0, x): return x0 * x @@ -4186,6 +4213,7 @@ def fn(x, l): torch.compile(fn, backend=counter)(torch.randn([2, 2]), []) self.assertEqual(counter.frame_count, 1) +<<<<<<< HEAD def test_get_type_hints(self): class Foo: pass @@ -4201,6 +4229,8 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_graph_break_on_jit_isinstance(self): @torch.compile(backend="eager") def fn(x): @@ -4493,6 +4523,7 @@ def func3(x, y): # frame_count should stay at 1. self.assertEqual(cnt.frame_count, 1) +<<<<<<< HEAD def test_tensor_set_data_mismatched_dtype(self): def func(x, y): x.data = y.to(dtype=torch.bfloat16) @@ -4507,6 +4538,8 @@ def func(x, y): self.assertEqual(x1.data, x2.data) self.assertEqual(y1, y2) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_user_ctor_ctx_manager(self): class UserCtxManager: def __enter__(self): @@ -5006,6 +5039,7 @@ def fn(x_weak, weight, y): res = opt_fn(x_weak, weight, y) self.assertEqual(ref, res) +<<<<<<< HEAD # https://github.com/pytorch/pytorch/issues/159258 def test_weakref_proxy(self): class DummyTrainer: @@ -5027,6 +5061,8 @@ def foo(self): compiled_foo = torch.compile(model.foo, backend="eager", fullgraph=True) self.assertEqual(compiled_foo(), x) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_weakref_reconstruct(self): def fn(x_weak, weight, y): y = torch.sin(y) @@ -5088,7 +5124,10 @@ def fn(x_weak, y): # any behavior that depends on deallocation order. We do guarantee "eventual consistency", # that is, after the torch.compile'd function is finished running (including any graph breaks), # refcount semantics will match eager's. +<<<<<<< HEAD @skipIfWindows(msg="TODO: (xuhancn) fix, AssertionError: False is not true") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_weakref_callback(self): called1 = False @@ -5893,10 +5932,13 @@ def f(x): torch.view_as_real(out_test).sum().backward() self.assertEqual(x_ref.grad, x_test.grad) +<<<<<<< HEAD @unittest.skipIf( not SM70OrLater, "Triton only supports devices of CUDA capability >= 7.0", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_add_complex_conj(self): def f(x): return x + x.conj() @@ -6217,7 +6259,11 @@ def f(x, param): self.assertEqual(out_ref, out_test) @requires_cuda +<<<<<<< HEAD # This test will fail as flip in combination with particular input lengths +======= + # This test will fail as flip in combination with particular input lenghts +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # produces weird results. # This is under investigations in # https://github.com/pytorch/pytorch/issues/131805 @@ -6561,7 +6607,10 @@ def fn(x, y): self.assertEqual(ref, res) @skipIfPy312 # listcomp bytecode is optimized +<<<<<<< HEAD @skipIfWindows(msg="TODO: (xuhancn) fix, AssertionError: Scalars are not equal!") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_listcomp(self): class Module(torch.nn.Module): def __init__(self): @@ -6961,8 +7010,11 @@ def fn(x): torch._dynamo.utils.clear_compilation_metrics() +<<<<<<< HEAD # https://github.com/pytorch/pytorch/issues/156580 @serialTest() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dont_dce_rand(self): # https://github.com/pytorch/pytorch/issues/143431 def f(image_latent): @@ -7024,6 +7076,7 @@ def f(x, c): c = "foobar" self.assertEqual(f(x, c), opt_f(x, c)) +<<<<<<< HEAD def test_nn_param_freevar_codegen(self): class Model2(nn.Module): def __init__(self) -> None: @@ -7056,6 +7109,8 @@ def wrapper(*args, **kwargs): v2 = jit_func(input_tensor) self.assertEqual(v1, v2) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_amp_foreach_fake_impl(self): inv_scale = torch.full((1,), 0.25) found_inf = torch.full((1,), 0.0) @@ -7071,6 +7126,7 @@ def f(): res = torch.compile(f, backend="aot_eager")() self.assertEqual(ref, res) +<<<<<<< HEAD def test_deleted_compile_wrapper_segfault(self): def fn(x): return x + 1 @@ -7083,6 +7139,8 @@ def fn(x): opt_fn = torch.compile(fn, backend="eager") opt_fn(torch.randn(3)) # possible segfault due to first opt_fn deletion +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_delete_local_error(self): @torch.compile(backend="eager", fullgraph=True) def fn(x): @@ -7094,6 +7152,7 @@ def fn(x): with self.assertRaises(torch._dynamo.exc.Unsupported): fn(torch.ones(3)) +<<<<<<< HEAD def test_nanmean_out(self): def f(x, out): torch.nanmean(x, out=out) @@ -7259,6 +7318,8 @@ def fn(): ) self.assertEqual(explain_output.break_reasons[0].reason, expected_msg) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ReproTestsDevice(torch._dynamo.test_case.TestCase): def test_sub_alpha_scalar_repro(self, device): @@ -7500,7 +7561,10 @@ def f(x, s0, s1, s2): out = f_compiled(x, s0, s1, s2) self.assertEqual(out_ref, out) +<<<<<<< HEAD @skipIfRocm +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "requires gpu with fp8 support") @requires_cuda def test_partitioner_saves_weights_for_bw(self): @@ -7637,7 +7701,11 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): # *are* saved for backward, and become back inputs. # The easier-to-test thing I'm checking for here is that the recompute # on primals_2 happens in the backward. With the recompute, +<<<<<<< HEAD # there are 5 _to_copy ops in the backward. Without it, there are 4 +======= + # there are 5 _to_copy ops in the backwrad. Without it, there are 4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # (aka if you set torch._functorch.config.treat_parameters_as_free_to_save = False) self.assertEqual(mode.ops_counter[torch.ops.aten._to_copy.default], 5) @@ -7750,6 +7818,7 @@ def f(x): with mock.patch("torch.cuda.is_initialized", lambda: False): self.assertEqual(f(inp), inp + 2) +<<<<<<< HEAD def test_named_tuple_vt_clone(self): # https://github.com/pytorch/pytorch/issues/157945 class SVDCompressor(nn.Module): @@ -7863,6 +7932,8 @@ def unsafe_grad(y): unsafe_grad(y) # should not warn self.assertEqual(len(w), 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_parametrized_tests(ReproTests) diff --git a/test/dynamo/test_skip_guard_eval_unsafe.py b/test/dynamo/test_skip_guard_eval_unsafe.py index dc7d74bc3629d..5657c6f7d7ece 100644 --- a/test/dynamo/test_skip_guard_eval_unsafe.py +++ b/test/dynamo/test_skip_guard_eval_unsafe.py @@ -54,9 +54,14 @@ def fn(x, y): def test_post_recompile(self): class Foo: +<<<<<<< HEAD def __init__(self): self.a = 4 self.b = 5 +======= + a = 4 + b = 5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) foo = Foo() diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index 89c14961a3a75..dba4fa0fa4603 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -10,7 +10,10 @@ import subprocess import tempfile import unittest.mock +<<<<<<< HEAD from contextlib import contextmanager +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._dynamo.test_case @@ -22,6 +25,7 @@ from torch._logging._internal import TorchLogsFormatter from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_utils import find_free_port +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton @@ -30,6 +34,14 @@ HAS_TLPARSE = shutil.which("tlparse") is not None requires_tlparse = unittest.skipUnless(HAS_TLPARSE, "requires tlparse") +======= +from torch.testing._internal.inductor_utils import HAS_CUDA + + +HAS_TLPARSE = shutil.which("tlparse") is not None +requires_tlparse = unittest.skipUnless(HAS_TLPARSE, "requires tlparse") +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" ) @@ -236,7 +248,11 @@ def test_compile_id_serialization_deserialization(self): with self.assertRaises(ValueError): torch._guards.CompileId.from_string(bad_cid) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_schedule(self): fn_opt = torch.compile(inductor_schedule_fn, backend="inductor") fn_opt(torch.ones(1000, 1000, device="cuda")) @@ -245,23 +261,37 @@ def test_schedule(self): """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +======= +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "triton_kernel_info", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD +======= +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"compilation_metrics_runtime": "METRICS", "frame_id": 0, "frame_compile_id": 0} @@ -270,7 +300,11 @@ def test_schedule(self): self.assertParses() +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cudagraphs(self): fn_opt = torch.compile(mode="reduce-overhead")(inductor_schedule_fn) fn_opt(torch.ones(1000, 1000, device="cuda")) @@ -279,23 +313,37 @@ def test_cudagraphs(self): """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +======= +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "triton_kernel_info", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD +======= +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"compilation_metrics_runtime": "METRICS", "frame_id": 0, "frame_compile_id": 0} @@ -318,6 +366,7 @@ def fn(x, y): """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 1, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} @@ -332,33 +381,64 @@ def fn(x, y): {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['y']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_y_": [1000, 1000], "l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD +======= +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "recompile_reasons", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"create_symbol": {"symbol": "s48", "val": "1", "vr": "[-int_oo, int_oo]", "source": "L['y']", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +======= +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD +======= +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0} """, # noqa: B950 @@ -375,22 +455,36 @@ def test_example_fn(self): """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000], "ones_1": [1000, 1000], "output_1": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +======= +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD +======= +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 @@ -411,33 +505,56 @@ def test_example_training_fn(self): """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 1, "frame_compile_id": 0, "attempt": 1} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 1} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 1} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['___stack1']"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 1} {"dynamo_cpp_guards_str": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 1, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['___stack0']"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1} {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['___stack0']"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['___stack0']"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1} {"dynamo_output_graph": {"sizes": {"l_stack0_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000], "sum_1": []}}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +<<<<<<< HEAD {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"aot_joint_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} @@ -448,6 +565,10 @@ def test_example_training_fn(self): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +<<<<<<< HEAD +======= +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"dynamo_cpp_guards_str": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} @@ -460,7 +581,11 @@ def test_example_training_fn(self): {"bwd_compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['output']"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"compilation_metrics": "METRICS", "frame_id": 4, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 @@ -480,7 +605,11 @@ def test_dynamo_error(self): """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "dynamo_error", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} @@ -514,12 +643,19 @@ def throw(x): """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"aot_joint_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -536,7 +672,11 @@ def throw(x): self.assertParses() @requires_distributed() +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_ddp_graphs(self): class ToyModel(torch.nn.Module): def __init__(self) -> None: @@ -625,7 +765,11 @@ def forward(self, x): {"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['args'][0]"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} {"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} @@ -641,6 +785,7 @@ def forward(self, x): {"compilation_metrics": "METRICS", "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} {"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['self']._modules['layers']._modules['0']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 1, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} @@ -654,12 +799,28 @@ def forward(self, x): {"describe_source": {"describer_id": "ID", "id": 8, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 4, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 9, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 4, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['self']._modules['layers']._modules['0']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['self']._modules['layers']._modules['0']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 2, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 2, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "stride": [1024, 1], "storage": 2, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 2, "source": "L['x']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 3, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 8, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 3, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 8, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 4, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 9, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 4, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 9, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_self_modules_layers_modules_0_parameters_weight_": [1024, 1024], "l_self_modules_layers_modules_0_parameters_bias_": [1024], "l_x_": [1024, 1024], "l_self_modules_layers_modules_1_parameters_weight_": [1024, 1024], "l_self_modules_layers_modules_1_parameters_bias_": [1024], "input_1": [1024, 1024], "input_2": [1024, 1024]}}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"optimize_ddp_split_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"optimize_ddp_split_child": {"name": "submod_0"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"optimize_ddp_split_child": {"name": "submod_1"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 1, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} @@ -671,6 +832,18 @@ def forward(self, x): {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['self']._modules['layers']._modules['0']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 2, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 2, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 2, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 2, "source": "L['self']._modules['layers']._modules['0']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"aot_joint_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -681,6 +854,7 @@ def forward(self, x): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD {"describe_storage": {"id": 16, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 29, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "dynamo_hint_overrides": {}, "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 29, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} @@ -690,6 +864,17 @@ def forward(self, x): {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +======= +{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 16, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 29, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 29, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 17, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 30, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 17, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 30, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"aot_joint_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -700,6 +885,10 @@ def forward(self, x): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD +======= +{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 @@ -725,22 +914,36 @@ def fn(x): {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1], "add": [1]}}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_joint_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_joint_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +======= +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD +======= +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"dynamo_cpp_guards_str": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 1, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 @@ -766,10 +969,17 @@ def fn(a, b): """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 800}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 20], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [20, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 1, "describer_id": "ID", "size": 2400}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [20, 30], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [30, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 20], "is_leaf": true, "stride": [20, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 2400}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [20, 30], "is_leaf": true, "stride": [30, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [10, 20], "l_b_": [20, 30], "matmul": [10, 30]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -777,12 +987,20 @@ def fn(a, b): {"artifact": {"name": "recompile_reasons", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 200}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [5, 10], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [10, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [5, 10], "is_leaf": true, "stride": [10, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"create_symbol": {"symbol": "s97", "val": "5", "vr": "[2, int_oo]", "source": "L['a'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"create_symbol": {"symbol": "s98", "val": "10", "vr": "[2, int_oo]", "source": "L['a'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_storage": {"id": 1, "describer_id": "ID", "size": 600}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 15], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [15, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +======= +{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [10, 15], "is_leaf": true, "stride": [15, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 1, "source": "L['b']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"create_symbol": {"symbol": "s52", "val": "10", "vr": "[2, int_oo]", "source": "L['b'].size()[0]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"create_symbol": {"symbol": "s20", "val": "15", "vr": "[2, int_oo]", "source": "L['b'].size()[1]", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} @@ -818,7 +1036,11 @@ def inner(x, ys, zs): """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -826,7 +1048,11 @@ def inner(x, ys, zs): {"artifact": {"name": "recompile_reasons", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} @@ -856,10 +1082,17 @@ def forward(self, x, y): return add {"describe_storage": {"id": 0, "describer_id": "ID", "size": 12}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [3], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 1, "describer_id": "ID", "size": 12}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 1, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [3], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [3], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 12}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [3], "is_leaf": true, "stride": [1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 1, "source": "L['y']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [3], "l_y_": [3], "add": [3]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -886,27 +1119,45 @@ def fn(a): """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1], "sin": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD {"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_joint_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +======= +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD +======= +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +<<<<<<< HEAD {"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "dynamo_hint_overrides": {}, "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +======= +{"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1], "sin": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -916,8 +1167,11 @@ def fn(a): {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +<<<<<<< HEAD {"artifact": {"name": "inductor_provenance_tracking_node_mappings", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "inductor_provenance_tracking_kernel_stack_traces", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"artifact": {"name": "fx_graph_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aotautograd_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -1059,10 +1313,17 @@ def backward(ctx, gO): '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 9, "frame_compile_id": 0, "attempt": 0}', '{"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 1, "attempt": 0}', '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 6, "frame_compile_id": 1, "attempt": 0}', +<<<<<<< HEAD '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 9, "frame_compile_id": 1, "attempt": 0}', '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 13, "frame_compile_id": 0, "attempt": 0}', '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 12, "frame_compile_id": 1, "attempt": 0}', '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 13, "frame_compile_id": 1, "attempt": 0}', +======= + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 7, "frame_compile_id": 1, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 11, "frame_compile_id": 0, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 10, "frame_compile_id": 1, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 11, "frame_compile_id": 1, "attempt": 0}', +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] logs = self.buffer.getvalue() self.assertTrue(all(event in logs for event in expected)) @@ -1087,6 +1348,7 @@ def test_compiled_autograd_chromium(self): logs = self.buffer.getvalue() self.assertTrue(all(event in logs for event in expected)) +<<<<<<< HEAD def test_recompile_user_contexts(self): # test that user_context is called only once per recompile num_calls = 0 @@ -1533,6 +1795,8 @@ def fn(x): ) self.assertParses() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 9d60cbe81c970..b57ff164a7cdd 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -31,7 +31,11 @@ parametrize, subtest, ) +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +from torch.testing._internal.inductor_utils import HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.two_tensor import TwoTensor from torch.utils._python_dispatch import return_and_correct_aliasing @@ -145,6 +149,11 @@ def mk_subclass_dense_subclass_dense(): VIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()} +<<<<<<< HEAD +======= +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) compile_full_eager = torch.compile(backend="eager", fullgraph=True) @@ -1366,7 +1375,11 @@ def forward(self, L_x_: "f32[3, 4]"): ) self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) +<<<<<<< HEAD # Cannot reuse the version from AOTAutograd, since that uses python functional tensors. +======= + # Cannot re-use the version from AOTAutograd, since that uses python functional tensors. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def to_fun(x): x_functional = torch._to_functional_tensor(x) torch._mirror_autograd_meta_to(x, x_functional) @@ -2015,7 +2028,11 @@ def forward(self): exp_frame_count=[1, 1, 2, 2], exp_shape_env_guards=[ [], +<<<<<<< HEAD # s0 is specialized and guarded in outer shape_env when dynamo checks the guards +======= + # s0 is specialized and guarded in outter shape_env when dynamo checks the guards +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"], [ "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", @@ -2037,7 +2054,11 @@ def forward(self): exp_frame_count=[1, 1, 2, 2], exp_shape_env_guards=[ [], +<<<<<<< HEAD # s0 is specialized and guarded in outer shape_env when dynamo checks the guards +======= + # s0 is specialized and guarded in outter shape_env when dynamo checks the guards +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"], [ "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", @@ -2244,6 +2265,7 @@ def f(tt): fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2269,10 +2291,20 @@ def forward( primals_5, # SavedForBackwardsAOTOutput(idx=1) primals_7, # SavedForBackwardsAOTOutput(idx=2) ) +======= + normalize_gm(fw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"): + mul: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None + mul_3: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None + return (mul, mul_3, primals_5, primals_7, primals_7, primals_1, primals_5, primals_7) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2295,6 +2327,15 @@ def forward( primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) ) +======= + normalize_gm(bw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "Sym(s47)", primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"): + mul_8: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None + mul_9: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None + return (None, None, mul_8, mul_9, primals_5, primals_7, primals_7) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) @@ -2310,6 +2351,7 @@ def f(tt): fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2323,11 +2365,18 @@ def forward( primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1) primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) ): +======= + normalize_gm(fw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None view: "f32[s16, s47]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None view_1: "f32[s16, s47]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None +<<<<<<< HEAD return ( view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a') view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b') @@ -2337,10 +2386,14 @@ def forward( primals_5, # SavedForBackwardsAOTOutput(idx=0) primals_7, # SavedForBackwardsAOTOutput(idx=1) ) +======= + return (view, view_1, primals_2, primals_5, primals_5, primals_5, primals_7) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2362,6 +2415,15 @@ def forward( primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) ) +======= + normalize_gm(bw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16, s47]", tangents_2: "f32[s16, s47]"): + view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None + view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None + return (None, None, view_2, view_3, primals_5, primals_7, primals_7) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) @@ -2379,6 +2441,7 @@ def f(tt, a, b): fw, bw = self._compile_check(f, [(tt, a, b)], dynamic=True, call_backward=True) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2392,6 +2455,12 @@ def forward( primals_6: "Sym(s98)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1) primals_7: "Sym(s98)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) ): +======= + normalize_gm(fw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_3: "f32[s97, s98]", primals_4: "f32[s97, s98]", primals_5: "Sym(s97)", primals_6: "Sym(s98)", primals_7: "Sym(s98)"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mul: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None mul_3: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None mul_8: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul, primals_2); mul = None @@ -2400,6 +2469,7 @@ def forward( mul_19: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_11, primals_1); mul_11 = None mul_24: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_16, primals_2); mul_16 = None mul_27: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_19, primals_2); mul_19 = None +<<<<<<< HEAD return ( mul_24, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a') mul_27, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b') @@ -2411,10 +2481,14 @@ def forward( primals_5, # SavedForBackwardsAOTOutput(idx=2) primals_7, # SavedForBackwardsAOTOutput(idx=3) ) +======= + return (mul_24, mul_27, primals_5, primals_7, primals_7, primals_1, primals_2, primals_5, primals_7) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2427,6 +2501,12 @@ def forward( tangents_1: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a') tangents_2: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b') ): +======= + normalize_gm(bw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_5: "Sym(s97)", primals_7: "Sym(s98)", tangents_1: "f32[s97, s98]", tangents_2: "f32[s97, s98]"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mul_32: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); tangents_1 = None mul_33: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_2, primals_2); tangents_2 = None mul_34: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_32, primals_1); mul_32 = None @@ -2435,6 +2515,7 @@ def forward( mul_37: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_35, primals_2); mul_35 = primals_2 = None mul_38: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_36, primals_1); mul_36 = None mul_39: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_37, primals_1); mul_37 = primals_1 = None +<<<<<<< HEAD return ( None, # None None, # None @@ -2444,6 +2525,9 @@ def forward( primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) ) +======= + return (None, None, mul_38, mul_39, primals_5, primals_7, primals_7) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) @@ -2459,6 +2543,7 @@ def f(tt): fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2472,11 +2557,18 @@ def forward( primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1) primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) ): +======= + normalize_gm(fw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None view: "f32[s47, s16]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None view_1: "f32[s47, s16]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None +<<<<<<< HEAD return ( view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a') view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b') @@ -2486,10 +2578,14 @@ def forward( primals_5, # SavedForBackwardsAOTOutput(idx=0) primals_7, # SavedForBackwardsAOTOutput(idx=1) ) +======= + return (view, view_1, primals_5, primals_7, primals_7, primals_5, primals_7) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2511,6 +2607,15 @@ def forward( primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) ) +======= + normalize_gm(bw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"): + view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None + view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None + return (None, None, view_2, view_3, primals_5, primals_7, primals_7) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) @@ -2526,6 +2631,7 @@ def f(tt): fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2539,12 +2645,19 @@ def forward( primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1) primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) ): +======= + normalize_gm(fw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None +<<<<<<< HEAD return ( view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a') view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b') @@ -2552,10 +2665,14 @@ def forward( primals_5, # SavedForBackwardsAOTOutput(idx=0) primals_7, # SavedForBackwardsAOTOutput(idx=1) ) +======= + return (view, view_1, mul_6, primals_5, primals_7) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2577,6 +2694,15 @@ def forward( primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) ) +======= + normalize_gm(bw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"): + view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None + view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None + return (None, None, view_2, view_3, primals_5, primals_7, primals_7) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) @@ -2592,6 +2718,7 @@ def f(tt): fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2605,12 +2732,19 @@ def forward( primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1) primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0) ): +======= + normalize_gm(fw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6]) view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None +<<<<<<< HEAD return ( clone, # PlainAOTOutput(idx=0) view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a') @@ -2619,10 +2753,14 @@ def forward( primals_5, # SavedForBackwardsAOTOutput(idx=0) primals_7, # SavedForBackwardsAOTOutput(idx=1) ) +======= + return (clone, view, view_1, mul_6, primals_5, primals_7) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2644,6 +2782,15 @@ def forward( primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) ) +======= + normalize_gm(bw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"): + view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None + view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None + return (None, None, view_2, view_3, primals_5, primals_7, primals_7) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) @@ -2704,6 +2851,7 @@ def f(tt): ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2712,21 +2860,32 @@ def forward( primals_1: "f32[3, 4]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='a') primals_2: "f32[3, 4]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='b') ): +======= + normalize_gm(fw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[3, 4]", primals_2: "f32[3, 4]"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) clone: "f32[3, 4]" = torch.ops.aten.clone.default(primals_1); primals_1 = None clone_1: "f32[3, 4]" = torch.ops.aten.clone.default(primals_2); primals_2 = None view: "f32[12]" = torch.ops.aten.view.default(clone, [-1]) view_1: "f32[12]" = torch.ops.aten.view.default(clone_1, [-1]) +<<<<<<< HEAD return ( clone, # PlainAOTOutput(idx=0) view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a') view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b') clone_1, # PlainAOTOutput(idx=2) ) +======= + return (clone, view, view_1, clone_1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(fw[1].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2738,12 +2897,19 @@ def forward( primals_4: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=1), idx=1) primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0) ): +======= + normalize_gm(fw[1].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1]) sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0) view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1]) +<<<<<<< HEAD return ( clone, # PlainAOTOutput(idx=0) view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a') @@ -2752,10 +2918,14 @@ def forward( clone_1, # PlainAOTOutput(idx=2) primals_5, # SavedForBackwardsAOTOutput(idx=0) ) +======= + return (clone, view, view_1, sym_size_int_2, clone_1, primals_5) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2770,10 +2940,20 @@ def forward( view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='a') view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='b') ) +======= + normalize_gm(bw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, tangents_1: "f32[12]", tangents_2: "f32[12]"): + view_2: "f32[3, 4]" = torch.ops.aten.view.default(tangents_1, [3, 4]); tangents_1 = None + view_3: "f32[3, 4]" = torch.ops.aten.view.default(tangents_2, [3, 4]); tangents_2 = None + return (view_2, view_3) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(bw[1].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2792,6 +2972,15 @@ def forward( primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=1) primals_5, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=0) ) +======= + normalize_gm(bw[1].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"): + view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None + view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None + return (None, view_2, view_3, primals_5, primals_5) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) @@ -2815,6 +3004,7 @@ def f(tt): ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2826,12 +3016,19 @@ def forward( primals_4: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=1), idx=1) primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0) ): +======= + normalize_gm(fw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1]) sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0) view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1]) +<<<<<<< HEAD return ( clone, # PlainAOTOutput(idx=0) view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a') @@ -2840,10 +3037,14 @@ def forward( clone_1, # PlainAOTOutput(idx=2) primals_5, # SavedForBackwardsAOTOutput(idx=0) ) +======= + return (clone, view, view_1, sym_size_int_2, clone_1, primals_5) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2862,6 +3063,15 @@ def forward( primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=1) primals_5, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=0) ) +======= + normalize_gm(bw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"): + view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None + view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None + return (None, view_2, view_3, primals_5, primals_5) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) @@ -2877,6 +3087,7 @@ def f(tt): fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2885,19 +3096,30 @@ def forward( primals_1: "f32[24]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='a') primals_2: "f32[24]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='b') ): +======= + normalize_gm(fw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[24]", primals_2: "f32[24]"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) clone: "f32[24]" = torch.ops.aten.clone.default(primals_1); primals_1 = None clone_1: "f32[24]" = torch.ops.aten.clone.default(primals_2); primals_2 = None view: "f32[3, 2, 4]" = torch.ops.aten.view.default(clone, [3, 2, 4]); clone = None view_1: "f32[3, 2, 4]" = torch.ops.aten.view.default(clone_1, [3, 2, 4]); clone_1 = None +<<<<<<< HEAD return ( view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a') view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b') ) +======= + return (view, view_1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -2912,6 +3134,15 @@ def forward( view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='a') view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='b') ) +======= + normalize_gm(bw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, tangents_1: "f32[3, 2, 4]", tangents_2: "f32[3, 2, 4]"): + view_2: "f32[24]" = torch.ops.aten.view.default(tangents_1, [24]); tangents_1 = None + view_3: "f32[24]" = torch.ops.aten.view.default(tangents_2, [24]); tangents_2 = None + return (view_2, view_3) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) @@ -3038,6 +3269,7 @@ def f(nt): fw, bw = self._compile_check(f, [(nt,)], dynamic=True, call_backward=True) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -3069,10 +3301,21 @@ def forward( primals_8, # SavedForBackwardsAOTOutput(idx=1) primals_10, # SavedForBackwardsAOTOutput(idx=2) ) +======= + normalize_gm(fw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"): + clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None + + mul: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None + return (mul, primals_5, primals_6, primals_7, primals_8, primals_10, primals_10, primals_1, primals_8, primals_10) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -3099,6 +3342,14 @@ def forward( primals_10, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2) primals_10, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=1) ) +======= + normalize_gm(bw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "Sym(s51)", primals_8: "Sym(s51)", primals_10: "Sym(s55)", tangents_1: "f64[s64, s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"): + mul_1: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None + return (None, None, None, mul_1, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) @@ -3114,6 +3365,7 @@ def f(nt): fw, bw = self._compile_check(f, [(nt,)], dynamic=True, call_backward=True) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -3130,10 +3382,17 @@ def forward( primals_9: "Sym(s55)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=2) primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1) ): +======= + normalize_gm(fw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None cat: "f64[s64, 2*s55]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None add_2: "Sym(2*s55)" = primals_10 + primals_10 +<<<<<<< HEAD return ( cat, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values') primals_5, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets') @@ -3146,10 +3405,14 @@ def forward( primals_10, # SavedForBackwardsAOTOutput(idx=1) add_2, # SavedForBackwardsAOTOutput(idx=2) ) +======= + return (cat, primals_5, primals_6, primals_7, primals_8, add_2, add_2, primals_8, primals_10, add_2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)), """\ class GraphModule(torch.nn.Module): @@ -3163,10 +3426,17 @@ def forward( tangents_3: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_min_seqlen_tensor') tangents_4: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_max_seqlen_tensor') ): +======= + normalize_gm(bw[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_8: "Sym(s51)", primals_10: "Sym(s55)", add_2: "Sym(2*s55)", tangents_1: "f64[s64, 2*s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) slice_1: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10) slice_2: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None add_4: "f64[s64, s55]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None +<<<<<<< HEAD return ( None, # None None, # None @@ -3179,6 +3449,9 @@ def forward( primals_10, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2) primals_10, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=1) ) +======= + return (None, None, None, add_4, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) @@ -3203,6 +3476,7 @@ def f(nt): ) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)), """\ class (torch.nn.Module): @@ -3219,6 +3493,12 @@ def forward( arg8_1: "Sym(s55)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=2) arg9_1: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1) ): +======= + normalize_gm(fw[0].print_readable(print_output=False)), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "Sym(s51)", arg1_1: "Sym(s71)", arg2_1: "Sym(s55)", arg3_1: "f64[9, s55]", arg4_1: "i64[s51 + 1]", arg5_1: "f32[s0, 0]", arg6_1: "f32[s83, 0]", arg7_1: "Sym(s51)", arg8_1: "Sym(s55)", arg9_1: "Sym(s55)"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) randn: "f64[2, 5]" = torch.ops.aten.randn.default([2, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False) randn_1: "f64[3, 5]" = torch.ops.aten.randn.default([3, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False) randn_2: "f64[4, 5]" = torch.ops.aten.randn.default([4, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False) @@ -3239,6 +3519,7 @@ def forward( sym_size_int: "Sym(s55 + 5)" = torch.ops.aten.sym_size.int(cat_2, 1); cat_2 = None sym_stride_int: "Sym(s55 + 5)" = torch.ops.aten.sym_stride.int(mul, 0) +<<<<<<< HEAD return ( mul, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values') cat_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets') @@ -3247,6 +3528,9 @@ def forward( sym_size_int, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2) sym_stride_int, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1) ) +======= + return (mul, cat_1, zeros_1, zeros_2, sym_size_int, sym_stride_int) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) @@ -3457,7 +3741,11 @@ def forward(self, s71: "Sym(s71)", L_nt_: "f64[3, s71, 5]"): # triggers the eager logic to run, updating the counter and registry. # # Notably however, compile differs in two ways from eager: +<<<<<<< HEAD # (1) The order in which the offsets are assigned ids is different +======= + # (1) The order in which the offsets are assigned ids is differnet +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # the registry would be set in the order the offsets are returned # which is not necessarily the same order as they were constructed. # (2) If a NestedTensor is not returned, then the AOTAutograd wrapping @@ -3796,7 +4084,11 @@ def fn1(nt1, nt2): def test_basic_autograd(self): self._test_autograd("aot_eager") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_basic_autograd_inductor(self): self._test_autograd("inductor") diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py index 35036fd1de3fa..b91ba49ac16d4 100644 --- a/test/dynamo/test_subgraphs.py +++ b/test/dynamo/test_subgraphs.py @@ -401,7 +401,11 @@ def fn(a, b): y = torch.randn(3) self.assertEqual(opt_fn(x, y), fn(x, y)) self.assertEqual(opt_fn(x, x), fn(x, x)) +<<<<<<< HEAD # NB: This COULD validly be 2, but we don't test disjointedness in the +======= + # NB: This COULD validly be 2, but we don't test disjointness in the +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # guards for when x and y didn't duck size together, so we end up # with a generic graph that also works when x and y happen to duck # size together. diff --git a/test/dynamo/test_trace_rules.py b/test/dynamo/test_trace_rules.py index 9bfccd94b1f7e..aa5df42d59026 100644 --- a/test/dynamo/test_trace_rules.py +++ b/test/dynamo/test_trace_rules.py @@ -126,7 +126,11 @@ def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObject torch_name_rule_map = {} # In some platforms, these functions were loaded as classes instead of functions. +<<<<<<< HEAD # To mitigate these weird cases, we need this special check. +======= + # To mitigate these weired cases, we need this special check. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def is_special_functions(obj): return hashable(obj) and obj in { torch._C._cuda_isCurrentStreamCapturing, diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 91862e6d3eb00..12698f076c33c 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -132,9 +132,12 @@ def fn(shape): res1 = fn(shape) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) +<<<<<<< HEAD # Especially for internal: before resetting the seed, first shake out any rng # calls that occur on compile, e.g., as a result of some module initializations. opt_fn(shape) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) random.seed(1) res2 = opt_fn(shape) @@ -154,9 +157,12 @@ def fn(x): res1 = fn(x) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) +<<<<<<< HEAD # Especially for internal: before resetting the seed, first shake out any rng # calls that occur on compile, e.g., as a result of some module initializations. opt_fn(x) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) random.seed(1) res2 = opt_fn(x) self.assertTrue(same(res1, res2)) @@ -182,9 +188,12 @@ def fn(x): res1 = fn(x) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) +<<<<<<< HEAD # Especially for internal: before resetting the seed, first shake out any rng # calls that occur on compile, e.g., as a result of some module initializations. opt_fn(x) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) random.seed(1) res2 = opt_fn(x) self.assertTrue(same(res1, res2)) @@ -215,9 +224,12 @@ def fn(x): random.seed(1) res1 = fn(x) opt_fn = torch.compile(fn, backend="eager") +<<<<<<< HEAD # Especially for internal: before resetting the seed, first shake out any rng # calls that occur on compile, e.g., as a result of some module initializations. opt_fn(x) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) random.seed(1) res2 = opt_fn(x) self.assertTrue(same(res1, res2)) @@ -244,9 +256,12 @@ def fn(x, rand2): random.seed(0) y_1, rand2_1, rand3_1 = fn(inp, random.Random(12)) state_1 = random.getstate() +<<<<<<< HEAD # Especially for internal: before resetting the seed, first shake out any rng # calls that occur on compile, e.g., as a result of some module initializations. opt_fn(inp, random.Random(12)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) random.seed(0) y_2, rand2_2, rand3_2 = opt_fn(inp, random.Random(12)) state_2 = random.getstate() @@ -714,6 +729,7 @@ def fn(x, y): self.assertEqual(fn_opt(x, y3), fn(x, y3)) self.assertEqual(cnt.frame_count, 1) +<<<<<<< HEAD @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_tensorfiy_python_scalars_1(self): @torch.compile(backend="aot_eager") @@ -748,6 +764,8 @@ def f(x): x = torch.tensor([finfo_float16.max], dtype=torch.float16) self.assertEqual(f(x), x.item() * 101 * torch.tensor([1], dtype=torch.float32)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=False) def test_unspec_float_input_f64(self): cnts = torch._dynamo.testing.CompileCounter() diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index b7166c5ce6d1b..9ba7c50a7eeeb 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -12,9 +12,12 @@ from torch._inductor.test_case import TestCase +<<<<<<< HEAD _IS_WINDOWS = sys.platform == "win32" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestUtils(TestCase): def test_nan(self): a = torch.Tensor([float("nan")]) @@ -246,6 +249,7 @@ def add(x, y): utils.reset_frame_count() torch._logging._internal.structured_logging_overhead.clear() +<<<<<<< HEAD @dynamo_config.patch({"log_compilation_metrics": True}) @inductor_config.patch({"force_disable_caches": True}) def test_stack_trace(self): @@ -350,6 +354,8 @@ def backward(grad_output): "'Dynamo does not know how to trace builtin operator `print`'", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dynamo_config.patch( { "log_compilation_metrics": True, @@ -407,6 +413,7 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'additional_fake_tensor_prop': [0.0, 0.0], 'aot_collect_metadata': [0.0], 'aot_trace_joint_graph': [0.0], +<<<<<<< HEAD 'backward._backward_impl': [0.0], 'build_guards': [0.0], 'bytecode_tracing': [0.0], @@ -438,6 +445,8 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'additional_fake_tensor_prop': [0.0, 0.0], 'aot_collect_metadata': [0.0], 'aot_trace_joint_graph': [0.0], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'async_compile.wait': [0.0, 0.0], 'backward._backward_impl': [0.0], 'build_guards': [0.0], @@ -462,6 +471,7 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): {'_recursive_joint_graph_passes': 0.0, '_recursive_post_grad_passes': 0.0, '_recursive_pre_grad_passes': 0.0, +<<<<<<< HEAD 'backend_compile': 0.0, 'code_gen': 0.0, 'entire_backward_compile': 0.0, @@ -474,6 +484,8 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): {'_recursive_joint_graph_passes': 0.0, '_recursive_post_grad_passes': 0.0, '_recursive_pre_grad_passes': 0.0, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'async_compile.wait': 0.0, 'backend_compile': 0.0, 'code_gen': 0.0, @@ -500,9 +512,12 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): e.cuda_version = None e.triton_version = None e.python_version = None +<<<<<<< HEAD e.stack_trace = None e.graph_node_shapes = None e.exception_stack_trace = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # First event is for the forward. Formatting makes reading diffs # much easier. @@ -541,7 +556,10 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'dynamo_time_before_restart_s': 0.0, 'end_time_us': 100, 'entire_frame_compile_time_s': 0.0, +<<<<<<< HEAD 'exception_stack_trace': None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'fail_reason': None, 'fail_type': None, 'fail_user_frame_filename': None, @@ -550,7 +568,10 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'gc_time_us': 0, 'graph_input_count': 1, 'graph_node_count': 3, +<<<<<<< HEAD 'graph_node_shapes': None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'graph_op_count': 1, 'guard_count': 9, 'has_guarded_code': True, @@ -563,7 +584,10 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'inductor_fx_remote_cache_hit_keys': None, 'inductor_fx_remote_cache_miss_count': None, 'inductor_fx_remote_cache_miss_keys': None, +<<<<<<< HEAD 'inline_inbuilt_nn_modules_candidate': False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'is_forward': True, 'is_runtime': False, 'joint_graph_pass_time_us': 0, @@ -577,7 +601,10 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'pre_grad_pass_time_us': 0, 'python_version': None, 'recompile_reason': None, +<<<<<<< HEAD 'recompile_user_contexts': None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'remote_cache_time_saved_s': None, 'remote_cache_version': None, 'remote_fx_graph_cache_get_time_ms': None, @@ -589,6 +616,7 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'runtime_triton_autotune_time_us': None, 'shape_env_guard_count': 0, 'specialize_float': False, +<<<<<<< HEAD 'stack_trace': None, 'start_time': 0.0001, 'start_time_us': 100, @@ -675,6 +703,8 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'shape_env_guard_count': 0, 'specialize_float': False, 'stack_trace': None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'start_time': 0.0001, 'start_time_us': 100, 'structured_logging_overhead_s': 0.0, @@ -710,8 +740,13 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'compile_id': '1/0', 'compile_time_autotune_time_us': None, 'compliant_custom_ops': None, +<<<<<<< HEAD 'config_inline_inbuilt_nn_modules': False, 'config_suppress_errors': False, +======= + 'config_inline_inbuilt_nn_modules': None, + 'config_suppress_errors': None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'cuda_version': None, 'cudagraph_skip_reason': None, 'distributed_ephemeral_timeout_us': None, @@ -722,7 +757,10 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'dynamo_time_before_restart_s': None, 'end_time_us': 100, 'entire_frame_compile_time_s': None, +<<<<<<< HEAD 'exception_stack_trace': None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'fail_reason': None, 'fail_type': None, 'fail_user_frame_filename': None, @@ -731,7 +769,10 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'gc_time_us': None, 'graph_input_count': None, 'graph_node_count': None, +<<<<<<< HEAD 'graph_node_shapes': None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'graph_op_count': None, 'guard_count': None, 'has_guarded_code': None, @@ -744,7 +785,10 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'inductor_fx_remote_cache_hit_keys': None, 'inductor_fx_remote_cache_miss_count': None, 'inductor_fx_remote_cache_miss_keys': None, +<<<<<<< HEAD 'inline_inbuilt_nn_modules_candidate': False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'is_forward': False, 'is_runtime': False, 'joint_graph_pass_time_us': None, @@ -758,7 +802,10 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'pre_grad_pass_time_us': None, 'python_version': None, 'recompile_reason': None, +<<<<<<< HEAD 'recompile_user_contexts': None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'remote_cache_time_saved_s': None, 'remote_cache_version': None, 'remote_fx_graph_cache_get_time_ms': None, @@ -770,6 +817,7 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'runtime_triton_autotune_time_us': None, 'shape_env_guard_count': None, 'specialize_float': None, +<<<<<<< HEAD 'stack_trace': None, 'start_time': 0.0001, 'start_time_us': 100, @@ -856,6 +904,8 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'shape_env_guard_count': None, 'specialize_float': None, 'stack_trace': None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 'start_time': 0.0001, 'start_time_us': 100, 'structured_logging_overhead_s': 0.0, diff --git a/test/dynamo/test_view.py b/test/dynamo/test_view.py index 03b9ac5a9f81a..ae74914796736 100644 --- a/test/dynamo/test_view.py +++ b/test/dynamo/test_view.py @@ -33,6 +33,7 @@ def f(t, _n): t = torch.tensor([2, 4], dtype=torch.int32) f(t, 8) +<<<<<<< HEAD def test_view_with_tensor_shape_params(self): # Test for issue #156720: aten.view.default with tensor shape parameters class TestModel(torch.nn.Module): @@ -113,6 +114,8 @@ def test_fn(x, shape_params): torch.testing.assert_close(result, expected) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_abs b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_abs new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_abs_overflows b/test/dynamo_expected_failures/CPython313-test_cmath-CMathTests.test_abs_overflows new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_abs b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_abs new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_add b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_add new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_boolcontext b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_boolcontext new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_conjugate b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_conjugate new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_constructor_from_string b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_constructor_from_string new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_constructor_negative_nans_from_string b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_constructor_negative_nans_from_string new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_getnewargs b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_getnewargs new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_mul b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_mul new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_negative_zero_repr_str b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_negative_zero_repr_str new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_overflow b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_overflow new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_pow b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_pow new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_pow_with_small_integer_exponents b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_pow_with_small_integer_exponents new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_repr_str b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_repr_str new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_sub b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_sub new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_truediv_zero_division b/test/dynamo_expected_failures/CPython313-test_complex-ComplexTest.test_truediv_zero_division new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_bad_key b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_bad_key new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_clear b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_clear new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_constructor b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_constructor new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_contains b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_contains new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_copy b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_copy new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_copy_maintains_tracking b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_copy_maintains_tracking new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_copy_noncompact b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_copy_noncompact new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_contain_use_after_free b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_contain_use_after_free new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_copy_order b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_copy_order new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dictitems_contains_use_after_free b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dictitems_contains_use_after_free new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dictview_set_operations_on_keys b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dictview_set_operations_on_keys new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_empty_presized_dict_in_freelist b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_empty_presized_dict_in_freelist new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_eq b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_eq new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_equal_operator_modifying_operand b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_equal_operator_modifying_operand new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_errors_in_view_containment_check b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_errors_in_view_containment_check new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_fromkeys_operator_modifying_dict_operand b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_fromkeys_operator_modifying_dict_operand new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_fromkeys_operator_modifying_set_operand b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_fromkeys_operator_modifying_set_operand new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_get b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_get new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_getitem b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_getitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_init_use_after_free b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_init_use_after_free new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_instance_dict_getattr_str_subclass b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_instance_dict_getattr_str_subclass new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_invalid_keyword_arguments b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_invalid_keyword_arguments new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_merge_and_mutate b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_merge_and_mutate new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_missing b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_missing new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_mutating_lookup b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_mutating_lookup new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_object_set_item_single_instance_non_str_key b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_object_set_item_single_instance_non_str_key new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_oob_indexing_dictiter_iternextitem b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_oob_indexing_dictiter_iternextitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_pop b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_pop new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_popitem b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_popitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reentrant_insertion b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reentrant_insertion new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_resize2 b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_resize2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reverse_iterator_for_empty_dict b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reverse_iterator_for_empty_dict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reverse_iterator_for_shared_shared_dicts b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_reverse_iterator_for_shared_shared_dicts new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_setdefault b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_setdefault new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_setdefault_atomic b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_setdefault_atomic new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_setitem_atomic_at_resize b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_setitem_atomic_at_resize new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_del b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_del new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_pop b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_pop new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_pop_pending b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_pop_pending new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_popitem b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_popitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_setdefault b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_setdefault new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_to_generic_combinedtable b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_to_generic_combinedtable new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_update b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_splittable_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_store_evilattr b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_store_evilattr new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_str_nonstr b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_str_nonstr new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_views_mapping b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_views_mapping new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_constructor b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_constructor new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_get b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_get new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_getitem b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_getitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_items b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_items new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_keys b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_keys new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_popitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_read b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_read new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_setdefault b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_setdefault new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_values b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_values new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_write b/test/dynamo_expected_failures/CPython313-test_dict-GeneralMappingTests.test_write new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_constructor b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_constructor new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_get b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_get new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_getitem b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_getitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_items b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_items new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_keys b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_keys new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_popitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_setdefault b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_setdefault new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_values b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_values new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_write b/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_write new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float_ceil b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float_ceil new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float_floor b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float_floor new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float_mod b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float_mod new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float_pow b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_float_pow new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_floatconversion b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_floatconversion new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_keywords_in_subclass b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_keywords_in_subclass new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_non_numeric_input_types b/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_non_numeric_input_types new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_from_hex b/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_from_hex new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_invalid_inputs b/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_invalid_inputs new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_subclass b/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_subclass new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_whitespace b/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_whitespace new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_double_specials_do_unpack b/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_double_specials_do_unpack new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_float_specials_do_unpack b/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_float_specials_do_unpack new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_serialized_float_rounding b/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_serialized_float_rounding new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-InfNanTest.test_inf_from_str b/test/dynamo_expected_failures/CPython313-test_float-InfNanTest.test_inf_from_str new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-InfNanTest.test_nan_from_str b/test/dynamo_expected_failures/CPython313-test_float-InfNanTest.test_nan_from_str new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-RoundTestCase.test_inf_nan b/test/dynamo_expected_failures/CPython313-test_float-RoundTestCase.test_inf_nan new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_float-RoundTestCase.test_overflow b/test/dynamo_expected_failures/CPython313-test_float-RoundTestCase.test_overflow new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_int_from_other_bases b/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_int_from_other_bases new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_max_str_digits b/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_max_str_digits new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_max_str_digits_edge_cases b/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_max_str_digits_edge_cases new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_power_of_two_bases_unlimited b/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_power_of_two_bases_unlimited new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_sign_not_counted b/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_sign_not_counted new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_underscores_ignored b/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_underscores_ignored new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_basic b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_basic new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_base_indexable b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_base_indexable new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_base_limits b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_base_limits new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_returns_int_subclass b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_returns_int_subclass new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_subclass_with_index b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_subclass_with_index new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_subclass_with_int b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_subclass_with_int new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_intconversion b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_intconversion new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_invalid_signs b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_invalid_signs new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_non_numeric_input_types b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_non_numeric_input_types new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_string_float b/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_string_float new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_int-PyLongModuleTests.test_pylong_str_to_int b/test/dynamo_expected_failures/CPython313-test_int-PyLongModuleTests.test_pylong_str_to_int new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_3720 b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_3720 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_exception_function b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_exception_function new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_exception_sequence b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_exception_sequence new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_basic b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_basic new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_big_range b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_big_range new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_callable b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_callable new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_class_for b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_class_for new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_class_iter b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_class_iter new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_dict b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_dict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_empty b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_empty new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_for_loop b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_for_loop new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_function b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_function new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_function_stop b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_function_stop new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_independence b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_independence new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_range b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_range new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_string b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_string new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_tuple b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_tuple new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_mutating_seq_class_exhausted_iter b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_mutating_seq_class_exhausted_iter new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_nested_comprehensions_for b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_nested_comprehensions_for new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_nested_comprehensions_iter b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_nested_comprehensions_iter new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_new_style_iter_class b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_new_style_iter_class new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_ref_counting_behavior b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_ref_counting_behavior new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_seq_class_for b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_seq_class_for new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_seq_class_iter b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_seq_class_iter new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_callable b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_callable new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_dict b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_dict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_enumerate b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_enumerate new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_list b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_list new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_range b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_range new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_sequence b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_sequence new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_string b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_string new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_tuple b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_tuple new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_yield b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_yield new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_stop_sequence b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_stop_sequence new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_unicode_join_endcase b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_unicode_join_endcase new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_addmul b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_addmul new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_append b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_append new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_basic b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_basic new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_clear b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_clear new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_constructors b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_constructors new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_contains b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_contains new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_contains_order b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_contains_order new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_copy b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_copy new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_count b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_count new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_count_index_remove_crashes b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_count_index_remove_crashes new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_delitem b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_delitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_delslice b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_delslice new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_empty_slice b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_empty_slice new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_equal_operator_modifying_operand b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_equal_operator_modifying_operand new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_extend b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_extend new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_extendedslicing b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_extendedslicing new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_getitem b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_getitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_getitemoverwriteiter b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_getitemoverwriteiter new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_getslice b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_getslice new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_iadd b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_iadd new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_insert b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_insert new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_keywords_in_subclass b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_keywords_in_subclass new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_list_index_modifing_operand b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_list_index_modifing_operand new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_list_resize_overflow b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_list_resize_overflow new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_lt_operator_modifying_operand b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_lt_operator_modifying_operand new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_no_comdat_folding b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_no_comdat_folding new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_overflow b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_overflow new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_pop b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_pop new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_repr_mutate b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_repr_mutate new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_reverse b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_reverse new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_set_subscript b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_set_subscript new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_setitem b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_setitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_setslice b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_setslice new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_slice b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_slice new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_slice_assign_iterator b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_slice_assign_iterator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_step_overflow b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_step_overflow new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_subscript b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_subscript new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_tier2_invalidates_iterator b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_tier2_invalidates_iterator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-IsCloseTests.test_negative_tolerances b/test/dynamo_expected_failures/CPython313-test_math-IsCloseTests.test_negative_tolerances new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcos b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcos new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcosh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcosh new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsin b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsin new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsinh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsinh new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan2 b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtanh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtanh new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCbrt b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCbrt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCeil b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCeil new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCopysign b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCopysign new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCos b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCos new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCosh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCosh new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testDegrees b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testDegrees new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp2 b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFabs b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFabs new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFactorialHugeInputs b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFactorialHugeInputs new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFloor b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFloor new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFmod b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFmod new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFrexp b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFrexp new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLdexp b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLdexp new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog10 b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog10 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog1p b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog1p new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testModf b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testModf new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testPow b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testPow new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testRadians b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testRadians new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSin b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSin new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSinh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSinh new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSqrt b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSqrt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTan b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTan new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTanh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTanh new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_exceptions b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_exceptions new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_input_exceptions b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_input_exceptions new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_issue39871 b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_issue39871 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_math_dist_leak b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_math_dist_leak new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_nextafter b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_nextafter new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_sumprod_stress b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_sumprod_stress new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_trunc b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_trunc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_ulp b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_ulp new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_delitem_hash_collision b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_delitem_hash_collision new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_highly_nested_subclass b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_highly_nested_subclass new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_override_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_override_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_reinsert b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_reinsert new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_setitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonBuiltinDictTests.test_setitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_get b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_get new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_getitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_getitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_items b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_items new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_keys b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_keys new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_popitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_read b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_read new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_setdefault b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_setdefault new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_values b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_values new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_write b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonGeneralMappingTests.test_write new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_delitem_hash_collision b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_delitem_hash_collision new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_highly_nested_subclass b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_highly_nested_subclass new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_init_calls b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_init_calls new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_linked_list_by_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_linked_list_by_clear new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_linked_list_by_delete_key b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_linked_list_by_delete_key new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_size_by_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_size_by_clear new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_size_by_delete_key b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_size_by_delete_key new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_size_by_delete_key_in_dict_eq b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue119004_change_size_by_delete_key_in_dict_eq new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue24347 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue24347 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue24348 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue24348 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue24667 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_issue24667 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_iterators_empty b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_iterators_empty new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_overridden_init b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_overridden_init new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_override_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_override_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_popitem_last b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_popitem_last new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_reinsert b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_reinsert new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_setitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_setitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_delitem_hash_collision b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_delitem_hash_collision new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_fromkeys b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_fromkeys new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_highly_nested_subclass b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_highly_nested_subclass new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_init_calls b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_init_calls new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_linked_list_by_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_linked_list_by_clear new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_linked_list_by_delete_key b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_linked_list_by_delete_key new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_size_by_clear b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_size_by_clear new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_size_by_delete_key b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_size_by_delete_key new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_size_by_delete_key_in_dict_eq b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue119004_change_size_by_delete_key_in_dict_eq new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue24347 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue24347 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue24348 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue24348 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue24667 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_issue24667 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_iterators_empty b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_iterators_empty new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_merge_operator b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_merge_operator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_move_to_end b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_move_to_end new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_move_to_end_issue25406 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_move_to_end_issue25406 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_overridden_init b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_overridden_init new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_override_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_override_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_popitem_last b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_popitem_last new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_reinsert b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_reinsert new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_setitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_setitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_constructor b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_constructor new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_get b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_get new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_getitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_getitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_items b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_items new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_keys b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_keys new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_popitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_popitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_setdefault b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_setdefault new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_values b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_values new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_write b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_write new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_change_order_on_get b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_change_order_on_get new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_pop b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_pop new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_getitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonGeneralMappingTests.test_getitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_delitem_hash_collision b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_delitem_hash_collision new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_highly_nested_subclass b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_highly_nested_subclass new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_init_calls b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_init_calls new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_issue119004_attribute_error b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_issue119004_attribute_error new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_issue24347 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_issue24347 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_issue24348 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_issue24348 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_overridden_init b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_overridden_init new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_override_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictSubclassTests.test_override_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_delitem_hash_collision b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_delitem_hash_collision new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_highly_nested_subclass b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_highly_nested_subclass new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_init_calls b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_init_calls new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_issue119004_attribute_error b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_issue119004_attribute_error new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_issue24347 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_issue24347 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_issue24348 b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_issue24348 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_overridden_init b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_overridden_init new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_override_update b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonOrderedDictTests.test_override_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_getitem b/test/dynamo_expected_failures/CPython313-test_ordered_dict-PurePythonSubclassMappingTests.test_getitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_copy b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_copy new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_difference b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_difference new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_difference_rev b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_difference_rev new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_intersection b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_intersection new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_isdisjoint b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_isdisjoint new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_symmetric_difference b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_symmetric_difference new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_union b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_empty_union new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_equivalent_equality b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_equivalent_equality new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_intersection_empty b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_intersection_empty new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_length b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_length new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_difference b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_difference new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_equality b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_equality new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_intersection b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_intersection new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_isdisjoint b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_isdisjoint new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_symmetric_difference b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_symmetric_difference new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_union b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_self_union new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_union_empty b/test/dynamo_expected_failures/CPython313-test_set-TestBasicOpsTriple.test_union_empty new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_and_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_and_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_eq_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_eq_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_ge_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_ge_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_gt_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_gt_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_iadd_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_iadd_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_ior_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_ior_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_isub_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_isub_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_iteration_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_iteration_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_ixor_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_ixor_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_le_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_le_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_lt_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_lt_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_ne_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_ne_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_or_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_or_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_sub_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_sub_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_xor_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Set.test_xor_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_and_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_and_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_eq_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_eq_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_ge_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_ge_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_gt_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_gt_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_iadd_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_iadd_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_ior_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_ior_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_isub_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_isub_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_iteration_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_iteration_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_ixor_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_ixor_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_le_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_le_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_lt_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_lt_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_ne_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_ne_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_or_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_or_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_sub_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_sub_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_xor_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Set_Subclass.test_xor_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_and_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_and_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_eq_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_eq_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_ge_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_ge_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_gt_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_gt_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_iadd_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_iadd_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_ior_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_ior_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_isub_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_isub_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_iteration_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_iteration_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_ixor_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_ixor_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_le_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_le_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_lt_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_lt_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_ne_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_ne_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_or_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_or_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_sub_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_sub_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_xor_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Set.test_xor_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_and_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_and_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_eq_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_eq_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_ge_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_ge_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_gt_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_gt_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_iadd_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_iadd_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_ior_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_ior_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_isub_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_isub_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_iteration_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_iteration_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_ixor_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_ixor_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_le_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_le_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_lt_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_lt_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_ne_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_ne_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_or_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_or_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_sub_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_sub_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_xor_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestBinaryOpsMutating_Subclass_Subclass.test_xor_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_and b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_and new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_container_iterator b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_container_iterator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_contains b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_contains new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_deepcopy b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_deepcopy new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_difference b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_difference new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_equality b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_equality new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_gc b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_gc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_or b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_or new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_setOfFrozensets b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_setOfFrozensets new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_sub b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_sub new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_subclass_with_custom_hash b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_subclass_with_custom_hash new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_symmetric_difference b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_symmetric_difference new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_union b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_union new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_uniquification b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_uniquification new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_xor b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSet.test_xor new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_and b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_and new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_constructor_identity b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_constructor_identity new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_container_iterator b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_container_iterator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_contains b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_contains new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_deepcopy b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_deepcopy new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_difference b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_difference new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_equality b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_equality new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_gc b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_gc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_init b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_init new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_isdisjoint b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_isdisjoint new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_keywords_in_subclass b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_keywords_in_subclass new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_len b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_len new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_nested_empty_constructor b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_nested_empty_constructor new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_or b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_or new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_setOfFrozensets b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_setOfFrozensets new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_sub b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_sub new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_sub_and_super b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_sub_and_super new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_subclass_with_custom_hash b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_subclass_with_custom_hash new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_symmetric_difference b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_symmetric_difference new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_union b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_union new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_uniquification b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_uniquification new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_xor b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_xor new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestGraphs.test_cube b/test/dynamo_expected_failures/CPython313-test_set-TestGraphs.test_cube new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_difference_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_difference_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_difference_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_difference_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_intersection_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_intersection_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_intersection_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_intersection_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_isdisjoint_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_isdisjoint_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_issubset_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_issubset_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_issuperset_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_issuperset_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_symmetric_difference_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_symmetric_difference_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_symmetric_difference_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_symmetric_difference_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_union_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_union_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Dict.test_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_difference_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_difference_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_difference_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_difference_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_intersection_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_intersection_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_intersection_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_intersection_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_isdisjoint_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_isdisjoint_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_issubset_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_issubset_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_issuperset_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_issuperset_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_symmetric_difference_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_symmetric_difference_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_symmetric_difference_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_symmetric_difference_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_union_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_union_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_List.test_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_difference_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_difference_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_difference_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_difference_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_intersection_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_intersection_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_intersection_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_intersection_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_isdisjoint_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_isdisjoint_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_issubset_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_issubset_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_issuperset_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_issuperset_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_symmetric_difference_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_symmetric_difference_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_symmetric_difference_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_symmetric_difference_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_union_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_union_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Set.test_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_difference_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_difference_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_difference_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_difference_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_intersection_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_intersection_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_intersection_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_intersection_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_isdisjoint_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_isdisjoint_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_issubset_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_issubset_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_issuperset_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_issuperset_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_symmetric_difference_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_symmetric_difference_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_symmetric_difference_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_symmetric_difference_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_union_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_union_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Set_Subclass.test_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_difference_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_difference_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_difference_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_difference_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_intersection_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_intersection_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_intersection_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_intersection_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_isdisjoint_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_isdisjoint_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_issubset_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_issubset_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_issuperset_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_issuperset_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_symmetric_difference_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_symmetric_difference_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_symmetric_difference_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_symmetric_difference_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_union_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_union_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Set.test_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_difference_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_difference_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_difference_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_difference_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_intersection_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_intersection_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_intersection_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_intersection_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_isdisjoint_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_isdisjoint_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_issubset_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_issubset_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_issuperset_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_issuperset_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_symmetric_difference_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_symmetric_difference_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_symmetric_difference_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_symmetric_difference_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_union_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_union_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_update_with_mutation b/test/dynamo_expected_failures/CPython313-test_set-TestMethodsMutating_Subclass_Subclass.test_update_with_mutation new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsDict.test_union b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsDict.test_union new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsDict.test_update_operator b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsDict.test_update_operator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_difference_update_operator b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_difference_update_operator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_intersection_update_operator b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_intersection_update_operator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_sym_difference_update_operator b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_sym_difference_update_operator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_update_operator b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsGenerator.test_update_operator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsNumeric.test_union b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsNumeric.test_union new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsNumeric.test_update b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsNumeric.test_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsNumeric.test_update_operator b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsNumeric.test_update_operator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_intersection b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_intersection new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_intersection_update b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_intersection_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_sym_difference b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_sym_difference new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_sym_difference_update b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_sym_difference_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_union b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_union new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_update b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_update_operator b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsOperator.test_update_operator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_difference b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_difference new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_difference_update b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_difference_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_intersection b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_intersection new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_intersection_update b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_intersection_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_sym_difference b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_sym_difference new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_sym_difference_update b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_sym_difference_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_union b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_union new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_update_operator b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsString.test_update_operator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsTuple.test_union b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsTuple.test_union new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsTuple.test_update_operator b/test/dynamo_expected_failures/CPython313-test_set-TestOnlySetsTuple.test_update_operator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_container_iterator b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_container_iterator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_contains b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_contains new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_deepcopy b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_deepcopy new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_difference b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_difference new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_difference_update b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_difference_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_gc b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_gc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_intersection_update b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_intersection_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_or b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_or new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_remove_keyerror_unpacking b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_remove_keyerror_unpacking new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_rich_compare b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_rich_compare new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_setOfFrozensets b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_setOfFrozensets new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_set_literal_evaluation_order b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_set_literal_evaluation_order new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_subclass_with_custom_hash b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_subclass_with_custom_hash new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_symmetric_difference b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_symmetric_difference new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_symmetric_difference_update b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_symmetric_difference_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_union b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_union new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_uniquification b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_uniquification new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_update b/test/dynamo_expected_failures/CPython313-test_set-TestSet.test_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetOfSets.test_constructor b/test/dynamo_expected_failures/CPython313-test_set-TestSetOfSets.test_constructor new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_add b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_add new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_and b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_and new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_clear b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_clear new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_constructor_identity b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_constructor_identity new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_container_iterator b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_container_iterator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_contains b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_contains new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_deepcopy b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_deepcopy new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_difference b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_difference new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_difference_update b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_difference_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_equality b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_equality new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_gc b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_gc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_iand b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_iand new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_init b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_init new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_inplace_on_self b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_inplace_on_self new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_intersection_update b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_intersection_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_ior b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_ior new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_isdisjoint b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_isdisjoint new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_isub b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_isub new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_ixor b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_ixor new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_keywords_in_subclass b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_keywords_in_subclass new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_len b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_len new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_or b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_or new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_pop b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_pop new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_remove_keyerror_set b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_remove_keyerror_set new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_remove_keyerror_unpacking b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_remove_keyerror_unpacking new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_rich_compare b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_rich_compare new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_setOfFrozensets b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_setOfFrozensets new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_set_literal_evaluation_order b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_set_literal_evaluation_order new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_sub b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_sub new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_sub_and_super b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_sub_and_super new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_subclass_with_custom_hash b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_subclass_with_custom_hash new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_symmetric_difference b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_symmetric_difference new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_symmetric_difference_update b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_symmetric_difference_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_union b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_union new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_uniquification b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_uniquification new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_update b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_update new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_xor b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_xor new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestVariousIteratorArgs.test_constructor b/test/dynamo_expected_failures/CPython313-test_set-TestVariousIteratorArgs.test_constructor new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestWeirdBugs.test_merge_and_mutate b/test/dynamo_expected_failures/CPython313-test_set-TestWeirdBugs.test_merge_and_mutate new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_sort-TestBase.testStressfully b/test/dynamo_expected_failures/CPython313-test_sort-TestBase.testStressfully new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_sort-TestBase.test_small_stability b/test/dynamo_expected_failures/CPython313-test_sort-TestBase.test_small_stability new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_sort-TestBugs.test_bug453523 b/test/dynamo_expected_failures/CPython313-test_sort-TestBugs.test_bug453523 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_sort-TestDecorateSortUndecorate.test_key_with_exception b/test/dynamo_expected_failures/CPython313-test_sort-TestDecorateSortUndecorate.test_key_with_exception new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_sort-TestDecorateSortUndecorate.test_key_with_mutating_del b/test/dynamo_expected_failures/CPython313-test_sort-TestDecorateSortUndecorate.test_key_with_mutating_del new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_sort-TestDecorateSortUndecorate.test_key_with_mutating_del_and_exception b/test/dynamo_expected_failures/CPython313-test_sort-TestDecorateSortUndecorate.test_key_with_mutating_del_and_exception new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_sort-TestOptimizedCompares.test_none_in_tuples b/test/dynamo_expected_failures/CPython313-test_sort-TestOptimizedCompares.test_none_in_tuples new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_sort-TestOptimizedCompares.test_unsafe_object_compare b/test/dynamo_expected_failures/CPython313-test_sort-TestOptimizedCompares.test_unsafe_object_compare new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_addmul b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_addmul new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_constructors b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_constructors new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_contains b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_contains new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_contains_order b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_contains_order new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_count b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_count new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_getitem b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_getitem new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_getitemoverwriteiter b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_getitemoverwriteiter new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_getslice b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_getslice new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_iadd b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_iadd new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_imul b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_imul new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_keywords_in_subclass b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_keywords_in_subclass new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_no_comdat_folding b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_no_comdat_folding new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_subscript b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_subscript new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_track_subtypes b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_track_subtypes new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_tupleresizebug b/test/dynamo_expected_failures/CPython313-test_tuple-TupleTest.test_tupleresizebug new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_constructor b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_constructor new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_eq b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_eq new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_get b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_get new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_keys b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_keys new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_len b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_len new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_mutatingiteration b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_mutatingiteration new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_read b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_read new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_add_specials b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_add_specials new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_addmul b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_addmul new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_contains_order b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_contains_order new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_delslice b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_delslice new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_exhausted_iterator b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_exhausted_iterator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_extendedslicing b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_extendedslicing new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_getitemoverwriteiter b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_getitemoverwriteiter new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_getslice b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_getslice new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_imul b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_imul new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_len b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_len new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_minmax b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_minmax new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_mixedadd b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_mixedadd new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_mixedcmp b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_mixedcmp new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_radd_specials b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_radd_specials new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_slice b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_slice new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_slice_assign_iterator b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_slice_assign_iterator new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_slice_type b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_slice_type new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_truth b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_truth new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestBaseSparsifier.test_state_dict b/test/dynamo_expected_failures/TestBaseSparsifier.test_state_dict new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_complex128 b/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_complex128 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_complex64 b/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_complex64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_float32 b/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_float32 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_float64 b/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_float64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_expected_failures/TestVmapAPI.test_fallback_warns_when_warnings_are_enabled b/test/dynamo_expected_failures/TestVmapAPI.test_fallback_warns_when_warnings_are_enabled new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/dynamo_skips/TestOpenMP_ParallelFor.test_one_thread b/test/dynamo_skips/TestOpenMP_ParallelFor.test_one_thread new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index c650b102bf1a7..ce698be668f91 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -75,7 +75,10 @@ aten::_ctc_loss.out aten::_ctc_loss_backward aten::_ctc_loss_backward.Tensor aten::_ctc_loss_backward.out +<<<<<<< HEAD aten::_cudnn_attention_backward +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten::_cudnn_attention_forward aten::_cudnn_ctc_loss aten::_cudnn_ctc_loss.Tensor @@ -375,6 +378,10 @@ aten::_fused_adamw_.tensor_lr aten::_fused_moving_avg_obs_fq_helper aten::_fused_moving_avg_obs_fq_helper.out aten::_fused_moving_avg_obs_fq_helper_functional +<<<<<<< HEAD +======= +aten::_fused_rms_norm +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten::_fused_sdp_choice aten::_fused_sgd aten::_fused_sgd.out @@ -853,8 +860,11 @@ aten::hann_window.periodic aten::hann_window.periodic_out aten::hardshrink_backward aten::hardshrink_backward.grad_input +<<<<<<< HEAD aten::hash_tensor aten::hash_tensor.out +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten::histc aten::histc.out aten::histogram.bin_ct diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index 8dbe257ec3ae6..7dcb2a5fc4377 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -23,12 +23,20 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node torch.fx.graph.Graph.print_tabular(self) +<<<<<<< HEAD torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode +======= +torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False) -> torch.fx.graph.PythonCode +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule') torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None torch.fx.graph_module.GraphModule.delete_submodule(self, target: str) -> bool torch.fx.graph_module.GraphModule.recompile(self) -> torch.fx.graph.PythonCode +<<<<<<< HEAD +======= +torch.fx.graph_module.reduce_deploy_graph_module(importer: Callable, body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.fx.graph_module.reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module torch.fx.graph_module.reduce_package_graph_module(importer: Callable, body: Dict[Any, Any], generated_module_name: str) -> torch.nn.modules.module.Module torch.fx.interpreter.Interpreter.__init__(self, module: torch.nn.modules.module.Module, garbage_collect_values: bool = True, graph: Optional[torch.fx.graph.Graph] = None) @@ -52,7 +60,11 @@ torch.fx.interpreter.Transformer.placeholder(self, target: 'Target', args: Tuple torch.fx.interpreter.Transformer.transform(self) -> torch.fx.graph_module.GraphModule torch.fx.node.Node.__init__(self, graph: 'Graph', name: str, op: str, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Argument], return_type: Optional[Any] = None) -> None torch.fx.node.Node.append(self, x: 'Node') -> None +<<<<<<< HEAD torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None, include_tensor_metadata: bool = False) -> Optional[str] +======= +torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None) -> Optional[str] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.fx.node.Node.insert_arg(self, idx: int, arg: torch.fx.node.Argument) -> None torch.fx.node.Node.prepend(self, x: 'Node') -> None torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Callable[[Node], bool] = >, propagate_meta: bool = False) -> List[Node] diff --git a/test/export/test_converter.py b/test/export/test_converter.py index e739e5c346677..6490939cdbc7e 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -700,7 +700,11 @@ def forward(self, x: torch.Tensor): else: return self.w + self.m2(x) +<<<<<<< HEAD # Super nested, parameters need to be lifted +======= + # Super nested, parameters neeed to lifted +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # multiple times. class SuperNestedM(torch.nn.Module): def __init__(self) -> None: @@ -755,7 +759,11 @@ def forward(self, x: torch.Tensor): else: return self.linear(self.m2(x)) +<<<<<<< HEAD # Super nested, parameters need to be lifted +======= + # Super nested, parameters neeed to lifted +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # multiple times. class SuperNestedM1(torch.nn.Module): def __init__(self, dim: int) -> None: @@ -771,7 +779,11 @@ def forward(self, x: torch.Tensor): return self.linear(self.m2(x)) # Super nested, even the input needs to be +<<<<<<< HEAD # lifted recursively due to value propagation optimization. +======= + # lifted recursively due to value propogation optimiztaion. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SuperNestedM2(torch.nn.Module): def __init__(self, dim: int) -> None: super().__init__() @@ -911,7 +923,11 @@ def foo_impl(x): return x + x # Meta function of the custom op. +<<<<<<< HEAD @torch.library.register_fake( +======= + @torch.library.impl_abstract( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "mylib::foo", lib=lib, ) @@ -1448,11 +1464,15 @@ def fuse_model(self): ep_out, _ = pytree.tree_flatten(ep.module()(*inp)) self._check_tensor_list_equal(orig_out, ep_out) +<<<<<<< HEAD # qnnpack/xnnpack not supported on s390x. # it is required by # torch.ops.prepacked.linear_clamp_prepack # and # torch.ops.prepacked.linear_clamp_run +======= + # qnnpack not supported on s390x +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @xfailIfS390X def test_ts2ep_convert_quantized_model_with_opcontext(self): class M(torch.nn.Module): @@ -1471,12 +1491,15 @@ def forward(self, x): inp = (torch.randn(1, 10),) self._check_equal_ts_ep_converter(m, inp, ["script"]) +<<<<<<< HEAD # qnnpack/xnnpack not supported on s390x. # it is required by # torch.ops.prepacked.linear_clamp_prepack # and # torch.ops.prepacked.linear_clamp_run @xfailIfS390X +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_ts2ep_convert_quantized_model_with_opcontext_and_constant(self): class M(torch.nn.Module): def __init__(self, linear_op): diff --git a/test/export/test_draft_export.py b/test/export/test_draft_export.py index 7f7148273ad70..29c16c33c8b35 100644 --- a/test/export/test_draft_export.py +++ b/test/export/test_draft_export.py @@ -1,6 +1,9 @@ # Owner(s): ["oncall: export"] import copy +<<<<<<< HEAD import re +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import tempfile import unittest @@ -408,12 +411,16 @@ def forward(self, a): inp = (torch.ones(3, 3),) +<<<<<<< HEAD ep = draft_export( M(), inp, dynamic_shapes={"a": {0: Dim("a0")}}, prefer_deferred_runtime_asserts_over_guards=True, ) +======= + ep = draft_export(M(), inp, dynamic_shapes={"a": {0: Dim("a0")}}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) report = ep._report self.assertEqual(len(report.failures), 1) @@ -423,11 +430,15 @@ def forward(self, a): self.assertEqual(ep.module()(*inp), M()(*inp)) inp = (torch.randn(4, 3),) +<<<<<<< HEAD with self.assertRaisesRegex( AssertionError, re.escape("Guard failed: a.size()[0] <= 3"), ): # expected <= 3, but got 4 +======= + with self.assertRaises(RuntimeError): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(*inp) def test_side_effect1(self): diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 871dc813a687f..d62c87233425e 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -1,6 +1,9 @@ # Owner(s): ["oncall: export"] # flake8: noqa +<<<<<<< HEAD import copy +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import types import unittest from typing import Dict, List, Tuple @@ -10,6 +13,10 @@ from torch._dynamo.test_case import run_tests, TestCase from torch._functorch.aot_autograd import aot_export_module from torch.export import export, export_for_training +<<<<<<< HEAD +======= +from torch.export._trace import _convert_ts_to_export_experimental +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.export.experimental import _export_forward_backward, _sticky_export from torch.export.graph_signature import OutputKind from torch.testing import FileCheck @@ -17,6 +24,96 @@ @unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported") class TestExperiment(TestCase): +<<<<<<< HEAD +======= + def test_torchscript_module_export(self): + class M(torch.nn.Module): + def forward(self, x): + return x.cos() + x.sin() + + model_to_trace = M() + inps = (torch.randn(4, 4),) + traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps) + + exported_module = _convert_ts_to_export_experimental( + traced_module_by_torchscript, inps + ) + + self.assertTrue(torch.allclose(exported_module(*inps), model_to_trace(*inps))) + + def test_torchscript_module_export_single_input(self): + class M(torch.nn.Module): + def forward(self, x): + return x.cos() + x.sin() + + model_to_trace = M() + inps = torch.randn(4, 4) + traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps) + + exported_module = _convert_ts_to_export_experimental( + traced_module_by_torchscript, inps + ) + + self.assertTrue(torch.allclose(exported_module(inps), model_to_trace(inps))) + + def test_torchscript_module_export_various_inputs_with_annotated_input_names(self): + def _check_equality_and_annotations(m_func, inps): + # Original module. + model_to_trace = m_func() + + # ExportedProgram from TorchScript module. + traced_module_by_torchscript = torch.jit.trace( + m_func(), example_inputs=inps + ) + exported_module = _convert_ts_to_export_experimental( + traced_module_by_torchscript, inps + ) + + # ExportedProgram from original module. + original_exported_module = torch.export.export_for_training( + m_func(), inps, strict=True + ) + + # Check whether input annotations are the same as tracing the original module. + orig_ph_name_list = [ + n.name + for n in original_exported_module.graph.nodes + if n.op == "placeholder" + ] + ph_name_list = [ + n.name for n in exported_module.graph.nodes if n.op == "placeholder" + ] + self.assertEqual(orig_ph_name_list, ph_name_list) + + # Check results equality. + self.assertTrue( + torch.allclose(exported_module(*inps), model_to_trace(*inps)) + ) + + # Tuple + class MTuple(torch.nn.Module): + def forward(self, x: Tuple[torch.Tensor]): + return x[0] + x[1] + + _check_equality_and_annotations(MTuple, ((torch.randn(4), torch.randn(4)),)) + + # List + class MList(torch.nn.Module): + def forward(self, x: List[torch.Tensor]): + return x[0] + x[1] + + _check_equality_and_annotations(MList, ([torch.randn(4), torch.randn(4)],)) + + # Dict + class MDict(torch.nn.Module): + def forward(self, x: Dict[str, torch.Tensor]): + return x["0"] + x["1"] + + _check_equality_and_annotations( + MDict, ({"0": torch.randn(4), "1": torch.randn(4)},) + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_joint_basic(self) -> None: class Module(torch.nn.Module): def __init__(self) -> None: @@ -319,8 +416,15 @@ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) linear_weight = self.linear.weight linear_bias = self.linear.bias +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None +======= + sym_size_int_2 = torch.ops.aten.sym_size.int(x, 1) + linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None + eq = sym_size_int_2 == 4; sym_size_int_2 = None + _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s27, 4) on node 'eq'"); eq = _assert_scalar_default = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return pytree.tree_unflatten((linear,), self._out_spec)""", ) @@ -353,6 +457,7 @@ def generate(self, *, input_tensor, input_tensor2): res2 = p.generate(input_tensor=inp, input_tensor2=inp2) self.assertTrue(torch.allclose(res, res2)) +<<<<<<< HEAD def test_export_add_in_out_info(self): class Foo(torch.nn.Module): def forward(self, dct, lst, bleh): @@ -409,6 +514,8 @@ def forward(self, x): self.assertEqual(res_export, res_eager) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/export/test_export.py b/test/export/test_export.py index 664436a23ee4a..33ee8cfff7c42 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -7,6 +7,7 @@ import logging import math import operator +<<<<<<< HEAD import os import re import traceback @@ -14,6 +15,12 @@ import warnings import weakref from contextlib import contextmanager, nullcontext +======= +import re +import unittest +import warnings +from contextlib import contextmanager +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from dataclasses import dataclass from re import escape from typing import Dict, List, Union @@ -25,8 +32,12 @@ import torch.utils._pytree as pytree from functorch.experimental.control_flow import cond, map from torch import Tensor +<<<<<<< HEAD from torch._decomp import decomposition_table, get_decompositions from torch._dynamo._trace_wrapped_higher_order_op import mod_index +======= +from torch._decomp import decomposition_table +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo.test_case import TestCase from torch._dynamo.testing import normalize_gm from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse @@ -40,7 +51,10 @@ from torch._higher_order_ops.associative_scan import associative_scan from torch._higher_order_ops.hints_wrap import hints_wrapper from torch._higher_order_ops.scan import scan +<<<<<<< HEAD from torch._higher_order_ops.while_loop import while_loop +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.compile_fx import split_const_gm from torch._subclasses import FakeTensorMode from torch.export import ( @@ -62,7 +76,10 @@ OutputSpec, TensorArgument, ) +<<<<<<< HEAD from torch.export.passes import move_to_device_pass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.testing import FileCheck @@ -89,7 +106,11 @@ ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU from torch.testing._internal.torchbind_impls import load_torchbind_test_lib +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu +======= +from torch.testing._internal.triton_utils import requires_cuda, requires_gpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.two_tensor import TwoTensor from torch.utils._pytree import ( LeafSpec, @@ -110,7 +131,11 @@ from torch._library import capture_triton try: +<<<<<<< HEAD from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +======= + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) HAS_TORCHREC = True except ImportError: @@ -150,7 +175,11 @@ @torch.library.impl("testlib::returns_tensor_symint", "cpu") +<<<<<<< HEAD @torch.library.register_fake("testlib::returns_tensor_symint") +======= +@torch.library.impl_abstract("testlib::returns_tensor_symint") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def returns_tensor_symint_impl(x): return x, x.shape[0] @@ -163,7 +192,11 @@ def foo_impl(x, z): return x, z, x + z +<<<<<<< HEAD @torch.library.register_fake("testlib::foo") +======= +@torch.library.impl_abstract("testlib::foo") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def foo_abstract(x, z): return x, z, x + z @@ -255,10 +288,13 @@ def is_training_ir_test(test_name): ) +<<<<<<< HEAD def is_training_ir_strict_test(test_name): return test_name.endswith(TRAINING_IR_DECOMP_STRICT_SUFFIX) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def is_cpp_runtime_test(test_name): return test_name.endswith(CPP_RUNTIME_STRICT_SUFFIX) or test_name.endswith( CPP_RUNTIME_NONSTRICT_SUFFIX @@ -329,6 +365,7 @@ def forward(self, *args): dynamic_shapes=dynamic_shapes, ) +<<<<<<< HEAD def test_no_grad_param_inplace(self): class Foo(torch.nn.Module): def __init__(self): @@ -375,6 +412,8 @@ def forward(self, x): self.assertTrue(torch.allclose(res, res_export)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_export_slice_unbacked_dim1(self): class MySlice(torch.nn.Module): def forward(self, x, seq_len): @@ -518,6 +557,7 @@ def _test_export_same_as_eager(self, f, args, kwargs=None): # ) def _check_dynamic_shapes_specs_and_shapes( +<<<<<<< HEAD self, model, inputs, @@ -525,6 +565,9 @@ def _check_dynamic_shapes_specs_and_shapes( passing_shapes, failing_shapes, test_serdes=False, +======= + self, model, inputs, specs, passing_shapes, failing_shapes, test_serdes=False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): from torch._export.serde.dynamic_shapes import ( _dump_dynamic_shapes, @@ -549,7 +592,11 @@ def _is_tensor_leaf(x): eps = [ep] if test_serdes: # test dynamic shapes serialization +<<<<<<< HEAD # test that behavior remains the same when exporting with Ser/Des specs: +======= + # test that behavior remains the same when exporting with ser/des specs: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # serialize + deserialize original specs, and export. ep_serdes = export( model, @@ -566,7 +613,11 @@ def _is_tensor_leaf(x): ep.module()(*test_inputs) for shapes in failing_shapes: test_inputs = _construct_inputs(shapes) +<<<<<<< HEAD with self.assertRaisesRegex(AssertionError, "Guard failed"): +======= + with self.assertRaises(RuntimeError): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(*test_inputs) def test_basic(self): @@ -621,6 +672,7 @@ def forward(self, x): self.assertEqual(counter, 1) +<<<<<<< HEAD @testing.expectedFailureSerDer # can't serialize functorch ops @testing.expectedFailureSerDerNonStrict # can't serialize functorch ops def test_vmap_to_assert(self): @@ -637,6 +689,8 @@ def forward(self, x, y): eager = VmapToAssert()(torch.ones(4, 4, 4, 4), torch.ones(4, 4, 4, 4)) self.assertEqual(exported, eager) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_from_node_metadata_export(self): class Foo(torch.nn.Module): def __init__(self) -> None: @@ -655,6 +709,7 @@ def example_inputs(self): f = Foo() inputs = (torch.randn(1, 3, 5, 5),) +<<<<<<< HEAD ep = export(f, inputs) graph_id = id(ep.graph) gm = ep.module() @@ -662,6 +717,13 @@ def example_inputs(self): for node in gm.graph.nodes: if node.op in ("placeholder", "output", "call_module"): +======= + gm = export(f, inputs).module() + from torch.fx.traceback import NodeSourceAction + + for node in gm.graph.nodes: + if node.op in ("placeholder", "output"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue if "weight" in node.name or "bias" in node.name: self.assertTrue( @@ -672,9 +734,12 @@ def example_inputs(self): node.meta["from_node"][-1].action == [NodeSourceAction.CREATE, NodeSourceAction.REPLACE] ) +<<<<<<< HEAD self.assertEqual( node.meta["from_node"][-1].from_node[-1].graph_id, graph_id ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: self.assertTrue( node.meta["from_node"][-1].pass_name == "ExportedProgram.module()" @@ -682,6 +747,7 @@ def example_inputs(self): self.assertTrue( node.meta["from_node"][-1].action == [NodeSourceAction.CREATE] ) +<<<<<<< HEAD self.assertEqual(node.meta["from_node"][-1].graph_id, graph_id) ## re-export @@ -693,6 +759,15 @@ def example_inputs(self): if node.op in ("placeholder", "output", "call_module"): continue +======= + + ## re-export + gm2 = export(gm, inputs).module() + + for node in gm2.graph.nodes: + if node.op in ("placeholder", "output"): + continue +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if "weight" in node.name or "bias" in node.name: self.assertTrue( node.meta["from_node"][-1].pass_name @@ -702,9 +777,12 @@ def example_inputs(self): node.meta["from_node"][-1].action == [NodeSourceAction.CREATE, NodeSourceAction.REPLACE] ) +<<<<<<< HEAD self.assertEqual( node.meta["from_node"][-1].from_node[-1].graph_id, graph_id ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: self.assertTrue( node.meta["from_node"][-1].pass_name == "ExportedProgram.module()" @@ -712,7 +790,10 @@ def example_inputs(self): self.assertTrue( node.meta["from_node"][-1].action == [NodeSourceAction.CREATE] ) +<<<<<<< HEAD self.assertEqual(node.meta["from_node"][-1].graph_id, graph_id) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_bincount(self): class M(torch.nn.Module): @@ -953,8 +1034,12 @@ def forward(self, x): """\ graph(): %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0] +<<<<<<< HEAD %x : [num_users=2] = placeholder[target=x] %_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {}) +======= + %x : [num_users=1] = placeholder[target=x] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lifted_tensor_0), kwargs = {}) return (add,)""", ) @@ -1006,6 +1091,10 @@ def forward(self, x): ep = export(f, args, strict=False) self.assertEqual(ep.module()(*args), f(*args)) +<<<<<<< HEAD +======= + @testing.expectedFailureCppSerDes # Cpp serder seems to fail parsing complicated guards +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_export_statically_known_true(self): class Foo(torch.nn.Module): def forward(self, x, y): @@ -1018,6 +1107,7 @@ def forward(self, x, y): (torch.export.Dim.DYNAMIC, torch.export.Dim.DYNAMIC), ) +<<<<<<< HEAD m = Foo() inp = (torch.randn(4, 4), torch.randn(4, 4)) ep = export( @@ -1030,6 +1120,15 @@ def forward(self, x, y): self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp))) FileCheck().check_count("torch.ops.aten.slice.Tensor", 1, exactly=True).run( +======= + ep = export( + Foo(), + (torch.randn(4, 4), torch.randn(4, 4)), + dynamic_shapes=dynamic_shapes, + strict=False, + ) + FileCheck().check_count("torch.ops.aten.slice.Tensor", 2, exactly=True).run( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) str(ep.graph) ) FileCheck().check_count("operator.sub", 1, exactly=True).run(str(ep.graph)) @@ -1346,6 +1445,7 @@ def forward(self, x: torch.Tensor, as_tuple: bool) -> torch.Tensor: for vr_upper in vr_upper_bounds: self.assertEqual(vr_upper, 1) +<<<<<<< HEAD def test_detect_leak_strict(self): class Foo(torch.nn.Module): def __init__(self): @@ -1388,6 +1488,8 @@ def update(self): ): ref(torch.randn(4, 4), torch.randn(4, 4)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_mask_nonzero_static(self): class TestModule(torch.nn.Module): def forward(self, seq_embeddings, mask, exp): @@ -1518,11 +1620,15 @@ def forward(self, x, ys, zs, c): {"a": torch.zeros(5), "b": torch.ones(5)}, torch.ones(4), ) +<<<<<<< HEAD with self.assertRaisesRegex( AssertionError, escape("Guard failed: ys[0].size()[0] == x.size()[0]"), ): # expected 6, but got 5 +======= + with self.assertRaisesRegex(RuntimeError, "to be equal to 6, but got 5"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep_ns.module()(*bad_runtime_inp1) bad_runtime_inp2 = ( @@ -1532,10 +1638,16 @@ def forward(self, x, ys, zs, c): torch.ones(6), ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: c.size()[0] == 4"), ): # expected 4, but got 6 +======= + RuntimeError, + escape("Expected input at *args[3].shape[0] to be equal to 4, but got 6"), + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep_ns.module()(*bad_runtime_inp2) good_runtime_inp = ( @@ -1663,9 +1775,12 @@ def false_fn(x): torch.export.export(M(), (torch.randn(7),), strict=strict) def test_cond_branches_return_constant_int(self): +<<<<<<< HEAD if "cpp_runtime_nonstrict" in self.id(): self.skipTest("TODO Unexpected success in OSS but not in fbcode.") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class M(torch.nn.Module): def forward(self, x): idx = torch.cond(x.sum() > 3, lambda: 0, lambda: 1, tuple()) @@ -1683,8 +1798,11 @@ def forward(self, x): x: "f32[3, 3]"; x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sum_1: "f32[]" = torch.ops.aten.sum.default(x) gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 3); sum_1 = None @@ -1713,6 +1831,7 @@ def forward(self): ) self.assertEqual(m(*args), ep.module()(*args)) +<<<<<<< HEAD @testing.expectedFailureCppRuntimeNonStrict def test_cond_access_identical_symint_closure(self): class Example2(torch.nn.Module): @@ -1747,6 +1866,8 @@ def forward(self, x, trigger, target): self.assertEqual(m(*args), ep.module()(*args)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cond_branches_return_same_int(self): class M(torch.nn.Module): def forward(self, x): @@ -1770,8 +1891,11 @@ def forward(self, x): x: "f32[3, 3]"; x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sum_1: "f32[]" = torch.ops.aten.sum.default(x) gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 3); sum_1 = None @@ -1939,6 +2063,7 @@ def forward(self, x): ): export(M(), (torch.randn(2, 3),), strict=False) +<<<<<<< HEAD @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_while_loop_tensor_constant_idx(self): def while_loop_decomp(x, y0): @@ -1969,6 +2094,8 @@ def forward(self, x, y0): out = ep.module()(x, y0) self.assertEqual(exp_out, out) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_malformed_fqn_from_source_name(self): # See https://github.com/pytorch/pytorch/issues/141939 from types import MethodType @@ -2027,11 +2154,16 @@ def annotate_split_points(mod: torch.nn.Module, spec): for problem in [Problem1, Problem2]: m = problem() m(torch.rand(64, 64)) +<<<<<<< HEAD # simplified torch.distributed.pipeline code +======= + # simpified torch.distributed.pipeline code +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) annotate_split_points(m, {"blocks.1": 1, "blocks.3": 1}) gm = export(m, (torch.rand(64, 64),)) torch.export.unflatten(gm) +<<<<<<< HEAD def test_unflatten_closure(self): class Dummy(torch.nn.Module): def forward(self, fn, x): @@ -2080,6 +2212,8 @@ def forward(self, add): return add_5""", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_state_primitives(self): class M(torch.nn.Module): def __init__(self) -> None: @@ -2454,9 +2588,12 @@ def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ep = export(model, inputs) def test_subclasses_parameterization(self): +<<<<<<< HEAD if "cpp_runtime_nonstrict" in self.id(): self.skipTest("TODO Unexpected success in OSS but not in fbcode.") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Foo(torch.nn.Module): def __init__(self): super().__init__() @@ -2509,7 +2646,10 @@ def forward(self, x): self.assertEqual(res, ref_out) +<<<<<<< HEAD @testing.expectedFailureCppRuntimeNonStrict +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_subclasses_parameterization_nested(self): class Foo(torch.nn.Module): def __init__(self): @@ -2582,6 +2722,7 @@ def forward(self, x): res = ep.module()(ref_x) self.assertEqual(res, ref_out) +<<<<<<< HEAD @testing.expectedFailureSerDer # can't serialize functorch ops @testing.expectedFailureSerDerNonStrict # can't serialize functorch ops @testing.expectedFailureCppRuntime @@ -2643,6 +2784,8 @@ def forward(self, x, y): @testing.expectedFailureLegacyExportNonStrict # Old export doesn't work with subclasses @testing.expectedFailureLegacyExportStrict # Old export doesn't work with subclasses +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_subclass_nested_attr_access(self): class Foo(torch.nn.Module): def __init__(self): @@ -3035,6 +3178,7 @@ def forward(self, x, y, z): ): export(Foo(), inputs, dynamic_shapes=shapes) +<<<<<<< HEAD def test_issue_157289(self): class MyModule(torch.nn.Module): def __init__(self): @@ -3081,6 +3225,8 @@ def forward(self, causal_mask, fill_value): return (slice_scatter,)""", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dim_dynamic_specialization(self): class Foo(torch.nn.Module): def forward(self, x): @@ -3140,6 +3286,7 @@ def forward(self, x, y): ep = export(Foo(), inputs, dynamic_shapes=shapes) ep.module()(torch.randn(8, 5), torch.randn(8, 5)) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: x.size()[0] >= 4"), ): @@ -3156,6 +3303,18 @@ def forward(self, x, y): escape("Guard failed: x.size()[1] <= 32"), ): # expected <= 32, but got 33 +======= + RuntimeError, "Expected input at .* to be >= 4, but got 3" + ): + ep.module()(torch.randn(3, 5), torch.randn(3, 5)) + with self.assertRaisesRegex( + RuntimeError, "Expected input at .* to be <= 16, but got 17" + ): + ep.module()(torch.randn(17, 5), torch.randn(17, 5)) + with self.assertRaisesRegex( + RuntimeError, "Expected input at .* to be <= 32, but got 33" + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(torch.randn(9, 33), torch.randn(9, 33)) def test_dim_hint_range_violations(self): @@ -3410,12 +3569,20 @@ def forward(self, x): actual_torch_fns = [] for mod in gm.modules(): +<<<<<<< HEAD if hasattr(mod, "graph"): for node in mod.graph.nodes: if node.name in {"sin", "cos"}: torch_fn = node.meta.get("torch_fn") print(torch_fn) actual_torch_fns.append(torch_fn) +======= + for node in mod.graph.nodes: + if node.name in {"sin", "cos"}: + torch_fn = node.meta.get("torch_fn") + print(torch_fn) + actual_torch_fns.append(torch_fn) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) exp_torch_fns = [ ("cos_1", "method_descriptor.cos"), ("sin_1", "method_descriptor.sin"), @@ -3588,10 +3755,16 @@ def forward(self, x, y): dynamic_shapes=({0: dimx}, {0: dimy}), ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: x.size()[0] == -1 + y.size()[0]"), ): # expected 5, but got 6 +======= + RuntimeError, + "Expected input.*shape.*to be equal to 5, but got 6", + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(torch.randn(4), torch.randn(6)) self.assertEqual(ep.module()(torch.randn(4), torch.randn(5)).size()[0], 4) @@ -3650,6 +3823,7 @@ def forward(self, z, y): dynamic_shapes=({0: dimz}, {0: dimy}), ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: z.size()[0] <= 7"), ): @@ -3660,6 +3834,15 @@ def forward(self, z, y): escape("Guard failed: -1 + 2 * z.size()[0] == y.size()[0]"), ): # expected 9, but got 8 +======= + RuntimeError, "Expected input.*shape.*to be <= 7, but got 8" + ): + ep.module()(torch.randn(8), torch.randn(15)) + with self.assertRaisesRegex( + RuntimeError, + "Expected input.*shape.*to be equal to 9, but got 8", + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(torch.randn(5), torch.randn(8)) self.assertEqual(ep.module()(torch.randn(5), torch.randn(9)).size()[0], 4) @@ -3695,18 +3878,31 @@ def forward(self, w): dynamic_shapes=({0: dimw},), ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: w.size()[0] % 2 == 0"), ): # expected 2*..., got 9 +======= + RuntimeError, + "Expected input.*shape.*= 9 to be " + "of the form 2\\*s92, where s92 is an integer", + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(torch.randn(9)) self.assertEqual(ep.module()(torch.randn(8)).size()[0], 4) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: w.size()[0] <= 12"), ): # expected <= 12, but got 14 +======= + RuntimeError, + "Expected input.*shape.*to be <= 12, but got 14", + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(torch.randn(14)) def test_derived_dim_repeat_derived(self): @@ -3744,10 +3940,16 @@ def forward(self, x, y, z): dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}), ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: z.size()[0] >= 6"), ): # expected 8, but got 5 +======= + RuntimeError, + "Expected input.*shape.*to be equal to 8, but got 5", + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(torch.randn(6), torch.randn(7), torch.randn(5)) self.assertEqual( @@ -3780,10 +3982,16 @@ def forward(self, x, y, z, x1, x2): dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}, {0: dimx1}, {0: dimx2}), ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: x2.size()[0] == x.size()[0]"), ): # expected 6, but got 5 +======= + RuntimeError, + "Expected input.*shape.*to be equal to 6, but got 5", + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()( torch.randn(6), torch.randn(7), @@ -3809,10 +4017,16 @@ def forward(self, x, y, z, x1, x2): dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}, {0: dimx1}, {0: dimx2}), ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: x2.size()[0] == x.size()[0]"), ): # expected 6, but got 5 +======= + RuntimeError, + "Expected input.*shape.*to be equal to 6, but got 5", + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()( torch.randn(6), torch.randn(7), @@ -4199,6 +4413,7 @@ def forward(self, x): inp = torch.randn(3, 3) self.assertTrue(torch.allclose(ep.module()(inp)[0], inp + 1)) +<<<<<<< HEAD def test_set_grad_as_side_effect(self): class Foo(torch.nn.Module): def forward(self, x): @@ -4210,6 +4425,8 @@ def forward(self, x): after = torch.is_grad_enabled() self.assertEqual(before, after) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_derived_dim_out_of_order_simplified(self): _dimz = torch.export.Dim("_dimz", min=6, max=8) dimy = _dimz - 1 @@ -4248,10 +4465,16 @@ def forward(self, x, y, z): dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}), ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: z.size()[0] >= 6"), ): # expected 8, but got 5 +======= + RuntimeError, + "Expected input.*shape.*to be equal to 8, but got 5", + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(torch.randn(6), torch.randn(7), torch.randn(5)) self.assertEqual( @@ -4286,7 +4509,10 @@ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) linear_weight = self.linear.weight linear_bias = self.linear.bias +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None return pytree.tree_unflatten((linear,), self._out_spec)""", ) @@ -4327,7 +4553,10 @@ def forward(self, b_buffer, x): def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) buffer = self.buffer +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add_ = torch.ops.aten.add_.Tensor(x, 5); x = None add__1 = torch.ops.aten.add_.Tensor(buffer, 5); buffer = None add = torch.ops.aten.add.Tensor(add_, add__1); add_ = add__1 = None @@ -4422,6 +4651,7 @@ def forward(self, container): ) ) +<<<<<<< HEAD def test_function_holding_tensor(self): global_storage = [] @@ -4602,6 +4832,8 @@ def forward(self, x, y): with self.assertWarnsRegex(UserWarning, warn_re): ep = export(lc, (torch.randn(4, 4), torch.randn(4, 4)), strict=False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_export_for_training_run_decomp(self): class Foo(torch.nn.Module): def __init__(self) -> None: @@ -4649,10 +4881,16 @@ def forward(self, x, y, y1, z): dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimy}, {0: dimz}), ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: y1.size()[0] == y.size()[0]"), ): # expected 7, but got 5 +======= + RuntimeError, + "Expected input.*shape.*to be equal to 7, but got 5", + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()( torch.randn(6), torch.randn(7), @@ -4704,9 +4942,14 @@ def forward(self, x, y, z): ep = export(foo, inputs, dynamic_shapes=dynamic_shapes) self.assertEqual(foo(*inputs), ep.module()(*inputs)) for wrong_inputs in wrong_shape_inputs: +<<<<<<< HEAD with self.assertRaisesRegex(AssertionError, "Guard failed"): with self.assertRaises(RuntimeError): ep.module()(*wrong_inputs) +======= + with self.assertRaises(RuntimeError): + ep.module()(*wrong_inputs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # check range_constraints - static dims shouldn't be present ep = export(foo, inputs, dynamic_shapes=((dx, None), (dy, 4), (dz, 3))) @@ -4742,10 +4985,15 @@ def forward(self, x): ep.module()(torch.randn(1, 2)) ep.module()(torch.randn(2, 2)) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: x.size()[0] <= 2"), ): # expected <= 2, but got 3 +======= + RuntimeError, "Expected input at .* to be <= 2, but got 3" + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(torch.randn(3, 2)) vr = list(ep.range_constraints.values())[0] self.assertEqual(vr.lower, 1) @@ -4762,12 +5010,16 @@ def forward(self, x, y): (torch.randn(2, 2), torch.randn(3, 2)), dynamic_shapes=({0: dx, 1: None}, {0: dx + 1, 1: None}), ) +<<<<<<< HEAD with self.assertRaisesRegex( AssertionError, escape("Guard failed: -1 + y.size()[0] != 1"), ): # TODO: this should not error? ep.module()(torch.randn(1, 2), torch.randn(2, 2)) +======= + ep.module()(torch.randn(1, 2), torch.randn(2, 2)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) range_lower_bounds = sorted(vr.lower for vr in ep.range_constraints.values()) range_upper_bounds = sorted(vr.upper for vr in ep.range_constraints.values()) self.assertEqual(range_lower_bounds, [1, 2]) @@ -4814,6 +5066,7 @@ def forward(self, x, mask): self.assertTrue(torch.allclose(ref[0], actual[0])) self.assertTrue(torch.allclose(ref[1], actual[1])) +<<<<<<< HEAD @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_layer_norm_unbacked_normalized_shape(self): class MyModel(torch.nn.Module): @@ -4851,6 +5104,8 @@ def forward(self, x, repeat): exported = export(model, inputs).module() self.assertEqual(model(*inputs), exported(*inputs)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dynamic_shapes_builder_basic(self): class M(torch.nn.Module): def forward(self, x, y, z): @@ -4957,7 +5212,11 @@ def forward(self, x, y, z): self.assertEqual(got_shapes, expected_shapes) def expect_error(bad_args, run_time_msg, compile_time_msg): +<<<<<<< HEAD with self.assertRaisesRegex(AssertionError, run_time_msg): +======= + with self.assertRaisesRegex(RuntimeError, run_time_msg): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(*bad_args) additional_inputs = torch.export.AdditionalInputs() @@ -4969,27 +5228,39 @@ def expect_error(bad_args, run_time_msg, compile_time_msg): expect_error( # 4->2, 4->2, 3->3 bad_args=(torch.randn(2), [torch.randn(2)], {"k": torch.randn(3)}), +<<<<<<< HEAD run_time_msg=escape( "Guard failed: x.size()[0] >= 3" ), # expected >= 3, but got 2 +======= + run_time_msg="Expected input.*to be >= 3, but got 2", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) compile_time_msg="Expected input.*to be >= 3, but got 2", ) expect_error( # 4->6, 4->7, 3->3 bad_args=(torch.randn(6), [torch.randn(7)], {"k": torch.randn(3)}), +<<<<<<< HEAD run_time_msg=escape( "Guard failed: y[0].size()[0] == x.size()[0]" ), # expected 6, but got 7 +======= + run_time_msg="Expected input.*to be equal to 6, but got 7", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) compile_time_msg="Expected input.*to be equal to 6, but got 7", ) expect_error( # 4->5, 4->5, 3->4 bad_args=(torch.randn(5), [torch.randn(5)], {"k": torch.randn(4)}), +<<<<<<< HEAD run_time_msg=escape( "Guard failed: z['k'].size()[0] == 3" ), # expected 3, but got 4 +======= + run_time_msg="Expected input.*to be equal to 3, but got 4", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) compile_time_msg=r"You marked.*but your code specialized it to be a constant.*If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO", ) @@ -5472,9 +5743,12 @@ def forward(self, x, offsets_t, fixes): ) def test_simple_unbacked_view(self): +<<<<<<< HEAD if "cpp_runtime_nonstrict" in self.id(): self.skipTest("TODO Unexpected success in OSS but not in fbcode.") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Foo(torch.nn.Module): def forward(self, x): u0 = x.item() @@ -5584,6 +5858,10 @@ def forward(self, x): # There should be nonzero view nodes in the graph self.assertTrue(view_count > 0) +<<<<<<< HEAD +======= + @testing.expectedFailureCppSerDes # cpp ser/der not handling complicated symbols +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_solver_unsupported_sympy_function(self): # repro of https://github.com/pytorch/pytorch/issues/131897 @@ -5638,6 +5916,7 @@ def forward(self, x, y): self.assertTrue(torch.allclose(ep.module()(x, y), model(x, y))) x2 = torch.arange(4).reshape((2, 2)) y2 = torch.arange(9).reshape((3, 3)) +<<<<<<< HEAD with self.assertRaisesRegex( AssertionError, ( @@ -5650,6 +5929,9 @@ def forward(self, x, y): ): # TODO: this should not error? self.assertTrue(torch.allclose(ep.module()(x2, y2), model(x2, y2))) +======= + self.assertTrue(torch.allclose(ep.module()(x2, y2), model(x2, y2))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_export_max_nonstrict(self): class FooMax(torch.nn.Module): @@ -5773,11 +6055,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dim0_x = torch.export.Dim("dim0_x", min=3) dim1_x = torch.export.Dim("dim1_x", max=8000) dynamic_shapes = {"x": (dim0_x, dim1_x)} +<<<<<<< HEAD em = torch.export.export( m, (a,), dynamic_shapes=dynamic_shapes, prefer_deferred_runtime_asserts_over_guards=True, +======= + em = torch.export._trace._export( + m, + (a,), + dynamic_shapes=dynamic_shapes, + allow_complex_guards_as_runtime_asserts=True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) em.module()(torch.randn(4, 3)) with self.assertRaisesRegex( @@ -5792,10 +6082,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: em = torch.export.export(m, (a,), dynamic_shapes=dynamic_shapes) x = torch.randn(3, 5) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: 3 * x.size()[1] % 2 == 0"), ): # expected 2*..., but got 5 +======= + RuntimeError, + "Expected.*shape\\[1\\] = 5 to be of the form 2\\*s33, where s33 is an integer", + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) em.module()(x) def test_dont_duck_size_for_auto_dynamic(self): @@ -5818,9 +6114,12 @@ def forward(self, x, y): ep.module()(torch.randn(6, 3), torch.randn(7, 4)) def test_map(self): +<<<<<<< HEAD if "cpp_runtime_nonstrict" in self.id(): self.skipTest("TODO Unexpected success in OSS but not in fbcode.") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Module(torch.nn.Module): def forward(self, xs, y, z): def body(x, y, z): @@ -6075,6 +6374,7 @@ def forward(self, x, y): self.assertTrue(torch.allclose(er, r)) +<<<<<<< HEAD @testing.expectedFailureSerDerNonStrict @testing.expectedFailureCppRuntimeNonStrict def test_more_multidimensional_slicing(self): @@ -6233,6 +6533,8 @@ def forward(self, t, x): test(M_slice_None_Ellipsis_int(), G_slice_None_Ellipsis_int()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_sequential_slicing(self): # See https://github.com/pytorch/pytorch/issues/137455 @@ -6792,9 +7094,13 @@ def forward(self, kjt) -> torch.Tensor: efoo = torch.export.export( foo, inputs, +<<<<<<< HEAD dynamic_shapes={ "kjt": [{0: dim}, None, {0: dim}, {0: dim_plus_one}, None, None] }, +======= + dynamic_shapes={"kjt": [{0: dim}, None, {0: dim}, {0: dim_plus_one}]}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.assertEqual( [out.shape for out in efoo.module()(*inputs)], @@ -7895,7 +8201,10 @@ def forward(self, x): bn_running_mean = self.bn.running_mean bn_running_var = self.bn.running_var bn_num_batches_tracked = self.bn.num_batches_tracked; bn_num_batches_tracked = None +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None batch_norm = torch.ops.aten.batch_norm.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, False, 0.1, 1e-05, True); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None return pytree.tree_unflatten((batch_norm,), self._out_spec)""", @@ -7915,7 +8224,10 @@ def forward(self, x): bn_running_mean = self.bn.running_mean bn_running_var = self.bn.running_var bn_num_batches_tracked = self.bn.num_batches_tracked +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None add_ = torch.ops.aten.add_.Tensor(bn_num_batches_tracked, 1); bn_num_batches_tracked = add_ = None batch_norm = torch.ops.aten.batch_norm.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.1, 1e-05, True); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None @@ -8151,6 +8463,7 @@ def forward(self, inputs): ]: self.assertFalse(hasattr(tensor, attr)) +<<<<<<< HEAD @testing.expectedFailureCppRuntime def test_while_loop_index_assertions(self): from torch._higher_order_ops import while_loop @@ -8214,6 +8527,8 @@ def body_fn(idx, x): ): ep.graph_module.while_loop_body_graph_0(torch.tensor([5]), torch.zeros(1)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_constrain_decomp(self) -> None: class M(torch.nn.Module): def __init__(self) -> None: @@ -8471,9 +8786,12 @@ def forward(self, x, m): self.assertEqual(ref_out, ep.module()(ref_x, mod)) def test_unbacked_noncontig_lin(self): +<<<<<<< HEAD if "cpp_runtime_nonstrict" in self.id(): self.skipTest("TODO Unexpected success in OSS but not in fbcode.") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Foo(torch.nn.Module): def __init__(self): super().__init__() @@ -8509,20 +8827,32 @@ def forward(self, x, y): ) ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: y == 5"), ): # expected 5, but got 6 +======= + RuntimeError, + escape("Expected input at *args[1] to be equal to 5, but got 6"), + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _ = exported.module()(torch.ones(8, 5), 6) exported = torch.export.export( foo, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: y == 5.0"), ): # expected 5.0, but got 6.0 +======= + RuntimeError, + escape("Expected input at *args[1] to be equal to 5.0, but got 6.0"), + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _ = exported.module()(torch.ones(7, 5), 6.0) def test_runtime_assert_for_prm_str(self): @@ -8534,17 +8864,25 @@ def forward(self, a, b, mode): inps = (torch.randn(4, 4), torch.randn(4), "trunc") exported = export(foo, inps) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: mode == 'trunc'"), ): # expected 'trunc', but got 'floor' +======= + RuntimeError, "to be equal to trunc, but got floor" + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _ = exported.module()(torch.randn(4, 4), torch.randn(4), "floor") self.assertTrue(torch.allclose(exported.module()(*inps), foo(*inps))) def test_sym_or_sym_and(self): +<<<<<<< HEAD if "cpp_runtime_nonstrict" in self.id(): self.skipTest("TODO Unexpected success in OSS but not in fbcode.") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.experimental.symbolic_shapes import sym_and, sym_or class Foo(torch.nn.Module): @@ -8664,12 +9002,18 @@ def forward(self, x): dim0_x = torch.export.Dim("dim0_x") exported = torch.export.export(Foo(), (inp,), dynamic_shapes=({0: dim0_x},)) reexported = torch.export.export(exported.module(), (inp,)) +<<<<<<< HEAD with self.assertRaisesRegex( AssertionError, escape("Guard failed: x.size()[0] == 5"), ): # expected 5, but got 7 +======= + with self.assertRaisesRegex( + RuntimeError, "shape\[0\] to be equal to 5, but got 7" + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reexported.module()(torch.ones(7, 5)) reexported = torch.export.export( @@ -8687,10 +9031,16 @@ def forward(self, x): Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x_v2}} ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: x.size()[0] >= 3"), ): # expected >= 3, but got 2 +======= + RuntimeError, + escape("Expected input at *args[0].shape[0] to be >= 3, but got 2"), + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.export.export(exported_v2.module(), (torch.randn(2, 2),)) def test_export_cond_symbool_pred(self): @@ -8724,7 +9074,11 @@ def false_fn(x): str(schema), """cond(SymBool pred, GraphModule true_fn, GraphModule false_fn, Tensor[2] operands) -> Tensor[1]""", ) +<<<<<<< HEAD # serdes deserializes tuple as list +======= + # serdes deserailizes tuple as list +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if need_serdes_test(self._testMethodName): self.assertExpectedInline( ep.graph_module.code.strip(), @@ -8770,6 +9124,7 @@ def forward(self, b_a_buffer, x): torch.allclose(ep.module()(torch.ones(6, 4)), Foo()(torch.ones(6, 4))) ) +<<<<<<< HEAD def test_ccode_python_mod(self): import sympy @@ -8793,6 +9148,8 @@ def forward(self, xs): """(u0 % u1) < 0 ? u0 % u1 + abs(u1) : u0 % u1""", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_aten_lift_fresh_copy(self): class M(torch.nn.Module): def forward(self, x): @@ -8835,7 +9192,11 @@ def forward(self, x): len([node for node in gm.graph.nodes if node.op == "placeholder"]), 1 ) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @testing.expectedFailureCppRuntime def test_export_associative_scan_symbol_dim(self): device = torch.device("cuda") @@ -8860,7 +9221,11 @@ def forward(self, x): module_out = Foo()(xs) self.assertTrue(torch.allclose(ep.module()(xs), module_out)) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @testing.expectedFailureCppRuntime def test_export_associative_scan_symbol_scandim(self): device = torch.device("cuda") @@ -8885,11 +9250,16 @@ def forward(self, x): module_out = Foo()(xs) self.assertTrue(torch.allclose(ep.module()(xs), module_out)) +<<<<<<< HEAD @requires_cuda_and_triton def test_export_associative_scan_lifted_buffers(self): if "cpp_runtime_nonstrict" in self.id(): self.skipTest("TODO Unexpected success in OSS but not in fbcode.") +======= + @requires_cuda + def test_export_associative_scan_lifted_buffers(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device = torch.device("cuda") combine_mode = "pointwise" @@ -9368,7 +9738,11 @@ def _decompose_linear_custom(x, weight, bias): self.assertExpectedInline( str(ep_decompose_linear.graph_module.code).strip(), """\ +<<<<<<< HEAD def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y): +======= +def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_bias, c_linear_weight, x, y): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None permute = torch.ops.aten.permute.default(c_linear_weight, [1, 0]); c_linear_weight = None @@ -9436,6 +9810,13 @@ def forward(self, x): inp = torch.randn(2) self.assertTrue(torch.allclose(ep.module()(inp), torch.nonzero(inp))) +<<<<<<< HEAD +======= + # TODO(pianpwk) blocker: https://github.com/pytorch/pytorch/issues/151809 + @testing.expectedFailureSerDer + @testing.expectedFailureSerDerNonStrict + @testing.expectedFailureCppSerDes +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_redundant_asserts(self): class Foo(torch.nn.Module): def forward(self, x): @@ -9485,10 +9866,15 @@ def forward(self, a, b): dynamic_shapes=(None, None), ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: b.size()[0] == 4"), ): # expected 4, but got 7 +======= + RuntimeError, "shape\[0\] to be equal to 4, but got 7" + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep_v2.module()(*test_inp) def test_constant_output(self): @@ -9568,11 +9954,15 @@ def dynamify_inp(x): ep = torch.export.export(foo, inp, dynamic_shapes=dynamic_shapes) test_inp = ((torch.randn(4, 4), torch.randn(2, 4)), torch.randn(4, 4)) +<<<<<<< HEAD with self.assertRaisesRegex( AssertionError, escape("Guard failed: a[1].size()[0] >= 3"), ): # expected >= 3, but got 2 +======= + with self.assertRaisesRegex(RuntimeError, "shape\[0\] to be >= 3, but got 2"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(*test_inp) def test_nested_module(self): @@ -9774,6 +10164,7 @@ def forward(self, x): ).module() with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: x.size()[0] >= 3"), ): @@ -9785,6 +10176,15 @@ def forward(self, x): escape("Guard failed: x.size()[0] >= 3"), ): # expected >= 3, got 2 +======= + RuntimeError, escape("Expected input at *args[0].shape[0]") + ): + gm(torch.randn(2, 2)) + + with self.assertRaisesRegex( + RuntimeError, escape("Expected input at *args[0].shape[0]") + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) export(gm, (torch.randn(2, 2),)) ep = export( @@ -9892,7 +10292,11 @@ def forward(self, x): x = torch.rand(5, 2, 2) model = Model() +<<<<<<< HEAD # Manually set the fake_device of fake tensors. +======= + # Manualy set the fake_device of fake tensors. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x.fake_device = torch.device("cuda:0") for n, p in model.named_parameters(): p.fake_device = torch.device("cuda:0") @@ -11603,11 +12007,15 @@ def forward(self, x, y): ep = export(M(), (4, 5)) self.assertEqual(ep.module()(4, 5), 20) +<<<<<<< HEAD with self.assertRaisesRegex( AssertionError, escape("Guard failed: x == 4"), ): # expected 4, but got 3 +======= + with self.assertRaisesRegex(RuntimeError, r"to be equal to 4, but got 3"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(ep.module()(3, 6), 18) ep = export(M(), (4, 5), dynamic_shapes={"x": Dim.DYNAMIC, "y": Dim.AUTO}) @@ -11620,11 +12028,15 @@ def forward(self, x, y): ep = export(M(), (5, 5), dynamic_shapes={"x": None, "y": Dim.AUTO}) self.assertEqual(ep.module()(5, 6), 30) +<<<<<<< HEAD with self.assertRaisesRegex( AssertionError, escape("Guard failed: x == 5"), ): # expected 5, but got 3 +======= + with self.assertRaisesRegex(RuntimeError, r"to be equal to 5, but got 3"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(ep.module()(3, 5), 18) class M(torch.nn.Module): @@ -11640,6 +12052,10 @@ def forward(self, x, y): self.assertTrue(torch.allclose(ep.module()(*inp), M()(*inp))) @testing.expectedFailureCppRuntime +<<<<<<< HEAD +======= + @testing.expectedFailureRetraceabilityNonStrict # no runtime asserts added for assert x == 3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_symint_input_specialization(self): class M(torch.nn.Module): def forward(self, x, y): @@ -11664,6 +12080,7 @@ def forward(self, x, y): inp, dynamic_shapes=(Dim.AUTO, None), ) +<<<<<<< HEAD with self.assertRaisesRegex( AssertionError, escape("Guard failed: x == 3"), @@ -11672,6 +12089,13 @@ def forward(self, x, y): ep.module()(4, torch.randn(4, 4)) @testing.expectedFailureCppRuntime +======= + with self.assertRaisesRegex(RuntimeError, "to be equal to 3, but got 4"): + ep.module()(4, torch.randn(4, 4)) + + @testing.expectedFailureCppRuntime + @testing.expectedFailureRetraceabilityNonStrict # no runtime asserts added for assert x == 3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_symint_input_ranges(self): class M(torch.nn.Module): def forward(self, x, y): @@ -11685,6 +12109,7 @@ def forward(self, x, y): ) ep.module()(4, torch.randn(4, 4)) +<<<<<<< HEAD with self.assertRaisesRegex( AssertionError, escape("Guard failed: x <= 10"), @@ -11696,6 +12121,11 @@ def forward(self, x, y): escape("Guard failed: x >= 3"), ): # expected >= 3, but got 2 +======= + with self.assertRaisesRegex(RuntimeError, "to be <= 10, but got 16"): + ep.module()(16, torch.randn(4, 4)) + with self.assertRaisesRegex(RuntimeError, "to be >= 3, but got 2"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(2, torch.randn(4, 4)) # While tracing the range was found to be a subset of the original range @@ -12376,7 +12806,10 @@ def test(m, expected_graph, expected_fqns, expected_duplicates): [ fqn for fqn, _ in unflattened.named_modules(remove_duplicate=False) +<<<<<<< HEAD if fqn != "_guards_fn" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] ), expected_fqns, @@ -12527,9 +12960,13 @@ def forward(self, x): return x inp = torch.randn(4, 4) +<<<<<<< HEAD gm = torch.fx.experimental.proxy_tensor.make_fx( Foo(), record_stack_traces=True )( +======= + gm = torch.fx.experimental.proxy_tensor.make_fx(Foo(), stack_trace=True)( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inp, ) @@ -12570,6 +13007,7 @@ def forward(self, x): ) ) +<<<<<<< HEAD def test_filter_traceback_frames(self): class TestTracer(torch.fx.Tracer): def __init__(self) -> None: @@ -12594,6 +13032,8 @@ def forward(self, x): trace_x = [node for node in graph.nodes if node.name == "x"][0].stack_trace self.assertTrue(re.search(r"proxy.py.*in create_node\n", trace_x)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @testing.expectedFailureSerDerNonStrict # register_constant needs to handle serialization @testing.expectedFailureSerDer # register_constant needs to handle serialization def test_register_constant(self): @@ -13612,6 +14052,7 @@ def forward(self, x): ): _ = export(Foo(), (torch.randn(4, 4),), strict=False) +<<<<<<< HEAD def test_vmap_custom_autograd_function(self): from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex @@ -13658,6 +14099,8 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: ) self.assertEqual(m(idxs), ep.module()(idxs)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_unbacked_deferred_runtime_retrace(self): class Foo(torch.nn.Module): def forward(self, x, y): @@ -13744,7 +14187,11 @@ def forward(self, x): def test_disable_forced_specializations_ok(self): # check that we don't force specialization, and defer to runtime asserts +<<<<<<< HEAD # with prefer_deferred_runtime_asserts_over_guards=True to successfully export +======= + # with allow_complex_guards_as_runtime_asserts=True to successfully export +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # case 1: modulo guards from torch.export import dims @@ -13754,11 +14201,19 @@ def forward(self, x): inputs = (torch.randn(10, 72),) dx, dy = dims("dx", "dy") +<<<<<<< HEAD ep = torch.export.export( Mod4Reshape(), inputs, dynamic_shapes={"x": (dx, dy)}, prefer_deferred_runtime_asserts_over_guards=True, +======= + ep = torch.export._trace._export( + Mod4Reshape(), + inputs, + dynamic_shapes={"x": (dx, dy)}, + allow_complex_guards_as_runtime_asserts=True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) out1 = ep.module()(torch.randn(8, 7)) self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape) @@ -13785,6 +14240,7 @@ def forward(self, x, y, z): "y": [Dim(f"dy{i}", min=2) for i in range(2)], "z": [Dim(f"dz{i}", min=4) for i in range(1)], } +<<<<<<< HEAD for private_api in (True, False): if private_api: @@ -13818,6 +14274,24 @@ def forward(self, x, y, z): ): # expected 40, but got 20 ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)) +======= + ep = torch.export._trace._export( + FreeReshape(), + inputs, + dynamic_shapes=dynamic_shapes, + allow_complex_guards_as_runtime_asserts=True, + ) + ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes) + out1 = ep.module()(torch.randn(48, 1), torch.randn(4, 12), torch.randn(48)) + self.assertEqual(out1.shape, torch.ones(48).shape) + out2 = ep.module()(torch.randn(5, 8), torch.randn(4, 10), torch.randn(40)) + self.assertEqual(out2.shape, torch.ones(40).shape) + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Eq\((.*)\) on node '.*'", + ): # fail only at runtime + ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)) # fail +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # case 3: 3d reshape (previously failing with different issue) class Reshape3d(torch.nn.Module): @@ -13832,11 +14306,19 @@ def forward(self, x, y): "x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)), "y": (Dim("dy", min=8),), } +<<<<<<< HEAD ep = torch.export.export( Reshape3d(), inputs, dynamic_shapes=dynamic_shapes, prefer_deferred_runtime_asserts_over_guards=True, +======= + ep = torch.export._trace._export( + Reshape3d(), + inputs, + dynamic_shapes=dynamic_shapes, + allow_complex_guards_as_runtime_asserts=True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126)) self.assertEqual(out1.shape, torch.ones(126).shape) @@ -13897,6 +14379,7 @@ def forward(self, x, y): self.assertFalse(placeholders[1].meta["val"].requires_grad) self.assertTrue(placeholders[2].meta["val"].requires_grad) +<<<<<<< HEAD def test_expand_copy_export_handles_implicit_true(self): class ExpandModel(torch.nn.Module): def __init__(self): @@ -13917,6 +14400,9 @@ def test_unbacked_expand(self): if "cpp_runtime_nonstrict" in self.id(): self.skipTest("TODO Unexpected success in OSS but not in fbcode.") +======= + def test_unbacked_expand(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Foo(torch.nn.Module): def forward(self, xs): u0, u1, u2 = xs.tolist() @@ -13974,11 +14460,19 @@ def forward(self, x): model = Model() x = torch.rand(1024, 20, 16) dynamic_shapes = {"x": {0: Dim("batch")}} +<<<<<<< HEAD ep = torch.export.export( model, (x,), dynamic_shapes=dynamic_shapes, prefer_deferred_runtime_asserts_over_guards=True, +======= + ep = torch.export._trace._export( + model, + (x,), + dynamic_shapes=dynamic_shapes, + allow_complex_guards_as_runtime_asserts=True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) with self.assertRaisesRegex( RuntimeError, @@ -14051,11 +14545,19 @@ def forward(self, x, y): inputs = (torch.randn(6), torch.randn(12)) dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]} +<<<<<<< HEAD ep = torch.export.export( Foo(), inputs, dynamic_shapes=dynamic_shapes, prefer_deferred_runtime_asserts_over_guards=True, +======= + ep = torch.export._trace._export( + Foo(), + inputs, + dynamic_shapes=dynamic_shapes, + allow_complex_guards_as_runtime_asserts=True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # check forward pass out0, out1 = ep.module()(torch.randn(9), torch.randn(27)) @@ -14090,7 +14592,11 @@ def forward(self, x, y): Foo(), inputs, dynamic_shapes=dynamic_shapes, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards=True, +======= + allow_complex_guards_as_runtime_asserts=True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ).run_decompositions() self.assertEqual( @@ -14320,6 +14826,12 @@ def forward(self, x, y): ): ep.module()(torch.randn(10), torch.tensor(2)) +<<<<<<< HEAD +======= + @testing.expectedFailureCppSerDes # TODO: When we deserialize we somehow hardcode sympy.lower to 2 + @testing.expectedFailureSerDerNonStrict + @testing.expectedFailureSerDer +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch.fx.experimental._config.patch(backed_size_oblivious=True) def test_baddbmm(self): class M(torch.nn.Module): @@ -14344,7 +14856,11 @@ def forward(self, x): self.assertTrue(torch.allclose(m(x2), ep.module()(x2))) self.assertTrue(torch.allclose(m(x1), ep.module()(x1))) +<<<<<<< HEAD @testing.expectedFailureSerDerNonStrict # constructor is not serialized today +======= + @testing.expectedFailureSerDerNonStrict # construtor is not serialized today +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @testing.expectedFailureSerDer # constructor is not serialized today @testing.expectedFailureRetraceability # dynamo doesn't work with FlatApply op def test_capture_subclass_constructor(self): @@ -14455,7 +14971,14 @@ def __init__(self): def forward(self, x): return x.cos() +<<<<<<< HEAD export(Foo(), (torch.randn(4, 4),)) +======= + with self.assertRaisesRegex( + RuntimeError, "TestExport.test_capture_subclass_wrong..Foo" + ): + export(Foo(), (torch.randn(4, 4),)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_capture_subclass_constructor_torch_ir(self): class Foo(torch.nn.Module): @@ -14499,11 +15022,19 @@ def forward(self, x, y): inputs = (torch.randn(5), torch.randn(3)) shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)} +<<<<<<< HEAD ep = torch.export.export( Foo(), inputs, dynamic_shapes=shapes, prefer_deferred_runtime_asserts_over_guards=True, +======= + ep = torch.export._trace._export( + Foo(), + inputs, + dynamic_shapes=shapes, + allow_complex_guards_as_runtime_asserts=True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # count 2 pow nodes, 2 sym_size.int nodes self.assertEqual( @@ -15300,6 +15831,7 @@ class ModConstraint(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return x.view(x.shape[0] - 1, -1) +<<<<<<< HEAD for private_api in (True, False): if private_api: ep = torch.export.export( @@ -15338,6 +15870,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ): # expected 3*..., but got 8 ep.module()(torch.randn(4, 2)) +======= + ep = export( + ModConstraint(), + (torch.randn(3, 4),), + dynamic_shapes={ + "x": (dynamic, dynamic), + }, + ) + ep.module()(torch.randn(5, 8)) + num_asserts = [ + node.target == torch.ops.aten._assert_scalar.default + for node in ep.graph.nodes + ].count(True) + self.assertEqual(num_asserts, 2) + with self.assertRaises(RuntimeError): + ep.module()(torch.randn(4, 2)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @testing.expectedFailureSerDer # T195866111 @testing.expectedFailureSerDerNonStrict @@ -15530,6 +16079,7 @@ def fn(x): self.assertEqual(x.sin(), ep.module()(x)) pytree._deregister_pytree_node(torch.FunctionSchema) +<<<<<<< HEAD @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") def test_exception(self): class Model(torch.nn.Module): @@ -15575,6 +16125,8 @@ def forward(self, x): strict=False, ).module() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_export_for_training_with_state_dict_hooks(self): def _state_dict_pre_hook(mod, prefix, keep_vars): mod._buffers["test"] = torch.Tensor([1]) @@ -15708,10 +16260,19 @@ def forward(self, x): @contextmanager def distributed_env(self, world_size): try: +<<<<<<< HEAD +======= + from torch.testing._internal.distributed.fake_pg import FakeStore + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.distributed.init_process_group( backend="fake", world_size=world_size, rank=0, +<<<<<<< HEAD +======= + store=FakeStore(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) yield @@ -15868,6 +16429,7 @@ def forward(self, args_0): return (abs_1,)""", ) +<<<<<<< HEAD def test_sdpa_gqa(self): from torch.nn.attention import sdpa_kernel, SDPBackend @@ -15971,6 +16533,8 @@ def forward(self, q, k, v): ): export(Foo(), (torch.randn(1, 33, 256, 128), k, v)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestOneOffModelExportResult(TestCase): @@ -16099,10 +16663,16 @@ def forward(self, x, y): self.assertEqual(res[1], 5) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: y == 5"), ): # expected 5, but got 20 +======= + RuntimeError, + escape("Expected input at *args[1] to be equal to 5, but got 20"), + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) res = ep.module()(torch.tensor(4), 20) class F(torch.nn.Module): @@ -16496,6 +17066,7 @@ def forward(self, x): len(list(new_ep.graph.nodes)[-1].args[0]), len(signature.output_specs) ) +<<<<<<< HEAD @requires_cuda_and_triton def test_assert_tensor_metadata_device_index(self): class N(torch.nn.Module): @@ -16582,6 +17153,8 @@ def forward(self, x): exported_param_names = [name for name, _ in gm.named_parameters()] self.assertEqual(original_param_names, exported_param_names) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") class TestExportCustomClass(TorchTestCase): @@ -16643,6 +17216,7 @@ def forward(self, x): arg = node.args[0] self.assertTrue(arg.op == "placeholder") +<<<<<<< HEAD def test_int_lift_constant(self): class M(torch.nn.Module): def forward(self, a, x): @@ -16654,6 +17228,8 @@ def forward(self, a, x): inp = (3, torch.randn(4)) self.assertTrue(torch.allclose(M()(*inp), ep.module()(*inp))) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_export_script_module(self): class Add(torch.nn.Module): def forward(self, x, y): @@ -16747,6 +17323,7 @@ def forward(self, x, ranks): MyModel(), inps, dynamic_shapes=spec, strict=True ).run_decompositions({}) +<<<<<<< HEAD def test_unbacked_contiguous(self): class MyModel(torch.nn.Module): def forward(self, x, mask): @@ -16861,6 +17438,8 @@ def forward(self, y): str(ep.graph) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 4ae4d45498e93..14bca4a35d94c 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -411,10 +411,16 @@ def forward(self, x): ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: x.size()[1] <= 6"), ): # expected <= 6, but got 7 +======= + RuntimeError, + escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"), + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(torch.zeros(2, 7, 3)) self.assertEqual( @@ -443,6 +449,7 @@ def forward(self, x, y): ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: x.size()[1] <= 6"), ): @@ -454,6 +461,17 @@ def forward(self, x, y): escape("Guard failed: y.size()[0] >= 3"), ): # expected >= 3, but got 2 +======= + RuntimeError, + escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"), + ): + ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) + + with self.assertRaisesRegex( + RuntimeError, + escape("Expected input at *args[1].shape[0] to be >= 3, but got 2"), + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) def test_runtime_assert_some_dims_not_specified(self) -> None: @@ -478,18 +496,30 @@ def forward(self, x, y): ) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: x.size()[1] <= 6"), ): # expected <= 6, but got 7 +======= + RuntimeError, + escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"), + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) # y is specialized to 5 with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: y.size()[0] == 5"), ): # expected 5, but got 2 +======= + RuntimeError, + escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"), + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) # Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1 @@ -514,19 +544,29 @@ def forward(self, x, y): M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}}, strict=True ) +<<<<<<< HEAD with self.assertRaisesRegex( AssertionError, escape("Guard failed: x.size()[1] == 2"), ): # expected 2, but got 7 +======= + with self.assertRaisesRegex(RuntimeError, escape("shape[1] to be equal to 2")): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) # y is specialized to 5 with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: y.size()[0] == 5"), ): # expected 5, but got 2 +======= + RuntimeError, + escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"), + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) # Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1 @@ -813,7 +853,10 @@ def test_predispatch_set_grad(self): """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add = torch.ops.aten.add.Tensor(x, 1); x = None sin = torch.ops.aten.sin.default(add); add = None sum_1 = torch.ops.aten.sum.default(sin); sin = None @@ -833,7 +876,10 @@ def forward(self, x): """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add = torch.ops.aten.add.Tensor(x, 1); x = None sin = torch.ops.aten.sin.default(add); add = None sum_1 = torch.ops.aten.sum.default(sin); sin = None @@ -853,7 +899,10 @@ def forward(self, x): """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add = torch.ops.aten.add.Tensor(x, 1); x = None sin = torch.ops.aten.sin.default(add); add = None sum_1 = torch.ops.aten.sum.default(sin); sin = None @@ -873,7 +922,10 @@ def forward(self, x): """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add = torch.ops.aten.add.Tensor(x, 1); x = None submod_5 = self.submod_1 sum_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None @@ -894,7 +946,10 @@ def forward(self, x): """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add = torch.ops.aten.add.Tensor(x, 1); x = None sin = torch.ops.aten.sin.default(add) sum_1 = torch.ops.aten.sum.default(sin); sin = None @@ -920,7 +975,10 @@ def forward(self, x): """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add = torch.ops.aten.add.Tensor(x, 1); x = None submod_5 = self.submod_1 wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None @@ -956,7 +1014,10 @@ def test_sequential_split_graph(self): """\ def forward(self, x1, x2): x1, x2, = fx_pytree.tree_flatten_spec(([x1, x2], {}), self._in_spec) +<<<<<<< HEAD submod_0 = self.submod_0(x1, x2); submod_0 = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) submod_1 = self.submod_1(x1, x2); x1 = x2 = None getitem = submod_1[0] getitem_1 = submod_1[1]; submod_1 = None @@ -1012,7 +1073,10 @@ def test_predispatch_autocast_and_set_grad(self): """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) submod_3 = self.submod_3 add = torch.ops.aten.add.Tensor(x, 1); x = None sin = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_3, add); submod_3 = add = None @@ -1051,7 +1115,10 @@ def test_predispatch_autocast(self): """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add = torch.ops.aten.add.Tensor(x, 1); x = None submod_3 = self.submod_1 add_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_3, add); submod_3 = add = None @@ -1084,7 +1151,10 @@ def forward(self, add): """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add = torch.ops.aten.add.Tensor(x, 1); x = None submod_4 = self.submod_1 sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None @@ -1135,7 +1205,10 @@ def forward(self, add_1): """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add = torch.ops.aten.add.Tensor(x, 1); x = None submod_4 = self.submod_1 wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None @@ -1193,7 +1266,10 @@ def forward(self, add_1, add_2): """\ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add = torch.ops.aten.add.Tensor(x, 1); x = None submod_4 = self.submod_1 sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None @@ -1235,7 +1311,10 @@ def test_inline_(self): ) after_inline_str = new_gm.print_readable(print_output=False) self.assertEqual(before_str, after_inline_str) +<<<<<<< HEAD new_gm._guards_fn = gm._guards_fn +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(gm(*args), new_gm(*args)) def test_remove_auto_functionalized_pass(self) -> None: @@ -1326,6 +1405,7 @@ def forward(self, x): ) @unittest.skipIf(not TEST_CUDA, "requires cuda") +<<<<<<< HEAD def test_move_device_to(self): class M(torch.nn.Module): def forward(self, x): @@ -1369,6 +1449,8 @@ def forward(self, arg0_1): ) @unittest.skipIf(not TEST_CUDA, "requires cuda") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_move_to_device_pass(self): class Model(torch.nn.Module): def __init__(self, size=4, h_dim=10): @@ -1404,6 +1486,7 @@ def forward(self, x): outputs = gm(*test_inputs) self.assertEqual(outputs.device, torch.device("cuda:0")) +<<<<<<< HEAD @unittest.skipIf(not TEST_CUDA, "requires cuda") def test_move_device_example_inputs(self): class Model(torch.nn.Module): @@ -1436,6 +1519,8 @@ def forward(self, x, y, z): self.assertEqual(ep_cuda.example_inputs[0][1].device, torch.device("cuda:0")) self.assertEqual(ep_cuda.example_inputs[1]["z"].device, torch.device("cuda:0")) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_constant_folding_pass(self): from torch.ao.quantization.observer import MappingType, PerGroup, PerToken from torch.ao.quantization.pt2e._affine_quantization import ( diff --git a/test/export/test_schema.py b/test/export/test_schema.py index f184fead8b413..0407901a5aa66 100644 --- a/test/export/test_schema.py +++ b/test/export/test_schema.py @@ -404,6 +404,7 @@ def test_schema_check(self): next_version, _ = check(commit) self.assertEqual(next_version, [4, 1]) +<<<<<<< HEAD def test_schema_comparison(self): import torch._export.serde.schema as schema @@ -460,6 +461,8 @@ def test_schema_comparison(self): self.assertEqual(sig, sig_same) self.assertNotEqual(sig, sig_diff) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index faef9b455a0ee..b6595079707a1 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -14,6 +14,7 @@ from pathlib import Path from typing import NamedTuple +<<<<<<< HEAD from torch.testing._internal.inductor_utils import HAS_GPU @@ -24,6 +25,8 @@ from torch.library import wrap_triton from torch.utils._triton import has_triton +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._dynamo as torchdynamo import torch._export.serde.schema as schema @@ -31,9 +34,13 @@ import torch.utils._pytree as pytree from torch._export.db.case import ExportCase, SupportLevel from torch._export.db.examples import all_examples +<<<<<<< HEAD from torch._export.serde.schema import ArgumentKind from torch._export.serde.serialize import ( _dict_to_dataclass, +======= +from torch._export.serde.serialize import ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _to_json_bytes, canonicalize, deserialize, @@ -292,6 +299,7 @@ def forward(self, x): actual_out = loaded_ep.module()(*inp) self.assertEqual(exp_out, actual_out) +<<<<<<< HEAD def test_serialize_param_mutation(self): class Foo(torch.nn.Module): def __init__(self): @@ -311,6 +319,8 @@ def forward(self, x): val = loaded_ep.graph_signature.parameters_to_mutate self.assertEqual({"div": "parameter"}, val) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_serialize_constant_outputs(self): class MyModule(torch.nn.Module): def __init__(self) -> None: @@ -515,12 +525,19 @@ def forward(self, x): self.assertNotIn(name, seen) seen.add(name) +<<<<<<< HEAD def test_nonfinite_inputs(self) -> None: class Module(torch.nn.Module): def forward(self, x): x = torch.ops.aten.add.Scalar(x, math.inf) x = torch.ops.aten.add.Scalar(x, -math.inf) return torch.ops.aten.add.Scalar(x, math.nan) +======= + def test_infinity_inputs(self) -> None: + class Module(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.add.Scalar(x, math.inf) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fn = Module() ep = torch.export.export( @@ -593,6 +610,7 @@ def forward(self, x): serialized.exported_program.range_constraints[symint.name].max_val, 3 ) +<<<<<<< HEAD @unittest.skipIf( not torch.cuda.is_available() or not has_triton(), "requires cuda and triton" ) @@ -705,6 +723,8 @@ def forward(self, x, y): serialized.example_inputs, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_kwargs_default(self) -> None: """ Tests that the kwargs default values are serialized even if they are not @@ -759,6 +779,7 @@ def forward(self, x): if "aten.sum.dim_IntList" in node.target: self.assertEqual(node.inputs[1].arg.type, "as_ints") +<<<<<<< HEAD def test_empty_constant(self) -> None: class M(torch.nn.Module): def __init__(self): @@ -946,6 +967,8 @@ def forward(self, x): loaded_ep = load(buffer) self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") @@ -1125,7 +1148,11 @@ def test_optional_tuple(self): ) @torch.library.impl("mylib::foo", "cpu", lib=lib) +<<<<<<< HEAD @torch.library.register_fake("mylib::foo") +======= + @torch.library.impl_abstract("mylib::foo") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def foo_impl(a, b, c): res2 = None if c is not None: @@ -1214,21 +1241,33 @@ def test_auto_functionalize(self): ) @torch.library.impl("mylib::foo1", "cpu", lib=lib) +<<<<<<< HEAD @torch.library.register_fake("mylib::foo1") +======= + @torch.library.impl_abstract("mylib::foo1") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def foo1_impl(x, y, z, w, n): x.add_(y[0] + w) z.add_(y[1] + n) return n + n @torch.library.impl("mylib::foo2", "cpu", lib=lib) +<<<<<<< HEAD @torch.library.register_fake("mylib::foo2") +======= + @torch.library.impl_abstract("mylib::foo2") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def foo2_impl(x, y, z, w, n): x.add_(y[0] + w) z.add_(y[1] + n) return (n + n, n * n) @torch.library.impl("mylib::foo3", "cpu", lib=lib) +<<<<<<< HEAD @torch.library.register_fake("mylib::foo3") +======= + @torch.library.impl_abstract("mylib::foo3") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def foo3_impl(x, y, z, w, n): x.add_(y[0] + w) z.add_(y[1] + n) @@ -1771,6 +1810,7 @@ def forward(self, x): inputs = (torch.ones(2, 3),) self.check_graph(m, inputs, strict=False) +<<<<<<< HEAD def test_forward_compatibility(self): self.assertEqual( schema.TensorArgument( @@ -1785,6 +1825,8 @@ def test_forward_compatibility(self): ), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_parametrized_tests(TestDeserialize) @@ -1917,6 +1959,7 @@ def forward(self, x): f.seek(0) file_prefix = f.name.split("/")[2].split(".")[0] +<<<<<<< HEAD # Create a new file and copy things over, but modify the # archive version with tempfile.NamedTemporaryFile(suffix=".pt2") as fnew: @@ -1932,6 +1975,14 @@ def forward(self, x): f.seek(0) load(fnew.name) +======= + # Modify the version + with zipfile.ZipFile(f, "a") as zipf: + zipf.writestr(f"{file_prefix}/{ARCHIVE_VERSION_PATH}", "-1") + + f.seek(0) + load(f.name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_save_constants(self): class Foo(torch.nn.Module): @@ -2063,7 +2114,10 @@ def forward(self, obj_attr, x): def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) attr = self.attr +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) takes_foo = torch.ops._TorchScriptTesting.takes_foo.default(attr, x); attr = None add = torch.ops.aten.add.Tensor(x, takes_foo); x = takes_foo = None return pytree.tree_unflatten((add,), self._out_spec)""", @@ -2170,6 +2224,7 @@ def forward(self, x): self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") self.assertEqual(counter, 1) +<<<<<<< HEAD def test_unbacked_range_serdes(self): class Foo(torch.nn.Module): def forward(self, x, y): @@ -2224,6 +2279,8 @@ def forward(self, x, y, z): s0 = next(iter(ep.graph.nodes)).meta["val"].size(0) self.assertEqual(shape_env.var_to_range[s0.node.expr].lower, 0) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index f45775f09f29a..176f563730685 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -2,6 +2,10 @@ # ruff: noqa: F841 import copy +<<<<<<< HEAD +======= +import unittest +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.utils._pytree as pytree @@ -24,7 +28,10 @@ _empty_tensor_queue, init_torchbind_implementations, ) +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _assertEqualSkipScriptObject(test_case, exp, actual): @@ -185,7 +192,10 @@ def forward(self, x, n): def forward(self, x, n): x, n, = fx_pytree.tree_flatten_spec(([x, n], {}), self._in_spec) attr = self.attr +<<<<<<< HEAD _guards_fn = self._guards_fn(x, n); n = _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None return pytree.tree_unflatten((add,), self._out_spec)""", @@ -233,7 +243,10 @@ def forward(self, x): def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) attr = self.attr +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None return pytree.tree_unflatten((add,), self._out_spec)""", @@ -268,7 +281,10 @@ def forward(self, x): def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) attr = self.attr +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x); attr = None add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None return pytree.tree_unflatten((add,), self._out_spec)""", @@ -303,7 +319,10 @@ def forward(self, x, cc): """\ def forward(self, x, cc): x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x, cc); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) call_torchbind = torch.ops.higher_order.call_torchbind(cc, 'add_tensor', x); cc = None add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None return pytree.tree_unflatten((add,), self._out_spec)""", @@ -366,7 +385,10 @@ def forward(self, x, cc): """\ def forward(self, x, cc): x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x, cc); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(cc, x); cc = None add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None return pytree.tree_unflatten((add,), self._out_spec)""", @@ -411,7 +433,14 @@ def forward(self, x): F1(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch ) +<<<<<<< HEAD def test_torchbind_register_attr_at_runtime_error(self): +======= + # TODO(pianpwk): look into this + @unittest.expectedFailure + @parametrize("pre_dispatch", [True, False]) + def test_torchbind_input_and_alias(self, pre_dispatch): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # alias as model attribute class F3(torch.nn.Module): def forward(self, x, foo): @@ -419,6 +448,7 @@ def forward(self, x, foo): return x + self.foo.add_tensor(x) foo = torch.classes._TorchScriptTesting._Foo(10, 20) +<<<<<<< HEAD with self.assertRaisesRegex( ValueError, "following attrs were created in the model" ): @@ -438,6 +468,10 @@ def forward(self, x): foo = torch.classes._TorchScriptTesting._Foo(10, 20) self._test_export_same_as_eager( F3(foo), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch +======= + self._test_export_same_as_eager( + F3(), (torch.ones(2, 3), foo), strict=False, pre_dispatch=pre_dispatch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @parametrize("pre_dispatch", [True, False]) @@ -462,7 +496,10 @@ def forward(self, x): def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) attr = self.attr +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x) takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None @@ -505,7 +542,10 @@ def forward(self, x): def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) attr = self.attr +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) takes_foo_list_return_default = torch.ops._TorchScriptTesting.takes_foo_list_return.default(attr, x) getitem_2 = takes_foo_list_return_default[0] getitem_3 = takes_foo_list_return_default[1] @@ -558,7 +598,10 @@ def forward(self, x): def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) attr = self.attr +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(attr, x) getitem_1 = takes_foo_tuple_return_default[0] getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None @@ -834,8 +877,13 @@ def test_identifying_torchbind_ops(self): self.assertFalse(op._has_torchbind_op_overload) def test_torchbind_op_register_fallthrough(self): +<<<<<<< HEAD TEST_DISPATCH_KEY = torch._C.DispatchKey.AutogradCPU TEST_DISPATCH_KEY_STR = "AutogradCPU" +======= + TEST_DISPATCH_KEY = torch._C.DispatchKey.AutocastCPU + TEST_DISPATCH_KEY_STR = "AutocastCPU" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for op_packet in self.torch_bind_ops: op = op_packet.default @@ -966,19 +1014,31 @@ def forward(self, tq, x): with torch.library._scoped_library(ns, "FRAGMENT") as lib: for op in ops: lib.impl( +<<<<<<< HEAD op.__name__, torch.library.fallthrough_kernel, "AutogradCPU" +======= + op.__name__, torch.library.fallthrough_kernel, "AutocastCUDA" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) gm = make_fx(mod, tracing_mode="fake")(tq1, x) else: for op in ops: +<<<<<<< HEAD op.default.py_impl(torch._C.DispatchKey.AutogradCPU)( +======= + op.default.py_impl(torch._C.DispatchKey.AutocastCUDA)( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.library.fallthrough_kernel ) gm = make_fx(mod, tracing_mode="fake")(tq1, x) for op in ops: op.default._dispatch_cache.clear() +<<<<<<< HEAD del op.default.py_kernels[torch._C.DispatchKey.AutogradCPU] +======= + del op.default.py_kernels[torch._C.DispatchKey.AutocastCUDA] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertExpectedInline( gm.code.strip(), @@ -1073,7 +1133,10 @@ def forward(self, tq: torch.ScriptObject, x: torch.Tensor) -> None: """\ def forward(self, tq, x): tq, x, = fx_pytree.tree_flatten_spec(([tq, x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(tq, x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) queue_push_default = torch.ops._TorchScriptTesting.queue_push.default(tq, x); x = queue_push_default = None return pytree.tree_unflatten((tq,), self._out_spec)""", ) @@ -1159,7 +1222,11 @@ def __obj_unflatten__(cls, flattened_ctx): def tearDown(self): torch._dynamo.reset() +<<<<<<< HEAD @parametrize("backend", ["eager", "aot_eager", "inductor"]) +======= + @parametrize("backend", ["eager", "aot_eager"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_compile_script_object_input(self, backend): if backend == "eager": backend = EagerAndRecordGraphs() @@ -1223,7 +1290,11 @@ def forward(self, L_tq_ : torch.ScriptObject, L_x_ : torch.Tensor): return (x_sin,)""", ) +<<<<<<< HEAD @parametrize("backend", ["eager", "aot_eager", "inductor"]) +======= + @parametrize("backend", ["eager", "aot_eager"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_compile_script_object_input_guards(self, backend): class Model(torch.nn.Module): def __init__(self) -> None: @@ -1297,7 +1368,11 @@ def forward(self, tq, x): self.assertEqual(cnt.frame_count, 1) tq2 = _empty_tensor_queue() +<<<<<<< HEAD # make first tensor's second dim dynamic +======= + # make first tensor's secon dim dynamic +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tq2.push(torch.randn(2, 4, requires_grad=False)) torch.compile(mod, backend=cnt)(tq2, x) self.assertEqual(cnt.frame_count, 2) @@ -1308,7 +1383,11 @@ def forward(self, tq, x): torch.compile(mod, backend=cnt)(tq3, x) self.assertEqual(cnt.frame_count, 2) +<<<<<<< HEAD @parametrize("backend", ["eager", "aot_eager", "inductor"]) +======= + @parametrize("backend", ["eager", "aot_eager"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_compile_error_on_input_aliasing_contents(self, backend): if backend == "eager": backend = EagerAndRecordGraphs() @@ -1339,7 +1418,11 @@ def setattr_f(tq): return tq with self.assertRaisesRegex( +<<<<<<< HEAD RuntimeError, "Weird method call on TorchScript object" +======= + RuntimeError, "call method __setattr__ on script object is not safe" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.compile(setattr_f, backend=backend)(_empty_tensor_queue()) @@ -1352,11 +1435,19 @@ def setattr_f(tq): return tq._not_defined_attr with self.assertRaisesRegex( +<<<<<<< HEAD RuntimeError, "FakeScriptObject missing method implementation" ): torch.compile(setattr_f, backend=backend)(_empty_tensor_queue()) @parametrize("backend", ["eager", "aot_eager", "inductor"]) +======= + RuntimeError, "doesn't define method _not_defined_attr" + ): + torch.compile(setattr_f, backend=backend)(_empty_tensor_queue()) + + @parametrize("backend", ["eager", "aot_eager"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_compile_body_aliasing_contents(self, backend): if backend == "eager": backend = EagerAndRecordGraphs() @@ -1392,7 +1483,11 @@ def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject): return (sub, add)""", ) +<<<<<<< HEAD @parametrize("backend", ["eager", "aot_eager", "inductor"]) +======= + @parametrize("backend", ["eager", "aot_eager"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_compile_tensor_op_in_tensor_flatten(self, backend): test_obj = torch.classes._TorchScriptTesting._FlattenWithTensorOp( torch.randn(3, 2) @@ -1400,6 +1495,7 @@ def test_compile_tensor_op_in_tensor_flatten(self, backend): class TestMod(torch.nn.Module): def forward(self, obj, x): +<<<<<<< HEAD return obj.get() + x + obj.get().size(0) mod = TestMod() @@ -1409,11 +1505,21 @@ def forward(self, obj, x): compiled_out = torch.compile(mod, backend=backend, fullgraph=True)(test_obj, x) ep = torch.export.export_for_training( mod, (test_obj, x), strict=False +======= + return obj.get() + x + + mod = TestMod() + + torch.compile(mod, backend=backend, fullgraph=True)(test_obj, torch.randn(3, 1)) + ep = torch.export.export_for_training( + mod, (test_obj, torch.randn(3, 1)), strict=False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ).run_decompositions({}) self.assertExpectedInline( ep.graph_module.code.strip(), """\ def forward(self, token, obj, x): +<<<<<<< HEAD with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, obj, 'get'); token = None getitem = with_effects[0] getitem_1 = with_effects[1]; with_effects = None @@ -1427,6 +1533,16 @@ def forward(self, token, obj, x): self.assertEqual(eager_out, ep.module()(test_obj, x)) @parametrize("backend", ["eager", "aot_eager", "inductor"]) +======= + with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, obj, 'get'); token = obj = None + getitem = with_effects[0] + getitem_1 = with_effects[1]; with_effects = None + add_3 = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None + return (getitem, add_3)""", # noqa: B950 + ) + + @parametrize("backend", ["eager", "aot_eager"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_compile_error_on_non_fakified_method(self, backend): if backend == "eager": backend = EagerAndRecordGraphs() @@ -1443,11 +1559,19 @@ def f(tq, x): x = torch.randn(2, 3) with self.assertRaisesRegex( +<<<<<<< HEAD RuntimeError, "FakeScriptObject missing method implementation" ): torch.compile(f, backend=backend)(_empty_tensor_queue(), x) @parametrize("backend", ["eager", "aot_eager", "inductor"]) +======= + RuntimeError, "FakeScriptObject doesn't define method" + ): + torch.compile(f, backend=backend)(_empty_tensor_queue(), x) + + @parametrize("backend", ["eager", "aot_eager"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_compile_obj_as_hop_input(self, backend): def f(tq, x): def fn(tq, x): @@ -1463,7 +1587,11 @@ def fn(tq, x): torch.compile(f, backend=backend)(_empty_tensor_queue(), x), ) +<<<<<<< HEAD @parametrize("backend", ["eager", "aot_eager", "inductor"]) +======= + @parametrize("backend", ["eager", "aot_eager"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_compile_obj_closure(self, backend): def f(x): def inner_f(x): @@ -1478,7 +1606,11 @@ def inner_f(x): x = torch.randn(3, 2) _assertEqualScriptObject(self, f(x), opt_f(x)) +<<<<<<< HEAD @parametrize("backend", ["eager", "aot_eager", "inductor"]) +======= + @parametrize("backend", ["eager", "aot_eager"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_compile_global_obj(self, backend): global _TENSOR_QUEUE_GLOBAL_TEST _TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue() @@ -1514,7 +1646,11 @@ def f(tq, x): ) self.assertEqual(cnt.frame_count, 4) +<<<<<<< HEAD @parametrize("backend", ["eager", "aot_eager", "inductor"]) +======= + @parametrize("backend", ["eager", "aot_eager"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_compile_obj_attributes(self, backend): if backend == "eager": backend = EagerAndRecordGraphs() @@ -1546,7 +1682,11 @@ def forward(self, L_self_tq : torch.ScriptObject, L_x_ : torch.Tensor): return (call_torchbind_1,)""", ) +<<<<<<< HEAD @parametrize("backend", ["eager", "aot_eager", "inductor"]) +======= + @parametrize("backend", ["eager", "aot_eager"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_compile_obj_torchbind_op(self, backend): def f(tq, x): torch.ops._TorchScriptTesting.queue_push(tq, x.cos()) @@ -1561,6 +1701,7 @@ def f(tq, x): self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x) ) +<<<<<<< HEAD @requires_cuda_and_triton @parametrize("device", ["cpu", "cuda"]) @parametrize("backend", ["eager", "aot_eager", "inductor"]) @@ -1599,6 +1740,8 @@ def forward(self, x, tq): self, ep.module()(x, _empty_tensor_queue()), mod(x, _empty_tensor_queue()) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfTorchDynamo("torchbind not supported with dynamo yet") class TestRegisterFakeClass(TestCase): diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index 5e1872c249ed7..d003cbed01c66 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -178,6 +178,7 @@ def forward(self, x): id(getattr(unflattened_module.sub_net, "2")), ) +<<<<<<< HEAD def test_assert_tensor_metadata_stack(self): class N(torch.nn.Module): def __init__(self): @@ -211,6 +212,8 @@ def forward(self, x, y): inp = (torch.randn(3), torch.randn(3)) self.assertTrue(torch.allclose(uep(*inp), m(*inp))) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") @skipIfTorchDynamo("Non strict mode is not meant to run with dynamo") def test_unflatten_preserve_signature(self): @@ -359,10 +362,16 @@ def forward(self, x): export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=True) with self.assertRaisesRegex( +<<<<<<< HEAD AssertionError, escape("Guard failed: x.size()[0] == 2"), ): # expected 2, but got 6 +======= + RuntimeError, + escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"), + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) export_module.module()(torch.randn(6, 6)) unflattened = unflatten(export_module) @@ -665,6 +674,11 @@ def forward(self, x): export_module.module(), unflattened, (torch.randn((2, 3)),) ) +<<<<<<< HEAD +======= + # skip connection is not supported yet + @unittest.expectedFailure +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_unflatten_skipped_call_module(self): class C(torch.nn.Module): def __init__(self): @@ -702,6 +716,7 @@ def forward(self, x): # The call chain looks like this: # A -> B -> C -> A.d ep = torch.export.export(a, (torch.randn(3),), strict=False) +<<<<<<< HEAD ufm = unflatten(ep) self.assertExpectedInline( str(ufm.graph_module.code).strip(), @@ -726,6 +741,9 @@ def forward(self, x): sin = torch.ops.aten.sin.default(cos); cos = None return sin""", ) +======= + unflatten(ep) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nested_leaf_non_strict(self): class Leaf(torch.nn.Module): @@ -934,7 +952,11 @@ def forward(self, x, y): fn_count_sym_size = lambda graph: [node.target for node in graph.nodes].count( torch.ops.aten.sym_size.int ) +<<<<<<< HEAD self.assertEqual(fn_count_sym_size(unflat.graph), 1) +======= + self.assertEqual(fn_count_sym_size(unflat.graph), 3) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(fn_count_sym_size(unflat.m1.graph), 1) self.assertEqual(fn_count_sym_size(unflat.m2.graph), 0) diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 5a962dfa57c05..c2c325d7b070a 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -139,8 +139,11 @@ # These ops are defined in torch/csrc/distributed/c10d/Ops.cpp # TODO: add back restriction when c10d ops can be exported ("c10d::.*", datetime.date(9999, 1, 1)), +<<<<<<< HEAD # Previously MPS_only did not support backward ("aten::_fused_rms_norm", datetime.date(2025, 12, 30)), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] ALLOW_LIST_COMPILED = [ diff --git a/test/functorch/attn_ft.py b/test/functorch/attn_ft.py index 7038ded094904..374e13e802efa 100644 --- a/test/functorch/attn_ft.py +++ b/test/functorch/attn_ft.py @@ -126,7 +126,11 @@ def forward( if self.position_embedding_type == "relative_key": # these were einsum ops in the positional code because they are not easy to fit to existing matmul operators +<<<<<<< HEAD # even though they are degenerate matmuls +======= + # eventhough they are degenerate matmuls +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) relative_position_scores = (q * positional_embedding).sum(features) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": diff --git a/test/functorch/test_ac.py b/test/functorch/test_ac.py index fde84b6683edf..bd7ad89e89930 100644 --- a/test/functorch/test_ac.py +++ b/test/functorch/test_ac.py @@ -6,7 +6,11 @@ import torch import torch._functorch.config as config from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, TestCase +<<<<<<< HEAD from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON +======= +from torch.testing._internal.inductor_utils import HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._triton import has_triton from torch.utils.checkpoint import checkpoint from torch.utils.flop_counter import FlopCounterMode, register_flop_formula @@ -405,5 +409,9 @@ def call(): if __name__ == "__main__": # I'm using the cuda memory allocator to verify memory allocations +<<<<<<< HEAD if HAS_CUDA_AND_TRITON and not TEST_WITH_ROCM: +======= + if HAS_CUDA and not TEST_WITH_ROCM: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests() diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 7f365b3891763..2263435df0f88 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -51,14 +51,20 @@ from torch._dynamo.utils import counters from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache from torch._functorch.aot_autograd import ( +<<<<<<< HEAD _aot_export_function, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aot_export_joint_simple, aot_export_module, SerializableAOTDispatchCompiler, ) from torch._higher_order_ops.out_dtype import out_dtype from torch._inductor.codecache import compiled_fx_graph_hash +<<<<<<< HEAD from torch._inductor.custom_graph_pass import CustomPartitionerFn +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.output_code import MockFXGraphCacheOutput from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode from torch.fx.experimental.proxy_tensor import is_sym_node @@ -99,7 +105,10 @@ ) from torch.testing._internal.subclasses import WrapperSubclass from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode +<<<<<<< HEAD from torch.utils._python_dispatch import TorchDispatchMode +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) USE_TORCHVISION = False @@ -693,6 +702,7 @@ def f(a, b): ] self.verify_aot_autograd(f, inp, keep_inp_mutations=True) +<<<<<<< HEAD def _compile_autocast(self, device, *, forward_autocast): with torch.library._scoped_library("mylib", "FRAGMENT") as m: m.define("foo(Tensor x) -> Tensor") @@ -767,6 +777,8 @@ def test_backward_pass_autocast_custom(self): self.assertEqual(out, torch.zeros_like(out)) self.assertEqual(grad, torch.ones_like(grad)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfDynamoInput( "Test doesn't make sense with dynamo, which changes order of mutations" ) @@ -993,6 +1005,7 @@ def f(x): ): new_out.sum().backward() +<<<<<<< HEAD def test_nested_subclasses_non_homogenous(self): def f(x): x_elem = x.elem @@ -1097,6 +1110,8 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): self.assertTrue(isinstance(aa2.grad, ConstantExtraMetadataTensor)) self.assertTrue(isinstance(aa2.grad.elem, TwoTensor)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") def test_custom_tensor_metadata(self): def f(x): @@ -2478,7 +2493,11 @@ def f(a, b): return a.mul(3), b.mul(4) inp = [ +<<<<<<< HEAD # First inp doesn't require grad, but we switch it on +======= + # First inp doesnt require grad, but we switch it on +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.ones(3, 3, requires_grad=False), torch.ones(3, 3, requires_grad=True), ] @@ -4838,6 +4857,7 @@ def f(x, y): inps = [torch.randn(2, 2), torch.ones(2)] gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(gm.print_readable(False, expanded_def=True)), """\ class (torch.nn.Module): @@ -4846,6 +4866,12 @@ def forward( arg0_1: "f32[2, 2]", # PlainAOTInput(idx=0) arg1_1: "f32[2]", # PlainAOTInput(idx=1) ): +======= + normalize_gm(gm.print_readable(False)), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2]"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1) gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None @@ -4856,10 +4882,14 @@ def forward( add: "f32[2, 2]" = torch.ops.aten.add.Tensor(getitem, 3) add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(getitem, 4); getitem = None +<<<<<<< HEAD return ( add, # PlainAOTOutput(idx=0) add_1, # PlainAOTOutput(idx=1) ) +======= + return (add, add_1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class true_graph_0(torch.nn.Module): def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2]"): @@ -4933,6 +4963,7 @@ def f(x, y): inps = [torch.randn(2, 2), torch.ones(2)] gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True) self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(gm.print_readable(False, expanded_def=True)), """\ class (torch.nn.Module): @@ -4941,6 +4972,12 @@ def forward( arg0_1: "f32[2, 2]", # PlainAOTInput(idx=0) arg1_1: "f32[2]", # PlainAOTInput(idx=1) ): +======= + normalize_gm(gm.print_readable(False)), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2]"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cos: "f32[2, 2]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None _set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None @@ -4951,9 +4988,13 @@ def forward( sum_1: "f32[]" = torch.ops.aten.sum.default(getitem_2); getitem_2 = None add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None +<<<<<<< HEAD return ( add, # PlainAOTOutput(idx=0) ) +======= + return (add,) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class body_graph_0(torch.nn.Module): def forward(self, arg0_1: "f32[2]", arg1_1: "f32[2]"): @@ -5112,6 +5153,7 @@ def forward(self, x): for node in fx_g.graph.nodes: node.meta.pop("stack_trace", None) self.assertExpectedInline( +<<<<<<< HEAD fx_g.print_readable(print_output=False, expanded_def=True), """\ class (torch.nn.Module): @@ -5126,6 +5168,12 @@ def forward( arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]", ): +======= + fx_g.print_readable(print_output=False), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # No stacktrace found for following nodes convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg1_1 = None add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None @@ -5207,6 +5255,7 @@ def forward( for node in fx_g_inference.graph.nodes: node.meta.pop("stack_trace", None) self.assertExpectedInline( +<<<<<<< HEAD fx_g_inference.print_readable(print_output=False, expanded_def=True), """\ class (torch.nn.Module): @@ -5221,6 +5270,12 @@ def forward( arg6_1: "i64[]", # PlainAOTInput(idx=6) arg7_1: "f32[1, 1, 3, 3]", # PlainAOTInput(idx=7) ): +======= + fx_g_inference.print_readable(print_output=False), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # No stacktrace found for following nodes convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg7_1 = arg0_1 = arg1_1 = None add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None @@ -5233,6 +5288,7 @@ def forward( detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach); detach = None detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None +<<<<<<< HEAD return ( getitem_3, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=4)) getitem_4, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=5)) @@ -5240,6 +5296,9 @@ def forward( sum_1, # PlainAOTOutput(idx=0) detach_2, # PlainAOTOutput(idx=1) ) +======= + return (getitem_3, getitem_4, add, sum_1, detach_2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) # Some important characteristics of the exported graph below: @@ -5365,6 +5424,7 @@ def forward(self, x): mod = M() inp = torch.randn(2, requires_grad=True) +<<<<<<< HEAD gm, _ = aot_export_module(mod, [inp], trace_joint=False) self.assertExpectedInline( str(gm.graph).strip(), @@ -5374,6 +5434,13 @@ def forward(self, x): %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 4), kwargs = {}) return (add, add)""", ) +======= + with self.assertRaisesRegex( + RuntimeError, + "Found a graph input that requires gradients, and received a mutation", + ): + aot_export_module(mod, [inp], trace_joint=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_aot_export_input_mutation_on_parameter_banned(self): def fn(p, x): @@ -5384,6 +5451,7 @@ def fn(p, x): inp = torch.randn(2) with self.assertRaisesRegex( RuntimeError, +<<<<<<< HEAD "aot_export_joint_simple does not support input mutations. ViewAndMutationMeta", ): aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) @@ -5404,6 +5472,13 @@ def fn(p, x): %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %arg1_1), kwargs = {}) return (mul, add)""", ) +======= + "Found a graph input that requires gradients, and received a mutation", + ): + aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) + aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) + aot_export_module(mod, [inp], trace_joint=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_aot_export_synthetic_bases_banned(self): def fn(p, x, y): @@ -5554,6 +5629,7 @@ def forward(self): return (full_1,)""", # noqa: B950 ) +<<<<<<< HEAD def test_aot_export_input_mutation(self): def f(x, buf): buf.add_(1) @@ -5585,6 +5661,8 @@ def forward(self, primals, tangents): return pytree.tree_unflatten([mul, mul_1, None], self._out_spec)""", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestPartitioning(AOTTestCase): @unittest.skipIf(not USE_NETWORKX, "networkx not available") @@ -5689,6 +5767,7 @@ def forward(self, primals_1, tangents_1): ) @unittest.skipIf(not USE_NETWORKX, "networkx not available") +<<<<<<< HEAD def test_custom_partitioner_fn(self): class MyCustomPartitionerFn(CustomPartitionerFn): def __init__(self): @@ -5732,6 +5811,8 @@ def forward(self, primals_1, tangents_1): ) @unittest.skipIf(not USE_NETWORKX, "networkx not available") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_min_cut_partitioner_save_shape(self): def f(x): s = x.sum(dim=1) @@ -5877,7 +5958,11 @@ def f(a, b, c, d): _, fw_graph_out_nodes = get_ins_outs(fw_graph) self.assertEqual( # fw outputs include b.size() which expands to 2 symints, +<<<<<<< HEAD # then 4 tensors (transposes of matrices used for mm) are saved +======= + # then 4 tensors (transposes of matricies used for mm) are saved +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # finally 3 symints are saved [False, True, True, False, False] + [False] * 4 + [True] * 3, [is_sym_node(n) for n in fw_graph_out_nodes], @@ -6207,7 +6292,11 @@ def f(a, b): self.assertEqual(b_test.a, b_ref.a) self.assertEqual(b_test.b, b_ref.b) +<<<<<<< HEAD # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile the backward. +======= + # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile teh backward. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (b_ref * out_ref).sum().backward() (b_test * out_test).sum().backward() # Both grad_inputs are TwoTensor @@ -7676,7 +7765,11 @@ def test_saved_tensors_hooks_donated_buffers(self): "pack_hash", "unpack_hash", ) +<<<<<<< HEAD logger_name = "torch._functorch._aot_autograd.graph_compile" +======= + logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SAF(torch.autograd.Function): @staticmethod @@ -8157,6 +8250,7 @@ def _inps(): self.assertEqual(ref_inps_after_fw, inps_after_fw) self.assertEqual(ref_inps_after_bw, inps_after_bw) +<<<<<<< HEAD def test_mutation_of_input_in_fw_and_bw(self): class AF(torch.autograd.Function): @staticmethod @@ -8204,6 +8298,8 @@ def sc_inps(): y.sum().backward() self.assertEqual(ref, inplace) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class MockFXGraphCache: """ @@ -8292,6 +8388,10 @@ def run_autograd( { "enable_autograd_cache": True, "strict_autograd_cache": True, +<<<<<<< HEAD +======= + "view_replay_for_aliased_outputs": False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ) @torch._inductor.config.patch("fx_graph_cache", True) diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 81aa26c2be8a7..5015ad069ff85 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -7,7 +7,11 @@ import torch.utils._pytree as pytree from functorch.experimental import control_flow from functorch.experimental.control_flow import cond +<<<<<<< HEAD from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm +======= +from torch._dynamo.testing import normalize_gm +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._higher_order_ops.associative_scan import ( _fake_associative_scan, associative_scan, @@ -144,7 +148,11 @@ def complex_pointwise(x, y): } def non_pointwise(x: torch.Tensor, y: torch.Tensor): +<<<<<<< HEAD W = torch.arange(4, dtype=torch.float, device=x.device).view(2, 2) +======= + W = torch.diag(torch.ones(2, device=x.device)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return x @ W + y @ W def RNN(x: torch.Tensor, y: torch.Tensor): @@ -395,6 +403,7 @@ def body_fn(a, b, c1, c2, c3, c0, u0, x): ([torch.randn(3, 3)], {"x": torch.randn(3, 3), "y": torch.randn(3, 3)}), ), ), +<<<<<<< HEAD "int_carry": (int_carry, (torch.randn(2, 3),)), "pytree_int_carry": ( pytree_int_carry, @@ -403,6 +412,16 @@ def body_fn(a, b, c1, c2, c3, c0, u0, x): "const_and_symint_output": ( const_and_symint_output, (torch.randn(2, 3),), +======= + "int_carry": (int_carry, (torch.randn(2, 3, requires_grad=True),)), + "pytree_int_carry": ( + pytree_int_carry, + (torch.randn(2, 3, requires_grad=True),), + ), + "const_and_symint_output": ( + const_and_symint_output, + (torch.randn(2, 3, requires_grad=True),), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), } @@ -737,7 +756,12 @@ def forward(self, pred_1, x_1): getitem_1 = cond_1[0]; getitem_1 = None getitem_2 = cond_1[1] getitem_3 = cond_1[2]; getitem_3 = None +<<<<<<< HEAD getitem_4 = cond_1[3]; cond_1 = getitem_4 = None +======= + getitem_4 = cond_1[3]; getitem_4 = None + getitem_5 = cond_1[4]; cond_1 = getitem_5 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (getitem_2,)""", # noqa: B950 ) @@ -853,7 +877,14 @@ def forward(self, pred_1, a_1, b_1, c_1): cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2, ones_like)); pred_1 = true_graph_1 = false_graph_1 = a_1 = b_1 = sym_size_int = sym_size_int_1 = c_1 = sym_size_int_2 = ones_like = None getitem_1 = cond_1[0] getitem_2 = cond_1[1] +<<<<<<< HEAD getitem_3 = cond_1[2]; cond_1 = getitem_3 = None +======= + getitem_3 = cond_1[2]; getitem_3 = None + getitem_4 = cond_1[3]; getitem_4 = None + getitem_5 = cond_1[4]; getitem_5 = None + getitem_6 = cond_1[5]; cond_1 = getitem_6 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (getitem_1, getitem_2)""", # noqa: B950 ) # Forward @@ -873,7 +904,11 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1): clone = torch.ops.aten.clone.default(arg6_1) clone_1 = torch.ops.aten.clone.default(arg6_1); arg6_1 = None zeros_like = torch.ops.aten.zeros_like.default(arg4_1, pin_memory = False); arg4_1 = None +<<<<<<< HEAD return [clone, clone_1, zeros_like]""", +======= + return [clone, clone_1, None, None, zeros_like, None]""", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_cond_autograd_pytree_input(self): @@ -1298,11 +1333,21 @@ def _extract_tensor_metadata_except_requires_grad(arg): return cond_outputs, cond_inputs +<<<<<<< HEAD +======= + # TODO: The compile_mode = `compile_dynamic_shape` raises the Error + # torch._inductor.exc.LoweringException: NotImplementedError: get_size() is not + # implemented by ! +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfTorchDynamo("don't test compile on compile") @unittest.skipIf(not SM70OrLater, "triton") @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") @parametrize("compile_mode", ["compile_dynamic_shape"]) @parametrize("scalar", [False]) +<<<<<<< HEAD +======= + @unittest.expectedFailure +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cond_autograd_zeros_unused_branch_complex_compile_fail( self, compile_mode, scalar ): @@ -1401,6 +1446,10 @@ def f(x, y): f, (torch.ones(3, 4, 5), torch.ones(4, 4, 5)), torch.ones(5) ) +<<<<<<< HEAD +======= + @torch._dynamo.config.patch(capture_scalar_outputs=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_map_illegal_outputs(self): def f(x, y): return x.item() @@ -2004,7 +2053,11 @@ def test_scan_complex_pytree(self, reverse, compile_mode, device, autograd): if autograd: self.check_autograd(result, expected_result, (init, inp)) +<<<<<<< HEAD # TODO: Does not work because of the usage of vmap within associative_scan +======= + # TODO: Does not work because of the usage of vmap witin associative_scan +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The paT206899919 rameterization is commented out for the moment and the test is marked with expected fail # Fails with: AssertionError: scan is not an OpOverload @skipIfRocm(msg="Unsupported on ROCM yet") @@ -2737,6 +2790,11 @@ def fct_pointwise_different_carry(x, y): @skipIfNoDynamoSupport @skipIfCrossRef # Arg order changes with crossref def test_scan_pytree_output(self): +<<<<<<< HEAD +======= + from torch._dynamo.testing import EagerAndRecordGraphs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.randn(3, 10, 2, device=torch.device("cpu")) init = torch.randn(1, 10, 2, device=torch.device("cpu")) @@ -2947,6 +3005,7 @@ def RNN(x: torch.Tensor, y: torch.Tensor): params, ) +<<<<<<< HEAD @requires_cuda @skipIfTorchDynamo("not a dynamo test") @unittest.skipIf(not SM70OrLater, "triton") @@ -3110,6 +3169,8 @@ def run_test_and_get_grads_loss(model, initial_hs, inputs): compiled_loss, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @parametrize("reverse", [False, True]) @@ -3447,6 +3508,11 @@ def f(fct, init, xs): @skipIfNoDynamoSupport @skipIfCrossRef # Arg order changes with crossref def test_scan_simple_graph(self): +<<<<<<< HEAD +======= + from torch._dynamo.testing import EagerAndRecordGraphs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.randn(3, 10, 2, device=torch.device("cpu")) init = torch.randn(1, 10, 2, device=torch.device("cpu")) @@ -3717,6 +3783,7 @@ def setUp(self): torch._dynamo.reset() super().setUp() +<<<<<<< HEAD def _check_autograd(self, result, result_exp, autograd_param): grad_param = [p for p in autograd_param if p.requires_grad] @@ -3740,15 +3807,21 @@ def _check_autograd(self, result, result_exp, autograd_param): self.assertEqual(grads, expected_grads, atol=6e-05, rtol=6e-06) def _run_test(self, model, model_fake, inputs, autograd_param=None): +======= + def _run_test(self, model, model_fake, inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result = model(inputs) result_exp = model_fake(inputs) self.assertEqual(result, result_exp) +<<<<<<< HEAD if autograd_param is not None and any( par.requires_grad for par in autograd_param ): self._check_autograd(result, result_exp, autograd_param) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Return the result of the functions under test for further investigations return result @@ -3763,7 +3836,10 @@ def _prepare_fake_kwargs(self, original_kwargs): @parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"]) @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skipping the combination of combine_mode=pointwise and device=cpu # as the current implementation of pointwise does only support CUDA device # Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape @@ -3779,6 +3855,7 @@ def _prepare_fake_kwargs(self, original_kwargs): ) ), ) +<<<<<<< HEAD # # Skipping this combination as there is a CPP compilation failure that # # may be unrelated to associative_scan itself. There is a dedicated tests for # # this case below. @@ -3795,6 +3872,12 @@ def test_associative_scan_compile( self, combine_mode, reverse, compile_mode, device, autograd ): x = torch.randn(3, 10, 2, device=device, requires_grad=autograd) +======= + def test_associative_scan_compile( + self, combine_mode, reverse, compile_mode, device + ): + x = torch.randn(3, 10, 2, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs = { "dim": 0, "reverse": reverse, @@ -3806,7 +3889,10 @@ def test_associative_scan_compile( model=AssociativeScanModels.Simple(**kwargs), model_fake=AssociativeScanModels.Simple(**kwargs_fake), inputs=x, +<<<<<<< HEAD autograd_param=None if not autograd else (x,), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if not reverse: @@ -3816,9 +3902,13 @@ def test_associative_scan_compile( self.assertEqual(results, results_torch) # Jax Examples +<<<<<<< HEAD x = torch.arange( 0, 4, device=device, dtype=torch.float32, requires_grad=autograd ) +======= + x = torch.arange(0, 4, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs = { "dim": 0, "reverse": reverse, @@ -3831,6 +3921,7 @@ def test_associative_scan_compile( model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=x, +<<<<<<< HEAD autograd_param=None if not autograd else (x,), ) @@ -3838,6 +3929,14 @@ def test_associative_scan_compile( results_torch = torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.float32) else: results_torch = torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.float32) +======= + ) + + if not reverse: + results_torch = torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64) + else: + results_torch = torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(result, results_torch) @@ -3847,7 +3946,10 @@ def test_associative_scan_compile( @parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"]) @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skipping the combination of combine_mode=pointwise and device=cpu # as the current implementation of pointwise does only support CUDA device # Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape @@ -3863,9 +3965,13 @@ def test_associative_scan_compile( ) ), ) +<<<<<<< HEAD def test_associative_scan_dim( self, combine_mode, compile_mode, reverse, device, autograd ): +======= + def test_associative_scan_dim(self, combine_mode, compile_mode, reverse, device): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import random random.seed(1234) @@ -3876,7 +3982,11 @@ def test_associative_scan_dim( torch._dynamo.reset() shapes = [random.randint(1, 9) for _ in range(num_dim)] rnd_scan_dim = random.randint(0, num_dim - 1) +<<<<<<< HEAD x = torch.randn(*shapes, device=device, requires_grad=autograd) +======= + x = torch.randn(*shapes, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs = { "dim": rnd_scan_dim, @@ -3889,7 +3999,10 @@ def test_associative_scan_dim( model=AssociativeScanModels.Simple(**kwargs), model_fake=AssociativeScanModels.Simple(**kwargs_fake), inputs=x, +<<<<<<< HEAD autograd_param=None if not autograd else (x,), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if not reverse: @@ -3928,7 +4041,10 @@ def test_associative_scan_dim_shape_failure(self, compile_mode, combine_mode): @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skipping the combination of combine_mode=pointwise and device=cpu # as the current implementation of pointwise does only support CUDA device # Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape @@ -3944,11 +4060,17 @@ def test_associative_scan_dim_shape_failure(self, compile_mode, combine_mode): ) ), ) +<<<<<<< HEAD def test_associative_scan_tuple( self, compile_mode, combine_mode, reverse, device, autograd ): x = torch.randn(3, 2, 2, device=device, requires_grad=autograd) y = torch.randn(3, 2, 2, device=device, requires_grad=autograd) +======= + def test_associative_scan_tuple(self, compile_mode, combine_mode, reverse, device): + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inp = (x, y) kwargs = { @@ -3963,12 +4085,16 @@ def test_associative_scan_tuple( model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=inp, +<<<<<<< HEAD autograd_param=None if not autograd else inp, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"]) +<<<<<<< HEAD @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) @parametrize("autograd", [False, True]) @@ -3976,6 +4102,15 @@ def test_associative_scan_expand_in_combine_fn( self, compile_mode, reverse, device, autograd ): x = torch.randn(3, 2, 2, device=device, requires_grad=autograd) +======= + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_associative_scan_expand_in_combine_fn( + self, compile_mode, combine_mode, reverse, device + ): + x = torch.randn(3, 2, 2, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def combine_fn(x, y): return x * torch.sum(y, -1).expand(x.shape) @@ -3992,7 +4127,10 @@ def combine_fn(x, y): model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=x, +<<<<<<< HEAD autograd_param=None if not autograd else (x,), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @unittest.skipIf(not SM70OrLater, "triton") @@ -4000,6 +4138,7 @@ def combine_fn(x, y): @parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) def test_associative_scan_non_contiguous_tensor( self, compile_mode, reverse, device, autograd @@ -4009,6 +4148,12 @@ def test_associative_scan_non_contiguous_tensor( .view(10, 3) .t() ) +======= + def test_associative_scan_non_contiguous_tensor( + self, compile_mode, reverse, device + ): + x = torch.arange(30, device=device).view(10, 3).t() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert not x.is_contiguous() kwargs = { @@ -4023,7 +4168,10 @@ def test_associative_scan_non_contiguous_tensor( model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=x, +<<<<<<< HEAD autograd_param=None if not autograd else (x,), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @unittest.skipIf(not SM70OrLater, "triton") @@ -4032,7 +4180,10 @@ def test_associative_scan_non_contiguous_tensor( @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skipping the combination of combine_mode=pointwise and device=cpu # as the current implementation of pointwise does only support CUDA device # Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape @@ -4049,11 +4200,19 @@ def test_associative_scan_non_contiguous_tensor( ), ) def test_associative_scan_complex_pytree( +<<<<<<< HEAD self, compile_mode, combine_mode, reverse, device, autograd ): x = torch.randn(3, 2, 2, device=device, requires_grad=autograd) y = torch.randn(3, 2, 2, device=device, requires_grad=autograd) z = torch.randn(3, 2, 2, device=device, requires_grad=autograd) +======= + self, compile_mode, combine_mode, reverse, device + ): + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + z = torch.randn(3, 2, 2, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inp = {"i": x, "j": ([y], [{"o": z}])} kwargs = { @@ -4068,13 +4227,21 @@ def test_associative_scan_complex_pytree( model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=inp, +<<<<<<< HEAD autograd_param=None if not autograd else (x, y, z), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @skipIfTorchDynamo("don't test compile on compile") @skipIfNoDynamoSupport @skipIfCrossRef # Arg order changes with crossref def test_associative_scan_pytree_output(self): +<<<<<<< HEAD +======= + from torch._dynamo.testing import EagerAndRecordGraphs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = ( ( torch.randn(3, 10, 2, device=torch.device("cpu")), @@ -4122,6 +4289,7 @@ def forward(self, L_xs_0_0_: "f32[3, 10, 2]", L_xs_0_1_0_: "f32[3, 10, 2]", L_xs child_4: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_4, 0, 1, None, 2) child_5: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_5, 0, 1, None, 2) +<<<<<<< HEAD lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions = None _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting = None @@ -4132,22 +4300,43 @@ def forward(self, L_xs_0_0_: "f32[3, 10, 2]", L_xs_0_1_0_: "f32[3, 10, 2]", L_xs _add_batch_dim_3: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_3, 0, 1); child_3 = _add_batch_dim_3 = None _add_batch_dim_4: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_4, 0, 1); child_4 = _add_batch_dim_4 = None _add_batch_dim_5: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_5, 0, 1); child_5 = None +======= + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None + + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting = None + + _add_batch_dim: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None + _add_batch_dim_1: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_1, 0, 1); child_1 = None + _add_batch_dim_2: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_2, 0, 1); child_2 = _add_batch_dim_2 = None + _add_batch_dim_3: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_3, 0, 1); child_3 = _add_batch_dim_3 = None + _add_batch_dim_4: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_4, 0, 1); child_4 = _add_batch_dim_4 = None + _add_batch_dim_5: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_5, 0, 1); child_5 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a: "f32[10, 2]" = _add_batch_dim + _add_batch_dim_5; _add_batch_dim = None b: "f32[10, 2]" = _add_batch_dim_1 - _add_batch_dim_5; _add_batch_dim_1 = _add_batch_dim_5 = None child_6: "f32[10, 2]" = a - b +<<<<<<< HEAD child_7: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(a, 1, 1, 0); a = None child_8: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(b, 1, 1, 0); b = None child_9: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(child_6, 1, 1, 0); child_6 = None _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +======= + child_7: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(a, 1, 1, 0); a = None + child_8: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(b, 1, 1, 0); b = None + child_9: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(child_6, 1, 1, 0); child_6 = None + + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) child_10: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_3, 0, 2, None, 2) child_11: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_4, 0, 2, None, 2) child_12: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_5, 0, 2, None, 2) +<<<<<<< HEAD lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions(); lazy_load_decompositions_1 = None _vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting_1 = None @@ -4158,17 +4347,37 @@ def forward(self, L_xs_0_0_: "f32[3, 10, 2]", L_xs_0_1_0_: "f32[3, 10, 2]", L_xs _add_batch_dim_9: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_10, 0, 1); child_10 = _add_batch_dim_9 = None _add_batch_dim_10: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_11, 0, 1); child_11 = _add_batch_dim_10 = None _add_batch_dim_11: "f32[10, 2]" = torch._functorch.predispatch._add_batch_dim(child_12, 0, 1); child_12 = None +======= + lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None + + _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(1, 'error'); _vmap_increment_nesting_1 = None + + _add_batch_dim_6: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_7, 0, 1) + _add_batch_dim_7: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_8, 0, 1) + _add_batch_dim_8: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_9, 0, 1); _add_batch_dim_8 = None + _add_batch_dim_9: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_10, 0, 1); child_10 = _add_batch_dim_9 = None + _add_batch_dim_10: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_11, 0, 1); child_11 = _add_batch_dim_10 = None + _add_batch_dim_11: "f32[10, 2]" = torch._C._functorch._add_batch_dim(child_12, 0, 1); child_12 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a_1: "f32[10, 2]" = _add_batch_dim_6 + _add_batch_dim_11; _add_batch_dim_6 = None b_1: "f32[10, 2]" = _add_batch_dim_7 - _add_batch_dim_11; _add_batch_dim_7 = _add_batch_dim_11 = None child_13: "f32[10, 2]" = a_1 - b_1 +<<<<<<< HEAD child_14: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(a_1, 1, 1, 0); a_1 = None child_15: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(b_1, 1, 1, 0); b_1 = None child_16: "f32[1, 10, 2]" = torch._functorch.predispatch._remove_batch_dim(child_13, 1, 1, 0); child_13 = None _vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None +======= + child_14: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(a_1, 1, 1, 0); a_1 = None + child_15: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(b_1, 1, 1, 0); b_1 = None + child_16: "f32[1, 10, 2]" = torch._C._functorch._remove_batch_dim(child_13, 1, 1, 0); child_13 = None + + _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) slice_10: "f32[1, 10, 2]" = torch.ops.aten.slice(elem_3, 0, 0, 1); elem_3 = None cat: "f32[2, 10, 2]" = torch.cat([slice_10, child_14], dim = 0); slice_10 = child_14 = None @@ -4218,7 +4427,10 @@ def forward(self, L_xs_0_0_: "f32[3, 10, 2]", L_xs_0_1_0_: "f32[3, 10, 2]", L_xs @parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skipping the combination of combine_mode=pointwise and device=cpu # as the current implementation of pointwise does only support CUDA device # Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape @@ -4235,7 +4447,11 @@ def forward(self, L_xs_0_0_: "f32[3, 10, 2]", L_xs_0_1_0_: "f32[3, 10, 2]", L_xs ), ) def test_associative_scan_downstream_scan_matmul( +<<<<<<< HEAD self, combine_mode, compile_mode, reverse, device, autograd +======= + self, combine_mode, compile_mode, reverse, device +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): def first_chain_fct(scan_fct, inp, **kwargs): o = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs) @@ -4245,7 +4461,11 @@ def second_chain_fct(scan_fct, inp, **kwargs): W = torch.ones(2, 5, device=device) return inp @ W +<<<<<<< HEAD inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd) +======= + inp = torch.randn(3, 10, 2, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs = { "dim": 1, "reverse": reverse, @@ -4258,7 +4478,10 @@ def second_chain_fct(scan_fct, inp, **kwargs): model=AssociativeScanModels.ChainFn(**kwargs), model_fake=AssociativeScanModels.ChainFn(**kwargs_fake), inputs=inp, +<<<<<<< HEAD autograd_param=None if not autograd else (inp,), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @unittest.skipIf(not SM70OrLater, "triton") @@ -4267,7 +4490,10 @@ def second_chain_fct(scan_fct, inp, **kwargs): @parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skipping the combination of combine_mode=pointwise and device=cpu # as the current implementation of pointwise does only support CUDA device # Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape @@ -4284,7 +4510,11 @@ def second_chain_fct(scan_fct, inp, **kwargs): ), ) def test_associative_scan_downstream_scan_scan( +<<<<<<< HEAD self, combine_mode, compile_mode, reverse, device, autograd +======= + self, combine_mode, compile_mode, reverse, device +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): def first_chain_fct(scan_fct, inp, **kwargs): o1 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs) @@ -4294,7 +4524,11 @@ def second_chain_fct(scan_fct, inp, **kwargs): o2 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs) return o2 +<<<<<<< HEAD inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd) +======= + inp = torch.randn(3, 10, 2, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs = { "dim": 1, @@ -4308,7 +4542,10 @@ def second_chain_fct(scan_fct, inp, **kwargs): model=AssociativeScanModels.ChainFn(**kwargs), model_fake=AssociativeScanModels.ChainFn(**kwargs_fake), inputs=inp, +<<<<<<< HEAD autograd_param=None if not autograd else (inp,), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @unittest.skipIf(not SM70OrLater, "triton") @@ -4318,7 +4555,10 @@ def second_chain_fct(scan_fct, inp, **kwargs): @parametrize("reverse_first", [False, True]) @parametrize("same_direction", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skipping the combination of combine_mode=pointwise and device=cpu # as the current implementation of pointwise does only support CUDA device # Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape @@ -4334,6 +4574,7 @@ def second_chain_fct(scan_fct, inp, **kwargs): ) ), ) +<<<<<<< HEAD # Skipping the autograd=True because # associative_scan does currently not support gradients for lifted parameters @decorateIf( @@ -4348,6 +4589,10 @@ def test_associative_scan_downstream_scan_scan_different_dim( same_direction, device, autograd, +======= + def test_associative_scan_downstream_scan_scan_different_dim( + self, combine_mode, compile_mode, reverse_first, same_direction, device +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): reverse_second = reverse_first if same_direction else not reverse_first @@ -4359,7 +4604,11 @@ def second_chain_fct(scan_fct, inp, **kwargs): o2 = scan_fct(get_scan_combine_fn("add", True), inp, **kwargs) return o2 +<<<<<<< HEAD inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd) +======= + inp = torch.randn(3, 10, 2, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs = { "dim": [1, 0], @@ -4373,10 +4622,16 @@ def second_chain_fct(scan_fct, inp, **kwargs): model=AssociativeScanModels.ChainFn(**kwargs), model_fake=AssociativeScanModels.ChainFn(**kwargs_fake), inputs=inp, +<<<<<<< HEAD autograd_param=None if not autograd else (inp,), ) # TODO: Does not work because of the usage of vmap within associative_scan +======= + ) + + # TODO: Does not work because of the usage of vmap witin associative_scan +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: Re-enable additional parameters again once this issues has been resolved @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @@ -4432,9 +4687,14 @@ def second_nested_fct(x, y): @parametrize("loop_type", ["for"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) def test_associative_scan_loop_in_combine_fn( self, compile_mode, loop_type, reverse, device, autograd +======= + def test_associative_scan_loop_in_combine_fn( + self, compile_mode, loop_type, reverse, device +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): def combine_fn(x, y): cnt = torch.zeros_like(y[0, :]) @@ -4459,7 +4719,11 @@ def body_fn(ind, loop_val): cnt += torch.abs(y[ind]) return x * cnt +<<<<<<< HEAD inp = torch.randn(3, 10, 1, device=device, requires_grad=autograd) * 2 +======= + inp = torch.randn(3, 10, 1, device=device) * 2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs = { "dim": 0, @@ -4473,10 +4737,16 @@ def body_fn(ind, loop_val): model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=inp, +<<<<<<< HEAD autograd_param=None if not autograd else (inp,), ) # TODO: Does not work because of the usage of vmap within associative_scan +======= + ) + + # TODO: Does not work because of the usage of vmap witin associative_scan +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: Re-enable additional parameters again once this issues has been resolved @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @@ -4518,7 +4788,10 @@ def body_fn(ind, loop_val): @parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skipping the combination of compile_mode=compile_dynamic_shape # as the current implementation does not support lifted arguments @decorateIf( @@ -4529,14 +4802,22 @@ def body_fn(ind, loop_val): or torch.version.hip ), ) +<<<<<<< HEAD def test_associative_scan_cond_in_combine_fn( self, compile_mode, reverse, device, autograd ): +======= + def test_associative_scan_cond_in_combine_fn(self, compile_mode, reverse, device): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def combine_fn(x, y): val = cond(torch.sum(y) > 0.0, lambda y: y.clone(), lambda y: 1.0 - y, (y,)) return x * val +<<<<<<< HEAD inp = torch.randn(3, 10, 1, device=device, requires_grad=autograd) +======= + inp = torch.randn(3, 10, 1, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs = { "dim": 0, @@ -4550,10 +4831,16 @@ def combine_fn(x, y): model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=inp, +<<<<<<< HEAD autograd_param=None if not autograd else (inp,), ) # TODO: Does not work because of the usage of vmap within associative_scan +======= + ) + + # TODO: Does not work because of the usage of vmap witin associative_scan +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: Re-enable additional parameters again once this issues has been resolved @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @@ -4592,10 +4879,14 @@ def body(x, y): @parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) def test_associative_scan_vmap_in_combine_fn( self, compile_mode, reverse, device, autograd ): +======= + def test_associative_scan_vmap_in_combine_fn(self, compile_mode, reverse, device): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def combine_fn(x, y): def body(x): return x**2 @@ -4604,7 +4895,11 @@ def body(x): y_new = mapped_body(y) return x + y_new +<<<<<<< HEAD inp = torch.randn(3, 10, 2, device=device, requires_grad=autograd) +======= + inp = torch.randn(3, 10, 2, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs = { "dim": 0, @@ -4618,7 +4913,10 @@ def body(x): model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=inp, +<<<<<<< HEAD autograd_param=None if not autograd else (inp,), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @unittest.skipIf(not SM70OrLater, "triton") @@ -4626,7 +4924,10 @@ def body(x): @parametrize("reverse", [False, True]) @parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skipping the combination of associative_scan and device=cpu # as the current implementation of pointwise does only support CUDA device @decorateIf( @@ -4634,9 +4935,15 @@ def body(x): lambda params: (params["device"] == torch.device("cpu")), ) def test_associative_scan_non_pointwise_generic( +<<<<<<< HEAD self, reverse, compile_mode, device, autograd ): x = torch.randn(3, 10, 2, device=device, requires_grad=autograd) +======= + self, reverse, compile_mode, device + ): + x = torch.randn(3, 10, 2, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs = { "dim": 0, @@ -4650,7 +4957,10 @@ def test_associative_scan_non_pointwise_generic( model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=x, +<<<<<<< HEAD autograd_param=None if not autograd else (x,), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @skipIfRocm(msg="Unsupported on ROCM yet") @@ -4660,7 +4970,10 @@ def test_associative_scan_non_pointwise_generic( @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skipping the combination of combine_mode=pointwise and device=cpu # as the current implementation of pointwise does only support CUDA device # Skipping the combination of combine_mode=pointwise and compile_mode=compile_dynamic_shape @@ -4677,14 +4990,24 @@ def test_associative_scan_non_pointwise_generic( ), ) def test_associative_scan_binary_operator( +<<<<<<< HEAD self, compile_mode, combine_mode, reverse, device, autograd +======= + self, compile_mode, combine_mode, reverse, device +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): state_dim = 20 timesteps = 10 projected_inputs = torch.randn( +<<<<<<< HEAD timesteps, state_dim, device=device, requires_grad=autograd ) A = torch.randn(state_dim, device=device, requires_grad=autograd) +======= + timesteps, state_dim, requires_grad=True, device=device + ) + A = torch.randn(state_dim, requires_grad=True, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elements = (A.repeat((timesteps, 1)), projected_inputs) kwargs = { @@ -4699,7 +5022,10 @@ def test_associative_scan_binary_operator( model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=elements, +<<<<<<< HEAD autograd_param=None if not autograd else elements, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @skipIfRocm(msg="Unsupported on ROCM yet") @@ -4781,7 +5107,10 @@ def test_associative_scan_different_input_size_wrong_dim(self): @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skipping the combine_mode=pointwise # as the current implementation of associative_scan lowering # does not support lifted arguments @@ -4790,9 +5119,15 @@ def test_associative_scan_different_input_size_wrong_dim(self): lambda params: (params["combine_mode"] == "pointwise"), ) def test_associative_scan_freevars_simple( +<<<<<<< HEAD self, compile_mode, combine_mode, reverse, device, autograd ): H = torch.rand(2, device=device, requires_grad=autograd) +======= + self, compile_mode, combine_mode, reverse, device + ): + H = torch.rand(2, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def fct_freevars1(x: torch.Tensor, y: torch.Tensor): return x * H + y * 2 @@ -4800,13 +5135,22 @@ def fct_freevars1(x: torch.Tensor, y: torch.Tensor): def fct_freevars2(x: torch.Tensor, y: torch.Tensor): return x * H + y * H +<<<<<<< HEAD H1 = torch.rand(1, device=device, requires_grad=autograd) H2 = torch.rand(1, device=device, requires_grad=autograd) +======= + H1 = torch.rand(1, device=device) + H2 = torch.rand(1, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def fct_freevars3(x: torch.Tensor, y: torch.Tensor): return x * H1 + y * H2 +<<<<<<< HEAD inp = torch.randn(3, 2, 2, device=device, requires_grad=autograd) +======= + inp = torch.randn(3, 2, 2, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for fct, param in [ (fct_freevars1, (H,)), @@ -4825,7 +5169,10 @@ def fct_freevars3(x: torch.Tensor, y: torch.Tensor): model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=inp, +<<<<<<< HEAD autograd_param=None if not autograd else (inp, *param), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @unittest.skipIf(not SM70OrLater, "triton") @@ -4834,7 +5181,10 @@ def fct_freevars3(x: torch.Tensor, y: torch.Tensor): @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skipping the combine_mode=pointwise # as the current implementation of associative_scan lowering # does not support lifted arguments @@ -4843,10 +5193,17 @@ def fct_freevars3(x: torch.Tensor, y: torch.Tensor): lambda params: (params["combine_mode"] == "pointwise"), ) def test_associative_scan_freevars_nested( +<<<<<<< HEAD self, compile_mode, combine_mode, reverse, device, autograd ): H1 = torch.rand(4, 5, device=device, requires_grad=autograd) H2 = torch.rand(4, 1, device=device, requires_grad=autograd) +======= + self, compile_mode, combine_mode, reverse, device + ): + H1 = torch.rand(4, 5, device=device) + H2 = torch.rand(4, 1, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def fct_nested_outside(x: torch.Tensor, y: torch.Tensor): def inner(xi): @@ -4862,10 +5219,19 @@ def inner(xi): ret = inner(y) return x + ret * H1 +<<<<<<< HEAD +======= + H1_i = torch.rand(4, 5, device=device) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: Using random tensors in the `combine_fn` triggers the vmap randomness error: # RuntimeError: vmap: called random operation while in randomness error mode. # Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap def fct_nested_inside(x: torch.Tensor, y: torch.Tensor): +<<<<<<< HEAD +======= + # H2_i = torch.rand(4, 1, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) H2_i = torch.ones(4, 1, device=device) * 42 def inner(xi): @@ -4875,6 +5241,10 @@ def inner(xi): return x + ret * H1 def fct_nested_inside_fake(x: torch.Tensor, y: torch.Tensor): +<<<<<<< HEAD +======= + # H2_i = torch.rand(4, 1, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) H2_i = torch.ones(4, 1, device=device) * 42 def inner(xi): @@ -4883,11 +5253,19 @@ def inner(xi): ret = inner(y) return x + ret * H1 +<<<<<<< HEAD inp = torch.randn(3, 4, 5, device=device, requires_grad=autograd) for fct, fct_fake, param in [ (fct_nested_outside, fct_nested_outside_fake, (H1, H2)), (fct_nested_inside, fct_nested_inside_fake, ()), +======= + inp = torch.randn(3, 4, 5, device=device) + + for fct, fct_fake, param in [ + (fct_nested_outside, fct_nested_outside_fake, (H1, H2)), + (fct_nested_inside, fct_nested_inside_fake, (H1_i,)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ]: kwargs = { "dim": 0, @@ -4902,7 +5280,10 @@ def inner(xi): model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=inp, +<<<<<<< HEAD autograd_param=None if not autograd else (inp, *param), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @unittest.skipIf(not SM70OrLater, "triton") @@ -4911,7 +5292,10 @@ def inner(xi): @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skipping the combine_mode=pointwise # as the current implementation of associative_scan lowering # does not support lifted arguments @@ -4920,7 +5304,11 @@ def inner(xi): lambda params: (params["combine_mode"] == "pointwise"), ) def test_associative_scan_freevars_fct( +<<<<<<< HEAD self, compile_mode, combine_mode, reverse, device, autograd +======= + self, compile_mode, combine_mode, reverse, device +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): def additional_fct_no_add_inp(x, y): return x * y @@ -4929,7 +5317,11 @@ def fct_nested_outside(x: torch.Tensor, y: torch.Tensor): ret = additional_fct_no_add_inp(y, y) return x + ret +<<<<<<< HEAD inp = torch.randn(3, 4, 5, device=device, requires_grad=autograd) +======= + inp = torch.randn(3, 4, 5, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs = { "dim": 0, @@ -4943,7 +5335,10 @@ def fct_nested_outside(x: torch.Tensor, y: torch.Tensor): model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=inp, +<<<<<<< HEAD autograd_param=None if not autograd else (inp,), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @unittest.skipIf(not SM70OrLater, "triton") @@ -4951,10 +5346,14 @@ def fct_nested_outside(x: torch.Tensor, y: torch.Tensor): @parametrize("compile_mode", ["none", "eager", "compile", "compile_dynamic_shape"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) def test_associative_scan_freevars_fct_generic( self, compile_mode, reverse, device, autograd ): +======= + def test_associative_scan_freevars_fct_generic(self, compile_mode, reverse, device): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def additional_fct_no_add_inp(x, y): return x * y @@ -4968,7 +5367,11 @@ def fct_nested_outside_fake(x: torch.Tensor, y: torch.Tensor): ret = _fake_associative_scan(additional_fct_no_add_inp, y, 1) return x + ret +<<<<<<< HEAD inp = torch.randn(3, 4, 5, device=device, requires_grad=autograd) +======= + inp = torch.randn(3, 4, 5, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs = { "dim": 0, @@ -4983,7 +5386,10 @@ def fct_nested_outside_fake(x: torch.Tensor, y: torch.Tensor): model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=inp, +<<<<<<< HEAD autograd_param=None if not autograd else (inp,), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @unittest.skipIf(not SM70OrLater, "triton") @@ -4992,7 +5398,10 @@ def fct_nested_outside_fake(x: torch.Tensor, y: torch.Tensor): @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skipping the combine_mode=pointwise # as the current implementation of associative_scan lowering # does not support lifted arguments @@ -5001,7 +5410,11 @@ def fct_nested_outside_fake(x: torch.Tensor, y: torch.Tensor): lambda params: (params["combine_mode"] == "pointwise"), ) def test_associative_scan_freevars_shape_check( +<<<<<<< HEAD self, compile_mode, combine_mode, reverse, device, autograd +======= + self, compile_mode, combine_mode, reverse, device +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): H = torch.eye(2, device=device, requires_grad=True) @@ -5022,7 +5435,10 @@ def fct_freevars(x: torch.Tensor, y: torch.Tensor): model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=inp, +<<<<<<< HEAD autograd_param=None if not autograd else (inp,), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @unittest.skipIf(not SM70OrLater, "triton") @@ -5031,7 +5447,10 @@ def fct_freevars(x: torch.Tensor, y: torch.Tensor): @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) @parametrize("combine_mode", ["pointwise", "generic"]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skipping the combine_mode=pointwise # as the current implementation of associative_scan lowering # does not support lifted arguments @@ -5040,11 +5459,19 @@ def fct_freevars(x: torch.Tensor, y: torch.Tensor): lambda params: (params["combine_mode"] == "pointwise"), ) def test_associative_scan_freevars_pytree( +<<<<<<< HEAD self, compile_mode, combine_mode, reverse, device, autograd ): xf = torch.randn(2, 2, device=device, requires_grad=autograd) yf = torch.randn(2, 2, device=device, requires_grad=autograd) zf = torch.randn(2, 2, device=device, requires_grad=autograd) +======= + self, compile_mode, combine_mode, reverse, device + ): + xf = torch.randn(2, 2, device=device, requires_grad=True) + yf = torch.randn(2, 2, device=device, requires_grad=True) + zf = torch.randn(2, 2, device=device, requires_grad=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inpf = {"i": xf, "j": ([yf], [{"o": zf}])} def fct_pointwise(x, y): @@ -5061,9 +5488,15 @@ def fct_pointwise(x, y): ), } +<<<<<<< HEAD x = torch.randn(3, 2, 2, device=device, requires_grad=autograd) y = torch.randn(3, 2, 2, device=device, requires_grad=autograd) z = torch.randn(3, 2, 2, device=device, requires_grad=autograd) +======= + x = torch.randn(3, 2, 2, device=device, requires_grad=True) + y = torch.randn(3, 2, 2, device=device, requires_grad=True) + z = torch.randn(3, 2, 2, device=device, requires_grad=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inp = {"i": x, "j": ([y], [{"o": z}])} kwargs = { @@ -5078,7 +5511,10 @@ def fct_pointwise(x, y): model=AssociativeScanModels.CombineFn(**kwargs), model_fake=AssociativeScanModels.CombineFn(**kwargs_fake), inputs=inp, +<<<<<<< HEAD autograd_param=None if not autograd else (*pytree.tree_leaves(inp),), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @unittest.skipIf(not SM70OrLater, "triton") @@ -5275,6 +5711,11 @@ def f(x, y): @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") @skipIfCrossRef # Arg order changes with crossref def test_cond_simple_with_linear_compile_check_graph(self): +<<<<<<< HEAD +======= + from torch._dynamo.testing import EagerAndRecordGraphs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def true_fn(x): return x.sin() @@ -5322,7 +5763,11 @@ def forward(self, L_ctx_saved_tensors_0_: "f32[4]", L_ctx_pred: "b8[]", L_args_1 return (getitem,) class cond_true_0(torch.nn.Module): +<<<<<<< HEAD def forward(self, l_args_1_: "f32[4]", l_ctx_saved_tensors_0_: "f32[4]"): +======= + def forward(self, l_args_1_, l_ctx_saved_tensors_0_): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) l_args_1__1 = l_args_1_ l_ctx_saved_tensors_0__1 = l_ctx_saved_tensors_0_ @@ -5334,7 +5779,11 @@ def forward(self, l_args_1_: "f32[4]", l_ctx_saved_tensors_0_: "f32[4]"): return (mul,) class cond_false_0(torch.nn.Module): +<<<<<<< HEAD def forward(self, l_args_1_: "f32[4]", l_ctx_saved_tensors_0_: "f32[4]"): +======= + def forward(self, l_args_1_, l_ctx_saved_tensors_0_): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) l_args_1__1 = l_args_1_ l_ctx_saved_tensors_0__1 = l_ctx_saved_tensors_0_ @@ -5419,6 +5868,11 @@ def forward(self, arg0_1, arg1_1, arg2_1): def test_while_loop_pytree_carry(self): fn, inp = WHILE_LOOP_TESTS["simple_with_pytree_carry"] +<<<<<<< HEAD +======= + from torch._dynamo.testing import EagerAndRecordGraphs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) backend = EagerAndRecordGraphs() expected_res = fn(*inp) compiled_res = torch.compile(fn, backend=backend)(*inp) @@ -5572,7 +6026,11 @@ def forward(self, arg0_1): ) @parametrize("func_type", ["no", "cpp", "python", "functorch"]) +<<<<<<< HEAD # - "simple_with_linear" and "nested_with_linear" doesn't work because parameters and buffers +======= + # - "simple_with_linear" and "nested_with_linear" doesn't work becaue parameters and buffers +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # are not inputs so they're not wrapped by functionalization and tracing. # # - make_fx tracing mode "real" fails for "int_carry", "pytree_int_carry" and "const_and_symint_output" @@ -5625,12 +6083,18 @@ def test_while_loop_compile(self, backend, while_loop_test): @skipIfCrossRef # Arg order changes with cross ref def test_while_loop_simple_with_linear_compile_check_graph(self): fn, inp = WHILE_LOOP_TESTS["simple_with_linear"] +<<<<<<< HEAD +======= + from torch._dynamo.testing import EagerAndRecordGraphs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) backend = EagerAndRecordGraphs() torch.compile(fn, backend=backend)(*inp) self.assertEqual(len(backend.graphs), 1) gm = backend.graphs[0] if torch._dynamo.config.inline_inbuilt_nn_modules: self.assertExpectedInline( +<<<<<<< HEAD normalize_gm(gm.print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): @@ -5660,6 +6124,71 @@ def forward(self, child_2: "i64[]", child_3: "f32[2, 2]", l_self_buffers_dec__co child_4: "f32[2, 2]" = torch._C._nn.linear(child_3, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); child_3 = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None return (child, child_4) """, # noqa: B950 +======= + gm.code.strip(), + """\ +def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor, L_self_buffers_dec_ : torch.Tensor, L_self_modules_linear_parameters_weight_ : torch.nn.parameter.Parameter, L_self_modules_linear_parameters_bias_ : torch.nn.parameter.Parameter): + l_iter_ = L_iter_ + l_x_ = L_x_ + l_self_buffers_dec_ = L_self_buffers_dec_ + l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_ + l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_ + cond_fn_0 = self.cond_fn_0 + body_fn_0 = self.body_fn_0 + while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l_self_buffers_dec_, l_self_modules_linear_parameters_bias_, l_self_modules_linear_parameters_weight_)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l_self_buffers_dec_ = l_self_modules_linear_parameters_bias_ = l_self_modules_linear_parameters_weight_ = None + getitem = while_loop[0] + getitem_1 = while_loop[1]; while_loop = None + return (getitem, getitem_1)""", # noqa: B950 + ) + self.assertExpectedInline( + gm.cond_fn_0.code.strip(), + """\ +def forward(self, l_iter_ : torch.Tensor, l_x_ : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): + sub = l_iter_ - l_self_buffers_dec__cond_fn; l_iter_ = l_self_buffers_dec__cond_fn = None + gt = sub > 0; sub = None + return gt""", # noqa: B950 + ) + self.assertExpectedInline( + gm.body_fn_0.code.strip(), + """\ +def forward(self, l_iter_ : torch.Tensor, l_x_ : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): + child = l_iter_ - 1; l_iter_ = None + child_1 = torch._C._nn.linear(l_x_, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); l_x_ = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None + return (child, child_1)""", # noqa: B950 + ) + else: + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor): + l_iter_ = L_iter_ + l_x_ = L_x_ + l__self___dec = self.L__self___dec + l__self___linear_weight = self.L__self___linear_weight + l__self___linear_bias = self.L__self___linear_bias + cond_fn_0 = self.cond_fn_0 + body_fn_0 = self.body_fn_0 + while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l__self___dec, l__self___linear_bias, l__self___linear_weight)); cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l__self___dec = l__self___linear_bias = l__self___linear_weight = None + getitem = while_loop[0] + getitem_1 = while_loop[1]; while_loop = None + return (getitem, getitem_1)""", # noqa: B950 + ) + self.assertExpectedInline( + gm.cond_fn_0.code.strip(), + """\ +def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn): + sub = l_iter_ - l__self___dec_cond_fn; l_iter_ = l__self___dec_cond_fn = None + gt = sub > 0; sub = None + return gt""", # noqa: B950 + ) + self.assertExpectedInline( + gm.body_fn_0.code.strip(), + """\ +def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn): + child = l_iter_ - 1; l_iter_ = None + child_1 = torch._C._nn.linear(l_x_, l__self___linear_weight_body_fn, l__self___linear_bias_body_fn); l_x_ = l__self___linear_weight_body_fn = l__self___linear_bias_body_fn = None + return (child, child_1)""", # noqa: B950 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_while_loop_nested2_traced(self): @@ -7652,6 +8181,11 @@ def forward(self, inp: torch.Tensor, tmp: torch.Tensor) -> torch.Tensor: ): out = torch.compile(Mod(), backend="inductor")(inp, tmp) +<<<<<<< HEAD +======= + from torch._dynamo.testing import EagerAndRecordGraphs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) backend = EagerAndRecordGraphs() out = torch.compile(Mod(), backend=backend)(inp, tmp) self.assertExpectedInline( @@ -7673,9 +8207,16 @@ def forward(self, l_inp_, l_tmp_): ) self.assertEqual(out, f(inp, tmp)) +<<<<<<< HEAD @skipIfCrossRef # Args get renamed to r in crossref mode @parametrize("requires_grad", [True, False]) def test_cond_symint_operands(self, requires_grad): +======= + @parametrize("requires_grad", [True, False]) + def test_cond_symint_operands(self, requires_grad): + from torch._dynamo.testing import EagerAndRecordGraphs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) backend = EagerAndRecordGraphs() class Mod(torch.nn.Module): @@ -7839,6 +8380,11 @@ def f(init, xs): @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") def test_scan_pytree_closure(self): +<<<<<<< HEAD +======= + from torch._dynamo.testing import EagerAndRecordGraphs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) param_buffer = ({"param": torch.randn(3, 3)}, (torch.randn(3),)) def add(carry, x): @@ -7893,7 +8439,16 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_ self.assertEqual(compiled_out, exp_out) @skipIfTorchDynamo("Skip because we're testing export") +<<<<<<< HEAD @parametrize("strict", [True, False]) +======= + # TODO: we cannot turn on strict=True yet because torch._check for out_it > 0 is + # removed from the graph in dynamo and in non-strict export's graph capturing + # step, we re-run the traced graph module to get graph captured result. + # Since torch._check is removed from graph, we end up getting a data-dependent + # error when we call torch.ones(out_it * 2). + @parametrize("strict", [False]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("dynamic", [True, False]) def test_while_loop_op_int_carry_export(self, strict, dynamic): m, args = WHILE_LOOP_TESTS["int_carry"] @@ -7908,8 +8463,11 @@ def forward(self, x): x: "f32[s77, 3]"; x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0) while_loop_cond_graph_0 = self.while_loop_cond_graph_0 @@ -7938,9 +8496,14 @@ def forward(self, x): class while_loop_cond_graph_0(torch.nn.Module): def forward(self, it_1: "Sym(u0)", x_1: "f32[s77, 3]"): +<<<<<<< HEAD sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None lt: "Sym(u0 < s77)" = it_1 < sym_size_int_1; it_1 = sym_size_int_1 = None +======= + sym_size_int: "Sym(s77)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None + lt: "Sym(u0 < s77)" = it_1 < sym_size_int; it_1 = sym_size_int = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return lt class while_loop_body_graph_0(torch.nn.Module): @@ -7959,6 +8522,11 @@ def forward(self, it_1: "Sym(u0)", x_1: "f32[s77, 3]"): @parametrize("dynamic", [True, False]) @parametrize("backend", ["eager", "aot_eager"]) def test_while_loop_op_int_carry_compile(self, dynamic, backend): +<<<<<<< HEAD +======= + from torch._dynamo.testing import EagerAndRecordGraphs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m, args = WHILE_LOOP_TESTS["int_carry"] if backend == "eager": backend = EagerAndRecordGraphs() @@ -7980,6 +8548,7 @@ def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"): body_fn_0 = self.body_fn_0 while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (0, l_x_), (s27, s77)); cond_fn_0 = body_fn_0 = l_x_ = s27 = None +<<<<<<< HEAD getitem_4: "Sym(u2)" = while_loop[0] ge: "Sym(u2 >= 1)" = getitem_4 >= 1 @@ -8009,12 +8578,44 @@ def forward(self, unbacked_symint: "Sym(u0)", child: "f32[s77, s27]", s27: "Sym( s77_1 = s77 size = child.size(); child = None +======= + getitem_4: "Sym(u1)" = while_loop[0] + + ge: "Sym(u1 >= 1)" = getitem_4 >= 1 + _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 1 on node 'ge'"); ge = _assert_scalar_default = None + + gt_1: "Sym(u1 > 0)" = getitem_4 > 0 + _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None + + out_x: "f32[s77, s27]" = while_loop[1]; while_loop = None + + gt: "Sym(u1 > 0)" = getitem_4 > 0 + _check = torch._check(gt); gt = _check = None + + add: "Sym(u1 + 1)" = getitem_4 + 1 + + add_1: "f32[s77, s27]" = getitem_4 + out_x; out_x = None + + lt: "Sym(u1 < s77)" = getitem_4 < s77; s77 = None + + mul: "Sym(2*u1)" = getitem_4 * 2; getitem_4 = None + ones: "f32[2*u1]" = torch.ones(mul); mul = None + return (add, add_1, lt, ones) + + class cond_fn_0(torch.nn.Module): + def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27, s77): + s27_1 = s27 + s77_1 = s77 + + size = l_x_.size(); l_x_ = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) getitem: "Sym(s77)" = size[0] getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None lt: "Sym(u0 < s77)" = unbacked_symint < getitem; unbacked_symint = getitem = None return lt class body_fn_0(torch.nn.Module): +<<<<<<< HEAD def forward(self, unbacked_symint_0: "Sym(u1)", child_1: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): s27_1 = s27 s77_1 = s77 @@ -8036,6 +8637,29 @@ def forward(self, unbacked_symint_0: "Sym(u1)", child_1: "f32[s77, s27]", s27: " copy_: "f32[s27]" = select.copy_(add); select = add = copy_ = None add_1: "Sym(u1 + 1)" = unbacked_symint_0 + 1; unbacked_symint_0 = None +======= + def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27, s77): + s27_1 = s27 + s77_1 = s77 + + x_clone: "f32[s77, s27]" = l_x_.clone() + + ge: "Sym(u0 >= 0)" = unbacked_symint >= 0 + _check = torch._check(ge); ge = _check = None + + size = l_x_.size(); l_x_ = None + getitem: "Sym(s77)" = size[0] + getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None + lt: "Sym(u0 < s77)" = unbacked_symint < getitem; getitem = None + _check_1 = torch._check(lt); lt = _check_1 = None + + select: "f32[s27]" = x_clone.select(0, unbacked_symint) + select_1: "f32[s27]" = x_clone.select(0, unbacked_symint) + add: "f32[s27]" = select_1 + unbacked_symint; select_1 = None + copy_: "f32[s27]" = select.copy_(add); select = add = copy_ = None + + add_1: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (add_1, x_clone) """, # noqa: B950 ) @@ -8058,8 +8682,11 @@ def forward(self, t): t: "f32[2, 3]"; t, = fx_pytree.tree_flatten_spec(([t], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(t); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sum_1: "f32[]" = torch.ops.aten.sum.default(t) _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(sum_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None to: "i64[]" = torch.ops.aten.to.dtype(sum_1, torch.int64); sum_1 = None @@ -8120,6 +8747,11 @@ def forward(self, a_1: "Sym(u1)", b_1: "Sym(u2)", c1_1: "Sym(u3)", c2_1: "Sym(u4 @parametrize("backend", ["eager", "aot_eager"]) @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_while_loop_op_constant_and_symint_output_compile(self, dynamic, backend): +<<<<<<< HEAD +======= + from torch._dynamo.testing import EagerAndRecordGraphs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m, args = WHILE_LOOP_TESTS["const_and_symint_output"] if backend == "eager": backend = EagerAndRecordGraphs() @@ -8141,6 +8773,7 @@ def forward(self, L_t_: "f32[2, 3]"): sum_1: "f32[]" = l_t_.sum() to: "i64[]" = sum_1.to(torch.int64); sum_1 = None item: "Sym(u0)" = to.item(); to = None +<<<<<<< HEAD sin: "f32[2, 3]" = l_t_.sin() cond_fn_0 = self.cond_fn_0 @@ -8165,6 +8798,32 @@ def forward(self, L_t_: "f32[2, 3]"): add_5: "Sym(u20 + 1)" = getitem_13 + 1 add_6: "Sym(u21 + 1)" = getitem_14 + 1 add_7: "f32[2, 3]" = child + 1 +======= + child: "f32[2, 3]" = l_t_.sin() + + cond_fn_0 = self.cond_fn_0 + body_fn_0 = self.body_fn_0 + while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (2, 3, 1, 1, 1, 3, item, child), ()); cond_fn_0 = body_fn_0 = item = child = None + + getitem_8: "Sym(u8)" = while_loop[0] + getitem_9: "Sym(u9)" = while_loop[1] + getitem_10: "Sym(u10)" = while_loop[2] + getitem_11: "Sym(u11)" = while_loop[3] + getitem_12: "Sym(u12)" = while_loop[4] + getitem_13: "Sym(u13)" = while_loop[5] + getitem_14: "Sym(u14)" = while_loop[6] + + child_1: "f32[2, 3]" = while_loop[7]; while_loop = None + + add: "Sym(u8 + 1)" = getitem_8 + 1 + add_1: "Sym(u9 + 1)" = getitem_9 + 1 + add_2: "Sym(u10 + 1)" = getitem_10 + 1 + add_3: "Sym(u11 + 1)" = getitem_11 + 1 + add_4: "Sym(u12 + 1)" = getitem_12 + 1 + add_5: "Sym(u13 + 1)" = getitem_13 + 1 + add_6: "Sym(u14 + 1)" = getitem_14 + 1 + add_7: "f32[2, 3]" = child_1 + 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add_8: "f32[2, 3]" = getitem_8 + l_t_; getitem_8 = None add_9: "f32[2, 3]" = getitem_9 + l_t_; getitem_9 = None @@ -8173,7 +8832,11 @@ def forward(self, L_t_: "f32[2, 3]"): add_12: "f32[2, 3]" = getitem_12 + l_t_; getitem_12 = None add_13: "f32[2, 3]" = getitem_13 + l_t_; getitem_13 = None add_14: "f32[2, 3]" = getitem_14 + l_t_; getitem_14 = None +<<<<<<< HEAD add_15: "f32[2, 3]" = child + l_t_; child = l_t_ = None +======= + add_15: "f32[2, 3]" = child_1 + l_t_; child_1 = l_t_ = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9, add_10, add_11, add_12, add_13, add_14, add_15) class cond_fn_0(torch.nn.Module): @@ -8185,10 +8848,17 @@ def forward(self, unbacked_symint: "Sym(u1)", unbacked_symint_0: "Sym(u2)", unba return lt class body_fn_0(torch.nn.Module): +<<<<<<< HEAD def forward(self, unbacked_symint_6: "Sym(u8)", unbacked_symint_7: "Sym(u9)", unbacked_symint_8: "Sym(u10)", unbacked_symint_9: "Sym(u11)", unbacked_symint_10: "Sym(u12)", unbacked_symint_11: "Sym(u13)", unbacked_symint_12: "Sym(u14)", child_1: "f32[2, 3]"): add: "Sym(u14 + 1)" = unbacked_symint_12 + 1; unbacked_symint_12 = None child: "f32[2, 3]" = child_1 + 1; child_1 = None return (unbacked_symint_7, unbacked_symint_8, unbacked_symint_9, unbacked_symint_10, unbacked_symint_6, 0, add, child) +======= + def forward(self, unbacked_symint: "Sym(u1)", unbacked_symint_0: "Sym(u2)", unbacked_symint_1: "Sym(u3)", unbacked_symint_2: "Sym(u4)", unbacked_symint_3: "Sym(u5)", unbacked_symint_4: "Sym(u6)", unbacked_symint_5: "Sym(u7)", child: "f32[2, 3]"): + add: "Sym(u7 + 1)" = unbacked_symint_5 + 1; unbacked_symint_5 = None + child_1: "f32[2, 3]" = child + 1; child = None + return (unbacked_symint_0, unbacked_symint_1, unbacked_symint_2, unbacked_symint_3, unbacked_symint, 0, add, child_1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) @@ -8199,7 +8869,11 @@ def test_while_loop_op_pytree_int_carry_export(self, strict, dynamic): m, args = WHILE_LOOP_TESTS["pytree_int_carry"] dynamic_shapes = {"x": {0: torch.export.Dim("dim_x")}} if dynamic else None ep = self._check_export(m, args, strict=strict, dynamic_shapes=dynamic_shapes) +<<<<<<< HEAD if strict and dynamic and not TEST_WITH_CROSSREF: +======= + if strict and dynamic: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertExpectedInline( normalize_gm(ep.module().print_readable(print_output=False)), """\ @@ -8208,8 +8882,11 @@ def forward(self, x): x: "f32[s77, 3]"; x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0) sin: "f32[s77, 3]" = torch.ops.aten.sin.default(x); x = None @@ -8218,6 +8895,7 @@ def forward(self, x): while_loop_body_graph_0 = self.while_loop_body_graph_0 while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (sym_size_int_1, 3, 2, 2, 3, sin), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = sym_size_int_1 = sin = None +<<<<<<< HEAD getitem_6: "Sym(u10)" = while_loop[0] getitem_7: "Sym(u11)" = while_loop[1] getitem_8: "Sym(u12)" = while_loop[2] @@ -8229,6 +8907,19 @@ def forward(self, x): add: "Sym(u12 + 1)" = getitem_8 + 1 add_1: "Sym(u13 + 1)" = getitem_9 + 1 add_2: "Sym(u14 + 1)" = getitem_10 + 1 +======= + getitem_6: "Sym(u5)" = while_loop[0] + getitem_7: "Sym(u6)" = while_loop[1] + getitem_8: "Sym(u7)" = while_loop[2] + getitem_9: "Sym(u8)" = while_loop[3] + getitem_10: "Sym(u9)" = while_loop[4] + + getitem_5: "f32[s77, 3]" = while_loop[5]; while_loop = None + + add: "Sym(u7 + 1)" = getitem_8 + 1 + add_1: "Sym(u8 + 1)" = getitem_9 + 1 + add_2: "Sym(u9 + 1)" = getitem_10 + 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add_3: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None add_4: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None @@ -8236,6 +8927,7 @@ def forward(self, x): return pytree.tree_unflatten((getitem_6, getitem_7, add, add_1, add_2, add_3, add_4, add_5, getitem_5), self._out_spec) class while_loop_cond_graph_0(torch.nn.Module): +<<<<<<< HEAD def forward(self, arg0_1: "Sym(u20)", arg1_1: "Sym(u21)", arg2_1: "Sym(u22)", arg3_1: "Sym(u23)", arg4_1: "Sym(u24)", arg5_1: "f32[s77, 3]"): mul: "Sym(u22*u23)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None mul_1: "Sym(u22*u23*u24)" = mul * arg4_1; mul = arg4_1 = None @@ -8251,6 +8943,23 @@ def forward(self, arg0_1: "Sym(u20)", arg1_1: "Sym(u21)", arg2_1: "Sym(u22)", ar add_2: "Sym(u22 + 1)" = arg2_1 + 1; arg2_1 = None add_3: "Sym(u23 + 1)" = arg3_1 + 1; arg3_1 = None add_4: "Sym(u24 + 1)" = arg4_1 + 1; arg4_1 = None +======= + def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"): + mul: "Sym(u17*u18)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None + mul_1: "Sym(u17*u18*u19)" = mul * arg4_1; mul = arg4_1 = None + mul_2: "Sym(u15*u16)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None + lt: "Sym(u17*u18*u19 < u15*u16)" = mul_1 < mul_2; mul_1 = mul_2 = None + return lt + + class while_loop_body_graph_0(torch.nn.Module): + def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"): + add: "Sym(u15 + 1)" = arg0_1 + 1; arg0_1 = None + add_1: "Sym(u16 + 1)" = arg1_1 + 1; arg1_1 = None + + add_2: "Sym(u17 + 1)" = arg2_1 + 1; arg2_1 = None + add_3: "Sym(u18 + 1)" = arg3_1 + 1; arg3_1 = None + add_4: "Sym(u19 + 1)" = arg4_1 + 1; arg4_1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None return (add, add_1, add_2, add_3, add_4, add_5) @@ -8260,7 +8969,14 @@ def forward(self, arg0_1: "Sym(u20)", arg1_1: "Sym(u21)", arg2_1: "Sym(u22)", ar @skipIfTorchDynamo("Graph is not captured correctly when test with dynamo") @parametrize("dynamic", [True, False]) @parametrize("backend", ["eager", "aot_eager"]) +<<<<<<< HEAD + def test_while_loop_op_pytree_int_carry_compile(self, dynamic, backend): +======= + @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_while_loop_op_pytree_int_carry_compile(self, dynamic, backend): + from torch._dynamo.testing import EagerAndRecordGraphs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m, args = WHILE_LOOP_TESTS["pytree_int_carry"] if backend == "eager": backend = EagerAndRecordGraphs() @@ -8284,6 +9000,7 @@ def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"): body_fn_0 = self.body_fn_0 while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (s77, s27, 2, 2, 3, child), (s27, s77)); cond_fn_0 = body_fn_0 = s77 = s27 = child = None +<<<<<<< HEAD getitem_10: "Sym(u10)" = while_loop[0] getitem_11: "Sym(u11)" = while_loop[1] getitem_12: "Sym(u12)" = while_loop[2] @@ -8295,6 +9012,19 @@ def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"): add: "Sym(u12 + 1)" = getitem_12 + 1 add_1: "Sym(u13 + 1)" = getitem_13 + 1 add_2: "Sym(u14 + 1)" = getitem_14 + 1 +======= + getitem_10: "Sym(u5)" = while_loop[0] + getitem_11: "Sym(u6)" = while_loop[1] + getitem_12: "Sym(u7)" = while_loop[2] + getitem_13: "Sym(u8)" = while_loop[3] + getitem_14: "Sym(u9)" = while_loop[4] + + out_x: "f32[s77, s27]" = while_loop[5]; while_loop = None + + add: "Sym(u7 + 1)" = getitem_12 + 1 + add_1: "Sym(u8 + 1)" = getitem_13 + 1 + add_2: "Sym(u9 + 1)" = getitem_14 + 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add_3: "f32[s77, s27]" = getitem_12 + out_x; getitem_12 = None add_4: "f32[s77, s27]" = getitem_13 + out_x; getitem_13 = None @@ -8302,7 +9032,11 @@ def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"): return (getitem_10, getitem_11, add, add_1, add_2, add_3, add_4, add_5, out_x) class cond_fn_0(torch.nn.Module): +<<<<<<< HEAD def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child_1: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): +======= + def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27, s77): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) s27_1 = s27 s77_1 = s77 @@ -8313,6 +9047,7 @@ def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unba return lt class body_fn_0(torch.nn.Module): +<<<<<<< HEAD def forward(self, unbacked_symint_4: "Sym(u5)", unbacked_symint_5: "Sym(u6)", unbacked_symint_6: "Sym(u7)", unbacked_symint_7: "Sym(u8)", unbacked_symint_8: "Sym(u9)", child_2: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): s27_1 = s27 s77_1 = s77 @@ -8504,6 +9239,21 @@ def forward(self, arg0_1: "i64[]", arg1_1: "f32[3, 3]", arg2_1: "f32[3]", arg3_1 add_10: "f32[3]" = torch.ops.aten.add.Tensor(view, arg2_1); view = arg2_1 = None add_11: "f32[3, 3]" = torch.ops.aten.add.Tensor(t_4, arg3_1); t_4 = arg3_1 = None return (add_9, add_8, add_10, add_11) +======= + def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27, s77): + s27_1 = s27 + s77_1 = s77 + + add: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None + add_1: "Sym(u1 + 1)" = unbacked_symint_0 + 1; unbacked_symint_0 = None + + add_2: "Sym(u2 + 1)" = unbacked_symint_1 + 1; unbacked_symint_1 = None + add_3: "Sym(u3 + 1)" = unbacked_symint_2 + 1; unbacked_symint_2 = None + add_4: "Sym(u4 + 1)" = unbacked_symint_3 + 1; unbacked_symint_3 = None + + child_1: "f32[s77, s27]" = child + 1; child = None + return (add, add_1, add_2, add_3, add_4, child_1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, # noqa: B950 ) @@ -8575,6 +9325,11 @@ def mutate_f(x): @skipIfTorchDynamo("Graph is not captured correctly when test with dynamo") def test_while_loop_unbacked_bindings(self): +<<<<<<< HEAD +======= + from torch._dynamo.testing import EagerAndRecordGraphs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m, args = WHILE_LOOP_TESTS["pytree_int_carry"] backend = EagerAndRecordGraphs() self._check_compile(m, args, dynamic=True, backend=backend) @@ -8599,6 +9354,10 @@ def _check_export_ret_graph_str(self, fn, args, dynamic_shapes=None) -> str: return normalize_gm(non_strict_ep.module().print_readable(print_output=False)) @skipIfTorchDynamo("Skip because dynamo cannot trace torch.export.") +<<<<<<< HEAD +======= + @torch._dynamo.config.patch(capture_scalar_outputs=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cond_eager_run_with_item(self): class M(torch.nn.Module): def forward(self, a, b1, b2, c): @@ -8629,8 +9388,11 @@ def forward(self, a, b1, b2, c): a: "b8[]"; b1: "i64[1]"; b2: "i64[1]"; c: "f32[10]"; a, b1, b2, c, = fx_pytree.tree_flatten_spec(([a, b1, b2, c], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(a, b1, b2, c); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(a, true_graph_0, false_graph_0, (c, b1, b2)); a = true_graph_0 = false_graph_0 = c = b1 = b2 = None @@ -8655,6 +9417,7 @@ def forward(self, c: "f32[10]", b1: "i64[1]", b2: "i64[1]"): """, # noqa: B950 ) +<<<<<<< HEAD def test_cond_merge_graph_preserves_ph_meta(self): class M(torch.nn.Module): def forward(self, x, y, z): @@ -8680,6 +9443,8 @@ def false_fn(x): for ph in subgm.graph.find_nodes(op="placeholder"): self.assertTrue("example_value" in ph.meta) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfTorchDynamo("Skip because dynamo cannot trace torch.export.") def test_cond_symint_closure(self): from torch.export import Dim @@ -8713,8 +9478,11 @@ def forward(self, x, y, z): x: "f32[s68, 3]"; y: "f32[s17]"; z: "f32[s68, 3]"; x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec) +<<<<<<< HEAD _guards_fn = self._guards_fn(x, y, z); _guards_fn = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sym_size_int_4: "Sym(s17)" = torch.ops.aten.sym_size.int(y, 0); y = None sym_size_int_5: "Sym(s68)" = torch.ops.aten.sym_size.int(z, 0) @@ -8862,6 +9630,11 @@ def _inner(case): @parametrize("dynamic", [True, False]) @parametrize("backend", ["eager", "aot_eager"]) def test_cond_mismatched_branch_output(self, dynamic, backend): +<<<<<<< HEAD +======= + from torch._dynamo.testing import EagerAndRecordGraphs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class M(torch.nn.Module): def forward(self, x, y, z): a = y.shape[0] @@ -8921,7 +9694,11 @@ def forward(self, s17: "Sym(s17)", s94: "Sym(s94)", L_y_: "f32[s17, s94]", L_z_: return (sub,) class cond_true_0(torch.nn.Module): +<<<<<<< HEAD def forward(self, l_x_: "f32[s17, s94]", s94: "Sym(s94)", s17_true_branch: "Sym(s17)", getitem_2_false_branch: "Sym(s17)", l_z__false_branch: "f32[s17, s94]"): +======= + def forward(self, l_x_, s94, s17_true_branch, getitem_2_false_branch, l_z__false_branch): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) l_x__1 = l_x_ s94_1 = s94 @@ -8931,7 +9708,11 @@ def forward(self, l_x_: "f32[s17, s94]", s94: "Sym(s94)", s17_true_branch: "Sym( return (clone,) class cond_false_0(torch.nn.Module): +<<<<<<< HEAD def forward(self, l_x_: "f32[s17, s94]", s94: "Sym(s94)", s17_true_branch: "Sym(s17)", getitem_2_false_branch: "Sym(s17)", l_z__false_branch: "f32[s17, s94]"): +======= + def forward(self, l_x_, s94, s17_true_branch, getitem_2_false_branch, l_z__false_branch): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) l_x__1 = l_x_ s94_1 = s94 @@ -9094,6 +9875,25 @@ def test_function_schema_gen(self): self.assertEqual(schema2.parse(str(schema2)), schema2) self.assertEqual(schema3.parse(str(schema3)), schema3) +<<<<<<< HEAD +======= + def test_while_loop_schema_gen(self): + fn, inp = WHILE_LOOP_TESTS["simple_with_linear"] + graph = make_fx(fn)(*inp).graph + while_loop_node = next( + node + for node in graph.nodes + if node.op == "call_function" + and node.target is torch.ops.higher_order.while_loop + ) + schema = torch._library.utils.hop_schema_from_fx_node(while_loop_node) + self.assertExpectedInline( + str(schema), + """while_loop(GraphModule cond_fn, GraphModule body_fn, Tensor[2] carried_inputs, Tensor[3] additional_inputs) -> Tensor[2]""", # noqa: B950 + ) + self.assertEqual(schema.parse(str(schema)), schema) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_schema_tree_spec(self): schema_gen = HopSchemaGenerator(torch.ops.higher_order.cond) args = (torch.randn(3, 4), torch.randn(2, 3)) @@ -9110,6 +9910,7 @@ def test_schema_tree_spec(self): str(flat_schema), """cond(Tensor tuple_args0, Tensor tuple_args1) -> ()""" ) +<<<<<<< HEAD def test_cond_gen_schema_tensor_inputs(self): schema = torch.ops.higher_order.cond.gen_schema( torch.tensor(True), @@ -9306,6 +10107,8 @@ def body_fn(x, y, z, c): """while_loop(Any cond_fn, Any body_fn, Tensor(a2!) carried_input0, Tensor(a3!) carried_input1, Tensor(a4!) carried_input2, Tensor(a5!) additional_input0) -> (Tensor, Tensor, Tensor)""", # noqa: B950 ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_parametrized_tests(TestHopSchema) instantiate_parametrized_tests(TestControlFlowTraced) diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 78e64278cb1e2..0f8ba716ec57d 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -399,6 +399,7 @@ def is_inplace(op, variant): "as_strided_copy", } +<<<<<<< HEAD bool_unsupported_ordered_ops = { "topk", "argmin", @@ -431,6 +432,8 @@ def is_inplace(op, variant): filter(lambda op: op.name in complex_unsupported_ordered_ops, op_db) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant") @unMarkDynamoStrictTest @@ -468,6 +471,7 @@ class TestOperators(TestCase): ), # Works on ROCm xfail("torch.ops.aten._flash_attention_forward"), xfail("torch.ops.aten._efficient_attention_forward"), +<<<<<<< HEAD # RuntimeError: Expected contiguous tensor, but got # non-contiguous tensor for argument #2 'grad_output' decorate( @@ -475,6 +479,8 @@ class TestOperators(TestCase): decorator=expectedFailureIf(TEST_WITH_ROCM), device_type="cuda", ), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ), ) @@ -2400,6 +2406,7 @@ def fn(input, weight, bias): skip("sparse.sampled_addmm", ""), skip("sparse.mm", "reduce"), skip("native_layer_norm", "", device_type="cpu"), +<<<<<<< HEAD # RuntimeError: Expected contiguous tensor, but got # non-contiguous tensor for argument #2 'grad_output' decorate( @@ -2407,6 +2414,8 @@ def fn(input, weight, bias): decorator=expectedFailureIf(TEST_WITH_ROCM), device_type="cuda", ), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, ) @opsToleranceOverride( @@ -3000,6 +3009,7 @@ def func(x): actual_fn(torch.ones_like(actual_o)), ) +<<<<<<< HEAD @ops(bool_ordered_op_db, dtypes=[torch.bool]) def test_ordered_bool_raises(self, device, dtype, op): # Generate sample inputs for the op @@ -3033,6 +3043,8 @@ def test_ordered_complex_raises(self, device, dtype, op): **sample_input.kwargs, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) only_for = ("cpu", "cuda") instantiate_device_type_tests(TestOperators, globals(), only_for=only_for) diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 0f893201733d3..c34e7ae805467 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -734,7 +734,10 @@ def test_fallback_does_not_warn_by_default(self): # warning, not a warning from the vmap fallback path. self.assertEqual(len(wa), 1) +<<<<<<< HEAD @skipIfTorchDynamo("Flaky test") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.expectedFailure def test_fallback_warns_when_warnings_are_enabled(self): # NB: One day we will implement a batching rule for torch.atan2. @@ -4152,7 +4155,11 @@ def test(): with subtest_ctx(self), skip_xfail_ctx(self): args = (sample_input.input,) + sample_input.args if not any(isinstance(arg, torch.Tensor) for arg in args): +<<<<<<< HEAD # At least one tensor required for vmap. +======= + # Atleast one tensor required for vmap. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue kwargs = sample_input.kwargs is_batch_norm_and_training = is_batch_norm_training(op.name, kwargs) @@ -4230,7 +4237,11 @@ def sample_vmap_out_dim_numpy_split_copy_with_int( xfail("as_strided_copy"), xfail( "as_strided_scatter" +<<<<<<< HEAD ), # no batching rule implemented, default doesn't work +======= + ), # no batching rule implemented, default doesnt work +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) skip( "new_empty_strided" ), # empty tensor data is garbage so it's hard to make comparisons with it @@ -4534,6 +4545,7 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail("clamp_min", ""), xfail("sparse.sampled_addmm"), xfail("sparse.mm", "reduce"), +<<<<<<< HEAD xfail("special.chebyshev_polynomial_t"), xfail("special.chebyshev_polynomial_v"), xfail("special.chebyshev_polynomial_u"), @@ -4542,13 +4554,19 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail("special.shifted_chebyshev_polynomial_v"), xfail("special.shifted_chebyshev_polynomial_u"), xfail("special.shifted_chebyshev_polynomial_w"), +======= + xfail("special.chebyshev_polynomial_u"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("_segment_reduce", "offsets"), xfail("index_reduce", "prod"), xfail("index_reduce", "mean"), xfail("index_reduce", "amin"), xfail("index_reduce", "amax"), xfail("special.laguerre_polynomial_l"), +<<<<<<< HEAD xfail("special.legendre_polynomial_p"), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("special.hermite_polynomial_h"), xfail("jiterator_binary", device_type="cuda"), xfail("jiterator_4inputs_with_extra_args", device_type="cuda"), @@ -4556,6 +4574,10 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail("lu_solve", ""), xfail("special.hermite_polynomial_he"), xfail("nn.functional.dropout3d", ""), +<<<<<<< HEAD +======= + xfail("special.chebyshev_polynomial_t"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("as_strided_scatter", ""), xfail("equal", ""), xfail("linalg.lu", ""), diff --git a/test/functorch/test_vmap_registrations.py b/test/functorch/test_vmap_registrations.py index adb66ac4d9709..516914846e10e 100644 --- a/test/functorch/test_vmap_registrations.py +++ b/test/functorch/test_vmap_registrations.py @@ -208,7 +208,10 @@ "aten::subtract_.Scalar", "aten::subtract_.Tensor", "aten::svd.U", +<<<<<<< HEAD "aten::sym_is_contiguous", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "aten::sym_size.int", "aten::sym_stride.int", "aten::sym_numel", diff --git a/test/fx/test_dce_pass.py b/test/fx/test_dce_pass.py index 7fd3a6dbb0041..dcd7075cd1d58 100644 --- a/test/fx/test_dce_pass.py +++ b/test/fx/test_dce_pass.py @@ -238,8 +238,12 @@ def forward(self, a: torch.Tensor) -> torch.Tensor: def test_impure_random(self): """ +<<<<<<< HEAD Test that DCE doesn't remove call_function for torch.rand and other random functions. Tests both FX tracing and AOT compilation (issue #151524). +======= + Test that DCE doesn't remove call_function for torch.rand. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ class TestModule(torch.nn.Module): @@ -247,6 +251,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor: x = torch.rand([10]) # noqa: F841 return a * 2 +<<<<<<< HEAD # Test FX tracing + DCE self._run_dce_and_test(TestModule(), expect_dce_changes=False) @@ -304,6 +309,11 @@ def count_random_ops(): compiled_result = torch.compile(model, backend=aot_backend)(torch.tensor([1.0])) self.assertEqual(eager_result, compiled_result) +======= + # %torch.rand should not be removed because it has side effects. + self._run_dce_and_test(TestModule(), expect_dce_changes=False) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_impure_kwargs(self): """ Test that DCE doesn't remove call_function nodes with side effects on kwargs. @@ -338,6 +348,11 @@ def test_keep_collectives(self): Test that DCE doesn't remote collective ops even the results are not used. """ +<<<<<<< HEAD +======= + from torch.testing._internal.distributed.fake_pg import FakeStore + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestModule(torch.nn.Module): def forward( self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor @@ -352,6 +367,10 @@ def forward( backend="fake", world_size=2, rank=0, +<<<<<<< HEAD +======= + store=FakeStore(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # collective nodes should not be removed because they have side effects. self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False) @@ -363,6 +382,11 @@ def test_keep_collectives_no_overload(self): Test that DCE doesn't remote collective ops (no overload version) even the results are not used. """ +<<<<<<< HEAD +======= + from torch.testing._internal.distributed.fake_pg import FakeStore + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestModule(torch.nn.Module): def forward( self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor @@ -377,6 +401,10 @@ def forward( backend="fake", world_size=2, rank=0, +<<<<<<< HEAD +======= + store=FakeStore(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # collective nodes should not be removed because they have side effects. self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False) diff --git a/test/fx/test_fx_traceback.py b/test/fx/test_fx_traceback.py index 05369d17078ba..be9f34d4e33d9 100644 --- a/test/fx/test_fx_traceback.py +++ b/test/fx/test_fx_traceback.py @@ -2,7 +2,10 @@ import torch from torch._inductor.compile_fx import aot_export_module +<<<<<<< HEAD from torch.export import default_decompositions +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.traceback import get_graph_provenance_json, NodeSource, NodeSourceAction from torch.testing._internal.common_utils import TestCase @@ -32,8 +35,11 @@ def test_node_source(self): dummy_source_dict, ) +<<<<<<< HEAD self.assertEqual(node_source, NodeSource._from_dict(node_source.to_dict())) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Dummy node node = torch.fx.Node( graph=torch.fx.Graph(), @@ -67,6 +73,7 @@ def test_node_source(self): }, ) +<<<<<<< HEAD # Test two node sources are same node_source1 = NodeSource( node=None, pass_name="test_pass", action=NodeSourceAction.CREATE @@ -123,6 +130,8 @@ def test_node_source(self): self.assertNotEqual(node_source_replace, node_source_create) self.assertNotEqual(hash(node_source_replace), hash(node_source_create)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_graph_provenance(self): def check_node_source(node_source_dict, name, pass_name, action): self.assertEqual(node_source_dict["name"], name) @@ -154,6 +163,7 @@ def forward(self, x): model = Model() example_inputs = (torch.randn(8, 10),) ep = torch.export.export(model, example_inputs, strict=True) +<<<<<<< HEAD decomposed_ep = ep.run_decompositions(default_decompositions()) # node decomposed from same ancestor node should have same from_node info @@ -205,6 +215,8 @@ def forward(self, x): node_name_to_from_node[node_name_2], ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gm = ep.module() provenance = get_graph_provenance_json(gm.graph) self.assertEqual( diff --git a/test/fx/test_fx_xform_observer.py b/test/fx/test_fx_xform_observer.py index d9dcb8504ba7b..36208d6e2a196 100644 --- a/test/fx/test_fx_xform_observer.py +++ b/test/fx/test_fx_xform_observer.py @@ -55,7 +55,11 @@ def replacement(x): ) ) +<<<<<<< HEAD @torch._inductor.config.patch("trace.provenance_tracking_level", 1) +======= + @torch._inductor.config.patch("trace.enabled", True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_graph_transform_observer_node_tracking(self): class M(torch.nn.Module): def forward(self, x): @@ -156,7 +160,11 @@ def forward(self, x): [NodeSourceAction.REPLACE, NodeSourceAction.CREATE], ) +<<<<<<< HEAD @torch._inductor.config.patch("trace.provenance_tracking_level", 1) +======= + @torch._inductor.config.patch("trace.enabled", True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_graph_transform_observer_deepcopy(self): class SimpleLinearModel(torch.nn.Module): def forward(self, x): @@ -179,6 +187,7 @@ def forward(self, x): self.assertEqual(len(gm2._erase_node_hooks), 0) self.assertEqual(len(gm2._deepcopy_hooks), 0) +<<<<<<< HEAD @torch._inductor.config.patch("trace.provenance_tracking_level", 1) def test_graph_transform_observer_replace(self): # the node sohuld should not be duplicated @@ -205,6 +214,8 @@ def forward(self, x): self.assertEqual(new_node.meta["from_node"][0].name, "add") self.assertEqual(new_node.meta["from_node"][0].pass_name, "test") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": raise RuntimeError( diff --git a/test/fx/test_lazy_graph_module.py b/test/fx/test_lazy_graph_module.py index a17bcb9151def..6d51d92a30d58 100644 --- a/test/fx/test_lazy_graph_module.py +++ b/test/fx/test_lazy_graph_module.py @@ -69,7 +69,11 @@ def f(x): def test_needs_recompile(self): """ +<<<<<<< HEAD Make sure needs_recompile() return the correct state. +======= + Make sure needs_recompile() return the corrent state. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ def f(x): @@ -141,7 +145,11 @@ def f(x): self.assertTrue(isinstance(gm2, _LazyGraphModule)) self.assertTrue(gm2._needs_recompile()) +<<<<<<< HEAD # make_fx will cal forward method of gm. That clears the _needs_recompile() +======= + # make_fx will cal foward method of gm. That clears the _needs_recompile() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # flag. self.assertFalse(gm._needs_recompile()) @@ -175,7 +183,11 @@ def f(x): def test_save_lazy_foward(self): """ +<<<<<<< HEAD Save the lazy forward method and call it repeatedly. Make sure we +======= + Save the lazy forward method and call it repeatly. Make sure we +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) don't recompile for each such call. """ diff --git a/test/fx/test_partitioner_order.py b/test/fx/test_partitioner_order.py index f4c3ef072f9a6..b1e6100079211 100644 --- a/test/fx/test_partitioner_order.py +++ b/test/fx/test_partitioner_order.py @@ -24,7 +24,10 @@ def __init__(self, graph_module: torch.fx.GraphModule): ) +<<<<<<< HEAD # original graph node order is: ['x', 'add', 'add_1', 'output'] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class AddModule(torch.nn.Module): def forward(self, x): y = torch.add(x, x) @@ -33,6 +36,7 @@ def forward(self, x): class TestPartitionerOrder(TestCase): +<<<<<<< HEAD # partitoner test to check graph node order remains the same with the original graph after partitioning def test_partitioner_graph_node_order(self): m = AddModule() @@ -50,6 +54,15 @@ def test_partitioner_multiple_runs_order(self): partitions = DummyPartitioner(traced_m).propose_partitions() partition_nodes = [list(partition.nodes) for partition in partitions] node_order = [n.name for n in partition_nodes[0]] +======= + # partitoner test to check graph node order + def test_partitioner_order(self): + m = AddModule() + traced_m = torch.fx.symbolic_trace(m) + partions = DummyPartitioner(traced_m).propose_partitions() + partion_nodes = [list(partition.nodes) for partition in partions] + node_order = [n.name for n in partion_nodes[0]] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for _ in range(10): traced_m = torch.fx.symbolic_trace(m) new_partion = DummyPartitioner(traced_m).propose_partitions() diff --git a/test/fx/test_pass_infra.py b/test/fx/test_pass_infra.py index 47531e15040eb..97cc813566b89 100644 --- a/test/fx/test_pass_infra.py +++ b/test/fx/test_pass_infra.py @@ -131,7 +131,11 @@ def check_bad_args(graph_module, i): def test_topological_sort(self): """ +<<<<<<< HEAD Tests that passes are correctly ordered based on constraints. +======= + Tests that passes are correctly ordered based on contraints. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ def pass0(x): diff --git a/test/higher_order_ops/test_invoke_quant.py b/test/higher_order_ops/test_invoke_quant.py index 7796a9e4a1685..2d894662458ab 100644 --- a/test/higher_order_ops/test_invoke_quant.py +++ b/test/higher_order_ops/test_invoke_quant.py @@ -186,7 +186,11 @@ def quant_matching(match: Match, *args, **kwargs): @skipIfXpu( msg="MM Triton template fusion for XPU not work because the fusion" +<<<<<<< HEAD " can not speedup, unskip until #146568 fixed." +======= + " can not speedup, unskip untill #146568 fixed." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @requires_gpu() @config.patch(prologue_fusion=True) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 34d8e41d8978e..7fde9a60cfc83 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -21,7 +21,10 @@ normalize_gm, ) from torch._higher_order_ops.schema import find_hop_schema +<<<<<<< HEAD from torch._inductor import config as inductor_config +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.pattern_matcher import ( CallFunctionVarArgs, PatternMatcherPass, @@ -34,7 +37,11 @@ TestCase, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu +======= +from torch.testing._internal.triton_utils import requires_cuda, requires_gpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nested_compile_region = torch.compiler.nested_compile_region @@ -253,6 +260,13 @@ def fn(mod, x, y): y_clone = y.detach().clone().requires_grad_(True) backend = EagerAndRecordGraphs() with ( +<<<<<<< HEAD +======= + mock.patch( + "torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation", + True, + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.no_grad(), ): res = torch.compile(fn, backend=backend, fullgraph=True)( @@ -329,8 +343,19 @@ def fn(mod, x, y): x_clone = x.detach().clone().requires_grad_(True) y_clone = y.detach().clone().requires_grad_(True) backend = AotEagerAndRecordGraphs() +<<<<<<< HEAD res = torch.compile(fn, backend=backend, fullgraph=True)(mod, x_clone, y_clone) res.sum().backward() +======= + with mock.patch( + "torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation", + True, + ): + res = torch.compile(fn, backend=backend, fullgraph=True)( + mod, x_clone, y_clone + ) + res.sum().backward() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(len(backend.fw_graphs), 1) self.assertEqual(len(backend.bw_graphs), 1) self.assertEqual(ref, res) @@ -424,23 +449,49 @@ def fn(mod, x, y): x_clone = x.detach().clone().requires_grad_(True) y_clone = y.detach().clone().requires_grad_(True) +<<<<<<< HEAD with torch.no_grad(): res = torch.compile(fn, fullgraph=True)(mod, x_clone, y_clone) +======= + with mock.patch( + "torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation", + True, + ): + with torch.no_grad(): + res = torch.compile(fn, fullgraph=True)(mod, x_clone, y_clone) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(ref, res) self.assertEqual(mod_ref.buf, mod.buf) mod = Mod() x_clone = x.detach().clone().requires_grad_(True) y_clone = y.detach().clone().requires_grad_(True) +<<<<<<< HEAD with torch.inference_mode(): res = torch.compile(fn, fullgraph=True)(mod, x_clone, y_clone) +======= + with mock.patch( + "torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation", + True, + ): + with torch.inference_mode(): + res = torch.compile(fn, fullgraph=True)(mod, x_clone, y_clone) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(ref, res) self.assertEqual(mod_ref.buf, mod.buf) mod = Mod() x_clone = x.detach().clone().requires_grad_(False) y_clone = y.detach().clone().requires_grad_(False) +<<<<<<< HEAD res = torch.compile(fn, fullgraph=True)(mod, x_clone, y_clone) +======= + with mock.patch( + "torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation", + True, + ): + res = torch.compile(fn, fullgraph=True)(mod, x_clone, y_clone) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(ref, res) self.assertEqual(mod_ref.buf, mod.buf) @@ -466,7 +517,15 @@ def fn(mod, x, y): RuntimeError, "does not currently support training with in-place input or buffer mutations", ): +<<<<<<< HEAD torch.compile(fn, backend="inductor", fullgraph=True)(mod, x, y) +======= + with mock.patch( + "torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation", + True, + ): + torch.compile(fn, backend="inductor", fullgraph=True)(mod, x, y) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_list(self): @nested_compile_region @@ -556,7 +615,11 @@ def fn(x): self.assertEqual(ref, res) self.assertEqual(x.grad, x_clone.grad) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_sdpa(self): @nested_compile_region def gn(q, k, v): @@ -620,7 +683,10 @@ def fn(x, y): self.assertEqual(ref, res) res.sum().backward() +<<<<<<< HEAD @inductor_config.patch("fx_graph_cache", False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dropout_checks_joint_graph(self): # `dropout` tests that joint graph passes (not just partitioner) is ran # on the hop graphs. Inductor rng functionalization happens in the joint @@ -677,9 +743,15 @@ def forward(self, primals_0: "f32[8]"): sin: "f32[8]" = torch.ops.aten.sin.default(primals_0) inductor_seeds_default: "i64[1]" = torch.ops.prims.inductor_seeds.default(1, device(type='cpu')) +<<<<<<< HEAD inductor_lookup_seed_default: "i64[]" = torch.ops.prims.inductor_lookup_seed.default(inductor_seeds_default, 0); inductor_seeds_default = None inductor_random_default: "f32[8]" = torch.ops.prims.inductor_random.default([8], inductor_lookup_seed_default, 'rand'); inductor_lookup_seed_default = None +======= + inductor_lookup_seed_default: "i64[]" = torch.ops.prims.inductor_lookup_seed.default(inductor_seeds_default, 0); inductor_seeds_default = None + inductor_random_default: "f32[8]" = torch.ops.prims.inductor_random.default([8], inductor_lookup_seed_default, 'rand'); inductor_lookup_seed_default = None + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gt: "b8[8]" = torch.ops.aten.gt.Scalar(inductor_random_default, 0.5); inductor_random_default = None mul: "f32[8]" = torch.ops.aten.mul.Tensor(gt, sin); sin = None mul_1: "f32[8]" = torch.ops.aten.mul.Tensor(mul, 2.0); mul = None @@ -692,7 +764,10 @@ def forward(self, primals_0: "f32[8]"): """, ) +<<<<<<< HEAD @inductor_config.patch("fx_graph_cache", False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dropout_checks_joint_graph_inference(self): # Checks that joint graph results in inductor seeds for just the inference graph @nested_compile_region @@ -722,9 +797,15 @@ def forward(self, arg0_1: "f32[8]"): class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[8]"): inductor_seeds_default: "i64[1]" = torch.ops.prims.inductor_seeds.default(1, device(type='cpu')) +<<<<<<< HEAD inductor_lookup_seed_default: "i64[]" = torch.ops.prims.inductor_lookup_seed.default(inductor_seeds_default, 0); inductor_seeds_default = None inductor_random_default: "f32[8]" = torch.ops.prims.inductor_random.default([8], inductor_lookup_seed_default, 'rand'); inductor_lookup_seed_default = None +======= + inductor_lookup_seed_default: "i64[]" = torch.ops.prims.inductor_lookup_seed.default(inductor_seeds_default, 0); inductor_seeds_default = None + inductor_random_default: "f32[8]" = torch.ops.prims.inductor_random.default([8], inductor_lookup_seed_default, 'rand'); inductor_lookup_seed_default = None + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gt: "b8[8]" = torch.ops.aten.gt.Scalar(inductor_random_default, 0.5); inductor_random_default = None sin: "f32[8]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None mul: "f32[8]" = torch.ops.aten.mul.Tensor(gt, sin); gt = sin = None @@ -920,7 +1001,10 @@ def forward(self, a: "f32[8]", l_y_: "f32[8]"): """, ) +<<<<<<< HEAD @inductor_config.patch("fx_graph_cache", False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_view_to_reshape(self): @nested_compile_region def gn(x): @@ -1033,6 +1117,7 @@ def fn(x, y): opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) +<<<<<<< HEAD x_clone = x.clone() self.assertEqual(opt_fn(x, y), fn(x_clone, y)) @@ -1115,6 +1200,19 @@ def _mock_invoke_subgraph(mode, subgraph, identifier, *operands): exp_out = fn(x_clone, y) self.assertEqual(exp_out, out) self.assertEqual(x_clone, x) +======= + with self.assertRaisesRegex( + RuntimeError, + "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", + ) as cm: + opt_fn(x, y) + + cause = cm.exception.__cause__ + self.assertIsInstance(cause, torch._dynamo.exc.Unsupported) + self.assertTrue( + "Encountered input mutation during higher order op tracing" in str(cause) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_input_mutation_inference_mode(self): @nested_compile_region @@ -1133,10 +1231,23 @@ def fn(x, y): with self.assertRaisesRegex( RuntimeError, +<<<<<<< HEAD "Inplace update to inference tensor outside InferenceMode is not allowed", ): opt_fn(x, y) +======= + "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", + ) as cm: + opt_fn(x, y) + + cause = cm.exception.__cause__ + self.assertIsInstance(cause, torch._dynamo.exc.Unsupported) + self.assertTrue( + "Encountered input mutation during higher order op tracing" in str(cause) + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_simple_module(self): mod = torch.nn.Linear(8, 8) @@ -1195,11 +1306,25 @@ def fn(x, y): opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) with self.assertRaisesRegex( +<<<<<<< HEAD torch._dynamo.exc.UncapturedHigherOrderOpError, "Encountered aliasing during higher order op tracing", ): opt_fn(x, y) +======= + RuntimeError, + "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", + ) as cm: + opt_fn(x, y) + + cause = cm.exception.__cause__ + self.assertIsInstance(cause, torch._dynamo.exc.Unsupported) + self.assertTrue( + "Encountered aliasing during higher order op tracing" in str(cause) + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_input_input_aliasing(self): @nested_compile_region def gn(x, y): @@ -1213,11 +1338,25 @@ def fn(x): opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) with self.assertRaisesRegex( +<<<<<<< HEAD torch._dynamo.exc.UncapturedHigherOrderOpError, "Encountered aliasing during higher order op tracing", ): opt_fn(x) +======= + RuntimeError, + "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", + ) as cm: + opt_fn(x) + + cause = cm.exception.__cause__ + self.assertIsInstance(cause, torch._dynamo.exc.Unsupported) + self.assertTrue( + "Encountered aliasing during higher order op tracing" in str(cause) + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_output_output_aliasing(self): @nested_compile_region def gn(x): @@ -1232,11 +1371,25 @@ def fn(x): opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) with self.assertRaisesRegex( +<<<<<<< HEAD torch._dynamo.exc.UncapturedHigherOrderOpError, "Encountered aliasing during higher order op tracing", ): opt_fn(x) +======= + RuntimeError, + "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", + ) as cm: + opt_fn(x) + + cause = cm.exception.__cause__ + self.assertIsInstance(cause, torch._dynamo.exc.Unsupported) + self.assertTrue( + "Encountered aliasing during higher order op tracing" in str(cause) + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_mod_attr_aliasing(self): class MutateParam(torch.nn.Module): def __init__(self): @@ -1258,12 +1411,30 @@ def fn(x, y): x = torch.randn(8, requires_grad=False) y = torch.randn(8, requires_grad=False) +<<<<<<< HEAD opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) compiled_out = opt_fn(x, y) # reset constant attr mod.a = torch.ones(8) self.assertEqual(compiled_out, fn(x, y)) +======= + fn(x, y) + + opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) + + with self.assertRaisesRegex( + RuntimeError, + "torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", + ) as cm: + opt_fn(x, y) + + cause = cm.exception.__cause__ + self.assertIsInstance(cause, torch._dynamo.exc.Unsupported) + self.assertTrue( + "Encountered input mutation during higher order op tracing" in str(cause) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_redundant_compile_region(self): @nested_compile_region @@ -1429,7 +1600,11 @@ def forward(self, l_x_: "f32[8, 8]"): """, ) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_return_none(self): from torch.nn import functional as F @@ -1498,7 +1673,11 @@ def forward(self, L_x_: "f32[8, 8]"): subgraph_0 = self.subgraph_0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_); subgraph_0 = l_x_ = None getitem: "f32[8, 8]" = invoke_subgraph[0] +<<<<<<< HEAD getitem_1: "f32[8, 8]" = invoke_subgraph[1]; invoke_subgraph = None +======= + getitem_1: "f32[8, 8]" = invoke_subgraph[2]; invoke_subgraph = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add: "f32[8, 8]" = getitem + getitem_1; getitem = getitem_1 = None return (add,) @@ -1507,7 +1686,11 @@ class subgraph_0(torch.nn.Module): def forward(self, l_x_: "f32[8, 8]"): child: "f32[8, 8]" = l_x_ * 2 child_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None +<<<<<<< HEAD return (child, child_1) +======= + return (child, None, child_1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, ) @@ -1520,16 +1703,26 @@ def forward(self, primals_1: "f32[8, 8]"): invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1); partitioned_fw_subgraph_0_0 = primals_1 = None getitem: "f32[8, 8]" = invoke_subgraph_2[0] +<<<<<<< HEAD getitem_1: "f32[8, 8]" = invoke_subgraph_2[1]; invoke_subgraph_2 = None add: "f32[8, 8]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None +======= + getitem_2: "f32[8, 8]" = invoke_subgraph_2[2]; invoke_subgraph_2 = None + + add: "f32[8, 8]" = torch.ops.aten.add.Tensor(getitem, getitem_2); getitem = getitem_2 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (add,) class partitioned_fw_subgraph_0_0(torch.nn.Module): def forward(self, primals_0: "f32[8, 8]"): mul: "f32[8, 8]" = torch.ops.aten.mul.Tensor(primals_0, 2) mul_1: "f32[8, 8]" = torch.ops.aten.mul.Tensor(primals_0, 3); primals_0 = None +<<<<<<< HEAD return (mul, mul_1) +======= + return (mul, None, mul_1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, ) @@ -1541,8 +1734,13 @@ def forward(self, tangents_1: "f32[8, 8]"): partitioned_bw_subgraph_0_0 = self.partitioned_bw_subgraph_0_0 invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', tangents_1, tangents_1); partitioned_bw_subgraph_0_0 = tangents_1 = None +<<<<<<< HEAD getitem_2: "f32[8, 8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None return (getitem_2,) +======= + getitem_3: "f32[8, 8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None + return (getitem_3,) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class partitioned_bw_subgraph_0_0(torch.nn.Module): def forward(self, tangents_0: "f32[8, 8]", tangents_1: "f32[8, 8]"): @@ -1772,6 +1970,7 @@ def fn(x): res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone) self.assertEqual(ref, res) +<<<<<<< HEAD @torch._inductor.config.patch(fallback_random=True) def test_ac_rng(self): def fn1(x): @@ -1838,6 +2037,8 @@ def fn(q, k, v): )(q, k, v) res.sum().backward() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_fake_tensor_checking(self): @nested_compile_region def gn(x): @@ -1888,6 +2089,7 @@ def forward(self, l_y_: "f32[16, 16]"): """, ) +<<<<<<< HEAD def test_return_size(self): def run(dynamic): torch.compiler.reset() @@ -1919,6 +2121,8 @@ def fn(x): run(dynamic=True) run(dynamic=False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_different_symint(self): """ Tests check that the same subgraph called with different symints use different graphs @@ -2123,7 +2327,11 @@ def fn(x, y): # NOTE THAT THIS TEST DOES NOT REALLY WORK # We wanted one invoke_subgraph called twice, but because of +<<<<<<< HEAD # constant_args_idx changing in the graph, the graph equivalence fails +======= + # constant_args_idx changing in the grpah, the graph equivalence fails +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not TEST_WITH_CROSSREF: self.assertExpectedInline( diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index 67facfb127d8e..2658a1b840081 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -328,7 +328,11 @@ def record_scalar_tensor(x, prefix): return # Meta function of the custom op +<<<<<<< HEAD @torch.library.register_fake( +======= + @torch.library.impl_abstract( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "mylib::record_scalar_tensor", lib=lib, ) diff --git a/test/inductor/custom_ops.cpp b/test/inductor/custom_ops.cpp index ade7695a10d02..1b9a5df656880 100644 --- a/test/inductor/custom_ops.cpp +++ b/test/inductor/custom_ops.cpp @@ -1,7 +1,12 @@ #include // @manual=fbcode//caffe2:libtorch +<<<<<<< HEAD #include // @manual #include // @manual +======= +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include diff --git a/test/inductor/indirect_assert_helper.py b/test/inductor/indirect_assert_helper.py index 6d1bc2b608fba..c34f427326360 100644 --- a/test/inductor/indirect_assert_helper.py +++ b/test/inductor/indirect_assert_helper.py @@ -73,6 +73,7 @@ def lower2(x): shape = (y.numel(),) + x.shape[2:] z = torch.randn(shape, device=GPU_TYPE) fn(x, y, z) +<<<<<<< HEAD # On Windows, Python will optimize away a function call if its updated value is not used. # Touch the memory of x so that the fn(x, y, z) will not be optimized away print(x) @@ -80,3 +81,9 @@ def lower2(x): print(fn(x)) else: print(fn(x, y)) +======= + elif fn_name in ("upper1", "upper2", "lower1", "lower2"): + fn(x) + else: + fn(x, y) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 917a914a5359e..3a743130156b1 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -20,6 +20,7 @@ from torch._dynamo.testing import rand_strided, same from torch._dynamo.utils import counters from torch._inductor import config +<<<<<<< HEAD from torch._inductor.codecache import WritableTempFile from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.package import package_aoti @@ -31,6 +32,12 @@ run_and_get_cpp_code, ) from torch._library import capture_triton +======= +from torch._inductor.package import package_aoti +from torch._inductor.runtime.runtime_utils import cache_dir +from torch._inductor.test_case import TestCase +from torch._inductor.utils import is_big_gpu, run_and_get_cpp_code +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._utils_internal import full_aoti_runtime_assert from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer @@ -39,12 +46,18 @@ from torch.testing import FileCheck from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import ( +<<<<<<< HEAD _get_torch_cuda_version, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_FP8, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, SM80OrLater, +<<<<<<< HEAD tf32_on_and_off, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from torch.testing._internal.common_device_type import ( _has_sufficient_memory, @@ -61,6 +74,7 @@ IS_FBCODE, IS_MACOS, IS_WINDOWS, +<<<<<<< HEAD MACOS_VERSION, MI300_ARCH, parametrize, @@ -71,6 +85,11 @@ skipIfWindows, skipIfXpu, TEST_MPS, +======= + parametrize, + skipIfRocm, + skipIfXpu, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TEST_WITH_ROCM, ) from torch.testing._internal.custom_tensor import CustomTensorPlainOut @@ -95,7 +114,10 @@ add_kernel_autotuned_weird_param_order, add_kernel_on_device_tma_new_api, add_kernel_on_device_tma_old_api, +<<<<<<< HEAD add_kernel_with_boolean_param, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add_kernel_with_none_param_and_equal_to_1_arg, add_kernel_with_optional_param, add_kernel_with_scaling, @@ -156,6 +178,7 @@ raise +<<<<<<< HEAD def get_module_ext_type(): if IS_WINDOWS: return "pyd" @@ -169,6 +192,11 @@ class AOTInductorTestsTemplate: @common_utils.parametrize("embed_kernel_binary", [False, True]) @common_utils.parametrize("max_autotune", [False, True]) @skipIfRocmArch(MI300_ARCH) +======= +class AOTInductorTestsTemplate: + @common_utils.parametrize("embed_kernel_binary", [False, True]) + @common_utils.parametrize("max_autotune", [False, True]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_simple(self, embed_kernel_binary, max_autotune): if self.device == "cpu" and IS_MACOS and max_autotune: raise unittest.SkipTest("max_autotune not supported on macos") @@ -197,9 +225,13 @@ def forward(self, x, y): _, code = run_and_get_cpp_code( AOTIRunnerUtil.compile, model, example_inputs ) +<<<<<<< HEAD if self.device == "mps": FileCheck().check("getKernelFunction(").run(code) elif self.device == GPU_TYPE: +======= + if self.device == GPU_TYPE: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) FileCheck().check("launchKernel(").run(code) if config.aot_inductor.embed_kernel_binary: # Not expect to see launchKernel("CUBIN_FILE_NAME" @@ -210,6 +242,7 @@ def forward(self, x, y): model, example_inputs, "AOTInductorModelRunMinimalArrayrefInterface(", 1 ) +<<<<<<< HEAD def test_triton_kernel_bool_param(self): if self.device != GPU_TYPE or self.device == "mps": raise unittest.SkipTest("requires GPU") @@ -230,10 +263,13 @@ def forward(self, x): inputs = (torch.randn(4, device=self.device),) self.check_model(Model(), inputs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf( IS_FBCODE, "toolchain doesn't support ptx to fatbin", ) +<<<<<<< HEAD @skipIfMPS @skipIfRocm # Skip embed_kernel_binary == True for now as it shows random @@ -242,6 +278,10 @@ def forward(self, x): @unittest.skipIf( _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" ) +======= + @skipIfRocm + @common_utils.parametrize("embed_kernel_binary", [True, False]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_simple_multi_arch(self, embed_kernel_binary): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU_TYPE") @@ -316,11 +356,15 @@ def forward(self, x, y): torch.randn(10, 10, device=self.device), torch.randn(10, 10, device=self.device), ) +<<<<<<< HEAD expected_path = normalize_path_separator( os.path.join( tempfile.mkdtemp(dir=cache_dir()), f"model.{get_module_ext_type()}" ) ) +======= + expected_path = os.path.join(tempfile.mkdtemp(dir=cache_dir()), "model.so") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) actual_path = AOTIRunnerUtil.legacy_compile( model, example_inputs, options={"aot_inductor.output_path": expected_path} ) @@ -483,6 +527,7 @@ def forward(self, y): ep, inductor_configs={"aot_inductor.use_runtime_constant_folding": True} ) +<<<<<<< HEAD @unittest.skipIf( TEST_MPS and MACOS_VERSION < 14.0, "Compilation error", @@ -516,6 +561,8 @@ def forward(self, y): }, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize("dynamic", [False, True]) @common_utils.parametrize("tma_version", ["new", "old"]) def test_triton_kernel_on_device_tma(self, dynamic, tma_version): @@ -568,7 +615,11 @@ def forward(self, a, b): triton.set_allocator( lambda size, align, stream: torch.empty( +<<<<<<< HEAD size, dtype=torch.int8, device=GPU_TYPE +======= + size, dtype=torch.int8, device="cuda" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) @@ -768,7 +819,10 @@ def forward(self, y): IS_FBCODE, "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used", ) +<<<<<<< HEAD @tf32_on_and_off(0.005) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_deconv_freezing(self): dtypes = [torch.float] if torch._C._has_mkldnn and torch.ops.mkldnn._is_mkldnn_bf16_supported(): @@ -844,10 +898,13 @@ def forward(self, a, b): inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device)) self.check_model(M(), inp) +<<<<<<< HEAD @unittest.skipIf( TEST_MPS and MACOS_VERSION < 14.0, "MPS BFloat16 is only supported on MacOS 14+", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_empty_cat_dtype_promotion(self): class Foo(torch.nn.Module): def forward(self, x, y): @@ -1170,7 +1227,10 @@ def forward(self, x, y): example_inputs = (x, y) self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) +<<<<<<< HEAD @skipIfWindows(msg="TODO: (xuhancn) confirm, Crash: access violation") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_large_dynamic_dim(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -1413,7 +1473,10 @@ def forward(self, a, b): dynamic_shapes=dynamic_shapes, ) +<<<<<<< HEAD @skipIfWindows(msg="TODO: (xuhancn) confirm, Crash: access violation") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_foreach_multiple_dynamic(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -1510,6 +1573,7 @@ def forward(self, x): self.check_model(Model(self.device), example_inputs) @skipIfNoFBGEMM +<<<<<<< HEAD def test_quantized_linear_bias_none(self): class Model(torch.nn.Module): def __init__(self, device): @@ -1526,6 +1590,8 @@ def forward(self, x): self.check_model(Model(self.device), example_inputs) @skipIfNoFBGEMM +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_quanatized_int8_linear(self): class Model(torch.nn.Module): def __init__(self, device): @@ -1595,10 +1661,13 @@ def forward(self, x, y): ) self.check_model(Repro(), example_inputs) +<<<<<<< HEAD @unittest.skipIf( TEST_MPS and MACOS_VERSION < 14.0, "bfloat16 is only supported on MacOS 14+", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_size_with_unbacked_add_expr(self): # Tests AOTI autotuning to make sure the correct input tensor sizes # are generated for sizes that include an expr such as s0 + u0. @@ -1857,7 +1926,11 @@ def forward(self, x): Foo(user_float_feature_idx, self.device), example_inputs, strict=False ).run_decompositions() gm = ep.module() +<<<<<<< HEAD self.check_model(gm.to(self.device), example_inputs) +======= + self.check_model(gm, example_inputs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_large_grid(self): if self.device != GPU_TYPE: @@ -2060,7 +2133,10 @@ def test_cond_unbacked_symint_closure(self, dynamic): dynamic_shapes=dynamic_shapes, ) +<<<<<<< HEAD @skipIfWindows(msg="TODO: (xuhancn) confirm, Crash: access violation") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize("dynamic", [False, True]) def test_cond_mismatched_branch_output(self, dynamic): inputs = ( @@ -2073,7 +2149,11 @@ def test_cond_mismatched_branch_output(self, dynamic): # Note the minimum has to be 4 because the model # is slicing over the first dim with [2:], if first # dim is 2 or 3, the slicing will be 0/1 specialized, +<<<<<<< HEAD # causing a constraint violation error. +======= + # causing a constraint violation eror. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dim0_a = Dim("s0", min=4, max=1024) dim0_b = Dim("s1", min=4, max=1024) dynamic_shapes = { @@ -2172,6 +2252,7 @@ def test_while_loop_with_outer_code(self): dynamic_shapes=dynamic_shapes, ) +<<<<<<< HEAD # mps doesn't support float64 @skipIfMPS def test_while_loop_with_parameters(self): @@ -2185,6 +2266,10 @@ def test_while_loop_with_parameters(self): device=self.device, ), ) +======= + def test_while_loop_with_parameters(self): + inputs = (torch.randn((10, 20), device=self.device),) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dim0_a = Dim("s0", min=2, max=1024) dynamic_shapes = { "c": {}, @@ -2438,9 +2523,12 @@ def forward(self, x): example_inputs = (torch.randn(10, device=self.device),) self.check_model(Model(self.device), example_inputs) +<<<<<<< HEAD @skipIfWindows( msg="OpenMP crashed application on windows" ) # TODO: (xuhancn) need to root cause and fix. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_buffer_mutation_3(self): class KVCache(torch.nn.Module): def __init__( @@ -2506,6 +2594,10 @@ def forward(self, x): torch._export.aot_compile(Model(), example_inputs) @skipCUDAIf(True, "Test for x86 backend") +<<<<<<< HEAD +======= + @skipIfXpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_FBCODE, "Need newer ideep") def test_buffer_mutation_and_force_mmap_weights(self): class Model(nn.Module): @@ -2539,7 +2631,10 @@ def forward(self, x): self.check_model(converted_model, example_inputs) +<<<<<<< HEAD @skipIfMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_fallback_mem_leak_fix(self): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") @@ -2584,7 +2679,10 @@ def forward(self, x, y, idx): torch.testing.assert_close(actual, expected) @requires_multigpu() +<<<<<<< HEAD @skipIfMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_replicate_on_devices(self): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") @@ -2624,7 +2722,10 @@ def forward(self, x, y): self.assertTrue(same(result_cpu, result_gpu.cpu())) @requires_multigpu() +<<<<<<< HEAD @skipIfMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_on_gpu_device1(self): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") @@ -2774,11 +2875,15 @@ def forward(self, x, y): model, example_inputs, atol=1e-4, rtol=1e-4 ) # 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py +<<<<<<< HEAD if self.device == "mps": self.code_check_count( model, example_inputs, '.getKernelFunction("generated_kernel")', 1 ) elif self.device == GPU_TYPE: +======= + if self.device == GPU_TYPE: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.code_check_count( model, example_inputs, "triton_poi_fused_sin_0 = loadKernel(", 1 ) @@ -3177,9 +3282,16 @@ def forward(self, x): # Call eval() here so that batch_norm won't update the running stats # Use float64 to avoid numeric difference failure +<<<<<<< HEAD dtype = torch.float32 if self.device == "mps" else torch.float64 model = Model().to(device=self.device, dtype=dtype).eval() example_inputs = (torch.randn(4, 3, 64, 64, device=self.device, dtype=dtype),) +======= + model = Model().to(device=self.device, dtype=torch.float64).eval() + example_inputs = ( + torch.randn(4, 3, 64, 64, device=self.device, dtype=torch.float64), + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.check_model(model, example_inputs) def test_triton_next_power_of_2(self): @@ -3240,7 +3352,10 @@ def forward(self, a, b, ranks): torch._dynamo.mark_dynamic(example_inputs[1], 0) self.check_model(Model(), example_inputs) +<<<<<<< HEAD @skipIfMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize("grid_type", [1, 2, 3]) @common_utils.parametrize("num_dims", [1, 2]) @common_utils.parametrize("dynamic", [False, True]) @@ -4232,6 +4347,10 @@ def forward(self, x, y): self.check_model(Model(), example_inputs) +<<<<<<< HEAD +======= + # @skipIfXpu(msg="torch.xpu.memory_allocated not supported yet") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_triton_kernel_reinterpret_view_mem_leak(self): # Check for memory leak when using user-defined Triton Kernel + AOTI. if self.device != GPU_TYPE: @@ -4271,7 +4390,10 @@ def forward(self, x, y): expected = Model()(*example_inputs) torch.testing.assert_close(actual, expected) +<<<<<<< HEAD @skipIfMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._dynamo.config.patch(capture_scalar_outputs=True) @common_utils.parametrize("dynamic", [False, True]) @common_utils.parametrize("autotuning", [False, True]) @@ -4363,7 +4485,11 @@ def test_aoti_runtime_asserts(self): def foo(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return a[: b.item()] +<<<<<<< HEAD @torch.library.register_fake("mylib::foo", lib=lib) +======= + @torch.library.impl_abstract("mylib::foo", lib=lib) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def foo_fake_impl(a, b): ctx = torch.library.get_ctx() u = ctx.new_dynamic_size() @@ -4427,7 +4553,10 @@ def forward(self, x): with self.assertRaisesRegex(Exception, "run_func_(.*) API call failed "): optimized(*input2) +<<<<<<< HEAD @skipIfWindows(msg="TODO: (xuhancn) confirm, Crash: access violation") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_index_put_with_none_index(self): # index_put falls back in the deterministic mode with DeterministicGuard(True): @@ -4452,6 +4581,7 @@ def forward(self, x, i1, i2, y): @patch.dict(os.environ, {"AOTI_RUNTIME_CHECK_INPUTS": "1"}) def test_runtime_checks(self): class Model(torch.nn.Module): +<<<<<<< HEAD def forward(self, inputs): return list(inputs.values()) @@ -4459,6 +4589,26 @@ def forward(self, inputs): dtypes = [ torch.float16, torch.float32, +======= + def __init__(self) -> None: + super().__init__() + + if SM80OrLater: + + def forward(self, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9): + return (x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) + + else: + + def forward(self, x0, x1, x2, x4, x5, x6, x7, x8, x9): + return (x0, x1, x2, x4, x5, x6, x7, x8, x9) + + inputs = [] + dtypes = [ + torch.float16, + torch.float32, + torch.float64, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.bool, torch.int8, torch.int16, @@ -4466,6 +4616,7 @@ def forward(self, inputs): torch.int64, torch.uint8, ] +<<<<<<< HEAD if not TEST_MPS: dtypes.append(torch.float64) @@ -4476,11 +4627,18 @@ def forward(self, inputs): inputs[f"x_{str(dtype)}"] = torch.ones( 4, 8, 10, dtype=dtype, device=self.device ) +======= + if SM80OrLater: + dtypes.append(torch.bfloat16) + for dtype in dtypes: + inputs.append(torch.ones(4, 8, 10, dtype=dtype, device=self.device)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dim0 = Dim("s0", min=2, max=1024) dim1 = Dim("s1", min=2, max=512) dim2 = Dim("s2", min=2, max=128) dynamic_shapes = { +<<<<<<< HEAD "x_torch.float16": {0: dim0}, "x_torch.float32": {0: dim0}, "x_torch.bool": {1: dim1}, @@ -4498,10 +4656,28 @@ def forward(self, inputs): m = Model() inputs = (inputs,) dynamic_shapes = (dynamic_shapes,) +======= + "x0": {0: dim0}, + "x1": {0: dim0}, + "x2": {0: dim0}, + "x4": {1: dim1}, + "x5": {1: dim1}, + "x6": {}, + "x7": {2: dim2}, + "x8": {2: dim2}, + "x9": {2: dim2}, + } + if SM80OrLater: + dynamic_shapes["x3"] = {1: dim1} + + m = Model() + inputs = tuple(inputs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch.no_grad(): so_path = AOTIRunnerUtil.legacy_compile( m, inputs, dynamic_shapes=dynamic_shapes ) +<<<<<<< HEAD # Expected results for the following checks: # ("unmatched dtype", "unmatched dim value at", "dim value is too", "unmatched stride value at") @@ -4515,26 +4691,50 @@ def forward(self, inputs): # 9 dynamic dims expected_results = (9, 19, 16, 19) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with open(os.path.splitext(so_path)[0] + ".cpp") as cpp: src_code = cpp.read() FileCheck().check_count( "unmatched dtype", +<<<<<<< HEAD expected_results[0], +======= + 10 if SM80OrLater else 9, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) exactly=True, ).run(src_code) FileCheck().check_count( "unmatched dim value at", +<<<<<<< HEAD expected_results[1], +======= + 21 + if SM80OrLater + else 19, # we have 9 dynamic dims for which we generate different checks +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) exactly=True, ).run(src_code) FileCheck().check_count( "dim value is too", +<<<<<<< HEAD expected_results[2], +======= + 18 + if SM80OrLater + else 16, # we have 9 dynamic dims for which we generate two checks +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) exactly=True, ).run(src_code) FileCheck().check_count( "unmatched stride value at", +<<<<<<< HEAD expected_results[3], +======= + 21 + if SM80OrLater + else 19, # we have 9 symbolic strides for which we don't generate checks +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) exactly=True, ).run(src_code) @@ -4708,7 +4908,11 @@ def forward(self, x): self.assertTrue(result[0].data_ptr() != result[1].data_ptr()) def test_multiple_output_alias(self): +<<<<<<< HEAD # Test when multiple outputs alias the same tensor +======= + # Test when mutliple outputs alias the same tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Model(torch.nn.Module): def forward(self, x): squared = x * x @@ -4798,10 +5002,13 @@ def forward(self, w, i, o): ) self.check_model(Model(), example_inputs) +<<<<<<< HEAD @unittest.skipIf( TEST_MPS and MACOS_VERSION < 14.0, "FFT operations are only supported on MacOS 14+", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_fft_c2c(self): class Model(torch.nn.Module): def forward(self, x): @@ -4886,10 +5093,14 @@ def forward(self, values, offsets): ) self.assertTrue(same(model(*example_input), actual)) +<<<<<<< HEAD # Temporarily skipping test as pytorch/cpuinfo not able to retrieve cache size for # AMD EPYC 9575F 64-Core Processor CPU in gfx942 VM Runners @common_utils.parametrize("max_autotune", [True, False]) @skipIfRocmArch(MI300_ARCH) +======= + @common_utils.parametrize("max_autotune", [True, False]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_misc_1(self, max_autotune): if self.device == "cpu" and IS_MACOS and max_autotune: raise unittest.SkipTest("max_autotune not supported on macos") @@ -4971,6 +5182,7 @@ def forward(self, a): a = torch.randn(batch, M, K, device=self.device) example_inputs = (a,) +<<<<<<< HEAD if self.device == "mps": kernel_calls = [("aoti_torch_mps_addmm_out", 2)] elif self.device == GPU_TYPE: @@ -4980,6 +5192,18 @@ def forward(self, a): ] else: kernel_calls = [("aoti_torch_cpu_addmm_out", 2)] +======= + kernel_calls = ( + [ + ("triton_poi_fused_0", 1), + (f"aoti_torch_{GPU_TYPE}_addmm_out", 2), + ] + if self.device == GPU_TYPE + else [ + ("aoti_torch_cpu_addmm_out", 2), + ] + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # test default debug printing all tensor values codegen with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}): @@ -5068,6 +5292,7 @@ def forward(self, a): _, code = run_and_get_cpp_code( AOTIRunnerUtil.compile, model, example_inputs ) +<<<<<<< HEAD shim_fn_codes = f'RAIIAtenRecordFunctionHandle .*\\("{kernel_calls}"' if enable_kernel_profile: FileCheck().check_regex(shim_fn_codes).run(code) @@ -5075,6 +5300,15 @@ def forward(self, a): FileCheck().check_not("RAIIAtenRecordFunctionHandle").run(code) self.check_model(Model(N, K, self.device), example_inputs) +======= + shim_fn_codes = ( + f'RECORD_FUNCTION("{kernel_calls}", c10::ArrayRef());' + ) + if enable_kernel_profile: + FileCheck().check(shim_fn_codes).run(code) + else: + FileCheck().check_not(shim_fn_codes).run(code) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_aoti_debug_printer_user_defined_triton_kernel(self): if self.device != GPU_TYPE: @@ -5176,7 +5410,11 @@ def forward(self, x): expected_scalar_args = [ "triton_poi_fused_zeros_like_0_xnumel", +<<<<<<< HEAD "triton_poi_fused_ones_1_xnumel", +======= + "triton_poi_fused_1_xnumel", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "std::max(static_cast(512L), static_cast(u0))", ] @@ -5272,9 +5510,15 @@ def forward(self, a, b, c): return z example_inputs = ( +<<<<<<< HEAD torch.randn(10, 20, device=GPU_TYPE), torch.randn(20, 30, device=GPU_TYPE), torch.randn(10, 30, device=GPU_TYPE), +======= + torch.randn(10, 20, device="cuda"), + torch.randn(20, 30, device="cuda"), + torch.randn(10, 30, device="cuda"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) model = Model() kernel_calls = [ @@ -5490,6 +5734,7 @@ def sin_triton(x, out): self.check_model(sin_triton, none_inputs) self.check_model(sin_triton, not_none_inputs) +<<<<<<< HEAD @skipIfRocm # RoCM does not support the config block size in test suite. def test_autotune_int64_user_defined_triton_kernel(self): if self.device != GPU_TYPE: @@ -5555,6 +5800,8 @@ def forward(self, x): @skipIfWindows( msg="OpenMP crashed application on windows" ) # TODO: (xuhancn) need to root cause and fix. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_issue_140766(self): class Model(torch.nn.Module): def __init__(self): @@ -5595,6 +5842,7 @@ def forward_block(self, x): example_inputs = (torch.randn(2, 128, 4096, device=self.device),) self.check_model(Model(), example_inputs, dynamic_shapes={"x": {0: bs}}) +<<<<<<< HEAD @requires_gpu def test_d2h_copy(self): # device to copy host should always have the same stride @@ -5632,6 +5880,8 @@ def forward(self, x): all_ops = [event.key for event in prof.key_averages()] self.assertTrue(not any("aten::contiguous" in op for op in all_ops)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_so_without_weight(self): class Model(torch.nn.Module): def __init__(self, n, k, device): @@ -5739,7 +5989,11 @@ def forward(self, a): example_inputs=example_inputs, ) +<<<<<<< HEAD with WritableTempFile(suffix=".pt2") as f: +======= + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) package_path = package_aoti( f.name, {"model": aoti_files}, @@ -5872,6 +6126,7 @@ def runner_call(*args, **kwargs): ) self.assertEqual(new_expected, new_output) +<<<<<<< HEAD def test_update_constant_buffer_simple(self): class Model(torch.nn.Module): def __init__(self, device): @@ -5919,6 +6174,8 @@ def runner_call(*args, **kwargs): output = runner_call(test_inputs) self.assertEqual(expected, output) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_update_inactive_constant_buffer(self): class Model(torch.nn.Module): def __init__(self, n, k, device): @@ -6151,6 +6408,7 @@ def runner_call(*args, **kwargs): ) self.assertEqual(new_expected, new_output) +<<<<<<< HEAD new_weights = { "L__self___weight": torch.randn(N, K, device=self.device), "L__self___bias": torch.randn(N, device=self.device), @@ -6176,6 +6434,8 @@ def runner_call(*args, **kwargs): with self.assertRaises(AssertionError): torch.testing.assert_close(new_expected, new_output) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cond_share_predicte(self): class Model(torch.nn.Module): def forward(self, predicate, x): @@ -6289,17 +6549,24 @@ def forward(self, x, y): ) @unittest.skipIf(IS_FBCODE, "Not runnable in fbcode") +<<<<<<< HEAD @unittest.skipIf( TEST_MPS and MACOS_VERSION < 14.0, "FFT operations are only supported on MacOS 14+", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_stft(self): N_FFT = 400 HOP_LENGTH = 160 class Model(torch.nn.Module): def forward(self, x): +<<<<<<< HEAD window = torch.hann_window(N_FFT, device=x.device) +======= + window = torch.hann_window(N_FFT).to(x.device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stft = torch.stft( x, N_FFT, HOP_LENGTH, window=window, return_complex=True ) @@ -6310,7 +6577,10 @@ def forward(self, x): example_inputs = (torch.randn(500, device=self.device),) self.check_model(model, example_inputs) +<<<<<<< HEAD @skipIfXpu +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv3d(self): if self.device != GPU_TYPE or not is_big_gpu(): raise unittest.SkipTest("requires modern GPU to run max-autotune") @@ -6367,6 +6637,12 @@ def forward( dynamic_shapes=dynamic_shapes, ) +<<<<<<< HEAD +======= + @skipIfXpu( + msg="The operator 'aten::_int_mm' is not currently implemented for the XPU device" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test__int_mm(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -6381,7 +6657,10 @@ def forward(self, x, y): ) self.check_model(Model(), example_inputs) +<<<<<<< HEAD @skipIfMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfXpu( msg="aten::convert_weight_to_int4pack is not currently implemented for XPU" ) @@ -6516,6 +6795,7 @@ def forward(self, x): rtol=1e-3, ) +<<<<<<< HEAD @runOnRocm def test_rocm_triton_autotuning(self): if self.device != GPU_TYPE: @@ -6558,6 +6838,8 @@ def forward(self, x, y, m): ): torch._export.aot_compile(Model(), (x, y, m)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfRocm # RoCM does not support the config block size in test suite. def test_triton_autotuning(self): if self.device != GPU_TYPE: @@ -6749,6 +7031,7 @@ def forward(self, x): } self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) +<<<<<<< HEAD def test_boolean_indexing(self): class Model(torch.nn.Module): def forward(self, x, y, z, x1, z1): @@ -6842,6 +7125,8 @@ def forward( ) self.check_model(m, example_inputs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_with_cudagraphs(self): if self.device != "cuda": raise unittest.SkipTest("requires CUDA") @@ -6914,6 +7199,7 @@ def wrapped(**kwargs): # compare against eager self.assertEqual(optimized(**model_kwargs), model(**model_kwargs)) +<<<<<<< HEAD def test_custom_op_in_subgraph(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( @@ -6957,6 +7243,8 @@ def forward(self, x): M(), list_example_inputs, dynamic_shapes=({0: Dim.DYNAMIC},) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_clamp_decomposition(self): class Model1(torch.nn.Module): def forward(self, x): @@ -6973,6 +7261,7 @@ def forward(self, x): # the output should have int type self.check_model(Model2(), (x,)) +<<<<<<< HEAD def test_upper_bound_i64(self): class Model(torch.nn.Module): def forward(self, x, y): @@ -7001,6 +7290,8 @@ def forward(self, x, y): # this test is mostly checking to ensure there's no IMA. m(*inp) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_using_model_name_for_files(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -7027,14 +7318,21 @@ def forward(self, x, y): with zipfile.ZipFile(package_path, "r") as zip_ref: all_files = zip_ref.namelist() base_dir = "test_model.wrapper/data/aotinductor/model/test_model" +<<<<<<< HEAD ext_type = get_module_ext_type() self.assertTrue(f"{base_dir}.wrapper.cpp" in all_files) self.assertTrue(f"{base_dir}.kernel.cpp" in all_files) self.assertTrue(f"{base_dir}.wrapper.{ext_type}" in all_files) +======= + self.assertTrue(f"{base_dir}.wrapper.cpp" in all_files) + self.assertTrue(f"{base_dir}.kernel.cpp" in all_files) + self.assertTrue(f"{base_dir}.wrapper.so" in all_files) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aot_inductor_module = torch._inductor.aoti_load_package(package_path) self.assertEqual(aot_inductor_module(*example_inputs), model(*example_inputs)) +<<<<<<< HEAD def test_copy_non_blocking_is_pinned(self): if self.device == "cpu" or self.device == "mps": raise unittest.SkipTest("only matters for device-to-cpu copy") @@ -7093,6 +7391,8 @@ def forward(self, x): "RAIIAtenTensorHandle buf0(buf0_handle_restrided);" ).run(code) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class AOTInductorLoggingTest(LoggingTestCase): @make_logging_test(dynamic=logging.DEBUG) @@ -7111,6 +7411,7 @@ def forward(self, x): torch._inductor.aot_compile(ep.module(), inputs) self.assertEqual([r.msg == "create_env" for r in records].count(True), 1) +<<<<<<< HEAD @make_logging_test(dynamic=logging.DEBUG) def test_shape_env_reuse_zero_consts_use_consts_asm_false(self, records): # make sure ShapeEnv is only created once and reused afterwards @@ -7175,6 +7476,8 @@ def test_compile_standalone_package_cpp_false_raises(self): with self.assertRaises(RuntimeError): maybe_aoti_standalone_config(patches) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate) @@ -7186,6 +7489,7 @@ def fail_cpu(is_skip=False): ) +<<<<<<< HEAD def fail_mps(is_skip=False): return TestFailure( ("mps",), @@ -7193,6 +7497,8 @@ def fail_mps(is_skip=False): ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def fail_gpu(suffixes: tuple[str, ...], is_skip=False): return TestFailure( suffixes, @@ -7211,13 +7517,17 @@ def fail_gpu(suffixes: tuple[str, ...], is_skip=False): # quantized unsupported for GPU "test_quantized_linear": fail_gpu(("cuda", "xpu")), "test_quanatized_int8_linear": fail_gpu(("cuda", "xpu")), +<<<<<<< HEAD "test_quantized_linear_bias_none": fail_gpu(("cuda", "xpu")), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # No scaled_dot_product_efficient_attention implementation for XPU yet. "test_scaled_dot_product_efficient_attention": fail_gpu(("xpu",)), # No fft implementation for XPU yet. "test_fft_c2c": fail_gpu(("xpu",), is_skip=True), } +<<<<<<< HEAD MPS_TEST_FAILURES = { # aten::_embedding_bag is not currently implemented for the MPS device. "test_embedding_bag": fail_mps(), @@ -7290,6 +7600,8 @@ def fail_gpu(suffixes: tuple[str, ...], is_skip=False): "test_autotune_int64_user_defined_triton_kernel": fail_mps(), } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class AOTInductorTestABICompatibleCpu(TestCase): device = "cpu" @@ -7327,6 +7639,7 @@ class AOTInductorTestABICompatibleGpu(TestCase): GPU_TEST_FAILURES, ) +<<<<<<< HEAD @unittest.skipIf(not torch.backends.mps.is_available(), "No MPS backend available") class AOTInductorTestABICompatibleMps(TestCase): @@ -7347,6 +7660,8 @@ class AOTInductorTestABICompatibleMps(TestCase): ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_aot_inductor_arrayref.py b/test/inductor/test_aot_inductor_arrayref.py index 492ad9c23c5c7..8ef48686dc258 100644 --- a/test/inductor/test_aot_inductor_arrayref.py +++ b/test/inductor/test_aot_inductor_arrayref.py @@ -70,7 +70,10 @@ def fail_minimal_arrayref_interface(is_skip=False): "test_cond_with_multiple_outputs": fail_minimal_arrayref_interface(), "test_cond_with_parameters": fail_minimal_arrayref_interface(), "test_cond_with_reinterpret_view_inputs_outputs": fail_minimal_arrayref_interface(), +<<<<<<< HEAD "test_custom_op_in_subgraph": fail_minimal_arrayref_interface(), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "test_cond_share_predicte": fail_stack_allocation(is_skip=True), "test_cond_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(), "test_while_loop_with_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(), @@ -85,7 +88,10 @@ def fail_minimal_arrayref_interface(is_skip=False): "test_while_loop_with_pytree_inputs": fail_stack_allocation(), # FIXME: failed with Segfault while exiting the Python runtime "test_duplicate_constant_folding": fail_stack_allocation(is_skip=True), +<<<<<<< HEAD "test_aot_inductor_consts_cpp_build": fail_stack_allocation(is_skip=True), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "test_stride_with_unbacked_expr": fail_minimal_arrayref_interface(is_skip=True), # TODO: use of deleted function RAIIAtenTensorHandle "test_dup_unbacked_sym_decl": fail_minimal_arrayref_interface(is_skip=True), diff --git a/test/inductor/test_aot_inductor_custom_ops.py b/test/inductor/test_aot_inductor_custom_ops.py index 0b4f508477ac4..93d71a3e828a2 100644 --- a/test/inductor/test_aot_inductor_custom_ops.py +++ b/test/inductor/test_aot_inductor_custom_ops.py @@ -1,5 +1,9 @@ # Owner(s): ["module: inductor"] +<<<<<<< HEAD # This test requires libaoti_custom_ops.so to be built, which happens when BUILD_TEST = 1 +======= +# This test requires libaoti_custom_ops.so to be built, which happnes when BUILD_TEST = 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import logging import os import sys @@ -24,7 +28,11 @@ skipIfXpu, ) from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test +<<<<<<< HEAD from torch.testing._internal.triton_utils import HAS_CUDA_AND_TRITON +======= +from torch.testing._internal.triton_utils import HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._python_dispatch import TorchDispatchMode @@ -416,7 +424,10 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): @skipIfXpu @skipIfRocm +<<<<<<< HEAD @unittest.skipIf(IS_FBCODE, "unable to find library -laoti_custom_ops") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_custom_op_square(self) -> None: class Model(torch.nn.Module): def forward(self, x): @@ -512,7 +523,10 @@ def fail_cuda(is_skip=False): # quantized unsupported for GPU "test_quantized_linear": fail_cuda(), "test_quanatized_int8_linear": fail_cuda(), +<<<<<<< HEAD "test_quantized_linear_bias_none": fail_cuda(), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } @@ -556,5 +570,9 @@ class AOTInductorTestABICompatibleCuda(AOTICustomOpTestCase): from torch._inductor.test_case import run_tests # cpp_extension N/A in fbcode +<<<<<<< HEAD if HAS_CUDA_AND_TRITON or sys.platform == "darwin": +======= + if HAS_CUDA or sys.platform == "darwin": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests(needs="filelock") diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 0eb1057c802eb..c3561a5d57a64 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -15,6 +15,7 @@ from parameterized import parameterized_class import torch +<<<<<<< HEAD import torch._inductor.config from torch._inductor.codecache import get_kernel_bin_format from torch._inductor.package import load_package, package_aoti @@ -28,6 +29,14 @@ load_weights_to_pt2_contents, ) from torch.testing._internal.common_cuda import _get_torch_cuda_version +======= +from torch._inductor.codecache import get_kernel_bin_format +from torch._inductor.package import AOTICompiledModel, load_package, package_aoti +from torch._inductor.test_case import TestCase +from torch._inductor.utils import fresh_cache +from torch.export import Dim +from torch.export.pt2_archive._package import load_pt2, load_weights_to_pt2_contents +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import ( IS_FBCODE, skipIfRocm, @@ -133,6 +142,7 @@ def check_model( self.assertEqual(actual, expected, atol=atol, rtol=rtol) return compiled_model +<<<<<<< HEAD def check_package_cpp_only(self: TestCase) -> None: """ Check if cmake and make are available. @@ -210,6 +220,8 @@ def cmake_compile(self, model, example_inputs, options, tmp_dir): subprocess.run(["make"], cwd=build_path, check=True) return build_path, tmp_path +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_add(self): class Model(torch.nn.Module): def forward(self, x, y): @@ -224,7 +236,11 @@ def forward(self, x, y): def test_remove_intermediate_files(self): # For CUDA, generated cpp files contain absolute path to the generated cubin files. # With the package artifact, that cubin path should be overridden at the run time, +<<<<<<< HEAD # so removing those intermediate files in this test to verify that. +======= + # so removing those intermeidate files in this test to verify that. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Model(torch.nn.Module): def forward(self, x, y): return x + y @@ -271,12 +287,23 @@ def forward(self, x, y): self.check_model(Model(), example_inputs) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") +<<<<<<< HEAD @unittest.skipIf( _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" ) @skipIfXpu # build system may be different def test_compile_after_package(self): self.check_package_cpp_only() +======= + @skipIfXpu # build system may be different + def test_compile_after_package(self): + if not self.package_cpp_only: + raise unittest.SkipTest("Only meant to test cpp package") + if shutil.which("cmake") is None: + raise unittest.SkipTest("cmake is not available") + if shutil.which("make") is None: + raise unittest.SkipTest("make is not available") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Model(torch.nn.Module): def __init__(self) -> None: @@ -299,6 +326,7 @@ def forward(self, x, y): # Require kernels to be compiled into .o files "aot_inductor.embed_kernel_binary": True, } +<<<<<<< HEAD with ( tempfile.TemporaryDirectory() as tmp_dir, ): @@ -306,12 +334,44 @@ def forward(self, x, y): model, example_inputs, options, tmp_dir ) +======= + ep = torch.export.export(model, example_inputs, strict=True) + package_path = torch._inductor.aoti_compile_and_package( + ep, inductor_configs=options + ) + with ( + tempfile.TemporaryDirectory() as tmp_dir, + zipfile.ZipFile(package_path, "r") as zip_ref, + ): + filenames = zip_ref.namelist() + prefix = filenames[0].split("/")[0] + zip_ref.extractall(tmp_dir) + tmp_path = Path(tmp_dir) / prefix / "data" / "aotinductor" / "model" + self.assertTrue(tmp_path.exists()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.device == GPU_TYPE: kernel_bin = get_kernel_bin_format(self.device) self.assertTrue(not list(tmp_path.glob(f"*.{kernel_bin}"))) # Check if .cubin.o files exist and use unique kernel names self.assertTrue(list(tmp_path.glob(f"triton_*.{kernel_bin}.o"))) +<<<<<<< HEAD +======= + build_path = tmp_path / "build" + self.assertTrue(not build_path.exists()) + + # Create a build directory to run cmake + build_path.mkdir() + custom_env = os.environ.copy() + custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent) + subprocess.run( + ["cmake", ".."], + cwd=build_path, + env=custom_env, + ) + subprocess.run(["make"], cwd=build_path) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Check if the .so file was build successfully so_path = build_path / "libaoti_model.so" self.assertTrue(so_path.exists()) @@ -319,16 +379,28 @@ def forward(self, x, y): actual = optimized(*example_inputs) self.assertTrue(torch.allclose(actual, expected)) +<<<<<<< HEAD @unittest.skipIf( _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfRocm # doesn't support multi-arch binary @skipIfXpu # doesn't support multi-arch binary def test_compile_after_package_multi_arch(self): if self.device != GPU_TYPE: raise unittest.SkipTest("Only meant to test GPU_TYPE") +<<<<<<< HEAD self.check_package_cpp_only() +======= + if not self.package_cpp_only: + raise unittest.SkipTest("Only meant to test cpp package") + if shutil.which("cmake") is None: + raise unittest.SkipTest("cmake is not available") + if shutil.which("make") is None: + raise unittest.SkipTest("make is not available") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Model(torch.nn.Module): def __init__(self) -> None: @@ -348,17 +420,49 @@ def forward(self, x, y): options = { "aot_inductor.package_cpp_only": self.package_cpp_only, +<<<<<<< HEAD # Expect kernel to be embedded in the final binary. +======= + # Expect kernel to be embeded in the final binary. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We will make it the default behavior for the standalone mode. "aot_inductor.emit_multi_arch_kernel": True, "aot_inductor.embed_kernel_binary": True, } +<<<<<<< HEAD with ( tempfile.TemporaryDirectory() as tmp_dir, ): build_path, _ = self.cmake_compile( model, example_inputs, options, tmp_dir ) +======= + ep = torch.export.export(model, example_inputs) + package_path = torch._inductor.aoti_compile_and_package( + ep, inductor_configs=options + ) + with ( + tempfile.TemporaryDirectory() as tmp_dir, + zipfile.ZipFile(package_path, "r") as zip_ref, + ): + filenames = zip_ref.namelist() + prefix = filenames[0].split("/")[0] + zip_ref.extractall(tmp_dir) + tmp_path = Path(tmp_dir) / prefix / "data" / "aotinductor" / "model" + self.assertTrue(tmp_path.exists()) + # Create a build directory to run cmake + build_path = tmp_path / "build" + build_path.mkdir() + custom_env = os.environ.copy() + custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent) + subprocess.run( + ["cmake", ".."], + cwd=build_path, + env=custom_env, + ) + subprocess.run(["make"], cwd=build_path) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Check if the .so file was build successfully so_path = build_path / "libaoti_model.so" self.assertTrue(so_path.exists()) @@ -366,6 +470,7 @@ def forward(self, x, y): actual = optimized(*example_inputs) self.assertTrue(torch.allclose(actual, expected)) +<<<<<<< HEAD @unittest.skipIf( _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" ) @@ -565,6 +670,8 @@ def default(*args, **kwargs): true_res = next(iter(tensor_model.parameters())) self.assertEqual(expected_res, true_res) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_metadata(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -1024,6 +1131,7 @@ def forward(self, x): loaded1 = pt2_contents.aoti_runners["model"] self.assertEqual(loaded1(x), bar1(x)) +<<<<<<< HEAD def test_loading_wrong_model(self): class Model(torch.nn.Module): def forward(self, x): @@ -1040,6 +1148,8 @@ def forward(self, x): ): load_package(package_path, model_name="forward") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_aot_inductor_utils.py b/test/inductor/test_aot_inductor_utils.py index 50edf7b695ad8..59aab6e0a1393 100644 --- a/test/inductor/test_aot_inductor_utils.py +++ b/test/inductor/test_aot_inductor_utils.py @@ -102,8 +102,11 @@ def legacy_load_runner(device, so_path: str) -> "AOTIModelContainerRunner": return torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) elif device == "xpu": return torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device) +<<<<<<< HEAD elif device == "mps": return torch._C._aoti.AOTIModelContainerRunnerMps(so_path, 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) @@ -148,7 +151,11 @@ def legacy_run( @staticmethod def compile( model: Union[torch.nn.Module, types.FunctionType], +<<<<<<< HEAD example_inputs: tuple[torch.Tensor, ...], +======= + example_inputs: list[torch.Tensor], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inductor_configs: Optional[dict[str, Any]] = None, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, ): @@ -159,11 +166,15 @@ def compile( with torch.no_grad(): # strict=False needs extra migration work ep = torch.export.export( +<<<<<<< HEAD model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True, prefer_deferred_runtime_asserts_over_guards=True, +======= + model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) package_path = torch._inductor.aoti_compile_and_package( ep, inductor_configs=inductor_configs @@ -173,7 +184,11 @@ def compile( @staticmethod def run( model: Union[torch.nn.Module, types.FunctionType], +<<<<<<< HEAD example_inputs: tuple[torch.Tensor, ...], +======= + example_inputs: list[torch.Tensor], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inductor_configs: Optional[dict[str, Any]] = None, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, ): @@ -189,7 +204,11 @@ def run( @staticmethod def run_multiple( model: Union[torch.nn.Module, types.FunctionType], +<<<<<<< HEAD list_example_inputs: list[tuple[torch.Tensor, ...]], +======= + list_example_inputs: list[list[torch.Tensor]], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inductor_configs: Optional[dict[str, Any]] = None, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, ): @@ -228,7 +247,11 @@ def check_model( if not isinstance(model, types.FunctionType): model = model.to(self.device) +<<<<<<< HEAD # For non mixed device inputs with default "cpu",set the device manually. +======= + # For non mixed device inputs with default "cpu",set the device manully. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if all( t.device.type == "cpu" for t in example_inputs diff --git a/test/inductor/test_async_compile.py b/test/inductor/test_async_compile.py index cc94c4c95e01a..5ec5ee0c66b1b 100644 --- a/test/inductor/test_async_compile.py +++ b/test/inductor/test_async_compile.py @@ -1,4 +1,5 @@ # Owner(s): ["module: inductor"] +<<<<<<< HEAD from unittest.mock import patch import torch @@ -9,6 +10,11 @@ from torch._inductor.runtime.triton_heuristics import ( generate_lookup_hash_from_source_code, ) +======= +import torch +from torch._inductor import config +from torch._inductor.async_compile import AsyncCompile, shutdown_compile_workers +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import fresh_cache from torch.testing._internal.common_utils import ( @@ -36,12 +42,18 @@ def fn(x, y): with config.patch("worker_start_method", method): shutdown_compile_workers() +<<<<<<< HEAD AsyncCompile.wait_pool_ready() +======= + pool = AsyncCompile.process_pool() + pool.ready_future.result(timeout=120) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with fresh_cache(): compiled_fn = torch.compile(fn) self.assertEqual(fn(x, y), compiled_fn(x, y)) +<<<<<<< HEAD @requires_gpu() @requires_triton() def test_bad_kernel(self): @@ -153,6 +165,8 @@ def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.cons self.assertEqual(args[1].num_warps, autotune_config.num_warps) self.assertEqual(args[1].num_stages, autotune_config.num_stages) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index 6025c90cdb4a2..c3402cf643b0d 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -185,6 +185,7 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): +<<<<<<< HEAD foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = foo_default = None return ()""", # noqa: B950 ignore_comments=True, @@ -193,6 +194,11 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3 # stack trace should be in post_grad_graph self.assertTrue( "code: torch.ops.mylib.foo(x, y, z, 2, n)" in post_grad_graphs, +======= + # No stacktrace found for following nodes + foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = foo_default = None + return ()""", # noqa: B950 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) @@ -216,7 +222,11 @@ def foo_impl(x, y, z, w, n): z.add_(y[1] + n) return y[0] + w, y[1] + n +<<<<<<< HEAD @torch.library.register_fake("mylib::foo", lib=lib) +======= + @torch.library.impl_abstract("mylib::foo", lib=lib) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def foo_abstract(x, y, z, w, n): return y[0] + w, y[1] + n @@ -246,7 +256,11 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): +<<<<<<< HEAD foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = None +======= + foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) getitem_4: "f32[3][1]cpu" = foo_default[0] getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None return (getitem_4, getitem_5)""", # noqa: B950 @@ -333,6 +347,7 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): +<<<<<<< HEAD foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg0_1, 2, arg1_1); arg2_1 = \ arg3_1 = arg0_1 = arg1_1 = foo_default = None return ()""", @@ -342,6 +357,12 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3 # stack trace should be in post_grad_graph self.assertTrue( "code: torch.ops.mylib.foo(x, y, z, 2, n)" in post_grad_graphs, +======= + # No stacktrace found for following nodes + foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); \ +arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None + return ()""", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) @@ -414,10 +435,17 @@ def f(x, y, z, n): self.assertExpectedInline( post_grad_graphs, """\ +<<<<<<< HEAD def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu", arg2_1: "f32[s77][1]cpu", arg3_1: "f32[s77][1]cpu", arg4_1: "f32[s77][1]cpu", arg5_1: "f32[s77][1]cpu"): foo_default = torch.ops.mylib.foo.default(arg1_1, [arg4_1, arg5_1], arg2_1, 2, arg3_1); arg4_1 = arg5_1 = arg3_1 = foo_default = None copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None copy__1: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None +======= +def forward(self, arg0_1: "Sym(s72)", arg1_1: "f32[s72][1]cpu", arg2_1: "f32[s72][1]cpu", arg3_1: "f32[s72][1]cpu", arg4_1: "f32[s72][1]cpu", arg5_1: "f32[s72][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1); arg4_1 = arg5_1 = arg1_1 = foo_default = None + copy_: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None + copy__1: "f32[s72][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ()""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -427,9 +455,15 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu", arg2_1: "f32[s77 post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): +<<<<<<< HEAD foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg3_1 = arg4_1 = arg2_1 = foo_default = None copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None +======= + foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = foo_default = None + copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ()""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -443,7 +477,11 @@ def run_aot_eager(self, f, orig_args, _dynamic=False): aot_eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) log_stream, ctx = logs_to_string( +<<<<<<< HEAD "torch._functorch._aot_autograd.graph_capture", "aot_graphs" +======= + "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) result = None @@ -493,7 +531,11 @@ def foo_impl(x, y, z, w, n): z.add_(y[1] + n) return y[0] + w, y[1] + n +<<<<<<< HEAD @torch.library.register_fake("mylib::foo", lib=lib) +======= + @torch.library.impl_abstract("mylib::foo", lib=lib) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def foo_abstract(x, y, z, w, n): return y[0] + w, y[1] + n @@ -521,11 +563,19 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): +<<<<<<< HEAD foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg3_1 = arg4_1 = arg2_1 = None getitem_4: "f32[3][1]cpu" = foo_default[0] getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None +======= + foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = None + getitem_4: "f32[3][1]cpu" = foo_default[0] + getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None + copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (getitem_4, getitem_5)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -579,6 +629,7 @@ def f(x, y): self.assertExpectedInline( graph_aot, """\ +<<<<<<< HEAD def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu", arg2_1: "f32[s77][1]cpu"): auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg2_1]) getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1] @@ -586,6 +637,15 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu", arg2_1: "f32[s77 add: "f32[s77][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None copy__1: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_2); arg2_1 = getitem_2 = copy__1 = None +======= +def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1]) + getitem_1: "f32[s17][1]cpu" = auto_functionalized_v2[1] + getitem_2: "f32[s17][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) + copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None + copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (add,)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -595,12 +655,21 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu", arg2_1: "f32[s77 graph_aot, """\ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): +<<<<<<< HEAD auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg0_1, arg1_1]) getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1] getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None +======= + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg0_1]) + getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1] + getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_2); arg0_1 = getitem_2 = copy_ = None + copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy__1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (add,)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -611,11 +680,19 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): self.assertExpectedInline( graph_inductor, """\ +<<<<<<< HEAD def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu", arg2_1: "f32[s77][1]cpu"): foo_default = torch.ops.mylib.foo.default(arg1_1, arg2_1); foo_default = None add: "f32[s77][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg2_1) copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None copy__1: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None +======= +def forward(self, arg0_1: "Sym(s17)", arg1_1: "f32[s17][1]cpu", arg2_1: "f32[s17][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None + add: "f32[s17][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1) + copy_: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + copy__1: "f32[s17][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (add,)""", ignore_comments=True, ignore_empty_lines=True, @@ -625,8 +702,13 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu", arg2_1: "f32[s77 graph_inductor, """\ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): +<<<<<<< HEAD foo_default = torch.ops.mylib.foo.default(arg0_1, arg1_1); foo_default = None add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) +======= + foo_default = torch.ops.mylib.foo.default(arg1_1, arg0_1); foo_default = None + add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg0_1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None return (add,)""", @@ -841,11 +923,19 @@ def f(x, y): graph_aot, """\ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): +<<<<<<< HEAD auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg0_1, arg1_1]) getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1] getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None +======= + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg0_1]) + getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1] + getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_2); arg0_1 = getitem_2 = copy_ = None + copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy__1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ()""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -857,7 +947,11 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): graph_inductor, """\ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): +<<<<<<< HEAD foo_default = torch.ops.mylib.foo.default(arg0_1, arg1_1); foo_default = None +======= + foo_default = torch.ops.mylib.foo.default(arg1_1, arg0_1); foo_default = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None return ()""", # noqa: B950 @@ -977,8 +1071,13 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): +<<<<<<< HEAD foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg0_1, 2, arg1_1); arg2_1 = arg3_1 = arg1_1 = foo_default = None copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None +======= + foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None + copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ()""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -1414,7 +1513,11 @@ def test_round_trip(base, tensor): test_round_trip(t, f[1]) test_round_trip(t, f[2]) +<<<<<<< HEAD # example where slice won't work +======= + # example where slice wont work +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # selection t = torch.ones(10) @@ -1576,7 +1679,11 @@ def forward(self, arg0_1: "f32[2][1]cpu"): def test_alias2_dynamic(self): self.test_alias2(_dynamic=True) +<<<<<<< HEAD # Test that the view regeneration optimizations do not result in recompilations. By comparing re-compilation in eager backend +======= + # Test that the view regenration optimizations do not result in recompilations. By comparing re-compilation in eager backend +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # with recompilation in inductor backend. @torch.fx.experimental._config.patch(use_duck_shape=False) def test_recompile(self): diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py index 56310adc977d3..2195f1e9eb058 100644 --- a/test/inductor/test_benchmark_fusion.py +++ b/test/inductor/test_benchmark_fusion.py @@ -13,7 +13,11 @@ from torch.testing._internal.inductor_utils import ( get_func_call, HAS_CPU, +<<<<<<< HEAD HAS_CUDA_AND_TRITON, +======= + HAS_CUDA, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) IS_BIG_GPU, ) @@ -197,7 +201,11 @@ def f(x): self.common(f, (x,)) +<<<<<<< HEAD if HAS_CUDA_AND_TRITON: +======= +if HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class BenchmarkFusionCudaTest(TestCase): common = check_model_cuda @@ -294,7 +302,11 @@ def test_equivalent_template_code(self): for out_code in [code, code2]: FileCheck().check(get_func_call()).check_count( "empty_strided", 1, exactly=True +<<<<<<< HEAD ).check("triton_tem_fused_addmm_relu_t_0").check_count( +======= + ).check("triton_tem_fused_addmm_relu_0").check_count( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ".reset()" if config.cpp_wrapper else "del", 3, exactly=True ).check("" if config.cpp_wrapper else "return").run(out_code[0]) @@ -347,5 +359,9 @@ class BenchmarkFusionCpuTest(TestCase): if __name__ == "__main__": from torch._inductor.test_case import run_tests +<<<<<<< HEAD if HAS_CPU or HAS_CUDA_AND_TRITON: +======= + if HAS_CPU or HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests() diff --git a/test/inductor/test_ck_backend.py b/test/inductor/test_ck_backend.py index f73a47e45a57a..9f03586984785 100644 --- a/test/inductor/test_ck_backend.py +++ b/test/inductor/test_ck_backend.py @@ -1,5 +1,8 @@ # Owner(s): ["module: inductor"] +<<<<<<< HEAD import functools +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import logging import os import unittest @@ -13,11 +16,15 @@ import torch from torch._inductor import config from torch._inductor.test_case import run_tests, TestCase +<<<<<<< HEAD from torch.testing._internal.common_cuda import tf32_off +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, ) +<<<<<<< HEAD from torch.testing._internal.inductor_utils import ( _quantize_rowwise, _quantize_tensorwise, @@ -27,12 +34,28 @@ if HAS_CUDA_AND_TRITON: +======= +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA + + +try: + from .test_fp8 import _quantize_rowwise, _quantize_tensorwise +except ImportError: + from test_fp8 import _quantize_rowwise, _quantize_tensorwise + + +torch.set_float32_matmul_precision("high") +if HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.cuda.memory._set_allocator_settings("expandable_segments:False") log = logging.getLogger(__name__) +<<<<<<< HEAD @functools.lru_cache(None) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _get_path_without_sccache() -> str: """ Get the PATH environment variable without sccache. @@ -42,12 +65,15 @@ def _get_path_without_sccache() -> str: return ":".join(path_envs) +<<<<<<< HEAD _test_env = { "PATH": _get_path_without_sccache(), "DISABLE_SCCACHE": "1", } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @instantiate_parametrized_tests class TestCKBackend(TestCase): def setUp(self): @@ -78,7 +104,11 @@ def setUp(self): ) @unittest.skipIf(not torch.version.hip, "ROCM only") +<<<<<<< HEAD @unittest.mock.patch.dict(os.environ, _test_env) +======= + @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("max_autotune_gemm_backends", ("CK", "CKTILE", "ATen,Triton,CK")) @parametrize("autotune_in_subproc", (True, False)) @parametrize("use_aoti", (True, False)) @@ -89,6 +119,11 @@ def test_max_autotune_precompile_matmul( Make sure autotuning mm doesn't crash. """ +<<<<<<< HEAD +======= + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def mm(a, b): return a @ b @@ -99,6 +134,7 @@ def mm(a, b): assert "rocm" in dir(config) +<<<<<<< HEAD with ( config.patch( { @@ -112,6 +148,18 @@ def mm(a, b): } ), tf32_off(), +======= + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": autotune_in_subproc, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "compile_threads": 16, + "rocm.ck_max_profiling_configs": 8, + "rocm.ck_tile_max_profiling_configs": 8, + "rocm.ck_dir": self.ck_dir, + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): if use_aoti: Y_compiled = AOTIRunnerUtil.run( @@ -130,7 +178,11 @@ def compiled_mm(x, w): torch.testing.assert_close(Y_compiled, Y) @unittest.skipIf(not torch.version.hip, "ROCM only") +<<<<<<< HEAD @unittest.mock.patch.dict(os.environ, _test_env) +======= + @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("max_autotune_gemm_backends", ("CK",)) @parametrize("autotune_in_subproc", (True,)) def test_max_autotune_precompile_matmul_dynamic( @@ -140,6 +192,11 @@ def test_max_autotune_precompile_matmul_dynamic( Test matmul with dynamic shapes """ +<<<<<<< HEAD +======= + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor_options = {"device": "cuda", "dtype": torch.bfloat16} a = torch.randn(2240, 256, **tensor_options) @@ -149,6 +206,7 @@ def test_max_autotune_precompile_matmul_dynamic( assert "rocm" in dir(config) +<<<<<<< HEAD with ( config.patch( { @@ -162,6 +220,18 @@ def test_max_autotune_precompile_matmul_dynamic( } ), tf32_off(), +======= + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": autotune_in_subproc, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "compile_threads": 16, + "rocm.ck_max_profiling_configs": 8, + "rocm.ck_tile_max_profiling_configs": 8, + "rocm.ck_dir": self.ck_dir, + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): @torch.compile(dynamic=True) @@ -178,13 +248,22 @@ def compiled_mm(a, b): torch.testing.assert_close(Y1_compiled, Y1) @unittest.skipIf(not torch.version.hip, "ROCM only") +<<<<<<< HEAD @unittest.mock.patch.dict(os.environ, _test_env) +======= + @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK")) def test_max_autotune_precompile_preselected(self, max_autotune_gemm_backends): """ End to end test for picking preselected ck instances """ +<<<<<<< HEAD +======= + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def mm(a, b): return a @ b @@ -195,6 +274,7 @@ def mm(a, b): assert "rocm" in dir(config) +<<<<<<< HEAD with ( config.patch( { @@ -207,12 +287,24 @@ def mm(a, b): } ), tf32_off(), +======= + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": True, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "compile_threads": 12, + "rocm.ck_dir": self.ck_dir, + "rocm.use_preselected_instances": True, + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): Y_compiled = torch.compile(mm, dynamic=False)(a, b) Y = mm(a, b) torch.testing.assert_close(Y_compiled, Y) @unittest.skipIf(not torch.version.hip, "ROCM only") +<<<<<<< HEAD @unittest.mock.patch.dict(os.environ, _test_env) @parametrize("max_autotune_gemm_backends", ("Aten,CK",)) def test_max_autotune_precompile_non_contiguous(self, max_autotune_gemm_backends): @@ -220,6 +312,17 @@ def test_max_autotune_precompile_non_contiguous(self, max_autotune_gemm_backends Make sure the matmul with non-contiguous inputs can fallback """ +======= + @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + @parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK")) + def test_max_autotune_precompile_non_contiguous(self, max_autotune_gemm_backends): + """ + Make sure the ck template can work with non-contiguous inputs + """ + + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor_options = {"device": "cuda", "dtype": torch.float16} a = torch.empty_strided((50257, 32768), (1, 50304), **tensor_options) @@ -227,6 +330,7 @@ def test_max_autotune_precompile_non_contiguous(self, max_autotune_gemm_backends assert "rocm" in dir(config) +<<<<<<< HEAD with ( config.patch( { @@ -240,6 +344,18 @@ def test_max_autotune_precompile_non_contiguous(self, max_autotune_gemm_backends } ), tf32_off(), +======= + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": True, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "compile_threads": 16, + "rocm.ck_dir": self.ck_dir, + "rocm.ck_max_profiling_configs": 8, + "rocm.ck_tile_max_profiling_configs": 8, + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): @torch.compile(dynamic=False) @@ -248,6 +364,7 @@ def mm(a, b): Y_compiled = mm(a, b) Y_eager = a @ b +<<<<<<< HEAD torch.testing.assert_close(Y_compiled, Y_eager, equal_nan=True) @unittest.skipIf(not torch.version.hip, "ROCM only") @@ -255,6 +372,17 @@ def mm(a, b): @parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK")) @parametrize("x_shape", ([4096, 2048], [2048], [4096, 1])) def test_max_autotune_addmm(self, max_autotune_gemm_backends, x_shape): +======= + torch.testing.assert_close(Y_compiled, Y_eager) + + @unittest.skipIf(not torch.version.hip, "ROCM only") + @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + @parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK")) + @parametrize("x_shape", ([4096, 2048], [2048], [4096, 1])) + def test_max_autotune_addmm(self, max_autotune_gemm_backends, x_shape): + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m, k, n = 4096, 224, 2048 alpha, beta = 1.0, 1.0 @@ -265,6 +393,7 @@ def test_max_autotune_addmm(self, max_autotune_gemm_backends, x_shape): assert "rocm" in dir(config) +<<<<<<< HEAD with ( config.patch( { @@ -277,6 +406,17 @@ def test_max_autotune_addmm(self, max_autotune_gemm_backends, x_shape): } ), tf32_off(), +======= + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": True, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "compile_threads": 2, + "rocm.ck_dir": self.ck_dir, + "rocm.ck_max_profiling_configs": 2, + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): @torch.compile(dynamic=False) @@ -289,6 +429,7 @@ def addmm(x, a, b, alpha, beta): torch.testing.assert_close(Y_compiled, Y_eager) @unittest.skipIf(not torch.version.hip, "ROCM only") +<<<<<<< HEAD @unittest.mock.patch.dict(os.environ, _test_env) @parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK")) @parametrize("quantize_type", ("tensorwise", "rowwise")) @@ -302,6 +443,17 @@ def test_max_autotune_scaled_mm( self.skipTest(f"Unsupported arch {runtime_arch}") # output dtype dtype = torch.bfloat16 +======= + @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + @parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK")) + @parametrize("dtype", (torch.bfloat16,)) + @parametrize("use_fast_accum", (True,)) + @parametrize("quantize_type", ("tensorwise", "rowwise")) + @parametrize("has_bias", (True, False)) + def test_max_autotune_scaled_mm( + self, max_autotune_gemm_backends, dtype, use_fast_accum, quantize_type, has_bias + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor_options = {"device": "cuda", "dtype": dtype} M = 2240 @@ -315,9 +467,13 @@ def test_max_autotune_scaled_mm( if has_bias: bias = torch.randn(N, **tensor_options) +<<<<<<< HEAD dtype_float8 = ( torch.float8_e4m3fnuz if "gfx94" in runtime_arch else torch.float8_e4m3fn ) +======= + dtype_float8 = torch.float8_e4m3fnuz +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f_quantize = ( _quantize_tensorwise if quantize_type == "tensorwise" else _quantize_rowwise @@ -345,6 +501,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): ) return y +<<<<<<< HEAD y_eager = linear( x_fp8, x_inverse_scale, @@ -352,6 +509,29 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): w_inverse_scale_t, bias, ) +======= + if quantize_type == "tensorwise": + y_eager = linear( + x_fp8, + x_inverse_scale, + w_t_fp8, + w_inverse_scale_t, + bias, + ) + else: + # FIXME when rowwise quantize is supported by pt eager on ROCm + w_fp8_tw, w_inverse_scale_tw = _quantize_tensorwise(w, dtype_float8) + w_fp8_tw_t = w_fp8_tw.t() + w_inverse_scale_tw_t = w_inverse_scale_tw.t() + x_fp8_tw, x_inverse_scale_tw = _quantize_tensorwise(x, dtype_float8) + y_eager = linear( + x_fp8_tw, + x_inverse_scale_tw, + w_fp8_tw_t, + w_inverse_scale_tw_t, + bias, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with config.patch( { @@ -380,10 +560,19 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): @unittest.skipIf(not torch.version.hip, "ROCM only") @unittest.mock.patch.dict( os.environ, +<<<<<<< HEAD {**_test_env, "PYTORCH_MIOPEN_SUGGEST_NHWC": "1"}, ) @parametrize("max_autotune_conv_backends", ("CK", "ATEN,CK,TRITON")) def test_max_autotune_conv2d(self, max_autotune_conv_backends): +======= + {"PATH": _get_path_without_sccache(), "PYTORCH_MIOPEN_SUGGEST_NHWC": "1"}, + ) + @parametrize("max_autotune_conv_backends", ("CK", "ATEN,CK,TRITON")) + def test_max_autotune_conv2d(self, max_autotune_conv_backends): + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor_options = {"device": "cuda", "dtype": torch.float32} x = torch.randn(1, 8, 224, 224, **tensor_options) @@ -393,6 +582,7 @@ def test_max_autotune_conv2d(self, max_autotune_conv_backends): assert "rocm" in dir(config) +<<<<<<< HEAD with ( config.patch( { @@ -405,6 +595,17 @@ def test_max_autotune_conv2d(self, max_autotune_conv_backends): } ), tf32_off(), +======= + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": False, + "max_autotune_conv_backends": max_autotune_conv_backends, + "compile_threads": 4, + "rocm.ck_dir": self.ck_dir, + "rocm.ck_max_profiling_configs": 4, + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): @torch.compile(dynamic=False) @@ -417,7 +618,11 @@ def conv2d(x, w): torch.testing.assert_close(Y_compiled, Y_eager, atol=2e-4, rtol=2e-4) @unittest.skipIf(not torch.version.hip, "ROCM only") +<<<<<<< HEAD @unittest.mock.patch.dict(os.environ, _test_env) +======= + @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("max_autotune_gemm_backends", ("CK", "ATen,Triton,CK")) def test_max_autotune_precompile_bmm( self, @@ -427,6 +632,11 @@ def test_max_autotune_precompile_bmm( Test gemm-max-autotune torch.bmm with CK backend """ +<<<<<<< HEAD +======= + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def bmm(a, b): return torch.bmm(a, b) @@ -437,6 +647,7 @@ def bmm(a, b): assert "rocm" in dir(config) +<<<<<<< HEAD with ( config.patch( { @@ -448,6 +659,16 @@ def bmm(a, b): } ), tf32_off(), +======= + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "compile_threads": 2, + "rocm.ck_max_profiling_configs": 2, + "rocm.ck_dir": self.ck_dir, + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): @torch.compile(dynamic=False) @@ -464,5 +685,9 @@ def compiled_bmm(x, w): from torch._inductor.utils import is_big_gpu # Set env to make it work in CI. +<<<<<<< HEAD if HAS_CUDA_AND_TRITON and HAS_CPU and is_big_gpu(): +======= + if HAS_CUDA and HAS_CPU and is_big_gpu(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests() diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 6da49ab392290..28174075182dd 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -29,11 +29,17 @@ TensorMetadata, TensorMetadataAndValues, ) +<<<<<<< HEAD from torch._inductor.cpp_builder import normalize_path_separator from torch._inductor.custom_graph_pass import ( CustomGraphModulePass, CustomGraphPass, CustomPartitionerFn, +======= +from torch._inductor.custom_graph_pass import ( + CustomGraphModulePass, + CustomGraphPass, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) get_hash_for_files, ) from torch._inductor.graph import GraphLowering @@ -47,11 +53,15 @@ CacheArtifactFactory, CacheArtifactManager, ) +<<<<<<< HEAD from torch.testing._internal.common_cuda import ( SM80OrLater, TEST_MULTIGPU, with_tf32_off, ) +======= +from torch.testing._internal.common_cuda import SM80OrLater, TEST_MULTIGPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_device_type import largeTensorTest from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -61,6 +71,10 @@ ) from torch.testing._internal.inductor_utils import ( GPU_TYPE, +<<<<<<< HEAD +======= + HAS_CUDA, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) HAS_GPU, HAS_MULTIGPU, HAS_TRITON, @@ -68,6 +82,7 @@ requires_gpu, requires_triton, ) +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton @@ -75,6 +90,9 @@ from . import custom_inductor_config except ImportError: import custom_inductor_config +======= +from torch.testing._internal.triton_utils import requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if HAS_TRITON: @@ -873,7 +891,11 @@ def fn(x): @torch._functorch.config.patch({"enable_autograd_cache": False}) @config.patch("fx_graph_remote_cache", False) @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @unittest.skipIf(not HAS_CUDA, "Requires CUDA") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_no_arguments_tensor_device_guards(self): """ Usually, when there are example inputs, the device index of the inputs @@ -903,7 +925,11 @@ def f(): @torch._functorch.config.patch({"enable_autograd_cache": False}) @config.patch("fx_graph_remote_cache", False) @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @unittest.skipIf(not HAS_CUDA, "Requires CUDA") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_tensor_device_guards_cpu_tensor(self): """ CPU tensor arguments should still cache hit @@ -1007,10 +1033,16 @@ def fn(x, op): self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1) +<<<<<<< HEAD @requires_cuda_and_triton @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) @with_tf32_off +======= + @requires_cuda + @config.patch({"fx_graph_cache": True}) + @config.patch({"fx_graph_remote_cache": False}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_flex_attention_caching(self): from torch.nn.attention.flex_attention import create_block_mask, flex_attention @@ -1465,7 +1497,11 @@ def f(x, val): self.assertNotEqual(a, b) @config.patch({"fx_graph_cache": False, "fx_graph_remote_cache": False}) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.expectedFailure # TODO: pass in optimize_mem at runtime def test_async_compile_cache(self): class SimpleFunction(torch.autograd.Function): @@ -1808,9 +1844,13 @@ def f(x): assert not kwargs with tempfile.TemporaryDirectory() as temp_dir: +<<<<<<< HEAD path = normalize_path_separator( os.path.join(temp_dir, "compiled_artifact.bin") ) +======= + path = os.path.join(temp_dir, "compiled_artifact.bin") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with fresh_cache(): compiled_artifact = torch._inductor.standalone_compile(gm, args) @@ -1868,6 +1908,7 @@ def f(x): @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) @functorch_config.patch({"enable_autograd_cache": True}) +<<<<<<< HEAD @functorch_config.patch({"autograd_cache_normalize_inputs": True}) def test_split_module(self): class Mod(torch.nn.Module): @@ -1930,6 +1971,8 @@ def t(): @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) @functorch_config.patch({"enable_autograd_cache": True}) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("config_patches", [True, False]) def test_dynamic_shapes_from_example_inputs(self, config_patches): def f(x): @@ -2061,6 +2104,7 @@ def backend(gm, args, **kwargs): result = torch.compile(f, backend=backend)(static_x) self.assertEqual(result, static_x * 3) +<<<<<<< HEAD @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) def test_custom_pass_handling(self): @@ -2128,6 +2172,8 @@ def __call__( def uuid(self) -> Optional[Union[bytes, str]]: return self._uuid +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestFxGraphCacheHashing(TestCase): def test_parameter_constants(self): @@ -2490,6 +2536,7 @@ def uuid(self) -> Optional[Union[bytes, str]]: pickler.dumps(details3), ) +<<<<<<< HEAD def test_hash_custom_backend_config(self): """ Test cache correctness when a custom inductor codegen config @@ -2563,6 +2610,8 @@ def test_hash_custom_partitioner_fn(self): pickler.dumps(details3), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_bypass_unsupported(self): """ Test _reduce_unsupported @@ -2579,7 +2628,11 @@ def test_stable_strings(self): even if they are not the same id. """ s1 = "string" +<<<<<<< HEAD s2 = "strin" # codespell:ignore +======= + s2 = "strin" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) s2 += "g" self.assertNotEqual(id(s1), id(s2)) @@ -2619,7 +2672,11 @@ def test_get_hash_for_files(self): class TestCudaCompileCommand(TestCase): +<<<<<<< HEAD @requires_cuda_and_triton +======= + @unittest.skipIf(not HAS_CUDA, "Requires CUDA") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cuda_compile_command(self): cmd_no_extra_args: str = cuda_compile_command( ["abc.cu", "def.cu"], "output", "so" @@ -2664,7 +2721,11 @@ def reset(self): torch._dynamo.reset() clear_caches() +<<<<<<< HEAD @requires_cuda_and_triton +======= + @unittest.skipIf(not HAS_CUDA, "Requires CUDA") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not SM80OrLater, "Requires SM80+") @unittest.skipIf( TEST_WITH_ROCM, "Requires static cuda launcher, which does not support ROCM" @@ -2715,7 +2776,11 @@ def f(x, y, a, b): for k in global_stats.triton.cache.keys(): self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @unittest.skipIf(not HAS_CUDA, "Requires CUDA") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": False}) @@ -2756,7 +2821,11 @@ def f(x, y, a, b): for k in global_stats.triton.cache.keys(): self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @unittest.skipIf(not HAS_CUDA, "Requires CUDA") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": False}) @@ -2817,7 +2886,11 @@ def f(a, b, c, d, e, f): self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") @requires_triton() +<<<<<<< HEAD @requires_cuda_and_triton +======= + @unittest.skipIf(not HAS_CUDA, "Requires CUDA") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": False}) @@ -2846,8 +2919,13 @@ def get_autotune_stats(): def fn(x, y): return (x + y).relu() +<<<<<<< HEAD x = torch.randn(100, 100).to(GPU_TYPE) y = torch.randn(100, 100).to(GPU_TYPE) +======= + x = torch.randn(100, 100).cuda() + y = torch.randn(100, 100).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with config.patch( { @@ -2881,7 +2959,11 @@ def fn(x, y): class TestRemoteAOTAutogradCache(TestCase): +<<<<<<< HEAD @requires_cuda_and_triton +======= + @unittest.skipIf(not HAS_CUDA, "Requires CUDA") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": True}) @@ -2920,7 +3002,11 @@ def f(a, b): for k in global_stats.fx_graph.cache.keys(): self.assertRegex(k, r"pt2:fx-graph-v1::[0-9a-z]{52}:c[0-9]+") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @unittest.skipIf(not HAS_CUDA, "Requires CUDA") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": True}) @@ -2995,7 +3081,11 @@ def fn(x, y): # This combination of settings exposed a bug where we cleared the # PyCodeCache disk artifacts while they were still needed: +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @config.patch( { "coordinate_descent_tuning": True, diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index 6523cddcec6db..a23c629faba2c 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -10,8 +10,13 @@ instantiate_parametrized_tests, TestCase, ) +<<<<<<< HEAD from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.triton_utils import requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten = torch.ops.aten @@ -55,7 +60,11 @@ def tearDown(self): torch._inductor.metrics.reset() super().tearDown() +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_activation_functions(self): def test_activations(a, b, c): a1 = torch.nn.functional.relu(a) @@ -75,7 +84,11 @@ def test_activations(a, b, c): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_reduce_functions(self): def test_reduce(a, b, c, d): a1 = torch.sum(a, dim=0) @@ -98,7 +111,11 @@ def test_reduce(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertTrue(torch._inductor.metrics.generated_kernel_count <= 2) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_mutated_args(self): def test_mutated(a, b, c, d): a.add_(1) @@ -121,7 +138,11 @@ def test_mutated(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_reduce_split(self): def fn(a, b): a1 = torch.linalg.vector_norm(a) @@ -137,7 +158,11 @@ def fn(a, b): self.assertEqual(out_eager, out_compiled) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_2d_blocking_partitioning(self): def fn(a0, a1, a2, b0, b1, b2): c0 = torch.add(a0, b0) @@ -184,7 +209,11 @@ def tearDown(self): torch._inductor.metrics.reset() super().tearDown() +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_activation_benchmark(self): def test_activations(a, b, c): a1 = torch.nn.functional.relu(a) @@ -204,7 +233,11 @@ def test_activations(a, b, c): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_reduce_benchmark(self): def test_reduce(a, b, c, d): a1 = torch.sum(a, dim=0) @@ -227,7 +260,11 @@ def test_reduce(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_mutated_benchmark(self): def test_mutated(a, b, c, d): a.add_(1) @@ -250,7 +287,11 @@ def test_mutated(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertTrue(torch._inductor.metrics.generated_kernel_count in [6, 9]) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_round_robin_dispatch(self): # combo kernel dispatch strategy: round robin def test_mutated(a, b, c, d): @@ -274,7 +315,11 @@ def test_mutated(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_2d_blocking_benchmark(self): def fn(a0, a1, a2, b0, b1, b2): c0 = torch.add(a0, b0) @@ -329,7 +374,11 @@ def tearDown(self): torch._inductor.metrics.reset() super().tearDown() +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dynamic_shapes_activations(self): def test_activations(a, b, c): a1 = torch.nn.functional.relu(a) @@ -349,7 +398,11 @@ def test_activations(a, b, c): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dynamic_shapes_2d_blocking(self): def fn(a0, a1, a2, b0, b1, b2): c0 = torch.add(a0, b0) @@ -371,7 +424,11 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dynamic_shapes_reduce(self): def test_reduce(a, b, c, d): a1 = torch.sum(a, dim=0) @@ -394,7 +451,11 @@ def test_reduce(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dynamic_shapes_mutated(self): # combo kernel dispatch strategy: round robin def test_mutated(a, b, c, d): @@ -418,7 +479,11 @@ def test_mutated(a, b, c, d): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._inductor.config.patch("combo_kernels_autotune", 0) def test_dynamic_shapes_activations_no_autotune(self): def test_activations(a, b, c): @@ -439,7 +504,11 @@ def test_activations(a, b, c): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._dynamo.config.patch("automatic_dynamic_shapes", True) @torch._dynamo.config.patch("assume_static_by_default", True) def test_dynamic_shapes_persistent_reduction_no_x_dim(self): @@ -458,7 +527,11 @@ def fn(x, y): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._dynamo.config.patch("automatic_dynamic_shapes", True) @torch._dynamo.config.patch("assume_static_by_default", True) def test_dynamic_shapes_persistent_reduction_no_x_dim_2(self): @@ -477,7 +550,11 @@ def fn(x, y): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._dynamo.config.patch("automatic_dynamic_shapes", True) @torch._dynamo.config.patch("assume_static_by_default", True) def test_dynamic_shapes_2d_blocking_round_robin(self): @@ -516,7 +593,11 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(out_eager, out_compiled) self.assertTrue(5 <= torch._inductor.metrics.generated_kernel_count <= 6) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._dynamo.config.patch("automatic_dynamic_shapes", True) @torch._dynamo.config.patch("assume_static_by_default", True) @torch._inductor.config.patch("triton.autotune_at_compile_time", True) @@ -541,5 +622,9 @@ def fn(x, y, z): if __name__ == "__main__": from torch._dynamo.test_case import run_tests +<<<<<<< HEAD if HAS_CPU or HAS_CUDA_AND_TRITON: +======= + if HAS_CPU or HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests(needs="filelock") diff --git a/test/inductor/test_compile.py b/test/inductor/test_compile.py index 6908936eca3f3..6bcf2be03f5cb 100644 --- a/test/inductor/test_compile.py +++ b/test/inductor/test_compile.py @@ -1,4 +1,5 @@ # Owner(s): ["module: inductor"] +<<<<<<< HEAD import os import shlex import subprocess @@ -9,6 +10,10 @@ from torch import _dynamo as dynamo, _inductor as inductor from torch._inductor.codecache import write from torch._inductor.cpp_builder import CppBuilder, CppOptions +======= +import torch +from torch import _dynamo as dynamo, _inductor as inductor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import gen_gm_and_inputs from torch.fx import symbolic_trace @@ -16,6 +21,7 @@ from torch.testing._internal.inductor_utils import HAS_CPU +<<<<<<< HEAD _IS_MACOS = sys.platform.startswith("darwin") _IS_WINDOWS = sys.platform == "win32" @@ -35,6 +41,8 @@ def safe_command_output(cmd, timeout=30): return "runt timeout" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -136,6 +144,7 @@ def test_inductor_via_op_with_multiple_outputs(self): mod_opt = inductor.compile(mod, inp) self.assertEqual(mod(*inp), mod_opt(*inp)) +<<<<<<< HEAD @mock.patch.dict(os.environ, {"TORCHINDUCTOR_DEBUG_SYMBOL": "1"}) def test_inductor_generate_debug_symbol(self): cpp_code = """ @@ -183,6 +192,8 @@ def check_windows_pdb_exist(module_path: str): else: check_linux_debug_section(binary_path) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": if HAS_CPU: diff --git a/test/inductor/test_compile_subprocess.py b/test/inductor/test_compile_subprocess.py index bf474bfbf1776..ff336a3c545e4 100644 --- a/test/inductor/test_compile_subprocess.py +++ b/test/inductor/test_compile_subprocess.py @@ -9,13 +9,17 @@ import os import sys import time +<<<<<<< HEAD import unittest from unittest import mock +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from unittest.mock import patch import torch import torch.library from torch._inductor.compile_fx import _InProcessFxCompile, FxCompile, FxCompileMode +<<<<<<< HEAD from torch._inductor.graph import GraphLowering from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, TEST_WITH_ASAN @@ -37,6 +41,11 @@ if __name__ == "__main__": sys.exit(0) raise unittest.SkipTest("pass_fds not supported on Windows") +======= +from torch._inductor.test_case import TestCase +from torch.testing._internal.common_utils import TEST_WITH_ASAN +from torch.testing._internal.inductor_utils import GPU_TYPE, RUN_CPU, RUN_GPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Make the helper files in test/ importable @@ -62,6 +71,12 @@ "test_remove_noop_slice_scatter": TestFailure(("xpu"), is_skip=True), "test_remove_noop_view_default": TestFailure(("xpu"), is_skip=True), "test_remove_noop_view_dtype": TestFailure(("xpu"), is_skip=True), +<<<<<<< HEAD +======= + # TODO:remove test_upsample_bicubic2d after the following issue resolved: + # https://github.com/intel/intel-xpu-backend-for-triton/issues/4184 + "test_upsample_bicubic2d": TestFailure(("xpu"), is_skip=False), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } @@ -92,6 +107,7 @@ def tearDown(self): TestCase.tearDown(self) torch._dynamo.reset() +<<<<<<< HEAD @requires_gpu() @requires_triton() @unittest.skipIf( @@ -174,6 +190,8 @@ def baseline(x, y): ) self.assertTrue("'max_autotune': True" in source_codes[-1]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @patch("torch._inductor.compile_fx.fx_compile_async", True) def test_async(self): # Test that async+subprocess works. @@ -189,7 +207,11 @@ def model_add(x, y): _AsyncFxCompile._reset_stats() with contextlib.ExitStack() as stack: +<<<<<<< HEAD assert torch._inductor.compile_fx_async.BUG_CACHES_DONT_WORK_WITH_ASYNC +======= + # TODO: Turn off local caches - they don't play nice w/ async currently. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stack.enter_context( torch._inductor.config.patch( autotune_local_cache=False, fx_graph_cache=False @@ -206,6 +228,7 @@ def model_add(x, y): start = time.time() last_report = start +<<<<<<< HEAD while True: start_stat_compiled_runs = _AsyncFxCompile._stat_compiled_runs # Sleep a bit so we don't drive the CPU unnecessarily. @@ -222,6 +245,15 @@ def model_add(x, y): if _AsyncFxCompile._stat_compiled_runs - start_stat_compiled_runs == 2: break +======= + while _AsyncFxCompile._stat_compiled_runs < 4: + # Sleep a bit so we don't drive the CPU unnecessarily. + time.sleep(0.25) + + x = torch.randn(100, 100) + y = torch.randn(100, 100) + model_add(x, y) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # DEBUGGING: Print a periodic message so we know we're still # running... @@ -235,12 +267,21 @@ def model_add(x, y): "Test timed out before producing a compiled artifact." ) +<<<<<<< HEAD self.assertGreater(_AsyncFxCompile._stat_compiled_runs, 1) # Make sure we ran eager at least once. Normally this will be # something like 80. self.assertGreater(_AsyncFxCompile._stat_eager_runs, 0) self.assertEqual(_AsyncFxCompile._stat_bg_started, 2) self.assertEqual(_AsyncFxCompile._stat_bg_finished, 2) +======= + self.assertEqual(_AsyncFxCompile._stat_compiled_runs, 4) + # Make sure we ran eager at least once. Normally this will be + # something like 80. + self.assertGreater(_AsyncFxCompile._stat_eager_runs, 0) + self.assertEqual(_AsyncFxCompile._stat_bg_started, 1) + self.assertEqual(_AsyncFxCompile._stat_bg_finished, 1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if RUN_CPU: diff --git a/test/inductor/test_compile_worker.py b/test/inductor/test_compile_worker.py index 8fde26c6acf67..1127837f9352a 100644 --- a/test/inductor/test_compile_worker.py +++ b/test/inductor/test_compile_worker.py @@ -1,7 +1,10 @@ # Owner(s): ["module: inductor"] import operator import os +<<<<<<< HEAD import tempfile +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.compile_worker.subproc_pool import ( raise_testexc, @@ -54,6 +57,7 @@ def test_crash(self): finally: pool.shutdown() +<<<<<<< HEAD @skipIfWindows(msg="pass_fds not supported on Windows.") def test_quiesce(self): pool = SubprocPool(2) @@ -80,6 +84,8 @@ def test_logging(self): finally: pool.shutdown() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 6014a6e698607..8c46feb99d1f0 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -29,7 +29,10 @@ from torch._dynamo.testing import normalize_gm from torch._dynamo.utils import counters from torch._inductor import config as inductor_config +<<<<<<< HEAD from torch._inductor.cpp_builder import is_msvc_cl +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.test_case import run_tests, TestCase from torch.nn.attention.flex_attention import flex_attention from torch.nn.parallel import DistributedDataParallel as DDP @@ -41,12 +44,16 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_S390X, +<<<<<<< HEAD IS_WINDOWS, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) parametrize, scoped_load_inline, skipIfWindows, ) from torch.testing._internal.hop_db import hop_db +<<<<<<< HEAD from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_CPU, @@ -55,6 +62,10 @@ ) from torch.testing._internal.logging_utils import logs_to_string from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU +from torch.testing._internal.logging_utils import logs_to_string +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._python_dispatch import TorchDispatchMode @@ -172,6 +183,7 @@ def run_as_subprocess(self, script) -> bytes: except subprocess.CalledProcessError as e: self.fail(f"Subprocess exited with return code: {e.returncode}") +<<<<<<< HEAD def test_hipify_not_loaded_with_import_torch(self): script = """ import torch @@ -186,6 +198,8 @@ def test_hipify_not_loaded_with_import_cpp_extension(self): """ self.run_as_subprocess(script) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dynamo_flaky_segfault(self): script = """ import torch @@ -215,6 +229,7 @@ def model(i): for _ in range(3): self.run_as_subprocess(script) +<<<<<<< HEAD def gen_cache_miss_log_prefix(self): if IS_WINDOWS: if is_msvc_cl(): @@ -227,6 +242,8 @@ def gen_cache_miss_log_prefix(self): else: return "Cache miss due to new autograd node: " +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_reset(self): compiled_autograd.compiled_autograd_enabled = True torch._C._dynamo.compiled_autograd.set_autograd_compiler(lambda: None, True) @@ -1042,8 +1059,13 @@ def test_inputs_aliasing_bytecode_attr_mutations(self): # Freeze compiled autograd graph compiler = torch._dynamo.compiled_autograd.AutogradCompilerInstance(compiler_fn) param = torch.ones(100) +<<<<<<< HEAD active = torch.ones(100) * 2 inputs = [param, active] +======= + activ = torch.ones(100) * 2 + inputs = [param, activ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _, proxies, _, _ = compiler.begin_capture( inputs=inputs, sizes=[], @@ -1094,7 +1116,11 @@ def bytecode_hook(code, out_code): try: runtime_wrapper( compiled_fn=compiled_fn, +<<<<<<< HEAD inputs=[param, active], +======= + inputs=[param, activ], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sizes=(), scalars=(), hooks=[], @@ -3009,7 +3035,11 @@ def backward(ctx, grad): b = MyFunc.apply(a) b.sum().backward() +<<<<<<< HEAD @requires_cuda_and_triton +======= + @unittest.skipIf(not HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cudagraphs_cpu_division(self): from torch._dynamo.testing import reduce_to_scalar_loss @@ -3049,7 +3079,11 @@ def test_cudagraphs_cpu_graph(self): self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @unittest.skipIf(not HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cudagraphs_sdpa(self): query = torch.rand( 32, 8, 128, 64, dtype=torch.float16, device="cuda", requires_grad=True @@ -3071,7 +3105,11 @@ def test_cudagraphs_sdpa(self): 2 if inductor_config.cpp_wrapper else 0, ) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @unittest.skipIf(not HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cudagraphs_cpu_scalar_used_in_python_custom_op(self): class MyFn(torch.autograd.Function): @staticmethod @@ -3099,6 +3137,7 @@ def backward(ctx, gO): self.assertEqual(counters["compiled_autograd"]["captures"], 1) # Compiled autograd lifts custom autograd.Function bwd instead of tracing it. # Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. +<<<<<<< HEAD if inductor_config.graph_partition: # instead of skipping cudagraph, graph partition splits off cpu inputs/outputs and ops # and cudagraphify the remaining computation. So there is no cudagraph skip. @@ -3112,6 +3151,12 @@ def backward(ctx, gO): @scoped_load_inline @requires_cuda_and_triton +======= + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + + @scoped_load_inline + @unittest.skipIf(not HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { @@ -3173,6 +3218,7 @@ def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): # into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. # In the future, we can consider having a cpu scalar movement pass sometime after we trace # into the custom C++ autograd::Function (like in AOTDispatcher) +<<<<<<< HEAD if inductor_config.graph_partition: # instead of skipping cudagraph, graph partition splits off cpu inputs/outputs and ops # and cudagraphify the remaining computation. So there is no cudagraph skip. @@ -3185,6 +3231,11 @@ def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): self.assertEqual( counters["inductor"]["cudagraph_skips"], expected_cudagraph_skips, +======= + self.assertEqual( + counters["inductor"]["cudagraph_skips"], + 2 if inductor_config.cpp_wrapper else 1, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_logs(self): @@ -3198,7 +3249,11 @@ def test_logs(self): self.assertEqual(counters["compiled_autograd"]["compiles"], 1) assert "torch::autograd::AccumulateGrad (NodeCall" in logs.getvalue() assert ( +<<<<<<< HEAD self.gen_cache_miss_log_prefix() + "torch::autograd::GraphRoot" +======= + "Cache miss due to new autograd node: torch::autograd::GraphRoot" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) not in logs.getvalue() ) @@ -3405,6 +3460,10 @@ def fn(x, obj): sum(1 for e in expected_logs if e in logs.getvalue()), len(expected_logs) ) +<<<<<<< HEAD +======= + @skipIfWindows(msg="AssertionError: Scalars are not equal!") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_verbose_logs_cpp(self): torch._logging.set_logs(compiled_autograd_verbose=True) @@ -3432,9 +3491,14 @@ def fn(): self.check_output_and_recompiles(fn) patterns1 = [ +<<<<<<< HEAD r".*" + self.gen_cache_miss_log_prefix() + r"torch::autograd::GraphRoot \(NodeCall 0\) with key size (\d+), previous key sizes=\[\]\n", +======= + r".*Cache miss due to new autograd node: torch::autograd::GraphRoot \(NodeCall 0\) with key size (\d+), " + r"previous key sizes=\[\]\n", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] all_logs = logs.getvalue() @@ -3447,7 +3511,10 @@ def fn(): ) # for a single match: matches1=['match'], for multiple matches: matches1=[('match1', 'match2')]... self.assertEqual(len(matches1), len(patterns1)) +<<<<<<< HEAD @skipIfWindows(msg="node name demangling inconsistent on windows") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_verbose_logs_dynamic_shapes(self): logs, ctx = logs_to_string( torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose" @@ -3472,8 +3539,12 @@ def test_verbose_logs_dynamic_shapes(self): actual_logs = logs.getvalue() expected_logs = [ +<<<<<<< HEAD self.gen_cache_miss_log_prefix() + "torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]", +======= + "Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[]", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] for expected in expected_logs: self.assertTrue(expected in actual_logs) @@ -3504,7 +3575,11 @@ def fn(): fn() unexpected_logs = [ +<<<<<<< HEAD self.gen_cache_miss_log_prefix() + "torch::autograd::GraphRoot (NodeCall 0)" +======= + "Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0)" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] self.assertEqual(sum(1 for e in unexpected_logs if e in logs.getvalue()), 0) @@ -3748,7 +3823,11 @@ def inner_compiler(gm_, example_inputs_): self.assertTrue(isinstance(view_nodes[0].args[1][0], torch.fx.Node)) self.assertTrue(isinstance(view_nodes[1].args[1][0], torch.fx.Node)) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @unittest.skipIf(not HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_flex_attention(self): def _squared(score, b, h, m, n): """Joint graph needed for correctness""" @@ -3916,7 +3995,11 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data): compiler_fn=make_compiler_fn(backend="ca_eager", gm_hook=check), ) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @unittest.skipIf(not HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cpu_offloading(self): def fn(): def pack(x): @@ -5084,7 +5167,11 @@ def wrap_test_class(orig_cls): dct[name] = unittest.expectedFailure elif name.startswith("test_"): backend = lookup_backend(name) +<<<<<<< HEAD if not HAS_CUDA_AND_TRITON and backend == "inductor": +======= + if not HAS_CUDA and backend == "inductor": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue ctxs = [ compiled_autograd._enable( @@ -5197,7 +5284,10 @@ def wrap_test_class(orig_cls): "test_nested_checkpoint_set_early_stop", # dynamo disable "test_nested_checkpoint_two_children_early_stop_False", # dynamo disable "test_nested_checkpoint_two_children_early_stop_True", # dynamo disable +<<<<<<< HEAD "test_custom_autograd_ac_early_stop", # marked as skipped +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "test_dropout", # dynamo disable "test_dropout_inductor", # dynamo disable "test_function_with_kwargs", # dynamo disable @@ -5322,7 +5412,11 @@ def wrap_test_class(orig_cls): skipped_tests = set() +<<<<<<< HEAD if not HAS_CUDA_AND_TRITON: +======= +if not HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Found Tesla M60 which is too old to be supported by the triton GPU compiler skipped_tests.add("test_type_conversions") @@ -5348,7 +5442,11 @@ def wrap_test_class(orig_cls): test_higher_order_ops.ActivationCheckpointingTests ) +<<<<<<< HEAD if torch.distributed.is_available() and HAS_CUDA_AND_TRITON: +======= +if torch.distributed.is_available() and HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test_dtensor = load_test_module("distributed/tensor/test_dtensor_compile") TestDTensorCompileWithCompiledAutograd = wrap_test_class( test_dtensor.TestDTensorCompile diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index c313348e93346..bb0d90e1844eb 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -1,16 +1,23 @@ # Owner(s): ["module: inductor"] +<<<<<<< HEAD import random import sys import types +======= +import sys +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import unittest import weakref from contextlib import ExitStack from copy import deepcopy from typing import NamedTuple +<<<<<<< HEAD from expecttest import assert_expected_inline +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._inductor import torch._inductor.cudagraph_trees @@ -58,14 +65,22 @@ optim_db, optims, ) +<<<<<<< HEAD from torch.testing._internal.common_utils import parametrize, skipIfWindows +======= +from torch.testing._internal.common_utils import parametrize +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_CPU, HAS_GPU, has_triton, ) +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu +======= +from torch.testing._internal.triton_utils import requires_cuda, requires_gpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_inputs(optim): @@ -190,6 +205,7 @@ class KernelCounts(NamedTuple): # tests you can get different kernel counts # This maps the test name to the # expected kernel count +<<<<<<< HEAD # fmt: off # expecttest got error after PYFMT add line break for the triple quotes @@ -257,6 +273,71 @@ class KernelCounts(NamedTuple): "test_sgd_tensor_lr_xpu": lambda x: assert_expected_inline(x, """2"""), } # fmt: on +======= +KERNEL_COUNT_OVERRIDES = { + "test_rmsprop_foreach_weight_decay_cpu": 12, + "test_nadam_foreach_weight_decay_momentum_decay_cpu": 20, + "test_adamw_amsgrad_capturable_foreach_cuda": 3, + "test_adamw_amsgrad_capturable_foreach_xpu": 3, + "test_adamw_amsgrad_capturable_cuda": 6, + "test_adamw_amsgrad_capturable_xpu": 6, + "test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_cuda": 6, + "test_adamw_tensor_lr_tensor_betas_capturable_cuda": 6, + "test_adamw_tensor_lr_tensor_betas_amsgrad_capturable_xpu": 6, + "test_adamw_tensor_lr_amsgrad_capturable_cuda": 6, + "test_adamw_tensor_lr_amsgrad_capturable_xpu": 6, + "test_adam_tensor_lr_amsgrad_capturable_cuda": 6, + "test_adam_tensor_lr_amsgrad_capturable_xpu": 6, + "test_adam_tensor_lr_tensor_betas_amsgrad_capturable_cuda": 6, + "test_adam_tensor_lr_tensor_betas_capturable_cuda": 6, + "test_adam_amsgrad_capturable_cuda": 6, + "test_adam_amsgrad_capturable_xpu": 6, + "test_adadelta_tensor_lr_capturable_cuda": 6, + "test_adadelta_tensor_lr_capturable_xpu": 6, + "test_rmsprop_tensor_lr_capturable_cuda": 6, + "test_rmsprop_tensor_lr_capturable_xpu": 6, + "test_adadelta_foreach_weight_decay_maximize_cpu": 12, + "test_adadelta_foreach_rho_weight_decay_cpu": 12, + "test_adadelta_foreach_weight_decay_cpu": 12, + "test_sgd_foreach_momentum_weight_decay_cpu": 16, + "test_sgd_foreach_momentum_nesterov_weight_decay_cpu": 16, + "test_sgd_momentum_dampening_foreach_cuda": 5, + "test_sgd_momentum_dampening_foreach_xpu": 5, + "test_sgd_momentum_foreach_cuda": 5, + "test_sgd_momentum_foreach_xpu": 5, + "test_sgd_weight_decay_maximize_cuda": 4, + "test_sgd_weight_decay_maximize_xpu": 4, + "test_sgd_weight_decay_maximize_cpu": 4, + "test_sgd_weight_decay_cpu": 4, + "test_sgd_weight_decay_cuda": 4, + "test_sgd_weight_decay_xpu": 4, + "test_sgd_momentum_weight_decay_foreach_cuda": 2, + "test_sgd_momentum_weight_decay_foreach_xpu": 2, + "test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2, + "test_sgd_momentum_nesterov_weight_decay_foreach_xpu": 2, + "test_sgd_cuda": 4, + "test_sgd_cpu": 4, + "test_sgd_xpu": 4, + "test_adagrad_initial_accumulator_value_weight_decay_foreach_xpu": 2, + "test_adagrad_lr_decay_weight_decay_foreach_xpu": 2, + "test_adagrad_weight_decay_foreach_xpu": 2, + "test_adagrad_weight_decay_maximize_foreach_xpu": 2, + "test_adagrad_tensor_lr_cpu": 6, + "test_adagrad_tensor_lr_cuda": 6, + "test_adagrad_tensor_lr_xpu": 6, + "test_adamax_tensor_lr_weight_decay_capturable_cuda": 6, + "test_adamax_tensor_lr_weight_decay_capturable_xpu": 6, + "test_asgd_tensor_lr_weight_decay_maximize_capturable_cuda": 5, + "test_asgd_tensor_lr_weight_decay_maximize_capturable_xpu": 8, + "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_cuda": 6, + "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_xpu": 9, + "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_cuda": 6, + "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_xpu": 6, + "test_sgd_tensor_lr_cpu": 2, + "test_sgd_tensor_lr_cuda": 2, + "test_sgd_tensor_lr_xpu": 2, +} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # also tracks currently supported optimizers KERNEL_COUNTS = { @@ -511,12 +592,18 @@ def test_fn(self): # currently, we compile the step and the rest of the computation # separately because the step is a single element tensor # hence, the usual kernel count is 2 +<<<<<<< HEAD if isinstance(kernel_count, types.LambdaType): kernel_count(str(torch._inductor.metrics.generated_kernel_count)) else: self.assertEqual( torch._inductor.metrics.generated_kernel_count, kernel_count ) +======= + self.assertEqual( + torch._inductor.metrics.generated_kernel_count, kernel_count + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) finally: stack.close() @@ -584,9 +671,12 @@ class CompiledOptimizerParityTests(TestCase): @optims(optim_db, dtypes=[torch.float32]) @parametrize("use_closure", [True, False]) def test_correctness(self, device, dtype, optim_info, use_closure): +<<<<<<< HEAD torch.cuda.manual_seed_all(0) torch.manual_seed(0) random.seed(0) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) optim_cls = optim_info.optim_cls all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( device, dtype, optim_info, skip=("differentiable",) @@ -608,10 +698,14 @@ def test_correctness(self, device, dtype, optim_info, use_closure): torch._inductor.metrics.reset() input = torch.ones([10, 10], device=device) model_eager = torch.nn.Sequential( +<<<<<<< HEAD *[ torch.nn.Linear(10, 10, device=device, bias=False) for _ in range(2) ] +======= + *[torch.nn.Linear(10, 10, device=device) for _ in range(2)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) model_eager(input).sum().backward() model_compiled = deepcopy(model_eager) @@ -731,7 +825,10 @@ def check_cudagraphs_ran(self): SGD, kernel_count=1, lr=0.01, foreach=True ) +<<<<<<< HEAD @skipIfWindows +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @requires_gpu def test_static_address_finalizer(self): import gc @@ -924,7 +1021,11 @@ def fn(xs, ys): self.assertLess(end - start, 90) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_S429861(self): # Just verify we can compile this function without error try: @@ -943,7 +1044,11 @@ def test_S429861(self): kwargs = aot_graph_input_parser(forward) torch.compile(forward)(**kwargs) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_foreach_map_adam(self): params = [ torch.rand( diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index 715176a5ee51f..744af77933b0b 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -5,7 +5,10 @@ import torch import torch._dynamo.testing +<<<<<<< HEAD import torch.utils._pytree as pytree +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._higher_order_ops.associative_scan import associative_scan from torch._higher_order_ops.map import _fake_map from torch._higher_order_ops.scan import _fake_scan, scan @@ -38,6 +41,7 @@ def prepend_counters(inputs, num_counters=1, counter_values=(0, 1, 5)): return _prepend_product_of_values(inputs, counter_values, num_counters) +<<<<<<< HEAD # a testing loss_fn def loss_fn(result) -> torch.Tensor: flat_results, _ = pytree.tree_flatten(result) @@ -56,6 +60,8 @@ def loss_fn(result) -> torch.Tensor: return total_loss +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class CondModels: class Simple(torch.nn.Module): def forward(self, p, a, b): @@ -232,6 +238,7 @@ def false_fn(x): return y.sum() - torch.cond(x.sum() > 0, true_fn, false_fn, (x,)) +<<<<<<< HEAD class FunctionalCall(torch.nn.Module): def __init__(self): super().__init__() @@ -276,6 +283,8 @@ def fn(): return torch.cond(x0.sum() > 0, fn, fn) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class CondTests(TestCase): def _run_test( @@ -294,6 +303,7 @@ def _run_test( if dynamic: larger_inputs = [] for inp in inputs: +<<<<<<< HEAD # only tile non-scalar tensor inputs if inp.ndim > 0: # tile every first dim 5x @@ -301,6 +311,11 @@ def _run_test( larger_inputs.append(torch.tile(inp, tiling)) else: larger_inputs.append(inp) +======= + # tile every first dim 5x + tiling = [5] + [1] * (inp.ndim - 1) + larger_inputs.append(torch.tile(inp, tiling)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_sets.append(larger_inputs) for inputs in input_sets: for inp in inputs: @@ -505,9 +520,12 @@ def false_fn(x): @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @torch._inductor.config.patch(size_asserts=False) +<<<<<<< HEAD # TODO: graph partition does not support creating tensor # with dynamic shape in conditional subgraph yet @torch._inductor.config.patch(graph_partition=False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cond_unbacked_symint_inner(self, device): class Model(torch.nn.Module): def forward(self, p, a): @@ -745,6 +763,7 @@ def test_cond_mismatched_branch_output_size(self, device, dynamic): dynamic=dynamic, ) +<<<<<<< HEAD @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [True, False]) @@ -768,6 +787,8 @@ def test_cond_select_with_input_idx(self, device, dynamic): dynamic=dynamic, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class WhileLoopModels: class Simple(torch.nn.Module): @@ -804,12 +825,17 @@ class Parameters(torch.nn.Module): class InnerModel(torch.nn.Module): def __init__(self, device): super().__init__() +<<<<<<< HEAD self.layer1 = torch.nn.Linear( 20, 30, device=device, dtype=torch.float64 ) self.layer2 = torch.nn.Linear( 30, 20, device=device, dtype=torch.float64 ) +======= + self.layer1 = torch.nn.Linear(20, 30, device=device) + self.layer2 = torch.nn.Linear(30, 20, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def forward(self, c, x): return c - 1, self.layer2(self.layer1(x - 2)) * 3.14 @@ -989,6 +1015,7 @@ def body_fn(c, a_view): ) return out1 + 1, out2 + 2 +<<<<<<< HEAD class ZeroLoop4(torch.nn.Module): def forward(self, c, a): a_view = torch.sin(a.view(-1, 1)) @@ -1006,6 +1033,8 @@ def body_fn(c, a_view): ) return out2.sin_(), a_view.cos_() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class UnbackedSymIntClosure(torch.nn.Module): def forward(self, c, a, b): d = a.sum().to(torch.int64).item() @@ -1029,7 +1058,11 @@ def forward(self, c, a, b): e = torch.nonzero(b).size(0) def cond_fn(c, a, b): +<<<<<<< HEAD return c + d + e + a.shape[0] - b.shape[0] < 10 +======= + return d + e + a.shape[0] - b.shape[0] < 10 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def body_fn(c, a, b): return c + 1, a + e, b + d @@ -1092,6 +1125,7 @@ def body_fn(loop_idx, x): (c, x), ) +<<<<<<< HEAD class WhileLoopStackOutputSimple(torch.nn.Module): def __init__(self, device): super().__init__() @@ -1113,10 +1147,22 @@ def body_fn(c, x): class WhileLoopTests(TestCase): def _run_test( self, model, inputs, device, dynamic=False, num_counters=1, autograd=False +======= + +class WhileLoopTests(TestCase): + def _run_test( + self, + model, + inputs, + device, + dynamic=False, + num_counters=1, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): import torch.utils._pytree as pytree cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor") +<<<<<<< HEAD import copy if not autograd: @@ -1134,10 +1180,28 @@ def mark_first_dim_dyn(inp): if dynamic: +======= + compiled_model = torch.compile(backend=cnt, fullgraph=True)(model) + + inputs = pytree.tree_map(lambda t: t.to(device=device), inputs) + input_sets = [inputs] + if dynamic: + + def mark_first_dim_dyn(inp): + torch._dynamo.mark_dynamic(inp, 0) + + pytree.tree_map(mark_first_dim_dyn, input_sets) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def tile_fn(inp): # tile every first dim 5x tiling = [5] + [1] * (inp.ndim - 1) t = torch.tile(inp, tiling) +<<<<<<< HEAD +======= + # mark every first dim as dynamic + torch._dynamo.mark_dynamic(inp, 0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return t larger_inputs = pytree.tree_map(tile_fn, inputs) @@ -1154,6 +1218,7 @@ def tile_fn(inp): ) unflat_inputs = pytree.tree_unflatten(flat, inp_spec) inputs_with_counters = counters + unflat_inputs +<<<<<<< HEAD def process_inputs(inp): inp = inp.clone() @@ -1169,12 +1234,21 @@ def process_inputs(inp): result = model(*cloned_inputs) result_compiled = compiled_fn(*cloned_inputs2) +======= + cloned_inputs = pytree.tree_map( + lambda t: t.clone(), inputs_with_counters + ) + result = model(*inputs_with_counters) + with torch.no_grad(): + result_compiled = compiled_model(*inputs_with_counters) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # inputs must not be mutated torch.testing.assert_close(cloned_inputs, inputs_with_counters) torch.testing.assert_close( result, result_compiled, atol=1e-4, rtol=1e-4 ) +<<<<<<< HEAD if autograd and any( pytree.tree_map_only( torch.Tensor, lambda t: t.requires_grad, cloned_inputs @@ -1218,14 +1292,20 @@ def process_inputs(inp): rtol=1e-4, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(cnt.frame_count, 1, "only one compilation expected") @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [False, True]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_while_loop_simple_control_flow(self, device, dynamic, autograd): +======= + def test_while_loop_simple_control_flow(self, device, dynamic): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # while_loop control flow without nesting self._run_test( model=WhileLoopModels.Simple(), @@ -1235,15 +1315,22 @@ def test_while_loop_simple_control_flow(self, device, dynamic, autograd): ), device=device, dynamic=dynamic, +<<<<<<< HEAD autograd=autograd, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [False, True]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_while_loop_nested_control_flow(self, device, dynamic, autograd): +======= + def test_while_loop_nested_control_flow(self, device, dynamic): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # while_loop control flow with nesting self._run_test( model=WhileLoopModels.Nested(), @@ -1254,15 +1341,22 @@ def test_while_loop_nested_control_flow(self, device, dynamic, autograd): device=device, dynamic=dynamic, num_counters=2, +<<<<<<< HEAD autograd=autograd, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [False, True]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_while_loop_with_outer_code(self, device, dynamic, autograd): +======= + def test_while_loop_with_outer_code(self, device, dynamic): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # while_loop control flow with outer code self._run_test( model=WhileLoopModels.OuterCode(), @@ -1272,12 +1366,16 @@ def test_while_loop_with_outer_code(self, device, dynamic, autograd): ), device=device, dynamic=dynamic, +<<<<<<< HEAD autograd=autograd, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [False, True]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_while_loop_with_parameters(self, device, dynamic, autograd): @@ -1288,6 +1386,15 @@ def test_while_loop_with_parameters(self, device, dynamic, autograd): device=device, dynamic=dynamic, autograd=autograd, +======= + def test_while_loop_with_parameters(self, device, dynamic): + # while_loop control flow with parameters + self._run_test( + model=WhileLoopModels.Parameters(device), + inputs=(torch.randn(10, 20),), + device=device, + dynamic=dynamic, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @requires_gpu @@ -1295,9 +1402,13 @@ def test_while_loop_with_parameters(self, device, dynamic, autograd): # dynamic=True doesn't work now due to # https://github.com/pytorch/pytorch/issues/123596 @parametrize("dynamic", [False]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_while_loop_with_outer_buffers(self, device, dynamic, autograd): +======= + def test_while_loop_with_outer_buffers(self, device, dynamic): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # while_loop control flow with outer code self._run_test( model=WhileLoopModels.OuterBuffers(), @@ -1307,15 +1418,24 @@ def test_while_loop_with_outer_buffers(self, device, dynamic, autograd): ), device=device, dynamic=dynamic, +<<<<<<< HEAD autograd=autograd, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) +<<<<<<< HEAD @parametrize("dynamic", [True, False]) @parametrize("autograd", [False, True]) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_while_loop_with_pytree_inputs(self, device, dynamic, autograd): +======= + # dynamic=True doesn't work due to we haven't handle lifted symbols + @parametrize("dynamic", [True, False]) + def test_while_loop_with_pytree_inputs(self, device, dynamic): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._run_test( model=WhileLoopModels.PytreeCarry(), inputs=( @@ -1326,15 +1446,22 @@ def test_while_loop_with_pytree_inputs(self, device, dynamic, autograd): ), device=device, dynamic=dynamic, +<<<<<<< HEAD autograd=autograd, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [True, False]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_while_loop_with_data_dependent_ops(self, device, dynamic, autograd): +======= + def test_while_loop_with_data_dependent_ops(self, device, dynamic): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch._dynamo.config.patch( { "capture_dynamic_output_shape_ops": True, @@ -1350,15 +1477,22 @@ def test_while_loop_with_data_dependent_ops(self, device, dynamic, autograd): ), device=device, dynamic=dynamic, +<<<<<<< HEAD autograd=autograd, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [True, False]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_while_loop_with_data_dependent_in_out(self, device, dynamic, autograd): +======= + def test_while_loop_with_data_dependent_in_out(self, device, dynamic): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch._dynamo.config.patch( { "capture_dynamic_output_shape_ops": True, @@ -1375,7 +1509,10 @@ def test_while_loop_with_data_dependent_in_out(self, device, dynamic, autograd): ), device=device, dynamic=dynamic, +<<<<<<< HEAD autograd=autograd, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @parametrize("dynamic", [True, False]) @@ -1421,7 +1558,10 @@ def test_while_loop_zero_loop(self, device, dynamic): WhileLoopModels.ZeroLoop(), WhileLoopModels.ZeroLoop2(), WhileLoopModels.ZeroLoop3(), +<<<<<<< HEAD WhileLoopModels.ZeroLoop4(), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ]: self._run_test( model=model, @@ -1436,8 +1576,12 @@ def test_while_loop_zero_loop(self, device, dynamic): @torch._dynamo.config.patch( {"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True} ) +<<<<<<< HEAD @parametrize("autograd", [False, True]) def test_while_loop_with_unbacked_symint_closure(self, device, dynamic, autograd): +======= + def test_while_loop_with_unbacked_symint_closure(self, device, dynamic): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._run_test( model=WhileLoopModels.UnbackedSymIntClosure(), inputs=( @@ -1446,7 +1590,10 @@ def test_while_loop_with_unbacked_symint_closure(self, device, dynamic, autograd ), device=device, dynamic=dynamic, +<<<<<<< HEAD autograd=autograd, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @requires_gpu @@ -1481,11 +1628,18 @@ def test_while_loop_models_with_mixed_device(self, device): @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [True, False]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) @torch._dynamo.config.patch( {"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True} ) def test_while_loop_with_sym_expr_cond(self, device, dynamic, autograd): +======= + @torch._dynamo.config.patch( + {"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True} + ) + def test_while_loop_with_sym_expr_cond(self, device, dynamic): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._run_test( model=WhileLoopModels.SymExprCond(), inputs=( @@ -1494,20 +1648,28 @@ def test_while_loop_with_sym_expr_cond(self, device, dynamic, autograd): ), device=device, dynamic=dynamic, +<<<<<<< HEAD autograd=autograd, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [True, False]) +<<<<<<< HEAD @parametrize("autograd", [False, True]) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_while_loop_with_conv(self, device, dynamic, autograd): +======= + def test_while_loop_with_conv(self, device, dynamic): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._run_test( model=WhileLoopModels.Conv(device), inputs=(torch.randn(2, 4, 4, 4, dtype=torch.float64),), device=device, dynamic=dynamic, +<<<<<<< HEAD autograd=autograd, ) @@ -1521,6 +1683,8 @@ def test_while_loop_stack_output_simple(self, device, dynamic): inputs=(torch.randn(3, 3, dtype=torch.float32),), device=device, dynamic=dynamic, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -1529,7 +1693,11 @@ class AssociativeScanTests(TestCase): @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("backend", ["inductor"]) @parametrize("device", [torch.device("cpu"), GPU_TYPE]) +<<<<<<< HEAD # This test will fail as flip in combination with particular input lengths +======= + # This test will fail as flip in combination with particular input lenghts +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # produces weird results. # This is under investigations in # https://github.com/pytorch/pytorch/issues/131805 @@ -1553,7 +1721,11 @@ def fct(x: torch.Tensor, y: torch.Tensor): fct, x, 0, reverse=False, combine_mode=combine_mode ) +<<<<<<< HEAD # Skipping test because combine_mode currently only supports CUDA tensors +======= + # Skipping test because combine_mode currently only suppors CUDA tensors +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return result1 = associative_scan1( @@ -1743,6 +1915,11 @@ def __init__(self, reverse, dim): def forward(self, scan_op, _input, weight, bias): def combine_fn(carry, x): +<<<<<<< HEAD +======= + from torch.utils import _pytree as pytree + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_carry = { "param": carry["param"] @ x + carry["bias"], "bias": carry["bias"].sin(), @@ -2152,6 +2329,7 @@ def _run_test( inputs, device, dynamic=False, +<<<<<<< HEAD autograd=False, ): import copy @@ -2171,11 +2349,25 @@ def _run_test( cloned_inputs = [inp.clone() for inp in inputs] result = model(torch._higher_order_ops.map, *cloned_inputs) result_exp = model_eager(_fake_map, *cloned_inputs) +======= + ): + cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor") + compiled_model = torch.compile(backend=cnt, fullgraph=True, dynamic=dynamic)( + model + ) + + inputs = [inp.to(device=device) for inp in inputs] + model = model.to(device=device) + cloned_inputs = [inp.clone() for inp in inputs] + result = model(torch._higher_order_ops.map, *cloned_inputs) + result_exp = model(_fake_map, *cloned_inputs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result_compiled = compiled_model(torch._higher_order_ops.map, *cloned_inputs) self.assertEqual(result, result_exp) self.assertEqual(result, result_compiled) +<<<<<<< HEAD if autograd: loss_fn(result).backward() loss_fn(result_exp).backward() @@ -2196,34 +2388,57 @@ def _run_test( @parametrize("autograd", [True, False]) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_map_simple(self, device, dynamic, autograd): +======= + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) + @parametrize("dynamic", [True, False]) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_map_simple(self, device, dynamic): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._run_test( model=MapModels.Simple(), inputs=(torch.randn(3, 4),), device=device, dynamic=dynamic, +<<<<<<< HEAD autograd=autograd, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [True, False]) +<<<<<<< HEAD @parametrize("autograd", [True, False]) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_map_simple_linear_with_view(self, device, dynamic, autograd): +======= + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_map_simple_linear_with_view(self, device, dynamic): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._run_test( model=MapModels.SimpleWithLinearWithView(), inputs=(torch.randn(3, 4),), device=device, dynamic=dynamic, +<<<<<<< HEAD autograd=autograd, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [True, False]) +<<<<<<< HEAD @parametrize("autograd", [True, False]) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_map_pytree_in_out(self, device, dynamic, autograd): +======= + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_map_pytree_in_out(self, device, dynamic): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._run_test( model=MapModels.PytreeInOut(), inputs=( @@ -2233,15 +2448,23 @@ def test_map_pytree_in_out(self, device, dynamic, autograd): ), device=device, dynamic=dynamic, +<<<<<<< HEAD autograd=autograd, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @requires_gpu @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [True, False]) +<<<<<<< HEAD @parametrize("autograd", [True, False]) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_map_nested_with_cond(self, device, dynamic, autograd): +======= + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_map_nested_with_cond(self, device, dynamic): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._run_test( model=MapModels.NestedWithCond(), inputs=( @@ -2251,7 +2474,10 @@ def test_map_nested_with_cond(self, device, dynamic, autograd): ), device=device, dynamic=dynamic, +<<<<<<< HEAD autograd=autograd, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/test/inductor/test_cooperative_reductions.py b/test/inductor/test_cooperative_reductions.py index 0b8f60dc0d269..294fc91263021 100644 --- a/test/inductor/test_cooperative_reductions.py +++ b/test/inductor/test_cooperative_reductions.py @@ -18,7 +18,11 @@ instantiate_parametrized_tests, parametrize, ) +<<<<<<< HEAD from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON +======= +from torch.testing._internal.inductor_utils import HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestingHeuristics(InductorChoices): @@ -381,5 +385,9 @@ def fn(x, y): if __name__ == "__main__": from torch._dynamo.test_case import run_tests +<<<<<<< HEAD if HAS_CUDA_AND_TRITON: +======= + if HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests(needs="filelock") diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 47a8f3aa063e3..d5d95f06d2dc6 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -110,6 +110,10 @@ def make_test_case( @config.patch( cpp_wrapper=True, +<<<<<<< HEAD +======= + search_autotune_cache=False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cpp_wrapper_build_separate=test_build_separate, ) def fn(self): @@ -268,7 +272,11 @@ class BaseTest(NamedTuple): "test_multi_threading", condition=not IS_WINDOWS, # Two threads compile, so we expect the output code to be printed twice. +<<<<<<< HEAD code_string_count={"py::gil_scoped_release_simple release;": 2}, +======= + code_string_count={"py::gil_scoped_release release;": 2}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), BaseTest("test_profiler_mark_wrapper_call"), BaseTest( diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 10e7c3068f10a..172d172753f3b 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -152,7 +152,11 @@ class RecordFunctions(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} if func == torch.ops.aten.convolution.default: +<<<<<<< HEAD # For CPU and mkldnn enable, we always using channels last +======= + # For CPU and mkldnn enable, we always using channles last +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal fmt if ( torch.backends.mkldnn.enabled @@ -996,7 +1000,11 @@ def fn(x): v = torch.randn(10) # TODO: OMP parallel reduction order is not deterministic. +<<<<<<< HEAD # Hence, the accuracy might vary up and down. For short term, +======= + # Hence, the accurarcy might vary up and down. For short term, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # we increase the tolerance and will fix it later by using # aten parallel. self.common(fn, (v,), atol=5e-1, rtol=5e-1) @@ -1004,7 +1012,11 @@ def fn(x): def test_parallel_reduction_vectorization(self): # Fix issue: https://github.com/pytorch/pytorch/issues/151523 class Model(torch.nn.Module): +<<<<<<< HEAD def __init__(self, enable_masked_tail_vec): +======= + def __init__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__() self.conv = torch.nn.Conv2d( in_channels=3, @@ -1013,16 +1025,24 @@ def __init__(self, enable_masked_tail_vec): stride=(2, 1), padding=0, ) +<<<<<<< HEAD self.enable_masked_tail_vec = enable_masked_tail_vec def forward(self, x, weight): x = self.conv(x) if not self.enable_masked_tail_vec: x = F.hardshrink(x, lambd=0) +======= + + def forward(self, x, weight): + x = self.conv(x) + x = F.hardshrink(x, lambd=0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = x.view(x.size(0), -1) x = torch.mv(weight, x[0]) return x +<<<<<<< HEAD for enable_masked_tail_vec in [True, False]: mod = Model(enable_masked_tail_vec).eval() x = torch.randn(2, 3, 127, 255) @@ -1030,6 +1050,14 @@ def forward(self, x, weight): # Use same criterion as test_inplace_squeeze_needed # for parallel reduction. self.common(mod, (x, weight), atol=5e-1, rtol=5e-1) +======= + mod = Model().eval() + x = torch.randn(2, 3, 127, 255) + weight = torch.randn(10, 254976) + # Use same criterion as test_inplace_squeeze_needed + # for parallel reduction. + self.common(mod, (x, weight), atol=5e-1, rtol=5e-1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cat_mul(self): # https://github.com/pytorch/pytorch/issues/93365 @@ -2644,6 +2672,7 @@ def fn(a, dim, index, b): self.common(fn, inps) assert metrics.generated_cpp_vec_kernel_count == 2 +<<<<<<< HEAD def test_large_mean(self): size = (30000, 100000) t = torch.rand(size, dtype=torch.float) @@ -2656,6 +2685,8 @@ def test_large_mean(self): actual = torch.compile(op)(t) self.assertEqual(expected, actual) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") @requires_vectorization @patch("torch.cuda.is_available", lambda: False) @@ -3117,6 +3148,7 @@ def get_traj_idx(lengths: torch.Tensor, num_slices: int) -> torch.Tensor: lengths = torch.zeros(11, dtype=torch.long) get_traj_idx(lengths, num_slices=4) +<<<<<<< HEAD def test_store_reduction(self): # fix https://github.com/pytorch/pytorch/issues/157683 def fn(x, y): @@ -3141,6 +3173,8 @@ def fn(x, y): ), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @requires_vectorization @patch("torch.cuda.is_available", lambda: False) def test_sign_cpu_only(self): @@ -4101,6 +4135,7 @@ def fn(x1, x2): ) self.assertEqual(metrics.generated_kernel_count, 1) +<<<<<<< HEAD def test_relu_permute_reshape_reinterpret_view(self): def fn(x): n, c, h, w = x.shape @@ -4119,6 +4154,8 @@ def fn(x): # check that there is no transpose FileCheck().check_count("transpose_mxn", 0, exactly=True).run(code) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_attention_size_mismatch(self): class Attention(torch.nn.Module): def __init__(self, hidden_size, num_heads): @@ -4384,6 +4421,7 @@ def fn(x, y): y = torch.randint(0, 255, (3, 3), dtype=torch.uint8) self.common(fn, (x, y)) +<<<<<<< HEAD def test_float32_to_uint8(self): # https://github.com/pytorch/pytorch/issues/156788 @torch.compile @@ -4397,6 +4435,8 @@ def fn(x): msg=f"Expected {x.to(torch.uint8)} but got {fn(x)}", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_non_contiguous_reduction_store(self): # https://github.com/pytorch/pytorch/issues/113018 class M(torch.nn.Module): @@ -5434,6 +5474,7 @@ def test_vector_norm_compile(self): res = compiled_vector_norm(x, ord=2, dim=[], keepdim=False, dtype=None) self.assertEqual(ref, res) +<<<<<<< HEAD def test_fractional_max_pool2d_3d_input(self): """Test for https://github.com/pytorch/pytorch/issues/156682 - 3D input causing assertion error""" @@ -5490,6 +5531,8 @@ def fn( result = compiled_func(xs, Ls) torch.testing.assert_close(result, expected) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index fe1e59bd7f49a..8977064eb058e 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -26,7 +26,10 @@ ) from torch.testing._internal.common_utils import ( IS_MACOS, +<<<<<<< HEAD IS_WINDOWS, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) parametrize, skipIfWindows, TEST_MKL, @@ -52,7 +55,11 @@ def patches(fn): +<<<<<<< HEAD def skip_cache(self, choices, name, key, benchmark, hint_override=None): +======= + def skip_cache(self, choices, name, key, benchmark): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if benchmark is None: return {} timings = benchmark(choices) @@ -297,10 +304,13 @@ def forward(self, x): dtype == torch.float16 and torch.ops.mkldnn._is_mkldnn_fp16_supported() ) +<<<<<<< HEAD or ( dtype == torch.float32 and not dynamo_config.assume_static_by_default ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) and epilogue != "mul" and epilogue != "div" @@ -309,15 +319,33 @@ def forward(self, x): and epilogue == "add" and not bias ) +<<<<<<< HEAD +======= + or ( + dtype == torch.float32 + and epilogue == "add" + and not bias + and not dynamo_config.assume_static_by_default + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): # Several scenarios where epilogue fusion is not counted in: # 1. For bfloat16, the epilogue fusion is part of the template, # not fused via scheduler. This will also be true for float16 when +<<<<<<< HEAD # hardware has the float16 instruction. And this will also be true # for float32 dynamic mode. The exception is mul or div fusion # which is not supported for oneDNN linear. # 2. For bfloat16/float16, when oneDNN linear is not applied, linear w/o bias # plus epilogue add is treated as linear w/ bias. +======= + # hardware has the float16 instruction. The exception is mul or + # div fusion which is not supported for oneDNN linear. + # 2. For bfloat16/float16, when oneDNN linear is not applied, linear w/o bias + # plus epilogue add is treated as linear w/ bias. + # 3. For float32, when dynamic shapes is enabled, mkl linear is not applied. + # and linear w/o bias plus epilogue add is treated as addmm. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0) else: self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) @@ -798,7 +826,11 @@ def forward(self, arg7_1): with verify(dtype) as (atol, rtol): self.common(mod, (v,), atol=atol, rtol=rtol) self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 3) +<<<<<<< HEAD self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0) +======= + self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf( not torch._C._cpu._is_amx_tile_supported(), "AMX ISA support is required" @@ -828,7 +860,11 @@ def forward(self, x): self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) vec_amx = VecAMX() # Currently brgemm config is only added for half +<<<<<<< HEAD if dtype == torch.half and not vec_amx.is_amx_fp16_supported(): +======= + if dtype == torch.half: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._check_brgemm_counter(vec_amx) else: self._check_amx_counter(vec_amx) @@ -1956,7 +1992,11 @@ def test_quantized_linear_with_pointwise_binary( input = torch.randn(*B, in_features).to(dtype=torch.float32) other = torch.randn(*B, out_features).to(dtype=dtype) +<<<<<<< HEAD # Avoid hitting qlinear inplace sum fusion +======= + # Avoid hiting qlinear inplace sum fusion +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input_3d: other2 = torch.randn(B[0] * B[1], out_features).to(dtype=dtype) else: @@ -1973,7 +2013,11 @@ def __init__(self, bias, input_3d): def forward(self, x, other, other2): res = self.epilogue(self.linear(x) + other) +<<<<<<< HEAD # Avoid hitting qlinear inplace sum fusion +======= + # Avoid hiting qlinear inplace sum fusion +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.input_3d: other2 = other2.view(2, other2.size(0) // 2, other2.size(1)) else: @@ -2213,7 +2257,11 @@ def __init__(self, in_feature, out_feature, gemm_num): def forward(self, x): return [linear(x) for linear in self.linears] +<<<<<<< HEAD # each linear has different num of out features, thus invalid grouped gemm +======= + # each linear has different num of out features, thus invaild grouped gemm +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtypes = [] if torch.ops.mkldnn._is_mkldnn_bf16_supported(): dtypes.append(torch.bfloat16) @@ -2680,7 +2728,11 @@ def forward(self, x): @torch.no_grad @unittest.skipIf(not TEST_MKL, "Test requires MKL") @parametrize("bs", (5,)) +<<<<<<< HEAD @parametrize("Mdim", (3, 64)) # Test small Mdim which uses reshaped weights +======= + @parametrize("Mdim", (64,)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.float) def test_bmm_self_square(self, bs, Mdim, dtype): class M(torch.nn.Module): @@ -2769,6 +2821,7 @@ def forward(self, x, w): @patches @torch.no_grad +<<<<<<< HEAD @parametrize("bs", (1, 50)) @parametrize("Mdim", (192,)) @parametrize("Kdim", (196,)) @@ -2796,6 +2849,8 @@ def forward(self, x, y): @patches @torch.no_grad +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.float) def test_aoti_bmm_unique_identifiers(self, dtype): try: @@ -3119,5 +3174,9 @@ def forward(self, x, weight): if __name__ == "__main__": from torch.testing._internal.inductor_utils import HAS_CPU +<<<<<<< HEAD if HAS_CPU and not (IS_MACOS or IS_WINDOWS): +======= + if HAS_CPU and not IS_MACOS: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests() diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 5cfb622855725..8ae63c2f68712 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -26,7 +26,10 @@ run_fw_bw_and_get_code, ) from torch.fx.experimental.proxy_tensor import make_fx +<<<<<<< HEAD from torch.nn.attention import sdpa_kernel, SDPBackend +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing import FileCheck from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, @@ -178,10 +181,16 @@ def test_effn_attn_bias_padding_misaligned(self): inputs = [q, k, v, mask] def f(q, k, v, mask): +<<<<<<< HEAD with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): return F.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0 ) +======= + return F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0 + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f_compiled = torch.compile(f) @@ -189,9 +198,15 @@ def f(q, k, v, mask): # padded bias should have an expanded dim FileCheck().check("buf0 =").check_same(", 0, ").run(code[0]) # single fused padded kernel +<<<<<<< HEAD FileCheck().check_count("empty_strided_cuda(", 1, exactly=True).check( "return" ).run(code[0]) +======= + FileCheck().check("def call").check_count( + "empty_strided_cuda", 1, exactly=True + ).check("return").run(code[0]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(out, f(*inputs)) @@ -935,7 +950,11 @@ def foo(x): inp = inp.to(torch.float) out, code = run_and_get_code(torch.compile(foo), inp) +<<<<<<< HEAD FileCheck().check_not("tl_math.exp").check("libdevice.exp").run(code[0]) +======= + FileCheck().check_not("libdevice.exp").check("tl_math.exp").run(code[0]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(foo(inp), out) def foo(x): @@ -1845,7 +1864,10 @@ def fn(x): self.assertEqual(graph.disable_cudagraphs_reason, None) self.assertEqual(graph.device_types, {"cuda"}) +<<<<<<< HEAD @unittest.skipIf(IS_FBCODE, "Not runnable in fbcode") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_triton_interpret(self): import subprocess @@ -1858,7 +1880,11 @@ def test_triton_interpret(self): def foo(x): return x + 1 +<<<<<<< HEAD # somehow gives different results.. still, check that it doesn't error +======= +# somehow gives different results.. still, check that it doesnt error +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) foo(torch.rand([256], device="cuda")) """ subprocess.run([sys.executable, "-c", script], check=True) @@ -2099,7 +2125,10 @@ def get_input() -> torch.Tensor: self.assertIn("znumel", code) @xfailIfPy312Plus # https://github.com/pytorch/pytorch/issues/142032 +<<<<<<< HEAD @unittest.skipIf(config.is_fbcode(), "Dependence on functorch.einops") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_repeated_masked_load(self): target_size = (8, 2) mem_eff_temporal_upsampling_interp_chunks = 2 @@ -2169,6 +2198,7 @@ def forward(self, x): self.assertEqual(default_output, max_autotune_output) +<<<<<<< HEAD def test_adaptive_avg_pool3d_issue_157248(self): """Test for GitHub issue #157248: Conv2d-unsqueeze-AdaptiveAvgPool3d produces incorrect results""" @@ -2221,4 +2251,12 @@ def forward(self, x): from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON if HAS_CUDA_AND_TRITON and not TEST_WITH_ASAN: +======= + +if __name__ == "__main__": + from torch._inductor.test_case import run_tests + from torch.testing._internal.inductor_utils import HAS_CUDA + + if HAS_CUDA and not TEST_WITH_ASAN: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests(needs="filelock") diff --git a/test/inductor/test_cudacodecache.py b/test/inductor/test_cudacodecache.py index b6786130416bd..de676d36fa038 100644 --- a/test/inductor/test_cudacodecache.py +++ b/test/inductor/test_cudacodecache.py @@ -1,15 +1,25 @@ # Owner(s): ["module: inductor"] import ctypes +<<<<<<< HEAD import torch +======= +import unittest + +import torch +from torch._inductor import config +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.async_compile import AsyncCompile from torch._inductor.codecache import CUDACodeCache from torch._inductor.codegen.cuda.cuda_env import nvcc_exist from torch._inductor.exc import CUDACompileError from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import fresh_cache +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _SOURCE_CODE = r""" @@ -36,8 +46,13 @@ """ +<<<<<<< HEAD class TestCUDACodeCache(InductorTestCase): @requires_cuda_and_triton +======= +@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUDA_HOME setup") +class TestCUDACodeCache(InductorTestCase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cuda_load(self): with fresh_cache(): # Test both .o and .so compilation. @@ -49,8 +64,13 @@ def test_cuda_load(self): dll_wrapper, so_hash_key, source_code_path1 = CUDACodeCache.load( _SOURCE_CODE, "so" ) +<<<<<<< HEAD self.assertEqual(source_code_path0, source_code_path1) self.assertEqual(object_hash_key, so_hash_key) +======= + self.assertNotEqual(source_code_path0, source_code_path1) + self.assertNotEqual(object_hash_key, so_hash_key) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Test load and call functions in .so. x = torch.rand(10).float().cuda() @@ -65,14 +85,20 @@ def test_cuda_load(self): ) torch.testing.assert_close(y, expected_y) +<<<<<<< HEAD @requires_cuda_and_triton +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_compilation_error(self): with fresh_cache(): error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1) with self.assertRaises(CUDACompileError): CUDACodeCache.compile(error_source_code, "o") +<<<<<<< HEAD @requires_cuda_and_triton +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_async_compile(self): with fresh_cache(): async_compile = AsyncCompile() diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 1dfc0a390eca7..bce22760d7f58 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -5,7 +5,10 @@ import gc import importlib import itertools +<<<<<<< HEAD import re +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import unittest import warnings @@ -41,7 +44,10 @@ skipIfRocm, TEST_CUDA_GRAPH, ) +<<<<<<< HEAD from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import TorchDispatchMode @@ -57,8 +63,16 @@ importlib.import_module("functorch") importlib.import_module("filelock") +<<<<<<< HEAD aten = torch.ops.aten +======= +from torch.testing._internal.inductor_utils import HAS_CUDA + + +aten = torch.ops.aten +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_multigpu = functools.partial( unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices" ) @@ -123,7 +137,11 @@ def tearDown(self): torch._dynamo.reset() +<<<<<<< HEAD if HAS_CUDA_AND_TRITON: +======= +if HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_all_cudagraph_segments(): segments = torch.cuda.memory_snapshot() @@ -280,6 +298,7 @@ def foo(x, y): with capture_stderr() as captured_output: foo(torch.ones([10], device="cuda"), torch.ones([20])) +<<<<<<< HEAD if torch._inductor.config.graph_partition: # graph partition splits on cpu ops self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) @@ -288,6 +307,12 @@ def foo(x, y): "skipping cudagraphs due to cpu device (arg1_1). Found from" ).check("y + 2").run(captured_output[0]) self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) +======= + FileCheck().check( + "skipping cudagraphs due to cpu device (arg1_1). Found from" + ).check("y + 2").run(captured_output[0]) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with capture_stderr() as captured_output: foo( @@ -297,10 +322,14 @@ def foo(x, y): FileCheck().check("skipping cudagraphs due to multiple devices").run( captured_output[0] ) +<<<<<<< HEAD self.assertEqual( counters["inductor"]["cudagraph_skips"], 1 if torch._inductor.config.graph_partition else 2, ) +======= + self.assertEqual(counters["inductor"]["cudagraph_skips"], 2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._inductor.config.patch("triton.cudagraph_skip_dynamic_graphs", True) def test_skip_symbolic(self): @@ -338,7 +367,11 @@ def inp(): ).check(".add_(2)").run(captured_output[0]) self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) +<<<<<<< HEAD # mutation on inp doesn't hit cudagraphs +======= + # mutation on inp doesnt hit cudagraphs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(len(self.get_manager().roots), 0) # mutation on parameters/buffers hits cudagraphs @@ -570,8 +603,13 @@ def foo2(x): del out # when I tried inducing separate recordings via graph break, +<<<<<<< HEAD # the frame kept interfering by keeping outputs alive # this isn't great by simulates the logic. +======= + # the frame kept interferring by keeping outputs alive + # this isnt great by simulates the logic. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo.mutation_guard import GenerationTracker GenerationTracker.generation -= 1 @@ -581,7 +619,11 @@ def foo2(x): foo_opt(torch.ones([4, 4], device="cuda")) +<<<<<<< HEAD # Two separate traces - one has a child, one doesn't +======= + # Two separate traces - one has a child, one doesnt +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(self.get_root_children(), [1, 0]) def test_execution_into_recording(self): @@ -815,6 +857,7 @@ def foo(x): # the three saved tensors should die in the backward # we kept alive the output self.assertEqual(self.curr_node().expected_dead_indices_before_graph, []) +<<<<<<< HEAD if torch._inductor.config.graph_partition: self.assertEqual( self.curr_node().expected_dead_indices_after_graph, @@ -825,6 +868,12 @@ def foo(x): self.curr_node().expected_dead_indices_after_graph, [(0, 1), (0, 2)], ) +======= + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 1), (0, 2)], + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertFalse(self.get_manager().new_graph_id().id == 0) self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) @@ -914,6 +963,7 @@ def test_unaligned_static_input_no_cudagraphs(self): self._test_unaligned_static_input_impl(expected_clones=0) @torch._inductor.config.patch("graph_partition", True) +<<<<<<< HEAD @torch._inductor.config.patch("implicit_fallbacks", True) def test_graph_partition_custom_rule(self): def get_num_partitions(code): @@ -975,6 +1025,8 @@ def f(x, flag): self.assertEqual(num_partitions, 1) @torch._inductor.config.patch("graph_partition", True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._inductor.config.patch("triton.cudagraph_trees", False) def test_graph_partition_gc(self): def _test_dummy(): @@ -1202,6 +1254,7 @@ def foo2(x): node = self.curr_node() first_node = next(node._path_from_root) +<<<<<<< HEAD if torch._inductor.config.graph_partition: # graph partition may changed the order of outputs self.assertFalse(first_node.unaliased_in_all_paths[1]) @@ -1209,6 +1262,10 @@ def foo2(x): else: self.assertFalse(first_node.unaliased_in_all_paths[0]) self.assertTrue(first_node.cached_tensor_outputs[0] is None) +======= + self.assertFalse(first_node.unaliased_in_all_paths[0]) + self.assertTrue(first_node.cached_tensor_outputs[0] is None) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._inductor.config.patch("implicit_fallbacks", True) def test_multinomial(self): @@ -1403,7 +1460,11 @@ def test_multiple_insert_removal_caching(self): torch._C._set_cached_tensors_enabled(False) def test_accumulate_grad(self): +<<<<<<< HEAD # cudagraph trees shouldn't interfere with accumulation logic +======= + # cudagraph trees shouldnt interfere with accumulation logic +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def compute_grad(grad_output, create_graph): x = torch.randn(5, 5, requires_grad=True, device="cuda") @@ -1444,7 +1505,11 @@ def foo(x): for _ in range(3): out = frozen(torch.rand([10, 10], device="cuda")) +<<<<<<< HEAD # didn't do additional recordings +======= + # didnt do additional recordings +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(self.get_manager().new_graph_id().id == 2) def test_empty_cpu_tensor(self): @@ -1711,6 +1776,7 @@ def foo(x): # the three saved tensors should die in the backward # we kept alive the output self.assertEqual(self.curr_node().expected_dead_indices_before_graph, []) +<<<<<<< HEAD if torch._inductor.config.graph_partition: self.assertEqual( self.curr_node().expected_dead_indices_after_graph, @@ -1721,6 +1787,12 @@ def foo(x): self.curr_node().expected_dead_indices_after_graph, [(0, 1), (0, 2)], ) +======= + self.assertEqual( + self.curr_node().expected_dead_indices_after_graph, + [(0, 1), (0, 2)], + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertFalse(self.get_manager().new_graph_id().id == 0) def test_separate_recordings(self): @@ -2223,8 +2295,13 @@ def forward(self, x) -> torch.Tensor: with self.assertRaisesRegex( Exception, r"(?s)static input data pointer changed.\n" +<<<<<<< HEAD r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*" r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*," +======= + r"input name: primals_2. data pointer changed from .* to .*. input stack trace:.*" + r"input name: primals_3. data pointer changed from .* to .*. input stack trace:.*," +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n", ): self.curr_node().run( @@ -2795,6 +2872,7 @@ def f(x, y): self.assertEqual(self.get_manager().new_graph_id().id, 2) @torch._inductor.config.patch("graph_partition", True) +<<<<<<< HEAD def test_graph_partition_log_message(self): def foo(x, y): return (x + 1, y + 2) @@ -2811,6 +2889,8 @@ def foo(x, y): ).run(captured_output[0]) @torch._inductor.config.patch("graph_partition", True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_graph_partition_cpu_scalar1(self): def f(x, y): return x + y @@ -2934,6 +3014,7 @@ def foo(x): self.assertEqual(x, torch.tensor(1, device="cpu")) @torch._inductor.config.patch("graph_partition", True) +<<<<<<< HEAD def test_graph_partition_cpu_scalar_multiple(self): def f(x, y, z): return x + y, x + z @@ -2956,6 +3037,8 @@ def f(x, y, z): self.assertEqual(self.get_manager().new_graph_id().id, 1) @torch._inductor.config.patch("graph_partition", True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._inductor.config.patch("triton.cudagraphs", False) def test_graph_partition_reduce_overhead_mode_effectiveness(self): # test that `mode="reduce-overhead"` still controls whether @@ -3297,6 +3380,7 @@ def fn(x): @config.patch(implicit_fallbacks=True) @torch._inductor.config.patch("graph_partition", True) +<<<<<<< HEAD def test_graph_partition_custom_op_mutation_late_free(self): @torch.library.custom_op( "mylib::op1", @@ -3351,6 +3435,8 @@ def f(x): @config.patch(implicit_fallbacks=True) @torch._inductor.config.patch("graph_partition", True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_graph_partition_custom_op_dynamoc_shapes(self): @torch.library.custom_op( "mylib::movement", @@ -3691,6 +3777,7 @@ def run(padded_size, original_size): self.assertEqual(self.get_manager().new_graph_id().id, 2) +<<<<<<< HEAD @torch._inductor.config.patch("graph_partition", True) def test_graph_partition_simple(self): def f(x, y): @@ -3963,6 +4050,8 @@ def foo(x): compiled_out = compiled_foo(x) self.assertEqual(eager_out, compiled_out) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_meta_tensor(self): def foobar(x, y): return x * 2, y * 3 @@ -4053,6 +4142,7 @@ def run(batch_size, seq_len, d): self.assertEqual(self.get_manager().new_graph_id().id, 4) +<<<<<<< HEAD @torch._inductor.config.patch("triton.cudagraph_or_error", True) def test_cudagraph_or_error(self): def f(x): @@ -4064,6 +4154,8 @@ def f(x): with self.assertRaises(RuntimeError): f(torch.tensor(1, device="cuda")) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestSAC(TestCase): def _make_observer_mode(self): class ObserverMode(TorchDispatchMode): @@ -4413,7 +4505,11 @@ def multi_fn(x, y, a, b): a = torch.randn(4, 4, device="cuda:1", requires_grad=True) b = torch.randn(4, 4, device="cuda:1", requires_grad=True) +<<<<<<< HEAD # No errors. TODO - get graphs from logging, couldn't figure out how +======= + # No errors. TODO - get graphs from logging, couldnt figure out how +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) multi_fn_c = torch.compile(multi_fn, backend="aot_eager_decomp_partition") out = multi_fn_c(x, y, a, b) @@ -4478,5 +4574,9 @@ def fn(x, y): sys.exit(0) raise unittest.SkipTest("cuda graph test is skipped") +<<<<<<< HEAD if HAS_CUDA_AND_TRITON: +======= + if HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests(needs="filelock") diff --git a/test/inductor/test_cudagraph_trees_expandable_segments.py b/test/inductor/test_cudagraph_trees_expandable_segments.py index 65597316091d4..036fe2687add5 100644 --- a/test/inductor/test_cudagraph_trees_expandable_segments.py +++ b/test/inductor/test_cudagraph_trees_expandable_segments.py @@ -8,13 +8,21 @@ import torch from torch.testing._internal.common_cuda import IS_JETSON, IS_WINDOWS from torch.testing._internal.common_utils import run_tests +<<<<<<< HEAD from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON +======= +from torch.testing._internal.inductor_utils import HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +<<<<<<< HEAD if HAS_CUDA_AND_TRITON: +======= +if HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: from .test_cudagraph_trees import CudaGraphTreeTests except ImportError: @@ -32,12 +40,16 @@ sys.path.remove(str(REPO_ROOT)) if __name__ == "__main__": +<<<<<<< HEAD if ( torch.cuda.is_available() and not IS_JETSON and not IS_WINDOWS and HAS_CUDA_AND_TRITON ): +======= + if torch.cuda.is_available() and not IS_JETSON and not IS_WINDOWS and HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) get_disabled_tests(".") torch.cuda.memory._set_allocator_settings("expandable_segments:True") diff --git a/test/inductor/test_custom_post_grad_passes.py b/test/inductor/test_custom_post_grad_passes.py index c7823845bd570..f427f66c5322e 100644 --- a/test/inductor/test_custom_post_grad_passes.py +++ b/test/inductor/test_custom_post_grad_passes.py @@ -66,6 +66,7 @@ def change_cos_pass(graph): node.target = aten.sin.default +<<<<<<< HEAD class ChangeCosCustomPass(CustomGraphPass): def __init__(self) -> None: super().__init__() @@ -77,6 +78,8 @@ def uuid(self) -> bytes: return get_hash_for_files((__file__,)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestPostGradCustomPrePostPass(TestCustomPassBase): # mkldnn fusion's pattern_matcher # (torch/_inductor/fx_passes/mkldnn_fusion.py), @@ -145,7 +148,11 @@ def forward(self, x): return x1.relu() def test_custom_joint_pass_pre(self): +<<<<<<< HEAD with config.patch(joint_custom_pre_pass=ChangeCosCustomPass()): +======= + with config.patch(joint_custom_pre_pass=change_cos_pass): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def g(x): return x.sin().sin().sin() @@ -157,7 +164,11 @@ def f(x): torch.testing.assert_close(torch.compile(f)(x), g(x)) def test_custom_joint_pass_post(self): +<<<<<<< HEAD with config.patch(joint_custom_post_pass=ChangeCosCustomPass()): +======= + with config.patch(joint_custom_post_pass=change_cos_pass): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def g(x): return x.sin().sin().sin() diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index b807df5d6691c..c8e820cf7faca 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -12,12 +12,18 @@ from pathlib import Path from typing import Callable, Optional +<<<<<<< HEAD from torch._dynamo.exc import BackendCompilerFailed +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer from torch._inductor.utils import clear_caches from torch.export import Dim from torch.testing._internal.logging_utils import log_settings +<<<<<<< HEAD from torch.utils import _pytree as pytree +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: @@ -58,12 +64,20 @@ _quantize_rowwise, _quantize_tensorwise, HAS_CPU, +<<<<<<< HEAD HAS_CUDA_AND_TRITON, +======= + HAS_CUDA, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) torch.set_float32_matmul_precision("high") +<<<<<<< HEAD if HAS_CUDA_AND_TRITON: +======= +if HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.cuda.memory._set_allocator_settings("expandable_segments:False") @@ -148,6 +162,7 @@ def gen_args(op, shape, dtype=torch.float16): ) +<<<<<<< HEAD def select_no_algorithm(*args, **kwargs): """ Utility function to skip precompilation and autotuning. @@ -160,6 +175,13 @@ class TestCutlassBackend(TestCase): def setUp(self): if not HAS_CUDA_AND_TRITON: self.skipTest("CUDA and triton are not available") +======= +@instantiate_parametrized_tests +class TestCutlassBackend(TestCase): + def setUp(self): + if not HAS_CUDA: + self.skipTest("CUDA is not available") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if torch.version.hip: self.skipTest("CUTLASS backend is not supported on HIP") @@ -187,7 +209,11 @@ def tearDown(self): def run_evt_test(self, model, op, shape, num_fusions=1): M, N = shape a = torch.ones(M, N).cuda().half() +<<<<<<< HEAD b = torch.ones(N, N).cuda().half().t() +======= + b = torch.ones(N, N).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extra_args = gen_args(op, (M, N)) model = model.cuda() @@ -200,6 +226,7 @@ def run_evt_test(self, model, op, shape, num_fusions=1): ) torch.testing.assert_close(result, ref_result) +<<<<<<< HEAD def test_check_paths(self): cutlass_mock_imports_path = os.path.join( os.path.dirname(torch.__file__), @@ -213,6 +240,8 @@ def test_check_paths(self): self.assertTrue(os.path.exists(cutlass_mock_pydot_path)) self.assertTrue(os.path.exists(cutlass_mock_scipy_path)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_max_autotune_cutlass_threshold(self): @@ -224,7 +253,11 @@ def mm(a, b): return a @ b a = torch.randn(100, 10).cuda().half() +<<<<<<< HEAD b = torch.randn(100, 10).cuda().half().t() +======= + b = torch.randn(10, 100).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with config.patch( { @@ -235,6 +268,13 @@ def mm(a, b): "cuda.cutlass_max_profiling_configs": 2, } ): +<<<<<<< HEAD +======= + + def select_no_algorithm(*args, **kwargs): + raise NoValidChoicesError + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with mock.patch( "torch._inductor.kernel.mm.autotune_select_algorithm", wraps=select_no_algorithm, @@ -252,10 +292,14 @@ def test_import_cutlass(self): self.assertTrue(try_import_cutlass()) +<<<<<<< HEAD if config.is_fbcode(): import python_cutlass else: import cutlass as python_cutlass # noqa: F401 +======= + import cutlass # noqa: F401 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import cutlass_library # noqa: F401 def test_cutlass_key(self): @@ -279,7 +323,11 @@ def test_cutlass_backend_subproc_mm(self): M, N, K = 4096, 2048, 25728 a = torch.randn(M, K).cuda().half() +<<<<<<< HEAD b = torch.randn(N, K).cuda().half().t() +======= + b = torch.randn(K, N).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with config.patch( { @@ -294,19 +342,34 @@ def test_cutlass_backend_subproc_mm(self): Y = torch.mm(a, b) torch.testing.assert_close(Y_compiled, Y) +<<<<<<< HEAD @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) @parametrize("dtype", (torch.float16, torch.bfloat16)) def test_cutlass_backend_subproc_addmm(self, dtype): +======= + @unittest.skipIf( + True, "FIXME: Disabled temporarily since IMA or crashing in subprocess" + ) + @unittest.skipIf(not SM90OrLater, "need sm_90") + @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + def test_cutlass_backend_subproc_addmm(self, shape_combo): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Test autotune_in_subproc works for addmm. """ M, N, K = 4096, 2048, 25728 +<<<<<<< HEAD dtype = torch.float16 a = torch.randn(M, K, dtype=dtype).cuda() b = torch.randn(N, K, dtype=dtype).cuda().t() +======= + + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x_shapes = [ (M, N), @@ -328,10 +391,14 @@ def test_cutlass_backend_subproc_addmm(self, dtype): } ): for x_shape in x_shapes: +<<<<<<< HEAD torch._dynamo.reset() clear_caches() x = torch.randn(x_shape).cuda().to(dtype) +======= + x = torch.randn(x_shape).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Y_compiled = torch.compile(torch.addmm)(x, a, b, alpha=alpha, beta=beta) Y = torch.addmm(x, a, b, alpha=alpha, beta=beta) torch.testing.assert_close(Y_compiled, Y) @@ -346,7 +413,11 @@ def test_cutlass_backend_subproc_bmm(self): B, M, N, K = 10, 4096, 2048, 25728 a = torch.randn(B, M, K).cuda().half() +<<<<<<< HEAD b = torch.randn(B, N, K).cuda().half().permute(0, 2, 1) +======= + b = torch.randn(B, K, N).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with config.patch( { @@ -378,8 +449,13 @@ def forward(self, a, b, c): model = MyModel() a = torch.randn(128, 16).cuda().half() +<<<<<<< HEAD b = torch.randn(128, 16).cuda().half().t() c = torch.randn(512, 16).cuda().half().t() +======= + b = torch.randn(16, 128).cuda().half() + c = torch.randn(16, 512).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with config.patch( { @@ -420,8 +496,13 @@ def forward(self, a, b, c): model = MyModel() a = torch.randn(128, 16).cuda().half() +<<<<<<< HEAD b = torch.randn(128, 16).cuda().half().t() c = torch.randn(512, 16).cuda().half().t() +======= + b = torch.randn(16, 128).cuda().half() + c = torch.randn(16, 512).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with config.patch( { @@ -434,9 +515,13 @@ def forward(self, a, b, c): 2, 4, ], # guarantees > 1 choices +<<<<<<< HEAD "fx_graph_cache": False, "fx_graph_remote_cache": False, "autotune_local_cache": False, +======= + "force_disable_caches": True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ): from torch._inductor.utils import run_and_get_code @@ -655,7 +740,11 @@ def forward(self, x, a, b): ( torch.randn(x_shape(M, N)).cuda().to(dtype), torch.randn(M, K).cuda().to(dtype), +<<<<<<< HEAD torch.randn(N, K).cuda().to(dtype).t(), +======= + torch.randn(K, N).cuda().to(dtype), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for (M, N, K) in shapes ] @@ -697,7 +786,10 @@ def forward(self, x, a, b): @parametrize("dynamic", (False, True)) @parametrize("use_aoti", (False, True)) @parametrize("dtype", (torch.float16, torch.bfloat16)) +<<<<<<< HEAD @parametrize("use_expand", (False, True)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_max_autotune_cutlass_backend_bmm( self, @@ -705,7 +797,10 @@ def test_max_autotune_cutlass_backend_bmm( use_aoti: bool = False, max_autotune_gemm_backends: str = "CUTLASS", dtype: torch.dtype = torch.float16, +<<<<<<< HEAD use_expand: bool = False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Main test for bmm. @@ -723,6 +818,7 @@ def forward(self, a, b): ] shapes = shapes[0:1] if not dynamic else shapes +<<<<<<< HEAD inputs = [] for B, M, N, K in shapes: if use_expand: @@ -734,6 +830,15 @@ def forward(self, a, b): B_tensor = torch.randn(B, N, K).cuda().to(dtype).permute(0, 2, 1) inputs.append((A, B_tensor)) +======= + inputs = [ + ( + torch.randn(B, M, K).cuda().to(dtype), + torch.randn(B, N, K).cuda().to(dtype).permute(0, 2, 1), + ) + for B, M, N, K in shapes + ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dynamic_shapes = ( { "a": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC, 2: Dim.DYNAMIC}, @@ -768,7 +873,15 @@ def test_max_autotune_cutlass_backend_regular_mm_streamk( Make sure autotuning mm in sub processes work without crashes. """ +<<<<<<< HEAD compiled_model = torch.compile(torch.mm, dynamic=dynamic) +======= + def mm(a, b): + return a @ b + + a = torch.randn(128, 16).cuda().half() + b = torch.randn(16, 128).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with config.patch( { @@ -794,13 +907,20 @@ def test_max_autotune_cutlass_backend_regular_mm_streamk( ), ): a = torch.randn(M, K).cuda().half() +<<<<<<< HEAD b = torch.randn(N, K).cuda().half().t() Y_compiled = compiled_model(a, b) Y = torch.mm(a, b) +======= + b = torch.randn(K, N).cuda().half() + Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) + Y = mm(a, b) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # we need relaxed numerical limits due to the sheer size of the # matmuls involved. Many small addition differences add up. torch.testing.assert_close(Y_compiled, Y, atol=0.01, rtol=0.01) +<<<<<<< HEAD @unittest.skipIf(not SM90OrLater, "need sm_90") def test_streamk_with_dynamic( self, @@ -855,6 +975,8 @@ def test_streamk_with_static( ): _ = compiled_model(a, b) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _test_max_autotune_cutlass_backend_epilogue_fusion( self, dynamic: bool = False, @@ -871,10 +993,17 @@ def _test_max_autotune_cutlass_backend_epilogue_fusion( # that allows fusions if batch_size is None: a = torch.randn(256, 32).cuda() +<<<<<<< HEAD b = torch.randn(256, 32).cuda().t() else: a = torch.randn(batch_size, 256, 32).cuda() b = torch.randn(batch_size, 256, 32).cuda().permute(0, 2, 1) +======= + b = torch.randn(32, 256).cuda() + else: + a = torch.randn(batch_size, 256, 32).cuda() + b = torch.randn(batch_size, 32, 256).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if fp16: a = a.half() b = b.half() @@ -1013,7 +1142,11 @@ def forward(self, x, w): } x = torch.randn(M, K).cuda().half() +<<<<<<< HEAD w = torch.randn(N, K).cuda().half().t() +======= + w = torch.randn(K, N).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) actual = AOTIRunnerUtil.run( model, @@ -1051,7 +1184,11 @@ def forward(self, x, w): } x = torch.randn(M, K).cuda().half() +<<<<<<< HEAD w = torch.randn(N, K).cuda().half().t() +======= + w = torch.randn(K, N).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) actual = AOTIRunnerUtil.run( model, @@ -1081,7 +1218,11 @@ def forward(self, x, w): M, N, K = 200, 5216, 10_432 x = torch.randn(M, K).cuda().half() +<<<<<<< HEAD w = torch.randn(N, K).cuda().half().t() +======= + w = torch.randn(K, N).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) actual = AOTIRunnerUtil.run( model, @@ -1150,7 +1291,14 @@ def my_addmm(x, a, b, alpha, beta): x = torch.randn((128, 128)).cuda().half() a = torch.randn(128, 128).cuda().half() +<<<<<<< HEAD b = torch.randn(128, 128).cuda().half().t() +======= + b = torch.randn(128, 128).cuda().half() + + def select_no_algorithm(*args, **kwargs): + raise NoValidChoicesError +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with fresh_cache(): with config.patch( @@ -1195,7 +1343,14 @@ def addmm(x, a, b, alpha, beta): x = torch.randn((128, 128)).cuda().half() a = torch.randn(128, 128).cuda().half() +<<<<<<< HEAD b = torch.randn(128, 128).cuda().half().t() +======= + b = torch.randn(128, 128).cuda().half() + + def select_no_algorithm(*args, **kwargs): + raise NoValidChoicesError +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with fresh_cache(): with config.patch( @@ -1267,6 +1422,12 @@ def linear( linear_compiled = torch.compile(linear, backend="inductor") +<<<<<<< HEAD +======= + def select_no_algorithm(*args, **kwargs): + raise NoValidChoicesError + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def run_test(use_fast_accum): with fresh_cache(): with config.patch( @@ -1344,6 +1505,12 @@ def test_cutlass_backend_shape_coverage_mm( ), ] +<<<<<<< HEAD +======= + def select_no_algorithm(*args, **kwargs): + raise NoValidChoicesError + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with ( fresh_cache(), config.patch( @@ -1401,7 +1568,14 @@ def test_cutlass_presets( M, N, K = (128, 128, 16) A = torch.randn(M, K).cuda().half() +<<<<<<< HEAD B = torch.randn(N, K).cuda().half().t() +======= + B = torch.randn(K, N).cuda().half() + + def select_no_algorithm(*args, **kwargs): + raise NoValidChoicesError +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with ( fresh_cache(), @@ -1510,7 +1684,11 @@ def test_standalone_runner(self): max_autotune_gemm_backends = "CUTLASS" a = torch.randn(128, 16).cuda().half() +<<<<<<< HEAD b = torch.randn(128, 16).cuda().half().t() +======= + b = torch.randn(16, 128).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with config.patch( { @@ -1593,7 +1771,11 @@ def mm(a, b): return a @ b a = torch.randn(128, 16).cuda().half() +<<<<<<< HEAD b = torch.randn(128, 16).cuda().half().t() +======= + b = torch.randn(16, 128).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with config.patch( { @@ -1601,8 +1783,12 @@ def mm(a, b): "max_autotune_gemm_backends": "ATEN,TRITON,CUTLASS", "cuda.cutlass_max_profiling_configs": 2, # needed for log searching +<<<<<<< HEAD "fx_graph_cache": False, "fx_graph_remote_cache": False, +======= + "force_disable_caches": True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ): with ( @@ -1625,6 +1811,7 @@ def mm(a, b): self.assertTrue(num_ops > 0, "The number of ops should be greater than 0") @unittest.skipIf(not SM90OrLater, "need sm_90") +<<<<<<< HEAD def test_maybe_append_choice_caching(self): """ Test if maybe_append_choice's caching leads to correct results and @@ -1796,6 +1983,8 @@ def counting_render(self, *args, **kwargs): self.assertEqual(render_call_count, num_matmuls + num_matmuls * 2) @unittest.skipIf(not SM90OrLater, "need sm_90") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_cutlass_backend_matmul_same_tensor(self): max_autotune_gemm_backends = "CUTLASS" @@ -1816,6 +2005,7 @@ def test_cutlass_backend_matmul_same_tensor(self): @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) +<<<<<<< HEAD def test_cutlass_backend_matmul_nonzero_offset(self): max_autotune_gemm_backends = "CUTLASS" @@ -1836,11 +2026,17 @@ def test_cutlass_backend_matmul_nonzero_offset(self): @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_flexible_layout(self): class TestModel(torch.nn.Module): def forward(self, B): A = torch.zeros_like(B) +<<<<<<< HEAD return A @ B.t() +======= + return A @ B +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) M = 1024 B = torch.randn(M, M).cuda().half() @@ -1862,7 +2058,11 @@ def test_evt_flexible_layout(self): class TestModel(torch.nn.Module): def forward(self, B): A = torch.zeros_like(B) +<<<<<<< HEAD return (A @ B.t()).relu() +======= + return (A @ B).relu() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) M = 1024 B = torch.randn(M, M).cuda().half() @@ -1888,7 +2088,11 @@ class TestModel(torch.nn.Module): def forward(self, B): A = torch.zeros_like(B) for _ in range(100): +<<<<<<< HEAD A = A @ B.t() +======= + A = A @ B +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return A M = 1024 @@ -1908,6 +2112,7 @@ def forward(self, B): @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) +<<<<<<< HEAD @parametrize("use_aoti", (False, True)) def test_compilation_time(self, use_aoti): M = 1024 @@ -1920,6 +2125,12 @@ def forward(self, a, b): model = MyModel().cuda() expected = model(A, B) +======= + def test_compilation_time(self): + M = 1024 + A = torch.randn(M, M).cuda().half() + B = torch.randn(M, M).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) start_time = time.time() with config.patch( @@ -1929,6 +2140,7 @@ def forward(self, a, b): "cuda.cutlass_max_profiling_configs": 1, } ): +<<<<<<< HEAD if use_aoti: actual = AOTIRunnerUtil.run( model, @@ -1938,6 +2150,9 @@ def forward(self, a, b): actual = torch.compile(model, fullgraph=True)(A, B) torch.testing.assert_close(actual, expected) +======= + _ = torch.compile(torch.mm)(A, B) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(time.time() - start_time < 50) @unittest.skipIf(not SM90OrLater, "need sm_90") @@ -1964,7 +2179,11 @@ def forward(self, a, b, extra_args): M = 1024 N = 512 a = torch.ones(M, N).cuda().half() +<<<<<<< HEAD b = torch.ones(N, N).cuda().half().t() +======= + b = torch.ones(N, N).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extra_args = gen_args(op, (M, N)) model = TestModel().cuda() @@ -1994,7 +2213,11 @@ def forward(self, a, b, extra_args): model = TestModel().cuda() a = torch.ones(M, N).cuda().half() +<<<<<<< HEAD b = torch.ones(N, N).cuda().half().t() +======= + b = torch.ones(N, N).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extra_args = gen_args(op, (M, N), dtype=torch.float16) # baseline is cutlass kernel + triton @@ -2059,7 +2282,11 @@ def forward(self, a, b, extra_args): for i, shape in enumerate(shapes): M, N = shape a = torch.ones(M, N).cuda().half() +<<<<<<< HEAD b = torch.ones(N, N).cuda().half().t() +======= + b = torch.ones(N, N).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extra_args = gen_args(op, (M, N)) model = TestModel().cuda() @@ -2087,7 +2314,11 @@ def forward(self, a, b, extra_args): M = 1024 N = 512 a = torch.ones(M, N).cuda().half() +<<<<<<< HEAD b = torch.ones(N, N).cuda().half().t() +======= + b = torch.ones(N, N).cuda().half() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extra_args = gen_args(op, (M, N)) model = TestModel().cuda() @@ -2101,13 +2332,18 @@ def forward(self, a, b, extra_args): @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) @parametrize("arch", ("90", "100")) +<<<<<<< HEAD @parametrize("cuda_version", ("12.4", "12.8")) +======= + @parametrize("cuda_version", ("12.4", "12.6", "12.8")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_gemm_operation_serialization(self, arch: str, cuda_version: str): """ Testing serialization for GEMM operations generated by CUTLASS. This should cover GroupedGemmOperation as well. """ full_ops = _gen_ops_cached(arch, cuda_version) +<<<<<<< HEAD ops = pytree.tree_flatten(full_ops)[0] # sanity check @@ -2116,16 +2352,32 @@ def test_gemm_operation_serialization(self, arch: str, cuda_version: str): # test if configuration name is unique op_config_names = [op.configuration_name() for op in ops] self.assertEqual(len(op_config_names), len(set(op_config_names))) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) serializer = get_cutlass_operation_serializer() self.assertIsNotNone(serializer) +<<<<<<< HEAD serialized_ops = [serializer.serialize(op) for op in ops] deserialized_ops = [ serializer.deserialize(serialized_op) for serialized_op in serialized_ops ] for op, deserialized_op in zip(ops, deserialized_ops): self.assertTrue(_check_if_instances_equal(op, deserialized_op)) +======= + count = 0 + for ops in full_ops.values(): + for op_dict in ops.values(): + for op_list in op_dict.values(): + for op in op_list: + count += 1 + serialized = serializer.serialize(op) + deserialized = serializer.deserialize(serialized) + self.assertTrue(_check_if_instances_equal(op, deserialized)) + + self.assertGreater(count, 1000, "Too few ops generated") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+") @unittest.skipIf(not SM90OrLater, "need sm_90") @@ -2142,25 +2394,40 @@ def test_gemm_operation_serialization(self, arch: str, cuda_version: str): ), ) @parametrize("has_bias", (False, True)) +<<<<<<< HEAD @parametrize("use_fast_accum", (False, True)) @parametrize("input_dtype", (torch.bfloat16, torch.float16)) +======= + @parametrize("use_fast_accum", (False,)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_fp8_rowwise_scaling( self, float8_dtype: torch.dtype, shape: tuple[int, int, int], has_bias: bool, use_fast_accum: bool, +<<<<<<< HEAD input_dtype: torch.dtype, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): # Only bf16 output type is supported for row-wise scaling, not fp32 output_dtype: torch.dtype = torch.bfloat16 device = "cuda" M, K, N = shape # Matmul Y = X [M, K] x W [N, K] +<<<<<<< HEAD x = torch.randn(M, K, dtype=input_dtype, device=device) w = torch.randn(N, K, dtype=input_dtype, device=device) bias = None if has_bias: bias = torch.randn(N, device=device, dtype=input_dtype).to(torch.bfloat16) +======= + x = torch.randn(M, K, dtype=output_dtype, device=device) + w = torch.randn(N, K, dtype=output_dtype, device=device) + bias = None + if has_bias: + bias = torch.randn(N, device=device, dtype=torch.bfloat16) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # quantize weight (prior to inference) w_fp8, w_inverse_scale = _quantize_rowwise(w, float8_dtype) @@ -2210,6 +2477,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): ( ( 512, +<<<<<<< HEAD 1024, ), ), @@ -2304,6 +2572,8 @@ def forward(self, x): ( ( 512, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 128, 64, ), @@ -2311,25 +2581,40 @@ def forward(self, x): ) @parametrize("has_bias", (False, True)) @parametrize("use_fast_accum", (False,)) +<<<<<<< HEAD @parametrize("input_dtype", (torch.bfloat16, torch.float16)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_fp8_tensorwise_scaling( self, float8_dtype: torch.dtype, shape: tuple[int, int, int], has_bias: bool, use_fast_accum: bool, +<<<<<<< HEAD input_dtype: torch.dtype, ): device = "cuda" M, K, N = shape # Matmul Y = X [M, K] x W [N, K] output_dtype = input_dtype +======= + ): + device = "cuda" + M, K, N = shape # Matmul Y = X [M, K] x W [N, K] + input_dtype = torch.bfloat16 + output_dtype = torch.bfloat16 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # input and output dtypes of _scaled_mm do not need to be the same, but # typically in a model they are x = torch.randn(M, K, dtype=input_dtype, device=device) w = torch.randn(N, K, dtype=input_dtype, device=device) bias = None if has_bias: +<<<<<<< HEAD bias = torch.randn(N, device=device, dtype=input_dtype) +======= + bias = torch.randn(N, device=device, dtype=torch.bfloat16) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # quantize weight (prior to inference) w_fp8, w_inverse_scale = _quantize_tensorwise(w, float8_dtype) @@ -2373,6 +2658,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): # setting a small absolute tolerance in these tests torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) +<<<<<<< HEAD @unittest.skipIf(not SM90OrLater, "need sm_90") def test_config_number_post_filtering(self) -> None: """ @@ -2423,10 +2709,16 @@ def test_config_number_post_filtering(self) -> None: f"Got counts: {config_counts}", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._inductor.utils import is_big_gpu # Set env to make it work in CI. +<<<<<<< HEAD if HAS_CUDA_AND_TRITON and HAS_CPU and is_big_gpu(): +======= + if HAS_CUDA and HAS_CPU and is_big_gpu(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests() diff --git a/test/inductor/test_cutlass_evt.py b/test/inductor/test_cutlass_evt.py index cae9558d2ec2a..f2b115a2b458d 100644 --- a/test/inductor/test_cutlass_evt.py +++ b/test/inductor/test_cutlass_evt.py @@ -4,21 +4,32 @@ import sympy import torch +<<<<<<< HEAD import torch._inductor.config as config +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo.test_case import TestCase from torch._inductor.codegen.cuda.cutlass_utils import ( torch_dtype_to_cutlass_type, try_import_cutlass, ) +<<<<<<< HEAD +======= +from torch._inductor.graph import GraphLowering +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.ir import ComputedBuffer, FixedLayout, PermuteView, Pointwise from torch._inductor.scheduler import BaseSchedulerNode from torch._inductor.utils import OrderedSet from torch.testing._internal.common_cuda import SM90OrLater +<<<<<<< HEAD from torch.testing._internal.inductor_utils import ( HAS_CPU, HAS_CUDA_AND_TRITON, MockGraphHandler, ) +======= +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if try_import_cutlass(): @@ -30,6 +41,7 @@ from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import ( _render_argument_type, _trace, +<<<<<<< HEAD trace, ) @@ -39,6 +51,12 @@ import cutlass as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401 CutlassTensor = python_cutlass.backend.evt.ir.tensor.Tensor +======= + CutlassTensor, + trace, + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) BIAS_CODE = """def example_epilogue(accum, C, aux, bias): F = accum + C + aux E = relu(F) + bias @@ -108,6 +126,20 @@ def num_reads(self): return 1 +<<<<<<< HEAD +======= +class MockGraphHandler(GraphLowering): + def __init__(self, name_to_buffer): + import torch._inductor.sizevars + + self.sizevars = torch._inductor.sizevars.SizeVarAllocator() + self.name_to_buffer = name_to_buffer + self.graph_inputs = dict() + self.mutated_buffers = OrderedSet() + self.constants = dict() + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestCutlassEVT(TestCase): @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") @@ -337,6 +369,7 @@ def test_example_tensor_creation(self): from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import ( create_example_tensors, ) +<<<<<<< HEAD from torch._inductor.virtualized import V with V.set_graph_handler(MockGraphHandler({})): @@ -362,6 +395,31 @@ def test_example_tensor_creation(self): self.assertEqual( result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32) ) +======= + + row_major_buf0 = MockComputedBuffer( + "buf0", None, torch.float32, (3, 4, 1), (4, 1, 0) + ) + col_major_buf1 = MockComputedBuffer( + "buf1", None, torch.float32, (3, 2, 1), (1, 3, 0) + ) + buffer_renames = {"buf0": "buf0", "buf1": "buf1", "acc": "buf0"} + name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1} + result = create_example_tensors( + buffer_renames, name_to_buffer, lambda x: int(x) + ) + self.assertEqual(result["acc"].shape, (3, 4, 1)) + self.assertEqual(result["acc"].stride, (4, 1, 0)) + self.assertEqual( + result["acc"].element, torch_dtype_to_cutlass_type(torch.float32) + ) + + self.assertEqual(result["buf1"].shape, (3, 2, 1)) + self.assertEqual(result["buf1"].stride, (1, 3, 0)) + self.assertEqual( + result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") @@ -376,7 +434,11 @@ def test_evt_argument_codegen(self): epilogue_functor, _create_mock_buffer_name_map(EXAMPLE_TENSORS), lambda x: int(x), +<<<<<<< HEAD )[0], +======= + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """\ { /* thread */ { /* F */ @@ -386,12 +448,21 @@ def test_evt_argument_codegen(self): {}, /* C */ {}, /* compute_0 */ }, +<<<<<<< HEAD {/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */ {}, /* compute_1 */ }, {/* ptr_aux */ (float*) (ptr_1 + ptr_1_offset), /* dAux */ {2048, _1{}, _0{}}}, /* F */ }, {/* ptr_col */ (float*) (ptr_2 + ptr_2_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */ +======= + {/* ptr_aux */ (float*) aux, /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */ + {}, /* compute_1 */ + }, + {/* ptr_aux */ (float*) F, /* dAux */ {2048, _1{}, _0{}}}, /* F */ + }, + {/* ptr_col */ (float*) bias, /* null_default */ float(0), /* dCol */ {}}, /* bias */ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {}, /* compute_2 */ {}, /* compute_3 */ {}, /* compute_4 */ @@ -433,14 +504,24 @@ def fn(accum, bias): epilogue_functor, _create_mock_buffer_name_map(example_tensors), lambda x: int(x), +<<<<<<< HEAD )[0], +======= + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """\ { /* thread */ { /* E */ {}, /* accum */ +<<<<<<< HEAD {/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* dAux */ {2048, _1{}, _0{}}}, /* E */ }, {/* ptr_col */ (float*) (ptr_1 + ptr_1_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */ +======= + {/* ptr_aux */ (float*) E, /* dAux */ {2048, _1{}, _0{}}}, /* E */ + }, + {/* ptr_col */ (float*) bias, /* null_default */ float(0), /* dCol */ {}}, /* bias */ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {}, /* compute_0 */ } """, @@ -449,7 +530,11 @@ def fn(accum, bias): @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_evt_codegen(self): +<<<<<<< HEAD _, _, code, _ = trace( +======= + _, _, code = trace( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) BIAS_CODE, EXAMPLE_TENSORS, DataType.f32, @@ -565,5 +650,9 @@ def test_evt_codegen(self): if __name__ == "__main__": from torch._dynamo.test_case import run_tests +<<<<<<< HEAD if HAS_CPU or HAS_CUDA_AND_TRITON: +======= + if HAS_CPU or HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests(needs="filelock") diff --git a/test/inductor/test_decompose_mem_bound_mm.py b/test/inductor/test_decompose_mem_bound_mm.py index 919d97f987f64..2603c020e0a59 100644 --- a/test/inductor/test_decompose_mem_bound_mm.py +++ b/test/inductor/test_decompose_mem_bound_mm.py @@ -12,10 +12,19 @@ from torch.testing import FileCheck from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, +<<<<<<< HEAD parametrize, TEST_XPU, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON +======= + patch_test_members, + is_navi3_arch, + parametrize, + TEST_XPU, +) +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.triton_utils import requires_gpu @@ -48,6 +57,7 @@ def forward(self, input1, input2): return output +<<<<<<< HEAD class TestDecomposeAddMM(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -58,6 +68,8 @@ def forward( return torch.ops.aten.addmm.default(z, x, y) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @requires_gpu @unittest.skipIf( TEST_XPU, @@ -71,12 +83,29 @@ def forward( ) @instantiate_parametrized_tests class TestDecomposeMemMM(TestCase): +<<<<<<< HEAD def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): +======= + def __init__(self, method_name='runTest', methodName='runTest'): + super().__init__(method_name, methodName) + self.atol = 1e-3 + self.rtol = 1e-3 + + def setup_tolerance(self, rtol=None, atol=None): + if rtol is None: + rtol = self.rtol + if atol is None: + atol = self.rtol + + def compare_dict_tensors(self, ref_dict, res_dict, rtol=None, atol=None): + self.setup_tolerance(rtol, atol) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if len(set(ref_dict.keys())) != len(set(res_dict.keys())): return False for key1 in ref_dict.keys(): key2 = "_orig_mod." + key1 assert key2 in res_dict, f"{key1} does not exist in traced module" +<<<<<<< HEAD if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol): return False return True @@ -96,6 +125,30 @@ def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): res_grad = {key: param.grad for key, param in traced.named_parameters()} self.assertTrue( self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol) +======= + if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=self.rtol, atol=self.atol): + return False + return True + + def compare_pred(self, module, traced, input, rtol=None, atol=None): + self.setup_tolerance(rtol, atol) + ref = module(*input) + res = traced(*input) + self.assertEqual(ref, res, rtol=self.rtol, atol=self.atol) + + def compare_parameters(self, module, traced, rtol=None, atol=None): + self.setup_tolerance(rtol, atol) + ref_params = dict(module.named_parameters()) + res_params = dict(traced.named_parameters()) + self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol=self.rtol, atol=self.atol)) + + def compare_gradients(self, module, traced, rtol=None, atol=None): + self.setup_tolerance(rtol, atol) + ref_grad = {key: param.grad for key, param in module.named_parameters()} + res_grad = {key: param.grad for key, param in traced.named_parameters()} + self.assertTrue( + self.compare_dict_tensors(ref_grad, res_grad, rtol=self.rtol, atol=self.atol) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @parametrize( @@ -117,7 +170,11 @@ def test_decompose_bmm(self, b, m, n, k, should_decompose): self.compare_pred(module, traced, input) +<<<<<<< HEAD expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 +======= + expected_val = 1 if should_decompose and HAS_CUDA else 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( counters["inductor"]["decompose_bmm"], expected_val, @@ -128,7 +185,11 @@ def test_decompose_bmm(self, b, m, n, k, should_decompose): self.compare_parameters(module, traced) self.compare_gradients(module, traced) +<<<<<<< HEAD expected_val = 3 if should_decompose and HAS_CUDA_AND_TRITON else 0 +======= + expected_val = 3 if should_decompose and HAS_CUDA else 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( counters["inductor"]["decompose_bmm"], expected_val, @@ -177,7 +238,11 @@ def test_decompose_linear(self, m, n, k, has_bias, should_decompose): self.compare_pred(module, traced, input) +<<<<<<< HEAD expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 +======= + expected_val = 1 if should_decompose and HAS_CUDA else 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if has_bias: self.assertEqual( counters["inductor"]["decompose_addmm"], @@ -202,6 +267,15 @@ def test_decompose_linear(self, m, n, k, has_bias, should_decompose): ) counters.clear() +<<<<<<< HEAD +======= + # We have to increase tolerance for navi3 because all fp16, bf16 + # GEMMs operations have an accuracy issue caused by hardware limitation + @patch_test_members({ + "atol": 2e-3 if is_navi3_arch() else 1e-3, + "rtol": 2e-3 if is_navi3_arch() else 1e-3 + }) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize( "m,k,n, should_decompose", [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)], @@ -224,7 +298,11 @@ def test_decompose_linear_mixed_precision( self.compare_pred(module, traced, input) +<<<<<<< HEAD expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 +======= + expected_val = 1 if should_decompose and HAS_CUDA else 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if has_bias: self.assertEqual( counters["inductor"]["decompose_addmm"], @@ -269,7 +347,11 @@ def test_decompose_mm(self, m, n, k, has_bias, should_decompose): self.compare_pred(module, traced, input) +<<<<<<< HEAD expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 +======= + expected_val = 1 if should_decompose and HAS_CUDA else 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( counters["inductor"]["decompose_mm"], expected_val, @@ -281,16 +363,28 @@ def test_decompose_mm(self, m, n, k, has_bias, should_decompose): self.compare_parameters(module, traced) self.compare_gradients(module, traced) +<<<<<<< HEAD expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 +======= + expected_val = 1 if should_decompose and HAS_CUDA else 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( counters["inductor"]["decompose_mm"] - decompose_mm_fwd, expected_val, ) counters.clear() +<<<<<<< HEAD @parametrize( "m,k,n, should_decompose", [(1, 64, 16, True), (2, 64, 16, False), (1, 64, 32, True)], +======= + # (1, 64, 32, False) vesrion fails + @unittest.skip + @parametrize( + "m,k,n, should_decompose", + [(1, 64, 16, True), (2, 64, 16, False), (1, 64, 32, False)], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_decompose_mm_cpu(self, m, n, k, should_decompose): torch._logging.set_logs(inductor=logging.DEBUG) @@ -310,6 +404,15 @@ def test_decompose_mm_cpu(self, m, n, k, should_decompose): ) counters.clear() +<<<<<<< HEAD +======= + # We have to increase tolerance for navi3 because all fp16, bf16 + # GEMMs operations have an accuracy issue caused by hardware limitation + @patch_test_members({ + "atol": 3e-3 if is_navi3_arch() else 1e-3, + "rtol": 4e-3 if is_navi3_arch() else 1e-3 + }) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize( "m,k,n, should_decompose", [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)], @@ -331,7 +434,11 @@ def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose) self.compare_pred(module, traced, input) +<<<<<<< HEAD expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 +======= + expected_val = 1 if should_decompose and HAS_CUDA else 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( counters["inductor"]["decompose_mm"], expected_val, @@ -343,7 +450,11 @@ def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose) self.compare_parameters(module, traced) self.compare_gradients(module, traced) +<<<<<<< HEAD expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 +======= + expected_val = 1 if should_decompose and HAS_CUDA else 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( counters["inductor"]["decompose_mm"] - decompose_mm_fwd, expected_val, @@ -367,7 +478,11 @@ def test_dynamic_shape(self, m, n, k, has_bias, should_decompose): self.compare_pred(module, traced, input) +<<<<<<< HEAD expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0 +======= + expected_val = 1 if should_decompose and HAS_CUDA else 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if has_bias: self.assertEqual( counters["inductor"]["decompose_addmm"], @@ -381,7 +496,11 @@ def test_dynamic_shape(self, m, n, k, has_bias, should_decompose): self.compare_gradients(module, traced) expected_val = 0 +<<<<<<< HEAD if HAS_CUDA_AND_TRITON: +======= + if HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expected_val = 1 if has_bias else 2 self.assertEqual( @@ -434,6 +553,7 @@ def test_check_device(self): self.assertFalse(check_device(input1, input2, device="mtia")) +<<<<<<< HEAD @torch._inductor.config.patch( post_grad_fusion_options={ "decompose_mm_pass": {"skip_dynamic_shape_dim_check": True}, @@ -459,6 +579,8 @@ def test_dynamic_shape_decompose_addmm(self): ) counters.clear() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_distributed_patterns.py b/test/inductor/test_distributed_patterns.py index 780fac7db5287..2a86beb422d90 100644 --- a/test/inductor/test_distributed_patterns.py +++ b/test/inductor/test_distributed_patterns.py @@ -436,7 +436,10 @@ def fn(x): self._assert_same_grad(r1, r2) self._assert_same_grad(p1, p2) +<<<<<<< HEAD @torch._dynamo.config.patch("graph_break_on_nn_param_ctor", False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nn_param_return3(self): def fn(x): p = torch.nn.Parameter(x + 123) @@ -453,7 +456,10 @@ def fn(x): self._assert_same_grad(r1, r2) self._assert_same_grad(p1, p2) +<<<<<<< HEAD @torch._dynamo.config.patch("graph_break_on_nn_param_ctor", False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nn_param_return4(self): def fn(x): p = torch.nn.Parameter(x + 123, requires_grad=False) diff --git a/test/inductor/test_external_callables.py b/test/inductor/test_external_callables.py index a8aab1c00d80b..d773be53c216b 100644 --- a/test/inductor/test_external_callables.py +++ b/test/inductor/test_external_callables.py @@ -16,7 +16,11 @@ def forward(self, x): return torch.matmul(x, self.matrix) +<<<<<<< HEAD # torch.add performs better than torch.mm and got chosen during tuning +======= +# torch.add performs better than torch.mm and got choosed during tuning +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def matmul_cpu(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None: torch.add(a, b, out=out) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 740faa0b37577..f004f20cf2ff5 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -5,7 +5,10 @@ import random import string import unittest +<<<<<<< HEAD import warnings +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from collections import namedtuple from contextlib import contextmanager from dataclasses import dataclass @@ -28,10 +31,14 @@ _identity, _mask_mod_signature, _score_mod_signature, +<<<<<<< HEAD _WARNINGS_SHOWN, and_masks, AuxOutput, AuxRequest, +======= + and_masks, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) BlockMask, create_block_mask, flex_attention, @@ -45,26 +52,39 @@ from torch.testing._internal.common_device_type import ( dtypes, dtypesIfCUDA, +<<<<<<< HEAD dtypesIfXPU, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) flex_attention_supported_platform as supported_platform, instantiate_device_type_tests, largeTensorTest, skipCPUIf, skipCUDAIf, +<<<<<<< HEAD skipXPUIf, ) from torch.testing._internal.inductor_utils import HAS_GPU +======= +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._triton import has_triton, has_triton_tma_device # Use this decorator only when hitting Triton bugs on H100 running_on_a100_only = skipUnless( +<<<<<<< HEAD ( (torch.cuda.is_available() and has_triton()) and (torch.cuda.get_device_capability() == (8, 0) or torch.version.hip) ) or (torch.xpu.is_available() and has_triton()), "Requires Triton + A100 or Triton + ROCm or Triton + Intel GPU", +======= + (torch.cuda.is_available() and has_triton()) + and (torch.cuda.get_device_capability() == (8, 0) or torch.version.hip), + "Requires Triton + A100 or Triton + ROCm", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) Tolerances = namedtuple("Tolerances", ["atol", "rtol"]) @@ -98,6 +118,7 @@ def temp_float32_matmul_precision(precision: str): Args: precision (str): The precision to set ('highest', 'high', or 'medium'). """ +<<<<<<< HEAD def set_float32_matmul_precision_xpu(precision: str): if precision == "highest": @@ -115,6 +136,14 @@ def set_float32_matmul_precision_xpu(precision: str): torch.set_float32_matmul_precision(original_precision) if TEST_ON_XPU: set_float32_matmul_precision_xpu(original_precision) +======= + original_precision = torch.get_float32_matmul_precision() + try: + torch.set_float32_matmul_precision(precision) + yield + finally: + torch.set_float32_matmul_precision(original_precision) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def skip_on_cpu(test_func): @@ -136,12 +165,15 @@ def skip_on_rocm(test_func): return decorated_func +<<<<<<< HEAD def skip_on_xpu(test_func): """Decorator to skip tests that are not supported on Intel GPU.""" decorated_func = skipXPUIf(True, "Not supported on Intel GPU")(test_func) return decorated_func +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def rmse(ref, res): """ Calculate root mean squared error @@ -149,13 +181,20 @@ def rmse(ref, res): return torch.sqrt(torch.mean(torch.square(ref - res))) +<<<<<<< HEAD def create_attention(score_mod, block_mask, enable_gqa=False, kernel_options=None): +======= +def create_attention(score_mod, block_mask, enable_gqa=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return functools.partial( flex_attention, score_mod=score_mod, block_mask=block_mask, enable_gqa=enable_gqa, +<<<<<<< HEAD kernel_options=kernel_options, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -182,6 +221,7 @@ class DeviceConfig: and torch.utils._triton.has_triton() and torch.cuda.get_device_capability() >= (8, 0) ) +<<<<<<< HEAD TEST_ON_XPU = torch.xpu.is_available() and torch.utils._triton.has_triton() device_configs = {} @@ -196,6 +236,11 @@ class DeviceConfig: test_device = ("xpu",) else: test_device = ("cpu",) +======= + +device_configs = {} +test_device = ("cpu", "cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SubstringSet: @@ -205,8 +250,11 @@ def __init__(self, items): def __contains__(self, item): if "cuda" in item: item = "cuda" +<<<<<<< HEAD if "xpu" in item: item = "xpu" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return item in self.items @@ -224,10 +272,13 @@ def __contains__(self, item): ), dtypes_fast=[torch.float16], ) +<<<<<<< HEAD device_configs["xpu"] = DeviceConfig( dtypes=([torch.float32, torch.bfloat16, torch.float16]), dtypes_fast=[torch.float16], ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device_configs["cpu"] = DeviceConfig( dtypes=( [torch.float32, torch.bfloat16, torch.float16] @@ -436,7 +487,11 @@ def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor): ) +<<<<<<< HEAD @large_tensor_test_class("2GB", device=test_device[0]) +======= +@large_tensor_test_class("2GB", device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestFlexAttention(InductorTestCase): def setUp(self): super().setUp() @@ -702,6 +757,7 @@ def preprocess_paged_attention( paged_attention.assign(batch_idx, input_pos, k, v, k_cache, v_cache) # convert block mask and score mod +<<<<<<< HEAD kv_len_tensor = torch.full((KV_B,), KV_S, device=device, dtype=torch.int64) converted_block_mask = paged_attention.convert_logical_block_mask( block_mask, kv_len=kv_len_tensor @@ -709,6 +765,10 @@ def preprocess_paged_attention( converted_score_mod = paged_attention.get_score_mod( score_mod, kv_len=kv_len_tensor ) +======= + converted_block_mask = paged_attention.convert_logical_block_mask(block_mask) + converted_score_mod = paged_attention.get_score_mod(score_mod) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return k_cache, v_cache, converted_block_mask, converted_score_mod def run_paged_attention( @@ -720,7 +780,10 @@ def run_paged_attention( dtype: torch.dtype, device: str, block_mask: Optional[BlockMask] = None, +<<<<<<< HEAD kernel_options: Optional[dict] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> tuple[Tensor, Tensor]: B, Q_H, Q_S, KV_H, KV_S = ( q.shape[0], @@ -756,7 +819,10 @@ def run_paged_attention( block_mask=converted_block_mask, score_mod=converted_score_mod, enable_gqa=(not Q_H == KV_H), +<<<<<<< HEAD kernel_options=kernel_options, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: return_lse = False @@ -769,7 +835,10 @@ def run_paged_attention( block_mask=converted_block_mask, score_mod=converted_score_mod, enable_gqa=(not Q_H == KV_H), +<<<<<<< HEAD kernel_options=kernel_options, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return compiled_out, compiled_lse @@ -1242,7 +1311,10 @@ def run_automatic_dynamic_test( @supported_platform @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods(self, device, dtype, score_mod: Callable): self.run_test(score_mod, dtype, device=device) @@ -1252,7 +1324,10 @@ def test_builtin_score_mods(self, device, dtype, score_mod: Callable): @common_utils.parametrize("score_mod", test_score_mods) @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_builtin_score_mods_seqlen_lt_default_sparse_block_size( self, device, dtype, score_mod: Callable ): @@ -1267,7 +1342,10 @@ def test_builtin_score_mods_seqlen_lt_default_sparse_block_size( @running_on_a100_only @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_seqlen_lt_custom_sparse_block_size( self, device, dtype: torch.dtype, score_mod: Callable @@ -1301,7 +1379,10 @@ def causal_mask(b, h, q, kv): @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize("score_mask_mod", test_score_mask_mod_map.items()) def test_builtin_score_mods_dynamic( self, device, dtype: torch.dtype, score_mask_mod: tuple[Callable, Callable] @@ -1311,7 +1392,10 @@ def test_builtin_score_mods_dynamic( @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_automatic_dynamic( self, device, dtype: torch.dtype, score_mod: Callable @@ -1321,7 +1405,10 @@ def test_builtin_score_mods_automatic_dynamic( @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_different_seqlen( self, device, dtype: torch.dtype, score_mod: Callable @@ -1345,7 +1432,10 @@ def test_builtin_score_mods_different_seqlen( @supported_platform @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize("score_mod", test_score_mods) @common_utils.parametrize("BLOCK_SIZE", test_block_size) def test_builtin_score_mods_different_block_size( @@ -1366,7 +1456,10 @@ def test_builtin_score_mods_different_block_size( @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize("batch_dims", test_Bq_Bkv) @common_utils.parametrize("head_dims", test_Hq_Hkv) @common_utils.parametrize("score_mod", test_score_mods) @@ -1437,7 +1530,10 @@ def batch_mask_mod( @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize("batch_dims", test_Bq_Bkv) @common_utils.parametrize("head_dims", test_Hq_Hkv) @common_utils.parametrize("score_mod", test_score_mods) @@ -1468,10 +1564,15 @@ def mask_mod(b, h, q, kv): @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) @skip_on_rocm # TODO: NaNs on ROCM @skip_on_xpu # TODO: NaNs on XPU like ROCM, need another PR to fix. +======= + @common_utils.parametrize("score_mod", test_score_mods) + @skip_on_rocm # TODO: NaNs on ROCM +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_GQA(self, device, dtype: torch.dtype, score_mod: Callable): inputs = ( score_mod, @@ -1492,7 +1593,10 @@ def test_GQA(self, device, dtype: torch.dtype, score_mod: Callable): @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize( "q_s", test_strides[:2] ) # TODO: fix layout for query braodcasting @@ -1535,6 +1639,7 @@ def coerce_to_strides(val, shape, strides): v = coerce_to_strides(v1, v_shape, v_s) do = coerce_to_strides(do1, do_shape, do_s) +<<<<<<< HEAD kernel_options = {"USE_TMA": True} block_mask = _create_empty_block_mask(q, k) @@ -1542,6 +1647,11 @@ def coerce_to_strides(val, shape, strides): sdpa_partial = create_attention( score_mod=score_mod, block_mask=block_mask, kernel_options=kernel_options ) +======= + block_mask = _create_empty_block_mask(q, k) + score_mod = _generate_alibi_bias(8) + sdpa_partial = create_attention(score_mod=score_mod, block_mask=block_mask) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) compiled_sdpa = torch.compile(sdpa_partial, fullgraph=True) ref_out = sdpa_partial(q, k, v) compiled_out = compiled_sdpa(q, k, v) @@ -1584,7 +1694,11 @@ def coerce_to_strides(val, shape, strides): # test paged attention which does not support backward q.requires_grad, k.requires_grad, v.requires_grad = False, False, False paged_compiled_out, _ = self.run_paged_attention( +<<<<<<< HEAD score_mod, q, k, v, dtype, device=device, kernel_options=kernel_options +======= + score_mod, q, k, v, dtype, device=device +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) torch.testing.assert_close( ref_out, paged_compiled_out, atol=tolerance.atol, rtol=tolerance.rtol @@ -1640,7 +1754,10 @@ def index_weird2(score, b, h, q_idx, kv_idx): @supported_platform @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_skip_odd_keys(self, device, dtype: torch.dtype): def score_mod(score, b, h, q, kv): return torch.where(kv % 2 == 0, score, float("-inf")) @@ -1651,7 +1768,10 @@ def score_mod(score, b, h, q, kv): @supported_platform @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_function_composition(self, device, dtype: torch.dtype): def score_mod_1(score, b, h, m, n): return score + (m - n) @@ -1668,7 +1788,10 @@ def composed_score_mod(score, b, h, m, n): @supported_platform @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_captured_buffers_all_dims(self, device, dtype: torch.dtype): head_scale = torch.randn(H, device=device) batch_scale = torch.randn(B, device=device) @@ -1686,7 +1809,10 @@ def all_bias(score, batch, head, token_q, token_kv): @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_seq_masking(self, device, dtype): seq_idx = torch.zeros(S, device=device, dtype=torch.bool) seq_idx[S // 2 :] = 1 @@ -1700,7 +1826,10 @@ def seq_mask_mod(score, b, h, q, kv): @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_load_from_bias_seq_only(self, device, dtype): bias = torch.randn(S, S, device=device, dtype=dtype) @@ -1713,7 +1842,10 @@ def bias_mod(score, b, h, q, kv): @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_load_from_bias_seq_batch(self, device, dtype): bias = torch.randn(B, S, S, device=device, dtype=dtype) @@ -1773,7 +1905,10 @@ def add_decomposed_rel_pos(self, q): @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_load_from_bias_head_seq_batch(self, device, dtype): bias = torch.randn(B, H, S, S, device=device, dtype=dtype) @@ -1786,7 +1921,10 @@ def bias_mod(score, b, h, q, kv): @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_load_rel_bias(self, device, dtype): rel_bias = torch.randn(2 * S, device=device, dtype=dtype) @@ -1799,7 +1937,10 @@ def bias_mod(score, b, h, q, kv): @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dependent_causal_bidirectional(self, device, dtype): num_bidirectional = torch.randint(0, S, (B,), device=device, dtype=torch.int32) @@ -1821,7 +1962,10 @@ def bias_mod(score, b, h, q, kv): @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_natten_2d(self, device, dtype): H = 32 W = S // H @@ -1890,7 +2034,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_silu_on_score(self, device, dtype): def silu_score(score, b, h, q, kv): return torch.nn.functional.silu(score) @@ -1901,7 +2048,10 @@ def silu_score(score, b, h, q, kv): @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_padded_dense_causal(self, device, dtype): seq_len = torch.arange(B, device=device, dtype=torch.int32) + 1 @@ -1920,7 +2070,10 @@ def njt_score_mod(qk, b, h, q, kv): @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_captured_scale(self, device, dtype): scale = torch.ones((), device=device, dtype=torch.int32) @@ -1933,7 +2086,10 @@ def score_mod_scale(qk, b, h, q, kv): @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_recompile_changed_score_mod(self, device, dtype): scale = torch.ones((), device=device, dtype=torch.int32) ADD = True @@ -1955,7 +2111,10 @@ def score_mod_scale(qk, b, h, q, kv): @expectedFailure # If we capture a tensor then we can perform a reduction on it, and that shouldn't be allowed @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_captured_reduction(self, device, dtype): scale = torch.randn((B, 8), device=device) @@ -1965,6 +2124,7 @@ def score_mod_scale(qk, b, h, q, kv): self.run_test(score_mod_scale, dtype, device=device) @supported_platform +<<<<<<< HEAD @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) @common_utils.parametrize( @@ -2237,6 +2397,8 @@ def test_shape(S, backend): _ = [test_shape(S, backend) for S in test_shapes] @supported_platform +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_multiple_score_mod_calls(self, device): query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) keys = [ @@ -2643,7 +2805,10 @@ def f(q, k, v): @supported_platform @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_njt_causal(self, device, dtype): offsets = torch.tensor( [0, 1024, 1024 + 512, S], device=device, dtype=torch.int32 @@ -2685,12 +2850,15 @@ def score_mod(score, b, h, m, n): self.run_test_with_paged_attention( score_mod, dtype=torch.float16, device=device ) +<<<<<<< HEAD self.run_test_with_paged_attention( score_mod=score_mod, dtype=torch.bfloat16, KV_S=64, device=device, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @supported_platform @skip("TODO: Figure out why this is erroring") @@ -2712,7 +2880,10 @@ def bias_mod(score, batch, head, token_q, token_kv): @common_utils.parametrize("score_mod", test_score_mods) @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize("head_dims", [(D, D // 2), (D // 2, D)]) def test_non_equal_head_dims(self, device, dtype, score_mod, head_dims): qk_d, v_d = head_dims @@ -2806,7 +2977,10 @@ def causal(b, h, q_idx, kv_idx): @common_utils.parametrize("head_dim", [17, 24, 94, 121]) @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes_fast) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_non_pow_2_headdim(self, device, dtype, head_dim): self.run_test(_rel_bias, dtype, device, B, H, S, head_dim, B, H, S, head_dim) @@ -2871,7 +3045,10 @@ def causal_constructor(S): @skip_on_cpu @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize("score_mod", [_identity, _causal]) def test_logsumexp_correctness(self, device, dtype, score_mod): make_tensor = functools.partial( @@ -3007,7 +3184,13 @@ def test_differentiable_logsumexp_gradcheck(self, device): def flex_attention_lse_only(q, k, v): return flex_attention(q, k, v, return_lse=True)[1] +<<<<<<< HEAD func = torch.compile(flex_attention_lse_only, backend="aot_eager") +======= + func = torch.compile( + flex_attention_lse_only, backend="aot_eager", fullgraph=True + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue( torch.autograd.gradcheck(func, (query, key, value), raise_exception=True) @@ -3033,7 +3216,13 @@ def test_differentiable_logsumexp_compiled(self, device): k.grad = None v.grad = None +<<<<<<< HEAD out2, lse2 = torch.compile(flex_attention)(q, k, v, return_lse=True) +======= + out2, lse2 = torch.compile(flex_attention, fullgraph=True)( + q, k, v, return_lse=True + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (out2.mean() + (lse2 * lse_mask).sum()).backward() q_grad2, k_grad2, v_grad2 = q.grad, k.grad, v.grad tolerance = Tolerances(atol=1e-1, rtol=1e-1) @@ -3080,7 +3269,11 @@ def mask(b, h, q, kv): ) q, k, v = make_tensor2(), make_tensor2(), make_tensor2() +<<<<<<< HEAD # Compile 2nd version with q/k/v(seqlen=2048) and block_mask(seqlen=4096), +======= + # Compile 2st version with q/k/v(seqlen=2048) and block_mask(seqlen=4096), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The graph includes the BlockMask._adjust part. out = torch.compile(flex_attention, dynamic=True, fullgraph=True)( q, k, v, block_mask=block_mask @@ -3219,7 +3412,10 @@ def test_strided_backwards(self, device): torch.testing.assert_close(eager, compiled, atol=9e-3, rtol=0) @supported_platform +<<<<<<< HEAD @skip_on_cpu +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize("mode", ["eager", "inductor", "paged_attention"]) @common_utils.parametrize( "permute_order", @@ -3235,11 +3431,14 @@ def test_strided_backwards(self, device): def test_flex_attention_stride_ordering(self, device, mode, permute_order, shape): from torch._inductor.ir import get_stride_order +<<<<<<< HEAD if torch.version.hip and mode == "paged_attention": raise self.skipTest( "TODO: figure out why mode_paged_attention_permute_order3_shape0 on MI200 caused mem fault" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype = torch.float32 # Setup requires_grad = device in DEVICE_SUPPORTS_BACKWARDS @@ -3330,7 +3529,11 @@ def test_flex_attention_backward_stride_ordering( def test_non_contiguous_last_dim(self, device): """Test flex_attention with tensors having non contiguous last dimension.""" B, H, D = 4, 8, 64 +<<<<<<< HEAD dtype = torch.float16 if device in DEVICE_SUPPORTS_BACKWARDS else torch.float32 +======= + dtype = torch.float16 if device == "cuda" else torch.float32 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for S in [16, 64]: def column_major_tensor(): @@ -3552,7 +3755,11 @@ def test_force_write_lse(self, device): query, key, value = make_tensor(), make_tensor(), make_tensor() out_eager, lse_eager = flex_attention(query, key, value, return_lse=True) +<<<<<<< HEAD flex_compile = torch.compile(flex_attention) +======= + flex_compile = torch.compile(flex_attention, fullgraph=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_compiled, lse_compiled = flex_compile(query, key, value, return_lse=True) out_paged, lse_paged = self.run_paged_attention( @@ -3560,9 +3767,13 @@ def test_force_write_lse(self, device): ) torch.testing.assert_close(lse_eager, lse_compiled, atol=3e-3, rtol=0) +<<<<<<< HEAD requires_grad = device in DEVICE_SUPPORTS_BACKWARDS if requires_grad: torch.testing.assert_close(lse_eager, lse_paged, atol=3e-3, rtol=0) +======= + torch.testing.assert_close(lse_eager, lse_paged, atol=3e-3, rtol=0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @supported_platform @skip_on_cpu @@ -4086,9 +4297,13 @@ def causal_mask(b, h, q_idx, kv_idx): self.assertEqual(len(cnt.graphs), 1) graph = cnt.graphs[0] norm_graph = normalize_gm(graph.print_readable(print_output=False)) +<<<<<<< HEAD self.assertExpectedInline( norm_graph, """\ +======= + expected_graph = """\ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class GraphModule(torch.nn.Module): def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_value_: "f64[2, 2, 128, 4]", L_block_mask_kv_indices: "i32[1, 1, 1, 1]", L_block_mask_kv_num_blocks: "i32[1, 1, 1]", L_block_mask_full_kv_num_blocks: "i32[1, 1, 1]", L_block_mask_full_kv_indices: "i32[1, 1, 1, 1]", L_block_mask_q_num_blocks: "i32[1, 1, 1]", L_block_mask_q_indices: "i32[1, 1, 1, 1]", L_block_mask_full_q_num_blocks: "i32[1, 1, 1]", L_block_mask_full_q_indices: "i32[1, 1, 1, 1]"): l_query_ = L_query_ @@ -4105,7 +4320,11 @@ def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_ score_mod_0 = self.score_mod_0 mask_fn_0 = self.mask_fn_0 +<<<<<<< HEAD flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None +======= + flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None return (out,) @@ -4118,7 +4337,14 @@ class mask_fn_0(torch.nn.Module): def forward(self, child: "i32[]", child_1: "i32[]", child_2: "i32[]", child_3: "i32[]"): ge: "b8[]" = child_2 >= child_3; child_2 = child_3 = None return ge +<<<<<<< HEAD """, # noqa: B950 +======= +""" + self.assertExpectedInline( + norm_graph, + expected_graph, # noqa: B950 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Save the AOT graphs aot_graphs = [] @@ -4136,6 +4362,7 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): out.sum().backward() joint_graph = normalize_gm(aot_graphs[1].print_readable(print_output=False)) +<<<<<<< HEAD self.assertExpectedInline( joint_graph, """\ @@ -4150,6 +4377,20 @@ def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]" getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[1] getitem_7: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None return (getitem_5, getitem_6, getitem_7) +======= + expected_joint_graph = """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"): + full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False) + fw_graph0 = self.fw_graph0 + joint_graph0 = self.joint_graph0 + mask_graph0 = self.mask_graph0 + flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None + getitem_4: "f64[2, 2, 128, 4]" = flex_attention_backward[0] + getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[1] + getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None + return (getitem_4, getitem_5, getitem_6) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class fw_graph0(torch.nn.Module): def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]"): @@ -4168,8 +4409,17 @@ def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3 full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) return full_default """.replace( # noqa: B950 +<<<<<<< HEAD "GPU_TYPE", torch.device(device).type ), +======= + "GPU_TYPE", torch.device(device).type + ) + + self.assertExpectedInline( + joint_graph, + expected_joint_graph, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @supported_platform @@ -4255,7 +4505,11 @@ def flex_attention_as_strided_error_tensor( mask_mod_other_buffers=(), ): inner_q, inner_k, inner_v = query.elem, key.elem, value.elem +<<<<<<< HEAD out, lse, max_scores = flex_attention_hop( +======= + out, lse = flex_attention_hop( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inner_q, inner_k, inner_v, @@ -4266,11 +4520,15 @@ def flex_attention_as_strided_error_tensor( score_mod_other_buffers, mask_mod_other_buffers, ) +<<<<<<< HEAD return ( AsStridedErrorTensor(out), AsStridedErrorTensor(lse), AsStridedErrorTensor(max_scores), ) +======= + return AsStridedErrorTensor(out), AsStridedErrorTensor(lse) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Test setup B, H, S, D = 2, 1, 128, 16 @@ -4291,7 +4549,11 @@ def flex_attention_as_strided_error_tensor( ) # Test 2: Run flex_attention with normal tensors first +<<<<<<< HEAD compiled_fn = torch.compile(flex_attention, backend="aot_eager") +======= + compiled_fn = torch.compile(flex_attention, backend="aot_eager", fullgraph=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) normal_out, normal_lse = compiled_fn( query_elem, key_elem, value_elem, return_lse=True ) @@ -4453,9 +4715,15 @@ def flex_attn_fn(x): return output flex_module = SacModule(hidden_size=512, num_heads=8, context_fn=context_fn).to( +<<<<<<< HEAD device, dtype=torch.bfloat16 ) x = torch.ones(8, 1024, 512, device=device, dtype=torch.bfloat16) +======= + "cuda", dtype=torch.bfloat16 + ) + x = torch.ones(8, 1024, 512, device="cuda", dtype=torch.bfloat16) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Run without compilation output_module = flex_module(x) @@ -4551,11 +4819,19 @@ def make_tensor(): @supported_platform @skip_on_cpu @skipCUDAIf(not has_triton_tma_device(), "Requires TMA enabled CUDA device") +<<<<<<< HEAD def test_tma_with_customer_kernel_options(self, device): make_tensor = functools.partial( torch.ones, (1, 1, 256, 128), device=device, +======= + def test_tma_with_customer_kernel_options(self): + make_tensor = functools.partial( + torch.ones, + (1, 1, 256, 128), + device="cuda", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype=torch.bfloat16, ) query, key, value = make_tensor(), make_tensor(), make_tensor() @@ -4576,6 +4852,7 @@ def test_tma_with_customer_kernel_options(self, device): # vanilla compiled vs TMA compiled torch.testing.assert_close(out_tma_compiled, out_compiled, atol=2e-1, rtol=2e-1) +<<<<<<< HEAD @supported_platform @skip_on_cpu def test_large_batch_heads_grid_dimension(self, device): @@ -4652,6 +4929,8 @@ def simple_score_mod(score, b, h, q_idx, kv_idx): fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = original_flag fa._WARNINGS_SHOWN = original_warnings_shown +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestBlockMask(InductorTestCase): def setUp(self): @@ -4666,8 +4945,13 @@ def causal_mask(b, h, q, kv): block_mask = create_block_mask(causal_mask, 4, 2, 2048, 2048, device=device) self.assertEqual(block_mask.shape, (4, 2, 2048, 2048)) +<<<<<<< HEAD self.assertEqual(block_mask[0].shape, (1, 2, 2048, 2048)) self.assertEqual(block_mask[0, 0].shape, (1, 1, 2048, 2048)) +======= + self.assertEqual(block_mask[0].shape, (2, 2048, 2048)) + self.assertEqual(block_mask[0, 0].shape, (2048, 2048)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(block_mask.numel(), 4 * 2 * 2048 * 2048) self.assertEqual(block_mask.sparsity(), 46.875) self.assertEqual(block_mask[0].sparsity(), 46.875) @@ -4711,6 +4995,7 @@ def causal_mask(b, h, q, kv): # Index on batch dimension new_block_mask = block_mask[0] +<<<<<<< HEAD assert new_block_mask.kv_num_blocks.shape == (1, 2, 4) assert new_block_mask.kv_indices.shape == (1, 2, 4, 4) @@ -4731,6 +5016,15 @@ def causal_mask(b, h, q, kv): 4, ) assert new_block_mask.kv_indices.shape == (1, 1, 4, 4) +======= + assert new_block_mask.kv_num_blocks.shape == (2, 4) + assert new_block_mask.kv_indices.shape == (2, 4, 4) + + # Index on batch and head dimension + new_block_mask = block_mask[0, 1] + assert new_block_mask.kv_num_blocks.shape == (4,) + assert new_block_mask.kv_indices.shape == (4, 4) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # slicing on batch and head dimension new_block_mask = block_mask[0:2, 1:2] @@ -5039,6 +5333,47 @@ def test_init_mismatched_full_q(self, device): ) @supported_platform +<<<<<<< HEAD +======= + @common_utils.parametrize("compile", [False, True]) + def test_no_q_info(self, device, compile: bool): + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048, device=device) + # manually set q_num_blocks and q_indices to None + block_mask.q_num_blocks = None + block_mask.q_indices = None + block_mask.full_q_num_blocks = None + block_mask.full_q_indices = None + + mask_mod_sparse_flex = functools.partial(flex_attention, block_mask=block_mask) + if compile: + mask_mod_sparse_flex = torch.compile( + mask_mod_sparse_flex, backend="inductor" + ) + inputs = [ + torch.randn( + 2, + 2, + 2048, + 64, + device=device, + dtype=torch.float16, + requires_grad=True, + ) + for _ in range(3) + ] + + causal_mask_out = mask_mod_sparse_flex(*inputs) + sdpa_mask_out = torch.nn.functional.scaled_dot_product_attention( + *inputs, is_causal=True + ) + + torch.testing.assert_close(causal_mask_out, sdpa_mask_out, atol=5e-3, rtol=0.0) + + @supported_platform +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_doc_mask_clamped_repro(self, device): def _offsets_to_doc_ids_tensor(offsets): device = offsets.device @@ -5152,7 +5487,10 @@ def flex_attention_fn(): ) @supported_platform +<<<<<<< HEAD @skip_on_xpu +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_create_is_cuda_graphable(self, device): def mask_mod(b, h, q, kv): return q >= kv @@ -5193,6 +5531,7 @@ def create_inputs(S): with self.assertRaisesRegex(ValueError, "block_mask was created for"): flex_attention_call(*create_inputs(1024), block_mask=block_mask) +<<<<<<< HEAD @supported_platform @common_utils.parametrize("full_indices", [False, True]) def test_from_kv_blocks_without_q_computation(self, device, full_indices: bool): @@ -5487,6 +5826,10 @@ def _mask_mod(b, h, q, kv): @large_tensor_test_class("2GB", device=test_device[0]) +======= + +@large_tensor_test_class("2GB", device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestPagedAttention(InductorTestCase): def setUp(self): super().setUp() @@ -5608,12 +5951,16 @@ def causal_mask(b, h, q, kv): block_mask = create_block_mask( causal_mask, max_batch_size, 1, max_seq_len, max_seq_len, device=device ) +<<<<<<< HEAD kv_len_tensor = torch.full( (max_batch_size,), max_seq_len, device=device, dtype=torch.int64 ) new_block_mask = paged_cache.convert_logical_block_mask( block_mask, kv_len=kv_len_tensor ) +======= + new_block_mask = paged_cache.convert_logical_block_mask(block_mask) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) zeros = [0, 0, 0, 0] # Check that the new block mask is correct @@ -5806,7 +6153,10 @@ def test_update(self, device): @supported_platform @dtypes(*device_configs["cpu"].dtypes) @dtypesIfCUDA(*device_configs["cuda"].dtypes) +<<<<<<< HEAD @dtypesIfXPU(*device_configs["xpu"].dtypes) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @common_utils.parametrize("score_mod", test_score_mods) def test_paged_builtin_score_mods( self, device, dtype: torch.dtype, score_mod: Callable @@ -5889,6 +6239,7 @@ def causal_mask(b, h, q, kv): ) paged_cache.assign(batch_idx, input_pos, k, v, k_cache, v_cache) +<<<<<<< HEAD kv_len_tensor = torch.full( (max_batch_size,), max_seq_len, device=device, dtype=torch.int64 ) @@ -5901,6 +6252,13 @@ def causal_mask(b, h, q, kv): paged_cache.get_score_mod(score_mod, kv_len=kv_len_tensor), block_mask, enable_gqa=False, +======= + new_block_mask = paged_cache.convert_logical_block_mask(block_mask) + + compiled_sdpa = torch.compile( + create_attention( + paged_cache.get_score_mod(score_mod), block_mask, enable_gqa=False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) paged_out = compiled_sdpa(q, k_cache, v_cache, block_mask=new_block_mask) @@ -5942,16 +6300,25 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]: supports_learnable_bias = unittest.skipUnless( +<<<<<<< HEAD ( (torch.cuda.is_available() and has_triton()) and (torch.cuda.get_device_capability() >= (8, 0) or torch.version.hip) ), +======= + (torch.cuda.is_available() and has_triton()) + and (torch.cuda.get_device_capability() >= (8, 0) or torch.version.hip), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "Requires Triton + A100 or Triton + ROCm", ) @supports_learnable_bias +<<<<<<< HEAD @large_tensor_test_class("2GB", device=test_device[0]) +======= +@large_tensor_test_class("2GB", device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestLearnableBiases(InductorTestCase): def setUp(self): super().setUp() @@ -6004,7 +6371,11 @@ def _gold_check(self, eager, compiled, gold, tensor_name, fudge_factor=1.35): def _check_outputs_and_grads( self, out_eager, out_compiled, out_gold, tensors, names=None ): +<<<<<<< HEAD backwards_grad = torch.randn_like(out_eager, device="cpu").to(out_eager.device) +======= + backwards_grad = torch.randn_like(out_eager) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) grads_eager = torch.autograd.grad((out_eager,), tensors, backwards_grad) grads_compiled = torch.autograd.grad((out_compiled,), tensors, backwards_grad) grads_gold = torch.autograd.grad((out_gold,), tensors, backwards_grad) @@ -6500,6 +6871,7 @@ def bias_func(score, b, h, q_idx, kv_idx): @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) +<<<<<<< HEAD @torch.compile def test_learnable_bias_global_compiled(self, device, params): batch_size = 1 @@ -6579,6 +6951,8 @@ def score_mod(score, b, h, q_idx, kv_idx): @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_relative_1d_bias_only_grad(self, device, params): query, key, value = self._init_tensors(params, device=device) query = query.detach().requires_grad_(False) @@ -6921,6 +7295,7 @@ def _test_learnable_bias_inner( ) +<<<<<<< HEAD instantiate_device_type_tests( TestFlexAttention, globals(), only_for=test_device, allow_xpu=True ) @@ -6937,6 +7312,12 @@ def _test_learnable_bias_inner( TestLearnableBiases, globals(), only_for=test_device, allow_xpu=True ) +======= +instantiate_device_type_tests(TestFlexAttention, globals(), only_for=test_device) +instantiate_device_type_tests(TestPagedAttention, globals(), only_for=test_device) +instantiate_device_type_tests(TestBlockMask, globals(), only_for=("cuda",)) +instantiate_device_type_tests(TestLearnableBiases, globals(), only_for=test_device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 849aefff8a965..73c4ffa6e6443 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -2,7 +2,10 @@ # flake8: noqa: B950 import functools +<<<<<<< HEAD import sys +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import unittest from collections import namedtuple from typing import Callable, Optional, Union @@ -23,6 +26,7 @@ ) from torch.testing import FileCheck from torch.testing._internal import common_utils +<<<<<<< HEAD from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, with_tf32_off from torch.testing._internal.common_device_type import ( flex_attention_supported_platform as supported_platform, @@ -47,6 +51,20 @@ torch.set_float32_matmul_precision("highest") else: torch.set_float32_matmul_precision("high") +======= +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_BF16, + PLATFORM_SUPPORTS_FLASH_ATTENTION, +) +from torch.testing._internal.common_device_type import ( + flex_attention_supported_platform as supported_platform, + instantiate_device_type_tests, +) + + +Tolerances = namedtuple("Tolerances", ["atol", "rtol"]) +torch.set_float32_matmul_precision("high") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) index = torch.ops.aten.index Tensor = torch.Tensor @@ -56,6 +74,7 @@ and torch.utils._triton.has_triton() and torch.cuda.get_device_capability() >= (8, 0) ) +<<<<<<< HEAD TEST_ON_XPU = torch.xpu.is_available() and torch.utils._triton.has_triton() if HAS_GPU: @@ -74,6 +93,18 @@ test_dtypes = [torch.float32, torch.bfloat16, torch.float16] test_dtypes_fast = [torch.float16] SKIP_UT_ON_CPU = False +======= + +if TEST_ON_CUDA: + test_device = ("cuda",) + test_dtypes = ( + [torch.float32, torch.bfloat16, torch.float16] + if PLATFORM_SUPPORTS_BF16 + else [torch.float16, torch.float32] + ) + test_dtypes_fast = [torch.float16] + SKIP_UT_ON_CPU = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: test_device = ("cpu",) torch_config_string = torch.__config__.show() @@ -93,6 +124,7 @@ test_dtypes_fast = [torch.float32] +<<<<<<< HEAD def skip_on_xpu(test_func): """Decorator to skip tests that are not supported on Intel GPU.""" decorated_func = skipXPUIf(True, "Not supported on Intel GPU")(test_func) @@ -100,12 +132,18 @@ def skip_on_xpu(test_func): def create_attention(score_mod, block_mask, enable_gqa=False, kernel_options=None): +======= +def create_attention(score_mod, block_mask, enable_gqa=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return functools.partial( flex_attention, score_mod=score_mod, block_mask=block_mask, enable_gqa=enable_gqa, +<<<<<<< HEAD kernel_options=kernel_options, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -378,7 +416,10 @@ def run_test( V_D: int = D, block_mask: Optional[BlockMask] = None, device="cuda", +<<<<<<< HEAD kernel_options=None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): assert score_mod is not None or block_mask is not None, ( "Must provide score_mod or block_mask" @@ -409,10 +450,14 @@ def run_test( q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) sdpa_partial = create_attention( +<<<<<<< HEAD score_mod, block_mask, enable_gqa=(not Q_H == KV_H), kernel_options=kernel_options, +======= + score_mod, block_mask, enable_gqa=(not Q_H == KV_H) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) compiled_sdpa = torch.compile(sdpa_partial) if not self.test_inference_only: @@ -559,6 +604,7 @@ def preprocess_paged_attention( paged_attention.assign(batch_idx, input_pos, k, v, k_cache, v_cache) # convert block mask and score mod +<<<<<<< HEAD kv_len_tensor = torch.full((KV_B,), KV_S, device=device, dtype=torch.int64) converted_block_mask = paged_attention.convert_logical_block_mask( block_mask, kv_len=kv_len_tensor @@ -566,6 +612,10 @@ def preprocess_paged_attention( converted_score_mod = paged_attention.get_score_mod( score_mod, kv_len=kv_len_tensor ) +======= + converted_block_mask = paged_attention.convert_logical_block_mask(block_mask) + converted_score_mod = paged_attention.get_score_mod(score_mod) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return k_cache, v_cache, converted_block_mask, converted_score_mod @@ -747,22 +797,37 @@ def run_test_with_call_paged_attention( ) @supported_platform +<<<<<<< HEAD @expectedFailure # tl.dot does not support embedding size less than 16 @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") @common_utils.parametrize("dtype", test_dtypes_fast) def test_bw_decoding_fails(self, device, dtype): +======= + @expectedFailure + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + @common_utils.parametrize("dtype", test_dtypes_fast) + def test_bw_decoding_fails(self, dtype): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) make_kv = functools.partial( torch.randn, (2, 2, 128, 4), dtype=dtype, +<<<<<<< HEAD device=device, +======= + device="cuda", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_grad=True, ) make_q = functools.partial( torch.randn, (2, 2, 8, 4), dtype=dtype, +<<<<<<< HEAD device=device, +======= + device="cuda", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_grad=True, ) q, k, v, backward_grad = make_q(), make_kv(), make_kv(), make_q() @@ -781,7 +846,10 @@ def sdpa_hop(q, k, v, score_mod, block_mask): @common_utils.parametrize("dtype", test_dtypes) @common_utils.parametrize("score_mod", test_score_mods) @common_utils.parametrize("head_dims", test_Hq_Hkv) +<<<<<<< HEAD @with_tf32_off +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_builtin_score_mods( self, device, dtype: torch.dtype, score_mod: Callable, head_dims ): @@ -849,6 +917,7 @@ def test_builtin_score_mods_different_block_size( ) self.run_test(score_mod, dtype, block_mask=block_mask, device=device) +<<<<<<< HEAD @unittest.skipIf(not has_triton_tma_device(), "Skip when TMA is not available") @common_utils.parametrize("dtype", test_dtypes_fast) def test_tma_decoding(self, device, dtype: torch.dtype): @@ -871,6 +940,8 @@ def test_tma_decoding(self, device, dtype: torch.dtype): kernel_options=kernel_options, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("k_s", test_input_strides) @@ -1055,12 +1126,21 @@ def mask_mod(b, h, q, kv): @supported_platform @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") +<<<<<<< HEAD def test_non_divisible_multi_token_offset_mask_with_captured_buffer(self, device): KV_S = S - 3 Q_S = 3 offset_kv = torch.randn(KV_S, device=device, dtype=torch.bfloat16) offset_q = torch.randn(Q_S, device=device, dtype=torch.bfloat16) offset_tensor = torch.tensor(S // 2 - 3, device=device, dtype=torch.int32) +======= + def test_non_divisible_multi_token_offset_mask_with_captured_buffer(self): + KV_S = S - 3 + Q_S = 3 + offset_kv = torch.randn(KV_S, device="cuda", dtype=torch.bfloat16) + offset_q = torch.randn(Q_S, device="cuda", dtype=torch.bfloat16) + offset_tensor = torch.tensor(S // 2 - 3, device="cuda", dtype=torch.int32) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def score_mod(score, b, h, q, kv): return score + offset_kv[kv] + offset_q[q] @@ -1068,6 +1148,7 @@ def score_mod(score, b, h, q, kv): def mask_mod(b, h, q, kv): return kv >= q + offset_tensor +<<<<<<< HEAD block_mask = create_block_mask(mask_mod, B, 1, Q_S, KV_S, device=device) self.run_test( Q_S=Q_S, @@ -1076,6 +1157,10 @@ def mask_mod(b, h, q, kv): score_mod=score_mod, device=device, ) +======= + block_mask = create_block_mask(mask_mod, B, 1, Q_S, KV_S) + self.run_test(Q_S=Q_S, KV_S=KV_S, block_mask=block_mask, score_mod=score_mod) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) @@ -1121,7 +1206,10 @@ def bias_mod(score, b, h, q, kv): @common_utils.parametrize("score_mod", test_score_mods) @common_utils.parametrize("dtype", test_dtypes) @common_utils.parametrize("head_dims", [(D, D // 2), (D // 2, D)]) +<<<<<<< HEAD @with_tf32_off +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_non_equal_head_dims(self, device, dtype, score_mod, head_dims): qk_d, v_d = head_dims self.run_test( @@ -1276,7 +1364,10 @@ def score_mod_scale(qk, b, h, q, kv): @supported_platform @common_utils.parametrize("head_dim", [17, 24, 94, 121]) @common_utils.parametrize("dtype", test_dtypes_fast) +<<<<<<< HEAD @common_utils.serialTest() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_non_pow_2_headdim(self, device, dtype, head_dim): self.run_test( _rel_bias, dtype, B, Hq, S, head_dim, B, Hkv, S, head_dim, device=device @@ -1578,6 +1669,7 @@ def score_mod(score, b, h, m, n): self.run_test(score_mod, device=device) self.run_test_with_paged_attention(score_mod, device=device) +<<<<<<< HEAD self.run_test_with_paged_attention( score_mod=score_mod, dtype=torch.bfloat16, @@ -1591,6 +1683,8 @@ def score_mod(score, b, h, m, n): V_D=16, device=device, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @supported_platform @patch.object(torch._inductor.config, "max_autotune", True) @@ -1666,6 +1760,10 @@ def mask_mod(b, h, q, kv): self.assertEqual(out[:, :, M:, :].sum(), 0) @supported_platform +<<<<<<< HEAD +======= + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_windowed_no_mask_vs_sdpa(self, device): score_mod = _generate_windowed(1000) attention = functools.partial(flex_attention, score_mod=score_mod) @@ -1754,19 +1852,31 @@ def mask_mod(b, h, q, kv): @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") @common_utils.parametrize("dtype", test_dtypes) @common_utils.parametrize("score_mod", [_identity, _causal]) +<<<<<<< HEAD def test_logsumexp_correctness(self, device, dtype, score_mod): +======= + def test_logsumexp_correctness(self, dtype, score_mod): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) make_kv = functools.partial( torch.randn, (B, Hkv, S, D), dtype=dtype, +<<<<<<< HEAD device=device, +======= + device="cuda", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_grad=True, ) make_q = functools.partial( torch.randn, (B, Hkv, Hq // Hkv, D), dtype=dtype, +<<<<<<< HEAD device=device, +======= + device="cuda", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_grad=True, ) q, k, v = make_q(), make_kv(), make_kv() @@ -1806,6 +1916,7 @@ def eager_sdpa_hop(q, k, v, score_mod): @supported_platform @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") +<<<<<<< HEAD def test_not_pw_of_two(self, device): query = torch.randn(1, 12, 1, 16, device=device) key = torch.randn(1, 2, 128, 16, device=device) @@ -1817,18 +1928,29 @@ def test_not_pw_of_two(self, device): @supported_platform @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") def test_logsumexp_only_return(self, device): +======= + def test_logsumexp_only_return(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) make_q = functools.partial( torch.randn, (B, Hkv, Hq // Hkv, D), dtype=torch.float32, +<<<<<<< HEAD device=device, +======= + device="cuda", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_grad=True, ) make_kv = functools.partial( torch.randn, (B, Hkv, S, D), dtype=torch.float32, +<<<<<<< HEAD device=device, +======= + device="cuda", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_grad=True, ) @@ -1847,7 +1969,10 @@ def func(q, k, v, score_mod): ) @supported_platform +<<<<<<< HEAD @skip_on_xpu # TODO: SYCL acc issue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_non_sparse_mulitple_block_size(self, device): def generate_causal_offset(offset: torch.Tensor): def causal_offset_mask(b, h, q_idx, kv_idx): @@ -1964,7 +2089,11 @@ def causal_mask(b, h, q, kv): # init 4 requests with different prefill length prefill_length = [5, 98, 47, 194] +<<<<<<< HEAD queries, keys, values = [], [], [] +======= + querys, keys, values = [], [], [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for seq_len in prefill_length: q = torch.randn( 1, @@ -1993,13 +2122,21 @@ def causal_mask(b, h, q, kv): dtype=dtype, requires_grad=False, ) +<<<<<<< HEAD queries.append(q) +======= + querys.append(q) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) keys.append(k) values.append(v) # get ground truth output ref_outs, golden_outs = [], [] +<<<<<<< HEAD for q, k, v in zip(queries, keys, values): +======= + for q, k, v in zip(querys, keys, values): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) @@ -2059,6 +2196,7 @@ def causal_mask(b, h, q, kv): input_pos = torch.tensor(prefill_length, device=device, dtype=torch.int32).view( max_batch_size, 1 ) +<<<<<<< HEAD kv_len_tensor = torch.full( (max_batch_size,), max_seq_len, device=device, dtype=torch.int64 ) @@ -2075,6 +2213,17 @@ def causal_mask(b, h, q, kv): ) paged_out = compiled_sdpa( torch.cat(queries, 0), k_cache, v_cache, block_mask=new_block_mask +======= + new_block_mask = paged_cache.convert_logical_block_mask(block_mask) + new_block_mask.seq_lengths = (1, new_block_mask.seq_lengths[1]) + compiled_sdpa = torch.compile( + create_attention( + paged_cache.get_score_mod(score_mod), new_block_mask, enable_gqa=False + ) + ) + paged_out = compiled_sdpa( + torch.cat(querys, 0), k_cache, v_cache, block_mask=new_block_mask +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) with torch.no_grad(): @@ -2088,9 +2237,13 @@ def causal_mask(b, h, q, kv): self._check_equal(golden_outs, ref_outs, paged_out, fudge_factor, "Out") +<<<<<<< HEAD instantiate_device_type_tests( TestFlexDecoding, globals(), only_for=test_device, allow_xpu=True ) +======= +instantiate_device_type_tests(TestFlexDecoding, globals(), only_for=test_device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_foreach.py b/test/inductor/test_foreach.py index c51d0bba229ec..b8b67126bebf3 100644 --- a/test/inductor/test_foreach.py +++ b/test/inductor/test_foreach.py @@ -14,8 +14,13 @@ IS_FBCODE, parametrize, ) +<<<<<<< HEAD from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.triton_utils import requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._pytree import tree_flatten @@ -269,29 +274,49 @@ def fn(a0, a1): ) # called in test_cuda_cpp_wrapper.py +<<<<<<< HEAD @requires_cuda_and_triton def test_foreach_cpp_wrapper_cuda(self): self._test_single_list(op=torch._foreach_add) @requires_cuda_and_triton +======= + @requires_cuda + def test_foreach_cpp_wrapper_cuda(self): + self._test_single_list(op=torch._foreach_add) + + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @all_ops def test_single_list(self, op): self._test_single_list(op) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @scalar_bin_ops def test_single_scalar(self, op): self._test_single_scalar(op) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @scalar_tensor_bin_ops def test_single_scalar_tensor(self, op): self._test_single_scalar_tensor(op) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @all_ops def test_scheduler_fusion_list(self, op): if op in un_ops_under_test: @@ -319,7 +344,11 @@ def fn(a0, a1, b0, b1, c0, c1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @scalar_bin_ops def test_scheduler_fusion_scalar(self, op): def fn(a0, a1): @@ -336,7 +365,11 @@ def fn(a0, a1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @scalar_bin_ops def test_broadcasting(self, op): def fn(a0, a1, b0, b1): @@ -355,7 +388,11 @@ def fn(a0, a1, b0, b1): self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @all_ops def test_singleton_lists(self, op): if op in un_ops_under_test: @@ -392,7 +429,11 @@ def fn(a0, b0, c0): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @bin_ops def test_type_promotion(self, op): def fn(a0, a1, b0, b1): @@ -413,7 +454,11 @@ def fn(a0, a1, b0, b1): self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @scalar_bin_ops def test_kernel_split_arg_limit_list(self, op): # NB: foeach_copy won't pass this test because it will dce one set of buffers @@ -435,7 +480,11 @@ def fn(a, b): self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @scalar_bin_ops @unittest.skip( "Triton recursion depth exceeded: https://github.com/triton-lang/triton/issues/1763" @@ -455,7 +504,11 @@ def fn(a): self.assertEqual(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @bin_ops def test_fusion_duplicate_buffer_list(self, op): def fn(a0, a1, b0, b1): @@ -479,7 +532,11 @@ def fn(a0, a1, b0, b1): kernel_count = 2 self.assertEqual(torch._inductor.metrics.generated_kernel_count, kernel_count) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @all_ops def test_non_foreach_consumer_list(self, op): if op in un_ops_under_test: @@ -507,7 +564,11 @@ def fn(a0, a1, b0, b1, c0, c1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @scalar_bin_ops def test_non_foreach_consumer_scalar(self, op): def fn(a0, a1): @@ -524,7 +585,11 @@ def fn(a0, a1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @all_ops def test_non_foreach_producer_list(self, op): if op in un_ops_under_test: @@ -554,7 +619,11 @@ def fn(a0, a1, b0, b1, c0, c1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @scalar_bin_ops def test_non_foreach_producer_scalar(self, op): def fn(a0, a1, b0, b1): @@ -574,7 +643,11 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @all_ops def test_non_foreach_consumer_producer_list(self, op): if op in un_ops_under_test: @@ -616,7 +689,11 @@ def fn(a0, a1, b0, b1, c0, c1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @scalar_bin_ops def test_non_foreach_consumer_producer_scalar(self, op): def fn(a0, a1, b0, b1): @@ -641,11 +718,18 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton @bin_ops @torch._dynamo.config.patch("automatic_dynamic_shapes", False) @torch._dynamo.config.patch("assume_static_by_default", False) @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", False) +======= + @requires_cuda + @bin_ops + @torch._dynamo.config.patch("automatic_dynamic_shapes", False) + @torch._dynamo.config.patch("assume_static_by_default", False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dynamic_shapes_fallback(self, op): def fn(a0, a1, b0, b1): return op([a0, a1], [b0, b1]) @@ -661,7 +745,11 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._dynamo.config.patch("automatic_dynamic_shapes", False) @torch._dynamo.config.patch("assume_static_by_default", False) @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True) @@ -680,7 +768,11 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._dynamo.config.patch("automatic_dynamic_shapes", False) @torch._dynamo.config.patch("assume_static_by_default", False) @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True) @@ -715,7 +807,11 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @decomp_ops def test_decomp(self, op): def fn(a0, a1, b0, b1, c0, c1): @@ -735,7 +831,11 @@ def fn(a0, a1, b0, b1, c0, c1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_fuse_concat(self): def fn(x1, x2, x3, w1, w2, w3): x = torch.stack([x1, x2, x3]) @@ -758,7 +858,11 @@ def fn(x1, x2, x3, w1, w2, w3): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_zero_elems(self): def fn(a0, a1, b0, b1): return torch._foreach_add([a0, a1], [b0, b1]) @@ -775,7 +879,11 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @bin_ops def test_2d_blocking(self, op): def fn(a0, a1, b0, b1): @@ -793,7 +901,11 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @bin_ops def test_2d_blocking_partitioning(self, op): def fn(a0, a1, b0, b1): @@ -811,7 +923,11 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @bin_ops def test_2d_blocking_partitioning_elems(self, op): """2D blocking should be grouped by number of yelems""" @@ -833,7 +949,11 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @bin_ops @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 2) def test_2d_blocking_partitioning_mixed_sizes(self, op): @@ -856,7 +976,11 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @inplace_bin_ops def test_reinplacing(self, op): def fn(a0, a1, b0, b1): @@ -874,7 +998,11 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @inplace_bin_ops def test_reinplacing_mut_before(self, op): def fn(a0, a1, b0, b1): @@ -893,7 +1021,11 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @inplace_bin_ops def test_reinplacing_mut_after(self, op): def fn(a0, a1, b0, b1): @@ -912,7 +1044,11 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_multi_device(self): def test_foreach_add(a0, a1, b0, b1): return torch._foreach_add([a0, a1], [b0, b1]) @@ -930,7 +1066,11 @@ def test_foreach_add(a0, a1, b0, b1): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_aliasing(self): def test_foreach_add(a0, a1, a2, b0, b1, b2): return torch._foreach_add_([a0, a1, a2], [b0, b1, b2]) @@ -952,7 +1092,11 @@ def test_foreach_add(a0, a1, a2, b0, b1, b2): self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 1) def test_2d_block_no_mixed_sizes_no_mask(self): """2D blocking with no mixed sizes constant mask""" @@ -974,7 +1118,11 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._inductor.config.patch("combo_kernel_allow_mixed_sizes", 2) def test_2d_block_mixed_sizes_with_mask(self): """2D blocking with mixed sizes should have mask""" @@ -996,7 +1144,11 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @foreach_map_bin_ops def test_foreach_map_backward_binary(self, op): from torch._dynamo.polyfills import foreach_map_fn @@ -1037,7 +1189,11 @@ def ref_fn(xs, ys): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_foreach_map_input_mutation(self): def fn(xs, ys): outs = foreach_map_add_inplace(xs, ys) @@ -1073,7 +1229,11 @@ def fn(xs, ys): ): _ = run_fw_bw_and_get_code(lambda: torch.compile(fn)(*inps)) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @foreach_map_un_ops def test_foreach_map_backward_unary(self, op): from torch._dynamo.polyfills import foreach_map_fn @@ -1109,5 +1269,9 @@ def ref_fn(xs): if __name__ == "__main__": from torch._inductor.test_case import run_tests +<<<<<<< HEAD if HAS_CPU or HAS_CUDA_AND_TRITON: +======= + if HAS_CPU or HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests(needs="filelock") diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index 82e4a923a92e1..8c8c4025c8089 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -8,7 +8,10 @@ from torch import Tensor from torch._inductor import config, utils from torch._inductor.test_case import run_tests, TestCase +<<<<<<< HEAD from torch._inductor.utils import run_and_get_code +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FP8, PLATFORM_SUPPORTS_MX_GEMM, @@ -23,9 +26,14 @@ _quantize_tensorwise, _to_fp8_saturated, HAS_CPU, +<<<<<<< HEAD HAS_CUDA_AND_TRITON, ) from torch.testing._internal.jit_utils import FileCheck +======= + HAS_CUDA, +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._triton import has_triton_tma_device @@ -42,7 +50,11 @@ def _fix_fp8_dtype_for_rocm( # with MI300 supported FP8 types if device is GPU: # e4m3fn -> e4m3fnuz # e5m2 -> e5m2fnuz +<<<<<<< HEAD # Supports single, tuple and list of dtypes +======= + # Supports single, typle and list of dtypes +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Keeps the same test name for CUDA and ROCm # Also it allows to enable FP8 inductor tests for CPU if ( @@ -465,6 +477,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): # autotuning for the compiled case, the results can be different because of # the way blocks of results are accumulated (float addition not associative), so # setting a small absolute tolerance in these tests +<<<<<<< HEAD if dtype == torch.bfloat16: self.assertEqual(y_eager, y_compiled, rtol=5e-2, atol=0.07) else: @@ -548,6 +561,8 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): # autotuning for the compiled case, the results can be different because of # the way blocks of results are accumulated (float addition not associative), so # setting a small absolute tolerance in these tests +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @@ -614,6 +629,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): ) self.assertEqual(y_eager.dtype, dtype) self.assertEqual(y_compiled.dtype, dtype) +<<<<<<< HEAD torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0.07) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @@ -689,6 +705,8 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): FileCheck().check("SCALING_ROWWISE : tl.constexpr = True").run(code[0]) self.assertEqual(y_eager.dtype, dtype) self.assertEqual(y_compiled.dtype, dtype) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @@ -747,7 +765,11 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): ) self.assertEqual(y_eager.dtype, dtype) self.assertEqual(y_compiled.dtype, dtype) +<<<<<<< HEAD torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0.07) +======= + torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("M", (1, 3, 33, 257, 1024)) @@ -926,5 +948,9 @@ def linear(x, w_t_fp8, w_inverse_scale, bias): if __name__ == "__main__": +<<<<<<< HEAD if HAS_CUDA_AND_TRITON or HAS_CPU: +======= + if HAS_CUDA or HAS_CPU: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests() diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py index 25e96fa9f1e9f..273eb88160b12 100644 --- a/test/inductor/test_fused_attention.py +++ b/test/inductor/test_fused_attention.py @@ -15,12 +15,16 @@ SM80OrLater, ) from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm +<<<<<<< HEAD from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_CPU, HAS_CUDA_AND_TRITON, HAS_XPU_AND_TRITON, ) +======= +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_XPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def checkpoint_wrapper(fn): @@ -1080,6 +1084,7 @@ def dot_prod_attention( check_train=False, ) +<<<<<<< HEAD def _test_sdpa_rewriter_24(self): def dot_prod_attention( query: torch.Tensor, @@ -1120,6 +1125,10 @@ def dot_prod_attention( if HAS_XPU_AND_TRITON or (HAS_CUDA_AND_TRITON and PLATFORM_SUPPORTS_FUSED_ATTENTION): +======= + +if HAS_XPU or (HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SDPAPatternRewriterGpuTests(TestSDPAPatternRewriterTemplate): device = GPU_TYPE @@ -1186,9 +1195,12 @@ class SDPAPatternRewriterGpuTests(TestSDPAPatternRewriterTemplate): test_sdpa_rewriter_23_gpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_23 ) +<<<<<<< HEAD test_sdpa_rewriter_24_gpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_24 ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SDPAPatternRewriterGpuDynamicTests(SDPAPatternRewriterGpuTests): use_static_shapes = False @@ -1255,9 +1267,12 @@ class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): test_sdpa_rewriter_23_cpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_23 ) +<<<<<<< HEAD test_sdpa_rewriter_24_cpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_24 ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SDPAPatternRewriterCpuDynamicTests(SDPAPatternRewriterCpuTests): use_static_shapes = False diff --git a/test/inductor/test_fxir_backend.py b/test/inductor/test_fxir_backend.py index 32ccce7e6c038..5c5e2d31a3a15 100644 --- a/test/inductor/test_fxir_backend.py +++ b/test/inductor/test_fxir_backend.py @@ -20,10 +20,16 @@ from torch._inductor.codegen.common import register_backend_for_device from torch._inductor.codegen.cpp import CppScheduling from torch._inductor.codegen.triton import TritonScheduling +<<<<<<< HEAD from torch._inductor.codegen.wrapper import PythonWrapperCodegen from torch._inductor.codegen.wrapper_fxir import FxConverter, WrapperFxCodegen from torch._inductor.test_case import TestCase as InductorTestCase from torch.export import Dim +======= +from torch._inductor.codegen.wrapper_fxir import FxConverter, WrapperFxCodegen +from torch._inductor.select_algorithm import extern_kernels +from torch._inductor.test_case import TestCase as InductorTestCase +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -36,6 +42,7 @@ ) +<<<<<<< HEAD if HAS_GPU: import triton import triton.language as tl @@ -43,6 +50,8 @@ from torch.testing._internal.triton_utils import add_kernel_2d_autotuned +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @requires_gpu() @config.patch( compile_threads=1, @@ -159,11 +168,16 @@ def foo(x, y): (gm,) = self._compile_and_check(foo, args, expected_num_triton_kernels=1) # Check for the extern kernel +<<<<<<< HEAD num_extern = self._count_ops(gm, torch.ops.aten.addmm.out) +======= + num_extern = self._count_ops(gm, extern_kernels.addmm) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(num_extern, 1) def test_fallback(self): """ +<<<<<<< HEAD Test a program that calls aten fallbacks. """ @@ -173,6 +187,17 @@ def foo(x): return torch.addbmm(x, batch1, batch2) args = (torch.randn(3, 4, device=self.device),) +======= + Test a program that calls an aten fallback. + """ + + length = 8 + + def foo(x): + return x + torch.randn(1, device=self.device) + + args = (torch.randn(length, device=self.device),) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Since the program has a random output, just check metadata. # Don't check for an exact value. @@ -181,10 +206,15 @@ def foo(x): ) # Check for the fallback kernel. +<<<<<<< HEAD num_fallback = self._count_ops( gm, torch.ops.aten.randint.low_out ) + self._count_ops(gm, torch.ops.aten.addbmm.default) self.assertEqual(num_fallback, 2) +======= + num_fallback = self._count_ops(gm, torch.ops.aten.randint.low_out) + self.assertEqual(num_fallback, 1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cat_inputs(self): """ @@ -403,6 +433,7 @@ def get_input(): ] self.assertEqual(placeholder.meta["val"], symbol) +<<<<<<< HEAD @parametrize( "shape", [ @@ -527,6 +558,8 @@ def mocked_from_meta(inductor_meta, cfg, mode="python"): self.assertEqual(grid[1], 1) self.assertEqual(grid[2], 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @config.patch({"trace.enabled": True}) @unittest.mock.patch("torch._inductor.debug.DebugFormatter.output_code") def test_debug(self, mock_output_code): @@ -635,6 +668,7 @@ def run(*args, **kwargs): op="call_function", target=torch.empty_strided ) (shape, stride) = empty_strided.args +<<<<<<< HEAD if use_dynamic_shapes: self.assertEqual(type(shape[0]), torch.fx.Node) @@ -820,6 +854,10 @@ def forward(self, x): # Now the backend should have been called. self.assertTrue(called) +======= + output_is_symbolic = any(isinstance(dim, torch.SymInt) for dim in shape) + self.assertEqual(output_is_symbolic, use_dynamic_shapes) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": diff --git a/test/inductor/test_gpu_cpp_wrapper.py b/test/inductor/test_gpu_cpp_wrapper.py index 24163ece1f919..8475e2e749808 100644 --- a/test/inductor/test_gpu_cpp_wrapper.py +++ b/test/inductor/test_gpu_cpp_wrapper.py @@ -125,7 +125,11 @@ def make_test_case( assert callable(func), "not a callable" func = slowTest(func) if slow else func +<<<<<<< HEAD @config.patch(cpp_wrapper=True) +======= + @config.patch(cpp_wrapper=True, search_autotune_cache=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def fn(self): tests.setUpClass() tests.setUp() diff --git a/test/inductor/test_graph_transform_observer.py b/test/inductor/test_graph_transform_observer.py index 2bd0b6ef43f11..5ec03fa959255 100644 --- a/test/inductor/test_graph_transform_observer.py +++ b/test/inductor/test_graph_transform_observer.py @@ -11,7 +11,11 @@ from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION from torch.testing._internal.common_utils import IS_LINUX +<<<<<<< HEAD from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON +======= +from torch.testing._internal.inductor_utils import HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: @@ -28,10 +32,14 @@ class TestGraphTransformObserver(TestCase): def test_sdpa_rewriter(self): if not ( +<<<<<<< HEAD HAS_CUDA_AND_TRITON and PLATFORM_SUPPORTS_FUSED_ATTENTION and HAS_PYDOT and HAS_DOT +======= + HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION and HAS_PYDOT and HAS_DOT +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): return diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py index 090a7e8e29d3f..1935d76e6e722 100644 --- a/test/inductor/test_group_batch_fusion.py +++ b/test/inductor/test_group_batch_fusion.py @@ -286,6 +286,27 @@ def forward(self, x): return torch.stack((stack_input, stack_other), dim=0) +<<<<<<< HEAD +======= +@requires_gpu() +@torch._inductor.config.patch( + pre_grad_fusion_options={ + "batch_linear": {}, + "batch_linear_lhs": {}, + "batch_layernorm": {}, + "batch_tanh": {}, + "batch_relu": {}, + "batch_sigmoid": {}, + }, + post_grad_fusion_options={ + "batch_aten_add": {}, + "batch_aten_mul": {}, + "batch_aten_sub": {}, + "batch_aten_div": {}, + "group_linear": {"require_fbgemm": True}, + }, +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestGroupBatchFusion(TestCase): def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): if len(set(ref_dict.keys())) != len(set(res_dict.keys())): @@ -314,6 +335,7 @@ def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol) ) +<<<<<<< HEAD @requires_gpu() @unittest.skipIf(not has_fbgemm, "requires fbgemm") @torch._inductor.config.patch( @@ -322,6 +344,9 @@ def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): "group_linear": {"require_fbgemm": True}, }, ) +======= + @unittest.skipIf(not has_fbgemm, "requires fbgemm") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_group_linear_fusion(self): z = 10 for has_bias in [True, False]: @@ -344,6 +369,7 @@ def test_group_linear_fusion(self): counters["inductor"]["group_linear"], 4, ) +<<<<<<< HEAD counters.clear() @requires_gpu() @@ -354,6 +380,15 @@ def test_group_linear_fusion(self): "group_linear": {"require_fbgemm": True}, }, ) +======= + self.assertEqual( + counters["inductor"]["batch_aten_add"], + 0, + ) + counters.clear() + + @unittest.skipIf(not has_fbgemm, "requires fbgemm") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_group_linear_fusion_different_shapes(self): counters.clear() module = MyModule2().eval().to(GPU_TYPE) @@ -378,6 +413,7 @@ def test_group_linear_fusion_different_shapes(self): counters["inductor"]["group_linear"], 2, ) +<<<<<<< HEAD counters.clear() @requires_gpu() @@ -386,6 +422,15 @@ def test_group_linear_fusion_different_shapes(self): pre_grad_fusion_options={"batch_layernorm": {}}, post_grad_fusion_options={}, ) +======= + self.assertEqual( + counters["inductor"]["batch_aten_mul"], + 1, + ) + counters.clear() + + @unittest.skipIf(GPU_TYPE == "mps", "welford_reduce is yet not implemented for MPS") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_batch_layer_norm_fusion(self): for has_weight in [True, False]: for has_bias in [True, False]: @@ -403,11 +448,14 @@ def test_batch_layer_norm_fusion(self): self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) counters.clear() +<<<<<<< HEAD @requires_gpu() @torch._inductor.config.patch( pre_grad_fusion_options={"batch_linear_lhs": {}}, post_grad_fusion_options={}, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_batch_linear_lhs_fusion(self): z = 10 for has_bias in [True, False]: @@ -425,11 +473,14 @@ def test_batch_linear_lhs_fusion(self): self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) counters.clear() +<<<<<<< HEAD @requires_gpu() @torch._inductor.config.patch( pre_grad_fusion_options={"batch_linear": {}}, post_grad_fusion_options={}, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_batch_linear_pre_grad_fusion(self): for has_bias in [True, False]: counters.clear() @@ -446,6 +497,7 @@ def test_batch_linear_pre_grad_fusion(self): self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) counters.clear() +<<<<<<< HEAD @requires_gpu() @torch._inductor.config.patch( pre_grad_fusion_options={ @@ -459,6 +511,8 @@ def test_batch_linear_pre_grad_fusion(self): "batch_aten_div": {}, }, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_pointwise_op_fusion(self): counters.clear() module = TestPoitwiseOps(GPU_TYPE) @@ -1230,7 +1284,11 @@ def test_find_independent_subset_greedy_fuse(self): ) self.assertEqual(next(i), [lookup[n] for n in ["n2", "n3", "n5"]]) +<<<<<<< HEAD # fuse n2 and n3 which makes n4 now dependent on n1. +======= + # fuse n2 and n3 which makes n4 now dependant on n1. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = tuple(lookup[n] for n in ["n0", "n1"]) fused = g.create_node("placeholder", "target", name="n2+n3", args=args) lookup["n2"].replace_all_uses_with(fused) diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index 3359b237904fe..2de56e297841d 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -55,7 +55,11 @@ def test_indexing_simplification(self): sizevars.simplify_with_ranges(expr, var_ranges), i1 + 128 * i2 + 64 * ModularIndexing(r3, 1, 2), ) +<<<<<<< HEAD # all the modular indexing should be removed when the body can't be larger than the modulus +======= + # all the modular indexing should be removed when the body cant be larger than the modulus +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) var_ranges[r3] = 2 self.assertEqual( sizevars.simplify_with_ranges(expr, var_ranges), i1 + 128 * i2 + 64 * r3 @@ -247,7 +251,11 @@ def f(x): x = torch.randint(0, 255, (2, 4096, 5504), dtype=torch.uint8, device=GPU_TYPE) triton_code = run_and_get_triton_code(f, x) +<<<<<<< HEAD # Make sure the 2 load uses simplified indexing rather than something like +======= + # Make sure the 2 load uses simpified indexing rather than something like +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)), self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + (x2 // 2),")) if DO_PERF_TEST: diff --git a/test/inductor/test_inductor_annotations.py b/test/inductor/test_inductor_annotations.py index 3824b25cdeaea..5eb76610e83c7 100644 --- a/test/inductor/test_inductor_annotations.py +++ b/test/inductor/test_inductor_annotations.py @@ -3,7 +3,11 @@ import torch._inductor.config as inductor_config from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +from torch.testing._internal.triton_utils import requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class InductorAnnotationTestCase(TestCase): @@ -18,7 +22,11 @@ def f(a, b): _, code = run_and_get_code(f_comp, a, b) return code[0] +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_no_annotations(self): code = self.get_code() @@ -26,16 +34,27 @@ def test_no_annotations(self): self.assertTrue("training_annotation" not in code) @inductor_config.patch(annotate_training=True) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_training_annotation(self): code = self.get_code() self.assertTrue("from torch.cuda import nvtx" in code) +<<<<<<< HEAD self.assertTrue( code.count("training_annotation = nvtx._device_range_start('inference')") >= 1 ) self.assertTrue(code.count("nvtx._device_range_end(training_annotation)") >= 1) +======= + self.assertEqual( + code.count("training_annotation = nvtx._device_range_start('inference')"), 1 + ) + self.assertEqual(code.count("nvtx._device_range_end(training_annotation)"), 1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": diff --git a/test/inductor/test_inductor_scheduler.py b/test/inductor/test_inductor_scheduler.py index f180c9d003df4..7ba786c7e3965 100644 --- a/test/inductor/test_inductor_scheduler.py +++ b/test/inductor/test_inductor_scheduler.py @@ -1,12 +1,20 @@ # Owner(s): ["module: inductor"] +<<<<<<< HEAD from unittest import skipIf +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._inductor.metrics as metrics import torch.utils.flop_counter from torch._dynamo.utils import counters +<<<<<<< HEAD from torch._inductor.utils import fresh_inductor_cache +======= +from torch._inductor.ir import FixedLayout +from torch._inductor.utils import fresh_cache +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_cuda import SM70OrLater from torch.testing._internal.common_device_type import ( dtypes, @@ -14,7 +22,10 @@ skipCUDAIf, ) from torch.testing._internal.common_utils import parametrize, run_tests, TestCase +<<<<<<< HEAD from torch.testing._internal.inductor_utils import IS_BIG_GPU +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def FlopCounterMode(*args, **kwargs): @@ -79,7 +90,11 @@ def test_disable_get_estimated_runtime_logging(self, device, dtype): for op, example_inputs, kwargs in tc: comp = torch.compile(op) torch._dynamo.reset() +<<<<<<< HEAD with fresh_inductor_cache(): +======= + with fresh_cache(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comp(*example_inputs, **kwargs) self.assertEqual(metrics.num_bytes_accessed, 0) self.assertEqual(any(m[1] for m in metrics.node_runtimes), False) @@ -89,16 +104,55 @@ def test_disable_get_estimated_runtime_logging(self, device, dtype): @dtypes(torch.float, torch.float16) @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") +<<<<<<< HEAD +======= + def test_get_estimated_runtime_logging(self, device, dtype): + if device == "cpu": + return + tc = _test_cases(device, dtype) + expected_metrics = [ + # num_bytes_accessed, number of nonzero node_runtimes + (74 * dtype.itemsize, 1), + (60 * dtype.itemsize, 1), + (222 * dtype.itemsize, 4), + (77 * dtype.itemsize, 2), + ] + tc_plus_metrics = zip(tc, expected_metrics) + + metrics.reset() + torch._logging.set_logs(inductor_metrics=True) + for test_case, met in tc_plus_metrics: + op, example_inputs, kwargs = test_case + enba, enr = met + + comp = torch.compile(op) + torch._dynamo.reset() + with fresh_cache(): + comp(*example_inputs, **kwargs) + self.assertEqual(enba, metrics.num_bytes_accessed) + nonzero_node_runtimes = sum(1 for x in metrics.node_runtimes if x[1] != 0) + self.assertEqual(enr, nonzero_node_runtimes) + metrics.reset() + torch._logging.set_logs() + + @dtypes(torch.float, torch.float16) + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize( "options", [ { "max_autotune": True, "max_autotune_gemm_backends": "TRITON", +<<<<<<< HEAD +======= + "force_disable_caches": True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "max_autotune": True, "max_autotune_gemm_backends": "TRITON,ATEN", +<<<<<<< HEAD }, ], ) @@ -107,15 +161,38 @@ def test_disable_get_estimated_runtime_logging(self, device, dtype): def test_flop_counter_op(self, device, dtype, options): if device == "cpu": return +======= + "force_disable_caches": True, + }, + ], + ) + def test_flop_counter_op(self, device, dtype, options): + if device == "cpu": + return + if ( + options["max_autotune_gemm_backends"] == "TRITON" + and torch.cuda.is_available() + and not torch._inductor.utils.use_triton_template( + FixedLayout(torch.device("cuda"), torch.float16, [400, 800]) + ) + ): + return +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tc = _test_cases(device, dtype) torch._logging.set_logs(inductor_metrics=True) for op, example_inputs, kwargs in tc: comp = torch.compile(op, options=options) +<<<<<<< HEAD # next two lines are required, otherwise the flops will be cached from previous runs of this function. torch._dynamo.reset() with fresh_inductor_cache(): +======= + # next two lines are required, otherwise the flops will be cached from pervious runs of this function. + torch._dynamo.reset() + with fresh_cache(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # actually run to set the counters comp(*example_inputs, **kwargs) with FlopCounterMode() as mode: diff --git a/test/inductor/test_inplace_padding.py b/test/inductor/test_inplace_padding.py index 7ddd0dd4441b8..120ee6c0fc894 100644 --- a/test/inductor/test_inplace_padding.py +++ b/test/inductor/test_inplace_padding.py @@ -9,7 +9,10 @@ from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code from torch.testing import FileCheck +<<<<<<< HEAD from torch.testing._internal.common_utils import serialTest +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_GPU, @@ -212,7 +215,10 @@ def f(x, y): @requires_cuda_with_enough_memory(2e10) @inductor_config.patch(force_shape_pad=True) +<<<<<<< HEAD @serialTest() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_linear_and_cel(self): # Use nan for torch.empty torch.use_deterministic_algorithms(True) @@ -233,9 +239,15 @@ def f(x, y): loss.backward() return loss +<<<<<<< HEAD x = torch.randn(B * T, C, requires_grad=True).to(GPU_TYPE).bfloat16() x.retain_grad() y = torch.randint(0, V, (B * T,)).to(GPU_TYPE) +======= + x = torch.randn(B * T, C, requires_grad=True).cuda().bfloat16() + x.retain_grad() + y = torch.randint(0, V, (B * T,)).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) opt_f = torch.compile(f) diff --git a/test/inductor/test_inplacing_pass.py b/test/inductor/test_inplacing_pass.py index dd592f8c4e823..ab790b6793a2a 100644 --- a/test/inductor/test_inplacing_pass.py +++ b/test/inductor/test_inplacing_pass.py @@ -413,6 +413,7 @@ def f(b): # Both list inputs failed to reinplace. So we should have emitted clones for them. self.assertEqual(post_grad_graphs.count("aten.clone"), 2) +<<<<<<< HEAD def test_generalized_scatter(self): # This is an integration test for the reinplacing pass. def fn(x_1): @@ -438,6 +439,8 @@ def fn(x_1): result = torch.compile(fn, fullgraph=True, backend="inductor")(x) self.assertEqual(result, expected) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize( "factory_op", [ diff --git a/test/inductor/test_kernel_benchmark.py b/test/inductor/test_kernel_benchmark.py index 4c35cec9bee9b..abb4c4673cfb0 100644 --- a/test/inductor/test_kernel_benchmark.py +++ b/test/inductor/test_kernel_benchmark.py @@ -199,7 +199,11 @@ def test_matmul_bandwidth_computation(self): def triton_(in_out_ptr0, xnumel, XBLOCK : tl.constexpr): Note the in_out_ptr0 argument. It's for a 1000x1000 tensor, but it's +<<<<<<< HEAD inplace updated, so when computing the bandwidth, we should count +======= + inplace udpated, so when computing the bandwidth, we should count +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) the total memory access as 2 * 1000 * 1000 * 4 = 8MB. This amount is what this test asserts. """ @@ -386,9 +390,12 @@ def f(a, b, c): max_autotune=True, max_autotune_gemm_backends="TRITON", force_shape_pad=True ) def test_slice_mm_bandwidth_computation(self): +<<<<<<< HEAD if GPU_TYPE == "xpu" and not torch._inductor.utils.is_big_gpu(): raise unittest.SkipTest("unsupported device") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) M, N, K = 1000, 2000, 3000 @torch.compile diff --git a/test/inductor/test_layout_optim.py b/test/inductor/test_layout_optim.py index 8962e6bb18b5f..17daaa2b43467 100644 --- a/test/inductor/test_layout_optim.py +++ b/test/inductor/test_layout_optim.py @@ -300,7 +300,11 @@ def test_nll_loss_backward(self): The CUDA implementation of aten.nll_loss2d_backward.default requires the self tensor (whose layout will be used to create grad_input) to be contiguous. Layout optimization may change the self tensor's layout +<<<<<<< HEAD and cause failure. We fix that by adding layout constraints to the +======= + and cause failure. We fix that by adding layout constaints to the +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fallback of aten.nll_loss2d_backward.default . """ diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py index 13e3c3684d381..521fca7b9791e 100644 --- a/test/inductor/test_loop_ordering.py +++ b/test/inductor/test_loop_ordering.py @@ -3,13 +3,19 @@ import contextlib import os import unittest +<<<<<<< HEAD from unittest import skipUnless +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import numpy as np import sympy import torch +<<<<<<< HEAD import torch.nn.functional as F +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch import nn from torch._dynamo.testing import rand_strided from torch._dynamo.utils import same @@ -19,7 +25,11 @@ from torch._inductor.scheduler import SchedulerNode from torch._inductor.test_case import run_tests, TestCase from torch._inductor.test_operators import realize +<<<<<<< HEAD from torch._inductor.utils import is_big_gpu, run_and_get_code, sympy_index_symbol +======= +from torch._inductor.utils import run_and_get_code, sympy_index_symbol +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.virtualized import ops, V from torch.testing import FileCheck from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 @@ -478,6 +488,7 @@ def test_pattern2(tensor_x_inp, scale_x): expected_numbytes += tensor_fp8.nbytes + tensor_fp8_t.nbytes # output self.assertEqual(expected_numbytes, metrics.num_bytes_accessed) +<<<<<<< HEAD def test_outer_dimension_softmax(self): """ This test repros the not able to fuse problem for outer dimension @@ -519,6 +530,8 @@ def f(x): optf = torch.compile(f) print(f"ms={do_bench(lambda: optf(x))}") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Disable split reduction to make it easier to calculate the expected # number of bytes accessed. In this case, split reduction does not # help perf much. @@ -563,6 +576,7 @@ def f(x): ms = do_bench(lambda: opt_f(x)) print(f"{ms=:.3f}") +<<<<<<< HEAD @inductor_config.patch( { "max_autotune": True, @@ -609,6 +623,8 @@ def f(x): out, code = run_and_get_code(f, x) FileCheck().check_count("@triton.jit", 1, exactly=True).run(code[0]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @inductor_config.patch( { @@ -1091,7 +1107,11 @@ def test_penalized_small_dim(self): x = torch.rand([2000, 1], device=GPU_TYPE) y = torch.rand([4, 1], device=GPU_TYPE).T +<<<<<<< HEAD # don't tile when it doesn't affect total coalesced mem accesses much +======= + # dont tile when it doesnt affect total coalesced mem accesses much +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def f(x, y): return x + y diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 320bdf3462e64..a119e7142d019 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -12,13 +12,21 @@ import unittest from typing import Callable, Optional from unittest import mock +<<<<<<< HEAD +======= +from unittest.mock import MagicMock +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch import multiprocessing as mp, nn from torch._dynamo import reset from torch._dynamo.exc import BackendCompilerFailed from torch._dynamo.testing import rand_strided, reset_rng_state +<<<<<<< HEAD from torch._dynamo.utils import counters, same +======= +from torch._dynamo.utils import same +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor import config from torch._inductor.autotune_process import ( _TestBenchmarkRequest, @@ -30,6 +38,7 @@ from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm from torch._inductor.select_algorithm import ( +<<<<<<< HEAD add_feedback_saver, add_preprocessing_fn, AlgorithmSelectorCache, @@ -45,19 +54,38 @@ GemmConfig, ) from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 +======= + AlgorithmSelectorCache, + TritonTemplate, + TritonTemplateCaller, +) +from torch._inductor.template_heuristics import CUDAConfigHeuristic, GemmConfig +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 +from torch.testing._internal.common_device_type import largeTensorTest +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_WINDOWS, parametrize, TEST_WITH_ROCM, +<<<<<<< HEAD ) from torch.testing._internal.logging_utils import multiple_logs_to_string from torch.utils._triton import has_triton_stable_tma_api, has_triton_tma_device +======= + MI300_ARCH, + runOnRocmArch, + skipIfXpu, +) +from torch.testing._internal.logging_utils import multiple_logs_to_string +from torch.utils._triton import has_triton_tma_device +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten = torch.ops.aten from torch._inductor.mock_cache import global_stats, PatchCaches, Stats from torch._inductor.test_case import run_tests, TestCase +<<<<<<< HEAD from torch._inductor.utils import ( fresh_cache, get_k_splits, @@ -68,18 +96,32 @@ from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck from torch.testing._internal.common_utils import MI300_ARCH, runOnRocmArch, skipIfXpu +======= +from torch._inductor.utils import fresh_cache, run_and_get_code +from torch._inductor.virtualized import V +from torch.fx.experimental.proxy_tensor import make_fx +from torch.testing import FileCheck +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.inductor_utils import ( get_func_call, get_kernel_launch, GPU_TYPE, HAS_CPU, +<<<<<<< HEAD HAS_CUDA_AND_TRITON, +======= + HAS_CUDA, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) HAS_GPU, ) torch.set_float32_matmul_precision("high") +<<<<<<< HEAD if HAS_CUDA_AND_TRITON: +======= +if HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.cuda.memory._set_allocator_settings("expandable_segments:False") @@ -147,6 +189,7 @@ def mm(a, b): return torch.mm(a, b) M, N, K = 21, 31, 11 +<<<<<<< HEAD a = ( torch.randn(*((K, M) if a_transposed else (M, K))) .to(torch.float16) @@ -157,6 +200,10 @@ def mm(a, b): .to(torch.float16) .to(GPU_TYPE) ) +======= + a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda() + b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with config.patch( { @@ -173,6 +220,7 @@ def mm(a, b): @unittest.skipIf( not has_triton_tma_device(), "Need device-side TMA support in Triton" ) +<<<<<<< HEAD @parametrize("a_transposed", (False, True)) @parametrize("b_transposed", (False, True)) @parametrize("dynamic", (False, True)) @@ -235,14 +283,21 @@ def next_multiple_16(a: int) -> int: not has_triton_tma_device(), "Need device-side TMA support in Triton" ) @skipIfXpu(msg="TMA path on Intel GPU not require this check") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("dynamic", (False, True)) def test_max_autotune_regular_mm_persistent_tma_illegal_alignment(self, dynamic): def mm(a, b): return torch.mm(a, b) M, N, K = 21, 31, 11 +<<<<<<< HEAD a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) +======= + a = torch.randn(M, K).to(torch.float16).cuda() + b = torch.randn(K, N).to(torch.float16).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with ( self.assertRaises(BackendCompilerFailed) as context, @@ -269,8 +324,13 @@ def mm(a, b): return torch.mm(a, b) M, N, K = 21, 31, 11 +<<<<<<< HEAD a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) +======= + a = torch.randn(M, K).to(torch.float16).cuda() + b = torch.randn(K, N).to(torch.float16).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TMA requires 16-byte alignment: here we repeat the dims # by the factor of 8, as float16 is 2-byte. All dims are @@ -336,6 +396,7 @@ def addmm(x, a, b): return torch.addmm(x, a, b) M, N, K = 21, 31, 11 +<<<<<<< HEAD a = ( torch.randn(*((K, M) if a_transposed else (M, K))) .to(torch.float16) @@ -347,6 +408,11 @@ def addmm(x, a, b): .to(GPU_TYPE) ) x = torch.randn(N).to(torch.float16).to(GPU_TYPE) +======= + a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda() + b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda() + x = torch.randn(N).to(torch.float16).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with config.patch( { @@ -363,16 +429,25 @@ def addmm(x, a, b): @unittest.skipIf( not has_triton_tma_device(), "Need device-side TMA support in Triton" ) +<<<<<<< HEAD @skipIfXpu(msg="TMA path on Intel GPU not require this check") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("dynamic", (False, True)) def test_max_autotune_addmm_persistent_tma_illegal_alignment(self, dynamic): def addmm(x, a, b): return torch.addmm(x, a, b) M, N, K = 21, 31, 11 +<<<<<<< HEAD a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) x = torch.randn(N).to(torch.float16).to(GPU_TYPE) +======= + a = torch.randn(M, K).to(torch.float16).cuda() + b = torch.randn(K, N).to(torch.float16).cuda() + x = torch.randn(N).to(torch.float16).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with ( self.assertRaises(BackendCompilerFailed) as context, @@ -399,9 +474,15 @@ def addmm(x, a, b): return torch.addmm(x, a, b) M, N, K = 21, 31, 11 +<<<<<<< HEAD a = torch.randn(M, K).to(torch.float16).to(GPU_TYPE) b = torch.randn(K, N).to(torch.float16).to(GPU_TYPE) x = torch.randn(N).to(torch.float16).to(GPU_TYPE) +======= + a = torch.randn(M, K).to(torch.float16).cuda() + b = torch.randn(K, N).to(torch.float16).cuda() + x = torch.randn(N).to(torch.float16).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TMA requires 16-byte alignment: here we repeat the dims # by the factor of 8, as float16 is 2-byte. All dims are @@ -425,7 +506,10 @@ def addmm(x, a, b): torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2) @fresh_cache() +<<<<<<< HEAD @skipIfXpu(msg="XPU doesn't support sm carveout") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support persistent TMA") @unittest.skipIf( @@ -447,15 +531,26 @@ def scaled_mm( # Create large matrices to ensure we use all possible sms size = 2560 +<<<<<<< HEAD a = torch.randn(size, size, device=GPU_TYPE, dtype=torch.bfloat16) b = ( torch.randn(size, size, device=GPU_TYPE, dtype=torch.bfloat16) +======= + a = torch.randn(size, size, device="cuda", dtype=torch.bfloat16) + b = ( + torch.randn(size, size, device="cuda", dtype=torch.bfloat16) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .transpose(0, 1) .contiguous() .transpose(0, 1) ) +<<<<<<< HEAD scale_a = torch.tensor(1, dtype=torch.float32, device=GPU_TYPE) scale_b = torch.tensor(1, dtype=torch.float32, device=GPU_TYPE) +======= + scale_a = torch.tensor(1, dtype=torch.float32, device="cuda") + scale_b = torch.tensor(1, dtype=torch.float32, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = ( (a.to(torch.float8_e4m3fn), b.to(torch.float8_e4m3fn), scale_a, scale_b) @@ -717,7 +812,11 @@ def forward(self, x): m_c = torch.compile(mode="max-autotune")(mod) out, code = run_and_get_code(m_c, x) +<<<<<<< HEAD self.assertEqual(out, mod(x), atol=2e-3, rtol=2e-3) +======= + self.assertEqual(out, mod(x), atol=2e-3, rtol=1e-3) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) FileCheck().check("triton_tem_fused_baddbmm").run(code[0]) @@ -843,6 +942,7 @@ def test_cat_max_autotune_extern(self): self._test_cat_max_autotune_impl(using_triton_mm=False) @skipIfXpu( +<<<<<<< HEAD msg="The fusion not happened because it do not speedup on XPU, see issue #146568" ) @config.patch( @@ -851,6 +951,11 @@ def test_cat_max_autotune_extern(self): "benchmark_epilogue_fusion": False, } ) +======= + msg="The fusion not happend because it do not speedup on XPU, see issue #146568" + ) + @config.patch(max_autotune_gemm_backends="TRITON") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cat_max_autotune_triton(self): self._test_cat_max_autotune_impl(using_triton_mm=True) @@ -876,7 +981,11 @@ def forward(self, x): self.assertEqual(out, m(input_tensor)) if not TEST_WITH_ROCM: +<<<<<<< HEAD FileCheck().check("def triton_poi_fused_add_cat_").run(code[0]) +======= + FileCheck().check("triton_poi_fused_cat_2.run").run(code[0]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv3d(self): fn = torch.nn.functional.conv3d @@ -902,15 +1011,26 @@ def test_conv_backend(self): self.assertIn("NoValidChoicesError", str(context.exception)) +<<<<<<< HEAD +======= + # Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks + @largeTensorTest("30 GB", device=GPU_TYPE) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_non_contiguous_input_mm(self): """ Make sure the triton template can work with non-contiguous inputs without crash. Check https://github.com/pytorch/pytorch/issues/125437 for more details. """ x = rand_strided( +<<<<<<< HEAD (50257, 2048), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE ) y = rand_strided((2048, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE) +======= + (50257, 32768), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE + ) + y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch.compile(mode="max-autotune") def f(x, y): @@ -923,9 +1043,15 @@ def f(x, y): def test_non_contiguous_input_addmm(self): b = torch.randn((768), dtype=torch.bfloat16, device=GPU_TYPE) x = rand_strided( +<<<<<<< HEAD (50257, 2048), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE ) y = rand_strided((2048, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE) +======= + (50257, 32768), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE + ) + y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch.compile(mode="max-autotune") def f(x, y): @@ -937,10 +1063,17 @@ def f(x, y): def test_non_contiguous_input_bmm(self): x = rand_strided( +<<<<<<< HEAD (1, 50257, 2048), (0, 1, 50304), dtype=torch.bfloat16, device=GPU_TYPE ) y = rand_strided( (1, 2048, 768), (0, 768, 1), dtype=torch.bfloat16, device=GPU_TYPE +======= + (1, 50257, 32768), (0, 1, 50304), dtype=torch.bfloat16, device=GPU_TYPE + ) + y = rand_strided( + (1, 32768, 768), (0, 768, 1), dtype=torch.bfloat16, device=GPU_TYPE +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @torch.compile(mode="max-autotune") @@ -954,12 +1087,23 @@ def f(x, y): # TODO: fix accuracy failure of the triton template on XPU. # and enable this test case. @skipIfXpu +<<<<<<< HEAD def test_non_contiguous_input_mm_plus_mm(self): x1 = rand_strided((50257, 2048), (1, 50304), device=GPU_TYPE) y1 = rand_strided((2048, 768), (768, 1), device=GPU_TYPE) x2 = rand_strided((50257, 2048), (1, 50304), device=GPU_TYPE) y2 = rand_strided((2048, 768), (768, 1), device=GPU_TYPE) +======= + # Some ROCm GPUs don't have enough VRAM to run all autotune configurations and padding benchmarks + @largeTensorTest("30 GB", device=GPU_TYPE) + def test_non_contiguous_input_mm_plus_mm(self): + x1 = rand_strided((50257, 32768), (1, 50304), device=GPU_TYPE) + y1 = rand_strided((32768, 768), (768, 1), device=GPU_TYPE) + + x2 = rand_strided((50257, 32768), (1, 50304), device=GPU_TYPE) + y2 = rand_strided((32768, 768), (768, 1), device=GPU_TYPE) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch.compile(mode="max-autotune") def f(x1, y1, x2, y2): @@ -1034,9 +1178,15 @@ def f(x, y): loss.backward() return loss +<<<<<<< HEAD x = torch.randn(B * T, C, requires_grad=True).to(GPU_TYPE).bfloat16() x.retain_grad() y = torch.randint(0, V, (B * T,)).to(GPU_TYPE) +======= + x = torch.randn(B * T, C, requires_grad=True).cuda().bfloat16() + x.retain_grad() + y = torch.randint(0, V, (B * T,)).cuda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch._inductor.utils as inductor_utils @@ -1048,6 +1198,10 @@ def f(x, y): assert same(expect, actual, tol=1e-2), f"ref:\n{expect}\nact:\n{actual}" @skipIfXpu +<<<<<<< HEAD +======= + @unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf( config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet" ) @@ -1070,8 +1224,13 @@ def test_max_autotune_decompose_k(self, sizes, dtype, dynamic): M, N, K = sizes +<<<<<<< HEAD a = torch.randn(M, K, dtype=dtype, device=GPU_TYPE, requires_grad=True) b = torch.randn(K, N, dtype=dtype, device=GPU_TYPE, requires_grad=True) +======= + a = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + b = torch.randn(K, N, dtype=dtype, device="cuda", requires_grad=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) possible_splits = range(2, min(K // M, K // N) + 1) @@ -1095,7 +1254,11 @@ def check_divisors(code): # We assume with the large k dim relative to m, n, decompose_k will be most performant out, code = run_and_get_code(compiled_func, a, b) +<<<<<<< HEAD if dynamic or torch.version.hip: +======= + if dynamic: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) FileCheck().check_not("extern_kernels.bmm_dtype").check_not( "decompose_k" ).run(code[0]) @@ -1109,7 +1272,11 @@ def check_divisors(code): # Test adding epilogue also equivalent to eager compiled_func = torch.compile(lambda a, b: (a @ b).relu(), dynamic=dynamic) out, code = run_and_get_code(compiled_func, a, b) +<<<<<<< HEAD if dynamic or torch.version.hip: +======= + if dynamic: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) FileCheck().check_not("extern_kernels.bmm_dtype").check_not( "decompose_k" ).run(code[0]) @@ -1128,9 +1295,13 @@ def check_divisors(code): lambda a, b: (a.transpose(0, 1) @ b).relu(), dynamic=dynamic ) out, code = run_and_get_code(compiled_func, a, b) +<<<<<<< HEAD # DecomposeK is not enabled for AMD yet if dynamic or torch.version.hip: +======= + if dynamic: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) FileCheck().check_not("extern_kernels.bmm_dtype").check_not( "decompose_k" ).run(code[0]) @@ -1168,10 +1339,17 @@ def f(a, b): return (a_in @ b).relu() a = torch.randn( +<<<<<<< HEAD 32, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True ) b = torch.randn( 32768, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True +======= + 32, 32768, dtype=torch.bfloat16, device="cuda", requires_grad=True + ) + b = torch.randn( + 32768, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) torch._dynamo.reset() @@ -1211,11 +1389,17 @@ def f(a, b): a_in = torch.cat([a for _ in range(256)], dim=0) return (a_in @ b).relu().sum() +<<<<<<< HEAD a = torch.randn( 8, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True ) b = torch.randn( 64, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True +======= + a = torch.randn(8, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True) + b = torch.randn( + 64, 32768, dtype=torch.bfloat16, device="cuda", requires_grad=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) torch._dynamo.reset() @@ -1255,21 +1439,38 @@ def f(a, b): a = a.transpose(0, 1) return a @ b +<<<<<<< HEAD a = torch.randn((32768, 256), device=GPU_TYPE, dtype=torch.bfloat16) b = torch.randn((32768, 1152), device=GPU_TYPE, dtype=torch.bfloat16) +======= + a = torch.randn((32768, 256), device="cuda", dtype=torch.bfloat16) + b = torch.randn((32768, 1152), device="cuda", dtype=torch.bfloat16) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) b = b[:, :1096] # Force only decomposeK choice with ( +<<<<<<< HEAD override_template_heuristics( device_type=GPU_TYPE, template_op_pairs=[(torch._inductor.kernel.mm.mm_template.name, "mm")], ), +======= + mock.patch( + "torch._inductor.kernel.mm.V.choices.get_base_mm_configs" + ) as base_mm_mock, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mock.patch( "torch._inductor.kernel.mm.use_decompose_k_choice" ) as decompose_mock, ): +<<<<<<< HEAD +======= + mm_configs_mock = MagicMock() + mm_configs_mock.return_value = [] + base_mm_mock.return_value = mm_configs_mock +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) decompose_mock.return_value = True compiled_f = torch.compile(f) out, code = run_and_get_code(compiled_f, a, b) @@ -1284,6 +1485,7 @@ def f(a, b): code[0] ) +<<<<<<< HEAD @unittest.skipIf(not torch.version.hip, "ROCM only") @parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32)) @parametrize("sizes", ((64, 128, 256), (128, 256, 512), (256, 512, 1024))) @@ -1479,6 +1681,8 @@ def mm_transpose_relu(a, b): # Check that contiguous transform was used FileCheck().check("contiguous_mm").run(code[0]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_triton_template_generated_code_cache_key(self): generate_and_load_args = len( inspect.signature( @@ -1494,7 +1698,11 @@ def test_triton_template_generated_code_cache_key(self): # Make sure all args of generate_and_load_args are passed to make_key_args (Except generate_with_caching) # update this function each time new arg added to generate_and_load and make sure arg is added to make_key self.assertEqual(generate_and_load_args - 1, make_key_args) +<<<<<<< HEAD self.assertEqual(generate_and_load_args, 17) +======= + self.assertEqual(generate_and_load_args, 16) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @fresh_cache() @config.patch( @@ -1582,7 +1790,11 @@ def func_test1(x, y, z, m): 'layout':"[[10,30],[30,1],torch.float32,device(type='cuda',index=0),0]", 'num_consumer_groups':0,'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity', 'kwargs':{'EVEN_K':False,'ALLOW_TF32':True,'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32', +<<<<<<< HEAD 'BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8},'hint_override':None}""" +======= + 'BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8}}""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expected = expected.replace("cuda", GPU_TYPE) self.assertExpectedInline( @@ -1616,12 +1828,21 @@ def func_test1(x, y, z, m): if not TEST_WITH_ROCM: expected = """{ 'input_nodes':[ +<<<<<<< HEAD "[[s77,s27],[s27,1],torch.float32,device(type='cuda',index=0),0]", "[[s27,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]"], 'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[s77,s94], 'layout':"[[s77,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]",'num_consumer_groups':0, 'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','kwargs':{'EVEN_K':False,'ALLOW_TF32':True, 'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8},'hint_override':None}""" +======= + "[[s77,s17],[s17,1],torch.float32,device(type='cuda',index=0),0]", + "[[s17,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]"], + 'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[s77,s94], + 'layout':"[[s77,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]",'num_consumer_groups':0, + 'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','kwargs':{'EVEN_K':False,'ALLOW_TF32':True, + 'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8}}""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expected = expected.replace("cuda", GPU_TYPE) self.assertExpectedInline( remove_white_space(cache_key), @@ -1782,7 +2003,10 @@ def misses(): self.assertEqual(hits(), 4) self.assertEqual(misses(), 4) +<<<<<<< HEAD @fresh_cache() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfXpu @unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm") @unittest.skipIf( @@ -1792,6 +2016,7 @@ def misses(): max_autotune=True, max_autotune_gemm_backends="TRITON", autotune_fallback_to_aten=False, +<<<<<<< HEAD ) @parametrize("num_decompose_k_splits", (0, 5, 20)) @parametrize("decompose_k_threshold", (8, 16)) @@ -1828,6 +2053,21 @@ def test_max_autotune_decompose_k_envvars( else: self.assertTrue(decompose_count > 0) self.assertTrue(decompose_count <= num_decompose_k_splits) +======= + disable_decompose_k=True, + ) + def test_max_autotune_disable_decompose_K(self): + M, N, K = (32, 32, 32768) + + a = torch.randn(M, K, dtype=torch.float16, device="cuda", requires_grad=True) + b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True) + + compiled_func = torch.compile(lambda a, b: a @ b) + out, code = run_and_get_code(compiled_func, a, b) + + for codegen in code: + FileCheck().check_not("decompose_k").run(codegen) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfXpu @unittest.skipIf( @@ -1840,6 +2080,7 @@ def f(a, b): M, N, K = (1024, 1024, 1024) +<<<<<<< HEAD a = torch.randn(M, K, dtype=torch.float16, device=GPU_TYPE, requires_grad=True) b = torch.randn(K, N, dtype=torch.float16, device=GPU_TYPE, requires_grad=True) @@ -1847,6 +2088,15 @@ def f(a, b): "torch._inductor.template_heuristics.registry.get_template_heuristic" ) as config_mock: config_heuristics = CUDAMMTemplateConfigHeuristic() +======= + a = torch.randn(M, K, dtype=torch.float16, device="cuda", requires_grad=True) + b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True) + + with mock.patch( + "torch._inductor.kernel.mm.V.choices.get_config_heuristics" + ) as config_mock: + config_heuristics = CUDAConfigHeuristic() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Traditionally, this would be set of all possible configs # We mock out the code path for the sake of the unit test @@ -1864,6 +2114,7 @@ def f(a, b): if "benchmark_gpu" in counter: self.assertEqual(counters["inductor"][counter], 2) +<<<<<<< HEAD @config.patch( { "max_autotune": True, @@ -1973,6 +2224,8 @@ def choice_validator(choices): finally: clear_preprocessing_fns() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestMaxAutotunePrecompile(TestCase): def test_precompilation_threads(self): @@ -2008,7 +2261,10 @@ def no_lookup( op: str, inputs: str, benchmark: Callable[[Any], dict[ChoiceCaller, float]], +<<<<<<< HEAD hint_override: Optional[int] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Optional[dict[ChoiceCaller, float]]: if benchmark is not None: return benchmark(choices) @@ -2062,6 +2318,24 @@ def fn(a, b, c): fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn) self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) +<<<<<<< HEAD +======= + @fresh_cache() + @config.patch(search_autotune_cache=True) + def test_search_autotune_cache(self): + def fn(a, b, c): + a = (a @ b) @ c + a, b, c = (t.to(torch.float16) for t in [a, b, c]) + return (a @ b) @ c + + fn_c = torch.compile()(fn) + inputs = [torch.rand([256, 256], device=GPU_TYPE) for _ in range(3)] + from torch._dynamo.utils import counters + + self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2) + self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @config.patch(autotune_local_cache=False, autotune_remote_cache=False) @runOnRocmArch(MI300_ARCH) def test_precompilations(self): @@ -2098,7 +2372,11 @@ def test_benchmark_choice_in_subproc(self): )() # a dummy graph to construct the GraphLowering graph = GraphLowering(gm) +<<<<<<< HEAD # the graph handler is needed to create benchmark example value below +======= + # the graph handler is neede to create benchmark example value below +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with V.set_graph_handler(graph): buf1 = self._create_buffer("mat1", (2, 3)) buf2 = self._create_buffer("mat2", (3, 2)) @@ -2138,7 +2416,11 @@ def test_benchmark_choice_fail_in_subproc(self): )() # a dummy graph to construct the GraphLowering graph = GraphLowering(gm) +<<<<<<< HEAD # the graph handler is needed to create benchmark example value below +======= + # the graph handler is neede to create benchmark example value below +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with V.set_graph_handler(graph): buf1 = self._create_buffer("mat1", (2, 3)) buf2 = self._create_buffer("mat2", (3, 2)) @@ -2486,6 +2768,7 @@ def test_tuning_pool_multiple_devices(self): tuning_pool.shutdown() +<<<<<<< HEAD def test_add_feedback_saver(self): """Test that add_feedback_saver correctly adds feedback functions.""" from torch._inductor.select_algorithm import get_algorithm_selector_cache @@ -2598,6 +2881,8 @@ def mm(a, b): # Clean up clear_feedback_savers() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @instantiate_parametrized_tests class TestPrologueFusion(TestCase): @@ -2637,9 +2922,12 @@ def check_code(self, code_str, num_kernels, num_allocs, num_deallocs): "del", num_deallocs, exactly=True ).run(code_str) +<<<<<<< HEAD @skipIfXpu( msg="Triton issue exposed by new driver, will be resolved after next triton update." ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250))) def test_upcast(self, sizes): M, K, N = sizes @@ -2756,7 +3044,11 @@ def foo(x, y): } ) @skipIfXpu( +<<<<<<< HEAD msg="The fusion not happened because it do not speedup on XPU, see issue #146568" +======= + msg="The fusion not happend because it do not speedup on XPU, see issue #146568" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_pending_fusions_multiple(self): def multi_use(x, y): @@ -2790,7 +3082,11 @@ def resolve_pending(x): } ) @skipIfXpu( +<<<<<<< HEAD msg="The fusion not happened because it do not speedup on XPU, see issue #146568" +======= + msg="The fusion not happend because it do not speedup on XPU, see issue #146568" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_pending_fusion_pro_and_epi(self): def test_multiple_fusions(x): @@ -2804,9 +3100,12 @@ def test_multiple_fusions(x): ).run(code[0]) self.assertEqual(out, test_multiple_fusions(x), atol=0.05, rtol=0.05) +<<<<<<< HEAD @skipIfXpu( msg="Triton issue exposed by new driver, will be resolved after next triton update." ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250))) def test_multiple_inputs(self, sizes): M, K, N = sizes @@ -2938,8 +3237,13 @@ def foo(x, y, z): out, code = run_and_get_code(torch.compile(foo), x, y, z) self.assertEqual(out, foo(x, y, z), atol=0.05, rtol=0.05) +<<<<<<< HEAD # there's one more dealloc than there should be because of a buffer reuse. TODO: # not sure why disabling buffer reuse doesn't stop +======= + # theres one more dealloc than there should be because of a buffer reuse. TODO: + # not sure why disabling buffer reuse doesnt stop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.check_code(code[0], num_kernels=2, num_allocs=2, num_deallocs=4) # XPU have not enabled pad_mm in fx_passes, so there is always one kernel. diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index 80372bca9fdca..ce9f968422184 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -8,6 +8,7 @@ from torch._inductor import config, memory from torch._inductor.test_case import TestCase from torch._inductor.utils import run_and_get_triton_code +<<<<<<< HEAD from torch.testing._internal.common_utils import serialTest from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU @@ -21,6 +22,11 @@ TRITON_AVAILABLE = False +======= +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Foo(torch.nn.Module): """ The default compiled graph is @@ -68,6 +74,7 @@ def test_reorder_peak_memory(self): outp_corr = self.model(self.inputs) compiled_model = torch.compile(self.model) code = run_and_get_triton_code(compiled_model, self.inputs) +<<<<<<< HEAD call_str = ( "def call(self, args):" @@ -78,6 +85,11 @@ def test_reorder_peak_memory(self): ( FileCheck() .check(call_str) +======= + ( + FileCheck() + .check("def call(args):") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .check("buf1 = ") .check("buf0 = ") .check("buf2 = ") @@ -112,12 +124,15 @@ def reorder_with_only_lpmf( methods=[memory.topological_sort_lpmf], ) +<<<<<<< HEAD call_str = ( "def call(self, args):" if torch._inductor.config.graph_partition else "def call(args):" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with mock.patch.object( memory, "reorder_for_peak_memory", reorder_with_only_lpmf ): @@ -126,7 +141,11 @@ def reorder_with_only_lpmf( code = run_and_get_triton_code(compiled_model, self.inputs) ( FileCheck() +<<<<<<< HEAD .check(call_str) +======= + .check("def call(args):") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .check("buf1 = ") .check("buf0 = ") .check("buf2 = ") @@ -161,22 +180,31 @@ def reorder_with_only_bfs( methods=[memory.topological_sort_bfs], ) +<<<<<<< HEAD call_str = ( "def call(self, args):" if torch._inductor.config.graph_partition else "def call(args):" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with mock.patch.object( memory, "reorder_for_peak_memory", reorder_with_only_bfs ): compiled_model = torch.compile(self.model) code = run_and_get_triton_code(compiled_model, self.inputs) +<<<<<<< HEAD ( FileCheck() .check(call_str) +======= + ( + FileCheck() + .check("def call(args):") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .check("buf0 = ") .check("buf1 = ") .check("buf2 = ") @@ -211,12 +239,15 @@ def reorder_with_only_dfs( methods=[memory.topological_sort_dfs], ) +<<<<<<< HEAD call_str = ( "def call(self, args):" if torch._inductor.config.graph_partition else "def call(args):" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with mock.patch.object( memory, "reorder_for_peak_memory", reorder_with_only_dfs ): @@ -225,7 +256,11 @@ def reorder_with_only_dfs( code = run_and_get_triton_code(compiled_model, self.inputs) ( FileCheck() +<<<<<<< HEAD .check(call_str) +======= + .check("def call(args):") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .check("buf0 = ") .check("buf2 = ") .check("buf4 = ") @@ -239,6 +274,7 @@ def reorder_with_only_dfs( outp = compiled_model(self.inputs) self.assertTrue(same(outp, outp_corr)) +<<<<<<< HEAD @mock.patch.object(config, "allow_buffer_reuse", False) @unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available") @config.patch("test_configs.track_memory_lifecycle", "assert") @@ -319,6 +355,8 @@ def f(a, p): # succ nodes should be forwarded to pre mutation buffer self.assertTrue(buffer_info[post][2] <= buffer_info[pre][2]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf( not torch.cuda.is_available() or torch.cuda.get_device_properties().total_memory < int(1e10), @@ -339,6 +377,7 @@ def f(a, b, c): expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2 self.assertLess(peak_mem, expected_bound) +<<<<<<< HEAD @serialTest() def test_fusion_acc_large_reads(self): def f(x, y, z): @@ -434,6 +473,8 @@ def replace_foreach(gm): code = run_and_get_triton_code(foo, inp, inp2) FileCheck().check("allocated=['buf0']").run(code) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index 1bcdeaa08e955..d6904490c47ed 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -24,6 +24,7 @@ from torch.export import Dim +<<<<<<< HEAD try: from .test_aot_inductor import AOTIRunnerUtil except ImportError: @@ -32,6 +33,8 @@ ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @requires_gpu() @config.patch(memory_planning=True) class TestMemoryPlanning(TestCase): @@ -84,6 +87,16 @@ def test_cpp_wrapper(self): @skipIfXpu(msg="aoti doesn't work on XPU") def test_aoti(self): +<<<<<<< HEAD +======= + try: + from .test_aot_inductor import AOTIRunnerUtil + except ImportError: + from test_aot_inductor import ( # @manual=fbcode//caffe2/test/inductor:test_aot_inductor-library + AOTIRunnerUtil, + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f, args = self._generate(device=GPU_TYPE) dim0_x = Dim("dim0_x", min=1, max=2048) dynamic_shapes = ({0: dim0_x}, None, None) @@ -104,6 +117,7 @@ def test_aoti(self): ).check_next("aoti_torch__alloc_from_pool(pool1, 0").run(code) self.assertTrue(same(f(*args), result)) +<<<<<<< HEAD @config.patch({"triton.autotune_at_compile_time": False}) def test_unbacked_symint(self): # when allocation's size has unbacked symints @@ -152,6 +166,8 @@ def forward(self, x, y): "AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(pool0, 0, cached_torch_dtype_float32, 3, int_array_4, int_array_5, &tmp_tensor_handle_1));" # noqa: B950 ).check("RAIIAtenTensorHandle(tmp_tensor_handle_1);").run(code) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": if HAS_GPU: diff --git a/test/inductor/test_minifier.py b/test/inductor/test_minifier.py index e8d695a1852d3..79efd5c64d4ef 100644 --- a/test/inductor/test_minifier.py +++ b/test/inductor/test_minifier.py @@ -249,7 +249,11 @@ def _aoti_check_relu_repro(self, res): assert res is not None ep_file_path = res.get_exported_program_path() assert ep_file_path is not None +<<<<<<< HEAD gm = export_load(ep_file_path).module(check_guards=False) +======= + gm = export_load(ep_file_path).module() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertExpectedInline( str(gm.code).strip(), """\ diff --git a/test/inductor/test_minifier_utils.py b/test/inductor/test_minifier_utils.py index 80c773830b4af..6617c82b5db29 100644 --- a/test/inductor/test_minifier_utils.py +++ b/test/inductor/test_minifier_utils.py @@ -63,7 +63,11 @@ def true_fn(x): ) model = M() +<<<<<<< HEAD gm = torch.export.export(model, inputs, strict=False).module(check_guards=False) +======= + gm = torch.export.export(model, inputs, strict=False).module() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: make NNModuleToString.convert() generate string for nested submodules. model_string = get_module_string(gm) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 8bbf76af6bac6..d8253a7ad39e1 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -18,7 +18,10 @@ from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer from torch.nn import functional as F from torch.testing._internal.common_device_type import instantiate_device_type_tests +<<<<<<< HEAD from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_quantization import ( _generate_qdq_quantized_model, skipIfNoDynamoSupport, @@ -177,7 +180,10 @@ def _test_common( is_dynamic=False, quantizer=None, compile_options={}, # noqa: B006 +<<<<<<< HEAD quantization_with_autocast=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): if not hasattr(self, "device"): has_xpu = any( @@ -207,6 +213,7 @@ def _test_common( assert check_autocast == torch.float32 maybe_autocast = contextlib.nullcontext() if check_quantization: +<<<<<<< HEAD if quantization_with_autocast: with maybe_autocast: convert_model = _generate_qdq_quantized_model( @@ -216,6 +223,11 @@ def _test_common( convert_model = _generate_qdq_quantized_model( mod, inputs, is_qat, is_dynamic, quantizer ) +======= + convert_model = _generate_qdq_quantized_model( + mod, inputs, is_qat, is_dynamic, quantizer + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch.no_grad(), maybe_autocast: _ = torch.compile(convert_model)(*inputs) matcher_check_fn() @@ -224,12 +236,16 @@ def _test_common( clone_inputs = self._clone_inputs(inputs) expected = mod(*inputs) actual = torch.compile(mod, **compile_options)(*clone_inputs) +<<<<<<< HEAD if self.precision != 0: torch.testing.assert_close( actual, expected, atol=self.precision, rtol=self.precision ) else: torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) +======= + torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) matcher_check_fn() def _test_code_common( @@ -319,11 +335,14 @@ def forward(self, x): memory_format, dtype, ) in options: +<<<<<<< HEAD if ( dtype != torch.float32 and torch.backends.mkldnn.matmul.fp32_precision == "tf32" ): continue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) metrics.reset() if dim == 4: x_shape = (1, 3, 56, 56) @@ -362,7 +381,10 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm +<<<<<<< HEAD @reduced_f32_on_and_off() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv2d_unary(self, device): self.device = device self._test_conv_unary_base(dim=4) @@ -370,7 +392,10 @@ def test_conv2d_unary(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm +<<<<<<< HEAD @reduced_f32_on_and_off() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv3d_unary(self, device): self.device = device self._test_conv_unary_base(dim=5) @@ -454,7 +479,10 @@ def matcher_check_fn(): @skipIfXpu( msg="The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device." ) +<<<<<<< HEAD @reduced_f32_on_and_off() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv_transpose2d_unary(self, device): self.device = device self._test_conv_transpose_unary_base(dim=4) @@ -465,7 +493,10 @@ def test_conv_transpose2d_unary(self, device): @skipIfXpu( msg="The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device." ) +<<<<<<< HEAD @reduced_f32_on_and_off() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv_transpose3d_unary(self, device): self.device = device self._test_conv_transpose_unary_base(dim=5) @@ -520,11 +551,14 @@ def forward(self, x): memory_format, dtype, ) in options: +<<<<<<< HEAD if ( dtype != torch.float32 and torch.backends.mkldnn.matmul.fp32_precision == "tf32" ): continue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) metrics.reset() if dim == 4: x_shape = (1, 3, 56, 56) @@ -560,7 +594,10 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm +<<<<<<< HEAD @reduced_f32_on_and_off(0.02) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv2d_binary(self, device): self.device = device self._test_conv_binary_base(dim=4) @@ -568,7 +605,10 @@ def test_conv2d_binary(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm +<<<<<<< HEAD @reduced_f32_on_and_off(0.02) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv3d_binary(self, device): self.device = device self._test_conv_binary_base(dim=5) @@ -667,7 +707,10 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm +<<<<<<< HEAD @reduced_f32_on_and_off() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv2d_binary_broadcast_shapes(self, device): self.device = device self._test_conv_binary_broadcast_shapes_base(dim=4) @@ -675,7 +718,10 @@ def test_conv2d_binary_broadcast_shapes(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm +<<<<<<< HEAD @reduced_f32_on_and_off() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv3d_binary_broadcast_shapes(self, device): self.device = device self._test_conv_binary_broadcast_shapes_base(dim=5) @@ -684,7 +730,10 @@ def test_conv3d_binary_broadcast_shapes(self, device): @skipIfNoONEDNN @skipIfRocm @unittest.skipIf(IS_FBCODE, "Failing in fbcode") +<<<<<<< HEAD @reduced_f32_on_and_off() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv2d_linear_add_broadcast_shapes(self, device): self.device = device @@ -716,7 +765,10 @@ def matcher_check_fn(): class TestPatternMatcher(TestPatternMatcherBase): +<<<<<<< HEAD @reduced_f32_on_and_off() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_linear_unary(self, device="cpu"): self.device = device @@ -747,6 +799,7 @@ def forward(self, x): dtypes.append(torch.bfloat16) if is_mkldnn_fp16_supported(self.device): dtypes.append(torch.float16) +<<<<<<< HEAD if torch.backends.mkldnn.matmul.fp32_precision in ["bf16", "tf32"]: dtypes.append(torch.float32) options = itertools.product(unary_list, [True, False], dtypes) @@ -756,6 +809,10 @@ def forward(self, x): and torch.backends.mkldnn.matmul.fp32_precision == "tf32" ): continue +======= + options = itertools.product(unary_list, [True, False], dtypes) + for unary_fn, bias, dtype in options: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) metrics.reset() mod = M(unary_fn, 10, 30, bias=bias).eval() # only fuse for linear when the dtype is bf16 @@ -764,7 +821,11 @@ def forward(self, x): def matcher_check_fn(): match_nodes = unary_list[unary_fn] +<<<<<<< HEAD if dtype != torch.float32 and self._check_unary_is_decomposed(unary_fn): +======= + if self._check_unary_is_decomposed(unary_fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Has extra dtype conversion nodes for autocast. match_nodes += 2 self.assertEqual( @@ -776,6 +837,7 @@ def matcher_check_fn(): ) self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype) +<<<<<<< HEAD # only generated 1 kernel for "to_dtype" expected_kernel_count = 2 if TEST_ACL else 1 if dtype == torch.float32: @@ -784,6 +846,11 @@ def matcher_check_fn(): self.assertEqual(metrics.generated_kernel_count, expected_kernel_count) @reduced_f32_on_and_off() +======= + # only generated 1 kernel for "to" + self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not TEST_MKL, "Test requires MKL") def test_linear_fp32(self, device="cpu"): self.device = device @@ -931,7 +998,10 @@ def matcher_check_fn(): # 1 kernel for "to_lowp", 2 kernels for unary ops self.assertEqual(metrics.generated_kernel_count, 3) +<<<<<<< HEAD @reduced_f32_on_and_off() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_linear_binary(self, device="cpu"): self.device = device @@ -953,8 +1023,11 @@ def forward(self, x, y): dtypes.append(torch.bfloat16) if is_mkldnn_fp16_supported(self.device): dtypes.append(torch.float16) +<<<<<<< HEAD if torch.backends.mkldnn.matmul.fp32_precision in ["bf16", "tf32"]: dtypes.append(torch.float32) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) options = itertools.product( binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes ) @@ -962,11 +1035,14 @@ def forward(self, x, y): for binary_fn, input_shape, bias, dtype in options: metrics.reset() +<<<<<<< HEAD if ( dtype != torch.float32 and torch.backends.mkldnn.matmul.fp32_precision == "tf32" ): continue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def matcher_check_fn(): self.assertEqual( @@ -996,12 +1072,16 @@ def matcher_check_fn(): matcher_check_fn, check_autocast=dtype, ) +<<<<<<< HEAD # only generated 1 kernel for "to_dtype" expected_kernel_count = 2 if TEST_ACL else 1 if dtype == torch.float32: # In BF32, input is float32, will not generate kernel for "to_dtype" expected_kernel_count -= 1 self.assertEqual(metrics.generated_kernel_count, expected_kernel_count) +======= + self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_linear_binary_broadcast_shapes(self, device="cpu"): self.device = device @@ -1113,12 +1193,16 @@ def matcher_check_fn(): v = torch.randn(2, 4, 16).to(dtype) self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2) +<<<<<<< HEAD def _qconv2d_test_helper( self, device="cpu", int8_mixed_bf16=False, quantization_with_autocast=False, ): +======= + def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class M(torch.nn.Module): def __init__( self, @@ -1151,7 +1235,11 @@ def matcher_check_fn(): ) self.assertEqual( counters["inductor"]["qconv_weight_prepack_matcher_nodes"], +<<<<<<< HEAD (16 if quantization_with_autocast else 18) if int8_mixed_bf16 else 12, +======= + 18 if int8_mixed_bf16 else 12, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.assertEqual( counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 3 @@ -1163,7 +1251,10 @@ def matcher_check_fn(): matcher_check_fn, check_quantization=True, check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, +<<<<<<< HEAD quantization_with_autocast=quantization_with_autocast, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @skipIfNoDynamoSupport @@ -1197,6 +1288,7 @@ def test_qconv2d_int8_mixed_bf16(self): @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN +<<<<<<< HEAD @skipIfRocmArch(MI300_ARCH) def test_qconv2d_int8_mixed_bf16_use_autocast(self): r""" @@ -1207,6 +1299,8 @@ def test_qconv2d_int8_mixed_bf16_use_autocast(self): @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfNoXPU def test_qconv2d_int8_mixed_bf16_xpu(self): r""" @@ -2353,7 +2447,10 @@ def _qlinear_test_helper( bias=True, is_dynamic=False, is_qat=False, +<<<<<<< HEAD quantization_with_autocast=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): class M(torch.nn.Module): def __init__(self, use_bias, do_permute=False): @@ -2392,14 +2489,21 @@ def _default_matcher_check_fn(): check_quantization=True, is_qat=is_qat, is_dynamic=is_dynamic, +<<<<<<< HEAD quantization_with_autocast=quantization_with_autocast, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @skipIfNoDynamoSupport @skipIfNoONEDNN def test_qlinear_cpu(self): r""" +<<<<<<< HEAD This testcase will quantize a single Linear Module. +======= + This testcase will quantize a single Linear Moduel. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ for bias in [True, False]: self._qlinear_test_helper((torch.randn((2, 4)),), bias=bias) @@ -2409,7 +2513,11 @@ def test_qlinear_cpu(self): @skipIfNoXPU def test_qlinear_xpu(self): r""" +<<<<<<< HEAD This testcase will quantize a single Linear Module. +======= + This testcase will quantize a single Linear Moduel. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ for bias in [True, False]: self._qlinear_test_helper( @@ -2420,7 +2528,11 @@ def test_qlinear_xpu(self): @skipIfNoONEDNN def test_dynamic_qlinear_cpu(self): r""" +<<<<<<< HEAD This testcase will quantize a single Linear Module. +======= + This testcase will quantize a single Linear Moduel. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ for bias in [True, False]: self._qlinear_test_helper( @@ -2431,7 +2543,11 @@ def test_dynamic_qlinear_cpu(self): @skipIfNoONEDNN def test_dynamic_qlinear_qat_cpu(self): r""" +<<<<<<< HEAD This testcase will quantize a single Linear Module. +======= + This testcase will quantize a single Linear Moduel. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ for bias in [True, False]: self._qlinear_test_helper( @@ -2442,7 +2558,11 @@ def test_dynamic_qlinear_qat_cpu(self): @skipIfNoONEDNN def test_dynamic_qlinear_input_dim_exceeds_2(self): r""" +<<<<<<< HEAD This testcase will quantize a single Linear Module. +======= + This testcase will quantize a single Linear Moduel. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ for bias in [True, False]: self._qlinear_test_helper( @@ -2454,7 +2574,11 @@ def test_dynamic_qlinear_input_dim_exceeds_2(self): @skipIfNoONEDNN def test_qlinear_int8_mixed_bf16(self): r""" +<<<<<<< HEAD This testcase will quantize a single Linear Module with int8_mixed_bf16 quantization. +======= + This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ for bias in [True, False]: self._qlinear_test_helper( @@ -2463,6 +2587,7 @@ def test_qlinear_int8_mixed_bf16(self): @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 +<<<<<<< HEAD @skipIfNoONEDNN def test_qlinear_int8_mixed_bf16_use_autocast(self): r""" @@ -2482,6 +2607,12 @@ def test_qlinear_int8_mixed_bf16_use_autocast(self): def test_qlinear_int8_mixed_bf16_xpu(self): r""" This testcase will quantize a single Linear Module with int8_mixed_bf16 quantization. +======= + @skipIfNoXPU + def test_qlinear_int8_mixed_bf16_xpu(self): + r""" + This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ for bias in [True, False]: self._qlinear_test_helper( @@ -2495,7 +2626,11 @@ def test_qlinear_int8_mixed_bf16_xpu(self): @skipIfNoONEDNN def test_qlinear_input_dim_exceeds_2(self): r""" +<<<<<<< HEAD This testcase will quantize a single Linear Module. +======= + This testcase will quantize a single Linear Moduel. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ for bias in [True, False]: self._qlinear_test_helper((torch.randn((2, 3, 4)),), bias=bias) @@ -2505,7 +2640,11 @@ def test_qlinear_input_dim_exceeds_2(self): @skipIfNoXPU def test_qlinear_input_dim_exceeds_2_xpu(self): r""" +<<<<<<< HEAD This testcase will quantize a single Linear Module. +======= + This testcase will quantize a single Linear Moduel. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ for bias in [True, False]: self._qlinear_test_helper( @@ -2517,7 +2656,11 @@ def test_qlinear_input_dim_exceeds_2_xpu(self): @skipIfNoONEDNN def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2(self): r""" +<<<<<<< HEAD This testcase will quantize a single Linear Module with int8_mixed_bf16 quantization. +======= + This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ for bias in [True, False]: self._qlinear_test_helper( @@ -2527,6 +2670,7 @@ def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2(self): @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN +<<<<<<< HEAD def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_use_autocast(self): r""" This testcase will quantize a single Linear Module with int8_mixed_bf16 quantization. @@ -2546,6 +2690,12 @@ def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_use_autocast(self): def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_xpu(self): r""" This testcase will quantize a single Linear Module with int8_mixed_bf16 quantization. +======= + @skipIfNoXPU + def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_xpu(self): + r""" + This testcase will quantize a single Linear Moduel with int8_mixed_bf16 quantization. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ for bias in [True, False]: self._qlinear_test_helper( @@ -2612,6 +2762,7 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN +<<<<<<< HEAD def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_and_not_contiguous_use_autocast( self, ): @@ -2643,6 +2794,8 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfNoXPU def test_qlinear_int8_mixed_bf16_input_dim_exceeds_2_and_not_contiguous_xpu(self): r""" @@ -2988,8 +3141,13 @@ def matcher_check_fn(): mod, (v,), [ +<<<<<<< HEAD f"aoti_torch_{device}__qlinear_pointwise_tensor", f"aoti_torch_{device}__qlinear_pointwise_binary_tensor", +======= + "aoti_torch_cpu__qlinear_pointwise_tensor", + "aoti_torch_cpu__qlinear_pointwise_binary_tensor", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ], [], check_quantization=True, @@ -3058,6 +3216,7 @@ def test_qlinear_add_int8_mixed_bf16_xpu(self, use_relu, is_qat, is_dynamic): is_dynamic=is_dynamic, ) +<<<<<<< HEAD def _test_qlinear_fp8_inductor_cpu_helper(self, qlinear_op, post_op="none"): dtype = torch.float8_e4m3fn qlinear_prepack = torch.ops.onednn.qlinear_prepack @@ -3156,6 +3315,8 @@ def test_qlinear_add_fp8_inductor_cpu(self): qlinear_op = torch.ops.onednn.qlinear_pointwise.binary self._test_qlinear_fp8_inductor_cpu_helper(qlinear_op, "add") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _qlinear_dequant_promotion_test_helper( self, inputs, diff --git a/test/inductor/test_mmdecomp.py b/test/inductor/test_mmdecomp.py index 22a5d83324597..bccd45e429849 100644 --- a/test/inductor/test_mmdecomp.py +++ b/test/inductor/test_mmdecomp.py @@ -6,6 +6,7 @@ import torch from torch._inductor import config +<<<<<<< HEAD from torch._inductor.decomposition import mm from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.symbolic_shapes import ( @@ -13,6 +14,8 @@ ShapeEnv, StatelessSymbolicContext, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_cuda import SM80OrLater from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_nn import NNTestCase @@ -85,6 +88,7 @@ def torch_baddbmm(add, b, c, alpha, beta): return torch.baddbmm(add, b, c, alpha=alpha, beta=beta) +<<<<<<< HEAD def create_fake_tensor_with_dynamic_size(x, fake_mode): with fake_mode: dynamic_sizes = [DimDynamic.DYNAMIC for _ in range(x.dim())] @@ -98,6 +102,8 @@ def create_fake_tensor_with_dynamic_size(x, fake_mode): ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The shapes we test on ts_list = [ (1, 32, 32, 1), @@ -172,7 +178,11 @@ def test_bmm_batch2_last_dim_size_is_one(self, device): @parametrize("dtype", [torch.float, torch.bfloat16, torch.int]) def test_some(self, device, dtype): # this Pytorch data type is not fully supported on cuda today +<<<<<<< HEAD # - unfortunately we can't skipIf because we don't see the actual params in skipIf +======= + # - unfortunately we can't skipIf because we don't see the actual parms in skipIf +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if device.startswith(GPU_TYPE) and dtype == torch.int: return @@ -192,7 +202,11 @@ def test_some(self, device, dtype): @parametrize("bs", [1, 2, 4, 10]) def test_some_batched(self, device, dtype, bs): # this Pytorch data type is not fully supported on cuda today +<<<<<<< HEAD # - unfortunately we can't skipIf because we don't see the actual params in skipIf +======= + # - unfortunately we can't skipIf because we don't see the actual parms in skipIf +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if device.startswith(GPU_TYPE) and dtype == torch.int: return @@ -207,6 +221,7 @@ def test_some_batched(self, device, dtype, bs): init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device), ) +<<<<<<< HEAD @parametrize("dtype", [torch.float, torch.bfloat16]) def test_dynamic_shape_mm(self, device, dtype): # Test that the mm decomp does not evaluate expressions for dynamic shapes @@ -272,6 +287,8 @@ def test_dynamic_shape_mm(self, device, dtype): self.assertTrue(r_expr_types[0] == og_t1_expr_types[0]) self.assertTrue(r_expr_types[1] == og_t2_expr_types[1]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device_types = ("cpu", GPU_TYPE) instantiate_device_type_tests(TestDecomp, globals(), only_for=device_types) diff --git a/test/inductor/test_move_constructors_to_cuda.py b/test/inductor/test_move_constructors_to_cuda.py index b174c79f1ebd0..38b84de868d7a 100644 --- a/test/inductor/test_move_constructors_to_cuda.py +++ b/test/inductor/test_move_constructors_to_cuda.py @@ -9,7 +9,11 @@ from torch.testing import FileCheck from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import IS_LINUX +<<<<<<< HEAD from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON +======= +from torch.testing._internal.inductor_utils import HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_multigpu = functools.partial( @@ -112,5 +116,9 @@ def foo(x): if __name__ == "__main__": +<<<<<<< HEAD if IS_LINUX and HAS_CUDA_AND_TRITON: +======= + if IS_LINUX and HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests() diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index 529fe0727028b..c65e28ee49958 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -6,7 +6,11 @@ import numpy as np import torch +<<<<<<< HEAD from torch.testing import FileCheck, make_tensor +======= +from torch.testing import make_tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_dtype import get_all_dtypes from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -85,6 +89,68 @@ def foo(x): def test_cast(self, dtype): self.common(lambda a: a.to(dtype), (torch.rand(1024),)) +<<<<<<< HEAD +======= + pointwise_unary_ops = [ + "i0", + "i0e", + "i1", + "i1e", + "erf", + "digamma", + "sinc", + "spherical_bessel_j0", + "bessel_j0", + "bessel_j1", + "bessel_y0", + "bessel_y1", + "modified_bessel_i0", + "modified_bessel_i1", + "modified_bessel_k0", + "modified_bessel_k1", + "scaled_modified_bessel_k0", + "scaled_modified_bessel_k1", + "entr", + ] + + @parametrize("op_name", pointwise_unary_ops) + def test_pointwise_unary_op(self, op_name): + self.common( + lambda x: getattr(torch.special, op_name)(x), + (torch.rand(128, 128),), + check_lowp=False, + ) + + def test_pointwise_polygamma(self): + self.common( + torch.special.polygamma, + ( + 1, + torch.rand(128, 128), + ), + check_lowp=False, + ) + + @parametrize( + "op_name", + [ + "zeta", + "xlog1py", + "chebyshev_polynomial_t", + "chebyshev_polynomial_u", + "chebyshev_polynomial_v", + "chebyshev_polynomial_w", + "hermite_polynomial_he", + ], + ) + def test_pointwise_binary_op(self, op_name): + self.common( + lambda x, y: getattr(torch.special, op_name)(x, y), + (torch.rand(128, 128), torch.rand(128, 128)), + check_lowp=False, + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_broadcast(self): self.common(torch.add, (torch.rand(32, 1024), torch.rand(1024))) @@ -130,10 +196,13 @@ def fn(x): self.common(fn, (torch.eye(64),), check_lowp=False) +<<<<<<< HEAD def test_reduced_max(self): # inductor test do not validate that max of say 16K half elements can be computed self.common(torch.max, (torch.rand(16384, dtype=torch.half),), check_lowp=False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class MPSBasicTestsAOTI(TestCase): def check_model(self, m, inp, dynamic_shapes=None): @@ -228,6 +297,7 @@ def forward(self, a, b): dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}} self.check_model(m, inp, dynamic_shapes) +<<<<<<< HEAD def test_reuse_kernel(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -260,6 +330,8 @@ def forward(self, x, y): exactly=True, ).run(src_code) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/inductor/test_multi_kernel.py b/test/inductor/test_multi_kernel.py index f576016cf08c5..5b2a2fe9f6633 100644 --- a/test/inductor/test_multi_kernel.py +++ b/test/inductor/test_multi_kernel.py @@ -16,6 +16,7 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, +<<<<<<< HEAD skipIfRocm, skipIfXpu, ) @@ -25,6 +26,11 @@ IS_BIG_GPU, requires_triton, ) +======= + skipIfXpu, +) +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TransformerSnippet(nn.Module): @@ -77,7 +83,10 @@ def fn(self): { "triton.multi_kernel": int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "1")), "benchmark_kernel": True, +<<<<<<< HEAD "multi_kernel_hints": [64, 256, 4096], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ) @instantiate_parametrized_tests @@ -98,6 +107,7 @@ def test_softmax(self, expect_multi_kernel=True): else: self.assertFalse(_contains_multi_kernel_code(wrapper_code)) +<<<<<<< HEAD @requires_triton() # TODO: bobrenjc93 to fix multi-kernel for ROCM @skipIfRocm @@ -152,6 +162,8 @@ def fn(x, y): self.assertEqual(ref, act) self.assertTrue(_contains_multi_kernel_code(wrapper_code)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("force_kernel", (0, 1)) @unittest.mock.patch.dict( os.environ, {"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE": "1"} @@ -252,8 +264,13 @@ def test_batchnorm_training(self): once for input and once for output. They are ruled out as in-out argument because they are considered as graph inputs. +<<<<<<< HEAD Multi-kernel previously assumes that we never pass the same argument multi times for a kernel. No matter if we change inductor behavior to assure that, it's better +======= + Multi-kernel previously assumes that we never pass the same argument mutli times + for a kernel. No mater if we change inductor behavior to assure that, it's better +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) to make multi-kernel being able to handle those cases. """ bn = nn.BatchNorm2d(3).to(GPU_TYPE) @@ -293,7 +310,11 @@ def f(x, y): def test_reduction_scratch_buffer(self, force_multi_kernel=1): """ +<<<<<<< HEAD The explicitly realized buffer in the test function will be passed in +======= + The explicited realized buffer in the test function will be passed in +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) as a scratch buffer for the non-persistent reduction kernel but can be skipped for the persistent reduction kernel. diff --git a/test/inductor/test_online_softmax.py b/test/inductor/test_online_softmax.py index 808757b7e041f..1ff232d24c3ba 100644 --- a/test/inductor/test_online_softmax.py +++ b/test/inductor/test_online_softmax.py @@ -14,7 +14,11 @@ IS_LINUX, parametrize, ) +<<<<<<< HEAD from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON +======= +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" @@ -293,6 +297,7 @@ def f(x, mask): self.assertTrue(not act.isnan().any()) self.assertTrue(torch.allclose(ref, act)) +<<<<<<< HEAD @inductor_config.patch(split_reductions=False) def test_3d_tiled_online_softmax(self): def f(x, y): @@ -306,9 +311,15 @@ def f(x, y): opt_f = torch.compile(f) torch.testing.assert_close(f(x, y), opt_f(x, y), atol=1e-3, rtol=1e-3) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_parametrized_tests(TestOnlineSoftmax) if __name__ == "__main__": +<<<<<<< HEAD if IS_LINUX and HAS_CUDA_AND_TRITON: +======= + if IS_LINUX and HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests() diff --git a/test/inductor/test_op_dtype_prop.py b/test/inductor/test_op_dtype_prop.py index 6f7eec601666b..881c6afcba5fc 100644 --- a/test/inductor/test_op_dtype_prop.py +++ b/test/inductor/test_op_dtype_prop.py @@ -260,7 +260,11 @@ def test_downcast_div_mod(self): def fn(x, y): return x % y, x / y +<<<<<<< HEAD x, y = (torch.rand([8], dtype=torch.float16, device=GPU_TYPE) for _ in range(2)) +======= + x, y = (torch.rand([8], dtype=torch.float16, device="cuda") for _ in range(2)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out, code = run_and_get_code(torch.compile(fn), x, y) @@ -271,7 +275,11 @@ def fn(x, y): @config.patch("test_configs.runtime_triton_dtype_assert", True) def test_constant(self): def fn(): +<<<<<<< HEAD return (torch.full((2, 3), 3.1416, device=GPU_TYPE, dtype=torch.float16),) +======= + return (torch.full((2, 3), 3.1416, device="cuda", dtype=torch.float16),) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out, code = run_and_get_code(torch.compile(fn)) FileCheck().check("static_assert").check_same(".dtype").run(code[0]) @@ -284,7 +292,11 @@ def test_any(self): def fn(x): return torch.any(x) +<<<<<<< HEAD x = torch.rand([40], device=GPU_TYPE).to(torch.bool) +======= + x = torch.rand([40], device="cuda").to(torch.bool) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out, code = run_and_get_code(torch.compile(fn), x) self.assertEqual(fn(x), out) @@ -293,7 +305,11 @@ def fn(x): def test_assoc_scan(self): from torch._higher_order_ops.associative_scan import associative_scan +<<<<<<< HEAD x = torch.randn(10, device=GPU_TYPE) +======= + x = torch.randn(10, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # dtype check correctly associative_scan( lambda acc, curr: acc + torch.abs(curr), x, dim=-1, combine_mode="pointwise" diff --git a/test/inductor/test_ordered_set.py b/test/inductor/test_ordered_set.py index 216b8ab0f0216..9cbe948fd318a 100644 --- a/test/inductor/test_ordered_set.py +++ b/test/inductor/test_ordered_set.py @@ -156,8 +156,13 @@ def f(s1, s2): "Pure python equivalent of isdisjoint()" return not OrderedSet(s1).intersection(s2) +<<<<<<< HEAD for large in "", "a", "ab", "abc", "ababac", "cdc", "cc", "efgfe", "ccb", "ef": s1 = self.thetype(large) +======= + for larg in "", "a", "ab", "abc", "ababac", "cdc", "cc", "efgfe", "ccb", "ef": + s1 = self.thetype(larg) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for rarg in ( "", "a", @@ -235,8 +240,12 @@ def test_symmetric_difference(self): self.assertRaises(TypeError, self.s.symmetric_difference, [[]]) for C in OrderedSet, frozenset, dict.fromkeys, str, list, tuple: self.assertEqual( +<<<<<<< HEAD self.thetype("abcba").symmetric_difference(C("cdc")), OrderedSet("abd"), # codespell:ignore +======= + self.thetype("abcba").symmetric_difference(C("cdc")), OrderedSet("abd") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.assertEqual( self.thetype("abcba").symmetric_difference(C("efgfe")), @@ -652,7 +661,11 @@ def test_symmetric_difference_update(self): ) self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) for p, q in ( +<<<<<<< HEAD ("cdc", "abd"), # codespell:ignore +======= + ("cdc", "abd"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("efgfe", "abcefg"), ("ccb", "a"), ("ef", "abcef"), @@ -991,7 +1004,11 @@ def test_changingSizeWhileIterating(self): s = OrderedSet([1, 2, 3]) try: for i in s: +<<<<<<< HEAD s.update([4]) # noqa: B909 +======= + s.update([4]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except RuntimeError: pass else: diff --git a/test/inductor/test_pad_mm.py b/test/inductor/test_pad_mm.py index 781f4588e1472..100786c200d7c 100644 --- a/test/inductor/test_pad_mm.py +++ b/test/inductor/test_pad_mm.py @@ -16,7 +16,11 @@ from torch._inductor.utils import fresh_cache, is_big_gpu, run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_utils import skipIfRocm +<<<<<<< HEAD from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON +======= +from torch.testing._internal.inductor_utils import HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class PadMMTest(TestCase): @@ -539,6 +543,7 @@ def fn(x, y): # Its name should contain `mm` because `mm` was the original aten op where the mm came from. FileCheck().check("def triton_tem_fused_mm").run(code[0]) +<<<<<<< HEAD def test_no_autocast_in_pad_bmm_joint_graph_pass(self): # Track bmm dtypes before and after joint graph passes bmm_dtypes_pre = {} @@ -651,4 +656,9 @@ def test_masked_mha(B, H, S, D, device, dtype): if __name__ == "__main__": if HAS_CUDA_AND_TRITON: +======= + +if __name__ == "__main__": + if HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests() diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index c67bde87a369b..9604667eb5bd8 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -49,6 +49,7 @@ def geninp(): return input_dict +<<<<<<< HEAD def get_padded_stride(shape, alignment_bytes, pad_output, itemsize): align = alignment_bytes // itemsize new_strides = [0 for _ in range(len(shape))] @@ -61,6 +62,8 @@ def get_padded_stride(shape, alignment_bytes, pad_output, itemsize): return tuple(new_strides) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class LinearAndSoftmax(nn.Module): """ It's very common that a transformer model will do a matmul and then @@ -109,10 +112,14 @@ def setUpClass(cls): if HAS_GPU: cls.prior_float32_matmul_precision = torch.get_float32_matmul_precision() cls.prior_default_device = torch.get_default_device() +<<<<<<< HEAD if torch.version.hip: torch.set_float32_matmul_precision("highest") else: torch.set_float32_matmul_precision("high") +======= + torch.set_float32_matmul_precision("high") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.set_default_device(GPU_TYPE) @classmethod @@ -378,7 +385,11 @@ def test_longformer(self, bs=4): @unittest.skipIf(not DO_PERF_TEST or not HAS_TRANSFORMER, "Perf test not enabled") def test_longformer_small_bs(self): """ +<<<<<<< HEAD The model exists in both HF and TB. In TB it uses a smaller batch size. +======= + The model exists in both HF and TB. In TB it uses a samller batch size. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ self.test_longformer(bs=2) @@ -419,7 +430,11 @@ def pad_mm(a, b, align=16): @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled") def test_padmm(self): """ +<<<<<<< HEAD Latency between original matmul and padded matmul: 2.717 v.s. 2.356 +======= + Latency between origional matmul and padded matmul: 2.717 v.s. 2.356 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ mat1_pad = torch.randn(8192, 30522, dtype=torch.float16) mat2_pad = torch.randn(30522, 768, dtype=torch.float16) @@ -443,7 +458,11 @@ def g(): pad_time = benchmarker.benchmark_gpu(g) print( +<<<<<<< HEAD f"Latency between original matmul and padded matmul: {ori_time:.3f} v.s. {pad_time:.3f}" +======= + f"Latency between origional matmul and padded matmul: {ori_time:.3f} v.s. {pad_time:.3f}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.do_profiling(f, g, "No MM Padding", "With mm padding") @@ -496,7 +515,11 @@ def test_LinearAndSoftmax_codegen(self, bias=True): self.assertEqual( m_bad_shape.linear.weight.grad, m_bad_shape_opt.linear.weight.grad ) +<<<<<<< HEAD self.assertTrue(len(wrapper_codes) == 2) # one for forward and one for backward +======= + self.assertTrue(len(wrapper_codes) == 2) # one for forward and oen for backward +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) forward_wrapper = wrapper_codes[0] # make sure the load for softmax is aligned @@ -776,6 +799,7 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor: output_shape = (shape[0] * num_inputs, shape[1]) output_stride = input_tensors[0].stride() output_line = f"buf12 = empty_strided_{GPU_TYPE}({output_shape}, {output_stride}, torch.float32)" +<<<<<<< HEAD self.assertTrue(output_line in code[0]) @parametrize( @@ -907,6 +931,9 @@ def test_dynamic_shape_padding(self, shape, alignment_bytes, enable_pad): result.shape, alignment_bytes, enable_pad, result.dtype.itemsize ) self.assertEqual(result.stride(), expected_stride) +======= + self.assertTrue(any(output_line in line for line in code)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index bfdc371006472..fef85dac65419 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -1004,7 +1004,11 @@ def fn(a, b, c): ] self.common(fn, args, 0, 0) +<<<<<<< HEAD # cat and split lengths are different +======= + # cat and split lenghts are different +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def fn(a, b, c): cat = torch.ops.aten.cat.default([a, b, c], 1) split_with_sizes = torch.ops.aten.split_with_sizes.default(cat, [5, 5], 1) @@ -1354,6 +1358,7 @@ def repl(inp, x1, x2): # addmm should be replaced FileCheck().check_not("extern_kernels.addmm(").run(code[0]) +<<<<<<< HEAD def test_addmm_dtype_mismatch(self): a = torch.nn.Linear(1024, 1024, bias=False).to(GPU_TYPE) a = a.to(dtype=torch.float16) @@ -1370,6 +1375,8 @@ def func(): self.assertEqual(actual, func()) FileCheck().check_not("addmm").run(code[0]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_replace_mul_zero(self): def test(x, y): return x + (y * 0) @@ -1752,6 +1759,7 @@ def my_func_static(x, w, epsilon): test, (code,) = run_and_get_code(my_func_static, *inputs) self.assertTrue("static_scaled_int8_quant" not in code) +<<<<<<< HEAD def test_fwd_only_generate_original_aten_meta(self): def f(x): return torch.ops.aten.sigmoid(x) @@ -1764,6 +1772,8 @@ def f(x): self.assertEqual(len(sigmoid_nodes), 1) self.assertTrue("original_aten" in sigmoid_nodes[0].meta) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": if IS_LINUX and HAS_GPU: diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 2dd6d498936fe..58df3406a9278 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -28,16 +28,24 @@ # performance for that setting. # # Defines all the kernels for tests +<<<<<<< HEAD from torch.testing._internal.triton_utils import ( HAS_CUDA_AND_TRITON, requires_cuda_and_triton, ) +======= +from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # set so that metrics appear torch._logging.set_logs(inductor_metrics=True) +<<<<<<< HEAD if HAS_CUDA_AND_TRITON: +======= +if HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import triton # @manual import triton.language as tl # @manual @@ -649,7 +657,11 @@ def f(a): @patch.object(config, "pattern_matcher", False) def test_fusion_choice4_cpu(self): +<<<<<<< HEAD # Fuse nodes with same number of elements and compatible original var ranges +======= + # Fuse nodes with same number of elements and compatible orginal var ranges +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # [buf0: {d0: 60, d1: 11}, buf1: {d0: 660}] -> buf0_buf1 def f(x, w): o1 = x * w @@ -923,7 +935,11 @@ def f(a, b): inp = (T(10, 10), TI(2, mx=5)) self.assertExpectedInline(count_numel(f, *inp), """42""") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_inplace_triton_kernel_training(self): @triton.jit def sin_kernel( @@ -967,7 +983,11 @@ def f(x): x = T(3, grad=True) self.assertExpectedInline(count_numel_train(f, x), """9""") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_triton_kernel_not_fusable_with_users(self): @triton.jit def _sin_kernel( @@ -1020,7 +1040,11 @@ def f(x): # (it will cost an extra kernel) self.assertExpectedInline(count_numel_train(f, x), """27""") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_inplace_custom_op_training_two_mutated_inputs(self): @torch.library.custom_op( "_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"} @@ -1040,7 +1064,11 @@ def f(x): x = T(3, grad=True) self.assertExpectedInline(count_numel(f, x), """21""") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_inplace_custom_op_training(self): @torch.library.custom_op("_reinplacing::sin", mutates_args={"result"}) def sin(x: torch.Tensor, result: torch.Tensor) -> None: @@ -1069,7 +1097,11 @@ def f(x): x = T(3, grad=True) self.assertExpectedInline(count_numel_train(f, x), """9""") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_inplace_custom_op(self): with torch.library._scoped_library("mylib", "FRAGMENT") as m: m.define("foo(Tensor x, Tensor(a!) out) -> ()") @@ -1099,7 +1131,11 @@ def f(x, out): self.assertExpectedInline(count_numel(f, x, out), """21""") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_inplace_custom_op_intermediate(self): with torch.library._scoped_library("mylib", "FRAGMENT") as m: m.define("foo(Tensor x, Tensor(a!) out) -> ()") @@ -1130,7 +1166,11 @@ def f(x, out): self.assertExpectedInline(count_numel(f, x, out), """21""") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_inplace_custom_op_two_mutated_inputs(self): with torch.library._scoped_library("mylib", "FRAGMENT") as m: m.define("foo(Tensor q, Tensor(a!) k_cache, Tensor(b!) v_cache) -> Tensor") @@ -1156,6 +1196,7 @@ def f(): torch.compile(f, fullgraph=True), ) +<<<<<<< HEAD # Check that we are not allocate intermediate buffers # which can be reused. matches = re.findall(r"empty_strided_\w+\(", code) @@ -1165,6 +1206,15 @@ def f(): self.assertExpectedInline(count_numel(f), """45""") @requires_cuda_and_triton +======= + # Check that we are allocating the minimum number of intermediate buffers + matches = re.findall(r"empty_strided_\w+\(", code) + self.assertEqual(len(matches), 1) + + self.assertExpectedInline(count_numel(f), """39""") + + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_inplace_triton_kernel_v1(self): def f(x: torch.Tensor, y: torch.Tensor): output = torch.zeros_like(x) @@ -1176,7 +1226,11 @@ def f(x: torch.Tensor, y: torch.Tensor): inp = (T(10), T(10)) self.assertExpectedInline(count_numel(f, *inp), """50""") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_inplace_triton_kernel_v2(self): def f(x: torch.Tensor, y: torch.Tensor): output = torch.zeros_like(x) @@ -1189,7 +1243,11 @@ def f(x: torch.Tensor, y: torch.Tensor): inp = (T(10), T(10)) self.assertExpectedInline(count_numel(f, *inp), """70""") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_inplace_triton_kernel_v3(self): def f(x: torch.Tensor, y: torch.Tensor): output = torch.zeros_like(x) @@ -1202,7 +1260,11 @@ def f(x: torch.Tensor, y: torch.Tensor): inp = (T(10), T(10)) self.assertExpectedInline(count_numel(f, *inp), """80""") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_inplace_triton_kernel_v4(self): def f(x: torch.Tensor, y: torch.Tensor): x_view = x.view(-1) @@ -1216,7 +1278,11 @@ def f(x: torch.Tensor, y: torch.Tensor): inp = (T(10), T(10)) self.assertExpectedInline(count_numel(f, *inp), """70""") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_inplace_triton_kernel_v5(self): def f(x: torch.Tensor, y: torch.Tensor): x_view = x.view(-1) @@ -1230,7 +1296,11 @@ def f(x: torch.Tensor, y: torch.Tensor): inp = (T(10), T(10)) self.assertExpectedInline(count_numel(f, *inp), """80""") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_inplace_triton_kernel_v6(self): def f(x: torch.Tensor, y: torch.Tensor): output = torch.zeros_like(x) @@ -1297,5 +1367,9 @@ def f(a, b): if __name__ == "__main__": from torch._inductor.test_case import run_tests +<<<<<<< HEAD if HAS_CUDA_AND_TRITON: +======= + if HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests(needs="filelock") diff --git a/test/inductor/test_profiler.py b/test/inductor/test_profiler.py index f22f0374813b0..b6bc560220a2d 100644 --- a/test/inductor/test_profiler.py +++ b/test/inductor/test_profiler.py @@ -12,7 +12,11 @@ from torch._inductor import config from torch.profiler import ProfilerActivity from torch.testing._internal.common_utils import TemporaryFileName +<<<<<<< HEAD from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON, IS_BIG_GPU +======= +from torch.testing._internal.inductor_utils import HAS_CUDA, IS_BIG_GPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.torch_version import TorchVersion from torch.utils._triton import has_triton @@ -313,5 +317,9 @@ def fn(x, y): if __name__ == "__main__": from torch._inductor.test_case import run_tests +<<<<<<< HEAD if HAS_CUDA_AND_TRITON: +======= + if HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests() diff --git a/test/inductor/test_provenance_tracing.py b/test/inductor/test_provenance_tracing.py index 7d6b714838ff9..e501686868837 100644 --- a/test/inductor/test_provenance_tracing.py +++ b/test/inductor/test_provenance_tracing.py @@ -1,14 +1,20 @@ # Owner(s): ["module: inductor"] +<<<<<<< HEAD import contextlib import io import json import logging import os +======= +import json +import logging +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import re import shutil import tempfile import unittest +<<<<<<< HEAD import zipfile from pathlib import Path @@ -26,6 +32,16 @@ from torch._inductor.virtualized import V from torch.testing._internal.common_utils import IS_MACOS from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +from pathlib import Path + +import torch +from torch._inductor import config +from torch._inductor.debug import create_node_mapping +from torch._inductor.test_case import run_tests, TestCase +from torch.testing._internal.inductor_utils import HAS_GPU +from torch.testing._internal.triton_utils import requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: @@ -34,9 +50,12 @@ from test_aot_inductor_utils import AOTIRunnerUtil +<<<<<<< HEAD trace_log = logging.getLogger("torch.__trace") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Model(torch.nn.Module): def __init__(self): super().__init__() @@ -70,6 +89,7 @@ def forward(self, a): return torch.nn.functional.linear(a, self.weight, self.bias) +<<<<<<< HEAD class Model4(torch.nn.Module): def __init__(self): super().__init__() @@ -89,18 +109,29 @@ def forward(self, x, a, b, c): @config.patch("trace.enabled", True) @config.patch("trace.provenance_tracking_level", 1) +======= +@config.patch("trace.enabled", True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestProvenanceTracingArtifact(TestCase): """ This test checks that generated provenance tracing artifact from "post_grad" to corresponding "inductor triton kernel node" is expected. """ +<<<<<<< HEAD def _check_provenance_tracing_kernel_to_post_grad(self, filepath, expected_data): self.assertTrue(filepath.is_dir()) filename = Path(filepath) / "inductor_provenance_tracking_node_mappings.json" with open(filename) as f: actual_data = json.load(f) actual_data = actual_data["cppCodeToPost"] +======= + def _check_provenance_tracing_artifact(self, filepath, expected_data): + self.assertTrue(filepath.is_dir()) + filename = Path(filepath) / "inductor_generated_kernel_to_post_grad_nodes.json" + with open(filename) as f: + actual_data = json.load(f) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # check that the generated provenance tracing artifact is expected self.assertEqual(sorted(actual_data.items()), sorted(expected_data.items())) @@ -118,11 +149,18 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): c = torch.randn(10, 30, device=device) example_inputs = (a, b, c) +<<<<<<< HEAD model = Model().to(device) filepath = None for backend in ["aot_inductor", "inductor"]: reset_inductor_kernel_provenance_debug_handle() +======= + model = Model() + filepath = None + + for backend in ["aot_inductor", "inductor"]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: with config.patch( { @@ -145,12 +183,32 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): self.assertTrue(m) filepath = Path(m.group(1)) if device == "cuda": +<<<<<<< HEAD +======= + expected_data = { + "triton_poi_fused_mul_0": ["mul"], + "triton_poi_fused_addmm_gelu_1": [ + "mul_3", + "mul_1", + "add_tensor", + "add", + "erf", + "mul_2", + ], + } + self._check_provenance_tracing_artifact(filepath, expected_data) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expected_mapping = [ ( "cppCodeToPost", { +<<<<<<< HEAD "triton_poi_fused_mul_0:1": ["mul"], "triton_poi_fused_addmm_gelu_1:2": [ +======= + "triton_poi_fused_mul_0": ["mul"], + "triton_poi_fused_addmm_gelu_1": [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "mul_3", "mul_1", "add_tensor", @@ -163,6 +221,7 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): ( "postToCppCode", { +<<<<<<< HEAD "mul": ["triton_poi_fused_mul_0:1"], "mul_3": ["triton_poi_fused_addmm_gelu_1:2"], "mul_1": ["triton_poi_fused_addmm_gelu_1:2"], @@ -170,6 +229,15 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): "add": ["triton_poi_fused_addmm_gelu_1:2"], "erf": ["triton_poi_fused_addmm_gelu_1:2"], "mul_2": ["triton_poi_fused_addmm_gelu_1:2"], +======= + "mul": ["triton_poi_fused_mul_0"], + "mul_3": ["triton_poi_fused_addmm_gelu_1"], + "mul_1": ["triton_poi_fused_addmm_gelu_1"], + "add_tensor": ["triton_poi_fused_addmm_gelu_1"], + "add": ["triton_poi_fused_addmm_gelu_1"], + "erf": ["triton_poi_fused_addmm_gelu_1"], + "mul_2": ["triton_poi_fused_addmm_gelu_1"], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, ), ( @@ -194,6 +262,7 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): }, ), ] +<<<<<<< HEAD if backend == "aot_inductor": expected_mapping[0][1]["aoti_torch_cuda_mm_out:3"] = [ "mm_default" @@ -208,6 +277,8 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): expected_mapping[1][1]["mm_default"] = [ "extern_kernels.mm:3" ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._check_provenance_tracking_node_mappings( filepath, expected_mapping ) @@ -216,9 +287,15 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): # check the inductor kernel to post grad nodes mapping is expected for cpu if backend == "aot_inductor": expected_data = { +<<<<<<< HEAD "cpp_fused_mul_0:1": ["mul"], "aoti_torch_cpu_addmm_out:3": ["addmm"], "cpp_fused_gelu_1:2": [ +======= + "cpp_fused_mul_0": ["mul"], + "aoti_torch_cpu_addmm_out": ["addmm", "mul"], + "cpp_fused_gelu_1": [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "mul_3", "mul_1", "add", @@ -229,24 +306,37 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): else: # backend == "inductor" expected_data = { +<<<<<<< HEAD "cpp_fused_mul_0:1": ["mul"], "cpp_fused_gelu_1:2": [ +======= + "cpp_fused_mul_0": ["mul"], + "aoti_torch_cpu_addmm_out": ["addmm", "mul"], + "cpp_fused_gelu_1": [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "mul_3", "mul_1", "add", "erf", "mul_2", ], +<<<<<<< HEAD "extern_kernels.addmm:3": ["addmm"], } self._check_provenance_tracing_kernel_to_post_grad( filepath, expected_data ) +======= + "extern_kernels.addmm": ["addmm", "mul"], + } + self._check_provenance_tracing_artifact(filepath, expected_data) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) finally: if filepath: shutil.rmtree(filepath) +<<<<<<< HEAD @requires_cuda_and_triton def test_triton_kernel_to_post_grad_tracing_cuda(self): self._test_triton_kernel_to_post_grad_tracing(device="cuda") @@ -255,6 +345,17 @@ def test_triton_kernel_to_post_grad_tracing_cpu(self): self._test_triton_kernel_to_post_grad_tracing(device="cpu") @requires_cuda_and_triton +======= + @requires_cuda + def test_triton_kernel_to_post_grad_tracing_cuda(self): + self._test_triton_kernel_to_post_grad_tracing(device="cuda") + + @unittest.skipIf(HAS_GPU, "the test is only for cpu") + def test_triton_kernel_to_post_grad_tracing_cpu(self): + self._test_triton_kernel_to_post_grad_tracing(device="cpu") + + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_triton_kernel_to_post_grad_tracing_extern_kernel(self): M = 8 N = 6 @@ -266,7 +367,10 @@ def test_triton_kernel_to_post_grad_tracing_extern_kernel(self): filepath = None for backend in ["aot_inductor", "inductor"]: +<<<<<<< HEAD reset_inductor_kernel_provenance_debug_handle() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: with config.patch( { @@ -290,22 +394,39 @@ def test_triton_kernel_to_post_grad_tracing_extern_kernel(self): filepath = Path(m.group(1)) if backend == "inductor": expected_data = { +<<<<<<< HEAD "extern_kernels.addmm:1": ["addmm"], +======= + "aoti_torch_cuda_addmm_out": ["addmm", "_tensor_constant1"], + "triton_poi_fused_0": ["_tensor_constant1"], + "extern_kernels.addmm": ["addmm"], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else: # backend = aot_inductor expected_data = { +<<<<<<< HEAD "aoti_torch_cuda_addmm_out:2": ["addmm"], "triton_poi_fused_0:1": ["_tensor_constant1"], } self._check_provenance_tracing_kernel_to_post_grad( filepath, expected_data ) +======= + "aoti_torch_cuda_addmm_out": ["addmm", "_tensor_constant1"], + "triton_poi_fused_0": ["_tensor_constant1"], + } + self._check_provenance_tracing_artifact(filepath, expected_data) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) finally: if filepath: shutil.rmtree(filepath) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _test_pt_tracing_combo_kernel(self, backend): """This test checks that generated provenance tracing artifact from triton combo kernel to post grad nodes""" a = torch.randn(10, 10, device="cuda") @@ -314,7 +435,10 @@ def _test_pt_tracing_combo_kernel(self, backend): example_inputs = (a, b, c) model = Model2() +<<<<<<< HEAD reset_inductor_kernel_provenance_debug_handle() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with config.patch( { @@ -338,10 +462,17 @@ def _test_pt_tracing_combo_kernel(self, backend): m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0]) self.assertTrue(m) filepath = Path(m.group(1)).resolve() +<<<<<<< HEAD expected_data = {"triton_poi_fused_0:1": ["relu", "sigmoid", "tanh"]} self._check_provenance_tracing_kernel_to_post_grad(filepath, expected_data) @requires_cuda_and_triton +======= + expected_data = {"triton_poi_fused_0": ["relu", "sigmoid", "tanh"]} + self._check_provenance_tracing_artifact(filepath, expected_data) + + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_triton_kernel_to_post_grad_tracing_combo_kernel(self): self._test_pt_tracing_combo_kernel(backend="inductor") self._test_pt_tracing_combo_kernel(backend="aot_inductor") @@ -413,6 +544,7 @@ def test_create_node_mapping(self): "triton_poi_fused_addmm_relu_sigmoid_0": ["relu", "add_tensor"] } +<<<<<<< HEAD result = create_mapping_pre_post_grad_nodes( pre_grad_graph_id, post_to_pre_grad_nodes_json, @@ -424,6 +556,13 @@ def test_create_node_mapping(self): ), } +======= + result = create_node_mapping( + pre_grad_graph_id, + post_to_pre_grad_nodes_json, + triton_kernel_to_post_grad_json, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( result, { @@ -451,6 +590,7 @@ def test_create_node_mapping(self): ) +<<<<<<< HEAD class TestProvenanceTracingNodeMeta(TestCase): def get_node_with_target(self, gm, target): """ @@ -778,5 +918,7 @@ def forward(self, x): self.assertTrue("aoti_torch_cpu_convolution" in keys) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_select_algorithm.py b/test/inductor/test_select_algorithm.py index b30cdc2d946c1..53e948b59d3aa 100644 --- a/test/inductor/test_select_algorithm.py +++ b/test/inductor/test_select_algorithm.py @@ -1,8 +1,12 @@ # Owner(s): ["module: inductor"] +<<<<<<< HEAD import contextlib import functools import unittest.mock from typing import Callable +======= +import functools +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from unittest.mock import patch import torch @@ -12,6 +16,7 @@ import torch.nn.functional as F from torch._dynamo.testing import expectedFailureDynamicWrapper from torch._dynamo.utils import counters +<<<<<<< HEAD from torch._inductor import config from torch._inductor.autotune_process import TritonBenchmarkRequest from torch._inductor.ir import FixedLayout @@ -30,13 +35,24 @@ requires_gpu, requires_triton, ) +======= +from torch._inductor.autotune_process import TritonBenchmarkRequest +from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.utils import is_big_gpu +from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm, skipIfXpu +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten = torch.ops.aten def patches(fn): +<<<<<<< HEAD def skip_cache(self, choices, name, key, benchmark, hint_override=None): +======= + def skip_cache(self, choices, name, key, benchmark): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if benchmark is None: return {} return benchmark(choices) @@ -67,8 +83,11 @@ def setUp(self): super().setUp() if not is_big_gpu(): return self.skipTest("Need a big GPU to run max_autotune=True") +<<<<<<< HEAD # Clear preprocessing functions to ensure clean state select_algorithm.clear_preprocessing_fns() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @patches def test_linear_relu(self): @@ -101,6 +120,7 @@ def foo(input, weight, bias): foo(*inps) self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) +<<<<<<< HEAD @patches def test_preprocessing_single_choice(self): # pass a list to the preprocessing function to assert that it was @@ -132,6 +152,8 @@ def foo(input, weight, bias): # The preprocessing function should have been called self.assertTrue(func_called[0]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @patch.object(select_algorithm, "VERIFY", dict(atol=5e-2, rtol=5e-2)) @patches def test_addmm_fp16(self): @@ -162,6 +184,10 @@ def foo(a, b): self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) @patches +<<<<<<< HEAD +======= + @skipIfXpu(msg="XPU has not supported _int_mm yet") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test__int_mm(self): @torch.compile def foo(a, b): @@ -417,6 +443,7 @@ def test_TritonTemplateCaller_str(self): self.assertEqual(caller_str, f"TritonTemplateCaller({module_path}, extra)") +<<<<<<< HEAD @contextlib.contextmanager def patch_lowering(lowering_overrides) -> Callable[[], None]: import torch._inductor.lowering as inductor_lowering @@ -531,6 +558,8 @@ def add(a, b): assert hook_identifier in kernels[0] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": if IS_LINUX and HAS_GPU and is_big_gpu(): run_tests() diff --git a/test/inductor/test_smoke.py b/test/inductor/test_smoke.py index 2a247fddbe76e..1019ef0d9abd8 100644 --- a/test/inductor/test_smoke.py +++ b/test/inductor/test_smoke.py @@ -6,11 +6,15 @@ import torch._logging from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import IS_LINUX +<<<<<<< HEAD from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_CUDA_AND_TRITON, HAS_GPU, ) +======= +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class MLP(torch.nn.Module): @@ -66,5 +70,9 @@ def test_compile_invalid_options(self): from torch._inductor.test_case import run_tests if IS_LINUX and HAS_GPU: +<<<<<<< HEAD if (not HAS_CUDA_AND_TRITON) or torch.cuda.get_device_properties(0).major <= 5: +======= + if (not HAS_CUDA) or torch.cuda.get_device_properties(0).major <= 5: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests() diff --git a/test/inductor/test_snode_runtime.py b/test/inductor/test_snode_runtime.py index c57393d993eab..1601ea6187077 100644 --- a/test/inductor/test_snode_runtime.py +++ b/test/inductor/test_snode_runtime.py @@ -56,7 +56,11 @@ class TestCase(InductorTestCase): """ Helper methods to compare runtime estimate against 0. Since this estimate is hardware dependent, +<<<<<<< HEAD stronger comparisons may fail depending on the host's specs. +======= + stronger comparisons may fail dependending on the host's specs. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) atol/rtol must be provided explicitly with each call, since precision/rel_tol overrides are not always utilized """ diff --git a/test/inductor/test_split_cat_fx_aten_passes.py b/test/inductor/test_split_cat_fx_aten_passes.py index 0ec7825df001c..2b4c48965bbe1 100644 --- a/test/inductor/test_split_cat_fx_aten_passes.py +++ b/test/inductor/test_split_cat_fx_aten_passes.py @@ -5,7 +5,11 @@ from torch._dynamo.utils import counters from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import GPU_TYPE +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +from torch.testing._internal.triton_utils import requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: @@ -49,6 +53,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): return torch.ops.aten.cat.default([cat_1, cat_2], 1) +<<<<<<< HEAD class TestSplitCatSingular(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -65,6 +70,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): return torch.ops.aten.cat.default([cat_1, cat_2], 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestSplitCatPartial(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -248,7 +255,11 @@ def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol) ) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={ @@ -291,6 +302,7 @@ def test_split_cat_post_grad(self): self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) counters.clear() +<<<<<<< HEAD @requires_cuda_and_triton @torch._inductor.config.patch( pre_grad_fusion_options={}, @@ -318,6 +330,9 @@ def test_split_cat_post_grad_singular(self): counters.clear() @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={ @@ -342,7 +357,11 @@ def test_select_cat_post_grad(self): self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) counters.clear() +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={ @@ -367,6 +386,7 @@ def test_move_view_after_cat_aten(self): counters.clear() +<<<<<<< HEAD class TestSplitCatAtenNormalizationPasses(TestCase): @torch._inductor.config.patch( pre_grad_fusion_options={}, @@ -400,5 +420,7 @@ def arg_only_size_different(x): counters.clear() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_split_cat_fx_passes.py b/test/inductor/test_split_cat_fx_passes.py index 4286bdfda7cd9..00e90b0d8deb9 100644 --- a/test/inductor/test_split_cat_fx_passes.py +++ b/test/inductor/test_split_cat_fx_passes.py @@ -115,6 +115,7 @@ def normalize_reshape_with_dynamic_shape(x): ) counters.clear() +<<<<<<< HEAD @torch._inductor.config.patch( pre_grad_fusion_options={ "normalization_pass": {}, @@ -142,6 +143,8 @@ def caoncat_only(x): ) counters.clear() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @patch def test_consecutive_split_merge(self): def multi_split(x): diff --git a/test/inductor/test_static_cuda_launcher.py b/test/inductor/test_static_cuda_launcher.py index 654bfd269f761..aa1f74f87b3ac 100644 --- a/test/inductor/test_static_cuda_launcher.py +++ b/test/inductor/test_static_cuda_launcher.py @@ -13,10 +13,17 @@ from torch._inductor.runtime.triton_helpers import libdevice from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import skipIfRocm +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton @requires_cuda_and_triton +======= +from torch.testing._internal.triton_utils import requires_cuda + + +@requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestStaticCudaLauncher(TestCase): def setUp(self): super().setUp() @@ -396,7 +403,11 @@ def kernel_many_args(out_tensor, {decl}): self.assertEqual(buf0, buf1) +<<<<<<< HEAD @requires_cuda_and_triton +======= +@requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._inductor.config.patch( {"use_static_cuda_launcher": True, "strict_static_cuda_launcher": True} ) diff --git a/test/inductor/test_subgraph_choice.py b/test/inductor/test_subgraph_choice.py index d2d5a3bf59a9e..35f378c40b55f 100644 --- a/test/inductor/test_subgraph_choice.py +++ b/test/inductor/test_subgraph_choice.py @@ -1,13 +1,27 @@ # Owner(s): ["module: inductor"] +<<<<<<< HEAD +======= +import functools +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import unittest from unittest import mock from unittest.mock import MagicMock import torch +<<<<<<< HEAD +======= +from torch._dispatch.python import enable_python_dispatcher +from torch._inductor.codegen.subgraph import SubgraphTemplate +from torch._inductor.decomposition import select_decomp_table +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.ir import Buffer, FixedLayout, FlexibleLayout from torch._inductor.lowering import register_lowering from torch._inductor.select_algorithm import autotune_select_algorithm from torch._inductor.test_case import run_tests, TestCase +<<<<<<< HEAD +======= +from torch.fx.experimental.proxy_tensor import make_fx +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import skipIfXpu, TEST_WITH_ROCM from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU @@ -59,6 +73,7 @@ def _(a, b): choices = [aten_mm.bind((mat1, mat2), layout)] kPartitions = 256 +<<<<<<< HEAD decompose_k_subgraph_template = ( torch._inductor.kernel.mm.DecomposeKSugraphTemplate() @@ -67,6 +82,22 @@ def _(a, b): decompose_k_subgraph_template.maybe_append_choice( choices, k_split=kPartitions, +======= + with enable_python_dispatcher(): + decompositions = select_decomp_table() + + decompose_k_subgraph_template = SubgraphTemplate( + name="decompose_k_mm", + make_fx_graph=make_fx( + functools.partial(decomposeK, kPartitions=kPartitions), + decompositions, + tracing_mode="real", + ), + ) + + decompose_k_subgraph_template.maybe_append_choice( + choices, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_nodes=(mat1, mat2), layout=layout, ) @@ -128,6 +159,7 @@ def _(a, b): choices = [] kPartitions = 2 +<<<<<<< HEAD decompose_k_subgraph_template = ( torch._inductor.kernel.mm.DecomposeKSugraphTemplate() @@ -136,6 +168,21 @@ def _(a, b): decompose_k_subgraph_template.maybe_append_choice( choices, k_split=kPartitions, +======= + with enable_python_dispatcher(): + decompositions = select_decomp_table() + + decompose_k_subgraph_template = SubgraphTemplate( + name="decompose_k_mm", + make_fx_graph=make_fx( + functools.partial(decomposeK, kPartitions=kPartitions), + decompositions, + ), + ) + + decompose_k_subgraph_template.maybe_append_choice( + choices, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_nodes=(mat1, mat2), layout=layout, ) diff --git a/test/inductor/test_torchbind.py b/test/inductor/test_torchbind.py index c604f8450bbbf..11c26ef8b15fa 100644 --- a/test/inductor/test_torchbind.py +++ b/test/inductor/test_torchbind.py @@ -1,5 +1,9 @@ # Owner(s): ["module: functorch"] import json +<<<<<<< HEAD +======= +import tempfile +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import zipfile from pathlib import Path @@ -10,10 +14,15 @@ import torch._inductor.decomposition from torch._higher_order_ops.torchbind import CallTorchBind, enable_torchbind_tracing from torch._inductor import aot_compile, ir +<<<<<<< HEAD from torch._inductor.codecache import WritableTempFile from torch._inductor.package import package_aoti from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_utils import skipIfWindows +======= +from torch._inductor.package import package_aoti +from torch._inductor.test_case import run_tests, TestCase +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu from torch.testing._internal.torchbind_impls import ( _empty_tensor_queue, @@ -159,7 +168,10 @@ def test_torchbind_hop_schema_no_output(self): "call_torchbind(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0, str method, Tensor _1) -> NoneType _0", ) +<<<<<<< HEAD @skipIfWindows(msg="AOTI is not fully support on Windows") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_torchbind_aot_compile(self): ep, inputs, _, _ = self.get_exported_model() aoti_files = aot_compile( @@ -174,7 +186,11 @@ def test_torchbind_aot_compile(self): custom_objs_config = file elif file.endswith("/custom_obj_0"): custom_obj_0 = file +<<<<<<< HEAD elif file.endswith("wrapper.json") and "metadata" not in file: +======= + elif file.endswith(".json") and "metadata" not in file: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extern_json = file self.assertIsNotNone(custom_objs_config) @@ -282,7 +298,11 @@ def test_torchbind_aot_compile(self): ) # Test that the files are packaged +<<<<<<< HEAD with WritableTempFile(suffix=".pt2") as f: +======= + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) package_path = package_aoti(f.name, aoti_files) with zipfile.ZipFile(package_path, "r") as zip_ref: @@ -304,7 +324,10 @@ def test_torchbind_aoti(self): self.assertEqual(result, orig_res) @torch._inductor.config.patch("aot_inductor.use_runtime_constant_folding", True) +<<<<<<< HEAD @skipIfWindows(msg="AOTI is not fully support on Windows") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_torchbind_aot_compile_constant_folding(self): ep, inputs, orig_res, _ = self.get_exported_model() pt2_path = torch._inductor.aoti_compile_and_package(ep) @@ -413,6 +436,7 @@ def forward(self, x, y): ): aot_compile(ep.module(), inputs, options={"aot_inductor.package": True}) +<<<<<<< HEAD def test_aoti_torchbind_name_collision(self): class M(torch.nn.Module): def __init__(self) -> None: @@ -437,6 +461,8 @@ def forward(self, x): result = optimized(*inputs) self.assertEqual(result, orig_res) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index b5d880cd90f4f..5219303032d20 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -69,7 +69,10 @@ from torch.nn import functional as F from torch.testing import FileCheck, make_tensor from torch.testing._internal.common_cuda import ( +<<<<<<< HEAD IS_SM90, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, SM80OrLater, @@ -96,7 +99,10 @@ MACOS_VERSION, parametrize, serialTest, +<<<<<<< HEAD skipIfMPS, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) skipIfRocm, skipIfWindows, skipIfXpu, @@ -139,7 +145,11 @@ skipCPUIf, skipCUDAIf, ) +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +from torch.testing._internal.triton_utils import requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _T = TypeVar("_T") @@ -193,7 +203,11 @@ def _large_cumprod_input(shape, dim, dtype, device): +<<<<<<< HEAD # Construct a cumprod input which guarantees not to overflow or underflow +======= + # Construct a cumprod input which guaruntees not to overflow or underflow +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_integer_dtype(dtype): # Large products don't fit in integers, the best we can do # is random +/-1 values to test the sign of the result @@ -434,8 +448,11 @@ def check_model( check_gradient=False, check_has_compiled=True, output_process_fn_grad=lambda x: x, +<<<<<<< HEAD # TODO: enable this for all tests exact_stride=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): kwargs = kwargs or {} torch._dynamo.reset() @@ -469,12 +486,16 @@ def upcast_fn(x): x.dtype == torch.float16 or x.dtype == torch.bfloat16 ): has_lowp_args = True +<<<<<<< HEAD # Preserve strides when casting result = torch.empty_strided( x.size(), x.stride(), device=x.device, dtype=torch.float ) result.copy_(x) return result +======= + return x.float() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return x @@ -545,6 +566,7 @@ def reference_to_expect(actual_flat, correct_flat): correct_flat = reference_to_expect(actual_flat, correct_flat) correct = tree_unflatten(correct_flat, correct_spec) +<<<<<<< HEAD # Allow assert_equal to be a custom function, instead of True or False, for # cases where differences may not indicate incorrectness. if assert_equal: @@ -558,16 +580,25 @@ def custom_assert_with_self(*args, **kwargs): assert_equal_fn = self.assertEqual assert_equal_fn( +======= + if assert_equal: + self.assertEqual( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) actual, correct, atol=atol, rtol=rtol, equal_nan=True, exact_dtype=exact_dtype, +<<<<<<< HEAD exact_stride=exact_stride, ) # In case of input mutations, check that inputs are the same # (This never uses a custom assert_equal fn.) +======= + ) + # In case of input mutations, check that inputs are the same +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( ref_inputs, example_inputs, @@ -576,7 +607,10 @@ def custom_assert_with_self(*args, **kwargs): equal_nan=True, # our testing sometimes uses higher precision inputs for the reference exact_dtype=False, +<<<<<<< HEAD exact_stride=exact_stride, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: for correct_val, actual_val in zip(correct_flat, actual_flat): @@ -590,8 +624,11 @@ def custom_assert_with_self(*args, **kwargs): assert correct_val.layout == actual_val.layout if exact_dtype: assert correct_val.dtype == actual_val.dtype +<<<<<<< HEAD if exact_stride: assert correct_val.stride() == actual_val.stride() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if check_gradient: actual = output_process_fn_grad(actual) @@ -645,7 +682,10 @@ def custom_assert_with_self(*args, **kwargs): rtol=grad_rtol or rtol, equal_nan=True, exact_dtype=exact_dtype, +<<<<<<< HEAD exact_stride=exact_stride, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) torch._dynamo.reset() @@ -671,8 +711,11 @@ def check_model_gpu( check_gradient=False, check_has_compiled=True, output_process_fn_grad=lambda x: x, +<<<<<<< HEAD # TODO: enable this for all tests exact_stride=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): kwargs = kwargs or {} if hasattr(model, "to"): @@ -699,7 +742,10 @@ def check_model_gpu( check_gradient=check_gradient, check_has_compiled=check_has_compiled, output_process_fn_grad=output_process_fn_grad, +<<<<<<< HEAD exact_stride=exact_stride, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if check_lowp: @@ -732,7 +778,10 @@ def downcast_fn(x): check_gradient=check_gradient, check_has_compiled=check_has_compiled, output_process_fn_grad=output_process_fn_grad, +<<<<<<< HEAD exact_stride=exact_stride, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -1409,6 +1458,13 @@ def fn(*args): b = torch.add(args[0], args[0]) return (a, b) +<<<<<<< HEAD +======= + # Complex are not supported on MacOS-13 + if self.device == "mps" and MACOS_VERSION < 14.0: + raise unittest.SkipTest("No complex on MacOS13") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.randn(41, dtype=torch.complex64, device=self.device) y = x.clone() # should not inplace write to the input @@ -1468,6 +1524,7 @@ def fn(a, b, alpha): self.common(fn, (x, y, 2)) +<<<<<<< HEAD def test_add_complex7(self): # Fix https://github.com/pytorch/pytorch/issues/160495 # Test scalar (0-dimensional) complex tensor addition: 0D + 0D @@ -1511,6 +1568,8 @@ def fn(a, b, alpha): y = torch.rand((), dtype=torch.complex64, device=self.device) self.common(fn, (x, y, 2)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_concat_add_inplace(self): def fn(x, y, z): return torch.cat([x, y], dim=1).add_(z) @@ -1537,6 +1596,13 @@ def fn(a, b, c): ) real_input = torch.tensor([-1.0, 0.0, 1.0, float("nan")]) interger_real_input = torch.tensor([-1, 0, 1]) +<<<<<<< HEAD +======= + # Complex are not supported on MacOS-13 + if self.device == "mps" and MACOS_VERSION < 14.0: + self.common(fn, (complex_input.real, real_input, interger_real_input)) + return +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.common(fn, (complex_input, real_input, interger_real_input)) def test_sgn(self): @@ -1689,6 +1755,12 @@ def copy(x): i = torch.arange(x.size(0), device=x.device) return x[i] +<<<<<<< HEAD +======= + if self.device == "mps" and MACOS_VERSION < 13.3: + raise unittest.SkipTest("Inaccurate on MacOS-13") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.randn(8, device=self.device) copy_opt = torch.compile(copy, backend="inductor") @@ -2109,6 +2181,11 @@ def fn(a): return torch.max(a), torch.sum(a) # Requires masked loading for the intermediate reduction +<<<<<<< HEAD +======= + if self.device == "mps" and MACOS_VERSION < 13.3: + raise unittest.SkipTest("Fails with internal compiler error on MacOS-13") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sample = torch.full((3999971,), 0, dtype=torch.int64) sample[-1] = 1 self.common(fn, (sample,)) @@ -2166,6 +2243,16 @@ def fn(a): ): if not self.is_dtype_supported(dtype): continue +<<<<<<< HEAD +======= + # cumsum not implemented for integers on MacOS-13 + if ( + self.device == "mps" + and not dtype.is_floating_point + and MACOS_VERSION < 13.3 + ): + continue +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Use low=0 since when the mean value is 0, cumsum at all points # tends towards zero which makes the relative error term blow up inp = make_tensor(10, 3, 352, 352, low=0, dtype=dtype, device=self.device) @@ -2218,6 +2305,12 @@ def fn(lengths, data): offsets = torch.cumsum(lengths, 0) return data[offsets] +<<<<<<< HEAD +======= + if self.device == "mps" and MACOS_VERSION < 13.3: + raise unittest.SkipTest("CumSum for int64 needs MacOS-13.3+") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lengths = torch.full((2**14,), 2**2, dtype=torch.int64, device=self.device) lengths[-2] = 3 lengths[-1] = 3 @@ -2231,6 +2324,16 @@ def fn(a): for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]: if not self.is_dtype_supported(dtype): continue +<<<<<<< HEAD +======= + # cumsum not implemented on MacOS-13 + if ( + self.device == "mps" + and not dtype.is_floating_point + and MACOS_VERSION < 13.3 + ): + continue +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inp = _large_cumprod_input( (10, 10000), dim=1, dtype=dtype, device=self.device ) @@ -2585,6 +2688,13 @@ def test_sum_int(self): def fn(x): return 2 * x.sum(-1) + x.sum() +<<<<<<< HEAD +======= + # Requires masked loading for the intermediate reduction + if self.device == "mps" and MACOS_VERSION < 13.3: + raise unittest.SkipTest("Fails with internal compiler error on MacOS-13") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtypes = torch.bool, torch.uint8, torch.int inps = [torch.randint(2, (64,), dtype=dtype) for dtype in dtypes] @@ -2592,6 +2702,12 @@ def fn(x): self.common(fn, (i,), check_lowp=False) def test_sum_dtype(self): +<<<<<<< HEAD +======= + if self.device == "mps" and MACOS_VERSION < 14.0: + raise unittest.SkipTest("bfloat unsupported on MacOS-13") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sum_dtype = torch.double if self.device != "mps" else torch.bfloat16 def fn(x): @@ -2686,6 +2802,10 @@ def make_tensor(shape): inp = torch.full((2, n), float("inf"), device=self.device, dtype=_dtype) self.assertEqual(cfn(inp), fn(inp)) +<<<<<<< HEAD +======= + @xfail_if_mps_unimplemented +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @xfail_if_triton_cpu def test_logcumsumexp(self): def fn(x): @@ -2712,6 +2832,10 @@ def fn(x): rtol=1e-5, ) +<<<<<<< HEAD +======= + @xfail_if_mps_unimplemented +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_logcumsumexp_zero_dim(self): def fn(x): return x.logcumsumexp(0), x.logcumsumexp(-1) @@ -2991,6 +3115,7 @@ def forward(x, y): ), ) +<<<<<<< HEAD def test_torch_device_split(self): def fn(x): return x.split(2) @@ -3003,6 +3128,8 @@ def fn(x): for a, b in zip(out, ref): self.assertTrue(torch.allclose(a, b)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_relu(self): def fn(a, b): return (torch.relu(a), torch.relu(a + b) / 10) @@ -3024,6 +3151,11 @@ def fn(a, b): @skipIfXpu(msg="logaddexp_xpu not implemented for ComplexFloat") @skipCUDAIf(True, "Not implemented for CUDA") def test_logaddexp(self): +<<<<<<< HEAD +======= + if self.device == "mps" and MACOS_VERSION < 14.0: + raise unittest.SkipTest("Complex needs MacOS-14+") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.common( torch.logaddexp, ( @@ -3221,6 +3353,12 @@ def fn(a, b): a // b, ) +<<<<<<< HEAD +======= + if self.device == "mps" and MACOS_VERSION < 13.3: + raise unittest.SkipTest("Inaccurate for MPS no MacOS-13") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.common( fn, (torch.randint(-100, 0, [8, 8]), torch.randint(1, 10, [8, 8])), @@ -3440,6 +3578,7 @@ def forward(x, y): cf(x, 1e-5) cf(x, 1e-6) +<<<<<<< HEAD def test_div_presicion_accuracy(self): # fix https://github.com/pytorch/pytorch/issues/157959 def forward(x, y): @@ -3449,6 +3588,8 @@ def forward(x, y): y = 101 self.common(forward, (x, y)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_mul_softmax_symfloat(self): def forward(x, y): z = x.mul(y * x.shape[-1]) @@ -3854,6 +3995,11 @@ def fn(a, b): torch.randn(256, 256), torch.randint(-128, 127, (256, 256), dtype=torch.int8), ), +<<<<<<< HEAD +======= + # MacOS-13 MM ops have precision issues + check_lowp=self.device != "mps" or MACOS_VERSION > 14.0, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) rtol=0.01, atol=0.1, ) @@ -4335,7 +4481,11 @@ def fn(a): (torch.randn([2, 20, 2]),), ) +<<<<<<< HEAD # It's a view so it doesn't generate a kernel +======= + # It's a view so it doens't generate a kernel +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @expectedFailureCodegenDynamic def test_slice3(self): def fn(a, b): @@ -4428,7 +4578,13 @@ def fn2(a): ) @parametrize("dilation", (1, 2)) +<<<<<<< HEAD @parametrize("dim", (subtest(2), subtest(3))) +======= + @parametrize( + "dim", (subtest(2), subtest(3, decorators=[xfail_if_mps_unimplemented])) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_low_memory_max_pool(self, dilation: int, dim: int): prims = torch.ops.prims @@ -4459,6 +4615,12 @@ def fn(x): self.common(fn, (torch.randn(1, 3, *[10] * dim),)) def test_to_dtype(self): +<<<<<<< HEAD +======= + if self.device == "mps" and MACOS_VERSION < 14.0: + raise unittest.SkipTest("bfloat unsupported on MacOS-13") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_dtype = torch.float64 if self.device != "mps" else torch.bfloat16 def fn(a, b): @@ -4705,7 +4867,11 @@ def test_conv3d(self): self.common( m, (torch.randn([1, 3, 8, 16, 32]),), +<<<<<<< HEAD atol=1e-3, +======= + atol=6e-5, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) rtol=0.001, # Make sure we compute also with fp16 in the reference. Otherwise, # the reference will compute with fp32 and cast back to fp16, which @@ -5400,6 +5566,13 @@ def test_tan(self): def fn(x): return aten.tan(x) + 2, aten.tan(x + 1) +<<<<<<< HEAD +======= + # tan is broken in MPSGraph for MacOS before version 13.3 + if self.device == "mps" and MACOS_VERSION < 13.3: + raise unittest.SkipTest("tan is inaccurate for MPS no MacOS-13") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.common( fn, (torch.randn([16, 16]),), @@ -5773,6 +5946,7 @@ def forward(self, x): if self.device != "cpu": assertGeneratedKernelCountEqual(self, 1) +<<<<<<< HEAD def test_complex_from_real_imag(self): def fn(x, y): return aten.complex.default(x, y) @@ -5786,6 +5960,8 @@ def fn(x, y): reference_in_float=False, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_view_as_complex(self): class Repro(torch.nn.Module): def __init__(self) -> None: @@ -6019,6 +6195,13 @@ def test_pow1(self): def fn(x): return [aten.pow(x, e) for e in range(-8, 9)] +<<<<<<< HEAD +======= + # pow is broken in MPSGraph for MacOS before version 13.3 + if self.device == "mps" and MACOS_VERSION < 13.3: + raise unittest.SkipTest("pow is inaccurate for MPS no MacOS-13") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.common( fn, (torch.randn([16, 16]),), @@ -6029,6 +6212,13 @@ def test_pow2(self): def fn(x): return aten.pow(1000, x), aten.pow(x, 1000) +<<<<<<< HEAD +======= + # pow is broken in MPSGraph for MacOS before version 13.3 + if self.device == "mps" and MACOS_VERSION < 13.3: + raise unittest.SkipTest("pow is inaccurate for MPS no MacOS-13") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.common( fn, ( @@ -6094,6 +6284,7 @@ def fn(x): (torch.randn([8, 16, 8, 8]),), ) +<<<<<<< HEAD def test_unsigned_constant_tensors(self): def fn(x): c = torch.tensor(7, dtype=torch.uint8) @@ -6104,6 +6295,8 @@ def fn(x): (torch.randn([16, 16]),), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Disable size_asserts for this test due to https://github.com/pytorch/pytorch/issues/145963 @config.patch(size_asserts=os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS") == "1") @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) @@ -7016,6 +7209,7 @@ def fn(a): self.common(fn, (torch.randn(8),)) +<<<<<<< HEAD def test_full_like_transposed(self): def fn(a): return torch.full_like(a, 3) @@ -7028,6 +7222,8 @@ def fn(a): self.common(fn, (torch.rand(3, 4)[:, ::2],), exact_stride=True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_full_truncation(self): def fn(a): return a + torch.full_like(a, 7.777) @@ -7414,6 +7610,7 @@ def fn(a, descending): self.common(fn, (inp, False)) self.common(fn, (inp, True)) +<<<<<<< HEAD @parametrize("stable", (True, False)) @parametrize("descending", (True, False)) def test_nan_sort(self, descending, stable): @@ -7445,6 +7642,8 @@ def test_sort(x, descending, stable): b = test_sort(*inps) self.assertEqual(a, b, equal_nan=True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_sort_stable(self): def fn(a, descending): return a.sort(dim=-1, stable=True, descending=descending) @@ -7536,6 +7735,7 @@ def fn(a): fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),) ) +<<<<<<< HEAD def test_constant_pad_2d_strides_nonpositive(self): def fn(a): return torch.constant_pad_nd(a, [0, 0, 0, -2, 0, 0]) @@ -7544,6 +7744,8 @@ def fn(a): fn, (torch.empty_strided((2, 4, 5), (20, 1, 4), dtype=torch.float32),) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skip_if_gpu_halide # misaligned address def test_constant_pad_3d(self): def fn(a): @@ -8038,6 +8240,11 @@ def fn(a, b, c, beta): # Greatest relative difference: 1.0 at index (3, 19, 4) (up to 0.001 allowed) atol=0.002, rtol=0.001, +<<<<<<< HEAD +======= + # MacOS-13 MM ops have precision issues + check_lowp=self.device != "mps" or MACOS_VERSION > 14.0, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @config.patch({"triton.max_tiles": 2}) @@ -8281,6 +8488,13 @@ def fn(x, y): ) return torch.ops.aten.index.Tensor(y, [iota, sub]) +<<<<<<< HEAD +======= + # Requires masked loading for the intermediate reduction + if self.device == "mps" and MACOS_VERSION < 13.3: + raise unittest.SkipTest("Fails with internal compiler error on MacOS-13") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.common(fn, [torch.randn(1, 1024), torch.randn(1, 1024, 2)]) @config.patch(fallback_random=True) @@ -8503,6 +8717,7 @@ def forward(self, x, start_pos): self.common(kv_cache_module, (inp, 1), check_lowp=False) assertGeneratedKernelCountEqual(self, 1) +<<<<<<< HEAD @skipIfMPS def test_slice_scatter_dtype_consistency(self): # Test dtype consistency of slice_scatter @@ -8521,6 +8736,8 @@ def fn(x, y): ], ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skip_if_gpu_halide # compile error on gpu def test_scatter1(self): def fn(a, dim, index, b): @@ -8610,6 +8827,12 @@ def fn(a, dim, index, b, reduce): a1.scatter_(dim, index, b, reduce=reduce) return (a, a1) +<<<<<<< HEAD +======= + if self.device == "mps" and MACOS_VERSION < 14.0: + raise unittest.SkipTest("Crashes on MacOS-13") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) check_lowp = True if self.device == "xpu": check_lowp = False @@ -8759,6 +8982,12 @@ def fn(a, dim, index, b, reduce): a1.scatter_reduce_(dim, index, b, reduce=reduce) return (a, a1) +<<<<<<< HEAD +======= + if self.device == "mps" and MACOS_VERSION < 14.0: + raise unittest.SkipTest("Crashes on MacOS-13") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) check_lowp = True if self.device == "xpu": check_lowp = False @@ -8966,7 +9195,11 @@ def test_fallback_mutable_op_basic(self): with torch.library._scoped_library("mylib", "FRAGMENT") as m: def impl(a, b, c, d, e=2): +<<<<<<< HEAD a.add_(b[0] * c * e) +======= + (a.add_(b[0] * c * e),) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if d is not None: d.add_(b[1]) @@ -9039,7 +9272,11 @@ def test_fallback_mutable_op_with_return(self): with torch.library._scoped_library("mylib", "FRAGMENT") as m: def impl(a, b, c, d, e=2): +<<<<<<< HEAD a.add_(b[0] * c * e) +======= + (a.add_(b[0] * c * e),) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if d is not None: d.add_(b[1]) return b[0] + b[1] @@ -9292,7 +9529,11 @@ def forward(self, v1: torch.Tensor): model = Model() x = torch.rand(10, 3, 0) +<<<<<<< HEAD self.common(model, (x,), exact_stride=True) +======= + self.common(model, (x,)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_randint(self): @torch.compile(fullgraph=True) @@ -9348,6 +9589,7 @@ def bin(index, max_size): @xfail_if_mps # 100% are not close def test_like_rands(self): def fn(x): +<<<<<<< HEAD return torch.rand_like(x), torch.randn_like(x), torch.randint_like(x, 1, 11) self.common(fn, [torch.zeros([20, 20])], exact_stride=True) @@ -9363,6 +9605,11 @@ def fn(x): ) self.common(fn, (torch.zeros([3, 4])[:, ::2].permute(1, 0),), exact_stride=True) +======= + return torch.rand_like(x), torch.randn_like(x) + + self.common(fn, [torch.zeros([20, 20])]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @config.patch(check_stack_no_cycles_TESTING_ONLY=True) def test_check_stack_no_cycles(self): @@ -9395,8 +9642,11 @@ def fn(x): a0 = fn(x).clone() a1 = fn(x).clone() self.assertFalse(torch.allclose(a0, a1)) +<<<<<<< HEAD self.assertEqual(a0.shape, a1.shape) self.assertEqual(a0.stride(), a1.stride()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @requires_gpu() @skip_if_triton_cpu("Flaky on Triton CPU") @@ -9414,8 +9664,11 @@ def fn(x, device): a1 = test_like_rands_on_different_device(GPU_TYPE, "cpu") self.assertTrue(a0.device.type == GPU_TYPE) self.assertTrue(a1.device.type == "cpu") +<<<<<<< HEAD self.assertEqual(a0.shape, a1.shape) self.assertEqual(a0.stride(), a1.stride()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_max_pool2d_with_indices_backward(self): def fn(a, b, c): @@ -9675,6 +9928,10 @@ def fn(a, b): ) assertGeneratedKernelCountEqual(self, 0) +<<<<<<< HEAD +======= + @xfail_if_mps_unimplemented +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_avg_pool3d_backward(self): def fn(a, b): return aten.avg_pool3d_backward( @@ -9696,6 +9953,10 @@ def fn(a, b): ], ) +<<<<<<< HEAD +======= + @xfail_if_mps_unimplemented +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skip_if_halide # compiles for 5+ minutes def test_avg_pool3d_backward2(self): def fn(a, b): @@ -9718,6 +9979,10 @@ def fn(a, b): ], ) +<<<<<<< HEAD +======= + @xfail_if_mps_unimplemented +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_avg_pool3d_backward3(self): def fn(a, b): return aten.avg_pool3d_backward( @@ -9741,6 +10006,10 @@ def fn(a, b): ) assertGeneratedKernelCountEqual(self, 1) +<<<<<<< HEAD +======= + @xfail_if_mps_unimplemented +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_avg_pool3d_backward4(self): def fn(a, b): return aten.avg_pool3d_backward( @@ -9942,7 +10211,10 @@ def fn(x): ], ) +<<<<<<< HEAD @skipIfXpu(msg="Incorrect XPU reference") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_argmax_argmin2(self): def fn(x): return ( @@ -9954,7 +10226,10 @@ def fn(x): self.common(fn, (torch.randn([144, 144]),)) +<<<<<<< HEAD @skipIfXpu(msg="Incorrect XPU reference") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_argmax_argmin_with_duplicates(self): def fn(x): return ( @@ -9976,7 +10251,10 @@ def fn(x): t1 = torch.randint(8, size=(1028, 1028)) self.common(fn, (t1,)) +<<<<<<< HEAD @skipIfXpu(msg="# Incorrect XPU reference ") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @xfail_if_mps # eager nan is wrong, see https://github.com/pytorch/pytorch/issues/130295 @skip_if_halide # nan behavior def test_argmax_argmin_with_nan(self): @@ -10077,7 +10355,10 @@ def shrink_rank(x, rank): [rank4_inps, rank3_inps, rank5_inps], ) +<<<<<<< HEAD @skipIfXpu(msg="Incorrect XPU reference") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_argmax_argmin3(self): def fn(x): return ( @@ -10360,19 +10641,33 @@ def test_zero_dim_reductions(self): for kd in [True, False]: inps0 = (torch.zeros(2, 0, device=self.device, dtype=torch.float16), 1, kd) failed_ops = [aten.argmin, aten.argmax, aten.max, aten.min] +<<<<<<< HEAD for op in failed_ops: with self.assertRaisesRegex( IndexError, "Expected reduction dim 1 to have non-zero size" ): mod = make_fx(op)(*inps0) +======= + for fo in failed_ops: + with self.assertRaisesRegex( + IndexError, "Expected reduction dim 1 to have non-zero size" + ): + mod = make_fx(fo)(*inps0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _ = compile_fx_inner(mod, inps0) pass_ops = [ lambda *x: fn(*x) for fn in [aten.sum, aten.prod, aten.any, aten.all] ] +<<<<<<< HEAD for op in pass_ops: compiled = torch.compile(op, backend="inductor") expected = op(*inps0) +======= + for po in pass_ops: + compiled = torch.compile(po, backend="inductor") + expected = po(*inps0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) actual = compiled(*inps0) self.assertTrue(torch.allclose(actual, expected, atol=1e-3, rtol=1e-3)) @@ -10513,7 +10808,11 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): with TestRefMode(): fn_compiled(inps) +<<<<<<< HEAD # for some reason, TorchDispatch doesn't capture the +======= + # for some reason, TorchDispatch doesnt capture the +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # cuda mm call (even without cudagraphs) if self.device == "cpu": self.assertTrue(matmul_seen) @@ -10635,6 +10934,11 @@ def f(x): not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)" ) def test_inductor_multiple_specializations(self): +<<<<<<< HEAD +======= + from triton.testing import do_bench + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch.compile( options={ "max_autotune": True, @@ -10649,7 +10953,11 @@ def inductor_matmul(a, b): m = 16 k = 1280 dynamic_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16) +<<<<<<< HEAD dynamic_specialized_a = dynamic_a.clone() +======= + dynamic_specialized_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) b = torch.randn(k, m, device=GPU_TYPE, dtype=torch.bfloat16) torch._dynamo.decorators.mark_dynamic( dynamic_a, @@ -10664,6 +10972,7 @@ def inductor_matmul(a, b): b, 1, ) +<<<<<<< HEAD dynamic = inductor_matmul(dynamic_a, b) torch._dynamo.reset() dynamic_specialized = inductor_matmul(dynamic_specialized_a, b) @@ -10695,6 +11004,14 @@ def override(x): self.assertNotEqual(code1, code2) self.assertEqual(no_override(x_small), override(x_small)) +======= + dynamic = do_bench(lambda: inductor_matmul(dynamic_a, b)) + torch._dynamo.reset() + dynamic_specialized = do_bench( + lambda: inductor_matmul(dynamic_specialized_a, b) + ) + self.assertGreaterEqual(dynamic, dynamic_specialized) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @requires_gpu() def test_stride_preservation_with_stride_modifying_fx_pass(self): @@ -10703,7 +11020,11 @@ def f(x): def custom_pass(g: torch.fx.Graph) -> None: """ +<<<<<<< HEAD Applies `lambda x: x.t().contiguous().t()` to the output. +======= + Applies `lamda x: x.t().contiguous().t()` to the output. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ output_node = g.find_nodes(op="output")[0] assert len(output_node.args) == 1 @@ -10752,6 +11073,7 @@ def fn(x): self.common(fn, [torch.randn(1, 8, 396 * 300)]) +<<<<<<< HEAD @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_pattern_matcher_unbacked(self): @torch.compile(fullgraph=True) @@ -10768,6 +11090,8 @@ def get_mask(W: torch.Tensor, percentage_nonzeros: torch.Tensor): p = torch.tensor(0.50, device=self.device) get_mask(x, p) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_sqrt_dynamic_shapes(self): # TIMM convit_base model: https://github.com/pytorch/pytorch/issues/97877. # TODO: support cuda path. @@ -11521,6 +11845,12 @@ def fn(input_ids) -> torch.Tensor: attention_mask = attention_mask.long() return torch.cumsum(attention_mask, dim=1) +<<<<<<< HEAD +======= + if self.device == "mps" and MACOS_VERSION < 13.3: + raise unittest.SkipTest("CumSum for int64 needs MacOS-13.3+") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.randn(2, 2) self.common(fn, (x,), atol=0, rtol=0) @@ -11604,7 +11934,11 @@ def fn(x, size, memory_format): @staticmethod def _cases_resize_as_common(): for x, y_size, memory_format in CommonTemplate._cases_resize_common(): +<<<<<<< HEAD # each sizes /memory_format combination tested in 2 ways: +======= + # each sizes /memory_format combintation tested in 2 ways: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # 1. y is contiguous fn gets memory_format kwargs # 2. y has memory_format contiguity and fn gets preserve kwarg # 3. y has some other strides (not contiguous or channels last) and fn gets preserve @@ -11731,12 +12065,24 @@ def test_fft_real_input(self): def fn(x): return torch.fft.fftn(x) +<<<<<<< HEAD +======= + if self.device == "mps" and MACOS_VERSION < 14.0: + raise unittest.SkipTest("FFT needs MacOS-14+") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.common(fn, (torch.randn((16, 16, 16)),), check_lowp=False) def test_fft_real_input_real_output(self): def fn(x): return torch.fft.fftn(x).real +<<<<<<< HEAD +======= + if self.device == "mps" and MACOS_VERSION < 14.0: + raise unittest.SkipTest("FFT needs MacOS-14+") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.common(fn, (torch.randn((16, 16, 16)),), check_lowp=False) def test_searchsorted(self): @@ -12298,7 +12644,11 @@ def fn(x): # a new test case. self.assertEqual(len(bar_strides), 1) if self.device == "mps" and MACOS_VERSION < 15.0: +<<<<<<< HEAD # Before MacOS15 contiguous output were returned regardless of input +======= + # Before MacOS15 contigous output were returned regardless of input +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(bar_strides[0], expected_stride) else: self.assertNotEqual(bar_strides[0], expected_stride) @@ -12931,10 +13281,22 @@ def fn(x): not in [ "airy_ai", "erfcx", +<<<<<<< HEAD +======= + "gammainc", + "gammaincc", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "laguerre_polynomial_l", "legendre_polynomial_p", "log_ndtr", "ndtri", +<<<<<<< HEAD +======= + "shifted_chebyshev_polynomial_t", + "shifted_chebyshev_polynomial_u", + "shifted_chebyshev_polynomial_v", + "shifted_chebyshev_polynomial_w", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] else self.assertRaises(NotImplementedError) ) @@ -13046,9 +13408,21 @@ def forward(float_1, view_1): a = torch.randn(512, 4096, requires_grad=True) b = torch.randint(size=(512,), low=0, high=4095) +<<<<<<< HEAD self.common(forward, (a, b)) def test_isin_tensor_scalar(self): +======= + if self.device == "mps" and MACOS_VERSION < 13.3: + raise unittest.SkipTest("Fails with internal compiler error on MacOS-13") + + self.common(forward, (a, b)) + + def test_isin_tensor_scalar(self): + if self.device == "mps" and MACOS_VERSION < 14.0: + raise unittest.SkipTest("isin is not implemented on MacOS-13") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for invert in [True, False]: torch._dynamo.reset() elements = 1 @@ -13221,7 +13595,11 @@ def __init__(self, dim): def forward(self, x): x = self.conv_t(x) +<<<<<<< HEAD x = torch.sigmoid(x) # trigger condition +======= + x = torch.sigmoid(x) # tigger condition +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return x for dim in (1, 2, 3): @@ -13304,7 +13682,11 @@ def f(x): "assert_size_stride(buf2, (16, 32), (32, 1)" ).run(code) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @config.patch(use_fast_math=True) def test_prepare_softmax_with_fast_math(self): """ @@ -13643,7 +14025,11 @@ def test_split_reduction_with_int64_size(self): op = torch.mean expected = op(t) actual = torch.compile(op)(t) +<<<<<<< HEAD # self.common takes more GPU memory. Do the check directly +======= + # self.common takes more GPU memory. Do the check dirctly +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue( torch.allclose(expected, actual, atol=1e-2, rtol=1e-2), f"{expected=} {actual=}", @@ -13713,6 +14099,10 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar ) @config.patch("min_num_split", 256) +<<<<<<< HEAD +======= + @xfail_if_mps # TypeError: cannot determine truth value of Relational +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_split_reduction_dynamic_shape(self): from torch._dynamo.decorators import mark_dynamic @@ -13781,6 +14171,7 @@ def forward(self, x): FileCheck().check("cpp_fused_add_0").run(code) self.assertEqual(refe_out, test_out) +<<<<<<< HEAD def test_triton_kernel_bool_param(self): if self.device != GPU_TYPE or self.device == "mps": raise unittest.SkipTest("requires GPU") @@ -14046,6 +14437,8 @@ def fn(a, b): # end of class CommonTemplate - add new tests here +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass class TestFailure: @@ -14089,6 +14482,7 @@ def new_test(self, value=value): other_cls.is_dtype_supported = my_cls.is_dtype_supported +<<<<<<< HEAD def add_test_failures( test_failures: dict[str, TestFailure], added_test_failures: dict[str, TestFailure] ): @@ -14108,6 +14502,8 @@ def add_test_failures( test_failures[name] = new_failure +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if RUN_CPU: class SweepInputsCpuTest(SweepInputs2, TestCase): @@ -14420,7 +14816,11 @@ def forward( torch._inductor.aot_compile(traced, inputs) @skipCUDAIf(not SM90OrLater, "Requires sm90") +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(TEST_WITH_ROCM, "no grouped_mm support") @config.patch(implicit_fallbacks=True) def test_grouped_mm(self): @@ -14481,7 +14881,11 @@ def has_indirect(code, tl_fn: str): def has_assert(code, lower: bool, upper: bool): self.assertIn( +<<<<<<< HEAD "device_assert", code, msg=f"No device assert found:\n{code}" +======= + "device_assert", code, msg=f"No device asert found:\n{code}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for line in code.split("\n"): if "device_assert" in line: @@ -14577,6 +14981,7 @@ def fn_gpu(x): self.assertEqual(type(r), np.ndarray) self.assertEqual(r, np.sin(x)) +<<<<<<< HEAD @config.patch(expand_dimension_for_pointwise_nodes=True) def test_rope_fusion(self): batch_size, seq_length, hidden_dim = 8, 16, 128 @@ -14626,6 +15031,8 @@ def apply_rotary_pos_emb( code = run_and_get_triton_code(compiled_fn, q, k, cos, sin, pos_ids) self.assertEqual(code.count(".run("), 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_numpy_autograd(self): def my_torch(x): y = torch.cat([torch.sin(x) ** 2, torch.max(x)[None]]) @@ -14983,11 +15390,19 @@ def fn(x): else: self.assertTrue("Graph fragment" in code) self.assertTrue( +<<<<<<< HEAD f'%sin : Tensor "f32[4, 4][4, 1]{GPU_TYPE}:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default]' in code ) self.assertTrue( f'%relu : Tensor "f32[4, 4][4, 1]{GPU_TYPE}:0"[num_users=1] = call_function[target=torch.ops.aten.relu.default]' +======= + "%sin : [num_users=1] = call_function[target=torch.ops.aten.sin.default]" + in code + ) + self.assertTrue( + "%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default]" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) in code ) @@ -15442,6 +15857,7 @@ def fn(x): "'XBLOCK': 'constexpr'" ).run(code[0]) +<<<<<<< HEAD @unittest.skipIf(TEST_WITH_ROCM or not IS_SM90, "no scaled_grouped_mm support") def test_respect_scaled_grouped_mm_layout_tag(self): # scaled_grouped_mm needs `mat2` to be column-major @@ -15533,6 +15949,303 @@ def f(x): FileCheck().check_count( "with torch.cuda._DeviceGuard(0)", 1, exactly=True ).run(code) +======= + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to(GPU_TYPE) + + x, y = [torch.ones(2, 2, device=self.device) for _ in range(2)] + x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] + eager_out = f(x, y) + + f_compiled = torch.compile(f) + compiled_out = f_compiled(x_cloned, y_cloned) + self.assertEqual(eager_out, compiled_out) + + _, code = run_and_get_code(f_compiled, x_cloned, y_cloned) + + if not config.cpp_wrapper: + FileCheck().check("def partition_0(args):").check( + "(buf0, buf1, arg0_1, arg1_1) = self.partitions[0](partition0_args)" + ).check("recursively_apply_fns = runner.recursively_apply_fns").run( + code[0] + ) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_foreach_op(self): + def fn(a0, a1): + c = torch._foreach_abs([a0, a1]) + return torch.mul(c[0], a0) + + compiled_fn = torch.compile(fn) + + a0 = torch.randn(2, 3, device=self.device) + a1 = torch.randn(2, 3, device=self.device) + eager_out = fn(a0, a1) + compiled_out = compiled_fn(a0, a1) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_multiple_functions(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to(GPU_TYPE) + + def g(x): + return x + 1 + + x, y = [torch.ones(2, 2, device=self.device) for _ in range(2)] + x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] + eager_out = g(f(x, y)) + + f_compiled = torch.compile(f) + g_compiled = torch.compile(g) + compiled_out = g_compiled(f_compiled(x_cloned, y_cloned)) + + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_condition_op(self): + def f(p, b): + def true_fn(x): + return torch.cos(x) + + def false_fn(x): + return torch.sin(x) + + return torch.cond(p, true_fn, false_fn, [b]) + + compiled_f = torch.compile(f) + + # static shape + p = torch.tensor([True], device=self.device) + a = torch.ones([2, 3], device=self.device) + eager_out = f(p, a) + compiled_out = compiled_f(p, a) + self.assertEqual(eager_out, compiled_out) + + # dynamic shape with backed symint + p = torch.tensor([True], device=self.device) + a = torch.ones([4, 5], device=self.device) + eager_out = f(p, a) + compiled_out = compiled_f(p, a) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_graph_partition_unbacked_symint_multi_output_layout(self): + def f(p, size_tensor): + size_val = size_tensor.item() + b = torch.ones([size_val, 3], device=GPU_TYPE) + + def true_fn(x): + return torch.cos(x), torch.cos(x) + 1 + + def false_fn(x): + return torch.sin(x), torch.sin(x) + 1 + + cond_out = torch.cond(p, true_fn, false_fn, [b]) + return cond_out[0] + cond_out[1] + + compiled_f = torch.compile(f) + p = torch.tensor([True], device=GPU_TYPE) + size_tensor = torch.tensor(2, device=GPU_TYPE) + eager_out = f(p, size_tensor) + compiled_out = compiled_f(p, size_tensor) + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to(GPU_TYPE) + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device=self.device), + torch.randn(3, 3, device=self.device), + ) + compiled_out = f_compiled(x, y) + self.assertEqual(compiled_out, f(x, y)) + + x, y = ( + torch.ones(4, 4, device=self.device), + torch.randn(4, 4, device=self.device), + ) + compiled_out = f_compiled(x, y) + self.assertEqual(compiled_out, f(x, y)) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_cat_backward(self): + def f(x, w): + y = torch.cat((x, x), dim=0) + z = y @ w + return z @ z.T + + compiled_f = torch.compile(f) + + for shape in (2, 3): + torch.manual_seed(42) + eager_x = torch.randn(shape, 2, device=self.device) + eager_w = torch.randn(2, 2, device=self.device, requires_grad=True) + torch.manual_seed(42) + compiled_x = torch.randn(shape, 2, device=self.device) + compiled_w = torch.randn(2, 2, device=self.device, requires_grad=True) + + f(eager_x, eager_w).sum().backward() + compiled_f(compiled_x, compiled_w).sum().backward() + self.assertEqual(eager_w.grad, compiled_w.grad) + + @dynamo_config.patch("capture_dynamic_output_shape_ops", True) + @config.patch(implicit_fallbacks=True) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_from_nested_indirect_indexing(self): + def nested(x, repeats): + rank = torch.arange(repeats.numel(), device=x.device) + index = rank.repeat_interleave(repeats, dim=0) + return torch.index_select(x, index=index, dim=0) + + example_inputs = ( + torch.randn((32, 64), device=self.device), + repeats := torch.tensor([5, 10, 15], device=self.device), + ) + torch._dynamo.mark_dynamic(repeats, 0) # create backed symint + + nested_opt = torch.compile(nested, backend="inductor") + + expect = nested(*example_inputs) + actual = nested_opt(*example_inputs) + self.assertEqual(expect, actual) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_symint_from_mutation_index(self): + x = torch.zeros(7, device=GPU_TYPE) + + def fn(n, a): + a[n] = -1 + return a + + opt_fn = torch.compile(fn, fullgraph=True) + + for n in range(2, x.shape[0]): + opt_fn(n, x) + self.assertEqual(x[n], -1) + + # Negative index triggers new compilation. + opt_fn(-x.shape[0], x) + + self.assertEqual(x[0], -1) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_unbacked_symint(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + return x1 + y1 + z + y_cpu.to(GPU_TYPE) + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device=self.device), + torch.randn(3, 3, device=self.device), + ) + + torch._dynamo.decorators.mark_unbacked(x, 0) + torch._dynamo.decorators.mark_unbacked(y, 1) + + compiled_out = f_compiled(x, y) + eager_out = f(x, y) + self.assertEqual(compiled_out, eager_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_dynamic_scalar_inputs(self): + def f(x, y, integer): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x @ y + z += integer + return x1 + y1 + z + y_cpu.to(GPU_TYPE) + + f_compiled = torch.compile(f) + x, y = ( + torch.ones(3, 3, device=self.device), + torch.randn(3, 3, device=self.device), + ) + + torch._dynamo.decorators.mark_unbacked(x, 0) + torch._dynamo.decorators.mark_unbacked(y, 1) + + compiled_out = f_compiled(x, y, 5) + self.assertEqual(compiled_out, f(x, y, 5)) + + compiled_out = f_compiled(x, y, 6) + self.assertEqual(compiled_out, f(x, y, 6)) + + @torch._inductor.config.patch("graph_partition", True) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_graph_partition_item(self): + def f(x): + y = x + 1 + scalar = y.item() + return x + y + scalar + + compiled_f = torch.compile(f) + compiled_out = f(torch.tensor(1, device=GPU_TYPE)) + self.assertEqual(compiled_out, f(torch.tensor(1, device=GPU_TYPE))) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_buffer_reuse(self): + def f(x, y): + x1 = x + 1 + y1 = y + 1 + y_cpu = y1.cpu() + 1 + z = x1 + y1 + x @ y + u = (y_cpu.to(GPU_TYPE) + 2) @ y + 3 + u_cpu = u.cpu() + 2 + return z + u_cpu.to(GPU_TYPE) + + x, y = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(2)] + x_cloned, y_cloned = [tmp.clone() for tmp in [x, y]] + eager_out = f(x, y) + + f_compiled = torch.compile(f) + compiled_out = f_compiled(x_cloned, y_cloned) + + self.assertEqual(eager_out, compiled_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_fused_scheduler_node(self): + def foo(x): + x = x * 20 + x_alias = x[0] + y = x * 10 + y_alias = y[0] + torch._dynamo.graph_break() + ind = torch.tensor(4, device=GPU_TYPE) + x_alias2 = x[ind:] + y_alias2 = y[ind:] + return x, x_alias, x_alias2, y_alias, y_alias2 + + foo = torch.compile(foo) + x = torch.rand([20, 20], device=GPU_TYPE) + _, code = run_and_get_code(foo, x) + + if not config.cpp_wrapper: + FileCheck().check("def partition_0(args):").run(code[0]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class RNNTest(TestCase): device_type = GPU_TYPE diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 4bcdf0d0cddcf..cd553f3915233 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -25,7 +25,10 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library +<<<<<<< HEAD add_test_failures, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CommonTemplate, copy_tests, run_and_get_cpp_code, @@ -138,12 +141,15 @@ def run(*ex, **kwargs): "test_mul_index_expr_dynamic_shapes": TestFailure(("cpu",)), "test_flip_cat_dynamic_shapes": TestFailure(("cpu",)), "test_pad_single_dynamic_shapes": TestFailure(("cpu",)), +<<<<<<< HEAD "test_slice_scatter_dtype_consistency_dynamic_shapes": TestFailure( ( "cpu", "mps", ) ), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "test_embedding_sparse_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), # # Failed to find for loop/triton kernel: @@ -159,7 +165,10 @@ def run(*ex, **kwargs): "test_bmm2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_both_scalars_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_compar_dynamic_shapes": TestFailure(("cpu",)), +<<<<<<< HEAD "test_complex_from_real_imag_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "test_const_int32_to_float_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_conv2d_backward_channels_last_dynamic_shapes": TestFailure(("cpu",)), "test_conv_backward_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), @@ -185,7 +194,10 @@ def run(*ex, **kwargs): "test_bucketize_int_int64_int64_dynamic_shapes": TestFailure(("cpu",)), "test_searchsorted_dynamic_shapes": TestFailure(("cpu",)), "test_like_rands_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), +<<<<<<< HEAD "test_like_rands_sliced_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "test_linspace2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_linspace3_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_linspace4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), @@ -251,9 +263,12 @@ def run(*ex, **kwargs): "test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda", "xpu")), "test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda", "xpu")), "test_polar_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True), +<<<<<<< HEAD "test_add_complex7_dynamic_shapes": TestFailure(("cpu",), is_skip=True), "test_add_complex8_dynamic_shapes": TestFailure(("cpu",), is_skip=True), "test_add_complex9_dynamic_shapes": TestFailure(("cpu",), is_skip=True), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "test_randn_generator_dynamic_shapes": TestFailure(("cpu",)), "test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_single_elem_dynamic_shapes": TestFailure(("cpu",)), @@ -358,7 +373,11 @@ def run(*ex, **kwargs): "test_rand_like_deterministic_dynamic_shapes": TestFailure( ("cpu", "cuda", "xpu"), is_skip=True ), +<<<<<<< HEAD "test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu",)), +======= + "test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "test_slice_mutation2_dynamic_shapes": TestFailure( ("cpu", "cuda", "xpu"), is_skip=True ), @@ -393,10 +412,16 @@ def run(*ex, **kwargs): # Refinement means we don't actually generate dynamic shapes (but only on # cpu apparently?!) "test_nonzero_unbacked_refinement_dynamic_shapes": TestFailure(("cpu",)), +<<<<<<< HEAD } add_test_failures(test_failures, dynamic_shapes_test_failures) +======= + **dynamic_shapes_test_failures, +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not TEST_WITH_ROCM: test_failures.update( { diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 57d263a63e8ac..a9a8a6a286025 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -36,7 +36,10 @@ GPU_TYPE, HAS_CPU, HAS_GPU, +<<<<<<< HEAD HAS_MPS, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) patch_inductor_backend, ) @@ -60,6 +63,7 @@ "test_kwargs_dynamic_shapes": TestFailure(("cpu",)), # calling div on only symint args "test_AllenaiLongformerBase_repro_dynamic_shapes": TestFailure( +<<<<<<< HEAD ("cpu", "cuda", "xpu", "mps") ), "test_argmax_argmin_with_duplicates_dynamic_shapes": TestFailure(("mps",)), @@ -88,6 +92,11 @@ ), } +======= + ("cpu", "cuda", "xpu") + ), +} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not torch._inductor.config.cpp_wrapper: test_failures["test_conv_inference_heuristics_dynamic_shapes"] = TestFailure( ("cuda",) @@ -104,6 +113,7 @@ test_failures["test_unbacked_reduction"] = TestFailure(("cpu"), is_skip=True) +<<<<<<< HEAD if any(os.getenv("BUILD_ENVIRONMENT", "").endswith(x) for x in ("-debug", "-asan")): # Fails with TORCH_INTERNAL_ASSERT(!is_heap_allocated()), see https://github.com/pytorch/pytorch/issues/130073 # After https://github.com/pytorch/pytorch/pull/161586, starts failing UBSAN so we can't even xfail. @@ -115,6 +125,12 @@ test_failures["test_resize_dynamic_shapes"] = TestFailure( ("cpu", "cuda"), is_skip=True ) +======= +if os.getenv("BUILD_ENVIRONMENT", "").endswith("-debug"): + # Fails with TORCH_INTERNAL_ASSERT(!is_heap_allocated()), see https://github.com/pytorch/pytorch/issues/130073 + test_failures["test_resize_as_dynamic_shapes"] = TestFailure(("cpu", "cuda")) + test_failures["test_resize_dynamic_shapes"] = TestFailure(("cpu", "cuda")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def make_dynamic_cls(cls, xfail_prop="_expected_failure_dynamic"): @@ -139,7 +155,11 @@ class DynamicShapesCpuTests(TestCase): copy_tests(DynamicShapesCommonTemplate, DynamicShapesCpuTests, "cpu", test_failures) +<<<<<<< HEAD if (HAS_GPU or HAS_MPS) and not TEST_WITH_ASAN: +======= +if HAS_GPU and not TEST_WITH_ASAN: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class DynamicShapesGPUTests(TestCase): common = check_model_gpu @@ -154,7 +174,11 @@ class TestInductorDynamic(TestCase): compile_fn = partial(torch.compile, dynamic=True) def setUp(self): +<<<<<<< HEAD # HAS_CUDA_AND_TRITON also checks compute capability to skip tests +======= + # HAS_CUDA also checks compute capability to skip tests +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # on older devices if not HAS_GPU: self.skipTest("Triton not available") @@ -1166,5 +1190,9 @@ def fn(a, descending): from torch._inductor.test_case import run_tests # Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068 +<<<<<<< HEAD if (HAS_CPU or HAS_GPU or HAS_MPS) and not TEST_WITH_ASAN: +======= + if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests(needs="filelock") diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 807ccb48a7983..29cf257d8509d 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -2,7 +2,10 @@ import atexit import contextlib import functools +<<<<<<< HEAD import math +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import os import sys import unittest @@ -26,13 +29,21 @@ OpDTypes, ops, skipCPUIf, +<<<<<<< HEAD +======= + skipCUDAIf, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) skipXPUIf, ) from torch.testing._internal.common_methods_invocations import op_db, skipOps from torch.testing._internal.common_utils import ( +<<<<<<< HEAD IS_CI, IS_MACOS, IS_WINDOWS, +======= + IS_MACOS, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) IS_X86, skipCUDAMemoryLeakCheckIf, skipIfCrossRef, @@ -45,11 +56,19 @@ from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_CPU, +<<<<<<< HEAD has_triton, HAS_XPU_AND_TRITON, maybe_skip_size_asserts, ) from torch.testing._internal.triton_utils import requires_gpu_and_triton +======= + HAS_CUDA, + has_triton, + HAS_XPU, + maybe_skip_size_asserts, +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._dtype_abbrs import dtype_abbrs from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_map @@ -69,6 +88,7 @@ sys.exit(0) raise +<<<<<<< HEAD if IS_WINDOWS and IS_CI: # TODO(xuhancn) : improve the compiler build performance on windows. sys.stderr.write( @@ -78,6 +98,8 @@ sys.exit(0) raise unittest.SkipTest("skip slow test") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bf16 = torch.bfloat16 # not tested f64 = torch.float64 f32 = torch.float32 @@ -285,6 +307,11 @@ def format_op(op): "torch.ops.aten._efficient_attention_forward": {f16, f32}, "to_sparse": {f32, f64}, "linalg.eig": {f32, f64}, +<<<<<<< HEAD +======= + # Double and complex datatype matmul is not supported in oneDNN + "byte": {f16, f32}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("linalg.pinv", "singular"): {f64}, # could not create a primitive "addmv": {f64}, @@ -292,6 +319,7 @@ def format_op(op): # a deconvolution forward propagation primitive "nn.functional.conv_transpose2d": {f32, f64}, "nn.functional.conv_transpose3d": {f32, f64}, +<<<<<<< HEAD # [Begin] Incorrect XPU reference due to new driver. "masked.prod": {b8, i32, i64}, "masked.amin": {i64}, @@ -303,6 +331,11 @@ def format_op(op): "std_mean": {f64}, "var_mean": {f64}, # [End] +======= + # not implemented for 'Half' + "sort": {b8}, + "argsort": {b8}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } @@ -372,9 +405,14 @@ def wrapper_noop_set_seed(op, *args, **kwargs): return op(*args, **kwargs) +<<<<<<< HEAD wrapper_noop_set_seed_decorator = patch( "torch.testing._internal.common_methods_invocations.wrapper_set_seed", wrapper_noop_set_seed, +======= +torch.testing._internal.common_methods_invocations.wrapper_set_seed = ( + wrapper_noop_set_seed +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # key can be either op_name, or (op_name, dtype) @@ -409,7 +447,11 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "rtol": 1e-4, }, ("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-4, "rtol": 0.01}, +<<<<<<< HEAD # Following tests are failing with strict comparison but atol=1 is acceptable due roundings errors +======= + # Following tests are failing with strict comparision but atol=1 is acceptable due roundings errors +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("nn.functional.interpolate.bilinear", u8): {"atol": 1, "rtol": 0}, ("nn.functional.upsample_bilinear", u8): {"atol": 1, "rtol": 0}, ("nn.functional.interpolate.bicubic", u8): {"atol": 1, "rtol": 0}, @@ -438,7 +480,10 @@ def wrapper_noop_set_seed(op, *args, **kwargs): ("cumsum", f16): {"reference_in_float": True}, "cumprod": {"reference_in_float": True, "atol": 7e-5, "rtol": 0.002}, "logcumsumexp": {"grad_atol": 8e-4, "grad_rtol": 0.001}, +<<<<<<< HEAD ("logcumsumexp", f16): {"grad_atol": 3e-3, "grad_rtol": 0.01}, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "exponential": {"reference_in_float": True}, "geometric": {"reference_in_float": True}, ("kron", f16): {"reference_in_float": True}, @@ -448,7 +493,10 @@ def wrapper_noop_set_seed(op, *args, **kwargs): ("nn.functional.batch_norm.without_cudnn", f16): {"reference_in_float": True}, ("nn.functional.cosine_similarity", f16): {"reference_in_float": True}, ("nn.functional.instance_norm", f16): {"reference_in_float": True}, +<<<<<<< HEAD ("nn.functional.linear", f16): {"atol": 3e-4, "rtol": 0.01}, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("nn.functional.local_response_norm", f16): {"reference_in_float": True}, ("nn.functional.normalize", f16): {"atol": 1e-3, "rtol": 0.05}, ("nn.functional.rms_norm", f16): {"reference_in_float": True}, @@ -554,7 +602,10 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "grad_atol": 8e-4, "grad_rtol": 0.001, }, +<<<<<<< HEAD ("logcumsumexp", f16): {"grad_atol": 4e-3, "grad_rtol": 0.01}, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "exponential": {"reference_in_float": True}, "geometric": {"reference_in_float": True}, ("kron", f16): {"reference_in_float": True}, @@ -627,7 +678,11 @@ def wrapper_noop_set_seed(op, *args, **kwargs): ("var_mean", f16): {"atol": 1e-5, "rtol": 2e-3}, ("var_mean.unbiased", f16): {"atol": 1e-5, "rtol": 2e-3}, ("vdot", f16): {"atol": 1e-5, "rtol": 2e-3}, +<<<<<<< HEAD # Following tests are failing with strict comparison but atol=1 is acceptable due roundings errors +======= + # Following tests are failing with strict comparision but atol=1 is acceptable due roundings errors +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # High atol due to precision loss ("nn.functional.interpolate.bilinear", f64): {"atol": 5e-4, "rtol": 0}, ("nn.functional.upsample_bilinear", f64): {"atol": 5e-4, "rtol": 0}, @@ -682,6 +737,7 @@ def wrapper_noop_set_seed(op, *args, **kwargs): ("nn.functional.unfold", f16): { "reference_in_float": True, }, +<<<<<<< HEAD # Reference crash on Intel LTS2 driver. ("nn.functional.interpolate.trilinear", f32): { "check_gradient": False, @@ -690,6 +746,8 @@ def wrapper_noop_set_seed(op, *args, **kwargs): ("nn.functional.interpolate.trilinear", f64): { "check_gradient": False, }, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if TEST_WITH_ROCM: inductor_override_kwargs["cuda"].update( @@ -880,7 +938,10 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "nn.functional.adaptive_avg_pool3d": {f16}, "nn.functional.adaptive_max_pool1d": {f16, f32}, "nn.functional.adaptive_max_pool2d": {f16, f32}, +<<<<<<< HEAD "nn.functional.max_pool2d": {f16, f32, f64}, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "nn.functional.bilinear": {f16}, "nn.functional.conv_transpose1d": {f16}, "nn.functional.conv_transpose2d": {f16}, @@ -978,6 +1039,7 @@ def wrapper_noop_set_seed(op, *args, **kwargs): } +<<<<<<< HEAD # Custom replacements for assertEquals, in cases where a difference in value # may not indicate correctness. @@ -1103,6 +1165,8 @@ def get_sort_assert_equal_fn(args, kwargs): } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def collection_decorator(fn): @functools.wraps(fn) def inner(self, device, dtype, op): @@ -1120,7 +1184,10 @@ def inner(self, device, dtype, op): return inner +<<<<<<< HEAD @wrapper_noop_set_seed_decorator +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestInductorOpInfo(TestCase): def tearDown(self): torch._dynamo.reset() @@ -1133,10 +1200,15 @@ def tearDown(self): @skipCUDAMemoryLeakCheckIf( True ) # inductor kernels failing this test intermittently +<<<<<<< HEAD @requires_gpu_and_triton @skipXPUIf( not HAS_XPU_AND_TRITON, "Skipped! Supported XPU compiler and Triton not found" ) +======= + @skipCUDAIf(not HAS_CUDA, "Skipped! Triton not found") + @skipXPUIf(not HAS_XPU, "Skipped! Supported XPU compiler not found") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipCPUIf(not HAS_CPU, "Skipped! Supported CPU compiler not found") @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @skipIfTorchDynamo("Test uses dynamo already") @@ -1261,7 +1333,11 @@ def map_to_fake(e): return True, rng_mode.has_rng_op +<<<<<<< HEAD def get_contexts(has_rng_op, args, kwargs): +======= + def get_contexts(has_rng_op): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if has_rng_op: # TODO - enable this, running into errors return ( @@ -1278,6 +1354,7 @@ def get_contexts(has_rng_op, args, kwargs): ) ctx = functools.partial(maybe_skip_size_asserts, op) +<<<<<<< HEAD if op_name in CUSTOM_ASSERT_EQUALS_FNS: assert_equal_fn = CUSTOM_ASSERT_EQUALS_FNS[op_name](args, kwargs) return ( @@ -1287,6 +1364,8 @@ def get_contexts(has_rng_op, args, kwargs): ), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ((ctx, {}),) try: @@ -1312,9 +1391,13 @@ def _get_tolerances(dtype): # print(f"RUNNING OP {op_name} on {device_type} with {dtype}", flush=True) rtol, atol = _get_tolerances(dtype) no_python, has_rng_op = do_nopython_and_has_rng(fn, args, kwargs) +<<<<<<< HEAD for context_fn, kwarg_overrides in get_contexts( has_rng_op, args, kwargs ): +======= + for context_fn, kwarg_overrides in get_contexts(has_rng_op): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with context_fn(): # Base kwargs adjusted_kwargs = { diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 6bde7a8c540a4..08ab2aad66f92 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -9,7 +9,10 @@ import torch import torch.utils._pytree as pytree +<<<<<<< HEAD from torch._dynamo.debug_utils import InputReader +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor import config from torch._inductor.choices import InductorChoices from torch._inductor.codegen.triton import FixedTritonConfig @@ -19,7 +22,10 @@ from torch._inductor.utils import run_and_get_code from torch._inductor.virtualized import V from torch.testing._internal.common_utils import ( +<<<<<<< HEAD decorateIf, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_parametrized_tests, parametrize, skipIfXpu, @@ -27,7 +33,10 @@ ) from torch.testing._internal.inductor_utils import ( GPU_TYPE, +<<<<<<< HEAD HAS_CUDA_AND_TRITON, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) HAS_GPU, requires_gpu, skip_windows_ci, @@ -54,6 +63,7 @@ } +<<<<<<< HEAD # These xfails are due to the current restrictions with the TMA descriptor API. # see Note: TMA API Restrictions. In some cases TMA descriptors cannot be generated, and so tests # that assert on the expected number of descriptors (= equivalent block ptrs) will fail @@ -89,6 +99,58 @@ def xfail_if_use_tensor_descriptor(fn): class BlockDescriptorTestBase(InductorTestCase): block_descriptor_constructor_str = "tl.make_block_ptr" +======= +def run_and_compare( + self: InductorTestCase, + func: Callable[..., Any], + *args, + compile_kwargs: Optional[dict] = None, + expected_num_block_pointers: Optional[int] = None, + expected_num_programs: int = 1, + expected_num_triton_kernels: int = 1, + config_patches: Optional[dict] = None, + rtol: Optional[float] = None, + atol: Optional[float] = None, +): + """ + Runs the module through Inductor, comparing to eager reference. + """ + if compile_kwargs is None: + compile_kwargs = {} + if config_patches is None: + config_patches = {} + + def flatten_tensors(tensors): + flat, spec = pytree.tree_flatten(tensors) + return flat + + with config.patch(config_patches): + compiled = torch.compile(func, backend="inductor", **compile_kwargs) + result, code = run_and_get_code(compiled, *args) + + # Check numerical accuracy + ref_tensors = flatten_tensors(func(*args)) + actual_tensors = flatten_tensors(result) + for ref, actual in zip(ref_tensors, actual_tensors): + # Don't clobber the default tolerance values + tol = {t: v for t, v in {"rtol": rtol, "atol": atol}.items() if v is not None} + self.assertTrue(torch.allclose(ref, actual, **tol)) + + def count_code(substr: str, expected: Optional[int]): + count = sum(prog.count(substr) for prog in code) + if expected is not None: + self.assertEqual(count, expected) + + # Check the code + self.assertEqual(len(code), expected_num_programs) + count_code("@triton.jit", expected_num_triton_kernels) + count_code("tl.make_block_ptr", expected_num_block_pointers) + + return result, code + + +class BlockPointerTestBase(InductorTestCase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _discontiguous_tensor( self, view_size: tuple[int, ...], device: Union[torch.device, str] ) -> torch.Tensor: @@ -120,6 +182,7 @@ def _assert_tiling_ndims(self, code, blocks: list[str], num_dims: int) -> None: def _get_lines_containing_substr(self, code: str, substr: str) -> str: return "\n".join(line for line in code.split("\n") if substr in line) +<<<<<<< HEAD def _run_and_compare( self: InductorTestCase, func: Callable[..., Any], @@ -170,6 +233,8 @@ def count_code(substr: str, expected: Optional[int]): return result, code +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @instantiate_parametrized_tests class CommonTemplate: @@ -196,7 +261,12 @@ def foo(x, y): # Expect failure for bad inputs with self.assertRaises(AssertionError) if raises else contextlib.nullcontext(): # Expect 3 block pointers: 2 inputs 1 output +<<<<<<< HEAD self._run_and_compare( +======= + run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) foo, *inputs, expected_num_block_pointers=expected_num_block_pointers, @@ -271,7 +341,12 @@ def get_input() -> torch.Tensor: args = [get_input() for arg_idx in range(2)] # Expect 3 block pointers: 2 inputs 1 output +<<<<<<< HEAD self._run_and_compare( +======= + run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.add, *args, expected_num_block_pointers=3 if require_block_ptr else None, @@ -319,7 +394,12 @@ def foo(x, y): self.assertIn(1, all_dims) # Expect 3 block pointers: 2 inputs one output +<<<<<<< HEAD self._run_and_compare( +======= + run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) foo, x, y, @@ -327,6 +407,7 @@ def foo(x, y): config_patches={"triton.prefer_nd_tiling": prefer_nd_tiling}, ) +<<<<<<< HEAD def test_broadcast_with_singleton_dims(self): # This tests the case when the input / output contains both zero strides # and singleton dimensions. In this case the broadcasting dimensions @@ -396,6 +477,8 @@ def load_args(reader): rtol=rtol, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize( "x_size,y_size", [ @@ -438,9 +521,14 @@ def get_input(size: tuple[int]) -> torch.Tensor: if i != 1: self.assertEqual(i, j) +<<<<<<< HEAD result, (triton_code,) = self._run_and_compare(foo, x, y) @xfail_if_use_tensor_descriptor +======= + result, (triton_code,) = run_and_compare(self, foo, x, y) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("prefer_nd_tiling", [False, True]) @config.patch("triton.skip_l1_cache", False) def test_pointwise_broadcast_nonzero_strides(self, prefer_nd_tiling: bool): @@ -455,7 +543,12 @@ def test_pointwise_broadcast_nonzero_strides(self, prefer_nd_tiling: bool): col = torch.as_strided(full, col_shape, full.stride()) # Expect 3 block pointers: 2 inputs one output +<<<<<<< HEAD result, (triton_code,) = self._run_and_compare( +======= + result, (triton_code,) = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.add, full, col, @@ -551,7 +644,12 @@ def test_reduction( # Expect at least 1 block pointer for the input. # Add 2 more if we generate 2 kernels. +<<<<<<< HEAD result, (code,) = self._run_and_compare( +======= + result, (code,) = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.sum, view, expected_num_block_pointers=num_block_pointers, @@ -585,14 +683,22 @@ def foo(x, y): ] # Expect 2 block pointers: inputs +<<<<<<< HEAD result, (code,) = self._run_and_compare( +======= + result, (code,) = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) foo, *inputs, expected_num_block_pointers=num_block_pointers, expected_num_triton_kernels=num_triton_kernels, ) +<<<<<<< HEAD @xfail_if_use_tensor_descriptor +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_multiple_max_block_non_power_of_2(self): """ Check that we support dims of size n * MAX_BLOCK, where n is any positive integer, not @@ -617,14 +723,22 @@ def foo(x): self.assertTrue(len(nontrivial_dims) > 1) # Expect 2 block pointers: input and output +<<<<<<< HEAD self._run_and_compare(foo, view, expected_num_block_pointers=2) +======= + run_and_compare(self, foo, view, expected_num_block_pointers=2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize( "nd_tiling,num_block_pointers", [ +<<<<<<< HEAD subtest( (True, 2), decorators=[xfail_if_use_tensor_descriptor] ), # With tiling, the index is affine. +======= + (True, 2), # With tiling, the index is affine. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (False, 1), # We can't infer that the load is a power of 2. ], ) @@ -636,7 +750,12 @@ def test_dynamic_shapes_pointwise(self, nd_tiling: bool, num_block_pointers: int view_size = (4, 4) view = self._discontiguous_tensor(view_size, self.device) +<<<<<<< HEAD self._run_and_compare( +======= + run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.div, view, view, @@ -648,9 +767,13 @@ def test_dynamic_shapes_pointwise(self, nd_tiling: bool, num_block_pointers: int @parametrize( "with_tiling,num_block_pointers", [ +<<<<<<< HEAD subtest( (True, 1), decorators=[xfail_if_use_tensor_descriptor] ), # With tiling, the index is affine. +======= + (True, 1), # With tiling, the index is affine. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (False, 0), # We can't infer that the load is a power of 2. ], ) @@ -663,7 +786,12 @@ def test_dynamic_shapes_reduction(self, with_tiling: bool, num_block_pointers: i view_size = (4, 4) view = self._discontiguous_tensor(view_size, self.device) +<<<<<<< HEAD self._run_and_compare( +======= + run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.prod, view, expected_num_block_pointers=num_block_pointers, @@ -693,6 +821,7 @@ def foo(x): x = torch.randn(x_size).to(device) # Expect 2 block pointers: input and output +<<<<<<< HEAD self._run_and_compare( x, compile_kwargs={"dynamic": True}, expected_num_block_pointers=2 ) @@ -703,6 +832,12 @@ def foo(x): param_kwargs["num_block_pointers"] == 3 and param_kwargs["num_tiles"] == 1 ), ) +======= + run_and_compare( + self, x, compile_kwargs={"dynamic": True}, expected_num_block_pointers=2 + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize( "full_size,view_size,num_block_pointers,num_tiles", [ @@ -760,7 +895,12 @@ def get_input() -> torch.Tensor: args = [get_input() for arg_idx in range(2)] # Expect up to 3 block pointers: 2 inputs 1 output. +<<<<<<< HEAD result, code = self._run_and_compare( +======= + result, code = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.add, *args, expected_num_block_pointers=num_block_pointers, @@ -779,7 +919,10 @@ def get_input() -> torch.Tensor: else: self.assertNotIn(tile_name, program) +<<<<<<< HEAD @xfail_if_use_tensor_descriptor +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize( "view_size,num_block_pointers,num_triton_kernels,reduction_op", [ @@ -805,7 +948,12 @@ def test_2d_reduction_odd_shapes( # Expect at least 1 block pointer for the input. # Add 2 more if we generate 2 kernels. +<<<<<<< HEAD result, (code,) = self._run_and_compare( +======= + result, (code,) = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reduction_op, view, expected_num_block_pointers=num_block_pointers, @@ -821,9 +969,13 @@ def test_2d_reduction_odd_shapes( "size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback", [ ((8, 8), 1, 1, True), # Persistent Welford fallback +<<<<<<< HEAD subtest( ((128, 128), 9, 2, False), decorators=[xfail_if_use_tensor_descriptor] ), # Looped Welford reduction +======= + ((128, 128), 9, 2, False), # Looped Welford reduction +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ], ) def test_2d_welford_reduction( @@ -844,7 +996,12 @@ def test_2d_welford_reduction( view = self._discontiguous_tensor(size, self.device) # We expect many block pointers for this one. +<<<<<<< HEAD result, (code,) = self._run_and_compare( +======= + result, (code,) = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.var_mean, view, expected_num_block_pointers=expected_num_block_pointers, @@ -871,7 +1028,12 @@ def test_welford_non_block_pointer( view = self._discontiguous_tensor((259, 311), self.device) # We expect many block pointers for this one. +<<<<<<< HEAD result, (code,) = self._run_and_compare( +======= + result, (code,) = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.var_mean, view, expected_num_block_pointers=6, @@ -893,7 +1055,12 @@ def test_reduction_multiple_discontiguous_dims(self): # Use odd shapes to frustrate block pointer analysis. view = self._discontiguous_tensor((3, 7, 11), self.device) +<<<<<<< HEAD result, (code,) = self._run_and_compare( +======= + result, (code,) = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.sum, view, expected_num_block_pointers=0, @@ -904,7 +1071,10 @@ def test_reduction_multiple_discontiguous_dims(self): # Check for 2 reduction dimensions. self._assert_reduction_ndims(code, 2) +<<<<<<< HEAD @xfail_if_use_tensor_descriptor # Cannot use TMA API for store with no x dimension. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @test_torchinductor.skip_if_triton_cpu # Illegal instruction File; cannot xfail because it crashes process def test_2d_reduction_multi_kernel(self): """ @@ -919,7 +1089,12 @@ def foo(x): x = x.reshape(x.shape[0], -1) return torch.softmax(x, -1) +<<<<<<< HEAD result, (code,) = self._run_and_compare( +======= + result, (code,) = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) foo, view, expected_num_block_pointers=6, @@ -936,7 +1111,10 @@ def foo(x): # Check for 2 reduction dimensions. self._assert_reduction_ndims(code, 2) +<<<<<<< HEAD @xfail_if_use_tensor_descriptor +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_fused_2d_reduction( self, ): @@ -951,7 +1129,12 @@ def foo(x): view = self._discontiguous_tensor(view_size, self.device) # Expect at least 1 block pointer for the input. +<<<<<<< HEAD result, (code,) = self._run_and_compare( +======= + result, (code,) = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) foo, view, expected_num_block_pointers=1, @@ -980,7 +1163,12 @@ def foo(*args): arg1 = torch.empty(view_size) # No guarantees on the number of kernels or pointers. +<<<<<<< HEAD result, (code,) = self._run_and_compare( +======= + result, (code,) = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) foo, arg0, arg1, @@ -992,7 +1180,11 @@ def foo(*args): @parametrize( "tile_reductions", +<<<<<<< HEAD [False, subtest(True, decorators=[xfail_if_use_tensor_descriptor])], +======= + [False, True], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_enable_tiled_reductions(self, tile_reductions: bool): """ @@ -1001,7 +1193,12 @@ def test_enable_tiled_reductions(self, tile_reductions: bool): view = self._discontiguous_tensor((9, 11), self.device) # If tiled, we expect 1 block pointer for the input. +<<<<<<< HEAD result, (code,) = self._run_and_compare( +======= + result, (code,) = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.sum, view, expected_num_block_pointers=1 if tile_reductions else 0, @@ -1015,7 +1212,10 @@ def test_enable_tiled_reductions(self, tile_reductions: bool): # Check the code for multiple Rn_BLOCK's self._assert_reduction_ndims(code, 2 if tile_reductions else 1) +<<<<<<< HEAD @xfail_if_use_tensor_descriptor +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_complex_reshape_block_ptr(self): def func(x, y): add_ = x + y @@ -1029,7 +1229,12 @@ def func(x, y): return clone_0, clone_1 inps = (torch.rand((8, 2048), device=self.device, dtype=torch.float32),) * 2 +<<<<<<< HEAD result, code = self._run_and_compare( +======= + result, code = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) func, *inps, expected_num_triton_kernels=2, @@ -1037,7 +1242,10 @@ def func(x, y): ) self.assertTrue("Min" not in code[0]) +<<<<<<< HEAD @xfail_if_use_tensor_descriptor +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @requires_gpu() # FIXME this test failed on Triton-CPU def test_3d_permute_tiling(self): """ @@ -1051,7 +1259,12 @@ def foo(x, y, z): return a + b inps = (torch.rand((51, 51, 51), device=self.device, dtype=torch.float32),) * 3 +<<<<<<< HEAD result, (code,) = self._run_and_compare( +======= + result, (code,) = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) foo, *inps, expected_num_triton_kernels=1, @@ -1087,7 +1300,12 @@ def foo(x, length): ) with torch._dynamo.config.patch({"capture_scalar_outputs": True}): +<<<<<<< HEAD self._run_and_compare( +======= + run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) foo, *inps, expected_num_triton_kernels=1, @@ -1112,7 +1330,12 @@ def fn(a): return aten.bernoulli(a).sum() / torch.prod(torch.tensor(a.size())) p = 0.3 +<<<<<<< HEAD result, code = self._run_and_compare( +======= + result, code = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fn, *[torch.ones(200, 200, device=self.device) * p], expected_num_triton_kernels=2, @@ -1121,7 +1344,10 @@ def fn(a): rtol=0.06, ) +<<<<<<< HEAD @xfail_if_use_tensor_descriptor +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_pointwise_index_order(self): """ Test the order of indices in pointwise kernels. Expect Z to be the leading dim, @@ -1132,7 +1358,12 @@ def test_pointwise_index_order(self): self._discontiguous_tensor((5, 5, 5), device=self.device) for _ in range(2) ] +<<<<<<< HEAD result, (triton_code,) = self._run_and_compare( +======= + result, (triton_code,) = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.add, *inps, expected_num_triton_kernels=1, @@ -1180,7 +1411,12 @@ def foo(x): return x.expand(*expanded_size).clone() inps = [torch.randn(base_size, device=self.device)] +<<<<<<< HEAD result, (triton_code,) = self._run_and_compare( +======= + result, (triton_code,) = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) foo, *inps, expected_num_triton_kernels=1, @@ -1209,7 +1445,12 @@ def foo(x, y, z): torch.randn((128,), device=self.device), torch.randn((8, 11, 128), device=self.device), ] +<<<<<<< HEAD result, (triton_code,) = self._run_and_compare( +======= + result, (triton_code,) = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) foo, *inps, expected_num_triton_kernels=1, @@ -1224,6 +1465,7 @@ def foo(x, y, z): # Singleton splits should be discarded. self._assert_pointwise_ndims(triton_code, 2) +<<<<<<< HEAD # Integration test to ensure that matched dims & strides from match_mod_div_expr # are unsigned and signed integers respectively. This test case has the following # index:=(ModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2)) @@ -1285,6 +1527,8 @@ def model(x, y): expected_num_block_pointers=3, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @config.patch("triton.prefer_nd_tiling", True) @config.patch("triton.max_tiles", 3) @parametrize( @@ -1302,7 +1546,10 @@ def model(x, y): ), ], ) +<<<<<<< HEAD @xfail_if_use_tensor_descriptor +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_boundary_check(self, block_multiple, ynumel_exceed_ygrid_size, include_z): @dataclasses.dataclass class InputShape: @@ -1342,7 +1589,12 @@ def func(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return a + b with V.set_choices_handler(FixedBlockSizeChoices()): +<<<<<<< HEAD result, code = self._run_and_compare( +======= + result, code = run_and_compare( + self, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) func, a, b, @@ -1373,7 +1625,11 @@ def func(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: @unittest.skipIf(not TRITON_HAS_CPU, "requires triton CPU backend") @config.patch(cpu_backend="triton") @config.patch("triton.use_block_ptr", True) +<<<<<<< HEAD class TritonBlockPointerTestCPU(BlockDescriptorTestBase): +======= +class TritonBlockPointerTestCPU(BlockPointerTestBase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device = "cpu" @@ -1387,12 +1643,17 @@ class TritonBlockPointerTestCPU(BlockDescriptorTestBase): @unittest.skipIf(not HAS_GPU, "requires triton GPU backend") @config.patch("triton.use_block_ptr", True) +<<<<<<< HEAD class TritonBlockPointerTestGPU(BlockDescriptorTestBase): +======= +class TritonBlockPointerTestGPU(BlockPointerTestBase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device = GPU_TYPE test_torchinductor.copy_tests(CommonTemplate, TritonBlockPointerTestGPU, GPU_TYPE) +<<<<<<< HEAD @unittest.skipIf( not ( @@ -1416,6 +1677,8 @@ class TritonTensorDescriptorTestCUDA(BlockDescriptorTestBase): test_failures=TMA_TEST_XFAIL, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index 1573d4860a84c..18d9aa96c8a6e 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -3,13 +3,17 @@ import functools import sys import unittest +<<<<<<< HEAD from unittest import skipUnless +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from unittest.mock import MagicMock, patch import torch from torch._dynamo.testing import rand_strided from torch._inductor.runtime.triton_compat import HAS_WARP_SPEC from torch._inductor.utils import clone_preserve_strides +<<<<<<< HEAD from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_LINUX, @@ -21,6 +25,11 @@ from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_CUDA_AND_TRITON, +======= +from torch.testing._internal.common_utils import IS_LINUX, runOnRocm, skipIfXpu +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) HAS_GPU, requires_cuda_with_enough_memory, ) @@ -76,7 +85,10 @@ def get_autotuned_amd_sqr_kernel(): )(amd_sqr_kernel) +<<<<<<< HEAD @instantiate_parametrized_tests +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestTritonHeuristics(TestCase): device_type = GPU_TYPE @@ -267,11 +279,16 @@ def grid(meta): def fn(x): return triton_sqr(x) +<<<<<<< HEAD x = torch.randn(32, device=GPU_TYPE) +======= + x = torch.randn(32, device="cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ref = fn(x) res = torch.compile(fn)(x) self.assertEqual(ref, res) +<<<<<<< HEAD @skipIfXpu @skipIfRocm @skipUnless(HAS_CUDA_AND_TRITON, "requires CUDA") @@ -300,6 +317,8 @@ def test_prune_configs_over_shared_memory_limit(self, do_pruning): ) self.assertEqual(len(configs), expected_count) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestArgumentCloneAndRestore(TestCase): # Our tensor is large enough. If a unexpected copy happens, the diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 5fe3623b271a5..160609992e762 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -31,12 +31,16 @@ skipIfWindows, skipIfXpu, ) +<<<<<<< HEAD from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_CUDA_AND_TRITON, HAS_GPU, HAS_XPU_AND_TRITON, ) +======= +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.logging_utils import log_settings, logs_to_string # Defines all the kernels for tests @@ -52,7 +56,11 @@ import triton from triton import language as tl +<<<<<<< HEAD if HAS_CUDA_AND_TRITON: +======= + if HAS_CUDA: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: from triton.language.extra.libdevice import ( # @manual fast_dividef, @@ -63,7 +71,11 @@ fast_dividef, fast_dividef as my_fast_dividef, ) +<<<<<<< HEAD elif HAS_XPU_AND_TRITON: +======= + elif HAS_XPU: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from triton.language.extra.intel.libdevice import ( # @manual fast_dividef, fast_dividef as my_fast_dividef, @@ -83,12 +95,15 @@ def _triton_get_ast_equal_to_str(params): BOOL_CONSTANT_C: tl.constexpr = tl.constexpr(True) FLOAT_CONSTANT_C = tl.constexpr(3.14) # intentionally un-annotated +<<<<<<< HEAD if hasattr(triton, "constexpr_function"): @triton.constexpr_function def log2(n): return len(bin(n)) - 3 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class KernelTests(torch._inductor.test_case.TestCase): def _kernel_launched_in_code(self, kernel_name: str, code: str) -> bool: @@ -1016,7 +1031,11 @@ def _mul2(x): def f(x): for _ in range(4): # The output of one kernel is the input to the next kernel, but +<<<<<<< HEAD # at some point we should reuse buffers not allocate new ones. +======= + # at some point we should re-use buffers not allocate new ones. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = _mul2(x) return x + 1 @@ -1034,7 +1053,11 @@ def f(x): num_bufs_allocated = code.count(code_string) self.assertEqual(num_bufs_allocated, 2) +<<<<<<< HEAD # Check we're reusing buffers if not allocating. +======= + # Check we're re-using buffers if not allocating. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) num_bufs_reused = code.count( "// reuse" if inductor_config.cpp_wrapper else "# reuse" ) @@ -1319,10 +1342,17 @@ def f(x, y): else: if dynamic: # when half_n_elements passed to the Triton kernel is +<<<<<<< HEAD # dynamic, equal_to_1 specialization can't be enforced # also, equal_to_1 specialization doesn't occur (or appear in the signature) # for newer versions of triton (i.e. the ones where triton_version_uses_attrs_dict() == True) +======= + # dynamic, equal_to_1 specializaiton can't be enforced + + # also, equal_to_1 specialization doesn't occur (or appear in the signature) + # for newer versions ofo triton (i.e. the ones where triton_version_uses_attrs_dict() == True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0]) else: self.assertTrue(_triton_get_ast_equal_to_str((3,)) in sources[0]) @@ -1389,6 +1419,7 @@ def f(x): self.assertEqual(compiled_out, eager_out) +<<<<<<< HEAD @unittest.skipIf( not HAS_GPU or not hasattr(triton, "constexpr_function"), "newer triton version required", @@ -1422,6 +1453,8 @@ def f(x): self.assertIn("@triton.constexpr_function", triton_code) self.assertEqual(compiled_out, eager_out) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @requires_gpu def test_triton_kernel_with_imported_symbol_with_custom_name(self): @triton.jit @@ -2239,7 +2272,11 @@ def f(x): self.assertEqual(compiled_out, eager_out) # TODO enable this test case on XPU. +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("cfg", ["normal", "cpp_wrapper"]) def test_triton_kernel_dtype_view(self, cfg): # https://github.com/pytorch/pytorch/issues/136159 @@ -3623,6 +3660,7 @@ def f(x, y): self.assertNotIn(opname, code) @requires_gpu +<<<<<<< HEAD def test_subclass(self): libname = "my_cool_namespace" opname = "my_triton_operator" @@ -3657,6 +3695,8 @@ def f(x, y): self.assertEqual(out.b, expected.b) @requires_gpu +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dynamo_config.patch("recompile_limit", 1) def test_triton_dynamic_grid_no_recompile(self): libname = "my_cool_namespace" diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index cca1cb6a6dabb..0b2c7c0de1a23 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -489,6 +489,7 @@ def fn(q, k, vector, scalar): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) +<<<<<<< HEAD @skipGPUIf(not HAS_GPU, "requires gpu and triton") @skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet") @dynamo_config.patch({"capture_dynamic_output_shape_ops": True}) @@ -564,6 +565,8 @@ def fn(x): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True) diff --git a/test/inductor/test_utils.py b/test/inductor/test_utils.py index 0fb1a8dcf3222..60f7483692a1e 100644 --- a/test/inductor/test_utils.py +++ b/test/inductor/test_utils.py @@ -1,11 +1,15 @@ # Owner(s): ["module: inductor"] +<<<<<<< HEAD import unittest +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from sympy import Symbol, sympify import torch from torch._inductor.fx_utils import count_flops_fx, countable_fx +<<<<<<< HEAD from torch._inductor.utils import get_device_tflops, sympy_str, sympy_subs from torch._inductor.virtualized import V from torch.testing._internal.common_device_type import ( @@ -13,6 +17,11 @@ instantiate_device_type_tests, ) from torch.testing._internal.common_utils import run_tests, TestCase +======= +from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.utils import sympy_str, sympy_subs +from torch._inductor.virtualized import V +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestUtils(TestCase): @@ -65,7 +74,11 @@ def testSympySubs(self): result = sympy_subs(expr, {Symbol("x", integer=False): Symbol("y")}) self.assertEqual(result.name, "x") +<<<<<<< HEAD # replaced can't be string +======= + # replaced cant be string +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertRaises(AssertionError, sympy_subs, expr, {"x": "y"}) # replaced can be an expression @@ -194,6 +207,7 @@ def create_fx_node( countable_fx(fx_node_2), f"Expected false {f}: {fx_node_2}" ) +<<<<<<< HEAD @unittest.skipIf(not torch.cuda.is_available(), "skip if no device") @dtypes(torch.float16, torch.bfloat16, torch.float32) def test_get_device_tflops(self, dtype): @@ -202,6 +216,8 @@ def test_get_device_tflops(self, dtype): instantiate_device_type_tests(TestUtils, globals()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_xpu_basic.py b/test/inductor/test_xpu_basic.py index 4501b8264c5f9..5282493b7df02 100644 --- a/test/inductor/test_xpu_basic.py +++ b/test/inductor/test_xpu_basic.py @@ -53,7 +53,13 @@ def fn(a, b): if __name__ == "__main__": from torch._dynamo.test_case import run_tests +<<<<<<< HEAD from torch.testing._internal.inductor_utils import HAS_XPU_AND_TRITON if HAS_XPU_AND_TRITON: +======= + from torch.testing._internal.inductor_utils import HAS_XPU + + if HAS_XPU: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests(needs="filelock") diff --git a/test/jit/test_alias_analysis.py b/test/jit/test_alias_analysis.py index 8905872c5c3cc..bc724751298b1 100644 --- a/test/jit/test_alias_analysis.py +++ b/test/jit/test_alias_analysis.py @@ -23,7 +23,11 @@ def test_becomes_wildcard_annotations(self): graph = parse_ir(graph_str) alias_db = graph.alias_db() split_node = graph.findNode("aten::split") +<<<<<<< HEAD # split input enters wildcard set, list initialized as containing wildcard set +======= + # split input enters wildcard set, list initalized as containing wildcard set +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue( alias_db.may_contain_alias(next(split_node.inputs()), split_node.output()) ) diff --git a/test/jit/test_backends.py b/test/jit/test_backends.py index eef8cc75fdcd9..1f173b8afc957 100644 --- a/test/jit/test_backends.py +++ b/test/jit/test_backends.py @@ -802,8 +802,12 @@ def test_attribute(self): # Attach bundled inputs which adds several attributes and functions to the model self.lowered_module = ( torch.utils.bundled_inputs.augment_model_with_bundled_inputs( +<<<<<<< HEAD lowered_module, # noqa: F821 input, +======= + lowered_module, input # noqa: F821 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) post_bundled = self.lowered_module( diff --git a/test/jit/test_builtins.py b/test/jit/test_builtins.py index 781080f5deb60..bd13523b27f68 100644 --- a/test/jit/test_builtins.py +++ b/test/jit/test_builtins.py @@ -131,6 +131,7 @@ def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]: jit_out = torch.jit.script(del_dict_multiple_operands)({"hi": 5, "there": 6}) self.assertEqual(py_out, jit_out) +<<<<<<< HEAD def test_torch_check(self): """Test torch._check functionality with flexible argument handling""" @@ -289,6 +290,8 @@ def too_many_total_args(x): torch._check(True, "msg", cond=False) return x +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestTensorBuiltins(JitTestCase): def test_tensor_properties(self): diff --git a/test/jit/test_ignore_context_manager.py b/test/jit/test_ignore_context_manager.py index 98fb3e7e21d20..fb2d5c034a4e8 100644 --- a/test/jit/test_ignore_context_manager.py +++ b/test/jit/test_ignore_context_manager.py @@ -2,6 +2,10 @@ import os import sys +<<<<<<< HEAD +======= +import unittest +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch @@ -9,11 +13,19 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +<<<<<<< HEAD +======= +from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import raise_on_run_directly from torch.testing._internal.jit_utils import JitTestCase class TestIgnoreContextManager(JitTestCase): +<<<<<<< HEAD +======= + @unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_with_ignore_context_manager_with_inp_out(self): class A(torch.nn.Module): def forward(self): @@ -65,6 +77,10 @@ def forward(self): self.assertEqual(s(), 6) self.assertEqual(s(), model()) +<<<<<<< HEAD +======= + @unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_with_ignore_context_manager_with_just_inp(self): class A(torch.nn.Module): def forward(self): @@ -79,6 +95,10 @@ def forward(self): self.assertEqual(s(), 4) self.assertEqual(s(), model()) +<<<<<<< HEAD +======= + @unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_with_ignore_context_manager_with_just_out(self): class A(torch.nn.Module): def forward(self): diff --git a/test/jit/test_models.py b/test/jit/test_models.py index 4dd099dbaad5e..48856fbb3d2e6 100644 --- a/test/jit/test_models.py +++ b/test/jit/test_models.py @@ -7,7 +7,10 @@ import torch import torch.nn as nn import torch.nn.functional as F +<<<<<<< HEAD from torch.testing._internal.common_cuda import tf32_on_and_off +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import ( enable_profiling_mode_for_profiling_tests, GRAPH_EXECUTOR, @@ -483,7 +486,10 @@ def test_super_resolution(self): self._test_super_resolution(self, device="cpu") @unittest.skipIf(not RUN_CUDA, "no CUDA") +<<<<<<< HEAD @tf32_on_and_off(0.02) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_super_resolution_cuda(self): # XXX: export_import on CUDA modules doesn't work (#11480) self._test_super_resolution(self, device="cuda", check_export_import=False) diff --git a/test/jit/test_module_containers.py b/test/jit/test_module_containers.py index eaedf48080b92..154d445f6071c 100644 --- a/test/jit/test_module_containers.py +++ b/test/jit/test_module_containers.py @@ -279,23 +279,41 @@ def __init__(self) -> None: self.moduledict = CustomModuleDict({"submod": self.submod}) def forward(self, inputs): +<<<<<<< HEAD assert self.modulelist[0] is self.submod, ( "__getitem__ failing for ModuleList" ) +======= + assert ( + self.modulelist[0] is self.submod + ), "__getitem__ failing for ModuleList" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(self.modulelist) == 1, "__len__ failing for ModuleList" for module in self.modulelist: assert module is self.submod, "__iter__ failing for ModuleList" +<<<<<<< HEAD assert self.sequential[0] is self.submod, ( "__getitem__ failing for Sequential" ) +======= + assert ( + self.sequential[0] is self.submod + ), "__getitem__ failing for Sequential" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(self.sequential) == 1, "__len__ failing for Sequential" for module in self.sequential: assert module is self.submod, "__iter__ failing for Sequential" +<<<<<<< HEAD assert self.moduledict["submod"] is self.submod, ( "__getitem__ failing for ModuleDict" ) +======= + assert ( + self.moduledict["submod"] is self.submod + ), "__getitem__ failing for ModuleDict" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(self.moduledict) == 1, "__len__ failing for ModuleDict" # note: unable to index moduledict with a string variable currently @@ -439,9 +457,15 @@ def __init__(self) -> None: self.moduledict = CustomModuleDict() def forward(self, inputs): +<<<<<<< HEAD assert "submod" not in self.moduledict, ( "__contains__ fails for ModuleDict" ) +======= + assert ( + "submod" not in self.moduledict + ), "__contains__ fails for ModuleDict" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return inputs m = MyModule() diff --git a/test/jit/test_modules.py b/test/jit/test_modules.py index ff4ca58e557e4..7505daa297577 100644 --- a/test/jit/test_modules.py +++ b/test/jit/test_modules.py @@ -21,7 +21,11 @@ def test_script_module_with_constants_list(self): """ # torch.nn.Linear has a __constants__ attribute defined +<<<<<<< HEAD # and initialized to a list. +======= + # and intialized to a list. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Net(torch.nn.Linear): x: torch.jit.Final[int] diff --git a/test/jit/test_python_builtins.py b/test/jit/test_python_builtins.py index 771ba85895226..f0c8bc7b11ee3 100644 --- a/test/jit/test_python_builtins.py +++ b/test/jit/test_python_builtins.py @@ -405,7 +405,13 @@ def test_index_ellipses(self): def f(): x = torch.ones(10, 9, 8, 7, 6) return x{indices}.shape +<<<<<<< HEAD """.format(indices=indices) +======= + """.format( + indices=indices + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) test_str = test_str.replace(r"'", r"") scope = {} diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py index d6addfddca1a7..d3501d0db18aa 100644 --- a/test/jit/test_recursive_script.py +++ b/test/jit/test_recursive_script.py @@ -4,7 +4,10 @@ import os import re import sys +<<<<<<< HEAD import threading +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import types import typing import typing_extensions @@ -774,6 +777,7 @@ def forward(self, x): mod.foo = None self.checkModule(mod, (torch.rand(2, 2),)) +<<<<<<< HEAD def test_thread_safe_error_stacks(self): # prior to #160386, this causes a segfault. See [Note: Thread-safe CallStack] callstacks = [] @@ -793,6 +797,8 @@ def callstack_creator(): del callstacks[0] self.assertTrue(len(callstacks) == 0) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_override_instance_method_ignore(self): class M(torch.nn.Module): @torch.jit.ignore diff --git a/test/jit/test_scriptmod_ann.py b/test/jit/test_scriptmod_ann.py index 4541c24dc5e0c..16810c653c5fc 100644 --- a/test/jit/test_scriptmod_ann.py +++ b/test/jit/test_scriptmod_ann.py @@ -139,7 +139,13 @@ def forward(self, x: List[int]): ): with self.assertWarnsRegex( UserWarning, +<<<<<<< HEAD "doesn't support instance-level annotations on empty non-base types", +======= + "doesn't support " + "instance-level annotations on " + "empty non-base types", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(M()) @@ -158,7 +164,13 @@ def forward(self, x: list[int]): ): with self.assertWarnsRegex( UserWarning, +<<<<<<< HEAD "doesn't support instance-level annotations on empty non-base types", +======= + "doesn't support " + "instance-level annotations on " + "empty non-base types", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(M()) @@ -177,7 +189,13 @@ def forward(self, x: Dict[str, int]): ): with self.assertWarnsRegex( UserWarning, +<<<<<<< HEAD "doesn't support instance-level annotations on empty non-base types", +======= + "doesn't support " + "instance-level annotations on " + "empty non-base types", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(M()) @@ -196,7 +214,13 @@ def forward(self, x: dict[str, int]): ): with self.assertWarnsRegex( UserWarning, +<<<<<<< HEAD "doesn't support instance-level annotations on empty non-base types", +======= + "doesn't support " + "instance-level annotations on " + "empty non-base types", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(M()) @@ -215,7 +239,13 @@ def forward(self, x: Optional[str]): ): with self.assertWarnsRegex( UserWarning, +<<<<<<< HEAD "doesn't support instance-level annotations on empty non-base types", +======= + "doesn't support " + "instance-level annotations on " + "empty non-base types", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(M()) @@ -234,7 +264,13 @@ def forward(self, x: List[int]): ): with self.assertWarnsRegex( UserWarning, +<<<<<<< HEAD "doesn't support instance-level annotations on empty non-base types", +======= + "doesn't support " + "instance-level annotations on " + "empty non-base types", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(M()) @@ -253,7 +289,13 @@ def forward(self, x: list[int]): ): with self.assertWarnsRegex( UserWarning, +<<<<<<< HEAD "doesn't support instance-level annotations on empty non-base types", +======= + "doesn't support " + "instance-level annotations on " + "empty non-base types", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(M()) @@ -272,7 +314,13 @@ def forward(self, x: Dict[str, int]): ): with self.assertWarnsRegex( UserWarning, +<<<<<<< HEAD "doesn't support instance-level annotations on empty non-base types", +======= + "doesn't support " + "instance-level annotations on " + "empty non-base types", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(M()) @@ -291,7 +339,13 @@ def forward(self, x: dict[str, int]): ): with self.assertWarnsRegex( UserWarning, +<<<<<<< HEAD "doesn't support instance-level annotations on empty non-base types", +======= + "doesn't support " + "instance-level annotations on " + "empty non-base types", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(M()) @@ -310,7 +364,13 @@ def forward(self, x: Optional[str]): ): with self.assertWarnsRegex( UserWarning, +<<<<<<< HEAD "doesn't support instance-level annotations on empty non-base types", +======= + "doesn't support " + "instance-level annotations on " + "empty non-base types", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(M()) @@ -331,7 +391,13 @@ def forward(self, x: Optional[str]): ): with self.assertWarnsRegex( UserWarning, +<<<<<<< HEAD "doesn't support instance-level annotations on empty non-base types", +======= + "doesn't support " + "instance-level annotations on " + "empty non-base types", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(M()) diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 8d5cfffbcad8e..704daae93e2a6 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -960,9 +960,14 @@ def foo(a, b): V = Variable a, b = V(torch.rand(1)), V(torch.rand(1)) ge = torch.jit.trace(foo, (a, b)) +<<<<<<< HEAD a, b = ( V(torch.rand(1), requires_grad=True), V(torch.rand(1), requires_grad=True), +======= + a, b = V(torch.rand(1), requires_grad=True), V( + torch.rand(1), requires_grad=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) (r,) = ge(a, b) da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True) diff --git a/test/jit/test_union.py b/test/jit/test_union.py index c5afa13463221..881f6eebcec44 100644 --- a/test/jit/test_union.py +++ b/test/jit/test_union.py @@ -396,7 +396,13 @@ def fn(): with self.assertRaisesRegex( RuntimeError, +<<<<<<< HEAD "only int, float, complex, Tensor, device and string keys are supported", +======= + "only int, float, " + "complex, Tensor, device and string keys " + "are supported", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(fn) @@ -600,7 +606,13 @@ def fn(x: int) -> str: with self.assertRaisesRegex( RuntimeError, +<<<<<<< HEAD "y is set to type str in the true branch and type int in the false branch", +======= + "y is set to type str" + " in the true branch and type int " + "in the false branch", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(fn) @@ -618,7 +630,13 @@ def fn(x: int) -> str: with self.assertRaisesRegex( RuntimeError, +<<<<<<< HEAD "previously had type str but is now being assigned to a value of type int", +======= + "previously had type " + "str but is now being assigned to a" + " value of type int", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(fn) @@ -723,7 +741,12 @@ def fn(): template, "Union[List[str], List[torch.Tensor]]", lhs["list_literal_empty"], +<<<<<<< HEAD "there are multiple possible List type candidates in the Union annotation", +======= + "there are multiple possible List type " + "candidates in the Union annotation", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self._assert_passes( @@ -895,7 +918,12 @@ def fn(): template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", lhs["dict_literal_of_mixed"], +<<<<<<< HEAD "none of those dict types can hold the types of the given keys and values", +======= + "none of those dict types can hold the " + "types of the given keys and values", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # TODO: String frontend does not support tuple unpacking diff --git a/test/jit/test_union_pep604.py b/test/jit/test_union_pep604.py index 0cd2ce33165a3..1a7113be910ef 100644 --- a/test/jit/test_union_pep604.py +++ b/test/jit/test_union_pep604.py @@ -406,7 +406,13 @@ def fn(): with self.assertRaisesRegex( RuntimeError, +<<<<<<< HEAD "only int, float, complex, Tensor, device and string keys are supported", +======= + "only int, float, " + "complex, Tensor, device and string keys " + "are supported", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(fn) @@ -610,7 +616,13 @@ def fn(x: int) -> str: with self.assertRaisesRegex( RuntimeError, +<<<<<<< HEAD "y is set to type str in the true branch and type int in the false branch", +======= + "y is set to type str" + " in the true branch and type int " + "in the false branch", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(fn) @@ -628,7 +640,13 @@ def fn(x: int) -> str: with self.assertRaisesRegex( RuntimeError, +<<<<<<< HEAD "previously had type str but is now being assigned to a value of type int", +======= + "previously had type " + "str but is now being assigned to a" + " value of type int", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.jit.script(fn) @@ -733,7 +751,12 @@ def fn(): template, "List[str] | List[torch.Tensor]", lhs["list_literal_empty"], +<<<<<<< HEAD "there are multiple possible List type candidates in the Union annotation", +======= + "there are multiple possible List type " + "candidates in the Union annotation", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self._assert_passes( @@ -899,7 +922,12 @@ def fn(): template, "Dict[str, torch.Tensor] | Dict[str, int]", lhs["dict_literal_of_mixed"], +<<<<<<< HEAD "none of those dict types can hold the types of the given keys and values", +======= + "none of those dict types can hold the " + "types of the given keys and values", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # TODO: String frontend does not support tuple unpacking diff --git a/test/jit/test_warn.py b/test/jit/test_warn.py index 70f14cd2faff5..448f054bd85a2 100644 --- a/test/jit/test_warn.py +++ b/test/jit/test_warn.py @@ -135,6 +135,7 @@ def bar(): bar() FileCheck().check_count( +<<<<<<< HEAD str="UserWarning: I am warning you from foo", count=1, exactly=True, @@ -143,6 +144,14 @@ def bar(): count=1, exactly=True, ).run(f.getvalue()) +======= + str="UserWarning: I am warning you from foo", count=1, exactly=True + ).check_count( + str="UserWarning: I am warning you from bar", count=1, exactly=True + ).run( + f.getvalue() + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": diff --git a/test/lazy/test_generator.py b/test/lazy/test_generator.py index a513b41b08808..83a291fb876cd 100644 --- a/test/lazy/test_generator.py +++ b/test/lazy/test_generator.py @@ -42,12 +42,21 @@ def generate_tensor(): torch._lazy.mark_step() +<<<<<<< HEAD assert torch.allclose(cpu_t1, lazy_t1.to("cpu")), ( f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}" ) assert torch.allclose(cpu_t2, lazy_t2.to("cpu")), ( f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}" ) +======= + assert torch.allclose( + cpu_t1, lazy_t1.to("cpu") + ), f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}" + assert torch.allclose( + cpu_t2, lazy_t2.to("cpu") + ), f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type") def test_generator_causes_multiple_compiles(self): @@ -69,22 +78,35 @@ def generate_tensor(seed): torch._lazy.mark_step() uncached_compile = metrics.counter_value("UncachedCompile") +<<<<<<< HEAD assert uncached_compile == 1, ( f"Expected 1 uncached compiles, got {uncached_compile}" ) +======= + assert ( + uncached_compile == 1 + ), f"Expected 1 uncached compiles, got {uncached_compile}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) t = generate_tensor(2) torch._lazy.mark_step() uncached_compile = metrics.counter_value("UncachedCompile") +<<<<<<< HEAD assert uncached_compile == 2, ( f"Expected 2 uncached compiles, got {uncached_compile}" ) +======= + assert ( + uncached_compile == 2 + ), f"Expected 2 uncached compiles, got {uncached_compile}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) t = generate_tensor(1) # noqa: F841 torch._lazy.mark_step() uncached_compile = metrics.counter_value("UncachedCompile") +<<<<<<< HEAD assert uncached_compile == 2, ( f"Expected 2 uncached compiles, got {uncached_compile}" ) @@ -92,6 +114,15 @@ def generate_tensor(seed): assert cached_compile == 1, ( f"Expected 1 cached compile, got {cached_compile}" ) +======= + assert ( + uncached_compile == 2 + ), f"Expected 2 uncached compiles, got {uncached_compile}" + cached_compile = metrics.counter_value("CachedCompile") + assert ( + cached_compile == 1 + ), f"Expected 1 cached compile, got {cached_compile}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) metrics.reset() diff --git a/test/mobile/model_test/gen_test_model.py b/test/mobile/model_test/gen_test_model.py index 5e760a739cec7..29b4f711de05c 100644 --- a/test/mobile/model_test/gen_test_model.py +++ b/test/mobile/model_test/gen_test_model.py @@ -118,16 +118,26 @@ def calcOpsCoverage(ops): uncovered_ops = production_ops - covered_ops coverage = round(100 * len(covered_ops) / len(production_ops), 2) +<<<<<<< HEAD # weighted coverage (take op occurrences into account) total_occurrences = sum(production_ops_dict["root_operators"].values()) +======= + # weighted coverage (take op occurances into account) + total_occurances = sum(production_ops_dict["root_operators"].values()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) covered_ops_dict = { op: production_ops_dict["root_operators"][op] for op in covered_ops } uncovered_ops_dict = { op: production_ops_dict["root_operators"][op] for op in uncovered_ops } +<<<<<<< HEAD covered_occurrences = sum(covered_ops_dict.values()) occurrences_coverage = round(100 * covered_occurrences / total_occurrences, 2) +======= + covered_occurances = sum(covered_ops_dict.values()) + occurances_coverage = round(100 * covered_occurances / total_occurances, 2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) print(f"\n{len(uncovered_ops)} uncovered ops: {uncovered_ops}\n") print(f"Generated {len(all_generated_ops)} ops") @@ -135,7 +145,11 @@ def calcOpsCoverage(ops): f"Covered {len(covered_ops)}/{len(production_ops)} ({coverage}%) production ops" ) print( +<<<<<<< HEAD f"Covered {covered_occurrences}/{total_occurrences} ({occurrences_coverage}%) occurrences" +======= + f"Covered {covered_occurances}/{total_occurances} ({occurances_coverage}%) occurances" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) print(f"pytorch ver {torch.__version__}\n") diff --git a/test/mobile/test_lite_script_module.py b/test/mobile/test_lite_script_module.py index a1f84ca7e37b1..d145f3602960a 100644 --- a/test/mobile/test_lite_script_module.py +++ b/test/mobile/test_lite_script_module.py @@ -486,9 +486,23 @@ def forward(self): "Traceback of TorchScript" ).check("self.b.forwardError").check_next( "~~~~~~~~~~~~~~~~~~~ <--- HERE" +<<<<<<< HEAD ).check("return self.call").check_next("~~~~~~~~~ <--- HERE").check( "return torch.ones" ).check_next("~~~~~~~~~~ <--- HERE").run(str(exp)) +======= + ).check( + "return self.call" + ).check_next( + "~~~~~~~~~ <--- HERE" + ).check( + "return torch.ones" + ).check_next( + "~~~~~~~~~~ <--- HERE" + ).run( + str(exp) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestLiteScriptQuantizedModule(QuantizationLiteTestCase): diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index 7dacfeed003cc..479a317cbe47d 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -30,7 +30,10 @@ skipCUDAIfMiopen, skipCUDAIfNoCudnn, skipCUDAIfNoMiopen, +<<<<<<< HEAD skipCUDAIfNotMiopenSuggestNHWC, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) skipCUDAIfRocm, skipMeta, skipMPS, @@ -52,6 +55,7 @@ parametrize as parametrize_test, run_tests, set_default_dtype, +<<<<<<< HEAD skipIfNotMiopenSuggestNHWC, skipIfRocmArch, skipIfRocmVersionLessThan, @@ -59,6 +63,12 @@ TEST_SCIPY, TEST_WITH_ROCM, xfailIf, +======= + skipIfRocmArch, + subtest, + TEST_SCIPY, + TEST_WITH_ROCM, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -67,6 +77,10 @@ if TEST_WITH_ROCM: os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1" +<<<<<<< HEAD +======= + os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TEST_SCIPY: @@ -718,7 +732,10 @@ def test_ConvTranspose2d_half_cublas_gemm(self): # Almost identical to the above `test_Conv2d_naive_groups` @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False) @tf32_on_and_off(0.001) +<<<<<<< HEAD @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_Conv2d_groups_nobias(self): dev_dtypes = [("cpu", torch.float)] if TEST_CUDA: @@ -764,7 +781,10 @@ def test_Conv2d_groups_nobias(self): # and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024 @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False) @tf32_on_and_off(0.001) +<<<<<<< HEAD @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_Conv2d_groups_nobias_v2(self): torch.manual_seed(123) dev_dtypes = [("cpu", torch.float)] @@ -899,7 +919,10 @@ def test_conv_tbc(self): @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @unittest.skipIf(not TEST_CUDNN, "needs cudnn") +<<<<<<< HEAD @skipIfNotMiopenSuggestNHWC +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_grouped_conv_cudnn_nhwc_support(self): # in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to( @@ -2850,7 +2873,10 @@ def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N): @parametrize_test("strided", [False, True]) # Test with both contiguous and non-contiguous inputs. @parametrize_test("contiguous", [False, True]) +<<<<<<< HEAD @expectedFailureMPS # No double support +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv_backend( self, device, @@ -3149,7 +3175,10 @@ def test_conv_noncontig_weights_and_bias(self, device): @onlyCUDA @largeTensorTest("12GB") +<<<<<<< HEAD @skipIfRocmVersionLessThan((6, 0)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv_transposed_large(self, device): dtype = torch.half if self.device_type == "cuda" else torch.float conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype) @@ -3193,7 +3222,10 @@ def test_conv_transposed_large(self, device): self.assertEqual(maxdiff3, 0) @onlyCUDA +<<<<<<< HEAD @skipCUDAIfRocm +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @largeTensorTest("12GB") def test_conv_large(self, device): dtype = torch.half if self.device_type == "cuda" else torch.float @@ -3226,7 +3258,10 @@ def test_conv_large(self, device): self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3) @onlyCUDA +<<<<<<< HEAD @skipCUDAIfRocm +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @largeTensorTest("20GB", "cpu") @largeTensorTest("60GB", "cuda") def test_conv_large_batch_1(self, device): @@ -3249,6 +3284,20 @@ def test_conv_large_batch_1(self, device): self.assertEqual(output.cpu().float(), output_cpu, atol=1e-3, rtol=1e-3) @onlyCUDA +<<<<<<< HEAD +======= + @skipCUDAIfRocm + @largeTensorTest("24GB", "cpu") + @largeTensorTest("20GB", "cuda") + def test_conv3d_large_batch_1(self, device): + x = torch.rand(1, 32, 512, 512, 256) + m = torch.nn.Conv3d(32, 1, kernel_size=1, padding=0, stride=1, bias=False) + yref = m(x) + y = m.to(device=device)(x.to(device=device)) + self.assertEqual(yref, y.cpu()) + + @onlyCUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipCUDAIfNoCudnn def test_contig_wrong_stride_cudnn(self, device): # x has to have batch_size 1 to test contiguous checks @@ -3363,7 +3412,10 @@ def test_ConvTranspose3d_size_1_kernel(self, device): @dtypes(torch.float) @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False) @tf32_on_and_off(0.001) +<<<<<<< HEAD @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_Conv2d_naive_groups(self, device, dtype): # Check that grouped convolutions matches two half convolutions m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype) @@ -3632,6 +3684,7 @@ def helper( ) @onlyCUDA +<<<<<<< HEAD @skipCUDAIfNotMiopenSuggestNHWC @dtypes(torch.half, torch.float, torch.cfloat) def test_conv_cudnn_nhwc(self, device, dtype): @@ -3639,12 +3692,27 @@ def helper(n, c, h, w, out_channels, kernel_size, groups): input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to( memory_format=torch.channels_last ) +======= + @dtypes(torch.half, torch.float, torch.cfloat) + def test_conv_cudnn_nhwc(self, device, dtype): + def helper(n, c, h, w, out_channels, kernel_size, groups): + # randint with dtype=torch.cfloat fails with + # RuntimeError: check_random_bounds handles only integral, floating-point and boolean types + # must create randint and randint_like using default int64, then cast to desired + input = torch.randint( + -3, 3, (n, c, h, w), dtype=torch.int64, device=device + ).to(dtype, memory_format=torch.channels_last) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input.requires_grad_() conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to( device="cuda", dtype=dtype, memory_format=torch.channels_last ) for p in conv.parameters(): +<<<<<<< HEAD p.data = torch.randint_like(p, -3, 3) +======= + p.data = torch.randint_like(p, -3, 3, dtype=torch.int64).to(p.dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # use FP64 channels-first conv as reference ref_input = input.detach().clone().contiguous().double().requires_grad_() @@ -3658,7 +3726,11 @@ def helper(n, c, h, w, out_channels, kernel_size, groups): out = conv(input) ref_out = ref_conv(ref_input) +<<<<<<< HEAD grad = torch.randint_like(out, -3, 3) +======= + grad = torch.randint_like(out, -3, 3, dtype=torch.int64).to(out.dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ref_grad = grad.detach().clone().double().contiguous() out.backward(grad) @@ -3685,7 +3757,10 @@ def helper(n, c, h, w, out_channels, kernel_size, groups): helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16) @onlyCUDA +<<<<<<< HEAD @skipCUDAIfRocm +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.half, torch.float) def test_conv_cudnn_ndhwc(self, device, dtype): def helper(n, c, d, h, w, out_channels, kernel_size, groups): @@ -3815,7 +3890,10 @@ def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device): ) @onlyCUDA +<<<<<<< HEAD @skipCUDAIfNotMiopenSuggestNHWC +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @tf32_on_and_off(0.05) def test_conv_cudnn_mismatch_memory_format(self, device): configs = [ @@ -3949,7 +4027,10 @@ def test_cudnn_convolution_add_relu(self, device, dtype): self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out) @onlyCUDA +<<<<<<< HEAD @skipCUDAIfRocm +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_convert_conv2d_weight_memory_format(self, device): input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device) model = nn.Sequential(nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float() @@ -3969,7 +4050,10 @@ def test_convert_conv2d_weight_memory_format(self, device): self.assertTrue(out.is_contiguous(memory_format=memory_format)) @onlyCUDA +<<<<<<< HEAD @skipCUDAIfRocm +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_convert_conv3d_weight_memory_format(self, device): input = torch.randint( 1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device @@ -4054,6 +4138,7 @@ def test_conv3d_64bit_indexing(self, device): @skipCUDAIfRocm @onlyCUDA +<<<<<<< HEAD @largeTensorTest("40GB", "cuda") def test_conv3d_cudnn_broken(self, device): for dtype in (torch.half, torch.bfloat16): @@ -4100,6 +4185,16 @@ def test_depthwise_conv_64bit_indexing(self, device): yref = c(x) y = c.to(device=device)(x.to(device=device)) self.assertEqual(yref, y, atol=1e-3, rtol=1e-4) +======= + @largeTensorTest("20GB") + @largeTensorTest("80GB", "cpu") + def test_depthwise_conv_64bit_indexing(self, device): + x = torch.randn(1, 2, 32800, 32800) + c = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1, groups=2) + yref = c(x) + y = c.to(device=device)(x.to(device=device)) + self.assertEqual(yref, y) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), allow_mps=True) diff --git a/test/nn/test_embedding.py b/test/nn/test_embedding.py index 3b21143711a56..e558cc86790fb 100644 --- a/test/nn/test_embedding.py +++ b/test/nn/test_embedding.py @@ -182,7 +182,10 @@ def test_embedding_functional(self): self.assertEqual(res_old, res_F) # https://github.com/pytorch/pytorch/issues/130806 +<<<<<<< HEAD @unittest.skipIf(not TEST_CUDA, "CUDA not available") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @largeTensorTest("40GB", device="cuda") def test_large_tensors(self): input = torch.randint(low=0, high=16032, size=[131072], device="cuda") diff --git a/test/nn/test_parametrization.py b/test/nn/test_parametrization.py index eb1f7c982b7ca..0cc5aefe0c834 100644 --- a/test/nn/test_parametrization.py +++ b/test/nn/test_parametrization.py @@ -1652,7 +1652,11 @@ def assert_weight_allclose_Q(weight, W): if can_initialize: assert_weight_allclose_Q(m.weight, w_init) +<<<<<<< HEAD # Initializing with a given orthogonal matrix works +======= + # Intializing with a given orthogonal matrix works +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) X = torch.randn_like(m.weight) if wide_matrix: X = X.mT @@ -1669,7 +1673,11 @@ def assert_weight_allclose_Q(weight, W): with self.assertRaisesRegex(NotImplementedError, msg): m.weight = w_new +<<<<<<< HEAD # Initializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix +======= + # Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) w_new = torch.randn_like(m.weight) if can_initialize: m.weight = w_new diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py index a8f77df22d311..4559692e8b437 100644 --- a/test/nn/test_pooling.py +++ b/test/nn/test_pooling.py @@ -504,7 +504,10 @@ def test_quantized_max_pool3d(self): class TestPoolingNNDeviceType(NNTestCase): +<<<<<<< HEAD @expectedFailureMPS # No double, float shape prop does not work +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes @dtypes(torch.float, torch.double) def test_adaptive_pooling_zero_batch(self, dtype, device): @@ -524,7 +527,10 @@ def test_adaptive_pooling_zero_batch(self, dtype, device): # when output_size = 0, in adaptive_{avg, max}_pool and its variants. # These tests are explicitly written because ErrorInputs does not support backward calls # Issue: https://github.com/pytorch/pytorch/issues/78868 +<<<<<<< HEAD @expectedFailureMPS # No double, float shape prop does not work +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes @dtypes(torch.float32, torch.float64) @dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16, torch.float16) @@ -558,7 +564,10 @@ def test_adaptive_pooling_empty_output_size(self, dtype, device): with self.assertRaisesRegex(RuntimeError, error_msg): fn(input2, output_size).sum().backward() +<<<<<<< HEAD @expectedFailureMPS # Error message does not match +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_adaptive_avg_pooling_backward_fails(self, device): grad_output = torch.randn(1, 2, 7, device=device) @@ -585,7 +594,10 @@ def test_adaptive_max_pooling_backward_fails(self, device): with self.assertRaisesRegex(RuntimeError, "expected dimensions"): torch.ops.aten.adaptive_max_pool3d_backward(grad_output, input, indices) +<<<<<<< HEAD @expectedFailureMPS # Op not implemented +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_FractionalMaxPool2d_zero_batch(self, device): mod = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5)) @@ -596,7 +608,10 @@ def test_FractionalMaxPool2d_zero_batch(self, device): inp = torch.randn(1, 0, 50, 32, device=device) mod(inp) +<<<<<<< HEAD @expectedFailureMPS # Op not implemented +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_FractionalMaxPool3d_zero_batch(self, device): mod = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5)).to(device) @@ -607,7 +622,10 @@ def test_FractionalMaxPool3d_zero_batch(self, device): inp = torch.randn(1, 0, 50, 32, 32, device=device) mod(inp) +<<<<<<< HEAD @expectedFailureMPS # Op not implemented +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_FractionalMaxPool2d_zero_out_size(self, device): mod = nn.FractionalMaxPool2d([2, 2], output_size=[0, 1]) @@ -615,7 +633,10 @@ def test_FractionalMaxPool2d_zero_out_size(self, device): out = mod(inp) self.assertEqual(out, torch.empty((16, 50, 0, 1), device=device)) +<<<<<<< HEAD @expectedFailureMPS # Op not implemented +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_FractionalMaxPool3d_zero_out_size(self, device): mod = nn.FractionalMaxPool3d([3, 2, 2], output_size=[0, 1, 1]) @@ -623,7 +644,10 @@ def test_FractionalMaxPool3d_zero_out_size(self, device): out = mod(inp) self.assertEqual(out, torch.empty((16, 0, 1, 1), device=device)) +<<<<<<< HEAD @expectedFailureMPS # Op not implemented +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_FractionalMaxPool2d_zero_samples(self, device): samples = torch.rand([0, 16, 2], device=device) @@ -638,7 +662,10 @@ def test_FractionalMaxPool2d_zero_samples(self, device): with self.assertRaisesRegex(RuntimeError, "Expect _random_samples"): mod(inp1) +<<<<<<< HEAD @expectedFailureMPS # Op not implemented +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_FractionalMaxPool3d_zero_samples(self, device): samples = torch.rand([0, 16, 3], device=device) @@ -654,6 +681,7 @@ def test_FractionalMaxPool3d_zero_samples(self, device): mod(inp1) @onlyNativeDeviceTypes +<<<<<<< HEAD def test_FractionalMaxPool3d_errors(self, device): samples = torch.rand([0, 16, 3], device=device) with self.assertRaisesRegex(ValueError, "kernel_size must greater than 0"): @@ -664,6 +692,8 @@ def test_FractionalMaxPool3d_errors(self, device): ) @onlyNativeDeviceTypes +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_MaxPool_zero_batch_dim(self, device): inp = torch.randn(0, 16, 50, device=device) mod = torch.nn.MaxPool1d(3, stride=2).to(device) @@ -832,7 +862,10 @@ def test_MaxUnpool_index_errors( else: unpool(output, indices) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_AdaptiveMaxPool_zero_batch_dim(self, device): inp = torch.randn(0, 16, 50, device=device) @@ -875,6 +908,7 @@ def test_AvgPool2d_empty(self, device): inp = torch.randn(16, 0, 20, 32, device=device) avgpool(inp) +<<<<<<< HEAD @parametrize_test("kernel", ["max", "avg"]) @parametrize_test("pooling_dims", [1, 2, 3]) def test_pooling_shape(self, device, kernel, pooling_dims): @@ -892,6 +926,22 @@ def check(expected_out_shape, sizes, *args, **kwargs): self.assertEqual( op(t, *args, **kwargs).shape, expected_out_shape[: pooling_dims + 2] ) +======= + @expectedFailureMPS # max_pool3d_with_indices not supported on MPS + def test_pooling_shape(self, device): + """Test the output shape calculation for pooling functions""" + + # Checks output shape against expected for 1D, 2D and 3D + def check(expected_out_shape, sizes, *args, **kwargs): + for kernel in ["max", "avg"]: + for i in [1, 2, 3]: + if hasattr(torch.nn.functional, f"{kernel}_pool{i}d"): + op = getattr(torch.nn.functional, f"{kernel}_pool{i}d") + t = torch.randn(sizes[: i + 2], device=device) + self.assertEqual( + op(t, *args, **kwargs).shape, expected_out_shape[: i + 2] + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) check( (1, 1, 3, 3, 4), @@ -972,7 +1022,10 @@ def test_adaptive_avg_pool3d_output_size_one(self, device): c = out.size(1) self.assertEqual(out.stride(), [c, 1, 1, 1, 1]) +<<<<<<< HEAD @expectedFailureMPS # Runtime Error not raised for mps +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @expectedFailureMeta # Runtime Error not raised for meta @onlyNativeDeviceTypes @dtypes(torch.uint8, torch.int8, torch.short, torch.int, torch.long) @@ -987,7 +1040,10 @@ def test_adaptive_pooling_no_suppot_input(self, device, dtype): with self.assertRaisesRegex(RuntimeError, "not implemented"): module(input) +<<<<<<< HEAD @expectedFailureMPS # TODO: fixme +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes @gcIfJetson @dtypes(torch.float, torch.double) @@ -1135,7 +1191,10 @@ def helper(n, c, h, w, ks): helper(1, 100000, 32, 32, ks=4) helper(1, 100000, 1, 4, ks=(1, 4)) # test for max_pool1d +<<<<<<< HEAD @expectedFailureMPS # TODO: Fixme +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes @dtypes(torch.half, torch.bfloat16, torch.float, torch.double) @dtypesIfCUDA(torch.half, torch.float, torch.double) @@ -1211,7 +1270,10 @@ def check(x, args, expected, memory_format): torch.channels_last, ) +<<<<<<< HEAD @expectedFailureMPS # TODO: Fixme +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes @dtypes(torch.half, torch.bfloat16, torch.float, torch.double) @dtypesIfCUDA(torch.half, torch.float, torch.double) @@ -1676,6 +1738,10 @@ def test_MaxPool1d_indices(self, device, dtype): def test_MaxPool2d_indices(self, device, dtype): self._test_maxpool_indices(2, device=device, dtype=dtype) +<<<<<<< HEAD +======= + @expectedFailureMPS +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) @dtypes(torch.float) def test_MaxPool3d_indices(self, device, dtype): @@ -1736,7 +1802,10 @@ def test_maxpool_indices_no_batch_dim(self, device, dtype): @dtypesIfCUDA(torch.half, torch.float, torch.double) @dtypes(torch.float) +<<<<<<< HEAD @expectedFailureMPS # Exception not raise +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes # TODO: Fails on XLA @gcIfJetson def test_max_pool_nan_inf(self, device, dtype): @@ -1773,7 +1842,10 @@ def test_max_pool_nan_inf(self, device, dtype): res2 = fn(x2, 1 if adaptive else 3) self.assertTrue(math.isinf(res2.item())) +<<<<<<< HEAD @expectedFailureMPS # float64 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @expectedFailureMeta # RuntimeError: Unrecognized tensor type ID: Meta @onlyNativeDeviceTypes def test_fractional_max_pool2d(self, device): @@ -1836,7 +1908,10 @@ def test_fractional_max_pool2d_backward_fails(self, device): grad_output, input, kernel_size, output_size, indices ) +<<<<<<< HEAD @expectedFailureMPS # float64 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @expectedFailureMeta # RuntimeError: Unrecognized tensor type ID: Meta @onlyNativeDeviceTypes def test_fractional_max_pool3d(self, device): @@ -1884,7 +1959,10 @@ def func(x): x, (2, 2, 2), output_size=output_size, _random_samples=samples ) +<<<<<<< HEAD @expectedFailureMPS # Not implemented +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypesIfCUDA(torch.half, torch.float, torch.double) @dtypes(torch.float) @onlyNativeDeviceTypes # TODO: Fails on XLA @@ -1914,7 +1992,10 @@ def test_fractional_max_pool_nan_inf(self, device, dtype): res2.backward(torch.randn_like(res2)) self.assertTrue(math.isinf(res2.item())) +<<<<<<< HEAD @expectedFailureMPS # TODO: Fix me +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes # TODO: RuntimeError message different on XLA def test_pooling_zero_stride(self, device): for op in ("max", "avg"): @@ -1937,6 +2018,10 @@ def test_pooling_zero_stride(self, device): ) @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) +<<<<<<< HEAD +======= + @expectedFailureMPS +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.float) def test_pool_large_size(self, device, dtype): for op in ("max", "avg"): @@ -2031,6 +2116,10 @@ def test_pooling_bfloat16(self, device): prec=0.05, ) +<<<<<<< HEAD +======= + @expectedFailureMPS # max_pool3d_with_indices not supported on MPS device +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_maxpool3d_non_square_backward(self, device): # previous CUDA routine of this backward calculates kernel launch grid size # with last two dimensions interchanged, so the tailing along the longer dim diff --git a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py new file mode 100644 index 0000000000000..2e47e48f140eb --- /dev/null +++ b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py @@ -0,0 +1,849 @@ +# Owner(s): ["module: onnx"] +from __future__ import annotations + +import contextlib +import copy +import dataclasses +import os +import sys +import unittest +from pathlib import Path + +import onnxruntime +from parameterized import parameterized + +import torch +import torch._dynamo.backends.registry +from torch import nn +from torch.onnx import ( + _OrtBackend as OrtBackend, + _OrtBackendOptions as OrtBackendOptions, +) +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import skipIfNNModuleInlined + + +sys.path.append(str(Path(__file__).absolute().parents[1])) + +import onnx_test_common + + +def make_aot_ort(): + ort_backend = OrtBackend(options=OrtBackendOptions()) + return ort_backend, ort_backend + + +class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime): + def setUp(self): + super().setUp() + torch._dynamo.reset() + OrtBackend.clear_cached_instances() + + def tearDown(self): + super().tearDown() + torch._dynamo.reset() + OrtBackend.clear_cached_instances() + + def test_get_ort_device_type(self): + from onnxruntime.capi import _pybind_state as ORTC + + self.assertEqual( + torch.onnx._internal.onnxruntime._get_ort_device_type("cuda"), + ORTC.OrtDevice.cuda(), + ) + self.assertEqual( + torch.onnx._internal.onnxruntime._get_ort_device_type("cpu"), + ORTC.OrtDevice.cpu(), + ) + self.assertEqual( + torch.onnx._internal.onnxruntime._get_ort_device_type("maia"), + ORTC.OrtDevice.npu(), + ) + + def test_torch_compile_backend_registration(self): + self.assertIn("onnxrt", torch._dynamo.backends.registry.list_backends()) + backend = torch._dynamo.backends.registry.lookup_backend("onnxrt") + self.assertEqual(backend.__module__, "torch.onnx._internal.onnxruntime") + + def _test_torch_compile_backend_caching_assert_reused( + self, options: OrtBackendOptions + ): + self.assertFalse(OrtBackend.get_cached_instances()) # assert setUp/tearDown + new_backend = OrtBackend.get_cached_instance_for_options(options) + reused_backend = OrtBackend.get_cached_instance_for_options(options) + self.assertEqual(len(OrtBackend.get_cached_instances()), 1) + self.assertIs(reused_backend, new_backend) + if options is None or options.ort_session_options is None: + # OrtBackendOptions.ort_session_options is a pybind11 object that + # cannot be pickled via dataclasses.asdict + self.assertEqual( + new_backend, + OrtBackend.get_cached_instance_for_options( + dataclasses.asdict(options) if options else None + ), + ) + + @parameterized.expand( + [ + (None,), + (OrtBackendOptions(),), + (OrtBackendOptions(use_aot_autograd=True),), + (OrtBackendOptions(use_aot_autograd=False),), + (OrtBackendOptions(preallocate_output=True),), + (OrtBackendOptions(preallocate_output=False),), + (OrtBackendOptions(infer_execution_providers=True),), + (OrtBackendOptions(infer_execution_providers=False),), + (OrtBackendOptions(preferred_execution_providers=["A", "B", "C"]),), + ( + OrtBackendOptions( + preferred_execution_providers=["A", "B", ("C", {"option": "value"})] + ), + ), + (OrtBackendOptions(default_execution_providers=["Something"]),), + (OrtBackendOptions(),), + ] + ) + def test_torch_compile_backend_caching_assert_reused( + self, options: OrtBackendOptions + ): + self._test_torch_compile_backend_caching_assert_reused(options) + + @parameterized.expand( + [ + (OrtBackendOptions(ort_session_options=onnxruntime.SessionOptions()),), + ] + ) + def test_torch_compile_backend_caching_assert_not_reused( + self, options: OrtBackendOptions + ): + with self.assertRaises(AssertionError): + self._test_torch_compile_backend_caching_assert_reused(options) + + def _test_model_numerically( + self, + model, + dynamo_backend, + example_args_collection, + fullgraph: bool = False, + test_backward: bool = False, + atol: float = 1e-5, + rtol: float = 1e-6, + ): + """Run original and compiled model and compare the results. + + Args: + model: The model to test. + dynamo_backend: The dynamo backend to use. Here we use string `onnxrt` or + the first returned value of `make_aot_ort()`. + example_args_collection: A tuple of example arguments to test. E.g., + ( + (torch.randn(2), torch.randn(2)), + (torch.randn(4), torch.randn(4)), + ) + if you want to test + model(torch.randn(2), torch.randn(2)) and + model(torch.randn(4), torch.randn(4)) + . + """ + compiled_model = torch.compile( + model if not isinstance(model, torch.nn.Module) else copy.deepcopy(model), + backend=dynamo_backend, + dynamic=True, + fullgraph=fullgraph, + ) + + for example_args in example_args_collection: + baseline_result = model(*example_args) + result = compiled_model(*example_args) + if isinstance(baseline_result, torch.Tensor): + torch.testing.assert_close( + baseline_result, result, atol=atol, rtol=rtol + ) + if test_backward: + baseline_result.sum().backward() + result.sum().backward() + for baseline_param, param in zip( + model.parameters(), compiled_model.parameters() + ): + torch.testing.assert_close( + baseline_param.grad, param.grad, atol=atol, rtol=rtol + ) + else: + assert test_backward is False, ( + "Calculating backward with multiple outputs is not supported yet." + ) + for baseline_elem, result_elem in zip(baseline_result, result): + torch.testing.assert_close( + baseline_elem, result_elem, atol=atol, rtol=rtol + ) + + def _assert_counting_information( + self, + ort_backend: OrtBackend, + # Number of session runs. + # If there is no graph break, this should be the same as + # total number of forward calls. + expected_execution_count: int, + # Number of GraphModule's cached. + # With one graph break, a model will be mapped + # to two GraphModule's. + number_of_cached_graph_modules: int, + # Number of ONNX models cached for each GraphModule, + # number_of_exported_onnx_models[i] contains # of ONNX models exported from + # the i-th element (type: torch.fx.GraphModule) in + # OrtBackend._all_ort_execution_info.execution_info_per_graph_module.values(). + number_of_exported_onnx_models_for_all_graph_modules: tuple[int, ...], + ): + self.assertEqual(expected_execution_count, ort_backend.execution_count) + self.assertEqual( + len(ort_backend._all_ort_execution_info.execution_info_per_graph_module), + number_of_cached_graph_modules, + ) + self.assertEqual( + len(ort_backend._all_ort_execution_info.execution_info_per_graph_module), + len(number_of_exported_onnx_models_for_all_graph_modules), + ) + for ( + onnx_info, + expected_number_of_onnx_models, + ) in zip( + ort_backend._all_ort_execution_info.execution_info_per_graph_module.values(), + number_of_exported_onnx_models_for_all_graph_modules, + ): + self.assertEqual(len(onnx_info), expected_number_of_onnx_models) + + def _assert_dynamic_input_and_output_shapes_in_all_onnx_models(self, backend): + for ( + onnx_session_infos + ) in backend._all_ort_execution_info.execution_info_per_graph_module.values(): + for onnx_session_info in onnx_session_infos: + inputs_have_dynamic_shapes = False + for input in onnx_session_info.input_value_infos: + if hasattr(input.type, "tensor_type") and hasattr( + input.type.tensor_type, "shape" + ): + for dim in input.type.tensor_type.shape.dim: + inputs_have_dynamic_shapes = ( + inputs_have_dynamic_shapes or hasattr(dim, "dim_param") + ) + output_have_dynamic_shapes = False + for output in onnx_session_info.output_value_infos: + if hasattr(output.type, "tensor_type") and hasattr( + output.type.tensor_type, "shape" + ): + for dim in output.type.tensor_type.shape.dim: + output_have_dynamic_shapes = ( + output_have_dynamic_shapes or hasattr(dim, "dim_param") + ) + self.assertTrue(inputs_have_dynamic_shapes) + self.assertTrue(output_have_dynamic_shapes) + + @parameterized.expand( + [ + (True,), + (False,), + ] + ) + def test_elementwise_function_single_output(self, test_local_backend: bool): + example_args_collection = tuple( + (torch.randn(batch, dtype=torch.float32),) for batch in (2, 4, 6, 8, 10) + ) + + def elementwise_model(x: torch.Tensor): + y = x.relu() + z = y.sigmoid() + return z + + if test_local_backend: + local_aot_ort, local_ort = make_aot_ort() + else: + # This will use the global ONNXRuntime backend registered + # in Dynamo to compile the tested model. + local_aot_ort, local_ort = "onnxrt", None + + self._test_model_numerically( + elementwise_model, + local_aot_ort, + example_args_collection, + ) + + # We can only check local backend's counting information + # since global backend's counting information comes from + # all compiled models. + if test_local_backend: + assert local_ort is not None + self._assert_counting_information( + local_ort, + # OrtBackend._ort_acclerated_call should have been called 5 times because + # we have 5 different batch sizes to test. + expected_execution_count=len(example_args_collection), + # Since this local_ort only compiled one function, + # there should be only one GraphModule in its cached. + number_of_cached_graph_modules=1, + # Since dynamic shape is enabled, we should only have one ONNX model + # to support different batch sizes. + number_of_exported_onnx_models_for_all_graph_modules=(1,), + ) + + @parameterized.expand( + [ + (True,), + (False,), + ] + ) + def test_elementwise_function_multiple_output(self, test_local_backend: bool): + example_args_collection = tuple( + (torch.randn(batch, dtype=torch.float32),) for batch in (2, 4, 8) + ) + + def elementwise_model_with_multiple_outputs(w: torch.Tensor): + x = w + w + y = x.relu() + z = y * y + return x, y, z + + if test_local_backend: + local_aot_ort, local_ort = make_aot_ort() + else: + local_aot_ort, local_ort = "onnxrt", None + + self._test_model_numerically( + elementwise_model_with_multiple_outputs, + local_aot_ort, + example_args_collection, + ) + + if test_local_backend: + assert local_ort is not None + self._assert_counting_information( + local_ort, + expected_execution_count=len(example_args_collection), + number_of_cached_graph_modules=1, + number_of_exported_onnx_models_for_all_graph_modules=(1,), + ) + + @parameterized.expand( + [ + (True,), + (False,), + ] + ) + def test_mlp_with_local_backend(self, test_local_backend: bool): + example_args_collection = tuple( + (torch.randn(batch, 2, dtype=torch.float32),) for batch in (1, 2, 4, 6, 8) + ) + + class MLP(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = nn.Linear(2, 4, bias=True) + self.fc2 = nn.Linear(4, 2, bias=True) + + def forward(self, tensor_x: torch.Tensor): + tensor_x = self.fc1(tensor_x) + tensor_x = torch.sigmoid(tensor_x) + tensor_x = self.fc2(tensor_x) + tensor_x = torch.sigmoid(tensor_x) + return tensor_x + + if test_local_backend: + local_aot_ort, local_ort = make_aot_ort() + else: + local_aot_ort, local_ort = "onnxrt", None + + self._test_model_numerically( + MLP(), + local_aot_ort, + example_args_collection, + ) + + if test_local_backend: + assert local_ort is not None + self._assert_counting_information( + local_ort, + # OrtBackend._ort_acclerated_call should have been called 5 times because + # we have 5 different batch sizes to test. + expected_execution_count=len(example_args_collection), + # Since this local_ort only compiled one function, there should be only two + # GraphModule's in its cached. One for batch sizes 2, 4, 6, 8 and the other + # for batch size 1. + number_of_cached_graph_modules=2, + # Since dynamic shape is enabled, we should only have one ONNX model + # to support different batch sizes. + number_of_exported_onnx_models_for_all_graph_modules=(1, 1), + ) + + @parameterized.expand( + [ + (True, True), + (True, False), + ] + ) + @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/129456") + def test_llama_attention_with_local_backend( + self, test_local_backend: bool, test_backward: bool + ): + from transformers import LlamaConfig # noqa: F811 + from transformers.models.llama.modeling_llama import ( # noqa: F811 + LlamaAttention, + ) + + hidden_size = 16 + + config = LlamaConfig( + num_hidden_layers=1, + vocab_size=1024, + hidden_size=hidden_size, + intermediate_size=16, + max_position_embeddings=256, + num_attention_heads=2, + hidden_dropout_prob=0.0, + attention_dropout_prob=0.0, + ) + + class LlamaAttentionWrapper(torch.nn.Module): + def __init__(self, config): + super().__init__() + try: + # New version of LlamaAttention has layer_idx argument. + self.attention = LlamaAttention(config, layer_idx=0) + except TypeError: + # Fall back to old version of LlamaAttention. + self.attention = LlamaAttention(config) + + def forward(self, hidden_states, attention_mask, position_ids): + attn_output, _, _ = self.attention( + hidden_states, attention_mask, position_ids + ) + return attn_output + + def generate_example_inputs(batch: int, seq: int, hidden_size: int): + # shape: batch x seq x hidden_size + hidden_state = torch.randn(batch, seq, hidden_size) + # [0.0000e+00, ..., 0.0000e+00, -3.4028e+38, ...] + # shape: batch x 1 x seq x seq + attention_mask = torch.zeros(batch, 1, seq, seq, dtype=torch.float) + position_ids = torch.arange(0, seq, dtype=torch.int64) + position_ids = position_ids.unsqueeze(0).view(-1, seq) + + return hidden_state, attention_mask, position_ids + + # Reason for using multiple example argument groups: + # Export model to ONNX with one example argument group + # and test it with other example argument groups. + example_args_collection = ( + generate_example_inputs(2, 8, hidden_size), + generate_example_inputs(4, 7, hidden_size), + generate_example_inputs(9, 15, hidden_size), + ) + + if test_local_backend: + local_aot_ort, local_ort = make_aot_ort() + else: + local_aot_ort, local_ort = "onnxrt", None + + model = LlamaAttentionWrapper(config).eval() + + self._test_model_numerically( + model, + local_aot_ort, + example_args_collection, + fullgraph=True, + test_backward=test_backward, + ) + + if test_local_backend: + assert local_ort is not None + number_of_captured_graphs = 2 if test_backward else 1 + + execution_count = len(example_args_collection) * number_of_captured_graphs + self._assert_counting_information( + local_ort, + # Number of InferenceSession runs. + expected_execution_count=execution_count, + # Number of GraphModule's seen by ORT. + number_of_cached_graph_modules=number_of_captured_graphs, + # Number of InferenceSession's created per GraphModule. + number_of_exported_onnx_models_for_all_graph_modules=(1,) + * number_of_captured_graphs, + ) + self._assert_dynamic_input_and_output_shapes_in_all_onnx_models(local_ort) + + @parameterized.expand( + [ + (True, False), + (True, True), + ] + ) + @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/129456") + def test_llama_decoder_with_local_backend( + self, test_local_backend: bool, test_backward: bool + ): + from transformers import LlamaConfig # noqa: F811 + from transformers.models.llama.modeling_llama import ( # noqa: F811 + LlamaDecoderLayer, + ) + + hidden_size = 16 + + config = LlamaConfig( + num_hidden_layers=1, + vocab_size=1024, + hidden_size=hidden_size, + intermediate_size=16, + max_position_embeddings=256, + num_attention_heads=2, + hidden_dropout_prob=0.0, + attention_dropout_prob=0.0, + ) + + class LlamaDecoderWrapper(torch.nn.Module): + def __init__(self, config): + super().__init__() + try: + # New version of LlamaDecoderLayer has layer_idx argument. + self.decoder = LlamaDecoderLayer(config, layer_idx=0) + except TypeError: + # Fall back to old version of LlamaDecoderLayer. + self.decoder = LlamaDecoderLayer(config) + + def forward(self, hidden_states, attention_mask, position_ids): + (decoder_output,) = self.decoder( + hidden_states, attention_mask, position_ids + ) + return decoder_output + + def generate_example_inputs(batch: int, seq: int, hidden_size: int): + # shape: batch x seq x hidden_size + hidden_state = torch.randn(batch, seq, hidden_size) + # [0.0000e+00, ..., 0.0000e+00, -3.4028e+38, ...] + # shape: batch x 1 x seq x seq + attention_mask = torch.zeros(batch, 1, seq, seq, dtype=torch.float) + position_ids = torch.arange(0, seq, dtype=torch.int64) + position_ids = position_ids.unsqueeze(0).view(-1, seq) + return hidden_state, attention_mask, position_ids + + # Reason for using multiple example argument groups: + # Export model to ONNX with one example argument group + # and test it with other example argument groups. + example_args_collection = ( + generate_example_inputs(2, 8, hidden_size), + generate_example_inputs(4, 7, hidden_size), + generate_example_inputs(9, 15, hidden_size), + ) + + if test_local_backend: + local_aot_ort, local_ort = make_aot_ort() + else: + local_aot_ort, local_ort = "onnxrt", None + + model = LlamaDecoderWrapper(config).eval() + + self._test_model_numerically( + model, + local_aot_ort, + example_args_collection, + fullgraph=True, + test_backward=test_backward, + ) + + if test_local_backend: + assert local_ort is not None + number_of_captured_graphs = 2 if test_backward else 1 + + execution_count = len(example_args_collection) * number_of_captured_graphs + + self._assert_counting_information( + local_ort, + expected_execution_count=execution_count, + number_of_cached_graph_modules=number_of_captured_graphs, + number_of_exported_onnx_models_for_all_graph_modules=(1,) + * number_of_captured_graphs, + ) + self._assert_dynamic_input_and_output_shapes_in_all_onnx_models(local_ort) + + @parameterized.expand( + [ + (True, False), + (True, True), + ] + ) + @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/129456") + def test_llama_with_local_backend( + self, test_local_backend: bool, test_backward: bool + ): + from transformers import LlamaConfig # noqa: F811 + from transformers.models.llama.modeling_llama import LlamaModel # noqa: F811 + + config = LlamaConfig( + num_hidden_layers=1, + vocab_size=1024, + hidden_size=16, + intermediate_size=16, + max_position_embeddings=256, + num_attention_heads=2, + hidden_dropout_prob=0.0, + attention_dropout_prob=0.0, + ) + + config._attn_implementation = "eager" + + class LlamaModelWrapper(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.llama = LlamaModel(config) + + def forward(self, input_ids, attention_mask, position_ids): + decoder_output = self.llama( + input_ids, attention_mask, position_ids, return_dict=False + ) + return decoder_output[0] + + def generate_example_inputs(batch: int, seq: int): + # shape: batch x seq x hidden_size + input_ids = torch.randint(0, 7, size=(batch, seq), dtype=torch.int64) + # Usually, its shape is a tensor with shape batch x seq x seq. + # However, to bypass some control flow in the model, we use None. + attention_mask = None + position_ids = torch.arange(0, seq, dtype=torch.int64) + position_ids = position_ids.unsqueeze(0).view(-1, seq) + return input_ids, attention_mask, position_ids + + # Reason for using multiple example argument groups: + # Export model to ONNX with one example argument group + # and test it with other example argument groups. + example_args_collection = ( + generate_example_inputs(2, 8), + generate_example_inputs(4, 7), + generate_example_inputs(9, 15), + ) + + if test_local_backend: + local_aot_ort, local_ort = make_aot_ort() + else: + local_aot_ort, local_ort = "onnxrt", None + + model = LlamaModelWrapper(config).eval() + + self._test_model_numerically( + model, + local_aot_ort, + example_args_collection, + fullgraph=True, + test_backward=test_backward, + atol=1e-4, + rtol=1e-4, + ) + + if test_local_backend: + assert local_ort is not None + number_of_captured_graphs = 2 if test_backward else 1 + execution_count = len(example_args_collection) * number_of_captured_graphs + self._assert_counting_information( + local_ort, + expected_execution_count=execution_count, + number_of_cached_graph_modules=number_of_captured_graphs, + number_of_exported_onnx_models_for_all_graph_modules=(1,) + * number_of_captured_graphs, + ) + self._assert_dynamic_input_and_output_shapes_in_all_onnx_models(local_ort) + + @parameterized.expand( + [ + (True,), + (False,), + ] + ) + def test_dump_model(self, test_local_backend: bool): + @contextlib.contextmanager + def onnxrt_dump_path(path): + key = "ONNXRT_DUMP_PATH" + before = os.environ.get(key, None) + os.environ[key] = path + yield + if before is None: + del os.environ[key] + else: + os.environ[key] = before + + example_args_collection = tuple( + (torch.randn(batch, 2, dtype=torch.float32),) for batch in (1, 2, 4, 6, 8) + ) + + class MLP(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = nn.Linear(2, 4, bias=True) + self.fc2 = nn.Linear(4, 2, bias=True) + + def forward(self, tensor_x: torch.Tensor): + tensor_x = self.fc1(tensor_x) + tensor_x = torch.sigmoid(tensor_x) + tensor_x = self.fc2(tensor_x) + tensor_x = torch.sigmoid(tensor_x) + return tensor_x + + if test_local_backend: + local_aot_ort, _ = make_aot_ort() + else: + local_aot_ort, _ = "onnxrt", None + + prefix = f"test_dump_model_{'local' if test_local_backend else 'onnxrt'}_" + expected = f"{prefix}0.onnx" + expected_graph = f"{prefix}0.txt" + if os.path.exists(expected): + os.remove(expected) + if os.path.exists(expected_graph): + os.remove(expected_graph) + not_expected = f"{prefix}1.onnx" + self.assertFalse(os.path.exists(not_expected)) + + model = MLP() + compiled_model = torch.compile( + model if not isinstance(model, torch.nn.Module) else copy.deepcopy(model), + backend=local_aot_ort, + dynamic=True, + ) + + self.assertFalse(os.path.exists(expected)) + self.assertFalse(os.path.exists(not_expected)) + + with onnxrt_dump_path(prefix): + example_args = example_args_collection[0] + compiled_model(*example_args) + self.assertTrue(os.path.exists(expected)) + self.assertTrue(os.path.exists(expected_graph)) + self.assertFalse(os.path.exists(not_expected)) + + compiled_model(*example_args) + self.assertTrue(os.path.exists(expected)) + self.assertFalse(os.path.exists(not_expected)) + + @unittest.skipIf(not torch.cuda.is_available(), "No CUDA to run mix devicei nputs") + def test_mix_device_inputs(self): + data = torch.randn(4, 8, device="cuda") + ref_data = torch.randn(8, 4, device="cpu") + + def reshape_wrapper(data, ref_cpu_data): + # Dummy line to make sure ref_cpu_data + # is included in the captured graph. + ref_cpu_data += 1 + shape = ref_cpu_data.shape + # A call with GPU and CPU inputs. + return torch.reshape(data, shape) + + compiled_model = torch.compile( + reshape_wrapper, + backend="onnxrt", + dynamic=True, + ) + + result = compiled_model(data, ref_data) + + self.assertTrue(torch.allclose(result, data.view(ref_data.shape))) + + def test_no_input(self): + def reshape_wrapper(): + # A model without input. + ones = torch.ones(4, 8) + zeros = torch.zeros(4, 8) + return ones + zeros + + recorded_models = [] + + def record_onnx_model_transform(onnx_model): + # Record the ONNX model seen by the transform. + recorded_models.append(onnx_model) + + compiled_model = torch.compile( + reshape_wrapper, + backend="onnxrt", + dynamic=True, + options=torch.onnx._OrtBackendOptions( + pre_ort_model_transforms=[ + record_onnx_model_transform, + ] + ), + ) + + result = compiled_model() + + self.assertEqual(len(recorded_models), 1) + # NOTE: Constant folded by optimizer + self.assertTrue( + "Constant" in [node.op_type for node in recorded_models[0].graph.node] + ) + + self.assertEqual(result, torch.ones(4, 8)) + + def test_custom_onnx_transform(self): + # This test consists of 2 parts: + # 1. If a registered ONNX transform is called and recorded a model. + # 2. If a registered ONNX transform is called and changed the model + + # Part 1: Record the ONNX model seen by the transform. + # This list contains the models recorded by record_onnx_model_transform. + recorded_models = [] + + def record_onnx_model_transform(onnx_model): + # Record the ONNX model seen by the transform. + recorded_models.append(onnx_model) + + def example_model(x: torch.Tensor): + y = torch.sigmoid(x) + z = x + y + return z + + compiled_model = torch.compile( + example_model, + backend="onnxrt", + dynamic=True, + options=torch.onnx._OrtBackendOptions( + pre_ort_model_transforms=[record_onnx_model_transform] + ), + ) + + x = torch.randn(2) + assert len(recorded_models) == 0 + y = compiled_model(x) + assert len(recorded_models) == 1 + + # Part 2: Change the ONNX model seen by the transform so that + # ORT receives a different model. + # NOTE: the function is optimized away by optimizer + def replace_relu_with_sigmoid(onnx_model): + for node in onnx_model.graph.node: + if node.op_type == "Relu": + node.op_type = "Sigmoid" + + def another_example_model(x: torch.Tensor): + y = torch.relu(x) + z = x + y + return z + + another_compiled = torch.compile( + another_example_model, + backend="onnxrt", + dynamic=True, + options=torch.onnx._OrtBackendOptions( + pre_ort_model_transforms=[ + replace_relu_with_sigmoid, + record_onnx_model_transform, + ] + ), + ) + + another_y = another_compiled(x) + # We have 2 models recorded `record_onnx_model_transform` + # by the 2 torch.compile calls above. + assert len(recorded_models) == 2 + # Since we have changed "Relu" to "Sigmoid" in replace_sigmoid_with_relu, + # the result should be the same to previous y. + torch.testing.assert_close(y, another_y) + # another_example_model still uses "Relu", so the result should be different + # than y. + self.assertFalse(torch.allclose(y, another_example_model(x))) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 24a9176bbe5bc..9fd3c123b5461 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -4,12 +4,22 @@ from __future__ import annotations import io +<<<<<<< HEAD import logging import os from onnxscript import BOOL, FLOAT, opset18 as op import torch +======= +import os + +import numpy as np +from onnxscript import BOOL, FLOAT, ir, opset18 as op + +import torch +import torch.onnx._flags +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.onnx._internal.exporter import _testing as onnx_testing from torch.testing._internal import common_utils @@ -28,11 +38,14 @@ def forward(self, x, b): return (y, z) +<<<<<<< HEAD class SampleModelReduction(torch.nn.Module): def forward(self, x): return x.sum() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SampleModelForDynamicShapes(torch.nn.Module): def forward(self, x, b): return x.relu(), b.sigmoid() @@ -70,7 +83,10 @@ def assert_export( ) assert onnx_program is not None onnx_testing.assert_onnx_program(onnx_program, strategy=strategy) +<<<<<<< HEAD return onnx_program +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_args_normalization_with_no_kwargs(self): self.assert_export( @@ -78,6 +94,7 @@ def test_args_normalization_with_no_kwargs(self): (torch.randn(1, 1, 2), torch.randn(1, 1, 2)), ) +<<<<<<< HEAD def test_lower_opset_support(self): # First test that opset 18 (torchlib opset works) onnx_program = self.assert_export( @@ -152,6 +169,8 @@ def forward(self, a, x): self.assertEqual(len(onnx_program.model.graph.inputs), 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self): self.assert_export( SampleModelForDynamicShapes(), @@ -250,6 +269,38 @@ def test_partial_dynamic_shapes(self): }, ) +<<<<<<< HEAD +======= + def test_auto_convert_all_axes_to_dynamic_shapes_with_dynamo_export(self): + torch.onnx._flags.USE_EXPERIMENTAL_LOGIC = True + + class Nested(torch.nn.Module): + def forward(self, x): + (a0, a1), (b0, b1), (c0, c1, c2) = x + return a0 + a1 + b0 + b1 + c0 + c1 + c2 + + inputs = ( + (1, 2), + ( + torch.randn(4, 4), + torch.randn(4, 4), + ), + ( + torch.randn(4, 4), + torch.randn(4, 4), + torch.randn(4, 4), + ), + ) + + onnx_program = torch.onnx.dynamo_export( + Nested(), + inputs, + export_options=torch.onnx.ExportOptions(dynamic_shapes=True), + ) + assert onnx_program is not None + onnx_testing.assert_onnx_program(onnx_program) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dynamic_shapes_supports_nested_input_model_with_input_names_assigned(self): # kwargs can still be renamed as long as it's in order input_names = ["input_x", "input_y", "input_z", "d", "e", "f"] @@ -297,11 +348,18 @@ def forward(self, input): # Use GELU activation function return torch.nn.functional.gelu(input, approximate="tanh") +<<<<<<< HEAD input = (torch.randn(1, 3, 4, 4),) onnx_program_op18 = torch.onnx.export( GeluModel(), input, opset_version=18, +======= + input = torch.randn(1, 3, 4, 4) + onnx_program_op18 = torch.onnx.export( + GeluModel(), + input, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dynamo=True, ) all_nodes_op18 = [n.op_type for n in onnx_program_op18.model.graph] @@ -355,6 +413,7 @@ def test_export_successful_when_dynamic_dimension_is_one(self): ), ) +<<<<<<< HEAD def test_is_in_onnx_export(self): class Mod(torch.nn.Module): def forward(self, x): @@ -396,6 +455,8 @@ def forward(self, x): ) onnx_testing.assert_onnx_program(onnx_program) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestCustomTranslationTable(common_utils.TestCase): def test_custom_translation_table_overrides_ops(self): @@ -528,5 +589,138 @@ def onnx_add(self: FLOAT, other: FLOAT) -> FLOAT: self.assertNotIn("Sub", all_nodes_decomp) +<<<<<<< HEAD +======= +class TestFakeTensorExport(common_utils.TestCase): + """Test exporting in fake mode.""" + + def test_onnx_program_raises_when_model_defined_in_fake_mode(self): + with torch.onnx.enable_fake_mode(): + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor(42.0)) + + def forward(self, x): + return self.weight + x + + onnx_program = torch.onnx.export( + Model(), (torch.tensor(1.0),), dynamo=True, optimize=False + ) + assert onnx_program is not None + # Convert to model proto and back to trigger to_bytes method which serializes the tensor + with self.assertRaises(Exception): + # The tensors need to be replaced with real tensors + _ = onnx_program.model_proto + + # Convert to model proto and back to trigger to_bytes method which serializes the tensor + with self.assertRaises(Exception): + # It doesn't matter if it is called inside or outside of the enable_fake_mode() context + _ = onnx_program.model_proto + + # If we replace with concrete tensors, the serialization will succeed. + # This needs to happen outside of the fake context + onnx_program.apply_weights({"weight": torch.tensor(42.0)}) + onnx_model = ir.serde.deserialize_model(onnx_program.model_proto) + np.testing.assert_allclose( + onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0 + ) + + def test_onnx_program_save_raises_when_model_initialized_in_fake_mode(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor(42.0)) + + def forward(self, x): + return self.weight + x + + with torch.onnx.enable_fake_mode(): + onnx_program = torch.onnx.export( + Model(), (torch.tensor(1.0),), dynamo=True, optimize=False + ) + assert onnx_program is not None + # Convert to model proto and back to trigger to_bytes method which serializes the tensor + with self.assertRaises(Exception): + # The tensors need to be replaced with real tensors + _ = onnx_program.model_proto + + with self.assertRaises(Exception): + # It doesn't matter if it is called inside or outside of the enable_fake_mode() context + _ = onnx_program.model_proto + + # If we replace with concrete tensors, the serialization will succeed + # This needs to happen outside of the fake context + onnx_program.apply_weights({"weight": torch.tensor(42.0)}) + onnx_model = ir.serde.deserialize_model(onnx_program.model_proto) + np.testing.assert_allclose( + onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0 + ) + + def test_onnx_program_save_succeeds_when_export_and_save_in_fake_mode(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor(42.0)) + + def forward(self, x): + return self.weight + x + + real_model = Model() + + with torch.onnx.enable_fake_mode(): + onnx_program = torch.onnx.export( + real_model, (torch.tensor(1.0),), dynamo=True, optimize=False + ) + + assert onnx_program is not None + # Convert to model proto and back to trigger to_bytes method which serializes the tensor + # Note that even though we are calling .model_proto (equivalently .save()) in fake mode, + # the concrete tensors are maintained. + # This is due to the usage of torch._subclasses.fake_tensor.unset_fake_temporarily() in + # TorchTensor.tobytes() + onnx_model = ir.serde.deserialize_model(onnx_program.model_proto) + np.testing.assert_allclose( + onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0 + ) + + # This works inside or outside the fake mode + onnx_model = ir.serde.deserialize_model(onnx_program.model_proto) + np.testing.assert_allclose( + onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0 + ) + + def test_is_in_onnx_export(self): + class Mod(torch.nn.Module): + def forward(self, x): + def f(x): + return x.sin() if torch.onnx.is_in_onnx_export() else x.cos() + + return f(x) + + self.assertFalse(torch.onnx.is_in_onnx_export()) + onnx_program = torch.onnx.export( + Mod(), + (torch.randn(3, 4),), + dynamo=True, + fallback=False, + ) + self.assertFalse(torch.onnx.is_in_onnx_export()) + + node_names = [n.op_type for n in onnx_program.model.graph] + self.assertIn("Sin", node_names) + + def test_torchscript_exporter_raises_deprecation_warning(self): + # Test that the deprecation warning is raised when using torchscript exporter + with self.assertWarnsRegex( + DeprecationWarning, "You are using the legacy TorchScript-based ONNX export" + ): + torch.onnx.export( + SampleModel(), (torch.randn(1, 1, 2),), io.BytesIO(), dynamo=False + ) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/exporter/test_dynamic_shapes.py b/test/onnx/exporter/test_dynamic_shapes.py index 42a08e5647bdb..3da3a2cf0ebac 100644 --- a/test/onnx/exporter/test_dynamic_shapes.py +++ b/test/onnx/exporter/test_dynamic_shapes.py @@ -199,7 +199,10 @@ def test_dynamic_shapes_supports_nested_input_model_with_input_names_assigned( filename, dynamic_axes=dynamic_axes, input_names=input_names, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) onnx_model = onnx.load(filename) diff --git a/test/onnx/exporter/test_small_models_e2e.py b/test/onnx/exporter/test_small_models_e2e.py index c5dd4132f5763..a038c3d8df03f 100644 --- a/test/onnx/exporter/test_small_models_e2e.py +++ b/test/onnx/exporter/test_small_models_e2e.py @@ -5,6 +5,11 @@ import logging +<<<<<<< HEAD +======= +import onnx.reference as onnx_ref + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import onnxruntime import pytest import transformers @@ -18,7 +23,11 @@ def has_onnxruntime_opset_23() -> bool: +<<<<<<< HEAD return version.parse(onnxruntime.__version__) >= version.parse("1.23") +======= + return version.parse(onnxruntime.__version__) >= version.parse("1.22") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _WithExport: @@ -732,7 +741,10 @@ def forward(self, x): [node.op_type for node in onnx_program.model.graph], ) +<<<<<<< HEAD def test_attention_opset_23(self): +======= + def test_graph_attention_opset_23(self): class Model(torch.nn.Module): def forward(self, query, key, value): return torch.nn.functional.scaled_dot_product_attention( @@ -742,7 +754,30 @@ def forward(self, query, key, value): query = torch.rand(32, 8, 128, 64, dtype=torch.float16) key = torch.rand(32, 8, 128, 64, dtype=torch.float16) value = torch.rand(32, 8, 128, 64, dtype=torch.float16) + expected = Model()(query, key, value) + onnx_program = self.export(Model(), (query, key, value), opset_version=23) + self.assertIn("Attention", [node.op_type for node in onnx_program.model.graph]) + + ref = onnx_ref.ReferenceEvaluator(onnx_program.model_proto) + got = ref.run( + None, dict(query=query.numpy(), key=key.numpy(), value=value.numpy()) + )[0] + torch.testing.assert_close(torch.from_numpy(got), expected, atol=1e-2, rtol=1) + + def test_graph_accuracy_attention_opset_23(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + class Model(torch.nn.Module): + def forward(self, query, key, value): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value + ) + + query = torch.rand(32, 8, 128, 64, dtype=torch.float16) + key = torch.rand(32, 8, 128, 64, dtype=torch.float16) + value = torch.rand(32, 8, 128, 64, dtype=torch.float16) + +<<<<<<< HEAD onnx_program = self.export(Model(), (query, key, value), opset_version=23) self.assertEqual(["Attention"], [n.op_type for n in onnx_program.model.graph]) @@ -805,6 +840,15 @@ def forward(self, x): # Test with reference evaluator because ORT does not support the op as of version 1.22 onnx_testing.assert_onnx_program(onnx_program, backend="reference") +======= + onnx_program = self.export( + Model(), (query, key, value), opset_version=23, optimize=True + ) + self.assertEqual(["Attention"], [n.op_type for n in onnx_program.model.graph]) + # onnxruntime inlines any op defined as a function and without any implemented kernel + if has_onnxruntime_opset_23(): + onnx_testing.assert_onnx_program(onnx_program, atol=1e-2, rtol=1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": diff --git a/test/onnx/internal/test_registraion.py b/test/onnx/internal/test_registraion.py index fcc4cdeedd92f..9a4d1ff630b78 100644 --- a/test/onnx/internal/test_registraion.py +++ b/test/onnx/internal/test_registraion.py @@ -4,7 +4,11 @@ from collections.abc import Sequence from torch.onnx import errors +<<<<<<< HEAD from torch.onnx._internal.torchscript_exporter import registration +======= +from torch.onnx._internal import registration +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal import common_utils diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index ab2bfb51bdea4..21fa3ad2f819a 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -17,8 +17,12 @@ import torch from torch import export as torch_export +<<<<<<< HEAD from torch.onnx import _constants from torch.onnx._internal.torchscript_exporter import verification +======= +from torch.onnx import _constants, verification +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal import common_utils from torch.testing._internal.opinfo import core as opinfo_core from torch.types import Number @@ -66,6 +70,39 @@ def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs): return verification.verify(*args, options=options, **kwargs) +<<<<<<< HEAD +======= +def assert_dynamic_shapes(onnx_program: torch.onnx.ONNXProgram, dynamic_shapes: bool): + """Assert whether the exported model has dynamic shapes or not. + + Args: + onnx_program (torch.onnx.ONNXProgram): The output of torch.onnx.dynamo_export. + dynamic_shapes (bool): Whether the exported model has dynamic shapes or not. + When True, raises if graph inputs don't have at least one dynamic dimension + When False, raises if graph inputs have at least one dynamic dimension. + + Raises: + AssertionError: If the exported model has dynamic shapes and dynamic_shapes is False and vice-versa. + """ + + if dynamic_shapes is None: + return + + model_proto = onnx_program.model_proto + # Process graph inputs + dynamic_inputs = [] + for inp in model_proto.graph.input: + dynamic_inputs += [ + dim + for dim in inp.type.tensor_type.shape.dim + if dim.dim_value == 0 and dim.dim_param != "" + ] + assert dynamic_shapes == (len(dynamic_inputs) > 0), ( + "Dynamic shape check failed for graph inputs" + ) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def parameterize_class_name(cls: type, idx: int, input_dicts: Mapping[Any, Any]): """Combine class name with the parameterized arguments. diff --git a/test/onnx/ops/test_ops.py b/test/onnx/ops/test_ops.py index 3736b930900f1..92c4c4706b3b7 100644 --- a/test/onnx/ops/test_ops.py +++ b/test/onnx/ops/test_ops.py @@ -4,20 +4,29 @@ from __future__ import annotations import onnx_ir.passes.common as common_passes +<<<<<<< HEAD import onnxruntime from onnxscript import ir from packaging import version import torch from torch.onnx._internal.exporter import _testing as onnx_testing +======= +from onnxscript import ir + +import torch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.onnx.ops import _impl, _symbolic_impl from torch.testing._internal import common_utils +<<<<<<< HEAD def has_onnxruntime_opset_23() -> bool: return version.parse(onnxruntime.__version__) >= version.parse("1.23") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SchemaTest(common_utils.TestCase): def test_symbolic_has_correct_schema(self): torch.library.opcheck( @@ -439,7 +448,11 @@ def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgr def test_onnx_ops_can_be_decomposed_to_aten(self): input_data = torch.rand(2, 3, 4, 8) +<<<<<<< HEAD position_ids_data = torch.randint(0, 50, (2, 4)).long() +======= + position_ids_data = torch.randint(0, 50, (2, 3)).long() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sin_cache_data = torch.rand(50, 4) cos_cache_data = torch.rand(50, 4) @@ -480,7 +493,11 @@ def forward( def test_rotary_embedding_opcheck(self): input_data = torch.rand(2, 3, 4, 8) +<<<<<<< HEAD position_ids_data = torch.randint(0, 50, (2, 4)).long() +======= + position_ids_data = torch.randint(0, 50, (2, 3)).long() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sin_cache_data = torch.rand(50, 4) cos_cache_data = torch.rand(50, 4) @@ -491,7 +508,11 @@ def test_rotary_embedding_opcheck(self): def test_rotary_embedding(self): input_data = torch.rand(2, 3, 4, 8) +<<<<<<< HEAD position_ids_data = torch.randint(0, 50, (2, 4)).long() +======= + position_ids_data = torch.randint(0, 50, (2, 3)).long() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sin_cache_data = torch.rand(50, 4) cos_cache_data = torch.rand(50, 4) @@ -532,6 +553,7 @@ def forward( ) self.assertEqual(onnx_program.model.opset_imports[""], 23) self.assertEqual("RotaryEmbedding", onnx_program.model.graph.node(0).op_type) +<<<<<<< HEAD if has_onnxruntime_opset_23(): onnx_testing.assert_onnx_program(onnx_program) else: @@ -575,6 +597,8 @@ def forward(self, input_data, cos_cache_data, sin_cache_data): else: # Test with reference evaluator because ORT does not support the op as of version 1.22 onnx_testing.assert_onnx_program(onnx_program, backend="reference") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_attention_basic(self): """Test basic attention functionality.""" diff --git a/test/onnx/test_autograd_funs.py b/test/onnx/test_autograd_funs.py index 81c70d7d98777..30342b3ad01b3 100644 --- a/test/onnx/test_autograd_funs.py +++ b/test/onnx/test_autograd_funs.py @@ -5,11 +5,20 @@ import torch from torch.onnx import OperatorExportTypes +<<<<<<< HEAD +======= +from torch.onnx._globals import GLOBALS +from torch.onnx.utils import _model_to_graph +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal import common_utils class TestAutogradFuns(pytorch_test_common.ExportTestCase): +<<<<<<< HEAD opset_version = 20 +======= + opset_version = GLOBALS.export_onnx_opset_version +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) keep_initializers_as_inputs = False onnx_shape_inference = True @@ -131,7 +140,11 @@ def forward(self, input): input = torch.ones(1, 5) # Test ONNX_FALLTHROUGH_MODE +<<<<<<< HEAD graph, _, _ = torch.onnx.utils._model_to_graph( +======= + graph, _, _ = _model_to_graph( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model, (input,), operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, @@ -140,7 +153,11 @@ def forward(self, input): self.assertEqual(next(iter).kind(), "prim::PythonOp") # Test ATEN_FALLBACK_MODE +<<<<<<< HEAD graph, _, _ = torch.onnx.utils._model_to_graph( +======= + graph, _, _ = _model_to_graph( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model, (input,), operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, diff --git a/test/onnx/test_fx_passes.py b/test/onnx/test_fx_passes.py new file mode 100644 index 0000000000000..97d255abdcb14 --- /dev/null +++ b/test/onnx/test_fx_passes.py @@ -0,0 +1,60 @@ +# Owner(s): ["module: onnx"] +import torch +import torch._dynamo +import torch.fx +from torch.onnx._internal.fx.passes import _utils as pass_utils +from torch.testing._internal import common_utils + + +class TestFxPasses(common_utils.TestCase): + def test_set_node_name_correctly_renames_when_new_name_collides_recursively(self): + def func(x, y, z): + return x + y + z + + x = torch.randn(3) + y = torch.randn(3) + z = torch.randn(3) + gm, _ = torch._dynamo.export(func)(x, y, z) + torch._dynamo.reset() + + # Purposely name the nodes in a way that will cause a recursive collision later. + # See :func:`set_node_name` for name collision renaming logic. + base_name = "tensor" + nodes = list(gm.graph.nodes) + for i, node in enumerate(nodes[1:]): + if i == 0: + node.name = base_name + else: + node.name = f"{base_name}.{i}" + + # Run `set_node_name` and verify that the names are correct. + name_to_node = {node.name: node for node in gm.graph.nodes} + pass_utils.set_node_name(nodes[0], base_name, name_to_node) + assert nodes[0].name == base_name, f"Expected {base_name}, got {nodes[0].name}" + assert len({node.name for node in nodes}) == len(nodes), ( + f"Expected all names to be unique, got {nodes}" + ) + + def test_set_node_name_succeeds_when_no_name_collisions(self): + def func(x, y, z): + return x + y + z + + x = torch.randn(3) + y = torch.randn(3) + z = torch.randn(3) + gm, _ = torch._dynamo.export(func)(x, y, z) + torch._dynamo.reset() + + # Run `set_node_name` and verify that the names are correct. + new_name = "some_tensor" + nodes = list(gm.graph.nodes) + name_to_node = {node.name: node for node in nodes} + pass_utils.set_node_name(nodes[1], new_name, name_to_node) + assert nodes[1].name == new_name, f"Expected {new_name}, got {nodes[0].name}" + assert len({node.name for node in nodes}) == len(nodes), ( + f"Expected all names to be unique, got {nodes}" + ) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py index 75de1f3fab83e..57308b6d9c5bc 100644 --- a/test/onnx/test_onnx_opset.py +++ b/test/onnx/test_onnx_opset.py @@ -11,7 +11,11 @@ import torch.onnx from torch.nn import Module from torch.onnx import producer_name, producer_version +<<<<<<< HEAD from torch.onnx._internal.torchscript_exporter._globals import GLOBALS +======= +from torch.onnx._globals import GLOBALS +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal import common_utils @@ -67,7 +71,10 @@ def check_onnx_opsets_operator( training=training, input_names=input_names, dynamic_axes=dynamic_axes, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) model = onnx.load(io.BytesIO(f.getvalue())) check_onnx_opset_operator(model, ops[opset_version], opset_version) diff --git a/test/onnx/test_onnxscript_no_runtime.py b/test/onnx/test_onnxscript_no_runtime.py index e47c88b4c4406..3cc7f4cb190ed 100644 --- a/test/onnx/test_onnxscript_no_runtime.py +++ b/test/onnx/test_onnxscript_no_runtime.py @@ -10,7 +10,11 @@ from onnxscript.onnx_types import FLOAT import torch +<<<<<<< HEAD from torch.onnx._internal.torchscript_exporter import jit_utils +======= +from torch.onnx._internal import jit_utils +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal import common_utils @@ -86,20 +90,28 @@ def custom_layer_norm( x = torch.randn(1, 2, 3, 4, requires_grad=True) model_selu = torch.nn.SELU() selu_onnx = io.BytesIO() +<<<<<<< HEAD torch.onnx.export( model_selu, x, selu_onnx, opset_version=self.opset_version, dynamo=False ) +======= + torch.onnx.export(model_selu, x, selu_onnx, opset_version=self.opset_version) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) N, C = 3, 4 y = torch.randn(N, C) model_layer_norm = torch.nn.LayerNorm(C) layer_norm_onnx = io.BytesIO() torch.onnx.export( +<<<<<<< HEAD model_layer_norm, y, layer_norm_onnx, opset_version=self.opset_version, dynamo=False, +======= + model_layer_norm, y, layer_norm_onnx, opset_version=self.opset_version +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # 4. test on models @@ -162,11 +174,15 @@ def custom_selu(g, X): saved_model = io.BytesIO() torch.onnx.export( +<<<<<<< HEAD torch.jit.script(model), inputs, f=saved_model, opset_version=15, dynamo=False, +======= + torch.jit.script(model), inputs, f=saved_model, opset_version=15 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) loop_selu_proto = onnx.load(io.BytesIO(saved_model.getvalue())) self.assertEqual(len(loop_selu_proto.functions), 1) diff --git a/test/onnx/test_onnxscript_runtime.py b/test/onnx/test_onnxscript_runtime.py index dc19971498d95..bffb186cafc9e 100644 --- a/test/onnx/test_onnxscript_runtime.py +++ b/test/onnx/test_onnxscript_runtime.py @@ -9,7 +9,11 @@ from onnxscript.onnx_types import FLOAT import torch +<<<<<<< HEAD from torch.onnx._internal.torchscript_exporter import jit_utils +======= +from torch.onnx._internal import jit_utils +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal import common_utils diff --git a/test/onnx/test_pytorch_jit_onnx.py b/test/onnx/test_pytorch_jit_onnx.py index bc3c64ab8679b..ab34b7f74ec69 100644 --- a/test/onnx/test_pytorch_jit_onnx.py +++ b/test/onnx/test_pytorch_jit_onnx.py @@ -4,11 +4,16 @@ from pytorch_test_common import skipIfNoCuda import torch +<<<<<<< HEAD from torch.onnx._internal.torchscript_exporter import verification from torch.onnx._internal.torchscript_exporter._globals import GLOBALS from torch.onnx._internal.torchscript_exporter.utils import ( _trigger_symbolic_function_registration, ) +======= +from torch.onnx import verification +from torch.onnx._globals import GLOBALS +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal import common_utils @@ -23,7 +28,10 @@ def _jit_graph_to_onnx_model(graph, operator_export_type, opset_version): """ GLOBALS.export_onnx_opset_version = opset_version +<<<<<<< HEAD _trigger_symbolic_function_registration() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph = torch.onnx.utils._optimize_graph( graph, operator_export_type, params_dict={} ) diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py new file mode 100644 index 0000000000000..b3a3aa01cf3c0 --- /dev/null +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -0,0 +1,1226 @@ +# Owner(s): ["module: onnx"] + +"""Tests for onnx export that don't run the exported model.""" + +from __future__ import annotations + +import contextlib +import io +import itertools +import unittest +import unittest.mock +import warnings +from typing import Callable, Optional, TYPE_CHECKING, Union + +import numpy as np + +import onnx +import onnx.numpy_helper +import pytorch_test_common + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.onnx import symbolic_helper, utils +from torch.onnx._internal import registration +from torch.testing._internal import common_quantization, common_utils, jit_utils + + +if TYPE_CHECKING: + from collections.abc import Iterable + + +def export_to_onnx( + model: Union[torch.nn.Module, torch.jit.ScriptFunction], + input: Union[torch.Tensor, tuple[torch.Tensor]], + custom_ops: Optional[ + Iterable[Union[contextlib.AbstractContextManager, contextlib.ContextDecorator]] + ] = None, + mocks: Optional[Iterable] = None, + operator_export_type: torch.onnx.OperatorExportTypes = torch.onnx.OperatorExportTypes.ONNX, + opset_version: int = 17, + **torch_onnx_export_kwargs, +) -> onnx.ModelProto: + """Exports `model(input)` to ONNX and returns it. + + Custom operators and/or unittest patches can be used help reproducing specific behaviors. + + Args: + model: model to export + input: model input with same format as `torch.onnx.export(..,args,...)` + custom_ops: list of custom operators to use during export + mocks: list of mocks to use during export + operator_export_type: export type as described by `torch.onnx.export(...operator_export_type,...)` + opset_version: ONNX opset version as described by `torch.onnx.export(...opset_version,...)` + torch_onnx_export_kwargs: extra torch.onnx.export kwargs arguments + Returns: + A valid ONNX model (`onnx.ModelProto`) + """ + custom_ops = custom_ops or [] + mocks = mocks or [] + with contextlib.ExitStack() as stack: + for ctx in itertools.chain(custom_ops, mocks): + stack.enter_context(ctx) + + f = io.BytesIO() + torch.onnx.export( + model, + input, + f, + operator_export_type=operator_export_type, + opset_version=opset_version, + **torch_onnx_export_kwargs, + ) + + # Validate ONNX graph before returning it + onnx_model = onnx.load_from_string(f.getvalue()) + onnx.checker.check_model(onnx_model) + return onnx_model + + +@common_utils.instantiate_parametrized_tests +class TestONNXExport(pytorch_test_common.ExportTestCase): + def test_fuse_addmm(self): + class AddmmModel(torch.nn.Module): + def forward(self, x): + return torch.mm(x, x) + x + + x = torch.ones(3, 3) + f = io.BytesIO() + torch.onnx.export(AddmmModel(), x, f) + + def test_onnx_transpose_incomplete_tensor_type(self): + # Smoke test to get us into the state where we are attempting to export + # a transpose op, where the input is a TensorType without size information. + # This would previously not work, since we would + # take the size of the input and use the length of its sizes as the + # number of dimensions in the permutation. + class Foo(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + return x.contiguous().transpose(0, 1).sum() + + class TraceMe(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.foo = Foo() + + def forward(self, x): + return self.foo(x) + + tm = TraceMe() + tm = torch.jit.trace(tm, torch.rand(3, 4)) + f = io.BytesIO() + torch.onnx.export(tm, (torch.rand(3, 4),), f) + + def test_export_tensoroption_to(self): + def foo(x): + return x[0].detach().clone().cpu() + x + + traced = torch.jit.trace(foo, (torch.rand([2]))) + + f = io.BytesIO() + torch.onnx.export(traced, (torch.rand([2]),), f) + + def test_onnx_export_script_module(self): + class ModuleToExport(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + y = x - x # noqa: F841 + return x + x + + mte = ModuleToExport() + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) + + @common_utils.suppress_warnings + def test_onnx_export_func_with_warnings(self): + @torch.jit.script + def func_with_warning(inp): + return torch.nn.functional.sigmoid(inp) # triggers a deprecation warning + + class WarningTest(torch.nn.Module): + def forward(self, x): + return func_with_warning(x) + + # no exception + f = io.BytesIO() + torch.onnx.export(WarningTest(), torch.randn(42), f) + + def test_onnx_export_script_python_fail(self): + class PythonModule(torch.jit.ScriptModule): + @torch.jit.ignore + def forward(self, x): + return torch.neg(x) + + class ModuleToExport(torch.jit.ScriptModule): + def __init__(self) -> None: + super().__init__() + self.mod = PythonModule() + + @torch.jit.script_method + def forward(self, x): + y = self.mod(x) + return y + y + + mte = ModuleToExport() + f = io.BytesIO() + with self.assertRaisesRegex(RuntimeError, "Couldn't export Python"): + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) + + def test_onnx_export_script_inline_trace(self): + class ModuleToInline(torch.nn.Module): + def forward(self, x): + return torch.neg(x) + + class ModuleToExport(torch.jit.ScriptModule): + def __init__(self) -> None: + super().__init__() + self.mod = torch.jit.trace(ModuleToInline(), torch.zeros(1, 2, 3)) + + @torch.jit.script_method + def forward(self, x): + y = self.mod(x) + return y + y + + mte = ModuleToExport() + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) + + def test_onnx_export_script_inline_script(self): + class ModuleToInline(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + return torch.neg(x) + + class ModuleToExport(torch.jit.ScriptModule): + def __init__(self) -> None: + super().__init__() + self.mod = ModuleToInline() + + @torch.jit.script_method + def forward(self, x): + y = self.mod(x) + return y + y + + mte = ModuleToExport() + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) + + def test_onnx_export_script_module_loop(self): + class ModuleToExport(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + # test if we support end to end onnx export on loop and + # nested loops with and without loop index + for _ in range(5): + for i in range(3): + x = x + i + return x + + mte = ModuleToExport() + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) + + @common_utils.suppress_warnings + def test_onnx_export_script_truediv(self): + class ModuleToExport(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + z = x.size(0) / 2 + return x + z + + mte = ModuleToExport() + + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3, dtype=torch.float),), f) + + def test_onnx_export_script_non_alpha_add_sub(self): + class ModuleToExport(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + bs = x.size(0) + 1 + return bs - 1 + + mte = ModuleToExport() + f = io.BytesIO() + torch.onnx.export(mte, (torch.rand(3, 4),), f) + + def test_onnx_export_script_module_if(self): + class ModuleToExport(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + if bool(torch.sum(x) > 0): + x = torch.neg(x) + return x + + mte = ModuleToExport() + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) + + def test_onnx_export_script_inline_params(self): + class ModuleToInline(torch.jit.ScriptModule): + def __init__(self) -> None: + super().__init__() + self.m = torch.nn.Parameter(torch.ones(3, 3)) + self.unused = torch.nn.Parameter(torch.ones(1, 2, 3)) + + @torch.jit.script_method + def forward(self, x): + return torch.mm(x, self.m) + + class ModuleToExport(torch.jit.ScriptModule): + def __init__(self) -> None: + super().__init__() + self.mod = ModuleToInline() + self.param = torch.nn.Parameter(torch.ones(3, 4)) + + @torch.jit.script_method + def forward(self, x): + y = self.mod(x) + return torch.mm(y, self.param) + + mte = ModuleToExport() + result = mte(torch.zeros(2, 3)) + reference = torch.mm( + torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4) + ) + self.assertEqual(result, reference) + f = io.BytesIO() + torch.onnx.export(mte, (torch.ones(2, 3),), f) + + def test_onnx_export_speculate(self): + class Foo(torch.jit.ScriptModule): + def __init__(self, m): + super().__init__() + self.m = m + + @torch.jit.script_method + def forward(self, x): + x += x + # because we are testing if we emit `if` statement correctly + # we cannot use `True` as the condition. Constant prop + # would remove the `if` statements. + c = torch.sum(x) > 4 + if bool(c): + if bool(c): + y = self.m(x) + else: + y = self.m(x) + else: + y = self.m(x) + return y + + linear = torch.jit.trace( + torch.nn.Linear(10, 20).float(), torch.zeros(1, 10, dtype=torch.float) + ) + + @torch.jit.script + def transpose(x): + return x.t() + + f1 = Foo(transpose) + f2 = Foo(linear) + + f = io.BytesIO() + torch.onnx.export(f1, (torch.ones(1, 10, dtype=torch.float),), f) + f = io.BytesIO() + torch.onnx.export(f2, (torch.ones(1, 10, dtype=torch.float),), f) + + def test_onnx_export_shape_reshape(self): + class Foo(torch.nn.Module): + def forward(self, x): + import torch.onnx.operators + + x = x.repeat(5, 1, 1) + shape = torch.onnx.operators.shape_as_tensor(x) + reshaped = torch.onnx.operators.reshape_from_tensor_shape(x, shape) + return reshaped + + foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3)) + f = io.BytesIO() + torch.onnx.export(foo, (torch.zeros(1, 2, 3)), f) + + def test_export_dynamic_slice(self): + class DynamicSliceExportMod(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + retval = x[0] + for i in range(x.size(1)): + retval += torch.sum(x[0:i], dim=0) + return retval + + input = torch.rand(3, 4, 5) + + f = io.BytesIO() + torch.onnx.export(DynamicSliceExportMod(), (input,), f, opset_version=10) + + def test_export_dict(self): + class DictModule(torch.nn.Module): + def forward(self, x_in: torch.Tensor) -> dict[str, torch.Tensor]: + return {"test_key_out": x_in} + + x_in = torch.tensor(1) + mod = DictModule() + mod.train(False) + + f = io.BytesIO() + torch.onnx.export(mod, (x_in,), f) + + with self.assertRaisesRegex(RuntimeError, r"DictConstruct.+is not supported"): + f = io.BytesIO() + torch.onnx.export(torch.jit.script(mod), (x_in,), f) + + def test_source_range_propagation(self): + class ExpandingModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + # Will be expanded during ONNX export + self.ln = torch.nn.LayerNorm([1]) + + def forward(self, input): + return self.ln(input) + + mod = ExpandingModule() + + graph, _, _ = utils._model_to_graph( + mod, + (torch.zeros(1),), + operator_export_type=torch.onnx.OperatorExportTypes.ONNX, + ) + + # Ensure that every node in the graph has a valid source range + for node in graph.nodes(): + self.assertTrue(node.sourceRange()) + + def test_clip_aten_fallback_due_exception(self): + def bad_clamp(g, self, min, max): + return symbolic_helper._onnx_unsupported("Bad boy!") + + class MyClip(torch.nn.Module): + def forward(self, x): + return torch.clamp(x, min=-0.5, max=0.5) + + onnx_model = export_to_onnx( + MyClip(), + torch.randn(3, 4, requires_grad=True), + custom_ops=[common_utils.custom_op("aten::clamp", bad_clamp, 17)], + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + ) + self.assertAtenOp(onnx_model, "clamp", "Tensor") + + def test_clip_aten_fallback_explicit_request(self): + class MyClip(torch.nn.Module): + def forward(self, x): + return torch.clamp(x, min=-0.5, max=0.5) + + # Copy of mocked method must be saved to prevent + # max recursion depth while trying to run original instance method + original_get_function_group = registration.registry.get_function_group + + def break_is_registered_op_api(name): + fake_missing_symbolics = {"aten::clamp"} + if name in fake_missing_symbolics: + return None + return original_get_function_group(name) + + # Force missing symbolic for well-known op using a mock + onnx_model = export_to_onnx( + MyClip(), + torch.randn(3, 4, requires_grad=True), + mocks=[ + unittest.mock.patch( + "torch.onnx._internal.registration.registry.get_function_group", + side_effect=break_is_registered_op_api, + # wraps=registration.registry.get_function_group + ) + ], + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + ) + self.assertAtenOp(onnx_model, "clamp", "Tensor") + + def _helper_test_to_(self, cast_fn: Callable[[torch.Tensor], torch.Tensor]): + """Helper to test aten::to(device) variants. + + `cast_fn` is converted into a `torch.jit.script`. It wraps `aten::to` + during export to preventing the devices to be hard-coded. + + Needed by detectron2 after https://github.com/facebookresearch/detectron2/pull/4132/ + """ + cast_fn = torch.jit.script(cast_fn) + onnx_model = export_to_onnx(cast_fn, torch.zeros([1, 3, 32, 32])) + for n in onnx_model.graph.node: + self.assertNotEqual(n.op_type, "To") + self.assertNotEqual(n.op_type, "Cast") + + def test_to__cpu_string(self): + def cast_cpu_string(src: torch.Tensor) -> torch.Tensor: + return src.to("cpu") + + self._helper_test_to_(cast_cpu_string) + + def test_to__device_cpu_string(self): + def cast_device_cpu_string(src: torch.Tensor) -> torch.Tensor: + return src.to(device="cpu") + + self._helper_test_to_(cast_device_cpu_string) + + def test_script_custom_class_error(self): + class BoxCoder: + def __init__(self, bbox_xform_clip: float) -> None: + self.bbox_xform_clip = bbox_xform_clip + + def decode(self, rel_codes: Tensor, boxes: list[Tensor]) -> Tensor: + boxes = torch.cat(boxes, dim=0) + pred_ctr_x = ( + torch.clamp(rel_codes[:, 0::4], max=self.bbox_xform_clip) + * boxes[:, 2] + ) + return pred_ctr_x + + class MyModule(torch.nn.Module): + __annotations__ = { + "box_coder": BoxCoder, + } + + def __init__(self) -> None: + super().__init__() + self.box_coder = BoxCoder(1.4) + + def forward(self, box_regression: Tensor, proposals: list[Tensor]): + return self.box_coder.decode(box_regression, proposals) + + model = torch.jit.script(MyModule()) + box_regression = torch.randn([4, 4]) + proposal = [torch.randn(2, 4), torch.randn(2, 4)] + + with self.assertRaises(RuntimeError): + f = io.BytesIO() + torch.onnx.export( + model, + (box_regression, proposal), + f, + ) + + def test_initializer_sequence(self): + class MyModule(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super().__init__() + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out + + test_model = MyModule(3, 4, 10) + state_dict_list = [k for (k, v) in test_model.state_dict().items()] + named_params_list = [k for (k, v) in test_model.named_parameters()] + + x = torch.randn(32, 3) + f = io.BytesIO() + torch.onnx.export(test_model, (x,), f, do_constant_folding=False) + loaded_model = onnx.load_from_string(f.getvalue()) + + actual_list = [p.name for p in loaded_model.graph.initializer] + assert actual_list == state_dict_list, ( + "Initializers' sequence is not as same as state_dict(). Expected: (" + + ", ".join(state_dict_list) + + "). Actual:(" + + ", ".join(actual_list) + + ")." + ) + assert actual_list == named_params_list, ( + "Initializers' sequence is not as same as named_parameters(). Expected: (" + + ", ".join(named_params_list) + + "). Actual:(" + + ", ".join(actual_list) + + ")." + ) + + def test_initializer_sequence_script_model(self): + def list_is_expected(short_list, long_list) -> bool: + if len(short_list) > len(long_list): + return False + + for i in range(len(short_list)): + if short_list[i] not in long_list[i]: + return False + + return True + + def loop(x, y): + for i in range(int(y)): + x = x + i + return x + + class MyModule(torch.nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super().__init__() + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(hidden_size, num_classes) + + def forward(self, x, y): + x = loop(x, y) + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out + + test_model = torch.jit.script(MyModule(3, 4, 10)) + state_dict_list = [k for (k, v) in test_model.state_dict().items()] + named_params_list = [k for (k, v) in test_model.named_parameters()] + + x = torch.ones(2, 3, dtype=torch.float) + y = torch.tensor(5, dtype=torch.long) + f = io.BytesIO() + + torch.onnx.export(test_model, (x, y), f, do_constant_folding=False) + loaded_model = onnx.load_from_string(f.getvalue()) + + actual_list = [p.name for p in loaded_model.graph.initializer] + assert list_is_expected(state_dict_list, actual_list), ( + "ScriptModel - Initializers' sequence is not as same as state_dict(). Expected: (" + + ", ".join(state_dict_list) + + "). Actual:(" + + ", ".join(actual_list) + + ")." + ) + assert list_is_expected(named_params_list, actual_list), ( + "ScriptModel - Initializers' sequence is not as same as named_parameters(). Expected: (" + + ", ".join(named_params_list) + + "). Actual:(" + + ", ".join(actual_list) + + ")." + ) + + def test_shape_value_map(self): + class RSoftMax(torch.nn.Module): + def __init__(self, radix, cardinality): + super().__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + return x + + radix = 2 + cardinality = 1 + x = torch.randn(10, 1, 128, 1) + f = io.BytesIO() + torch.onnx.export( + RSoftMax(radix, cardinality), + (x,), + f, + input_names=["x"], + dynamic_axes={"x": [0]}, + ) + loaded_model = onnx.load_from_string(f.getvalue()) + self.assertEqual( + loaded_model.graph.output[0].type.tensor_type.shape.dim[1].dim_value, 128 + ) + + def test_onnx_proto_checker(self): + class Model(torch.nn.Module): + def forward(self, x): + return 2 * x + + x = torch.randn(1, 2, 3, requires_grad=True) + f = io.BytesIO() + torch.onnx.export(Model(), (x,), f) + model = onnx.load(f) + model.ir_version = 0 + + def check_proto(): + torch._C._check_onnx_proto(model.SerializeToString()) + + self.assertRaises(RuntimeError, check_proto) + + def test_maintain_dynamic_shapes_of_unreliable_nodes(self): + def symbolic_pythonop(g, *args, **kwargs): + return g.op("com.microsoft::PythonOp") + + torch.onnx.register_custom_op_symbolic("prim::PythonOp", symbolic_pythonop, 1) + self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "prim::PythonOp", 1) + + # necessay parameters for transformer embeddings + hidden_size = 48 + max_position_embeddings = 32 + batch_size = 2 + + # issue found that autograd.function making downstream + # node unreliable but with static shape. The issue was first + # discovered with using Apex FusedLayerNorm in Transformers + class CustomLayerNorm(torch.autograd.Function): + @staticmethod + def forward(ctx, embedding): + layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12) + return layer_norm(embedding) + + class EmbeddingModule(torch.nn.Module): + def forward( + self, + embeddings=None, + ): + embedding_output = CustomLayerNorm.apply(embeddings) + query = embedding_output.transpose(0, 1) + target_len, batch_size, embedding_dim = query.size() + # Reshape is used for consuming batch_size, and if it is static, + # this will be a Constant node in the graph + query = query.reshape(target_len, batch_size, embedding_dim) + return query + + embeddings = torch.randn(batch_size, max_position_embeddings, hidden_size) + + f = io.BytesIO() + torch.onnx.export( + EmbeddingModule().eval(), + (embeddings,), + f, + input_names=["embeddings"], + dynamic_axes={ + "embeddings": { + 0: "batch_size", + 1: "max_position_embeddings", + 2: "hidden_size", + } + }, + custom_opsets={"com.microsoft": 1}, + ) + model = onnx.load(io.BytesIO(f.getvalue())) + + # If there is a constant node with dim=3 and max_position_embeddings, + # batch_size, hidden_size as shape, it means the shape becomes static. + # Normally, with dynamic batch size, this constant node should not exist. + const_node = [n for n in model.graph.node if n.op_type == "Constant"] + self.assertNotEqual(len(const_node), 0) + for node in const_node: + for a in node.attribute: + if a.name == "value": + shape = onnx.numpy_helper.to_array(a.t) + self.assertNotEqual( + shape.tolist(), + [max_position_embeddings, batch_size, hidden_size], + ) + + def test_is_fp_for_C_TypeList(self): + class M(torch.nn.Module): + def forward(self, x): + x = x.squeeze(1) + w = x.shape[2] + pos = x.view(2, -1).argmax(1) + x_int = pos % w + y_int = (pos - x_int) // w + return y_int, x_int + + model = torch.jit.script(M()) + inputs = torch.randn(2, 4, 6) + f = io.BytesIO() + torch.onnx.export( + model, inputs, f, dynamic_axes={"x": [0, 1]}, input_names=["x"] + ) + + def test_dropout_script(self): + eg = torch.zeros(1, 2, 3, requires_grad=True) + + @jit_utils._trace(eg) + def foo(x): + x = torch.neg(x) + return F.dropout(x) + + class MyDrop(torch.nn.Module): + def forward(self, x): + return foo(x) + + f = io.BytesIO() + with warnings.catch_warnings(record=True): + torch.onnx.export(MyDrop(), (eg,), f) + + def test_pack_padded_pad_packed_trace(self): + from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + + T, B, C = 3, 5, 7 + + class PadPackedWrapper(torch.nn.Module): + def forward(self, x, seq_lens): + x = pack_padded_sequence(x, seq_lens) + x, _ = pad_packed_sequence(x) + return x + + x = np.ones((T, B, C)) + seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32) + # set padding value so we can test equivalence + for b in range(B): + if seq_lens[b] < T: + x[seq_lens[b] :, b, :] = 0 + seq_lens = torch.from_numpy(seq_lens) + x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True) + + m = PadPackedWrapper() + m_traced = torch.jit.trace( + m, + ( + x, + seq_lens, + ), + ) + + y = m(x, seq_lens) + loss = torch.sum(y) + loss.backward() + grad = x.grad.clone() + x.grad.zero_() + + y_traced = m_traced(x, seq_lens) + loss_traced = torch.sum(y_traced) + loss_traced.backward() + grad_traced = x.grad.clone() + + self.assertEqual(y_traced, x) + self.assertEqual(y_traced, y) + self.assertEqual(grad, grad_traced) + + f = io.BytesIO() + torch.onnx.export(m, (x, seq_lens), f) + + # Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch. + @common_utils.suppress_warnings + def test_rnn_trace_override(self): + from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + + num_layers = 3 + T, B, C = 11, 5, 7 + + class RNNTraceWrapper(torch.nn.Module): + def __init__(self, cell_type): + super().__init__() + if cell_type == "RNN": + self.rnn = torch.nn.RNN( + input_size=C, hidden_size=C, num_layers=num_layers + ) + elif cell_type == "LSTM": + self.rnn = torch.nn.LSTM( + input_size=C, hidden_size=C, num_layers=num_layers + ) + elif cell_type == "GRU": + self.rnn = torch.nn.GRU( + input_size=C, hidden_size=C, num_layers=num_layers + ) + + def forward(self, x, seq_lens): + x = pack_padded_sequence(x, seq_lens) + x, _ = self.rnn(x) + x, _ = pad_packed_sequence(x) + return x + + for cell_type in ["RNN", "LSTM", "GRU"]: + x = torch.ones(T, B, C, requires_grad=True) + seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32)) + + m = RNNTraceWrapper(cell_type) + m_traced = torch.jit.trace( + m, + ( + x, + seq_lens, + ), + ) + + y = m(x, seq_lens) + loss = torch.sum(y) + loss.backward() + grad = x.grad.clone() + x.grad.zero_() + + y_traced = m_traced(x, seq_lens) + loss_traced = torch.sum(y_traced) + loss_traced.backward() + grad_traced = x.grad.clone() + + self.assertEqual(y_traced, y) + self.assertEqual(grad, grad_traced) + + f = io.BytesIO() + torch.onnx.export(m, (x, seq_lens), f) + + def test_pushpackingpastrnn_in_peephole_create_own_gather_input(self): + from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + + num_layers = 3 + T, B, C = 11, 5, 7 + mask_start_point = 0 + + class LSTMTraceWrapper(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + self.rnn = torch.nn.LSTM( + input_size=C, hidden_size=C, num_layers=num_layers + ) + + def forward(self, x, seq_lens): + mask = torch.arange(mask_start_point, x.shape[1]) + seq_lens = seq_lens[mask] + x = pack_padded_sequence(x, seq_lens) + # Calculate sizes and prepare views to our zero buffer to pass as hx + max_batch_size = x.batch_sizes[0] + hx = torch.randn(num_layers, max_batch_size, C) + cx = torch.randn(num_layers, max_batch_size, C) + x, _ = self.rnn(x, (hx, cx)) + x, _ = pad_packed_sequence(x) + return x + + x = torch.ones(T, B, C) + # length 5 because of B + seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32)) + m = LSTMTraceWrapper() + + f = io.BytesIO() + torch.onnx.export( + m, + (x, seq_lens), + f, + verbose=True, + input_names=["input", "seq_len"], + dynamic_axes={"input": {1: "B"}}, + ) + onnx_proto = onnx.load_model_from_string(f.getvalue()) + # the first argument in onnx::Range should be constant node with value 0 + const_node = [] + constant_input_name = None + for n in onnx_proto.graph.node: + if n.op_type == "Constant": + const_node.append(n) + elif n.op_type == "Range": + constant_input_name = n.input[0] + self.assertNotEqual(constant_input_name, None) + self.assertNotEqual(len(const_node), 0) + + value = None + for n in const_node: + if n.output[0] == constant_input_name: + value = np.frombuffer(n.attribute[0].t.raw_data, dtype=np.int64) + self.assertEqual(value, 0) + + def test_trace_fork_wait_inline_onnx(self): + def fork_body(x): + return torch.neg(x), torch.neg(x) + + class MyMod(torch.nn.Module): + def forward(self, x): + fut = torch.jit._fork(fork_body, x) + val = torch.jit._wait(fut) + return val[1] + + # smoke test for ONNX export + f = io.BytesIO() + torch.onnx.export(MyMod(), (torch.rand(3, 4),), f) + + def test_trace_detach_onnx_erase(self): + class Mod(torch.nn.Module): + def forward(self, x, w): + return torch.matmul(x, w).detach() + + f = io.BytesIO() + torch.onnx.export(Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f) + + def test_aten_fallback_must_fallback(self): + class ModelWithAtenNotONNXOp(torch.nn.Module): + def forward(self, x, y): + abcd = x + y + defg = torch.linalg.qr(abcd) + return defg + + x = torch.rand(3, 4) + y = torch.rand(3, 4) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenNotONNXOp(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + # support for linalg.qr was added in later op set versions. + opset_version=9, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "linalg_qr") + + def test_onnx_aten(self): + class ModelWithAtenFmod(torch.nn.Module): + def forward(self, x, y): + return torch.fmod(x, y) + + x = torch.randn(3, 4, dtype=torch.float32) + y = torch.randn(3, 4, dtype=torch.float32) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenFmod(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "fmod", "Tensor") + + def test_onnx_aten_fallback_must_not_fallback(self): + # For BUILD_CAFFE2=0, aten fallback only when not exportable + class ONNXExportable(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.quant = torch.ao.quantization.QuantStub() + self.fc1 = torch.nn.Linear(12, 8) + self.fc2 = torch.nn.Linear(8, 4) + self.fc3 = torch.nn.Linear(4, 6) + self.dequant = torch.ao.quantization.DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = x.view((-1, 12)) + h = F.relu(self.fc1(x)) + h = F.relu(self.fc2(h)) + h = F.relu(self.fc3(h)) + h = self.dequant(h) + return h + + dummy_input = torch.randn(12) + f = io.BytesIO() + torch.onnx.export( + ONNXExportable(), + (dummy_input,), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + all_aten_nodes = [ + p + for p in onnx_model.graph.node + if p.op_type == "ATen" and p.domain == "org.pytorch.aten" + ] + self.assertEqual(len(all_aten_nodes), 0) + + def test_cat_with_empty_tensor(self): + class NoopConcat(torch.nn.Module): + def forward(self, x): + return torch.cat((torch.Tensor([]), x)) + + x = torch.randn(4, 5, 6) + # TODO: Parametrize this test for opset_version + for opset_version in {9, 11}: + f = io.BytesIO() + torch.onnx.export(NoopConcat(), (x,), f, opset_version=opset_version) + loaded_model = onnx.load_from_string(f.getvalue()) + self.assertEqual( + len(loaded_model.graph.output[0].type.tensor_type.shape.dim), 3 + ) + for idx, dim in enumerate(x.shape): + self.assertEqual( + loaded_model.graph.output[0] + .type.tensor_type.shape.dim[idx] + .dim_value, + dim, + ) + + def test_col2im(self): + # This test can be moved to test/onnx/test_pytorch_onnx_onnxruntime.py when ORT implement ::Col2Im + + # Random batched RGB 32x32 image-shaped input tensor of batch size 64 + original_image_inputs = torch.randn((64, 3, 32, 32)) + output_size = tuple(original_image_inputs.shape[2:]) + kernel_size = (1, 2) + dilation = 3 + padding = 2 + stride = 1 + model_im2col = torch.nn.Unfold( + kernel_size, dilation=dilation, padding=padding, stride=stride + ) + blocks = model_im2col(original_image_inputs) + + model = torch.nn.Fold( + output_size=output_size, + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride, + ) + f = io.BytesIO() + torch.onnx.export(model, (blocks,), f, opset_version=18) + + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertEqual(onnx_model.graph.node[-1].op_type, "Col2Im") + self.assertEqual(onnx_model.graph.node[-1].domain, "") + self.assertEqual(len(onnx_model.graph.node[-1].input), 3) + self.assertEqual(onnx_model.graph.node[-1].attribute[0].name, "dilations") + self.assertEqual(onnx_model.graph.node[-1].attribute[1].name, "pads") + self.assertEqual(onnx_model.graph.node[-1].attribute[2].name, "strides") + + @unittest.skipIf( + not torch.hub._check_module_exists("torch_scatter"), + "torch_scatter not installed.", + ) + def test_random_namespace_custom_op_is_onnx_exportable(self): + from torch_scatter import scatter_max # type: ignore[import] + + class MyModel(torch.nn.Module): + def forward(self, src: torch.Tensor, idx: torch.Tensor): + return scatter_max(src, idx) + + m = MyModel().eval() + src = torch.ones([3, 10], dtype=torch.float32) + idx = torch.randint(0, 4, [3, 10], dtype=torch.long) + + def sym_scatter_max(g, src, index, dim, out, dim_size): + return g.op( + "torch_scatter::scatter_max", src, index, dim_size_i=-1, outputs=2 + ) + + torch.onnx.register_custom_op_symbolic( + "torch_scatter::scatter_max", sym_scatter_max, 1 + ) + f = io.BytesIO() + with torch.no_grad(): + torch.onnx.export( + m, + (src, idx), + f, + opset_version=13, + custom_opsets={"torch_scatter": 1}, + do_constant_folding=True, + ) + + @common_utils.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + def test_fp8_export(self, fp8_dtype: torch.dtype): + class Model(torch.nn.Module): + def forward(self, x): + return x.to(torch.float32) + + x = torch.randn(2, 3).to(fp8_dtype) + + f = io.BytesIO() + torch.onnx.export(Model(), x, f, opset_version=19) + onnx.checker.check_model(f.getvalue()) + + onnx_type = { + torch.float8_e4m3fn: 17, + torch.float8_e5m2: 19, + } # From https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3#L512-L521 + loaded_model = onnx.load_from_string(f.getvalue()) + self.assertEqual( + loaded_model.graph.input[0].type.tensor_type.elem_type, onnx_type[fp8_dtype] + ) + + +class TestQuantizeEagerONNXExport(common_utils.TestCase): + def _test_lower_graph_impl(self, model, data): + model.qconfig = torch.ao.quantization.default_qconfig + model = torch.ao.quantization.prepare(model) + model = torch.ao.quantization.convert(model) + + _ = model(data) + input_names = ["x"] + + def _export_to_onnx(model, input, input_names): + traced = torch.jit.trace(model, input) + buf = io.BytesIO() + torch.jit.save(traced, buf) + buf.seek(0) + + model = torch.jit.load(buf) + f = io.BytesIO() + torch.onnx.export( + model, + input, + f, + input_names=input_names, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + opset_version=9, + ) + + _export_to_onnx(model, data, input_names) + + @common_quantization.skipIfNoFBGEMM + @unittest.skip( + "onnx opset9 does not support quantize_per_tensor and caffe2 \ + does not support conv3d" + ) + def test_lower_graph_conv3d(self): + model = torch.ao.quantization.QuantWrapper( + torch.nn.Conv3d(3, 5, 2, bias=True) + ).to(dtype=torch.float) + data_numpy = np.random.rand(1, 3, 6, 6, 6).astype(np.float32) + data = torch.from_numpy(data_numpy).to(dtype=torch.float) + self._test_lower_graph_impl(model, data) + + @pytorch_test_common.skipIfNoCuda + def test_composed_layer_norm_small_eps_fp16_keep_double(self): + class Net(torch.nn.Module): + def __init__(self, C): + super().__init__() + self.layer_norm = torch.nn.LayerNorm(C, eps=1e-8) + + def forward(self, x): + return self.layer_norm(x) + + N, C = 8, 4 + model = Net(C).cuda().half() + x = torch.randn(N, C).cuda().half() + f = io.BytesIO() + torch.onnx.export(model, (x,), f, opset_version=14) + onnx_model = onnx.load_from_string(f.getvalue()) + const_node = [n for n in onnx_model.graph.node if n.op_type == "Constant"] + self.assertNotEqual(len(const_node), 0) + double_type_count = 0 + for node in const_node: + for a in node.attribute: + # EPS constant should be in double type + if a.name == "value" and a.t.data_type == 11: + double_type_count += 1 + self.assertNotEqual(double_type_count, 0) + + @pytorch_test_common.skipIfNoCuda + def test_aten_device_with_index(self): + from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") + model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small") + model = torch.compile(model, backend="onnxrt") + model = model.eval() + device = "cuda:0" + model = model.to(device) + ids = tokenizer.batch_encode_plus(["This is a test"], return_tensors="pt").to( + device + ) + + with torch.no_grad(): + _ = model( + input_ids=ids["input_ids"], + attention_mask=ids["attention_mask"], + decoder_input_ids=ids["input_ids"], + decoder_attention_mask=ids["attention_mask"], + ) + + def test_aten_linalg_vector_norm_with_reducel2(self): + class Net(torch.nn.Module): + def forward(self, x): + x = F.normalize(x) + return x + + f = io.BytesIO() + torch.onnx.export(Net(), (torch.randn(1, 2, 2),), f) + onnx_model = onnx.load_from_string(f.getvalue()) + onnx_nodes = [n.op_type for n in onnx_model.graph.node] + self.assertIn("ReduceL2", onnx_nodes) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 6fa49ed61b71b..93aeb4869c7e7 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -41,9 +41,13 @@ import torch from torch import Tensor from torch.nn.utils import rnn as rnn_utils +<<<<<<< HEAD from torch.onnx import errors from torch.onnx._internal.torchscript_exporter import verification from torch.onnx._internal.torchscript_exporter._type_utils import JitScalarType +======= +from torch.onnx import errors, verification +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal import common_utils from torch.testing._internal.common_utils import skipIfNoLapack @@ -897,11 +901,15 @@ def forward( # export succeeds, but running ORT through run_test would fail because the exported model # has the inputs flattened into 3 inputs. torch.onnx.export( +<<<<<<< HEAD model, (x, {"y": (y0, y1)}), io.BytesIO(), opset_version=self.opset_version, dynamo=False, +======= + model, (x, {"y": (y0, y1)}), io.BytesIO(), opset_version=self.opset_version +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_primitive_input_integer(self): @@ -10795,7 +10803,10 @@ def forward(self, x): opset_version=self.opset_version, do_constant_folding=False, training=torch.onnx.TrainingMode.TRAINING, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ort_sess = verification._ort_session(model_onnx) ort_outs = verification._run_onnx(ort_sess, (x,)) @@ -10811,7 +10822,10 @@ def forward(self, x): opset_version=self.opset_version, do_constant_folding=False, training=torch.onnx.TrainingMode.TRAINING, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ort_outs = verification._run_onnx(ort_sess, (x,)) assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0]))) @@ -10845,7 +10859,10 @@ def forward(self, x): opset_version=self.opset_version, do_constant_folding=False, training=torch.onnx.TrainingMode.TRAINING, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ort_sess = verification._ort_session(model_onnx) ort_outs = verification._run_onnx(ort_sess, (x,)) @@ -10871,7 +10888,10 @@ def forward(self, x): opset_version=self.opset_version, do_constant_folding=False, training=torch.onnx.TrainingMode.TRAINING, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ort_sess = verification._ort_session(model_onnx) ort_outs = verification._run_onnx(ort_sess, (x,)) @@ -12632,11 +12652,15 @@ def forward(self, x, y): dummy_input = (torch.tensor([expected_mean]), torch.tensor([expected_std])) model_onnx = io.BytesIO() torch.onnx.export( +<<<<<<< HEAD model_export, dummy_input, model_onnx, opset_version=self.opset_version, dynamo=False, +======= + model_export, dummy_input, model_onnx, opset_version=self.opset_version +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ort_sess = verification._ort_session(model_onnx) ort_out = verification._run_onnx(ort_sess, inputs=dummy_input) @@ -12667,11 +12691,15 @@ def forward(self): model_onnx = io.BytesIO() test_inputs = () torch.onnx.export( +<<<<<<< HEAD model_export, test_inputs, model_onnx, opset_version=self.opset_version, dynamo=False, +======= + model_export, test_inputs, model_onnx, opset_version=self.opset_version +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ort_sess = verification._ort_session(model_onnx) ort_out = verification._run_onnx(ort_sess, inputs=test_inputs) @@ -12714,11 +12742,15 @@ def forward(self, x, y): dummy_input = (torch.tensor([expected_min]), torch.tensor([expected_max])) model_onnx = io.BytesIO() torch.onnx.export( +<<<<<<< HEAD model_export, dummy_input, model_onnx, opset_version=self.opset_version, dynamo=False, +======= + model_export, dummy_input, model_onnx, opset_version=self.opset_version +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ort_sess = verification._ort_session(model_onnx) @@ -13725,10 +13757,16 @@ def test_optional_output(self, module_class: type[torch.nn.Module], x_size: int) # Ensure condition is not constant dynamic_axes={"x": {0: dynamic_axis_name}}, input_names=["x"], +<<<<<<< HEAD dynamo=False, ) exported = onnx.load_from_string(f.getvalue()) expected_elem_type = JitScalarType.from_value(x).onnx_type() +======= + ) + exported = onnx.load_from_string(f.getvalue()) + expected_elem_type = torch.onnx.JitScalarType.from_value(x).onnx_type() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expected_output_type = onnx.helper.make_optional_type_proto( onnx.helper.make_tensor_type_proto(expected_elem_type, (dynamic_axis_name,)) ) diff --git a/test/onnx/test_pytorch_onnx_shape_inference.py b/test/onnx/test_pytorch_onnx_shape_inference.py index e7c58e1ffdbe1..cb58d295acef7 100644 --- a/test/onnx/test_pytorch_onnx_shape_inference.py +++ b/test/onnx/test_pytorch_onnx_shape_inference.py @@ -10,8 +10,13 @@ import torch from torch.onnx import _constants, utils +<<<<<<< HEAD from torch.onnx._internal.torchscript_exporter import jit_utils from torch.onnx._internal.torchscript_exporter._globals import GLOBALS +======= +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import jit_utils +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal import common_utils @@ -396,7 +401,10 @@ def linalg_inv_settype(g, self): f, opset_version=self.opset_version, custom_opsets={"com.microsoft": 1}, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) model_proto = onnx.load(io.BytesIO(f.getvalue())) @@ -431,7 +439,10 @@ def linalg_inv_no_settype(g, self): f, opset_version=self.opset_version, custom_opsets={"com.microsoft": 1}, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) model_proto = onnx.load(io.BytesIO(f.getvalue())) @@ -470,7 +481,10 @@ def linalg_inv_settype(g, self): custom_opsets={"com.microsoft": 1}, input_names=["x"], dynamic_axes={"x": {0: "batch"}}, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) model_proto = onnx.load(io.BytesIO(f.getvalue())) @@ -511,7 +525,10 @@ def linalg_inv_settype(g, self): f, opset_version=self.opset_version, custom_opsets={"com.microsoft": 1}, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) model_proto = onnx.load(io.BytesIO(f.getvalue())) diff --git a/test/onnx/test_symbolic_helper.py b/test/onnx/test_symbolic_helper.py index cc7a3a133732c..c8fabc4239972 100644 --- a/test/onnx/test_symbolic_helper.py +++ b/test/onnx/test_symbolic_helper.py @@ -3,7 +3,11 @@ import torch from torch.onnx import symbolic_helper +<<<<<<< HEAD from torch.onnx._internal.torchscript_exporter._globals import GLOBALS +======= +from torch.onnx._globals import GLOBALS +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal import common_utils diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 1f80f4163eb25..ea7947b05c063 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -1,9 +1,17 @@ # Owner(s): ["module: onnx"] import copy +<<<<<<< HEAD import io import re import warnings +======= +import functools +import io +import re +import warnings +from typing import Callable +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import onnx @@ -21,7 +29,11 @@ import torch.onnx import torch.utils.cpp_extension from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils +<<<<<<< HEAD from torch.onnx._internal.torchscript_exporter._globals import GLOBALS +======= +from torch.onnx._globals import GLOBALS +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.onnx.symbolic_helper import _unpack_list, parse_args from torch.testing._internal import common_utils from torch.testing._internal.common_utils import skipIfNoLapack @@ -84,6 +96,89 @@ def _model_to_graph( return graph, params_dict, torch_out +<<<<<<< HEAD +======= +@common_utils.instantiate_parametrized_tests +class TestUnconvertibleOps(pytorch_test_common.ExportTestCase): + """Unit tests for the `unconvertible_ops` function.""" + + def setUp(self): + class EinsumModule(torch.nn.Module): + def forward(self, x): + return torch.einsum("ii", x) + + self.einsum_module = EinsumModule() + + def test_it_returns_graph_and_unconvertible_ops_at_lower_opset_version(self): + x = torch.randn(4, 4) + + # Einsum is supported since opset 12. It should be unconvertible at opset 9. + graph, unconvertible_ops = utils.unconvertible_ops( + self.einsum_module, (x,), opset_version=9 + ) + nodes = graph.nodes() + self.assertEqual(next(nodes).kind(), "prim::Constant") + self.assertEqual(next(nodes).kind(), "prim::ListConstruct") + self.assertEqual(next(nodes).kind(), "prim::Constant") + self.assertEqual(next(nodes).kind(), "aten::einsum") + self.assertEqual(unconvertible_ops, ["aten::einsum"]) + + @common_utils.parametrize( + "jit_function", + [ + common_utils.subtest( + functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)), + name="traced", + ), + common_utils.subtest(torch.jit.script, name="scripted"), + ], + ) + def test_it_returns_unconvertible_ops_at_lower_opset_version_for_jit_module( + self, jit_function: Callable + ): + module = jit_function(self.einsum_module) + x = torch.randn(4, 4) + + # Einsum is supported since opset 12. It should be unconvertible at opset 9. + _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=9) + self.assertEqual(unconvertible_ops, ["aten::einsum"]) + + @common_utils.parametrize( + "jit_function", + [ + common_utils.subtest(lambda x: x, name="nn_module"), + common_utils.subtest( + functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)), + name="traced", + ), + common_utils.subtest(torch.jit.script, name="scripted"), + ], + ) + def test_it_returns_empty_list_when_all_ops_convertible( + self, jit_function: Callable + ): + module = jit_function(self.einsum_module) + x = torch.randn(4, 4) + + # Einsum is supported since opset 12 + _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=12) + self.assertEqual(unconvertible_ops, []) + + def test_it_returns_empty_list_when_model_contains_supported_inplace_ops(self): + class SkipConnectionModule(torch.nn.Module): + def forward(self, x): + out = x + out += x + out = torch.nn.functional.relu(out, inplace=True) + return out + + module = SkipConnectionModule() + x = torch.randn(4, 4) + _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=13) + self.assertEqual(unconvertible_ops, []) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parameterized.parameterized_class( [ {"opset_version": opset} @@ -111,9 +206,13 @@ def forward(self, x): x = torch.randn(3, 4) f = io.BytesIO() try: +<<<<<<< HEAD torch.onnx.export( MyModule(), x, f, opset_version=self.opset_version, dynamo=False ) +======= + torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except ValueError: self.assertFalse(torch.onnx.is_in_onnx_export()) @@ -640,7 +739,11 @@ def test_constant_fold_upsample_scale_fold_as_constant(self): model = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) x = torch.randn(1, 32, 224, 224) f = io.BytesIO() +<<<<<<< HEAD torch.onnx.export(model, x, f, dynamo=False) +======= + torch.onnx.export(model, x, f) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) onnx_model = onnx.load(io.BytesIO(f.getvalue())) self.assertEqual(len(onnx_model.graph.initializer), 0) @@ -653,6 +756,7 @@ def forward(self, input): def is_model_stripped(f, verbose=None): if verbose is None: +<<<<<<< HEAD torch.onnx.export( MyModule(), x, f, opset_version=self.opset_version, dynamo=False ) @@ -664,6 +768,12 @@ def is_model_stripped(f, verbose=None): verbose=verbose, opset_version=self.opset_version, dynamo=False, +======= + torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version) + else: + torch.onnx.export( + MyModule(), x, f, verbose=verbose, opset_version=self.opset_version +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) model = onnx.load(io.BytesIO(f.getvalue())) model_strip = copy.copy(model) @@ -686,9 +796,13 @@ def test_error_on_data_parallel(self): "exporter, please use 'attribute' module to " "unwrap model from torch.nn.DataParallel. Try ", ): +<<<<<<< HEAD torch.onnx.export( model, x, f, opset_version=self.opset_version, dynamo=False ) +======= + torch.onnx.export(model, x, f, opset_version=self.opset_version) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfUnsupportedMinOpsetVersion(11) def test_sequence_dim(self): @@ -712,7 +826,10 @@ def forward(self, x, y): opset_version=self.opset_version, input_names=["x", "y"], dynamic_axes={"y": [1]}, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) onnx_model = onnx.load(io.BytesIO(f.getvalue())) loop_output_value_info_proto = onnx_model.graph.output[0] @@ -724,9 +841,13 @@ def forward(self, x, y): # Case 2: no dynamic axes. f = io.BytesIO() y = torch.randn(2, 3) +<<<<<<< HEAD torch.onnx.export( script_model, (x, y), f, opset_version=self.opset_version, dynamo=False ) +======= + torch.onnx.export(script_model, (x, y), f, opset_version=self.opset_version) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) onnx_model = onnx.load(io.BytesIO(f.getvalue())) loop_output_value_info_proto = onnx_model.graph.output[0] ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info( @@ -753,7 +874,10 @@ def forward(self, x): f, opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # verify that the model state is preserved self.assertEqual(model.training, old_state) @@ -767,7 +891,10 @@ def forward(self, x): f, opset_version=self.opset_version, training=torch.onnx.TrainingMode.EVAL, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # verify that the model state is preserved self.assertEqual(model.training, old_state) @@ -795,9 +922,13 @@ def forward(self, x): # jit.freeze removes the training attribute in the module module = torch.jit.freeze(module) +<<<<<<< HEAD torch.onnx.export( module, (x,), io.BytesIO(), opset_version=self.opset_version, dynamo=False ) +======= + torch.onnx.export(module, (x,), io.BytesIO(), opset_version=self.opset_version) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfUnsupportedMinOpsetVersion(15) def test_local_function(self): @@ -846,7 +977,10 @@ def forward(self, x, y, z): torch.nn.Dropout, torch.nn.LayerNorm, }, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) onnx_model = onnx.load(io.BytesIO(f.getvalue())) @@ -881,7 +1015,10 @@ def forward(self, x, y, z): f, opset_version=self.opset_version, export_modules_as_functions={torch.nn.CELU}, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) onnx_model = onnx.load(io.BytesIO(f.getvalue())) @@ -897,7 +1034,10 @@ def forward(self, x, y, z): f, opset_version=self.opset_version, export_modules_as_functions=set(), +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) onnx_model = onnx.load(io.BytesIO(f.getvalue())) @@ -912,7 +1052,10 @@ def forward(self, x, y, z): f, opset_version=self.opset_version, export_modules_as_functions=True, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) onnx_model = onnx.load(io.BytesIO(f.getvalue())) @@ -949,7 +1092,10 @@ def forward(self, x, y, z): f, opset_version=self.opset_version, export_modules_as_functions={NWithOverloads}, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) onnx_model = onnx.load(io.BytesIO(f.getvalue())) @@ -979,7 +1125,10 @@ def forward(self, x): export_modules_as_functions=True, opset_version=self.opset_version, do_constant_folding=False, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) onnx_model = onnx.load(io.BytesIO(f.getvalue())) @@ -1012,7 +1161,10 @@ def forward(self, x): f, export_modules_as_functions=True, opset_version=self.opset_version, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) onnx_model = onnx.load(io.BytesIO(f.getvalue())) @@ -1078,7 +1230,10 @@ def forward(self, x): export_modules_as_functions=True, opset_version=self.opset_version, verbose=True, # Allows the test case to print `Skipping module attribute 'freeze'` +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_node_scope(self): @@ -1323,7 +1478,10 @@ def gelu(g, self, approximate): f, opset_version=self.opset_version, custom_opsets={"com.microsoft": 1}, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) graph = onnx.load(io.BytesIO(f.getvalue())) @@ -1344,9 +1502,13 @@ def gelu(g, self, approximate): model = torch.nn.GELU(approximate="none") x = torch.randn(3, 3) f = io.BytesIO() +<<<<<<< HEAD torch.onnx.export( model, (x,), f, opset_version=self.opset_version, dynamo=False ) +======= + torch.onnx.export(model, (x,), f, opset_version=self.opset_version) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph = onnx.load(io.BytesIO(f.getvalue())) self.assertEqual(graph.graph.node[0].op_type, "Gelu") @@ -1373,7 +1535,10 @@ def linalg_inv(g, self): f, opset_version=self.opset_version, custom_opsets={"com.microsoft": 1}, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) graph = onnx.load(io.BytesIO(f.getvalue())) @@ -1677,7 +1842,10 @@ def forward(self, x): f, opset_version=self.opset_version, keep_initializers_as_inputs=True, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) graph = onnx.load(io.BytesIO(f.getvalue())) self.assertEqual(graph.graph.input[1].name, "in_weight") @@ -1710,19 +1878,27 @@ def forward(self, x): ] f = io.BytesIO() +<<<<<<< HEAD torch.onnx.export( module, torch.ones(1, 10), f, output_names=["y"], dynamo=False ) +======= + torch.onnx.export(module, torch.ones(1, 10), f, output_names=["y"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) onnx_model = onnx.load(io.BytesIO(f.getvalue())) for n in onnx_model.graph.node: self.assertIn(n.name, ref_node_names) torch.onnx.export( +<<<<<<< HEAD torch.jit.script(module), torch.ones(1, 10), f, output_names=["y"], dynamo=False, +======= + torch.jit.script(module), torch.ones(1, 10), f, output_names=["y"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) onnx_model = onnx.load(io.BytesIO(f.getvalue())) for n in onnx_model.graph.node: @@ -1765,7 +1941,10 @@ def forward(self, x): f, training=TrainingMode.TRAINING, opset_version=self.opset_version, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) graph = onnx.load(io.BytesIO(f.getvalue())) self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set) @@ -1778,7 +1957,10 @@ def forward(self, x): f, training=TrainingMode.PRESERVE, opset_version=self.opset_version, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) graph = onnx.load(io.BytesIO(f.getvalue())) self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set) @@ -1786,9 +1968,13 @@ def forward(self, x): # Test eval mode. model.eval() f = io.BytesIO() +<<<<<<< HEAD torch.onnx.export( model, (x,), f, opset_version=self.opset_version, dynamo=False ) +======= + torch.onnx.export(model, (x,), f, opset_version=self.opset_version) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph = onnx.load(io.BytesIO(f.getvalue())) param_name_set.remove("param2") self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set) @@ -1817,9 +2003,13 @@ def forward(self, x, y): x = torch.randn(3, 3, device=torch.device("cpu")) y = torch.randn(3, 3, device=torch.device("cuda")) f = io.BytesIO() +<<<<<<< HEAD torch.onnx.export( Model(), (x, y), f, opset_version=self.opset_version, dynamo=False ) +======= + torch.onnx.export(Model(), (x, y), f, opset_version=self.opset_version) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph = onnx.load(io.BytesIO(f.getvalue())) self.assertSetEqual({i.name for i in graph.graph.initializer}, {"w_cpu"}) @@ -1860,7 +2050,10 @@ def forward(self, input0, input1): dynamic_axes=dynamic_axes, verbose=True, keep_initializers_as_inputs=True, +<<<<<<< HEAD dynamo=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) graph = onnx.load(io.BytesIO(f.getvalue())) @@ -1888,7 +2081,11 @@ def forward(self, x): f = io.BytesIO() x = torch.randn(1, 32, 224, 224) +<<<<<<< HEAD torch.onnx.export(Model(), x, f, dynamo=False) +======= + torch.onnx.export(Model(), x, f) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) onnx_model = onnx.load(io.BytesIO(f.getvalue())) # aten::upsample converts to onnx::resize resize_nodes = [n for n in onnx_model.graph.node if n.op_type == "Resize"] @@ -1920,7 +2117,11 @@ def forward(self, x): self.assertExpectedRaisesInline( AssertionError, lambda: torch.onnx.export( +<<<<<<< HEAD model, (x,), f, opset_version=_onnx_opset_version, dynamo=False +======= + model, (x,), f, opset_version=_onnx_opset_version +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ( "A mismatch between the number of arguments (2) and their descriptors (1) was found at symbolic function " diff --git a/test/onnx/test_verification.py b/test/onnx/test_verification.py new file mode 100644 index 0000000000000..4d2b4676d9b17 --- /dev/null +++ b/test/onnx/test_verification.py @@ -0,0 +1,298 @@ +# Owner(s): ["module: onnx"] + +import contextlib +import io +import tempfile +import unittest + +import numpy as np + +import onnx +import parameterized +import pytorch_test_common +from packaging import version + +import torch +from torch.onnx import _constants, _experimental, verification +from torch.testing._internal import common_utils + + +class TestVerification(pytorch_test_common.ExportTestCase): + def test_check_export_model_diff_returns_diff_when_constant_mismatch(self): + class UnexportableModel(torch.nn.Module): + def forward(self, x, y): + # tensor.data() will be exported as a constant, + # leading to wrong model output under different inputs. + return x + y.data + + test_input_groups = [ + ((torch.randn(2, 3), torch.randn(2, 3)), {}), + ((torch.randn(2, 3), torch.randn(2, 3)), {}), + ] + + results = verification.check_export_model_diff( + UnexportableModel(), test_input_groups + ) + self.assertRegex( + results, + r"Graph diff:(.|\n)*" + r"First diverging operator:(.|\n)*" + r"prim::Constant(.|\n)*" + r"Former source location:(.|\n)*" + r"Latter source location:", + ) + + def test_check_export_model_diff_returns_diff_when_dynamic_controlflow_mismatch( + self, + ): + class UnexportableModel(torch.nn.Module): + def forward(self, x, y): + for i in range(x.size(0)): + y = x[i] + y + return y + + test_input_groups = [ + ((torch.randn(2, 3), torch.randn(2, 3)), {}), + ((torch.randn(4, 3), torch.randn(2, 3)), {}), + ] + + export_options = _experimental.ExportOptions( + input_names=["x", "y"], dynamic_axes={"x": [0]} + ) + results = verification.check_export_model_diff( + UnexportableModel(), test_input_groups, export_options + ) + self.assertRegex( + results, + r"Graph diff:(.|\n)*" + r"First diverging operator:(.|\n)*" + r"prim::Constant(.|\n)*" + r"Latter source location:(.|\n)*", + ) + + def test_check_export_model_diff_returns_empty_when_correct_export(self): + class SupportedModel(torch.nn.Module): + def forward(self, x, y): + return x + y + + test_input_groups = [ + ((torch.randn(2, 3), torch.randn(2, 3)), {}), + ((torch.randn(2, 3), torch.randn(2, 3)), {}), + ] + + results = verification.check_export_model_diff( + SupportedModel(), test_input_groups + ) + self.assertEqual(results, "") + + def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage( + self, + ): + ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])] + pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])] + options = verification.VerificationOptions( + rtol=1e-5, + atol=1e-6, + check_shape=True, + check_dtype=False, + ignore_none=True, + acceptable_error_percentage=0.3, + ) + verification._compare_onnx_pytorch_outputs( + ort_outs, + pytorch_outs, + options, + ) + + def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage( + self, + ): + ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])] + pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])] + options = verification.VerificationOptions( + rtol=1e-5, + atol=1e-6, + check_shape=True, + check_dtype=False, + ignore_none=True, + acceptable_error_percentage=None, + ) + with self.assertRaises(AssertionError): + verification._compare_onnx_pytorch_outputs( + ort_outs, + pytorch_outs, + options, + ) + + +@common_utils.instantiate_parametrized_tests +class TestVerificationOnWrongExport(pytorch_test_common.ExportTestCase): + opset_version: int + + def setUp(self): + super().setUp() + + def incorrect_add_symbolic_function(g, self, other, alpha): + return self + + self.opset_version = _constants.ONNX_DEFAULT_OPSET + torch.onnx.register_custom_op_symbolic( + "aten::add", + incorrect_add_symbolic_function, + opset_version=self.opset_version, + ) + + def tearDown(self): + super().tearDown() + torch.onnx.unregister_custom_op_symbolic( + "aten::add", opset_version=self.opset_version + ) + + @common_utils.parametrize( + "onnx_backend", + [ + common_utils.subtest( + verification.OnnxBackend.REFERENCE, + decorators=[ + unittest.skipIf( + version.Version(onnx.__version__) < version.Version("1.13"), + reason="Reference Python runtime was introduced in 'onnx' 1.13.", + ) + ], + ), + verification.OnnxBackend.ONNX_RUNTIME_CPU, + ], + ) + def test_verify_found_mismatch_when_export_is_wrong( + self, onnx_backend: verification.OnnxBackend + ): + class Model(torch.nn.Module): + def forward(self, x): + return x + 1 + + with self.assertRaisesRegex(AssertionError, ".*Tensor-likes are not close!.*"): + verification.verify( + Model(), + (torch.randn(2, 3),), + opset_version=self.opset_version, + options=verification.VerificationOptions(backend=onnx_backend), + ) + + +@parameterized.parameterized_class( + [ + # TODO: enable this when ONNX submodule catches up to >= 1.13. + # {"onnx_backend": verification.OnnxBackend.ONNX}, + {"onnx_backend": verification.OnnxBackend.ONNX_RUNTIME_CPU}, + ], + class_name_func=lambda cls, + idx, + input_dicts: f"{cls.__name__}_{input_dicts['onnx_backend'].name}", +) +class TestFindMismatch(pytorch_test_common.ExportTestCase): + onnx_backend: verification.OnnxBackend + opset_version: int + graph_info: verification.GraphInfo + + def setUp(self): + super().setUp() + self.opset_version = _constants.ONNX_DEFAULT_OPSET + + def incorrect_relu_symbolic_function(g, self): + return g.op("Add", self, g.op("Constant", value_t=torch.tensor(1.0))) + + torch.onnx.register_custom_op_symbolic( + "aten::relu", + incorrect_relu_symbolic_function, + opset_version=self.opset_version, + ) + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.layers = torch.nn.Sequential( + torch.nn.Linear(3, 4), + torch.nn.ReLU(), + torch.nn.Linear(4, 5), + torch.nn.ReLU(), + torch.nn.Linear(5, 6), + ) + + def forward(self, x): + return self.layers(x) + + self.graph_info = verification.find_mismatch( + Model(), + (torch.randn(2, 3),), + opset_version=self.opset_version, + options=verification.VerificationOptions(backend=self.onnx_backend), + ) + + def tearDown(self): + super().tearDown() + torch.onnx.unregister_custom_op_symbolic( + "aten::relu", opset_version=self.opset_version + ) + delattr(self, "opset_version") + delattr(self, "graph_info") + + def test_pretty_print_tree_visualizes_mismatch(self): + f = io.StringIO() + with contextlib.redirect_stdout(f): + self.graph_info.pretty_print_tree() + self.assertExpected(f.getvalue()) + + def test_preserve_mismatch_source_location(self): + mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info() + + self.assertTrue(len(mismatch_leaves) > 0) + + for leaf_info in mismatch_leaves: + f = io.StringIO() + with contextlib.redirect_stdout(f): + leaf_info.pretty_print_mismatch(graph=True) + self.assertRegex( + f.getvalue(), + r"(.|\n)*aten::relu.*/torch/nn/functional.py:[0-9]+(.|\n)*", + ) + + def test_find_all_mismatch_operators(self): + mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info() + + self.assertEqual(len(mismatch_leaves), 2) + + for leaf_info in mismatch_leaves: + self.assertEqual(leaf_info.essential_node_count(), 1) + self.assertEqual(leaf_info.essential_node_kinds(), {"aten::relu"}) + + def test_find_mismatch_prints_correct_info_when_no_mismatch(self): + self.maxDiff = None + + class Model(torch.nn.Module): + def forward(self, x): + return x + 1 + + f = io.StringIO() + with contextlib.redirect_stdout(f): + verification.find_mismatch( + Model(), + (torch.randn(2, 3),), + opset_version=self.opset_version, + options=verification.VerificationOptions(backend=self.onnx_backend), + ) + self.assertExpected(f.getvalue()) + + def test_export_repro_for_mismatch(self): + mismatch_leaves = self.graph_info.all_mismatch_leaf_graph_info() + self.assertTrue(len(mismatch_leaves) > 0) + leaf_info = mismatch_leaves[0] + with tempfile.TemporaryDirectory() as temp_dir: + repro_dir = leaf_info.export_repro(temp_dir) + + with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): + options = verification.VerificationOptions(backend=self.onnx_backend) + verification.OnnxTestCaseRepro(repro_dir).validate(options) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/test/onnx/torchlib/error_reproduction.py b/test/onnx/torchlib/error_reproduction.py index 9fd1dace77677..f6f61b376951c 100644 --- a/test/onnx/torchlib/error_reproduction.py +++ b/test/onnx/torchlib/error_reproduction.py @@ -205,7 +205,11 @@ def create_reproduction_report( onnxscript=={onnxscript.__version__} numpy=={np.__version__} torch=={torch.__version__}""" +<<<<<<< HEAD short_test_name = test_name.rsplit(".", maxsplit=1)[-1] +======= + short_test_name = test_name.split(".")[-1] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reproduction_code = _REPRODUCTION_TEMPLATE.format( onnx_model_text=onnx_model_text, ort_inputs=input_text, @@ -245,7 +249,11 @@ def create_mismatch_report( error_text = str(error) error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__)) +<<<<<<< HEAD short_test_name = test_name.rsplit(".", maxsplit=1)[-1] +======= + short_test_name = test_name.split(".")[-1] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff = difflib.unified_diff( str(actual).splitlines(), str(expected).splitlines(), diff --git a/test/onnx/torchlib/ops_test_common.py b/test/onnx/torchlib/ops_test_common.py index 54ecbdc195181..8ca5f5dfbd876 100644 --- a/test/onnx/torchlib/ops_test_common.py +++ b/test/onnx/torchlib/ops_test_common.py @@ -246,7 +246,11 @@ def duplicate_opinfo_for_prims( new_opinfo = copy.deepcopy(opinfo) new_opinfo.name = new_name new_opinfo.op = getattr(torch.ops.prims, prims_name) +<<<<<<< HEAD opinfos.append(new_opinfo) # noqa: B909 +======= + opinfos.append(new_opinfo) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return raise RuntimeError(f"OpInfo '{name}' not found in the database.") diff --git a/test/optim/test_lrscheduler.py b/test/optim/test_lrscheduler.py index c36e7b2e21d62..147de23140929 100644 --- a/test/optim/test_lrscheduler.py +++ b/test/optim/test_lrscheduler.py @@ -784,6 +784,7 @@ def test_sequentiallr5(self): scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) self._test(scheduler, targets, epochs) +<<<<<<< HEAD def test_sequentiallr_no_warnings(self): scheduler1 = LinearLR(self.opt, start_factor=0.5, end_factor=0.1, total_iters=5) scheduler2 = ExponentialLR(self.opt, gamma=0.9) @@ -797,6 +798,8 @@ def test_sequentiallr_no_warnings(self): scheduler.step() self.assertTrue(len(ws) == 0, "No warning should be raised") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_get_last_lr_sequentiallr(self): epochs = 12 milestones = [3, 6] diff --git a/test/package/package_a/test_nn_module.py b/test/package/package_a/test_nn_module.py index 18cc9a395ada2..0060cf00e86f7 100644 --- a/test/package/package_a/test_nn_module.py +++ b/test/package/package_a/test_nn_module.py @@ -25,7 +25,11 @@ def __init__(self, nz=6, ngf=9, nc=3): torch.nn.ReLU(True), # state size. (ngf) x 32 x 32 torch.nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), +<<<<<<< HEAD torch.nn.Tanh(), +======= + torch.nn.Tanh() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # state size. (nc) x 64 x 64 ) diff --git a/test/package/test_save_load.py b/test/package/test_save_load.py index edbba9f6f8ee8..24dfe17e9706f 100644 --- a/test/package/test_save_load.py +++ b/test/package/test_save_load.py @@ -208,10 +208,18 @@ def make_exporter(): # Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first. return pe +<<<<<<< HEAD # This succeeds because OrderedImporter.get_name() properly # falls back to sys_importer which can find the original PackageAObject pe = make_exporter() pe.save_pickle("obj", "obj.pkl", obj2) +======= + # This should fail. The 'PackageAObject' type defined from 'importer1' + # is not necessarily the same 'obj2's version of 'PackageAObject'. + pe = make_exporter() + with self.assertRaises(pickle.PicklingError): + pe.save_pickle("obj", "obj.pkl", obj2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This should also fail. The 'PackageAObject' type defined from 'importer1' # is not necessarily the same as the one defined from 'importer2' diff --git a/test/profiler/test_execution_trace.py b/test/profiler/test_execution_trace.py index 7da2898ffbe76..8d41a7e9f081d 100644 --- a/test/profiler/test_execution_trace.py +++ b/test/profiler/test_execution_trace.py @@ -1,5 +1,21 @@ # Owner(s): ["oncall: profiler"] +<<<<<<< HEAD +======= +# if tqdm is not shutdown properly, it will leave the monitor thread alive. +# This causes an issue in the multithreading test because we check all events +# in that test with their tids. The events that correspond to these lingering +# threads all have TID of (uint64_t)(-1) which is invalid. +# The work around is turnning off monitoring thread when tqdm is loaded. +# Since these are unit tests, it is safe to turn off monitor thread. +try: + import tqdm + + tqdm.tqdm.monitor_interval = 0 +except ImportError: + pass + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import json import os import tempfile @@ -39,6 +55,7 @@ from torch.utils._triton import has_triton +<<<<<<< HEAD # if tqdm is not shutdown properly, it will leave the monitor thread alive. # This causes an issue in the multithreading test because we check all events # in that test with their tids. The events that correspond to these lingering @@ -52,6 +69,8 @@ except ImportError: pass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Json = dict[str, Any] @@ -404,7 +423,10 @@ def fn(a, b, c): nodes = self.get_execution_trace_root(fp.name) found_captured_triton_kernel_node = False +<<<<<<< HEAD found_call_compiled_fx_graph = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for n in nodes: assert "name" in n if "triton_" in n["name"]: @@ -413,10 +435,14 @@ def fn(a, b, c): found_captured_triton_kernel_node = True assert len(n["inputs"]["values"]) > 0 assert len(n["outputs"]["values"]) == 0 +<<<<<<< HEAD elif "Call CompiledFxGraph" in n["name"]: found_call_compiled_fx_graph = True assert found_captured_triton_kernel_node assert found_call_compiled_fx_graph +======= + assert found_captured_triton_kernel_node +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS") @unittest.skipIf( @@ -425,11 +451,14 @@ def fn(a, b, c): ) @skipCPUIf(True, "skip CPU device for testing profiling triton") def test_execution_trace_env_enabled_with_pt2(self, device): +<<<<<<< HEAD # clean up the local cache for triton kernel from torch._inductor.codecache import PyCodeCache as PyCodeCache PyCodeCache.cache_clear(purge=True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import os os.environ["ENABLE_PYTORCH_EXECUTION_TRACE"] = "1" @@ -444,9 +473,13 @@ def fn(a, b, c): a, b, c = (torch.randn(4, 4, requires_grad=True).to(device) for _ in range(3)) inputs = [a, b, c] +<<<<<<< HEAD with torch._inductor.config.patch( compile_threads=1, fx_graph_cache=False, fx_graph_remote_cache=False ): +======= + with torch._inductor.config.patch(compile_threads=1): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fn(*inputs) with profile( @@ -480,6 +513,7 @@ def fn(a, b, c): assert len(n["outputs"]["values"]) == 0 assert found_captured_triton_kernel_node +<<<<<<< HEAD @unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS") @unittest.skipIf( (not has_triton()) or (not TEST_CUDA and not TEST_XPU), @@ -585,6 +619,8 @@ def fn(a, b, c): ) assert fx_graph[7] == "# return %cos" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_execution_trace_start_stop(self, device): use_device = ( torch.profiler.ProfilerActivity.CUDA @@ -740,9 +776,15 @@ def test_execution_trace_record_integral_tensor_data(self): with tempfile.TemporaryDirectory() as temp_dir: fp_name = os.path.join(temp_dir, "test.et.json") +<<<<<<< HEAD os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_DATA"] = ( "aten::gather" ) +======= + os.environ[ + "ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_DATA" + ] = "aten::gather" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) et = ExecutionTraceObserver() et.register_callback(fp_name) et.set_extra_resource_collection(True) diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py index 5351f147cf33f..4e7c5c2e6ffcf 100644 --- a/test/profiler/test_memory_profiler.py +++ b/test/profiler/test_memory_profiler.py @@ -1324,6 +1324,7 @@ def step_fn(mark_region): aten::detach 7 (GRADIENT) -> 7 (GRADIENT) -- Optimizer -------------------------------------------------------------------------------------------- +<<<<<<< HEAD aten::detach 7 (GRADIENT) -> 7 (GRADIENT) aten::detach 7 (GRADIENT) -> 7 (GRADIENT) aten::clone 7 (GRADIENT) -> 10 (OPTIMIZER_STATE) @@ -1331,6 +1332,15 @@ def step_fn(mark_region): aten::detach 9 (GRADIENT) -> 9 (GRADIENT) aten::detach 9 (GRADIENT) -> 9 (GRADIENT) aten::clone 9 (GRADIENT) -> 11 (OPTIMIZER_STATE) +======= + aten::clone 7 (GRADIENT) -> 10 (OPTIMIZER_STATE) + aten::detach 10 (OPTIMIZER_STATE) -> 10 (OPTIMIZER_STATE) + aten::detach 10 (OPTIMIZER_STATE) -> 10 (OPTIMIZER_STATE) + aten::add_.Tensor 2 (PARAMETER), 10 (OPTIMIZER_STATE) -> 2 (PARAMETER) + aten::clone 9 (GRADIENT) -> 11 (OPTIMIZER_STATE) + aten::detach 11 (OPTIMIZER_STATE) -> 11 (OPTIMIZER_STATE) + aten::detach 11 (OPTIMIZER_STATE) -> 11 (OPTIMIZER_STATE) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten::add_.Tensor 3 (PARAMETER), 11 (OPTIMIZER_STATE) -> 3 (PARAMETER)""", ) diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 46b21cb4dc097..edadb1a4fd458 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -27,7 +27,10 @@ import torch.optim import torch.utils.data from torch._C._profiler import _ExperimentalConfig, _ExtraFields_PyCall +<<<<<<< HEAD from torch._inductor.utils import is_big_gpu +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.autograd.profiler import KinetoStepTracker, profile as _profile from torch.autograd.profiler_legacy import profile as _profile_legacy from torch.profiler import ( @@ -985,6 +988,7 @@ def test_flops(self): ) self.assertIn("Total MFLOPs", profiler_output) +<<<<<<< HEAD def test_override_time_units(self): US_IN_SECOND = 1000.0 * 1000.0 US_IN_MS = 1000.0 @@ -1029,6 +1033,8 @@ def test_override_time_units(self): self.assertTrue(cpu_time_str_us in profiler_output) self.assertTrue(cpu_time_total_str_us in profiler_output) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"}) @patch.dict(os.environ, {"KINETO_DAEMON_INIT_DELAY_S": "1"}) def test_kineto_profiler_api(self): @@ -1468,7 +1474,11 @@ def trace_and_check(exp_config: Optional[_ExperimentalConfig]) -> None: cats = {e.get("cat", None) for e in j["traceEvents"]} self.assertTrue( "cuda_sync" in cats, +<<<<<<< HEAD f"Expected to find cuda_sync event found = {cats}", +======= + "Expected to find cuda_sync event" f" found = {cats}", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) print("Testing enable_cuda_sync_events in _ExperimentalConfig") @@ -2336,6 +2346,7 @@ def verify_events(events): events = main_with_thread_fn(profile_all_threads) verify_events(events) +<<<<<<< HEAD @skipIfTorchDynamo("profiler gets ignored if dynamo activated") @unittest.skipIf(not kineto_available(), "Kineto is required") def test_python_gc_event(self): @@ -2404,6 +2415,8 @@ def validate_json(prof, gc_collection_on): payload() validate_json(prof, gc_flag) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SimpleNet(nn.Module): def __init__(self) -> None: @@ -3158,6 +3171,7 @@ def validate_json(prof): assert "Overload Name" in key_averages.table() validate_json(prof) +<<<<<<< HEAD @unittest.skipIf(not torch.cuda.is_available(), "requries CUDA") def test_profiler_debug_autotuner(self): """ @@ -3205,6 +3219,8 @@ def names(prof): n2 = names(prof2) self.assertEqual(n1, n2) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index 670e639c98e23..ab3d8d9d22614 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -191,6 +191,7 @@ def fmt_name(name: str) -> str: name, ) +<<<<<<< HEAD # HACK: this patches around the fact that PyBind11 improperly sets the # __qualname__ attribute on functions and methods; see # https://github.com/pybind/pybind11/issues/5774. This should be removed if @@ -201,6 +202,8 @@ def fmt_name(name: str) -> str: name, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return re.sub("object at 0x[0-9a-fA-F]+>", "object at 0xXXXXXXXXXXXX>", name) @classmethod @@ -764,7 +767,10 @@ def test_profiler_experimental_tree_with_stack_and_torch_dispatch(self): aten::add torch/_library/simple_registry.py(...): find_torch_dispatch_rule torch/_library/simple_registry.py(...): find +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch/_library/simple_registry.py(...): find test_profiler_tree.py(...): __torch_dispatch__ diff --git a/test/profiler/test_record_function.py b/test/profiler/test_record_function.py index 9ab80c9a07a17..efb4b0ad50fa9 100644 --- a/test/profiler/test_record_function.py +++ b/test/profiler/test_record_function.py @@ -1,6 +1,22 @@ # Owner(s): ["oncall: profiler"] # ruff: noqa: F841 +<<<<<<< HEAD +======= +# if tqdm is not shutdown properly, it will leave the monitor thread alive. +# This causes an issue in the multithreading test because we check all events +# in that test with their tids. The events that correspond to these lingering +# threads all have TID of (uint64_t)(-1) which is invalid. +# The work around is turnning off monitoring thread when tqdm is loaded. +# Since these are unit tests, it is safe to turn off monitor thread. +try: + import tqdm + + tqdm.tqdm.monitor_interval = 0 +except ImportError: + None + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Any import torch @@ -16,6 +32,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase +<<<<<<< HEAD # if tqdm is not shutdown properly, it will leave the monitor thread alive. # This causes an issue in the multithreading test because we check all events # in that test with their tids. The events that correspond to these lingering @@ -29,6 +46,8 @@ except ImportError: pass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Json = dict[str, Any] diff --git a/test/profiler/test_torch_tidy.py b/test/profiler/test_torch_tidy.py index efbd4b8189dee..b771a9bc4f69b 100644 --- a/test/profiler/test_torch_tidy.py +++ b/test/profiler/test_torch_tidy.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: profiler"] +<<<<<<< HEAD import gc import re import textwrap @@ -16,6 +17,8 @@ from torch.testing._internal.common_utils import run_tests, TestCase +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # if tqdm is not shutdown properly, it will leave the monitor thread alive. # This causes an issue in the multithreading test because we check all events # in that test with their tids. The events that correspond to these lingering @@ -27,10 +30,35 @@ tqdm.tqdm.monitor_interval = 0 except ImportError: +<<<<<<< HEAD pass Json = dict[str, Any] +======= + None + +import gc +import re +import textwrap +import unittest +import weakref +from typing import Any + +import torch +import torch.nn as nn +import torch.optim +import torch.utils.data +from torch._C._profiler import _TensorMetadata +from torch.profiler import _utils, profile +from torch.testing._internal.common_utils import run_tests, TestCase + + +Json = dict[str, Any] + +from torch._C._profiler import _ExtraFields_PyCall + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def find_node_with_name(nodes, name): for node in _utils.traverse_dfs(nodes): diff --git a/test/quantization/ao_migration/common.py b/test/quantization/ao_migration/common.py index 5797b4bab1d44..21982e86d2b30 100644 --- a/test/quantization/ao_migration/common.py +++ b/test/quantization/ao_migration/common.py @@ -47,6 +47,12 @@ def _test_dict_import( new_dict = getattr(new_location, dict_name) assert old_dict == new_dict, f"Dicts don't match: {dict_name}" for key in new_dict.keys(): +<<<<<<< HEAD assert old_dict[key] == new_dict[key], ( f"Dicts don't match: {dict_name} for key {key}" ) +======= + assert ( + old_dict[key] == new_dict[key] + ), f"Dicts don't match: {dict_name} for key {key}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/test/quantization/core/experimental/test_floatx.py b/test/quantization/core/experimental/test_floatx.py index ee7fe0a9d1860..87e66c5ae8372 100644 --- a/test/quantization/core/experimental/test_floatx.py +++ b/test/quantization/core/experimental/test_floatx.py @@ -426,6 +426,10 @@ def test_f4_save_load(self, device): class TestFloat8DtypeCPUOnly(TestCase): +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Test of mul implementation diff --git a/test/quantization/core/test_docs.py b/test/quantization/core/test_docs.py new file mode 100644 index 0000000000000..ab1689cccab2d --- /dev/null +++ b/test/quantization/core/test_docs.py @@ -0,0 +1,146 @@ +# Owner(s): ["oncall: quantization"] + +import re +import contextlib +from pathlib import Path + +import torch + +from torch.testing._internal.common_quantization import ( + QuantizationTestCase, + SingleLayerLinearModel, +) +from torch.testing._internal.common_quantized import override_quantized_engine +from torch.testing._internal.common_utils import raise_on_run_directly, IS_ARM64, IS_FBCODE +import unittest + + +@unittest.skipIf(IS_FBCODE, "some path issues in fbcode") +class TestQuantizationDocs(QuantizationTestCase): + r""" + The tests in this section import code from the quantization docs and check that + they actually run without errors. In cases where objects are undefined in the code snippet, + they must be provided in the test. The imports seem to behave a bit inconsistently, + they can be imported either in the test file or passed as a global input + """ + + def run(self, result=None): + with override_quantized_engine("qnnpack") if IS_ARM64 else contextlib.nullcontext(): + super().run(result) + + def _get_code( + self, path_from_pytorch, unique_identifier, offset=2, short_snippet=False + ): + r""" + This function reads in the code from the docs given a unique identifier. + Most code snippets have a 2 space indentation, for other indentation levels, + change the offset `arg`. the `short_snippet` arg can be set to allow for testing + of smaller snippets, the check that this arg controls is used to make sure that + we are not accidentally only importing a blank line or something. + """ + + def get_correct_path(path_from_pytorch): + r""" + Current working directory when CI is running test seems to vary, this function + looks for docs relative to this test file. + """ + core_dir = Path(__file__).parent + assert core_dir.match("test/quantization/core/"), ( + "test_docs.py is in an unexpected location. If you've been " + "moving files around, ensure that the test and build files have " + "been updated to have the correct relative path between " + "test_docs.py and the docs." + ) + pytorch_root = core_dir.parents[2] + return pytorch_root / path_from_pytorch + + path_to_file = get_correct_path(path_from_pytorch) + if path_to_file: + with open(path_to_file) as file: + content = file.readlines() + + # it will register as having a newline at the end in python + if "\n" not in unique_identifier: + unique_identifier += "\n" + + assert unique_identifier in content, f"could not find {unique_identifier} in {path_to_file}" + + # get index of first line of code + line_num_start = content.index(unique_identifier) + 1 + + # next find where the code chunk ends. + # this regex will match lines that don't start + # with a \n or " " with number of spaces=offset + r = r = re.compile("^[^\n," + " " * offset + "]") + # this will return the line of first line that matches regex + line_after_code = next(filter(r.match, content[line_num_start:])) + last_line_num = content.index(line_after_code) + + # remove the first `offset` chars of each line and gather it all together + code = "".join( + [x[offset:] for x in content[line_num_start + 1 : last_line_num]] + ) + + # want to make sure we are actually getting some code, + assert last_line_num - line_num_start > 3 or short_snippet, ( + f"The code in {path_to_file} identified by {unique_identifier} seems suspiciously short:" + f"\n\n###code-start####\n{code}###code-end####" + ) + return code + + return None + + def _test_code(self, code, global_inputs=None): + r""" + This function runs `code` using any vars in `global_inputs` + """ + # if couldn't find the + if code is not None: + expr = compile(code, "test", "exec") + exec(expr, global_inputs) + + def test_quantization_doc_ptdq(self): + path_from_pytorch = "docs/source/quantization.rst" + unique_identifier = "PTDQ API Example::" + code = self._get_code(path_from_pytorch, unique_identifier) + self._test_code(code) + + def test_quantization_doc_ptsq(self): + path_from_pytorch = "docs/source/quantization.rst" + unique_identifier = "PTSQ API Example::" + code = self._get_code(path_from_pytorch, unique_identifier) + self._test_code(code) + + def test_quantization_doc_qat(self): + path_from_pytorch = "docs/source/quantization.rst" + unique_identifier = "QAT API Example::" + + def _dummy_func(*args, **kwargs): + return None + + input_fp32 = torch.randn(1, 1, 1, 1) + global_inputs = {"training_loop": _dummy_func, "input_fp32": input_fp32} + code = self._get_code(path_from_pytorch, unique_identifier) + self._test_code(code, global_inputs) + + def test_quantization_doc_fx(self): + path_from_pytorch = "docs/source/quantization.rst" + unique_identifier = "FXPTQ API Example::" + + input_fp32 = SingleLayerLinearModel().get_example_inputs() + global_inputs = {"UserModel": SingleLayerLinearModel, "input_fp32": input_fp32} + + code = self._get_code(path_from_pytorch, unique_identifier) + self._test_code(code, global_inputs) + + def test_quantization_doc_custom(self): + path_from_pytorch = "docs/source/quantization.rst" + unique_identifier = "Custom API Example::" + + global_inputs = {"nnq": torch.ao.nn.quantized} + + code = self._get_code(path_from_pytorch, unique_identifier) + self._test_code(code, global_inputs) + +if __name__ == "__main__": + raise_on_run_directly("test/test_quantization.py") diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index b6df2089e87e7..8fc06d0040725 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -7,8 +7,13 @@ import numpy as np import operator import random +<<<<<<< HEAD import unittest from packaging.version import Version +======= +import sys +import unittest +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import NamedTuple import torch @@ -73,7 +78,11 @@ class PointwisePostOp(NamedTuple): def avoid_vpmaddubsw_overflow_linear( batch_size, input_channels, output_channels, X, X_min, X_max, W, W_min, W_max ): +<<<<<<< HEAD if Version(np.__version__) >= Version("2.1"): +======= + if np.lib.NumpyVersion(np.__version__) >= '2.1.0': +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise unittest.SkipTest("numpy 2.1 overflow error") for i, j in np.ndindex((batch_size, output_channels)): for k in range(0, input_channels // 2 * 2, 2): @@ -154,6 +163,7 @@ def _get_random_tensor_and_q_params(shapes, rand_scale, torch_type): X_scale = 1e-10 return X, X_scale, X_zero_point +<<<<<<< HEAD def _quantize_fp8e4m3(t: torch.Tensor, channelwise: bool, scale: Optional[torch.Tensor] = None): quant_max = torch.finfo(torch.float8_e4m3fn).max eps = torch.Tensor([torch.finfo(torch.float32).eps]) @@ -181,6 +191,8 @@ def _dequantize_fp8e4m3(qt: torch.Tensor, scale: torch.Tensor): dqt = dqt * scale_reshape return dqt +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestQuantizedOps(TestCase): """Helper function to test quantized activation functions.""" @@ -3551,15 +3563,24 @@ def test_wrapped_fbgemm_linear_fp16(self): (2, 4), # batch_size (4, 5), # input_channels (4, 7), # output_channels +<<<<<<< HEAD (True, False), # bias None or not ) for batch_size, input_channels, output_channels, bias_is_none in options: +======= + ) + for batch_size, input_channels, output_channels in options: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pack_op = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16 linear_op = torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight x = torch.randn(batch_size, input_channels) w = torch.randn(output_channels, input_channels) +<<<<<<< HEAD bias = torch.randn(output_channels) if not bias_is_none else None +======= + bias = torch.randn(output_channels) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) w_packed = pack_op(w) out = linear_op(x, w_packed, bias, output_channels) @@ -3593,6 +3614,7 @@ def func(X, W, B): self.assertEqual(ref_out, compiled_out) +<<<<<<< HEAD def func(X, W): packed_W = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(W) return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(X, packed_W, None, W.size(0)) @@ -3605,6 +3627,8 @@ def func(X, W): self.assertEqual(ref_out, compiled_out) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Tests the correctness of the dynamic quantized lstm/gru.""" def _get_rnn_inputs(self, seq_len, num_batches, input_size, hidden_size, num_directions, reduce_range): @@ -4551,7 +4575,11 @@ def _test_qlinear_pt2e_helper( qlinear_op, post_op="none", unary_post_op_args=(), +<<<<<<< HEAD post_op_algorithms=("none",), +======= + post_op_algorithms=("none"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): qlinear_prepack = torch.ops.onednn.qlinear_prepack linear_op = F.linear @@ -4718,6 +4746,7 @@ def test_qlinear_add_relu_pt2e(self): qlinear = torch.ops.onednn.qlinear_pointwise.binary self._test_qlinear_pt2e_helper(qlinear, "add_relu") +<<<<<<< HEAD def _test_qlinear_fp8_helper( self, qlinear_op, @@ -4876,6 +4905,8 @@ def test_qlinear_add_relu_fp8(self): qlinear = torch.ops.onednn.qlinear_pointwise.binary self._test_qlinear_fp8_helper(qlinear, "add_relu") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_MACOS, "Known test failure on Mac.") class TestQuantizedEmbeddingOps(TestCase): @@ -7510,10 +7541,17 @@ def test_qconv2d_hardtanh_pt2e(self): qconv_output_dtype=output_dtype, ) +<<<<<<< HEAD # Test qconv with post op swish @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") @skipIfNoONEDNN def test_qconv2d_swish_pt2e(self): +======= + # Test qconv with post op silu + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qconv2d_silu_pt2e(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_channels_per_group = 2 output_channels_per_group = 2 groups_list = [1, 10] @@ -7835,6 +7873,7 @@ def test_qconv1d_relu_pt2e(self): qconv_output_dtype=output_dtype, ) +<<<<<<< HEAD def _make_qconv_tensors_fp8( self, batch_size, input_channels_per_group, input_feature_map_shape, output_channels_per_group, groups, kernels, strides, pads, dilations, @@ -8159,6 +8198,8 @@ def test_qconv3d_fp8(self): self._test_qconv_fp8_helper(3, pointwise_post_op) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestPadding(TestCase): @given(batch_size=st.integers(1, 64), diff --git a/test/quantization/core/test_quantized_tensor.py b/test/quantization/core/test_quantized_tensor.py index f241cc4387578..3a325533e9d58 100644 --- a/test/quantization/core/test_quantized_tensor.py +++ b/test/quantization/core/test_quantized_tensor.py @@ -1409,9 +1409,12 @@ def test_choose_qparams_optimized(self): self.assertEqual(y[0].numpy(), ref[0]) self.assertEqual(y[1].numpy(), ref[1]) +<<<<<<< HEAD with self.assertRaisesRegex(ValueError, "input tensor is empty and has no data"): torch.choose_qparams_optimized(torch.tensor([]), numel=0, n_bins=200, ratio=0.16, bit_width=8) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _test_pickle_checkpoint_qtensor(self, device): with TemporaryFileName() as fname: class M(torch.jit.ScriptModule): diff --git a/test/quantization/core/test_workflow_ops.py b/test/quantization/core/test_workflow_ops.py index 4cf34ac8c6c84..a91714ddcb104 100644 --- a/test/quantization/core/test_workflow_ops.py +++ b/test/quantization/core/test_workflow_ops.py @@ -29,7 +29,11 @@ from hypothesis import strategies as st import torch.testing._internal.hypothesis_utils as hu hu.assert_deadline_disabled() +<<<<<<< HEAD from torch.testing._internal.common_cuda import TEST_CUDA, TEST_WITH_ROCM +======= +from torch.testing._internal.common_cuda import TEST_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo # Reference method for fake quantize @@ -1038,6 +1042,7 @@ def test_fake_quantize_per_channel_affine_scale_dtypes(self): input, scale, zero_point, axis, quant_min, quant_max ) +<<<<<<< HEAD @skipIfTorchDynamo("Not a suitable test for TorchDynamo") @unittest.skipIf(TEST_WITH_ROCM, "Not a suitable test for ROCM") @given(dtype=st.sampled_from([torch.float, torch.float64, torch.half, torch.bfloat16]), @@ -1054,6 +1059,8 @@ def test_fake_quantize_per_tensor_affine_inf(self, dtype, device) -> None: ref_result = torch.Tensor([ref_result]).to(dtype).to(device) self.assertEqual(result, ref_result) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestFusedObsFakeQuant(TestCase): @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), diff --git a/test/quantization/eager/test_quantize_eager_qat.py b/test/quantization/eager/test_quantize_eager_qat.py index c5ce0659f55fa..bda7a75998da2 100644 --- a/test/quantization/eager/test_quantize_eager_qat.py +++ b/test/quantization/eager/test_quantize_eager_qat.py @@ -248,9 +248,15 @@ def from_float(cls, mod, qconfig=None): + cls._FLOAT_MODULE.__name__ ) if not qconfig: +<<<<<<< HEAD assert hasattr(mod, "qconfig"), ( "Input float module must have qconfig defined" ) +======= + assert hasattr( + mod, "qconfig" + ), "Input float module must have qconfig defined" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert mod.qconfig, "Input float module must have a valid qconfig" qconfig = mod.qconfig conv, bn = mod[0], mod[1] diff --git a/test/quantization/fx/test_model_report_fx.py b/test/quantization/fx/test_model_report_fx.py index 80ab0f1e8618e..ffa35bb7876ba 100644 --- a/test/quantization/fx/test_model_report_fx.py +++ b/test/quantization/fx/test_model_report_fx.py @@ -1945,7 +1945,11 @@ def _get_prepped_for_calibration_model_helper(model, detector_set, example_input example_input = example_input.to(torch.float) q_config_mapping = torch.ao.quantization.get_default_qconfig_mapping() +<<<<<<< HEAD # if they passed in fusion parameter, make sure to test that +======= + # if they passed in fusion paramter, make sure to test that +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if fused: model = torch.ao.quantization.fuse_modules(model, model.get_fusion_modules()) diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index e38c56da2a71b..33db4e4990fd7 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -827,7 +827,11 @@ def conv_bn_res_relu_extra_inputs_getter(pattern): named_modules = dict(m.named_modules()) for node in m.graph.nodes: if node.op == "call_module" and type(named_modules[node.target]) == torch.nn.Conv2d: +<<<<<<< HEAD self.assertTrue(len(node.args) == 2, msg="Expecting the fused op to have two arguments") +======= + self.assertTrue(len(node.args) == 2), "Expecting the fused op to have two arguments" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_fusion_pattern_with_matchallnode(self): """This test tests that the node matched by MatchAllNode will be regared as an input @@ -6648,7 +6652,11 @@ class SubModule(nn.Module): """ def __init__(self, input_dim, output_dim): +<<<<<<< HEAD super().__init__() +======= + super(__class__, self).__init__() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.w = nn.Parameter(torch.randn(input_dim, output_dim)) self.b = nn.Parameter(torch.randn(input_dim)) @@ -6661,7 +6669,11 @@ class MainModule(nn.Module): """ def __init__(self, input_dim, hidden_dim, output_dim): +<<<<<<< HEAD super().__init__() +======= + super(__class__, self).__init__() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.submodule_1 = SubModule(hidden_dim, input_dim) setattr(self, 'submodule|2', SubModule(hidden_dim, hidden_dim)) setattr(self, 'submodule/3', SubModule(hidden_dim, hidden_dim)) diff --git a/test/quantization/jit/test_ondevice_quantization.py b/test/quantization/jit/test_ondevice_quantization.py index ce23b155810d8..e181d8271b3e8 100644 --- a/test/quantization/jit/test_ondevice_quantization.py +++ b/test/quantization/jit/test_ondevice_quantization.py @@ -99,17 +99,29 @@ def get_linear_packed_param_fp_weight(node): ): raise ValueError("Quantized weight must be produced.") fp_weight = weight.inputsAt(0).node() +<<<<<<< HEAD assert fp_weight.kind() == "prim::GetAttr", ( "Weight must be an attribute of the module." ) +======= + assert ( + fp_weight.kind() == "prim::GetAttr" + ), "Weight must be an attribute of the module." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fp_weight_name = fp_weight.s("name") return fp_weight_name @staticmethod def is_per_channel_quantized_packed_param(node): +<<<<<<< HEAD assert node.kind() == "quantized::linear_prepack", ( "Node must corresponds to linear_prepack." ) +======= + assert ( + node.kind() == "quantized::linear_prepack" + ), "Node must corresponds to linear_prepack." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) weight = node.inputsAt(0).node() assert ( weight.kind() != "aten::quantize_per_tensor" diff --git a/test/quantization/jit/test_quantize_jit.py b/test/quantization/jit/test_quantize_jit.py index c71f7182b7071..98fb64c7e3bb7 100644 --- a/test/quantization/jit/test_quantize_jit.py +++ b/test/quantization/jit/test_quantize_jit.py @@ -124,7 +124,15 @@ def forward(self, x): "aten::dequantize" ).check_not("aten::quantize_per_channel").check("aten::dequantize").check_next( "aten::conv2d" +<<<<<<< HEAD ).check_next("aten::quantize_per_tensor").check_next("aten::dequantize").run( +======= + ).check_next( + "aten::quantize_per_tensor" + ).check_next( + "aten::dequantize" + ).run( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) freezed.graph ) @@ -666,9 +674,15 @@ def forward(self, x): } assert len(activation_dtypes) == 1, "Expected to have 1 activation dtype" assert len(weight_dtypes) == 1, "Expected to have 1 weight dtype" +<<<<<<< HEAD assert next(iter(activation_dtypes)) != next(iter(weight_dtypes)), ( "Expected activation dtype to " ) +======= + assert next(iter(activation_dtypes)) != next( + iter(weight_dtypes) + ), "Expected activation dtype to " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) " be different from wegiht dtype" def test_insert_observers_for_reused_weight(self): @@ -702,9 +716,15 @@ def forward(self, x): conv2_observers = attrs_with_prefix(m.conv2, "_observer_") assert len(conv1_observers) == 1, "Expected to have 1 observer submodules" assert len(conv2_observers) == 1, "Expected to have 1 observer submodules" +<<<<<<< HEAD assert conv1_observers == conv2_observers, ( "Expect conv1 and conv2 to have same observers since the class type is shared" ) +======= + assert ( + conv1_observers == conv2_observers + ), "Expect conv1 and conv2 to have same observers since the class type is shared" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_insert_observers_for_general_ops(self): """Make sure we skip observers for ops that doesn't require @@ -730,9 +750,19 @@ def forward(self, x): 'prim::GetAttr[name="conv"]' ).check("prim::CallMethod").check( 'Observer = prim::GetAttr[name="_observer_' +<<<<<<< HEAD ).check("aten::flatten").check_not( 'Observer = prim::GetAttr[name="_observer_' ).run(m.graph) +======= + ).check( + "aten::flatten" + ).check_not( + 'Observer = prim::GetAttr[name="_observer_' + ).run( + m.graph + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: this is too long, split this to test_insert_observers.py and remove # insrt_observers prefix @@ -762,11 +792,25 @@ def forward(self, x): 'prim::GetAttr[name="conv1"]' ).check("prim::CallMethod").check( 'Observer = prim::GetAttr[name="_observer_' +<<<<<<< HEAD ).check("aten::flatten").check_not( 'Observer = prim::GetAttr[name="_observer_' ).check('prim::GetAttr[name="conv2"]').check( 'Observer = prim::GetAttr[name="_observer_' ).run(m.graph) +======= + ).check( + "aten::flatten" + ).check_not( + 'Observer = prim::GetAttr[name="_observer_' + ).check( + 'prim::GetAttr[name="conv2"]' + ).check( + 'Observer = prim::GetAttr[name="_observer_' + ).run( + m.graph + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_insert_observers_propagate_observed_in_submodule(self): """Make sure we propagate observed property through general ops""" @@ -795,11 +839,25 @@ def forward(self, x): 'prim::GetAttr[name="conv1"]' ).check("prim::CallMethod").check( 'Observer = prim::GetAttr[name="_observer_' +<<<<<<< HEAD ).check("prim::CallMethod").check_not( 'Observer = prim::GetAttr[name="_observer_' ).check('prim::GetAttr[name="conv2"]').check( 'Observer = prim::GetAttr[name="_observer_' ).run(m.graph) +======= + ).check( + "prim::CallMethod" + ).check_not( + 'Observer = prim::GetAttr[name="_observer_' + ).check( + 'prim::GetAttr[name="conv2"]' + ).check( + 'Observer = prim::GetAttr[name="_observer_' + ).run( + m.graph + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_insert_observers_propagate_observed_for_function(self): def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor: @@ -1035,9 +1093,15 @@ def forward(self, x): m(data) m = convert_jit(m, debug=True) +<<<<<<< HEAD assert len(m._modules._c.items()) == 1, ( "Expected to have single submodule of conv" ) +======= + assert ( + len(m._modules._c.items()) == 1 + ), "Expected to have single submodule of conv" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # make sure the quantized model is executable m(data) quant_func = ( @@ -1068,6 +1132,7 @@ def forward(self, x): qconfig_dict = {"": qconfig} m = prepare_jit(m, qconfig_dict) # observers for input, output and value between conv1/conv2 +<<<<<<< HEAD assert len(attrs_with_prefix(m, "_observer_")) == 3, ( "Expected to have 3 obervers" ) @@ -1079,6 +1144,19 @@ def forward(self, x): assert len(attrs_with_prefix(m.conv2, "_observer_")) == 1, ( "Expected to have 1 obervers" ) +======= + assert ( + len(attrs_with_prefix(m, "_observer_")) == 3 + ), "Expected to have 3 obervers" + # observer for weight + assert ( + len(attrs_with_prefix(m.conv1, "_observer_")) == 1 + ), "Expected to have 1 obervers" + # observer for weight + assert ( + len(attrs_with_prefix(m.conv2, "_observer_")) == 1 + ), "Expected to have 1 obervers" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) data = torch.randn(1, 3, 10, 10, dtype=torch.float) m(data) @@ -1087,6 +1165,7 @@ def forward(self, x): assert m.conv1._c._type() == m.conv2._c._type() # check all observers have been removed +<<<<<<< HEAD assert len(attrs_with_prefix(m, "_observer_")) == 0, ( "Expected to have 0 obervers" ) @@ -1096,6 +1175,17 @@ def forward(self, x): assert len(attrs_with_prefix(m.conv2, "_observer_")) == 0, ( "Expected to have 0 obervers" ) +======= + assert ( + len(attrs_with_prefix(m, "_observer_")) == 0 + ), "Expected to have 0 obervers" + assert ( + len(attrs_with_prefix(m.conv1, "_observer_")) == 0 + ), "Expected to have 0 obervers" + assert ( + len(attrs_with_prefix(m.conv2, "_observer_")) == 0 + ), "Expected to have 0 obervers" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) quant_func = ( "aten::quantize_per_channel" @@ -1314,7 +1404,15 @@ def forward(self, x): "aten::avg_pool2d" ).check("aten::q_scale").check_next("aten::q_zero_point").check_next( "prim::dtype" +<<<<<<< HEAD ).check_next("aten::quantize_per_tensor").check("aten::dequantize").run( +======= + ).check_next( + "aten::quantize_per_tensor" + ).check( + "aten::dequantize" + ).run( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model.graph ) @@ -1733,7 +1831,13 @@ def forward(self, x): "aten::relu" ).check_not(f"quantized::conv{dim}d(").check_not( "quantized::relu(" +<<<<<<< HEAD ).run(m.graph) +======= + ).run( + m.graph + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfNoFBGEMM def test_quantized_add_alpha(self): @@ -1884,7 +1988,13 @@ def forward(self, x, y): "aten::relu(" ).check_not("aten::relu_(").check_not("quantized::add(").check_not( "quantized::relu(" +<<<<<<< HEAD ).run(m.graph) +======= + ).run( + m.graph + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfNoFBGEMM def test_quantized_add(self): @@ -2091,7 +2201,13 @@ def forward(self, x, y): "aten::relu(" ).check_not("aten::relu_(").check_not("quantized::add(").check_not( "quantized::relu(" +<<<<<<< HEAD ).run(m.graph) +======= + ).run( + m.graph + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfNoFBGEMM def test_quantized_add_scalar_relu(self): @@ -2175,7 +2291,15 @@ def forward(self, x): "aten::relu(" ).check_not("aten::relu_(").check_not( "quantized::add_scalar(" +<<<<<<< HEAD ).check_not("quantized::relu(").run(m.graph) +======= + ).check_not( + "quantized::relu(" + ).run( + m.graph + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfNoFBGEMM def test_quantized_cat(self): @@ -2510,7 +2634,13 @@ def forward(self, x, y): "aten::relu(" ).check_not("aten::relu_(").check_not("quantized::mul(").check_not( "quantized::relu(" +<<<<<<< HEAD ).run(m.graph) +======= + ).run( + m.graph + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfNoFBGEMM def test_quantized_mul_scalar_relu(self): @@ -2593,7 +2723,15 @@ def forward(self, x): "aten::relu(" ).check_not("aten::relu_(").check_not( "quantized::mul_scalar(" +<<<<<<< HEAD ).check_not("quantized::relu(").run(m.graph) +======= + ).check_not( + "quantized::relu(" + ).run( + m.graph + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @override_qengines def test_hardswish(self): @@ -3063,7 +3201,13 @@ def forward(self, x): 'Observer = prim::GetAttr[name="_observer_' ).check("prim::CallMethod").check_not( 'Observer = prim::GetAttr[name="_observer_' +<<<<<<< HEAD ).run(m.graph) +======= + ).run( + m.graph + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_insert_quant_dequant_linear_dynamic(self): class M(torch.nn.Module): @@ -3084,9 +3228,15 @@ def forward(self, x): else default_dynamic_qconfig ) m = quantize_dynamic_jit(m, {"": qconfig}, debug=True) +<<<<<<< HEAD assert len(m._modules._c.items()) == 2, ( "Expected to have two submodule of linear" ) +======= + assert ( + len(m._modules._c.items()) == 2 + ), "Expected to have two submodule of linear" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) wt_quant_func = ( "aten::quantize_per_channel" @@ -3099,11 +3249,29 @@ def forward(self, x): act_quant_func ).check_next("aten::dequantize").check( "aten::_choose_qparams_per_tensor" +<<<<<<< HEAD ).check_next(act_quant_func).check_next("aten::dequantize").check( wt_quant_func ).check_next("aten::dequantize").check_not(wt_quant_func).check( "return" ).run(m.graph) +======= + ).check_next( + act_quant_func + ).check_next( + "aten::dequantize" + ).check( + wt_quant_func + ).check_next( + "aten::dequantize" + ).check_not( + wt_quant_func + ).check( + "return" + ).run( + m.graph + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @override_qengines def test_dynamic_multi_op(self): diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index f9780dbf7b3df..513cde50d19d3 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -254,6 +254,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: maxpool_node = node input_act = maxpool_node.args[0] assert isinstance(input_act, Node) +<<<<<<< HEAD maxpool_node.meta["quantization_annotation"] = ( QuantizationAnnotation( input_qspec_map={ @@ -264,6 +265,18 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: ), _annotated=True, ) +======= + maxpool_node.meta[ + "quantization_annotation" + ] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + }, + output_qspec=SharedQuantizationSpec( + (input_act, maxpool_node) + ), + _annotated=True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def validate(self, model: torch.fx.GraphModule) -> None: @@ -339,9 +352,15 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: def derive_qparams_fn( obs_or_fqs: list[ObserverOrFakeQuantize], ) -> tuple[Tensor, Tensor]: +<<<<<<< HEAD assert len(obs_or_fqs) == 2, ( f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" ) +======= + assert ( + len(obs_or_fqs) == 2 + ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) act_obs_or_fq = obs_or_fqs[0] weight_obs_or_fq = obs_or_fqs[1] act_scale, act_zp = act_obs_or_fq.calculate_qparams() @@ -442,9 +461,15 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: def derive_qparams_fn( obs_or_fqs: list[ObserverOrFakeQuantize], ) -> tuple[Tensor, Tensor]: +<<<<<<< HEAD assert len(obs_or_fqs) == 1, ( f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}" ) +======= + assert ( + len(obs_or_fqs) == 1 + ), f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) weight_obs_or_fq = obs_or_fqs[0] ( weight_scale, @@ -748,6 +773,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: (first_input_node, cat_node) ) for input_node in input_nodes[1:]: +<<<<<<< HEAD input_qspec_map[input_node] = ( share_qparams_with_input_act0_qspec ) @@ -758,6 +784,18 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: output_qspec=share_qparams_with_input_act0_qspec, _annotated=True, ) +======= + input_qspec_map[ + input_node + ] = share_qparams_with_input_act0_qspec + + cat_node.meta[ + "quantization_annotation" + ] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=share_qparams_with_input_act0_qspec, + _annotated=True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def validate(self, model: torch.fx.GraphModule) -> None: @@ -783,9 +821,15 @@ def validate(self, model: torch.fx.GraphModule) -> None: obs_ins0 = getattr(m, input0.target) obs_ins1 = getattr(m, input1.target) assert obs_ins0 == obs_ins1 +<<<<<<< HEAD assert len(conv_output_obs) == 2, ( "expecting two observer that follows conv2d ops" ) +======= + assert ( + len(conv_output_obs) == 2 + ), "expecting two observer that follows conv2d ops" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # checking that the output observers for the two convs are shared as well assert conv_output_obs[0] == conv_output_obs[1] @@ -850,9 +894,15 @@ def _test_transitive_sharing_with_cat_helper(self, quantizer): obs_ins2 = getattr(m, output_obs.target) assert obs_ins0 == obs_ins2, "input observer does not match output" +<<<<<<< HEAD assert len(conv_output_obs) == 2, ( "expecting two observer that follows conv2d ops" ) +======= + assert ( + len(conv_output_obs) == 2 + ), "expecting two observer that follows conv2d ops" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # checking that the output observers for the two convs are shared as well assert conv_output_obs[0] == conv_output_obs[1] @@ -967,6 +1017,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: (first_input_node, cat_node) ) for input_node in input_nodes[1:]: +<<<<<<< HEAD input_qspec_map[input_node] = ( share_qparams_with_input_act0_qspec ) @@ -977,6 +1028,18 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: output_qspec=share_qparams_with_input_act0_qspec, _annotated=True, ) +======= + input_qspec_map[ + input_node + ] = share_qparams_with_input_act0_qspec + + cat_node.meta[ + "quantization_annotation" + ] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=share_qparams_with_input_act0_qspec, + _annotated=True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def validate(self, model: torch.fx.GraphModule) -> None: @@ -1063,6 +1126,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: share_qparams_with_input_act1_qspec = SharedQuantizationSpec( (second_input_node, cat_node) ) +<<<<<<< HEAD input_qspec_map[first_input_node] = ( share_qparams_with_input_act1_qspec ) @@ -1073,6 +1137,18 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: output_qspec=share_qparams_with_input_act1_qspec, _annotated=True, ) +======= + input_qspec_map[ + first_input_node + ] = share_qparams_with_input_act1_qspec + + cat_node.meta[ + "quantization_annotation" + ] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=share_qparams_with_input_act1_qspec, + _annotated=True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def validate(self, model: torch.fx.GraphModule) -> None: @@ -1121,6 +1197,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: share_qparams_with_input_act1_qspec = SharedQuantizationSpec( (second_input_node, add_node) ) +<<<<<<< HEAD input_qspec_map[first_input_node] = ( share_qparams_with_input_act1_qspec ) @@ -1132,6 +1209,19 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: allow_implicit_sharing=False, _annotated=True, ) +======= + input_qspec_map[ + first_input_node + ] = share_qparams_with_input_act1_qspec + + add_node.meta[ + "quantization_annotation" + ] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=share_qparams_with_input_act1_qspec, + allow_implicit_sharing=False, + _annotated=True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def validate(self, model: torch.fx.GraphModule) -> None: diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index 98682dc14e079..bbfb1f1fed1f4 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -277,9 +277,15 @@ def _verify_symmetric_xnnpack_qat_graph_helper( # Verify: conv literal args if expected_conv_literal_args is not None: +<<<<<<< HEAD assert len(expected_conv_literal_args) == 6, ( "wrong num conv args, bad test setup" ) +======= + assert ( + len(expected_conv_literal_args) == 6 + ), "wrong num conv args, bad test setup" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i in range(6): if i + 3 < len(conv_node.args): self.assertEqual( diff --git a/test/run_test.py b/test/run_test.py index 44a15d4ab2c68..981f94a68d6fb 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -12,17 +12,28 @@ import signal import subprocess import sys +<<<<<<< HEAD import sysconfig +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import tempfile import time from collections import defaultdict from collections.abc import Sequence from contextlib import ExitStack from datetime import datetime +<<<<<<< HEAD from importlib.metadata import PackageNotFoundError, version from pathlib import Path from typing import Any, cast, NamedTuple, Optional, Union +======= +from pathlib import Path +from typing import Any, cast, NamedTuple, Optional, Union + +import pkg_resources + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.distributed as dist from torch.multiprocessing import current_process, get_context @@ -36,6 +47,10 @@ TEST_CUDA, TEST_SAVE_XML, TEST_WITH_ASAN, +<<<<<<< HEAD +======= + TEST_WITH_CROSSREF, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TEST_WITH_ROCM, TEST_WITH_SLOW_GRADCHECK, ) @@ -182,19 +197,46 @@ def __contains__(self, item): "dynamo/test_misc", "inductor/test_cpu_repro", "inductor/test_cpu_select_algorithm", +<<<<<<< HEAD +======= + "inductor/test_aot_inductor_arrayref", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "inductor/test_torchinductor_codegen_dynamic_shapes", "lazy/test_meta_kernel", "onnx/test_utility_funs", "profiler/test_profiler", +<<<<<<< HEAD "test_jit", "dynamo/test_utils", "test_nn", +======= + "test_ao_sparsity", + "test_cpp_extensions_open_device_registration", + "test_jit", + "test_metal", + "test_mps", + "dynamo/test_torchrec", + "inductor/test_aot_inductor_utils", + "inductor/test_coordinate_descent_tuner", + "test_jiterator", + "inductor/test_cpu_cpp_wrapper", + "export/test_converter", + "inductor/test_inductor_freezing", + "dynamo/test_utils", + "test_nn", + "functorch/test_ops", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # these tests run long and fail in addition to that "dynamo/test_dynamic_shapes", "test_quantization", "inductor/test_torchinductor", "inductor/test_torchinductor_dynamic_shapes", "inductor/test_torchinductor_opinfo", +<<<<<<< HEAD +======= + "test_binary_ufuncs", + "test_unary_ufuncs", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # these tests fail when cuda is not available "inductor/test_aot_inductor", "inductor/test_best_config", @@ -213,12 +255,18 @@ def __contains__(self, item): # these tests fail when mkldnn is not available "inductor/test_custom_post_grad_passes", "inductor/test_mkldnn_pattern_matcher", +<<<<<<< HEAD "test_metal", # lacks quantization support "onnx/test_models_quantized_onnxruntime", "onnx/test_pytorch_onnx_onnxruntime", # sysctl -n hw.memsize is not available "test_mps", +======= + # lacks quantization support + "onnx/test_models_quantized_onnxruntime", + "onnx/test_pytorch_onnx_onnxruntime", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # https://github.com/pytorch/pytorch/issues/102078 "test_decomp", # https://github.com/pytorch/pytorch/issues/146698 @@ -229,6 +277,10 @@ def __contains__(self, item): # some false errors "doctests", # new failures to investigate and fix +<<<<<<< HEAD +======= + "cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "test_tensorboard", # onnx + protobuf failure, see # https://github.com/protocolbuffers/protobuf/issues/22104 @@ -237,9 +289,12 @@ def __contains__(self, item): "inductor/test_config", "test_public_bindings", "test_testing", +<<<<<<< HEAD # depend on z3-solver "fx/test_z3_gradual_types", "test_proxy_tensor", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] XPU_BLOCKLIST = [ @@ -251,7 +306,10 @@ def __contains__(self, item): "profiler/test_profiler_tree", "profiler/test_record_function", "profiler/test_torch_tidy", +<<<<<<< HEAD "test_openreg", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] XPU_TEST = [ @@ -262,6 +320,10 @@ def __contains__(self, item): RUN_PARALLEL_BLOCKLIST = [ "test_extension_utils", "test_cpp_extensions_jit", +<<<<<<< HEAD +======= + "test_cpp_extensions_open_device_registration", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "test_cpp_extensions_stream_and_event", "test_cpp_extensions_mtia_backend", "test_jit_disabled", @@ -649,6 +711,7 @@ def run_test( return ret_code +<<<<<<< HEAD def install_cpp_extensions(extensions_dir, env=os.environ): # Wipe the build folder, if it exists already build_dir = os.path.join(extensions_dir, "build") @@ -676,6 +739,28 @@ def install_cpp_extensions(extensions_dir, env=os.environ): platlib_path, os.path.splitdrive(platlib_path)[0] + os.sep ) install_directory = os.path.join(extensions_dir, "install", platlib_rel) +======= +def install_cpp_extensions(cpp_extensions_test_dir, env=os.environ): + # Wipe the build folder, if it exists already + cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, "build") + if os.path.exists(cpp_extensions_test_build_dir): + shutil.rmtree(cpp_extensions_test_build_dir) + + # Build the test cpp extensions modules + cmd = [sys.executable, "setup.py", "install", "--root", "./install"] + return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=env) + if return_code != 0: + return None, return_code + + install_directory = "" + # install directory is the one that is named site-packages + for root, directories, _ in os.walk( + os.path.join(cpp_extensions_test_dir, "install") + ): + for directory in directories: + if "-packages" in directory: + install_directory = os.path.join(root, directory) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert install_directory, "install_directory must not be empty" return install_directory, 0 @@ -822,6 +907,7 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja): # Build the test cpp extensions modules shell_env = os.environ.copy() shell_env["USE_NINJA"] = str(1 if use_ninja else 0) +<<<<<<< HEAD install_cmd = [ sys.executable, "-m", @@ -833,6 +919,10 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja): "./install", ] wheel_cmd = [sys.executable, "-m", "pip", "wheel", ".", "-w", "./dist"] +======= + install_cmd = [sys.executable, "setup.py", "install", "--root", "./install"] + wheel_cmd = [sys.executable, "setup.py", "bdist_wheel"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return_code = shell(install_cmd, cwd=cpp_extensions_test_dir, env=shell_env) if return_code != 0: return return_code @@ -919,7 +1009,11 @@ def _test_autoload(test_directory, options, enable=True): def run_test_with_openreg(test_module, test_directory, options): openreg_dir = os.path.join( +<<<<<<< HEAD test_directory, "cpp_extensions", "open_registration_extension", "torch_openreg" +======= + test_directory, "cpp_extensions", "open_registration_extension" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) install_dir, return_code = install_cpp_extensions(openreg_dir) if return_code != 0: @@ -1255,6 +1349,10 @@ def run_ci_sanity_check(test: ShardedTest, test_directory, options): "test_ci_sanity_check_fail": run_ci_sanity_check, "test_autoload_enable": test_autoload_enable, "test_autoload_disable": test_autoload_disable, +<<<<<<< HEAD +======= + "test_cpp_extensions_open_device_registration": run_test_with_openreg, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "test_openreg": run_test_with_openreg, "test_transformers_privateuse1": run_test_with_openreg, } @@ -1409,6 +1507,7 @@ def parse_args(): action="store_true", help="Enables removing tests based on TD", default=IS_CI +<<<<<<< HEAD and get_pr_number() is not None and not strtobool(os.environ.get("NO_TD", "False")) and not IS_MACOS @@ -1416,6 +1515,20 @@ def parse_args(): and "onnx" not in BUILD_ENVIRONMENT and os.environ.get("GITHUB_WORKFLOW", "slow") in ("trunk", "pull", "rocm", "rocm-mi300"), +======= + and ( + TEST_WITH_CROSSREF + or TEST_CONFIG == "distributed" + or TEST_CONFIG == "default" + ) + and get_pr_number() is not None + and not strtobool(os.environ.get("NO_TD", "False")) + and not TEST_WITH_ROCM + and not IS_MACOS + and "xpu" not in BUILD_ENVIRONMENT + and "onnx" not in BUILD_ENVIRONMENT + and os.environ.get("GITHUB_WORKFLOW", "slow") in ("trunk", "pull"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) parser.add_argument( "--shard", @@ -1473,7 +1586,10 @@ def parse_args(): parser.add_argument( "--upload-artifacts-while-running", action="store_true", +<<<<<<< HEAD default=IS_CI, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) group = parser.add_mutually_exclusive_group() @@ -1560,7 +1676,11 @@ def get_selected_tests(options) -> list[str]: if options.einops: selected_tests = list( filter( +<<<<<<< HEAD lambda test_name: test_name.startswith("dynamo/test_einops"), +======= + lambda test_name: test_name.startswith("test/dynamo/test_einops"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) selected_tests, ) ) @@ -1586,8 +1706,11 @@ def get_selected_tests(options) -> list[str]: "test_nn", "inductor/test_mps_basic", "inductor/test_torchinductor", +<<<<<<< HEAD "inductor/test_aot_inductor", "inductor/test_torchinductor_dynamic_shapes", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] else: # Exclude all mps tests otherwise @@ -1861,7 +1984,10 @@ def handle_complete(failure: Optional[TestFailure]): "If running on CI, add the 'keep-going' label to your PR and rerun your jobs." ) +<<<<<<< HEAD pool = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: for test in selected_tests_serial: options_clone = copy.deepcopy(options) @@ -1924,9 +2050,14 @@ def parallel_test_completion_callback(failure): del os.environ["NUM_PARALLEL_PROCS"] finally: +<<<<<<< HEAD if pool: pool.terminate() pool.join() +======= + pool.terminate() + pool.join() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return @@ -1937,6 +2068,7 @@ def check_pip_packages() -> None: "pytest-flakefinder", "pytest-xdist", ] +<<<<<<< HEAD try: for pkg in packages: version(pkg) @@ -1945,6 +2077,15 @@ def check_pip_packages() -> None: f"Missing pip dependency: {pkg}, please run `pip install -r .ci/docker/requirements-ci.txt`" ) sys.exit(1) +======= + installed_packages = [i.key for i in pkg_resources.working_set] + for package in packages: + if package not in installed_packages: + print_to_stderr( + f"Missing pip dependency: {package}, please run `pip install -r .ci/docker/requirements-ci.txt`" + ) + sys.exit(1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def main(): diff --git a/test/scripts/run_cuda_memcheck.py b/test/scripts/run_cuda_memcheck.py index ca3196f4f4910..4aad8e57dcb53 100755 --- a/test/scripts/run_cuda_memcheck.py +++ b/test/scripts/run_cuda_memcheck.py @@ -157,9 +157,15 @@ async def run1(coroutine_id): gpuid = coroutine_id % GPUS else: gpu_assignments = args.gpus.split(":") +<<<<<<< HEAD assert args.nproc == len(gpu_assignments), ( "Please specify GPU assignment for each process, separated by :" ) +======= + assert args.nproc == len( + gpu_assignments + ), "Please specify GPU assignment for each process, separated by :" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gpuid = gpu_assignments[coroutine_id] while progress < len(ALL_TESTS): diff --git a/test/slow_tests.json b/test/slow_tests.json index cd9d6864f0ec4..3713323914a0a 100644 --- a/test/slow_tests.json +++ b/test/slow_tests.json @@ -1,4 +1,5 @@ { +<<<<<<< HEAD "EndToEndLSTM (__main__.RNNTest)": 194.9510040283203, "MultiheadAttention (__main__.ModulesTest)": 140.13499959309897, "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 89.57710986667209, @@ -241,4 +242,266 @@ "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 140.70899963378906, "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 118.22750091552734, "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 181.27366256713867 +======= + "EndToEndLSTM (__main__.RNNTest)": 184.65333048502603, + "MultiheadAttention (__main__.ModulesTest)": 134.43099975585938, + "test_AllenaiLongformerBase_repro_cpu_halide (__main__.HalideCpuTests)": 199.10467020670572, + "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 83.39333131578233, + "test_adaptive_max_pool2d1_cpu_halide (__main__.HalideCpuTests)": 113.98933410644531, + "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 61.397444831000435, + "test_alexnet_prefix_cpu_halide (__main__.HalideCpuTests)": 176.93266805013022, + "test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 64.99899800618489, + "test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 66.08271371750604, + "test_aot_autograd_symbolic_exhaustive_masked_norm_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 61.71266555786133, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 151.31399536132812, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 267.58533732096356, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 120.89933013916016, + "test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 73.94028554643903, + "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 112.47666422526042, + "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 609.4812072753906, + "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 158.25587558746338, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 502.05988226996527, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 494.381110297309, + "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 124.20333480834961, + "test_avg_pool3d_backward_cpu_halide (__main__.HalideCpuTests)": 61.64700063069662, + "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 71.78066380818684, + "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 78.40683364868164, + "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 250.50655958387586, + "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 145.54050064086914, + "test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 327.4082217746311, + "test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 409.865227593316, + "test_collect_callgrind (__main__.TestBenchmarkUtils)": 310.50811258951825, + "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 90.77466710408528, + "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 88.94400024414062, + "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 61.99116643269857, + "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 89.07300059000652, + "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 98.6163330078125, + "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 65.7913335164388, + "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 400.17799886067706, + "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 65.32166544596355, + "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 433.8283386230469, + "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 65.70300038655598, + "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 246.12633005777994, + "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 237.4903361002604, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1256.5741882324219, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.78149922688802, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1055.0651448567708, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 69.93966611226399, + "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 65.20016670227051, + "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 61.16316668192545, + "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 62.08466657002767, + "test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 62.160666147867836, + "test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 65.54600079854329, + "test_comprehensive_linalg_vector_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 85.31400044759114, + "test_comprehensive_linalg_vector_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 86.7923355102539, + "test_comprehensive_linalg_vector_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 83.80366770426433, + "test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 65.01507412945783, + "test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 66.07433333220305, + "test_comprehensive_logspace_cpu_float32 (__main__.TestInductorOpInfoCPU)": 169.19166564941406, + "test_comprehensive_logspace_cpu_float64 (__main__.TestInductorOpInfoCPU)": 164.14199829101562, + "test_comprehensive_logspace_cpu_int32 (__main__.TestInductorOpInfoCPU)": 167.1233367919922, + "test_comprehensive_logspace_cpu_int64 (__main__.TestInductorOpInfoCPU)": 161.9933319091797, + "test_comprehensive_masked_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 204.7566680908203, + "test_comprehensive_masked_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 202.51532999674478, + "test_comprehensive_masked_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 205.77066548665366, + "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 114.11033376057942, + "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 105.25066757202148, + "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 113.67999903361003, + "test_comprehensive_nn_functional_fractional_max_pool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 101.1036114162869, + "test_comprehensive_nn_functional_fractional_max_pool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 94.08183288574219, + "test_comprehensive_nn_functional_fractional_max_pool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 94.20638847351074, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 93.08233388264973, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 94.11516571044922, + "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 107.86000061035156, + "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 94.72633361816406, + "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 284.54283142089844, + "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 228.18283081054688, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 77.24066543579102, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 77.22533416748047, + "test_comprehensive_nn_functional_max_pool1d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 153.27567036946616, + "test_comprehensive_nn_functional_max_pool1d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 151.73899841308594, + "test_comprehensive_nn_functional_max_pool1d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 137.59866841634116, + "test_comprehensive_nn_functional_max_pool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 1176.6233723958333, + "test_comprehensive_nn_functional_max_pool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 1034.320332845052, + "test_comprehensive_nn_functional_max_pool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 1053.9040120442708, + "test_comprehensive_nn_functional_max_pool2d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 901.5313517252604, + "test_comprehensive_nn_functional_max_pool2d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 914.4829915364584, + "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1132.8611653645833, + "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1129.974344889323, + "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1135.6740112304688, + "test_comprehensive_nn_functional_max_pool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 891.2769978841146, + "test_comprehensive_nn_functional_max_pool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 687.6756591796875, + "test_comprehensive_nn_functional_max_pool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 683.6936645507812, + "test_comprehensive_nn_functional_max_pool3d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 678.6616617838541, + "test_comprehensive_nn_functional_max_pool3d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 701.6133422851562, + "test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 495.5906626383464, + "test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 487.7074940999349, + "test_comprehensive_nn_functional_max_unpool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 115.73200225830078, + "test_comprehensive_nn_functional_max_unpool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 118.66033426920573, + "test_comprehensive_nn_functional_max_unpool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 115.82266743977864, + "test_comprehensive_nn_functional_max_unpool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 67.43566640218098, + "test_comprehensive_nn_functional_unfold_cpu_bool (__main__.TestInductorOpInfoCPU)": 68.42166900634766, + "test_comprehensive_nn_functional_unfold_cpu_float16 (__main__.TestInductorOpInfoCPU)": 118.02966817220052, + "test_comprehensive_nn_functional_unfold_cpu_float32 (__main__.TestInductorOpInfoCPU)": 105.94366709391277, + "test_comprehensive_nn_functional_unfold_cpu_float64 (__main__.TestInductorOpInfoCPU)": 118.99266815185547, + "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 115.5125020345052, + "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 103.90849939982097, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 66.59218077226119, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 64.84800084431966, + "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 60.27900060017904, + "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 68.57966613769531, + "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 66.81166776021321, + "test_compute_global_tensor_shape_1D_invalid_shape (__main__.UtilTest)": 209.35732873280844, + "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 154.30916849772134, + "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 142.58683141072592, + "test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 94.73116620381673, + "test_constructor_autograd_SparseCSR_cuda (__main__.TestSparseAnyCUDA)": 110.29800033569336, + "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 244.17077806260852, + "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 627.981665717231, + "test_conv2d_unary_cpu_cpp_wrapper (__main__.TestCppWrapper)": 68.8806660970052, + "test_conv3d_binary_broadcast_shapes_cpu_cpu (__main__.TestPatternMatcherGenericCPU)": 75.51066589355469, + "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 78.39416631062825, + "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 74.26416778564453, + "test_count_nonzero_all (__main__.TestBool)": 630.1393364800347, + "test_custom_module_lstm (__main__.TestQuantizedOps)": 666.0326605902778, + "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.40749867757161, + "test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDTensorOpsCPU)": 88.80566660563152, + "test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 153.85249682267508, + "test_error_detection_and_propagation (__main__.NcclErrorHandlingTest)": 67.68433125813802, + "test_fail_arithmetic_ops.py (__main__.TestTyping)": 64.70655483669705, + "test_fail_creation_ops.py (__main__.TestTyping)": 70.33796894550323, + "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 73.33583068847656, + "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 95.88233311971028, + "test_fn_gradgrad_map_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 84.52066802978516, + "test_fn_gradgrad_map_triple_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 518.5540161132812, + "test_fn_gradgrad_map_triple_nested_cuda_float64 (__main__.TestBwdGradientsCUDA)": 352.0611623128255, + "test_fuse_large_params_cpu (__main__.CpuTests)": 98.19175052642822, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 172.9732191297743, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 178.04811265733508, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 96.32300059000652, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 94.25100072224934, + "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 110.52466583251953, + "test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 147.46899922688803, + "test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 137.17833455403647, + "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 223.40133412679037, + "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 130.75699996948242, + "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 159.8721669514974, + "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 661.1241658528646, + "test_grid_sampler_2d_cpu_halide (__main__.HalideCpuTests)": 196.1066640218099, + "test_group_norm (__main__.TestQuantizedOps)": 143.82022105322943, + "test_indirect_device_assert (__main__.TritonCodeGenTests)": 252.9750010172526, + "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 68.59622192382812, + "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 132.5279998779297, + "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 151.57311164008246, + "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 117.37533315022786, + "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 577.0678304036459, + "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 72.07283401489258, + "test_linear (__main__.TestStaticQuantizedModule)": 178.05622397528754, + "test_linear_relu (__main__.TestStaticQuantizedModule)": 64.9945551554362, + "test_lobpcg_ortho_cuda_float64 (__main__.TestLinalgCUDA)": 83.73499965667725, + "test_lstm_cpu (__main__.TestMkldnnCPU)": 66.0846659342448, + "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 125.42355600992839, + "test_max_pool2d2_cpu_halide (__main__.HalideCpuTests)": 445.62599690755206, + "test_max_pool2d3_cpu_halide (__main__.HalideCpuTests)": 134.19500223795572, + "test_max_pool2d5_cpu_halide (__main__.HalideCpuTests)": 363.20066324869794, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 63.19877794053819, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 61.39377763536241, + "test_proper_exit (__main__.TestDataLoader)": 240.04466501871744, + "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 271.00699615478516, + "test_python_ref_executor__refs_special_zeta_executor_aten_cuda_float64 (__main__.TestCommonCUDA)": 64.18233426411946, + "test_qat_conv2d_unary (__main__.TestQuantizePT2EX86Inductor)": 151.71777767605252, + "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn1d)": 61.14148919847276, + "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn2d)": 60.4263552347819, + "test_qat_mobilenet_v2 (__main__.TestQuantizePT2EQATModels)": 88.72544479370117, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 69.56600189208984, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 82.00166829427083, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 78.14999898274739, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 68.93766784667969, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 75.8633321126302, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 78.89766947428386, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 67.93033345540364, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 76.1066665649414, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 77.59533437093098, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 70.57233174641927, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 86.69966634114583, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 82.32333374023438, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 69.6453348795573, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 79.38400014241536, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 80.18400065104167, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 71.49599965413411, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 78.35600026448567, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 82.9933344523112, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 71.89866892496745, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 75.72566731770833, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 80.28999837239583, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 80.68799845377605, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 85.98066711425781, + "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 418.50034586588544, + "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 842.5636698404948, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 658.1936645507812, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1321.1958414713542, + "test_quick_core_backward_expand_copy_cuda_float64 (__main__.TestDecompCUDA)": 72.79183260599773, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 68.16699981689453, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 222.59966786702475, + "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 89.49299875895183, + "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 208.05382792154947, + "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 61.09833272298177, + "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 119.15299987792969, + "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 72.5490010579427, + "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 137.61000188191733, + "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 83.77516682942708, + "test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 112.9426663716634, + "test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 68.61433410644531, + "test_rosenbrock_sparse_with_lrsched_False_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 71.73550089200337, + "test_rosenbrock_sparse_with_lrsched_True_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 66.45991698900859, + "test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 60.68633270263672, + "test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 74.52111011081271, + "test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 156.46233622233072, + "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 128.3509979248047, + "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 148.15933481852213, + "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 145.64644877115884, + "test_sort_stable_cpu (__main__.CpuTritonTests)": 76.39066569010417, + "test_split_cumsum_cpu (__main__.CpuTritonTests)": 89.5290018717448, + "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 150.72099796930948, + "test_tensor_split (__main__.TestVmapOperators)": 72.26428134347766, + "test_terminate_handler_on_crash (__main__.TestTorch)": 100.98866719669766, + "test_terminate_signal (__main__.ForkTest)": 134.33088995267948, + "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 133.97255667547384, + "test_terminate_signal (__main__.SpawnTest)": 137.73455943001642, + "test_torch_distributions_functions_dynamic_shapes (__main__.DynamicShapesFunctionTests)": 193.52591840426126, + "test_torchvision_smoke (__main__.TestTensorBoardPytorchGraph)": 144.84678077697754, + "test_train_parity_multi_group_unshard_async_op (__main__.TestFullyShard1DTrainingCore)": 62.523999532063804, + "test_transformer_backend_inductor_fullgraph_True (__main__.TestFullyShardCompile)": 82.06791687011719, + "test_transformer_backend_inductor_fullgraph_True_graph_partition (__main__.TestFullyShardCompile)": 82.57758394877116, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 93.72849909464519, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 86.33483123779297, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 84.0580005645752, + "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 128.47150166829428, + "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 125.92099952697754, + "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 105.98566563924153, + "test_unary_ops (__main__.TestTEFuserDynamic)": 173.52266354031033, + "test_unary_ops (__main__.TestTEFuserStatic)": 154.03555562761096, + "test_upsample_bicubic2d_cpu_halide (__main__.HalideCpuTests)": 95.91699727376302, + "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 91.32800038655598, + "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 72.65949885050456, + "test_vmapjvpvjp_diff_cuda_float32 (__main__.TestOperatorsCUDA)": 64.64249992370605, + "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 114.75466410319011, + "test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 61.643143063499814, + "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 76.99316660563152, + "test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 67.82800102233887, + "test_vmapjvpvjp_linalg_pinv_singular_cpu_float32 (__main__.TestOperatorsCPU)": 60.267666498819985, + "test_vmapjvpvjp_linalg_solve_triangular_cuda_float32 (__main__.TestOperatorsCUDA)": 68.94433307647705, + "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 73.93966547648112, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 88.03500111897786, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 90.39650090535481, + "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 79.07066853841145, + "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 95.49366696675618, + "test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 81.16833623250325, + "test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 61.30799865722656, + "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 79.50816663106282, + "test_vmapvjpvjp_linalg_lstsq_cuda_float32 (__main__.TestOperatorsCUDA)": 100.31945332613859, + "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 106.99416732788086, + "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 103.08566665649414, + "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 149.96750259399414 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } \ No newline at end of file diff --git a/test/test_accelerator.py b/test/test_accelerator.py index 21731bd275b60..317be9d781881 100644 --- a/test/test_accelerator.py +++ b/test/test_accelerator.py @@ -1,6 +1,9 @@ # Owner(s): ["module: tests"] +<<<<<<< HEAD import gc +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys import unittest @@ -157,6 +160,7 @@ def test_generic_event_behavior(self): ): event1.elapsed_time(event2) +<<<<<<< HEAD @unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!") def test_memory_stats(self): # Ensure that device allocator is initialized @@ -234,6 +238,8 @@ def test_memory_stats(self): self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/test_autograd.py b/test/test_autograd.py index dbd1454ff7459..3afbfdfdf61c7 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -109,10 +109,13 @@ def graph_desc(fn): class TestAutograd(TestCase): +<<<<<<< HEAD def tearDown(self): torch.autograd._force_original_view_tracking(False) super(TestCase, self).tearDown() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_copy_slices_graph_task_updates(self): def f1(x, y): out = x.clone().view(-1) @@ -1196,6 +1199,7 @@ def fn(x, reduce=True): tmp_edge, inputs=(x,), grad_tensors=torch.tensor([1.0, 2.0, 3.0, 4.0]) ) +<<<<<<< HEAD def test_gradient_edge_graph_ownership(self): # Ensure we own the graph properly class Clone(torch.autograd.Function): @@ -1223,6 +1227,8 @@ def backward(ctx, gX): del out torch.autograd.backward(edge) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_grad_nonleaf(self): x_init = torch.randn(2, 2, requires_grad=True) x = x_init @@ -3888,6 +3894,7 @@ def backward(ctx, grad_output): torch.autograd.grad(y, x, create_graph=True) torch.autograd.grad(y, x) # should not error! +<<<<<<< HEAD def test_custom_autograd_ac_early_stop(self): refs = [] @@ -3920,6 +3927,8 @@ def scope(): for ref in refs: self.assertIsNone(ref()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_detach(self): x = torch.randn(10, 10, requires_grad=True) y = x + 2 @@ -4192,7 +4201,11 @@ def backward(self, grad_output): self.assertIsNone(y.grad_fn) def test_backward_copy(self): +<<<<<<< HEAD # This tests checks backward engine for a very subtle bug that appeared +======= + # This tests checks backward engine for a very subtle bug that appreared +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # in one of the initial versions of autograd. Gradients tensors were # simply stored in lists while the function waited for all its gradients # to be computed. However, sometimes an output was used multiple times, @@ -4375,7 +4388,11 @@ def backward(ctx, grad_output): ctx.output_var.sum().backward() return ctx.x.grad * grad_output +<<<<<<< HEAD # Reentrant starts on CPU thread, finishes on GPU thread +======= + # Reentrant starts on CPU thread, finishs on GPU thread +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.randn(2, 2, requires_grad=True) out = Reenter.apply(x) out.sum().backward() @@ -10791,7 +10808,11 @@ def get_tensor_and_weak_ref(): dual = fwAD.make_dual(foo, tangent) self.assertFalse(tangent_ref.expired()) +<<<<<<< HEAD # Make sure that the tangent we provided has been reused as is +======= + # Make sure that the tangent we provided has been re-used as is +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(fwAD.unpack_dual(dual)[1] is tangent) # Make sure that dual is keeping the tangent alive @@ -11150,7 +11171,11 @@ def test_advanced_packing_unpacking(self): self.assertEqual( dual_tangent.storage().data_ptr(), bar.storage().data_ptr() ) +<<<<<<< HEAD # And the tangent is actually reused as-is so it is still the same Tensor +======= + # And the tangent is actually re-used as-is so it is still the same Tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertIs(dual_tangent, bar) # Ensure we properly share the version counter @@ -12032,19 +12057,31 @@ def backward(ctx, grad_output): (new_param**2).sum().backward() return grad_output +<<<<<<< HEAD # Reentrant starts on GPU thread, finishes on GPU thread +======= + # Reentrant starts on GPU thread, finishs on GPU thread +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.randn(2, 2, device=device, requires_grad=True) out = ReentrantFunc.apply(x) out.sum().backward() +<<<<<<< HEAD # Reentrant starts on CPU thread, finishes on GPU thread +======= + # Reentrant starts on CPU thread, finishs on GPU thread +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.randn(2, 2, requires_grad=True) # set ReentrantFunc node to GPU to emit tasks to GPU queue ReentrantFunc._cpu_mode = False out = ReentrantFunc.apply(x) out.sum().backward() +<<<<<<< HEAD # Reentrant starts on GPU thread, finishes on CPU thread +======= + # Reentrant starts on GPU thread, finishs on CPU thread +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.randn(2, 2, device=device, requires_grad=True) # set ReentrantFunc node to CPU to emit tasks to CPU queue ReentrantFunc._cpu_mode = True @@ -12455,6 +12492,7 @@ def test_resize_version_bump(self, device): x.resize_as_(y) self.assertEqual(x._version, 2) +<<<<<<< HEAD @unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") def test_zero_dim_param_mixed_device_grad(self, device): # cpu 0-dim params with an accelerator device grad @@ -12478,6 +12516,8 @@ def forward(self, x): self.assertEqual(model.a.grad.device, torch.device("cpu")) self.assertEqual(model.b.grad.device, torch.device("cpu")) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestAllowMutationOnSaved(TestCase): def assertClonedLenEqual(self, ctx, n): @@ -13751,7 +13791,11 @@ def forward(self, x): y = x * x if torch.cuda.device_count() >= 2: # DataParallel is calling the forward in different threads +<<<<<<< HEAD # without propagating TLS, so hooks should not be called here +======= + # without progating TLS, so hooks should not be called here +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _self.assertEqual(len(w), 0) else: # DataParallel only uses one thread @@ -14173,13 +14217,17 @@ def fn(x): # early stop is enabled. return clone(x.sin().cos()) +<<<<<<< HEAD # Test default +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Early stopping is enabled by default a = torch.tensor(1.0, requires_grad=True) out = checkpoint(fn, a, use_reentrant=False) out.backward() self.assertEqual(counter[0], 1) +<<<<<<< HEAD # Test local setting counter = [0] a = torch.tensor(1.0, requires_grad=True) @@ -14194,6 +14242,9 @@ def fn(x): self.assertEqual(counter[0], 1) # Test context manager +======= + # Try using the context manager to set early stopping to False. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Expect early stopping to be disabled for all checkpoints ran under # the context manager, even though context manager is no longer active # when backward/recomputation is performed. @@ -14201,6 +14252,7 @@ def fn(x): a = torch.tensor(1.0, requires_grad=True) with torch.utils.checkpoint.set_checkpoint_early_stop(False): out = checkpoint(fn, a, use_reentrant=False) +<<<<<<< HEAD out.backward() self.assertEqual(counter[0], 2) @@ -14235,6 +14287,12 @@ def fn(x): out.backward() self.assertEqual(counter[0], 1) +======= + + out.backward() + self.assertEqual(counter[0], 2) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nested_checkpoint_set_early_stop_no_recompution_needed(self): # Case 1: We have one tensor saved and its the input diff --git a/test/test_autograd_fallback.py b/test/test_autograd_fallback.py index d6252ac6f34a3..d3346d6ad50c8 100644 --- a/test/test_autograd_fallback.py +++ b/test/test_autograd_fallback.py @@ -6,7 +6,11 @@ import numpy as np import torch +<<<<<<< HEAD from torch.library import _scoped_library +======= +from torch.library import _scoped_library, Library +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -28,6 +32,7 @@ def autograd_fallback_mode(mode): class TestAutogradFallback(TestCase): test_ns = "_test_autograd_fallback" +<<<<<<< HEAD def setUp(self): super().setUp() self.libraries = [] @@ -38,14 +43,28 @@ def tearDown(self): for lib in self.libraries: lib._destroy() del self.libraries +======= + def tearDown(self): + if hasattr(torch.ops, self.test_ns): + delattr(torch.ops, self.test_ns) + if hasattr(self, "lib"): + del self.lib.m + del self.lib +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_op(self, name): return getattr(getattr(torch.ops, self.test_ns), name).default def get_lib(self): +<<<<<<< HEAD result = torch.library.Library(self.test_ns, "FRAGMENT") # noqa: TOR901 self.libraries.append(result) return result +======= + lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901 + self.lib = lib + return lib +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("mode", ("nothing", "warn")) def test_no_grad(self, mode): diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 569d1bac85958..f85d572c5e258 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -79,7 +79,11 @@ class TestBinaryUfuncs(TestCase): # Generic tests for elementwise binary (AKA binary universal (u) functions (funcs)) # TODO: below contiguous tensor results are compared with a variety of noncontiguous results. +<<<<<<< HEAD # It would be interesting to have the lhs and rhs have different discontinuities. +======= + # It would be interesting to have the lhs and rhs have different discontiguities. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Helper for comparing torch tensors and NumPy arrays # TODO: should this or assertEqual also validate that strides are equal? @@ -1688,11 +1692,20 @@ def test_cpu_tensor_pow_cuda_scalar_tensor(self, device): @onlyCUDA @dtypes(torch.complex64, torch.complex128) +<<<<<<< HEAD def test_pow_cuda_complex_extremal_passing(self, device, dtype): t = torch.tensor(complex(-1.0, float("inf")), dtype=dtype, device=device) cuda_out = t.pow(2) cpu_out = t.cpu().pow(2) self.assertEqual(cpu_out, cuda_out) +======= + def test_pow_cuda_complex_extremal_failing(self, device, dtype): + t = torch.tensor(complex(-1.0, float("inf")), dtype=dtype, device=device) + with self.assertRaises(AssertionError): + cuda_out = t.pow(2) + cpu_out = t.cpu().pow(2) + self.assertEqual(cpu_out, cuda_out) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfTorchDynamo() @onlyNativeDeviceTypes @@ -2521,7 +2534,11 @@ def _test_copysign_numpy(a, b): # Verify Value self.assertEqual(torch_result, expected) # Verify Sign +<<<<<<< HEAD # Use double copysign to verify the correctness of 0.0 and -0.0, since +======= + # Use double copysign to verify the correctnes of 0.0 and -0.0, since +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # it always True for self.assertEqual(0.0 == -0.0). So, we use 1 as the # magnitude to verify the sign between torch and numpy results, elementwise. # Special case: NaN conversions between FP32 and FP16 is not bitwise diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index 2f69bcfeb9c48..9b44b7a64cbd6 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -148,7 +148,11 @@ def test_cusolver_extension(self): @unittest.skipIf(IS_WINDOWS, "Not available on Windows") def test_no_python_abi_suffix_sets_the_correct_library_name(self): +<<<<<<< HEAD # For this test, run_test.py will call `python -m pip install .` in the +======= + # For this test, run_test.py will call `python setup.py install` in the +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # cpp_extensions/no_python_abi_suffix_test folder, where the # `BuildExtension` class has a `no_python_abi_suffix` option set to # `True`. This *should* mean that on Python 3, the produced shared diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index e93167296a002..40c620a4cbb90 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -21,7 +21,10 @@ from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN from torch.testing._internal.common_utils import gradcheck, TEST_XPU from torch.utils.cpp_extension import ( +<<<<<<< HEAD _get_cuda_arch_flags, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _TORCH_PATH, check_compiler_is_gcc, CUDA_HOME, @@ -220,12 +223,15 @@ def test_mps_extension(self): self.assertEqual(cpu_output, mps_output.to("cpu")) +<<<<<<< HEAD # Regression test for https://github.com/pytorch/pytorch/issues/163721 lib = torch.mps.compile_shader("void kernel noop(device float *x) {}") lib.noop(mps_output) module.mps_add_one_new_context(mps_output) self.assertEqual(cpu_output + 1.0, mps_output.to("cpu")) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _run_jit_cuda_archflags(self, flags, expected): # Compile an extension with given `flags` def _check_cuobjdump_output(expected_values, is_ptx=False): @@ -328,6 +334,7 @@ def test_jit_cuda_archflags(self): [f"{capability[0]}{capability[1]}" for capability in capabilities], None, ), +<<<<<<< HEAD } archflags["7.5+PTX"] = (["75"], ["75"]) major, minor = map(int, torch.version.cuda.split(".")[:2]) @@ -337,6 +344,14 @@ def test_jit_cuda_archflags(self): archflags["Volta"] = (["70"], ["70"]) archflags["5.0;6.0+PTX;7.0;7.5"] = (["50", "60", "70", "75"], ["60"]) if major < 12: +======= + "Maxwell+Tegra;6.1": (["53", "61"], None), + "Volta": (["70"], ["70"]), + } + archflags["7.5+PTX"] = (["75"], ["75"]) + archflags["5.0;6.0+PTX;7.0;7.5"] = (["50", "60", "70", "75"], ["60"]) + if int(torch.version.cuda.split(".")[0]) < 12: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # CUDA 12 drops compute capability < 5.0 archflags["Pascal 3.5"] = (["35", "60", "61"], None) @@ -357,6 +372,7 @@ def test_jit_cuda_archflags(self): # to avoid errors from here leaking into other tests pass +<<<<<<< HEAD @unittest.skipIf(not TEST_CUDA, "CUDA not found") def test_cuda_arch_flags_non_default_gencode(self): user_arch_flags = ["-gencode=arch=compute_86,code=sm_86"] @@ -386,6 +402,8 @@ def test_cuda_arch_flags_default_gencode(self): len(empty_flags), 0, "Empty list should generate default flags" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not TEST_CUDNN, "CuDNN not found") @unittest.skipIf(TEST_ROCM, "Not supported on ROCm") def test_jit_cudnn_extension(self): @@ -1040,7 +1058,11 @@ def test_warning(self): t = torch.rand(2).double() cpp_tensor_name = r"CPUDoubleType" +<<<<<<< HEAD # Without error handling, the warnings cannot be caught +======= + # Without error handling, the warnings cannot be catched +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) warn_mod = torch.utils.cpp_extension.load_inline( name="warn_mod", cpp_sources=[source], @@ -1074,23 +1096,39 @@ def test_warning(self): ) with warnings.catch_warnings(record=True) as w: +<<<<<<< HEAD # Caught with no error should be detected warn_mod.foo(t, 0) self.assertEqual(len(w), 1) # Caught with cpp error should also be detected +======= + # Catched with no error should be detected + warn_mod.foo(t, 0) + self.assertEqual(len(w), 1) + + # Catched with cpp error should also be detected +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaisesRegex(TypeError, t.type()): warn_mod.foo(t, 1) self.assertEqual(len(w), 2) +<<<<<<< HEAD # Caught with python error should also be detected +======= + # Catched with python error should also be detected +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaisesRegex( SystemError, "bad argument to internal function" ): warn_mod.foo(t, 2) self.assertEqual(len(w), 3) +<<<<<<< HEAD # Caught with pybind error should also be detected +======= + # Catched with pybind error should also be detected +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Note that there is no type name translation for pybind errors with self.assertRaisesRegex(KeyError, cpp_tensor_name): warn_mod.foo(t, 3) @@ -1233,7 +1271,11 @@ def test_aoti_torch_call_dispatcher(self): #include #include #include +<<<<<<< HEAD #include +======= + #include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py new file mode 100644 index 0000000000000..4aedfdb630149 --- /dev/null +++ b/test/test_cpp_extensions_open_device_registration.py @@ -0,0 +1,269 @@ +# Owner(s): ["module: cpp-extensions"] + +import _codecs +import io +import os +import unittest +from unittest.mock import patch + +import numpy as np +import pytorch_openreg # noqa: F401 + +import torch +import torch.testing._internal.common_utils as common +import torch.utils.cpp_extension +from torch.serialization import safe_globals +from torch.testing._internal.common_utils import TemporaryFileName + + +@unittest.skipIf(common.TEST_XPU, "XPU does not support cppextension currently") +@common.markDynamoStrictTest +class TestCppExtensionOpenRegistration(common.TestCase): + """Tests Open Device Registration with C++ extensions.""" + + module = None + + def setUp(self): + super().setUp() + + # cpp extensions use relative paths. Those paths are relative to + # this file, so we'll change the working directory temporarily + self.old_working_dir = os.getcwd() + os.chdir(os.path.dirname(os.path.abspath(__file__))) + + assert self.module is not None + + def tearDown(self): + super().tearDown() + + # return the working directory (see setUp) + os.chdir(self.old_working_dir) + + @classmethod + def setUpClass(cls): + common.remove_cpp_extensions_build_root() + + cls.module = torch.utils.cpp_extension.load( + name="custom_device_extension", + sources=[ + "cpp_extensions/open_registration_extension.cpp", + ], + extra_include_paths=["cpp_extensions"], + extra_cflags=["-g"], + verbose=True, + ) + + def test_open_device_faketensor(self): + with torch._subclasses.fake_tensor.FakeTensorMode.push(): + a = torch.empty(1, device="openreg") + b = torch.empty(1, device="openreg:0") + result = a + b # noqa: F841 + + def test_open_device_named_tensor(self): + torch.empty([2, 3, 4, 5], device="openreg", names=["N", "C", "H", "W"]) + + # Not an open registration test - this file is just very convenient + # for testing torch.compile on custom C++ operators + def test_compile_autograd_function_returns_self(self): + x_ref = torch.randn(4, requires_grad=True) + out_ref = self.module.custom_autograd_fn_returns_self(x_ref) + out_ref.sum().backward() + + x_test = x_ref.detach().clone().requires_grad_(True) + f_compiled = torch.compile(self.module.custom_autograd_fn_returns_self) + out_test = f_compiled(x_test) + out_test.sum().backward() + + self.assertEqual(out_ref, out_test) + self.assertEqual(x_ref.grad, x_test.grad) + + # Not an open registration test - this file is just very convenient + # for testing torch.compile on custom C++ operators + @common.skipIfTorchDynamo("Temporary disabled due to torch._ops.OpOverloadPacket") + def test_compile_autograd_function_aliasing(self): + x_ref = torch.randn(4, requires_grad=True) + out_ref = torch.ops._test_funcs.custom_autograd_fn_aliasing(x_ref) + out_ref.sum().backward() + + x_test = x_ref.detach().clone().requires_grad_(True) + f_compiled = torch.compile(torch.ops._test_funcs.custom_autograd_fn_aliasing) + out_test = f_compiled(x_test) + out_test.sum().backward() + + self.assertEqual(out_ref, out_test) + self.assertEqual(x_ref.grad, x_test.grad) + + def test_open_device_scalar_type_fallback(self): + z_cpu = torch.Tensor([[0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]).to(torch.int64) + z = torch.triu_indices(3, 3, device="openreg") + self.assertEqual(z_cpu, z) + + def test_open_device_tensor_type_fallback(self): + # create tensors located in custom device + x = torch.Tensor([[1, 2, 3], [2, 3, 4]]).to("openreg") + y = torch.Tensor([1, 0, 2]).to("openreg") + # create result tensor located in cpu + z_cpu = torch.Tensor([[0, 2, 1], [1, 3, 2]]) + # Check that our device is correct. + device = self.module.custom_device() + self.assertTrue(x.device == device) + self.assertFalse(x.is_cpu) + + # call sub op, which will fallback to cpu + z = torch.sub(x, y) + self.assertEqual(z_cpu, z) + + # call index op, which will fallback to cpu + z_cpu = torch.Tensor([3, 1]) + y = torch.Tensor([1, 0]).long().to("openreg") + z = x[y, y] + self.assertEqual(z_cpu, z) + + def test_open_device_tensorlist_type_fallback(self): + # create tensors located in custom device + v_openreg = torch.Tensor([1, 2, 3]).to("openreg") + # create result tensor located in cpu + z_cpu = torch.Tensor([2, 4, 6]) + # create tensorlist for foreach_add op + x = (v_openreg, v_openreg) + y = (v_openreg, v_openreg) + # Check that our device is correct. + device = self.module.custom_device() + self.assertTrue(v_openreg.device == device) + self.assertFalse(v_openreg.is_cpu) + + # call _foreach_add op, which will fallback to cpu + z = torch._foreach_add(x, y) + self.assertEqual(z_cpu, z[0]) + self.assertEqual(z_cpu, z[1]) + + # call _fused_adamw_ with undefined tensor. + self.module.fallback_with_undefined_tensor() + + @common.skipIfTorchDynamo() + @unittest.skipIf( + np.__version__ < "1.25", + "versions < 1.25 serialize dtypes differently from how it's serialized in data_legacy_numpy", + ) + def test_open_device_numpy_serialization(self): + """ + This tests the legacy _rebuild_device_tensor_from_numpy serialization path + """ + device = self.module.custom_device() + + # Legacy data saved with _rebuild_device_tensor_from_numpy on f80ed0b8 via + + # with patch.object(torch._C, "_has_storage", return_value=False): + # x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32, device=device) + # x_foo = x.to(device) + # sd = {"x": x_foo} + # rebuild_func = x_foo._reduce_ex_internal(default_protocol)[0] + # self.assertTrue( + # rebuild_func is torch._utils._rebuild_device_tensor_from_numpy + # ) + # with open("foo.pt", "wb") as f: + # torch.save(sd, f) + + data_legacy_numpy = ( + b"PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x10\x00\x12\x00archive/data.pklFB\x0e\x00ZZZZZZZZZZZZZZ\x80\x02}q\x00X\x01" + b"\x00\x00\x00xq\x01ctorch._utils\n_rebuild_device_tensor_from_numpy\nq\x02(cnumpy.core.m" + b"ultiarray\n_reconstruct\nq\x03cnumpy\nndarray\nq\x04K\x00\x85q\x05c_codecs\nencode\nq\x06" + b"X\x01\x00\x00\x00bq\x07X\x06\x00\x00\x00latin1q\x08\x86q\tRq\n\x87q\x0bRq\x0c(K\x01K\x02K" + b"\x03\x86q\rcnumpy\ndtype\nq\x0eX\x02\x00\x00\x00f4q\x0f\x89\x88\x87q\x10Rq\x11(K\x03X\x01" + b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00" + b"PK\x05\x06\x00\x00\x00\x00\x04\x00\x04\x00\x06\x01\x00\x008\x03\x00\x00\x00\x00" + ) + buf_data_legacy_numpy = io.BytesIO(data_legacy_numpy) + + with safe_globals( + [ + (np.core.multiarray._reconstruct, "numpy.core.multiarray._reconstruct") + if np.__version__ >= "2.1" + else np.core.multiarray._reconstruct, + np.ndarray, + np.dtype, + _codecs.encode, + np.dtypes.Float32DType, + ] + ): + sd_loaded = torch.load(buf_data_legacy_numpy, weights_only=True) + buf_data_legacy_numpy.seek(0) + # Test map_location + sd_loaded_cpu = torch.load( + buf_data_legacy_numpy, weights_only=True, map_location="cpu" + ) + expected = torch.tensor( + [[1, 2, 3], [4, 5, 6]], dtype=torch.float32, device=device + ) + self.assertEqual(sd_loaded["x"].cpu(), expected.cpu()) + self.assertFalse(sd_loaded["x"].is_cpu) + self.assertTrue(sd_loaded_cpu["x"].is_cpu) + + def test_open_device_cpu_serialization(self): + torch.utils.rename_privateuse1_backend("openreg") + device = self.module.custom_device() + default_protocol = torch.serialization.DEFAULT_PROTOCOL + + with patch.object(torch._C, "_has_storage", return_value=False): + x = torch.randn(2, 3) + x_openreg = x.to(device) + sd = {"x": x_openreg} + rebuild_func = x_openreg._reduce_ex_internal(default_protocol)[0] + self.assertTrue( + rebuild_func is torch._utils._rebuild_device_tensor_from_cpu_tensor + ) + # Test map_location + with TemporaryFileName() as f: + torch.save(sd, f) + sd_loaded = torch.load(f, weights_only=True) + # Test map_location + sd_loaded_cpu = torch.load(f, weights_only=True, map_location="cpu") + self.assertFalse(sd_loaded["x"].is_cpu) + self.assertEqual(sd_loaded["x"].cpu(), x) + self.assertTrue(sd_loaded_cpu["x"].is_cpu) + + # Test metadata_only + with TemporaryFileName() as f: + with self.assertRaisesRegex( + RuntimeError, + "Cannot serialize tensors on backends with no storage under skip_data context manager", + ): + with torch.serialization.skip_data(): + torch.save(sd, f) + + def test_open_device_dlpack(self): + t = torch.randn(2, 3).to("openreg") + capsule = torch.utils.dlpack.to_dlpack(t) + t1 = torch.from_dlpack(capsule) + self.assertTrue(t1.device == t.device) + t = t.to("cpu") + t1 = t1.to("cpu") + self.assertEqual(t, t1) + + +if __name__ == "__main__": + common.run_tests() diff --git a/test/test_cuda.py b/test/test_cuda.py index d293601fad138..ceae45cc67b11 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -69,7 +69,10 @@ load_tests, MI300_ARCH, parametrize, +<<<<<<< HEAD recover_orig_fp32_precision, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests, serialTest, setBlasBackendsToDefaultFinally, @@ -174,7 +177,11 @@ def test_pinned_memory_with_cudaregister_multithread(self): for thread in threads: thread.join() +<<<<<<< HEAD @serialTest() +======= + @serialTest +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_host_memory_stats(self): # Helper functions def empty_stats(): @@ -373,6 +380,7 @@ def test_memory_allocation(self): torch.cuda.caching_allocator_delete(mem) self.assertEqual(torch.cuda.memory_allocated(), prev) +<<<<<<< HEAD def test_memory_stats(self): gc.collect() torch.cuda.empty_cache() @@ -409,6 +417,8 @@ def test_memory_stats(self): self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_check_error(self): # Assert this call doesn't raise. torch.cuda.check_error(0) @@ -762,7 +772,57 @@ def check_workspace_size(inp): torch._C._cuda_clearCublasWorkspaces() +<<<<<<< HEAD def test_cublas_allow_tf32_get_set(self): +======= + @contextlib.contextmanager + def _hip_allow_tf32(self): + # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new + # and only for MI300+ + hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) + os.environ["HIPBLASLT_ALLOW_TF32"] = "1" + + try: + yield + finally: + if hip_allow_tf32 is not None: + os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 + else: + del os.environ["HIPBLASLT_ALLOW_TF32"] + + @unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing") + def test_hipblaslt_allow_tf32(self): + tf32_ctx = self._hip_allow_tf32 + with tf32_ctx(): + os.environ["HIPBLASLT_ALLOW_TF32"] = "0" + # Save original value of allow_tf32 + orig = torch.backends.cuda.matmul.allow_tf32 + # If allow_tf32 variable is declared as static in aten/src/ATen/Context.cpp + # then matmul.allow_tf32 will return False after this point even if + # HIP_BLASLT_ALLOW_TF32 is set to 1 and matmul.allow_tf32 is changed. + os.environ["HIPBLASLT_ALLOW_TF32"] = "1" + # Toggle torch.backends.cuda.matmul.allow_tf32 couple of times. + torch.backends.cuda.matmul.allow_tf32 = not orig + test1 = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = orig + test2 = torch.backends.cuda.matmul.allow_tf32 + self.assertNotEqual(test1, test2) + # Restore original value of allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = orig + + def test_cublas_allow_tf32_get_set(self): + """ + We only turn on TF32 for MI300 with a special env var. This is because TF32 + is only available in MI300+ and is in experimental mode (hipblaslt support + is current WIP) + """ + tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext + + with tf32_ctx(): + self._test_cublas_allow_tf32_get_set_inner() + + def _test_cublas_allow_tf32_get_set_inner(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] ) @@ -777,12 +837,25 @@ def test_cublas_allow_tf32_get_set(self): torch.backends.cuda.matmul.allow_tf32 = orig def test_float32_matmul_precision_get_set(self): +<<<<<<< HEAD +======= + tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext + + with tf32_ctx(): + self._test_float32_matmul_precision_get_set_inner() + + def _test_float32_matmul_precision_get_set_inner(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig = torch.get_float32_matmul_precision() skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] ) # this is really just checking that the environment variable is respected during testing +<<<<<<< HEAD # and not overwritten by another function that doesn't revert it to the initial value +======= + # and not overwritten by another function that doesn't revert it to the intitial value +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not skip_tf32_cublas: self.assertFalse(torch.backends.cuda.matmul.allow_tf32) self.assertEqual(torch.get_float32_matmul_precision(), "highest") @@ -837,6 +910,7 @@ def test_cudnn_allow_tf32_get_set(self): ): self.assertTrue(torch.backends.cudnn.allow_tf32) +<<<<<<< HEAD @recover_orig_fp32_precision def test_fp32_precision_with_tf32(self): with torch.backends.cudnn.flags( @@ -886,6 +960,8 @@ def test_invalid_status_for_legacy_api(self): with self.assertRaisesRegex(RuntimeError, "mix of the legacy and new APIs"): print(torch.backends.cuda.matmul.allow_tf32) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_type_conversions(self): x = torch.randn(5, 5) self.assertIsInstance(x.float(), torch.FloatTensor) @@ -1130,7 +1206,11 @@ def perform_copy(): tmp2 = torch.cuda.FloatTensor(t.size()) tmp2.zero_() self.assertNotEqual( +<<<<<<< HEAD tmp2.data_ptr(), ptr[0], msg="allocation reused to soon" +======= + tmp2.data_ptr(), ptr[0], msg="allocation re-used to soon" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.assertEqual(result.tolist(), [1, 2, 3, 4]) @@ -1141,7 +1221,11 @@ def perform_copy(): torch.cuda.current_stream().synchronize() with torch.cuda.stream(stream): tmp3 = torch.cuda.FloatTensor(t.size()) +<<<<<<< HEAD self.assertEqual(tmp3.data_ptr(), ptr[0], msg="allocation not reused") +======= + self.assertEqual(tmp3.data_ptr(), ptr[0], msg="allocation not re-used") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_record_stream_on_shifted_view(self): # See issue #27366 @@ -1222,20 +1306,32 @@ def test_noncontiguous_pinned_memory(self): def test_caching_pinned_memory(self): cycles_per_ms = get_cycles_per_ms() +<<<<<<< HEAD # check that allocations are reused after deletion +======= + # check that allocations are re-used after deletion +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) t = torch.FloatTensor([1]).pin_memory() ptr = t.data_ptr() del t t = torch.FloatTensor([1]).pin_memory() self.assertEqual(t.data_ptr(), ptr, msg="allocation not reused") +<<<<<<< HEAD # check that the allocation is not reused if it's in-use by a copy +======= + # check that the allocation is not re-used if it's in-use by a copy +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gpu_tensor = torch.cuda.FloatTensor([0]) torch.cuda._sleep(int(1000 * cycles_per_ms)) # delay the copy by 1s gpu_tensor.copy_(t, non_blocking=True) del t t = torch.FloatTensor([1]).pin_memory() +<<<<<<< HEAD self.assertNotEqual(t.data_ptr(), ptr, msg="allocation reused too soon") +======= + self.assertNotEqual(t.data_ptr(), ptr, msg="allocation re-used too soon") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(list(gpu_tensor), [1]) def test_caching_allocator_record_stream_oom(self): @@ -1250,7 +1346,11 @@ def test_caching_allocator_record_stream_oom(self): x = torch.empty(40 * 1024 * 1024, device="cuda") with torch.cuda.stream(stream): y += x +<<<<<<< HEAD # delays reuse of `x` until after all operations in `stream` +======= + # delays re-use of `x` until after all operations in `stream` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x.record_stream(stream) del x @@ -1490,7 +1590,10 @@ def run(dev: torch.device) -> int: ) @largeTensorTest("20GB", "cuda") +<<<<<<< HEAD @serialTest() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_randint_generation_for_large_numel(self) -> None: numel = 2**31 + 1 s = torch.randint(2, (numel,), device="cuda", dtype=torch.int8).sum() @@ -2958,7 +3061,11 @@ def test_graph_memory_stats_and_use_result_after_destroy_graph(self): current = postcapture_stats[stat] - precapture_stats[stat] # There will only ever be one expandable segment in each of the small and large pools. The way the +<<<<<<< HEAD # bookkeeping is done in the allocator means that we never increment the number of segments. +======= + # bookeeping is done in the allocator means that we never increment the number of segments. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.expandable_segments and "segment" in stat: expected = 0 # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an @@ -2999,7 +3106,11 @@ def test_graph_memory_stats_and_use_result_after_destroy_graph(self): current = postdel_stats[stat] - precapture_stats[stat] # There will only ever be one expandable segment in each of the small and large pools. The way the +<<<<<<< HEAD # bookkeeping is done in the allocator means that we never increment the number of segments. +======= + # bookeeping is done in the allocator means that we never increment the number of segments. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.expandable_segments and "segment" in stat: expected = 0 # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an @@ -3526,14 +3637,22 @@ def raw_malloc(): try: with torch.cuda.stream(stream): mem = torch.cuda.caching_allocator_alloc(1024) +<<<<<<< HEAD except BaseException: # noqa: B036 +======= + except BaseException: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if mem is None: return try: torch.cuda.caching_allocator_delete(mem) mem = None return None +<<<<<<< HEAD except BaseException: # noqa: B036 +======= + except BaseException: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pass def throws_on_cuda_event(capture_error_mode): @@ -3618,6 +3737,7 @@ def test_cuda_graph_raw_graph(self): graph.replay() @unittest.skipIf( +<<<<<<< HEAD not TEST_CUDA_GRAPH or not TEST_CUDA_PYTHON_BINDINGS, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs, cuda-bindings must be installed", ) @@ -3647,6 +3767,8 @@ def test_cuda_graph_raw_graph_exec(self, keep_graph): graph.replay() @unittest.skipIf( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) def test_cuda_graph_raw_graph_reset_and_recapture(self): @@ -3665,7 +3787,11 @@ def test_cuda_graph_raw_graph_reset_and_recapture(self): graph.replay() self.assertTrue(torch.all(x == 3.0)) +<<<<<<< HEAD # Check that graph capture can succeed after resetting. +======= + # Check that graph capture can succeed after reseting. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph.reset() # Don't do x[:] = 0.0 because we want to capture a new address @@ -3715,6 +3841,7 @@ def test_cuda_graph_allocator_propagates_stream(self): self.assertEqual(len(x), 2) self.assertEqual(x[0], x[1]) +<<<<<<< HEAD @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) @@ -3748,6 +3875,8 @@ def my_func(a: torch.Tensor, b: torch.Tensor, perm: torch.Tensor): .strip() ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_batch_norm_gather_stats(self): input = torch.randn(1, 3, 3, 3, device="cuda") mean, invstd = torch.batch_norm_gather_stats( @@ -3889,6 +4018,7 @@ def test_hip_device_count(self): {"CUDA_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None}, {"CUDA_VISIBLE_DEVICES": None, "HIP_VISIBLE_DEVICES": "0"}, {"CUDA_VISIBLE_DEVICES": "0,1,2,3", "HIP_VISIBLE_DEVICES": "0"}, +<<<<<<< HEAD {"ROCR_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None}, ] @@ -3899,6 +4029,12 @@ def test_hip_device_count(self): ] ) +======= + {"ROCR_VISIBLE_DEVICES": "1,2,3", "HIP_VISIBLE_DEVICES": "0"}, + {"ROCR_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None}, + ] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for env_config in custom_envs: env = os.environ.copy() for key, value in env_config.items(): @@ -4342,7 +4478,11 @@ def foo(): finally: torch.cuda.memory._record_memory_history(None) +<<<<<<< HEAD @serialTest() +======= + @serialTest +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_max_split_expandable(self): try: torch.cuda.memory.empty_cache() @@ -4378,7 +4518,11 @@ def alloc(n): finally: torch.cuda.memory.set_per_process_memory_fraction(orig) +<<<<<<< HEAD @serialTest() +======= + @serialTest +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_garbage_collect_expandable(self): try: torch.cuda.memory.empty_cache() @@ -4408,7 +4552,11 @@ def alloc(n): # expandable_segment blocks can be in the free list when this is called. alloc(80) finally: +<<<<<<< HEAD orig = torch.cuda.get_per_process_memory_fraction(0) +======= + torch.cuda.memory.set_per_process_memory_fraction(orig) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_allocator_settings(self): def power2_div(size, div_factor): @@ -5348,7 +5496,10 @@ def test_mempool_empty_cache(self): segments = torch.cuda.memory._snapshot()["segments"] self.assertTrue(len(segments) > 0, "expected more than one segment") +<<<<<<< HEAD @serialTest() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_mempool_empty_cache_inactive(self): torch.cuda.empty_cache() allocator, dummy_allocator = self.get_dummy_allocator(check_vars=True) @@ -5433,7 +5584,11 @@ def test_mempool_with_allocator(self): out_2 = torch.randn(nelem_1mb, device="cuda") # pool now should have 2 segments since the CUDACachingAllocator had +<<<<<<< HEAD # to make a new 2 MB buffer to accommodate out_2 +======= + # to make a new 2 MB buffer to accomodate out_2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(len(pool.snapshot()), 2) self.assertEqual(len(pool.snapshot()), 2) @@ -5564,6 +5719,7 @@ def my_function(pool): s = p.snapshot() self.assertEqual(len(s), 1, "Expected to have a single segment") +<<<<<<< HEAD @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) @@ -5707,6 +5863,8 @@ def test_graph_capture_reclaim_4_streams(self): "graph_capture_record_stream_reuse:False" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfRocm(msg="expandable_segments mode is not supported on ROCm") @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Load_inline doesn't work in fbcode") def test_mempool_expandable(self): @@ -5721,7 +5879,10 @@ def test_mempool_expandable(self): out_0 = torch.randn(nelem_1mb, device="cuda") torch.cuda.memory._set_allocator_settings("expandable_segments:False") +<<<<<<< HEAD @serialTest() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_mempool_ctx_multithread(self): torch.cuda.empty_cache() segments = torch.cuda.memory._snapshot()["segments"] @@ -6641,7 +6802,10 @@ def test_autocast_rnn(self): for grad, grad_control in zip(grads, grads_control): self.assertEqual(grad.half(), grad_control) +<<<<<<< HEAD @serialTest() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_autocast_cache_leak(self): # Reported at https://github.com/pytorch/pytorch/issues/48049 # Test is used to check, if autocast recaches the same parameters @@ -6656,7 +6820,11 @@ def test_autocast_cache_leak(self): first_iter_mem = torch.cuda.memory_allocated() for _ in range(3): out = linear(data) +<<<<<<< HEAD self.assertEqual(first_iter_mem, torch.cuda.memory_allocated()) +======= + self.assertTrue(first_iter_mem == torch.cuda.memory_allocated()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_autocast_checkpointing(self): model = torch.nn.Sequential( @@ -6680,10 +6848,19 @@ def test_cuda_autocast_deprecated_warning(self): with torch.cuda.amp.autocast(): _ = torch.ones(10) +<<<<<<< HEAD @unittest.skipIf( os.environ.get("USE_LEGACY_DRIVER", None) == "1", "Doesn't work with older driver" ) +======= + def test_cuda_module_loading_env(self): + torch.cuda.init() + val = os.environ.get("CUDA_MODULE_LOADING", "") + self.assertEqual(val, "LAZY") + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestCompileKernel(TestCase): @unittest.skipIf(TEST_WITH_ROCM, "ROCM does not support nvrtc") @unittest.skipIf(not TEST_CUDA, "No CUDA") @@ -6996,7 +7173,13 @@ def test_graph_external_wait_and_record(self): """ from torch.cuda import _compile_kernel +<<<<<<< HEAD spin_wait_kernel = _compile_kernel(kernel_source, "wait_for_cpu") +======= + spin_wait_kernel = _compile_kernel( + kernel_source, "wait_for_cpu", compute_capability="70" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.ones(4, device="cuda") x_cpu = torch.zeros(x.shape, device="cpu").pin_memory() diff --git a/test/test_cuda_multigpu.py b/test/test_cuda_multigpu.py index 2882b0f58808a..ce8e878da4fef 100644 --- a/test/test_cuda_multigpu.py +++ b/test/test_cuda_multigpu.py @@ -967,7 +967,11 @@ def test_external_streams_multi_device(self): @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") def test_caching_pinned_memory_multi_gpu(self): +<<<<<<< HEAD # checks that the events preventing pinned memory from being reused +======= + # checks that the events preventing pinned memory from being re-used +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # too early are recorded on the correct GPU cycles_per_ms = get_cycles_per_ms() @@ -982,7 +986,11 @@ def test_caching_pinned_memory_multi_gpu(self): del t t = torch.FloatTensor([2]).pin_memory() +<<<<<<< HEAD self.assertNotEqual(t.data_ptr(), ptr, msg="allocation reused too soon") +======= + self.assertNotEqual(t.data_ptr(), ptr, msg="allocation re-used too soon") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch.cuda.device(0): gpu_tensor0.copy_(t, non_blocking=True) diff --git a/test/test_cuda_nvml_based_avail.py b/test/test_cuda_nvml_based_avail.py index c47607f4c7ac9..37b992a3bccbf 100644 --- a/test/test_cuda_nvml_based_avail.py +++ b/test/test_cuda_nvml_based_avail.py @@ -138,7 +138,11 @@ def test_partial_uuid_resolver(self): _transform_uuid_to_ordinals(["GPU-9e8d35e3", "GPU-123", "GPU-47"], uuids), [1], ) +<<<<<<< HEAD # First ambiguous UUID aborts parsing +======= + # First ambigous UUID aborts parsing +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( _transform_uuid_to_ordinals(["GPU-9e8d35e3", "GPU-e", "GPU-47"], uuids), [1] ) diff --git a/test/test_cuda_primary_ctx.py b/test/test_cuda_primary_ctx.py index 284d048e9e080..bc0f03c4eb7b4 100644 --- a/test/test_cuda_primary_ctx.py +++ b/test/test_cuda_primary_ctx.py @@ -42,7 +42,11 @@ def test_set_device_0(self): self.assertFalse(torch._C._cuda_hasPrimaryContext(0)) torch.cuda.set_device(0) if _get_torch_cuda_version() >= (12, 0): +<<<<<<< HEAD # Now after the device was set, the context should present in CUDA 12. +======= + # Now after the device was set, the contex should present in CUDA 12. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(torch._C._cuda_hasPrimaryContext(0)) else: # In CUDA 11 the context should not be created. diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 5a494f5487423..6e44c64629558 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -167,7 +167,10 @@ def foo_impl(x): lib.impl("foo", Foo.apply, "Autograd") lib.impl("foo", foo_impl, "CPU") lib.impl("foo", foo_impl, "CUDA") +<<<<<<< HEAD lib.impl("foo", foo_impl, "XPU") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.tensor(3.14159 / 3, requires_grad=True, device=device) with self.assertRaisesRegex( @@ -272,7 +275,10 @@ def foo_impl(x): lib.impl("foo", Foo.apply, "Autograd") lib.impl("foo", foo_impl, "CPU") lib.impl("foo", foo_impl, "CUDA") +<<<<<<< HEAD lib.impl("foo", foo_impl, "XPU") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.tensor([0, 1.0], requires_grad=True) with self.assertRaisesRegex( @@ -314,7 +320,10 @@ def foo_meta(x): lib.impl("foo", Foo.apply, "Autograd") lib.impl("foo", foo_impl, "CPU") lib.impl("foo", foo_impl, "CUDA") +<<<<<<< HEAD lib.impl("foo", foo_impl, "XPU") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lib.impl("foo", foo_meta, "Meta") x = torch.tensor([0, 1.0], requires_grad=True) @@ -346,7 +355,10 @@ def foo_meta(x): lib.impl("foo", Foo.apply, "Autograd") lib.impl("foo", foo_impl, "CPU") lib.impl("foo", foo_impl, "CUDA") +<<<<<<< HEAD lib.impl("foo", foo_impl, "XPU") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lib.impl("foo", foo_meta, "Meta") x = torch.tensor([0, 1.0]) @@ -373,7 +385,10 @@ def backward(ctx, gx): lib.impl("foo", Foo.apply, "CPU") lib.impl("foo", Foo.apply, "CUDA") +<<<<<<< HEAD lib.impl("foo", Foo.apply, "XPU") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lib.impl("foo", lambda x: x.clone(), "Meta") x = torch.randn([], requires_grad=True) @@ -467,7 +482,10 @@ def foo_impl(x): lib.impl("foo", Foo.apply, "Autograd") lib.impl("foo", foo_impl, "CPU") lib.impl("foo", foo_impl, "CUDA") +<<<<<<< HEAD lib.impl("foo", foo_impl, "XPU") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.randn(3, requires_grad=True, device=device) # Should not raise @@ -517,7 +535,10 @@ def backward(ctx, gx): lib.impl("foo", Foo.apply, "CPU") lib.impl("foo", Foo.apply, "CUDA") +<<<<<<< HEAD lib.impl("foo", Foo.apply, "XPU") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.randn(3, requires_grad=True, device=device) with self.assertRaisesRegex(AssertionError, "incorrectly registered"): @@ -544,6 +565,65 @@ def test_assert_raises_regex(self, device): class TestCustomOp(CustomOpTestCaseBase): test_ns = "_test_custom_op" +<<<<<<< HEAD +======= + def test_deploy_interaction(self): + # run in a different process to avoid parallel issues when we monkeypatch torch._running_with_deploy + script = """ +import torch +torch._running_with_deploy = lambda: True + +# creating the library is a no-op, so you can DEF multiple times +m1 = torch.library.Library("mylib4392", "DEF") # noqa: TOR901 +m2 = torch.library.Library("mylib4392", "DEF") # noqa: TOR901 + +m = torch.library.Library("aten", "FRAGMENT") # noqa: TOR901 + +# define is a no-op +m.define("foobarbaz9996(Tensor x) -> Tensor") +assert not hasattr(torch.ops.aten, "foobarbaz9996"), "m.define should have been a noop" + +def sin_override(x): + raise AssertionError("m.impl should have been a noop") + +# impl is a no-op +m.impl("sin", sin_override, "CompositeImplicitAutograd") +x = torch.randn(3) +y = torch.sin(x) + +# should be a no-op +@torch.library.custom_op("mylib::foobar", mutates_args={}) +def foobar(x: torch.Tensor) -> torch.Tensor: + return x.sin() + +# should be a no-op +@foobar.register_fake +def _(x): + return torch.empty_like(x) + +# should be a no-op +m2.define("foobarbaz9996(Tensor x) -> Tensor") + +# should be a no-op +@torch.library.register_fake("mylib4392::foobarbaz9996") +def _(x): + return torch.empty_like(x) + """ + script = script.strip() + env = os.environ.copy() + try: + subprocess.check_output( + [sys.executable, "-c", script], + stderr=subprocess.STDOUT, + # On Windows, opening the subprocess with the default CWD makes `import torch` + # fail, so just set CWD to this script's directory + cwd=os.path.dirname(os.path.realpath(__file__)), + env=env, + ) + except subprocess.CalledProcessError as e: + self.fail(msg=("Subprocess exception:\n" + e.output.decode("utf-8"))) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @requires_compile def test_functionalize_error(self): with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib: @@ -581,7 +661,11 @@ def g(x): g(x) def test_invalid_schemas(self): +<<<<<<< HEAD # function schema validation goes through torchgen, so this is just a +======= + # function schmea validation goes through torchgen, so this is just a +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # basic test. with self.assertRaisesRegex(AssertionError, "Invalid function schema: foo"): custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(") @@ -1608,7 +1692,11 @@ def test_impl_abstract_overload(self): lib = self.lib() lib.define("sin.blah(Tensor x) -> Tensor") +<<<<<<< HEAD torch.library.register_fake( +======= + torch.library.impl_abstract( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib ) @@ -1621,7 +1709,11 @@ def test_impl_meta(self): def foo(x: torch.Tensor, dim: int) -> torch.Tensor: raise NotImplementedError +<<<<<<< HEAD @torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) +======= + @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def foo_meta(x, dim): output_shape = list(x.shape) del output_shape[dim] @@ -1637,7 +1729,11 @@ def test_duplicate_impl(self): def foo(x: torch.Tensor, dim: int) -> torch.Tensor: raise NotImplementedError +<<<<<<< HEAD @torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) +======= + @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def foo_meta(x, dim): output_shape = list(x.shape) del output_shape[dim] @@ -1645,7 +1741,11 @@ def foo_meta(x, dim): with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"): +<<<<<<< HEAD @torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) +======= + @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def foo_meta2(x, dim): output_shape = list(x.shape) del output_shape[dim] @@ -1656,7 +1756,11 @@ def test_new_data_dependent_symint(self): def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError +<<<<<<< HEAD @torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) +======= + @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def foo_meta(x): ctx = torch.library.get_ctx() r = ctx.new_dynamic_size(min=1) @@ -1683,7 +1787,11 @@ def test_basic_make_fx(self): def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError +<<<<<<< HEAD @torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) +======= + @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def foo_meta(x): return x.sum() @@ -1768,8 +1876,12 @@ def f(x): Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True` Developer debug context: _torch_testing.numpy_nonzero.default +<<<<<<< HEAD For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0036.html""", +======= +""", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # pre-existing problem: torch.compile(dynamic=True) will, by default, @@ -1827,7 +1939,11 @@ def test_abstract_impl_on_existing_op(self): lib.define("foo(Tensor x) -> Tensor") qualname = f"{self.test_ns}::foo" +<<<<<<< HEAD @torch.library.register_fake(qualname, lib=self.lib()) +======= + @torch.library.impl_abstract(qualname, lib=self.lib()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def foo_impl(x): return x.sin() @@ -1850,7 +1966,11 @@ def foo_impl(x): op = self.get_op(qualname) with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"): +<<<<<<< HEAD torch.library.register_fake(qualname, foo_impl, lib=self.lib()) +======= + torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self): lib = self.lib() @@ -1864,7 +1984,11 @@ def foo_impl(x): op = self.get_op(qualname) with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"): +<<<<<<< HEAD torch.library.register_fake(qualname, foo_impl, lib=self.lib()) +======= + torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self): lib = self.lib() @@ -1877,7 +2001,11 @@ def foo_impl(x): lib.impl("foo", foo_impl, "CompositeExplicitAutograd") op = self.get_op(qualname) +<<<<<<< HEAD torch.library.register_fake(qualname, lambda x: x.sum(), lib=self.lib()) +======= + torch.library.impl_abstract(qualname, func=lambda x: x.sum(), lib=self.lib()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch._subclasses.FakeTensorMode(): x = torch.randn(10) result = op(x) @@ -2333,6 +2461,7 @@ def test_autograd_function_backed_op(self, load_inline): loss.backward() self.assertEqual(x.grad, temp) +<<<<<<< HEAD # Using a non-existent DSO is a quick way to trigger an OSError, # which can be used to not break BC. def test_load_library(self): @@ -2341,6 +2470,8 @@ def test_load_library(self): ): torch.ops.load_library("libnoexist.so") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def op_with_incorrect_schema(testcase, name): lib = testcase.lib() @@ -2672,7 +2803,11 @@ def backward(ctx, grad): self.assertEqual(ctx.needs_input_grad, expected) return list(grad.unbind(0)) +<<<<<<< HEAD # call two applies, do a backward on the first +======= + # call two applys, do a backward on the first +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def t(): return torch.randn([], requires_grad=True) @@ -4637,10 +4772,15 @@ def test_version(self): loaded = read_profiles_from_yaml(yaml_str) +<<<<<<< HEAD only_for = ("cpu", "cuda", "xpu") instantiate_device_type_tests( TestCustomOpTesting, globals(), only_for=only_for, allow_xpu=True ) +======= +only_for = ("cpu", "cuda") +instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_parametrized_tests(TestCustomOp) instantiate_parametrized_tests(TestCustomOpAPI) diff --git a/test/test_dataloader.py b/test/test_dataloader.py index b2f47c437fc33..43675af621cc6 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -25,7 +25,10 @@ from torch.testing._internal.common_utils import ( IS_CI, IS_JETSON, +<<<<<<< HEAD IS_MACOS, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) IS_S390X, IS_SANDCASTLE, IS_WINDOWS, @@ -735,12 +738,21 @@ class SleepDataset(Dataset): def __init__(self, size, sleep_sec): self.size = size self.sleep_sec = sleep_sec +<<<<<<< HEAD self.slept = False def __getitem__(self, idx): if not self.slept: time.sleep(self.sleep_sec) self.slept = True +======= + self.sleeped = False + + def __getitem__(self, idx): + if not self.sleeped: + time.sleep(self.sleep_sec) + self.sleeped = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return idx def __len__(self): @@ -3134,6 +3146,7 @@ def test_pin_memory(self): self.assertTrue(sample["a_tensor"].is_pinned()) self.assertTrue(sample["another_dict"]["a_number"].is_pinned()) +<<<<<<< HEAD @skipIfXpu @skipIfRocm @unittest.skipIf(TEST_CUDA, "Test for when CUDA is not available") @@ -3143,6 +3156,8 @@ def test_pin_memory_no_cuda(self): self.assertFalse(sample["a_tensor"].is_pinned()) self.assertFalse(sample["another_dict"]["a_number"].is_pinned()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_pin_memory_device(self): loader = DataLoader( @@ -3478,10 +3493,13 @@ def _run_ind_worker_queue_test(self, batch_size, num_workers): if current_worker_idx == num_workers: current_worker_idx = 0 +<<<<<<< HEAD @unittest.skipIf( IS_WINDOWS or IS_MACOS, "Flaky on Windows and MacOS https://github.com/pytorch/pytorch/issues/68643", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_ind_worker_queue(self): max_num_workers = None if hasattr(os, "sched_getaffinity"): diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 2a57bef2075b8..a05b195165cd8 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -573,7 +573,11 @@ def operations(df): class TestDataFramesPipes(TestCase): """ +<<<<<<< HEAD Most of test will fail if pandas installed, but no dill available. +======= + Most of test will fail if pandas instaled, but no dill available. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Need to rework them to avoid multiple skips. """ @@ -1887,7 +1891,11 @@ def _non_bool_fn(data): with self.assertRaises(ValueError): list(filter_dp) +<<<<<<< HEAD # Functional Test: Specify input_col +======= + # Funtional Test: Specify input_col +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tuple_input_ds = dp.iter.IterableWrapper([(d - 1, d, d + 1) for d in range(10)]) # Single input_col @@ -3356,7 +3364,11 @@ def construct_sharded_pipe(): with self.assertRaises(Exception): dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT) +<<<<<<< HEAD # Test tud.datapipes.iter.grouping.SHARDING_PRIORITIES for backward compatibility +======= + # Test tud.datapipes.iter.grouping.SHARDING_PRIORITIES for backward compatbility +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: Remove this test once tud.datapipes.iter.grouping.SHARDING_PRIORITIES is deprecated def test_sharding_groups_in_legacy_grouping_package(self): with self.assertWarnsRegex( diff --git a/test/test_decomp.py b/test/test_decomp.py index 5a2e427057460..98a3507eef380 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -15,7 +15,11 @@ from torch._export.utils import _is_cia_op from torch._ops import DispatchKey from torch.testing import make_tensor +<<<<<<< HEAD from torch.testing._internal.common_cuda import SM70OrLater, tf32_off +======= +from torch.testing._internal.common_cuda import tf32_off +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCPU, @@ -854,13 +858,18 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): # de-functionalise the graph, as that would break AoTAutograd # We run the real function *after* the decomposition to make sure that the # decomposition does not modify any of the inputs in-place. If it does +<<<<<<< HEAD # real_out should be different than decom_out so we should catch this +======= + # real_out should be differen than decom_out so we should catch this +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) real_out_unflat = func(*args, **kwargs) real_out = pytree.tree_leaves(real_out_unflat) assert len(real_out) == len(decomp_out) if do_relative_check: +<<<<<<< HEAD device_arg = kwargs.get("device", None) def upcast(x): @@ -871,6 +880,9 @@ def upcast(x): else: return upcast_tensor(x, dtype=torch.float64) +======= + upcast = partial(upcast_tensor, dtype=torch.float64) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) real_out_double, _ = tree_flatten( func(*tree_map(upcast, args), **tree_map(upcast, kwargs)) ) @@ -1235,6 +1247,7 @@ def f(x, w, b): for o_ref, o in zip(out_ref, out): self.assertEqual(o_ref.dtype, o.dtype) +<<<<<<< HEAD @onlyCUDA @unittest.skipIf(not SM70OrLater, "triton") def test_rms_norm_decomp_cuda(self, device): @@ -1262,6 +1275,8 @@ def forward_pass_fn(): "triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1] ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_device_type_tests(DecompOneOffTests, globals()) diff --git a/test/test_deploy.py b/test/test_deploy.py new file mode 100644 index 0000000000000..b852802c0c20f --- /dev/null +++ b/test/test_deploy.py @@ -0,0 +1,43 @@ +# Owner(s): ["oncall: package/deploy"] + +import textwrap +import types + +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.utils._freeze import Freezer, PATH_MARKER + + +class TestFreezer(TestCase): + """Tests the freeze.py script""" + + def test_compile_string(self): + freezer = Freezer(True) + code_str = textwrap.dedent( + """ + class MyCls: + def __init__(self) -> None: + pass + """ + ) + co = freezer.compile_string(code_str) + num_co = 0 + + def verify_filename(co: types.CodeType): + nonlocal num_co + + if not isinstance(co, types.CodeType): + return + + self.assertEqual(PATH_MARKER, co.co_filename) + num_co += 1 + + for nested_co in co.co_consts: + verify_filename(nested_co) + + verify_filename(co) + # there is at least one nested code object besides the top level one + self.assertTrue(num_co >= 2) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_dlpack.py b/test/test_dlpack.py index b960575cc6348..37a95e4e038b5 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -3,6 +3,7 @@ import torch from torch.testing import make_tensor from torch.testing._internal.common_device_type import ( +<<<<<<< HEAD deviceCountAtLeast, dtypes, dtypesIfMPS, @@ -26,6 +27,18 @@ TestCase, ) from torch.utils.dlpack import DLDeviceType, from_dlpack, to_dlpack +======= + dtypes, + instantiate_device_type_tests, + onlyCUDA, + onlyNativeDeviceTypes, + skipCUDAIfRocm, + skipMeta, +) +from torch.testing._internal.common_dtype import all_types_and_complex_and +from torch.testing._internal.common_utils import IS_JETSON, run_tests, TestCase +from torch.utils.dlpack import from_dlpack, to_dlpack +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Wraps a tensor, exposing only DLPack methods: @@ -60,7 +73,10 @@ class TestTorchDlPack(TestCase): torch.uint64, ) ) +<<<<<<< HEAD @dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dlpack_capsule_conversion(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) z = from_dlpack(to_dlpack(x)) @@ -78,7 +94,10 @@ def test_dlpack_capsule_conversion(self, device, dtype): torch.uint64, ) ) +<<<<<<< HEAD @dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dlpack_protocol_conversion(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) z = from_dlpack(x) @@ -87,8 +106,12 @@ def test_dlpack_protocol_conversion(self, device, dtype): @skipMeta @onlyNativeDeviceTypes def test_dlpack_shared_storage(self, device): +<<<<<<< HEAD dtype = torch.bfloat16 if device.startswith("mps") else torch.float64 x = make_tensor((5,), dtype=dtype, device=device) +======= + x = make_tensor((5,), dtype=torch.float64, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) z = from_dlpack(to_dlpack(x)) z[0] = z[0] + 20.0 self.assertEqual(z, x) @@ -128,14 +151,20 @@ def test_dlpack_conversion_with_streams(self, device, dtype): torch.uint64, ) ) +<<<<<<< HEAD @dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_from_dlpack(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) y = torch.from_dlpack(x) self.assertEqual(x, y) @skipMeta +<<<<<<< HEAD @skipIfMPS # MPS crashes with noncontiguous now +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes @dtypes( *all_types_and_complex_and( @@ -182,7 +211,11 @@ def test_dlpack_conversion_with_diff_streams(self, device, dtype): # in the current stream to make sure that it was correctly populated. with torch.cuda.stream(stream_a): x = make_tensor((5,), dtype=dtype, device=device) + 1 +<<<<<<< HEAD z = torch.from_dlpack(x.__dlpack__(stream=stream_b.cuda_stream)) +======= + z = torch.from_dlpack(x.__dlpack__(stream_b.cuda_stream)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stream_a.synchronize() stream_b.synchronize() self.assertEqual(z, x) @@ -199,7 +232,10 @@ def test_dlpack_conversion_with_diff_streams(self, device, dtype): torch.uint64, ) ) +<<<<<<< HEAD @dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat, torch.chalf)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_from_dlpack_dtype(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) y = torch.from_dlpack(x) @@ -220,7 +256,11 @@ def __dlpack__(self, stream=None): assert stream == 1 else: assert stream == 0 +<<<<<<< HEAD capsule = self.tensor.__dlpack__(stream=stream) +======= + capsule = self.tensor.__dlpack__(stream) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return capsule # CUDA-based tests runs on non-default streams @@ -243,7 +283,11 @@ def test_dlpack_convert_default_stream(self, device): x = torch.zeros(1, device=device) torch.cuda._sleep(2**20) self.assertTrue(torch.cuda.default_stream().query()) +<<<<<<< HEAD x.__dlpack__(stream=1) +======= + x.__dlpack__(1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # check that the default stream has work (a pending cudaStreamWaitEvent) self.assertFalse(torch.cuda.default_stream().query()) @@ -255,6 +299,7 @@ def test_dlpack_tensor_invalid_stream(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) x.__dlpack__(stream=object()) +<<<<<<< HEAD @skipMeta @onlyCUDA @skipCUDAIfRocm @@ -311,25 +356,39 @@ def test_dlpack_tensor_on_different_device(self, devices): with torch.device(dev1): x.__dlpack__() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: add interchange tests once NumPy 1.22 (dlpack support) is required @skipMeta def test_dlpack_export_requires_grad(self): x = torch.zeros(10, dtype=torch.float32, requires_grad=True) +<<<<<<< HEAD with self.assertRaisesRegex(BufferError, r"require gradient"): +======= + with self.assertRaisesRegex(RuntimeError, r"require gradient"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x.__dlpack__() @skipMeta def test_dlpack_export_is_conj(self): x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) y = torch.conj(x) +<<<<<<< HEAD with self.assertRaisesRegex(BufferError, r"conjugate bit"): +======= + with self.assertRaisesRegex(RuntimeError, r"conjugate bit"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) y.__dlpack__() @skipMeta def test_dlpack_export_non_strided(self): x = torch.sparse_coo_tensor([[0]], [1], size=(1,)) y = torch.conj(x) +<<<<<<< HEAD with self.assertRaisesRegex(BufferError, r"strided"): +======= + with self.assertRaisesRegex(RuntimeError, r"strided"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) y.__dlpack__() @skipMeta @@ -356,6 +415,7 @@ def test_automatically_select_in_creation(self, device): new_tensor = torch.tensor(wrap) self.assertEqual(tensor, new_tensor) +<<<<<<< HEAD @skipMeta @skipIfTorchDynamo("__dlpack__ doesn't work with dynamo") @onlyNativeDeviceTypes @@ -495,6 +555,10 @@ def test_dlpack_unsupported_dtype_error(self, device): instantiate_device_type_tests(TestTorchDlPack, globals(), allow_mps=True) +======= + +instantiate_device_type_tests(TestTorchDlPack, globals()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 7ba466119da85..4df8ae07bae9d 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -15,7 +15,11 @@ import torch.nn.functional as F from torch import sym_int, SymBool, SymFloat, SymInt from torch._C import _disabled_torch_function_impl +<<<<<<< HEAD from torch._dynamo.testing import CompileCounter, CompileCounterWithBackend +======= +from torch._dynamo.testing import CompileCounterWithBackend +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.utils import fresh_cache from torch.fx.experimental import sym_node from torch.fx.experimental.proxy_tensor import make_fx @@ -861,7 +865,11 @@ def test_mul_int_oo_nan(self): s2 = create_symint(shape_env, 5, duck=False) bool(s0 * (s1 // s0) == s2) +<<<<<<< HEAD def test_non_overlapping_and_dense_backed(self): +======= + def test_non_overlapping_and_dense(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shape_env = ShapeEnv() a0 = create_symint(shape_env, 5) r = torch.empty_strided((a0, 7), (1, a0), device="meta") @@ -896,6 +904,7 @@ def test_non_overlapping_and_dense_unbacked(self): ) ) +<<<<<<< HEAD def test_prims_non_overlapping_and_dense(self): shape_env = ShapeEnv() cf = torch._prims_common.is_non_overlapping_and_dense @@ -954,6 +963,8 @@ def test_prims_non_overlapping_and_dense(self): ) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_sympy_optimized_add_binary_search(self): import sympy @@ -1915,6 +1926,7 @@ def is_complex(x): class TestDimConstraints(TestCase): +<<<<<<< HEAD @skipIfTorchDynamo("mark_dynamic not supported") def test_simplify_max_1_0(self): x = torch.rand(10) @@ -1937,6 +1949,8 @@ def func(x, v): self.assertEqual(func(x, 1), x * 400) self.assertEqual(func(x, 0), x * 400) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_dim_constraints_reduce_congruences_simple(self): from sympy import Symbol @@ -3130,7 +3144,10 @@ def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph: class TestUnbacked(TestCase): +<<<<<<< HEAD @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._dynamo.config.patch("capture_scalar_outputs", True) @parametrize("backend", ["inductor", "eager"]) def test_deferred_neq_assert(self, backend): @@ -3178,7 +3195,10 @@ def func(x, y): with self.assertRaises(RuntimeError): func(torch.rand(2, 50), torch.tensor([51])) +<<<<<<< HEAD @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._dynamo.config.patch("capture_scalar_outputs", True) @parametrize("backend", ["inductor", "eager"]) def test_deferred_sym_or_assert(self, backend): @@ -3200,7 +3220,10 @@ def test_has_free_symbols(self): self.assertTrue(has_free_symbols(sympy.sympify("a*2"))) self.assertTrue(has_free_symbols(sympy.sympify("a+b"))) +<<<<<<< HEAD @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._dynamo.config.patch("capture_scalar_outputs", True) @parametrize("backend", ["inductor", "eager"]) def test_deferred_sym_eq_assert(self, backend): @@ -3245,10 +3268,16 @@ def func(x, y): f = y.item() t1 = x.view((f, f)) t2 = x.reshape((f, f)) +<<<<<<< HEAD t3 = torch._ops.ops.aten.view_copy(x, (f, f)) # TODO avoid _check_is_size here. torch._check_is_size(f) return t1 * 10, t2 * 10, t3 +======= + # TODO avoid _check_is_size here. + torch._check_is_size(f) + return t1 * 10, t2 * 10 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) compiled_func = torch.compile( fullgraph=True, @@ -3270,7 +3299,11 @@ def make_non_contiguous_tensor_and_test(cnt): self.assertEqual(compiled_result, eager_result) log_stream, ctx = logs_to_string( +<<<<<<< HEAD "torch._functorch._aot_autograd.graph_capture", "aot_graphs" +======= + "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) with ctx(): make_non_contiguous_tensor_and_test(4) @@ -3288,12 +3321,19 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None view: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]) +<<<<<<< HEAD view_1: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]) view_2: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None clone: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.clone.default(view_2); view_2 = None mul_11: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None mul_14: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None return (mul_11, mul_14, clone)""", # noqa: B950 +======= + view_1: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None + mul_9: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None + mul_12: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None + return (mul_9, mul_12)""", # noqa: B950 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ignore_comments=True, ignore_empty_lines=True, ) @@ -3307,7 +3347,11 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", torch._dynamo.decorators.mark_unbacked(x, 0) log_stream, ctx = logs_to_string( +<<<<<<< HEAD "torch._functorch._aot_autograd.graph_capture", "aot_graphs" +======= + "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) with ctx(): compiled_result = compiled_func(x, torch.tensor([10])) @@ -3329,12 +3373,19 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1] eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None view: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]) +<<<<<<< HEAD view_1: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]) view_2: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None clone: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.clone.default(view_2); view_2 = None mul_6: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None mul_9: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None return (mul_6, mul_9, clone)""", # noqa: B950 +======= + view_1: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None + mul_4: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None + mul_7: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None + return (mul_4, mul_7)""", # noqa: B950 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ignore_comments=True, ignore_empty_lines=True, ) @@ -3349,7 +3400,11 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1] def test_unbacked_reshape2(self): cnt = CompileCounterWithBackend("inductor") +<<<<<<< HEAD # This reshape requires a clone when the input is not contiguous and we can't compute strides. +======= + # This reshape requires a clone when the input is not contiguous and we cant compute strides. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # reshape (u2, u3) -> (u0, u1) def func(x, y): u0, u1 = y.tolist() @@ -3368,7 +3423,11 @@ def func(x, y): torch._dynamo.decorators.mark_unbacked(x, 1) log_stream, ctx = logs_to_string( +<<<<<<< HEAD "torch._functorch._aot_autograd.graph_capture", "aot_graphs" +======= + "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) with ctx(): result_eager = func(x, torch.tensor([5, 20])) @@ -3399,8 +3458,13 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", _assert_scalar_4 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_4 = None clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None +<<<<<<< HEAD mul_21: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None return (mul_21,)""", # noqa: B950 +======= + mul_19: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None + return (mul_19,)""", # noqa: B950 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ignore_comments=True, ignore_empty_lines=True, ) @@ -3418,11 +3482,19 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", # Pass a contiguous tensor. A recompilation will happen due to 0/1 speciialization on stride. log_stream, ctx = logs_to_string( +<<<<<<< HEAD "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) with ctx(): # This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3]. # but not anymore since we use contiguous_or_false . +======= + "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" + ) + with ctx(): + # This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3]. + # but not anymore since we use definitely_contiguous . +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We need a way to mark strides unbacked to avoid the recompilation here. x = torch.randn(10, 10) torch._dynamo.decorators.mark_unbacked(x, 0) @@ -3484,7 +3556,11 @@ def make_non_contiguous_tensor(cnt): def test_invalid_view_unbacked_view(self): cnt = CompileCounterWithBackend("inductor") +<<<<<<< HEAD # This view (u2, u3) -> (u0, u1) can't happen in general unless we know that input is contiguous or we have +======= + # This view (u2, u3) -> (u0, u1) cant happen in general unless we know that input is contigous or we have +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # hints to to compute strides. def func(x, y): u0, u1 = y.tolist() @@ -3505,6 +3581,7 @@ def func(x, y): # throws a data dependent error. compiled_func(x, torch.tensor([5, 20])) +<<<<<<< HEAD @skipIfTorchDynamo() def test_unbind_not_dynamic(self): cnt = CompileCounter() @@ -3685,6 +3762,8 @@ def f(idx, x): out = torch.compile(f)(idx, x) self.assertEqual(out, f(idx, x)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 78e3ff69ed5b5..e82a06c5d0251 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -64,7 +64,10 @@ skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, +<<<<<<< HEAD skipIfWindows, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TemporaryFileName, TEST_WITH_TORCHDYNAMO, TestCase, @@ -97,7 +100,11 @@ def checkType(self, t, device_str, size): @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_cuda_initialized(self): +<<<<<<< HEAD # doesn't error +======= + # doesnt error +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with FakeTensorMode(): p = torch.randn(4, 2, requires_grad=True, device="cuda") x = torch.randn(8, 4, device="cuda") @@ -211,6 +218,7 @@ def test_zero_dim(self): self.assertEqual(out.device, y.device) self.assertTrue(isinstance(out, FakeTensor)) +<<<<<<< HEAD @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_op_with_zero_dim_bypassed(self): if torch._functorch.config.fake_tensor_propagate_real_tensors: @@ -227,6 +235,8 @@ def test_op_with_zero_dim_bypassed(self): ) as exc: torch.nextafter(fake_x, fake_y) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nan_to_num(self): with FakeTensorMode(): for dtype in [torch.float16, torch.float32]: @@ -269,6 +279,7 @@ def test_device_inplace_copy(self): assert x.copy_(y).device.type == "cpu" assert y.copy_(x).device.type == "cuda" +<<<<<<< HEAD def test_fake_device(self): t = torch.ones(3) t = t.view(1, 3) @@ -282,6 +293,8 @@ def test_fake_device(self): self.assertEqual(new_fake_t.device, fake_t.device) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_fake_dispatch_keys(self): with FakeTensorMode(): x = torch.rand([4]) @@ -1048,6 +1061,7 @@ def test_fast_div(self): y = fast_div(mode, x, 2) self.assertEqual(y.dtype, torch.float32) +<<<<<<< HEAD def test_nanmean_out(self): # Regression test to ensure we don't error out. with torch._subclasses.fake_tensor.FakeTensorMode() as mode: @@ -1068,6 +1082,8 @@ def test_unbind_copy_out(self): self.assertEqual(out[1].dtype, eye.dtype) self.assertEqual(out[2].dtype, eye.dtype) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_parametrized_tests(FakeTensorTest) @@ -1500,7 +1516,11 @@ def forward(self, arg1, arg2, arg3): with torch._subclasses.CrossRefFakeMode(): Repro()(*args) except MetadataMismatchError as e: +<<<<<<< HEAD # We expect the cross ref to succeed for the first output to fail +======= + # We expect the cross ref to succed for the first output to fail +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # for the rng state, see Note [Seed and Offset] self.assertTrue("output[0]" not in str(e)) if self.__class__.__name__.startswith("PropagateRealTensors"): @@ -1721,6 +1741,7 @@ def test_nonzero_stride(self): self.assertEqual(fake_r.T.is_contiguous(), r.T.is_contiguous()) +<<<<<<< HEAD def test_nan_to_num(self): shape_env = ShapeEnv() fake_mode = FakeTensorMode(shape_env=shape_env) @@ -1731,6 +1752,8 @@ def test_nan_to_num(self): self.assertEqual(x.size(), y.size()) self.assertEqual(x.stride(), y.stride()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_torch_load_with_fake_mode(self): model = torch.nn.Linear(5, 10) @@ -1943,6 +1966,7 @@ def test_cache_key_constants(self): self._test_cache_key(fm, 1.0, 1.0, 1) self._test_cache_key(fm, 0.0, 0.0, 0) +<<<<<<< HEAD def test_empty_list(self): with FakeTensorMode() as fm: func = aten.any.dims @@ -1953,6 +1977,8 @@ def test_empty_list(self): self.assertNotEqual(key_x, key_y) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def assertHitsMisses(self, hits, misses): """ Helper to assert on the number of recorded hits and misses. @@ -2332,9 +2358,12 @@ def test_cache_aten_index(self): lambda: torch.ops.aten.index(x, [None, idx_tensor1]), ) +<<<<<<< HEAD @skipIfWindows( msg="weird bug - cache may not be cleared after https://github.com/pytorch/pytorch/pull/154283" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching") def test_invoke_subgraph(self): """ @@ -2376,7 +2405,11 @@ def fn(x, y): self.assertEqual(len(backend.fw_graphs), 1) mod = backend.fw_graphs[0] +<<<<<<< HEAD # Ensure that we see hits every time +======= + # Ensure that we see hits everytime +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with FakeTensorMode(): x = torch.randn(6, 4) y = torch.randn(6, 4) @@ -2496,6 +2529,7 @@ def forward( self.assertBypasses("unrepresented symbol in output", 2) +<<<<<<< HEAD class FakeTensorPreferDeviceType(TestCase): @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_fake_tensor_prefer_device_type(self): @@ -2572,5 +2606,7 @@ def test_fake_tensor_prefer_device_type_cpu_only(self): self.assertTrue(isinstance(result, FakeTensor)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/test_foreach.py b/test/test_foreach.py index 7ac128d6bac8a..08ef496a62b7f 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -12,7 +12,11 @@ import torch from torch.testing import make_tensor from torch.testing._comparison import default_tolerances +<<<<<<< HEAD from torch.testing._internal.common_cuda import _get_torch_cuda_version, TEST_MULTIGPU +======= +from torch.testing._internal.common_cuda import TEST_MULTIGPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, @@ -20,6 +24,10 @@ onlyCUDA, OpDTypes, ops, +<<<<<<< HEAD +======= + skipCUDAVersionIn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from torch.testing._internal.common_dtype import ( all_types_and_complex_and, @@ -43,7 +51,11 @@ TEST_WITH_ROCM, TestCase, ) +<<<<<<< HEAD from torch.testing._internal.triton_utils import requires_cuda_and_triton +======= +from torch.testing._internal.triton_utils import requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator" @@ -79,6 +91,7 @@ def __init__(self, func): def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs): actual = None zero_size = kwargs.pop("zero_size", False) +<<<<<<< HEAD # Skip profiler check for CUDA 12.6, 12.8 as the upgrade makes profiler results flaky # https://github.com/pytorch/pytorch/issues/148681. TODO: ADD IT BACK!!! @@ -86,6 +99,10 @@ def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs): if ( is_cuda and not skip_profiler_check +======= + if ( + is_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and torch.autograd.kineto_available() and torch.profiler.ProfilerActivity.CUDA in torch.profiler.supported_activities() @@ -96,7 +113,10 @@ def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs): torch.cuda.synchronize() keys = tuple([e.key for e in p.key_averages()]) mta_called = any("multi_tensor_apply_kernel" in k for k in keys) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert mta_called == (expect_fastpath and (not zero_size)), ( f"{mta_called=}, {expect_fastpath=}, {zero_size=}, {self.func.__name__=}, {keys=}" ) @@ -196,6 +216,12 @@ def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op): zero_size=True, ) +<<<<<<< HEAD +======= + # Skip CUDA version 12.8 as the upgrade makes profiler results flaky + # https://github.com/pytorch/pytorch/issues/148681 + @skipCUDAVersionIn([(12, 8)]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfRocmVersionLessThan((6, 0)) @ops( foreach_unary_op_db @@ -308,6 +334,12 @@ def _binary_test( else: self.assertEqual(expected, actual) +<<<<<<< HEAD +======= + # Skip CUDA version 12.8 as the upgrade makes profiler results flaky + # https://github.com/pytorch/pytorch/issues/148681 + @skipCUDAVersionIn([(12, 8)]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @ops(filter(lambda op: op.supports_scalar_self_arg, foreach_binary_op_db)) @parametrize("is_fastpath", (True, False)) def test_binary_op_with_scalar_self_support(self, device, dtype, op, is_fastpath): @@ -365,6 +397,12 @@ def clone(arg): @ops(foreach_pointwise_op_db) @parametrize("is_fastpath", (True, False)) +<<<<<<< HEAD +======= + # Skip CUDA version 12.8 as the upgrade makes profiler results flaky + # https://github.com/pytorch/pytorch/issues/148681 + @skipCUDAVersionIn([(12, 8)]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_pointwise_op_with_tensor_of_scalarlist_overload( self, device, dtype, op, is_fastpath ): @@ -702,6 +740,12 @@ def test_binary_op_list_error_cases(self, device, dtype, op): ): foreach_op_([tensor1], [tensor2]) +<<<<<<< HEAD +======= + # Skip CUDA version 12.8 as the upgrade makes profiler results flaky + # https://github.com/pytorch/pytorch/issues/148681 + @skipCUDAVersionIn([(12, 8)]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") @ops( filter(lambda op: op.supports_out, foreach_binary_op_db), @@ -817,6 +861,12 @@ def test_binary_op_list_slow_path(self, device, dtype, op): scalar_self_arg=False, ) +<<<<<<< HEAD +======= + # Skip CUDA version 12.8 as the upgrade makes profiler results flaky + # https://github.com/pytorch/pytorch/issues/148681 + @skipCUDAVersionIn([(12, 8)]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @ops( filter(lambda op: op.supports_out, foreach_binary_op_db), dtypes=floating_types_and(torch.half, torch.bfloat16), @@ -1340,6 +1390,12 @@ def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op): copy_(t, s, non_blocking) self.assertEqual(ref_input, sample.input) +<<<<<<< HEAD +======= + # Skip CUDA version 12.8 as the upgrade makes profiler results flaky + # https://github.com/pytorch/pytorch/issues/148681 + @skipCUDAVersionIn([(12, 8)]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyCUDA @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db)) def test_foreach_copy_with_multi_dtypes(self, device, dtype, op): @@ -1375,7 +1431,11 @@ def test_foreach_copy_with_multi_dtypes_large_input(self): ref_out = torch.empty_like(self_tensor).copy_(src_tensor) self.assertEqual(self_tensor, ref_out) +<<<<<<< HEAD @requires_cuda_and_triton +======= + @requires_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db)) def test_foreach_copy_with_different_device_inputs(self, device, dtype, op): if dtype in (torch.complex128, torch.complex64): diff --git a/test/test_functionalization.py b/test/test_functionalization.py index 65e74297a531f..0588ff20e9219 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -199,7 +199,11 @@ def f(x): y.set_(x.storage()) return y +<<<<<<< HEAD # We should probably get the crossref test to work, +======= + # We should probaby get the crossref test to work, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # but fixing it for Storage() objects is annoying. r = _functionalize(f, reapply_views=True, crossref=False)(torch.ones(2)) self.assertEqual(str(r.device), "cpu") @@ -2318,7 +2322,11 @@ def forward(self, arg0_1): ] ) @unittest.skipIf( +<<<<<<< HEAD TEST_WITH_TORCHDYNAMO, "dynamo-ing code with proxy + fake doesn't work well" +======= + TEST_WITH_TORCHDYNAMO, "dynamo-ing code with proxy + fake doesnt work well" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) class TestCrossRefFunctionalization(TestFunctionalization): crossref = True diff --git a/test/test_functionalization_of_rng_ops.py b/test/test_functionalization_of_rng_ops.py index 9b4542500d50d..cca20c684bfca 100644 --- a/test/test_functionalization_of_rng_ops.py +++ b/test/test_functionalization_of_rng_ops.py @@ -302,7 +302,11 @@ def fn(x, y): fwd_compiler = functools.partial(count_philox_rand, freq=1) bwd_compiler = functools.partial(count_philox_rand, freq=0) aot_fn = aot_function(fn, fwd_compiler, bwd_compiler) +<<<<<<< HEAD # We can't check accuracy here because rand_like generated different rand numbers than dropout +======= + # We cant check accuracy here because rand_like generated different rand numbers than dropout +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) res = aot_fn(x, y) res.sum().backward() @@ -316,7 +320,11 @@ def fn(x): # Ensure the decomp is happening aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=1)) +<<<<<<< HEAD # We can't check accuracy here because rand_like generated different rand numbers than dropout +======= + # We cant check accuracy here because rand_like generated different rand numbers than dropout +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aot_fn(x) diff --git a/test/test_fx.py b/test/test_fx.py index ba80f69828df3..63123eb421d87 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -908,7 +908,11 @@ def __init__(self, interpreter): wrapper = WrapperModule(interpreter) # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter +<<<<<<< HEAD # 3) Returns the specified return value +======= + # 3) Returns the speficied return value +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # FIXME: The following code could be greatly simplified by symbolic_trace'ing # the wrapper with a Tracer that considers the Wrapper instance a root @@ -954,7 +958,11 @@ def __init__(self, interpreter): script_out = scripted_lowered(x) torch.testing.assert_close(script_out, ref_out) +<<<<<<< HEAD # Test TorchScript Ser/De +======= + # Test TorchScript ser/de +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import_copy = self.getExportImportCopy(scripted_lowered) imported_out = import_copy(x) torch.testing.assert_close(imported_out, ref_out) @@ -2225,8 +2233,13 @@ def forward( foo_scripted = torch.jit.script(Foo()) foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) +<<<<<<< HEAD fixed = symbolic_trace(Foo()) fxed_scripted = torch.jit.script(fixed) +======= + fxed = symbolic_trace(Foo()) + fxed_scripted = torch.jit.script(fxed) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) def test_fn_type_annotation_empty(self): @@ -4660,6 +4673,10 @@ def tearDown(self): "linear": BUILT_IN_FUNC, "logsigmoid": BUILT_IN_FUNC, "one_hot": BUILT_IN_FUNC, +<<<<<<< HEAD +======= + "pad": ARG_TYPE_MISMATCH, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "pairwise_distance": BUILT_IN_FUNC, "pdist": BUILT_IN_FUNC, "pixel_shuffle": BUILT_IN_FUNC, @@ -4692,6 +4709,15 @@ def tearDown(self): "max_unpool3d": PROXY_ITERATED, "fold": PROXY_ITERATED, "unfold": PROXY_ITERATED, +<<<<<<< HEAD +======= + "adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH, + "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH, + "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH, + "layer_norm": ARG_TYPE_MISMATCH, + "rms_norm": ARG_TYPE_MISMATCH, + "lp_pool1d": ARG_TYPE_MISMATCH, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "affine_grid": CONTROL_FLOW, "alpha_dropout": CONTROL_FLOW, "batch_norm": CONTROL_FLOW, @@ -4725,6 +4751,12 @@ def tearDown(self): "leaky_relu": CONTROL_FLOW, "local_response_norm": CONTROL_FLOW, "margin_ranking_loss": CONTROL_FLOW, +<<<<<<< HEAD +======= + "max_pool1d_with_indices": ARG_TYPE_MISMATCH, + "max_pool2d_with_indices": ARG_TYPE_MISMATCH, + "max_pool3d_with_indices": ARG_TYPE_MISMATCH, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "mse_loss": CONTROL_FLOW, "multi_head_attention_forward": CONTROL_FLOW, "multi_margin_loss": CONTROL_FLOW, diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 72d770e6d3f02..97aad0ccc95c2 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -53,7 +53,11 @@ ) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_nn import module_tests, get_new_module_tests +<<<<<<< HEAD from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase, TEST_WITH_CROSSREF +======= +from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.jit_utils import JitTestCase import torch.utils._pytree as pytree @@ -963,6 +967,7 @@ def _test_split_graph(split_gm): # `keep_original_order=True` _test_split_graph(split_module(g, None, split_callback=lambda _ : 0, keep_original_order=True)) +<<<<<<< HEAD @unittest.skipIf(TEST_WITH_CROSSREF, "See https://github.com/pytorch/pytorch/issues/160077") def test_split_module_symint_dependency_handling(self): # Based on the code from - transformers/models/granitemoe/modeling_granitemoe.py @@ -1052,6 +1057,8 @@ def backend(gm, inps): actual = torch.compile(moe, backend=backend)(inp) torch.testing.assert_close(actual, expected) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_normalize_binary_operators(self): ops_to_test = { torch.add, diff --git a/test/test_fx_passes.py b/test/test_fx_passes.py index be22f8e61e509..a9bd093fa22ec 100644 --- a/test/test_fx_passes.py +++ b/test/test_fx_passes.py @@ -110,7 +110,11 @@ def forward5(a, b, c): @staticmethod def forward6(a, b, c): +<<<<<<< HEAD # add should have its own partition, as neither branches are supported +======= + # add should have its own partition, as neither branchs are supported +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add = a + 1 # left branch relu = add.relu() @@ -283,7 +287,11 @@ class TestFXGraphPasses(JitTestCase): (TestPartitionFunctions.forward15, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False), (TestPartitionFunctions.forward16, [["permute_1", "add_1", "add"]], True), (TestPartitionFunctions.forward16, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False), +<<<<<<< HEAD # should be empty partition, not a partition with empty nodes +======= + # should be empty partition, not a partiton with empty nodes +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (TestPartitionFunctions.forward18, [], False), ]) def test_partitioner(self, fn, expected_partition, bookend_non_compute_pass): @@ -344,9 +352,15 @@ def test_partitioner_independent_output(self, fn, expected_partition): [['add', 'add_1', 'add_2']], # vertical fusion [['add_2', 'add_3']], # horizontal fusion [['add_3', 'add_4']], +<<<<<<< HEAD [['add_6', 'add_5']], # arbitrary node order [['add_4', 'add_1', 'add_3', 'add_2']], # arbitrary node order [['add_5', 'add_6'], ['add_1', 'add_2', 'add_3', 'add_4']], # arbitrary partition order +======= + [['add_6', 'add_5']], # arbitray node order + [['add_4', 'add_1', 'add_3', 'add_2']], # arbitray node order + [['add_5', 'add_6'], ['add_1', 'add_2', 'add_3', 'add_4']], # arbitray partition order +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [['add_5', 'linear2']], # includes call_function + call_module node [['add_6', 'relu']], # includes call_function + call_module node [['param', 'add_2']], # includes get_attr + call_module nodes diff --git a/test/test_fx_reinplace_pass.py b/test/test_fx_reinplace_pass.py index 4acda3bece746..0ea4073f51d68 100644 --- a/test/test_fx_reinplace_pass.py +++ b/test/test_fx_reinplace_pass.py @@ -43,7 +43,11 @@ def test_reinplace_with_view(self): def f(x): a = x.clone() a_view = a.view(-1) +<<<<<<< HEAD # We shouldn't re-inplace the first add(), because an alias of a is reused later in the program +======= + # We shouldn't re-inplace the first add(), because an alias of a is re-used later in the program +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) b = a.add(1) # noqa: F841 # Second add() is fine to re-inplace @@ -188,11 +192,22 @@ def f(a_): def forward(self, a__1): clone = torch.ops.aten.clone.default(a__1); a__1 = None +<<<<<<< HEAD select = torch.ops.aten.select.int(clone, 1, 1) select_1 = torch.ops.aten.select.int(select, 0, 1); select = None add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = add = None select_2 = torch.ops.aten.select.int(clone, 1, 1); select_2 = None select_3 = torch.ops.aten.select.int(clone, 1, 1) +======= + slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) + select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None + select_1 = torch.ops.aten.select.int(select, 0, 1); select = None + add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = add = None + slice_2 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) + select_2 = torch.ops.aten.select.int(slice_2, 1, 1); slice_2 = select_2 = None + slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) + select_3 = torch.ops.aten.select.int(slice_3, 1, 1); slice_3 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) select_4 = torch.ops.aten.select.int(select_3, 0, 1); select_3 = select_4 = None return clone """) @@ -225,7 +240,12 @@ def f(a_): def forward(self, a__1): clone = torch.ops.aten.clone.default(a__1); a__1 = None +<<<<<<< HEAD select = torch.ops.aten.select.int(clone, 1, 1) +======= + slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) + select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) select_1 = torch.ops.aten.select.int(select, 0, 1); select = None add = torch.ops.aten.add_.Tensor(select_1, 1); select_1 = add = None as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None @@ -260,7 +280,12 @@ def f(a_): def forward(self, a__1): clone = torch.ops.aten.clone.default(a__1); a__1 = None +<<<<<<< HEAD select = torch.ops.aten.select.int(clone, 1, 1) +======= + slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) + select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) select_1 = torch.ops.aten.select.int(select, 0, 1); select = None add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 1); clone = None @@ -292,7 +317,12 @@ def f(a_): def forward(self, a__1): clone = torch.ops.aten.clone.default(a__1); a__1 = None +<<<<<<< HEAD select = torch.ops.aten.select.int(clone, 1, 1) +======= + slice_1 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807) + select = torch.ops.aten.select.int(slice_1, 1, 1); slice_1 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) select_1 = torch.ops.aten.select.int(select, 0, 1); select = None add = torch.ops.aten.add.Tensor(select_1, 1); select_1 = None as_strided = torch.ops.aten.as_strided.default(clone, [4], [4], 0); clone = None @@ -347,9 +377,18 @@ def f(): def forward(self): zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False) ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False) +<<<<<<< HEAD slice_1 = torch.ops.aten.slice.Tensor(zeros, 1, 2, 9223372036854775807) copy = torch.ops.aten.copy_.default(slice_1, ones); slice_1 = ones = copy = None slice_2 = torch.ops.aten.slice.Tensor(zeros, 1, 2, 9223372036854775807); slice_2 = None +======= + slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) + slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807); slice_1 = None + copy = torch.ops.aten.copy_.default(slice_2, ones); slice_2 = ones = copy = None + slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807); slice_3 = None + slice_4 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) + slice_5 = torch.ops.aten.slice.Tensor(slice_4, 1, 2, 9223372036854775807); slice_4 = slice_5 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return zeros """) diff --git a/test/test_indexing.py b/test/test_indexing.py index 7a202efbe084f..e3da091798896 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -5,7 +5,10 @@ import unittest import warnings from functools import reduce +<<<<<<< HEAD from itertools import product +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import numpy as np @@ -16,14 +19,19 @@ dtypes, dtypesIfCPU, dtypesIfCUDA, +<<<<<<< HEAD dtypesIfMPS, expectedFailureMPS, instantiate_device_type_tests, onlyCPU, +======= + instantiate_device_type_tests, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) onlyCUDA, onlyNativeDeviceTypes, skipXLA, ) +<<<<<<< HEAD from torch.testing._internal.common_dtype import ( all_mps_types_and, all_types_and, @@ -33,11 +41,18 @@ from torch.testing._internal.common_utils import ( DeterministicGuard, parametrize, +======= +from torch.testing._internal.common_utils import ( + DeterministicGuard, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_tests, serialTest, skipIfTorchDynamo, TEST_CUDA, +<<<<<<< HEAD TEST_MPS, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TestCase, xfailIfTorchDynamo, ) @@ -152,10 +167,14 @@ def consec(size, start=1): ) lst = [list(range(i, i + 10)) for i in range(0, 100, 10)] +<<<<<<< HEAD _make_tensor = ( torch.DoubleTensor if not device.startswith("mps") else torch.FloatTensor ) tensor = _make_tensor(lst).to(device) +======= + tensor = torch.DoubleTensor(lst).to(device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for _i in range(100): idx1_start = random.randrange(10) idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1) @@ -171,7 +190,11 @@ def consec(size, start=1): else: lst_indexed = lst[idx1] tensor_indexed = tensor[idx1] +<<<<<<< HEAD self.assertEqual(_make_tensor(lst_indexed), tensor_indexed) +======= + self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertRaises(ValueError, lambda: reference[1:9:0]) self.assertRaises(ValueError, lambda: reference[1:9:-1]) @@ -194,7 +217,10 @@ def delitem(): @onlyNativeDeviceTypes @dtypes(torch.half, torch.double) +<<<<<<< HEAD @dtypesIfMPS(torch.half) # TODO: add bf16 there? +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_advancedindex(self, device, dtype): # Tests for Integer Array Indexing, Part I - Purely integer array # indexing @@ -247,7 +273,11 @@ def validate_setting(x): x[ri([0, 2, 4]),], torch.tensor([5, 4, 3], dtype=dtype, device=device) ) +<<<<<<< HEAD # Only validates indexing and setting for Halfs +======= + # Only validates indexing and setting for halfs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if dtype == torch.half: reference = consec((10,)) validate_indexing(reference) @@ -924,6 +954,7 @@ def test_multiple_bool_indices(self, device): mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device) self.assertEqual(v[mask1, :, mask2].shape, (3, 7)) +<<<<<<< HEAD def test_multi_dimensional_bool_mask(self, device): x = torch.randn(2, 2, 3, device=device) b = ((True, False), (False, False)) @@ -984,6 +1015,8 @@ def test_multi_dimensional_bool_mask_assignment(self, device): torch.ops.aten.index_put_(v, [None, mask, None], torch.tensor(0)) self.assertEqual(v, torch.tensor([[[[0], [2]], [[3], [0]]]], device=device)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_byte_mask(self, device): v = torch.randn(5, 7, 3, device=device) mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device) @@ -1005,11 +1038,18 @@ def test_byte_mask_accumulate(self, device): self.assertEqual(y, torch.ones(size=(10, 10), device=device)) self.assertEqual(len(w), 2) +<<<<<<< HEAD # MPS: Fails locally, but passes in CI... @skipIfTorchDynamo( "This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472" ) @serialTest(TEST_CUDA or TEST_MPS) +======= + @skipIfTorchDynamo( + "This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472" + ) + @serialTest(TEST_CUDA) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_index_put_accumulate_large_tensor(self, device): # This test is for tensors with number of elements >= INT_MAX (2^31 - 1). N = (1 << 31) + 5 @@ -1208,11 +1248,18 @@ def func1(x, i, v): @onlyNativeDeviceTypes def test_index_put_accumulate_duplicate_indices(self, device): +<<<<<<< HEAD dtype = torch.float if device.startswith("mps") else torch.double for i in range(1, 512): # generate indices by random walk, this will create indices with # lots of duplicates interleaved with each other delta = torch.empty(i, dtype=dtype, device=device).uniform_(-1, 1) +======= + for i in range(1, 512): + # generate indices by random walk, this will create indices with + # lots of duplicates interleaved with each other + delta = torch.empty(i, dtype=torch.double, device=device).uniform_(-1, 1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) indices = delta.cumsum(0).long() input = torch.randn(indices.abs().max() + 1, device=device) @@ -1321,7 +1368,10 @@ def test_int_indices(self, device): torch.float8_e5m2, torch.float8_e4m3fn, ) +<<<<<<< HEAD @dtypesIfMPS(torch.float, torch.float16, torch.long, torch.bool) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_index_put_src_datatype(self, device, dtype): src = torch.ones(3, 2, 4, device=device, dtype=dtype) vals = torch.ones(3, 2, 4, device=device, dtype=dtype) @@ -1797,6 +1847,7 @@ def test_index_limits(self, device): self.assertRaises(IndexError, lambda: t[idx_min]) self.assertRaises(IndexError, lambda: t[idx_max]) +<<<<<<< HEAD @parametrize("reduce", ["prod", "amin", "amax", "mean"]) @dtypes(*all_types_and(torch.half, torch.bfloat16)) @expectedFailureMPS # Unimplemented for MPS device @@ -2125,6 +2176,8 @@ def ref_index_select(src, dim, idx): out = source.index_select(0, idx) self.assertEqual(out.item(), source.item()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The tests below are from NumPy test_indexing.py with some modifications to # make them compatible with PyTorch. It's licensed under the BDS license below: @@ -2396,9 +2449,13 @@ def test_truncate_leading_1s(self, device): self.assertEqual(kernel, kernel2) +<<<<<<< HEAD instantiate_device_type_tests( TestIndexing, globals(), except_for="meta", allow_mps=True ) +======= +instantiate_device_type_tests(TestIndexing, globals(), except_for="meta") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_device_type_tests(NumpyTests, globals(), except_for="meta") if __name__ == "__main__": diff --git a/test/test_jit.py b/test/test_jit.py index c86fb111bfb85..823a98d70bf63 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4764,7 +4764,11 @@ def fun(): self.assertIsNot(fun_compiled, fun_compiled_2) self.assertEqual(fun_compiled_2(), 7) +<<<<<<< HEAD # caching doesn't increase refcounts to function (holds weak reference) +======= + # caching doesnt increase refcounts to function (holds weak reference) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(sys.getrefcount(fun), num_ref_counts) def test_string_ops(self): @@ -7374,7 +7378,11 @@ def func(): # tensor from empty list is type float in python and annotated type in torchscript if "annotate" in li and "dtype" not in option: continue +<<<<<<< HEAD # Skip unsigned tensor initialization for signed values on 3.10 +======= + # Skip unsigned tensor initializaton for signed values on 3.10 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info[:2] >= (3, 10) and "torch.uint8" in option and "-" in li: continue code = tensor_template.format(list_create=li, tensor_op=op, options=option) @@ -7990,7 +7998,11 @@ def test_varexit(cond): m += k return m +<<<<<<< HEAD # use of k tests the pathway where we have to insert uninitialized +======= + # use of k tests the pathway where we have to insert unitialized +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.checkScript(test_varexit, (3,)) self.checkScript(test_varexit, (2,)) @@ -10066,7 +10078,11 @@ def forward(self): buffer = io.BytesIO() torch.jit.save(cm, buffer) buffer.seek(0) +<<<<<<< HEAD # when tensor is loaded as constant it isn't specialized +======= + # when tensor is loaded as constant it isnt specialized +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cm_load = torch.jit.load(buffer) FileCheck().check_not("Float(1, 3)").run(cm_load.forward.graph) @@ -10300,7 +10316,11 @@ def method(self, x): def test_type_inferred_from_empty_annotation(self): """ +<<<<<<< HEAD Test that the type inferred from an empty or missing annotation is Torch.Tensor with `inferred=true` +======= + Test that the type inferred from an empty or missing annotation is Torch.Tensor wtih `inferred=true` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ @torch.jit.script def fn(x): @@ -14835,7 +14855,11 @@ def forward(self): # testing overload declared first, then non-overload if sys.version_info < (3, 13): # test broken in 3.13 +<<<<<<< HEAD with self.assertRaisesRegex(Exception, "Overloads are not usable when a module"): +======= + with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class W3(torch.nn.Module): @torch.jit._overload_method # noqa: F811 def forward(self, x): # noqa: F811 @@ -14888,7 +14912,11 @@ def forward(self, x): return self.hello(1), self.hello(x) if sys.version_info < (3, 13): # test broken in 3.13 +<<<<<<< HEAD with self.assertRaisesRegex(Exception, "Overloads are not usable when a module"): +======= + with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = torch.jit.script(W2()) def test_narrow_copy(self): @@ -15606,7 +15634,11 @@ def forward(self): a = hasattr(self, "fee") b = hasattr(self, "foo") c = hasattr(self, "hi") +<<<<<<< HEAD d = hasattr(self, "nonexistent") +======= + d = hasattr(self, "nonexistant") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (a, b, c, d) def foo(self): @@ -16044,7 +16076,11 @@ def f(x): # chunk returns a list in scripting and we don't unpack the list, # Thus it won't be replaced by ConstantChunk and run AD. # It's explicitly checked in test_chunk_constant_script_ad +<<<<<<< HEAD # Similarly for split, it's replaced by split_with_sizes in tracing, +======= +# Similary for split, it's replaced by split_with_sizes in tracing, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # but we don't have AD formula for aten::split(Tensor, int[], int), # an op registered in JIT so AD is not triggered in scripting. EXCLUDE_SCRIPT_AD_CHECK = { diff --git a/test/test_jit_autocast.py b/test/test_jit_autocast.py index b3cf4d9bee8f1..748020564df74 100644 --- a/test/test_jit_autocast.py +++ b/test/test_jit_autocast.py @@ -319,7 +319,11 @@ def fn(a, b, c, d): # TODO: fix and enable this test? # (we could technically fix this, but is it really worth it?) +<<<<<<< HEAD @unittest.skipIf(True, "unsupported autocast syntax") +======= + @unittest.skipIf(True, "unsuported autocast syntax") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_reused_autocast_expr(self): @torch.jit.script def fn(a, b, c, d): diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index c3e26d37da1b2..6bbb23a4c7a91 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -126,7 +126,11 @@ def setUp(self): super().setUp() self.tensorexpr_options = TensorExprTestOptions() +<<<<<<< HEAD # note: `self.dynamic_shapes` instantiated in specialization of class +======= + # note: `self.dynamic_shapes` instatiated in specialization of class +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # defined below fusion_strategy = [("DYNAMIC", 20)] if self.dynamic_shapes else [("STATIC", 20)] @@ -2939,10 +2943,14 @@ def test_unsupported(self, device, dtype, op): @slowTest @onlyCPU +<<<<<<< HEAD @ops( [op for op in op_db if get_name(op) not in known_failures], dtypes=OpDTypes.supported, ) +======= + @ops(op_db, dtypes=OpDTypes.supported) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nnc_correctness(self, device, dtype, op): if not op.supports_tracing: self.skipTest("Requires tracing support") diff --git a/test/test_legacy_vmap.py b/test/test_legacy_vmap.py index bfd1075b25ed5..556f478265398 100644 --- a/test/test_legacy_vmap.py +++ b/test/test_legacy_vmap.py @@ -1679,7 +1679,11 @@ def get(shape): # Interesting case #2: Batch dim at end of tensor, success cases # view_as_complex requires that the dim with size 2 have stride 1 +<<<<<<< HEAD # in order for the view to function properly +======= + # in order for the view to function propertly +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) test(op, [get([B0, 2]).transpose(0, 1)], in_dims=1) test(vmap(op, in_dims=1), [get([B0, B1, 2]).movedim(1, 2)]) test(vmap(op, in_dims=2), [get([B0, 3, B1, 2]).movedim(2, 3)]) diff --git a/test/test_license.py b/test/test_license.py index 6f289a15bb4ec..4256bdb3e6a72 100644 --- a/test/test_license.py +++ b/test/test_license.py @@ -45,7 +45,11 @@ def test_distinfo_license(self): 'Found too many "torch-*dist-info" directories ' f'in "{site_packages}, expected only one' ) +<<<<<<< HEAD # setuptools renamed *dist-info/LICENSE to *dist-info/licenses/LICENSE since 77.0 +======= + # setuptools renamed *dist-info/LICENSE to *dist-info/licenses/LICENSE sicne 77.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) license_file = os.path.join(distinfo[0], "licenses", "LICENSE") if not os.path.exists(license_file): license_file = os.path.join(distinfo[0], "LICENSE") diff --git a/test/test_linalg.py b/test/test_linalg.py index 31d4e0d1d92d5..f5b39f4403868 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -40,7 +40,11 @@ _get_torch_cuda_version, CDNA2OrLater, TEST_MULTIGPU from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel, \ _group_quantize_tensor_symmetric +<<<<<<< HEAD from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off +======= +from torch.testing._internal.common_mkldnn import bf32_on_and_off +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributions.binomial import Binomial import torch.backends.opt_einsum as opt_einsum import operator @@ -109,6 +113,25 @@ def get_tunableop_untuned_filename(): return untuned_filename class TestLinalg(TestCase): +<<<<<<< HEAD +======= + @contextlib.contextmanager + def _hip_allow_tf32(self): + # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new + # and only for MI300+. Environment variable will be removed in the future. + import os + hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) + os.environ["HIPBLASLT_ALLOW_TF32"] = "1" + + try: + yield + finally: + if hip_allow_tf32 is not None: + os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 + else: + del os.environ["HIPBLASLT_ALLOW_TF32"] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def setUp(self): super().setUp() torch.backends.cuda.matmul.allow_tf32 = False @@ -119,7 +142,11 @@ def tearDown(self): @contextlib.contextmanager def _tunableop_ctx(self): +<<<<<<< HEAD # Initialize and then tear down TunableOp +======= + # Inialize and then tear down TunableOp +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import glob import os self._set_tunableop_defaults() @@ -215,7 +242,11 @@ def _compare_untuned_tuned_entries(self, untuned_filename=None, tuned_filename=N @dtypes(torch.float, torch.cfloat) @precisionOverride({torch.float: 1e-06, torch.cfloat: 1e-06}) @tf32_on_and_off(5e-3) +<<<<<<< HEAD @reduced_f32_on_and_off(5e-3) +======= + @bf32_on_and_off(5e-3) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_inner(self, device, dtype): def check(a_sizes_, b_sizes_): for a_sizes, b_sizes in ((a_sizes_, b_sizes_), (b_sizes_, a_sizes_)): @@ -769,7 +800,11 @@ def cholesky_test_helper(n, batch_dims, upper): @skipCPUIfNoLapack @dtypes(*floating_and_complex_types()) @tf32_on_and_off(0.1 if TEST_WITH_ROCM else 0.01) +<<<<<<< HEAD @reduced_f32_on_and_off(0.01) +======= + @bf32_on_and_off(0.01) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_old_cholesky(self, device, dtype): from torch.testing._internal.common_utils import random_hermitian_pd_matrix @@ -4238,6 +4273,7 @@ def test(n=10, # how many tests to generate test(500) +<<<<<<< HEAD @dtypes(torch.float) def test_einsum_output_layout(self, device, dtype): batch, in_dim, out_dim = 2, 3, 5 @@ -4250,6 +4286,8 @@ def test_einsum_output_layout(self, device, dtype): self.assertEqual(result.stride(), expected.stride()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_einsum_corner_cases(self, device): def check(equation, *operands, expected_output): tensors = [torch.tensor(operand, device=device, dtype=torch.float32) if not isinstance(operand, tuple) @@ -4257,7 +4295,11 @@ def check(equation, *operands, expected_output): output = torch.einsum(equation, tensors) self.assertEqual(output, torch.tensor(expected_output, dtype=torch.float32, device=device)) +<<<<<<< HEAD # Test equation variations +======= + # Test equation variantions +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) check(' ', 1, expected_output=1) check(' -> ', 1, expected_output=1) check(' , ', 2, 2, expected_output=4) @@ -4767,7 +4809,11 @@ def test_matmul_small_brute_force_tunableop(self, device, dtype): with self._tunableop_ctx(): torch.cuda.tunable.set_rotating_buffer_size(0) # Numerical check adds significant overhead, unsure if this is needed +<<<<<<< HEAD # or if there was a transient problem at the time. +======= + # or if there was a transiet problem at the time. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # if dtype is torch.half: # os.environ["PYTORCH_TUNABLEOP_NUMERICAL_CHECK"] = "1" ordinal = torch.cuda.current_device() @@ -5006,7 +5052,11 @@ def test_scaled_gemm_offline_tunableop(self, device, dtype): torch.cuda.tunable.tune_gemm_in_file(untuned_filename) new_results = len(torch.cuda.tunable.get_results()) +<<<<<<< HEAD # This stores total number of cumulative results +======= + # This stores total number of cummulative results +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) total_num_results = new_results - ref_results # Rowwise case will have an extra solution @@ -5199,7 +5249,11 @@ def test_validator_tunableop_rocm(self, device, dtype): # Validator,ROCBLAS_VERSION,X.Y,Z # Validator,HIPBLASLT_VERSION,X,Y.Z # Validator,ROCM_Version,X,Y.Z +<<<<<<< HEAD # Validator,GCN_ARCH_NAME, +======= + # Validator,GCN_ARCH_NAME, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) validator_num_lines = 5 with self._tunableop_ctx(): @@ -5216,7 +5270,11 @@ def test_validator_tunableop_rocm(self, device, dtype): # Check for rocBLAS and hipBLASLt self.assertTrue("ROCBLAS_VERSION" in validators) # format: [major].[minor].[patch].[tweak].[commit id] +<<<<<<< HEAD self.assertTrue(re.match(r'^\d+[a-z0-9.]+$', validators["ROCBLAS_VERSION"])) +======= + self.assertTrue(re.match(r'^\d+.\d+.\d+.\d+.[a-z0-9]+$', validators["ROCBLAS_VERSION"])) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue("HIPBLASLT_VERSION" in validators) self.assertTrue(re.match(r'^\d+-[a-z0-9]+$', validators["HIPBLASLT_VERSION"])) @@ -5239,7 +5297,11 @@ def test_minimum_tuning_iteration_tunableop(self, device, dtype): B = torch.randn(K, M, device=device, dtype=dtype) C = torch.matmul(A, B) +<<<<<<< HEAD # This stores total number of cumulative results +======= + # This stores total number of cummulative results +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) total_num_results = len(torch.cuda.tunable.get_results()) # There must be a new tuning result @@ -5267,7 +5329,11 @@ def test_matmul_check_entries_tunableop(self, device, dtype): B = torch.randn(K, M, device=device, dtype=dtype) C = torch.matmul(A, B) +<<<<<<< HEAD # This stores total number of cumulative results +======= + # This stores total number of cummulative results +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) total_num_results = len(torch.cuda.tunable.get_results()) # Take the difference to calculate the number of results from @@ -5300,7 +5366,11 @@ def test_disable_tuning_tunableop(self, device, dtype): B = torch.randn(K, M, device=device, dtype=dtype) C = torch.matmul(A, B) +<<<<<<< HEAD # This stores total number of cumulative results +======= + # This stores total number of cummulative results +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) total_num_results = len(torch.cuda.tunable.get_results()) # Take the difference to calculate the number of results from @@ -5323,7 +5393,11 @@ def test_disable_tuning_tunableop(self, device, dtype): # Take the difference to calculate the number of results from # this test. There should be no change in the number of results +<<<<<<< HEAD # since tuning is disable. +======= + # since tuning is disabe. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual((total_num_results - ref_num_results), 0) @onlyCUDA @@ -5332,7 +5406,11 @@ def test_dump_results_on_exit_tunableop(self, device, dtype): # Test that the TunableOp results file is created # and is NOT empty. # To test this we create a subprocess and then +<<<<<<< HEAD # execute a matmul from within the subprocess +======= + # execut a matmul from within the subprocess +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import os import multiprocessing as mp @@ -5381,7 +5459,11 @@ def test_gemm_bias_tunableop(self, device, dtype): torch.nn.functional.linear(X, matA, bias) +<<<<<<< HEAD # This stores total number of cumulative results +======= + # This stores total number of cummulative results +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) total_num_results = len(torch.cuda.tunable.get_results()) # There must be a new tuning result @@ -5435,7 +5517,11 @@ def test_gemm_bias_offline_tunableop(self, device, dtype): torch.cuda.tunable.tune_gemm_in_file(untuned_filename) new_results = len(torch.cuda.tunable.get_results()) +<<<<<<< HEAD # This stores total number of cumulative results +======= + # This stores total number of cummulative results +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) total_num_results = new_results - ref_results # There must be a new tuning results @@ -5511,7 +5597,11 @@ def test_scaled_gemm_tunableop(self, device, dtype): scaleB = torch.ones((1, matB.shape[1]), device=device) torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16) +<<<<<<< HEAD # This stores total number of cumulative results +======= + # This stores total number of cummulative results +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) total_num_results = len(torch.cuda.tunable.get_results()) # Rowwise case will have an extra solution @@ -5526,8 +5616,18 @@ def test_scaled_gemm_tunableop(self, device, dtype): @runOnRocmArch(MI300_ARCH) @dtypes(torch.float) def test_tf32_tunableop(self, device, dtype): +<<<<<<< HEAD try: with self._tunableop_ctx(): +======= + # Test TunableOp with TF32. Supported by hipblasLT on MI300+. + # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new + # and only for MI300+. Eventually this flag will go away. + tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext + + try: + with self._tunableop_ctx(), tf32_ctx(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.backends.cuda.matmul.allow_tf32 = True torch.cuda.tunable.set_rotating_buffer_size(0) @@ -5590,8 +5690,18 @@ def test_tf32_offline_tunableop(self, device, dtype): # This test is the offline version of test_tf32_tunableop import os +<<<<<<< HEAD try: with self._tunableop_ctx(): +======= + # Test TunableOp with TF32. Supported by hipblasLT on MI300+. + # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new + # and only for MI300+. Eventually this flag will go away. + tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext + + try: + with self._tunableop_ctx(), tf32_ctx(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.backends.cuda.matmul.allow_tf32 = True ordinal = torch.cuda.current_device() torch.cuda.tunable.set_rotating_buffer_size(0) @@ -5625,7 +5735,11 @@ def test_tf32_offline_tunableop(self, device, dtype): torch.cuda.tunable.tune_gemm_in_file(untuned_filename) new_results = len(torch.cuda.tunable.get_results()) +<<<<<<< HEAD # This stores total number of cumulative results +======= + # This stores total number of cummulative results +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) total_num_results = new_results - ref_results # There must be a new tuning results @@ -5866,7 +5980,11 @@ def test_mm_submatrix_offline_tunableop(self, device, dtype): torch.cuda.tunable.tune_gemm_in_file(untuned_filename) new_results = len(torch.cuda.tunable.get_results()) +<<<<<<< HEAD # This stores total number of cumulative results +======= + # This stores total number of cummulative results +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) total_num_results = new_results - ref_results # There must be a new tuning results @@ -5989,7 +6107,10 @@ def test_tensordot_out_kernel_errors_with_autograd(self, device, dtype): self.assertEqual(len(w), 1) # 4GB should do, but we run tests in parallel in CI, so let's be generous +<<<<<<< HEAD @onlyCUDA +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @largeTensorTest('16GB', device='cuda') def test_large_bmm_mm_backward(self, device): A = torch.randn([1024, 2, 1024], device="cuda").mT.contiguous().mT @@ -6000,7 +6121,10 @@ def test_large_bmm_mm_backward(self, device): (A @ B).backward(G) # 4GB should do, but we run tests in parallel in CI, so let's be generous +<<<<<<< HEAD @onlyCUDA +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @largeTensorTest('16GB', device='cuda') def test_large_bmm_backward(self, device): A = torch.randn([1024, 2, 1024], device="cuda").mT.contiguous().mT @@ -6689,7 +6813,11 @@ def test_lu_unpack_check_input(self, device, dtype): with self.assertRaisesRegex(RuntimeError, "torch.int32 dtype"): torch.lu_unpack(lu_data, lu_pivots.long()) +<<<<<<< HEAD # check that once flags are unset, Nones are returned +======= + # check that onces flags are unset, Nones are returned +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False) self.assertTrue(l.numel() == 0 and u.numel() == 0) p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_pivots=False) @@ -6908,7 +7036,11 @@ def tracker(worker): lambdas1.append(worker.E[:]) tol = 1e-8 +<<<<<<< HEAD # tol for scipy lobpcg will be chosen so that the number of +======= + # tol for scipy lobpcg will be choosed so that the number of +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # iterations will be equal or very close to pytorch lobpcg # (that is around 170-180) @@ -6988,7 +7120,11 @@ def tracker(worker): -(input size: {m:4}, eigenpairs:{k:2}, units: ms per call)- ''') +<<<<<<< HEAD # Handling of very small tolerance +======= + # Handling of very small tolerence +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tol = 1e-100 lambdas1 = [] @@ -7188,7 +7324,11 @@ def maybe_transpose(cond, m): *[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [])) @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) @tf32_on_and_off(0.05) +<<<<<<< HEAD @reduced_f32_on_and_off(0.05) +======= + @bf32_on_and_off(0.05) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_addmm(self, device, dtype): self._test_addmm_impl(torch.addmm, None, device, dtype) @@ -7198,7 +7338,11 @@ def test_addmm(self, device, dtype): *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else [])) @dtypes(*floating_types_and(torch.bfloat16)) @tf32_on_and_off(0.05) +<<<<<<< HEAD @reduced_f32_on_and_off(0.05) +======= + @bf32_on_and_off(0.05) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_addmm_relu(self, device, dtype): self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype) @@ -7210,7 +7354,11 @@ def test_addmm_relu(self, device, dtype): *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else [])) @dtypes(*floating_types_and(torch.bfloat16)) @tf32_on_and_off(0.05) +<<<<<<< HEAD @reduced_f32_on_and_off(0.05) +======= + @bf32_on_and_off(0.05) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_addmm_relu_tunableop_rocm(self, device, dtype): with self._tunableop_ctx(): torch.cuda.tunable.set_rotating_buffer_size(0) @@ -7224,14 +7372,22 @@ def test_addmm_relu_tunableop_rocm(self, device, dtype): *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else [])) @dtypes(*floating_types_and(torch.bfloat16)) @tf32_on_and_off(0.05) +<<<<<<< HEAD @reduced_f32_on_and_off(0.05) +======= + @bf32_on_and_off(0.05) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_addmm_gelu(self, device, dtype): self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype) @dtypes(torch.float, torch.double) @dtypesIfCUDA(*floating_and_complex_types()) @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) +<<<<<<< HEAD @reduced_f32_on_and_off(0.005) +======= + @bf32_on_and_off(0.005) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_addmm_sizes(self, device, dtype): for m in [0, 1, 25]: for n in [0, 1, 10]: @@ -7753,7 +7909,11 @@ def dyn_quant_matmul_4bit( all_elements_within_threshold, "Some elements have error >= 0.06" ) +<<<<<<< HEAD @onlyNativeDeviceTypes +======= + @onlyCPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("m", [32, 64]) @parametrize("k", [32, 64]) @parametrize("n", [48, 64]) @@ -7799,6 +7959,7 @@ def weight_int8pack_mm(a, b_int8pack, b_scales): mean_err = ((res - ref).abs() / ref).mean() self.assertTrue(mean_err < 0.05) +<<<<<<< HEAD @slowTest @onlyCPU @largeTensorTest('12GB', device='cpu') @@ -7825,6 +7986,8 @@ def weight_int8pack_mm(a, b_int8pack, b_scales): # should pass without segfault weight_int8pack_mm(a, b_int8pack, b_scales) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyCPU @parametrize("m", [32, 35, 36, 40, 64]) @parametrize("k", [32, 35, 36, 40, 64]) @@ -7855,7 +8018,11 @@ def test_fp16_mv_transposed_first_argument_arm_cpu(self, device, m, k): @dtypes(torch.half, torch.float32, torch.float64, torch.int32, torch.int64, torch.cfloat, torch.cdouble) @dtypesIfCUDA(torch.float32, torch.float64, torch.cfloat, torch.cdouble) @tf32_on_and_off(0.01) +<<<<<<< HEAD @reduced_f32_on_and_off(0.01) +======= + @bf32_on_and_off(0.01) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_mm(self, device, dtype): def _test_mm(n, m, p, dtype, genf): # helper function @@ -8035,12 +8202,20 @@ def test_strided_mm_bmm(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) @tf32_on_and_off(0.05) +<<<<<<< HEAD @reduced_f32_on_and_off(0.05) +======= + @bf32_on_and_off(0.05) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_bmm(self, device, dtype): if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater: # cuBLAS does not guarantee BFloat16 support on SM < 53. # So on PyTorch, we consider BFloat16 support on SM < 53 as +<<<<<<< HEAD # undefined behavior +======= + # undefined bahavior +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return batch_sizes = [1, 10] @@ -8111,7 +8286,11 @@ def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor): with self.assertWarnsOnceRegex( UserWarning, f"This overload of {func}_ is deprecated"): getattr(out_tensor, func + "_")(1, b1, b2) +<<<<<<< HEAD self.assertEqual(out_tensor, ref * 2) +======= + self.assertEqual(out_tensor, ref * 2), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) getattr(res3, func + "_")(b1, b2, beta=1) self.assertEqual(out_tensor, res3) @@ -8127,7 +8306,11 @@ def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor): self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2)) res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=.5) +<<<<<<< HEAD self.assertEqual(res4, ref * 3) +======= + self.assertEqual(res4, ref * 3), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nan = torch.full_like(out_tensor, math.nan) res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1) @@ -8148,12 +8331,20 @@ def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor): @onlyNativeDeviceTypes @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) @tf32_on_and_off(0.05) +<<<<<<< HEAD @reduced_f32_on_and_off(0.05) +======= + @bf32_on_and_off(0.05) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_addbmm(self, device, dtype): if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater: # cuBLAS does not guarantee BFloat16 support on SM < 53. # So on PyTorch, we consider BFloat16 support on SM < 53 as +<<<<<<< HEAD # undefined behavior +======= + # undefined bahavior +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return num_batches = 2 @@ -8222,12 +8413,20 @@ def generate_tensor(): @onlyNativeDeviceTypes @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) @tf32_on_and_off(0.05) +<<<<<<< HEAD @reduced_f32_on_and_off(0.05) +======= + @bf32_on_and_off(0.05) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_baddbmm(self, device, dtype): if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater: # cuBLAS does not guarantee BFloat16 support on SM < 53. # So on PyTorch, we consider BFloat16 support on SM < 53 as +<<<<<<< HEAD # undefined behavior +======= + # undefined bahavior +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return num_batches = 10 @@ -9182,7 +9381,11 @@ def dims_full_for_fn(): # ROCm 6.4 passes with tf32=on, but 6.4.1 needed tolerance reduced slightly @tf32_on_and_off(0.002 if torch.version.hip else 0.001) +<<<<<<< HEAD @reduced_f32_on_and_off(0.001) +======= + @bf32_on_and_off(0.001) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_broadcast_batched_matmul(self, device): n_dim = random.randint(1, 8) m_dim = random.randint(1, 8) @@ -9519,7 +9722,11 @@ def fn(torchfn, *args): fn(torch.slogdet, (0, 0))) @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) +<<<<<<< HEAD @reduced_f32_on_and_off(0.07, 0.005) +======= + @bf32_on_and_off(0.07) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_tensordot(self, device): a = torch.arange(60., device=device).reshape(3, 4, 5) b = torch.arange(24., device=device).reshape(4, 3, 2) diff --git a/test/test_masked.py b/test/test_masked.py index 1c6bd42ab763f..89ed038792f98 100644 --- a/test/test_masked.py +++ b/test/test_masked.py @@ -57,7 +57,11 @@ def apply_masked_reduction_along_dim(op, input, *args, **kwargs): [[op([1, 2], *args0, **kwargs, dim=None, keepdim=False)] [op([3, 4, 5], *args0, **kwargs, dim=None, keepdim=False)]] +<<<<<<< HEAD where args0 is args where dim value is replaced with None if +======= + where args0 is args where dim value is replased with None if +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) present. Using the same example data, if the op is called with dim=(0, 1) diff --git a/test/test_maskedtensor.py b/test/test_maskedtensor.py index 03c05c7ea6da4..df5fc5d13a434 100644 --- a/test/test_maskedtensor.py +++ b/test/test_maskedtensor.py @@ -236,6 +236,7 @@ def test_to_sparse(self, device): _compare_mt_t(sparse_mt, data) _compare_mt_t(mt.grad, data.grad) +<<<<<<< HEAD def test_to_device(self, device): for sample in _generate_sample_data(device=device): data = sample.input @@ -262,6 +263,8 @@ def test_to_dtype(self, device): self.assertEqual(mt_dtype.get_mask().dtype, torch.bool) self.assertEqual(mt_dtype.get_data().dtype, new_dtype) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_to_dense(self, device): samples = _generate_sample_data( device=device, diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 175e6a9649cd2..1266b38427b7f 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -21,6 +21,7 @@ from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_BF16, SM53OrLater, +<<<<<<< HEAD SM80OrLater, SM89OrLater, SM90OrLater, @@ -32,6 +33,14 @@ PLATFORM_SUPPORTS_MX_GEMM, PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, IS_SM90, +======= + SM89OrLater, + SM90OrLater, + xfailIfSM100OrLater, + _get_torch_cuda_version, + PLATFORM_SUPPORTS_FP8, + PLATFORM_SUPPORTS_MX_GEMM, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from torch.testing._internal.common_device_type import ( dtypes, @@ -51,11 +60,16 @@ parametrize, run_tests, skipIfRocm, +<<<<<<< HEAD +======= + skipIfRocmVersionAndArch, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) skipIfRocmVersionLessThan, TEST_CUDA, TEST_WITH_ROCM, TestCase, ) +<<<<<<< HEAD from torch.testing._internal.common_quantized import ( _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, @@ -63,6 +77,9 @@ to_mxfp8, generate_jagged_offs, ) +======= +from torch.testing._internal.common_quantized import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, ceil_div, to_blocked +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _IS_SM8X = False if TEST_CUDA: @@ -317,6 +334,7 @@ def grouped_mm_helper(self, alist, blist, gOlist, agradlist, bgradlist, outlist) self.assertEqual(bgrad, b.grad) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") +<<<<<<< HEAD @xfailIfSM120OrLater @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") @parametrize("strided", [False, True]) @@ -325,6 +343,16 @@ def grouped_mm_helper(self, alist, blist, gOlist, agradlist, bgradlist, outlist) @dtypes(torch.bfloat16, torch.float32, torch.float16) def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, dtype): device = "cuda" +======= + @xfailIfSM100OrLater + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major): + device = "cuda" + dtype = torch.bfloat16 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m, n, k, n_groups = 16, 32, 64, 4 if a_row_major: a = torch.randn(m, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups] @@ -341,7 +369,11 @@ def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, dtype): offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32) f = torch._grouped_mm +<<<<<<< HEAD out = f(a, b.t(), offs=offs, out_dtype=dtype) +======= + out = f(a, b.t(), offs=offs, out_dtype=torch.bfloat16) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gO = torch.rand_like(out) out.backward(gO) offs_cpu = offs.cpu() @@ -356,6 +388,7 @@ def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, dtype): self.grouped_mm_helper(alist, blist, gO, agradlist, bgradlist, out) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") +<<<<<<< HEAD @xfailIfSM120OrLater @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") @parametrize("strided", [False, True]) @@ -364,6 +397,16 @@ def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, dtype): @dtypes(torch.bfloat16, torch.float32, torch.float16) def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major, dtype): device = "cuda" +======= + @xfailIfSM100OrLater + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major): + device = "cuda" + dtype = torch.bfloat16 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) s_int = int(strided) m, n, k, n_groups = 16, 32, 64, 4 if a_row_major: @@ -390,12 +433,20 @@ def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major, dtype): a.grad = None b.grad = None +<<<<<<< HEAD offs = torch.arange(m, n_groups * m + 1, m, device=device, dtype=torch.int32) +======= + offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if check_zero_size: offs[0] = offs[1] f = torch._grouped_mm +<<<<<<< HEAD out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype) +======= + out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=torch.bfloat16) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gO = torch.rand_like(out) if not check_zero_size: out.backward(gO) @@ -413,6 +464,7 @@ def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major, dtype): @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") +<<<<<<< HEAD @xfailIfSM120OrLater @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") @parametrize("strided", [False, True]) @@ -421,6 +473,16 @@ def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major, dtype): @dtypes(torch.bfloat16, torch.float32, torch.float16) def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major, dtype): device = "cuda" +======= + @xfailIfSM100OrLater + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major): + device = "cuda" + dtype = torch.bfloat16 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) s_int = int(strided) m, n, k, n_groups = 16, 32, 64, 4 if a_row_major: @@ -442,12 +504,17 @@ def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major, dtype): self.assertTrue(b_contig.is_contiguous() is not strided) f = torch._grouped_mm +<<<<<<< HEAD out = f(a, b.transpose(-2, -1), out_dtype=dtype) +======= + out = f(a, b.transpose(-2, -1), out_dtype=torch.bfloat16) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gO = torch.rand_like(out) out.backward(gO) self.grouped_mm_helper(a, b, gO, a.grad, b.grad, out) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") +<<<<<<< HEAD @xfailIfSM120OrLater @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") @parametrize("strided", [False, True]) @@ -456,6 +523,16 @@ def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major, dtype): @dtypes(torch.bfloat16, torch.float32, torch.float16) def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype): device = "cuda" +======= + @xfailIfSM100OrLater + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major): + device = "cuda" + dtype = torch.bfloat16 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) s_int = int(strided) m, n, k, n_groups = 16, 32, 64, 4 if a_row_major: @@ -479,12 +556,20 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype): if check_zero_size and n_groups <= 1: continue +<<<<<<< HEAD offs = torch.arange(n, n_groups * n + 1, n, device=device, dtype=torch.int32) +======= + offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if check_zero_size: offs[0] = offs[1] f = torch._grouped_mm +<<<<<<< HEAD out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype) +======= + out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=torch.bfloat16) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gO = torch.rand_like(out) if not check_zero_size: out.backward(gO) @@ -502,6 +587,7 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype): @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @xfailIfSM100OrLater +<<<<<<< HEAD # TODO(future PR): enable compile for torch._grouped_mm fallback path @unittest.skipIf(not SM90OrLater, "Grouped gemm with compile supported on SM90") @parametrize("op", ["2d/2d", "2d/3d", "3d/2d", "3d/3d"]) @@ -509,6 +595,13 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype): @parametrize("b_row_major", [False, True]) @parametrize("max_autotune", [False, True]) def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune): +======= + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("op", ["2d/2d", "2d/3d", "3d/2d", "3d/3d"]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._dynamo.reset() device = "cuda" @@ -518,6 +611,7 @@ def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune) align = 16 // dtype_AB.itemsize f_ref = torch._grouped_mm +<<<<<<< HEAD options = {} if max_autotune: @@ -530,6 +624,14 @@ def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune) f = torch.compile( f_ref, options=options, +======= + f = torch.compile( + f_ref, + options={ + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if op == "2d/2d": @@ -537,9 +639,15 @@ def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune) m_align = (m + align - 1) // align * align n_align = (n + align - 1) // align * align if not a_row_major and not b_row_major: +<<<<<<< HEAD offs = torch.tensor([0, 1, 6, 6, 7], device=device, dtype=dtype_offset) else: offs = torch.tensor([0, 8, 16, 16, 27], device=device, dtype=dtype_offset) +======= + offs = torch.tensor([1, 3, 4, 6, 7], device=device, dtype=dtype_offset) + else: + offs = torch.tensor([8, 16, 32, 37], device=device, dtype=dtype_offset) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ngroups = offs.shape[0] k = offs[-1] k_align = (k + align - 1) // align * align @@ -553,7 +661,11 @@ def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune) else: B = torch.randn(k, n_align, device=device, dtype=dtype_AB).t()[:n, :] elif op == "2d/3d": +<<<<<<< HEAD n, k = 7, 259 # k is larger here, to validate iterating over k tiles on an op +======= + n, k = 7, 13 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) n_align = (n + align - 1) // align * align k_align = (k + align - 1) // align * align if a_row_major: @@ -613,7 +725,11 @@ def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune) -2, -1 )[:, :n, :] else: +<<<<<<< HEAD raise AssertionError(f"Invalid op: {op}") +======= + raise AssertionError(f"Invaild op: {op}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) C_ref = f_ref(A, B.transpose(-2, -1), offs=offs) C = f(A, B.transpose(-2, -1), offs=offs) @@ -621,13 +737,21 @@ def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune) @onlyCUDA +<<<<<<< HEAD +======= + @skipIfRocm +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) @parametrize("M", [1, 32, 64]) @parametrize("N", [1, 32, 64]) @parametrize("K", [1, 32, 64]) @parametrize("batch_size", [None, 1, 16]) +<<<<<<< HEAD # TODO: enable rocblas path on ROCm @parametrize("backend", ["cublaslt"] if torch.version.hip else ["cublas", "cublaslt"]) +======= + @parametrize("backend", ["cublas", "cublaslt"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_mm_bmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend): device = "cuda" dtype = input_dtype @@ -676,13 +800,21 @@ def create_inputs(B=None): @onlyCUDA +<<<<<<< HEAD +======= + @skipIfRocm +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) @parametrize("M", [1, 32, 64]) @parametrize("N", [1, 32, 64]) @parametrize("K", [1, 32, 64]) @parametrize("batch_size", [None, 1, 32]) +<<<<<<< HEAD # TODO: enable rocblas path on ROCm @parametrize("backend", ["cublaslt"] if torch.version.hip else ["cublas", "cublaslt"]) +======= + @parametrize("backend", ["cublas", "cublaslt"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_addmm_baddmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend): device = "cuda" dtype = input_dtype @@ -778,9 +910,13 @@ def expand(tensor): torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" +<<<<<<< HEAD f8_grouped_msg = "FP8 grouped is only supported on SM90 and MI300+ devices" mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+" mxfp8_grouped_mm_skip_msg = "MXFP8 grouped GEMM is only supported when PyTorch is built with USE_FBGEMM_GENAI=1 on SM100+" +======= +mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # avoid division by zero when calculating scale EPS = 1e-12 @@ -798,7 +934,11 @@ def amax_to_scale( if float8_dtype == e4m3_type: res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) elif float8_dtype == e5m2_type: +<<<<<<< HEAD res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) +======= + res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") @@ -819,6 +959,7 @@ def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): return amax_to_scale(amax, float8_dtype, x.dtype) +<<<<<<< HEAD def tensor_to_scale_block( x: torch.Tensor, float8_dtype: torch.dtype, @@ -833,6 +974,8 @@ def tensor_to_scale_block( scale = scale.flatten(2, 3).flatten(0, 1) return x, scale +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: # naive implementation: dq -> op -> q x_fp32 = x.to(torch.float) / x_scale @@ -841,6 +984,7 @@ def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: return out_fp32.to(out_dtype) +<<<<<<< HEAD def mm_float8_emulated_block(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: x = x.unflatten(1, (x_scale.shape[1], -1)).unflatten(0, (x_scale.shape[0], -1)) y = y.unflatten(1, (y_scale.shape[1], -1)).unflatten(0, (y_scale.shape[0], -1)) @@ -852,6 +996,8 @@ def mm_float8_emulated_block(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: return out_fp32.to(out_dtype) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def addmm_float8_unwrapped( a_data: torch.Tensor, a_scale: torch.Tensor, @@ -911,8 +1057,11 @@ def to_fp8_saturated( return x.to(fp8_dtype) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes the error between two tensors in dB. @@ -1057,6 +1206,7 @@ def test_float8_scale(self, device) -> None: out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) self.assertEqual(out_fp8, out_fp8_s) +<<<<<<< HEAD @unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg) @parametrize("G", [1, 4, 16]) @parametrize("M", [2048, 2049]) @@ -1218,6 +1368,8 @@ def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K): torch.testing.assert_close(y_mxfp8, y_bf16, atol=8.0e-2, rtol=8.0e-2) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) def test_scaled_mm_vs_emulated(self, base_dtype): @@ -1406,6 +1558,10 @@ def test_float8_scale_fast_accum(self, device) -> None: out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) self.assertEqual(out_fp8, out_fp8_s) +<<<<<<< HEAD +======= + @skipIfRocmVersionAndArch((7, 1), "gfx950") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific") @@ -1446,7 +1602,15 @@ def test_float8_error_messages(self, device) -> None: y_fp8 = y.to(e4m3_type).t() with self.assertRaisesRegex( +<<<<<<< HEAD RuntimeError, re.escape("Invalid scaling configuration") +======= + RuntimeError, + re.escape( + "For RowWise scaling, scale_a should be (1024, 1) and scale_b " + "should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)" + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch._scaled_mm( x_fp8, @@ -1457,7 +1621,15 @@ def test_float8_error_messages(self, device) -> None: ) with self.assertRaisesRegex( +<<<<<<< HEAD RuntimeError, re.escape("Invalid scaling configuration") +======= + RuntimeError, + re.escape( + " For RowWise scaling, scale_a should be (1024, 1) and scale_b " + "should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)" + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch._scaled_mm( x_fp8, @@ -1467,18 +1639,34 @@ def test_float8_error_messages(self, device) -> None: out_dtype=torch.bfloat16, ) with self.assertRaisesRegex( +<<<<<<< HEAD RuntimeError, re.escape("Invalid scaling configuration") +======= + RuntimeError, + re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch._scaled_mm( x_fp8, y_fp8, scale_a=torch.ones((M), device="cuda"), +<<<<<<< HEAD scale_b=torch.ones((N, 1), device="cuda"), +======= + scale_b=torch.ones((N, N), device="cuda"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_dtype=torch.bfloat16, ) with self.assertRaisesRegex( +<<<<<<< HEAD RuntimeError, re.escape("Invalid scaling configuration") +======= + RuntimeError, + re.escape( + "Both scale_a and scale_b must be contiguous for RowWise scaling." + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch._scaled_mm( x_fp8, @@ -1488,14 +1676,24 @@ def test_float8_error_messages(self, device) -> None: out_dtype=torch.bfloat16, ) +<<<<<<< HEAD def e5m2(): out = torch._scaled_mm( +======= + # Note re.compile is used, not re.escape. This is to accomodate fn vs fnuz type message. + with self.assertRaisesRegex( + RuntimeError, + r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.", + ): + torch._scaled_mm( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x_fp8, y_fp8.to(e5m2_type), scale_a=torch.ones((M, 1), device="cuda"), scale_b=torch.ones((1, N), device="cuda"), out_dtype=torch.bfloat16, ) +<<<<<<< HEAD return out if torch.cuda.get_device_capability() == (9, 0) and torch.version.cuda and torch.version.cuda >= "12.9": @@ -1521,6 +1719,14 @@ def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): if torch.cuda.get_device_capability() < (9, 0): raise unittest.SkipTest("Need sm90+ for row-wise fp8 w/ cuBLAS") +======= + + @skipIfRocmVersionAndArch((7, 1), "gfx950") + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) + @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific") + @parametrize("base_dtype", [torch.bfloat16]) + def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.manual_seed(42) input_dtype = e4m3_type output_dtype = base_dtype @@ -1551,6 +1757,7 @@ def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) +<<<<<<< HEAD @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+") @unittest.skipIf( @@ -1605,6 +1812,9 @@ def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_blo @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @unittest.skipIf(torch.version.hip is not None, "Float8_e4m3fn not supported on current ROCm CI setup (MI325X)") +======= + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("which_dim_zero", [0, 1, 2]) @parametrize("use_torch_compile", [False, True]) def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None: @@ -1631,6 +1841,10 @@ def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None: self.assertEqual(out_dtype, out_fp8.dtype) self.assertEqual(out_fp32, out_fp8.to(torch.float)) +<<<<<<< HEAD +======= + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet") @@ -1659,6 +1873,7 @@ def test_honor_sm_carveout(self) -> None: torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16) prof.export_chrome_trace(f.name) +<<<<<<< HEAD if torch.version.hip: events = [evt for evt in json.load(open(f.name))["traceEvents"] if evt.get("cat", "") == "kernel"] # events were returned out of order; need to be sorted on "ts" timestamp @@ -1692,6 +1907,17 @@ def test_honor_sm_carveout(self) -> None: # correct behavior self.assertNotEqual(no_carveout, carveout_66) self.assertNotEqual(carveout_66, carveout_0) +======= + no_carveout, carveout_0, carveout_66, no_carveout_again = [ + math.prod(evt.get("args", {}).get("grid", [])) + for evt in json.load(open(f.name))["traceEvents"] + if evt.get("cat", "") == "kernel" + ] + + self.assertEqual(no_carveout, no_carveout_again) + self.assertNotEqual(no_carveout, carveout_66) + self.assertNotEqual(carveout_66, carveout_0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_pack_uint4(self): """ @@ -1750,11 +1976,22 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, device = "cuda" M, K, N = mkn +<<<<<<< HEAD +======= + if torch.version.hip: + if not (M % 32 == 0 and K % 32 == 0 and N % 32 == 0): + raise unittest.SkipTest("Matrix dimensions must be multiples of 32 on ROCm, skipping") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (recipe == "nvfp4" or recipe == "mxfp4") and K % 32 != 0: raise unittest.SkipTest("K must be divisible by 32 for nvfp4/mxfp4 cublas gemm, skipping") fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn +<<<<<<< HEAD BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32) +======= + BLOCK_SIZE = 16 if recipe == "nvfp4" else 32 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) require_exact_match = True approx_match_sqnr_target = 22.0 @@ -1917,8 +2154,13 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn) else: # nvfp4 # mxfp4 scale_func = data_to_mx_scale if recipe == "mxfp4" else data_to_nvfp4_scale +<<<<<<< HEAD A_scale = scale_func(*([A_ref, BLOCK_SIZE] + recipe if recipe == "mxfp4" else [A_ref, BLOCK_SIZE])) B_scale = scale_func(*([B_ref, BLOCK_SIZE] + recipe if recipe == "mxfp4" else [B_ref, BLOCK_SIZE])) +======= + A_scale = scale_func(A_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None) + B_scale = scale_func(B_ref, BLOCK_SIZE, recipe if recipe == "mxfp4" else None) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) max_val = FP4_MAX_VAL min_val = -1 * max_val @@ -1953,6 +2195,10 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, sqnr = compute_error(C_ref, C) assert sqnr.item() > approx_match_sqnr_target +<<<<<<< HEAD +======= + @skipIfRocm +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM or IS_WINDOWS, mx_skip_msg) @parametrize("recipe", ["mxfp8", "nvfp4"]) def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None: @@ -1980,9 +2226,16 @@ def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None: # Test wrong scale tensor size for scale_a with correct dtype with self.assertRaisesRegex( RuntimeError, +<<<<<<< HEAD f".*For Block[W,w]ise.*scaling.*scale_a should have {expected_a_size} " f"elements.*" , +======= + re.escape( + f"For BlockWise scaling: Expected scale_a size to be {expected_a_size} " + f"but got {expected_a_size - 1}" + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): incorrect_size_a = torch.ones(expected_a_size - 1, device=device, dtype=scale_dtype) correct_size_b = torch.ones(expected_b_size, device=device, dtype=scale_dtype) @@ -1997,9 +2250,16 @@ def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None: # Test wrong scale tensor size for scale_b with correct dtype with self.assertRaisesRegex( RuntimeError, +<<<<<<< HEAD f"For Block[W,w]ise.*scaling.*scale_b should have {expected_b_size} " f"elements.*" , +======= + re.escape( + f"For BlockWise scaling: Expected scale_b size to be {expected_b_size} " + f"but got {expected_b_size + 1}" + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): correct_size_a = torch.ones(expected_a_size, device=device, dtype=scale_dtype) incorrect_size_b = torch.ones(expected_b_size + 1, device=device, dtype=scale_dtype) @@ -2014,8 +2274,14 @@ def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None: # Test non-contiguous scale tensors with correct dtype with self.assertRaisesRegex( RuntimeError, +<<<<<<< HEAD "For Block[W,w]ise.*scaling.*both should be contiguous" , +======= + re.escape( + "For BlockWise scaling: Both scale_a and scale_b must be contiguous" + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): non_contiguous_a = torch.ones(expected_a_size * 2, device=device, dtype=scale_dtype)[::2] contiguous_b = torch.ones(expected_b_size, device=device, dtype=scale_dtype) @@ -2035,6 +2301,7 @@ def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist # Testing only _scaled_grouped_mm() with multiple shapes, as # _scaled_mm() already has more combinations of parameters than +<<<<<<< HEAD # _scaled_grouped_mm(), for supporting more than one inputs layout # combinations. @unittest.skipIf(not PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, f8_grouped_msg) @@ -2047,6 +2314,21 @@ def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided): m, n, k, n_groups = 16, 32, 64, 4 a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(fp8_dtype)[:, :k * n_groups] b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(fp8_dtype)[:, :k * n_groups] +======= + # _scaled_grouped_mm(), for supporing more than one inputs layout + # combinations. + + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") + @xfailIfSM100OrLater + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("fast_accum", [False, True]) + @parametrize("strided", [False, True]) + def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided): + device = "cuda" + m, n, k, n_groups = 16, 32, 64, 4 + a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups] + b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32) scale_b = torch.rand(n * n_groups, device=device, dtype=torch.float32) offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32) @@ -2065,6 +2347,7 @@ def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided): self.scaled_grouped_mm_helper(alist, blist, ascalelist, bscalelist, out, fast_accum) +<<<<<<< HEAD @unittest.skipIf(not PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, f8_grouped_msg) @parametrize("fast_accum", [False, True]) # AMD does not support non-contiguous inputs yet @@ -2076,6 +2359,19 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided): s_int = int(strided) a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(fp8_dtype)[:, :k] b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(fp8_dtype)[::(1 + s_int), :, :k] +======= + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") + @xfailIfSM100OrLater + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("fast_accum", [False, True]) + @parametrize("strided", [False, True]) + def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided): + device = "cuda" + m, n, k, n_groups = 16, 32, 64, 4 + s_int = int(strided) + a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k] + b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(a.is_contiguous() is not strided) self.assertTrue(b.is_contiguous() is not strided) for check_zero_size in (True, False): @@ -2087,6 +2383,10 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided): offs[0] = offs[1] scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32) scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n) +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f = torch._scaled_grouped_mm out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs, out_dtype=torch.bfloat16, use_fast_accum=fast_accum) @@ -2102,6 +2402,7 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided): self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum) +<<<<<<< HEAD @unittest.skipIf(not PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, f8_grouped_msg) @parametrize("fast_accum", [False, True]) # AMD does not support non-contiguous inputs yet @@ -2113,6 +2414,19 @@ def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided): s_int = int(strided) a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(fp8_dtype)[::(1 + s_int), :, :k] b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(fp8_dtype)[::(1 + s_int), :, :k] +======= + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") + @xfailIfSM100OrLater + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("fast_accum", [False, True]) + @parametrize("strided", [False, True]) + def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided): + device = "cuda" + m, n, k, n_groups = 16, 32, 64, 4 + s_int = int(strided) + a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] + b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(a.is_contiguous() is not strided) self.assertTrue(b.is_contiguous() is not strided) scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m) @@ -2125,6 +2439,7 @@ def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided): self.scaled_grouped_mm_helper(a, b, scale_a, scale_b, out, fast_accum) +<<<<<<< HEAD @unittest.skipIf(not PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, f8_grouped_msg) @parametrize("fast_accum", [False, True]) # AMD does not support non-contiguous inputs yet @@ -2136,6 +2451,19 @@ def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided): s_int = int(strided) a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(fp8_dtype)[::(1 + s_int), :, :k] b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(fp8_dtype)[:, :k] +======= + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") + @xfailIfSM100OrLater + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("fast_accum", [False, True]) + @parametrize("strided", [False, True]) + def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided): + device = "cuda" + m, n, k, n_groups = 16, 32, 64, 4 + s_int = int(strided) + a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k] + b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(a.is_contiguous() is not strided) self.assertTrue(b.is_contiguous() is not strided) scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m) @@ -2190,6 +2518,10 @@ def test_blockwise_mxfp8_compile(self) -> None: ) torch.testing.assert_close(C, C_ref, atol=0, rtol=0) +<<<<<<< HEAD +======= + @skipIfRocm +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg) def test_blockwise_nvfp4_compile(self) -> None: diff --git a/test/test_meta.py b/test/test_meta.py index b3e5faab4f659..19d1890612aca 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -575,8 +575,13 @@ def run_meta_crossref( elif func in (torch.ops.aten.repeat_interleave.Tensor, torch.ops.aten.repeat_interleave.Tensor_out): if kwargs.get("output_size", None) is None: meta_args = args +<<<<<<< HEAD if func is torch.ops.aten.repeat_interleave.Tensor_out: meta_kwargs["out"] = kwargs["out"] +======= + if func is torch.ops.aten.repeat_interleave.Tensor_out: + meta_kwargs["out"] = kwargs["out"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif func in (torch.ops.aten.index.Tensor, torch.ops.aten.index.Tensor_out): # Don't convert boolean tensors to meta as they will have nonzero # called on them @@ -681,10 +686,14 @@ def run_meta_crossref( } meta_function_expected_failures_conditional = { +<<<<<<< HEAD torch.repeat_interleave: lambda dtype, *args, **kwargs: ( not isinstance(kwargs.get("repeats", None), int) and (kwargs.get("output_size", None) is None) ), +======= + torch.repeat_interleave : (lambda dtype, *args, **kwargs: not isinstance(kwargs.get("repeats", None), int)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } """ @@ -1505,7 +1514,11 @@ def test_batch_norm_backward(self, output_mask): def test_fill__alias_relationship(self): inps = torch.rand(2**52, device='meta') r = torch.ops.aten.fill_(inps, 1.0) +<<<<<<< HEAD # aten.fill_ returns an alias +======= + # aten.fill_ returns an aliase +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(id(inps), id(r)) # aten.fill returns a new tensor diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index e2ec92fc8dada..6e8bd18075416 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -22,12 +22,19 @@ from torch.utils import mkldnn as mkldnn_utils from torch.testing._internal.common_utils import TestCase, \ run_tests, TemporaryFileName, gradcheck, gradgradcheck, IS_WINDOWS, \ +<<<<<<< HEAD skipIfTorchDynamo, xfailIfTorchDynamo, recover_orig_fp32_precision +======= + skipIfTorchDynamo, xfailIfTorchDynamo +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, dtypes, ) +<<<<<<< HEAD from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # batched grad doesn't support mkldnn gradcheck = functools.partial(gradcheck, check_batched_grad=False) @@ -265,10 +272,14 @@ def _test_conv_base(self, dim): loss1.backward() if not train or (train and dim != 1): y_mkldnn = mkldnn_conv(x2).to_dense() +<<<<<<< HEAD if self.precision != 0: self.assertEqual(y_aten, y_mkldnn, atol=self.precision, rtol=self.precision) else: self.assertEqual(y_aten, y_mkldnn) +======= + self.assertEqual(y_aten, y_mkldnn) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not train: self._test_serialization(mkldnn_conv, (x.to_mkldnn(),)) self._test_tracing(mkldnn_conv, (x.to_mkldnn(),)) @@ -284,6 +295,7 @@ def _test_conv_base(self, dim): if bias: self.assertEqual(conv.bias.grad, mkldnn_conv.bias.grad) +<<<<<<< HEAD @reduced_f32_on_and_off() def test_conv1d(self): self._test_conv_base(dim=1) @@ -293,6 +305,14 @@ def test_conv2d(self): self._test_conv_base(dim=2) @reduced_f32_on_and_off() +======= + def test_conv1d(self): + self._test_conv_base(dim=1) + + def test_conv2d(self): + self._test_conv_base(dim=2) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv3d(self): self._test_conv_base(dim=3) @@ -407,7 +427,10 @@ def _test_conv_deconv_nhwc_base(self, conv_module, weight_memory_format, dtype, self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec) self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec) +<<<<<<< HEAD @reduced_f32_on_and_off() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv_nhwc_fp32(self): self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32) self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32) @@ -443,7 +466,10 @@ def test_conv_nhwc_lower_precision(self, dtype): self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype, prec=prec) +<<<<<<< HEAD @reduced_f32_on_and_off() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv_transpose_nhwc_fp32(self): self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32) self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32) @@ -492,7 +518,11 @@ def _test_conv_transpose_base(self, dim): C = torch.randint(1, 3, (1,)).item() * groups x_shape = (N, C) + input_shapes[dim] data = torch.randn(x_shape, dtype=torch.float32) +<<<<<<< HEAD # conv: mkldnn transpose conv fp32 +======= + # conv: mkldnn tranpose conv fp32 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # conv_ref: thnn transpose conv fp32 conv = conv_module[dim](in_channels=C, out_channels=M, @@ -518,11 +548,15 @@ def _test_conv_transpose_base(self, dim): if train: y.sum().backward() +<<<<<<< HEAD if self.precision != 0: self.assertEqual(y, y_ref, atol=self.precision, rtol=self.precision) else: self.assertEqual(y, y_ref) +======= + self.assertEqual(y, y_ref) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if train: self.assertEqual(x.grad, x_ref.grad) self.assertEqual(conv.weight.grad, @@ -532,6 +566,7 @@ def _test_conv_transpose_base(self, dim): if bias: self.assertEqual(conv.bias.grad, conv_ref.bias.grad) +<<<<<<< HEAD @reduced_f32_on_and_off() def test_conv_transpose1d(self): self._test_conv_transpose_base(dim=1) @@ -541,6 +576,14 @@ def test_conv_transpose2d(self): self._test_conv_transpose_base(dim=2) @reduced_f32_on_and_off() +======= + def test_conv_transpose1d(self): + self._test_conv_transpose_base(dim=1) + + def test_conv_transpose2d(self): + self._test_conv_transpose_base(dim=2) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_conv_transpose3d(self): self._test_conv_transpose_base(dim=3) @@ -1675,6 +1718,7 @@ def test_mkldnn_scaled_mm(self, device) -> None: self.assertEqual(out_emulated.float(), out.float(), atol=5e-2, rtol=5e-2) +<<<<<<< HEAD @recover_orig_fp32_precision def test_mlkdnn_get_set(self): # get/set mkldnn ops @@ -1726,6 +1770,8 @@ def test_default_use_parent(self): with torch.backends.flags(fp32_precision="tf32"): self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "tf32") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',)) diff --git a/test/test_mps.py b/test/test_mps.py index 9204bf5dba2c5..80205fd07ff96 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -23,13 +23,20 @@ from torch.testing._internal import opinfo from torch.testing._internal.common_utils import \ (gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, MACOS_VERSION, IS_CI, +<<<<<<< HEAD NoTest, skipIfSlowGradcheckEnv, suppress_warnings, serialTest, instantiate_parametrized_tests, xfailIf) +======= + NoTest, skipIfSlowGradcheckEnv, suppress_warnings, serialTest, instantiate_parametrized_tests) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_mps import mps_ops_modifier, mps_ops_grad_modifier, mps_ops_error_inputs_modifier from torch.testing import make_tensor from torch.testing._internal.common_dtype import get_all_dtypes, integral_types import torch.backends.mps from torch.distributions import Uniform, Exponential +<<<<<<< HEAD from torch.utils._python_dispatch import TorchDispatchMode +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from functools import partial from torch.testing._internal.common_methods_invocations import ( @@ -72,6 +79,17 @@ ) ) +<<<<<<< HEAD +======= +def xfailIf(condition): + def wrapper(func): + if condition: + return unittest.expectedFailure(func) + else: + return func + return wrapper + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Same logic as test_cuda.py if not torch.backends.mps.is_available(): print('MPS not available, skipping tests', file=sys.stderr) @@ -188,6 +206,11 @@ def test_matmul_autocast(self): @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) def test_scaled_dot_product_attention_autocast(self, dtype): # Regression test for https://github.com/pytorch/pytorch/issues/141774 +<<<<<<< HEAD +======= + if dtype == torch.bfloat16 and MACOS_VERSION < 14.0: + raise unittest.SkipTest("bfloat16 needs MacOS14+") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) query = torch.rand(4, 1, 16, 8, dtype=torch.float32, device="mps") key = torch.rand(4, 1, 16, 8, dtype=torch.float32, device="mps") @@ -199,6 +222,7 @@ def test_scaled_dot_product_attention_autocast(self, dtype): y = F.scaled_dot_product_attention(query, key, value.to(torch.float32)) self.assertEqual(y.to(y_autocast.dtype), y_autocast) +<<<<<<< HEAD def test_conv_transpose3d_autocast_fp32(self): m = nn.ConvTranspose3d(16, 33, 3, stride=2).to("mps") x = torch.randn(20, 16, 10, 50, 100, device="mps") @@ -225,6 +249,8 @@ def forward(self, x): y = model(x) self.assertEqual(y.dtype, torch.float16) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_gradscaler_mps(self): # big model to force chunking/depth in the gradscaler dispatch class Model(nn.Module): @@ -246,6 +272,11 @@ def forward(self, x): torch.manual_seed(42) def helper(model_cpu, model_mps, dtype, iterations, batch_size, atol=3e-4, rtol=1e-5): +<<<<<<< HEAD +======= + if dtype == torch.bfloat16 and MACOS_VERSION < 14.0: + raise unittest.SkipTest("bfloat16 needs MacOS14+") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) optimizer_cpu = torch.optim.SGD(model_cpu.parameters(), lr=0.01) optimizer_mps = torch.optim.SGD(model_mps.parameters(), lr=0.01) loss_fn = nn.MSELoss() @@ -655,7 +686,11 @@ def _testLeakyRelu(self, shape, dtype, negative_slope, contiguous): mps_x = cpu_x.detach().clone().to('mps') if not contiguous and not (0 in shape or len(shape) < 2): +<<<<<<< HEAD # Transposing will make the tensor non-contiguous +======= + # Tranposing will make the tensor non-contiguous +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cpu_x = cpu_x.transpose(0, 1) mps_x = mps_x.transpose(0, 1) assert not mps_x.is_contiguous() @@ -738,6 +773,7 @@ def test_avg_pool2d_ceil_mode(self): padding=(0, 1), stride=2) self.assertFalse(torch.isnan(y).any()) +<<<<<<< HEAD # Test some cases for avg_pool2d which used to mismatch CPU results. # Addresses this issue: https://github.com/pytorch/pytorch/issues/160743 def test_avg_pool2d_ceil_mode_mismatch(self): @@ -765,6 +801,8 @@ def test_avg_pool2d_ceil_mode_mismatch(self): msg = f'{input_size=}, {kwargs=}' self.assertEqual(out_mps, out_cpu, msg=msg) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestMPS(TestCaseMPS): def test_exp(self, device="mps", dtype=torch.float): @@ -982,7 +1020,11 @@ def test_cdist_same_inputs(self, device="mps"): x.requires_grad = True d = torch.cdist(x, y) d.backward(dist_grad) +<<<<<<< HEAD # Check that the backward pass does not contain invalid +======= + # Check that the backward passs does not contain invalid +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # values such as nan or inf assert torch.isfinite(x.grad).all() @@ -1122,6 +1164,12 @@ def test_large_bmm(self, dtype): @parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) def test_take_along_dim(self, dtype): +<<<<<<< HEAD +======= + if dtype == torch.bfloat16 and MACOS_VERSION < 14.0: + raise unittest.SkipTest("bfloat16 needs MacOS14+") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = torch.tensor([[-5.], [0.], [5.]], dtype=dtype) inds = torch.tensor([[0], [1], [2]]) ref = torch.take_along_dim(x, inds, 0) @@ -1234,7 +1282,11 @@ def test_linear_errors(self): torch.nn.functional.linear(torch.rand(size, device='mps'), torch.randint(-10, 10, size, dtype=torch.int8, device='mps')) +<<<<<<< HEAD # Weights on wrong device +======= + # Weigths on wrong device +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaisesRegex(RuntimeError, "argument weight is on cpu but expected on mps"): torch.nn.functional.linear(torch.rand(size, device='mps'), torch.rand(size, device='cpu')) @@ -1244,6 +1296,7 @@ def test_linear_errors(self): torch.nn.functional.linear(torch.rand(size, device='cpu'), torch.rand(size, device='mps')) +<<<<<<< HEAD def test_linear_non_contiguous(self): # Regression test for https://github.com/pytorch/pytorch/issues/161640 # Slice tensors to force non-contiguity @@ -1255,6 +1308,8 @@ def test_linear_non_contiguous(self): result_contig = torch.nn.functional.linear(input_s, weight_contiguous_equiv) self.assertEqual(result_contig, result_sliced) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False): cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias) mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias) @@ -1354,6 +1409,10 @@ def test_linear3D_no_bias(self): def test_linear3D_no_bias_backward(self): self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True) +<<<<<<< HEAD +======= + @xfailIf(MACOS_VERSION < 14.0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_linear_large(self): # Regression test for https://github.com/pytorch/pytorch/issues/122045 x_cpu = torch.randn(9, 1024, 1, device='cpu') @@ -1572,11 +1631,20 @@ def test_masked_fill(self): dst2[i] = val self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0) +<<<<<<< HEAD # Regression test for https://github.com/pytorch/pytorch/issues/143477 # Allocating 48x25x1024x1024 tensor crashes on MacOS-13 mask_bool = torch.triu(torch.ones(1024, 1024, device=device), diagonal=1).bool() attn_scores = torch.rand(48, 25, 1024, 1024, device=device) attn_scores.masked_fill_(mask_bool, 0) +======= + if MACOS_VERSION >= 14.0: + # Regression test for https://github.com/pytorch/pytorch/issues/143477 + # Allocating 48x25x1024x1024 tensor crashes on MacOS-13 + mask_bool = torch.triu(torch.ones(1024, 1024, device=device), diagonal=1).bool() + attn_scores = torch.rand(48, 25, 1024, 1024, device=device) + attn_scores.masked_fill_(mask_bool, 0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_masked_fill__non_contiguous(self): shape = (3, 5) @@ -1796,6 +1864,7 @@ def test_batch_norm_slices(self): self.assertEqual(res_cpu, res_mps) +<<<<<<< HEAD def test_batch_norm_backward_weight_bias_gradients(self): # See issue: https://github.com/pytorch/pytorch/issues/156555 N, C, L = 4, 3, 5 @@ -1816,6 +1885,8 @@ def test_batch_norm_backward_weight_bias_gradients(self): self.assertEqual(bn_cpu.weight.grad, bn_mps.weight.grad, atol=1e-5, rtol=1e-5) self.assertEqual(bn_cpu.bias.grad, bn_mps.bias.grad, atol=1e-5, rtol=1e-5) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_layer_norm_backward(self): inputs = torch.rand(4, 4, device="mps", requires_grad=True) x = torch.nn.LayerNorm(4).to("mps") @@ -2088,6 +2159,10 @@ def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dt # Regression test for https://github.com/pytorch/pytorch/issues/96113 torch.nn.LayerNorm((16,), elementwise_affine=True).to("mps")(torch.randn(1, 2, 16).to("mps", dtype=torch.float16)) +<<<<<<< HEAD +======= + @xfailIf(MACOS_VERSION < 14.0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_ifft(self): # See: https://github.com/pytorch/pytorch/issues/124096 device = torch.device("mps") @@ -2344,6 +2419,7 @@ def helper(dim, layer='linear', dtype=torch.float32): helper(3, layer='conv') helper(-1, layer='conv') +<<<<<<< HEAD # Conv3d is only available from MacOS 13 onwards helper(0, layer='conv3d') helper(1, layer='conv3d') @@ -2351,6 +2427,16 @@ def helper(dim, layer='linear', dtype=torch.float32): helper(3, layer='conv3d') helper(4, layer='conv3d') helper(-1, layer='conv3d') +======= + if MACOS_VERSION >= 13.2: + # Conv3d is only available from MacOS 13 onwards + helper(0, layer='conv3d') + helper(1, layer='conv3d') + helper(2, layer='conv3d') + helper(3, layer='conv3d') + helper(4, layer='conv3d') + helper(-1, layer='conv3d') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Test conv2d def test_conv2d_unit(self): @@ -3213,7 +3299,12 @@ def test_torch_repeat_interleave(self, device="mps"): def test_repeat_interleave(self, device="mps"): x = torch.tensor([0, 1, 2, 3], device=device) expected = torch.tensor([1, 2, 2, 3, 3, 3], device=device) +<<<<<<< HEAD self.assertEqual(torch.repeat_interleave(x), expected) +======= + # Prior to macos 13.3, input of dtype=torch.int64 returns dtype=torch.int32 + self.assertEqual(torch.repeat_interleave(x), expected, exact_dtype=MACOS_VERSION >= 13.3) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaises(RuntimeError): torch.repeat_interleave(torch.arange(4, device=device).reshape(2, 2)) @@ -3624,7 +3715,11 @@ def rotate_subset(data, dim): self.assertFalse(x2.is_contiguous()) return torch.concat((x1, x2), dim=dim) for dtype in MPS_DTYPES: +<<<<<<< HEAD if dtype == torch.bool: +======= + if dtype == torch.bool or (dtype.is_complex and MACOS_VERSION < 14.0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue data = torch.arange(48).to(dtype=dtype).reshape(1, 2, 4, 6) data = data.to(memory_format=torch.channels_last) @@ -3637,6 +3732,7 @@ def rotate_subset(data, dim): # TODO: enable memory format test # self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous()) +<<<<<<< HEAD # See https://github.com/pytorch/pytorch/issues/152701 def test_jacfwd_cat(self): def fn(x, y): @@ -3647,6 +3743,8 @@ def fn(x, y): rc = torch.func.jacfwd(fn)(x, y) self.assertEqual(rc.shape, (5, 2)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # See https://github.com/pytorch/pytorch/issues/85967 def test_from_numpy_non_contiguous(self): a = np.arange(9).reshape(3, 3)[:, :2] @@ -3840,7 +3938,18 @@ def helper(dtype): a_cpu = t_cpu.cumsum(0, dtype=dtype) self.assertEqual(a.cpu(), a_cpu) +<<<<<<< HEAD [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.float32]] +======= + [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]] + + try: + helper(torch.int64) + except Exception as e: + e_string = str(e) + self.assertEqual(e_string, "MPS does not support cumsum_out_mps op with int64 input." + + " Support has been added in macOS 13.3") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cumsum_bool(self): a = torch.ones(2**16, dtype=torch.bool) @@ -3875,7 +3984,18 @@ def helper(dtype): a_cpu = t_cpu.cumprod(0, dtype=dtype) self.assertEqual(a.cpu(), a_cpu) +<<<<<<< HEAD [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.float32]] +======= + [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]] + + try: + helper(torch.int64) + except Exception as e: + e_string = str(e) + self.assertEqual(e_string, "MPS does not support cumprod_out_mps op with int64 input." + + " Support has been added in macOS 13.3") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cumprod_minus_one_axis(self): def helper(dtype): @@ -4658,6 +4778,10 @@ def helper(n, c, h, w, reduction_type, dtype=torch.float32): helper(2, 8, 4, 4, "min", torch.float16) helper(2, 8, 4, 4, "min", torch.int64) +<<<<<<< HEAD +======= + @unittest.skipIf(MACOS_VERSION < 13.3, "Long data type supported from macOS 13.3 and above") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_reduction_sum_max_long_val(self): x_mps = torch.tensor([sys.maxsize, sys.maxsize - 10, sys.maxsize - 5, sys.maxsize - 18], device="mps") x_cpu = x_mps.detach().clone().cpu() @@ -5359,9 +5483,12 @@ def helper(): helper() +<<<<<<< HEAD # Regression test for https://github.com/pytorch/pytorch/issues/160738 self.assertTrue(torch.var(torch.tensor(3.13, device='mps'), dim=0).isnan().item()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Test forward amax def test_amax(self): def helper(shape, dim, keepdim): @@ -5715,7 +5842,12 @@ def helper(shapes, output_size, kernel_size, padding, stride, contiguous, dtype= helper((4, 15, 1600), (40, 40), (3, 5), (1, 2), (1, 1), True) helper((4, 45, 187), (35, 33), (3, 5), (0, 1), (2, 3), True) helper((1600, 15), (40, 40), (3, 5), (1, 2), (1, 1), False) +<<<<<<< HEAD helper((20, 15), (2, 10), (3, 5), (1, 2), (1, 1), False, torch.bfloat16) +======= + if MACOS_VERSION >= 14.0: + helper((20, 15), (2, 10), (3, 5), (1, 2), (1, 1), False, torch.bfloat16) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) helper((20, 15), (2, 10), (3, 5), (1, 2), (1, 1), False, torch.float16) helper((20, 15), (2, 10), (3, 5), (1, 2), (1, 1), False, test_bool=True) @@ -6004,6 +6136,10 @@ def helper(shape): helper((2, 8, 4, 5)) +<<<<<<< HEAD +======= + @xfailIf(MACOS_VERSION < 14.0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_angle(self): def helper(shape, dtype): cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) @@ -6055,7 +6191,11 @@ def helper(shape): helper((2, 8, 4, 5)) +<<<<<<< HEAD @parametrize("dtype", {torch.float, torch.half, torch.bfloat16}) +======= + @parametrize("dtype", {torch.float, torch.half} if MACOS_VERSION < 14 else {torch.float, torch.half, torch.bfloat16}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_log1p(self, dtype): eps = torch.finfo(dtype).eps # Small values @@ -6316,7 +6456,11 @@ def helper(shape, contiguous=True): x = cpu_x.detach().clone().to('mps') if not contiguous and (0 not in shape and len(shape) >= 2): +<<<<<<< HEAD # Transposing will make the tensor non-contiguous +======= + # Tranposing will make the tensor non-contiguous +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cpu_x = cpu_x.transpose(0, 1) x = x.transpose(0, 1) assert not x.is_contiguous() @@ -6472,7 +6616,11 @@ def helper(shape, dtype=torch.float, contiguous=True): x = cpu_x.detach().clone().to('mps') if not contiguous and (0 not in shape and len(shape) >= 2): +<<<<<<< HEAD # Transposing will make the tensor non-contiguous +======= + # Tranposing will make the tensor non-contiguous +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cpu_x = cpu_x.transpose(0, 1) x = x.transpose(0, 1) assert not x.is_contiguous() @@ -6512,7 +6660,11 @@ def helper(shape, dtype=torch.float, contiguous=True): x = cpu_x.detach().clone().to('mps') if not contiguous and (0 not in shape and len(shape) >= 2): +<<<<<<< HEAD # Transposing will make the tensor non-contiguous +======= + # Tranposing will make the tensor non-contiguous +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cpu_x = cpu_x.transpose(0, 1) x = x.transpose(0, 1) assert not x.is_contiguous() @@ -6794,6 +6946,11 @@ def helper(shape, dim, index, source_shape, alpha, x_dtype=torch.float32, idx_dt def test_index_64bit(self): """ Test that index operations work for 4Gb+ tensors """ +<<<<<<< HEAD +======= + if MACOS_VERSION < 14.0: + raise unittest.SkipTest("Sonoma is needed for large tensors, see https://github.com/pytorch/pytorch/issues/84039") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Cleanup memory gc.collect() torch.mps.empty_cache() @@ -6830,10 +6987,19 @@ def compare_mm(m, n, k, dtype=torch.float): # see https://github.com/pytorch/pytorch/issues/116769#issuecomment-1920066984 compare_mm(32769, 1, 1025) +<<<<<<< HEAD # Test bfloat16 mm compare_mm(1024, 1, 32769, torch.bfloat16) @unittest.skipIf(total_memory < 12_000_000_000, "Needs at least 12Gb RAM to run the test") +======= + if MACOS_VERSION >= 14.0: + # Test bfloat16 mm + compare_mm(1024, 1, 32769, torch.bfloat16) + + @unittest.skipIf(total_memory < 12_000_000_000, "Needs at least 12Gb RAM to run the test") + @unittest.skipIf(MACOS_VERSION < 14.0, "Can't allocate 4Gb tensor on MacOS 13") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_CI, "May be fixes https://github.com/pytorch/pytorch/issues/149999") def test_copy_large(self): """ Test that copy of 4Gb+ tensors works """ @@ -7323,11 +7489,19 @@ def test_arange(self): self.assertEqual(np.arange(7, 1, -1), torch.arange(7, 1, -1, device='mps')) self.assertEqual(np.arange(1, 2, .3, dtype=np.float32), torch.arange(1, 2, .3, device='mps')) self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(6.3, device='mps')) +<<<<<<< HEAD def do_arange(start=1.2, end=10.3, dtype=torch.bfloat16, device='cpu'): return torch.arange(start, end, device=device, dtype=dtype) self.assertEqual(do_arange(device='mps'), do_arange(device='cpu')) +======= + # To be removed + if MACOS_VERSION >= 14.0: + def do_arange(start=1.2, end=10.3, dtype=torch.bfloat16, device='cpu'): + return torch.arange(start, end, device=device, dtype=dtype) + self.assertEqual(do_arange(device='mps'), do_arange(device='cpu')) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_arange_empty(self): out_mps = torch.tensor([], device="mps") @@ -7489,12 +7663,23 @@ def helper(shape, mean=0.0, std=1.0, dtype=torch.float): helper((2, 3, 4, 5, 6)) helper((100, 100), 2.5, 1.2) +<<<<<<< HEAD helper((10, 10), 2.5, 1.2, dtype=torch.bfloat16) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Test invalid inputs with self.assertRaises(TypeError): helper((10, 10), 10, 11, dtype=torch.int32) +<<<<<<< HEAD +======= + if MACOS_VERSION >= 14.0: + helper((10, 10), 2.5, 1.2, dtype=torch.bfloat16) + else: + with self.assertRaises(TypeError): + helper((10, 10), 2.5, 1.2, dtype=torch.bfloat16) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_bernoulli(self): shape = (10, 10) @@ -7521,6 +7706,7 @@ def test_bernoulli(self): for dtype in [torch.float16, torch.int8, torch.int16, torch.int32, torch.int64]: mps_out = torch.zeros(shape, device='mps', dtype=dtype).bernoulli(0.5) # Check that output is not all zeros or ones +<<<<<<< HEAD uniq = mps_out.unique() self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype)) @@ -7556,6 +7742,14 @@ def test_dropout(self, dtype): else: self.assertEqual(input.grad, output_grad) +======= + if MACOS_VERSION > 13.0: + uniq = mps_out.unique() + self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype)) + else: + self.assertEqual(mps_out.min().item(), 0.) + self.assertEqual(mps_out.max().item(), 1.) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_mps_generator(self): # explicit manual seeding by creating an MPS Generator @@ -7758,6 +7952,7 @@ def test_random_5d(self): # Test exponential @unittest.skip("This does not test anything") def test_exponential(self): +<<<<<<< HEAD def helper(shape, lambda_, dtype=torch.float32): mps_out = torch.zeros(shape, device='mps', dtype=dtype) @@ -7765,6 +7960,15 @@ def helper(shape, lambda_, dtype=torch.float32): print(mps_out.to('cpu').float().mean(), 1 / lambda_) print(mps_out.to('cpu').float().std() ** 2, 1 / (lambda_**2)) +======= + def helper(shape, lamda, dtype=torch.float32): + + mps_out = torch.zeros(shape, device='mps', dtype=dtype) + mps_out.exponential_(lamda) + + print(mps_out.to('cpu').float().mean(), 1 / lamda) + print(mps_out.to('cpu').float().std() ** 2, 1 / (lamda**2)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for dtype in [torch.float32, torch.float16]: helper([100, 100], 2, dtype) @@ -7782,12 +7986,15 @@ def test_exponential_1(self): self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,)) self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,)) +<<<<<<< HEAD @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) def test_exponential_nonzero(self, dtype): for _ in range(100): a = torch.empty(32_000, device="mps", dtype=dtype).exponential_() self.assertTrue((a != 0).all()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Test add def test_add_sub(self): def helper(shape, alpha, op_name, inplace): @@ -7836,8 +8043,11 @@ def helper(shape, alpha, op_name, inplace): y = torch.arange(32, device='mps', dtype=torch.int32) self.assertEqual(torch.add(x, y, alpha=2).cpu(), torch.add(x.cpu(), y.cpu(), alpha=2)) self.assertEqual(torch.add(x, 3, alpha=2).cpu(), torch.add(x.cpu(), 3, alpha=2)) +<<<<<<< HEAD # Regression test for https://github.com/pytorch/pytorch/issues/160208 self.assertEqual(torch.add(y, x, alpha=2).cpu(), torch.add(y.cpu(), x.cpu(), alpha=2)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Test add def test_add_scalars(self): @@ -8050,6 +8260,7 @@ def test_inplace_bitwise_not(self, dtype): x[::2].bitwise_not_() self.assertEqual(x_mps.cpu(), x_cpu) +<<<<<<< HEAD def test_empty_posneginf(self): # just to check that it doesnt crash input_tensor = torch.empty(0, device="mps") @@ -8061,6 +8272,10 @@ def test_empty_posneginf(self): class TestLargeTensors(TestCaseMPS): @serialTest() +======= + +class TestLargeTensors(TestCaseMPS): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_64bit_binops(self): if torch.mps.recommended_max_memory() < 16_000_000_000: raise unittest.SkipTest("Needs at least 16Gb of RAM") @@ -8089,6 +8304,7 @@ def test_64bit_index_select(self): gc.collect() torch.mps.empty_cache() +<<<<<<< HEAD @serialTest() def test_rand_2b_raises(self): int32_max = torch.iinfo(torch.int32).max @@ -8099,6 +8315,8 @@ def test_rand_2b_raises(self): self.assertEqual(x.numel(), int32_max) del x +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestLogical(TestCaseMPS): def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False): @@ -8210,7 +8428,11 @@ def test_min_max(self, dtype): z_cpu = x_cpu.min() self.assertEqual(z, z_cpu) +<<<<<<< HEAD @parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +======= + @parametrize("dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if MACOS_VERSION >= 14.0 else [])) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_min_max_nan_propagation(self, dtype): cpu_x = torch.tensor([1.0, float("nan"), 3.0], device="cpu", dtype=dtype) mps_x = cpu_x.detach().clone().to('mps') @@ -8257,10 +8479,17 @@ def helper(dtype): self.assertEqual(mps_out, cpu_ref) dtypes = [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.uint8, torch.int8] +<<<<<<< HEAD +======= + if MACOS_VERSION < 14.0: + # Int types expected to fail on MacOS < 14.0 + dtypes = [torch.float32, torch.float16, torch.bfloat16] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [helper(dtype) for dtype in dtypes] # Mixed dtypes (see https://github.com/pytorch/pytorch/issues/151443 ) +<<<<<<< HEAD x = torch.arange(4.0, device="mps") y = torch.tensor([1, 3], device="mps", dtype=torch.float16) self.assertEqual(torch.isin(x, y), torch.tensor([False, True, False, True], device="mps")) @@ -8271,6 +8500,20 @@ def helper(dtype): self.assertEqual(torch.isin(x, 8.0), torch.tensor([False, False, False, False], device="mps")) # Scalar.Tensor variant(alaises to Scalar.Scalar), not covered by OpInfo self.assertEqual(torch.isin(2.0, x), torch.tensor(True, device="mps")) +======= + # torch.isin is broken in MacOS-13.2 even for the same dtype + if MACOS_VERSION >= 14.0: + x = torch.arange(4.0, device="mps") + y = torch.tensor([1, 3], device="mps", dtype=torch.float16) + self.assertEqual(torch.isin(x, y), torch.tensor([False, True, False, True], device="mps")) + + # Tensor.Scalar variant (aliases to eq), not covered by OpInfo + self.assertEqual(torch.isin(x, 2.0), torch.tensor([False, False, True, False], device="mps")) + self.assertEqual(torch.isin(x, 1.0, invert=True), torch.tensor([True, False, True, True], device="mps")) + self.assertEqual(torch.isin(x, 8.0), torch.tensor([False, False, False, False], device="mps")) + # Scalar.Tensor varaiant(alaises to Scalar.Scalar), not covered by OpInfo + self.assertEqual(torch.isin(2.0, x), torch.tensor(True, device="mps")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_isin_asserts(self): C = torch.randn(size=[1, 4], device='mps', dtype=torch.float32) @@ -8983,12 +9226,15 @@ def test_constant_pad_nd_preserves_memory_format(self): nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5) self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last)) +<<<<<<< HEAD def test_constant_pad_nd_with_empty_pad(self): # Empty constant pad is no-op # See https://github.com/pytorch/pytorch/issues/161066 input_mps = torch.randn((2, 3, 4), device="mps") output_mps = torch.constant_pad_nd(input_mps, []) self.assertEqual(output_mps, input_mps) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestLinalgMPS(TestCaseMPS): def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False): @@ -9209,7 +9455,11 @@ def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b_f32) +<<<<<<< HEAD for dtype in [torch.float16, torch.float32, torch.bfloat16]: +======= + for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if MACOS_VERSION > 14.0 else []): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = a_f32.to(dtype=dtype) b = b_f32.to(dtype=dtype) b_scales_and_zeros = b_scales_and_zeros_f32.to(dtype=dtype) @@ -9237,7 +9487,11 @@ def weight_int8pack_mm(a, b_int8pack, b_scales): return torch._weight_int8pack_mm(a, b_int8pack, b_scales) b_int8pack, b_scales_f32 = convert_weight_to_int8pack(b_f32) +<<<<<<< HEAD for dtype in [torch.float16, torch.float32, torch.bfloat16]: +======= + for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if MACOS_VERSION > 14.0 else []): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = a_f32.to(dtype=dtype) b = b_f32.to(dtype=dtype) b_scales = b_scales_f32.to(dtype=dtype) @@ -9344,6 +9598,7 @@ def test_sdpa_mask_fp16_L6(self): def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self): self._test_sdpa_mask(torch.float16, 7, 17, 23, 121) +<<<<<<< HEAD # Regression test from: https://github.com/pytorch/pytorch/issues/156707 @parametrize("dtype", [torch.float16, torch.float32]) def test_sdpa_full_mask(self, dtype): @@ -9356,6 +9611,8 @@ def test_sdpa_full_mask(self, dtype): out_mps = F.scaled_dot_product_attention(q.to('mps'), k.to('mps'), v.to('mps'), attn_mask=mask.to('mps')) self._compare_tensors(out_mps.cpu(), out_cpu) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("dtype", [torch.float16, torch.float32]) def test_sdpa_3d_input(self, dtype): head_num, seq_len, embed_dim = 16, 16, 80 @@ -9454,7 +9711,11 @@ def test_sdpa_enable_gqa(self, dtype, is_causal): ) self._compare_tensors(y.cpu(), y_ref) +<<<<<<< HEAD @serialTest() +======= + @serialTest +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_sdpa_fp32_no_memory_leak(self): def get_mps_memory_usage(): return (torch.mps.current_allocated_memory() / (1024 * 1024), @@ -9472,6 +9733,7 @@ def get_mps_memory_usage(): # 5 MB different maximum allowed value(could be decreased even more) torch.testing.assert_close(memory_footprints[-1], memory_footprints[0], atol=5, rtol=1) +<<<<<<< HEAD def generate_qkv(self, batch: int, NH: int, q_len: int, s_len: int, head_dim: int, layout: str, dtype: torch.dtype): if layout == "contiguous": q = torch.randn(batch, NH, q_len, head_dim, dtype=dtype, device="mps") @@ -9503,6 +9765,19 @@ def run_fast_attention_test( dropout_p: float = 0.0, is_causal: bool = False, ): +======= + def generate_qkv(self, batch, NH, q_len, s_len, head_dim, contiguous, dtype): + if contiguous: + q = torch.randn(batch, NH, q_len, head_dim, dtype=dtype, device="mps") + k = torch.randn(batch, NH, s_len, head_dim, dtype=dtype, device="mps") + else: + q = torch.randn(batch, NH, head_dim, q_len, dtype=dtype, device="mps").mT + k = torch.randn(batch, NH, head_dim, s_len, dtype=dtype, device="mps").mT + v = torch.randn(batch, NH, s_len, head_dim, dtype=dtype, device="mps") + return q, k, v + + def run_fast_attention_test(self, q, k, v, with_mask, dropout_p=0.0, is_causal=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) q_len = q.shape[2] s_len = k.shape[2] @@ -9543,15 +9818,23 @@ def run_fast_attention_test( self._compare_tensors(y.cpu(), y_ref) @parametrize("dtype", [torch.float16, torch.float32]) +<<<<<<< HEAD @parametrize("layout", ["contiguous", "mT", "transpose_seq_head", "permute"]) @parametrize("head_dim", [64, 96, 128]) # 64, 96, 128 are for the fast kernel @parametrize("with_mask", [True, False]) def test_fast_vector_attention(self, dtype: torch.dtype, layout: str, head_dim: int, with_mask: bool): +======= + @parametrize("contiguous", [True, False]) + @parametrize("head_dim", [64, 96, 128]) # 64, 96, 128 are for the fast kernel + @parametrize("with_mask", [True, False]) + def test_fast_vector_attention(self, dtype, contiguous, head_dim, with_mask): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.manual_seed(1729) batch = 1 NH = 2 q_len = 4 # <8 so that vector fast is eligible s_len = 16 # smaller than 1024 so that we use the one–pass variant +<<<<<<< HEAD q, k, v = self.generate_qkv(batch, NH, q_len, s_len, head_dim, layout, dtype) self.run_fast_attention_test(q, k, v, with_mask) @@ -9559,26 +9842,47 @@ def test_fast_vector_attention(self, dtype: torch.dtype, layout: str, head_dim: @parametrize("layout", ["contiguous", "mT", "transpose_seq_head", "permute"]) @parametrize("with_mask", [True, False]) def test_fast_vector_attention_2pass(self, dtype: torch.dtype, layout: str, with_mask: bool): +======= + q, k, v = self.generate_qkv(batch, NH, q_len, s_len, head_dim, contiguous, dtype) + self.run_fast_attention_test(q, k, v, with_mask) + + @parametrize("dtype", [torch.float32]) # float16 underflows sometimes, which leads to flaky tests + @parametrize("contiguous", [True, False]) + @parametrize("with_mask", [True, False]) + def test_fast_vector_attention_2pass(self, dtype, contiguous, with_mask): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.manual_seed(1729) batch = 1 NH = 32 q_len = 8 s_len = 1024 # large enough to trigger the two–pass path head_dim = 64 # supported head dimension for vector attention +<<<<<<< HEAD q, k, v = self.generate_qkv(batch, NH, q_len, s_len, head_dim, layout, dtype) +======= + q, k, v = self.generate_qkv(batch, NH, q_len, s_len, head_dim, contiguous, dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.run_fast_attention_test(q, k, v, with_mask) @unittest.skip("Full attention fast kernel not implemented yet") @parametrize("dtype", [torch.float16, torch.float32]) +<<<<<<< HEAD @parametrize("layout", ["contiguous", "mT"]) @parametrize("head_dim", [64, 80, 128]) # 64, 80, 128 are for the fast kernel @parametrize("with_mask", [True, False]) def test_fast_full_attention(self, dtype: torch.dtype, layout: str, head_dim: int, with_mask: bool): +======= + @parametrize("contiguous", [True, False]) + @parametrize("head_dim", [64, 80, 128]) # 64, 80, 128 are for the fast kernel + @parametrize("with_mask", [True, False]) + def test_fast_full_attention(self, dtype, contiguous, head_dim, with_mask): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.manual_seed(1729) batch = 1 NH = 2 q_len = 32 # threshold to trigger full fast attention path s_len = 16 +<<<<<<< HEAD q, k, v = self.generate_qkv(batch, NH, q_len, s_len, head_dim, layout, dtype) self.run_fast_attention_test(q, k, v, with_mask) @@ -9654,6 +9958,12 @@ def new_fn(self, *args, **kwargs): TestSDPAMeta = create_sdpa_meta_test() instantiate_parametrized_tests(TestSDPAMeta) +======= + q, k, v = self.generate_qkv(batch, NH, q_len, s_len, head_dim, contiguous, dtype) + self.run_fast_attention_test(q, k, v, with_mask) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestGatherScatter(TestCaseMPS): def test_slicing_with_step(self): # Slicing with step @@ -10636,7 +10946,11 @@ def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1): grad_in_cl = torch.empty(1, f, oc, device="mps").transpose(1, 2) grad_in_cl[:] = grad_in +<<<<<<< HEAD # It does not matter whether grad_in contiguous, or channels last, results should equal to each other +======= + # It does not matter whether grad_in contigous, or channels last, results should equal to each other +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) grad_rc = torch.autograd.grad((out,), (inp, conv.weight, conv.bias), (grad_in,), retain_graph=True) grad_rc_cl = torch.autograd.grad((out,), (inp, conv.weight, conv.bias), (grad_in_cl,), retain_graph=True) @@ -11005,6 +11319,10 @@ class TestAdvancedIndexing(TestCaseMPS): supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8] supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8] +<<<<<<< HEAD +======= + @unittest.skipIf(MACOS_VERSION < 14.0, "Skipped on macOS < 14") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nonzero_no_warning(self): device = "mps" t = torch.randn((2, 2), device=device) @@ -12164,7 +12482,11 @@ def test_serialization_map_location(self): self.assertEqual(x2.device.type, "mps") +<<<<<<< HEAD MPS_UNSUPPORTED_TYPES = [torch.double, torch.cdouble] +======= +MPS_UNSUPPORTED_TYPES = [torch.double, torch.cdouble] + ([torch.bfloat16] if MACOS_VERSION < 14.0 else []) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MPS_DTYPES = [t for t in get_all_dtypes() if t not in MPS_UNSUPPORTED_TYPES] MPS_GRAD_DTYPES = [torch.float32, torch.float16] @@ -12285,6 +12607,13 @@ def _compute_tolerances(self, op, dtype): return (7e-4, 2e-3) if op.name == "native_layer_norm": return (1e-4, 1.3e-5) +<<<<<<< HEAD +======= + if op.name in ["pow", "__rpow__"] and MACOS_VERSION < 13.3: + # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721. + # fixed in macOS 13.3+ + return (1e-6, 2e-3 if dtype == torch.float16 else 4e-6) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if op.name in ['fft.rfftn', 'fft.hfftn', 'fft.hfft2', 'fft.fft', 'fft.fftn', 'fft.rfft']: # TODO: Investigate why this is needed # See https://github.com/pytorch/pytorch/issues/120237 @@ -12302,6 +12631,11 @@ def _compute_tolerances(self, op, dtype): def test_output_match(self, device, dtype, op): self.assertEqual(device, "mps:0") include_conjugated_inputs = dtype.is_complex and op.test_conjugated_samples +<<<<<<< HEAD +======= + if op.name.endswith("svd") and MACOS_VERSION < 14.0 and dtype == torch.complex64: + raise unittest.SkipTest("Can't even generate complex samples on MacOS-13") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_samples(): return op.sample_inputs( @@ -12352,6 +12686,7 @@ def get_samples(): # Similar to the above, float vs double precision aresults in slight error atol, rtol = 2e-5, 2e-6 +<<<<<<< HEAD if op.name in ["grid_sampler_3d", "asinh"]: atol, rtol = 1e-4, 1e-4 @@ -12364,6 +12699,8 @@ def get_samples(): self.assertEqual(values if keep_dim else values.squeeze(dim), mps_out[0]) continue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol) @ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES) @@ -12456,6 +12793,7 @@ def req_grad(t): # which leads to larger errors if op.name == "_unsafe_masked_index" and dtype == torch.float16: atol, rtol = 3e-3, 3e-3 +<<<<<<< HEAD if op.name == "logcumsumexp": atol, rtol = 4e-3, 1e-3 if op.name == "nn.functional.max_pool3d" and dtype == torch.float16: @@ -12498,6 +12836,10 @@ def get_samples(): self.assertEqual(half_out, full_out.to(dtype), atol=atol, rtol=rtol) +======= + self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_fmax_mixed_dtypes(self, device): # Regression tesing for https://github.com/pytorch/pytorch/issues/149951 # fmax and fmin are implemented as binary metal shaders and they were implemented @@ -12610,7 +12952,11 @@ def test_numpy_ref_mps(self, device, dtype, op): def test_tensor_creation(self, device, dtype): def ones(device): return torch.ones((2, 2), dtype=dtype, device=device) +<<<<<<< HEAD if dtype not in MPS_DTYPES: +======= + if dtype not in MPS_DTYPES + ([torch.bfloat16] if MACOS_VERSION > 14.0 else []): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaises(TypeError): ones(device) else: @@ -12717,8 +13063,15 @@ def test_metal_include(self): lib = torch.mps.compile_shader("#include ") self.assertIsNotNone(lib) +<<<<<<< HEAD @parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int64]) def test_reduction_utils(self, dtype): +======= + @parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.int64]) + def test_reduction_utils(self, dtype): + if dtype == torch.int64 and MACOS_VERSION < 13.3: + raise unittest.SkipTest("Using simd_shuffle_down_and_fill results in ICE on MacOS-13") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.codegen.mps import DTYPE_TO_METAL lib = torch.mps.compile_shader(f""" #include @@ -12727,6 +13080,7 @@ def test_reduction_utils(self, dtype): uint idx [[thread_position_in_grid]]) {{ out[idx] = c10::metal::simd_sum(inp[idx]); }} +<<<<<<< HEAD kernel void do_max(device {DTYPE_TO_METAL[dtype]}* out0, device int* out1, @@ -12765,6 +13119,21 @@ def test_reduction_utils(self, dtype): @parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.bfloat16]) def test_atomic_add(self, dtype): +======= + """) + x = torch.testing.make_tensor(28, device="mps", dtype=dtype) + y = torch.empty_like(x) + lib.do_sum(y, x) + x_sum = x.sum() + max_err = (y - x_sum).abs().max().item() + self.assertLess(max_err, 1e-2 if dtype == torch.float16 else 1e-5, + f"results are {y}, but all elements should have been {x_sum.item()}") + + @parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.bfloat16]) + def test_atomic_add(self, dtype): + if dtype == torch.bfloat16 and MACOS_VERSION < 14.0: + raise unittest.SkipTest("bfloat requires MacOS-14+") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.codegen.mps import DTYPE_TO_METAL mdtype = DTYPE_TO_METAL[dtype] lib = torch.mps.compile_shader(f""" diff --git a/test/test_multiprocessing_spawn.py b/test/test_multiprocessing_spawn.py index d093e01921dc1..962e56664516b 100644 --- a/test/test_multiprocessing_spawn.py +++ b/test/test_multiprocessing_spawn.py @@ -47,7 +47,11 @@ def _test_terminate_signal_func(i): def _test_terminate_exit_func(i, arg): if i == 0: sys.exit(arg) +<<<<<<< HEAD time.sleep(4.0) +======= + time.sleep(1.0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _test_success_first_then_exception_func(i, arg): @@ -145,7 +149,11 @@ def test_terminate_signal(self): with self.assertRaisesRegex(Exception, message): mp.start_processes(_test_terminate_signal_func, nprocs=2, start_method=self.start_method) +<<<<<<< HEAD @parametrize("grace_period", [None, 20]) +======= + @parametrize("grace_period", [None, 5]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_terminate_exit(self, grace_period): exitcode = 123 ctx = mp.start_processes(_test_terminate_exit_func, args=(exitcode,), nprocs=2, start_method=self.start_method, join=False) @@ -201,7 +209,11 @@ def _test_nested(self): try: os.kill(pid, 0) except ProcessLookupError: +<<<<<<< HEAD pids.remove(pid) # noqa: B909 +======= + pids.remove(pid) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) break # This assert fails if any nested child process is still diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index ac97f2beda8e8..5cae5bc562c33 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -4444,18 +4444,24 @@ def test_jagged_op_different_output_shape_dim( @dtypes(torch.float32) @parametrize("requires_grad", [False, True]) @parametrize("components_require_grad", [False, True]) +<<<<<<< HEAD @parametrize( "func", [torch.nn.functional.softmax, torch.nn.functional.log_softmax], name_fn=lambda func: func.__name__, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_softmax_dim( self, device, dtype, requires_grad, components_require_grad, +<<<<<<< HEAD func, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Softmax passes when reducing on valid reduction dimensions. @@ -4474,7 +4480,11 @@ def test_softmax_dim( for reduce_dim, _ in reduce_dims: nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged) +<<<<<<< HEAD out_actual = func(nt, dim=reduce_dim) +======= + out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._dynamo.disable(self.assertEqual)( len(out_actual.shape), len(output_shape) ) # disable if running on dynamo @@ -4504,10 +4514,19 @@ def test_softmax_dim( reduce_dim, reduce_dim_expected = reduce_dim_tuple if nt.dim() > reduce_dim: +<<<<<<< HEAD # nested tensor out_actual = func(nt, dim=reduce_dim) # dense tensor of dimensions 1 less than out_actual out_expected = func(nt.values(), dim=reduce_dim_expected) +======= + out_actual = torch.nn.functional.softmax( + nt, dim=reduce_dim + ) # nested tensor + out_expected = torch.nn.functional.softmax( + nt.values(), dim=reduce_dim_expected + ) # dense tensor of dimensions 1 less than out_actual +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue( torch.allclose(out_actual.values().view(-1), out_expected.view(-1)) ) @@ -4605,6 +4624,7 @@ def test_softmax_dim_reduce_ragged_idx_1( @dtypes(torch.float32) @parametrize("requires_grad", [False, True]) @parametrize("components_require_grad", [False, True]) +<<<<<<< HEAD @parametrize( "func", [torch.nn.functional.softmax, torch.nn.functional.log_softmax], @@ -4612,6 +4632,10 @@ def test_softmax_dim_reduce_ragged_idx_1( ) def test_softmax_reduce_batch_dim( self, device, dtype, requires_grad, components_require_grad, func +======= + def test_softmax_reduce_batch_dim( + self, device, dtype, requires_grad, components_require_grad +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Softmax on NestedTensor fails when trying to reduce across batch dimension. @@ -4636,7 +4660,11 @@ def test_softmax_reduce_batch_dim( RuntimeError, "not supported when reducing across the batch dimension for NestedTensor", ): +<<<<<<< HEAD out = func(nt, dim=reduce_dim) +======= + out = torch.nn.functional.softmax(nt, dim=reduce_dim) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.float32) @parametrize("requires_grad", [False, True]) @@ -5649,11 +5677,14 @@ def test_nested_tensor_from_jagged(self, device, dtype, pass_min_max): ): torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None) +<<<<<<< HEAD with self.assertRaisesRegex(ValueError, "Expected jagged_dim >=1, but got 0."): torch.nested.nested_tensor_from_jagged( values, lengths=lengths, jagged_dim=0 ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyCPU def test_nested_tensor_from_jagged_fx_trace(self, device): def fn(x, y): @@ -6760,10 +6791,18 @@ def check_forward_backward(skip_backward=False): and check_cudnn and (dtype == torch.float16 or dtype == torch.bfloat16) ): +<<<<<<< HEAD with torch.nn.attention.sdpa_kernel( torch.nn.attention.SDPBackend.CUDNN_ATTENTION ): check_forward_backward() +======= + with self.assertRaisesRegex(RuntimeError, "cuDNN SDPA Nested Tensor"): + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.CUDNN_ATTENTION + ): + check_forward_backward() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfTorchDynamo("SDPA test compiles internally") @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @@ -7203,7 +7242,11 @@ def forward(self, query, value, offsets): query = torch.rand(bs, d1, d3, device=device) value = torch.rand(30, d2, requires_grad=True, device=device) +<<<<<<< HEAD # total_length must > than max_length otherwise flash_attn backward will fail +======= + # total_length must > than max_length otherwise flash_attn backwark will fail +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) offsets = torch.tensor([0, 2, 3, 30], device=device) m = mha(use_legacy_api) @@ -7285,9 +7328,12 @@ def _rand_nt(noncontig_with_holes=noncontig_with_holes): return query, key, value +<<<<<<< HEAD @unittest.skip( "Temporarily skip - nested tensor backward pass broken after return-max-scores commit" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyCUDA @flex_attention_supported_platform @dtypes(torch.float32) @@ -8099,7 +8145,10 @@ def f(values, offsets): "std.unbiased", "var", "var.unbiased", +<<<<<<< HEAD "hash_tensor", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, name="not_implemented", ), @@ -8547,6 +8596,17 @@ def f(values, offsets): COMPILE_FORWARD_SKIPS_AND_XFAILS = [ *FORWARD_SKIPS_AND_XFAILS, +<<<<<<< HEAD +======= + # Needs investigation in AOTAutograd: len(unwrapped_args) == num_args_tallied assertion fails + # e.g. Expected 5 == 4 + XFailRule( + error_type=AssertionError, + op_match_fn=lambda device, op: (op.full_name == "fill"), + sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name), + name="fill_aot_autograd_bug_with_transposed_input", + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Bug: cross-device conversions with to() result in new nested ints within compile only XFailRule( error_type=AssertionError, @@ -8584,6 +8644,21 @@ def f(values, offsets): sample_match_fn=lambda device, sample: ("batch_dim" in sample.name), name="broken_select_backward_unbacked", ), +<<<<<<< HEAD +======= + # Bug: no idea what's going on here; needs investigation within AOTAutograd + XFailRule( + op_match_fn=lambda device, op: (op.full_name == "nan_to_num"), + sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name), + name="crazy_aot_autograd_bug1", + ), + # Bug: also no idea what's going on here: needs investigation within AOTAutograd + XFailRule( + op_match_fn=lambda device, op: (op.full_name == "isreal"), + sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name), + name="crazy_aot_autograd_bug2", + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] COMPILE_BACKWARD_SKIPS_AND_XFAILS = [ diff --git a/test/test_nn.py b/test/test_nn.py index 0c84d6ffe129e..a62c11d66295d 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -31,8 +31,12 @@ from torch.nn import Buffer, Parameter from torch.nn.parallel._functions import Broadcast from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types +<<<<<<< HEAD from torch.testing._internal.common_utils import dtype_name, freeze_rng_state, run_tests, TestCase, \ skipIfNoLapack, skipIfRocm, \ +======= +from torch.testing._internal.common_utils import dtype_name, freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \ download_file, get_function_arglist, load_tests, skipIfMPS, \ IS_PPC, \ @@ -56,12 +60,20 @@ from torch.testing._internal.common_utils import dtype2prec_DONTUSE from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_off, tf32_on from torch.types import _TensorOrTensors +<<<<<<< HEAD from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off +======= +from torch.testing._internal.common_mkldnn import bf32_on_and_off +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported() if TEST_WITH_ROCM: os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1" +<<<<<<< HEAD +======= + os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -583,6 +595,7 @@ def test_register_buffer_allows_overwriting_with_same_name(self): m.buffer_name = Buffer(buffer3) self.assertEqual(m.buffer_name, Buffer(buffer3)) +<<<<<<< HEAD def test_register_buffer_allows_tensor_like_object(self): class TensorLike: @classmethod @@ -599,6 +612,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): self.assertEqual(m.buffer_name, buffer2) self.assertEqual(m.get_buffer('buffer_name'), buffer2) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_get_buffer(self): m = nn.Module() buffer1 = torch.randn(2, 3) @@ -2033,7 +2048,11 @@ def fn(input): eval_out0 = wrapped_m(input) # assert eval gives same result as last training iteration self.assertEqual(eval_out0, last_train_out) +<<<<<<< HEAD # assert doing more iteration in eval don't change things +======= + # assert doing more iteartion in eval don't change things +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(eval_out0, wrapped_m(input)) self.assertEqual(last_train_u, m.weight_u) self.assertEqual(last_train_v, m.weight_v) @@ -3513,7 +3532,10 @@ def test_cudnn_forward_exception(self): self.assertRaisesRegex(RuntimeError, re.escape("input.size(-1) must be equal to input_size"), rnn, x_wrong) @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') +<<<<<<< HEAD @skipIfRocm +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cudnn_weight_format(self): rnns = [ nn.LSTM(10, 20, batch_first=True), @@ -3521,7 +3543,12 @@ def test_cudnn_weight_format(self): nn.GRU(10, 20, batch_first=True), nn.RNN(10, 20, batch_first=True) ] +<<<<<<< HEAD first_warn = True +======= + # ROCm RNN does not issue warning about single contig chunk of memory, so don't assert it + first_warn = False if torch.version.hip else True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for rnn in rnns: rnn.cuda() input = torch.randn(5, 4, 10, requires_grad=True, device="cuda") @@ -5147,6 +5174,20 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self): self.assertTrue(torch.equal(running_mean, bn.running_mean)) self.assertTrue(torch.equal(running_var, bn.running_var)) +<<<<<<< HEAD +======= + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_batchnorm_nhwc_cuda(self): + for dtype in (torch.half, torch.float): + (N, C, H, W) = 2, 64, 50, 50 + model = torch.nn.BatchNorm2d(C, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + model = model.eval().cuda().to(dtype) + inp1 = torch.randn(N, C, H, W, device=torch.device('cuda'), dtype=dtype) + inp2 = inp1.contiguous(memory_format=torch.channels_last) + out1 = model(inp1) + out2 = model(inp2) + self.assertEqual(out1, out2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @parametrize_test("dims", [2, 3], name_fn=lambda x: f"{x}D") @@ -5170,10 +5211,26 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self): ("NCHW", "native", False, torch.float), ("NCHW", "native", True, torch.half), ("NCHW", "native", True, torch.bfloat16), +<<<<<<< HEAD +======= + + ("NHWC", "cpu", False, torch.float), + ("NHWC", "cpu", True, torch.half), + ("NHWC", "cpu", True, torch.bfloat16), + + ("NHWC", "native", False, torch.float), + ("NHWC", "native", True, torch.half), + ("NHWC", "native", True, torch.bfloat16), + + ("NHWC", "NCHW", False, torch.float), + ("NHWC", "NCHW", True, torch.half), + ("NHWC", "NCHW", True, torch.bfloat16), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ], name_fn=lambda f, b, m, t: f"{f}_vs_{b}{'_mixed' if m else ''}_{dtype_name(t)}" ) def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype): +<<<<<<< HEAD if torch.version.cuda: if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16", "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16", @@ -5201,6 +5258,24 @@ def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype): if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16": self.skipTest("3D float16 NCHW train failed on ROCm") +======= + if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16": + self.skipTest("3D float16 NCHW train failed on CUDA and ROCm due to Native batchnorm accuracy issue SWDEV-541024") + if torch.version.hip: + if self._testMethodName in ("test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16", + "test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16", + "test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16", + "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16" + ) and _get_torch_rocm_version() < (6, 4): + # NCHW bfloat16 path uses native kernels for rocm<=6.3 + # train failed on rocm<=6.3 due to native tolerance issue SWDEV-507600 + self.skipTest("bfloat16 NHWC train failed on ROCm <= 6.3") + + if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_native_mixed_bfloat16", + "test_batchnorm_3D_train_NCHW_vs_native_mixed_bfloat16" + ) and _get_torch_rocm_version() >= (6, 4): + self.skipTest("bfloat16 NCHW train failed due to native tolerance issue SWDEV-507600") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if dims == 3 and memory_format in ("NHWC", "NCHW"): memory_format = memory_format + "3D" @@ -5252,10 +5327,18 @@ def _get_memory_format_from_name(memory_format_name: str) -> torch.memory_format return ValueError("Unsupported memory_format") def _create_backend(inp: torch.Tensor, mixed: bool = False): +<<<<<<< HEAD if inp.dim() == 4: return nn.BatchNorm2d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype) else: return nn.BatchNorm3d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype) +======= + + mod = nn.BatchNorm2d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype) \ + if inp.dim() == 4 else \ + nn.BatchNorm3d(inp.size(1), device=inp.device, dtype=torch.float if mixed else inp.dtype) + return mod +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _test_batchnorm_train(inp, grad, mixed, ref_inp, ref_grad, ref_backend): mod = _create_backend(inp, mixed).train() @@ -5296,6 +5379,28 @@ def _train(memory_format_name, ref_backend, mixed, dtype): _test_batchnorm_train(inp=inp, grad=grad, mixed=mixed, ref_inp=ref_inp, ref_grad=ref_grad, ref_backend=ref_backend) +<<<<<<< HEAD +======= + # TODO: enable permute logic later + # size = (2, 8, 8, 1) + # input = _create_tensor(size, memory_format, dtype, device="cuda").detach().requires_grad_() + # grad = _create_tensor(size, memory_format=torch.contiguous_format, dtype=dtype, device="cuda") + # # grad = _create_tensor(size, memory_format=memory_format, dtype=dtype, device="cuda") + + # ref_input = input.detach().clone(memory_format=ref_memory_format).to(device=ref_device).requires_grad_(True) + # ref_grad = grad.detach().clone(memory_format=torch.contiguous_format).to(device=ref_device) + # # ref_grad = grad.detach().clone(memory_format=ref_memory_format).to(device=ref_device) + + # if memory_format == torch.channels_last: + # grad = grad.permute(0, 2, 1, 3) + # # grad = grad.permute(0, 2, 3, 1) + # if ref_memory_format == torch.channels_last: + # ref_grad = ref_grad.permute(0, 2, 1, 3) + # # ef_grad = ref_grad.permute(0, 2, 3, 1) + # _test_batchnorm_train(input=input, grad=grad, mixed=mixed, + # ref_input=ref_input, ref_grad=ref_grad, ref_backend=ref_backend) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _inference(memory_format_name, ref_backend, mixed, dtype): memory_format = _get_memory_format_from_name(memory_format_name) ref_memory_format = _get_backend_memory_format(ref_backend, memory_format) @@ -5312,6 +5417,7 @@ def _inference(memory_format_name, ref_backend, mixed, dtype): ref_out = ref_mod(ref_inp) self.assertEqual(out, ref_out) +<<<<<<< HEAD if mode == "train": _train(memory_format, ref_backend, mixed, dtype) else: @@ -5328,6 +5434,22 @@ def test_batchnorm_nhwc_cuda(self): out1 = model(inp1) out2 = model(inp2) self.assertTrue(torch.equal(out1, out2)) +======= + # TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM once ROCm officially supports NHWC in MIOpen + PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM = "PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM" + prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM) + try: + os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM] = "1" + if mode == "train": + _train(memory_format, ref_backend, mixed, dtype) + else: + _inference(memory_format, ref_backend, mixed, dtype) + finally: + if prev_val is None: + del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM] + else: + os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM] = prev_val +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_batchnorm_load_state_dict(self): bn = torch.nn.BatchNorm2d(3) @@ -7435,7 +7557,10 @@ def test_layer_norm_backwards_eps(self): if bias and elementwise_affine: self.assertEqual(ln.bias.grad, ln_cuda.bias.grad, f"bias grad failed: {m=} {n=}", rtol=rtol, atol=atol) +<<<<<<< HEAD @unittest.skipIf(not TEST_CUDA, "CUDA not available") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @largeTensorTest("40GB", device="cuda") def test_layer_norm_large_tensor(self): # test for https://github.com/pytorch/pytorch/issues/136291 @@ -8299,7 +8424,11 @@ def _test_module_empty_inputs(self, module, inputs): "Scipy v1.0 and/or numpy not found") @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off() +<<<<<<< HEAD @reduced_f32_on_and_off() +======= + @bf32_on_and_off() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_affine_2d_rotate0(self, device): # scipy before 1.0.0 do not support homogeneous coordinate # scipy.ndimage.affine_transform, so we need to skip. @@ -8340,7 +8469,11 @@ def test_affine_2d_rotate0(self, device): "Scipy v1.0 and/or numpy not found") @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off(0.01 if TEST_WITH_ROCM else 0.001) +<<<<<<< HEAD @reduced_f32_on_and_off(0.001) +======= + @bf32_on_and_off(0.001) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_affine_2d_rotate90(self, device): # scipy before 1.0.0 do not support homogeneous coordinate # scipy.ndimage.affine_transform, so we need to skip. @@ -8390,7 +8523,11 @@ def test_affine_2d_rotate90(self, device): "Scipy v1.0 and/or numpy not found") @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off(0.005) +<<<<<<< HEAD @reduced_f32_on_and_off(0.005) +======= + @bf32_on_and_off(0.005) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_affine_2d_rotate45(self, device): # scipy before 1.0.0 do not support homogeneous coordinate # scipy.ndimage.affine_transform, so we need to skip. @@ -8468,7 +8605,11 @@ def test_avg_pool_large_tensor2(self, device): "Scipy v1.0 and/or numpy not found") @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) +<<<<<<< HEAD @reduced_f32_on_and_off(0.005) +======= + @bf32_on_and_off(0.005) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_affine_2d_rotateRandom(self, device): # scipy before 1.0.0 do not support homogeneous coordinate # scipy.ndimage.affine_transform, so we need to skip. @@ -8519,8 +8660,14 @@ def test_affine_2d_rotateRandom(self, device): @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), "Scipy v1.0 and/or numpy not found") +<<<<<<< HEAD @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) @reduced_f32_on_and_off(0.005) +======= + @expectedFailureMPS # aten::grid_sampler_3d not implemented https://github.com/pytorch/pytorch/issues/77764 + @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) + @bf32_on_and_off(0.005) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_affine_3d_rotateRandom(self, device): # scipy before 1.0.0 do not support homogeneous coordinate # scipy.ndimage.affine_transform, so we need to skip. @@ -8573,7 +8720,10 @@ def test_affine_3d_rotateRandom(self, device): self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyCUDA @dtypes(torch.float, torch.half) def test_batchnorm_large_batch(self, device, dtype): @@ -8768,6 +8918,7 @@ def rms_norm_reference_fn(i, normalized_shape, weight, eps=None): self.assertEqual(Y_ref, Y) +<<<<<<< HEAD @onlyNativeDeviceTypes @dtypes(torch.float16, torch.bfloat16, torch.float32, torch.float64) @dtypesIfMPS(torch.float16, torch.bfloat16, torch.float32) @@ -8793,6 +8944,8 @@ def rms_norm_reference_fn(i, normalized_shape): self.assertEqual(Y_ref, Y) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyCPU def test_glu_bfloat16(self, device): def test_dtype(fn, input, dtype): @@ -8945,7 +9098,10 @@ def group_norm_ref(X, gamma, beta, groups, channels, eps): Y_cpu = group_norm(X.cpu()) self.assertEqual(Y_cpu, Y, rtol=0, atol=1e-5) +<<<<<<< HEAD @expectedFailureMPS # Double is not supported on MPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes @dtypes(torch.float64, torch.complex128) def test_pad(self, device, dtype): @@ -8957,7 +9113,11 @@ def test_pad(self, device, dtype): # Should raise error when negative padding results in negative output shape self.assertRaises(RuntimeError, lambda: F.pad(inputs, (-3, -2), mode='circular')) +<<<<<<< HEAD # assert that reflection padding errors when pad >= input size +======= + # assert that relfection padding errors when pad >= input size +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expected_err_msg = r"Padding size should be less than the corresponding input dimension" inputs = torch.randn(1, 1, 2, 3, device=device, dtype=dtype) self.assertRaisesRegex(RuntimeError, expected_err_msg, @@ -8977,7 +9137,10 @@ def test_pad(self, device, dtype): out.fill_(4) self.assertTrue(torch.all(torch.abs(inputs) < 2)) +<<<<<<< HEAD @expectedFailureMPS # Unsupported float64/complex128 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes @dtypes(torch.float64, torch.complex128) def test_ReplicationPad_empty(self, device, dtype): @@ -9116,7 +9279,10 @@ def test_Bilinear_empty(self, device): self.assertEqual(inp1.grad, torch.zeros_like(inp1)) self.assertEqual(inp2.grad, torch.zeros_like(inp2)) +<<<<<<< HEAD @expectedFailureMPS # Double not supported +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] @onlyNativeDeviceTypes def test_TransformerEncoderLayer_empty(self, device): @@ -9146,7 +9312,10 @@ def test_TransformerEncoderLayer_empty(self, device): _test_module_empty_input(self, encoder_layer, input, check_size=False) @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] +<<<<<<< HEAD @expectedFailureMPS # Float64 is not supported +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_TransformerEncoder_empty(self, device): for batch_first, input_shape in [(True, (0, 10, 512)), @@ -9157,7 +9326,10 @@ def test_TransformerEncoder_empty(self, device): _test_module_empty_input(self, transformer_encoder, input, check_size=False) @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] +<<<<<<< HEAD @expectedFailureMPS # Float64 is not supported +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_TransformerDecoderLayer_empty(self, device): for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)), @@ -9168,7 +9340,10 @@ def test_TransformerDecoderLayer_empty(self, device): self._test_module_empty_inputs(decoder_layer, [tgt, memory]) @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] +<<<<<<< HEAD @expectedFailureMPS # Float64 is not supported +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_TransformerDecoder_empty(self, device): for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)), @@ -9180,7 +9355,10 @@ def test_TransformerDecoder_empty(self, device): self._test_module_empty_inputs(transformer_decoder, [tgt, memory]) @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] +<<<<<<< HEAD @expectedFailureMPS # Float64 is not supported +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_Transformer_empty(self, device): for batch_first, src_shape, tgt_shape in [(True, (10, 0, 512), (20, 0, 512))]: @@ -9215,7 +9393,11 @@ def test_ReflectionPad_empty(self, device, dtype): @onlyNativeDeviceTypes def test_ReflectionPad_fails(self, device): +<<<<<<< HEAD with self.assertRaisesRegex(RuntimeError, r'Padding size 2 is not supported for 4D input tensor'): +======= + with self.assertRaisesRegex(RuntimeError, 'Only 2D, 3D, 4D, 5D'): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mod = torch.nn.ReflectionPad1d(2) inp = torch.randn(3, 3, 10, 10, device=device) mod(inp) @@ -9224,7 +9406,11 @@ def test_ReflectionPad_fails(self, device): inp = torch.randn(3, 3, 10, 10, device=device) torch.ops.aten.reflection_pad1d(inp, (2, 2)) +<<<<<<< HEAD with self.assertRaisesRegex(RuntimeError, r'Padding size 4 is not supported for 5D input tensor'): +======= + with self.assertRaisesRegex(RuntimeError, 'Only 2D, 3D, 4D, 5D'): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mod = torch.nn.ReflectionPad2d(2) inp = torch.randn(3, 3, 10, 10, 10, device=device) mod(inp) @@ -9233,7 +9419,11 @@ def test_ReflectionPad_fails(self, device): inp = torch.randn(3, 3, 10, 10, 10, device=device) torch.ops.aten.reflection_pad2d(inp, (2, 2, 2, 2)) +<<<<<<< HEAD with self.assertRaisesRegex(RuntimeError, r'Padding size 6 is not supported for 6D input tensor'): +======= + with self.assertRaisesRegex(RuntimeError, 'Only 2D, 3D, 4D, 5D'): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mod = torch.nn.ReflectionPad3d(3) inp = torch.randn(3, 3, 10, 10, 10, 10, device=device) mod(inp) @@ -9316,7 +9506,10 @@ def test_ReflectionPad3d_large(self, device): self.assertEqual(x.grad, ref_x.grad) +<<<<<<< HEAD @expectedFailureMPS # Unimplemented margin_loss +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes @dtypes(torch.float, torch.double) def test_MarginLoss_empty(self, device, dtype): @@ -9345,6 +9538,7 @@ def test_MarginLoss_empty(self, device, dtype): mod(x, y) @onlyCUDA +<<<<<<< HEAD @dtypes(torch.float, torch.double) def test_MarginLoss_race(self, device, dtype): loss = torch.nn.MultiMarginLoss().to(device) @@ -9364,6 +9558,8 @@ def test_MarginLoss_race(self, device, dtype): self.assertEqual(x_cpu.grad, x.grad.cpu()) @onlyCUDA +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_MarginLoss_warnings(self, device): model = torch.nn.Linear(128, 22, device=device) loss = torch.nn.MultiMarginLoss() @@ -9376,6 +9572,7 @@ def test_MarginLoss_warnings(self, device): l.backward() self.assertTrue(len(f.getvalue()) == 0) +<<<<<<< HEAD @onlyCUDA def test_mse_loss_error(self, device): i = torch.randn((10, 1), device=device) @@ -9384,6 +9581,8 @@ def test_mse_loss_error(self, device): F.mse_loss(i, t) @expectedFailureMPS # TODO: Fixme, and raise assert on empty tensor +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_Unfold_empty(self, device): inp = torch.randn(0, 3, 3, 4, device=device) @@ -9607,7 +9806,10 @@ def verify_reduction_scalars(input, reduction, output): verify_reduction_scalars(input, reduction, output) # verify that bogus reduction strings are errors +<<<<<<< HEAD @expectedFailureMPS # CTCLoss unimplemented +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_invalid_reduction_strings(self, device): input = torch.randn(3, 5, requires_grad=True, device=device) @@ -10094,7 +10296,10 @@ def test_upsamplingNearestExact3d_correctness(self, device, memory_format, isize @parametrize_test("align_corners", [True, False]) @parametrize_test("mode", ["bilinear", "bicubic"]) @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) +<<<<<<< HEAD @expectedFailureMPS # double device type +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_upsamplingBiMode2d(self, device, antialias, align_corners, mode, memory_format): # Forward AD does not support XLA because XLA tensors don't have storage @@ -10164,7 +10369,10 @@ def test_upsamplingBiMode2d(self, device, antialias, align_corners, mode, memory @parametrize_test("num_channels", [3, 5]) @parametrize_test("mode", ["nearest", "nearest-exact", "bilinear", "bicubic"]) @parametrize_test("dtype", integral_types() + floating_types()) +<<<<<<< HEAD @skipIfMPS # Error message is wrong for some dtypes +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_upsamplingBiMode2d_nonsupported_dtypes(self, device, antialias, num_channels, mode, dtype): x = torch.ones(1, num_channels, 32, 32, dtype=dtype, device=device) @@ -11094,7 +11302,11 @@ def test_rnn_retain_variables(self, device, dtype): @onlyCUDA @dtypes(torch.double) def test_lstmcell_backward_only_one_output_grad(self, device, dtype): +<<<<<<< HEAD # checks that undefined gradients doesn't hamper the backward +======= + # checks that undefined gradients doen't hamper the backward +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # see #11872 l = torch.nn.LSTMCell(2, 3).to(device).to(dtype=dtype) s = torch.randn(1, 2, device=device, dtype=dtype, requires_grad=True) @@ -11487,7 +11699,10 @@ def test_hardsigmoid_grad(self, device): self.assertTrue(gradcheck(F.hardsigmoid, (inputs,))) # currently fails on XLA +<<<<<<< HEAD @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_hardswish_grad(self, device): inputs = (torch.randn(4, 16, 16, device=device, dtype=torch.double) - 0.5) * 10 @@ -11695,7 +11910,10 @@ def test_batchnorm_simple_average_mixed(self, device, dtype): self._test_batchnorm_simple_average(device, dtype, torch.float) @onlyNativeDeviceTypes +<<<<<<< HEAD @expectedFailureMPS # Unsupported Border padding mode +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.float, torch.double) def test_grid_sample_nan_inf(self, device, dtype): input = torch.zeros([1, 1, 3, 3], device=device, dtype=dtype) @@ -12045,6 +12263,7 @@ def test_activations_bfloat16(self, device): def test_softmax_bfloat16(self, device): for dim in [0, 1, 2, 3]: _test_bfloat16_ops(self, torch.nn.Softmax(dim=dim), device, inp_dims=(16, 33, 15, 16), prec=1e-2) +<<<<<<< HEAD # test softmax with large input value which causes exp() to overflow _test_bfloat16_ops(self, torch.nn.Softmax(dim=dim), device, inp_dims=(16, 33, 15, 16), prec=0.05, scale_factor=1000.0) @@ -12054,6 +12273,11 @@ def test_nll_loss_1d_input_1d_target_invalid_size(self, device): with self.assertRaisesRegex(ValueError, "For 1D input, 1D target must have size 1"): F.nll_loss(x, t) +======= + # test softmax with large input value which casues exp() to overflow + _test_bfloat16_ops(self, torch.nn.Softmax(dim=dim), device, inp_dims=(16, 33, 15, 16), prec=0.05, scale_factor=1000.0) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_nll_loss_mismatched_batch(self, device): x = torch.randn((10, 3), requires_grad=True, device=device) # t should have size (10,) @@ -12382,7 +12606,11 @@ def test_cross_entropy_label_smoothing_consistent_index_target_and_probs(self, d input = torch.randn(N, C, *other_dims, device=device, requires_grad=True) target = torch.empty(N, *other_dims, dtype=torch.long, device=device).random_(0, C) +<<<<<<< HEAD # construct target probability that should have the same result as label_smoothing +======= + # construct target probablity that should have the same result as label_smoothing +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) target_proba = F.one_hot(target, num_classes=C) # Need to put the C dim at index 1. target_proba = target_proba.permute(0, -1, *range(1, target_proba.dim() - 1)) @@ -12814,7 +13042,10 @@ def test_threshold_inplace_overlap(self, device): F.threshold(x, 0.5, 0.5, inplace=True) F.threshold_(x, 0.5, 0.5) +<<<<<<< HEAD @expectedFailureMPS # Double is unsupported +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_triplet_margin_with_distance_loss_default_parity(self, device): # Test for `nn.TripletMarginWithDistanceLoss` and @@ -12849,7 +13080,10 @@ def test_triplet_margin_with_distance_loss_default_parity(self, device): self.assertTrue(gradcheck(lambda a, p, n: loss_op(a, p, n), (anchor, positive, negative))) +<<<<<<< HEAD @expectedFailureMPS # Double is unsupported +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes def test_triplet_margin_with_distance_loss(self, device): # Test for parity between `nn.TripletMarginWithDistanceLoss` and diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py index 286882dfdb370..f0c86523bf68a 100644 --- a/test/test_numpy_interop.py +++ b/test/test_numpy_interop.py @@ -488,7 +488,11 @@ def test_parse_numpy_int_overflow(self, device): ) # type: ignore[call-overload] else: self.assertRaisesRegex( +<<<<<<< HEAD ValueError, +======= + RuntimeError, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "(Overflow|an integer is required)", lambda: torch.mean(torch.randn(1, 1), np.uint64(-1)), ) # type: ignore[call-overload] @@ -639,6 +643,7 @@ def test_empty_tensors_interop(self, device): # Regression test for https://github.com/pytorch/pytorch/issues/113037 self.assertEqual(torch.div(x, y, rounding_mode="floor").shape, y.shape) +<<<<<<< HEAD def test_ndarray_astype_object_graph_break(self): @torch.compile(backend="eager", fullgraph=True) def f(xs): @@ -661,6 +666,8 @@ def f(xs): ): f(xs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_device_type_tests(TestNumPyInterop, globals()) diff --git a/test/test_openreg.py b/test/test_openreg.py index 7ee8ccefcd093..b1b022d98cf3c 100644 --- a/test/test_openreg.py +++ b/test/test_openreg.py @@ -1,11 +1,15 @@ # Owner(s): ["module: PrivateUse1"] +<<<<<<< HEAD import _codecs import io +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import os import tempfile import types import unittest +<<<<<<< HEAD from unittest.mock import patch import numpy as np @@ -21,6 +25,18 @@ skipIfWindows, skipIfXpu, TemporaryFileName, +======= + +import psutil +import pytorch_openreg # noqa: F401 + +import torch +from torch.testing._internal.common_utils import ( + IS_LINUX, + run_tests, + skipIfTorchDynamo, + skipIfXpu, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TestCase, ) @@ -206,7 +222,11 @@ def test_backend_packed_sequence_methods(self): class TestOpenReg(TestCase): +<<<<<<< HEAD """Tests of mimic accelerator named OpenReg based on PrivateUse1""" +======= + """Tests of mimick accelerator named OpenReg based on PrivateUse1""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Stream & Event def test_stream_synchronize(self): @@ -286,8 +306,12 @@ def test_manual_seed(self): self.assertEqual(torch.openreg.initial_seed(), 2024) # type: ignore[misc] # Autograd +<<<<<<< HEAD @skipIfMPS @skipIfWindows() +======= + @unittest.skipIf(not IS_LINUX, "Only works on linux") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_autograd_init(self): # Make sure autograd is initialized torch.ones(2, requires_grad=True, device="openreg").sum().backward() @@ -341,6 +365,13 @@ def test_rewrapped_storage(self): self.assertNotEqual(pinned_a.data_ptr(), rewrapped_a.data_ptr()) # Serialization +<<<<<<< HEAD +======= + @unittest.skip( + "Temporarily disable due to the tiny differences between clang++ and g++ in defining static variable in inline function," + "this pr can fix this, https://github.com/pytorch/pytorch/pull/147095" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_serialization(self): storage = torch.UntypedStorage(4, device=torch.device("openreg")) self.assertEqual(torch.serialization.location_tag(storage), "openreg:0") @@ -372,6 +403,7 @@ def test_serialization(self): self.assertFalse(tensor_cpu.is_openreg) self.assertEqual(torch._utils.get_tensor_metadata(tensor_cpu), {}) # type: ignore[misc] +<<<<<<< HEAD @skipIfTorchDynamo() @unittest.skipIf( np.__version__ < "1.25", @@ -478,6 +510,9 @@ def test_open_device_cpu_serialization(self): torch.save(sd, f) # Operators +======= + # Opeartors +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_factory(self): x = torch.empty(3, device="openreg") self.assertEqual(x.device.type, "openreg") @@ -491,6 +526,7 @@ def test_factory(self): self.assertEqual(z.device.type, "openreg") self.assertEqual(z.shape, torch.Size([0])) +<<<<<<< HEAD def test_fake_tensor(self): with torch._subclasses.fake_tensor.FakeTensorMode(): a = torch.empty(1, device="openreg") @@ -500,6 +536,8 @@ def test_fake_tensor(self): def test_named_tensor(self): return torch.empty([2, 3, 4, 5], device="openreg", names=["N", "C", "H", "W"]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_printing(self): a = torch.ones(20, device="openreg") # Does not crash! @@ -542,6 +580,7 @@ def test_quantize(self): self.assertEqual(quantized_tensor.device, torch.device("openreg:0")) self.assertEqual(quantized_tensor.dtype, torch.qint8) +<<<<<<< HEAD # custom autograd def test_compile_autograd_function_returns_self(self): in_ref = torch.randn(4, device="openreg", requires_grad=True) @@ -625,6 +664,8 @@ def test_tensorlist_type_fallback(self): self.assertEqual(z_cpu, z[0]) self.assertEqual(z_cpu, z[1]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/test_ops.py b/test/test_ops.py index 2d5af9966690f..dfb32176b4451 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -87,7 +87,11 @@ # Get names of all the operators which have ref in their entry in OpInfo (testing infra) # except for elementwise unary operators (separately implemented in test/test_unary_ufuncs.py), # elementwise binary operators (separately implemented in test_binary_ufuncs.py), +<<<<<<< HEAD # reduction operations (separately implemented in test_reductions.py), +======= +# reduction operations (separately impelemented in test_reductions.py), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py) _ref_test_ops = tuple( filter( @@ -118,17 +122,30 @@ def reduction_dtype_filter(op): aten = torch.ops.aten meta_consistency_out_dtype_mismatch_xfails = { +<<<<<<< HEAD +======= + xfail("alias_copy"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("all"), xfail("amax"), xfail("amin"), xfail("aminmax"), xfail("any"), +<<<<<<< HEAD +======= + xfail("as_strided_copy"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("bucketize"), xfail("conj_physical"), xfail("cross"), xfail("cummax"), xfail("cummin"), xfail("diag"), +<<<<<<< HEAD +======= + xfail("diagonal_copy"), + xfail("expand_copy"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("fft.ihfft2"), xfail("fft.ihfftn"), xfail("frexp"), @@ -163,6 +180,11 @@ def reduction_dtype_filter(op): xfail("msort"), xfail("multinomial"), xfail("nan_to_num"), +<<<<<<< HEAD +======= + xfail("nanmean"), + xfail("narrow_copy"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("native_batch_norm"), xfail("neg"), xfail("nn.functional.avg_pool3d"), @@ -172,6 +194,10 @@ def reduction_dtype_filter(op): xfail("nn.functional.softplus"), xfail("nn.functional.softshrink"), xfail("ormqr"), +<<<<<<< HEAD +======= + xfail("permute_copy"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("qr"), xfail("renorm"), xfail("round"), @@ -186,10 +212,22 @@ def reduction_dtype_filter(op): xfail("softmax"), xfail("sort"), xfail("sparse.sampled_addmm"), +<<<<<<< HEAD xfail("take"), xfail("tril"), xfail("triu"), xfail("unfold_copy"), +======= + xfail("squeeze_copy"), + xfail("t_copy"), + xfail("take"), + xfail("transpose_copy"), + xfail("tril"), + xfail("triu"), + xfail("unfold_copy"), + xfail("unsqueeze_copy"), + xfail("view_copy"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xfail("where"), # Output has dynamic shape. # Does not have a meta kernel implementation. @@ -373,7 +411,11 @@ def to_cpu(arg): # output_process_fn_grad has a very unfortunate name # We use this function in linalg extensively to postprocess the inputs of functions +<<<<<<< HEAD # that are not completely well-defined. Think svd and multiplying the singular vectors by -1. +======= + # that are not completely well-defined. Think svd and muliplying the singular vectors by -1. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # CPU and CUDA implementations of the SVD can return valid SVDs that are different. # We use this function to compare them. cuda_results = sample.output_process_fn_grad(cuda_results) @@ -580,7 +622,11 @@ def _distance(a, b): # Tests that experimental Python References perform the same computation # as the operators they reference, when operator calls in the torch +<<<<<<< HEAD # namespace are remapped to the refs namespace (torch.foo becomes refs.foo). +======= + # namesapce are remapped to the refs namespace (torch.foo becomes refs.foo). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypesAnd(["hpu"]) @ops(python_ref_db) @skipIfTorchInductor("Takes too long for inductor") @@ -759,7 +805,11 @@ def test_noncontiguous_samples(self, device, dtype, op): else tuple(n_inp) + n_args ) +<<<<<<< HEAD # Filter the elements that are tensors that require grad +======= + # Filter the elemnts that are tensors that require grad +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) t_input_tensors = [ t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad ] @@ -1109,6 +1159,7 @@ def _case_four_transform(t): if op.is_factory_function and sample.kwargs.get("dtype", None) is None: op_out(out=out) else: +<<<<<<< HEAD # TODO: Remove me when all ops will raise type error on mismatched types exc_type = ( TypeError @@ -1125,6 +1176,9 @@ def _case_four_transform(t): else RuntimeError ) with self.assertRaises(exc_type, msg=msg_fail): +======= + with self.assertRaises(RuntimeError, msg=msg_fail): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) op_out(out=out) @ops( @@ -1601,6 +1655,7 @@ def _tensor_requires_grad(x): ) == 0: return +<<<<<<< HEAD if TEST_WITH_TORCHDYNAMO: # NOTE: Also for TEST_WITH_TORCHINDUCTOR tests # Under compile, some ops may be decomposed into supported ops @@ -1611,6 +1666,8 @@ def _tensor_requires_grad(x): ) == 0: return +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Reference operators often support additional dtypes, and that's OK if op in python_ref_db: if ( @@ -2511,6 +2568,10 @@ def test_refs_are_in_decomp_table(self, op): "mvlgamma.mvlgamma_p_1", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend "mvlgamma.mvlgamma_p_3", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend "mvlgamma.mvlgamma_p_5", # Could not run 'aten::_local_scalar_dense' with arguments from the 'Meta' backend +<<<<<<< HEAD +======= + "nanmean", # logical_not() got an unexpected keyword argument 'out' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "quantile", # quantile() q values must be in the range [0, 1] "nanquantile", # quantile() q values must be in the range [0, 1] "nn.functional.ctc_loss", # The tensor has a non-zero number of elements, but its data is not allocated yet @@ -2595,7 +2656,10 @@ def test_refs_are_in_decomp_table(self, op): @unMarkDynamoStrictTest class TestFakeTensor(TestCase): def setUp(self): +<<<<<<< HEAD super().setUp() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Turn on FakeTensor caching and cross-checking for these tests: cache_enabled = unittest.mock.patch( "torch._dynamo.config.fake_tensor_cache_enabled", True diff --git a/test/test_ops_jit.py b/test/test_ops_jit.py index 9dfb75cc6a8f1..b9faaf52d1f10 100644 --- a/test/test_ops_jit.py +++ b/test/test_ops_jit.py @@ -188,7 +188,11 @@ def get_sample(): # Note: only runs in float32 because schema isn't affected by dtype, # so running it on all dtypes is would be excessive if dtype == torch.float32: +<<<<<<< HEAD # TODO: no reason why we can't run this with tracing graph +======= + # TODO: no reason why we cant run this with tracing graph +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if support_script and op.name != "rsub": check_alias_annotation( name, diff --git a/test/test_optim.py b/test/test_optim.py index 6dd23d6328c89..a613644309181 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -187,8 +187,12 @@ def test_forloop_goes_right_direction( ) input = torch.randn(5, device=device, dtype=dtype) +<<<<<<< HEAD params = [weight, bias] if optim_cls.__name__ != "Muon" else [weight] optimizer = optim_cls(params, **optim_input.kwargs) +======= + optimizer = optim_cls([weight, bias], **optim_input.kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) schedulers = [ s(optimizer) for s in (schedulers_constructor if schedulers_constructor else []) @@ -196,12 +200,16 @@ def test_forloop_goes_right_direction( def closure(): optimizer.zero_grad() +<<<<<<< HEAD wo = ( weight.mv(input) if optim_cls.__name__ == "Muon" else weight.mv(input) + bias ) loss = wo.pow(2).sum() +======= + loss = (weight.mv(input) + bias).pow(2).sum() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) loss.backward() if optim_info.only_supports_sparse_grads: # For this test, we naively convert the Tensor layout, which we know does @@ -252,8 +260,12 @@ def test_forloop_goes_right_direction_multigpu( bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype)) inpt = torch.randn(5, device="cuda:0", dtype=dtype) +<<<<<<< HEAD params = [weight, bias] if optim_cls.__name__ != "Muon" else [weight] optimizer = optim_cls(params, **optim_input.kwargs) +======= + optimizer = optim_cls([weight, bias], **optim_input.kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) schedulers = [ s(optimizer) for s in (schedulers_constructor if schedulers_constructor else []) @@ -261,12 +273,16 @@ def test_forloop_goes_right_direction_multigpu( def closure(): optimizer.zero_grad() +<<<<<<< HEAD wo = ( weight.mv(inpt).cuda(1) if optim_cls.__name__ == "Muon" else weight.mv(inpt).cuda(1) + bias ) loss = wo.pow(2).sum() +======= + loss = (weight.mv(inpt).cuda(1) + bias).pow(2).sum() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) loss.backward() if optim_info.only_supports_sparse_grads: # For this test, we naively convert the Tensor layout, which we know does @@ -297,25 +313,41 @@ def test_param_group_with_lrscheduler_goes_right_direction( for schedulers_c in optim_info.scheduler_inputs: weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) +<<<<<<< HEAD weight2 = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) +======= + bias = Parameter(torch.randn((10), device=device, dtype=dtype)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inpt = torch.randn(5, device=device, dtype=dtype) # avoid endless recompiles by wrapping LR in a tensor if we're compiling lr = torch.tensor(0.01) if torch.compiler.is_compiling() else 0.01 +<<<<<<< HEAD optimizer = optim_cls( [{"params": [weight]}, {"params": [weight2], "lr": lr}] ) +======= + optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c] def closure(): optimizer.zero_grad() +<<<<<<< HEAD loss = (weight.mv(inpt) + weight2.mv(inpt)).pow(2).sum() +======= + loss = (weight.mv(inpt) + bias).pow(2).sum() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) loss.backward() if optim_info.only_supports_sparse_grads: # For this test, we naively convert the Tensor layout, which we know does # NOT represent the expected use case for optims like SparseAdam! weight.grad = weight.grad.to_sparse() +<<<<<<< HEAD weight2.grad = weight2.grad.to_sparse() +======= + bias.grad = bias.grad.to_sparse() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return loss initial_value = closure().item() @@ -353,26 +385,39 @@ def test_tensor_lr(self, device, dtype, optim_info, num_dim): if "lr" in kwargs: del kwargs["lr"] +<<<<<<< HEAD params = [weight, bias] if optim_cls.__name__ != "Muon" else [weight] kwargs["lr"] = 1.0 if optim_info.step_requires_closure else 1e-3 optimizer_r = optim_cls(params, **kwargs) +======= + kwargs["lr"] = 1.0 if optim_info.step_requires_closure else 1e-3 + optimizer_r = optim_cls([weight, bias], **kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: kwargs["lr"] = ( torch.tensor(kwargs["lr"]).reshape([1] * num_dim).to(lr_device) ) +<<<<<<< HEAD params_c = [weight_c, bias_c] if optim_cls.__name__ == "Muon": params_c = [weight_c] optimizer = optim_cls(params_c, **kwargs) +======= + optimizer = optim_cls([weight_c, bias_c], **kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except ValueError as e: self.assertRegex(str(e), ".*lr as a Tensor is not supported.*") continue def closure(optim, w, b, i): optim.zero_grad() +<<<<<<< HEAD wo = w.mv(i) if optim_cls.__name__ == "Muon" else w.mv(i) + b loss = wo.pow(2).sum() +======= + loss = (w.mv(i) + b).pow(2).sum() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) loss.backward() if optim_info.only_supports_sparse_grads: # For this test, we naively convert the Tensor layout, which we know does @@ -396,8 +441,12 @@ def closure(optim, w, b, i): optimizer.step() self.assertEqual(weight, weight_c) +<<<<<<< HEAD if optim_cls.__name__ != "Muon": self.assertEqual(bias, bias_c) +======= + self.assertEqual(bias, bias_c) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("with_lrsched", [True, False]) @optims( @@ -1237,16 +1286,25 @@ def test_param_groups_weight_decay(self, device, dtype, optim_info): ) for optim_input in all_optim_inputs: weight_kwargs = optim_input.kwargs +<<<<<<< HEAD weight2_kwargs = deepcopy(optim_input.kwargs) weight2_kwargs["weight_decay"] = 0.0 weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) weight2 = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) +======= + bias_kwargs = deepcopy(optim_input.kwargs) + bias_kwargs["weight_decay"] = 0.0 + + weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) + bias = Parameter(torch.randn((10), device=device, dtype=dtype)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input = torch.randn(5, device=device, dtype=dtype) optimizer = optim_cls( [ dict(params=[weight], **weight_kwargs), +<<<<<<< HEAD dict(params=[weight2], **weight2_kwargs), ] ) @@ -1256,12 +1314,27 @@ def test_param_groups_weight_decay(self, device, dtype, optim_info): for _ in range(20): optimizer.zero_grad() loss = (weight.mv(input) + weight2.mv(input)).pow(2).sum() +======= + dict(params=[bias], **bias_kwargs), + ] + ) + + loss = (weight.mv(input) + bias).pow(2).sum() + initial_value = loss.item() + for _ in range(20): + optimizer.zero_grad() + loss = (weight.mv(input) + bias).pow(2).sum() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) loss.backward() if optim_info.only_supports_sparse_grads: # For this test, we naively convert the Tensor layout, which we know does # NOT represent the expected use case for optims like SparseAdam! weight.grad = weight.grad.to_sparse() +<<<<<<< HEAD weight2.grad = weight2.grad.to_sparse() +======= + bias.grad = bias.grad.to_sparse() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) optimizer.step() # Test that the direction of loss moved appropriately @@ -1288,6 +1361,7 @@ def test_param_groups_lr(self, device, dtype, optim_info): weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype)) bias = Parameter(torch.randn((10), device=device, dtype=dtype)) +<<<<<<< HEAD irrelevant = Parameter(torch.randn((2, 2), device=device, dtype=dtype)) irrelevant_clone = irrelevant.clone() input = torch.randn(5, device=device, dtype=dtype) @@ -1295,11 +1369,20 @@ def test_param_groups_lr(self, device, dtype, optim_info): optimizer = optim_cls( [ dict(params=params, **optim_input.kwargs), +======= + irrelevant = Parameter(torch.randn(2, device=device, dtype=dtype)) + irrelevant_clone = irrelevant.clone() + input = torch.randn(5, device=device, dtype=dtype) + optimizer = optim_cls( + [ + dict(params=[weight, bias], **optim_input.kwargs), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dict(params=[irrelevant]), ], **outer_kwargs, ) +<<<<<<< HEAD wo = ( weight.mv(input) if optim_cls.__name__ == "Muon" @@ -1315,6 +1398,13 @@ def test_param_groups_lr(self, device, dtype, optim_info): else weight.mv(input) + bias ) loss = wo.pow(2).sum() +======= + loss = (weight.mv(input) + bias).pow(2).sum() + initial_value = loss.item() + for _ in range(20): + optimizer.zero_grad() + loss = (weight.mv(input) + bias).pow(2).sum() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) loss.backward() irrelevant.grad = torch.rand_like(irrelevant) if optim_info.only_supports_sparse_grads: @@ -1372,8 +1462,13 @@ def closure(): if kwargs.get("weight_decay", 0) != 0: continue +<<<<<<< HEAD # AdamW/Muon params will be updated regardless of grads due to lr, so make lr smaller if optim_cls.__name__ == "AdamW" or optim_cls.__name__ == "Muon": +======= + # AdamW params will be updated regardless of grads due to lr, so make lr smaller + if optim_cls.__name__ == "AdamW": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs["lr"] = ( torch.tensor(1e-5) if isinstance(kwargs.get("lr", 1e-5), torch.Tensor) @@ -1470,8 +1565,11 @@ def test_state_dict_deterministic( bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype)) input = torch.randn(3, requires_grad=True, device=device, dtype=dtype) params = [weight, bias] +<<<<<<< HEAD if optim_cls.__name__ == "Muon": params = [weight] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def make_named_param(param, is_named): if not is_named: @@ -1486,8 +1584,12 @@ def without_param_names(state_dict): def fwd_bwd(optim, w, b, i): optim.zero_grad() +<<<<<<< HEAD wo = w.mv(i) if optim_cls.__name__ == "Muon" else w.mv(i) + b loss = wo.pow(2).sum() +======= + loss = (w.mv(i) + b).pow(2).sum() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) loss.backward() if optim_info.only_supports_sparse_grads: if w.grad is not None: @@ -1513,10 +1615,14 @@ def fwd_bwd(optim, w, b, i): with torch.no_grad(): weight_c = Parameter(weight.clone()) bias_c = Parameter(bias.clone()) +<<<<<<< HEAD params_c_list = ( [weight_c, bias_c] if optim_cls.__name__ != "Muon" else [weight_c] ) params_c = make_named_param(params_c_list, is_named=is_named_optim1) +======= + params_c = make_named_param([weight_c, bias_c], is_named=is_named_optim1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) optimizer_c = optim_cls(params_c, **optim_input.kwargs) closure_c = functools.partial(fwd_bwd, optimizer_c, weight_c, bias_c, input) @@ -1535,8 +1641,12 @@ def fwd_bwd(optim, w, b, i): optimizer_c.step() self.assertEqual(weight, weight_c) +<<<<<<< HEAD if optim_cls.__name__ != "Muon": self.assertEqual(bias, bias_c) +======= + self.assertEqual(bias, bias_c) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Make sure state dict is deterministic with equal (not identical) parameters # Param names are optional and not needed to be the consistent. @@ -1560,6 +1670,7 @@ def test_can_load_older_state_dict(self, device, dtype, optim_info): all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( device, dtype, optim_info, skip=("differentiable",) ) +<<<<<<< HEAD def _get_model_and_input_tensor(device, dtype, optim_cls): if optim_cls.__name__ == "Muon": @@ -1578,6 +1689,16 @@ def _get_model_and_input_tensor(device, dtype, optim_cls): for optim_input in all_optim_inputs: torch.manual_seed(1) model, input = _get_model_and_input_tensor(device, dtype, optim_cls) +======= + for optim_input in all_optim_inputs: + torch.manual_seed(1) + model = torch.nn.Sequential( + torch.nn.Conv2d(4, 2, 1, stride=2), + torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1), + ) + model.to(dtype=dtype, device=device) + input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) optimizer = optim_cls(model.parameters(), **optim_input.kwargs) def fwd_bwd(optim, mod, i): @@ -1625,6 +1746,7 @@ def test_can_load_from_to_named_state_dict( all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( device, dtype, optim_info, skip=("differentiable",) ) +<<<<<<< HEAD def _get_model_and_input_tensor(device, dtype, optim_cls): if optim_cls.__name__ == "Muon": @@ -1643,6 +1765,16 @@ def _get_model_and_input_tensor(device, dtype, optim_cls): for optim_input in all_optim_inputs: torch.manual_seed(1) model, input = _get_model_and_input_tensor(device, dtype, optim_cls) +======= + for optim_input in all_optim_inputs: + torch.manual_seed(1) + model = torch.nn.Sequential( + torch.nn.Conv2d(4, 2, 1, stride=2), + torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1), + ) + model.to(dtype=dtype, device=device) + input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def fwd_bwd(optim, mod, i): optim.zero_grad() @@ -1679,12 +1811,19 @@ def fwd_bwd(optim, mod, i): fwd_bwd(optimizer2, model, input) optimizer2.step() +<<<<<<< HEAD ref_names = [p[0] for p in model.named_parameters()] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Make sure that param_names are preserved when provided to at least one of the optimizers if is_named_optim0 or is_named_optim1: self.assertEqual( optimizer2.state_dict()["param_groups"][0]["param_names"], +<<<<<<< HEAD ref_names, +======= + ["0.weight", "0.bias", "1.weight", "1.bias"], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @parametrize("is_named_optim", [True, False]) @@ -1703,7 +1842,11 @@ def test_save_load_equality_with_weights_only( ) bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype)) input = torch.randn(3, requires_grad=True, device=device, dtype=dtype) +<<<<<<< HEAD params = [weight, bias] if optim_cls.__name__ != "Muon" else [weight] +======= + params = [weight, bias] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def make_named_param(param, is_named): if not is_named: @@ -1712,8 +1855,12 @@ def make_named_param(param, is_named): def fwd_bwd(optim, w, b, i): optim.zero_grad() +<<<<<<< HEAD wo = w.mv(i) if optim_cls.__name__ == "Muon" else w.mv(i) + b loss = wo.pow(2).sum() +======= + loss = (w.mv(i) + b).pow(2).sum() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) loss.backward() if optim_info.only_supports_sparse_grads: weight.grad = weight.grad.to_sparse() @@ -1997,7 +2144,11 @@ def post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): nonlocal data data += 2 +<<<<<<< HEAD params = [torch.tensor([[1, 1]], device=device, dtype=dtype)] +======= + params = [torch.tensor([1, 1], device=device, dtype=dtype)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def dummy_closure(): return 1 @@ -2029,8 +2180,12 @@ def pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): nonlocal data data += 2 +<<<<<<< HEAD # Create a random 2D tensor for compatibility with Muon. params = [torch.tensor([[1, 1]], device=device, dtype=dtype)] +======= + params = [torch.tensor([1, 1], device=device, dtype=dtype)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def dummy_closure(): return 1 @@ -2074,7 +2229,11 @@ def local_post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): nonlocal data data.append(2) +<<<<<<< HEAD params = [torch.tensor([[1, 1]], device=device, dtype=dtype)] +======= + params = [torch.tensor([1, 1], device=device, dtype=dtype)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def dummy_closure(): return 1 @@ -2280,8 +2439,12 @@ def test_defaults_changed_to_foreach(self, device, dtype, optim_info): def test_non_empty_state(self, device, dtype, optim_info): # There are internal tests that check that the state is not empty optim_cls = optim_info.optim_cls +<<<<<<< HEAD # Muon only accepts 2D parameter. model = torch.nn.Linear(5, 5, bias=False) +======= + model = torch.nn.Linear(5, 5) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model.to(dtype=dtype, device=device) inpt = torch.rand(2, 5, dtype=dtype, device=device) diff --git a/test/test_overrides.py b/test/test_overrides.py index 8454677856d0f..dc23739f03779 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -77,7 +77,11 @@ def quux(a): # dictionary are function names in the torch API and the values are # function implementations. Implementations are added to # HANDLED_FUNCTION_DIAGONAL by decorating a python function with +<<<<<<< HEAD # implements_diagonal. See the overrides immediately below the definition +======= +# implements_diagonal. See the overrides immediately below the defintion +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # of DiagonalTensor for usage examples. HANDLED_FUNCTIONS_DIAGONAL = {} @@ -133,7 +137,11 @@ class DiagonalTensor: https://numpy.org/devdocs/user/basics.dispatch.html """ # This is defined as a class attribute so that SubDiagonalTensor +<<<<<<< HEAD # below which subclasses DiagonalTensor can reuse DiagonalTensor's +======= + # below which subclasses DiagonalTensor can re-use DiagonalTensor's +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # __torch_function__ implementation. handled_functions = HANDLED_FUNCTIONS_DIAGONAL @@ -615,6 +623,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): self.assertEqual(NothingImplemented() ** RPowOnly(), -1) +<<<<<<< HEAD def test_torch_function_in_lists(self): """Test that __torch_function__ is called for objects inside lists""" @@ -880,6 +889,8 @@ def __index__(self): self.assertNotIn('size', called_functions, "size should not be called - we should use getitem, not convert to advanced indexing") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def generate_tensor_like_override_tests(cls): from torch.testing._internal.generated.annotated_fn_args import annotated_args @@ -1400,6 +1411,7 @@ def test_resolve_name(self): ) class TestTorchFunctionWarning(TestCase): +<<<<<<< HEAD def test_torch_function_standalone_class(self): class StandaloneTorchFunctionClass: @classmethod @@ -1425,6 +1437,31 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): result2 = torch.abs(b) self.assertEqual(result1, torch.tensor(99.0)) self.assertEqual(result2, torch.tensor(99.0)) +======= + def test_warn_on_invalid_torch_function_standalone_class(self): + class StandaloneTorchFunctionClass: + def __torch_function__(self, *args, **kwargs): + pass + a = StandaloneTorchFunctionClass() + with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"): + # Function that handles torch_function on the python side + torch.nn.functional.dropout(a) + with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"): + # Function that handles torch_function in C++ + torch.abs(a) + + def test_warn_on_invalid_torch_function_tensor_subclass(self): + class TensorSubclassTorchFunctionClass(torch.Tensor): + def __torch_function__(self, *args, **kwargs): + pass + b = TensorSubclassTorchFunctionClass() + with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"): + # Function that handles torch_function on the python side + torch.nn.functional.dropout(b) + with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"): + # Function that handles torch_function in C++ + torch.abs(b) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestDisabledUserWarnings(TestCase): def test_no_implicit_user_warning_for_deprecated_functions(self): diff --git a/test/test_per_overload_api.py b/test/test_per_overload_api.py index e5cf2aa1d5679..823883b8084d5 100644 --- a/test/test_per_overload_api.py +++ b/test/test_per_overload_api.py @@ -7,7 +7,11 @@ class TestPerOverloadAPI(TestCase): def test_basics_opoverloadpacket(self): +<<<<<<< HEAD # add is only used as an example here. It is ok to update the test +======= + # add is ony used as an example here. It is ok to update the test +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # if the semantics of add are modified in the future. add_packet = torch.ops.aten.add diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 6d36b36996c4b..fa563e903f9d1 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1370,8 +1370,13 @@ def forward(self, crop_camera_1, mask_1): view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None +<<<<<<< HEAD mul_9 = sym_size_int * 3 view_3 = torch.ops.aten.view.default(view_2, [mul_9, 3]); view_2 = mul_9 = None +======= + mul_6 = sym_size_int * 3 + view_3 = torch.ops.aten.view.default(view_2, [mul_6, 3]); view_2 = mul_6 = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None _unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view); crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 09bbbcbadcc87..d21534f8202bd 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -512,7 +512,11 @@ def check_one_element(elem, modname, mod, *, is_public, is_all): "does not have `__all__` defined" ) fix_is_public = ( +<<<<<<< HEAD f"remove it from the modules' (`{modname}`) `__all__`" +======= + f"remove it from the modules's (`{modname}`) `__all__`" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_all else f"either define a `__all__` for `{modname}` or add a `_` at the beginning of the name" ) @@ -522,7 +526,11 @@ def check_one_element(elem, modname, mod, *, is_public, is_all): f"it is not inside the module's (`{modname}`) `__all__`" ) fix_is_public = ( +<<<<<<< HEAD f"add it from the modules' (`{modname}`) `__all__`" +======= + f"add it from the modules's (`{modname}`) `__all__`" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if looks_public: why_looks_public = ( diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 07a92244cd733..46830a155766b 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -1,6 +1,10 @@ # Owner(s): ["module: __torch_dispatch__"] # ruff: noqa: F841 +<<<<<<< HEAD +======= +import logging +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import pickle import sys import tempfile @@ -155,7 +159,11 @@ def second_fallback(op, *args, **kwargs): # New dispatcher call should hit the first callback again self.assertFalse(first_called) a, b = args +<<<<<<< HEAD # Make a subtraction here instead of add ! +======= + # Make a substraction here instead of add ! +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c = a - b self.assertTrue(first_called) return c @@ -587,6 +595,7 @@ def test_error_for_unsupported_ns_or_kind(self) -> None: with self.assertRaisesRegex(ValueError, "reserved namespace"): my_lib1 = Library("prim", kind) # noqa: TOR901 +<<<<<<< HEAD def test_dispatcher_error_filenames(self) -> None: # Test that dispatcher errors report correct Python filenames and line numbers # when defining duplicate libraries (which triggers the filename tracking) @@ -628,6 +637,8 @@ def test_dispatcher_error_filenames(self) -> None: self.assertIn("FIRST_LIB_MARKER", first_line) self.assertIn("SECOND_LIB_MARKER", second_line) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_returning_symint(self) -> None: shape_env = ShapeEnv() fake_tensor_mode = FakeTensorMode(shape_env=shape_env) @@ -781,8 +792,14 @@ def test_produce_real_type(self) -> None: $0: f32[2, 2] = input('x') $1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64) $2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64) +<<<<<<< HEAD $3: f32[2] = torch._ops.aten.select.int($0, 1, 1) $4: f32[2] = torch._ops.aten.clone.default($3, memory_format=torch.contiguous_format)""", +======= +$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807) +$4: f32[2] = torch._ops.aten.select.int($3, 1, 1) +$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)""", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_optional_tensor_list(self) -> None: @@ -1758,6 +1775,52 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): self.assertEqual(s.device_index, 2) self.assertEqual(s.device_type, 3) +<<<<<<< HEAD +======= + def test_subclass_autograd_device_check(self) -> None: + class NonWrapperSubclass(torch.Tensor): + elem: torch.Tensor + + __slots__ = ["elem"] + + @staticmethod + def __new__(cls, elem, *args, **kwargs): + # Wrong device here! + r = torch.Tensor._make_subclass( + cls, elem.to("meta"), elem.requires_grad + ) + # ...the real tensor is held as an element on the tensor. + r.elem = elem + return r + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def unwrap(e): + return e.elem if isinstance(e, NonWrapperSubclass) else e + + def wrap(e): + return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e + + rs = tree_map( + wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) + ) + logging.getLogger("NonWrapperSubclass").info( + f"{func.__module__}.{func.__name__}", # noqa: G004 + args, + kwargs, + rs, + ) + return rs + + x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True)) + y = torch.randn(2, requires_grad=True) + z = x * y + self.assertIsInstance(z, NonWrapperSubclass) + z.sum().backward(torch.tensor(1)) + self.assertEqual(x.grad, y) + self.assertEqual(y.grad, x) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_none_wrapping(self): # A Tensor subclass that returns None when doing add # See LoggingTensor above for more details on the subclass @@ -1999,8 +2062,11 @@ def __new__(cls, data, wrapper): def __torch_dispatch__(cls, func, types, args, kwargs): if func.overloadpacket == torch.ops.aten.is_contiguous: return contiguous_data.is_contiguous() +<<<<<<< HEAD if func.overloadpacket == torch.ops.aten.sym_is_contiguous: return torch.ops.aten.sym_is_contiguous(contiguous_data) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return NotImplemented class ExampleTensor3(torch.Tensor): @@ -2014,8 +2080,11 @@ def __new__(cls, data, wrapper): def __torch_dispatch__(cls, func, types, args, kwargs): if func.overloadpacket == torch.ops.aten.is_contiguous: return not_contiguous_data.is_contiguous() +<<<<<<< HEAD if func.overloadpacket == torch.ops.aten.sym_is_contiguous: return torch.ops.aten.sym_is_contiguous(not_contiguous_data) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return NotImplemented err_msg = "Multiple dispatch failed for 'torch.ops.aten.is_contiguous'" @@ -2048,7 +2117,10 @@ def __new__(cls, data): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): if func in [ +<<<<<<< HEAD torch.ops.aten.sym_is_contiguous.default, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.ops.aten.is_contiguous.default, torch.ops.aten.is_contiguous.memory_format, torch.ops.aten.is_strides_like_format.default, @@ -2482,6 +2554,7 @@ def __torch_dispatch__(self, func, types, args, kwargs=None): self.assertEqual(res, t.a) self.assertIs(type(res), torch.Tensor) +<<<<<<< HEAD def test_custom_dispatch_mode_supports_higher_order_operators(self): class Mode(TorchDispatchMode): supports_higher_order_operators = True @@ -2528,6 +2601,8 @@ def __torch_dispatch__(self, func, types, args, kwargs): self.assertEqual(m.last_args[1], uarg) self.assertTrue((a == uarg).all().item()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestPythonDispatcher(TestCase): def test_basic(self): diff --git a/test/test_pytree.py b/test/test_pytree.py index e19f1471267cb..ea78ffc99d963 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -14,7 +14,11 @@ from typing import Any, NamedTuple, Optional import torch +<<<<<<< HEAD import torch.utils._pytree as python_pytree +======= +import torch.utils._pytree as py_pytree +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.immutable_collections import immutable_dict, immutable_list from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -26,6 +30,7 @@ ) +<<<<<<< HEAD pytree_modules = { "python": python_pytree, } @@ -43,6 +48,13 @@ [subtest(module, name=name) for name, module in pytree_modules.items()], ) +======= +if IS_FBCODE: + # optree is not yet enabled in fbcode, so just re-test the python implementation + cxx_pytree = py_pytree +else: + import torch.utils._cxx_pytree as cxx_pytree +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GlobalPoint = namedtuple("GlobalPoint", ["x", "y"]) @@ -65,17 +77,24 @@ class TestEnum(enum.Enum): A = auto() +<<<<<<< HEAD python_leafspec = python_pytree.LeafSpec() class TestGenericPytree(TestCase): def test_aligned_public_apis(self): public_apis = python_pytree.__all__ +======= +class TestGenericPytree(TestCase): + def test_aligned_public_apis(self): + public_apis = py_pytree.__all__ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(public_apis, cxx_pytree.__all__) for name in public_apis: cxx_api = getattr(cxx_pytree, name) +<<<<<<< HEAD python_api = getattr(python_pytree, name) self.assertEqual(inspect.isclass(cxx_api), inspect.isclass(python_api)) @@ -91,6 +110,20 @@ def test_aligned_public_apis(self): cxx_param_names = list(cxx_signature.parameters) python_param_names = list(python_signature.parameters) self.assertEqual(cxx_param_names, python_param_names) +======= + py_api = getattr(py_pytree, name) + + self.assertEqual(inspect.isclass(cxx_api), inspect.isclass(py_api)) + self.assertEqual(inspect.isfunction(cxx_api), inspect.isfunction(py_api)) + if inspect.isfunction(cxx_api): + cxx_signature = inspect.signature(cxx_api) + py_signature = inspect.signature(py_api) + + # Check the parameter names are the same. + cxx_param_names = list(cxx_signature.parameters) + py_param_names = list(py_signature.parameters) + self.assertEqual(cxx_param_names, py_param_names) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Check the positional parameters are the same. cxx_positional_param_names = [ @@ -104,9 +137,15 @@ def test_aligned_public_apis(self): } ) ] +<<<<<<< HEAD python_positional_param_names = [ n for n, p in python_signature.parameters.items() +======= + py_positional_param_names = [ + n + for n, p in py_signature.parameters.items() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( p.kind in { @@ -115,6 +154,7 @@ def test_aligned_public_apis(self): } ) ] +<<<<<<< HEAD self.assertEqual( cxx_positional_param_names, python_positional_param_names, @@ -131,6 +171,21 @@ def test_aligned_public_apis(self): # Check parameter annotations are the same. if "TreeSpec" in str(cxx_param.annotation): self.assertIn("TreeSpec", str(python_param.annotation)) +======= + self.assertEqual(cxx_positional_param_names, py_positional_param_names) + + for py_name, py_param in py_signature.parameters.items(): + self.assertIn(py_name, cxx_signature.parameters) + cxx_param = cxx_signature.parameters[py_name] + + # Check parameter kinds and default values are the same. + self.assertEqual(cxx_param.kind, py_param.kind) + self.assertEqual(cxx_param.default, py_param.default) + + # Check parameter annotations are the same. + if "TreeSpec" in str(cxx_param.annotation): + self.assertIn("TreeSpec", str(py_param.annotation)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( re.sub( r"(?:\b)([\w\.]*)TreeSpec(?:\b)", @@ -140,33 +195,60 @@ def test_aligned_public_apis(self): re.sub( r"(?:\b)([\w\.]*)TreeSpec(?:\b)", "TreeSpec", +<<<<<<< HEAD str(python_param.annotation), ), msg=( f"C++ parameter {cxx_param} " f"does not match Python parameter {python_param} " +======= + str(py_param.annotation), + ), + msg=( + f"C++ parameter {cxx_param} " + f"does not match Python parameter {py_param} " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"for API `{name}`" ), ) else: self.assertEqual( cxx_param.annotation, +<<<<<<< HEAD python_param.annotation, msg=( f"C++ parameter {cxx_param} " f"does not match Python parameter {python_param} " +======= + py_param.annotation, + msg=( + f"C++ parameter {cxx_param} " + f"does not match Python parameter {py_param} " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"for API `{name}`" ), ) +<<<<<<< HEAD @parametrize_pytree_module def test_register_pytree_node(self, pytree): +======= + @parametrize( + "pytree_impl", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_register_pytree_node(self, pytree_impl): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class MyDict(UserDict): pass d = MyDict(a=1, b=2, c=3) # Custom types are leaf nodes by default +<<<<<<< HEAD values, spec = pytree.tree_flatten(d) self.assertEqual(values, [d]) self.assertIs(values[0], d) @@ -175,11 +257,22 @@ class MyDict(UserDict): # Register MyDict as a pytree node pytree.register_pytree_node( +======= + values, spec = pytree_impl.tree_flatten(d) + self.assertEqual(values, [d]) + self.assertIs(values[0], d) + self.assertEqual(d, pytree_impl.tree_unflatten(values, spec)) + self.assertTrue(spec.is_leaf()) + + # Register MyDict as a pytree node + pytree_impl.register_pytree_node( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MyDict, lambda d: (list(d.values()), list(d.keys())), lambda values, keys: MyDict(zip(keys, values)), ) +<<<<<<< HEAD values, spec = pytree.tree_flatten(d) self.assertEqual(values, [1, 2, 3]) self.assertEqual(d, pytree.tree_unflatten(values, spec)) @@ -187,11 +280,21 @@ class MyDict(UserDict): # Do not allow registering the same type twice with self.assertRaisesRegex(ValueError, "already registered"): pytree.register_pytree_node( +======= + values, spec = pytree_impl.tree_flatten(d) + self.assertEqual(values, [1, 2, 3]) + self.assertEqual(d, pytree_impl.tree_unflatten(values, spec)) + + # Do not allow registering the same type twice + with self.assertRaisesRegex(ValueError, "already registered"): + pytree_impl.register_pytree_node( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MyDict, lambda d: (list(d.values()), list(d.keys())), lambda values, keys: MyDict(zip(keys, values)), ) +<<<<<<< HEAD @parametrize_pytree_module def test_flatten_unflatten_leaf(self, pytree): def run_test_with_leaf(leaf): @@ -200,6 +303,22 @@ def run_test_with_leaf(leaf): self.assertEqual(treespec, pytree.LeafSpec()) unflattened = pytree.tree_unflatten(values, treespec) +======= + @parametrize( + "pytree_impl", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_flatten_unflatten_leaf(self, pytree_impl): + def run_test_with_leaf(leaf): + values, treespec = pytree_impl.tree_flatten(leaf) + self.assertEqual(values, [leaf]) + self.assertEqual(treespec, pytree_impl.LeafSpec()) + + unflattened = pytree_impl.tree_unflatten(values, treespec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(unflattened, leaf) run_test_with_leaf(1) @@ -209,6 +328,7 @@ def run_test_with_leaf(leaf): run_test_with_leaf(torch.randn(3, 3)) @parametrize( +<<<<<<< HEAD "pytree,gen_expected_fn", [ subtest( @@ -219,6 +339,18 @@ def run_test_with_leaf(leaf): ), ), name="python", +======= + "pytree_impl,gen_expected_fn", + [ + subtest( + ( + py_pytree, + lambda tup: py_pytree.TreeSpec( + tuple, None, [py_pytree.LeafSpec() for _ in tup] + ), + ), + name="py", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), subtest( (cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))), @@ -226,15 +358,26 @@ def run_test_with_leaf(leaf): ), ], ) +<<<<<<< HEAD def test_flatten_unflatten_tuple(self, pytree, gen_expected_fn): def run_test(tup): expected_spec = gen_expected_fn(tup) values, treespec = pytree.tree_flatten(tup) +======= + def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn): + def run_test(tup): + expected_spec = gen_expected_fn(tup) + values, treespec = pytree_impl.tree_flatten(tup) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertIsInstance(values, list) self.assertEqual(values, list(tup)) self.assertEqual(treespec, expected_spec) +<<<<<<< HEAD unflattened = pytree.tree_unflatten(values, treespec) +======= + unflattened = pytree_impl.tree_unflatten(values, treespec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(unflattened, tup) self.assertIsInstance(unflattened, tuple) @@ -244,6 +387,7 @@ def run_test(tup): run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11)) @parametrize( +<<<<<<< HEAD "pytree,gen_expected_fn", [ subtest( @@ -254,6 +398,18 @@ def run_test(tup): ), ), name="python", +======= + "pytree_impl,gen_expected_fn", + [ + subtest( + ( + py_pytree, + lambda lst: py_pytree.TreeSpec( + list, None, [py_pytree.LeafSpec() for _ in lst] + ), + ), + name="py", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), subtest( (cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))), @@ -261,15 +417,26 @@ def run_test(tup): ), ], ) +<<<<<<< HEAD def test_flatten_unflatten_list(self, pytree, gen_expected_fn): def run_test(lst): expected_spec = gen_expected_fn(lst) values, treespec = pytree.tree_flatten(lst) +======= + def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn): + def run_test(lst): + expected_spec = gen_expected_fn(lst) + values, treespec = pytree_impl.tree_flatten(lst) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertIsInstance(values, list) self.assertEqual(values, lst) self.assertEqual(treespec, expected_spec) +<<<<<<< HEAD unflattened = pytree.tree_unflatten(values, treespec) +======= + unflattened = pytree_impl.tree_unflatten(values, treespec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(unflattened, lst) self.assertIsInstance(unflattened, list) @@ -278,6 +445,7 @@ def run_test(lst): run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11]) @parametrize( +<<<<<<< HEAD "pytree,gen_expected_fn", [ subtest( @@ -290,6 +458,20 @@ def run_test(lst): ), ), name="python", +======= + "pytree_impl,gen_expected_fn", + [ + subtest( + ( + py_pytree, + lambda dct: py_pytree.TreeSpec( + dict, + list(dct.keys()), + [py_pytree.LeafSpec() for _ in dct.values()], + ), + ), + name="py", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), subtest( ( @@ -300,15 +482,26 @@ def run_test(lst): ), ], ) +<<<<<<< HEAD def test_flatten_unflatten_dict(self, pytree, gen_expected_fn): def run_test(dct): expected_spec = gen_expected_fn(dct) values, treespec = pytree.tree_flatten(dct) +======= + def test_flatten_unflatten_dict(self, pytree_impl, gen_expected_fn): + def run_test(dct): + expected_spec = gen_expected_fn(dct) + values, treespec = pytree_impl.tree_flatten(dct) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertIsInstance(values, list) self.assertEqual(values, list(dct.values())) self.assertEqual(treespec, expected_spec) +<<<<<<< HEAD unflattened = pytree.tree_unflatten(values, treespec) +======= + unflattened = pytree_impl.tree_unflatten(values, treespec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(unflattened, dct) self.assertIsInstance(unflattened, dict) @@ -319,6 +512,7 @@ def run_test(dct): run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)}) @parametrize( +<<<<<<< HEAD "pytree,gen_expected_fn", [ subtest( @@ -331,6 +525,20 @@ def run_test(dct): ), ), name="python", +======= + "pytree_impl,gen_expected_fn", + [ + subtest( + ( + py_pytree, + lambda odict: py_pytree.TreeSpec( + OrderedDict, + list(odict.keys()), + [py_pytree.LeafSpec() for _ in odict.values()], + ), + ), + name="py", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), subtest( ( @@ -343,15 +551,26 @@ def run_test(dct): ), ], ) +<<<<<<< HEAD def test_flatten_unflatten_ordereddict(self, pytree, gen_expected_fn): def run_test(odict): expected_spec = gen_expected_fn(odict) values, treespec = pytree.tree_flatten(odict) +======= + def test_flatten_unflatten_ordereddict(self, pytree_impl, gen_expected_fn): + def run_test(odict): + expected_spec = gen_expected_fn(odict) + values, treespec = pytree_impl.tree_flatten(odict) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertIsInstance(values, list) self.assertEqual(values, list(odict.values())) self.assertEqual(treespec, expected_spec) +<<<<<<< HEAD unflattened = pytree.tree_unflatten(values, treespec) +======= + unflattened = pytree_impl.tree_unflatten(values, treespec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(unflattened, odict) self.assertIsInstance(unflattened, OrderedDict) @@ -363,6 +582,7 @@ def run_test(odict): run_test(od) @parametrize( +<<<<<<< HEAD "pytree,gen_expected_fn", [ subtest( @@ -375,6 +595,20 @@ def run_test(odict): ), ), name="python", +======= + "pytree_impl,gen_expected_fn", + [ + subtest( + ( + py_pytree, + lambda ddct: py_pytree.TreeSpec( + defaultdict, + [ddct.default_factory, list(ddct.keys())], + [py_pytree.LeafSpec() for _ in ddct.values()], + ), + ), + name="py", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), subtest( ( @@ -387,15 +621,26 @@ def run_test(odict): ), ], ) +<<<<<<< HEAD def test_flatten_unflatten_defaultdict(self, pytree, gen_expected_fn): def run_test(ddct): expected_spec = gen_expected_fn(ddct) values, treespec = pytree.tree_flatten(ddct) +======= + def test_flatten_unflatten_defaultdict(self, pytree_impl, gen_expected_fn): + def run_test(ddct): + expected_spec = gen_expected_fn(ddct) + values, treespec = pytree_impl.tree_flatten(ddct) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertIsInstance(values, list) self.assertEqual(values, list(ddct.values())) self.assertEqual(treespec, expected_spec) +<<<<<<< HEAD unflattened = pytree.tree_unflatten(values, treespec) +======= + unflattened = pytree_impl.tree_unflatten(values, treespec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(unflattened, ddct) self.assertEqual(unflattened.default_factory, ddct.default_factory) self.assertIsInstance(unflattened, defaultdict) @@ -407,6 +652,7 @@ def run_test(ddct): run_test(defaultdict(int, {"a": 1, "b": 2, "c": torch.randn(2, 3)})) @parametrize( +<<<<<<< HEAD "pytree,gen_expected_fn", [ subtest( @@ -417,6 +663,20 @@ def run_test(ddct): ), ), name="python", +======= + "pytree_impl,gen_expected_fn", + [ + subtest( + ( + py_pytree, + lambda deq: py_pytree.TreeSpec( + deque, + deq.maxlen, + [py_pytree.LeafSpec() for _ in deq], + ), + ), + name="py", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), subtest( ( @@ -429,15 +689,26 @@ def run_test(ddct): ), ], ) +<<<<<<< HEAD def test_flatten_unflatten_deque(self, pytree, gen_expected_fn): def run_test(deq): expected_spec = gen_expected_fn(deq) values, treespec = pytree.tree_flatten(deq) +======= + def test_flatten_unflatten_deque(self, pytree_impl, gen_expected_fn): + def run_test(deq): + expected_spec = gen_expected_fn(deq) + values, treespec = pytree_impl.tree_flatten(deq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertIsInstance(values, list) self.assertEqual(values, list(deq)) self.assertEqual(treespec, expected_spec) +<<<<<<< HEAD unflattened = pytree.tree_unflatten(values, treespec) +======= + unflattened = pytree_impl.tree_unflatten(values, treespec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(unflattened, deq) self.assertEqual(unflattened.maxlen, deq.maxlen) self.assertIsInstance(unflattened, deque) @@ -446,6 +717,7 @@ def run_test(deq): run_test(deque([1.0, 2])) run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8)) +<<<<<<< HEAD @parametrize_pytree_module def test_flatten_unflatten_namedtuple(self, pytree): Point = namedtuple("Point", ["x", "y"]) @@ -458,11 +730,35 @@ def run_test(tup): else: expected_spec = cxx_pytree.tree_structure(Point(0, 1)) values, treespec = pytree.tree_flatten(tup) +======= + @parametrize( + "pytree_impl", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_flatten_unflatten_namedtuple(self, pytree_impl): + Point = namedtuple("Point", ["x", "y"]) + + def run_test(tup): + if pytree_impl is py_pytree: + expected_spec = py_pytree.TreeSpec( + namedtuple, Point, [py_pytree.LeafSpec() for _ in tup] + ) + else: + expected_spec = cxx_pytree.tree_structure(Point(0, 1)) + values, treespec = pytree_impl.tree_flatten(tup) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertIsInstance(values, list) self.assertEqual(values, list(tup)) self.assertEqual(treespec, expected_spec) +<<<<<<< HEAD unflattened = pytree.tree_unflatten(values, treespec) +======= + unflattened = pytree_impl.tree_unflatten(values, treespec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(unflattened, tup) self.assertIsInstance(unflattened, Point) @@ -476,6 +772,7 @@ def run_test(tup): subtest(torch.min, name="min"), ], ) +<<<<<<< HEAD @parametrize_pytree_module def test_flatten_unflatten_return_types(self, pytree, op): x = torch.randn(3, 3) @@ -486,21 +783,57 @@ def test_flatten_unflatten_return_types(self, pytree, op): for value in values: self.assertIsInstance(value, torch.Tensor) result = pytree.tree_unflatten(values, spec) +======= + @parametrize( + "pytree_impl", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_flatten_unflatten_return_types(self, pytree_impl, op): + x = torch.randn(3, 3) + expected = op(x, dim=0) + + values, spec = pytree_impl.tree_flatten(expected) + # Check that values is actually List[Tensor] and not (ReturnType(...),) + for value in values: + self.assertIsInstance(value, torch.Tensor) + result = pytree_impl.tree_unflatten(values, spec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(type(result), type(expected)) self.assertEqual(result, expected) +<<<<<<< HEAD @parametrize_pytree_module def test_flatten_unflatten_nested(self, pytree): def run_test(tree): values, treespec = pytree.tree_flatten(tree) +======= + @parametrize( + "pytree_impl", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_flatten_unflatten_nested(self, pytree_impl): + def run_test(pytree): + values, treespec = pytree_impl.tree_flatten(pytree) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertIsInstance(values, list) self.assertEqual(len(values), treespec.num_leaves) # NB: python basic data structures (dict list tuple) all have # contents equality defined on them, so the following works for them. +<<<<<<< HEAD unflattened = pytree.tree_unflatten(values, treespec) self.assertEqual(unflattened, tree) +======= + unflattened = pytree_impl.tree_unflatten(values, treespec) + self.assertEqual(unflattened, pytree) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cases = [ [()], @@ -512,11 +845,25 @@ def run_test(tree): for case in cases: run_test(case) +<<<<<<< HEAD @parametrize_pytree_module def test_flatten_with_is_leaf(self, pytree): def run_test(tree, one_level_leaves): values, treespec = pytree.tree_flatten( tree, is_leaf=lambda x: x is not tree +======= + @parametrize( + "pytree_impl", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_flatten_with_is_leaf(self, pytree_impl): + def run_test(pytree, one_level_leaves): + values, treespec = pytree_impl.tree_flatten( + pytree, is_leaf=lambda x: x is not pytree +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.assertIsInstance(values, list) self.assertEqual(len(values), treespec.num_nodes - 1) @@ -526,6 +873,7 @@ def run_test(tree, one_level_leaves): self.assertEqual( treespec, +<<<<<<< HEAD pytree.tree_structure( pytree.tree_unflatten([0] * treespec.num_leaves, treespec) ), @@ -533,6 +881,15 @@ def run_test(tree, one_level_leaves): unflattened = pytree.tree_unflatten(values, treespec) self.assertEqual(unflattened, tree) +======= + pytree_impl.tree_structure( + pytree_impl.tree_unflatten([0] * treespec.num_leaves, treespec) + ), + ) + + unflattened = pytree_impl.tree_unflatten(values, treespec) + self.assertEqual(unflattened, pytree) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cases = [ ([()], [()]), @@ -551,6 +908,7 @@ def run_test(tree, one_level_leaves): for case in cases: run_test(*case) +<<<<<<< HEAD @parametrize_pytree_module def test_tree_map(self, pytree): def run_test(tree): @@ -559,14 +917,35 @@ def f(x): sm1 = sum(map(f, pytree.tree_leaves(tree))) sm2 = sum(pytree.tree_leaves(pytree.tree_map(f, tree))) +======= + @parametrize( + "pytree_impl", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_tree_map(self, pytree_impl): + def run_test(pytree): + def f(x): + return x * 3 + + sm1 = sum(map(f, pytree_impl.tree_leaves(pytree))) + sm2 = sum(pytree_impl.tree_leaves(pytree_impl.tree_map(f, pytree))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(sm1, sm2) def invf(x): return x // 3 self.assertEqual( +<<<<<<< HEAD pytree.tree_map(invf, pytree.tree_map(f, tree)), tree, +======= + pytree_impl.tree_map(invf, pytree_impl.tree_map(f, pytree)), + pytree, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) cases = [ @@ -579,6 +958,7 @@ def invf(x): for case in cases: run_test(case) +<<<<<<< HEAD @parametrize_pytree_module def test_tree_map_multi_inputs(self, pytree): def run_test(tree): @@ -592,6 +972,29 @@ def f(x, y, z): self.assertEqual( pytree.tree_map(f, tree_x, tree_y, tree_z), pytree.tree_map(lambda x: f(x, (x + 1,), {"a": x * 2, "b": 2}), tree), +======= + @parametrize( + "pytree_impl", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_tree_map_multi_inputs(self, pytree_impl): + def run_test(pytree): + def f(x, y, z): + return x, [y, (z, 0)] + + pytree_x = pytree + pytree_y = pytree_impl.tree_map(lambda x: (x + 1,), pytree) + pytree_z = pytree_impl.tree_map(lambda x: {"a": x * 2, "b": 2}, pytree) + + self.assertEqual( + pytree_impl.tree_map(f, pytree_x, pytree_y, pytree_z), + pytree_impl.tree_map( + lambda x: f(x, (x + 1,), {"a": x * 2, "b": 2}), pytree + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) cases = [ @@ -604,6 +1007,7 @@ def f(x, y, z): for case in cases: run_test(case) +<<<<<<< HEAD @parametrize_pytree_module def test_tree_map_only(self, pytree): self.assertEqual(pytree.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]) @@ -627,6 +1031,57 @@ def test_tree_all_any(self, pytree): @parametrize_pytree_module def test_broadcast_to_and_flatten(self, pytree): +======= + @parametrize( + "pytree_impl", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_tree_map_only(self, pytree_impl): + self.assertEqual( + pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"] + ) + + @parametrize( + "pytree_impl", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_tree_map_only_predicate_fn(self, pytree_impl): + self.assertEqual( + pytree_impl.tree_map_only(lambda x: x == 0, lambda x: x + 2, [0, 1]), [2, 1] + ) + + @parametrize( + "pytree_impl", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_tree_all_any(self, pytree_impl): + self.assertTrue(pytree_impl.tree_all(lambda x: x % 2, [1, 3])) + self.assertFalse(pytree_impl.tree_all(lambda x: x % 2, [0, 1])) + self.assertTrue(pytree_impl.tree_any(lambda x: x % 2, [0, 1])) + self.assertFalse(pytree_impl.tree_any(lambda x: x % 2, [0, 2])) + self.assertTrue(pytree_impl.tree_all_only(int, lambda x: x % 2, [1, 3, "a"])) + self.assertFalse(pytree_impl.tree_all_only(int, lambda x: x % 2, [0, 1, "a"])) + self.assertTrue(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 1, "a"])) + self.assertFalse(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 2, "a"])) + + @parametrize( + "pytree_impl", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_broadcast_to_and_flatten(self, pytree_impl): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cases = [ (1, (), []), # Same (flat) structures @@ -659,6 +1114,7 @@ def test_broadcast_to_and_flatten(self, pytree): ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]), (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]), ] +<<<<<<< HEAD for tree, to_tree, expected in cases: _, to_spec = pytree.tree_flatten(to_tree) result = pytree._broadcast_to_and_flatten(tree, to_spec) @@ -670,6 +1126,31 @@ def test_pytree_serialize_bad_input(self, pytree): pytree.treespec_dumps("random_blurb") @parametrize_pytree_module +======= + for pytree, to_pytree, expected in cases: + _, to_spec = pytree_impl.tree_flatten(to_pytree) + result = pytree_impl._broadcast_to_and_flatten(pytree, to_spec) + self.assertEqual(result, expected, msg=str([pytree, to_spec, expected])) + + @parametrize( + "pytree_impl", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_pytree_serialize_bad_input(self, pytree_impl): + with self.assertRaises(TypeError): + pytree_impl.treespec_dumps("random_blurb") + + @parametrize( + "pytree", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_is_namedtuple(self, pytree): DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"]) @@ -710,7 +1191,17 @@ class IndirectNamedTuple2(DirectNamedTuple2): self.assertFalse(pytree.is_namedtuple_class(tuple)) self.assertFalse(pytree.is_namedtuple_class(list)) +<<<<<<< HEAD @parametrize_pytree_module +======= + @parametrize( + "pytree", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_is_structseq(self, pytree): class FakeStructSeq(tuple): n_fields = 2 @@ -784,7 +1275,17 @@ class DirectNamedTuple2(NamedTuple): self.assertFalse(pytree.is_namedtuple(cls)) self.assertFalse(pytree.is_namedtuple_class(cls)) +<<<<<<< HEAD @parametrize_pytree_module +======= + @parametrize( + "pytree", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_enum_treespec_roundtrip(self, pytree): data = {TestEnum.A: 5} spec = pytree.tree_structure(data) @@ -804,14 +1305,22 @@ def __init__(self, x, y): with self.assertWarnsRegex( FutureWarning, "torch.utils._pytree._register_pytree_node" ): +<<<<<<< HEAD python_pytree._register_pytree_node( +======= + py_pytree._register_pytree_node( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), ) with self.assertWarnsRegex(UserWarning, "already registered"): +<<<<<<< HEAD python_pytree._register_pytree_node( +======= + py_pytree._register_pytree_node( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -848,6 +1357,7 @@ def test_import_pytree_doesnt_import_optree(self): def test_treespec_equality(self): self.assertEqual( +<<<<<<< HEAD python_pytree.LeafSpec(), python_pytree.LeafSpec(), ) @@ -866,12 +1376,35 @@ def test_treespec_equality(self): self.assertTrue( python_pytree.TreeSpec(tuple, None, []) != python_pytree.TreeSpec(list, None, []), +======= + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ) + self.assertEqual( + py_pytree.TreeSpec(list, None, []), + py_pytree.TreeSpec(list, None, []), + ) + self.assertEqual( + py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]), + py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]), + ) + self.assertFalse( + py_pytree.TreeSpec(tuple, None, []) == py_pytree.TreeSpec(list, None, []), + ) + self.assertTrue( + py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_treespec_repr(self): # Check that it looks sane +<<<<<<< HEAD tree = (0, [0, 0, [0]]) spec = python_pytree.tree_structure(tree) +======= + pytree = (0, [0, 0, [0]]) + _, spec = py_pytree.tree_flatten(pytree) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( repr(spec), ( @@ -885,6 +1418,7 @@ def test_treespec_repr(self): @parametrize( "spec", [ +<<<<<<< HEAD # python_pytree.tree_structure([]) python_pytree.TreeSpec(list, None, []), # python_pytree.tree_structure(()) @@ -944,11 +1478,94 @@ def test_treespec_repr(self): list, None, [python_leafspec, python_leafspec], +======= + # py_pytree.tree_structure([]) + py_pytree.TreeSpec(list, None, []), + # py_pytree.tree_structure(()) + py_pytree.TreeSpec(tuple, None, []), + # py_pytree.tree_structure({}) + py_pytree.TreeSpec(dict, [], []), + # py_pytree.tree_structure([0]) + py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]), + # py_pytree.tree_structure([0, 1]) + py_pytree.TreeSpec( + list, + None, + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ], + ), + # py_pytree.tree_structure((0, 1, 2)) + py_pytree.TreeSpec( + tuple, + None, + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ], + ), + # py_pytree.tree_structure({"a": 0, "b": 1, "c": 2}) + py_pytree.TreeSpec( + dict, + ["a", "b", "c"], + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ], + ), + # py_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})]) + py_pytree.TreeSpec( + OrderedDict, + ["a", "b", "c"], + [ + py_pytree.TreeSpec( + tuple, + None, + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ], + ), + py_pytree.LeafSpec(), + py_pytree.TreeSpec( + dict, + ["a", "b", "c"], + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ], + ), + ], + ), + # py_pytree.tree_structure([(0, 1, [2, 3])]) + py_pytree.TreeSpec( + list, + None, + [ + py_pytree.TreeSpec( + tuple, + None, + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + py_pytree.TreeSpec( + list, + None, + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ], ), ], ), +<<<<<<< HEAD # python_pytree.tree_structure(defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}})) python_pytree.TreeSpec( defaultdict, @@ -965,6 +1582,30 @@ def test_treespec_repr(self): [python_leafspec, python_leafspec], ), python_pytree.TreeSpec(dict, [], []), +======= + # py_pytree.tree_structure(defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}})) + py_pytree.TreeSpec( + defaultdict, + [list, ["a", "b", "c"]], + [ + py_pytree.TreeSpec( + list, + None, + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ], + ), + py_pytree.TreeSpec( + list, + None, + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ], + ), + py_pytree.TreeSpec(dict, [], []), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ], ), ], @@ -973,6 +1614,7 @@ def test_pytree_serialize(self, spec): # Ensure that the spec is valid self.assertEqual( spec, +<<<<<<< HEAD python_pytree.tree_structure( python_pytree.tree_unflatten([0] * spec.num_leaves, spec) ), @@ -992,10 +1634,32 @@ def test_pytree_serialize_defaultdict_enum(self): None, [ python_leafspec, +======= + py_pytree.tree_structure( + py_pytree.tree_unflatten([0] * spec.num_leaves, spec) + ), + ) + + serialized_spec = py_pytree.treespec_dumps(spec) + self.assertIsInstance(serialized_spec, str) + self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec)) + + def test_pytree_serialize_defaultdict_enum(self): + spec = py_pytree.TreeSpec( + defaultdict, + [list, [TestEnum.A]], + [ + py_pytree.TreeSpec( + list, + None, + [ + py_pytree.LeafSpec(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ], ), ], ) +<<<<<<< HEAD serialized_spec = python_pytree.treespec_dumps(spec) self.assertIsInstance(serialized_spec, str) @@ -1003,62 +1667,109 @@ def test_pytree_serialize_enum(self): spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_leafspec]) serialized_spec = python_pytree.treespec_dumps(spec) +======= + serialized_spec = py_pytree.treespec_dumps(spec) + self.assertIsInstance(serialized_spec, str) + + def test_pytree_serialize_enum(self): + spec = py_pytree.TreeSpec(dict, TestEnum.A, [py_pytree.LeafSpec()]) + + serialized_spec = py_pytree.treespec_dumps(spec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertIsInstance(serialized_spec, str) def test_pytree_serialize_namedtuple(self): Point1 = namedtuple("Point1", ["x", "y"]) +<<<<<<< HEAD python_pytree._register_namedtuple( +======= + py_pytree._register_namedtuple( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Point1, serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1", ) +<<<<<<< HEAD spec = python_pytree.tree_structure(Point1(1, 2)) self.assertIs(spec.type, namedtuple) roundtrip_spec = python_pytree.treespec_loads( python_pytree.treespec_dumps(spec) ) +======= + spec = py_pytree.tree_structure(Point1(1, 2)) + self.assertIs(spec.type, namedtuple) + roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(spec, roundtrip_spec) class Point2(NamedTuple): x: int y: int +<<<<<<< HEAD python_pytree._register_namedtuple( +======= + py_pytree._register_namedtuple( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Point2, serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2", ) +<<<<<<< HEAD spec = python_pytree.tree_structure(Point2(1, 2)) self.assertIs(spec.type, namedtuple) roundtrip_spec = python_pytree.treespec_loads( python_pytree.treespec_dumps(spec) ) +======= + spec = py_pytree.tree_structure(Point2(1, 2)) + self.assertIs(spec.type, namedtuple) + roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(spec, roundtrip_spec) class Point3(Point2): pass +<<<<<<< HEAD python_pytree._register_namedtuple( +======= + py_pytree._register_namedtuple( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Point3, serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point3", ) +<<<<<<< HEAD spec = python_pytree.tree_structure(Point3(1, 2)) self.assertIs(spec.type, namedtuple) roundtrip_spec = python_pytree.treespec_loads( python_pytree.treespec_dumps(spec) ) +======= + spec = py_pytree.tree_structure(Point3(1, 2)) + self.assertIs(spec.type, namedtuple) + roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(spec, roundtrip_spec) def test_pytree_serialize_namedtuple_bad(self): DummyType = namedtuple("DummyType", ["x", "y"]) +<<<<<<< HEAD spec = python_pytree.tree_structure(DummyType(1, 2)) +======= + spec = py_pytree.tree_structure(DummyType(1, 2)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaisesRegex( NotImplementedError, "Please register using `_register_namedtuple`" ): +<<<<<<< HEAD python_pytree.treespec_dumps(spec) +======= + py_pytree.treespec_dumps(spec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_pytree_custom_type_serialize_bad(self): class DummyType: @@ -1066,17 +1777,29 @@ def __init__(self, x, y): self.x = x self.y = y +<<<<<<< HEAD python_pytree.register_pytree_node( +======= + py_pytree.register_pytree_node( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), ) +<<<<<<< HEAD spec = python_pytree.tree_structure(DummyType(1, 2)) with self.assertRaisesRegex( NotImplementedError, "No registered serialization name" ): python_pytree.treespec_dumps(spec) +======= + spec = py_pytree.tree_structure(DummyType(1, 2)) + with self.assertRaisesRegex( + NotImplementedError, "No registered serialization name" + ): + py_pytree.treespec_dumps(spec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_pytree_custom_type_serialize(self): class DummyType: @@ -1084,7 +1807,11 @@ def __init__(self, x, y): self.x = x self.y = y +<<<<<<< HEAD python_pytree.register_pytree_node( +======= + py_pytree.register_pytree_node( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -1092,10 +1819,17 @@ def __init__(self, x, y): to_dumpable_context=lambda context: "moo", from_dumpable_context=lambda dumpable_context: None, ) +<<<<<<< HEAD spec = python_pytree.tree_structure(DummyType(1, 2)) serialized_spec = python_pytree.treespec_dumps(spec, 1) self.assertIn("moo", serialized_spec) roundtrip_spec = python_pytree.treespec_loads(serialized_spec) +======= + spec = py_pytree.tree_structure(DummyType(1, 2)) + serialized_spec = py_pytree.treespec_dumps(spec, 1) + self.assertIn("moo", serialized_spec) + roundtrip_spec = py_pytree.treespec_loads(serialized_spec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(roundtrip_spec, spec) def test_pytree_serialize_register_bad(self): @@ -1107,7 +1841,11 @@ def __init__(self, x, y): with self.assertRaisesRegex( ValueError, "Both to_dumpable_context and from_dumpable_context" ): +<<<<<<< HEAD python_pytree.register_pytree_node( +======= + py_pytree.register_pytree_node( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -1121,7 +1859,11 @@ def __init__(self, x, y): self.x = x self.y = y +<<<<<<< HEAD python_pytree.register_pytree_node( +======= + py_pytree.register_pytree_node( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DummyType, lambda dummy: ([dummy.x, dummy.y], None), lambda xs, _: DummyType(*xs), @@ -1130,31 +1872,51 @@ def __init__(self, x, y): from_dumpable_context=lambda dumpable_context: None, ) +<<<<<<< HEAD spec = python_pytree.tree_structure(DummyType(1, 2)) +======= + spec = py_pytree.tree_structure(DummyType(1, 2)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaisesRegex( TypeError, "Object of type type is not JSON serializable" ): +<<<<<<< HEAD python_pytree.treespec_dumps(spec) +======= + py_pytree.treespec_dumps(spec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_pytree_serialize_bad_protocol(self): import json Point = namedtuple("Point", ["x", "y"]) +<<<<<<< HEAD spec = python_pytree.tree_structure(Point(1, 2)) python_pytree._register_namedtuple( +======= + spec = py_pytree.tree_structure(Point(1, 2)) + py_pytree._register_namedtuple( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Point, serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point", ) with self.assertRaisesRegex(ValueError, "Unknown protocol"): +<<<<<<< HEAD python_pytree.treespec_dumps(spec, -1) serialized_spec = python_pytree.treespec_dumps(spec) +======= + py_pytree.treespec_dumps(spec, -1) + + serialized_spec = py_pytree.treespec_dumps(spec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _, data = json.loads(serialized_spec) bad_protocol_serialized_spec = json.dumps((-1, data)) with self.assertRaisesRegex(ValueError, "Unknown protocol"): +<<<<<<< HEAD python_pytree.treespec_loads(bad_protocol_serialized_spec) def test_saved_serialized(self): @@ -1169,20 +1931,51 @@ def test_saved_serialized(self): dict, [4, 5, 6], [python_leafspec, python_leafspec, python_leafspec], +======= + py_pytree.treespec_loads(bad_protocol_serialized_spec) + + def test_saved_serialized(self): + # py_pytree.tree_structure(OrderedDict([(1, (0, 1)), (2, 2), (3, {4: 3, 5: 4, 6: 5})])) + complicated_spec = py_pytree.TreeSpec( + OrderedDict, + [1, 2, 3], + [ + py_pytree.TreeSpec( + tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] + ), + py_pytree.LeafSpec(), + py_pytree.TreeSpec( + dict, + [4, 5, 6], + [ + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + py_pytree.LeafSpec(), + ], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ], ) # Ensure that the spec is valid self.assertEqual( complicated_spec, +<<<<<<< HEAD python_pytree.tree_structure( python_pytree.tree_unflatten( +======= + py_pytree.tree_structure( + py_pytree.tree_unflatten( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [0] * complicated_spec.num_leaves, complicated_spec ) ), ) +<<<<<<< HEAD serialized_spec = python_pytree.treespec_dumps(complicated_spec) +======= + serialized_spec = py_pytree.treespec_dumps(complicated_spec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) saved_spec = ( '[1, {"type": "collections.OrderedDict", "context": "[1, 2, 3]", ' '"children_spec": [{"type": "builtins.tuple", "context": "null", ' @@ -1195,11 +1988,19 @@ def test_saved_serialized(self): '[]}, {"type": null, "context": null, "children_spec": []}]}]}]' ) self.assertEqual(serialized_spec, saved_spec) +<<<<<<< HEAD self.assertEqual(complicated_spec, python_pytree.treespec_loads(saved_spec)) def test_tree_map_with_path(self): tree = [{i: i for i in range(10)}] all_zeros = python_pytree.tree_map_with_path( +======= + self.assertEqual(complicated_spec, py_pytree.treespec_loads(saved_spec)) + + def test_tree_map_with_path(self): + tree = [{i: i for i in range(10)}] + all_zeros = py_pytree.tree_map_with_path( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lambda kp, val: val - kp[1].key + kp[0].idx, tree ) self.assertEqual(all_zeros, [dict.fromkeys(range(10), 0)]) @@ -1212,14 +2013,22 @@ class Data: c: Optional[str] = None d: str = field(init=False, default="") +<<<<<<< HEAD python_pytree.register_dataclass(Data) old_data = Data(torch.tensor(3), "b", "c") old_data.d = "d" new_data = python_pytree.tree_map(lambda x: x, old_data) +======= + py_pytree.register_dataclass(Data) + old_data = Data(torch.tensor(3), "b", "c") + old_data.d = "d" + new_data = py_pytree.tree_unflatten(*py_pytree.tree_flatten(old_data)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(new_data.a, torch.tensor(3)) self.assertEqual(new_data.b, "b") self.assertEqual(new_data.c, "c") self.assertEqual(new_data.d, "") +<<<<<<< HEAD python_pytree._deregister_pytree_node(Data) with self.assertRaisesRegex(ValueError, "Missing fields"): @@ -1240,6 +2049,28 @@ class Data: self.assertEqual(new_data.b, "moo") self.assertEqual(new_data.c, None) python_pytree._deregister_pytree_node(Data) +======= + py_pytree._deregister_pytree_node(Data) + + with self.assertRaisesRegex(ValueError, "Missing fields"): + py_pytree.register_dataclass(Data, field_names=["a", "b"]) + + with self.assertRaisesRegex(ValueError, "Unexpected fields"): + py_pytree.register_dataclass(Data, field_names=["a", "b", "e"]) + + with self.assertRaisesRegex(ValueError, "Unexpected fields"): + py_pytree.register_dataclass(Data, field_names=["a", "b", "c", "d"]) + + py_pytree.register_dataclass( + Data, field_names=["a"], drop_field_names=["b", "c"] + ) + old_data = Data(torch.tensor(3), "b", "c") + new_data = py_pytree.tree_unflatten(*py_pytree.tree_flatten(old_data)) + self.assertEqual(new_data.a, torch.tensor(3)) + self.assertEqual(new_data.b, "moo") + self.assertEqual(new_data.c, None) + py_pytree._deregister_pytree_node(Data) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_register_dataclass_class(self): class CustomClass: @@ -1248,11 +2079,19 @@ def __init__(self, x, y): self.y = y with self.assertRaisesRegex(ValueError, "field_names must be specified"): +<<<<<<< HEAD python_pytree.register_dataclass(CustomClass) python_pytree.register_dataclass(CustomClass, field_names=["x", "y"]) c = CustomClass(torch.tensor(0), torch.tensor(1)) mapped = python_pytree.tree_map(lambda x: x + 1, c) +======= + py_pytree.register_dataclass(CustomClass) + + py_pytree.register_dataclass(CustomClass, field_names=["x", "y"]) + c = CustomClass(torch.tensor(0), torch.tensor(1)) + mapped = py_pytree.tree_map(lambda x: x + 1, c) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(mapped.x, torch.tensor(1)) self.assertEqual(mapped.y, torch.tensor(2)) @@ -1263,10 +2102,17 @@ def test_constant(self): class Config: norm: str +<<<<<<< HEAD python_pytree.register_constant(Config) config = Config("l1") elements, spec = python_pytree.tree_flatten(config) +======= + py_pytree.register_constant(Config) + + config = Config("l1") + elements, spec = py_pytree.tree_flatten(config) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(elements, []) self.assertEqual(spec.context.value, config) @@ -1276,7 +2122,11 @@ def __init__(self, norm: str): self.norm = norm try: +<<<<<<< HEAD python_pytree.register_constant(Config) +======= + py_pytree.register_constant(Config) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertFalse(True) # must raise error before this except TypeError as e: msg = "register_constant(cls) expects `cls` to have a non-default `__eq__` implementation." @@ -1291,7 +2141,11 @@ def __eq__(self, other): return self.norm == other.norm try: +<<<<<<< HEAD python_pytree.register_constant(Config) +======= + py_pytree.register_constant(Config) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertFalse(True) # must raise error before this except TypeError as e: msg = "register_constant(cls) expects `cls` to have a non-default `__hash__` implementation." @@ -1307,23 +2161,40 @@ class ACustomPytree: tree1 = [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5] tree2 = [ACustomPytree(x=2, y={"cin": [2, 2, 2], "bar": 2}, z="leaf"), 2] +<<<<<<< HEAD python_pytree.register_pytree_node( +======= + py_pytree.register_pytree_node( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ACustomPytree, flatten_fn=lambda f: ([f.x, f.y], f.z), unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z), flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z), ) +<<<<<<< HEAD from_two_trees = python_pytree.tree_map_with_path( lambda kp, a, b: a + b, tree1, tree2 ) from_one_tree = python_pytree.tree_map(lambda a: a + 2, tree1) +======= + from_two_trees = py_pytree.tree_map_with_path( + lambda kp, a, b: a + b, tree1, tree2 + ) + from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(from_two_trees, from_one_tree) def test_tree_flatten_with_path_is_leaf(self): leaf_dict = {"foo": [(3)]} +<<<<<<< HEAD tree = (["hello", [1, 2], leaf_dict],) key_leaves, _ = python_pytree.tree_flatten_with_path( tree, is_leaf=lambda x: isinstance(x, dict) +======= + pytree = (["hello", [1, 2], leaf_dict],) + key_leaves, _ = py_pytree.tree_flatten_with_path( + pytree, is_leaf=lambda x: isinstance(x, dict) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.assertTrue(key_leaves[-1][1] is leaf_dict) @@ -1339,7 +2210,11 @@ class ACustomPytree: y: Any z: Any +<<<<<<< HEAD python_pytree.register_pytree_node( +======= + py_pytree.register_pytree_node( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ACustomPytree, flatten_fn=lambda f: ([f.x, f.y], f.z), unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z), @@ -1352,12 +2227,19 @@ class ACustomPytree: [ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")], [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5], ] +<<<<<<< HEAD for tree in SOME_PYTREES: key_leaves, spec = python_pytree.tree_flatten_with_path(tree) actual = python_pytree.tree_unflatten( [leaf for _, leaf in key_leaves], spec ) self.assertEqual(actual, tree) +======= + for pytree in SOME_PYTREES: + key_leaves, spec = py_pytree.tree_flatten_with_path(pytree) + actual = py_pytree.tree_unflatten([leaf for _, leaf in key_leaves], spec) + self.assertEqual(actual, pytree) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_tree_leaves_with_path(self): class ANamedTuple(NamedTuple): @@ -1371,7 +2253,11 @@ class ACustomPytree: y: Any z: Any +<<<<<<< HEAD python_pytree.register_pytree_node( +======= + py_pytree.register_pytree_node( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ACustomPytree, flatten_fn=lambda f: ([f.x, f.y], f.z), unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z), @@ -1384,9 +2270,15 @@ class ACustomPytree: [ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")], [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5], ] +<<<<<<< HEAD for tree in SOME_PYTREES: flat_out, _ = python_pytree.tree_flatten_with_path(tree) leaves_out = python_pytree.tree_leaves_with_path(tree) +======= + for pytree in SOME_PYTREES: + flat_out, _ = py_pytree.tree_flatten_with_path(pytree) + leaves_out = py_pytree.tree_leaves_with_path(pytree) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(flat_out, leaves_out) def test_key_str(self): @@ -1395,8 +2287,13 @@ class ANamedTuple(NamedTuple): y: int tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],) +<<<<<<< HEAD flat, _ = python_pytree.tree_flatten_with_path(tree) paths = [f"{python_pytree.keystr(kp)}: {val}" for kp, val in flat] +======= + flat, _ = py_pytree.tree_flatten_with_path(tree) + paths = [f"{py_pytree.keystr(kp)}: {val}" for kp, val in flat] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( paths, [ @@ -1411,7 +2308,11 @@ class ANamedTuple(NamedTuple): def test_flatten_flatten_with_key_consistency(self): """Check that flatten and flatten_with_key produces consistent leaves/context.""" +<<<<<<< HEAD reg = python_pytree.SUPPORTED_NODES +======= + reg = py_pytree.SUPPORTED_NODES +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) EXAMPLE_TREE = { list: [1, 2, 3], @@ -1430,8 +2331,13 @@ def test_flatten_flatten_with_key_consistency(self): example = EXAMPLE_TREE.get(typ) if example is None: continue +<<<<<<< HEAD flat_with_path, spec1 = python_pytree.tree_flatten_with_path(example) flat, spec2 = python_pytree.tree_flatten(example) +======= + flat_with_path, spec1 = py_pytree.tree_flatten_with_path(example) + flat, spec2 = py_pytree.tree_flatten(example) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(flat, [x[1] for x in flat_with_path]) self.assertEqual(spec1, spec2) @@ -1442,9 +2348,15 @@ class ANamedTuple(NamedTuple): y: int tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],) +<<<<<<< HEAD flat, _ = python_pytree.tree_flatten_with_path(tree) for kp, val in flat: self.assertEqual(python_pytree.key_get(tree, kp), val) +======= + flat, _ = py_pytree.tree_flatten_with_path(tree) + for kp, val in flat: + self.assertEqual(py_pytree.key_get(tree, kp), val) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestCxxPytree(TestCase): @@ -1457,8 +2369,13 @@ def test_treespec_equality(self): def test_treespec_repr(self): # Check that it looks sane +<<<<<<< HEAD tree = (0, [0, 0, [0]]) spec = cxx_pytree.tree_structure(tree) +======= + pytree = (0, [0, 0, [0]]) + _, spec = cxx_pytree.tree_flatten(pytree) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual( repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')" ) @@ -1495,7 +2412,11 @@ def test_pytree_serialize(self, spec): self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec)) def test_pytree_serialize_namedtuple(self): +<<<<<<< HEAD python_pytree._register_namedtuple( +======= + py_pytree._register_namedtuple( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GlobalPoint, serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.GlobalPoint", ) @@ -1505,7 +2426,11 @@ def test_pytree_serialize_namedtuple(self): self.assertEqual(roundtrip_spec.type._fields, spec.type._fields) LocalPoint = namedtuple("LocalPoint", ["x", "y"]) +<<<<<<< HEAD python_pytree._register_namedtuple( +======= + py_pytree._register_namedtuple( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) LocalPoint, serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.LocalPoint", ) diff --git a/test/test_quantization.py b/test/test_quantization.py index 6d72da3279e1c..2353d9987b3f9 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -38,6 +38,16 @@ from quantization.core.test_workflow_module import TestFusedObsFakeQuantModule # noqa: F401 from quantization.core.test_backend_config import TestBackendConfig # noqa: F401 from quantization.core.test_utils import TestUtils # noqa: F401 +<<<<<<< HEAD +======= +log = logging.getLogger(__name__) +try: + # This test has extra data dependencies, so in some environments, e.g. Meta internal + # Buck, it has its own test runner. + from quantization.core.test_docs import TestQuantizationDocs # noqa: F401 +except ImportError as e: + log.warning(e) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Eager Mode Workflow. Tests for the functionality of APIs and different features implemented # using eager mode. @@ -60,7 +70,10 @@ from quantization.eager.test_bias_correction_eager import TestBiasCorrectionEager # noqa: F401 +<<<<<<< HEAD log = logging.getLogger(__name__) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # FX GraphModule Graph Mode Quantization. Tests for the functionality of APIs and different features implemented # using fx quantization. try: diff --git a/test/test_reductions.py b/test/test_reductions.py index f0ec8b434535b..5aba0fe6e803b 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -735,7 +735,11 @@ def test_numpy_named_args(self, device): res2 = x1.sum(axis=(0, 2), keepdims=True) self.assertEqual(res1, res2) +<<<<<<< HEAD # TODO: kill this and replace with common creation ops +======= + # TODO: kill this ane replace with common creation ops +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True, use_complex=False) -> dict[str, list[torch.Tensor]]: float_types = [torch.double, @@ -1629,7 +1633,11 @@ def test_bucketization(self, device): RuntimeError, "only when boundaries tensor dimension is 1"): torch.searchsorted(boundaries, 1) +<<<<<<< HEAD # incompatible output tensor's dtype +======= + # incompatiable output tensor's dtype +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_output_dtype(dtype, is_int32): output = values_1d.to(dtype) with self.assertRaisesRegex( @@ -2018,7 +2026,11 @@ def test_repeated_dim(self, device): with self.assertRaisesRegex(RuntimeError, error_msg): op(x, dim=dim) +<<<<<<< HEAD # TODO: update this test to compare against NumPy +======= + # TODO: update this test to comapre against NumPy +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyCUDA def test_var(self, device): cpu_tensor = torch.randn(2, 3, 3) @@ -2513,7 +2525,11 @@ def test_median_real_values(self, device, dtype): k = int((t.numel() - 1) / 2) self.assertEqual(res, t.view(-1).sort()[0][k]) if t.numel() % 2 == 1: +<<<<<<< HEAD # We can only test against numpy for odd reductions because numpy +======= + # We can only test agains numpy for odd reductions because numpy +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # returns the mean of the two medians and torch returns the lower self.assertEqual(res.cpu().numpy(), np.median(t_numpy)) for dim in range(t.ndim): @@ -2524,7 +2540,11 @@ def test_median_real_values(self, device, dtype): self.assertEqual(res[0], (t.sort(dim)[0]).select(dim, k).unsqueeze_(dim)) self.assertEqual(res[0], t.gather(dim, res[1])) if size % 2 == 1: +<<<<<<< HEAD # We can only test against numpy for odd reductions because numpy +======= + # We can only test agains numpy for odd reductions because numpy +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # returns the mean of the two medians and torch returns the lower self.assertEqual(res[0].cpu().numpy(), np.median(t_numpy, dim, keepdims=True), exact_dtype=False) @@ -2548,7 +2568,11 @@ def test_median_nan_values(self, device, dtype): k = int((t.numel() - num_nan - 1) / 2) self.assertEqual(res, t.view(-1).sort()[0][k]) if (t.numel() - num_nan) % 2 == 1: +<<<<<<< HEAD # We can only test against numpy for odd reductions because numpy +======= + # We can only test agains numpy for odd reductions because numpy +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # returns the mean of the two medians and torch returns the lower self.assertEqual(res.item(), numpy_op(t.cpu().numpy())) for dim in range(t.ndim): @@ -2561,7 +2585,11 @@ def test_median_nan_values(self, device, dtype): k = ((size - num_nan - 1) / 2).type(torch.long) self.assertEqual(res[0], (t.sort(dim)[0]).gather(dim, k)) self.assertEqual(res[0], t.gather(dim, res[1])) +<<<<<<< HEAD # We can only test against numpy for odd reductions because numpy +======= + # We can only test agains numpy for odd reductions because numpy +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # returns the mean of the two medians and torch returns the lower mask = (size - num_nan) % 2 == 1 res = res[0].masked_select(mask).cpu() @@ -3526,7 +3554,11 @@ def test_tensor_compare_ops_empty(self, device): # raises an error if no `dim` parameter is specified. This exists separately from tests in # test_tensot_compare_ops_empty because not specifying a `dim` parameter in the former tests does # not throw errors. Also, checking the return type of argmax requires supplying a different dtype +<<<<<<< HEAD # argument than that for the input tensor. There is also variation in numpy testing. +======= + # argument than that for the input tensor. There is also variantion in numpy testing. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_tensor_compare_ops_argmax_argmix_kthvalue_dim_empty(self, device): shape = (2, 0, 4) master_input = torch.randn(shape, device=device) diff --git a/test/test_scatter_gather_ops.py b/test/test_scatter_gather_ops.py index ba967c142f1e7..8e2dc48da5356 100644 --- a/test/test_scatter_gather_ops.py +++ b/test/test_scatter_gather_ops.py @@ -380,6 +380,7 @@ def helper(input_size, idx_size): helper([50, 8, 7], 100) helper([50, 3, 4, 5], 100) +<<<<<<< HEAD @dtypes(torch.float32) def test_scatter_add_broadcasted_index_deterministic(self, device, dtype): for d in (0, 1): @@ -397,6 +398,8 @@ def test_scatter_add_broadcasted_index_deterministic(self, device, dtype): self.assertEqual(res, ref) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyCPU @dtypes(torch.float32, torch.float64, torch.bfloat16) def test_gather_expanded_index(self, device, dtype): @@ -456,7 +459,11 @@ def unsqueeze_helper(idx, dim): helper([50, 8, 7], 100) helper([50, 3, 4, 5], 100) +<<<<<<< HEAD # Generic Device Test Framework instantiation, see +======= +# Generic Device Test Framework instantation, see +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests # for details. instantiate_device_type_tests(TestScatterGather, globals()) diff --git a/test/test_schema_check.py b/test/test_schema_check.py index 91d9a484d3c89..829a72925042a 100644 --- a/test/test_schema_check.py +++ b/test/test_schema_check.py @@ -14,12 +14,18 @@ from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests +<<<<<<< HEAD from torch.testing._internal.common_utils import IS_WINDOWS, slowTestIf pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +======= +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def secretly_aliasing(x): return x.view(-1) @@ -496,9 +502,15 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): with SchemaInfoBindTestMode(self) as schemaInfoCheck: x.add(x) +<<<<<<< HEAD class TestSchemaCheckModeOpInfo(JitTestCase): @ops(op_db, dtypes=OpDTypes.supported) @slowTestIf(IS_WINDOWS) +======= + +class TestSchemaCheckModeOpInfo(JitTestCase): + @ops(op_db, dtypes=OpDTypes.supported) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_schema_correctness(self, device, dtype, op): # Currently torch.equal isn't supported with torch.complex32 # There's also errors with complex64 and complex128 diff --git a/test/test_segment_reductions.py b/test/test_segment_reductions.py index 0b269595db211..a5db1e6193726 100644 --- a/test/test_segment_reductions.py +++ b/test/test_segment_reductions.py @@ -558,7 +558,11 @@ def test_unsafe_flag(self, device, dtype): lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type) data = torch.arange(6, dtype=torch.float, device=device) +<<<<<<< HEAD # test for error on 1-D lengths +======= + # test for error on 1-D lenghts +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"): torch._segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False) diff --git a/test/test_serialization.py b/test/test_serialization.py index 8fa78cb5da4b5..c1b5db838cc07 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -61,7 +61,10 @@ ) from torch.testing._internal.two_tensor import TwoTensor # noqa: F401 from torch.utils._import_utils import import_dill +<<<<<<< HEAD from pickle import UnpicklingError +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not IS_WINDOWS: @@ -747,7 +750,11 @@ def test_serialization_filelike_stress(self): 'readinto() stress test') def test_serialization_filelike_uses_readinto(self): +<<<<<<< HEAD # For maximum efficiency, when reading a file-like object, +======= + # For maximum effiency, when reading a file-like object, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ensure the C API calls readinto instead of read. a = torch.randn(5, 4) @@ -1357,6 +1364,7 @@ def test_weights_only_error(self, unsafe_global): "file an issue with the following so that we can make `weights_only=True`"): torch.load(f, weights_only=True) +<<<<<<< HEAD def test_weights_only_blocked_func_error_msg(self): import datetime import zoneinfo @@ -1390,6 +1398,8 @@ def test_weights_only_with_zoneinfo_unpickle_registration_success(self): loaded_data = torch.load(f) self.assertEqual(loaded_data, data) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize('weights_only', (False, True)) def test_serialization_math_bits(self, weights_only): t = torch.randn(1, dtype=torch.cfloat) diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 5be1758186467..29aa8c87305c5 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -179,7 +179,11 @@ def test_sort_stable_none(self): def test_complex_unsupported_cpu(self): x = torch.tensor([3.0 + 2j, 4.0 + 3j]) with self.assertRaisesRegex( +<<<<<<< HEAD RuntimeError, " Sort does not support complex dtypes on CPU" +======= + ValueError, "Sort currently does not support complex dtypes on CPU." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.sort(input=x) diff --git a/test/test_sparse.py b/test/test_sparse.py index 727c3a5f6bcdd..b4ee6f36b300c 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -14,19 +14,31 @@ parametrize, subtest, is_coalesced_indices, suppress_warnings, instantiate_parametrized_tests, \ skipIfCrossRef from torch.testing._internal.common_cuda import TEST_CUDA +<<<<<<< HEAD from torch.testing._internal.common_mps import mps_ops_modifier +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from numbers import Number from typing import Any from packaging import version from torch.testing._internal.common_cuda import \ (SM53OrLater, SM80OrLater, TEST_MULTIGPU) from torch.testing._internal.common_device_type import \ +<<<<<<< HEAD (instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, dtypesIfMPS, onlyCPU, onlyCUDA, precisionOverride, deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes, skipCUDAIf, expectedFailureMPS, largeTensorTest) from torch.testing._internal.common_methods_invocations import \ (op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs) from torch.testing._internal.common_dtype import ( all_types, all_types_and_complex, all_mps_types, all_types_and_complex_and, floating_and_complex_types, +======= + (instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride, + deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes) +from torch.testing._internal.common_methods_invocations import \ + (op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs) +from torch.testing._internal.common_dtype import ( + all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) floating_and_complex_types_and, integral_types, floating_types_and, ) from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse @@ -43,6 +55,10 @@ def _op_supports_any_sparse(op): or op.supports_sparse_bsc) +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reduction_ops_with_sparse_support = [ op for op in reduction_ops if 'masked.' not in op.name and _op_supports_any_sparse(op) and not isinstance(op, ReductionPythonRefInfo)] @@ -224,12 +240,18 @@ def randn(self, *args, **kwargs): return torch.empty(*args, **kwargs).normal_() @dtypes(torch.double) +<<<<<<< HEAD @dtypesIfMPS(torch.float32) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_print_coalesced(self, device, dtype): self._test_print(device, dtype, True) @dtypes(torch.double) +<<<<<<< HEAD @dtypesIfMPS(torch.float32) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_print_uncoalesced(self, device, dtype): self._test_print(device, dtype, False) @@ -268,7 +290,11 @@ def _test_print(self, device, dtype, coalesced): if values.dtype == torch.double: dtypes.append(torch.float) else: +<<<<<<< HEAD dtypes.append(torch.double if values.device != torch.device("mps:0") else torch.float32) +======= + dtypes.append(torch.double) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for dtype in dtypes: printed.append(f"########## {dtype} ##########") x = sp_tensor.detach().to(dtype) @@ -288,7 +314,10 @@ def _test_print(self, device, dtype, coalesced): @coalescedonoff @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_basic(self, device, dtype, coalesced): def test_shape(sparse_dims, nnz, with_size): if isinstance(with_size, Number): @@ -323,7 +352,10 @@ def test_shape(sparse_dims, nnz, with_size): @coalescedonoff @dtypes(torch.double, torch.cdouble, torch.bfloat16) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @precisionOverride({torch.bfloat16: 1e-2}) def test_coalesce(self, device, dtype, coalesced): @@ -371,6 +403,7 @@ def _test_coalesce(t): t, _, _ = self._gen_sparse(len(sparse_size), nnz, sparse_size + dense_size, dtype, device, coalesced) _test_coalesce(t) # this tests correctness +<<<<<<< HEAD @onlyCUDA @largeTensorTest("30GB", "cuda") @skipCUDAIf(not SM80OrLater and not TEST_WITH_ROCM, "CUDA capability < SM80 and not ROCM") @@ -387,6 +420,9 @@ def test_coalesce_accepts_large_tensor(self, device, dtype): @dtypes(torch.double) @dtypesIfMPS(torch.float32) +======= + @dtypes(torch.double) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/89395") def test_coalesce_reference_cycle(self, device, dtype): # Test coalesce doesn't create autograd graph cycles (gh-52253) @@ -414,7 +450,10 @@ def test_sparse_sum(): self.assertTrue(ref.expired()) @dtypes(torch.double) +<<<<<<< HEAD @dtypesIfMPS(torch.float32) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_ctor_large_sizes(self, device, dtype): # Test that integer overflow is detected when computing numel # of a sparse tensor with large dimensions (gh-57416). Notice @@ -429,7 +468,10 @@ def test_ctor_large_sizes(self, device, dtype): indices, values, (N + 1,) * 4, device=device)) @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_ctor_size_checks(self, device, dtype): indices = self.index_tensor([ [0, 0, 0], @@ -453,7 +495,10 @@ def test_ctor_size_checks(self, device, dtype): RuntimeError, lambda: self.sparse_tensor(indices, values, torch.Size([2, 4, 2, 1]))) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @coalescedonoff @dtypes(torch.double) def test_ctor_is_coalesced_with_gradcheck(self, device, dtype, coalesced): @@ -479,9 +524,14 @@ def func(indices, values, shape, is_coalesced): "cannot set is_coalesced to true if indices correspond to uncoalesced COO tensor"): torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), shape, True)) +<<<<<<< HEAD @expectedFailureMPS @dtypes(*floating_and_complex_types_and(torch.float16, torch.bfloat16)) @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error") +======= + @dtypes(*floating_and_complex_types_and(torch.float16, torch.bfloat16)) + @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @gradcheck_semantics() def test_to_dense_with_gradcheck(self, device, dtype, gradcheck): @@ -545,7 +595,10 @@ def fn(x): @coalescedonoff @dtypes(torch.float16, torch.bfloat16, torch.float64, torch.int, torch.cfloat, torch.cdouble) +<<<<<<< HEAD @expectedFailureMPS # unique_dim not implemented for MPS device +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_to_sparse(self, device, dtype, coalesced): shape = [5, 2, 10, 4] max_nnz = 1 @@ -565,7 +618,10 @@ def test_to_sparse(self, device, dtype, coalesced): self.assertEqual(dim, result.sparse_dim()) @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_sparse_bool(self, device, dtype): a = torch.tensor([True, False], dtype=dtype, device=device).to(torch.bool) b = a.to_sparse().to_dense() @@ -573,7 +629,10 @@ def test_sparse_bool(self, device, dtype): @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/108667") @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_scalar(self, device, dtype): # tensor with value a = self.sparse_tensor(self.index_tensor([], device=device).unsqueeze(1), 12.3, [], dtype=dtype, device=device) @@ -604,7 +663,10 @@ def test_scalar(self, device, dtype): self.assertEqual(a, a.to_dense().to_sparse()) @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_shared(self, device, dtype): i = self.index_tensor([[2]], device=device) v = torch.tensor([5], dtype=dtype, device=device) @@ -620,9 +682,14 @@ def test_shared(self, device, dtype): i[0][0] = 0 self.assertEqual(torch.empty((3, 0), dtype=dtype, device=device), self.safeToDense(x)) +<<<<<<< HEAD @expectedFailureMPS @dtypes(torch.double, torch.cdouble) @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error") +======= + @dtypes(torch.double, torch.cdouble) + @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @gradcheck_semantics() def test_to_dense_hybrid(self, device, dtype, gradcheck): @@ -670,7 +737,10 @@ def fn(x): test_tensor(x, res) @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_contig(self, device, dtype): def test_tensor(x, exp_i, exp_v): x = x.coalesce() @@ -752,7 +822,10 @@ def test_tensor(x, exp_i, exp_v): test_tensor(x, exp_i, exp_v) @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_contig_hybrid(self, device, dtype): def test_tensor(x, exp_i, exp_v): x = x.coalesce() @@ -840,7 +913,10 @@ def test_tensor(x, exp_i, exp_v): test_tensor(x, exp_i, exp_v) @coalescedonoff +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double, torch.cdouble) def test_clone(self, device, dtype, coalesced): def test_shape(sparse_dims, nnz, with_size): @@ -859,7 +935,10 @@ def test_shape(sparse_dims, nnz, with_size): test_shape(3, 0, [0, 0, 100, 5, 5, 5, 0]) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double, torch.cdouble, torch.bfloat16) @precisionOverride({torch.bfloat16: 2e-2}) def test_Sparse_to_Sparse_copy_(self, device, dtype, coalesced): @@ -962,7 +1041,10 @@ def test_tensor(x): @coalescedonoff @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_transpose(self, device, dtype, coalesced): def test_shape(sparse_dims, nnz, with_size): x = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0] @@ -983,8 +1065,12 @@ def test_shape(sparse_dims, nnz, with_size): @coalescedonoff @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @expectedFailureMPS @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error") +======= + @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @gradcheck_semantics() def test_permute(self, device, dtype, coalesced, gradcheck): # trivial checks @@ -1063,7 +1149,10 @@ def test_shape(di, dj, dk, nnz): @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1166") @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_t_empty(self, device, dtype): def test_in_place(x): shape_original = x.shape @@ -1093,7 +1182,10 @@ def test_not_in_place(x): @coalescedonoff @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_add_zeros(self, device, dtype, coalesced): def test_shape(sparse_dims, nnz, sizes): x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced) @@ -1108,7 +1200,10 @@ def test_shape(sparse_dims, nnz, sizes): test_shape(2, 20, [3, 17, 19, 5]) test_shape(2, 20, [3, 17, 19, 0]) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double, torch.cdouble) def test_add_sub_nnz(self, device, dtype): # nnz should not grow unbounded (gh-34964) @@ -1121,7 +1216,10 @@ def test_add_sub_nnz(self, device, dtype): x.sub_(2 * x) self.assertLessEqual(x._nnz(), 10) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_cat(self, device, dtype, coalesced): @@ -1164,7 +1262,10 @@ def test_shapes(shapes, dim, fail_message=None): "Concatenating sparse tensors, but a dense tensor was found at position 1."): torch.cat((sp, dn)) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_unsqueeze(self, device, dtype, coalesced): @@ -1199,7 +1300,10 @@ def test_shape(sparse_dims, nnz, sizes, unsqueeze_dim, fail_message=None): @coalescedonoff @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_select(self, device, dtype, coalesced): def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=None): x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes, dtype, device, coalesced) @@ -1245,7 +1349,10 @@ def test_select_no_type_promotion(self, device, dtype): self.assertEqual(t.dtype, t[0, 0].dtype) self.assertEqual(t.dtype, t[1, 1].dtype) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_index_select(self, device, dtype, coalesced): @@ -1281,7 +1388,11 @@ def _test_index_select_exhaustive_index(self, sizes, dims, device, dtype, coales # NOTE: indices are negative idx_dim_d_range = list(range(-sizes[d], 0)) for idx_len in range(sizes[d], sizes[d] + 1): +<<<<<<< HEAD # creates all possible valid indices into dim d of length idx_len +======= + # creates all possible valid indices into dim d of lenght idx_len +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for idx in itertools.product(*itertools.repeat(idx_dim_d_range, idx_len)): t_idx = torch.tensor(idx, dtype=torch.long, device=device) @@ -1298,21 +1409,30 @@ def _test_index_select_exhaustive_index(self, sizes, dims, device, dtype, coales small_sparse_result = t_small_sparse.index_select(d, t_idx) self.assertEqual(small_dense_result, small_sparse_result) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_index_select_exhaustive_index_small(self, device, dtype, coalesced): # will trigger brute-force algo self._test_index_select_exhaustive_index((3, 3, 4), range(3), device, dtype, coalesced) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_index_select_exhaustive_index_large(self, device, dtype, coalesced): # will trigger more sophisticated algos self._test_index_select_exhaustive_index((100, 50, 3, 3), (2, 3), device, dtype, coalesced) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_index_select_empty_and_non_contiguous_index(self, device, dtype, coalesced): @@ -1411,7 +1531,10 @@ def test_shape(di, dj, dk, nnz): "bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1" ) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double) def test_bmm(self, device, dtype, coalesced): def test_shape(num_mats, dim_i, dim_j, dim_k, nnz): @@ -1622,7 +1745,10 @@ def test_shape(di, dj, dk, nnz): self.assertEqual(self.safeToDense(res), self.safeToDense(true_result)) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @precisionOverride({torch.bfloat16: 5e-2, torch.float16: 5e-2}) @dtypes(torch.double, torch.cdouble, torch.bfloat16, torch.float16) def test_sparse_addmm(self, device, dtype, coalesced): @@ -1664,9 +1790,14 @@ def fn(S, D1, D2, beta=beta, alpha=alpha): test_shape(7, 8, 9, 20, True, (1, 1)) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS @dtypes(torch.double) @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error") +======= + @dtypes(torch.double) + @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_sparse_mm(self, device, dtype, coalesced): def test_shape(d1, d2, d3, nnz, transposed): if transposed: @@ -1687,9 +1818,14 @@ def fn(S, D): test_shape(7, 8, 9, 20, True) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS @dtypes(torch.double) @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error") +======= + @dtypes(torch.double) + @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @gradcheck_semantics() def test_sparse_mul(self, device, dtype, coalesced, gradcheck): # https://github.com/pytorch/pytorch/issues/79914 @@ -1711,7 +1847,10 @@ def test_shape(sparse_dims, nnz, with_shape): # test_shape(2, 3, [2, 2, 0]) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double) def test_dsmm(self, device, dtype, coalesced): def test_shape(di, dj, dk, nnz): @@ -1731,7 +1870,10 @@ def test_shape(di, dj, dk, nnz): test_shape(1000, 100, 0, 20) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double) def test_hsmm(self, device, dtype, coalesced): def test_shape(di, dj, dk, nnz): @@ -1751,7 +1893,10 @@ def test_shape(di, dj, dk, nnz): test_shape(1000, 100, 0, 20) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double) def test_spadd(self, device, dtype, coalesced): @@ -1839,7 +1984,10 @@ def test_sparse_add_out_bfloat16(self, device, dtype, coalesced): self.assertEqual(res_fp32, res_bf16, atol=1e-2, rtol=0) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double, torch.cdouble) def test_norm(self, device, dtype, coalesced): def test_shape(sparse_dims, nnz, with_size): @@ -1868,7 +2016,10 @@ def test_shape(sparse_dims, nnz, with_size): x.norm(**kwargs) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double) @unittest.skipIf(TEST_WITH_CROSSREF, "fallback triggers cuda device error") def test_sparse_sum(self, device, dtype, coalesced): @@ -1933,7 +2084,10 @@ def fn(S): S = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0] run_tests(S.requires_grad_(True), test_dim) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _test_basic_ops_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, device, coalesced): shape = shape_i + (shape_v) x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape, dtype, device, coalesced) @@ -2042,7 +2196,10 @@ def _test_basic_ops_hybrid(): _test_basic_ops_hybrid() @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_add_dense_sparse_mismatch(self, device, dtype): def test_shape(dense_size, sparse_dims_shape, dense_dims_shape, sparse_size): x = torch.zeros(dense_size, dtype=dtype, device=device) @@ -2059,7 +2216,10 @@ def test_shape(dense_size, sparse_dims_shape, dense_dims_shape, sparse_size): @skipIfTorchDynamo("Not a TorchDynamo suitable test") @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_add_noncontiguous(self, device, dtype): indices = self.index_tensor([[1, 2], [0, 2]], device=device) values = torch.tensor([1.], dtype=dtype, device=device).expand(2, 3, 4, 5) @@ -2082,7 +2242,10 @@ def _test_sparse_mask_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, devic self.assertEqual(self.safeToDense(y2), expected) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double, torch.cdouble) def test_sparse_mask(self, device, dtype, coalesced): def _test_sparse_mask_fixed(): @@ -2153,7 +2316,10 @@ def _test_sparse_mask_fixed(): @coalescedonoff @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_sparse_mask_hybrid(self, device, dtype, coalesced): def _test_sparse_mask_hybrid_fixed(): i = self.index_tensor([ @@ -2215,7 +2381,10 @@ def _test_sparse_mask_hybrid_fixed(): self._test_sparse_mask_shape(0, 0, [10, 10, 0], [2, 0], dtype, device, coalesced) @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfCrossRef def test_sparse_mask_backward(self, device, dtype): from itertools import product, repeat @@ -2250,7 +2419,10 @@ def test_sparse_mask_backward(self, device, dtype): @coalescedonoff @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_zeros(self, device, dtype, coalesced): def _test_zeros(nnzs, shape, out_shape_i, out_shape_v=None): out_shape = out_shape_i + (out_shape_v or []) @@ -2275,7 +2447,10 @@ def test_shape(i_shapes, v_shapes, shape, nnzs): test_shape([2, 3, 4], [0, 4, 5, 6], [2, 3, 0], [9, 12]) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double, torch.cdouble) def test_zeros_like(self, device, dtype, coalesced): def _test_zeros_like(nnzs, template_shape_i, template_shape_v=None): @@ -2359,7 +2534,10 @@ def _test_empty_like(self, sparse_tensor, dtype, device, coalesced): result = torch.empty_like(dense_tensor, layout=torch.sparse_coo) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double, torch.cdouble) def test_empty_like(self, device, dtype, coalesced): # tests https://github.com/pytorch/pytorch/issues/43699 @@ -2416,7 +2594,10 @@ def _all_narrow_combs(self, shape): yield [dim, start, length] @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double, torch.cdouble) def test_narrow(self, device, dtype, coalesced): shape = [3, 3, 4, 2] @@ -2459,7 +2640,10 @@ def is_integral(dtype): sparse_tensor.requires_grad_() @coalescedonoff +<<<<<<< HEAD @dtypesIfMPS(*all_mps_types()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(*all_types()) def test_log1p(self, device, dtype, coalesced): if coalesced: @@ -2525,7 +2709,10 @@ def _test_neg_negative(self, sparse_tensor): @coalescedonoff @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_neg_negative(self, device, dtype, coalesced): if coalesced: @@ -2607,7 +2794,10 @@ def is_integral(dtype): @coalescedonoff @dtypes(*all_types()) +<<<<<<< HEAD @dtypesIfMPS(*all_mps_types()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_asin_arcsin(self, device, dtype, coalesced): if coalesced: input_coalesced = torch.sparse_coo_tensor( @@ -2653,7 +2843,10 @@ def test_asin_arcsin(self, device, dtype, coalesced): self._test_asin_arcsin(input_uncoalesced, coalesced) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double) def test_mv(self, device, dtype, coalesced): def test_shape(di, dj, dk, nnz): @@ -2681,7 +2874,10 @@ def test_shape(di, dj, dk, nnz): res = x.mv(y) @dtypes(*floating_and_complex_types()) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.bfloat16, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_sparse_add_coalesce(self, device, dtype): i = self.index_tensor([[1, 2, 1]], device=device) v = torch.tensor([3, 4, 5], dtype=dtype, device=device) @@ -2759,7 +2955,10 @@ def test_new_device_multi_gpu(self): @coalescedonoff @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_new(self, device, dtype, coalesced): def test_shape(sparse_dims, nnz, with_size): x, indices, values = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced) @@ -2816,7 +3015,10 @@ def test_factory(self, device, dtype): self.assertEqual(True, sparse_tensor.requires_grad) @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_factory_size_check(self, device, dtype): indices = self.index_tensor([[1, 2], [0, 2]], device=device) @@ -2871,7 +3073,10 @@ def test_factory_empty_indices(self, device): self.assertEqual(tensor._indices(), expected_indices) @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_factory_nnz(self, device, dtype): indices = self.index_tensor([[0]], device=device) # (sparse_dim, nnz): (1, 1) values = torch.tensor([[1, 1], [1, 1]], dtype=dtype, device=device) # (nnz, ...): (2, 2) @@ -2886,7 +3091,10 @@ def test_factory_nnz(self, device, dtype): torch.sparse_coo_tensor(indices, values, sizes, dtype=dtype, device=device) @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_factory_nnz_zero(self, device, dtype): def test_shape(i_shape, v_shape, size, expected_size): if size: @@ -2908,7 +3116,10 @@ def test_shape(i_shape, v_shape, size, expected_size): test_shape([3, 0], [0, 2, 4, 0], [1, 2, 3, 2, 4, 0], [1, 2, 3, 2, 4, 0]) @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_factory_dense_dim(self, device, dtype): indices = self.index_tensor([[0]], device=device) values = torch.tensor([[[1, 1, 1], [1, 1, 1]]], dtype=dtype, device=device) @@ -3149,7 +3360,10 @@ def _test_resize_shape(self, x_i, x_v, x_size, y_i, y_v, y_size, dtype, device): x_dense.view(-1)[0:x_v_numel].view(x_v)) @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_resize(self, device, dtype): # 1. Expand the size of some dense dimensions [Supported] self._test_resize_shape([1, 1], [1, 2, 3], [2, 2, 3], @@ -3235,7 +3449,10 @@ def test_is_nonzero(self, device): .is_nonzero()) @dtypes(torch.double, torch.cdouble) +<<<<<<< HEAD @dtypesIfMPS(torch.float32, torch.complex64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_change_tensor_metadata(self, device, dtype): i = self.index_tensor([[0], [1]], device=device) v = torch.tensor([[3, 4, 5]], dtype=dtype, device=device) @@ -3278,7 +3495,10 @@ def test_change_tensor_metadata(self, device, dtype): self.assertEqual(list(t.coalesce().values().size()), [1, 3]) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double) def test_pickle(self, device, dtype, coalesced): import pickle @@ -3310,7 +3530,10 @@ def test_pickle(self, device, dtype, coalesced): sp_tensor_loaded = pickle.loads(serialized) self.assertEqual(sp_tensor, sp_tensor_loaded) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_any(self, device): t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [2, 0])), torch.tensor([False, False]), device=device) t_any = torch.tensor(False) @@ -3328,7 +3551,10 @@ def test_isnan(self, device): self.assertEqual(torch.isnan(t).int(), t_nan.int()) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.float32, torch.float64) def test_div_rounding_mode(self, device, dtype, coalesced): sparse, _, _ = self._gen_sparse(2, 10, (10, 10), dtype, @@ -3349,13 +3575,19 @@ def test_div_rounding_mode(self, device, dtype, coalesced): torch.div(sparse, -2, rounding_mode=mode, out=actual) self.assertEqual(self.safeToDense(actual), expect) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_div_by_sparse_error(self, device): self.assertRaisesRegex(RuntimeError, 'Sparse division requires', lambda: torch.tensor(1., device=device).to_sparse() / torch.tensor(1., device=device).to_sparse()) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_floor_divide_by_sparse_error(self, device): self.assertRaisesRegex(RuntimeError, 'Sparse floor division requires', lambda: torch.tensor(1., device=device).to_sparse() @@ -3368,7 +3600,10 @@ def test_sparse_to_numpy(self, device): self.assertRaises(TypeError, lambda: t.numpy()) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double) def test_softmax(self, device, dtype, coalesced): import torch.nn.functional as F @@ -3681,15 +3916,23 @@ def _check_zero_nnz_softmax_op(self, func, ndim, device, dtype): @dtypes(torch.double, torch.float) +<<<<<<< HEAD @expectedFailureMPS @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error") +======= + @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_softmax_zero_nnz(self, device, dtype): self._check_zero_nnz_softmax_op(torch.sparse.softmax, 1, device, dtype) self._check_zero_nnz_softmax_op(torch.sparse.softmax, 10, device, dtype) @dtypes(torch.double, torch.float) +<<<<<<< HEAD @expectedFailureMPS @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error") +======= + @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_log_softmax_zero_nnz(self, device, dtype): self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 1, device, dtype) self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 10, device, dtype) @@ -3697,7 +3940,10 @@ def test_log_softmax_zero_nnz(self, device, dtype): # TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA @coalescedonoff @dtypes(*floating_and_complex_types()) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypesIfCUDA(*floating_types_and(*[torch.half] if SM53OrLater and not TEST_WITH_ROCM else [], *[torch.bfloat16] if SM80OrLater and not TEST_WITH_ROCM else [], torch.complex64, @@ -3828,7 +4074,10 @@ def assign_to(): self.assertRaises(TypeError, assign_to) +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double, torch.cdouble) def test_full_broadcast_to(self, device, dtype): def can_broadcast(s0, s1): @@ -3859,7 +4108,10 @@ def can_broadcast(s0, s1): torch._sparse_broadcast_to(s, s1) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.double, torch.cdouble) def test_sparse_broadcast_to(self, device, dtype, coalesced): def test(sparse_dims, nnz, with_size, new_size): @@ -3889,7 +4141,10 @@ def _test_mul_skips(self, device, dtype, coalesced): self.skipTest(f"Test with dtype={dtype}, device={device} runs only with coalesced inputs") @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NOTE: addcmul_out is not implemented for bool. @dtypes(*all_types_and_complex_and(torch.bfloat16, torch.float16)) @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2}) @@ -3941,7 +4196,10 @@ def check_empty(sparse_shape, nnz, dense_shape, coalesce): # check_autograd(x, y) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)) @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2}) def test_sparse_dense_mul(self, device, dtype, coalesced): @@ -4073,11 +4331,19 @@ def valid_cases(): # some normal cases yield (make_diags((1, 5)), make_offsets([0]), (5, 5)) yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4)) +<<<<<<< HEAD # non-contiguous diags yield (make_diags((5, 4), noncontiguous=True), make_offsets([-1, 1, 0, 2, -2]), (5, 5)) # non-contiguous offsets yield (make_diags((3, 4)), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5)) # non-contiguous diags + offsets +======= + # noncontigous diags + yield (make_diags((5, 4), noncontiguous=True), make_offsets([-1, 1, 0, 2, -2]), (5, 5)) + # noncontigous offsets + yield (make_diags((3, 4)), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5)) + # noncontigous diags + offsets +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) yield (make_diags((3, 4), noncontiguous=True), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5)) # correct dimensionality, 2d, 2d , and shapes match, but the number of diagonals is zero yield (make_diags((0, 3)), make_offsets([]), (3, 3)) @@ -4127,7 +4393,10 @@ def test_small_nnz_coalesced(self): self.assertFalse(torch.sparse_coo_tensor([[0, 1], [0, 1]], [1, 2], (2, 2)).is_coalesced()) @coalescedonoff +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(*all_types_and_complex_and(torch.bool)) def test_sum(self, device, dtype, coalesced): def run_test(shape, nnz): @@ -4201,7 +4470,11 @@ def _sparse_to_dense(tensor): return tensor.to(torch.int8).to_dense().to(torch.bool) +<<<<<<< HEAD _sparse_unary_ops = ops(mps_ops_modifier(sparse_unary_ufuncs, sparse=True), dtypes=OpDTypes.supported, +======= +_sparse_unary_ops = ops(sparse_unary_ufuncs, dtypes=OpDTypes.supported, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) allowed_dtypes=all_types_and_complex()) class TestSparseUnaryUfuncs(TestCase): exact_dtype = True @@ -4253,8 +4526,13 @@ def test_inplace(self, device, dtype, op): @_sparse_unary_ops def test_sparse_zero_dims(self, device, dtype, op): # test 0x0 sparse_coo_tensor +<<<<<<< HEAD indices = torch.empty(2, 0, dtype=torch.int64, device=device) values = torch.empty(0, dtype=dtype, device=device) +======= + indices = torch.empty(2, 0, dtype=torch.int64) + values = torch.empty(0, dtype=dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sparse_0x0 = torch.sparse_coo_tensor(indices, values, (0, 0)) expected = torch.sparse_coo_tensor(indices, op(values), (0, 0)) actual = op(sparse_0x0) @@ -4713,7 +4991,11 @@ def create_invalid_tensor(check_invariants=None): # However, invariants check can be disabled via # constructor's optional argument so that the invalid +<<<<<<< HEAD # tensor is successfully constructed: +======= + # tensor is succesfully constructed: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = create_invalid_tensor(check_invariants=False) self.assertEqual(r.layout, layout) @@ -4735,7 +5017,11 @@ def create_invalid_tensor(check_invariants=None): self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) self.assertFalse(torch.sparse.check_sparse_tensor_invariants.is_enabled()) +<<<<<<< HEAD # Test an attempt to reuse an activate context manager instance +======= + # Test an attempt to re-use an activate context manager instance +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) check_ctx2 = torch.sparse.check_sparse_tensor_invariants(True) with check_ctx: self.assertTrue(torch.sparse.check_sparse_tensor_invariants.is_enabled()) @@ -5601,12 +5887,20 @@ def generic_constructor(*args, **kwargs): # e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA +<<<<<<< HEAD instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), allow_mps=True, except_for='meta') +======= +instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_device_type_tests(TestSparseMaskedReductions, globals(), except_for='meta') # e.g., TestSparseCPU and TestSparseCUDA +<<<<<<< HEAD instantiate_device_type_tests(TestSparse, globals(), allow_mps=True, except_for='meta') +======= +instantiate_device_type_tests(TestSparse, globals(), except_for='meta') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_device_type_tests(TestSparseAny, globals(), except_for='meta') diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 8fb490e1b5bc7..e027196b544b1 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -2791,7 +2791,11 @@ def test_autograd_sparse_csr_unary(self, device, dtype, op): raise ValueError("Expected at least one 2D tensor in samples.") for sample in samples: +<<<<<<< HEAD # We must skip samples of low dimensionality, we can't convert them to sparsed compressed layouts +======= + # We must skip samples of low dimensionality, we can't covert them to sparsed compressed layouts +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sample.input.ndim < 2: continue sparse_input = sample.input.to_sparse_csr().requires_grad_(True) @@ -3255,7 +3259,11 @@ def test_dense_to_from_sparse_compressed(self, device, hybrid, batched, layout): # helpers def _check_against_scipy_matrix(pt_matrix, dense, blocksize, **kwargs): +<<<<<<< HEAD # scipy has no bsc layout, so we check against the bsr layout of the transposed dense +======= + # scipy has no bsc layout, so we check against the bsr layout of the tranposed dense +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if layout == torch.sparse_bsc: sp_matrix = self._construct_sp_matrix(dense.t(), layout=torch.sparse_bsr, blocksize=blocksize[::-1]) else: @@ -3272,7 +3280,11 @@ def _check_against_scipy_matrix(pt_matrix, dense, blocksize, **kwargs): self.assertEqual(torch.tensor(sp_matrix.indptr, dtype=torch.int64), compressed_indices_mth(pt_matrix)) self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix)) if layout == torch.sparse_bsc: +<<<<<<< HEAD # we must transpose the blocks before comparing +======= + # we must tranpose the blocks before comparing +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values().transpose(-2, -1)) else: self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values()) @@ -3371,7 +3383,11 @@ def _generate_subject(sparse_shape, batch_shape, hybrid_shape): # special cases for batched tensors if batched: +<<<<<<< HEAD # batched sparse tensors need only have the same number of non-zeros in each batch not necessarily the +======= + # batched sparse tensors need only have the same number of non-zeros in each batch not nessesarily the +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # same sparsity pattern in each batch sparse_shape = sparse_sizes[0] hybrid_shape = hybrid_sizes[0] @@ -3382,7 +3398,11 @@ def _generate_subject(sparse_shape, batch_shape, hybrid_shape): # number of elements/blocks in each batch (total not nnz) batch_mask_shape = sparse_shape if layout in blocked_layouts: +<<<<<<< HEAD # if we are blocked the mask is generated for the block valued elements +======= + # if we are blocked the mask is genereated for the block valued elemetns +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) batch_mask_shape = sparse_shape[0] // blocksize[0], sparse_shape[1] // blocksize[1] # random bool vector w/ length equal to max possible nnz for the sparse_shape @@ -3603,8 +3623,13 @@ def test_triton_bsr_softmax(self, device, dtype): @onlyCUDA @dtypes(torch.half, torch.bfloat16, torch.float) @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float) +<<<<<<< HEAD @unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU), "Skipped for internal with remote GPUs") +======= + @unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU) or torch._running_with_deploy(), + "Skipped for deploy and internal with remote GPUs") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size): from functools import partial from torch.sparse._triton_ops import bsr_dense_mm @@ -3680,8 +3705,13 @@ def kernel_impl(*args, **kwargs): @onlyCUDA @dtypes(torch.half) +<<<<<<< HEAD @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Skipped for internal with remote GPUs") +======= + @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU or torch._running_with_deploy(), + "Skipped for deploy and internal with remote GPUs") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_triton_bsr_dense_bmm_error_messages(self, device, dtype): from torch.sparse._triton_ops import bsr_dense_mm @@ -3815,7 +3845,11 @@ def test_triton_sampled_addmm(self, device, dtype, block_size): input_broadcasted_clone.col_indices(), # For testing `out=` let's make values to have "weird" strides # so that if the kernel modifies values to it's needs, the result +<<<<<<< HEAD # is being copied into out.values. +======= + # is being compied into out.values. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_broadcasted_clone.values().transpose(-3, -2).contiguous().transpose(-3, -2), layout=input_broadcasted_clone.layout, size=input_broadcasted_clone.shape @@ -3930,7 +3964,11 @@ def test_triton_bsr_scatter_mm(self, device, dtype, blocksize): try: result = bsr_scatter_mm(bsr, dense, indices_data=indices_data) except triton.compiler.OutOfResources: +<<<<<<< HEAD # ensure that there was at least one successful test: +======= + # ensure that there was at least one succesful test: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert SPLIT_N < SPLIT_N_list[0] break diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index 51fb4aa48c221..1c53ef5b5dcb5 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -714,18 +714,28 @@ def test_pack_both_ways_id(self, dtype) -> None: max_diff = (ref_gemm - pack_gemm).abs().argmax() torch.testing.assert_close( ref_gemm, pack_gemm, +<<<<<<< HEAD **atol_rtol_kw[dtype], msg=f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})", ) +======= + **atol_rtol_kw[dtype] + ), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Test A.t@B pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed_t, meta_t) max_diff = (ref_gemm - pack_gemm).abs().argmax() torch.testing.assert_close( ref_gemm, pack_gemm, +<<<<<<< HEAD **atol_rtol_kw[dtype], msg=f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})", ) +======= + **atol_rtol_kw[dtype] + ), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @training_dtypes @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") diff --git a/test/test_stateless.py b/test/test_stateless.py index d24194ed460e9..70623c69e2945 100644 --- a/test/test_stateless.py +++ b/test/test_stateless.py @@ -210,7 +210,11 @@ def test_circular_references(self, functional_call): prev_buffer = module.buffer.clone() res = functional_call(module, parameters, x, tie_weights=False) self.assertEqual(x, res) +<<<<<<< HEAD # check that the weights remain unmodified and were correctly accessed +======= + # check that the weights remain unmodified and were correctly accesed +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cur_weight = module.l1.weight cur_buffer = module.buffer self.assertEqual(cur_weight, prev_weight) @@ -753,7 +757,11 @@ def test_functional_call_tuple_dicts(self): res = torch.func.functional_call(mod, (), x) self.assertEqual(res, mod(x)) +<<<<<<< HEAD # three dictionaries +======= + # three dictonaries +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = ({'l1.weight': torch.ones(1, 1)}, {'l1.bias': torch.ones(1)}, {'buffer': torch.zeros(1)}) res = torch.func.functional_call(mod, a, x) self.assertEqual(res, x + 1) diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index a7bcd04ce14e7..9d50dba56a0b7 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -5,7 +5,11 @@ import math import pickle import sys +<<<<<<< HEAD from collections.abc import Callable +======= +from typing import Callable +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sympy @@ -24,7 +28,10 @@ FloorDiv, Identity, OpaqueUnaryFn_cos, +<<<<<<< HEAD BitwiseFn_bitwise_and, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) simple_floordiv_gcd, ) from torch.utils._sympy.interp import sympy_interp @@ -424,10 +431,17 @@ def test_interp(self, fn): sargs = [sympy.sympify(a) for a in args] sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) ref_r = getattr(ReferenceAnalysis, fn)(*sargs) +<<<<<<< HEAD # Yes, I know this is a long-winded way of saying xreplace; the # point is to test sympy_interp r = sympy_interp( ReferenceAnalysis, dict(zip(symbols, sargs, strict=False)), sympy_expr +======= + # Yes, I know this is a longwinded way of saying xreplace; the + # point is to test sympy_interp + r = sympy_interp( + ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.assertEqual(ref_r, r) @@ -502,7 +516,11 @@ def trace_f(px, py): self.assertEqual( sympy_interp( +<<<<<<< HEAD PythonReferenceAnalysis, dict(zip(symbols, args, strict=False)), sympy_expr +======= + PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), gm(*args), ) @@ -556,7 +574,11 @@ def test_tensor_interp(self, fn): direct_result = tensor_fn(*tensor_args) interp_result = sympy_interp( TensorReferenceAnalysis, +<<<<<<< HEAD dict(zip(symbols, tensor_args, strict=False)), +======= + dict(zip(symbols, tensor_args)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sympy_expr, ) @@ -874,10 +896,13 @@ def test_pickle(self): r = pickle.loads(pickle.dumps(x)) self.assertEqual(x, r) +<<<<<<< HEAD x = BitwiseFn_bitwise_and(sympy.Symbol("a"), sympy.Symbol("b")) r = pickle.loads(pickle.dumps(x)) self.assertEqual(x, r) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestSingletonInt(TestCase): def test_basic(self): diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 15c04b8154c3a..e2982e973e475 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -1531,7 +1531,11 @@ def test_combinations(self, device): expected = torch.empty(0, 5, dtype=a.dtype, device=device) self.assertEqual(c, expected) +<<<<<<< HEAD # test empty input +======= + # test empty imput +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = torch.empty(0, device=device) c1 = torch.combinations(a) c2 = torch.combinations(a, with_replacement=True) @@ -1965,11 +1969,14 @@ def test_zeros(self, device): expected = torch.tensor([[0., 0.], [0., 0.]], device=device, dtype=torch.complex32) self.assertEqual(complexHalfTensor, expected) +<<<<<<< HEAD def test_zeros_bounds_checking(self, device): # Test negative large integer with self.assertRaisesRegex(RuntimeError, r"zeros: Dimension size must be non-negative."): torch.zeros(-6744789213055875072, device=device) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: this test should be updated def test_zeros_out(self, device): shape = (3, 4) diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index 17d3a58535d65..07b55f334394f 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -695,12 +695,20 @@ def test_type_as(x, y): _atol = 2e-3 _rtol = 1e-5 if data_type is torch.bfloat16: +<<<<<<< HEAD # Compared to aten logic, NNC could save additional BF16/Fp32 conversion. +======= + # Compared to aten logic, NNC coudl save addtional BF16/Fp32 conversion. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Take d = a + b - c as an example, the aten logic is as follows at # operator level: # tmp = to_bf16(to_fp32(a) + to_fp32(b)) # d = to_bf16(to_fp32(tmp) + to_fp32(c)) +<<<<<<< HEAD # But NNC could fuse the compression and remove the redundant conversions. +======= + # But NNC could fuse the compression and remove the redudant conversions. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The final statement is as follows # d = to_bf16(to_fp32(a) + to_fp32(b) + to_fp32(c)) # Hence, we simulate NNC computation by feeding fp32 tensors and converting diff --git a/test/test_testing.py b/test/test_testing.py index 00fb106ac2ab6..cbf03d73aeb71 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -2351,7 +2351,11 @@ def _check_python_output(cls, program) -> str: # fail, so just set CWD to this script's directory cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8") +<<<<<<< HEAD # The test is flaky on ROCm/XPU and has been open and close multiple times +======= + # The test is flaky on ROCm and has been open and close multiple times +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # https://github.com/pytorch/pytorch/issues/110040 @skipIfRocm def test_circular_dependencies(self) -> None: @@ -2369,7 +2373,10 @@ def test_circular_dependencies(self) -> None: "torch.distributed.benchmarks", # depends on RPC and DDP Optim "torch.distributed.examples", # requires CUDA and torchvision "torch.distributed.tensor.examples", # example scripts +<<<<<<< HEAD "torch.distributed._tools.sac_ilp", # depends on pulp +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch.csrc", # files here are devtools, not part of torch "torch.include", # torch include files after install ] diff --git a/test/test_torch.py b/test/test_torch.py index a6c265c309a2a..68635dd23685e 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -43,8 +43,12 @@ skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, bytes_to_scalar, parametrize, skipIfMPS, noncontiguous_like, +<<<<<<< HEAD AlwaysWarnTypedStorageRemoval, TEST_WITH_TORCHDYNAMO, xfailIfTorchDynamo, xfailIfS390X, set_warn_always_context) +======= + AlwaysWarnTypedStorageRemoval, TEST_WITH_TORCHDYNAMO, xfailIfTorchDynamo, set_warn_always_context) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from multiprocessing.reduction import ForkingPickler from torch.testing._internal.common_device_type import ( expectedFailureMeta, @@ -59,14 +63,21 @@ from torch.testing._internal.common_cuda import ( tf32_on_and_off, TEST_CUDNN, TEST_MULTIGPU, _create_scaling_case, _create_scaling_models_optimizers) +<<<<<<< HEAD from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off +======= +from torch.testing._internal.common_mkldnn import bf32_on_and_off +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_dtype import ( floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types, all_types_and, floating_types, floating_and_complex_types, integral_types_and, get_all_qint_dtypes, all_types_complex_float8_and, ) from torch.testing._internal.two_tensor import TwoTensor +<<<<<<< HEAD from torch.testing._internal.common_utils import IS_WINDOWS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TEST_WITH_TORCHINDUCTOR: from torch._inductor.test_case import TestCase @@ -159,7 +170,10 @@ def test_constants(self, device): self.assertEqual(torch.inf, math.inf) @onlyNativeDeviceTypes +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.bool, torch.float32, torch.complex64, torch.float64, torch.complex128, torch.uint16, torch.uint32, torch.uint64) @@ -192,7 +206,10 @@ def test_int64_upsample3d(self, device, dtype): @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.bool, torch.float32, torch.complex64, torch.float64, torch.complex128, torch.uint16, torch.uint32, torch.uint64) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_storage(self, device, dtype): v = make_tensor((3, 5), dtype=dtype, device=device, low=-9, high=9) self.assertEqual(v.storage()[0], v[0][0]) @@ -223,7 +240,10 @@ def test_storage(self, device, dtype): torch.bool, torch.float32, torch.complex64, torch.float64, torch.complex128, torch.quint8, torch.qint8, torch.qint32, torch.quint4x2) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_storage_setitem(self, device, dtype): # Skip quantized dtypes for CUDA, since they're not supported if torch.device(device).type == 'cuda': @@ -255,7 +275,14 @@ def test_storage_setitem(self, device, dtype): @skipIfTorchDynamo("Not a suitable test for TorchDynamo") @onlyNativeDeviceTypes +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= + @unittest.skipIf( + "RelWithAssert" in torch.__config__.show(), + "failing in debug build, see https://github.com/pytorch/pytorch/pull/156731 for example", + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_storage_use_count(self, device): a = torch.randn(10, device=device) prev_cf = torch._C._storage_Use_Count(a.untyped_storage()._cdata) @@ -266,7 +293,10 @@ def test_storage_use_count(self, device): @xfailIfTorchDynamo @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_tensor_storage_type(self, device, dtype): a = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9) @@ -277,7 +307,10 @@ def test_tensor_storage_type(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64)) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_tensor_from_storage(self, device, dtype): a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9) a_s = a.storage() @@ -295,7 +328,10 @@ def test_tensor_from_storage(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_set_storage(self, device, dtype): a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9) a_s = a.storage() @@ -334,7 +370,10 @@ def _check_storage_meta(self, s, s_check): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_typed_storage_meta(self, device, dtype): args_list = [ [], @@ -348,7 +387,10 @@ def test_typed_storage_meta(self, device, dtype): self._check_storage_meta(s, s_check) @onlyNativeDeviceTypes +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_untyped_storage_meta(self, device): args_list = [ [], @@ -363,7 +405,10 @@ def test_untyped_storage_meta(self, device): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_storage_meta_from_tensor(self, device, dtype): t_check = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9) t = t_check.to('meta') @@ -373,7 +418,10 @@ def test_storage_meta_from_tensor(self, device, dtype): self._check_storage_meta(s, s_check) @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_storage_meta_errors(self, device, dtype): s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype) @@ -414,7 +462,10 @@ def test_storage_meta_errors(self, device, dtype): @onlyCPU @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_storage_meta_ok(self, device, dtype): s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype) @@ -430,7 +481,10 @@ def test_module_share_memory(self): model.share_memory() @dtypes(torch.float32, torch.complex64) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_deepcopy(self, device, dtype): from copy import deepcopy a = torch.randn(5, 5, dtype=dtype, device=device) @@ -458,7 +512,10 @@ def test_deepcopy(self, device, dtype): self.assertEqual(deepcopy(a).foo, 3) @dtypes(torch.float32, torch.complex64) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_deepcopy_scalar(self, device, dtype): from copy import deepcopy a = torch.tensor(5, dtype=dtype, device=device) @@ -1106,7 +1163,11 @@ def test_broadcast(self, fn, device): small2_expanded = small2.expand(*dims_full) if small.is_cuda and fn in ['map', 'map2']: +<<<<<<< HEAD # map and map2 are not implemented on CUDA tensors +======= + # map and map2 are not implementd on CUDA tensors +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return if hasattr(large_expanded, fn): @@ -1561,6 +1622,7 @@ def test_nondeterministic_alert_interpolate_bilinear(self, device): 'upsample_bilinear2d_backward_out_cuda', torch.device(device).type == 'cuda') +<<<<<<< HEAD def test_no_nondeterministic_alert_interpolate_bilinear(self, device): input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True) @@ -1595,6 +1657,8 @@ def fn(): 'upsample_trilinear3d_backward_out_cuda', False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipIfTorchInductor("aot-autograd issue") def test_deterministic_replication_pad2d(self, device): test_cases = [ @@ -2162,11 +2226,16 @@ def _cond_fn(x): ind_cpu = ind.cpu() repeats = torch.full((1,), 2, device=device) mask = torch.randint(2, (size,), device=device, dtype=bool) +<<<<<<< HEAD mask_cpu = mask.cpu() expect_no_sync = (lambda: _ind_put_fn(x, mask, 1.), lambda: _ind_put_fn(x, mask_cpu, y), lambda: _ind_put_fn(x, ind, y), lambda: _ind_get_fn(x, mask_cpu), +======= + expect_no_sync = (lambda: _ind_put_fn(x, mask, 1.), + lambda: _ind_put_fn(x, ind, y), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lambda: _ind_get_fn(x, ind), lambda: torch.nn.functional.one_hot(ind, num_classes=size), lambda: torch.randperm(20000, device=device), @@ -2336,7 +2405,11 @@ def test_corrcoef(self, device, dtype): for x in self._generate_correlation_tensors(device, dtype): res = torch.corrcoef(x) ref = np.corrcoef(x.cpu().numpy()) +<<<<<<< HEAD self.assertEqual(res, ref, atol=1e-04, rtol=1e-03, exact_dtype=False) +======= + self.assertEqual(res, ref, exact_dtype=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipRocmIfTorchInductor @dtypes(torch.int, torch.float, torch.cfloat) @@ -2572,7 +2645,11 @@ def test_cdist_cuda_backward(self, device): self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001) @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) +<<<<<<< HEAD @reduced_f32_on_and_off(0.08) +======= + @bf32_on_and_off(0.08) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cdist_large(self, device): for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: x = torch.randn(1000, 10, device=device) @@ -2583,7 +2660,11 @@ def test_cdist_large(self, device): @slowTest @tf32_on_and_off(0.01) +<<<<<<< HEAD @reduced_f32_on_and_off(0.08) +======= + @bf32_on_and_off(0.08) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cdist_large_batch(self, device): for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: x = torch.randn(4, 3, 1000, 10, device=device) @@ -2593,7 +2674,11 @@ def test_cdist_large_batch(self, device): self.assertEqual(expected, actual) @tf32_on_and_off(0.005) +<<<<<<< HEAD @reduced_f32_on_and_off(0.04) +======= + @bf32_on_and_off(0.04) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cdist_non_contiguous(self, device): for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: x = torch.randn(5, 7, device=device).mT @@ -2621,7 +2706,11 @@ def test_cdist_non_contiguous(self, device): self.assertEqual(expected, actual) @tf32_on_and_off(0.005) +<<<<<<< HEAD @reduced_f32_on_and_off(0.04) +======= + @bf32_on_and_off(0.04) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cdist_non_contiguous_batch(self, device): for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: x = torch.randn(4, 3, 2, 5, 7, device=device).mT @@ -2692,7 +2781,11 @@ def test_cdist_same_inputs(self, device): x.requires_grad = True d = torch.cdist(x, y) d.backward(dist_grad) +<<<<<<< HEAD # Check that the backward pass does not contain invalid +======= + # Check that the backward passs does not contain invalid +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # values such as nan or inf assert torch.isfinite(x.grad).all() @@ -2724,7 +2817,11 @@ def test_cumsum(self, device): [0, 0, 0], [1, 2, 3]])) +<<<<<<< HEAD # Check that cumulative sum over a zero length dimension doesn't crash on backprop. +======= + # Check that cummulative sum over a zero length dimension doesn't crash on backprop. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Also check that cumsum over other dimensions in a tensor with a zero-length # dimensiuon also works # Also include a basic suite of similar tests for other bases cases. @@ -2776,7 +2873,11 @@ def test_cumprod(self, device): [0, 0, 0], [1, 1, 1]])) +<<<<<<< HEAD # Check that cumulative prod over a zero length dimension doesn't crash on backprop. +======= + # Check that cummulative prod over a zero length dimension doesn't crash on backprop. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Also check that cumprod over other dimensions in a tensor with a zero-length # dimensiuon also works # Also include a basic suite of similar tests for other bases cases. @@ -3448,9 +3549,275 @@ def test_narrow_copy_non_contiguous(self, device): actual = torch.narrow_copy(inp, 1, 0, 10) self.assertEqual(expected, actual) +<<<<<<< HEAD # FIXME: find a test suite for the take operator @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) @slowTestIf(IS_WINDOWS) +======= + # FIXME: move to indexing test suite + @parametrize("reduce", ['prod', 'amin', 'amax', 'mean']) + @dtypes(*all_types_and(torch.half, torch.bfloat16)) + def test_index_reduce(self, device, dtype, reduce): + size = (3, 4, 5) + index_dtypes = [torch.int, torch.long] + include_selfs = [True, False] + amin_init = float('inf') if dtype.is_floating_point else torch.iinfo(dtype).max + amax_init = -float('inf') if dtype.is_floating_point else torch.iinfo(dtype).min + reduction_init = {'prod': 1, 'mean': 0, 'amin': amin_init, 'amax': amax_init} + + for dest_noncontig, src_noncontig, index_noncontig in product([True, False], repeat=3): + for idx_dtype, include_self in product(index_dtypes, include_selfs): + for dim in range(len(size)): + num_src = np.random.randint(10) + num_dest = size[dim] + dest = make_tensor(size, device=device, dtype=dtype, noncontiguous=dest_noncontig) + src_size = size[:dim] + (num_src,) + size[dim + 1:] + src = make_tensor(src_size, device=device, dtype=dtype, noncontiguous=src_noncontig) + idx = torch.testing.make_tensor( + num_src, low=0, high=num_dest, dtype=idx_dtype, device=device, noncontiguous=index_noncontig + ) + expected = dest.clone() + dest.index_reduce_(dim, idx, src, reduce, include_self=include_self) + # fill rows in idx with reduction inits if include_self=False + if (not include_self): + expected.index_fill_(dim, idx.long(), reduction_init[reduce]) + expected = expected.transpose(0, dim) + src = src.transpose(0, dim) + for i in range(num_src): + if reduce == 'prod': + expected[idx[i]] *= src[i] + elif reduce == 'amin': + torch.minimum(expected[idx[i]], src[i], out=expected[idx[i]]) + elif reduce == 'amax': + torch.maximum(expected[idx[i]], src[i], out=expected[idx[i]]) + else: + expected[idx[i]] += src[i] + if reduce == 'mean': + counts = torch.ones_like(expected) if include_self else torch.zeros_like(expected) + counts.index_add_(0, idx, torch.ones_like(src)) + counts.masked_fill_(counts == 0, 1) + if (dtype.is_floating_point): + expected.div_(counts) + else: + expected.div_(counts, rounding_mode="floor") + expected = expected.transpose(0, dim) + + self.assertEqual(dest, expected) + + # FIXME: move to test indexing + @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + def test_index_copy(self, device, dtype): + # We just test for num_copy <= num_dest, as otherwise there are repeated indices + # and the behavior is undefined + num_copy, num_dest = 3, 5 + + def make_arg(batch_sizes, n, dim, contig): + size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:] + return make_tensor(size_arg, dtype=dtype, device=device, low=None, high=None, noncontiguous=not contig) + + def ref_index_copy(tgt, dim, idx, src): + for i in range(idx.size(0)): + idx_dest = dim * (slice(None),) + (idx[i],) + idx_src = dim * (slice(None),) + (i,) + tgt[idx_dest] = src[idx_src] + + # More thorough testing as in index_add + for dest_contig, src_contig, index_contig in product([True, False], repeat=3): + for other_sizes in ((), (4, 5)): + for dim in range(len(other_sizes)): + dest = make_arg(other_sizes, num_dest, dim, dest_contig) + src = make_arg(other_sizes, num_copy, dim, src_contig) + idx = torch.randperm(num_dest, dtype=torch.int64, device=device)[:num_copy] + if not index_contig: + idx = torch.repeat_interleave(idx, 2, dim=-1) + idx = idx[..., ::2] + dest2 = dest.clone() + dest.index_copy_(dim, idx, src) + ref_index_copy(dest2, dim, idx, src) + self.assertEqual(dest, dest2) + + # FIXME: move to test indexing + # onlyNativeDeviceTypes due to an XLA error: + # https://github.com/pytorch/pytorch/issues/53256 + @onlyNativeDeviceTypes + @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + def test_index_copy_scalars(self, device, dtype): + # Create the 8 possible combinations of scalar sizes for target / index / source + scalars = ((make_tensor(size_t, dtype=dtype, device=device, low=None, high=None), + make_tensor(size_i, dtype=torch.int64, device=device, low=0, high=1), + make_tensor(size_s, dtype=dtype, device=device, low=None, high=None)) + for size_t, size_i, size_s in product([(), (1,)], repeat=3)) + for target, idx, source in scalars: + target.index_copy_(0, idx, source) + self.assertEqual(target.item(), source.item()) + + # FIXME: move to test indexing + @onlyCPU + def test_errors_index_copy(self, device): + # We do not test the GPU as the CUDA_ASSERT would break the CUDA context + idx_dim = 8 + tgt_dim = 5 + batch_dim = 3 + + # Too large of an index + a = torch.randn(batch_dim, tgt_dim, device=device) + idx = torch.full((idx_dim,), tgt_dim, device=device) + c = torch.zeros(batch_dim, idx_dim, device=device) + with self.assertRaises(IndexError): + a.index_copy_(1, idx, c) + + # Too small (negative indices) + idx = torch.full((idx_dim,), -1, device=device) + with self.assertRaises(IndexError): + a.index_copy_(1, idx, c) + + # Too small (very negative indices) - they should be unsupported even + # when support for negative indices is implemented for index_copy_ + idx = torch.full((idx_dim,), -tgt_dim - 1, device=device) + with self.assertRaises(IndexError): + a.index_copy_(1, idx, c) + + def _prepare_data_for_index_copy_and_add_deterministic( + self, dim: int, device: torch.device + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert (dim >= 0 and dim < 3) + a = [5, 4, 3] + a[dim] = 2000 + x = torch.zeros(a, device=device) + b = a.copy() + elems = a[dim] * 20 + b[dim] = elems + src = torch.rand(b, device=device) + index = torch.randint(a[dim], (elems,), device=device) + return (x, index, src) + + # FIXME: move to test indexing + @onlyNativeDeviceTypes + def test_index_copy_deterministic(self, device: torch.device) -> None: + for dim in range(3): + x, index, src = self._prepare_data_for_index_copy_and_add_deterministic(dim, device) + with DeterministicGuard(True): + y0 = torch.index_copy(x, dim, index, src) + + x0 = x.detach().clone() + index_list = index.tolist() + for i in range(len(index_list)): + if dim == 0: + x0[index_list[i], :, :] = src[i, :, :] + elif dim == 1: + x0[:, index_list[i], :] = src[:, i, :] + elif dim == 2: + x0[:, :, index_list[i]] = src[:, :, i] + + self.assertEqual(x0, y0, atol=0, rtol=0) + + # FIXME: move to test indexing + @onlyNativeDeviceTypes + def test_index_add_deterministic(self, device: torch.device) -> None: + for dim in range(3): + x, index, src = self._prepare_data_for_index_copy_and_add_deterministic(dim, device) + alpha = random.random() + 1 + # on CPU it should be deterministic regardless of the deterministic mode + with DeterministicGuard(True): + y0 = torch.index_add(x, dim, index, src, alpha=alpha) + for _ in range(3): + y = torch.index_add(x, dim, index, src, alpha=alpha) + self.assertEqual(y, y0, atol=0, rtol=0) + + with DeterministicGuard(False): + for _ in range(3): + y_nd = torch.index_add(x, dim, index, src, alpha=alpha) + self.assertEqual(y_nd, y0, atol=1e-3, rtol=1e-5) + + # FIXME: find a test suite for the put operator + @onlyNativeDeviceTypes + def test_index_put_non_accumulate_deterministic(self, device) -> None: + with DeterministicGuard(True): + for i in range(3): + m = random.randint(10, 20) + elems = random.randint(20000, 30000) + values = torch.rand(elems, device=device) + indices = torch.randint(m, (elems,), device=device) + input = torch.rand(m, device=device) + output = input.index_put((indices,), values, accumulate=False) + + input_list = input.tolist() + indices_list = indices.tolist() + values_list = values.tolist() + for i, v in zip(indices_list, values_list): + input_list[i] = v + + self.assertEqual(output, input_list) + + # FIXME: move to test indexing + @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @skipIfMPS + def test_index_fill(self, device, dtype): + x = torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device) + index = torch.tensor([0], device=device) + x.index_fill_(1, index, 0) + self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dtype, device=device)) + if not x.is_complex() and not device == "meta": + with self.assertRaisesRegex(RuntimeError, r"Scalar"): + x.index_fill_(1, index, 1 + 1j) + # Make sure that the result stays 0-dim while applied to + # a 0-dim input + x = torch.tensor(1, dtype=dtype, device=device) + self.assertEqual(0, x.index_fill(0, index, -1).dim()) + self.assertEqual(0, x.index_fill_(0, index, -1).dim()) + + # FIXME: move to test indexing + # The test fails for zero-dimensional tensors on XLA + @onlyNativeDeviceTypes + @dtypes(*all_types_complex_float8_and(torch.half, torch.bool, torch.bfloat16)) + def test_index_select(self, device, dtype): + num_src, num_out = 3, 5 + + def make_arg(batch_sizes, n, dim, contig): + size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:] + return make_tensor(size_arg, dtype=dtype, device=device, low=None, high=None, noncontiguous=not contig) + + def ref_index_select(src, dim, idx): + # some types not supported on numpy + not_np_dtypes = (torch.bfloat16, torch.float8_e5m2, torch.float8_e5m2fnuz, torch.float8_e4m3fn, torch.float8_e4m3fnuz) + if dtype in not_np_dtypes: + src = src.float() + out = torch.from_numpy(np.take(src.cpu().numpy(), idx.cpu().numpy(), axis=dim)) + if dtype in not_np_dtypes: + out = out.to(device=device, dtype=dtype) + return out + + for src_contig, idx_contig in product([True, False], repeat=2): + for other_sizes in ((), (4, 5)): + for dim in range(len(other_sizes)): + src = make_arg(other_sizes, num_src, dim, src_contig) + idx = make_tensor( + (num_out,), dtype=torch.int64, device=device, low=0, high=num_src, noncontiguous=not idx_contig + ) + out = torch.index_select(src, dim, idx) + out2 = ref_index_select(src, dim, idx) + self.assertEqual(out, out2) + + for idx_type in (torch.int32, torch.int64): + other_sizes = (3, 2) + dim = 1 + src = make_arg(other_sizes, num_src, dim, True) + idx = make_tensor((num_out,), dtype=idx_type, device=device, low=0, high=num_src, noncontiguous=False) + out = torch.index_select(src, dim, idx) + out2 = ref_index_select(src, dim, idx) + self.assertEqual(out, out2) + + # Create the 4 possible combinations of scalar sizes for index / source + scalars = ((make_tensor(size_s, dtype=dtype, device=device), + torch.zeros(size_i, dtype=torch.int64, device=device)) + for size_s, size_i in product([(), (1,)], repeat=2)) + for source, idx in scalars: + out = source.index_select(0, idx) + self.assertEqual(out.item(), source.item()) + + # FIXME: find a test suite for the take operator + @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_take(self, device, dtype): idx_size = (4,) @@ -3561,7 +3928,11 @@ def test_put_accumulate(self, device, dtype): # Test for parallel adds with accumulate == True low_precision = dtype == torch.half or dtype == torch.bfloat16 # Less numbers to avoid overflow with low_precision +<<<<<<< HEAD # Grainsize is 3000 for the for_loop to be parallelized on CPU +======= + # Grainsize is 3000 for the for_loop to be parallized on CPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sizes = ((100,)) if low_precision else ((200,), (3002,)) # Bfloat16 has a particularly bad performance here # This operation is nondeterministic on GPU, so we are generous with the rtol @@ -6818,7 +7189,11 @@ def test_index_add_cornercase(self): dest.index_add(0, index, source) def test_linspace_logspace(self): +<<<<<<< HEAD # Ensure the output does not require grad regardless of inputs requiring guard or not. +======= + # Ensure the output does not require grad regardless of inputs requiring gard or not. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The output of factory functions should not be part of any computational graph. start = 0.0 end = 3.0 @@ -8092,7 +8467,11 @@ def test_print(self): self.assertExpectedInline(str(x), '''tensor([1.0000e+02, 1.0000e-02])''') torch.set_printoptions(sci_mode=False) self.assertEqual(x.__repr__(), str(x)) +<<<<<<< HEAD self.assertExpectedInline(str(x), '''tensor([100.0000, 0.0100])''') +======= + self.assertExpectedInline(str(x), '''tensor([ 100.0000, 0.0100])''') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.set_printoptions(sci_mode=None) # reset to the default value # test no leading space if all elements positive @@ -8455,7 +8834,11 @@ def test_Size(self): self.assertEqual(2 * size, (1, 2, 3, 1, 2, 3)) def test_Size_concat_non_tuple_sequence(self): +<<<<<<< HEAD # check that TypeError gets raised on adding non-tuple sequences. +======= + # check that TypeError get's raised on adding non-tuple sequences. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from collections.abc import Sequence class DummySequence(Sequence): @@ -9187,7 +9570,11 @@ def test_manual_seed(self): f"after calling manual_seed({seed:x}), but got {actual_initial_seed:x} instead") self.assertEqual(expected_initial_seed, actual_initial_seed, msg=msg) for invalid_seed in [min_int64 - 1, max_uint64 + 1]: +<<<<<<< HEAD with self.assertRaisesRegex(ValueError, r'Overflow when unpacking long long'): +======= + with self.assertRaisesRegex(RuntimeError, r'Overflow when unpacking long'): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.manual_seed(invalid_seed) torch.set_rng_state(rng_state) @@ -9478,7 +9865,10 @@ def test_type(self): self.assertEqual(x.type(torch.int32).dtype, torch.int32) # FIXME: port to a quantization test suite +<<<<<<< HEAD @xfailIfS390X +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_qengine(self): qengines = torch.backends.quantized.supported_engines original_qe = torch.backends.quantized.engine @@ -10606,8 +10996,13 @@ def test_size_stride(self) -> None: def test_invalid_arg_error_handling(self) -> None: """ Tests that errors from old TH functions are propagated back """ for invalid_val in [-1, 2**65]: +<<<<<<< HEAD self.assertRaises((ValueError, RuntimeError), lambda: torch.set_num_threads(invalid_val)) self.assertRaises((ValueError, RuntimeError), lambda: torch.set_num_interop_threads(invalid_val)) +======= + self.assertRaises(RuntimeError, lambda: torch.set_num_threads(invalid_val)) + self.assertRaises(RuntimeError, lambda: torch.set_num_interop_threads(invalid_val)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _get_tensor_prop(self, t): preserved = ( @@ -10859,7 +11254,11 @@ def add_neg_dim_tests(): assert not hasattr(TestTorch, test_name), "Duplicated test name: " + test_name setattr(TestTorch, test_name, make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim)) +<<<<<<< HEAD # TODO: these empty classes are temporarily instantiated for XLA compatibility +======= +# TODO: these empy classes are temporarily instantiated for XLA compatibility +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # once XLA updates their test suite it should be removed class TestViewOps(TestCase): pass diff --git a/test/test_transformers.py b/test/test_transformers.py index b2a3959a50429..b76ceb16e1374 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -17,7 +17,11 @@ import math import itertools import torch.optim as optim +<<<<<<< HEAD from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, largeTensorTest +======= +from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU, largeTensorTest +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Optional import torch.utils.cpp_extension from torch.testing._internal.common_nn import NNTestCase @@ -49,8 +53,15 @@ PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, PLATFORM_SUPPORTS_FUSED_ATTENTION, PLATFORM_SUPPORTS_CUDNN_ATTENTION, +<<<<<<< HEAD tf32_on_and_off, tf32_enabled, +======= + SM90OrLater, + tf32_on_and_off, + tf32_enabled, + ROCM_VERSION, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if TEST_FAIRSEQ: @@ -97,7 +108,11 @@ def _check_equal( """ Compare test tensor against golden and reference tensors. Golden is the highest precision possible serving as the "ground truth" +<<<<<<< HEAD Reference is the same precision as test and should also serve as less precisie ground truth. +======= + Refernce is the same precision as test and should also serve as less precisie ground truth. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) We calcculate the "reference error" by comparing the golden to reference and use this as the measruing stick for the test tensor. @@ -339,11 +354,21 @@ def test_train_with_pad_and_catch_error(self, device): l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item() self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL") +<<<<<<< HEAD @tf32_on_and_off(0.001) +======= + @tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("attn_mask_dim", [2, 3, None]) @parametrize("key_padding_mask_dim", [2, None]) @parametrize("mask_dtype", [torch.bool, torch.float32]) def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype): +<<<<<<< HEAD +======= + if TEST_WITH_ROCM: + if attn_mask_dim is not None and mask_dtype == torch.bool: + self.skipTest("boolean mask is not fully supported on ROCm yet.") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # MHA converts all with torch.no_grad(): B = 2 @@ -426,7 +451,11 @@ def hook(module, inputs, output): # remove hook handle.remove() +<<<<<<< HEAD @tf32_on_and_off(0.0021 if TEST_WITH_ROCM else 0.001) +======= + @tf32_on_and_off(0.001) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("use_torchscript", [False]) @parametrize("enable_nested_tensor", [True, False]) @parametrize("use_autocast", [True, False]) @@ -519,7 +548,11 @@ def test_transformerencoder_fastpath(self, device, use_torchscript, enable_neste slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0) self.assertEqual(fastpath_output_expanded, slowpath_output) +<<<<<<< HEAD @tf32_on_and_off(0.001) +======= + @tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("with_no_grad", [True, False]) @parametrize("training", [True, False]) @parametrize("enable_nested_tensor", [False]) @@ -1105,7 +1138,11 @@ def forward( return_all_hiddens=False, )[0] +<<<<<<< HEAD @tf32_on_and_off(0.003) +======= + @tf32_on_and_off(0.003, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("input_dim,attn_mask_dim,is_causal", [(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True), (4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)], @@ -1655,6 +1692,22 @@ def test_invalid_sdpa_kernel_grouped_query_attention_cuda(self, device, fused_ke F.scaled_dot_product_attention(rand_query, rand_key, rand_value, dropout_p=0.0, is_causal=False, enable_gqa=True) +<<<<<<< HEAD +======= + @onlyCPU + def test_invalid_sdpa_kernel_grouped_query_attention_cpu(self, device): + rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True) + rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True) + rand_value = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True) + + with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): + with self.assertRaisesRegex(RuntimeError, "No available kernel"): + with self.assertWarnsRegex(UserWarning, "For dense inputs, both fused kernels require query, " + "key and value to have"): + F.scaled_dot_product_attention(rand_query, rand_key, rand_value, dropout_p=0.0, + is_causal=False, enable_gqa=True) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention") @parametrize("kernel", PLATFORM_SPECIFIC_SDPA) @@ -1702,7 +1755,11 @@ def test_invalid_fused_inputs_attn_mask_present(self, device, kernel: SDPBackend @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware") def test_unaligned_tensors(self, device): +<<<<<<< HEAD # The alignment is dependent on arch so we specify SM80OrLater +======= + # The alignment is depdent on arch so we specifiy SM80OrLater +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype = torch.float16 size = SdpaShape(2, 2, 8, 5) make_tensor = partial(torch.rand, size, device=device, dtype=dtype) @@ -2064,11 +2121,14 @@ def ref(x): sdp_math = torch.nn.functional.scaled_dot_product_attention(x, x, x, scale=-1.0 / 0.0001) self.assertEqual(ref_result, sdp_math) +<<<<<<< HEAD def test_scaled_dot_product_attention_fp16_overflow(self, device): # Regression test for https://github.com/pytorch/pytorch/issues/160841 x = torch.full((1, 32, 23, 80), 64.0, dtype=torch.half, device=device) y = torch.nn.functional.scaled_dot_product_attention(x, x, x) self.assertFalse(y.isnan().any().item()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestSDPACpuOnly(NNTestCase): """ Used to test CPU only functionality of scaled_dot_product_attention """ @@ -2089,6 +2149,7 @@ def test_fused_sdp_choice_cpu(self, device, type: str, dropout: float, dtype: to else: assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.FLASH_ATTENTION.value +<<<<<<< HEAD def _generate_fixed_qkv_helper( self, device, @@ -2109,6 +2170,8 @@ def _generate_fixed_qkv_helper( v = make_tensor(kv_shape).transpose(1, 2) return q, k, v +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION]) @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16]) @parametrize("batch_size", [2, 12]) @@ -2142,20 +2205,34 @@ def test_scaled_dot_product_fused_attention_mask_vs_math_cpu( tol = Tolerances(5e-2, 5e-2) if dtype is torch.float16: tol = Tolerances(1e-2, 1e-2) +<<<<<<< HEAD tol_grad = Tolerances(1e-5, 5e-6) if dtype is torch.bfloat16: tol_grad = Tolerances(5e-2, 5e-2) if dtype is torch.float16: tol_grad = Tolerances(1e-1, 1e-1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for mask_shape in itertools.product( [q_seq_len, 1], [kv_seq_len, 1] ) if mask_dim == 2 else itertools.product( [batch_size, 1], [n_head, 1], [q_seq_len, 1], [kv_seq_len, 1] ): +<<<<<<< HEAD q, k, v = self._generate_fixed_qkv_helper( device, dtype, batch_size, n_head, n_head, q_seq_len, kv_seq_len, head_dim) q2, k2, v2 = self._generate_fixed_qkv_helper( device, dtype, batch_size, n_head, n_head, q_seq_len, kv_seq_len, head_dim) +======= + make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False) + q_shape = SdpaShape(batch_size, n_head, q_seq_len, head_dim) + kv_shape = SdpaShape(batch_size, n_head, kv_seq_len, head_dim) + q = make_tensor(q_shape) + k = make_tensor(kv_shape) + v = make_tensor(kv_shape) + q2, k2, v2 = q.clone(), k.clone(), v.clone() + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if train: q.requires_grad_(True) k.requires_grad_(True) @@ -2164,6 +2241,15 @@ def test_scaled_dot_product_fused_attention_mask_vs_math_cpu( k2.requires_grad_(True) v2.requires_grad_(True) +<<<<<<< HEAD +======= + if dtype in [torch.bfloat16, torch.float16]: + q2, k2, v2 = q2.float(), k2.float(), v2.float() + # (B, nh, T, hs) + q = q.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2) + k = k.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2) + v = v.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if set_attn_mask and not casual: if bool_mask: attn_mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=device) @@ -2171,11 +2257,22 @@ def test_scaled_dot_product_fused_attention_mask_vs_math_cpu( attn_mask = torch.randn(mask_shape, dtype=dtype, device=device) else: attn_mask = None +<<<<<<< HEAD +======= + q2 = q2.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2) + k2 = k2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2) + v2 = v2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with sdpa_kernel(backends=[fused_kernel]): actual = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=casual) with sdpa_kernel(backends=[SDPBackend.MATH]): +<<<<<<< HEAD +======= + if not bool_mask and dtype in [torch.bfloat16, torch.float16] and attn_mask is not None: + attn_mask = attn_mask.float() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) math_ref = torch.nn.functional.scaled_dot_product_attention( q2, k2, v2, attn_mask=attn_mask, dropout_p=0.0, is_causal=casual) @@ -2194,6 +2291,7 @@ def test_scaled_dot_product_fused_attention_mask_vs_math_cpu( grad_q_actual, grad_k_actual, grad_v_actual = q.grad, k.grad, v.grad grad_q_ref, grad_k_ref, grad_v_ref = q2.grad, k2.grad, v2.grad +<<<<<<< HEAD self.assertFalse(grad_q_actual is None) self.assertFalse(grad_k_actual is None) self.assertFalse(grad_v_actual is None) @@ -2266,6 +2364,11 @@ def test_scaled_dot_product_fused_attention_gqa_vs_math_cpu( self.assertEqual(grad_q_actual, grad_q_ref, atol=tol_grad.atol, rtol=tol_grad.rtol) self.assertEqual(grad_k_actual, grad_k_ref, atol=tol_grad.atol, rtol=tol_grad.rtol) self.assertEqual(grad_v_actual, grad_v_ref, atol=tol_grad.atol, rtol=tol_grad.rtol) +======= + self.assertEqual(grad_q_actual, grad_q_ref, atol=tol.atol, rtol=tol.rtol) + self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol) + self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_sdpa_with_inf(self, device): # https://github.com/pytorch/pytorch/issues/127055. @@ -2663,6 +2766,7 @@ def test_cudnn_attention_d256_heuristic(self, device): v_shape = SdpaShape(batch, num_heads, seq_len, head_dim_v) query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) +<<<<<<< HEAD def test(): with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION], set_priority=True): actual = torch.nn.functional.scaled_dot_product_attention( @@ -2681,6 +2785,20 @@ def test(): else: with self.assertRaisesRegex(RuntimeError, "No available kernel."): test() +======= + with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH], set_priority=True): + actual = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) + actual.backward(torch.randn_like(actual)) + with sdpa_kernel(backends=[SDPBackend.MATH]): + math_ref = torch.nn.functional.scaled_dot_product_attention( + query.contiguous().to(torch.float32), + key.contiguous().to(torch.float32), + value.contiguous().to(torch.float32), + attn_mask=None, dropout_p=0.0, is_causal=False) + + self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_fused_attention_different_dk_dv(self, device): @@ -2706,7 +2824,10 @@ def test_fused_attention_different_dk_dv(self, device): @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") +<<<<<<< HEAD @unittest.skipIf(True, "broken as of cuDNN 9.10") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_cudnn_attention_fail_d128(self, device): # Test that cuDNN attention dispatching correctly bails out on d > 128 b, h = 1, 2 @@ -2721,6 +2842,10 @@ def test_cudnn_attention_fail_d128(self, device): ISSM90 = device_cap == (9, 0) ISSM100 = device_cap == (10, 0) with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): +<<<<<<< HEAD +======= + # SM90/100 support d <= 256 as of cuDNN 9.5.1+ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (ISSM90 or ISSM100) and torch.backends.cudnn.version() >= 90501: torch.nn.functional.scaled_dot_product_attention(q, k, v) else: @@ -2810,6 +2935,7 @@ def test_attention(backend: SDPBackend, permute_order: list[list[int]]): for permute_order in permute_orders: test_attention(SDPBackend.CUDNN_ATTENTION, list(permute_order) + [3]) +<<<<<<< HEAD @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_compiles(self): q = torch.randn(2, 8, 1024, 128, dtype=torch.half, device='cuda', requires_grad=True) @@ -2841,6 +2967,8 @@ def test_cudnn_attention_seqlen1_dropout_heuristic(self): out = torch.nn.functional.scaled_dot_product_attention(q, q, q, dropout_p=0.5) out.backward(grad) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("mask_dim", [1, 2, 3, 4]) def test_mem_efficient_attention_mask_variants(self, device, mask_dim: list[int]): @@ -3163,7 +3291,11 @@ def test_sdp_flash_attention_grad_against_math(self, device, contiguous_inputs: # Cast up and compare # Since we are doing the compute on fp16 we have to bump the tolerance +<<<<<<< HEAD # Bump down the tolerance for blfoat16 +======= + # Bump down the tolearnce for blfoat16 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) atol = 7e-4 if dtype == torch.float16 else 7e-3 rtol = 7e-4 if dtype == torch.float16 else 7e-3 if TEST_WITH_ROCM: @@ -3184,6 +3316,7 @@ def test_fused_sdp_choice(self, device, type: str): value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) +<<<<<<< HEAD device_capability = None if "cuda" in str(device): device_capability = torch.cuda.get_device_capability() @@ -3197,6 +3330,17 @@ def test_fused_sdp_choice(self, device, type: str): elif PLATFORM_SUPPORTS_FLASH_ATTENTION: self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value) elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and not prefer_cudnn: # e.g., we're on Windows +======= + # TODO we are currently disabling this by default, lets assert that this returns + # FlashAttention, we need to change when we make remove opt-in for cudnn + if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater: + self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value) + with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): + self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) + elif PLATFORM_SUPPORTS_FLASH_ATTENTION: + self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value) + elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION: # e.g., we're on Windows +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.EFFICIENT_ATTENTION.value) with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) @@ -3318,7 +3462,11 @@ def test_mem_eff_backwards_determinism(self, device): out = F.scaled_dot_product_attention(query, key, value) upward_grad = torch.rand_like(out) out.backward(upward_grad) +<<<<<<< HEAD initial_query_grad = query.grad +======= + intial_query_grad = query.grad +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Re-run the op with the same upward grad and check that the backward is # not deterministic @@ -3327,7 +3475,11 @@ def test_mem_eff_backwards_determinism(self, device): query.grad = None out = F.scaled_dot_product_attention(query, key, value) out.backward(upward_grad) +<<<<<<< HEAD if not torch.equal(initial_query_grad, query.grad): +======= + if not torch.equal(intial_query_grad, query.grad): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff_anwser_once = True break self.assertTrue(diff_anwser_once) @@ -3337,7 +3489,11 @@ def test_mem_eff_backwards_determinism(self, device): out = F.scaled_dot_product_attention(query, key, value) upward_grad = torch.rand_like(out) out.backward(upward_grad) +<<<<<<< HEAD initial_query_grad = query.grad +======= + intial_query_grad = query.grad +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Re-run the op with the same upward grad and check that the backward is # deterministic now that we have enforced it @@ -3346,7 +3502,11 @@ def test_mem_eff_backwards_determinism(self, device): query.grad = None out = F.scaled_dot_product_attention(query, key, value) out.backward(upward_grad) +<<<<<<< HEAD if not torch.equal(initial_query_grad, query.grad): +======= + if not torch.equal(intial_query_grad, query.grad): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff_anwser_once = True break self.assertFalse(diff_anwser_once) @@ -3655,7 +3815,11 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le query, key, value, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa) else: # Problem: We pad sizes in the composite region of the top level SDPA. But we need the +<<<<<<< HEAD # Debug mask when have dropout. So I am going to manually pad up here when testing dropout +======= + # Debug mask when have dropout. So I am going to manualy pad up here when testing dropout +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) q_padded, q_og_size = pad_last_dim(query, 8) k_padded, k_og_size = pad_last_dim(key, 8) v_padded, v_og_size = pad_last_dim(value, 8) @@ -4172,6 +4336,12 @@ def rand_nt(sequence_list, num_heads, head_dim): class TestSDPAXpuOnly(NNTestCase): """ Used to test XPU only functionality of scaled_dot_product_attention Mostly migrate from TestSDPACudaOnly in test/test_transformers.py +<<<<<<< HEAD +======= + + Note that as SDPBackend.OVERRIDEABLE is not managed by sdpa_kernel so that + math ref has to be called explicitly via torch.ops.aten._scaled_dot_product_attention_math. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ @parametrize("type", ["dense"]) @@ -4197,6 +4367,10 @@ def test_fused_attention_different_dk_dv(self, device): v_shape = SdpaShape(batch, num_heads, 2, head_dim_v) query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) +<<<<<<< HEAD +======= + # test that we do not dispatch to onednn for an unsupported case +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) actual = F.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) @@ -4234,6 +4408,10 @@ def test_fused_attention_gqa(self, device, dtype, batch_size, n_head, n_head_kv, v_shape = SdpaShape(batch_size, n_head_kv, kv_size, head_dim) query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) +<<<<<<< HEAD +======= + # test that we do not dispatch to onednn for an unsupported case +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) actual = F.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=is_causal, enable_gqa=True) @@ -4303,6 +4481,7 @@ def test_attention(permute_order: list[list[int]]): for permute_order in permute_orders: test_attention(list(permute_order) + [3]) +<<<<<<< HEAD def test_backends_set_to_math(self, device): dtype = torch.bfloat16 q_shape = SdpaShape(1, 1, 8, 16) @@ -4332,6 +4511,8 @@ def test_default_priority_order(self, device): self.assertTrue(overrideable_index < math_index < flash_index, f"Expected overrideable < math < flash, got {overrideable_index}, {math_index}, {flash_index}") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_scaled_dot_product_attention_fused_kernels_safe_softmax(self, device): dtype = torch.bfloat16 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) diff --git a/test/test_transformers_privateuse1.py b/test/test_transformers_privateuse1.py index 0aa15260d0949..697cf877111f4 100644 --- a/test/test_transformers_privateuse1.py +++ b/test/test_transformers_privateuse1.py @@ -4,7 +4,11 @@ from collections import namedtuple from functools import partial +<<<<<<< HEAD import torch_openreg # noqa: F401 +======= +import pytorch_openreg # noqa: F401 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch.nn.attention import SDPBackend diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index 59d856ec4fc9f..6a1ebe665f48c 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -1046,13 +1046,21 @@ def test_cat_out_different_dtypes(self, device): and not (out_dtype.is_floating_point or out_dtype.is_complex)) or ((x_dtype.is_complex or y_dtype.is_complex) and not out_dtype.is_complex)): # This combinations do not support type conversion to a different class out type +<<<<<<< HEAD with self.assertRaises(TypeError): +======= + with self.assertRaises(RuntimeError): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.cat([x, y], out=out) else: torch.cat([x, y], out=out) self.assertEqual(out, expected_out, exact_dtype=True) +<<<<<<< HEAD # Verifies that unary ops require matching out types +======= + # Verfies that unary ops require matching out types +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onlyNativeDeviceTypes @dtypes(*itertools.product((torch.int64, torch.float32, torch.float64, diff --git a/test/test_typing.py b/test/test_typing.py index f28091fa8d046..015d400d6be6a 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -35,7 +35,11 @@ def _key_func(key: str) -> str: +<<<<<<< HEAD """Split at the first occurrence of the ``:`` character. +======= + """Split at the first occurance of the ``:`` character. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Windows drive-letters (*e.g.* ``C:``) are ignored herein. """ @@ -135,7 +139,11 @@ def _parse_reveals(file: IO[str]) -> list[str]: comments = "/n".join(comments_array) # Only search for the `{*}` pattern within comments, +<<<<<<< HEAD # otherwise there is the risk of accidentally grabbing dictionaries and sets +======= + # otherwise there is the risk of accidently grabbing dictionaries and sets +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) key_set = set(re.findall(r"\{(.*?)\}", comments)) kwargs = { k: FORMAT_DICT.get(k, f"") for k in key_set diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 9939e8e76ce94..30c4126423a0c 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -54,8 +54,11 @@ ) from torch.utils import _pytree as pytree +<<<<<<< HEAD from torch.testing._internal.common_utils import IS_WINDOWS, slowTestIf +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TEST_SCIPY: import scipy @@ -273,7 +276,10 @@ def _helper_reference_numerics( # and noncontiguities. @suppress_warnings @ops(reference_filtered_ops) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_reference_numerics_normal(self, device, dtype, op): tensors = generate_elementwise_unary_tensors( op, device=device, dtype=dtype, requires_grad=False @@ -282,7 +288,10 @@ def test_reference_numerics_normal(self, device, dtype, op): @suppress_warnings @ops(reference_filtered_ops) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_reference_numerics_small(self, device, dtype, op): if dtype in (torch.bool,): raise self.skipTest("bool has no small values") @@ -294,7 +303,10 @@ def test_reference_numerics_small(self, device, dtype, op): @suppress_warnings @ops(reference_filtered_ops) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_reference_numerics_large(self, device, dtype, op): if dtype in (torch.bool, torch.uint8, torch.int8): raise self.skipTest("bool, uint8, and int8 dtypes have no large values") @@ -309,7 +321,10 @@ def test_reference_numerics_large(self, device, dtype, op): reference_filtered_ops, allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), ) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_reference_numerics_extremal(self, device, dtype, op): tensors = generate_elementwise_unary_extremal_value_tensors( op, device=device, dtype=dtype, requires_grad=False @@ -318,7 +333,10 @@ def test_reference_numerics_extremal(self, device, dtype, op): # Tests for testing (non)contiguity consistency @ops(unary_ufuncs) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_contig_vs_every_other(self, device, dtype, op): contig = make_tensor( (1026,), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1] @@ -335,7 +353,10 @@ def test_contig_vs_every_other(self, device, dtype, op): self.assertEqual(result, expected) @ops(unary_ufuncs) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_contig_vs_transposed(self, device, dtype, op): contig = make_tensor( (789, 357), device=device, dtype=dtype, low=op.domain[0], high=op.domain[1] @@ -352,7 +373,10 @@ def test_contig_vs_transposed(self, device, dtype, op): self.assertEqual(result, expected) @ops(unary_ufuncs) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_non_contig(self, device, dtype, op): shapes = [(5, 7), (1024,)] for shape in shapes: @@ -369,7 +393,10 @@ def test_non_contig(self, device, dtype, op): self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs)) @ops(unary_ufuncs) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_non_contig_index(self, device, dtype, op): contig = make_tensor( (2, 2, 1, 2), @@ -388,7 +415,10 @@ def test_non_contig_index(self, device, dtype, op): self.assertEqual(op(contig, **torch_kwargs), op(non_contig, **torch_kwargs)) @ops(unary_ufuncs) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_non_contig_expand(self, device, dtype, op): shapes = [(1, 3), (1, 7), (5, 7)] for shape in shapes: @@ -410,7 +440,10 @@ def test_non_contig_expand(self, device, dtype, op): ) @ops(unary_ufuncs) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_contig_size1(self, device, dtype, op): contig = make_tensor( (5, 100), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] @@ -426,7 +459,10 @@ def test_contig_size1(self, device, dtype, op): self.assertEqual(op(contig, **torch_kwargs), op(contig2, **torch_kwargs)) @ops(unary_ufuncs) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_contig_size1_large_dim(self, device, dtype, op): contig = make_tensor( (5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), @@ -448,7 +484,10 @@ def test_contig_size1_large_dim(self, device, dtype, op): # Tests that computation on a multiple batches is the same as # per-batch computation. @ops(unary_ufuncs) +<<<<<<< HEAD @slowTestIf(IS_WINDOWS) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_batch_vs_slicing(self, device, dtype, op): input = make_tensor( (1024, 512), dtype=dtype, device=device, low=op.domain[0], high=op.domain[1] @@ -1094,7 +1133,11 @@ def test_silu(self, device, dtype): def test_silu_complex(self, device, dtype): atol = 1e-6 rtol = 1e-6 +<<<<<<< HEAD inp_outs = [ +======= + inouts = [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (0.2 + 0.3j, 0.08775215595960617065 + 0.18024823069572448730j), (1e-19 + 1e-18j, 4.99999984132761269448e-20 + 5.00000022906852482872e-19j), (-1.0 + 2.0j, -0.78546208143234252930 + -0.44626939296722412109j), @@ -1102,7 +1145,11 @@ def test_silu_complex(self, device, dtype): (2.0j, -1.55740761756896972656 + 0.99999988079071044922j), ] +<<<<<<< HEAD for inp, out in inp_outs: +======= + for inp, out in inouts: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) res = torch.nn.functional.silu( torch.tensor(inp, dtype=dtype, device=device) ) @@ -1110,7 +1157,11 @@ def test_silu_complex(self, device, dtype): self.assertEqual(res.real, out.real, atol=atol, rtol=rtol) self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol) +<<<<<<< HEAD for inp, out in inp_outs: +======= + for inp, out in inouts: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) res = torch.nn.functional.silu( torch.tensor(inp, dtype=dtype, device=device), inplace=True ) @@ -1118,7 +1169,11 @@ def test_silu_complex(self, device, dtype): self.assertEqual(res.real, out.real, atol=atol, rtol=rtol) self.assertEqual(res.imag, out.imag, atol=atol, rtol=rtol) +<<<<<<< HEAD # It is not obvious how to merge this into OpInfo because these inputs +======= + # It is not obvious how to merge this into OpInfo becuase these inputs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # succeed for gradcheck but are expected to fail for gradgradcheck @dtypes(torch.double) def test_sinc(self, device, dtype): @@ -1184,7 +1239,11 @@ def test_log1p_complex(self, device, dtype): # Not using numpy's log1p here because by the time of writing this, # np.log1p has precision problems for small complex input values, see here: # https://github.com/numpy/numpy/issues/22609 +<<<<<<< HEAD inp_outs = [ +======= + inouts = [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (0.2 + 0.3j, 0.21263386770217202 + 0.24497866312686414j), (1e-19 + 1e-18j, 1e-19 + 1e-18j), (1e-18 + 0.1j, 0.00497517 + 0.0996687j), @@ -1198,7 +1257,11 @@ def test_log1p_complex(self, device, dtype): ] # test the extreme values if dtype == torch.complex128: +<<<<<<< HEAD inp_outs += [ +======= + inouts += [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (-1 + 1e250j, 575.6462732485114 + 1.5707963267948966j), (1e250 + 1j, 575.6462732485114 + 1e-250j), (1e250 + 1e250j, 575.9928468387914 + 0.7853981633974483j), @@ -1207,7 +1270,11 @@ def test_log1p_complex(self, device, dtype): (1e250 + 1e-250j, 575.6462732485114 + 0.0j), ] elif dtype == torch.complex64: +<<<<<<< HEAD inp_outs += [ +======= + inouts += [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (-1 + 1e30j, 69.07755278982137 + 1.5707963267948966j), (1e30 + 1j, 69.07755278982137 + 1e-30j), (1e30 + 1e30j, 69.42412638010134 + 0.7853981633974483j), @@ -1217,7 +1284,11 @@ def test_log1p_complex(self, device, dtype): ] # test the log1p individually +<<<<<<< HEAD for inp, out in inp_outs: +======= + for inp, out in inouts: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) res = torch.log1p(torch.tensor(inp, dtype=dtype, device=device)) self.assertFalse(torch.any(torch.isnan(res))) # setting up atol == 0.0 because some part has very small values @@ -1225,7 +1296,11 @@ def test_log1p_complex(self, device, dtype): self.assertEqual(res.imag, out.imag, atol=0.0, rtol=1e-6) # test the log1p in tensor +<<<<<<< HEAD inp_lst, out_lst = (list(elmt) for elmt in zip(*inp_outs)) +======= + inp_lst, out_lst = (list(elmt) for elmt in zip(*inouts)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inp_tens = torch.tensor(inp_lst, dtype=dtype, device=device) out_tens = torch.tensor(out_lst, dtype=dtype, device=device) res_tens = torch.log1p(inp_tens) @@ -1306,7 +1381,11 @@ def test_igamma_edge_cases(self, device, dtype): zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs) small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs) nans = torch.zeros((3,), **tkwargs) + float("nan") +<<<<<<< HEAD inp_outs = [ +======= + inpouts = [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # (a , x), out ((zeros, small_to_inf), ones), ((small_to_inf, zeros), zeros), @@ -1316,7 +1395,11 @@ def test_igamma_edge_cases(self, device, dtype): ((infs, infs), nans), ((-small_to_inf, small_to_inf), nans), ] +<<<<<<< HEAD for inputs, output in inp_outs: +======= + for inputs, output in inpouts: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input0, input1 = inputs calc = torch.igamma(input0, input1) if torch.all(torch.isnan(output)): @@ -1335,7 +1418,11 @@ def test_igammac_edge_cases(self, device, dtype): zero_to_large = torch.tensor([0.0, 1.0, 1e3], **tkwargs) small_to_inf = torch.tensor([1e-3, 1.0, float("inf")], **tkwargs) nans = torch.zeros((3,), **tkwargs) + float("nan") +<<<<<<< HEAD inp_outs = [ +======= + inpouts = [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # (a , x), out ((zeros, small_to_inf), zeros), ((small_to_inf, zeros), ones), @@ -1345,7 +1432,11 @@ def test_igammac_edge_cases(self, device, dtype): ((infs, infs), nans), ((-small_to_inf, small_to_inf), nans), ] +<<<<<<< HEAD for inputs, output in inp_outs: +======= + for inputs, output in inpouts: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input0, input1 = inputs calc = torch.igammac(input0, input1) if torch.all(torch.isnan(output)): diff --git a/test/test_utils.py b/test/test_utils.py index 0314da6e320a1..d54fdc849d5bd 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -20,7 +20,12 @@ import torch.utils.cpp_extension import torch.utils.data from torch._utils import try_import +<<<<<<< HEAD from torch._utils_internal import deprecated +======= +from torch.autograd._functions.utils import check_onnx_broadcast +from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, @@ -60,9 +65,12 @@ from torch.testing._internal.common_utils import run_tests, TestCase +<<<<<<< HEAD # mypy: disable-error-code="name-defined" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class RandomDatasetMock(torch.utils.data.Dataset): def __getitem__(self, index): return torch.tensor([torch.rand(1).item(), random.uniform(0, 1)]) @@ -788,6 +796,68 @@ def test_smoke(self): self.assertTrue(info_output.count("\n") >= 17) +<<<<<<< HEAD +======= +class TestONNXUtils(TestCase): + def test_prepare_onnx_paddings(self): + sizes = [2, 3, 4] + pad = [1, 2, 3, 4] + paddings = _prepare_onnx_paddings(len(sizes), pad) + self.assertEqual(paddings, [0, 3, 1, 0, 4, 2]) + + def test_check_onnx_broadcast(self): + def try_check_onnx_broadcast(dims1, dims2, expect_broadcast, expect_fail): + broadcast = True + fail = False + try: + broadcast = check_onnx_broadcast(dims1, dims2) + except ValueError: + fail = True + self.assertEqual(broadcast, expect_broadcast) + self.assertEqual(fail, expect_fail) + + # Case 1, check the case when len(dims1) < len(dims2) and numel(dims2) > 1 + dims1 = [3, 4] + dims2 = [2, 3, 4] + try_check_onnx_broadcast(dims1, dims2, True, True) + + # Case 2, check the case when len(dims1) < len(dims2) and numel(dims2) == 1 + dims1 = [3, 4] + dims2 = [1, 1, 1] + try_check_onnx_broadcast(dims1, dims2, True, False) + + # Case 3, check the case when len(dims1) > len(dims2) and numel(dims2) == 1 + dims1 = [1, 1] + dims2 = [1] + try_check_onnx_broadcast(dims1, dims2, True, False) + + # Case 4, check the case when len(dims1) > len(dims2) and dims1[x:] == dims2 + dims1 = [2, 3, 4] + dims2 = [3, 4] + try_check_onnx_broadcast(dims1, dims2, True, False) + + # Case 5, check the case when len(dims1) > len(dims2), but dims1[x:] != dims2 + dims1 = [2, 3, 4] + dims2 = [1, 4] + try_check_onnx_broadcast(dims1, dims2, True, True) + + # Case 6, check the equal case, no broadcast + dims1 = [3, 4] + dims2 = [3, 4] + try_check_onnx_broadcast(dims1, dims2, False, False) + + # Case 7, check the case when len(dims1) == len(dims2), but dims1 != dims2 + dims1 = [3, 4] + dims2 = [1, 4] + try_check_onnx_broadcast(dims1, dims2, True, True) + + # Case 8, check the case when len(dims1) == len(dims2) and numel(s2) == 1 + dims1 = [3, 4] + dims2 = [1, 1] + try_check_onnx_broadcast(dims1, dims2, True, False) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestHipify(TestCase): def test_import_hipify(self): from torch.utils.hipify import hipify_python # noqa: F401 @@ -795,9 +865,13 @@ def test_import_hipify(self): class TestHipifyTrie(TestCase): def setUp(self): +<<<<<<< HEAD from torch.utils.hipify import hipify_python self.trie = hipify_python.Trie() +======= + self.trie = torch.utils.hipify.hipify_python.Trie() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_add_and_search_trie(self): self.trie.add("banana") @@ -1142,6 +1216,7 @@ def test_import_missing(self): self.assertIsNone(missing_module) +<<<<<<< HEAD @deprecated() def _deprecated_api(x, y=15): return x + y @@ -1157,5 +1232,7 @@ def test_deprecated(self): _deprecated_api(1, y=2) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if __name__ == "__main__": run_tests() diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 5bec225787cc6..c1f0360c3bd1e 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -11,16 +11,26 @@ from torch.testing._internal.common_device_type import ( dtypes, dtypesIfMPS, +<<<<<<< HEAD expectedFailureMPS, instantiate_device_type_tests, onlyCPU, onlyNativeDeviceTypes, +======= + instantiate_device_type_tests, + onlyCPU, + onlyNativeDeviceTypes, + onlyNativeDeviceTypesAnd, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) skipLazy, skipMeta, skipXLA, ) from torch.testing._internal.common_dtype import ( +<<<<<<< HEAD all_mps_types_and, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) all_types_and, all_types_and_complex_and, complex_types, @@ -158,11 +168,16 @@ def test_conj_self(self, device, dtype): @skipIfTorchDynamo("TorchDynamo fails with unknown reason") @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) +<<<<<<< HEAD @dtypesIfMPS(*integral_types_and(torch.cfloat, torch.float, torch.half, torch.bool)) def test_view_dtype_new(self, device, dtype): dtypes = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} if device.startswith("mps"): del dtypes[torch.float64] +======= + def test_view_dtype_new(self, device, dtype): + dtypes = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) del dtypes[torch.bool] def generate_inputs(): @@ -275,7 +290,10 @@ def calc_expected_size_and_stride(a, view_dtype): # has a greater element size than the original dtype @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) +<<<<<<< HEAD @dtypesIfMPS(*all_mps_types_and(torch.bool)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_view_dtype_upsize_errors(self, device, dtype): dtype_size = torch._utils._element_size(dtype) @@ -377,7 +395,10 @@ def fn(contiguous_input=True, dim0=0, dim1=1): @onlyNativeDeviceTypes @dtypes(*complex_types(), torch.complex32) +<<<<<<< HEAD @dtypesIfMPS(torch.cfloat, torch.chalf) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_view_as_real(self, device, dtype): def fn(contiguous_input=True): t = torch.randn(3, 4, dtype=dtype, device=device) @@ -404,7 +425,13 @@ def fn(contiguous_input=True): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) +<<<<<<< HEAD @dtypesIfMPS(*all_mps_types_and(torch.bool)) +======= + @dtypesIfMPS( + *integral_types_and(torch.half, torch.bfloat16, torch.bool, torch.float32) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_view_tensor_split(self, device, dtype): a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9) a_split_dim0 = a.tensor_split(7, 0) @@ -416,7 +443,10 @@ def test_view_tensor_split(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) +<<<<<<< HEAD @dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_view_tensor_hsplit(self, device, dtype): t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) t_hsplit = torch.hsplit(t, 2) @@ -427,7 +457,10 @@ def test_view_tensor_hsplit(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) +<<<<<<< HEAD @dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_view_tensor_vsplit(self, device, dtype): t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) t_vsplit = torch.vsplit(t, 2) @@ -438,7 +471,10 @@ def test_view_tensor_vsplit(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) +<<<<<<< HEAD @dtypesIfMPS(*all_mps_types_and(torch.cfloat, torch.bool)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_view_tensor_dsplit(self, device, dtype): t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) t_dsplit = torch.dsplit(t, 2) @@ -447,9 +483,15 @@ def test_view_tensor_dsplit(self, device, dtype): t[2, 2, 2] = 7 self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2]) +<<<<<<< HEAD @onlyNativeDeviceTypes @dtypes(*all_types_and(torch.half, torch.bfloat16)) @dtypesIfMPS(*all_mps_types_and(torch.bool)) +======= + @onlyNativeDeviceTypesAnd("mps") + @dtypes(*all_types_and(torch.half, torch.bfloat16)) + @dtypesIfMPS(*integral_types_and(torch.half, torch.bool, torch.float32)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_imag_noncomplex(self, device, dtype): t = torch.ones((5, 5), dtype=dtype, device=device) @@ -458,7 +500,10 @@ def test_imag_noncomplex(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*complex_types()) +<<<<<<< HEAD @dtypesIfMPS(torch.cfloat) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_real_imag_view(self, device, dtype): def compare_with_numpy(contiguous_input=True): t = torch.randn(3, 3, dtype=dtype, device=device) @@ -489,7 +534,10 @@ def compare_with_numpy(contiguous_input=True): self.assertEqual(a[5:].imag, a.imag[5:]) @onlyNativeDeviceTypes +<<<<<<< HEAD @expectedFailureMPS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dtypes(*complex_types()) def test_conj_imag_view(self, device, dtype) -> None: t = _make_tensor((4, 5), dtype, device) @@ -521,12 +569,15 @@ def test_conj_view_with_shared_memory(self, device) -> None: all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), ) ) +<<<<<<< HEAD @dtypesIfMPS( *product( [torch.cfloat, torch.chalf], all_mps_types_and(torch.cfloat, torch.chalf, torch.bool), ) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @suppress_warnings def test_set_real_imag(self, device, dtypes): x = torch.randn(10, dtype=dtypes[0], device=device) @@ -1656,7 +1707,11 @@ def test_broadcast_shapes(self, device): inputs_with_neg_vals = [[1, 1, -12], [-1, 1], [-11]] for integral_inputs_with_neg_vals in inputs_with_neg_vals: with self.assertRaisesRegex( +<<<<<<< HEAD ValueError, "Attempting to broadcast a dimension with negative length!" +======= + RuntimeError, "Trying to create tensor with negative dimension" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.broadcast_shapes(*integral_inputs_with_neg_vals) @@ -1664,21 +1719,33 @@ def test_broadcast_shapes(self, device): for error_input in integral_inputs_error_case: with self.assertRaisesRegex( RuntimeError, +<<<<<<< HEAD ".*expected shape should be broadcastable to*", +======= + "Shape mismatch: objects cannot be broadcast to a single shape", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.broadcast_shapes(*error_input) negative_inputs = [(-1,), (1, -12), (4, -11), (-4, 1), (1, 1, -2)] for s0 in negative_inputs: with self.assertRaisesRegex( +<<<<<<< HEAD ValueError, "Attempting to broadcast a dimension with negative length!" +======= + RuntimeError, "Trying to create tensor with negative dimension" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.broadcast_shapes(s0) for s1 in negative_inputs: with self.assertRaisesRegex( +<<<<<<< HEAD ValueError, "Attempting to broadcast a dimension with negative length!", +======= + RuntimeError, "Trying to create tensor with negative dimension" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): torch.broadcast_shapes(s0, s1) @@ -1971,7 +2038,11 @@ def test_tensor_split_errors(self, device): with self.assertRaises(numpy_err, msg=msg): np.array_split(a.cpu().numpy(), sections_or_indices, dim) +<<<<<<< HEAD # additional tests for tensor_split with tensor_indices_or_sections +======= + # addtional tests for tensor_split with tensor_indices_or_sections +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.assertRaisesRegex( RuntimeError, r"tensor_split expected tensor_indices_or_sections to have dtype of long, but got Float", diff --git a/test/test_weak.py b/test/test_weak.py index 629ed12db3267..8081fe2017775 100644 --- a/test/test_weak.py +++ b/test/test_weak.py @@ -159,7 +159,11 @@ def test_weak_keyed_bad_delitem(self): self.assertRaises(KeyError, d.__delitem__, o) self.assertRaises(KeyError, d.__getitem__, o) +<<<<<<< HEAD # If a key isn't of a weakly referenceable type, __getitem__ and +======= + # If a key isn't of a weakly referencable type, __getitem__ and +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # __setitem__ raise TypeError. __delitem__ should too. self.assertRaises(TypeError, d.__delitem__, 13) self.assertRaises(TypeError, d.__getitem__, 13) diff --git a/test/test_xpu.py b/test/test_xpu.py index 04d045b00d8bc..c4d65f6900cb5 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -1,6 +1,9 @@ # Owner(s): ["module: intel"] +<<<<<<< HEAD import gc +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import re import subprocess import sys @@ -102,7 +105,10 @@ def test_get_device_properties(self): self.assertEqual(device_name, torch.xpu.get_device_name()) device_capability = torch.xpu.get_device_capability(current_device) +<<<<<<< HEAD self.assertTrue(device_capability["device_id"] > 0) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertTrue(device_capability["max_work_group_size"] > 0) self.assertTrue(device_capability["max_num_sub_groups"] > 0) self.assertEqual( @@ -134,10 +140,13 @@ def test_get_device_properties(self): device_properties.architecture, device_capability["architecture"], ) +<<<<<<< HEAD self.assertEqual( len(str(device_properties.uuid)), 36 ) # xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx self.assertEqual(len(device_properties.uuid.bytes), 16) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @unittest.skipIf(IS_WINDOWS, "not applicable to Windows (only fails with fork)") def test_wrong_xpu_fork(self): @@ -525,6 +534,7 @@ def test_device_memory_allocated(self): ) del a +<<<<<<< HEAD def test_memory_stats(self): gc.collect() torch.xpu.empty_cache() @@ -561,6 +571,8 @@ def test_memory_stats(self): self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @skipXPUIf( int(torch.version.xpu) < 20250000, "Test requires SYCL compiler version 2025.0.0 or newer.", diff --git a/test/torch_np/numpy_tests/core/test_getlimits.py b/test/torch_np/numpy_tests/core/test_getlimits.py index 8b4911e1106d2..74b8b0ad4fe93 100644 --- a/test/torch_np/numpy_tests/core/test_getlimits.py +++ b/test/torch_np/numpy_tests/core/test_getlimits.py @@ -1,7 +1,13 @@ # Owner(s): ["module: dynamo"] +<<<<<<< HEAD """Test functions for limits module.""" +======= +""" Test functions for limits module. + +""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import functools import warnings from unittest import expectedFailure as xfail, skipIf diff --git a/test/torch_np/numpy_tests/core/test_indexing.py b/test/torch_np/numpy_tests/core/test_indexing.py index 91dae96868376..f5b97131b2713 100644 --- a/test/torch_np/numpy_tests/core/test_indexing.py +++ b/test/torch_np/numpy_tests/core/test_indexing.py @@ -219,7 +219,11 @@ def test_single_int_index(self): assert_raises(IndexError, a.__getitem__, 1 << 30) # Index overflow produces IndexError # Note torch raises RuntimeError here +<<<<<<< HEAD assert_raises((IndexError, ValueError), a.__getitem__, 1 << 64) +======= + assert_raises((IndexError, RuntimeError), a.__getitem__, 1 << 64) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_single_bool_index(self): # Single boolean index diff --git a/test/torch_np/numpy_tests/core/test_multiarray.py b/test/torch_np/numpy_tests/core/test_multiarray.py index ba19b62e821a6..4950d6e32852c 100644 --- a/test/torch_np/numpy_tests/core/test_multiarray.py +++ b/test/torch_np/numpy_tests/core/test_multiarray.py @@ -4104,7 +4104,10 @@ def test_decimal(decimal_sep_localization): def test_decimal_period_separator(): pass +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_decimal_comma_separator(): with CommaDecimalPointLocale(): pass @@ -6787,10 +6790,14 @@ def test_dot_out(self): class TestArange(TestCase): def test_infinite(self): assert_raises( +<<<<<<< HEAD (RuntimeError, ValueError), np.arange, 0, np.inf, # "unsupported range", +======= + (RuntimeError, ValueError), np.arange, 0, np.inf # "unsupported range", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_nan_step(self): diff --git a/test/torch_np/numpy_tests/core/test_numeric.py b/test/torch_np/numpy_tests/core/test_numeric.py index 75bf5c0fc6287..46dc2f010292f 100644 --- a/test/torch_np/numpy_tests/core/test_numeric.py +++ b/test/torch_np/numpy_tests/core/test_numeric.py @@ -2733,6 +2733,7 @@ def test_errors(self): assert_raises(np.AxisError, np.moveaxis, x, 3, 0) # 'source.*out of bounds', assert_raises(np.AxisError, np.moveaxis, x, -4, 0) # 'source.*out of bounds', assert_raises( +<<<<<<< HEAD np.AxisError, np.moveaxis, x, @@ -2745,6 +2746,12 @@ def test_errors(self): x, [0, 0], [0, 1], # 'repeated axis in `source`', +======= + np.AxisError, np.moveaxis, x, 0, 5 # 'destination.*out of bounds', + ) + assert_raises( + ValueError, np.moveaxis, x, [0, 0], [0, 1] # 'repeated axis in `source`', +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) assert_raises( ValueError, # 'repeated axis in `destination`', diff --git a/test/torch_np/numpy_tests/core/test_scalar_ctors.py b/test/torch_np/numpy_tests/core/test_scalar_ctors.py index a630eda39ce8c..a3d5cee2e7ce1 100644 --- a/test/torch_np/numpy_tests/core/test_scalar_ctors.py +++ b/test/torch_np/numpy_tests/core/test_scalar_ctors.py @@ -3,7 +3,10 @@ """ Test the scalar constructors, which also do type-coercion """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import functools from unittest import skipIf as skipif diff --git a/test/torch_np/numpy_tests/core/test_scalar_methods.py b/test/torch_np/numpy_tests/core/test_scalar_methods.py index 2bfce58f39442..ff16b0753e595 100644 --- a/test/torch_np/numpy_tests/core/test_scalar_methods.py +++ b/test/torch_np/numpy_tests/core/test_scalar_methods.py @@ -3,7 +3,10 @@ """ Test the scalar constructors, which also do type-coercion """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import fractions import functools import types diff --git a/test/torch_np/numpy_tests/core/test_scalarinherit.py b/test/torch_np/numpy_tests/core/test_scalarinherit.py index 7c7fec495f182..0b378d2fb0ceb 100644 --- a/test/torch_np/numpy_tests/core/test_scalarinherit.py +++ b/test/torch_np/numpy_tests/core/test_scalarinherit.py @@ -1,7 +1,13 @@ # Owner(s): ["module: dynamo"] +<<<<<<< HEAD """Test printing of scalar types.""" +======= +""" Test printing of scalar types. + +""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import functools from unittest import skipIf as skipif diff --git a/test/torch_np/numpy_tests/core/test_shape_base.py b/test/torch_np/numpy_tests/core/test_shape_base.py index 9563d0c8bacba..baca34a911a2f 100644 --- a/test/torch_np/numpy_tests/core/test_shape_base.py +++ b/test/torch_np/numpy_tests/core/test_shape_base.py @@ -811,10 +811,14 @@ def test_invalid_nesting(self, block): assert_raises_regex(ValueError, msg, block, [[1], 2]) assert_raises_regex(ValueError, msg, block, [[], 2]) assert_raises_regex( +<<<<<<< HEAD ValueError, msg, block, [[[1], [2]], [[3, 4]], [5]], # missing brackets +======= + ValueError, msg, block, [[[1], [2]], [[3, 4]], [5]] # missing brackets +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_empty_lists(self, block): diff --git a/test/torch_np/numpy_tests/fft/test_helper.py b/test/torch_np/numpy_tests/fft/test_helper.py index 2b6a384bf899c..e7c4c11a3d620 100644 --- a/test/torch_np/numpy_tests/fft/test_helper.py +++ b/test/torch_np/numpy_tests/fft/test_helper.py @@ -5,7 +5,10 @@ Copied from fftpack.helper by Pearu Peterson, October 2005 """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_utils import ( run_tests, TEST_WITH_TORCHDYNAMO, diff --git a/test/torch_np/numpy_tests/fft/test_pocketfft.py b/test/torch_np/numpy_tests/fft/test_pocketfft.py index 70607110d214d..1f101dc94531f 100644 --- a/test/torch_np/numpy_tests/fft/test_pocketfft.py +++ b/test/torch_np/numpy_tests/fft/test_pocketfft.py @@ -372,7 +372,11 @@ def worker(args, q): assert_allclose( q.get(timeout=5), expected, +<<<<<<< HEAD atol=2e-14, +======= + atol=2e-14 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # msg="Function returned wrong value in multithreaded context", ) diff --git a/test/torch_np/numpy_tests/lib/test_arraysetops.py b/test/torch_np/numpy_tests/lib/test_arraysetops.py index 79f41fc415af2..f19900c3ccc68 100644 --- a/test/torch_np/numpy_tests/lib/test_arraysetops.py +++ b/test/torch_np/numpy_tests/lib/test_arraysetops.py @@ -1,7 +1,13 @@ # Owner(s): ["module: dynamo"] +<<<<<<< HEAD """Test functions for 1D array set operations.""" +======= +"""Test functions for 1D array set operations. + +""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from unittest import expectedFailure as xfail, skipIf import numpy diff --git a/test/torch_np/numpy_tests/lib/test_function_base.py b/test/torch_np/numpy_tests/lib/test_function_base.py index 13dba55837cf5..465902eb4e705 100644 --- a/test/torch_np/numpy_tests/lib/test_function_base.py +++ b/test/torch_np/numpy_tests/lib/test_function_base.py @@ -2881,8 +2881,12 @@ def test_linear_nan_1D(self, dtype): np.testing.assert_equal(res.dtype, arr.dtype) H_F_TYPE_CODES = [ +<<<<<<< HEAD (int_type, np.float64) for int_type in "Bbhil" # np.typecodes["AllInteger"] +======= + (int_type, np.float64) for int_type in "Bbhil" # np.typecodes["AllInteger"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] + [ (np.float16, np.float16), (np.float32, np.float32), diff --git a/test/torch_np/numpy_tests/lib/test_histograms.py b/test/torch_np/numpy_tests/lib/test_histograms.py index 82382cfc147e3..ce58d0664e543 100644 --- a/test/torch_np/numpy_tests/lib/test_histograms.py +++ b/test/torch_np/numpy_tests/lib/test_histograms.py @@ -505,7 +505,12 @@ def test_simple(self): assert_equal( len(a), numbins, +<<<<<<< HEAD err_msg=f"For the {estimator} estimator with datasize of {testlen}", +======= + err_msg=f"For the {estimator} estimator " + f"with datasize of {testlen}", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_small(self): @@ -551,7 +556,12 @@ def test_small(self): assert_equal( len(a), expbins, +<<<<<<< HEAD err_msg=f"For the {estimator} estimator with datasize of {testlen}", +======= + err_msg=f"For the {estimator} estimator " + f"with datasize of {testlen}", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def test_incorrect_methods(self): diff --git a/test/torch_np/numpy_tests/lib/test_twodim_base.py b/test/torch_np/numpy_tests/lib/test_twodim_base.py index f873ae8091a2f..8dbffac7061a0 100644 --- a/test/torch_np/numpy_tests/lib/test_twodim_base.py +++ b/test/torch_np/numpy_tests/lib/test_twodim_base.py @@ -1,7 +1,13 @@ # Owner(s): ["module: dynamo"] +<<<<<<< HEAD """Test functions for matrix module""" +======= +"""Test functions for matrix module + +""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import functools from unittest import expectedFailure as xfail, skipIf as skipif diff --git a/test/torch_np/numpy_tests/linalg/test_linalg.py b/test/torch_np/numpy_tests/linalg/test_linalg.py index f8fa81bca63e5..f6ef4ee515194 100644 --- a/test/torch_np/numpy_tests/linalg/test_linalg.py +++ b/test/torch_np/numpy_tests/linalg/test_linalg.py @@ -1,6 +1,12 @@ # Owner(s): ["module: dynamo"] +<<<<<<< HEAD """Test functions for linalg module""" +======= +""" Test functions for linalg module + +""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import functools import itertools import os @@ -488,7 +494,11 @@ class SolveCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): # kept apart from TestSolve for use for testing with matrices. def do(self, a, b, tags): x = linalg.solve(a, b) +<<<<<<< HEAD assert_almost_equal(b, dot_generalized(a, x), single_decimal=5) +======= + assert_almost_equal(b, dot_generalized(a, x)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert_(consistent_subclass(x, b)) diff --git a/test/torch_np/test_random.py b/test/torch_np/test_random.py index 8ef19caa7624c..11e5bf82a4584 100644 --- a/test/torch_np/test_random.py +++ b/test/torch_np/test_random.py @@ -1,7 +1,12 @@ # Owner(s): ["module: dynamo"] +<<<<<<< HEAD """Light smoke test switching between numpy to pytorch random streams.""" +======= +"""Light smoke test switching between numpy to pytorch random streams. +""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from contextlib import contextmanager from functools import partial diff --git a/test/torch_np/test_ufuncs_basic.py b/test/torch_np/test_ufuncs_basic.py index b8d923cdd3f19..64c370c87588e 100644 --- a/test/torch_np/test_ufuncs_basic.py +++ b/test/torch_np/test_ufuncs_basic.py @@ -9,7 +9,10 @@ by >>> import torch._numpy as np """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import operator from unittest import skipIf as skip, SkipTest diff --git a/test/typing/fail/arithmetic_ops.py b/test/typing/fail/arithmetic_ops.py index b3f816329445a..b96f008bf3231 100644 --- a/test/typing/fail/arithmetic_ops.py +++ b/test/typing/fail/arithmetic_ops.py @@ -7,12 +7,17 @@ # See ../pass/arithmetic_ops.py for more information +<<<<<<< HEAD TENSOR, FLOAT = randn(3), 1.5 +======= +TENSOR, INT, FLOAT = randn(3), 2, 1.5 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) FLOAT & TENSOR # E: Unsupported operand types for & ("float" and "Tensor") FLOAT | TENSOR # E: Unsupported operand types for | ("float" and "Tensor") FLOAT ^ TENSOR # E: Unsupported operand types for ^ ("float" and "Tensor") # FIXME: false negatives (https://github.com/pytorch/pytorch/issues/155701) +<<<<<<< HEAD # # FLOAT << TENSOR # E: Unsupported operand types for & ("float" and "Tensor") # FLOAT >> TENSOR # E: Unsupported operand types for & ("float" and "Tensor") @@ -22,3 +27,8 @@ # TENSOR ^ FLOAT # E: Unsupported operand types for ^ ("Tensor" and "float" ) # TENSOR << FLOAT # E: Unsupported operand types for & ("Tensor" and "float") # TENSOR >> FLOAT # E: Unsupported operand types for & ("Tensor" and "float") +======= +# TENSOR & FLOAT # E: Unsupported operand types for & ("Tensor" and "float" ) +# TENSOR | FLOAT # E: Unsupported operand types for | ("Tensor" and "float" ) +# TENSOR ^ FLOAT # E: Unsupported operand types for ^ ("Tensor" and "float" ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/test/typing/pass/arithmetic_ops.py b/test/typing/pass/arithmetic_ops.py index 556ef90523e94..35c5e44934454 100644 --- a/test/typing/pass/arithmetic_ops.py +++ b/test/typing/pass/arithmetic_ops.py @@ -1,9 +1,14 @@ +<<<<<<< HEAD from typing import Union +======= +from typing import Any, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing_extensions import assert_type, TypeAlias from torch import randn, Tensor +<<<<<<< HEAD # Test deduced types of arithmetic operations between tensors, ints, floats and bools # The expected type should always be `Tensor`, but isn't. # See https://github.com/pytorch/pytorch/issues/145838 @@ -13,11 +18,21 @@ # # Unary ops # +======= +TENSOR, INT, FLOAT, BOOL = randn(3), 2, 1.5, True + +# Test deduced types of arithmetic operations between tensors, ints, floats and bools +# The expected type should always be `Tensor`: `Any` and `bool` below are wrong. +# See https://github.com/pytorch/pytorch/issues/145838 + +# Unary ops +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert_type(+TENSOR, Tensor) assert_type(-TENSOR, Tensor) assert_type(~TENSOR, Tensor) +<<<<<<< HEAD # # Binary ops that return a bolean # @@ -204,6 +219,139 @@ assert_type(INT ^ TENSOR, Tensor) assert_type(TENSOR ^ FLOAT, Tensor) # Should fail type checking assert_type(FLOAT ^ TENSOR, Tensor) # type: ignore[operator] +======= +# Binary ops + +assert_type(TENSOR == TENSOR, Tensor) +assert_type(TENSOR != TENSOR, Tensor) +assert_type(TENSOR < TENSOR, Tensor) +assert_type(TENSOR > TENSOR, Tensor) +assert_type(TENSOR <= TENSOR, Tensor) +assert_type(TENSOR >= TENSOR, Tensor) +assert_type(TENSOR + TENSOR, Tensor) +assert_type(TENSOR - TENSOR, Tensor) +assert_type(TENSOR * TENSOR, Tensor) +assert_type(TENSOR // TENSOR, Any) +assert_type(TENSOR / TENSOR, Tensor) +assert_type(TENSOR % TENSOR, Tensor) +assert_type(TENSOR**TENSOR, Tensor) +assert_type(TENSOR << TENSOR, Tensor) +assert_type(TENSOR >> TENSOR, Tensor) +assert_type(TENSOR & TENSOR, Tensor) +assert_type(TENSOR | TENSOR, Tensor) +assert_type(TENSOR ^ TENSOR, Tensor) + +assert_type(TENSOR == BOOL, Tensor) +assert_type(TENSOR != BOOL, Tensor) +assert_type(TENSOR < BOOL, Tensor) +assert_type(TENSOR > BOOL, Tensor) +assert_type(TENSOR <= BOOL, Tensor) +assert_type(TENSOR >= BOOL, Tensor) +assert_type(TENSOR + BOOL, Tensor) +assert_type(TENSOR - BOOL, Tensor) +assert_type(TENSOR * BOOL, Tensor) +assert_type(TENSOR // BOOL, Any) +assert_type(TENSOR / BOOL, Tensor) +assert_type(TENSOR % BOOL, Tensor) +assert_type(TENSOR**BOOL, Tensor) +assert_type(TENSOR << BOOL, Tensor) +assert_type(TENSOR >> BOOL, Tensor) +assert_type(TENSOR & BOOL, Tensor) +assert_type(TENSOR | BOOL, Tensor) +assert_type(TENSOR ^ BOOL, Tensor) + +assert_type(BOOL == TENSOR, bool) +assert_type(BOOL != TENSOR, bool) +assert_type(BOOL < TENSOR, Tensor) +assert_type(BOOL > TENSOR, Tensor) +assert_type(BOOL <= TENSOR, Tensor) +assert_type(BOOL >= TENSOR, Tensor) +assert_type(BOOL + TENSOR, Tensor) +assert_type(BOOL - TENSOR, Any) +assert_type(BOOL * TENSOR, Tensor) +assert_type(BOOL // TENSOR, Any) +assert_type(BOOL / TENSOR, Any) +assert_type(BOOL % TENSOR, Any) +assert_type(BOOL**TENSOR, Any) +assert_type(BOOL << TENSOR, Any) +assert_type(BOOL >> TENSOR, Any) +assert_type(BOOL & TENSOR, Tensor) +assert_type(BOOL | TENSOR, Tensor) +assert_type(BOOL ^ TENSOR, Tensor) + +assert_type(TENSOR == INT, Tensor) +assert_type(TENSOR != INT, Tensor) +assert_type(TENSOR < INT, Tensor) +assert_type(TENSOR > INT, Tensor) +assert_type(TENSOR <= INT, Tensor) +assert_type(TENSOR >= INT, Tensor) +assert_type(TENSOR + INT, Tensor) +assert_type(TENSOR - INT, Tensor) +assert_type(TENSOR * INT, Tensor) +assert_type(TENSOR // INT, Any) +assert_type(TENSOR / INT, Tensor) +assert_type(TENSOR % INT, Tensor) +assert_type(TENSOR**INT, Tensor) +assert_type(TENSOR << INT, Tensor) +assert_type(TENSOR >> INT, Tensor) +assert_type(TENSOR & INT, Tensor) +assert_type(TENSOR | INT, Tensor) +assert_type(TENSOR ^ INT, Tensor) + +assert_type(INT == TENSOR, bool) +assert_type(INT != TENSOR, bool) +assert_type(INT < TENSOR, Tensor) +assert_type(INT > TENSOR, Tensor) +assert_type(INT <= TENSOR, Tensor) +assert_type(INT >= TENSOR, Tensor) +assert_type(INT + TENSOR, Tensor) +assert_type(INT - TENSOR, Any) +assert_type(INT * TENSOR, Tensor) +assert_type(INT // TENSOR, Any) +assert_type(INT / TENSOR, Any) +assert_type(INT % TENSOR, Any) +assert_type(INT**TENSOR, Any) +assert_type(INT << TENSOR, Any) +assert_type(INT >> TENSOR, Any) +assert_type(INT & TENSOR, Tensor) +assert_type(INT | TENSOR, Tensor) +assert_type(INT ^ TENSOR, Tensor) + +assert_type(TENSOR == FLOAT, Tensor) +assert_type(TENSOR != FLOAT, Tensor) +assert_type(TENSOR < FLOAT, Tensor) +assert_type(TENSOR > FLOAT, Tensor) +assert_type(TENSOR <= FLOAT, Tensor) +assert_type(TENSOR >= FLOAT, Tensor) +assert_type(TENSOR + FLOAT, Tensor) +assert_type(TENSOR - FLOAT, Tensor) +assert_type(TENSOR * FLOAT, Tensor) +assert_type(TENSOR // FLOAT, Any) +assert_type(TENSOR / FLOAT, Tensor) +assert_type(TENSOR % FLOAT, Tensor) +assert_type(TENSOR**FLOAT, Tensor) +assert_type(TENSOR << FLOAT, Tensor) +assert_type(TENSOR >> FLOAT, Tensor) +assert_type(TENSOR & FLOAT, Tensor) +assert_type(TENSOR | FLOAT, Tensor) +assert_type(TENSOR ^ FLOAT, Tensor) + +assert_type(FLOAT == TENSOR, bool) +assert_type(FLOAT != TENSOR, bool) +assert_type(FLOAT < TENSOR, Tensor) +assert_type(FLOAT > TENSOR, Tensor) +assert_type(FLOAT <= TENSOR, Tensor) +assert_type(FLOAT >= TENSOR, Tensor) +assert_type(FLOAT + TENSOR, Tensor) +assert_type(FLOAT - TENSOR, Any) +assert_type(FLOAT * TENSOR, Tensor) +assert_type(FLOAT // TENSOR, Any) +assert_type(FLOAT / TENSOR, Any) +assert_type(FLOAT % TENSOR, Any) +assert_type(FLOAT**TENSOR, Any) +assert_type(FLOAT << TENSOR, Any) +assert_type(FLOAT >> TENSOR, Any) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NUMBER: TypeAlias = Union[int, float, bool] @@ -427,3 +575,41 @@ def __xor__(self, other: NUMBER) -> "Binary": # type: ignore[override] assert_type(BOOL >> BINARY, Binary) assert_type(BOOL - BINARY, Binary) assert_type(BOOL ^ BINARY, Binary) +<<<<<<< HEAD +======= + +# Tensor operators whose types could be improved +# This is the "diff" of the first and second sections. + +assert_type(BOOL // TENSOR, Any) +assert_type(FLOAT // TENSOR, Any) +assert_type(INT // TENSOR, Any) +assert_type(TENSOR // BOOL, Any) +assert_type(TENSOR // FLOAT, Any) +assert_type(TENSOR // INT, Any) +assert_type(TENSOR // TENSOR, Any) + +assert_type(BOOL**TENSOR, Any) +assert_type(FLOAT**TENSOR, Any) +assert_type(INT**TENSOR, Any) + +assert_type(BOOL - TENSOR, Any) +assert_type(FLOAT - TENSOR, Any) +assert_type(INT - TENSOR, Any) + +assert_type(BOOL / TENSOR, Any) +assert_type(FLOAT / TENSOR, Any) +assert_type(INT / TENSOR, Any) + +assert_type(BOOL % TENSOR, Any) +assert_type(FLOAT % TENSOR, Any) +assert_type(INT % TENSOR, Any) + +assert_type(BOOL << TENSOR, Any) +assert_type(FLOAT << TENSOR, Any) +assert_type(INT << TENSOR, Any) + +assert_type(BOOL >> TENSOR, Any) +assert_type(FLOAT >> TENSOR, Any) +assert_type(INT >> TENSOR, Any) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/test/xpu/test_gemm.py b/test/xpu/test_gemm.py index 1164a2b676368..6eac49cc6228d 100644 --- a/test/xpu/test_gemm.py +++ b/test/xpu/test_gemm.py @@ -12,9 +12,12 @@ import numpy as np import torch +<<<<<<< HEAD import torch._inductor.decomposition from torch._higher_order_ops.out_dtype import out_dtype from torch.fx.experimental.proxy_tensor import make_fx +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing import make_tensor from torch.testing._internal.common_device_type import ( dtypes, @@ -1367,6 +1370,7 @@ def test_mm_with_offset(self, device): cpu_out = torch.matmul(a.cpu(), b.cpu()) self.assertEqual(gpu_out.cpu(), cpu_out) +<<<<<<< HEAD @parametrize("m", [0, 8, 17]) @parametrize("k", [0, 16, 32]) @parametrize("n", [16, 32]) @@ -1446,6 +1450,8 @@ def forward(self, x_1, w_1): return out_dtype""", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True) diff --git a/third_party/tensorpipe.BUILD b/third_party/tensorpipe.BUILD index 5e5b69b4cb4ec..2e5c6a8dba2db 100644 --- a/third_party/tensorpipe.BUILD +++ b/third_party/tensorpipe.BUILD @@ -7,7 +7,10 @@ LIBUV_COMMON_SRCS = [ "third_party/libuv/src/inet.c", "third_party/libuv/src/random.c", "third_party/libuv/src/strscpy.c", +<<<<<<< HEAD "third_party/libuv/src/strtok.c", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "third_party/libuv/src/threadpool.c", "third_party/libuv/src/timer.c", "third_party/libuv/src/uv-common.c", @@ -38,7 +41,13 @@ LIBUV_POSIX_SRCS = [ LIBUV_LINUX_SRCS = LIBUV_POSIX_SRCS + [ "third_party/libuv/src/unix/proctitle.c", +<<<<<<< HEAD "third_party/libuv/src/unix/linux.c", +======= + "third_party/libuv/src/unix/linux-core.c", + "third_party/libuv/src/unix/linux-inotify.c", + "third_party/libuv/src/unix/linux-syscalls.c", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "third_party/libuv/src/unix/procfs-exepath.c", "third_party/libuv/src/unix/random-getrandom.c", "third_party/libuv/src/unix/random-sysctl-linux.c", @@ -59,7 +68,10 @@ cc_library( "third_party/libuv/src/unix/*.h", ], ), +<<<<<<< HEAD copts = ["-D_GNU_SOURCE"], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) visibility = ["//visibility:public"], ) @@ -151,7 +163,11 @@ cc_library( ".", ], copts = [ +<<<<<<< HEAD "-std=c++17", +======= + "-std=c++14", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ], visibility = ["//visibility:public"], deps = [ @@ -168,7 +184,11 @@ cc_library( ".", ], copts = [ +<<<<<<< HEAD "-std=c++17", +======= + "-std=c++14", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ], visibility = ["//visibility:public"], deps = [ diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index b353d5d0d5982..29d96c34ad210 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -1437,7 +1437,11 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F labels = labels, platform_srcs = [ ( +<<<<<<< HEAD "(arm64|aarch64)", +======= + "(arm64|aarch64)$", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prod_srcs_for_arch_wrapper("neonfma") + prod_srcs_for_arch_wrapper("neonfma_aarch64"), ), ] if not is_arvr_mode() else [], @@ -2227,7 +2231,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], # doesn't cover iphonesimulator-x86_64 "ovr_config//runtime:arm64-linux-ubuntu-neon": [":arm64_lib"], +<<<<<<< HEAD "ovr_config//runtime:fbcode-arm64": [":arm64_lib"], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "ovr_config//runtime:platform010": [":x86_and_x86_64_lib"], }), ) diff --git a/third_party/xpu.txt b/third_party/xpu.txt index c402bb1984830..5eac1c27f80d8 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1,5 @@ +<<<<<<< HEAD 789f59d8261b521282a26025c4a7a201621b4683 +======= +3a9419c8bb6a98dd3e3cd473c36691fb4abeae40 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/tools/autograd/build.bzl b/tools/autograd/build.bzl index c5ddf7a20b800..edd9454ebe777 100644 --- a/tools/autograd/build.bzl +++ b/tools/autograd/build.bzl @@ -12,9 +12,12 @@ def define_targets(rules): "//torchgen", ], ) +<<<<<<< HEAD rules.filegroup( name = "deprecated_yaml", srcs = ["deprecated.yaml"], visibility = ["//:__subpackages__"], ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index c050c6cbdc4c3..484daac708447 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1267,11 +1267,14 @@ mean: not_implemented("native_layer_norm_backward mean") rstd: not_implemented("native_layer_norm_backward rstd") +<<<<<<< HEAD - name: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) input, weight: "GradMode::is_enabled() || grads[1].defined() ? infinitely_differentiable_native_rms_norm_backward(grads[0], grads[1], input, normalized_shape, result1, weight, grad_input_mask) : (grads[0].defined() ? _fused_rms_norm_backward(grads[0], input, normalized_shape, result1, weight, grad_input_mask) : std::tuple())" result0: rms_norm_jvp(input_p, input_t, weight_p, weight_t, result1, normalized_shape) result1: rms_norm_rstd_jvp(input_p, input_t, result1, normalized_shape) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())" result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group) @@ -1799,9 +1802,12 @@ self: zeros_like(grad) result: auto_element_wise +<<<<<<< HEAD - name: hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor output_differentiability: [False] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # DO NOT define a backward for to_dense # See [Note: Sometimes view derivatives] # - name: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor @@ -2801,7 +2807,11 @@ self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" - name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) +<<<<<<< HEAD input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" +======= + input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon) - name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) @@ -2904,10 +2914,13 @@ output_differentiability: [True, False, False, False, False, False] query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale) +<<<<<<< HEAD - name: _cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) output_differentiability: [True, False, False, False, False, False, False, False, False] query, key, value: _cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) output_differentiability: [True, False, False, False, False, False, False, False, False] query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 5a003cadf6b32..251235f58415b 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -97,7 +97,10 @@ "is_sparse_csr", "size", "stride", +<<<<<<< HEAD "sym_is_contiguous", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "sym_size", "sym_stride", "sym_storage_offset", diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index bfc5b80835c4b..fb298ed48d247 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -264,7 +264,11 @@ static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args, PyObjec auto& self_ = THPVariable_Unpack(self); auto memory_format = r.memoryformat(0); // avoids touching the GIL or current device if self is already contiguous +<<<<<<< HEAD if (self_.is_contiguous_or_false(memory_format)) { +======= + if (self_.is_contiguous(memory_format)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // NOTE: this logic is duplicated from VariableType.cpp. Since we need to // record this call to contiguous() in the trace regardless of whether // we actually call contiguous here, we need to record this information diff --git a/tools/bazel.bzl b/tools/bazel.bzl index 9b662859adb46..768d8787445e0 100644 --- a/tools/bazel.bzl +++ b/tools/bazel.bzl @@ -2,7 +2,11 @@ load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test") load("@rules_cuda//cuda:defs.bzl", "cuda_library", "requires_cuda_enabled") load("@rules_python//python:defs.bzl", "py_binary", "py_library") load("@pip_deps//:requirements.bzl", "requirement") +<<<<<<< HEAD load("@pytorch//torch/headeronly/macros:cmake_configure_file.bzl", "cmake_configure_file") +======= +load("@pytorch//c10/macros:cmake_configure_file.bzl", "cmake_configure_file") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) load("@pytorch//tools/config:defs.bzl", "if_cuda") def _genrule(**kwds): diff --git a/tools/build/bazel/requirements.in b/tools/build/bazel/requirements.in index ae94ca4a24c4d..05711db83a8d3 100644 --- a/tools/build/bazel/requirements.in +++ b/tools/build/bazel/requirements.in @@ -1,7 +1,16 @@ +<<<<<<< HEAD pyyaml==6.0.2 numpy==1.26.4 requests==2.32.4 setuptools==78.1.1 sympy==1.12 typing-extensions==4.11.0 +======= +PyYAML==6.0.1 +numpy==1.26.4 +requests==2.32.2 +setuptools==78.1.1 +sympy==1.12 +typing_extensions==4.11.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) networkx==2.8.8 diff --git a/tools/build/bazel/requirements.txt b/tools/build/bazel/requirements.txt index 288c8cf1fba6f..7272d61ac19cc 100644 --- a/tools/build/bazel/requirements.txt +++ b/tools/build/bazel/requirements.txt @@ -1,3 +1,4 @@ +<<<<<<< HEAD # This file was autogenerated by uv via the following command: # uv pip compile --generate-hashes tools/build/bazel/requirements.in --output-file tools/build/bazel/requirements.txt certifi==2025.7.14 \ @@ -101,6 +102,113 @@ charset-normalizer==3.4.2 \ idna==3.10 \ --hash=sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9 \ --hash=sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3 +======= +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# pip-compile --allow-unsafe --generate-hashes tools/build/bazel/requirements.in +# +certifi==2024.7.4 \ + --hash=sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b \ + --hash=sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90 + # via requests +charset-normalizer==3.3.2 \ + --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ + --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ + --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ + --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ + --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ + --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ + --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ + --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ + --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ + --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ + --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ + --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ + --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ + --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ + --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ + --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ + --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ + --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ + --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ + --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ + --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ + --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ + --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ + --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ + --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ + --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ + --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ + --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ + --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ + --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ + --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ + --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ + --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ + --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ + --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ + --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ + --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ + --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ + --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ + --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ + --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ + --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ + --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ + --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ + --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ + --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ + --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ + --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ + --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ + --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ + --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ + --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ + --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ + --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ + --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ + --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ + --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ + --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ + --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ + --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ + --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ + --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ + --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ + --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ + --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ + --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ + --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ + --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ + --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ + --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ + --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ + --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ + --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ + --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ + --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ + --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ + --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ + --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ + --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ + --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ + --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ + --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ + --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ + --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ + --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ + --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ + --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ + --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ + --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ + --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 + # via requests +idna==3.7 \ + --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ + --hash=sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # via requests mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ @@ -109,7 +217,11 @@ mpmath==1.3.0 \ networkx==2.8.8 \ --hash=sha256:230d388117af870fce5647a3c52401fcf753e94720e6ea6b4197a5355648885e \ --hash=sha256:e435dfa75b1d7195c7b8378c3859f0445cd88c6b0375c181ed66823a9ceb7524 +<<<<<<< HEAD # via -r tools/build/bazel/requirements.in +======= + # via -r requirements.in +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) numpy==1.26.4 \ --hash=sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b \ --hash=sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818 \ @@ -147,6 +259,7 @@ numpy==1.26.4 \ --hash=sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef \ --hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \ --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f +<<<<<<< HEAD # via -r tools/build/bazel/requirements.in pyyaml==6.0.2 \ --hash=sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff \ @@ -223,3 +336,81 @@ urllib3==2.5.0 \ --hash=sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760 \ --hash=sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc # via requests +======= + # via -r requirements.in +pyyaml==6.0.1 \ + --hash=sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5 \ + --hash=sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc \ + --hash=sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df \ + --hash=sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741 \ + --hash=sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206 \ + --hash=sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27 \ + --hash=sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595 \ + --hash=sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62 \ + --hash=sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98 \ + --hash=sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696 \ + --hash=sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290 \ + --hash=sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9 \ + --hash=sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d \ + --hash=sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6 \ + --hash=sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867 \ + --hash=sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47 \ + --hash=sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486 \ + --hash=sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6 \ + --hash=sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3 \ + --hash=sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007 \ + --hash=sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938 \ + --hash=sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0 \ + --hash=sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c \ + --hash=sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735 \ + --hash=sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d \ + --hash=sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28 \ + --hash=sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4 \ + --hash=sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba \ + --hash=sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8 \ + --hash=sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef \ + --hash=sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5 \ + --hash=sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd \ + --hash=sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3 \ + --hash=sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0 \ + --hash=sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515 \ + --hash=sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c \ + --hash=sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c \ + --hash=sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924 \ + --hash=sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34 \ + --hash=sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43 \ + --hash=sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859 \ + --hash=sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673 \ + --hash=sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54 \ + --hash=sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a \ + --hash=sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b \ + --hash=sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab \ + --hash=sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa \ + --hash=sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c \ + --hash=sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585 \ + --hash=sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d \ + --hash=sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f + # via -r requirements.in +requests==2.32.2 \ + --hash=sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289 \ + --hash=sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c + # via -r requirements.in +sympy==1.12 \ + --hash=sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5 \ + --hash=sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8 + # via -r requirements.in +typing-extensions==4.11.0 \ + --hash=sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0 \ + --hash=sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a + # via -r requirements.in +urllib3==2.2.2 \ + --hash=sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472 \ + --hash=sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168 + # via requests + +# The following packages are considered to be unsafe in a requirements file: +setuptools==78.1.1 \ + --hash=sha256:c3a9c4211ff4c309edb8b8c4f1cbfa7ae324c4ba9f91ff254e3d305b9fd54561 \ + --hash=sha256:fcc17fd9cd898242f6b4adfaca46137a9edef687f43e6f78469692a5e70d851d + # via -r requirements.in +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py index 9d43de80f1298..815e4948e5960 100644 --- a/tools/build_pytorch_libs.py +++ b/tools/build_pytorch_libs.py @@ -2,7 +2,10 @@ import os import platform +<<<<<<< HEAD import subprocess +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .optional_submodules import checkout_nccl from .setup_helpers.cmake import CMake, USE_NINJA @@ -88,8 +91,12 @@ def build_pytorch( ) -> None: my_env = _create_build_env() if ( +<<<<<<< HEAD not check_negative_env_flag("USE_DISTRIBUTED") and not check_negative_env_flag("USE_CUDA") +======= + not check_negative_env_flag("USE_CUDA") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and not check_negative_env_flag("USE_NCCL") and not check_env_flag("USE_SYSTEM_NCCL") ): @@ -100,6 +107,7 @@ def build_pytorch( ) if cmake_only: return +<<<<<<< HEAD build_custom_step = os.getenv("BUILD_CUSTOM_STEP") if build_custom_step: try: @@ -116,4 +124,6 @@ def build_pytorch( print("Output (stdout and stderr):") print(e.output) raise +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cmake.build(my_env) diff --git a/tools/build_with_debinfo.py b/tools/build_with_debinfo.py index d2e0fefa61ac2..0bd63d1bce990 100755 --- a/tools/build_with_debinfo.py +++ b/tools/build_with_debinfo.py @@ -95,8 +95,12 @@ def main() -> None: sys.exit(-95) if not is_devel_setup(): print( +<<<<<<< HEAD "Not a devel setup of PyTorch, " "please run `python -m pip install --no-build-isolation -v -e .` first" +======= + "Not a devel setup of PyTorch, please run `python3 setup.py develop --user` first" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) sys.exit(-1) if not has_build_ninja(): diff --git a/tools/dynamo/gb_id_mapping.py b/tools/dynamo/gb_id_mapping.py index 8fef79bd80777..7a8c91f435aea 100644 --- a/tools/dynamo/gb_id_mapping.py +++ b/tools/dynamo/gb_id_mapping.py @@ -1,7 +1,13 @@ +<<<<<<< HEAD +======= +# mypy: ignore-errors + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import argparse import ast import json import re +<<<<<<< HEAD from pathlib import Path from typing import Any, Optional @@ -18,16 +24,43 @@ def load_registry(path: Path) -> dict[str, Any]: def save_registry(reg: dict[str, Any], path: Path) -> None: +======= +import sys +from pathlib import Path + + +def get_source_segment(source, node): + return ast.get_source_segment(source, node) + + +def load_registry(path): + if path.exists(): + with path.open() as f: + return json.load(f) + return {} + + +def save_registry(reg, path): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with path.open("w") as f: json.dump(reg, f, indent=2) +<<<<<<< HEAD def next_gb_id(reg: dict[str, Any]) -> str: ids = [int(x[2:]) for x in reg if x.startswith("GB") and x[2:].isdigit()] return f"GB{(max(ids, default=-1) + 1):04d}" def clean_string(s: Any) -> Any: +======= +def next_gb_id(reg): + ids = [int(x[2:]) for x in reg if x.startswith("GB") and x[2:].isdigit()] + return f"GB{(max(ids, default=0) + 1):04d}" + + +def clean_string(s): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Normalizes string literals by removing formatting artifacts and escape sequences. Handles f-strings, quotes, newlines, and other syntax elements for cleaner output. @@ -48,6 +81,7 @@ def clean_string(s: Any) -> Any: return s +<<<<<<< HEAD def expand_hints(hints: list[str], dynamo_dir: Optional[str] = None) -> list[str]: """ Expands hint references to their actual values from graph_break_hints. @@ -71,10 +105,21 @@ def expand_hints(hints: list[str], dynamo_dir: Optional[str] = None) -> list[str name: value for name, value in hints_namespace.items() if isinstance(value, list) and name.isupper() and not name.startswith("_") +======= +def expand_hints(hints): + # Expands hint references to their actual values from graph_break_hints. + from torch._dynamo import graph_break_hints + + hint_constants = { + name: value + for name, value in graph_break_hints.__dict__.items() + if isinstance(value, list) and name.isupper() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } expanded_hints = [] for hint in hints: +<<<<<<< HEAD expanded = False for name, value in hint_constants.items(): if f"*graph_break_hints.{name}" in hint: @@ -88,6 +133,16 @@ def expand_hints(hints: list[str], dynamo_dir: Optional[str] = None) -> list[str def extract_info_from_keyword(source: str, kw: ast.keyword) -> Any: +======= + for name, value in hint_constants.items(): + if f"*graph_break_hints.{name}" in hint: + expanded_hints.extend(value) + break + return expanded_hints + + +def extract_info_from_keyword(source, kw): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Extracts and returns the value of a keyword argument from an AST node. @@ -113,6 +168,7 @@ def extract_info_from_keyword(source: str, kw: ast.keyword) -> Any: return clean_string(param_source) +<<<<<<< HEAD def find_unimplemented_v2_calls( path: str, dynamo_dir: Optional[str] = None ) -> list[dict[str, Any]]: @@ -123,6 +179,16 @@ def find_unimplemented_v2_calls( file_paths = path_obj.glob("**/*.py") else: file_paths = [path_obj] # type: ignore[assignment] +======= +def find_unimplemented_v2_calls(path): + results = [] + path = Path(path) + + if path.is_dir(): + file_paths = path.glob("**/*.py") + else: + file_paths = [path] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for file_path in file_paths: with open(file_path) as f: @@ -132,18 +198,28 @@ def find_unimplemented_v2_calls( for node in ast.walk(tree): if isinstance(node, ast.FunctionDef): +<<<<<<< HEAD if node.name in ( "unimplemented_v2", "unimplemented_v2_with_warning", ): +======= + if node.name == "unimplemented_v2": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue if ( isinstance(node, ast.Call) and isinstance(node.func, ast.Name) +<<<<<<< HEAD and node.func.id in ("unimplemented_v2", "unimplemented_v2_with_warning") ): info: dict[str, Any] = { +======= + and node.func.id == "unimplemented_v2" + ): + info = { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "gb_type": None, "context": None, "explanation": None, @@ -165,7 +241,11 @@ def find_unimplemented_v2_calls( expanded_hints.extend(items) if "*graph_break_hints." in hints: +<<<<<<< HEAD expanded_hints.extend(expand_hints([hints], dynamo_dir)) +======= + expanded_hints.extend(expand_hints([hints])) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) info["hints"] = expanded_hints @@ -176,7 +256,125 @@ def find_unimplemented_v2_calls( return results +<<<<<<< HEAD def create_registry(dynamo_dir: str, registry_path: str) -> None: +======= +def cmd_add_new_gb_type(gb_type, file_path, registry_path, additional_info=None): + """ + Add a new graph break type to the registry. + + Args: + gb_type: The graph break type to add + file_path: Path to the file containing the unimplemented_v2 call + registry_path: Path to the registry JSON file + """ + registry_path = Path(registry_path) + reg = load_registry(registry_path) + + existing_gb_types = {entry[0]["Gb_type"] for entry in reg.values()} + if gb_type in existing_gb_types: + print( + f"Error: gb_type '{gb_type}' already exists in registry. Please rename the gb_type so it can be unique." + ) + return False + + calls = find_unimplemented_v2_calls(Path(file_path)) + matching_call = next((call for call in calls if call["gb_type"] == gb_type), None) + + if not matching_call: + print( + f"Error: Could not find unimplemented_v2 call with gb_type '{gb_type}' in {file_path}" + ) + return False + + gb_id = next_gb_id(reg) + reg[gb_id] = [ + { + "Gb_type": gb_type, + "Context": matching_call["context"], + "Explanation": matching_call["explanation"], + "Hints": matching_call["hints"] or [], + **({"Additional_Info": [additional_info]} if additional_info else {}), + } + ] + + save_registry(reg, registry_path) + print(f"Added {gb_type} to registry with ID {gb_id}") + return True + + +def cmd_update_gb_type( + old_gb_type, file_path, registry_path, new_gb_type=None, additional_info=None +): + """ + Update an existing graph break type in the registry by adding a new version + to the version history list. + + Args: + old_gb_type: The current graph break type to update + file_path: Path to the file containing the updated unimplemented_v2 call + registry_path: Path to the registry JSON file + new_gb_type: Optional new gb_type name to replace the old one + """ + registry_path = Path(registry_path) + reg = load_registry(registry_path) + + gb_id_map = {entry[0]["Gb_type"]: id for id, entry in reg.items()} + gb_id = gb_id_map.get(old_gb_type) + + if gb_id is None: + print(f"Error: gb_type '{old_gb_type}' not found in registry.") + return False + + search_gb_type = new_gb_type if new_gb_type else old_gb_type + calls = find_unimplemented_v2_calls(Path(file_path)) + matching_call = next( + (call for call in calls if call["gb_type"] == search_gb_type), None + ) + + if not matching_call: + print( + f"Error: Could not find unimplemented_v2 call with gb_type '{search_gb_type}' in {file_path}" + ) + return False + + if ( + matching_call["gb_type"] != old_gb_type + and matching_call["gb_type"] in gb_id_map + ): + print( + f"Error: New gb_type '{matching_call['gb_type']}' already exists in registry. Please use a unique gb_type." + ) + return False + + new_entry = { + "Gb_type": matching_call["gb_type"], + "Context": matching_call["context"], + "Explanation": matching_call["explanation"], + "Hints": matching_call["hints"] or [], + } + + if additional_info: + additional_info_list = reg[gb_id][0].get("Additional_Info", []) + new_entry["Additional_Info"] = ( + additional_info_list + [additional_info] + if additional_info_list + else [additional_info] + ) + elif "Additional_Info" in reg[gb_id][0]: + new_entry["Additional_Info"] = reg[gb_id][0]["Additional_Info"] + + reg[gb_id].insert(0, new_entry) + + save_registry(reg, registry_path) + print( + f"Updated {old_gb_type} to {matching_call['gb_type']} in registry with ID {gb_id}" + ) + return True + + +def create_registry(dynamo_dir, registry_path): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) calls = find_unimplemented_v2_calls(dynamo_dir) registry = {} @@ -202,9 +400,16 @@ def create_registry(dynamo_dir: str, registry_path: str) -> None: json.dump(registry, f, indent=2) +<<<<<<< HEAD def main() -> None: repo_root = Path(__file__).resolve().parent.parent.parent registry_path = repo_root / "torch" / "_dynamo" / "graph_break_registry.json" +======= +def main(): + script_dir = Path(__file__).resolve().parent + repo_root = script_dir.parent.parent + registry_path = script_dir / "graph_break_registry.json" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: import torch._dynamo @@ -224,6 +429,33 @@ def main() -> None: help="Directory to search for unimplemented_v2 calls.", ) +<<<<<<< HEAD +======= + add_parser = subparsers.add_parser("add", help="Add a gb_type to registry") + add_parser.add_argument("gb_type", help="The gb_type to add") + add_parser.add_argument( + "file_path", help="Path to the file containing the unimplemented_v2 call" + ) + add_parser.add_argument( + "--additional-info", help="Optional additional information to include" + ) + + update_parser = subparsers.add_parser( + "update", help="Update an existing gb_type in registry" + ) + update_parser.add_argument("gb_type", help="The gb_type to update") + update_parser.add_argument( + "file_path", + help="Path to the file containing the updated unimplemented_v2 call", + ) + update_parser.add_argument( + "--new_gb_type", help="New gb_type name if it has changed", default=None + ) + update_parser.add_argument( + "--additional-info", help="Optional additional information to include" + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) parser.add_argument( "--registry-path", type=str, @@ -235,6 +467,25 @@ def main() -> None: if args.command == "create": create_registry(args.dynamo_dir, args.registry_path) +<<<<<<< HEAD +======= + elif args.command == "add": + success = cmd_add_new_gb_type( + args.gb_type, args.file_path, args.registry_path, args.additional_info + ) + if not success: + sys.exit(1) + elif args.command == "update": + success = cmd_update_gb_type( + args.gb_type, + args.file_path, + args.registry_path, + args.new_gb_type, + args.additional_info, + ) + if not success: + sys.exit(1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: parser.print_help() diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index e0aaef31c1c32..88c212fb9b18f 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -24,7 +24,10 @@ Traceback, ) from tools.flight_recorder.components.utils import ( +<<<<<<< HEAD add_stack_id_in_entries, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) align_trace_from_beginning, check_current_entry_match, check_no_missing_dump_files, @@ -392,9 +395,12 @@ def build_db( # Ensure version is consistent across all ranks. check_version(version_by_ranks, version) entries = align_trace_from_beginning(entries) +<<<<<<< HEAD stack_id_trace_map: dict[str, int] = {} if args.just_print_entries: entries, stack_id_trace_map = add_stack_id_in_entries(entries) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # flattened database groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships( @@ -402,6 +408,7 @@ def build_db( ) logger.debug("built groups, memberships") +<<<<<<< HEAD if args.just_print_entries: just_print_entries( entries, _groups, _memberships, _pg_guids, args, stack_id_trace_map @@ -411,6 +418,15 @@ def build_db( if not args.allow_incomplete_ranks: check_no_missing_dump_files(entries, memberships) +======= + if not args.allow_incomplete_ranks: + check_no_missing_dump_files(entries, memberships) + + if args.just_print_entries: + just_print_entries(entries, _groups, _memberships, _pg_guids, args) + sys.exit(0) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tracebacks, collectives, nccl_calls = build_collectives( entries, _groups, _memberships, _pg_guids, version ) diff --git a/tools/flight_recorder/components/config_manager.py b/tools/flight_recorder/components/config_manager.py index abd7f5372133c..1eac16d306d28 100644 --- a/tools/flight_recorder/components/config_manager.py +++ b/tools/flight_recorder/components/config_manager.py @@ -67,7 +67,10 @@ def __init__(self: "JobConfig"): ) self.parser.add_argument("-j", "--just_print_entries", action="store_true") self.parser.add_argument("-v", "--verbose", action="store_true") +<<<<<<< HEAD self.parser.add_argument("--print_stack_trace", action="store_true") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def parse_args( self: "JobConfig", args: Optional[Sequence[str]] diff --git a/tools/flight_recorder/components/loader.py b/tools/flight_recorder/components/loader.py index 7634226bae528..3a8bebb02e1a3 100644 --- a/tools/flight_recorder/components/loader.py +++ b/tools/flight_recorder/components/loader.py @@ -78,9 +78,15 @@ def read_dir(args: argparse.Namespace) -> tuple[dict[str, dict[str, Any]], str]: if prefix is None: prefix = _determine_prefix(files) for f in files: +<<<<<<< HEAD if (offset := f.find(prefix)) == -1: continue details[f] = read_dump(f[:offset] + prefix, os.path.join(root, f)) +======= + if f.find(prefix) != 0: + continue + details[f] = read_dump(prefix, os.path.join(root, f)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) filecount += 1 if not version: version = str(details[f]["version"]) diff --git a/tools/flight_recorder/components/types.py b/tools/flight_recorder/components/types.py index 20e093688ba14..8f3cda205f4d8 100644 --- a/tools/flight_recorder/components/types.py +++ b/tools/flight_recorder/components/types.py @@ -388,17 +388,26 @@ def __init__( self, event: dict[Any, Any], memberships: dict[str, set[Any]], pg_name: str ): self.profiling_name = event["profiling_name"] +<<<<<<< HEAD comm_lib_backend, name = self.profiling_name.split(":") assert comm_lib_backend in ["nccl", "xccl"], ( f"name formatting error? {comm_lib_backend} != 'nccl' or 'xccl'" ) +======= + nccl, name = self.profiling_name.split(":") + assert nccl == "nccl", f"name formatting error? {nccl} != 'nccl'" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) parts = name.split(" ") type = parts[0] meta = parts[1] if len(parts) == 2 else None self.state = event["state"] +<<<<<<< HEAD # Store the hashed pg_name for accessing memberships, and original pg info for display self.pg_name = pg_name # This is the hashed version used for memberships lookup self.original_pg_name, self.pg_desc = event["process_group"] +======= + self.pg_name, self.pg_desc = event["process_group"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert type in COLLECTIVES | P2P | {"coalesced"}, ( f"{type} is not a supported operation" ) @@ -421,7 +430,10 @@ def __init__( else: self.input_sizes, self.output_sizes = None, None self.collective_seq_id = event["collective_seq_id"] +<<<<<<< HEAD self.stack_id = event.get("stack_id", -1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.p2p_seq_id = event["p2p_seq_id"] self.input_dtypes = event["input_dtypes"] self.output_dtypes = event["output_dtypes"] @@ -430,9 +442,15 @@ def __init__( self.is_verbose = os.getenv("FR_TRACE_VERBOSE_OUTPUT", "0") == "1" def _init_global_src_dst(self, pg_ranks: set[Any]) -> None: +<<<<<<< HEAD pg_ranks_sorted = sorted(pg_ranks) self._src_g = pg_ranks_sorted[self._src] if self._src is not None else None self._dst_g = pg_ranks_sorted[self._dst] if self._dst is not None else None +======= + pg_ranks = sorted(pg_ranks) + self._src_g = pg_ranks[self._src] if self._src is not None else None + self._dst_g = pg_ranks[self._dst] if self._dst is not None else None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def src(self) -> int: @@ -461,7 +479,10 @@ def __repr__(self) -> str: f"pg_name={self.pg_name}", f"pg_description={self.pg_desc}", f"pg_size={self.pg_size}", +<<<<<<< HEAD f"stack_id={self.stack_id}", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"state={self.state}", ) return f"{self.type}(%s)" % ", ".join(s for s in verbose_info if s) diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index 69455a5a433b0..d0f6f4cee0121 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -115,6 +115,7 @@ def visualize_ops( for r in all_ops: if len(all_ops[r]) > i: rank, event = all_rank_events[r][i] +<<<<<<< HEAD # Check if the pg_guid exists for this rank and process group pg_key = (event["process_group"][0], rank) if pg_key in _pg_guids: @@ -128,6 +129,15 @@ def visualize_ops( else: # Skip this entry if pg_guid mapping doesn't exist row.append(None) # type: ignore[arg-type] +======= + row.append( + Op( + event, + memberships, + _pg_guids[(event["process_group"][0], rank)], + ) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) progress = True else: row.append(None) # type: ignore[arg-type] @@ -250,6 +260,7 @@ def visualize_ops( for r in all_ops: if len(all_ops[r]) > i: rank, event = all_rank_events[r][i] +<<<<<<< HEAD # Check if the pg_guid exists for this rank and process group pg_key = (event["process_group"][0], rank) if pg_key in _pg_guids: @@ -263,6 +274,15 @@ def visualize_ops( else: # Skip this entry if pg_guid mapping doesn't exist row.append(None) # type: ignore[arg-type] +======= + row.append( + Op( + event, + memberships, + _pg_guids[(event["process_group"][0], rank)], + ) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) progress = True else: row.append(None) # type: ignore[arg-type] @@ -628,7 +648,10 @@ def just_print_entries( _memberships: dict[str, set[Any]], _pg_guids: dict[tuple[str, int], str], args: argparse.Namespace, +<<<<<<< HEAD stack_id_trace_map: dict[str, int], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: rows = [] ranks = sorted(all_entries.keys()) @@ -663,6 +686,7 @@ def just_print_entries( logger.info(tabulate(rows, headers=headers)) +<<<<<<< HEAD if stack_id_trace_map and args.print_stack_trace: headers = ["stack_id", "frame_stack"] rows = [] @@ -674,6 +698,8 @@ def just_print_entries( logger.info(tabulate(rows, headers=headers)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def check_no_missing_dump_files( entries: dict[int, Any], memberships: list[Membership] @@ -701,6 +727,7 @@ def get_version_detail(version: str) -> tuple[int, int]: return major, minor +<<<<<<< HEAD def add_stack_id_in_entries( entries: dict[int, list[dict[str, Any]]], ) -> tuple[dict[int, list[dict[str, Any]]], dict[str, int]]: @@ -722,6 +749,8 @@ def add_stack_id_in_entries( return entries, stack_id_trace_map +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def align_trace_from_beginning( entries: dict[int, list[dict[str, Any]]], ) -> dict[int, list[dict[str, Any]]]: diff --git a/tools/linter/adapters/_linter/block.py b/tools/linter/adapters/_linter/block.py index 4097da50a7e4e..31d1463fac3c0 100644 --- a/tools/linter/adapters/_linter/block.py +++ b/tools/linter/adapters/_linter/block.py @@ -14,9 +14,12 @@ from tokenize import TokenInfo +<<<<<<< HEAD _OVERRIDES = {"@override", "@typing_extensions.override", "@typing.override"} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @total_ordering @dc.dataclass class Block: @@ -71,6 +74,7 @@ class Category(str, Enum): @property def start_line(self) -> int: +<<<<<<< HEAD """The line number for the def or class statement""" return self.tokens[self.begin].start[0] @@ -85,6 +89,13 @@ def end_line(self) -> int: # def function(): ... # # and the dedent correctly pointed to one past the end of self.tokens +======= + return self.tokens[max(self.indent, self.index)].start[0] + + @property + def end_line(self) -> int: + return self.tokens[max(self.dedent, self.index)].start[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def line_count(self) -> int: @@ -111,7 +122,13 @@ def decorators(self) -> list[str]: @cached_property def is_override(self) -> bool: +<<<<<<< HEAD return not self.is_class and bool(_OVERRIDES.intersection(self.decorators)) +======= + return not self.is_class and any( + d.rpartition(".")[2] == "override" for d in self.decorators + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DATA_FIELDS = ( "category", @@ -159,9 +176,15 @@ def _get_decorators(tokens: Sequence[TokenInfo], block_start: int) -> list[str]: def decorators() -> Iterator[str]: rev = reversed(range(block_start)) newlines = (i for i in rev if tokens[i].type == token.NEWLINE) +<<<<<<< HEAD it = iter(itertools.chain(newlines, [-1])) # The -1 accounts for the very first line in the file +======= + newlines = itertools.chain(newlines, [-1]) # To account for the first line + + it = iter(newlines) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) end = next(it, -1) # Like itertools.pairwise in Python 3.10 for begin in it: for i in range(begin + 1, end): diff --git a/tools/linter/adapters/_linter/bracket_pairs.py b/tools/linter/adapters/_linter/bracket_pairs.py index 323f4da88bced..236ae0a15bb8e 100644 --- a/tools/linter/adapters/_linter/bracket_pairs.py +++ b/tools/linter/adapters/_linter/bracket_pairs.py @@ -16,10 +16,16 @@ def bracket_pairs(tokens: Sequence[TokenInfo]) -> dict[int, int]: """Returns a dictionary mapping opening to closing brackets""" braces: dict[int, int] = {} stack: list[int] = [] +<<<<<<< HEAD in_fstring = False for i, t in enumerate(tokens): if t.type == token.OP and not in_fstring: +======= + + for i, t in enumerate(tokens): + if t.type == token.OP: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if t.string in BRACKETS: stack.append(i) elif inv := BRACKETS_INV.get(t.string): @@ -35,11 +41,17 @@ def bracket_pairs(tokens: Sequence[TokenInfo]) -> dict[int, int]: raise ParseError(t, f"Mismatched braces '{b}' at {begin}") elif t.type == FSTRING_START: stack.append(FSTRING_START) +<<<<<<< HEAD in_fstring = True elif t.type == FSTRING_END: if stack.pop() != FSTRING_START: raise ParseError(t, "Mismatched FSTRING_START/FSTRING_END") in_fstring = False +======= + elif t.type == FSTRING_END: + if stack.pop() != FSTRING_START: + raise ParseError(t, "Mismatched FSTRING_START/FSTRING_END") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if stack: raise ParseError(t, "Left open") return braces diff --git a/tools/linter/adapters/black_linter.py b/tools/linter/adapters/black_linter.py new file mode 100644 index 0000000000000..c22a89032cfb3 --- /dev/null +++ b/tools/linter/adapters/black_linter.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import argparse +import concurrent.futures +import json +import logging +import os +import subprocess +import sys +import time +from enum import Enum +from typing import BinaryIO, NamedTuple + + +IS_WINDOWS: bool = os.name == "nt" + + +class LintSeverity(str, Enum): + ERROR = "error" + WARNING = "warning" + ADVICE = "advice" + DISABLED = "disabled" + + +class LintMessage(NamedTuple): + path: str | None + line: int | None + char: int | None + code: str + severity: LintSeverity + name: str + original: str | None + replacement: str | None + description: str | None + + +def as_posix(name: str) -> str: + return name.replace("\\", "/") if IS_WINDOWS else name + + +def _run_command( + args: list[str], + *, + stdin: BinaryIO, + timeout: int, +) -> subprocess.CompletedProcess[bytes]: + logging.debug("$ %s", " ".join(args)) + start_time = time.monotonic() + try: + return subprocess.run( + args, + stdin=stdin, + capture_output=True, + shell=IS_WINDOWS, # So batch scripts are found. + timeout=timeout, + check=True, + ) + finally: + end_time = time.monotonic() + logging.debug("took %dms", (end_time - start_time) * 1000) + + +def run_command( + args: list[str], + *, + stdin: BinaryIO, + retries: int, + timeout: int, +) -> subprocess.CompletedProcess[bytes]: + remaining_retries = retries + while True: + try: + return _run_command(args, stdin=stdin, timeout=timeout) + except subprocess.TimeoutExpired as err: + if remaining_retries == 0: + raise err + remaining_retries -= 1 + logging.warning( + "(%s/%s) Retrying because command failed with: %r", + retries - remaining_retries, + retries, + err, + ) + time.sleep(1) + + +def check_file( + filename: str, + retries: int, + timeout: int, +) -> list[LintMessage]: + try: + with open(filename, "rb") as f: + original = f.read() + with open(filename, "rb") as f: + proc = run_command( + [sys.executable, "-mblack", "--stdin-filename", filename, "-"], + stdin=f, + retries=retries, + timeout=timeout, + ) + except subprocess.TimeoutExpired: + return [ + LintMessage( + path=filename, + line=None, + char=None, + code="BLACK", + severity=LintSeverity.ERROR, + name="timeout", + original=None, + replacement=None, + description=( + "black timed out while trying to process a file. " + "Please report an issue in pytorch/pytorch with the " + "label 'module: lint'" + ), + ) + ] + except (OSError, subprocess.CalledProcessError) as err: + return [ + LintMessage( + path=filename, + line=None, + char=None, + code="BLACK", + severity=LintSeverity.ADVICE, + name="command-failed", + original=None, + replacement=None, + description=( + f"Failed due to {err.__class__.__name__}:\n{err}" + if not isinstance(err, subprocess.CalledProcessError) + else ( + "COMMAND (exit code {returncode})\n" + "{command}\n\n" + "STDERR\n{stderr}\n\n" + "STDOUT\n{stdout}" + ).format( + returncode=err.returncode, + command=" ".join(as_posix(x) for x in err.cmd), + stderr=err.stderr.decode("utf-8").strip() or "(empty)", + stdout=err.stdout.decode("utf-8").strip() or "(empty)", + ) + ), + ) + ] + + replacement = proc.stdout + if original == replacement: + return [] + + return [ + LintMessage( + path=filename, + line=None, + char=None, + code="BLACK", + severity=LintSeverity.WARNING, + name="format", + original=original.decode("utf-8"), + replacement=replacement.decode("utf-8"), + description="Run `lintrunner -a` to apply this patch.", + ) + ] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Format files with black.", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--retries", + default=3, + type=int, + help="times to retry timed out black", + ) + parser.add_argument( + "--timeout", + default=90, + type=int, + help="seconds to wait for black", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "filenames", + nargs="+", + help="paths to lint", + ) + args = parser.parse_args() + + logging.basicConfig( + format="<%(threadName)s:%(levelname)s> %(message)s", + level=logging.NOTSET + if args.verbose + else logging.DEBUG + if len(args.filenames) < 1000 + else logging.INFO, + stream=sys.stderr, + ) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=os.cpu_count(), + thread_name_prefix="Thread", + ) as executor: + futures = { + executor.submit(check_file, x, args.retries, args.timeout): x + for x in args.filenames + } + for future in concurrent.futures.as_completed(futures): + try: + for lint_message in future.result(): + print(json.dumps(lint_message._asdict()), flush=True) + except Exception: + logging.critical('Failed at "%s".', futures[future]) + raise + + +if __name__ == "__main__": + main() diff --git a/tools/linter/adapters/codespell_linter.py b/tools/linter/adapters/codespell_linter.py index 13498cff13204..396947955d9c1 100644 --- a/tools/linter/adapters/codespell_linter.py +++ b/tools/linter/adapters/codespell_linter.py @@ -49,8 +49,13 @@ def format_error_message( if message is None and error is not None: message = ( f"Failed due to {error.__class__.__name__}:\n{error}\n" +<<<<<<< HEAD "Please either fix the error or add the word(s) to the dictionary file.\n" "HINT: all-lowercase words in the dictionary can cover all case variations." +======= + "Please either fix the error or " + "add the word(s) to the dictionary file (lowercase is preferred)." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return LintMessage( path=filename, diff --git a/tools/linter/adapters/docstring_linter-grandfather.json b/tools/linter/adapters/docstring_linter-grandfather.json index 49b12adb127bd..e735e94d9f86a 100644 --- a/tools/linter/adapters/docstring_linter-grandfather.json +++ b/tools/linter/adapters/docstring_linter-grandfather.json @@ -1,4 +1,5 @@ { +<<<<<<< HEAD "torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py": { "class MMRankingA100": 279, "def MMRankingA100.fill_choices()": 199 @@ -15,10 +16,37 @@ "class MixedMMH100": 132, "def MixedMMH100.get_best_choices()": 85 }, +======= + "torch/_inductor/async_compile.py": { + "class AsyncCompile": 281 + }, + "torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py": { + "class MMRankingA100": 278, + "def MMRankingA100.fill_choices()": 199 + }, + "torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py": { + "class MMRankingH100": 303, + "def MMRankingH100.fill_choices()": 203 + }, + "torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py": { + "class MixedMMA100": 132, + "def MixedMMA100.get_best_choices()": 85 + }, + "torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py": { + "class MixedMMH100": 131, + "def MixedMMH100.get_best_choices()": 85 + }, + "torch/_inductor/autotune_process.py": { + "class CUDABenchmarkRequest": 115, + "class TritonBenchmarkRequest": 121, + "def TritonBenchmarkRequest.make_run_fn()": 81 + }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/_inductor/bounds.py": { "class ValueRangeAnalysis": 107 }, "torch/_inductor/codecache.py": { +<<<<<<< HEAD "class CppPythonBindingsCodeCache": 179, "class HalideCodeCache": 357, "class PyCodeCache": 102 @@ -35,11 +63,33 @@ "class CppOverrides": 429, "class CppScheduling": 786, "class CppVecKernel": 865, +======= + "class AotCodeCompiler": 516, + "class CUDACodeCache": 107, + "class CppCodeCache": 125, + "class CppPythonBindingsCodeCache": 168, + "class HalideCodeCache": 350 + }, + "torch/_inductor/codegen/common.py": { + "class CSE": 167, + "class CSEProxy": 310, + "class Kernel": 286, + "class KernelArgs": 325, + "class OpOverrides": 227 + }, + "torch/_inductor/codegen/cpp.py": { + "class CppKernel": 572, + "class CppKernelProxy": 601, + "class CppOverrides": 429, + "class CppScheduling": 777, + "class CppVecKernel": 857, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "class OuterLoopFusedSchedulerNode": 159, "def CppKernel.codegen_loops_impl()": 144, "def CppKernelProxy.codegen_functions()": 183, "def CppKernelProxy.legalize_lowp_fp_dtype_loopbody()": 224, "def CppScheduling.fuse()": 81, +<<<<<<< HEAD "def CppVecKernel.reduction_combine_vec()": 100, "def OuterLoopFusedSchedulerNode.check_outer_fusion_loop_level_attr()": 85, "def TilingSelect.select_tiling()": 170 @@ -54,17 +104,44 @@ "torch/_inductor/codegen/cpp_grouped_gemm_template.py": { "def CppGroupedGemmTemplate.add_choices()": 154, "def CppGroupedGemmTemplate.render()": 153 +======= + "def CppVecKernel.reduction()": 193, + "def CppVecKernel.reduction_combine_vec()": 87, + "def TilingSelect.select_tiling()": 165 + }, + "torch/_inductor/codegen/cpp_flex_attention_template.py": { + "class CppFlexAttentionTemplate": 374, + "def CppFlexAttentionTemplate.modification()": 94 + }, + "torch/_inductor/codegen/cpp_gemm_template.py": { + "class CppGemmTemplate": 998, + "def CppGemmTemplate.add_choices()": 163, + "def CppGemmTemplate.get_options()": 243 + }, + "torch/_inductor/codegen/cpp_grouped_gemm_template.py": { + "def CppGroupedGemmTemplate.add_choices()": 141, + "def CppGroupedGemmTemplate.render()": 146 + }, + "torch/_inductor/codegen/cpp_micro_gemm.py": { + "def create_micro_gemm()": 94 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, "torch/_inductor/codegen/cpp_template.py": { "class CppTemplate": 114 }, "torch/_inductor/codegen/cpp_template_kernel.py": { +<<<<<<< HEAD "class CppTemplateKernel": 499, "def CppTemplateKernel.store_outputs()": 111 +======= + "class CppTemplateKernel": 469, + "def CppTemplateKernel.store_outputs()": 102 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, "torch/_inductor/codegen/cpp_utils.py": { "def create_epilogue_with_attr()": 165 }, +<<<<<<< HEAD "torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py": { "def CppWrapperCpuArrayRef.generate_return()": 128, "def CppWrapperCpuArrayRef.write_wrapper_decl()": 208 @@ -88,15 +165,65 @@ }, "torch/_inductor/codegen/rocm/ck_universal_gemm_template.py": { "class CKGemmTemplate": 950 +======= + "torch/_inductor/codegen/cpp_wrapper_cpu.py": { + "def CppWrapperCpu.generate_extern_kernel_args_decl_if_needed()": 152, + "def CppWrapperCpu.generate_input_output_runtime_checks()": 115, + "def CppWrapperCpu.generate_py_arg()": 96, + "def CppWrapperCpu.val_to_arg_str()": 88, + "def CppWrapperCpu.write_wrapper_decl()": 140 + }, + "torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py": { + "def CppWrapperCpuArrayRef.generate_return()": 127, + "def CppWrapperCpuArrayRef.write_wrapper_decl()": 208 + }, + "torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py": { + "def EmitGemmUniversal3xInstanceWithEVT.emit()": 98 + }, + "torch/_inductor/codegen/cuda/device_op_overrides.py": { + "class CUDADeviceOpOverrides": 222, + "def CUDADeviceOpOverrides.tma_descriptor_helpers()": 102 + }, + "torch/_inductor/codegen/cuda/gemm_template.py": { + "class CUTLASS2xGemmTemplate": 265, + "class CUTLASS3xGemmTemplate": 326 + }, + "torch/_inductor/codegen/debug_utils.py": { + "class DebugPrinterManager": 228 + }, + "torch/_inductor/codegen/halide.py": { + "class HalideKernel": 982, + "class HalideOverrides": 329, + "class HalidePrinter": 129, + "def HalideKernel.halide_kernel_meta()": 82 + }, + "torch/_inductor/codegen/mps.py": { + "class MetalKernel": 354, + "class MetalOverrides": 335, + "def MetalKernel.reduction()": 109 + }, + "torch/_inductor/codegen/rocm/ck_conv_template.py": { + "class CKGroupedConvFwdTemplate": 531, + "def CKGroupedConvFwdTemplate.globals()": 143 + }, + "torch/_inductor/codegen/rocm/ck_universal_gemm_template.py": { + "class CKGemmTemplate": 947 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, "torch/_inductor/codegen/rocm/rocm_benchmark_request.py": { "class ROCmBenchmarkRequest": 117 }, "torch/_inductor/codegen/simd.py": { +<<<<<<< HEAD +======= + "class IterationRangesRoot": 122, + "class SIMDScheduling": 1054, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "def SIMDScheduling.candidate_tilings()": 126, "def SIMDScheduling.generate_node_schedule()": 95 }, "torch/_inductor/codegen/triton.py": { +<<<<<<< HEAD "class TritonKernel": 2562, "class TritonOverrides": 469, "class TritonPrinter": 172, @@ -107,6 +234,19 @@ "def TritonKernel.reduction()": 396, "def TritonKernel.scan()": 110, "def TritonScheduling.benchmark_codegened_module()": 85, +======= + "class BlockPtrOptions": 272, + "class TritonKernel": 2455, + "class TritonOverrides": 505, + "class TritonPrinter": 172, + "class TritonScheduling": 396, + "def TritonKernel.codegen_kernel()": 222, + "def TritonKernel.codegen_kernel_benchmark()": 89, + "def TritonKernel.load()": 134, + "def TritonKernel.reduction()": 383, + "def TritonKernel.scan()": 103, + "def TritonScheduling.benchmark_codegened_module()": 83, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "def TritonScheduling.benchmark_combo_kernel()": 91 }, "torch/_inductor/codegen/triton_combo_kernel.py": { @@ -118,6 +258,7 @@ }, "torch/_inductor/codegen/wrapper.py": { "def PythonWrapperCodegen.benchmark_compiled_module()": 92, +<<<<<<< HEAD "def PythonWrapperCodegen.define_user_defined_triton_kernel()": 266, "def PythonWrapperCodegen.generate_example_arg_value()": 84, "def user_defined_kernel_grid_fn_code()": 102 @@ -150,10 +291,54 @@ }, "torch/_inductor/fx_passes/b2b_gemm.py": { "def b2b_gemm_handler()": 182 +======= + "def PythonWrapperCodegen.define_user_defined_triton_kernel()": 249, + "def PythonWrapperCodegen.generate_example_arg_value()": 83, + "def user_defined_kernel_grid_fn_code()": 96 + }, + "torch/_inductor/comm_lowering.py": { + "def register_comm_lowerings()": 189 + }, + "torch/_inductor/comms.py": { + "def enforce_comm_ordering_for_fsdp()": 170, + "def reinplace_fsdp_all_gather()": 110 + }, + "torch/_inductor/compile_fx.py": { + "def _InProcessFxCompile.codegen_and_compile()": 379, + "def fw_compiler_freezing()": 93 + }, + "torch/_inductor/config.py": { + "class cpp": 107, + "class triton": 182 + }, + "torch/_inductor/constant_folding.py": { + "class ConstantFolder": 223, + "def ConstantFolder.run_node()": 94 + }, + "torch/_inductor/cpu_vec_isa.py": { + "class VecISA": 120 + }, + "torch/_inductor/debug.py": { + "class DebugContext": 158, + "class DebugFormatter": 189, + "def DebugFormatter.log_autotuning_results()": 81 + }, + "torch/_inductor/dependencies.py": { + "class MemoryDep": 225 + }, + "torch/_inductor/fx_passes/b2b_gemm.py": { + "def b2b_gemm_handler()": 180 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, "torch/_inductor/fx_passes/binary_folding.py": { "def binary_folding_init()": 416 }, +<<<<<<< HEAD +======= + "torch/_inductor/fx_passes/freezing_patterns.py": { + "def addmm_patterns_init()": 94 + }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch/_inductor/fx_passes/group_batch_fusion.py": { "def BatchLayernormFusion.fuse()": 131, "def PostGradBatchLinearFusion.fuse()": 83, @@ -161,12 +346,17 @@ }, "torch/_inductor/fx_passes/joint_graph.py": { "def constant_fold_uniform_value()": 109, +<<<<<<< HEAD "def remove_no_ops()": 97 +======= + "def remove_no_ops()": 93 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, "torch/_inductor/fx_passes/micro_pipeline_tp.py": { "def find_all_gather_patterns()": 116, "def find_reduce_scatter_patterns()": 125 }, +<<<<<<< HEAD "torch/_inductor/fx_passes/split_cat.py": { "def SplitCatSimplifier.replace_cat()": 152, "def merge_getitem_cat()": 97, @@ -213,11 +403,83 @@ }, "torch/_inductor/kernel/mm.py": { "def tuned_addmm()": 151 +======= + "torch/_inductor/fx_passes/post_grad.py": { + "def lower_scan_to_while_loop()": 154 + }, + "torch/_inductor/fx_passes/split_cat.py": { + "def SplitCatSimplifier.replace_cat()": 145, + "def merge_getitem_cat()": 97, + "def merge_split_cat_aten()": 87, + "def move_reshape_out_of_split_stack()": 110 + }, + "torch/_inductor/fx_utils.py": { + "def FakeTensorUpdater.incremental_update()": 100 + }, + "torch/_inductor/graph.py": { + "class GraphLowering": 2032, + "def GraphLowering.call_function()": 116, + "def GraphLowering.extract_autotune_inputs()": 90, + "def GraphLowering.output()": 87, + "def GraphLowering.placeholder()": 92, + "def GraphLowering.run_node()": 380 + }, + "torch/_inductor/ir.py": { + "class Buffer": 122, + "class ComputedBuffer": 329, + "class Conditional": 138, + "class ExternKernel": 793, + "class FallbackKernel": 439, + "class FlexibleLayout": 139, + "class IRNode": 244, + "class Layout": 202, + "class Loops": 128, + "class Reduction": 737, + "class Scan": 199, + "class Sort": 150, + "class UserDefinedTritonKernel": 183, + "class View": 174, + "class WelfordReduction": 221, + "class WhileLoop": 203, + "def ConcatKernel.create()": 95, + "def ExternKernel.process_kernel()": 110, + "def ExternKernel.require_strides()": 149, + "def FallbackKernel.create()": 81, + "def FallbackKernel.export_extern_kernel_node()": 82, + "def Reduction.create()": 136, + "def Reduction.num_splits()": 152, + "def Scan.create()": 83, + "def WelfordReduction.create()": 110, + "def WhileLoop.create()": 161 + }, + "torch/_inductor/jagged_lowerings.py": { + "def register_jagged_ops()": 156 + }, + "torch/_inductor/kernel/bmm.py": { + "def tuned_bmm()": 91 + }, + "torch/_inductor/kernel/conv.py": { + "def convolution()": 231 + }, + "torch/_inductor/kernel/flex_attention.py": { + "def flex_attention()": 303, + "def flex_attention_backward()": 323, + "def lower_cpu()": 273 + }, + "torch/_inductor/kernel/flex_decoding.py": { + "def create_flex_decoding_kernel()": 288 + }, + "torch/_inductor/kernel/mm.py": { + "def tuned_addmm()": 169, + "def tuned_mm()": 127, + "def tuned_scaled_mm()": 130 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, "torch/_inductor/loop_body.py": { "class CaptureIndexing": 174 }, "torch/_inductor/lowering.py": { +<<<<<<< HEAD "def avg_pool2d_backward()": 164, "def avg_pool3d_backward()": 198, "def cat()": 123, @@ -235,11 +497,29 @@ }, "torch/_inductor/mkldnn_lowerings.py": { "def register_onednn_fusion_ops()": 1156 +======= + "def avg_pool2d_backward()": 155, + "def avg_pool3d_backward()": 189, + "def cat()": 123, + "def index_put_impl_()": 125, + "def make_pointwise()": 85, + "def max_pool2d_with_indices_backward()": 140, + "def scatter_reduce_()": 111, + "def sdpa_constraint()": 132, + "def searchsorted()": 84 + }, + "torch/_inductor/mkldnn_ir.py": { + "class MkldnnRnnLayer": 114 + }, + "torch/_inductor/mkldnn_lowerings.py": { + "def register_onednn_fusion_ops()": 1152 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, "torch/_inductor/mock_cache.py": { "class PatchCaches": 108 }, "torch/_inductor/pattern_matcher.py": { +<<<<<<< HEAD "class ReplacementPatternEntry": 202, "def ReplacementPatternEntry.replace_with_graph()": 188 }, @@ -248,11 +528,22 @@ }, "torch/_inductor/runtime/autotune_cache.py": { "class AutotuneCache": 201 +======= + "class ReplacementPatternEntry": 196, + "def ReplacementPatternEntry.replace_with_graph()": 177 + }, + "torch/_inductor/quantized_lowerings.py": { + "def register_woq_mm_ops()": 136 + }, + "torch/_inductor/runtime/autotune_cache.py": { + "class AutotuneCache": 190 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, "torch/_inductor/runtime/benchmarking.py": { "class InductorBenchmarker": 111 }, "torch/_inductor/scheduler.py": { +<<<<<<< HEAD "class BaseSchedulerNode": 695, "class BaseScheduling": 142, "class SchedulerBuffer": 106, @@ -263,5 +554,31 @@ }, "torch/_inductor/utils.py": { "class IndentedBuffer": 145 +======= + "class BaseSchedulerNode": 697, + "class BaseScheduling": 139, + "class Scheduler": 2568, + "class SchedulerBuffer": 103, + "class SchedulerNode": 256 + }, + "torch/_inductor/select_algorithm.py": { + "class AlgorithmSelectorCache": 694, + "class TritonTemplate": 224, + "class TritonTemplateKernel": 770, + "def AlgorithmSelectorCache.log_results()": 92, + "def AlgorithmSelectorCache.make_benchmark_fn[2]()": 145 + }, + "torch/_inductor/sizevars.py": { + "class SizeVarAllocator": 780 + }, + "torch/_inductor/template_heuristics.py": { + "class ROCmConfigHeuristic": 212 + }, + "torch/_inductor/utils.py": { + "class IndentedBuffer": 136 + }, + "torch/_inductor/wrapper_benchmark.py": { + "def parse_profile_event_list()": 119 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } \ No newline at end of file diff --git a/tools/linter/adapters/docstring_linter.py b/tools/linter/adapters/docstring_linter.py index 477bfe7d9a809..a0896c271ac2a 100644 --- a/tools/linter/adapters/docstring_linter.py +++ b/tools/linter/adapters/docstring_linter.py @@ -10,7 +10,10 @@ _FILE = Path(__file__).absolute() _PATH = [Path(p).absolute() for p in sys.path] +<<<<<<< HEAD _OVERRIDES = {"@override", "@typing_extensions.override", "@typing.override"} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING or _FILE.parent not in _PATH: from . import _linter @@ -155,7 +158,11 @@ def has_class_init_doc(b: _linter.Block) -> bool: def _is_bad_block(self, b: _linter.Block, pf: _linter.PythonFile) -> bool: max_lines = self._max_lines[b.category] return ( +<<<<<<< HEAD not (b.is_override or pf.omitted(pf.tokens, b.begin, b.dedent)) +======= + not pf.omitted(pf.tokens, b.begin, b.dedent) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and b.line_count > max_lines and len(b.docstring) < self.args.min_docstring and (self.args.lint_local or not b.is_local) diff --git a/tools/linter/adapters/pip_init.py b/tools/linter/adapters/pip_init.py index 05a7a8acf9324..230cfd756e4e4 100644 --- a/tools/linter/adapters/pip_init.py +++ b/tools/linter/adapters/pip_init.py @@ -13,6 +13,7 @@ import time +<<<<<<< HEAD def run_command( args: list[str], env: dict[str, str] | None = None, @@ -21,12 +22,23 @@ def run_command( start_time = time.monotonic() try: return subprocess.run(args, env=env, text=True, encoding="utf-8", check=True) +======= +def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]: + logging.debug("$ %s", " ".join(args)) + start_time = time.monotonic() + try: + return subprocess.run(args, check=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) finally: end_time = time.monotonic() logging.debug("took %dms", (end_time - start_time) * 1000) +<<<<<<< HEAD def main() -> None: +======= +if __name__ == "__main__": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) parser = argparse.ArgumentParser(description="pip initializer") parser.add_argument( "packages", @@ -41,6 +53,14 @@ def main() -> None: parser.add_argument( "--dry-run", help="do not install anything, just print what would be done." ) +<<<<<<< HEAD +======= + parser.add_argument( + "--no-black-binary", + help="do not use pre-compiled binaries from pip for black.", + action="store_true", + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = parser.parse_args() @@ -50,6 +70,7 @@ def main() -> None: stream=sys.stderr, ) +<<<<<<< HEAD env: dict[str, str] = { **os.environ, "UV_PYTHON": sys.executable, @@ -60,6 +81,19 @@ def main() -> None: uv_index = env.get("UV_INDEX", env.get("PIP_EXTRA_INDEX_URL")) if uv_index: env["UV_INDEX"] = uv_index +======= + uv_available = ( + any(prefix in sys.base_prefix for prefix in ["uv/python", "uv\\python"]) + and shutil.which("uv") is not None + ) + + if uv_available: + pip_args = ["uv", "pip", "install"] + elif sys.executable: + pip_args = [sys.executable, "-mpip", "install"] + else: + pip_args = ["pip3", "install"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # If we are in a global install, use `--user` to install so that you do not # need root access in order to initialize linters. @@ -67,6 +101,7 @@ def main() -> None: # However, `pip install --user` interacts poorly with virtualenvs (see: # https://bit.ly/3vD4kvl) and conda (see: https://bit.ly/3KG7ZfU). So in # these cases perform a regular installation. +<<<<<<< HEAD in_conda = env.get("CONDA_PREFIX") is not None in_virtualenv = env.get("VIRTUAL_ENV") is not None need_user_flag = not in_conda and not in_virtualenv @@ -81,6 +116,11 @@ def main() -> None: pip_args = ["pip3", "install"] if need_user_flag: +======= + in_conda = os.environ.get("CONDA_PREFIX") is not None + in_virtualenv = os.environ.get("VIRTUAL_ENV") is not None + if not in_conda and not in_virtualenv: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pip_args.append("--user") pip_args.extend(args.packages) @@ -92,14 +132,23 @@ def main() -> None: "Package {package_name} did not have a version specified. " "Please specify a version to produce a consistent linting experience." ) +<<<<<<< HEAD +======= + if args.no_black_binary and "black" in package_name: + pip_args.append(f"--no-binary={package_name}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dry_run = args.dry_run == "1" if dry_run: print(f"Would have run: {pip_args}") sys.exit(0) +<<<<<<< HEAD run_command(pip_args, env=env) if __name__ == "__main__": main() +======= + run_command(pip_args) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/tools/linter/adapters/pyfmt_linter.py b/tools/linter/adapters/pyfmt_linter.py index ce5f8252a20f0..686f0a4e4decf 100644 --- a/tools/linter/adapters/pyfmt_linter.py +++ b/tools/linter/adapters/pyfmt_linter.py @@ -2,6 +2,10 @@ import argparse import concurrent.futures +<<<<<<< HEAD +======= +import fnmatch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import json import logging import os @@ -12,6 +16,10 @@ from pathlib import Path from typing import NamedTuple +<<<<<<< HEAD +======= +import black +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import isort import usort @@ -19,6 +27,56 @@ IS_WINDOWS: bool = os.name == "nt" REPO_ROOT = Path(__file__).absolute().parents[3] +<<<<<<< HEAD +======= +# TODO: remove this when it gets empty and remove `black` in PYFMT +USE_BLACK_FILELIST = re.compile( + "|".join( + ( + r"\A\Z", # empty string + *map( + fnmatch.translate, + [ + # ** + # .ci/** + # .github/** + # benchmarks/** + # functorch/** + # tools/** + # torchgen/** + # test/** + # test/[a-h]*/** + # test/[i-j]*/** + "test/j*/**", + # test/[k-m]*/** + "test/[k-m]*/**", + # test/optim/** + # "test/[p-z]*/**", + "test/[p-z]*/**", + # torch/** + # torch/_[a-c]*/** + "torch/_[a-c]*/**", + # torch/_[e-h]*/** + "torch/_[e-h]*/**", + # torch/_i*/** + # torch/_[j-z]*/** + "torch/_[j-z]*/**", + # torch/[a-c]*/** + "torch/a[a-n]*/**", + "torch/a[p-z]*/**", + "torch/[b-c]*/**", + # torch/d*/** + # torch/[e-m]*/** + # torch/optim/** + # torch/[p-z]*/** + "torch/[p-z]*/**", + ], + ), + ) + ) +) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class LintSeverity(str, Enum): ERROR = "error" @@ -78,6 +136,26 @@ def run_usort(content: str, path: Path) -> str: return usort.usort_string(content, path=path, config=usort_config) +<<<<<<< HEAD +======= +def run_black(content: str, path: Path) -> str: + black_config = black.parse_pyproject_toml(black.find_pyproject_toml((str(path),))) # type: ignore[attr-defined,arg-type] + # manually patch options that do not have a 1-to-1 match in Mode arguments + black_config["target_versions"] = { + black.TargetVersion[ver.upper()] # type: ignore[attr-defined] + for ver in black_config.pop("target_version", []) + } + black_config["string_normalization"] = not black_config.pop( + "skip_string_normalization", False + ) + black_mode = black.Mode(**black_config) + black_mode.is_pyi = path.suffix.lower() == ".pyi" + black_mode.is_ipynb = path.suffix.lower() == ".ipynb" + + return black.format_str(content, mode=black_mode) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def run_ruff_format(content: str, path: Path) -> str: try: return subprocess.check_output( @@ -109,7 +187,14 @@ def check_file(filename: str) -> list[LintMessage]: # NB: run isort first to enforce style for blank lines replacement = run_isort(replacement, path=path) replacement = run_usort(replacement, path=path) +<<<<<<< HEAD replacement = run_ruff_format(replacement, path=path) +======= + if USE_BLACK_FILELIST.match(path.absolute().relative_to(REPO_ROOT).as_posix()): + replacement = run_black(replacement, path=path) + else: + replacement = run_ruff_format(replacement, path=path) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if original == replacement: return [] diff --git a/tools/linter/adapters/test_device_bias_linter.py b/tools/linter/adapters/test_device_bias_linter.py index a2079e4fe810a..b5005eca61202 100644 --- a/tools/linter/adapters/test_device_bias_linter.py +++ b/tools/linter/adapters/test_device_bias_linter.py @@ -1,9 +1,15 @@ #!/usr/bin/env python3 """ This lint verifies that every Python test file (file that matches test_*.py or +<<<<<<< HEAD *_test.py in the test folder) has a cuda hard code in `requires_gpu()` or `requires_triton()` decorated function or `if HAS_GPU:` guarded main section, to ensure that the test not fail on other GPU devices. +======= +*_test.py in the test folder) has a cuda hard code in `requires_gpu()` +decorated function to ensure that the test not fail on other GPU. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ from __future__ import annotations @@ -39,6 +45,7 @@ class LintMessage(NamedTuple): DEVICE_BIAS = ["cuda", "xpu", "mps"] +<<<<<<< HEAD GPU_RELATED_DECORATORS = {"requires_gpu", "requires_triton"} @@ -87,11 +94,27 @@ def __init__(self, filename: str, is_gpu_test_suite: bool) -> None: def _has_proper_decorator(self, node: ast.FunctionDef) -> bool: for d in node.decorator_list: if isinstance(d, ast.Name) and d.id in GPU_RELATED_DECORATORS: +======= + + +class DeviceBiasVisitor(ast.NodeVisitor): + def __init__(self, filename: str): + self.filename = filename + self.lint_messages: list[LintMessage] = [] + + def _has_requires_gpu_decorator(self, node: ast.FunctionDef) -> bool: + for d in node.decorator_list: + if isinstance(d, ast.Name) and d.id == "requires_gpu": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return True if ( isinstance(d, ast.Call) and isinstance(d.func, ast.Name) +<<<<<<< HEAD and d.func.id in GPU_RELATED_DECORATORS +======= + and d.func.id == "requires_gpu" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): return True return False @@ -100,6 +123,10 @@ def _has_proper_decorator(self, node: ast.FunctionDef) -> bool: def _check_keyword_device(self, subnode: ast.keyword, msg_prefix: str) -> None: if subnode.arg != "device": return +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) val = subnode.value if isinstance(val, ast.Constant) and any( bias in val.value for bias in DEVICE_BIAS @@ -142,6 +169,7 @@ def _check_device_methods(self, subnode: ast.Call, msg_prefix: str) -> None: f"{msg_prefix} .to('{arg.value}'), suggest to use .to(GPU_TYPE)", ) +<<<<<<< HEAD def _check_with_statement(self, node: ast.With, msg_prefix: str) -> None: for item in node.items: ctx_expr = item.context_expr @@ -162,6 +190,17 @@ def _check_with_statement(self, node: ast.With, msg_prefix: str) -> None: ) def _check_node(self, node: ast.AST, msg_prefix: str) -> None: +======= + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + # Check if the function is decorated with @requires_gpu, which indicates + # that the function is intended to run on GPU devices (e.g., CUDA or XPU), + # but ensure it does not hardcode the device to CUDA. + if not self._has_requires_gpu_decorator(node): + self.generic_visit(node) + return + + msg_prefix = "`@requires_gpu` function should not hardcode" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for subnode in ast.walk(node): if isinstance(subnode, ast.keyword): self._check_keyword_device(subnode, msg_prefix) @@ -169,6 +208,7 @@ def _check_node(self, node: ast.AST, msg_prefix: str) -> None: subnode.func, ast.Attribute ): self._check_device_methods(subnode, msg_prefix) +<<<<<<< HEAD elif isinstance(subnode, ast.With): self._check_with_statement(subnode, msg_prefix) @@ -182,6 +222,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # If the function is guarded by HAS_GPU in main(), we still need to check for device bias msg_prefix = "The test suites is shared amount GPUS, should not hardcode" self._check_node(node, msg_prefix) +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.generic_visit(node) def record(self, node: ast.AST, message: str) -> None: @@ -204,16 +247,27 @@ def check_file(filename: str) -> list[LintMessage]: with open(filename) as f: source = f.read() tree = ast.parse(source, filename=filename) +<<<<<<< HEAD is_gpu_test_suite = is_main_has_gpu(tree) checker = DeviceBiasVisitor(filename, is_gpu_test_suite) checker.visit(tree) +======= + checker = DeviceBiasVisitor(filename) + checker.visit(tree) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return checker.lint_messages def main() -> None: parser = argparse.ArgumentParser( +<<<<<<< HEAD description="Detect Device bias in functions decorated with requires_gpu/requires_triton" " or guarded by HAS_GPU block in main() that may break other GPU devices.", +======= + description="Detect Device bias in python functions decorated with [require_gpu]" + " that may potentially break support for other GPU devices.", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fromfile_prefix_chars="@", ) parser.add_argument( diff --git a/tools/linter/dictionary.txt b/tools/linter/dictionary.txt index c4a250db04836..d438a7bb9bf0a 100644 --- a/tools/linter/dictionary.txt +++ b/tools/linter/dictionary.txt @@ -1,3 +1,4 @@ +<<<<<<< HEAD aLoad aLoads ans @@ -8,18 +9,28 @@ bLoad bLoads bStore bStores +======= +ans +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) BU contiguities contiguity coo DEPENDEES +<<<<<<< HEAD deser din dout +======= +Din +Dout +dOut +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ElementE followings fro froms +<<<<<<< HEAD Halfs hsa indexT @@ -44,12 +55,23 @@ overrideable oW padD posIn +======= +hsa +nd +nin +nout +NowNs +optins +OT +overrideable +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ptd rebuild rebuilt reenable reenabled requestor +<<<<<<< HEAD ser serde serder @@ -62,3 +84,8 @@ te THW tne WONT +======= +ser'de +supercedes +te +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/tools/lite_interpreter/gen_selected_mobile_ops_header.py b/tools/lite_interpreter/gen_selected_mobile_ops_header.py index f90d33c5ba452..d6e63edc01910 100644 --- a/tools/lite_interpreter/gen_selected_mobile_ops_header.py +++ b/tools/lite_interpreter/gen_selected_mobile_ops_header.py @@ -25,8 +25,13 @@ selected_kernel_dtypes_h_template_str = """ #include +<<<<<<< HEAD #include #include +======= +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace at { inline constexpr bool should_include_kernel_dtype( diff --git a/tools/lldb/deploy_debugger.py b/tools/lldb/deploy_debugger.py new file mode 100644 index 0000000000000..7a28c72a6caf2 --- /dev/null +++ b/tools/lldb/deploy_debugger.py @@ -0,0 +1,38 @@ +import lldb # type: ignore[import] + + +# load into lldb instance with: +# command script import tools/lldb/deploy_debugger.py + +target = lldb.debugger.GetSelectedTarget() +bp = target.BreakpointCreateByRegex("__deploy_register_code") +bp.SetScriptCallbackBody( + """\ +process = frame.thread.GetProcess() +target = process.target +symbol_addr = frame.module.FindSymbol("__deploy_module_info").GetStartAddress() +info_addr = symbol_addr.GetLoadAddress(target) +e = lldb.SBError() +ptr_size = 8 +str_addr = process.ReadPointerFromMemory(info_addr, e) +file_addr = process.ReadPointerFromMemory(info_addr + ptr_size, e) +file_size = process.ReadPointerFromMemory(info_addr + 2*ptr_size, e) +load_bias = process.ReadPointerFromMemory(info_addr + 3*ptr_size, e) +name = process.ReadCStringFromMemory(str_addr, 512, e) +r = process.ReadMemory(file_addr, file_size, e) +from tempfile import NamedTemporaryFile +from pathlib import Path +stem = Path(name).stem +with NamedTemporaryFile(prefix=stem, suffix='.so', delete=False) as tf: + tf.write(r) + print("torch_deploy registering debug information for ", tf.name) + cmd1 = f"target modules add {tf.name}" + # print(cmd1) + lldb.debugger.HandleCommand(cmd1) + cmd2 = f"target modules load -f {tf.name} -s {hex(load_bias)}" + # print(cmd2) + lldb.debugger.HandleCommand(cmd2) + +return False +""" +) diff --git a/tools/nightly.py b/tools/nightly.py index ba66eb7022288..14d22d6d7f10f 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -250,7 +250,10 @@ def __init__( self._env = { "PIP_EXTRA_INDEX_URL": self.pip_source.index_url, "UV_INDEX": self.pip_source.index_url, +<<<<<<< HEAD "UV_PYTHON_DOWNLOADS": "never", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "FORCE_COLOR": "1", "CLICOLOR_FORCE": "1", } @@ -436,7 +439,11 @@ def python( check=check, text=True, encoding="utf-8", +<<<<<<< HEAD env={**os.environ, **self._env, **env}, +======= + env={**self._env, **env}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) **popen_kwargs, ) @@ -476,12 +483,20 @@ def uv( cmd = [str(self.bindir / "uv"), *args] env = popen_kwargs.pop("env", None) or {} check = popen_kwargs.pop("check", True) +<<<<<<< HEAD +======= + env["UV_PYTHON"] = str(python) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return subprocess.run( cmd, check=check, text=True, encoding="utf-8", +<<<<<<< HEAD env={**os.environ, **self._env, **env, "UV_PYTHON": str(python)}, +======= + env={**self._env, **env}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) **popen_kwargs, ) @@ -686,7 +701,11 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N logging_record_exception(e) print(f"log file: {log_file}") sys.exit(1) +<<<<<<< HEAD except BaseException as e: # noqa: B036 +======= + except BaseException as e: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # You could logging.debug here to suppress the backtrace # entirely, but there is no reason to hide it from technically # savvy users. diff --git a/tools/packaging/build_wheel.py b/tools/packaging/build_wheel.py index 10c4516a32805..6d743e1baeac4 100644 --- a/tools/packaging/build_wheel.py +++ b/tools/packaging/build_wheel.py @@ -4,7 +4,10 @@ import contextlib import logging import os +<<<<<<< HEAD import re +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import subprocess import sys import tempfile @@ -17,12 +20,19 @@ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) +<<<<<<< HEAD logger.setLevel(logging.INFO) +======= +logger.setLevel(logging.DEBUG) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ROOT_PATH = Path(__file__).absolute().parent.parent.parent SETUP_PY_PATH = ROOT_PATH / "setup.py" REQUIREMENTS_PATH = ROOT_PATH / "requirements.txt" +<<<<<<< HEAD PYPROJECT_TOML_PATH = ROOT_PATH / "pyproject.toml" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def run_cmd( @@ -47,6 +57,7 @@ def interpreter_version(interpreter: str) -> str: return str(version_string.split(" ")[1]) +<<<<<<< HEAD def get_supported_python_versions() -> list[str]: """Extract supported Python versions from pyproject.toml classifiers.""" with open(PYPROJECT_TOML_PATH) as f: @@ -120,6 +131,8 @@ def _find_manylinux_interpreters() -> list[str]: return interpreters +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @contextlib.contextmanager def venv(interpreter: str) -> Iterator[str]: # Should this use EnvBuilder? Probably, maybe a good todo in the future @@ -176,6 +189,7 @@ def parse_args() -> argparse.Namespace: ), ) parser.add_argument( +<<<<<<< HEAD "--find-python", type=str, choices=["manylinux"], @@ -186,6 +200,8 @@ def parse_args() -> argparse.Namespace: ), ) parser.add_argument( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "-d", "--destination", default="dist/", @@ -197,6 +213,7 @@ def parse_args() -> argparse.Namespace: def main() -> None: args = parse_args() +<<<<<<< HEAD if args.find_python: if args.python: @@ -217,6 +234,9 @@ def main() -> None: else: pythons = args.python or [sys.executable] +======= + pythons = args.python or [sys.executable] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_times: dict[str, float] = dict() if len(pythons) > 1 and args.destination == "dist/": diff --git a/tools/packaging/split_wheel.py b/tools/packaging/split_wheel.py new file mode 100644 index 0000000000000..fd52c39a22b02 --- /dev/null +++ b/tools/packaging/split_wheel.py @@ -0,0 +1,109 @@ +"""Script to build split pytorch wheels + +What is split build / why is it important? + > Split build is splitting the PyTorch build into a libtorch & + > PyTorch python frontend package. This allows us to to publish + > both as separate packages and opens up our ability to have users + > install different libtorch backends per their PyTorch frontend + > + > Example: opening up the door to things like: + > pip install torch[cuda] + > pip install torch[rocm] + > pip install torch[cpu] + > etc. + +Why does this exist? + > Currently our split build requires you to invoke setup.py twice + > Which ends up complicating the build process and adds some level + > of complexity to our setup.py / build invocation for split builds. + > Ideally this script will eventually not be needed but for + > development purposes we should have an easy way to invoke this script +""" + +import argparse +import logging +import os +import subprocess +import sys +from pathlib import Path +from typing import Optional + + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +# NOTE: This will need to be updated if this script is ever moved +ROOT_PATH = Path(__file__).absolute().parents[2] +SETUP_PY_PATH = ROOT_PATH / "setup.py" + + +def requirements_installed() -> bool: + try: + import setuptools # type: ignore[import-untyped] # noqa: F401 + + return True + except ImportError: + logger.error( + "Requirements not installed, run the following command to install:" + ) + logger.error( + " > %s -m pip install -r %s/requirements.txt", sys.executable, ROOT_PATH + ) + return False + + +def setup_py(cmd_args: list[str], extra_env: Optional[dict[str, str]] = None) -> None: + if extra_env is None: + extra_env = {} + cmd = [sys.executable, str(SETUP_PY_PATH), *cmd_args] + logger.debug("+ %s", " ".join(cmd)) + subprocess.run( + cmd, + # Give the parent environment to the subprocess + env={**os.environ, **extra_env}, + check=True, + ) + + +def split_build(cmd: str) -> None: + logger.info("Running %s for libtorch wheel", cmd) + setup_py( + [cmd], + extra_env={"BUILD_LIBTORCH_WHL": "1", "BUILD_PYTHON_ONLY": "0"}, + ) + logger.info("Running %s for torch wheel", cmd) + # NOTE: Passing CMAKE_FRESH=1 is necessary here since the torch frontend has it's + # own cmake files that it needs to generate + setup_py( + [cmd], + extra_env={ + "BUILD_LIBTORCH_WHL": "0", + "BUILD_PYTHON_ONLY": "1", + "CMAKE_FRESH": "1", + }, + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + command_subparser = parser.add_subparsers(dest="command") + # Ideally these should mirror setuptools commands if we need support here for that + command_subparser.add_parser("install") + command_subparser.add_parser("bdist_wheel") + command_subparser.add_parser("develop") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + if not requirements_installed(): + sys.exit(1) + split_build(args.command) + + +if __name__ == "__main__": + main() diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 0dc1e8de37d8c..d443f4e781571 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -422,6 +422,7 @@ def gen_nn_functional(fm: FileManager) -> None: "Tensor", ) ], +<<<<<<< HEAD f"max_pool{d}d_with_indices": [ defs( f"max_pool{d}d_with_indices", @@ -435,6 +436,8 @@ def gen_nn_functional(fm: FileManager) -> None: "tuple[Tensor, Tensor]", ) ], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ) @@ -564,6 +567,7 @@ def gen_nn_functional(fm: FileManager) -> None: "Tensor", ) ], +<<<<<<< HEAD "elu": [ defs( "elu", @@ -663,6 +667,8 @@ def gen_nn_functional(fm: FileManager) -> None: "Tensor", ) ], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ) @@ -1024,6 +1030,7 @@ def gen_pyi( "None", ) ], +<<<<<<< HEAD "_functionalize_mutation_counter": [ defs( "_functionalize_mutation_counter", @@ -1045,6 +1052,8 @@ def gen_pyi( "_int", ) ], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "_functionalize_are_all_mutations_hidden_from_autograd": [ defs( "_functionalize_are_all_mutations_hidden_from_autograd", @@ -1070,8 +1079,13 @@ def gen_pyi( "_functionalize_was_storage_changed": [ defs("_functionalize_was_storage_changed", ["tensor: Tensor"], "_bool") ], +<<<<<<< HEAD "_functionalize_mark_storage_changed": [ "def _functionalize_mark_storage_changed(tensor: Tensor) -> _bool: ..." +======= + "_functionalize_set_storage_changed": [ + "def _functionalize_set_storage_changed(tensor: Tensor) -> _bool: ..." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ], "_functionalize_has_metadata_mutation": [ defs( @@ -1430,6 +1444,7 @@ def replace_special_case(hint: str) -> str: "S", ) ], +<<<<<<< HEAD "_make_dtensor": [ "@staticmethod\n" + defs( @@ -1444,6 +1459,8 @@ def replace_special_case(hint: str) -> str: "S", ) ], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "__contains__": [defs("__contains__", ["self", "item: Any", "/"], "_bool")], "__getitem__": [defs("__getitem__", ["self", INDICES, "/"], "Tensor")], "__setitem__": [ @@ -1863,10 +1880,17 @@ def replace_special_case(hint: str) -> str: # Include only the functions that contain hints, to prevent undefined # symbols to be included in the `__all__` directive. +<<<<<<< HEAD hinted_function_names = { name for name, hint in unsorted_function_hints.items() if hint } all_symbols = sorted(hinted_function_names.union(structseqs)) +======= + hinted_function_names = [ + name for name, hint in unsorted_function_hints.items() if hint + ] + all_symbols = sorted(list(structseqs) + hinted_function_names) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) all_directive = [ "__all__ = [", *(f' "{name}",' for name in all_symbols), @@ -1952,6 +1976,7 @@ def main() -> None: default=".", help="path to output directory", ) +<<<<<<< HEAD parser.add_argument( "--template-dir", default=".", @@ -1961,6 +1986,10 @@ def main() -> None: fm = FileManager( install_dir=args.out, template_dir=args.template_dir, dry_run=False ) +======= + args = parser.parse_args() + fm = FileManager(install_dir=args.out, template_dir=".", dry_run=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gen_pyi( args.native_functions_path, args.tags_path, diff --git a/tools/setup_helpers/__init__.py b/tools/setup_helpers/__init__.py index e227fd2ac0d95..9f84a23bc9ca0 100644 --- a/tools/setup_helpers/__init__.py +++ b/tools/setup_helpers/__init__.py @@ -2,6 +2,7 @@ import os import sys +<<<<<<< HEAD import warnings @@ -13,6 +14,11 @@ def which(thefile: str) -> str | None: stacklevel=2, ) +======= + + +def which(thefile: str) -> str | None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) path = os.environ.get("PATH", os.defpath).split(os.pathsep) for d in path: fname = os.path.join(d, thefile) diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index 02ab011dd482d..4b84d6324725e 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Manages CMake.""" from __future__ import annotations @@ -34,6 +35,25 @@ from distutils.version import ( # type: ignore[assignment,no-redef] LooseVersion as Version, ) +======= +"Manages CMake." + +from __future__ import annotations + +import multiprocessing +import os +import platform +import sys +import sysconfig +from distutils.version import LooseVersion +from pathlib import Path +from subprocess import CalledProcessError, check_call, check_output +from typing import Any, cast + +from . import which +from .cmake_utils import CMakeValue, get_cmake_cache_variables_from_file +from .env import BUILD_DIR, check_negative_env_flag, IS_64BIT, IS_DARWIN, IS_WINDOWS +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _mkdir_p(d: str) -> None: @@ -45,6 +65,7 @@ def _mkdir_p(d: str) -> None: ) from e +<<<<<<< HEAD # Print to stderr eprint = functools.partial(print, file=sys.stderr, flush=True) @@ -53,13 +74,22 @@ def _mkdir_p(d: str) -> None: # Use ninja if it is on the PATH. Previous version of PyTorch required the # ninja python package, but we no longer use it, so we do not have to import it USE_NINJA = bool(not check_negative_env_flag("USE_NINJA") and shutil.which("ninja")) +======= +# Ninja +# Use ninja if it is on the PATH. Previous version of PyTorch required the +# ninja python package, but we no longer use it, so we do not have to import it +USE_NINJA = not check_negative_env_flag("USE_NINJA") and which("ninja") is not None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if "CMAKE_GENERATOR" in os.environ: USE_NINJA = os.environ["CMAKE_GENERATOR"].lower() == "ninja" +<<<<<<< HEAD CMAKE_MINIMUM_VERSION = Version(CMAKE_MINIMUM_VERSION_STRING) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class CMake: "Manages cmake." @@ -76,6 +106,7 @@ def _cmake_cache_file(self) -> str: """ return os.path.join(self.build_dir, "CMakeCache.txt") +<<<<<<< HEAD @property def _ninja_build_file(self) -> str: r"""Returns the path to build.ninja. @@ -140,6 +171,55 @@ def run(self, args: list[str], env: dict[str, str]) -> None: command = [self._cmake_command] + args eprint(" ".join(command)) +======= + @staticmethod + def _get_cmake_command() -> str: + "Returns cmake command." + + cmake_command = "cmake" + if IS_WINDOWS: + return cmake_command + cmake3_version = CMake._get_version(which("cmake3")) + cmake_version = CMake._get_version(which("cmake")) + + _cmake_min_version = LooseVersion("3.27.0") + if all( + ver is None or ver < _cmake_min_version + for ver in [cmake_version, cmake3_version] + ): + raise RuntimeError( + "no cmake or cmake3 with version >= 3.27.0 found:" + + str([cmake_version, cmake3_version]) + ) + + if cmake3_version is None: + cmake_command = "cmake" + elif cmake_version is None: + cmake_command = "cmake3" + else: + if cmake3_version >= cmake_version: + cmake_command = "cmake3" + else: + cmake_command = "cmake" + return cmake_command + + @staticmethod + def _get_version(cmd: str | None) -> Any: + "Returns cmake version." + + if cmd is None: + return None + for line in check_output([cmd, "--version"]).decode("utf-8").split("\n"): + if "version" in line: + return LooseVersion(line.strip().split(" ")[2]) + raise RuntimeError("no version found") + + def run(self, args: list[str], env: dict[str, str]) -> None: + "Executes cmake with arguments and an environment." + + command = [self._cmake_command] + args + print(" ".join(command)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: check_call(command, cwd=self.build_dir, env=env) except (CalledProcessError, KeyboardInterrupt): @@ -150,7 +230,11 @@ def run(self, args: list[str], env: dict[str, str]) -> None: @staticmethod def defines(args: list[str], **kwargs: CMakeValue) -> None: +<<<<<<< HEAD """Adds definitions to a cmake argument list.""" +======= + "Adds definitions to a cmake argument list." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for key, value in sorted(kwargs.items()): if value is not None: args.append(f"-D{key}={value}") @@ -172,11 +256,16 @@ def generate( my_env: dict[str, str], rerun: bool, ) -> None: +<<<<<<< HEAD """Runs cmake to generate native build files.""" +======= + "Runs cmake to generate native build files." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if rerun and os.path.isfile(self._cmake_cache_file): os.remove(self._cmake_cache_file) +<<<<<<< HEAD cmake_cache_file_available = os.path.exists(self._cmake_cache_file) if cmake_cache_file_available: cmake_cache_variables = self.get_cmake_cache_variables() @@ -197,6 +286,11 @@ def generate( if cmake_cache_file_available and ( not USE_NINJA or os.path.exists(self._ninja_build_file) +======= + ninja_build_file = os.path.join(self.build_dir, "build.ninja") + if os.path.exists(self._cmake_cache_file) and not ( + USE_NINJA and not os.path.exists(ninja_build_file) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): # Everything's in place. Do not rerun. return @@ -210,9 +304,15 @@ def generate( generator = os.getenv("CMAKE_GENERATOR", "Visual Studio 16 2019") supported = ["Visual Studio 16 2019", "Visual Studio 17 2022"] if generator not in supported: +<<<<<<< HEAD eprint("Unsupported `CMAKE_GENERATOR`: " + generator) eprint("Please set it to one of the following values: ") eprint("\n".join(supported)) +======= + print("Unsupported `CMAKE_GENERATOR`: " + generator) + print("Please set it to one of the following values: ") + print("\n".join(supported)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sys.exit(1) args.append("-G" + generator) toolset_dict = {} @@ -221,7 +321,11 @@ def generate( toolset_dict["version"] = toolset_version curr_toolset = os.getenv("VCToolsVersion") if curr_toolset is None: +<<<<<<< HEAD eprint( +======= + print( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "When you specify `CMAKE_GENERATOR_TOOLSET_VERSION`, you must also " "activate the vs environment of this version. Please read the notes " "in the build steps carefully." @@ -338,7 +442,10 @@ def generate( # future, as CMake can detect many of these libraries pretty comfortably. We have them here for now before CMake # integration is completed. They appear here not in the CMake.defines call below because they start with either # "BUILD_" or "USE_" and must be overwritten here. +<<<<<<< HEAD use_numpy = not check_negative_env_flag("USE_NUMPY") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_options.update( { # Note: Do not add new build options to this dict if it is directly read from environment variable -- you @@ -348,15 +455,25 @@ def generate( "BUILD_TEST": build_test, # Most library detection should go to CMake script, except this one, which Python can do a much better job # due to NumPy's inherent Pythonic nature. +<<<<<<< HEAD "USE_NUMPY": use_numpy, +======= + "USE_NUMPY": not check_negative_env_flag("USE_NUMPY"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ) # Detect build dependencies from python lib path (in order to set *_HOME variables) # NVSHMEM +<<<<<<< HEAD nvshmem_py_dir = py_lib_path + "/nvidia/nvshmem" if os.path.exists(nvshmem_py_dir): build_options["NVSHMEM_PY_DIR"] = nvshmem_py_dir +======= + nvshmem_home = py_lib_path + "/nvidia/nvshmem" + if os.path.exists(nvshmem_home): + build_options["NVSHMEM_HOME"] = nvshmem_home +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Options starting with CMAKE_ cmake__options = { @@ -367,13 +484,18 @@ def generate( # error if the user also attempts to set these CMAKE options directly. specified_cmake__options = set(build_options).intersection(cmake__options) if len(specified_cmake__options) > 0: +<<<<<<< HEAD eprint( +======= + print( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ", ".join(specified_cmake__options) + " should not be specified in the environment variable. They are directly set by PyTorch build script." ) sys.exit(1) build_options.update(cmake__options) +<<<<<<< HEAD if use_numpy: try: # This helps CMake find the correct include directory for NumPy @@ -388,6 +510,8 @@ def generate( # use_numpy is just a hint.... so we can fail silently here pass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CMake.defines( args, Python_EXECUTABLE=sys.executable, @@ -410,8 +534,16 @@ def generate( my_env[env_var_name] = str(my_env[env_var_name].encode("utf-8")) except UnicodeDecodeError as e: shex = ":".join(f"{ord(c):02x}" for c in my_env[env_var_name]) +<<<<<<< HEAD eprint(f"Invalid ENV[{env_var_name}] = {shex}") eprint(e) +======= + print( + f"Invalid ENV[{env_var_name}] = {shex}", + file=sys.stderr, + ) + print(e, file=sys.stderr) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # According to the CMake manual, we should pass the arguments first, # and put the directory as the last element. Otherwise, these flags # may not be passed correctly. @@ -422,7 +554,11 @@ def generate( self.run(args, env=my_env) def build(self, my_env: dict[str, str]) -> None: +<<<<<<< HEAD """Runs cmake to build binaries.""" +======= + "Runs cmake to build binaries." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .env import build_type @@ -460,6 +596,7 @@ def build(self, my_env: dict[str, str]) -> None: # CMake 3.12 provides a '-j' option. build_args += ["-j", max_jobs] self.run(build_args, my_env) +<<<<<<< HEAD def clear_cache(self) -> None: """Clears the CMake cache.""" @@ -467,3 +604,5 @@ def clear_cache(self) -> None: os.remove(self._cmake_cache_file) if os.path.isfile(self._ninja_build_file): os.remove(self._ninja_build_file) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/tools/setup_helpers/env.py b/tools/setup_helpers/env.py index 3eb23af44a231..e0165e37c4add 100644 --- a/tools/setup_helpers/env.py +++ b/tools/setup_helpers/env.py @@ -11,8 +11,11 @@ from collections.abc import Iterable +<<<<<<< HEAD CMAKE_MINIMUM_VERSION_STRING = "3.27" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) IS_WINDOWS = platform.system() == "Windows" IS_DARWIN = platform.system() == "Darwin" IS_LINUX = platform.system() == "Linux" diff --git a/tools/setup_helpers/generate_code.py b/tools/setup_helpers/generate_code.py index e53efd7288c1f..4ba8f5e145ce7 100644 --- a/tools/setup_helpers/generate_code.py +++ b/tools/setup_helpers/generate_code.py @@ -189,12 +189,15 @@ def main() -> None: ) options = parser.parse_args() +<<<<<<< HEAD # Path: aten/src/ATen aten_path = os.path.dirname(os.path.dirname(options.native_functions_path)) operator_selector = get_selector( options.selected_op_list_path, options.operators_yaml_path ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) generate_code( options.gen_dir, options.native_functions_path, @@ -204,6 +207,7 @@ def main() -> None: options.disable_autograd, options.force_schema_registration, # options.selected_op_list +<<<<<<< HEAD operator_selector=operator_selector, ) @@ -235,6 +239,20 @@ def main() -> None: ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp" ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h" lazy_install_dir = os.path.join(install_dir, "lazy", "generated") +======= + operator_selector=get_selector( + options.selected_op_list_path, options.operators_yaml_path + ), + ) + + if options.gen_lazy_ts_backend: + aten_path = os.path.dirname(os.path.dirname(options.native_functions_path)) + ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml") + ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp" + ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h" + install_dir = options.install_dir or os.fspath(options.gen_dir / "torch/csrc") + lazy_install_dir = os.path.join(install_dir, "lazy/generated") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) os.makedirs(lazy_install_dir, exist_ok=True) assert os.path.isfile(ts_backend_yaml), ( diff --git a/tools/stats/import_test_stats.py b/tools/stats/import_test_stats.py index 0d39eb8203f75..75251238c7887 100644 --- a/tools/stats/import_test_stats.py +++ b/tools/stats/import_test_stats.py @@ -108,7 +108,11 @@ def process_disabled_test(the_response: dict[str, Any]) -> dict[str, Any]: return disabled_test_from_issues try: +<<<<<<< HEAD url = "https://ossci-metrics.s3.amazonaws.com/disabled-tests-condensed.json?versionId=UsscdNP.2GMOzUxAvqIx8GAj4MuhX1Xi" +======= + url = "https://ossci-metrics.s3.amazonaws.com/disabled-tests-condensed.json?versionId=oZfFXdfoa7trdcAiH1aL91T9jUDckwlX" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return fetch_and_cache(dirpath, filename, url, process_disabled_test) except Exception: print("Couldn't download test skip set, leaving all tests enabled...") diff --git a/tools/stats/monitor.py b/tools/stats/monitor.py index 38d1f94b178b2..82b8857755a2c 100644 --- a/tools/stats/monitor.py +++ b/tools/stats/monitor.py @@ -78,9 +78,12 @@ class GpuData: uuid: str utilization: float mem_utilization: float +<<<<<<< HEAD allocated_mem: float allocated_mem_value: float total_mem_value: float +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: @@ -262,7 +265,10 @@ def _generate_stats(self, data_list: list[float]) -> UtilizationStats: return UtilizationStats( avg=round(avg, 2), max=round(maxi, 2), +<<<<<<< HEAD raw=data_list, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def _output_data(self) -> None: @@ -342,33 +348,45 @@ def _calculate_gpu_utilization(self, data_list: list[UsageData]) -> list[GpuUsag calculate_gpu = [] gpu_mem_utilization = defaultdict(list) gpu_utilization = defaultdict(list) +<<<<<<< HEAD gpu_allocated_mem = defaultdict(list) gpu_allocated_mem_values = defaultdict(list) gpu_total_mem_values = defaultdict(float) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for data in data_list: for gpu in data.gpu_list: gpu_mem_utilization[gpu.uuid].append(gpu.mem_utilization) gpu_utilization[gpu.uuid].append(gpu.utilization) +<<<<<<< HEAD gpu_allocated_mem[gpu.uuid].append(gpu.allocated_mem) gpu_allocated_mem_values[gpu.uuid].append(gpu.allocated_mem_value) gpu_total_mem_values[gpu.uuid] = gpu.total_mem_value +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for gpu_uuid in gpu_utilization.keys(): gpu_util_stats = self._generate_stats(gpu_utilization[gpu_uuid]) gpu_mem_util_stats = self._generate_stats(gpu_mem_utilization[gpu_uuid]) +<<<<<<< HEAD gpu_allocated_mem_stats = self._generate_stats(gpu_allocated_mem[gpu_uuid]) gpu_allocated_mem_value_stats = self._generate_stats( gpu_allocated_mem_values[gpu_uuid] ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) calculate_gpu.append( GpuUsage( uuid=gpu_uuid, util_percent=gpu_util_stats, mem_util_percent=gpu_mem_util_stats, +<<<<<<< HEAD allocated_mem_percent=gpu_allocated_mem_stats, allocated_mem_value=gpu_allocated_mem_value_stats, total_mem_value=gpu_total_mem_values[gpu_uuid], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) return calculate_gpu @@ -399,6 +417,7 @@ def _collect_gpu_data(self) -> list[GpuData]: # see https://docs.nvidia.com/deploy/nvml-api/group__nvmlDeviceQueries.html gpu_utilization = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle) gpu_uuid = pynvml.nvmlDeviceGetUUID(gpu_handle) +<<<<<<< HEAD gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle) mem_utilization = gpu_utilization.memory @@ -406,14 +425,20 @@ def _collect_gpu_data(self) -> list[GpuData]: total_mem_MB = gpu_memory_info.total / 1024**2 allocate_mem_percent = allocate_mem_MB / total_mem_MB * 100 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gpu_data_list.append( GpuData( uuid=gpu_uuid, utilization=gpu_utilization.gpu, +<<<<<<< HEAD mem_utilization=mem_utilization, allocated_mem=allocate_mem_percent, allocated_mem_value=allocate_mem_MB, total_mem_value=total_mem_MB, +======= + mem_utilization=gpu_utilization.memory, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) elif self._has_amdsmi: @@ -424,20 +449,26 @@ def _collect_gpu_data(self) -> list[GpuData]: gpu_uuid = amdsmi.amdsmi_get_gpu_device_uuid(handle) gpu_utilization = engine_usage["gfx_activity"] gpu_mem_utilization = gpu_utilization["umc_activity"] +<<<<<<< HEAD mem_info = amdsmi.amdsmi_get_gpu_memory_usage(handle) allocate_mem_MB = mem_info["vram_usage"] / 1024**2 total_mem_MB = mem_info["vram_total"] / 1024**2 allocate_mem_percent = allocate_mem_MB / total_mem_MB * 100 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gpu_data_list.append( GpuData( uuid=gpu_uuid, utilization=gpu_utilization, mem_utilization=gpu_mem_utilization, +<<<<<<< HEAD allocated_mem=allocate_mem_percent, allocated_mem_value=allocate_mem_MB, total_mem_value=total_mem_MB, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) return gpu_data_list @@ -535,9 +566,13 @@ def get_processes_running_python_tests() -> list[Any]: cmd = " ".join(process.cmdline()) processName = process.name() pid = process.pid +<<<<<<< HEAD is_python = "python" in processName and "python" in cmd is_pytest = "pytest" in cmd if is_python or is_pytest: +======= + if "python" in processName and cmd.startswith("python"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python_test_processes.append({"pid": pid, "cmd": cmd}) except Exception: pass diff --git a/tools/stats/utilization_stats_lib.py b/tools/stats/utilization_stats_lib.py index 33551fd55de5f..9e1b2c3dd4324 100644 --- a/tools/stats/utilization_stats_lib.py +++ b/tools/stats/utilization_stats_lib.py @@ -5,7 +5,11 @@ from dataclasses_json import DataClassJsonMixin +<<<<<<< HEAD _DATA_MODEL_VERSION = 1.5 +======= +_DATA_MODEL_VERSION = 1.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # data model for test log usage @@ -13,7 +17,10 @@ class UtilizationStats: avg: Optional[float] = None max: Optional[float] = None +<<<<<<< HEAD raw: Optional[list[float]] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass @@ -37,9 +44,12 @@ class GpuUsage(DataClassJsonMixin): uuid: Optional[str] = None util_percent: Optional[UtilizationStats] = None mem_util_percent: Optional[UtilizationStats] = None +<<<<<<< HEAD allocated_mem_percent: Optional[UtilizationStats] = None allocated_mem_value: Optional[UtilizationStats] = None total_mem_value: Optional[float] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass diff --git a/tools/test/docstring_linter_testdata/more_python_code.py.txt.after.json b/tools/test/docstring_linter_testdata/more_python_code.py.txt.after.json index a62e93ecc2615..87bcdc9754981 100644 --- a/tools/test/docstring_linter_testdata/more_python_code.py.txt.after.json +++ b/tools/test/docstring_linter_testdata/more_python_code.py.txt.after.json @@ -1,5 +1,6 @@ { "tools/test/docstring_linter_testdata/more_python_code.py.txt": { +<<<<<<< HEAD " 1": "def a_very_very_long(): lines=8, docs=0: (grandfathered)", "10": "class LintInit: lines=6, docs=0: (grandfathered)" }, @@ -8,5 +9,13 @@ "24": "class LongWithShortDocstring: lines=6, docs=10: (grandfathered)", "54": "def long_without_docstring(): lines=7, docs=0: (grandfathered)", "71": "def ImpossibleCombo.needs_docs(): lines=12, docs=0: (grandfathered)" +======= + "11": "class LintInit: lines=6, docs=0: (grandfathered)" + }, + "tools/test/docstring_linter_testdata/python_code.py.txt": { + "20": "class LongWithoutDocstring: lines=4, docs=0: (grandfathered)", + "25": "class LongWithShortDocstring: lines=6, docs=10: (grandfathered)", + "72": "def ImpossibleCombo.needs_docs(): lines=12, docs=0: (grandfathered)" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } diff --git a/tools/test/docstring_linter_testdata/more_python_code.py.txt.before.json b/tools/test/docstring_linter_testdata/more_python_code.py.txt.before.json index f6f71e0a45d6a..f06e9ffc33500 100644 --- a/tools/test/docstring_linter_testdata/more_python_code.py.txt.before.json +++ b/tools/test/docstring_linter_testdata/more_python_code.py.txt.before.json @@ -1,5 +1,6 @@ { "tools/test/docstring_linter_testdata/more_python_code.py.txt": { +<<<<<<< HEAD " 1": "def a_very_very_long(): lines=8, docs=0: FAIL", "10": "class LintInit: lines=6, docs=0: FAIL" }, @@ -8,5 +9,13 @@ "24": "class LongWithShortDocstring: lines=6, docs=10: FAIL", "54": "def long_without_docstring(): lines=7, docs=0: FAIL", "71": "def ImpossibleCombo.needs_docs(): lines=12, docs=0: FAIL" +======= + "11": "class LintInit: lines=6, docs=0: FAIL" + }, + "tools/test/docstring_linter_testdata/python_code.py.txt": { + "20": "class LongWithoutDocstring: lines=4, docs=0: FAIL", + "25": "class LongWithShortDocstring: lines=6, docs=10: FAIL", + "72": "def ImpossibleCombo.needs_docs(): lines=12, docs=0: FAIL" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } diff --git a/tools/test/docstring_linter_testdata/more_python_code.py.txt.before.txt b/tools/test/docstring_linter_testdata/more_python_code.py.txt.before.txt index de8cf370f7cc4..4f9c52159a3a0 100644 --- a/tools/test/docstring_linter_testdata/more_python_code.py.txt.before.txt +++ b/tools/test/docstring_linter_testdata/more_python_code.py.txt.before.txt @@ -1,3 +1,4 @@ +<<<<<<< HEAD tools/test/docstring_linter_testdata/more_python_code.py.txt:1: No docstring found for function 'a_very_very_long' (8 lines) 1 | def a_very_very_long(): ^ @@ -6,6 +7,8 @@ tools/test/docstring_linter_testdata/more_python_code.py.txt:1: No docstring fou 4 | # Lots of lines! 5 | # Lots of lines! +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tools/test/docstring_linter_testdata/more_python_code.py.txt:10: No docstring found for class 'LintInit' (6 lines) 8 | 9 | @@ -14,7 +17,11 @@ tools/test/docstring_linter_testdata/more_python_code.py.txt:10: No docstring fo 11 | def __init__(self) -> None: 12 | # Lots of lines! +<<<<<<< HEAD tools/test/docstring_linter_testdata/python_code.py.txt:17: No docstring found for class 'LongWithoutDocstring' (6 lines) +======= +tools/test/docstring_linter_testdata/python_code.py.txt:17: No docstring found for class 'LongWithoutDocstring' (4 lines) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 15 | 16 | 17 | class LongWithoutDocstring: @@ -30,6 +37,7 @@ tools/test/docstring_linter_testdata/python_code.py.txt:24: docstring found for 25 | """TODO""" 26 | +<<<<<<< HEAD tools/test/docstring_linter_testdata/python_code.py.txt:54: No docstring found for function 'long_without_docstring' (7 lines) 52 | 53 | @@ -38,6 +46,8 @@ tools/test/docstring_linter_testdata/python_code.py.txt:54: No docstring found f 55 | # 56 | # +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tools/test/docstring_linter_testdata/python_code.py.txt:71: No docstring found for function 'needs_docs' (12 lines). If the method overrides a method on a parent class, adding the `@typing_extensions.override` decorator will make this error go away. 69 | """This docstring, while short, is enough""" 70 | diff --git a/tools/test/docstring_linter_testdata/more_python_code.py.txt.grandfather.json b/tools/test/docstring_linter_testdata/more_python_code.py.txt.grandfather.json index 1c4c8b6963a31..a343d65bd2944 100644 --- a/tools/test/docstring_linter_testdata/more_python_code.py.txt.grandfather.json +++ b/tools/test/docstring_linter_testdata/more_python_code.py.txt.grandfather.json @@ -1,5 +1,6 @@ { "tools/test/docstring_linter_testdata/more_python_code.py.txt": { +<<<<<<< HEAD "class LintInit": 6, "def a_very_very_long()": 8 }, @@ -8,5 +9,13 @@ "class LongWithoutDocstring": 6, "def ImpossibleCombo.needs_docs()": 12, "def long_without_docstring()": 7 +======= + "class LintInit": 6 + }, + "tools/test/docstring_linter_testdata/python_code.py.txt": { + "class LongWithShortDocstring": 6, + "class LongWithoutDocstring": 4, + "def ImpossibleCombo.needs_docs()": 12 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } \ No newline at end of file diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.json b/tools/test/docstring_linter_testdata/python_code.py.txt.json index b95486e7ff563..148ad9737301c 100644 --- a/tools/test/docstring_linter_testdata/python_code.py.txt.json +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.json @@ -4,7 +4,11 @@ "code": "DOCSTRING_LINTER", "description": null, "line": 17, +<<<<<<< HEAD "name": "No docstring found for class 'LongWithoutDocstring' (6 lines)", +======= + "name": "No docstring found for class 'LongWithoutDocstring' (4 lines)", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "original": null, "path": "tools/test/docstring_linter_testdata/python_code.py.txt", "replacement": null, @@ -22,6 +26,7 @@ "severity": "error" }, { +<<<<<<< HEAD "char": 0, "code": "DOCSTRING_LINTER", "description": null, @@ -33,6 +38,8 @@ "severity": "error" }, { +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "char": 4, "code": "DOCSTRING_LINTER", "description": null, diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.lintrunner b/tools/test/docstring_linter_testdata/python_code.py.txt.lintrunner index 2db9a576291a0..3d4cb1ddfde4d 100644 --- a/tools/test/docstring_linter_testdata/python_code.py.txt.lintrunner +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.lintrunner @@ -1,4 +1,8 @@ +<<<<<<< HEAD tools/test/docstring_linter_testdata/python_code.py.txt:17: No docstring found for class 'LongWithoutDocstring' (6 lines) +======= +tools/test/docstring_linter_testdata/python_code.py.txt:17: No docstring found for class 'LongWithoutDocstring' (4 lines) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 15 | 16 | 17 | class LongWithoutDocstring: @@ -14,6 +18,7 @@ tools/test/docstring_linter_testdata/python_code.py.txt:24: docstring found for 25 | """TODO""" 26 | +<<<<<<< HEAD tools/test/docstring_linter_testdata/python_code.py.txt:54: No docstring found for function 'long_without_docstring' (7 lines) 52 | 53 | @@ -22,6 +27,8 @@ tools/test/docstring_linter_testdata/python_code.py.txt:54: No docstring found f 55 | # 56 | # +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tools/test/docstring_linter_testdata/python_code.py.txt:71: No docstring found for function 'needs_docs' (12 lines). If the method overrides a method on a parent class, adding the `@typing_extensions.override` decorator will make this error go away. 69 | """This docstring, while short, is enough""" 70 | diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.recursive.json b/tools/test/docstring_linter_testdata/python_code.py.txt.recursive.json index 65d46215a3c25..67ef14fefbc67 100644 --- a/tools/test/docstring_linter_testdata/python_code.py.txt.recursive.json +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.recursive.json @@ -11,7 +11,11 @@ "is_method": false, "line_count": 4, "parent": null, +<<<<<<< HEAD "start_line": 1 +======= + "start_line": 2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -25,7 +29,11 @@ "is_method": false, "line_count": 3, "parent": null, +<<<<<<< HEAD "start_line": 6 +======= + "start_line": 7 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -42,7 +50,11 @@ "is_method": true, "line_count": 3, "parent": 2, +<<<<<<< HEAD "start_line": 13 +======= + "start_line": 14 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ], "decorators": [], @@ -54,7 +66,11 @@ "is_method": false, "line_count": 6, "parent": null, +<<<<<<< HEAD "start_line": 10 +======= + "start_line": 11 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -71,7 +87,11 @@ "is_method": true, "line_count": 3, "parent": 4, +<<<<<<< HEAD "start_line": 20 +======= + "start_line": 21 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ], "decorators": [], @@ -81,9 +101,15 @@ "index": 4, "is_local": false, "is_method": false, +<<<<<<< HEAD "line_count": 6, "parent": null, "start_line": 17 +======= + "line_count": 4, + "parent": null, + "start_line": 20 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -100,7 +126,11 @@ "is_method": true, "line_count": 3, "parent": 6, +<<<<<<< HEAD "start_line": 27 +======= + "start_line": 28 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ], "decorators": [], @@ -112,7 +142,11 @@ "is_method": false, "line_count": 6, "parent": null, +<<<<<<< HEAD "start_line": 24 +======= + "start_line": 25 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -129,7 +163,11 @@ "is_method": true, "line_count": 3, "parent": 8, +<<<<<<< HEAD "start_line": 34 +======= + "start_line": 35 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ], "decorators": [], @@ -141,7 +179,11 @@ "is_method": false, "line_count": 6, "parent": null, +<<<<<<< HEAD "start_line": 31 +======= + "start_line": 32 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -153,9 +195,15 @@ "index": 10, "is_local": false, "is_method": false, +<<<<<<< HEAD "line_count": 6, "parent": null, "start_line": 38 +======= + "line_count": 3, + "parent": null, + "start_line": 42 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -169,7 +217,11 @@ "is_method": false, "line_count": 8, "parent": null, +<<<<<<< HEAD "start_line": 45 +======= + "start_line": 46 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -181,9 +233,15 @@ "index": 12, "is_local": false, "is_method": false, +<<<<<<< HEAD "line_count": 7, "parent": null, "start_line": 54 +======= + "line_count": 3, + "parent": null, + "start_line": 59 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -206,7 +264,11 @@ "is_method": false, "line_count": 6, "parent": 15, +<<<<<<< HEAD "start_line": 73 +======= + "start_line": 74 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -220,7 +282,11 @@ "is_method": false, "line_count": 3, "parent": 15, +<<<<<<< HEAD "start_line": 80 +======= + "start_line": 81 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ], "decorators": [], @@ -232,7 +298,11 @@ "is_method": false, "line_count": 11, "parent": 14, +<<<<<<< HEAD "start_line": 72 +======= + "start_line": 73 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -246,7 +316,11 @@ "is_method": false, "line_count": 6, "parent": 15, +<<<<<<< HEAD "start_line": 73 +======= + "start_line": 74 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -260,7 +334,11 @@ "is_method": false, "line_count": 3, "parent": 15, +<<<<<<< HEAD "start_line": 80 +======= + "start_line": 81 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ], "decorators": [], @@ -272,7 +350,11 @@ "is_method": true, "line_count": 12, "parent": 13, +<<<<<<< HEAD "start_line": 71 +======= + "start_line": 72 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -289,7 +371,11 @@ "is_method": false, "line_count": 6, "parent": 15, +<<<<<<< HEAD "start_line": 73 +======= + "start_line": 74 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -303,7 +389,11 @@ "is_method": false, "line_count": 3, "parent": 15, +<<<<<<< HEAD "start_line": 80 +======= + "start_line": 81 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ], "decorators": [], @@ -315,7 +405,11 @@ "is_method": false, "line_count": 11, "parent": 14, +<<<<<<< HEAD "start_line": 72 +======= + "start_line": 73 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -329,7 +423,11 @@ "is_method": false, "line_count": 6, "parent": 15, +<<<<<<< HEAD "start_line": 73 +======= + "start_line": 74 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -343,7 +441,11 @@ "is_method": false, "line_count": 3, "parent": 15, +<<<<<<< HEAD "start_line": 80 +======= + "start_line": 81 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ], "decorators": [], @@ -353,9 +455,15 @@ "index": 13, "is_local": false, "is_method": false, +<<<<<<< HEAD "line_count": 21, "parent": null, "start_line": 62 +======= + "line_count": 15, + "parent": null, + "start_line": 69 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -372,7 +480,11 @@ "is_method": true, "line_count": 2, "parent": 18, +<<<<<<< HEAD "start_line": 86 +======= + "start_line": 87 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -386,9 +498,15 @@ "index": 20, "is_local": false, "is_method": true, +<<<<<<< HEAD "line_count": 6, "parent": 18, "start_line": 92 +======= + "line_count": 2, + "parent": 18, + "start_line": 97 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -402,7 +520,11 @@ "is_method": true, "line_count": 2, "parent": 18, +<<<<<<< HEAD "start_line": 99 +======= + "start_line": 100 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -416,7 +538,11 @@ "is_method": true, "line_count": 4, "parent": 18, +<<<<<<< HEAD "start_line": 102 +======= + "start_line": 103 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ], "decorators": [ @@ -430,7 +556,11 @@ "is_method": false, "line_count": 21, "parent": null, +<<<<<<< HEAD "start_line": 85 +======= + "start_line": 86 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -442,8 +572,14 @@ "index": 23, "is_local": false, "is_method": false, +<<<<<<< HEAD "line_count": 5, "parent": null, "start_line": 107 +======= + "line_count": 1, + "parent": null, + "start_line": 112 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ] diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.recursive.terse.json b/tools/test/docstring_linter_testdata/python_code.py.txt.recursive.terse.json index dd4c90dc2710c..110ab5e93d879 100644 --- a/tools/test/docstring_linter_testdata/python_code.py.txt.recursive.terse.json +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.recursive.terse.json @@ -1,17 +1,29 @@ { "class ImpossibleCombo": { "children": { +<<<<<<< HEAD "71": { "children": { "72": { "children": { "73": { +======= + "72": { + "children": { + "73": { + "children": { + "74": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 6, "name": "class ImpossibleCombo.needs_docs.not_short.Long", "status": "good" }, +<<<<<<< HEAD "80": { +======= + "81": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "class ImpossibleCombo.needs_docs.not_short.Short", @@ -23,13 +35,21 @@ "name": "def ImpossibleCombo.needs_docs.not_short", "status": "good" }, +<<<<<<< HEAD "73": { +======= + "74": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 6, "name": "class ImpossibleCombo.needs_docs.not_short.Long", "status": "good" }, +<<<<<<< HEAD "80": { +======= + "81": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "class ImpossibleCombo.needs_docs.not_short.Short", @@ -41,15 +61,25 @@ "name": "def ImpossibleCombo.needs_docs", "status": "good" }, +<<<<<<< HEAD "72": { "children": { "73": { +======= + "73": { + "children": { + "74": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 6, "name": "class ImpossibleCombo.needs_docs.not_short.Long", "status": "good" }, +<<<<<<< HEAD "80": { +======= + "81": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "class ImpossibleCombo.needs_docs.not_short.Short", @@ -61,13 +91,21 @@ "name": "def ImpossibleCombo.needs_docs.not_short", "status": "good" }, +<<<<<<< HEAD "73": { +======= + "74": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 6, "name": "class ImpossibleCombo.needs_docs.not_short.Long", "status": "good" }, +<<<<<<< HEAD "80": { +======= + "81": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "class ImpossibleCombo.needs_docs.not_short.Short", @@ -75,13 +113,22 @@ } }, "docstring_len": 44, +<<<<<<< HEAD "line": 62, "lines": 21, +======= + "line": 69, + "lines": 15, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "status": "good" }, "class LongWithDocstring": { "children": { +<<<<<<< HEAD "13": { +======= + "14": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "def LongWithDocstring.short1", @@ -89,13 +136,21 @@ } }, "docstring_len": 44, +<<<<<<< HEAD "line": 10, +======= + "line": 11, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 6, "status": "good" }, "class LongWithShortDocstring": { "children": { +<<<<<<< HEAD "27": { +======= + "28": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "def LongWithShortDocstring.short1", @@ -103,13 +158,21 @@ } }, "docstring_len": 10, +<<<<<<< HEAD "line": 24, +======= + "line": 25, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 6, "status": "good" }, "class LongWithoutDocstring": { "children": { +<<<<<<< HEAD "20": { +======= + "21": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "def LongWithoutDocstring.short1", @@ -117,18 +180,28 @@ } }, "docstring_len": 0, +<<<<<<< HEAD "line": 17, "lines": 6, +======= + "line": 20, + "lines": 4, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "status": "good" }, "class NotDocstring": { "children": { +<<<<<<< HEAD " 86": { +======= + " 87": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 2, "name": "def NotDocstring.short1", "status": "good" }, +<<<<<<< HEAD " 92": { "docstring_len": 0, "lines": 6, @@ -136,12 +209,25 @@ "status": "good" }, " 99": { +======= + " 97": { + "docstring_len": 0, + "lines": 2, + "name": "def NotDocstring.long_with_override", + "status": "good" + }, + "100": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 2, "name": "def NotDocstring.short2", "status": "good" }, +<<<<<<< HEAD "102": { +======= + "103": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 4, "name": "def NotDocstring.short3", @@ -149,25 +235,41 @@ } }, "docstring_len": 0, +<<<<<<< HEAD "line": 85, +======= + "line": 86, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 21, "status": "good" }, "class Short": { "docstring_len": 0, +<<<<<<< HEAD "line": 6, +======= + "line": 7, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 3, "status": "good" }, "class ShortWithDocstring": { "docstring_len": 44, +<<<<<<< HEAD "line": 1, +======= + "line": 2, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 4, "status": "good" }, "class _Protected": { "children": { +<<<<<<< HEAD "34": { +======= + "35": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "def _Protected.short1", @@ -175,32 +277,55 @@ } }, "docstring_len": 10, +<<<<<<< HEAD "line": 31, +======= + "line": 32, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 6, "status": "good" }, "def long": { "docstring_len": 44, +<<<<<<< HEAD "line": 45, +======= + "line": 46, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 8, "status": "good" }, "def long_with_omit": { "docstring_len": 0, +<<<<<<< HEAD "line": 107, "lines": 5, +======= + "line": 112, + "lines": 1, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "status": "good" }, "def long_without_docstring": { "docstring_len": 0, +<<<<<<< HEAD "line": 54, "lines": 7, +======= + "line": 59, + "lines": 3, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "status": "good" }, "def short": { "docstring_len": 0, +<<<<<<< HEAD "line": 38, "lines": 6, +======= + "line": 42, + "lines": 3, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "status": "good" } } diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.recursive.terse.line.json b/tools/test/docstring_linter_testdata/python_code.py.txt.recursive.terse.line.json index cadee32ab874f..8a1afea19cb86 100644 --- a/tools/test/docstring_linter_testdata/python_code.py.txt.recursive.terse.line.json +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.recursive.terse.line.json @@ -1,19 +1,33 @@ { +<<<<<<< HEAD " 1": { +======= + " 2": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 44, "lines": 4, "name": "class ShortWithDocstring", "status": "good" }, +<<<<<<< HEAD " 6": { +======= + " 7": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "class Short", "status": "good" }, +<<<<<<< HEAD " 10": { "children": { "13": { +======= + " 11": { + "children": { + "14": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "def LongWithDocstring.short1", @@ -25,9 +39,15 @@ "name": "class LongWithDocstring", "status": "good" }, +<<<<<<< HEAD " 17": { "children": { "20": { +======= + " 20": { + "children": { + "21": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "def LongWithoutDocstring.short1", @@ -35,6 +55,7 @@ } }, "docstring_len": 0, +<<<<<<< HEAD "lines": 6, "name": "class LongWithoutDocstring", "status": "good" @@ -42,6 +63,15 @@ " 24": { "children": { "27": { +======= + "lines": 4, + "name": "class LongWithoutDocstring", + "status": "good" + }, + " 25": { + "children": { + "28": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "def LongWithShortDocstring.short1", @@ -53,9 +83,15 @@ "name": "class LongWithShortDocstring", "status": "good" }, +<<<<<<< HEAD " 31": { "children": { "34": { +======= + " 32": { + "children": { + "35": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "def _Protected.short1", @@ -67,6 +103,7 @@ "name": "class _Protected", "status": "good" }, +<<<<<<< HEAD " 38": { "docstring_len": 0, "lines": 6, @@ -74,11 +111,21 @@ "status": "good" }, " 45": { +======= + " 42": { + "docstring_len": 0, + "lines": 3, + "name": "def short", + "status": "good" + }, + " 46": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 44, "lines": 8, "name": "def long", "status": "good" }, +<<<<<<< HEAD " 54": { "docstring_len": 0, "lines": 7, @@ -92,12 +139,31 @@ "72": { "children": { "73": { +======= + " 59": { + "docstring_len": 0, + "lines": 3, + "name": "def long_without_docstring", + "status": "good" + }, + " 69": { + "children": { + "72": { + "children": { + "73": { + "children": { + "74": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 6, "name": "class ImpossibleCombo.needs_docs.not_short.Long", "status": "good" }, +<<<<<<< HEAD "80": { +======= + "81": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "class ImpossibleCombo.needs_docs.not_short.Short", @@ -109,13 +175,21 @@ "name": "def ImpossibleCombo.needs_docs.not_short", "status": "good" }, +<<<<<<< HEAD "73": { +======= + "74": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 6, "name": "class ImpossibleCombo.needs_docs.not_short.Long", "status": "good" }, +<<<<<<< HEAD "80": { +======= + "81": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "class ImpossibleCombo.needs_docs.not_short.Short", @@ -127,15 +201,25 @@ "name": "def ImpossibleCombo.needs_docs", "status": "good" }, +<<<<<<< HEAD "72": { "children": { "73": { +======= + "73": { + "children": { + "74": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 6, "name": "class ImpossibleCombo.needs_docs.not_short.Long", "status": "good" }, +<<<<<<< HEAD "80": { +======= + "81": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "class ImpossibleCombo.needs_docs.not_short.Short", @@ -147,13 +231,21 @@ "name": "def ImpossibleCombo.needs_docs.not_short", "status": "good" }, +<<<<<<< HEAD "73": { +======= + "74": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 6, "name": "class ImpossibleCombo.needs_docs.not_short.Long", "status": "good" }, +<<<<<<< HEAD "80": { +======= + "81": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "class ImpossibleCombo.needs_docs.not_short.Short", @@ -161,6 +253,7 @@ } }, "docstring_len": 44, +<<<<<<< HEAD "lines": 21, "name": "class ImpossibleCombo", "status": "good" @@ -168,11 +261,21 @@ " 85": { "children": { " 86": { +======= + "lines": 15, + "name": "class ImpossibleCombo", + "status": "good" + }, + " 86": { + "children": { + " 87": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 2, "name": "def NotDocstring.short1", "status": "good" }, +<<<<<<< HEAD " 92": { "docstring_len": 0, "lines": 6, @@ -180,12 +283,25 @@ "status": "good" }, " 99": { +======= + " 97": { + "docstring_len": 0, + "lines": 2, + "name": "def NotDocstring.long_with_override", + "status": "good" + }, + "100": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 2, "name": "def NotDocstring.short2", "status": "good" }, +<<<<<<< HEAD "102": { +======= + "103": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 4, "name": "def NotDocstring.short3", @@ -197,9 +313,15 @@ "name": "class NotDocstring", "status": "good" }, +<<<<<<< HEAD "107": { "docstring_len": 0, "lines": 5, +======= + "112": { + "docstring_len": 0, + "lines": 1, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "def long_with_omit", "status": "good" } diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.report.json b/tools/test/docstring_linter_testdata/python_code.py.txt.report.json index 43a8648aad288..8c82149703976 100644 --- a/tools/test/docstring_linter_testdata/python_code.py.txt.report.json +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.report.json @@ -11,7 +11,11 @@ "is_method": false, "line_count": 4, "parent": null, +<<<<<<< HEAD "start_line": 1 +======= + "start_line": 2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -25,7 +29,11 @@ "is_method": false, "line_count": 3, "parent": null, +<<<<<<< HEAD "start_line": 6 +======= + "start_line": 7 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -41,7 +49,11 @@ "is_method": false, "line_count": 6, "parent": null, +<<<<<<< HEAD "start_line": 10 +======= + "start_line": 11 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -55,7 +67,11 @@ "is_method": true, "line_count": 3, "parent": 2, +<<<<<<< HEAD "start_line": 13 +======= + "start_line": 14 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -69,9 +85,15 @@ "index": 4, "is_local": false, "is_method": false, +<<<<<<< HEAD "line_count": 6, "parent": null, "start_line": 17 +======= + "line_count": 4, + "parent": null, + "start_line": 20 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -85,7 +107,11 @@ "is_method": true, "line_count": 3, "parent": 4, +<<<<<<< HEAD "start_line": 20 +======= + "start_line": 21 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -101,7 +127,11 @@ "is_method": false, "line_count": 6, "parent": null, +<<<<<<< HEAD "start_line": 24 +======= + "start_line": 25 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -115,7 +145,11 @@ "is_method": true, "line_count": 3, "parent": 6, +<<<<<<< HEAD "start_line": 27 +======= + "start_line": 28 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -131,7 +165,11 @@ "is_method": false, "line_count": 6, "parent": null, +<<<<<<< HEAD "start_line": 31 +======= + "start_line": 32 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -145,7 +183,11 @@ "is_method": true, "line_count": 3, "parent": 8, +<<<<<<< HEAD "start_line": 34 +======= + "start_line": 35 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -157,9 +199,15 @@ "index": 10, "is_local": false, "is_method": false, +<<<<<<< HEAD "line_count": 6, "parent": null, "start_line": 38 +======= + "line_count": 3, + "parent": null, + "start_line": 42 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -173,7 +221,11 @@ "is_method": false, "line_count": 8, "parent": null, +<<<<<<< HEAD "start_line": 45 +======= + "start_line": 46 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -185,9 +237,15 @@ "index": 12, "is_local": false, "is_method": false, +<<<<<<< HEAD "line_count": 7, "parent": null, "start_line": 54 +======= + "line_count": 3, + "parent": null, + "start_line": 59 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -204,9 +262,15 @@ "index": 13, "is_local": false, "is_method": false, +<<<<<<< HEAD "line_count": 21, "parent": null, "start_line": 62 +======= + "line_count": 15, + "parent": null, + "start_line": 69 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -224,7 +288,11 @@ "is_method": true, "line_count": 12, "parent": 13, +<<<<<<< HEAD "start_line": 71 +======= + "start_line": 72 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -241,7 +309,11 @@ "is_method": false, "line_count": 11, "parent": 14, +<<<<<<< HEAD "start_line": 72 +======= + "start_line": 73 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -255,7 +327,11 @@ "is_method": false, "line_count": 6, "parent": 15, +<<<<<<< HEAD "start_line": 73 +======= + "start_line": 74 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -269,7 +345,11 @@ "is_method": false, "line_count": 3, "parent": 15, +<<<<<<< HEAD "start_line": 80 +======= + "start_line": 81 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "class", @@ -290,7 +370,11 @@ "is_method": false, "line_count": 21, "parent": null, +<<<<<<< HEAD "start_line": 85 +======= + "start_line": 86 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -304,7 +388,11 @@ "is_method": true, "line_count": 2, "parent": 18, +<<<<<<< HEAD "start_line": 86 +======= + "start_line": 87 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -318,9 +406,15 @@ "index": 20, "is_local": false, "is_method": true, +<<<<<<< HEAD "line_count": 6, "parent": 18, "start_line": 92 +======= + "line_count": 2, + "parent": 18, + "start_line": 97 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -334,7 +428,11 @@ "is_method": true, "line_count": 2, "parent": 18, +<<<<<<< HEAD "start_line": 99 +======= + "start_line": 100 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -348,7 +446,11 @@ "is_method": true, "line_count": 4, "parent": 18, +<<<<<<< HEAD "start_line": 102 +======= + "start_line": 103 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, { "category": "def", @@ -360,8 +462,14 @@ "index": 23, "is_local": false, "is_method": false, +<<<<<<< HEAD "line_count": 5, "parent": null, "start_line": 107 +======= + "line_count": 1, + "parent": null, + "start_line": 112 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ] diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.single.line.json b/tools/test/docstring_linter_testdata/python_code.py.txt.single.line.json index 16b1f18567f78..981e777314aa5 100644 --- a/tools/test/docstring_linter_testdata/python_code.py.txt.single.line.json +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.single.line.json @@ -1,4 +1,5 @@ { +<<<<<<< HEAD " 1": "class ShortWithDocstring: lines=4, docs=44", " 6": "class Short: lines=3, docs=0", " 10": "class LongWithDocstring: lines=6, docs=44", @@ -23,4 +24,30 @@ " 99": "def NotDocstring.short2(): lines=2, docs=0", "102": "def NotDocstring.short3(): lines=4, docs=0", "107": "def long_with_omit(): lines=5, docs=0" +======= + " 2": "class ShortWithDocstring: lines=4, docs=44", + " 7": "class Short: lines=3, docs=0", + " 11": "class LongWithDocstring: lines=6, docs=44", + " 14": "def LongWithDocstring.short1(): lines=3, docs=0", + " 20": "class LongWithoutDocstring: lines=4, docs=0", + " 21": "def LongWithoutDocstring.short1(): lines=3, docs=0", + " 25": "class LongWithShortDocstring: lines=6, docs=10", + " 28": "def LongWithShortDocstring.short1(): lines=3, docs=0", + " 32": "class _Protected: lines=6, docs=10", + " 35": "def _Protected.short1(): lines=3, docs=0", + " 42": "def short(): lines=3, docs=0", + " 46": "def long(): lines=8, docs=44", + " 59": "def long_without_docstring(): lines=3, docs=0", + " 69": "class ImpossibleCombo: lines=15, docs=44", + " 72": "def ImpossibleCombo.needs_docs(): lines=12, docs=0", + " 73": "def ImpossibleCombo.needs_docs.not_short(): lines=11, docs=0", + " 74": "class ImpossibleCombo.needs_docs.not_short.Long: lines=6, docs=0", + " 81": "class ImpossibleCombo.needs_docs.not_short.Short: lines=3, docs=0", + " 86": "class NotDocstring: lines=21, docs=0", + " 87": "def NotDocstring.short1(): lines=2, docs=0", + " 97": "def NotDocstring.long_with_override(): lines=2, docs=0", + "100": "def NotDocstring.short2(): lines=2, docs=0", + "103": "def NotDocstring.short3(): lines=4, docs=0", + "112": "def long_with_omit(): lines=1, docs=0" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.terse.json b/tools/test/docstring_linter_testdata/python_code.py.txt.terse.json index 224da17c004fd..33105088d4b7e 100644 --- a/tools/test/docstring_linter_testdata/python_code.py.txt.terse.json +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.terse.json @@ -1,146 +1,248 @@ { "class ImpossibleCombo": { "docstring_len": 44, +<<<<<<< HEAD "line": 62, "lines": 21, +======= + "line": 69, + "lines": 15, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "status": "good" }, "class ImpossibleCombo.needs_docs.not_short.Long": { "docstring_len": 0, +<<<<<<< HEAD "line": 73, +======= + "line": 74, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 6, "status": "good" }, "class ImpossibleCombo.needs_docs.not_short.Short": { "docstring_len": 0, +<<<<<<< HEAD "line": 80, +======= + "line": 81, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 3, "status": "good" }, "class LongWithDocstring": { "docstring_len": 44, +<<<<<<< HEAD "line": 10, +======= + "line": 11, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 6, "status": "good" }, "class LongWithShortDocstring": { "docstring_len": 10, +<<<<<<< HEAD "line": 24, +======= + "line": 25, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 6, "status": "good" }, "class LongWithoutDocstring": { "docstring_len": 0, +<<<<<<< HEAD "line": 17, "lines": 6, +======= + "line": 20, + "lines": 4, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "status": "good" }, "class NotDocstring": { "docstring_len": 0, +<<<<<<< HEAD "line": 85, +======= + "line": 86, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 21, "status": "good" }, "class Short": { "docstring_len": 0, +<<<<<<< HEAD "line": 6, +======= + "line": 7, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 3, "status": "good" }, "class ShortWithDocstring": { "docstring_len": 44, +<<<<<<< HEAD "line": 1, +======= + "line": 2, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 4, "status": "good" }, "class _Protected": { "docstring_len": 10, +<<<<<<< HEAD "line": 31, +======= + "line": 32, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 6, "status": "good" }, "def ImpossibleCombo.needs_docs": { "docstring_len": 0, +<<<<<<< HEAD "line": 71, +======= + "line": 72, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 12, "status": "good" }, "def ImpossibleCombo.needs_docs.not_short": { "docstring_len": 0, +<<<<<<< HEAD "line": 72, +======= + "line": 73, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 11, "status": "good" }, "def LongWithDocstring.short1": { "docstring_len": 0, +<<<<<<< HEAD "line": 13, +======= + "line": 14, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 3, "status": "good" }, "def LongWithShortDocstring.short1": { "docstring_len": 0, +<<<<<<< HEAD "line": 27, +======= + "line": 28, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 3, "status": "good" }, "def LongWithoutDocstring.short1": { "docstring_len": 0, +<<<<<<< HEAD "line": 20, +======= + "line": 21, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 3, "status": "good" }, "def NotDocstring.long_with_override": { "docstring_len": 0, +<<<<<<< HEAD "line": 92, "lines": 6, +======= + "line": 97, + "lines": 2, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "status": "good" }, "def NotDocstring.short1": { "docstring_len": 0, +<<<<<<< HEAD "line": 86, +======= + "line": 87, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 2, "status": "good" }, "def NotDocstring.short2": { "docstring_len": 0, +<<<<<<< HEAD "line": 99, +======= + "line": 100, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 2, "status": "good" }, "def NotDocstring.short3": { "docstring_len": 0, +<<<<<<< HEAD "line": 102, +======= + "line": 103, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 4, "status": "good" }, "def _Protected.short1": { "docstring_len": 0, +<<<<<<< HEAD "line": 34, +======= + "line": 35, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 3, "status": "good" }, "def long": { "docstring_len": 44, +<<<<<<< HEAD "line": 45, +======= + "line": 46, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lines": 8, "status": "good" }, "def long_with_omit": { "docstring_len": 0, +<<<<<<< HEAD "line": 107, "lines": 5, +======= + "line": 112, + "lines": 1, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "status": "good" }, "def long_without_docstring": { "docstring_len": 0, +<<<<<<< HEAD "line": 54, "lines": 7, +======= + "line": 59, + "lines": 3, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "status": "good" }, "def short": { "docstring_len": 0, +<<<<<<< HEAD "line": 38, "lines": 6, +======= + "line": 42, + "lines": 3, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "status": "good" } } diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.terse.line.json b/tools/test/docstring_linter_testdata/python_code.py.txt.terse.line.json index 0e7d43c440f31..cd0838da987b0 100644 --- a/tools/test/docstring_linter_testdata/python_code.py.txt.terse.line.json +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.terse.line.json @@ -1,28 +1,45 @@ { +<<<<<<< HEAD " 1": { +======= + " 2": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 44, "lines": 4, "name": "class ShortWithDocstring", "status": "good" }, +<<<<<<< HEAD " 6": { +======= + " 7": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "class Short", "status": "good" }, +<<<<<<< HEAD " 10": { +======= + " 11": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 44, "lines": 6, "name": "class LongWithDocstring", "status": "good" }, +<<<<<<< HEAD " 13": { +======= + " 14": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "def LongWithDocstring.short1", "status": "good" }, +<<<<<<< HEAD " 17": { "docstring_len": 0, "lines": 6, @@ -30,35 +47,61 @@ "status": "good" }, " 20": { +======= + " 20": { + "docstring_len": 0, + "lines": 4, + "name": "class LongWithoutDocstring", + "status": "good" + }, + " 21": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "def LongWithoutDocstring.short1", "status": "good" }, +<<<<<<< HEAD " 24": { +======= + " 25": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 10, "lines": 6, "name": "class LongWithShortDocstring", "status": "good" }, +<<<<<<< HEAD " 27": { +======= + " 28": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "def LongWithShortDocstring.short1", "status": "good" }, +<<<<<<< HEAD " 31": { +======= + " 32": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 10, "lines": 6, "name": "class _Protected", "status": "good" }, +<<<<<<< HEAD " 34": { +======= + " 35": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "def _Protected.short1", "status": "good" }, +<<<<<<< HEAD " 38": { "docstring_len": 0, "lines": 6, @@ -66,11 +109,21 @@ "status": "good" }, " 45": { +======= + " 42": { + "docstring_len": 0, + "lines": 3, + "name": "def short", + "status": "good" + }, + " 46": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 44, "lines": 8, "name": "def long", "status": "good" }, +<<<<<<< HEAD " 54": { "docstring_len": 0, "lines": 7, @@ -84,41 +137,77 @@ "status": "good" }, " 71": { +======= + " 59": { + "docstring_len": 0, + "lines": 3, + "name": "def long_without_docstring", + "status": "good" + }, + " 69": { + "docstring_len": 44, + "lines": 15, + "name": "class ImpossibleCombo", + "status": "good" + }, + " 72": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 12, "name": "def ImpossibleCombo.needs_docs", "status": "good" }, +<<<<<<< HEAD " 72": { +======= + " 73": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 11, "name": "def ImpossibleCombo.needs_docs.not_short", "status": "good" }, +<<<<<<< HEAD " 73": { +======= + " 74": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 6, "name": "class ImpossibleCombo.needs_docs.not_short.Long", "status": "good" }, +<<<<<<< HEAD " 80": { +======= + " 81": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 3, "name": "class ImpossibleCombo.needs_docs.not_short.Short", "status": "good" }, +<<<<<<< HEAD " 85": { +======= + " 86": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 21, "name": "class NotDocstring", "status": "good" }, +<<<<<<< HEAD " 86": { +======= + " 87": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 2, "name": "def NotDocstring.short1", "status": "good" }, +<<<<<<< HEAD " 92": { "docstring_len": 0, "lines": 6, @@ -126,20 +215,39 @@ "status": "good" }, " 99": { +======= + " 97": { + "docstring_len": 0, + "lines": 2, + "name": "def NotDocstring.long_with_override", + "status": "good" + }, + "100": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 2, "name": "def NotDocstring.short2", "status": "good" }, +<<<<<<< HEAD "102": { +======= + "103": { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "docstring_len": 0, "lines": 4, "name": "def NotDocstring.short3", "status": "good" }, +<<<<<<< HEAD "107": { "docstring_len": 0, "lines": 5, +======= + "112": { + "docstring_len": 0, + "lines": 1, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "def long_with_omit", "status": "good" } diff --git a/tools/test/set_linter_testdata/python_code.py.txt b/tools/test/set_linter_testdata/python_code.py.txt index e805a3ca92be1..3be8bfe8cf2b0 100644 --- a/tools/test/set_linter_testdata/python_code.py.txt +++ b/tools/test/set_linter_testdata/python_code.py.txt @@ -30,9 +30,12 @@ class A: set = A().set +<<<<<<< HEAD # An f string as in https://github.com/pytorch/pytorch/issues/159056 f_string = f" {h:{w}} " +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Braced sets set1 = {1} diff --git a/tools/test/set_linter_testdata/python_code.py.txt.json b/tools/test/set_linter_testdata/python_code.py.txt.json index 22935a7904dfa..fcf97fb29c3e1 100644 --- a/tools/test/set_linter_testdata/python_code.py.txt.json +++ b/tools/test/set_linter_testdata/python_code.py.txt.json @@ -47,7 +47,11 @@ "char": 7, "code": "SET_LINTER", "description": null, +<<<<<<< HEAD "line": 38, +======= + "line": 35, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -58,7 +62,11 @@ "char": 9, "code": "SET_LINTER", "description": null, +<<<<<<< HEAD "line": 38, +======= + "line": 35, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -69,7 +77,11 @@ "char": 7, "code": "SET_LINTER", "description": null, +<<<<<<< HEAD "line": 39, +======= + "line": 36, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -80,7 +92,11 @@ "char": 12, "code": "SET_LINTER", "description": null, +<<<<<<< HEAD "line": 39, +======= + "line": 36, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -91,7 +107,11 @@ "char": 15, "code": "SET_LINTER", "description": null, +<<<<<<< HEAD "line": 41, +======= + "line": 38, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -102,7 +122,11 @@ "char": 36, "code": "SET_LINTER", "description": null, +<<<<<<< HEAD "line": 41, +======= + "line": 38, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -113,7 +137,11 @@ "char": 17, "code": "SET_LINTER", "description": null, +<<<<<<< HEAD "line": 44, +======= + "line": 41, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -124,7 +152,11 @@ "char": 22, "code": "SET_LINTER", "description": null, +<<<<<<< HEAD "line": 44, +======= + "line": 41, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -135,7 +167,11 @@ "char": 30, "code": "SET_LINTER", "description": null, +<<<<<<< HEAD "line": 44, +======= + "line": 41, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -146,7 +182,11 @@ "char": 50, "code": "SET_LINTER", "description": null, +<<<<<<< HEAD "line": 44, +======= + "line": 41, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -157,7 +197,11 @@ "char": 10, "code": "SET_LINTER", "description": null, +<<<<<<< HEAD "line": 47, +======= + "line": 44, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -168,7 +212,11 @@ "char": 51, "code": "SET_LINTER", "description": null, +<<<<<<< HEAD "line": 47, +======= + "line": 44, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -179,7 +227,11 @@ "char": 75, "code": "SET_LINTER", "description": null, +<<<<<<< HEAD "line": 47, +======= + "line": 44, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -190,7 +242,11 @@ "char": 77, "code": "SET_LINTER", "description": null, +<<<<<<< HEAD "line": 47, +======= + "line": 44, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "name": "Builtin `set` is deprecated", "original": null, "path": "tools/test/set_linter_testdata/python_code.py.txt", @@ -203,9 +259,15 @@ "description": null, "line": null, "name": "Suggested fixes for set_linter", +<<<<<<< HEAD "original": "# Basic tests\nimport tempfile\n\nprint(f\"{tempfile.gettempdir()}/memory_snapshot.pickle\")\n\nignored = set() # noqa: set_linter\na = set()\nb = \"set()\"\nc = set\nd = c.set\nf = (\n set(\n )\n)\nignored = (\n set( # noqa: set_linter\n )\n)\n\n# Non-sets\n\nd = {}\nlong_string = \"\"\" set()\nset() set x.set set()\n\\\"\"\"\"\n\nclass A:\n def set(self, x):\n self.x = x\n\nset = A().set\n\n# An f string as in https://github.com/pytorch/pytorch/issues/159056\nf_string = f\" {h:{w}} \"\n\n# Braced sets\n\nset1 = {1}\nset2 = {1, 2}\n\niterator_set = {i for i in range(10)}\n\n# A dict with two sets.\ndict_set = {\"a\": {2, 3}, \"b\": {i for i in range(3)}}\n\n# A set containing an object constructed with a dict and a set\nsos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})}\n", "path": "tools/test/set_linter_testdata/python_code.py.txt", "replacement": "# Basic tests\nimport tempfile\nfrom torch.utils._ordered_set import OrderedSet\n\n\nprint(f\"{tempfile.gettempdir()}/memory_snapshot.pickle\")\n\nignored = set() # noqa: set_linter\na = OrderedSet()\nb = \"set()\"\nc = OrderedSet\nd = c.set\nf = (\n OrderedSet(\n )\n)\nignored = (\n set( # noqa: set_linter\n )\n)\n\n# Non-sets\n\nd = {}\nlong_string = \"\"\" set()\nset() set x.set set()\n\\\"\"\"\"\n\nclass A:\n def set(self, x):\n self.x = x\n\nset = A().set\n\n# An f string as in https://github.com/pytorch/pytorch/issues/159056\nf_string = f\" {h:{w}} \"\n\n# Braced sets\n\nset1 = OrderedSet([1])\nset2 = OrderedSet([1, 2])\n\niterator_set = OrderedSet([i for i in range(10)])\n\n# A dict with two sets.\ndict_set = {\"a\": OrderedSet([2, 3]), \"b\": OrderedSet([i for i in range(3)])}\n\n# A set containing an object constructed with a dict and a set\nsos_set = OrderedSet([Something({i: i + 1 for i in range(3)}, OrderedSet([i + 1 for i in range(3)]))])\n", +======= + "original": "# Basic tests\nimport tempfile\n\nprint(f\"{tempfile.gettempdir()}/memory_snapshot.pickle\")\n\nignored = set() # noqa: set_linter\na = set()\nb = \"set()\"\nc = set\nd = c.set\nf = (\n set(\n )\n)\nignored = (\n set( # noqa: set_linter\n )\n)\n\n# Non-sets\n\nd = {}\nlong_string = \"\"\" set()\nset() set x.set set()\n\\\"\"\"\"\n\nclass A:\n def set(self, x):\n self.x = x\n\nset = A().set\n\n# Braced sets\n\nset1 = {1}\nset2 = {1, 2}\n\niterator_set = {i for i in range(10)}\n\n# A dict with two sets.\ndict_set = {\"a\": {2, 3}, \"b\": {i for i in range(3)}}\n\n# A set containing an object constructed with a dict and a set\nsos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})}\n", + "path": "tools/test/set_linter_testdata/python_code.py.txt", + "replacement": "# Basic tests\nimport tempfile\nfrom torch.utils._ordered_set import OrderedSet\n\n\nprint(f\"{tempfile.gettempdir()}/memory_snapshot.pickle\")\n\nignored = set() # noqa: set_linter\na = OrderedSet()\nb = \"set()\"\nc = OrderedSet\nd = c.set\nf = (\n OrderedSet(\n )\n)\nignored = (\n set( # noqa: set_linter\n )\n)\n\n# Non-sets\n\nd = {}\nlong_string = \"\"\" set()\nset() set x.set set()\n\\\"\"\"\"\n\nclass A:\n def set(self, x):\n self.x = x\n\nset = A().set\n\n# Braced sets\n\nset1 = OrderedSet([1])\nset2 = OrderedSet([1, 2])\n\niterator_set = OrderedSet([i for i in range(10)])\n\n# A dict with two sets.\ndict_set = {\"a\": OrderedSet([2, 3]), \"b\": OrderedSet([i for i in range(3)])}\n\n# A set containing an object constructed with a dict and a set\nsos_set = OrderedSet([Something({i: i + 1 for i in range(3)}, OrderedSet([i + 1 for i in range(3)]))])\n", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "severity": "error" } ] diff --git a/tools/test/set_linter_testdata/python_code.py.txt.lintrunner b/tools/test/set_linter_testdata/python_code.py.txt.lintrunner index 4926368e9ab17..3c93a79bc437f 100644 --- a/tools/test/set_linter_testdata/python_code.py.txt.lintrunner +++ b/tools/test/set_linter_testdata/python_code.py.txt.lintrunner @@ -30,6 +30,7 @@ tools/test/set_linter_testdata/python_code.py.txt:12:4: Builtin `set` is depreca 13 | ) 14 | ) +<<<<<<< HEAD tools/test/set_linter_testdata/python_code.py.txt:38:8: Builtin `set` is deprecated 36 | # Braced sets 37 | @@ -132,4 +133,108 @@ tools/test/set_linter_testdata/python_code.py.txt:47:78: Builtin `set` is deprec 45 | 46 | # A set containing an object constructed with a dict and a set 47 | sos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})} +======= +tools/test/set_linter_testdata/python_code.py.txt:35:8: Builtin `set` is deprecated + 33 | # Braced sets + 34 | + 35 | set1 = {1} + ^ + 36 | set2 = {1, 2} + 37 | + +tools/test/set_linter_testdata/python_code.py.txt:35:10: Builtin `set` is deprecated + 33 | # Braced sets + 34 | + 35 | set1 = {1} + ^ + 36 | set2 = {1, 2} + 37 | + +tools/test/set_linter_testdata/python_code.py.txt:36:8: Builtin `set` is deprecated + 34 | + 35 | set1 = {1} + 36 | set2 = {1, 2} + ^ + 37 | + 38 | iterator_set = {i for i in range(10)} + +tools/test/set_linter_testdata/python_code.py.txt:36:13: Builtin `set` is deprecated + 34 | + 35 | set1 = {1} + 36 | set2 = {1, 2} + ^ + 37 | + 38 | iterator_set = {i for i in range(10)} + +tools/test/set_linter_testdata/python_code.py.txt:38:16: Builtin `set` is deprecated + 36 | set2 = {1, 2} + 37 | + 38 | iterator_set = {i for i in range(10)} + ^ + 39 | + 40 | # A dict with two sets. + +tools/test/set_linter_testdata/python_code.py.txt:38:37: Builtin `set` is deprecated + 36 | set2 = {1, 2} + 37 | + 38 | iterator_set = {i for i in range(10)} + ^ + 39 | + 40 | # A dict with two sets. + +tools/test/set_linter_testdata/python_code.py.txt:41:18: Builtin `set` is deprecated + 39 | + 40 | # A dict with two sets. + 41 | dict_set = {"a": {2, 3}, "b": {i for i in range(3)}} + ^ + 42 | + 43 | # A set containing an object constructed with a dict and a set + +tools/test/set_linter_testdata/python_code.py.txt:41:23: Builtin `set` is deprecated + 39 | + 40 | # A dict with two sets. + 41 | dict_set = {"a": {2, 3}, "b": {i for i in range(3)}} + ^ + 42 | + 43 | # A set containing an object constructed with a dict and a set + +tools/test/set_linter_testdata/python_code.py.txt:41:31: Builtin `set` is deprecated + 39 | + 40 | # A dict with two sets. + 41 | dict_set = {"a": {2, 3}, "b": {i for i in range(3)}} + ^ + 42 | + 43 | # A set containing an object constructed with a dict and a set + +tools/test/set_linter_testdata/python_code.py.txt:41:51: Builtin `set` is deprecated + 39 | + 40 | # A dict with two sets. + 41 | dict_set = {"a": {2, 3}, "b": {i for i in range(3)}} + ^ + 42 | + 43 | # A set containing an object constructed with a dict and a set + +tools/test/set_linter_testdata/python_code.py.txt:44:11: Builtin `set` is deprecated + 42 | + 43 | # A set containing an object constructed with a dict and a set + 44 | sos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})} + ^ + +tools/test/set_linter_testdata/python_code.py.txt:44:52: Builtin `set` is deprecated + 42 | + 43 | # A set containing an object constructed with a dict and a set + 44 | sos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})} + ^ + +tools/test/set_linter_testdata/python_code.py.txt:44:76: Builtin `set` is deprecated + 42 | + 43 | # A set containing an object constructed with a dict and a set + 44 | sos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})} + ^ + +tools/test/set_linter_testdata/python_code.py.txt:44:78: Builtin `set` is deprecated + 42 | + 43 | # A set containing an object constructed with a dict and a set + 44 | sos_set = {Something({i: i + 1 for i in range(3)}, {i + 1 for i in range(3)})} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ^ diff --git a/tools/test/set_linter_testdata/python_code.py.txt.python b/tools/test/set_linter_testdata/python_code.py.txt.python index 52aaf12f26315..f3df3ae0e455e 100644 --- a/tools/test/set_linter_testdata/python_code.py.txt.python +++ b/tools/test/set_linter_testdata/python_code.py.txt.python @@ -32,9 +32,12 @@ class A: set = A().set +<<<<<<< HEAD # An f string as in https://github.com/pytorch/pytorch/issues/159056 f_string = f" {h:{w}} " +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Braced sets set1 = OrderedSet([1]) diff --git a/tools/test/test_docstring_linter.py b/tools/test/test_docstring_linter.py index f1b98391b9ae9..7c5d1e2a1b786 100644 --- a/tools/test/test_docstring_linter.py +++ b/tools/test/test_docstring_linter.py @@ -28,7 +28,11 @@ TEST_FILE = Path("tools/test/docstring_linter_testdata/python_code.py.txt") TEST_FILE2 = Path("tools/test/docstring_linter_testdata/more_python_code.py.txt") TEST_BLOCK_NAMES = Path("tools/test/docstring_linter_testdata/block_names.py.txt") +<<<<<<< HEAD ARGS = "--max-class=5", "--max-def=6", "--min-docstring=16" +======= +ARGS = "--max-class=3", "--max-def=4", "--min-docstring=16" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TestDocstringLinter(LinterTestCase): diff --git a/tools/test/test_set_linter.py b/tools/test/test_set_linter.py index 003096c3c408e..f6f54aa42af2b 100644 --- a/tools/test/test_set_linter.py +++ b/tools/test/test_set_linter.py @@ -77,7 +77,10 @@ def test_match_braced_sets(self) -> None: ("{i for i in range(2, 3)}", 1), ("{1, 2}", 1), ("{One({'a': 1}), Two([{}, {2}, {1, 2}])}", 3), +<<<<<<< HEAD ('f" {h:{w}} "', 0), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for s, expected in TESTS: pf = SetLinter.make_file(s) diff --git a/tools/testing/discover_tests.py b/tools/testing/discover_tests.py index 96aee230f89f8..dccf95f7a6a3d 100644 --- a/tools/testing/discover_tests.py +++ b/tools/testing/discover_tests.py @@ -13,7 +13,11 @@ def parse_test_module(test: str) -> str: +<<<<<<< HEAD return test.split(".", maxsplit=1)[0] +======= + return test.split(".")[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def discover_tests( diff --git a/tools/testing/modulefinder_determinator.py b/tools/testing/modulefinder_determinator.py index e0ef858b96b21..067faeafd703e 100644 --- a/tools/testing/modulefinder_determinator.py +++ b/tools/testing/modulefinder_determinator.py @@ -23,6 +23,10 @@ "test_cpp_extensions_aot_ninja", "test_cpp_extensions_aot_no_ninja", "test_cpp_extensions_jit", +<<<<<<< HEAD +======= + "test_cpp_extensions_open_device_registration", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "test_cpp_extensions_stream_and_event", "test_cpp_extensions_mtia_backend", "test_cuda", @@ -186,7 +190,11 @@ def get_dep_modules(test: str) -> set[str]: def parse_test_module(test: str) -> str: +<<<<<<< HEAD return test.split(".", maxsplit=1)[0] +======= + return test.split(".")[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def print_to_stderr(message: str) -> None: diff --git a/tools/testing/test_selections.py b/tools/testing/test_selections.py index 9493e35f97d72..196b58619aad5 100644 --- a/tools/testing/test_selections.py +++ b/tools/testing/test_selections.py @@ -10,6 +10,7 @@ from tools.testing.test_run import ShardedTest, TestRun +<<<<<<< HEAD try: from torch.testing._internal.common_cuda import SM80OrLater from torch.testing._internal.common_utils import TEST_CUDA @@ -18,6 +19,8 @@ SM80OrLater = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING: from collections.abc import Sequence @@ -26,13 +29,21 @@ IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1" BUILD_ENVIRONMENT = os.getenv("BUILD_ENVIRONMENT", "") +<<<<<<< HEAD +======= +USE_3_PROCS = "sm86" in BUILD_ENVIRONMENT or "cuda" not in BUILD_ENVIRONMENT +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job # to ensure that sharding is consistent, NUM_PROCS is the actual number of procs # used to run tests. If they are not equal, the only consequence should be # unequal shards. IS_ROCM = os.path.exists("/opt/rocm") +<<<<<<< HEAD NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 3 if not TEST_CUDA or SM80OrLater else 2 +======= +NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 3 if USE_3_PROCS else 2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NUM_PROCS_FOR_SHARDING_CALC = NUM_PROCS if not IS_ROCM or IS_MEM_LEAK_CHECK else 2 THRESHOLD = 60 * 10 # 10 minutes diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 1632147f0220e..1b2d9eca0865a 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -29,6 +29,10 @@ endif() set(LIBSHM_SRCDIR ${TORCH_SRC_DIR}/lib/${LIBSHM_SUBDIR}) add_subdirectory(${LIBSHM_SRCDIR}) +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Generate files set(TOOLS_PATH "${TORCH_ROOT}/tools") @@ -145,6 +149,7 @@ if(USE_CUDA) list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_CUFILE) endif() +<<<<<<< HEAD if(TARGET torch::nvtx3) list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::nvtx3) else() @@ -152,6 +157,9 @@ if(USE_CUDA) list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::nvtoolsext) endif() endif() +======= + list(APPEND TORCH_PYTHON_LINK_LIBRARIES CUDA::nvtx3) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() if(USE_ROCM) @@ -265,7 +273,11 @@ add_custom_command( OUTPUT "${TORCH_SRC_DIR}/utils/data/datapipes/datapipe.pyi" COMMAND +<<<<<<< HEAD ${CMAKE_COMMAND} -E env --modify PYTHONPATH=path_list_prepend:"${TORCH_ROOT}" -- +======= + ${CMAKE_COMMAND} -E env PYTHONPATH="${TORCH_ROOT}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "${Python_EXECUTABLE}" ${TORCH_SRC_DIR}/utils/data/datapipes/gen_pyi.py DEPENDS "${TORCH_SRC_DIR}/utils/data/datapipes/datapipe.pyi.in" @@ -477,11 +489,14 @@ else() set(TORCH_VERSION_DEBUG 0) endif() +<<<<<<< HEAD set(CUDA_VERSION "") if(CUDAToolkit_VERSION_MAJOR) set(CUDA_VERSION "${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR}") endif() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add_custom_target( gen_torch_version ALL "${Python_EXECUTABLE}" "${TOOLS_PATH}/generate_torch_version.py" @@ -505,11 +520,18 @@ if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") ) # Pybind11 requires explicit linking of the torch_python library if(BUILD_LIBTORCHLESS) +<<<<<<< HEAD target_link_libraries(nnapi_backend PRIVATE ${TORCH_LIB}) else() target_link_libraries(nnapi_backend PRIVATE torch) endif() target_link_libraries(nnapi_backend PRIVATE torch_python pybind::pybind11 fmt::fmt-header-only) +======= + target_link_libraries(nnapi_backend PRIVATE ${TORCH_LIB} torch_python pybind::pybind11) + else() + target_link_libraries(nnapi_backend PRIVATE torch torch_python pybind::pybind11) + endif() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) endif() set(TORCH_PYTHON_COMPILE_OPTIONS ${TORCH_PYTHON_COMPILE_OPTIONS} PARENT_SCOPE) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 5fe3f7e178b73..9f70b542dcc85 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -30,7 +30,10 @@ from torch._C import ( _cpu, _dynamo, _export, +<<<<<<< HEAD _functionalization, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _functorch, _lazy, _lazy_ts_backend, @@ -41,7 +44,10 @@ from torch._C import ( ) from torch._prims_common import DeviceLikeType from torch.autograd.graph import Node as _Node +<<<<<<< HEAD from torch.cuda import _POOL_HANDLE +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.node import Node as FxNode from torch.package import PackageExporter from torch.storage import TypedStorage, UntypedStorage @@ -952,7 +958,10 @@ class FunctionSchema: is_vararg: _bool, is_varret: _bool, ) -> None: ... +<<<<<<< HEAD def _is_view_op(self) -> _bool: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _UpgraderEntry: bumped_at_version: _int @@ -1215,8 +1224,11 @@ def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn def _set_mkldnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledMkldnn def _get_cudnn_benchmark() -> _bool: ... # THPModule_benchmarkCuDNN def _set_cudnn_benchmark(arg: _bool) -> None: ... # THPModule_setBenchmarkCuDNN +<<<<<<< HEAD def _get_miopen_immediate() -> _bool: ... # THPModule_userImmediateMiopen def _set_miopen_immediate(arg: _bool) -> None: ... # THPModule_setUserImmediateMiopen +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN def _get_mkldnn_deterministic() -> _bool: ... # THPModule_deterministicMkldnn @@ -1275,7 +1287,10 @@ def _set_sm_carveout_experimental(arg: _int | None) -> None: ... def _set_conj(x: Tensor, conj: _bool) -> None: ... def _set_neg(x: Tensor, neg: _bool) -> None: ... def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ... +<<<<<<< HEAD def _autocast_supported_devices() -> list[str]: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _meta_in_tls_dispatch_include() -> _bool: ... def _stash_obj_in_tls(key: str, arg: Any) -> None: ... def _get_obj_in_tls(key: str) -> Any: ... @@ -1301,6 +1316,7 @@ def _group_tensors_by_device_and_dtype( tuple[torch.device, torch.dtype], tuple[list[list[Tensor | None]], list[_int]], ]: ... +<<<<<<< HEAD def _initCrashHandler() -> None: ... # NB: There is no Capsule type in typing, see @@ -1319,6 +1335,13 @@ def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack def _torchDeviceToDLDevice( device: torch.device, ) -> tuple[_int, _int]: ... # THPModule_torchDeviceToDLDevice +======= + +# NB: There is no Capsule type in typing, see +# https://github.com/python/cpython/issues/109562 +def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack +def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _get_cpp_backtrace( frames_to_skip: _int, maximum_number_of_frames: _int, @@ -1378,8 +1401,11 @@ def _disabled_torch_dispatch_impl( ) -> Any: ... # THPModule_disable_dispatch_function def _get_linalg_preferred_backend() -> _LinalgBackend: ... def _set_linalg_preferred_backend(arg: _LinalgBackend): ... +<<<<<<< HEAD def _get_fp32_precision_getter(backend: str, op: str) -> str: ... def _set_fp32_precision_setter(backend: str, op: str, value: str) -> str: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _LinalgBackend: Default: _LinalgBackend @@ -1845,9 +1871,12 @@ class _SetExcludeDispatchKeyGuard: def __enter__(self): ... def __exit__(self, *exc_info: object) -> None: ... +<<<<<<< HEAD def _get_dtensor_allow_implicit_replication() -> _bool: ... def _set_dtensor_allow_implicit_replication(value: _bool) -> None: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Defined in torch/csrc/utils/schema_info.h class _SchemaInfo: @@ -1953,10 +1982,15 @@ def _mtia_isBuilt() -> _bool: ... def _mtia_isInBadFork() -> _bool: ... def _mtia_deviceSynchronize() -> None: ... def _mtia_getCurrentStream(device: _int) -> Stream: ... +<<<<<<< HEAD def _mtia_getCurrentRawStream(device: _int) -> _int: ... def _mtia_setCurrentStream(stream: Stream) -> None: ... def _mtia_getDefaultStream(device: _int) -> Stream: ... def _mtia_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ... +======= +def _mtia_setCurrentStream(stream: Stream) -> None: ... +def _mtia_getDefaultStream(device: _int) -> Stream: ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _mtia_memoryStats(device: _int) -> dict[str, Any]: ... def _mtia_getDeviceCapability(device: _int) -> tuple[_int, _int]: ... def _mtia_getDeviceProperties(device: _int) -> dict[str, Any]: ... @@ -1975,9 +2009,13 @@ def _mtia_resetPeakMemoryStats(device: _int) -> None: ... # Defined in torch/csrc/mps/Module.cpp def _mps_deviceSynchronize() -> None: ... +<<<<<<< HEAD def _mps_get_core_count() -> _int: ... def _mps_get_default_generator() -> Generator: ... def _mps_get_name() -> _str: ... +======= +def _mps_get_default_generator() -> Generator: ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _mps_emptyCache() -> None: ... def _mps_setMemoryFraction(fraction: _float) -> None: ... def _mps_currentAllocatedMemory() -> _int: ... @@ -2065,7 +2103,10 @@ def _cuda_record_memory_history_legacy( alloc_trace_record_context: _bool, clear_history: _bool, compile_context: _bool, +<<<<<<< HEAD global_record_annotations: _bool, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: ... def _cuda_record_memory_history( enabled: str | None, @@ -2074,7 +2115,10 @@ def _cuda_record_memory_history( max_entries: _int, clear_history: _bool, compile_context: _bool, +<<<<<<< HEAD global_record_annotations: _bool, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: ... def _cuda_isHistoryEnabled() -> _bool: ... def _cuda_getAllocatorBackend() -> str: ... @@ -2222,7 +2266,10 @@ class _SDPBackend(Enum): FLASH_ATTENTION = 1 EFFICIENT_ATTENTION = 2 CUDNN_ATTENTION = 3 +<<<<<<< HEAD OVERRIDEABLE = 4 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _is_flash_attention_available() -> _bool: ... def _can_use_cudnn_attention(params: _SDPAParams, debug: _bool) -> _bool: ... @@ -2312,7 +2359,11 @@ class _CUDAGraph: def __new__(cls, keep_graph: _bool = ...) -> Self: ... def capture_begin( self, +<<<<<<< HEAD pool: _POOL_HANDLE | None = ..., +======= + pool: tuple[_int, _int] | None = ..., +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) capture_error_mode: str = "global", ) -> None: ... def capture_end(self) -> None: ... @@ -2320,11 +2371,18 @@ class _CUDAGraph: def register_generator_state(self, Generator) -> None: ... def replay(self) -> None: ... def reset(self) -> None: ... +<<<<<<< HEAD def pool(self) -> _POOL_HANDLE: ... def enable_debug_mode(self) -> None: ... def debug_dump(self, debug_path: str) -> None: ... def raw_cuda_graph(self) -> _int: ... def raw_cuda_graph_exec(self) -> _int: ... +======= + def pool(self) -> tuple[_int, _int]: ... + def enable_debug_mode(self) -> None: ... + def debug_dump(self, debug_path: str) -> None: ... + def raw_cuda_graph(self) -> _int: ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Defined in torch/csrc/cuda/MemPool.cpp class _MemPool: @@ -2333,10 +2391,19 @@ class _MemPool: allocator: _cuda_CUDAAllocator | None = None, is_user_created: _bool = True, use_on_oom: _bool = False, +<<<<<<< HEAD +======= + symmetric: _bool = False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: ... @property def id(self) -> tuple[_int, _int]: ... @property +<<<<<<< HEAD +======= + def is_symmetric(self) -> _bool: ... + @property +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def allocator(self) -> _cuda_CUDAAllocator | None: ... def use_count(self) -> _int: ... @@ -2366,7 +2433,10 @@ class _XpuDeviceProperties: name: str platform_name: str vendor: str +<<<<<<< HEAD device_id: _int +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) driver_version: str version: str max_compute_units: _int @@ -2385,7 +2455,10 @@ class _XpuDeviceProperties: gpu_subslice_count: _int architecture: _int type: str +<<<<<<< HEAD uuid: Any +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Defined in torch/csrc/xpu/Stream.cpp class _XpuStreamBase(Stream): @@ -2441,11 +2514,14 @@ def _accelerator_getStream(device_index: _int) -> Stream: ... def _accelerator_synchronizeDevice(device_index: _int) -> None: ... def _accelerator_exchangeDevice(device_index: _int) -> _int: ... def _accelerator_maybeExchangeDevice(device_index: _int) -> _int: ... +<<<<<<< HEAD def _accelerator_isAllocatorInitialized() -> _bool: ... def _accelerator_emptyCache() -> None: ... def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ... def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ... def _accelerator_resetPeakStats(device_index: _int) -> None: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Defined in torch/csrc/jit/python/python_tracer.cpp class TracingState: diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index b166b280df9da..4304b9854f46a 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -77,7 +77,10 @@ class _KinetoEvent: def cuda_elapsed_us(self) -> int: ... def privateuse1_elapsed_us(self) -> int: ... def is_user_annotation(self) -> bool: ... +<<<<<<< HEAD def is_hidden_event(self) -> bool: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _ProfilerResult: def events(self) -> list[_KinetoEvent]: ... diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index ad3d8e3abf245..0a61ae72ec828 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -279,12 +279,19 @@ class Work: def is_success(self) -> bool: ... def exception(self) -> Any: ... def wait(self, timeout: timedelta = ...) -> bool: ... +<<<<<<< HEAD def block_current_stream(self) -> None: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_future(self) -> Future: ... def source_rank(self) -> int: ... def _source_rank(self) -> int: ... def result(self) -> list[Tensor]: ... +<<<<<<< HEAD def synchronize(self) -> None: ... +======= + def synchronize(self): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def boxed(self) -> ScriptObject: ... @staticmethod def unbox(obj: ScriptObject) -> Work: ... @@ -298,8 +305,11 @@ class Backend: def _timeout(self) -> timedelta: ... @_timeout.setter def _timeout(self, val: timedelta) -> None: ... +<<<<<<< HEAD global_ranks_in_group: list[int] group_name: str +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__( self, @@ -312,12 +322,18 @@ class Backend: def supports_coalescing(self) -> bool: ... @property def supports_time_estimate(self) -> bool: ... +<<<<<<< HEAD def set_timeout(self, timeout: timedelta) -> None: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def options(self) -> Options: ... def rank(self) -> int: ... def size(self) -> int: ... +<<<<<<< HEAD def name(self) -> str: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def abort(self) -> None: ... def shutdown(self) -> None: ... def eager_connect_single_device(self, device: torch.device | None) -> None: ... @@ -353,6 +369,7 @@ class ProcessGroup: ) -> None: ... def rank(self) -> int: ... def size(self) -> int: ... +<<<<<<< HEAD def get_group_store(self) -> Store: ... def split_group( self, @@ -372,6 +389,9 @@ class ProcessGroup: ) -> ProcessGroup: ... def abort(self) -> None: ... def set_timeout(self, timeout: timedelta) -> None: ... +======= + def abort(self) -> None: ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def shutdown(self) -> None: ... @overload def broadcast( @@ -384,7 +404,10 @@ class ProcessGroup: self, tensor: Tensor, root: int, +<<<<<<< HEAD timeout: timedelta | None = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Work: ... @overload def allreduce( @@ -397,14 +420,20 @@ class ProcessGroup: self, tensors: list[Tensor], op=..., +<<<<<<< HEAD timeout: timedelta | None = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Work: ... @overload def allreduce( self, tensor: Tensor, op=..., +<<<<<<< HEAD timeout: timedelta | None = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Work: ... def allreduce_coalesced( self, @@ -429,7 +458,10 @@ class ProcessGroup: tensor: Tensor, root: int, op=..., +<<<<<<< HEAD timeout: timedelta | None = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Work: ... @overload def allgather( @@ -443,7 +475,10 @@ class ProcessGroup: self, output_tensors: list[Tensor], input_tensor: Tensor, +<<<<<<< HEAD timeout: timedelta | None = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Work: ... def _allgather_base( self, @@ -476,7 +511,10 @@ class ProcessGroup: output_tensors: list[Tensor], input_tensor: Tensor, root: int, +<<<<<<< HEAD timeout: timedelta | None = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Work: ... @overload def scatter( @@ -491,7 +529,10 @@ class ProcessGroup: output_tensor: Tensor, input_tensors: list[Tensor], root: int, +<<<<<<< HEAD timeout: timedelta | None = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Work: ... @overload def reduce_scatter( @@ -505,8 +546,11 @@ class ProcessGroup: self, output_tensors: Tensor, input_tensor: list[Tensor], +<<<<<<< HEAD op=..., timeout: timedelta | None = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Work: ... def _reduce_scatter_base( self, @@ -530,7 +574,10 @@ class ProcessGroup: input: Tensor, output_split_sizes: list[int], input_split_sizes: list[int], +<<<<<<< HEAD timeout: timedelta | None = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Work: ... @overload def alltoall( @@ -544,7 +591,10 @@ class ProcessGroup: self, output: list[Tensor], input: list[Tensor], +<<<<<<< HEAD timeout: timedelta | None = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Work: ... def send( self, @@ -559,10 +609,14 @@ class ProcessGroup: tag: int, ) -> Work: ... def recv_anysource(self, tensors: list[Tensor], tag: int) -> Work: ... +<<<<<<< HEAD @overload def barrier(self, opts=...) -> Work: ... @overload def barrier(self, timeout: timedelta | None = None) -> Work: ... +======= + def barrier(self, opts=...) -> Work: ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def boxed(self) -> ScriptObject: ... @staticmethod def unbox(obj: ScriptObject) -> ProcessGroup: ... @@ -610,6 +664,11 @@ class ProcessGroupGloo(Backend): class Options(Backend.Options): devices: list[ProcessGroupGloo.Device] threads: int +<<<<<<< HEAD +======= + global_ranks_in_group: list[int] + group_name: str +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self): ... @@ -644,13 +703,21 @@ class ProcessGroupNCCL(Backend): cga_cluster_size: int min_ctas: int max_ctas: int +<<<<<<< HEAD def unsafe_get_ptr(self) -> int: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Options(Backend.Options): config: ProcessGroupNCCL.NCCLConfig is_high_priority_stream: bool split_from: ProcessGroupNCCL split_color: int +<<<<<<< HEAD +======= + global_ranks_in_group: list[int] + group_name: str +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self, is_high_priority_stream: bool = False): ... @@ -734,7 +801,11 @@ def _allow_inflight_collective_as_graph_input() -> bool: ... def _unregister_all_process_groups() -> None: ... def _unregister_process_group(group_name: str) -> None: ... +<<<<<<< HEAD # Initializes the device state in CUmodule so that it’s able to perform NVSHMEM +======= +# Intializes the device state in CUmodule so that it’s able to perform NVSHMEM +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # operations. CUmodule is a pointer to a CUDA module, carried by a int64 in # Python. At C++ interface, it is converted to a uintptr_t. def _nvshmemx_cumodule_init(module: int) -> None: ... @@ -764,6 +835,7 @@ class _SymmetricMemory: device_type: DeviceType, device_idx: int, ) -> bool: ... +<<<<<<< HEAD # Set Symmetric Memory allocation backend. @staticmethod def set_backend(name: str) -> None: ... @@ -771,6 +843,8 @@ class _SymmetricMemory: def get_backend(device: torch.device) -> Optional[str]: ... @staticmethod def get_mempool_allocator(device: torch.device) -> Any: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def rank(self) -> int: ... @property @@ -806,12 +880,15 @@ class _SymmetricMemory: channel: int = 0, timeout_ms: int = 0, ) -> None: ... +<<<<<<< HEAD def get_remote_tensor( self, peer: int, sizes: torch.types._size, dtype: torch.dtype, ) -> torch.Tensor: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def memset32( tensor: torch.Tensor, offset: int, val: int, count: int = 1 @@ -836,14 +913,18 @@ class _SymmetricMemory: def signal_pad_size(self) -> int: ... class ProcessGroupXCCL(Backend): +<<<<<<< HEAD class Options(Backend.Options): def __init__(self): ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__( self, store: Store, rank: int, size: int, +<<<<<<< HEAD options: Options, ) -> None: ... @property @@ -851,3 +932,6 @@ class ProcessGroupXCCL(Backend): def _set_process_group(pg: ProcessGroup) -> None: ... def _current_process_group() -> ProcessGroup: ... +======= + ): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index 117795db5ac3e..02c06a5cd013a 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -2,9 +2,18 @@ import enum import types from typing import Optional, overload +<<<<<<< HEAD from torch._dynamo.guards import GuardManagerWrapper from torch._dynamo.types import DynamoCallback, DynamoGuardCompleteHook, DynamoGuardHook from torch._guards import CompileId +======= +from torch._dynamo.types import ( + DynamoCallback, + DynamoGuardCompleteHook, + DynamoGuardHook, + GuardFn, +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ... def set_skip_guard_eval_unsafe(value: bool) -> bool: ... @@ -22,6 +31,7 @@ def raise_sigtrap() -> None: ... class _CacheEntry: def check_fn(self, *args: object, **kwargs: object) -> bool: ... +<<<<<<< HEAD def update_diff_guard_root_manager(self) -> None: ... code: types.CodeType compile_id: CompileId @@ -36,6 +46,13 @@ class _ExtraState: def invalidate( self, cache_entry: _CacheEntry, guard_manager: GuardManagerWrapper ) -> None: ... +======= + code: types.CodeType + next: _CacheEntry | None + +class _ExtraState: + def invalidate(self, cache_entry: _CacheEntry, guard_manager: object) -> None: ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _FrameAction(enum.IntEnum): DEFAULT = 0 @@ -61,7 +78,11 @@ class _PyInterpreterFrame: f_globals: dict[str, object] f_builtins: dict[str, object] f_lasti: int +<<<<<<< HEAD f_lineno: int +======= + f_lineo: int +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f_back: types.FrameType # A tuple containing cell objects captured by this frame. closure: tuple[types.CellType] @@ -72,9 +93,15 @@ py_opcode_caches: list[int] def code_framelocals_names(code: types.CodeType) -> tuple[str]: ... def _load_precompile_entry( +<<<<<<< HEAD code: types.CodeType, guard_manager: GuardManagerWrapper, dynamo_code: types.CodeType, ) -> None: ... def _reset_precompile_entries(code: types.CodeType) -> None: ... def _debug_get_precompile_entries(code: types.CodeType) -> list[_PrecompileEntry]: ... +======= + code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType +) -> None: ... +def _reset_precompile_entries(code: types.CodeType) -> None: ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index aa6614504fc23..2589be6109056 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -1,3 +1,4 @@ +<<<<<<< HEAD import enum from typing import Any, Callable, Optional from typing_extensions import TypeAlias @@ -9,10 +10,18 @@ import torch # imports GuardManagerType: TypeAlias = enum.Enum +======= +# mypy: allow-untyped-defs +from typing import Any, Callable + +import torch + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class GlobalStateGuard: def check(self) -> bool: ... def reason(self) -> str: ... +<<<<<<< HEAD class LeafGuard: def verbose_code_parts(self) -> list[str]: ... @@ -26,18 +35,33 @@ class GuardDebugInfo: class GuardManager: def check(self, value: Any) -> bool: ... def check_verbose(self, value: Any) -> GuardDebugInfo: ... +======= +class LeafGuard: ... +class GuardDebugInfo: ... + +class GuardManager: + def check(self, value) -> bool: ... + def check_verbose(self, value) -> GuardDebugInfo: ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Accessors def globals_dict_manager( self, f_globals: dict[str, Any], +<<<<<<< HEAD source: str, example_value: Any, guard_manager_enum: GuardManagerType, +======= + source, + example_value, + guard_manager_enum, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> GuardManager: ... def framelocals_manager( self, key: tuple[str, int], +<<<<<<< HEAD source: str, example_value: Any, guard_manager_enum: GuardManagerType, @@ -126,10 +150,23 @@ class GuardManager: source: str, example_value: Any, guard_manager_enum: GuardManagerType, +======= + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def dict_getitem_manager( + self, + key, + source, + example_value, + guard_manager_enum, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> GuardManager: ... def global_weakref_manager( self, global_name: str, +<<<<<<< HEAD source: str, example_value: Any, guard_manager_enum: GuardManagerType, @@ -139,24 +176,48 @@ class GuardManager: source: str, example_value: Any, guard_manager_enum: GuardManagerType, +======= + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def type_manager( + self, + source, + example_value, + guard_manager_enum, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> GuardManager: ... def getattr_manager( self, attr: str, +<<<<<<< HEAD source: str, example_value: Any, guard_manager_enum: GuardManagerType, +======= + source, + example_value, + guard_manager_enum, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> GuardManager: ... def tensor_property_size_manager( self, idx: int, +<<<<<<< HEAD source: str, example_value: Any, guard_manager_enum: GuardManagerType, +======= + source, + example_value, + guard_manager_enum, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> GuardManager: ... def tensor_property_shape_manager( self, idx: int, +<<<<<<< HEAD source: str, example_value: Any, guard_manager_enum: GuardManagerType, @@ -167,10 +228,23 @@ class GuardManager: source: str, example_value: Any, guard_manager_enum: GuardManagerType, +======= + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def tensor_property_storage_offset_manager( + self, + idx: None, + source, + example_value, + guard_manager_enum, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> GuardManager: ... def indexed_manager( self, idx: int, +<<<<<<< HEAD source: str, example_value: Any, guard_manager_enum: GuardManagerType, @@ -330,6 +404,33 @@ class GuardManager: ) -> None: ... def mark_tag_safe(self) -> None: ... def mark_tag_safe_root(self) -> None: ... +======= + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def lambda_manager( + self, + python_lambda, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + + # Leaf guards + def add_lambda_guard(self, user_lambda, verbose_code_parts: list[str]) -> None: ... + def add_id_match_guard(self, id_val, verbose_code_parts: list[str]) -> None: ... + def add_equals_match_guard( + self, + equals_val, + verbose_code_parts: list[str], + ) -> None: ... + def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ... + def add_torch_function_mode_stack_guard( + self, initial_stack, verbose_code_parts: list[str] + ) -> None: ... + def add_mapping_keys_guard(sef, value, verbose_code_parts: list[str]) -> None: ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class RootGuardManager(GuardManager): def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ... @@ -341,11 +442,15 @@ class RootGuardManager(GuardManager): def clone_manager( self, clone_filter_fn: Callable[[GuardManager], bool] ) -> RootGuardManager: ... +<<<<<<< HEAD def attach_compile_id(self, compile_id: str) -> None: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class DictGuardManager(GuardManager): def get_key_manager( self, +<<<<<<< HEAD index: int, source: str, example_value: Any, @@ -383,16 +488,44 @@ def install_object_aliasing_guard( y: GuardManager, verbose_code_parts: list[str], ) -> None: ... +======= + index, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + def get_value_manager( + self, + index, + source, + example_value, + guard_manager_enum, + ) -> GuardManager: ... + +def install_object_aliasing_guard( + guard_managers: list[GuardManager], + tensor_names: list[str], + verbose_code_parts: list[str], +): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def install_no_tensor_aliasing_guard( guard_managers: list[GuardManager], tensor_names: list[str], verbose_code_parts: list[str], +<<<<<<< HEAD ) -> None: ... +======= +): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def install_storage_overlapping_guard( overlapping_guard_managers: list[GuardManager], non_overlapping_guard_managers: list[GuardManager], verbose_code_parts: list[str], +<<<<<<< HEAD ) -> None: ... +======= +): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def install_symbolic_shape_guard( guard_managers: list[GuardManager], nargs_int: int, @@ -400,7 +533,11 @@ def install_symbolic_shape_guard( py_addr: int, py_addr_keep_alive: Any, verbose_code_parts: list[str], +<<<<<<< HEAD ) -> None: ... +======= +): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def profile_guard_manager( guard_manager: GuardManager, f_locals: dict[str, Any], @@ -414,22 +551,35 @@ class TensorGuards: dynamic_dims_sizes: list[torch.SymInt | None] | None = None, dynamic_dims_strides: list[torch.SymInt | None] | None = None, ) -> None: ... +<<<<<<< HEAD def check(self, *args: Any) -> bool: ... def check_verbose( self, *args: Any, tensor_check_names: Optional[list[str]] = None ) -> bool | str: ... +======= + def check(self, *args) -> bool: ... + def check_verbose(self, *args, tensor_check_names=None) -> bool | str: ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def assert_size_stride( item: torch.Tensor, size: torch.types._size, stride: torch.types._size, op_name: str | None = None, +<<<<<<< HEAD ) -> None: ... +======= +): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def assert_alignment( item: torch.Tensor, alignment: int, op_name: str | None = None, +<<<<<<< HEAD ) -> None: ... +======= +): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def check_obj_id(obj: object, expected: int) -> bool: ... def check_type_id(obj: object, expected: int) -> bool: ... def dict_version(d: dict[Any, Any]) -> int: ... diff --git a/torch/_C/_export/pt2_archive_constants.pyi b/torch/_C/_export/pt2_archive_constants.pyi index ce225f0f1880b..f85a5e1974195 100644 --- a/torch/_C/_export/pt2_archive_constants.pyi +++ b/torch/_C/_export/pt2_archive_constants.pyi @@ -10,10 +10,15 @@ MODELS_FILENAME_FORMAT: str = ... AOTINDUCTOR_DIR: str = ... MTIA_DIR: str = ... WEIGHTS_DIR: str = ... +<<<<<<< HEAD WEIGHTS_CONFIG_FILENAME_FORMAT: str = ... WEIGHT_FILENAME_PREFIX: str = ... CONSTANTS_DIR: str = ... CONSTANTS_CONFIG_FILENAME_FORMAT: str = ... +======= +WEIGHT_FILENAME_PREFIX: str = ... +CONSTANTS_DIR: str = ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TENSOR_CONSTANT_FILENAME_PREFIX: str = ... CUSTOM_OBJ_FILENAME_PREFIX: str = ... SAMPLE_INPUTS_DIR: str = ... diff --git a/torch/_C/_nn.pyi.in b/torch/_C/_nn.pyi.in index 7be3dcff4da67..7d62995bf1005 100644 --- a/torch/_C/_nn.pyi.in +++ b/torch/_C/_nn.pyi.in @@ -67,6 +67,7 @@ def pad_sequence( padding_value: float = 0.0, padding_side: Literal["left", "right"] = "right", ) -> Tensor: ... +<<<<<<< HEAD # Upsample functions used by torch.nn.functional.interpolate def upsample_nearest1d( @@ -135,5 +136,7 @@ def upsample_bicubic2d( align_corners: bool, scale_factors: Sequence[float] | None, ) -> Tensor: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def flatten_dense_tensors(tensors: list[Tensor]) -> Tensor: ... def unflatten_dense_tensors(flat: Tensor, tensors: list[Tensor]) -> list[Tensor]: ... diff --git a/torch/__init__.py b/torch/__init__.py index 0625ad60bfff6..e874bc923c164 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -32,6 +32,7 @@ TypeVar as _TypeVar, Union as _Union, ) +<<<<<<< HEAD from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs @@ -40,6 +41,20 @@ # they are likely stale. def _running_with_deploy() -> builtins.bool: return False +======= +from typing_extensions import ParamSpec as _ParamSpec + + +if TYPE_CHECKING: + from .types import Device, IntLikeType + + +# multipy/deploy is setting this import before importing torch, this is the most +# reliable way we have to detect if we're running within deploy. +# https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137 +def _running_with_deploy() -> builtins.bool: + return sys.modules.get("torch._meta_registrations", None) is object +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._utils import ( @@ -54,12 +69,30 @@ def _running_with_deploy() -> builtins.bool: USE_GLOBAL_DEPS, USE_RTLD_GLOBAL_WITH_LIBTORCH, ) +<<<<<<< HEAD from torch.torch_version import __version__ as __version__ if TYPE_CHECKING: from torch.types import Device, IntLikeType +======= + + +# TODO(torch_deploy) figure out how to freeze version.py in fbcode build +if _running_with_deploy(): + __version__ = "torch-deploy-1.8" + # TODO: Remove this ugly hack when deploy typing extensions are updated to 4.10+ + if not TYPE_CHECKING: + import typing_extensions + + _TypeIs = typing_extensions.TypeGuard + typing_extensions.TypeIs = _TypeIs +else: + from typing_extensions import TypeIs as _TypeIs + + from torch.torch_version import __version__ as __version__ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __all__ = [ "BoolStorage", @@ -193,6 +226,7 @@ def _load_dll_libraries() -> None: if os.path.exists(p) ] +<<<<<<< HEAD if not builtins.any( os.path.exists(os.path.join(p, "nvToolsExt64_1.dll")) for p in dll_paths ): @@ -207,6 +241,8 @@ def _load_dll_libraries() -> None: else: nvtoolsext_dll_path = "" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if cuda_version and builtins.all( not glob.glob(os.path.join(p, "cudart64*.dll")) for p in dll_paths ): @@ -219,9 +255,13 @@ def _load_dll_libraries() -> None: else: cuda_path = "" +<<<<<<< HEAD dll_paths.extend( p for p in (nvtoolsext_dll_path, cuda_path) if os.path.exists(p) ) +======= + dll_paths.extend(p for p in (cuda_path,) if os.path.exists(p)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) with_load_library_flags = hasattr(kernel32, "AddDllDirectory") @@ -244,7 +284,11 @@ def _load_dll_libraries() -> None: textwrap.dedent( """ Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure. +<<<<<<< HEAD It can be downloaded at https://aka.ms/vs/17/release/vc_redist.x64.exe +======= + It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ ).strip() ) @@ -283,6 +327,7 @@ def _load_dll_libraries() -> None: def _get_cuda_dep_paths(path: str, lib_folder: str, lib_name: str) -> list[str]: +<<<<<<< HEAD # Libraries can either be in # path/nvidia/lib_folder/lib or # path/nvidia/cuXX/lib (since CUDA 13.0) or @@ -297,12 +342,22 @@ def _get_cuda_dep_paths(path: str, lib_folder: str, lib_name: str) -> list[str]: nvidia_lib_paths += glob.glob( os.path.join(path, "nvidia", f"cu{maj_cuda_version}", "lib", lib_name) ) +======= + # Libraries can either be in path/nvidia/lib_folder/lib or path/lib_folder/lib + nvidia_lib_paths = glob.glob( + os.path.join(path, "nvidia", lib_folder, "lib", lib_name) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lib_paths = glob.glob(os.path.join(path, lib_folder, "lib", lib_name)) return nvidia_lib_paths + lib_paths +<<<<<<< HEAD def _preload_cuda_deps(lib_folder: str, lib_name: str, required: bool = True) -> None: # type: ignore[valid-type] +======= +def _preload_cuda_deps(lib_folder: str, lib_name: str) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Preloads cuda deps if they could not be found otherwise.""" # Should only be called on Linux if default path resolution have failed assert platform.system() == "Linux", "Should only be called on Linux" @@ -313,15 +368,25 @@ def _preload_cuda_deps(lib_folder: str, lib_name: str, required: bool = True) -> if candidate_lib_paths: lib_path = candidate_lib_paths[0] break +<<<<<<< HEAD if not lib_path and required: raise ValueError(f"{lib_name} not found in the system path {sys.path}") if lib_path: ctypes.CDLL(lib_path) +======= + if not lib_path: + raise ValueError(f"{lib_name} not found in the system path {sys.path}") + ctypes.CDLL(lib_path) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # See Note [Global dependencies] def _load_global_deps() -> None: +<<<<<<< HEAD if platform.system() == "Windows": +======= + if _running_with_deploy() or platform.system() == "Windows": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return # Determine the file extension based on the platform @@ -341,6 +406,7 @@ def _load_global_deps() -> None: try: with open("/proc/self/maps") as f: _maps = f.read() +<<<<<<< HEAD # libtorch_global_deps.so always depends in cudart, check if its installed and loaded if "libcudart.so" not in _maps: @@ -348,6 +414,14 @@ def _load_global_deps() -> None: # If all above-mentioned conditions are met, preload nvrtc and nvjitlink _preload_cuda_deps("cuda_nvrtc", "libnvrtc.so.*[0-9]") _preload_cuda_deps("cuda_nvrtc", "libnvrtc-builtins.so.*[0-9]") +======= + # libtorch_global_deps.so always depends in cudart, check if its installed via wheel + if "nvidia/cuda_runtime/lib/libcudart.so" not in _maps: + return + # If all above-mentioned conditions are met, preload nvrtc and nvjitlink + # Please note that order are important for CUDA-11.8 , as nvjitlink does not exist there + _preload_cuda_deps("cuda_nvrtc", "libnvrtc.so.*[0-9]") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _preload_cuda_deps("nvjitlink", "libnvJitLink.so.*[0-9]") except Exception: pass @@ -355,6 +429,11 @@ def _load_global_deps() -> None: except OSError as err: # Can only happen for wheel with cuda libs as PYPI deps # As PyTorch is not purelib, but nvidia-*-cu12 is +<<<<<<< HEAD +======= + from torch.version import cuda as cuda_version + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cuda_libs: dict[str, str] = { "cublas": "libcublas.so.*[0-9]", "cudnn": "libcudnn.so.*[0-9]", @@ -368,9 +447,20 @@ def _load_global_deps() -> None: "cusparselt": "libcusparseLt.so.*[0-9]", "cusolver": "libcusolver.so.*[0-9]", "nccl": "libnccl.so.*[0-9]", +<<<<<<< HEAD "nvshmem": "libnvshmem_host.so.*[0-9]", "cufile": "libcufile.so.*[0-9]", } +======= + } + # cufiile is only available on cuda 12+ + # TODO: Remove once CUDA 11.8 binaries are deprecated + if cuda_version is not None: + t_version = cuda_version.split(".") + t_major = int(t_version[0]) # type: ignore[operator] + if t_major >= 12: + cuda_libs["cufile"] = "libcufile.so.*[0-9]" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_cuda_lib_err = [ lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0] @@ -379,14 +469,21 @@ def _load_global_deps() -> None: raise err for lib_folder, lib_name in cuda_libs.items(): _preload_cuda_deps(lib_folder, lib_name) +<<<<<<< HEAD # libnvToolsExt is Optional Dependency _preload_cuda_deps("nvtx", "libnvToolsExt.so.*[0-9]", required=False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and ( +<<<<<<< HEAD platform.system() != "Windows" +======= + _running_with_deploy() or platform.system() != "Windows" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a # few circumstances: @@ -1018,10 +1115,17 @@ def sym_fresh_size(expr): of the PyTorch repository rather than the C extensions which are expected in the `torch._C` namespace. This can occur when using the `install` workflow. e.g. +<<<<<<< HEAD $ python -m pip install --no-build-isolation -v . && python -c "import torch" This error can generally be solved using the `develop` workflow $ python -m pip install --no-build-isolation -v -e . && python -c "import torch" # This should succeed +======= + $ python setup.py install && python -c "import torch" + + This error can generally be solved using the `develop` workflow + $ python setup.py develop && python -c "import torch" # This should succeed +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) or by running Python from a different directory. """ ).strip() @@ -1120,7 +1224,11 @@ def is_tensor(obj: _Any, /) -> _TypeIs["torch.Tensor"]: r"""Returns True if `obj` is a PyTorch tensor. Note that this function is simply doing ``isinstance(obj, Tensor)``. +<<<<<<< HEAD Using that ``isinstance`` check is better for type checking with mypy, +======= + Using that ``isinstance`` check is better for typechecking with mypy, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and more explicit - so it's recommended to use that instead of ``is_tensor``. @@ -2087,7 +2195,11 @@ def _dtype(self): # Shared memory manager needs to know the exact location of manager executable def _manager_path(): +<<<<<<< HEAD if platform.system() == "Windows": +======= + if _running_with_deploy() or platform.system() == "Windows": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return b"" path = get_file_path("torch", "bin", "torch_shm_manager") prepare_multiprocessing_environment(get_file_path("torch")) @@ -2147,7 +2259,11 @@ def _manager_path(): ) ################################################################################ +<<<<<<< HEAD # Import TorchDynamo's lazy APIs to avoid circular dependencies +======= +# Import TorchDynamo's lazy APIs to avoid circular dependenices +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ################################################################################ # needs to be before from torch.functional import * to avoid circular dependencies @@ -2230,7 +2346,10 @@ def _assert(condition, message): testing as testing, types as types, utils as utils, +<<<<<<< HEAD version as version, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xpu as xpu, ) from torch.signal import windows as windows @@ -2508,7 +2627,11 @@ def compile( Args: model (Callable or None): Module/function to optimize +<<<<<<< HEAD fullgraph (bool): If False (default), torch.compile attempts to discover compilable regions +======= + fullgraph (bool): If False (default), torch.compile attempts to discover compileable regions +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) in the function that it will optimize. If True, then we require that the entire function be capturable into a single graph. If this is not possible (that is, if there are graph breaks), then this will raise an error. @@ -2693,6 +2816,7 @@ def _register_device_module(device_type, module): # Register MPS specific decomps torch.backends.mps._init() +<<<<<<< HEAD from torch import compiler as compiler @@ -2708,6 +2832,23 @@ def registerOp(cls, op_key, full_schema, op_impl, dispatch_key): cls.ops_table[(op_key, dispatch_key)] = op_impl return cls.ops_table[(op_key, dispatch_key)] +======= +if not _running_with_deploy(): + from torch import compiler as compiler + + class _TritonLibrary: + lib = torch.library.Library("triton", "DEF") + ops_table: dict[tuple[str, str], _Callable] = {} + + @classmethod + def registerOp(cls, op_key, full_schema, op_impl, dispatch_key): + if (op_key, dispatch_key) not in cls.ops_table: + cls.lib.define(full_schema) + cls.lib.impl("triton::" + op_key, op_impl, dispatch_key) + cls.ops_table[(op_key, dispatch_key)] = op_impl + + return cls.ops_table[(op_key, dispatch_key)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Deprecated attributes diff --git a/torch/_classes.py b/torch/_classes.py index a811c7c30be61..0a9a86ad099d5 100644 --- a/torch/_classes.py +++ b/torch/_classes.py @@ -1,15 +1,28 @@ +<<<<<<< HEAD import types from typing import Any +======= +# mypy: allow-untyped-defs +import types +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch._C class _ClassNamespace(types.ModuleType): +<<<<<<< HEAD def __init__(self, name: str) -> None: super().__init__("torch.classes" + name) self.name = name def __getattr__(self, attr: str) -> Any: +======= + def __init__(self, name): + super().__init__("torch.classes" + name) + self.name = name + + def __getattr__(self, attr): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) proxy = torch._C._get_custom_class_python_wrapper(self.name, attr) if proxy is None: raise RuntimeError(f"Class {self.name}.{attr} not registered!") @@ -22,16 +35,27 @@ class _Classes(types.ModuleType): def __init__(self) -> None: super().__init__("torch.classes") +<<<<<<< HEAD def __getattr__(self, name: str) -> _ClassNamespace: +======= + def __getattr__(self, name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace = _ClassNamespace(name) setattr(self, name, namespace) return namespace @property +<<<<<<< HEAD def loaded_libraries(self) -> Any: return torch.ops.loaded_libraries def load_library(self, path: str) -> None: +======= + def loaded_libraries(self): + return torch.ops.loaded_libraries + + def load_library(self, path): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Loads a shared library from the given path into the current process. diff --git a/torch/_custom_op/impl.py b/torch/_custom_op/impl.py index 208c18e392a46..ac1eae8e2af1b 100644 --- a/torch/_custom_op/impl.py +++ b/torch/_custom_op/impl.py @@ -648,7 +648,11 @@ def custom_op_from_existing(op): name = op.name().split("::")[-1] schema_str = str(op._schema) # CustomOp expects the schema string without the namespace +<<<<<<< HEAD schema_str = schema_str.rsplit("::", maxsplit=1)[-1] +======= + schema_str = schema_str.split("::")[-1] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) schema = FunctionSchema.parse(schema_str) return CustomOp(lib, ns, schema, name, op, _private_access=True) diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 8e9796d2f7c1b..b5f6d4ede0216 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -418,7 +418,10 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.native_dropout_backward, aten.native_group_norm_backward, aten.native_layer_norm_backward, +<<<<<<< HEAD aten._fused_rms_norm_backward, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten.new_empty, aten.new_full, aten.new_ones, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index ba09c6173c5f3..87d6eb5302e00 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -52,7 +52,11 @@ class Reduction(Enum): # This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided +<<<<<<< HEAD # We're currently reusing ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops +======= +# We're currently re-using ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Will need to validate the non-elementwise uses def type_casts( f: Callable, @@ -814,7 +818,11 @@ def slice_scatter( if start == 0 and end == dim_size and step == 1: return src.clone() +<<<<<<< HEAD indices: list[Optional[Tensor]] = [None] * input.dim() +======= + indices = [None] * input.dim() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) idx = torch.arange(dim_size, device=input.device) indices[dim] = (idx - start) // step @@ -947,7 +955,11 @@ def check_positive(param, param_name, strict=True): ) torch._check( all(c > 0 for c in output_size), +<<<<<<< HEAD lambda: f"Given an input with spatial size {tuple(shape[-2:])}, " +======= + lambda: f"Given an input with spacial size {tuple(shape[-2:])}, " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"kernel_size={kernel_size}, dilation={dilation}, " f"padding={padding}, stride={stride}, " "the calculated shape of the array of sliding blocks " @@ -1667,9 +1679,15 @@ def native_layer_norm_backward( N = prod(inner_dims) # type: ignore[arg-type] M = prod(outer_dims) # type: ignore[arg-type] +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import statically_known_true if statically_known_true(M == 0) or statically_known_true(N == 0): +======= + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( input.new_zeros(input_shape) if output_mask[0] else None, input.new_zeros(input_shape[axis:]) if output_mask[1] else None, @@ -1677,7 +1695,10 @@ def native_layer_norm_backward( ) mean = _unsqueeze_to_dim(mean, input_cast.dim()) # type: ignore[union-attr] rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr] +<<<<<<< HEAD assert input_cast is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x_hat = (input_cast - mean) * rstd if weight_cast is not None: grad_x_hat = grad_out_cast * weight_cast @@ -1743,6 +1764,7 @@ def native_layer_norm_backward_out( return grad_input +<<<<<<< HEAD @register_decomposition(aten._fused_rms_norm_backward.default) def _fused_rms_norm_backward( grad_out: Tensor, @@ -1818,6 +1840,8 @@ def _fused_rms_norm_backward( ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def native_batch_norm_helper( input: Tensor, weight: Optional[Tensor], @@ -3987,9 +4011,15 @@ def _unsafe_masked_index(x, mask, indices, fill): lambda: "tensors used as masks must be bool tensors", ) +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import guard_or_false if guard_or_false(x.numel() == 0): +======= + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(x.numel() == 0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) meta_result = torch._meta_registrations.meta_index_Tensor(x, indices) return x.new_full(meta_result.shape, fill) @@ -4121,7 +4151,11 @@ def nll_loss2d_forward( return _nll_loss_forward(self, target, weight, reduction, ignore_index) +<<<<<<< HEAD # These are adapted from aten/src/ATen/native/UpSample.h, which is based on +======= +# These are adapted from aten/src/ATen/native/UpSample.h, wich is based on +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm def _upsample_cubic_convolution1(x: Tensor, A: float) -> Tensor: return ((A + 2) * x - (A + 3)) * x * x + 1 @@ -4461,7 +4495,11 @@ def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> b t1, t2 = (tensor1, tensor2) if tensor1.ndim >= tensor2.ndim else (tensor2, tensor1) +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import guard_or_false +======= + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not (t1.ndim >= 3 and t2.ndim <= 2): return False @@ -4469,7 +4507,11 @@ def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> b return True if tensor1.ndim == 2: return False +<<<<<<< HEAD if guard_or_false(t1.numel() == 0): +======= + if guard_size_oblivious(t1.numel() == 0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return True t1_shape = t1.shape @@ -4481,7 +4523,11 @@ def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> b for size in reversed(t1_shape[1:]): expected_stride.append(size * expected_stride[-1]) return all( +<<<<<<< HEAD guard_or_false(size == 1) or guard_or_false(left == right) +======= + guard_size_oblivious(size == 1) or left == right +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for left, right, size in zip( t1_stride, list(reversed(expected_stride)), t1_shape ) @@ -5075,7 +5121,10 @@ def scaled_dot_product_flash_attention_for_cpu( is_causal=is_causal, dropout_mask=None, scale=scale, +<<<<<<< HEAD enable_gqa=query.size(1) != key.size(1), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Why this change? # In pre-dispatch export scaled_dot_product_attention is executed via diff --git a/torch/_deploy.py b/torch/_deploy.py new file mode 100644 index 0000000000000..0443a2447d00d --- /dev/null +++ b/torch/_deploy.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +import io + +import torch +from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer +from torch.package._package_pickler import create_pickler +from torch.package._package_unpickler import PackageUnpickler +from torch.serialization import _maybe_decode_ascii + + +def _save_storages(importer, obj): + serialized_storages = [] + serialized_dtypes = [] + + importer = importer if isinstance(importer, torch.package.PackageImporter) else None + importers: Importer + if importer is not None: + importers = OrderedImporter(importer, sys_importer) + else: + importers = sys_importer + + def persistent_id(obj): + if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, we can + # remove this case + dtype = obj.dtype + else: + dtype = torch.uint8 + + serialized_storages.append(obj) + serialized_dtypes.append(dtype) + return ("storage", len(serialized_storages) - 1) + + if hasattr(obj, "__reduce_deploy__"): + if _serialized_reduces.get(id(obj)) is None: + _serialized_reduces[id(obj)] = ( + "reduce_deploy", + id(obj), + *obj.__reduce_deploy__(importers), + ) + return _serialized_reduces[id(obj)] + + return None + + # Write the pickle data for `obj` + data_buf = io.BytesIO() + pickler = create_pickler(data_buf, importers) + pickler.persistent_id = persistent_id + pickler.dump(obj) + data_value = data_buf.getvalue() + return ( + data_value, + serialized_storages, + serialized_dtypes, + importer.zip_reader if importer else None, + ) + + +def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes): + def persistent_load(saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + + if typename == "storage": + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + storage = serialized_storages[data[0]] + dtype = serialized_dtypes[data[0]] + return torch.storage.TypedStorage( + wrap_storage=storage.untyped(), dtype=dtype + ) + + if typename == "reduce_deploy": + reduce_id, func, args = data + if reduce_id not in _loaded_reduces: + _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args) + return _loaded_reduces[reduce_id] + + return None + + importer: Importer + if zip_reader is not None: + importer = OrderedImporter(_get_package(zip_reader), sys_importer) + else: + importer = sys_importer + + unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes)) + unpickler.persistent_load = persistent_load # type: ignore[method-assign] + result = _deploy_objects[id] = unpickler.load() + return result + + +def _get_package(zip_reader): + if zip_reader not in _raw_packages: + _raw_packages[zip_reader] = PackageImporter(zip_reader) + return _raw_packages[zip_reader] + + +_raw_packages: dict = {} +_deploy_objects: dict = {} +_serialized_reduces: dict = {} +_loaded_reduces: dict = {} diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 561acf62f785c..935e764fd631a 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -10,6 +10,7 @@ import torch +<<<<<<< HEAD from . import ( aot_compile, config, @@ -18,6 +19,9 @@ functional_export, resume_execution, ) +======= +from . import config, convert_frame, eval_frame, resume_execution +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .backends.registry import list_backends, lookup_backend, register_backend from .callback import callback_handler, on_compile_end, on_compile_start from .code_context import code_context @@ -28,7 +32,10 @@ disable, disallow_in_graph, dont_skip_tracing, +<<<<<<< HEAD error_on_graph_break, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) forbid_in_graph, graph_break, mark_dynamic, @@ -57,6 +64,7 @@ from .mutation_guard import GenerationTracker from .pgo import reset_code_state from .symbolic_convert import TensorifyState +<<<<<<< HEAD from .utils import ( graph_break_reasons, guard_failures, @@ -64,6 +72,9 @@ register_hook_for_recompile_user_context, reset_frame_count, ) +======= +from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Register polyfill functions @@ -73,6 +84,7 @@ __all__ = [ "allow_in_graph", "assume_constant_result", +<<<<<<< HEAD "config", "disable", "disallow_in_graph", @@ -84,6 +96,13 @@ "is_compiling", "list_backends", "lookup_backend", +======= + "disallow_in_graph", + "dont_skip_tracing", + "forbid_in_graph", + "substitute_in_graph", + "graph_break", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "mark_dynamic", "maybe_mark_dynamic", "mark_static", @@ -91,6 +110,7 @@ "nonstrict_trace", "optimize", "optimize_assert", +<<<<<<< HEAD "OptimizedModule", "patch_dynamo_config", "register_backend", @@ -101,6 +121,23 @@ "set_stance", "skip_frame", "substitute_in_graph", +======= + "patch_dynamo_config", + "skip_frame", + "export", + "explain", + "run", + "replay", + "disable", + "set_stance", + "reset", + "OptimizedModule", + "is_compiling", + "register_backend", + "list_backends", + "lookup_backend", + "config", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] # allowlist this for weights_only load of NJTs diff --git a/torch/_dynamo/_trace_wrapped_higher_order_op.py b/torch/_dynamo/_trace_wrapped_higher_order_op.py index 9b000ee926a1b..c09b76098c7fa 100644 --- a/torch/_dynamo/_trace_wrapped_higher_order_op.py +++ b/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -49,6 +49,7 @@ __all__ = ["trace_wrapped"] +<<<<<<< HEAD @torch.library.custom_op("flex_lib::zeros_and_scatter", mutates_args=()) # type: ignore[misc] def zeros_and_scatter( shape: list[int], @@ -89,6 +90,49 @@ def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def] value, ) return out, None +======= +if not torch._running_with_deploy(): + # torch.library.custom_op does not work with torch.deploy/multipy # codespell:ignore + + @torch.library.custom_op("flex_lib::zeros_and_scatter", mutates_args=()) # type: ignore[misc] + def zeros_and_scatter( + shape: list[int], + indices: list[Tensor], + vals: Tensor, + ) -> Tensor: + """Custom Op so that we can register a custom lowering for the new_output + scatter in the backwards pass""" + grad = torch.zeros(shape, device=vals.device, dtype=vals.dtype) + return torch.ops.aten.index_put(grad, indices, vals, accumulate=True) + + @zeros_and_scatter.register_fake # type: ignore[misc] + def _( + shape: list[int], + indices: list[Tensor], + vals: Tensor, + ) -> Tensor: + return vals.new_empty(shape) + + @zeros_and_scatter.register_vmap # type: ignore[misc] + def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def] + """The batching rule is special in that it returns a tensor that is not batched""" + indices_indims = indims[1] + expanded_indices = [] + for idx, idx_indim in zip(indices, indices_indims): + # The index is not a being batched, we should unsqueeze and expand to val + if idx_indim is None: + expanded_indices.append(idx.expand(value.shape)) + else: + # the index is being part of the vmap batch, it should be the same size as val + assert idx.shape == value.shape + expanded_indices.append(idx) + + out = torch.ops.flex_lib.zeros_and_scatter( + shape, + expanded_indices, + value, + ) + return out, None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ModIndex(torch.autograd.Function): @@ -116,11 +160,14 @@ def backward(ctx, gradOut): # type: ignore[no-untyped-def] None, ) +<<<<<<< HEAD @classmethod @torch._export.wrappers.allow_in_pre_dispatch_graph def apply(cls, *args, **kwargs): # type: ignore[no-untyped-def] return super().apply(*args, **kwargs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mod_index = ModIndex.apply diff --git a/torch/_dynamo/backends/common.py b/torch/_dynamo/backends/common.py index b7604db5429d6..ddfde88737529 100644 --- a/torch/_dynamo/backends/common.py +++ b/torch/_dynamo/backends/common.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: ignore-errors + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module provides common utilities and base classes for TorchDynamo backends. @@ -19,9 +24,12 @@ import contextlib import functools import logging +<<<<<<< HEAD from collections.abc import Iterable from typing import Any, Callable from typing_extensions import ParamSpec, TypeVar +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from unittest.mock import patch import torch @@ -37,6 +45,7 @@ log = logging.getLogger(__name__) +<<<<<<< HEAD P = ParamSpec("P") R = TypeVar("R") @@ -49,6 +58,15 @@ def __init__(self, **kwargs: Any) -> None: def __call__( self, gm: torch.fx.GraphModule, example_inputs: Iterable[Any], **kwargs: Any ) -> Callable[..., Any]: +======= + +class AotAutograd: + def __init__(self, **kwargs) -> None: + self.__name__ = "compiler_fn" + self.kwargs = kwargs + + def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if kwargs: log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs) @@ -72,8 +90,13 @@ def __call__( counters["aot_autograd"]["not_ok"] += 1 return gm +<<<<<<< HEAD def wrap_bw_compiler(bw_compiler_fn: Callable[P, R]) -> Callable[..., R]: def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R: +======= + def wrap_bw_compiler(bw_compiler_fn): + def _wrapped_bw_compiler(*args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Note [Wrapping bw_compiler in disable] # The two disables here: # - stop TorchDynamo from trying to compile the bw_compiler function itself @@ -81,7 +104,11 @@ def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R: return disable( disable( bw_compiler_fn, reason="do not trace backward compiler function" +<<<<<<< HEAD )(*args, **kwargs), # type: ignore[misc] +======= + )(*args, **kwargs), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reason="do not trace generated backwards pass", ) @@ -105,9 +132,13 @@ def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R: # debug asserts slow down compile time noticeably, # So only default them on when the aot_eager backend is used. if self.kwargs.get("fw_compiler", None) == nop: +<<<<<<< HEAD patch_config: contextlib.AbstractContextManager[Any] = patch( "functorch.compile.config.debug_assert", True ) +======= + patch_config = patch("functorch.compile.config.debug_assert", True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: patch_config = contextlib.nullcontext() @@ -124,11 +155,19 @@ def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R: raise +<<<<<<< HEAD def aot_autograd(**kwargs: Any) -> AotAutograd: return AotAutograd(**kwargs) def mem_efficient_fusion_kwargs(use_decomps: bool) -> dict[str, Any]: +======= +def aot_autograd(**kwargs) -> AotAutograd: + return AotAutograd(**kwargs) + + +def mem_efficient_fusion_kwargs(use_decomps): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from functorch.compile import ( default_decompositions, min_cut_rematerialization_partition, @@ -148,21 +187,33 @@ def mem_efficient_fusion_kwargs(use_decomps: bool) -> dict[str, Any]: return kwargs +<<<<<<< HEAD def fake_tensor_unsupported(fn: Callable[[Any, list[Any], Any], R]) -> Any: +======= +def fake_tensor_unsupported(fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Decorator for backends that need real inputs. We swap out fake tensors for zero tensors. """ @functools.wraps(fn) +<<<<<<< HEAD def wrapper(model: Any, inputs: Any, **kwargs: Any) -> Any: with _disable_current_modes(): inputs = list(map(defake, inputs)) return fn(model, inputs, **kwargs) # type: ignore[call-arg] +======= + def wrapper(model, inputs, **kwargs): + with _disable_current_modes(): + inputs = list(map(defake, inputs)) + return fn(model, inputs, **kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return wrapper +<<<<<<< HEAD def device_from_inputs(example_inputs: Iterable[Any]) -> torch.device: for x in example_inputs: if hasattr(x, "device"): @@ -175,3 +226,15 @@ def dtype_from_inputs(example_inputs: Iterable[Any]) -> torch.dtype: if hasattr(x, "dtype"): return x.dtype return torch.float32 # Default fallback +======= +def device_from_inputs(example_inputs) -> torch.device: + for x in example_inputs: + if hasattr(x, "device"): + return x.device + + +def dtype_from_inputs(example_inputs) -> torch.dtype: + for x in example_inputs: + if hasattr(x, "dtype"): + return x.dtype +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_dynamo/backends/cudagraphs.py b/torch/_dynamo/backends/cudagraphs.py index f8599d393833e..2129d11ff6ca3 100644 --- a/torch/_dynamo/backends/cudagraphs.py +++ b/torch/_dynamo/backends/cudagraphs.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: ignore-errors + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module implements CUDA graphs support for TorchDynamo backends. @@ -23,11 +28,17 @@ import functools from collections import defaultdict +<<<<<<< HEAD from collections.abc import Sequence from typing import Any, Callable, Optional import torch import torch.fx +======= +from typing import Optional + +import torch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo import config from torch._dynamo.backends.common import aot_autograd from torch._dynamo.backends.debugging import boxed_nop @@ -51,8 +62,13 @@ from .registry import register_backend +<<<<<<< HEAD def find_input_mutations(g: torch.fx.Graph) -> set[int]: def meta_fk(meta: dict[str, Any]) -> Any: +======= +def find_input_mutations(g): + def meta_fk(meta): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return meta["val"] if "val" in meta else meta["fake_result"] inputs = defaultdict(set) @@ -90,9 +106,13 @@ def meta_fk(meta: dict[str, Any]) -> Any: return mutated_inputs +<<<<<<< HEAD def get_device_node_mapping( gm: torch.fx.GraphModule, ) -> dict[torch.device, torch.fx.Node]: +======= +def get_device_node_mapping(gm: torch.fx.GraphModule): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device_node_mapping: dict[torch.device, torch.fx.Node] = {} for n in gm.graph.nodes: t = n.meta.get("val", None) @@ -102,7 +122,11 @@ def get_device_node_mapping( def check_for_mutation_ignore_cuda_graph_managed_tensor( +<<<<<<< HEAD aot_model: torch.fx.GraphModule, num_fixed: int +======= + aot_model: torch.fx.GraphModule, num_fixed +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Optional[str]: mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed)) if not mutation_indices: @@ -112,7 +136,11 @@ def check_for_mutation_ignore_cuda_graph_managed_tensor( return get_mutation_stack_trace(placeholders, mutation_indices) +<<<<<<< HEAD def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed: int) -> Optional[str]: +======= +def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not config.cudagraph_backend_support_input_mutation: if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor( aot_model, num_fixed @@ -130,12 +158,17 @@ def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed: int) -> Optional[ return None +<<<<<<< HEAD def get_device_index(gm: torch.fx.GraphModule) -> int: +======= +def get_device_index(gm) -> int: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device = next(iter(get_device_node_mapping(gm))) assert device.type == "cuda" return device.index +<<<<<<< HEAD def get_stack_traces(gm: torch.fx.GraphModule) -> list[Optional[str]]: output = output_node(gm) assert len(output.args) == 1 @@ -149,16 +182,32 @@ def get_stack_traces(gm: torch.fx.GraphModule) -> list[Optional[str]]: def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any]) -> Any: +======= +def get_stack_traces(gm) -> list[Optional[str]]: + output = output_node(gm) + assert len(output.args) == 1 + return [ + (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) + for arg in output.args[0] + ] + + +def cudagraphs(dynamo_model, dynamo_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.cudagraph_trees import cudagraphify_impl do_cudagraphs = BoxedBool(True) boxed_device_index = BoxedDeviceIndex(None) +<<<<<<< HEAD def forward_cudagraphs( aot_model: torch.fx.GraphModule, aot_inputs: list[Any], is_inference: bool = False, ) -> Any: +======= + def forward_cudagraphs(aot_model, aot_inputs, is_inference=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) interp = boxed_nop(aot_model, aot_inputs) fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs)) if skip_msg := check_for_skip(aot_model, fixed): @@ -175,17 +224,28 @@ def forward_cudagraphs( range(fixed), device_index=boxed_device_index.value, is_backward=False, +<<<<<<< HEAD is_inference=False, # Q: should forward is_inference here? +======= + is_inference=False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stack_traces=get_stack_traces(aot_model), placeholders=get_placeholder_info(aot_model.graph), mutated_input_idxs=find_input_mutations(aot_model.graph), ) +<<<<<<< HEAD out._boxed_call = True # type: ignore[attr-defined] return out def backward_cudagraphs( aot_model: torch.fx.GraphModule, aot_inputs: list[Any] ) -> Any: +======= + out._boxed_call = True + return out + + def backward_cudagraphs(aot_model, aot_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) interp = boxed_nop(aot_model, aot_inputs) if not do_cudagraphs: return aot_model @@ -193,6 +253,7 @@ def backward_cudagraphs( fixed = count_tangents(aot_model) if skip_msg := check_for_skip(aot_model, fixed): log_cudagraph_skip_and_bump_counter( +<<<<<<< HEAD f"skipping cudagraphs due to {skip_msg}" ) @@ -210,6 +271,22 @@ def fn(inputs: list[Any]) -> Any: return aot_model(inputs) fn._boxed_call = True # type: ignore[attr-defined] +======= + "skipping cudagraphs due to %s", skip_msg + ) + + # See [Backward Generation Handling] + manager = torch._inductor.cudagraph_trees.get_manager( + boxed_device_index.value, create_if_none_exists=False + ) + assert manager is not None + + def fn(inputs): + manager.set_to_running_backward() + return aot_model(inputs) + + fn._boxed_call = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return fn out = cudagraphify_impl( @@ -223,7 +300,11 @@ def fn(inputs: list[Any]) -> Any: placeholders=get_placeholder_info(aot_model.graph), mutated_input_idxs=find_input_mutations(aot_model.graph), ) +<<<<<<< HEAD out._boxed_call = True # type: ignore[attr-defined] +======= + out._boxed_call = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return out aot_cudagraphs = aot_autograd( @@ -239,13 +320,21 @@ class CudagraphsBackend: compiler_name = "cudagraphs" @staticmethod +<<<<<<< HEAD def reset() -> None: +======= + def reset(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.cudagraph_trees import reset_cudagraph_trees reset_cudagraph_trees() @staticmethod +<<<<<<< HEAD def __call__(model: torch.fx.GraphModule, inputs: Sequence[Any]) -> Any: +======= + def __call__(model, inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return cudagraphs(model, inputs) @@ -254,12 +343,16 @@ def __call__(model: torch.fx.GraphModule, inputs: Sequence[Any]) -> Any: register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend()) +<<<<<<< HEAD def cudagraphs_inner( model: Callable[..., Any], inputs: Sequence[Any], copy_outputs: bool = True, copy_inputs: bool = True, ) -> Callable[..., Sequence[Any]]: +======= +def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """This isn't registered as a backend, but is used in some benchmarks""" assert isinstance(inputs, (list, tuple)) if copy_inputs: @@ -284,7 +377,11 @@ def cudagraphs_inner( if not isinstance(static_outputs, (list, tuple)): static_outputs = (static_outputs,) +<<<<<<< HEAD def run(*new_inputs: Any) -> Sequence[Any]: +======= + def run(*new_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(static_inputs) == len(new_inputs) if copy_inputs: for dst, src in zip(static_inputs, new_inputs): diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index 32fc72cfa52a3..f1e62119db1ce 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: ignore-errors + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module provides debugging backends for TorchDynamo to help diagnose and troubleshoot compilation and execution issues. It includes: @@ -26,37 +31,54 @@ import dataclasses import functools import logging +<<<<<<< HEAD from collections.abc import Iterable from importlib import import_module from typing import Any, Callable, Optional, TYPE_CHECKING, Union +======= +from importlib import import_module +from typing import Any, Optional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from functorch.compile import min_cut_rematerialization_partition from torch import _guards +<<<<<<< HEAD from torch._dynamo.output_graph import GraphCompileReason +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._functorch import config as functorch_config from torch._functorch.compilers import ts_compile from .common import aot_autograd +<<<<<<< HEAD from .registry import CompiledFn, CompilerFn, register_debug_backend as register_backend if TYPE_CHECKING: from torch.fx.node import Target +======= +from .registry import register_debug_backend as register_backend +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) @register_backend +<<<<<<< HEAD def eager( gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any ) -> Callable[..., Any]: +======= +def eager(gm, fake_tensor_inputs, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if kwargs: log.warning("eager backend ignoring extra kwargs %s", kwargs) return gm.forward +<<<<<<< HEAD def make_eager_backend_with_torch_function_mode( mode: torch.overrides.TorchFunctionMode, ) -> Callable[..., Any]: @@ -66,14 +88,25 @@ def make_eager_backend_with_torch_function_mode( def make_eager_backend_with_torch_function_modes( modes: Iterable[torch.overrides.TorchFunctionMode], ) -> Callable[..., Any]: +======= +def make_eager_backend_with_torch_function_mode(mode): + return make_eager_backend_with_torch_function_modes([mode]) + + +def make_eager_backend_with_torch_function_modes(modes): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Used to trace HOPs (cond and while) for eager execution, the metadata TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks in the HOP, so we need to externally run this mode and not trace it.""" from contextlib import ExitStack +<<<<<<< HEAD def fn( gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any ) -> Callable[..., Any]: +======= + def fn(gm, fake_tensor_inputs, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stack = ExitStack() for mode in modes: stack.enter_context(mode) @@ -86,15 +119,23 @@ def fn( @register_backend +<<<<<<< HEAD def eager_noexcept( gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any ) -> Callable[..., Any]: +======= +def eager_noexcept(gm, fake_tensor_inputs, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if kwargs: log.warning("eager_noexcept backend ignoring extra kwargs %s", kwargs) # This backend is intended to check that dynamo-generated GraphModules # do not cause errors. +<<<<<<< HEAD def inner(*args: Any) -> Any: +======= + def inner(*args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: return gm(*args) except Exception as e: @@ -106,15 +147,23 @@ def inner(*args: Any) -> Any: @register_backend +<<<<<<< HEAD def pre_dispatch_eager( gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any ) -> torch.fx.GraphModule: +======= +def pre_dispatch_eager(gm, fake_tensor_inputs, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if kwargs: log.warning("pre_dispatch_eager backend ignoring extra kwargs %s", kwargs) from torch.fx.experimental.proxy_tensor import make_fx +<<<<<<< HEAD def runnable_gm(*args: Any) -> Any: +======= + def runnable_gm(*args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch.fx.Interpreter(gm).run(*args) pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs) @@ -124,9 +173,13 @@ def runnable_gm(*args: Any) -> Any: @register_backend +<<<<<<< HEAD def eager_debug( gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any ) -> Callable[..., Any]: +======= +def eager_debug(gm, fake_tensor_inputs, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if kwargs: log.warning("eager_debug backend ignoring extra kwargs %s", kwargs) @@ -135,21 +188,31 @@ def eager_debug( # We could add more debugging bits here. # Right now, this backend can be used to check for and error on # custom dispatcher ops that have incorrect schemas. +<<<<<<< HEAD def inner(*args: Any) -> Any: +======= + def inner(*args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with SchemaCheckMode(): return torch.fx.Interpreter(gm).run(*args) return inner +<<<<<<< HEAD @register_backend(name="ts") # type: ignore[misc] def torchscript( gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor] ) -> torch.jit.ScriptModule: +======= +@register_backend(name="ts") +def torchscript(gm, fake_tensor_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch.jit.script(gm) # used boxed call to discard inputs when they are no longer needed +<<<<<<< HEAD def boxed_nop( fx_g: torch.fx.GraphModule, example_inputs: list[torch.Tensor] ) -> Callable[..., Any]: @@ -184,6 +247,31 @@ def run(args: Any) -> Any: return torch.fx.Interpreter(fx_g).boxed_run(args) run._boxed_call = True # type: ignore[attr-defined] +======= +def boxed_nop(fx_g, example_inputs): + def run(args): + return torch.fx.Interpreter(fx_g).boxed_run(args) + + run._boxed_call = True + return run + + +def boxed_nop_with_mode(fx_g, example_inputs, *, mode): + def run(args): + with mode: + return torch.fx.Interpreter(fx_g).boxed_run(args) + + run._boxed_call = True + return run + + +def fake_crossref_boxed_nop(fx_g, example_inputs, ignore_op_fn=None): + def run(args): + with torch._subclasses.CrossRefFakeMode(ignore_op_fn): + return torch.fx.Interpreter(fx_g).boxed_run(args) + + run._boxed_call = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return run @@ -191,9 +279,13 @@ def ignore_builtins(op: torch._ops.OpOverload) -> bool: return op.namespace in ("aten", "prims", "prim") +<<<<<<< HEAD def get_nop_func() -> Callable[ [torch.fx.GraphModule, list[torch.Tensor]], Callable[..., Any] ]: +======= +def get_nop_func(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not torch._functorch.config.fake_tensor_crossref: return boxed_nop elif torch._functorch.config.fake_tensor_crossref == "all": @@ -206,12 +298,21 @@ def get_nop_func() -> Callable[ # Useful for debugging purpose # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging. def aot_eager( +<<<<<<< HEAD gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], fw_compiler: Optional[Callable[..., Any]] = None, bw_compiler: Optional[Callable[..., Any]] = None, **kwargs: Any, ) -> Callable[..., Any]: +======= + gm, + fake_tensor_inputs, + fw_compiler=None, + bw_compiler=None, + **kwargs, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return aot_autograd( fw_compiler=fw_compiler or boxed_nop, bw_compiler=bw_compiler or boxed_nop, @@ -234,9 +335,13 @@ def aot_eager( # inductor problems. # aot_eager_decomp_partition just replaces the inductor compiler with nop to help # isolate inductor vs aot_eager errors +<<<<<<< HEAD def aot_eager_decomp_partition( gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any ) -> Callable[..., Any]: +======= +def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if kwargs: log.warning( "aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs @@ -248,7 +353,11 @@ def aot_eager_decomp_partition( if bisect_changes := CompilerBisector.get_config_change( "aot_eager_decomp_partition" ): +<<<<<<< HEAD config_patches.update(bisect_changes) # type: ignore[arg-type] +======= + config_patches.update(bisect_changes) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with functorch_config.patch(config_patches): return aot_autograd( @@ -272,12 +381,16 @@ def aot_eager_decomp_partition( # aot_eager_decomp_partition_with_mode is similar as aot_eager_decomp_partition, # except that it takes a TorchDispatchMode mode and run the fw/bw in the mode +<<<<<<< HEAD def aot_eager_decomp_partition_with_mode( gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], mode: Any, **kwarg: Any, ) -> Callable[..., Any]: +======= +def aot_eager_decomp_partition_with_mode(gm, fake_tensor_inputs, mode, **kwarg): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return aot_autograd( # these are taken from memory_efficient_fusion() fw_compiler=functools.partial(boxed_nop_with_mode, mode=mode), @@ -294,6 +407,7 @@ def aot_eager_decomp_partition_with_mode( register_backend( name="aot_eager_decomp_partition_with_mode", +<<<<<<< HEAD compiler_fn=aot_eager_decomp_partition_with_mode, # type: ignore[arg-type] ) @@ -301,6 +415,13 @@ def aot_eager_decomp_partition_with_mode( def aot_eager_decomp_partition_crossref( gm: torch.fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], **kwargs: Any ) -> Callable[..., Any]: +======= + compiler_fn=aot_eager_decomp_partition_with_mode, +) + + +def aot_eager_decomp_partition_crossref(gm, fake_tensor_inputs, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # if the config is set, respect it, otherwise only test custom_ops. # custom_op bad metas always manifest as an error whereas aten will only sometimes. # by default, use the less noisy option @@ -338,9 +459,13 @@ class TestingOnlyCompileError(Exception): @register_backend +<<<<<<< HEAD def relu_compile_error_TESTING_ONLY( gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] ) -> torch.fx.GraphModule: +======= +def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for node in gm.graph.nodes: if node.target == torch.relu: raise ReluCompileError @@ -348,9 +473,13 @@ def relu_compile_error_TESTING_ONLY( @register_backend +<<<<<<< HEAD def relu_runtime_error_TESTING_ONLY( gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] ) -> torch.fx.GraphModule: +======= +def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for node in gm.graph.nodes: if node.target == torch.relu: node.target = torch._assert @@ -360,9 +489,13 @@ def relu_runtime_error_TESTING_ONLY( @register_backend +<<<<<<< HEAD def relu_accuracy_error_TESTING_ONLY( gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] ) -> torch.fx.GraphModule: +======= +def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for node in gm.graph.nodes: if node.target == torch.relu: node.target = torch.add @@ -373,9 +506,13 @@ def relu_accuracy_error_TESTING_ONLY( @register_backend +<<<<<<< HEAD def non_leaf_compile_error_TESTING_ONLY( gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] ) -> torch.fx.GraphModule: +======= +def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Require at least one non-trivial thing in the graph, # see https://github.com/pytorch/pytorch/issues/102898 for node in gm.graph.nodes: @@ -399,9 +536,17 @@ class ExplainOutput: graphs: list[torch.fx.GraphModule] graph_count: int graph_break_count: int +<<<<<<< HEAD break_reasons: list[GraphCompileReason] op_count: int ops_per_graph: Optional[list[list["Target"]]] = None +======= + break_reasons: list[ + Any + ] # Type is GraphCompileReason but doesn't matter for this purpose + op_count: int + ops_per_graph: Optional[list[torch.fx.Node]] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_guards: Optional[list[_guards.Guard]] = None compile_times: Optional[str] = None @@ -437,6 +582,7 @@ def __str__(self) -> str: def _explain_graph_detail( +<<<<<<< HEAD gm: torch.fx.GraphModule, graphs: list[torch.fx.GraphModule], op_count: int, @@ -449,6 +595,10 @@ def _explain_graph_detail( list[list["Target"]], list[GraphCompileReason], ]: +======= + gm: torch.fx.GraphModule, graphs, op_count, ops_per_graph, break_reasons +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This function is a utility which processes a torch.fx.GraphModule and accumulates information about its ops, graph breaks, and other details. It @@ -470,8 +620,13 @@ def _explain_graph_detail( ops = [node.target for node in gm.graph.nodes if node.op == "call_function"] op_count += len(ops) ops_per_graph.append(ops) +<<<<<<< HEAD if gm.compile_subgraph_reason.graph_break: # type: ignore[union-attr] break_reasons.append(gm.compile_subgraph_reason) # type: ignore[arg-type] +======= + if gm.compile_subgraph_reason.graph_break: + break_reasons.append(gm.compile_subgraph_reason) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return gm, graphs, op_count, ops_per_graph, break_reasons @@ -501,6 +656,7 @@ def fn(x): print(eb.output()) """ +<<<<<<< HEAD def __init__(self, backend: Union[CompilerFn, str]) -> None: from .registry import lookup_backend @@ -515,6 +671,19 @@ def __call__( ops_per_graph: list[list[Target]] = [] gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail( gm, self.graphs, self.op_count, ops_per_graph, self.break_reasons +======= + def __init__(self, backend) -> None: + from .registry import lookup_backend + + self.backend = lookup_backend(backend) + self.graphs = [] + self.op_count = 0 + self.break_reasons = [] + + def __call__(self, gm: torch.fx.GraphModule, example_inputs): + gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail( + gm, self.graphs, self.op_count, [], self.break_reasons +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return self.backend(gm, example_inputs) diff --git a/torch/_dynamo/backends/distributed.py b/torch/_dynamo/backends/distributed.py index b282a62188163..ab193d392265e 100644 --- a/torch/_dynamo/backends/distributed.py +++ b/torch/_dynamo/backends/distributed.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: ignore-errors + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module implements distributed training optimizations for TorchDynamo backends. @@ -19,22 +24,32 @@ import logging import traceback from dataclasses import dataclass, field +<<<<<<< HEAD from typing import Any, Callable, Optional, TYPE_CHECKING +======= +from typing import Any, Optional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from unittest import mock import torch from torch import fx +<<<<<<< HEAD from torch._dynamo.backends.registry import CompiledFn, CompilerFn +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo.output_graph import GraphCompileReason from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode from torch._logging import trace_structured from torch.fx.node import Node +<<<<<<< HEAD if TYPE_CHECKING: from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Regular log messages should go through 'log'. # ddp_graph_log is a separate artifact logger reserved for dumping graphs. # See docs/source/logging.rst for more info. @@ -42,7 +57,11 @@ ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs") +<<<<<<< HEAD def args_str(args: Any) -> str: +======= +def args_str(args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # a debug helper if torch.is_tensor(args): return f"T[{args.shape}]" @@ -61,7 +80,11 @@ class Bucket: nodes: list[fx.Node] = field(default_factory=list) # param_ids is just used for unit testing +<<<<<<< HEAD param_ids: list[int] = field(default_factory=list) +======= + param_ids: list = field(default_factory=list) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # keep track of any buckets that were extended for logging purposes opcount_increased_to_capture_external_output: int = 0 @@ -81,9 +104,15 @@ def bucket_has_external_output(bucket: Bucket) -> bool: return False +<<<<<<< HEAD def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int) -> None: headers = ("Index", "Size (b)", "Param Names") rows: list[tuple[Optional[int], Optional[int], str]] = [] +======= +def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int): + headers = ("Index", "Size (b)", "Param Names") + rows = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extended_buckets = [] for idx, bucket in enumerate(reversed(buckets)): if len(bucket.params) > 0: @@ -139,7 +168,11 @@ def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int) -> None: log.debug("DDPOptimizer captured no parameters and did not split this graph.") +<<<<<<< HEAD def has_higher_order_op(gm: fx.GraphModule) -> bool: +======= +def has_higher_order_op(gm): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Check if there is a higher order op in the graph for node in gm.graph.nodes: if node.op == "get_attr": @@ -149,7 +182,11 @@ def has_higher_order_op(gm: fx.GraphModule) -> bool: return False +<<<<<<< HEAD def propagate_metadata(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None: +======= +def propagate_metadata(orig_gm, split_gm) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for name, module in split_gm.named_modules(): if "." not in name and len(name): # TODO: add split id to CompileId: https://github.com/pytorch/tlparse/pull/83/files#r1880649384 @@ -157,7 +194,11 @@ def propagate_metadata(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> Non module._param_name_to_source = orig_gm._param_name_to_source +<<<<<<< HEAD def propagate_dynamo_source(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None: +======= +def propagate_dynamo_source(orig_gm, split_gm) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) name_to_dynamo_source = {} for node in orig_gm.graph.find_nodes(op="placeholder"): name_to_dynamo_source[node.name] = node._dynamo_source @@ -169,6 +210,7 @@ def propagate_dynamo_source(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) - node._dynamo_source = name_to_dynamo_source.get(node.name, None) +<<<<<<< HEAD class DDPOptimizerContext: def __init__(self) -> None: self.curr_bucket: int = -1 @@ -194,6 +236,16 @@ def __init__( def compile_submod( self, input_mod: fx.GraphModule, args: list[torch.Tensor], kwargs: Any ) -> Any: +======= +# compile each of the partitioned submodules using the user-provided compiler +class SubmodCompiler(torch.fx.interpreter.Interpreter): + def __init__(self, module, compiler, fake_mode) -> None: + super().__init__(module) + self.compiler = compiler + self.fake_mode = fake_mode + + def compile_submod(self, input_mod, args, kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Compile the submodule, using a wrapper to make sure its output is always a tuple, @@ -202,14 +254,22 @@ def compile_submod( assert len(kwargs) == 0, "We assume only args for these modules" class WrapperModule(torch.nn.Module): +<<<<<<< HEAD def __init__( self, submod: Callable[..., Any], unwrap_singleton_tuple: bool ) -> None: +======= + def __init__(self, submod, unwrap_singleton_tuple) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__() self.submod = submod self.unwrap_singleton_tuple = unwrap_singleton_tuple +<<<<<<< HEAD def forward(self, *args: Any) -> Any: +======= + def forward(self, *args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = self.submod(*args) # TODO(whc) # for some reason the isinstance check is necessary if I split one node per submod @@ -227,12 +287,20 @@ def forward(self, *args: Any) -> Any: sn.args = (sn.args,) input_mod.recompile() +<<<<<<< HEAD input_mod.compile_subgraph_reason = GraphCompileReason( # type: ignore[assignment] +======= + input_mod.compile_subgraph_reason = GraphCompileReason( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])." " Set `torch._dynamo.config.optimize_ddp = False` to disable.", [ # it's close to useless to get a real stacktrace here, and quite verbose. +<<<<<<< HEAD traceback.FrameSummary(__file__, 0, "DDPOptimizer"), +======= + traceback.FrameSummary(__file__, 0, DDPOptimizer), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ], ) @@ -279,7 +347,11 @@ def run_node(self, n: Node) -> Any: assert isinstance(kwargs, dict) if n.op == "call_module": +<<<<<<< HEAD real_mod = self.fetch_attr(str(n.target)) +======= + real_mod = self.fetch_attr(n.target) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.fake_mode: curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode) else: @@ -309,10 +381,17 @@ class FakeifyFirstAOTInvocationGuard: def __init__(self) -> None: self.tc = torch._guards.TracingContext.try_get() assert self.tc +<<<<<<< HEAD self.tc.fakify_first_call = True def __del__(self) -> None: self.tc.fakify_first_call = False # type: ignore[union-attr] +======= + torch._guards.TracingContext.try_get().fakify_first_call = True + + def __del__(self) -> None: + self.tc.fakify_first_call = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # For aot_eager and other backends, tracing context is not set has_tracing_context = torch._guards.TracingContext.try_get() is not None @@ -330,9 +409,15 @@ def __del__(self) -> None: # We update the original (outer) graph with a call into the compiled module # instead of the uncompiled one. +<<<<<<< HEAD self.module.delete_submodule(n.target) # type: ignore[operator] n.target = "compiled_" + n.target # type: ignore[operator] self.module.add_submodule(n.target, compiled_submod_real) # type: ignore[operator] +======= + self.module.delete_submodule(n.target) + n.target = "compiled_" + n.target + self.module.add_submodule(n.target, compiled_submod_real) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Finally, we have to produce inputs for use compiling the next submodule, # and these need to be FakeTensors, so we execute the module under fake_mode @@ -342,6 +427,7 @@ def __del__(self) -> None: mock.patch.object(self.fake_mode, "allow_non_fake_inputs", True), ): if has_tracing_context and invoked_aot_autograd: +<<<<<<< HEAD tracing_ctx = torch._guards.TracingContext.try_get() assert tracing_ctx is not None # DDPOptimizer maintains 1 dynamo graph -> N AOT graphs @@ -352,6 +438,8 @@ def __del__(self) -> None: ddp_ctx.curr_bucket += 1 ddp_ctx.metadata_per_bucket.append(tracing_ctx.fw_metadata) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = compiled_submod_real(*new_args, **kwargs) # output should be fake or subclass assert all( @@ -430,7 +518,11 @@ class DDPOptimizer: def __init__( self, bucket_bytes_cap: int, +<<<<<<< HEAD backend_compile_fn: CompilerFn, +======= + backend_compile_fn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) first_bucket_cap: Optional[int] = None, ) -> None: if first_bucket_cap is not None: @@ -448,14 +540,22 @@ def __init__( self.backend_compile_fn = backend_compile_fn +<<<<<<< HEAD def _ignore_parameter(self, parameter: torch.nn.Parameter) -> bool: return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored def add_param(self, bucket: Bucket, param: torch.nn.Parameter, name: str) -> None: +======= + def _ignore_parameter(self, parameter): + return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored + + def add_param(self, bucket, param, name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bucket.size += param.untyped_storage().nbytes() bucket.params.append(name) bucket.param_ids.append(id(param)) +<<<<<<< HEAD def add_module_params_to_bucket( self, mod: torch.nn.Module, @@ -463,12 +563,19 @@ def add_module_params_to_bucket( processed_modules: set[torch.nn.Module], prefix: str, ) -> None: +======= + def add_module_params_to_bucket(self, mod, bucket, processed_modules, prefix): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) processed_modules.add(mod) for name, param in mod.named_parameters(): if param.requires_grad and not self._ignore_parameter(param): self.add_param(bucket, param, f"{prefix}_{name}") +<<<<<<< HEAD def add_param_args(self, bucket: Bucket, node: fx.Node) -> None: +======= + def add_param_args(self, bucket, node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for arg in node.args: if not isinstance(arg, torch.fx.node.Node): continue @@ -480,11 +587,17 @@ def add_param_args(self, bucket: Bucket, node: fx.Node) -> None: and param.requires_grad and not self._ignore_parameter(param) ): +<<<<<<< HEAD self.add_param(bucket, param, str(arg.target)) def compile_fn( self, gm: fx.GraphModule, example_inputs: list[torch.Tensor] ) -> CompiledFn: +======= + self.add_param(bucket, param, arg.target) + + def compile_fn(self, gm: fx.GraphModule, example_inputs: list[torch.Tensor]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Implements graph splitting, first determining a set of of buckets by counting parameter sizes in reverse graph order, then invoking the user/backend compiler @@ -493,7 +606,11 @@ def compile_fn( """ # 1: compute the partition map according to DDP bucket logic buckets = [Bucket()] # (size, param_names) +<<<<<<< HEAD processed_modules: set[torch.nn.Module] = set() +======= + processed_modules = set() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for node in reversed(gm.graph.nodes): if node.op in ("output", "placeholder"): continue @@ -573,9 +690,13 @@ def compile_fn( partition_map[node] = idx split_gm = fx.passes.split_module.split_module( +<<<<<<< HEAD gm, None, # type: ignore[arg-type] lambda node: partition_map[node], +======= + gm, None, lambda node: partition_map[node] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # See note [Assumption on Dynamo Metadata] diff --git a/torch/_dynamo/backends/inductor.py b/torch/_dynamo/backends/inductor.py index ae62dd56678b8..6d7406876b100 100644 --- a/torch/_dynamo/backends/inductor.py +++ b/torch/_dynamo/backends/inductor.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: ignore-errors + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module provides the TorchInductor backend integration for TorchDynamo. @@ -10,13 +15,17 @@ model = torch.compile(model, backend="inductor") """ +<<<<<<< HEAD from typing import Any +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo import register_backend from torch._dynamo.utils import dynamo_timed @register_backend +<<<<<<< HEAD def inductor(*args: Any, **kwargs: Any) -> Any: with dynamo_timed("inductor_import", log_pt2_compile_event=True): # do import here to avoid loading inductor into memory when it is not used @@ -26,6 +35,11 @@ def inductor(*args: Any, **kwargs: Any) -> Any: maybe_warm_pool() +======= +def inductor(*args, **kwargs): + with dynamo_timed("inductor_import", log_pt2_compile_event=True): + # do import here to avoid loading inductor into memory when it is not used +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.compile_fx import compile_fx return compile_fx(*args, **kwargs) diff --git a/torch/_dynamo/backends/onnxrt.py b/torch/_dynamo/backends/onnxrt.py index 93490e64f4ae2..faa0a552e1bfe 100644 --- a/torch/_dynamo/backends/onnxrt.py +++ b/torch/_dynamo/backends/onnxrt.py @@ -1,7 +1,13 @@ +<<<<<<< HEAD +======= +# mypy: ignore-errors + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This backend is maintained by ONNX team. To direct issues # to the right people, please tag related GitHub issues with `module: onnx`. # # Maintainers' Github IDs: wschin, xadupre +<<<<<<< HEAD # from torch.onnx._internal.onnxruntime import ( # is_onnxrt_backend_supported, # torch_compile_backend, @@ -37,3 +43,37 @@ # ) # register_backend(name="onnxrt", compiler_fn=information_displaying_backend) +======= +from torch.onnx._internal.onnxruntime import ( + is_onnxrt_backend_supported, + torch_compile_backend, +) + +from .registry import register_backend + + +def has_onnxruntime(): + # FIXME: update test/dynamo/test_backends.py to call is_onnxrt_backend_supported() + return is_onnxrt_backend_supported() + + +if is_onnxrt_backend_supported(): + register_backend(name="onnxrt", compiler_fn=torch_compile_backend) +else: + + def information_displaying_backend(*args, **kwargs): + raise ImportError( + "onnxrt is not registered as a backend. " + "Please make sure all dependencies such as " + "numpy, onnx, onnxscript, and onnxruntime-training are installed. " + "Suggested procedure to fix dependency problem:\n" + " (1) pip or conda install numpy onnx onnxscript onnxruntime-training.\n" + " (2) Open a new python terminal.\n" + " (3) Call the API `torch.onnx.is_onnxrt_backend_supported()`:\n" + " (4) If it returns `True`, then you can use `onnxrt` backend.\n" + " (5) If it returns `False`, please execute the package importing section in " + "torch/onnx/_internal/onnxruntime.py under pdb line-by-line to see which import fails." + ) + + register_backend(name="onnxrt", compiler_fn=information_displaying_backend) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_dynamo/backends/registry.py b/torch/_dynamo/backends/registry.py index 699d82fff3f00..2ab875081e9e5 100644 --- a/torch/_dynamo/backends/registry.py +++ b/torch/_dynamo/backends/registry.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: ignore-errors + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module implements TorchDynamo's backend registry system for managing compiler backends. @@ -63,7 +68,11 @@ def my_compiler_function(fx_graph, example_inputs): import sys from collections.abc import Sequence from importlib.metadata import EntryPoint +<<<<<<< HEAD from typing import Any, Callable, Optional, Protocol, Union +======= +from typing import Callable, Optional, Protocol +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch import fx @@ -86,7 +95,11 @@ def register_backend( compiler_fn: Optional[CompilerFn] = None, name: Optional[str] = None, tags: Sequence[str] = (), +<<<<<<< HEAD ) -> Callable[..., Any]: +======= +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Decorator to add a given compiler to the registry to allow calling `torch.compile` with string shorthand. Note: for projects not @@ -100,14 +113,22 @@ def register_backend( """ if compiler_fn is None: # @register_backend(name="") syntax +<<<<<<< HEAD return functools.partial(register_backend, name=name, tags=tags) # type: ignore[return-value] +======= + return functools.partial(register_backend, name=name, tags=tags) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert callable(compiler_fn) name = name or compiler_fn.__name__ assert name not in _COMPILER_FNS, f"duplicate name: {name}" if compiler_fn not in _BACKENDS: _BACKENDS[name] = None _COMPILER_FNS[name] = compiler_fn +<<<<<<< HEAD compiler_fn._tags = tuple(tags) # type: ignore[attr-defined] +======= + compiler_fn._tags = tuple(tags) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return compiler_fn @@ -117,7 +138,11 @@ def register_backend( ) +<<<<<<< HEAD def lookup_backend(compiler_fn: Union[str, CompilerFn]) -> CompilerFn: +======= +def lookup_backend(compiler_fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Expand backend strings to functions""" if isinstance(compiler_fn, str): if compiler_fn not in _BACKENDS: @@ -129,33 +154,53 @@ def lookup_backend(compiler_fn: Union[str, CompilerFn]) -> CompilerFn: if compiler_fn not in _COMPILER_FNS: entry_point = _BACKENDS[compiler_fn] +<<<<<<< HEAD if entry_point is not None: register_backend(compiler_fn=entry_point.load(), name=compiler_fn) +======= + register_backend(compiler_fn=entry_point.load(), name=compiler_fn) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) compiler_fn = _COMPILER_FNS[compiler_fn] return compiler_fn +<<<<<<< HEAD # NOTE: can't type this due to public api mismatch; follow up with dev team def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: # type: ignore[no-untyped-def] +======= +def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Return valid strings that can be passed to: torch.compile(..., backend="name") """ _lazy_import() +<<<<<<< HEAD exclude_tags_set = set(exclude_tags or ()) +======= + exclude_tags = set(exclude_tags or ()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) backends = [ name for name in _BACKENDS.keys() if name not in _COMPILER_FNS +<<<<<<< HEAD or not exclude_tags_set.intersection(_COMPILER_FNS[name]._tags) # type: ignore[attr-defined] +======= + or not exclude_tags.intersection(_COMPILER_FNS[name]._tags) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] return sorted(backends) @functools.cache +<<<<<<< HEAD def _lazy_import() -> None: +======= +def _lazy_import(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .. import backends from ..utils import import_submodule @@ -169,7 +214,11 @@ def _lazy_import() -> None: @functools.cache +<<<<<<< HEAD def _discover_entrypoint_backends() -> None: +======= +def _discover_entrypoint_backends(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # importing here so it will pick up the mocked version in test_backends.py from importlib.metadata import entry_points @@ -177,9 +226,18 @@ def _discover_entrypoint_backends() -> None: if sys.version_info < (3, 10): eps = entry_points() eps = eps[group_name] if group_name in eps else [] +<<<<<<< HEAD eps_dict = {ep.name: ep for ep in eps} else: eps = entry_points(group=group_name) eps_dict = {name: eps[name] for name in eps.names} for backend_name in eps_dict: _BACKENDS[backend_name] = eps_dict[backend_name] +======= + eps = {ep.name: ep for ep in eps} + else: + eps = entry_points(group=group_name) + eps = {name: eps[name] for name in eps.names} + for backend_name in eps: + _BACKENDS[backend_name] = eps[backend_name] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_dynamo/backends/tensorrt.py b/torch/_dynamo/backends/tensorrt.py index 493e21a9dfc5f..bfd0d70ccb9b1 100644 --- a/torch/_dynamo/backends/tensorrt.py +++ b/torch/_dynamo/backends/tensorrt.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: ignore-errors + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # import torch # type: ignore[import] # from .common import device_from_inputs, fake_tensor_unsupported # type: ignore[import] # from .registry import register_backend # type: ignore[import] diff --git a/torch/_dynamo/backends/torchxla.py b/torch/_dynamo/backends/torchxla.py index 7fa5d2d8668b6..be65f911833b6 100644 --- a/torch/_dynamo/backends/torchxla.py +++ b/torch/_dynamo/backends/torchxla.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD import logging from typing import Any, Callable @@ -7,12 +8,23 @@ from ..backends.common import aot_autograd from .registry import CompiledFn, register_backend, register_experimental_backend +======= +# mypy: ignore-errors + +import logging + +from functorch.compile import make_boxed_func + +from ..backends.common import aot_autograd +from .registry import register_backend, register_experimental_backend +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) @register_experimental_backend +<<<<<<< HEAD def openxla_eval( model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor] ) -> CompiledFn: @@ -28,6 +40,17 @@ def openxla_eval_boxed( def xla_backend_helper( model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], boxed: bool = False ) -> Callable[..., Any]: +======= +def openxla_eval(model, fake_tensor_inputs): + return xla_backend_helper(model, fake_tensor_inputs, boxed=False) + + +def openxla_eval_boxed(model, fake_tensor_inputs): + return xla_backend_helper(model, fake_tensor_inputs, boxed=True) + + +def xla_backend_helper(model, fake_tensor_inputs, boxed=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: import torch_xla.core.dynamo_bridge as bridge except ImportError as e: @@ -37,7 +60,11 @@ def xla_backend_helper( compiled_graph = None +<<<<<<< HEAD def fwd(*args: torch.Tensor) -> Any: +======= + def fwd(*args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal model nonlocal compiled_graph if compiled_graph is None: diff --git a/torch/_dynamo/backends/tvm.py b/torch/_dynamo/backends/tvm.py index 7e2ab19bb9c0a..1271a0f391e63 100644 --- a/torch/_dynamo/backends/tvm.py +++ b/torch/_dynamo/backends/tvm.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: ignore-errors + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module provides TVM backend integration for TorchDynamo. @@ -27,10 +32,16 @@ import sys import tempfile from types import MappingProxyType +<<<<<<< HEAD from typing import Any, Callable, Optional import torch from torch import fx +======= +from typing import Optional + +import torch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .common import device_from_inputs, fake_tensor_unsupported from .registry import register_backend @@ -40,6 +51,7 @@ @register_backend +<<<<<<< HEAD @fake_tensor_unsupported # type: ignore[arg-type] def tvm( gm: fx.GraphModule, @@ -50,6 +62,17 @@ def tvm( if options is None: options = MappingProxyType({"scheduler": None, "trials": 20000, "opt_level": 3}) assert options is not None +======= +@fake_tensor_unsupported +def tvm( + gm, + example_inputs, + *, + options: Optional[MappingProxyType] = MappingProxyType( + {"scheduler": None, "trials": 20000, "opt_level": 3} + ), +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import tvm # type: ignore[import] from tvm import relay # type: ignore[import] from tvm.contrib import graph_executor # type: ignore[import] @@ -147,7 +170,11 @@ def tvm( ) m = graph_executor.GraphModule(lib["default"](dev)) +<<<<<<< HEAD def to_torch_tensor(nd_tensor: tvm.nd.array) -> torch.Tensor: +======= + def to_torch_tensor(nd_tensor): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """A helper function to transfer a NDArray to torch.tensor.""" if nd_tensor.dtype == "bool": # DLPack does not support boolean so it can't be handled by @@ -156,7 +183,11 @@ def to_torch_tensor(nd_tensor: tvm.nd.array) -> torch.Tensor: return torch.from_numpy(nd_tensor.numpy()) return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack()) +<<<<<<< HEAD def to_tvm_tensor(torch_tensor: torch.Tensor) -> tvm.nd.array: +======= + def to_tvm_tensor(torch_tensor): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """A helper function to transfer a torch.tensor to NDArray.""" if torch_tensor.dtype == torch.bool: # same reason as above, fallback to numpy conversion which @@ -164,7 +195,11 @@ def to_tvm_tensor(torch_tensor: torch.Tensor) -> tvm.nd.array: return tvm.nd.array(torch_tensor.cpu().numpy()) return tvm.nd.from_dlpack(torch_tensor) +<<<<<<< HEAD def exec_tvm(*i_args: torch.Tensor) -> list[torch.Tensor]: +======= + def exec_tvm(*i_args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = [a.contiguous() for a in i_args] shape_info, _ = m.get_input_info() active_inputs = {name for name, _ in shape_info.items()} @@ -193,7 +228,11 @@ def exec_tvm(*i_args: torch.Tensor) -> list[torch.Tensor]: tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler") +<<<<<<< HEAD def has_tvm() -> bool: +======= +def has_tvm(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: importlib.import_module("tvm") return True @@ -202,7 +241,11 @@ def has_tvm() -> bool: @functools.cache +<<<<<<< HEAD def llvm_target() -> str: +======= +def llvm_target(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.platform == "linux": cpuinfo = open("/proc/cpuinfo").read() if "avx512" in cpuinfo: diff --git a/torch/_dynamo/bytecode_analysis.py b/torch/_dynamo/bytecode_analysis.py index 8bdf155e00603..e28e87d3b22d5 100644 --- a/torch/_dynamo/bytecode_analysis.py +++ b/torch/_dynamo/bytecode_analysis.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module provides utilities for analyzing and optimizing Python bytecode. Key functionality includes: @@ -16,6 +21,7 @@ import dataclasses import dis import sys +<<<<<<< HEAD from typing import Any, TYPE_CHECKING, Union @@ -24,6 +30,11 @@ # and refactoring in callsite; that way we don't have to guard this import from .bytecode_transformation import Instruction +======= +from typing import Any, Union + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TERMINAL_OPCODES = { dis.opmap["RETURN_VALUE"], dis.opmap["JUMP_FORWARD"], @@ -36,7 +47,11 @@ TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"]) else: TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"]) +<<<<<<< HEAD if (3, 12) <= sys.version_info < (3, 14): +======= +if sys.version_info >= (3, 12): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"]) if sys.version_info >= (3, 13): TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD_NO_INTERRUPT"]) @@ -48,7 +63,11 @@ stack_effect = dis.stack_effect +<<<<<<< HEAD def get_indexof(insts: list["Instruction"]) -> dict["Instruction", int]: +======= +def get_indexof(insts): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Get a mapping from instruction memory address to index in instruction list. Additionally checks that each instruction only appears once in the list. @@ -60,12 +79,20 @@ def get_indexof(insts: list["Instruction"]) -> dict["Instruction", int]: return indexof +<<<<<<< HEAD def remove_dead_code(instructions: list["Instruction"]) -> list["Instruction"]: +======= +def remove_dead_code(instructions): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Dead code elimination""" indexof = get_indexof(instructions) live_code = set() +<<<<<<< HEAD def find_live_code(start: int) -> None: +======= + def find_live_code(start): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i in range(start, len(instructions)): if i in live_code: return @@ -74,7 +101,10 @@ def find_live_code(start: int) -> None: if inst.exn_tab_entry: find_live_code(indexof[inst.exn_tab_entry.target]) if inst.opcode in JUMP_OPCODES: +<<<<<<< HEAD assert inst.target is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) find_live_code(indexof[inst.target]) if inst.opcode in TERMINAL_OPCODES: return @@ -106,7 +136,11 @@ def find_live_code(start: int) -> None: return [inst for i, inst in enumerate(instructions) if i in live_code] +<<<<<<< HEAD def remove_pointless_jumps(instructions: list["Instruction"]) -> list["Instruction"]: +======= +def remove_pointless_jumps(instructions): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Eliminate jumps to the next instruction""" pointless_jumps = { id(a) @@ -116,11 +150,19 @@ def remove_pointless_jumps(instructions: list["Instruction"]) -> list["Instructi return [inst for inst in instructions if id(inst) not in pointless_jumps] +<<<<<<< HEAD def propagate_line_nums(instructions: list["Instruction"]) -> None: """Ensure every instruction has line number set in case some are removed""" cur_line_no = None def populate_line_num(inst: "Instruction") -> None: +======= +def propagate_line_nums(instructions): + """Ensure every instruction has line number set in case some are removed""" + cur_line_no = None + + def populate_line_num(inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal cur_line_no if inst.starts_line: cur_line_no = inst.starts_line @@ -131,12 +173,20 @@ def populate_line_num(inst: "Instruction") -> None: populate_line_num(inst) +<<<<<<< HEAD def remove_extra_line_nums(instructions: list["Instruction"]) -> None: +======= +def remove_extra_line_nums(instructions): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Remove extra starts line properties before packing bytecode""" cur_line_no = None +<<<<<<< HEAD def remove_line_num(inst: "Instruction") -> None: +======= + def remove_line_num(inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal cur_line_no if inst.starts_line is None: return @@ -156,14 +206,22 @@ class ReadsWrites: visited: set[Any] +<<<<<<< HEAD def livevars_analysis( instructions: list["Instruction"], instruction: "Instruction" ) -> set[Any]: +======= +def livevars_analysis(instructions, instruction): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) indexof = get_indexof(instructions) must = ReadsWrites(set(), set(), set()) may = ReadsWrites(set(), set(), set()) +<<<<<<< HEAD def walk(state: ReadsWrites, start: int) -> None: +======= + def walk(state, start): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if start in state.visited: return state.visited.add(start) @@ -183,7 +241,10 @@ def walk(state: ReadsWrites, start: int) -> None: if inst.exn_tab_entry: walk(may, indexof[inst.exn_tab_entry.target]) if inst.opcode in JUMP_OPCODES: +<<<<<<< HEAD assert inst.target is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) walk(may, indexof[inst.target]) state = may if inst.opcode in TERMINAL_OPCODES: @@ -204,19 +265,31 @@ class StackSize: high: Union[int, float] fixed_point: FixedPointBox +<<<<<<< HEAD def zero(self) -> None: +======= + def zero(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.low = 0 self.high = 0 self.fixed_point.value = False +<<<<<<< HEAD def offset_of(self, other: "StackSize", n: int) -> None: +======= + def offset_of(self, other, n): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prior = (self.low, self.high) self.low = min(self.low, other.low + n) self.high = max(self.high, other.high + n) if (self.low, self.high) != prior: self.fixed_point.value = False +<<<<<<< HEAD def exn_tab_jump(self, depth: int) -> None: +======= + def exn_tab_jump(self, depth): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prior = (self.low, self.high) self.low = min(self.low, depth) self.high = max(self.high, depth) @@ -224,7 +297,11 @@ def exn_tab_jump(self, depth: int) -> None: self.fixed_point.value = False +<<<<<<< HEAD def stacksize_analysis(instructions: list["Instruction"]) -> Union[int, float]: +======= +def stacksize_analysis(instructions) -> Union[int, float]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert instructions fixed_point = FixedPointBox() stack_sizes = { @@ -245,7 +322,10 @@ def stacksize_analysis(instructions: list["Instruction"]) -> Union[int, float]: eff = stack_effect(inst.opcode, inst.arg, jump=False) stack_sizes[next_inst].offset_of(stack_size, eff) if inst.opcode in JUMP_OPCODES: +<<<<<<< HEAD assert inst.target is not None, f"missing target: {inst}" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stack_sizes[inst.target].offset_of( stack_size, stack_effect(inst.opcode, inst.arg, jump=True) ) @@ -255,6 +335,14 @@ def stacksize_analysis(instructions: list["Instruction"]) -> Union[int, float]: depth = inst.exn_tab_entry.depth + int(inst.exn_tab_entry.lasti) + 1 stack_sizes[inst.exn_tab_entry.target].exn_tab_jump(depth) +<<<<<<< HEAD +======= + if False: + for inst in instructions: + stack_size = stack_sizes[inst] + print(stack_size.low, stack_size.high, inst) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) low = min(x.low for x in stack_sizes.values()) high = max(x.high for x in stack_sizes.values()) diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index 14a6f78bfcd48..f269a77a3b2ff 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module provides utilities for analyzing, transforming and manipulating Python bytecode. It includes functionality for: @@ -21,11 +26,18 @@ import sys import types import uuid +<<<<<<< HEAD from collections.abc import Iterable, Iterator, Mapping, Sequence from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union from ..utils._backport_slots import dataclass_slots from . import config +======= +from collections.abc import Iterator, Sequence +from typing import Any, Callable, cast, Optional, Union + +from ..utils._backport_slots import dataclass_slots +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .bytecode_analysis import ( get_indexof, propagate_line_nums, @@ -35,10 +47,13 @@ from .utils import is_safe_constant +<<<<<<< HEAD if TYPE_CHECKING: from .output_graph import DynamoTracerOutput +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass_slots @dataclasses.dataclass class InstructionExnTabEntry: @@ -56,9 +71,13 @@ def __repr__(self) -> str: f"depth={self.depth}, lasti={self.lasti})" ) +<<<<<<< HEAD def __eq__(self, o: object) -> bool: if not isinstance(o, InstructionExnTabEntry): return False +======= + def __eq__(self, o) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( self.start is o.start and self.end is o.end @@ -89,7 +108,11 @@ class Instruction: def __hash__(self) -> int: return id(self) +<<<<<<< HEAD def __eq__(self, other: object) -> bool: +======= + def __eq__(self, other) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return id(self) == id(other) def short_inst_repr(self) -> str: @@ -150,26 +173,42 @@ def __repr__(self) -> str: if sys.version_info >= (3, 12): +<<<<<<< HEAD def inst_has_op_bits(name: str) -> bool: +======= + def inst_has_op_bits(name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return name in ("LOAD_ATTR", "LOAD_GLOBAL", "LOAD_SUPER_ATTR") elif sys.version_info >= (3, 11): +<<<<<<< HEAD def inst_has_op_bits(name: str) -> bool: +======= + def inst_has_op_bits(name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return name == "LOAD_GLOBAL" else: +<<<<<<< HEAD def inst_has_op_bits(name: str): +======= + def inst_has_op_bits(name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False def create_instruction( +<<<<<<< HEAD name: str, *, arg: Optional[int] = None, argval: Optional[Any] = _NotProvided, target: Optional[Instruction] = None, +======= + name, *, arg=None, argval=_NotProvided, target=None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Instruction: """ At most one of `arg`, `argval`, and `target` can be not None/_NotProvided. @@ -207,16 +246,24 @@ def create_instruction( # Python 3.11 remaps +<<<<<<< HEAD def create_jump_absolute(target: Instruction) -> Instruction: +======= +def create_jump_absolute(target) -> Instruction: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inst = "JUMP_FORWARD" if sys.version_info >= (3, 11) else "JUMP_ABSOLUTE" return create_instruction(inst, target=target) +<<<<<<< HEAD def is_jump_absolute(target: Instruction) -> bool: return target.opname in ("JUMP_FORWARD", "JUMP_ABSOLUTE") def create_load_const(val: Any, checked: bool = True) -> Instruction: +======= +def create_load_const(val, checked=True) -> Instruction: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ In general we should only create `LOAD_CONST` for immutable objects, but sometimes it's convenient _and safe_ for Dynamo create `LOAD_CONST` for @@ -233,7 +280,11 @@ def create_dup_top() -> Instruction: return create_instruction("DUP_TOP") +<<<<<<< HEAD def create_rot_n(n: int) -> list[Instruction]: +======= +def create_rot_n(n) -> list[Instruction]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Returns a "simple" sequence of instructions that rotates TOS to the n-th position in the stack. For Python < 3.11, returns a single ROT_* @@ -251,6 +302,7 @@ def create_rot_n(n: int) -> list[Instruction]: # e.g. rotate 3 is equivalent to swap 3, swap 2 return [create_instruction("SWAP", arg=i) for i in range(n, 1, -1)] +<<<<<<< HEAD # ROT_N does not exist in Python <= 3.9, but we can simulate it if sys.version_info < (3, 10) and n >= 5: """ @@ -266,6 +318,11 @@ def create_rot_n(n: int) -> list[Instruction]: create_instruction("BUILD_TUPLE", arg=n - 1), create_instruction("UNPACK_SEQUENCE", arg=n - 1), ] +======= + # ensure desired rotate function exists + if sys.version_info < (3, 10) and n >= 5: + raise AttributeError(f"rotate {n} not supported for Python < 3.10") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if n <= 4: return [create_instruction("ROT_" + ["TWO", "THREE", "FOUR"][n - 2])] @@ -290,6 +347,7 @@ def add_push_null( In this case, instructions WILL be modified. """ if isinstance(inst_or_insts, Instruction): +<<<<<<< HEAD insts: list[Instruction] = [inst_or_insts] else: assert isinstance(inst_or_insts, list) @@ -302,6 +360,19 @@ def inst_has_bit_set(idx: int) -> bool: def set_inst_bit(idx: int) -> None: assert insts[idx].arg is not None insts[idx].arg |= 1 # type: ignore[operator] +======= + insts = [inst_or_insts] + else: + insts = inst_or_insts + + def inst_has_bit_set(idx): + assert insts[idx].arg is not None + return insts[idx].arg & 1 == 1 + + def set_inst_bit(idx): + assert insts[idx].arg is not None + insts[idx].arg |= 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 13): # In 3.13, NULL follows the callable @@ -338,9 +409,14 @@ def add_push_null_call_function_ex( is not set, due to an expected CALL_FUNCTION_EX instruction. """ if isinstance(inst_or_insts, Instruction): +<<<<<<< HEAD insts: list[Instruction] = [inst_or_insts] else: assert isinstance(inst_or_insts, list) +======= + insts = [inst_or_insts] + else: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) insts = inst_or_insts if sys.version_info < (3, 11): @@ -361,7 +437,11 @@ def add_push_null_call_function_ex( return insts +<<<<<<< HEAD def create_call_function(nargs: int, push_null: bool) -> list[Instruction]: +======= +def create_call_function(nargs, push_null) -> list[Instruction]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Creates a sequence of instructions that makes a function call. @@ -416,7 +496,11 @@ def create_call_function(nargs: int, push_null: bool) -> list[Instruction]: return [create_instruction("CALL_FUNCTION", arg=nargs)] +<<<<<<< HEAD def create_call_method(nargs: int) -> list[Instruction]: +======= +def create_call_method(nargs) -> list[Instruction]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 12): return [create_instruction("CALL", arg=nargs)] if sys.version_info >= (3, 11): @@ -427,28 +511,43 @@ def create_call_method(nargs: int) -> list[Instruction]: return [create_instruction("CALL_METHOD", arg=nargs)] +<<<<<<< HEAD def create_load_method(name: str) -> Instruction: +======= +def create_load_method(name) -> Instruction: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 12): # in 3.12, create a LOAD_ATTR instruction with the low bit set return create_instruction("LOAD_ATTR", arg=1, argval=name) return create_instruction("LOAD_METHOD", argval=name) +<<<<<<< HEAD def create_setup_with(target: Instruction) -> Instruction: +======= +def create_setup_with(target) -> Instruction: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) opname = "BEFORE_WITH" if sys.version_info >= (3, 11) else "SETUP_WITH" return create_instruction(opname, target=target) +<<<<<<< HEAD def create_swap(n: int) -> list[Instruction]: +======= +def create_swap(n) -> list[Instruction]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 11): return [create_instruction("SWAP", arg=n)] # in Python < 3.11, SWAP is a macro that expands to multiple instructions if n == 1: return [] +<<<<<<< HEAD elif n == 2: return [create_instruction("ROT_TWO")] elif n == 3: return [create_instruction("ROT_THREE"), create_instruction("ROT_TWO")] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ e.g. swap "a" and "b" in this stack: 0 a 1 2 3 b @@ -485,6 +584,7 @@ def create_swap(n: int) -> list[Instruction]: ] +<<<<<<< HEAD def create_binary_slice( start: Optional[int], end: Optional[int], store: bool = False ) -> list[Instruction]: @@ -545,6 +645,8 @@ def create_print_value(value: Any) -> list[Instruction]: ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def lnotab_writer( lineno: int, byteno: int = 0 ) -> tuple[list[int], Callable[[int, int], None]]: @@ -556,7 +658,11 @@ def lnotab_writer( assert sys.version_info < (3, 10) lnotab: list[int] = [] +<<<<<<< HEAD def update(lineno_new: int, byteno_new: int) -> None: +======= + def update(lineno_new, byteno_new): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal byteno, lineno while byteno_new != byteno or lineno_new != lineno: byte_offset = max(0, min(byteno_new - byteno, 255)) @@ -569,9 +675,13 @@ def update(lineno_new: int, byteno_new: int) -> None: return lnotab, update +<<<<<<< HEAD def linetable_310_writer( first_lineno: int, ) -> tuple[list[int], Callable[[int, int], None], Callable[[int], None]]: +======= +def linetable_310_writer(first_lineno): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Used to create typing.CodeType.co_linetable See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt @@ -583,7 +693,11 @@ def linetable_310_writer( lineno_delta = 0 byteno = 0 +<<<<<<< HEAD def _update(byteno_delta: int, lineno_delta: int) -> None: +======= + def _update(byteno_delta, lineno_delta): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) while byteno_delta != 0 or lineno_delta != 0: byte_offset = max(0, min(byteno_delta, 254)) line_offset = max(-127, min(lineno_delta, 127)) @@ -592,7 +706,11 @@ def _update(byteno_delta: int, lineno_delta: int) -> None: lineno_delta -= line_offset linetable.extend((byte_offset, line_offset & 0xFF)) +<<<<<<< HEAD def update(lineno_new: int, byteno_new: int) -> None: +======= + def update(lineno_new, byteno_new): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal lineno, lineno_delta, byteno byteno_delta = byteno_new - byteno byteno = byteno_new @@ -600,7 +718,11 @@ def update(lineno_new: int, byteno_new: int) -> None: lineno_delta = lineno_new - lineno lineno = lineno_new +<<<<<<< HEAD def end(total_bytes: int) -> None: +======= + def end(total_bytes): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _update(total_bytes - byteno, lineno_delta) return linetable, update, end @@ -621,9 +743,13 @@ def encode_varint(n: int) -> list[int]: return b +<<<<<<< HEAD def linetable_311_writer( first_lineno: int, ) -> tuple[list[int], Callable[[Optional["dis.Positions"], int], None]]: +======= +def linetable_311_writer(first_lineno: int): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Used to create typing.CodeType.co_linetable See https://github.com/python/cpython/blob/3.11/Objects/locations.md @@ -633,11 +759,19 @@ def linetable_311_writer( linetable = [] lineno = first_lineno +<<<<<<< HEAD def update(positions: Optional["dis.Positions"], inst_size: int) -> None: nonlocal lineno lineno_new = positions.lineno if positions else None def _update(delta: int, size: int) -> None: +======= + def update(positions: "dis.Positions", inst_size): + nonlocal lineno + lineno_new = positions.lineno if positions else None + + def _update(delta, size): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert 0 < size <= 8 # first byte - use 13 (no column info) is positions is # malformed, otherwise use 14 (long form) @@ -816,9 +950,13 @@ def assemble(instructions: list[Instruction], firstlineno: int) -> tuple[bytes, return bytes(code), bytes(lnotab) +<<<<<<< HEAD def _get_instruction_by_offset( offset_to_inst: dict[int, Instruction], offset: int ) -> Optional[Instruction]: +======= +def _get_instruction_by_offset(offset_to_inst: dict[int, Instruction], offset: int): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Get the instruction located at a given offset, accounting for EXTENDED_ARGs """ @@ -828,11 +966,17 @@ def _get_instruction_by_offset( return None +<<<<<<< HEAD def virtualize_jumps(instructions: Iterable[Instruction]) -> None: """Replace jump targets with pointers to make editing easier""" jump_targets = { inst.offset: inst for inst in instructions if inst.offset is not None } +======= +def virtualize_jumps(instructions) -> None: + """Replace jump targets with pointers to make editing easier""" + jump_targets = {inst.offset: inst for inst in instructions} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for inst in instructions: if inst.opcode in dis.hasjabs or inst.opcode in dis.hasjrel: @@ -855,7 +999,11 @@ def flip_jump_direction(instruction: Instruction) -> None: assert instruction.opcode in _REL_JUMPS +<<<<<<< HEAD def _get_instruction_front(instructions: list[Instruction], idx: int) -> Instruction: +======= +def _get_instruction_front(instructions: list[Instruction], idx: int): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ i.e. get the first EXTENDED_ARG instruction (if any) when targeting instructions[idx] with a jump. @@ -869,7 +1017,11 @@ def _get_instruction_front(instructions: list[Instruction], idx: int) -> Instruc return target +<<<<<<< HEAD def devirtualize_jumps(instructions: list[Instruction]) -> None: +======= +def devirtualize_jumps(instructions): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Fill in args for virtualized jump target after instructions may have moved""" jumps = set(dis.hasjabs).union(set(dis.hasjrel)) @@ -877,11 +1029,14 @@ def devirtualize_jumps(instructions: list[Instruction]) -> None: for inst in instructions: if inst.opcode in jumps: if inst.opcode not in dis.hasjabs: +<<<<<<< HEAD assert ( inst.target is not None and inst.target.offset is not None and inst.offset is not None ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if inst.target.offset < inst.offset: if sys.version_info < (3, 11): raise RuntimeError("Got negative jump offset for Python < 3.11") @@ -900,7 +1055,10 @@ def devirtualize_jumps(instructions: list[Instruction]) -> None: # compute jump instruction arg for inst in instructions: if inst.opcode in jumps: +<<<<<<< HEAD assert inst.target is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) target = _get_instruction_front(instructions, indexof[inst.target]) if inst.opcode in dis.hasjabs: if sys.version_info < (3, 10): @@ -913,7 +1071,10 @@ def devirtualize_jumps(instructions: list[Instruction]) -> None: raise RuntimeError("Python 3.11+ should not have absolute jumps") else: # relative jump # byte offset between target and next instruction +<<<<<<< HEAD assert target.offset is not None and inst.offset is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inst.arg = abs( int(target.offset - inst.offset - instruction_size(inst)) ) @@ -924,9 +1085,13 @@ def devirtualize_jumps(instructions: list[Instruction]) -> None: inst.argrepr = f"to {target.offset}" +<<<<<<< HEAD def virtualize_exception_table( exn_tab_bytes: bytes, instructions: list[Instruction] ) -> None: +======= +def virtualize_exception_table(exn_tab_bytes: bytes, instructions: list[Instruction]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Replace exception table entries with pointers to make editing easier""" exn_tab = parse_exception_table(exn_tab_bytes) offset_to_inst = {cast(int, inst.offset): inst for inst in instructions} @@ -935,7 +1100,11 @@ def virtualize_exception_table( exn_tab_iter = iter(exn_tab) try: +<<<<<<< HEAD def step() -> tuple[ExceptionTableEntry, InstructionExnTabEntry]: +======= + def step(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal end_offset_idx entry = next(exn_tab_iter) # find rightmost offset <= entry.end, since entry.end may not be @@ -949,9 +1118,15 @@ def step() -> tuple[ExceptionTableEntry, InstructionExnTabEntry]: assert end_offset_idx > 0 end_offset = offsets[end_offset_idx - 1] inst_entry = InstructionExnTabEntry( +<<<<<<< HEAD _get_instruction_by_offset(offset_to_inst, entry.start), # type: ignore[arg-type] _get_instruction_by_offset(offset_to_inst, end_offset), # type: ignore[arg-type] _get_instruction_by_offset(offset_to_inst, entry.target), # type: ignore[arg-type] +======= + _get_instruction_by_offset(offset_to_inst, entry.start), + _get_instruction_by_offset(offset_to_inst, end_offset), + _get_instruction_by_offset(offset_to_inst, entry.target), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) entry.depth, entry.lasti, ) @@ -959,7 +1134,10 @@ def step() -> tuple[ExceptionTableEntry, InstructionExnTabEntry]: entry, inst_entry = step() for inst in instructions: +<<<<<<< HEAD assert inst.offset is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) while inst.offset > entry.end: entry, inst_entry = step() if inst.offset >= entry.start: @@ -981,18 +1159,27 @@ def compute_exception_table( start = _get_instruction_front( instructions, indexof[inst.exn_tab_entry.start] ).offset +<<<<<<< HEAD assert start is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # point to the last 2 bytes of the end instruction end = ( cast(int, inst.exn_tab_entry.end.offset) + instruction_size(inst.exn_tab_entry.end) - 2 ) +<<<<<<< HEAD assert end is not None target = _get_instruction_front( instructions, indexof[inst.exn_tab_entry.target] ).offset assert target is not None +======= + target = _get_instruction_front( + instructions, indexof[inst.exn_tab_entry.target] + ).offset +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) key = (start, end) val = (target, inst.exn_tab_entry.depth, inst.exn_tab_entry.lasti) if key in exn_dict: @@ -1012,7 +1199,11 @@ def compute_exception_table( key_stack: list[tuple[int, int]] = [] exn_tab: list[ExceptionTableEntry] = [] +<<<<<<< HEAD def pop() -> None: +======= + def pop(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Pop the key_stack and append an exception table entry if possible. """ @@ -1046,7 +1237,11 @@ def pop() -> None: def check_inst_exn_tab_entries_nested( +<<<<<<< HEAD tab: list[InstructionExnTabEntry], indexof: dict[Instruction, int] +======= + tab: list[InstructionExnTabEntry], indexof +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: """ Checks `tab` is a properly sorted list of nested InstructionExnTabEntry's, @@ -1091,7 +1286,11 @@ def propagate_inst_exn_table_entries(instructions: list[Instruction]) -> None: instructions[i].exn_tab_entry = copy.copy(entry) +<<<<<<< HEAD def check_inst_exn_tab_entries_valid(instructions: list[Instruction]) -> None: +======= +def check_inst_exn_tab_entries_valid(instructions: list[Instruction]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Checks that exn_tab_entries of instructions are valid. An entry's start, end, and target must be in instructions. @@ -1124,9 +1323,13 @@ def strip_extended_args(instructions: list[Instruction]) -> None: # instruction, exception table entries, and positions. # Returns the modified sequence of instructions (including the modified # old instruction!) that can be manipulated elsewhere. +<<<<<<< HEAD def overwrite_instruction( old_inst: Instruction, new_insts: list[Instruction] ) -> list[Instruction]: +======= +def overwrite_instruction(old_inst, new_insts): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # update old_inst.exnt_tab_entry.end if necessary if ( old_inst.exn_tab_entry @@ -1233,6 +1436,7 @@ def remove_fused_load_store(instructions: list[Instruction]) -> None: instructions[:] = new_insts +<<<<<<< HEAD # adds GRAPH_BREAK_IF_LEAF (not a real instruction) before RETURN_* instructions # for testing purposes def add_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> None: @@ -1276,6 +1480,8 @@ def remove_graph_break_if_leaf_instructions(instructions: list[Instruction]) -> instructions[:] = new_insts +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def explicit_super(code: types.CodeType, instructions: list[Instruction]) -> None: """convert super() with no args into explicit arg form""" cell_and_free = (code.co_cellvars or ()) + (code.co_freevars or ()) @@ -1318,7 +1524,11 @@ def fix_extended_args(instructions: list[Instruction]) -> int: """Fill in correct argvals for EXTENDED_ARG ops""" output: list[Instruction] = [] +<<<<<<< HEAD def maybe_pop_n(n: int) -> None: +======= + def maybe_pop_n(n): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for _ in range(n): if output and output[-1].opcode == dis.EXTENDED_ARG: output.pop() @@ -1347,7 +1557,11 @@ def maybe_pop_n(n: int) -> None: return added +<<<<<<< HEAD def instruction_size(inst: Instruction) -> int: +======= +def instruction_size(inst) -> int: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch if sys.version_info >= (3, 11): @@ -1355,21 +1569,33 @@ def instruction_size(inst: Instruction) -> int: return 2 +<<<<<<< HEAD def check_offsets(instructions: Sequence[Instruction]) -> None: +======= +def check_offsets(instructions) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) offset = 0 for inst in instructions: assert inst.offset == offset offset += instruction_size(inst) +<<<<<<< HEAD def update_offsets(instructions: Sequence[Instruction]) -> None: +======= +def update_offsets(instructions) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) offset = 0 for inst in instructions: inst.offset = offset offset += instruction_size(inst) +<<<<<<< HEAD def debug_bytes(*args: bytes) -> str: +======= +def debug_bytes(*args) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) index = range(max(map(len, args))) result = [ " ".join(f"{x:03}" for x in arg) @@ -1381,9 +1607,15 @@ def debug_bytes(*args: bytes) -> str: return "bytes mismatch\n" + "\n".join(result) +<<<<<<< HEAD def debug_checks(code: types.CodeType) -> None: """Make sure our assembler produces same bytes as we start with""" dode, _ = transform_code_object(code, lambda x, y: None, safe=True) +======= +def debug_checks(code): + """Make sure our assembler produces same bytes as we start with""" + dode = transform_code_object(code, lambda x, y: None, safe=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert code.co_code == dode.co_code, debug_bytes(code.co_code, dode.co_code) assert code.co_lnotab == dode.co_lnotab, debug_bytes(code.co_lnotab, dode.co_lnotab) @@ -1394,7 +1626,11 @@ def debug_checks(code: types.CodeType) -> None: HAS_CONST = set(dis.hasconst) +<<<<<<< HEAD def get_const_index(code_options: dict[str, Any], val: Any) -> int: +======= +def get_const_index(code_options, val) -> int: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i, v in enumerate(code_options["co_consts"]): # NOTE: stronger comparison is required, since we have # examples where two values compare equal but have @@ -1406,6 +1642,7 @@ def get_const_index(code_options: dict[str, Any], val: Any) -> int: return len(code_options["co_consts"]) - 1 +<<<<<<< HEAD def fix_vars( instructions: list[Instruction], code_options: dict[str, Any], @@ -1415,6 +1652,13 @@ def fix_vars( names = {name: idx for idx, name in enumerate(code_options["co_names"])} def get_name_index(name: str) -> int: +======= +def fix_vars(instructions: list[Instruction], code_options, varname_from_oparg=None): + # compute instruction arg from argval if arg is not provided + names = {name: idx for idx, name in enumerate(code_options["co_names"])} + + def get_name_index(name) -> int: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: idx = names[name] except KeyError: @@ -1449,7 +1693,11 @@ def get_name_index(name: str) -> int: } for i in range(len(instructions)): +<<<<<<< HEAD def should_compute_arg() -> bool: +======= + def should_compute_arg(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # argval is prioritized over arg return instructions[i].argval is not _NotProvided @@ -1517,7 +1765,11 @@ def should_compute_arg() -> bool: instructions[i].arg = idx +<<<<<<< HEAD def clear_instruction_args(instructions: list[Instruction]) -> None: +======= +def clear_instruction_args(instructions): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Clear the instruction arg for instructions that have argvals. # Useful for using dis'd bytecode within generated bytecode. for inst in instructions: @@ -1574,6 +1826,7 @@ def get_code_keys() -> list[str]: return keys +<<<<<<< HEAD def transform_code_object( code: types.CodeType, transformations: Callable[ @@ -1581,23 +1834,36 @@ def transform_code_object( ], safe: bool = False, ) -> tuple[types.CodeType, Optional["DynamoTracerOutput"]]: +======= +def transform_code_object(code, transformations, safe=False) -> types.CodeType: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) keys = get_code_keys() code_options = {k: getattr(code, k) for k in keys} assert len(code_options["co_varnames"]) == code_options["co_nlocals"] instructions = cleaned_instructions(code, safe) +<<<<<<< HEAD # propagate line nums again for added instructions propagate_line_nums(instructions) tracer_output = transformations(instructions, code_options) _, bytecode = clean_and_assemble_instructions(instructions, keys, code_options) return bytecode, tracer_output +======= + propagate_line_nums(instructions) + + transformations(instructions, code_options) + return clean_and_assemble_instructions(instructions, keys, code_options)[1] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def clean_and_assemble_instructions( instructions: list[Instruction], keys: list[str], code_options: dict[str, Any] ) -> tuple[list[Instruction], types.CodeType]: +<<<<<<< HEAD remove_graph_break_if_leaf_instructions(instructions) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # also implicitly checks for no duplicate instructions check_inst_exn_tab_entries_valid(instructions) @@ -1636,7 +1902,11 @@ def clean_and_assemble_instructions( return instructions, types.CodeType(*[code_options[k] for k in keys]) +<<<<<<< HEAD def populate_kw_names_argval(instructions: Sequence[Instruction], consts: Any) -> None: +======= +def populate_kw_names_argval(instructions, consts): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for inst in instructions: if inst.opname == "KW_NAMES": inst.argval = consts[inst.arg] @@ -1644,7 +1914,11 @@ def populate_kw_names_argval(instructions: Sequence[Instruction], consts: Any) - # If safe=True, we do not make any bytecode modifications. # Mainly used for debugging bytecode_transformation (see debug_checks) +<<<<<<< HEAD def cleaned_instructions(code: types.CodeType, safe: bool = False) -> list[Instruction]: +======= +def cleaned_instructions(code, safe=False) -> list[Instruction]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instructions = _cached_cleaned_instructions(code, safe) # We have a lot of code that implicitly mutates the instruction array. We # could do better here by making the copies explicit when necessary. @@ -1652,7 +1926,11 @@ def cleaned_instructions(code: types.CodeType, safe: bool = False) -> list[Instr # Copy an instructions array, making sure to remap the individual instruction targets. +<<<<<<< HEAD def _clone_instructions(instructions: Sequence[Instruction]) -> list[Instruction]: +======= +def _clone_instructions(instructions): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This is super hot and this is the fastest way to do this (tried copy.copy # and dataclasses.replace). copied = [ @@ -1674,10 +1952,17 @@ def _clone_instructions(instructions: Sequence[Instruction]) -> list[Instruction remap = dict(zip(instructions, copied)) # Handle `None` in the remapper so we don't need an extra `if`. +<<<<<<< HEAD remap[None] = None # type: ignore[index, assignment] for i in copied: i.target = remap[i.target] # type: ignore[index] +======= + remap[None] = None + + for i in copied: + i.target = remap[i.target] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if entry := i.exn_tab_entry: i.exn_tab_entry = InstructionExnTabEntry( remap[entry.start], @@ -1690,12 +1975,17 @@ def _clone_instructions(instructions: Sequence[Instruction]) -> list[Instruction @functools.lru_cache +<<<<<<< HEAD def _cached_cleaned_instructions( code: types.CodeType, safe: bool = False ) -> Sequence[Instruction]: instructions = list(map(convert_instruction, dis.get_instructions(code))) # propagate now in case we remove some instructions propagate_line_nums(instructions) +======= +def _cached_cleaned_instructions(code, safe=False) -> Sequence[Instruction]: + instructions = list(map(convert_instruction, dis.get_instructions(code))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) check_offsets(instructions) if sys.version_info >= (3, 11): populate_kw_names_argval(instructions, code.co_consts) @@ -1713,8 +2003,11 @@ def _cached_cleaned_instructions( remove_binary_store_slice(instructions) if sys.version_info >= (3, 13): remove_fused_load_store(instructions) +<<<<<<< HEAD if config.debug_force_graph_break_on_leaf_return: add_graph_break_if_leaf_instructions(instructions) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 11): update_offsets(instructions) devirtualize_jumps(instructions) @@ -1724,7 +2017,11 @@ def _cached_cleaned_instructions( _unique_id_counter = itertools.count() +<<<<<<< HEAD def unique_id(name: str, with_uuid: bool = False) -> str: +======= +def unique_id(name, with_uuid=False) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ret = f"{name}_{next(_unique_id_counter)}" if with_uuid: ret += f"_{uuid.uuid4()}".replace("-", "_") @@ -1736,12 +2033,16 @@ def is_generator(code: types.CodeType) -> bool: return (code.co_flags & co_generator) > 0 +<<<<<<< HEAD def bytecode_from_template( fn: Callable[..., Any], varname_map: Optional[Mapping[Any, Any]] = None, noreturn: bool = True, noprefix: bool = True, ) -> list[Instruction]: +======= +def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Generates bytecode from a template function `fn` for use in dynamo bytecode generation. diff --git a/torch/_dynamo/cache_size.py b/torch/_dynamo/cache_size.py index d1a46742f37ac..3b19ab907736a 100644 --- a/torch/_dynamo/cache_size.py +++ b/torch/_dynamo/cache_size.py @@ -1,7 +1,14 @@ +<<<<<<< HEAD import logging import weakref from dataclasses import dataclass from typing import Any, Optional +======= +# mypy: allow-untyped-defs +import logging +import weakref +from dataclasses import dataclass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._guards import CompileId @@ -9,7 +16,11 @@ from .types import DynamoFrameType +<<<<<<< HEAD log: logging.Logger = logging.getLogger(__name__) +======= +log = logging.getLogger(__name__) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ [Note on cache size limit] @@ -99,9 +110,13 @@ def will_compilation_exceed_specific_limit(self, limit: int) -> bool: return self.num_cache_entries_with_same_id_matched_objs >= limit +<<<<<<< HEAD def _get_weakref_from_f_locals( frame: DynamoFrameType, local_name: str ) -> Optional[weakref.ref[Any]]: +======= +def _get_weakref_from_f_locals(frame: DynamoFrameType, local_name: str): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = frame.f_locals.get(local_name, None) weak_id = None try: @@ -111,7 +126,11 @@ def _get_weakref_from_f_locals( return weak_id +<<<<<<< HEAD def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry: Any) -> bool: +======= +def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Checks if the ID_MATCH'd objects saved on cache_entry are same as the ones in frame.f_locals. @@ -133,7 +152,11 @@ def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry: Any) -> bool: def compute_cache_size( +<<<<<<< HEAD frame: DynamoFrameType, cache_entry: Any +======= + frame: DynamoFrameType, cache_entry +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> CacheSizeRelevantForFrame: # Walk the linked list to calculate the cache size num_cache_entries = 0 diff --git a/torch/_dynamo/callback.py b/torch/_dynamo/callback.py index 58cfe66baee7a..7b71da93bd6d6 100644 --- a/torch/_dynamo/callback.py +++ b/torch/_dynamo/callback.py @@ -39,7 +39,11 @@ class CallbackTrigger(enum.Enum): # backward compilation can be deferred to runtime LAZY_BACKWARD = 2 # some backends autotune at runtime +<<<<<<< HEAD TRITON_AUTOTUNING = 3 # Temporarily disabled due to spam +======= + TRITON_AUTOTUNING = 3 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # cudagraphs record at runtime CUDAGRAPH_RECORDING = 4 @@ -126,9 +130,15 @@ def install_callbacks( args = CallbackArgs(trigger, compile_id) try: with self.__pending_callbacks_counter_lock: +<<<<<<< HEAD self.__pending_callbacks_counter += 1 if self.__pending_callbacks_counter == 1: self.run_start_callbacks(args) +======= + if self.__pending_callbacks_counter == 0: + self.run_start_callbacks(args) + self.__pending_callbacks_counter += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) yield finally: with self.__pending_callbacks_counter_lock: diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index d929e3270f38d..9bf11c7d5f0ba 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module provides utilities for generating Python bytecode in PyTorch's Dynamo system. It includes functionality for: @@ -16,8 +21,12 @@ import sys import types from collections import Counter +<<<<<<< HEAD from collections.abc import Iterable from typing import Any, Callable, Optional, TYPE_CHECKING, Union +======= +from typing import Optional, TYPE_CHECKING, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch.nn from torch.utils._ordered_set import OrderedSet @@ -54,8 +63,11 @@ if TYPE_CHECKING: +<<<<<<< HEAD from torch._dynamo.variables.builder import GraphArg +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .symbolic_convert import InstructionTranslatorBase @@ -75,8 +87,13 @@ def __init__( tx: "InstructionTranslatorBase", root: Optional[torch.nn.Module] = None, graph_output_var: Optional[str] = None, +<<<<<<< HEAD tempvars: Optional[dict[Union[VariableTracker, Source], Any]] = None, overridden_sources: Optional[dict[Source, Source]] = None, +======= + tempvars=None, + overridden_sources=None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: self.root = root self.top_of_stack: Optional[Union[VariableTracker, Source]] = None @@ -87,7 +104,11 @@ def __init__( # locals, and maps the VariableTracker/Source to the local variable # name. Note that it could map to None initially, in which case we'll # overwrite it to map to real temporary names via `add_cache`. +<<<<<<< HEAD self.tempvars: dict[Union[VariableTracker, Source], Any] = tempvars or {} +======= + self.tempvars = tempvars or {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.tx = tx self.graph_output_var = graph_output_var self.code_options = self.tx.output.code_options @@ -99,9 +120,13 @@ def __init__( # without affecting other components, e.g., guards. self.overridden_sources: dict[Source, Source] = overridden_sources or {} +<<<<<<< HEAD def restore_stack( self, stack_values: list[Any], *, value_from_source: bool = True ) -> None: +======= + def restore_stack(self, stack_values, *, value_from_source=True): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prev = self.value_from_source self.value_from_source &= value_from_source try: @@ -109,6 +134,7 @@ def restore_stack( finally: self.value_from_source = prev +<<<<<<< HEAD def graph_output_vars(self) -> list[VariableTracker]: return [x.variable for x in self.graph_outputs.values()] @@ -121,6 +147,16 @@ def call_reconstruct( def add_push_null( self, gen_fn: Callable[[], None], call_function_ex: bool = False ) -> None: +======= + def graph_output_vars(self): + return [x.variable for x in self.graph_outputs.values()] + + def call_reconstruct(self, value): + res = value.reconstruct(self) + assert res is None, f"reconstruct!=None {value}" + + def add_push_null(self, gen_fn, call_function_ex=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ `gen_fn` generates instructions via PyCodegen methods that push a single callable to the stack. @@ -149,9 +185,13 @@ def add_push_null( # NULL will be at top of stack self.clear_tos() +<<<<<<< HEAD def __call__( self, value: Union[VariableTracker, Source], allow_cache: bool = True ) -> None: +======= + def __call__(self, value, allow_cache=True): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Generate code such that top-of-stack (TOS) is set to value. @@ -306,7 +346,11 @@ def __call__( value.as_tensor(self.tx, torch.float64) ) +<<<<<<< HEAD def gen_fn() -> None: +======= + def gen_fn(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.load_graph_output(graph_outputs[graph_outputs_key].index) output.append(self.create_load_attr("item")) @@ -331,7 +375,11 @@ def gen_fn() -> None: output.extend(create_call_function(1, False)) elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap: +<<<<<<< HEAD def gen_fn() -> None: +======= + def gen_fn(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.load_graph_output(graph_outputs[graph_outputs_key].index) output.append(self.create_load_attr("item")) @@ -372,7 +420,11 @@ def gen_fn() -> None: self.top_of_stack = value +<<<<<<< HEAD def add_graph_output(self, value: VariableTracker) -> int: +======= + def add_graph_output(self, value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph_outputs_key = id(value.as_proxy()) if graph_outputs_key not in self.graph_outputs: self.graph_outputs[graph_outputs_key] = GraphOutputEntry( @@ -380,26 +432,43 @@ def add_graph_output(self, value: VariableTracker) -> int: ) return graph_outputs_key +<<<<<<< HEAD def load_graph_output(self, index: int) -> None: output = self._output assert self.graph_output_var is not None +======= + def load_graph_output(self, index): + output = self._output +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output.append(self.create_load(self.graph_output_var)) output.append(self.create_load_const(index)) output.append(self.create_binary_subscr()) +<<<<<<< HEAD def add_cache(self, value: Union[VariableTracker, Source]) -> None: +======= + def add_cache(self, value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) var = self.new_var() self.tempvars[value] = var self._output.append(self.create_store(var)) +<<<<<<< HEAD def foreach(self, items: Iterable[Union[VariableTracker, Source]]) -> None: +======= + def foreach(self, items): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i in items: self(i) def create_binary_subscr(self) -> Instruction: return create_instruction("BINARY_SUBSCR") +<<<<<<< HEAD def setup_globally_cached(self, name: str, value: Any) -> list[Instruction]: +======= + def setup_globally_cached(self, name, value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Store value in a new global""" name = re.sub(r"[^a-zA-Z0-9_]+", "_", name) f_globals = self.tx.f_globals @@ -409,15 +478,26 @@ def setup_globally_cached(self, name: str, value: Any) -> list[Instruction]: f_globals[name] = value return [self.create_load_global(name, add=True)] +<<<<<<< HEAD def clear_tos(self) -> None: self.top_of_stack = None def append_output(self, inst: Instruction) -> None: +======= + def clear_tos(self): + self.top_of_stack = None + + def append_output(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(inst, Instruction) self._output.append(inst) self.clear_tos() +<<<<<<< HEAD def extend_output(self, insts: list[Instruction]) -> None: +======= + def extend_output(self, insts): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert all(isinstance(x, Instruction) for x in insts) self._output.extend(insts) self.clear_tos() @@ -425,15 +505,24 @@ def extend_output(self, insts: list[Instruction]) -> None: def get_instructions(self) -> list[Instruction]: return self._output +<<<<<<< HEAD def create_load(self, name: str) -> Instruction: assert name in self.code_options["co_varnames"], f"{name} missing" return create_instruction("LOAD_FAST", argval=name) def create_load_closure(self, name: str) -> Instruction: +======= + def create_load(self, name) -> Instruction: + assert name in self.code_options["co_varnames"], f"{name} missing" + return create_instruction("LOAD_FAST", argval=name) + + def create_load_closure(self, name) -> Instruction: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert name in self.cell_and_freevars() inst_name = "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE" return create_instruction(inst_name, argval=name) +<<<<<<< HEAD def create_load_deref(self, name: str) -> Instruction: assert name in self.cell_and_freevars() return create_instruction("LOAD_DEREF", argval=name) @@ -447,11 +536,27 @@ def create_store_deref(self, name: str) -> Instruction: return create_instruction("STORE_DEREF", argval=name) def create_load_global(self, name: str, add: bool = False) -> Instruction: +======= + def create_load_deref(self, name) -> Instruction: + assert name in self.cell_and_freevars() + return create_instruction("LOAD_DEREF", argval=name) + + def create_store(self, name) -> Instruction: + assert name in self.code_options["co_varnames"], f"{name} missing" + return create_instruction("STORE_FAST", argval=name) + + def create_store_deref(self, name) -> Instruction: + assert name in self.cell_and_freevars() + return create_instruction("STORE_DEREF", argval=name) + + def create_load_global(self, name, add=False) -> Instruction: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if add: self.tx.output.update_co_names(name) assert name in self.code_options["co_names"], f"{name} not in co_names" return create_instruction("LOAD_GLOBAL", argval=name) +<<<<<<< HEAD def create_load_const(self, value: Any) -> Instruction: return create_load_const(value) @@ -466,10 +571,27 @@ def call_method(self, nargs: int) -> None: self.extend_output(create_call_method(nargs)) def create_load_attr(self, name: str) -> Instruction: +======= + def create_load_const(self, value) -> Instruction: + return create_load_const(value) + + def create_load_const_unchecked(self, value) -> Instruction: + return create_load_const(value, checked=False) + + def load_method(self, name): + self.tx.output.update_co_names(name) + self.append_output(create_load_method(name)) + + def call_method(self, nargs): + self.extend_output(create_call_method(nargs)) + + def create_load_attr(self, name) -> Instruction: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if name not in self.code_options["co_names"]: self.code_options["co_names"] += (name,) return create_instruction("LOAD_ATTR", argval=name) +<<<<<<< HEAD def load_attr(self, name: str) -> None: self.append_output(self.create_load_attr(name)) @@ -477,16 +599,32 @@ def create_load_attrs(self, names: str) -> list[Instruction]: return [self.create_load_attr(name) for name in names.split(".")] def create_store_attr(self, name: str) -> Instruction: +======= + def load_attr(self, name): + self.append_output(self.create_load_attr(name)) + + def create_load_attrs(self, names): + return [self.create_load_attr(name) for name in names.split(".")] + + def create_store_attr(self, name) -> Instruction: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if name not in self.code_options["co_names"]: self.code_options["co_names"] += (name,) return create_instruction("STORE_ATTR", argval=name) +<<<<<<< HEAD def store_attr(self, name: str) -> None: self.append_output(self.create_store_attr(name)) def load_function_name( self, fn_name: str, push_null: bool, num_on_stack: int = 0 ) -> list[Instruction]: +======= + def store_attr(self, name): + self.append_output(self.create_store_attr(name)) + + def load_function_name(self, fn_name, push_null, num_on_stack=0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Load the global fn_name on the stack num_on_stack down""" output = [] if push_null and sys.version_info >= (3, 11): @@ -507,7 +645,11 @@ def load_function_name( ) return output +<<<<<<< HEAD def rot_n(self, n: int) -> list[Instruction]: +======= + def rot_n(self, n): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: return create_rot_n(n) except AttributeError: @@ -520,6 +662,7 @@ def rot_n(self, n: int) -> list[Instruction]: create_instruction("UNPACK_SEQUENCE", arg=n), ] +<<<<<<< HEAD def pop_top(self) -> None: self.append_output(create_instruction("POP_TOP")) @@ -543,17 +686,42 @@ def make_function_with_closure( push_null: bool, num_on_stack: int = 0, ) -> None: +======= + def pop_top(self): + self.append_output(create_instruction("POP_TOP")) + + def call_function(self, nargs: int, push_null: bool): + self.extend_output(create_call_function(nargs, push_null=push_null)) + + def dup_top(self): + self.append_output(create_dup_top()) + + def store(self, varname): + self.append_output(self.create_store(varname)) + + def load_deref(self, varname): + self.append_output(self.create_load_deref(varname)) + + def make_function_with_closure( + self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0 + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) freevars = code.co_freevars assert freevars output = self._output +<<<<<<< HEAD def gen_fn() -> None: self.clear_tos() +======= + def gen_fn(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars` # requires that in the generated bytecode, these cells would keep # their original local names, which we ensure via # `CellVariable.local_name`. for var in freevars: +<<<<<<< HEAD if tx is self.tx: # root frame assert var in self.cell_and_freevars() output.append(self.create_load_closure(var)) @@ -561,6 +729,10 @@ def gen_fn() -> None: assert var in tx.cell_and_freevars() assert tx.post_prune_cell_and_freevars self(tx.post_prune_cell_and_freevars[var]) +======= + assert var in self.cell_and_freevars() + output.append(self.create_load_closure(var)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output.append(create_instruction("BUILD_TUPLE", arg=len(freevars))) output.append(self.create_load_const(code)) if sys.version_info < (3, 11): @@ -584,7 +756,11 @@ def gen_fn() -> None: output.extend(self.rot_n(num_on_stack + 1)) self.clear_tos() +<<<<<<< HEAD def create_load_python_module(self, mod: types.ModuleType) -> Instruction: +======= + def create_load_python_module(self, mod) -> Instruction: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Generate a LOAD_GLOBAL instruction to fetch a given python module. """ @@ -612,7 +788,11 @@ def make_call_generated_code(self, fn_name: str) -> None: seen_sources: OrderedSet[Source] = OrderedSet() +<<<<<<< HEAD def collect_temp_source(source: Source) -> None: +======= + def collect_temp_source(source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if source in seen_sources: # This source is used at least twice, so it can be reused self.mark_source_temp(source) @@ -678,10 +858,14 @@ def collect_temp_source(source: Source) -> None: self.extend_output(create_call_function(len(graphargs), False)) +<<<<<<< HEAD def create_import_name(self, module_name: str) -> Instruction: return create_instruction("IMPORT_NAME", argval=module_name) def load_import_from(self, module_name: str, object_name: str) -> None: +======= + def load_import_from(self, module_name, object_name) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) source = AttrSource(self.tx.import_source(module_name), object_name) # Note: This approach is somewhat aggressive because typically, a source is marked # as a tempvar only when it is used more than once. In this case, we're marking it @@ -690,9 +874,13 @@ def load_import_from(self, module_name: str, object_name: str) -> None: self.mark_source_temp(source) self(source) +<<<<<<< HEAD def create_call_function_kw( self, nargs: int, kw_names: Iterable[str], push_null: bool ) -> list[Instruction]: +======= + def create_call_function_kw(self, nargs, kw_names, push_null) -> list[Instruction]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 13): output = create_call_function(nargs, push_null) assert output[-1].opname == "CALL" @@ -716,5 +904,9 @@ def create_call_function_kw( create_instruction("CALL_FUNCTION_KW", arg=nargs), ] +<<<<<<< HEAD def create_delete(self, value: object) -> Instruction: +======= + def create_delete(self, value) -> Instruction: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return create_instruction("DELETE_FAST", argval=value) diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 84145d64f38a4..1fcf75b15342b 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Provides functionality for compiling PyTorch's autograd (automatic differentiation) system. @@ -20,12 +25,19 @@ import operator import time from collections import Counter, defaultdict +<<<<<<< HEAD from collections.abc import Generator, Sequence from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch import torch.utils._pytree as pytree from torch._dispatch.python import enable_python_dispatcher +======= +from typing import Optional, TYPE_CHECKING, Union + +import torch +import torch.utils._pytree as pytree +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo.external_utils import ( call_accumulate_grad, call_backward, @@ -44,11 +56,18 @@ AutogradLazyBackwardCompileInfo, CachedAutogradLazyBackwardCompileInfo, ) +<<<<<<< HEAD from torch._guards import compile_context, CompileContext, CompileId, Source from torch._logging import getArtifactLogger, trace_structured from torch._prims_common import clone_preserve_strides from torch._subclasses import FakeTensorMode from torch._subclasses.fake_tensor import FakeTensor +======= +from torch._guards import compile_context, CompileContext, CompileId +from torch._logging import getArtifactLogger, trace_structured +from torch._prims_common import clone_preserve_strides +from torch._subclasses import FakeTensorMode +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx import GraphModule from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.proxy_tensor import ( @@ -62,7 +81,10 @@ ) from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv from torch.fx.traceback import preserve_node_meta, set_stack_trace +<<<<<<< HEAD from torch.types import FloatLikeType, IntLikeType +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._ordered_set import OrderedSet from torch.utils._traceback import CapturedTraceback @@ -81,23 +103,39 @@ verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose") +<<<<<<< HEAD def snapshot_verbose_logging_enabled() -> bool: +======= +def snapshot_verbose_logging_enabled(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch._logging._internal.log_state.is_artifact_enabled( "compiled_autograd_verbose" ) +<<<<<<< HEAD def snapshot_cudagraph_enabled() -> bool: return torch._inductor.config.triton.cudagraphs def maybe_clone(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: +======= +def snapshot_cudagraph_enabled(): + return torch._inductor.config.triton.cudagraphs + + +def maybe_clone(x): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if x is not None: return clone_preserve_strides(x) return x +<<<<<<< HEAD def extract_bw_module(CompiledFunction: Any) -> Callable[..., Any]: +======= +def extract_bw_module(CompiledFunction): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance( CompiledFunction._lazy_backward_info, AutogradLazyBackwardCompileInfo ): @@ -125,13 +163,21 @@ def extract_bw_module(CompiledFunction: Any) -> Callable[..., Any]: # So different semantics are needed, this implementation below will check # for NaNs at the end of the autograd call, instead of after each node class NaNChecker: +<<<<<<< HEAD def __init__(self, accumulate_grad: bool) -> None: +======= + def __init__(self, accumulate_grad: bool): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.accumulate_grad = accumulate_grad self.params_indices: list[int] = [] self.params_to_check: dict[str, torch.Tensor] = {} self.output_names: list[str] = [] +<<<<<<< HEAD def prep_with_graph(self, graph: torch.fx.Graph) -> None: +======= + def prep_with_graph(self, graph: torch.fx.Graph): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inputs_node = next(iter(graph.nodes)) acc_grad_nodes = graph.find_nodes( op="call_function", target=call_accumulate_grad @@ -155,7 +201,11 @@ def prep_with_graph(self, graph: torch.fx.Graph) -> None: self.output_names = [node.name for node in output_nodes] +<<<<<<< HEAD def prep_with_inputs(self, inputs: tuple[torch.Tensor]) -> None: +======= + def prep_with_inputs(self, inputs: tuple[torch.Tensor]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not self.accumulate_grad: # Using .grad, nothing to prep return @@ -171,7 +221,11 @@ def prep_with_inputs(self, inputs: tuple[torch.Tensor]) -> None: self.params_to_check[f"inputs[{idx}]"] = inputs[idx] +<<<<<<< HEAD def check(self, out: tuple[torch.Tensor]) -> None: +======= + def check(self, out: tuple[torch.Tensor]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.accumulate_grad: # Using .backward, graph outputs are empty assert not out @@ -204,6 +258,7 @@ def check(self, out: tuple[torch.Tensor]) -> None: # function is called. It's possible to avoid lazy binding and instead bind # all of this upfront (perhaps at import time) via codegen changes. class OpNamespace: +<<<<<<< HEAD def __init__(self) -> None: self.custom_function_name_counter: Counter[str] = Counter() @@ -214,6 +269,12 @@ def add( is_custom_function: bool, is_traceable: bool, ) -> str: +======= + def __init__(self): + self.custom_function_name_counter: Counter[str] = Counter() + + def add(self, name, fn, is_custom_function, is_traceable): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_custom_function: name = "CppNode" + name count = self.custom_function_name_counter[name] @@ -227,30 +288,50 @@ def add( else: # C++ autograd function was not marked as traceable # Dynamo can't dry run it at compile time, so must fallback to eager +<<<<<<< HEAD @torch._dynamo.disable # type: ignore[misc] def run_non_traceable_cpp_in_eager(*args: Any, **kwargs: Any) -> Any: +======= + @torch._dynamo.disable + def run_non_traceable_cpp_in_eager(*args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return result(*args, **kwargs) setattr(self, name, run_non_traceable_cpp_in_eager) return name +<<<<<<< HEAD def get(self, name: str) -> Any: +======= + def get(self, name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return getattr(self, name) class Op: +<<<<<<< HEAD def __init__( self, name: str, fn: Callable[..., Any], is_custom_function: bool ) -> None: +======= + def __init__(self, name, fn, is_custom_function): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.fn = fn self.is_custom_function = is_custom_function self.__name__ = name self.__module__ = "torch._dynamo.compiled_autograd.ops" +<<<<<<< HEAD def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.fn(*args, **kwargs) def __repr__(self) -> str: +======= + def __call__(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + def __repr__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.__module__ + "." + self.__name__ @@ -270,7 +351,11 @@ def __repr__(self) -> str: COMPILE_COUNTER = itertools.count() +<<<<<<< HEAD def make_compile_context(compiled_autograd_id: int) -> Any: +======= +def make_compile_context(compiled_autograd_id): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return compile_context( CompileContext( CompileId( @@ -283,7 +368,11 @@ def make_compile_context(compiled_autograd_id: int) -> Any: class AutogradCompilerInstance: +<<<<<<< HEAD def __init__(self, compiler_fn: Callable[..., Any]) -> None: +======= + def __init__(self, compiler_fn) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.compiler_fn = compiler_fn self.stack = contextlib.ExitStack() self.close = self.stack.close @@ -297,12 +386,20 @@ def __init__(self, compiler_fn: Callable[..., Any]) -> None: self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic") self.hooks_proxy: Optional[Proxy] = None +<<<<<<< HEAD def wrap_fake(self, x: torch.Tensor, source: Optional[Source]) -> FakeTensor: +======= + def wrap_fake(self, x, source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(x, torch.Tensor) return self.fake_tensor_mode.from_tensor(x, source=source) @staticmethod +<<<<<<< HEAD def source(name: str, idx: Any) -> GetItemSource: +======= + def source(name, idx) -> GetItemSource: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return GetItemSource(LocalSource(name), idx) def begin_capture( @@ -313,7 +410,11 @@ def begin_capture( origins: list[list[tuple[int, str]]], accumulate_grad: bool, check_nans: bool, +<<<<<<< HEAD ) -> tuple[str, list[torch.Tensor], list[IntLikeType], list[FloatLikeType]]: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) counters["compiled_autograd"]["captures"] += 1 self.id = next(COMPILE_COUNTER) self.aot_id_counter: dict[int, int] = defaultdict(int) @@ -345,10 +446,13 @@ def begin_capture( self.stack.enter_context(preserve_node_meta()) inputs_origins, sizes_origins, scalars_origins = origins +<<<<<<< HEAD # Turn on PythonDispatcher during initial trace to make it identifiable # that tracing is happening, which is needed to prevent hashing symints self.stack.enter_context(enable_python_dispatcher()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # tensor inputs to fake tensors x = inputs[0] # mypy will complain about unbound x try: @@ -361,7 +465,11 @@ def begin_capture( self.bind_objects_to_proxies(inputs, args_proxy, inputs_origins) # size inputs to symints +<<<<<<< HEAD sym_sizes = [ +======= + sizes = [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.shape_env.create_unspecified_symint_and_symbol( val, self.source("sizes", idx), @@ -373,8 +481,13 @@ def begin_capture( # We want to mark every size as dynamic, but since there's no way to # mark a primitive `int` as dynamic, we need to wrap it in a tensor. # In the graph, we unwrap it with `unwrap_maybe_dynamic_int` back into a primitive. +<<<<<<< HEAD proxies = [self.sizes_proxy[i] for i in range(len(sym_sizes))] # type: ignore[index] for i, symint in enumerate(sym_sizes): +======= + proxies = [self.sizes_proxy[i] for i in range(len(sizes))] # type: ignore[index] + for i, symint in enumerate(sizes): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) proxies[i] = self.fx_tracer.create_proxy( "call_function", unwrap_maybe_dynamic_int, @@ -382,7 +495,11 @@ def begin_capture( {}, ) self.symnode_proxy_lookup[symint.node] = proxies[i] +<<<<<<< HEAD proxies = self.bind_objects_to_proxies(sym_sizes, proxies, sizes_origins) +======= + proxies = self.bind_objects_to_proxies(sizes, proxies, sizes_origins) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for idx, val in enumerate(scalars): source = self.source("scalars", idx) @@ -422,14 +539,23 @@ def begin_capture( return ( str(CompileContext.current_compile_id()), inputs, +<<<<<<< HEAD sym_sizes, scalars, # type: ignore[return-value] +======= + sizes, + scalars, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def log_compile_reasons( self, compile_reasons: list[str], +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert compile_reasons trace_structured( "artifact", @@ -442,6 +568,7 @@ def log_compile_reasons( def proxy_call_aot_backward( self, +<<<<<<< HEAD pinputs: Sequence[Any], psaved_tensors: Sequence[torch.Tensor], saved_tensors: Sequence[torch.Tensor], @@ -449,6 +576,15 @@ def proxy_call_aot_backward( ctx: Any, maybe_backward_state_idx: Optional[int], ) -> Sequence[Any]: +======= + pinputs, + psaved_tensors, + saved_tensors, + pctx, + ctx, + maybe_backward_state_idx, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The AOTBackward call consists of three things: the prologue, the # backward graph, and the epilogue. # Our strategy is: @@ -478,11 +614,15 @@ def proxy_call_aot_backward( ) @torch._dynamo.allow_in_graph # type: ignore[misc] +<<<<<<< HEAD def call_aot_bwd_prologue( ctx_saved_tensors: Sequence[torch.Tensor], ctx_symints: Sequence[IntLikeType], *flat_args: Sequence[Any], ) -> Any: +======= + def call_aot_bwd_prologue(ctx_saved_tensors, ctx_symints, *flat_args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = torch._functorch._aot_autograd.runtime_wrappers._backward_prologue_functional( ctx_saved_tensors, ctx_symints, @@ -508,8 +648,13 @@ def call_aot_bwd_prologue( pbackward_state = self.hooks_proxy[maybe_backward_state_idx] # type: ignore[index] # Copy-paste the AOT backward graph into the compiled autograd graph +<<<<<<< HEAD def copy_paste_aot_backward_graph() -> list[torch.Tensor]: def num_inputs(graph: torch.fx.Graph) -> int: +======= + def copy_paste_aot_backward_graph(): + def num_inputs(graph): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) num_args = 0 for node in graph.nodes: if node.op == "placeholder": @@ -521,7 +666,11 @@ def num_inputs(graph: torch.fx.Graph) -> int: # set up the proxy inputs to bw_module # the calling convention is: [*symints, *args (primals and tangents), backward_state] +<<<<<<< HEAD num_args = num_inputs(bw_module.graph) # type: ignore[attr-defined] +======= + num_args = num_inputs(bw_module.graph) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pall_args = [ pgrads[i] for i in range(num_args - int(pbackward_state is not None)) ] @@ -547,11 +696,19 @@ def num_inputs(graph: torch.fx.Graph) -> int: deduped_aot_id += f"_{self.aot_id_counter[aot_id]}" self.aot_id_counter[aot_id] += 1 +<<<<<<< HEAD def make_unique(node_name: str) -> str: # make it both informative and unique return f"aot{deduped_aot_id}_{node_name}" for node in bw_module.graph.nodes: # type: ignore[attr-defined] +======= + def make_unique(node_name): + # make it both informative and unique + return f"aot{deduped_aot_id}_{node_name}" + + for node in bw_module.graph.nodes: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if node.op == "placeholder": ph = pall_args[args_idx].node ph.name = make_unique(node.name) @@ -599,7 +756,11 @@ def make_unique(node_name: str) -> str: # In general we don't know what the shapes of the outputs are, so allocate # some dummy sizes for them. +<<<<<<< HEAD def dummy() -> torch.Tensor: +======= + def dummy(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with disable_proxy_modes_tracing(): return torch.zeros(0, 0, 0, 0, 123) @@ -611,11 +772,17 @@ def dummy() -> torch.Tensor: outputs = copy_paste_aot_backward_graph() +<<<<<<< HEAD def proxy_subclass_constructor( subclass_meta: Any, is_runtime: bool, unwrapped_args: Sequence[Any] ) -> torch.Tensor: @torch._dynamo.allow_in_graph # type: ignore[misc] def make_subclass(*unwrapped_args: Any) -> Any: +======= + def proxy_subclass_constructor(subclass_meta, is_runtime, unwrapped_args): + @torch._dynamo.allow_in_graph + def make_subclass(*unwrapped_args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime) punwrapped_args = pytree.tree_map(self.to_proxy, unwrapped_args) @@ -642,6 +809,7 @@ def make_subclass(*unwrapped_args: Any) -> Any: def proxy_call_backward( self, +<<<<<<< HEAD inputs: Sequence[Any], output_metadatas: Sequence[Any], saved_tensors: Sequence[torch.Tensor], @@ -649,6 +817,15 @@ def proxy_call_backward( ctx: torch.autograd.function.BackwardCFunction, maybe_backward_state_idx: Optional[int], ) -> tuple[Optional[torch.Tensor], ...]: +======= + inputs, + output_metadatas, + saved_tensors, + backward_idx: int, + ctx: torch.autograd.function.BackwardCFunction, + maybe_backward_state_idx: Optional[int], + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.hooks_proxy is not None pctx = self.hooks_proxy[backward_idx] # type: ignore[index] pinputs = self.to_proxy(inputs) @@ -693,6 +870,7 @@ def proxy_call_backward( def call_copy_slices_prologue( self, +<<<<<<< HEAD inputs: Sequence[Any], base_sizes: Sequence[Any], base_strides: Sequence[Any], @@ -701,6 +879,16 @@ def call_copy_slices_prologue( view_strides: Sequence[Any], view_storage_offset: Any, ) -> Sequence[torch.Tensor]: +======= + inputs, + base_sizes, + base_strides, + base_storage_offset, + view_sizes, + view_strides, + view_storage_offset, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = ( inputs, self.to_proxy(base_sizes), @@ -712,6 +900,7 @@ def call_copy_slices_prologue( ) return self.proxy_call(copy_slices_prologue, args, [None] * 3) +<<<<<<< HEAD def call_copy_slices_epilogue( self, needs_input_grad: Sequence[bool], @@ -719,17 +908,25 @@ def call_copy_slices_epilogue( res: Sequence[Any], grad_slice: torch.Tensor, ) -> Sequence[torch.Tensor]: +======= + def call_copy_slices_epilogue(self, needs_input_grad, result, res, grad_slice): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.proxy_call( copy_slices_epilogue, (needs_input_grad, result, res, grad_slice), [None] * len(needs_input_grad), ) +<<<<<<< HEAD def allocate_dummy(self) -> torch.Tensor: +======= + def allocate_dummy(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with disable_proxy_modes_tracing(): # Weird quantity so it's easy to grep return torch.zeros([0, 123456789]) +<<<<<<< HEAD def bind_function( self, fn_name: str, @@ -747,13 +944,24 @@ def apply_functional( args: Any, output_metadata: Sequence[Any], ) -> Sequence[torch.Tensor]: +======= + def bind_function(self, fn_name, fn, is_custom_function, is_traceable): + """Binds ops.fn_name = fn""" + return ops.add(fn_name, fn, is_custom_function, is_traceable) + + def apply_functional(self, fn_name, grads, args, output_metadata): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Proxies a call to ops.fn_name(grads, *args) into the graph""" op = ops.get(fn_name) return self.proxy_call(op, (grads, *args), output_metadata) +<<<<<<< HEAD def proxy_call( self, fn: Callable[..., Any], args: Any, output_metadata: Sequence[Any] ) -> Sequence[torch.Tensor]: +======= + def proxy_call(self, fn, args, output_metadata): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Proxies a call to fn(*args) into the graph""" flat_args, _ = pytree.tree_flatten(args) proxy_args = pytree.tree_map(lambda e: self.to_proxy(e), args) @@ -764,9 +972,13 @@ def proxy_call( self.bind_objects_to_proxies(result, [proxy_out[i] for i in range(len(result))]) return result +<<<<<<< HEAD def validate_outputs( self, _: Any, outputs: Sequence[Any], args: Any, output_metadata: Sequence[Any] ) -> Sequence[torch.Tensor]: +======= + def validate_outputs(self, _, outputs, args, output_metadata): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Proxies a call to ops.validate_outputs(outputs, *args) into the graph""" op = ops.get("validate_outputs") proxy_args = pytree.tree_map(self.to_proxy, (outputs, *args)) @@ -777,7 +989,11 @@ def validate_outputs( self.bind_objects_to_proxies(outputs, new_proxy_outputs) return outputs +<<<<<<< HEAD def accumulate(self, old_var: Any, new_var: Any) -> torch.Tensor: +======= + def accumulate(self, old_var, new_var): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) old_var_proxy = self.to_proxy(old_var) new_var_proxy = self.to_proxy(new_var) proxy_out = self.fx_tracer.create_proxy( @@ -787,9 +1003,13 @@ def accumulate(self, old_var: Any, new_var: Any) -> torch.Tensor: self.bind_objects_to_proxies([result], [proxy_out]) return result +<<<<<<< HEAD def accumulate_grad( self, variable: torch.Tensor, grad: torch.Tensor, has_post_hooks: bool ) -> None: +======= + def accumulate_grad(self, variable, grad, has_post_hooks): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.fx_tracer.create_proxy( "call_function", call_accumulate_grad, @@ -801,9 +1021,13 @@ def accumulate_grad( kwargs={}, ) +<<<<<<< HEAD def proxy_call_hook( self, hook: Callable[..., Any], *args: Any, **kwargs: Any ) -> torch.fx.Proxy: +======= + def proxy_call_hook(self, hook, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.fx_tracer.create_proxy( "call_function", call_hook, @@ -814,7 +1038,11 @@ def proxy_call_hook( kwargs, ) +<<<<<<< HEAD def unpack_hook(self, hook_id: int, data_id: int) -> torch.Tensor: +======= + def unpack_hook(self, hook_id, data_id): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] data = self.packed_data_proxy[data_id] # type: ignore[index] @@ -827,9 +1055,13 @@ def unpack_hook(self, hook_id: int, data_id: int) -> torch.Tensor: self.bind_objects_to_proxies([out], [proxy]) return out +<<<<<<< HEAD def tensor_pre_hook( self, inputs: list[torch.Tensor], hook_id: int, i: int ) -> list[torch.Tensor]: +======= + def tensor_pre_hook(self, inputs, hook_id, i: int): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] proxy = self.proxy_call_hook( @@ -838,6 +1070,7 @@ def tensor_pre_hook( hook_type="tensor_pre_hook", ) with disable_proxy_modes_tracing(): +<<<<<<< HEAD inputs[i] = maybe_clone(inputs[i]) # type: ignore[assignment] self.bind_objects_to_proxies([inputs[i]], [proxy]) return inputs @@ -845,6 +1078,13 @@ def tensor_pre_hook( def cpp_tensor_pre_hook( self, inputs: list[torch.Tensor], hook_id: int, i: int ) -> list[torch.Tensor]: +======= + inputs[i] = maybe_clone(inputs[i]) + self.bind_objects_to_proxies([inputs[i]], [proxy]) + return inputs + + def cpp_tensor_pre_hook(self, inputs: list[torch.Tensor], hook_id: int, i: int): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) proxy = self.fx_tracer.create_proxy( "call_function", torch._C._dynamo.compiled_autograd.call_cpp_tensor_pre_hooks, @@ -852,11 +1092,19 @@ def cpp_tensor_pre_hook( {}, ) with disable_proxy_modes_tracing(): +<<<<<<< HEAD inputs[i] = maybe_clone(inputs[i]) # type: ignore[assignment] self.bind_objects_to_proxies([inputs[i]], [proxy]) return inputs def pre_hook(self, inputs: Sequence[Any], hook_id: int) -> list[torch.Tensor]: +======= + inputs[i] = maybe_clone(inputs[i]) + self.bind_objects_to_proxies([inputs[i]], [proxy]) + return inputs + + def pre_hook(self, inputs, hook_id): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] proxies = self.proxy_call_hook( @@ -869,9 +1117,13 @@ def pre_hook(self, inputs: Sequence[Any], hook_id: int) -> list[torch.Tensor]: self.bind_objects_to_proxies(inputs, proxies) return inputs +<<<<<<< HEAD def post_hook( self, outputs: list[torch.Tensor], inputs: Sequence[torch.Tensor], hook_id: int ) -> list[torch.Tensor]: +======= + def post_hook(self, outputs, inputs, hook_id): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] proxies = self.proxy_call_hook( @@ -881,6 +1133,7 @@ def post_hook( hook_type="post_hook", ) with disable_proxy_modes_tracing(): +<<<<<<< HEAD outputs = [maybe_clone(x) for x in outputs] # type: ignore[misc] self.bind_objects_to_proxies(outputs, proxies) return outputs @@ -888,6 +1141,13 @@ def post_hook( def post_acc_grad_hook( self, input: torch.Tensor, hook_id: int ) -> list[torch.Tensor]: +======= + outputs = [maybe_clone(x) for x in outputs] + self.bind_objects_to_proxies(outputs, proxies) + return outputs + + def post_acc_grad_hook(self, input, hook_id): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(input, torch.Tensor) assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] @@ -897,16 +1157,26 @@ def post_acc_grad_hook( hook_type="post_acc_grad_hook", ) with disable_proxy_modes_tracing(): +<<<<<<< HEAD res = [maybe_clone(input)] self.bind_objects_to_proxies(res, [proxy]) return res # type: ignore[return-value] +======= + input = [maybe_clone(input)] + self.bind_objects_to_proxies(input, [proxy]) + return input +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Note: [Compiled autograd and cudagraphs] # Eager autograd backward implements scalars as 0-dim tensors, see DivBackward0::other_. # When compiled autograd traces those nodes, it lifts the scalar tensors, resulting in a graph # with some cpu 0-dim tensor inputs. To prevent the entire graph from skipping cudagraph, we move the # scalars tensors to cuda. This works because ATen/prims ops will accept cuda 0-dim tensors too. +<<<<<<< HEAD def move_graph_nodes_to_cuda(self, graph: torch.fx.Graph) -> list[int]: +======= + def move_graph_nodes_to_cuda(self, graph) -> list[int]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) to_move: dict[int, torch.fx.Node] = {} has_cuda_inputs = False nodes = list(graph.nodes) @@ -955,7 +1225,11 @@ def move_graph_nodes_to_cuda(self, graph: torch.fx.Graph) -> list[int]: return [] +<<<<<<< HEAD def is_sym_node(self, node: Any) -> bool: +======= + def is_sym_node(self, node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( isinstance(node, torch.fx.Node) and node.op == "call_function" @@ -963,7 +1237,11 @@ def is_sym_node(self, node: Any) -> bool: in [torch.ops.aten.sym_size.int, torch.ops.aten.sym_numel.default] ) +<<<<<<< HEAD def dce(self) -> None: +======= + def dce(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Most of these removed nodes would have been removed during Dynamo and AOTDispatch # Remove some of these nodes earlier to improve compilation speed @@ -973,7 +1251,11 @@ def dce(self) -> None: unpack_nodes.update(node.users.keys()) assert i == len(_graph_placeholders) - 1 +<<<<<<< HEAD def is_impure(node: torch.fx.Node) -> bool: +======= + def is_impure(node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if node in unpack_nodes or ( node.op == "call_function" and node.target in _impure_targets ): @@ -985,7 +1267,11 @@ def is_impure(node: torch.fx.Node) -> bool: after = len(self.fx_tracer.graph.nodes) verbose_log.debug("DCE removed %d nodes", before - after) +<<<<<<< HEAD def remove_unused_sizes(self) -> set[int]: +======= + def remove_unused_sizes(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) used_sizes = [] unused_sizes = [] @@ -1019,10 +1305,17 @@ def remove_unused_sizes(self) -> set[int]: return used_sizes_idx +<<<<<<< HEAD def create_graph_module(self, id: str) -> GraphModule: return GraphModule(self.fx_tracer.root, self.fx_tracer.graph, id) def end_capture(self, outputs: Any) -> tuple[Callable[..., Any], Any]: +======= + def create_graph_module(self, id): + return GraphModule(self.fx_tracer.root, self.fx_tracer.graph, id) + + def end_capture(self, outputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.fx_tracer.create_proxy( "call_function", FakeCompiledAutogradEngine._exec_final_callbacks_stub, @@ -1100,6 +1393,7 @@ def end_capture(self, outputs: Any) -> tuple[Callable[..., Any], Any]: payload_fn=lambda: graph.print_readable(print_output=False), ) +<<<<<<< HEAD def runtime_wrapper( compiled_fn: Callable[..., Any], inputs: Any, @@ -1108,6 +1402,9 @@ def runtime_wrapper( hooks: Any, packed_inputs: Any, ) -> tuple[Any, Any]: +======= + def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks, packed_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) global in_compiled_autograd_region try: in_compiled_autograd_region = True @@ -1149,22 +1446,38 @@ def runtime_wrapper( return runtime_wrapper, self.compiler_fn(graph) @staticmethod +<<<<<<< HEAD def get_all_nodes(args: Sequence[Any]) -> list[torch.fx.Node]: +======= + def get_all_nodes(args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # filter out non-Node args, like None nodes = [n for n in args if type(n) is torch.fx.Node] return nodes @staticmethod +<<<<<<< HEAD def is_placeholder(node: torch.fx.Node) -> bool: if node.op == "placeholder" or ( node.op == "call_function" and node.target == operator.getitem and node.args[0].op == "placeholder" # type: ignore[union-attr, arg-type] +======= + def is_placeholder(node): + if node.op == "placeholder" or ( + node.op == "call_function" + and node.target == operator.getitem + and node.args[0].op == "placeholder" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): return True return False +<<<<<<< HEAD def reorder_accumulate_grad_nodes(self) -> None: +======= + def reorder_accumulate_grad_nodes(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Usage of AOTAutograd causes all the accumulate_grad_ nodes to get pushed to the end of the graph. This differs from eager mode, which schedules them as soon as possible. This @@ -1185,7 +1498,11 @@ def reorder_accumulate_grad_nodes(self) -> None: if getitem_node is not None: arg.append(getitem_node) +<<<<<<< HEAD def delay_unpack_hook_nodes(self) -> None: +======= + def delay_unpack_hook_nodes(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ We can delay unpack hooks until they are needed, even later than in the eager autograd engine. """ @@ -1198,7 +1515,11 @@ def delay_unpack_hook_nodes(self) -> None: first_user = min(node.users) first_user.prepend(node) +<<<<<<< HEAD def reorder_tensor_pre_hook_nodes(self) -> None: +======= + def reorder_tensor_pre_hook_nodes(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Usage of AOTAutograd causes all the tensor_pre_hook nodes to get pushed to the end of the graph. This differs from eager mode, which schedules @@ -1218,7 +1539,11 @@ def reorder_tensor_pre_hook_nodes(self) -> None: input_node.append(getitem_node) getitem_node.append(node) +<<<<<<< HEAD def reorder_pre_hook_nodes_to_schedule_asap(self) -> None: +======= + def reorder_pre_hook_nodes_to_schedule_asap(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ In this function, we schedule the pre hooks as soon as possible. This does not match eager behavior (schedule pre hook right before its @@ -1246,7 +1571,11 @@ def reorder_pre_hook_nodes_to_schedule_asap(self) -> None: hook_block.append(n) for a, b in zip(to_remove, to_append): input_nodes.remove(a) +<<<<<<< HEAD input_nodes.append(b) # type: ignore[arg-type] +======= + input_nodes.append(b) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) arg = max(input_nodes) # last input if arg is not node.prev and not self.is_placeholder(arg): @@ -1254,7 +1583,11 @@ def reorder_pre_hook_nodes_to_schedule_asap(self) -> None: for n in hook_block: getitem_node.append(n) +<<<<<<< HEAD def reorder_pre_hook_nodes_to_mimic_eager(self) -> None: +======= + def reorder_pre_hook_nodes_to_mimic_eager(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Usage of AOTAutograd causes all the pre_hook nodes to get pushed to the end of the graph. This differs from eager mode, which schedules them @@ -1289,7 +1622,11 @@ def reorder_pre_hook_nodes_to_mimic_eager(self) -> None: for getitem in users: registered_node.prepend(getitem) +<<<<<<< HEAD def reorder_post_acc_grad_hook_nodes(self) -> None: +======= + def reorder_post_acc_grad_hook_nodes(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Usage of AOTAutograd causes all the post_acc_grad_hook nodes to get pushed to the end of the graph. This differs from eager mode, which @@ -1325,7 +1662,11 @@ def reorder_post_acc_grad_hook_nodes(self) -> None: acc_grad_node.append(getitem_node) getitem_node.append(node) +<<<<<<< HEAD def reorder_post_hook_nodes(self) -> None: +======= + def reorder_post_hook_nodes(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Usage of AOTAutograd causes all the post_hook nodes to get pushed to the end of the graph. This differs from eager mode, which schedules them as @@ -1382,7 +1723,11 @@ def reorder_post_hook_nodes(self) -> None: arg.append(getitem_node) getitem_node.append(node) +<<<<<<< HEAD def to_proxy(self, t: Any) -> Any: +======= + def to_proxy(self, t): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if t is None: return None if isinstance(t, list): @@ -1399,11 +1744,16 @@ def to_proxy(self, t: Any) -> Any: return proxy_tensor.proxy def bind_objects_to_proxies( +<<<<<<< HEAD self, objects: Sequence[Any], proxies: Any, origins: Optional[list[tuple[int, str]]] = None, ) -> Sequence[Any]: +======= + self, objects, proxies, origins: Optional[list[tuple[int, str]]] = None + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(proxies, torch.fx.Proxy): if origins: assert len(origins) == len(objects) @@ -1420,7 +1770,11 @@ def bind_objects_to_proxies( track_tensor_tree(objects, proxies, constant=None, tracer=self.fx_tracer) return proxies +<<<<<<< HEAD def bind_backward_state(self, index: int) -> BackwardState: +======= + def bind_backward_state(self, index: int): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.hooks_proxy is not None proxy = self.hooks_proxy[index] # type: ignore[index] bw_state = BackwardState() @@ -1432,7 +1786,11 @@ def set_node_origin( node_name: str, nodecall_index: int, pyobj: Optional[torch.autograd.Function], +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) maybe_aot_id = "" if pyobj is not None: forward_cls = pyobj._forward_cls # type: ignore[attr-defined] @@ -1467,11 +1825,15 @@ def set_node_origin( @contextlib.contextmanager +<<<<<<< HEAD def _enable( compiler_fn: Callable[..., Any], dynamic: bool = True, ignore_active_disable_ctx: bool = True, ) -> Generator[None, None, None]: +======= +def _enable(compiler_fn, dynamic: bool = True, ignore_active_disable_ctx=True): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The entrypoint to enable CA. # It is recommended to enable via `torch._dynamo.config.compiled_autograd = True` rather # than using this context manager directly. If you are torch.compiling the corresponding @@ -1512,8 +1874,12 @@ def _enable( else: # we need to import this, because user might not have imported it if they directly use this context manager # we need to lazily import it, because of circular dependencies +<<<<<<< HEAD if torch.cuda.is_available(): from torch._inductor import cudagraph_trees # noqa: F401 +======= + import torch._inductor.cudagraph_trees +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ( prior_compiler, @@ -1544,7 +1910,11 @@ def _enable( @contextlib.contextmanager +<<<<<<< HEAD def _disable() -> Generator[None, None, None]: +======= +def _disable(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ( prior_compiler, prior_dynamic, @@ -1580,6 +1950,7 @@ def reset() -> None: # Reimplementation of part of CopySlices::apply in Python. # The shared code is really similar so we're not going to try to deduplicate. def copy_slices_prologue( +<<<<<<< HEAD inputs: Sequence[torch.Tensor], base_sizes: Sequence[IntLikeType], base_strides: Sequence[IntLikeType], @@ -1588,6 +1959,16 @@ def copy_slices_prologue( view_strides: Sequence[IntLikeType], view_storage_offset: IntLikeType, ) -> list[torch.Tensor]: +======= + inputs, + base_sizes, + base_strides, + base_storage_offset, + view_sizes, + view_strides, + view_storage_offset, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) grad = inputs[0] result = grad.new_empty_strided(base_sizes, base_strides) assert grad is not None @@ -1599,6 +1980,7 @@ def copy_slices_prologue( # Reimplementation of part of CopySlices::apply in Python. # The shared code is really similar so we're not going to try to deduplicate. +<<<<<<< HEAD def copy_slices_epilogue( needs_input_grad: Sequence[bool], result: torch.Tensor, @@ -1606,14 +1988,22 @@ def copy_slices_epilogue( grad_slice: torch.Tensor, ) -> list[Optional[torch.Tensor]]: grad_inputs: list[Optional[torch.Tensor]] = [None] * len(needs_input_grad) +======= +def copy_slices_epilogue(needs_input_grad, result, res, grad_slice): + grad_inputs = [None] * len(needs_input_grad) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i in range(len(needs_input_grad)): if needs_input_grad[i]: if res[i] is None: continue if i == 0: +<<<<<<< HEAD to_copy = res[i] assert to_copy is not None grad_slice.copy_(to_copy) +======= + grad_slice.copy_(res[i]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) grad_inputs[i] = result else: grad_inputs[i] = res[i] diff --git a/torch/_dynamo/comptime.py b/torch/_dynamo/comptime.py index 2864168dfb82b..d870d3f723a82 100644 --- a/torch/_dynamo/comptime.py +++ b/torch/_dynamo/comptime.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module provides the public comptime interface to TorchDynamo, enabling users to execute arbitrary Python code during symbolic evaluation of their programs. @@ -38,6 +43,7 @@ def my_model(x): import dis import time import traceback +<<<<<<< HEAD from collections.abc import Sequence from typing import Any, Callable, Optional, TextIO, Union @@ -45,6 +51,11 @@ def my_model(x): from torch._dynamo.symbolic_convert import InstructionTranslatorBase from torch._dynamo.variables.base import VariableTracker from torch._subclasses.fake_tensor import FakeTensor +======= +from typing import Optional, Union + +import torch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.experimental.symbolic_shapes import free_symbols from .exc import unimplemented_v2 @@ -64,10 +75,17 @@ class ComptimeVar: actual data in the Tensor is.) """ +<<<<<<< HEAD def __init__(self, v: VariableTracker) -> None: self.__variable = v def as_proxy(self) -> Union[VariableTracker, Sequence[VariableTracker]]: +======= + def __init__(self, v) -> None: + self.__variable = v + + def as_proxy(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Returns an fx.Proxy (or tuple/list of fx.Proxy) representing this variable in the FX graph we are assembling to pass @@ -81,13 +99,21 @@ def as_proxy(self) -> Union[VariableTracker, Sequence[VariableTracker]]: """ return self.__variable.as_proxy() +<<<<<<< HEAD def is_proxy(self) -> bool: +======= + def is_proxy(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Returns True if as_proxy() would succeed. """ return self.__variable.is_proxy() +<<<<<<< HEAD def as_fake(self) -> Union[FakeTensor, torch.SymInt]: +======= + def as_fake(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Returns a "fake" value (either a FakeTensor or a SymInt) representing the variable in question. This only works @@ -104,16 +130,26 @@ def size(self, dim: Optional[int] = None) -> Union[int, torch.SymInt]: Returns the size of the tensor (if dim is None) or the size at the dimension dim. The returned size may be a SymInt. """ +<<<<<<< HEAD return self.as_fake().size(dim) # type: ignore[union-attr, return-value] def python_type(self) -> type: +======= + return self.as_fake().size(dim) + + def python_type(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Returns what type(v) would have returned for the variable at compile time. """ return self.__variable.python_type() +<<<<<<< HEAD def as_python_constant(self) -> Any: +======= + def as_python_constant(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Returns the Python value this variable would have, but only if it is completely known at compile-time (e.g., it is constant). @@ -125,19 +161,31 @@ def as_python_constant(self) -> Any: """ return self.__variable.as_python_constant() +<<<<<<< HEAD def is_python_constant(self) -> bool: +======= + def is_python_constant(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Returns True if as_python_constant would succeed. """ return self.__variable.is_python_constant() +<<<<<<< HEAD def is_dynamic(self) -> bool: +======= + def is_dynamic(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(self.__variable, SymNodeVariable): fs = free_symbols(self.__variable.sym_num) return bool(fs) return False +<<<<<<< HEAD def force_static(self) -> None: +======= + def force_static(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Forces that a value is static, inducing a guard on its specific value """ @@ -151,7 +199,11 @@ def force_static(self) -> None: f"cannot force {self.__variable} ({type(self.__variable)}) static" ) +<<<<<<< HEAD def _i_will_not_complain_if_bc_breaks_VariableTracker(self) -> VariableTracker: +======= + def _i_will_not_complain_if_bc_breaks_VariableTracker(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Returns the internal data structure VariableTracker that Dynamo uses to represent variables at compile time. There are no BC guarantees on @@ -173,10 +225,17 @@ class ComptimeContext: file a feature request at https://github.com/pytorch/pytorch/ """ +<<<<<<< HEAD def __init__(self, tx: InstructionTranslatorBase) -> None: self.__tx = tx def get_local(self, name: str, *, stacklevel: int = 0) -> ComptimeVar: +======= + def __init__(self, tx) -> None: + self.__tx = tx + + def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Retrieve the compile-time known information about a local. """ @@ -189,7 +248,11 @@ def get_local(self, name: str, *, stacklevel: int = 0) -> ComptimeVar: return ComptimeVar(var) +<<<<<<< HEAD def graph_break(self, msg: str = "ComptimeContext.graph_break") -> None: +======= + def graph_break(self, msg="ComptimeContext.graph_break"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Manually trigger a graph break """ @@ -200,14 +263,22 @@ def graph_break(self, msg: str = "ComptimeContext.graph_break") -> None: hints=[], ) +<<<<<<< HEAD def graph(self) -> torch.fx.Graph: +======= + def graph(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Retrieve the partially constructed FX graph that would be passed to the user compiler after compilation. """ return self.__tx.output.graph +<<<<<<< HEAD def assert_static(self, val: ComptimeVar) -> None: +======= + def assert_static(self, val): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Asserts that the int is static (and not dynamic, per dynamic shapes) """ @@ -215,9 +286,13 @@ def assert_static(self, val: ComptimeVar) -> None: "expected static but got dynamic (run with TORCH_LOGS=dynamic for more info)" ) +<<<<<<< HEAD def print_graph( self, *, verbose: bool = True, file: Optional[TextIO] = None ) -> None: +======= + def print_graph(self, *, verbose=True, file=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Print the partially constructed FX graph that would be passed to the user compiler after compilation. @@ -226,6 +301,7 @@ def print_graph( self.__tx.output.graph.python_code("self", verbose=verbose).src, file=file ) +<<<<<<< HEAD def parent(self) -> "ComptimeContext": return ComptimeContext(self.__tx.parent) # type: ignore[arg-type] @@ -241,6 +317,21 @@ def print(self, val: Any, *, file: Optional[TextIO] = None) -> None: def print_disas( self, *, file: Optional[TextIO] = None, stacklevel: int = 0 ) -> None: +======= + def parent(self): + return ComptimeContext(self.__tx.parent) + + def __get_tx(self, stacklevel): + tx = self.__tx + for _ in range(stacklevel): + tx = tx.parent + return tx + + def print(self, val, *, file=None): + print(repr(val), file=file) + + def print_disas(self, *, file=None, stacklevel=0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Print the current series of opcodes being executed (not including parent frames), including where you are in the particular opcode @@ -255,9 +346,13 @@ def print_disas( file=file, ) +<<<<<<< HEAD def print_value_stack( self, *, file: Optional[TextIO] = None, stacklevel: int = 0 ) -> None: +======= + def print_value_stack(self, *, file=None, stacklevel=0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Print the current Python value stack. Note that this is NOT the same as the traceback; use print_bt() to print that. Note that at @@ -272,9 +367,13 @@ def print_value_stack( for s in tx.stack: print(f"- {s.debug_repr()}", file=file) +<<<<<<< HEAD def print_locals( self, *, file: Optional[TextIO] = None, stacklevel: int = 0 ) -> None: +======= + def print_locals(self, *, file=None, stacklevel=0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Print all of the locals available in the current context. By default this view is very limited; you can get more information @@ -284,7 +383,11 @@ def print_locals( for k, v in tx.symbolic_locals.items(): print(f"{k} = {v.debug_repr()}", file=file) +<<<<<<< HEAD def print_bt(self, *, file: Optional[TextIO] = None, stacklevel: int = 0) -> None: +======= + def print_bt(self, *, file=None, stacklevel=0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Print the user code backtrace, starting at the beginning of the frame Dynamo started evaluating. Note that this MAY NOT go all @@ -303,7 +406,11 @@ def print_bt(self, *, file: Optional[TextIO] = None, stacklevel: int = 0) -> Non file=file, ) +<<<<<<< HEAD def print_guards(self, *, file: Optional[TextIO] = None) -> None: +======= + def print_guards(self, *, file=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Print the currently installed guards for the Dynamo context. This does NOT include guards associated with variables that @@ -317,9 +424,13 @@ def print_guards(self, *, file: Optional[TextIO] = None) -> None: file=file, ) +<<<<<<< HEAD def _i_will_not_complain_if_bc_breaks_InstructionTranslator( self, ) -> InstructionTranslatorBase: +======= + def _i_will_not_complain_if_bc_breaks_InstructionTranslator(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Returns the internal data structure InstructionTranslator that Dynamo uses to track state of symbolic evaluation. There are no BC @@ -328,22 +439,31 @@ def _i_will_not_complain_if_bc_breaks_InstructionTranslator( """ return self.__tx +<<<<<<< HEAD def sleep(self, sec: Union[int, float]) -> None: +======= + def sleep(self, sec): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) time.sleep(sec) class _Comptime: @staticmethod +<<<<<<< HEAD def __call__( fn: Callable[[ComptimeContext], Any], fallback_fn: Callable[[], Any] = lambda: None, ) -> Any: +======= + def __call__(fn, fallback_fn=lambda: None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """fn gets called at compile time in TorchDynamo, calls fallback_fn otherwise""" fallback_fn() # Convenience wrappers that are more compact to use @staticmethod +<<<<<<< HEAD def graph_break() -> None: comptime(lambda ctx: ctx.graph_break()) @@ -357,6 +477,21 @@ def print_graph() -> None: @staticmethod def print_disas(*, stacklevel: int = 0) -> None: +======= + def graph_break(): + comptime(lambda ctx: ctx.graph_break()) + + @staticmethod + def print(e): + comptime(lambda ctx: ctx.print(ctx.get_local("e")), lambda: print(e)) + + @staticmethod + def print_graph(): + comptime(lambda ctx: ctx.print_graph()) + + @staticmethod + def print_disas(*, stacklevel=0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comptime( lambda ctx: ctx.print_disas( stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 @@ -364,7 +499,11 @@ def print_disas(*, stacklevel: int = 0) -> None: ) @staticmethod +<<<<<<< HEAD def print_value_stack(*, stacklevel: int = 0) -> None: +======= + def print_value_stack(*, stacklevel=0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comptime( lambda ctx: ctx.print_value_stack( stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 @@ -375,7 +514,11 @@ def print_value_stack(*, stacklevel: int = 0) -> None: # in an expression context; e.g., x + print_value_stack_and_return(y + z), # you will see x on the stack prior to the addition operation @staticmethod +<<<<<<< HEAD def print_value_stack_and_return(e: Any, *, stacklevel: int = 0) -> Any: +======= + def print_value_stack_and_return(e, *, stacklevel=0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comptime( lambda ctx: ctx.print_value_stack( stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 @@ -384,7 +527,11 @@ def print_value_stack_and_return(e: Any, *, stacklevel: int = 0) -> Any: return e @staticmethod +<<<<<<< HEAD def print_locals(*, stacklevel: int = 0) -> None: +======= + def print_locals(*, stacklevel=0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comptime( lambda ctx: ctx.print_locals( stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 @@ -392,7 +539,11 @@ def print_locals(*, stacklevel: int = 0) -> None: ) @staticmethod +<<<<<<< HEAD def print_bt(*, stacklevel: int = 0) -> None: +======= + def print_bt(*, stacklevel=0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comptime( lambda ctx: ctx.print_bt( stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 @@ -400,6 +551,7 @@ def print_bt(*, stacklevel: int = 0) -> None: ) @staticmethod +<<<<<<< HEAD def print_guards() -> None: comptime(lambda ctx: ctx.print_guards()) @@ -413,6 +565,21 @@ def force_static(val: Any) -> None: @staticmethod def breakpoint() -> None: +======= + def print_guards(): + comptime(lambda ctx: ctx.print_guards()) + + @staticmethod + def assert_static(val): + comptime(lambda ctx: ctx.assert_static(ctx.get_local("val"))) + + @staticmethod + def force_static(val): + comptime(lambda ctx: ctx.get_local("val").force_static()) + + @staticmethod + def breakpoint(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Like pdb breakpoint(), but drop into pdb whenever this line of code is compiled by dynamo. Use it by putting @@ -430,14 +597,22 @@ def breakpoint() -> None: (Pdb) p ctx.get_local("attention").as_fake() """ +<<<<<<< HEAD def inner(inner_ctx: ComptimeContext) -> None: +======= + def inner(inner_ctx): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ctx = inner_ctx.parent() # noqa: F841 builtins.breakpoint() comptime(inner) @staticmethod +<<<<<<< HEAD def sleep(sec: Union[int, float]) -> None: +======= + def sleep(sec): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comptime(lambda ctx: ctx.sleep(ctx.get_local("sec").as_python_constant())) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index b8d1008dec8e1..3db5eef3eb78d 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Configuration module for TorchDynamo compiler and optimization settings. @@ -110,12 +115,15 @@ # Valid options: "dynamic", "unbacked" automatic_dynamic_shapes_mark_as: Literal["dynamic", "unbacked"] = "dynamic" +<<<<<<< HEAD # log graph in/out metadata # This is only turned on for export today since we # know we are tracing a flat callable. later, this # can extended to other use cases as well. log_graph_in_out_metadata = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This flag changes how the shapes of parameters are treated. # If this flag is set to True, then the shapes of torch.nn.Parameter as well as of torch.Tensor are attempted to be dynamic # If this flag is set to False, then the shapes of torch.nn.Parameter are assumed to be static, @@ -264,6 +272,15 @@ # hybrid backed unbacked symints prefer_deferred_runtime_asserts_over_guards = False +<<<<<<< HEAD +======= +# For complex dynamic shapes guards that we're unable to specify with dynamo/export's +# range constraints + dims + derived dims language, we raise constraint violation +# errors or specialize by default. If set to True, this flag avoids crashing/specialization, +# and allows complex guards as runtime assertions in the graph. +allow_complex_guards_as_runtime_asserts = False + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # By default, dynamo will treat all ints as backed SymInts, which means (1) it # will wait to see the int change over multiple runs before generalizing and # (2) it will still always 0/1 specialize an int. When true, this knob @@ -326,10 +343,13 @@ # No longer used optimize_ddp_lazy_compile = False +<<<<<<< HEAD # lambda guarding on object aliasing to improve opportunity for dict tag # optimization use_lamba_guard_for_object_aliasing = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Whether to skip guarding on FSDP-managed modules skip_fsdp_guards = True # Whether to apply torch._dynamo.disable() to FSDP2 hooks. @@ -351,6 +371,7 @@ # the dictionary tag is same across invocation calls. skip_tensor_guards_with_matching_dict_tags = True +<<<<<<< HEAD # Skips guards on func.__defaults__ if the element to be guarded is a constant skip_guards_on_constant_func_defaults = True @@ -381,6 +402,8 @@ # useful for regional compilation. max_saved_pointers_for_recursive_dict_tags_check = 256 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # If True, raises exception if TorchDynamo is called with a context manager raise_on_ctx_manager_usage = True @@ -430,7 +453,11 @@ enable_cpp_guard_manager = True # Use C++ guard manager for symbolic shapes +<<<<<<< HEAD enable_cpp_symbolic_shape_guards = not is_fbcode() +======= +enable_cpp_symbolic_shape_guards = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Enable tracing through contextlib.contextmanager enable_trace_contextlib = True @@ -448,10 +475,13 @@ justknob="pytorch/compiler:inline_inbuilt_nn_modules", ) +<<<<<<< HEAD # Resume tracing in nested frames if a nested graph break occurs # Old behavior is to bubble up the graph break to the top level frame. nested_graph_breaks = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Install "free" tensor variables (globals, non-locals, nn module attributes) # as graph attributes. This is useful for export, as it # produces a consistent number of inputs to the graph. @@ -481,6 +511,7 @@ # traced FX graph is empty when RETURN_* is traced. allow_empty_graphs = False +<<<<<<< HEAD # Used for testing - forces all top-level functions to be nested when traced with Dynamo debug_force_nested_calls = False @@ -493,12 +524,18 @@ # always pass. debug_disable_compile_counter = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # When set, total compile time instruction count is recorded using # torch._dynamo.utilsCompileTimeInstructionCounter. record_compile_time_instruction_count = False +<<<<<<< HEAD def default_debug_dir_root() -> str: +======= +def default_debug_dir_root(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # [@compile_ignored: debug] DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR" if DEBUG_DIR_VAR_NAME in os.environ: @@ -590,12 +627,15 @@ def default_debug_dir_root() -> str: # the inference_mode is still respected. fake_tensor_disable_inference_mode = True +<<<<<<< HEAD # Experimental feature for running automatic caching precompile. # Enables automatic DynamoCache save/load caching_precompile = os.environ.get("TORCH_CACHING_PRECOMPILE", "0") == "1" strict_precompile = os.environ.get("TORCH_STRICT_PRECOMPILE", "0") == "1" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Enables the Compiled Autograd engine to trace autograd calls made under torch.compile(). # Note: AOTAutograd will still trace and partition an AOT backward graph local to that # compiled region. But AOTAutograd traces without knowledge of backward hooks which are @@ -607,6 +647,7 @@ def default_debug_dir_root() -> str: # registering backward hooks on tensors contained within the compiled region. compiled_autograd = False +<<<<<<< HEAD # Checks if we should graph break when seeing nn parameter constructors # in dynamo; this is so that we clearly fail and ask users to move outside @@ -614,6 +655,8 @@ def default_debug_dir_root() -> str: # See https://github.com/pytorch/pytorch/issues/157452 for more context graph_break_on_nn_param_ctor = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Overrides torch.compile() kwargs for Compiled Autograd: compiled_autograd_kwargs_override: dict[str, Any] = {} @@ -665,9 +708,12 @@ def default_debug_dir_root() -> str: os.environ.get("UNSAFE_SKIP_FSDP_MODULE_GUARDS", "0") == "1" ) +<<<<<<< HEAD # Common prefix to append to the id of each compile run to filter out data pt2_compile_id_prefix: Optional[str] = os.environ.get("PT2_COMPILE_ID_PREFIX", None) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Run GC at the end of compilation run_gc_after_compile = Config( # type: ignore[var-annotated] default=True, @@ -683,15 +729,22 @@ def default_debug_dir_root() -> str: # and AOTAutograd runtime wrapper. record_runtime_overhead = True +<<<<<<< HEAD enable_aot_compile = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # HACK: this is for testing custom ops profiling only _custom_ops_profile: Optional[Any] = None if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 +<<<<<<< HEAD def _make_closure_patcher(**changes: Any) -> Any: ... +======= + def _make_closure_patcher(**changes): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) install_config_module(sys.modules[__name__]) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 686f0945179f3..9ddf81577d36a 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-decorators + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module implements TorchDynamo's core frame conversion functionality, transforming Python frames into FX graphs. It handles: @@ -15,10 +20,13 @@ The conversion process preserves program semantics while enabling optimizations through torch.compile() and related systems. +<<<<<<< HEAD NOTE: _torchdynamo_orig_backend is used for convert frame wrappers to identify the inner wrapped function. By going down the _torchdynamo_orig_backend chain, one can recover the original unwrapped backend, which is checked for during the Dynamo cache lookup. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ from __future__ import annotations @@ -39,10 +47,15 @@ import threading import time import traceback +<<<<<<< HEAD import types import typing import weakref from dataclasses import dataclass +======= +import typing +import weakref +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from pathlib import Path from types import CellType, CodeType, FunctionType, ModuleType from typing import Any, Callable, Optional, TypeVar, Union @@ -73,7 +86,10 @@ from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils._python_dispatch import ( _disable_current_modes, +<<<<<<< HEAD is_in_any_mode_without_ignore_compile_internals, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_in_torch_dispatch_mode, ) from torch.utils._traceback import CapturedTraceback, format_traceback_short @@ -107,7 +123,10 @@ InternalTorchDynamoError, PackageError, RecompileLimitExceeded, +<<<<<<< HEAD ResumePrologueTracingError, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ShortenTraceback, SkipCodeRecursiveException, TorchRuntimeError, @@ -121,7 +140,10 @@ GuardedCode, ) from .hooks import Hooks +<<<<<<< HEAD from .output_graph import DynamoTracerOutput +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .pgo import log_frame_dynamic_whitelist, put_code_state from .replay_record import ExecutionRecord from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX @@ -135,7 +157,10 @@ from .trace_rules import is_numpy from .types import ConvertFrameReturn, FrameAction, FrameExecStrategy, wrap_guarded_code from .utils import ( +<<<<<<< HEAD _get_error_on_graph_break, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) chromium_event_timed, CleanupManager, CompileTimeInstructionCounter, @@ -143,7 +168,10 @@ dynamo_timed, format_bytecode, gen_record_file_name, +<<<<<<< HEAD get_hook_for_recompile_user_context, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) get_metrics_context, increment_frame, is_namedtuple, @@ -228,6 +256,7 @@ def fx_forward_from_src_skip_result( return result +<<<<<<< HEAD def log_dynamo_start(code: CodeType, skip: int = 0) -> list[str]: convert_frame_intern = structured.intern_string(__file__) captured_tb = CapturedTraceback.extract(skip=4 + skip).summary() @@ -262,6 +291,8 @@ def log_dynamo_start(code: CodeType, skip: int = 0) -> list[str]: return stack_strings +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: """ Context manager to: @@ -297,9 +328,13 @@ def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: cuda_rng_state = None if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state() +<<<<<<< HEAD cuda_matmul_fp32_prec = torch._C._get_fp32_precision_getter( "cuda", "matmul" ) +======= + allow_tf32 = torch._C._get_cublas_allow_tf32() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prior_fwd_from_src = torch.fx.graph_module._forward_from_src torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result cleanup = setup_compile_debug() @@ -331,15 +366,23 @@ def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: torch._C._unset_default_mobile_cpu_allocator() if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) +<<<<<<< HEAD torch._C._set_fp32_precision_setter( "cuda", "matmul", cuda_matmul_fp32_prec ) +======= + torch._C._set_cublas_allow_tf32(allow_tf32) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.fx.graph_module._forward_from_src = prior_fwd_from_src assert guards.check(), ( f"Global {guards.reason()}state changed while dynamo tracing, please report a bug" ) +<<<<<<< HEAD _fn._torchdynamo_orig_backend = fn # type: ignore[attr-defined] +======= + _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return _fn @@ -519,6 +562,7 @@ def profile_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: return profile_wrapper +<<<<<<< HEAD @dataclass class ConvertFrameBox: error_on_graph_break: Optional[bool] = None @@ -547,6 +591,8 @@ def get_compile_id( ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ConvertFrameAssert: def __init__( self, @@ -558,12 +604,19 @@ def __init__( ) -> None: # assert export_constraints is None reset_graph_break_dup_checker() +<<<<<<< HEAD self._torchdynamo_orig_backend = compiler_fn +======= + self._torchdynamo_orig_callable = compiler_fn +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._one_graph = one_graph self._export = export self._export_constraints = export_constraints self._package = package +<<<<<<< HEAD self._box = ConvertFrameBox() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def _clone_with_backend(self) -> Callable[[CompilerFn], ConvertFrameAssert]: @@ -584,6 +637,10 @@ def __call__( skip: int = 0, ) -> ConvertFrameReturn: increment_frame() +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) code = frame.f_code cache_size = compute_cache_size(frame, cache_entry) @@ -662,8 +719,29 @@ def __call__( global initial_global_state initial_global_state = GlobalStateGuard() +<<<<<<< HEAD compile_id = get_compile_id(frame_state) frame_id = compile_id.frame_id +======= + global FRAME_COUNTER + if "_id" not in frame_state: + frame_state["_id"] = FRAME_COUNTER + FRAME_COUNTER += 1 + frame_id = frame_state["_id"] + assert isinstance(frame_id, int) + + frame_compile_id = FRAME_COMPILE_COUNTER[frame_id] + FRAME_COMPILE_COUNTER[frame_id] += 1 + + compiled_autograd_id = None + if prior := CompileContext.current_compile_id(): + compiled_autograd_id = prior.compiled_autograd_id + compile_id = CompileId( + compiled_autograd_id=compiled_autograd_id, + frame_id=frame_id, + frame_compile_id=frame_compile_id, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) signpost_event( "dynamo", @@ -685,13 +763,21 @@ def __call__( dynamo_tls.traced_frame_infos.append(info) with compile_context(CompileContext(compile_id)): +<<<<<<< HEAD result = _compile( +======= + return _compile( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) frame.f_code, frame.f_globals, frame.f_locals, frame.f_builtins, frame.closure, +<<<<<<< HEAD self._torchdynamo_orig_backend, +======= + self._torchdynamo_orig_callable, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._one_graph, self._export, self._export_constraints, @@ -703,6 +789,7 @@ def __call__( compile_id=compile_id, skip=skip + 1, package=self._package, +<<<<<<< HEAD convert_frame_box=self._box, ) @@ -713,6 +800,10 @@ def __call__( DynamoCache.record_package(self._package) return result +======= + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def convert_frame_assert( compiler_fn: CompilerFn, @@ -721,7 +812,11 @@ def convert_frame_assert( export_constraints: Optional[typing.Never] = None, package: Optional[CompilePackage] = None, ) -> ConvertFrameAssert: +<<<<<<< HEAD """Fully convert a frame into an FX graph, raising an exception if we fail.""" +======= + """Fully convert a frame into an FX graph""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ConvertFrameAssert( compiler_fn, one_graph, export, export_constraints, package ) @@ -732,6 +827,12 @@ def convert_frame_assert( from torch.utils.hooks import RemovableHandle +<<<<<<< HEAD +======= +if typing.TYPE_CHECKING: + from .output_graph import OutputGraph + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # we have to use `OrderedDict` to make `RemovableHandle` work. _bytecode_hooks: dict[int, BytecodeHook] = OrderedDict() @@ -746,6 +847,7 @@ def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle: return handle +<<<<<<< HEAD @preserve_global_state def trace_frame( code: types.CodeType, @@ -1070,6 +1172,8 @@ def transform( raise +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _compile( code: CodeType, globals: dict[str, object], @@ -1089,6 +1193,7 @@ def _compile( compile_id: CompileId, skip: int = 0, package: Optional[CompilePackage] = None, +<<<<<<< HEAD # Can be used to record things for the caller, both # in the case of normal and exception code paths convert_frame_box: Optional[ConvertFrameBox] = None, @@ -1096,17 +1201,98 @@ def _compile( from torch._inductor.async_compile import async_compile_pool_manager from torch.fx.experimental.validator import ( BisectValidationException, +======= +) -> ConvertFrameReturn: + from torch.fx.experimental.validator import ( + bisect, + BisectValidationException, + translation_validation_enabled, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ValidationException, ) # Only nonlocal defs here please! # Time spent compiling this frame before restarting or failing analysis dynamo_time_before_restart: float = 0.0 +<<<<<<< HEAD @compile_time_strobelight_meta(phase_name="compile_inner") def compile_inner( code: CodeType, one_graph: bool, hooks: Hooks ) -> tuple[ConvertFrameReturn, Optional[DynamoTracerOutput]]: +======= + output: Optional[OutputGraph] = None + tracer: Optional[InstructionTranslator] = None + + tf_mode_stack: list[torch.overrides.TorchFunctionMode] = ( + torch.overrides._get_current_function_mode_stack() + ) + + @preserve_global_state + def transform( + instructions: list[Instruction], code_options: dict[str, object] + ) -> None: + nonlocal output + nonlocal tracer + speculation_log.restart() # type: ignore[has-type] + exn_vt_stack = ExceptionStack() + tracer = InstructionTranslator( + instructions, + code, + locals, + globals, + builtins, + closure, + tf_mode_stack, + code_options, + compiler_fn, + one_graph, + export, + export_constraints, + frame_state=frame_state, + speculation_log=speculation_log, # type: ignore[has-type] + exn_vt_stack=exn_vt_stack, + distributed_state=distributed_state, # type: ignore[has-type] + package=package, + ) + + try: + tracer.output.mark_bytecode_tracing_start() + with tracing(tracer.output.tracing_context), tracer.set_current_tx(): + tracer.run() + except exc.UnspecializeRestartAnalysis: + speculation_log.clear() # type: ignore[has-type] + raise + except ( + exc.SpeculationRestartAnalysis, + exc.TensorifyScalarRestartAnalysis, + exc.SkipFrame, + ): + raise + except Exception: + if translation_validation_enabled(): + bisect(tracer.output.shape_env) + raise + finally: + tracer.output.call_cleanup_hooks() + + output = tracer.output + assert output is not None + assert output.output_instructions + instructions[:] = output.output_instructions + code_options.update(output.code_options) + propagate_inst_exn_table_entries(instructions) + check_inst_exn_tab_entries_valid(instructions) + instructions[:] = remove_pointless_jumps(remove_dead_code(instructions)) + + @compile_time_strobelight_meta(phase_name="compile_inner") + def compile_inner( + code: CodeType, + one_graph: bool, + hooks: Hooks, + transform: Callable[[list[Instruction], dict[str, Any]], Any], + ) -> ConvertFrameReturn: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with contextlib.ExitStack() as stack: stack.enter_context( torch._dynamo.callback_handler.install_callbacks( @@ -1114,11 +1300,18 @@ def compile_inner( ) ) stack.enter_context(CompileTimeInstructionCounter.record()) +<<<<<<< HEAD return _compile_inner(code, one_graph, hooks) return ( ConvertFrameReturn(), None, +======= + return _compile_inner(code, one_graph, hooks, transform) + + return ( + ConvertFrameReturn() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # dead, but see https://github.com/python/mypy/issues/7577 @maybe_cprofile @@ -1126,7 +1319,12 @@ def _compile_inner( code: CodeType, one_graph: bool, hooks: Hooks, +<<<<<<< HEAD ) -> tuple[ConvertFrameReturn, DynamoTracerOutput]: +======= + transform: Callable[[list[Instruction], dict[str, Any]], Any], + ) -> ConvertFrameReturn: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal dynamo_time_before_restart last_attempt_start_time = start_time = time.time() @@ -1147,6 +1345,7 @@ def log_bytecode( ) out_code = None +<<<<<<< HEAD try: dynamo_output = compile_frame( code, @@ -1168,14 +1367,61 @@ def log_bytecode( log.debug("No graph captured with export/fullgraph=True") assert e._torch_dynamo_tracer_output is not None return ConvertFrameReturn(), e._torch_dynamo_tracer_output +======= + for attempt in itertools.count(): + CompileContext.get().attempt = attempt + try: + with dynamo_timed( + f"compile_attempt_{attempt}", log_pt2_compile_event=True + ): + out_code = transform_code_object(code, transform) + break + except exc.RestartAnalysis as e: + if not isinstance(e, exc.TensorifyScalarRestartAnalysis): + TensorifyState.clear() + log.info( + "Restarting analysis due to %s", + LazyString(format_traceback_short, e.__traceback__), + ) + # If restart reason is None just log the type of the exception + restart_reasons.add(e.restart_reason or str(type(e))) + # We now have a new "last attempt", reset the clock + last_attempt_start_time = time.time() + if attempt > 100: + unimplemented_v2( + gb_type="Excessive RestartAnalysis() calls", + context="", + explanation="Dynamo attempted to trace the same frame 100+ times. " + "Giving up on compiling as the compile time tradeoff is likely not " + "worth the performance gain.", + hints=[], + ) + except exc.SkipFrame as e: + if not isinstance(e, exc.TensorifyScalarRestartAnalysis): + TensorifyState.clear() + log.debug( + "Skipping frame %s %s \ + %s %s", + e, + code.co_name, + code.co_filename, + code.co_firstlineno, + ) + if one_graph: + log.debug("No graph captured with one_graph=True") + return ConvertFrameReturn() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert distributed_state is None or distributed_state.all_states is not None, ( # type: ignore[has-type] "compiler collective wasn't run before compilation completed" ) +<<<<<<< HEAD out_code = dynamo_output.bytecode tracer_output = dynamo_output.tracer_output if dynamo_output.last_attempt_start_time is not None: last_attempt_start_time = dynamo_output.last_attempt_start_time +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert out_code is not None log_bytecode( @@ -1186,17 +1432,28 @@ def log_bytecode( out_code, ) +<<<<<<< HEAD for idx, hook in enumerate(_bytecode_hooks.values()): with dynamo_timed(f"bytecode_hooks_{idx}", log_pt2_compile_event=True): hook_output = hook(code, out_code) if hook_output is not None: out_code = hook_output +======= + for hook in _bytecode_hooks.values(): + hook_output = hook(code, out_code) + if hook_output is not None: + out_code = hook_output +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig_code_map[out_code] = code output_codes.add(out_code) dynamo_time_before_restart = last_attempt_start_time - start_time +<<<<<<< HEAD assert tracer_output.output_graph is not None output = tracer_output.output_graph +======= + assert output is not None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Tests for new code objects. # The rationale for these tests can be found in torch/csrc/dynamo/eval_frame.c @@ -1242,23 +1499,40 @@ def count_args(code: CodeType) -> int: # are extra graphs now. if output.export and output.is_empty_graph(): +<<<<<<< HEAD return ConvertFrameReturn(), tracer_output +======= + return ConvertFrameReturn() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert output.guards is not None CleanupManager.instance[out_code] = output.cleanups nonlocal cache_entry with dynamo_timed("build_guards", log_pt2_compile_event=True): +<<<<<<< HEAD check_fn = dynamo_output.build_guards( code, hooks=hooks, save=package is not None, cache_entry=cache_entry, +======= + check_fn = CheckFunctionManager( + code, + output, + cache_entry, + hooks.guard_fail_fn if hooks else None, + hooks.guard_filter_fn if hooks else None, + guards_serialization_mode="save" if package else None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if package is not None: assert check_fn.guards_state is not None package.add_guarded_code(check_fn.guards_state, out_code) +<<<<<<< HEAD package.add_inlined_source(output.tracing_context.traced_code) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) compile_id_str = str(compile_id) if compile_id is not None else "Unknown" annotation_str = "Torch-Compiled Region: " + compile_id_str @@ -1277,7 +1551,11 @@ def count_args(code: CodeType) -> int: # they are benign and do not generate any new graphs. hooks.guard_export_fn(output.guards) +<<<<<<< HEAD return wrap_guarded_code(guarded_code), tracer_output +======= + return wrap_guarded_code(guarded_code) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) metrics_context = get_metrics_context() code_context = ( @@ -1286,7 +1564,10 @@ def count_args(code: CodeType) -> int: with ( _use_lazy_graph_module(config.use_lazy_graph_module), compile_context(CompileContext(compile_id)), +<<<<<<< HEAD async_compile_pool_manager(), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) chromium_event_timed( "dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True ), @@ -1300,6 +1581,11 @@ def count_args(code: CodeType) -> int: code_context, ): restart_reasons: set[str] = set() +<<<<<<< HEAD +======= + # This is shared across restarts + speculation_log = SpeculationLog() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if compile_pg := get_compile_pg(): distributed_state = DistributedState(compile_pg, LocalState()) else: @@ -1312,6 +1598,7 @@ def count_args(code: CodeType) -> int: recompile_reason = ( "Unable to find recompilation reasons" if not reasons else reasons[0] ) +<<<<<<< HEAD # Recheck for recompilation, for when inline_inbuilt_nn_modules is set to False inline_inbuilt_nn_modules_candidate = False if not config.inline_inbuilt_nn_modules and frame: @@ -1346,6 +1633,9 @@ def count_args(code: CodeType) -> int: user_context()[:256] for user_context in recompile_user_contexts } metrics_context.set("recompile_user_contexts", user_contexts_msg) +======= + metrics_context.update_outer({"recompile_reason": recompile_reason}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) exceeded, limit_type = exceeds_recompile_limit(cache_size, compile_id) if exceeded: @@ -1353,14 +1643,21 @@ def count_args(code: CodeType) -> int: def format_func_info(code: CodeType) -> str: return f"'{code.co_name}' ({code.co_filename}:{code.co_firstlineno})" +<<<<<<< HEAD # NS: Don't add period at the end of string, as it'll be added to URL # rendering it incorrect +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log.warning( "torch._dynamo hit config.%s (%s)\n" " function: %s\n" " last reason: %s\n" 'To log all recompilation reasons, use TORCH_LOGS="recompiles".\n' +<<<<<<< HEAD "To diagnose recompilation issues, see %s", +======= + "To diagnose recompilation issues, see %s.", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) limit_type, getattr(config, limit_type), format_func_info(code), @@ -1373,7 +1670,11 @@ def format_func_info(code: CodeType) -> str: ) elif one_graph: raise FailOnRecompileLimitHit( +<<<<<<< HEAD f"{limit_type} reached with fullgraph=True. Excessive recompilations can degrade " +======= + f"{limit_type} reached with one_graph=True. Excessive recompilations can degrade " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "performance due to the compilation overhead of each recompilation. To monitor " "recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider " "increasing torch._dynamo.config.cache_size_limit to an appropriate value." @@ -1420,17 +1721,49 @@ def format_func_info(code: CodeType) -> str: # # 2 extra here # torch/_logging/_internal.py:1064 in trace_structured # torch/_dynamo/convert_frame.py:780 in +<<<<<<< HEAD stack_trace = log_dynamo_start(code, skip) start_time_ns = time.time_ns() fail_type: Optional[str] = None fail_reason: Optional[str] = None exception_stack_trace: Optional[list[str]] = None +======= + convert_frame_intern = structured.intern_string(__file__) + # Initialize the ChromiumEventLogger on start + torch._logging.trace_structured( + "dynamo_start", + lambda: { + "stack": list( + itertools.takewhile( + lambda f: f["filename"] != convert_frame_intern, + structured.from_traceback( + CapturedTraceback.extract(skip=4 + skip).summary() + ), + ) + ) + + [ + { + "line": code.co_firstlineno, + "name": code.co_name, + "filename": structured.intern_string(code.co_filename), + } + ] + }, + ) + start_time_ns = time.time_ns() + fail_type: Optional[str] = None + fail_reason: Optional[str] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fail_user_frame_filename: Optional[str] = None fail_user_frame_lineno: Optional[int] = None torch._dynamo.utils.ReinplaceCounters.clear() guarded_code = None try: +<<<<<<< HEAD guarded_code, tracer_output = compile_inner(code, one_graph, hooks) +======= + guarded_code = compile_inner(code, one_graph, hooks, transform) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: We only put_code_state in success case. Success case here # does include graph breaks; specifically, if a graph break still @@ -1442,12 +1775,16 @@ def format_func_info(code: CodeType) -> str: # to upload for graph break though, because this can prevent # extra graph break compilations.) put_code_state() +<<<<<<< HEAD if ( tracer_output and (output_graph := tracer_output.output_graph) and output_graph.has_outputs() ): log_frame_dynamic_whitelist(code) +======= + log_frame_dynamic_whitelist(code) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return guarded_code except Exception as e: @@ -1456,7 +1793,10 @@ def format_func_info(code: CodeType) -> str: # info here and add it to the metrics context below. fail_type = type(e).__qualname__ fail_reason = str(e) +<<<<<<< HEAD exception_stack_trace = [traceback.format_exc()] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) exception_handler(e, code, frame, export=export) # NB: this is the post-mutation exception torch._logging.trace_structured( @@ -1470,7 +1810,10 @@ def format_func_info(code: CodeType) -> str: fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message( e, compile_id ) +<<<<<<< HEAD tracer_output = getattr(e, "_torch_dynamo_tracer_output", None) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance( e, ( @@ -1485,7 +1828,10 @@ def format_func_info(code: CodeType) -> str: BisectValidationException, ShortenTraceback, PackageError, +<<<<<<< HEAD ResumePrologueTracingError, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ): raise @@ -1505,6 +1851,7 @@ def format_func_info(code: CodeType) -> str: log.info("run_gc_after_compile: running gc") gc.collect(1) +<<<<<<< HEAD output = None if tracer_output: output = tracer_output.output_graph @@ -1513,6 +1860,11 @@ def format_func_info(code: CodeType) -> str: # tracer should already be None, keep an extra check here just in case. if tracer := output.root_tx: tracer.f_locals = {} +======= + if tracer: + tracer.output.local_scope = {} + tracer.f_locals = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .utils import curr_frame @@ -1522,7 +1874,10 @@ def format_func_info(code: CodeType) -> str: shape_env_guard_count = len(output.shape_env.guards) graph_op_count = output.count_calls() graph_node_count = len(output.graph.nodes) +<<<<<<< HEAD graph_node_shapes = output.get_graph_sizes_structured() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph_input_count = len(output.placeholders) non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops} compliant_custom_ops = { @@ -1534,7 +1889,10 @@ def format_func_info(code: CodeType) -> str: shape_env_guard_count = None graph_op_count = None graph_node_count = None +<<<<<<< HEAD graph_node_shapes = {} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph_input_count = None non_compliant_ops = set({}) compliant_custom_ops = set({}) @@ -1563,14 +1921,22 @@ def format_func_info(code: CodeType) -> str: "restart_reasons": restart_reasons, "dynamo_time_before_restart_s": dynamo_time_before_restart, "has_guarded_code": guarded_code is not None, +<<<<<<< HEAD +======= + "config_suppress_errors": config.suppress_errors, + "config_inline_inbuilt_nn_modules": config.inline_inbuilt_nn_modules, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "specialize_float": config.specialize_float, "is_forward": True, "dynamo_compile_time_before_restart_us": to_int_us( dynamo_time_before_restart ), +<<<<<<< HEAD "stack_trace": stack_trace, "graph_node_shapes": str(graph_node_shapes), "exception_stack_trace": exception_stack_trace, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } # TODO: replace with CompileEventLogger.compilation_metrics # There are some columns here not in PT2 Compile Events @@ -1578,6 +1944,7 @@ def format_func_info(code: CodeType) -> str: metrics_context.update_outer(metrics) # === END WARNING WARNING WARNING === +<<<<<<< HEAD # If tracer is available, then tracer.error_on_graph_break reflects value of # global symbolic_convert.error_on_graph_break at the time of the graph break - # symbolic_convert.error_on_graph_break may have been (correctly) changed during cleanup. @@ -1589,6 +1956,8 @@ def format_func_info(code: CodeType) -> str: else _get_error_on_graph_break() ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ConvertFrame: def __init__( @@ -1597,7 +1966,11 @@ def __init__( hooks: Hooks, package: Optional[CompilePackage] = None, ) -> None: +<<<<<<< HEAD self._torchdynamo_orig_backend = compiler_fn +======= + self._torchdynamo_orig_callable = compiler_fn +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._inner_convert = convert_frame_assert( compiler_fn, one_graph=False, package=package ) @@ -1605,10 +1978,14 @@ def __init__( @property def _clone_with_backend(self) -> Callable[[WrapBackendDebug], ConvertFrame]: +<<<<<<< HEAD return lambda backend: convert_frame( backend, self._hooks, ) +======= + return lambda backend: convert_frame(backend, self._hooks) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __call__( self, @@ -1627,6 +2004,7 @@ def __call__( counters["frames"]["ok"] += 1 return result except Exception as e: +<<<<<<< HEAD # Do not allow errors to be suppressed if we're tracing a resume function prologue if isinstance(e, ResumePrologueTracingError): raise @@ -1643,6 +2021,8 @@ def __call__( # as an observed exception - we don't expect that exception to be suppressed. raise +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # These two exception types are "soft" failure, in the sense that # we know this is due to something we didn't implement all the # way, scare the user less about it. That being said, if you @@ -1723,9 +2103,13 @@ def __call__( def convert_frame( +<<<<<<< HEAD compiler_fn: CompilerFn, hooks: Hooks, package: Optional[CompilePackage] = None, +======= + compiler_fn: CompilerFn, hooks: Hooks, package: Optional[CompilePackage] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> ConvertFrame: """Try to convert a frame into an FX graph, if error leave frame unmodified""" return ConvertFrame(compiler_fn, hooks, package=package) @@ -1741,6 +2125,7 @@ def replay(filename: str) -> None: record = ExecutionRecord.load(in_file) record.globals = dict(itertools.chain(record.globals.items(), globals().items())) +<<<<<<< HEAD with decorators.error_on_graph_break(False): try: _compile( @@ -1762,6 +2147,28 @@ def replay(filename: str) -> None: ) finally: config.replay_record_enabled = original_replay_val +======= + try: + _compile( + record.code, + record.globals, + record.locals, + record.builtins, + record.closure, + compiler_fn=eager, + one_graph=False, + export=False, + export_constraints=None, + hooks=Hooks(), + cache_size=CacheSizeRelevantForFrame(0, 0), + cache_entry=None, + frame=None, + frame_state={}, + compile_id=CompileId(frame_id=42, frame_compile_id=999), + ) + finally: + config.replay_record_enabled = original_replay_val +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def first_real_inst_idx(code: CodeType) -> int: @@ -1785,6 +2192,7 @@ def __call__( ) -> ConvertFrameReturn: ... +<<<<<<< HEAD def should_skip_due_to_torch_dispatch_mode() -> bool: return is_in_any_mode_without_ignore_compile_internals() @@ -1793,6 +2201,12 @@ class CatchErrorsWrapper: def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None: functools.wraps(callback)(self) self._torchdynamo_orig_backend = callback +======= +class CatchErrorsWrapper: + def __init__(self, callback: ConvertFrameProtocol, hooks: Hooks) -> None: + functools.wraps(callback)(self) + self._torchdynamo_orig_callable = callback +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.hooks = hooks def __call__( @@ -1802,6 +2216,10 @@ def __call__( frame_state: dict[str, Union[int, FrameStateSizeEntry]], ) -> ConvertFrameReturn: assert frame_state is not None +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_codes.add(frame.f_code) is_skipfile = trace_rules.check(frame.f_code) @@ -1815,8 +2233,13 @@ def __call__( or is_skipfile or config.disable or ( +<<<<<<< HEAD should_skip_due_to_torch_dispatch_mode() and not getattr(self._torchdynamo_orig_backend, "_export", False) +======= + is_in_torch_dispatch_mode(include_infra_modes=False) + and not getattr(self._torchdynamo_orig_callable, "_export", False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ): if log.isEnabledFor(logging.DEBUG): @@ -1837,6 +2260,7 @@ def __call__( ) return ConvertFrameReturn() +<<<<<<< HEAD if ( frame.f_code.co_filename == "" and frame.f_code.co_name == "__new__" ) or ( @@ -1844,6 +2268,10 @@ def __call__( and frame.f_code.co_name == "_make" ): # nametuple constructor/_make +======= + if frame.f_code.co_filename == "" and frame.f_code.co_name == "__new__": + # nametuple constructor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ConvertFrameReturn() if torch._dynamo.utils.get_optimize_ddp_mode() == "ddp_optimizer": ddp_module = DistributedDataParallel._get_active_ddp_module() @@ -1853,15 +2281,26 @@ def __call__( ddp_optimizer = DDPOptimizer( bucket_bytes_cap=ddp_module.bucket_bytes_cap, +<<<<<<< HEAD backend_compile_fn=self._torchdynamo_orig_backend._torchdynamo_orig_backend, # type: ignore[attr-defined] ) assert hasattr( self._torchdynamo_orig_backend, "_clone_with_backend" +======= + backend_compile_fn=self._torchdynamo_orig_callable._torchdynamo_orig_callable, # type: ignore[attr-defined] + ) + assert hasattr( + self._torchdynamo_orig_callable, "_clone_with_backend" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ( "DDPOptimizer only supports callback fns that know how to clone themselves." ) hijacked_callback = ( +<<<<<<< HEAD self._torchdynamo_orig_backend._clone_with_backend( +======= + self._torchdynamo_orig_callable._clone_with_backend( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ddp_optimizer.compile_fn, ) ) @@ -1871,10 +2310,16 @@ def __call__( with compile_lock, _disable_current_modes(): # skip=1: skip this frame +<<<<<<< HEAD result = self._torchdynamo_orig_backend( frame, cache_entry, self.hooks, frame_state, skip=1 ) return result +======= + return self._torchdynamo_orig_callable( + frame, cache_entry, self.hooks, frame_state, skip=1 + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def catch_errors_wrapper( diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 2321213a0a3ba..736c687bbd3a2 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -1,3 +1,9 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs +# mypy: disable-error-code="method-assign" + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Debug utilities for TorchDynamo compilation and execution. @@ -16,8 +22,11 @@ - BuckTargetWriter: Manages Buck build system integration """ +<<<<<<< HEAD from __future__ import annotations +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import atexit import copy import cProfile @@ -34,14 +43,21 @@ import textwrap from collections import Counter from importlib import import_module +<<<<<<< HEAD from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar +======= +from typing import Any, Callable, Optional, TypeVar +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._prims_common as utils import torch._subclasses.meta_utils from torch import Tensor from torch._dynamo.testing import rand_strided +<<<<<<< HEAD from torch._inductor.cpp_builder import normalize_path_separator +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._prims_common import is_float_dtype from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._content_store import ContentStoreReader, ContentStoreWriter @@ -50,6 +66,7 @@ from .utils import clone_inputs, get_debug_dir +<<<<<<< HEAD if TYPE_CHECKING: from collections.abc import Sequence @@ -57,6 +74,8 @@ from torch.storage import UntypedStorage +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) T = TypeVar("T") @@ -71,7 +90,10 @@ extra_deps = [] extra_imports = "" +<<<<<<< HEAD cur_target = "" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if use_buck: extra_deps = [ "//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu", @@ -87,7 +109,11 @@ class BuckTargetWriter: +<<<<<<< HEAD def __init__(self, filename: str) -> None: +======= + def __init__(self, filename): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.subdir, self.py_file = os.path.split(os.path.abspath(filename)) self.target = self.py_file.replace(".py", "") @@ -101,7 +127,11 @@ def __init__(self, filename: str) -> None: tmp = tmp[tmp.find("fbcode/") :][7:] self.cmd_line_path = f"//{tmp}:{self.target}" +<<<<<<< HEAD def build(self) -> str: +======= + def build(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps]) return textwrap.dedent( f""" @@ -127,7 +157,11 @@ def build(self) -> str: """ ) +<<<<<<< HEAD def write(self, print_msg: bool = True) -> list[str]: +======= + def write(self, print_msg=True): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) target_file = os.path.join(self.subdir, "TARGETS") with open(target_file, "w") as fd: fd.write(self.build()) @@ -141,7 +175,11 @@ def write(self, print_msg: bool = True) -> list[str]: return cmd_split +<<<<<<< HEAD def minifier_dir() -> str: +======= +def minifier_dir(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) path = os.path.join(get_debug_dir(), "minifier") if path is None: path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}" @@ -179,7 +217,11 @@ class NNModuleToString: ] @staticmethod +<<<<<<< HEAD def can_convert_to_string(gm: torch.fx.GraphModule) -> bool: +======= + def can_convert_to_string(gm): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cant_convert = set() for _, module in gm.named_children(): if type(module) not in NNModuleToString.safe_reprs: @@ -191,7 +233,11 @@ def can_convert_to_string(gm: torch.fx.GraphModule) -> bool: return True @staticmethod +<<<<<<< HEAD def convert(gm: torch.fx.GraphModule) -> str: +======= + def convert(gm): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.nn.modules.module import _addindent tab = " " * 4 @@ -256,7 +302,11 @@ def __init__(self) -> None: @functools.cache # subprocess is expensive +<<<<<<< HEAD def _cuda_system_info_comment() -> str: +======= +def _cuda_system_info_comment(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not torch.cuda.is_available(): return "# torch.cuda.is_available()==False, no GPU info collected\n" @@ -280,7 +330,11 @@ def _cuda_system_info_comment() -> str: return model_str +<<<<<<< HEAD def generate_env_vars_string(*, stable_output: bool = False) -> str: +======= +def generate_env_vars_string(*, stable_output=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Generate a string configuration for environment variables related to Dynamo, Inductor, and Triton. """ @@ -290,7 +344,11 @@ def generate_env_vars_string(*, stable_output: bool = False) -> str: allow_list = ["TORCH", "DYNAMO", "INDUCTOR", "TRITON"] skip_list = ["TRITON_LIBDEVICE_PATH", "TRITON_PTXAS_PATH", "TRITON_LIBCUDA_PATH"] +<<<<<<< HEAD def filter(key: str) -> bool: +======= + def filter(key): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return any(string in key for string in allow_list) and key not in skip_list config_lines = [ @@ -299,6 +357,7 @@ def filter(key: str) -> bool: if filter(key) ] config_string = "\n".join(config_lines) +<<<<<<< HEAD return normalize_path_separator(f"""\ import os {config_string} @@ -306,6 +365,15 @@ def filter(key: str) -> bool: def generate_config_string(*, stable_output: bool = False) -> str: +======= + return f"""\ +import os +{config_string} + """ + + +def generate_config_string(*, stable_output=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch._functorch.config import torch._inductor.config @@ -325,11 +393,19 @@ def generate_config_string(*, stable_output: bool = False) -> str: """ +<<<<<<< HEAD def get_minifier_repro_path() -> str: return os.path.join(minifier_dir(), "minifier_launcher.py") def helper_for_dump_minify(contents: str) -> None: +======= +def get_minifier_repro_path(): + return os.path.join(minifier_dir(), "minifier_launcher.py") + + +def helper_for_dump_minify(contents): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) minified_repro_path = get_minifier_repro_path() log.warning("Writing minified repro to:\n%s", minified_repro_path) @@ -348,7 +424,11 @@ class AccuracyError(Exception): pass +<<<<<<< HEAD def clone_inputs_retaining_gradness(example_inputs: Sequence[Any]) -> list[Any]: +======= +def clone_inputs_retaining_gradness(example_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This clone inputs is different from utils clone_input. In case of minifier, all the tensors are leaf tensors while creating a new graph. So, we set the @@ -358,6 +438,7 @@ def clone_inputs_retaining_gradness(example_inputs: Sequence[Any]) -> list[Any]: for idx in range(len(example_inputs)): if isinstance(cloned_inputs[idx], torch.Tensor): cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad) +<<<<<<< HEAD return cloned_inputs # type: ignore[return-value] @@ -367,6 +448,12 @@ def run_fwd_maybe_bwd( only_fwd: bool = False, disable_clone: bool = False, ) -> Any: +======= + return cloned_inputs + + +def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Runs a forward and possibly backward iteration for a given mod and args. @@ -394,6 +481,7 @@ def run_fwd_maybe_bwd( def same_two_models( +<<<<<<< HEAD gm: torch.fx.GraphModule, opt_gm: torch.fx.GraphModule, example_inputs: Sequence[Any], @@ -402,6 +490,16 @@ def same_two_models( require_fp64: bool = False, ignore_non_fp: bool = False, ) -> bool: +======= + gm, + opt_gm, + example_inputs, + only_fwd=False, + *, + require_fp64=False, + ignore_non_fp=False, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Check two models have same accuracy. @@ -451,7 +549,11 @@ def same_two_models( return passing +<<<<<<< HEAD def cast_dtype_args_to_fp64(model: torch.fx.GraphModule) -> torch.fx.GraphModule: +======= +def cast_dtype_args_to_fp64(model): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for node in model.graph.nodes: if ( node.op == "call_function" @@ -472,9 +574,13 @@ def cast_dtype_args_to_fp64(model: torch.fx.GraphModule) -> torch.fx.GraphModule return model +<<<<<<< HEAD def cast_to( dtype: torch.dtype, model: torch.fx.GraphModule, inputs: list[Any] ) -> tuple[torch.fx.GraphModule, list[Any]]: +======= +def cast_to(dtype, model, inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._pytree import tree_map model = model.to(dtype) @@ -492,13 +598,18 @@ def cast_to( return model, inputs +<<<<<<< HEAD def cast_to_fp64( model: torch.fx.GraphModule, inputs: list[Any] ) -> tuple[torch.fx.GraphModule, list[Any]]: +======= +def cast_to_fp64(model, inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return cast_to(torch.float64, model, inputs) def backend_accuracy_fails( +<<<<<<< HEAD gm: torch.fx.GraphModule, example_inputs: Sequence[Any], compiler_fn: Callable[[torch.fx.GraphModule, list[Any]], torch.fx.GraphModule], @@ -507,6 +618,16 @@ def backend_accuracy_fails( require_fp64: bool = False, ignore_non_fp: bool = False, ) -> bool: +======= + gm, + example_inputs, + compiler_fn, + only_fwd=False, + *, + require_fp64=False, + ignore_non_fp=False, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: compiled_gm = compiler_fn( copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) @@ -540,10 +661,17 @@ def backend_accuracy_fails( def _stride_or_default( +<<<<<<< HEAD stride: Optional[torch._prims_common.StrideType], *, shape: torch._prims_common.ShapeType, ) -> torch._prims_common.StrideType: +======= + stride: Optional["torch._prims_common.StrideType"], + *, + shape: "torch._prims_common.ShapeType", +) -> "torch._prims_common.StrideType": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return stride if stride is not None else utils.make_contiguous_strides_for(shape) @@ -562,6 +690,7 @@ class NopInputReader: def __init__(self) -> None: self.total = 0 +<<<<<<< HEAD def storage( self, storage_hash: Optional[str], @@ -576,13 +705,26 @@ def tensor(self, *args: Any, **kwargs: Any) -> Optional[torch.Tensor]: pass def symint(self, *args: Any, **kwargs: Any) -> Optional[int]: +======= + def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): + self.total += 1 + + def tensor(self, *args, **kwargs): + pass + + def symint(self, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pass # TODO: Support bundling the entire repro into a zip file for ease of # transferring around class InputReader: +<<<<<<< HEAD def __init__(self, save_dir: Optional[str] = None, *, pbar: Optional[tqdm] = None): +======= + def __init__(self, save_dir=None, *, pbar=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # If None, we will generate random data instead. It's important # to natively support this use case as it will allow people to # share repros without including the real data, if the problem @@ -590,6 +732,7 @@ def __init__(self, save_dir: Optional[str] = None, *, pbar: Optional[tqdm] = Non if save_dir is None: log.warning("no save_dir specified, will generate random data") self.store = ContentStoreReader(save_dir) if save_dir is not None else None +<<<<<<< HEAD self.args: list[Any] = [] self.pbar = pbar @@ -604,6 +747,15 @@ def storage( if self.pbar is not None: self.pbar.update(1) device = _device_or_default(device) # type: ignore[arg-type] +======= + self.args = [] + self.pbar = pbar + + def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): + if self.pbar is not None: + self.pbar.update(1) + device = _device_or_default(device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype_hint = _dtype_or_default(dtype_hint) if self.store is not None and storage_hash is not None: try: @@ -624,6 +776,7 @@ def storage( def tensor( self, +<<<<<<< HEAD storage: UntypedStorage, shape: torch._prims_common.ShapeType, stride: Optional[torch._prims_common.StrideType] = None, @@ -634,6 +787,18 @@ def tensor( is_leaf: Optional[bool] = None, **metadata: Any, ) -> torch.Tensor: +======= + storage, + shape, + stride=None, + *, + storage_offset=None, + dtype=None, + requires_grad=None, + is_leaf=None, + **metadata, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stride = _stride_or_default(stride, shape=shape) storage_offset = _storage_offset_or_default(storage_offset) dtype = _dtype_or_default(dtype) @@ -655,7 +820,11 @@ def tensor( self.args.append(t) return t # for BC +<<<<<<< HEAD def symint(self, val: Any) -> Any: +======= + def symint(self, val): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.args.append(val) return val # for BC @@ -673,8 +842,13 @@ def symint(self, val: Any) -> Any: class InputWriter: +<<<<<<< HEAD def __init__(self, save_dir: Optional[str], *, stable_hash: bool = False) -> None: self._lines: list[str] = [] +======= + def __init__(self, save_dir, *, stable_hash=False): + self._lines = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: consider ensuring tensor and storage counters line up? self.storage_counter = itertools.count() self.save_dir = save_dir @@ -683,9 +857,15 @@ def __init__(self, save_dir: Optional[str], *, stable_hash: bool = False) -> Non if save_dir is not None else None ) +<<<<<<< HEAD self.seen_storages: dict[StorageWeakRef, str] = {} def lines(self) -> list[str]: +======= + self.seen_storages = {} + + def lines(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = [ "def load_args(reader):", ] @@ -700,6 +880,7 @@ def lines(self) -> list[str]: # of initialization may be appropriate # # If we had a FakeTensor, device_hint tells us what device should be +<<<<<<< HEAD def storage( self, untyped_storage: UntypedStorage, @@ -707,6 +888,9 @@ def storage( device_hint: Optional[torch._prims_common.DeviceLikeType] = None, dtype_hint: Optional[torch.dtype] = None, ) -> str: +======= + def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ws = StorageWeakRef(untyped_storage) v = self.seen_storages.get(ws) if v is not None: @@ -721,7 +905,11 @@ def storage( device = untyped_storage.device if device.type == "meta": assert device_hint is not None +<<<<<<< HEAD device = device_hint # type: ignore[assignment] +======= + device = device_hint +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if _device_or_default(None) != device: maybe_device = f", device={device!r}" nbytes = untyped_storage.nbytes() @@ -734,7 +922,11 @@ def storage( self.seen_storages[ws] = v return v +<<<<<<< HEAD def tensor(self, name: str, t: torch.Tensor) -> None: +======= + def tensor(self, name, t) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq storage = self.storage( @@ -766,7 +958,11 @@ def tensor(self, name: str, t: torch.Tensor) -> None: + f") # {name}" ) +<<<<<<< HEAD def unsupported(self, name: str, arg: Any) -> None: +======= + def unsupported(self, name, arg): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: Try hard not to /print/ a tensor, that will be very slow self._lines.append(f"# {name} was unsupported type for dumping: {type(arg)}") # Best effort dump as much useful stuff we can lol, in case you want @@ -784,13 +980,21 @@ def unsupported(self, name: str, arg: Any) -> None: self._lines.append('"""') # write out that the arg was filtered out as it is constant +<<<<<<< HEAD def const(self, name: str) -> None: +======= + def const(self, name) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._lines.append( f"reader.const({name!r}) # {name}, filtered out during compilation" ) # TODO: this doesn't actually symint atm +<<<<<<< HEAD def symint(self, name: str, val: Any) -> None: +======= + def symint(self, name, val) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(val, torch.SymInt): val = val.node.hint self._lines.append(f"reader.symint({val!r}) # {name}") @@ -819,10 +1023,15 @@ def forward(self, primals_1: "f32[1001, 6]", primals_2: "f32[s0]", primals_3: "S from torch.utils._dtype_abbrs import dtype_abbrs +<<<<<<< HEAD dtype_map: dict[str, torch.dtype] = { value: key for key, value in dtype_abbrs.items() } dtype_pattern: str = "|".join(dtype_abbrs.values()) +======= + dtype_map = {value: key for key, value in dtype_abbrs.items()} + dtype_pattern = "|".join(dtype_abbrs.values()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Extracting the source code from the function source = inspect.getsource(func) @@ -838,6 +1047,7 @@ class TensorContainer: # Dictionary for tensors from annotations kwargs: dict[str, Any] = {} +<<<<<<< HEAD sym_shapes_dict: dict[str, int] = sym_shapes or {} def get_sym_int(symint: str) -> int: @@ -848,11 +1058,27 @@ def get_sym_int(symint: str) -> int: return sym_shapes_dict.get(symint, default_sym_shape) # type: ignore[return-value] def gen_tensor(shape: torch._prims_common.ShapeType, dtype: torch.dtype) -> Tensor: +======= + sym_shapes = sym_shapes or {} + + def get_sym_int(symint): + torch._check( + symint in sym_shapes or default_sym_shape is not None, + lambda: f"{symint} not in symbolic_shapes and default sym shape not passed in", + ) + return sym_shapes.get(symint, default_sym_shape) + + def gen_tensor(shape, dtype) -> Tensor: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Resolve symbolic shapes to concrete values resolved_shape = [] dynamic_dims = [] for i, dim in enumerate(shape): +<<<<<<< HEAD dim = dim.strip() # type: ignore[attr-defined] +======= + dim = dim.strip() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if "s" in dim: s = get_sym_int(dim) resolved_shape.append(s) @@ -907,9 +1133,15 @@ def profile_to_file(filename: str) -> Callable[[T], T]: prof = cProfile.Profile() filename = os.path.abspath(os.path.expanduser(filename)) +<<<<<<< HEAD def decorator(fn: Any) -> Any: @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> Any: +======= + def decorator(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prof.enable() try: return fn(*args, **kwargs) @@ -918,7 +1150,11 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper +<<<<<<< HEAD def save_it() -> None: +======= + def save_it(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prof.dump_stats(filename) sys.stderr.write( textwrap.dedent( diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 8143a31608d57..07bf47f9fc6be 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -1,3 +1,9 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs +# ruff: noqa: TCH004 + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module provides decorators and utilities for controlling TorchDynamo's behavior during compilation. """ @@ -6,12 +12,19 @@ import inspect import weakref from dataclasses import dataclass +<<<<<<< HEAD from types import TracebackType from typing import Any, Callable, Optional, overload, TYPE_CHECKING, TypeVar, Union from typing_extensions import ParamSpec import torch from torch.compiler import is_compiling +======= +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._contextlib import _DecoratorContextManager from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -27,10 +40,18 @@ ) from .exc import IncorrectUsage from .external_utils import ( +<<<<<<< HEAD get_nonrecursive_disable_wrapper, wrap_dunder_call_ctx_manager, ) from .utils import _get_error_on_graph_break, _set_error_on_graph_break, is_function +======= + _dynamo_config_patch_proxy_dunder_call, + get_nonrecursive_disable_wrapper, + is_compiling, +) +from .utils import is_function +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING: @@ -54,11 +75,17 @@ _P = ParamSpec("_P") _R = TypeVar("_R") +<<<<<<< HEAD FuncType = Callable[..., Any] F = TypeVar("F", bound=FuncType) def run(fn: Optional[Callable[_P, _R]] = None) -> Any: +======= + + +def run(fn=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Don't do any dynamic compiles, just use prior optimizations""" if fn is not None: fn = innermost_fn(fn) @@ -67,7 +94,11 @@ def run(fn: Optional[Callable[_P, _R]] = None) -> Any: return RunOnlyContext() +<<<<<<< HEAD def disable(fn=None, recursive=True, *, reason=None, wrapping=True): # type: ignore[no-untyped-def] +======= +def disable(fn=None, recursive=True, *, reason=None, wrapping=True): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Decorator to disable TorchDynamo @@ -87,7 +118,11 @@ def disable(fn=None, recursive=True, *, reason=None, wrapping=True): # type: ig return DisableContext(msg=reason, wrapping=wrapping) else: +<<<<<<< HEAD def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]: +======= + def wrap(fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fn = innermost_fn(fn) assert callable(fn) @@ -106,7 +141,11 @@ def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]: skip_code(_nonrecursive_disable_wrapper_code) +<<<<<<< HEAD def skip(fn: Optional[Callable[_P, _R]] = None) -> Callable[..., Any]: +======= +def skip(fn=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Skip frames associated with the function code, but still process recursively invoked frames @@ -116,7 +155,11 @@ def skip(fn: Optional[Callable[_P, _R]] = None) -> Callable[..., Any]: fn = innermost_fn(fn) assert callable(fn) skip_code(fn.__code__) +<<<<<<< HEAD fn._torchdynamo_disable = True # type: ignore[attr-defined] +======= + fn._torchdynamo_disable = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return fn @@ -134,7 +177,11 @@ def __init__( stance: str = "default", *, skip_guard_eval_unsafe: bool = False, +<<<<<<< HEAD force_backend: Union[str, Callable[..., Any], None] = None, +======= + force_backend=None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: if force_backend is not None and stance != "default": raise RuntimeError("non-default stance cannot have force_backend set") @@ -142,13 +189,18 @@ def __init__( self.stance = DynamoStance(stance, skip_guard_eval_unsafe, force_backend) self.prev = _set_stance(self.stance) +<<<<<<< HEAD def __call__(self, fn: F) -> F: +======= + def __call__(self, fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _set_stance(self.prev) wrapper = super().__call__(fn) # forbid wrapper in graph wrapper._dynamo_forbidden = True # type: ignore[attr-defined] return wrapper +<<<<<<< HEAD def __enter__(self) -> None: _set_stance(self.stance) @@ -170,6 +222,24 @@ def assume_constant_result(fn): # type: ignore[no-untyped-def] def allow_in_graph(fn): # type: ignore[no-untyped-def] +======= + def __enter__(self): + _set_stance(self.stance) + + def __exit__(self, exc_type, exc_val, exc_tb): + _set_stance(self.prev) + + def clone(self): + return self.__class__(self.stance.stance, force_backend=self.stance.backend) + + +def assume_constant_result(fn): + fn._dynamo_marked_constant = True + return fn + + +def allow_in_graph(fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function and instead directly write it to the graph when encountered. @@ -187,14 +257,22 @@ def allow_in_graph(fn): # type: ignore[no-untyped-def] trace_rules._allowed_callable_ids.add(fn_id) # Avoid id reuse which creates subtle bugs. +<<<<<<< HEAD def deregister() -> None: +======= + def deregister(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) trace_rules._allowed_callable_ids.remove(fn_id) weakref.finalize(fn, deregister) return fn +<<<<<<< HEAD def nonstrict_trace(traceable_fn: Callable[_P, _R]) -> Callable[_P, _R]: +======= +def nonstrict_trace(traceable_fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Like `allow_in_graph`, but with the following enhancements/differences: # # 1. Supports user-defined class as inputs, as long as the class has been @@ -215,7 +293,11 @@ def nonstrict_trace(traceable_fn: Callable[_P, _R]) -> Callable[_P, _R]: assert callable(traceable_fn), "nonstrict_trace expects a callable" @functools.wraps(traceable_fn) +<<<<<<< HEAD def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _R: +======= + def wrapped(*args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return traceable_fn(*args, **kwargs) wrapped_id = id(wrapped) @@ -227,7 +309,11 @@ def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _R: trace_rules._nonstrict_trace_callable_ids.add(wrapped_id) # Avoid id reuse which creates subtle bugs. +<<<<<<< HEAD def deregister() -> None: +======= + def deregister(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) trace_rules._allowed_callable_ids.remove(wrapped_id) trace_rules._nonstrict_trace_callable_ids.remove(wrapped_id) @@ -236,8 +322,13 @@ def deregister() -> None: return wrapped +<<<<<<< HEAD def _disallow_in_graph_helper(throw_if_not_allowed: bool) -> Callable[..., Any]: def inner(fn: Any) -> Any: +======= +def _disallow_in_graph_helper(throw_if_not_allowed): + def inner(fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(fn, (list, tuple)): return [disallow_in_graph(x) for x in fn] assert callable(fn), "disallow_in_graph expects a callable" @@ -259,7 +350,11 @@ def inner(fn: Any) -> Any: return inner +<<<<<<< HEAD def disallow_in_graph(fn: Callable[..., Any]) -> Any: +======= +def disallow_in_graph(fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Customize which functions TorchDynamo will exclude in the generated graph and force a graph break on. @@ -285,17 +380,29 @@ def fn(a): @_disallow_in_graph_helper(throw_if_not_allowed=False) +<<<<<<< HEAD def graph_break(msg: str = "") -> None: +======= +def graph_break(msg=""): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Force a graph break""" # NOTE: primarily used for internal debugging purposes! @_disallow_in_graph_helper(throw_if_not_allowed=False) +<<<<<<< HEAD def skip_frame(msg: str = "") -> None: """Force a skipped frame""" def forbid_in_graph(fn: Any) -> Any: +======= +def skip_frame(msg=""): + """Force a skipped frame""" + + +def forbid_in_graph(fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Customize which functions TorchDynamo will assert are not present while tracing. @@ -397,9 +504,13 @@ def wrapper(traceable_fn: Callable[_P, _R]) -> Callable[_P, _R]: else: traceable_sig = inspect.signature(traceable_fn) +<<<<<<< HEAD def sig_ident( sig: inspect.Signature, ) -> tuple[tuple[str, ...], set[str], dict[str, Any]]: +======= + def sig_ident(sig): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Ignore annotations for parameters and return type return ( tuple( @@ -479,9 +590,13 @@ def sig_ident( def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _R: return original_fn(*args, **kwargs) +<<<<<<< HEAD def dispatch_fn( self: VariableBuilder, value: Callable[_P, _R] ) -> PolyfilledFunctionVariable: +======= + def dispatch_fn(self, value: Callable[_P, _R]) -> PolyfilledFunctionVariable: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return PolyfilledFunctionVariable( value, source=self.source, @@ -506,9 +621,13 @@ def dispatch_fn( # Helper function to flatten a tensor subclass and apply a function to # all inner tensors that match the outer dim. Used to reduce duplication # across the various marking APIs. +<<<<<<< HEAD def _apply_func_to_inner_tensors_of_same_dim( func: Callable[..., Any], t: object, *args: Any, **kwargs: Any ) -> None: +======= +def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert is_traceable_wrapper_subclass(t) attrs, _ctx = t.__tensor_flatten__() @@ -533,12 +652,16 @@ class directly; instead, use :func:`mark_dynamic`. @forbid_in_graph +<<<<<<< HEAD def mark_unbacked( t: Any, index: Union[int, list[Any], tuple[Any]], strict: bool = False, specialize_on: Optional[list[Any]] = None, ) -> None: +======= +def mark_unbacked(t, index, strict=False, specialize_on=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Mark a tensor as having an unbacked dim. This changes the semantics of operations, we will always report the size does not equal zero/one, we will turn asserts @@ -581,6 +704,7 @@ def mark_unbacked( @forbid_in_graph +<<<<<<< HEAD def mark_dynamic( t: Any, index: Union[int, list[Any], tuple[Any]], @@ -590,6 +714,9 @@ def mark_dynamic( max: Optional[int] = None, specialize_on: Optional[list[Any]] = None, ) -> None: +======= +def mark_dynamic(t, index, *, min=None, max=None, specialize_on=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim. @@ -638,16 +765,25 @@ def mark_dynamic( if not hasattr(t, "_dynamo_dynamic_indices"): t._dynamo_dynamic_indices = set() t._dynamo_dynamic_range = set() +<<<<<<< HEAD t._dynamo_hint_overrides = {} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not hasattr(t, "_specialize_on"): t._specialize_on = {} +<<<<<<< HEAD if hint_override: t._dynamo_hint_overrides[index] = hint_override # TODO(voz): Should we bounds check? t._dynamo_dynamic_indices.add(index) t._dynamo_dynamic_range.add(_DimRange(index, min, max)) # type: ignore[arg-type] +======= + # TODO(voz): Should we bounds check? + t._dynamo_dynamic_indices.add(index) + t._dynamo_dynamic_range.add(_DimRange(index, min, max)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies: # TypeError: 'Attribute' object does not support item assignment @@ -663,7 +799,11 @@ def mark_dynamic( @forbid_in_graph +<<<<<<< HEAD def maybe_mark_dynamic(t: Any, index: Union[int, list[Any], tuple[Any]]) -> None: +======= +def maybe_mark_dynamic(t, index): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Mark a tensor as having a dynamic dim, but don't enforce it (i.e., if this dimension ends up getting specialized, don't error). @@ -685,9 +825,13 @@ def maybe_mark_dynamic(t: Any, index: Union[int, list[Any], tuple[Any]]) -> None maybe_mark_dynamic(t, i) +<<<<<<< HEAD def mark_static( t: Any, index: Optional[Union[int, list[Any], tuple[Any]]] = None ) -> None: +======= +def mark_static(t, index=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Mark a tensor as having a static dim or mark a nn module class as static. @@ -752,7 +896,11 @@ def mark_static( @forbid_in_graph +<<<<<<< HEAD def mark_static_address(t: Any, guard: bool = True) -> None: +======= +def mark_static_address(t, guard=True): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Marks an input tensor whose data_ptr will not change across multiple calls to a dynamo-compiled function. This indicates to cudagraphs that an extra allocation @@ -771,7 +919,11 @@ def mark_static_address(t: Any, guard: bool = True) -> None: # One day, Dynamo will support tracing into einops directly (no allow_in_graph needed) # Note that PyTorch supports multiple versions of einops, so when that day comes, # we still need to be really careful about version matches. +<<<<<<< HEAD def _allow_in_graph_einops() -> None: +======= +def _allow_in_graph_einops(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import einops try: @@ -802,15 +954,24 @@ def _allow_in_graph_einops() -> None: # Proxy class for torch._dynamo.config patching - so dynamo can identify context managers/decorators # created by patch_dynamo_config, compared to ones created by a raw torch._dynamo.config.patch. class DynamoConfigPatchProxy: +<<<<<<< HEAD def __init__(self, config_patch: Any) -> None: self.config_patch = config_patch @property def changes(self) -> dict[str, Any]: +======= + def __init__(self, config_patch): + self.config_patch = config_patch + + @property + def changes(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.config_patch.changes # Decorator implementation that simply sets up `self` as a context manager. # Placed in external_utils so that we can trace through it. +<<<<<<< HEAD __call__ = wrap_dunder_call_ctx_manager def __enter__(self) -> None: @@ -822,6 +983,14 @@ def __exit__( exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: +======= + __call__ = _dynamo_config_patch_proxy_dunder_call + + def __enter__(self): + return self.config_patch.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.config_patch.__exit__(exc_type, exc_val, exc_tb) @@ -853,7 +1022,11 @@ def __exit__( del config +<<<<<<< HEAD def _patch_dynamo_config_check(changes: dict[str, Any]) -> None: +======= +def _patch_dynamo_config_check(changes: dict[str, Any]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for k, v in changes.items(): if k not in _allowed_config_patches: raise ValueError( @@ -897,6 +1070,7 @@ def patch_dynamo_config( return DynamoConfigPatchProxy(config_patch) +<<<<<<< HEAD @overload def dont_skip_tracing(fn: None = None) -> DynamoConfigPatchProxy: ... @@ -906,6 +1080,9 @@ def dont_skip_tracing(fn: Callable[_P, _R]) -> Callable[_P, _R]: ... def dont_skip_tracing(fn: Optional[Any] = None) -> Any: +======= +def dont_skip_tracing(fn=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Context manager/decorator to trace into functions intentionally marked by developers to be skipped when tracing. @@ -916,6 +1093,7 @@ def dont_skip_tracing(fn: Optional[Any] = None) -> Any: if fn: return ctx(fn) return ctx +<<<<<<< HEAD class ErrorOnGraphBreakDecoratorContextManager: @@ -955,3 +1133,5 @@ def error_on_graph_break( The default value of torch.compile's `error_on_graph_break` setting is False. """ return ErrorOnGraphBreakDecoratorContextManager(error_on_graph_break) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index 26cf4796fd073..3294318d88a17 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -1,11 +1,23 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Device abstraction layer for TorchDynamo and Inductor backends. This module provides a unified interface for different hardware backends (CUDA, XPU, +<<<<<<< HEAD CPU, MPS, MTIA) through a common device interface. Key components include: - DeviceInterface: Base class defining the common API for all device types - Device-specific implementations: CudaInterface, XpuInterface, CpuInterface, MpsInterface, MtiaInterface +======= +CPU, MPS) through a common device interface. Key components include: + +- DeviceInterface: Base class defining the common API for all device types +- Device-specific implementations: CudaInterface, XpuInterface, CpuInterface, MpsInterface +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - Device registration system for managing available backends - Worker APIs for multi-processing scenarios - Stream and event management across different devices @@ -17,10 +29,16 @@ import inspect import time +<<<<<<< HEAD from collections import namedtuple from collections.abc import Iterable from dataclasses import dataclass from typing import Any, Callable, Literal, Optional, Union +======= +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch @@ -43,17 +61,29 @@ class DeviceInterface: """ class device: +<<<<<<< HEAD def __new__(cls, device: torch.types.Device) -> Any: raise NotImplementedError class Event: def __new__(cls, *args: Any, **kwargs: Any) -> Any: +======= + def __new__(cls, device: torch.types.Device): + raise NotImplementedError + + class Event: + def __new__(cls, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError( "Event should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo." ) class Stream: +<<<<<<< HEAD def __new__(cls, *args: Any, **kwargs: Any) -> Any: +======= + def __new__(cls, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError( "Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo." ) @@ -67,7 +97,11 @@ class Worker: """ @staticmethod +<<<<<<< HEAD def set_device(device: int) -> None: +======= + def set_device(device: int): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError @staticmethod @@ -75,6 +109,7 @@ def current_device() -> int: raise NotImplementedError @staticmethod +<<<<<<< HEAD def get_device_properties(device: torch.types.Device = None) -> Any: raise NotImplementedError @@ -84,6 +119,17 @@ def current_device() -> int: @staticmethod def set_device(device: torch.types.Device) -> None: +======= + def get_device_properties(device: torch.types.Device = None): + raise NotImplementedError + + @staticmethod + def current_device(): + raise NotImplementedError + + @staticmethod + def set_device(device: torch.types.Device): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError @staticmethod @@ -95,7 +141,11 @@ def exchange_device(device: int) -> int: raise NotImplementedError @staticmethod +<<<<<<< HEAD def device_count() -> int: +======= + def device_count(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError @staticmethod @@ -103,6 +153,7 @@ def is_available() -> bool: raise NotImplementedError @staticmethod +<<<<<<< HEAD def stream(stream: torch.Stream) -> Any: raise NotImplementedError @@ -116,6 +167,21 @@ def set_stream(stream: torch.Stream) -> None: @staticmethod def _set_stream_by_id(stream_id: int, device_index: int, device_type: int) -> None: +======= + def stream(stream: torch.Stream): + raise NotImplementedError + + @staticmethod + def current_stream(): + raise NotImplementedError + + @staticmethod + def set_stream(stream: torch.Stream): + raise NotImplementedError + + @staticmethod + def _set_stream_by_id(stream_id: int, device_index: int, device_type: int): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError @staticmethod @@ -123,6 +189,7 @@ def get_raw_stream(device_idx: int) -> int: raise NotImplementedError @staticmethod +<<<<<<< HEAD def synchronize(device: torch.types.Device = None) -> None: raise NotImplementedError @@ -136,6 +203,21 @@ def get_compute_capability(device: torch.types.Device = None) -> Any: @staticmethod def is_bf16_supported(including_emulation: bool = False) -> bool: +======= + def synchronize(device: torch.types.Device = None): + raise NotImplementedError + + @classmethod + def get_device_properties(cls, device: torch.types.Device = None): + return cls.Worker.get_device_properties(device) + + @staticmethod + def get_compute_capability(device: torch.types.Device = None): + raise NotImplementedError + + @staticmethod + def is_bf16_supported(including_emulation: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError @classmethod @@ -187,11 +269,19 @@ def __init__( self.idx = index self.prev_idx = -1 +<<<<<<< HEAD def __enter__(self) -> None: if self.idx is not None: self.prev_idx = self.device_interface.exchange_device(self.idx) def __exit__(self, type: Any, value: Any, traceback: Any) -> Literal[False]: +======= + def __enter__(self): + if self.idx is not None: + self.prev_idx = self.device_interface.exchange_device(self.idx) + + def __exit__(self, type: Any, value: Any, traceback: Any): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.idx is not None: self.idx = self.device_interface.maybe_exchange_device(self.prev_idx) return False @@ -207,7 +297,11 @@ class CudaInterface(DeviceInterface): class Worker: @staticmethod +<<<<<<< HEAD def set_device(device: int) -> None: +======= + def set_device(device: int): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) caching_worker_current_devices["cuda"] = device @staticmethod @@ -217,7 +311,11 @@ def current_device() -> int: return torch.cuda.current_device() @staticmethod +<<<<<<< HEAD def get_device_properties(device: torch.types.Device = None) -> Any: +======= + def get_device_properties(device: torch.types.Device = None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if device is not None: if isinstance(device, str): device = torch.device(device) @@ -246,8 +344,13 @@ def get_device_properties(device: torch.types.Device = None) -> Any: synchronize = staticmethod(torch.cuda.synchronize) get_device_properties = staticmethod(torch.cuda.get_device_properties) # type: ignore[assignment] get_raw_stream = staticmethod(get_cuda_stream) # type: ignore[assignment, arg-type] +<<<<<<< HEAD exchange_device = staticmethod(torch.cuda._exchange_device) # type: ignore[arg-type, has-type] maybe_exchange_device = staticmethod(torch.cuda._maybe_exchange_device) # type: ignore[arg-type, has-type] +======= + exchange_device = staticmethod(torch.cuda._exchange_device) # type: ignore[arg-type] + maybe_exchange_device = staticmethod(torch.cuda._maybe_exchange_device) # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) memory_allocated = staticmethod(torch.cuda.memory_allocated) is_bf16_supported = staticmethod(torch.cuda.is_bf16_supported) # type: ignore[arg-type] @@ -257,7 +360,11 @@ def is_available() -> bool: return torch.cuda.is_available() @staticmethod +<<<<<<< HEAD def get_compute_capability(device: torch.types.Device = None) -> Union[int, str]: +======= + def get_compute_capability(device: torch.types.Device = None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if torch.version.hip is None: major, min = torch.cuda.get_device_capability(device) return major * 10 + min @@ -288,6 +395,7 @@ def raise_if_triton_unavailable(device: torch.types.Device = None) -> None: raise RuntimeError("triton not built with the 'nvidia' backend") +<<<<<<< HEAD get_mtia_stream: Optional[Callable[[int], int]] if torch.mtia._is_compiled(): from torch._C import _mtia_getCurrentRawStream as get_mtia_stream @@ -369,6 +477,8 @@ def raise_if_triton_unavailable(evice: torch.types.Device = None) -> None: raise RuntimeError("triton not built with the 'mtia' backend") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) get_xpu_stream: Optional[Callable[[int], int]] if torch.xpu._is_compiled(): from torch._C import _xpu_getCurrentRawStream as get_xpu_stream @@ -383,7 +493,11 @@ class XpuInterface(DeviceInterface): class Worker: @staticmethod +<<<<<<< HEAD def set_device(device: int) -> None: +======= + def set_device(device: int): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) caching_worker_current_devices["xpu"] = device @staticmethod @@ -393,7 +507,11 @@ def current_device() -> int: return torch.xpu.current_device() @staticmethod +<<<<<<< HEAD def get_device_properties(device: torch.types.Device = None) -> Any: +======= + def get_device_properties(device: torch.types.Device = None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if device is not None: if isinstance(device, str): device = torch.device(device) @@ -432,7 +550,11 @@ def is_available() -> bool: return torch.xpu.is_available() @staticmethod +<<<<<<< HEAD def get_compute_capability(device: torch.types.Device = None) -> Any: +======= + def get_compute_capability(device: torch.types.Device = None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cc = torch.xpu.get_device_capability(device) return cc @@ -445,7 +567,11 @@ def is_triton_capable(device: torch.types.Device = None) -> bool: return True @staticmethod +<<<<<<< HEAD def raise_if_triton_unavailable(device: torch.types.Device = None) -> None: +======= + def raise_if_triton_unavailable(evice: torch.types.Device = None) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import triton.backends if "intel" not in triton.backends.backends: @@ -459,6 +585,7 @@ class CpuDeviceProperties: class CpuInterface(DeviceInterface): class Event(torch.Event): +<<<<<<< HEAD def __init__(self, enable_timing: bool = True) -> None: self.time = 0.0 @@ -466,13 +593,26 @@ def elapsed_time(self, end_event: Any) -> float: return (end_event.time - self.time) * 1000 def record(self, stream: Any = None) -> None: +======= + def __init__(self, enable_timing=True): + self.time = 0.0 + + def elapsed_time(self, end_event) -> float: + return (end_event.time - self.time) * 1000 + + def record(self, stream=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.time = time.perf_counter() class Worker: @staticmethod +<<<<<<< HEAD def get_device_properties( device: torch.types.Device = None, ) -> CpuDeviceProperties: +======= + def get_device_properties(device: torch.types.Device = None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import multiprocessing cpu_count = multiprocessing.cpu_count() @@ -483,7 +623,11 @@ def is_available() -> bool: return True @staticmethod +<<<<<<< HEAD def is_bf16_supported(including_emulation: bool = False) -> bool: +======= + def is_bf16_supported(including_emulation: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return True @staticmethod @@ -491,6 +635,7 @@ def get_compute_capability(device: torch.types.Device = None) -> str: return "" @staticmethod +<<<<<<< HEAD def get_raw_stream(device_idx: Any) -> int: return 0 @@ -500,6 +645,17 @@ def current_device() -> int: @staticmethod def synchronize(device: torch.types.Device = None) -> None: +======= + def get_raw_stream(device_idx) -> int: + return 0 + + @staticmethod + def current_device(): + return 0 + + @staticmethod + def synchronize(device: torch.types.Device = None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pass @staticmethod @@ -532,7 +688,11 @@ def is_available() -> bool: return torch.backends.mps.is_available() @staticmethod +<<<<<<< HEAD def current_device() -> int: +======= + def current_device(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return 0 @staticmethod @@ -540,11 +700,16 @@ def get_compute_capability(device: torch.types.Device = None) -> str: return "" @staticmethod +<<<<<<< HEAD def synchronize(device: torch.types.Device = None) -> None: +======= + def synchronize(device: torch.types.Device = None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.mps.synchronize() class Worker: @staticmethod +<<<<<<< HEAD def get_device_properties(device: torch.types.Device = None) -> Any: return namedtuple("MPSProperties", ["multi_processor_count"])( torch.backends.mps.get_core_count() # type: ignore[arg-type] @@ -552,6 +717,13 @@ def get_device_properties(device: torch.types.Device = None) -> Any: @staticmethod def current_device() -> int: +======= + def get_device_properties(device: torch.types.Device = None): + return {} + + @staticmethod + def current_device(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return 0 @@ -561,7 +733,11 @@ def current_device() -> int: def register_interface_for_device( device: Union[str, torch.device], device_interface: type[DeviceInterface] +<<<<<<< HEAD ) -> None: +======= +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(device, torch.device): device = device.type device_interfaces[device] = device_interface @@ -583,7 +759,11 @@ def get_registered_device_interfaces() -> Iterable[tuple[str, type[DeviceInterfa return device_interfaces.items() +<<<<<<< HEAD def init_device_reg() -> None: +======= +def init_device_reg(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) global _device_initialized register_interface_for_device("cuda", CudaInterface) for i in range(torch.cuda.device_count()): @@ -593,10 +773,13 @@ def init_device_reg() -> None: for i in range(torch.xpu.device_count()): register_interface_for_device(f"xpu:{i}", XpuInterface) +<<<<<<< HEAD register_interface_for_device("mtia", MtiaInterface) for i in range(torch.mtia.device_count()): register_interface_for_device(f"mtia:{i}", MtiaInterface) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) register_interface_for_device("cpu", CpuInterface) register_interface_for_device("mps", MpsInterface) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 177541e8f3341..f2551e4d5fb01 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1,3 +1,7 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # mypy: disable-error-code="method-assign" """ @@ -36,7 +40,10 @@ import threading import traceback import types +<<<<<<< HEAD import unittest +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import warnings import weakref from dataclasses import dataclass @@ -103,6 +110,7 @@ ) from .hooks import Hooks from .mutation_guard import install_generation_tagging_init +<<<<<<< HEAD from .utils import ( _get_error_on_graph_break, _set_error_on_graph_break, @@ -126,6 +134,15 @@ GuardFail, GuardFilterEntry, ) +======= +from .utils import common_constant_types, compile_times + + +if TYPE_CHECKING: + from torch._subclasses import fake_tensor + + from .types import CacheEntry, DynamoCallback +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) @@ -145,7 +162,11 @@ class Unset(Enum): unset = Unset.token +<<<<<<< HEAD def _maybe_set_eval_frame(callback: DynamoCallback) -> DynamoCallback: +======= +def _maybe_set_eval_frame(callback: DynamoCallback): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # A wrapper on set_eval_frame that is guarded by a Justknob. # Users can disable torchDynamo by setting the JK to False. if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"): @@ -187,7 +208,11 @@ def _set_stance(stance: DynamoStance) -> DynamoStance: _EXAMPLE_INPUTS: Optional[dict[str, list[Any]]] = None +<<<<<<< HEAD def get_example_inputs(key: str) -> list[Any]: +======= +def get_example_inputs(key) -> list[Any]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) global _EXAMPLE_INPUTS if _EXAMPLE_INPUTS is None: _EXAMPLE_INPUTS = {} @@ -198,7 +223,11 @@ def get_example_inputs(key: str) -> list[Any]: return _EXAMPLE_INPUTS[key] +<<<<<<< HEAD def _callback_from_stance(callback: DynamoCallback) -> DynamoCallback: +======= +def _callback_from_stance(callback): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if _stance.stance == "default": # force_backend if _stance.backend is not None and callback not in (False, None): @@ -223,6 +252,7 @@ def _callback_from_stance(callback: DynamoCallback) -> DynamoCallback: if callback in (False, None): return callback +<<<<<<< HEAD def fail_callback( frame: DynamoFrameType, *args: Any, **kwargs: Any ) -> ConvertFrameReturn: @@ -248,15 +278,30 @@ def fail_callback( # to prevent cache miss due to different backend fail_callback._torchdynamo_orig_backend = callback # type: ignore[attr-defined] +======= + def fail_callback(frame, *args, **kwargs): + if trace_rules.check(frame.f_code): + return ConvertFrameReturn() + raise RuntimeError( + "Detected recompile when torch.compile stance is 'fail_on_recompile'" + ) + + # to prevent cache miss due to different callback + fail_callback._torchdynamo_orig_callable = callback # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return fail_callback else: raise RuntimeError(f"invalid torch.compile stance '{_stance}'") +<<<<<<< HEAD def _create_wrapped_callback( compiler_fn: CompilerFn, ) -> convert_frame.CatchErrorsWrapper: +======= +def _create_wrapped_callback(compiler_fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) hooks = Hooks() return convert_frame.catch_errors_wrapper( convert_frame.convert_frame( # type: ignore[arg-type] @@ -267,7 +312,11 @@ def _create_wrapped_callback( ) +<<<<<<< HEAD def _get_or_add_example_inputs(frame: DynamoFrameType) -> list[Any]: +======= +def _get_or_add_example_inputs(frame): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) key = frame.f_code.co_filename + str(frame.f_code.co_firstlineno) example_inputs = get_example_inputs(key) @@ -277,10 +326,15 @@ def _get_or_add_example_inputs(frame: DynamoFrameType) -> list[Any]: return example_inputs +<<<<<<< HEAD def _create_delayed_compile_callback( callback: DynamoCallback, stance: str ) -> Callable[..., Any]: def callback_fn(*args: Any, **kwargs: Any) -> convert_frame.ConvertFrameReturn: +======= +def _create_delayed_compile_callback(callback, stance): + def callback_fn(*args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) frame = args[0] example_inputs = _get_or_add_example_inputs(frame) @@ -297,6 +351,7 @@ def callback_fn(*args: Any, **kwargs: Any) -> convert_frame.ConvertFrameReturn: dynamism = track_dynamism_across_examples(example_inputs) code_context.get_context(frame.f_code)["dynamism"] = dynamism +<<<<<<< HEAD compiler_fn = callback._torchdynamo_orig_backend._torchdynamo_orig_backend # type: ignore[union-attr] return _create_wrapped_callback(compiler_fn)(*args, **kwargs) @@ -311,6 +366,19 @@ def _is_skip_guard_eval_unsafe_stance() -> bool: def _reset_guarded_backend_cache() -> None: +======= + compiler_fn = callback._torchdynamo_orig_callable._torchdynamo_orig_callable + return _create_wrapped_callback(compiler_fn)(*args, **kwargs) + + return callback_fn + + +def _is_skip_guard_eval_unsafe_stance(): + return _stance.skip_guard_eval_unsafe + + +def _reset_guarded_backend_cache(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) global cached_backends for backend in cached_backends.values(): if hasattr(backend, "reset"): @@ -358,7 +426,11 @@ class OptimizedModule(torch.nn.Module): "_super_module_initialized", } +<<<<<<< HEAD def __init__(self, mod: torch.nn.Module, dynamo_ctx: _TorchDynamoContext) -> None: +======= + def __init__(self, mod: torch.nn.Module, dynamo_ctx) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NOTE: this must go first, because attribute reads/writes of `self` # uses `_orig_mod`, and sometimes users override `Module.__init__` to # do attribute reads/writes on `self`. @@ -376,7 +448,11 @@ def __init__(self, mod: torch.nn.Module, dynamo_ctx: _TorchDynamoContext) -> Non self._initialize() self.training = self._orig_mod.training +<<<<<<< HEAD def _initialize(self) -> None: +======= + def _initialize(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Do this stuff in constructor to lower overhead slightly if isinstance(self.dynamo_ctx, DisableContext): # No need to check trace rules @@ -400,7 +476,11 @@ def _initialize(self) -> None: self._forward = self.forward self.forward = self._call_lazy_check +<<<<<<< HEAD def __call__(self, *args: Any, **kwargs: Any) -> Any: +======= + def __call__(self, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if torch.nn.modules.module._has_any_global_hook(): warnings.warn( "Using `torch.compile(module)` when there are global hooks on " @@ -413,39 +493,66 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: ) return super().__call__(*args, **kwargs) +<<<<<<< HEAD def __reduce__( self, ) -> tuple[type[OptimizedModule], tuple[torch.nn.Module, _TorchDynamoContext]]: return (self.__class__, (self._orig_mod, self.dynamo_ctx)) def __getstate__(self) -> dict[str, Any]: +======= + def __reduce__(self): + return (self.__class__, (self._orig_mod, self.dynamo_ctx)) + + def __getstate__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) state = dict(self.__dict__) state.pop("forward", None) state.pop("__call__", None) return state +<<<<<<< HEAD def __setstate__(self, state: dict[str, Any]) -> None: +======= + def __setstate__(self, state): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.__dict__ = state self._initialize() @property +<<<<<<< HEAD def training(self) -> bool: return self._orig_mod.training @training.setter def training(self, value: bool) -> None: +======= + def training(self): + return self._orig_mod.training + + @training.setter + def training(self, value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Ignore the `training` mutation in `super().__init__()`, since that's # setting the default on `nn.Module`, but we are mirroring the # `training` attr in `self._orig_mod`. if self._super_module_initialized: self._orig_mod.training = value +<<<<<<< HEAD def __getattr__(self, name: str) -> Any: +======= + def __getattr__(self, name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if name == "_orig_mod": return self._modules["_orig_mod"] return getattr(self._orig_mod, name) +<<<<<<< HEAD def __setattr__(self, name: str, val: Any) -> None: +======= + def __setattr__(self, name, val) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Allow patching over class attributes if hasattr(type(self), name): return super().__setattr__(name, val) @@ -454,7 +561,11 @@ def __setattr__(self, name: str, val: Any) -> None: return super().__setattr__(name, val) return setattr(self._orig_mod, name, val) +<<<<<<< HEAD def __delattr__(self, name: str) -> None: +======= + def __delattr__(self, name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This mirrors `__setattr__` if hasattr(type(self), name): return super().__delattr__(name) @@ -463,7 +574,11 @@ def __delattr__(self, name: str) -> None: return super().__delattr__(name) return delattr(self._orig_mod, name) +<<<<<<< HEAD def _call_lazy_check(self, *args: Any, **kwargs: Any) -> Any: +======= + def _call_lazy_check(self, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( hasattr(self._orig_mod, "_initialize_hook") and hasattr(self._orig_mod, "_infer_parameters") @@ -476,14 +591,22 @@ def _call_lazy_check(self, *args: Any, **kwargs: Any) -> Any: self._orig_mod._infer_parameters(self._orig_mod, args, kwargs) return self._forward(*args, **kwargs) +<<<<<<< HEAD def __dir__(self) -> list[str]: +======= + def __dir__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig_mod_attrs = self._orig_mod.__dir__() return orig_mod_attrs + [ attr for attr in super().__dir__() if attr not in orig_mod_attrs ] +<<<<<<< HEAD def remove_from_cache(f: Any) -> None: +======= +def remove_from_cache(f): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Make sure f.__code__ is not cached to force a recompile """ @@ -500,6 +623,7 @@ def remove_from_cache(f: Any) -> None: log.warning("could not determine __code__ for %s", f) +<<<<<<< HEAD def nothing() -> None: pass @@ -511,21 +635,41 @@ def always_false() -> bool: def innermost_fn( fn: Callable[..., Any], unaltered_fn_attr: str = "_torchdynamo_orig_callable" ) -> Callable[..., Any]: +======= +def nothing(): + pass + + +def always_false(): + return False + + +def innermost_fn(fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ In case of nesting of _TorchDynamoContext calls, find the innermost function. TorchDynamo caches on fn.__code__ object, so its necessary to find the innermost function to pass on the optimize, run, disable etc. """ unaltered_fn = fn +<<<<<<< HEAD while hasattr(unaltered_fn, unaltered_fn_attr): unaltered_fn = getattr(unaltered_fn, unaltered_fn_attr) +======= + while hasattr(unaltered_fn, "_torchdynamo_orig_callable"): + unaltered_fn = unaltered_fn._torchdynamo_orig_callable +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert callable(unaltered_fn), ( f"A callable function is expected, but {type(unaltered_fn)} is provided." ) return unaltered_fn +<<<<<<< HEAD def make_set_enable_dynamic(enable: bool) -> Any: +======= +def make_set_enable_dynamic(enable: bool): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(enable, bool) if enable: # Assume everything is dynamic by default @@ -547,12 +691,20 @@ class DynamoTLS(threading.local): dynamo_tls = DynamoTLS() +<<<<<<< HEAD def clear_dynamo_tls() -> None: +======= +def clear_dynamo_tls(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dynamo_tls.traced_frame_infos.clear() @atexit.register +<<<<<<< HEAD def _log_traced_frames() -> None: +======= +def _log_traced_frames(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ At program exit, log all of the frames Dynamo has attempted to trace from, excluding the continuation frames generated by Dynamo. @@ -563,23 +715,45 @@ def _log_traced_frames() -> None: log.info(msg) +<<<<<<< HEAD def guard_collectives_hook(guard_eval_result: bool) -> bool: +======= +def guard_collectives_hook(guard_eval_result): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch.distributed as dist from torch._dynamo.utils import dynamo_timed # guard_eval_result == True ==> cache hit if pg := distributed.get_guard_pg(): with dynamo_timed( +<<<<<<< HEAD "guard_collective", log_pt2_compile_event=False, log_waitcounter=True ): log.debug("guard_collective %s", guard_eval_result) +======= + "guard_collective", log_pt2_compile_event=True, log_waitcounter=True + ): + log.info("guard_collective %s", guard_eval_result) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "guard_collective", + "encoding": "string", + }, + payload_fn=lambda: str(guard_eval_result), + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: a bit awkward to time, this isn't inside of the dynamo compile region all_results = [None] * pg.size() dist.all_gather_object(all_results, guard_eval_result, group=pg) # True = everyone hit, OK to run # False = someone missed, force recompile everywhere res = all(all_results) +<<<<<<< HEAD log.debug("guard_collective %s -> %s", guard_eval_result, res) +======= + log.info("guard_collective %s -> %s", guard_eval_result, res) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return res return guard_eval_result @@ -591,6 +765,7 @@ class _TorchDynamoContext: def __init__( self, callback: DynamoCallback, +<<<<<<< HEAD on_enter: Callable[[], Any] = nothing, backend_ctx_ctor: Callable[ [], contextlib.AbstractContextManager[Any] @@ -605,6 +780,17 @@ def __init__( compiler_config: Optional[Any] = None, package: Optional[CompilePackage] = None, hooks: Optional[Hooks] = None, +======= + on_enter=nothing, + backend_ctx_ctor=null_context, + patch_fn=nothing, + first_ctx=False, + *, + export=False, + dynamic=None, + compiler_config=None, + package=None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: super().__init__() assert callable(callback) or callback is False or callback is None @@ -612,27 +798,42 @@ def __init__( self._backend_ctx_ctor = backend_ctx_ctor self.prior: Union[Unset, DynamoCallback] = unset self.first_ctx = first_ctx +<<<<<<< HEAD self.fullgraph = fullgraph self.error_on_graph_break = error_on_graph_break +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.export = export self._dynamic = dynamic self.compiler_config = compiler_config self.cleanup_fns: list[Callable[[], Any]] = [] self.enter_exit_hooks = [] self._package = package +<<<<<<< HEAD self._hooks = hooks patch_fn() # Save the backends so that we can reset them during torch._dynamo.reset backend = innermost_fn(callback, unaltered_fn_attr="_torchdynamo_orig_backend") # type: ignore[arg-type] cached_backends.setdefault(id(backend), backend) # type: ignore[arg-type] +======= + patch_fn() + + # Save the backends so that we can reset them during torch._dynamo.reset + backend = innermost_fn(callback) + cached_backends.setdefault(id(backend), backend) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if dynamic is not None: self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic)) if on_enter is not nothing: # this case is not common +<<<<<<< HEAD def call_on_enter() -> Callable[[], None]: +======= + def call_on_enter(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) on_enter() return nothing @@ -640,14 +841,22 @@ def call_on_enter() -> Callable[[], None]: if backend_ctx_ctor is not contextlib.nullcontext: # this case is not common +<<<<<<< HEAD def call_backend_ctx() -> functools.partial[Optional[bool]]: +======= + def call_backend_ctx(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ctx = backend_ctx_ctor() ctx.__enter__() return functools.partial(ctx.__exit__, None, None, None) self.enter_exit_hooks.append(call_backend_ctx) +<<<<<<< HEAD def __enter__(self) -> None: +======= + def __enter__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if config.raise_on_ctx_manager_usage: raise RuntimeError( "torch._dynamo.optimize(...) is used with a context manager. " @@ -661,12 +870,16 @@ def __enter__(self) -> None: ) _maybe_set_eval_frame(_callback_from_stance(self.callback)) +<<<<<<< HEAD def __exit__( self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[types.TracebackType], ) -> Optional[bool]: +======= + def __exit__(self, exc_type, exc_val, exc_tb): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.prior is not unset set_eval_frame(None) set_skip_guard_eval_unsafe(self.prior_skip_guard_eval_unsafe) @@ -675,6 +888,7 @@ def __exit__( self.cleanup_fns.clear() _maybe_set_eval_frame(_callback_from_stance(self.prior)) self.prior = unset +<<<<<<< HEAD return None def __call__(self, fn: Any) -> Any: @@ -725,6 +939,16 @@ def aot_compile(example_inputs: tuple[tuple[Any, ...], dict[str, Any]]) -> Any: ), ) +======= + + def __call__(self, fn): + # public api for compiler config/options + def get_compiler_config(): + return self.compiler_config + + fn = innermost_fn(fn) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # add context containing GraphModule to any GraphModule forward functions if isinstance(fn, GraphModule): # add context containing GraphModule to any GraphModule forward functions @@ -765,9 +989,13 @@ def aot_compile(example_inputs: tuple[tuple[Any, ...], dict[str, Any]]) -> Any: filename = inspect.getsourcefile(fn) except TypeError: filename = None +<<<<<<< HEAD if config.debug_force_nested_calls: fn = external_utils.wrap_inline(fn) elif config.wrap_top_frame or ( +======= + if config.wrap_top_frame or ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (filename is None or trace_rules.check(fn)) and ( getattr(fn, "__name__", "") @@ -778,6 +1006,7 @@ def aot_compile(example_inputs: tuple[tuple[Any, ...], dict[str, Any]]) -> Any: # call to a builtin without a frame for us to capture fn = external_utils.wrap_inline(fn) +<<<<<<< HEAD def do_nothing(*arg: Any, **kwargs: Any) -> None: pass @@ -793,6 +1022,24 @@ def compile_wrapper(*args: Any, **kwargs: Any) -> Any: prior = set_eval_frame(None) try: if is_fx_symbolic_tracing(): +======= + def do_nothing(*arg, **kwargs): + pass + + if hasattr(self, "callback"): + callback = self.callback + else: + callback = do_nothing + + is_jit_tracing = torch._C._is_tracing + is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing + + @functools.wraps(fn) + def compile_wrapper(*args, **kwargs): + prior = set_eval_frame(None) + try: + if is_fx_tracing(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if config.error_on_nested_fx_trace: raise RuntimeError( "Detected that you are using FX to symbolically trace " @@ -811,10 +1058,13 @@ def compile_wrapper(*args: Any, **kwargs: Any) -> Any: prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe( _is_skip_guard_eval_unsafe_stance() ) +<<<<<<< HEAD prior_error_on_graph_break = None if not self.fullgraph and self.error_on_graph_break is not None: prior_error_on_graph_break = _get_error_on_graph_break() _set_error_on_graph_break(self.error_on_graph_break) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Ensure that if an assertion occurs after graph pushes # something onto the DynamicLayerStack then we pop it off (the @@ -825,7 +1075,10 @@ def compile_wrapper(*args: Any, **kwargs: Any) -> Any: saved_dynamic_layer_stack_depth = ( torch._C._functorch.get_dynamic_layer_stack_depth() ) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _maybe_set_eval_frame(_callback_from_stance(callback)) try: @@ -846,8 +1099,11 @@ def compile_wrapper(*args: Any, **kwargs: Any) -> Any: finally: # Restore the dynamic layer stack depth if necessary. set_eval_frame(None) +<<<<<<< HEAD if prior_error_on_graph_break is not None: _set_error_on_graph_break(prior_error_on_graph_break) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth( saved_dynamic_layer_stack_depth ) @@ -859,6 +1115,7 @@ def compile_wrapper(*args: Any, **kwargs: Any) -> Any: _maybe_set_eval_frame(prior) # hooks to properly handle inlining +<<<<<<< HEAD if self.error_on_graph_break is not None: compile_wrapper._torchdynamo_inline = ( # type: ignore[attr-defined] external_utils.wrap_inline_with_error_on_graph_break( @@ -867,6 +1124,9 @@ def compile_wrapper(*args: Any, **kwargs: Any) -> Any: ) else: compile_wrapper._torchdynamo_inline = fn # type: ignore[attr-defined] +======= + compile_wrapper._torchdynamo_inline = fn # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Save the function pointer to find the original callable while nesting # of decorators. @@ -876,8 +1136,11 @@ def compile_wrapper(*args: Any, **kwargs: Any) -> Any: # provide public api _fn.get_compiler_config() assert not hasattr(compile_wrapper, "get_compiler_config") compile_wrapper.get_compiler_config = get_compiler_config # type: ignore[attr-defined] +<<<<<<< HEAD if torch._dynamo.config.enable_aot_compile: compile_wrapper.aot_compile = aot_compile # type: ignore[attr-defined] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # If the function is called using torch._dynamo.optimize decorator, we # should prevent any type of skipping. @@ -924,6 +1187,7 @@ def compile_wrapper(*args: Any, **kwargs: Any) -> Any: class OptimizeContext(_TorchDynamoContext): def __init__( self, +<<<<<<< HEAD callback: DynamoCallback, backend_ctx_ctor: Callable[[], contextlib.AbstractContextManager[Any]], first_ctx: bool = False, @@ -940,6 +1204,21 @@ def __init__( hooks: Optional[Hooks] = None, ) -> None: def on_enter() -> None: +======= + callback, + backend_ctx_ctor, + first_ctx=False, + *, + export=False, + dynamic=None, + compiler_config=None, + rebuild_ctx: Optional[ + Callable[[], Union[OptimizeContext, _NullDecorator]] + ] = None, + package=None, + ) -> None: + def on_enter(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) install_generation_tagging_init() super().__init__( @@ -948,13 +1227,19 @@ def on_enter() -> None: backend_ctx_ctor=backend_ctx_ctor, patch_fn=TorchPatcher.patch, first_ctx=first_ctx, +<<<<<<< HEAD fullgraph=fullgraph, error_on_graph_break=error_on_graph_break, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) export=export, dynamic=dynamic, compiler_config=compiler_config, package=package, +<<<<<<< HEAD hooks=hooks, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if config.compiled_autograd: @@ -962,7 +1247,11 @@ def on_enter() -> None: if _dynamic is None: _dynamic = not torch._dynamo.config.assume_static_by_default +<<<<<<< HEAD def call_compiled_autograd() -> functools.partial[Optional[bool]]: +======= + def call_compiled_autograd(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert rebuild_ctx is not None compiler_fn = rebuild_ctx() ctx = torch._dynamo.compiled_autograd._enable( @@ -973,9 +1262,13 @@ def call_compiled_autograd() -> functools.partial[Optional[bool]]: self.enter_exit_hooks.append(call_compiled_autograd) +<<<<<<< HEAD def __reduce__( self, ) -> tuple[type[OptimizeContext], tuple[Any, ...], dict[str, Any]]: +======= + def __reduce__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( self.__class__, (self.callback, self._backend_ctx_ctor, self.first_ctx), @@ -990,12 +1283,20 @@ def __reduce__( class RunOnlyContext(_TorchDynamoContext): def __init__(self) -> None: # cudagraph trees relies on generation increment +<<<<<<< HEAD def on_enter() -> None: +======= + def on_enter(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._dynamo.mutation_guard.GenerationTracker.generation += 1 super().__init__(callback=False, on_enter=on_enter) +<<<<<<< HEAD def __reduce__(self) -> tuple[type[RunOnlyContext], tuple[Any, ...]]: +======= + def __reduce__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (self.__class__, ()) @@ -1005,7 +1306,11 @@ def __init__(self, msg: Optional[str] = None, wrapping: bool = True) -> None: self.msg = msg self.wrapping = wrapping +<<<<<<< HEAD def __call__(self, fn: Callable[..., Any]) -> Callable[..., Any]: +======= + def __call__(self, fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Earlier this code was in the base class _TorchDynamoContext. But we # moved it here to have better code organization. For disable, we just # want the callback to be None. We don't have to check trace_rules or @@ -1036,7 +1341,11 @@ def __call__(self, fn: Callable[..., Any]) -> Callable[..., Any]: f"A callable function is expected, but {type(fn)} is provided." ) +<<<<<<< HEAD def _fn(*args: Any, **kwargs: Any) -> Any: +======= + def _fn(*args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prior = set_eval_frame(None) try: _maybe_set_eval_frame(_callback_from_stance(self.callback)) @@ -1064,11 +1373,16 @@ def _fn(*args: Any, **kwargs: Any) -> Any: return _fn +<<<<<<< HEAD def __reduce__(self) -> tuple[type[DisableContext], tuple[Any, ...]]: +======= + def __reduce__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (self.__class__, ()) def _optimize_catch_errors( +<<<<<<< HEAD compile_fn: convert_frame.ConvertFrameProtocol, hooks: Hooks, backend_ctx_ctor: Callable[ @@ -1082,17 +1396,32 @@ def _optimize_catch_errors( rebuild_ctx: Optional[Callable[[], Union[OptimizeContext, _NullDecorator]]] = None, package: Optional[CompilePackage] = None, ) -> OptimizeContext: +======= + compile_fn, + hooks: Hooks, + backend_ctx_ctor=null_context, + export=False, + dynamic=None, + compiler_config=None, + rebuild_ctx=None, + package=None, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return OptimizeContext( convert_frame.catch_errors_wrapper(compile_fn, hooks), backend_ctx_ctor=backend_ctx_ctor, first_ctx=True, +<<<<<<< HEAD fullgraph=fullgraph, error_on_graph_break=error_on_graph_break, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) export=export, dynamic=dynamic, compiler_config=compiler_config, rebuild_ctx=rebuild_ctx, package=package, +<<<<<<< HEAD hooks=hooks, ) @@ -1108,22 +1437,41 @@ def get_compiler_fn( elif hasattr(compiler_fn, "compiler_name"): compiler_str = compiler_fn.compiler_name # type: ignore[union-attr] assert isinstance(compiler_str, str) +======= + ) + + +def get_compiler_fn(compiler_fn): + from .repro.after_dynamo import wrap_backend_debug + + if hasattr(compiler_fn, "compiler_name"): + compiler_str = compiler_fn.compiler_name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif isinstance(compiler_fn, str): compiler_str = compiler_fn else: compiler_str = None +<<<<<<< HEAD compiler_fn = lookup_backend(compiler_fn) # type: ignore[arg-type] +======= + compiler_fn = lookup_backend(compiler_fn) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return wrap_backend_debug(compiler_fn, compiler_str) class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg] +<<<<<<< HEAD def __call__(self, fn: Callable[..., Any]) -> Callable[..., Any]: +======= + def __call__(self, fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert callable(fn), ( f"A callable function is expected, but {type(fn)} is provided." ) return fn +<<<<<<< HEAD # Make dynamo graph to have same input/output spec as user code def argument_names( f_sig: inspect.Signature, args: list[Any], kwargs: dict[str, Any] @@ -1208,6 +1556,9 @@ def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec: def check_if_dynamo_supported() -> None: +======= +def check_if_dynamo_supported(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 14): raise RuntimeError("Python 3.14+ not yet supported for torch.compile") elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1 and sys.version_info < ( @@ -1221,7 +1572,11 @@ def check_if_dynamo_supported() -> None: ) +<<<<<<< HEAD def is_dynamo_supported() -> bool: +======= +def is_dynamo_supported(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: check_if_dynamo_supported() return True @@ -1229,11 +1584,19 @@ def is_dynamo_supported() -> bool: return False +<<<<<<< HEAD def check_if_inductor_supported() -> None: check_if_dynamo_supported() def is_inductor_supported() -> bool: +======= +def check_if_inductor_supported(): + check_if_dynamo_supported() + + +def is_inductor_supported(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: check_if_inductor_supported() return True @@ -1241,15 +1604,24 @@ def is_inductor_supported() -> bool: return False +<<<<<<< HEAD def check_for_incompatible_configs() -> None: +======= +def check_for_incompatible_configs(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Some of the configs should be mutually exclusive assert not (config.suppress_errors and config.fail_on_recompile_limit_hit), ( "Dynamo configs suppress_error and fail_on_recompile_limit_hit can not both be active at the same time." ) +<<<<<<< HEAD def optimize(*args: Any, **kwargs: Any) -> Union[OptimizeContext, _NullDecorator]: def rebuild_ctx() -> Union[OptimizeContext, _NullDecorator]: +======= +def optimize(*args, **kwargs): + def rebuild_ctx(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ca_kwargs_override = config.compiled_autograd_kwargs_override if ca_kwargs_override: # NOTE: The process of translating other `torch.compile` kwargs to `torch._dynamo.optimize` kwargs @@ -1265,6 +1637,7 @@ def rebuild_ctx() -> Union[OptimizeContext, _NullDecorator]: def _optimize( rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]], +<<<<<<< HEAD backend: Union[str, Callable[..., Any]] = "inductor", *, nopython: bool = False, @@ -1275,6 +1648,17 @@ def _optimize( disable: bool = False, dynamic: Optional[bool] = None, package: Optional[CompilePackage] = None, +======= + backend="inductor", + *, + nopython=False, + guard_export_fn=None, + guard_fail_fn=None, + guard_filter_fn=None, + disable=False, + dynamic=None, + package=None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Union[OptimizeContext, _NullDecorator]: """ The main entrypoint of TorchDynamo. Do graph capture and call @@ -1291,11 +1675,14 @@ def _optimize( - Or, a string backend name in `torch._dynamo.list_backends()` nopython: If True, graph breaks will be errors and there will be a single whole-program graph. +<<<<<<< HEAD error_on_graph_break: If not None, the current `error_on_graph_break` setting is set to the given value. See `torch._dynamo.error_on_graph_break()` for more details on what `error_on_graph_break` means. Unlike `nopython=True` (i.e. `fullgraph=True`), there is no guarantee of a single whole-program graph. If `nopython` is True, `error_on_graph_break` does nothing. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) disable: If True, turn this decorator into a no-op dynamic: If True, upfront compile as dynamic a kernel as possible. If False, disable all dynamic shapes support (always specialize). If None, automatically @@ -1326,7 +1713,11 @@ def toy_example(a, b): ... ): return _NullDecorator() +<<<<<<< HEAD if nopython and not config.debug_force_graph_break_on_leaf_return: +======= + if nopython: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return optimize_assert( backend, dynamic=dynamic, @@ -1341,6 +1732,7 @@ def toy_example(a, b): ... backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) # The backend function is stashed in the callable returned by +<<<<<<< HEAD # _optimize_catch_errors in the field _torchdynamo_orig_backend. This can # be used by eval_frame.c to insert a guard on the backend. @@ -1362,6 +1754,14 @@ def toy_example(a, b): ... fullgraph=False, error_on_graph_break=error_on_graph_break and not config.debug_force_graph_break_on_leaf_return, +======= + # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can + # be used by eval_frame.c to insert a guard on the backend. + return _optimize_catch_errors( + convert_frame.convert_frame(backend, hooks=hooks, package=package), + hooks, + backend_ctx_ctor, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dynamic=dynamic, compiler_config=( backend.get_compiler_config() @@ -1375,10 +1775,15 @@ def toy_example(a, b): ... # TODO(voz): Consider making "explain" output alongside a run / part of a run @patch("torch._dynamo.symbolic_convert.explain", True) +<<<<<<< HEAD def explain(f: Callable[..., Any], *extra_args: Any, **extra_kwargs: Any) -> Any: from .backends.debugging import ExplainOutput def inner(*args: Any, **kwargs: Any) -> ExplainOutput: +======= +def explain(f, *extra_args, **extra_kwargs): + def inner(*args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO(voz): Do we want a decorator for this? from . import reset # type: ignore[attr-defined] @@ -1387,12 +1792,21 @@ def inner(*args: Any, **kwargs: Any) -> ExplainOutput: graphs: list[torch.fx.GraphModule] = [] break_reasons: list[Any] = [] op_count: int = 0 +<<<<<<< HEAD ops_per_graph: list[list[Target]] = [] out_guards: list[_guards.Guard] = [] def dynamo_graph_accumulating_compiler( gm: torch.fx.GraphModule, example_inputs: Any ) -> Callable[..., Any]: +======= + ops_per_graph: list[torch.fx.Node] = [] + out_guards: list[_guards.Guard] = [] + + def dynamo_graph_accumulating_compiler( + gm: torch.fx.GraphModule, example_inputs + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .backends.debugging import _explain_graph_detail nonlocal graphs @@ -1406,7 +1820,11 @@ def dynamo_graph_accumulating_compiler( return gm.forward +<<<<<<< HEAD def guard_export_print(guards: Iterable[_guards.Guard]) -> None: +======= + def guard_export_print(guards): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal out_guards out_guards.extend(guards) @@ -1424,6 +1842,10 @@ def guard_export_print(guards: Iterable[_guards.Guard]) -> None: # TODO(voz): Do we want a decorator for this? reset() +<<<<<<< HEAD +======= + from .backends.debugging import ExplainOutput +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ExplainOutput( graphs, @@ -1453,9 +1875,15 @@ class FlattenInputOutputSignature(torch.fx.Transformer): def __init__( self, m: torch.fx.GraphModule, +<<<<<<< HEAD flat_args: list[Any], matched_input_elements_positions: list[int], flat_results: Sequence[Any], +======= + flat_args: tuple[Any], + matched_input_elements_positions: list[int], + flat_results: list[Any], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) matched_output_elements_positions: list[int], example_fake_inputs: list[torch.Tensor], flat_args_dynamic_dims: list[set[int]], @@ -1503,9 +1931,13 @@ def __init__( self.matched_output_elements_positions = matched_output_elements_positions self.flat_results = flat_results +<<<<<<< HEAD def placeholder( self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Any: +======= + def placeholder(self, target, args, kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) arg = next(self.old_args_gen) if "val" in self.current_node.meta: arg.node.meta["val"] = self.current_node.meta["val"] @@ -1520,11 +1952,17 @@ def placeholder( ] return arg +<<<<<<< HEAD def output( self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Any: dynamo_result_flat = args[0] lookup = [*dynamo_result_flat, *self.new_args] # type: ignore[misc] +======= + def output(self, target, args, kwargs): + dynamo_result_flat = args[0] + lookup = [*dynamo_result_flat, *self.new_args] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_results_flat = [] for i in range(len(self.flat_results)): if self.matched_output_elements_positions[i] is not None: @@ -1537,7 +1975,11 @@ def output( new_results_flat.append(const_val) return super().output(target, (new_results_flat,), {}) +<<<<<<< HEAD def run_node(self, n: Node) -> Any: +======= + def run_node(self, n): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.current_node = n result_proxy = super().run_node(n) if "val" in self.current_node.meta: @@ -1557,7 +1999,11 @@ def run_node(self, n: Node) -> Any: ) return result_proxy +<<<<<<< HEAD def transform(self) -> torch.fx.GraphModule: +======= + def transform(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result_gm = super().transform() if "dynamo_flat_name_to_original_fqn" in self.module.meta: # type: ignore[operator] result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[ # type: ignore[index] @@ -1576,17 +2022,26 @@ class ExportResult(NamedTuple): # NOTE: this function only supports graphs created by Dynamo's OutputGraph module +<<<<<<< HEAD def check_signature_rewritable(graph: torch.fx.GraphModule) -> None: +======= +def check_signature_rewritable(graph): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_errors = [] for node in graph.graph.find_nodes(op="placeholder"): # set in OutputGraph._call_user_compiler assert hasattr(node, "_dynamo_source") assert hasattr(graph, "_source_to_user_stacks") +<<<<<<< HEAD # NOTE: We can safely ignore these type warnings if and only if # the function is made from OutputGraph (checked in the assertions) source = node._dynamo_source # type: ignore[attr-defined] user_stacks = graph._source_to_user_stacks.get(source) # type: ignore[operator, union-attr] +======= + source = node._dynamo_source + user_stacks = graph._source_to_user_stacks.get(source) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if user_stacks is None: continue assert len(user_stacks) > 0 @@ -1623,6 +2078,7 @@ def check_signature_rewritable(graph: torch.fx.GraphModule) -> None: def rewrite_signature( +<<<<<<< HEAD f_sig: inspect.Signature, graph: torch.fx.GraphModule, fake_mode: Optional[fake_tensor.FakeTensorMode], @@ -1639,6 +2095,22 @@ def rewrite_signature( def check_user_input_output( flat_values: list[Any], error_type: UserErrorType ) -> None: +======= + f_sig, + graph, + fake_mode, + flat_args, + in_spec, + example_fake_inputs, + graph_captured_input, + graph_captured_output, + dynamo_traced_result, + flat_args_dynamic_dims, +): + orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec) + + def check_user_input_output(flat_values, error_type): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) supported_types = [ torch.Tensor, torch.SymInt, @@ -1648,7 +2120,11 @@ def check_user_input_output( _IntWrapper, ] + list(common_constant_types) +<<<<<<< HEAD def is_supported_type(val: Any) -> bool: +======= + def is_supported_type(val): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return isinstance(val, tuple(supported_types)) value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output" @@ -1674,7 +2150,11 @@ def is_supported_type(val: Any) -> bool: flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result) check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT) +<<<<<<< HEAD def check_optional_input_and_error(f_sig: inspect.Signature) -> None: +======= + def check_optional_input_and_error(f_sig: inspect.Signature): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Check if function has optional input. for name, param in f_sig.parameters.items(): if param.default is not inspect.Parameter.empty: @@ -1690,9 +2170,13 @@ def check_optional_input_and_error(f_sig: inspect.Signature) -> None: case_name="optional_input", ) +<<<<<<< HEAD def produce_matching( debug_type: str, sources: Iterable[Any], candidates: Iterable[Any] ) -> list[Optional[int]]: +======= + def produce_matching(debug_type, sources, candidates): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) matched_elements_positions: list[Optional[int]] = [] dict_of_source_vals = {} for i, val in enumerate(sources): @@ -1725,14 +2209,106 @@ def produce_matching( new_graph = FlattenInputOutputSignature( graph, flat_args, +<<<<<<< HEAD matched_input_elements_positions, # type: ignore[arg-type] flat_results_traced, matched_output_elements_positions, # type: ignore[arg-type] +======= + matched_input_elements_positions, + flat_results_traced, + matched_output_elements_positions, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) example_fake_inputs, flat_args_dynamic_dims, fake_mode, ).transform() +<<<<<<< HEAD +======= + # Make dynamo graph to have same input/output spec as user code + def argument_names(f_sig, args, kwargs) -> list[str]: + def signature_to_fullargspec(sig: inspect.Signature): + # Get a list of Parameter objects from the Signature object + params = list(sig.parameters.values()) + # Separate positional arguments, keyword-only arguments and varargs/varkw + args = [ + p.name + for p in params + if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ] + kwonlyargs = [ + p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY + ] + varargs = next( + (p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL), + None, + ) + varkw = next( + (p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD), + None, + ) + # Get default values for positional arguments and keyword-only arguments + defaults = tuple( + p.default + for p in params + if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + and p.default is not inspect.Parameter.empty + ) + kwonlydefaults = { + p.name: p.default + for p in params + if p.kind == inspect.Parameter.KEYWORD_ONLY + and p.default is not inspect.Parameter.empty + } + # Get annotations for parameters and return value + annotations = {} + if sig.return_annotation: + annotations = {"return": sig.return_annotation} + for parameter in params: + annotations[parameter.name] = parameter.annotation + # Return a FullArgSpec object with the extracted attributes + return inspect.FullArgSpec( + args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations + ) + + fullargspec = signature_to_fullargspec(f_sig) + + # 1. Map `args` 1-to-1 to positional arguments in original signature. + input_strs = fullargspec.args[: len(args)] + + if len(args) > len(fullargspec.args): + # 2. If there are more arguments left in `args`, they map to varargs in original + # signature. Assign names as {varargs}_0, {varargs}_1, ... + assert fullargspec.varargs is not None, "More arguments than expected" + input_strs += [ + f"{fullargspec.varargs}_{i}" + for i in range(0, len(args) - len(input_strs)) + ] + elif len(args) < len(fullargspec.args): + # 3. If there are fewer arguments in `args` than `fullargspec.args`, + # it implies these are arguments either with default values, or provided in + # `kwargs`. The former can be safely ignored. Because Dynamo.export does not + # export them as part of the function signature. The latter will be handled + # in the next step. + for unprovided_arg in fullargspec.args[ + len(args) : -len(fullargspec.defaults or []) + ]: + assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}" + + # 4. Keyword arguments provided in `kwargs`. + input_strs += list(kwargs.keys()) + + # 5. Keyword-only arguments with default values if not provided are not exported + # as part of the function signature. + for kwonly_arg in fullargspec.kwonlyargs: + kwonlydefaults = fullargspec.kwonlydefaults or {} + assert kwonly_arg in kwargs or kwonly_arg in kwonlydefaults, ( + f"Missing keyword only argument {kwonly_arg}" + ) + + return input_strs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_graph.graph._codegen = _PyTreeCodeGen( _PyTreeInfo( argument_names(f_sig, orig_args, orig_kwargs), @@ -1746,7 +2322,11 @@ def produce_matching( def export( f: Callable[..., Any], +<<<<<<< HEAD *extra_args: Any, +======= + *extra_args, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten_graph: bool = False, pre_dispatch: bool = False, decomposition_table: Optional[ @@ -1759,9 +2339,16 @@ def export( same_signature: bool = True, disable_constraint_solver: bool = False, prefer_deferred_runtime_asserts_over_guards: bool = False, +<<<<<<< HEAD _log_export_usage: bool = True, constraints: Optional[list[Constraint]] = None, **extra_kwargs: Any, +======= + allow_complex_guards_as_runtime_asserts: bool = False, + _log_export_usage: bool = True, + constraints: Optional[list[Constraint]] = None, + **extra_kwargs, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Callable[..., ExportResult]: """ Export an input function f to a format that can be executed outside of PyTorch using the FX graph. @@ -1816,9 +2403,12 @@ def export( Note - this headerdoc was authored by ChatGPT, with slight modifications by the author. """ +<<<<<<< HEAD if config.debug_force_graph_break_on_leaf_return: raise unittest.SkipTest("Cannot force graph break on export") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if _log_export_usage: log_export_usage(event="export.private_api", flags={"_dynamo"}) @@ -1828,7 +2418,11 @@ def export( _assume_static_by_default = assume_static_by_default _constraints = constraints +<<<<<<< HEAD def inner(*args: Any, **kwargs: Any) -> ExportResult: +======= + def inner(*args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not _constraints: combined_args = _combine_args(_f, args, kwargs) constraints = _process_dynamic_shapes(combined_args, dynamic_shapes) @@ -1848,7 +2442,11 @@ def inner(*args: Any, **kwargs: Any) -> ExportResult: assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True" f = innermost_fn(f) call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f +<<<<<<< HEAD original_signature = inspect.signature(call_to_inspect) # type: ignore[arg-type] +======= + original_signature = inspect.signature(call_to_inspect) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph = None out_guards = None graph_captured_input = None @@ -1856,18 +2454,30 @@ def inner(*args: Any, **kwargs: Any) -> ExportResult: fake_mode = None result_traced = None +<<<<<<< HEAD def guard_export_print(guards: _guards.GuardsSet) -> None: +======= + def guard_export_print(guards: _guards.GuardsSet): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal out_guards assert out_guards is None, ( "whole graph export entails exactly one guard export" ) out_guards = guards +<<<<<<< HEAD example_inputs: list[Any] = [] def dynamo_normalization_capturing_compiler( gm: torch.fx.GraphModule, inner_example_inputs: list[Any] ) -> Callable[..., Any]: +======= + example_inputs = [] + + def dynamo_normalization_capturing_compiler( + gm: torch.fx.GraphModule, inner_example_inputs + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal graph assert graph is None, ( "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph." @@ -1883,7 +2493,11 @@ def dynamo_normalization_capturing_compiler( fake_mode = _guards.detect_fake_mode() example_inputs = inner_example_inputs +<<<<<<< HEAD def result_capturing_wrapper(*graph_inputs: Any) -> Any: +======= + def result_capturing_wrapper(*graph_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal graph_captured_result nonlocal graph_captured_input @@ -1907,7 +2521,11 @@ def result_capturing_wrapper(*graph_inputs: Any) -> Any: ignore_fresh_unbacked = null_context() assert ambient_fake_mode is not None if shape_env := ambient_fake_mode.shape_env: +<<<<<<< HEAD ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols() # type: ignore[assignment] +======= + ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with ( ambient_fake_mode, @@ -1925,6 +2543,7 @@ def result_capturing_wrapper(*graph_inputs: Any) -> Any: value, static_shapes=True ) +<<<<<<< HEAD from torch._export.non_strict_utils import ( key_path_to_source, KeyPath, @@ -1933,6 +2552,9 @@ def result_capturing_wrapper(*graph_inputs: Any) -> Any: def fakify_with_ambient( path: KeyPath, t: Union[torch.Tensor, _IntWrapper, Any] ) -> Any: +======= + def fakify_with_ambient(path, t): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(t, torch.Tensor): return ambient_fake_mode.from_tensor(t, static_shapes=True) elif isinstance(t, _IntWrapper): @@ -1945,6 +2567,13 @@ def fakify_with_ambient( _DimHintType.AUTO, ) ): # type: ignore[union-attr] +<<<<<<< HEAD +======= + from torch._export.non_strict_utils import ( + key_path_to_source, + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) source = key_path_to_source(path) symint = ambient_fake_mode.shape_env.create_unspecified_symint_and_symbol( # type: ignore[union-attr] t.val, source, DimDynamic.DYNAMIC @@ -1959,9 +2588,13 @@ def fakify_with_ambient( fakify_with_ambient, graph_inputs ) graph_captured_result = torch.func.functional_call( +<<<<<<< HEAD graph, fake_params_buffers, # type: ignore[arg-type] fake_graph_inputs, # type: ignore[arg-type] +======= + graph, fake_params_buffers, fake_graph_inputs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return graph_captured_result @@ -1985,6 +2618,10 @@ def fakify_with_ambient( capture_dynamic_output_shape_ops=True, capture_scalar_outputs=True, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, +<<<<<<< HEAD +======= + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), _compiling_state_context(), ): @@ -2103,7 +2740,11 @@ def fakify_with_ambient( if aten_graph: # Running graph with interpreter is needed for propagating the stack_trace +<<<<<<< HEAD def graph_with_interpreter(*args: Any) -> Any: +======= + def graph_with_interpreter(*args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch.fx.traceback.preserve_node_meta(): return torch.fx.Interpreter(graph).run(*args) # type: ignore[arg-type] @@ -2153,12 +2794,20 @@ def graph_with_interpreter(*args: Any) -> Any: flat_args, in_spec, example_fake_inputs, +<<<<<<< HEAD graph_captured_input, # type: ignore[arg-type] +======= + graph_captured_input, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph_captured_result, result_traced, # type: ignore[possibly-undefined] flat_args_dynamic_dims, ) +<<<<<<< HEAD return ExportResult(graph, out_guards) +======= + return ExportResult(graph, out_guards) # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if extra_args or extra_kwargs: warnings.warn( @@ -2168,19 +2817,31 @@ def graph_with_interpreter(*args: Any) -> Any: FutureWarning, stacklevel=2, ) +<<<<<<< HEAD return inner(*extra_args, **extra_kwargs) # type: ignore[return-value] +======= + return inner(*extra_args, **extra_kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return inner +<<<<<<< HEAD def optimize_assert(*args: Any, **kwargs: Any) -> OptimizeContext: +======= +def optimize_assert(*args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if "rebuild_ctx" in kwargs and kwargs["rebuild_ctx"] is not None: # called from optimize rebuild_ctx = kwargs["rebuild_ctx"] del kwargs["rebuild_ctx"] else: +<<<<<<< HEAD def rebuild_ctx() -> OptimizeContext: +======= + def rebuild_ctx(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return optimize_assert(*args, **kwargs) return _optimize_assert(rebuild_ctx, *args, **kwargs) @@ -2188,6 +2849,7 @@ def rebuild_ctx() -> OptimizeContext: def _optimize_assert( rebuild_ctx: Callable[[], OptimizeContext], +<<<<<<< HEAD backend: Union[str, Callable[..., Any], None], *, hooks: Hooks = Hooks(None, None, None), @@ -2203,12 +2865,25 @@ def _optimize_assert( Used for fullgraph=True and export, since we must always error on graph breaks and ignore symbolic_convert.error_on_graph_break. Can also be used for testing. +======= + backend, + *, + hooks=Hooks(None, None, None), + export=False, + export_constraints=None, + dynamic=None, + package=None, +): + """ + The same as `torch._dynamo.optimize(backend, nopython=True)` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ backend = get_compiler_fn(backend) # Find if backend has any extra context manager backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) +<<<<<<< HEAD if config.caching_precompile and package is None: # Create an uninitialized package that will be set/filled by # _OptimizeContext.__call__ @@ -2219,6 +2894,8 @@ def _optimize_assert( package = CompilePackage(fn=None, dynamo=None, ignore_inlined_sources=False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return _optimize_catch_errors( convert_frame.convert_frame_assert( backend, @@ -2228,7 +2905,10 @@ def _optimize_assert( ), hooks, backend_ctx_ctor, +<<<<<<< HEAD fullgraph=True, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) export=export, dynamic=dynamic, rebuild_ctx=rebuild_ctx, @@ -2239,7 +2919,11 @@ def _optimize_assert( class TorchPatcher: @staticmethod @functools.cache +<<<<<<< HEAD def patch() -> None: +======= + def patch(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # A better way to disable the following would be decorate the source # functions with @torch._disable_dynamo. However, this causes issues # with torch.deploy internally. @@ -2332,6 +3016,7 @@ def patch() -> None: ) @staticmethod +<<<<<<< HEAD def suppress_torch_distributed_warnings( fn: Callable[..., Any], ) -> Callable[..., Any]: @@ -2340,11 +3025,23 @@ def inner_fn(*args: Any, **kwargs: Any) -> Any: torch._logging._internal.user_warning_filter ): return fn(*args, **kwargs) +======= + def suppress_torch_distributed_warnings(fn): + def inner_fn(*args, **kwargs): + warnings.filterwarnings( + "ignore", category=UserWarning, module="torch.distributed" + ) + return fn(*args, **kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return inner_fn +<<<<<<< HEAD def skip_code(code: types.CodeType) -> None: +======= +def skip_code(code: types.CodeType): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) set_code_exec_strategy( code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT) ) diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index e69b768ba3746..480b709f9f293 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -26,20 +26,29 @@ - Debugging utilities for error reporting """ +<<<<<<< HEAD import json +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import logging import os import re import textwrap import typing from enum import auto, Enum +<<<<<<< HEAD from functools import lru_cache from pathlib import Path +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from traceback import extract_stack, format_exc, format_list, StackSummary from typing import Any, NoReturn, Optional, TYPE_CHECKING import torch._guards +<<<<<<< HEAD from torch._utils_internal import get_file_path_2 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from . import config from .utils import counters @@ -50,7 +59,10 @@ from torch._guards import CompileId +<<<<<<< HEAD from .output_graph import DynamoTracerOutput +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .symbolic_convert import InstructionTranslatorBase from .types import DynamoFrameType @@ -68,6 +80,7 @@ def exportdb_error_message(case_name: str) -> str: class TorchDynamoException(RuntimeError): +<<<<<<< HEAD def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._torch_dynamo_tracer_output: Optional[DynamoTracerOutput] = None @@ -78,6 +91,12 @@ class InternalTorchDynamoError(TorchDynamoException): class ResumePrologueTracingError(TorchDynamoException): +======= + pass + + +class InternalTorchDynamoError(TorchDynamoException): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pass @@ -268,6 +287,7 @@ class UnsafeScriptObjectError(TorchDynamoException): class UncapturedHigherOrderOpError(TorchDynamoException): +<<<<<<< HEAD def __init__(self, msg: str, real_stack: Optional[StackSummary] = None) -> None: super().__init__(msg) self.msg = msg @@ -276,6 +296,9 @@ def __init__(self, msg: str, real_stack: Optional[StackSummary] = None) -> None: if real_stack is not None else torch._guards.TracingContext.extract_stack() ) +======= + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class IncorrectUsage(Exception): @@ -366,7 +389,11 @@ class ObservedTypeError(ObservedException): def get_dynamo_observed_exception(exc_type: type[Exception]) -> type[ObservedException]: if exc_type not in observed_exception_map: name = getattr(exc_type, "__name__", str(exc_type)) +<<<<<<< HEAD observed_exception_map[exc_type] = type( # type: ignore[assignment] +======= + observed_exception_map[exc_type] = type( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"Observed{name}Error", (ObservedException,), {} ) return observed_exception_map[exc_type] @@ -384,8 +411,13 @@ def raise_observed_exception( # CPython here raises an exception. Since there is no python code, we have to manually setup the exception # stack and raise the exception. exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type] +<<<<<<< HEAD tx.exn_vt_stack.set_current_exception(exception_vt) # type: ignore[arg-type] raise get_dynamo_observed_exception(exc_type) +======= + tx.exn_vt_stack.set_current_exception(exception_vt) + raise observed_exception_map[exc_type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def handle_observed_exception(tx: Any) -> None: @@ -512,6 +544,7 @@ def format_graph_break_message( return msg +<<<<<<< HEAD @lru_cache(maxsize=1) def _load_gb_type_to_gb_id_map() -> dict[str, Any]: """ @@ -562,6 +595,8 @@ def get_gbid_documentation_link(gb_type: str) -> Optional[str]: return None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO replace old unimplemented later def unimplemented_v2( gb_type: str, @@ -584,10 +619,17 @@ def unimplemented_v2( msg = format_graph_break_message(gb_type, context, explanation, hints) +<<<<<<< HEAD documentation_link = get_gbid_documentation_link(gb_type) if documentation_link: msg += f"\n For more details about this graph break, please visit: {documentation_link}" +======= + # Temporarily disabling the generation of the weblinks in error message + + # documentation_link = get_gbid_documentation_link(gb_type) + # msg += f"\n For more details about this graph break, please visit: {documentation_link}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if log_warning: log.warning(msg) @@ -596,6 +638,14 @@ def unimplemented_v2( raise Unsupported(msg) +<<<<<<< HEAD +======= +def warning(msg: str) -> None: + counters["warnings"][msg] += 1 + assert msg != os.environ.get("BREAK", False) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # KeyError has special handling for its args # see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details class KeyErrorMsg: diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index 2ff3f6752f568..b49edb8e07016 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# This module contains functions that *will be allowed* by dynamo + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module contains utility functions that are explicitly allowed to be called during TorchDynamo compilation. These functions are carefully vetted to ensure they work @@ -198,12 +203,19 @@ def nonrecursive_disable_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: return nonrecursive_disable_wrapper +<<<<<<< HEAD def wrap_dunder_call_ctx_manager(self: Any, func: Callable[_P, _R]) -> Callable[_P, _R]: """ Apply self as a ctx manager around a call to func """ # NOTE: do not functools.wraps(func) because we don't ever want this frame to be skipped! +======= +def _dynamo_config_patch_proxy_dunder_call( + self: Any, func: Callable[_P, _R] +) -> Callable[_P, _R]: + @functools.wraps(func) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: with self: return func(*args, **kwargs) @@ -227,6 +239,7 @@ def call_accumulate_grad( [grad], variable, variable.grad, has_post_hooks ) variable.grad = updated_grad[0] +<<<<<<< HEAD def wrap_inline_with_error_on_graph_break( @@ -278,3 +291,5 @@ def insert_const_values_with_mask( out.append(tup[idx]) idx += 1 return tuple(out) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 28fd02294ad3c..c2ffa27130161 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -168,7 +168,11 @@ { "Gb_type": "Attempted to wrap torch._higher_order_ops.invoke_subgraph", "Context": "", +<<<<<<< HEAD "Explanation": "Directly using invoke_subgraph is not supported. Use nested_compile_region", +======= + "Explanation": "Directly using invoke_subgraph is not supported. Use mark_compile_region", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "Hints": [] } ], @@ -222,7 +226,11 @@ { "Gb_type": "Builtin `operator.*` comparison with constant `self` failed", "Context": "call_method {self} {name} {args} {kwargs}", +<<<<<<< HEAD "Explanation": "\"Failed to compare {self} with {other}, \" + f\"because {other} is not a Python constant or its mutation check fails.\"", +======= + "Explanation": "\"Failed to compare {self} with {other}, because {other} is not a Python constant or its mutation check fails.\"", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "Hints": [] } ], @@ -391,7 +399,10 @@ "Context": "context", "Explanation": "Higher order ops do not support aliasing. Found in {source_target.name()}", "Hints": [ +<<<<<<< HEAD "Replace `return input` with `return input.clone()` to avoid aliasing.", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "Consider using the debug context to change user code to avoid aliasing.", "Please open an issue." ] @@ -851,7 +862,11 @@ "GB0088": [ { "Gb_type": "Observed exception", +<<<<<<< HEAD "Context": "raised exception {curr_exc.python_type_name()}({curr_exc.args})", +======= + "Context": "str(raised_exception)", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "Explanation": "observed_exn_gb_explanation", "Hints": [ "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." @@ -1660,6 +1675,7 @@ ], "GB0170": [ { +<<<<<<< HEAD "Gb_type": "Data-dependent branching", "Context": "attempted to jump with {value}", "Explanation": "_explanation", @@ -1675,6 +1691,8 @@ "Hints": [] }, { +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "Gb_type": "_gb_type", "Context": "attempted to jump with {value}", "Explanation": "_explanation", @@ -2153,6 +2171,7 @@ "Explanation": "Dynamo does not support tracing builtin index() on a Tensor", "Hints": [] } +<<<<<<< HEAD ], "GB0219": [ { @@ -2718,5 +2737,7 @@ "Explanation": "Dyanmo does not support tracing mutations on a class when its __dict__ is materialized", "Hints": [] } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] } diff --git a/torch/_dynamo/graph_deduplication.py b/torch/_dynamo/graph_deduplication.py index be2b51a7abdf7..57b3cede5e8f3 100644 --- a/torch/_dynamo/graph_deduplication.py +++ b/torch/_dynamo/graph_deduplication.py @@ -9,7 +9,11 @@ import logging import operator +<<<<<<< HEAD from collections import defaultdict, deque +======= +from collections import defaultdict +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from collections.abc import Generator, Iterable from typing import Optional @@ -80,8 +84,11 @@ def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]: ( subgraph, external_node_usages, +<<<<<<< HEAD node_usage_to_tuple_elems, ind_to_tuple_spec, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) = _create_subgraph(region, inds_with_external_users) # Ignore regions with no args for now, could they possibly be evaluated at compile time? @@ -102,8 +109,11 @@ def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]: region, get_subgraph_node, external_node_usages, +<<<<<<< HEAD node_usage_to_tuple_elems, ind_to_tuple_spec, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inds_with_external_users, subgraph_name, node_to_additional_deps, @@ -126,18 +136,26 @@ def _replace_region_with_subgraph( region: Region, get_subgraph_node: Node, external_node_usages: Iterable[OrderedSet[UsageIndex]], +<<<<<<< HEAD node_usage_to_tuple_elems: dict[UsageIndex, OrderedSet[int]], ind_to_tuple_spec: dict[int, dict[tuple[int, ...], int]], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inds_with_external_users: list[int], subgraph_name: str, node_to_additional_deps: dict[Node, OrderedSet[Node]], node_to_mutated_arg_positions: dict[Node, OrderedSet[int]], ) -> None: sub_args = [] +<<<<<<< HEAD flattened_getitem_nodes: OrderedSet[Node] = OrderedSet() for usages in external_node_usages: usage = next(iter(usages)) node_ind, usage_ind = usage +======= + for usages in external_node_usages: + node_ind, usage_ind = next(iter(usages)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) node = region[node_ind] flattened_args_kwargs = _get_flat_args(node, {}) for user_ind, node_usage_ind in usages: @@ -148,19 +166,27 @@ def _replace_region_with_subgraph( "NYI: Failed to substitute region %s due to mutation", region ) return +<<<<<<< HEAD if usage in node_usage_to_tuple_elems: tuple_elems = [region[i] for i in node_usage_to_tuple_elems[usage]] flattened_getitem_nodes.update(tuple_elems) sub_args.extend(tuple_elems) else: sub_args.append(flattened_args_kwargs[usage_ind]) +======= + sub_args.append(flattened_args_kwargs[usage_ind]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Input/Output aliasing not supported in HOPs today # Note: we should use the nodes in the original graph (the region here) # because we use the original traced example values for this check +<<<<<<< HEAD if _has_aliasing( region, sub_args, inds_with_external_users, flattened_getitem_nodes ): +======= + if _has_aliasing(region, sub_args, inds_with_external_users): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return invoke_args = (get_subgraph_node, subgraph_name, *sub_args) @@ -171,6 +197,7 @@ def _replace_region_with_subgraph( invoke_args, # type: ignore[arg-type] {}, ) +<<<<<<< HEAD ind = 0 flattened_output_nodes: OrderedSet[Node] = OrderedSet() @@ -200,6 +227,18 @@ def _replace_region_with_subgraph( if node not in flattened_output_nodes: graph.erase_node(node) +======= + for ind, external_user_ind in enumerate(inds_with_external_users): + node = region[external_user_ind] + subgraph_output = graph.create_node( + "call_function", operator.getitem, (invoke_subgraph_node, ind), {} + ) + node.replace_all_uses_with(subgraph_output, propagate_meta=True) + + # Erase in reverse topological order + for node in reversed(region): + graph.erase_node(node) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Remove any nodes with additional deps # This is safe; we've guaranteed that there is # no input mutation, so all additional deps @@ -254,6 +293,7 @@ def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None inds_unique.add(ind) +<<<<<<< HEAD def _create_subgraph( region: Region, inds_with_external_users: list[int], @@ -291,6 +331,17 @@ def _create_subgraph( placeholder = subgraph.placeholder(f"subgraph_input_{node.name}") region_to_subgraph_node[node] = placeholder +======= +def _copy_nodes_and_remap_inputs( + subgraph: torch.fx.Graph, region: Region +) -> list[OrderedSet[UsageIndex]]: + external_input_to_usages = _get_external_inputs(region) + external_node_usages = list[OrderedSet[UsageIndex]]() + region_to_subgraph_node = {} + for node, usage_indices in external_input_to_usages.items(): + placeholder = subgraph.placeholder(f"subgraph_input_{node.name}") + region_to_subgraph_node[node] = placeholder +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) external_node_usages.append(usage_indices) def map_arg(node: Node) -> Node: @@ -299,6 +350,7 @@ def map_arg(node: Node) -> Node: else: return node +<<<<<<< HEAD def copy_to_subgraph(node: Node) -> Node: subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old)) region_to_subgraph_node[node] = subgraph_node @@ -322,6 +374,31 @@ def copy_to_subgraph(node: Node) -> Node: subgraph.output(tuple(output_list)) return subgraph, external_node_usages, node_usage_to_tuple_elems, ind_to_tuple_spec +======= + for node in region: + subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old)) + region_to_subgraph_node[node] = subgraph_node + + return external_node_usages + + +def _create_subgraph_outputs( + subgraph: torch.fx.Graph, inds_to_output: list[int] +) -> None: + node_list = [n for n in subgraph.nodes if n.op not in ("placeholder", "output")] + out_tup = tuple(node_list[ind] for ind in inds_to_output) + subgraph.output(out_tup) + + +def _create_subgraph( + region: Region, + inds_with_external_users: list[int], +) -> tuple[torch.fx.Graph, list[OrderedSet[UsageIndex]]]: + subgraph: torch.fx.Graph = torch.fx.Graph() + external_node_usages = _copy_nodes_and_remap_inputs(subgraph, region) + _create_subgraph_outputs(subgraph, inds_with_external_users) + return subgraph, external_node_usages +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _stable_topological_sort( @@ -446,6 +523,7 @@ def _add_mutation_dependencies( def _has_aliasing( +<<<<<<< HEAD region: Region, inputs: list[Node], inds_with_external_users: list[int], @@ -455,6 +533,13 @@ def _has_aliasing( for node in inputs: if node in flattened_getitem_nodes: continue +======= + region: Region, inputs: list[Node], inds_with_external_users: list[int] +) -> bool: + input_storages: dict[StorageWeakRef, Node] = dict() + + for node in inputs: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) example_value = node.meta["example_value"] if isinstance(example_value, torch.Tensor): storage = StorageWeakRef(example_value._typed_storage()) @@ -468,11 +553,18 @@ def _has_aliasing( ) return True input_storages[storage] = node +<<<<<<< HEAD output_storages: dict[StorageWeakRef, Node] = dict() for i in inds_with_external_users: out_node = region[i] if out_node in flattened_getitem_nodes: continue +======= + + output_storages: dict[StorageWeakRef, Node] = dict() + for i in inds_with_external_users: + out_node = region[i] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if out_node: example_value = out_node.meta["example_value"] assert not isinstance(example_value, list) @@ -488,6 +580,10 @@ def _has_aliasing( ) return True output_storages[storage] = out_node +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) intersected_storages = input_storages.keys() & output_storages.keys() if len(intersected_storages) > 0: # input-output aliasing @@ -501,6 +597,7 @@ def _has_aliasing( aliased, ) return True +<<<<<<< HEAD return False @@ -589,3 +686,7 @@ def _replace_tuple_outputs( graph.erase_node(node) erased_nodes.add(node) return erased_nodes +======= + + return False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_dynamo/graph_region_tracker.py b/torch/_dynamo/graph_region_tracker.py index c1463d290bc9c..4ca3a2acba8b3 100644 --- a/torch/_dynamo/graph_region_tracker.py +++ b/torch/_dynamo/graph_region_tracker.py @@ -13,8 +13,11 @@ optimization operations. """ +<<<<<<< HEAD from __future__ import annotations +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import copyreg import io import logging @@ -125,6 +128,7 @@ def _normalize_args( return (sorted_keys, tuple(_extract_args(arg) for arg in all_args)) +<<<<<<< HEAD def _sort_with_ref_region( index_to_rank: dict[int, int], regions: list[list[Any]] ) -> None: @@ -137,6 +141,8 @@ def _sort_with_ref_region( region[:] = [region[i] for i in sorted_indices] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_global_state_key() -> GlobalStateKey: return ( torch.is_grad_enabled(), @@ -165,7 +171,11 @@ def __init__(self, origin: Node) -> None: self._queue: deque[Optional[Node]] = deque() @staticmethod +<<<<<<< HEAD def create(origin: Node) -> BackwardBfsArgIter: +======= + def create(origin: Node) -> "BackwardBfsArgIter": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) it = BackwardBfsArgIter(origin) it.add_children(origin) # pop the origin node, since it is the origin of @@ -240,13 +250,18 @@ def _is_identical(self, n0: Node, n1: Node) -> bool: and n0 is not n1 ) +<<<<<<< HEAD def track_node(self, tx: InstructionTranslatorBase, node: Node) -> None: +======= + def track_node(self, tx: "InstructionTranslatorBase", node: Node) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ The main entry point for tracking a node. This function will hash the node argument and group nodes with the same hash together. It updates the hash_to_duplicates and node_to_duplicates dictionaries to track the new node. """ try: +<<<<<<< HEAD if ( node not in self.node_to_duplicates ): # don't allow nodes to be added twice @@ -257,6 +272,15 @@ def track_node(self, tx: InstructionTranslatorBase, node: Node) -> None: ] duplicates.append(node) self.node_to_duplicates[node] = duplicates +======= + duplicates = self.hash_to_duplicates[ + self._hash_node( + tx.f_code.co_filename, tx.lineno, tx.instruction_pointer, node + ) + ] + duplicates.append(node) + self.node_to_duplicates[node] = duplicates +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except NodeHashException as e: log.debug("Unable to hash node %s with exception %s", node, e) @@ -344,6 +368,7 @@ def get_identical_regions(self, graph: torch.fx.Graph) -> list[list[Region]]: self._is_identical, ) # sort topologically +<<<<<<< HEAD # we need to handle edge cases where some nodes have no dependencies # so first we map each node to its ranking, ref_region = region_group[0] @@ -351,6 +376,10 @@ def get_identical_regions(self, graph: torch.fx.Graph) -> list[list[Region]]: index: topological_ranking[n] for index, n in enumerate(ref_region) } _sort_with_ref_region(index_to_rank, region_group) +======= + for region in region_group: + region.sort(key=lambda n: topological_ranking[n]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return [ region_group for region_group in region_groups if len(region_group[0]) > 1 @@ -444,7 +473,10 @@ def fully_expand_region_group( candidate not in seen_nodes and candidate not in nodes_to_add and candidate.op != "placeholder" +<<<<<<< HEAD and candidate.op != "get_attr" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and is_identical_fn(candidate, current_node) and not region_wrapper.will_inclusion_create_cycle(candidate) ) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index be7ff5051f2d5..03312f73e28fa 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Core guard system for Dynamo that detects when compiled code needs to be recompiled due to changes in program state. Guards are conditions that must remain true for previously-compiled @@ -31,7 +36,10 @@ import pickle import sys import textwrap +<<<<<<< HEAD import traceback +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import types import warnings import weakref @@ -39,6 +47,7 @@ from copy import deepcopy from inspect import currentframe from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union +<<<<<<< HEAD try: @@ -47,6 +56,8 @@ from typing_extensions import LiteralString from typing_extensions import TypeAliasType, TypeVar +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from weakref import ReferenceType import torch @@ -56,6 +67,7 @@ from torch._C._dynamo.guards import ( check_obj_id, check_type_id, +<<<<<<< HEAD ClosureGuardAccessor, CodeGuardAccessor, dict_version, @@ -68,10 +80,15 @@ GuardAccessor, GuardDebugInfo, GuardManager, +======= + dict_version, + DictGuardManager, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) install_no_tensor_aliasing_guard, install_object_aliasing_guard, install_storage_overlapping_guard, install_symbolic_shape_guard, +<<<<<<< HEAD LeafGuard, profile_guard_manager, RelationalGuard, @@ -80,6 +97,10 @@ TypeDictGuardAccessor, TypeGuardAccessor, TypeMROGuardAccessor, +======= + profile_guard_manager, + RootGuardManager, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from torch._dynamo.source import ( get_global_source_name, @@ -88,8 +109,11 @@ is_from_flatten_script_object_source, is_from_local_source, is_from_optimizer_source, +<<<<<<< HEAD is_from_skip_guard_source, is_from_unspecialized_builtin_nn_module_source, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorProperty, TensorPropertySource, ) @@ -105,7 +129,10 @@ Source, StorageOverlap, ) +<<<<<<< HEAD from torch._inductor.utils import IndentedBuffer +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._logging import structured from torch._utils_internal import justknobs_check from torch.fx.experimental.symbolic_shapes import ( @@ -128,8 +155,11 @@ CallFunctionNoArgsSource, CallMethodItemSource, ChainedSource, +<<<<<<< HEAD ClosureSource, CodeSource, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ConstantSource, ConstDictKeySource, DataclassFieldsSource, @@ -147,19 +177,27 @@ GradSource, ListGetItemSource, LocalSource, +<<<<<<< HEAD NamedTupleFieldsSource, NNModuleSource, NonSerializableSetGetItemSource, +======= + NNModuleSource, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NumpyTensorSource, OptimizerSource, ScriptObjectQualifiedNameSource, ShapeEnvSource, SubclassAttrListSource, TorchFunctionModeStackSource, +<<<<<<< HEAD TorchSource, TupleIteratorGetItemSource, TypeDictSource, TypeMROSource, +======= + TupleIteratorGetItemSource, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TypeSource, UnspecializedBuiltinNNModuleSource, UnspecializedNNModuleSource, @@ -197,7 +235,11 @@ ) +<<<<<<< HEAD guard_manager_testing_hook_fn: Optional[Callable[[Any, Any, Any], Any]] = None +======= +guard_manager_testing_hook_fn: Optional[Callable[[Any, Any], Any]] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: import numpy as np @@ -206,6 +248,7 @@ if TYPE_CHECKING: +<<<<<<< HEAD from collections.abc import Generator, KeysView, Sequence from sympy import Symbol @@ -214,6 +257,13 @@ from torch._dynamo.output_graph import OutputGraph, OutputGraphGuardsState T = TypeVar("T") +======= + from sympy import Symbol + + from torch._dynamo.output_graph import OutputGraphGuardsState + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) guards_log = torch._logging.getArtifactLogger(__name__, "guards") recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles") @@ -223,6 +273,7 @@ verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards") +<<<<<<< HEAD dunder_attrs_assumed_constants = ( "__defaults__", "__kwdefaults__", @@ -245,6 +296,8 @@ def writeline(self, line: str, skip_prefix: bool = False) -> None: # type: igno super().writeline("+- " + line) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class GuardManagerWrapper: """ A helper class that contains the root guard manager. An instance of this @@ -253,12 +306,17 @@ class is stored in the Dynamo cache entry, so that the cache entry can the check_nopybind from C++. """ +<<<<<<< HEAD def __init__(self, root: Optional[RootGuardManager] = None) -> None: +======= + def __init__(self, root=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if root is None: self.root = RootGuardManager() else: self.root = root +<<<<<<< HEAD self.diff_guard_root: Optional[RootGuardManager] = None self.closure_vars: Optional[dict[str, Any]] = None self.args: Optional[list[str]] = None @@ -272,19 +330,42 @@ def __init__(self, root: Optional[RootGuardManager] = None) -> None: self.no_tensor_aliasing_sources: list[str] = [] self.printed_relational_guards: set[RelationalGuard] = set() +======= + self.diff_guard_root = None + self.closure_vars = None + self.args = None + self.code_parts = [] + self.verbose_code_parts = None + self.global_scope = None + self.guard_fail_fn = None + self.cache_entry = None + self.extra_state = None + self.id_matched_objs = {} + self.no_tensor_aliasing_sources = [] + + self.printed_relational_guards = set() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.diff_guard_sources: OrderedSet[str] = OrderedSet() @contextmanager +<<<<<<< HEAD def _preserve_printed_relational_guards(self) -> Generator[None, None, None]: +======= + def _preserve_printed_relational_guards(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.printed_relational_guards = set() try: yield finally: self.printed_relational_guards = set() +<<<<<<< HEAD # TODO: clarify what fn and attributes guard manager has to get the right things here def collect_diff_guard_sources(self) -> OrderedSet[str]: +======= + def collect_diff_guard_sources(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # At the time of finalize, we have only marked guard managers with # TENSOR_MATCH guards as diff guard managers. So, we do a tree traversal # and collect all the nodes in the tree (branches) that lead to tensor @@ -294,7 +375,11 @@ def collect_diff_guard_sources(self) -> OrderedSet[str]: # 0, so we collect them as well. Later on, we accumulate the diff guard # sources for all the guard managers. +<<<<<<< HEAD def visit_dict_manager(node: DictGuardManager) -> bool: +======= + def visit_dict_manager(node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_diff_guard_node = ( node.get_source() in self.diff_guard_sources or node.fail_count() > 0 ) @@ -308,7 +393,11 @@ def visit_dict_manager(node: DictGuardManager) -> bool: return is_diff_guard_node +<<<<<<< HEAD def visit_manager(node: GuardManager) -> bool: +======= + def visit_manager(node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert not isinstance(node, DictGuardManager) is_diff_guard_node = ( @@ -322,7 +411,11 @@ def visit_manager(node: GuardManager) -> bool: return is_diff_guard_node +<<<<<<< HEAD def visit(node: GuardManager) -> bool: +======= + def visit(node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if node is None: return False if isinstance(node, DictGuardManager): @@ -333,6 +426,7 @@ def visit(node: GuardManager) -> bool: return self.diff_guard_sources +<<<<<<< HEAD def finalize(self) -> None: if config.use_recursive_dict_tags_for_guards and justknobs_check( "pytorch/compiler:use_recursive_dict_tags_for_guards" @@ -565,6 +659,13 @@ def visit(node: GuardManager) -> list[GuardManager]: node.mark_tag_safe_root() def populate_diff_guard_manager(self) -> None: +======= + def finalize(self): + self.collect_diff_guard_sources() + self.populate_diff_guard_manager() + + def populate_diff_guard_manager(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.diff_guard_root = self.clone_with_chosen_sources(self.diff_guard_sources) # Ensure that that C++ side points to the updated diff guard manager. @@ -577,28 +678,42 @@ def populate_diff_guard_manager(self) -> None: if self.cache_entry: self.cache_entry.update_diff_guard_root_manager() +<<<<<<< HEAD def clone_with_chosen_sources( self, chosen_sources: OrderedSet[str] ) -> RootGuardManager: def filter_fn(node_mgr: GuardManager) -> bool: +======= + def clone_with_chosen_sources(self, chosen_sources): + def filter_fn(node_mgr): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return node_mgr.get_source() in chosen_sources return self.root.clone_manager(filter_fn) +<<<<<<< HEAD def get_guard_lines(self, guard: LeafGuard) -> list[str]: +======= + def get_guard_lines(self, guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) guard_name = guard.__class__.__name__ parts = guard.verbose_code_parts() parts = [guard_name + ": " + part for part in parts] return parts +<<<<<<< HEAD def get_manager_line( self, guard_manager: GuardManager, accessor_str: Optional[str] = None ) -> str: +======= + def get_manager_line(self, guard_manager, accessor_str=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) source = guard_manager.get_source() t = guard_manager.__class__.__name__ s = t + ": source=" + source if accessor_str: s += ", " + accessor_str +<<<<<<< HEAD s += f", type={guard_manager.get_type_of_guarded_value()}" s += f", tag_safe=({guard_manager.is_tag_safe()}, {guard_manager.is_tag_safe_root()})" return s @@ -606,6 +721,11 @@ def get_manager_line( def construct_dict_manager_string( self, mgr: DictGuardManager, body: IndentedBufferWithPrefix ) -> None: +======= + return s + + def construct_dict_manager_string(self, mgr, body): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for idx, (key_mgr, val_mgr) in sorted(mgr.get_key_value_managers().items()): body.writeline(f"KeyValueManager pair at index={idx}") with body.indent(): @@ -617,12 +737,19 @@ def construct_dict_manager_string( body.writeline(f"ValueManager: {self.get_manager_line(val_mgr)}") self.construct_manager_string(val_mgr, body) +<<<<<<< HEAD def construct_manager_string( self, mgr: GuardManager, body: IndentedBufferWithPrefix ) -> None: with body.indent(): for guard in mgr.get_leaf_guards(): if isinstance(guard, RelationalGuard): +======= + def construct_manager_string(self, mgr, body): + with body.indent(): + for guard in mgr.get_leaf_guards(): + if isinstance(guard, torch._C._dynamo.guards.RelationalGuard): # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if guard not in self.printed_relational_guards: self.printed_relational_guards.add(guard) body.writelines(self.get_guard_lines(guard)) @@ -648,7 +775,23 @@ def construct_manager_string( ) self.construct_manager_string(child_mgr, body) +<<<<<<< HEAD def __str__(self) -> str: +======= + def __str__(self): + from torch._inductor.utils import IndentedBuffer + + class IndentedBufferWithPrefix(IndentedBuffer): + def prefix(self): + return "| " * (self._indent * self.tabwidth) + + def writeline(self, line, skip_prefix=False): + if skip_prefix: + super().writeline(line) + else: + super().writeline("+- " + line) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self._preserve_printed_relational_guards(): body = IndentedBufferWithPrefix() body.tabwidth = 1 @@ -661,6 +804,7 @@ def __str__(self) -> str: body.writelines(self.get_guard_lines(guard)) return body.getvalue() +<<<<<<< HEAD def check(self, x: Any) -> bool: # Only needed for debugging purposes. return self.root.check(x) @@ -674,16 +818,38 @@ def populate_code_parts_for_debugging(self) -> None: relational_guards_seen = set() def get_code_parts(leaf_guard: LeafGuard) -> list[str]: +======= + def check(self, x): + # Only needed for debugging purposes. + return self.root.check(x) + + def check_verbose(self, x): + # Only needed for debugging purposes. + return self.root.check_verbose(x) + + def populate_code_parts_for_debugging(self): + # This should be called when the guard manager is fully populated + relational_guards_seen = set() + + def get_code_parts(leaf_guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) code_parts = [] for verbose_code_part in leaf_guard.verbose_code_parts(): code_part = verbose_code_part.split("#")[0].rstrip() code_parts.append(code_part) return code_parts +<<<<<<< HEAD def visit(mgr: GuardManager) -> None: nonlocal relational_guards_seen for guard in mgr.get_leaf_guards(): if isinstance(guard, RelationalGuard): +======= + def visit(mgr): + nonlocal relational_guards_seen + for guard in mgr.get_leaf_guards(): + if isinstance(guard, torch._C._dynamo.guards.RelationalGuard): # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if guard not in relational_guards_seen: self.code_parts.extend(get_code_parts(guard)) relational_guards_seen.add(guard) @@ -696,7 +862,11 @@ def visit(mgr: GuardManager) -> None: visit(self.root) +<<<<<<< HEAD def from_numpy(a: Any) -> torch.Tensor: +======= +def from_numpy(a): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # If not numpy array, piggy back on e.g. tensor guards to check type # Re-enable torch function since we disable it on leaf guards # we need it to properly construct the tensor if a default device is set @@ -706,7 +876,11 @@ def from_numpy(a: Any) -> torch.Tensor: # For user stack printing @functools.cache +<<<<<<< HEAD def uninteresting_files() -> set[str]: +======= +def uninteresting_files(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch._dynamo.external_utils import torch._dynamo.polyfills @@ -722,7 +896,11 @@ def uninteresting_files() -> set[str]: _CLOSURE_VARS: Optional[dict[str, object]] = None +<<<<<<< HEAD def _get_closure_vars() -> dict[str, object]: +======= +def _get_closure_vars(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) global _CLOSURE_VARS if _CLOSURE_VARS is None: _CLOSURE_VARS = { @@ -736,7 +914,10 @@ def _get_closure_vars() -> dict[str, object]: "___normalize_range_iter": normalize_range_iter, "___tuple_iterator_getitem": tuple_iterator_getitem, "___dataclass_fields": dataclass_fields, +<<<<<<< HEAD "___namedtuple_fields": lambda x: x._fields, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, "__math_isnan": math.isnan, "__numpy_isnan": None if np is None else np.isnan, @@ -759,13 +940,18 @@ def _ast_unparse(node: ast.AST) -> str: strip_function_call = torch._C._dynamo.strip_function_call +<<<<<<< HEAD def get_verbose_code_part(code_part: str, guard: Optional[Guard]) -> str: +======= +def get_verbose_code_part(code_part: str, guard: Guard) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extra = "" if guard is not None: if guard.user_stack: for fs in reversed(guard.user_stack): if fs.filename not in uninteresting_files(): extra = f" # {format_frame(fs, line=True)}" +<<<<<<< HEAD if len(extra) > 1024: # For fx graphs, the line can be very long in case of # torch.stack ops, where many inputs are set to None @@ -773,6 +959,8 @@ def get_verbose_code_part(code_part: str, guard: Optional[Guard]) -> str: # guards log file. In such cases, do not print the line # contents. extra = f" # {format_frame(fs)}" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) break elif guard.stack: summary = guard.stack.summary() @@ -784,6 +972,7 @@ def get_verbose_code_part(code_part: str, guard: Optional[Guard]) -> str: def get_verbose_code_parts( +<<<<<<< HEAD code_parts: Union[str, list[str]], guard: Optional[Guard], recompile_hint: Optional[str] = None, @@ -803,6 +992,16 @@ def get_verbose_code_parts( def convert_int_to_concrete_values(dim: Any) -> Optional[int]: +======= + code_parts: Union[str | list[str]], guard: Guard +) -> list[str]: + if not isinstance(code_parts, list): + code_parts = [code_parts] + return [get_verbose_code_part(code_part, guard) for code_part in code_parts] + + +def convert_int_to_concrete_values(dim) -> Optional[int]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if dim is None: return None if not is_symbolic(dim): @@ -812,6 +1011,7 @@ def convert_int_to_concrete_values(dim: Any) -> Optional[int]: return dim.node.maybe_as_int() +<<<<<<< HEAD def convert_to_concrete_values(size_or_stride: list[Any]) -> list[Optional[int]]: return [convert_int_to_concrete_values(dim) for dim in size_or_stride] @@ -824,6 +1024,13 @@ def get_tensor_guard_code_part( pytype: type, dispatch_keys: DispatchKeySet, ) -> str: +======= +def convert_to_concrete_values(size_or_stride): + return [convert_int_to_concrete_values(dim) for dim in size_or_stride] + + +def get_tensor_guard_code_part(value, name, sizes, strides, pytype, dispatch_keys): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dispatch_key = ( dispatch_keys | torch._C._dispatch_tls_local_include_set() ) - torch._C._dispatch_tls_local_exclude_set() @@ -837,7 +1044,11 @@ def get_tensor_guard_code_part( return guard_str +<<<<<<< HEAD def get_key_index(dct: dict[Any, Any], key: Any) -> int: +======= +def get_key_index(dct, key): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Ensure that we call dict.keys and not value.keys (which can call # overridden keys method). In the C++ guards, we relied on PyDict_Next # to traverse the dictionary, which uses the internal data structure and @@ -845,7 +1056,11 @@ def get_key_index(dct: dict[Any, Any], key: Any) -> int: return list(builtin_dict_keys(dct)).index(key) +<<<<<<< HEAD def get_key_index_source(source: Any, index: Any) -> str: +======= +def get_key_index_source(source, index): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"list(dict.keys({source}))[{index}]" @@ -857,6 +1072,7 @@ def raise_local_type_error(obj: Any) -> NoReturn: ) +<<<<<<< HEAD def should_optimize_getattr_on_nn_module(value: Any) -> bool: # If inline_inbuilt_nn_modules flag is True, Dynamo has already traced # through the __getattr__, and therefore it is always safe to optimize @@ -867,6 +1083,8 @@ def should_optimize_getattr_on_nn_module(value: Any) -> bool: ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass(frozen=True) class NNModuleAttrAccessorInfo: # Represents where is the attr name is present in the nn module attribute @@ -883,12 +1101,17 @@ class NNModuleAttrAccessorInfo: def getitem_on_dict_manager( +<<<<<<< HEAD source: Union[DictGetItemSource, DictSubclassGetItemSource], base_guard_manager: DictGuardManager, base_example_value: Any, example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: +======= + source, base_guard_manager, base_example_value, example_value, guard_manager_enum +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) base_source_name = source.base.name() if isinstance(source.index, ConstDictKeySource): index = source.index.index @@ -927,7 +1150,11 @@ def getitem_on_dict_manager( ) +<<<<<<< HEAD def match_on_id_for_tensor(guard: Guard) -> bool: +======= +def match_on_id_for_tensor(guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) source = guard.originating_source # For numpy tensors, always use TENSOR_MATCH because __from_numpy leads # to a new tensor every time and therefore id differs. @@ -954,7 +1181,11 @@ class GuardManagerType(enum.Enum): @functools.cache +<<<<<<< HEAD def code_framelocals_names_reversed_cached(code: types.CodeType) -> list[str]: +======= +def code_framelocals_names_reversed_cached(code: types.CodeType): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return list(reversed(code_framelocals_names(code))) @@ -962,22 +1193,36 @@ class GuardBuilder(GuardBuilderBase): def __init__( self, f_code: types.CodeType, +<<<<<<< HEAD id_ref: Callable[[object, str], int], source_ref: Callable[[Source], str], lookup_weakrefs: Callable[[object], Optional[weakref.ref[object]]], +======= + id_ref: Callable[[Any, str], str], + source_ref: Callable[[Source], str], + lookup_weakrefs: Callable[[object], ReferenceType[object]], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) local_scope: dict[str, object], global_scope: dict[str, object], guard_manager: GuardManagerWrapper, check_fn_manager: CheckFunctionManager, +<<<<<<< HEAD save_guards: bool = False, runtime_global_scope: Optional[dict[str, object]] = None, ) -> None: +======= + serialization_mode: Optional[str] = None, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.f_code = f_code self.id_ref = id_ref self.source_ref = source_ref self.lookup_weakrefs = lookup_weakrefs self.scope: dict[str, dict[str, object]] = {"L": local_scope, "G": global_scope} +<<<<<<< HEAD self.runtime_global_scope = runtime_global_scope or global_scope +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.scope["__builtins__"] = builtins.__dict__.copy() for ( name, @@ -1002,7 +1247,11 @@ def __init__( # Collect the guard managers and debug info to insert no tensor aliasing # guards. self.no_tensor_aliasing_names: list[str] = [] +<<<<<<< HEAD self.no_tensor_aliasing_guard_managers: list[GuardManager] = [] +======= + self.no_tensor_aliasing_guard_managers: list[GuardManagerWrapper] = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.check_fn_manager: CheckFunctionManager = check_fn_manager @@ -1011,7 +1260,10 @@ def __init__( # to access the same object - self._module["param"] is same as # self.param. self.key_order_guarded_dict_ids = set() +<<<<<<< HEAD assert self.check_fn_manager.output_graph is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for source in self.check_fn_manager.output_graph.guard_on_key_order: self.key_order_guarded_dict_ids.add(id(self.get(source.name()))) @@ -1021,6 +1273,7 @@ def __init__( self.id_matched_objs: dict[str, ReferenceType[object]] = {} # Save the guard managers to avoid repeatedly traversing sources. +<<<<<<< HEAD self._cached_guard_managers: dict[str, GuardManager] = {} self._cached_duplicate_input_guards: set[tuple[str, str]] = set() self.object_aliasing_guard_codes: list[tuple[str, str]] = [] @@ -1035,6 +1288,15 @@ def __init__( def guard_on_dict_keys_and_ignore_order( self, example_value: dict[Any, Any], guard: Guard ) -> None: +======= + self._cached_guard_managers: dict[ + str, torch._C._dynamo.guards.GuardManager + ] = {} + self._cached_duplicate_input_guards: set[tuple[str, str]] = set() + self.serialization_mode = serialization_mode + + def guard_on_dict_keys_and_ignore_order(self, example_value, guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dict_mgr = self.get_guard_manager(guard) if isinstance(dict_mgr, DictGuardManager): raise NotImplementedError( @@ -1062,7 +1324,11 @@ def guard_on_dict_keys_and_ignore_order( guard_manager_enum=guard_manager_enum, ) +<<<<<<< HEAD def guard_on_dict_keys_and_order(self, value: dict[Any, Any], guard: Guard) -> None: +======= + def guard_on_dict_keys_and_order(self, value, guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Add key managers for the DictGuardManager. Then add either an # ID_MATCH or EQUALS_MATCH guard on the key. dict_mgr = self.get_guard_manager(guard) @@ -1101,7 +1367,11 @@ def guard_on_dict_keys_and_order(self, value: dict[Any, Any], guard: Guard) -> N ) @staticmethod +<<<<<<< HEAD def _get_generic_dict_manager_example_value(example_value: Any) -> Optional[Any]: +======= + def _get_generic_dict_manager_example_value(example_value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # due to a bug in 3.13.0 (introduced by https://github.com/python/cpython/pull/116115, # reported in https://github.com/python/cpython/issues/125608, # fixed by https://github.com/python/cpython/pull/125611), we cannot take @@ -1120,6 +1390,7 @@ def _get_generic_dict_manager_example_value(example_value: Any) -> Optional[Any] def getattr_on_nn_module( self, +<<<<<<< HEAD source: AttrSource, base_guard_manager: GuardManager, base_example_value: Any, @@ -1128,6 +1399,16 @@ def getattr_on_nn_module( source_name: str, guard_manager_enum: GuardManagerType, ) -> GuardManager: +======= + source, + base_guard_manager, + base_example_value, + example_value, + base_source_name, + source_name, + guard_manager_enum, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This tries to avoid calling the expensive nn module custom getattr method by checking if the attribute is accessible via __dict__. For attributes that @@ -1146,6 +1427,7 @@ def getattr_on_nn_module( """ def getitem_on_dict_mgr( +<<<<<<< HEAD mgr: GuardManager, key: Any, source_name: str, @@ -1153,6 +1435,10 @@ def getitem_on_dict_mgr( example_value: Any, guard_manager_enum: GuardManagerType, ) -> GuardManager: +======= + mgr, key, source_name, base_example_value, example_value, guard_manager_enum + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(mgr, DictGuardManager): # Case where the user code relies on key order, e.g., # named_parameters @@ -1262,7 +1548,10 @@ def getitem_on_dict_mgr( ) if l2_key: +<<<<<<< HEAD assert l2_source_name is not None and l2_guard_manager_enum is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return getitem_on_dict_mgr( mgr=l1_mgr, key=l2_key, @@ -1273,13 +1562,18 @@ def getitem_on_dict_mgr( ) return l1_mgr +<<<<<<< HEAD def requires_key_order_guarding(self, source: Source) -> bool: +======= + def requires_key_order_guarding(self, source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) source_name = source.name() if source_name == "": return False obj_id = id(self.get(source_name)) return obj_id in self.key_order_guarded_dict_ids +<<<<<<< HEAD def get_guard_manager_type( self, source: Source, @@ -1287,33 +1581,52 @@ def get_guard_manager_type( Union[KeysView[Any], set[Any], frozenset[Any], dict[Any, Any]] ], ) -> GuardManagerType: +======= + def get_guard_manager_type(self, source, example_value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) guard_manager_enum = GuardManagerType.GUARD_MANAGER if self.requires_key_order_guarding(source): # Fix this if condition if isinstance(example_value, dict_keys): guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER +<<<<<<< HEAD elif isinstance(example_value, (set, frozenset)): # we don't need to guard on key order for set/frozenset # but the if above will be true for these types as set is # implemented using a dict in Dynamo guard_manager_enum = GuardManagerType.GUARD_MANAGER +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: assert isinstance(example_value, dict) guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER return guard_manager_enum +<<<<<<< HEAD def manager_guards_on_keys(self, mgr_enum: GuardManagerType) -> bool: return mgr_enum == GuardManagerType.DICT_GUARD_MANAGER def get_global_guard_manager(self) -> GuardManager: return self.guard_manager.root.globals_dict_manager( f_globals=self.runtime_global_scope, +======= + def manager_guards_on_keys(self, mgr_enum): + return mgr_enum == GuardManagerType.DICT_GUARD_MANAGER + + def get_global_guard_manager(self): + return self.guard_manager.root.globals_dict_manager( + f_globals=self.scope["G"], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) source="G", example_value=self.scope["G"], guard_manager_enum=GuardManagerType.GUARD_MANAGER, ) +<<<<<<< HEAD def get_guard_manager_from_source(self, source: Source) -> GuardManager: +======= + def get_guard_manager_from_source(self, source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) root_guard_manager = self.guard_manager.root example_value = None @@ -1392,6 +1705,7 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: example_value=example_value, guard_manager_enum=guard_manager_enum, ) +<<<<<<< HEAD elif istype(source, TypeDictSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.type_dict_manager( @@ -1406,6 +1720,8 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: example_value=example_value, guard_manager_enum=guard_manager_enum, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif istype( source, ( @@ -1418,6 +1734,7 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: ): assert base_guard_manager # to make mypy happy out = base_guard_manager +<<<<<<< HEAD elif istype(source, TorchSource): out = root_guard_manager.lambda_manager( python_lambda=lambda _: torch, @@ -1425,6 +1742,8 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: example_value=example_value, guard_manager_enum=guard_manager_enum, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif istype(source, TorchFunctionModeStackSource): out = root_guard_manager.lambda_manager( python_lambda=lambda _: get_torch_function_mode_stack_at( @@ -1451,9 +1770,18 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: ) elif istype(source, (AttrSource, UnspecializedParamBufferSource)): assert base_guard_manager # to make mypy happy +<<<<<<< HEAD assert isinstance(source, AttrSource) if should_optimize_getattr_on_nn_module(base_example_value): assert base_source_name +======= + + if ( + isinstance(base_example_value, torch.nn.Module) + and get_custom_getattr(base_example_value) + is unpatched_nn_module_getattr + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = self.getattr_on_nn_module( source, base_guard_manager, @@ -1473,7 +1801,10 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: elif istype(source, (DictGetItemSource, DictSubclassGetItemSource)): assert base_guard_manager # to make mypy happy assert isinstance(base_example_value, (dict, collections.OrderedDict)) +<<<<<<< HEAD assert isinstance(source, (DictGetItemSource, DictSubclassGetItemSource)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(base_guard_manager, DictGuardManager): assert self.manager_guards_on_keys(base_guard_manager_enum) out = getitem_on_dict_manager( @@ -1553,7 +1884,10 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: ) elif istype(source, DefaultsSource): assert base_guard_manager # to make mypy happy +<<<<<<< HEAD assert base_source_name +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert callable(base_example_value) if not source.is_kw: out = base_guard_manager.func_defaults_manager( @@ -1661,6 +1995,7 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: example_value=example_value, guard_manager_enum=guard_manager_enum, ) +<<<<<<< HEAD elif istype(source, NonSerializableSetGetItemSource): assert base_guard_manager out = base_guard_manager.set_getitem_manager( @@ -1669,6 +2004,8 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: example_value=example_value, guard_manager_enum=guard_manager_enum, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif istype(source, WeakRefCallSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.weakref_call_manager( @@ -1691,6 +2028,7 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: example_value=example_value, guard_manager_enum=guard_manager_enum, ) +<<<<<<< HEAD elif istype(source, NamedTupleFieldsSource): assert base_guard_manager out = base_guard_manager.lambda_manager( @@ -1713,6 +2051,8 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: example_value=example_value, guard_manager_enum=guard_manager_enum, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: raise AssertionError( f"missing guard manager builder {source} - {source.name()}" @@ -1721,16 +2061,28 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: self._cached_guard_managers[source.name()] = out return out +<<<<<<< HEAD def get_guard_manager(self, guard: Guard) -> GuardManager: +======= + def get_guard_manager(self, guard: Guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.get_guard_manager_from_source(guard.originating_source) def add_python_lambda_leaf_guard_to_root( self, +<<<<<<< HEAD code_parts: list[str], verbose_code_parts: list[str], closure_vars: Optional[dict[str, object]] = None, is_epilogue: bool = True, ) -> None: +======= + code_parts, + verbose_code_parts, + closure_vars=None, + is_epilogue=True, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if closure_vars is None: closure_vars = _get_closure_vars() # Adds a lambda leaf guard to the root guard manager. It wraps the @@ -1785,6 +2137,7 @@ def arg_ref(self, guard: Union[str, Guard]) -> str: return name +<<<<<<< HEAD def _guard_on_attribute( self, guard: Guard, @@ -1795,6 +2148,10 @@ def _guard_on_attribute( attr_source = CodeSource(guard.originating_source) else: attr_source = AttrSource(guard.originating_source, attr_name) # type: ignore[assignment] +======= + def _guard_on_attribute(self, guard: Guard, attr_name: str, guard_fn): + attr_source = AttrSource(guard.originating_source, attr_name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Copy the stack info new_guard = Guard( attr_source, guard_fn, stack=guard.stack, user_stack=guard.user_stack @@ -1802,6 +2159,7 @@ def _guard_on_attribute( new_guard.create(self) # Note: the order of the guards in this file matters since we sort guards on the same object by lineno +<<<<<<< HEAD def HASATTR(self, guard: Guard) -> None: source = guard.originating_source if isinstance(source, NNModuleSource): @@ -1809,6 +2167,12 @@ def HASATTR(self, guard: Guard) -> None: if isinstance(source, CodeSource): # No need to guard that a function has a __code__ attribute return +======= + def HASATTR(self, guard: Guard): + source = guard.originating_source + if isinstance(source, NNModuleSource): + source = source.base +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(source, AttrSource), f"invalid source {guard.name}" base_source = source.base base = base_source.name() @@ -1835,8 +2199,17 @@ def HASATTR(self, guard: Guard) -> None: # if the base value is nn.Module, check if we can speedup the # guard by going through __dict__ attrs. +<<<<<<< HEAD if should_optimize_getattr_on_nn_module(base_example_value): self.getattr_on_nn_module( +======= + if ( + isinstance(base_example_value, torch.nn.Module) + and get_custom_getattr(base_example_value) + is unpatched_nn_module_getattr + ): + return self.getattr_on_nn_module( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) source, base_manager, base_example_value, @@ -1855,6 +2228,7 @@ def HASATTR(self, guard: Guard) -> None: else: base_manager.add_no_hasattr_guard(attr, get_verbose_code_parts(code, guard)) +<<<<<<< HEAD def NOT_PRESENT_IN_GENERIC_DICT( self, guard: Guard, attr: Optional[Any] = None ) -> None: @@ -1867,6 +2241,16 @@ def NOT_PRESENT_IN_GENERIC_DICT( if (ref, attr) in self.already_guarded_not_present_in_generic_dict: return +======= + def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None: + assert attr is not None + ref = self.arg_ref(guard) + val = self.get(guard.name) + assert isinstance(val, torch.nn.Module) + + base_manager = self.get_guard_manager(guard) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mod_dict_source = f"{guard.name}.__dict__" mod_generic_dict_manager = base_manager.get_generic_dict_manager( source=mod_dict_source, @@ -1878,7 +2262,10 @@ def NOT_PRESENT_IN_GENERIC_DICT( mod_generic_dict_manager.add_dict_contains_guard( False, attr, get_verbose_code_parts(code, guard) ) +<<<<<<< HEAD self.already_guarded_not_present_in_generic_dict.add((ref, attr)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def TYPE_MATCH(self, guard: Guard) -> None: # ___check_type_id is same as `id(type(x)) == y` @@ -1888,10 +2275,16 @@ def TYPE_MATCH(self, guard: Guard) -> None: else: t = type(value) +<<<<<<< HEAD if t.__qualname__ != t.__name__: # Type match guards must be local scope, this is # raised in self.serialize_guards guard._unserializable = True +======= + if self.serialization_mode == "save": + if t.__qualname__ != t.__name__: + raise_local_type_error(value) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj_id = self.id_ref(t, f"type({guard.name})") code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})" @@ -1901,7 +2294,15 @@ def TYPE_MATCH(self, guard: Guard) -> None: obj_id, get_verbose_code_parts(code, guard) ) +<<<<<<< HEAD def DICT_VERSION(self, guard: Guard) -> None: +======= + def DICT_VERSION(self, guard: Guard): + if self.serialization_mode == "save": + raise torch._dynamo.exc.PackageError( + "DICT_VERSION guard cannot be serialized." + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ___check_dict_version is same as `dict_version(x) == y` ref = self.arg_ref(guard) val = self.get(guard.name) @@ -1915,7 +2316,11 @@ def DICT_VERSION(self, guard: Guard) -> None: val, get_verbose_code_parts(code, guard) ) +<<<<<<< HEAD def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool) -> None: +======= + def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dict_ref = self.arg_ref(guard) maybe_not = "not " if invert else "" @@ -1926,6 +2331,7 @@ def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool) -> None: not invert, key, get_verbose_code_parts(code, guard) ) +<<<<<<< HEAD def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool) -> None: set_ref = self.arg_ref(guard) item = key @@ -1940,6 +2346,9 @@ def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool) -> None: ) def BOOL_MATCH(self, guard: Guard) -> None: +======= + def BOOL_MATCH(self, guard: Guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # checks val == True or val == False ref = self.arg_ref(guard) val = self.get(guard.name) @@ -1956,7 +2365,11 @@ def BOOL_MATCH(self, guard: Guard) -> None: get_verbose_code_parts(code, guard) ) +<<<<<<< HEAD def NONE_MATCH(self, guard: Guard) -> None: +======= + def NONE_MATCH(self, guard: Guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # checks `val is None` ref = self.arg_ref(guard) val = self.get(guard.name) @@ -1968,12 +2381,18 @@ def NONE_MATCH(self, guard: Guard) -> None: get_verbose_code_parts(code, guard) ) +<<<<<<< HEAD def ID_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None: return self.id_match_unchecked(guard, recompile_hint) def id_match_unchecked( self, guard: Guard, recompile_hint: Optional[str] = None ) -> None: +======= + def ID_MATCH(self, guard: Guard): + if self.serialization_mode == "save": + raise torch._dynamo.exc.PackageError("ID_MATCH guard cannot be serialized.") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ___check_obj_id is same as `id(x) == y` if isinstance(guard.originating_source, TypeSource): # optional optimization to produce cleaner/faster guard code @@ -1985,9 +2404,16 @@ def id_match_unchecked( val = self.get(guard.name) id_val = self.id_ref(val, guard.name) code = f"___check_obj_id({ref}, {id_val})" +<<<<<<< HEAD self._set_guard_export_info(guard, [code], provided_func_name="ID_MATCH") self.get_guard_manager(guard).add_id_match_guard( id_val, get_verbose_code_parts(code, guard, recompile_hint) +======= + self._set_guard_export_info(guard, [code]) + + self.get_guard_manager(guard).add_id_match_guard( + id_val, get_verbose_code_parts(code, guard) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Keep track of ID_MATCH'd objects. This will be used to modify the @@ -2002,7 +2428,11 @@ def id_match_unchecked( if weak_id is not None: self.id_matched_objs[local_name] = weak_id +<<<<<<< HEAD def NOT_NONE_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: +======= + def NOT_NONE_MATCH(self, guard: Guard, value=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ref = self.arg_ref(guard) val = self.get(guard.name) assert isinstance(val, torch.Tensor) @@ -2013,7 +2443,11 @@ def NOT_NONE_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: get_verbose_code_parts(code, guard) ) +<<<<<<< HEAD def DISPATCH_KEY_SET_MATCH(self, guard: Guard) -> None: +======= + def DISPATCH_KEY_SET_MATCH(self, guard: Guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ref = self.arg_ref(guard) val = self.get(guard.name) assert isinstance(val, torch._C.DispatchKeySet) @@ -2023,6 +2457,7 @@ def DISPATCH_KEY_SET_MATCH(self, guard: Guard) -> None: val, get_verbose_code_parts(code_parts, guard) ) +<<<<<<< HEAD def NAME_MATCH(self, guard: Guard) -> None: self._guard_on_attribute(guard, "__name__", GuardBuilder.EQUALS_MATCH) # type: ignore[arg-type] @@ -2037,16 +2472,37 @@ def DUAL_LEVEL(self, guard: Guard) -> None: forward_ad = torch.autograd.forward_ad def fn(x: Any) -> bool: +======= + def NAME_MATCH(self, guard: Guard): + self._guard_on_attribute(guard, "__name__", GuardBuilder.EQUALS_MATCH) + + def DUAL_LEVEL(self, guard: Guard): + # Invalidate dual level if current dual level is different than the one + # in the fx graph + dual_level = self.check_fn_manager.output_graph.dual_level + code = [f"torch.autograd.forward_ad._current_level == {dual_level}"] + self._set_guard_export_info(guard, [code]) + # TODO(anijain2305) - Consider this moving this guard to C++ + forward_ad = torch.autograd.forward_ad + + def fn(x): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return forward_ad._current_level == dual_level self.guard_manager.root.add_lambda_guard( fn, get_verbose_code_parts(code, guard) ) +<<<<<<< HEAD def FUNCTORCH_STACK_MATCH(self, guard: Guard) -> None: # Invalidate functorch code if current level is different than # the one when FX graph was generated assert self.check_fn_manager.output_graph is not None +======= + def FUNCTORCH_STACK_MATCH(self, guard: Guard): + # Invalidate functorch code if current level is different than + # the one when FX graph was generated +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cis = self.check_fn_manager.output_graph.functorch_layers states = [ci.get_state() for ci in cis] code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"] @@ -2055,22 +2511,34 @@ def FUNCTORCH_STACK_MATCH(self, guard: Guard) -> None: # TODO(anijain2305) - Consider this moving this guard to C++ compare_fn = torch._functorch.pyfunctorch.compare_functorch_state +<<<<<<< HEAD def fn(x: Any) -> bool: +======= + def fn(x): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return compare_fn(states) self.guard_manager.root.add_lambda_guard( fn, get_verbose_code_parts(code, guard) ) +<<<<<<< HEAD def AUTOGRAD_SAVED_TENSORS_HOOKS(self, guard: Guard) -> None: +======= + def AUTOGRAD_SAVED_TENSORS_HOOKS(self, guard: Guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks are_inline_hooks = ( torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable ) +<<<<<<< HEAD def hooks_ids_fn( hooks: tuple[Callable[[torch.Tensor], Any], Callable[[Any], torch.Tensor]], ) -> Optional[tuple[int, ...]]: +======= + def hooks_ids_fn(hooks): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not are_inline_hooks(hooks): return None @@ -2084,27 +2552,43 @@ def hooks_ids_fn( ] self._set_guard_export_info(guard, code) +<<<<<<< HEAD def fn(x: Any) -> bool: +======= + def fn(x): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return guard_hooks_ids == hooks_ids_fn(get_hooks()) self.guard_manager.root.add_lambda_guard( fn, get_verbose_code_parts(code, guard) ) +<<<<<<< HEAD def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard) -> None: +======= + def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) value = self.get(guard.name) original_metadata = deepcopy(self.get(guard.name).__tensor_flatten__()[1]) if hasattr(value, "__metadata_guard__"): verify_guard_fn_signature(value) +<<<<<<< HEAD def metadata_checker(x: Any) -> bool: +======= + def metadata_checker(x): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return value.__metadata_guard__( original_metadata, x.__tensor_flatten__()[1] ) else: +<<<<<<< HEAD def metadata_checker(x: Any) -> bool: +======= + def metadata_checker(x): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return x.__tensor_flatten__()[1] == original_metadata global_name = f"___check_metadata_{id(metadata_checker)}_c{CompileContext.current_compile_id()}" @@ -2112,7 +2596,11 @@ def metadata_checker(x: Any) -> bool: metadata_checker, get_verbose_code_parts(global_name, guard) ) +<<<<<<< HEAD def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None: +======= + def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ref = self.arg_ref(guard) val = self.get(guard.name) if np: @@ -2184,7 +2672,11 @@ def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> No self._set_guard_export_info(guard, code) self.get_guard_manager(guard).add_lambda_guard( +<<<<<<< HEAD _get_closure_vars()["__math_isnan"], # type: ignore[arg-type] +======= + _get_closure_vars()["__math_isnan"], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) get_verbose_code_parts(code, guard), ) return @@ -2197,7 +2689,11 @@ def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> No self._set_guard_export_info(guard, code) self.get_guard_manager(guard).add_lambda_guard( +<<<<<<< HEAD _get_closure_vars()["__numpy_isnan"], # type: ignore[arg-type] +======= + _get_closure_vars()["__numpy_isnan"], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) get_verbose_code_parts(code, guard), ) return @@ -2220,7 +2716,11 @@ def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> No self._set_guard_export_info(guard, code) return +<<<<<<< HEAD def CONSTANT_MATCH(self, guard: Guard) -> None: +======= + def CONSTANT_MATCH(self, guard: Guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) val = self.get(guard.name) if istype(val, bool): self.BOOL_MATCH(guard) @@ -2231,6 +2731,7 @@ def CONSTANT_MATCH(self, guard: Guard) -> None: else: self.EQUALS_MATCH(guard) +<<<<<<< HEAD def NN_MODULE(self, guard: Guard) -> None: # don't support this in serialization because it uses unsupported ID_MATCH self.ID_MATCH(guard, "[inline-inbuilt-nn-modules-candidate]") @@ -2240,6 +2741,19 @@ def NN_MODULE(self, guard: Guard) -> None: if not self.guard_nn_modules: # If guard_nn_modules is true, we will guard on the right set of guards self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) # type: ignore[arg-type] +======= + def NN_MODULE(self, guard: Guard): + # don't support this in serialization because it uses unsupported ID_MATCH + if self.serialization_mode == "save": + raise torch._dynamo.exc.PackageError( + "NN_MODULE guard cannot be serialized." + ) + self.ID_MATCH(guard) + val = self.get(guard.name) + if hasattr(val, "training"): + assert istype(val.training, bool) + self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: exc.unimplemented_v2( gb_type="Attempted to guard on uninitialized nn.Module", @@ -2251,6 +2765,7 @@ def NN_MODULE(self, guard: Guard) -> None: ], ) +<<<<<<< HEAD def FUNCTION_MATCH(self, guard: Guard) -> None: """things like torch.add and user defined functions""" # don't support this in serialization because it uses unsupported ID_MATCH @@ -2279,6 +2794,36 @@ def BUILTIN_MATCH(self, guard: Guard) -> None: return self.ID_MATCH(guard) def SEQUENCE_LENGTH(self, guard: Guard) -> None: +======= + def FUNCTION_MATCH(self, guard: Guard): + """things like torch.add and user defined functions""" + # don't support this in serialization because it uses unsupported ID_MATCH + if self.serialization_mode == "save": + raise torch._dynamo.exc.PackageError( + "FUNCTION_MATCH guard cannot be serialized." + ) + return self.ID_MATCH(guard) + + def CLOSURE_MATCH(self, guard: Guard): + """matches a closure by __code__ id.""" + # don't support this in serialization because it uses unsupported FUNCTION_MATCH + if self.serialization_mode == "save": + raise torch._dynamo.exc.PackageError( + "CLOSURE_MATCH guard cannot be serialized." + ) + val = self.get(guard.name) + # Strictly only want user-defined functions + if type(val) == types.FunctionType and hasattr(val, "__code__"): + self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) + self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) + else: + self.FUNCTION_MATCH(guard) + + def BUILTIN_MATCH(self, guard: Guard): + return self.FUNCTION_MATCH(guard) + + def SEQUENCE_LENGTH(self, guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This guard is used to check length of PySequence objects like list, # tuple, collections.deque etc ref = self.arg_ref(guard) @@ -2304,7 +2849,11 @@ def SEQUENCE_LENGTH(self, guard: Guard) -> None: len(value), get_verbose_code_parts(code, guard) ) +<<<<<<< HEAD def TUPLE_ITERATOR_LEN(self, guard: Guard) -> None: +======= + def TUPLE_ITERATOR_LEN(self, guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ref = self.arg_ref(guard) value = self.get(guard.name) t = type(value) @@ -2320,7 +2869,11 @@ def TUPLE_ITERATOR_LEN(self, guard: Guard) -> None: tuple_iterator_len(value), obj_id, get_verbose_code_parts(code, guard) ) +<<<<<<< HEAD def RANGE_ITERATOR_MATCH(self, guard: Guard) -> None: +======= + def RANGE_ITERATOR_MATCH(self, guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ref = self.arg_ref(guard) value = self.get(guard.name) t = type(value) @@ -2339,6 +2892,7 @@ def RANGE_ITERATOR_MATCH(self, guard: Guard) -> None: ) # TODO(voz): Deduplicate w/ AOTAutograd dupe input guards +<<<<<<< HEAD def DUPLICATE_INPUT(self, guard: Guard, source_b: Source) -> None: if self.save_guards: if name := get_local_source_name(source_b): @@ -2346,6 +2900,13 @@ def DUPLICATE_INPUT(self, guard: Guard, source_b: Source) -> None: if name := get_global_source_name(source_b): self.check_fn_manager.additional_used_global_vars.add(name) +======= + def DUPLICATE_INPUT(self, guard, source_b): + if self.serialization_mode == "save": + raise torch._dynamo.exc.PackageError( + "DUPLICATE_INPUT guard cannot be serialized yet." + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ref_a = self.arg_ref(guard) ref_b = self.arg_ref(source_b.name()) @@ -2365,6 +2926,7 @@ def DUPLICATE_INPUT(self, guard: Guard, source_b: Source) -> None: code = [f"{ref_b} is {ref_a}"] self._set_guard_export_info(guard, code) +<<<<<<< HEAD if config.use_lamba_guard_for_object_aliasing: # Save the code part so that we can install a lambda guard at the # end. Read the Note - On Lambda guarding of object aliasing - to @@ -2380,6 +2942,19 @@ def DUPLICATE_INPUT(self, guard: Guard, source_b: Source) -> None: ) def WEAKREF_ALIVE(self, guard: Guard) -> None: +======= + install_object_aliasing_guard( + self.get_guard_manager(guard), + self.get_guard_manager_from_source(source_b), + get_verbose_code_parts(code, guard), + ) + + def WEAKREF_ALIVE(self, guard): + if self.serialization_mode == "save": + raise torch._dynamo.exc.PackageError( + "WEAKREF_ALIVE guard cannot be serialized." + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) code = [f"{self.arg_ref(guard)} is not None"] self._set_guard_export_info(guard, code) @@ -2387,7 +2962,11 @@ def WEAKREF_ALIVE(self, guard: Guard) -> None: get_verbose_code_parts(code, guard) ) +<<<<<<< HEAD def MAPPING_KEYS_CHECK(self, guard: Guard) -> None: +======= + def MAPPING_KEYS_CHECK(self, guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Guard on the key order of types.MappingProxyType object""" ref = self.arg_ref(guard) value = self.get(guard.name) @@ -2397,7 +2976,11 @@ def MAPPING_KEYS_CHECK(self, guard: Guard) -> None: self._set_guard_export_info(guard, code) self.get_guard_manager(guard).add_mapping_keys_guard(value, code) +<<<<<<< HEAD def DICT_KEYS_MATCH(self, guard: Guard) -> None: +======= + def DICT_KEYS_MATCH(self, guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Insert guard to check that the keys of a dict are same""" ref = self.arg_ref(guard) value = self.get(guard.name) @@ -2422,13 +3005,18 @@ def DICT_KEYS_MATCH(self, guard: Guard) -> None: else: self.guard_on_dict_keys_and_ignore_order(value, guard) +<<<<<<< HEAD def EMPTY_NN_MODULE_HOOKS_DICT(self, guard: Guard) -> None: +======= + def EMPTY_NN_MODULE_HOOKS_DICT(self, guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards""" if config.skip_nnmodule_hook_guards: # This is unsafe if you add/remove a hook on nn module variable return self.SEQUENCE_LENGTH(guard) +<<<<<<< HEAD def GRAD_MODE(self, guard: Guard) -> None: pass # we always guard on this via GlobalStateGuard() @@ -2446,6 +3034,24 @@ def DEFAULT_DEVICE(self, guard: Guard) -> None: assert guard.source is GuardSource.GLOBAL assert self.check_fn_manager.output_graph is not None +======= + def GRAD_MODE(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + + def DETERMINISTIC_ALGORITHMS(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + + def TORCH_FUNCTION_STATE(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + + def FSDP_TRAINING_STATE(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + + def DEFAULT_DEVICE(self, guard: Guard): + """Guard on CURRENT_DEVICE per torch.utils._device""" + assert guard.source is GuardSource.GLOBAL + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) code = [ f"utils_device.CURRENT_DEVICE == {self.check_fn_manager.output_graph.current_device!r}" ] @@ -2455,6 +3061,7 @@ def DEFAULT_DEVICE(self, guard: Guard) -> None: get_verbose_code_parts(code, guard) ) +<<<<<<< HEAD def SHAPE_ENV(self, guard: Guard) -> None: from torch._dynamo.output_graph import OutputGraph @@ -2462,6 +3069,13 @@ def SHAPE_ENV(self, guard: Guard) -> None: output_graph = self.check_fn_manager.output_graph assert output_graph is not None if self.check_fn_manager.shape_code_parts is not None: +======= + def SHAPE_ENV(self, guard: Guard): + assert guard.name == "" + output_graph = self.check_fn_manager.output_graph + if self.serialization_mode == "load": + assert self.check_fn_manager.shape_code_parts is not None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shape_code_parts = self.check_fn_manager.shape_code_parts python_code_parts = shape_code_parts.python_code_parts verbose_code_parts = shape_code_parts.verbose_code_parts @@ -2473,11 +3087,18 @@ def SHAPE_ENV(self, guard: Guard) -> None: # shape variables to sources from tracked_fakes. This must happen after # tensor checks. # NB: self.output_graph can be None in the debug_nops tests +<<<<<<< HEAD assert isinstance(output_graph, OutputGraph) fs = output_graph.tracked_fakes input_contexts = [a.symbolic_context for a in fs] def get_sources(t_id: int, dim: int) -> list[Source]: +======= + fs = output_graph.tracked_fakes + input_contexts = [a.symbolic_context for a in fs] + + def get_sources(t_id, dim): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Looks up base sources mapped to a tensor id and uses them to create # sources for the corresponding tensor dimension. return [ @@ -2485,7 +3106,10 @@ def get_sources(t_id: int, dim: int) -> list[Source]: for source in output_graph.tracked_fakes_id_to_source[t_id] ] +<<<<<<< HEAD assert output_graph.shape_env is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if output_graph.export_constraints: names: dict[str, tuple[int, int]] = {} source_pairs: list[tuple[Source, Source]] = [] @@ -2494,7 +3118,11 @@ def get_sources(t_id: int, dim: int) -> list[Source]: ] = [] phantom_symbols: dict[str, Symbol] = {} relaxed_sources: set[Source] = set() +<<<<<<< HEAD for constraint in output_graph.export_constraints: # type: ignore[attr-defined] +======= + for constraint in output_graph.export_constraints: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if constraint.t_id in output_graph.tracked_fakes_id_to_source: torch.export.dynamic_shapes._process_equalities( constraint, @@ -2518,6 +3146,7 @@ def get_sources(t_id: int, dim: int) -> list[Source]: else: equalities_inputs = None +<<<<<<< HEAD def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]: return output_graph.shape_env.produce_guards_verbose( [a.fake for a in fs], # type: ignore[misc] @@ -2527,6 +3156,17 @@ def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]: source_ref=self.source_ref, # Export keeps static. ignore_static=(not output_graph.export), +======= + def _get_code_parts(langs): + return output_graph.shape_env.produce_guards_verbose( + [a.fake for a in fs], + [a.source for a in fs], + input_contexts=input_contexts, + equalities_inputs=equalities_inputs, + source_ref=self.source_ref, + # Export keeps static. + ignore_static=(not self.check_fn_manager.output_graph.export), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) langs=langs, ) @@ -2534,7 +3174,11 @@ def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]: try: # For exporting we need the python code parts python_code_parts, verbose_code_parts, cpp_code_parts = ( +<<<<<<< HEAD _get_code_parts(("python", "verbose_python", "cpp")) # type: ignore[assignment] +======= + _get_code_parts(("python", "verbose_python", "cpp")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) python_fallback = False except OverflowError: @@ -2551,10 +3195,17 @@ def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]: # When exporting, we may work with the shape constraints some more in # postprocessing, so don't freeze yet +<<<<<<< HEAD if not output_graph.export: output_graph.shape_env.freeze() if self.save_guards: +======= + if not self.check_fn_manager.output_graph.export: + output_graph.shape_env.freeze() + + if self.serialization_mode == "save": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # For SHAPE_ENV we want to skip serializing the entire ShapeEnv so instead # we directly serialize the generated code here. maybe_cpp_code_parts = locals().get("cpp_code_parts") @@ -2695,7 +3346,11 @@ def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]: closure_vars={**SYMPY_INTERP, **_get_closure_vars()}, ) +<<<<<<< HEAD def TENSOR_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: +======= + def TENSOR_MATCH(self, guard: Guard, value=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if config._unsafe_skip_fsdp_module_guards and guard.is_fsdp_module(): return # For tensors that are part of the Dynamo extracted Fx graph module, an @@ -2748,7 +3403,10 @@ def TENSOR_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: # The list of tensor fields and calls we care about can be found in `terms` below. # TODO(voz): We are missing storage offset in all our tensor guards? code: list[str] = [] +<<<<<<< HEAD assert self.check_fn_manager.output_graph is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.check_fn_manager.output_graph.export: self.TYPE_MATCH(guard) terms = [ @@ -2779,12 +3437,16 @@ def TENSOR_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: # insert aliasing guards on them if not ( config.skip_no_tensor_aliasing_guards_on_parameters +<<<<<<< HEAD and ( istype(value, torch.nn.Parameter) or is_from_unspecialized_builtin_nn_module_source( guard.originating_source ) ) +======= + and istype(value, torch.nn.Parameter) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) and not isinstance(guard.originating_source, NumpyTensorSource): # Keep track of all the tensor guard managers to insert # NoAliasing check at the end. @@ -2800,19 +3462,28 @@ def TENSOR_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: verbose_code_parts = get_verbose_code_parts( get_tensor_guard_code_part( +<<<<<<< HEAD value, tensor_name, size, stride, pytype, dispatch_keys, +======= + value, tensor_name, size, stride, pytype, dispatch_keys +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), guard, ) guard_manager.add_tensor_match_guard( value, +<<<<<<< HEAD size, # type: ignore[arg-type] stride, # type: ignore[arg-type] +======= + size, + stride, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor_name, verbose_code_parts, pytype, @@ -2880,6 +3551,7 @@ def TENSOR_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None: self._set_guard_export_info(guard, code) # A util that in the case of export, adds data onto guards +<<<<<<< HEAD def _set_guard_export_info( self, guard: Guard, @@ -2887,6 +3559,9 @@ def _set_guard_export_info( provided_guarded_object: Optional[Any] = None, provided_func_name: Optional[str] = None, ) -> None: +======= + def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # WARNING: It is important that cur_frame/caller do NOT stay in # the current frame, because they will keep things live longer # than they should. See TestMisc.test_release_module_memory @@ -2895,7 +3570,11 @@ def _set_guard_export_info( caller = cur_frame.f_back del cur_frame assert caller is not None +<<<<<<< HEAD func_name = provided_func_name or caller.f_code.co_name +======= + func_name = caller.f_code.co_name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) del caller # We use func_name for export, so might as well get a nice defensive check out of it assert func_name in self.__class__.__dict__, ( @@ -2919,9 +3598,13 @@ def _set_guard_export_info( getattr(guarded_object.__class__, "__weakrefoffset__", 0) != 0 ) # See D64140537 for why we are checking for tuple. +<<<<<<< HEAD if supports_weakref and not isinstance( guarded_object, (enum.Enum, tuple, weakref.ProxyTypes) ): +======= + if supports_weakref and not isinstance(guarded_object, (enum.Enum, tuple)): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj_ref = weakref.ref(guarded_object) guard.set_export_info( @@ -2966,7 +3649,11 @@ class ExprCounter(ast.NodeVisitor): def __init__(self, config: PyExprCSEPass.Config) -> None: self._config = config +<<<<<<< HEAD def visit(self, node: ast.AST) -> None: +======= + def visit(self, node: ast.AST) -> Any: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES): self._config.expr_count[_ast_unparse(node)] += 1 super().visit(node) @@ -3034,7 +3721,11 @@ def replace(self, expr: str) -> tuple[list[str], str]: return replacer.preface, _ast_unparse(new_node) +<<<<<<< HEAD def must_add_nn_module_guards(guard: Guard) -> bool: +======= +def must_add_nn_module_guards(guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # For config.guard_nn_modules=False, we can skip all the guards that # originate from inside of nn module except for a few categories. return ( @@ -3049,11 +3740,19 @@ def must_add_nn_module_guards(guard: Guard) -> bool: class DeletedGuardManagerWrapper(GuardManagerWrapper): +<<<<<<< HEAD def __init__(self, reason: str) -> None: super().__init__() self.invalidation_reason = reason def populate_diff_guard_manager(self) -> None: +======= + def __init__(self, reason): + super().__init__() + self.invalidation_reason = reason + + def populate_diff_guard_manager(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.diff_guard_root = None @@ -3072,23 +3771,33 @@ class GuardsState: shape_code_parts: Optional[ShapeCodeParts] +<<<<<<< HEAD class _Missing: pass class GuardsStatePickler(pickle.Pickler): def __init__(self, *args: Any, **kwargs: Any) -> None: +======= +class GuardsStatePickler(pickle.Pickler): + def __init__(self, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(*args, **kwargs) self.fake_mode = torch._subclasses.FakeTensorMode() self.tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter() @classmethod +<<<<<<< HEAD def _unpickle_module(cls, state: Any) -> torch.nn.Module: +======= + def _unpickle_module(cls, state): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mod = torch.nn.Module() mod.__setstate__(state) return mod @classmethod +<<<<<<< HEAD def _unpickle_tensor( cls, meta_tensor: torch.Tensor, @@ -3100,12 +3809,19 @@ def _unpickle_tensor( fake_mode = torch._subclasses.FakeTensorMode() tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter() ret = tensor_converter.from_meta_and_device( +======= + def _unpickle_tensor(cls, meta_tensor, device, pytype, dispatch_keys_raw): + fake_mode = torch._subclasses.FakeTensorMode() + tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter() + return tensor_converter.from_meta_and_device( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fake_mode, meta_tensor, device, pytype, torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw), ) +<<<<<<< HEAD ret.grad = grad return ret @@ -3119,13 +3835,24 @@ def _unpickle_traceable_wrapper_subclass( ctx: Any, inner_data: list[tuple[str, Callable[..., Any], tuple[Any, ...]]], ) -> torch.Tensor: +======= + + @classmethod + def _unpickle_traceable_wrapper_subclass( + cls, meta_tensor, device, pytype, dispatch_keys_raw, ctx, inner_data + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Unpickle the inner tensor components. These could also be subclass instances. inner_tensors = {} for attr, unpickle_func, unpickle_func_args in inner_data: inner_tensors[attr] = unpickle_func(*unpickle_func_args) outer_size, outer_stride = meta_tensor.shape, meta_tensor.stride() +<<<<<<< HEAD out = type(meta_tensor).__tensor_unflatten__( # type: ignore[attr-defined] +======= + out = type(meta_tensor).__tensor_unflatten__( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inner_tensors, ctx, outer_size, outer_stride ) out.pytype = pytype @@ -3133,6 +3860,7 @@ def _unpickle_traceable_wrapper_subclass( return out @classmethod +<<<<<<< HEAD def _unpickle_python_module(cls, alias: str) -> types.ModuleType: return importlib.import_module(alias) @@ -3159,6 +3887,24 @@ def _unpickle_c_op(cls, name: str) -> Any: def reducer_override( self, obj: Any ) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]: +======= + def _unpickle_python_module(cls, alias: str): + return importlib.import_module(alias) + + @classmethod + def _unpickle_dispatch_key_set(cls, raw_repr: int): + return torch._C.DispatchKeySet.from_raw_repr(raw_repr) + + @classmethod + def _unpickle_functorch_interpreter(cls, json: bytes): + return torch._C._functorch.CInterpreter.deserialize(json) + + @classmethod + def _unpickle_mapping_proxy(cls, d): + return types.MappingProxyType(d) + + def reducer_override(self, obj): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sympy if isinstance(obj, torch.Tensor) and obj.device.type != "meta": @@ -3186,11 +3932,18 @@ def reducer_override( ) return type(self)._unpickle_tensor, ( +<<<<<<< HEAD torch.empty_like(obj, device="meta", requires_grad=obj.requires_grad), obj.device, type(obj), torch._C._dispatch_keys(obj).raw_repr(), obj.grad, +======= + torch.empty_like(obj, device="meta"), + obj.device, + type(obj), + torch._C._dispatch_keys(obj).raw_repr(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) elif isinstance(obj, torch.nn.Module): @@ -3222,6 +3975,7 @@ def reducer_override( elif isinstance(obj, types.MappingProxyType): return type(self)._unpickle_mapping_proxy, (obj.copy(),) +<<<<<<< HEAD elif isinstance( obj, torch._ops.OpOverloadPacket ) and obj._qualified_op_name.startswith("_C::"): @@ -3243,6 +3997,8 @@ def reducer_override( assert obj.__qualname__ != obj.__name__ return _Missing, () +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if type(obj).__qualname__ != type(obj).__name__: raise torch._dynamo.exc.PackageError( f"Type {type(obj)} for object {obj} cannot be saved " @@ -3271,17 +4027,28 @@ def pickle_guards_state(state: GuardsState) -> bytes: class CheckFunctionManager: def __init__( self, +<<<<<<< HEAD f_code: types.CodeType, output_graph: OutputGraphGuardsState, cache_entry: Optional[CacheEntry] = None, +======= + f_code, + output_graph=None, + cache_entry=None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) guard_fail_fn: Optional[Callable[[GuardFail], None]] = None, guard_filter_fn: Optional[ Callable[[list[GuardFilterEntry]], list[bool]] ] = None, +<<<<<<< HEAD shape_code_parts: Optional[ShapeCodeParts] = None, runtime_global_scope: Optional[dict[str, Any]] = None, save_guards: bool = False, strict_error: bool = False, +======= + guards_serialization_mode: Optional[str] = None, + shape_code_parts: Optional[ShapeCodeParts] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): guards = output_graph.guards if output_graph else None self._weakrefs: dict[int, ReferenceType[object]] = {} @@ -3289,8 +4056,12 @@ def __init__( existing_diff_guard_sources = ( update_diff_guard_managers_for_existing_cache_entries(cache_entry) ) +<<<<<<< HEAD self.output_graph: Optional[OutputGraphGuardsState] = output_graph assert self.output_graph is not None +======= + self.output_graph = output_graph +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Only used for serialization. self.shape_code_parts = shape_code_parts @@ -3300,14 +4071,19 @@ def __init__( self.torch_function_mode_stack = ( output_graph.torch_function_mode_stack if output_graph else None ) +<<<<<<< HEAD self.used_builtin_vars: OrderedSet[str] = OrderedSet() self.additional_used_local_vars: OrderedSet[str] = OrderedSet() self.additional_used_global_vars: OrderedSet[str] = OrderedSet() self.runtime_global_scope = runtime_global_scope +======= + self.guards_serialization_mode = guards_serialization_mode +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not justknobs_check("pytorch/compiler:guard_nn_modules"): log.warning("guard_nn_modules is turned off using justknobs killswitch") +<<<<<<< HEAD # TODO Be more explicit about the behavior for the users. if torch._dynamo.config.caching_precompile: _guard_filter_fn = guard_filter_fn or (lambda gs: [True for g in gs]) @@ -3341,12 +4117,27 @@ def guard_filter_fn(guards: list[GuardFilterEntry]) -> list[bool]: ) def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry: +======= + sorted_guards = sorted(guards or (), key=Guard.sort_key) + builder, guard_manager = self.build_guards( + sorted_guards, + existing_diff_guard_sources, + f_code, + output_graph, + None if guard_filter_fn else self.guards_serialization_mode, + ) + + if guard_filter_fn: + + def make_guard_filter_entry(guard): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MISSING = object() name = strip_local_scope(guard.name) if name == "": has_value = False value = MISSING else: +<<<<<<< HEAD try: # Guard evaluation is expected to fail when we guard on # things like "not hasattr(x, 'foo')". In cases like this, @@ -3358,14 +4149,29 @@ def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry: value = MISSING has_value = False is_global = get_global_source_name(guard.originating_source) is not None +======= + has_value = True + value = builder.get(guard.name) + is_global = get_global_source_name(guard.originating_source) is not None + guard_fn = guard.create_fn + if isinstance(guard_fn, functools.partial): + guard_fn = guard.create_fn.func +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return GuardFilterEntry( name=name, has_value=has_value, value=value, +<<<<<<< HEAD guard_type=guard.create_fn_name(), derived_guard_types=( tuple(guard.guard_types) if guard.guard_types else () ), +======= + guard_type=guard_fn.__name__, + derived_guard_types=tuple(guard.guard_types) + if guard.guard_types + else (), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_global=is_global, orig_guard=guard, ) @@ -3378,6 +4184,7 @@ def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry: sorted_guards = [ guard for i, guard in enumerate(sorted_guards) if filter_results[i] ] +<<<<<<< HEAD # Redo the guards because filtering relies on the results from the last guard builder. builder, guard_manager = self.build_guards( @@ -3387,6 +4194,16 @@ def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry: output_graph, save_guards, ) +======= + # Redo the guards because filtering relies on the results from the last guard builder. + builder, guard_manager = self.build_guards( + sorted_guards, + existing_diff_guard_sources, + f_code, + output_graph, + self.guards_serialization_mode, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.guard_manager = guard_manager self.compile_check_fn(builder, sorted_guards, guard_fail_fn) @@ -3409,11 +4226,18 @@ def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry: # TODO(anijain2305, ydwu4) - Skipping export because of following test # python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs latency = 0.0 +<<<<<<< HEAD if not output_graph.skip_guards_check and not output_graph.export: if not self.guard_manager.check(output_graph.local_scope): reasons = get_guard_fail_reason_helper( self.guard_manager, +======= + if not output_graph.export and self.guards_serialization_mode != "load": + if not self.guard_manager.check(output_graph.local_scope): + reasons = get_guard_fail_reason_helper( + self.guard_manager, # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_graph.local_scope, CompileContext.current_compile_id(), ) @@ -3421,7 +4245,11 @@ def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry: if guard_manager_testing_hook_fn is not None: guard_manager_testing_hook_fn( +<<<<<<< HEAD self.guard_manager, output_graph.local_scope, builder +======= + self.guard_manager, output_graph.local_scope +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # NB for developers: n_iters is chosen to be 1 to prevent excessive @@ -3446,6 +4274,7 @@ def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry: CompileEventLogger.increment_toplevel("guard_latency_us", int(latency)) self.guards_state: Optional[bytes] = None +<<<<<<< HEAD if save_guards: from torch._dynamo.output_graph import OutputGraph @@ -3461,6 +4290,80 @@ def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry: f"Guard evaluation failed: {str(e)}", traceback=traceback.format_exc().split("\n"), ) +======= + if self.guards_serialization_mode == "save": + used_global_vars = set() + used_local_vars = set() + + def prune_variable(source): + if name := get_global_source_name(source): + assert isinstance(name, str) + used_global_vars.add(name) + elif name := get_local_source_name(source): + assert isinstance(name, str) + used_local_vars.add(name) + + output_graph_guards_state = self.output_graph.dump_guards_state() + # Only serialize the global variables that are actually used in guards. + for guard in sorted_guards: + if isinstance(guard.originating_source, ShapeEnvSource): + assert self.shape_code_parts + for source in self.shape_code_parts.shape_env_sources: + prune_variable(source) + else: + prune_variable(guard.originating_source) + + for source in self.output_graph.guard_on_key_order: + prune_variable(source) + + def normalize_create_fn(x): + if isinstance(x, functools.partial): + + def _ref(x): + if isinstance(x, (TensorWeakRef, weakref.ref)): + return x() + return x + + new_args = tuple(_ref(a) for a in x.args) + new_keywords = {k: _ref(v) for k, v in x.keywords.items()} + return functools.partial(x.func, *new_args, **new_keywords) + + return x + + output_graph_guards_state = dataclasses.replace( + output_graph_guards_state, + local_scope={ + k: v + for k, v in output_graph_guards_state.local_scope.items() + if k in used_local_vars + }, + global_scope={ + k: v + for k, v in output_graph_guards_state.global_scope.items() + if k in used_global_vars + }, + _guards=torch._guards.GuardsSet( + { + dataclasses.replace( + guard, + obj_weakref=None, + guarded_class_weakref=None, + create_fn=normalize_create_fn(guard.create_fn), + ) + for guard in sorted_guards + } + ), + input_source_to_sizes_strides=pytree.tree_map( + convert_int_to_concrete_values, + output_graph_guards_state.input_source_to_sizes_strides, + ), + ) + guards_state = GuardsState( + output_graph=output_graph_guards_state, + shape_code_parts=self.shape_code_parts, + ) + self.guards_state = pickle_guards_state(guards_state) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: don't do the string rep, do something more structured here torch._logging.trace_structured( @@ -3478,6 +4381,7 @@ def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry: self._weakrefs.clear() self.output_graph = None +<<<<<<< HEAD UNSUPPORTED_SERIALIZATION_GUARD_TYPES: tuple[LiteralString, ...] = ( "DICT_VERSION", "NN_MODULE", @@ -3618,12 +4522,26 @@ def build_guards( output_graph: OutputGraphGuardsState, save_guards: bool, ) -> tuple[GuardBuilder, GuardManagerWrapper]: +======= + def build_guards( + self, + sorted_guards, + existing_diff_guard_sources, + f_code, + output_graph, + serialization_mode=None, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) guard_manager = GuardManagerWrapper() guard_manager.diff_guard_sources = existing_diff_guard_sources w_builder = None +<<<<<<< HEAD def source_ref(source: Source) -> str: +======= + def source_ref(source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) guard_source = source.guard_source() if guard_source is GuardSource.CONSTANT: # No need to track constants @@ -3642,6 +4560,7 @@ def source_ref(source: Source) -> str: output_graph.global_scope, guard_manager, self, +<<<<<<< HEAD save_guards, runtime_global_scope=self.runtime_global_scope, ) @@ -3651,6 +4570,16 @@ def cleanup_builder(weak_b: weakref.ref[GuardBuilder]) -> None: b = weak_b() if b: b.scope = None # type: ignore[assignment] +======= + serialization_mode, + ) + + # Break retain cycle. See test_release_scope_memory + def cleanup_builder(weak_b): + b = weak_b() + if b: + b.scope = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Break retain cycle. See test_release_input_memory w_builder = weakref.ref(builder, cleanup_builder) @@ -3674,12 +4603,16 @@ def cleanup_builder(weak_b: weakref.ref[GuardBuilder]) -> None: guard.create(builder) return builder, guard_manager +<<<<<<< HEAD def compile_check_fn( self, builder: GuardBuilder, guards_out: list[Guard], guard_fail_fn: Optional[Callable[[GuardFail], None]], ) -> None: +======= + def compile_check_fn(self, builder, guards_out, guard_fail_fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # see parallel handling of ".0" / "___implicit0" in _eval_frame.c largs = builder.argnames largs += ["**___kwargs_ignored"] @@ -3690,11 +4623,15 @@ def compile_check_fn( verbose_code_parts = [] structured_guard_fns: list[Callable[[], dict[str, Any]]] = [] +<<<<<<< HEAD assert self.torch_function_mode_stack is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch_function_mode_stack_check_fn = make_torch_function_mode_stack_guard( self.torch_function_mode_stack ) +<<<<<<< HEAD # Add compile id info in the guard manager for debugging purpose self.guard_manager.root.attach_compile_id( str(CompileContext.current_compile_id()) @@ -3706,6 +4643,10 @@ def compile_check_fn( self.guard_manager.root.add_global_state_guard( global_state, ["___check_global_state()"] ) +======= + # Insert the global_state guard + self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.guard_manager.root.add_torch_function_mode_stack_guard( self.torch_function_mode_stack, @@ -3714,9 +4655,13 @@ def compile_check_fn( # Clear references to torch_function modes held in the list self.torch_function_mode_stack = None +<<<<<<< HEAD def add_code_part( code_part: str, guard: Optional[Guard], log_only: bool = False ) -> None: +======= + def add_code_part(code_part, guard, log_only=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) verbose_code_part = get_verbose_code_part(code_part, guard) guards_log.debug("%s", verbose_code_part) @@ -3779,6 +4724,7 @@ def add_code_part( ["check_no_aliasing(" + ", ".join(no_tensor_aliasing_names) + ")"], ) +<<<<<<< HEAD # Note - On Lambda guarding of object aliasing # We previously installed object‑aliasing guards as relational guards, # but that undermined the recursive‑dict guard optimization: placing the @@ -3799,6 +4745,8 @@ def add_code_part( aliasing_code_parts, aliasing_verbose_code_parts ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aotautograd_guards: list[GuardEnvExpr] = ( self.output_graph.aotautograd_guards if self.output_graph else [] ) @@ -3854,7 +4802,12 @@ def add_code_part( "dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns] ) +<<<<<<< HEAD if convert_frame.initial_global_state is None: +======= + global_state = convert_frame.initial_global_state + if global_state is None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # we should only hit this case in NopTests() global_state = convert_frame.GlobalStateGuard() closure_vars = { @@ -3886,7 +4839,11 @@ def add_code_part( self.guard_manager.extra_state = None self.guard_manager.no_tensor_aliasing_sources = no_tensor_aliasing_names +<<<<<<< HEAD def invalidate(self, obj_str: str) -> None: +======= + def invalidate(self, obj_str): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Some tests reveal that CheckFunctionManager has no attribute # guard_manager, but this case should not be of any concern. # This case doesn't seem easy to repro. @@ -3903,7 +4860,11 @@ def invalidate(self, obj_str: str) -> None: extra_state.invalidate(cache_entry, deleted_guard_manager) self.guard_manager = deleted_guard_manager +<<<<<<< HEAD def id_ref(self, obj: object, obj_str: str) -> int: +======= + def id_ref(self, obj, obj_str): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """add a weakref, return the id""" try: if id(obj) not in self._weakrefs: @@ -3918,14 +4879,22 @@ def id_ref(self, obj: object, obj_str: str) -> int: pass # cannot weakref bool object return id(obj) +<<<<<<< HEAD def lookup_weakrefs(self, obj: object) -> Optional[weakref.ref[object]]: +======= + def lookup_weakrefs(self, obj): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Lookup the _weakrefs created in id_ref function for ID_MATCH'd objects""" if id(obj) in self._weakrefs: return self._weakrefs[id(obj)] return None +<<<<<<< HEAD def build_guard_function(code_parts: list[str], closure_args: str) -> tuple[str, str]: +======= +def build_guard_function(code_parts, closure_args) -> tuple[str, str]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.utils import IndentedBuffer csepass = PyExprCSEPass() @@ -3934,7 +4903,10 @@ def build_guard_function(code_parts: list[str], closure_args: str) -> tuple[str, def replace(expr: str) -> tuple[list[str], str]: return csepass.replace(expr) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except RecursionError: # If we hit recursion limits during CSE analysis, fall back to a no-op replace function # This can happen with extremely complex guard expressions @@ -3969,21 +4941,36 @@ def replace(expr: str) -> tuple[list[str], str]: return guard_body.getvalue(), make_guard_fn.getvalue() +<<<<<<< HEAD def is_recompiles_enabled() -> bool: return torch._logging._internal.log_state.is_artifact_enabled("recompiles") def is_recompiles_verbose_enabled() -> bool: +======= +def is_recompiles_enabled(): + return torch._logging._internal.log_state.is_artifact_enabled("recompiles") + + +def is_recompiles_verbose_enabled(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch._logging._internal.log_state.is_artifact_enabled("recompiles_verbose") # this will only be used if cpp guards are disabled +<<<<<<< HEAD def make_torch_function_mode_stack_guard( initial_stack: list[torch.overrides.TorchFunctionMode], ) -> Callable[[], bool]: types = [type(x) for x in initial_stack] def check_torch_function_mode_stack() -> bool: +======= +def make_torch_function_mode_stack_guard(initial_stack): + types = [type(x) for x in initial_stack] + + def check_torch_function_mode_stack(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cur_stack = get_torch_function_mode_stack() if len(cur_stack) != len(types): @@ -3998,6 +4985,7 @@ def check_torch_function_mode_stack() -> bool: return check_torch_function_mode_stack +<<<<<<< HEAD Scope = TypeAliasType("Scope", dict[str, object]) @@ -4008,6 +4996,12 @@ def recompilation_reason_for_no_tensor_aliasing_guard( global_scope = dict(guard_manager.global_scope) ids_to_source = collections.defaultdict(list) for tensor_source in guard_manager.no_tensor_aliasing_sources: +======= +def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope): + global_scope = dict(guard_manager.global_scope) + ids_to_source = collections.defaultdict(list) + for tensor_source in guard_manager.no_tensor_aliasing_sources: # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) global_scope["__compile_source__"] = tensor_source tensor_id = id(eval(tensor_source, global_scope, scope)) ids_to_source[tensor_id].append(tensor_source) @@ -4034,17 +5028,26 @@ def strip_local_scope(s: str) -> str: def get_guard_fail_reason_helper( +<<<<<<< HEAD guard_manager: GuardManagerWrapper, f_locals: dict[str, object], compile_id: Optional[CompileId], +======= + guard_manager: GuardFn, + f_locals: dict[str, object], + compile_id: CompileId, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> str: """ Return the reason why `guard_manager` failed. Updates `guard_failures` with the generated reason. Only the first failed check of guard_manager is reported. """ +<<<<<<< HEAD assert guard_manager.global_scope is not None assert guard_manager.closure_vars is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) scope = {"L": f_locals, "G": guard_manager.global_scope["G"]} scope.update(guard_manager.closure_vars) reasons: list[str] = [] @@ -4052,7 +5055,11 @@ def get_guard_fail_reason_helper( no_tensor_aliasing_check_failed = False verbose_code_parts: list[str] = [] +<<<<<<< HEAD guard_debug_info = guard_manager.check_verbose(f_locals) +======= + guard_debug_info = guard_manager.check_verbose(f_locals) # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # For test_export_with_map_cond, the check_verbose fail even without the # C++ guard manager. We need to fix the issue to remove the comment. # assert not guard_debug_info.result @@ -4103,17 +5110,27 @@ def get_guard_fail_reason_helper( def get_guard_fail_reason( +<<<<<<< HEAD guard_manager: GuardManagerWrapper, code: types.CodeType, f_locals: dict[str, object], compile_id: CompileId, skip_logging: bool = False, +======= + guard_manager: GuardFn, + code: types.CodeType, + f_locals: dict[str, object], + compile_id: CompileId, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> str: if isinstance(guard_manager, DeletedGuardManagerWrapper): return f"{compile_id}: {guard_manager.invalidation_reason}" reason_str = get_guard_fail_reason_helper(guard_manager, f_locals, compile_id) +<<<<<<< HEAD if skip_logging: return reason_str +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) guard_failures[orig_code_map[code]].append(reason_str) try: @@ -4130,9 +5147,13 @@ def get_guard_fail_reason( def get_and_maybe_log_recompilation_reasons( +<<<<<<< HEAD cache_entry: Optional[CacheEntry], frame: DynamoFrameType, skip_logging: bool = False, +======= + cache_entry, frame: DynamoFrameType +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> list[str]: """ Return the list of guard failure reasons using cache_entry. @@ -4146,7 +5167,10 @@ def get_and_maybe_log_recompilation_reasons( cache_entry.code, frame.f_locals, cache_entry.compile_id, +<<<<<<< HEAD skip_logging, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if reason: reasons.append(reason) @@ -4154,8 +5178,11 @@ def get_and_maybe_log_recompilation_reasons( code = frame.f_code +<<<<<<< HEAD if skip_logging: return reasons +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # at least one of "recompiles" or "recompiles_verbose" is enabled do_recompiles_log = is_recompiles_enabled() or is_recompiles_verbose_enabled() @@ -4194,20 +5221,32 @@ def get_and_maybe_log_recompilation_reasons( return reasons +<<<<<<< HEAD def update_diff_guard_managers_for_existing_cache_entries( cache_entry: Optional[CacheEntry], ) -> OrderedSet[str]: +======= +def update_diff_guard_managers_for_existing_cache_entries(cache_entry): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) first_cache_entry = cache_entry # On the first pass, go through the cache entries and accumulate the diff # guard sources. Different guard managers can fail with different sources. # So, we collect all of them first. +<<<<<<< HEAD acc_diff_guard_sources: OrderedSet[str] = OrderedSet() +======= + acc_diff_guard_sources = set() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) while cache_entry is not None: acc_diff_guard_sources.update( cache_entry.guard_manager.collect_diff_guard_sources() ) +<<<<<<< HEAD cache_entry = cache_entry.next # type: ignore[assignment] +======= + cache_entry = cache_entry.next +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # On the second pass, set the diff_guard_sources for each cache line to the # accumulated value. And the re-populate the diff guard manager. @@ -4215,7 +5254,11 @@ def update_diff_guard_managers_for_existing_cache_entries( while cache_entry is not None: cache_entry.guard_manager.diff_guard_sources = acc_diff_guard_sources cache_entry.guard_manager.populate_diff_guard_manager() +<<<<<<< HEAD cache_entry = cache_entry.next # type: ignore[assignment] +======= + cache_entry = cache_entry.next +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # return the accumulated sources to set up the new cache line. return acc_diff_guard_sources @@ -4227,7 +5270,11 @@ def guard_error_hook( f_locals: dict[str, object], index: int, last: bool, +<<<<<<< HEAD ) -> None: +======= +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) print( f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}" ) @@ -4247,7 +5294,11 @@ def guard_error_hook( set_guard_error_hook(guard_error_hook) +<<<<<<< HEAD def unique(seq: Sequence[T]) -> Generator[T, None, None]: +======= +def unique(seq): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) seen = set() for x in seq: if x not in seen: @@ -4255,9 +5306,13 @@ def unique(seq: Sequence[T]) -> Generator[T, None, None]: seen.add(x) +<<<<<<< HEAD def make_dupe_guard( obj_source: Source, dupe_source: Source ) -> Optional[functools.partial[Any]]: +======= +def make_dupe_guard(obj_source, dupe_source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Note - we may end up in a situation where we invoke something like # def fn(x, y) # with fn(x, x) @@ -4291,7 +5346,11 @@ def make_dupe_guard( return None +<<<<<<< HEAD def install_guard(*guards: Guard, skip: int = 0) -> None: +======= +def install_guard(*guards, skip=0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Add dynamo guards to the current tracing context. @@ -4307,7 +5366,10 @@ def install_guard(*guards: Guard, skip: int = 0) -> None: add = TracingContext.get().guards_context.dynamo_guards.add for guard in guards: assert isinstance(guard, Guard) +<<<<<<< HEAD if is_from_skip_guard_source(guard.originating_source): continue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add(guard, collect_debug_stack=collect_debug_stack, skip=skip + 1) diff --git a/torch/_dynamo/metrics_context.py b/torch/_dynamo/metrics_context.py index 786dc1a9d34d0..f6b3c227559bd 100644 --- a/torch/_dynamo/metrics_context.py +++ b/torch/_dynamo/metrics_context.py @@ -13,6 +13,7 @@ execution performance. """ +<<<<<<< HEAD from __future__ import annotations import heapq @@ -26,6 +27,14 @@ from collections.abc import Iterator from torch.utils._traceback import CapturedTraceback +======= +import heapq +import logging +import time +from collections.abc import Iterator +from typing import Any, Callable, Optional +from typing_extensions import TypeAlias +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) @@ -70,9 +79,14 @@ def __init__(self, on_exit: OnExitType): self._metrics: dict[str, Any] = {} self._start_time_ns: int = 0 self._level: int = 0 +<<<<<<< HEAD self._edits: list[tuple[CapturedTraceback, set[str]]] = [] def __enter__(self) -> Self: +======= + + def __enter__(self) -> "MetricsContext": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Initialize metrics recording. """ @@ -120,6 +134,7 @@ def increment(self, metric: str, value: int) -> None: self._metrics[metric] = 0 self._metrics[metric] += value +<<<<<<< HEAD def _render_edits(self, pred: set[str]) -> str: return "\n\n" + "\n\n".join( "Previous Traceback:\n" + "".join(e.format()) @@ -127,6 +142,8 @@ def _render_edits(self, pred: set[str]) -> str: if k & pred ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def set(self, metric: str, value: Any, overwrite: bool = False) -> None: """ Set a metric to a given value. Raises if the metric has been assigned previously @@ -136,11 +153,16 @@ def set(self, metric: str, value: Any, overwrite: bool = False) -> None: raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext") if metric in self._metrics and not overwrite: raise RuntimeError( +<<<<<<< HEAD self._render_edits({metric}) + f"\n\nRuntimeError: Metric '{metric}' has already been set in the current context " "(see above for current and previous traceback)." ) self._edits.append((CapturedTraceback.extract(skip=1), {metric})) +======= + f"Metric '{metric}' has already been set in the current context" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._metrics[metric] = value def set_key_value(self, metric: str, key: str, value: Any) -> None: @@ -168,11 +190,16 @@ def update(self, values: dict[str, Any], overwrite: bool = False) -> None: existing = self._metrics.keys() & values.keys() if existing and not overwrite: raise RuntimeError( +<<<<<<< HEAD self._render_edits(set(values.keys())) + f"\n\nRuntimeError: Metric(s) {existing} have already been set in the current context. " "(see above for current and previous traceback)." ) self._edits.append((CapturedTraceback.extract(skip=1), set(values.keys()))) +======= + f"Metric(s) {existing} have already been set in the current context" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._metrics.update(values) def update_outer(self, values: dict[str, Any]) -> None: diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 4cdf353da99ed..bb0eea47a5d4c 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Core graph building functionality for PyTorch's Dynamo system. This module contains the essential components for constructing and managing FX graphs during compilation: @@ -30,6 +35,7 @@ import re import sys import traceback +<<<<<<< HEAD import warnings import weakref from collections.abc import Generator, Sequence @@ -37,6 +43,11 @@ from types import CodeType from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union from typing_extensions import ParamSpec, TypeVar +======= +import weakref +from dataclasses import dataclass, field as dc_field +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sympy @@ -58,7 +69,10 @@ ) from torch._subclasses.fake_tensor import FakeTensor from torch._utils_internal import signpost_event +<<<<<<< HEAD from torch.export.dynamic_shapes import _ConstraintTarget +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined] from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.symbolic_shapes import ( @@ -68,7 +82,10 @@ ShapeEnv, Specialization, ) +<<<<<<< HEAD from torch.fx.node import Target +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._ordered_set import OrderedSet @@ -77,6 +94,7 @@ from . import config, exc, logging as torchdynamo_logging, variables from .backends.registry import CompiledFn, CompilerFn from .bytecode_transformation import ( +<<<<<<< HEAD create_binary_slice, create_call_function, create_dup_top, @@ -84,6 +102,11 @@ create_load_const, create_rot_n, create_swap, +======= + create_call_function, + create_instruction, + create_load_const, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Instruction, unique_id, ) @@ -102,9 +125,14 @@ from .graph_region_tracker import GraphRegionTracker from .guards import GuardBuilder, install_guard from .mutation_guard import is_dynamic_nn_module +<<<<<<< HEAD from .side_effects import AttributeMutationExisting, SideEffects, ValueMutationExisting from .source import ( _get_source_debug_name, +======= +from .side_effects import AttributeMutationExisting, SideEffects +from .source import ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AttrSource, BackwardStateSource, ConstantSource, @@ -150,7 +178,11 @@ ) from .variables.ctx_manager import ContextWrappingVariable from .variables.lists import BaseListVariable +<<<<<<< HEAD from .variables.misc import NullVariable +======= +from .variables.misc import CellVariable, NullVariable +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .variables.nn_module import NNModuleVariable from .variables.tensor import ( NumpyNdarrayVariable, @@ -159,6 +191,7 @@ UnspecializedPythonVariable, ) from .variables.torch_function import TensorWithTFOverrideVariable +<<<<<<< HEAD from .variables.user_defined import UserDefinedDictVariable @@ -166,6 +199,14 @@ from torch._dynamo.package import CompilePackage from torch._dynamo.symbolic_convert import InstructionTranslatorBase +======= + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslatorBase + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph") graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code") @@ -197,31 +238,54 @@ class MutationInfo: class VariableTrackerCache: +<<<<<<< HEAD def __init__(self) -> None: self.cache: dict[VariableTrackerCacheKey, VariableTracker] = {} def lookup(self, value: Any, source: Source) -> Optional[VariableTracker]: +======= + def __init__(self): + self.cache = {} + + def lookup(self, value, source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) key = VariableTrackerCacheKey(id(value), source) if key not in self.cache: return None return self.cache[key] +<<<<<<< HEAD def add(self, value: Any, source: Source, vt: VariableTracker) -> None: key = VariableTrackerCacheKey(id(value), source) self.cache[key] = vt def clone(self) -> "VariableTrackerCache": +======= + def add(self, value, source, vt): + key = VariableTrackerCacheKey(id(value), source) + self.cache[key] = vt + + def clone(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Needed for copy and restore graph state new_cache = VariableTrackerCache() new_cache.cache.update(self.cache) return new_cache +<<<<<<< HEAD def clear(self) -> None: +======= + def clear(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.cache.clear() @functools.cache +<<<<<<< HEAD def _step_logger() -> Any: +======= +def _step_logger(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torchdynamo_logging.get_step_logger(log) @@ -232,16 +296,28 @@ class GraphCompileReason: reason: str user_stack: list[traceback.FrameSummary] +<<<<<<< HEAD # Indicates if this was a graph break reason due to graph break. graph_break: bool = True def __post_init__(self) -> None: +======= + # Indicates if this was a graph compile reason due to graph break. + graph_break: bool = True + + def __post_init__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.graph_break: graph_break_reasons.append(self) +<<<<<<< HEAD def _get_gen_rand_values_fn(random_calls: Any) -> Callable[[], list[Any]]: def _gen_rand_values() -> list[Any]: +======= +def _get_gen_rand_values_fn(random_calls): + def _gen_rand_values(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return [fn(*args, **kwargs) for fn, args, kwargs in random_calls] return _gen_rand_values @@ -258,18 +334,29 @@ def __init__(self, nn_modules: dict[str, torch.nn.Module]): def __repr__(self) -> str: return "FakeRootModule(...)" +<<<<<<< HEAD def add_nn_modules(self, nn_modules: dict[str, torch.nn.Module]) -> None: +======= + def add_nn_modules(self, nn_modules: dict[str, torch.nn.Module]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for k, v in nn_modules.items(): setattr(self, k, v) class WrapperBackend: +<<<<<<< HEAD def __init__(self, backend: CompilerFn) -> None: self.backend: CompilerFn = backend def __call__( self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] ) -> CompiledFn: +======= + def __init__(self, backend: CompilerFn): + self.backend: CompilerFn = backend + + def __call__(self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.restore = checkpoint_params(gm) self.gm = gm copy_gm = copy.deepcopy(self.gm) @@ -322,6 +409,7 @@ class OutputGraphGuardsState: dual_level: int functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter] current_device: Optional[torch.device] +<<<<<<< HEAD global_state_guard: torch._C._dynamo.guards.GlobalStateGuard _guards: torch._guards.GuardsSet _aotautograd_guards: list[torch._guards.GuardEnvExpr] @@ -343,6 +431,25 @@ def guards(self) -> torch._guards.GuardsSet: @property def aotautograd_guards(self) -> list[torch._guards.GuardEnvExpr]: +======= + + export: bool = False + export_constraints: bool = False + + _guards: Optional[torch._guards.GuardsSet] = None + _aotautograd_guards: Optional[list[torch._guards.GuardEnvExpr]] = None + + @property + def shape_env(self): + raise AssertionError(f"shape_env shouldn't be accessed from {type(self)}") + + @property + def guards(self): + return self._guards + + @property + def aotautograd_guards(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self._aotautograd_guards @@ -352,10 +459,13 @@ class StackLocalsMetadata: Stores metadata for a frame's stack and locals for the purposes of building resume functions """ +<<<<<<< HEAD num_stack: int = 0 # number of stack elements, minus removed NULLs locals_names: dict[str, int] = dc_field( default_factory=dict ) # order of locals codegen'd to the stack +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stack_null_idxes: list[int] = dc_field(default_factory=list) locals_null_keys: list[str] = dc_field(default_factory=list) stack_ctx_args: list[tuple[int, tuple[Any, ...]]] = dc_field(default_factory=list) @@ -363,6 +473,7 @@ class StackLocalsMetadata: locals_ctx_args: list[tuple[str, tuple[Any, ...]]] = dc_field(default_factory=list) +<<<<<<< HEAD # TODO we should expand this to make it work for atribtrary in/out @dataclass class ExportMetaData: @@ -401,6 +512,8 @@ def get_builtins_dict(global_scope: Scope) -> dict[str, Any]: return f_builtins +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class OutputGraph(OutputGraphGuardsState): """ Wrapper class to hold outputs of InstructionTranslator. Mainly the @@ -418,6 +531,7 @@ def __init__( self, code_options: dict[str, Any], compiler_fn: Optional[CompilerFn], +<<<<<<< HEAD root_tx: "InstructionTranslatorBase", export: bool, export_constraints: Sequence[_ConstraintTarget], @@ -428,6 +542,18 @@ def __init__( torch_function_mode_stack: list[torch.overrides.TorchFunctionMode], package: Optional["CompilePackage"], ) -> None: +======= + root_tx, + export: bool, + export_constraints, + frame_state, + local_scope: Scope, + global_scope: Scope, + f_code, + torch_function_mode_stack, + package, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__( local_scope, global_scope, @@ -437,19 +563,26 @@ def __init__( dual_level=torch.autograd.forward_ad._current_level, functorch_layers=torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters(), current_device=torch.utils._device.CURRENT_DEVICE, +<<<<<<< HEAD # initial_global_state is only None during NopTest. global_state_guard=torch._dynamo.convert_frame.initial_global_state or torch._C._dynamo.guards.GlobalStateGuard(), # These are set by @property instead, just initialize them as blank _guards=torch._guards.GuardsSet(), _aotautograd_guards=[], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.tracers = [SubgraphTracer(self, is_export=export)] # Map from graph input's `Source` to its `VariableTracker` to # de-duplicate graph inputs by source and reuse the tracker self.input_source_to_var: dict[Source, VariableTracker] = {} self.export = export +<<<<<<< HEAD self.export_constraints = export_constraints # type: ignore[assignment] +======= + self.export_constraints = export_constraints +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.frame_state = frame_state self.cleanup_hooks: list[Callable[[], Any]] = [] # compile_id is an id number for the current torch.compile @@ -486,6 +619,10 @@ def __init__( allow_scalar_outputs=config.capture_scalar_outputs, allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards, +<<<<<<< HEAD +======= + allow_complex_guards_as_runtime_asserts=config.allow_complex_guards_as_runtime_asserts, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) co_fields=self.co_fields, ) @@ -613,12 +750,16 @@ def __init__( self.maybe_install_saved_tensors_hooks_subgraphs() ) +<<<<<<< HEAD # mangled alias -> module fqn name self.import_sources: dict[str, str] = {} self.export_metadata = ExportMetaData() def mark_bytecode_tracing_start(self) -> None: +======= + def mark_bytecode_tracing_start(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.compiler_trace_stack.enter_context( dynamo_timed( "bytecode_tracing", @@ -626,6 +767,7 @@ def mark_bytecode_tracing_start(self) -> None: ) ) +<<<<<<< HEAD def mark_bytecode_tracing_stop(self) -> None: self.compiler_trace_stack.close() @@ -636,12 +778,41 @@ def install_builtins_dict_in_fglobals(self) -> str: def add_backward_state_hook( self, hook: VariableTracker, prefix: str = "hook" ) -> tuple[str, torch.fx.Proxy]: +======= + def mark_bytecode_tracing_stop(self): + self.compiler_trace_stack.close() + + def install_builtins_dict_in_fglobals(self): + # f_globals["__builtins__"] can be a dict or a module. This is an + # implementation detail - + # https://docs.python.org/3/library/builtins.html. + + # This makes guarding on any builtin messy because the guard check_fn + # has to check if the __builtins__ is a module or dict, and then access + # by either using getattr or getitem respectively. + + # To solve this problem, we insert a new entry in f_globals which points + # to the builtins __dict__ and then we guard any builtin on this dict. + # To avoid any collision with the pre-existing keys, we use the + # install_global to give us a unique dict key. + + f_builtins = self.global_scope["__builtins__"] + if not isinstance(f_builtins, dict): + f_builtins = f_builtins.__dict__ + return self.install_global("__builtins_dict__", f_builtins) + + def add_backward_state_hook(self, hook: VariableTracker, prefix="hook"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) name = f"{prefix}{len(self.backward_state)}" assert name not in self.backward_state self.backward_state[name] = hook return name, self.get_backward_state_proxy() +<<<<<<< HEAD def get_backward_state_proxy(self) -> torch.fx.Proxy: +======= + def get_backward_state_proxy(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.backward_state_proxy is None: if self.export: unimplemented_v2( @@ -662,7 +833,11 @@ def get_backward_state_proxy(self) -> torch.fx.Proxy: return self.backward_state_proxy # This gets its own helper function so guards DEBUG logs are more informative +<<<<<<< HEAD def init_ambient_guards(self) -> None: +======= + def init_ambient_guards(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Register a SHAPE_ENV guard to make sure we setup shape guards # that show up in ShapeEnv self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV)) @@ -720,8 +895,12 @@ def maybe_install_saved_tensors_hooks_subgraphs(self) -> Optional[list[str]]: assert unpack_subgraph_name == "saved_tensors_hooks_unpack_0" return [pack_subgraph_name, unpack_subgraph_name] +<<<<<<< HEAD def dump_guards_state(self) -> OutputGraphGuardsState: # Dump a serializable version of self without extras +======= + def dump_guards_state(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return OutputGraphGuardsState( local_scope=self.local_scope, global_scope=self.global_scope, @@ -731,18 +910,27 @@ def dump_guards_state(self) -> OutputGraphGuardsState: dual_level=self.dual_level, functorch_layers=self.functorch_layers, current_device=self.current_device, +<<<<<<< HEAD global_state_guard=self.global_state_guard, name_of_builtins_dict_key_in_fglobals=self.name_of_builtins_dict_key_in_fglobals, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) export=self.export, export_constraints=self.export_constraints, _guards=self.guards, _aotautograd_guards=self.aotautograd_guards, +<<<<<<< HEAD skip_guards_check=self.skip_guards_check, ) def synthetic_graph_input( self, fn: Callable[..., Any], args: tuple[Any, ...] ) -> VariableTracker: +======= + ) + + def synthetic_graph_input(self, fn, args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ call fn(*args) before the graph runs and turn the result into a fake input. """ @@ -768,15 +956,23 @@ def synthetic_graph_input( ) return result +<<<<<<< HEAD def add_cleanup_hook(self, fn: Callable[[], Any]) -> None: self.cleanup_hooks.append(fn) def call_cleanup_hooks(self) -> None: +======= + def add_cleanup_hook(self, fn: Callable[[], Any]): + self.cleanup_hooks.append(fn) + + def call_cleanup_hooks(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for hook in reversed(self.cleanup_hooks): hook() self.cleanup_hooks.clear() @property +<<<<<<< HEAD def root_tracer(self) -> "SubgraphTracer": return self.tracers[0] @@ -785,15 +981,30 @@ def current_tracer(self) -> "SubgraphTracer": return self.tracers[-1] def is_root_tracer(self) -> bool: +======= + def root_tracer(self): + return self.tracers[0] + + @property + def current_tracer(self): + return self.tracers[-1] + + def is_root_tracer(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Helper to tell if we are inside the higher order operator tracing. return len(self.tracers) == 1 @property +<<<<<<< HEAD def graph(self) -> torch.fx.Graph: +======= + def graph(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.current_tracer.graph # TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer. @graph.setter +<<<<<<< HEAD def graph(self, value: torch.fx.Graph) -> None: self.current_tracer.graph = value @@ -807,6 +1018,21 @@ def real_value_cache(self) -> dict[fx.Node, torch.Tensor]: @property def bound_symbols(self) -> dict[sympy.Symbol, Union[torch.fx.Proxy, "LazyProxy"]]: +======= + def graph(self, value): + self.current_tracer.graph = value + + @property + def input_name_to_proxy(self): + return self.current_tracer.input_name_to_proxy + + @property + def real_value_cache(self): + return self.current_tracer.real_value_cache + + @property + def bound_symbols(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.current_tracer.bound_symbols # If you are here, and you're looking for create_graph_input, @@ -815,6 +1041,7 @@ def bound_symbols(self) -> dict[sympy.Symbol, Union[torch.fx.Proxy, "LazyProxy"] # - self.root_tracer.create_graph_input # See NOTE [HigherOrderOperator tracing design] for more context. +<<<<<<< HEAD def create_proxy(self, *args: Any, **kwargs: Any) -> torch.fx.Proxy: return self.current_tracer.create_proxy(*args, **kwargs) @@ -828,6 +1055,19 @@ def remove_node(self, *args: Any, **kwargs: Any) -> None: def subtracer( self, source_target: Optional[Target], prior_tracer: "SubgraphTracer" ) -> Generator[fx.Tracer, None, None]: +======= + def create_proxy(self, *args, **kwargs): + return self.current_tracer.create_proxy(*args, **kwargs) + + def create_node(self, *args, **kwargs): + return self.current_tracer.create_node(*args, **kwargs) + + def remove_node(self, *args, **kwargs): + return self.current_tracer.remove_node(*args, **kwargs) + + @contextlib.contextmanager + def subtracer(self, source_target, prior_tracer): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_scope_ctx = enter_new_scope() try: if prior_tracer: @@ -851,6 +1091,7 @@ def subtracer( self.tracers.pop() @property +<<<<<<< HEAD def output(self) -> "OutputGraph": return self @@ -863,6 +1104,17 @@ def fake_mode(self) -> torch._subclasses.FakeTensorMode: def shape_env(self) -> ShapeEnv: assert self.tracing_context.fake_mode is not None assert self.tracing_context.fake_mode.shape_env is not None +======= + def output(self): + return self + + @property + def fake_mode(self): + return self.tracing_context.fake_mode + + @property + def shape_env(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.tracing_context.fake_mode.shape_env @property @@ -874,12 +1126,19 @@ def nn_modules(self) -> dict[str, Any]: return self.tracing_context.module_context.nn_modules @property +<<<<<<< HEAD def aotautograd_guards(self) -> list[torch._guards.GuardEnvExpr]: return self.tracing_context.guards_context.aotautograd_guards def save_global_state( self, out: Optional[dict[str, tuple[Callable[..., Any], bool]]] = None ) -> None: +======= + def aotautograd_guards(self): + return self.tracing_context.guards_context.aotautograd_guards + + def save_global_state(self, out=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Saves to out if it is provided. Else saves to the tracing context's global_state. """ @@ -915,6 +1174,7 @@ def save_global_state( torch.is_autocast_cache_enabled(), ) +<<<<<<< HEAD def push_tx(self, tx: "InstructionTranslatorBase") -> None: self._current_tx.append(tx) @@ -935,6 +1195,25 @@ def has_outputs(self) -> bool: return len([x for x in self.graph.nodes if x.op == "output"]) > 0 def get_submodule(self, keys: str) -> Union[torch.nn.Module, Any]: +======= + def push_tx(self, tx): + self._current_tx.append(tx) + + def pop_tx(self): + return self._current_tx.pop() + + @property + def current_tx(self): + return self.root_tx if not self._current_tx else self._current_tx[-1] + + def count_calls(self): + return count_calls(self.graph) + + def is_empty_graph(self): + return len(list(self.graph.nodes)) == 0 + + def get_submodule(self, keys): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert keys obj: Union[torch.nn.Module, dict[str, torch.nn.Module]] = self.nn_modules for k in keys.split("."): @@ -944,7 +1223,11 @@ def get_submodule(self, keys: str) -> Union[torch.nn.Module, Any]: obj = getattr(obj, k) return obj +<<<<<<< HEAD def new_var(self, name: str = "tmp") -> str: +======= + def new_var(self, name="tmp"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) existing = set(self.code_options["co_varnames"]) # In common case, this will be O(1) while True: @@ -953,13 +1236,21 @@ def new_var(self, name: str = "tmp") -> str: self.code_options["co_varnames"] += (var,) return var +<<<<<<< HEAD def update_co_names(self, name: str) -> None: +======= + def update_co_names(self, name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Ensure self.code_options.co_names contains name""" if name not in self.code_options["co_names"]: self.code_options["co_names"] += (name,) @staticmethod +<<<<<<< HEAD def module_key_name(*names: Any) -> str: +======= + def module_key_name(*names): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # create a new unique name name = "_".join(map(str, names)) # Strip the guard lookup L/G access @@ -988,9 +1279,15 @@ def register_static_attr_and_return_proxy( def register_attr_or_module( self, target: Union[torch.nn.Module, torch.Tensor, Any], +<<<<<<< HEAD *names: Any, **options: Any, ) -> VariableTracker: +======= + *names, + **options, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_dynamic_nn_module(target, self.export): # Instead of returning UnspecializedNNModuleVariable, call # VariableTracker.build so that it is tracked for mutation. @@ -1017,13 +1314,20 @@ def register_attr_or_module( # are registered as get_attr nodes in the root graph. tracer = self.root_tracer +<<<<<<< HEAD def wrap_name(module_key: str) -> VariableTracker: +======= + def wrap_name(module_key): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.param_name_to_source is not None self.param_name_to_source[module_key] = source # Check if the attr has already been registered. This can happen # when two different sources point to the same tensor. +<<<<<<< HEAD assert self.root_tx is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if target in self.root_tx.output.side_effects: return self.root_tx.output.side_effects[target] @@ -1045,8 +1349,13 @@ def wrap_name(module_key: str) -> VariableTracker: # different sources pointing to the same tensor object. vt = self.root_tx.output.side_effects.track_object_existing(target, vt) +<<<<<<< HEAD assert "tensor_dict" not in vt.as_proxy().node.meta vt.as_proxy().node.meta["tensor_dict"] = _extract_tensor_dict(target) +======= + assert "tensor_dict" not in vt.proxy.node.meta + vt.proxy.node.meta["tensor_dict"] = _extract_tensor_dict(target) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return vt @@ -1056,7 +1365,11 @@ def wrap_name(module_key: str) -> VariableTracker: if source: install_guard(source.make_guard(GuardBuilder.NN_MODULE)) +<<<<<<< HEAD def wrap_name(module_key: str) -> VariableTracker: +======= + def wrap_name(module_key): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return NNModuleVariable(type(target), module_key, target, **options) else: @@ -1064,7 +1377,11 @@ def wrap_name(module_key: str) -> VariableTracker: # from higher order ops. NNModuleVariable tracker can't be # sourceless, so let's return a unspecializedNNModule variable # tracker. +<<<<<<< HEAD def wrap_name(module_key: str) -> VariableTracker: +======= + def wrap_name(module_key): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return variables.UnspecializedNNModuleVariable(target, **options) elif isinstance(target, (torch.SymInt, torch.SymFloat)): @@ -1075,7 +1392,11 @@ def wrap_name(module_key: str) -> VariableTracker: # own storage # alas, this is like this for now +<<<<<<< HEAD def wrap_name(module_key: str) -> VariableTracker: +======= + def wrap_name(module_key): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return SymNodeVariable.create( self, self.create_proxy("get_attr", module_key, (), {}), @@ -1086,7 +1407,11 @@ def wrap_name(module_key: str) -> VariableTracker: # HACKY CODE REGION END else: +<<<<<<< HEAD def wrap_name(module_key: str) -> VariableTracker: +======= + def wrap_name(module_key): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.output.update_co_names(module_key) self.global_scope[module_key] = target return VariableTracker.build( @@ -1105,7 +1430,11 @@ def wrap_name(module_key: str) -> VariableTracker: self.nn_modules[name] = target if isinstance(target, torch.nn.Module): +<<<<<<< HEAD def register_leaf_name(leaf_name: str) -> None: +======= + def register_leaf_name(leaf_name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.param_name_to_source is not None new_source = ParamBufferSource(source, leaf_name) new_name = f"{name}.{leaf_name}" @@ -1126,9 +1455,13 @@ def register_leaf_name(leaf_name: str) -> None: return wrap_name(name) +<<<<<<< HEAD def handle_aliases_for_stolen_lists( self, tx: "InstructionTranslatorBase" ) -> tuple[list[Instruction], dict[Source, Source]]: +======= + def handle_aliases_for_stolen_lists(self, tx): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # If list inputs are stolen, but still needed after the function call, create aliases to keep them alive maybe_gm = self.local_scope.get("self") stolen_list_names = get_locals_to_steal(maybe_gm) @@ -1214,9 +1547,13 @@ def handle_aliases_for_stolen_lists( # other parts of Dynamo like guards. return alias_insts, overridden_sources +<<<<<<< HEAD def _get_stack_values_to_restore( self, tx: "InstructionTranslatorBase", stack_pops: int ) -> tuple[list[VariableTracker], StackLocalsMetadata]: +======= + def _get_stack_values_to_restore(self, tx, stack_pops): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Gets the stack + locals values belonging to tx that need to be restored. @@ -1228,6 +1565,10 @@ def _get_stack_values_to_restore( Returns: - stack_values: stack and locals values that need to be restored +<<<<<<< HEAD +======= + - restore_vars: names of locals corresponding to the locals part of `stack_values` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - meta: locations of NULLs and ContextWrappingVariables in the stack/locals (ignores the top `stack_pops` values on the stack) """ @@ -1256,10 +1597,16 @@ def _get_stack_values_to_restore( meta.stack_ctx_args.append((len(stack_values) - 1, target_values)) meta.stack_ctx_idxes_orig.append(i) +<<<<<<< HEAD meta.num_stack = len(stack_values) cell_and_freevars = set(tx.cellvars() + tx.freevars()) +======= + # Add all the local vars to the "stack" so restore at the end + restore_vars: list[str] = [] + val_to_names: dict[VariableTracker, list[str]] = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: Typically (i.e., for graph compile from RETURN_VALUE), # symbolic_locals will be empty at this point, as prune_dead_locals # will clear out all of symbolic_locals because RETURN_VALUE is the @@ -1274,6 +1621,7 @@ def _get_stack_values_to_restore( # This will in turn result in spurious variables showing up in the graph. # This was very tricky to debug. For an example, dump the graph at call_user_compiler # while running test_subgraphs.py +<<<<<<< HEAD # Do not include top-frame unmodified locals here - otherwise, the compiled graph may # erroneously include them as part of the return. We manually codegen them afterward. if ( @@ -1288,6 +1636,14 @@ def _get_stack_values_to_restore( # Do not load variable if it is NULL. if sys.version_info >= (3, 12): # NOTE: do not use isinstance, since it realizes lazy VT's +======= + if isinstance(v.source, LocalSource) and v.source.local_name == k: + continue # no need to restore initial state + if isinstance(v, CellVariable) and v.local_name == k: + continue # no need to restore initial state + # Do not load variable if it is NULL. + if sys.version_info >= (3, 12): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Continuation function will load the NULL for v. if type.__instancecheck__(NullVariable, v): meta.locals_null_keys.append(k) @@ -1295,23 +1651,43 @@ def _get_stack_values_to_restore( else: # A variable should never be NULL in < 3.12 assert not type.__instancecheck__(NullVariable, v) +<<<<<<< HEAD meta.locals_names[k] = len(meta.locals_names) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(v, ContextWrappingVariable): target_values = ( () if v.target_values is None else tuple(v.target_values) ) meta.locals_ctx_args.append((k, target_values)) +<<<<<<< HEAD stack_values.append(v) return stack_values, meta +======= + if v not in val_to_names: + val_to_names[v] = [] + val_to_names[v].append(k) + for v in val_to_names.keys(): + restore_vars.extend(val_to_names[v]) + stack_values.extend([v] * len(val_to_names[v])) + + return stack_values, restore_vars, meta +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def compile_subgraph( self, tx: "InstructionTranslatorBase", reason: GraphCompileReason, +<<<<<<< HEAD partial_convert: bool = False, stack_pops: int = 0, ) -> list[StackLocalsMetadata]: +======= + partial_convert=False, + stack_pops=0, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Compiles the current subgraph, with inputs w.r.t. self.root_tx, and codegens: - Call the compiled subgraph @@ -1329,9 +1705,15 @@ def compile_subgraph( assert self.root_tx is not None +<<<<<<< HEAD if not config.nested_graph_breaks: # expect to only compile 1 frame assert self.root_tx is tx +======= + # FIXME temporary assert to make sure we're not accidentally compiling nested graph breaks + # before we're done the full implementation + assert self.root_tx is tx +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # bytecode tracing has finished. Pop the context manager for dynamo_timed self.mark_bytecode_tracing_stop() @@ -1345,16 +1727,29 @@ def compile_subgraph( # prefix instructions (Python 3.11+) prefix_insts: list[Instruction] = [] if sys.version_info >= (3, 11): +<<<<<<< HEAD for inst in self.root_tx.prefix_insts: if inst.opname == "COPY_FREE_VARS": prefix_insts.append( create_instruction( "COPY_FREE_VARS", arg=len(self.root_tx.code_options["co_freevars"]), +======= + for inst in tx.prefix_insts: + if inst.opname == "MAKE_CELL": + prefix_insts.append( + create_instruction("MAKE_CELL", argval=inst.argval) + ) + elif inst.opname == "COPY_FREE_VARS": + prefix_insts.append( + create_instruction( + "COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) else: prefix_insts.append(copy.copy(inst)) +<<<<<<< HEAD # stack values and restore vars for each frame are pushed in reverse order # i.e. last element corresponds to root frame (1), @@ -1381,6 +1776,8 @@ def compile_subgraph( # "Garbage collect the heap". self.side_effects.prune_dead_object_new(tx) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.add_output_instructions(prefix_insts) assert not (self.pregraph_bytecode and self.export), ( @@ -1393,8 +1790,37 @@ def compile_subgraph( ) self.add_output_instructions(alias_insts) +<<<<<<< HEAD + self.cleanup_graph() + +======= + # Exit from all context manager variables to make sure global state is restored + for block in reversed(self.root_tx.block_stack): + block.exit(self.root_tx, is_graph_break=reason.graph_break) + self.cleanup_graph() + # stack values and restore vars for each frame are pushed in reverse order + # i.e. last element corresponds to root frame, first element corresponds to current frame + all_stack_values = [] + all_restore_vars = [] + all_stack_locals_metas = [] + cur_tx: Optional[InstructionTranslatorBase] = tx + while True: + assert cur_tx is not None + # this should have been checked by the caller + assert all(block.can_restore() for block in cur_tx.block_stack) + stack_values, restore_vars, meta = self._get_stack_values_to_restore( + cur_tx, stack_pops + ) + all_stack_values.append(stack_values) + all_restore_vars.append(restore_vars) + all_stack_locals_metas.append(meta) + if cur_tx is self.root_tx: + break + cur_tx = tx.parent + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Use nn.Module "proxies" in the constructed GraphModule so that # the resulting GM does not hold additional strong references to the original modules. # This prevents a strong ref cycle where Dynamo created code holds on to references @@ -1429,6 +1855,7 @@ def compile_subgraph( ) self.add_output_instructions(random_calls_instructions) +<<<<<<< HEAD # Codegen stack convention before the unsupported instruction # NOTE: in these comment blocks, "locals" EXCLUDE free and cell vars. # NOTE: stack and locals must be codegen'd BEFORE the unsupported instruction, since the latter @@ -1453,6 +1880,15 @@ def compile_subgraph( if ( self.root_tx is tx # single frame and stack_values_flat +======= + # call compiled fx graph + graph_output_var = None + stored_graph_output_var = False + root_stack_values = all_stack_values[-1] + if ( + self.root_tx is tx + and root_stack_values +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and all( not isinstance( v, @@ -1463,10 +1899,17 @@ def compile_subgraph( ), ) and not (isinstance(v, SymNodeVariable) and v.python_type() is float) +<<<<<<< HEAD for v in stack_values_flat ) and all(isinstance(x, TensorVariable) for x in stack_values_flat) and len(set(stack_values_flat)) == len(stack_values_flat) +======= + for v in root_stack_values + ) + and all(isinstance(x, TensorVariable) for x in root_stack_values) + and len(set(root_stack_values)) == len(root_stack_values) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and self.side_effects.is_empty() and not tx.debug_locals and not self.backward_state @@ -1475,6 +1918,7 @@ def compile_subgraph( ): # optimization to generate better code in a common case self.add_output_instructions( +<<<<<<< HEAD [ # load in reverse since UNPACK_SEQUENCE will reverse *self.compile_and_call_fx_graph( @@ -1488,6 +1932,19 @@ def compile_subgraph( graph_output_var = self.new_var("graph_out") # load stack values in a flat manner - we will codegen bytecode to place them correctly # according to our convention above +======= + self.compile_and_call_fx_graph( + tx, list(reversed(root_stack_values)), root + ) + + [create_instruction("UNPACK_SEQUENCE", arg=len(root_stack_values))] + ) + else: + graph_output_var = self.new_var("graph_out") + # load stack values in a flat manner for now - will likely change later. + stack_values_flat = [ + val for vals in reversed(all_stack_values) for val in vals + ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pass1 = PyCodegen( self.root_tx, root, @@ -1514,6 +1971,7 @@ def compile_subgraph( ) self.codegen_suffix(tx, stack_values_flat, pass2) +<<<<<<< HEAD if ( torch._dynamo.config.log_graph_in_out_metadata and stack_values_flat @@ -1562,6 +2020,8 @@ def compile_subgraph( self.export_metadata.out_spec = out_spec.as_python_constant() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output = [] if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0: output.extend( @@ -1579,6 +2039,7 @@ def compile_subgraph( self.run_compiler_collective() self.add_output_instructions(output + pass2.get_instructions()) +<<<<<<< HEAD # store all stack and locals for each frame # current state of the stack: # *(frame N stack), *(frame N locals), @@ -1721,6 +2182,28 @@ def codegen_suffix( stack_values: list[VariableTracker], cg: PyCodegen, ) -> None: +======= + # restore all the live local vars of the root + local_restore_cg = PyCodegen( + self.root_tx, overridden_sources=overridden_sources + ) + # TODO this local restoration should be removed when fully implementing nested graph breaks + self.add_output_instructions( + [ + local_restore_cg.create_store(var) + for var in reversed(all_restore_vars[-1]) + ] + ) + + if graph_output_var and stored_graph_output_var: + self.add_output_instructions( + [local_restore_cg.create_delete(graph_output_var)] + ) + + return all_stack_locals_metas + + def codegen_suffix(self, tx, stack_values, cg): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NOTE: `codegen_save_tempvars` must run first to update `source` fields # for variables with `AttributeMutationNew`, as they don't implement # `reconstruct` themselves. @@ -1729,7 +2212,10 @@ def codegen_suffix( assert not self.export for name, val in self.backward_state.items(): cg(val) +<<<<<<< HEAD assert self.backward_state_var is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cg.append_output(cg.create_load(self.backward_state_var)) cg.store_attr(name) self.side_effects.codegen_hooks(cg) @@ -1745,7 +2231,11 @@ def codegen_suffix( cg.restore_stack(stack_values, value_from_source=not tx.export) self.side_effects.codegen_update_mutated(cg) +<<<<<<< HEAD def cleanup_graph(self) -> None: +======= + def cleanup_graph(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Remove "creation_timestamp" from node meta @@ -1775,6 +2265,7 @@ def cleanup_graph(self) -> None: self.graph.erase_node(node1) self.graph.erase_node(node2) +<<<<<<< HEAD def bypass_package(self, reason: str = "", **kwargs: Any) -> None: """ Do not save this output graph to the CompilePackage @@ -1803,6 +2294,10 @@ def bypass_package(self, reason: str = "", **kwargs: Any) -> None: def get_graph_sizes_structured(self) -> dict[str, list[Union[int, str]]]: ret: dict[str, list[Union[int, str]]] = {} +======= + def get_graph_sizes_structured(self): + ret = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for node in self.graph.nodes: example_value = node.meta.get("example_value", None) if isinstance(example_value, torch._subclasses.FakeTensor): @@ -1810,7 +2305,11 @@ def get_graph_sizes_structured(self) -> dict[str, list[Union[int, str]]]: ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size] return ret +<<<<<<< HEAD def get_graph_sizes(self, name: str) -> str: +======= + def get_graph_sizes(self, name: str): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n" graph_sizes_str += f"===== {name} =====\n" for node in self.graph.nodes: @@ -1836,7 +2335,11 @@ def get_graph_sizes(self, name: str) -> str: return graph_sizes_str @contextlib.contextmanager +<<<<<<< HEAD def restore_global_state(self) -> Any: +======= + def restore_global_state(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Momentarily restores the global state to what it was prior to tracing the current output """ @@ -1853,7 +2356,11 @@ def restore_global_state(self) -> Any: GlobalContextCheckpointState(current_global_state) ) +<<<<<<< HEAD def run_compiler_collective(self) -> None: +======= + def run_compiler_collective(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tx = self.root_tx assert tx is not None if (ds := tx.distributed_state) is not None and ds.all_states is None: @@ -1877,7 +2384,11 @@ def run_compiler_collective(self) -> None: ), dynamo_timed("compiler_collective", log_pt2_compile_event=True), ): +<<<<<<< HEAD all_states: list[Any] = [None] * compile_pg.size() +======= + all_states = [None] * compile_pg.size() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dist.all_gather_object(all_states, ds.local_state, group=compile_pg) ds.all_states = all_states # Clear speculation log, because are tracing may diverge due to @@ -1885,12 +2396,16 @@ def run_compiler_collective(self) -> None: tx.speculation_log.clear() raise exc.CompileCollectiveRestartAnalysis +<<<<<<< HEAD def compile_and_call_fx_graph( self, tx: "InstructionTranslatorBase", rv: list[VariableTracker], root: FakeRootModule, ) -> list[Instruction]: +======= + def compile_and_call_fx_graph(self, tx, rv, root): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Generate code from self.graph and return the Instruction()s to call that generated code. @@ -1904,8 +2419,11 @@ def compile_and_call_fx_graph( assert self.should_exit self.run_compiler_collective() +<<<<<<< HEAD if count_calls(self.graph) == 0 and len(rv) == 0: return [] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) name = unique_id("__compiled_fn", with_uuid=True) @@ -1919,7 +2437,11 @@ def compile_and_call_fx_graph( {}, ) sub_gms = self.dedup_pass() +<<<<<<< HEAD root.add_nn_modules(sub_gms) # type: ignore[arg-type] +======= + root.add_nn_modules(sub_gms) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.current_tracer._maybe_preserve_original_meta(tx, output_node) if not config.do_not_emit_runtime_asserts: @@ -1959,6 +2481,7 @@ def compile_and_call_fx_graph( for register_finalizer in self.register_finalizer_fns: register_finalizer(gm) +<<<<<<< HEAD if next(gm.parameters(), None) is not None: # If dynamo produces a graph with parameters, skip package stuff # Bypass output graph @@ -1973,12 +2496,18 @@ def compile_and_call_fx_graph( if self.package is not None: gm._backend_id = name +======= + gm._backend_id = name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gm.compile_subgraph_reason = self.compile_subgraph_reason gm.meta["dynamo_flat_name_to_original_fqn"] = ( self.dynamo_flat_name_to_original_fqn.copy() ) gm.meta["dynamo_compile_id"] = self.dynamo_compile_id +<<<<<<< HEAD gm.meta["backend_id"] = name +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph_code_log.debug( "%s", @@ -1995,7 +2524,10 @@ def compile_and_call_fx_graph( ) self.call_cleanup_hooks() old_fake_mode = self.tracing_context.fake_mode +<<<<<<< HEAD assert old_fake_mode is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not self.export: import torch._functorch.config as _config @@ -2043,7 +2575,10 @@ def compile_and_call_fx_graph( ) counters["stats"]["unique_graphs"] += 1 +<<<<<<< HEAD assert old_fake_mode.shape_env is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if specializations := old_fake_mode.shape_env.specializations: specialization_guards = [] specialization_cache: dict[Specialization, Callable[[Any], Any]] = {} @@ -2051,10 +2586,14 @@ def compile_and_call_fx_graph( for specialization in specializations: source_index = sources.index(specialization.source) check_fn_source = inspect.getsource(specialization.check_fn).strip() +<<<<<<< HEAD # Required because the LABDA_GUARD API requires a root guard manager unused_root_guard_manager = RootGuardManager() check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined] unused_root_guard_manager, +======= + check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) specialization.check_fn, [check_fn_source], ) @@ -2076,8 +2615,13 @@ def compile_and_call_fx_graph( ) ) +<<<<<<< HEAD @torch._dynamo.disable(reason="do not trace Dynamo-compiled graph") # type: ignore[misc] def specialized_dispatch(*args: Any, **kwargs: Any) -> Any: +======= + @torch._dynamo.disable(reason="do not trace Dynamo-compiled graph") + def specialized_dispatch(*args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for check_fn, specialization in specialization_guards: if check_fn(args): if specialization in specialization_cache: @@ -2107,10 +2651,13 @@ def specialized_dispatch(*args: Any, **kwargs: Any) -> Any: assert self.root_tx is not None cg = PyCodegen(self.root_tx) +<<<<<<< HEAD for idx, arg in enumerate(self.graphargs): self.export_metadata.graph_input_idx_to_local_source[idx] = arg.source +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cg.make_call_generated_code(name) return cg.get_instructions() @@ -2210,16 +2757,27 @@ def _call_user_compiler( return compiled_fn +<<<<<<< HEAD def dedup_pass(self) -> dict[str, torch.fx.GraphModule]: +======= + def dedup_pass(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if torch._dynamo.config.use_graph_deduplication: return apply_graph_deduplication(self) else: return {} +<<<<<<< HEAD def install_subgraph(self, name: str, sub_gm: torch.fx.GraphModule) -> str: next_name = get_unique_name_wrt(name, self.nn_modules, requires_suffix=True) sub_gm.__name__ = next_name # type: ignore[assignment] sub_gm.torchdynamo_force_dynamic = False # type: ignore[assignment] +======= + def install_subgraph(self, name, sub_gm): + next_name = get_unique_name_wrt(name, self.nn_modules, requires_suffix=True) + sub_gm.__name__ = next_name + sub_gm.torchdynamo_force_dynamic = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This graph module is not present in the user space, so it can't be # accessed by a source. Set source=None. self.register_attr_or_module(sub_gm, next_name, source=None) @@ -2248,7 +2806,11 @@ def remove_unused_graphargs(self) -> None: assert self.should_exit # Miniature DCE pass, but only for obviously trivial operations +<<<<<<< HEAD def is_static_true(b_node: fx.node.Argument) -> bool: +======= + def is_static_true(b_node: fx.node.Argument): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if b_node is True: return True if not isinstance(b_node, fx.Node): @@ -2267,7 +2829,11 @@ def is_static_true(b_node: fx.node.Argument) -> bool: # doesn't have unbacked inputs, since it's all in the ShapeEnv return False +<<<<<<< HEAD def is_symnode_arg(a: fx.node.Argument) -> bool: +======= + def is_symnode_arg(a: fx.node.Argument): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.experimental.sym_node import SymTypes if isinstance(a, (int, float, bool)): @@ -2279,7 +2845,11 @@ def is_symnode_arg(a: fx.node.Argument) -> bool: # NB: We assume that you cannot do mutations on int/float/bool, # because they are immutable types, and therefore is always safe to # DCE. +<<<<<<< HEAD def is_symnode_compute_node(node: fx.Node) -> bool: +======= + def is_symnode_compute_node(node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.experimental.sym_node import SymTypes if node.op != "call_function": @@ -2313,7 +2883,11 @@ def is_symnode_compute_node(node: fx.Node) -> bool: ): self.remove_node(node) +<<<<<<< HEAD def placeholder_binds_symbol(node: fx.Node) -> Optional[sympy.Symbol]: +======= + def placeholder_binds_symbol(node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) arg = node.meta["grapharg"] example = arg.example if isinstance(example, torch.SymInt) and isinstance( @@ -2322,7 +2896,11 @@ def placeholder_binds_symbol(node: fx.Node) -> Optional[sympy.Symbol]: return example.node.expr return None +<<<<<<< HEAD def remove_unused(node: fx.Node) -> None: +======= + def remove_unused(node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name()) # I'm not really sure why you need to delete these from the # node since the node is going to get removed @@ -2332,9 +2910,13 @@ def remove_unused(node: fx.Node) -> None: used_symbols: set[sympy.Symbol] = set() +<<<<<<< HEAD def update_used_symbols( used_symbols: set[sympy.Symbol], fake: Union[torch.SymInt, torch.Tensor] ) -> None: +======= + def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) used_symbols |= free_symbols(fake) recheck_placeholders = [] @@ -2428,7 +3010,11 @@ def add_output_instructions(self, prefix: list[Instruction]) -> None: self.output_instructions.extend(prefix) self.should_exit = True +<<<<<<< HEAD def install_global_unsafe(self, name: str, value: Any) -> None: +======= + def install_global_unsafe(self, name, value) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ WARNING: prefer the safer `install_global_by_id/install_global`. torch.compile instances should be independent of each other; @@ -2440,7 +3026,11 @@ def install_global_unsafe(self, name: str, value: Any) -> None: self.installed_globals.add(name) self.cleanups.append(CleanupHook.create(self.global_scope, name, value)) +<<<<<<< HEAD def install_global_by_id(self, prefix: str, value: Any) -> str: +======= + def install_global_by_id(self, prefix, value) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Installs a global if it hasn't been installed already. This is determined by (prefix, id(value)) pair. @@ -2455,7 +3045,11 @@ def install_global_by_id(self, prefix: str, value: Any) -> str: self.install_global_unsafe(name, value) return name +<<<<<<< HEAD def install_global(self, prefix: str, value: Any) -> str: +======= + def install_global(self, prefix, value) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Installs a global, generating a unique name for it. @@ -2469,7 +3063,11 @@ def install_global(self, prefix: str, value: Any) -> str: def cleanup(self) -> None: # There is a reference cycle between tracer and OutputGraph, causing # some of the tensor objects to be held alive for longer than necessary. +<<<<<<< HEAD self.root_tx = None # type: ignore[assignment] +======= + self.root_tx = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.nn_modules.clear() self.param_name_to_source = None @@ -2492,7 +3090,11 @@ def add_graph_finalizer( ) -> None: self.register_finalizer_fns.append(register_finalizer) +<<<<<<< HEAD def example_value_from_input_node(self, node: torch.fx.Node) -> Any: +======= + def example_value_from_input_node(self, node: torch.fx.Node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Extract the non-fake example tensor""" if node.op == "placeholder": return node.meta["grapharg"].example @@ -2500,6 +3102,7 @@ def example_value_from_input_node(self, node: torch.fx.Node) -> Any: return self.nn_modules[node.target] # type: ignore[index] +<<<<<<< HEAD class DynamoTracerOutput: error_on_graph_break: bool is_tracing_resume_prologue: bool @@ -2516,6 +3119,8 @@ def __init__( self.output_graph = tracer.output +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) err_epilogue = ( "With the current config, we will graph break " "(and fall back to eager-mode PyTorch) on all ops " @@ -2525,6 +3130,7 @@ def __init__( ) +<<<<<<< HEAD def check_pt2_compliant_op( output_graph: OutputGraph, kind: str, target: Any, args: Any, kwargs: Any ) -> None: @@ -2532,11 +3138,22 @@ def check_pt2_compliant_op( return def encountered_compliant_op(target: torch._ops.OpOverload) -> None: +======= +def check_pt2_compliant_op(output_graph, kind, target, args, kwargs): + if kind != "call_function": + return + + def encountered_compliant_op(target): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if target.namespace in {"prim", "prims", "aten"}: return output_graph.compliant_custom_ops.add(target) +<<<<<<< HEAD def encountered_non_compliant_op(target: torch._ops.OpOverload, msg: str) -> None: +======= + def encountered_non_compliant_op(target, msg): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_graph.non_compliant_ops.add(target) if config.only_allow_pt2_compliant_ops: unimplemented_v2( @@ -2602,6 +3219,7 @@ def encountered_non_compliant_op(target: torch._ops.OpOverload, msg: str) -> Non _compile_id_counter = itertools.count() +<<<<<<< HEAD P = ParamSpec("P") R = TypeVar("R") @@ -2614,12 +3232,21 @@ def __init__( *args: P.args, **kwargs: P.kwargs, ) -> None: +======= + +class LazyProxy: + def __init__(self, tracer, fn, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.tracer = tracer self.fn = fn self.args = args self.kwargs = kwargs +<<<<<<< HEAD def __call__(self) -> Any: +======= + def __call__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.fn(*self.args, **self.kwargs) @@ -2631,6 +3258,7 @@ class SubgraphTracer(fx.Tracer): compiling and executing the graph. """ +<<<<<<< HEAD def __init__( self, output_graph: "OutputGraph", @@ -2638,6 +3266,9 @@ def __init__( is_export: bool = False, source_target: Optional[Target] = None, ) -> None: +======= + def __init__(self, output_graph, parent=None, is_export=False, source_target=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__() self.output_graph = weakref.proxy(output_graph) self.graph = torch.fx.Graph() @@ -2667,14 +3298,22 @@ def __init__( # need to keep track of what free variables were lifted so we can # rewrite the HigherOrderOperator call using the traced body_fn. # Dicts maintain the order of args for the HigherOrderOperator call. +<<<<<<< HEAD self.lifted_freevars: dict[fx.Proxy, fx.Proxy] = {} +======= + self.lifted_freevars = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # map basic symbols (unbacked and unbacked) to their bound proxies. # There are only two cases where bound_symbols will be recorded: # 1. when we create_graph_input for a backed SymInt that's basic symbol +<<<<<<< HEAD # 2. when we track_produced_symints for intermediate results # bound_symbols always map the symbol to the proxy whose # tracer is the current tracer that's readily accessible in current tracer's graph. +======= + # 2. when we track_unbacked_symbols for intermediate results that contain unbacked symints. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.bound_symbols: dict[sympy.Symbol, Union[torch.fx.Proxy, LazyProxy]] = {} self.prev_inst = None @@ -2698,15 +3337,25 @@ def __init__( self.debug_level: int = parent.debug_level + 1 if parent is not None else 0 self._cur_code = None +<<<<<<< HEAD self._orig_gm_meta: Optional[list[Any]] = None self._orig_gm_lineno_map: Optional[dict[int, Optional[int]]] = None self._orig_gm_firstlineno: Optional[int] = None +======= + self._orig_gm_meta = None + self._orig_gm_lineno_map = None + self._orig_gm_firstlineno = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Each SubgraphTracer is associated with a source target, which indicates # which operator this subgraph is attached to. We compute a source_fn_stack # based on the source target. For the root tracer, it's set to []. # This is useful for debugging and transforming the exported graph. if self.parent is None: +<<<<<<< HEAD self.source_fn_stack: list[Any] = [] +======= + self.source_fn_stack = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: self.source_fn_stack = self.parent.source_fn_stack + [ (self.graph._target_to_str(source_target), source_target) @@ -2723,9 +3372,13 @@ def __init__( ) # preserve original meta if it is available +<<<<<<< HEAD def _maybe_preserve_original_meta( self, tx: "InstructionTranslatorBase", node: fx.Node ) -> None: +======= + def _maybe_preserve_original_meta(self, tx, node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( self._orig_gm_meta and self._orig_gm_lineno_map @@ -2747,6 +3400,7 @@ def _maybe_preserve_original_meta( def create_proxy( self, +<<<<<<< HEAD kind: str, target: Any, args: Any, @@ -2755,6 +3409,16 @@ def create_proxy( type_expr: Optional[Any] = None, proxy_factory_fn: Optional[Callable[[fx.Node], fx.Proxy]] = None, ) -> fx.Proxy: +======= + kind, + target, + args, + kwargs, + name=None, + type_expr=None, + proxy_factory_fn=None, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NOTE: [Nested SubgraphTracer and free_variable handling] # -------------------------------------------------------- # Read NOTE [HigherOrderOperator tracing design] first. @@ -2798,6 +3462,7 @@ def create_proxy( args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec) rv = super().create_proxy( +<<<<<<< HEAD kind, target, args, @@ -2805,6 +3470,9 @@ def create_proxy( name, type_expr, proxy_factory_fn, # type: ignore[arg-type] +======= + kind, target, args, kwargs, name, type_expr, proxy_factory_fn +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # append stack trace to fx node @@ -2825,7 +3493,11 @@ def create_proxy( tx_code = tx.f_code header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno) +<<<<<<< HEAD def get_trace_call_log_str() -> str: +======= + def get_trace_call_log_str(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) line = get_instruction_source_311(tx_code, cur_inst).rstrip() return f"TRACE FX call {rv.node.name} from {header}\n{line}" @@ -2935,6 +3607,7 @@ def get_trace_call_log_str() -> str: return rv def create_node( +<<<<<<< HEAD self, op: str, target: Target, @@ -2943,6 +3616,10 @@ def create_node( name: Optional[str] = None, type_expr: Optional[Any] = None, ) -> fx.Node: +======= + self, op, target, args=None, kwargs=None, name=None, type_expr=None + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) check_pt2_compliant_op(self.output_graph, op, target, args, kwargs) if self.parent is not None: flat_args = pytree.arg_tree_leaves(*args, **kwargs) @@ -2960,7 +3637,11 @@ def create_node( # Note: we did not override erase_node since # we call self.graph.erase_node elsewhere +<<<<<<< HEAD def remove_node(self, node: fx.Node) -> None: +======= + def remove_node(self, node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if len(node.users) > 0: user_graph_nodes: list[torch.fx.Node] = [] for user in node.users.keys(): @@ -2983,6 +3664,7 @@ def remove_node(self, node: fx.Node) -> None: # Remove this if https://github.com/pytorch/pytorch/issues/99007 gets # fixed. def create_graph_input( +<<<<<<< HEAD self, name: str, type_expr: Any, @@ -2990,6 +3672,10 @@ def create_graph_input( before: bool = False, source: Optional[Source] = None, ) -> fx.Proxy: +======= + self, name, type_expr, example_value, before=False, source=None + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(example_value, torch.Tensor): self._input_versions_at_beginning.append(example_value._version) log.debug( @@ -3015,7 +3701,10 @@ def create_graph_input( # So we are a bit more strict about what sources can become inputs # in export if self.is_export and self.parent is None: +<<<<<<< HEAD assert source is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not is_from_local_source(source, only_allow_input=True): self.output_graph.source_to_user_stacks.setdefault(source, []).append( TracingContext.extract_stack() @@ -3050,6 +3739,7 @@ def create_graph_input( self._used_names.add(name) # NOTE: [Auto lift basic free symbols when create_graph_input] +<<<<<<< HEAD # There are two sources of basic symbols: # # - They can come from inputs, e.g. when an input tensor is specified as dynamic. We handle @@ -3078,6 +3768,29 @@ def create_graph_input( # immediately after they're created at wrap_fx_proxy with track_produced_symints. Notice # that for basic symbols that're already tracked by create_graph_input, we won't track it again. # +======= + # Whenever we call create_graph_input, we try to also lift the basic symbols in example values + # as graph input. + # This applies to both top-level graph and subgraphs in higher order ops. + # It has several cases: + # 1. When create_graph_input for a tensor that has symbolic shapes, + # we look for basic symbols in its size and stride, we check if the symbol is bound + # in current graph (i.e. bound_symbols), it it's not bound, we'll create a placeholder + # for it then recursively check its parent, creates ph if not bound. + # Every tracer maintains a mapping (i.e. lifted_freevars) + # that maps from parent proxy to proxy in current tracer for the symbol. + # 2. When create_graph_input for a tensor with unbacked symbolic shapes, + # Backed symbols all come from inputs's symbolic shape. But unbacked symbols + # can be created while tracing. So we use track_unbacked_symbols will intercept + # at wrap_fx_proxy, and try to bind the unbacked symbols immediately after they're + # created. + # 3. subgraph will also lifted basic symbols in compound exprs of tensor shape. + # For example, if an input to subgraph takes size [s1+s2//8], we'll look for the + # the free symbols in the sizes and lift as inputs similar to 1 in _lift_symbols_in_symint) + # 4. When create_graph_input for a SymInt, if the symint is a basic symbol, we'll track it + # in bound_symbols so that we don't lift the same basic symbol twice. When the symint is a + # compound expr, we'll just create the proxy for the compouned expr but not lift its basic symbols. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Also see NOTE: [Export inputs must be explicitly passed in] is_strict_export = self.is_export is_non_strict_export = torch.compiler.is_compiling() @@ -3105,9 +3818,13 @@ def create_graph_input( return proxy # See NOTE: [Nested SubgraphTracer and free_variable handling] for more details +<<<<<<< HEAD def lift_tracked_freevar_to_input( self, proxy: fx.Proxy ) -> Union[LazyProxy, fx.Proxy]: +======= + def lift_tracked_freevar_to_input(self, proxy): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # You're doing something wrong if we are the root SubgraphTracer because # Dynamo adds tensors to graph inputs before creating a proxy for them. assert self.parent is not None, ( @@ -3147,7 +3864,11 @@ def lift_tracked_freevar_to_input( self.lifted_freevars[proxy] = new_proxy return new_proxy +<<<<<<< HEAD def maybe_lift_tracked_freevar_to_input(self, arg: Any) -> Any: +======= + def maybe_lift_tracked_freevar_to_input(self, arg): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ If arg is a free variable, then lift it to be an input. Returns the new lifted arg (if arg was a freevar), else the @@ -3173,6 +3894,7 @@ def maybe_lift_tracked_freevar_to_input(self, arg: Any) -> Any: # See NOTE: [Auto lift basic free symbols when create_graph_input] for overall design # You MUST call this API every time when creating a proxy in wrap_fx_proxy for a call +<<<<<<< HEAD # that produced symints or tensors with unbacked symint shapes. # This function is used to track the symints with its proxies created during # dynamo tracing so that subgraph knows how to bind a symbol input with parent's proxy. @@ -3182,6 +3904,16 @@ def maybe_lift_tracked_freevar_to_input(self, arg: Any) -> Any: def track_produced_symints( self, example_value: Any, e_proxy: Union[LazyProxy, torch.fx.Proxy] ) -> None: +======= + # that produced unbacked symints or tensors with unbacked symint shapes. + # This function is used to track the unbacked symints with its proxies created during + # dynamo tracing so that subgraph knows how to bind a symbol input with parent's proxy. + # LazyProxy are created for tensor shapes that're unbacked so that we don't create proxies + # for symbols that're not going to be used. + def track_unbacked_symbols( + self, example_value, e_proxy: Union[LazyProxy, torch.fx.Proxy] + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # When binding the symbols in an exmaple_value, we bind the symbols # to the proxy's associated Tracer instead of current tracer. # This is because: @@ -3198,12 +3930,17 @@ def track_produced_symints( tracer = e_proxy.tracer assert isinstance(tracer, SubgraphTracer) +<<<<<<< HEAD def need_bind(s: Any) -> bool: +======= + def need_bind(s) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.experimental.symbolic_shapes import is_symbolic return ( is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol) +<<<<<<< HEAD and s.node.expr not in self.bound_symbols ) @@ -3218,12 +3955,26 @@ def _proxy_with_example_value( proxy = tracer.create_proxy(*args, **kwargs) set_example_value(proxy.node, example_value) return proxy +======= + and s.node.shape_env.is_unbacked_symint(s.node.expr) + and s.node.expr not in self.bound_symbols + ) + + def _proxy_with_example_value(example_value, *args, **kwargs): + proxy = tracer.create_proxy(*args, **kwargs) + set_example_value(proxy.node, example_value) + return proxy +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(example_value, torch.Tensor): for i, s in enumerate(example_value.size()): if need_bind(s): log.debug( +<<<<<<< HEAD "track_produced_symints %s for %s.size()[%s] at debug_level %s", +======= + "_track_unbacked_symbols %s for %s.size()[%s] at debug_level %s", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) s, e_proxy, i, @@ -3239,6 +3990,7 @@ def _proxy_with_example_value( {}, type_expr=type(s), ) +<<<<<<< HEAD self.track_produced_symints(s, lazy_proxy) storage_offset = example_value.storage_offset() @@ -3260,12 +4012,19 @@ def _proxy_with_example_value( type_expr=type(storage_offset), ) self.track_produced_symints(storage_offset, lazy_proxy) +======= + self.track_unbacked_symbols(s, lazy_proxy) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if example_value.layout is torch.strided: for i, s in enumerate(example_value.stride()): if need_bind(s): log.debug( +<<<<<<< HEAD "track_produced_symints %s for %s.stride()[%s] at debug_level %s", +======= + "_track_unbacked_symbols %s for %s.stride()[%s] at debug_level %s", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) s, e_proxy, i, @@ -3281,6 +4040,7 @@ def _proxy_with_example_value( {}, type_expr=type(s), ) +<<<<<<< HEAD self.track_produced_symints(s, lazy_proxy) elif example_value.layout is torch.sparse_coo: @@ -3292,12 +4052,31 @@ def _proxy_with_example_value( elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}: self.track_produced_symints(example_value.ccol_indices(), e_proxy) self.track_produced_symints(example_value.row_indices(), e_proxy) +======= + self.track_unbacked_symbols(s, lazy_proxy) + + elif example_value.layout is torch.sparse_coo: + self.track_unbacked_symbols(example_value._indices(), e_proxy) + self.track_unbacked_symbols(example_value._values(), e_proxy) + elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}: + self.track_unbacked_symbols(example_value.crow_indices(), e_proxy) + self.track_unbacked_symbols(example_value.col_indices(), e_proxy) + elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}: + self.track_unbacked_symbols(example_value.ccol_indices(), e_proxy) + self.track_unbacked_symbols(example_value.row_indices(), e_proxy) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_traceable_wrapper_subclass(example_value): attrs, ctx = example_value.__tensor_flatten__() for attr in attrs: inner_t = getattr(example_value, attr) +<<<<<<< HEAD self.track_produced_symints(inner_t, getattr(e_proxy, attr)) elif isinstance(example_value, torch.SymInt): +======= + self.track_unbacked_symbols(inner_t, getattr(e_proxy, attr)) + elif isinstance(example_value, torch.SymInt): + # Only bind unbacked symbols. backed symbols are lifted as inputs. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if need_bind(example_value): expr = example_value.node.expr tracer.bound_symbols[expr] = e_proxy @@ -3305,7 +4084,11 @@ def _proxy_with_example_value( # See Note [Auto lift basic free symbols when create_graph_input] def _lift_basic_symbols( self, example_value: Union[torch.SymInt, torch.Tensor], src: Optional[Source] +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The before arg is for inserting symints in the sizes/strides of a tensor # before the tensor. This ordering ensures that when we look at the tensor's # symbols, they're already lifted/tracked. E.g. this assumption is used @@ -3329,7 +4112,11 @@ def _lift_symbols_in_symint( self.parent._lift_basic_symbols(s, source) for s0 in self_to_be_bound: parent_proxy = self.parent.bound_symbols[s0] +<<<<<<< HEAD example_val = parent_proxy.node.meta["example_value"] # type: ignore[union-attr] +======= + example_val = parent_proxy.node.meta["example_value"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(example_val, torch.SymInt) ph = self.create_graph_input( str(s0), @@ -3344,7 +4131,11 @@ def _lift_symbols_in_symint( source.name() if source is not None else "subgraph inputs", self.debug_level, ) +<<<<<<< HEAD self.lifted_freevars[parent_proxy] = ph # type: ignore[index] +======= + self.lifted_freevars[parent_proxy] = ph +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # For root_tracer: else: assert len(self_to_be_bound) == 1, ( @@ -3454,7 +4245,11 @@ def lookup_unbound_symbols(self, s: torch.SymInt) -> list[sympy.Symbol]: # Sort the symbols so that we can have a deterministic lifting order return sorted(to_be_bound, key=lambda s: s.name) +<<<<<<< HEAD def has_input_mutation(self) -> MutationInfo: +======= + def has_input_mutation(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_versions_at_beginning = self._input_versions_at_beginning input_nodes = [] @@ -3483,7 +4278,11 @@ def has_input_mutation(self) -> MutationInfo: return MutationInfo(False, "") +<<<<<<< HEAD def has_aliasing(self) -> AliasingInfo: +======= + def has_aliasing(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._higher_order_ops.utils import _collect_fake_inputs input_storages: dict[StorageWeakRef, torch.fx.Node] = dict() diff --git a/torch/_dynamo/package.py b/torch/_dynamo/package.py index 9aa00a6a9d1e3..6e06b5c8d664a 100644 --- a/torch/_dynamo/package.py +++ b/torch/_dynamo/package.py @@ -8,18 +8,25 @@ from a different process or host. """ +<<<<<<< HEAD import abc import ast +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import contextlib import dataclasses import functools import hashlib import importlib +<<<<<<< HEAD import inspect +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import logging import os import pickle import platform +<<<<<<< HEAD import shutil import sys import types @@ -36,6 +43,19 @@ from .bytecode_transformation import get_code_keys from .utils import dynamo_timed, increment_frame +======= +import sys +import types +from collections.abc import Generator +from typing import Any, NewType, Optional + +import torch +import torch._inductor.package +from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext +from torch.compiler._cache import CacheArtifactFactory + +from .bytecode_transformation import get_code_keys +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) logger = logging.getLogger(__name__) @@ -103,6 +123,7 @@ class _GuardedCodeCacheEntry: _FunctionId = NewType("_FunctionId", str) # __resume_at +<<<<<<< HEAD @dataclasses.dataclass(frozen=True) class InlinedSource: module: str @@ -123,6 +144,10 @@ class DynamoCaptureOutput: @dataclasses.dataclass class _DynamoCodeCacheEntry(DynamoCaptureOutput): +======= +@dataclasses.dataclass +class _DynamoCodeCacheEntry: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Contains the serializable information associated with a single code object in dynamo. To restore an execution of compiled code, we will need the following @@ -136,16 +161,20 @@ class _DynamoCodeCacheEntry(DynamoCaptureOutput): 4. A list of guarded code that eval frame dispatches to. 5. A list of imported module objects unioned from all compiled branches. 6. A list of "backends" (compiled fx graph) unioned from all compield branches. +<<<<<<< HEAD 7. A string path used to access the original code object users defined. A code object can be accessed by "{python_module}.{function_name}.{code_source}" . 8. A boolean flag indicating whether the function is installed to global scope. 9. A boolean flag indicating whether the function has a compile id. 10. Whether or not this code entry was bypassed +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ python_code: SerializedCode python_module: str function_names: list[_FunctionId] +<<<<<<< HEAD import_sources: dict[str, str] code_source: Optional[str] install_to_global: bool @@ -269,12 +298,20 @@ def _find_code_source(obj: Any) -> Optional[str]: if code_source is None: _raise_resolution_error(code, toplevel) return toplevel.__qualname__, code_source.strip(".") +======= + guarded_codes: list[_GuardedCodeCacheEntry] + import_sources: dict[str, str] + backend_ids: list[_BackendId] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass class _DynamoCacheEntry: codes: list[_DynamoCodeCacheEntry] +<<<<<<< HEAD inlined_sources: set[InlinedSource] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python_version: str = platform.python_version() torch_version: str = torch.__version__ @@ -293,6 +330,7 @@ def after_deserialization(self) -> _DynamoCacheEntry: return pickle.loads(self.content) +<<<<<<< HEAD def _hash_source(source: str) -> str: sha256_hash = hashlib.sha256() sha256_hash.update(source.encode()) @@ -345,6 +383,8 @@ def _ctx() -> Iterator[None]: return _ctx() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class CompilePackage: """ CompilePackage is considered a low level component and should not be directly exposed to @@ -358,12 +398,16 @@ class CompilePackage: updates with compiled functions and resume functions. """ +<<<<<<< HEAD def __init__( self, fn: Optional[Callable[..., Any]], dynamo: Optional[_DynamoCacheEntry] = None, ignore_inlined_sources: bool = False, ) -> None: +======= + def __init__(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._innermost_fn = None self._codes: dict[types.CodeType, _DynamoCodeCacheEntry] = {} @@ -372,6 +416,7 @@ def __init__( # For debugging/testing purpose only. self._cached_backends: dict[_BackendId, Any] = {} +<<<<<<< HEAD self._inlined_sources: set[InlinedSource] = set() self._resume_codes: set[types.CodeType] = set() self._initialized = False @@ -394,6 +439,17 @@ def initialize( assert not self._initialized self._inlined_sources = set() self._innermost_fn = innermost_fn(fn) # type: ignore[assignment] +======= + + self._initialize(fn, dynamo) + self.uninstall() + self.validate() + + def _initialize(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None: + from .eval_frame import innermost_fn + + self._innermost_fn = innermost_fn(fn) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self._innermost_fn is not None if dynamo is not None: assert isinstance(dynamo, _DynamoCacheEntry) @@ -405,6 +461,7 @@ def initialize( raise RuntimeError( f"Compile package was created with a different PyTorch version: {dynamo.torch_version}" ) +<<<<<<< HEAD if not ignore_inlined_sources: for code in dynamo.inlined_sources: m = importlib.import_module(code.module) @@ -415,6 +472,8 @@ def initialize( ) self._inlined_sources = dynamo.inlined_sources +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) main, *codes = dynamo.codes self._codes = {self._innermost_fn.__code__: main} @@ -424,15 +483,22 @@ def initialize( self._add_function( self._innermost_fn.__code__, self._innermost_fn.__module__ ) +<<<<<<< HEAD self._initialized = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _add_function( self, python_code: types.CodeType, python_module: str, +<<<<<<< HEAD function_name: Optional[_FunctionId] = None, code_source: Optional[str] = None, install_to_global: bool = False, +======= + name: Optional[_FunctionId] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: if python_code not in self._codes: code = _DynamoCodeCacheEntry( @@ -442,18 +508,27 @@ def _add_function( guarded_codes=[], import_sources={}, backend_ids=[], +<<<<<<< HEAD code_source=code_source, install_to_global=install_to_global, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self._codes[python_code] = code else: code = self._codes[python_code] assert code.python_module == python_module +<<<<<<< HEAD assert code.install_to_global == install_to_global assert code.code_source == code_source if function_name is not None: code.function_names.append(function_name) +======= + + if name is not None: + code.function_names.append(name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def cached_backends(self) -> dict[_BackendId, Any]: @@ -462,6 +537,7 @@ def cached_backends(self) -> dict[_BackendId, Any]: @functools.cached_property def source_id(self) -> str: assert self._innermost_fn is not None +<<<<<<< HEAD return CompilePackage.source_id_from_fn(self._innermost_fn) def _add_user_function(self, code: types.CodeType) -> None: @@ -475,26 +551,38 @@ def _add_user_function(self, code: types.CodeType) -> None: function_name=_FunctionId(function_name), code_source=code_source, ) +======= + sha256_hash = hashlib.sha256() + sha256_hash.update(self._innermost_fn.__qualname__.encode()) + sha256_hash.update(str(self._innermost_fn.__code__.co_firstlineno).encode()) + return sha256_hash.hexdigest() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @contextlib.contextmanager def code_context(self, code: types.CodeType) -> Generator[None, None, None]: assert self._current_entry is None +<<<<<<< HEAD # Sometimes user code cannot be inlined in dynamo resulting in extra user code # being compiled. We should record these as when they are actually invoked. if code not in self._codes: self._add_user_function(code) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) entry = self._codes[code] self._current_entry = entry try: yield finally: +<<<<<<< HEAD if ( entry.bypassed ): # Remove the code from the cache entry if it's been bypassed del self._codes[code] entry.has_compile_id = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._current_entry = None def add_guarded_code( @@ -503,14 +591,18 @@ def add_guarded_code( dynamo_code: types.CodeType, ) -> None: assert self._current_entry is not None +<<<<<<< HEAD if self._current_entry.bypassed: return +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) guarded_code_entry = _GuardedCodeCacheEntry( guards_state=guards_state, dynamo_code=SerializedCode.from_code_object(dynamo_code), ) self._current_entry.guarded_codes.append(guarded_code_entry) +<<<<<<< HEAD def add_inlined_source(self, sources: list[types.CodeType]) -> None: assert self._current_entry is not None if self._current_entry.bypassed: @@ -538,10 +630,13 @@ def bypass_current_entry(self) -> None: assert self._current_entry is not None self._current_entry.bypassed = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def add_resume_function( self, python_code: types.CodeType, python_module: str, +<<<<<<< HEAD function_name: Optional[str], ) -> None: self._add_function( @@ -551,6 +646,13 @@ def add_resume_function( install_to_global=True, ) self._resume_codes.add(python_code) +======= + name: Optional[str], + ) -> None: + self._add_function( + python_code, python_module, _FunctionId(name) if name else None + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def add_import_source(self, alias: str, module_name: str) -> None: assert self._current_entry is not None @@ -567,7 +669,10 @@ def add_backend_id(self, backend_id: str, backend: Optional[Any] = None) -> None def validate(self) -> None: assert self._current_entry is None assert self._innermost_fn is not None +<<<<<<< HEAD assert self._initialized +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert next(iter(self._codes)) is self._innermost_fn.__code__ def _install_global(self, module: types.ModuleType, name: str, value: Any) -> None: @@ -595,6 +700,7 @@ def install(self, backends: dict[_BackendId, Any]) -> None: """ from torch._C._dynamo.eval_frame import _load_precompile_entry +<<<<<<< HEAD from .output_graph import get_builtins_dict self.uninstall() @@ -683,6 +789,50 @@ def source_id_from_fn(fn: Callable[..., Any]) -> str: sha256_hash.update(innermost_fn_.__qualname__.encode()) sha256_hash.update(str(innermost_fn_.__code__.co_firstlineno).encode()) return sha256_hash.hexdigest() +======= + self.uninstall() + + for code, entry in self._codes.items(): + module = sys.modules[entry.python_module] + for alias, module_name in entry.import_sources.items(): + self._install_global( + module, alias, importlib.import_module(module_name) + ) + for function_name in entry.function_names: + fn = types.FunctionType(code, module.__dict__, function_name) + self._install_global(module, function_name, fn) + for backend_id in entry.backend_ids: + if backend_id not in backends: + raise RuntimeError( + f"Backend {backend_id} is not found in the given backends" + ) + backend = backends[backend_id] + self._install_global( + module, + backend_id, + torch._dynamo.disable(backend), + ) + + for code, entry in self._codes.items(): + for guarded_code in entry.guarded_codes: + guards_state = pickle.loads(guarded_code.guards_state) + assert isinstance(guards_state, torch._dynamo.guards.GuardsState) + check_fn_manager = torch._dynamo.guards.CheckFunctionManager( + code, + guards_state.output_graph, + guards_serialization_mode="load", + shape_code_parts=guards_state.shape_code_parts, + ) + _load_precompile_entry( + code, + check_fn_manager.guard_manager, + SerializedCode.to_code_object(guarded_code.dynamo_code), + ) + + def cache_entry(self) -> _DynamoCacheEntry: + self.validate() + return _DynamoCacheEntry(codes=list(self._codes.values())) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @CacheArtifactFactory.register @@ -695,6 +845,7 @@ def after_deserialization(self) -> Any: return pickle.loads(self.content) +<<<<<<< HEAD _Backends = dict[_BackendId, PrecompileCacheArtifact[Any]] @@ -709,6 +860,15 @@ def record_package(self, package: CompilePackage) -> None: """ Records a package to PrecompileContext, so that it can be serialized later. """ +======= +class DynamoStore: + """ + A DynamoStore tracks active CompilePackages, and provides methods to store and retrieve them. + """ + + def record_package(self, package: CompilePackage) -> None: + """Records a package to PrecompileContext, so that it can be serialized later.""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cache_entry = package.cache_entry() pickled_result = pickle.dumps(cache_entry) PrecompileContext.record_artifact( @@ -716,14 +876,19 @@ def record_package(self, package: CompilePackage) -> None: ) def record_eager_backend(self, backend_id: _BackendId, backend: Any) -> None: +<<<<<<< HEAD """ Records eager fx graphs to PrecompileContext for testing purposes. """ +======= + """Records eager fx graphs to PrecompileContext for testing purposes.""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pickled_result = pickle.dumps(backend) PrecompileContext.record_artifact( EagerCacheArtifact.type(), key=backend_id, content=pickled_result ) +<<<<<<< HEAD @abc.abstractmethod def clear(self) -> None: ... @@ -749,12 +914,19 @@ def save_cache_entry(self, cache_entry: _DynamoCacheEntry, key: str) -> None: Saves a package to a given path. Grabs backends from PrecompileContext. """ backend_content: _Backends = {} +======= + def save_package(self, package: CompilePackage, path: str) -> None: + """Saves a package to a given path. Grabs backends from PrecompileContext.""" + backend_content = {} + cache_entry = package.cache_entry() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for backend_id in cache_entry.backend_ids: serialized_backend = PrecompileContext.serialize_artifact_by_key(backend_id) if serialized_backend is None: raise RuntimeError( f"Backend {backend_id} is not found in the given backends" ) +<<<<<<< HEAD assert isinstance(serialized_backend, PrecompileCacheArtifact) backend_content[backend_id] = serialized_backend @@ -881,11 +1053,27 @@ def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]: Read dynamo cache entry and backends from disk. """ path = os.path.join(self.path_prefix, path) if self.path_prefix else path +======= + backend_content[backend_id] = serialized_backend + try: + with open(os.path.join(path, "dynamo"), "wb") as dynamo_path: + pickle.dump(cache_entry, dynamo_path) + with open(os.path.join(path, "backends"), "wb") as backend_path: + pickle.dump(backend_content, backend_path) + except Exception as e: + raise RuntimeError(f"Failed to save package to {path}: {e}") from e + + def load_package( + self, fn: Any, path: str + ) -> tuple[CompilePackage, dict[_BackendId, Any]]: + """Loads a package from a given path and returns it plus a list of deserialized backends""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: with open(os.path.join(path, "dynamo"), "rb") as dynamo_path: cache_entry = pickle.load(dynamo_path) with open(os.path.join(path, "backends"), "rb") as backend_path: backend_content = pickle.load(backend_path) +<<<<<<< HEAD return cache_entry, backend_content except Exception as e: raise RuntimeError(f"Failed to load package from path {path}: {e}") from e @@ -941,3 +1129,11 @@ def load_and_install_package( DynamoCache = DiskDynamoCache(os.path.join(cache_dir(), "dynamo")) +======= + except Exception as e: + raise RuntimeError(f"Failed to load package from path {path}: {e}") from e + for backend_id, backend in backend_content.items(): + backend_content[backend_id] = backend.after_deserialization() + package = CompilePackage(fn, cache_entry) + return package, backend_content +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index 1a2c98ee6c7dd..ddaa9c878ad1f 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -173,7 +173,10 @@ class CodeState: _INIT_CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None _CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None +<<<<<<< HEAD _LOGGED_DYNAMIC_ALLOWLIST: bool = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass(frozen=True) @@ -520,7 +523,18 @@ def process_automatic_dynamic( return res +<<<<<<< HEAD def format_cache_key(key: str) -> str: +======= +def get_cache_key() -> Optional[str]: + # TODO: info versions of these logs that log only once + if torch._inductor.config.force_disable_caches: + warn_once( + "dynamo_pgo force disabled by torch._inductor.config.force_disable_caches" + ) + return None + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: We always use global rank for keys, even though they are overkill # for local only cache rank = None @@ -528,6 +542,7 @@ def format_cache_key(key: str) -> str: rank = dist.get_rank() tag = torch.compiler.config.cache_key_tag +<<<<<<< HEAD return f"{key}:{rank}:{tag}" @@ -538,6 +553,8 @@ def get_cache_key() -> Optional[str]: "dynamo_pgo force disabled by torch.compiler.config.force_disable_caches" ) return None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: We namespace the cache keys so that only user-specified job id # can alias with each other. @@ -548,15 +565,24 @@ def get_cache_key() -> Optional[str]: "automatically generated job id associated with a specific MAST job " "name and version." ) +<<<<<<< HEAD return format_cache_key(r) if (name_version := torch._utils_internal.get_mast_job_name_version()) is not None: mast_job_name, mast_job_version = name_version return format_cache_key(f"mast:{mast_job_name}:{mast_job_version}") +======= + return f"{r}:{rank}:{tag}" + + if (name_version := torch._utils_internal.get_mast_job_name_version()) is not None: + mast_job_name, mast_job_version = name_version + return f"mast:{mast_job_name}:{mast_job_version}:{rank}:{tag}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return None +<<<<<<< HEAD def get_extra_cache_key(sticky_key: str) -> Optional[str]: if torch.compiler.config.force_disable_caches: warn_once( @@ -567,6 +593,8 @@ def get_extra_cache_key(sticky_key: str) -> Optional[str]: return format_cache_key(sticky_key) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This solely controls local PGO def code_state_path(cache_key: str) -> Optional[str]: if not torch._dynamo.config.automatic_dynamic_local_pgo: @@ -580,7 +608,11 @@ def code_state_path(cache_key: str) -> Optional[str]: def should_use_remote_dynamo_pgo_cache() -> bool: +<<<<<<< HEAD if torch.compiler.config.force_disable_caches: +======= + if torch._inductor.config.force_disable_caches: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False if (r := torch._dynamo.config.automatic_dynamic_remote_pgo) is not None: @@ -630,7 +662,10 @@ def _collect_dynamic_sources(code_state: CodeState) -> OrderedSet[str]: def log_frame_dynamic_whitelist(f_code: types.CodeType) -> None: +<<<<<<< HEAD global _LOGGED_DYNAMIC_ALLOWLIST +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) code_id = CodeId.make(f_code) frame_state = get_code_state()[code_id] frame_whitelist = ",".join(_collect_dynamic_sources(frame_state)) @@ -639,6 +674,7 @@ def log_frame_dynamic_whitelist(f_code: types.CodeType) -> None: CompileEventLogger.pt2_compile( name, recompile_dynamic_whitelist=frame_whitelist ) +<<<<<<< HEAD if not _LOGGED_DYNAMIC_ALLOWLIST: torch._utils_internal.add_mlhub_insight( category="dynamic_shapes_analysis", @@ -649,6 +685,8 @@ def log_frame_dynamic_whitelist(f_code: types.CodeType) -> None: ) # add mlhub insight only once per rank _LOGGED_DYNAMIC_ALLOWLIST = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str: @@ -671,6 +709,7 @@ def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str: return code_state_str +<<<<<<< HEAD def merge_pgo_entry(src: FrameStateSizeEntry, dst: FrameStateSizeEntry) -> None: def rank(entry: FrameStateSizeEntry) -> int: if not isinstance(entry.size, tuple): # scalar @@ -681,6 +720,8 @@ def rank(entry: FrameStateSizeEntry) -> int: dst |= src +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @CacheArtifactFactory.register class PGOCacheArtifact(CacheArtifact): @override @@ -710,6 +751,7 @@ def _rewrite_cache_key_for_mega_cache(original_key: str) -> str: return original_key +<<<<<<< HEAD def hit(key: str, ty: str) -> defaultdict[CodeId, CodeState]: global _INIT_CODE_STATE assert isinstance(_CODE_STATE, defaultdict) @@ -726,6 +768,34 @@ def hit(key: str, ty: str) -> defaultdict[CodeId, CodeState]: def get_local_code_state(cache_key: str) -> Optional[defaultdict[CodeId, CodeState]]: global _CODE_STATE +======= +def get_code_state() -> defaultdict[CodeId, CodeState]: + global _CODE_STATE, _INIT_CODE_STATE + if _CODE_STATE is not None: + return _CODE_STATE + + # Initialize it (even if we don't look up profile) + _CODE_STATE = defaultdict(CodeState) + + cache_key = get_cache_key() + if cache_key is None: + return _CODE_STATE + + def hit(ty: str) -> defaultdict[CodeId, CodeState]: + global _INIT_CODE_STATE + assert isinstance(_CODE_STATE, defaultdict) + log.info("get_code_state %s hit %s, %d entries", path, ty, len(_CODE_STATE)) + trace_structured_artifact( + f"get_{ty}_code_state", + "string", + lambda: render_code_state(_CODE_STATE), # type: ignore[arg-type] + ) + set_feature_use("pgo", True) + _INIT_CODE_STATE = copy.deepcopy(_CODE_STATE) + return _CODE_STATE + + # Attempt local +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) path = code_state_path(cache_key) if path is not None and os.path.exists(path): with dynamo_timed( @@ -747,6 +817,7 @@ def get_local_code_state(cache_key: str) -> Optional[defaultdict[CodeId, CodeSta CacheArtifactManager.record_artifact( PGOCacheArtifact.type(), cache_key, content ) +<<<<<<< HEAD return hit(path, "local") return None @@ -790,6 +861,11 @@ def lookup_remote_cache_entry( def get_remote_code_state(cache_key: str) -> Optional[defaultdict[CodeId, CodeState]]: global _CODE_STATE +======= + return hit("local") + + # Attempt remote +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) remote_cache = get_remote_cache() if remote_cache is not None: with dynamo_timed( @@ -798,6 +874,7 @@ def get_remote_code_state(cache_key: str) -> Optional[defaultdict[CodeId, CodeSt dynamo_compile_column_us="pgo_get_remote_code_state_time_us", ): CompileEventLogger.pt2_compile(name, cache_key=cache_key) +<<<<<<< HEAD code_state = lookup_remote_cache_entry(remote_cache, cache_key, name) if code_state is not None: _CODE_STATE = code_state @@ -872,6 +949,39 @@ def get_code_state() -> defaultdict[CodeId, CodeState]: extra_read_key = get_extra_cache_key(sticky_read) if extra_read_key is not None: add_extra_remote_code_state(extra_read_key) +======= + # TODO: I don't really understand why there's a JSON container format + try: + cache_data = remote_cache.get(cache_key) + except Exception: + log.warning( + "get_code_state failed remote read on %s", cache_key, exc_info=True + ) + else: + if cache_data is not None: + try: + assert isinstance(cache_data, dict) + data = cache_data["data"] + assert isinstance(data, str) + payload = base64.b64decode(data) + CompileEventLogger.pt2_compile( + name, cache_size_bytes=len(payload) + ) + _CODE_STATE = pickle.loads(payload) + except Exception: + log.warning( + "get_code_state failed parsing remote result on %s", + cache_key, + exc_info=True, + ) + else: + CacheArtifactManager.record_artifact( + PGOCacheArtifact.type(), cache_key, payload + ) + return hit("remote") + else: + log.info("get_code_state remote miss on %s", cache_key) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log.info("get_code_state using default") @@ -895,10 +1005,13 @@ def put_code_state() -> None: put_local_code_state(cache_key) put_remote_code_state(cache_key) +<<<<<<< HEAD if (sticky_write := torch.compiler.config.pgo_extra_write_key) is not None: extra_write_key = get_extra_cache_key(sticky_write) if extra_write_key is not None: put_remote_code_state(extra_write_key) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def write_local_impl(cache_key: str, pickled_code: bytes) -> Optional[tuple[str, int]]: @@ -986,7 +1099,13 @@ def put_remote_code_state(cache_key: str) -> None: # NB: this does NOT reset the cached code state on disk def reset_code_state() -> None: +<<<<<<< HEAD global _CODE_STATE, _INIT_CODE_STATE, _LOGGED_DYNAMIC_ALLOWLIST _CODE_STATE = None _INIT_CODE_STATE = None _LOGGED_DYNAMIC_ALLOWLIST = False +======= + global _CODE_STATE, _INIT_CODE_STATE + _CODE_STATE = None + _INIT_CODE_STATE = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 4fc777ffe7efd..63fcd1d00fa65 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -9,8 +9,12 @@ # mypy: allow-untyped-defs import types +<<<<<<< HEAD from collections import OrderedDict from collections.abc import Hashable, Iterable, MutableMapping, Sequence +======= +from collections.abc import Iterable, MutableMapping, Sequence +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from itertools import repeat as _repeat from typing import Any, Callable, TYPE_CHECKING @@ -24,14 +28,20 @@ # See also the POLYFILLED_MODULE_NAMES in torch/_dynamo/polyfills/loader.py # Put the submodules here to avoid circular imports from . import ( +<<<<<<< HEAD _collections as _collections, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) builtins as builtins, functools as functools, itertools as itertools, operator as operator, os as os, pytree as pytree, +<<<<<<< HEAD struct as struct, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sys as sys, ) @@ -77,6 +87,7 @@ def radians(x): return math.pi / 180.0 * x +<<<<<<< HEAD def impl_CONTAINS_OP_fallback(a, b): # performs fallback "a in b" if hasattr(b, "__iter__"): @@ -88,6 +99,8 @@ def impl_CONTAINS_OP_fallback(a, b): raise TypeError(f"argument of type {type(b)} is not iterable") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def accumulate_grad(x, new_grad): # polyfills according to the Gradient Layout Contract if new_grad is None: @@ -115,6 +128,7 @@ def list_cmp(op: Callable[[Any, Any], bool], left: Sequence[Any], right: Sequenc return op(len(left), len(right)) +<<<<<<< HEAD def dict___eq__(d, other): if (len(d) != len(other)) or (d.keys() != other.keys()): return False @@ -129,6 +143,8 @@ def dict___eq__(d, other): return True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def set_symmetric_difference(set1, set2): symmetric_difference_set = set() for x in set1: @@ -147,6 +163,7 @@ def set_symmetric_difference_update(set1, set2): def set_isdisjoint(set1, set2): +<<<<<<< HEAD if not isinstance(set2, Iterable): raise TypeError(f"'{type(set2)}' object is not iterable") @@ -156,6 +173,11 @@ def set_isdisjoint(set1, set2): raise TypeError(f"unhashable type: '{type(y)}'") if x == y: return False +======= + for x in set1: + if x in set2: + return False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return True @@ -163,6 +185,7 @@ def set_intersection(set1, *others): if len(others) == 0: return set1.copy() +<<<<<<< HEAD if not all(isinstance(s, Iterable) for s in others): raise TypeError(f"set.difference expected an iterable, got {type(others)}") @@ -175,6 +198,12 @@ def set_intersection(set1, *others): for x in set1: for set2 in others: if not any(x == y for y in set2): +======= + intersection_set = set() + for x in set1: + for set2 in others: + if x not in set2: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) break else: intersection_set.add(x) @@ -189,6 +218,7 @@ def set_intersection_update(set1, *others): def set_union(set1, *others): # frozenset also uses this function +<<<<<<< HEAD if len(others) == 0: return set1.copy() @@ -204,6 +234,11 @@ def set_union(set1, *others): set_update(union_set, set2) # frozenset also uses this function +======= + union_set = set(set1.copy()) + for set2 in others: + set_update(union_set, set2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return type(set1)(union_set) @@ -224,10 +259,13 @@ def set_difference(set1, *others): if not all(isinstance(s, Iterable) for s in others): raise TypeError(f"set.difference expected an iterable, got {type(others)}") +<<<<<<< HEAD for s in others: if any(not isinstance(x, Hashable) for x in s): raise TypeError("unhashable type") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) difference_set = set() for x in set1: for set2 in others: @@ -244,6 +282,7 @@ def set_difference_update(set1, *others): set1.update(result) +<<<<<<< HEAD def assert_dict_equal(self_, d1, d2, msg=None): self_.assertTrue(d1 == d2, msg) @@ -257,6 +296,8 @@ def assert_sequence_equal(self_, seq1, seq2, msg=None, seq_type=None): return self_.assertTrue(seq1 == seq2, msg) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def getattr_and_trace(*args, **kwargs): wrapper_obj = args[0] attr_name = args[1] @@ -288,9 +329,12 @@ def construct_dict(cls, /, *args, **kwargs): if args: src = args[0] +<<<<<<< HEAD if not isinstance(src, Iterable): raise TypeError(f"{type(src)} object is not iterable") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Ensure that the overridden __iter__ method is invoked if isinstance(src, (dict, MutableMapping, types.MappingProxyType)): for key in src: diff --git a/torch/_dynamo/polyfills/builtins.py b/torch/_dynamo/polyfills/builtins.py index d3544fa354faf..75ef66d04c7f1 100644 --- a/torch/_dynamo/polyfills/builtins.py +++ b/torch/_dynamo/polyfills/builtins.py @@ -7,7 +7,11 @@ import builtins import functools import operator +<<<<<<< HEAD from typing import Callable, TYPE_CHECKING, TypeVar +======= +from typing import TYPE_CHECKING, TypeVar +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..decorators import substitute_in_graph @@ -58,6 +62,7 @@ def enumerate(iterable: Iterable[_T], start: int = 0) -> Iterable[tuple[int, _T] @substitute_in_graph(builtins.sum, can_constant_fold_through=True) # type: ignore[arg-type] def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T: # type: ignore[assignment] return functools.reduce(operator.add, iterable, start) +<<<<<<< HEAD class _CallableIterator: @@ -120,3 +125,5 @@ def sequence_protocol(iterable): # type: ignore[no-untyped-def] raise TypeError("iter(v, w): v must be a callable") return _CallableIterator(fn, sentinel) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index 2b64327b93de9..901674927acff 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -5,9 +5,14 @@ from __future__ import annotations import itertools +<<<<<<< HEAD import operator import sys from typing import Callable, Optional, overload, TYPE_CHECKING, TypeVar +======= +import sys +from typing import Callable, overload, TYPE_CHECKING, TypeVar +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing_extensions import TypeAlias from ..decorators import substitute_in_graph @@ -18,6 +23,7 @@ __all__ = [ +<<<<<<< HEAD "accumulate", "chain", "chain_from_iterable", @@ -25,6 +31,12 @@ "cycle", "dropwhile", "filterfalse", +======= + "chain", + "chain_from_iterable", + "compress", + "dropwhile", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "islice", "tee", "zip_longest", @@ -45,6 +57,7 @@ def chain(*iterables: Iterable[_T]) -> Iterator[_T]: yield from iterable +<<<<<<< HEAD # Reference: https://docs.python.org/3/library/itertools.html#itertools.accumulate @substitute_in_graph(itertools.accumulate, is_embedded_type=True) # type: ignore[arg-type] def accumulate( @@ -81,6 +94,11 @@ def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: # If iterable is an infinite generator, this will lead to infinite recursion for it in iterable: yield from it +======= +@substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type] +def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: + return itertools.chain(*iterable) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) chain.from_iterable = chain_from_iterable # type: ignore[attr-defined] @@ -92,6 +110,7 @@ def compress(data: Iterable[_T], selectors: Iterable[_U], /) -> Iterator[_T]: return (datum for datum, selector in zip(data, selectors) if selector) +<<<<<<< HEAD # Reference: https://docs.python.org/3/library/itertools.html#itertools.cycle @substitute_in_graph(itertools.cycle, is_embedded_type=True) # type: ignore[arg-type] def cycle(iterable: Iterable[_T]) -> Iterator[_T]: @@ -110,6 +129,8 @@ def _cycle(iterator: Iterator[_T]) -> Iterator[_T]: return _cycle(iterator) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Reference: https://docs.python.org/3/library/itertools.html#itertools.dropwhile @substitute_in_graph(itertools.dropwhile, is_embedded_type=True) # type: ignore[arg-type] def dropwhile(predicate: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]: @@ -124,6 +145,7 @@ def dropwhile(predicate: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[ yield from iterator +<<<<<<< HEAD @substitute_in_graph(itertools.filterfalse, is_embedded_type=True) # type: ignore[arg-type] def filterfalse(function: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]: it = iter(iterable) @@ -133,6 +155,8 @@ def filterfalse(function: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator return filter(lambda x: not function(x), it) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Reference: https://docs.python.org/3/library/itertools.html#itertools.islice @substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type] def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]: diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index d348a422ff576..fdbd1abc2a9a0 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -13,14 +13,20 @@ # See also the TYPE_CHECKING block in torch/_dynamo/polyfills/__init__.py POLYFILLED_MODULE_NAMES: tuple[str, ...] = ( +<<<<<<< HEAD "_collections", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "builtins", "functools", "itertools", "operator", "os", "pytree", +<<<<<<< HEAD "struct", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "sys", "fx", "tensor", diff --git a/torch/_dynamo/polyfills/operator.py b/torch/_dynamo/polyfills/operator.py index 4ce889b297c9f..6dfc5b4971894 100644 --- a/torch/_dynamo/polyfills/operator.py +++ b/torch/_dynamo/polyfills/operator.py @@ -5,18 +5,27 @@ from __future__ import annotations import operator +<<<<<<< HEAD from typing import Any, Callable, overload, TYPE_CHECKING, TypeVar +======= +from typing import Any, Callable, overload, TypeVar +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing_extensions import TypeVarTuple, Unpack from ..decorators import substitute_in_graph +<<<<<<< HEAD if TYPE_CHECKING: from collections.abc import Iterable # Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`) __all__ = ["attrgetter", "itemgetter", "methodcaller", "countOf"] +======= +# Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`) +__all__ = ["attrgetter", "itemgetter", "methodcaller"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _T = TypeVar("_T") @@ -107,9 +116,12 @@ def caller(obj: Any) -> Any: return getattr(obj, name)(*args, **kwargs) return caller +<<<<<<< HEAD # Reference: https://docs.python.org/3/library/operator.html#operator.countOf @substitute_in_graph(operator.countOf, can_constant_fold_through=True) # type: ignore[arg-type,misc] def countOf(a: Iterable[_T], b: _T, /) -> int: return sum(it is b or it == b for it in a) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_dynamo/polyfills/sys.py b/torch/_dynamo/polyfills/sys.py index ab666c385806f..b4aee1817ee0d 100644 --- a/torch/_dynamo/polyfills/sys.py +++ b/torch/_dynamo/polyfills/sys.py @@ -23,6 +23,7 @@ def intern(string: str, /) -> str: @substitute_in_graph(sys.getrecursionlimit, can_constant_fold_through=True) def getrecursionlimit() -> int: return sys.getrecursionlimit() +<<<<<<< HEAD if hasattr(sys, "get_int_max_str_digits"): @@ -32,3 +33,5 @@ def get_int_max_str_digits() -> int: return sys.get_int_max_str_digits() __all__ += ["get_int_max_str_digits"] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_dynamo/precompile_context.py b/torch/_dynamo/precompile_context.py index 38f97e583375d..82ad37dd3fef6 100644 --- a/torch/_dynamo/precompile_context.py +++ b/torch/_dynamo/precompile_context.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD import copy import logging import pickle @@ -5,6 +6,11 @@ from collections import defaultdict from itertools import chain from typing import Any, Callable, Generic, Optional, TypeVar, Union +======= +from abc import abstractmethod +from collections import defaultdict +from typing import Any, Generic, Optional, TypeVar +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing_extensions import override from torch.compiler._cache import ( @@ -24,7 +30,10 @@ """ T = TypeVar("T") +<<<<<<< HEAD logger = logging.getLogger(__name__) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class PrecompileCacheArtifact(CacheArtifact, Generic[T]): @@ -65,6 +74,7 @@ def after_deserialization(self) -> T: ... +<<<<<<< HEAD class EditablePrecompileCacheArtifact(Generic[T]): """ A PrecompileCacheArtifact whose content isn't encoded until we call PrecompileContext.serialize() @@ -95,6 +105,8 @@ def edit_contents(self, edit_fn: Callable[..., Any]) -> None: self.content = edit_fn(self.content) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class PrecompileContext(CacheArtifactManager): """ PrecompileContext is a special CacheArtifactManager for handling precompilation @@ -104,8 +116,12 @@ class PrecompileContext(CacheArtifactManager): The following artifact types are supported by PrecompileContext: - BundledAOTAutogradCacheArtifact +<<<<<<< HEAD - DynamoCodeStateArtifact - AutotuneCacheArtifact (regular autotune results, same as Megacache) +======= + - CodeStateArtifact (from torch._dynamo.package once available) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ # Protected by the compile_lock @@ -113,9 +129,13 @@ class PrecompileContext(CacheArtifactManager): # This allows us to implement serialize_by_key easily. # On call to `serialize()`, all cache artifacts in _new_cache_artifacts_by_key # are transferred to _new_cache_artifacts before serialization. +<<<<<<< HEAD _new_cache_artifacts_by_key: dict[ str, Union[EditablePrecompileCacheArtifact[object], CacheArtifact] ] = {} +======= + _new_cache_artifacts_by_key: dict[str, CacheArtifact] = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _new_cache_artifacts: CacheArtifactsResult = defaultdict(list) # Keep a separate seen artifacts list to make avoid unnecessary duplicates # This list will not be cleared between serialize() calls @@ -140,12 +160,16 @@ def record_artifact( artifact_type: str, key: str, content: Any, +<<<<<<< HEAD editable: bool = False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: """ Called from each caching operation to record the artifact in this "mega" list """ +<<<<<<< HEAD artifact: Union[EditablePrecompileCacheArtifact[object], CacheArtifact] if editable: artifact = EditablePrecompileCacheArtifact(artifact_type, content, key) @@ -161,6 +185,19 @@ def record_artifact( cls._seen_artifacts.add(artifact) cls._new_cache_artifacts_by_key[key] = artifact +======= + artifact = CacheArtifactFactory.encode_create(artifact_type, key, content) + # TODO: although this covers completely same artifacts, it's possible + # with AOTAutogradCacheEntries to have multiple artifacts whose keys + # (i.e. backend_ids) are different, but whose contents are equal. + # In those cases, it would be much better if we only serialize once instead + # of N times. + if artifact in cls._seen_artifacts: + return + + cls._new_cache_artifacts_by_key[key] = artifact + cls._seen_artifacts.add(artifact) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def _save_artifacts_by_type(cls) -> None: @@ -169,12 +206,16 @@ def _save_artifacts_by_type(cls) -> None: by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts """ for artifact in cls._new_cache_artifacts_by_key.values(): +<<<<<<< HEAD if isinstance(artifact, EditablePrecompileCacheArtifact): artifact = artifact.real_encode() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cls._new_cache_artifacts[artifact.__class__.type()].append(artifact) cls._new_cache_artifacts_by_key.clear() @classmethod +<<<<<<< HEAD def edit_artifact(cls, key: str, edit_fn: Callable[..., Any]) -> None: """ Edit the content of an existing artifact @@ -189,25 +230,35 @@ def edit_artifact(cls, key: str, edit_fn: Callable[..., Any]) -> None: artifact.edit_contents(edit_fn) @classmethod +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]: """ Serialize all artifacts with the given key returned in a list. """ +<<<<<<< HEAD result = cls._new_cache_artifacts_by_key.get(key, None) if isinstance(result, EditablePrecompileCacheArtifact): result = result.real_encode() return result +======= + return cls._new_cache_artifacts_by_key.get(key, None) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]: cls._save_artifacts_by_type() +<<<<<<< HEAD # No need to serialize if there are no new dynamo compiles if "precompile_dynamo" not in cls._new_cache_artifacts: return None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().serialize() @staticmethod def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo: +<<<<<<< HEAD PrecompileContext._ensure_cache_artifacts_registered() artifacts_by_key = {} @@ -240,6 +291,12 @@ def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo: @classmethod def _ensure_cache_artifacts_registered(cls) -> None: from torch._dynamo.package import _DynamoCacheArtifact # noqa: F401 +======= + raise NotImplementedError("TODO") + + @classmethod + def _ensure_cache_artifacts_registered(cls) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401 BundledAOTAutogradCacheArtifact, ) diff --git a/torch/_dynamo/profiler.py b/torch/_dynamo/profiler.py index 2055507f72a4c..9426afce20f63 100644 --- a/torch/_dynamo/profiler.py +++ b/torch/_dynamo/profiler.py @@ -12,8 +12,11 @@ by tracking both captured and total operations, timing, and graph statistics. """ +<<<<<<< HEAD from __future__ import annotations +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import dataclasses import os from typing import Any @@ -37,7 +40,11 @@ def __iadd__(self, other: Self) -> Self: self.fusions += other.fusions return self +<<<<<<< HEAD def __add__(self, other: ProfileMetrics) -> ProfileMetrics: +======= + def __add__(self, other: "ProfileMetrics") -> "ProfileMetrics": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(other, ProfileMetrics) return ProfileMetrics( self.microseconds + other.microseconds, @@ -45,7 +52,11 @@ def __add__(self, other: ProfileMetrics) -> ProfileMetrics: self.fusions + other.fusions, ) +<<<<<<< HEAD def __truediv__(self, other: Any) -> ProfileMetrics: +======= + def __truediv__(self, other: Any) -> "ProfileMetrics": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(other, int): other = ProfileMetrics(other, other, other) return ProfileMetrics( diff --git a/torch/_dynamo/replay_record.py b/torch/_dynamo/replay_record.py index 5d01217fdbb61..4e163dd209fcd 100644 --- a/torch/_dynamo/replay_record.py +++ b/torch/_dynamo/replay_record.py @@ -15,9 +15,14 @@ import dataclasses from dataclasses import field +<<<<<<< HEAD from io import BufferedReader, BufferedWriter from types import CellType, CodeType, ModuleType from typing import Any, IO, Union +======= +from types import CellType, CodeType, ModuleType +from typing import Any, IO +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing_extensions import Self from torch.utils._import_utils import import_dill @@ -52,12 +57,20 @@ class ExecutionRecord: builtins: dict[str, Any] = field(default_factory=dict) code_options: dict[str, Any] = field(default_factory=dict) +<<<<<<< HEAD def dump(self, f: Union[IO[str], BufferedWriter]) -> None: +======= + def dump(self, f: IO[str]) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert dill is not None, "replay_record requires `pip install dill`" dill.dump(self, f) @classmethod +<<<<<<< HEAD def load(cls, f: Union[IO[bytes], BufferedReader]) -> Self: +======= + def load(cls, f: IO[bytes]) -> Self: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert dill is not None, "replay_record requires `pip install dill`" return dill.load(f) diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index 998acc7397753..a33bffda38c44 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Utilities for reproducing and debugging issues in PyTorch's Dynamo AOT compilation. @@ -17,8 +22,11 @@ the Dynamo AOT compilation pipeline, particularly for the Inductor backend. """ +<<<<<<< HEAD from __future__ import annotations +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import argparse import copy import functools @@ -30,6 +38,7 @@ import sys import textwrap import uuid +<<<<<<< HEAD from importlib import import_module from tempfile import TemporaryFile from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union @@ -51,6 +60,14 @@ class Heuristics: # type: ignore[no-redef] pass +======= +from collections.abc import Sequence +from importlib import import_module +from tempfile import TemporaryFile +from typing import Any, Callable, TYPE_CHECKING, Union +from typing_extensions import Unpack + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.fx as fx import torch.nn as nn @@ -75,10 +92,15 @@ class Heuristics: # type: ignore[no-redef] ) from torch._dynamo.utils import clone_inputs, counters, same from torch._environment import is_fbcode +<<<<<<< HEAD from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table from torch._inductor.cpp_builder import normalize_path_separator from torch._library.fake_class_registry import FakeScriptObject from torch._ops import OpOverload +======= +from torch._inductor.output_code import OutputCode +from torch._library.fake_class_registry import FakeScriptObject +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ( fx_placeholder_targets, @@ -90,10 +112,14 @@ class Heuristics: # type: ignore[no-redef] if TYPE_CHECKING: +<<<<<<< HEAD from collections.abc import Sequence from torch._inductor.compile_fx import _CompileFxCallable, _CompileFxKwargs from torch._inductor.output_code import OutputCode +======= + from torch._inductor.compile_fx import _CompileFxCallable, _CompileFxKwargs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.utils import InputType @@ -109,9 +135,15 @@ class Heuristics: # type: ignore[no-redef] def wrap_compiler_debug( +<<<<<<< HEAD unconfigured_compiler_fn: _CompileFxCallable, compiler_name: str, ) -> _CompileFxCallable: +======= + unconfigured_compiler_fn: "_CompileFxCallable", + compiler_name: str, +) -> "_CompileFxCallable": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both forward and backward call separately with the backend compiler_fn - like @@ -123,8 +155,13 @@ def wrap_compiler_debug( @functools.wraps(unconfigured_compiler_fn) def debug_wrapper( gm: torch.fx.GraphModule, +<<<<<<< HEAD example_inputs: Sequence[InputType], **kwargs: Unpack[_CompileFxKwargs], +======= + example_inputs: Sequence["InputType"], + **kwargs: Unpack["_CompileFxKwargs"], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> OutputCode: from torch._subclasses import FakeTensorMode @@ -164,7 +201,11 @@ def debug_wrapper( # We may run regular PyTorch compute that may trigger Dynamo, do NOT # recursively attempt to accuracy minify in that case! def deferred_for_real_inputs( +<<<<<<< HEAD real_inputs: Sequence[InputType], **_kwargs: object +======= + real_inputs: Sequence["InputType"], **_kwargs: object +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Any: # This is a bit obscure: if we recursively try to accuracy minify # the SAME function, this would trigger. But most of the time @@ -176,7 +217,11 @@ def deferred_for_real_inputs( with config.patch(repro_after=None): return inner_debug_fn(real_inputs) +<<<<<<< HEAD def inner_debug_fn(real_inputs: Sequence[InputType]) -> Any: +======= + def inner_debug_fn(real_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Aot Autograd fw_compiler and bw_compiler can have fake tensors. So, example_inputs can be fake tensors. We can call compiler_fn (which is @@ -205,7 +250,11 @@ def inner_debug_fn(real_inputs: Sequence[InputType]) -> Any: ) failed = not same_two_models( gm, +<<<<<<< HEAD inner_compiled_fn, # type: ignore[arg-type] +======= + inner_compiled_fn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) real_inputs, only_fwd=True, ignore_non_fp=config.repro_ignore_non_fp, @@ -269,7 +318,11 @@ def inner_debug_fn(real_inputs: Sequence[InputType]) -> Any: # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +<<<<<<< HEAD def maybe_fbcode_instructions() -> str: +======= +def maybe_fbcode_instructions(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_fbcode(): extra_deps_formatted = "\n".join([f' "{dep}",' for dep in extra_deps]) if len(extra_deps_formatted) > 0: @@ -302,6 +355,7 @@ def maybe_fbcode_instructions() -> str: def generate_compiler_repro_string( +<<<<<<< HEAD gm: torch.fx.GraphModule, args: Sequence[Any], *, @@ -332,6 +386,10 @@ def generate_compiler_repro_string( """ ).strip() +======= + gm, args, *, stable_output=False, save_dir=None, stable_hash=False +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model_str = textwrap.dedent( f""" {generate_env_vars_string(stable_output=stable_output)} @@ -341,8 +399,11 @@ def generate_compiler_repro_string( from torch._dynamo.testing import rand_strided from math import inf import torch._inductor.inductor_prims +<<<<<<< HEAD {distributed_imports} {triton_imports} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {generate_config_string(stable_output=stable_output)} @@ -351,7 +412,11 @@ def generate_compiler_repro_string( {extra_imports} {maybe_fbcode_instructions()} +<<<<<<< HEAD """ +======= + """ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if not stable_output: model_str += f"# torch version: {torch.version.__version__}\n" @@ -361,6 +426,7 @@ def generate_compiler_repro_string( model_str += f"# torch git version: {torch.version.git_version}\n\n\n" model_str += _cuda_system_info_comment() +<<<<<<< HEAD kernel_side_table_prefix = ( "torch._higher_order_ops.triton_kernel_wrap.kernel_side_table" ) @@ -420,6 +486,16 @@ def generate_compiler_repro_string( # Extract from graph placeholders and their corresponding arguments placeholder_targets = fx_placeholder_targets(gm) for placeholder, arg in zip(placeholder_targets, args): +======= + model_str += NNModuleToString.convert(gm) + + # get hint shape/stride when dynamic shape enabled + def hint_if_symint(x): + return tuple(i.node.hint if isinstance(i, torch.SymInt) else i for i in x) + + writer = InputWriter(save_dir, stable_hash=stable_hash) + for placeholder, arg in zip(fx_placeholder_targets(gm), args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(arg, (int, torch.SymInt)): writer.symint(placeholder, arg) elif isinstance(arg, torch.Tensor): @@ -428,6 +504,7 @@ def generate_compiler_repro_string( elif arg is None: writer.const(placeholder) else: +<<<<<<< HEAD writer.unsupported(placeholder, arg) # Extract symbolic variables from the same arguments @@ -454,12 +531,20 @@ def generate_compiler_repro_string( load_args_lines = writer.lines() load_args_code = "\n".join(load_args_lines) model_str += load_args_code + "\n" +======= + # It's better to produce a slightly wrong repro string than none + # at all + writer.unsupported(placeholder, arg) + + model_str += "\n".join(writer.lines()) + "\n" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model_str += "mod = Repro()\n" return model_str def save_graph_repro( +<<<<<<< HEAD fd: IO[Any], gm: torch.fx.GraphModule, args: Sequence[Any], @@ -473,6 +558,21 @@ def save_graph_repro( check_str: Optional[str] = None, stable_hash: bool = False, ) -> None: +======= + fd, + gm, + args, + compiler_name, + *, + stable_output=False, + save_dir=None, + command="run", + accuracy=None, + tracing_mode=None, + check_str=None, + stable_hash=False, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if any( isinstance(arg, torch.fx.experimental._backward_state.BackwardState) for arg in args @@ -482,6 +582,7 @@ def save_graph_repro( ) return +<<<<<<< HEAD if save_dir is not None: save_dir = normalize_path_separator(save_dir) @@ -493,6 +594,8 @@ def save_graph_repro( for node in gm.graph.nodes ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fd.write( generate_compiler_repro_string( gm, @@ -500,7 +603,10 @@ def save_graph_repro( stable_output=stable_output, save_dir=save_dir, stable_hash=stable_hash, +<<<<<<< HEAD has_distributed_ops=has_distributed_ops, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) if accuracy is None: @@ -513,6 +619,7 @@ def save_graph_repro( tracing_mode = "symbolic" fd.write("if __name__ == '__main__':\n") fd.write(" from torch._dynamo.repro.after_aot import run_repro\n") +<<<<<<< HEAD # Add distributed initialization before run_repro if needed if has_distributed_ops: @@ -527,6 +634,8 @@ def save_graph_repro( " )\n" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fd.write( f" with torch.no_grad():\n" f" run_repro(mod, load_args, accuracy={accuracy!r}, command={command!r}, " @@ -537,6 +646,7 @@ def save_graph_repro( f" # mod(*args)" ) +<<<<<<< HEAD # Add distributed cleanup after run_repro if has_distributed_ops: fd.write("\n dist.destroy_process_group()\n") @@ -549,6 +659,10 @@ def dump_compiler_graph_state( *, accuracy: Optional[Union[str, bool]] = None, ) -> None: +======= + +def dump_compiler_graph_state(gm, args, compiler_name, *, accuracy=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) subdir = os.path.join(minifier_dir(), "checkpoints") if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) @@ -576,9 +690,13 @@ def dump_compiler_graph_state( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +<<<<<<< HEAD def dump_to_minify( gm: torch.fx.GraphModule, args: Sequence[Any], compiler_name: str ) -> None: +======= +def dump_to_minify(gm, args, compiler_name: str): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = io.StringIO() # TODO: factor this out subdir = os.path.join(minifier_dir(), "checkpoints") @@ -589,6 +707,7 @@ def dump_to_minify( def isolate_fails( +<<<<<<< HEAD fx_g: torch.fx.GraphModule, args: Sequence[Any], compiler_name: str, @@ -598,6 +717,17 @@ def isolate_fails( tracing_mode: Optional[str] = None, check_str: Optional[str] = None, ) -> bool: +======= + fx_g, + args, + compiler_name: str, + env=None, + save_dir=None, + accuracy=None, + tracing_mode=None, + check_str=None, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if env is None: env = {} subdir = os.path.join(os.getcwd(), "isolate") @@ -653,16 +783,24 @@ def isolate_fails( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +<<<<<<< HEAD def inductor_fails( fx_g: torch.fx.GraphModule, args: Sequence[Any], check_str: Optional[str] = None ) -> bool: +======= +def inductor_fails(fx_g, args, check_str=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) has_cuda = False for arg in args: if isinstance(arg, torch.Tensor) and arg.is_cuda: has_cuda = True break +<<<<<<< HEAD def sync() -> None: +======= + def sync(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if has_cuda: # Ensures that segfaults are surfaced torch.cuda.synchronize() @@ -692,6 +830,7 @@ def sync() -> None: def inductor_accuracy_fails( +<<<<<<< HEAD fx_g: torch.fx.GraphModule, args: Sequence[Any], check_str: Optional[str] = None, @@ -699,12 +838,21 @@ def inductor_accuracy_fails( require_fp64: bool = False, ignore_non_fp: bool = False, ) -> bool: +======= + fx_g, args, check_str=None, *, require_fp64=False, ignore_non_fp=False +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.compile_fx import compile_fx_inner return backend_aot_accuracy_fails( fx_g, +<<<<<<< HEAD args, # type: ignore[arg-type] compile_fx_inner, # type: ignore[arg-type] +======= + args, + compile_fx_inner, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) require_fp64=require_fp64, ignore_non_fp=ignore_non_fp, ) @@ -718,9 +866,13 @@ def inductor_accuracy_fails( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +<<<<<<< HEAD def repro_common( options: Any, mod: nn.Module, load_args: Any ) -> tuple[torch.fx.GraphModule, Sequence[Any]]: +======= +def repro_common(options, mod, load_args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Invariant for graphs we generate with the repro script assert not any(mod.named_parameters()) for n, b in mod.named_buffers(): @@ -763,7 +915,11 @@ def repro_common( return mod, args +<<<<<<< HEAD ACCURACY_FAILS: dict[str, Callable[[torch.fx.GraphModule, Any], bool]] = { +======= +ACCURACY_FAILS: dict[str, Callable[[nn.Module, Any], bool]] = { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "": inductor_fails, # This might look inverted but it's not. strict_accuracy means "we will # minify any time we see anything that diverges", whereas accuracy is more @@ -776,7 +932,11 @@ def repro_common( } +<<<<<<< HEAD def repro_minifier_query(options: Any, mod: nn.Module, load_args: Any) -> None: +======= +def repro_minifier_query(options, mod, load_args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mod, args = repro_common(options, mod, load_args) fail_fn = functools.partial( ACCURACY_FAILS[options.accuracy], @@ -788,7 +948,11 @@ def repro_minifier_query(options: Any, mod: nn.Module, load_args: Any) -> None: sys.exit(0) +<<<<<<< HEAD def repro_minify(options: Any, mod: nn.Module, load_args: Any) -> None: +======= +def repro_minify(options, mod, load_args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from functorch.compile import minifier mod, args = repro_common(options, mod, load_args) @@ -825,7 +989,11 @@ def repro_minify(options: Any, mod: nn.Module, load_args: Any) -> None: ) +<<<<<<< HEAD def repro_analyze(options: Any, mod: nn.Module, load_args: Any) -> None: +======= +def repro_analyze(options, mod, load_args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.compile_fx import compile_fx_inner from torch._inductor.hooks import intermediate_hook @@ -843,7 +1011,11 @@ def repro_analyze(options: Any, mod: nn.Module, load_args: Any) -> None: known_names = set() +<<<<<<< HEAD def save_hook(name: str, val: Any) -> None: +======= + def save_hook(name, val): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) known_names.add(name) if not options.skip_saving_inductor_intermediates: writer.write_tensor(os.path.join("inductor", name), val) @@ -860,10 +1032,17 @@ def save_hook(name: str, val: Any) -> None: tqdm(desc="Saving inductor intermediates", total=total) as pbar, ): assert not isinstance(compiled, str) +<<<<<<< HEAD compiled(new_args) # type: ignore[arg-type] assert not new_args def compare_tuples(tuple1: tuple[Any], tuple2: tuple[Any]) -> Optional[str]: +======= + compiled(new_args) + assert not new_args + + def compare_tuples(tuple1, tuple2): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff_indices = [i for i in range(len(tuple1)) if tuple1[i] != tuple2[i]] diff_values = [(tuple1[i], tuple2[i]) for i in diff_indices] @@ -872,7 +1051,11 @@ def compare_tuples(tuple1: tuple[Any], tuple2: tuple[Any]) -> Optional[str]: else: return " and ".join(f"{a} != {b}" for a, b in diff_values) +<<<<<<< HEAD def check_hook(name: str, val: Any) -> None: +======= + def check_hook(name, val): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) meta = writer.compute_tensor_metadata(val) meta2 = reader.read_tensor_metadata(os.path.join("inductor", name)) reason = compare_tuples(meta, meta2) @@ -886,6 +1069,7 @@ def check_hook(name: str, val: Any) -> None: intermediate_hook(check_hook), tqdm(desc="Checking inductor determinism", total=total) as pbar, ): +<<<<<<< HEAD compiled(new_args) # type: ignore[arg-type] assert not new_args @@ -895,6 +1079,17 @@ def __init__(self, mod: torch.nn.Module, subdir: str) -> None: self.subdir = subdir def run_node(self, n: torch.fx.Node) -> Any: +======= + compiled(new_args) + assert not new_args + + class WriterInterp(fx.Interpreter): + def __init__(self, mod, subdir) -> None: + super().__init__(mod) + self.subdir = subdir + + def run_node(self, n): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = super().run_node(n) name = n.name if name in known_names: @@ -905,13 +1100,21 @@ def run_node(self, n: torch.fx.Node) -> Any: # NB: the module cast doesn't actually do anything, since there are no # parameters/buffers on the module if not options.skip_saving_float64_intermediates: +<<<<<<< HEAD new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) # type: ignore[arg-type] +======= + new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with tqdm(desc="Saving float64 intermediates", total=total) as pbar: WriterInterp(new_mod, "float64").boxed_run(new_args) assert not new_args class ExactReaderInterp(fx.Interpreter): +<<<<<<< HEAD def run_node(self, n: torch.fx.Node) -> Any: +======= + def run_node(self, n): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = super().run_node(n) name = n.name if name in known_names: @@ -926,7 +1129,11 @@ def run_node(self, n: torch.fx.Node) -> Any: # TODO: check eager determinism if not options.skip_check_deterministic: +<<<<<<< HEAD new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) # type: ignore[arg-type] +======= + new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with tqdm(desc="Checking float64 determinism", total=total) as pbar: ExactReaderInterp(new_mod).boxed_run(new_args) assert not new_args @@ -934,7 +1141,11 @@ def run_node(self, n: torch.fx.Node) -> Any: # Now that we've saved everything, interp through the eager graph # and do comparisons class ReaderInterp(fx.Interpreter): +<<<<<<< HEAD def run_node(self, n: torch.fx.Node) -> Any: +======= + def run_node(self, n): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = super().run_node(n) name = n.name if name in known_names: @@ -942,7 +1153,11 @@ def run_node(self, n: torch.fx.Node) -> Any: float64 = reader.read_tensor(os.path.join("float64", name)) logged = False +<<<<<<< HEAD def log_error(msg: str, *args: Any) -> None: +======= + def log_error(msg, *args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal logged logged = True pbar.write(f"DIVERGED at {name}: {msg % args}") @@ -964,6 +1179,7 @@ def log_error(msg: str, *args: Any) -> None: assert not args +<<<<<<< HEAD def repro_get_args( options: Any, mod: nn.Module, load_args: Any ) -> tuple[torch.fx.GraphModule, list[Any]]: @@ -972,6 +1188,14 @@ def repro_get_args( def repro_run(options: Any, mod: nn.Module, load_args: Any) -> None: +======= +def repro_get_args(options, mod, load_args): + mod, args = repro_common(options, mod, load_args) + return mod, args + + +def repro_run(options, mod, load_args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.compile_fx import compile_fx_inner mod, args = repro_common(options, mod, load_args) @@ -986,7 +1210,11 @@ def repro_run(options: Any, mod: nn.Module, load_args: Any) -> None: # seems counterintuitive if not same_two_models( mod, +<<<<<<< HEAD compiled, # type: ignore[arg-type] +======= + compiled, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args, only_fwd=True, ignore_non_fp=config.repro_ignore_non_fp, @@ -1008,6 +1236,7 @@ def repro_run(options: Any, mod: nn.Module, load_args: Any) -> None: # TODO: lazily load the inputs or something, rather than cloning them def run_repro( +<<<<<<< HEAD mod: nn.Module, load_args: Any, *, @@ -1019,6 +1248,19 @@ def run_repro( check_str: Optional[str] = None, **kwargs: Any, ) -> Any: +======= + mod, + load_args, + *, + command="run", + accuracy: Union[bool, str] = "", + save_dir=None, + tracing_mode=None, + patch_code=None, + check_str=None, + **kwargs, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for k in kwargs: log.warning( "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", @@ -1051,7 +1293,11 @@ def run_repro( formatter_class=argparse.RawTextHelpFormatter, ) +<<<<<<< HEAD def common_flags(parser: argparse.ArgumentParser) -> None: +======= + def common_flags(parser): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) accuracy_group = parser.add_mutually_exclusive_group() accuracy_group.add_argument( "--no-accuracy", diff --git a/torch/_dynamo/repro/after_dynamo.py b/torch/_dynamo/repro/after_dynamo.py index 65b9fc2eaa35d..7cd9798f00093 100644 --- a/torch/_dynamo/repro/after_dynamo.py +++ b/torch/_dynamo/repro/after_dynamo.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Utilities for reproducing and debugging issues in Dynamo after graph capture. @@ -24,12 +29,21 @@ import shutil import sys import textwrap +<<<<<<< HEAD from collections.abc import Sequence from importlib import import_module from typing import Any, Callable, Optional, Union import torch import torch.fx as fx +======= +from importlib import import_module +from typing import Union + +import torch +import torch.fx as fx +from torch._dynamo.backends.registry import CompiledFn +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo.debug_utils import ( AccuracyError, backend_accuracy_fails, @@ -51,7 +65,11 @@ from torch.hub import tqdm from .. import config +<<<<<<< HEAD from ..backends.registry import CompilerFn, lookup_backend, register_debug_backend +======= +from ..backends.registry import lookup_backend, register_debug_backend +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..debug_utils import clone_inputs_retaining_gradness @@ -66,11 +84,15 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +<<<<<<< HEAD def _accuracy_fails( gm: torch.fx.GraphModule, example_inputs: Sequence[Any], compiler_fn: Callable[[torch.fx.GraphModule, list[Any]], torch.fx.GraphModule], ) -> bool: +======= +def _accuracy_fails(gm, example_inputs, compiler_fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return backend_accuracy_fails( gm, example_inputs, @@ -81,15 +103,22 @@ def _accuracy_fails( class WrapBackendDebug: +<<<<<<< HEAD def __init__( self, unconfigured_compiler_fn: CompilerFn, compiler_name: Optional[str] ) -> None: functools.wraps(unconfigured_compiler_fn)(self) self._torchdynamo_orig_backend = unconfigured_compiler_fn +======= + def __init__(self, unconfigured_compiler_fn, compiler_name: str) -> None: + functools.wraps(unconfigured_compiler_fn)(self) + self._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._compiler_name = compiler_name if hasattr(unconfigured_compiler_fn, "__name__"): self.__name__ = unconfigured_compiler_fn.__name__ if hasattr(unconfigured_compiler_fn, "compiler_name"): +<<<<<<< HEAD self.__name__ = unconfigured_compiler_fn.compiler_name # type: ignore[attr-defined] if hasattr(unconfigured_compiler_fn, "get_compiler_config"): self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] @@ -98,16 +127,33 @@ def __call__( self, gm: torch.fx.GraphModule, example_inputs: list[Any], **kwargs: Any ) -> torch.fx.GraphModule: compiler_fn = functools.partial(self._torchdynamo_orig_backend, **kwargs) +======= + self.__name__ = unconfigured_compiler_fn.compiler_name + if hasattr(unconfigured_compiler_fn, "get_compiler_config"): + self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] + + def __call__(self, gm, example_inputs, **kwargs): + compiler_fn = functools.partial(self._torchdynamo_orig_callable, **kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert config.repro_after in ("dynamo", "aot", None) if config.repro_after == "dynamo": +<<<<<<< HEAD def add_paths(exc: Exception) -> None: exc.minifier_path = os.path.join(minifier_dir(), "minifier_launcher.py") # type: ignore[attr-defined] if use_buck: exc.buck_command = " ".join( # type: ignore[attr-defined] BUCK_CMD_PREFIX + [BuckTargetWriter(exc.minifier_path).cmd_line_path] # type: ignore[attr-defined] +======= + def add_paths(exc): + exc.minifier_path = os.path.join(minifier_dir(), "minifier_launcher.py") + if use_buck: + exc.buck_command = " ".join( + BUCK_CMD_PREFIX + + [BuckTargetWriter(exc.minifier_path).cmd_line_path] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if config.repro_level == 3: @@ -117,7 +163,11 @@ def add_paths(exc: Exception) -> None: if config.repro_level == 4: # Check Accuracy compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) +<<<<<<< HEAD if _accuracy_fails(gm, example_inputs, compiler_fn): # type: ignore[arg-type] +======= + if _accuracy_fails(gm, example_inputs, compiler_fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log.warning( "Accuracy failed for the TorchDynamo produced graph. Creating script to minify the error." ) @@ -132,7 +182,11 @@ def add_paths(exc: Exception) -> None: else: try: compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) +<<<<<<< HEAD run_fwd_maybe_bwd(compiled_gm, example_inputs) # type: ignore[arg-type] +======= + run_fwd_maybe_bwd(compiled_gm, example_inputs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except Exception as exc: log.warning( "Compiled Fx GraphModule failed. Creating script to minify the error." @@ -155,12 +209,19 @@ def add_paths(exc: Exception) -> None: else: compiled_gm = compiler_fn(gm, example_inputs) +<<<<<<< HEAD return compiled_gm # type: ignore[return-value] def wrap_backend_debug( unconfigured_compiler_fn: CompilerFn, compiler_name: Optional[str] ) -> WrapBackendDebug: +======= + return compiled_gm + + +def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ A minifier decorator that wraps the TorchDynamo produced Fx graph modules. As opposed to wrap_compiler_debug, this wrapper intercepts at the @@ -178,6 +239,7 @@ def wrap_backend_debug( def generate_dynamo_fx_repro_string( +<<<<<<< HEAD gm: torch.fx.GraphModule, args: Sequence[Any], compiler_name: Optional[str], @@ -187,6 +249,17 @@ def generate_dynamo_fx_repro_string( save_dir: Optional[str] = None, command: str = "run", ) -> str: +======= + gm, + args, + compiler_name, + check_accuracy=False, + *, + stable_output=False, + save_dir=None, + command="run", +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Generate a repro string for backend-agnostic minified version. """ @@ -233,12 +306,16 @@ def generate_dynamo_fx_repro_string( ) +<<<<<<< HEAD def dump_backend_repro_as_file( gm: torch.fx.GraphModule, args: Sequence[Any], compiler_name: Optional[str], check_accuracy: bool = False, ) -> None: +======= +def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Saves the repro to a repro.py file """ @@ -266,12 +343,16 @@ def dump_backend_repro_as_file( shutil.copyfile(file_name, latest_repro) +<<<<<<< HEAD def dump_backend_state( gm: torch.fx.GraphModule, args: Sequence[Any], compiler_name: Optional[str], check_accuracy: bool = False, ) -> None: +======= +def dump_backend_state(gm, args, compiler_name, check_accuracy=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Dumps the dynamo graph to repro the issue. 1) It tries to convert Fx GraphModule to a string. If we can, it writes to a @@ -289,9 +370,13 @@ def dump_backend_state( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +<<<<<<< HEAD def dump_to_minify_after_dynamo( gm: torch.fx.GraphModule, args: Sequence[Any], compiler_name: Optional[str] ) -> None: +======= +def dump_to_minify_after_dynamo(gm, args, compiler_name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: factor this out subdir = os.path.join(minifier_dir(), "checkpoints") if not os.path.exists(subdir): @@ -315,11 +400,19 @@ def dump_to_minify_after_dynamo( @register_debug_backend # type: ignore[arg-type] def dynamo_minifier_backend( +<<<<<<< HEAD gm: fx.GraphModule, example_inputs: Sequence[Any], compiler_name: Optional[str] ) -> fx.GraphModule: from functorch.compile import minifier compiler_fn = lookup_backend(compiler_name) # type: ignore[arg-type] +======= + gm: fx.GraphModule, example_inputs, compiler_name: CompiledFn +): + from functorch.compile import minifier + + compiler_fn = lookup_backend(compiler_name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: It's inconsistent to pass SymInt inputs but REAL tensors. # We should pass ints and look at the GraphModule placeholders @@ -330,7 +423,11 @@ def dynamo_minifier_backend( try: compiled_gm = compiler_fn(gm, example_inputs) +<<<<<<< HEAD run_fwd_maybe_bwd(compiled_gm, example_inputs) # type: ignore[arg-type] +======= + run_fwd_maybe_bwd(compiled_gm, example_inputs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise ValueError("No issue was detected") except Exception as exc: orig_failure = str(exc) @@ -356,25 +453,40 @@ def dynamo_minifier_backend( @register_debug_backend # type: ignore[arg-type] +<<<<<<< HEAD def dynamo_accuracy_minifier_backend( gm: fx.GraphModule, example_inputs: Sequence[Any], compiler_name: Optional[str] ) -> fx.GraphModule: from functorch.compile import minifier compiler_fn = lookup_backend(compiler_name) # type: ignore[arg-type] +======= +def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name): + from functorch.compile import minifier + + compiler_fn = lookup_backend(compiler_name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Set the eval mode to remove randomness. gm.eval() # Check Accuracy +<<<<<<< HEAD if _accuracy_fails(gm, example_inputs, compiler_fn): # type: ignore[arg-type] +======= + if _accuracy_fails(gm, example_inputs, compiler_fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log.warning("Accuracy failed for the TorchDynamo produced graph") dump_state_fn = functools.partial( dump_backend_state, compiler_name=compiler_name, check_accuracy=True ) fails_fn = functools.partial( _accuracy_fails, +<<<<<<< HEAD compiler_fn=compiler_fn, # type: ignore[arg-type] +======= + compiler_fn=compiler_fn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs) minifier( @@ -388,12 +500,16 @@ def dynamo_accuracy_minifier_backend( return gm +<<<<<<< HEAD def backend_fails( gm: fx.GraphModule, example_inputs: Sequence[Any], compiler_fn: CompilerFn, orig_failure: Sequence[Any], ) -> bool: +======= +def backend_fails(gm, example_inputs, compiler_fn, orig_failure): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Minifier uses this function to identify if the minified graph module fails with the same error. @@ -410,8 +526,13 @@ def backend_fails( try: # Run the original gm to check eager validity run_fwd_maybe_bwd(gm, clone_inputs_retaining_gradness(example_inputs)) +<<<<<<< HEAD compiled_gm = compiler_fn(gm, example_inputs) # type: ignore[arg-type] run_fwd_maybe_bwd(compiled_gm, clone_inputs_retaining_gradness(example_inputs)) # type: ignore[arg-type] +======= + compiled_gm = compiler_fn(gm, example_inputs) + run_fwd_maybe_bwd(compiled_gm, clone_inputs_retaining_gradness(example_inputs)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except Exception as e: new_failure = str(e) if SequenceMatcher(None, orig_failure, new_failure).ratio() > 0.5: @@ -424,7 +545,11 @@ def backend_fails( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +<<<<<<< HEAD def run_load_args(options: Any, mod: torch.nn.Module, load_args: Any) -> list[Any]: +======= +def run_load_args(options, mod, load_args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not hasattr(load_args, "_version"): log.warning( "load_args does not have a _version attribute, please file a bug to PyTorch " @@ -450,7 +575,11 @@ def run_load_args(options: Any, mod: torch.nn.Module, load_args: Any) -> list[An return args +<<<<<<< HEAD def repro_minify(options: Any, mod: torch.nn.Module, load_args: Any) -> None: +======= +def repro_minify(options, mod, load_args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = run_load_args(options, mod, load_args) # Setup debug minifier compiler @@ -469,7 +598,11 @@ def repro_minify(options: Any, mod: torch.nn.Module, load_args: Any) -> None: dynamo_minifier_backend = functools.partial( compiler_fn, +<<<<<<< HEAD compiler_name=options.backend, # type: ignore[call-arg] +======= + compiler_name=options.backend, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) opt_mod = torch._dynamo.optimize(dynamo_minifier_backend)(mod) @@ -477,20 +610,35 @@ def repro_minify(options: Any, mod: torch.nn.Module, load_args: Any) -> None: opt_mod(*args) +<<<<<<< HEAD def repro_run(options: Any, mod: torch.nn.Module, load_args: Any) -> None: +======= +def repro_run(options, mod, load_args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) opt_mod = torch._dynamo.optimize(options.backend)(mod) if options.accuracy != "": mod.eval() +<<<<<<< HEAD opt_mod.eval() # type: ignore[union-attr] +======= + opt_mod.eval() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch.amp.autocast("cuda", enabled=options.autocast): # TODO: disable clone args = run_load_args(options, mod, load_args) +<<<<<<< HEAD assert same_two_models(mod, mod, args), "Eager itself failed" # type: ignore[arg-type] if not same_two_models( mod, # type: ignore[arg-type] opt_mod, # type: ignore[arg-type] +======= + assert same_two_models(mod, mod, args), "Eager itself failed" + if not same_two_models( + mod, + opt_mod, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args, only_fwd=config.repro_forward_only, ignore_non_fp=config.repro_ignore_non_fp, @@ -499,19 +647,28 @@ def repro_run(options: Any, mod: torch.nn.Module, load_args: Any) -> None: else: with torch.amp.autocast("cuda", enabled=options.autocast): args = run_load_args(options, mod, load_args) +<<<<<<< HEAD run_fwd_maybe_bwd(mod, args, only_fwd=options.only_fwd, disable_clone=True) # type: ignore[arg-type] +======= + run_fwd_maybe_bwd(mod, args, only_fwd=options.only_fwd, disable_clone=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) del args args = run_load_args(options, mod, load_args) run_fwd_maybe_bwd( +<<<<<<< HEAD opt_mod, # type: ignore[arg-type] args, only_fwd=options.only_fwd, disable_clone=True, # type: ignore[arg-type] +======= + opt_mod, args, only_fwd=options.only_fwd, disable_clone=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def run_repro( +<<<<<<< HEAD mod: torch.nn.Module, load_args: Any, *, @@ -522,6 +679,18 @@ def run_repro( backend: str = "inductor", **kwargs: Any, ) -> None: +======= + mod, + load_args, + *, + command="run", + accuracy: Union[bool, str] = "", + save_dir=None, + autocast=False, + backend="inductor", + **kwargs, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for k in kwargs: log.warning( "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", @@ -547,7 +716,11 @@ def run_repro( formatter_class=argparse.RawTextHelpFormatter, ) +<<<<<<< HEAD def common_flags(parser: argparse.ArgumentParser) -> None: +======= + def common_flags(parser): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) accuracy_group = parser.add_mutually_exclusive_group() accuracy_group.add_argument( "--no-accuracy", diff --git a/torch/_dynamo/repro/aoti.py b/torch/_dynamo/repro/aoti.py index e0aaf4caee475..c7d54006e887c 100644 --- a/torch/_dynamo/repro/aoti.py +++ b/torch/_dynamo/repro/aoti.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Utilities for debugging and reproducing issues in Ahead of Time with Inductor (AOTI) compilation. @@ -24,9 +29,14 @@ import shutil import sys import textwrap +<<<<<<< HEAD from collections.abc import Sequence from importlib import import_module from typing import Any, IO, Optional, Union +======= +from importlib import import_module +from typing import Any, Optional, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch._dynamo.debug_utils import ( @@ -53,7 +63,11 @@ class AOTIMinifierError(Exception): +<<<<<<< HEAD def __init__(self, original_exception: Union[str, Exception]) -> None: +======= + def __init__(self, original_exception): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) additional_message = "This error is caused by a bug in the AOTI minifier, please report a bug to PyTorch" full_message = f"{additional_message}: {str(original_exception)}" super().__init__(full_message) @@ -65,7 +79,11 @@ def dump_to_minify( compiler_name: str, command: str = "minify", options: Optional[dict[str, Any]] = None, +<<<<<<< HEAD ) -> None: +======= +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ If command is "minify": Dump exported_program to `debug_dir/minifier/minifier_launcher.py`, with minify command. @@ -110,8 +128,13 @@ def dump_to_minify( log.warning("No write permissions for %s", file_name) +<<<<<<< HEAD def get_module_string(gm: torch.fx.GraphModule) -> str: def _convert_to_comment(s_: str) -> str: +======= +def get_module_string(gm): + def _convert_to_comment(s_): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) s = s_.split("\n") if len(s) == 1: return "# " + s_ @@ -131,13 +154,19 @@ def _convert_to_comment(s_: str) -> str: def save_graph_repro_ep( +<<<<<<< HEAD fd: IO[Any], compiler_name: str, +======= + fd, + compiler_name, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *, exported_program: Optional[ExportedProgram] = None, gm: Optional[torch.nn.Module] = None, args: Optional[tuple[Any]] = None, config_patches: Optional[dict[str, str]] = None, +<<<<<<< HEAD stable_output: bool = False, save_dir: Optional[str] = None, command: str = "run", @@ -146,6 +175,16 @@ def save_graph_repro_ep( module_in_comment: bool = False, strict: bool = False, ) -> None: +======= + stable_output=False, + save_dir=None, + command="run", + accuracy=None, + check_str=None, + module_in_comment=False, + strict=False, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Save graph for reproducing the error. # Either exported_program or gm will be saved, depending on which one is defined. # Only one of exported_program and gm should be defined. @@ -162,10 +201,17 @@ def save_graph_repro_ep( assert args is not None exported_program = torch.export.export(gm, args, strict=strict) elif gm is None: +<<<<<<< HEAD gm = exported_program.module(check_guards=False) # save a graph preview using gm module_string = get_module_string(gm) # type: ignore[arg-type] +======= + gm = exported_program.module() + + # save a graph preview using gm + module_string = get_module_string(gm) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fd.write(module_string) # save a graph repro using exported_program @@ -189,6 +235,7 @@ def save_graph_repro_ep( def dump_compiler_graph_state( +<<<<<<< HEAD gm: torch.fx.GraphModule, args: Sequence[Any], compiler_name: str, @@ -197,6 +244,16 @@ def dump_compiler_graph_state( accuracy: Optional[Union[str, bool]] = None, strict: bool = False, ) -> None: +======= + gm, + args, + compiler_name, + *, + config_patches=None, + accuracy=None, + strict=False, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) subdir = os.path.join(minifier_dir(), "checkpoints") if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) @@ -233,12 +290,21 @@ def dump_compiler_graph_state( def generate_compiler_repro_exported_program( +<<<<<<< HEAD exported_program: ExportedProgram, *, options: Optional[dict[str, str]] = None, stable_output: bool = False, save_dir: Optional[str] = None, ) -> str: +======= + exported_program, + *, + options: Optional[dict[str, str]] = None, + stable_output=False, + save_dir=None, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model_str = textwrap.dedent( f""" {generate_env_vars_string(stable_output=stable_output)} @@ -260,10 +326,15 @@ def generate_compiler_repro_exported_program( if hasattr(torch.version, "git_version"): model_str += f"# torch git version: {torch.version.git_version}\n\n\n" model_str += _cuda_system_info_comment() +<<<<<<< HEAD if save_dir: ep_path = os.path.join(save_dir, "exported_program.pt2") else: ep_path = "exported_program.pt2" +======= + + ep_path = os.path.join(save_dir, "exported_program.pt2") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.export.save(exported_program, ep_path) model_str += f"exported_program = torch.export.load('{ep_path}')\n" @@ -272,7 +343,11 @@ def generate_compiler_repro_exported_program( return model_str +<<<<<<< HEAD def repro_load_args(load_args: Any, save_dir: Optional[str]) -> tuple[Any]: +======= +def repro_load_args(load_args, save_dir): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not hasattr(load_args, "_version"): log.warning( "load_args does not have a _version attribute, please file a bug to PyTorch " @@ -298,6 +373,7 @@ def repro_load_args(load_args: Any, save_dir: Optional[str]) -> tuple[Any]: return tuple(args) +<<<<<<< HEAD def repro_common( options: Any, exported_program: ExportedProgram ) -> tuple[torch.fx.GraphModule, Any, Any]: @@ -312,15 +388,29 @@ def repro_get_args( exported_program: ExportedProgram, config_patches: Optional[dict[str, Any]], ) -> tuple[torch.fx.GraphModule, Any, Any]: +======= +def repro_common(options, exported_program): + torch._inductor.config.generate_intermediate_hooks = True + mod = exported_program.module() + args, kwargs = exported_program.example_inputs + return mod, args, kwargs + + +def repro_get_args(options, exported_program, config_patches): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mod, args, kwargs = repro_common(options, exported_program) return mod, args, kwargs +<<<<<<< HEAD def repro_run( options: Any, exported_program: ExportedProgram, config_patches: Optional[dict[str, Any]], ) -> None: +======= +def repro_run(options, exported_program, config_patches): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor import _aoti_compile_and_package_inner gm, args, kwargs = repro_common(options, exported_program) @@ -348,10 +438,14 @@ def repro_run( def export_for_aoti_minifier( +<<<<<<< HEAD gm: torch.nn.Module, tuple_inputs: tuple[Any], strict: bool = False, skip_export_error: bool = True, +======= + gm, tuple_inputs, strict=False, skip_export_error=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Optional[torch.nn.Module]: # Some graphs cannot be used for AOTI/export (illegal graphs), these should be # considered as graphs that don't fail in the minifier, so the minifier keeps searching. @@ -368,7 +462,11 @@ def export_for_aoti_minifier( try: ep = torch.export.export(gm, tuple_inputs, strict=strict) +<<<<<<< HEAD gm = ep.module(check_guards=False) +======= + gm = ep.module() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return gm except Exception as e: if skip_export_error: @@ -386,11 +484,15 @@ def export_for_aoti_minifier( return None +<<<<<<< HEAD def repro_minify( options: Any, exported_program: ExportedProgram, config_patches: Optional[dict[str, Any]], ) -> None: +======= +def repro_minify(options, exported_program, config_patches): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from functorch.compile import minifier from torch._inductor import _aoti_compile_and_package_inner from torch._inductor.compile_fx import _aoti_flatten_inputs @@ -415,11 +517,15 @@ def repro_minify( need_sync = True break +<<<<<<< HEAD def module_fails( gm: torch.fx.GraphModule, flat_example_inputs: list[Any], check_str: Optional[str] = None, ) -> bool: +======= + def module_fails(gm, flat_example_inputs, check_str=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Need to export first so the in_spec and out_spec are populated tuple_inputs = tuple(flat_example_inputs) gm = export_for_aoti_minifier( @@ -469,6 +575,7 @@ def module_fails( def run_repro( +<<<<<<< HEAD exported_program: ExportedProgram, *, config_patches: Optional[dict[str, str]] = None, @@ -481,6 +588,20 @@ def run_repro( skip_export_error: bool = True, **more_kwargs: Any, ) -> Any: +======= + exported_program, + *, + config_patches: Optional[dict[str, str]] = None, + command="run", + accuracy: Union[bool, str] = "", + save_dir=None, + tracing_mode=None, + check_str=None, + minifier_export_mode="python", + skip_export_error=True, + **more_kwargs, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for k in more_kwargs: log.warning( "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", @@ -508,7 +629,11 @@ def run_repro( formatter_class=argparse.RawTextHelpFormatter, ) +<<<<<<< HEAD def common_flags(parser: argparse.ArgumentParser) -> None: +======= + def common_flags(parser): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) accuracy_group = parser.add_mutually_exclusive_group() accuracy_group.add_argument( "--no-accuracy", diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 840e02a9cdb80..ba1190d86c89d 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module provides functionality for resuming Python execution at specific points in code, primarily used by PyTorch Dynamo for control flow handling and optimization. It implements @@ -17,12 +22,18 @@ import dataclasses import sys import types +<<<<<<< HEAD from collections.abc import Iterable from contextlib import AbstractContextManager from typing import Any, Callable, cast, Optional from .bytecode_transformation import ( add_push_null, +======= +from typing import Any, cast, Optional + +from .bytecode_transformation import ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bytecode_from_template, create_call_function, create_instruction, @@ -50,10 +61,16 @@ # trace_rules.py import this constant for consistency TORCH_DYNAMO_RESUME_IN_PREFIX = "torch_dynamo_resume_in" +<<<<<<< HEAD IS_TRACING_RESUME_PROLOGUE_VARNAME = "__is_tracing_resume_prologue" def _initial_push_null(insts: list[Instruction]) -> None: +======= + + +def _initial_push_null(insts): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 11): insts.append(create_instruction("PUSH_NULL")) if sys.version_info < (3, 13): @@ -61,11 +78,15 @@ def _initial_push_null(insts: list[Instruction]) -> None: # Generates bytecode from template and splits the code where LOAD_FAST dummy is present. +<<<<<<< HEAD def _bytecode_from_template_with_split( template: Callable[..., Any], stack_index: int, varname_map: Optional[dict[str, Any]] = None, ) -> tuple[list[Instruction], list[Instruction]]: +======= +def _bytecode_from_template_with_split(template, stack_index, varname_map=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template_code = bytecode_from_template(template, varname_map=varname_map) template_code.append(create_instruction("POP_TOP")) @@ -83,7 +104,11 @@ def _bytecode_from_template_with_split( ), (None, None), ) +<<<<<<< HEAD assert dummy_idx is not None and dummy_inst is not None +======= + assert dummy_idx is not None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # replace LOAD_FAST dummy with first NOP marking exception area overwrite_instruction(dummy_inst, [create_instruction("NOP")]) @@ -95,7 +120,11 @@ def _bytecode_from_template_with_split( return template_code[: dummy_idx + 1], template_code[dummy_idx + 1 :] +<<<<<<< HEAD def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None: +======= +def _try_except_tf_mode_template(dummy, stack_var_name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NOTE: Make sure this name matches what is generated by symbolic_convert:import_source # on torch._dynamo.utils. global __import_torch_dot__dynamo_dot_utils @@ -113,9 +142,13 @@ class ReenterWith: stack_index: int target_values: Optional[tuple[Any, ...]] = None +<<<<<<< HEAD def try_except_torch_function_mode( self, code_options: dict[str, Any], cleanup: list[Instruction] ) -> list[Instruction]: +======= + def try_except_torch_function_mode(self, code_options, cleanup: list[Instruction]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Codegen based off of: try: @@ -137,9 +170,13 @@ def try_except_torch_function_mode( # If we do not want to destroy the stack, we can do the same thing as a # `SETUP_WITH` block, only that we store the context manager in a local_symbol +<<<<<<< HEAD def try_finally( self, code_options: dict[str, Any], cleanup: list[Instruction] ) -> list[Instruction]: +======= + def try_finally(self, code_options, cleanup: list[Instruction]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Codegen based off of: load args @@ -170,7 +207,11 @@ def try_finally( ] ) +<<<<<<< HEAD def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None: +======= + def _template(ctx, dummy): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ctx.__enter__() try: dummy @@ -183,9 +224,13 @@ def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None: cleanup[:] = epilogue + cleanup return create_ctx + setup_try_finally +<<<<<<< HEAD def __call__( self, code_options: dict[str, Any], cleanup: list[Instruction] ) -> tuple[list[Instruction], Optional[Instruction]]: +======= + def __call__(self, code_options, cleanup): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Codegen based off of: with ctx(args): @@ -205,7 +250,11 @@ def __call__( ] ) +<<<<<<< HEAD def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None: +======= + def _template(ctx, dummy): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with ctx: dummy @@ -249,6 +298,7 @@ class ResumeFunctionMetadata: prefix_block_target_offset_remap: list[int] = dataclasses.field( default_factory=list ) +<<<<<<< HEAD # per-offset map from new block target offsets to original block target offsets block_target_offset_remap: dict[int, dict[int, int]] = dataclasses.field( default_factory=dict @@ -260,6 +310,13 @@ def _filter_iter( l2: Iterable[Any], cond: Callable[[Any, Any], bool], ) -> list[Any]: +======= + # map from new block target offsets to original block target offsets + block_target_offset_remap: Optional[dict[int, int]] = None + + +def _filter_iter(l1, l2, cond): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Two-pointer conditional filter. e.g. _filter_iter(insts, sorted_offsets, lambda i, o: i.offset == o) @@ -278,7 +335,11 @@ def _filter_iter( return res +<<<<<<< HEAD def _load_tuple_and_call(tup: tuple[Any, ...]) -> list[Instruction]: +======= +def _load_tuple_and_call(tup): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) insts: list[Instruction] = [] _initial_push_null(insts) insts.extend(create_load_const(val) for val in tup) @@ -291,7 +352,11 @@ class ContinueExecutionCache: generated_code_metadata = ExactWeakKeyDictionary() @classmethod +<<<<<<< HEAD def lookup(cls, code: types.CodeType, lineno: int, *key: Any) -> types.CodeType: +======= + def lookup(cls, code, lineno, *key): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if code not in cls.cache: cls.cache[code] = {} key = tuple(key) @@ -302,8 +367,13 @@ def lookup(cls, code: types.CodeType, lineno: int, *key: Any) -> types.CodeType: @classmethod def generate( cls, +<<<<<<< HEAD code: types.CodeType, lineno: int, +======= + code, + lineno, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) offset: int, setup_fn_target_offsets: tuple[int, ...], # only used in Python 3.11+ nstack: int, @@ -313,9 +383,12 @@ def generate( stack_ctx_vars: tuple[tuple[int, tuple[Any, ...]], ...], argnames_ctx_vars: tuple[tuple[str, tuple[Any, ...]], ...], null_idxes: tuple[int, ...], +<<<<<<< HEAD # mainly used to ensure distinct code objects per stack trace, # which prevents excessive recompilation of inner frames nested_code_objs: tuple[types.CodeType], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> types.CodeType: assert offset is not None assert not ( @@ -336,12 +409,16 @@ def generate( stack_ctx_vars, argnames_ctx_vars, null_idxes, +<<<<<<< HEAD nested_code_objs, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) is_py311_plus = sys.version_info >= (3, 11) meta = ResumeFunctionMetadata(code) +<<<<<<< HEAD def update( instructions: list[Instruction], code_options: dict[str, Any] ) -> None: @@ -349,6 +426,12 @@ def update( args = ["__nested_resume_fns", "__nested_frame_values"] args += [f"___stack{i}" for i in range(nstack)] +======= + def update(instructions: list[Instruction], code_options: dict[str, Any]): + meta.instructions = copy.deepcopy(instructions) + + args = [f"___stack{i}" for i in range(nstack)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args.extend(v for v in argnames if v not in args) freevars = tuple(code_options["co_cellvars"] or []) + tuple( code_options["co_freevars"] or [] @@ -376,8 +459,16 @@ def update( code_options["co_varnames"] = tuple( args + [v for v in argnames_null if v not in args] +<<<<<<< HEAD + [v for v in code_options["co_varnames"] if v not in args] + [IS_TRACING_RESUME_PROLOGUE_VARNAME] +======= + + [ + v + for v in code_options["co_varnames"] + if v not in args and v not in freevars + ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) code_options["co_flags"] = code_options["co_flags"] & ~( CO_VARARGS | CO_VARKEYWORDS @@ -392,6 +483,7 @@ def update( ) prefix.append(create_instruction("RESUME", arg=0)) +<<<<<<< HEAD # Set is_tracing_resume_prologue to prevent graph breaks. # This doesn't really do anything at runtime, but dynamo will trace this # and will know that we're in a resume function prologue. @@ -404,6 +496,8 @@ def update( ] ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cleanup: list[Instruction] = [] hooks = {fn.stack_index: fn for fn in setup_fns} hook_target_offsets = { @@ -465,6 +559,7 @@ def update( ] ) +<<<<<<< HEAD # Call nested resume function if nested_code_objs: prefix.extend( @@ -527,6 +622,8 @@ def update( ] ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prefix.append(create_jump_absolute(target)) # because the line number table monotonically increases from co_firstlineno @@ -551,19 +648,31 @@ def update( inst.exn_tab_entry and inst.exn_tab_entry.target in old_hook_target_remap ): +<<<<<<< HEAD inst.exn_tab_entry.target = old_hook_target_remap[ # type: ignore[assignment] +======= + inst.exn_tab_entry.target = old_hook_target_remap[ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inst.exn_tab_entry.target ] # TODO(jansel): add dead code elimination here instructions[:] = prefix + instructions +<<<<<<< HEAD new_code, _ = transform_code_object(code, update) +======= + new_code = transform_code_object(code, update) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ContinueExecutionCache.generated_code_metadata[new_code] = meta return new_code @staticmethod +<<<<<<< HEAD def unreachable_codes(code_options: dict[str, Any]) -> list[Instruction]: +======= + def unreachable_codes(code_options) -> list[Instruction]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Codegen a `raise None` to make analysis work for unreachable code""" return [ create_load_const(None), @@ -572,6 +681,7 @@ def unreachable_codes(code_options: dict[str, Any]) -> list[Instruction]: @classmethod def generate_based_on_original_code_object( +<<<<<<< HEAD cls, code: types.CodeType, lineno: int, @@ -579,6 +689,10 @@ def generate_based_on_original_code_object( setup_fn_target_offsets: tuple[int, ...], *args: Any, ) -> types.CodeType: +======= + cls, code, lineno, offset: int, setup_fn_target_offsets: tuple[int, ...], *args + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This handles the case of generating a resume into code generated to resume something else. We want to always generate starting @@ -590,11 +704,19 @@ def generate_based_on_original_code_object( meta: ResumeFunctionMetadata = ContinueExecutionCache.generated_code_metadata[ code ] +<<<<<<< HEAD new_offset = -1 def find_new_offset( instructions: list[Instruction], code_options: dict[str, Any] ) -> None: +======= + new_offset = None + + def find_new_offset( + instructions: list[Instruction], code_options: dict[str, Any] + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal new_offset (target,) = (i for i in instructions if i.offset == offset) # match the functions starting at the last instruction as we have added a prefix @@ -604,17 +726,24 @@ def find_new_offset( if i1 is target ) assert target.opcode == new_target.opcode +<<<<<<< HEAD assert new_target.offset is not None new_offset = new_target.offset transform_code_object(code, find_new_offset) assert new_offset >= 0 +======= + new_offset = new_target.offset + + transform_code_object(code, find_new_offset) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 11): # setup_fn_target_offsets currently contains the target offset of # each setup_fn, based on `code`. When we codegen the resume function # based on the original code object, `meta.code`, the offsets in # setup_fn_target_offsets must be based on `meta.code` instead. +<<<<<<< HEAD if new_offset not in meta.block_target_offset_remap: block_target_offset_remap = meta.block_target_offset_remap[ new_offset @@ -623,6 +752,14 @@ def find_new_offset( def remap_block_offsets( instructions: list[Instruction], code_options: dict[str, Any] ) -> None: +======= + if not meta.block_target_offset_remap: + block_target_offset_remap = meta.block_target_offset_remap = {} + + def remap_block_offsets( + instructions: list[Instruction], code_options: dict[str, Any] + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NOTE: each prefix block generates exactly one PUSH_EXC_INFO, # so we can tell which block a prefix PUSH_EXC_INFO belongs to, # by counting. Then we can use meta.prefix_block-target_offset_remap @@ -666,8 +803,12 @@ def remap_block_offsets( # if offset is not in setup_fn_target_offsets, it is an error setup_fn_target_offsets = tuple( +<<<<<<< HEAD meta.block_target_offset_remap[new_offset][n] for n in setup_fn_target_offsets +======= + meta.block_target_offset_remap[n] for n in setup_fn_target_offsets +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return ContinueExecutionCache.lookup( meta.code, lineno, new_offset, setup_fn_target_offsets, *args diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 80b22e55227cd..bfe03619e309d 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Side effect tracking and management for TorchDynamo's compilation system. @@ -26,12 +31,19 @@ import inspect import warnings import weakref +<<<<<<< HEAD from collections.abc import Generator, MutableMapping +======= +from collections.abc import MutableMapping +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from types import CellType from typing import Any, Optional, TYPE_CHECKING import torch.nn +<<<<<<< HEAD from torch._dynamo.variables.misc import AutogradFunctionContextVariable +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from . import graph_break_hints, utils, variables from .bytecode_transformation import ( @@ -57,6 +69,7 @@ if TYPE_CHECKING: +<<<<<<< HEAD from torch._dynamo.output_graph import OutputGraph from torch._dynamo.symbolic_convert import InstructionTranslatorBase from torch._dynamo.variables.lists import ListVariable @@ -65,17 +78,32 @@ def _manual_dict_setitem( dict_from: dict[Any, Any], dict_to: dict[Any, Any], mro_index: int ) -> None: +======= + from torch._dynamo.symbolic_convert import InstructionTranslator + + +def _manual_dict_setitem(dict_from, dict_to, mro_index): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have # to be careful because we don't want to trigger the user defined object # setitem or clear. The mro_index is used to find the dict/OrderedDict from # the class mro. dict_class = type(dict_to).__mro__[mro_index] +<<<<<<< HEAD dict_class.clear(dict_to) # type: ignore[attr-defined] for k, v in dict_from.items(): dict_class.__setitem__(dict_to, k, v) # type: ignore[index] def _manual_list_update(list_from: list[Any], list_to: list[Any]) -> None: +======= + dict_class.clear(dict_to) + for k, v in dict_from.items(): + dict_class.__setitem__(dict_to, k, v) + + +def _manual_list_update(list_from, list_to): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) list.clear(list_to) list.extend(list_to, list_from) @@ -106,6 +134,7 @@ class SideEffects: def __init__( self, +<<<<<<< HEAD output_graph: "OutputGraph", id_to_variable: Optional[dict[int, VariableTracker]] = None, store_attr_mutations: Optional[ @@ -127,6 +156,15 @@ def __init__( ] ] = None, ) -> None: +======= + output_graph, + id_to_variable=None, + store_attr_mutations=None, + keepalive=None, + save_for_backward=None, + tensor_hooks=None, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__() self.output_graph_weakref = weakref.ref(output_graph) self.id_to_variable = id_to_variable or {} @@ -139,6 +177,7 @@ def __init__( self._has_existing_dict_mutation = False # Track Compiled Autograd final callbacks that must be called at the end of Compiled Autograd backward graph. # Only applicable if this graph is created from Dynamo tracing in Compiled Autograd. +<<<<<<< HEAD self.ca_final_callbacks_var: Optional[ListVariable] = None # Tracks VariableTracker objects whose mutations can be skipped. @@ -159,6 +198,9 @@ def stop_ignoring_mutations_on(self, var: VariableTracker) -> None: """Remove a variable from the skip mutation set, restoring normal mutation tracking.""" if var in self.ignore_mutation_on_these_variables: self.ignore_mutation_on_these_variables.remove(var) +======= + self.ca_final_callbacks_var = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __eq__(self, other: object) -> bool: assert isinstance(other, SideEffects) @@ -192,12 +234,19 @@ def diff(self, other: "SideEffects") -> Optional[str]: else: return None +<<<<<<< HEAD def clone(self) -> "SideEffects": """Create a shallow copy""" ref = self.output_graph_weakref() assert ref is not None return self.__class__( output_graph=ref, +======= + def clone(self): + """Create a shallow copy""" + return self.__class__( + output_graph=self.output_graph_weakref(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) id_to_variable=dict(self.id_to_variable), store_attr_mutations={ k: dict(v) for k, v in self.store_attr_mutations.items() @@ -207,6 +256,7 @@ def clone(self) -> "SideEffects": tensor_hooks=self.tensor_hooks, ) +<<<<<<< HEAD def __contains__(self, item: Any) -> bool: return id(item) in self.id_to_variable @@ -216,27 +266,55 @@ def __getitem__(self, item: Any) -> VariableTracker: def should_allow_side_effects_under_checkpoint(self) -> bool: output_graph = self.output_graph_weakref() return bool( +======= + def __contains__(self, item): + return id(item) in self.id_to_variable + + def __getitem__(self, item): + return self.id_to_variable[id(item)] + + def should_allow_side_effects_under_checkpoint(self): + output_graph = self.output_graph_weakref() + return ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_graph and output_graph.current_tx.output.current_tracer.under_activation_checkpoint and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint ) +<<<<<<< HEAD def should_allow_externally_visible_side_effects_in_subtracer(self) -> bool: output_graph = self.output_graph_weakref() return bool( +======= + def should_allow_externally_visible_side_effects_in_subtracer(self): + output_graph = self.output_graph_weakref() + return ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_graph and output_graph.current_tx.output.current_tracer.unsafe_allow_externally_visible_side_effects ) +<<<<<<< HEAD def is_reconstructing_generator(self) -> bool: output_graph = self.output_graph_weakref() return bool( +======= + def is_reconstructing_generator(self): + output_graph = self.output_graph_weakref() + + return ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_graph and output_graph.current_tx.output.current_tracer.is_reconstructing_generator ) +<<<<<<< HEAD def check_allowed_side_effect(self, item: VariableTracker) -> bool: +======= + def check_allowed_side_effect(self, item): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo.variables.misc import AutogradFunctionContextVariable # People do things like self.dim = dim inside autograd.Function. @@ -263,17 +341,23 @@ def check_allowed_side_effect(self, item: VariableTracker) -> bool: explanation="This is not supported.", hints=[], ) +<<<<<<< HEAD return False def store_attr( self, item: VariableTracker, name: str, value: VariableTracker ) -> None: +======= + + def store_attr(self, item: VariableTracker, name: str, value: VariableTracker): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.is_attribute_mutation(item) self.check_allowed_side_effect(item) if item not in self.store_attr_mutations: self.store_attr_mutations[item] = {} self.store_attr_mutations[item][name] = value +<<<<<<< HEAD def load_attr( self, item: VariableTracker, @@ -281,6 +365,9 @@ def load_attr( deleted_ok: bool = False, check: bool = False, ) -> VariableTracker: +======= + def load_attr(self, item, name, deleted_ok=False, check=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if check: assert self.is_attribute_mutation(item) result = self.store_attr_mutations[item][name] @@ -293,7 +380,11 @@ def load_attr( ) return result +<<<<<<< HEAD def store_cell(self, cellvar: VariableTracker, value: VariableTracker) -> None: +======= + def store_cell(self, cellvar, value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if cellvar.is_immutable(): unimplemented_v2( gb_type="Write to immutable cell", @@ -305,7 +396,11 @@ def store_cell(self, cellvar: VariableTracker, value: VariableTracker) -> None: assert isinstance(value, variables.VariableTracker) self.store_attr(cellvar, "cell_contents", value) +<<<<<<< HEAD def load_cell(self, cellvar: VariableTracker) -> VariableTracker: +======= + def load_cell(self, cellvar): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(cellvar, variables.CellVariable) if self.has_pending_mutation_of_attr(cellvar, "cell_contents"): return self.load_attr(cellvar, "cell_contents", check=False) @@ -318,6 +413,7 @@ def load_cell(self, cellvar: VariableTracker) -> VariableTracker: hints=[*graph_break_hints.USER_ERROR], ) +<<<<<<< HEAD def load_global(self, gvar: VariableTracker, name: str) -> VariableTracker: assert isinstance(gvar, variables.VariableTracker) return self.load_attr(gvar, name) @@ -325,17 +421,31 @@ def load_global(self, gvar: VariableTracker, name: str) -> VariableTracker: def store_global( self, gvar: VariableTracker, name: str, value: VariableTracker ) -> None: +======= + def load_global(self, gvar: VariableTracker, name: str): + assert isinstance(gvar, variables.VariableTracker) + return self.load_attr(gvar, name) + + def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(gvar, variables.VariableTracker) assert isinstance(value, variables.VariableTracker) self.store_attr(gvar, name, value) @staticmethod +<<<<<<< HEAD def cls_supports_mutation_side_effects(cls: type) -> bool: return inspect.getattr_static(cls, "__getattribute__", None) in ( object.__getattribute__, dict.__getattribute__, set.__getattribute__, frozenset.__getattribute__, +======= + def cls_supports_mutation_side_effects(cls): + return inspect.getattr_static(cls, "__getattribute__", None) in ( + object.__getattribute__, + dict.__getattribute__, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int.__getattribute__, str.__getattribute__, list.__getattribute__, @@ -343,20 +453,35 @@ def cls_supports_mutation_side_effects(cls: type) -> bool: BaseException.__getattribute__, ) +<<<<<<< HEAD def is_attribute_mutation(self, item: VariableTracker) -> bool: return isinstance(item.mutation_type, AttributeMutation) def has_pending_mutation(self, item: VariableTracker) -> bool: +======= + def is_attribute_mutation(self, item): + return isinstance(item.mutation_type, AttributeMutation) + + def has_pending_mutation(self, item): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.is_attribute_mutation(item) and bool( self.store_attr_mutations.get(item) ) +<<<<<<< HEAD def has_pending_mutation_of_attr(self, item: VariableTracker, name: str) -> bool: +======= + def has_pending_mutation_of_attr(self, item, name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.is_attribute_mutation( item ) and name in self.store_attr_mutations.get(item, ()) +<<<<<<< HEAD def is_modified(self, item: VariableTracker) -> bool: +======= + def is_modified(self, item): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if item.is_immutable(): return False if isinstance(item.mutation_type, (AttributeMutationNew, ValueMutationNew)): @@ -371,14 +496,23 @@ def is_modified(self, item: VariableTracker) -> bool: if self.is_attribute_mutation(item): return item in self.store_attr_mutations +<<<<<<< HEAD return item.mutation_type.is_modified # type: ignore[attr-defined] +======= + return item.mutation_type.is_modified +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _track_obj( self, item: Any, variable: VariableTracker, +<<<<<<< HEAD mutation_type_cls: type = ValueMutationExisting, ) -> VariableTracker: +======= + mutation_type_cls=ValueMutationExisting, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Start tracking an existing or new variable for mutation""" if id(item) in self.id_to_variable: raise AssertionError( @@ -400,7 +534,11 @@ def track_object_existing( self, item: Any, variable: VariableTracker, +<<<<<<< HEAD ) -> VariableTracker: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self._track_obj( item, variable, @@ -412,8 +550,13 @@ def track_object_new( cls_source: Source, user_cls: Any, variable_cls: Any, +<<<<<<< HEAD options: dict[str, Any], ) -> VariableTracker: +======= + options, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if user_cls is torch.autograd.function.FunctionCtx: with warnings.catch_warnings(record=True): obj = torch.autograd.Function() @@ -428,7 +571,11 @@ def track_object_new( self.keepalive.append(obj) return variable +<<<<<<< HEAD def get_variable_cls(self, user_cls: type) -> type: +======= + def get_variable_cls(self, user_cls): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.overrides import TorchFunctionMode from .variables.ctx_manager import GenericContextWrappingVariable @@ -452,8 +599,11 @@ def get_variable_cls(self, user_cls: type) -> type: variable_cls = variables.UnspecializedNNModuleVariable elif issubclass(user_cls, (dict, collections.OrderedDict)): variable_cls = variables.UserDefinedDictVariable +<<<<<<< HEAD elif issubclass(user_cls, (set, frozenset)): variable_cls = variables.UserDefinedSetVariable +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif issubclass(user_cls, tuple): variable_cls = variables.UserDefinedTupleVariable elif issubclass(user_cls, list): @@ -469,11 +619,19 @@ def get_variable_cls(self, user_cls: type) -> type: def get_example_value( self, +<<<<<<< HEAD base_cls_vt: VariableTracker, cls_vt: VariableTracker, init_args: list[VariableTracker], ) -> Any: user_cls = cls_vt.value # type: ignore[attr-defined] +======= + base_cls_vt, + cls_vt, + init_args, + ): + user_cls = cls_vt.value +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if issubclass(user_cls, torch.nn.Module): # TODO(anijain2305) - Is it possible to remove this specialization? obj = nn_module_new(user_cls) @@ -500,10 +658,17 @@ def get_example_value( def track_new_user_defined_object( self, +<<<<<<< HEAD base_cls_vt: VariableTracker, cls_vt: VariableTracker, init_args: list[VariableTracker], ) -> VariableTracker: +======= + base_cls_vt, + cls_vt, + init_args, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Creates a UserDefinedObjectVariable (or its subclass) variable tracker and mark it for attribute mutation tracking. @@ -513,7 +678,11 @@ def track_new_user_defined_object( base_cls_vt.__new__(user_cls, *init_args) """ cls_source = cls_vt.source +<<<<<<< HEAD user_cls = cls_vt.value # type: ignore[attr-defined] +======= + user_cls = cls_vt.value +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) variable_cls = self.get_variable_cls(user_cls) obj = self.get_example_value(base_cls_vt, cls_vt, init_args) @@ -530,7 +699,11 @@ def track_new_user_defined_object( def track_cell_new( self, +<<<<<<< HEAD ) -> VariableTracker: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = object() variable = variables.CellVariable( mutation_type=AttributeMutationNew(), @@ -541,7 +714,11 @@ def track_cell_new( def track_cell_existing( self, source: Optional[Source], cell: CellType, contents: VariableTracker +<<<<<<< HEAD ) -> VariableTracker: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) variable = variables.CellVariable( # We don't support mutation to cell without source because we need # source to properly codegen the mutations. @@ -553,7 +730,11 @@ def track_cell_existing( self.keepalive.append(cell) return variable +<<<<<<< HEAD def track_global_existing(self, source: Source, item: Any) -> VariableTracker: +======= + def track_global_existing(self, source: Source, item: Any): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) variable = variables.NewGlobalVariable( mutation_type=AttributeMutationExisting(), source=source, @@ -562,6 +743,7 @@ def track_global_existing(self, source: Source, item: Any) -> VariableTracker: self.keepalive.append(item) return variable +<<<<<<< HEAD def track_save_for_backward( self, ctx: VariableTracker, args: list[VariableTracker] ) -> None: @@ -571,6 +753,13 @@ def track_save_for_backward( def track_runahead_tensor_and_symvar_side_effects( self, other: "SideEffects" ) -> None: +======= + def track_save_for_backward(self, ctx, args): + assert isinstance(ctx, variables.AutogradFunctionContextVariable) + self.save_for_backward.append((ctx, args)) + + def track_tensor_variables_from_runahead_side_effects(self, other): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # In higher order ops we want to keep track of tensors seen in the # speculate_subgraph so that we don't lift them again as a new input in # other speculate_subgraph or in the root tracer. @@ -578,16 +767,28 @@ def track_runahead_tensor_and_symvar_side_effects( other_id = id(other_item) other_variable = other.id_to_variable[other_id] if other_id not in self.id_to_variable and isinstance( +<<<<<<< HEAD other_variable, (variables.TensorVariable, variables.SymNodeVariable) ): self.track_object_existing(other_item, other_variable) def prune_dead_object_new(self, tx: "InstructionTranslatorBase") -> None: +======= + other_variable, variables.TensorVariable + ): + self.track_object_existing(other_item, other_variable) + + def prune_dead_object_new(self, tx): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Avoid VT cycles from e.g., recursive function. visited: set[VariableTracker] = set() live_new_objects: set[VariableTracker] = set() +<<<<<<< HEAD def visit(var: VariableTracker) -> None: +======= + def visit(var: VariableTracker): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if var in visited: return visited.add(var) @@ -603,7 +804,11 @@ def visit(var: VariableTracker) -> None: self.store_attr_mutations[var], ) +<<<<<<< HEAD def is_live(var: VariableTracker) -> bool: +======= + def is_live(var: VariableTracker): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(var.mutation_type, AttributeMutationNew): return var in live_new_objects return True @@ -617,6 +822,7 @@ def is_live(var: VariableTracker) -> bool: # The only live side effects come from returns (tx.stack), any intermediates # during a graph break (tx.symbolic_locals), and mutation on pre-existing variables. # Recursively visit Variables and see if any of them have been mutated. +<<<<<<< HEAD init_live_vars = [] # gather stack/symbolic_locals for all tx's up the chain cur_tx: Optional[InstructionTranslatorBase] = tx @@ -632,6 +838,18 @@ def is_live(var: VariableTracker) -> bool: tx.output.backward_state, self.tensor_hooks, ], +======= + VariableTracker.visit( + visit, + # TODO track from all possible sources. + ( + tx.stack, + tx.symbolic_locals, + pre_existing_vars, + tx.output.backward_state, + self.tensor_hooks, + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Manually release the self-referential function, which indirectly # captures certain `VariableTracker` and affects parts of PT test/logic @@ -651,10 +869,14 @@ def is_live(var: VariableTracker) -> bool: k: v for k, v in self.store_attr_mutations.items() if is_live(k) } +<<<<<<< HEAD def mutation(self, var: VariableTracker) -> None: if var in self.ignore_mutation_on_these_variables: return +======= + def mutation(self, var): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.check_allowed_side_effect(var) if isinstance(var.mutation_type, ValueMutationExisting): var.mutation_type.is_modified = True @@ -665,6 +887,7 @@ def mutation(self, var: VariableTracker) -> None: ): self._has_existing_dict_mutation = True +<<<<<<< HEAD def has_existing_dict_mutation(self) -> bool: return self._has_existing_dict_mutation @@ -672,6 +895,15 @@ def _get_modified_vars(self) -> list[VariableTracker]: return [var for var in self.id_to_variable.values() if self.is_modified(var)] def codegen_save_tempvars(self, cg: PyCodegen) -> None: +======= + def has_existing_dict_mutation(self): + return self._has_existing_dict_mutation + + def _get_modified_vars(self): + return [var for var in self.id_to_variable.values() if self.is_modified(var)] + + def codegen_save_tempvars(self, cg: PyCodegen): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We must codegen modified VT to their source by default, so that # mutation and aliasing are properly accounted for. # @@ -731,7 +963,11 @@ def codegen_save_tempvars(self, cg: PyCodegen) -> None: # base_cls.__new__(user_cls, *args) if isinstance(var, variables.UserDefinedObjectVariable): +<<<<<<< HEAD def load_new_method() -> None: +======= + def load_new_method(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert var.base_cls_vt is not None cg(var.base_cls_vt) # type: ignore[attr-defined] cg.extend_output([cg.create_load_attr("__new__")]) @@ -741,6 +977,7 @@ def load_new_method() -> None: cg.add_push_null( lambda: cg.load_import_from(utils.__name__, "object_new") ) +<<<<<<< HEAD assert var.mutation_type.cls_source is not None cg(var.mutation_type.cls_source) @@ -750,6 +987,16 @@ def load_new_method() -> None: # Call the __new__ method cg.extend_output(create_call_function(1 + len(var.init_args), False)) # type: ignore[attr-defined] +======= + cg(var.mutation_type.cls_source) + + # Generate the args to the __new__ method + for arg in var.init_args: + cg(arg) + + # Call the __new__ method + cg.extend_output(create_call_function(1 + len(var.init_args), False)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cg.add_cache(var) var.source = LocalSource(cg.tempvars[var]) @@ -766,6 +1013,7 @@ def load_new_method() -> None: ] ) +<<<<<<< HEAD def register_hook( self, tensor: "variables.TensorVariable", @@ -773,6 +1021,9 @@ def register_hook( handle: "variables.RemovableHandleVariable", name: str, ) -> None: +======= + def register_hook(self, tensor, hook, handle, name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(tensor, variables.TensorVariable) assert isinstance(hook, variables.VariableTracker) assert ( @@ -788,10 +1039,17 @@ def register_hook( assert not handle.idx handle.idx = idx +<<<<<<< HEAD def remove_hook(self, idx: int) -> None: del self.tensor_hooks[idx] def codegen_hooks(self, cg: PyCodegen) -> None: +======= + def remove_hook(self, idx): + del self.tensor_hooks[idx] + + def codegen_hooks(self, cg): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for ( tensor, hook, @@ -833,7 +1091,11 @@ def codegen_hooks(self, cg: PyCodegen) -> None: # - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST. assert tensor.source, "Hooks on non input tensors NYI - should not get here" +<<<<<<< HEAD def gen_fn() -> None: +======= + def gen_fn(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cg(tensor) cg.extend_output([cg.create_load_attr(name)]) @@ -845,17 +1107,27 @@ def gen_fn() -> None: # be associated with the return value of register_hook(). This consumes the top of stack. cg.add_cache(handle) +<<<<<<< HEAD def get_ca_final_callbacks_var(self) -> "variables.ListVariable": +======= + def get_ca_final_callbacks_var(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .variables.base import ValueMutationNew if self.ca_final_callbacks_var is None: self.ca_final_callbacks_var = variables.ListVariable( [], mutation_type=ValueMutationNew() ) +<<<<<<< HEAD return self.ca_final_callbacks_var def codegen_update_mutated(self, cg: PyCodegen) -> None: +======= + return self.ca_final_callbacks_var + + def codegen_update_mutated(self, cg: PyCodegen): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) suffixes = [] for var in self._get_modified_vars(): if isinstance(var, variables.ListVariable): @@ -1148,7 +1420,11 @@ def codegen_update_mutated(self, cg: PyCodegen) -> None: cg.pop_top() elif isinstance(var, variables.RandomVariable): # set correct random seed state +<<<<<<< HEAD def gen_fn() -> None: +======= + def gen_fn(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cg(var.source) # type: ignore[attr-defined] cg.load_attr("setstate") @@ -1168,7 +1444,11 @@ def gen_fn() -> None: for suffix in reversed(suffixes): cg.extend_output(suffix) +<<<<<<< HEAD def is_empty(self) -> bool: +======= + def is_empty(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return not ( any(map(self.is_modified, self.id_to_variable.values())) or self.tensor_hooks @@ -1176,15 +1456,23 @@ def is_empty(self) -> bool: or self.tensor_hooks ) +<<<<<<< HEAD def clear(self) -> None: +======= + def clear(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.keepalive.clear() self.id_to_variable.clear() @contextlib.contextmanager +<<<<<<< HEAD def allow_side_effects_under_checkpoint( tx: "InstructionTranslatorBase", ) -> Generator[None, None, None]: +======= +def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert tx.output.current_tracer.under_activation_checkpoint orig_val = tx.output.current_tracer.allow_side_effects_under_checkpoint try: @@ -1195,9 +1483,13 @@ def allow_side_effects_under_checkpoint( @contextlib.contextmanager +<<<<<<< HEAD def allow_externally_visible_side_effects_in_subtracer( tx: "InstructionTranslatorBase", ) -> Generator[None, None, None]: +======= +def allow_externally_visible_side_effects_in_subtracer(tx: "InstructionTranslator"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig_val = tx.output.current_tracer.unsafe_allow_externally_visible_side_effects try: tx.output.current_tracer.unsafe_allow_externally_visible_side_effects = True @@ -1207,9 +1499,13 @@ def allow_externally_visible_side_effects_in_subtracer( @contextlib.contextmanager +<<<<<<< HEAD def disallow_side_effects_in_generator( tx: "InstructionTranslatorBase", ) -> Generator[None, None, None]: +======= +def disallow_side_effects_in_generator(tx: "InstructionTranslator"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig_val = tx.output.current_tracer.is_reconstructing_generator try: tx.output.current_tracer.is_reconstructing_generator = True diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index c1906eeee710c..e949e7502d8be 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This module provides Source classes that track the origins of values in PyTorch Dynamo. Sources represent where values come from (e.g. local variables, globals, attributes) and @@ -20,9 +25,15 @@ import dataclasses import enum import functools +<<<<<<< HEAD from typing import Any, Callable, Optional, TYPE_CHECKING, Union from torch._guards import ChainedSource, Guard, GuardSource, Source +======= +from typing import Any, Optional, TYPE_CHECKING, Union + +from torch._guards import ChainedSource, GuardSource, Source +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from . import utils from .bytecode_transformation import create_call_function, create_instruction @@ -94,7 +105,11 @@ } +<<<<<<< HEAD def is_constant_source(source: Source) -> bool: +======= +def is_constant_source(source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(source, ConstantSource): return True try: @@ -106,6 +121,7 @@ def is_constant_source(source: Source) -> bool: return False +<<<<<<< HEAD def _get_source_debug_name(source: Source) -> str: try: return source.name() @@ -113,6 +129,8 @@ def _get_source_debug_name(source: Source) -> str: return "" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass(frozen=True) class LocalSource(Source): local_name: str @@ -129,16 +147,27 @@ class LocalSource(Source): # or `co_freevars`. is_derefed_cell_contents: bool = False +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: +======= + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.is_derefed_cell_contents: codegen.load_deref(self.local_name) else: codegen.append_output(codegen.create_load(self.local_name)) +<<<<<<< HEAD def guard_source(self) -> GuardSource: return GuardSource.LOCAL def name(self) -> str: +======= + def guard_source(self): + return GuardSource.LOCAL + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"L[{repr(self.local_name)}]" @@ -146,6 +175,7 @@ def name(self) -> str: class SyntheticLocalSource(Source): local_name: str +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load(self.local_name)) @@ -153,6 +183,15 @@ def guard_source(self) -> GuardSource: return GuardSource.SYNTHETIC_LOCAL def name(self) -> str: +======= + def reconstruct(self, codegen: "PyCodegen"): + codegen.append_output(codegen.create_load(self.local_name)) + + def guard_source(self): + return GuardSource.SYNTHETIC_LOCAL + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"SYNTHETIC_LOCAL[{self.local_name!r}]" @@ -160,15 +199,26 @@ def name(self) -> str: class RandomValueSource(Source): random_call_index: int +<<<<<<< HEAD def guard_source(self) -> GuardSource: return GuardSource.RANDOM_VALUE def reconstruct(self, codegen: "PyCodegen") -> None: +======= + def guard_source(self): + return GuardSource.RANDOM_VALUE + + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var)) codegen.append_output(codegen.create_load_const(self.random_call_index)) codegen.append_output(create_instruction("BINARY_SUBSCR")) +<<<<<<< HEAD def name(self) -> str: +======= + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"random_value_{self.random_call_index}" @@ -176,6 +226,7 @@ def name(self) -> str: class GlobalSource(Source): global_name: str +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_global(self.global_name, add=True)) @@ -183,6 +234,15 @@ def guard_source(self) -> GuardSource: return GuardSource.GLOBAL def name(self) -> str: +======= + def reconstruct(self, codegen: "PyCodegen"): + codegen.append_output(codegen.create_load_global(self.global_name, add=True)) + + def guard_source(self): + return GuardSource.GLOBAL + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"G[{repr(self.global_name)}]" @@ -190,7 +250,11 @@ def name(self) -> str: class GlobalWeakRefSource(Source): global_name: str +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: +======= + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) codegen.add_push_null( lambda: codegen.append_output( codegen.create_load_global(self.global_name, add=True) @@ -198,15 +262,23 @@ def reconstruct(self, codegen: "PyCodegen") -> None: ) codegen.extend_output(create_call_function(0, False)) +<<<<<<< HEAD def guard_source(self) -> GuardSource: return GuardSource.GLOBAL def name(self) -> str: +======= + def guard_source(self): + return GuardSource.GLOBAL + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"G[{repr(self.global_name)}]()" @dataclasses.dataclass(frozen=True) class WeakRefCallSource(ChainedSource): +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: codegen(self.base)) codegen.extend_output(create_call_function(0, False)) @@ -215,6 +287,16 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() def name(self) -> str: +======= + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null(lambda: codegen(self.base)) + codegen.extend_output(create_call_function(0, False)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"{self.base.name()}()" @@ -227,7 +309,11 @@ class CallFunctionNoArgsSource(WeakRefCallSource): class AttrSource(ChainedSource): member: str +<<<<<<< HEAD def __post_init__(self) -> None: +======= + def __post_init__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.base, "Can't construct an AttrSource without a valid base source" if "." in self.member: member_parts = self.member.split(".") @@ -236,6 +322,7 @@ def __post_init__(self) -> None: ) object.__setattr__(self, "member", member_parts[-1]) +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.member)) @@ -244,6 +331,16 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() def name(self) -> str: +======= + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + codegen.extend_output(codegen.create_load_attrs(self.member)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not self.member.isidentifier(): return f"getattr({self.base.name()}, {self.member!r})" return f"{self.base.name()}.{self.member}" @@ -253,7 +350,11 @@ def name(self) -> str: class GenericAttrSource(ChainedSource): member: str +<<<<<<< HEAD def __post_init__(self) -> None: +======= + def __post_init__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.base, "Can't construct an AttrSource without a valid base source" if "." in self.member: member_parts = self.member.split(".") @@ -262,6 +363,7 @@ def __post_init__(self) -> None: ) object.__setattr__(self, "member", member_parts[-1]) +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.member)) @@ -305,6 +407,19 @@ def name(self) -> str: return f"{self.base.name()}.__mro__" +======= + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + codegen.extend_output(codegen.create_load_attrs(self.member)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"object.__getattribute__({self.base.name()}, {self.member!r})" + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass(frozen=True) class LocalCellSource(Source): """ @@ -314,7 +429,11 @@ class LocalCellSource(Source): local_name: str +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: +======= + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics, # Dynamo's bytecode transformation differentiates them slightly, so we # always emit `LOAD_CLOSURE` here. @@ -324,6 +443,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: # local cell object should never be used for guards. +<<<<<<< HEAD # Represents obj.__code__ where object is type object @dataclasses.dataclass(frozen=True) class CodeSource(ChainedSource): @@ -352,6 +472,8 @@ def name(self) -> str: return f"{self.base.name()}.__closure__" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Represents tensor.grad source. It could be represented by AttrSource as well. # But, we could access grad field on tensor directly in C++ without going # through the Python bytecodes. Therefore, we use a separate source for grad @@ -360,6 +482,7 @@ def name(self) -> str: class GradSource(ChainedSource): member: str = "grad" +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.member)) @@ -368,12 +491,26 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() def name(self) -> str: +======= + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + codegen.extend_output(codegen.create_load_attrs(self.member)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"{self.base.name()}.{self.member}" @dataclasses.dataclass(frozen=True) class ParamBufferSource(AttrSource): +<<<<<<< HEAD def guard_source(self) -> GuardSource: +======= + def guard_source(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] @@ -396,6 +533,7 @@ class UnspecializedParamBufferSource(AttrSource): class EphemeralSource(Source): desc: Optional[str] = None +<<<<<<< HEAD def guard_source(self) -> GuardSource: return GuardSource.EPHEMERAL @@ -421,20 +559,42 @@ def name(self) -> str: return self.base.name() +======= + def guard_source(self): + return GuardSource.EPHEMERAL + + def name(self): + return f"" + + def make_guard(self, fn): + raise NotImplementedError + + def is_ephemeral(self): + return True + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TensorProperty(enum.Enum): SIZE = 0 STRIDE = 1 STORAGE_OFFSET = 2 +<<<<<<< HEAD def method_name(self) -> str: +======= + def method_name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self is TensorProperty.SIZE: return "size" elif self is TensorProperty.STRIDE: return "stride" elif self is TensorProperty.STORAGE_OFFSET: return "storage_offset" +<<<<<<< HEAD else: raise AssertionError(f"unhandled {self}") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass(frozen=True) @@ -442,14 +602,22 @@ class TensorPropertySource(ChainedSource): prop: TensorProperty idx: Optional[int] = None # None for STORAGE_OFFSET +<<<<<<< HEAD def __post_init__(self) -> None: +======= + def __post_init__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.base is not None if self.prop is TensorProperty.STORAGE_OFFSET: assert self.idx is None else: assert self.idx is not None +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: +======= + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) codegen.add_push_null( lambda: codegen.load_import_from( utils.__name__, f"call_{self.prop.method_name()}" @@ -463,10 +631,17 @@ def reconstruct(self, codegen: "PyCodegen") -> None: create_call_function(2 if self.idx is not None else 1, False) ) +<<<<<<< HEAD def guard_source(self) -> GuardSource: return self.base.guard_source() def name(self) -> str: +======= + def guard_source(self): + return self.base.guard_source() + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.prop is TensorProperty.SIZE: return f"{self.base.name()}.size()[{self.idx}]" elif self.prop is TensorProperty.STRIDE: @@ -482,6 +657,7 @@ def name(self) -> str: class IndexedSource(ChainedSource): idx: int +<<<<<<< HEAD def __post_init__(self) -> None: assert self.base is not None @@ -492,11 +668,24 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() def name(self) -> str: +======= + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen: "PyCodegen"): + raise NotImplementedError + + def guard_source(self): + return self.base.guard_source() + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"({self.idx}, {self.base.name()})" @dataclasses.dataclass(frozen=True) class NegateSource(ChainedSource): +<<<<<<< HEAD def __post_init__(self) -> None: assert self.base is not None @@ -507,12 +696,25 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() def name(self) -> str: +======= + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen: "PyCodegen"): + raise NotImplementedError + + def guard_source(self): + return self.base.guard_source() + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: use method call so that function stripping regexes work return f"{self.base.name()}.__neg__()" @dataclasses.dataclass(frozen=True) class ConvertIntSource(ChainedSource): +<<<<<<< HEAD def __post_init__(self) -> None: assert self.base is not None @@ -523,11 +725,24 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() def name(self) -> str: +======= + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + + def guard_source(self): + return self.base.guard_source() + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"cast_symbool_to_symint_guardless({self.base.name()})" @dataclasses.dataclass(frozen=True) class FlattenScriptObjectSource(ChainedSource): +<<<<<<< HEAD def __post_init__(self) -> None: assert self.base is not None @@ -538,11 +753,24 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() def name(self) -> str: +======= + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + + def guard_source(self): + return self.base.guard_source() + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"{self.base.name()}.__obj_flatten__()" @dataclasses.dataclass(frozen=True) class ScriptObjectQualifiedNameSource(ChainedSource): +<<<<<<< HEAD def __post_init__(self) -> None: assert self.base is not None @@ -553,10 +781,23 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() def name(self) -> str: +======= + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + + def guard_source(self): + return self.base.guard_source() + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"{self.base.name()}._type().qualified_name()" class AttrProxySource(ChainedSource): +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) @@ -564,6 +805,15 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() def name(self) -> str: +======= + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + + def guard_source(self): + return self.base.guard_source() + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"{self.base.name()}.get_base()" @@ -574,7 +824,11 @@ class DefaultsSource(ChainedSource): field: str = dataclasses.field(init=False, repr=False, compare=False) _name: str = dataclasses.field(init=False, repr=False, compare=False) +<<<<<<< HEAD def __post_init__(self) -> None: +======= + def __post_init__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.base, ( "Base must be a valid source in order to properly track and guard this Defaults to its origin." ) @@ -591,16 +845,27 @@ def __post_init__(self) -> None: self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]" ) +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: +======= + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.field)) codegen.append_output(codegen.create_load_const(self.idx_key)) codegen.append_output(create_instruction("BINARY_SUBSCR")) +<<<<<<< HEAD def guard_source(self) -> GuardSource: return self.base.guard_source() def name(self) -> str: +======= + def guard_source(self): + return self.base.guard_source() + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self._name @@ -609,14 +874,22 @@ class GetItemSource(ChainedSource): index: Any index_is_slice: bool = False +<<<<<<< HEAD def __post_init__(self) -> None: +======= + def __post_init__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.base is not None if isinstance(self.index, slice): # store the hashable version of the slice so the whole GetItemSource is hashable super().__setattr__("index", self.index.__reduce__()) super().__setattr__("index_is_slice", True) +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: +======= + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) codegen(self.base) if self.index_is_slice: codegen.append_output(codegen.create_load_const(self.unpack_slice())) @@ -624,15 +897,26 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_const(self.index)) codegen.append_output(create_instruction("BINARY_SUBSCR")) +<<<<<<< HEAD def guard_source(self) -> GuardSource: return self.base.guard_source() def unpack_slice(self) -> slice: +======= + def guard_source(self): + return self.base.guard_source() + + def unpack_slice(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.index_is_slice slice_class, slice_args = self.index return slice_class(*slice_args) +<<<<<<< HEAD def name(self) -> str: +======= + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Index can be of following types # 1) index is a slice - example 1:4 # 2) index is a constant - example string, integer @@ -647,10 +931,17 @@ def name(self) -> str: class ConstDictKeySource(ChainedSource): index: Any +<<<<<<< HEAD def guard_source(self) -> GuardSource: return self.base.guard_source() def reconstruct(self, codegen: "PyCodegen") -> None: +======= + def guard_source(self): + return self.base.guard_source() + + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) codegen.add_push_null( lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem") ) @@ -658,6 +949,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_const(self.index)) codegen.extend_output(create_call_function(2, False)) +<<<<<<< HEAD def name(self) -> str: # The list creation will be CSE'd by PyExprCSEPass return f"list(dict.keys({self.base.name()}))[{self.index!r}]" @@ -694,6 +986,16 @@ def is_dict_key(self) -> bool: return False +======= + def name(self): + # The list creation will be CSE'd by PyExprCSEPass + return f"list(dict.keys({self.base.name()}))[{self.index!r}]" + + def is_dict_key(self): + return True + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Used to access an item from the dictionary @dataclasses.dataclass(frozen=True) class DictGetItemSource(ChainedSource): @@ -702,17 +1004,28 @@ class DictGetItemSource(ChainedSource): # 2) constant - like string, integer index: Any +<<<<<<< HEAD def __post_init__(self) -> None: +======= + def __post_init__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .variables import ConstantVariable assert isinstance( self.index, ConstDictKeySource ) or ConstantVariable.is_literal(self.index) +<<<<<<< HEAD def guard_source(self) -> GuardSource: return self.base.guard_source() def reconstruct(self, codegen: "PyCodegen") -> None: +======= + def guard_source(self): + return self.base.guard_source() + + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Load dict codegen(self.base) @@ -723,7 +1036,11 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_const(self.index)) codegen.append_output(create_instruction("BINARY_SUBSCR")) +<<<<<<< HEAD def name(self) -> str: +======= + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(self.index, ConstDictKeySource): return f"{self.base.name()}[{self.index.name()}]" else: @@ -739,17 +1056,28 @@ class DictSubclassGetItemSource(ChainedSource): # 2) constant - like string, integer index: Any +<<<<<<< HEAD def __post_init__(self) -> None: +======= + def __post_init__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .variables import ConstantVariable assert isinstance( self.index, ConstDictKeySource ) or ConstantVariable.is_literal(self.index) +<<<<<<< HEAD def guard_source(self) -> GuardSource: return self.base.guard_source() def reconstruct(self, codegen: "PyCodegen") -> None: +======= + def guard_source(self): + return self.base.guard_source() + + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # reconstruct dict.__getitem__(dct, key) # Load dict.__getitem__ @@ -768,7 +1096,11 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.extend_output(create_call_function(2, False)) +<<<<<<< HEAD def name(self) -> str: +======= + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(self.index, ConstDictKeySource): return f"dict.__getitem__({self.base.name()}, {self.index.name()})" else: @@ -781,7 +1113,11 @@ class ListGetItemSource(GetItemSource): Same as GetItemSource with reconstruct and name overridden to be list specific. """ +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: +======= + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Reconstruct list.__getitem__(lst, index) to avoid any side effects # from possibly overridden __getitem__. @@ -803,7 +1139,11 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.extend_output(create_call_function(2, False)) +<<<<<<< HEAD def name(self) -> str: +======= + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Index can be of following types # 1) index is a slice - example 1:4 # 2) index is a constant - example string, integer @@ -818,7 +1158,11 @@ def name(self) -> str: @dataclasses.dataclass(frozen=True) class TupleIteratorGetItemSource(GetItemSource): +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: +======= + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) codegen.add_push_null( lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem") ) @@ -826,11 +1170,16 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_const(self.index)) codegen.extend_output(create_call_function(2, False)) +<<<<<<< HEAD def name(self) -> str: +======= + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})" @dataclasses.dataclass(frozen=True) +<<<<<<< HEAD class NamedTupleFieldsSource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) @@ -846,38 +1195,64 @@ def name(self) -> str: @dataclasses.dataclass(frozen=True) class DataclassFieldsSource(ChainedSource): def reconstruct(self, codegen: "PyCodegen") -> None: +======= +class DataclassFieldsSource(ChainedSource): + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) codegen.add_push_null( lambda: codegen.load_import_from(utils.__name__, "dataclass_fields") ) codegen(self.base) codegen.extend_output(create_call_function(1, False)) +<<<<<<< HEAD def guard_source(self) -> GuardSource: return self.base.guard_source() def name(self) -> str: +======= + def guard_source(self): + return self.base.guard_source() + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"___dataclass_fields({self.base.name()})" @dataclasses.dataclass(frozen=True) class TypeSource(ChainedSource): +<<<<<<< HEAD def __post_init__(self) -> None: assert self.base is not None def reconstruct(self, codegen: "PyCodegen") -> None: +======= + def __post_init__(self): + assert self.base is not None + + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type")) codegen(self.base) codegen.extend_output(create_call_function(1, False)) +<<<<<<< HEAD def guard_source(self) -> GuardSource: return self.base.guard_source() def name(self) -> str: +======= + def guard_source(self): + return self.base.guard_source() + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"type({self.base.name()})" @dataclasses.dataclass(frozen=True) class OptimizerSource(ChainedSource): +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) @@ -885,11 +1260,21 @@ def guard_source(self) -> GuardSource: return self.base.guard_source() def name(self) -> str: +======= + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + + def guard_source(self): + return self.base.guard_source() + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.base.name() @dataclasses.dataclass(frozen=True) class NNModuleSource(ChainedSource): +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.base) @@ -897,29 +1282,51 @@ def guard_source(self) -> GuardSource: return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] def name(self) -> str: +======= + def reconstruct(self, codegen: "PyCodegen"): + codegen(self.base) + + def guard_source(self): + return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] + + def name(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.base.name() @dataclasses.dataclass(frozen=True) class UnspecializedNNModuleSource(NNModuleSource): +<<<<<<< HEAD def guard_source(self) -> GuardSource: +======= + def guard_source(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source()] @dataclasses.dataclass(frozen=True) class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource): +<<<<<<< HEAD def guard_source(self) -> GuardSource: +======= + def guard_source(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source()] @dataclasses.dataclass(frozen=True) class FSDPNNModuleSource(NNModuleSource): +<<<<<<< HEAD def guard_source(self) -> GuardSource: +======= + def guard_source(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()] @dataclasses.dataclass(frozen=True) class GlobalStateSource(Source): +<<<<<<< HEAD def name(self) -> str: return "" @@ -951,6 +1358,12 @@ def reconstruct(self, codegen: "PyCodegen") -> None: ) def guard_source(self) -> GuardSource: +======= + def name(self): + return "" + + def guard_source(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return GuardSource.GLOBAL @@ -958,15 +1371,26 @@ def guard_source(self) -> GuardSource: class TorchFunctionModeStackSource(Source): ind: int +<<<<<<< HEAD def name(self) -> str: return f"___get_torch_function_mode_stack_at({self._get_index()})" def _get_index(self) -> int: +======= + def name(self): + return f"___get_torch_function_mode_stack_at({self._get_index()})" + + def _get_index(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .variables.torch_function import TorchFunctionModeStackVariable return TorchFunctionModeStackVariable.get_mode_index(self.ind) +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: +======= + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) codegen.add_push_null( lambda: codegen.load_import_from( utils.__name__, "get_torch_function_mode_stack_at" @@ -975,7 +1399,11 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.extend_output([codegen.create_load_const(self._get_index())]) codegen.extend_output(create_call_function(1, False)) +<<<<<<< HEAD def guard_source(self) -> GuardSource: +======= + def guard_source(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return GuardSource.GLOBAL @@ -983,6 +1411,7 @@ def guard_source(self) -> GuardSource: class ConstantSource(Source): source_name: str +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_global(self.source_name, add=False)) @@ -993,6 +1422,18 @@ def name(self) -> str: return self.source_name def make_guard(self, fn: Any) -> Any: +======= + def reconstruct(self, codegen: "PyCodegen"): + codegen.append_output(codegen.create_load_global(self.source_name, add=False)) + + def guard_source(self): + return GuardSource.CONSTANT + + def name(self): + return self.source_name + + def make_guard(self, fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError @@ -1001,10 +1442,17 @@ class NumpyTensorSource(ChainedSource): def name(self) -> str: return f"___from_numpy({self.base.name()})" +<<<<<<< HEAD def guard_source(self) -> GuardSource: return self.base.guard_source() def reconstruct(self, codegen: "PyCodegen") -> None: +======= + def guard_source(self): + return self.base.guard_source() + + def reconstruct(self, codegen: "PyCodegen"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor")) codegen(self.base) codegen.extend_output(create_call_function(1, False)) @@ -1015,7 +1463,11 @@ class SubclassAttrListSource(ChainedSource): def name(self) -> str: return f"{self.base.name()}.__tensor_flatten__()[0]" +<<<<<<< HEAD def guard_source(self) -> GuardSource: +======= + def guard_source(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.base.guard_source() @@ -1026,7 +1478,11 @@ class FloatTensorSource(ChainedSource): def name(self) -> str: return f"___as_tensor({self.base.name()})" +<<<<<<< HEAD def guard_source(self) -> GuardSource: +======= + def guard_source(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.base.guard_source() @@ -1035,7 +1491,11 @@ class CallMethodItemSource(ChainedSource): def name(self) -> str: return f"{self.base.name()}.item()" +<<<<<<< HEAD def guard_source(self) -> GuardSource: +======= + def guard_source(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.base.guard_source() @@ -1044,15 +1504,23 @@ def guard_source(self) -> GuardSource: # guard contents from the ambient ShapeEnv @dataclasses.dataclass(frozen=True) class ShapeEnvSource(Source): +<<<<<<< HEAD def name(self) -> str: return "" def guard_source(self) -> GuardSource: +======= + def name(self): + return "" + + def guard_source(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return GuardSource.SHAPE_ENV @dataclasses.dataclass(frozen=True) class BackwardStateSource(Source): +<<<<<<< HEAD def name(self) -> str: return "" @@ -1063,6 +1531,16 @@ def guard_source(self) -> GuardSource: def get_local_source_name( source: Source, *, only_allow_input: bool = False ) -> Optional[str]: +======= + def name(self): + return "" + + def guard_source(self): + return GuardSource.BACKWARD_STATE + + +def get_local_source_name(source: Source, *, only_allow_input=False) -> Optional[str]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(source, ChainedSource): return get_local_source_name(source.base, only_allow_input=only_allow_input) if not isinstance(source, LocalSource): @@ -1072,7 +1550,11 @@ def get_local_source_name( return source.local_name +<<<<<<< HEAD def is_from_local_source(source: Source, *, only_allow_input: bool = False) -> bool: +======= +def is_from_local_source(source: Source, *, only_allow_input=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return get_local_source_name(source, only_allow_input=only_allow_input) is not None @@ -1088,7 +1570,11 @@ def get_global_source_name(source: Source) -> Optional[str]: return source.global_name +<<<<<<< HEAD def is_from_nonlocal_source(source: Source) -> bool: +======= +def is_from_nonlocal_source(source: Source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(source, ChainedSource): return is_from_nonlocal_source(source.base) return ( @@ -1098,6 +1584,7 @@ def is_from_nonlocal_source(source: Source) -> bool: ) +<<<<<<< HEAD def is_from_closure_source(source: Source) -> bool: if isinstance(source, ClosureSource): return True @@ -1107,13 +1594,20 @@ def is_from_closure_source(source: Source) -> bool: def is_from_source(source: Source, target: Source) -> bool: +======= +def is_from_source(source: Source, target: Source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(source, ChainedSource): return is_from_source(source.base, target) return source == target @functools.lru_cache +<<<<<<< HEAD def is_from_unspecialized_nn_module_source(source: Source) -> bool: +======= +def is_from_unspecialized_nn_module_source(source: Source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(source, UnspecializedNNModuleSource): return True if isinstance(source, ChainedSource): @@ -1122,7 +1616,11 @@ def is_from_unspecialized_nn_module_source(source: Source) -> bool: @functools.lru_cache +<<<<<<< HEAD def is_from_unspecialized_builtin_nn_module_source(source: Source) -> bool: +======= +def is_from_unspecialized_builtin_nn_module_source(source: Source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(source, UnspecializedBuiltinNNModuleSource): return True if isinstance(source, ChainedSource): @@ -1131,7 +1629,11 @@ def is_from_unspecialized_builtin_nn_module_source(source: Source) -> bool: @functools.lru_cache +<<<<<<< HEAD def is_from_unspecialized_param_buffer_source(source: Source) -> bool: +======= +def is_from_unspecialized_param_buffer_source(source: Source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(source, UnspecializedParamBufferSource): return True if isinstance(source, ChainedSource): @@ -1140,7 +1642,11 @@ def is_from_unspecialized_param_buffer_source(source: Source) -> bool: @functools.lru_cache +<<<<<<< HEAD def is_from_flatten_script_object_source(source: Source) -> bool: +======= +def is_from_flatten_script_object_source(source: Source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(source, FlattenScriptObjectSource): return True elif isinstance(source, ChainedSource): @@ -1149,7 +1655,11 @@ def is_from_flatten_script_object_source(source: Source) -> bool: @functools.lru_cache +<<<<<<< HEAD def is_from_optimizer_source(source: Source) -> bool: +======= +def is_from_optimizer_source(source: Source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(source, OptimizerSource): return True if isinstance(source, ChainedSource): @@ -1160,7 +1670,11 @@ def is_from_optimizer_source(source: Source) -> bool: # TODO: can probably write a generic "test this on everything in the chain" # helper @functools.lru_cache +<<<<<<< HEAD def is_from_defaults(source: Source) -> bool: +======= +def is_from_defaults(source: Source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(source, DefaultsSource): return True @@ -1183,6 +1697,7 @@ def is_from_defaults(source: Source) -> bool: if isinstance(source, ChainedSource): return is_from_defaults(source.base) return False +<<<<<<< HEAD @functools.lru_cache @@ -1194,3 +1709,5 @@ def is_from_skip_guard_source(source: Source) -> bool: return is_from_skip_guard_source(source.base) return False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 4dd1321a5057d..418b043db8796 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Core module responsible for converting Python bytecode into TorchDynamo's symbolic execution format. @@ -22,8 +27,11 @@ optimization of PyTorch programs. """ +<<<<<<< HEAD from __future__ import annotations +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import collections import collections.abc import contextlib @@ -42,15 +50,25 @@ import threading import traceback import types +<<<<<<< HEAD import weakref from traceback import StackSummary from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union from typing_extensions import TypeAlias, TypeIs +======= +import typing +import weakref +from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from unittest.mock import patch import torch import torch._logging +<<<<<<< HEAD from torch._dynamo.exc import ObservedException, TensorifyScalarRestartAnalysis +======= +from torch._dynamo.exc import TensorifyScalarRestartAnalysis +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._guards import tracing, TracingContext from torch._logging.structured import dump_file from torch.fx.experimental.symbolic_shapes import guard_bool @@ -72,6 +90,7 @@ ) from .bytecode_transformation import ( cleaned_instructions, +<<<<<<< HEAD create_binary_slice, create_call_function, create_copy, @@ -79,11 +98,19 @@ create_instruction, create_jump_absolute, create_rot_n, +======= + create_call_function, + create_instruction, + create_jump_absolute, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) create_swap, get_code_keys, Instruction, is_generator, +<<<<<<< HEAD is_jump_absolute, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unique_id, ) from .code_context import code_context @@ -94,13 +121,17 @@ collapse_resume_frames, format_graph_break_message, get_stack_above_dynamo, +<<<<<<< HEAD ResumePrologueTracingError, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unimplemented_v2, Unsupported, ) from .funcname_cache import get_funcname from .guards import GuardBuilder, install_guard from .output_graph import GraphCompileReason, OutputGraph +<<<<<<< HEAD from .polyfills import impl_CONTAINS_OP_fallback from .replay_record import DummyModule, ExecutionRecorder from .resume_execution import ( @@ -108,6 +139,10 @@ IS_TRACING_RESUME_PROLOGUE_VARNAME, ReenterWith, ) +======= +from .replay_record import DummyModule, ExecutionRecorder +from .resume_execution import ContinueExecutionCache, ReenterWith +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .source import ( AttrSource, DictGetItemSource, @@ -115,12 +150,18 @@ GlobalWeakRefSource, LocalCellSource, LocalSource, +<<<<<<< HEAD SkipGuardSource, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Source, ) from .trace_rules import is_builtin_constant, is_forbidden from .utils import ( +<<<<<<< HEAD _get_error_on_graph_break, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) counters, get_fake_value, get_instruction_source_311, @@ -153,7 +194,10 @@ from .variables.lazy import LazyVariableTracker from .variables.lists import ( BaseListVariable, +<<<<<<< HEAD IteratorVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ListIteratorVariable, ListVariable, SliceVariable, @@ -183,10 +227,13 @@ if TYPE_CHECKING: +<<<<<<< HEAD from collections.abc import Generator, Sequence from torch._subclasses.fake_tensor import FakeTensorMode +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .package import CompilePackage log = logging.getLogger(__name__) @@ -207,6 +254,7 @@ tx, [handle_contains(tx, [*reversed(args)], {})], {} ) +<<<<<<< HEAD PT2_ISSUE_TRACKER_URL = "https://github.com/pytorch/pytorch/issues/new?&labels=oncall%3A+pt2&projects=&template=pt2-bug-report.yml" ExceptionVals: TypeAlias = Union[ @@ -214,6 +262,10 @@ UserDefinedExceptionClassVariable, UserDefinedExceptionObjectVariable, ] +======= + +PT2_ISSUE_TRACKER_URL = "https://github.com/pytorch/pytorch/issues/new?&labels=oncall%3A+pt2&projects=&template=pt2-bug-report.yml" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @functools.cache @@ -232,6 +284,7 @@ class SpeculationEntry: lineno: int instruction_pointer: int inst: Instruction # for debugging only +<<<<<<< HEAD _failed: bool = False error_on_graph_break: Optional[bool] = None reason: Optional[GraphCompileReason] = None @@ -242,12 +295,23 @@ def fail_and_restart_analysis(self, error_on_graph_break: bool) -> None: """ self._failed = True self.error_on_graph_break = error_on_graph_break +======= + failed: bool = False + reason: Optional[GraphCompileReason] = None + + def fail_and_restart_analysis(self): + """ + Start tracing of the current frame over again, and don't take this branch. + """ + self.failed = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.reason is not None: restart_reason = self.reason.reason else: restart_reason = "Unknown fail_and_restart_analysis" raise exc.SpeculationRestartAnalysis(restart_reason=restart_reason) +<<<<<<< HEAD def failed(self, tx: InstructionTranslatorBase) -> bool: if self._failed: assert self.error_on_graph_break is not None @@ -255,6 +319,8 @@ def failed(self, tx: InstructionTranslatorBase) -> bool: return True return False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass class SpeculationLog: @@ -269,15 +335,26 @@ class SpeculationLog: entries: list[SpeculationEntry] = dataclasses.field(default_factory=list) index: int = 0 +<<<<<<< HEAD def restart(self) -> None: self.index = 0 def clear(self) -> None: +======= + def restart(self): + self.index = 0 + + def clear(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.entries.clear() self.index = 0 def next( +<<<<<<< HEAD self, filename: str, lineno: int, instruction_pointer: int, inst: Instruction +======= + self, filename: str, lineno: int, instruction_pointer, inst +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> SpeculationEntry: """ Lookup or create a SpeculationEntry() that is shared across @@ -368,14 +445,22 @@ def empty(cls) -> bool: @functools.cache +<<<<<<< HEAD def _step_logger() -> Callable[..., None]: +======= +def _step_logger(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torchdynamo_logging.get_step_logger(log) @contextlib.contextmanager +<<<<<<< HEAD def save_and_restart_speculation_log( tx: InstructionTranslatorBase, ) -> Generator[None, None, None]: +======= +def save_and_restart_speculation_log(tx: "InstructionTranslatorBase"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # When reconstructing a generator after a graph break, we advance it until # it is fully exhausted. This process adds new entries to the speculation # log that were not previously observed. Without temporarily clearing the @@ -393,9 +478,13 @@ def save_and_restart_speculation_log( @contextlib.contextmanager +<<<<<<< HEAD def temporarely_allow_writes_to_output_graph( tx: InstructionTranslatorBase, ) -> Generator[None, None, None]: +======= +def temporarely_allow_writes_to_output_graph(tx: "InstructionTranslatorBase"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: tmp = tx.output.should_exit tx.output.should_exit = False @@ -414,10 +503,17 @@ class BlockStackEntry: Union[ContextWrappingVariable, GenericContextWrappingVariable] ] = None +<<<<<<< HEAD def can_restore(self) -> bool: return self.with_context is not None def resume_fn(self) -> ReenterWith: +======= + def can_restore(self): + return self.with_context is not None + + def resume_fn(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.stack_index is not None if ( self.with_context @@ -430,12 +526,20 @@ def resume_fn(self) -> ReenterWith: else: return ReenterWith(self.stack_index - 1) +<<<<<<< HEAD def exit(self, tx: InstructionTranslatorBase, is_graph_break: bool) -> None: +======= + def exit(self, tx, is_graph_break): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.with_context is not None if ( is_graph_break and self.with_context.exit_on_graph_break() ) or not is_graph_break: +<<<<<<< HEAD return self.with_context.exit(tx) # type: ignore[arg-type] +======= + return self.with_context.exit(tx) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SpeculationLogDivergence(AssertionError): @@ -453,17 +557,26 @@ class YieldValueOp(Exception): """ +<<<<<<< HEAD def stack_op(fn: Callable[..., object]) -> Callable[..., Any]: +======= +def stack_op(fn: typing.Callable[..., object]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nargs = len(inspect.signature(fn).parameters) fn_var = BuiltinVariable(fn) @functools.wraps(fn) +<<<<<<< HEAD def impl(self: InstructionTranslator, inst: Instruction) -> None: +======= + def impl(self: "InstructionTranslator", inst: Instruction): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.push(fn_var.call_function(self, self.popn(nargs), {})) return impl +<<<<<<< HEAD def is_stdlib(mod: object) -> bool: if sys.version_info < (3, 10): # For < 3.10, no easy way to identify a stdlib module name. @@ -478,6 +591,13 @@ def _detect_and_normalize_assert_statement( truth_fn: Callable[[object], bool], push: bool, ) -> bool: +======= +def _detect_and_normalize_assert_statement( + self: "InstructionTranslatorBase", + truth_fn: typing.Callable[[object], bool], + push: bool, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Detect if this jump instruction is assert and normalize the assert # by pushing dummy error message when nothing is given. # @@ -542,12 +662,16 @@ def _detect_and_normalize_assert_statement( explain = False +<<<<<<< HEAD def log_graph_break( code_options: dict[str, Any], reason: str = "", exc_info: bool = False, user_stack: Optional[StackSummary] = None, ) -> None: +======= +def log_graph_break(code_options, reason="", exc_info=False, user_stack=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if user_stack is None: user_stack = torch._guards.TracingContext.extract_stack() @@ -567,7 +691,11 @@ def log_graph_break( traceback.format_list(stack_above_dynamo) ) else: +<<<<<<< HEAD user_stack = get_stack_above_dynamo() + user_stack # type: ignore[assignment] +======= + user_stack = get_stack_above_dynamo() + user_stack +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) user_stack = collapse_resume_frames(user_stack) user_stack_formatted = "".join(traceback.format_list(user_stack)) user_stack_trace = ( @@ -623,9 +751,13 @@ def log_graph_break( ) +<<<<<<< HEAD def generic_jump( truth_fn: Callable[[object], bool], push: bool ) -> Callable[[InstructionTranslatorBase, Instruction], None]: +======= +def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # graph break message fields for data dependent branching _gb_type = "Data-dependent branching" _explanation = ( @@ -637,12 +769,16 @@ def generic_jump( "Use `torch.cond` to express dynamic control flow.", ] +<<<<<<< HEAD def jump_graph_break( self: InstructionTranslatorBase, inst: Instruction, value: VariableTracker, extra_msg: str = "", ) -> None: +======= + def jump_graph_break(self, inst, value, extra_msg=""): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log_graph_break( self.code_options, reason=format_graph_break_message( @@ -674,6 +810,7 @@ def jump_graph_break( self.pop() if_next = self.create_call_resume_at( +<<<<<<< HEAD self.next_instruction, all_stack_locals_metadata, False ) if push: @@ -682,6 +819,13 @@ def jump_graph_break( if_jump = self.create_call_resume_at( inst.target, all_stack_locals_metadata, False ) +======= + self.next_instruction, all_stack_locals_metadata + ) + if push: + self.push(value) + if_jump = self.create_call_resume_at(inst.target, all_stack_locals_metadata) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 13): # 3.13 requires stack[-1] to be bool type @@ -691,7 +835,11 @@ def jump_graph_break( jump_inst.copy_positions(inst) self.output.add_output_instructions([jump_inst] + if_next + if_jump) +<<<<<<< HEAD def inner(self: InstructionTranslatorBase, inst: Instruction) -> None: +======= + def inner(self: "InstructionTranslatorBase", inst: Instruction): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) value: VariableTracker = self.pop() if ( config.rewrite_assert_with_torch_assert @@ -874,6 +1022,7 @@ def inner(self: InstructionTranslatorBase, inst: Instruction) -> None: self.jump(inst) else: unimplemented_v2( +<<<<<<< HEAD gb_type="Data-dependent branching", context=f"attempted to jump with {value}", explanation=_explanation, @@ -881,11 +1030,18 @@ def inner(self: InstructionTranslatorBase, inst: Instruction) -> None: *graph_break_hints.FUNDAMENTAL, "Use `torch.cond` to express dynamic control flow.", ], +======= + gb_type=_gb_type, + context=f"attempted to jump with {value}", + explanation=_explanation, + hints=_hints, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return inner +<<<<<<< HEAD def break_graph_if_unsupported( *, push: int ) -> Callable[ @@ -898,6 +1054,14 @@ def decorator( def wrapper(self: InstructionTranslatorBase, inst: Instruction) -> None: speculation = self.speculate() if speculation.failed(self): +======= +def break_graph_if_unsupported(*, push): + def decorator(inner_fn): + @functools.wraps(inner_fn) + def wrapper(self: "InstructionTranslatorBase", inst: Instruction): + speculation = self.speculate() + if speculation.failed: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert speculation.reason is not None return handle_graph_break(self, inst, speculation.reason) try: @@ -942,6 +1106,7 @@ def wrapper(self: InstructionTranslatorBase, inst: Instruction) -> None: excp.remove_from_stats() excp.add_to_stats("graph_break") speculation.reason = GraphCompileReason(excp.msg, excp.real_stack) +<<<<<<< HEAD speculation.fail_and_restart_analysis(self.error_on_graph_break) def handle_graph_break( @@ -949,6 +1114,15 @@ def handle_graph_break( inst: Instruction, reason: GraphCompileReason, ) -> None: +======= + speculation.fail_and_restart_analysis() + + def handle_graph_break( + self: "InstructionTranslatorBase", + inst: Instruction, + reason: GraphCompileReason, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( sys.version_info >= (3, 11) and sys.version_info < (3, 12) @@ -996,7 +1170,10 @@ def handle_graph_break( self.output.add_output_instructions( [create_instruction("KW_NAMES", argval=kw_names)] ) +<<<<<<< HEAD assert inst.arg is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) call_insts = create_call_function(inst.arg, False) call_insts[-1].copy_positions(inst) self.output.add_output_instructions(call_insts) @@ -1014,7 +1191,11 @@ def handle_graph_break( self.push(UnknownVariable()) self.output.add_output_instructions( self.create_call_resume_at( +<<<<<<< HEAD self.next_instruction, all_stack_locals_metadata, False +======= + self.next_instruction, all_stack_locals_metadata +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) @@ -1026,10 +1207,17 @@ def handle_graph_break( class BytecodeDistpatchTableMeta(type): """Installs a `cls.dispatch_table` on every subclass to speed up calls to self.OPCODE()""" +<<<<<<< HEAD def __init__(cls: type, name: str, bases: Any, dct: Any) -> None: super().__init__(name, bases, dct) # type: ignore[misc] def _missing(opname: str, *args: Any) -> None: +======= + def __init__(cls, name, bases, dct) -> None: + super().__init__(name, bases, dct) + + def _missing(opname, *args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unimplemented_v2( gb_type="Missing bytecode handler", context=f"{opname} with args {args}", @@ -1065,6 +1253,7 @@ class ExceptionStack: # + PUSH_EXC_INFO := pushes the current_exception to the *exception stack* # + POP_EXCEPT := pops TOS from the *exception stack* +<<<<<<< HEAD _exc_stack: list[ExceptionVals] = dataclasses.field(default_factory=list) _current_exception: Optional[ExceptionVals] = dataclasses.field(default=None) @@ -1076,10 +1265,24 @@ def set_current_exception(self, val: ExceptionVals) -> None: self._current_exception = val def move_current_exception_to_stack(self) -> None: +======= + _exc_stack: list[VariableTracker] = dataclasses.field(default_factory=list) + _current_exception: Optional[VariableTracker] = dataclasses.field(default=None) + + def clear_current_exception(self): + self._current_exception = None + + def set_current_exception(self, val): + self._set_context_and_break_context_reference_cycle(val) + self._current_exception = val + + def move_current_exception_to_stack(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self._current_exception is not None self.append(self._current_exception) self.clear_current_exception() +<<<<<<< HEAD def get_current_exception(self) -> ExceptionVals: assert self._current_exception is not None return self._current_exception @@ -1088,14 +1291,29 @@ def _set_context_recursive( self, val: ExceptionVals, prev_idx: int ) -> ExceptionVals: if (ctx := val.__context__) and type(ctx) is not ConstantVariable: # type: ignore[union-attr] +======= + def get_current_exception(self): + assert self._current_exception is not None + return self._current_exception + + def _set_context_recursive(self, val, prev_idx): + if (ctx := val.__context__) and type(ctx) is not ConstantVariable: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return val if len(self._exc_stack) + prev_idx > 0: prev = self._exc_stack[prev_idx] self._set_context_recursive(prev, prev_idx - 1) +<<<<<<< HEAD val.set_context(prev) # type: ignore[union-attr, arg-type] return val def _break_context_reference_cycle(self, val: ExceptionVals) -> None: +======= + val.set_context(prev) + return val + + def _break_context_reference_cycle(self, val): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # See test_exceptions::test_raise_does_not_create_context_chain_cycle # Based on https://github.com/python/cpython/blob/e635bf2e49797ecb976ce45a67fce2201a25ca68/Python/errors.c#L207-L228 # As noted on CPython, this is O(chain length) but the context chains @@ -1103,21 +1321,33 @@ def _break_context_reference_cycle(self, val: ExceptionVals) -> None: o = slow_o = val slow_update_toggle = False # floyd's algorithm for detecting cycle while True: +<<<<<<< HEAD context = o.__context__ # type: ignore[union-attr] +======= + context = o.__context__ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if type(context) is ConstantVariable: # context not set break if context is val: +<<<<<<< HEAD o.set_context(ConstantVariable(None)) # type: ignore[union-attr, arg-type] break o = context # type: ignore[assignment] +======= + o.set_context(ConstantVariable(None)) + break + + o = context +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if o is slow_o: # pre-existing cycle - all exceptions on the path were # visited and checked break if slow_update_toggle: +<<<<<<< HEAD # visited all exceptions slow_o = slow_o.__context__ # type: ignore[union-attr, assignment] slow_update_toggle = not slow_update_toggle @@ -1125,10 +1355,17 @@ def _break_context_reference_cycle(self, val: ExceptionVals) -> None: def _set_context_and_break_context_reference_cycle( self, val: ExceptionVals ) -> None: +======= + slow_o = slow_o.__context__ # visited all exceptions + slow_update_toggle = not slow_update_toggle + + def _set_context_and_break_context_reference_cycle(self, val): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # set Exception.__context__ self._set_context_recursive(val, len(self._exc_stack) - 1) self._break_context_reference_cycle(val) +<<<<<<< HEAD def pop(self) -> ExceptionVals: return self._exc_stack.pop() @@ -1142,6 +1379,21 @@ def __getitem__(self, index: int) -> ExceptionVals: return self._exc_stack[index] def __str__(self) -> str: +======= + def pop(self): + return self._exc_stack.pop() + + def append(self, val): + self._exc_stack.append(val) + + def __len__(self): + return len(self._exc_stack) + + def __getitem__(self, index): + return self._exc_stack[index] + + def __str__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"{self._exc_stack=} - {self._current_exception=}" __repr__ = __str__ @@ -1154,7 +1406,10 @@ class InstructionTranslatorBase( symbolic_locals: dict[str, VariableTracker] symbolic_globals: dict[str, VariableTracker] symbolic_torch_function_state: SymbolicTorchFunctionState +<<<<<<< HEAD post_prune_cell_and_freevars: Optional[dict[str, VariableTracker]] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stack: list[VariableTracker] instruction_pointer: Optional[int] current_instruction: Instruction @@ -1172,11 +1427,19 @@ class InstructionTranslatorBase( strict_checks_fn: Optional[Callable[[VariableTracker], bool]] start_point: Optional[int] is_leaf_tracer: bool +<<<<<<< HEAD parent: Optional[InstructionTranslatorBase] debug_locals: list[tuple[VariableTracker, list[VariableTracker]]] package: Optional[CompilePackage] def mark_inconsistent_side_effects(self) -> None: +======= + parent: Optional["InstructionTranslatorBase"] + debug_locals: list[tuple[VariableTracker, list[VariableTracker]]] + package: Optional["CompilePackage"] + + def mark_inconsistent_side_effects(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ InstructionTranslator has encountered instructions which may cause dynamo to see a different version of history from eager @@ -1184,7 +1447,11 @@ def mark_inconsistent_side_effects(self) -> None: """ self.inconsistent_side_effects = True +<<<<<<< HEAD def maybe_has_backedge(self) -> bool: +======= + def maybe_has_backedge(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This function employs a heuristic. It does not reliably detect a backedge. # The heuristic is straightforward: starting from the current instruction and # continuing to the end, if any jump instruction targets an instruction before @@ -1204,6 +1471,7 @@ def maybe_has_backedge(self) -> bool: # graph during a for loop. In general, its better to have fewer false # negatives so that Dynamo does not skip the whole frame. +<<<<<<< HEAD # If any parent tx has a backedge, then return True cur_tx: Optional[InstructionTranslatorBase] = self while cur_tx is not None: @@ -1226,10 +1494,43 @@ def freevars(self) -> list[str]: return self.code_options["co_freevars"] def cell_and_freevars(self) -> list[str]: +======= + cur_offset = self.current_instruction.offset + assert self.instruction_pointer is not None + for inst in self.instructions[self.instruction_pointer :]: + if inst.opname in ("RETURN_VALUE", "RETURN_CONST"): + return False + if inst.opname in JUMP_OPNAMES: + jump_offset = inst.argval + if jump_offset < cur_offset: + return True + return False + + def cellvars(self): + if not hasattr(self, "_cellvars"): + self._cellvars = tuple(self.code_options["co_cellvars"] or []) + # An inlined function might depend on the cellvar of the parent + # function. So, recursively obtain parent cellvars. + if isinstance(self, InliningInstructionTranslator): + self._cellvars += self.parent.cellvars() + return self._cellvars + + def freevars(self): + if not hasattr(self, "_freevars"): + self._freevars = tuple(self.code_options["co_freevars"] or []) + # An inlined function might depend on the freevar of the parent + # function. So, recursively obtain parent freevars. + if isinstance(self, InliningInstructionTranslator): + self._freevars += self.parent.freevars() + return self._freevars + + def cell_and_freevars(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not hasattr(self, "_cell_and_freevars"): self._cell_and_freevars = self.cellvars() + self.freevars() return self._cell_and_freevars +<<<<<<< HEAD def prune_dead_locals(self) -> None: # keep cell and freevar references alive self.post_prune_cell_and_freevars = { @@ -1237,18 +1538,30 @@ def prune_dead_locals(self) -> None: for k, v in self.symbolic_locals.items() if k in self.cell_and_freevars() } +======= + def prune_dead_locals(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Only keep the locals that must remain on the stack. reads = livevars_analysis(self.instructions, self.current_instruction) self.symbolic_locals = { k: v for k, v in self.symbolic_locals.items() if k in reads } +<<<<<<< HEAD +======= + # "Garbage collect the heap". + self.output.side_effects.prune_dead_object_new(self) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def call_function( self, fn: VariableTracker, args: list[VariableTracker], kwargs: dict[str, VariableTracker], +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(fn, VariableTracker) assert isinstance(args, list) assert isinstance(kwargs, dict) @@ -1265,13 +1578,18 @@ def call_function( raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}") self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] +<<<<<<< HEAD def inline_generator_function( self, fn: VariableTracker, args: Sequence[Any], kwargs: dict[str, Any] ) -> Any: +======= + def inline_generator_function(self, fn, args, kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Redirect the call to the generator "call_function" """ if not isinstance(fn, LocalGeneratorFunctionVariable): +<<<<<<< HEAD fn = LocalGeneratorFunctionVariable(fn) # type: ignore[arg-type] return fn.call_function(self, args, kwargs) # type: ignore[arg-type] @@ -1283,11 +1601,25 @@ def inline_user_function_return( """ self.is_leaf_tracer = False if config.enable_faithful_generator_behavior and is_generator(fn.get_code()): # type: ignore[attr-defined] +======= + fn = LocalGeneratorFunctionVariable(fn) + return fn.call_function(self, args, kwargs) + + def inline_user_function_return(self, fn, args, kwargs): + """ + A call to some user defined function by inlining it. + """ + if config.enable_faithful_generator_behavior and is_generator(fn.get_code()): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.inline_generator_function(fn, args, kwargs) else: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) +<<<<<<< HEAD def get_line_of_code_header(self, lineno: Optional[int] = None) -> str: +======= + def get_line_of_code_header(self, lineno=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if lineno is None: lineno = self.lineno inline_depth_str = ( @@ -1297,13 +1629,21 @@ def get_line_of_code_header(self, lineno: Optional[int] = None) -> str: funcname_str = "" if funcname is None else f" ({funcname})" return f"{self.f_code.co_filename}:{lineno} in {self.f_code.co_name}{funcname_str}{inline_depth_str}" +<<<<<<< HEAD def get_log_starts_line_log_str(self) -> str: +======= + def get_log_starts_line_log_str(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log_str = f"TRACE starts_line {self.get_line_of_code_header()}\n" line = linecache.getline(self.f_code.co_filename, self.lineno).rstrip() log_str += f" {line}" return log_str +<<<<<<< HEAD def starts_line(self, lineno: int) -> None: +======= + def starts_line(self, lineno): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.lineno == lineno: return self.lineno = lineno @@ -1314,10 +1654,15 @@ def starts_line(self, lineno: int) -> None: if self.is_trace_source_log_enabled: trace_source_log.debug("%s", LazyString(self.get_log_starts_line_log_str)) +<<<<<<< HEAD def step(self) -> bool: """Process exactly one instruction, return False we should exit""" self.error_on_graph_break = _get_error_on_graph_break() +======= + def step(self): + """Process exactly one instruction, return False we should exit""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ip = self.instruction_pointer if ip is None: return False @@ -1333,9 +1678,14 @@ def step(self) -> bool: and self.is_non_empty_graph() ): self.current_speculation = self.speculate() +<<<<<<< HEAD if self.current_speculation.failed(self): self.step_graph_break(inst) return False +======= + if self.current_speculation.failed: + return self.step_graph_break(inst) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.is_trace_bytecode_log_enabled: trace_bytecode_log.debug( @@ -1360,12 +1710,20 @@ def step(self) -> bool: raise log.debug("step triggered compile", exc_info=True) +<<<<<<< HEAD self.current_speculation.fail_and_restart_analysis(self.error_on_graph_break) return False if sys.version_info >= (3, 11): def update_block_stack(self, inst: Instruction) -> None: +======= + self.current_speculation.fail_and_restart_analysis() + + if sys.version_info >= (3, 11): + + def update_block_stack(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # 3.11+ no longer uses a block stack, but we still keep track of one # so that we know which contexts are currently active. # For our purposes, all exception table entries with the same target @@ -1406,6 +1764,7 @@ def update_block_stack(self, inst: Instruction) -> None: else: +<<<<<<< HEAD def update_block_stack(self, inst: Instruction) -> None: pass @@ -1415,6 +1774,16 @@ def next_instruction(self) -> Instruction: return self.instructions[self.instruction_pointer] def step_graph_break(self, continue_inst: Instruction) -> None: +======= + def update_block_stack(self, inst): + pass + + @property + def next_instruction(self): + return self.instructions[self.instruction_pointer] # type: ignore[index] + + def step_graph_break(self, continue_inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # generate code from checkpoint assert not self.output.output_instructions assert self.current_speculation is not None @@ -1422,14 +1791,19 @@ def step_graph_break(self, continue_inst: Instruction) -> None: # where we call step_graph_break right now is when the stack is empty, # so let's enforce that for now. assert not self.stack +<<<<<<< HEAD # NOTE: if we support non-empty self.stack in the future, the `stack_pops` argument # below should be set to the stack length to ensure that the stack is codegen'd # for the rest of the function. all_stack_locals_metadata = self.output.compile_subgraph( +======= + self.output.compile_subgraph( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, partial_convert=True, reason=GraphCompileReason("step_unsupported", [self.frame_summary()]), ) +<<<<<<< HEAD if self.parent: # nested graph break assert config.nested_graph_breaks @@ -1472,17 +1846,29 @@ def step_graph_break(self, continue_inst: Instruction) -> None: ) def run_ctx_mgr(self) -> Any: +======= + self.output.add_output_instructions( + [create_jump_absolute(continue_inst)] + self.instructions + ) + + def run_ctx_mgr(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: Don't push the top level frame summary; set_current_loc will # take care of it. However, DO make sure we attach real_stack to # exceptions return TracingContext.current_frame(None) +<<<<<<< HEAD def run(self) -> None: +======= + def run(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.run_ctx_mgr(): dump_file(self.f_code.co_filename) try: self.output.push_tx(self) self.start_point = self.instruction_pointer +<<<<<<< HEAD try: while self.step(): pass @@ -1494,6 +1880,10 @@ def run(self) -> None: f"{type(e).__qualname__}: {str(e)}" ).with_traceback(e.__traceback__) from None raise +======= + while self.step(): + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except TensorifyScalarRestartAnalysis: raise except BackendCompilerFailed: @@ -1529,13 +1919,21 @@ def run(self) -> None: # twice is not an issue (second stop is a no op). self.output.mark_bytecode_tracing_stop() +<<<<<<< HEAD def push(self, val: Optional[VariableTracker]) -> None: +======= + def push(self, val: Optional[VariableTracker]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert val is None or isinstance(val, VariableTracker), ( f"push expects VariableTracker, got {typestr(val)}" ) self.stack.append(val) # type: ignore[arg-type] +<<<<<<< HEAD def push_many(self, vals: list[VariableTracker]) -> None: +======= + def push_many(self, vals: list[VariableTracker]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for val in vals: self.push(val) @@ -1545,7 +1943,11 @@ def pop(self) -> VariableTracker: def popn(self, n: int) -> list[VariableTracker]: return [*reversed([self.pop() for _ in range(n)])] +<<<<<<< HEAD def LOAD_FAST(self, inst: Instruction) -> None: +======= + def LOAD_FAST(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) name = inst.argval if self.exec_recorder and name in self.f_locals: self.exec_recorder.add_local_var(name, self.f_locals[name]) @@ -1580,7 +1982,11 @@ def LOAD_FAST(self, inst: Instruction) -> None: if name.startswith("__stack"): self.symbolic_locals.pop(name) +<<<<<<< HEAD def LOAD_DEREF(self, inst: Instruction) -> None: +======= + def LOAD_DEREF(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert inst.argval in self.cell_and_freevars() cell = self.symbolic_locals[inst.argval] contents_var = self.output.side_effects.load_cell(cell) @@ -1589,11 +1995,16 @@ def LOAD_DEREF(self, inst: Instruction) -> None: if self.exec_recorder and inst.argval in self.f_locals: self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval]) +<<<<<<< HEAD def STORE_FAST(self, inst: Instruction) -> None: +======= + def STORE_FAST(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) name = inst.argval loaded_vt = self.pop() loaded_vt.set_name_hint(name) self.symbolic_locals[name] = loaded_vt +<<<<<<< HEAD if name == IS_TRACING_RESUME_PROLOGUE_VARNAME: val = loaded_vt.as_python_constant() assert type(val) is bool @@ -1603,6 +2014,13 @@ def DELETE_FAST(self, inst: Instruction) -> None: del self.symbolic_locals[inst.argval] def STORE_DEREF(self, inst: Instruction) -> None: # type: ignore[override] +======= + + def DELETE_FAST(self, inst): + del self.symbolic_locals[inst.argval] + + def STORE_DEREF(self, inst): # type: ignore[override] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert inst.argval in self.cell_and_freevars() cell = self.symbolic_locals[inst.argval] val = self.pop() @@ -1614,6 +2032,7 @@ def STORE_DEREF(self, inst: Instruction) -> None: # type: ignore[override] LOAD_CLOSURE = LOAD_FAST +<<<<<<< HEAD def _load_const(self, inst: Instruction) -> ConstantVariable: i = inst.arg if i is None: @@ -1629,6 +2048,21 @@ def LOAD_CONST(self, inst: Instruction) -> None: self.push(self._load_const(inst)) def _load_global(self, inst: Instruction) -> None: +======= + def _load_const(self, inst): + i = inst.arg + if i is None: + return ConstantVariable.create(value=inst.argval) + val = self._constants_cache[i] + if not val: + self._constants_cache[i] = val = ConstantVariable.create(value=inst.argval) + return val + + def LOAD_CONST(self, inst): + self.push(self._load_const(inst)) + + def _load_global(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) name = inst.argval if self.exec_recorder: @@ -1650,21 +2084,33 @@ def _load_global(self, inst: Instruction) -> None: self.push(VariableTracker.build(self, value, GlobalSource(name))) @functools.cached_property +<<<<<<< HEAD def nn_modules_globals_vt(self) -> VariableTracker: +======= + def nn_modules_globals_vt(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) module_name = "torch.nn.modules.module" module_source = self.import_source(module_name) fglobals_value = _import_module(module_name) return VariableTracker.build(self, fglobals_value, module_source) +<<<<<<< HEAD def LOAD_GLOBAL(self, inst: Instruction) -> None: assert inst.arg is not None +======= + def LOAD_GLOBAL(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2: self.PUSH_NULL(inst) self._load_global(inst) if sys.version_info >= (3, 13) and inst.arg % 2: self.PUSH_NULL(inst) +<<<<<<< HEAD def STORE_GLOBAL(self, inst: Instruction) -> None: +======= + def STORE_GLOBAL(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) value = self.pop() name = inst.argval source = GlobalSource(name) @@ -1685,7 +2131,11 @@ def STORE_GLOBAL(self, inst: Instruction) -> None: # Cache note: This cache only exists for the duration of this # InstructionTranslator - so it should be safe to do. @cache_method +<<<<<<< HEAD def import_source(self, module_name: str) -> GlobalSource: +======= + def import_source(self, module_name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Create an alias to a module for use in guards""" if "torch_package" in module_name: value = torch.package.package_importer._package_imported_modules[ @@ -1700,14 +2150,21 @@ def import_source(self, module_name: str) -> GlobalSource: if self.package is not None: self.package.add_import_source(alias, module_name) +<<<<<<< HEAD self.output.import_sources[alias] = module_name +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f_globals = self.output.global_scope assert alias not in f_globals or f_globals[alias] is value f_globals[alias] = value self.output.update_co_names(alias) return GlobalSource(alias) +<<<<<<< HEAD def resolve_name(self, name: str, package: str, level: int) -> str: +======= + def resolve_name(self, name, package, level): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Copied from the Cpython implementation of __import__ Resolve a relative module name to an absolute one. @@ -1719,7 +2176,11 @@ def resolve_name(self, name: str, package: str, level: int) -> str: base = bits[0] return f"{base}.{name}" if name else base +<<<<<<< HEAD def calc_package(self) -> str: +======= + def calc_package(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Copied from the Cpython implementation of __import__ https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090 @@ -1748,7 +2209,11 @@ def calc_package(self) -> str: package = package.rpartition(".")[0] return package +<<<<<<< HEAD def IMPORT_NAME(self, inst: Instruction) -> None: +======= + def IMPORT_NAME(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) level, fromlist = self.popn(2) level = level.as_python_constant() fromlist = fromlist.as_python_constant() @@ -1808,14 +2273,22 @@ def IMPORT_NAME(self, inst: Instruction) -> None: # fb internal 3.12 opcode EAGER_IMPORT_NAME = IMPORT_NAME +<<<<<<< HEAD def IMPORT_FROM(self, inst: Instruction) -> None: +======= + def IMPORT_FROM(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.DUP_TOP(inst) self._load_attr(inst) # Cache note: This cache only exists for the duration of this # InstructionTranslator - so it should be safe to do. @cache_method +<<<<<<< HEAD def load_builtin_from_argval(self, argval: Any) -> VariableTracker: +======= + def load_builtin_from_argval(self, argval): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if argval not in self.f_builtins: raise Unsupported(f"name '{argval}' is not defined") val = self.f_builtins[argval] @@ -1830,6 +2303,7 @@ def load_builtin_from_argval(self, argval: Any) -> VariableTracker: assert is_builtin_constant(val) return ConstantVariable.create(value=val) +<<<<<<< HEAD def load_builtin(self, inst: Instruction) -> None: self.push(self.load_builtin_from_argval(inst.argval)) @@ -1837,6 +2311,14 @@ def jump(self, inst: Instruction) -> None: assert self.instruction_pointer is not None assert self.start_point is not None assert inst.target is not None +======= + def load_builtin(self, inst): + self.push(self.load_builtin_from_argval(inst.argval)) + + def jump(self, inst): + assert self.instruction_pointer is not None + assert self.start_point is not None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) get_metrics_context().increment( "ir_count", self.instruction_pointer - self.start_point ) @@ -1851,6 +2333,7 @@ def jump(self, inst: Instruction) -> None: JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True) JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True) +<<<<<<< HEAD def SETUP_LOOP(self, inst: Instruction) -> None: # only exists in python<=3.7 assert inst.target is not None @@ -1875,16 +2358,47 @@ def BEGIN_FINALLY(self, inst: Instruction) -> None: self.push(None) def WITH_CLEANUP_START(self, inst: Instruction) -> None: +======= + def SETUP_LOOP(self, inst): + # only exists in python<=3.7 + self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack))) + + def SETUP_EXCEPT(self, inst): + # only exists in python<=3.7 + self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack))) + + def POP_BLOCK(self, inst): + self.block_stack.pop() + + def SETUP_WITH(self, inst): + self.setup_or_before_with(inst) + + def SETUP_FINALLY(self, inst): + self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack))) + + def BEGIN_FINALLY(self, inst): + self.push(None) + + def WITH_CLEANUP_START(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) exit, exc = self.popn(2) assert exc is None self.push(exc) self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {})) +<<<<<<< HEAD def WITH_CLEANUP_FINISH(self, inst: Instruction) -> None: self.popn(2) self.push(None) def FOR_ITER(self, inst: Instruction) -> None: +======= + def WITH_CLEANUP_FINISH(self, inst): + self.popn(2) + self.push(None) + + def FOR_ITER(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) it = self.pop().realize() try: val = it.next_variable(self) @@ -1904,7 +2418,11 @@ def FOR_ITER(self, inst: Instruction) -> None: self.push(ConstantVariable.create(None)) self.jump(inst) +<<<<<<< HEAD def _create_exception_type(self, val: VariableTracker) -> VariableTracker: +======= + def _create_exception_type(self, val): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance( val, (variables.BuiltinVariable, UserDefinedExceptionClassVariable) ): @@ -1913,7 +2431,11 @@ def _create_exception_type(self, val: VariableTracker) -> VariableTracker: val = val.call_function(self, [], {}) # type: ignore[arg-type] return val +<<<<<<< HEAD def _raise_exception_variable(self, val: VariableTracker) -> NoReturn: +======= + def _raise_exception_variable(self, val) -> NoReturn: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # User can raise exception in 2 ways # 1) raise exception type - raise NotImplementedError # 2) raise exception instance - raise NotImplemetedError("foo") @@ -1931,11 +2453,19 @@ def _raise_exception_variable(self, val: VariableTracker) -> NoReturn: val = variables.BuiltinVariable(RuntimeError).call_function(self, [], {}) # type: ignore[arg-type] # Save the exception in a global data structure +<<<<<<< HEAD self.exn_vt_stack.set_current_exception(val) # type: ignore[arg-type] # 2) when user raises exception instance if self._isinstance_exception(val): observed_exception_type = exc.get_dynamo_observed_exception(val.exc_type) # type: ignore[attr-defined, union-attr] +======= + self.exn_vt_stack.set_current_exception(val) + + # 2) when user raises exception instance + if self._isinstance_exception(val): + observed_exception_type = exc.get_dynamo_observed_exception(val.exc_type) # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise observed_exception_type(f"raised exception {val}") unimplemented_v2( gb_type="Failed to raise exception", @@ -1944,7 +2474,11 @@ def _raise_exception_variable(self, val: VariableTracker) -> NoReturn: hints=[*graph_break_hints.USER_ERROR], ) +<<<<<<< HEAD def RAISE_VARARGS(self, inst: Instruction) -> None: +======= + def RAISE_VARARGS(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if inst.arg == 0: if not len(self.exn_vt_stack): msg = ConstantVariable("No active exception to reraise") @@ -1958,21 +2492,35 @@ def RAISE_VARARGS(self, inst: Instruction) -> None: self._raise_exception_variable(val) elif inst.arg == 1: # raise TOS +<<<<<<< HEAD val = self.stack[-1] # type: ignore[assignment] +======= + val = self.stack[-1] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._raise_exception_variable(val) else: # raise .. from ... from_vt = self.pop() +<<<<<<< HEAD val = self.pop() # type: ignore[assignment] +======= + val = self.pop() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: self._raise_exception_variable(val) finally: # Update __cause__/__supppress_context__ in the raised exception curr_exc = self.exn_vt_stack.get_current_exception() cause = self._create_exception_type(from_vt) +<<<<<<< HEAD curr_exc.call_setattr(self, ConstantVariable("__cause__"), cause) # type: ignore[arg-type, union-attr, assignment] def CLEANUP_THROW(self, inst: Instruction) -> None: +======= + curr_exc.call_setattr(self, ConstantVariable("__cause__"), cause) + + def CLEANUP_THROW(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # https://github.com/python/cpython/pull/96010 tos = self.stack[-1] assert isinstance(tos, ExceptionVariable) @@ -1986,7 +2534,11 @@ def CLEANUP_THROW(self, inst: Instruction) -> None: else: self.RERAISE(inst) +<<<<<<< HEAD def RERAISE(self, inst: Instruction) -> None: +======= + def RERAISE(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # https://docs.python.org/3/library/dis.html#opcode-RERAISE # Re-raises the exception currently on top of the stack. If oparg is # non-zero, pops an additional value from the stack which is used to @@ -2009,7 +2561,11 @@ def RERAISE(self, inst: Instruction) -> None: _tb = self.pop() self._raise_exception_variable(val) +<<<<<<< HEAD def _isinstance_exception(self, val: VariableTracker) -> TypeIs[ExceptionVals]: +======= + def _isinstance_exception(self, val): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return isinstance( val, ( @@ -2019,7 +2575,11 @@ def _isinstance_exception(self, val: VariableTracker) -> TypeIs[ExceptionVals]: ), ) +<<<<<<< HEAD def WITH_EXCEPT_START(self, inst: Instruction) -> None: +======= + def WITH_EXCEPT_START(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 11): # At the top of the stack are 4 values: # - TOP = exc_info() @@ -2032,7 +2592,11 @@ def WITH_EXCEPT_START(self, inst: Instruction) -> None: fn = self.stack[-4] val = self.stack[-1] assert self._isinstance_exception(val) +<<<<<<< HEAD typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined, union-attr] +======= + typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tb = ConstantVariable(None) else: assert len(self.stack) >= 7 @@ -2044,12 +2608,17 @@ def WITH_EXCEPT_START(self, inst: Instruction) -> None: self.call_function(fn, [typ, val, tb], {}) +<<<<<<< HEAD def exception_handler(self, raised_exception: ObservedException) -> None: +======= + def exception_handler(self, raised_exception): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) observed_exn_gb_explanation = ( "Dynamo found no exception handler at the top-level compiled function " "when encountering an exception. Exception will propagate outside the compiled region." ) +<<<<<<< HEAD def bubble_exception_to_interpreter() -> None: # Bubble the exception to the interpreter curr_exc = self.exn_vt_stack.get_current_exception() @@ -2065,6 +2634,8 @@ def bubble_exception_to_interpreter() -> None: ], ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 11): exn_tab_entry = self.current_instruction.exn_tab_entry if exn_tab_entry: @@ -2085,13 +2656,29 @@ def bubble_exception_to_interpreter() -> None: self.push(self.exn_vt_stack.get_current_exception()) # 4) jump to the handler +<<<<<<< HEAD self.jump(exn_tab_entry) # type: ignore[arg-type] +======= + self.jump(exn_tab_entry) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: # No handler found. Bubble the exception to the parent # instruction translator. We use special exception for this. self.stack.clear() if type(self) is InstructionTranslator: +<<<<<<< HEAD bubble_exception_to_interpreter() +======= + unimplemented_v2( + gb_type="Observed exception", + context=str(raised_exception), + explanation=observed_exn_gb_explanation, + hints=[ + *graph_break_hints.USER_ERROR, + *graph_break_hints.SUPPORTABLE, + ], + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise raised_exception else: if len(self.block_stack): @@ -2163,10 +2750,25 @@ def bubble_exception_to_interpreter() -> None: # instruction translator. We use special exception for this. self.stack.clear() if type(self) is InstructionTranslator: +<<<<<<< HEAD bubble_exception_to_interpreter() raise raised_exception def PUSH_EXC_INFO(self, inst: Instruction) -> None: +======= + unimplemented_v2( + gb_type="Observed exception", + context=str(raised_exception), + explanation=observed_exn_gb_explanation, + hints=[ + *graph_break_hints.USER_ERROR, + *graph_break_hints.SUPPORTABLE, + ], + ) + raise raised_exception + + def PUSH_EXC_INFO(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # https://docs.python.org/3/library/dis.html#opcode-PUSH_EXC_INFO # Pops a value from the stack. Pushes the current exception to the top # of the stack. Pushes the value originally popped back to the stack. @@ -2188,14 +2790,22 @@ def PUSH_EXC_INFO(self, inst: Instruction) -> None: val = self.pop() if len(self.exn_vt_stack) == 0: +<<<<<<< HEAD prev_exc: VariableTracker = ConstantVariable(None) +======= + prev_exc = ConstantVariable(None) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: prev_exc = self.exn_vt_stack[-1] self.push(prev_exc) self.push(val) self.exn_vt_stack.move_current_exception_to_stack() +<<<<<<< HEAD def POP_EXCEPT(self, inst: Instruction) -> None: +======= + def POP_EXCEPT(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 11): _ = self.pop() # This exception is handled and therefore we can clear the error indicator @@ -2216,7 +2826,11 @@ def POP_EXCEPT(self, inst: Instruction) -> None: assert len(self.exn_vt_stack) self.exn_vt_stack.pop() +<<<<<<< HEAD def check_if_exc_matches(self) -> bool: +======= + def check_if_exc_matches(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(self.stack) >= 2 expected_exc_types = self.pop() if sys.version_info >= (3, 11): @@ -2286,7 +2900,11 @@ def check_if_exc_matches(self) -> bool: hints=[*graph_break_hints.USER_ERROR], ) if self._isinstance_exception(exc_instance) and issubclass( +<<<<<<< HEAD exc_instance.exc_type, # type: ignore[union-attr] +======= + exc_instance.exc_type, # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expected_type.fn, # type: ignore[attr-defined] ): return True @@ -2297,6 +2915,7 @@ def check_if_exc_matches(self) -> bool: return False +<<<<<<< HEAD def CHECK_EXC_MATCH(self, inst: Instruction) -> None: self.push(variables.ConstantVariable(self.check_if_exc_matches())) @@ -2305,22 +2924,44 @@ def JUMP_IF_NOT_EXC_MATCH(self, inst: Instruction) -> None: self.jump(inst) def COMPARE_OP(self, inst: Instruction) -> None: +======= + def CHECK_EXC_MATCH(self, inst): + self.push(variables.ConstantVariable(self.check_if_exc_matches())) + + def JUMP_IF_NOT_EXC_MATCH(self, inst): + if not self.check_if_exc_matches(): + self.jump(inst) + + def COMPARE_OP(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if inst.argval == "exception match": self.CHECK_EXC_MATCH(inst) else: self.push(compare_op_handlers[inst.argval](self, self.popn(2), {})) +<<<<<<< HEAD def GET_ITER(self, inst: Instruction) -> None: self.call_function(BuiltinVariable(iter), [self.pop()], {}) @break_graph_if_unsupported(push=1) def CALL_FUNCTION(self, inst: Instruction) -> None: +======= + def GET_ITER(self, inst): + self.call_function(BuiltinVariable(iter), [self.pop()], {}) + + @break_graph_if_unsupported(push=1) + def CALL_FUNCTION(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = self.popn(inst.argval) fn = self.pop() self.call_function(fn, args, {}) @break_graph_if_unsupported(push=1) +<<<<<<< HEAD def CALL_FUNCTION_EX(self, inst: Instruction) -> None: +======= + def CALL_FUNCTION_EX(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargsvars: VariableTracker if inst.argval == 0: kwargsvars = ConstDictVariable({}) @@ -2371,7 +3012,11 @@ def CALL_FUNCTION_EX(self, inst: Instruction) -> None: self.call_function(fn, argsvars.items, kwargsvars) @break_graph_if_unsupported(push=1) +<<<<<<< HEAD def CALL_FUNCTION_KW(self, inst: Instruction) -> None: +======= + def CALL_FUNCTION_KW(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) argnames = self.pop() args = self.popn(inst.argval) fn = self.pop() @@ -2382,7 +3027,11 @@ def CALL_FUNCTION_KW(self, inst: Instruction) -> None: assert len(kwargs) == len(argnames) self.call_function(fn, args, kwargs) +<<<<<<< HEAD def LOAD_METHOD_SUPER(self, inst: Instruction) -> None: +======= + def LOAD_METHOD_SUPER(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) arg = inst.argval[0] argval = self.code_options["co_names"][arg] @@ -2391,13 +3040,21 @@ def LOAD_METHOD_SUPER(self, inst: Instruction) -> None: else: self.LOAD_METHOD(dataclasses.replace(inst, argval=argval)) +<<<<<<< HEAD def LOAD_ATTR_SUPER(self, inst: Instruction) -> None: +======= + def LOAD_ATTR_SUPER(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) arg = inst.argval[0] argval = self.code_options["co_names"][arg] self._load_attr(dataclasses.replace(inst, argval=argval)) +<<<<<<< HEAD def LOAD_METHOD(self, inst: Instruction) -> None: +======= + def LOAD_METHOD(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._load_attr(inst) obj = self.pop() if sys.version_info >= (3, 13): @@ -2413,14 +3070,22 @@ def LOAD_METHOD(self, inst: Instruction) -> None: self.push(obj) self.push(None) +<<<<<<< HEAD def CALL_METHOD(self, inst: Instruction) -> None: +======= + def CALL_METHOD(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = self.popn(inst.argval) dummy = self.pop() assert dummy is None fn = self.pop() self.call_function(fn, args, {}) +<<<<<<< HEAD def _load_attr(self, inst: Instruction) -> None: +======= + def _load_attr(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = self.pop() result = BuiltinVariable(getattr).call_function( self, # type: ignore[arg-type] @@ -2429,16 +3094,26 @@ def _load_attr(self, inst: Instruction) -> None: ) self.push(result) +<<<<<<< HEAD def LOAD_ATTR(self, inst: Instruction) -> None: +======= + def LOAD_ATTR(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 12): if inst.arg % 2: self.LOAD_METHOD(inst) return self._load_attr(inst) +<<<<<<< HEAD def STORE_ATTR(self, inst: Instruction) -> None: speculation = self.speculate() if speculation.failed(self): +======= + def STORE_ATTR(self, inst): + speculation = self.speculate() + if speculation.failed: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.store_attr_graph_break(inst) val, obj = self.popn(2) @@ -2462,9 +3137,15 @@ def STORE_ATTR(self, inst: Instruction) -> None: log.debug("STORE_ATTR triggered compile", exc_info=True) e.remove_from_stats() e.add_to_stats("graph_break") +<<<<<<< HEAD speculation.fail_and_restart_analysis(self.error_on_graph_break) def store_attr_graph_break(self, inst: Instruction) -> None: +======= + speculation.fail_and_restart_analysis() + + def store_attr_graph_break(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log_graph_break(self.code_options, reason="STORE_ATTR-caused graph break") if not self.should_compile_partial_graph(): unimplemented_v2( @@ -2482,12 +3163,19 @@ def store_attr_graph_break(self, inst: Instruction) -> None: self.output.add_output_instructions([copy.copy(inst)]) self.popn(2) self.output.add_output_instructions( +<<<<<<< HEAD self.create_call_resume_at( self.next_instruction, all_stack_locals_metadata, False ) ) def DELETE_ATTR(self, inst: Instruction) -> None: +======= + self.create_call_resume_at(self.next_instruction, all_stack_locals_metadata) + ) + + def DELETE_ATTR(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = self.pop() BuiltinVariable(delattr).call_function( self, # type: ignore[arg-type] @@ -2495,6 +3183,7 @@ def DELETE_ATTR(self, inst: Instruction) -> None: {}, ) +<<<<<<< HEAD def create_call_resume_at( self, inst: Instruction, @@ -2840,6 +3529,40 @@ def BUILD_LIST(self, inst: Instruction) -> None: self.push(ListVariable(items, mutation_type=ValueMutationNew())) def BUILD_SET(self, inst: Instruction) -> None: +======= + def create_call_resume_at(self, offset, all_stack_locals_metadata): + raise AssertionError( + f"create_call_resume_at not overridden by subclass {type(self)}" + ) + + def should_compile_partial_graph(self) -> bool: + raise AssertionError( + f"should_compile_partial_graph not overridden by subclass {type(self)}" + ) + + @break_graph_if_unsupported(push=0) + def STORE_SUBSCR(self, inst): + val, obj, key = self.popn(3) + obj.call_method(self, "__setitem__", [key, val], {}) + + def DELETE_SUBSCR(self, inst): + obj, key = self.popn(2) + obj.call_method(self, "__delitem__", [key], {}) + + def BUILD_TUPLE(self, inst): + items = self.popn(inst.argval) + self.push(TupleVariable(items)) + + def BUILD_SLICE(self, inst): + items = self.popn(inst.argval) + self.push(SliceVariable(items)) + + def BUILD_LIST(self, inst): + items = self.popn(inst.argval) + self.push(ListVariable(items, mutation_type=ValueMutationNew())) + + def BUILD_SET(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if config.inject_BUILD_SET_unimplemented_TESTING_ONLY: unimplemented_v2( gb_type="missing BUILD_SET handler", @@ -2851,7 +3574,11 @@ def BUILD_SET(self, inst: Instruction) -> None: new_set = SetVariable(items, mutation_type=ValueMutationNew()) self.push(new_set) +<<<<<<< HEAD def BUILD_LIST_UNPACK(self, inst: Instruction, cls: type = ListVariable) -> None: +======= + def BUILD_LIST_UNPACK(self, inst, cls=ListVariable): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) seqs = self.popn(inst.argval) items = [] for seq in seqs: @@ -2867,21 +3594,37 @@ def BUILD_LIST_UNPACK(self, inst: Instruction, cls: type = ListVariable) -> None ) self.push(cls(items, mutation_type=ValueMutationNew())) +<<<<<<< HEAD def BUILD_TUPLE_UNPACK(self, inst: Instruction) -> None: +======= + def BUILD_TUPLE_UNPACK(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.BUILD_LIST_UNPACK(inst, cls=TupleVariable) BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK +<<<<<<< HEAD def BUILD_MAP(self, inst: Instruction) -> None: +======= + def BUILD_MAP(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) items = self.popn(inst.argval * 2) d = dict(zip(items[::2], items[1::2])) self.push(ConstDictVariable(d, mutation_type=ValueMutationNew())) +<<<<<<< HEAD def BUILD_MAP_UNPACK(self, inst: Instruction) -> None: items = self.popn(inst.argval) # ensure everything is a dict items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] # type: ignore[arg-type] result: dict[Any, Any] = {} +======= + def BUILD_MAP_UNPACK(self, inst): + items = self.popn(inst.argval) + # ensure everything is a dict + items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] # type: ignore[arg-type] + result = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for x in items: assert isinstance(x, ConstDictVariable) result.update(x.items) @@ -2894,7 +3637,11 @@ def BUILD_MAP_UNPACK(self, inst: Instruction) -> None: BUILD_MAP_UNPACK_WITH_CALL = BUILD_MAP_UNPACK +<<<<<<< HEAD def BUILD_CONST_KEY_MAP(self, inst: Instruction) -> None: +======= + def BUILD_CONST_KEY_MAP(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) keys = self.pop() values = self.popn(inst.argval) assert isinstance(keys, TupleVariable) @@ -2910,14 +3657,21 @@ def BUILD_CONST_KEY_MAP(self, inst: Instruction) -> None: ) ) +<<<<<<< HEAD def MAP_ADD(self, inst: Instruction) -> None: k, v = self.popn(2) assert inst.argval > 0 assert inst.arg is not None +======= + def MAP_ADD(self, inst): + k, v = self.popn(2) + assert inst.argval > 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = self.stack[-inst.arg].realize() assert isinstance(obj, ConstDictVariable) obj.call_method(self, "__setitem__", (k, v), {}) # type: ignore[arg-type] +<<<<<<< HEAD def SET_ADD(self, inst: Instruction) -> None: v = self.pop() assert inst.argval > 0 @@ -2931,22 +3685,45 @@ def SET_UPDATE(self, inst: Instruction) -> None: v = self.pop() assert inst.argval > 0 assert inst.arg is not None +======= + def SET_ADD(self, inst): + v = self.pop() + assert inst.argval > 0 + obj = self.stack[-inst.arg] + assert isinstance(obj, SetVariable) + assert obj.is_mutable() + return obj.call_method(self, "add", [v], {}) + + def SET_UPDATE(self, inst): + v = self.pop() + assert inst.argval > 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = self.stack[-inst.arg] assert isinstance(obj, SetVariable) assert obj.is_mutable() obj.call_method(self, "update", [v], {}) +<<<<<<< HEAD def LIST_APPEND(self, inst: Instruction) -> None: v = self.pop() assert inst.argval > 0 assert inst.arg is not None +======= + def LIST_APPEND(self, inst): + v = self.pop() + assert inst.argval > 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = self.stack[-inst.arg].realize() assert isinstance(obj, ListVariable) assert obj.is_mutable() self.output.side_effects.mutation(obj) obj.items.append(v) +<<<<<<< HEAD def MAKE_FUNCTION(self, inst: Instruction) -> None: +======= + def MAKE_FUNCTION(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) flags = inst.arg if sys.version_info < (3, 11): fn_name = self.pop() @@ -2963,6 +3740,7 @@ def MAKE_FUNCTION(self, inst: Instruction) -> None: if sys.version_info < (3, 13): # in 3.13, this is handled in SET_FUNCTION_ATTRIBUTE +<<<<<<< HEAD if flags is not None: if flags & 0x08: closure = self.pop() @@ -2972,6 +3750,16 @@ def MAKE_FUNCTION(self, inst: Instruction) -> None: kwdefaults = self.pop() if flags & 0x01: defaults = self.pop() +======= + if flags & 0x08: + closure = self.pop() + if flags & 0x04: + annotations = self.pop() + if flags & 0x02: + kwdefaults = self.pop() + if flags & 0x01: + defaults = self.pop() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.push( NestedUserFunctionVariable( @@ -2985,7 +3773,11 @@ def MAKE_FUNCTION(self, inst: Instruction) -> None: ) ) +<<<<<<< HEAD def UNPACK_SEQUENCE(self, inst: Instruction) -> None: +======= + def UNPACK_SEQUENCE(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) seq = self.pop() if isinstance(seq, TensorVariable): val = seq.unpack_var_sequence(self, idxes=range(inst.argval)) # type: ignore[arg-type] @@ -3014,7 +3806,11 @@ def UNPACK_SEQUENCE(self, inst: Instruction) -> None: for i in reversed(val): self.push(i) +<<<<<<< HEAD def UNPACK_EX(self, inst: Instruction) -> None: +======= + def UNPACK_EX(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert 0 <= inst.argval <= 0xFFFF prefix = inst.argval & 0xFF # low byte suffix = inst.argval >> 8 # high byte @@ -3038,6 +3834,7 @@ def UNPACK_EX(self, inst: Instruction) -> None: hints=[*graph_break_hints.USER_ERROR], ) +<<<<<<< HEAD @break_graph_if_unsupported(push=0) def graph_break_on_leaf_function(self, inst: Instruction) -> None: if self.is_leaf_tracer: @@ -3059,12 +3856,25 @@ def POP_TOP(self, inst: Instruction) -> None: self.pop() def ROT_TWO(self, inst: Instruction) -> None: +======= + def NOP(self, inst): + pass + + def POP_TOP(self, inst): + self.pop() + + def ROT_TWO(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = self.pop() b = self.pop() self.push(a) self.push(b) +<<<<<<< HEAD def ROT_THREE(self, inst: Instruction) -> None: +======= + def ROT_THREE(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = self.pop() b = self.pop() c = self.pop() @@ -3072,7 +3882,11 @@ def ROT_THREE(self, inst: Instruction) -> None: self.push(c) self.push(b) +<<<<<<< HEAD def ROT_FOUR(self, inst: Instruction) -> None: +======= + def ROT_FOUR(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = self.pop() b = self.pop() c = self.pop() @@ -3082,12 +3896,20 @@ def ROT_FOUR(self, inst: Instruction) -> None: self.push(c) self.push(b) +<<<<<<< HEAD def DUP_TOP(self, inst: Instruction) -> None: +======= + def DUP_TOP(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = self.pop() self.push(a) self.push(a) +<<<<<<< HEAD def DUP_TOP_TWO(self, inst: Instruction) -> None: +======= + def DUP_TOP_TWO(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = self.pop() b = self.pop() self.push(b) @@ -3095,7 +3917,11 @@ def DUP_TOP_TWO(self, inst: Instruction) -> None: self.push(b) self.push(a) +<<<<<<< HEAD def _convert_value(self, value: VariableTracker, flag: int) -> VariableTracker: +======= + def _convert_value(self, value, flag): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if flag == 1: return BuiltinVariable(str).call_function(self, [value], {}) # type: ignore[arg-type] elif flag == 2: @@ -3104,7 +3930,11 @@ def _convert_value(self, value: VariableTracker, flag: int) -> VariableTracker: return BuiltinVariable(ascii).call_function(self, [value], {}) # type: ignore[arg-type] return value +<<<<<<< HEAD def _format_value(self, fmt_spec: VariableTracker, flags: int) -> None: +======= + def _format_value(self, fmt_spec, flags): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) value = self.pop() if isinstance(value, SymNodeVariable): from torch._dynamo.variables.lazy import ( @@ -3124,9 +3954,14 @@ def _format_value(self, fmt_spec: VariableTracker, flags: int) -> None: self.call_function(BuiltinVariable(str.format), [fmt_var, value], {}) +<<<<<<< HEAD def FORMAT_VALUE(self, inst: Instruction) -> None: flags = inst.arg assert flags is not None +======= + def FORMAT_VALUE(self, inst): + flags = inst.arg +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (flags & 0x04) == 0x04: fmt_spec = self.pop() else: @@ -3134,11 +3969,18 @@ def FORMAT_VALUE(self, inst: Instruction) -> None: return self._format_value(fmt_spec, flags) +<<<<<<< HEAD def BUILD_STRING(self, inst: Instruction) -> None: format_string_parts: list[str] = [] args: list[VariableTracker] = [] kwargs: dict[str, VariableTracker] = {} assert inst.arg is not None +======= + def BUILD_STRING(self, inst): + format_string_parts: list[str] = [] + args: list[VariableTracker] = [] + kwargs: dict[str, VariableTracker] = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for part in self.popn(inst.arg): if isinstance(part, ConstantVariable): format_string_parts.append("{}") @@ -3167,7 +4009,11 @@ def BUILD_STRING(self, inst: Instruction) -> None: ) ) +<<<<<<< HEAD def IS_OP(self, inst: Instruction) -> None: +======= + def IS_OP(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert inst.argval == 0 or inst.argval == 1 if inst.argval == 0: new_argval = "is" @@ -3176,6 +4022,7 @@ def IS_OP(self, inst: Instruction) -> None: new_inst = create_instruction("COMPARE_OP", argval=new_argval) self.COMPARE_OP(new_inst) +<<<<<<< HEAD def CONTAINS_OP(self, inst: Instruction) -> None: assert inst.argval == 0 or inst.argval == 1 left, right = self.popn(2) @@ -3206,15 +4053,35 @@ def LIST_EXTEND(self, inst: Instruction) -> None: v = self.pop() assert inst.argval > 0 assert inst.arg is not None +======= + def CONTAINS_OP(self, inst): + assert inst.argval == 0 or inst.argval == 1 + left, right = self.popn(2) + op = inst.argval + self.push(right.call_method(self, "__contains__", [left], {})) + if op == 1: + self.UNARY_NOT(inst) + + def LIST_EXTEND(self, inst): + v = self.pop() + assert inst.argval > 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = self.stack[-inst.arg] assert isinstance(obj, ListVariable) assert obj.is_mutable() obj.call_method(self, "extend", [v], {}) +<<<<<<< HEAD def LIST_TO_TUPLE(self, inst: Instruction) -> None: self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type] def STOPITERATION_ERROR(self, inst: Instruction) -> None: +======= + def LIST_TO_TUPLE(self, inst): + self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type] + + def STOPITERATION_ERROR(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # wrap the generator body in a try: ... except StopIteration: ... which # converts the StopIteration into a RuntimeError # https://peps.python.org/pep-0479/ @@ -3222,7 +4089,11 @@ def STOPITERATION_ERROR(self, inst: Instruction) -> None: # https://github.com/python/cpython/commit/28187141cc34063ef857976ddbca87ba09a882c2 val = self.stack[-1] assert self._isinstance_exception(val) +<<<<<<< HEAD if val.exc_type is StopIteration: # type: ignore[union-attr] +======= + if val.exc_type is StopIteration: # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_val = variables.BuiltinVariable(RuntimeError).call_function( self, # type: ignore[arg-type] [ConstantVariable("generator raised StopIteration")], @@ -3232,10 +4103,16 @@ def STOPITERATION_ERROR(self, inst: Instruction) -> None: new_val.call_setattr(self, ConstantVariable("__cause__"), val) # type: ignore[attr-defined] self.stack[-1] = new_val +<<<<<<< HEAD def DICT_MERGE(self, inst: Instruction) -> None: v = self.pop() assert inst.argval > 0 assert inst.arg is not None +======= + def DICT_MERGE(self, inst): + v = self.pop() + assert inst.argval > 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = self.stack[-inst.arg].realize() assert isinstance(obj, ConstDictVariable) assert obj.is_mutable() @@ -3243,17 +4120,28 @@ def DICT_MERGE(self, inst: Instruction) -> None: DICT_UPDATE = DICT_MERGE +<<<<<<< HEAD def GEN_START(self, inst: Instruction) -> None: self.pop() def GET_LEN(self, inst: Instruction) -> None: +======= + def GEN_START(self, inst): + self.pop() + + def GET_LEN(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tos = self.stack[-1] if tos.is_python_constant(): self.push(ConstantVariable.create(len(tos.as_python_constant()))) else: self.push(tos.call_method(self, "__len__", [], {})) +<<<<<<< HEAD def MATCH_MAPPING(self, inst: Instruction) -> None: +======= + def MATCH_MAPPING(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tos = self.stack[-1] assert isinstance(tos, ConstDictVariable) if isinstance(tos.items, collections.abc.Mapping): @@ -3261,7 +4149,11 @@ def MATCH_MAPPING(self, inst: Instruction) -> None: else: self.push(ConstantVariable.create(False)) +<<<<<<< HEAD def MATCH_SEQUENCE(self, inst: Instruction) -> None: +======= + def MATCH_SEQUENCE(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tos = self.stack[-1] assert tos.is_python_constant() tos_value = tos.as_python_constant() @@ -3272,7 +4164,11 @@ def MATCH_SEQUENCE(self, inst: Instruction) -> None: else: self.push(ConstantVariable.create(False)) +<<<<<<< HEAD def MATCH_KEYS(self, inst: Instruction) -> None: +======= + def MATCH_KEYS(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tos = self.stack[-1] tos1 = self.stack[-2] assert isinstance(tos1, ConstDictVariable) @@ -3286,10 +4182,17 @@ def MATCH_KEYS(self, inst: Instruction) -> None: if sys.version_info < (3, 11): self.push(ConstantVariable.create(False)) +<<<<<<< HEAD def LOAD_ASSERTION_ERROR(self, inst: Instruction) -> None: self.push(self.load_builtin_from_argval("AssertionError")) def LOAD_BUILD_CLASS(self, inst: Instruction) -> None: +======= + def LOAD_ASSERTION_ERROR(self, inst): + self.push(self.load_builtin_from_argval("AssertionError")) + + def LOAD_BUILD_CLASS(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unimplemented_v2( gb_type="LOAD_BUILD_CLASS bytecode not supported", context="", @@ -3337,7 +4240,11 @@ def LOAD_BUILD_CLASS(self, inst: Instruction) -> None: INPLACE_OR = stack_op(operator.ior) # 3.11 opcodes +<<<<<<< HEAD def RESUME(self, inst: Instruction) -> None: +======= + def RESUME(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if inst.arg == 0: self.append_prefix_inst(inst) self.accept_prefix_inst = False @@ -3346,6 +4253,7 @@ def RESUME(self, inst: Instruction) -> None: if sys.version_info >= (3, 11): +<<<<<<< HEAD def BINARY_OP(self, inst: Instruction) -> None: assert inst.arg is not None return _binary_op_lookup[inst.arg](self, inst) @@ -3354,6 +4262,15 @@ def PRECALL(self, inst: Instruction) -> None: pass def KW_NAMES(self, inst: Instruction) -> None: +======= + def BINARY_OP(self, inst): + return _binary_op_lookup[inst.arg](self, inst) + + def PRECALL(self, inst): + pass + + def KW_NAMES(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kw_names = self.code_options["co_consts"][inst.arg] assert isinstance(kw_names, tuple) for name in kw_names: @@ -3361,10 +4278,17 @@ def KW_NAMES(self, inst: Instruction) -> None: assert self.kw_names is None self.kw_names = ConstantVariable.create(value=kw_names) # type: ignore[assignment] +<<<<<<< HEAD def PUSH_NULL(self, inst: Instruction) -> None: self.push(NullVariable()) def _call(self, inst: Instruction, call_kw: bool = False) -> None: +======= + def PUSH_NULL(self, inst): + self.push(NullVariable()) + + def _call(self, inst, call_kw=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # see https://docs.python.org/3.11/library/dis.html#opcode-CALL # for convention if call_kw: @@ -3376,7 +4300,10 @@ def _call(self, inst: Instruction, call_kw: bool = False) -> None: else: kw_names = self.kw_names.value if self.kw_names else () +<<<<<<< HEAD assert inst.arg is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) contents = self.popn(inst.arg + 2) if sys.version_info >= (3, 13): # NULL and callable swapped @@ -3407,6 +4334,7 @@ def _call(self, inst: Instruction, call_kw: bool = False) -> None: self.kw_names = None @break_graph_if_unsupported(push=1) +<<<<<<< HEAD def CALL(self, inst: Instruction) -> None: self._call(inst) @@ -3416,6 +4344,15 @@ def COPY(self, inst: Instruction) -> None: def SWAP(self, inst: Instruction) -> None: assert inst.arg is not None +======= + def CALL(self, inst): + self._call(inst) + + def COPY(self, inst): + self.push(self.stack[-inst.arg]) + + def SWAP(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.stack[-1], self.stack[-inst.arg] = self.stack[-inst.arg], self.stack[-1] JUMP_BACKWARD = jump @@ -3426,6 +4363,7 @@ def SWAP(self, inst: Instruction) -> None: POP_JUMP_FORWARD_IF_FALSE = generic_jump(operator.not_, False) POP_JUMP_BACKWARD_IF_FALSE = generic_jump(operator.not_, False) +<<<<<<< HEAD def CACHE(self, inst: Instruction) -> None: pass @@ -3433,6 +4371,15 @@ def BEFORE_WITH(self, inst: Instruction) -> None: self.setup_or_before_with(inst) def setup_or_before_with(self, inst: Instruction) -> None: +======= + def CACHE(self, inst): + pass + + def BEFORE_WITH(self, inst): + self.setup_or_before_with(inst) + + def setup_or_before_with(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ctx = self.pop() if not isinstance( ctx, (ContextWrappingVariable, GenericContextWrappingVariable) @@ -3479,7 +4426,10 @@ def setup_or_before_with(self, inst: Instruction) -> None: ): target = None else: +<<<<<<< HEAD assert self.next_instruction.exn_tab_entry is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) target = self.next_instruction.exn_tab_entry.target else: target = inst.target @@ -3487,7 +4437,11 @@ def setup_or_before_with(self, inst: Instruction) -> None: self.push(exit) if target: +<<<<<<< HEAD if isinstance(self, InstructionTranslator) or config.nested_graph_breaks: +======= + if isinstance(self, InstructionTranslator): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.block_stack.append( BlockStackEntry(inst, target, len(self.stack), ctx) ) @@ -3496,11 +4450,19 @@ def setup_or_before_with(self, inst: Instruction) -> None: self.push(ctx.enter(self)) +<<<<<<< HEAD def append_prefix_inst(self, inst: Instruction) -> None: assert self.accept_prefix_inst self.prefix_insts.append(inst) def MAKE_CELL(self, inst: Instruction) -> None: +======= + def append_prefix_inst(self, inst): + assert self.accept_prefix_inst + self.prefix_insts.append(inst) + + def MAKE_CELL(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 12) and not self.accept_prefix_inst: # In 3.12+, MAKE_CELL is not longer necessarily a prefix instruction. # It can be generated by inlined comprehensions. @@ -3511,24 +4473,40 @@ def MAKE_CELL(self, inst: Instruction) -> None: else: self.append_prefix_inst(inst) +<<<<<<< HEAD def COPY_FREE_VARS(self, inst: Instruction) -> None: self.append_prefix_inst(inst) def RETURN_GENERATOR(self, inst: Instruction) -> None: +======= + def COPY_FREE_VARS(self, inst): + self.append_prefix_inst(inst) + + def RETURN_GENERATOR(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.append_prefix_inst(inst) # 3.12 opcodes # BINARY/STORE_SLICE opcodes are broken down into # BUILD_SLICE 2 and BINARY/STORE_SUBSCR +<<<<<<< HEAD def END_FOR(self, inst: Instruction) -> None: +======= + def END_FOR(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info >= (3, 13): self.pop() else: self.popn(2) +<<<<<<< HEAD def LOAD_FAST_CHECK(self, inst: Instruction) -> None: if istype(self.symbolic_locals.get(inst.argval, None), NullVariable): +======= + def LOAD_FAST_CHECK(self, inst): + if isinstance(self.symbolic_locals.get(inst.argval, None), NullVariable): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unimplemented_v2( gb_type="LOAD_FAST_CHECK on uninitialized variable", context=inst.argval, @@ -3537,22 +4515,35 @@ def LOAD_FAST_CHECK(self, inst: Instruction) -> None: ) self.LOAD_FAST(inst) +<<<<<<< HEAD def LOAD_FAST_AND_CLEAR(self, inst: Instruction) -> None: +======= + def LOAD_FAST_AND_CLEAR(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if inst.argval not in self.symbolic_locals: self.push(NullVariable()) else: self.LOAD_FAST(inst) self.symbolic_locals[inst.argval] = NullVariable() +<<<<<<< HEAD def LOAD_SUPER_ATTR(self, inst: Instruction) -> None: self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) assert inst.arg is not None +======= + def LOAD_SUPER_ATTR(self, inst): + self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if inst.arg & 1: self.LOAD_METHOD(inst) else: self._load_attr(inst) +<<<<<<< HEAD def CALL_INTRINSIC_1(self, inst: Instruction) -> None: +======= + def CALL_INTRINSIC_1(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if inst.argval == 3: # INTRINSIC_STOPITERATION_ERROR self.STOPITERATION_ERROR(inst) @@ -3570,7 +4561,11 @@ def CALL_INTRINSIC_1(self, inst: Instruction) -> None: hints=[*graph_break_hints.SUPPORTABLE], ) +<<<<<<< HEAD def END_SEND(self, inst: Instruction) -> None: +======= + def END_SEND(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tos = self.pop() self.pop() self.push(tos) @@ -3579,10 +4574,17 @@ def END_SEND(self, inst: Instruction) -> None: # fused instructions LOAD_FAST_LOAD_FAST, STORE_FAST_STORE_FAST, STORE_FAST_LOAD_FAST # are broken down. @break_graph_if_unsupported(push=1) +<<<<<<< HEAD def CALL_KW(self, inst: Instruction) -> None: self._call(inst, call_kw=True) def TO_BOOL(self, inst: Instruction) -> None: +======= + def CALL_KW(self, inst): + self._call(inst, call_kw=True) + + def TO_BOOL(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TO_BOOL only precedes a conditional jump or UNARY_NOT (see compile.c in CPython) # So we can skip this instruction as long as we remember to codegen a TO_BOOL # before conditional jumps/UNARY_NOT. @@ -3592,9 +4594,14 @@ def TO_BOOL(self, inst: Instruction) -> None: "UNARY_NOT", ) +<<<<<<< HEAD def SET_FUNCTION_ATTRIBUTE(self, inst: Instruction) -> None: flags = inst.arg assert flags is not None +======= + def SET_FUNCTION_ATTRIBUTE(self, inst): + flags = inst.arg +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fn = self.pop() assert isinstance(fn, NestedUserFunctionVariable) attr = self.pop() @@ -3610,6 +4617,7 @@ def SET_FUNCTION_ATTRIBUTE(self, inst: Instruction) -> None: self.push(fn) +<<<<<<< HEAD def CONVERT_VALUE(self, inst: Instruction) -> None: self.push(self._convert_value(self.pop(), inst.argval)) @@ -3620,15 +4628,31 @@ def FORMAT_WITH_SPEC(self, inst: Instruction) -> None: self._format_value(self.pop(), 0) def is_non_empty_graph(self) -> bool: +======= + def CONVERT_VALUE(self, inst): + self.push(self._convert_value(self.pop(), inst.argval)) + + def FORMAT_SIMPLE(self, inst): + self._format_value(ConstantVariable.create(""), 0) + + def FORMAT_WITH_SPEC(self, inst): + self._format_value(self.pop(), 0) + + def is_non_empty_graph(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.output.count_calls() > 1: # perf optimization only self.is_non_empty_graph = lambda: True # type: ignore[method-assign] return True return False +<<<<<<< HEAD def format_frame_summary( self, additional_stack_frames: Optional[list[Any]] = None ) -> str: +======= + def format_frame_summary(self, additional_stack_frames=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if additional_stack_frames is None: additional_stack_frames = [] return "".join( @@ -3637,7 +4661,11 @@ def format_frame_summary( ) ) +<<<<<<< HEAD def frame_summary(self) -> traceback.FrameSummary: +======= + def frame_summary(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return traceback.FrameSummary( getattr(self.f_code, "co_filename", ""), self.lineno, @@ -3645,12 +4673,20 @@ def frame_summary(self) -> traceback.FrameSummary: lookup_line=False, ) +<<<<<<< HEAD def is_co_filename_from_nn_modules(self) -> bool: +======= + def is_co_filename_from_nn_modules(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) filename = getattr(self.f_code, "co_filename", "") nn_modules_pattern = re.compile(r".*torch/nn/modules.*") return nn_modules_pattern.match(filename) is not None +<<<<<<< HEAD def store_global_weakref_by_id(self, prefix: str, value: Any) -> str: +======= + def store_global_weakref_by_id(self, prefix, value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) global_name = self.output.install_global_by_id(prefix, weakref.ref(value)) install_guard( GlobalWeakRefSource(global_name).make_guard(GuardBuilder.WEAKREF_ALIVE) @@ -3658,6 +4694,7 @@ def store_global_weakref_by_id(self, prefix: str, value: Any) -> str: return global_name @property +<<<<<<< HEAD def fake_mode(self) -> Optional[FakeTensorMode]: return self.output.tracing_context.fake_mode @@ -3665,6 +4702,13 @@ def fake_mode(self) -> Optional[FakeTensorMode]: def strict_translation_mode( self, check_fn: Callable[[VariableTracker], bool] ) -> Any: +======= + def fake_mode(self): + return self.output.tracing_context.fake_mode + + @contextlib.contextmanager + def strict_translation_mode(self, check_fn: Callable[[VariableTracker], bool]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Strict mode is enabled on a per-VariableTracker level depending on the return value of check_fn(node). """ @@ -3704,7 +4748,11 @@ def __init__( distributed_state: Optional[DistributedState], # This determines whether to use the execution recorder. closure: Optional[tuple[types.CellType]] = None, +<<<<<<< HEAD package: Optional[CompilePackage] = None, +======= + package: Optional["CompilePackage"] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: super().__init__() self.speculation_log = speculation_log @@ -3715,10 +4763,14 @@ def __init__( self.symbolic_locals = symbolic_locals self.symbolic_globals = symbolic_globals self.symbolic_torch_function_state = symbolic_torch_function_state +<<<<<<< HEAD # used to keep cell/freevars alive after pruning symbolic_locals (prune_dead_locals) # in order to generate any nested closures self.post_prune_cell_and_freevars = None self.stack: list[VariableTracker] = [] +======= + self.stack = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.instruction_pointer = 0 self.start_point = None self.current_instruction = create_instruction("NOP") @@ -3756,6 +4808,7 @@ def __init__( self.num_calls: dict[str, int] = {} # Flag to indicate whether tracing is used for export. self.export = export +<<<<<<< HEAD # NOTE: one_graph is used for export/fullgraph=True to always force errors on graph breaks. # To toggle erroring/resuming on graph breaks during fullgraph=False compile, self.error_on_graph_break # is used instead. Every step(), its value is updated to the global tls.error_on_graph_break. @@ -3766,6 +4819,9 @@ def __init__( self.error_on_graph_break = False # Also do not graph break when tracing resume function prologues self.is_tracing_resume_prologue = False +======= + self.one_graph = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.current_speculation = None @@ -3792,7 +4848,11 @@ def __init__( self.inline_depth = inline_depth self.inconsistent_side_effects = False +<<<<<<< HEAD self._constants_cache: list[Optional[ConstantVariable]] = [None] * len( +======= + self._constants_cache: list[Optional[VariableTracker]] = [None] * len( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f_code.co_consts ) @@ -3807,11 +4867,19 @@ def __init__( class InstructionTranslator(InstructionTranslatorBase): @staticmethod +<<<<<<< HEAD def current_tx() -> InstructionTranslator: return tls.current_tx @contextlib.contextmanager def set_current_tx(self) -> Any: +======= + def current_tx() -> "InstructionTranslator": + return tls.current_tx + + @contextlib.contextmanager + def set_current_tx(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prior = getattr(tls, "current_tx", None) tls.current_tx = self try: @@ -3822,6 +4890,7 @@ def set_current_tx(self) -> Any: def __init__( self, instructions: list[Instruction], +<<<<<<< HEAD f_code: types.CodeType, f_locals: dict[str, Any], f_globals: dict[str, Any], @@ -3838,6 +4907,24 @@ def __init__( exn_vt_stack: ExceptionStack, distributed_state: Optional[DistributedState], package: Optional[CompilePackage], +======= + f_code, + f_locals, + f_globals, + f_builtins, + closure, + torch_function_mode_stack, + code_options, + compiler_fn, + one_graph, + export, + export_constraints, + frame_state, + speculation_log: SpeculationLog, + exn_vt_stack: ExceptionStack, + distributed_state: Optional[DistributedState], + package: Optional["CompilePackage"], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: _step_logger()( logging.INFO, @@ -3945,12 +5032,19 @@ def __init__( side_effects.store_cell(cell_var, contents_var) else: cell_var = side_effects.track_cell_new() +<<<<<<< HEAD cell_var.local_name = name # type: ignore[attr-defined] +======= + cell_var.local_name = name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.symbolic_locals[name] = cell_var # Populate `symbolic_locals` with cells captured by this frame, # effectively implementing the `COPY_FREE_VARS` instruction. +<<<<<<< HEAD assert closure is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for name, cell in zip(self.freevars(), closure): cell_source = LocalCellSource(name) contents_source = LocalSource(name, is_derefed_cell_contents=True) @@ -3964,7 +5058,11 @@ def __init__( cell_var = side_effects.track_cell_existing( cell_source, cell, contents_var ) +<<<<<<< HEAD cell_var.local_name = name # type: ignore[attr-defined] +======= + cell_var.local_name = name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.symbolic_locals[name] = cell_var self.symbolic_torch_function_state = SymbolicTorchFunctionState( @@ -3978,7 +5076,11 @@ def __init__( self.symbolic_locals ) +<<<<<<< HEAD def _throw_if_in_functorch(self) -> None: +======= + def _throw_if_in_functorch(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Fallback to eager in case of a graph break inside vmap eager = torch._dynamo.lookup_backend("eager") compiler_fn = inspect.getattr_static( @@ -4009,14 +5111,136 @@ def _throw_if_in_functorch(self) -> None: hints=[], ) +<<<<<<< HEAD def get_example_value(self, source: Source) -> Any: +======= + def get_example_value(self, source: Source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(source, LocalSource): return self.f_locals[source.local_name] if isinstance(source, GlobalSource): return self.f_globals[source.global_name] raise KeyError +<<<<<<< HEAD def symbolic_locals_contain_module_class(self) -> bool: +======= + def run(self): + super().run() + + def should_compile_partial_graph(self): + if sys.version_info >= (3, 11): + # Do not compile if current instruction's block is not the top with block + entry = self.current_instruction.exn_tab_entry + if entry and ( + not self.block_stack or entry.target is not self.block_stack[-1].target + ): + return False + return ( + all(b.can_restore() for b in self.block_stack) + and not self.one_graph + and not self.active_generic_context_managers + ) + + def create_call_resume_at(self, inst, all_stack_locals_metadata): + self.instruction_pointer = None + + if inst.opname == "RETURN_VALUE": + return [create_instruction("RETURN_VALUE")] + elif inst.opname == "RETURN_CONST": + return [create_instruction("RETURN_CONST", argval=inst.argval)] + + reads = livevars_analysis(self.instructions, inst) + all_argnames = tuple( + k + for k in self.symbolic_locals.keys() + if k in reads and k not in self.cell_and_freevars() + ) + # NOTE: do not use isinstance, since it realizes lazy VT's + argnames_null_set = set(all_stack_locals_metadata[0].locals_null_keys) + argnames = tuple(k for k in all_argnames if k not in argnames_null_set) + argnames_null = tuple(k for k in all_argnames if k in argnames_null_set) + if sys.version_info < (3, 12): + assert len(argnames_null) == 0, "variables should not be NULL in < 3.12" + # compile_subgraph did not codegen any NULLs, + # so we should not count NullVariables + stack_len = len(self.stack) - len(all_stack_locals_metadata[0].stack_null_idxes) + nargs = stack_len + len(argnames) + + cg = PyCodegen(self) + + # Handle inactive context variables. + # The resume function assumes that context variables are the class, NOT the object. + # e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled + # NOTE: if the unsupported instruction modifies the inactive context variable, it may + # result in silent incorrectness! + for (i, _), i_orig in zip( + all_stack_locals_metadata[0].stack_ctx_args, + all_stack_locals_metadata[0].stack_ctx_idxes_orig, + ): + # Replace the current stack var with the context class + ctx = cast(ContextWrappingVariable, self.stack[i_orig]) + ctx.reconstruct_type(cg) + cg.extend_output(create_swap(stack_len - i + 1)) + cg.append_output(create_instruction("POP_TOP")) + + for name, _ in all_stack_locals_metadata[0].locals_ctx_args: + # Replace the local with the context class + ctx = cast(ContextWrappingVariable, self.symbolic_locals[name]) + ctx.reconstruct_type(cg) + cg.append_output(create_instruction("STORE_FAST", argval=name)) + + name = unique_id(f"__resume_at_{inst.offset}", with_uuid=True) + + new_code: types.CodeType = ContinueExecutionCache.lookup( + self.f_code, + self.lineno, + inst.offset, + tuple(b.target.offset for b in self.block_stack), + stack_len, + argnames, + argnames_null, + tuple(b.resume_fn() for b in self.block_stack), + tuple(all_stack_locals_metadata[0].stack_ctx_args), + tuple(all_stack_locals_metadata[0].locals_ctx_args), + tuple(all_stack_locals_metadata[0].stack_null_idxes), + ) + + # Add original GraphModule context to the resume function to handle + # the case of a graph break while tracing a GraphModule + orig_graphmodule_maybe = code_context.get_context(self.f_code).get( + "orig_graphmodule", lambda: None + )() + if orig_graphmodule_maybe is not None: + code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref( + orig_graphmodule_maybe + ) + + if new_code.co_freevars: + # expose code object for debugging purposes + self.output.install_global_unsafe(name, new_code) + cg.make_function_with_closure(name, new_code, True, stack_len) + package_name = None + else: + # This is safe: we pre-generate a unique name + self.output.install_global_unsafe( + name, types.FunctionType(new_code, self.f_globals, name) + ) + cg.extend_output(cg.load_function_name(name, True, stack_len)) + package_name = name + + if self.package is not None: + self.package.add_resume_function( + new_code, self.f_globals["__name__"], package_name + ) + + cg.extend_output([cg.create_load(k) for k in argnames]) + cg.extend_output(create_call_function(nargs, False)) + cg.append_output(create_instruction("RETURN_VALUE")) + return cg.get_instructions() + + def symbolic_locals_contain_module_class(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for v in self.symbolic_locals.values(): if isinstance(v, UserDefinedClassVariable) and issubclass( v.as_python_constant(), torch.nn.Module @@ -4024,7 +5248,11 @@ def symbolic_locals_contain_module_class(self) -> bool: return True return False +<<<<<<< HEAD def replace_tos_if_return_is_generator(self) -> None: +======= + def replace_tos_if_return_is_generator(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( len(self.stack) and (tos := self.stack[-1]) @@ -4035,7 +5263,11 @@ def replace_tos_if_return_is_generator(self) -> None: mutation_type=ValueMutationNew(), ) +<<<<<<< HEAD def _return(self, inst: Instruction) -> None: +======= + def _return(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.replace_tos_if_return_is_generator() assert self.instruction_pointer is not None assert self.start_point is not None @@ -4050,8 +5282,11 @@ def _return(self, inst: Instruction) -> None: and not self.symbolic_locals_contain_module_class() and not self.export and not self.one_graph +<<<<<<< HEAD and not self.error_on_graph_break and not self.is_tracing_resume_prologue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): raise exc.SkipFrame("because no content in function call") @@ -4066,8 +5301,11 @@ def _return(self, inst: Instruction) -> None: reason=GraphCompileReason( "return_value", [self.frame_summary()], graph_break=False ), +<<<<<<< HEAD # the value to be returned stack_pops=1 if inst.opname == "RETURN_VALUE" else 0, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # check that our stack/locals meta are correct: # we should only be tracing 1 frame, and there should not be any NULLs on the stack @@ -4078,6 +5316,7 @@ def _return(self, inst: Instruction) -> None: if inst.opname == "RETURN_VALUE" else create_instruction("RETURN_CONST", argval=inst.argval) ) +<<<<<<< HEAD # NOTE: does the stack need to be empty after the return? self.output.add_output_instructions([return_inst]) raise ReturnValueOp @@ -4086,6 +5325,15 @@ def RETURN_VALUE(self, inst: Instruction) -> None: self._return(inst) def RETURN_CONST(self, inst: Instruction) -> None: +======= + self.output.add_output_instructions([return_inst]) + raise ReturnValueOp + + def RETURN_VALUE(self, inst): + self._return(inst) + + def RETURN_CONST(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._return(inst) @@ -4106,13 +5354,21 @@ class InliningInstructionTranslator(InstructionTranslatorBase): parent: InstructionTranslatorBase @classmethod +<<<<<<< HEAD def inline_call(cls, parent: Any, func: Any, args: Any, kwargs: Any) -> Any: +======= + def inline_call(cls, parent, func, args, kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with patch.dict(counters, {"unimplemented": counters["inline_call"]}): tracer = cls.build_inline_tracer(parent, func, args, kwargs) return tracer.inline_call_() @staticmethod +<<<<<<< HEAD def check_inlineable(func: Any) -> trace_rules.SkipResult: +======= + def check_inlineable(func): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if func.has_self(): unimplemented_v2( gb_type="Inline attempt with __self__", @@ -4175,11 +5431,19 @@ def check_inlineable(func: Any) -> trace_rules.SkipResult: @staticmethod def build_inline_tracer( +<<<<<<< HEAD parent: Any, func: VariableTracker, args: list[VariableTracker], kwargs: Any, ) -> InliningInstructionTranslator: +======= + parent, + func: VariableTracker, + args: list[VariableTracker], + kwargs, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance( func, ( @@ -4259,7 +5523,11 @@ def build_inline_tracer( cur_inst = parent.current_instruction parent_code = parent.f_code +<<<<<<< HEAD def get_trace_call_log_str() -> str: +======= + def get_trace_call_log_str(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) header = parent.get_line_of_code_header( lineno=cur_inst.positions.lineno ) @@ -4303,7 +5571,11 @@ def get_trace_call_log_str() -> str: ) return tracer +<<<<<<< HEAD def inline_call_(self) -> VariableTracker: +======= + def inline_call_(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) parent = self.parent code = self.f_code @@ -4325,6 +5597,7 @@ def inline_call_(self) -> VariableTracker: except Exception: log.debug("FAILED INLINING %s", code) raise +<<<<<<< HEAD finally: parent.error_on_graph_break = self.error_on_graph_break @@ -4332,6 +5605,8 @@ def inline_call_(self) -> VariableTracker: # graph break return ConstantVariable.create(None) # return dummy variable +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.symbolic_result is not None if self.f_globals is parent.f_globals: @@ -4354,6 +5629,7 @@ def inline_call_(self) -> VariableTracker: ): assert isinstance(self, InliningGeneratorInstructionTranslator) # When the generator returns None, we raise StopIteration +<<<<<<< HEAD args = [] if not ( isinstance(self.symbolic_result, ConstantVariable) @@ -4361,6 +5637,9 @@ def inline_call_(self) -> VariableTracker: ): args = [self.symbolic_result] exc.raise_observed_exception(StopIteration, self, args=args) +======= + exc.raise_observed_exception(StopIteration, self) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return self.symbolic_result else: @@ -4432,6 +5711,7 @@ def __init__( self.one_graph = parent.one_graph @property +<<<<<<< HEAD def fake_mode(self) -> Optional[FakeTensorMode]: return self.parent.fake_mode @@ -4455,6 +5735,18 @@ def create_call_resume_at( return super().create_call_resume_at( inst, all_stack_locals_metadata, disable_current_frame_resume ) +======= + def fake_mode(self): + return self.parent.fake_mode + + def run_ctx_mgr(self): + return TracingContext.current_frame(self.parent.frame_summary()) + + def should_compile_partial_graph(self): + return False # inlining functions is all-or-nothing + + def create_call_resume_at(self, inst, all_stack_locals_metadata): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unimplemented_v2( gb_type="Graph break in inlined function", context="", @@ -4462,19 +5754,31 @@ def create_call_resume_at( hints=[], ) +<<<<<<< HEAD def RETURN_VALUE(self, inst: Instruction) -> None: +======= + def RETURN_VALUE(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.symbolic_result = self.pop() # type: ignore[assignment] self.instruction_pointer = None raise ReturnValueOp +<<<<<<< HEAD def RETURN_CONST(self, inst: Instruction) -> None: +======= + def RETURN_CONST(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.symbolic_result = self._load_const(inst) self.instruction_pointer = None raise ReturnValueOp +<<<<<<< HEAD def get_globals_source_and_value( self, name: str ) -> tuple[Any, VariableTracker, Source]: +======= + def get_globals_source_and_value(self, name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NamedTuple's `__new__` has a fake global scope that's not an actual # module. TODO generalize the check for other non-importable cases. # https://github.com/python/cpython/blob/8421b03b16a4852a527256cb7cdce2ab2d318548/Lib/collections/__init__.py#L441-L447 @@ -4503,6 +5807,7 @@ def get_globals_source_and_value( # Dont use lazy vt because we will do a setattr afterwards fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value) global_source = DictGetItemSource(globals_source, name) # type: ignore[assignment] +<<<<<<< HEAD if is_stdlib(fglobals_value): # Users don't inplace mutate a stdlib attribute (like inspect, @@ -4512,6 +5817,11 @@ def get_globals_source_and_value( return fglobals_value, fglobals_vt, global_source def _load_global(self, inst: Instruction) -> None: +======= + return fglobals_value, fglobals_vt, global_source + + def _load_global(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) name = inst.argval if name not in self.f_globals: return self.load_builtin(inst) @@ -4528,7 +5838,11 @@ def _load_global(self, inst: Instruction) -> None: value = self.f_globals[name] self.push(VariableTracker.build(self, value, global_source)) +<<<<<<< HEAD def STORE_GLOBAL(self, inst: Instruction) -> None: +======= + def STORE_GLOBAL(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.output.global_scope is self.f_globals: # If the global scope matches that of the root frame, use handler in # root frame instruction translator, to enforce consistency. @@ -4551,13 +5865,21 @@ class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): generated_items: list[VariableTracker] # Flag whether or not the InlineGenerator should consume the entire iterator +<<<<<<< HEAD def __init__(self, *args: Any, **kwargs: Any) -> None: +======= + def __init__(self, *args, **kwargs) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(*args, **kwargs) self.generated_items = [] self.generator_exhausted = False self.is_generator_from_ctx_manager = False +<<<<<<< HEAD def YIELD_VALUE(self, inst: Instruction) -> None: +======= + def YIELD_VALUE(self, inst: Instruction): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) top = self.pop() self.generated_items.append(top) if len(self.generated_items) > MAX_ITERATOR_LIMIT: @@ -4574,13 +5896,18 @@ def YIELD_VALUE(self, inst: Instruction) -> None: # Stop tracing raise YieldValueOp +<<<<<<< HEAD def GET_YIELD_FROM_ITER(self, inst: Instruction) -> None: +======= + def GET_YIELD_FROM_ITER(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tos = self.stack[-1] if not isinstance(tos, ListIteratorVariable): self.pop() res = BuiltinVariable(iter).call_function(self, [tos], {}) # type: ignore[arg-type] self.push(res) +<<<<<<< HEAD def RETURN_VALUE(self, inst: Instruction) -> None: self.generator_exhausted = True return super().RETURN_VALUE(inst) @@ -4590,6 +5917,17 @@ def RETURN_CONST(self, inst: Instruction) -> None: return super().RETURN_CONST(inst) def YIELD_FROM(self, inst: Instruction) -> None: +======= + def RETURN_VALUE(self, inst): + self.generator_exhausted = True + return super().RETURN_VALUE(inst) + + def RETURN_CONST(self, inst): + self.generator_exhausted = True + return super().RETURN_CONST(inst) + + def YIELD_FROM(self, inst): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(self.stack) >= 2 val = self.pop() tos = self.stack[-1] @@ -4627,11 +5965,19 @@ def YIELD_FROM(self, inst: Instruction) -> None: # Add the value to yield into generated_items and replace the top of the stack with None self.YIELD_VALUE(inst) +<<<<<<< HEAD def SEND(self, inst: Instruction) -> None: assert len(self.stack) >= 2 val = self.pop() tos = self.stack[-1] if isinstance(tos, (IteratorVariable, LocalGeneratorObjectVariable)) or ( +======= + def SEND(self, inst): + assert len(self.stack) >= 2 + val = self.pop() + tos = self.stack[-1] + if isinstance(tos, (ListIteratorVariable, LocalGeneratorObjectVariable)) or ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) isinstance(tos, UserDefinedObjectVariable) and isinstance(tos.value, collections.abc.Iterator) ): diff --git a/torch/_dynamo/tensor_version_op.py b/torch/_dynamo/tensor_version_op.py index 8709c5618d859..39c2e9adea0e5 100644 --- a/torch/_dynamo/tensor_version_op.py +++ b/torch/_dynamo/tensor_version_op.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """This module implements tensor version operations for Dynamo tracing. It provides primitives for handling tensor versioning during tracing, particularly in the @@ -16,11 +21,15 @@ Note this is similar to how no_grad is handled. """ +<<<<<<< HEAD from contextlib import AbstractContextManager from typing import Any import torch from torch import SymInt +======= +import torch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._prims import _make_prim, RETURN_TYPE from torch._subclasses import FakeTensorMode from torch._subclasses.functional_tensor import FunctionalTensorMode @@ -35,14 +44,22 @@ ) +<<<<<<< HEAD @_tensor_version.py_impl(FakeTensorMode) # type: ignore[misc] def _tensor_version_fake(fake_mode: FakeTensorMode, self_tensor: Any) -> SymInt: +======= +@_tensor_version.py_impl(FakeTensorMode) +def _tensor_version_fake(fake_mode, self_tensor): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ The initial dynamo capture of _tensor_version + _unsafe_set_version_counter turns the `._version` into an unbacked SymInt so that we don't need to specialize on the `._version` of input tensors to the graph. """ +<<<<<<< HEAD assert fake_mode.shape_env is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return fake_mode.shape_env.create_unbacked_symint() @@ -56,6 +73,7 @@ def _tensor_version_fake(fake_mode: FakeTensorMode, self_tensor: Any) -> SymInt: torch.fx.node.has_side_effect(_unsafe_set_version_counter) +<<<<<<< HEAD @_tensor_version.py_impl(FunctionalTensorMode) # type: ignore[misc] def _tensor_version_functional(mode: FunctionalTensorMode, self: Any) -> int: return self._version @@ -67,4 +85,13 @@ def _unsafe_set_version_counter_functional( tensors: tuple[torch.Tensor, ...], versions: tuple[int, ...], ) -> None: +======= +@_tensor_version.py_impl(FunctionalTensorMode) +def _tensor_version_functional(mode, self): + return self._version + + +@_unsafe_set_version_counter.py_impl(FunctionalTensorMode) +def _unsafe_set_version_counter_functional(ctx, tensors, versions): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._C._autograd._unsafe_set_version_counter(tensors, versions) diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index 77860c720a6e2..4bab77f6bede5 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality. This module extends PyTorch's testing framework with Dynamo-specific testing capabilities. @@ -16,11 +21,18 @@ import re import sys import unittest +<<<<<<< HEAD from typing import Any, Callable, Union import torch import torch.testing from torch._dynamo import polyfills +======= +from typing import Union + +import torch +import torch.testing +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._logging._internal import trace_log from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] IS_WINDOWS, @@ -101,6 +113,7 @@ def tearDown(self) -> None: log.warning("Running test changed grad mode") torch.set_grad_enabled(self._prior_is_grad_enabled) +<<<<<<< HEAD def assertEqual(self, x: Any, y: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override] if ( config.debug_disable_compile_counter @@ -113,6 +126,8 @@ def assertEqual(self, x: Any, y: Any, *args: Any, **kwargs: Any) -> None: # typ # assertExpectedInline might also need to be disabled for wrapped nested # graph break tests +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class CPythonTestCase(TestCase): """ @@ -147,12 +162,21 @@ class CPythonTestCase(TestCase): assertRegex = unittest.TestCase.assertRegex assertNotRegex = unittest.TestCase.assertNotRegex assertCountEqual = unittest.TestCase.assertCountEqual +<<<<<<< HEAD assertMultiLineEqual = polyfills.assert_multi_line_equal assertSequenceEqual = polyfills.assert_sequence_equal assertListEqual = unittest.TestCase.assertListEqual assertTupleEqual = unittest.TestCase.assertTupleEqual assertSetEqual = unittest.TestCase.assertSetEqual assertDictEqual = polyfills.assert_dict_equal +======= + assertMultiLineEqual = unittest.TestCase.assertMultiLineEqual + assertSequenceEqual = unittest.TestCase.assertSequenceEqual + assertListEqual = unittest.TestCase.assertListEqual + assertTupleEqual = unittest.TestCase.assertTupleEqual + assertSetEqual = unittest.TestCase.assertSetEqual + assertDictEqual = unittest.TestCase.assertDictEqual +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assertRaises = unittest.TestCase.assertRaises assertRaisesRegex = unittest.TestCase.assertRaisesRegex assertWarns = unittest.TestCase.assertWarns @@ -161,6 +185,7 @@ class CPythonTestCase(TestCase): fail = unittest.TestCase.fail failureException = unittest.TestCase.failureException +<<<<<<< HEAD def compile_fn( self, fn: Callable[..., Any], @@ -175,6 +200,17 @@ def compile_fn( return fn def _dynamo_test_key(self) -> str: +======= + def compile_fn(self, fn, backend, nopython): + # We want to compile only the test function, excluding any setup code + # from unittest + method = getattr(self, self._testMethodName) + method = torch._dynamo.optimize(backend, nopython=nopython)(method) + setattr(self, self._testMethodName, method) + return fn + + def _dynamo_test_key(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) suffix = super()._dynamo_test_key() test_cls = self.__class__ test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0] diff --git a/torch/_dynamo/test_minifier_common.py b/torch/_dynamo/test_minifier_common.py index f48dae1d0e33e..4fcda7597befd 100644 --- a/torch/_dynamo/test_minifier_common.py +++ b/torch/_dynamo/test_minifier_common.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Common utilities for testing Dynamo's minifier functionality. This module provides the base infrastructure for running minification tests in Dynamo. @@ -23,8 +28,12 @@ import sys import tempfile import traceback +<<<<<<< HEAD from collections.abc import Sequence from typing import Any, Optional, Union +======= +from typing import Optional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from unittest.mock import patch import torch @@ -39,7 +48,11 @@ class MinifierTestResult: minifier_code: str repro_code: str +<<<<<<< HEAD def _get_module(self, t: str) -> str: +======= + def _get_module(self, t): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) match = re.search(r"class Repro\(torch\.nn\.Module\):\s+([ ].*\n| *\n)+", t) assert match is not None, "failed to find module" r = match.group(0) @@ -47,7 +60,11 @@ def _get_module(self, t: str) -> str: r = re.sub(r"\n{3,}", "\n\n", r) return r.strip() +<<<<<<< HEAD def get_exported_program_path(self) -> Optional[str]: +======= + def get_exported_program_path(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Extract the exported program file path from AOTI minifier's repro.py # Regular expression pattern to match the file path pattern = r'torch\.export\.load\(\s*["\'](.*?)["\']\s*\)' @@ -59,10 +76,17 @@ def get_exported_program_path(self) -> Optional[str]: return file_path return None +<<<<<<< HEAD def minifier_module(self) -> str: return self._get_module(self.minifier_code) def repro_module(self) -> str: +======= + def minifier_module(self): + return self._get_module(self.minifier_code) + + def repro_module(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self._get_module(self.repro_code) @@ -70,7 +94,11 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase): DEBUG_DIR = tempfile.mkdtemp() @classmethod +<<<<<<< HEAD def setUpClass(cls) -> None: +======= + def setUpClass(cls): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().setUpClass() if not os.path.exists(cls.DEBUG_DIR): cls.DEBUG_DIR = tempfile.mkdtemp() @@ -93,14 +121,22 @@ def setUpClass(cls) -> None: ) @classmethod +<<<<<<< HEAD def tearDownClass(cls) -> None: +======= + def tearDownClass(cls): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if os.getenv("PYTORCH_KEEP_TMPDIR", "0") != "1": shutil.rmtree(cls.DEBUG_DIR) else: print(f"test_minifier_common tmpdir kept at: {cls.DEBUG_DIR}") cls._exit_stack.close() # type: ignore[attr-defined] +<<<<<<< HEAD def _gen_codegen_fn_patch_code(self, device: str, bug_type: str) -> str: +======= + def _gen_codegen_fn_patch_code(self, device, bug_type): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert bug_type in ("compile_error", "runtime_error", "accuracy") return f"""\ {torch._dynamo.config.codegen_config()} @@ -108,11 +144,15 @@ def _gen_codegen_fn_patch_code(self, device: str, bug_type: str) -> str: torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_TESTING_ONLY = {bug_type!r} """ +<<<<<<< HEAD def _maybe_subprocess_run( self, args: Sequence[Any], *, isolate: bool, cwd: Optional[str] = None ) -> subprocess.CompletedProcess[bytes]: from torch._inductor.cpp_builder import normalize_path_separator +======= + def _maybe_subprocess_run(self, args, *, isolate, cwd=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not isolate: assert len(args) >= 2, args assert args[0] == "python3", args @@ -123,8 +163,12 @@ def _maybe_subprocess_run( else: assert len(args) >= 2, args with open(args[1]) as f: +<<<<<<< HEAD # Need normalize path of the code. code = normalize_path_separator(f.read()) +======= + code = f.read() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = args[1:] # WARNING: This is not a perfect simulation of running @@ -178,9 +222,13 @@ def _maybe_subprocess_run( # Run `code` in a separate python process. # Returns the completed process state and the directory containing the # minifier launcher script, if `code` outputted it. +<<<<<<< HEAD def _run_test_code( self, code: str, *, isolate: bool ) -> tuple[subprocess.CompletedProcess[bytes], Union[str, Any]]: +======= + def _run_test_code(self, code, *, isolate): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) proc = self._maybe_subprocess_run( ["python3", "-c", code], isolate=isolate, cwd=self.DEBUG_DIR ) @@ -196,6 +244,7 @@ def _run_test_code( # Runs the minifier launcher script in `repro_dir` def _run_minifier_launcher( +<<<<<<< HEAD self, repro_dir: str, isolate: bool, @@ -203,6 +252,10 @@ def _run_minifier_launcher( minifier_args: Sequence[Any] = (), repro_after: Optional[str] = None, ) -> tuple[subprocess.CompletedProcess[bytes], str]: +======= + self, repro_dir, isolate, *, minifier_args=(), repro_after=None + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertIsNotNone(repro_dir) launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py")) with open(launch_file) as f: @@ -223,9 +276,13 @@ def _run_minifier_launcher( return launch_proc, launch_code # Runs the repro script in `repro_dir` +<<<<<<< HEAD def _run_repro( self, repro_dir: str, *, isolate: bool = True ) -> tuple[subprocess.CompletedProcess[bytes], str]: +======= + def _run_repro(self, repro_dir, *, isolate=True): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.assertIsNotNone(repro_dir) repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py")) with open(repro_file) as f: @@ -243,7 +300,11 @@ def _run_repro( # `run_code` is the code to run for the test case. # `patch_code` is the code to be patched in every generated file; usually # just use this to turn on bugs via the config +<<<<<<< HEAD def _gen_test_code(self, run_code: str, repro_after: str, repro_level: int) -> str: +======= + def _gen_test_code(self, run_code, repro_after, repro_level): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) repro_after_line = "" if repro_after == "aot_inductor": repro_after_line = ( @@ -276,6 +337,7 @@ def _gen_test_code(self, run_code: str, repro_after: str, repro_level: int) -> s # isolate=True only if the bug you're testing would otherwise # crash the process def _run_full_test( +<<<<<<< HEAD self, run_code: str, repro_after: str, @@ -283,6 +345,9 @@ def _run_full_test( *, isolate: bool, minifier_args: Sequence[Any] = (), +======= + self, run_code, repro_after, expected_error, *, isolate, minifier_args=() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Optional[MinifierTestResult]: if isolate: repro_level = 3 diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 805c3be524e8f..a42831154b905 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -42,7 +42,11 @@ ) from .guards import CheckFunctionManager, CompileId, GuardedCode from .types import ConvertFrameReturn, DynamoFrameType, wrap_guarded_code +<<<<<<< HEAD from .utils import CompileCounterInt, same +======= +from .utils import same +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) np: Optional[types.ModuleType] = None @@ -200,6 +204,7 @@ def insert_nops(instructions: list[Any], code_options: Any) -> None: return ConvertFrameReturn() debug_checks(frame.f_code) +<<<<<<< HEAD code, _ = transform_code_object(frame.f_code, insert_nops) graph = OutputGraph( code_options={}, @@ -207,6 +212,15 @@ def insert_nops(instructions: list[Any], code_options: Any) -> None: root_tx=None, # type: ignore[arg-type] export=False, export_constraints=[], +======= + code = transform_code_object(frame.f_code, insert_nops) + graph = OutputGraph( + code_options={}, + compiler_fn=None, + root_tx=None, + export=False, + export_constraints=None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) frame_state={"_id": 0}, # TODO: shouldn't this be f_locals/f_globals from frame? local_scope=locals(), @@ -227,8 +241,13 @@ def insert_nops(instructions: list[Any], code_options: Any) -> None: class CompileCounter: def __init__(self) -> None: +<<<<<<< HEAD self.frame_count: Union[int, CompileCounterInt] = 0 self.clear() +======= + self.frame_count = 0 + self.op_count = 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __call__( self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] @@ -240,19 +259,30 @@ def __call__( return gm.forward def clear(self) -> None: +<<<<<<< HEAD if config.debug_disable_compile_counter: self.frame_count = CompileCounterInt(0) else: self.frame_count = 0 +======= + self.frame_count = 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.op_count = 0 class CompileCounterWithBackend: def __init__(self, backend: str) -> None: +<<<<<<< HEAD self.frame_count: Union[int, CompileCounterInt] = 0 self.backend = backend self.graphs: list[torch.fx.GraphModule] = [] self.clear() +======= + self.frame_count = 0 + self.op_count = 0 + self.backend = backend + self.graphs: list[torch.fx.GraphModule] = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __call__( self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] @@ -267,10 +297,14 @@ def __call__( return lookup_backend(self.backend)(gm, example_inputs) def clear(self) -> None: +<<<<<<< HEAD if config.debug_disable_compile_counter: self.frame_count = CompileCounterInt(0) else: self.frame_count = 0 +======= + self.frame_count = 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.op_count = 0 self.graphs = [] @@ -420,12 +454,20 @@ def rand_strided( device: Union[str, torch.device] = "cpu", extra_size: int = 0, ) -> torch.Tensor: +<<<<<<< HEAD needed_size = extra_size if all(s > 0 for s in size): # only need to allocate if all sizes are non-zero needed_size += ( sum((shape - 1) * stride for shape, stride in zip(size, stride)) + 1 ) +======= + needed_size = ( + sum((shape - 1) * stride for shape, stride in zip(size, stride)) + + 1 + + extra_size + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if dtype.is_floating_point: if dtype.itemsize == 1: """ diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 47ad8cda0c974..38b447101bc23 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Tracing rules and policies for TorchDynamo compilation decisions. @@ -21,6 +26,10 @@ import abc import builtins +<<<<<<< HEAD +======= +import collections +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import copy import dataclasses import functools @@ -34,6 +43,10 @@ import sys import traceback import types +<<<<<<< HEAD +======= +import typing +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import unittest from collections import defaultdict from pathlib import Path @@ -48,6 +61,7 @@ from . import config from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX +<<<<<<< HEAD from .utils import ( getfile, hashable, @@ -55,6 +69,9 @@ NP_SUPPORTED_MODULES, unwrap_if_wrapper, ) +======= +from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .variables import ( BuiltinVariable, FunctionalCallVariable, @@ -63,13 +80,19 @@ LocalGeneratorObjectVariable, NestedUserFunctionVariable, PolyfilledFunctionVariable, +<<<<<<< HEAD ReparametrizeModuleCallVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SkipFunctionVariable, TorchInGraphFunctionVariable, UserFunctionVariable, UserMethodVariable, ) +<<<<<<< HEAD from .variables.base import VariableTracker +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) np: Optional[types.ModuleType] = None @@ -79,6 +102,13 @@ pass +<<<<<<< HEAD +======= +if typing.TYPE_CHECKING: + from .variables.base import VariableTracker + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ A note on skip/inline rules: @@ -146,6 +176,7 @@ """ +<<<<<<< HEAD manual_torch_name_rule_map: dict[ str, Union[ @@ -154,6 +185,9 @@ type[UserFunctionVariable], ], ] = { +======= +manual_torch_name_rule_map: dict[str, Any] = { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable, "torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable, "torch.overrides.is_tensor_like": TorchInGraphFunctionVariable, @@ -171,7 +205,10 @@ "torch.distributed.distributed_c10d.get_process_group_ranks": TorchInGraphFunctionVariable, "torch._utils.is_compiling": TorchInGraphFunctionVariable, "torch.fx._symbolic_trace.is_fx_tracing": TorchInGraphFunctionVariable, +<<<<<<< HEAD "torch.fx._symbolic_trace.is_fx_symbolic_tracing": TorchInGraphFunctionVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch._dynamo.external_utils.is_compiling": TorchInGraphFunctionVariable, "torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer": UserFunctionVariable, "torch.compiler.is_compiling": TorchInGraphFunctionVariable, @@ -208,10 +245,13 @@ "torch.fx.node.map_aggregate": UserFunctionVariable, "torch.fx.node.map_arg": UserFunctionVariable, "torch.fx.immutable_collections._no_mutation": UserFunctionVariable, +<<<<<<< HEAD "torch.fx.immutable_collections._immutable_list_flatten": UserFunctionVariable, "torch.fx.immutable_collections._immutable_list_unflatten": UserFunctionVariable, "torch.fx.immutable_collections._immutable_dict_flatten": UserFunctionVariable, "torch.fx.immutable_collections._immutable_dict_unflatten": UserFunctionVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # symbol operators implemented in Python "torch.sym_not": TorchInGraphFunctionVariable, "torch.sym_float": TorchInGraphFunctionVariable, @@ -245,8 +285,11 @@ "torch._C.set_autocast_xla_dtype": SkipFunctionVariable, "torch._C.set_autocast_xla_enabled": SkipFunctionVariable, "torch.resize_as_": SkipFunctionVariable, +<<<<<<< HEAD "torch._functorch.predispatch._add_batch_dim": TorchInGraphFunctionVariable, "torch._functorch.predispatch._remove_batch_dim": TorchInGraphFunctionVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch.resize_as_sparse_": SkipFunctionVariable, "torch.get_default_device": TorchInGraphFunctionVariable, # functorch/vmap @@ -317,7 +360,10 @@ # functional_call "torch._functorch.functional_call.functional_call": FunctionalCallVariable, "torch.nn.utils.stateless._groupby_tensor": TorchInGraphFunctionVariable, +<<<<<<< HEAD "torch.nn.utils.stateless._reparametrize_module": ReparametrizeModuleCallVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # functorch/deprecated "torch._functorch.deprecated.jvp": UserFunctionVariable, "torch._functorch.deprecated.hessian": UserFunctionVariable, @@ -327,6 +373,11 @@ "torch._functorch.deprecated.grad_and_value": UserFunctionVariable, "torch._functorch.deprecated.vjp": UserFunctionVariable, # functorch/C++ bindings +<<<<<<< HEAD +======= + "torch._C._functorch._add_batch_dim": TorchInGraphFunctionVariable, + "torch._C._functorch._remove_batch_dim": TorchInGraphFunctionVariable, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch._C._functorch._wrap_for_grad": TorchInGraphFunctionVariable, "torch._C._functorch._unwrap_for_grad": TorchInGraphFunctionVariable, "torch._C._functorch._unwrap_batched": TorchInGraphFunctionVariable, @@ -335,8 +386,11 @@ "torch._C._functorch.is_batchedtensor": TorchInGraphFunctionVariable, "torch._C._functorch.peek_interpreter_stack": TorchInGraphFunctionVariable, "torch._C._functorch.unwrap_if_dead": TorchInGraphFunctionVariable, +<<<<<<< HEAD "torch._functorch.predispatch._vmap_increment_nesting": TorchInGraphFunctionVariable, "torch._functorch.predispatch._vmap_decrement_nesting": TorchInGraphFunctionVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # everything else "torch._functorch.pyfunctorch.coerce_cinterpreter": TorchInGraphFunctionVariable, "torch._higher_order_ops.triton_kernel_wrap.do_prune_configs": UserFunctionVariable, @@ -349,7 +403,10 @@ "torch._dynamo.mark_static": UserFunctionVariable, "torch._dynamo.nonstrict_trace": UserFunctionVariable, "torch._dynamo.patch_dynamo_config": UserFunctionVariable, +<<<<<<< HEAD "torch._dynamo.error_on_graph_break": UserFunctionVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, "torch.fx.experimental.symbolic_shapes.guard_or_true": TorchInGraphFunctionVariable, "torch.fx.experimental.symbolic_shapes.guard_or_false": TorchInGraphFunctionVariable, @@ -586,6 +643,10 @@ "torch._C._dispatch_has_kernel", "torch._C._dispatch_is_alias_key", "torch._C._dispatch_is_included_in_alias", +<<<<<<< HEAD +======= + "torch._C._dispatch_is_main_interpreter", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch._C._dispatch_isTensorSubclassLike", "torch._C._dispatch_key_for_device", "torch._C._dispatch_key_name", @@ -661,7 +722,10 @@ "torch._C._get_cublas_allow_tf32", "torch._C._get_cudnn_allow_tf32", "torch._C._get_cudnn_benchmark", +<<<<<<< HEAD "torch._C._get_miopen_immediate", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch._C._get_cudnn_deterministic", "torch._C._get_cudnn_enabled", "torch._C._get_custom_class_python_wrapper", @@ -1947,7 +2011,10 @@ "torch.geqrf", "torch.ger", "torch.get_device", +<<<<<<< HEAD "torch.get_device_module", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch.gradient", "torch.greater_equal", "torch.greater", @@ -1961,7 +2028,10 @@ "torch.hamming_window", "torch.hann_window", "torch.hardshrink", +<<<<<<< HEAD "torch.hash_tensor", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch.heaviside", "torch.hinge_embedding_loss", "torch.histc", @@ -2368,11 +2438,15 @@ "torch._functorch.utils.enable_single_level_autograd_function", "torch._functorch.utils.exposed_in", "torch._functorch.utils.unwrap_dead_wrappers", +<<<<<<< HEAD "torch._functorch.predispatch.lazy_load_decompositions", "torch._functorch.predispatch._vmap_increment_nesting", "torch._functorch.predispatch._vmap_decrement_nesting", "torch._functorch.predispatch._add_batch_dim", "torch._functorch.predispatch._remove_batch_dim", +======= + "torch._functorch.vmap.lazy_load_decompositions", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch._guards.compile_context", "torch._guards.detect_fake_mode", "torch._guards.tracing", @@ -2417,6 +2491,10 @@ "torch._lowrank.svd_lowrank", "torch._preload_cuda_deps", "torch._register_device_module", +<<<<<<< HEAD +======= + "torch._running_with_deploy", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch._utils._dummy_type", "torch._utils._flatten_dense_tensors", "torch._utils._unflatten_dense_tensors", @@ -2690,6 +2768,10 @@ "torch.cuda.set_stream", "torch.cuda.set_sync_debug_mode", "torch.cuda.stream", +<<<<<<< HEAD +======= + "torch.cuda.synchronize", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch.cuda.temperature", "torch.cuda.utilization", "torch.einsum", @@ -2968,7 +3050,10 @@ "torch.xpu.random.seed_all", "torch.xpu.random.seed", "torch.xpu.set_stream", +<<<<<<< HEAD "torch.xpu.stream", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch.xpu.synchronize", ], TorchInGraphFunctionVariable, @@ -2995,6 +3080,7 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]: if ".py#" not in k: obj = load_object(k) else: +<<<<<<< HEAD torch_dir = _module_dir(torch) if torch_dir is None: continue @@ -3002,6 +3088,10 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]: if obj is not None: if is_lru_cache_wrapped_function(obj): obj = obj.__wrapped__ +======= + obj = _module_dir(torch) + k[len("torch/") :] + if obj is not None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if obj in d and d[obj] != v: raise AssertionError( f"Duplicate torch object {obj} with different rules: {v}, {d[obj]}" @@ -3011,7 +3101,11 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]: return d +<<<<<<< HEAD def _load_obj_from_str(fully_qualified_name: str) -> Any: +======= +def _load_obj_from_str(fully_qualified_name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) module, obj_name = fully_qualified_name.rsplit(".", maxsplit=1) return getattr(importlib.import_module(module), obj_name) @@ -3021,7 +3115,11 @@ def _load_obj_from_str(fully_qualified_name: str) -> Any: """ +<<<<<<< HEAD def load_object(name: str) -> Any: +======= +def load_object(name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: x = name.split("#") if len(x) == 2: @@ -3042,7 +3140,11 @@ def load_object(name: str) -> Any: @functools.cache +<<<<<<< HEAD def get_tensor_method() -> frozenset[Any]: +======= +def get_tensor_method(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) disallowed_tensor_methods = {"__new__", "_make_wrapper_subclass", "_make_subclass"} s = set() for name in dir(torch.Tensor): @@ -3071,7 +3173,11 @@ def get_tensor_method() -> frozenset[Any]: """ +<<<<<<< HEAD def is_aten_op_or_tensor_method(obj: Any) -> bool: +======= +def is_aten_op_or_tensor_method(obj): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return obj in get_tensor_method() or isinstance( obj, (torch._ops.OpOverloadPacket, torch._ops.OpOverload), @@ -3107,16 +3213,28 @@ def __call__(self) -> set[int]: self.function_ids = value return self.function_ids +<<<<<<< HEAD def get_name(self, idx: int, default: str) -> str: +======= + def get_name(self, idx: int, default: str): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self() # lazy init assert self.function_names is not None return self.function_names.get(idx, default) +<<<<<<< HEAD def add(self, idx: int) -> None: function_ids = self() # lazy init function_ids.add(idx) def remove(self, idx: int) -> None: +======= + def add(self, idx: int): + function_ids = self() # lazy init + function_ids.add(idx) + + def remove(self, idx: int): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) function_ids = self() if idx in function_ids: function_ids.remove(idx) @@ -3184,7 +3302,11 @@ def _numpy_function_ids() -> dict[int, str]: "sample", } +<<<<<<< HEAD def is_supported(k: str, v: Any, mod: Any) -> bool: +======= + def is_supported(k, v, mod): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not callable(v): return False if not getattr(v, "__module__", None): @@ -3243,53 +3365,93 @@ def _maybe_init_lazy_module(obj: object) -> None: fn() +<<<<<<< HEAD def is_callable_allowed(obj: Any) -> bool: +======= +def is_callable_allowed(obj) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _maybe_init_lazy_module(obj) return id(obj) in _allowed_callable_ids +<<<<<<< HEAD def is_nonstrict_trace_callable(obj: Any) -> bool: +======= +def is_nonstrict_trace_callable(obj) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _maybe_init_lazy_module(obj) return id(obj) in _nonstrict_trace_callable_ids +<<<<<<< HEAD def is_callable_disallowed(obj: Any) -> bool: +======= +def is_callable_disallowed(obj) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _maybe_init_lazy_module(obj) return id(obj) in _disallowed_callable_ids +<<<<<<< HEAD def is_forbidden(obj: Any) -> bool: +======= +def is_forbidden(obj) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _maybe_init_lazy_module(obj) return inspect.getattr_static(obj, "_dynamo_forbidden", False) +<<<<<<< HEAD def is_builtin_callable(obj: Any) -> bool: +======= +def is_builtin_callable(obj) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids return id(obj) in _builtin_function_ids +<<<<<<< HEAD def is_builtin_constant(obj: Any) -> bool: return id(obj) in _builtin_constant_ids def is_polyfilled_callable(obj: Any) -> bool: +======= +def is_builtin_constant(obj) -> bool: + return id(obj) in _builtin_constant_ids + + +def is_polyfilled_callable(obj) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # See also @torch._dynamo.decorators.substitute_in_graph(...), which adds items in _polyfilled_function_ids return id(obj) in _polyfilled_function_ids +<<<<<<< HEAD def is_numpy(obj: Any) -> bool: +======= +def is_numpy(obj) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if np is None: return False return isinstance(obj, (np.ndarray, np.generic)) or id(obj) in _numpy_function_ids +<<<<<<< HEAD def is_numpy_dtype(obj: Any) -> bool: +======= +def is_numpy_dtype(obj) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if np is None: return False return isinstance(obj, np.dtype) +<<<<<<< HEAD def is_numpy_type_info(obj: Any) -> bool: +======= +def is_numpy_type_info(obj) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if np is None: return False return isinstance(obj, (np.finfo, np.iinfo)) @@ -3297,6 +3459,10 @@ def is_numpy_type_info(obj: Any) -> bool: BUILTIN_SKIPLIST = ( abc, +<<<<<<< HEAD +======= + collections, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) copy, random, traceback, @@ -3326,7 +3492,11 @@ def is_numpy_type_info(obj: Any) -> bool: ) +<<<<<<< HEAD def _as_posix_path(path: str) -> str: +======= +def _as_posix_path(path): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) posix_path = Path(os.path.normpath(path)).as_posix() # os.path.normpath and pathlib.Path remove trailing slash, so we need to add it back if path.endswith((os.path.sep, "/")): @@ -3334,13 +3504,21 @@ def _as_posix_path(path: str) -> str: return posix_path +<<<<<<< HEAD def _strip_init_py(s: str) -> str: +======= +def _strip_init_py(s): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) suffix = "__init__.py" s = s.removesuffix(suffix) return _as_posix_path(s) +<<<<<<< HEAD def _module_dir(m: types.ModuleType) -> Optional[str]: +======= +def _module_dir(m: types.ModuleType): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Protect against a module not exporting __file__ - this can happen for # frozen modules, for example. file = getattr(m, "__file__", None) @@ -3403,7 +3581,10 @@ def _module_dir(m: types.ModuleType) -> Optional[str]: "torch._dynamo.compiled_autograd", "torch._dynamo.comptime", "torch._dynamo.polyfills", +<<<<<<< HEAD "torch._dynamo.test_case", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch._functorch._aot_autograd.subclass_parametrization", "torch._functorch.autograd_function", "torch._functorch.eager_transforms", @@ -3478,6 +3659,10 @@ def _module_dir(m: types.ModuleType) -> Optional[str]: "torch._custom_op", "torch._custom_ops", "torch._decomp", +<<<<<<< HEAD +======= + "torch._deploy", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "torch._dispatch", "torch._dynamo", "torch._export", @@ -3562,36 +3747,54 @@ def _module_dir(m: types.ModuleType) -> Optional[str]: @functools.cache +<<<<<<< HEAD def get_legacy_mod_inlinelist() -> set[str]: torch_dir = _module_dir(torch) if torch_dir is None: return set() inlinelist = { _as_posix_path(torch_dir + m[len("torch.") :].replace(".", "/")) +======= +def get_legacy_mod_inlinelist(): + inlinelist = { + _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for m in LEGACY_MOD_INLINELIST } return inlinelist @functools.cache +<<<<<<< HEAD def get_mod_inlinelist() -> set[str]: torch_dir = _module_dir(torch) if torch_dir is None: return set() inlinelist = { _as_posix_path(torch_dir + m[len("torch.") :].replace(".", "/")) +======= +def get_mod_inlinelist(): + inlinelist = { + _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for m in MOD_INLINELIST } return inlinelist @functools.cache +<<<<<<< HEAD def get_mod_skiplist() -> set[str]: torch_dir = _module_dir(torch) if torch_dir is None: return set() skiplist = { _as_posix_path(torch_dir + m[len("torch.") :].replace(".", "/")) +======= +def get_mod_skiplist(): + skiplist = { + _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for m in MOD_SKIPLIST } return skiplist @@ -3648,14 +3851,22 @@ def get_mod_skiplist() -> set[str]: FORCE_SKIP_FILES = {f"{_module_dir(torch)}optim/lr_scheduler.py"} +<<<<<<< HEAD def _recompile_re() -> None: +======= +def _recompile_re(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) global SKIP_DIRS_RE SKIP_DIRS_RE = re.compile( rf"^[^\s<]*({'|'.join(re.escape(_as_posix_path(d)) for d in SKIP_DIRS)})" ) +<<<<<<< HEAD def add(import_name: str) -> None: +======= +def add(import_name: str): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(import_name, types.ModuleType): return add(import_name.__name__) assert isinstance(import_name, str) @@ -3677,7 +3888,11 @@ class SkipResult: reason: Optional[str] +<<<<<<< HEAD def check_file(filename: Optional[str], is_inlined_call: bool = False) -> SkipResult: +======= +def check_file(filename, is_inlined_call=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Should skip this file?""" if filename is None: return SkipResult(True, "filename is None") @@ -3715,10 +3930,15 @@ def check_file(filename: Optional[str], is_inlined_call: bool = False) -> SkipRe ): return SkipResult(True, "FBCODE_SKIP_TORCHREC_DIRS") +<<<<<<< HEAD unittest_dir = _module_dir(unittest) if ( unittest_dir is not None and filename.startswith(unittest_dir) +======= + if ( + filename.startswith(_module_dir(unittest)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and not torch._dynamo.config.enable_trace_unittest ): return SkipResult(True, "unittest") @@ -3773,7 +3993,11 @@ def f3(x, y): """ +<<<<<<< HEAD def check_verbose(obj: Any, is_inlined_call: bool = False) -> SkipResult: +======= +def check_verbose(obj, is_inlined_call=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance( obj, ( @@ -3792,6 +4016,7 @@ def check_verbose(obj: Any, is_inlined_call: bool = False) -> SkipResult: elif isinstance(obj, types.CodeType): fi = FunctionInfo(None, obj.co_name, obj.co_filename, obj) elif isinstance(obj, (types.FunctionType, types.MethodType)): +<<<<<<< HEAD filename = getfile(obj) assert filename is not None fi = FunctionInfo( @@ -3804,11 +4029,24 @@ def check_verbose(obj: Any, is_inlined_call: bool = False) -> SkipResult: filename = getfile(obj) assert filename is not None fi = FunctionInfo(obj, None, filename, None) +======= + fi = FunctionInfo( + obj, + obj.__name__, + getfile(obj), + obj.__code__, # type: ignore[union-attr] # FIXME Add MethodType.__code__ to typeshed + ) + else: + fi = FunctionInfo(obj, None, getfile(obj), None) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Consulte the central trace rules defined in torch._dynamo.trace_rules. reasons: set[str] = set() rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons) +<<<<<<< HEAD assert rule is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if issubclass( rule, ( @@ -3834,7 +4072,11 @@ def check_verbose(obj: Any, is_inlined_call: bool = False) -> SkipResult: ) +<<<<<<< HEAD def check(obj: Any, is_inlined_call: bool = False) -> bool: +======= +def check(obj, is_inlined_call=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return check_verbose(obj, is_inlined_call).skipped @@ -3845,23 +4087,38 @@ def check(obj: Any, is_inlined_call: bool = False) -> bool: _recompile_re() +<<<<<<< HEAD def is_torch_inline_allowed(filename: str) -> bool: +======= +def is_torch_inline_allowed(filename): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return any(filename.startswith(d) for d in get_mod_inlinelist()) @functools.cache +<<<<<<< HEAD def dynamo_dir() -> Optional[str]: +======= +def dynamo_dir(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch._dynamo return _module_dir(torch._dynamo) +<<<<<<< HEAD def is_torch(filename: str) -> bool: dynamo_path = dynamo_dir() if dynamo_path is not None and filename.startswith(dynamo_path): return False torch_path = _module_dir(torch) return torch_path is not None and filename.startswith(torch_path) +======= +def is_torch(filename): + if filename.startswith(dynamo_dir()): + return False + return filename.startswith(_module_dir(torch)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ @@ -3869,7 +4126,11 @@ def is_torch(filename: str) -> bool: """ +<<<<<<< HEAD def lookup_callable(obj: Callable[..., Any]) -> Optional[type[VariableTracker]]: +======= +def lookup_callable(obj): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not hashable(obj): return None # Custom allow/disallow in graph takes precedence over the general lookup. @@ -3890,18 +4151,31 @@ def lookup_callable(obj: Callable[..., Any]) -> Optional[type[VariableTracker]]: """ +<<<<<<< HEAD def lookup(obj: Any) -> Optional[type[VariableTracker]]: +======= +def lookup(obj): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return lookup_inner(obj) # also takes config.dont_skip_tracing into account def lookup_inner( +<<<<<<< HEAD obj: Any, name: Optional[str] = None, filename: Optional[str] = None, is_direct_call: bool = True, reasons: Union[None, set[str]] = None, ) -> Optional[type[VariableTracker]]: +======= + obj, + name=None, + filename=None, + is_direct_call=True, + reasons: Union[None, set[str]] = None, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result = _lookup_inner( obj, name=name, @@ -3916,6 +4190,7 @@ def lookup_inner( if config.dont_skip_tracing and result is SkipFunctionVariable: if filename is None: filename = getfile(obj) +<<<<<<< HEAD assert filename is not None filename = _as_posix_path(filename) torch_dir = _module_dir(torch) @@ -3925,6 +4200,14 @@ def lookup_inner( "test_dont_skip_tracing_functions.py" ): return SkipFunctionVariable +======= + filename = _as_posix_path(filename) + dynamo_path = _as_posix_path(_module_dir(torch)) + "_dynamo" + if filename.startswith(dynamo_path) and not filename.endswith( + "test_dont_skip_tracing_functions.py" + ): + return SkipFunctionVariable +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if reasons is not None: reasons.add( "Attempted skip but we are ignoring skips due to torch._dynamo.config.dont_skip_tracing" @@ -3934,12 +4217,21 @@ def lookup_inner( def _lookup_inner( +<<<<<<< HEAD obj: Any, name: Optional[str] = None, filename: Optional[str] = None, is_direct_call: bool = True, reasons: Optional[set[str]] = None, ) -> Optional[type[VariableTracker]]: +======= + obj, + name=None, + filename=None, + is_direct_call=True, + reasons: Union[None, set[str]] = None, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Step 1: lookup obj's tracing rule in `torch_name_rule_map`. # The rules defined in `torch_name_rule_map` mainly includes two parts: # - Manually defined rules for any functions. @@ -4013,7 +4305,11 @@ def _lookup_inner( filename = getfile(obj) skip_result = check_file(filename, is_direct_call) +<<<<<<< HEAD if reasons is not None and skip_result.reason is not None: +======= + if reasons is not None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reasons.add(skip_result.reason) if skip_result.skipped: return SkipFunctionVariable @@ -4021,7 +4317,11 @@ def _lookup_inner( return UserFunctionVariable +<<<<<<< HEAD def clear_lru_cache() -> None: +======= +def clear_lru_cache(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._dynamo.trace_rules.get_torch_obj_rule_map.cache_clear() torch._dynamo.trace_rules.get_tensor_method.cache_clear() torch._dynamo.trace_rules.get_legacy_mod_inlinelist.cache_clear() diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 058a66cf5b772..b91039c4aead8 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Utility functions and classes used throughout the TorchDynamo system. @@ -60,7 +65,11 @@ TypeVar, Union, ) +<<<<<<< HEAD from typing_extensions import Literal, ParamSpec, TypeAlias, TypeGuard, TypeIs +======= +from typing_extensions import Literal, TypeAlias, TypeGuard, TypeIs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._functorch.config @@ -95,12 +104,16 @@ if typing.TYPE_CHECKING: from collections.abc import ( +<<<<<<< HEAD Container, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Generator, ItemsView, Iterable, Iterator, KeysView, +<<<<<<< HEAD Mapping, Sequence, ValuesView, @@ -114,6 +127,11 @@ from torch._dynamo.variables.base import VariableTracker from torch._prims_common import DeviceLikeType +======= + ValuesView, + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: import numpy as np @@ -153,8 +171,11 @@ T = TypeVar("T") +<<<<<<< HEAD R = TypeVar("R") _P = ParamSpec("_P") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unpatched_nn_module_getattr = torch.nn.Module.__getattr__ unpatched_nn_module_call = torch.nn.Module.__call__ @@ -194,43 +215,71 @@ class ReinplaceCounters: # Track sizes of known not re-inplaced tensors (exclude dynamic shapes). @classmethod +<<<<<<< HEAD def add_missed_bytes(cls, trigger: ReInplaceTrigger, bytes: int) -> None: +======= + def add_missed_bytes(cls, trigger: ReInplaceTrigger, bytes: int): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if bytes != 0: cls._values[f"missed_bytes_{trigger.name}"] += bytes # Track number of not re-inplaced tensors. @classmethod +<<<<<<< HEAD def add_missed_opportunities(cls, trigger: ReInplaceTrigger, count: int) -> None: +======= + def add_missed_opportunities(cls, trigger: ReInplaceTrigger, count: int): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if count != 0: cls._values[f"missed_tensors_{trigger}"] += count @classmethod +<<<<<<< HEAD def clear(cls) -> None: cls._values.clear() @classmethod def get_total_missed(cls) -> int: +======= + def clear(cls): + cls._values.clear() + + @classmethod + def get_total_missed(cls): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sum = 0 for trigger in ReInplaceTrigger: sum += cls._values.get(f"missed_tensors_{trigger}", 0) return sum @classmethod +<<<<<<< HEAD def get_total_missed_bytes(cls) -> int: +======= + def get_total_missed_bytes(cls): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sum = 0 for trigger in ReInplaceTrigger: sum += cls._values.get(f"missed_bytes_{trigger.name}", 0) return sum @classmethod +<<<<<<< HEAD def log(cls) -> None: +======= + def log(cls): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # if not empty log. if cls._values: signpost_event("inductor", "reinplace_counters", cls._values) def tabulate( +<<<<<<< HEAD rows: Union[list[tuple[str, Any]], list[list[Any]]], +======= + rows: Union[list[tuple[str, object]], list[list[object]]], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) headers: Union[tuple[str, ...], list[str]], ) -> str: try: @@ -260,6 +309,7 @@ def reset_frame_count() -> None: curr_frame = 0 +<<<<<<< HEAD _recompile_user_contexts: Optional[list[Callable[[], str]]] = None @@ -281,6 +331,8 @@ def get_hook_for_recompile_user_context() -> Optional[list[Callable[[], str]]]: return _recompile_user_contexts +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) op_count = 0 @@ -395,7 +447,11 @@ def log_instant_event( metadata: dict[str, Any], time_ns: Optional[int] = None, log_level: CompileEventLogLevel = CompileEventLogLevel.CHROMIUM, +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if time_ns is None: time_ns = time.time_ns() chromium_log = get_chromium_event_logger() @@ -417,7 +473,11 @@ def add_data( log_level: CompileEventLogLevel, overwrite: bool = False, **metadata: object, +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Centralized API for adding data to various events Log an event to a toplevel "dynamo" event or metrics context @@ -460,7 +520,11 @@ def add_data( @staticmethod def add_toplevel( log_level: CompileEventLogLevel, overwrite: bool = False, **metadata: object +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Syntactic sugar for logging to the toplevel event """ @@ -474,7 +538,11 @@ def add_toplevel( @staticmethod def increment( event_name: str, log_level: CompileEventLogLevel, key: str, value: int +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Increments an existing field, or adds it """ @@ -507,7 +575,11 @@ def increment_toplevel( key: str, value: int = 1, log_level: CompileEventLogLevel = CompileEventLogLevel.COMPILATION_METRIC, +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Increments a value on the toplevel metric. By default, logs to metric. """ @@ -522,7 +594,11 @@ def increment_toplevel( @staticmethod def add_to_set( event_name: str, log_level: CompileEventLogLevel, key: str, value: Any +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Add metadata to a set of values with key . Creates a set if it doesn't exist. """ @@ -555,7 +631,11 @@ def add_to_set_toplevel( key: str, value: Any, log_level: CompileEventLogLevel = CompileEventLogLevel.COMPILATION_METRIC, +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Same as add to set, just does it automatically to the toplevel event instead of having to explicitly name it. Defaults to COMPILATION_METRIC log level. @@ -571,7 +651,11 @@ def add_to_set_toplevel( # Helper functions that are syntactic sugar @staticmethod +<<<<<<< HEAD def chromium(event_name: str, **metadata: object) -> None: +======= + def chromium(event_name: str, **metadata: object): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Add to in chromium. Each key/value of metadata will appear in the chromium trace. should be the name of a timed event span passed to `dynamo_timed`. @@ -581,7 +665,11 @@ def chromium(event_name: str, **metadata: object) -> None: ) @staticmethod +<<<<<<< HEAD def pt2_compile(event_name: str, **metadata: object) -> None: +======= + def pt2_compile(event_name: str, **metadata: object): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Add to in chromium and PT2 Compile Events. Each key/value of metadata will appear in the chromium trace. Each kwarg name becomes @@ -594,7 +682,11 @@ def pt2_compile(event_name: str, **metadata: object) -> None: ) @staticmethod +<<<<<<< HEAD def compilation_metric(overwrite: bool = False, **metadata: object) -> None: +======= + def compilation_metric(overwrite: bool = False, **metadata: object): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Add to the CompilationMetrics context. Also logs to PT2 Compile Events and chromium. @@ -608,7 +700,11 @@ def compilation_metric(overwrite: bool = False, **metadata: object) -> None: @staticmethod def instant( event_name: str, metadata: dict[str, Any], time_ns: Optional[int] = None +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Log an instant event to chromium logs with name at time . The `args` field in Perfetto will point to metadata. should be a value obtained from time.time_ns(). @@ -618,7 +714,11 @@ def instant( ) @staticmethod +<<<<<<< HEAD def try_add_pt2_compile(event_name: str, **metadata: object) -> None: +======= + def try_add_pt2_compile(event_name: str, **metadata: object): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Adds to an existing pt2_compile event, but silently returns if the event doesn't exist or ChromiumEventLogger is not initialized. @@ -630,7 +730,11 @@ def try_add_pt2_compile(event_name: str, **metadata: object) -> None: chromium_log.try_add_event_data(event_name, **metadata) @staticmethod +<<<<<<< HEAD def try_(method_fn: Callable[_P, Any], *args: _P.args, **kwargs: _P.kwargs) -> None: +======= + def try_(method_fn, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Special function that quietly runs a given method, returning if CHROMIUM_EVENT_LOG is None or metrics context is not set """ @@ -801,9 +905,13 @@ def compile_times( ) -> tuple[list[str], list[object]]: ... +<<<<<<< HEAD def compile_times( # type: ignore[misc] repr: str = "str", aggregate: bool = False ) -> Union[str, None, tuple[list[str], list[str]]]: +======= +def compile_times(repr="str", aggregate: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Get metrics about torchdynamo frontend/backend compilation times. @@ -817,7 +925,11 @@ def compile_times( # type: ignore[misc] per metric. """ +<<<<<<< HEAD def fmt_fn(values: list[float], item_fn: Callable[[float], str] = str) -> str: +======= + def fmt_fn(values, item_fn=lambda x: x): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if aggregate: return item_fn(sum(values)) return ", ".join(map(item_fn, values)) @@ -864,8 +976,13 @@ def __init__(self, maxsize: int = 4096) -> None: self.maxsize = maxsize self.reset() +<<<<<<< HEAD def reset(self) -> None: self.set: OrderedDict[Any, Any] = OrderedDict() +======= + def reset(self): + self.set = OrderedDict() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def add(self, key: Union[str, tuple[object, object]]) -> bool: if key in self.set: @@ -882,7 +999,11 @@ def add(self, key: Union[str, tuple[object, object]]) -> bool: graph_break_dup_warning_checker = DuplicateWarningChecker() +<<<<<<< HEAD def setup_compile_debug() -> contextlib.ExitStack: +======= +def setup_compile_debug(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" if compile_debug: @@ -895,7 +1016,11 @@ def reset_graph_break_dup_checker() -> None: graph_break_dup_warning_checker.reset() +<<<<<<< HEAD def add_file_handler() -> contextlib.ExitStack: +======= +def add_file_handler(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log_path = os.path.join(get_debug_dir(), "torchdynamo") os.makedirs(log_path, exist_ok=True) @@ -908,7 +1033,11 @@ def add_file_handler() -> contextlib.ExitStack: return exitstack +<<<<<<< HEAD def setup_log_file() -> contextlib.ExitStack: +======= +def setup_log_file(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) exitstack = contextlib.ExitStack() if config.log_file_name is not None: log_file_handler = logging.FileHandler(config.log_file_name) @@ -920,12 +1049,20 @@ def setup_log_file() -> contextlib.ExitStack: return exitstack +<<<<<<< HEAD def gen_record_file_name(exc: Exception, code: CodeType) -> str: +======= +def gen_record_file_name(exc, code) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"{get_debug_dir()}/error_recordings/\ {code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec" +<<<<<<< HEAD def write_record_to_file(filename: str, exec_record: ExecutionRecord) -> None: +======= +def write_record_to_file(filename: str, exec_record) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: if os.path.exists(filename): log.warning( @@ -951,7 +1088,11 @@ def identity(x: T) -> T: return x +<<<<<<< HEAD def hashable(x: Any) -> bool: +======= +def hashable(x): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: hash(x) return True @@ -962,13 +1103,18 @@ def hashable(x: Any) -> bool: return False +<<<<<<< HEAD def nothing(*args: Any, **kwargs: Any) -> None: +======= +def nothing(*args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pass class ExactWeakKeyDictionary: """Similar to weakref.WeakKeyDictionary, but use `is`/`id` rather than `==` to compare equality""" +<<<<<<< HEAD def __init__(self) -> None: self.values: dict[int, Any] = {} self.refs: dict[int, weakref.ReferenceType[Any]] = {} @@ -983,18 +1129,42 @@ def __contains__(self, key: Any) -> bool: return id(key) in self.values def __setitem__(self, key: Any, value: Any) -> None: +======= + def __init__(self): + self.values = {} + self.refs = {} + + def __getitem__(self, key): + return self.values[id(key)] + + def get(self, key, default=None): + return self.values.get(id(key), default) + + def __contains__(self, key): + return id(key) in self.values + + def __setitem__(self, key, value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) idx = id(key) if idx not in self.refs: self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx)) self.values[idx] = value +<<<<<<< HEAD def _remove_id(self, idx: int) -> None: +======= + def _remove_id(self, idx): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if idx in self.values: del self.values[idx] if idx in self.refs: del self.refs[idx] +<<<<<<< HEAD def clear(self) -> None: +======= + def clear(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.refs.clear() self.values.clear() @@ -1013,7 +1183,11 @@ def istype( def istype(obj: object, allowed_types: Iterable[type]) -> bool: ... +<<<<<<< HEAD def istype(obj: object, allowed_types: Any) -> bool: +======= +def istype(obj, allowed_types): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """isinstance() without subclasses""" if isinstance(allowed_types, (tuple, list, set)): return type(obj) in allowed_types @@ -1033,7 +1207,11 @@ def istype(obj: object, allowed_types: Any) -> bool: ) +<<<<<<< HEAD def is_typing(value: Any) -> bool: +======= +def is_typing(value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # _Final catches most of typing classes: # - Any # - Callable @@ -1047,7 +1225,11 @@ def is_typing(value: Any) -> bool: return isinstance(value, typing._Final) or value is typing.Generic # type: ignore[attr-defined] +<<<<<<< HEAD def is_numpy_int_type(value: Any) -> bool: +======= +def is_numpy_int_type(value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not np: return False @@ -1066,7 +1248,11 @@ def is_numpy_int_type(value: Any) -> bool: ) +<<<<<<< HEAD def is_numpy_float_type(value: Any) -> bool: +======= +def is_numpy_float_type(value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not np: return False @@ -1178,11 +1364,19 @@ def is_wrapper_or_member_descriptor( ) +<<<<<<< HEAD def unwrap_if_wrapper(fn: Any) -> Any: return unwrap_with_attr_name_if_wrapper(fn)[0] def unwrap_with_attr_name_if_wrapper(fn: Any) -> tuple[Any, Optional[str]]: +======= +def unwrap_if_wrapper(fn): + return unwrap_with_attr_name_if_wrapper(fn)[0] + + +def unwrap_with_attr_name_if_wrapper(fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO(anijain2305) - Investigate if we can get rid of this function # unpack @torch._dynamo.optimize()(fn) wrapped function if is_function(fn) and inspect.getattr_static(fn, "_torchdynamo_inline", False): @@ -1193,14 +1387,22 @@ def unwrap_with_attr_name_if_wrapper(fn: Any) -> tuple[Any, Optional[str]]: return fn, attr_name +<<<<<<< HEAD def is_numpy_ndarray(value: Any) -> TypeGuard[np.ndarray]: # type: ignore[type-arg] +======= +def is_numpy_ndarray(value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not np: return False return istype(value, np.ndarray) +<<<<<<< HEAD def istensor(obj: Any) -> bool: +======= +def istensor(obj): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Check of obj is a tensor""" tensor_list: tuple[type, ...] = ( torch.Tensor, @@ -1211,11 +1413,16 @@ def istensor(obj: Any) -> bool: return istype(obj, tensor_list) +<<<<<<< HEAD def is_lazy_module(mod: Any) -> bool: +======= +def is_lazy_module(mod): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return isinstance(mod, LazyModuleMixin) @functools.lru_cache(4096) +<<<<<<< HEAD def print_once(*args: Any) -> None: print(*args) @@ -1225,13 +1432,28 @@ def make_cell(val: Any = None) -> types.CellType: x = val def f() -> Any: +======= +def print_once(*args): + print(*args) + + +def make_cell(val=None): + """Some black magic to create a cell object that usually only exists in a closure""" + x = val + + def f(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return x assert f.__closure__ is not None and len(f.__closure__) == 1 return f.__closure__[0] +<<<<<<< HEAD def proxy_args_kwargs(args: Any, kwargs: Any) -> tuple[tuple[Any, ...], dict[str, Any]]: +======= +def proxy_args_kwargs(args, kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: proxy_args = tuple(arg.as_proxy() for arg in args) proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} @@ -1291,9 +1513,12 @@ class CompilationMetrics: compliant_custom_ops: Optional[set[str]] = None restart_reasons: Optional[set[str]] = None dynamo_time_before_restart_s: Optional[float] = None +<<<<<<< HEAD stack_trace: Optional[list[str]] = None exception_stack_trace: Optional[list[str]] = None graph_node_shapes: Optional[str] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Sometimes, we will finish analyzing a frame but conclude we don't want # to install any guarded code. True means we actually decided to install # a compiled frame @@ -1362,11 +1587,17 @@ class CompilationMetrics: # The number of parameters counted by fields. This is mostly a proxy for # the number of distinct type of params. param_count: Optional[int] = None +<<<<<<< HEAD recompile_user_contexts: Optional[set[str]] = None inline_inbuilt_nn_modules_candidate: Optional[bool] = False @classmethod def create(cls, metrics: dict[str, Any]) -> CompilationMetrics: +======= + + @classmethod + def create(cls, metrics: dict[str, Any]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Factory method to create a CompilationMetrics from a dict of fields. Includes the logic to add legacy fields and any pre-processing, e.g., @@ -1491,6 +1722,7 @@ def add_compilation_metrics_to_chromium(c: CompilationMetrics) -> None: fail_user_frame_filename=c.fail_user_frame_filename, fail_user_frame_lineno=c.fail_user_frame_lineno, # Sets aren't JSON serializable +<<<<<<< HEAD non_compliant_ops=( list(c.non_compliant_ops) if c.non_compliant_ops is not None else None ), @@ -1500,6 +1732,17 @@ def add_compilation_metrics_to_chromium(c: CompilationMetrics) -> None: restart_reasons=( list(c.restart_reasons) if c.restart_reasons is not None else None ), +======= + non_compliant_ops=list(c.non_compliant_ops) + if c.non_compliant_ops is not None + else None, + compliant_custom_ops=list(c.compliant_custom_ops) + if c.compliant_custom_ops is not None + else None, + restart_reasons=list(c.restart_reasons) + if c.restart_reasons is not None + else None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dynamo_time_before_restart_s=c.dynamo_time_before_restart_s, has_guarded_code=c.has_guarded_code, dynamo_config=c.dynamo_config, @@ -1549,7 +1792,11 @@ def _scrubbed_inductor_config_for_logging() -> Optional[str]: # TypeSafeSerializer for json.dumps() # Skips complex types as values in config dict class TypeSafeSerializer(json.JSONEncoder): +<<<<<<< HEAD def default(self, o: Any) -> Any: +======= + def default(self, o): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: return super().default(o) except Exception: @@ -1557,6 +1804,7 @@ def default(self, o: Any) -> Any: keys_to_scrub: set[Any] = set() inductor_conf_str = None +<<<<<<< HEAD inductor_config_copy = None if torch._inductor.config: @@ -1565,6 +1813,11 @@ def default(self, o: Any) -> Any: except (TypeError, AttributeError): inductor_conf_str = "Inductor Config cannot be pickled" +======= + inductor_config_copy = ( + torch._inductor.config.get_config_copy() if torch._inductor.config else None + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if inductor_config_copy is not None: try: for key, val in inductor_config_copy.items(): @@ -1595,7 +1848,11 @@ def record_compilation_metrics( metrics: dict[str, Any], exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], +<<<<<<< HEAD ) -> None: +======= +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if torch._inductor.utils.should_use_remote_fx_graph_cache(): try: from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION @@ -1625,8 +1882,11 @@ def record_compilation_metrics( torch._logging.get_structured_logging_overhead() ), "dynamo_config": _get_dynamo_config_for_logging(), +<<<<<<< HEAD "config_suppress_errors": config.suppress_errors, "config_inline_inbuilt_nn_modules": config.inline_inbuilt_nn_modules, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "inductor_config": _scrubbed_inductor_config_for_logging(), "cuda_version": torch.version.cuda, "triton_version": triton.__version__ if has_triton() else "", @@ -1717,7 +1977,11 @@ def get_outermost_event(self) -> Optional[str]: stack = self.get_stack() return stack[0] if stack else None +<<<<<<< HEAD def get_pt2_compile_substack(self) -> list[str]: +======= + def get_pt2_compile_substack(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ A smaller subset of the main stack that gets used to log PT2 Compile Events internally. @@ -1733,6 +1997,7 @@ def get_event_data(self) -> dict[str, Any]: self.tls.event_data = {} return self.tls.event_data +<<<<<<< HEAD def __init__(self) -> None: self.tls = threading.local() @@ -1744,11 +2009,22 @@ def __init__(self) -> None: self.id_ = f"{config.pt2_compile_id_prefix}-{uuid.uuid4()}" else: self.id_ = str(uuid.uuid4()) +======= + def __init__(self): + self.tls = threading.local() + # Generate a unique id for this logger, which we can use in scuba to filter down + # to a single python run. + self.id_ = str(uuid.uuid4()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: log to init/id tlparse after I add support for it log.info("ChromiumEventLogger initialized with id %s", self.id_) +<<<<<<< HEAD def try_add_event_data(self, event_name: str, **kwargs: Any) -> None: +======= + def try_add_event_data(self, event_name: str, **kwargs) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Same as add_event_data, but will silently not log if the event isn't in the stack. """ @@ -1759,7 +2035,11 @@ def try_add_event_data(self, event_name: str, **kwargs: Any) -> None: def add_event_data( self, event_name: str, +<<<<<<< HEAD **kwargs: Any, +======= + **kwargs, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: """ Adds additional metadata info to an in-progress event @@ -1776,7 +2056,11 @@ def add_event_data( event_data[event_name] = {} event_data[event_name].update(kwargs) +<<<<<<< HEAD def increment(self, event_name: str, key: str, value: int) -> None: +======= + def increment(self, event_name: str, key: str, value: int): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Increment an integer event data field by the given amount """ @@ -1799,7 +2083,11 @@ def add_to_set( event_name: str, key: str, value: Any, +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Add a value to a set within a event_name's metadata if it exists """ @@ -1895,7 +2183,11 @@ def log_event_end( event_metadata, ) +<<<<<<< HEAD def pop_stack(stack: list[str]) -> None: +======= + def pop_stack(stack): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) while event_name != stack[-1]: # If the event isn't the most recent one to end, pop # off the stack until it is. @@ -2056,14 +2348,22 @@ class CleanupHook: scope: dict[str, Any] name: str +<<<<<<< HEAD def __call__(self, *args: Any) -> None: +======= + def __call__(self, *args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Make sure we're not shutting down if CleanupManager is not None: CleanupManager.count -= 1 del self.scope[self.name] @staticmethod +<<<<<<< HEAD def create(scope: dict[str, Any], name: str, val: Any) -> CleanupHook: +======= + def create(scope, name, val): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert name not in scope CleanupManager.count += 1 scope[name] = val @@ -2074,7 +2374,11 @@ class CleanupManager(ExactWeakKeyDictionary): count = 0 instance: ClassVar[CleanupManager] +<<<<<<< HEAD def _remove_id(self, idx: int) -> None: +======= + def _remove_id(self, idx): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for hook in self.values[idx]: hook() super()._remove_id(idx) @@ -2083,7 +2387,11 @@ def _remove_id(self, idx: int) -> None: CleanupManager.instance = CleanupManager() +<<<<<<< HEAD def clone_tensor(x: torch.Tensor) -> torch.Tensor: +======= +def clone_tensor(x): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Clone the tensor and its gradient""" y = x.clone().requires_grad_(x.requires_grad) if x.is_leaf and x.grad is not None: @@ -2091,16 +2399,24 @@ def clone_tensor(x: torch.Tensor) -> torch.Tensor: return y +<<<<<<< HEAD def clone_input( x: torch.Tensor, *, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: +======= +def clone_input(x, *, dtype=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """copy while preserving strides""" # TODO: this is questionable if is_fake(x): # this func fails on fake tensors in __torch_dispatch__ return x +<<<<<<< HEAD def torch_clone(x: torch.Tensor) -> torch.Tensor: +======= + def torch_clone(x): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) y = torch.clone(x) if x.is_leaf: y.requires_grad_(x.requires_grad) @@ -2167,6 +2483,7 @@ def torch_clone(x: torch.Tensor) -> torch.Tensor: return result +<<<<<<< HEAD @overload def clone_inputs( example_inputs: dict[str, Union[T, tuple[T, ...]]], @@ -2179,6 +2496,10 @@ def clone_inputs(example_inputs: Sequence[T]) -> list[T]: ... def clone_inputs(example_inputs: Any) -> Any: res: Union[dict[str, Any], list[Any]] +======= +def clone_inputs(example_inputs): + res: Union[dict[Any, Any], list[Any]] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if type(example_inputs) is dict: res = dict(example_inputs) for key, value in res.items(): @@ -2196,7 +2517,11 @@ def clone_inputs(example_inputs: Any) -> Any: return res +<<<<<<< HEAD def skip_frame_if_in_functorch_mode(val: torch.Tensor) -> None: +======= +def skip_frame_if_in_functorch_mode(val: torch.Tensor): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: val.data_ptr() # will throw for functorch tensors except RuntimeError as e: @@ -2210,7 +2535,11 @@ def skip_frame_if_in_functorch_mode(val: torch.Tensor) -> None: @contextmanager +<<<<<<< HEAD def preserve_rng_state() -> Generator[None, None, None]: +======= +def preserve_rng_state(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) disable_functorch = torch._C._DisableFuncTorch disable_current_modes = torch.utils._python_dispatch._disable_current_modes with disable_current_modes(), disable_functorch(): @@ -2228,6 +2557,7 @@ def preserve_rng_state() -> Generator[None, None, None]: def is_jit_model( +<<<<<<< HEAD model0: Any, ) -> TypeIs[ Union[ @@ -2237,6 +2567,10 @@ def is_jit_model( torch.jit.ScriptModule, ] ]: +======= + model0, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return isinstance( model0, ( @@ -2248,7 +2582,11 @@ def is_jit_model( ) +<<<<<<< HEAD def torchscript(model: Any, example_inputs: Any, verbose: bool = False) -> Any: +======= +def torchscript(model, example_inputs, verbose=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_jit_model(model): # already done? return model @@ -2266,19 +2604,31 @@ def torchscript(model: Any, example_inputs: Any, verbose: bool = False) -> Any: return None +<<<<<<< HEAD def getfile(obj: Any) -> Optional[str]: +======= +def getfile(obj): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: return inspect.getfile(obj) except (TypeError, OSError): return None +<<<<<<< HEAD def is_namedtuple(obj: Any) -> bool: +======= +def is_namedtuple(obj): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple""" return is_namedtuple_cls(type(obj)) +<<<<<<< HEAD def is_namedtuple_cls(cls: Any) -> bool: +======= +def is_namedtuple_cls(cls): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Test if an object is a namedtuple or a (torch.return_types|torch.autograd.forward_ad).* quasi-namedtuple""" try: if issubclass(cls, tuple): @@ -2309,7 +2659,11 @@ def is_namedtuple_cls(cls: Any) -> bool: @functools.lru_cache(1) +<<<<<<< HEAD def namedtuple_fields(cls: type) -> tuple[str, ...]: +======= +def namedtuple_fields(cls) -> tuple[str, ...]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Get the fields of a namedtuple or a torch.return_types.* quasi-namedtuple""" if cls is slice: return ("start", "stop", "step") @@ -2325,16 +2679,28 @@ class Marker: # frustrating ones e.g. torch.return_types.max assert cls.__module__ == "torch.return_types" +<<<<<<< HEAD obj = cls(map(Marker, range(cls.n_fields))) # type: ignore[attr-defined] +======= + obj = cls(map(Marker, range(cls.n_fields))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fields: dict[str, int] = {} for name in dir(obj): if name[0] != "_" and isinstance(getattr(obj, name), Marker): fields[name] = getattr(obj, name).index +<<<<<<< HEAD assert len(fields) == cls.n_fields # type: ignore[attr-defined] return tuple(sorted(fields, key=fields.get)) # type: ignore[arg-type] def checkpoint_params(gm: torch.fx.GraphModule) -> Callable[[], None]: +======= + assert len(fields) == cls.n_fields + return tuple(sorted(fields, key=fields.get)) # type: ignore[arg-type] + + +def checkpoint_params(gm): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch.no_grad(): rng_state = torch.clone(torch.random.get_rng_state()) if torch.cuda.is_available(): @@ -2344,7 +2710,11 @@ def checkpoint_params(gm: torch.fx.GraphModule) -> Callable[[], None]: for param in itertools.chain(gm.parameters(), gm.buffers()) ] +<<<<<<< HEAD def restore() -> None: +======= + def restore(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch.no_grad(): torch.random.set_rng_state(rng_state) if torch.cuda.is_available(): @@ -2356,9 +2726,13 @@ def restore() -> None: return restore +<<<<<<< HEAD def timed( model: Any, example_inputs: Iterable[Any], times: int = 1 ) -> tuple[Any, float]: +======= +def timed(model, example_inputs, times=1): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if torch.cuda.is_available(): synchronize = torch.cuda.synchronize else: @@ -2375,12 +2749,20 @@ def timed( return result, t1 - t0 # type: ignore[possibly-undefined] +<<<<<<< HEAD def check_is_cuda(gm: torch.fx.GraphModule, example_inputs: Iterable[Any]) -> bool: +======= +def check_is_cuda(gm, example_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return all(x.is_cuda for x in itertools.chain(example_inputs, gm.parameters(True))) @lru_cache(32) +<<<<<<< HEAD def rot_n_helper(n: int) -> Callable[..., Any]: +======= +def rot_n_helper(n): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert n > 1 vars = [f"v{i}" for i in range(n)] rotated = reversed(vars[-1:] + vars[:-1]) @@ -2424,7 +2806,11 @@ def rot_n_helper(n: int) -> Callable[..., Any]: """ +<<<<<<< HEAD def is_safe_constant(v: Any) -> bool: +======= +def is_safe_constant(v): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if istype(v, (tuple, frozenset)): return all(map(is_safe_constant, v)) return isinstance( @@ -2443,7 +2829,11 @@ def is_safe_constant(v: Any) -> bool: @functools.cache +<<<<<<< HEAD def common_constants() -> set[int]: +======= +def common_constants(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return { # We zero-one specialize shapes, so specialize these constants # too @@ -2458,7 +2848,11 @@ def is_torch_sym(value: Any) -> TypeGuard[Union[torch.SymBool, torch.SymInt]]: ) +<<<<<<< HEAD def is_int_specialization_case(value: Any, source: Any) -> bool: +======= +def is_int_specialization_case(value, source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .source import is_from_defaults return not TracingContext.get().force_unspec_int_unbacked_size_like and ( @@ -2489,7 +2883,11 @@ def is_int_specialization_case(value: Any, source: Any) -> bool: ) +<<<<<<< HEAD def specialize_symnode(arg: Any) -> Any: +======= +def specialize_symnode(arg): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .variables import ConstantVariable, LazyVariableTracker, SymNodeVariable # Guard and specialize @@ -2514,7 +2912,11 @@ def specialize_symnode(arg: Any) -> Any: return arg +<<<<<<< HEAD def guard_if_dyn(arg: Any) -> Any: +======= +def guard_if_dyn(arg): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .variables import ConstantVariable arg = specialize_symnode(arg) @@ -2525,11 +2927,19 @@ def guard_if_dyn(arg: Any) -> Any: return arg +<<<<<<< HEAD def check_constant_args(args: Iterable[Any], kwargs: Mapping[Any, Any]) -> bool: return all(x.is_python_constant() for x in itertools.chain(args, kwargs.values())) def check_unspec_python_args(args: Iterable[Any], kwargs: Mapping[Any, Any]) -> bool: +======= +def check_constant_args(args, kwargs): + return all(x.is_python_constant() for x in itertools.chain(args, kwargs.values())) + + +def check_unspec_python_args(args, kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .variables.constant import ConstantVariable from .variables.tensor import UnspecializedPythonVariable @@ -2542,9 +2952,13 @@ def check_unspec_python_args(args: Iterable[Any], kwargs: Mapping[Any, Any]) -> return unspec_count > 0 +<<<<<<< HEAD def check_unspec_or_constant_args( args: Iterable[Any], kwargs: Mapping[Any, Any] ) -> bool: +======= +def check_unspec_or_constant_args(args, kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # A fused version of: # return check_constant_args(args, kwargs) or check_unspec_python_args(args, kwargs) from .variables.tensor import UnspecializedPythonVariable @@ -2555,7 +2969,11 @@ def check_unspec_or_constant_args( return True +<<<<<<< HEAD def check_numpy_ndarray_args(args: Iterable[Any], kwargs: Mapping[Any, Any]) -> bool: +======= +def check_numpy_ndarray_args(args, kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .variables.tensor import NumpyNdarrayVariable return any( @@ -2578,10 +2996,13 @@ def check_numpy_ndarray_args(args: Iterable[Any], kwargs: Mapping[Any, Any]) -> for method in itertools.chain(dict.__dict__.values(), OrderedDict.__dict__.values()) if callable(method) } +<<<<<<< HEAD set_methods = {method for method in set.__dict__.values() if callable(method)} frozenset_methods = { method for method in frozenset.__dict__.values() if callable(method) } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tuple_new = tuple.__new__ tuple_methods = {method for method in tuple.__dict__.values() if callable(method)} @@ -2590,17 +3011,26 @@ def check_numpy_ndarray_args(args: Iterable[Any], kwargs: Mapping[Any, Any]) -> str_methods = {method for method in str.__dict__.values() if callable(method)} +<<<<<<< HEAD K = TypeVar("K") V = TypeVar("V") def builtin_dict_keys(d: dict[K, V]) -> KeysView[K]: +======= + +def builtin_dict_keys(d): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Avoids overridden keys method of the dictionary assert isinstance(d, dict) return dict.keys(d) +<<<<<<< HEAD def get_items_from_dict(obj: dict[K, V]) -> Iterable[tuple[K, Union[V, Any]]]: +======= +def get_items_from_dict(obj): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Get items without calling the user defined __getitem__ or keys method. assert isinstance(obj, dict) if istype(obj, (dict, OrderedDict)): @@ -2611,28 +3041,45 @@ def get_items_from_dict(obj: dict[K, V]) -> Iterable[tuple[K, Union[V, Any]]]: return [(k, dict.__getitem__(obj, k)) for k in dict.keys(obj)] +<<<<<<< HEAD def nn_module_new(cls: Any) -> Any: +======= +def nn_module_new(cls): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = object_new(cls) torch.nn.Module.__init__(obj) return obj +<<<<<<< HEAD def product(it: Iterable[T]) -> int: return functools.reduce(operator.mul, it, 1) def tuple_iterator_getitem(it: Any, index: int) -> Any: +======= +def product(it): + return functools.reduce(operator.mul, it, 1) + + +def tuple_iterator_getitem(it, index): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _, (obj,), start = it.__reduce__() return obj[start + index] +<<<<<<< HEAD def dataclass_fields(cls: Any) -> Any: +======= +def dataclass_fields(cls): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch._dynamo.disable(dataclasses.fields)(cls) iter_next = next +<<<<<<< HEAD def normalize_range_iter(range_iter: Any) -> tuple[int, int, int]: _, (range_obj,), maybe_idx = range_iter.__reduce__() # In 3.12+, `maybe_idx` could be None, and `range_obj.start` would've been @@ -2642,19 +3089,34 @@ def normalize_range_iter(range_iter: Any) -> tuple[int, int, int]: # start. See: # https://github.com/python/cpython/blob/ea77feecbba389916af8f90b2fc77f07910a2963/Objects/rangeobject.c#L885-L899 start = range_obj.start + (maybe_idx or 0) * range_obj.step +======= +def normalize_range_iter(range_iter) -> tuple[int, int, int]: + _, (range_obj,), maybe_idx = range_iter.__reduce__() + # In 3.12+, `maybe_idx` could be None, and `range_obj.start` would've been + # already incremented by the current index. + start = range_obj.start + (maybe_idx or 0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stop = range_obj.stop step = range_obj.step return (start, stop, step) +<<<<<<< HEAD def to_subclass(t: Any, cls: type) -> Any: +======= +def to_subclass(t, cls): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return t.as_subclass(cls) dict_getitem = dict.__getitem__ +<<<<<<< HEAD def dict_keys_getitem(d: dict[Any, Any], n: int) -> Any: +======= +def dict_keys_getitem(d, n): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Call dict(d) to prevent calling overridden __iter__/keys dict_class = dict if isinstance(d, OrderedDict): @@ -2662,12 +3124,16 @@ def dict_keys_getitem(d: dict[Any, Any], n: int) -> Any: return next(itertools.islice(dict_class.keys(d), n, n + 1)) +<<<<<<< HEAD def set_getitem(s: set[T], n: int) -> T: # Set ordering might not be stable return list(s)[n] def enum_repr(value: Any, local: bool) -> str: +======= +def enum_repr(value, local): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # enum class can override __str__ method. Use __class__ and name attribute # to extract the class name and key name. name = value.__class__.__name__ @@ -2677,7 +3143,11 @@ def enum_repr(value: Any, local: bool) -> str: return local_name +<<<<<<< HEAD def set_example_value(node: torch.fx.Node, example_value: Any) -> None: +======= +def set_example_value(node, example_value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: example_value is a bit of a misnomer, because this is always a fake # tensor of some sort. Furthermore, these example values serve as the # runtime state of Dynamo tracing, which means if metadata mutation @@ -2685,9 +3155,13 @@ def set_example_value(node: torch.fx.Node, example_value: Any) -> None: # this to accurately reflect what the state of the value was at the time # the program was traced). node.meta["example_value"] = example_value +<<<<<<< HEAD fake_mode = TracingContext.get().fake_mode assert fake_mode is not None shape_env = fake_mode.shape_env +======= + shape_env = TracingContext.get().fake_mode.shape_env +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( symbol_to_path := torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings( @@ -2697,7 +3171,11 @@ def set_example_value(node: torch.fx.Node, example_value: Any) -> None: node.meta["unbacked_bindings"] = symbol_to_path +<<<<<<< HEAD def _get_fake_tensor(vt: VariableTracker) -> Any: +======= +def _get_fake_tensor(vt): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fake_tensor = vt.as_proxy().node.meta.get("example_value") if not is_fake(fake_tensor): from . import graph_break_hints @@ -2712,6 +3190,7 @@ def _get_fake_tensor(vt: VariableTracker) -> Any: return fake_tensor +<<<<<<< HEAD def slice_length(s: slice, seq_len: int) -> int: start, stop, step = s.indices(seq_len) return max(0, (stop - start + (step - (1 if step > 0 else -1))) // step) @@ -2736,6 +3215,16 @@ def iter_contains( ) -> Any: from .variables import BuiltinVariable, ConstantVariable, TensorVariable +======= +def iter_contains(items, search, tx, check_tensor_identity=False): + from .variables import ( + BuiltinVariable, + ConstantVariable, + TensorVariable, + VariableTracker, + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if search.is_python_constant(): found_const = any( x.is_python_constant() @@ -2776,11 +3265,19 @@ def key_is_id( return isinstance(k, (torch.Tensor, torch.nn.Module, MethodWrapperType)) +<<<<<<< HEAD def key_to_id(value: Any) -> list[Any]: return [id(k) if key_is_id(k) else k for k in value.keys()] def const_repr(x: Any, *, local: Any) -> str: +======= +def key_to_id(value): + return [id(k) if key_is_id(k) else k for k in value.keys()] + + +def const_repr(x, *, local) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .trace_rules import is_builtin_callable if isinstance(x, (list, tuple)): @@ -2801,7 +3298,11 @@ def const_repr(x: Any, *, local: Any) -> str: return x.__name__ elif isinstance(x, type): +<<<<<<< HEAD def fullname(o: Any) -> str: +======= + def fullname(o): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) klass = o.__class__ module = klass.__module__ if module == "builtins": @@ -2813,7 +3314,11 @@ def fullname(o: Any) -> str: return f"{x!r}" +<<<<<<< HEAD def dict_keys_repr(const_keys: Any, *, local: Any) -> str: +======= +def dict_keys_repr(const_keys, *, local) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) keys_str = ",".join(const_repr(s, local=local) for s in const_keys) return "[" + keys_str + "]" @@ -2824,7 +3329,11 @@ def dict_keys_repr(const_keys: Any, *, local: Any) -> str: from torch._subclasses import UnsupportedFakeTensorException # noqa: F401 +<<<<<<< HEAD def get_safe_global_name(tx: InstructionTranslatorBase, root: str, obj: Any) -> str: +======= +def get_safe_global_name(tx, root, obj): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The global_mangled_class_name should be different for different # invocations of torch.compile. Otherwise, we can run into a situation # where multiple torch.compile invocations reuse the same global name, @@ -2834,16 +3343,24 @@ def get_safe_global_name(tx: InstructionTranslatorBase, root: str, obj: Any) -> return f"{root}_{id(obj)}_c{tx.output.compile_id}" +<<<<<<< HEAD def is_in(item: T, *containers: Container[T]) -> bool: +======= +def is_in(item: Any, *containers) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for container in containers: if item in container: return True return False +<<<<<<< HEAD def get_unique_name_wrt( prefix: str, *containers: Any, requires_suffix: bool = False ) -> str: +======= +def get_unique_name_wrt(prefix: str, *containers, requires_suffix=False) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Return a name that starts with `prefix` and is not in any of the `containers` (e.g., map, set). @@ -2859,7 +3376,11 @@ def get_unique_name_wrt( raise AssertionError("unreachable") +<<<<<<< HEAD def wrap_fake_exception(fn: Callable[[], Any]) -> Any: +======= +def wrap_fake_exception(fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: return fn() except UnsupportedFakeTensorException as e: @@ -2876,14 +3397,22 @@ def wrap_fake_exception(fn: Callable[[], Any]) -> Any: ) +<<<<<<< HEAD def deepcopy_to_fake_tensor( obj: Any, fake_mode: torch._subclasses.fake_tensor.FakeTensorMode ) -> Any: +======= +def deepcopy_to_fake_tensor(obj, fake_mode): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch._subclasses.fake_tensor.FakeCopyMode(fake_mode): return wrap_fake_exception(lambda: copy.deepcopy(obj)) +<<<<<<< HEAD def rmse(ref: torch.Tensor, res: torch.Tensor) -> torch.Tensor: +======= +def rmse(ref, res): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Calculate root mean squared error """ @@ -2891,6 +3420,7 @@ def rmse(ref: torch.Tensor, res: torch.Tensor) -> torch.Tensor: def same( +<<<<<<< HEAD ref: Any, res: Any, fp64_ref: Any = None, @@ -2904,6 +3434,21 @@ def same( use_larger_multiplier_for_smaller_tensor: bool = False, force_max_multiplier: bool = False, ) -> bool: +======= + ref, + res, + fp64_ref=None, + cos_similarity=False, + tol=1e-4, + equal_nan=False, + exact_dtype=True, + relax_numpy_equality=False, + ignore_non_fp=False, + log_error=log.error, + use_larger_multiplier_for_smaller_tensor=False, + force_max_multiplier: bool = False, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Check correctness to see if ref and res match""" if fp64_ref is None: fp64_ref = ref @@ -2984,7 +3529,11 @@ def same( assert not isinstance(ref, torch._subclasses.FakeTensor) assert not isinstance(res, torch._subclasses.FakeTensor) +<<<<<<< HEAD def to_tensor(t: Any) -> torch.Tensor: +======= + def to_tensor(t): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return t if isinstance(t, torch.Tensor) else torch.tensor(t) ref, res, fp64_ref = (to_tensor(val) for val in (ref, res, fp64_ref)) @@ -3023,7 +3572,11 @@ def to_tensor(t: Any) -> torch.Tensor: score = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6) if score < 0.99: log.warning("Similarity score=%s", score.detach().cpu().item()) +<<<<<<< HEAD return bool(score >= 0.99) +======= + return score >= 0.99 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: if not exact_dtype: ref = ref.to(res.dtype) @@ -3063,7 +3616,11 @@ def to_tensor(t: Any) -> torch.Tensor: res_error = rmse(fp64_ref, res).item() +<<<<<<< HEAD def get_multiplier() -> float: +======= + def get_multiplier(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # In some particular cases, we expect high difference in results. # At the moment one of this cases is inductor freezing bfloat16 convolution const folding. # In case of it the res_error is at least one order of magnitude higher. @@ -3194,13 +3751,21 @@ def get_multiplier() -> float: raise RuntimeError(f"unsupported type: {type(ref).__name__}") +<<<<<<< HEAD def format_func_info(code: CodeType) -> str: +======= +def format_func_info(code): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) short_filename = code.co_filename.split("/")[-1] return f"'{code.co_name}' ({short_filename}:{code.co_firstlineno})" @contextlib.contextmanager +<<<<<<< HEAD def disable_cache_limit() -> Generator[None, None, None]: +======= +def disable_cache_limit(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prior = config.recompile_limit config.recompile_limit = sys.maxsize prior_acc_limit = config.accumulated_recompile_limit @@ -3229,7 +3794,11 @@ def disable_cache_limit() -> Generator[None, None, None]: # return same dir unless user changes config between calls @functools.cache +<<<<<<< HEAD def _get_debug_dir(root_dir: str) -> str: +======= +def _get_debug_dir(root_dir): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dir_name = ( "run_" + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") @@ -3240,12 +3809,20 @@ def _get_debug_dir(root_dir: str) -> str: return os.path.join(root_dir, dir_name) +<<<<<<< HEAD def get_debug_dir() -> str: +======= +def get_debug_dir(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) debug_root = config.debug_dir_root return _get_debug_dir(debug_root) +<<<<<<< HEAD def extract_fake_example_value(node: torch.fx.Node, required: bool = True) -> Any: +======= +def extract_fake_example_value(node, required=True): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if "example_value" in node.meta and is_fake(node.meta["example_value"]): return node.meta["example_value"] elif required: @@ -3263,15 +3840,24 @@ def extract_fake_example_value(node: torch.fx.Node, required: bool = True) -> An return None +<<<<<<< HEAD def ensure_graph_fake(e: Any, tx: InstructionTranslatorBase) -> Any: +======= +def ensure_graph_fake(e, tx): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert maybe_get_fake_mode(e) is tx.fake_mode return e +<<<<<<< HEAD def get_fake_values_from_nodes( tx: InstructionTranslatorBase, nodes: Any, allow_non_graph_fake: bool ) -> Any: def visit(n: torch.fx.Node) -> Any: +======= +def get_fake_values_from_nodes(tx, nodes, allow_non_graph_fake): + def visit(n: torch.fx.Node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if n.op == "call_function" and "example_value" not in n.meta: # fake tensor validity is checked inside get_fake_value using # ensure_graph_fake @@ -3279,7 +3865,11 @@ def visit(n: torch.fx.Node) -> Any: elif n.op == "get_attr" and "example_value" not in n.meta: assert n.target in tx.output.nn_modules +<<<<<<< HEAD gm = tx.output.nn_modules[n.target] # type: ignore[index] +======= + gm = tx.output.nn_modules[n.target] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(gm, torch.fx.GraphModule) return gm @@ -3291,11 +3881,15 @@ def visit(n: torch.fx.Node) -> Any: return torch.fx.node.map_arg(nodes, visit) +<<<<<<< HEAD def get_fake_value( node: torch.fx.Node, tx: InstructionTranslatorBase, allow_non_graph_fake: bool = False, ) -> Any: +======= +def get_fake_value(node, tx, allow_non_graph_fake=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Run the computation represented by `node` using fake tensors and return the result. @@ -3339,6 +3933,7 @@ def get_fake_value( id_to_initial_version = {} nnmodule = None +<<<<<<< HEAD fake_mode = tx.fake_mode assert fake_mode is not None if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module): @@ -3347,6 +3942,14 @@ def get_fake_value( if op == "call_module": nnmodule = tx.output.nn_modules[node.target] # type: ignore[index] +======= + if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module): + # If the first argument is nn.Module, should copy to fake mode. + args = (deepcopy_to_fake_tensor(args[0], tx.fake_mode),) + tuple(args[1:]) + + if op == "call_module": + nnmodule = tx.output.nn_modules[node.target] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_lazy_module(nnmodule) and hasattr(nnmodule, "_initialize_hook"): # In the case of a lazy module, we want to run @@ -3356,23 +3959,37 @@ def get_fake_value( nnmodule._infer_parameters(nnmodule, args) # no matter it's lazy module or not, we should copy to fake mode. +<<<<<<< HEAD nnmodule = deepcopy_to_fake_tensor(nnmodule, fake_mode) +======= + nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if node.name in ["interpolate", "is_integer", "wrapped_gradient"] or any( isinstance(a, complex) for a in args ): # We need to specialize symfloats for now. Eventually we should do a tensorify pass in dynamo. args = tuple( +<<<<<<< HEAD ( float(arg) if isinstance(arg, torch.SymFloat) and arg.node.hint is not None else arg ) +======= + float(arg) + if isinstance(arg, torch.SymFloat) and arg.node.hint is not None + else arg +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for arg in args ) try: +<<<<<<< HEAD with fake_mode, enable_python_dispatcher(): +======= + with tx.fake_mode, enable_python_dispatcher(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ret_val = wrap_fake_exception( lambda: run_node(tx.output, node, args, kwargs, nnmodule) ) @@ -3434,7 +4051,11 @@ def get_fake_value( elif isinstance( cause, torch._subclasses.fake_tensor.UnsupportedOperatorException ): +<<<<<<< HEAD op = cause.func # type: ignore[assignment] +======= + op = cause.func +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import_suggestion = "" if isinstance(op, torch._ops.OpOverload): maybe_pystub = torch._C._dispatch_pystub( @@ -3498,12 +4119,20 @@ def get_fake_value( _current_node = threading.local() +<<<<<<< HEAD def get_current_node() -> Optional[torch.fx.Node]: +======= +def get_current_node(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return getattr(_current_node, "value", None) @contextmanager +<<<<<<< HEAD def set_current_node(node: torch.fx.Node) -> Generator[None, None, None]: +======= +def set_current_node(node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) old = get_current_node() _current_node.value = node try: @@ -3512,9 +4141,13 @@ def set_current_node(node: torch.fx.Node) -> Generator[None, None, None]: _current_node.value = old +<<<<<<< HEAD def run_node( tracer: Any, node: torch.fx.Node, args: Any, kwargs: Any, nnmodule: Any ) -> Any: +======= +def run_node(tracer, node, args, kwargs, nnmodule): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Runs a given node, with the given args and kwargs. @@ -3533,7 +4166,11 @@ def run_node( with set_current_node(node): +<<<<<<< HEAD def make_error_message(e: Any) -> str: +======= + def make_error_message(e): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( f"Dynamo failed to run FX node with fake tensors: {op} {node.target}(*{args}, **{kwargs}): got " + repr(e) @@ -3543,9 +4180,15 @@ def make_error_message(e: Any) -> str: try: if op == "call_function": +<<<<<<< HEAD return node.target(*args, **kwargs) # type: ignore[operator] elif op == "call_method": if not hasattr(args[0], node.target): # type: ignore[arg-type] +======= + return node.target(*args, **kwargs) + elif op == "call_method": + if not hasattr(args[0], node.target): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .exc import unimplemented_v2 unimplemented_v2( @@ -3554,7 +4197,11 @@ def make_error_message(e: Any) -> str: explanation=make_error_message("attribute not defined"), hints=[], ) +<<<<<<< HEAD return getattr(args[0], node.target)(*args[1:], **kwargs) # type: ignore[arg-type] +======= + return getattr(args[0], node.target)(*args[1:], **kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif op == "call_module": assert nnmodule is not None return nnmodule(*args, **kwargs) @@ -3591,7 +4238,11 @@ def make_error_message(e: Any) -> str: raise AssertionError(op) +<<<<<<< HEAD def get_real_value(node: torch.fx.Node, tracer: Any) -> Any: +======= +def get_real_value(node, tracer): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Run the actual computation represented by `node` and return the result. This will execute any dependent nodes in the graph as well. @@ -3630,10 +4281,17 @@ def get_real_value(node: torch.fx.Node, tracer: Any) -> Any: return real_value +<<<<<<< HEAD def assert_no_fake_params_or_buffers(gm: torch.fx.GraphModule) -> None: from torch._subclasses.fake_tensor import FakeTensorConfig, is_fake def stack_or_hint(t: Any) -> str: +======= +def assert_no_fake_params_or_buffers(gm): + from torch._subclasses.fake_tensor import FakeTensorConfig, is_fake + + def stack_or_hint(t): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if FakeTensorConfig.debug: import traceback @@ -3651,21 +4309,33 @@ def stack_or_hint(t: Any) -> str: ) +<<<<<<< HEAD def fqn(obj: Any) -> str: +======= +def fqn(obj: Any): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Returns the fully qualified name of the object. """ return f"{obj.__module__}.{obj.__qualname__}" +<<<<<<< HEAD def ifdynstaticdefault(count1: Any, count2: Any) -> Any: +======= +def ifdynstaticdefault(count1, count2): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if torch._dynamo.config.assume_static_by_default: return count1 else: return count2 +<<<<<<< HEAD def import_submodule(mod: types.ModuleType) -> None: +======= +def import_submodule(mod: types.ModuleType): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Ensure all the files in a given submodule are imported """ @@ -3674,17 +4344,29 @@ def import_submodule(mod: types.ModuleType) -> None: importlib.import_module(f"{mod.__name__}.{filename[:-3]}") +<<<<<<< HEAD def object_has_getattribute(value: Any) -> bool: return class_has_getattribute(type(value)) def object_setattr_ignore_descriptor(obj: Any, name: str, value: Any) -> None: +======= +def object_has_getattribute(value: Any): + return class_has_getattribute(type(value)) + + +def object_setattr_ignore_descriptor(obj, name, value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # https://github.com/python/cpython/blob/3.11/Objects/object.c#L1286-L1335 d = object.__getattribute__(obj, "__dict__") d[name] = value +<<<<<<< HEAD def class_has_getattribute(cls: type) -> bool: +======= +def class_has_getattribute(cls: type): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: if isinstance( inspect.getattr_static(cls, "__getattribute__"), @@ -3696,9 +4378,13 @@ def class_has_getattribute(cls: type) -> bool: return False +<<<<<<< HEAD def get_custom_getattr( value: Any, ignore_nn_module_getattr: bool = False ) -> Optional[Any]: +======= +def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: getattr_fn = inspect.getattr_static(type(value), "__getattr__") except AttributeError: @@ -3715,7 +4401,11 @@ class TensorStaticReason(enum.Enum): NN_MODULE_PROPERTY = 5 +<<<<<<< HEAD def tensor_static_reason_to_message(reason: TensorStaticReason) -> str: +======= +def tensor_static_reason_to_message(reason: TensorStaticReason): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if reason == TensorStaticReason.PARAMETER: return "mark_dynamic on parameter, parameters are always static today." if reason == TensorStaticReason.NOT_TENSOR: @@ -3759,8 +4449,13 @@ def tensor_always_has_static_shape( return False, None +<<<<<<< HEAD def lazy_format_graph_tabular(fn_name: str, gm: torch.fx.GraphModule) -> Any: def inner() -> str: +======= +def lazy_format_graph_tabular(fn_name, gm): + def inner(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: from tabulate import tabulate # TODO: Check that this is installed except ImportError: @@ -3780,9 +4475,13 @@ def inner() -> str: return LazyString(inner) +<<<<<<< HEAD def format_bytecode( prefix: str, name: str, filename: str, line_no: int, code: Any ) -> str: +======= +def format_bytecode(prefix, name, filename, line_no, code): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n" @@ -3797,21 +4496,37 @@ def format_bytecode( all_hook_names = forward_hook_names + backward_hook_names + state_dict_hook_names +<<<<<<< HEAD def nn_module_has_global_hooks() -> bool: # This is limited to backward hooks for now because NNModuleVariable # supports fwd hooks underneath. return bool( len(torch.nn.modules.module._global_backward_hooks) or len(torch.nn.modules.module._global_backward_pre_hooks) +======= +def nn_module_has_global_hooks(): + # This is limited to backward hooks for now because NNModuleVariable + # supports fwd hooks underneath. + return len(torch.nn.modules.module._global_backward_hooks) or len( + torch.nn.modules.module._global_backward_pre_hooks +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def nn_module_get_all_hooks( +<<<<<<< HEAD mod: torch.nn.Module, check_forward_hooks: bool = False, check_backward_hooks: bool = False, check_state_dict_hooks: bool = False, ) -> list[Any]: +======= + mod, + check_forward_hooks=False, + check_backward_hooks=False, + check_state_dict_hooks=False, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Sometimes its useful to differentiate between types of hooks such as forward/backward/pre hooks executed during module.__call__, and state_dict hooks which are executed separately. @@ -3840,11 +4555,19 @@ def nn_module_get_all_hooks( def nnmodule_has_hooks( +<<<<<<< HEAD mod: torch.nn.Module, check_forward_hooks: bool = False, check_backward_hooks: bool = False, check_state_dict_hooks: bool = False, ) -> bool: +======= + mod, + check_forward_hooks=False, + check_backward_hooks=False, + check_state_dict_hooks=False, +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Helper function to check if a module has any hooks attached to it. """ @@ -3857,7 +4580,11 @@ def nnmodule_has_hooks( return bool(hooks) +<<<<<<< HEAD def to_numpy_helper(value: Any) -> Any: +======= +def to_numpy_helper(value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Convert tensor and tnp.ndarray to numpy.ndarray.""" if is_fake(value): return value @@ -3871,7 +4598,11 @@ def to_numpy_helper(value: Any) -> Any: return value +<<<<<<< HEAD def numpy_to_tensor(value: Any) -> Any: +======= +def numpy_to_tensor(value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Convert tnp.ndarray to tensor, leave other types intact. If a list/tuple, loop through it to convert.""" assert np is not None if isinstance(value, np.ndarray): @@ -3884,20 +4615,33 @@ def numpy_to_tensor(value: Any) -> Any: return value +<<<<<<< HEAD class numpy_to_tensor_wrapper(Generic[_P, R]): def __init__(self, f: Callable[_P, R]) -> None: +======= +class numpy_to_tensor_wrapper: + def __init__(self, f): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.f = f self.__name__ = "wrapped_" + self.f.__name__ def __repr__(self) -> str: return f">" +<<<<<<< HEAD def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Any: +======= + def __call__(self, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = self.f(*args, **kwargs) return numpy_to_tensor(out) +<<<<<<< HEAD def numpy_attr_wrapper(obj: Any, name: str) -> Any: +======= +def numpy_attr_wrapper(obj, name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(obj, tnp.ndarray): out = getattr(obj, name) return numpy_to_tensor(out) @@ -3909,14 +4653,22 @@ def numpy_attr_wrapper(obj: Any, name: str) -> Any: class numpy_method_wrapper: """Convert obj from torch.Tensor to tnp.ndarray and call method. Then convert result back to torch.Tensor.""" +<<<<<<< HEAD def __init__(self, method: str) -> None: +======= + def __init__(self, method: str): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.method = method self.__name__ = "wrapped_" + self.method def __repr__(self) -> str: return f">" +<<<<<<< HEAD def __call__(self, *args: Any, **kwargs: Any) -> Any: +======= + def __call__(self, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = args[0] if isinstance(obj, torch.Tensor): obj = tnp.ndarray(obj) @@ -3925,17 +4677,28 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return numpy_to_tensor(out) +<<<<<<< HEAD class numpy_operator_wrapper(Generic[_P, R]): """Implements dunder methods for tnp.ndarray via functions from the operator library""" def __init__(self, op: Callable[..., Any]) -> None: +======= +class numpy_operator_wrapper: + """Implements dunder methods for tnp.ndarray via functions from the operator library""" + + def __init__(self, op: Callable[..., Any]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.op = op self.__name__ = f"wrapped_{op.__name__}" def __repr__(self) -> str: return f">" +<<<<<<< HEAD def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Any: +======= + def __call__(self, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert not kwargs args = ( @@ -3945,7 +4708,11 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Any: return numpy_to_tensor(out) +<<<<<<< HEAD def defake(x: Any) -> Any: +======= +def defake(x): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not isinstance(x, FakeTensor): return x size: torch._prims_common.ShapeType @@ -3977,6 +4744,7 @@ def defake(x: Any) -> Any: return y +<<<<<<< HEAD def _disable_side_effect_safety_checks_for_current_subtracer( fn: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs ) -> R: @@ -3984,19 +4752,34 @@ def _disable_side_effect_safety_checks_for_current_subtracer( def is_utils_checkpoint(obj: Any) -> bool: +======= +def _disable_side_effect_safety_checks_for_current_subtracer(fn, *args, **kwargs): + return fn(*args, **kwargs) + + +def is_utils_checkpoint(obj): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Lazy import to avoid circular dependencies import torch.utils.checkpoint return obj is torch.utils.checkpoint.checkpoint +<<<<<<< HEAD def is_invoke_subgraph(obj: Any) -> bool: +======= +def is_invoke_subgraph(obj): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._higher_order_ops.invoke_subgraph import invoke_subgraph_placeholder return obj is invoke_subgraph_placeholder +<<<<<<< HEAD def build_invoke_subgraph_variable(**options: Any) -> Any: +======= +def build_invoke_subgraph_variable(**options): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .variables.higher_order_ops import TorchHigherOrderOperatorVariable return TorchHigherOrderOperatorVariable.make( @@ -4005,7 +4788,11 @@ def build_invoke_subgraph_variable(**options: Any) -> Any: ) +<<<<<<< HEAD def build_checkpoint_variable(**options: Any) -> Any: +======= +def build_checkpoint_variable(**options): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch._higher_order_ops.wrap as higher_order_ops from .variables.higher_order_ops import TorchHigherOrderOperatorVariable @@ -4024,14 +4811,22 @@ def build_checkpoint_variable(**options: Any) -> Any: ) +<<<<<<< HEAD def is_compile_supported(device_type: DeviceLikeType) -> Any: +======= +def is_compile_supported(device_type): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .eval_frame import is_dynamo_supported type = torch.device(device_type).type compile_supported = is_dynamo_supported() if type == "cpu": pass +<<<<<<< HEAD elif type in ["cuda", "xpu", "mtia"] and compile_supported: +======= + elif type in ["cuda", "xpu"] and compile_supported: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) compile_supported = has_triton() else: compile_supported = False @@ -4090,12 +4885,20 @@ def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]: lines = segment.split("\n") # get character index given byte offset +<<<<<<< HEAD def normalize(lineno: int, offset: int) -> int: +======= + def normalize(lineno, offset): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return _fix_offset(lines[lineno], offset) # Gets the next valid character index in `lines`, if # the current location is not valid. Handles empty lines. +<<<<<<< HEAD def next_valid_char(lineno: int, col: int) -> tuple[int, int]: +======= + def next_valid_char(lineno, col): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) while lineno < len(lines) and col >= len(lines[lineno]): col = 0 lineno += 1 @@ -4103,14 +4906,22 @@ def next_valid_char(lineno: int, col: int) -> tuple[int, int]: return lineno, col # Get the next valid character index in `lines`. +<<<<<<< HEAD def increment(lineno: int, col: int) -> tuple[int, int]: +======= + def increment(lineno, col): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) col += 1 lineno, col = next_valid_char(lineno, col) assert lineno < len(lines) and col < len(lines[lineno]) return lineno, col # Get the next valid character at least on the next line +<<<<<<< HEAD def nextline(lineno: int, col: int) -> tuple[int, int]: +======= + def nextline(lineno, col): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) col = 0 lineno += 1 lineno, col = next_valid_char(lineno, col) @@ -4127,7 +4938,10 @@ def nextline(lineno: int, col: int) -> tuple[int, int]: # -2 since end_lineno is 1-indexed and because we added an extra # bracket to `segment` when calling ast.parse cur_lineno = cast(int, expr.left.end_lineno) - 2 +<<<<<<< HEAD assert expr.left.end_col_offset is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cur_col = normalize(cur_lineno, expr.left.end_col_offset) cur_lineno, cur_col = next_valid_char(cur_lineno, cur_col) @@ -4160,14 +4974,20 @@ def nextline(lineno: int, col: int) -> tuple[int, int]: # subscript^^^^^^^^^^^^^^^^^^^^ # find left bracket (first '[' after value) left_lineno = cast(int, expr.value.end_lineno) - 2 +<<<<<<< HEAD assert expr.value.end_col_offset is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) left_col = normalize(left_lineno, expr.value.end_col_offset) left_lineno, left_col = next_valid_char(left_lineno, left_col) while lines[left_lineno][left_col] != "[": left_lineno, left_col = increment(left_lineno, left_col) # find right bracket (final character of expression) right_lineno = cast(int, expr.end_lineno) - 2 +<<<<<<< HEAD assert expr.end_col_offset is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) right_col = normalize(right_lineno, expr.end_col_offset) return _Anchors(left_lineno, left_col, right_lineno, right_col) elif isinstance(expr, ast.Call): @@ -4176,14 +4996,20 @@ def nextline(lineno: int, col: int) -> tuple[int, int]: # call^^^^^^^^^^^^^^^^^^^^^^^^ # find left bracket (first '(' after func) left_lineno = cast(int, expr.func.end_lineno) - 2 +<<<<<<< HEAD assert expr.func.end_col_offset is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) left_col = normalize(left_lineno, expr.func.end_col_offset) left_lineno, left_col = next_valid_char(left_lineno, left_col) while lines[left_lineno][left_col] != "(": left_lineno, left_col = increment(left_lineno, left_col) # find right bracket (final character of expression) right_lineno = cast(int, expr.end_lineno) - 2 +<<<<<<< HEAD assert expr.end_col_offset is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) right_col = normalize(right_lineno, expr.end_col_offset) return _Anchors(left_lineno, left_col, right_lineno, right_col) @@ -4322,14 +5148,22 @@ def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> s return result +<<<<<<< HEAD def get_static_address_type(t: Any) -> Any: +======= +def get_static_address_type(t): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(t, torch.Tensor): return getattr(t, "_dynamo_static_input_type", None) return None +<<<<<<< HEAD def is_rng_state_getter_or_setter(value: Any) -> bool: +======= +def is_rng_state_getter_or_setter(value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) getters = ( # The following two functions are not identical, so don't remove anyone! torch._C.Generator.get_state, @@ -4346,7 +5180,11 @@ def is_rng_state_getter_or_setter(value: Any) -> bool: return value in (*setters, *getters) +<<<<<<< HEAD def is_tensor_base_attr_getter(value: Any) -> bool: +======= +def is_tensor_base_attr_getter(value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( isinstance(value, types.MethodWrapperType) and value.__name__ == "__get__" @@ -4354,7 +5192,11 @@ def is_tensor_base_attr_getter(value: Any) -> bool: ) +<<<<<<< HEAD def is_tensor_getset_descriptor(name: str) -> bool: +======= +def is_tensor_getset_descriptor(name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: attr = inspect.getattr_static(torch.Tensor, name) return type(attr) is types.GetSetDescriptorType @@ -4362,11 +5204,19 @@ def is_tensor_getset_descriptor(name: str) -> bool: return False +<<<<<<< HEAD def is_torch_function_object(value: Any) -> bool: return hasattr(value, "__torch_function__") def has_torch_function(vt: VariableTracker) -> bool: +======= +def is_torch_function_object(value): + return hasattr(value, "__torch_function__") + + +def has_torch_function(vt: torch._dynamo.variables.base.VariableTracker) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This emulates # https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/csrc/utils/disable_torch_function.cpp#L315-L323 from torch._dynamo.variables import UserDefinedObjectVariable @@ -4396,9 +5246,13 @@ def has_torch_function(vt: VariableTracker) -> bool: # see note [Tensor Fakification and Symbol Caching] +<<<<<<< HEAD def to_fake_tensor( t: torch.Tensor, fake_mode: torch._subclasses.fake_tensor.FakeTensorMode ) -> Any: +======= +def to_fake_tensor(t, fake_mode): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) symbolic_context = None source = None if tracing_context := torch._guards.TracingContext.try_get(): @@ -4412,7 +5266,11 @@ def to_fake_tensor( # NB: this works for both classes and instances +<<<<<<< HEAD def is_frozen_dataclass(value: Any) -> bool: +======= +def is_frozen_dataclass(value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( not object_has_getattribute(value) and not class_has_getattribute(value) @@ -4423,7 +5281,11 @@ def is_frozen_dataclass(value: Any) -> bool: ) +<<<<<<< HEAD def get_first_attr(obj: Any, *attrs: str) -> Any: +======= +def get_first_attr(obj, *attrs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Return the first available attribute or throw an exception if none is present. """ @@ -4435,15 +5297,24 @@ def get_first_attr(obj: Any, *attrs: str) -> Any: @contextlib.contextmanager +<<<<<<< HEAD def maybe_enable_compiled_autograd( should_enable: bool, fullgraph: bool = True, dynamic: bool = True ) -> Generator[Any, None, None]: +======= +def maybe_enable_compiled_autograd(should_enable, fullgraph=True, dynamic=True): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not should_enable: yield else: +<<<<<<< HEAD def compiler_fn(gm: Any) -> Any: def inner_compiler(gm_: Any, example_inputs_: Any) -> Any: +======= + def compiler_fn(gm): + def inner_compiler(gm_, example_inputs_): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._dynamo.utils.counters["compiled_autograd"]["compiles"] += 1 return torch._inductor.compile(gm_, example_inputs_) @@ -4455,7 +5326,11 @@ def inner_compiler(gm_: Any, example_inputs_: Any) -> Any: yield ctx +<<<<<<< HEAD def invalid_removeable_handle() -> RemovableHandle: +======= +def invalid_removeable_handle(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # need a subclass so weakref works class Invalid(dict): # type: ignore[type-arg] pass @@ -4467,7 +5342,11 @@ class Invalid(dict): # type: ignore[type-arg] # Attribute changes to the original object/proxy will be reflected in the other. # This is useful for cases where we want a keep-alive reference to a module without increasing # its reference count. +<<<<<<< HEAD def nn_module_proxy(mod: Any) -> Any: +======= +def nn_module_proxy(mod): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not isinstance(mod, torch.nn.Module): return mod if isinstance(mod, torch.fx.GraphModule): @@ -4479,21 +5358,33 @@ def nn_module_proxy(mod: Any) -> Any: class GmWrapper(torch.nn.Module): +<<<<<<< HEAD def __init__( self, gm: torch.fx.GraphModule, unflatten_fn: Callable[[list[Any]], Any] ) -> None: +======= + def __init__(self, gm, unflatten_fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__() self.gm = gm self.unflatten_fn = unflatten_fn +<<<<<<< HEAD def forward(self, *args: Any) -> Any: +======= + def forward(self, *args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args: list[Any] = list(args) return self.gm(*self.unflatten_fn(args)) +<<<<<<< HEAD def flatten_graph_inputs( gm: torch.fx.GraphModule, inputs: Any, compile_gm: Callable[[Any, Any], Any] ) -> Callable[..., Any]: +======= +def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Mutate inputs so that they are flat and wrap gm such that it accepts those inputs. This is needed for graphs that take @@ -4512,10 +5403,17 @@ def flatten_graph_inputs( assert isinstance(inputs[0], list) boxed_inputs_count = len(inputs[0]) +<<<<<<< HEAD def flatten_fn(args: Any) -> Any: return args[0] + list(args[1:]) def unflatten_fn(flat_args: Any) -> Any: +======= + def flatten_fn(args): + return args[0] + list(args[1:]) + + def unflatten_fn(flat_args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (flat_args[:boxed_inputs_count], *flat_args[boxed_inputs_count:]) compiled_fn = compile_gm(GmWrapper(gm, unflatten_fn), flatten_fn(inputs)) @@ -4527,7 +5425,11 @@ def unflatten_fn(flat_args: Any) -> Any: # note this doesn't check the spec, assuming it is the same flatten_fn = pytree.arg_tree_leaves +<<<<<<< HEAD def wrapper(*args: Any) -> Any: +======= + def wrapper(*args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) flat_args = flatten_fn(args) # flat_args is a new list, so we need to clear references from the old list @@ -4540,18 +5442,30 @@ def wrapper(*args: Any) -> Any: return wrapper +<<<<<<< HEAD def get_locals_to_steal(maybe_gm: Any) -> list[Any]: +======= +def get_locals_to_steal(maybe_gm): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not isinstance(maybe_gm, torch.fx.GraphModule) or not hasattr(maybe_gm, "meta"): return [] return maybe_gm.meta.get("locals_to_steal", []) +<<<<<<< HEAD def set_locals_to_steal(gm: torch.fx.GraphModule, locals_to_steal: list[Any]) -> None: +======= +def set_locals_to_steal(gm, locals_to_steal): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gm.meta["locals_to_steal"] = locals_to_steal class Lit: +<<<<<<< HEAD def __init__(self, s: str) -> None: +======= + def __init__(self, s): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.s = s def __repr__(self) -> str: @@ -4561,7 +5475,11 @@ def __repr__(self) -> str: warn_once_cache: set[str] = set() +<<<<<<< HEAD def warn_once(msg: str, stacklevel: int = 1) -> None: +======= +def warn_once(msg, stacklevel=1): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Dynamo causes all warnings.warn (in user code and in Dynamo code) to print all the time. # https://github.com/pytorch/pytorch/issues/128427. # warn_once is a workaround: if the msg has been warned on before, then we will not @@ -4573,14 +5491,22 @@ def warn_once(msg: str, stacklevel: int = 1) -> None: warnings.warn(msg, stacklevel=stacklevel + 1) +<<<<<<< HEAD def strip_color_from_string(text: str) -> str: +======= +def strip_color_from_string(text): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This regular expression matches ANSI escape codes ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]") return ansi_escape.sub("", text) @contextlib.contextmanager +<<<<<<< HEAD def _disable_saved_tensors_hooks_during_tracing() -> Generator[None, None, None]: +======= +def _disable_saved_tensors_hooks_during_tracing(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # See NOTE: [Deferring tensor pack/unpack hooks until runtime] try: prior = torch._C._autograd._saved_tensors_hooks_set_tracing(True) @@ -4589,22 +5515,38 @@ def _disable_saved_tensors_hooks_during_tracing() -> Generator[None, None, None] torch._C._autograd._saved_tensors_hooks_set_tracing(prior) +<<<<<<< HEAD def is_parameter_freezing() -> bool: return torch._inductor.config.freezing and not torch.is_grad_enabled() def get_torch_function_mode_stack() -> list[Any]: +======= +def is_parameter_freezing(): + return torch._inductor.config.freezing and not torch.is_grad_enabled() + + +def get_torch_function_mode_stack(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return [ get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) ] +<<<<<<< HEAD def get_torch_function_mode_stack_at(ind: int) -> Any: +======= +def get_torch_function_mode_stack_at(ind): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert ind < _len_torch_function_stack() and ind >= 0 return torch._C._get_function_stack_at(ind) +<<<<<<< HEAD def set_torch_function_mode_stack(stack: list[Any]) -> None: +======= +def set_torch_function_mode_stack(stack): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for _ in range(_len_torch_function_stack()): _pop_torch_function_stack() @@ -4612,17 +5554,29 @@ def set_torch_function_mode_stack(stack: list[Any]) -> None: _push_on_torch_function_stack(mode) +<<<<<<< HEAD def clear_torch_function_mode_stack() -> None: +======= +def clear_torch_function_mode_stack(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for _ in range(_len_torch_function_stack()): _pop_torch_function_stack() # call from C dynamo in order to inspect values in pdb +<<<<<<< HEAD def _breakpoint_for_c_dynamo(*args: Any) -> None: breakpoint() def verify_guard_fn_signature(value: Any) -> None: +======= +def _breakpoint_for_c_dynamo(*args): + breakpoint() + + +def verify_guard_fn_signature(value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fn = value.__metadata_guard__ sig = inspect.signature(fn) if len(sig.parameters) != 2: @@ -4639,7 +5593,11 @@ def verify_guard_fn_signature(value: Any) -> None: ) +<<<<<<< HEAD def does_not_override_dict_iter_methods(user_cls: Any) -> bool: +======= +def does_not_override_dict_iter_methods(user_cls): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( user_cls.items in (dict.items, OrderedDict.items) and user_cls.values in (dict.values, OrderedDict.values) @@ -4652,23 +5610,39 @@ def does_not_override_dict_iter_methods(user_cls: Any) -> bool: # __torch_function__ calls triggered on tensor properties in the pre graph # bytecode. @torch._disable_dynamo +<<<<<<< HEAD def call_size(x: Any, i: int) -> int: +======= +def call_size(x, i): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return x.size(i) @torch._disable_dynamo +<<<<<<< HEAD def call_stride(x: Any, i: int) -> int: +======= +def call_stride(x, i): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return x.stride(i) @torch._disable_dynamo +<<<<<<< HEAD def call_storage_offset(x: Any) -> int: +======= +def call_storage_offset(x): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return x.storage_offset() # Helper function to extract relevant parts of a tensor's __dict__ to store in node meta. # To avoid ref cycles, it's important that no tensors are present here, so leave those out. +<<<<<<< HEAD def _extract_tensor_dict(t: torch.Tensor) -> dict[str, Any]: +======= +def _extract_tensor_dict(t): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) KEYS_TO_COPY = [ "_dynamo_static_input_type", "tag", @@ -4687,13 +5661,21 @@ def _extract_tensor_dict(t: torch.Tensor) -> dict[str, Any]: user_obj_id_to_weakref: dict[int, weakref.ReferenceType[object]] = {} +<<<<<<< HEAD def get_user_object_from_id(obj_id: int) -> Any: +======= +def get_user_object_from_id(obj_id): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = user_obj_id_to_weakref[obj_id]() assert obj is not None, "User object is no longer alive" return obj +<<<<<<< HEAD def store_user_object_weakref(obj: object) -> None: +======= +def store_user_object_weakref(obj): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj_id = id(obj) user_obj_id_to_weakref[obj_id] = weakref.ref(obj) @@ -4726,7 +5708,11 @@ def value(cls) -> int: @classmethod @contextmanager +<<<<<<< HEAD def record(cls) -> Generator[None, None, None]: +======= + def record(cls): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: if config.record_compile_time_instruction_count: cls.start() @@ -4736,12 +5722,16 @@ def record(cls) -> Generator[None, None, None]: cls.end() +<<<<<<< HEAD class CompileCounterInt(int): def __add__(self, other: Any) -> CompileCounterInt: return CompileCounterInt(super().__add__(other)) def set_feature_use(feature: str, usage: bool) -> None: +======= +def set_feature_use(feature: str, usage: bool): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Records whether we are using a feature Generally a feature is a JK. @@ -4759,7 +5749,11 @@ def set_feature_use(feature: str, usage: bool) -> None: ) +<<<<<<< HEAD def get_optimize_ddp_mode() -> str: +======= +def get_optimize_ddp_mode(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) optimize_ddp = config.optimize_ddp if isinstance(optimize_ddp, bool): mode = "ddp_optimizer" if optimize_ddp else "no_optimization" @@ -4818,6 +5812,7 @@ def is_node_meta_valid(node: Optional[torch.fx.Node]) -> bool: return node is None or "example_value" in node.meta or "val" in node.meta +<<<<<<< HEAD # If True, enforce fullgraph=True - raise errors on graph break _error_on_graph_break = False @@ -4831,6 +5826,8 @@ def _set_error_on_graph_break(value: bool) -> None: _error_on_graph_break = value +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch._disable_dynamo def record_pregraph_bytecode_enter() -> AbstractContextManager[None]: cm: AbstractContextManager[None] = ( @@ -4849,7 +5846,11 @@ def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None: # Returns a set of code objects present traced in the current TracingContext, or None # if there is no current TracingContext. +<<<<<<< HEAD def get_traced_code() -> Optional[list[CodeType]]: +======= +def get_traced_code() -> list[CodeType]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._guards import TracingContext return TracingContext.get_traced_code() diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 31bc7db5128f7..fdb46b692672a 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -27,7 +27,10 @@ DisabledSavedTensorsHooksVariable, DualLevelContextManager, DynamoConfigPatchVariable, +<<<<<<< HEAD ErrorOnGraphBreakVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) FSDPParamGroupUseTrainingStateVariable, GradIncrementNestingCtxManagerVariable, GradInplaceRequiresGradCtxManagerVariable, @@ -75,16 +78,26 @@ from .higher_order_ops import ( FunctionalCallVariable, FunctorchHigherOrderVariable, +<<<<<<< HEAD ReparametrizeModuleCallVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TorchHigherOrderOperatorVariable, ) from .iter import ( CountIteratorVariable, +<<<<<<< HEAD +======= + CycleIteratorVariable, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) FilterVariable, IteratorVariable, ItertoolsVariable, MapVariable, +<<<<<<< HEAD ObjectIteratorVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) RepeatIteratorVariable, ZipVariable, ) @@ -140,7 +153,10 @@ ) from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable from .user_defined import ( +<<<<<<< HEAD FrozenDataClassVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MutableMappingVariable, RemovableHandleVariable, UserDefinedClassVariable, @@ -149,7 +165,10 @@ UserDefinedExceptionObjectVariable, UserDefinedListVariable, UserDefinedObjectVariable, +<<<<<<< HEAD UserDefinedSetVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) UserDefinedTupleVariable, ) @@ -168,6 +187,10 @@ "CreateTMADescriptorExperimentalVariable", "CreateTMADescriptorStableVariable", "CUDADeviceVariable", +<<<<<<< HEAD +======= + "CycleIteratorVariable", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "DataPtrVariable", "DefaultDictVariable", "DeletedVariable", @@ -200,7 +223,10 @@ "RemovableHandleVariable", "RepeatIteratorVariable", "SDPAParamsVariable", +<<<<<<< HEAD "ErrorOnGraphBreakVariable", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "SkipFunctionVariable", "SliceVariable", "StringFormatVariable", diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index eac2251320008..3822374adf582 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -375,9 +375,13 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke if not variables.ConstantVariable.is_literal(value): raise NotImplementedError source = self.source and AttrSource(self.source, name) +<<<<<<< HEAD if source and not isinstance(self, variables.ConstantVariable): # The second condition is to avoid guards on const getattr objects # like __code__.co_argcount +======= + if source: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) return variables.ConstantVariable.create(value, source=source) @@ -549,12 +553,15 @@ def call_method( "This can happen unintentionally if a previous graph break happens with a builtin iterator " "in the local scope." ) +<<<<<<< HEAD hints.append( "List/dict comprehensions in Python <= 3.11 result in implicit function calls, which Dynamo " "cannot trace as a top level frame. Possible workarounds are (1) use a loop instead of a comprehension, " "(2) fix any graph breaks in the function above the comprehension, (3) wrap the comprehension in a " "function, or (4) use Python 3.12+." ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unimplemented_v2( gb_type="Unsupported method call", context=f"call_method {self} {name} {args} {kwargs}", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index e49eef3707762..14ff2315627d2 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -36,6 +36,10 @@ import sys import traceback import types +<<<<<<< HEAD +======= +import warnings +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import weakref from collections.abc import MutableMapping from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union @@ -44,7 +48,10 @@ import torch from torch import SymInt +<<<<<<< HEAD from torch._dispatch.python import enable_python_dispatcher +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo.utils import ( get_metrics_context, is_int_specialization_case, @@ -52,7 +59,10 @@ set_feature_use, ) from torch._guards import TracingContext +<<<<<<< HEAD from torch._higher_order_ops.flat_apply import flat_apply +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._higher_order_ops.torchbind import call_torchbind from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode @@ -105,14 +115,20 @@ GetItemSource, GradSource, is_constant_source, +<<<<<<< HEAD is_from_closure_source, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_from_global_source, is_from_nonlocal_source, is_from_optimizer_source, is_from_unspecialized_nn_module_source, ListGetItemSource, LocalSource, +<<<<<<< HEAD NonSerializableSetGetItemSource, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NumpyTensorSource, OptimizerSource, RandomValueSource, @@ -134,7 +150,10 @@ get_locals_to_steal, get_static_address_type, is_frozen_dataclass, +<<<<<<< HEAD is_function, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_function_or_wrapper, is_invoke_subgraph, is_lru_cache_wrapped_function, @@ -164,12 +183,18 @@ VariableTracker, VariableTrackerMeta, ) +<<<<<<< HEAD from .builtin import BuiltinVariable +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .constant import ConstantVariable, EnumVariable from .ctx_manager import ( AutocastModeVariable, DynamoConfigPatchVariable, +<<<<<<< HEAD ErrorOnGraphBreakVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) EventVariable, NullContextVariable, PreserveVersionContextVariable, @@ -282,7 +307,10 @@ UserDefinedExceptionClassVariable, UserDefinedListVariable, UserDefinedObjectVariable, +<<<<<<< HEAD UserDefinedSetVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) UserDefinedTupleVariable, ) @@ -308,7 +336,12 @@ def safe_has_grad(t): +<<<<<<< HEAD with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter): +======= + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return hasattr(t, "grad") @@ -448,6 +481,7 @@ def __call__(self, value): if vt.source is None: vt.source = self.source +<<<<<<< HEAD def _is_deduplicable_sym_variable(value, vt): # Constants like 0, 1, 2, etc. can be unspecialized as SymNodeVariables sometimes, but we # should NOT track them. If we use a single SymNodeVariable instance to track them @@ -460,6 +494,10 @@ def _is_deduplicable_sym_variable(value, vt): self._can_lift_attrs_to_inputs(vt) or _is_deduplicable_sym_variable(value, vt) ) +======= + if ( + self._can_lift_attrs_to_inputs(vt) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and value not in self.tx.output.side_effects and not is_wrapper_or_member_descriptor(value) ): @@ -629,10 +667,14 @@ def _wrap(self, value): has_triton_tensor_descriptor_host_tma, ) +<<<<<<< HEAD from ..decorators import ( DynamoConfigPatchProxy, ErrorOnGraphBreakDecoratorContextManager, ) +======= + from ..decorators import DynamoConfigPatchProxy +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if has_triton(): from triton.runtime.autotuner import Autotuner @@ -690,10 +732,20 @@ def from_tensor(): ) and type(value) not in config.nontraceable_tensor_subclasses ): +<<<<<<< HEAD if ( type(value).__torch_dispatch__ is torch.Tensor.__torch_dispatch__ or is_traceable_wrapper_subclass(value) ): +======= + if type(value).__torch_dispatch__ is torch.Tensor.__torch_dispatch__: + # This case it's either tensor or subclass with default + # torch_dispatch (they might override torch_function or not), + # and we can always trace into them. + return self.wrap_tensor(value) + elif is_traceable_wrapper_subclass(value): + # For non-default torch_dispatch, we have more requirements. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.wrap_tensor(value) if is_namedtuple(value): @@ -784,6 +836,7 @@ def build_key_value(i, k, v): var = TorchFunctionModeVariable(value, source=self.source) self.tx.output.side_effects.track_object_existing(value, var) return var +<<<<<<< HEAD elif istype(value, set): if any(isinstance(x, torch.Tensor) for x in value): unimplemented_v2( @@ -816,6 +869,8 @@ def build_key_value(i, k, v): ] result = SetVariable(items, source=self.source) return self.tx.output.side_effects.track_object_existing(value, result) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif istype(value, frozenset) and all( ( # For DBR quantization, we could get a frozenset of torch funcs. @@ -989,8 +1044,11 @@ def build_key_value(i, k, v): ) elif isinstance(value, DynamoConfigPatchProxy): return DynamoConfigPatchVariable(value.changes) +<<<<<<< HEAD elif isinstance(value, ErrorOnGraphBreakDecoratorContextManager): return ErrorOnGraphBreakVariable(value.error_on_graph_break) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif callable(value) and trace_rules.lookup_callable(value) is not None: if trace_rules.is_callable_allowed(value): self.tx.output.has_user_defined_allowed_in_graph = True @@ -1224,12 +1282,15 @@ def build_key_value(i, k, v): ) and BuiltinMethodVariable.is_supported_builtin_method(value): self.install_guards(GuardBuilder.ID_MATCH) return BuiltinMethodVariable(value, source=self.source) +<<<<<<< HEAD elif is_function(value) and value in (float.fromhex, float.hex): self.install_guards(GuardBuilder.ID_MATCH) return GetAttrVariable( BuiltinVariable(float, source=self.source), value.__name__, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif is_function_or_wrapper(value): value, attr_name = unwrap_with_attr_name_if_wrapper(value) # For these wrappers, Dynamo points to the wrapped function, @@ -1331,6 +1392,7 @@ def build_key_value(i, k, v): and not is_traceable_wrapper_subclass_type(value) ): return TensorSubclassVariable(value, source=self.source) +<<<<<<< HEAD if not is_from_closure_source(self.source): # For closure source, the variable comes from LOAD_SUPER_ATTR, @@ -1341,6 +1403,11 @@ def build_key_value(i, k, v): # ID_MATCH even if its a global variable. self.install_guards(GuardBuilder.ID_MATCH) +======= + # This is a userdefined class, so install an ID_MATCH even if its a + # global variable. + self.install_guards(GuardBuilder.ID_MATCH) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return UserDefinedClassVariable( value, source=self.source, @@ -1502,6 +1569,7 @@ def build_key_value(i, k, v): ) result = UserDefinedListVariable(value, list_vt=list_vt, source=self.source) return self.tx.output.side_effects.track_object_existing(value, result) +<<<<<<< HEAD elif isinstance(value, (set, frozenset)): self.install_guards(GuardBuilder.TYPE_MATCH) self.install_guards(GuardBuilder.SEQUENCE_LENGTH) @@ -1524,6 +1592,11 @@ def build_key_value(i, k, v): self.install_guards(GuardBuilder.TYPE_MATCH) result = MutableMappingVariable(value, source=self.source) return self.tx.output.side_effects.track_object_existing(value, result) +======= + elif issubclass(type(value), MutableMapping): + self.install_guards(GuardBuilder.TYPE_MATCH) + return MutableMappingVariable(value, source=self.source) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif is_frozen_dataclass(value): self.install_guards(GuardBuilder.TYPE_MATCH) result = FrozenDataClassVariable.create(self.tx, value, source=self.source) @@ -1656,8 +1729,11 @@ def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): source=source, ) +<<<<<<< HEAD # Apply relevant logic from `VariableTracker.build(value[i])` # (except for the `create_graph_input` stuff). +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) guards = [] for i, tensor_variable in enumerate(list_variable.items): source_i = GetItemSource(base=source, index=i, index_is_slice=False) @@ -1666,6 +1742,10 @@ def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): tensor_variable.proxy.node.meta["tensor_dict"] = _extract_tensor_dict( value[i] ) +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) guard = functools.partial( GuardBuilder.TENSOR_MATCH, value=TensorWeakRef(value[i]) ) @@ -1682,6 +1762,7 @@ def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): ) tensor_list_proxy.node.meta["grapharg"] = grapharg +<<<<<<< HEAD # The following is very important for maintaining the "python object # <==> variable tracker" 1-to-1 mapping, which is mainly handled via # `side_effects`. Note that constructing `tensor_variable` above @@ -1703,6 +1784,8 @@ def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): for vt in output: vt.realize() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result = BaseListVariable.cls_for_instance(value)(output, source=self.source) if istype(value, (list, collections.deque)): return self.tx.output.side_effects.track_mutable(value, result) @@ -1943,12 +2026,15 @@ def wrap_literal(self, value): "integer into a tensor." ) +<<<<<<< HEAD process_automatic_dynamic( self.tx, self.source.name(), FrameStateSizeEntry.make_scalar(value), is_unspecialized_nn_module=self.source.guard_source().is_unspecialized_nn_module(), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.install_guards( functools.partial( GuardBuilder.EQUALS_MATCH, recompile_hint=recompile_hint @@ -2054,8 +2140,37 @@ def wrap_tensor(self, value: torch.Tensor): return self.tx.output.input_source_to_var[source] options = {} +<<<<<<< HEAD subclass_type = infer_subclass_type(value) if subclass_type is not None: +======= + if type(value) in ( + torch.Tensor, + torch.nn.Parameter, + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ) or is_traceable_wrapper_subclass(value): + # Ordinarily, we would fakeify a tensor so that it can get dynamic + # shapes and be computed on without triggering actual operations. + # However, how can we fakeify a tensor subclass? Ordinary + # inheritance (nor multiple inheritance) won't work work. + # + # Instead, our plan is to *manually simulate* the tensor subclass + # inheriting from a fake tensor with dynamo. This means our + # data representation for a tensor subclass will be a fake tensor + # + tensor subclass type + any extra data the subclass may have + # been storing on the tensor. Because all Python accesses are + # mediated through TensorWithTFOverrideVariable, we can ensure + # that we dispatch differently, e.g., according to + # __torch_function__ + # + # To simplify things for now, the __dict__ tracking bits haven't + # been implemented yet, but they can be added into this design at + # a later point in time. + subclass_type = None + else: + subclass_type = type(value) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.install_guards(GuardBuilder.TYPE_MATCH) if get_static_address_type(value) == "guarded": @@ -2913,7 +3028,11 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe elif example_value is None or proxy.node.target is torch.manual_seed: return ConstantVariable.create(None, **options) elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)): +<<<<<<< HEAD tx.output.current_tracer.track_produced_symints(example_value, proxy) +======= + tx.output.current_tracer.track_unbacked_symbols(example_value, proxy) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) set_example_value(proxy.node, example_value) return SymNodeVariable(proxy, example_value, **options) elif ( @@ -2955,8 +3074,11 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe torch.seed, operator.mod, torch._functorch.vmap._validate_and_get_batch_size, +<<<<<<< HEAD torch._functorch.predispatch._vmap_increment_nesting, torch._functorch.predispatch._vmap_decrement_nesting, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # some mac builds are missing torch.distributed.get_rank() getattr(torch.distributed, "get_rank", _missing), getattr(torch.distributed, "get_world_size", _missing), @@ -2990,10 +3112,16 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe ): set_example_value(proxy.node, example_value) return ConstantVariable.create(example_value, **options) +<<<<<<< HEAD elif isinstance(example_value, (int, float, bool)) and ( proxy.node.target is call_torchbind or proxy.node.target is flat_apply or (proxy.node.op == "call_method" and proxy.node.target == "item") +======= + elif ( + isinstance(example_value, (int, float, bool)) + and proxy.node.target is call_torchbind +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): set_example_value(proxy.node, example_value) return ConstantVariable.create(example_value, **options) @@ -3009,6 +3137,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe ) +<<<<<<< HEAD def infer_subclass_type(value): if type(value) in ( torch.Tensor, @@ -3058,6 +3187,8 @@ def get_specialized_props(target_cls, tx, example_value, subclass_type): return specialized_props +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def construct_tensor_variable( target_cls, tx, proxy, example_value, subclass_type, options ): @@ -3074,9 +3205,30 @@ def construct_tensor_variable( # So that subgraphs can access the unbacked symbol's proxy in parent graph # when lifting unbacked symbols of input tensors to subgraph inputs. # We do it lazily because the tensor may not be used in subgraphs. +<<<<<<< HEAD if proxy.node.op != "placeholder": tx.output.current_tracer.track_produced_symints(example_value, proxy) options.update(get_specialized_props(target_cls, tx, example_value, subclass_type)) +======= + tx.output.current_tracer.track_unbacked_symbols(example_value, proxy) + specialized_props = target_cls.specialize(example_value) + # TODO: not sure about this fake mode test + if ( + isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) + and example_value.fake_mode is tx.fake_mode + ): + if subclass_type: + tensor_type = subclass_type + elif isinstance(example_value, torch.nn.Parameter): + tensor_type = torch.nn.Parameter + elif isinstance(example_value, torch.nn.Buffer): + tensor_type = torch.nn.Buffer + else: + tensor_type = torch.Tensor + specialized_props["class_type"] = tensor_type + + options.update(specialized_props) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return target_cls(proxy, **options) @@ -3248,6 +3400,10 @@ def _automatic_dynamic( ) if static_shapes and not is_dynamic_source(name): +<<<<<<< HEAD +======= + record_automatic_dynamic(tx, name, e) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return StatefulSymbolicContext( dynamic_sizes=[DimDynamic.STATIC] * e.dim(), dynamic_strides=[DimDynamic.INFER_STRIDE] * e.dim(), @@ -3364,10 +3520,18 @@ def update_dim2constraint(dim, constraint_range, name): if is_dynamic_source(name): log.debug("%s marked dynamic via source whitelist", name) automatic_dynamic_size = True +<<<<<<< HEAD +======= + automatic_dynamic_stride = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_unbacked_source(name): log.debug("%s marked unbacked via source whitelist", name) automatic_dynamic_size = True +<<<<<<< HEAD +======= + automatic_dynamic_stride = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) automatic_dynamic = automatic_dynamic_size or automatic_dynamic_stride @@ -3498,6 +3662,7 @@ def wrap_to_fake_tensor_and_record( type(e), ) +<<<<<<< HEAD # Note [enable_python_dispatcher in dynamo] # Dynamo disables itself when it runs fake tensor prop, which means that tensor subclasses # have no way to know (purely based off of global state) if they are currently being run under compile or not. @@ -3511,6 +3676,15 @@ def wrap_to_fake_tensor_and_record( symbolic_context=symbolic_context, ) ) +======= + fake_e = wrap_fake_exception( + lambda: tx.fake_mode.from_tensor( + e, + source=source, + symbolic_context=symbolic_context, + ) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( source is not None and isinstance(fake_e, FakeTensor) @@ -3602,12 +3776,15 @@ def create(tx: "InstructionTranslator", value) -> VariableTracker: if trace_rules.is_callable_allowed(value): tx.output.has_user_defined_allowed_in_graph = True return trace_rules.lookup_callable(value)(value) +<<<<<<< HEAD elif callable(value) and UserDefinedClassVariable.is_supported_new_method( value ): # NamedTuple._make uses an alias of tuple.__new__ obj = trace_rules.lookup_callable(value.__self__)(value.__self__) return GetAttrVariable(obj, "__new__") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif is_function_or_wrapper(value): return trace_rules.lookup(value)(value) elif isinstance( diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index b46707f2f1172..7709e58b19557 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -33,20 +33,31 @@ import typing import unittest from collections import defaultdict, OrderedDict +<<<<<<< HEAD from collections.abc import Iterable, KeysView, Sequence from typing import Any, Callable, TYPE_CHECKING, Union +======= +from collections.abc import KeysView, Sequence +from typing import Callable, TYPE_CHECKING, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch import sym_float, sym_int from torch._subclasses.meta_utils import is_sparse_any +<<<<<<< HEAD from torch.overrides import BaseTorchFunctionMode +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._python_dispatch import is_traceable_wrapper_subclass from .. import config, graph_break_hints, polyfills, variables from ..exc import ( AttributeMutationError, ObservedAttributeError, +<<<<<<< HEAD ObservedUserStopIteration, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise_observed_exception, unimplemented_v2, Unsupported, @@ -70,7 +81,10 @@ cmp_name_to_op_mapping, dict_methods, extract_fake_example_value, +<<<<<<< HEAD frozenset_methods, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) get_fake_value, guard_if_dyn, is_tensor_getset_descriptor, @@ -78,7 +92,10 @@ istype, numpy_operator_wrapper, proxy_args_kwargs, +<<<<<<< HEAD set_methods, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) str_methods, tensortype_to_dtype, ) @@ -88,7 +105,10 @@ from .dicts import ( ConstDictVariable, DefaultDictVariable, +<<<<<<< HEAD DictKeysVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DictViewVariable, FrozensetVariable, is_hashable, @@ -109,12 +129,16 @@ TensorVariable, UnspecializedPythonVariable, ) +<<<<<<< HEAD from .user_defined import ( MutableMappingVariable, UserDefinedDictVariable, UserDefinedObjectVariable, UserDefinedVariable, ) +======= +from .user_defined import UserDefinedObjectVariable, UserDefinedVariable +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING: @@ -155,6 +179,7 @@ operator.ge: polyfills.cmp_ge, } +<<<<<<< HEAD bin_ops = ( operator.pow, operator.mul, @@ -275,6 +300,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None): if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]: BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class BuiltinVariable(VariableTracker): """ @@ -308,7 +335,10 @@ def _constant_fold_functions(): bool, callable, chr, +<<<<<<< HEAD complex, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) divmod, float, getattr, @@ -673,6 +703,7 @@ def list_iadd_handler(tx: "InstructionTranslator", a, b): def expand_list_like(tx: "InstructionTranslator", lst, const): if isinstance(lst, ConstantVariable): lst, const = const, lst +<<<<<<< HEAD try: return lst.__class__( items=lst.items * const.as_python_constant(), @@ -684,6 +715,12 @@ def expand_list_like(tx: "InstructionTranslator", lst, const): tx, args=list(map(ConstantVariable.create, exc.args)), ) +======= + return lst.__class__( + items=lst.items * const.as_python_constant(), + mutation_type=ValueMutationNew(), + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) list_like_expansion_handlers: list[ tuple[ @@ -700,6 +737,7 @@ def expand_list_like(tx: "InstructionTranslator", lst, const): def create_cmp_op_handlers(op): def compare_by_value(tx: "InstructionTranslator", a, b): +<<<<<<< HEAD try: return ConstantVariable(op(a.value, b.value)) except TypeError as exc: @@ -708,6 +746,9 @@ def compare_by_value(tx: "InstructionTranslator", a, b): tx, args=list(map(ConstantVariable.create, exc.args)), ) +======= + return ConstantVariable(op(a.value, b.value)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result: list[ tuple[ @@ -1158,6 +1199,7 @@ def builtin_dispatch(tx: "InstructionTranslator", args, kwargs): return builtin_dispatch +<<<<<<< HEAD def call_vars(self, tx: "InstructionTranslator", *args): if len(args) == 0: unimplemented_v2( @@ -1173,6 +1215,8 @@ def call_vars(self, tx: "InstructionTranslator", *args): except ObservedAttributeError: raise_observed_exception(TypeError, tx) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _handle_insert_op_in_graph(self, tx: "InstructionTranslator", args, kwargs): from .builder import wrap_fx_proxy, wrap_fx_proxy_cls @@ -1181,15 +1225,28 @@ def _handle_insert_op_in_graph(self, tx: "InstructionTranslator", args, kwargs): # insert handling for torch function here from .builder import SourcelessBuilder +<<<<<<< HEAD from .torch_function import can_dispatch_torch_function, dispatch_torch_function global BUILTIN_TO_TENSOR_RFN_MAP, BUILTIN_TO_TENSOR_FN_MAP +======= + from .torch_function import ( + BUILTIN_TO_TENSOR_FN_MAP, + BUILTIN_TO_TENSOR_RFN_MAP, + can_dispatch_torch_function, + dispatch_torch_function, + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if can_dispatch_torch_function(tx, args, kwargs): # Only remap the fn to tensor methods if we aren't exporting # export serde does not handle method descriptors today if not tx.export: +<<<<<<< HEAD # Ensure the builtin maps are populated before accessing them populate_builtin_to_tensor_fn_map() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Use sourceless builder, we built the map ourselves if not isinstance(args[0], TensorVariable): if self.fn in BUILTIN_TO_TENSOR_RFN_MAP: @@ -1389,11 +1446,19 @@ def call_method( if ( self.fn is tuple and len(args) == 2 +<<<<<<< HEAD and args[1].has_force_unpack_var_sequence(tx) and not kwargs ): if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple: init_args = args[1].force_unpack_var_sequence(tx) +======= + and args[1].has_unpack_var_sequence(tx) + and not kwargs + ): + if isinstance(args[0], BuiltinVariable) and args[0].fn is tuple: + init_args = args[1].unpack_var_sequence(tx) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return variables.TupleVariable( init_args, mutation_type=ValueMutationNew() ) @@ -1414,6 +1479,7 @@ def call_method( args[1:], ) +<<<<<<< HEAD if self.fn is float and len(args) == 1 and name in ("fromhex", "hex"): if isinstance(args[0], ConstantVariable): try: @@ -1427,6 +1493,8 @@ def call_method( args=list(map(ConstantVariable.create, e.args)), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.fn is object and name == "__init__": # object.__init__ is a no-op return variables.ConstantVariable(None) @@ -1442,6 +1510,7 @@ def call_method( elif isinstance(args[0], variables.ConstDictVariable): return args[0].call_method(tx, name, args[1:], kwargs) +<<<<<<< HEAD if self.fn is set: resolved_fn = getattr(self.fn, name) if resolved_fn in set_methods: @@ -1456,18 +1525,23 @@ def call_method( if isinstance(args[0], variables.FrozensetVariable): return args[0].call_method(tx, name, args[1:], kwargs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.fn is str and len(args) >= 1: resolved_fn = getattr(self.fn, name) if resolved_fn in str_methods: if isinstance(args[0], ConstantVariable): return args[0].call_method(tx, name, args[1:], kwargs) +<<<<<<< HEAD if self.fn is float and len(args) >= 1: if isinstance(args[0], ConstantVariable): return ConstantVariable.create( getattr(float, name)(args[0].as_python_constant()) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().call_method(tx, name, args, kwargs) def _call_int_float(self, tx: "InstructionTranslator", arg): @@ -1773,7 +1847,11 @@ def _call_iter_tuple_list( if ( getattr(obj, "source", False) and isinstance(obj, ConstDictVariable) +<<<<<<< HEAD and not istype(obj, (SetVariable, FrozensetVariable)) +======= + and not istype(obj, SetVariable) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): tx.output.guard_on_key_order.add(obj.source) @@ -1809,10 +1887,14 @@ def _call_tuple_list(self, tx, obj=None, *args, **kwargs): list(obj.force_unpack_var_sequence(tx)), mutation_type=ValueMutationNew(), ) +<<<<<<< HEAD elif isinstance(obj, variables.LocalGeneratorObjectVariable) or ( isinstance(obj, UserDefinedObjectVariable) and obj.has_force_unpack_var_sequence(tx) ): +======= + elif isinstance(obj, variables.LocalGeneratorObjectVariable): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self._call_iter_tuple_generator(tx, obj, *args, **kwargs) else: return self._call_iter_tuple_list(tx, obj, *args, **kwargs) @@ -1820,8 +1902,11 @@ def _call_tuple_list(self, tx, obj=None, *args, **kwargs): def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs): if isinstance(obj, variables.IteratorVariable): ret = obj +<<<<<<< HEAD elif isinstance(obj, variables.RangeVariable): ret = obj.call_method(tx, "__iter__", [], {}) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: # Handle the case where we are iterating over a tuple, list or iterator ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs) @@ -1830,6 +1915,7 @@ def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs): # If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway. # If the object implements a __iter__ method, inlining effectively forwards the call to another iter call # (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator. +<<<<<<< HEAD # If the object implements a __getitem__ method, iter(...) will call obj.__getitem__() # with an integer argument starting at 0, until __getitem__ raises IndexError ret = variables.UserFunctionVariable( @@ -1842,6 +1928,9 @@ def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs): # Wrap the return value in a IteratorVariable subclass (LazyObjectIteratorVariable) # that forwards the next_variable call to the object. ret = variables.ObjectIteratorVariable(ret) +======= + return obj.call_method(tx, "__iter__", args, kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ret call_tuple = _call_tuple_list @@ -1887,17 +1976,21 @@ def call_cast(self, _, *args, **kwargs): hints=["Ensure your call to cast() has exactly 2 arguments."], ) +<<<<<<< HEAD def call_dir(self, tx: "InstructionTranslator", arg): if isinstance(arg, variables.UserDefinedClassVariable): return VariableTracker.build(tx, dir(arg.value)) if isinstance(arg, BuiltinVariable): return VariableTracker.build(tx, dir(arg.fn)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def call_dict(self, tx: "InstructionTranslator", *args, **kwargs): return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs) @staticmethod def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs): +<<<<<<< HEAD args = list(args) if ( len(args) == 1 @@ -1909,6 +2002,8 @@ def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs): # VT(foo.__dict__). This simplifies the construction of the new # dict. args[0] = args[0].get_forwarded_dict(tx) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return tx.inline_user_function_return( VariableTracker.build(tx, polyfills.construct_dict), [VariableTracker.build(tx, user_cls), *args], @@ -1926,10 +2021,14 @@ def call_custom_dict_fromkeys( assert len(args) == 1 and len(kwargs) == 1 and "value" in kwargs args = (*args, kwargs.pop("value")) if len(args) == 0: +<<<<<<< HEAD msg = ConstantVariable.create( "fromkeys expected at least 1 arguments, got 0" ) raise_observed_exception(TypeError, tx, args=[msg]) +======= + raise UserError(TypeError, "fromkeys expected at least 1 argument, got 0") # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if len(args) == 1: args = (*args, ConstantVariable.create(None)) assert len(args) == 2 @@ -1980,7 +2079,11 @@ def call_set(self, tx: "InstructionTranslator", *args, **kwargs): ], ) arg = args[0] +<<<<<<< HEAD if istype(arg, variables.SetVariable): +======= + if isinstance(arg, variables.SetVariable): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return arg.clone(mutation_type=ValueMutationNew()) elif arg.has_force_unpack_var_sequence(tx): items = arg.force_unpack_var_sequence(tx) @@ -2015,10 +2118,17 @@ def call_frozenset(self, tx: "InstructionTranslator", *args, **kwargs): ], ) arg = args[0] +<<<<<<< HEAD if istype(arg, variables.FrozensetVariable): return FrozensetVariable([x.vt for x in arg.set_items]) elif arg.has_force_unpack_var_sequence(tx): items = arg.force_unpack_var_sequence(tx) +======= + if isinstance(arg, variables.FrozensetVariable): + return FrozensetVariable([x.vt for x in arg.set_items]) + elif arg.has_unpack_var_sequence(tx): + items = arg.unpack_var_sequence(tx) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return FrozensetVariable(items) raise_observed_exception( TypeError, @@ -2030,7 +2140,14 @@ def call_zip(self, tx: "InstructionTranslator", *args, **kwargs): if kwargs: assert len(kwargs) == 1 and "strict" in kwargs strict = kwargs.pop("strict", False) +<<<<<<< HEAD args = [BuiltinVariable(iter).call_function(tx, [arg], {}) for arg in args] +======= + args = [ + arg.unpack_var_sequence(tx) if arg.has_unpack_var_sequence(tx) else arg + for arg in args + ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return variables.ZipVariable( args, strict=strict, mutation_type=ValueMutationNew() ) @@ -2169,6 +2286,7 @@ def call_issubclass(self, tx: "InstructionTranslator", left_ty, right_ty): def call_super(self, tx: "InstructionTranslator", a, b): return variables.SuperVariable(a, b) +<<<<<<< HEAD def call_next(self, tx: "InstructionTranslator", *args): arg = args[0] try: @@ -2177,6 +2295,11 @@ def call_next(self, tx: "InstructionTranslator", *args): if len(args) == 2: return args[1] raise +======= + def call_next(self, tx: "InstructionTranslator", arg: VariableTracker): + try: + return arg.next_variable(tx) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except Unsupported as ex: if isinstance(arg, variables.BaseListVariable): ex.remove_from_stats() @@ -2201,6 +2324,7 @@ def call_filter(self, tx: "InstructionTranslator", fn, seq): seq = seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq return variables.FilterVariable(fn, seq, mutation_type=ValueMutationNew()) +<<<<<<< HEAD def var_getattr(self, tx: "InstructionTranslator", name): source = self.source and AttrSource(self.source, name) if self.fn is object: @@ -2213,6 +2337,8 @@ def var_getattr(self, tx: "InstructionTranslator", name): return VariableTracker.build(tx, value, source) return variables.GetAttrVariable(self, name, source=source) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def call_getattr( self, tx: "InstructionTranslator", @@ -2308,6 +2434,11 @@ def call_getattr( "assertRaisesRegex", "assertNotWarns", "assertWarnsRegex", +<<<<<<< HEAD +======= + "assertDictEqual", + "assertSequenceEqual", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "assertWarns", ) ): @@ -2420,6 +2551,7 @@ def call_setattr( "the mutation out of `torch.compile` region", ], ) +<<<<<<< HEAD elif obj.dtype != val.dtype: # type: ignore[attr-defined] unimplemented_v2( gb_type="Failed to mutate tensor data attribute to different dtype", @@ -2431,6 +2563,8 @@ def call_setattr( "the mutation out of `torch.compile` region", ], ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Remove the old reference in tracked fakes - if we don't do this # new .data value size and shape differences will cause @@ -2589,6 +2723,7 @@ def call_neg(self, tx: "InstructionTranslator", a): (operator.neg)(a.as_proxy()), sym_num=None, ) +<<<<<<< HEAD if ( isinstance(a, UserDefinedObjectVariable) @@ -2596,6 +2731,8 @@ def call_neg(self, tx: "InstructionTranslator", a): ): return a.call_method(tx, "__neg__", [], {}) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # None no-ops this handler and lets the driving function proceed return None @@ -2729,6 +2866,7 @@ def _comparison_with_symnode(self, tx: "InstructionTranslator", left, right): sym_num=None, ) +<<<<<<< HEAD def call_xor(self, tx: "InstructionTranslator", a, b): if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): return a.call_method(tx, "__xor__", [b], {}) @@ -2745,6 +2883,8 @@ def call_isub(self, tx: "InstructionTranslator", a, b): if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): return a.call_method(tx, "__isub__", [b], {}) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def call_and_(self, tx: "InstructionTranslator", a, b): # Rely on constant_handler if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): @@ -2759,6 +2899,7 @@ def call_and_(self, tx: "InstructionTranslator", a, b): ), sym_num=None, ) +<<<<<<< HEAD if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): return a.call_method(tx, "__and__", [b], {}) # None no-ops this handler and lets the driving function proceed @@ -2779,6 +2920,13 @@ def call_iand(self, tx: "InstructionTranslator", a, b): ) if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): return a.call_method(tx, "__iand__", [b], {}) +======= + if hasattr(a, "set_items") and hasattr(b, "set_items"): + return SetVariable(list(a.set_items & b.set_items)) + # None no-ops this handler and lets the driving function proceed + + call_iand = call_and_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def call_or_(self, tx: "InstructionTranslator", a, b): # Rely on constant_handler @@ -2794,6 +2942,7 @@ def call_or_(self, tx: "InstructionTranslator", a, b): ), sym_num=None, ) +<<<<<<< HEAD # This call looks like `{"one": torch.ones(1)} | {"two": torch.ones(2)}`. if isinstance( @@ -2844,6 +2993,17 @@ def call_ior(self, tx: "InstructionTranslator", a, b): # None no-ops this handler and lets the driving function proceed return None +======= + if hasattr(a, "set_items") and hasattr(b, "set_items"): + return SetVariable(list(a.set_items | b.set_items)) + # This call looks like `{"one": torch.ones(1)} | {"two": torch.ones(2)}`. + if isinstance(a, ConstDictVariable): + return a.call_method(tx, "__or__", args=[b], kwargs={}) + # None no-ops this handler and lets the driving function proceed + return None + + call_ior = call_or_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def call_not_(self, tx: "InstructionTranslator", a): if isinstance(a, SymNodeVariable): diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 11822016827ea..66e5e5a437b2d 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -125,7 +125,11 @@ def unpack_var_sequence(self, tx): def const_getattr(self, tx: "InstructionTranslator", name): if not hasattr(self.value, name): +<<<<<<< HEAD raise_observed_exception(AttributeError, tx, args=[name]) +======= + raise NotImplementedError +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) member = getattr(self.value, name) if callable(member): raise NotImplementedError @@ -173,6 +177,7 @@ def call_method( raise_observed_exception(type(e), tx) elif isinstance(self.value, (float, int)): if not (args or kwargs): +<<<<<<< HEAD try: return ConstantVariable.create(getattr(self.value, name)()) except (OverflowError, ValueError) as exc: @@ -181,6 +186,9 @@ def call_method( tx, args=list(map(ConstantVariable.create, exc.args)), ) +======= + return ConstantVariable.create(getattr(self.value, name)()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( hasattr(operator, name) and len(args) == 1 @@ -206,16 +214,20 @@ def call_method( elif isinstance(self.value, bytes) and name == "decode": method = getattr(self.value, name) return ConstantVariable.create(method(*const_args, **const_kwargs)) +<<<<<<< HEAD elif type(self.value) is complex and name in complex.__dict__.keys(): method = getattr(self.value, name) try: return ConstantVariable.create(method(*const_args, **const_kwargs)) except Exception as e: raise_observed_exception(type(e), tx) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if name == "__len__" and not (args or kwargs): return ConstantVariable.create(len(self.value)) elif name == "__round__" and len(args) == 1 and args[0].is_python_constant(): +<<<<<<< HEAD try: return ConstantVariable.create( round(self.value, args[0].as_python_constant()) @@ -234,6 +246,16 @@ def call_method( raise_observed_exception( type(e), tx, args=list(map(ConstantVariable.create, e.args)) ) +======= + return ConstantVariable.create( + round(self.value, args[0].as_python_constant()) + ) + elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant(): + assert not kwargs + search = args[0].as_python_constant() + result = search in self.value + return ConstantVariable.create(result) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().call_method(tx, name, args, kwargs) def call_obj_hasattr( diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 15a5540395d18..bbb3fbd4ce58b 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -38,7 +38,10 @@ from ..exc import unimplemented_v2 from ..guards import GuardBuilder, install_guard from ..source import AttrSource, GlobalStateSource +<<<<<<< HEAD from ..utils import _get_error_on_graph_break, _set_error_on_graph_break +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .base import VariableTracker from .functions import ( NestedUserFunctionVariable, @@ -197,6 +200,7 @@ def exit_on_graph_break(self): return True +<<<<<<< HEAD class RepararametrizeModuleContextVariable(GenericContextWrappingVariable): def __init__(self, ctx_manager_vt, mod): self.cm_vt = ctx_manager_vt @@ -225,6 +229,8 @@ def __getattr__(self, name): return getattr(self.cm_vt, name) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): """represents torch grad requires grad""" @@ -523,7 +529,11 @@ def enter(self, tx): self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting()) self.proxy = tx.output.create_node( "call_function", +<<<<<<< HEAD torch._functorch.predispatch._vmap_increment_nesting, +======= + torch._C._functorch._vmap_increment_nesting, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (batch_size_node, randomness), {}, ) @@ -532,10 +542,14 @@ def enter(self, tx): def exit(self, tx: "InstructionTranslator", *args): self.cleanup() tx.output.create_node( +<<<<<<< HEAD "call_function", torch._functorch.predispatch._vmap_decrement_nesting, (), {}, +======= + "call_function", torch._C._functorch._vmap_decrement_nesting, (), {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return variables.ConstantVariable.create(None) @@ -810,8 +824,15 @@ def enter(self, tx): def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 value = values[0] +<<<<<<< HEAD tx.output.create_node( "call_function", torch._C._set_deterministic_algorithms, (value,), {} +======= + ( + tx.output.create_node( + "call_function", torch._C._set_deterministic_algorithms, (value,), {} + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) torch._C._set_deterministic_algorithms(value) @@ -943,8 +964,12 @@ def __init__(self, target_values=None, **kwargs) -> None: super().__init__(target_values=target_values, **kwargs) def enter(self, tx): +<<<<<<< HEAD none = variables.ConstantVariable.create(None) return self.target_values if self.target_values else none +======= + return variables.ConstantVariable.create(None) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def exit(self, tx: "InstructionTranslator", *args): return variables.ConstantVariable.create(None) @@ -1411,6 +1436,19 @@ def __init__(self, target_values, **kwargs) -> None: self.initial_values[key] = torch._dynamo.config.__getattr__(key) self.initial_values = (tuple(self.initial_values.items()),) +<<<<<<< HEAD +======= + def enter(self, tx): + # resets all config patches at the end of tracing + self.set_cleanup_hook(tx) + self._call_func(tx, self.target_values) + return variables.ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + self._call_func(tx, self.initial_values) + return variables.ConstantVariable.create(None) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 value = values[0] @@ -1429,6 +1467,7 @@ def fn_name(self): return "patch_dynamo_config" +<<<<<<< HEAD class ErrorOnGraphBreakVariable(ContextWrappingVariable): """represents torch._dynamo.error_on_graph_break""" @@ -1450,6 +1489,8 @@ def fn_name(self): return "error_on_graph_break" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class WithExitFunctionVariable(VariableTracker): _nonvar_fields = { "target", diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index c33979aae07df..71f6f6be4724b 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -41,7 +41,10 @@ dict_keys, dict_values, istype, +<<<<<<< HEAD raise_args_mismatch, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) specialize_symnode, ) from .base import ValueMutationNew, VariableTracker @@ -58,6 +61,17 @@ # - (perhaps) Define how it is compared in _HashableTracker._eq_impl +<<<<<<< HEAD +======= +def raise_args_mismatch(tx, name): + raise_observed_exception( + TypeError, + tx, + args=[ConstantVariable(f"wrong number of arguments for {name}() call")], + ) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def was_instancecheck_override(obj): return type(obj).__dict__.get("__instancecheck__", False) @@ -91,8 +105,11 @@ def is_hashable(x): return x.as_proxy().node.meta.get("example_value") is not None elif isinstance(x, variables.TupleVariable): return all(is_hashable(e) for e in x.items) +<<<<<<< HEAD elif isinstance(x, variables.FrozenDataClassVariable): return all(is_hashable(e) for e in x.fields.values()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif ( isinstance(x, variables.UserDefinedObjectVariable) and not was_instancecheck_override(x.value) @@ -108,7 +125,10 @@ def is_hashable(x): variables.SymNodeVariable, variables.ConstantVariable, variables.EnumVariable, +<<<<<<< HEAD variables.FrozensetVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) variables.UserDefinedClassVariable, variables.UserFunctionVariable, variables.SkipFunctionVariable, @@ -120,14 +140,20 @@ def is_hashable(x): variables.TypingVariable, variables.FunctoolsPartialVariable, variables.WeakRefVariable, +<<<<<<< HEAD variables.TorchHigherOrderOperatorVariable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ) class ConstDictVariable(VariableTracker): +<<<<<<< HEAD CONTAINS_GUARD = GuardBuilder.DICT_CONTAINS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _nonvar_fields = { "user_cls", *VariableTracker._nonvar_fields, @@ -172,6 +198,7 @@ def underlying_value(self): # Access the underlying value inside the referent_vt for the key representation Hashable = ConstDictVariable._HashableTracker return Hashable(self.vt.referent_vt).underlying_value +<<<<<<< HEAD elif isinstance(self.vt, variables.FrozenDataClassVariable): Hashable = ConstDictVariable._HashableTracker fields_values = { @@ -180,6 +207,8 @@ def underlying_value(self): return variables.FrozenDataClassVariable.HashWrapper( self.vt.python_type(), fields_values ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif isinstance(self.vt, variables.UserDefinedObjectVariable): # The re module in Python 3.13+ has a dictionary (_cache2) with # an object as key (`class _ZeroSentinel(int): ...`): @@ -245,14 +274,19 @@ def __init__( def make_hashable(key): return key if isinstance(key, Hashable) else Hashable(key) +<<<<<<< HEAD dict_cls = self._get_dict_cls_from_user_cls(user_cls) self.items = dict_cls({make_hashable(x): v for x, v in items.items()}) +======= + self.items = {make_hashable(x): v for x, v in items.items()} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # need to reconstruct everything if the dictionary is an intermediate value # or if a pop/delitem was executed self.should_reconstruct_all = not is_from_local_source(self.source) self.original_items = items.copy() self.user_cls = user_cls +<<<<<<< HEAD def _get_dict_cls_from_user_cls(self, user_cls): accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict) @@ -272,6 +306,8 @@ def _get_dict_cls_from_user_cls(self, user_cls): dict_cls = dict return dict_cls +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def as_proxy(self): return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()} @@ -306,6 +342,7 @@ def __contains__(self, vt) -> bool: and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) ) +<<<<<<< HEAD def len(self) -> int: return sum( not isinstance(x, variables.DeletedVariable) for x in self.items.values() @@ -313,6 +350,21 @@ def len(self) -> int: def has_new_items(self) -> bool: return self.should_reconstruct_all or any( +======= + def len(self): + return len( + [ + x + for x in self.items.values() + if not isinstance(x, variables.DeletedVariable) + ] + ) + + def has_new_items(self): + if self.should_reconstruct_all: + return True + return any( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.is_new_item(self.original_items.get(key.vt), value) for key, value in self.items.items() ) @@ -420,7 +472,11 @@ def install_dict_contains_guard(self, tx, args): install_guard( self.make_guard( functools.partial( +<<<<<<< HEAD type(self).CONTAINS_GUARD, +======= + GuardBuilder.DICT_CONTAINS, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) key=args[0].value, invert=not contains, ) @@ -461,17 +517,25 @@ def call_method( return ConstantVariable.create(None) elif name == "__getitem__": # Key guarding - Nothing to do. LazyVT for value will take care. +<<<<<<< HEAD if len(args) != 1: raise_args_mismatch(tx, name) return self.getitem_const_raise_exception_if_absent(tx, args[0]) elif name == "items": if args or kwargs: raise_args_mismatch(tx, name) +======= + assert len(args) == 1 + return self.getitem_const_raise_exception_if_absent(tx, args[0]) + elif name == "items": + assert not (args or kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.install_dict_keys_match_guard() if self.source: tx.output.guard_on_key_order.add(self.source) return DictItemsVariable(self) elif name == "keys": +<<<<<<< HEAD if len(args): raise_args_mismatch(tx, name) self.install_dict_keys_match_guard() @@ -491,12 +555,32 @@ def call_method( self.install_dict_keys_match_guard() if args or kwargs: raise_args_mismatch(tx, name) +======= + self.install_dict_keys_match_guard() + if self.source: + tx.output.guard_on_key_order.add(self.source) + assert not (args or kwargs) + return DictKeysVariable(self) + elif name == "values": + self.install_dict_keys_match_guard() + if self.source: + tx.output.guard_on_key_order.add(self.source) + assert not (args or kwargs) + return DictValuesVariable(self) + elif name == "copy": + self.install_dict_keys_match_guard() + assert not (args or kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.clone( items=self.items.copy(), mutation_type=ValueMutationNew(), source=None ) elif name == "__len__": +<<<<<<< HEAD if args or kwargs: raise_args_mismatch(tx, name) +======= + assert not (args or kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.install_dict_keys_match_guard() return ConstantVariable.create(len(self.items)) elif name == "__setitem__" and self.is_mutable(): @@ -514,6 +598,7 @@ def call_method( tx.output.side_effects.mutation(self) self.items.__delitem__(Hashable(args[0])) return ConstantVariable.create(None) +<<<<<<< HEAD elif name == "get": if len(args) not in (1, 2): raise_args_mismatch(tx, name) @@ -579,6 +664,22 @@ def call_method( elif name == "clear": if args or kwargs: raise_args_mismatch(tx, name) +======= + elif name in ("pop", "get") and len(args) in (1, 2) and args[0] not in self: + # missing item, return the default value. Install no DICT_CONTAINS guard. + self.install_dict_contains_guard(tx, args) + if len(args) == 1: + if name == "pop": + raise_observed_exception(KeyError, tx) + return ConstantVariable(None) + else: + return args[1] + elif name == "pop" and arg_hashable and self.is_mutable(): + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + return self.items.pop(Hashable(args[0])) + elif name == "clear": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.should_reconstruct_all = True tx.output.side_effects.mutation(self) self.items.clear() @@ -610,16 +711,24 @@ def call_method( return ConstantVariable.create(None) else: return super().call_method(tx, name, args, kwargs) +<<<<<<< HEAD elif name == "__contains__": if not len(args): raise_args_mismatch(tx, name) +======= + elif name in ("get", "__getattr__") and args[0] in self: + # Key guarding - Nothing to do. + return self.getitem_const(tx, args[0]) + elif name == "__contains__" and len(args) == 1: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not arg_hashable: raise_unhashable(args[0]) self.install_dict_contains_guard(tx, args) contains = args[0] in self return ConstantVariable.create(contains) +<<<<<<< HEAD elif name == "setdefault" and self.is_mutable(): if len(args) not in (1, 2): raise_args_mismatch(tx, name) @@ -627,6 +736,9 @@ def call_method( if not arg_hashable: raise_unhashable(args[0]) +======= + elif name == "setdefault" and arg_hashable and self.is_mutable(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.install_dict_keys_match_guard() assert not kwargs assert len(args) <= 2 @@ -643,6 +755,7 @@ def call_method( return x elif name == "move_to_end": self.install_dict_keys_match_guard() +<<<<<<< HEAD tx.output.side_effects.mutation(self) if args[0] not in self: raise_observed_exception(KeyError, tx) @@ -718,6 +831,25 @@ def call_method( mutation_type=ValueMutationNew(), source=None, user_cls=user_cls, +======= + assert not kwargs and len(args) == 1 + tx.output.side_effects.mutation(self) + key = Hashable(args[0]) + val = self.items[key] + self.items.pop(key) + self.items[key] = val + return ConstantVariable.create(None) + elif name == "__or__": + assert len(args) == 1 + if not isinstance(args[0], ConstDictVariable): + raise TypeError( + f"unsupported operand type(s) for |: 'dict' and '{args[0].python_type().__name__}'" + ) + + self.install_dict_keys_match_guard() + new_dict_vt = self.clone( + items=self.items.copy(), mutation_type=ValueMutationNew(), source=None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # NB - Guard on all the keys of the other dict to ensure @@ -725,9 +857,12 @@ def call_method( args[0].install_dict_keys_match_guard() new_dict_vt.items.update(args[0].items) return new_dict_vt +<<<<<<< HEAD elif name == "__ior__": self.call_method(tx, "update", args, kwargs) return self +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return super().call_method(tx, name, args, kwargs) @@ -908,8 +1043,11 @@ def reconstruct(self, codegen): class SetVariable(ConstDictVariable): """We model a sets as dictionary with None values""" +<<<<<<< HEAD CONTAINS_GUARD = GuardBuilder.SET_CONTAINS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__( self, items: list[VariableTracker], @@ -946,6 +1084,7 @@ def reconstruct(self, codegen: "PyCodegen"): codegen.foreach([x.vt for x in self.set_items]) codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) +<<<<<<< HEAD def _fast_set_method(self, tx, fn, args, kwargs): try: res = fn( @@ -958,6 +1097,8 @@ def _fast_set_method(self, tx, fn, args, kwargs): ) return VariableTracker.build(tx, res) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def call_method( self, tx, @@ -966,6 +1107,7 @@ def call_method( kwargs: dict[str, VariableTracker], ) -> "VariableTracker": # We forward the calls to the dictionary model +<<<<<<< HEAD from ..utils import check_constant_args if ( @@ -983,6 +1125,8 @@ def call_method( py_type = self.python_type() return self._fast_set_method(tx, getattr(py_type, name), args, kwargs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if name == "__init__": temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs) tx.output.side_effects.mutation(self) @@ -1008,9 +1152,14 @@ def call_method( super().call_method(tx, name, (result,), kwargs) return result elif name == "isdisjoint": +<<<<<<< HEAD if len(args) != 1: raise_args_mismatch(tx, name) assert not kwargs +======= + assert not kwargs + assert len(args) == 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return variables.UserFunctionVariable( polyfills.set_isdisjoint ).call_function(tx, [self, args[0]], {}) @@ -1072,9 +1221,12 @@ def call_method( else: return ConstantVariable.create(value=None) elif name in ("issubset", "issuperset"): +<<<<<<< HEAD if len(args) != 1: raise_args_mismatch(tx, name) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) op = { "issubset": operator.le, "issuperset": operator.ge, @@ -1085,6 +1237,7 @@ def call_method( return variables.BuiltinVariable(op.get(name)).call_function( tx, [self, other], {} ) +<<<<<<< HEAD elif name in ("__and__", "__or__", "__xor__", "__sub__"): m = { "__and__": "intersection", @@ -1123,6 +1276,8 @@ def call_method( return ConstantVariable.create( cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().call_method(tx, name, args, kwargs) def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): @@ -1133,7 +1288,12 @@ def install_dict_keys_match_guard(self): pass def install_dict_contains_guard(self, tx, args): +<<<<<<< HEAD super().install_dict_contains_guard(tx, args) +======= + # Already EQUALS_MATCH guarded + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class FrozensetVariable(SetVariable): @@ -1158,7 +1318,11 @@ def python_type(self): return frozenset def as_python_constant(self): +<<<<<<< HEAD return frozenset({k.vt.as_python_constant() for k in self.set_items}) +======= + return {k.vt.as_python_constant() for k in self.set_items} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def reconstruct(self, codegen: "PyCodegen"): codegen.foreach([x.vt for x in self.set_items]) @@ -1189,6 +1353,7 @@ def call_method( # In[3]: s # frozenset({1, 2}) return ConstantVariable.create(None) +<<<<<<< HEAD elif name in ( "copy", "difference", @@ -1197,6 +1362,8 @@ def call_method( ): r = super().call_method(tx, name, args, kwargs) return FrozensetVariable(r.items) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().call_method(tx, name, args, kwargs) @@ -1218,6 +1385,7 @@ def debug_repr(self): + "])" ) +<<<<<<< HEAD def install_dict_keys_match_guard(self): # Already EQUALS_MATCH guarded pass @@ -1226,6 +1394,8 @@ def install_dict_contains_guard(self, tx, args): # Already EQUALS_MATCH guarded pass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def set_items(self): return self.items @@ -1283,11 +1453,14 @@ def reconstruct(self, codegen: "PyCodegen"): codegen.load_method(self.kv) codegen.call_method(0) +<<<<<<< HEAD def call_obj_hasattr(self, tx, name): if name in self.python_type().__dict__: return ConstantVariable.create(True) return ConstantVariable.create(False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def call_method( self, tx, @@ -1324,6 +1497,7 @@ def call_method( ) -> "VariableTracker": if name == "__contains__": return self.dv_dict.call_method(tx, name, args, kwargs) +<<<<<<< HEAD elif name in ( "__and__", "__iand__", @@ -1338,6 +1512,8 @@ def call_method( m = getattr(self.set_items, name) r = m(args[0].set_items) return SetVariable(r) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if name in cmp_name_to_op_mapping: if not isinstance(args[0], (SetVariable, DictKeysVariable)): return ConstantVariable.create(NotImplemented) diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index 59f3102c6519b..cabf085dcf310 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -266,10 +266,13 @@ def call_method( return ConstantVariable.create(self.value.size(*const_args, **const_kwargs)) if name == "get_coordinate": return ConstantVariable.create(self.value.get_coordinate()) +<<<<<<< HEAD if name == "get_rank": return ConstantVariable.create(self.value.get_rank()) if name == "get_local_rank": return ConstantVariable.create(self.value.get_local_rank()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if name == "get_group": const_args = [x.as_python_constant() for x in args] const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index d1755c85abf61..7deff2377c35d 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -56,6 +56,7 @@ Unsupported, ) from ..guards import GuardBuilder, install_guard +<<<<<<< HEAD from ..source import ( AttrSource, ClosureSource, @@ -64,6 +65,9 @@ GetItemSource, SkipGuardSource, ) +======= +from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..utils import ( check_constant_args, check_unspec_or_constant_args, @@ -159,15 +163,20 @@ def bind_args_cached(func, tx, fn_source, args, kwargs): ba[name] = wrap_bound_arg(tx, args[i]) elif name in rem_kw: if name in spec.posonly_names: +<<<<<<< HEAD raise_observed_exception( TypeError, tx, args=[ConstantVariable.create(f"{name} is positional-only")], ) +======= + raise TypeError(f"{name} is positional-only") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ba[name] = wrap_bound_arg(tx, rem_kw.pop(name)) elif name in spec.pos_default_map: idx = spec.pos_default_map[name] default_source = None +<<<<<<< HEAD if fn_source and not ( ConstantVariable.is_literal(spec.defaults[idx]) and config.skip_guards_on_constant_func_defaults @@ -184,12 +193,20 @@ def bind_args_cached(func, tx, fn_source, args, kwargs): ) ], ) +======= + if fn_source: + default_source = DefaultsSource(fn_source, idx) + ba[name] = wrap_bound_arg(tx, spec.defaults[idx], default_source) + else: + raise TypeError(f"Missing required positional argument: {name}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # 2) *args extra = args[len(spec.all_pos_names) :] if spec.varargs_name: ba[spec.varargs_name] = wrap_bound_arg(tx, tuple(extra)) elif extra: +<<<<<<< HEAD raise_observed_exception( TypeError, tx, @@ -198,6 +215,10 @@ def bind_args_cached(func, tx, fn_source, args, kwargs): f"Too many positional arguments: got {len(args)}, expected {len(spec.all_pos_names)}" ) ], +======= + raise TypeError( + f"Too many positional arguments: got {len(args)}, expected {len(spec.all_pos_names)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # 3) Keyword-only @@ -210,6 +231,7 @@ def bind_args_cached(func, tx, fn_source, args, kwargs): kwdefault_source = DefaultsSource(fn_source, name, is_kw=True) ba[name] = wrap_bound_arg(tx, spec.kwdefaults[name], kwdefault_source) else: +<<<<<<< HEAD raise_observed_exception( TypeError, tx, @@ -219,11 +241,15 @@ def bind_args_cached(func, tx, fn_source, args, kwargs): ) ], ) +======= + raise TypeError(f"Missing required keyword-only argument: {name}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # 4) **kwargs if spec.varkw_name: ba[spec.varkw_name] = wrap_bound_arg(tx, rem_kw) elif rem_kw: +<<<<<<< HEAD raise_observed_exception( TypeError, tx, @@ -231,6 +257,9 @@ def bind_args_cached(func, tx, fn_source, args, kwargs): ConstantVariable.create(f"Unexpected keyword arguments: {list(rem_kw)}") ], ) +======= + raise TypeError(f"Unexpected keyword arguments: {list(rem_kw)}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ba @@ -304,6 +333,7 @@ def _create_nested_fn( def fn_var_getattr(tx, fn, source, name): source = source and AttrSource(source, name) +<<<<<<< HEAD if source and name == "__annotations__": # We get a large number of silly guards from annotations from inspect @@ -311,6 +341,8 @@ def fn_var_getattr(tx, fn, source, name): # graph is even rarer. So skip guards. source = SkipGuardSource(source) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: subobj = inspect.getattr_static(fn, name) except AttributeError: @@ -424,6 +456,7 @@ def has_self(self): def get_globals(self): return self.fn.__globals__ +<<<<<<< HEAD def get_source(self): source = self.source @@ -431,6 +464,8 @@ def get_source(self): source = self.source_fn return source +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: """ Assume `args` and `kwargs` are VariableTracker arguments for a call to @@ -443,9 +478,13 @@ def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: if not isinstance(fn, FunctionType): raise TypeError("Only supports regular Python functions.") root_tx = parent.output.root_tx +<<<<<<< HEAD source = self.get_source() result = bind_args_cached(fn, root_tx, source, args, kwargs) +======= + result = bind_args_cached(fn, root_tx, self.source, args, kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) init_cellvars(parent, result, fn.__code__) closure = self.fn.__closure__ or () @@ -458,8 +497,15 @@ def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: if cell in side_effects: cell_var = side_effects[cell] +<<<<<<< HEAD elif source: closure_cell = GetItemSource(ClosureSource(source), idx) +======= + elif self.source: + closure_cell = GetItemSource( + AttrSource(self.source, "__closure__"), idx + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) closure_cell_contents = AttrSource(closure_cell, "cell_contents") try: contents_var = VariableTracker.build( @@ -489,8 +535,12 @@ def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: def var_getattr(self, tx: "InstructionTranslator", name: str): if name in cmp_name_to_op_mapping: return variables.GetAttrVariable(self, name) +<<<<<<< HEAD source = self.get_source() return fn_var_getattr(tx, self.fn, source, name) +======= + return fn_var_getattr(tx, self.fn, self.source, name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def call_obj_hasattr( self, tx: "InstructionTranslator", name: str @@ -505,6 +555,10 @@ def call_function( kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": # Handle patch_dynamo_config call +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.fn is torch._dynamo.patch_dynamo_config: try: args_const = [arg.as_python_constant() for arg in args] @@ -521,6 +575,7 @@ def call_function( "Please fix your call to patch_dynamo_config by using simpler inputs. " f"args: {args}, kwargs: {kwargs}" ) from e +<<<<<<< HEAD elif self.fn is torch._dynamo.error_on_graph_break: try: bound = inspect.signature(self.fn).bind(*args, **kwargs) @@ -536,6 +591,10 @@ def call_function( ) from e # Handle a `nonstrict_trace(fn)` call elif self.fn is torch._dynamo.nonstrict_trace: +======= + # Handle a `nonstrict_trace(fn)` call + if self.fn is torch._dynamo.nonstrict_trace: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bound = inspect.signature(self.fn).bind(*args, **kwargs) fn_var = bound.args[0] if not isinstance(fn_var, BaseUserFunctionVariable): @@ -725,11 +784,14 @@ def next_variable(self, tx): finally: counters["unimplemented"] |= counters["inline_call"] +<<<<<<< HEAD def call_obj_hasattr(self, tx, name): if name in self.python_type().__dict__: return ConstantVariable.create(True) return ConstantVariable.create(False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def has_unpack_var_sequence(self, tx): return False @@ -1006,6 +1068,7 @@ def call_function( args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": +<<<<<<< HEAD if not is_generator(self.vt.get_code()): unimplemented_v2( gb_type="non-generator contextlib.contextmanager", @@ -1017,6 +1080,9 @@ def call_function( "Remove the `@contextlib.contextmanager` decorator.", ], ) +======= + assert is_generator(self.vt.get_code()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inline_tracer = self._build_inline_tracer(tx, args, kwargs) code = self.vt.get_code() @@ -1064,6 +1130,7 @@ def _build_inline_tracer(self, tx, args, kwargs): class UserMethodVariable(UserFunctionVariable): """Some unsupported user-defined method""" +<<<<<<< HEAD def __init__(self, fn, obj, source_fn=None, **kwargs) -> None: super().__init__(fn=fn, **kwargs) self.obj = obj @@ -1082,6 +1149,11 @@ def __init__(self, fn, obj, source_fn=None, **kwargs) -> None: # `source_fn` rather than the original `source`. if source_fn is None and kwargs.get("source") is not None: self.source_fn = AttrSource(kwargs.get("source"), "__func__") +======= + def __init__(self, fn, obj, **kwargs) -> None: + super().__init__(fn=fn, **kwargs) + self.obj = obj +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __repr__(self) -> str: return f"{self.__class__.__name__}({self.fn}, {self.obj})" @@ -1157,6 +1229,7 @@ def inspect_parameter_names(self): return super().inspect_parameter_names()[1:] def var_getattr(self, tx: "InstructionTranslator", name: str): +<<<<<<< HEAD if name == "__self__": return self.obj if name == "__func__": @@ -1164,6 +1237,13 @@ def var_getattr(self, tx: "InstructionTranslator", name: str): # information is stored in self.source_fn, use that to construct the # variable tracker. return VariableTracker.build(tx, self.fn, self.source_fn) +======= + source = self.source and AttrSource(self.source, name) + if name == "__self__": + return self.obj + if name == "__func__": + return VariableTracker.build(tx, self.fn, source) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().var_getattr(tx, name) @@ -1468,6 +1548,7 @@ def as_python_constant(self): @classmethod def create_with_source(cls, value, source): +<<<<<<< HEAD # Use closure match guard (i.e. guard on __code__ object instead of # function id) to avoid guarding on nested functions. if inspect.getattr_static(value, "_torchdynamo_disable", False): @@ -1489,6 +1570,13 @@ def create_with_source(cls, value, source): # attribute lookup. They are unlikely to be changed, so we can skip # guarding them. install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) +======= + if not is_wrapper_or_member_descriptor(value): + # These descriptors are not guaranteed to return the same object on + # attribute lookup. They are unlikely to be changed, so we can skip + # guarding them. + install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return cls(value, source=source) def call_function( @@ -1868,6 +1956,7 @@ def call_function( ) -> "VariableTracker": constant_args = check_constant_args(args, kwargs) if constant_args: +<<<<<<< HEAD try: value = self.fn( *[x.as_python_constant() for x in args], @@ -1879,6 +1968,12 @@ def call_function( tx, args=list(map(ConstantVariable.create, exc.args)), ) +======= + value = self.fn( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return variables.UserDefinedClassVariable( value, mutation_type=ValueMutationNew() ) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 5ac883c7d3932..05c4bb6224224 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -26,9 +26,13 @@ import logging import types import warnings +<<<<<<< HEAD from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Optional, TYPE_CHECKING +======= +from typing import Optional, TYPE_CHECKING +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch._C import torch.fx @@ -37,7 +41,10 @@ from torch._dynamo.utils import get_fake_value from torch._dynamo.variables.builtin import BuiltinVariable from torch._dynamo.variables.constant import ConstantVariable +<<<<<<< HEAD from torch._dynamo.variables.ctx_manager import RepararametrizeModuleContextVariable +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo.variables.functions import UserFunctionVariable from torch._dynamo.variables.nn_module import UnspecializedNNModuleVariable from torch._dynamo.variables.tensor import SymNodeVariable @@ -71,6 +78,7 @@ hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile") +<<<<<<< HEAD @dataclass class OutputSpec: """ @@ -98,6 +106,8 @@ def __post_init__(self): assert len(self.masks_to_filter_const_values) == len(self.const_values) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def raise_hard_error_if_graph_break(reason): def deco(fn): @functools.wraps(fn) @@ -105,6 +115,7 @@ def graph_break_as_hard_error(*args, **kwargs): try: return fn(*args, **kwargs) except (Unsupported, ObservedException) as e: +<<<<<<< HEAD import sys if isinstance(e, Unsupported): @@ -118,6 +129,10 @@ def graph_break_as_hard_error(*args, **kwargs): f"{reason} Got {msg}", real_stack ) raise exc.with_traceback(sys.exc_info()[2]) from None +======= + msg = " Scroll up to find out what causes the graph break." + raise UncapturedHigherOrderOpError(reason + msg) from e +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return graph_break_as_hard_error @@ -244,7 +259,11 @@ def inline_call(*args, **kwargs): def _call_function_and_unflatten_output( +<<<<<<< HEAD tx, fn, args, kwargs, flat_example_value, ret_spec +======= + tx, fn, args, kwargs, flat_example_value, ret_treespec +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): from .builder import wrap_fx_proxy @@ -260,6 +279,7 @@ def _call_function_and_unflatten_output( example_value=flat_example_value, ) +<<<<<<< HEAD if ret_spec.masks_to_filter_const_values: from torch._dynamo.external_utils import insert_const_values_with_mask @@ -269,12 +289,19 @@ def _call_function_and_unflatten_output( flat_variable, ret_spec.masks_to_filter_const_values, ret_spec.const_values ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Transform variable back into a list (previously made into a tuple by # speculate_subgraph function) so as to respect the pytree API typing. flat_list_variable = BuiltinVariable(list).call_function(tx, [flat_variable], {}) return ( +<<<<<<< HEAD _make_inlined(tx, pytree.tree_unflatten)(flat_list_variable, ret_spec.treespec) if ret_spec.treespec +======= + _make_inlined(tx, pytree.tree_unflatten)(flat_list_variable, ret_treespec) + if ret_treespec +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else flat_variable ) @@ -312,6 +339,7 @@ def _check_supported_callable_arg( ) +<<<<<<< HEAD def _call_while_loop( self: VariableTracker, tx: "InstructionTranslator", @@ -552,6 +580,8 @@ def unspecialize_carried_inputs(tx, carry) -> VariableTracker: ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def are_same_graph_modules(fn_name, a_mod, b_mod, fake_mode): from torch._subclasses._fake_tensor_utils import _CacheKeyState from torch._subclasses.fake_tensor import extract_tensor_metadata @@ -866,7 +896,10 @@ def fixup_branch_inps(graph, lifted_freevars, shared, unique_l, unique_r): def _insert_or_replace_phs(new_args, name_suffix): for arg in new_args: new_ph = graph.placeholder(arg.node.name + name_suffix) +<<<<<<< HEAD new_ph.meta = arg.node.meta +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Override with new_ph if there exists a old placeholder. if arg in lifted_freevars: old_ph = lifted_freevars[arg].node @@ -913,9 +946,12 @@ def speculate_subgraph( set_subgraph_inputs="automatic", restore_side_effects=True, should_flatten_outputs=False, +<<<<<<< HEAD # if should_flatten_outputs is True, `remove_consts_from_outputs` remove the # const outputs from the subgraph output. remove_consts_from_outputs=True, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) under_activation_checkpoint=False, # TODO - supports input_mutation and aliasing should be False by default for strictness supports_input_mutation=True, @@ -1000,26 +1036,38 @@ def speculate_subgraph( if restore_side_effects: new_side_effects = tx.output.side_effects.clone() +<<<<<<< HEAD prev_side_effects.track_runahead_tensor_and_symvar_side_effects( +======= + prev_side_effects.track_tensor_variables_from_runahead_side_effects( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_side_effects ) tx.output.side_effects = prev_side_effects treespec = None +<<<<<<< HEAD masks_to_filter_const_values = None const_values = None if should_flatten_outputs: from torch._dynamo.external_utils import filter_out_const_values +======= + if should_flatten_outputs: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Flatten the speculated subgraph output. output, treespec = _make_inlined(tx, pytree.tree_flatten)( output ).unpack_var_sequence(tx) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Actually, transform the list (returned by flatten) into a tuple # for dynamo consistency. output = BuiltinVariable(tuple).call_function(tx, [output], {}) +<<<<<<< HEAD if remove_consts_from_outputs: # Filter out the constants and save them into a spec. Filtering # out constants makes the graph simpler for the backends. We @@ -1038,6 +1086,8 @@ def speculate_subgraph( output, masks_to_filter_const_values ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Register output to graph # Modeled off of compile_and_call_fx_graph # TODO: support pytree output @@ -1045,6 +1095,7 @@ def speculate_subgraph( # like bwd. if always_restore: # Nothing left to do here +<<<<<<< HEAD return ( ( output, @@ -1055,6 +1106,9 @@ def speculate_subgraph( tx.output.graph, subtracer.lifted_freevars, ) +======= + return (output, treespec), tx.output.graph, subtracer.lifted_freevars +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: validate_subgraph_output_types(output) @@ -1163,19 +1217,26 @@ def move_lifted_freevars_phs_to_end( context=context, explanation=f"Higher order ops do not support aliasing. Found in {source_target.name()}", hints=[ +<<<<<<< HEAD "Replace `return input` with `return input.clone()` to avoid aliasing.", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "Consider using the debug context to change user code to avoid aliasing.", "Please open an issue.", ], ) return ( +<<<<<<< HEAD ( output, OutputSpec( treespec, masks_to_filter_const_values, const_values ), ), +======= + (output, treespec), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph, lifted_freevars, ) @@ -1215,6 +1276,7 @@ def __init__( @staticmethod def make(value, source=None, **kwargs): +<<<<<<< HEAD variable_class = _hop_name_to_variable_class.get(value.__name__) if variable_class is not None: return variable_class(value, source, **kwargs) @@ -1224,10 +1286,71 @@ def make(value, source=None, **kwargs): if isinstance(value, BaseHOP): return BaseHOPVariable(value, source, **kwargs) unimplemented(f"HigherOrderOperator {value.__name__}") +======= + from torch._higher_order_ops import BaseHOP + + if value.__name__ == "cond": + return CondHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "while_loop": + return WhileLoopHigherOrderVariable(value, source, **kwargs) + elif value.__name__ in ("map", "map_impl"): + return MapHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "executorch_call_delegate": + return ExecutorchCallDelegateHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "out_dtype": + return OutDtypeHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "wrap": + return WrapHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "hints_wrapper": + return HintsWrapperHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "flex_attention": + return FlexAttentionHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "flex_attention_backward": + return FlexAttentionBackwardHighOrderVariable(value, source, **kwargs) + elif value.__name__ in ( + "wrap_activation_checkpoint", + "tag_activation_checkpoint", + ): + return CheckpointHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "_export_tracepoint": + return ExportTracepointHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "trace_wrapped": + return TraceWrappedHigherOrderOperatorVariable(value, source, **kwargs) + elif value.__name__ == "strict_mode": + return StrictModeHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "run_with_rng_state": + return RunWithRNGStateHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "associative_scan": + return AssociativeScanHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "scan": + return ScanHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "call_torchbind": + return CallTorchbindHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "wrap_with_set_grad_enabled": + return WrapWithSetGradEnabledHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "wrap_with_autocast": + return WrapWithAutocastHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "dynamo_bypassing_wrapper": + return DynamoBypassingWrapperHigherOrderVariable(value, source, **kwargs) + elif ( + value.__name__ == "auto_functionalized" + or value.__name__ == "auto_functionalized_v2" + ): + return AutoFunctionalizeHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "invoke_subgraph": + return InvokeSubgraphHigherOrderVariable(value, source, **kwargs) + elif isinstance(value, BaseHOP): + return BaseHOPVariable(value, source, **kwargs) + elif value.__name__ == "custom_function_call": + return CustomFunctionHigherOrderOperatorVariable(value, source, **kwargs) + else: + unimplemented(f"HigherOrderOperator {value.__name__}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def call_function( self, tx: "InstructionTranslator", +<<<<<<< HEAD args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: @@ -1242,6 +1365,9 @@ def _call_function( self, tx: "InstructionTranslator", args: Sequence[VariableTracker], +======= + args: list[VariableTracker], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs: dict[str, VariableTracker], ) -> VariableTracker: unimplemented(f"HigherOrderOperator {self.value.__name__}") @@ -1255,7 +1381,11 @@ class CustomFunctionHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable Wraps torch._functorch.autograd_function.custom_function_call """ +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -1266,7 +1396,11 @@ def _call_function( torch._dynamo.variables.UserDefinedObjectVariable( self.value, source=self.source ), +<<<<<<< HEAD source=AttrSource(self.source, "__call__"), +======= + source=AttrSource(AttrSource(self.source, "__call__"), "__func__"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ).call_function(tx, args, kwargs) @@ -1277,7 +1411,11 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( reason="Cond doesn't work unless it is captured completely with torch.compile." ) +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -1333,9 +1471,13 @@ def _call_function( f"{operands.python_type()}", ) operands_seq = operands.unpack_var_sequence(tx) +<<<<<<< HEAD if not only_consist_of( operands, (TensorVariable, ConstantVariable, SymNodeVariable) ): +======= + if not only_consist_of(operands, (TensorVariable, ConstantVariable)): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unimplemented( "Expect operands to be a tuple of pytrees that only consists of tensor leaves." ) @@ -1362,7 +1504,11 @@ def speculate_branch(branch): ix = 1 if branch else 2 # TODO: Support kwargs ( +<<<<<<< HEAD (ret_val, ret_spec), +======= + (ret_val, ret_treespec), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ret_graph, ret_lifted_freevars, ) = speculate_subgraph( @@ -1373,8 +1519,11 @@ def speculate_branch(branch): "cond", source_target=self.value, should_flatten_outputs=True, +<<<<<<< HEAD # TODO - removing consts from control flow ops need more work remove_consts_from_outputs=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) supports_input_mutation=self.supports_input_mutation, supports_aliasing=self.supports_aliasing, ) @@ -1390,23 +1539,42 @@ def speculate_branch(branch): "Expected branches to return a possibly nested pytree of tensors " f"or constant ints but it consists of others {ret.python_type()}.", ) +<<<<<<< HEAD return ret_val, ret_spec, ret_graph, ret_lifted_freevars (true_r, true_spec, true_graph, true_lifted_freevars) = speculate_branch(True) +======= + return ret_val, ret_treespec, ret_graph, ret_lifted_freevars + + (true_r, true_treespec, true_graph, true_lifted_freevars) = speculate_branch( + True + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) true_nn_modules = dict(tx.output.nn_modules) ( false_r, +<<<<<<< HEAD false_spec, +======= + false_treespec, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) false_graph, false_lifted_freevars, ) = speculate_branch(False) false_nn_modules = dict(tx.output.nn_modules) +<<<<<<< HEAD same_spec = _make_inlined(tx, pytree.TreeSpec.__eq__)( true_spec.treespec, false_spec.treespec ) if not same_spec.as_python_constant(): +======= + same_treespec = _make_inlined(tx, pytree.TreeSpec.__eq__)( + true_treespec, false_treespec + ) + if not same_treespec.as_python_constant(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unimplemented("Expected branches to return the same pytree structure.") ( @@ -1451,7 +1619,11 @@ def speculate_branch(branch): p_args, {}, None, +<<<<<<< HEAD true_spec, +======= + true_treespec, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -1461,7 +1633,11 @@ def __init__(self, hop, source, script_obj_var, method_name) -> None: self.script_obj_var = script_obj_var self.method_name = method_name +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: list[VariableTracker], @@ -1514,12 +1690,17 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( reason="while_loop doesn't work unless it is captured completely with torch.compile." ) +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: +<<<<<<< HEAD return _call_while_loop(self, tx, args, kwargs, stack_output=False) @@ -1537,6 +1718,236 @@ def _call_function( kwargs: dict[str, VariableTracker], ) -> VariableTracker: return _call_while_loop(self, tx, args, kwargs, stack_output=True) +======= + from torch._higher_order_ops.while_loop import _create_unbacked_symint + + from . import TensorVariable + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + cond_fn, body_fn, operands, additional_inputs = args + + # Input checks + for i, k in enumerate(["cond_fn", "body_fn", "operands"]): + if v := kwargs.pop(k, None): + assert i == len(args), ( + "did not provide the right number of non-keyword args" + ) + args.append(v) + + if kwargs: + unimplemented( + f"torch.while_loop: Got unexpected kwargs: {list(kwargs.keys())}" + ) + + if len(args) != 4: + unimplemented( + f"Expected 4 arguments but got {len(args)}.\n" + f"Usage: while_loop(cond_fn, body_fn, operands)", + ) + + # cond_fn and body_fn input check + _check_supported_callable_arg(tx, cond_fn, "cond_fn") + _check_supported_callable_arg(tx, body_fn, "body_fn") + + # operands input check + operands_seq = operands.unpack_var_sequence(tx) + + # additional_inputs input check + if not isinstance(additional_inputs, (ListVariable, TupleVariable)): + unimplemented( + f"Expected additional_inputs to be a list/tuple but got " + f"{additional_inputs.python_type()}. It seems to be an " + f"internal error, please report an issue to PyTorch." + ) + additional_inputs_seq = additional_inputs.unpack_var_sequence(tx) + + with discard_graph_changes(tx): + # See NOTE [unspecialize int carry with unbacked symints] + # Note: this must be run under discard graph changes. + def create_unbacked_sym_node_var(tx) -> SymNodeVariable: + example_value = _create_unbacked_symint( + tx.output.fake_mode, ignore_fresh_unbacked_symbols=True + ) + proxy = tx.output.current_tracer.create_graph_input( + "unbacked_symint", type(example_value), example_value + ) + return SymNodeVariable.create(tx, proxy, example_value) + + new_operands_seq = [ + ( + create_unbacked_sym_node_var(tx) + if ( + isinstance(carry, ConstantVariable) + and carry.python_type() is int + ) + or (isinstance(carry, SymNodeVariable)) + else carry + ) + for carry in operands_seq + ] + + # create cond subgrpahs + ( + (cond_r, _cond_treespec), + cond_graph, + cond_lifted_freevars, + ) = speculate_subgraph( + tx, + cond_fn, + new_operands_seq + additional_inputs_seq, + {}, + "while_loop", + source_target=self.value, + # NOTE [why we cannot use "automatic" for while_loop]: + # The reason is that we want to enforce + # the ordering of inputs and outputs to be consistent and the the ordering + # of cond_fn and body_fn to the consistent. + # e.g. suppose we use "automatic" and we have: + # + # def body_fn(ph1, ph2): + # new_a, new_b = ph2.cos(), ph1.sin() + # return new_a, new_b + # + # a, b = torch.randn(3), torch.randn(3) + # new_a, new_b = body_fn(a, b) + # + # Using automatic, the ordering of arguments will be the order that they're + # used. In this example, the capture graph looks like: + # + # def captured_body(ph1, ph2): + # new_a, new_b = ph1.cos(), ph2.add_(1) + # return new_a, new_b + # + # This is fine when we change the calling convention of captured_body to be + # new_a, new_b = captured_body(b, a). + # But for while_loop, the next iteration's input is previous iteration output + # we'll end up feeding captured_body(new_a, new_b) instead. + # So it's best we always enforce the ordering of carried_inputs the same as outputs + # with "flatten_manual". + set_subgraph_inputs="flatten_manual", + supports_input_mutation=self.supports_input_mutation, + supports_aliasing=self.supports_aliasing, + ) + cond_nn_modules = dict(tx.output.nn_modules) + validate_subgraph_output_types(cond_r) + if isinstance(cond_r, TensorVariable): + cond_r_meta = _extract_tensor_metadata( + cond_r.proxy.node.meta["example_value"], include_contiguity=False + ) + if ( + not cond_r_meta.dtype == torch.bool + or not cond_r_meta.shape == torch.Size([]) + ): + unimplemented( + f"Expected cond_fn to return a scalar tensor or a bool but got {cond_r_meta.shape}" + ) + elif isinstance(cond_r, ConstantVariable): + # short-circuiting while_loop when cond_fn returns a constant such as 0, 1 True or False + pred = cond_r.as_python_constant() + if pred: + unimplemented( + f"Infinite loop detected because while_loop's cond_fn always returns the same value {pred}" + ) + else: + return operands + + # create body subgraph + ( + (body_r, body_treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + body_fn, + new_operands_seq + additional_inputs_seq, + {}, + "while_loop", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + should_flatten_outputs=True, + supports_input_mutation=False, + supports_aliasing=False, + ) + validate_subgraph_output_types(body_r) + + # We set include contiguity=False because we have vmap x HOP tests, where if + # include_contiguity=True will call t.is_contiguous inside of vmap and get an error + # "querying is_contiguous inside of vmap for memory_format other than + # torch.contiguous_format is not yet implemented". This is okay because stride + # is still checked. + check_meta_consistency_vt( + body_r.unpack_var_sequence(tx), + operands_seq, + "body_fn_output", + "carried_inputs", + include_contiguity=False, + ) + + ( + cond_graph, + body_graph, + cond_shared, + _body_shared, + cond_unique, + body_unique, + ) = _merge_graph_inputs( + cond_graph, + cond_lifted_freevars, + "cond_fn", + body_graph, + body_lifted_freevars, + "body_fn", + ) + + # Note: cond_shared and body_shared refer to the same proxy in parent graph + # so using either of them is OK. Use cond_shared as it doesn't matter. + additional_lifted_inputs = cond_shared + cond_unique + body_unique + + body_nn_modules = dict(tx.output.nn_modules) + + cond_name = tx.output.install_subgraph( + "cond_fn", + torch.fx.GraphModule(cond_nn_modules, cond_graph), + ) + body_name = tx.output.install_subgraph( + "body_fn", + torch.fx.GraphModule(body_nn_modules, body_graph), + ) + + cond_node = make_attr(tx, cond_name) + body_node = make_attr(tx, body_name) + + p_args = ( + cond_node, + body_node, + tuple([operand.as_proxy() for operand in operands_seq]), + tuple( + [inp.as_proxy() for inp in additional_inputs_seq] + + additional_lifted_inputs + ), + ) + + flat_example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + unspecialized_flat_example_value = pytree.tree_map_only( + (int, torch.SymInt), + lambda _: _create_unbacked_symint( + tx.output.fake_mode, ignore_fresh_unbacked_symbols=False + ), + flat_example_value, + ) + return _call_function_and_unflatten_output( + tx, + torch.ops.higher_order.while_loop, + p_args, + {}, + unspecialized_flat_example_value, + body_treespec, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable): @@ -1546,7 +1957,11 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( reason="associative_scan must be captured completely with torch.compile." ) +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: list[VariableTracker], @@ -1625,7 +2040,11 @@ def arg_extractor(combine_fn, xs, additional_inputs): sub_args = sub_args + sub_args_additional_inputs ( +<<<<<<< HEAD (combine_result, _combine_spec), +======= + (combine_result, _combine_treespec), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) combine_graph, combine_lifted_freevars, ) = speculate_subgraph( @@ -1734,13 +2153,26 @@ def arg_extractor(combine_fn, xs, additional_inputs): additional_inputs_proxy, ) +<<<<<<< HEAD +======= + with tx.fake_mode: + out_meta = tuple( + inp_proxy.node.meta["example_value"].clone() for inp_proxy in xs_proxy + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return _call_function_and_unflatten_output( tx, torch.ops.higher_order.associative_scan, p_args, {}, +<<<<<<< HEAD None, OutputSpec(xs_treespec), +======= + out_meta, + xs_treespec, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -1751,13 +2183,21 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( reason="scan must be captured completely with torch.compile." ) +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: +<<<<<<< HEAD from torch._higher_order_ops.scan import _extract_carry_and_out +======= + from torch._higher_order_ops.scan import _extract_carry_and_out, stack_y +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._higher_order_ops.utils import first_slice_copy args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) @@ -1848,7 +2288,11 @@ def arg_extractor(combine_fn, init, xs, additional_inputs): sub_args = sub_args_init + sub_args_inp + sub_args_additional_inputs ( +<<<<<<< HEAD (combine_result, _combine_spec), +======= + (combine_result, _combine_treespec), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) combine_graph, combine_lifted_freevars, ) = speculate_subgraph( @@ -1882,7 +2326,11 @@ def arg_extractor(combine_fn, init, xs, additional_inputs): f"Expect combine_fn to return a tuple (next_carry, y) but got {combine_result_vars}" ) carry_tree, out_vars = combine_result_vars +<<<<<<< HEAD carry_vars, _ = _make_inlined(tx, pytree.tree_flatten)( +======= + carry_vars, carry_treespec = _make_inlined(tx, pytree.tree_flatten)( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) carry_tree ).unpack_var_sequence(tx) carry_vars = carry_vars.unpack_var_sequence(tx) @@ -1891,9 +2339,13 @@ def arg_extractor(combine_fn, init, xs, additional_inputs): ).unpack_var_sequence(tx) # additional output checking +<<<<<<< HEAD _combine_spec = OutputSpec( _make_inlined(tx, pytree.tree_structure)(combine_result) ) +======= + _combine_treespec = _make_inlined(tx, pytree.tree_structure)(combine_result) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) check_meta_consistency_vt( init_vars, @@ -1922,6 +2374,10 @@ def arg_extractor(combine_fn, init, xs, additional_inputs): additional_inputs_proxy = list(additional_inputs.as_proxy()) + list( combine_freevars_proxy ) +<<<<<<< HEAD +======= + y_proxies = [out_var.as_proxy() for out_var in out_vars] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph) combine_fn_name = tx.output.install_subgraph("scan_combine_fn", combine_gm) @@ -1933,8 +2389,24 @@ def arg_extractor(combine_fn, init, xs, additional_inputs): additional_inputs_proxy, ) +<<<<<<< HEAD return _call_function_and_unflatten_output( tx, torch.ops.higher_order.scan, p_args, {}, None, _combine_spec +======= + with tx.fake_mode: + example_carry = [ + init_p.node.meta["example_value"].clone() for init_p in init_proxy + ] + # For the fake mode, we need to duplicate the init tensor along the dim + # to have the same size as the xs arguments + example_stacked_out = [ + stack_y(y.node.meta["example_value"], scan_length) for y in y_proxies + ] + out_meta = [*example_carry, *example_stacked_out] + + return _call_function_and_unflatten_output( + tx, torch.ops.higher_order.scan, p_args, {}, out_meta, _combine_treespec +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -1954,7 +2426,11 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( reason="map doesn't work unless it is captured completely with torch.compile." ) +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: list[VariableTracker], @@ -2013,8 +2489,11 @@ def _call_function( source_target=self.value, set_subgraph_inputs="flatten_manual", should_flatten_outputs=True, +<<<<<<< HEAD # TODO - removing consts from control flow ops need more work remove_consts_from_outputs=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) supports_input_mutation=self.supports_input_mutation, supports_aliasing=self.supports_aliasing, ) @@ -2052,7 +2531,11 @@ def _call_function( class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable): +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -2125,6 +2608,7 @@ def call_function( return super().call_function(tx, args, kwargs) +<<<<<<< HEAD class ReparametrizeModuleCallVariable(FunctorchHigherOrderVariable): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -2136,6 +2620,8 @@ def call_function( return RepararametrizeModuleContextVariable(ctx_manager_vt, args[0]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): supports_input_mutation = True supports_aliasing = True @@ -2202,7 +2688,11 @@ def create_wrapped_node( return proxy_args, {}, example_value, body_r, treespec, body_gmod, body_name +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -2390,7 +2880,11 @@ class HintsWrapperHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( reason="Hints_wrapper doesn't work unless it is captured completely with torch.compile." ) +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" ) -> "VariableTracker": _check_supported_callable_arg(tx, args[0], "body_fn") @@ -2460,7 +2954,11 @@ def _call_function( class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable): +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -2498,7 +2996,11 @@ class StrictModeHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( reason="strict_mode HOO doesn't work unless it is captured completely with torch.compile." ) +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -2516,7 +3018,11 @@ def _call_function( ) ( +<<<<<<< HEAD (ret_val, ret_spec), +======= + (ret_val, ret_treespec), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ret_graph, ret_lifted_freevars, ) = speculate_subgraph( @@ -2554,12 +3060,20 @@ def _call_function( p_args, {}, flat_example_value, +<<<<<<< HEAD ret_spec, +======= + ret_treespec, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) class CheckpointHigherOrderVariable(WrapHigherOrderVariable): +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: list[VariableTracker], @@ -2568,6 +3082,11 @@ def _call_function( from torch._higher_order_ops.wrap import TagActivationCheckpoint from torch.utils.checkpoint import noop_context_fn +<<<<<<< HEAD +======= + from .builder import wrap_fx_proxy + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) context_fn = None if "context_fn" in kwargs and kwargs["context_fn"] != noop_context_fn: ctx = kwargs.pop("context_fn") @@ -2591,7 +3110,11 @@ def _call_function( _, example_value, _body_r, +<<<<<<< HEAD out_spec, +======= + treespec, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) checkpointed_gmod, _, ) = self.create_wrapped_node( @@ -2607,6 +3130,7 @@ def _call_function( _, checkpoint_kwargs = proxy_args_kwargs([], checkpoint_kwargs) +<<<<<<< HEAD return _call_function_and_unflatten_output( tx, self.value, @@ -2616,17 +3140,49 @@ def _call_function( out_spec, ) +======= + # Store the invocation as a call + variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=tuple(p_args), + kwargs=checkpoint_kwargs, + ), + example_value=example_value, + ) + + if treespec is None: + return variable + + # Transform variable back into a list (previously made into a tuple by + # speculate_subgraph function) so as to respect the pytree API typing. + variable = BuiltinVariable(list).call_function(tx, [variable], {}) + + return _make_inlined(tx, pytree.tree_unflatten)(variable, treespec) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable): def __init__(self, hop, source) -> None: super().__init__(hop, source) +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: +<<<<<<< HEAD +======= + from .builder import wrap_fx_proxy + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) func_var = args[0] if isinstance(func_var, torch._dynamo.variables.UserFunctionVariable): @@ -2644,7 +3200,11 @@ def _call_function( _, example_value, _body_r, +<<<<<<< HEAD out_spec, +======= + treespec, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gmod, _, ) = self.create_wrapped_node( @@ -2660,6 +3220,7 @@ def _call_function( gmod_meta_key = "_dynamo_bypassing_wrapper_fn" gmod.meta[gmod_meta_key] = func +<<<<<<< HEAD return _call_function_and_unflatten_output( tx, self.value, @@ -2669,6 +3230,29 @@ def _call_function( out_spec, ) +======= + # Store the invocation as a call + variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=(gmod_meta_key,) + tuple(p_args), + kwargs={}, + ), + example_value=example_value, + ) + + if treespec is None: + return variable + + # Transform variable back into a list (previously made into a tuple by + # speculate_subgraph function) so as to respect the pytree API typing. + variable = BuiltinVariable(list).call_function(tx, [variable], {}) + + return _make_inlined(tx, pytree.tree_unflatten)(variable, treespec) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable): def call_function( @@ -2694,7 +3278,11 @@ def call_function( class RunWithRNGStateHigherOrderVariable(TorchHigherOrderOperatorVariable): +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -2717,7 +3305,11 @@ def _call_function( class AutoFunctionalizeHigherOrderVariable(TorchHigherOrderOperatorVariable): +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" ) -> "VariableTracker": from .builder import wrap_fx_proxy @@ -2754,7 +3346,11 @@ def to_proxy(self, tx, arg): else: return arg.as_proxy() +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" ) -> "VariableTracker": from .builder import wrap_fx_proxy @@ -2786,7 +3382,11 @@ class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): here in the call to dynamo from compiled autograd. """ +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -2847,7 +3447,11 @@ def create_scalar(): with TransformGetItemToIndex(): ( +<<<<<<< HEAD (_body_output, _body_spec), +======= + (_body_output, _body_treespec), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) body_graph, body_lifted_freevars, ) = speculate_subgraph( @@ -2877,7 +3481,11 @@ def create_scalar(): return proxy_args +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -3359,7 +3967,11 @@ class BaseHOPVariable(WrapHigherOrderVariable): def python_type(self): return type(self.value) +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -3390,7 +4002,11 @@ def _call_function( class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable): +<<<<<<< HEAD supports_input_mutation = True +======= + supports_input_mutation = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) supports_aliasing = False def install_subgraph_in_output_graph( @@ -3456,7 +4072,11 @@ def install_subgraph_in_output_graph( @raise_hard_error_if_graph_break( reason="torch.compile requires the `nested_compile_region` decorated function to be capturable into a single graph", ) +<<<<<<< HEAD def _call_function( +======= + def call_function( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, tx: "InstructionTranslator", args: "list[VariableTracker]", @@ -3495,6 +4115,7 @@ def _call_function( flat_example_value, treespec, ) +<<<<<<< HEAD # Map operator names to their corresponding variable for fast TorchHigherOrderOperatorVariable.make() @@ -3526,3 +4147,5 @@ def _call_function( "invoke_subgraph": InvokeSubgraphHigherOrderVariable, "custom_function_call": CustomFunctionHigherOrderOperatorVariable, } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 80b9915aaa217..b616c3aee6f4a 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -16,8 +16,14 @@ """ import itertools +<<<<<<< HEAD import sys from typing import TYPE_CHECKING, Union +======= +import operator +import sys +from typing import Optional, TYPE_CHECKING, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_instruction @@ -59,6 +65,7 @@ def call_function( ) -> "VariableTracker": # See also: module `torch._dynamo.polyfills.itertools` +<<<<<<< HEAD if self.value is itertools.product: if any(kw != "repeat" for kw in kwargs.keys()): unimplemented_v2( @@ -78,6 +85,88 @@ def call_function( variables.TupleVariable(list(item)) for item in itertools.product(*seqs, repeat=r) ] +======= + if ( + self.value is itertools.product + and not kwargs + and all(arg.has_unpack_var_sequence(tx) for arg in args) + ): + seqs = [arg.unpack_var_sequence(tx) for arg in args] + items = [ + variables.TupleVariable(list(item)) for item in itertools.product(*seqs) + ] + return variables.ListIteratorVariable( + items, mutation_type=ValueMutationNew() + ) + elif self.value is itertools.accumulate: + from .builtin import BuiltinVariable + + if any(key not in ["initial", "func"] for key in kwargs.keys()): + unimplemented_v2( + gb_type="Unsupported kwargs for itertools.accumulate", + context=f"call_function {self} {args} {kwargs}", + explanation=f"Expected kwargs: 'initial', 'func', but got " + f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}", + hints=[*graph_break_hints.USER_ERROR], + ) + + if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx): + seq = args[0].unpack_var_sequence(tx) + + if "func" in kwargs and len(args) == 1: + func = kwargs["func"].call_function + elif len(args) == 2: + func = args[1].call_function + elif len(args) == 1: + # Default to operator.add + func = BuiltinVariable(operator.add).call_function + else: + unimplemented_v2( + gb_type="Unsupported `func` in itertools.accumulate", + context=f"call_function {self} {args} {kwargs}", + explanation="Dynamo does not know how to get the " + "function to use for itertools.accumulate. " + "itertools.accumulate expects the `func` as the second " + "argument or as a keyword argument.", + hints=[*graph_break_hints.USER_ERROR], + ) + else: + unimplemented_v2( + gb_type="Unsupported arguments for itertools.accumulate", + context=f"call_function {self} {args} {kwargs}", + explanation="Dynamo does not know how to trace " + f"itertools.accumulate with args: {args} and kwargs: {kwargs}. " + "itertools.accumulate expects an iterable, an optional " + "binary function for accumulation, and an optional initial " + "value to set the starting state.", + hints=[ + "Make sure the arguments to itertools.accumulate are correct.", + *graph_break_hints.SUPPORTABLE, + ], + ) + + items = [] + acc = kwargs.get("initial") + if acc is not None: + items.append(acc) + for item in seq: + if acc is None: + acc = item + else: + try: + acc = func(tx, [acc, item], {}) + except Exception as e: + unimplemented_v2( + gb_type="Unexpected failure during itertools.accumulate() iteration", + context=f"call_function {self} {args} {kwargs}", + explanation="Unexpected failure in invoking function during accumulate. " + f"Failed running func {func}({item}{acc})", + hints=[*graph_break_hints.DIFFICULT], + from_exc=e, + ) + items.append(acc) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return variables.ListIteratorVariable( items, mutation_type=ValueMutationNew() ) @@ -190,6 +279,7 @@ def keyfunc(x): return variables.CountIteratorVariable( *args, mutation_type=ValueMutationNew() ) +<<<<<<< HEAD elif ( self.value is itertools.permutations and (len(args) == 1 or (len(args) == 2 and args[1].is_python_constant())) @@ -207,6 +297,11 @@ def keyfunc(x): ] return variables.ListIteratorVariable( items, mutation_type=ValueMutationNew() +======= + elif self.value is itertools.cycle: + return variables.CycleIteratorVariable( + *args, mutation_type=ValueMutationNew() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: return super().call_function(tx, args, kwargs) @@ -247,6 +342,7 @@ def has_force_unpack_var_sequence(self, tx) -> bool: return True +<<<<<<< HEAD class ObjectIteratorVariable(IteratorVariable): """ VariableTracker for iter(obj) that implements the iterator protocol (i.e., @@ -279,6 +375,8 @@ def next_variable(self, tx): raise +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class RepeatIteratorVariable(IteratorVariable): def __init__(self, item: VariableTracker, **kwargs) -> None: super().__init__(**kwargs) @@ -332,6 +430,57 @@ def reconstruct(self, codegen: "PyCodegen"): codegen.extend_output(create_call_function(2, False)) +<<<<<<< HEAD +======= +class CycleIteratorVariable(IteratorVariable): + def __init__( + self, + iterator: IteratorVariable, + saved: Optional[list[VariableTracker]] = None, + saved_index: int = 0, + item: Optional[VariableTracker] = None, + **kwargs, + ) -> None: + if saved is None: + saved = [] + super().__init__(**kwargs) + self.iterator = iterator + self.saved = saved + self.saved_index = saved_index + self.item = item + + def next_variable(self, tx): + assert self.is_mutable() + + if self.iterator is not None: + try: + new_item = self.iterator.next_variable(tx) + if len(self.saved) > MAX_ITERATOR_LIMIT: + unimplemented_v2( + gb_type="input iterator to itertools.cycle has too many items", + context=f"next({self})", + explanation=f"Has reached internal Dynamo max iterator limit: {MAX_ITERATOR_LIMIT}", + hints=[], + ) + tx.output.side_effects.mutation(self) + self.saved.append(new_item) + self.item = new_item + if self.item is None: + return self.next_variable(tx) + return self.item + except ObservedUserStopIteration: + handle_observed_exception(tx) + self.iterator = None + return self.next_variable(tx) + elif len(self.saved) > 0: + tx.output.side_effects.mutation(self) + self.saved_index = (self.saved_index + 1) % len(self.saved) + return self.item + else: + raise_observed_exception(StopIteration, tx) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ZipVariable(IteratorVariable): """ Represents zip(*iterables) @@ -345,7 +494,11 @@ class ZipVariable(IteratorVariable): def __init__( self, +<<<<<<< HEAD iterables: list[VariableTracker], +======= + iterables: list[Union[list[VariableTracker], VariableTracker]], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) strict: bool = False, **kwargs, ) -> None: @@ -379,10 +532,13 @@ def unpack_var_sequence(self, tx) -> list["VariableTracker"]: def next_variable(self, tx): assert self.is_mutable() +<<<<<<< HEAD if len(self.iterables) == 0: raise_observed_exception(StopIteration, tx) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) old_index = self.index args = [] @@ -546,10 +702,14 @@ def _next(): while True: item = _next() self.index += 1 +<<<<<<< HEAD if isinstance(self.fn, ConstantVariable) and self.fn.value is None: res = item else: res = self.fn.call_function(tx, [item], {}) +======= + res = self.fn.call_function(tx, [item], {}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pred_res = variables.UserFunctionVariable( polyfills.predicate ).call_function(tx, [res], {}) diff --git a/torch/_dynamo/variables/lazy.py b/torch/_dynamo/variables/lazy.py index 44d346a48cd2a..a4b3c946dfb72 100644 --- a/torch/_dynamo/variables/lazy.py +++ b/torch/_dynamo/variables/lazy.py @@ -17,7 +17,10 @@ def __init__(self, value: Any, source: Any) -> None: assert source self.value = value self.source = source +<<<<<<< HEAD self.name_hint: Optional[str] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.vt: Optional[VariableTracker] = None def realize(self) -> None: @@ -32,12 +35,17 @@ def realize(self) -> None: else: self.vt = builder.VariableBuilder(tx, self.source)(self.value) +<<<<<<< HEAD if self.name_hint is not None: self.vt.set_name_hint(self.name_hint) del self.value del self.source del self.name_hint +======= + del self.value + del self.source +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @final @@ -97,12 +105,15 @@ def peek_value(self) -> Any: assert not self.is_realized() return self._cache.value +<<<<<<< HEAD def set_name_hint(self, name: str) -> None: if self.is_realized(): self._cache.vt.set_name_hint(name) # type: ignore[union-attr] else: self._cache.name_hint = name +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __str__(self) -> str: if self.is_realized(): return repr(self.unwrap()) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 654bf2e756c47..328073ab0e354 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -19,7 +19,10 @@ class that handles its unique behaviors while integrating with Dynamo's import collections import inspect import operator +<<<<<<< HEAD import sys +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Optional, TYPE_CHECKING import torch @@ -28,7 +31,11 @@ class that handles its unique behaviors while integrating with Dynamo's from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_instruction from ..exc import raise_observed_exception, unimplemented_v2 +<<<<<<< HEAD from ..source import AttrSource, NamedTupleFieldsSource +======= +from ..source import AttrSource +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..utils import ( cmp_name_to_op_mapping, cmp_name_to_op_str_mapping, @@ -38,8 +45,11 @@ class that handles its unique behaviors while integrating with Dynamo's Lit, namedtuple_fields, odict_values, +<<<<<<< HEAD raise_args_mismatch, range_iterator, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) set_example_value, ) from .base import ValueMutationNew, VariableTracker @@ -111,9 +121,12 @@ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): index = arg.as_python_constant() if isinstance(index, slice): +<<<<<<< HEAD if index.step == 0: msg = ConstantVariable.create("slice step cannot be zero") raise_observed_exception(ValueError, tx, args=[msg]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Set source to None because slicing a list gives a new local return self.clone( items=self.items[index], @@ -142,12 +155,15 @@ def call_method( if name == "__getitem__": from .tensor import TensorVariable +<<<<<<< HEAD if len(args) != 1: msg = ConstantVariable.create( f"{name} takes exactly one argument ({len(args)} given)" ) raise_observed_exception(TypeError, tx, args=[msg]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert not kwargs and len(args) == 1 if isinstance(args[0], TensorVariable): value = get_fake_value(args[0].as_proxy().node, tx) @@ -164,6 +180,7 @@ def call_method( ) else: value = args[0] +<<<<<<< HEAD if value.python_type() not in (int, slice): msg = f"indices must be integers or slices, not {value.python_type()}" @@ -178,11 +195,20 @@ def call_method( if not len(args): raise_args_mismatch(tx, name) +======= + return self.getitem_const(tx, value) + elif name == "__contains__": + assert len(args) == 1 + assert not kwargs + return iter_contains(self.unpack_var_sequence(tx), args[0], tx) + elif name == "index": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return tx.inline_user_function_return( VariableTracker.build(tx, polyfills.index), [self] + list(args), kwargs, ) +<<<<<<< HEAD elif name == "count": if len(args) != 1: raise_args_mismatch(tx, name) @@ -229,6 +255,9 @@ def call_method( if len(args) != 1: raise_args_mismatch(tx, name) +======= + elif name in cmp_name_to_op_mapping: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) left = self right = args[0] # TODO this type check logic mirrors the following @@ -278,6 +307,7 @@ def __init__(self, items, **kwargs) -> None: else: raise AssertionError +<<<<<<< HEAD def maybe_as_int(x): return ( ConstantVariable(int(x.value)) if isinstance(x, ConstantVariable) else x @@ -288,6 +318,8 @@ def maybe_as_int(x): step = maybe_as_int(step) stop = maybe_as_int(stop) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert stop is not None super().__init__([start, stop, step], **kwargs) @@ -374,12 +406,16 @@ def apply_index(self, index): index = length + index if index < 0 or index >= length: +<<<<<<< HEAD tx = torch._dynamo.symbolic_convert.InstructionTranslator.current_tx() raise_observed_exception( IndexError, tx, args=[ConstantVariable("range object index out of range")], ) +======= + raise IndexError(f"index {index} is out of range") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return variables.ConstantVariable.create(self.start() + (index * self.step())) @@ -413,11 +449,16 @@ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): if isinstance(index, slice): return self.apply_slice(index) +<<<<<<< HEAD elif isinstance(index, int): return self.apply_index(index) else: msg = ConstantVariable("range indices must be integers or slices") raise_observed_exception(TypeError, tx, args=[msg]) +======= + else: + return self.apply_index(index) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def as_proxy(self): return self.python_type()(*self._as_proxy()) @@ -433,6 +474,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach(self.items) codegen.extend_output(create_call_function(3, False)) +<<<<<<< HEAD def call_obj_hasattr( self, tx: "InstructionTranslator", name: str ) -> "VariableTracker": @@ -521,6 +563,19 @@ def var_getattr(self, tx: "InstructionTranslator", name): if name in fields: return self.items[fields.index(name)] return super().var_getattr(tx, name) +======= + def var_getattr(self, tx: "InstructionTranslator", name): + fields = ["start", "stop", "step"] + if name not in fields: + unimplemented_v2( + gb_type="Unsupported attribute for range() object", + context=f"var_getattr {self} {name}", + explanation=f"Expected attribute to be one of {','.join(fields)} " + f"but got {name}", + hints=[*graph_break_hints.USER_ERROR], + ) + return self.items[fields.index(name)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class CommonListMethodsVariable(BaseListVariable): @@ -539,12 +594,16 @@ def call_method( if name == "append" and self.is_mutable(): assert not kwargs +<<<<<<< HEAD if len(args) != 1: raise_args_mismatch(tx, name) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (arg,) = args tx.output.side_effects.mutation(self) self.items.append(arg) return ConstantVariable.create(None) +<<<<<<< HEAD elif name == "extend" and self.is_mutable(): if len(args) != 1 or kwargs: raise_args_mismatch(tx, name) @@ -553,14 +612,27 @@ def call_method( msg = ConstantVariable.create(f"{type(args[0])} object is not iterable") raise_observed_exception(TypeError, tx, args=[msg]) +======= + elif ( + name == "extend" + and self.is_mutable() + and args + and args[0].has_force_unpack_var_sequence(tx) + ): + assert not kwargs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (arg,) = args arg.force_apply_to_var_sequence( tx, lambda item: self.call_method(tx, "append", [item], {}) ) return ConstantVariable.create(None) elif name == "insert" and self.is_mutable(): +<<<<<<< HEAD if kwargs or len(args) != 2: raise_args_mismatch(tx, name) +======= + assert not kwargs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) idx, value = args if isinstance(idx, SymNodeVariable): const_idx = idx.evaluate_expr() @@ -571,6 +643,7 @@ def call_method( return ConstantVariable.create(None) elif name == "pop" and self.is_mutable(): assert not kwargs +<<<<<<< HEAD if kwargs or len(args) > 1: raise_args_mismatch(tx, name) @@ -588,6 +661,12 @@ def call_method( elif name == "clear" and self.is_mutable(): if args or kwargs: raise_observed_exception(TypeError, tx) +======= + tx.output.side_effects.mutation(self) + return self.items.pop(*[a.as_python_constant() for a in args]) + elif name == "clear" and self.is_mutable(): + assert not kwargs and not args +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tx.output.side_effects.mutation(self) self.items.clear() return ConstantVariable.create(None) @@ -595,6 +674,7 @@ def call_method( name == "__setitem__" and self.is_mutable() and args +<<<<<<< HEAD and ( args[0].is_python_constant() or isinstance(args[0], SymNodeVariable) @@ -606,10 +686,14 @@ def call_method( ) ) ) +======= + and args[0].is_python_constant() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): assert not kwargs key, value = args tx.output.side_effects.mutation(self) +<<<<<<< HEAD if isinstance(key, SymNodeVariable): self.items[key.evaluate_expr()] = value elif isinstance(key, SliceVariable): @@ -674,6 +758,25 @@ def call_method( idx = self.call_method(tx, "index", args, kwargs) self.call_method(tx, "pop", [idx], {}) return ConstantVariable.create(None) +======= + if isinstance(key, SliceVariable): + self.items[key.as_python_constant()] = list(value.items) + else: + self.items[key.as_python_constant()] = value + return ConstantVariable.create(None) + elif name == "copy": + # List copy() doesn't have args and kwargs + assert not kwargs + assert not args + items = list(self.items) + return self.modified(items, mutation_type=ValueMutationNew()) + elif name == "reverse" and self.is_mutable(): + assert not kwargs + assert not args + self.items.reverse() + tx.output.side_effects.mutation(self) + return ConstantVariable.create(None) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return super().call_method(tx, name, args, kwargs) @@ -699,6 +802,7 @@ def call_method( args: list["VariableTracker"], kwargs: dict[str, "VariableTracker"], ) -> "VariableTracker": +<<<<<<< HEAD from .tensor import SymNodeVariable if name == "__setitem__" and self.is_mutable(): @@ -742,6 +846,30 @@ def call_method( raise_observed_exception( type(e), tx, args=list(map(ConstantVariable.create, e.args)) ) +======= + if ( + name == "__setitem__" + and self.is_mutable() + and args + and args[0].is_python_constant() + ): + assert not kwargs + key, value = args + tx.output.side_effects.mutation(self) + if isinstance(key, SliceVariable): + if not value.has_force_unpack_var_sequence(tx): + unimplemented_v2( + gb_type="Unsupported conversion for slice assignment", + context=f"call_method {self} {name} {args}", + explanation=f"Missing dynamo support for converting {value} into a list for slice assignment.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + self.items[key.as_python_constant()] = value.force_unpack_var_sequence( + tx + ) + else: + self.items[key.as_python_constant()] = value +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ConstantVariable.create(None) if name == "sort" and self.is_mutable(): @@ -1144,10 +1272,17 @@ class NamedTupleVariable(TupleVariable): *TupleVariable._nonvar_fields, } +<<<<<<< HEAD def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None: super().__init__(items, **kwargs) self.tuple_cls = tuple_cls self.dynamic_attributes = {} if not dynamic_attributes else dynamic_attributes +======= + def __init__(self, items, tuple_cls, **kwargs) -> None: + super().__init__(items, **kwargs) + self.tuple_cls = tuple_cls + self.dynamic_attributes = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def is_namedtuple(self): return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable( @@ -1247,10 +1382,13 @@ def check_and_create_method(): else: return None +<<<<<<< HEAD if name == "_fields": source = NamedTupleFieldsSource(self.source) if self.source else None return VariableTracker.build(tx, self.fields(), source=source) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if name in self.dynamic_attributes: return self.dynamic_attributes[name] @@ -1360,8 +1498,24 @@ def next_variable(self, tx): self.index += 1 return self.items[old_index] +<<<<<<< HEAD def call_obj_hasattr(self, tx, name): return variables.ConstantVariable.create(hasattr(iter([]), name)) +======= + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ): + if name == "__contains__": + assert len(args) == 1 + assert not kwargs + return iter_contains(self.items[self.index :], args[0], tx) + + return super().call_method(tx, name, args, kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def python_type(self): return type(iter([])) @@ -1371,6 +1525,7 @@ def as_python_constant(self): raise NotImplementedError return iter([x.as_python_constant() for x in self.items]) +<<<<<<< HEAD def has_unpack_var_sequence(self, tx): return True @@ -1378,6 +1533,10 @@ def unpack_var_sequence(self, tx): r = list(self.items[self.index :]) self.index = len(self.items) return r +======= + def unpack_var_sequence(self, tx): + return list(self.items[self.index :]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: return self.unpack_var_sequence(tx) @@ -1395,6 +1554,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: class TupleIteratorVariable(ListIteratorVariable): pass +<<<<<<< HEAD class RangeIteratorVariable(IteratorVariable): @@ -1444,3 +1604,5 @@ def reconstruct(self, codegen: "PyCodegen"): codegen.append_output(codegen.create_load_const(self.step)) codegen.extend_output(create_call_function(3, False)) codegen.append_output(create_instruction("GET_ITER")) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 60086fe6758c7..e1b4c25e6c687 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -42,7 +42,10 @@ AttrSource, GenericAttrSource, GetItemSource, +<<<<<<< HEAD TypeMROSource, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TypeSource, WeakRefCallSource, ) @@ -135,7 +138,13 @@ def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name): # Equivalent of something like type(L['self']).__mro__[1].attr_name if type_to_use_source: source = AttrSource( +<<<<<<< HEAD GetItemSource(TypeMROSource(type_to_use_source), index), +======= + GetItemSource( + AttrSource(type_to_use_source, "__mro__"), index + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) name, ) return resolved_getattr, source @@ -246,14 +255,24 @@ def call_method( # different from type(self) with polymorphism. cls_source = None if self.objvar.source: +<<<<<<< HEAD cls_source = TypeSource(self.objvar.source) +======= + cls_source = AttrSource(self.objvar.source, "__class__") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cls_variable = VariableTracker.build( tx, self.objvar.value_type, cls_source ) +<<<<<<< HEAD return variables.UserFunctionVariable( inner_fn.__func__, source=AttrSource(source, "__func__") ).call_function(tx, [cls_variable, *args], kwargs) +======= + return variables.UserMethodVariable( + inner_fn.__func__, cls_variable, source=source + ).call_function(tx, args, kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif isinstance(inner_fn, types.FunctionType): return variables.UserFunctionVariable( inner_fn, source=source @@ -305,11 +324,14 @@ def call_method( ): return self.objvar._dict_vt.call_method(tx, name, args, kwargs) elif ( +<<<<<<< HEAD isinstance(self.objvar, variables.UserDefinedSetVariable) and inner_fn in self.objvar._set_methods ): return self.objvar._set_vt.call_method(tx, name, args, kwargs) elif ( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) isinstance(self.objvar, variables.UserDefinedTupleVariable) and inner_fn in tuple_methods ): @@ -657,6 +679,7 @@ def __init__(self, fn_cls, **kwargs) -> None: def call_apply(self, tx: "InstructionTranslator", args, kwargs): requires_grad = False +<<<<<<< HEAD def visit(vt): nonlocal requires_grad if isinstance(vt, variables.TensorVariable): @@ -664,6 +687,15 @@ def visit(vt): requires_grad = True if isinstance(vt, variables.NNModuleVariable): if vt.is_training(tx): +======= + def visit(node): + nonlocal requires_grad + if isinstance(node, variables.TensorVariable): + if node.requires_grad is not False: + requires_grad = True + if isinstance(node, variables.NNModuleVariable): + if node.is_training(tx): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_grad = True VariableTracker.visit(visit, (args, kwargs)) @@ -1010,7 +1042,11 @@ def call_method( ) -> "VariableTracker": if name == "queue_callback": if torch._dynamo.compiled_autograd.in_compiled_autograd_region: +<<<<<<< HEAD assert tx.one_graph or tx.error_on_graph_break, ( +======= + assert tx.one_graph, ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" ) return variables.UserFunctionVariable( @@ -1182,6 +1218,7 @@ def call_method( return super().call_method(tx, name, args, kwargs) +<<<<<<< HEAD def get_forwarded_dict(self, tx): assert ( self.name == "__dict__" @@ -1191,6 +1228,8 @@ def get_forwarded_dict(self, tx): self.obj.ban_mutation = True return VariableTracker.build(tx, self.obj.value.__dict__, self.source) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class MethodWrapperVariable(VariableTracker): def __init__(self, method_wrapper, **kwargs) -> None: diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 10ad8c4a12865..c292e64de5247 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -101,6 +101,7 @@ def convert_to_fake(x): proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) fake_args = [convert_to_fake(arg) for arg in proxy_args] fake_kwargs = {k: convert_to_fake(v) for k, v in proxy_kwargs.items()} +<<<<<<< HEAD try: mod._infer_parameters(mod, fake_args, fake_kwargs) except AttributeError: @@ -108,6 +109,9 @@ def convert_to_fake(x): AttributeError, tx, ) +======= + mod._infer_parameters(mod, fake_args, fake_kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @contextmanager @@ -909,11 +913,15 @@ def set_nn_module_stack_source(self, source): @functools.cache def _nn_module_method_ids(): # Allow __setattr__ to fall through to base class handler +<<<<<<< HEAD supported = { torch.nn.Module.__setattr__, torch.nn.Module.__init__, torch.nn.Module.__delattr__, } +======= + supported = {torch.nn.Module.__setattr__, torch.nn.Module.__init__} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return { id(x.__code__) for x in torch.nn.Module.__dict__.values() @@ -993,7 +1001,11 @@ def call_function( fn = self.value_type.forward if self.source: +<<<<<<< HEAD source = self.get_source_by_walking_mro(name) +======= + source = AttrSource(AttrSource(self.source, "__class__"), name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: source = None @@ -1021,7 +1033,11 @@ def call_method( if name in ["_call_impl", "_wrapped_call_impl"]: fn = getattr(self.value_type, name) if self.source: +<<<<<<< HEAD source = self.get_source_by_walking_mro(name) +======= + source = AttrSource(AttrSource(self.source, "__class__"), name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: source = None @@ -1036,7 +1052,13 @@ def call_method( method = None if isinstance(method, staticmethod): +<<<<<<< HEAD source = AttrSource(self.get_source_by_walking_mro(name), "__func__") +======= + source = AttrSource( + AttrSource(AttrSource(self.source, "__class__"), name), "__func__" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return tx.inline_user_function_return( variables.UserFunctionVariable(method.__func__, source=source), args, @@ -1095,10 +1117,16 @@ def call_method( # Handle submodules self.is_state_mutated = True +<<<<<<< HEAD if ( method is torch.nn.Module.__setattr__ and isinstance(args[1], variables.DeletedVariable) ) or method is torch.nn.Module.__delattr__: +======= + if method is torch.nn.Module.__setattr__ and isinstance( + args[1], variables.DeletedVariable + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Trace through __delattr__ to track mutations on the module # members like `_modules``. return tx.inline_user_function_return( diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index 499c956843beb..3840b34cd6fc6 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -239,6 +239,17 @@ def map_sources_and_install_guards(self, tx): self.grad_to_source = {} self.tensor_to_source = {} +<<<<<<< HEAD +======= + # Tracing the _init_group is expensive. But we still have to insert the + # necessary guards for _init_group. So, we manually handle insertion of + # guards. We also want to mark all the tensors inside the state dict to + # be static address. + + # Mark all the tensors in the state dict to be static address. This has + # to be done first because the variable builder relies on the static + # address annotation. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def mark_static(x): mark_static_address(x) diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index a120ab488ed95..4ec9a4606357d 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -25,8 +25,12 @@ import torch +<<<<<<< HEAD from .. import graph_break_hints from ..exc import unimplemented_v2, UnsafeScriptObjectError, Unsupported +======= +from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .base import VariableTracker from .user_defined import UserDefinedObjectVariable @@ -76,6 +80,7 @@ def var_getattr(self, tx, name: str) -> VariableTracker: method = getattr(self.value, name, None) if method is None: +<<<<<<< HEAD unimplemented_v2( gb_type="FakeScriptObject missing method implementation", context=f"value={self.value}, method={name}", @@ -94,6 +99,16 @@ def var_getattr(self, tx, name: str) -> VariableTracker: hints=[ "Use method calls instead of attribute access.", ], +======= + unimplemented( + f"FakeScriptObject doesn't define method {name}. Did you forget to implement it in the fake class?" + ) + + if not callable(method): + unimplemented( + "Only method calls on TorchScript objects can be supported safely." + " Please use method calls instead of attribute access." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return TorchHigherOrderOperatorVariable.make( @@ -111,6 +126,7 @@ def var_getattr(self, tx, name: str) -> VariableTracker: "Dynamo cannot safely trace script object due to graph break." ) def call_method(self, tx, name, args, kwargs): +<<<<<<< HEAD unimplemented_v2( gb_type="Weird method call on TorchScript object", context=f"value={self.value}, method={name}", @@ -122,3 +138,6 @@ def call_method(self, tx, name, args, kwargs): "Avoid calling this method.", ], ) +======= + unimplemented(f"call method {name} on script object is not safe.") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 08dab47451abf..1a9538f34290a 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -316,18 +316,28 @@ def dynamic_getattr(self, tx: "InstructionTranslator", name): real_value = getattr(_input_associated_real_value, name) attr_source = AttrSource(self.source, name) +<<<<<<< HEAD +======= + install_guard(attr_source.make_guard(GuardBuilder.HASATTR)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Typically we'd want to use variable builder here # but unfortunately id(real_value.__self__) is not id() if is_bound_tensor_method(real_value): +<<<<<<< HEAD # No need to install the guard because its a bound tensor method +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .misc import GetAttrVariable return GetAttrVariable( self, name, source=attr_source, py_type=type(real_value) ) +<<<<<<< HEAD install_guard(attr_source.make_guard(GuardBuilder.HASATTR)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return VariableTracker.build(tx, real_value, attr_source) def method_attr_ndim(self, tx): @@ -1090,6 +1100,7 @@ def method___setitem__(self, key, value): *proxy_args_kwargs([self, key, value], {}), ) +<<<<<<< HEAD if isinstance(value, TensorVariable): # [Note: Tensor.__setitem__ and VariableTracker metadata] # At this point, we proxied a node representing `self[key] = value` into the graph. @@ -1114,6 +1125,8 @@ def method___setitem__(self, key, value): for k, v in specialized_props.items(): setattr(self, k, v) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if config.use_graph_deduplication or config.track_nodes_for_deduplication: tx.output.region_tracker.add_node_mutation(proxy.node, 0) @@ -1564,11 +1577,15 @@ def call_method( args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": +<<<<<<< HEAD from ..exc import unimplemented_v2 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..utils import numpy_method_wrapper args, kwargs = self.patch_args(name, args, kwargs) +<<<<<<< HEAD if name == "astype": from .builtin import BuiltinVariable @@ -1594,6 +1611,8 @@ def call_method( ), hints=[*graph_break_hints.FUNDAMENTAL], ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if name in ["__len__", "size", "tolist"]: # delegate back to TensorVariable return super().call_method(tx, name, args, kwargs) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index bfebedc88d6eb..13bca9b834027 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -52,6 +52,7 @@ tracable_create_parameter, ) from ..device_interface import get_registered_device_interfaces +<<<<<<< HEAD from ..exc import raise_observed_exception, unimplemented_v2 from ..guards import GuardBuilder, install_guard from ..source import ( @@ -60,6 +61,11 @@ SyntheticLocalSource, TorchSource, ) +======= +from ..exc import unimplemented, unimplemented_v2 +from ..guards import GuardBuilder, install_guard +from ..source import CallFunctionNoArgsSource, SyntheticLocalSource +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..utils import ( check_unspec_or_constant_args, guard_if_dyn, @@ -77,7 +83,10 @@ ) from .dicts import ConstDictVariable from .distributed import DistributedVariable, ProcessGroupVariable +<<<<<<< HEAD from .functions import bind_args_cached +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .lists import ListVariable, TupleVariable from .torch_function import ( can_dispatch_torch_function, @@ -125,6 +134,7 @@ torch.autograd.graph.disable_saved_tensors_hooks, torch.cpu.amp.autocast_mode.autocast, torch.cuda.amp.autocast_mode.autocast, +<<<<<<< HEAD # We'll let Dynamo inline into the contextlib part of these context # manager instances, all the way till it invokes the wrapped function # itself (at which point we wrap it back to special context manager @@ -133,6 +143,10 @@ # This allows us to support calling functions decorated with these # context managers, without much extra effort or code dup. torch.nn.attention.sdpa_kernel.__wrapped__, # type: ignore[attr-defined] +======= + torch.nn.attention.sdpa_kernel, + torch.nn.attention._sdpa_kernel_variadic, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] ) @@ -196,7 +210,10 @@ def tracing_state_functions() -> dict[Callable[[], Any], Optional[bool]]: torch.jit.is_tracing: False, torch._C._get_tracing_state: None, torch.fx._symbolic_trace.is_fx_tracing: False, +<<<<<<< HEAD torch.fx._symbolic_trace.is_fx_symbolic_tracing: False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.onnx.is_in_onnx_export: False, torch._dynamo.external_utils.is_compiling: True, torch._utils.is_compiling: True, @@ -418,6 +435,7 @@ def call_function( return FSDPParamGroupUseTrainingStateVariable.create( tx, args[0], args[1].as_python_constant() ) +<<<<<<< HEAD elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined] name_to_arg_map = bind_args_cached( self.value, tx, self.source, args, kwargs @@ -425,6 +443,19 @@ def call_function( backends = name_to_arg_map["backends"].as_python_constant() set_priority = name_to_arg_map["set_priority"].as_python_constant() return SDPAKernelVariable.create(tx, backends, set_priority) +======= + elif self.value is torch.nn.attention.sdpa_kernel: + assert len(args) == 1 or (len(kwargs) == 1 and "backends" in kwargs) + backends = args[0] if len(args) == 1 else kwargs["backends"] + set_priority = kwargs["set_priority"] if "set_priority" in kwargs else False + return SDPAKernelVariable.create( + tx, backends.as_python_constant(), set_priority + ) + elif self.value is torch.nn.attention._sdpa_kernel_variadic: + return SDPAKernelVariable.create( + tx, [arg.as_python_constant() for arg in args] + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().call_function(tx, args, kwargs) @@ -597,6 +628,7 @@ def handle_torch_compile(self, tx: "InstructionTranslator", *args, **kwargs): # torch.compile is a no-op in dynamo return args[0] +<<<<<<< HEAD unimplemented_v2( gb_type="torch.compile call with > 1 args", context=f"args={args}, kwargs={kwargs}", @@ -606,6 +638,9 @@ def handle_torch_compile(self, tx: "InstructionTranslator", *args, **kwargs): *graph_break_hints.SUPPORTABLE, ], ) +======= + unimplemented("torch.compile is used as a decorator in the compiled frame") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register(*REWRITE_OPS_TO_TENSOR_SIZE_METHOD) def handle_tensor_size_rewrites(self, tx: "InstructionTranslator", input): @@ -632,6 +667,7 @@ def handle_use_deterministic_algorithms( self, tx: "InstructionTranslator", mode, warn_only=False ): if warn_only and warn_only.as_python_constant(): +<<<<<<< HEAD unimplemented_v2( gb_type="Attempted to use torch.use_deterministic_algorithms(warn_only=True)", context=f"mode={mode}, warn_only={warn_only}", @@ -641,6 +677,9 @@ def handle_use_deterministic_algorithms( *graph_break_hints.SUPPORTABLE, ], ) +======= + unimplemented("torch.use_deterministic_algorithms(warn_only=True)") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return DeterministicAlgorithmsVariable.create(tx, mode.as_python_constant()) @register(torch.are_deterministic_algorithms_enabled) @@ -691,6 +730,7 @@ def handle_device_interface_stream(self, tx: "InstructionTranslator", stream): @register(torch.from_numpy) def handle_from_numpy(self, tx: "InstructionTranslator", *args): if not config.trace_numpy: +<<<<<<< HEAD unimplemented_v2( gb_type="call `torch.from_numpy` with `torch._dynamo.config.trace_numpy=False`", context=f"trace_numpy={config.trace_numpy}", @@ -712,6 +752,11 @@ def handle_from_numpy(self, tx: "InstructionTranslator", *args): *graph_break_hints.USER_ERROR, ], ) +======= + unimplemented("torch.from_numpy. config.trace_numpy is False") + if not np: + unimplemented("torch.from_numpy. NumPy is not available") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return wrap_fx_proxy_cls( target_cls=TensorVariable, tx=tx, @@ -923,6 +968,7 @@ def handle_nested_tensor( from .lists import BaseListVariable if layout and layout.as_python_constant() == torch.strided: +<<<<<<< HEAD unimplemented_v2( gb_type="Attempted to use strided NestedTensor", context=f"layout={layout}", @@ -942,6 +988,11 @@ def handle_nested_tensor( *graph_break_hints.USER_ERROR, ], ) +======= + unimplemented("torch.compile does not support strided NestedTensor") + if not isinstance(tensor_list, BaseListVariable): + unimplemented("nested_tensor with non-list input") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register(torch.nn.functional.one_hot) def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs): @@ -950,6 +1001,7 @@ def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs): and args[1].is_python_constant() and args[1].as_python_constant() == -1 ): +<<<<<<< HEAD unimplemented_v2( gb_type="Attempted to use `torch.nn.functional.one_hot` with data-dependent output shape", context=f"args={args}, kwargs={kwargs}", @@ -958,6 +1010,10 @@ def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs): "Explicitly set the `num_classes` param of the function call " "`torch.nn.functional.one_hot` to something other than -1.", ], +======= + unimplemented( + "torch.nn.functional.one_hot with data-dependent output shape" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @register(torch.fx.experimental.symbolic_shapes.guard_size_oblivious) @@ -1126,6 +1182,7 @@ def handle_pop_torch_function( ): assert not args and not kwargs if not tx.symbolic_torch_function_state.mode_stack: +<<<<<<< HEAD unimplemented_v2( gb_type="Attempted to pop from empty torch function mode stack", context="", @@ -1135,6 +1192,9 @@ def handle_pop_torch_function( *graph_break_hints.USER_ERROR, ], ) +======= + raise unimplemented("Popping from an empty torch function mode stack") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TorchFunctionModeStackVariable.register_mutation(tx) return tx.symbolic_torch_function_state.pop_torch_function_mode() @@ -1163,6 +1223,7 @@ def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs): assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack) return tx.symbolic_torch_function_state.mode_stack[ind] +<<<<<<< HEAD @register(torch.get_device_module.__wrapped__) def handle_get_device_module(self, tx, *args, **kwargs): if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs): @@ -1201,6 +1262,8 @@ def handle_get_device_module(self, tx, *args, **kwargs): ) return VariableTracker.build(tx, module, new_source) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register(torch.set_default_device) def handle_set_default_device( self, tx: "InstructionTranslator", *args, **kwargs @@ -1263,6 +1326,7 @@ def call_function( arg_type = flat_arg_vt.python_type() if not is_graphable_type(arg_type): type_name = flat_arg_vt.python_type().__qualname__ +<<<<<<< HEAD unimplemented_v2( gb_type="Invalid input type for nonstrict_trace-ed function", context=f"Encountered input of type <{type_name}>.", @@ -1277,6 +1341,15 @@ def call_function( "* `torch.utils._pytree.register_dataclass`\n" "* `torch.utils._pytree.register_pytree_node`", ], +======= + unimplemented( + f""" +For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <{type_name}>, please use one of the following to register the type with pytree: + * `torch.utils._pytree.register_constant` + * `torch.utils._pytree.register_dataclass` + * `torch.utils._pytree.register_pytree_node` +""" # NOQA: B950 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Since we checked with `is_graphable` above, `as_proxy` on the @@ -1297,6 +1370,7 @@ def call_function( import torch.utils._pytree as pytree if pytree.is_constant_class(typ): +<<<<<<< HEAD unimplemented_v2( gb_type="Input marked with `pytree.register_constant` constructed in the `torch.compile` region", context=f"Input={input_spec_vt}, offending type <{type_name}>.", @@ -1328,6 +1402,27 @@ def call_function( *graph_break_hints.SUPPORTABLE, ], from_exc=e, +======= + unimplemented( + f""" +You are calling a `nonstrict_trace`-ed function with an input that contains an object of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region. + +Please construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub. + """ # NOQA: B950 + ) + else: + unimplemented( + f""" +You are calling a `nonstrict_trace`-ed function where one one of the inputs has been registered with a `pytree_flatten` that puts an object of type <{type_name}> into the context. + +Please consider modifying that `pytree_flatten` to avoid putting the object into context, and apply one of the following to <{type_name}> + * `torch.utils._pytree.register_constant` + * `torch.utils._pytree.register_dataclass` + * `torch.utils._pytree.register_pytree_node` + +If the above doesn't work, please subtmit an issue to GitHub. +""" # NOQA: B950 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) fn = self.value @@ -1362,6 +1457,7 @@ def patched_fn(*args, **kwargs): # 2. Create a proxy call to `flat_apply`, then fake-tensor propagate # the call and wrap output into a VariableTracker. proxy = tx.output.create_proxy("call_function", flat_apply, all_args, {}) +<<<<<<< HEAD try: # TODO support more output types once `flat_apply` supports # pytree-able output types. We can have Dynamo trace through an @@ -1383,6 +1479,14 @@ def patched_fn(*args, **kwargs): ), hints=[*graph_break_hints.SUPPORTABLE], ) +======= + out_vt = wrap_fx_proxy(tx, proxy) + # TODO support more output types + # Q: flat_apply will likely pytree_flatten the output for this, then + # how do we intercept the output before flatten, and wrap those? + # - Maybe we can have `flat_apply` return the output spec, so that + # Dynamo can unflatten and wrap the result. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return out_vt @@ -1397,6 +1501,7 @@ def patched_fn(*args, **kwargs): source = CallFunctionNoArgsSource(self.source) install_guard(source.make_guard(GuardBuilder.EQUALS_MATCH)) # constant fold +<<<<<<< HEAD try: return ConstantVariable.create( self.as_python_constant()( @@ -1410,6 +1515,14 @@ def patched_fn(*args, **kwargs): tx, args=list(map(ConstantVariable.create, exc.args)), ) +======= + return ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.is_tensor_method(): name = self.value.__name__ @@ -1460,6 +1573,7 @@ def patched_fn(*args, **kwargs): For now, dynamo will explicitly graph break when it encounters user code with this behavior. """ log.warning(msg) +<<<<<<< HEAD unimplemented_v2( gb_type="Attempted to call torch in-graph function on only torch.SymInt arguments", context=f"fn={self.value}, args={args}, kwargs={kwargs}", @@ -1471,6 +1585,9 @@ def patched_fn(*args, **kwargs): *graph_break_hints.SUPPORTABLE, ], ) +======= + unimplemented(msg) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO(voz): Replace w/ dynamic shape rewrite table. # Ideally, we would be able to do this at ctor time, but alas we need a combination @@ -1487,6 +1604,7 @@ def patched_fn(*args, **kwargs): # variant torch ops, the original function could come from a user # defined `@allow_in_graph` function as well, which doesn't have the # same semantics as the torch ops. +<<<<<<< HEAD # Calling fake tensor propagation can mutate the out= tensor in # tx.output.tracked_fakes. tracked_fakes are used to apply @@ -1512,6 +1630,17 @@ def patched_fn(*args, **kwargs): # e.g., out=output_tensor if isinstance(out_kwarg_vt, variables.TensorVariable): saved_out_shapes = out_kwarg_vt.proxy.node.meta["example_value"].shape +======= + fake_out_shape = None + if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): + # Calling fake tensor propagation can mutate the out= tensor in + # tx.output.tracked_fakes. tracked_fakes are used to apply + # symbolic_shape guards. Mutating them destroys the information + # prior to tracing, which is essential for creating right + # guards. So save the shape now, and check later if it has + # changed. If it has, graph break. + fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor_variable = wrap_fx_proxy( tx=tx, @@ -1528,6 +1657,7 @@ def patched_fn(*args, **kwargs): and "requires_grad" in kwargs and kwargs["requires_grad"].as_python_constant() ): +<<<<<<< HEAD unimplemented_v2( gb_type="Attempted to use tensor creation function with requires_grad=True", context=f"fn={self.value}, args={args}, kwargs={kwargs}", @@ -1541,6 +1671,18 @@ def patched_fn(*args, **kwargs): # Handle e.g., `torch.add(a, b, out=result)` if saved_out_shapes is not None: +======= + unimplemented( + """factory functions that return tensors that require grad are not supported. +Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" + ) + + # Handle e.g., `torch.add(a, b, out=result)` + if "out" in kwargs and not ( + isinstance(kwargs["out"], variables.ConstantVariable) + and kwargs["out"].as_python_constant() is None + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # out variants of torch operators like torch.sort and torch.sigmoid # mutate the tensors in the out field. # @@ -1552,6 +1694,7 @@ def patched_fn(*args, **kwargs): # Note that although these tensor variablels would hold different # proxies, the in-place mutation semantics is preserved in the FX # graph, so we won't have correctness issues. +<<<<<<< HEAD if isinstance(saved_out_shapes, list): for out_tensor_vt, saved_out_shape in zip( out_kwarg_vt.items, # type: ignore[union-attr] @@ -1619,6 +1762,63 @@ def patched_fn(*args, **kwargs): *graph_break_hints.SUPPORTABLE, ], ) +======= + if isinstance(tensor_variable, TupleVariable): + assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) + for out_tensor, result_tensor in zip( + kwargs["out"].items, tensor_variable.items + ): + if ( + isinstance(out_tensor, variables.TensorVariable) + and isinstance(result_tensor, variables.TensorVariable) + and out_tensor._size + != result_tensor._size # we actually want to compare None values here + ): + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented("out variants with resizing on graph inputs") + elif isinstance(tensor_variable, TensorVariable): + assert isinstance(kwargs["out"], TensorVariable) + assert "example_value" in kwargs["out"].proxy.node.meta + fake_tensor = tensor_variable.proxy.node.meta["example_value"] + fake_out = kwargs["out"].proxy.node.meta["example_value"] + if fake_out_shape != fake_tensor.shape: + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented("out variants with resizing on graph inputs") + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where output tensor was non-contiguous" + ) + elif ( + isinstance(tensor_variable, ConstantVariable) + and tensor_variable.value is None + ): + # Handle out-variant custom ops that return None. + if isinstance(kwargs["out"], TensorVariable): + assert "example_value" in kwargs["out"].proxy.node.meta + fake_out = kwargs["out"].proxy.node.meta["example_value"] + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where output tensor was non-contiguous" + ) + elif isinstance(kwargs["out"], ListVariable): + for idx, x in enumerate(kwargs["out"].items): + assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] + fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where some of the output tensors were non-contiguous" + ) + else: + unimplemented(f"out variant of {type(kwargs['out'])}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return tensor_variable @@ -1642,6 +1842,7 @@ def handle_ntuple(value): torch.nn.modules.utils._ntuple(count)(value.as_python_constant()), ) else: +<<<<<<< HEAD unimplemented_v2( gb_type="Attempted to use `torch.nn.modules.utils._ntuple` with unsupported argument type", context=f"value={value}", @@ -1650,6 +1851,9 @@ def handle_ntuple(value): "Change use of _ntuple with argument as constant or tensor.", ], ) +======= + unimplemented(f"torch.nn.modules.utils._ntuple({value})") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.value is torch.nn.modules.utils._ntuple: return variables.LambdaVariable(handle_ntuple) @@ -1660,6 +1864,7 @@ def handle_ntuple(value): def call_nn_parameter(cls, tx, data=None, requires_grad=True): """A call to torch.nn.Parameter() gets lifted to before the graph""" if tx.export: +<<<<<<< HEAD unimplemented_v2( gb_type="Attempted to use `torch.nn.Parameter()` with export", context="", @@ -1669,11 +1874,15 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): *graph_break_hints.SUPPORTABLE, ], ) +======= + unimplemented("nn parameter construction not supported with export") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(requires_grad, variables.VariableTracker): try: requires_grad = requires_grad.as_python_constant() except NotImplementedError: +<<<<<<< HEAD unimplemented_v2( gb_type="non-constant `requires_grad` argument to `torch.nn.Parameter`", context=f"requires_grad={requires_grad}", @@ -1694,11 +1903,18 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): *graph_break_hints.USER_ERROR, ], ) +======= + unimplemented("Parameter(requires_grad=...) not constant") + + if not isinstance(data, variables.TensorVariable): + unimplemented(f"Parameter(data={data}) not implemented") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # this results in cleaner graphs, but only works for inputs if data.source: return cls._nn_param_via_prefix_insert(tx, data, requires_grad) +<<<<<<< HEAD if config.graph_break_on_nn_param_ctor: # Need user to manually move since we cannot unimplemented_v2( @@ -1735,12 +1951,22 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): *graph_break_hints.DIFFICULT, ], ) +======= + if isinstance( + data, TensorWithTFOverrideVariable + ) or is_traceable_wrapper_subclass_type(data.class_type): + unimplemented("Parameter constructor with tensor subclass NYI") + + if not can_convert_to_tracable_parameter(): + unimplemented("Workaround for issues with nn_parameter construction") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) dtype = data.var_getattr(tx, "dtype").as_python_constant() device = data.var_getattr(tx, "device").as_python_constant() except NotImplementedError as e: +<<<<<<< HEAD unimplemented_v2( gb_type="`torch.nn.Parameter` with non-constant Tensor attributes", context=f"data={data}", @@ -1751,6 +1977,9 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): ], from_exc=e, ) +======= + unimplemented(f"Parameter not python_constant: {e}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placeholder = tx.output.synthetic_graph_input( new_parameter_placeholder, [shape, dtype, device, requires_grad] @@ -1776,7 +2005,11 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): result.class_type = torch.nn.Parameter # TODO(jansel/bdhirsh) - There is some issue with +<<<<<<< HEAD # tracable_create_parameter. It does not seem to use the right +======= + # tracable_create_paramter. It does not seem to use the right +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # grad_enabled. Since this is parameter, we can just override the # has_grad_fn field to False to workaround the issue. result.has_grad_fn = False @@ -1791,8 +2024,12 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad varname = tx.output.new_var() # construct the nn.Parameter before the graph save it to varname +<<<<<<< HEAD assert tx.output.root_tx is not None cg = PyCodegen(tx.output.root_tx) +======= + cg = PyCodegen(tx) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cg.add_push_null(lambda: cg.load_import_from("torch.nn", "Parameter")) cg(data.source) cg(variables.ConstantVariable(requires_grad)) @@ -1802,6 +2039,7 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad data_node = data.as_proxy().node if data_node.op not in ("placeholder", "get_attr"): +<<<<<<< HEAD unimplemented_v2( gb_type="Unexpected type of data placeholder op for parameter construction", context=f"data_node.op={data_node.op}", @@ -1809,13 +2047,21 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad hints=[ *graph_break_hints.DIFFICULT, ], +======= + unimplemented( + "Unexpected type of data placeholder op for parameter construction" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # add the newly constructed nn.Parameter as a graph input source = SyntheticLocalSource(varname) example_value = torch.nn.Parameter( +<<<<<<< HEAD tx.output.example_value_from_input_node(data.as_proxy().node), requires_grad=requires_grad, +======= + tx.output.example_value_from_input_node(data.as_proxy().node) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) result = VariableTracker.build(tx, example_value, source) # Realize the VT because we will delete the guards on it in the next line. diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 4458468d8118c..110440d4e78ea 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -38,6 +38,10 @@ from torch._guards import Source from torch.overrides import ( _get_overloaded_args, +<<<<<<< HEAD +======= + BaseTorchFunctionMode, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) get_default_nowrap_functions, TorchFunctionMode, ) @@ -59,7 +63,11 @@ from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import GenericContextWrappingVariable +<<<<<<< HEAD from .functions import UserFunctionVariable, UserMethodVariable +======= +from .functions import UserMethodVariable +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable @@ -123,6 +131,76 @@ operator.length_hint, ] +<<<<<<< HEAD +======= +BUILTIN_TO_TENSOR_FN_MAP = {} + +# These functions represent the r* versions of the above ops +# Basically, if __add__(1, Tensor) is called, it is translated +# to __radd__(Tensor, 1). +# In the builtin var, we check if there is a tensor in the first args position, +# if not, we swap the args and use the r* version of the op. +BUILTIN_TO_TENSOR_RFN_MAP = {} + + +def populate_builtin_to_tensor_fn_map(): + global BUILTIN_TO_TENSOR_FN_MAP + + most_recent_func = None + + class GetMethodMode(BaseTorchFunctionMode): + """ + Mode to extract the correct methods from torch function invocations + (Used to get the correct torch.Tensor methods from builtins) + """ + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + nonlocal most_recent_func + most_recent_func = func + return func(*args, **kwargs) + + inp0 = torch.ones(1) + inp1 = torch.ones(1) + inp0_int = torch.ones(1, dtype=torch.int32) + inp1_int = torch.ones(1, dtype=torch.int32) + with GetMethodMode(): + setups_and_oplists = [ + (lambda o: o(inp0), un_ops), + (lambda o: o(inp0_int), un_int_ops), + (lambda o: o(inp0, inp1), bin_ops), + (lambda o: o(inp0_int, inp1_int), bin_int_ops), + (lambda o: o(inp0_int, 0), tensor_and_int_ops), + ] + for setup_fn, op_list in setups_and_oplists: + for op in op_list: + setup_fn(op) + assert most_recent_func is not None + BUILTIN_TO_TENSOR_FN_MAP[op] = most_recent_func + + # gather the reverse functions + rsetups_and_oplists = [ + ( + lambda o: o(1, inp1), + bin_ops, + ), # Get r* ops, (ex. __sub__(int, Tensor) -> __rsub__(Tensor, int)) + (lambda o: o(1, inp1_int), bin_int_ops), + (lambda o: o(0, inp0_int), tensor_and_int_ops), + ] + + rskips = {operator.matmul, operator.imatmul, operator.getitem} + for setup_fn, op_list in rsetups_and_oplists: + for op in op_list: + if op in rskips: + continue + setup_fn(op) + assert most_recent_func is not None + if most_recent_func != BUILTIN_TO_TENSOR_FN_MAP[op]: + BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func + + +populate_builtin_to_tensor_fn_map() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) banned_attrs = [ fn.__self__.__name__ @@ -389,7 +467,11 @@ def _flatten_vts(vts): output = [] while vts: +<<<<<<< HEAD vt = vts.popleft() +======= + vt = vts.pop() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not vt.is_realized() and vt.peek_type() in (dict, list, tuple): vt.realize() @@ -397,10 +479,15 @@ def _flatten_vts(vts): if vt.is_realized(): if isinstance(vt, ListVariable): vts.extend(vt.items) +<<<<<<< HEAD continue elif isinstance(vt, ConstDictVariable): vts.extend(vt.items.values()) continue +======= + elif isinstance(vt, ConstDictVariable): + vts.extend(vt.items.values()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output.append(vt) @@ -620,8 +707,13 @@ def var_getattr(self, tx: "InstructionTranslator", name): elif isinstance(attr, property): getter_source = AttrSource(attr_source, "fget") getter = attr.fget +<<<<<<< HEAD getter_var = UserFunctionVariable(getter, source=getter_source) return getter_var.call_function(tx, [self], {}) +======= + getter_var = UserMethodVariable(getter, self, source=getter_source) + return getter_var.call_function(tx, [], {}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif isinstance(attr, classmethod): return UserMethodVariable( diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 9c28ceb762b09..a0dee0c1e4cb2 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -10,9 +10,13 @@ attribute access, and other Python object behaviors. - Specialized subclasses for common patterns: - UserDefinedDictVariable: For dict subclasses +<<<<<<< HEAD - UserDefinedSetVariable: For set subclasses - UserDefinedTupleVariable: For tuple subclasses - UserDefinedExceptionObjectVariable: For exception subclasses +======= + - UserDefinedTupleVariable: For tuple subclasses +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - FrozenDataClassVariable: Special handling of frozen dataclasses - MutableMappingVariable: For collections.abc.MutableMapping subclasses @@ -46,53 +50,85 @@ from torch._guards import TracingContext from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type +<<<<<<< HEAD from .. import graph_break_hints, polyfills, variables +======= +from .. import polyfills, variables +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..bytecode_transformation import create_call_function from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import ( handle_observed_exception, ObservedAttributeError, +<<<<<<< HEAD ObservedKeyError, ObservedTypeError, ObservedUserStopIteration, raise_observed_exception, unimplemented_v2, +======= + raise_observed_exception, + unimplemented, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from ..guards import GuardBuilder, install_guard from ..source import ( AttrSource, CallFunctionNoArgsSource, DataclassFieldsSource, +<<<<<<< HEAD DictGetItemSource, GetItemSource, RandomValueSource, TypeDictSource, TypeMROSource, +======= + GetItemSource, + RandomValueSource, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TypeSource, UnspecializedParamBufferSource, ) from ..utils import ( +<<<<<<< HEAD check_constant_args, cmp_name_to_op_mapping, dict_methods, frozenset_methods, +======= + build_checkpoint_variable, + check_constant_args, + cmp_name_to_op_mapping, + dict_methods, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) get_custom_getattr, has_torch_function, is_frozen_dataclass, is_lru_cache_wrapped_function, is_namedtuple_cls, +<<<<<<< HEAD +======= + is_utils_checkpoint, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_wrapper_or_member_descriptor, istype, list_methods, namedtuple_fields, object_has_getattribute, proxy_args_kwargs, +<<<<<<< HEAD set_methods, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensortype_to_dtype, tuple_methods, unpatched_nn_module_getattr, ) +<<<<<<< HEAD from .base import ValueMutationNew, VariableTracker +======= +from .base import AttributeMutationExisting, ValueMutationNew, VariableTracker +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .dicts import DefaultDictVariable from .lists import SizeVariable @@ -139,6 +175,7 @@ def is_forbidden_context_manager(ctx): return ctx in f_ctxs +<<<<<<< HEAD def is_cython_function(obj): return ( callable(obj) @@ -147,6 +184,8 @@ def is_cython_function(obj): ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class UserDefinedVariable(VariableTracker): value: object @@ -157,10 +196,13 @@ class UserDefinedClassVariable(UserDefinedVariable): def __init__(self, value, **kwargs) -> None: super().__init__(**kwargs) self.value = value +<<<<<<< HEAD # Used when we materialize class.__dict__ to a MappingProxyObject. In # this case, we don't want to allow mutation in the class because there # is no way to reflect it in the created MappingProxyVariable. self.ban_mutation = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def as_python_constant(self): return self.value @@ -224,8 +266,11 @@ def supported_c_new_functions(): return { object.__new__, dict.__new__, +<<<<<<< HEAD set.__new__, frozenset.__new__, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tuple.__new__, list.__new__, }.union(exceptions) @@ -258,9 +303,12 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke elif name == "__dict__": options = {"source": source} return variables.GetAttrVariable(self, name, **options) +<<<<<<< HEAD elif name == "__mro__": attr_source = self.source and TypeMROSource(self.source) return VariableTracker.build(tx, self.value.__mro__, attr_source) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Special handling of collections.OrderedDict.fromkeys() # Wrap it as GetAttrVariable(collections.OrderedDict, "fromkeys") to make it consistent with @@ -303,15 +351,25 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke func = obj.__get__(None, self.value) return VariableTracker.build(tx, func, source) elif source: +<<<<<<< HEAD if inspect.ismemberdescriptor(obj): +======= + # __mro__ is a member in < 3.12, an attribute in >= 3.12 + if inspect.ismemberdescriptor(obj) or ( + sys.version_info >= (3, 12) and name == "__mro__" + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return VariableTracker.build(tx, obj.__get__(self.value), source) if ConstantVariable.is_literal(obj): return ConstantVariable.create(obj) elif isinstance(obj, enum.Enum): return EnumVariable(obj) +<<<<<<< HEAD elif self.value is collections.OrderedDict: return variables.GetAttrVariable(self, name) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif name in getattr(self.value, "__dict__", {}) or ( self.value.__module__.startswith("torch.") or self.value.__module__ == "torch" @@ -419,18 +477,24 @@ def call_method( return BuiltinVariable.call_custom_dict_fromkeys( tx, self.value, *args, **kwargs ) +<<<<<<< HEAD elif self.value is collections.OrderedDict and name == "move_to_end": return args[0].call_method(tx, name, [*args[1:]], kwargs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif name == "__eq__" and len(args) == 1 and hasattr(args[0], "value"): return variables.ConstantVariable(self.value == args[0].value) elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"): return variables.ConstantVariable(self.value != args[0].value) +<<<<<<< HEAD elif issubclass(self.value, dict) and name != "__new__": # __new__ is handled below return variables.BuiltinVariable(dict).call_method(tx, name, args, kwargs) elif issubclass(self.value, (set, frozenset)) and name != "__new__": # __new__ is handled below return variables.BuiltinVariable(set).call_method(tx, name, args, kwargs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif ( name == "__new__" and self.value is collections.OrderedDict @@ -450,6 +514,7 @@ def call_method( args[0], args[1:], ) +<<<<<<< HEAD elif name == "__setattr__" and self.ban_mutation: unimplemented_v2( gb_type="Class attribute mutation when the __dict__ was already materialized", @@ -457,6 +522,8 @@ def call_method( explanation="Dyanmo does not support tracing mutations on a class when its __dict__ is materialized", hints=graph_break_hints.SUPPORTABLE, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().call_method(tx, name, args, kwargs) def call_function( @@ -484,7 +551,11 @@ def call_function( # import here to avoid circular dependency from .ctx_manager import NullContextVariable +<<<<<<< HEAD return NullContextVariable(*args, **kwargs) +======= + return NullContextVariable() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif self.value is collections.OrderedDict: return tx.inline_user_function_return( VariableTracker.build(tx, polyfills.construct_dict), @@ -504,6 +575,7 @@ def call_function( ) elif is_typeddict(self.value): if self.value.__optional_keys__: +<<<<<<< HEAD unimplemented_v2( gb_type="TypedDict with optional keys", context=str(self.value), @@ -556,6 +628,32 @@ def deque_signature(iterable=None, maxlen=None): if "maxlen" in bound_args.arguments: maxlen = bound_args.arguments["maxlen"] +======= + unimplemented("TypedDict with optional keys not supported") + return variables.BuiltinVariable(dict).call_dict(tx, *args, **kwargs) + elif self.value is collections.deque: + maxlen = variables.ConstantVariable.create(None) + if not kwargs: + if len(args) == 0: + items = [] + elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) + elif len(args) == 2 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) + maxlen = args[1] + else: + unimplemented("deque() with more than 2 arg not supported") + elif tuple(kwargs) == ("maxlen",): + maxlen = kwargs["maxlen"] + if len(args) == 0: + items = [] + if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) + else: + unimplemented("deque() with more than 1 arg not supported") + else: + unimplemented("deque() with invalid kwargs not supported") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return variables.lists.DequeVariable( items, maxlen=maxlen, mutation_type=ValueMutationNew() ) @@ -567,6 +665,7 @@ def deque_signature(iterable=None, maxlen=None): return variables.WeakRefVariable(args[0], callback) elif self.value is functools.partial: if not args: +<<<<<<< HEAD unimplemented_v2( gb_type="missing args to functools.partial", context="", @@ -576,6 +675,9 @@ def deque_signature(iterable=None, maxlen=None): *graph_break_hints.USER_ERROR, ], ) +======= + unimplemented("functools.partial malformed") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The first arg, a callable (the ctor below will assert on types) fn = args[0] rest_args = args[1:] @@ -602,7 +704,10 @@ def deque_signature(iterable=None, maxlen=None): and self.source and not is_forbidden_context_manager(self.value) ): +<<<<<<< HEAD from . import TorchCtxManagerClassVariable +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .functions import ( BaseUserFunctionVariable, FunctionDecoratedByContextlibContextManagerVariable, @@ -622,6 +727,7 @@ def deque_signature(iterable=None, maxlen=None): ): # We are not changing the behavior of Dynamo as these function were # already ignored on trace_rules.py before #136033 landed +<<<<<<< HEAD unimplemented_v2( gb_type="unsupported contextlib.* API", context=f"{self.value}", @@ -668,6 +774,19 @@ def deque_signature(iterable=None, maxlen=None): kwargs_dict = args[2].keys_as_python_constant() return fn_var.call_function(tx, args_list, kwargs_dict) +======= + unimplemented( + f"{self.value} not supported. This may be due to its use of " + "context-specific operations that are not supported in " + "Dynamo yet (i.e. Exception handling)" + ) + + if self.value is contextlib._GeneratorContextManager and isinstance( + args[0], BaseUserFunctionVariable + ): + if not torch._dynamo.config.enable_trace_contextlib: + unimplemented("contextlib.contextmanager") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Wrap UserFunctionVariable in FunctionDecoratedByContextlibContextManagerVariable # if the function is annotated with @contextlib.contextmanager # This shouldn't be necessary once generator functions are fully @@ -1033,6 +1152,7 @@ def call_method( if torch._dynamo.config.enable_faithful_generator_behavior and isinstance( self.value, types.GeneratorType ): +<<<<<<< HEAD unimplemented_v2( gb_type="call_method on generator", context=f"object={self.value}, method={name}, args={args}, kwargs={kwargs}", @@ -1051,14 +1171,31 @@ def call_method( source_fn = None if source: source_fn = self.get_source_by_walking_mro(name) +======= + unimplemented("Generator as graph argument is not supported") + + # check for methods implemented in C++ + if isinstance(method, types.FunctionType): + source = ( + None + if self.source is None + else AttrSource(AttrSource(self.source, "__class__"), name) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO(jansel): add a guard to check for monkey patching? from ..mutation_guard import unpatched_nn_module_init if method is torch.nn.Module.__init__: method = unpatched_nn_module_init +<<<<<<< HEAD return UserMethodVariable( method, self, source_fn=source_fn, source=source ).call_function(tx, args, kwargs) +======= + return UserMethodVariable(method, self, source=source).call_function( + tx, args, kwargs + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if method is list.__len__ and self.source and not (args or kwargs): install_guard(self.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) @@ -1072,6 +1209,7 @@ def method_setattr_standard( try: name = name.as_python_constant() except NotImplementedError: +<<<<<<< HEAD unimplemented_v2( gb_type="non-const setattr name on user-defined object", context=f"object={self}, name={name}, value={value}", @@ -1082,6 +1220,11 @@ def method_setattr_standard( "Attempted setattr on a user-defined object that does not have " "an AttributeMutation mutation_type" ) +======= + unimplemented(f"non-const setattr name: {name}") + if not tx.output.side_effects.is_attribute_mutation(self): + unimplemented(f"setattr({self}, {name}, ...)") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if directly_update_dict: self.attrs_directly_modifed_on_dict.add(name) @@ -1130,6 +1273,7 @@ def unpack_var_sequence(self, tx): ] return super().unpack_var_sequence(tx) +<<<<<<< HEAD def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: try: variables.BuiltinVariable(iter).call_function(tx, [self], {}) @@ -1151,6 +1295,8 @@ def force_unpack_var_sequence(self, tx): break return result +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def next_variable(self, tx): return self.call_method(tx, "__next__", [], {}) @@ -1198,6 +1344,7 @@ def call_function( ).call_function(tx, [var], kwargs) if self.source is None: +<<<<<<< HEAD unimplemented_v2( gb_type="attempted to call sourceless user-defined object as a method", context=f"object={self.value}, function={func}, args={args}, kwargs={kwargs}", @@ -1205,6 +1352,10 @@ def call_function( hints=[ f"Ensure the user-defined object {self.value} is constructed outside the compiled region.", ], +======= + unimplemented( + "Sourceless UserDefinedObjectVariable method not supported" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) func_src = AttrSource(self.source, "__func__") func_var = VariableTracker.build(tx, func, func_src) @@ -1293,6 +1444,7 @@ def get_source_by_walking_mro(self, name): for idx, klass in enumerate(type(self.value).__mro__): if name in klass.__dict__: +<<<<<<< HEAD if idx != 0: mro_source = TypeMROSource(self.cls_source) klass_source = GetItemSource(mro_source, idx) @@ -1339,6 +1491,19 @@ def get_source_by_walking_mro(self, name): ) def var_getattr(self, tx: "InstructionTranslator", name): +======= + mro_source = AttrSource(self.cls_source, "__mro__") + klass_source = GetItemSource(mro_source, idx) + dict_source = AttrSource(klass_source, "__dict__") + # TODO(anijain2305) - This is a mapping proxy object. Ideally we + # should use DictGetItemSource here. + return GetItemSource(dict_source, name) + + unimplemented(f"Could not find {name} in {type(self.value).__mro__}") + + def var_getattr(self, tx: "InstructionTranslator", name): + from .. import trace_rules +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from . import ConstantVariable source = AttrSource(self.source, name) if self.source else None @@ -1420,6 +1585,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): return out elif getattr_fn is not None: +<<<<<<< HEAD unimplemented_v2( gb_type="User-defined object with non-function __getattr__", context=f"object={self.value}, name={name}, getattr_fn={getattr_fn}", @@ -1429,12 +1595,16 @@ def var_getattr(self, tx: "InstructionTranslator", name): "Ensure the object's __getattr__ is a function type.", ], ) +======= + unimplemented("UserDefined with non-function __getattr__") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..mutation_guard import unpatched_nn_module_init if subobj is torch.nn.Module.__init__: subobj = unpatched_nn_module_init +<<<<<<< HEAD subobj_from_class = inspect.getattr_static( self.value.__class__, name, NO_SUCH_SUBOBJ ) @@ -1459,6 +1629,17 @@ def var_getattr(self, tx: "InstructionTranslator", name): return variables.UserFunctionVariable( subobj.fget, source=source ).call_function(tx, [self], {}) +======= + if isinstance(subobj, property): + if self.source: + # Read the class attribute to reach the property + source = AttrSource(AttrSource(self.source, "__class__"), name) + # Get the getter function + source = AttrSource(source, "fget") + return variables.UserMethodVariable( + subobj.fget, self, source=source + ).call_function(tx, [], {}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif isinstance(subobj, _collections._tuplegetter): # namedtuple fields are represented by _tuplegetter, and here we # emulate its `__get__`, which is implemented in C. @@ -1471,6 +1652,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): # Safe because `staticmethod.__get__` basically won't trigger user # code and just returns the underlying `__func__`: # https://github.com/python/cpython/blob/3.11/Objects/funcobject.c#L1088-L1100 +<<<<<<< HEAD if is_accessible_from_type_mro: # Accessing from __dict__ does not resolve the descriptor, it # returns a staticmethod object, so access the __func__ @@ -1490,6 +1672,13 @@ def var_getattr(self, tx: "InstructionTranslator", name): self.var_getattr(tx, "__class__"), source_fn=source_fn, source=source, +======= + func = subobj.__get__(self.value) + return VariableTracker.build(tx, func, source) + elif isinstance(subobj, classmethod): + return variables.UserMethodVariable( + subobj.__func__, self.var_getattr(tx, "__class__"), source=source +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) elif isinstance(subobj, types.ClassMethodDescriptorType): # e.g.: inspect.getattr_static({}, "fromkeys") @@ -1550,6 +1739,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): if isinstance(subobj, types.MethodType): if dynamic_subobj.__self__ is not self.value: if not isinstance(dynamic_subobj.__func__, types.FunctionType): +<<<<<<< HEAD unimplemented_v2( gb_type="User-defined object method with non-function __func__", context=f"object={self.value}, name={name}, method={dynamic_subobj}, " @@ -1559,6 +1749,10 @@ def var_getattr(self, tx: "InstructionTranslator", name): hints=[ "Ensure that the method's __func__ is a function type.", ], +======= + unimplemented( + f"Found a method whose __func__ is not of FunctionType - {dynamic_subobj}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Use the __self__ attribute of the method to find the @@ -1579,6 +1773,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): func = subobj if inspect.ismethod(dynamic_subobj): +<<<<<<< HEAD source_fn = None if is_accessible_from_type_mro: source_fn = self.get_source_by_walking_mro(name) @@ -1587,6 +1782,18 @@ def var_getattr(self, tx: "InstructionTranslator", name): ) elif inspect.isfunction(dynamic_subobj): return VariableTracker.build(tx, func, source) +======= + return variables.UserMethodVariable(func, self, source=source) + elif inspect.isfunction(dynamic_subobj): + if is_utils_checkpoint(func): + return build_checkpoint_variable(source=source) + elif source is not None: + return trace_rules.lookup(func).create_with_source( + func, source=source + ) + else: + return trace_rules.lookup(func)(func) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( # wrap the source only if inline_inbuilt_nn_modules is set or fsdp modules. This is a temporary solution to @@ -1608,6 +1815,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): source = self._wrap_source(source) if subobj is not NO_SUCH_SUBOBJ: +<<<<<<< HEAD if ( is_wrapper_or_member_descriptor(subobj) or torch._C._dynamo.utils.is_instancemethod(subobj) @@ -1619,6 +1827,12 @@ def var_getattr(self, tx: "InstructionTranslator", name): if is_accessible_from_type_mro: source = self.get_source_by_walking_mro(name) +======= + if is_wrapper_or_member_descriptor(subobj): + options = {"source": source} + return variables.GetAttrVariable(self, name, **options) + if source: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return variables.LazyVariableTracker.create(subobj, source) else: # Check if the subobj is accessible from the class itself. If the class source is known, we can create a @@ -1657,6 +1871,7 @@ def call_obj_hasattr( class FrozenDataClassVariable(UserDefinedObjectVariable): +<<<<<<< HEAD class HashWrapper: """This class is hashed if a dataclass is used as a key in a dict. It's necessary to avoid side effects from calling the __init__ of the dataclass class when hashing""" @@ -1675,6 +1890,8 @@ def __eq__(self, other): def __hash__(self): return hash((self.cls, self.fields)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def create(tx, value, source): from dataclasses import fields @@ -1748,6 +1965,7 @@ def as_proxy(self): ctor = self.python_type() return ctor(*args, **kwargs) +<<<<<<< HEAD def reconstruct(self, codegen: "PyCodegen") -> None: # Handle specific pytree classes import torch.utils._pytree as pytree @@ -1763,6 +1981,8 @@ def reconstruct(self, codegen: "PyCodegen") -> None: # For other frozen dataclasses, fall back to the base class behavior super().reconstruct(codegen) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: This is called during __init__ for a frozen dataclass # use this to accumulate the most up-to-date field values def method_setattr_standard(self, tx: "InstructionTranslator", name, value): @@ -1816,7 +2036,11 @@ def call_method(self, tx, name, args, kwargs): self.exc_vt.args = args self.value.args = args return variables.ConstantVariable(None) +<<<<<<< HEAD elif ( +======= + if ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) name == "__setattr__" and len(args) == 2 and isinstance(args[0], variables.ConstantVariable) @@ -1824,18 +2048,24 @@ def call_method(self, tx, name, args, kwargs): in ("__cause__", "__context__", "__suppress_context__", "__traceback__") ): self.exc_vt.call_setattr(tx, args[0], args[1]) +<<<<<<< HEAD elif name == "with_traceback": return self.exc_vt.call_method(tx, name, args, kwargs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().call_method(tx, name, args, kwargs) @property def __context__(self): return self.exc_vt.__context__ +<<<<<<< HEAD @property def args(self): return self.exc_vt.args +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def set_context(self, context: "variables.ExceptionVariable"): return self.exc_vt.set_context(context) @@ -1940,7 +2170,11 @@ def __init__(self, value, dict_vt=None, **kwargs): "dict_vt must be constructed by builder.py when source is present" ) self._dict_vt = variables.ConstDictVariable( +<<<<<<< HEAD {}, type(value), mutation_type=ValueMutationNew() +======= + {}, mutation_type=ValueMutationNew() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self._dict_methods = dict_methods @@ -1953,6 +2187,7 @@ def call_method( ) -> "VariableTracker": method = self._maybe_get_baseclass_method(name) if method in self._dict_methods: +<<<<<<< HEAD # Dict subclasses can override __missing__ to provide fallback # behavior instead of raising a KeyError. This is used, for example, # by collections.Counter. @@ -1967,6 +2202,9 @@ def call_method( return self.call_method(tx, "__missing__", args, kwargs) else: raise +======= + return self._dict_vt.call_method(tx, name, args, kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().call_method(tx, name, args, kwargs) def unpack_var_sequence(self, tx): @@ -1980,6 +2218,7 @@ def unpack_var_sequence(self, tx): def is_underlying_vt_modified(self, side_effects): return side_effects.is_modified(self._dict_vt) +<<<<<<< HEAD @property def user_cls(self): return self._dict_vt.user_cls @@ -2069,6 +2308,8 @@ def install_dict_keys_match_guard(self): def install_dict_contains_guard(self): return self._set_vt.install_dict_contains_guard() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class UserDefinedListVariable(UserDefinedObjectVariable): """ @@ -2138,7 +2379,11 @@ def __init__(self, value, tuple_vt=None, init_args=None, **kwargs): from torch._dynamo.symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() +<<<<<<< HEAD elems = init_args[0].force_unpack_var_sequence(tx) +======= + elems = init_args[0].unpack_var_sequence(tx) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._tuple_vt = variables.TupleVariable( elems, mutation_type=ValueMutationNew() ) @@ -2169,6 +2414,10 @@ class MutableMappingVariable(UserDefinedObjectVariable): def __init__(self, value, **kwargs): super().__init__(value, **kwargs) self.generic_dict_vt = variables.ConstDictVariable({}) +<<<<<<< HEAD +======= + self.mutation_type = AttributeMutationExisting() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": # A common pattern in the init code of MutableMapping objects is to diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 02e8c118cfafb..dc23ee887130d 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -148,7 +148,10 @@ def aot_compile( with torch.no_grad(): so_path = torch._inductor.aot_compile(gm, args, kwargs, options=options) # type: ignore[arg-type] +<<<<<<< HEAD assert isinstance(so_path, (str, list)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return so_path def aot_load(so_path: str, device: str) -> Callable: diff --git a/torch/_export/converter.py b/torch/_export/converter.py index bba7c2d16aa65..2cb7617a93324 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -134,7 +134,11 @@ def execute_subgraph_from_prim_loop( ): """ subgraph: GraphModule from sub-block. +<<<<<<< HEAD iter_idx: The index of interaction. +======= + iter_idx: The index of interation. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) len_loop_local_arguments: The number of loop local arguments in args. """ @@ -624,9 +628,15 @@ def convert_graph_inputs(self): self.fx_graph, name, self.is_top_level_graph() ) elif name in self.name_to_constant: +<<<<<<< HEAD assert isinstance(self.name_to_constant[name], torch.ScriptObject), ( "Input conversion only handles ScriptObject" ) +======= + assert isinstance( + self.name_to_constant[name], torch.ScriptObject + ), "Input conversion only handles ScriptObject" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) normalized_name = normalize_name(name) self.input_specs.append( InputSpec( @@ -661,7 +671,13 @@ def convert_aten_Float(self, node: torch._C.Node): def to_float_tensor(t): return t.to(dtype=torch.float).item() +<<<<<<< HEAD inp_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()] # noqa: C416 +======= + inp_list = [ + self.get_fx_value_by_ir_value(inp) for inp in node.inputs() + ] # noqa: C416 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fx_node = self.fx_graph.call_function( to_float_tensor, tuple(inp_list), @@ -747,7 +763,13 @@ def convert_prim_Constant(self, node: torch._C.Node): self.name_to_constant[name] = value def convert_prim_CallMethod(self, node: torch._C.Node): +<<<<<<< HEAD inp_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()] # noqa: C416 +======= + inp_list = [ + self.get_fx_value_by_ir_value(inp) for inp in node.inputs() + ] # noqa: C416 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fx_node = self.fx_graph.call_method( node.s("name"), tuple(inp_list), @@ -779,9 +801,15 @@ def convert_prim_GetAttr(self, node: torch._C.Node): self.name_to_node[output_name] = self.fx_graph.get_attr(attr_fqn) else: if attr_fqn not in self.name_to_non_tensor_attribute_node: +<<<<<<< HEAD self.name_to_non_tensor_attribute_node[attr_fqn] = ( self.name_to_non_tensor_attribute[attr_fqn] ) +======= + self.name_to_non_tensor_attribute_node[ + attr_fqn + ] = self.name_to_non_tensor_attribute[attr_fqn] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.name_to_node[output_name] = self.name_to_non_tensor_attribute_node[ attr_fqn ] @@ -810,7 +838,11 @@ def convert_call_function_op(self, node: torch._C.Node): fx_node = self.fx_graph.call_function(target, args, kwargs) +<<<<<<< HEAD # TODO: convert sourceRange() into stack_trace +======= + # TODO: covnert sourceRange() into stack_trace +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # fx_node.meta["stack_trace"] = node.sourceRange() if node.outputsSize() == 1: @@ -846,6 +878,7 @@ def convert_prim_DictConstruct(self, node: torch._C.Node): k = self.get_fx_value_by_ir_value(inp) else: v = self.get_fx_value_by_ir_value(inp) +<<<<<<< HEAD assert k is not None and v is not None, ( "DictConstruct has an empty key value pair." ) @@ -855,6 +888,17 @@ def convert_prim_DictConstruct(self, node: torch._C.Node): assert k is None and v is None, ( "DictConstruct has an odd number of elements (violating our assumption)." ) +======= + assert ( + k is not None and v is not None + ), "DictConstruct has an empty key value pair." + output_dict[k] = v + k, v = None, None + + assert ( + k is None and v is None + ), "DictConstruct has an odd number of elements (violating our assumption)." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_name = node.output().debugName() self.name_to_node[output_name] = output_dict @@ -883,7 +927,11 @@ def convert_aten_Int(self, node: torch._C.Node): torch.ops.aten._local_scalar_dense.default, (to_copy_node,) ) +<<<<<<< HEAD # TODO: convert sourceRange() into stack_trace +======= + # TODO: covnert sourceRange() into stack_trace +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # fx_node.meta["stack_trace"] = node.sourceRange() output_name = node.output().debugName() @@ -942,7 +990,11 @@ def convert_aten_div(self, node: torch._C.Node): kwargs, ) +<<<<<<< HEAD # TODO: convert sourceRange() into stack_trace +======= + # TODO: covnert sourceRange() into stack_trace +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # fx_node.meta["stack_trace"] = node.sourceRange() output_name = node.output().debugName() @@ -1006,7 +1058,11 @@ def convert_aten_add(self, node: torch._C.Node): ): target = torch.ops.aten.add.t else: +<<<<<<< HEAD raise RuntimeError(f"unable to determined the target for {node}") +======= + raise RuntimeError(f"unable to determind the target for {node}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: target = get_op_overload(node) @@ -1120,9 +1176,15 @@ def convert_prim_Loop(self, node: torch._C.Node): ), # + 1 because the 0th element is the condition. ) global_argument_index = global_arguments.index(name) +<<<<<<< HEAD fx_block_args[i + node.outputsSize() + global_argument_index] = ( self.name_to_node[name] ) +======= + fx_block_args[ + i + node.outputsSize() + global_argument_index + ] = self.name_to_node[name] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _check_set_attr_in_if_block(self, if_node: torch._C.Node): for block in if_node.blocks(): @@ -1541,9 +1603,15 @@ def retrace_as_exported_program( for spec in ep.graph_signature.input_specs: # Mark as constant tensors for erroneously traced buffers. if spec.kind == InputKind.BUFFER and spec.target in name_to_constant: +<<<<<<< HEAD assert isinstance(name_to_constant[spec.target], torch.Tensor), ( f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer" ) +======= + assert isinstance( + name_to_constant[spec.target], torch.Tensor + ), f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) spec.kind = InputKind.CONSTANT_TENSOR spec.persistent = None ep.verifier().check(ep) @@ -1565,7 +1633,11 @@ def lift_get_attr(self): # # This function should happen in TS2EPConverter instead of # TS2FXGraphConverter since it gets attributes from self.ts_model +<<<<<<< HEAD # which is not accessible in TS2FXGraphConverter. It is similar to where +======= + # which is not accessable in TS2FXGraphConverter. It is similar to where +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # we collect self.name_to_param and self.name_to_buffer. name_to_attribute_fqn: dict[str, str] = {} diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index fffe85beb467e..872cbdd1c4cfa 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -169,10 +169,14 @@ def fakify( return t if isinstance(t, _IntWrapper): +<<<<<<< HEAD if t.dynamism is not None and t.dynamism.type in ( # type: ignore[union-attr] _DimHintType.DYNAMIC, _DimHintType.AUTO, ): +======= + if t.dynamism is not None and t.dynamism.type in (_DimHintType.DYNAMIC, _DimHintType.AUTO): # type: ignore[union-attr] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) symint = mode.shape_env.create_unspecified_symint_and_symbol( # type: ignore[union-attr] t.val, source, DimDynamic.DYNAMIC ) @@ -330,7 +334,12 @@ def make_fake_inputs( args, kwargs, dynamic_shapes, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards=False, +======= + _is_torch_jit_trace=False, + allow_complex_guards_as_runtime_asserts=False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Given an nn module, example inputs, and constraints, return a new fake mode, @@ -365,8 +374,12 @@ def make_fake_inputs( # a toplevel TracingContext with a fake mode, so we do not want to # create another fake mode. fake_mode = context.fake_mode +<<<<<<< HEAD assert fake_mode is not None else: +======= + elif not _is_torch_jit_trace: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(nn_module.forward, functools.partial): # functools handles nesting by itself, no need to recurse code = nn_module.forward.func.__code__ @@ -382,12 +395,31 @@ def make_fake_inputs( shape_env=ShapeEnv( tracked_fakes=[], co_fields=co_fields, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, +======= + prefer_deferred_runtime_asserts_over_guards=True, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) trace_asserts=True, ), allow_non_fake_inputs=True, export=True, ) +<<<<<<< HEAD +======= + else: + with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): + fake_mode = FakeTensorMode( + shape_env=ShapeEnv( + tracked_fakes=[], + prefer_deferred_runtime_asserts_over_guards=True, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + trace_asserts=True, + ), + allow_non_fake_inputs=True, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None: raise ValueError( "Detected fake_mode does not have a shape_env with tracked fakes. " @@ -396,7 +428,15 @@ def make_fake_inputs( ) with fake_mode: +<<<<<<< HEAD original_signature = inspect.signature(nn_module.forward) +======= + # FIXME(ycao) ScriptMethod doesn't have signature, I am using an empty one to unblock + if not _is_torch_jit_trace: + original_signature = inspect.signature(nn_module.forward) + else: + original_signature = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sources: dict[tuple[int, int], list[Source]] = defaultdict(list) sourced_prefixes = make_sourced_prefixes(nn_module, args, kwargs) fake_args, fake_kwargs = tree_map_with_path( @@ -477,6 +517,10 @@ def produce_guards_and_solve_constraints( dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], equalities_inputs: EqualityConstraint, original_signature: inspect.Signature, +<<<<<<< HEAD +======= + _is_torch_jit_trace=False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions, @@ -517,6 +561,7 @@ def produce_guards_and_solve_constraints( raise constraint_violation_error dim_constraints.solve() forced_specializations = dim_constraints.forced_specializations() +<<<<<<< HEAD msg = dim_constraints.prettify_results( original_signature, @@ -525,6 +570,18 @@ def produce_guards_and_solve_constraints( forced_specializations, # type: ignore[arg-type] ) +======= + if not _is_torch_jit_trace: + msg = dim_constraints.prettify_results( + original_signature, + dynamic_shapes, # type: ignore[arg-type] + constraint_violation_error, + forced_specializations, # type: ignore[arg-type] + ) + else: + # FIXME(ycao): This is a hack to get around missing signature from ScriptMethod + msg = "dummy constraint violation message" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if constraint_violation_error: constraint_violation_error.args = (constraint_violation_error.args[0] + msg,) elif forced_specializations: @@ -852,7 +909,11 @@ def _fakify_script_objects( mod: torch.nn.Module, args: Sequence[Any], kwargs: dict[Any, Any], +<<<<<<< HEAD fake_mode: Optional[torch._subclasses.fake_tensor.FakeTensorMode], +======= + fake_mode: torch._subclasses.fake_tensor.FakeTensorMode, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): # This context manager is used to fakify script objects into FakeScriptObject. # Inputs: @@ -983,12 +1044,16 @@ def _override(self, func, args, kwargs): def rewrite(dim, item): # Redirect to torch.select for indexing. +<<<<<<< HEAD if item is None: return dim + 1, (torch.unsqueeze, [dim]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(item, (int, torch.SymInt)): return dim, (torch.select, [dim, item]) # Redirect to torch.ops.aten.slice for slicing. if isinstance(item, slice): +<<<<<<< HEAD step = item.step or 1 if item.start is None and item.stop is None and step == 1: # no-op @@ -1042,6 +1107,32 @@ def run(): return t return run, [], {} +======= + return dim + 1, ( + torch.ops.aten.slice, + [dim, item.start, item.stop, item.step or 1], + ) + # Otherwise do nothing. + + items = args[1] if isinstance(args[1], tuple) else (args[1],) + dim = 0 + # Sequence rewrites. + sequence = [] + for item in items: + if (r := rewrite(dim, item)) is None: + return func, args, kwargs + dim, call_spec = r + sequence.append(call_spec) + + def run(): + # Run sequence. + t = args[0] + for _method, _args in sequence: + t = _method(t, *_args) + return t + + return run, [], {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return func, args, kwargs diff --git a/torch/_export/pass_base.py b/torch/_export/pass_base.py index 952e904ca26e0..14fa3f5d4410f 100644 --- a/torch/_export/pass_base.py +++ b/torch/_export/pass_base.py @@ -252,11 +252,16 @@ def call_function( else: raise ExportPassBaseError(f"Unsupported target type: {target}") +<<<<<<< HEAD def get_attr( # type: ignore[override] self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument], +======= + def get_attr( + self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Argument: return super().get_attr(target, args, kwargs) @@ -268,11 +273,16 @@ def call_module( ) -> None: raise ExportPassBaseError("call_module is not supported.") +<<<<<<< HEAD def call_method( # type: ignore[override] self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument], +======= + def call_method( + self, target: str, args: tuple[Argument, ...], kwargs: dict[str, Argument] # type: ignore[override] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: raise ExportPassBaseError("call_method is not supported.") @@ -432,6 +442,7 @@ def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue: def call_submodule( self, graph_module: fx.GraphModule, inputs: tuple[Argument, ...] ) -> PassResult: +<<<<<<< HEAD prev_tracer, self.tracer = ( self.tracer, self.ExportTracer(self, graph_module.graph._codegen), @@ -443,6 +454,15 @@ def call_submodule( torch.fx.Interpreter( # type: ignore[assignment] torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) ), +======= + prev_tracer, self.tracer = self.tracer, self.ExportTracer( + self, graph_module.graph._codegen + ) + self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode + interpreter = self.ExportInterpreter(self, graph_module) + prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment] + torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs) with fx_traceback.preserve_node_meta(): @@ -468,9 +488,15 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: fake_tensor_mode = None for i in inputs: if isinstance(i, FakeTensor): +<<<<<<< HEAD assert fake_tensor_mode is None or fake_tensor_mode is i.fake_mode, ( "Multiple fake tensor mode detected." ) +======= + assert ( + fake_tensor_mode is None or fake_tensor_mode is i.fake_mode + ), "Multiple fake tensor mode detected." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fake_tensor_mode = i.fake_mode if fake_tensor_mode is None: self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True) diff --git a/torch/_export/passes/_node_metadata_hook.py b/torch/_export/passes/_node_metadata_hook.py index f1958815293c1..417f3cd41f5dd 100644 --- a/torch/_export/passes/_node_metadata_hook.py +++ b/torch/_export/passes/_node_metadata_hook.py @@ -1,29 +1,43 @@ # mypy: allow-untyped-defs import contextlib +<<<<<<< HEAD from typing import Any, Optional import torch import torch.utils._pytree as pytree from torch._dispatch.python import enable_python_dispatcher from torch._subclasses.fake_tensor import FakeTensorMode +======= +from typing import Optional + +import torch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.graph_module import GraphModule _EMPTY_NN_MODULE_STACK_KEY = "_empty_nn_module_stack_from_metadata_hook" +<<<<<<< HEAD def _node_metadata_hook( node: torch.fx.Node, metadata: Optional[dict[str, Any]] = None, fake_mode: Optional[FakeTensorMode] = None, ) -> None: +======= +def _node_metadata_hook(node: torch.fx.Node, stack_trace: Optional[str] = None) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Hook for adding the appropriate metadata to nodes that are created during a pass using graph.create_node. An example of how to use it: ``` with _set_node_metadata_hook(gm, +<<<<<<< HEAD functools.partial(_node_metadata_hook, metadata={"stack_trace": "file"}) +======= + functools.partial(_node_metadata_hook, stack_trace="file") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): pass(gm) ``` @@ -32,11 +46,19 @@ def _node_metadata_hook( that nodes being added are only call_function nodes, and copies over the first argument node's nn_module_stack. """ +<<<<<<< HEAD fake_mode = fake_mode or contextlib.nullcontext() assert node.op == "call_function" and callable(node.target), ( f"node: {node}, target: {node.target}" ) +======= + assert node.op == "call_function" and callable(node.target) + + arg_meta = [arg.meta for arg in node.args if isinstance(arg, torch.fx.Node)] + assert len(arg_meta) >= 1 + arg_meta = arg_meta[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( isinstance(node.target, torch._ops.OpOverload) @@ -44,6 +66,7 @@ def _node_metadata_hook( ): node.meta["val"] = None else: +<<<<<<< HEAD fake_args, fake_kwargs = pytree.tree_map_only( torch.fx.Node, lambda arg: arg.meta["val"], (node.args, node.kwargs) ) @@ -84,6 +107,28 @@ def _node_metadata_hook( f"{node.target.__name__}_0", f"{node.target.__class__.__name__}.{node.target.__name__}", ), +======= + fake_args = [ + arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg + for arg in node.args + ] + fake_res = node.target(*fake_args) + node.meta["val"] = fake_res + + node.meta["stack_trace"] = stack_trace + node.meta["nn_module_stack"] = arg_meta.get( + "nn_module_stack", + { + _EMPTY_NN_MODULE_STACK_KEY: ( + _EMPTY_NN_MODULE_STACK_KEY, + _EMPTY_NN_MODULE_STACK_KEY, + ) + }, + ) + node.meta["torch_fn"] = ( + f"{node.target.__name__}_0", + f"{node.target.__class__.__name__}.{node.target.__name__}", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/torch/_export/passes/insert_custom_op_guards.py b/torch/_export/passes/insert_custom_op_guards.py index bfea7b08c9248..bb710103230da 100644 --- a/torch/_export/passes/insert_custom_op_guards.py +++ b/torch/_export/passes/insert_custom_op_guards.py @@ -16,6 +16,7 @@ def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: set[str]) -> """ for node in gm.graph.nodes: if node.op == "call_function" and str(node.target) in ops_to_guard: +<<<<<<< HEAD with ( _set_node_metadata_hook( gm, @@ -26,6 +27,14 @@ def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: set[str]) -> ), gm.graph.inserting_before(node), ): +======= + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, stack_trace=node.meta.get("stack_trace") + ), + ), gm.graph.inserting_before(node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for arg in (*node.args, *node.kwargs.values()): if isinstance(arg, torch.fx.Node) and isinstance( arg.meta.get("val"), torch.Tensor diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index 20253a91c2583..4695d1846f2e9 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -165,7 +165,11 @@ def lift_constants_pass( constant_attrs: ConstantAttrMap, ) -> dict[str, _ConstantAttributeType]: """ +<<<<<<< HEAD Takes a graph module, graph signature, and modifies them inplace to lift any +======= + Takes a graph module, graph signature, and modifies them implace to lift any +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) constants (tensors or custom classes) as inputs to the graph. Returns a dictionary of names to constants. @@ -183,12 +187,21 @@ def lift_constants_pass( """ all_constants: dict[str, _ConstantAttributeType] = {} +<<<<<<< HEAD input_specs = graph_signature.input_specs num_custom_obj = sum( input_spec.kind == InputKind.CUSTOM_OBJ for input_spec in input_specs ) num_tensor_constants = sum( input_spec.kind == InputKind.CONSTANT_TENSOR for input_spec in input_specs +======= + inputs = graph_signature.input_specs + num_custom_obj = sum( + input_specs.kind == InputKind.CUSTOM_OBJ for input_specs in inputs + ) + num_tensor_constants = sum( + input_specs.kind == InputKind.CONSTANT_TENSOR for input_specs in inputs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) fake_mode = detect_fake_mode( @@ -197,6 +210,7 @@ def lift_constants_pass( first_user_input_loc, first_user_input = 0, next(iter(gm.graph.nodes)) used_target_names = set() +<<<<<<< HEAD input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] assert len(input_nodes) == len(input_specs) @@ -205,6 +219,21 @@ def lift_constants_pass( if input_spec.kind == InputKind.USER_INPUT: first_user_input = node first_user_input_loc = i +======= + for node in gm.graph.nodes: + if node.op == "placeholder": + if node.name in graph_signature.user_inputs: + first_user_input = node + break + used_target_names.add(inputs[first_user_input_loc].target) + first_user_input_loc += 1 + # If we ever hit here, it means that + # there was no user input so the constants + # should be inserted right before the first + # non-placeholder node. + if node.op != "placeholder": + first_user_input = node +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) break lifted_objs = ConstantAttrMap() @@ -368,7 +397,11 @@ def lift_constants_pass( def rewrite_script_object_meta( gm: torch.fx.GraphModule, +<<<<<<< HEAD ) -> dict[str, _ConstantAttributeType]: +======= +) -> dict[str, _ConstantAttributeType,]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """When tracing, we produce a graph with FakeScriptObject in the meta["val"]. diff --git a/torch/_export/passes/replace_autocast_with_hop_pass.py b/torch/_export/passes/replace_autocast_with_hop_pass.py index 71b90a3ff1bfb..386910f702883 100644 --- a/torch/_export/passes/replace_autocast_with_hop_pass.py +++ b/torch/_export/passes/replace_autocast_with_hop_pass.py @@ -100,8 +100,13 @@ def _split_autocast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: split_autocast creates a new graph module that splits the input graph module into multiple submodules based on the `_enter_autocast` and `_exit_autocast` nodes. It doesn't mutate the input graph module. +<<<<<<< HEAD Nodes between the **outer-most** `_enter_autocast` and `_exit_autocast(_enter_autocast)` are split into a submodule. Nested autocast regions are not split. +======= + Nodes between the **outer-most** `_enter_autocast` and `_exit_autocast(_enter_autocast)` are splitted + into a submodule. Nested autocast regions are not splitted. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) `_enter_autocast` and `_exit_autocast(_enter_autocast)` nodes are in the submodule as well. Below is an example of splitting. A, B, C, D, E are blocks of non-autocast nodes in the original graph diff --git a/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py b/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py index 4d91876801011..c880efd2b2734 100644 --- a/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py +++ b/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py @@ -292,7 +292,11 @@ def _conv1d_op_with_squeeze( def _transform_conv_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node): +<<<<<<< HEAD """Conv specific transformation function.""" +======= + """Conv specfic transformation function.""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(node.target, torch._ops.OpOverload) opname = node.target._opname scale_node, zero_point_node = node.args[2], node.args[3] @@ -347,7 +351,11 @@ def _transform_conv_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.No def _transform_linear_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node): +<<<<<<< HEAD """Linear specific transformation function.""" +======= + """Linear specfic transformation function.""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) scale_node, zero_point_node = node.args[2], node.args[3] inp_node, param_node = node.args[0], node.args[1] diff --git a/torch/_export/passes/replace_with_hop_pass_util.py b/torch/_export/passes/replace_with_hop_pass_util.py index 974058092448c..a34145f269f92 100644 --- a/torch/_export/passes/replace_with_hop_pass_util.py +++ b/torch/_export/passes/replace_with_hop_pass_util.py @@ -46,7 +46,11 @@ def set_hoo_node_meta(call_func_node): enter_block_node.meta.get("nn_module_stack", {}) ) output_node = next(iter(reversed(sub_gm.graph.nodes)), None) +<<<<<<< HEAD # Split_module pass intentionally doesn't add output node +======= + # Split_module pass intentially doesn't add output node +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # if the graph doesn't return anything. # TODO (tmanlaibaatar) Figure out if this is right behaviour # for split_module @@ -97,7 +101,11 @@ def set_hoo_node_meta(call_func_node): node_replace_(node, get_item_node) else: raise NotImplementedError( +<<<<<<< HEAD f"replace_with_hop_pass doesn't support output type {type(output_args)}" +======= + f"repalce_with_hop_pass doesnt' support output type {type(output_args)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: # TODO (shangdiy): remove this line, since the export graph can be non-functional diff --git a/torch/_export/serde/dynamic_shapes.py b/torch/_export/serde/dynamic_shapes.py index e6a0295163dcb..2191ead4067ce 100644 --- a/torch/_export/serde/dynamic_shapes.py +++ b/torch/_export/serde/dynamic_shapes.py @@ -107,6 +107,7 @@ def _dump_dynamic_shapes( would generate the following output: ``` { +<<<<<<< HEAD "dynamic_shapes": ( [ ["dx", 4], @@ -121,6 +122,22 @@ def _dump_dynamic_shapes( "min": 4, "max": 16, "derived": ["dx + 1"], +======= + 'dynamic_shapes': ( + [ + ['dx', 4], + ['dx + 1', 4], + ], + ['_DimHint.STATIC'], + ['_DimHint.STATIC', '_DimHint.STATIC'], + None, + ), + 'dims': { + 'dx': { + 'min': 4, + 'max': 16, + 'derived': ['dx + 1'], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, }, } @@ -149,7 +166,11 @@ def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def] return out def _track_dim_from_dims( +<<<<<<< HEAD val: Union[None, int, _DimHint, Dim], +======= + val: Union[None, int, _DimHint, Dim] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Union[None, int, str]: """ Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec. @@ -295,7 +316,11 @@ def _load_dynamic_shapes( dim_cache[_expr] = ddim # cache derived dims def deserialize_shape( +<<<<<<< HEAD val: Union[None, int, str], +======= + val: Union[None, int, str] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Union[None, int, Dim, _DimHint]: if val is None or isinstance(val, int): return val diff --git a/torch/_export/serde/export_schema.thrift b/torch/_export/serde/export_schema.thrift index f4a08f8739993..def587ee5e949 100644 --- a/torch/_export/serde/export_schema.thrift +++ b/torch/_export/serde/export_schema.thrift @@ -1,5 +1,9 @@ // @generated by update_schema.py +<<<<<<< HEAD // checksum<> +======= +// checksum<> +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace py3 torch._export namespace cpp2 torch._export.schema @@ -50,8 +54,11 @@ enum ScalarType { UINT16 = 28, FLOAT8E4M3FN = 29, FLOAT8E5M2 = 30, +<<<<<<< HEAD FLOAT8E4M3FNUZ = 31, FLOAT8E5M2FNUZ = 32, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } @@ -134,11 +141,14 @@ struct CustomObjArgument { 20: string class_fqn; } +<<<<<<< HEAD struct ComplexValue { 10: double real; 20: double imag; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) union Argument { 10: bool as_none; 20: TensorArgument as_tensor; @@ -166,7 +176,10 @@ union Argument { 230: SymFloatArgument as_sym_float; 240: list as_sym_floats; 250: OptionalTensorArgument as_optional_tensor; +<<<<<<< HEAD 260: ComplexValue as_complex; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } struct NamedArgument { @@ -260,11 +273,14 @@ struct BufferMutationSpec { 20: string buffer_name; } +<<<<<<< HEAD struct ParameterMutationSpec { 10: TensorArgument arg; 20: string parameter_name; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct GradientToParameterSpec { 10: TensorArgument arg; 20: string parameter_name; @@ -292,7 +308,10 @@ union OutputSpec { 50: GradientToUserInputSpec gradient_to_user_input; 60: UserInputMutationSpec user_input_mutation; 70: OutputTokenSpec token; +<<<<<<< HEAD 80: ParameterMutationSpec parameter_mutation; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } struct GraphSignature { @@ -342,6 +361,7 @@ struct ExportedProgram { 60: SchemaVersion schema_version; 70: list verifiers; 80: string torch_version; +<<<<<<< HEAD 90: list guards_code; } @@ -354,6 +374,21 @@ struct PayloadMeta { struct PayloadConfig { 10: map config; +======= +} + +struct Program { + 200: map methods; +} + +struct Model { + 10: string name; + 20: map tensorPaths; + 40: Program program; + 50: map delegates; + 60: map deviceAllocationMap; + 70: map constantPaths; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } struct AOTInductorModelPickleData { diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index f4ce89c006f59..1733cf048baad 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -5,11 +5,19 @@ from enum import IntEnum from typing import Annotated, Optional +<<<<<<< HEAD from torch._export.serde.union import _Union, _union_dataclass # NOTE: Please update this value if any modifications are made to the schema SCHEMA_VERSION = (8, 14) +======= +from torch._export.serde.union import _Union + + +# NOTE: Please update this value if any modifications are made to the schema +SCHEMA_VERSION = (8, 8) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TREESPEC_VERSION = 1 @@ -33,8 +41,11 @@ class ScalarType(IntEnum): UINT16 = 28 FLOAT8E4M3FN = 29 FLOAT8E5M2 = 30 +<<<<<<< HEAD FLOAT8E4M3FNUZ = 31 FLOAT8E5M2FNUZ = 32 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Layout(IntEnum): @@ -62,7 +73,11 @@ class Device: index: Annotated[Optional[int], 20] = None +<<<<<<< HEAD @_union_dataclass +======= +@dataclass(repr=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SymExprHint(_Union): as_int: Annotated[int, 10] as_bool: Annotated[bool, 20] @@ -79,19 +94,31 @@ class SymExpr: hint: Annotated[Optional[SymExprHint], 20] = None +<<<<<<< HEAD @_union_dataclass +======= +@dataclass(repr=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SymInt(_Union): as_expr: Annotated[SymExpr, 10] as_int: Annotated[int, 20] +<<<<<<< HEAD @_union_dataclass +======= +@dataclass(repr=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SymFloat(_Union): as_expr: Annotated[SymExpr, 10] as_float: Annotated[float, 20] +<<<<<<< HEAD @_union_dataclass +======= +@dataclass(repr=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SymBool(_Union): as_expr: Annotated[SymExpr, 10] as_bool: Annotated[bool, 20] @@ -114,7 +141,11 @@ class TensorMeta: # of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to # be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints # to the "as_int" field. +<<<<<<< HEAD @_union_dataclass +======= +@dataclass(repr=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SymIntArgument(_Union): as_name: Annotated[str, 10] as_int: Annotated[int, 20] @@ -126,7 +157,11 @@ class SymIntArgument(_Union): # of SymFloat and float (ex. [1.0, s0, ...]). We will serialize this type of list to # be List[SymFloatArgument] and map the SymFloats to the "as_name" field, and ints # to the "as_float" field. +<<<<<<< HEAD @_union_dataclass +======= +@dataclass(repr=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SymFloatArgument(_Union): as_name: Annotated[str, 10] as_float: Annotated[float, 20] @@ -138,7 +173,11 @@ class SymFloatArgument(_Union): # of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to # be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools # to the "as_bool" field. +<<<<<<< HEAD @_union_dataclass +======= +@dataclass(repr=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SymBoolArgument(_Union): as_name: Annotated[str, 10] as_bool: Annotated[bool, 20] @@ -158,7 +197,11 @@ class TokenArgument: # (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the # type List[OptionalTensorArgument], with tensor values seiralized to the # "as_tensor" field, and None values serialized to the "as_none" field. +<<<<<<< HEAD @_union_dataclass +======= +@dataclass(repr=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class OptionalTensorArgument(_Union): as_tensor: Annotated[TensorArgument, 20] as_none: Annotated[bool, 10] @@ -176,6 +219,7 @@ class CustomObjArgument: class_fqn: Annotated[str, 20] +<<<<<<< HEAD @dataclass class ComplexValue: real: Annotated[float, 10] @@ -184,6 +228,10 @@ class ComplexValue: # This is actually a union type @_union_dataclass +======= +# This is actually a union type +@dataclass(repr=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Argument(_Union): as_none: Annotated[bool, 10] as_tensor: Annotated[TensorArgument, 20] @@ -211,7 +259,10 @@ class Argument(_Union): as_sym_float: Annotated[SymFloatArgument, 230] as_sym_floats: Annotated[list[SymFloatArgument], 240] as_optional_tensor: Annotated[OptionalTensorArgument, 250] +<<<<<<< HEAD as_complex: Annotated[ComplexValue, 260] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ArgumentKind(IntEnum): @@ -262,7 +313,11 @@ class UserInputSpec: arg: Annotated[Argument, 10] +<<<<<<< HEAD @_union_dataclass +======= +@dataclass(repr=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ConstantValue(_Union): as_none: Annotated[bool, 10] as_int: Annotated[int, 20] @@ -307,7 +362,11 @@ class InputTokenSpec: arg: Annotated[TokenArgument, 10] +<<<<<<< HEAD @_union_dataclass +======= +@dataclass(repr=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class InputSpec(_Union): user_input: Annotated[UserInputSpec, 10] parameter: Annotated[InputToParameterSpec, 20] @@ -335,12 +394,15 @@ class BufferMutationSpec: @dataclass +<<<<<<< HEAD class ParameterMutationSpec: arg: Annotated[TensorArgument, 10] parameter_name: Annotated[str, 20] @dataclass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class GradientToParameterSpec: arg: Annotated[TensorArgument, 10] parameter_name: Annotated[str, 20] @@ -363,7 +425,11 @@ class OutputTokenSpec: arg: Annotated[TokenArgument, 10] +<<<<<<< HEAD @_union_dataclass +======= +@dataclass(repr=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class OutputSpec(_Union): user_output: Annotated[UserOutputSpec, 10] loss_output: Annotated[LossOutputSpec, 20] @@ -372,7 +438,10 @@ class OutputSpec(_Union): gradient_to_user_input: Annotated[GradientToUserInputSpec, 50] user_input_mutation: Annotated[UserInputMutationSpec, 60] token: Annotated[OutputTokenSpec, 70] +<<<<<<< HEAD parameter_mutation: Annotated[ParameterMutationSpec, 80] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass @@ -398,7 +467,11 @@ class ModuleCallSignature: out_spec: Annotated[str, 40] # This field is used to prettify the graph placeholders +<<<<<<< HEAD # after we Ser/Der and retrace +======= + # after we ser/der and retrace +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) forward_arg_names: Annotated[Optional[list[str]], 50] = None @@ -429,7 +502,11 @@ class GraphModule: # Invariant: Every time a change is made to the schema, one of the versions +<<<<<<< HEAD # should be updated. +======= +# should be upadted. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass class SchemaVersion: major: Annotated[ @@ -449,7 +526,10 @@ class ExportedProgram: schema_version: Annotated[SchemaVersion, 60] verifiers: Annotated[list[str], 70] = field(default_factory=list) torch_version: Annotated[str, 80] = "<=2.4" +<<<<<<< HEAD guards_code: Annotated[list[str], 90] = field(default_factory=list) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ######################################################################### @@ -457,6 +537,7 @@ class ExportedProgram: ######################################################################### +<<<<<<< HEAD # The metadata for payload saved in PT2 archive. # payload includes params, buffers, tensor constants, and custom objects. @dataclass @@ -476,6 +557,31 @@ class PayloadMeta: @dataclass class PayloadConfig: config: Annotated[dict[str, PayloadMeta], 10] +======= +@dataclass +class Program: + methods: Annotated[dict[str, ExportedProgram], 200] + + +# This is the top-level model definition that be will serialized into the package +@dataclass +class Model: + # unique identifier of the model in the package, e.g. local, remote, merge + name: Annotated[str, 10] + # key is the FQN of tensor in exported program + # value is the archive path of tensor payloads + # e.g. "L__self__linear.weight" : "/data/tensor/L__self__linear.weight" + tensorPaths: Annotated[dict[str, str], 20] + # program exported from torch.export() + program: Annotated[Program, 40] + # Backend-specialized Lowered GraphModule + # e.g. "aotinductor-a100" : ExportedProgram_with_AOTInductor_delegate + delegates: Annotated[dict[str, Program], 50] + deviceAllocationMap: Annotated[dict[str, str], 60] + # key is the FQN of constant in exported program (constant tensor or torchbind objs) + # value is the archive path of serialized constants + constantPaths: Annotated[dict[str, str], 70] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index 951351e7786aa..6189ad0a4323b 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,9 @@ # @generated by update_schema.py +<<<<<<< HEAD # checksum<<74d07b92c36d5854263145c231553dcda15215f0460e7ace43554248c05378ec>> +======= +# checksum<<110c364974d3b0f7dcbdf6862781212bdcc7178925c43c894c336fc2b6ca6628>> +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelPickleData: kind: struct fields: @@ -73,8 +77,11 @@ Argument: type: List[SymFloatArgument] as_optional_tensor: type: OptionalTensorArgument +<<<<<<< HEAD as_complex: type: ComplexValue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ArgumentKind: kind: enum fields: @@ -88,6 +95,7 @@ BufferMutationSpec: type: TensorArgument buffer_name: type: str +<<<<<<< HEAD ComplexValue: kind: struct fields: @@ -95,6 +103,8 @@ ComplexValue: type: float imag: type: float +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ConstantValue: kind: union fields: @@ -140,9 +150,12 @@ ExportedProgram: torch_version: type: str default: <=2.4 +<<<<<<< HEAD guards_code: type: List[str] default: '[]' +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ExternKernelNode: kind: struct fields: @@ -305,6 +318,24 @@ MemoryFormat: ChannelsLast: 2 ChannelsLast3d: 3 PreserveFormat: 4 +<<<<<<< HEAD +======= +Model: + kind: struct + fields: + name: + type: str + tensorPaths: + type: Dict[str, str] + program: + type: Program + delegates: + type: Dict[str, Program] + deviceAllocationMap: + type: Dict[str, str] + constantPaths: + type: Dict[str, str] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ModuleCallEntry: kind: struct fields: @@ -380,13 +411,17 @@ OutputSpec: type: UserInputMutationSpec token: type: OutputTokenSpec +<<<<<<< HEAD parameter_mutation: type: ParameterMutationSpec +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) OutputTokenSpec: kind: struct fields: arg: type: TokenArgument +<<<<<<< HEAD ParameterMutationSpec: kind: struct fields: @@ -410,6 +445,13 @@ PayloadMeta: type: bool tensor_meta: type: Optional[TensorMeta] +======= +Program: + kind: struct + fields: + methods: + type: Dict[str, ExportedProgram] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) RangeConstraint: kind: struct fields: @@ -437,8 +479,11 @@ ScalarType: UINT16: 28 FLOAT8E4M3FN: 29 FLOAT8E5M2: 30 +<<<<<<< HEAD FLOAT8E4M3FNUZ: 31 FLOAT8E5M2FNUZ: 32 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SchemaVersion: kind: struct fields: @@ -551,5 +596,9 @@ UserOutputSpec: type: Argument SCHEMA_VERSION: - 8 +<<<<<<< HEAD - 14 +======= +- 8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TREESPEC_VERSION: 1 diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index 29b9766ae18a4..b732f6e731c7d 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -129,6 +129,7 @@ def dump_field(f) -> tuple[dict[str, Any], str, Optional[str], str, int]: t, cpp_type, thrift_type = dump_type(f.type, 0) ret = {"type": t} cpp_default: Optional[str] = None +<<<<<<< HEAD assert typing.get_origin(f.type) == Annotated, ( f"Field {f.name} must be annotated with an integer id." ) @@ -136,6 +137,15 @@ def dump_field(f) -> tuple[dict[str, Any], str, Optional[str], str, int]: assert type(thrift_id) is int, ( f"Field {f.name} must be annotated with an integer id." ) +======= + assert ( + typing.get_origin(f.type) == Annotated + ), f"Field {f.name} must be annotated with an integer id." + thrift_id = f.type.__metadata__[0] + assert ( + type(thrift_id) is int + ), f"Field {f.name} must be annotated with an integer id." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) value = dataclasses.MISSING if f.default is not dataclasses.MISSING: @@ -173,7 +183,13 @@ def dump_field(f) -> tuple[dict[str, Any], str, Optional[str], str, int]: def _handle_int_enum(name, ty): yaml_ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}} +<<<<<<< HEAD cpp_enum_defs[name] = f""" +======= + cpp_enum_defs[ + name + ] = f""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) enum class {name} {{ {chr(10).join([f" {x.name} = {x.value}," for x in ty])} }}; @@ -238,6 +254,7 @@ def accessor(name, ty): from_json_def = f"""{{ {name} nlohmann_json_default_obj; +<<<<<<< HEAD { chr(10).join( [ @@ -249,6 +266,16 @@ def accessor(name, ty): }} """ cpp_class_defs[name] = f""" +======= +{chr(10).join( + [f' nlohmann_json_t.{name} = nlohmann_json_j.value("{name}", nlohmann_json_default_obj.{name});' + for name, f in cpp_fields.items()])} +}} +""" + cpp_class_defs[ + name + ] = f""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class {name} {{ private: {field_decls} @@ -263,7 +290,13 @@ class {name} {{ cpp_json_defs.append(f"inline {from_json_decl} {from_json_def}") cpp_type_decls.append(f"class {name};") +<<<<<<< HEAD thrift_type_defs[name] = f""" +======= + thrift_type_defs[ + name + ] = f""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct {name} {{ {chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())} }}""" @@ -306,7 +339,13 @@ def accessor(name, ty, idx): ] ) +<<<<<<< HEAD cpp_class_defs[name] = f""" +======= + cpp_class_defs[ + name + ] = f""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class {name} {{ struct Void {{}}; @@ -349,7 +388,13 @@ class {name} {{ """ cpp_type_decls.append(f"class {name};") +<<<<<<< HEAD thrift_type_defs[name] = f""" +======= + thrift_type_defs[ + name + ] = f""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) union {name} {{ {chr(10).join(f" {f['thrift_id']}: {f['thrift_type']} {n};" for n, f in thrift_fields.items())} }}""" @@ -448,7 +493,10 @@ class ForwardRef {{ ptr_ = std::make_unique(*other.ptr_); return *this; }} +<<<<<<< HEAD ~ForwardRef(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const T& operator*() const {{ return *ptr_; }} @@ -520,7 +568,10 @@ class F64 {{ template ForwardRef::ForwardRef(ForwardRef&&) = default; template ForwardRef& ForwardRef::operator=(ForwardRef&&) = default; +<<<<<<< HEAD template ForwardRef::~ForwardRef() = default; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }} // namespace _export }} // namespace torch """ @@ -691,7 +742,11 @@ def check(commit: _Commit, force_unsafe: bool = False): for f, d in fields.items(): if kind == "struct" and "default" not in d: reason += ( +<<<<<<< HEAD f"Field {k}.{f} is added to schema.py without a default value as an incompatible change " +======= + f"Field {k}.{f} is added to schema.py without a default value as an incomparible change " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + "which requires major version bump.\n" ) next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 07674b5702947..22de9f6b3bd90 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -14,7 +14,11 @@ import traceback import typing from collections import namedtuple, OrderedDict +<<<<<<< HEAD from collections.abc import Iterable, Iterator, Sequence +======= +from collections.abc import Iterable, Iterator +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum @@ -35,14 +39,20 @@ from torch.utils._sympy.symbol import prefix_str, SymT from torch.utils._sympy.value_ranges import ValueRanges from torch.utils._traceback import CapturedTraceback +<<<<<<< HEAD from torch.utils._triton import has_triton +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..utils import remove_proxy_from_state_dict from .schema import ( # type: ignore[attr-defined] Argument, ArgumentKind, BufferMutationSpec, +<<<<<<< HEAD ComplexValue, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ConstantValue, CustomObjArgument, Device, @@ -71,7 +81,10 @@ OptionalTensorArgument, OutputSpec, OutputTokenSpec, +<<<<<<< HEAD ParameterMutationSpec, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) RangeConstraint, ScalarType, SCHEMA_VERSION, @@ -146,8 +159,11 @@ def _reverse_map(d: dict[Any, Enum]): torch.bfloat16: ScalarType.BFLOAT16, torch.float8_e4m3fn: ScalarType.FLOAT8E4M3FN, torch.float8_e5m2: ScalarType.FLOAT8E5M2, +<<<<<<< HEAD torch.float8_e4m3fnuz: ScalarType.FLOAT8E4M3FNUZ, torch.float8_e5m2fnuz: ScalarType.FLOAT8E5M2FNUZ, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } @@ -227,6 +243,7 @@ class _SerializedProgram: example_inputs: bytes +<<<<<<< HEAD class LazyMap(dict): """ Dictionary class for deferred instantiation of node metadata values. @@ -252,12 +269,15 @@ def __repr__(self): return self.map.__repr__() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def deserialize_device(d: Device) -> torch.device: if d.index is None: return torch.device(type=d.type) # type: ignore[call-overload] return torch.device(type=d.type, index=d.index) +<<<<<<< HEAD def deserialize_size(sizes: Sequence[SymInt]) -> tuple[int, ...]: for sym_int_size in sizes: assert sym_int_size.type == "as_int", ( @@ -283,6 +303,8 @@ def deserialize_storage_offset(offset: SymInt) -> int: return offset.as_int +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _print_sympy(s: Union[torch.SymInt, torch.SymBool, torch.SymFloat, sympy.Expr]): if isinstance(s, (torch.SymInt, torch.SymBool, torch.SymFloat)): s = s.node.expr @@ -353,7 +375,11 @@ def serialize_tensor_meta(t: torch.Tensor) -> TensorMeta: requires_grad=t.requires_grad, device=Device(type=t.device.type, index=t.device.index), strides=[serialize_sym_int(s) for s in t.stride()], +<<<<<<< HEAD storage_offset=serialize_sym_int(t.storage_offset()), +======= + storage_offset=serialize_sym_int(0), # TODO needs to be fixed. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) layout=_TORCH_TO_SERIALIZE_LAYOUT[t.layout], ) @@ -377,9 +403,15 @@ def _reconstruct_fake_tensor( json_tensor_meta = json.loads(serialized_tensor_meta.decode("utf-8")) tensor_meta = _dict_to_dataclass(TensorMeta, json_tensor_meta) # Find the current fake mode +<<<<<<< HEAD assert _CURRENT_DESERIALIZER is not None, ( "Need access to current deserializer state" ) +======= + assert ( + _CURRENT_DESERIALIZER is not None + ), "Need access to current deserializer state" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta) if is_parameter: fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment] @@ -392,9 +424,15 @@ def serialize_torch_artifact( if artifact is None: return b"" +<<<<<<< HEAD assert FakeTensor not in copyreg.dispatch_table, ( "Refusing to stomp on existing FakeTensor reducer" ) +======= + assert ( + FakeTensor not in copyreg.dispatch_table + ), "Refusing to stomp on existing FakeTensor reducer" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: copyreg.pickle(FakeTensor, _reduce_fake_tensor) buffer = io.BytesIO() @@ -411,7 +449,11 @@ def serialize_torch_artifact( def deserialize_torch_artifact( +<<<<<<< HEAD serialized: Union[dict[str, Any], tuple[Any, ...], bytes], +======= + serialized: Union[dict[str, Any], tuple[Any, ...], bytes] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): if isinstance(serialized, (dict, tuple)): return serialized @@ -470,7 +512,11 @@ def _symbol_index(sym: sympy.Symbol, sym_type: SymT): def serialize_range_constraints( +<<<<<<< HEAD range_constraints: dict[sympy.Symbol, ValueRanges], +======= + range_constraints: dict[sympy.Symbol, ValueRanges] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> dict[str, RangeConstraint]: return { str(k): RangeConstraint( @@ -554,9 +600,15 @@ def handle_placeholder(self, node: torch.fx.Node): graph_input = Argument.create( as_custom_obj=CustomObjArgument(name=node.name, class_fqn=class_fqn) ) +<<<<<<< HEAD self.graph_state.custom_obj_values[node.name] = ( self.serialize_script_obj_meta(val) ) +======= + self.graph_state.custom_obj_values[ + node.name + ] = self.serialize_script_obj_meta(val) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: raise AssertionError(f"Unimplemented graph input type: {node.meta['val']}") self.graph_state.inputs.append(graph_input) @@ -672,6 +724,7 @@ def serialize_tensor_list_output(node): metadata=self.serialize_metadata(node), is_hop_single_tensor_return=False, ) +<<<<<<< HEAD elif ( node.target is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional @@ -742,6 +795,8 @@ def serialize_tensor_list_output(node): metadata=self.serialize_metadata(node), is_hop_single_tensor_return=_is_hop_single_tensor_return(node), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: ex_node = Node( target=self.serialize_operator(node.target), @@ -752,9 +807,15 @@ def serialize_tensor_list_output(node): ) elif type(node.target) in _serialization_registry: # Sanity check for unhandled serialization. +<<<<<<< HEAD assert type(node.target) in _serialization_registry, ( f"{type(node.target)} is not supported in export serialization." ) +======= + assert ( + type(node.target) in _serialization_registry + ), f"{type(node.target)} is not supported in export serialization." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) handler = _serialization_registry[type(node.target)] namespace = handler.namespace() @@ -982,6 +1043,7 @@ def serialize_input(self, arg, arg_type: Optional[Any] = None) -> Argument: return Argument.create( as_graph=GraphArgument(name=arg.target, graph=graph) ) +<<<<<<< HEAD elif type(attr).__name__ == "LoweredBackendModule": # Special handling for executorch_call_delegate HOP # It's first argument is a LoweredBackendModule, for which we @@ -991,6 +1053,8 @@ def serialize_input(self, arg, arg_type: Optional[Any] = None) -> Argument: assert module_name is not None, "module_name should not be None" assert backend_id is not None, "backend_id should not be None" return Argument.create(as_string=f"{module_name}-{backend_id}") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: raise SerializeError( f"Unsupported getattr attribute {arg.target} with type: {type(attr)}" @@ -1058,10 +1122,13 @@ def serialize_input(self, arg, arg_type: Optional[Any] = None) -> Argument: return Argument.create(as_int=arg) elif type(arg) is float: return Argument.create(as_float=arg) +<<<<<<< HEAD elif type(arg) is complex: return Argument.create( as_complex=ComplexValue(real=arg.real, imag=arg.imag) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif arg is None: return Argument.create(as_none=True) elif isinstance(arg, (list, tuple)): @@ -1352,6 +1419,7 @@ def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec: buffer_name=spec.target, ) ) +<<<<<<< HEAD elif spec.kind == ep.OutputKind.PARAMETER_MUTATION: assert spec.target is not None assert isinstance(spec.arg, ep.TensorArgument) @@ -1361,6 +1429,8 @@ def serialize_output_spec(self, spec: ep.OutputSpec) -> OutputSpec: parameter_name=spec.target, ) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif spec.kind == ep.OutputKind.GRADIENT_TO_PARAMETER: assert spec.target is not None assert isinstance(spec.arg, ep.TensorArgument) @@ -1442,9 +1512,15 @@ def store_namedtuple_fields(ts): f"but somehow previously was found to have field names {field_names}." ) else: +<<<<<<< HEAD self.treespec_namedtuple_fields[serialized_type_name] = ( NamedTupleDef(field_names=ts.context._fields) ) +======= + self.treespec_namedtuple_fields[ + serialized_type_name + ] = NamedTupleDef(field_names=ts.context._fields) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for child in ts.children_specs: store_namedtuple_fields(child) @@ -1555,7 +1631,11 @@ def _is_single_tensor_list_return(target: Any) -> bool: assert isinstance( return_schema.real_type, (torch.OptionalType, torch.TensorType) ) +<<<<<<< HEAD # When the return type is annotated as Tensor type, the op can also return an +======= + # When the return type is annoated as Tensor type, the op can also return an +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # undefined Tensor which will be implicitly converted to None in Python. output_arguments.append(Argument.create(as_none=True)) elif isinstance(meta, FakeTensor): @@ -1626,6 +1706,7 @@ def serialize_hoo_outputs(self, node: torch.fx.Node) -> list[Argument]: outputs.append(self.serialize_output(name, element_meta_val)) return outputs +<<<<<<< HEAD elif isinstance(meta_val, dict): tensor_args = [] # use the dict key as the idx @@ -1637,6 +1718,8 @@ def serialize_hoo_outputs(self, node: torch.fx.Node) -> list[Argument]: name = self._output_node_name_at_index(node, idx) tensor_args.append(self.serialize_tensor_output(name, meta)) return [Argument.create(as_tensors=tensor_args)] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return [self.serialize_output(node.name, meta_val)] @@ -1674,9 +1757,15 @@ def _handle_getitem_users(self, node: torch.fx.Node) -> list[TensorArgument]: idx_to_name = {} for user in node.users: +<<<<<<< HEAD assert user.target is operator.getitem, ( f"User node {user} of {node} is incorrect" ) +======= + assert ( + user.target is operator.getitem + ), f"User node {user} of {node} is incorrect" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) idx_to_name[user.args[1]] = user.name for idx, _ in enumerate(meta_val): @@ -1794,7 +1883,10 @@ def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram: ), verifiers=[v.dialect for v in exported_program.verifiers], torch_version=torch.__version__, +<<<<<<< HEAD guards_code=exported_program._guards_code, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Test canonical form is well defined. @@ -1828,7 +1920,11 @@ class Result: def __init__(self) -> None: self.serialized_name_to_node: dict[str, torch.fx.Node] = {} +<<<<<<< HEAD self.serialized_name_to_meta: LazyMap = LazyMap() # str -> MetaType +======= + self.serialized_name_to_meta: dict[str, MetaType] = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.graph = torch.fx.Graph() self.module = torch.nn.Module() @@ -1844,7 +1940,11 @@ def save_graph_module(self) -> Iterator[None]: self.graph = torch.fx.Graph() self.module = torch.nn.Module() self.serialized_name_to_node = {} +<<<<<<< HEAD self.serialized_name_to_meta = LazyMap() +======= + self.serialized_name_to_meta = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.unbacked_symbols: set[sympy.Symbol] = set() try: yield @@ -2033,6 +2133,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: # Handle the tensor metas. for name, tensor_value in serialized_graph.tensor_values.items(): log.debug("[deserialize_tensor_meta] %s (input): %s", name, tensor_value) +<<<<<<< HEAD self.serialized_name_to_meta[name] = ( lambda v=tensor_value: self.deserialize_tensor_meta(v) ) @@ -2059,6 +2160,34 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: log.debug("[deserialize_script_obj_meta] %s", script_obj_meta) self.serialized_name_to_meta[name] = ( lambda v=script_obj_meta: self.deserialize_script_obj_meta(v) +======= + meta_val = self.deserialize_tensor_meta(tensor_value) + log.debug("[deserialize_tensor_meta] %s (output): %s", name, meta_val) + self.serialized_name_to_meta[name] = meta_val + + for name, sym_int_value in serialized_graph.sym_int_values.items(): + log.debug("[deserialize_sym_int] %s (input): %s", name, sym_int_value) + int_val = self.deserialize_sym_int(sym_int_value) + log.debug("[deserialize_sym_int] %s (output): %s", name, int_val) + self.serialized_name_to_meta[name] = int_val + + for name, sym_float_value in serialized_graph.sym_float_values.items(): + log.debug("[deserialize_sym_float] %s (input): %s", name, sym_float_value) + float_val = self.deserialize_sym_float(sym_float_value) + log.debug("[deserialize_sym_float] %s (output): %s", name, float_val) + self.serialized_name_to_meta[name] = float_val + + for name, sym_bool_value in serialized_graph.sym_bool_values.items(): + log.debug("[deserialize_sym_bool] %s (input): %s", name, sym_bool_value) + bool_val = self.deserialize_sym_bool(sym_bool_value) + log.debug("[deserialize_sym_bool] %s (output): %s", name, bool_val) + self.serialized_name_to_meta[name] = bool_val + + for name, script_obj_meta in serialized_graph.custom_obj_values.items(): + log.debug("[deserialize_script_obj_meta] %s", script_obj_meta) + self.serialized_name_to_meta[name] = self.deserialize_script_obj_meta( + script_obj_meta +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) log.debug("\n[deserialize graph nodes]") @@ -2164,6 +2293,7 @@ def _is_single_tensor_return(target) -> bool: fx_node = self.graph.create_node("call_function", target, args, {}, name) self.deserialize_sym_op_outputs(serialized_node, fx_node) +<<<<<<< HEAD elif ( target is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional @@ -2171,6 +2301,9 @@ def _is_single_tensor_return(target) -> bool: raise SerializeError( "deserialize nyi for torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional" ) +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif isinstance(target, torch._ops.HigherOrderOperator): args, kwargs = self.deserialize_hoo_inputs(serialized_node.inputs) metadata = self.deserialize_metadata(serialized_node.metadata) @@ -2222,7 +2355,11 @@ def _is_single_tensor_return(target) -> bool: _additional_msg = ( ( f"We failed to resolve {target} to an operator. " +<<<<<<< HEAD + "If it's a custom op/custom triton op, this is usually because the custom op is not registered" +======= + + "If it's a custom op/custom triton op, this is usally because the custom op is not registered" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + " when deserializing. Please import the custom op to register it before deserializing." + " Otherwise, please file an issue on github." ) @@ -2243,6 +2380,7 @@ def _is_single_tensor_return(target) -> bool: fx_node.kwargs, fx_node.meta.get("val"), ) +<<<<<<< HEAD # handle ShapeEnv asserts if target == torch.ops.aten._assert_scalar.default: @@ -2258,11 +2396,19 @@ def _is_single_tensor_return(target) -> bool: self.shape_env._constrain_range_for_size(sym.node.expr) # handle nn_module_stack; serialization throws away empty dicts +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( fx_node.op not in ["placeholder", "output"] and "nn_module_stack" not in fx_node.meta ): +<<<<<<< HEAD fx_node.meta["nn_module_stack"] = {} +======= + fx_node.meta[ + "nn_module_stack" + ] = {} # serialization throws away empty dicts +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec: log.debug("[deserialize_input_spec] %s", i) @@ -2337,12 +2483,15 @@ def deserialize_output_spec(self, o: OutputSpec) -> ep.OutputSpec: arg=ep.TensorArgument(name=o.buffer_mutation.arg.name), target=o.buffer_mutation.buffer_name, ) +<<<<<<< HEAD elif o.type == "parameter_mutation": return ep.OutputSpec( kind=ep.OutputKind.PARAMETER_MUTATION, arg=ep.TensorArgument(name=o.parameter_mutation.arg.name), target=o.parameter_mutation.parameter_name, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif o.type == "gradient_to_parameter": return ep.OutputSpec( kind=ep.OutputKind.GRADIENT_TO_PARAMETER, @@ -2445,6 +2594,11 @@ def deserialize( if symbol_name_to_range: for k, vr in symbol_name_to_range.items(): lower = vr.lower +<<<<<<< HEAD +======= + if vr.upper >= 2: # max is >= 2, not sym bool range + lower = max(2, lower) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges( _int_to_sympy_int(lower, -int_oo), vr.upper ) @@ -2596,8 +2750,11 @@ def deserialize_input(self, inp: Argument) -> Any: return inp.as_bool elif typ_ == "as_string": return inp.as_string +<<<<<<< HEAD elif typ_ == "as_complex": return complex(inp.as_complex.real, inp.as_complex.imag) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif typ_ == "as_sym_int": return self.deserialize_sym_argument(inp.as_sym_int) elif typ_ == "as_sym_float": @@ -3039,7 +3196,10 @@ def deserialize( constants=res.constants, verifiers=[load_verifier(v) for v in exported_program.verifiers], ) +<<<<<<< HEAD result._guards_code = exported_program.guards_code +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log.debug("\n[deserialize]: %s", result) return result @@ -3072,7 +3232,11 @@ def _dataclass_to_dict(obj): return "Infinity" elif obj == -math.inf: return "-Infinity" +<<<<<<< HEAD elif math.isnan(obj): +======= + elif obj == math.nan: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return "NaN" else: return obj @@ -3126,6 +3290,7 @@ def _dict_to_dataclass(cls, data): field_type = cls.__annotations__[_type] return cls.create(**{_type: _dict_to_dataclass(field_type, _value)}) elif dataclasses.is_dataclass(cls): +<<<<<<< HEAD fields = {} type_hints = typing.get_type_hints(cls) # For forward compatibility consideration, we ignore all the keys @@ -3137,6 +3302,15 @@ def _dict_to_dataclass(cls, data): new_field_obj = _dict_to_dataclass(type_hints[name], data[name]) fields[name] = new_field_obj return cls(**fields) # type: ignore[operator] +======= + obj = cls(**data) # type: ignore[assignment,operator] + type_hints = typing.get_type_hints(cls) + for f in dataclasses.fields(cls): + name = f.name + new_field_obj = _dict_to_dataclass(type_hints[name], getattr(obj, name)) + setattr(obj, name, new_field_obj) + return obj +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif isinstance(data, list): if len(data) == 0: return data @@ -3150,6 +3324,7 @@ def _dict_to_dataclass(cls, data): return data +<<<<<<< HEAD def _bytes_to_dataclass(cls: Any, artifact_bytes: bytes) -> Any: artifact_str = artifact_bytes.decode("utf-8") artifact_dict = json.loads(artifact_str) @@ -3157,6 +3332,8 @@ def _bytes_to_dataclass(cls: Any, artifact_bytes: bytes) -> Any: return artifact_dataclass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def deserialize( artifact: SerializedArtifact, expected_opset_version: Optional[dict[str, int]] = None, @@ -3164,8 +3341,15 @@ def deserialize( _unsafe_skip_version_check=False, ) -> ep.ExportedProgram: assert isinstance(artifact.exported_program, bytes) +<<<<<<< HEAD serialized_exported_program = _bytes_to_dataclass( ExportedProgram, artifact.exported_program +======= + exported_program_str = artifact.exported_program.decode("utf-8") + exported_program_dict = json.loads(exported_program_str) + serialized_exported_program = _dict_to_dataclass( + ExportedProgram, exported_program_dict +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return ExportedProgramDeserializer(expected_opset_version).deserialize( serialized_exported_program, @@ -3198,8 +3382,11 @@ def _get_argument(a: Argument): return None elif a.type == "as_strings": return None +<<<<<<< HEAD elif a.type == "as_complex": return None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif a.type == "as_sym_int": return a.as_sym_int elif a.type == "as_sym_ints": @@ -3504,7 +3691,10 @@ def canonicalize( range_constraints = dict( sorted(ep.range_constraints.items(), key=operator.itemgetter(0)) ) +<<<<<<< HEAD guards_code = sorted(ep.guards_code) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) module_call_graph = sorted(ep.graph_module.module_call_graph, key=lambda x: x.fqn) signature = ep.graph_module.signature graph = ep.graph_module.graph @@ -3536,6 +3726,7 @@ def rank_output(out) -> tuple[int, Optional[str], int]: idx, (_arg, spec) = out assert isinstance(spec, OutputSpec) if spec.type == "user_output": +<<<<<<< HEAD return 4, None, idx elif spec.type == "loss_output": return 4, None, idx @@ -3549,6 +3740,19 @@ def rank_output(out) -> tuple[int, Optional[str], int]: return 6, None, idx elif spec.type == "user_input_mutation": return 3, None, idx +======= + return 3, None, idx + elif spec.type == "loss_output": + return 3, None, idx + elif spec.type == "buffer_mutation": + return 1, spec.buffer_mutation.buffer_name, idx + elif spec.type == "gradient_to_parameter": + return 4, spec.gradient_to_parameter.parameter_name, idx + elif spec.type == "gradient_to_user_input": + return 5, None, idx + elif spec.type == "user_input_mutation": + return 2, None, idx +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif spec.type == "token": return 0, None, idx else: @@ -3661,9 +3865,12 @@ def replace_output(out): elif spec.type == "buffer_mutation": t = spec.buffer_mutation.arg t.name = replace_table[t.name] +<<<<<<< HEAD elif spec.type == "parameter_mutation": t = spec.parameter_mutation.arg t.name = replace_table[t.name] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif spec.type == "gradient_to_parameter": t = spec.gradient_to_parameter.arg t.name = replace_table[t.name] @@ -3701,7 +3908,10 @@ def replace_output(out): schema_version=ep.schema_version, verifiers=ep.verifiers, torch_version=ep.torch_version, +<<<<<<< HEAD guards_code=guards_code, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -3732,9 +3942,15 @@ def register_extension( extension_handler: type[ExtensionHandler], ): """Register custom de/serialization method for a node with non-standard type.""" +<<<<<<< HEAD assert issubclass(extension_handler, ExtensionHandler), ( f"Expected ExtensionHandler, got {extension_handler}." ) +======= + assert issubclass( + extension_handler, ExtensionHandler + ), f"Expected ExtensionHandler, got {extension_handler}." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert op_type not in _serialization_registry, f"{op_type} is already registered." assert isinstance(op_type, type) # Maybe a good idea to enforce this first. assert not ( diff --git a/torch/_export/serde/union.py b/torch/_export/serde/union.py index c65ad38d337fe..c35e8070e0786 100644 --- a/torch/_export/serde/union.py +++ b/torch/_export/serde/union.py @@ -1,12 +1,16 @@ # mypy: allow-untyped-defs import functools from collections.abc import Hashable +<<<<<<< HEAD from dataclasses import dataclass, fields from typing import TypeVar from typing_extensions import dataclass_transform T = TypeVar("T", bound="_Union") +======= +from dataclasses import fields +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _UnionTag(str): @@ -23,9 +27,15 @@ def create(t, cls): def __eq__(self, cmp) -> bool: assert isinstance(cmp, str) other = str(cmp) +<<<<<<< HEAD assert other in _get_field_names(self._cls), ( f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}" ) +======= + assert other in _get_field_names( + self._cls + ), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return str(self) == other def __hash__(self): @@ -37,6 +47,7 @@ def _get_field_names(cls) -> set[str]: return {f.name for f in fields(cls)} +<<<<<<< HEAD # If you turn a schema class that inherits from union into a dataclass, please use # this decorator to configure it. It's safe, faster and allows code sharing. # @@ -49,6 +60,8 @@ def _union_dataclass(cls: type[T]) -> type[T]: return dataclass(repr=False, eq=False)(cls) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _Union: _type: _UnionTag @@ -60,10 +73,14 @@ def create(cls, **kwargs): return obj def __post_init__(self): +<<<<<<< HEAD assert not any( f.name in ("type", "_type", "create", "value") for f in fields(self) # type: ignore[arg-type, misc] ) +======= + assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self)) # type: ignore[arg-type, misc] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def type(self) -> str: @@ -84,11 +101,14 @@ def __getattribute__(self, name): raise AttributeError(f"Field {name} is not set.") return attr +<<<<<<< HEAD def __eq__(self, other: object) -> bool: if not isinstance(other, _Union): return False return self.type == other.type and self.value == other.value +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __str__(self): return self.__repr__() diff --git a/torch/_export/utils.py b/torch/_export/utils.py index b7807145a9fa8..a3cde8315975c 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -8,7 +8,10 @@ import math import operator import re +<<<<<<< HEAD from collections import defaultdict +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from collections.abc import Iterable from contextlib import contextmanager from inspect import ismethod, Parameter @@ -19,7 +22,10 @@ from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch._subclasses.functional_tensor import FunctionalTensor from torch.fx._utils import first_call_function_nn_module_stack +<<<<<<< HEAD from torch.fx.experimental.proxy_tensor import PreDispatchTorchFunctionMode +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts @@ -212,6 +218,7 @@ def _getattr(model: torch.fx.GraphModule, attr_name: str): return params_buffers_to_node_meta +<<<<<<< HEAD def _maybe_find_pre_dispatch_tf_mode_for_export(): if not torch._C._is_torch_function_mode_enabled(): return None @@ -235,6 +242,8 @@ def _maybe_find_pre_dispatch_tf_mode_for_export(): return mode +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _populate_param_buffer_metadata_to_new_gm( params_buffers_to_node_meta: dict[str, Any], gm: torch.fx.GraphModule, @@ -280,8 +289,11 @@ def _get_shape_env_from_gm(gm: torch.fx.GraphModule): def _rename_without_collisions( name_map: dict[str, str], +<<<<<<< HEAD find_available: dict[str, int], used_names: set[str], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig_name: str, name: str, is_placeholder: bool = False, @@ -289,12 +301,16 @@ def _rename_without_collisions( """ Renames nodes to avoid name collisions, with suffixing. name_map: map from original name to new name +<<<<<<< HEAD find_available: map prefix to available suffix used_names: cache of used names +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig_name: mapping key name: candidate name (potentially suffixed, e.g. mul_2) is_placeholder: if the node is a placeholder, avoid detecting suffix """ +<<<<<<< HEAD match = re.match(r"(.*)_(\d+)", name) key = name @@ -315,6 +331,21 @@ def _rename_without_collisions( name_map[orig_name] = new_name used_names.add(new_name) +======= + if name in name_map.values(): + # non-placeholder nodes may be suffixed with the count + # instead of adding another suffix, we will try to increment it + match = re.match(r"(.*)_(\d+)", name) + if match and not is_placeholder: + name, n = match.group(1), int(match.group(2)) + else: + n = 0 + while (dup_name := f"{name}_{n + 1}") in name_map.values(): + n += 1 + name_map[orig_name] = dup_name + else: + name_map[orig_name] = name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return name_map[orig_name] @@ -331,7 +362,11 @@ def get_keystr(key_path: KeyPath) -> str: return f"*args{keystr(key_path[1:])}" else: kwarg_key = key_path[1] +<<<<<<< HEAD assert isinstance(kwarg_key, (GetAttrKey, MappingKey)) +======= + assert isinstance(kwarg_key, MappingKey) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) name = str(kwarg_key)[1:-1] # get rid of the enclosed [] return f"{name}{keystr(key_path[2:])}" @@ -419,7 +454,11 @@ def _check_symint( # this means we deferred a guard from export analysis to runtime, let this pass # we'll add a runtime assert checking equality to this replacement expression pass +<<<<<<< HEAD elif arg != int(symint): +======= + elif arg != symint: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) path = get_keystr(keypath) if i is not None: path += f".shape[{i}]" @@ -484,9 +523,15 @@ def register_dataclass_as_pytree_node( from_dumpable_context: Optional[FromDumpableContextFn] = None, return_none_fields: bool = False, ) -> None: +<<<<<<< HEAD assert dataclasses.is_dataclass(cls), ( f"Only dataclasses can be registered with this function: {cls}" ) +======= + assert dataclasses.is_dataclass( + cls + ), f"Only dataclasses can be registered with this function: {cls}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def default_flatten_fn(obj: Any) -> tuple[list[Any], Context]: flattened = [] @@ -680,6 +725,7 @@ def _insert_aten_to_metadata_assert_pass(gm: torch.fx.GraphModule) -> None: continue if (tensor_val := node.args[0].meta.get("val")) is not None: +<<<<<<< HEAD with ( gm.graph.inserting_before(node), _set_node_metadata_hook( @@ -691,6 +737,13 @@ def _insert_aten_to_metadata_assert_pass(gm: torch.fx.GraphModule) -> None: "nn_module_stack": node.meta.get("nn_module_stack"), }, ), +======= + with gm.graph.inserting_before(node), _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ): gm.graph.call_function( @@ -717,10 +770,14 @@ def apply_runtime_assertion_pass(gm: torch.fx.GraphModule, graph_signature): "in insert_deferred_runtime_asserts" ) with _set_node_metadata_hook( +<<<<<<< HEAD gm, functools.partial( _node_metadata_hook, metadata={"stack_trace": stack_trace} ), +======= + gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): shape_env = _get_shape_env_from_gm(gm) if shape_env: @@ -909,6 +966,7 @@ def _bind_signature_to_inputs(mod, fake_args, fake_kwargs): return {**sig.bind_partial(*fake_args).arguments, **fake_kwargs} +<<<<<<< HEAD def _build_cache(name, find_available, used_names): used_names.add(name) match = re.match(r"(.*)_(\d+)", name) @@ -918,6 +976,8 @@ def _build_cache(name, find_available, used_names): find_available[prefix] = int(n) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: """ Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs, @@ -925,7 +985,10 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: Different HOO subgraph types have different input schemas, so we first enumerate them and gather the top-level named placeholder nodes. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # gather all HOO subgraphs and their top-level named placeholder nodes subgraph_ph_tuples: list[tuple[torch.fx.GraphModule, list[torch.fx.Node]]] = [] for node in gm.graph.nodes: @@ -949,17 +1012,25 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: # propagate names for subgraph, hoo_phs in subgraph_ph_tuples: name_map: dict[str, str] = {} +<<<<<<< HEAD find_available: dict[str, int] = defaultdict(int) used_names: set[str] = set() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i, node in enumerate(subgraph.graph.nodes): if i < len(hoo_phs): # placeholder, retain name name_map[node.name] = hoo_phs[i].name node.name = node.target = hoo_phs[i].name +<<<<<<< HEAD _build_cache(node.name, find_available, used_names) else: # non-placeholder, check for collisions node.name = _rename_without_collisions( name_map, find_available, used_names, node.name, node.name ) +======= + else: # non-placeholder, check for collisions + node.name = _rename_without_collisions(name_map, node.name, node.name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # recurse and recompile _name_hoo_subgraph_placeholders(subgraph) @@ -1019,8 +1090,11 @@ def _extract_pytree_key(x): raise RuntimeError(f"Pytree key of type {type(x)} not handled for {x}") name_map: dict[str, str] = {} +<<<<<<< HEAD find_available: dict[str, int] = defaultdict(int) used_names: set[str] = set() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # map user input names with mod.forward() signature combined_args = _bind_signature_to_inputs(mod, fake_args, fake_kwargs) @@ -1037,8 +1111,11 @@ def _extract_pytree_key(x): if user_input_name: _rename_without_collisions( name_map, +<<<<<<< HEAD find_available, used_names, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) user_input_name, placeholder_prefixes[InputKind.USER_INPUT] + "_".join(_extract_pytree_key(x).lower() for x in arg_path), @@ -1058,8 +1135,11 @@ def _extract_pytree_key(x): _rename_without_collisions( name_map, +<<<<<<< HEAD find_available, used_names, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) spec.arg.name, placeholder_prefixes[spec.kind] + base_name, is_placeholder=True, @@ -1078,9 +1158,13 @@ def _extract_pytree_key(x): for node in gm.graph.nodes: if node.op == "placeholder": continue +<<<<<<< HEAD _rename_without_collisions( name_map, find_available, used_names, node.name, node.name ) +======= + _rename_without_collisions(name_map, node.name, node.name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # assign new node names for node in gm.graph.nodes: @@ -1159,7 +1243,11 @@ def remove_proxy_from_state_dict(state_dict: dict, in_place: bool) -> dict: def _detect_fake_mode_from_gm( gm: torch.fx.GraphModule, +<<<<<<< HEAD ) -> Optional[torch._subclasses.fake_tensor.FakeTensorMode]: +======= +) -> torch._subclasses.fake_tensor.FakeTensorMode: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ For a given graph module, we look at the "val" of placeholder nodes to find the fake inputs. Additionally, if gm doesn't have placeholders, we further look at the "example_value" or "val" of other nodes. @@ -1334,7 +1422,11 @@ def _collect_all_valid_cia_ops() -> set["OperatorBase"]: def _get_decomp_for_cia(op: "OperatorBase"): +<<<<<<< HEAD # [NOTE] Separating out func.decompose +======= + # [NOTE] Seperating out func.decompose +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Ideally we should be able to just register func.decompose but # we can't as this decomp is gonna be registered to the py_impl. # As a result it will infinitely recurse. So we first check if the op @@ -1410,7 +1502,10 @@ def register_module_as_pytree_input_node(cls: type[torch.nn.Module]) -> None: import torch +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Module(torch.nn.Module): def __init__(self): super().__init__() @@ -1419,15 +1514,23 @@ def __init__(self): def forward(self, x): return self.linear(x) +<<<<<<< HEAD torch._export.utils.register_module_as_pytree_node(InputDataClass) +======= + torch._export.utils.register_module_as_pytree_node(InputDataClass) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Mod(torch.nn.Module): def forward(self, x, m): return m(x) + x +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep = torch.export.export(Mod(), (torch.randn(3), Module())) print(ep) diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index 28593291b22cc..f58e25897b5d9 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -215,8 +215,11 @@ def _allowed_op_types() -> tuple[type[Any], ...]: torch.sym_min, torch.sym_not, torch.sym_sqrt, +<<<<<<< HEAD torch.sym_sum, torch.export.custom_ops._call_custom_autograd_function_in_pre_dispatch, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO (tmanlaibaatar) # Predispatch export is able to contain autograd ops. # These will be modeled as HOO later @@ -224,11 +227,14 @@ def _allowed_op_types() -> tuple[type[Any], ...]: torch.amp.autocast_mode._enter_autocast, torch.amp.autocast_mode._exit_autocast, torch.fx.experimental.symbolic_shapes.cast_symbool_to_symint_guardless, +<<<<<<< HEAD torch._functorch.predispatch._add_batch_dim, torch._functorch.predispatch._remove_batch_dim, torch._functorch.predispatch._vmap_increment_nesting, torch._functorch.predispatch._vmap_decrement_nesting, torch._functorch.predispatch.lazy_load_decompositions, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if not isinstance(op, _allowed_op_types()): @@ -283,6 +289,7 @@ def _is_type(name, ty): if type(attr).__name__ == "LoweredBackendModule": if ( _is_type("backend_id", str) +<<<<<<< HEAD and hasattr(attr, "original_module") and hasattr(attr, "module_name") and getattr(attr, "backend_id", None) == "aoti" @@ -290,6 +297,8 @@ def _is_type(name, ty): continue if ( _is_type("backend_id", str) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and _is_type("processed_bytes", bytes) and _is_type("compile_specs", list) and hasattr(attr, "original_module") @@ -476,12 +485,16 @@ def _verify_exported_program_signature(exported_program) -> None: ) num_tokens = len(gs.output_tokens) +<<<<<<< HEAD end = ( len(gs.buffers_to_mutate) + len(gs.parameters_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens ) +======= + end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mutate_nodes: list[str] = output_nodes[num_tokens:end] user_output_nodes = output_nodes[end : end + len(gs.user_outputs)] @@ -493,6 +506,7 @@ def _verify_exported_program_signature(exported_program) -> None: f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n" f"Buffer nodes available: {gs.buffers} \n" ) +<<<<<<< HEAD elif mutation_node in gs.parameters_to_mutate: if gs.parameters_to_mutate[mutation_node] not in gs.parameters: raise SpecViolationError( @@ -500,6 +514,8 @@ def _verify_exported_program_signature(exported_program) -> None: f"Dict of parameters that are mutated, in order: {gs.parameters_to_mutate} \n" f"Parameter nodes available: {gs.parameters} \n" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif mutation_node in gs.user_inputs_to_mutate: if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs: raise SpecViolationError( diff --git a/torch/_export/wrappers.py b/torch/_export/wrappers.py index e023169403937..b8349a8ca58bf 100644 --- a/torch/_export/wrappers.py +++ b/torch/_export/wrappers.py @@ -1,12 +1,19 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD import inspect from contextlib import contextmanager from functools import wraps +======= +from contextlib import contextmanager +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._custom_ops from torch._C import DispatchKey +<<<<<<< HEAD from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._higher_order_ops.flat_apply import ( _ConstantFunction, flat_apply, @@ -17,6 +24,10 @@ from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( +<<<<<<< HEAD +======= + get_proxy_slot, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PreDispatchTorchFunctionMode, ProxyTorchDispatchMode, track_tensor_tree, @@ -130,7 +141,11 @@ def call(self, *args): return cls +<<<<<<< HEAD def _register_func_spec_proxy_in_tracer(tracer, name, spec): +======= +def _register_subclass_spec_proxy_in_tracer(tracer, name, spec): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This is a wrapper utility method on top of tracer to cache the already registered subclass spec attribute. This is useful because @@ -147,6 +162,7 @@ def _register_func_spec_proxy_in_tracer(tracer, name, spec): return tracer.create_proxy("get_attr", qualname, (), {}) +<<<<<<< HEAD def _emit_flat_apply_call( *, tracer, @@ -182,6 +198,8 @@ def _is_init(fn): return callable(fn) and fn.__name__ == "__init__" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def mark_subclass_constructor_exportable_experimental(constructor_subclass): """ Experimental decorator that makes subclass to be traceable in export @@ -203,6 +221,13 @@ def __new__(cls, elem, *, requires_grad=False): def __init__(self, elem, ...): # ... """ +<<<<<<< HEAD +======= + + def _is_init(fn): + return callable(fn) and fn.__name__ == "__init__" + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not _is_init(constructor_subclass): raise RuntimeError( f"torch._export.wrappers.mark_constructor_exportable_experimental can only be applied on subclass tensor.__init__" @@ -211,15 +236,19 @@ def __init__(self, elem, ...): ) def wrapper(*args, **kwargs): +<<<<<<< HEAD constructor_subclass(*args, **kwargs) if not torch.compiler.is_exporting(): return +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not is_traceable_wrapper_subclass_type(type(args[0])): assert constructor_subclass.__qualname__.endswith("__init__") obj_name = constructor_subclass.__qualname__[: -len("__init__")] raise RuntimeError( +<<<<<<< HEAD f"Can't intercept {obj_name} in export because this object is not a traceable " f"tensor subclass. Please look at DTensor.__init__ implementation as an example of proper usage of this API." ) @@ -335,3 +364,72 @@ def wrapper(*args, **kwargs): return out return wrapper +======= + f"Applying mark_constructor_exportable_experimental on {obj_name} is not valid as it is not a traceable " + f"tensor subclass. Please look at DTensor.__init__ implementation as an example of proper usage of this API." + ) + constructor_subclass(*args, **kwargs) + if not torch._C._is_torch_function_mode_enabled(): + return + torch_function_mode_stack = torch.overrides._get_current_function_mode_stack() + + pre_dispatch_tf_modes = [ + mode + for mode in torch_function_mode_stack + if isinstance(mode, PreDispatchTorchFunctionMode) + ] + assert ( + len(pre_dispatch_tf_modes) <= 1 + ), f"Expected only one PreDispatchTorchFunctionMode, found {len(pre_dispatch_tf_modes)}" + + if len(pre_dispatch_tf_modes) == 0: + return + + mode = pre_dispatch_tf_modes[0] + + tracer = mode.tracer + subclass = args[0] + + flat_args, in_spec = to_graphable((tuple(args[1:]), kwargs)) + + constructor_spec_name = "_".join( + constructor_subclass.__qualname__.lower().split(".") + ) + qualname = tracer.get_fresh_qualname(constructor_spec_name) # type: ignore[union-attr] + setattr(tracer.root, qualname, in_spec) # type: ignore[union-attr] + spec_proxy = tracer.create_proxy("get_attr", qualname, (), {}) + flat_proxy_args = pytree.tree_map_only( + torch.Tensor, lambda x: get_proxy_slot(x, tracer).proxy, flat_args + ) + + _, func_spec = torch.utils._pytree.tree_flatten( + _ConstantFunction(type(subclass)) + ) + + # We actually don't want to create a new spec for each instance + # In fx graph, it will look like dtensor_const_func_spec + # We can't directly shove DTensor.__init__ into fx as it is not + # allowed type. + fxable_constructor_call_spec_name = ( + type(subclass).__name__.lower() + "_const_func_spec" + ) + + # We should try to reuse the constructor call spec as it is guaranteed to be same + # for each subclass type. This is different from proxy-ing the init arguments which + # can't be reused because for example, DTensor can receive different DeviceMesh etc + # as it's arguments + func_spec_proxy = _register_subclass_spec_proxy_in_tracer( + tracer, fxable_constructor_call_spec_name, func_spec + ) + + inner_proxy = tracer.create_proxy( + "call_function", + flat_apply, + (func_spec_proxy, spec_proxy, *flat_proxy_args), + {}, + ) + track_tensor_tree(subclass, inner_proxy, constant=None, tracer=tracer) + return + + return wrapper +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_functorch/_activation_checkpointing/graph_info_provider.py b/torch/_functorch/_activation_checkpointing/graph_info_provider.py index 2a5da58fdd633..dd9439c102bf1 100644 --- a/torch/_functorch/_activation_checkpointing/graph_info_provider.py +++ b/torch/_functorch/_activation_checkpointing/graph_info_provider.py @@ -96,9 +96,15 @@ def inialize_from_graph( @property def recomputable_node_only_graph(self) -> nx.DiGraph: if self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] is None: +<<<<<<< HEAD self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] = ( self._create_recomputable_node_only_graph() ) +======= + self._lazily_initialized_graphs[ + self.__RECOMPUTABLE_NODE_ONLY_GRAPH + ] = self._create_recomputable_node_only_graph() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self._lazily_initialized_graphs[self.__RECOMPUTABLE_NODE_ONLY_GRAPH] @property @@ -119,17 +125,29 @@ def recomputable_node_only_graph_with_larger_graph_context(self) -> nx.DiGraph: @property def full_joint_nx_graph(self) -> nx.DiGraph: if self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] is None: +<<<<<<< HEAD self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] = ( self._create_full_joint_graph() ) +======= + self._lazily_initialized_graphs[ + self.__FULL_NX_JOINT_GRAPH + ] = self._create_full_joint_graph() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self._lazily_initialized_graphs[self.__FULL_NX_JOINT_GRAPH] @property def simplified_fx_joint_graph(self) -> Graph: if self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] is None: +<<<<<<< HEAD self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] = ( self._recreate_psuedo_joint_graph() ) +======= + self._lazily_initialized_graphs[ + self.__SIMPLIFIED_FX_JOINT_GRAPH + ] = self._recreate_psuedo_joint_graph() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self._lazily_initialized_graphs[self.__SIMPLIFIED_FX_JOINT_GRAPH] def get_non_ac_peak_memory(self) -> float: @@ -285,7 +303,13 @@ def _visualize_recomputable_candidate_graph_with_larger_context( float( self.recomputable_node_only_graph_with_larger_graph_context.nodes[ node +<<<<<<< HEAD ]["memory"] +======= + ][ + "memory" + ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) ) diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index ec1e70a9a00f4..a7d680b512da4 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -2,7 +2,10 @@ """ Utils for caching the outputs of AOTAutograd """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from __future__ import annotations import base64 @@ -44,7 +47,10 @@ sha256_hash, write_atomic, ) +<<<<<<< HEAD from torch._inductor.cudagraph_utils import BoxedDeviceIndex +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.output_code import ( CompiledFxGraph, CompiledFxGraphConstants, @@ -78,6 +84,10 @@ if TYPE_CHECKING: from torch._inductor.compile_fx import _CompileFxKwargs +<<<<<<< HEAD +======= + from torch._inductor.cudagraph_utils import BoxedDeviceIndex +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.remote_cache import JsonDataTy, RemoteCache from torch._inductor.utils import BoxedBool from torch.fx.node import Node @@ -95,7 +105,11 @@ class FXGraphCacheMiss(BypassAOTAutogradCache): def should_use_remote_autograd_cache(): +<<<<<<< HEAD if torch.compiler.config.force_disable_caches: +======= + if torch._inductor.config.force_disable_caches: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False if config.enable_remote_autograd_cache is not None: return config.enable_remote_autograd_cache @@ -116,15 +130,22 @@ def should_use_remote_autograd_cache(): def should_use_local_autograd_cache(): +<<<<<<< HEAD if torch.compiler.config.force_disable_caches: +======= + if torch._inductor.config.force_disable_caches: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False return config.enable_autograd_cache +<<<<<<< HEAD def should_bundle_autograd_cache(): return config.bundled_autograd_cache or torch._dynamo.config.caching_precompile +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def check_node_safe(node: Node): """ Checks that the node only uses supported operators. We are starting with very @@ -279,16 +300,37 @@ def check_cacheable(gm: torch.fx.GraphModule): # Subgraphs are only used for caching logic. if hasattr(gm, "saved_tensors_hooks_pack_0"): check_cacheable(gm.saved_tensors_hooks_pack_0) # type: ignore[arg-type] +<<<<<<< HEAD # We have guarantee of unpack sugraph existence if pack subgraph exists check_cacheable(gm.saved_tensors_hooks_unpack_0) # type: ignore[arg-type] +======= + # We have guarantee of unpack sugraph existance if pack subgraph exists + check_cacheable(gm.saved_tensors_hooks_unpack_0) # type: ignore[arg-type] + + +def check_metadata_cacheable(metadata: ViewAndMutationMeta): + """ + When view replay is turned on, we bypass autograd cache if + the output is aliased. + """ + if config.view_replay_for_aliased_outputs: + for info in metadata.output_info: + if info.functional_tensor is not None: + raise BypassAOTAutogradCache( + "Cannot cache a graph with functional tensor" + ) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class AOTAutogradCacheDetails(FxGraphHashDetails): """ Object to capture all the details for a dynamo graph module relevant to computing a safe and stable cache key for AOTAutograd. """ +<<<<<<< HEAD def get_triton_source_codes_from_gm( self, gm: torch.fx.GraphModule, @@ -325,6 +367,8 @@ def get_triton_source_codes_from_gm( return triton_kernel_source_codes +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__( self, gm: torch.fx.GraphModule, @@ -342,7 +386,10 @@ def __init__( [], [], ) +<<<<<<< HEAD self.triton_kernel_source_codes = self.get_triton_source_codes_from_gm(gm) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if hasattr(gm, "saved_tensors_hooks_pack_0"): @@ -408,6 +455,7 @@ def _reduce_tensor(self, tensor): return (_ident, (metadata,)) +<<<<<<< HEAD @contextlib.contextmanager def normalize_placeholder_names(gm: torch.fx.GraphModule): """ @@ -459,6 +507,8 @@ def normalize_placeholder_names(gm: torch.fx.GraphModule): gm.recompile() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def autograd_cache_key( gm: torch.fx.GraphModule, example_inputs, @@ -482,6 +532,10 @@ def autograd_cache_key( if triton.__version__ < "3.2.0": raise BypassAOTAutogradCache("AOTAutogradCache requires triton 3.2.0") +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config) pickler = AOTAutogradCachePickler(gm) # The prefix distinguishes among the other kinds of objects we cache @@ -504,6 +558,7 @@ class InductorOutput(Generic[TOut], ABC): """ @abstractmethod +<<<<<<< HEAD def pre_save(self) -> None: ... @abstractmethod @@ -511,6 +566,18 @@ def load(self, example_inputs) -> TOut: ... @abstractmethod def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut: ... +======= + def pre_save(self) -> None: + ... + + @abstractmethod + def load(self, example_inputs) -> TOut: + ... + + @abstractmethod + def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass @@ -549,6 +616,10 @@ def post_compile( }, payload_fn=lambda: json.dumps(cache_info), ) +<<<<<<< HEAD +======= + counters["inductor"]["fxgraph_cache_hit"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Run normal post compile graph.post_compile(self.example_inputs, constants, fx_config) return graph @@ -673,9 +744,13 @@ def post_compile( # See note [Wrapping bw_compiler in disable] # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py # But since on cache hit we do not call the bw_compiler, we need to reapply the disable +<<<<<<< HEAD return torch._dynamo.disable( # type: ignore[return-value] compiled_bw, reason="do not trace generated backwards pass" ) +======= + return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Forward types don't have any extra parameters, so this is just a TypeAlias, in essence @@ -694,9 +769,13 @@ def post_compile( # See note [Wrapping bw_compiler in disable] # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py # But since on cache hit we do not call the bw_compiler, we need to reapply the disable +<<<<<<< HEAD return torch._dynamo.disable( # type: ignore[return-value] compiled_bw, reason="do not trace generated backwards pass" ) +======= + return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass @@ -781,6 +860,10 @@ def pre_save(self): """ Perform any preparations to make the cache entry ready for serialization. """ +<<<<<<< HEAD +======= + check_metadata_cacheable(self.runtime_metadata) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.compiled_fw.pre_save() if self.compiled_bw is not None: self.compiled_bw.pre_save() @@ -997,6 +1080,7 @@ def sanitize_gm_for_cache(gm: torch.fx.GraphModule): and then put them back before returning. This way, we generate a cache key based off of a canonical graph without these fields, and also guarantee they aren't used to affect the cache's output. """ +<<<<<<< HEAD # Mapping from each field to a default value IGNORED_FIELDS: dict[str, Any] = { "meta": {}, # metadata used by export @@ -1013,6 +1097,23 @@ def sanitize_gm_for_cache(gm: torch.fx.GraphModule): with normalize_placeholder_names(gm): yield finally: +======= + IGNORED_FIELDS = ( + "meta", # metadata used by export + "compile_subgraph_reason", # Used by dynamo only for logging, no change in inductor/autograd behavior + "_param_name_to_source", # Encapsulated by aot_config.aot_autograd_arg_pos_to_source + "_backend_id", + ) + saved_fields = {} + for field in IGNORED_FIELDS: + saved_fields[field] = getattr(gm, field, None) + # Clear the field + setattr(gm, field, None) + try: + yield + finally: + # Put the fields back after dispatch_and_compile is complete +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for field, value in saved_fields.items(): setattr(gm, field, value) @@ -1046,6 +1147,7 @@ def after_deserialization(self) -> Callable: # which is set by compile_fx. But in precompile, we never actually call compile_fx # so we don't have a place to track cudagraphs here. cudagraphs = torch._inductor.config.triton.cudagraphs +<<<<<<< HEAD boxed_forward_device_index = BoxedDeviceIndex(None) compiled_fn = entry.wrap_post_compile( [], @@ -1054,6 +1156,10 @@ def after_deserialization(self) -> Callable: "cudagraphs": cudagraphs, "boxed_forward_device_index": boxed_forward_device_index, }, +======= + compiled_fn = entry.wrap_post_compile( + [], entry.sanitized_aot_config, {"cudagraphs": cudagraphs} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # TODO: this ignores flat_params, which can exist @@ -1109,7 +1215,12 @@ def clear(): pass @staticmethod +<<<<<<< HEAD def try_load( +======= + def load( + dispatch_and_compile: Callable, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mod: Union[torch.fx.GraphModule, torch._dynamo.utils.GmWrapper], args, aot_config: AOTConfig, @@ -1117,7 +1228,11 @@ def try_load( boxed_forward_device_index: Optional[BoxedDeviceIndex], local: bool, remote: bool, +<<<<<<< HEAD ) -> Optional[Callable]: +======= + ) -> Callable: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Load a result from the cache, and reconstruct a runtime wrapper around the object """ @@ -1137,10 +1252,17 @@ def try_load( cache_key, debug_lines = autograd_cache_key( gm, args, aot_config, fx_config ) +<<<<<<< HEAD entry: Optional[GenericAOTAutogradCacheEntry] = ( AOTAutogradCache._lookup( cache_key, local, remote, args, cache_info, aot_config ) +======= + entry: Optional[ + GenericAOTAutogradCacheEntry + ] = AOTAutogradCache._lookup( + cache_key, local, remote, args, cache_info, aot_config +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if entry is not None: compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config) @@ -1165,8 +1287,14 @@ def try_load( # FXGraphCache and AOTAutogradCache? # get_metrics_context().increment(...) if ( +<<<<<<< HEAD ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(time_saved_ns) +======= + ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( + time_saved_ns + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) != 0: cache_info["ephemeral_timeout_increase"] = ephemeral_increase @@ -1179,10 +1307,14 @@ def try_load( except FXGraphCacheMiss as e: counters["aot_autograd"]["autograd_cache_miss"] += 1 cache_state = "miss" +<<<<<<< HEAD if ( config.strict_autograd_cache or torch._dynamo.config.caching_precompile ): +======= + if config.strict_autograd_cache: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise e # Most often this is BypassAOTAutogradCache, but # if there's ever different reason we can't cache, @@ -1212,10 +1344,14 @@ def try_load( ) if remote: log_cache_bypass("bypass_aot_autograd", str(e)) +<<<<<<< HEAD if ( config.strict_autograd_cache or torch._dynamo.config.caching_precompile ): +======= + if config.strict_autograd_cache: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise e if compiled_fn is None: # Set the cache key so we can save a cache result later @@ -1226,6 +1362,10 @@ def try_load( time.time_ns(), forward_symints=symints, ) +<<<<<<< HEAD +======= + compiled_fn = dispatch_and_compile() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cache_info.update( { @@ -1259,7 +1399,10 @@ def try_load( }, payload_fn=lambda: json.dumps(cache_info), ) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return compiled_fn @classmethod @@ -1329,7 +1472,11 @@ def _lookup( AOTAutogradCacheArtifact.type(), key, pickled_content ) if ( +<<<<<<< HEAD should_bundle_autograd_cache() +======= + config.bundled_autograd_cache +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and aot_config is not None and aot_config.precompile_backend_id is not None ): @@ -1371,7 +1518,11 @@ def save(key: str, entry: GenericAOTAutogradCacheEntry, remote: bool): AOTAutogradCacheArtifact.type(), key, content ) if ( +<<<<<<< HEAD should_bundle_autograd_cache() +======= + config.bundled_autograd_cache +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and entry.sanitized_aot_config.precompile_backend_id is not None ): precompile_key = entry.sanitized_aot_config.precompile_backend_id @@ -1379,10 +1530,14 @@ def save(key: str, entry: GenericAOTAutogradCacheEntry, remote: bool): # useful, remove it from the entry. entry.sanitized_aot_config.precompile_backend_id = None PrecompileContext.record_artifact( +<<<<<<< HEAD BundledAOTAutogradCacheArtifact.type(), precompile_key, entry, editable=True, +======= + BundledAOTAutogradCacheArtifact.type(), precompile_key, content +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) AOTAutogradCache._write_to_local_cache(key, content) counters["aot_autograd"]["autograd_cache_saved"] += 1 @@ -1403,9 +1558,15 @@ def save(key: str, entry: GenericAOTAutogradCacheEntry, remote: bool): return None if remote: +<<<<<<< HEAD remote_cache: Optional[RemoteCache[JsonDataTy]] = ( AOTAutogradCache.get_remote_cache() ) +======= + remote_cache: Optional[ + RemoteCache[JsonDataTy] + ] = AOTAutogradCache.get_remote_cache() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if remote_cache is not None: time_taken_ms = int( (entry.forward_time_taken_ns + entry.backward_time_taken_ns) // 1e6 @@ -1450,7 +1611,11 @@ def make_entry( num_symints_saved_for_bw: Optional[int], serialized_bw_module: Optional[SerializedGraphModule], ) -> GenericAOTAutogradCacheEntry: +<<<<<<< HEAD if should_bundle_autograd_cache(): +======= + if config.bundled_autograd_cache: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Helper function to unwrap all the wrappers we added during aotdispatch # They get reapplied on cache load def unwrap_compiled_fx_graph(obj): diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index acfd40fe78c7f..6096d4180155f 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -11,6 +11,10 @@ import collections import contextlib import logging +<<<<<<< HEAD +======= +from functools import wraps +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Callable, Optional import torch @@ -27,6 +31,7 @@ transform_subclass, ) +<<<<<<< HEAD from .descriptors import ( AOTInput, AOTOutput, @@ -35,6 +40,8 @@ PlainAOTOutput, TangentAOTInput, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .functional_utils import ( are_all_mutations_hidden_from_autograd, are_all_mutations_under_no_grad_or_inference_mode, @@ -43,10 +50,17 @@ has_metadata_mutation, MetadataKey, to_fun, +<<<<<<< HEAD ViewMetaSequence, was_inductor_storage_resized, ) from .schemas import ( +======= + was_inductor_storage_resized, +) +from .schemas import ( + FunctionalTensorMetadataEq, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) InputAliasInfo, MemoryFormatMeta, MutationType, @@ -55,7 +69,11 @@ ViewAndMutationMeta, ) from .subclass_utils import create_subclass_meta +<<<<<<< HEAD from .utils import _get_autocast_states, KNOWN_TYPES, simple_wraps, strict_zip +======= +from .utils import _get_autocast_states, KNOWN_TYPES, strict_zip +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) zip = strict_zip @@ -68,7 +86,11 @@ # We assume tangents memory format to be similar to corresponding output's memory_format. # The idea is that we are technically making a guess about the strides of our tangents, # while we trace out the joint. +<<<<<<< HEAD # If runtime specified tangents will not have the same memory format as predicted traced tangents, +======= +# If runtime specfied tangents will not have the same memory format as predicted traced tangents, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # we coerce them at runtime to traced tangents memory format. @@ -90,7 +112,11 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor): out = out.contiguous(memory_format=memory_format.memory_format) updated = was is not out +<<<<<<< HEAD # For subclass we keep memory format of outer strides at the beginning of the list +======= + # For subclass we keep memory format of outer strides at the beggining of the list +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_memory_format = [memory_format] if is_subclass else memory_format # Note [Tangents memory format, Part 2] @@ -154,7 +180,10 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor): def run_functionalized_fw_and_collect_metadata( f, *, +<<<<<<< HEAD flat_args_descs: list[AOTInput], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) keep_input_mutations: bool, # TODO: refactor to kill this flag is_train: bool = False, @@ -177,7 +206,11 @@ def _to_fun(t): else: return t +<<<<<<< HEAD @simple_wraps(f) +======= + @wraps(f) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def inner(*flat_args): # This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args. assert all(isinstance(a, tuple(KNOWN_TYPES)) for a in flat_args) @@ -203,6 +236,7 @@ def inner(*flat_args): with disable_above, mode, suppress_pending: # precondition: The passed in function already handles unflattening inputs + flattening outputs flat_f_args = pytree.tree_map(_to_fun, flat_args) +<<<<<<< HEAD flat_f_args_descs = flat_args_descs flat_f_outs = f(*flat_f_args) @@ -221,6 +255,9 @@ def inner(*flat_args): # actually do the trace flat_f_outs_descs = [PlainAOTOutput(i) for i in range(len(flat_f_outs))] +======= + flat_f_outs = f(*flat_f_args) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We didn't do any tracing, so we don't need to process the # unbacked symbols, they will just disappear into the ether. # Also, prevent memoization from applying. @@ -411,7 +448,10 @@ def inner(*flat_args): # maps the id of an intermediate base to its index in the output of the compiled forward intermediate_base_tensor_id_to_output_idx: dict[int, int] = {} intermediate_bases: list[torch.Tensor] = [] +<<<<<<< HEAD intermediate_bases_descs: list[AOTInput] = [] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Why Do We Care If Storage Changed? # It's important to understand the implications of storage changes in complex scenarios. Take this example: # @@ -436,7 +476,11 @@ def inner(*flat_args): # the autograd engine mistakenly assumes that 'x' and 'out' are aliased, treating 'x' as 'out._base'. # This misinterpretation leads to an 'alias_of_input' flag, causing an unnecessary as_strided() call to be generated, # which could lead to issues later in the code. +<<<<<<< HEAD for o, desc in zip(flat_f_outs, flat_f_outs_descs): +======= + for o in flat_f_outs: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) functional_tensor_storage_changed = isinstance( o, FunctionalTensor ) and torch._functionalize_was_storage_changed( # type: ignore[attr-defined] @@ -597,6 +641,7 @@ def inner(*flat_args): output_type = ( OutputType.alias_of_intermediate_save_as_output ) +<<<<<<< HEAD intermediate_base_tensor_id_to_output_idx[id(o._base)] = ( new_out_idx ) @@ -609,6 +654,12 @@ def inner(*flat_args): intermediate_bases_descs.append( TangentAOTInput(IntermediateBaseAOTOutput(desc)) ) +======= + intermediate_base_tensor_id_to_output_idx[ + id(o._base) + ] = new_out_idx + intermediate_bases.append(o._base) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif ( # See https://github.com/pytorch/pytorch/issues/100348 for this case. # This protects against the specific case where a user fn returns (output, output.detach()) @@ -617,7 +668,11 @@ def inner(*flat_args): and not o.requires_grad ): # In theory we could use any of these tensors to regenerate the aliased outputs from, +<<<<<<< HEAD # since they all alias each other and have identical metadata +======= + # since they all alias each other and have identical metatadata +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_alias = outs_with_identical_metadata_that_require_grad[0] existing_out_idx = out_tensor_ids[id(out_alias)] output_type = OutputType.alias_of_intermediate_base_is_user_output @@ -640,7 +695,11 @@ def inner(*flat_args): # # The FunctionalTensor will be saved if one of the 2 conditions below # is true: +<<<<<<< HEAD view_meta_sequence = None +======= + functional_tensor = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( # 1. If the output_type is either of: # (i) alias_of_intermediate; @@ -672,7 +731,11 @@ def inner(*flat_args): and not input_info[base_idx].mutates_metadata ): if isinstance(o, FunctionalTensor): +<<<<<<< HEAD view_meta_sequence = ViewMetaSequence(o) +======= + functional_tensor = FunctionalTensorMetadataEq(o.elem) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_info = OutputAliasInfo( output_type=output_type, @@ -680,7 +743,11 @@ def inner(*flat_args): base_idx=base_idx, dynamic_dims=dynamic_dims, requires_grad=isinstance(o, torch.Tensor) and o.requires_grad, +<<<<<<< HEAD view_meta_sequence=view_meta_sequence, +======= + functional_tensor=functional_tensor, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) output_info.append(out_info) @@ -709,7 +776,11 @@ def _is_subclass_mutated_input_tangent_always_subclass(inp): or torch._functorch.config.disable_guess_zero_tangent_for_mutated_input_subclass ) +<<<<<<< HEAD f_input_tangents_pairs = [ +======= + f_input_tangents = [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Note: [AOTAutograd Tangent Subclassness for mutated inputs] # Generally when creating tangents to trace with, we assume that tangents will have # the same subclass-ness as their forward outs @@ -736,6 +807,7 @@ def _is_subclass_mutated_input_tangent_always_subclass(inp): # (a * b).sum().backward() # # We can not deduce it easily now, so introducing a debug config to be able to turn off this for specific cases. +<<<<<<< HEAD # NJT guarantees to have its tangent as NJT, because it has dedicated integration in Autograd # See torch/csrc/autograd/python_function.cpp, use_zeros_like. ( @@ -748,10 +820,22 @@ def _is_subclass_mutated_input_tangent_always_subclass(inp): TangentAOTInput(InputMutationAOTOutput(inp_desc)), ) for inp, inp_desc, info in zip(flat_f_args, flat_f_args_descs, input_info) +======= + # NJT gurantees to have its tangent as NJT, because it has dedicated integration in Autograd + # See torch/csrc/autograd/python_function.cpp, use_zeros_like. + ( + _plain_fake_tensor_like_subclass(inp) + if is_traceable_wrapper_subclass(inp) + and not _is_subclass_mutated_input_tangent_always_subclass(inp) + else inp + ) + for inp, info in zip(flat_f_args, input_info) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if info.mutation_type == MutationType.MUTATED_OUT_GRAPH and info.mutates_data and info.requires_grad ] +<<<<<<< HEAD f_input_tangents, f_input_tangents_descs = ( [x[0] for x in f_input_tangents_pairs], [x[1] for x in f_input_tangents_pairs], @@ -760,6 +844,11 @@ def _is_subclass_mutated_input_tangent_always_subclass(inp): f_output_tangents_pairs = [ (o, TangentAOTInput(desc)) for o, info, desc in zip(flat_f_outs, output_info, flat_f_outs_descs) +======= + f_output_tangents = [ + o + for o, info in zip(flat_f_outs, output_info) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if info.output_type in [ OutputType.non_alias, @@ -769,6 +858,7 @@ def _is_subclass_mutated_input_tangent_always_subclass(inp): and issubclass(info.raw_type, torch.Tensor) and info.requires_grad ] +<<<<<<< HEAD f_output_tangents, f_output_tangents_descs = ( [x[0] for x in f_output_tangents_pairs], [x[1] for x in f_output_tangents_pairs], @@ -781,18 +871,29 @@ def _is_subclass_mutated_input_tangent_always_subclass(inp): ) # TODO: I'm pretty sure you don't need a tree_map here +======= + # intermediate bases are also included in the backward graph + f_tangents = f_input_tangents + f_output_tangents + intermediate_bases +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) traced_tangents = pytree.tree_map(from_fun, f_tangents) traced_tangents = pytree.tree_map( view_avoid_dupes_with_primals, traced_tangents ) +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) traced_tangents = [ coerce_tangent_and_suggest_memory_format(tt)[0] for i, tt in enumerate(traced_tangents) ] +<<<<<<< HEAD # NB: update this if the maps above ever change structure. # Also, it might be helpful to add coercion information to the tangent desc! traced_tangents_descs = f_tangents_descs +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nonlocal static_input_indices static_input_indices = static_input_indices or [] if torch._dynamo.compiled_autograd.in_compiled_autograd_region: @@ -853,7 +954,10 @@ def _is_subclass_mutated_input_tangent_always_subclass(inp): num_intermediate_bases=len(intermediate_bases), keep_input_mutations=keep_input_mutations, traced_tangents=traced_tangents, +<<<<<<< HEAD traced_tangents_descs=traced_tangents_descs, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) subclass_inp_meta=create_subclass_meta(flat_args), subclass_fw_graph_out_meta=create_subclass_meta(fw_graph_outs), subclass_tangent_meta=create_subclass_meta( diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py new file mode 100644 index 0000000000000..3382cb102dcad --- /dev/null +++ b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py @@ -0,0 +1,338 @@ +# mypy: allow-untyped-defs +""" +This module dispatches the graphs to either the forward-only or joint compilation +pathways, taking into account the AOTConfig and the collected ViewAndMutationMetadata. +""" + +import dataclasses +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +import torch.utils.dlpack +from torch import Tensor +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code +from torch._logging import getArtifactLogger, trace_structured +from torch._subclasses.functional_tensor import FunctionalTensorMode +from torch.fx.experimental.proxy_tensor import make_fx +from torchgen.utils import dataclass_repr + +from .. import config +from .functional_utils import ( + assert_functional_graph, + propagate_input_mutation_stacktraces, +) +from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta +from .traced_function_transforms import ( + aot_dispatch_subclass, + create_functionalized_fn, + create_joint, + fn_input_mutations_to_outputs, + fn_prepped_for_autograd, + handle_effect_tokens_fn, +) +from .utils import ( + copy_fwd_metadata_to_bw_nodes, + register_buffer_assignment_hook, + root_module_when_exporting_non_strict, + unlift_tokens, +) + + +aot_graphs_log = getArtifactLogger(__name__, "aot_graphs") + + +def _create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule: + # FunctionalTensorMode must be enabled here. + # See Note [Accessing .grad_fn on FunctionalTensor] + with enable_python_dispatcher(), FunctionalTensorMode( + pre_dispatch=aot_config.pre_dispatch, + export=aot_config.is_export, + # Allow token discovery for joint fn tracing as tokens can be used in backward. + _allow_token_discovery=True, + ): + fx_g = make_fx( + f, + decomposition_table=aot_config.decompositions, + record_module_stack=True, + pre_dispatch=aot_config.pre_dispatch, + )(*args) + + return fx_g + + +# TODO: Refactor the following code so detach() persists item_memo +def _detach_and_copy_item_memo(t): + detached_t = t.detach() + if hasattr(t, "item_memo"): + detached_t.item_memo = t.item_memo + return detached_t + + +def aot_dispatch_base_graph( + flat_fn, + flat_args: list[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> tuple[torch.fx.GraphModule, list[Any], Optional[SubclassMeta]]: + # aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case. + # The cases that aot_dispatch_base doesn't need to handle include: + # - outputs that are aliases of graph intermediates + # - outputs that are aliases of graph inputs + # While cases that it does need to handle include: + # - input mutations (including when inputs are aliases of each other) + # - input metadata mutations + fn_to_trace = fn_input_mutations_to_outputs( + flat_fn, + fw_metadata, + keep_data_input_mutations=aot_config.keep_inference_input_mutations, + ) + + fn_to_trace, updated_flat_args = create_functionalized_fn( + fn_to_trace, + flat_args, + meta=fw_metadata, + aot_config=aot_config, + trace_joint=False, + ) + + # TODO: replace with AOTDispatchSubclassWrapper once we refactor + # fn_input_mutations_to_outputs and create_functionalized_fn + # into CompilerWrappers. + ( + fn_to_trace, + updated_flat_args_subclasses_desugared, + maybe_subclass_meta, + ) = aot_dispatch_subclass( + fn_to_trace, + updated_flat_args, + is_joint_structure=False, + meta=fw_metadata, + fw_only=flat_fn, + ) + + (fn_to_trace, updated_flat_args_subclasses_desugared) = handle_effect_tokens_fn( + fn_to_trace, + updated_flat_args_subclasses_desugared, + meta=fw_metadata, + trace_joint=False, + ) + + aot_graphs_log.debug( + "aot_config id: %s, fw_metadata=%s,subclass_metadata=%s", + str(aot_config.aot_id), + str(fw_metadata), + str(maybe_subclass_meta), + ) + + # We track buffer assignments when exporting in non-strict mode. + # (In contrast, strict mode errors on any attribute assignment.) + mod_when_exporting_non_strict = root_module_when_exporting_non_strict(flat_fn) + if aot_config.is_export and mod_when_exporting_non_strict is not None: + # For any buffer that is assigned, we want to associate it to the final proxy node + # that it is assigned to. This node can then be added as a buffer mutation output. + assigned_buffers: dict[str, str] = {} + hook = register_buffer_assignment_hook( + mod_when_exporting_non_strict, assigned_buffers + ) + + fake_mode = detect_fake_mode() + if fake_mode: + saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only( + torch.Tensor, + _detach_and_copy_item_memo, + updated_flat_args_subclasses_desugared, + ) + else: + saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only( + torch.Tensor, lambda t: t.detach(), updated_flat_args_subclasses_desugared + ) + + fw_module = _create_graph( + fn_to_trace, + updated_flat_args_subclasses_desugared, + aot_config=aot_config, + ) + + if aot_config.is_export and mod_when_exporting_non_strict is not None: + # We update metadata to consider any assigned buffers as buffer mutations. + i = len(dict(mod_when_exporting_non_strict.named_parameters())) + for name, _ in mod_when_exporting_non_strict.named_buffers(): + if name in assigned_buffers and not fw_metadata.input_info[i].mutates_data: # type: ignore[possibly-undefined] + fw_metadata.input_info[i] = dataclasses.replace( + fw_metadata.input_info[i], mutates_data=True + ) + fw_metadata.num_mutated_inp_runtime_indices += 1 + i += 1 + + # We add nodes corresponding to buffer assignments as output nodes in the graph. + add_nodes = [] + output_node = list(fw_module.graph.nodes)[-1] + for name in assigned_buffers.values(): # type: ignore[possibly-undefined] + for node in fw_module.graph.nodes: + if node.name == name: + add_nodes.append(node) + node.users[output_node] = None + output_node.args = ((*add_nodes, *output_node.args[0]),) + + hook.remove() # type: ignore[possibly-undefined] + + # As long as we opted to remove input mutations, then + # there should be *NO* mutating ops in the graph at this point. + copy_count = assert_functional_graph(fw_module.graph) + fw_module.graph.eliminate_dead_code() + fw_module.recompile() + + copy_count2 = assert_functional_graph(fw_module.graph) + propagate_input_mutation_stacktraces(fw_module.graph) + + # See Note [Side-Effectful Tokens in AOTAutograd] + num_tokens = len(fw_metadata.tokens) + if num_tokens != 0 and config.unlift_effect_tokens: + unlift_tokens(fw_module, fw_metadata, aot_config) + saved_updated_flat_args_subclasses_desugared = ( + saved_updated_flat_args_subclasses_desugared[num_tokens:] + ) + + assert copy_count == copy_count2 + + if aot_config.enable_log: + aot_graphs_log.info( + "%s", + lazy_format_graph_code( + "Forward graph", + fw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(fw_metadata), + ) + if maybe_subclass_meta is not None: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_subclass_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(maybe_subclass_meta), + ) + + trace_structured( + "aot_inference_graph", + payload_fn=lambda: fw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + + # TODO: should factor this into a separate function for export that always only returns just the graph. + if aot_config.is_export: + assert ( + maybe_subclass_meta is None + ), "aot_export_module does not support tensor subclass inputs for now." + return fw_module, saved_updated_flat_args_subclasses_desugared, maybe_subclass_meta + + +# Has the precondition that there +# are no duplicate arguments in flat_args (e.g., the same Tensor +# object never shows up twice. However, two tensor inputs MAY alias +# the same storage, so long as they have separate TensorImpls.) +def aot_dispatch_autograd_graph( + flat_fn, + flat_args: list[Any], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> tuple[torch.fx.GraphModule, tuple[list[Any], list[Any]], Optional[SubclassMeta]]: + # traced_tangents corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward. + # It includes outputs of the original forward, *and* any updated inputs due to input mutations. + # However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations. + joint_inputs = (flat_args, fw_metadata.traced_tangents) + + fn_prepared_for_autograd = fn_prepped_for_autograd( + flat_fn, + fw_metadata, + ) + joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config) + + joint_fn_to_trace, updated_joint_inputs = create_functionalized_fn( + joint_fn_to_trace, + joint_inputs, + meta=fw_metadata, + aot_config=aot_config, + trace_joint=True, + ) + + # TODO: replace with AOTDispatchSubclassWrapper once we refactor + # fn_input_mutations_to_outputs and create_functionalized_fn + # into CompilerWrappers. + subclass_tracing_info = aot_dispatch_subclass( + joint_fn_to_trace, + updated_joint_inputs, + is_joint_structure=True, + meta=fw_metadata, + fw_only=flat_fn, + ) + + joint_fn_to_trace = subclass_tracing_info.plain_tensor_trace_fn + updated_joint_inputs = subclass_tracing_info.plain_tensor_args + + (joint_fn_to_trace, updated_joint_inputs) = handle_effect_tokens_fn( + joint_fn_to_trace, + updated_joint_inputs, + meta=fw_metadata, + trace_joint=True, + ) + + # When we call _create_graph, this may mutate the metadata of joint + # inputs. But callers are expecting to get the original joint inputs. So + # we make aliases of all the inputs to make sure we have a copy that + # doesn't get modified. + # + # This destroys requires_grad/grad_fn information. However, backends + # beneath AOTAutograd are indifferent to this information, so it doesn't + # matter. + + fake_mode = detect_fake_mode() + if fake_mode: + saved_updated_joint_inputs = pytree.tree_map_only( + torch.Tensor, _detach_and_copy_item_memo, updated_joint_inputs + ) + else: + saved_updated_joint_inputs = pytree.tree_map_only( + torch.Tensor, lambda t: t.detach(), updated_joint_inputs + ) + maybe_subclass_meta = subclass_tracing_info.maybe_subclass_meta + + fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config) + + # There should be *NO* mutating ops in the graph at this point. + assert_functional_graph(fx_g.graph) + + # Redundant with the check above, but worth having in case tracing introduced + # a fake tensor. Unlikely. + # See Note: [Fake Modules and AOTAutograd] + torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g) + fx_g.graph.eliminate_dead_code() + copy_fwd_metadata_to_bw_nodes(fx_g) + fx_g.recompile() + + # TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect + # when we need to manually detach() some inputs in the forward. + # Higher order ops might eventually need to do the same. + if aot_config.is_export: + assert ( + maybe_subclass_meta is None + ), "aot_export_module does not support tensor subclass inputs for now." + return fx_g, saved_updated_joint_inputs, maybe_subclass_meta diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index 958804e5c763f..c1acba693f414 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -6,7 +6,10 @@ 3. regenerating/replaying views from their base 4. checking if a graph is functional i.e. whether it contains any mutation ops """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from __future__ import annotations from dataclasses import dataclass @@ -14,7 +17,10 @@ import torch from torch import Tensor +<<<<<<< HEAD from torch._C import _functionalization +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._logging import getArtifactLogger from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.functional_tensor import FunctionalTensor @@ -225,9 +231,15 @@ def gen_alias_from_base( aliased_base_tensor, target_meta_tensor, target_requires_grad, +<<<<<<< HEAD target_view_meta_sequence: Optional[ViewMetaSequence] = None, *, replay_views: bool, +======= + target_functional_tensor: Optional[FunctionalTensorMetadataEq] = None, + *, + replay_views, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): # Patch the correct requires_grad field of the output tensor, depending on whether: # (i) the reconstructed output (out) was came from a tensor that requires grad or not; @@ -246,11 +258,21 @@ def patch_requires_grad(out): # to replay them (view functions) on the aliased_base_tensor. if ( replay_views +<<<<<<< HEAD and target_view_meta_sequence is not None and not any(vm.has_symbolic_inputs for vm in target_view_meta_sequence.sequence) ): out = _functionalization.apply_view_meta_sequence( aliased_base_tensor, target_view_meta_sequence.sequence +======= + and target_functional_tensor is not None + and not torch._functionalize_is_symbolic(target_functional_tensor.tensor) + ): + functional_tensor = target_functional_tensor.tensor + + out = torch._functionalize_apply_view_metas( + functional_tensor, aliased_base_tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # If re-applying the ViewMeta sequence succeeded, there should be no more # problems going forward. We just check we got to the target shape and @@ -356,6 +378,7 @@ def make(t): ) +<<<<<<< HEAD # ViewMeta sequence wrapper for equality comparisons. # # Even though we can compare each ViewMeta instance, we compare the resulting @@ -395,6 +418,27 @@ def __eq__(self, other: object) -> bool: return NotImplemented return self.metadata == other.metadata +======= +# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata +# after applying all the ViewMeta operations. +class FunctionalTensorMetadataEq: + def __init__(self, tensor: torch.Tensor) -> None: + assert torch._is_functional_tensor(tensor) + self.tensor = tensor + + def __eq__(self, other: object) -> bool: + # If other is None, then it probably means that we weren't able to recreate + # the FunctionalTensorMetadataEq. One of this cases is when we update the + # view metadata by calling: create_synthetic_base_metadata. + if other is None: + return True + + # Comparison agains any other type is not implemented. + if not isinstance(other, FunctionalTensorMetadataEq): + return NotImplemented + + return has_same_metadata(self.tensor, other.tensor) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # new_arg and arg here are either: @@ -472,6 +516,7 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int: # this is mostly a hack to avoid failing XLA tests. # See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113 if "set_buffer_donor_" not in str(n.args[0]): +<<<<<<< HEAD assert n.args[0] in placeholders, ( f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" ) @@ -480,6 +525,16 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int: assert not n.target._schema.is_mutable, ( f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" ) +======= + assert ( + n.args[0] in placeholders + ), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" + mutation_count += 1 + else: + assert ( + not n.target._schema.is_mutable + ), f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return mutation_count @@ -492,9 +547,15 @@ def propagate_input_mutation_stacktraces(fx_g: torch.fx.Graph) -> None: if n.target is torch.ops.aten.copy_.default: # Can only copy_ into an input, and can only do so once if "set_buffer_donor_" not in str(n.args[0]): +<<<<<<< HEAD assert n.args[0] in placeholders, ( f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" ) +======= + assert ( + n.args[0] in placeholders + ), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placeholders.remove(n.args[0]) copy_from_node = n.args[1] # Pre-condition: every node has a "stack_trace" field in its meta, diff --git a/torch/_functorch/_aot_autograd/input_output_analysis.py b/torch/_functorch/_aot_autograd/input_output_analysis.py index 06581e1524fde..8e75f5afd7aa1 100644 --- a/torch/_functorch/_aot_autograd/input_output_analysis.py +++ b/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -24,7 +24,10 @@ from torch.fx.experimental.symbolic_shapes import is_concrete_int from .collect_metadata_analysis import coerce_tangent_and_suggest_memory_format +<<<<<<< HEAD from .descriptors import AOTInput, InputMutationAOTOutput, TangentAOTInput +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .schemas import ( BackwardSignature, GraphSignature, @@ -53,14 +56,18 @@ def remove_dupe_metadata( num_data_mutations = len([x for x in m.input_info if x.mutates_data]) other_traced_tangents = m.traced_tangents[num_data_mutations:] inp_traced_tangents = m.traced_tangents[:num_data_mutations] +<<<<<<< HEAD other_traced_tangents_descs = m.traced_tangents_descs[num_data_mutations:] inp_traced_tangents_descs = m.traced_tangents_descs[:num_data_mutations] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) filtered_inp_traced_tangents = [ # See Note [Tangents memory format] x for i, x in enumerate(inp_traced_tangents) if keep_arg_mask[m.mutated_inp_runtime_indices[i]] ] +<<<<<<< HEAD filtered_inp_traced_tangents_descs = [ x_desc for i, x_desc in enumerate(inp_traced_tangents_descs) @@ -70,6 +77,9 @@ def remove_dupe_metadata( traced_tangents_descs = ( filtered_inp_traced_tangents_descs + other_traced_tangents_descs ) +======= + traced_tangents = filtered_inp_traced_tangents + other_traced_tangents +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert m.subclass_tangent_meta is not None subclass_tangent_meta = [ @@ -89,14 +99,21 @@ def remove_dupe_metadata( dynamic_dims=o.dynamic_dims, base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx], requires_grad=o.requires_grad, +<<<<<<< HEAD view_meta_sequence=o.view_meta_sequence, +======= + functional_tensor=o.functional_tensor, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for o in m.output_info ], num_intermediate_bases=m.num_intermediate_bases, keep_input_mutations=m.keep_input_mutations, traced_tangents=traced_tangents, +<<<<<<< HEAD traced_tangents_descs=traced_tangents_descs, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We are guaranteed not to get here, since dupes are not supported today with subclass inputs. subclass_inp_meta=[], subclass_fw_graph_out_meta=[], @@ -122,7 +139,10 @@ def create_synthetic_base_metadata( synthetic_base_info: list[Union[int, tuple[int, torch.Tensor]]], outer_args: list[Any], inner_args: list[Any], +<<<<<<< HEAD inner_args_desc: list[AOTInput], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> tuple[ViewAndMutationMeta, list[int]]: # maps inner arg indices to outer arg indices synthetic_base_to_indices: dict[int, list[int]] = {} @@ -242,12 +262,17 @@ def create_synthetic_base_metadata( # Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases base_idx=new_base_idx, # type: ignore[arg-type] requires_grad=o.requires_grad, +<<<<<<< HEAD view_meta_sequence=o.view_meta_sequence, +======= + functional_tensor=o.functional_tensor, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) inner_mutated_tangents_and_memory_formats = [ # See Note [Tangents memory format] +<<<<<<< HEAD ( coerce_tangent_and_suggest_memory_format(x), TangentAOTInput(InputMutationAOTOutput(x_desc)), @@ -263,6 +288,15 @@ def create_synthetic_base_metadata( ] inner_mutated_tangents_memory_formats = [ x[0][1] for x in inner_mutated_tangents_and_memory_formats +======= + coerce_tangent_and_suggest_memory_format(x) + for inner_idx, x in enumerate(inner_args) + if input_infos[inner_idx].mutates_data and input_infos[inner_idx].requires_grad + ] + inner_mutated_tangents = [x[0] for x in inner_mutated_tangents_and_memory_formats] + inner_mutated_tangents_memory_formats = [ + x[1] for x in inner_mutated_tangents_and_memory_formats +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] output_info = existing_output_infos + input_metadata_output_info @@ -270,10 +304,13 @@ def create_synthetic_base_metadata( traced_tangents = ( inner_mutated_tangents + m.traced_tangents[len(inner_mutated_tangents) :] ) +<<<<<<< HEAD traced_tangents_descs = ( inner_mutated_tangents_descs + m.traced_tangents_descs[len(inner_mutated_tangents) :] ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert m.subclass_tangent_meta is not None subclass_tangent_meta = [ PlainTensorMeta(0, memory_format=x) @@ -287,7 +324,10 @@ def create_synthetic_base_metadata( num_intermediate_bases=m.num_intermediate_bases, keep_input_mutations=m.keep_input_mutations, traced_tangents=traced_tangents, +<<<<<<< HEAD traced_tangents_descs=traced_tangents_descs, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We are guaranteed not to get here, since synthetic_base codepaths are not supported today with subclass inputs. subclass_inp_meta=[], subclass_fw_graph_out_meta=[], @@ -306,12 +346,19 @@ def compute_overlapping_inputs(aot_config, fwd_inputs, aliased_input_indices): tracing_context = torch._guards.TracingContext.try_get() if tracing_context is not None: +<<<<<<< HEAD assert tracing_context.fake_mode is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shape_env = tracing_context.fake_mode.shape_env # Check whether we can actually get the dynamo sources from within AOTAutograd. if aot_config.aot_autograd_arg_pos_to_source and shape_env is not None: +<<<<<<< HEAD maybe_suppress_guards = shape_env.suppress_guards # type: ignore[assignment] +======= + maybe_suppress_guards = shape_env.suppress_guards +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Check whether there are any symbolic values being used. # We do this for 2 reasons: @@ -460,7 +507,10 @@ def create_graph_signature( named_buffers=buffer_names, num_user_inputs=num_user_args, num_user_outputs=num_user_fw_outs, +<<<<<<< HEAD trace_joint=trace_joint, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) loss_index=loss_index, backward_signature=backward_signature, ) diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py new file mode 100644 index 0000000000000..7e0b21b0b0446 --- /dev/null +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -0,0 +1,1845 @@ +# mypy: allow-untyped-defs +""" +Functions in this module do most of the "work" of AOTAutograd. +An aot_dispatch_* function: +- Takes in the input flat_fn, flat_args, and some metadata +- Runs a set of pre compile wrappers (e.g. argument deduping) +- Runs the actual compiler +- Wraps the returned callable in a set of post compile wrappers +- Returns the wrapped callable and metadata. +""" + +import copy +import dataclasses +import itertools +import logging +import operator +import time +import traceback +from collections import defaultdict +from contextlib import nullcontext +from typing import Any, Callable, Optional, TYPE_CHECKING + +import torch +import torch.utils._pytree as pytree +import torch.utils.dlpack +from torch import Tensor +from torch._dynamo.utils import detect_fake_mode, dynamo_timed, lazy_format_graph_code +from torch._guards import CompileContext, TracingContext +from torch._logging import getArtifactLogger, trace_structured +from torch._subclasses import FakeTensor +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.proxy_tensor import is_sym_node +from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals +from torch.fx.graph_module import GraphModule +from torch.fx.passes._tensorify_python_scalars import tensorify_python_scalars +from torch.multiprocessing.reductions import StorageWeakRef +from torch.types import py_sym_types +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torchgen.utils import dataclass_repr + +from .. import config +from .autograd_cache import ( + AOTAutogradCache, + serialize_graph_module, + should_use_remote_autograd_cache, +) +from .dispatch_and_compile_graph import ( + aot_dispatch_autograd_graph, + aot_dispatch_base_graph, +) +from .logging_utils import track_graph_compiling +from .runtime_wrappers import ( + AOTDedupeWrapper, + AOTDispatchAutograd, + AOTDispatchSubclassWrapper, + AOTSyntheticBaseWrapper, + AutogradLazyBackwardCompileInfo, + CompilerWrapper, + DebugAssertWrapper, + EffectTokensWrapper, + FakifiedOutWrapper, + FunctionalizedRngRuntimeWrapper, + make_runtime_safe, + post_compile, + pre_compile, + RuntimeWrapper, +) +from .schemas import AOTConfig, MutationType, ViewAndMutationMeta +from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta +from .utils import ( + _get_symint_hints, + contain_metadata_mutation_ops, + get_cuda_generator_meta_val, + make_boxed_func, + strict_zip, + unlift_tokens, +) + + +if TYPE_CHECKING: + from collections.abc import Sequence + +zip = strict_zip + +log = logging.getLogger(__name__) +aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph") +aot_graphs_log = getArtifactLogger(__name__, "aot_graphs") + +aten = torch.ops.aten + +# Returns a Callable and a ViewAndMutationMeta. +# Currently, only export needs the ViewAndMutationMeta after this function. +DispatchReturn = tuple[Callable, ViewAndMutationMeta] + + +def _create_wrappers_for_dispatch(needs_autograd: bool) -> list[CompilerWrapper]: + """ + Wrappers that run on every dispatch function + """ + return [AOTDedupeWrapper(), AOTSyntheticBaseWrapper(trace_joint=needs_autograd)] + + +# Export's dispatching logic is unique in a few ways: it only needs the "graph" +# bits of aot_autograd, and doesn't need to do any specific wrapping. +def aot_dispatch_export( + flat_fn: Callable, + flat_args: list[Any], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + needs_autograd: bool, +) -> DispatchReturn: + wrappers = _create_wrappers_for_dispatch(needs_autograd) + flat_fn, flat_args, fw_metadata = pre_compile( + wrappers, + flat_fn, + flat_args, + aot_config, + fw_metadata=fw_metadata, + ) + if needs_autograd and not aot_config.pre_dispatch: + graph, _, _ = aot_dispatch_autograd_graph( + flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + ) + else: + graph, _, _ = aot_dispatch_base_graph( + flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + ) + + # NB: the wrappers that run in pre_compile for export are + # either a no-op, because they're not needed, or will raise a runtime error, + # since they don't support export. + # We still run these wrappers to make sure that they're not needed pre compile, + # but we technically don't need to run them post compile at all here. + compiled_fn, fw_metadata = post_compile( + wrappers, graph, aot_config, runtime_metadata=fw_metadata + ) + + # Therefore, since no wrapperes run, we don't get back a callable - we get back the raw fx graph + # (either a joint or an inference-only graph) + assert isinstance(compiled_fn, torch.fx.GraphModule) + return compiled_fn, fw_metadata + + +def sanitize_aot_config(input: AOTConfig) -> AOTConfig: + return AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + inference_compiler=None, + num_params_buffers=input.num_params_buffers, + aot_id=input.aot_id, + keep_inference_input_mutations=input.keep_inference_input_mutations, + is_export=input.is_export, + no_tangents=input.no_tangents, + aot_autograd_arg_pos_to_source=input.aot_autograd_arg_pos_to_source, + dynamic_shapes=input.dynamic_shapes, + enable_log=input.enable_log, + static_input_indices=input.static_input_indices, + pre_dispatch=input.pre_dispatch, + cache_info=None, + precompile_backend_id=input.precompile_backend_id, + ) + + +def aot_dispatch_base( + flat_fn, + flat_args: list[Any], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> DispatchReturn: + """ + Handles functions that don't need autograd. Runs wrappers and compiles with fw_compiler. + """ + wrappers = _create_wrappers_for_dispatch(needs_autograd=False) + flat_fn, flat_args, fw_metadata = pre_compile( + wrappers, flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + ) + fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph( # type: ignore[misc] + flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + ) + # Save the forward_graph_str right after aot_dispatch_base_graph, + # to save in the cache + aot_forward_graph_str = None + if aot_config.cache_info is not None: + aot_forward_graph_str = fw_module.print_readable( + print_output=False, + include_stride=True, + include_device=True, + fast_sympy_print=True, + ) + + fakified_out_wrapper = FakifiedOutWrapper() + ( + fw_module, + updated_flat_args, + fw_metadata, + ) = fakified_out_wrapper.pre_compile( + fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata + ) + functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper() + ( + fw_module, + updated_flat_args, + fw_metadata, + ) = functionalized_rng_wrapper.pre_compile( + fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata + ) + assert isinstance(fw_module, GraphModule) + + if aot_config.enable_log: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "torch._functorch.config", + "encoding": "string", + }, + payload_fn=lambda: torch._functorch.config.get_config_copy(), + ) + + disable_amp = torch._C._is_any_autocast_enabled() + context = torch._C._DisableAutocast if disable_amp else nullcontext + + with context(), track_graph_compiling(aot_config, "inference"): + compiler = ( + aot_config.inference_compiler + if aot_config.inference_compiler is not None + else aot_config.fw_compiler + ) + + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.fw_metadata = ( + fw_metadata + if maybe_subclass_meta is None + else maybe_subclass_meta.fw_metadata + ) + + with TracingContext.report_output_strides() as fwd_output_strides: + fake_mode = detect_fake_mode() + if fake_mode is not None and fake_mode.shape_env is not None: + tensorify_python_scalars(fw_module, fake_mode.shape_env, fake_mode) + compiled_fw = compiler(fw_module, updated_flat_args) + + if fakified_out_wrapper.needs_post_compile: + fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides) + + make_runtime_safe(fw_metadata, maybe_subclass_meta) + + # However, RuntimeWrapper does not expect the rng offsets in the + # output. So, we have to create another wrapper and take out the offset. As + # a result, we have to account for not boxed_call compilers as well. + if not getattr(compiled_fw, "_boxed_call", False): + compiled_fw = make_boxed_func(compiled_fw) + + # Create a wrapper to set up the rng functionalize and fakified out bits + compiled_fw = functionalized_rng_wrapper.post_compile( + compiled_fw, aot_config, runtime_metadata=fw_metadata + ) + cache_info = aot_config.cache_info + if cache_info is not None: + if hasattr(compiled_fw, "_fx_graph_cache_key"): + time_taken_ns = time.time_ns() - cache_info.start_time_ns + guards_expr = AOTAutogradCache.generate_guards_expression(cache_info) + entry = AOTAutogradCache.make_entry( + compiled_fw_func=compiled_fw, # type: ignore[arg-type] + compiled_bw_func=None, + aot_joint_graph_str=None, + aot_forward_graph_str=aot_forward_graph_str, + aot_backward_graph_str=None, + runtime_metadata=fw_metadata, + dispatch_wrappers=wrappers, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=None, + indices_of_inps_to_detach=[], + forward_time_taken_ns=time_taken_ns, + backward_time_taken_ns=0, + sanitized_aot_config=sanitize_aot_config(aot_config), + guards_expr=guards_expr, + backward_state_indices=None, + num_symints_saved_for_bw=None, + serialized_bw_module=None, + ) + AOTAutogradCache.save( + cache_info.cache_key, entry, remote=should_use_remote_autograd_cache() + ) + + compiled_fw = fakified_out_wrapper.post_compile( + compiled_fw, + aot_config, + runtime_metadata=fw_metadata, + ) + + compiled_fw = EffectTokensWrapper().post_compile( + compiled_fw, + aot_config, + runtime_metadata=fw_metadata, + ) + + # Why do we need to pass in num_fw_outs_saved_for_bw? + # See Note: [Partitioner handling for Subclasses, Part 2] + compiled_fw = AOTDispatchSubclassWrapper( + trace_joint=False, + # TODO: once we use pre_compile this will be flat_fn at the top of this function + fw_only=None, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=None, + ).post_compile( + compiled_fw, + aot_config, # not used + runtime_metadata=fw_metadata, + ) + + if not getattr(compiled_fw, "_boxed_call", False): + compiled_fw = make_boxed_func(compiled_fw) + + compiled_fn = RuntimeWrapper( + indices_of_inps_to_detach=[], + trace_joint=False, + disable_amp=disable_amp, + ).post_compile( + compiled_fw, + aot_config, + runtime_metadata=fw_metadata, + ) + + compiled_fn = post_compile( + wrappers, compiled_fn, aot_config, runtime_metadata=fw_metadata + ) + return compiled_fn + + +def collect_fw_donated_buffer_idxs( + fw_ins: list[Optional[FakeTensor]], + user_fw_outs: list[Optional[FakeTensor]], + bw_outs: list[Optional[FakeTensor]], + saved_tensors: list[FakeTensor], +) -> list[int]: + """ + Checks if the saved tensors are donated buffers, which means a saved tensor is not + an alias of any tensors in fw_ins, user_fw_outs, and bw_outs. + """ + + storage_refs = set() + for t in itertools.chain(fw_ins, user_fw_outs, bw_outs): + # Only access storage if a tensor has storage (not sparse) + if t is not None and isinstance(t, FakeTensor) and not is_sparse_any(t): + storage_refs.add(StorageWeakRef(t.untyped_storage())) + + num_saved_tensor = len(saved_tensors) + donated_buffer_idxs = [] + for i in range(num_saved_tensor): + t = saved_tensors[i] + if ( + t is not None + and not is_sparse_any(t) + and StorageWeakRef(t.untyped_storage()) not in storage_refs + ): + donated_buffer_idxs.append(i) + + return donated_buffer_idxs + + +def collect_bw_donated_buffer_idxs( + fw_module: torch.fx.GraphModule, + bw_module: torch.fx.GraphModule, + fw_metadata: ViewAndMutationMeta, +) -> list[int]: + """ + Collects backward donated buffer indexes from fw_module and bw_module. + """ + + # [Note: Metadata mutation in proxy tracing] + # node.meta["val"] is a snapshot of the tensor value when tracing a graph, + # instead of the final state after the graph has run. node.meta["val"] is + # not updated even if later there is a metadata mutation op. + # See: https://github.com/pytorch/pytorch/pull/141308#issuecomment-2495798947 + # + # Currently, metadata mutation op happens only for sacrificial parameter + # specifically the `set_` op. This motivates banning metadata mutation from + # proxy tracing. + # + # Since node.meta["val"] is used to detect donated buffer, we return an empty + # list if there exists metadata mutation op. + if contain_metadata_mutation_ops(fw_module) or contain_metadata_mutation_ops( + bw_module + ): + return [] + + fw_ins = fw_module.graph.find_nodes(op="placeholder") + bw_outs = next(reversed(bw_module.graph.find_nodes(op="output"))).args[0] + fw_outs = next(reversed(fw_module.graph.find_nodes(op="output"))).args[0] + + fw_ins = [ + n.meta["val"] if (hasattr(n, "meta") and "val" in n.meta) else None + for n in fw_ins + ] + fw_outs = [ + n.meta["val"] if (hasattr(n, "meta") and "val" in n.meta) else None + for n in fw_outs + ] + bw_outs = [ + n.meta["val"] if (hasattr(n, "meta") and "val" in n.meta) else None + for n in bw_outs + ] + + user_fw_outs = fw_outs[: fw_metadata.num_forward] + saved_tensors = fw_outs[fw_metadata.tensors_saved_for_backwards_slice] + + fw_donated_buffer = collect_fw_donated_buffer_idxs( + fw_ins, + user_fw_outs, + bw_outs, + saved_tensors, + ) + + assert fw_metadata.num_symints_saved_for_bw is not None + return [fw_metadata.num_symints_saved_for_bw + i for i in fw_donated_buffer] + + +@dataclasses.dataclass +class InvokeSubgraphHopGraphs: + """ + A data structure to hold all the information needed to partition the + `joint_hop_gm` and joint graph and the restitch the `new_fw_hop_gm` and + `new_bw_hop_gm` into the bigger `joint_gm`. + """ + + # To avoid re-partitioning subgraphs + partitioning_done: bool = False + old_num_fw_outputs: Optional[int] = None + old_num_fw_inputs: Optional[int] = None + + new_fw_hop_gm: Optional[torch.fx.GraphModule] = None + new_bw_hop_gm: Optional[torch.fx.GraphModule] = None + new_num_sym_nodes: Optional[int] = None + new_num_saved_nodes: Optional[int] = None + + +def run_joint_graph_passes_on_hops( + joint_gm: torch.fx.GraphModule, + joint_inputs: Any, + aot_config: AOTConfig, +) -> torch.fx.GraphModule: + """ + This pass runs the joint graph passes on the HOP graph. In torch.compile, we + typically have many passes which work on the joint graph and then end with a + partitioner. + + + The partitioner part is quite mechanical to handle. HOP have their own + forward and backward graph. The process can be broken into following steps + + 1) Get a `joint_hop_gm` from the `fw_hop_gm` and `bw_hop_gm` + 2) Run joint graph passes on the `joint_hop_gm` to get `new_fw_hop_gm` and `new_bw_hop_gm` + 3) Stitch the `new_fw_hop_gm` and `new_bw_hop_gm` back into the `joint_gm`. + + The terminology used in the code is + `joint_graph/joint_gm` : Refers to the main graph. This may contain many HOPs which have their own `hop_graph` + `fw_hop_graph/fw_hop_gm` : Refers to the forward graph associated with a HOP. + `bw_hop_graph/bw_hop_gm` : Refers to the backward graph associated with a HOP. + `joint_hop_graph/joint_hop_gm` : Refers to the subgraph associated with the HOP like invoke_subgraph. + `new_fw_hop_graph/new_fw_hop_gm` : Refers to the forward graph after partitioning is applied to `joint_hop_gm`. + `new_bw_hop_graph/new_bw_hop_gm` : Refers to the backward graph after partitioning is applied to `joint_hop_gm`. + + NB: This pass works for invoke_subgraph today because we took extra care in + the Autograd.Dispatch key of invoke_subgraph to vastly simplify Step 1. + """ + from torch._higher_order_ops import invoke_subgraph + + def num_outputs(mod): + return len(mod.graph.find_nodes(op="output")[0].args[0]) + + def num_inputs(mod): + return len(mod.graph.find_nodes(op="placeholder")) + + def prepare_for_partitioner(mod, num_primals, num_fw_outputs): + # min-cut partitioner requires the placeholders to have primals and + # tangents string in the node.name. The signature of the joint graph is + # (*primals, *tangents) + + # We also have to update the output signature which is right now + # (*grads, *fw_outs) and we have to change to (*fw_outs, *grads) for the + # partitioner to work. + new_graph = torch.fx.Graph() + env = {} + + primals_counter = itertools.count(0) + tangents_counter = itertools.count(0) + + for idx, node in enumerate(mod.graph.nodes): + if node.op == "placeholder": + if idx < num_primals: + env[node] = new_graph.placeholder( + f"primals_{next(primals_counter)}" + ) + else: + env[node] = new_graph.placeholder( + f"tangents_{next(tangents_counter)}" + ) + env[node].meta = copy.copy(node.meta) + elif node.op == "output": + # Reverse the (*grads, *fw_outs) to (*fw_outs, *grads) + # The reason for having the reversed signature in the first + # place is to simplify step 3. + old_outputs = node.args[0] + new_outputs = ( + *old_outputs[-num_fw_outputs:], + *old_outputs[:-num_fw_outputs], + ) + new_outputs = [env[n] if n else None for n in new_outputs] + new_graph.output(tuple(new_outputs)) + else: + env[node] = new_graph.node_copy(node, lambda n: env[n]) + env[node].meta = copy.copy(node.meta) + + new_graph.lint() + + out = torch.fx.GraphModule(mod, new_graph) + return out + + new_hop_graphs: dict[str, InvokeSubgraphHopGraphs] = defaultdict( + lambda: InvokeSubgraphHopGraphs() + ) + + # Step 1 - Get a `joint_hop_gm` from the `fw_hop_gm` and `bw_hop_gm` This is + # easy to do for `invoke_subgraph` HOP. During the Autograd dispatch key + # tracing, we have put the joint_hop_graph in the backward hop graph itself. + # So to recover the joint_hop_gm, we just have to look at the backward + # HOP graphs. + # So we will merge step 1 and step 2 in this next section + + # Save the fw and bwd hop nodes. We will later in-place modify the graph + # using these nodes. + fw_hop_nodes = [] + bw_hop_nodes = [] + for node in joint_gm.graph.nodes: + if ( + node.op == "call_function" + and node.target is invoke_subgraph + and isinstance(node.args[1], str) + ): + if node.args[1].startswith("fw"): + fw_hop_nodes.append(node) + elif node.args[1].startswith("bw"): + bw_hop_nodes.append(node) + + if not bw_hop_nodes: + return joint_gm + + assert len(fw_hop_nodes) == len(bw_hop_nodes) + + # Create a bw to hop node mapping. This helps us in identifying the bw and + # fw subgraph pairs without relying on the identifier. This is important + # because we can have different subgraphs for bwd for same subgraph in the + # fwd because of differing strides in the backward. + bw_to_fw_hop_node = dict(zip(list(reversed(bw_hop_nodes)), fw_hop_nodes)) + + for node in bw_hop_nodes: + identifier = node.args[1].removeprefix("bw") + + # If partitioning already done for this identifier, skip. This saves + # redundant joint graph passes for same subgraphs. + if new_hop_graphs[identifier].partitioning_done: + continue + + # Collect some information from the forward hop graph + fw_hop_node = bw_to_fw_hop_node[node] + fw_hop_gm = getattr(joint_gm, fw_hop_node.args[0].target) + assert isinstance(fw_hop_gm, torch.fx.GraphModule) + num_fw_inputs = num_inputs(fw_hop_gm) + num_fw_outputs = num_outputs(fw_hop_gm) + new_hop_graphs[identifier].old_num_fw_inputs = num_fw_inputs + new_hop_graphs[identifier].old_num_fw_outputs = num_fw_outputs + + # Step 1) - Get the `joint_hop_gm`. As mentioned earlier, the + # backward graph is the joint graph. + joint_hop_gm = getattr(joint_gm, node.args[0].target) + assert isinstance(joint_hop_gm, torch.fx.GraphModule) + + # Prepare the graph for the partitioner + joint_hop_gm = prepare_for_partitioner( + joint_hop_gm, num_fw_inputs, num_fw_outputs + ) + + # TODO: invoke_subgraph should track which of its inputs static indices + # so it can propagate them to the partitioner (and use in cudagraphs) + static_lifetime_input_indices: list[int] = [] + # Step 2) and 3) - Run joint graph passes and partitioner + new_fw_hop_gm, new_bw_hop_gm = aot_config.partition_fn( + joint_hop_gm, + [], + num_fwd_outputs=num_fw_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + ) + + # Save the new forward and backward graph modules + new_hop_graphs[identifier].new_fw_hop_gm = new_fw_hop_gm + new_hop_graphs[identifier].new_bw_hop_gm = new_bw_hop_gm + + # Save the number of symints and saved tensors + new_fw_out_nodes = new_fw_hop_gm.graph.find_nodes(op="output")[0].args[0] + extra_outputs = new_fw_out_nodes[num_fw_outputs:] + symint_outputs = [n for n in extra_outputs if is_sym_node(n)] + + new_hop_graphs[identifier].new_num_sym_nodes = len(symint_outputs) + new_hop_graphs[identifier].new_num_saved_nodes = len(extra_outputs) - len( + symint_outputs + ) + + new_hop_graphs[identifier].partitioning_done = True + + # Step 3) Restitch the new fw and bw graphs back into the main graph. + # + # This is a very mechanical process. There are a quite a few pieces that we + # need to connect together to make it work. Lets try to understand the + # problem statement first. + # + # For the forward graph, the signature of the old_fw_hop_gm is + # inputs - (*primals) + # outputs - (*fw_outs) + # Now the signature of the new_fw_hop_gm is + # inputs - (*primals) -- This is same + # outputs - (*fw_outs, *saved_tensors) - This is different + # At a high level, this is an easy transformation, in the new graph we just + # have to replace the old_fw_hop_gm with the new_fw_hop_gm. Everything else + # falls into place, because the input signature (i.e. args) is same. And + # even though output signature is different, fw_outs are still at the same + # indexes as before. So the forward of the `joint_gm` works nicely. + # + # Now, lets look at the backward hop graph. Old signature + # inputs - (*primals, *tangents) + # outputs - (*grad_outs, *fw_outs) + # New signature + # inputs - (*saved_tensors, *tangents) -- Different + # outputs - (*grad_outs) -- Different + # Here both input and output signature change. The output signature handling + # is quite easy because the grads_out are sitting at the right place, so we + # dont have to do anything. + # + # For the input signature, we have to collect the saved tensors from the + # corresponding forward graph output. We collect all saved_tensors when we + # see the forward graph, and save it into a map and then later use it during + # the backward. + + # The stack of fw_nodes for invoke_subgraph HOP. There is an implicit + # assumption about the graph structure, i.e., if we have hop1, hop2, hop3, + # ... in the forward part of the joint graph, we will have .., hop3, hop2, + # hop1 order for the backward. This structure allows us to just use a stack + # to collect all the information that we need to pass from the forward hop + # node to the corresponding backward node. + + already_added_new_hop_mods = set() + + def add_new_hop_gm(new_subgraph_mod, name): + new_subgraph_attr_name = f"partitioned_{name}" + if new_subgraph_attr_name in already_added_new_hop_mods: + return new_subgraph_attr_name + + joint_gm.register_module(new_subgraph_attr_name, new_subgraph_mod) + already_added_new_hop_mods.add(new_subgraph_attr_name) + return new_subgraph_attr_name + + def propagate_meta_info(new_hop_gm, new_call_function_node, old_call_function_node): + # Copy all the fields from the old call_function node. And then override + # the `val` meta field with the outputs of new_hop_gm. + new_call_function_node.meta = copy.copy(old_call_function_node.meta) + + output = new_hop_gm.graph.find_nodes(op="output")[0] + out_example_vals = [n.meta["val"] if n else None for n in output.args[0]] + new_call_function_node.meta["val"] = tuple(out_example_vals) + + for bw_node in reversed(bw_hop_nodes): + identifier = bw_node.args[1].removeprefix("bw") + + # Make changes to the corresponding fw and bw node pair simultaneously. + # The removes the need of any bookkeeping. + + # Fw node changes + # Insert the new_fw_hop_gm. This is straightforward. Get the + # new_fw_hop_gm, insert the hop_gm as a get_attr fw_node, and then + # add a call_function fw_node. Additionally, also use getitem + # call_functions to collect the saved_tensor nodes + + fw_node = bw_to_fw_hop_node[bw_node] + new_fw_hop_gm = new_hop_graphs[identifier].new_fw_hop_gm + assert new_fw_hop_gm is not None + + old_num_fw_outputs = new_hop_graphs[identifier].old_num_fw_outputs + new_num_sym_nodes = new_hop_graphs[identifier].new_num_sym_nodes + new_num_saved_nodes = new_hop_graphs[identifier].new_num_saved_nodes + assert old_num_fw_outputs is not None + assert new_num_sym_nodes is not None + assert new_num_saved_nodes is not None + total_outputs = old_num_fw_outputs + new_num_saved_nodes + new_num_sym_nodes + + extra_fw_outputs = [] + + # Insert the new_fw_hop_gm into the joint_gm + with joint_gm.graph.inserting_after(fw_node): + new_fw_mod_attr_name = add_new_hop_gm(new_fw_hop_gm, f"fw{identifier}") + new_fw_mod_attr = joint_gm.graph.get_attr(new_fw_mod_attr_name) + + # new_hop_fw_gm output signature is (*fw_outs, *saved_tensors) + with joint_gm.graph.inserting_after(new_fw_mod_attr): + new_fw_node = joint_gm.graph.call_function( + the_function=invoke_subgraph, + args=( + new_fw_mod_attr, + new_fw_mod_attr_name, + *fw_node.args[2:], + ), + ) + propagate_meta_info(new_fw_hop_gm, new_fw_node, fw_node) + + # old_num_fw_outputs = (*fw_outs) + # new_num_fw_outputs = (*fw_outs, *saved_tensors, *sym_nodes) + with joint_gm.graph.inserting_after(new_fw_node): + for fw_out_idx in range(old_num_fw_outputs, total_outputs): + saved_tensor_node = joint_gm.graph.call_function( + the_function=operator.getitem, args=(new_fw_node, fw_out_idx) + ) + saved_tensor_node.meta = copy.copy(new_fw_node.meta) + saved_tensor_node.meta["val"] = new_fw_node.meta["val"][fw_out_idx] + extra_fw_outputs.append(saved_tensor_node) + + fw_node.replace_all_uses_with(new_fw_node) + joint_gm.graph.erase_node(fw_node) + + # Bw node changes + # Prepare the operands for the bwd graph + # Old bw graph signature : (*primals, *tangents) + # New signature will be : (*sym_nodes, *saved_tensors, *tangents) + # We have already collected the saved_tensors in the forward hop processing. + + # extra_fw_outputs are in the order (*saved_nodes, *sym_nodes). + # Partitioner has this quirk where the backward wants sym_nodes + # first. So extract the sym and saved nodes. + + new_bw_hop_gm = new_hop_graphs[identifier].new_bw_hop_gm + assert new_bw_hop_gm is not None + + saved_tensor_nodes = extra_fw_outputs[:new_num_saved_nodes] + sym_nodes = extra_fw_outputs[new_num_saved_nodes:] + + num_primals = new_hop_graphs[identifier].old_num_fw_inputs + assert num_primals is not None + tangents = list(bw_node.args[2 + num_primals :]) + operands = sym_nodes + saved_tensor_nodes + tangents + + # Insert the new_bw_hop_gm into the joint_gm + with joint_gm.graph.inserting_after(bw_node): + new_bw_mod_attr_name = add_new_hop_gm(new_bw_hop_gm, bw_node.args[1]) + new_bw_mod_attr = joint_gm.graph.get_attr(new_bw_mod_attr_name) + + with joint_gm.graph.inserting_after(new_bw_mod_attr): + new_bw_node = joint_gm.graph.call_function( + the_function=invoke_subgraph, + args=( + new_bw_mod_attr, + new_bw_mod_attr_name, + *operands, + ), + ) + propagate_meta_info(new_bw_hop_gm, new_bw_node, bw_node) + # Since the partitioner is run after the graph passes, we have lost + # the eager information and cannot faithfully extract the eager + # inputs for the new partitioned backward graph. For the forward + # graph, it was fine because the input signature remains same. + new_bw_node.meta.pop("eager_input_vals", None) + + bw_node.replace_all_uses_with(new_bw_node) + joint_gm.graph.erase_node(bw_node) + + joint_gm.graph.eliminate_dead_code() + joint_gm.graph.lint() + joint_gm.recompile() + return joint_gm + + +def maybe_log_graph( + gm, + graph_name, + aot_config, + structured_log_prefix_fn, + out_structured_logs: Optional[list[str]] = None, +): + if not aot_config.enable_log: + return + aot_graphs_log.debug( + "%s", + lazy_format_graph_code( + f"{graph_name}", + gm, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + + def gm_str_fn() -> str: + return gm.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + if out_structured_logs is not None: + out_structured_logs.append(f"{structured_log_prefix_fn()}:{gm_str_fn()}") + else: + trace_structured( + f"{structured_log_prefix_fn()}", + payload_fn=lambda: gm_str_fn(), + ) + + +def create_wrap_fn(fn, args): + from functools import wraps + + from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify + + from .functional_utils import from_fun, has_data_mutation, to_fun + + def assert_no_mutation(t): + assert not has_data_mutation( + t + ), "Saved tensors hooks with inputs mutations are not allowed" + + @wraps(fn) + def _wrapper(*args): + with maybe_enable_thunkify(): + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + with disable_above: + f_args = pytree.tree_map(to_fun, args) + f_outs = fn(*f_args) + pytree.tree_map(assert_no_mutation, f_args) + return pytree.tree_map(from_fun, f_outs) + + return _wrapper, args + + +def prepare_hook_gm(aot_config, fn, args): + from torch._functorch._aot_autograd.dispatch_and_compile_graph import _create_graph + + fn, args = create_wrap_fn(fn, args) + gm = _create_graph(fn, args, aot_config=aot_config) + return gm + + +# Inline Autograd saved_tensors_hooks into epilogue of forward graph +# and prologue of backward graph. +# This changes forward graph outputs and inputs. +# Pack hook can return tensors, sym scalars, constants. +# All tensors to save for backward will be grouped together at front. +# Sym scalars grouped on another end. Constants are inlined in the graph. +def maybe_inline_graph_saved_tensors_hooks( + fw_module, + bw_module, + num_inner_fwd_outputs, + inner_meta, + aot_config, + static_input_indices, +): + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + return + + get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks + are_inline_hooks = ( + torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable + ) + + hooks = get_hooks() + if not are_inline_hooks(hooks): + return + + pack_hook_gm, unpack_hook_gm = hooks + + structured_logs: list[str] = [] + maybe_log_graph( + fw_module, + "Forward graph pre saved_tensors_hooks inlining", + aot_config, + lambda: "aot_forward_graph_pre_saved_tensors_hooks", + structured_logs, + ) + maybe_log_graph( + bw_module, + "Backward graph pre saved_tensors_hooks inlining", + aot_config, + lambda: "aot_backward_graph_pre_saved_tensors_hooks", + structured_logs, + ) + fw_g = fw_module.graph + bw_g = bw_module.graph + + fw_g_names = {node.name for node in fw_g.nodes} + bw_g_names = {node.name for node in bw_g.nodes} + + def _gen_unused_name(candidate: str): + c = candidate + i = 0 + while c in fw_g_names or c in bw_g_names: + c = f"{candidate}_{i}" + i = i + 1 + return c + + bw_g_inputs = bw_g.find_nodes(op="placeholder") + + fw_out_n = fw_g.output_node() + fw_outs = fw_out_n.args[0] # type: ignore[var-annotated] + fw_outs_inner_set = set(fw_outs[:num_inner_fwd_outputs]) + fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:] + fw_outs_packed_tensors = [] # type: ignore[var-annotated] + fw_outs_packed_syms = [] # type: ignore[var-annotated] + + # The main use case for saved_tensors_hooks is activation quantization, + # for memory usage optimization. + # Desired behavior is to quantize saved activations to free the original saved tensor. + # Saved nodes may include forward inputs, outputs, parameters. + # They may be held by something else and will not be deallocated after quantization. + # Donated buffers are intermediates in the graph invisible for the user, + # this guarantees that they can be deallocated. + # Using this as a default behavior to select saved nodes to apply hooks. + # There is also a config to apply hooks for all saved nodes without any filtering. + # The plan is to propagate meta about the source of the saved node to the user hook function. + mode = torch._functorch.config.saved_tensors_hooks_filtering_mode + allow_set = None + exclude_set = None + + if mode == "donated": + # collect_bw_donated_buffer_idxs requires inner_meta to have num_symints_saved_for_bw + inner_meta.num_symints_saved_for_bw = len( + [n for n in fw_outs_saved_for_bw if is_sym_node(n)] + ) + bw_donated_idxs = collect_bw_donated_buffer_idxs( + fw_module, + bw_module, + inner_meta, + ) + fw_donated_idxs = [ + i - inner_meta.num_symints_saved_for_bw for i in bw_donated_idxs + ] + allow_set = {fw_outs_saved_for_bw[i].name for i in fw_donated_idxs} + elif mode == "no_static": + fw_g_inputs = fw_g.find_nodes(op="placeholder") + exclude_set = {fw_g_inputs[i].name for i in static_input_indices} + + if (allow_set is not None) and (not allow_set): + # This means we have empty whitelist, + # No donated (intermediate) saved. + # Do not do anything in this case + return + + if aot_config.enable_log: + structured_logs.append(f"fw_outs_saved_for_bw:{fw_outs_saved_for_bw}") + structured_logs.append(f"mode:{mode}") + structured_logs.append(f"allow_set:{allow_set}") + structured_logs.append(f"exclude_set:{exclude_set}") + + for saved in fw_outs_saved_for_bw: + if ((allow_set is not None) and (saved.name not in allow_set)) or ( + (exclude_set is not None) and (saved.name in exclude_set) + ): + if isinstance(saved.meta["val"], torch.Tensor): + fw_outs_packed_tensors.append(saved) + continue + + val = saved.meta["val"] + if not isinstance(val, torch.Tensor): + continue + + pack_out_val = pack_hook_gm(val) + + requires_sc_handling = any( + is_traceable_wrapper_subclass(x) for x in pytree.tree_leaves(pack_out_val) + ) + if requires_sc_handling: + raise NotImplementedError( + "Tensor subclasses in GraphModule saved tensors hooks are not supported" + "You can workaround it by manually returning subclass's inner tensors" + " in the pack hook, and reconstructing the subclass in the unpack hook" + ) + + pack_gm = prepare_hook_gm(aot_config, pack_hook_gm, (val,)) + pack_g = pack_gm.graph + maybe_log_graph( + pack_gm, + f"saved_tensors_pack_hook {saved.name}", + aot_config, + lambda: f"aot_saved_tensors_hooks_pack {saved.name}", + structured_logs, + ) + pack_out_val = pack_gm(val) + + # Install pack hook graph as eiplogue of fw_module. + # Saved tensor output becomes input of pack hook graph. + # Replace saved tensor output with pack hook graph output. + # Outputs symbolic scalars, tensors are accumulated separately. + # Then in forward outputs and backward inputs installed in order + # sym_scalars, packed_saved_tensors. + # Keeping all tensors together allows to preserve + # the same identification at runtime, + # updating only number of saved sym_scalars and tensors. + pack_g_inputs = pack_g.find_nodes(op="placeholder") + assert len(pack_g_inputs) == 1 + env = {pack_g_inputs[0]: saved} + fw_pack_out_args = None + with fw_g.inserting_before(fw_out_n): + for node in pack_g.nodes: + if node.op == "placeholder": + continue + new_n = fw_g.node_copy(node, lambda n: env[n]) + fw_g_names.add(new_n.name) + env[node] = new_n + # Output node is temporarily copied to have remapped arguments. + # Removed in the end. + if node.op == "output": + fw_pack_out_args = new_n.args[0] + fw_g.erase_node(new_n) + + env.clear() + assert fw_pack_out_args + fw_outs_bw_ins_node_names = [] + for out_idx, _n in enumerate(pytree.tree_leaves(fw_pack_out_args)): + if not isinstance(_n, torch.fx.Node): + fw_outs_bw_ins_node_names.append("") + continue + + # This happens when hook is noop and it is either user input or user output. + # Do not do anything with this node. + if _n.op == "placeholder" or _n in fw_outs_inner_set: + # This means the hook returned input primals unchanged + # Do not rename in this case. + n = _n + new_node_name = _n.name + fw_outs_bw_ins_node_names.append(new_node_name) + else: + # We can not specify desired name in node_copy. + # Copying node manually to set specifc name, + # to have matching fw_outs, bw_inputs names. + new_node_name = _gen_unused_name(f"{saved.name}_hook_{out_idx}") + with fw_g.inserting_before(_n): + n = fw_g.create_node( + _n.op, + _n.target, + _n.args, + _n.kwargs, + name=new_node_name, + ) + assert n.name == new_node_name + fw_outs_bw_ins_node_names.append(new_node_name) + n.meta = copy.copy(_n.meta) + _n.replace_all_uses_with(n) + fw_g.erase_node(_n) + if isinstance(n.meta["val"], torch.Tensor): + fw_outs_packed_tensors.append(n) + elif is_sym_node(n): + fw_outs_packed_syms.append(n) + + # Install unpack hook graph as a prologue of backward graph + # Saved tensors inputs are replaced with packed tensors and packed sym scalars. + # The saved tensors inputs usages in the graph are replaced with unpack hook graph outputs. + unpack_gm = prepare_hook_gm(aot_config, unpack_hook_gm, (pack_out_val,)) + unpack_g = unpack_gm.graph + maybe_log_graph( + unpack_gm, + f"saved_tensors_unpack_hook {saved.name}", + aot_config, + lambda: f"aot_saved_tensors_hooks_unpack {saved.name}", + structured_logs, + ) + + def find_saved_in_bw_inputs(bw_inputs): + for n in bw_inputs: + if n.name == saved.name: + return n + + bw_g_input = find_saved_in_bw_inputs(bw_g_inputs) + assert bw_g_input + original_bw_g_input_users = list(bw_g_input.users.keys()) + bw_g_input_used_directly = False + + # Replace backward graph saved tensor input with copy of pack graph outputs + # All non-Tensor, non-symscalars outputs are constanted. + + unpack_g_inputs = unpack_g.find_nodes(op="placeholder") + env = {} + for out_idx, (unp_in_n, out_n, val) in enumerate( + zip( + unpack_g_inputs, + pytree.tree_leaves(fw_pack_out_args), + pytree.tree_leaves(pack_out_val), + ) + ): + is_sym = isinstance(val, py_sym_types) + if isinstance(val, torch.Tensor) or is_sym: + # We want forward_outputs names to match backward_inputs, + # Potentially backward may already have "{saved.name}_hook_{idx}", + # In this case fx.Graph will add suffix. + new_node_name = fw_outs_bw_ins_node_names[out_idx] + if bw_g_input.name == new_node_name: + env[unp_in_n] = bw_g_input + bw_g_input_used_directly = True + else: + # Backward calling convention: ctx_symints,ctx_saved_tensors + # Inserting packed sym scalars before first saved tensor input. + # Inserting packed tensors before last saved tensor input. + # Saved tensor inputs between them will be removed. + with bw_g.inserting_before( + bw_g_inputs[0] + ) if is_sym else bw_g.inserting_before(bw_g_input): + new_n = bw_g.placeholder(new_node_name) + assert new_n.name == new_node_name + new_n.meta = copy.copy(out_n.meta) + env[unp_in_n] = new_n + else: + # Inline values of non-Tensor, non-SymScalars + env[unp_in_n] = val + + # Inserting unpack hook after placeholders. + bw_unpack_out_n = None + with bw_g.inserting_before(bw_g_inputs[-1].next): + for node in unpack_g.nodes: + if node.op == "placeholder": + continue + new_n = bw_g.node_copy(node, lambda n: env[n]) + bw_g_names.add(new_n.name) + env[node] = new_n + # Temporary insert output, to have remapped by node_copy args. + # Removed in the end. + if node.op == "output": + bw_unpack_out_n = new_n + + assert bw_unpack_out_n + _leaves = pytree.tree_leaves(bw_unpack_out_n.args) + assert len(_leaves) == 1 + unpack_saved_tensor_n = _leaves[0] + + if not bw_g_input_used_directly: + bw_g_input.replace_all_uses_with(unpack_saved_tensor_n) + bw_g.erase_node(bw_g_input) + else: + # Keep usages of bw_g_input in inserted unpacked hook graph. + # Replace other usages of bw_g_input with unpack_saved_tensor_n. + from torch._C import _fx_map_arg + + def maybe_replace_node(n): + return unpack_saved_tensor_n if n == bw_g_input else n + + for use_node in original_bw_g_input_users: + new_args = _fx_map_arg(use_node.args, maybe_replace_node) + new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node) + assert isinstance(new_args, tuple) + assert isinstance(new_kwargs, dict) + use_node._update_args_kwargs(new_args, new_kwargs) + bw_g.erase_node(bw_unpack_out_n) + + # Changing forward graph outputs, + # Inserting packed_tensors and packed_syms on the place of saved tensors. + # Packed sym_scalars are together with saved symints + symint_outs_saved_for_bw = [n for n in fw_outs_saved_for_bw if is_sym_node(n)] + fw_new_outs = pytree.tree_leaves( + ( + fw_outs[:num_inner_fwd_outputs], + fw_outs_packed_tensors, + fw_outs_packed_syms, + symint_outs_saved_for_bw, + ) + ) + fw_out_n.args = (tuple(fw_new_outs),) + + # Assert that saved tensors and symints in forward outputs are aligned with backward inputs + _fw_n = num_inner_fwd_outputs + _fw_num_t = len(fw_outs_packed_tensors) + _fw_num_s = len(fw_outs_packed_syms) + len(symint_outs_saved_for_bw) + fw_outs_saved_tensors = fw_new_outs[_fw_n : _fw_n + _fw_num_t] + fw_outs_saved_syms = fw_new_outs[_fw_n + _fw_num_t :] + bw_new_ins = list(bw_g.find_nodes(op="placeholder")) + bw_ins_saved_syms = bw_new_ins[:_fw_num_s] + bw_ins_saved_tensors = bw_new_ins[_fw_num_s : _fw_num_s + _fw_num_t] + + fw_t_names = [n.name for n in fw_outs_saved_tensors] + bw_t_names = [n.name for n in bw_ins_saved_tensors] + fw_s_names = [n.name for n in fw_outs_saved_syms] + bw_s_names = [n.name for n in bw_ins_saved_syms] + + def _log_structured_logs(): + if not aot_config.enable_log: + return + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_saved_tensors_hooks_graphs", + "encoding": "string", + }, + payload_fn=lambda: "\n".join(structured_logs), + ) + + if aot_config.enable_log: + structured_logs.append( + f"fw_outs[:num_inner_fwd_outputs]:{fw_outs[:num_inner_fwd_outputs]}" + ) + structured_logs.append(f"fw_outs_packed_tensors:{fw_outs_packed_tensors}") + structured_logs.append(f"fw_t_names:{fw_t_names}") + structured_logs.append(f"bw_t_names:{bw_t_names}") + structured_logs.append(f"fw_s_names:{fw_s_names}") + structured_logs.append(f"bw_s_names:{bw_s_names}") + structured_logs.append(f"\nfw_g_pre_assert:{fw_g}") + structured_logs.append(f"\nbw_g_pre_assert:{bw_g}") + maybe_log_graph( + fw_module, + "Forward graph after transform pre-assert", + aot_config, + lambda: "aot_forward_graph_pre_assert_saved_tensors_hooks", + structured_logs, + ) + maybe_log_graph( + bw_module, + "Backward graph after transform pre-assert", + aot_config, + lambda: "aot_backward_graph_pre_assert_saved_tensors_hooks", + structured_logs, + ) + _log_structured_logs() + + assert fw_t_names == bw_t_names + assert fw_s_names == bw_s_names + + fw_g.lint() + bw_g.lint() + fw_module.recompile() + bw_module.recompile() + + +def aot_dispatch_autograd( + flat_fn, + flat_args: list[Any], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> DispatchReturn: + """ + Autograd logic. Generates a joint graph, partitions it, manipulates the input with various wrappers, + and returns a wrapped torch.autograd.Function with a forward and backward. + """ + wrappers = _create_wrappers_for_dispatch(needs_autograd=True) + flat_fn, flat_args, fw_metadata = pre_compile( + wrappers, + flat_fn, + flat_args, + aot_config, + fw_metadata=fw_metadata, + ) + + fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled() + with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True): + fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph( + flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + ) + + # Copied from aot_dispatch_autograd_graph. + disable_amp = torch._C._is_any_autocast_enabled() + joint_graph_str = None + if aot_config.enable_log: + aot_joint_log.info( + "%s", + lazy_format_graph_code( + "Joint graph", + fx_g, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + joint_graph_str = fx_g.print_readable( + print_output=False, include_stride=True, include_device=True + ) + trace_structured( + "aot_joint_graph", + payload_fn=lambda: joint_graph_str, + ) + + with torch.no_grad(): + inner_meta = ( + fw_metadata + if maybe_subclass_meta is None + else maybe_subclass_meta.fw_metadata + ) + with track_graph_compiling(aot_config, "joint"): + # See Note: [Partitioner handling for Subclasses, Part 1] + # See Note: [Recomputing subclass mutation handling] + mutated_inp_runtime_indices = ( + compute_inner_mutated_inp_indices_from_subclass_meta( + fw_metadata, inner_meta + ) + ) + num_tokens = len(fw_metadata.tokens) + num_mutated_inp_runtime_indices = len(mutated_inp_runtime_indices) + num_inner_fwd_outputs = ( + num_mutated_inp_runtime_indices + + inner_meta.num_outputs + + inner_meta.num_intermediate_bases + + inner_meta.num_outputs_rng_offset + + num_tokens # See Note [Side-Effectful Tokens in AOTAutograd] + ) + fake_mode = detect_fake_mode() + fx_g = run_joint_graph_passes_on_hops(fx_g, joint_inputs, aot_config) + + # TODO(anijain2305) - Add tensorify_python_scalars to the HOP graph passes. + if fake_mode is not None and fake_mode.shape_env is not None: + tensorify_python_scalars(fx_g, fake_mode.shape_env, fake_mode) + + static_lifetime_input_indices = fw_metadata.static_input_indices + fw_module, bw_module = aot_config.partition_fn( + fx_g, + joint_inputs, + num_fwd_outputs=num_inner_fwd_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + ) + rng_states = [ + n + for n in fw_module.graph.find_nodes(op="placeholder") + if "fwd_rng_state" in n.name + ] + fw_metadata.num_graphsafe_rng_states = len(rng_states) + if rng_states: + fw_metadata.graphsafe_rng_state_index = ( + rng_states[0].meta["val"].device.index + ) + + # See Note [Side-Effectful Tokens in AOTAutograd] + if config.unlift_effect_tokens and ( + num_tokens > 0 or fw_metadata.num_backward_tokens > 0 + ): + unlift_tokens(fw_module, fw_metadata, aot_config, bw_module) + + num_inner_fwd_outputs -= num_tokens + joint_inputs = ( + joint_inputs[0][num_tokens:], + joint_inputs[1], + ) + + maybe_inline_graph_saved_tensors_hooks( + fw_module, + bw_module, + num_inner_fwd_outputs, + inner_meta, + aot_config, + fw_metadata.static_input_indices, + ) + static_lifetime_input_indices = fw_metadata.static_input_indices + + fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0] + # we only need to bookkeep the symints that are saved for bw, not any symints + # the user forward might have returned in its own output + fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:] + num_fw_outs_saved_for_bw = len(fw_outs_saved_for_bw) + symint_outs_saved_for_bw = [] + for idx, node in enumerate(fw_outs_saved_for_bw): + if is_sym_node(node): + symint_outs_saved_for_bw.append(node) + elif ( + isinstance(node, torch.fx.Node) + and "val" in getattr(node, "meta", {}) + and isinstance(node.meta["val"], FakeTensor) + ): + # record dynamic tensor activations + dynamic_dims: set[int] = { + dim + for dim, size in enumerate(node.meta["val"].shape) + if not isinstance(size, int) + } + if dynamic_dims: + fw_metadata.dynamic_saved_tensors_idxs[idx] = dynamic_dims + + fw_metadata.num_symints_saved_for_bw = len(symint_outs_saved_for_bw) + inner_meta.num_symints_saved_for_bw = len(symint_outs_saved_for_bw) + num_symints_saved_for_bw = len(symint_outs_saved_for_bw) + if torch._functorch.config.donated_buffer: + fw_metadata.bw_donated_idxs = collect_bw_donated_buffer_idxs( + fw_module, + bw_module, + inner_meta, + ) + inner_meta.bw_donated_idxs = fw_metadata.bw_donated_idxs + + if aot_config.enable_log: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "torch._functorch.config", + "encoding": "string", + }, + payload_fn=lambda: torch._functorch.config.get_config_copy(), + ) + aot_graphs_log.info( + "aot_config id: %s, fw_metadata=%s, inner_meta=%s", + str(aot_config.aot_id), + str(fw_metadata), + str(inner_meta), + ) + + # Note [Detaching inputs that never need gradients] + # See https://github.com/pytorch/pytorch/issues/97745 + # Suppose we have a function like this that we want to compile: + # + # def f(x, y): + # return torch.mul(x, y.detach()) + # + # What gradients should we compute for x and y? + # By default, AOTAutograd will compute a gradient for **every** input that requires gradients, + # and so we'll compute: + # x_grad_input = y + # y_grad_input = None + # Does this preserve the semantics of eager mode? + # Unfortunately, no. + # Doing the above will cause autograd to **continue** to backprop the autograd tape + # that was generated from constructing y. + # + # This is **different** from what would have happened in eager mode. + # In eager mode, if we backprop through the output of this function, autograd will only traverse + # the bit of the autograd tape corresponding to "x". + # In particular, if a user had previously backpropped through y's autograd tape, + # And then they try to backprop through the output of the above function, + # then we'll hit the dreaded "Trying to backward through the graph a second time" error. + # + # You might think: If autograd sees that a gradient is None, shouldn't it stop early, + # instead of continuing the backprop through the ancestors of that node in the graph? + # + # Autograd has two passes: + # (1) a first pass that traverses the autograd graph and figures out which nodes need to be executed + # (2) a second pass that actually goes ahead and executes each node when it becomes ready, + # propagating gradients + # By the time we're executing a node and we see that it produces a None, the set of nodes to execute + # is already locked-in. + # + # The fix: instead, we can recognize statically that the graph we're compiling will never contribute + # gradients to y, and prevent autograd from trying to traverse y's autograd tape at all. + # We can do this by manually detach'ing y before sending it through the `CompiledFunction`. + # + # Note that this solution is not bulletproof. + # It's possible to construct a case where eager may or may not have have tried to autograd through y, + # depending on the actual grad_outputs that were passed in during the backward. + # There is no easy fix for this: the simplest fix would be to run with `retain_graph=True`, + # allowing autograd to re-use the graph. + # + # An example of this case is: + # def f(x): + # return x.detach() * 2, x * 3 + # If we were to only backprop through outs[0], in eager, we would stop + # If we backward only on the first output, we shouldn't send a grad through x. + # But the custom autograd function doesn't know that: it will materialize zero grads for x * 3 + # and we will end up with a zero grad at x. + # If we later backprop through the second output, this will also require backprop'ing through x. + # Meaning we'll need to use `retain_graph=True` to be able to backprop through x the second time. + _indices_of_inps_to_detach: list[int] = [] + + # reversed() since we expect output at end of graph + bw_output = next(reversed(bw_module.graph.find_nodes(op="output"))) + bw_outs: Sequence[torch.fx.Node] = bw_output.args[0] # type: ignore[assignment] + + # TODO: we should apply the below "detach inputs if their gradients are statically known to be None" + # optimization even if we have subclass inputs/outputs (we do not handle this today). + # Computing which our our inputs get None gradients is a bit more complicated, + # if any of our inputs are subclasses. Why? + # (a) we need to make sure that we call .detach() on the input subclasses, since autograd sees subclasses. + # (b) The grad_outputs that we AOT computed in our backward graph are the desugared tensor tensors, + # so we need to figure out which subclass fw inputs they map to. + if maybe_subclass_meta is None: + num_backward_tokens: int = inner_meta.num_backward_tokens + assert ( + len(bw_outs) + == len(fw_metadata.input_info) + + inner_meta.num_outputs_rng_offset + + num_backward_tokens + ) + bw_outs_no_rng_no_tokens = bw_outs + if (inner_meta.num_outputs_rng_offset + num_backward_tokens) > 0: + bw_outs_no_rng_no_tokens = bw_outs[ + : -(inner_meta.num_outputs_rng_offset + num_backward_tokens) + ] + assert len(bw_outs_no_rng_no_tokens) == len(fw_metadata.input_info) + + for i, (bw_out) in enumerate(bw_outs_no_rng_no_tokens): + # If our input experiences a metadata mutation inside the graph (e.g. set_()), + # we *must* not detach, otherwise it will be the detach'd input that gets the metadata mutation + metadata_mutation_in_graph = ( + fw_metadata.input_info[i].mutation_type + == MutationType.MUTATED_IN_GRAPH + and fw_metadata.input_info[i].mutates_storage_metadata + ) + is_non_leaf = ( + fw_metadata.input_info[i].requires_grad + and not fw_metadata.input_info[i].is_leaf + ) + if bw_out is None and not metadata_mutation_in_graph and is_non_leaf: + _indices_of_inps_to_detach.append(i) + + fw_module_str = None + bw_module_str = None + if aot_config.enable_log: + aot_graphs_log.info( + "%s", + lazy_format_graph_code( + "Forward graph", + fw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + aot_graphs_log.info( + "%s", + lazy_format_graph_code( + "Backward graph", + bw_module, + aot_config.aot_id, + include_stride=True, + include_device=True, + colored=True, + ), + ) + fw_module_str = fw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ) + bw_module_str = bw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(fw_metadata), + ) + if maybe_subclass_meta is not None: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_subclass_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(maybe_subclass_meta), + ) + + trace_structured( + "aot_forward_graph", + payload_fn=lambda: fw_module_str, + ) + trace_structured( + "aot_backward_graph", + payload_fn=lambda: bw_module_str, + ) + + # AMP is already traced out in joint graph. we do not wish to reapply it accidentally + # in the compiler. + with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast(): + # flat_args at this point might still be subclasses- + # make sure to pass the unwrapped fake tensors into the compiler! + adjusted_flat_args = joint_inputs[0] + + fakified_out_wrapper = FakifiedOutWrapper() + ( + fw_module, + adjusted_flat_args, + fw_metadata, + ) = fakified_out_wrapper.pre_compile( + fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata + ) + + functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper( + return_new_outs=False + ) + + if rng_states: + index = fw_metadata.graphsafe_rng_state_index + assert index is not None + rng_states = [ + get_cuda_generator_meta_val(index) + for _ in range(fw_metadata.num_graphsafe_rng_states) + ] + adjusted_flat_args.extend(rng_states) # type: ignore[arg-type] + + ( + fw_module, + adjusted_flat_args, + fw_metadata, + ) = functionalized_rng_wrapper.pre_compile( + fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata + ) + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.fw_metadata = inner_meta + + with TracingContext.report_output_strides() as fwd_output_strides: + compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args) + + if not getattr(compiled_fw_func, "_boxed_call", False): + compiled_fw_func = make_boxed_func(compiled_fw_func) + + if fakified_out_wrapper.needs_post_compile: + fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides) + + compiled_fw_func = EffectTokensWrapper().post_compile( + compiled_fw_func, + aot_config, + runtime_metadata=fw_metadata, + ) + + compiled_fw_func = AOTDispatchSubclassWrapper( + fw_only=None, + trace_joint=False, + maybe_subclass_meta=maybe_subclass_meta, + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, + ).post_compile( + compiled_fw_func, + aot_config, # not used + runtime_metadata=fw_metadata, + ) + + compiled_fw_func = functionalized_rng_wrapper.post_compile( + compiled_fw_func, aot_config, runtime_metadata=fw_metadata + ) + compiled_fw_func = fakified_out_wrapper.post_compile( + compiled_fw_func, + aot_config, + runtime_metadata=fw_metadata, + ) + + # NB: It's important to compile backwards ahead of time, as this may + # add extra guards which we need to apply to the Dynamo cache at + # forwards + with track_graph_compiling(aot_config, "backward"), torch._C._DisableAutocast(): + placeholder_list = fx_placeholder_vals(bw_module) + + forward_saved_for_backwards_strides = None + if fwd_output_strides is not None: + forward_saved_for_backwards_strides = fwd_output_strides[ + inner_meta.tensors_saved_for_backwards_slice + ] + + # saved activations can have different stride to eager if + # the compiler does layout optimization. We should restride the + # tensor passed in for compiling the backward graph using the + # saved tensor's stride. + for i in range(len(placeholder_list)): + ph_arg = placeholder_list[i] + if not isinstance(ph_arg, torch.Tensor): + continue + + if forward_saved_for_backwards_strides is None: + continue + + real_stride = None + # Per all_args calling convention + j = i - num_symints_saved_for_bw + if 0 <= j < len(forward_saved_for_backwards_strides): + real_stride = forward_saved_for_backwards_strides[j] + if real_stride is None: + continue + + # Comparing ph_arg.stride() with real_stride directly may + # cause dynamic dimensions in ph_arg being specialized to static + # value. Using the hints to avoid that. + if _get_symint_hints(ph_arg.stride()) != real_stride: + # Note that here we use the stride of the real tensor to + # restride a FakeTensor. This does not cause trouble + # for dynamic shape since this code path only get + # executed if layout optimization is enabled. And we + # disable layout optimization for dynamic shape right + # now. + # + # A solution that decide stride order based on real + # tensor's stride and then apply that stride order to + # the FakeTensor does not work smoothly since some + # tensor's layout is not 'dense'. E.g. mixnet_l has a + # tensor with size [8, 64, 112, 112] and strides + # (2408448, 1, 21504, 192). The solution mentioned will + # decide a stride of (802816, 1, 7168, 64) for this + # tensor which is wrong. + placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride) + + compiled_bw_func = None + if num_symints_saved_for_bw > 0: + try: + # See Note: [Backward graph lazy lowering] + with torch._subclasses.fake_tensor.unset_fake_temporarily(): + # If bw_module contains lifted constants, they will be real tensors stored as + # GraphModule. Deepcopying tensors under fake mode is not supported and will + # raise when attempting to set storage. + bw_module_copy = copy.deepcopy(bw_module) + compiled_bw_func = aot_config.bw_compiler( + bw_module_copy, placeholder_list + ) + del bw_module_copy + except Exception as e: + exc = e + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "eager_compile_backwards_failure", + "encoding": "string", + }, + payload_fn=lambda: "\n".join( + traceback.format_exception( + type(exc), exc, exc.__traceback__ + ) + ), + ) + log.warning( + "failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed", + exc_info=True, + ) + # Compiled autograd will run the bw_module in the backward pass, + # so recompilation need happen anyway if the backward pass is ever + # called. + # + # The reason we do the GraphModule recompilation here is because + # the lazy recompilation will cause issue in the backward pass + # with compiled autograd. + # + # Do the _LazyGraphModule.force_recompile here rather than when + # bw_module is first generated by the partitioner because the bw_module.recompile + # may be called in some code path later and cause the _LazyGraphModule.forward + # becomes the lazy version again. One example is when dynamic shape is enabled + # upfront, the bw_compiler will be called above which can cause extra + # graph module recompilation on bw_module. + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + from torch.fx._lazy_graph_module import _LazyGraphModule + + _LazyGraphModule.force_recompile(bw_module) + + saved_context = TracingContext.try_get() + saved_compile_context = CompileContext.try_get() + + backward_state_indices = [ + idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState) + ] + assert len(backward_state_indices) <= 1 + + lazy_backward_info = AutogradLazyBackwardCompileInfo( + bw_module, + placeholder_list, + saved_context, + saved_compile_context, + ) + + make_runtime_safe(fw_metadata, maybe_subclass_meta) + + try_save_cache_entry: Optional[Callable] = None + + if aot_config.cache_info is not None: + forward_time_taken_ns = time.time_ns() - aot_config.cache_info.start_time_ns + + # NB: aot_config here is technically not needed as an argument: we could just + # close over aot_config.cache_info, since aot_config never changes. + # But closing over random variables is confusing IMO, so I'm leaving it. + def try_save_cache_entry( # noqa: F811 + compiled_bw_func: Callable, + bw_module: torch.fx.GraphModule, + _fw_metadata: ViewAndMutationMeta, + aot_config: AOTConfig, + ): + fw_key = getattr(compiled_fw_func, "_fx_graph_cache_key", None) + bw_key = getattr(compiled_bw_func, "_fx_graph_cache_key", None) + cache_info = aot_config.cache_info + if cache_info is not None and fw_key and bw_key: + assert forward_time_taken_ns is not None + # TODO: technically, AOTAutograd does a *little* bit of post processing work + # in the backward that isn't measured here. But it's small enough that it's not worth + # the complexity of threading a bunch of times through the code, so we + # use the compiled_bw_func's inductor compile time instead. + # It's possible this changes in the future, in which case we should + # update backward_time_taken_ns to be more inclusive + backward_time_taken_ns = getattr(compiled_bw_func, "_time_taken_ns", 0) + + aot_forward_graph_str: Optional[str] = fw_module_str + aot_backward_graph_str: Optional[str] = bw_module_str + aot_joint_graph_str: Optional[str] = joint_graph_str + guards_expr = AOTAutogradCache.generate_guards_expression(cache_info) + + entry = AOTAutogradCache.make_entry( + compiled_fw_func, # type: ignore[arg-type] + compiled_bw_func, # type: ignore[arg-type] + aot_joint_graph_str, + aot_forward_graph_str, + aot_backward_graph_str, + _fw_metadata, + wrappers, + maybe_subclass_meta, + num_fw_outs_saved_for_bw, + _indices_of_inps_to_detach, + forward_time_taken_ns, + backward_time_taken_ns, + sanitized_aot_config=sanitize_aot_config(aot_config), + guards_expr=guards_expr, + backward_state_indices=backward_state_indices, + num_symints_saved_for_bw=num_symints_saved_for_bw, + serialized_bw_module=serialize_graph_module(bw_module), + ) + remote = should_use_remote_autograd_cache() + AOTAutogradCache.save(cache_info.cache_key, entry, remote) + + if compiled_bw_func is not None: + # If we already compiled the backward, we save its cache entry now + try_save_cache_entry(compiled_bw_func, bw_module, fw_metadata, aot_config) + try_save_cache_entry = None + + compiled_fn = AOTDispatchAutograd.post_compile( + compiled_fw_func, + compiled_bw_func, + maybe_subclass_meta, + num_symints_saved_for_bw, + backward_state_indices, + disable_amp, + _indices_of_inps_to_detach, + lazy_backward_info, + aot_config, + fw_metadata=fw_metadata, + try_save_cache_entry=try_save_cache_entry, + ) + + if config.debug_assert: + flat_requires_grad: list[Optional[bool]] = [ + a.requires_grad if isinstance(a, Tensor) else None for a in flat_args + ] + compiled_fn = DebugAssertWrapper( + flat_requires_grad=flat_requires_grad + ).post_compile(compiled_fn, aot_config, runtime_metadata=fw_metadata) + + compiled_fn = post_compile( + wrappers, + compiled_fn, + aot_config, + runtime_metadata=fw_metadata, + ) + return compiled_fn diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index e2f66bdef70f4..69a21f6bd8944 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -6,7 +6,10 @@ 3. handle functionalized randomness 4. deduplicate inputs and consolidate views into their bases (see input_output_analysis) """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import builtins import collections import contextlib @@ -18,12 +21,16 @@ from functools import wraps from typing import Any, Callable, Optional, TYPE_CHECKING, Union +<<<<<<< HEAD if TYPE_CHECKING: from collections.abc import Sequence import torch import torch.fx as fx +======= +import torch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch.utils.dlpack from torch import Tensor from torch._dynamo import config as dynamo_config @@ -45,6 +52,7 @@ from .. import config from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata +<<<<<<< HEAD from .descriptors import ( AOTInput, AOTOutput, @@ -55,6 +63,9 @@ ) from .functional_utils import gen_alias_from_base from .graph_capture_wrappers import aot_dispatch_subclass +======= +from .functional_utils import gen_alias_from_base +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .input_output_analysis import ( compute_overlapping_inputs, create_synthetic_base_metadata, @@ -63,9 +74,12 @@ from .logging_utils import describe_input, format_guard_bug_msg, track_graph_compiling from .schemas import ( AOTConfig, +<<<<<<< HEAD CompilerWrapper, FxValue, InductorWrapper, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) InputAliasInfo, MemoryFormatMeta, MutationType, @@ -74,7 +88,10 @@ SubclassCreationMeta, SubclassMeta, TensorAlias, +<<<<<<< HEAD TraceFn, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ViewAndMutationMeta, ) from .subclass_utils import ( @@ -82,6 +99,7 @@ runtime_unwrap_tensor_subclasses, wrap_tensor_subclasses, ) +<<<<<<< HEAD from .utils import ( call_and_expect_output_descs, call_func_at_runtime_with_args, @@ -96,6 +114,72 @@ zip = strict_zip +======= +from .traced_function_transforms import aot_dispatch_subclass +from .utils import ( + call_func_at_runtime_with_args, + make_boxed_func, + partial_flatten_asdict, + strict_zip, +) + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +zip = strict_zip + + +class CompilerWrapper: + """ + A wrapper around the inputs and outputs to the compiler_fn. We separate these into two parts: + + 1. The prologue, which edits the input to the compiler_fn(flat_fn, flat_args, etc) + 2. The epilogue, which edits the outputs of the compiler_fn (compiled_fn, real arguments) + + Each wrapper below should be implemented as a CompilerWrapper, so that we can facilitate + caching on the compiled output, and re-wrapping the output via epilogues. + Extra metadata that is needed to compute pre or post compile can be passed in via attributes. + """ + + def pre_compile( + self, + flat_fn, + flat_args: list[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: + """ + Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs. + Args: + flat_fn: The function to compile + flat_args: Metadata from example inputs of the function to compile + aot_config: AOTConfig passed in at compile time + fw_metadata: ViewAndMutationMeta generated from flat_fn and flat_args + """ + return flat_fn, flat_args, fw_metadata + + def post_compile(self, compiled_fn, aot_config, *, runtime_metadata) -> Callable: + """ + Given an output of the compiler, wrap it with information received from prologue. + Args: + compiled_fn: Callable after calling compiler_fn + aot_config: AOTConfig after calling prologue + runtime_metadata: ViewAndMutationMeta after calling all wrappers's pre_compile steps. + Example: + + def wrapped_compiled_fn(args): + # do something with args, aot_config, fw_metadata + return compiled_fn(args) + + return wrapped_compiled_fn + """ + return compiled_fn + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic # that needs to run after the compiled function. # @@ -149,7 +233,11 @@ def __init__(self, info, runtime_metadata, trace_joint): self.base_idx = info.base_idx self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity self.requires_grad = info.requires_grad +<<<<<<< HEAD self.view_meta_sequence = info.view_meta_sequence +======= + self.functional_tensor = info.functional_tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.replay_views = config.view_replay_for_aliased_outputs def __call__(self, orig_inputs, fw_outs, out): @@ -158,7 +246,11 @@ def __call__(self, orig_inputs, fw_outs, out): aliased_base_tensor, self.unwrap_out(out), self.requires_grad, +<<<<<<< HEAD self.view_meta_sequence, +======= + self.functional_tensor, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) replay_views=self.replay_views, ) @@ -189,7 +281,11 @@ def __init__(self, info, runtime_metadata, trace_joint): self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity self.requires_grad = info.requires_grad +<<<<<<< HEAD self.view_meta_sequence = info.view_meta_sequence +======= + self.functional_tensor = info.functional_tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.replay_views = config.view_replay_for_aliased_outputs def __call__(self, orig_inputs, fw_outs, out): @@ -198,7 +294,11 @@ def __call__(self, orig_inputs, fw_outs, out): self._unwrap_aliased_base_tensor(aliased_base_tensor), self.unwrap_out(out), self.requires_grad, +<<<<<<< HEAD self.view_meta_sequence, +======= + self.functional_tensor, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) replay_views=self.replay_views, ) @@ -287,9 +387,15 @@ def _create_runtime_wrapper( for info in runtime_metadata.output_info ) +<<<<<<< HEAD def record_runtime_wrapper_prologue_enter() -> Optional[ AbstractContextManager[None] ]: +======= + def record_runtime_wrapper_prologue_enter() -> ( + Optional[AbstractContextManager[None]] + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( torch.autograd.profiler._is_profiler_enabled and dynamo_config.record_runtime_overhead @@ -472,9 +578,14 @@ def _runtime_wrapper(*args, **kwargs): return _runtime_wrapper +<<<<<<< HEAD # WARNING: this does NOT operate on TraceFn @dataclass class FunctionalizedRngRuntimeWrapper(InductorWrapper): +======= +@dataclass +class FunctionalizedRngRuntimeWrapper(CompilerWrapper): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: I would love to get rid of this argument, but it's # Wrapped pretty tightly around our aot_dispatch_autograd logic. # Specifically, tensors_saved_for_backwards_slice's value is both used for calculating indices @@ -486,21 +597,36 @@ class FunctionalizedRngRuntimeWrapper(InductorWrapper): def pre_compile( self, +<<<<<<< HEAD flat_fn: torch.fx.GraphModule, +======= + flat_fn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) flat_args, aot_config, *, fw_metadata, +<<<<<<< HEAD ) -> None: if config.functionalize_rng_ops: # Update example inputs for the fw_compiler fake_mode = detect_fake_mode() assert fake_mode is not None +======= + ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: + if config.functionalize_rng_ops: + # Update example inputs for the fw_compiler + fake_mode = detect_fake_mode() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode) flat_args.extend([seed, offset]) # We are not clearing flat_args here because # 1) There is a check in the debug compiler at the end # 2) It does not matter as these are fake tensors +<<<<<<< HEAD +======= + return flat_fn, flat_args, fw_metadata +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def post_compile( self, @@ -548,9 +674,14 @@ def _functionalized_rng_runtime_epilogue( return outs +<<<<<<< HEAD # WARNING: this does NOT operate on TraceFn @dataclass class FakifiedOutWrapper(InductorWrapper): +======= +@dataclass +class FakifiedOutWrapper(CompilerWrapper): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_metas: list[torch.Tensor] = field(default_factory=list) # TracingContext.fwd_output_strides # Generated from actually doing compile @@ -560,12 +691,20 @@ class FakifiedOutWrapper(InductorWrapper): def pre_compile( self, +<<<<<<< HEAD fw_module: fx.GraphModule, # Must be fw_module from aot_dispatch_*_graph +======= + fw_module, # Must be fw_module from aot_dispatch_*_graph +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) flat_args, aot_config, *, fw_metadata, +<<<<<<< HEAD ) -> None: +======= + ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tracing_context = torch._guards.TracingContext.try_get() if tracing_context and tracing_context.fakify_first_call: self.out_metas = [ @@ -573,6 +712,10 @@ def pre_compile( ] else: self.needs_post_compile = False +<<<<<<< HEAD +======= + return fw_module, flat_args, fw_metadata +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _compute_output_meta_with_inductor_strides(self): out = self.out_metas @@ -646,13 +789,19 @@ class AOTDispatchSubclassWrapper(CompilerWrapper): def pre_compile( self, +<<<<<<< HEAD flat_fn: TraceFn, flat_args: list[FxValue], flat_args_descs: list[AOTInput], +======= + flat_fn, + flat_args: list[Tensor], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta, ): +<<<<<<< HEAD (new_flat_fn, new_flat_args, new_flat_args_descs, subclass_meta) = ( aot_dispatch_subclass( flat_fn, @@ -665,6 +814,17 @@ def pre_compile( ) self.maybe_subclass_meta = subclass_meta return new_flat_fn, new_flat_args, new_flat_args_descs, fw_metadata +======= + (new_flat_fn, new_flat_args, subclass_meta) = aot_dispatch_subclass( + flat_fn, + flat_args, + is_joint_structure=self.trace_joint, + meta=fw_metadata, + fw_only=self.fw_only, # type: ignore[arg-type] + ) + self.maybe_subclass_meta = subclass_meta + return new_flat_fn, new_flat_args, fw_metadata +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def post_compile( self, @@ -832,6 +992,7 @@ def add_dupe_args(self, args): def pre_compile( self, +<<<<<<< HEAD flat_fn: TraceFn, flat_args: list[FxValue], flat_args_descs: list[AOTInput], @@ -839,11 +1000,20 @@ def pre_compile( *, fw_metadata: ViewAndMutationMeta, ) -> tuple[TraceFn, list[FxValue], list[AOTInput], ViewAndMutationMeta]: +======= + flat_fn, + flat_args: list[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Use information about whether or not flat_fn mutates its arguments # or not to handle dupe args # Strategy 1: For any input that is not mutated, we can leafify it if we # need to remove a duplicate. +<<<<<<< HEAD leaf_flat_args: list[FxValue] = [] leaf_flat_args_descs: list[AOTInput] = [] args_set = set() @@ -857,19 +1027,38 @@ def pre_compile( args_set.add(a) leaf_flat_args.append(a) leaf_flat_args_descs.append(a_desc) +======= + leaf_flat_args = [] + args_set = set() + ok = True + + for i, a in enumerate(flat_args): + if not isinstance(a, torch.Tensor): + leaf_flat_args.append(a) + elif a not in args_set: + args_set.add(a) + leaf_flat_args.append(a) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif ( not fw_metadata.input_info[i].mutates_data and not fw_metadata.input_info[i].mutates_metadata ): leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad)) +<<<<<<< HEAD leaf_flat_args_descs.append(a_desc) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: ok = False break if ok: self.needs_post_compile = False +<<<<<<< HEAD return flat_fn, leaf_flat_args, leaf_flat_args_descs, fw_metadata +======= + return flat_fn, leaf_flat_args, fw_metadata +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if requires_subclass_dispatch(leaf_flat_args, fw_metadata): raise RuntimeError( @@ -929,18 +1118,27 @@ def pre_compile( keep_arg_mask.append(True) add_dupe_map.append(j) j += 1 +<<<<<<< HEAD assert len(add_dupe_map) == duped_arg_len, ( f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}" ) +======= + assert ( + len(add_dupe_map) == duped_arg_len + ), f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.keep_arg_mask = keep_arg_mask self.add_dupe_map = add_dupe_map deduped_flat_args = self.remove_dupe_args(flat_args) +<<<<<<< HEAD # TODO: instead of arbitrarily removing args, it might be useful to # have a record that these were duped, perhaps as a mutable attribute # on the kept arg? Do this if someone needs it deduped_flat_args_descs = self.remove_dupe_args(flat_args_descs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Update our input metadata to remove duped input metadata. updated_fw_metadata = remove_dupe_metadata( @@ -968,6 +1166,7 @@ def pre_compile( DuplicateInputs(kept_arg_source, dupe_arg_source) ) +<<<<<<< HEAD @simple_wraps(flat_fn) def wrapped_flat_fn( *args: FxValue, @@ -981,10 +1180,20 @@ def wrapped_flat_fn( ref_fw_metadata = run_functionalized_fw_and_collect_metadata( without_output_descs(wrapped_flat_fn), flat_args_descs=deduped_flat_args_descs, +======= + @wraps(flat_fn) + def wrapped_flat_fn(*args): + return flat_fn(*self.add_dupe_args(args)) + + if config.debug_assert: + ref_fw_metadata = run_functionalized_fw_and_collect_metadata( + wrapped_flat_fn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static_input_indices=aot_config.static_input_indices, keep_input_mutations=fw_metadata.keep_input_mutations, is_train=fw_metadata.is_train, )(*deduped_flat_args) +<<<<<<< HEAD assert ref_fw_metadata == updated_fw_metadata, ( f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}" ) @@ -995,6 +1204,13 @@ def wrapped_flat_fn( deduped_flat_args_descs, updated_fw_metadata, ) +======= + assert ( + ref_fw_metadata == updated_fw_metadata + ), f"ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}" + + return wrapped_flat_fn, deduped_flat_args, updated_fw_metadata +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def post_compile( self, @@ -1069,6 +1285,7 @@ class AOTSyntheticBaseWrapper(CompilerWrapper): def pre_compile( self, +<<<<<<< HEAD flat_fn: TraceFn, flat_args: list[FxValue], flat_args_descs: list[AOTInput], @@ -1085,6 +1302,18 @@ def pre_compile( aot_config, flat_args, flat_args_descs, +======= + flat_fn, + flat_args: list[Any], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: + is_inference = not self.trace_joint + flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs( + aot_config, + flat_args, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fw_metadata.input_info, is_inference=is_inference, ) @@ -1092,7 +1321,11 @@ def pre_compile( # Happy path: we don't need synthetic bases if synthetic_base_info is None: self.needs_post_compile = False +<<<<<<< HEAD return flat_fn, flat_args, flat_args_descs, fw_metadata +======= + return flat_fn, flat_args, fw_metadata +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # export path: ban synthetic bases for now, add later if requested. if requires_subclass_dispatch(flat_args, fw_metadata): @@ -1122,11 +1355,15 @@ def pre_compile( fw_metadata_updated, aliased_arg_idx_with_metadata_mutations, ) = create_synthetic_base_metadata( +<<<<<<< HEAD fw_metadata, synthetic_base_info, flat_args, flat_args_with_synthetic_bases, flat_args_descs_with_synthetic_bases, +======= + fw_metadata, synthetic_base_info, flat_args, flat_args_with_synthetic_bases +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Save old input args for post-compile self.old_input_info = fw_metadata.input_info @@ -1153,7 +1390,11 @@ def _unpack_synthetic_bases(primals: tuple[Any, ...]) -> list[Any]: f_args_inner.append(view_arg) return f_args_inner +<<<<<<< HEAD @simple_wraps(flat_fn) +======= + @wraps(flat_fn) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def wrapped_flat_fn(*args): unpacked_args = _unpack_synthetic_bases(args) # This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases) @@ -1174,6 +1415,7 @@ def wrapped_flat_fn(*args): for i, x in enumerate(unpacked_args) if i in self.aliased_arg_idx_with_metadata_mutations ] +<<<<<<< HEAD out, out_descs = call_and_expect_output_descs(flat_fn, unpacked_args) if len(aliased_args_with_metadata_mutations) > 0: # TODO: record more detailed desc information here @@ -1195,6 +1437,16 @@ def wrapped_flat_fn(*args): ref_fw_metadata = run_functionalized_fw_and_collect_metadata( without_output_descs(wrapped_flat_fn), flat_args_descs=flat_args_descs_with_synthetic_bases, +======= + if len(aliased_args_with_metadata_mutations) > 0: + return *(flat_fn(*unpacked_args)), *aliased_args_with_metadata_mutations + else: + return flat_fn(*unpacked_args) + + if config.debug_assert: + ref_fw_metadata = run_functionalized_fw_and_collect_metadata( + wrapped_flat_fn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static_input_indices=aot_config.static_input_indices, keep_input_mutations=fw_metadata.keep_input_mutations, is_train=fw_metadata.is_train, @@ -1206,7 +1458,10 @@ def wrapped_flat_fn(*args): return ( wrapped_flat_fn, flat_args_with_synthetic_bases, +<<<<<<< HEAD flat_args_descs_with_synthetic_bases, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fw_metadata_updated, ) @@ -1224,10 +1479,15 @@ def post_compile( @wraps(compiled_fn) def wrapped_compiled_fn(args): +<<<<<<< HEAD # TODO: this sure seems expensive to run at runtime (which # post_compile seems to imply it does?!) args_with_synthetic_bases, _, synthetic_base_info = merge_view_inputs( aot_config, args, None, self.old_input_info, is_inference=is_inference +======= + args_with_synthetic_bases, synthetic_base_info = merge_view_inputs( + aot_config, args, self.old_input_info, is_inference=is_inference +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) assert synthetic_base_info is not None aliased_args_w_metadata_mutations = [ @@ -1333,18 +1593,25 @@ def wrapped_compiled_fn(args): def merge_view_inputs( aot_config: AOTConfig, fwd_inputs: list[Any], +<<<<<<< HEAD # This is None when called at runtime from post_compile closure fwd_inputs_descs: Optional[list[AOTInput]], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mutated_input_info: list[InputAliasInfo], *, # The autograd case currently has more restrictions than the inference case. is_inference: bool, +<<<<<<< HEAD ) -> tuple[ list[Any], list[AOTInput], Optional[list[Union[int, tuple[int, torch.Tensor]]]] ]: if fwd_inputs_descs is None: fwd_inputs_descs = [DummyAOTInput(i) for i in range(len(fwd_inputs))] +======= +) -> tuple[list[Any], Optional[list[Union[int, tuple[int, torch.Tensor]]]]]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _are_differentiable_views(view1, view2): if view1 is view2: return True @@ -1366,20 +1633,31 @@ def _same_dtype_views(view1, view2): assert len(fwd_inputs) == len(mutated_input_info) if not [info for info in mutated_input_info if info.mutates_data]: # Return early when there are no mutations. +<<<<<<< HEAD return fwd_inputs, fwd_inputs_descs, None +======= + return fwd_inputs, None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) storage_ref_to_idx: dict[StorageWeakRef, list[int]] = collections.defaultdict(list) base_args = [] other_args = [] +<<<<<<< HEAD base_args_descs = [] other_args_descs = [] for i, (inpt, source) in enumerate(zip(fwd_inputs, fwd_inputs_descs)): +======= + for i, inpt in enumerate(fwd_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(inpt, Tensor): storage_ref = StorageWeakRef(inpt.untyped_storage()) storage_ref_to_idx[storage_ref].append(i) else: other_args.append(inpt) +<<<<<<< HEAD other_args_descs.append(source) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Note [Synthetic Base Info Metadata] # This list contains metadata that tells you what the i'th argument in the inner calling convention should be. # It's either: @@ -1397,9 +1675,12 @@ def _same_dtype_views(view1, view2): other_args.extend( fwd_inputs[curr_idx] for curr_idx in aliased_input_indices ) +<<<<<<< HEAD other_args_descs.extend( fwd_inputs_descs[curr_idx] for curr_idx in aliased_input_indices ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue # Here, we attempt to do a more complicated check to detect false aliasing @@ -1415,9 +1696,12 @@ def _same_dtype_views(view1, view2): other_args.extend( fwd_inputs[curr_idx] for curr_idx in aliased_input_indices ) +<<<<<<< HEAD other_args_descs.extend( fwd_inputs_descs[curr_idx] for curr_idx in aliased_input_indices ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue # We detected an input that was mutated, AND aliases with another input. @@ -1433,6 +1717,7 @@ def _same_dtype_views(view1, view2): # The "inputs that are aliased but have different differentiable bases" case # is more complicated and hopefully pretty rare. Not currently handled. if not is_inference: +<<<<<<< HEAD assert _are_differentiable_views(view1, view2), ( "aot_autograd() does not yet handle non-differentiable view input mutations." ) @@ -1443,13 +1728,28 @@ def _same_dtype_views(view1, view2): ) non_none_bases = [ (i, fwd_inputs[i]._base) +======= + assert _are_differentiable_views( + view1, view2 + ), "aot_autograd() does not yet handle non-differentiable view input mutations." + # Regenerating views when reinterpreting complex / real tensors seems non-trivial, + # not handling for now + assert _same_dtype_views( + view1, view2 + ), "aot_autograd() does not yet handle input mutations on views with different dtypes." + non_none_bases = [ + fwd_inputs[i]._base +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i in aliased_input_indices if fwd_inputs[i]._base is not None ] aliases_with_none_bases = [ fwd_inputs[i] for i in aliased_input_indices if fwd_inputs[i]._base is None ] +<<<<<<< HEAD synthetic_base_desc: AOTInput +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if len(non_none_bases) == 0: # Case where none of the aliases have a ._base # we generate a synthetic base without gradients, and generate views off of it @@ -1476,7 +1776,11 @@ def _same_dtype_views(view1, view2): # to have incorrect sizes. example_idx = aliased_input_indices[0] example_alias = fwd_inputs[example_idx] +<<<<<<< HEAD # Note that this function is reused at both trace time and runtime. +======= + # Note that this function is re-used at both trace time and runtime. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # At trace time, we're under a FakeMode so synthetic_base becomes a FakeTensor. synthetic_base = torch.empty( (0,), dtype=example_alias.dtype, device=example_alias.device @@ -1484,6 +1788,7 @@ def _same_dtype_views(view1, view2): # We don't actually have a convenient way of going from storage -> tensor, # So using set_() here (we suffer some minor overhead, but this case is rare). synthetic_base.set_(example_alias.untyped_storage()) +<<<<<<< HEAD synthetic_base_desc = SyntheticBaseAOTInput(fwd_inputs_descs[example_idx]) else: # Case where all of the aliases require gradients, and have the same _base. @@ -1499,6 +1804,20 @@ def _same_dtype_views(view1, view2): ) base_args.append(synthetic_base) base_args_descs.append(synthetic_base_desc) +======= + else: + # Case where all of the aliases require gradients, and have the same _base. + synthetic_base = non_none_bases[0] + for other_base in non_none_bases[1:]: + assert ( + other_base is synthetic_base + ), "aot_autograd() does not yet handle non-differentiable view input mutations." + for alias in aliases_with_none_bases: + assert ( + alias is synthetic_base + ), "aot_autograd() does not yet handle non-differentiable view input mutations." + base_args.append(synthetic_base) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for curr_view_idx in aliased_input_indices: curr_view = fwd_inputs[curr_view_idx] base_idx = len(base_args) - 1 @@ -1508,7 +1827,11 @@ def _same_dtype_views(view1, view2): if len(base_args) == 0: assert len(other_args) == len(fwd_inputs) # If no synthetic bases are necessary, just return the original inputs. +<<<<<<< HEAD return fwd_inputs, fwd_inputs_descs, None +======= + return fwd_inputs, None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: from torch.fx.experimental.symbolic_shapes import SymIntEqByExpr @@ -1524,7 +1847,10 @@ def make_hashable(arg): # (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention. # We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention. args_to_functionalization = base_args + other_args +<<<<<<< HEAD args_to_functionalization_descs = base_args_descs + other_args_descs +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Map each argument into its old index. # There may be some repeated arguments, so we collect their indices in a list. @@ -1551,11 +1877,15 @@ def make_hashable(arg): # Quick assert: every argument in the inner calling convention should be accounted for. for x in post_processed_calling_convention_meta: assert x != -1 +<<<<<<< HEAD return ( args_to_functionalization, args_to_functionalization_descs, post_processed_calling_convention_meta, ) +======= + return args_to_functionalization, post_processed_calling_convention_meta +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Note: [Backward graph lazy lowering] @@ -1563,7 +1893,11 @@ def make_hashable(arg): # unless we suspect that inductor might specialize and insert additional guards. When we do lazy # lowering, we stash the AOT backward graph (bw_module) in this class. # +<<<<<<< HEAD # Lowering passes are performed on a deepcopy of this bw_module due to compatibility +======= +# Lowering passes are performed on a deepcopy of this bw_module due to compatbility +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # with compiled autograd. See: https://github.com/pytorch/pytorch/pull/149229#discussion_r2002122645. @dataclass class AutogradLazyBackwardCompileInfo: @@ -1886,7 +2220,11 @@ def coerce_to_expected_memory_format(x: torch.Tensor, memory_format: MemoryForma return x # Empty_strided creates a raw Tensor. +<<<<<<< HEAD # We are guaranteed that only raw Tensors has expected size and stride. +======= + # We are guranteed that only raw Tensors has expected size and stride. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Subclasses have only expected memory_format. restrided = torch.empty_strided( size=expected_size, @@ -1949,11 +2287,20 @@ def process_runtime_tangent(x, meta: Union[PlainTensorMeta, SubclassCreationMeta expected_meta = meta.meta runtime_type = type(x) +<<<<<<< HEAD # When we're inside compiled autograd's AOTDispatcher step, # regular Tensors look like FunctionalTensors. # Tensor subclasses still look like Tensor subclasses though. if isinstance(x, torch._subclasses.functional_tensor.FunctionalTensor): runtime_type = torch.Tensor +======= + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + # When we're inside compiled autograd's AOTDispatcher step, + # regular Tensors look like FunctionalTensors. + # Tensor subclasses still look like Tensor subclasses though. + if isinstance(x, torch._subclasses.functional_tensor.FunctionalTensor): + runtime_type = torch.Tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) runtime_meta = None runtime_subclass_keys: Sequence[str] = [] @@ -2329,12 +2676,19 @@ def backward(double_ctx, *args): @staticmethod def _backward_impl(ctx, all_args): +<<<<<<< HEAD from torch._inductor.async_compile import async_compile_pool_manager # compiled autograd reimplements this function at proxy_call_aot_backward assert not backward_state_indices, ( "BackwardState requires CompiledAutograd" ) +======= + # compiled autograd reimplements this function at proxy_call_aot_backward + assert ( + not backward_state_indices + ), "BackwardState requires CompiledAutograd" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ctx.maybe_clear_saved_tensors() saved_tensors_use_once = ( @@ -2347,6 +2701,7 @@ def _backward_impl(ctx, all_args): lazy_backward_info, AutogradLazyBackwardCompileInfo ) +<<<<<<< HEAD if ( hasattr(lazy_backward_info, "saved_context") and lazy_backward_info.saved_context is not None @@ -2385,6 +2740,8 @@ def _backward_impl(ctx, all_args): ddp_ctx.curr_bucket -= 1 lazy_backward_info.saved_context.fw_metadata = curr_fw_meta +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not saved_tensors_use_once: fw_metadata.bw_donated_idxs = [] # Update bw_donated_idxs if using lazy_backward_info from `aot_dispatch_autograd` @@ -2410,7 +2767,10 @@ def _backward_impl(ctx, all_args): with ( tracing(saved_context), compile_context(saved_compile_context), +<<<<<<< HEAD async_compile_pool_manager(), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) context(), track_graph_compiling(aot_config, "backward"), metrics_context, @@ -2518,6 +2878,7 @@ def debug_compiled_function(args: list[Any]): def pre_compile( wrappers: list[CompilerWrapper], +<<<<<<< HEAD flat_fn: TraceFn, flat_args: list[FxValue], flat_args_descs: list[AOTInput], @@ -2525,15 +2886,30 @@ def pre_compile( *, fw_metadata: ViewAndMutationMeta, ) -> tuple[TraceFn, list[FxValue], list[AOTInput], ViewAndMutationMeta]: +======= + flat_fn: Callable, + flat_args: list[Any], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, +) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Runs a sequence of wrappers on the given function and arguments. Mutates wrappers in place. """ for wrapper in wrappers: +<<<<<<< HEAD flat_fn, flat_args, flat_args_descs, fw_metadata = wrapper.pre_compile( flat_fn, flat_args, flat_args_descs, aot_config, fw_metadata=fw_metadata ) return flat_fn, flat_args, flat_args_descs, fw_metadata +======= + flat_fn, flat_args, fw_metadata = wrapper.pre_compile( + flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + ) + return flat_fn, flat_args, fw_metadata +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def post_compile( diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index a65351c31934e..fc51632518c33 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -4,6 +4,7 @@ input/output types, metadata, config, function signatures etc. """ +<<<<<<< HEAD from __future__ import annotations import collections @@ -48,6 +49,33 @@ from .graph_capture_wrappers import JointFnHandle +======= +import collections +import dataclasses +import functools +import itertools +from collections.abc import Iterable, Sequence +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, NewType, Optional, Union + +import torch +import torch.utils._pytree as pytree +from torch._guards import Source +from torch._ops import OpOverload +from torch._subclasses import FakeTensor +from torch._subclasses.fake_tensor import is_fake +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config +from .functional_utils import ( + _check_if_mutation_can_be_in_graph, + FunctionalTensorMetadataEq, +) +from .utils import strict_zip + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) zip = strict_zip @@ -113,6 +141,7 @@ class OutputAliasInfo: dynamic_dims: Optional[set[int]] # requires_grad requires_grad: bool +<<<<<<< HEAD # Sequence of ViewMeta objects. # # Provides us the means to re-run view functions on other tensors. @@ -121,6 +150,17 @@ class OutputAliasInfo: # we compare the ViewMeta elements appropriately, i.e. their type and # the elements returned by the `as_tuple()` call. view_meta_sequence: Optional[ViewMetaSequence] = None +======= + # FunctionalTensorWrapper that represents this output. + # + # Provides us the means to replay views from it. + # + # We need to wrap the actual FunctionalTensorWrapper with this class so that + # we only compare the tensor's metadata. That's because with the transformations + # of the model throughout AOTAutograd, the sequence of ViewMeta and the base + # tensor might change. + functional_tensor: Optional[FunctionalTensorMetadataEq] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class MutationType(Enum): @@ -184,7 +224,11 @@ class MemoryFormatMeta: memory_format: Optional[torch.memory_format] = None @staticmethod +<<<<<<< HEAD def from_tensor(t: torch.Tensor) -> Optional[MemoryFormatMeta]: +======= + def from_tensor(t: torch.Tensor) -> Optional["MemoryFormatMeta"]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We only memorize expected memory format for # 1. Traceable wrapper subclasses # We can not create restrided subclass tensor, as torch.empty_strided works only with dense tensors. @@ -242,7 +286,11 @@ class SubclassCreationMeta: # arg_count is inclusive of the arg_counts of any # inner tensor subclasses: If I have a TwoTensor and # both of its inner elements are TwoTensors, then the +<<<<<<< HEAD # arg_count of the outer-most subclass will be 4 +======= + # arg_count of the outer-most sublass will be 4 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) arg_count: int # Mark where or not symints were included. This flag is only used in one assertion # in "wrap_tensor_subclasses" @@ -250,7 +298,11 @@ class SubclassCreationMeta: # meta and attrs are produced by the subclass's __tensor_flatten__. # We need to keep them around along with outer_size / outer_stride to plumb them # into __tensor_unflatten__ +<<<<<<< HEAD attrs: dict[str, Union[SubclassCreationMeta, PlainTensorMeta]] +======= + attrs: dict[str, Union["SubclassCreationMeta", PlainTensorMeta]] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) outer_size: Iterable[Union[None, int, torch.SymInt]] outer_stride: Iterable[Union[None, int, torch.SymInt]] meta: Any @@ -402,12 +454,18 @@ class ViewAndMutationMeta: # metadata pass of the user's forward function. # Their only use today is to pass them as a best-guess for tangents when tracing the joint. # Stashing them as part of our "metadata" makes it simpler if we want to run our analysis +<<<<<<< HEAD # pass once, and reuse the output throughout AOTAutograd traced_tangents: list[Any] # TODO doc traced_tangents_descs: list[AOTInput] +======= + # pass once, and re-use the output throughout AOTAutograd + traced_tangents: list[Any] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Each of these is a list telling us about subclasses for the inputs/outputs/grad_outs # They are used throughout AOTDispatch to tell us how to generate a list of subclass tensors, # Given a (potentially larger) list of plain torch tensors. @@ -660,6 +718,20 @@ def extract_metadata(t): self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents] # Clear traced tangents at runtime self.traced_tangents = [] +<<<<<<< HEAD +======= + new_output_info = [] + for out in self.output_info: + if config.view_replay_for_aliased_outputs: + new_out = out + else: + # If we're not using view_replay, remove the functional tensor. + # Functional tensors are unfortunately not serializable, + # so doing this is required for AOTAutograd caching. + new_out = dataclasses.replace(out, functional_tensor=None) + new_output_info.append(new_out) + self.output_info = new_output_info +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for inp_meta in self.subclass_inp_meta: if isinstance(inp_meta, SubclassCreationMeta): inp_meta.make_runtime_safe() @@ -699,7 +771,11 @@ def __eq__(self, other): and len(self.traced_tangents) == len(other.traced_tangents) and all( x.shape == y.shape and x.dtype == y.dtype +<<<<<<< HEAD for x, y in zip(self.traced_tangents, other.traced_tangents) +======= + for x, y, in zip(self.traced_tangents, other.traced_tangents) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) and self.num_backward_tokens == other.num_backward_tokens ) @@ -736,9 +812,15 @@ class SubclassMeta: # in case we made incorrect assumptions about the subclass-ness of our grad_outputs # # Optional field because we don't compute for inference graphs +<<<<<<< HEAD grad_input_metas: Optional[list[Union[PlainTensorMeta, SubclassCreationMeta]]] = ( None ) +======= + grad_input_metas: Optional[ + list[Union[PlainTensorMeta, SubclassCreationMeta]] + ] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self) -> None: # The fields in this class get set after its construction. @@ -813,7 +895,10 @@ class GraphSignature: # "graph outputs that correspond to updated buffers" # to the FQN names of those mutated buffers. buffers_to_mutate: dict[GraphOutputName, FQN] +<<<<<<< HEAD parameters_to_mutate: dict[GraphOutputName, FQN] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) user_inputs_to_mutate: dict[GraphOutputName, GraphInputName] in_spec: pytree.TreeSpec @@ -837,10 +922,16 @@ def from_tracing_metadata( named_buffers: list[str], num_user_inputs: int, num_user_outputs: int, +<<<<<<< HEAD trace_joint: bool, loss_index: Optional[int], backward_signature: Optional[BackwardSignature], ) -> GraphSignature: +======= + loss_index: Optional[int], + backward_signature: Optional[BackwardSignature], + ) -> "GraphSignature": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph_inputs = graph_input_names graph_outputs = graph_output_names parameters = list(named_parameters) @@ -883,9 +974,14 @@ def from_tracing_metadata( mutations = [] for idx, input_info in enumerate(view_mutation_metadata.input_info): if input_info.mutates_data: +<<<<<<< HEAD if trace_joint: # Only buffers can be mutated, not parameters assert idx >= len(parameters) +======= + # Only buffers can be mutated, not parameters + assert idx >= len(parameters) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mutations.append(names[idx + num_tokens]) assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices @@ -898,16 +994,24 @@ def from_tracing_metadata( user_inputs_to_mutate = {} buffers_to_mutate = {} +<<<<<<< HEAD parameters_to_mutate = {} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for output_name, mutation_name in outputs_to_mutations.items(): if mutation_name in user_inputs: user_inputs_to_mutate[output_name] = mutation_name else: +<<<<<<< HEAD assert mutation_name in buffers or mutation_name in parameters if mutation_name in buffers: buffers_to_mutate[output_name] = mutation_name else: parameters_to_mutate[output_name] = mutation_name +======= + assert mutation_name in buffers + buffers_to_mutate[output_name] = mutation_name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) start, stop = stop, stop + num_user_outputs user_outputs = graph_outputs[start:stop] @@ -928,7 +1032,10 @@ def from_tracing_metadata( inputs_to_parameters=inputs_to_parameters, # type: ignore[arg-type] user_inputs_to_mutate=user_inputs_to_mutate, buffers_to_mutate=buffers_to_mutate, # type: ignore[arg-type] +<<<<<<< HEAD parameters_to_mutate=parameters_to_mutate, # type: ignore[arg-type] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) in_spec=in_spec, out_spec=out_spec, backward_signature=backward_signature, @@ -974,16 +1081,20 @@ class AOTConfig: # Used only by standalone_compile. ignore_shape_env: bool = False precompile_backend_id: Optional[str] = None +<<<<<<< HEAD force_non_lazy_backward_lowering: bool = False # This config makes sure to check certain things like # mutating input with req_grad in export joint tracing. export_trace_joint: bool = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __post_init__(self): if self.pre_dispatch: assert self.is_export, "Can only have pre_dispatch IR for export." +<<<<<<< HEAD # TODO: types here # plain_tensor_trace_fn, when it is joint, has tuple structure on the trace # info too! @@ -1297,3 +1408,9 @@ def graph_module(self): @graph_module.setter def graph_module(self, value): self._aot_graph_capture.graph_module = value +======= +SubclassTracingInfo = collections.namedtuple( + "SubclassTracingInfo", + ["plain_tensor_trace_fn", "plain_tensor_args", "maybe_subclass_meta"], +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_functorch/_aot_autograd/subclass_utils.py b/torch/_functorch/_aot_autograd/subclass_utils.py index d06f727e25aa9..a9bb2c6be048c 100644 --- a/torch/_functorch/_aot_autograd/subclass_utils.py +++ b/torch/_functorch/_aot_autograd/subclass_utils.py @@ -9,7 +9,10 @@ import typing from collections.abc import Iterable from typing import Any, Callable, Optional, TypeVar, Union +<<<<<<< HEAD from typing_extensions import TypeGuard +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.utils._pytree as pytree @@ -18,6 +21,7 @@ from torch.types import IntLikeType from torch.utils._python_dispatch import is_traceable_wrapper_subclass +<<<<<<< HEAD from .descriptors import ( AOTInput, AOTOutput, @@ -31,6 +35,9 @@ ) from .schemas import ( FxValue, +======= +from .schemas import ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MutationType, PlainTensorMeta, SubclassCreationMeta, @@ -125,8 +132,13 @@ def create_subclass_metadata( new_start_idx = ( new_start_idx +<<<<<<< HEAD + count_symints * len(enumerate_filter_symints(a.size())) + count_symints * len(enumerate_filter_symints(a.stride())) +======= + + count_symints * len(filter_symints(a.size())) + + count_symints * len(filter_symints(a.stride())) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return ( @@ -180,12 +192,21 @@ def create_subclass_meta( return infos +<<<<<<< HEAD def enumerate_filter_symints(lst: Iterable[IntLikeType]) -> list[tuple[int, SymInt]]: # Capture all SymInts from the iterable. def symint_check(s: IntLikeType) -> TypeGuard[SymInt]: return isinstance(s, SymInt) and not s.node.is_nested_int() return [(i, s) for i, s in enumerate(lst) if symint_check(s)] +======= +def filter_symints(lst: Iterable[IntLikeType]): + # Capture all SymInts from the iterable. + def symint_check(s: IntLikeType) -> bool: + return isinstance(s, SymInt) and not s.node.is_nested_int() + + return [s for s in lst if symint_check(s)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def compute_symint_placeholders(lst: Iterable[Union[None, int, SymInt]]) -> list[bool]: @@ -193,12 +214,15 @@ def compute_symint_placeholders(lst: Iterable[Union[None, int, SymInt]]) -> list return [s is None for s in lst] +<<<<<<< HEAD # Intended to make it easier to define function that is # either (AOTInput -> AOTInput) or (AOTOutput -> AOTOutput) # but not the other combos AOTDescriptor = TypeVar("AOTDescriptor", AOTInput, AOTOutput) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This function takes in a pytree of arguments and unwraps any tensor # subclasses. # @@ -213,6 +237,7 @@ def compute_symint_placeholders(lst: Iterable[Union[None, int, SymInt]]) -> list # primals (but not tangents) on entry to the forward. See the runtime version of # this function below. def unwrap_tensor_subclasses( +<<<<<<< HEAD wrapped_args: list[FxValue], wrapped_args_descs: list[AOTDescriptor], *, @@ -229,12 +254,24 @@ def flatten_subclass( if not is_traceable_wrapper_subclass(t): out[0].append(t) out[1].append(desc) +======= + wrapped_args: list[Union[Tensor, int]], + *, + append_symints: bool, +): + def flatten_subclass(t: Union[Tensor, int], *, out=None): + # unwrap a subclass into plain tensors and their size/stride if "append_symint" + # is True + if not is_traceable_wrapper_subclass(t): + out.append(t) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return attrs, _ = t.__tensor_flatten__() for attr in attrs: inner_tensor = getattr(t, attr) +<<<<<<< HEAD n_desc: Any = ( SubclassGetAttrAOTInput(desc, attr) if isinstance(desc, AOTInput) @@ -261,6 +298,20 @@ def flatten_subclass( flatten_subclass(typing.cast(Tensor, x), desc, out=(xs_inner, descs_inner)) return xs_inner, descs_inner +======= + flatten_subclass(inner_tensor, out=out) + + if append_symints: + out.extend(filter_symints(t.size())) + out.extend(filter_symints(t.stride())) + + xs_inner: list[Union[int, Tensor, SymInt]] = [] + + for x in wrapped_args: + flatten_subclass(typing.cast(Tensor, x), out=xs_inner) + + return xs_inner +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # subclass_metas is needed at runtime to compute which indices are symints in @@ -328,9 +379,13 @@ def unwrap_tensor_subclasses_with_indices_to_original(wrapped_args): ret_unwrapped = [] ret_indices_to_original = [] for i, a in enumerate(wrapped_args): +<<<<<<< HEAD a_unwrapped, _ = unwrap_tensor_subclasses( [a], [DummyAOTInput(9999)], append_symints=False ) +======= + a_unwrapped = unwrap_tensor_subclasses([a], append_symints=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ret_unwrapped.extend(a_unwrapped) n = len(a_unwrapped) ret_indices_to_original.extend([i] * n) @@ -347,8 +402,13 @@ def remap_unwrapped_subclass_arg_indices(wrapped_args, static_input_indices): if is_traceable_wrapper_subclass(arg): num_indices = ( len(get_plain_tensors(typing.cast(Tensor, arg), out=[])) +<<<<<<< HEAD + len(enumerate_filter_symints(arg.size())) + len(enumerate_filter_symints(arg.stride())) +======= + + len(filter_symints(arg.size())) + + len(filter_symints(arg.stride())) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for _ in range(num_indices): @@ -412,7 +472,11 @@ def wrap_tensor_subclasses( # we computed subclass metadata on every forward output, but this did **not** include activations # created by the partitioner. # as a result, `unwrapped_args` here will correspond to (*unwrapped_user_fw_outs, *activations), +<<<<<<< HEAD # but `subclass_metas` will only correspond to subclass metadata on `user_fw_outs`. +======= + # but `subclass_metas` will only correspond to subclass metatadata on `user_fw_outs`. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We then need to make sure that we return (*wrapped_user_fw_outs, *activations). if num_fw_outs_saved_for_bw is not None: assert len(unwrapped_args) == num_args_tallied + num_fw_outs_saved_for_bw, ( @@ -425,9 +489,15 @@ def wrap_tensor_subclasses( return wrapped_args + activations return tuple(list(wrapped_args) + list(activations)) else: +<<<<<<< HEAD assert len(unwrapped_args) == num_args_tallied, ( f"Expected {len(unwrapped_args)} == {num_args_tallied}" ) +======= + assert ( + len(unwrapped_args) == num_args_tallied + ), f"Expected {len(unwrapped_args)} == {num_args_tallied}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return tuple(wrapped_args) @@ -438,7 +508,11 @@ def wrap_tensor_subclasses( def wrap_tensor_subclasses_maybe_joint( unwrapped_args, *, is_joint_structure: bool, meta: ViewAndMutationMeta ) -> Union[tuple[Any, ...], list[Any]]: +<<<<<<< HEAD # Since this function is reused for both inference and joint graphs, +======= + # Since this function is re-used for both inference and joint graphs, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_joint_structure: assert isinstance(unwrapped_args, tuple) and len(unwrapped_args) == 2 assert isinstance(unwrapped_args[0], (tuple, list)) and isinstance( diff --git a/torch/_functorch/_aot_autograd/traced_function_transforms.py b/torch/_functorch/_aot_autograd/traced_function_transforms.py new file mode 100644 index 0000000000000..8b8b5d11884ab --- /dev/null +++ b/torch/_functorch/_aot_autograd/traced_function_transforms.py @@ -0,0 +1,924 @@ +# mypy: allow-untyped-defs +""" +This module is responsible for transforming functions to be traced into a form +that is easier for the downstream infra (e.g. Autograd, FX, AOTAutograd analysis) +to handle. + +It does so by: +1. functionalization (including RNG functionalzation) +2. creating a joint graph when required +3. transforming mutations into extra outputs +4. dispatching subclasses +""" + +import warnings +from contextlib import contextmanager, nullcontext +from functools import wraps +from typing import Any, Callable, Union +from unittest.mock import patch + +import torch +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree +from torch import Tensor +from torch._decomp.decompositions_for_rng import PhiloxStateTracker +from torch._guards import detect_fake_mode +from torch._prims_common import CUDARngStateHelper +from torch.fx.experimental.proxy_tensor import ( + maybe_disable_thunkify, + maybe_enable_thunkify, +) +from torch.fx.experimental.symbolic_shapes import ( + guard_or_true, + PropagateUnbackedSymInts, + sym_eq, +) +from torch.nn.utils import stateless + +from .. import config +from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata +from .functional_utils import ( + _check_if_mutation_can_be_in_graph, + are_all_mutations_hidden_from_autograd, + are_all_mutations_under_no_grad_or_inference_mode, + from_fun, + has_data_mutation, + has_metadata_mutation, + is_fun, + sync_functional_tensor, + to_fun, + was_inductor_storage_resized, +) +from .logging_utils import setup_stacktrace_preservation_hooks +from .schemas import ( + AOTConfig, + MutationType, + OutputType, + SubclassMeta, + SubclassTracingInfo, + ViewAndMutationMeta, +) +from .subclass_utils import ( + create_subclass_meta, + remap_unwrapped_subclass_arg_indices, + requires_subclass_dispatch, + unwrap_tensor_subclasses, + wrap_tensor_subclasses_maybe_joint, +) +from .utils import maybe_to_fresh_input + + +# This function returns a new function that returns mutated inputs as outputs. +# if keep_data_input_mutations is set, then we assume that data-only mutations +# will be left in the graph, and we only return metadata-mutated inputs as outputs. +def fn_input_mutations_to_outputs( + fn: Callable, + meta: ViewAndMutationMeta, + keep_data_input_mutations: bool, +) -> Any: + @wraps(fn) + def inner_fn(*args): + outs = fn(*args) + assert len(meta.output_info) == len(outs) + # The compiled fw will return mutated input tensors, *including* metadata-only mutation. + # However, if keep_data_input_mutations is set, the compiled fw only needs to return metadata-mutated inputs. + # (because data-only input mutations are handled directly in the compiled graph) + mutated_inputs_to_return = [ + x for (i, x) in enumerate(args) if i in meta.mutated_inp_runtime_indices + ] + return *mutated_inputs_to_return, *outs + + return inner_fn + + +# This function takes in a fn with external aliasing and mutation, +# and returns a new fn with no external aliasing and mutation, +# as needed for autograd. +# The main transformations are: +# - Return mutated inputs as extra outputs +# - Clone mutated inputs that require gradients, +# because autograd will require us to pass the pre-mutated inputs into autograd.grad +# - Return intermediate bases of outputs as additional outputs, +# needed to appease autograd.Function +# The new function returns: +# (1) The updated outputs +# (2) A boolean mask of len(new_fn_outputs), +# that can be used to tell autograd.grad which outputs should get tangents +# if we trace the backward. +def fn_prepped_for_autograd( + fn: Callable, + meta: ViewAndMutationMeta, +) -> Any: + @wraps(fn) + def inner_fn(*args): + args_maybe_cloned = [ + maybe_to_fresh_input(i, t, meta) for i, t in enumerate(args) + ] + + outs = fn(*args_maybe_cloned) + assert isinstance(outs, (tuple, list)) + outs = list(outs) + assert len(meta.output_info) == len(outs) + + mutated_inputs_to_return = [ + x + for (i, x) in enumerate(args_maybe_cloned) + if i in meta.mutated_inp_runtime_indices + ] + + intermediate_bases = [] + for i, (o, info) in enumerate(zip(outs, meta.output_info)): + if info.output_type == OutputType.alias_of_intermediate_save_as_output: + intermediate_bases.append(o._base) + + assert meta.num_intermediate_bases == len(intermediate_bases) + + # the compiled forward should return (mutated_inputs, user_outs, intermediate_bases) + fw_outs_to_return = *mutated_inputs_to_return, *outs, *intermediate_bases + + # Also return a boolean mask specifying which outputs to this function will be used as tangents + mutated_inputs_grad_mask = [ + meta.input_info[meta.mutated_inp_runtime_indices[i]].mutates_data + and meta.input_info[meta.mutated_inp_runtime_indices[i]].requires_grad + for (i, x) in enumerate(mutated_inputs_to_return) + ] + + # Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw + # For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead, + # which we *should* send to grad() + output_grad_mask = [ + meta.output_info[i].output_type + in [ + OutputType.non_alias, + OutputType.unsafe_view_alias, + OutputType.custom_function_view, + ] + # Also, only tensor outputs should participate in the backward + # (in particular, Symint outputs in the forward graph shouldn't get tangents) + and issubclass(meta.output_info[i].raw_type, Tensor) + and meta.output_info[i].requires_grad + for (i, x) in enumerate(outs) + ] + + intermediate_base_grad_mask = [True for _ in range(len(intermediate_bases))] + + out_grad_mask = ( + mutated_inputs_grad_mask + output_grad_mask + intermediate_base_grad_mask + ) + assert len(out_grad_mask) == len(fw_outs_to_return) + + # Take care to grab and sync the updated inputs from primals_after_cloning (the inputs we actually mutate!) + # and not primals (the preserved inputs, pre-mutation, that we pass to grad()) + # This is annoying: our joint function needs to be aware of functionalization + # (syncing mutated inputs before calling autograd.grad()) + # In theory, we could make the autograd engine do this automatically, although that probably isn't any cleaner. + for arg in args_maybe_cloned: + if not isinstance(arg, Tensor): + continue + sync_functional_tensor(arg) + + return fw_outs_to_return, out_grad_mask + + return inner_fn + + +# Given a fn, computes the joint. +# NOTE: fn is expects the following behavior: +# (1) fn() needs to return a tuple of (outs, mask), +# where `mask` tells us which outputs are meant to have tangents. +# we don't know this info automatically, because we don't actually want to blindly +# compute tangents for every output that requires grad. +# Specifically, outputs that alias inputs won't participate in the backward and get tangents. +# (2) fn() cannot mutate any inputs that require gradient. +# otherwise, when we compute autograd.grad(), we will not take those input mutations into account +# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first) +def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any: + def inner_fn(primals: list[Any], tangents: list[Any]): + outs, tangent_mask = fn(*primals) + + assert len(tangent_mask) == len(outs) + outs_to_grad = [ + o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent + ] + assert len(outs_to_grad) == len(tangents) + + # Get the inputs that need gradients + grad_primals = [] + inputs_needs_grads = [] + # Note that we're not using primals here, + # being carefully not to pass any mutated inputs into autograd.grad() + for p in primals: + is_grad_tensor = isinstance(p, Tensor) and p.requires_grad + inputs_needs_grads.append(is_grad_tensor) + if is_grad_tensor: + grad_primals.append(p) + + # Get the outputs that need gradients + needed_outs = [] + needed_tangents = [] + for out, tangent in zip(outs_to_grad, tangents): + if isinstance(out, Tensor) and out.requires_grad: + # A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32 + # The issue is that we are sensitive to decomps that don't accurately maintain + # their output's _base.shape compared to eager mode, and this helps mitigate a bit. + # The guard_or_true also sketchy; if unbacked + # symints are involved, we're just going to assume that the + # decomps setup the base shape correctly + + # Return out if the result of out.shape==tangent.shape is unknown or known to be true. + # otherwise if its a known false return out.view(tangent.shape). + needed_outs.append( + out + if guard_or_true(sym_eq(out.shape, tangent.shape)) + else out.view(tangent.shape) + ) + needed_tangents.append(tangent) + + setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs]) + + if config.functionalize_rng_ops: + PhiloxStateTracker.mark_beginning_of_backward() + backward_out: tuple[Tensor, ...] = () + # Call the backwards pass + if grad_primals: + functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + if functional_tensor_mode is not None: + # Side-Effect Tokens: + # We want to have independent chains of tokens for forward and backward. + # functional_tensor_mode._tokens is used by both. + # We memoize the result tokens of forward in functional_tensor_mode._tokens_forward_output, + # to return them as joint graph outputs. + # We clean functional_tensor_mode._tokens before backward, to prevent reuse of forward tokens in backward. + # Joint graph tracing allows tokens discovery, + # So all the tokens in backward will be created and added as a graph inputs during tracing. + functional_tensor_mode._tokens_forward_output = ( + functional_tensor_mode._tokens + ) + functional_tensor_mode._tokens = {} + + with set_partitioner_tag_is_backward(), fx_traceback.preserve_node_meta(): + # for full graph export, we always export a joint graph where we assume no tangents are needed. + if aot_config.no_tangents: + assert len(needed_tangents) == 1 and needed_tangents[0].numel() == 1 + backward_out = torch.autograd.grad( + needed_outs, + grad_primals, + allow_unused=True, + ) + else: + backward_out = torch.autograd.grad( + needed_outs, + grad_primals, + grad_outputs=needed_tangents, + allow_unused=True, + ) + backward_out_iter = iter(backward_out) + return outs, [ + next(backward_out_iter) if i else None for i in inputs_needs_grads + ] + + def inner_fn_with_anomaly(*args): + with fx_traceback.preserve_node_meta(), warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Anomaly Detection has been enabled.") + with torch.autograd.detect_anomaly(check_nan=False): + return inner_fn(*args) + + return inner_fn_with_anomaly + + +def create_functionalized_rng_ops_wrapper(func, args, trace_joint=True) -> Any: + # Functionalization of rng ops changes the calling convention of the joint graph. + # It goes from (primals, tangents) to (seed, offset, primals, tangents) + # At runtime, we pass on the current seed and offset. This is hidden from + # the user. + fake_mode = detect_fake_mode() + if fake_mode is None: + fake_mode = nullcontext() + + def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"): + out = PhiloxStateTracker.get_state_as_tensor() + return out + + def override_set_rng_state(x, device: Union[int, str, torch.device] = "cuda"): + PhiloxStateTracker.set_state_from_tensor(x) + + def append_rng_offsets(args): + if trace_joint: + # args signature before: Tuple(fwd_outputs), Tuple(bwd_outputs) + # args signature after: Tuple(fwd_outputs, new_fwd_rng_offset), Tuple(bwd_offset, new_bwd_rng_offset) + return ( + (*args[0], PhiloxStateTracker.get_updated_fwd_offset()), + (*args[1], PhiloxStateTracker.get_updated_bwd_offset()), + ) + else: + # args signature before: Tuple(fwd_outputs) + # args signature after: Tuple(fwd_outputs, new_fwd_rng_offset) + return (*args, PhiloxStateTracker.get_updated_fwd_offset()) + + def traced_joint( + primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset + ): + with patch("torch.cuda.get_rng_state", override_get_rng_state), patch( + "torch.cuda.set_rng_state", override_set_rng_state + ): + return append_rng_offsets(func(primals, tangents)) + + def traced_forward(*primals_fwd_seed_fwd_base_offset): + # The signature is (*primals, seed, offset) + with patch("torch.cuda.get_rng_state", override_get_rng_state), patch( + "torch.cuda.set_rng_state", override_set_rng_state + ): + return append_rng_offsets(func(*primals_fwd_seed_fwd_base_offset[:-2])) + + if trace_joint: + # Get the current seed and offset to setup tracing. + fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( + fake_mode + ) + bwd_seed, bwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( + fake_mode + ) + PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward") + PhiloxStateTracker.record_state(bwd_seed, bwd_base_offset, "backward") + return traced_joint, ( + *args, + fwd_seed, + fwd_base_offset, + bwd_seed, + bwd_base_offset, + ) + else: + # Get the current seed and offset to setup tracing. + fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( + fake_mode + ) + PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward") + return traced_forward, (*args, fwd_seed, fwd_base_offset) + + +@contextmanager +def set_partitioner_tag(tag: str): + meta_key = "partitioner_tag" + assert fx_traceback.has_preserved_node_meta() + + original_val = fx_traceback.current_meta.get(meta_key, None) + fx_traceback.current_meta[meta_key] = tag + try: + yield + finally: + fx_traceback.current_meta[meta_key] = original_val + + +def set_partitioner_tag_is_backward(): + return set_partitioner_tag("is_backward") + + +def set_partitioner_tag_must_be_in_backward(): + return set_partitioner_tag("must_be_in_backward") + + +# This creates the final function that we want to trace using make_fx(), +# in both aot_dispatch_autograd and aot_dispatch_base. +# Preconditions: +# - fn corresponds to the user's fw function +# - fn arguments have been flattened, duplicate arguments have been handled +# - In the returned function, the "primals" arguments *includes* synthetic bases. +# This function does the work of functionalizing the input function, +# and performing copy_() calls at the end of the function if `keep_input_mutations` is set. +# The function returned has signature that is either: +# (1) "traced_fn(primals: List[Any])" if trace_joint is False +# (2) "traced_fn(primals: List[Any], tangents: List[Any])" if trace_joint is True +# Returns a new (functionalized) function, and updated arguments to call it with. +def create_functionalized_fn( + fn, + args, + *, + meta: ViewAndMutationMeta, + aot_config: AOTConfig, + trace_joint: bool, +) -> Any: + @wraps(fn) + def _functionalized_f_helper(*args): + with maybe_enable_thunkify(): + # See Note [Disabling Functionalize TLS Above Python Functionalization] + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + with disable_above: + # The functionalization code here can potentially trigger traces + # into the graph, but we'd prefer to NOT do this, because if we + # trace them now, we will end up with FX nodes that don't have + # module stack annotations, which makes unflattener unhappy. + # Wrap inputs into functional wrappers + f_args = pytree.tree_map(to_fun, args) + + # Run the joint + f_outs = fn(*f_args) + + if trace_joint: + # We support a limited amount of mutation of graph inputs during the backward pass. + # (This is used e.g. by Float8, which needs to update buffers during the backward pass) + # Here, we perform extra checks for primals that were mutated in the **backward** + # We're doing the checks here instead of doing them with the rest of the input mutation handling because: + # - We need to detect inputs that were mutated in the backward **separately** from mutations that happened + # during the forward, because the handling is different: some input mutations from the the forward + # can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same + # types of mutations in the backward we would need a bw-only runtime epilogue. + # - We could in theory have our analysis pass differentiate mutations in the fw from mutations in + # the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would + # require an extra round of tracing though, so it's more efficient to do in-line here. + assert ( + isinstance(args, tuple) + and len(args) == 2 + and isinstance(args[0], (list, tuple)) + ) + # Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw) + primals_before = args[0] + primals_after = pytree.tree_map(from_fun, f_args[0]) + for idx, (f_inpt, before, after, inpt_info) in enumerate( + zip(f_args[0], primals_before, primals_after, meta.input_info) + ): + # Store information about mutations in joint(for backward analysis) + joint_mutates_data = has_data_mutation(f_inpt) + + joint_mutates_metadata = has_metadata_mutation( + f_inpt, before, check_only_storage_mutation=False + ) + + # Ban metadata mutations on fw inputs during the bw + if not inpt_info.mutates_metadata: + assert ( + not joint_mutates_metadata + ), "Found a graph input that had its metadata mutated in the backward. This is not supported" + + # Ban storage resizing on fw inputs during the bw + if not inpt_info.mutation_inductor_storage_resize: + assert not was_inductor_storage_resized( + f_inpt + ), "Found a graph input that had storage resizing in the backward. This is not supported" + + # Allow data mutations on fw inputs during the bw, but only if they do not require grad + # So we can guarantee that we can keep the mutations in the graph + if ( + joint_mutates_data + and not inpt_info.mutates_data + and not inpt_info.mutates_storage_metadata + ): + # Not banning here mutations on inpt_info.requires_grad - + # we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph) + # Add node meta for copy_ for partitioner that this node should be in backward graph. + with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag_must_be_in_backward(): + before.copy_(after) + meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append( + idx + ) + # Now that we covered mutations to *forward* inputs during the backward, + # we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out). + # Today, we will just error in all cases of this happening unless someone needs us to support it. + tangents_before = args[1] + tangents_after = pytree.tree_map(from_fun, f_args[1]) + for f_inpt, before, after in zip( + f_args[1], tangents_before, tangents_after + ): + assert not has_metadata_mutation( + f_inpt, before, check_only_storage_mutation=False + ), "Found an input to the backward that had metadata mutated during the backward pass. This is not supported" + if has_data_mutation(f_inpt): + can_be_in_graph = _check_if_mutation_can_be_in_graph( + keep_input_mutations=True, + mutates_data=True, + mutates_metadata=False, + mutations_hidden_from_autograd=are_all_mutations_hidden_from_autograd( + f_inpt + ), + mutations_under_no_grad_or_inference_mode=are_all_mutations_under_no_grad_or_inference_mode( + f_inpt + ), + mutates_storage_metadata=False, + mutation_inductor_storage_resize=was_inductor_storage_resized( + f_inpt + ), + requires_grad=f_inpt.requires_grad, + ) + assert ( + can_be_in_graph + ), "a backward input that had data mutated in an autograd-aware way. This is not supported" + # Perform the input mutation + with torch.fx.traceback.preserve_node_meta(): + before.copy_(after) + + if aot_config.keep_inference_input_mutations: + # Note: This is a bit annoying. There's a layering issue here, where: + # (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs. + # (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs. + # However, we **only** want to support this for inputs that have data-only (and no metadata) mutations, + # because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()). + # This makes it pretty difficult for this logic to operate on synthetic bases. + # (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual + # (unpacked) input aliases, instead of the synthetic base. + # Example case where (3) could be important: + # + # def f(x, y): + # x.mul_(2) + # y.mul_(3) + # return x, y + # a = torch.ones(1'000'000) + # x, y = out(a[0:9], a[1:10]) + # + # It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing + # a giant "updated synthetic base" and copying into a's entire storage. + # + # For now, we are pessimistically not performing the optimization from (3); + # we will materialize an "updated" synthetic base, and copy it back to the synthetic input base. + # This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry + # about synthetic bases. + for i, (inpt_old, inpt_f) in enumerate( + zip(args, f_args) if not trace_joint else zip(args[0], f_args[0]) + ): + if not isinstance(inpt_f, torch.Tensor): + continue + assert is_fun(inpt_f) + inpt_new = from_fun(inpt_f) + if ( + meta.input_info[i].mutation_type + == MutationType.MUTATED_IN_GRAPH + ): + # See Note [set_() Input Mutations in AOTAutograd] + # all mutations on the input must be under no_grad, so it is safe to put in the graph + # Here, we're saying that if an input experienced a set call, inp.set_(other), + # then we can effectively not have to worry about whether its data was mutated. + # There are 3 cases: + # (1) We mutate inp *after* the set_() call. other is a graph intermediate. + # In this case, we're not really mutating the input storage of "inp"; + # we're mutating the storage of an intermdiate value (other), + # and slamming that storage into the input tensor. So no data mutation is necessary. + # (2) We mutate inp *after* the set_() call. other is a graph *input*. + # In this case, the data mutation will be properly handled in the runtime + # epilogue during the processing of "other" + # (3) We mutate inp *before* the set_() call. + # This case is *not* currently handled. + if meta.input_info[i].mutates_storage_metadata: + with torch.no_grad(): + inpt_old.set_(inpt_new) + + # Note [Ordering of resize_() and set_()] + # Importantly: the common usage in FSDP is that we have a dummy parameter + # that sees a set_() and **Then** a resize_(). + # We must put those mutations into the graph in the same order, + # Since running them in the opposite order will have different behavior. + # We fully ban resize_() followed by set_() for now, although in principal + # we could support this + if meta.input_info[i].mutation_inductor_storage_resize: + # resizing is not supported on subclasses (we error earlier if this happens) + from torch._subclasses.functional_tensor import ( + FunctionalTensor, + ) + + assert isinstance(inpt_f, FunctionalTensor) + old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] + inpt_f.elem, before=True + ) + new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined] + inpt_f.elem, before=False + ) + if old_storage_size != new_storage_size: + assert ( + old_storage_size == 0 or new_storage_size == 0 + ), f"""\ + Encountered a storage resize during tracing on input {i}. Old nbytes={old_storage_size}, new nbytes={new_storage_size} + We only support storage resizing on graph inputs as long as the input either starts or ends with a storage size of 0 + (the case for FSDP)""" + torch.ops.inductor.resize_storage_bytes_( + inpt_old, new_storage_size + ) + if new_storage_size == 0: + # Even if we marked the input as having a data mutation (thus needing a copy_()), + # We should **ignore** it if our input has no storage + # (this can happen if, e.g. we temporarily resize our input, copy data into it, + # and resize it back down to zero) + continue + # Optimization: if the copy_() is a no-op then don't include it in the graph. + # In theory inductor could optimize this away, however in fsdp, we end up with + # param.copy_(param), where param is a zero-storage-size tensor, + # and running this op in eager mode (using the aot_eager backend) will result in a segfault. + # So we may as well optimize it away here. + if inpt_old is inpt_new: + # (This check needs to be done after putting resize_() in the graph, + # since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor) + continue + # We found an input that had a (data-only) mutation. + # Since keep_input_mutations is set, we need to faithfully apply a copy_() + # so the compiler will see the input mutation in the graph. + if ( + meta.input_info[i].mutates_data + and meta.input_info[i].mutations_hidden_from_autograd + ): + # Hidden from autograd = run under no_grad, **and** don't bump VC + # (although if the tensor was created in inference mode, it has no VC) + if inpt_old.is_inference(): + maybe_preserve_vc = nullcontext() + else: + maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter( + inpt_old # type: ignore[assignment] + ) + with torch.no_grad(), maybe_preserve_vc: + inpt_old.copy_(inpt_new) + elif ( + meta.input_info[i].mutates_data + and meta.input_info[ + i + ].mutations_under_no_grad_or_inference_mode + ): + # Under no_grad = run under no_grad (we still bump the VC though) + # (inference_mode will also bump the VC, as long as the tensor in question + # was created outside of inference_mode) + with torch.no_grad(): + inpt_old.copy_(inpt_new) + elif meta.input_info[i].mutates_data: + inpt_old.copy_(inpt_new) + + # When an output tensor is a functionalized mutated input, and we + # were able to move the mutation in to the graph then we can return + # the mutated input directly. This prevents duplicating the + # tensors contents. + flat_outs, outs_spec = pytree.tree_flatten(f_outs) + flat_outs = [from_fun(o) for o in flat_outs] + num_outs = len(meta.output_info) + + for i in range(num_outs): + info = meta.output_info[i] + if info.output_type != OutputType.is_input: + continue + + assert info.base_idx is not None + if ( + meta.input_info[info.base_idx].mutation_type + == MutationType.MUTATED_IN_GRAPH + ): + fw_args = args[0] if trace_joint else args + flat_outs[i] = fw_args[info.base_idx] + return pytree.tree_unflatten(flat_outs, outs_spec) + + return pytree.tree_map(from_fun, f_outs) + + # Kinda annoying, but needed to make sure that the fx graph we trace out has "primals" + # and "tangents" as its input names (which are special-cased by the partitioner) + # TODO (tmanlaibaatar) revisit this if we ever need to turn on non-strict joint graph export + def joint_helper(primals, tangents): + return _functionalized_f_helper(primals, tangents) + + helper = joint_helper if trace_joint else _functionalized_f_helper + if config.functionalize_rng_ops: + # Setup the wrapper for functionalization of rng ops + helper, args = create_functionalized_rng_ops_wrapper(helper, args, trace_joint) + + return helper, args + + +def handle_effect_tokens_fn( + fn, + args, + *, + meta: ViewAndMutationMeta, + trace_joint: bool, +) -> Any: + num_tokens = len(meta.tokens) + + @wraps(fn) + def inner_fn(*args): + # See Note [Disabling Functionalize TLS Above Python Functionalization] + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + with disable_above: + # See Note [Side-Effectful Tokens in AOTAutograd] + if trace_joint: + assert isinstance(args, tuple) and isinstance(args[0], (list, tuple)) + tokens = args[0][:num_tokens] + assert all(token.numel() == 0 for token in tokens) + args = (args[0][num_tokens:], *args[1:]) + else: + tokens = args[:num_tokens] + assert all(token.numel() == 0 for token in tokens) + args = args[num_tokens:] + + # Populate the current FunctionalTensorMode with the tokens per + # operator. See Note [FunctionalTensorMode is Stateful] + functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + assert functional_tensor_mode is not None + f_tokens = pytree.tree_map(to_fun, tokens) + for i, k in enumerate(meta.tokens.keys()): + functional_tensor_mode._tokens[k] = f_tokens[i] + + # Run the joint + outs = fn(*args) + + # Return both the tokens and the outputs + # See Note [Side-Effectful Tokens in AOTAutograd] + if trace_joint: + assert len(outs) == 2 + assert len(functional_tensor_mode._tokens_forward_output) == num_tokens + fwd_out_tokens = functional_tensor_mode._tokens_forward_output.values() + + bwd_out_tokens = functional_tensor_mode._tokens.values() + + f_fwd_out_tokens = [from_fun(t) for t in fwd_out_tokens] + f_bwd_out_tokens = [from_fun(t) for t in bwd_out_tokens] + + meta.num_backward_tokens = len(bwd_out_tokens) + return ((*f_fwd_out_tokens, *outs[0]), (*outs[1], *f_bwd_out_tokens)) + + out_tokens = [from_fun(t) for t in functional_tensor_mode._tokens.values()] + return (*out_tokens, *outs) + + # Additionally pass in tokens as inputs + # See Note [Side-Effectful Tokens in AOTAutograd] + additional_fwd_token_inputs = [torch.tensor([])] * num_tokens + + if trace_joint: + args = ([*additional_fwd_token_inputs, *args[0]], *args[1:]) + else: + args = [*additional_fwd_token_inputs, *args] + return inner_fn, args + + +# Given a function operating on Subclass -> Subclass, returns an function that operates on Tensor -> Tensor +# Also returns: +# - the new set of arguments to pass into this function (now that tensor subclasses have been eliminated) +# - the updated ViewAndMutationMeta for this dense -> dense function. +# The other important arguments are: +# - flat_fn_maybe_joint: when is_joint_structure=True, this is the joint fw-bw function. +# when is_joint_structure=False, this is just the forward function. +# - fw_only: this is *always* the forward-only function. +# Why do we need this? We need to collect updated ViewAndMutationMeta on our new dense -> dense functions. +# In particular, we need this to tell the partitioner how many dense forward outputs there are. +def aot_dispatch_subclass( + flat_fn_maybe_joint, + args: list[Any], + *, + is_joint_structure: bool, + meta: ViewAndMutationMeta, + fw_only: Callable, +) -> SubclassTracingInfo: + # Skip logic if we don't need to trace through any subclasses + req_subclass_dispatch = requires_subclass_dispatch(args, meta) + if not req_subclass_dispatch: + return SubclassTracingInfo( + plain_tensor_trace_fn=flat_fn_maybe_joint, + plain_tensor_args=args, + maybe_subclass_meta=None, + ) + + # TODO: add subclass guards (later PR). + + # What's going on here? We need to compute subclass metadata about the outputs of the joint (grad_inputs). + # Annoying: we don't know the grad input metas until we're in the middle of tracing the joint, + # so we set it later, while we're tracing the joint (see inner_fn() below). + # Another option would be to run our run_functionalized_fw_and_collect_metadata() function + # directly on the joint, but this would hurt compile time (adding yet another pass through the joint). + subclass_meta = SubclassMeta() + + def inner_fn(fn, args, *, use_trace_joint: bool): + # Step 1: wrap tensor inputs into subclasses if necessary + all_args = wrap_tensor_subclasses_maybe_joint( + args, is_joint_structure=use_trace_joint, meta=meta + ) + + # Step 2: call the inner function, with our (maybe subclass) inputs + wrapped_outs = fn(*all_args) + + if use_trace_joint: + # See Note: [Computing Subclass Metadata about grad_inputs] + # We also stash subclass info on our grad_inputs, if we're tracing the joint. + nonlocal subclass_meta + assert isinstance(wrapped_outs, tuple) and len(wrapped_outs) == 2 + # Don't need fw outs since we already have subclass metadata on them + grad_inputs = wrapped_outs[1] + subclass_meta.grad_input_metas = create_subclass_meta(grad_inputs) + + # Add extra symints as outputs to the forward/backward graphs + # ignore nested ints here + forward_outs = unwrap_tensor_subclasses( + wrapped_outs[0], append_symints=True + ) + # ignore nested ints here + backward_outs = unwrap_tensor_subclasses( + wrapped_outs[1], append_symints=True + ) + return (forward_outs, backward_outs) + + # Step 3: Unwrap any subclass outputs back into dense tensors + unwrapped_outs = unwrap_tensor_subclasses(wrapped_outs, append_symints=True) + return unwrapped_outs + + def joint_fn(primals, tangents): + with maybe_enable_thunkify(): + return inner_fn( + flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True + ) + + def fw_fn(*primals): + with maybe_enable_thunkify(): + return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False) + + def metadata_fn(*primals): + return inner_fn(fw_only, primals, use_trace_joint=False) + + if is_joint_structure: + args_unwrapped = ( + # Add extra symints (size/strides) as input to the forward graph + unwrap_tensor_subclasses(args[0], append_symints=True), + # We pass append_symints=False here because the partitioner will + # capture and add any extra argument + unwrap_tensor_subclasses(args[1], append_symints=False), + ) + else: + args_unwrapped = unwrap_tensor_subclasses(args, append_symints=True) + remapped_static_indices = remap_unwrapped_subclass_arg_indices( + args, meta.static_input_indices + ) + + if is_joint_structure: + primals_unwrapped = args_unwrapped[0] + fn_to_trace = joint_fn + else: + primals_unwrapped = args_unwrapped + fn_to_trace = fw_fn + + # Note: [Partitioner handling for Subclasses, Part 1] + # The way the partitioner works is that: + # (1) we pass is a single graph containing the joint fw/bw, + # where the # of graph outputs corresponds to # fw_outputs + # grad_inputs + # (2) The partitioner accepts an arguments, num_fwd_outputs, + # and assumes that the first "num_fwd_outputs" graph outputs correspond + # to outputs of the forward graph. + # How do tensor subclasses enter the picture? + # the num_fwd_outputs in the final graph is actually non-trivial to compute, + # because it can be influenced by input mutations and intermediate bases. + # So we compute it by inspecting the current ViewAndMutationMeta object. + # However, the original ViewAndMutationMeta that we computed was created + # on the subclass -> subclass graph, + # which can have a different number of outputs than the dense -> dense graph. + # That's why we created a fresh metadata object on the dense -> dense function here, + # and plumb it back up to the partitioner. + # See Note: [Partitioner handling for Subclasses, Part 2] for more info. + meta_updated = run_functionalized_fw_and_collect_metadata( + metadata_fn, + static_input_indices=remapped_static_indices, + keep_input_mutations=meta.keep_input_mutations, + is_train=meta.is_train, + )(*primals_unwrapped) + + subclass_meta.fw_metadata = meta_updated + + return SubclassTracingInfo( + plain_tensor_trace_fn=fn_to_trace, + plain_tensor_args=args_unwrapped, + maybe_subclass_meta=subclass_meta, + ) + + +def create_functional_call(mod, params_spec, params_len, store_orig_mod=False): + # Redundant with dynamo, but worth having in case this gets invoked elsewhere. + # https://github.com/pytorch/pytorch/issues/103569 + + def functional_call(*args, **kwargs): + with stateless._reparametrize_module( + mod, pytree.tree_unflatten(args[:params_len], params_spec) + ), maybe_disable_thunkify(): + if isinstance(mod, torch.fx.GraphModule): + with fx_traceback.preserve_node_meta(), warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "Anomaly Detection has been enabled." + ) + with torch.autograd.detect_anomaly(check_nan=False): + detect_fake_mode().epoch += 1 + out = PropagateUnbackedSymInts(mod).run( + *args[params_len:], **kwargs + ) + else: + out = mod(*args[params_len:], **kwargs) + + if not isinstance(out, (tuple, list)): + raise RuntimeError( + "Graph output must be a (). This is so that we can avoid " + "pytree processing of the outputs. Please change the module to " + "have tuple outputs or use aot_module instead." + ) + return out + + # Note [Preserving the nn module stack metadata during export non-strict mode] + # This path is currently only used by the non-strict export flow, + # where we cannot rely on dynamo to preserve nn stack metadata in our captured graph. + # Instead, we stash the original user nn module here, and rely on `make_fx` to grab + # this stashed module and use it to track nn module stack metadata + if store_orig_mod and not hasattr(functional_call, "_orig_mod"): + functional_call._orig_mod = mod # type: ignore[attr-defined] + + return functional_call diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index f028b63b3a8c7..fde5e20ece9b3 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -8,8 +8,12 @@ import warnings from contextlib import nullcontext from functools import wraps +<<<<<<< HEAD from typing import Any, Callable, Optional, TypeVar, Union from typing_extensions import ParamSpec +======= +from typing import Any, Callable, Optional, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.utils._pytree as pytree @@ -20,8 +24,11 @@ from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.proxy_tensor import py_sym_types +<<<<<<< HEAD from .descriptors import AOTOutput +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) KNOWN_TYPES = [ torch.Tensor, @@ -143,9 +150,15 @@ def call_func_at_runtime_with_args( class PytreeThunk: spec: Optional[pytree.TreeSpec] = None # These are some kinda dumb microoptimizations that save about 3-4 us of overhead. +<<<<<<< HEAD is_simple: Optional[bool] = ( None # if the output spec is a tuple/list, we won't bother unflattening it. ) +======= + is_simple: Optional[ + bool + ] = None # if the output spec is a tuple/list, we won't bother unflattening it. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_really_simple: Optional[bool] = None # if the output spec is a LeafSpec def set(self, spec: pytree.TreeSpec) -> None: @@ -338,12 +351,21 @@ def do(module, subgraph, expected_num_erased): num_erased_inputs = len(input_token_nodes) +<<<<<<< HEAD assert num_erased_inputs == expected_num_erased, ( f"{subgraph} num_erased_inputs:{num_erased_inputs} {input_token_nodes}!=expected {expected_num_erased}" ) assert num_erased_outs == expected_num_erased, ( f"{subgraph} num_erased_outs:{num_erased_outs} {output_token_nodes}!=expected {expected_num_erased}" ) +======= + assert ( + num_erased_inputs == expected_num_erased + ), f"{subgraph} num_erased_inputs:{num_erased_inputs} {input_token_nodes}!=expected {expected_num_erased}" + assert ( + num_erased_outs == expected_num_erased + ), f"{subgraph} num_erased_outs:{num_erased_outs} {output_token_nodes}!=expected {expected_num_erased}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) module.recompile() @@ -516,6 +538,7 @@ def saved_tensors_hooks_are_inlineable(hooks) -> bool: return isinstance(pack, torch.fx.GraphModule) and isinstance( unpack, torch.fx.GraphModule ) +<<<<<<< HEAD _P = ParamSpec("_P") @@ -580,3 +603,5 @@ def fn_wrappers(fn): f = f.__wrapped__ fns.append(f) return fns +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 1e0cb6a2ef8be..fa8ba45072141 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -1,10 +1,18 @@ # mypy: ignore-errors +<<<<<<< HEAD import contextlib import itertools from contextlib import nullcontext from functools import wraps from typing import Any, Callable, Optional +======= +import itertools +from collections.abc import KeysView, Sequence +from contextlib import contextmanager, nullcontext +from functools import partial, wraps +from typing import Any, Callable, NewType, Optional, Protocol, TypeVar +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from unittest.mock import patch import torch @@ -24,11 +32,23 @@ ) from torch._guards import detect_fake_mode from torch._inductor.cudagraph_utils import BoxedDeviceIndex +<<<<<<< HEAD from torch._inductor.utils import BoxedBool from torch._subclasses import FakeTensor, FakeTensorMode from torch.export._tree_utils import reorder_kwargs from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ShapeEnv +======= +from torch._inductor.output_code import OutputCode +from torch._inductor.utils import BoxedBool, InputType +from torch._subclasses import FakeTensor, FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + _pytree_subclasses_that_lose_info, + make_fx, +) +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static_inputs_log = torch._logging.getArtifactLogger( @@ -44,6 +64,7 @@ from ._aot_autograd.collect_metadata_analysis import ( # noqa: F401 run_functionalized_fw_and_collect_metadata, ) +<<<<<<< HEAD from ._aot_autograd.descriptors import ( AOTInput, BufferAOTInput, @@ -56,6 +77,8 @@ construct_fake_mode, process_inputs, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ._aot_autograd.functional_utils import ( # noqa: F401 _check_if_mutation_can_be_in_graph, are_all_mutations_hidden_from_autograd, @@ -69,6 +92,7 @@ sync_functional_tensor, to_fun, ) +<<<<<<< HEAD from ._aot_autograd.graph_capture_wrappers import ( # noqa: F401 aot_dispatch_subclass, create_functional_call, @@ -83,12 +107,22 @@ aot_stage2_compile, aot_stage2_export, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ._aot_autograd.input_output_analysis import ( # noqa: F401 compute_overlapping_inputs, create_graph_signature, create_synthetic_base_metadata, remove_dupe_metadata, ) +<<<<<<< HEAD +======= +from ._aot_autograd.jit_compile_runtime_wrappers import ( # noqa: F401 + aot_dispatch_autograd, + aot_dispatch_base, + aot_dispatch_export, +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ._aot_autograd.logging_utils import ( # noqa: F401 callback_set, describe_input, @@ -109,21 +143,31 @@ ) from ._aot_autograd.schemas import ( # noqa: F401 AOTConfig, +<<<<<<< HEAD AOTDispatchCompiler, AOTGraphCapture, AOTState, BackwardSignature, FakifiedFlatArgs, +======= + BackwardSignature, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) FQN, GraphInputName, GraphOutputName, GraphSignature, InputAliasInfo, +<<<<<<< HEAD JointWithDescriptors, MutationType, OutputAliasInfo, OutputType, SerializableAOTDispatchCompiler, +======= + MutationType, + OutputAliasInfo, + OutputType, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SubclassCreationMeta, SubclassMeta, TensorAlias, @@ -136,6 +180,18 @@ wrap_tensor_subclasses, wrap_tensor_subclasses_maybe_joint, ) +<<<<<<< HEAD +======= +from ._aot_autograd.traced_function_transforms import ( # noqa: F401 + aot_dispatch_subclass, + create_functional_call, + create_functionalized_fn, + create_functionalized_rng_ops_wrapper, + create_joint, + fn_input_mutations_to_outputs, + fn_prepped_for_autograd, +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ._aot_autograd.utils import ( # noqa: F401 _get_autocast_states, _get_symint_hints, @@ -148,7 +204,10 @@ normalize_as_list, partial_flatten_asdict, root_module_when_exporting_non_strict, +<<<<<<< HEAD simple_wraps, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) strict_zip, ) from .partitioners import default_partition @@ -380,7 +439,11 @@ # # We view every forward output when creating out tangent tensors to handle the problematic # case in which a subclass does extra aliasing between graph outputs/inputs in a way that +<<<<<<< HEAD # is not visible above the subclass. +======= +# is not visible above the sublass. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # # Ordinarily, when constructing the joint function that we want to trace in AOTAutograd, # we're guaranteed that the tangent tensors that we pass @@ -454,6 +517,7 @@ aot_autograd_decompositions = {} +<<<<<<< HEAD def create_aot_state( stack: contextlib.ExitStack, @@ -464,6 +528,154 @@ def create_aot_state( fake_mode: FakeTensorMode, shape_env: Optional[ShapeEnv], ) -> AOTState: +======= +FakifiedFlatArgs = NewType("FakifiedFlatArgs", list[Any]) + + +TOutputCode = TypeVar("TOutputCode", bound=OutputCode) + + +class AOTDispatchCompiler(Protocol): + """ + Represents a fw or bw_compiler passed to AOTAutograd. + """ + + def __call__( + self, + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + ) -> Any: + ... + + +# TODO: bikeshed on this name +class SerializableAOTDispatchCompiler(AOTDispatchCompiler): + """ + Represents an AOTDispatchCompiler that returns an OutputCode, and is + therefore cacheable. SerializableAOTDispatchCompiler always return an OutputCode. + A _CompileFxCallable usually gets converted into an AOTDispatchCompiler after binding all of + the kwargs in _CompileFxKwargs. + """ + + def __init__( + self, + output_code_ty: type[TOutputCode], + compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode], + ): + self.output_code_ty = output_code_ty + self.compiler_fn = compiler_fn + + def __call__( + self, + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + ) -> OutputCode: + return self.compiler_fn(gm, example_inputs) + + +def process_inputs( + flat_args: list[Any], + aot_config: AOTConfig, + fake_mode: FakeTensorMode, + shape_env: Optional[ShapeEnv], + ignore_shape_env: bool = False, +) -> FakifiedFlatArgs: + with fake_mode: + + def convert(idx, x): + if shape_env is not None and not ignore_shape_env: + from torch._dynamo.source import ConstantSource + + if isinstance(x, int): + # We always specialize on scalar values in export. + if aot_config.is_export: + return x + source = ConstantSource(f"sym_{idx}") + return shape_env.create_symintnode( + shape_env.create_symbol(x, source), hint=x, source=source + ) + if isinstance(x, torch.ScriptObject): + return torch._library.fake_class_registry.maybe_to_fake_obj( + fake_mode, x + ) + if not isinstance(x, torch.Tensor): + return x + if isinstance(x, FakeTensor): + assert x.fake_mode is fake_mode + return x + if is_traceable_wrapper_subclass(x): + attrs, _ = x.__tensor_flatten__() + if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs): + assert all( + getattr(x, attr).fake_mode is fake_mode for attr in attrs + ) + return x + + # see note [Tensor Fakification and Symbol Caching] + symbolic_context = None + source = None + trace = True + if tracing_context := torch._guards.TracingContext.try_get(): + if x in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[x] + source = symbolic_context.tensor_source + # We already fakeified this tensor in Dynamo, don't + # dump the trace for it again + trace = False + if ( + idx < aot_config.num_params_buffers + and config.static_weight_shapes + and not symbolic_context + ): + # TODO: Ensure that this codepath is never exercised from + # Dynamo + return fake_mode.from_tensor(x, static_shapes=True) + + result = fake_mode.from_tensor( + x, + static_shapes=ignore_shape_env, + symbolic_context=symbolic_context, + source=source, + trace=trace, + ) + return result + + return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)]) + + +def construct_fake_mode( + flat_args: list[Any], aot_config: AOTConfig +) -> tuple[FakeTensorMode, Optional[ShapeEnv]]: + fake_mode = detect_fake_mode(flat_args) + if fake_mode is None: + shape_env = ShapeEnv() if aot_config.dynamic_shapes else None + fake_mode = FakeTensorMode(shape_env=shape_env) + else: + shape_env = fake_mode.shape_env + return (fake_mode, shape_env) + + +def create_aot_dispatcher_function( + flat_fn, + fake_flat_args: FakifiedFlatArgs, + aot_config: AOTConfig, + fake_mode: FakeTensorMode, + shape_env: Optional[ShapeEnv], +) -> tuple[Callable, ViewAndMutationMeta]: + with dynamo_timed("create_aot_dispatcher_function", log_pt2_compile_event=True): + return _create_aot_dispatcher_function( + flat_fn, fake_flat_args, aot_config, fake_mode, shape_env + ) + + +def _create_aot_dispatcher_function( + flat_fn, + fake_flat_args: FakifiedFlatArgs, + aot_config: AOTConfig, + fake_mode: FakeTensorMode, + shape_env: Optional[ShapeEnv], +) -> tuple[Callable, ViewAndMutationMeta]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Traces the forward and backward graphs of the attr:`flat_fn` to generate a joint graph. The joint graph is an Fx graph with Aten ops. Please refer to @@ -480,6 +692,7 @@ def create_aot_state( inputs in flat_args are parameters and buffers, and the rest are inputs. We use this to assume that parameters/buffer's shapes don't change. +<<<<<<< HEAD """ # Old name for now to avoid messing with stats. Also, note this is pushed @@ -487,6 +700,13 @@ def create_aot_state( stack.enter_context( dynamo_timed("create_aot_dispatcher_function", log_pt2_compile_event=True) ) +======= + + Note: this function is used both by aot_function and aot_export (controlled by aot_config.is_export) + When aot_config.is_export is True, we return an FX graph + metadata + When aot_config.is_export is False, we return an ordinary runtime function + """ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This is the main entry point. # TODO: Chillee argues that dynamo itself should pass in fake tensors to @@ -518,6 +738,7 @@ def create_aot_state( # If any saved tensor hooks are active, we **don't** want to trace them. # Instead, we'll let them run at runtime, around the custom autograd.Function # that we generate in torch.compile. +<<<<<<< HEAD stack.enter_context(torch.autograd.set_multithreading_enabled(False)) stack.enter_context(preserve_rng_state()) stack.enter_context(fake_mode) @@ -654,10 +875,143 @@ def _dup_fake_script_obj(fake_flat_args): if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0: raise RuntimeError( f"""\ +======= + with torch.autograd.set_multithreading_enabled( + False + ), preserve_rng_state(), ( + fake_mode + ), ( + python_dispatcher_mode + ), PhiloxStateTracker(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): + from torch._library.fake_class_registry import ( + FakeScriptObject, + maybe_to_fake_obj, + ) + + # Tracing may mutate the states the fake script object, + # so we need to duplicate the fake script objects so that subsequent tracing + # won't be affected. + def _dup_fake_script_obj(fake_flat_args): + return [ + maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj) + if isinstance(arg, FakeScriptObject) + else arg + for arg in fake_flat_args + ] + + needs_autograd = any( + x.requires_grad for x in fake_flat_args if isinstance(x, Tensor) + ) + + with enable_python_dispatcher(): + # Patch set_rng_state as set_rng_state with fake tensors is + # nonsensical. This does not affect the collection of metadata. + with patch("torch.cuda.set_rng_state", lambda *args: None): + mod = root_module_when_exporting_non_strict(flat_fn) + if mod is not None: + ctx = _detect_attribute_assignment(mod) + else: + ctx = nullcontext() + + if torch._functorch.config.fake_tensor_propagate_real_tensors: + # Running dynamo_timed causes fake tensor issues when + # propagate real tensor is switched on. + dynamo_timed_ctx = nullcontext() + else: + dynamo_timed_ctx = dynamo_timed( + "aot_collect_metadata", log_pt2_compile_event=True + ) + + with dynamo_timed_ctx, ctx: + fw_metadata = run_functionalized_fw_and_collect_metadata( + flat_fn, + static_input_indices=aot_config.static_input_indices, + keep_input_mutations=aot_config.keep_inference_input_mutations, + is_train=needs_autograd, + pre_dispatch=aot_config.pre_dispatch, + is_export=aot_config.is_export, + )(*_dup_fake_script_obj(fake_flat_args)) + + req_subclass_dispatch = requires_subclass_dispatch( + fake_flat_args, fw_metadata + ) + CompileEventLogger.try_add_pt2_compile( + "backend_compile", requires_subclass_dispatch=req_subclass_dispatch + ) + + output_and_mutation_safe = not any( + x.requires_grad + # view-type operations preserve requires_grad even in no_grad. + # Do not count aliases of inputs with requires_grad as reason to make a training graph, + # as AOTAutograd will perform view-replay to regenerate the view outputs at runtime, + # setting their grad_fn properly. + and not ( + x.output_type + in (OutputType.alias_of_input, OutputType.is_input) + and fw_metadata.input_info[x.base_idx].requires_grad + ) + for x in fw_metadata.output_info + ) and not any( + x.requires_grad + and x.mutates_data + and not x.mutations_under_no_grad_or_inference_mode + and not x.mutations_hidden_from_autograd + for x in fw_metadata.input_info + ) + + if needs_autograd and output_and_mutation_safe: + # We realized that none of the outputs require grad, + # and none of the inputs that require grad are mutated. + # so we actually have an inference graph. + needs_autograd = False + # A bit silly: right now in the subclass codepath, our ViewAndMutationMeta + # changes depending on whether we pass in is_train / keep_input_mutations, + # so we're forced to recompute the metadata. + # TODO: refactor the subclass path of run_functionalized_fw_and_collect_metadata + # so that this is unnecessary. + if req_subclass_dispatch: + fw_metadata = run_functionalized_fw_and_collect_metadata( + flat_fn, + keep_input_mutations=aot_config.keep_inference_input_mutations, + is_train=False, + pre_dispatch=aot_config.pre_dispatch, + static_input_indices=aot_config.static_input_indices, + )(*fake_flat_args) + else: + fw_metadata = ViewAndMutationMeta( + input_info=fw_metadata.input_info, + output_info=fw_metadata.output_info, + num_intermediate_bases=fw_metadata.num_intermediate_bases, + keep_input_mutations=aot_config.keep_inference_input_mutations, + traced_tangents=fw_metadata.traced_tangents, + subclass_inp_meta=fw_metadata.subclass_inp_meta, + subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta, + subclass_tangent_meta=fw_metadata.subclass_tangent_meta, + is_train=False, + tokens=fw_metadata.tokens, + static_input_indices=fw_metadata.static_input_indices, + ) + + if fw_metadata.num_intermediate_bases > 0: + assert not req_subclass_dispatch, f"""\ +torch.compile is currently being used with tensor subclass inputs: +{','.join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs +that alias one another, which is currently unsupported in the subclass use case. If you run into this, +please file a github issue""" + + if aot_config.is_export: + # aot_export: ban input metadata mutations for now to keep shared code paths simpler. + # Keeping .resize_() in the graph will require some work + # Allowing it but keeping the graph functional will require some calling convention changes. + if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0: + raise RuntimeError( + f"""\ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Found an input that received a metadata mutation, through e.g. a call to `.resize_()` or `.transpose_()`. This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. fw_metadata={str(fw_metadata)}""" +<<<<<<< HEAD ) # In export, banning data mutations on inputs that require grad for now. # This should be rare, and is tricky to get right. When we trace the backward, @@ -676,10 +1030,30 @@ def _dup_fake_script_obj(fake_flat_args): ): raise RuntimeError( f"""\ +======= + ) + # In export, banning data mutations on inputs that require grad for now. + # This should be rare, and is tricky to get right. When we trace the backward, + # we currently trace with autograd.grad instead of .backward(), which makes it difficult + # to ensure that we run autograd all the way through the input **before** it saw the mutation. + if ( + len( + [ + x + for x in fw_metadata.input_info + if x.requires_grad and x.mutates_data + ] + ) + != 0 + ): + raise RuntimeError( + f"""\ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Found a graph input that requires gradients, and received a mutation. This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. fw_metadata={str(fw_metadata)}""" +<<<<<<< HEAD ) if req_subclass_dispatch: raise RuntimeError( @@ -706,6 +1080,56 @@ def _dup_fake_script_obj(fake_flat_args): aot_config=aot_config, stack=stack, ) +======= + ) + if req_subclass_dispatch: + raise RuntimeError( + """\ +aot_export is not currently supported with traceable tensor subclass. +If you need this feature, please comment on """ + ) + + # Need to decide on a strategy for functionalized RNG: toggling via global config seems bad, + # and turning it on will require a non-trivial calling convention change for any export runtime. + if config.functionalize_rng_ops: + raise RuntimeError( + """\ +Functionalized RNG is not currently supported in the aot_export workflow. Please file a github issue, +or otherwise set torch._functorch.config.functionalize_rng_ops = False.""" + ) + + def choose_dispatcher(needs_autograd, aot_config): + """ + Pick a dispatcher based on the config rules. + """ + if aot_config.is_export: + # export uses just the "graph bits", whereas the other + # two dispatchers include some extra work around handling a runtime epilogue + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="export" + ) + return partial(aot_dispatch_export, needs_autograd=needs_autograd) + elif needs_autograd and not aot_config.pre_dispatch: + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="autograd" + ) + return aot_dispatch_autograd + else: + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="inference" + ) + return aot_dispatch_base + + compiler_fn = choose_dispatcher(needs_autograd, aot_config) + + compiled_fn, fw_metadata = compiler_fn( + flat_fn, + _dup_fake_script_obj(fake_flat_args), + aot_config, + fw_metadata=fw_metadata, + ) + return compiled_fn, fw_metadata +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def aot_function( @@ -738,7 +1162,11 @@ def aot_function( This API is experimental and likely to change. Args: +<<<<<<< HEAD fn (Callable): A Python function that takes one or more arguments. Must +======= + fn (Callable): A Python function that takes one ore more arguments. Must +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return one or more Tensors. fw_compiler (Callable): A Python function that accepts an Fx graph with Aten ops and input args, and returns a Callable that semantically is @@ -765,7 +1193,11 @@ def aot_function( A simple example usage of :func:`aot_function` is as follows. This example will print the forward and backward graphs of the function ``fn`` +<<<<<<< HEAD >>> fn = lambda x: x.sin().cos() +======= + >>> fn = lambda x : x.sin().cos() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> def print_compile_fn(fx_module, args): >>> print(fx_module) >>> return fx_module @@ -808,6 +1240,7 @@ def returned_function(*args, **kwargs): fake_flat_args: FakifiedFlatArgs = process_inputs( flat_args, aot_config, fake_mode, shape_env ) +<<<<<<< HEAD # TODO: We actually could use the pytree path to make better descs. # Also, the descs here are bad if you do aot_module. fake_flat_args_descs = [ @@ -825,6 +1258,15 @@ def returned_function(*args, **kwargs): ) aot_graph_capture = aot_stage1_graph_capture(aot_state, flat_fn) compiled_fn, _ = aot_stage2_compile(aot_state, aot_graph_capture) +======= + compiled_fn, _ = create_aot_dispatcher_function( + flat_fn, + fake_flat_args, + aot_config, + fake_mode, + shape_env, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cached_res = (compiled_fn, out_spec) cached_fn, out_spec = cached_res @@ -886,6 +1328,7 @@ def forward(self, *args, **kwargs): return AOTModule() +<<<<<<< HEAD def prepare_aot_module_simplified( mod: nn.Module, args, @@ -1011,6 +1454,87 @@ def prepare_aot_module_simplified( in_spec, out_spec, ) +======= +def _try_get_metadata_from_dynamo( + mod: torch.nn.Module, param_keys: KeysView[str], full_args_num: int +) -> tuple[Optional[list[torch._guards.Source]], list[int]]: + """ + Metadata is forwarded from Dynamo to AOTDispatch via special fields on GraphModule. + We first verify that `mod` does come from Dynamo, then we handle cases where + metadata might be missing. + + Returns: + aot_autograd_arg_pos_to_source: used to dedup params and their guards + static_input_indices: used to identify static inputs for cudagraphs + """ + # Note [Assumption on Dynamo Metadata] + # This function assumes a graph module from dynamo provides `dynamo_compiled_id`, + # _param_name_to_source, and every placeholder node has `_dynamo_source` attributes. + # When gm is modified (e.g., DDPOptimizer via split_module), metadata needs to + # be propagated in order to be recognized as a dynamo graph + + if not (isinstance(mod, torch.fx.GraphModule) and "dynamo_compile_id" in mod.meta): + # graph was not captured by dynamo + return None, [] + + if not hasattr(mod, "_param_name_to_source"): + # is from export + return None, [] + + # We now know this came from dynamo, and (1) we care about guards, + # so setting up aot_autograd_arg_pos_to_source for downstream dedup guards + # can now be done safely. (2) Dynamo logic protects the 1:1 sizing below. + # Additionally, we mark static indices for cudagraphs. + param_name_to_source = mod._param_name_to_source + seen_sources = set() + + aot_autograd_arg_pos_to_source = [] + static_input_indices = [] + # Collect the new inputs lifted by aotdispatch + for i, name in enumerate(param_keys): + assert name in param_name_to_source, f"{name} not found." + source = param_name_to_source[name] + assert source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + + static_input_indices.append(i) + + # Collect the dynamo graph inputs + # TODO(mlazos): Revisit if this is still needed. With Dynamo install ID + # matched tensors back into the Fx graph, this might not be necessary. + for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")): + assert hasattr(node, "_dynamo_source") + source = node._dynamo_source + # `source`` specifies the source from user code. ddp optimizer may have + # intermediate values becoming submodule placeholders which does not + # have a source + assert source is None or source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + source_name = source.name() if source else str(source) + + # input[i] in dynamo is now: + # input[i + len(extra_params)] in AOT, + # where extra_params are the params/buffers that dynamo baked into the + # OutputGraph + actual_pos = pos + len(param_keys) + + if "tensor_dict" in node.meta and node.meta["tensor_dict"].get( + "_dynamo_static_input_type", None + ): + static_inputs_log.debug( + "Adding static input pos %s for source %s", actual_pos, source_name + ) + static_input_indices.append(actual_pos) + else: + static_inputs_log.debug( + "Non-static input pos %s for source %s", actual_pos, source_name + ) + + assert full_args_num == len(aot_autograd_arg_pos_to_source) + return aot_autograd_arg_pos_to_source, static_input_indices +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def aot_module_simplified( @@ -1022,8 +1546,11 @@ def aot_module_simplified( decompositions: Optional[dict] = None, keep_inference_input_mutations=False, inference_compiler: Optional[AOTDispatchCompiler] = None, +<<<<<<< HEAD # TODO: This doesn't seem to be used in any nontrivial way, check if it's # actually needed +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cudagraphs: Optional[BoxedBool] = None, boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, ignore_shape_env: bool = False, @@ -1038,14 +1565,29 @@ def aot_module_simplified( :func:`aot_module_simplified` removes these overheads. """ +<<<<<<< HEAD if cudagraphs is None: cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs) +======= + params = { + **dict(mod.named_parameters(remove_duplicate=False)), + **dict(mod.named_buffers(remove_duplicate=False)), + } + params_flat, params_spec = pytree.tree_flatten(params) + params_flat = list(params_flat) + params_len = len(params_flat) + + if cudagraphs is None: + cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if bw_compiler is None: bw_compiler = fw_compiler if inference_compiler is None: inference_compiler = fw_compiler +<<<<<<< HEAD with contextlib.ExitStack() as stack: ( functional_call, @@ -1098,25 +1640,113 @@ def aot_module_simplified( functional_call, fake_flat_args, full_args_descs, +======= + full_args = [] + # First, the params + full_args.extend(params_flat) + + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.params_flat = params_flat + ( + tracing_context.params_flat_unwrap_subclasses, + tracing_context.params_unwrapped_to_flat_index, + ) = unwrap_tensor_subclasses_with_indices_to_original(params_flat) + + # Next, the input args + full_args.extend(args) + + ( + aot_autograd_arg_pos_to_source, + static_input_indices, + ) = _try_get_metadata_from_dynamo(mod, params.keys(), len(full_args)) + + dynamic_shapes = False + for x in full_args: + if isinstance(x, FakeTensor): + dynamic_shapes = x.fake_mode.shape_env is not None + break + + aot_config = AOTConfig( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + inference_compiler=inference_compiler, + partition_fn=partition_fn, + decompositions=decompositions, + num_params_buffers=params_len, + aot_id=next(AOT_COUNTER), + keep_inference_input_mutations=keep_inference_input_mutations, + dynamic_shapes=dynamic_shapes, + aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source, + static_input_indices=static_input_indices, + is_export=False, + no_tangents=False, + cache_info=None, + ignore_shape_env=ignore_shape_env, + precompile_backend_id=getattr(mod, "_backend_id", None), + ) + fake_mode, shape_env = construct_fake_mode(full_args, aot_config) + fake_flat_args = process_inputs( + full_args, aot_config, fake_mode, shape_env, ignore_shape_env + ) + + def dispatch_and_compile(): + functional_call = create_functional_call(mod, params_spec, params_len) + with compiled_autograd._disable(): + compiled_fn, _ = create_aot_dispatcher_function( + functional_call, + fake_flat_args, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aot_config, fake_mode, shape_env, ) +<<<<<<< HEAD aot_graph_capture = aot_stage1_graph_capture(aot_state, functional_call) compiled_fn, _ = aot_stage2_compile(aot_state, aot_graph_capture) +======= + return compiled_fn + + # We only care if the forward will return an OutputCode. + if isinstance(fw_compiler, SerializableAOTDispatchCompiler): + local = should_use_local_autograd_cache() + remote = should_use_remote_autograd_cache() + if local or remote: + set_feature_use("aot_autograd_remote_cache", remote) + compiled_fn = AOTAutogradCache.load( + dispatch_and_compile, + mod, + fake_flat_args, + aot_config, + cudagraphs, + boxed_forward_device_index, + local, + remote, + ) + else: + compiled_fn = dispatch_and_compile() + else: + compiled_fn = dispatch_and_compile() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(mod, torch._dynamo.utils.GmWrapper): # This function is called by the flatten_graph_inputs wrapper, which boxes # the inputs so that they can be freed before the end of this scope. # For overhead reasons, this is not the default wrapper, see comment: # https://github.com/pytorch/pytorch/pull/122535/files#r1560096481 +<<<<<<< HEAD def forward(runtime_args: list[Any]): flat_args = [] flat_args.extend(params_buffers_flat) +======= + def boxed_forward(runtime_args: list[Any]): + flat_args = [] + flat_args.extend(params_flat) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) flat_args.extend(runtime_args) runtime_args.clear() return compiled_fn(flat_args) +<<<<<<< HEAD else: # TODO: There is something deeply wrong here; compiled_fn running with # the boxed calling convention, but aot_module_simplified somehow @@ -1128,6 +1758,24 @@ def forward(*runtime_args: tuple[Any]): full_args.extend(params_buffers_flat) full_args.extend(runtime_args) return compiled_fn(full_args) +======= + # Just for convenience + boxed_forward.zero_grad = mod.zero_grad + boxed_forward.named_parameters = mod.named_parameters + boxed_forward.named_buffers = mod.named_buffers + return boxed_forward + + # TODO: There is something deeply wrong here; compiled_fn running with + # the boxed calling convention, but aot_module_simplified somehow + # historically returned a function that was not the boxed calling + # convention. This should get fixed... + # NB: GraphModule/nn.Module rely on the non-boxed calling convention here + def forward(*runtime_args: tuple[Any]): + full_args = [] + full_args.extend(params_flat) + full_args.extend(runtime_args) + return compiled_fn(full_args) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Just for convenience forward.zero_grad = mod.zero_grad @@ -1137,6 +1785,7 @@ def forward(*runtime_args: tuple[Any]): return forward +<<<<<<< HEAD def boxed_nop_preserve_node_meta(fx_g, example_inputs): def run(args): with torch.fx.traceback.preserve_node_meta(): @@ -1305,6 +1954,8 @@ def unflattened_compiled_fn(*args, **kwargs): return unflattened_compiled_fn +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def aot_export_module( mod: nn.Module, args, @@ -1317,7 +1968,11 @@ def aot_export_module( # Your module can return multiple outputs, so you must specify which output the loss is. output_loss_index: Optional[int] = None, pre_dispatch: bool = False, +<<<<<<< HEAD # If None, will be inferred from inputs and mod.graph.nodes if mod is a graph module, but the inferred result might be wrong. +======= + # If None, will be infered from inputs and mod.graph.nodes if mod is a graph module, but the inferred result might be wrong. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dynamic_shapes: Optional[bool] = None, kwargs=None, ) -> tuple[torch.fx.GraphModule, GraphSignature]: @@ -1449,11 +2104,16 @@ def fn_to_trace(*args): no_tangents=True, pre_dispatch=pre_dispatch, dynamic_shapes=dynamic_shapes, +<<<<<<< HEAD trace_joint=trace_joint, kwargs=kwargs, ) # TODO: subsume this path with the aot_stage2_graph_capture path +======= + kwargs=kwargs, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if trace_joint: @wraps(functional_call) @@ -1485,7 +2145,13 @@ def flattened_joint(*args): output_gradients = [] for a, grad in zip(args, gradients): if isinstance(a, torch.Tensor) and a.requires_grad: +<<<<<<< HEAD assert grad is not None, """\ +======= + assert ( + grad is not None + ), """\ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Found a parameter that did not receive a gradient. "This is most likely a bug, but if this needs to be supported please comment on this Github issue: https://github.com/pytorch/pytorch/issues/101192 @@ -1519,7 +2185,11 @@ def aot_export_joint_simple( *, trace_joint: bool, # It looks like the main consequence of this API is that for dynamic shapes, +<<<<<<< HEAD # it will assume that params/buffers are static. +======= + # it will assume that parms/buffers are static. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # With the new inferred dynamic shapes API, maybe this doesn't matter? num_params_buffers: int = 0, decompositions: Optional[dict] = None, @@ -1552,7 +2222,10 @@ def aot_export_joint_simple( func, args, decompositions=decompositions, +<<<<<<< HEAD trace_joint=trace_joint, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) in_spec, _kw_in_spec = in_spec.children_specs # At this point, we can just directly return the (joint or inference graph) that we traced. @@ -1599,9 +2272,13 @@ def aot_export_joint_simple( if config.debug_assert: # Smoke test that after partitioning, we can run the forward without any calling convention changes. fw_module, _bw_module = aot_config.default_partition( # noqa: F821 +<<<<<<< HEAD fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos), # noqa: F821 +======= + fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos) # noqa: F821 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Attempt to run the fw_module with the original user inputs fake_mode = detect_fake_mode(args) @@ -1631,11 +2308,16 @@ def _aot_export_function( # We don't know this info at trace time though, so we need to make it an explicit config. no_tangents: bool = False, pre_dispatch: bool = False, +<<<<<<< HEAD # If None, `dynamic_shapes` will be inferred from inputs, but the inferred result might be wrong. dynamic_shapes: Optional[bool] = None, keep_input_mutations: bool = False, # Under export, configures whether we are getting inference or training IR trace_joint: bool = False, +======= + # If None, `dynamic_shapes` will be infered from inputs, but the inferred result might be wrong. + dynamic_shapes: Optional[bool] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs=None, ) -> tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]: kwargs = kwargs or {} @@ -1674,19 +2356,27 @@ def _aot_export_function( # For now there's no use case involving keeping input mutations in the graph # (which we can only do in the inference case anyway). # We can add this later if we need to. +<<<<<<< HEAD keep_inference_input_mutations=keep_input_mutations, +======= + keep_inference_input_mutations=False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dynamic_shapes=dynamic_shapes, aot_autograd_arg_pos_to_source=None, is_export=True, no_tangents=no_tangents, pre_dispatch=pre_dispatch, +<<<<<<< HEAD export_trace_joint=trace_joint, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if fake_mode is None: fake_mode, shape_env = construct_fake_mode(flat_args, aot_config) else: shape_env = fake_mode.shape_env fake_flat_args = process_inputs(flat_args, aot_config, fake_mode, shape_env) +<<<<<<< HEAD # TODO: Improve the descs here with pytree information fake_flat_args_descs = [PlainAOTInput(i) for i in range(len(fake_flat_args))] @@ -1706,5 +2396,118 @@ def _aot_export_function( return fx_g, meta, in_spec, out_spec.spec +======= + + fx_g, meta = create_aot_dispatcher_function( + flat_fn, + fake_flat_args, + aot_config, + fake_mode, + shape_env, + ) + return fx_g, meta, in_spec, out_spec.spec + + +@contextmanager +def _detect_attribute_assignment(mod: torch.nn.Module): + # Do not allow assignment of tensor attributes during export unless + # the attribute is registered as a buffer. + + NN_MODULE_STD_ATTRS = [ + "_backward_hooks", + "_backward_pre_hooks", + "_buffers", + "_forward_hooks", + "_forward_hooks_always_called", + "_forward_hooks_with_kwargs", + "_forward_pre_hooks", + "_forward_pre_hooks_with_kwargs", + "_is_full_backward_hook", + "_load_state_dict_post_hooks", + "_load_state_dict_pre_hooks", + "_modules", + "_non_persistent_buffers_set", + "_parameters", + "_state_dict_hooks", + "_state_dict_pre_hooks", + "training", + ] + NN_MODULE_LAZY_STD_ATTRS = [ + "_initialize_hook", + "_load_hook", + ] + STD_ATTRS = { + *NN_MODULE_STD_ATTRS, + *NN_MODULE_LAZY_STD_ATTRS, + } + + def _get_attributes(mod): + # return any attributes of a module that are not standard attributes + return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS} + + # save state of attributes before enter + snapshot = pytree.tree_map( + lambda x: x, + _get_attributes(mod), + is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info, + ) + try: + yield + finally: + # after exit, compare state of attributes with snapshot + # to detect which tensor attributes were assigned + assigned_tensor_attributes = [] + + def _collect_assigned_tensor_attributes(kp, v, _v): + if _v is not v: + attr, *rest = kp + if isinstance(v, torch.Tensor): + assigned_tensor_attributes.append( + f"self.{attr.key}{pytree.keystr(rest)}" + ) + # TODO(avik): Assigning all other types are allowed right now. + # Maybe in the future we want to limit this to primitive types? + return v + + new_attrs = _get_attributes(mod) + if len(new_attrs) != len(snapshot): + added_attrs = new_attrs.keys() - snapshot.keys() + deleted_attrs = snapshot.keys() - new_attrs.keys() + + if len(added_attrs) > 0: + raise ValueError( + f"During torch.export, following attrs were created in the model.forward: {added_attrs} " + f"Such attributes must be registered as buffers using the `register_buffer` " + f"API and must be initialized at model.__init__ " + f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." + ) + + if len(deleted_attrs) > 0: + raise ValueError( + f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} " + f"Such attributes must be registered as buffers using the `register_buffer` " + f"API and must be initialized at model.__init__ " + f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." + ) + + pytree.tree_map_with_path( + _collect_assigned_tensor_attributes, snapshot, new_attrs + ) + # restore state of all attributes (including, e.g., of primitive types) + mod.__dict__.update(snapshot) + + if assigned_tensor_attributes: + if len(assigned_tensor_attributes) > 1: + noun, verb = "attributes", "were" + else: + noun, verb = "attribute", "was" + raise ValueError( + f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. " + "Such attributes must be registered as buffers using the `register_buffer` API " + "(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." + ) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) compiled_function = aot_function compiled_module = aot_module diff --git a/torch/_functorch/apis.py b/torch/_functorch/apis.py index 1faa767d4d05c..67bdea4e1a6b1 100644 --- a/torch/_functorch/apis.py +++ b/torch/_functorch/apis.py @@ -92,7 +92,11 @@ def vmap( doesn't provide a batched ``torch.dot`` API; instead of unsuccessfully rummaging through docs, use :func:`vmap` to construct a new function. +<<<<<<< HEAD >>> torch.dot # [D], [D] -> [] +======= + >>> torch.dot # [D], [D] -> [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot(x, y) @@ -104,7 +108,11 @@ def vmap( >>> weights = torch.randn(feature_size, requires_grad=True) >>> >>> def model(feature_vec): +<<<<<<< HEAD >>> # Very simple linear model with activation +======= + >>> # Very simple linear model with activation +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> return feature_vec.dot(weights).relu() >>> >>> examples = torch.randn(batch_size, feature_size) @@ -120,7 +128,11 @@ def vmap( >>> # Setup >>> N = 5 +<<<<<<< HEAD >>> f = lambda x: x**2 +======= + >>> f = lambda x: x ** 2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> x = torch.randn(N, requires_grad=True) >>> y = f(x) >>> I_N = torch.eye(N) @@ -137,49 +149,84 @@ def vmap( :func:`vmap` can also be nested, producing an output with multiple batched dimensions +<<<<<<< HEAD >>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap( ... torch.vmap(torch.dot) ... ) # [N1, N0, D], [N1, N0, D] -> [N1, N0] >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) >>> batched_dot(x, y) # tensor of size [2, 3] +======= + >>> torch.dot # [D], [D] -> [] + >>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0] + >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) + >>> batched_dot(x, y) # tensor of size [2, 3] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) If the inputs are not batched along the first dimension, ``in_dims`` specifies the dimension that each inputs are batched along as +<<<<<<< HEAD >>> torch.dot # [N], [N] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot( ... x, y ... ) # output is [5] instead of [2] if batched along the 0th dimension +======= + >>> torch.dot # [N], [N] -> [] + >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] + >>> x, y = torch.randn(2, 5), torch.randn(2, 5) + >>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) If there are multiple inputs each of which is batched along different dimensions, ``in_dims`` must be a tuple with the batch dimension for each input as +<<<<<<< HEAD >>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(5) >>> batched_dot( ... x, y ... ) # second arg doesn't have a batch dim because in_dim[1] was None +======= + >>> torch.dot # [D], [D] -> [] + >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] + >>> x, y = torch.randn(2, 5), torch.randn(5) + >>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) If the input is a Python struct, ``in_dims`` must be a tuple containing a struct matching the shape of the input: +<<<<<<< HEAD >>> f = lambda dict: torch.dot(dict["x"], dict["y"]) >>> x, y = torch.randn(2, 5), torch.randn(5) >>> input = {"x": x, "y": y} >>> batched_dot = torch.vmap(f, in_dims=({"x": 0, "y": None},)) +======= + >>> f = lambda dict: torch.dot(dict['x'], dict['y']) + >>> x, y = torch.randn(2, 5), torch.randn(5) + >>> input = {'x': x, 'y': y} + >>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> batched_dot(input) By default, the output is batched along the first dimension. However, it can be batched along any dimension by using ``out_dims`` +<<<<<<< HEAD >>> f = lambda x: x**2 >>> x = torch.randn(2, 5) >>> batched_pow = torch.vmap(f, out_dims=1) >>> batched_pow(x) # [5, 2] +======= + >>> f = lambda x: x ** 2 + >>> x = torch.randn(2, 5) + >>> batched_pow = torch.vmap(f, out_dims=1) + >>> batched_pow(x) # [5, 2] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) For any function that uses kwargs, the returned function will not batch the kwargs but will accept kwargs @@ -190,7 +237,11 @@ def vmap( >>> >>> batched_pow = torch.vmap(fn) >>> assert torch.allclose(batched_pow(x), x * 4) +<<<<<<< HEAD >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5] +======= + >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. note:: vmap does not provide general autobatching or handle variable-length @@ -343,7 +394,11 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla >>> batch_size, feature_size = 3, 5 >>> >>> def model(weights, feature_vec): +<<<<<<< HEAD >>> # Very simple linear model with activation +======= + >>> # Very simple linear model with activation +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> assert feature_vec.dim() == 1 >>> return feature_vec.dot(weights).relu() >>> @@ -355,9 +410,13 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla >>> examples = torch.randn(batch_size, feature_size) >>> targets = torch.randn(batch_size) >>> inputs = (weights, examples, targets) +<<<<<<< HEAD >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))( ... *inputs ... ) +======= + >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Example of using ``grad`` with ``has_aux`` and ``argnums``: diff --git a/torch/_functorch/benchmark_utils.py b/torch/_functorch/benchmark_utils.py index ba0b31c018bd1..9e3bc9f02aa9c 100644 --- a/torch/_functorch/benchmark_utils.py +++ b/torch/_functorch/benchmark_utils.py @@ -185,12 +185,17 @@ def benchmark_utilization( ``` def f(a): return a.sum() +<<<<<<< HEAD a = torch.rand(2**20, device="cuda") utilization, mm_conv_utilization = benchmark_utilization( f, a, "tmp", trace_file_name="tmp_chrome_trace" ) +======= + a = torch.rand(2**20, device="cuda") + utilization, mm_conv_utilization = benchmark_utilization(f, a, "tmp", trace_file_name = "tmp_chrome_trace") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` Args: diff --git a/torch/_functorch/compile_utils.py b/torch/_functorch/compile_utils.py index 929b58540f413..5fc9a2a2b931b 100644 --- a/torch/_functorch/compile_utils.py +++ b/torch/_functorch/compile_utils.py @@ -179,7 +179,11 @@ def raise_getitems(gm: fx.GraphModule) -> fx.GraphModule: ) # loop through getitem nodes in the graph and raise them to the parent node +<<<<<<< HEAD # in reverse order to preserve their original relative order +======= + # in reverse order to perserve their original relative order +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for node in reversed(getitem_nodes): assert len(node.all_input_nodes) == 1 parent = node.all_input_nodes[0] diff --git a/torch/_functorch/compilers.py b/torch/_functorch/compilers.py index 5295a526e25c1..9f56c1a448d2d 100644 --- a/torch/_functorch/compilers.py +++ b/torch/_functorch/compilers.py @@ -31,7 +31,11 @@ log = logging.getLogger(__name__) +<<<<<<< HEAD # These canonicalization are needed here (and not decompositions), as the ops +======= +# These canonicalizations are needed here (and not decompositions), as the ops +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # we're trying to canonicalize to CompositeImplicitAutograd. def _canonicalize(fx_g): for node in fx_g.graph.find_nodes( @@ -150,6 +154,7 @@ def check_significant_strides(a, b): def check(nv, rv, desc): assert callable(desc) assert nv.dtype == rv.dtype, f"{desc()}: {nv.dtype} != {rv.dtype}" +<<<<<<< HEAD assert subst_symint_tuple(nv.size()) == rv.size(), ( f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}" ) @@ -157,6 +162,15 @@ def check(nv, rv, desc): assert same_strides, ( f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}" ) +======= + assert ( + subst_symint_tuple(nv.size()) == rv.size() + ), f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}" + same_strides = check_significant_strides(nv, rv) + assert ( + same_strides + ), f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r = super().run_node(n) if "val" in n.meta: @@ -249,7 +263,11 @@ def memory_efficient_fusion( Args: fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module`` +<<<<<<< HEAD that takes one or more arguments. Must return one or more Tensors. +======= + that takes one ore more arguments. Must return one or more Tensors. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) **kwargs: Any other overrides you want to make to the settings Returns: diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 5bf2dee3e1d7d..d403f29b92e64 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -7,7 +7,10 @@ """ Global flags for aot autograd """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import os import sys from typing import Literal, Optional, TYPE_CHECKING @@ -61,10 +64,13 @@ # need to add env vars or make it configurable bundled_autograd_cache: bool = False +<<<<<<< HEAD # Whether or not to normalize placeholder names in graphs # from dynaom in AOTAutogradCache autograd_cache_normalize_inputs = not is_fbcode() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def remote_autograd_cache_default() -> Optional[bool]: if os.environ.get("TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE") == "1": @@ -234,6 +240,7 @@ def remote_autograd_cache_default() -> Optional[bool]: # of tensors in question. fake_tensor_propagate_real_tensors = False +<<<<<<< HEAD # AOTDispatcher traces out a backward graph at the time of the forward pass. # This flags controls whether or not that backward graph gets autocast behavior # applied to it. @@ -268,6 +275,8 @@ def remote_autograd_cache_default() -> Optional[bool]: # z.backward() backward_pass_autocast = "same_as_forward" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This controls whether we collect donated buffer. This flag must be set # False if a user wants to retain_graph=True for backward. donated_buffer = False if is_fbcode() else True @@ -281,6 +290,7 @@ def remote_autograd_cache_default() -> Optional[bool]: # real tensor outputs. generate_fake_kernels_from_real_mismatches = False +<<<<<<< HEAD # When there are device mismatches in FakeTensor device propagation, # prefer a specific device type over others. This is particularly useful # in full compiled mode where intermediate tensors with device mismatches @@ -292,6 +302,8 @@ def remote_autograd_cache_default() -> Optional[bool]: # CPU, or "cuda" to prefer CUDA devices over CPU. fake_tensor_prefer_device_type: Optional[str] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # CUDAGraph save run_with_rng functionalization. # TODO: turn on by default graphsafe_rng_functionalization = True @@ -307,7 +319,11 @@ def remote_autograd_cache_default() -> Optional[bool]: # which can reorder or ,delete duplicate nodes in the graph # - If any of these passes reorder/delete/duplicate a collective # in a setting where the compiler is being run independently on multiple +<<<<<<< HEAD # ranks, we run the risk that the compiler will make a different decision on +======= +# ranks, we run the risk that the compiler will make a different decison on +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # different ranks, resulting in a NCCL hang when using torch.compile # To handle this, we will (by default) ensure that collectives are not modified # by the compiler. @@ -338,7 +354,11 @@ def remote_autograd_cache_default() -> Optional[bool]: # This is a temporary config to ensure all ranks take the same decision in the partitioner # it will untimately be removed once we share size_hints across ranks through compiler collectives +<<<<<<< HEAD _sync_decision_cross_ranks = False +======= +_broadcast_rank0_decision = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # By default apply inlined saved_tensors_hooks only for "donated" buffers. # "donated" buffers are invisible to the user, they are intermediates of the forward graph. diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index d99995b86f2ba..132bd92b340ab 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -233,7 +233,11 @@ def vjp(func: Callable, *primals, has_aux: bool = False): >>> x = torch.randn([5]) >>> f = lambda x: x.sin().sum() >>> (_, vjpfunc) = torch.func.vjp(f, x) +<<<<<<< HEAD >>> grad = vjpfunc(torch.tensor(1.0))[0] +======= + >>> grad = vjpfunc(torch.tensor(1.))[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> assert torch.allclose(grad, torch.func.grad(f)(x)) However, :func:`vjp` can support functions with multiple outputs by @@ -248,9 +252,15 @@ def vjp(func: Callable, *primals, has_aux: bool = False): :func:`vjp` can even support outputs being Python structs >>> x = torch.randn([5]) +<<<<<<< HEAD >>> f = lambda x: {"first": x.sin(), "second": x.cos()} >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> cotangents = {"first": torch.ones([5]), "second": torch.ones([5])} +======= + >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} + >>> (_, vjpfunc) = torch.func.vjp(f, x) + >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> vjps = vjpfunc(cotangents) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin()) @@ -274,7 +284,11 @@ def vjp(func: Callable, *primals, has_aux: bool = False): >>> >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc(torch.ones_like(x)) +<<<<<<< HEAD >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.0)) +======= + >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. note:: Using PyTorch ``torch.no_grad`` together with ``vjp``. @@ -930,7 +944,12 @@ def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None: return if not isinstance(output, tuple): raise RuntimeError( +<<<<<<< HEAD f"{api}: Expected output of f to be a Tensor or Tensors, got {type(output)}" +======= + f"{api}: Expected output of f to be a Tensor or Tensors, got " + f"{type(output)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if len(output) == 0: raise RuntimeError( @@ -1022,10 +1041,17 @@ def jvp( >>> from torch.func import jvp >>> x = torch.randn([]) +<<<<<<< HEAD >>> f = lambda x: x * torch.tensor([1.0, 2.0, 3]) >>> value, grad = jvp(f, (x,), (torch.tensor(1.0),)) >>> assert torch.allclose(value, f(x)) >>> assert torch.allclose(grad, torch.tensor([1.0, 2, 3])) +======= + >>> f = lambda x: x * torch.tensor([1., 2., 3]) + >>> value, grad = jvp(f, (x,), (torch.tensor(1.),)) + >>> assert torch.allclose(value, f(x)) + >>> assert torch.allclose(grad, torch.tensor([1., 2, 3])) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) :func:`jvp` can support functions with multiple inputs by passing in the tangents for each of the inputs diff --git a/torch/_functorch/functional_call.py b/torch/_functorch/functional_call.py index 8d019871ffee3..635554aef00a9 100644 --- a/torch/_functorch/functional_call.py +++ b/torch/_functorch/functional_call.py @@ -60,10 +60,14 @@ def functional_call( .. code-block:: python +<<<<<<< HEAD a = ( {"weight": torch.ones(1, 1)}, {"buffer": torch.zeros(1)}, ) # two separate dictionaries +======= + a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer print(mod.weight) # tensor(...) print(mod.buffer) # tensor(...) @@ -86,12 +90,18 @@ def functional_call( t = torch.randn(4, 3) model = nn.Linear(3, 3) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def compute_loss(params, x, t): y = functional_call(model, params, x) return nn.functional.mse_loss(y, t) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t) .. note:: If the user does not need grad tracking outside of grad transforms, they can detach all of the @@ -184,11 +194,17 @@ def stack_module_state( models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] data = torch.randn(batch_size, 3) +<<<<<<< HEAD def wrapper(params, buffers, data): return torch.func.functional_call(models[0], (params, buffers), data) +======= + def wrapper(params, buffers, data): + return torch.func.functional_call(models[0], (params, buffers), data) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) params, buffers = stack_module_state(models) output = vmap(wrapper, (0, 0, None))(params, buffers, data) @@ -199,8 +215,11 @@ def wrapper(params, buffers, data): .. code-block:: python import torch.nn as nn +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Foo(nn.Module): def __init__(self, in_features, out_features): super().__init__() @@ -211,7 +230,10 @@ def __init__(self, in_features, out_features): def forward(self, x): return self.l2(self.l1(x)) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) num_models = 5 in_features, out_features = 3, 3 models = [Foo(in_features, out_features) for i in range(num_models)] diff --git a/torch/_functorch/make_functional.py b/torch/_functorch/make_functional.py index 16988a022a977..47b97b4447df9 100644 --- a/torch/_functorch/make_functional.py +++ b/torch/_functorch/make_functional.py @@ -374,12 +374,18 @@ def make_functional( model = nn.Linear(3, 3) func, params = make_functional(model) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def compute_loss(params, x, t): y = func(params, x) return nn.functional.mse_loss(y, t) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) grad_weights = grad(compute_loss)(params, x, t) If the model has any buffers, please use :func:`make_functional_with_buffers` instead. @@ -445,12 +451,18 @@ def make_functional_with_buffers( model = nn.Linear(3, 3) func, params, buffers = make_functional_with_buffers(model) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def compute_loss(params, buffers, x, t): y = func(params, buffers, x) return nn.functional.mse_loss(y, t) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) grad_weights = grad(compute_loss)(params, buffers, x, t) Args: @@ -473,7 +485,11 @@ def compute_loss(params, buffers, x, t): def transpose_stack( +<<<<<<< HEAD tuple_of_tuple_of_tensors: tuple[tuple[Tensor, ...], ...], +======= + tuple_of_tuple_of_tensors: tuple[tuple[Tensor, ...], ...] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> tuple[Tensor, ...]: tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors)) results = tuple( diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 9030cfc3c17ca..460e54eb691dd 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -48,7 +48,10 @@ ilp_knapsack, ) from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator +<<<<<<< HEAD from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ._aot_autograd.logging_utils import get_aot_graph_name from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems @@ -176,7 +179,10 @@ def _extract_graph_with_inputs_outputs( joint_graph: fx.Graph, inputs: list[fx.Node], outputs: list[fx.Node], +<<<<<<< HEAD outputs_descs: list[AOTOutput], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) subgraph: Optional[str] = None, ) -> fx.Graph: """ @@ -204,10 +210,13 @@ def _extract_graph_with_inputs_outputs( env[node] = InvalidNode # type: ignore[assignment] continue +<<<<<<< HEAD if _must_be_in_forward(node) and subgraph != "forward": env[node] = InvalidNode # type: ignore[assignment] continue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if node in env: # Node must be one of our inputs. (Any member of env which wasn't an # input to start must have been created by this loop and won't be in @@ -235,6 +244,7 @@ def _extract_graph_with_inputs_outputs( if isinstance(x, fx.Node): if x not in env: raise RuntimeError(f"Node {x} couldn't be found in env") +<<<<<<< HEAD assert not isinstance(env[x], InvalidNodeBase), ( f"Node {x} was invalid, but is output" ) @@ -243,6 +253,15 @@ def _extract_graph_with_inputs_outputs( output_values.append(x) out = new_graph.output(tuple(output_values)) out.meta["desc"] = outputs_descs +======= + assert not isinstance( + env[x], InvalidNodeBase + ), f"Node {x} was invalid, but is output" + output_values.append(env[x]) + else: + output_values.append(x) + new_graph.output(tuple(output_values)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_graph.eliminate_dead_code() new_graph.lint() @@ -282,18 +301,24 @@ def _has_tag_is_backward(node: fx.Node) -> bool: return node.meta.get("partitioner_tag", None) == "is_backward" +<<<<<<< HEAD def _has_tag_must_be_in_forward(node: fx.Node) -> bool: return node.meta.get("partitioner_tag", None) == "must_be_in_forward" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _has_tag_must_be_in_backward(node: fx.Node) -> bool: return node.meta.get("partitioner_tag", None) == "must_be_in_backward" +<<<<<<< HEAD def _must_be_in_forward(node: fx.Node) -> bool: return _has_tag_must_be_in_forward(node) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _must_be_in_backward(node: fx.Node) -> bool: return _has_tag_must_be_in_backward(node) or ( _has_tag_is_backward(node) and is_with_effects(node) @@ -302,6 +327,7 @@ def _must_be_in_backward(node: fx.Node) -> bool: def _extract_fwd_bwd_outputs( joint_module: fx.GraphModule, *, num_fwd_outputs +<<<<<<< HEAD ) -> tuple[list[fx.Node], list[fx.Node], list[AOTOutput], list[AOTOutput]]: outputs = pytree.arg_tree_leaves( *(node.args for node in joint_module.graph.find_nodes(op="output")) @@ -316,6 +342,15 @@ def _extract_fwd_bwd_outputs( fwd_outputs_descs = outputs_descs[:num_fwd_outputs] bwd_outputs_descs = outputs_descs[num_fwd_outputs:] return fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs +======= +) -> tuple[list[fx.Node], list[fx.Node]]: + outputs = pytree.arg_tree_leaves( + *(node.args for node in joint_module.graph.find_nodes(op="output")) + ) + fwd_outputs = outputs[:num_fwd_outputs] + bwd_outputs = outputs[num_fwd_outputs:] + return fwd_outputs, bwd_outputs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _remove_by_name(saved_values: list[fx.Node], name: str): @@ -471,10 +506,17 @@ def perform_quantization( args=(clamp_max_scaled_node, quant_type), name="fp8_quant_" + str(node.name), ) +<<<<<<< HEAD quant_activation_node.meta["val"] = ( torch.ops.prims.convert_element_type.default( clamp_max_scaled_node.meta["val"], quant_type ) +======= + quant_activation_node.meta[ + "val" + ] = torch.ops.prims.convert_element_type.default( + clamp_max_scaled_node.meta["val"], quant_type +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) quant_activation_node.meta["tensor_meta"] = extract_tensor_metadata( quant_activation_node.meta["val"] @@ -523,7 +565,11 @@ def should_quantize(node: torch.fx.Node) -> bool: ].get("skip_dynamo_guards", False): return size_in_mb >= size_threshold else: +<<<<<<< HEAD # case 1: we always quantize tensors with dynamic shapes +======= + # case 1: we alway quantize tensors with dynamic shapes +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if torch._inductor.config.post_grad_fusion_options[ "activation_quantization_aten_pass" ].get("quantize_dynamic_shape", False): @@ -531,7 +577,11 @@ def should_quantize(node: torch.fx.Node) -> bool: size_in_mb >= size_threshold ) or not statically_known_false(size_in_mb >= size_threshold) else: +<<<<<<< HEAD # case 2: we always not quantize tensors with dynamic shapes +======= + # case 2: we alway not quantize tensors with dynamic shapes +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return statically_known_true(size_in_mb >= size_threshold) @@ -589,10 +639,17 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None: args=(node, quant_type), name="fp8_quant_" + str(node.name), ) +<<<<<<< HEAD quant_node.meta["val"] = ( torch.ops.prims.convert_element_type.default( node.meta["val"], quant_type ) +======= + quant_node.meta[ + "val" + ] = torch.ops.prims.convert_element_type.default( + node.meta["val"], quant_type +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) quant_node.meta["tensor_meta"] = extract_tensor_metadata( quant_node.meta["val"] @@ -600,9 +657,15 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None: node_to_quant[node] = quant_node # only update the return node args, and remain all other users unchanged output_updated_args = [ +<<<<<<< HEAD node_to_quant[node] if node in node_to_quant else node for node in fwd_outputs ] # add the scale nodes to the output find the first sym_node in the output +======= + node_to_quant[node] if node in node_to_quant else node for node in fwd_outputs # type: ignore[union-attr] + ] + # add the scale nodes to the ouput find the first sym_node in the output +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) idx = find_first_sym_node(output_updated_args) scale_nodes = tensor_scale_nodes + sym_scale_nodes if scale_nodes: @@ -639,10 +702,17 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None: torch.ops.prims.convert_element_type.default, args=(node, dequant_type), ) +<<<<<<< HEAD activation_node.meta["val"] = ( torch.ops.prims.convert_element_type.default( node.meta["val"], dequant_type ) +======= + activation_node.meta[ + "val" + ] = torch.ops.prims.convert_element_type.default( + node.meta["val"], dequant_type +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) activation_node.meta["tensor_meta"] = extract_tensor_metadata( activation_node.meta["val"] @@ -655,18 +725,31 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None: divided_target_node_32.meta["val"] = torch.ops.aten.div.Tensor( activation_node.meta["val"], scale_node.meta["val"] ) +<<<<<<< HEAD divided_target_node_32.meta["tensor_meta"] = ( extract_tensor_metadata(divided_target_node_32.meta["val"]) ) +======= + divided_target_node_32.meta[ + "tensor_meta" + ] = extract_tensor_metadata(divided_target_node_32.meta["val"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with graph.inserting_after(divided_target_node_32): dequant_node = graph.call_function( torch.ops.prims.convert_element_type.default, args=(divided_target_node_32, dequant_type), ) +<<<<<<< HEAD dequant_node.meta["val"] = ( torch.ops.prims.convert_element_type.default( divided_target_node_32.meta["val"], dequant_type ) +======= + dequant_node.meta[ + "val" + ] = torch.ops.prims.convert_element_type.default( + divided_target_node_32.meta["val"], dequant_type +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) dequant_node.meta["tensor_meta"] = extract_tensor_metadata( dequant_node.meta["val"] @@ -678,10 +761,17 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None: args=(node, dequant_type), name="dequant_" + str(node.name), ) +<<<<<<< HEAD dequant_node.meta["val"] = ( torch.ops.prims.convert_element_type.default( node.meta["val"], dequant_type ) +======= + dequant_node.meta[ + "val" + ] = torch.ops.prims.convert_element_type.default( + node.meta["val"], dequant_type +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) dequant_node.meta["tensor_meta"] = extract_tensor_metadata( dequant_node.meta["val"] @@ -694,11 +784,55 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None: counters["inductor"]["activation_quantization_bwd_aten_pass"] += 1 +<<<<<<< HEAD def perform_fp8_activation_quantization( fwd_module: fx.GraphModule, bwd_module: fx.GraphModule, bwd_module_inputs: dict[str, fx.Node], ) -> None: +======= +def enable_activation_quantization( + saved_values: list[fx.Node], + fwd_module: fx.GraphModule, + bwd_module: fx.GraphModule, + static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None, +) -> None: + if ( + inductor_config.post_grad_fusion_options.get( + "activation_quantization_aten_pass", None + ) + is None + ): + return + + static_input_names = ( + [node.name for node in static_lifetime_input_nodes] + if static_lifetime_input_nodes + else [] + ) + saved_values_names = {node.name: node for node in saved_values} + if torch._inductor.config.post_grad_fusion_options[ + "activation_quantization_aten_pass" + ].get("exclude_primals", False): + saved_values_names = { + node.name: node for node in saved_values if "primals" not in node.name + } + fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0] + bwd_module_inputs = { + node.name: node for node in bwd_module.graph.find_nodes(op="placeholder") + } + for node in fwd_module_outputs: + if node.name in saved_values_names and should_quantize(node): + if node.name in static_input_names: + log.debug("Skipping quantization of static input %s: ", node.name) + continue + node.meta["saved_for_quantization"] = True + node.meta["dequant_type"] = node.meta["val"].dtype + # some of the fwd outputs and bwd inputs are not share the same object + bwd_module_inputs[node.name].meta["saved_for_quantization"] = True + bwd_module_inputs[node.name].meta["dequant_type"] = node.meta["val"].dtype + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) trace_structured( "artifact", metadata_fn=lambda: { @@ -782,6 +916,7 @@ def perform_fp8_activation_quantization( ) +<<<<<<< HEAD def enable_activation_quantization( saved_values: list[fx.Node], fwd_module: fx.GraphModule, @@ -829,6 +964,8 @@ def enable_activation_quantization( perform_fp8_activation_quantization(fwd_module, bwd_module, bwd_module_inputs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _extract_fwd_bwd_modules( joint_module: fx.GraphModule, saved_values: list[fx.Node], @@ -837,8 +974,13 @@ def _extract_fwd_bwd_modules( num_fwd_outputs: int, static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None, ) -> tuple[fx.GraphModule, fx.GraphModule]: +<<<<<<< HEAD fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) +======= + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( + joint_module, num_fwd_outputs=num_fwd_outputs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) placeholders = joint_module.graph.find_nodes(op="placeholder") primal_inputs = [*filter(_is_primal, placeholders)] @@ -851,7 +993,10 @@ def _extract_fwd_bwd_modules( joint_module.graph, saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs, bwd_outputs, +<<<<<<< HEAD bwd_outputs_descs, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "backward", ) @@ -925,11 +1070,14 @@ def _extract_fwd_bwd_modules( joint_module.graph, primal_inputs + fwd_seed_offset_inputs, fwd_outputs + saved_values + saved_sym_nodes, +<<<<<<< HEAD fwd_outputs_descs + [ SavedForBackwardsAOTOutput(i) for i in range(len(saved_values) + len(saved_sym_nodes)) ], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "forward", ) bwd_graph = _extract_graph_with_inputs_outputs( @@ -940,7 +1088,10 @@ def _extract_fwd_bwd_modules( + bwd_seed_offset_inputs + backward_state_inputs, bwd_outputs, +<<<<<<< HEAD bwd_outputs_descs, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "backward", ) @@ -993,11 +1144,19 @@ def default_partition( primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) inputs = primal_inputs + fwd_seed_offset_inputs +<<<<<<< HEAD fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) ) forward_only_graph = _extract_graph_with_inputs_outputs( joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" +======= + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( + joint_module, num_fwd_outputs=num_fwd_outputs + ) + forward_only_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, inputs, fwd_outputs, "forward" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) forward_node_names = OrderedSet( node.name for node in forward_only_graph.nodes if node.op != "output" @@ -1122,7 +1281,11 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: """ This pass finds the first bwd node in the graph (by looking at users of tangents) and then reorders the graph by walking from this node to all the +<<<<<<< HEAD way to the end of the graph. At each op in this traversal, we insert this op +======= + way to the end of the graph. At each op in this traveral, we insert this op +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) in a new graph and try to bring only the relevant subgraph from the other non-bwd edges relevant for this op. This closely mimics the behavior of autograd engine. @@ -1345,6 +1508,7 @@ def get_device(node) -> Optional[torch.device]: return torch.device("cpu") def get_sample_rng_state(device: Optional[torch.device]): +<<<<<<< HEAD from torch._guards import detect_fake_mode # noqa: F401 fake_mode = detect_fake_mode() @@ -1353,6 +1517,11 @@ def get_sample_rng_state(device: Optional[torch.device]): if device is not None and device.type == "cuda": return fake_mode.from_tensor(torch.cuda.get_rng_state()) return fake_mode.from_tensor(torch.get_rng_state()) +======= + if device is not None and device.type == "cuda": + return torch.cuda.get_rng_state() + return torch.get_rng_state() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd. joint_graph_rng_ops = get_rng_ops(joint_module) @@ -1392,7 +1561,11 @@ def get_sample_rng_state(device: Optional[torch.device]): get_device(node_pair["fwd"]) for node_pair in recomputable_rng_ops_map.values() ) devices.discard(torch.device("cpu")) +<<<<<<< HEAD # multiple cuda devices won't work with cudagraphs anyway, +======= + # multiple cuda devices wont work with cudagraphs anyway, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # fallback to non graphsafe rng checkpointing multi_cuda_devices = len(devices) > 1 @@ -1447,8 +1620,11 @@ def get_sample_rng_state(device: Optional[torch.device]): args=(functional_fw_node, 0), kwargs={}, ) +<<<<<<< HEAD state.meta["val"] = get_sample_rng_state(device) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) rng_output = fw_graph.create_node( "call_function", operator.getitem, @@ -1458,9 +1634,12 @@ def get_sample_rng_state(device: Optional[torch.device]): ), kwargs={}, ) +<<<<<<< HEAD # Copy the meta data from the original node rng_output.meta = copy.copy(fw_node.meta) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fw_node.replace_all_uses_with(rng_output) fw_graph.erase_node(fw_node) fw_rng_state_outputs.append(state) @@ -1516,6 +1695,7 @@ def force_save_collectives(joint_module: fx.GraphModule) -> None: node.meta["recompute"] = CheckpointPolicy.MUST_SAVE +<<<<<<< HEAD def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None: # If we have mutations of the same primal in forward and backward, # We must not recompute the source of mutation to not apply twice. @@ -1539,6 +1719,8 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None: break +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: """ If there are two consecutive checkpointed blocks with no operator in @@ -2415,7 +2597,11 @@ def get_saved_values_knapsack(memory_budget, node_info, joint_graph): # if idx in all_recomputable_banned_nodes: try: dont_ban.add(all_recomputable_banned_nodes[idx]) +<<<<<<< HEAD except BaseException: # noqa: B036 +======= + except BaseException: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pass assert dont_ban.issubset(all_recomputable_banned_nodes) @@ -2517,7 +2703,11 @@ def estimate_for_budget(b): )[0] +<<<<<<< HEAD def _sync_decision_cross_ranks( +======= +def _broadcast_rank0_decision( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) joint_graph: torch.fx.Graph, saved_values: list[torch.fx.Node] ): # use the same policy across different GPUs @@ -2553,6 +2743,7 @@ def has_same_nodes(joint_graph): ): with no_dispatch(), unset_fake_temporarily(): objects = [[x.name for x in saved_values]] +<<<<<<< HEAD saved_ops_names_all_ranks: list[list[str]] = [ [] for _ in range(torch.distributed.get_world_size()) ] @@ -2680,6 +2871,16 @@ def thread_graphsafe_rng_from_hops(module, is_backward): return module +======= + # TODO: maybe use a different process group for this + torch.distributed.broadcast_object_list(objects, src=0) + saved_values_names = objects[0] + name_to_node = get_name_to_node(joint_graph) + saved_values = [name_to_node[n] for n in saved_values_names] + return saved_values + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def min_cut_rematerialization_partition( joint_module: fx.GraphModule, _joint_inputs, @@ -2731,7 +2932,10 @@ def min_cut_rematerialization_partition( joint_module = cleanup_recompute_tags(joint_module) if not config.unsafe_allow_optimization_of_collectives: force_save_collectives(joint_module) +<<<<<<< HEAD force_save_bw_mutation_src(joint_module) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def classify_nodes(joint_module, static_lifetime_input_indices): name_to_node = get_name_to_node(joint_module.graph) @@ -2750,14 +2954,23 @@ def classify_nodes(joint_module, static_lifetime_input_indices): filter(_is_fwd_seed_offset, joint_module.graph.nodes) ) inputs = primal_inputs + fwd_seed_offset_inputs +<<<<<<< HEAD fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) +======= + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( + joint_module, num_fwd_outputs=num_fwd_outputs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) required_bw_nodes.update( o for o in bwd_outputs if o is not None and o.op != "output" ) forward_only_graph = _extract_graph_with_inputs_outputs( +<<<<<<< HEAD joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" +======= + joint_module.graph, inputs, fwd_outputs, "forward" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) required_fw_nodes: OrderedSet[fx.Node] = OrderedSet( name_to_node[node.name] @@ -2823,8 +3036,13 @@ def classify_nodes(joint_module, static_lifetime_input_indices): node_info, memory_budget=memory_budget, ) +<<<<<<< HEAD if config._sync_decision_cross_ranks: saved_values = _sync_decision_cross_ranks(joint_graph, saved_values) +======= + if config._broadcast_rank0_decision: + saved_values = _broadcast_rank0_decision(joint_graph, saved_values) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes = list(filter(is_sym_node, saved_values)) saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) @@ -2849,9 +3067,12 @@ def classify_nodes(joint_module, static_lifetime_input_indices): fw_module = raise_getitems(fw_module) bw_module = raise_getitems(bw_module) +<<<<<<< HEAD fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False) bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if AOT_PARTITIONER_DEBUG: # Calculate sorted sizes of saved values sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values]) diff --git a/torch/_functorch/top_operators_github_usage.py b/torch/_functorch/top_operators_github_usage.py index 171c6fc6c1e01..efbb8acf46e75 100644 --- a/torch/_functorch/top_operators_github_usage.py +++ b/torch/_functorch/top_operators_github_usage.py @@ -4,7 +4,10 @@ From https://docs.google.com/spreadsheets/d/12R3nCOLskxPYjjiNkdqy4OdQ65eQp_htebXGODsjSeA/edit#gid=0 Try to keep this list in sync with that. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import operator diff --git a/torch/_functorch/vmap.py b/torch/_functorch/vmap.py index 5e3893fef5cd0..a193ba8121237 100644 --- a/torch/_functorch/vmap.py +++ b/torch/_functorch/vmap.py @@ -9,18 +9,31 @@ import contextlib import functools import itertools +<<<<<<< HEAD +======= +import os +import threading +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from functools import partial from typing import Any, Callable, Optional, Union import torch from torch import Tensor +<<<<<<< HEAD from torch._C._functorch import is_batchedtensor from torch._functorch.predispatch import ( +======= +from torch._C._functorch import ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _add_batch_dim, _remove_batch_dim, _vmap_decrement_nesting, _vmap_increment_nesting, +<<<<<<< HEAD lazy_load_decompositions, +======= + is_batchedtensor, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from torch.utils._pytree import ( _broadcast_to_and_flatten, @@ -257,6 +270,60 @@ def _get_name(func: Callable): return repr(func) +<<<<<<< HEAD +======= +DECOMPOSITIONS_LOADED = False +DECOMPOSITIONS_LOCK = threading.Lock() +VMAP_DECOMPOSITIONS_LIB = None + + +# torch.package, Python 3.11, and torch.jit-less environments are unhappy with +# decompositions. Only load them when needed if possible. +def lazy_load_decompositions(): + global DECOMPOSITIONS_LOADED + if DECOMPOSITIONS_LOADED: + return + + with DECOMPOSITIONS_LOCK: + if DECOMPOSITIONS_LOADED: + return + + if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__): + DECOMPOSITIONS_LOADED = True + return + + # use an alternate way to register an operator into the decomposition table + # _register_jit_decomposition doesn't work for some operators, e.g. addr, + # because the Tensor types generated cannot be unioned by torchscript + # decomp should be type OpOverload + global VMAP_DECOMPOSITIONS_LIB + VMAP_DECOMPOSITIONS_LIB = torch.library.Library( + "aten", "IMPL", "FuncTorchBatched" + ) + + from torch._decomp import decomposition_table + + def _register_python_decomposition_vmap(decomp): + if decomp in decomposition_table: + VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp]) + else: + raise RuntimeError(f"could not find decomposition for {decomp}") + + _register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default) + _register_python_decomposition_vmap( + torch.ops.aten.smooth_l1_loss_backward.default + ) + _register_python_decomposition_vmap(torch.ops.aten.huber_loss_backward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss_forward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_forward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss_backward.default) + _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_backward.default) + _register_python_decomposition_vmap(torch.ops.aten.addr.default) + + DECOMPOSITIONS_LOADED = True + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs): lazy_load_decompositions() _check_out_dims_is_int_or_int_pytree(out_dims, func) diff --git a/torch/_guards.py b/torch/_guards.py index f6f053ea064cb..9046fc16ade3d 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -1,3 +1,7 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from __future__ import annotations import contextlib @@ -36,16 +40,22 @@ if TYPE_CHECKING: +<<<<<<< HEAD from collections.abc import Generator, Iterator +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from types import CodeType import sympy +<<<<<<< HEAD from torch._dynamo.backends.distributed import DDPOptimizerContext from torch._dynamo.codegen import PyCodegen from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta from torch._subclasses.fake_tensor import FakeTensorMode +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ torch._guards is the definitional source of truth for general purpose guard structures. @@ -88,7 +98,11 @@ class CompileId: # TODO: consider also tracking the recompilation count # See Note: Updating CompileId +<<<<<<< HEAD def __str__(self) -> str: +======= + def __str__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NOTE: Keep this in sync with both from_string and the tlparse repo if self.compiled_autograd_id is not None: assert (self.frame_id is None) == (self.frame_compile_id is None) @@ -102,7 +116,11 @@ def __str__(self) -> str: return f"{self.frame_id}/{self.frame_compile_id}" @classmethod +<<<<<<< HEAD def from_string(cls, compile_id: Optional[str]) -> Optional[CompileId]: +======= + def from_string(cls, compile_id: Optional[str]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Factory method that creates a CompileId from its string representation. Keep this in sync with the __str__ method. @@ -130,7 +148,11 @@ class TraceId(NamedTuple): # up by one attempt: int +<<<<<<< HEAD def __str__(self) -> str: +======= + def __str__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Keep this in sync with tlparse repo if self.attempt == 0: return str(self.compile_id) @@ -190,7 +212,11 @@ def is_unspecialized_builtin_nn_module(self) -> bool: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, ) +<<<<<<< HEAD def is_local(self) -> bool: +======= + def is_local(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self in ( GuardSource.LOCAL, GuardSource.LOCAL_SPECIALIZED_NN_MODULE, @@ -223,7 +249,11 @@ class SLoc: framework_loc: Optional[Union[traceback.FrameSummary, str]] maybe_user_loc: Optional[str] +<<<<<<< HEAD def __str__(self) -> str: +======= + def __str__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) floc = ( self.framework_loc if isinstance(self.framework_loc, str) @@ -268,19 +298,32 @@ class Guard: guard_types: Optional[list[str]] = None code_list: Optional[list[str]] = None obj_weakref: Optional[object] = None +<<<<<<< HEAD guarded_class_weakref: Optional[weakref.ReferenceType[Any]] = None +======= + guarded_class_weakref: Optional[type] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stack: Optional[CapturedTraceback] = None user_stack: Optional[traceback.StackSummary] = None _hash: Optional[int] = None +<<<<<<< HEAD _unserializable: bool = False def __hash__(self) -> int: +======= + + def __hash__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self._hash is None: self._hash = hash((self.name, self.source, id(self.create_fn))) return self._hash +<<<<<<< HEAD def sort_key(self) -> tuple[bool, int, int, str, int]: +======= + def sort_key(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Put the duplicate input guards at the end. The duplicate guards have # two sources while guard.name only considers one source. @@ -296,10 +339,17 @@ def sort_key(self) -> tuple[bool, int, int, str, int]: self.inner_create_fn().__code__.co_firstlineno, ) +<<<<<<< HEAD def __lt__(self, other: Guard) -> bool: return self.sort_key() < other.sort_key() def inner_create_fn(self) -> Callable[[GuardBuilderBase, Guard], Any]: +======= + def __lt__(self, other): + return self.sort_key() < other.sort_key() + + def inner_create_fn(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(self.create_fn, functools.partial): return self.create_fn.func else: @@ -314,7 +364,11 @@ def source(self) -> GuardSource: return self.originating_source.guard_source() @staticmethod +<<<<<<< HEAD def weakref_to_str(obj_weakref: object) -> str: +======= + def weakref_to_str(obj_weakref): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This is a workaround of a Python weakref bug. @@ -338,7 +392,11 @@ def __getattr__(self, x): else: return str(obj_weakref) +<<<<<<< HEAD def __repr__(self) -> str: +======= + def __repr__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) s = f""" {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__} {{ @@ -350,7 +408,11 @@ def __repr__(self) -> str: """ return s +<<<<<<< HEAD def __str__(self) -> str: +======= + def __str__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output = f"Name: {repr(self.name)}\n" source = self.source.name.lower() if self.source else "" output += f" Source: {source}\n" @@ -361,7 +423,11 @@ def __str__(self) -> str: output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n" return output +<<<<<<< HEAD def create(self, builder: GuardBuilderBase) -> Any: +======= + def create(self, builder: GuardBuilderBase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: return self.create_fn(builder, self) except Exception: @@ -370,6 +436,7 @@ def create(self, builder: GuardBuilderBase) -> Any: log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip()) raise +<<<<<<< HEAD def is_specialized_nn_module(self) -> bool: return self.source.is_specialized_nn_module() @@ -393,6 +460,18 @@ def set_export_info( code_list: list[str], obj_weakref: object, ) -> None: +======= + def is_specialized_nn_module(self): + return self.source.is_specialized_nn_module() + + def is_fsdp_module(self): + return self.source.is_fsdp_module() + + def is_local(self): + return self.source.is_local() + + def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not self.guard_types: self.guard_types = [] @@ -447,7 +526,11 @@ class DuplicateInputs(GuardEnvExpr): input_source_a: Source input_source_b: Source +<<<<<<< HEAD def __post_init__(self) -> None: +======= + def __post_init__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.input_source_a != self.input_source_b @@ -478,7 +561,11 @@ class StorageOverlap(GuardEnvExpr): can also be taken in at restore_graphstate(T) calls. When to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable +<<<<<<< HEAD does not provide any guarantees around consistency, idempotency, or safety of calling its APIs, yet. +======= +does not provide any garuantees around consistency, idempotency, or safety of calling its APIs, yet. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) In the future, it will have a closer coupling to a generic Checkpoint management system. """ @@ -489,7 +576,11 @@ class Checkpointable(Generic[T]): def copy_graphstate(self) -> T: ... @abstractmethod +<<<<<<< HEAD def restore_graphstate(self, state: T) -> None: ... +======= + def restore_graphstate(self, state: T): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class GuardsCheckpointState: @@ -499,10 +590,17 @@ class GuardsCheckpointState: dynamo_guards: set[Guard] = set() +<<<<<<< HEAD def __init__(self, dynamo_guards: set[Guard]) -> None: self.dynamo_guards = dynamo_guards def diff(self, other: GuardsCheckpointState) -> Optional[set[Guard]]: +======= + def __init__(self, dynamo_guards): + self.dynamo_guards = dynamo_guards + + def diff(self, other): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Produces a delta against another GuardsCheckpointState. @@ -514,19 +612,30 @@ def diff(self, other: GuardsCheckpointState) -> Optional[set[Guard]]: return None return r +<<<<<<< HEAD def __eq__(self, other: object) -> bool: if not isinstance(other, GuardsCheckpointState): return False +======= + def __eq__(self, other): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.diff(other) is None class ModuleContextCheckpointState: nn_modules: dict[str, torch.nn.Module] = {} +<<<<<<< HEAD def __init__(self, nn_modules: dict[str, torch.nn.Module]) -> None: self.nn_modules = nn_modules def diff(self, other: ModuleContextCheckpointState) -> Optional[set[str]]: +======= + def __init__(self, nn_modules): + self.nn_modules = nn_modules + + def diff(self, other): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Produces a delta against another ModuleContextCheckpointState. @@ -538,9 +647,13 @@ def diff(self, other: ModuleContextCheckpointState) -> Optional[set[str]]: return None return r +<<<<<<< HEAD def __eq__(self, other: object) -> bool: if not isinstance(other, ModuleContextCheckpointState): return False +======= + def __eq__(self, other): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.diff(other) is None @@ -548,21 +661,37 @@ class ModuleContext(Checkpointable[ModuleContextCheckpointState]): def __init__(self) -> None: self.nn_modules: dict[str, Any] = {} +<<<<<<< HEAD def copy_graphstate(self) -> ModuleContextCheckpointState: return ModuleContextCheckpointState(dict(self.nn_modules)) def restore_graphstate(self, state: ModuleContextCheckpointState) -> None: +======= + def copy_graphstate(self): + return ModuleContextCheckpointState(dict(self.nn_modules)) + + def restore_graphstate(self, state): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(state, ModuleContextCheckpointState) self.nn_modules = state.nn_modules class GlobalContextCheckpointState: +<<<<<<< HEAD global_state: dict[str, tuple[Callable, Any]] = {} def __init__(self, global_states: dict[str, tuple[Callable, Any]]) -> None: self.global_state = global_states def diff(self, other: GlobalContextCheckpointState) -> Optional[set[str]]: +======= + global_state: dict[str, tuple[Callable, ...]] = {} + + def __init__(self, global_states): + self.global_state = global_states + + def diff(self, other): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Produces a delta against another GlobalContextCheckpointState. @@ -574,9 +703,13 @@ def diff(self, other: GlobalContextCheckpointState) -> Optional[set[str]]: return None return r +<<<<<<< HEAD def __eq__(self, other: object) -> bool: if not isinstance(other, GlobalContextCheckpointState): return False +======= + def __eq__(self, other): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.diff(other) is None @@ -596,12 +729,21 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]): } def __init__(self) -> None: +<<<<<<< HEAD self.global_state: dict[str, tuple[Callable, Any]] = {} def copy_graphstate(self) -> GlobalContextCheckpointState: return GlobalContextCheckpointState(self.global_state) def restore_graphstate(self, state: GlobalContextCheckpointState) -> None: +======= + self.global_state: dict[str, tuple[Callable, ...]] = {} + + def copy_graphstate(self): + return GlobalContextCheckpointState(dict(self.global_state)) + + def restore_graphstate(self, state): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(state, GlobalContextCheckpointState) self.global_state = state.global_state assert ( @@ -615,19 +757,31 @@ def restore_graphstate(self, state: GlobalContextCheckpointState) -> None: # Like a Set[Guard] but will record the user stack on all guards at the # time they were installed at their destination class GuardsSet: +<<<<<<< HEAD def __init__(self, inner: Optional[set[Guard]] = None) -> None: +======= + def __init__(self, inner=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if inner is None: inner = set() self.inner = inner +<<<<<<< HEAD def __iter__(self) -> Iterator[Guard]: return iter(self.inner) def __len__(self) -> int: +======= + def __iter__(self): + return iter(self.inner) + + def __len__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return len(self.inner) # Subtraction along with bool is typically used to determine the delta of # added guards between checkpoints for higher order ops +<<<<<<< HEAD def __sub__(self, other: GuardsSet) -> GuardsSet: return GuardsSet(self.inner - other.inner) @@ -637,21 +791,42 @@ def __bool__(self) -> bool: def add( self, guard: Guard, *, collect_debug_stack: bool = True, skip: int = 0 ) -> None: +======= + def __sub__(self, other): + return GuardsSet(self.inner - other.inner) + + def __bool__(self): + return bool(self.inner) + + def add(self, guard: Guard, *, collect_debug_stack=True, skip=0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if guard in self.inner: return if collect_debug_stack: if guard.stack is None: guard.stack = CapturedTraceback.extract(skip=1 + skip) +<<<<<<< HEAD if guard.user_stack is None: guard.user_stack = TracingContext.extract_stack() self.inner.add(guard) def update(self, *others: set[Guard]) -> None: +======= + if guard.user_stack is None: + guard.user_stack = TracingContext.extract_stack() + self.inner.add(guard) + + def update(self, *others: set[Guard]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for o in others: for g in o: self.add(g, skip=1) +<<<<<<< HEAD def remove_guards_with_source(self, source: Source) -> None: +======= + def remove_guards_with_source(self, source): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Delete all guards that contains a given source""" from ._dynamo.source import is_from_source @@ -673,10 +848,17 @@ def __init__(self) -> None: self.dynamo_guards: GuardsSet = GuardsSet() self.aotautograd_guards: list[GuardEnvExpr] = [] +<<<<<<< HEAD def copy_graphstate(self) -> GuardsCheckpointState: return GuardsCheckpointState(set(self.dynamo_guards.inner)) def restore_graphstate(self, state: GuardsCheckpointState) -> None: +======= + def copy_graphstate(self): + return GuardsCheckpointState(set(self.dynamo_guards.inner)) + + def restore_graphstate(self, state): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: "steals" the passed in state assert isinstance(state, GuardsCheckpointState) self.dynamo_guards = GuardsSet(state.dynamo_guards) @@ -684,12 +866,17 @@ def restore_graphstate(self, state: GuardsCheckpointState) -> None: class HopSubgraphCache: @abstractmethod +<<<<<<< HEAD def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: ... +======= + def add_dynamo_installed_submodule(self, fn_id: int, identifier: str): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @abstractmethod def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ... @abstractmethod +<<<<<<< HEAD def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: ... @abstractmethod @@ -700,6 +887,18 @@ def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: ... @abstractmethod def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]: ... +======= + def add_autograd_key_entry(self, identifier: str, key: Callable): ... + + @abstractmethod + def get_autograd_key_entry(self, identifier: str): ... + + @abstractmethod + def add_proxy_dispatch_entry(self, identifier: str, key: Callable): ... + + @abstractmethod + def get_proxy_dispatch_entry(self, identifier: str): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @abstractmethod def add_lazy_bwd_entry( @@ -707,12 +906,20 @@ def add_lazy_bwd_entry( identifier: str, tangent_metadata: tuple[object], gmod: torch.fx.GraphModule, +<<<<<<< HEAD ) -> int: ... +======= + ): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @abstractmethod def get_lazy_bwd_entry( self, identifier: str, tangent_metadata: tuple[object] +<<<<<<< HEAD ) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]: ... +======= + ) -> int: ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class InvokeSubgraphCache(HopSubgraphCache): @@ -724,12 +931,17 @@ def __init__(self) -> None: str, dict[tuple[object], tuple[torch.fx.GraphModule, int]] ] = defaultdict(dict) +<<<<<<< HEAD def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: +======= + def add_dynamo_installed_submodule(self, fn_id: int, identifier: str): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.dynamo_installed_submodules[fn_id].append(identifier) def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: return self.dynamo_installed_submodules.get(fn_id, []) +<<<<<<< HEAD def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: self.autograd_cache[identifier] = key @@ -740,6 +952,18 @@ def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: self.proxy_dispatch_cache[identifier] = key def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]: +======= + def add_autograd_key_entry(self, identifier: str, key: Callable): + self.autograd_cache[identifier] = key + + def get_autograd_key_entry(self, identifier: str): + return self.autograd_cache.get(identifier, None) + + def add_proxy_dispatch_entry(self, identifier: str, key: Callable): + self.proxy_dispatch_cache[identifier] = key + + def get_proxy_dispatch_entry(self, identifier: str): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.proxy_dispatch_cache.get(identifier, None) def add_lazy_bwd_entry( @@ -747,15 +971,23 @@ def add_lazy_bwd_entry( identifier: str, tangent_metadata: tuple[object], gmod: torch.fx.GraphModule, +<<<<<<< HEAD ) -> int: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Save the number of existing graph modules in the dictionary to get the suffix num_gmods = len(self.lazy_bwd_cache[identifier]) self.lazy_bwd_cache[identifier][tangent_metadata] = (gmod, num_gmods) return num_gmods +<<<<<<< HEAD def get_lazy_bwd_entry( self, identifier: str, tangent_metadata: tuple[object] ) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]: +======= + def get_lazy_bwd_entry(self, identifier: str, tangent_metadata: tuple[object]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if identifier not in self.lazy_bwd_cache: return (None, None) @@ -808,7 +1040,11 @@ def get() -> CompileContext: def try_get() -> Optional[CompileContext]: return getattr(_TLS, "compile_context", None) +<<<<<<< HEAD def __init__(self, compile_id: Optional[CompileId]) -> None: +======= + def __init__(self, compile_id): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert compile_id is None or isinstance(compile_id, CompileId) self.compile_id: Optional[CompileId] = compile_id self.attempt = 0 @@ -816,14 +1052,22 @@ def __init__(self, compile_id: Optional[CompileId]) -> None: self.shape_env_guards: list[str] = [] @staticmethod +<<<<<<< HEAD def current_compile_id() -> Optional[CompileId]: +======= + def current_compile_id(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self = CompileContext.try_get() if self is None: return None return self.compile_id @staticmethod +<<<<<<< HEAD def current_trace_id() -> Optional[TraceId]: +======= + def current_trace_id(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self = CompileContext.try_get() if self is None: return None @@ -852,6 +1096,7 @@ def get() -> TracingContext: "TracingContext.get() must be called within an ongoing trace." ) +<<<<<<< HEAD def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None: self.guards_context = GuardsContext() self.module_context = ModuleContext() @@ -860,12 +1105,23 @@ def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None: self.previously_cleaned_instructions: dict[Any, Any] = dict() self.fake_mode: Optional[FakeTensorMode] = fake_mode self.frame_summary_stack: list[traceback.FrameSummary] = [] +======= + def __init__(self, fake_mode): + self.guards_context = GuardsContext() + self.module_context = ModuleContext() + self.global_context = GlobalContext() + self.previously_inlined_functions = dict() + self.previously_cleaned_instructions = dict() + self.fake_mode = fake_mode + self.frame_summary_stack = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This is morally part of frame_summary_stack, but it is kept separate # for clarity. As we process a frame, this variable gets updated # to keep track of what line we are in the function. We make a # function call, this gets cleared and the frame location is pushed # to frame_summary_stack (prepping this variable for the inner frame's # progress) +<<<<<<< HEAD self.loc_in_frame: Optional[tuple[str, int, str]] = None # this is only set after aot_autograd self.fw_metadata: Optional[ViewAndMutationMeta] = None @@ -876,6 +1132,16 @@ def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None: self.params_flat: Optional[list[Any]] = None self.params_flat_unwrap_subclasses: Optional[list[Any]] = None self.params_unwrapped_to_flat_index: Optional[list[Any]] = None +======= + self.loc_in_frame = None + # this is only set after aot_autograd + self.fw_metadata = None + # this is only set after aot_autograd + self.aot_graph_name = None + self.params_flat = None + self.params_flat_unwrap_subclasses = None + self.params_unwrapped_to_flat_index = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # this is for extended return calling convention from backend # compiler to aot_autograd # Per output, what the compiler specified stride of the output is, @@ -895,7 +1161,11 @@ def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None: # See note [Tensor Fakification and Symbol Caching] self.tensor_to_context = WeakTensorKeyDictionary() +<<<<<<< HEAD # If this true, Aot Autograd will return output Fake Tensors with appropriate +======= + # If this true, Aot Autograd will return output Fake Tensors with appropiate +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # meta on the first invocation # see note: [Returning Fake Tensors on First AOT Autograd Call] self.fakify_first_call = False @@ -903,7 +1173,11 @@ def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None: # list of code objects for inlined functions self.traced_code: list[CodeType] = [] +<<<<<<< HEAD def clear(self) -> None: +======= + def clear(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Look at the note in output_graph.py in function `save_global_state` # for the context on clearing global context. self.global_context.global_state = {} @@ -912,7 +1186,11 @@ def clear(self) -> None: @staticmethod @contextmanager +<<<<<<< HEAD def patch(**kwargs: Any) -> Generator[None, None, None]: +======= + def patch(**kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prior = {} ctx = TracingContext.get() @@ -928,7 +1206,11 @@ def patch(**kwargs: Any) -> Generator[None, None, None]: setattr(ctx, key, val) @staticmethod +<<<<<<< HEAD def extract_stack() -> traceback.StackSummary: +======= + def extract_stack(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self = TracingContext.try_get() if self is None: return traceback.StackSummary() @@ -937,7 +1219,11 @@ def extract_stack() -> traceback.StackSummary: stack = stack + [self._populate_loc_in_frame_summary()] return traceback.StackSummary.from_list(stack) +<<<<<<< HEAD def _populate_loc_in_frame_summary(self) -> traceback.FrameSummary: +======= + def _populate_loc_in_frame_summary(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.loc_in_frame is not None filename, lineno, frame_name = self.loc_in_frame return traceback.FrameSummary(filename, lineno, frame_name, lookup_line=False) @@ -946,7 +1232,11 @@ def _populate_loc_in_frame_summary(self) -> traceback.FrameSummary: # associated with the current frame state @staticmethod @contextlib.contextmanager +<<<<<<< HEAD def clear_frame() -> Generator[None, None, None]: +======= + def clear_frame(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tc = TracingContext.get() with ( unittest.mock.patch.object(tc, "frame_summary_stack", []), @@ -978,9 +1268,13 @@ def clear_frame() -> Generator[None, None, None]: @staticmethod @contextlib.contextmanager +<<<<<<< HEAD def current_frame( frame_summary: Optional[traceback.FrameSummary], ) -> Generator[None, None, None]: +======= + def current_frame(frame_summary): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # frame_summary can be None to solely take advantage of real_stack # attachment to thrown exceptions tc = TracingContext.get() @@ -1001,9 +1295,13 @@ def current_frame( @staticmethod @contextlib.contextmanager +<<<<<<< HEAD def report_output_strides() -> Generator[ Optional[list[Optional[tuple[int, ...]]]], None, None ]: +======= + def report_output_strides(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tc = TracingContext.try_get() if tc is None: yield None @@ -1016,13 +1314,21 @@ def report_output_strides() -> Generator[ tc.output_strides = old_output_strides @staticmethod +<<<<<<< HEAD def set_current_loc(filename: str, lineno: int, frame_name: str) -> None: +======= + def set_current_loc(filename, lineno, frame_name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Save the current location in the frame. Lazily generate the # framesummary. TracingContext.get().loc_in_frame = (filename, lineno, frame_name) @staticmethod +<<<<<<< HEAD def get_traced_code() -> Optional[list[CodeType]]: +======= + def get_traced_code(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tc = TracingContext.try_get() if tc is None: return None @@ -1030,9 +1336,13 @@ def get_traced_code() -> Optional[list[CodeType]]: @contextmanager +<<<<<<< HEAD def compile_context( context: Optional[CompileContext], ) -> Generator[Optional[CompileContext], None, None]: +======= +def compile_context(context: Optional[CompileContext]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) old_context = getattr(_TLS, "compile_context", None) _TLS.compile_context = context try: @@ -1042,9 +1352,13 @@ def compile_context( @contextmanager +<<<<<<< HEAD def tracing( context: Optional[TracingContext], ) -> Generator[Optional[TracingContext], None, None]: +======= +def tracing(context: Optional[TracingContext]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This function installs the passed in tracing context as a dynamic scoped global variable. @@ -1074,6 +1388,7 @@ def tracing( # TODO(voz): Consider a toplevel torch/_source.py @dataclasses.dataclass(frozen=True) class Source: +<<<<<<< HEAD def is_dict_key(self) -> bool: return False @@ -1081,6 +1396,15 @@ def is_ephemeral(self) -> bool: return False def reconstruct(self, codegen: PyCodegen) -> None: +======= + def is_dict_key(self): + return False + + def is_ephemeral(self): + return False + + def reconstruct(self, codegen): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError def guard_source(self) -> GuardSource: @@ -1089,7 +1413,11 @@ def guard_source(self) -> GuardSource: def name(self) -> str: raise NotImplementedError +<<<<<<< HEAD def make_guard(self, fn: Callable[..., Any]) -> Guard: +======= + def make_guard(self, fn) -> Guard: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.guard_source() is GuardSource.CONSTANT: raise NotImplementedError return Guard(self, fn) @@ -1097,7 +1425,11 @@ def make_guard(self, fn: Callable[..., Any]) -> Guard: def is_specialized_nn_module(self) -> bool: return self.guard_source().is_specialized_nn_module() +<<<<<<< HEAD def subguards_allowed(self) -> bool: +======= + def subguards_allowed(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """True if you can guard on attributes of this""" return self.guard_source() != GuardSource.SYNTHETIC_LOCAL @@ -1107,11 +1439,19 @@ def subguards_allowed(self) -> bool: class ChainedSource(Source): base: Source +<<<<<<< HEAD def is_dict_key(self) -> bool: # Recurse until you either hit a ConstDictKey or a Source return self.base.is_dict_key() def is_ephemeral(self) -> bool: +======= + def is_dict_key(self): + # Recurse until you either hit a ConstDictKey or a Source + return self.base.is_dict_key() + + def is_ephemeral(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.base.is_ephemeral() def get_base(self) -> Source: @@ -1121,7 +1461,11 @@ def get_base(self) -> Source: return current +<<<<<<< HEAD def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]: +======= +def detect_fake_mode(inputs: Any = None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Attempts to "detect" what the current fake mode is. If there is one ambiently available from TracingContext, we preferentially use that. Otherwise, we @@ -1165,7 +1509,11 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]: return None +<<<<<<< HEAD def active_fake_mode() -> Optional[FakeTensorMode]: +======= +def active_fake_mode(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Inspects the dispatch mode stack for an active fake mode and returns it. Returns None if no fake mode is active. diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py index e809c729dc424..aaad7bcb2966b 100644 --- a/torch/_higher_order_ops/__init__.py +++ b/torch/_higher_order_ops/__init__.py @@ -27,10 +27,14 @@ from torch._higher_order_ops.scan import scan from torch._higher_order_ops.strict_mode import strict_mode from torch._higher_order_ops.torchbind import call_torchbind +<<<<<<< HEAD from torch._higher_order_ops.while_loop import ( while_loop, while_loop_stack_output_op as while_loop_stack_output, ) +======= +from torch._higher_order_ops.while_loop import while_loop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._higher_order_ops.wrap import ( dynamo_bypassing_wrapper, tag_activation_checkpoint, @@ -72,5 +76,8 @@ "strict_mode", "aoti_call_delegate", "map", +<<<<<<< HEAD "while_loop_stack_output", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] diff --git a/torch/_higher_order_ops/aoti_call_delegate.py b/torch/_higher_order_ops/aoti_call_delegate.py index bb2c62de7617a..cb618a117d6ce 100644 --- a/torch/_higher_order_ops/aoti_call_delegate.py +++ b/torch/_higher_order_ops/aoti_call_delegate.py @@ -156,9 +156,13 @@ def call_delegate_functionalize( ) with ctx.redispatch_to_next(): res = aoti_call_delegate( +<<<<<<< HEAD lowered_module, original_gm, unwrapped_weight_args, # type: ignore[arg-type] unwrapped_input_args, # type: ignore[arg-type] +======= + lowered_module, original_gm, unwrapped_weight_args, unwrapped_input_args # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return ctx.wrap_tensors(res) diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index fa59ee244fec1..b79dbdd106f49 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -5,11 +5,16 @@ import torch import torch._prims_common as utils +<<<<<<< HEAD +======= +import torch._subclasses.functional_tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch.utils._pytree as pytree from torch._C import DispatchKey from torch._higher_order_ops.utils import ( _maybe_compile_and_run_fn, _maybe_run_with_interpreter, +<<<<<<< HEAD check_input_alias_and_mutation_return_outputs, check_meta_consistency, create_bw_fn, @@ -20,6 +25,12 @@ save_tensors_and_symints_for_backward, saved_tensors_and_symints, split_into_chunks, +======= + autograd_not_implemented, + check_meta_consistency, + first_slice_copy, + reenter_make_fx, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unique_graph_id, validate_subgraph_args_types, ) @@ -36,9 +47,15 @@ def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves): +<<<<<<< HEAD assert len(args) == 2 * num_leaves, ( f"Combin_fn received wrong number of arguments, expected {2 * num_leaves}, but got {len(args)}" ) +======= + assert ( + len(args) == 2 * num_leaves + ), f"Combin_fn received wrong number of arguments, expected {2 * num_leaves}, but got {len(args)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lhs = pytree.tree_unflatten(args[:num_leaves], spec) rhs = pytree.tree_unflatten(args[num_leaves:], spec) return combine_fn(lhs, rhs) @@ -84,9 +101,15 @@ def __call__(self, combine_fn, xs, additional_inputs): # the additional_inputs being a list. See https://github.com/pytorch/pytorch/issues/145785 # Once this issue is resolved, the assertion should only allow tuples # and the tuple cast should be removed +<<<<<<< HEAD assert isinstance(additional_inputs, (tuple, list)), ( "additional_inputs must be a tuple." ) +======= + assert isinstance( + additional_inputs, (tuple, list) + ), "additional_inputs must be a tuple." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) additional_inputs = ( tuple(additional_inputs) if isinstance(additional_inputs, list) @@ -95,6 +118,7 @@ def __call__(self, combine_fn, xs, additional_inputs): validate_subgraph_args_types(additional_inputs) return super().__call__(combine_fn, xs, additional_inputs) +<<<<<<< HEAD def gen_schema(self, combine_fn, xs, additional_inputs): from torch._higher_order_ops.schema import HopSchemaGenerator from torch._higher_order_ops.utils import materialize_as_graph @@ -147,6 +171,8 @@ def gen_schema(self, combine_fn, xs, additional_inputs): schema_gen.add_schema_tree_spec(combine_fn, xs, additional_inputs) return schema_gen.gen_schema() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) associative_scan_op = AssociativeScanOp() @@ -191,6 +217,7 @@ def associative_scan( def add(x: torch.Tensor, y: torch.Tensor): return x + y +<<<<<<< HEAD cumsum = associative_scan(add, x, dim) @@ -198,6 +225,11 @@ def add(x: torch.Tensor, y: torch.Tensor): # TODO: Support lifted arguments in inductor for associative_scan # TODO: Support autograd for cases with lifted arguments for combine_mode=pointwise +======= + cumsum = associative_scan(add, x, dim) + + """ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The reason we flatten xs before calling into dynamo is that # we want to create a consistent input ordering for combine_fn # and we also want to the input ordering matches the output ordering. @@ -249,6 +281,12 @@ def _validate_input(cfn, lxs, d, r, cm): if reverse: leaves_xs = [torch.flip(elem, [0]) for elem in leaves_xs] +<<<<<<< HEAD +======= + # TODO: Support Autograd + # TODO: Unify handling of pytrees for control flow ops, such as cond, while_loop, etc. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if combine_mode == "generic": # The generic_associative_scan implementation calls the combine_fn with a `batch` along the scan dimension # For example, consider: @@ -435,9 +473,15 @@ def trace_associative_scan( assert outputs is not None outputs = pytree.tree_leaves(outputs) +<<<<<<< HEAD assert len(outputs) == len(xs), ( f"expected combine_fn to return {len(xs)} results but got {len(outputs)}" ) +======= + assert len(outputs) == len( + xs + ), f"expected combine_fn to return {len(xs)} results but got {len(outputs)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) xs_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [ first_slice_copy(x) for x in xs @@ -472,6 +516,7 @@ def associative_scan_op_dense(combine_fn, xs, additional_inputs): return generic_associative_scan(combine_fn, xs, additional_inputs=additional_inputs) +<<<<<<< HEAD class AssociativeScanAutogradOp(torch.autograd.Function): r""" associative_scan Example:: @@ -844,6 +889,11 @@ def associative_scan_autograd(combine_fn, xs, additional_inputs): *(tuple(xs) + tuple(additional_inputs)), ) return (*flat_out,) +======= +associative_scan_op.py_autograd_impl( + autograd_not_implemented(associative_scan_op, deferred_error=True) +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @associative_scan_op.py_impl(ProxyTorchDispatchMode) diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index d5aa0d09c8b18..337df905da9ad 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -3,7 +3,11 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass +<<<<<<< HEAD from typing import Any, Callable, get_args, Optional, Union +======= +from typing import Any, get_args, Optional, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._library.utils as library_utils @@ -522,8 +526,12 @@ def do_auto_functionalize( ) with ctx.redispatch_to_next(): unwrapped_outs = auto_functionalized( +<<<<<<< HEAD op, **unwrapped_kwargs, # type: ignore[arg-type] +======= + op, **unwrapped_kwargs # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # List of the name of args that get mutated (according to the schema) @@ -572,6 +580,7 @@ def sync_update(o, orig_arg): return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type] +<<<<<<< HEAD # Wrapper for GraphModule that applies functionalization during execution to enable # epilogue graph inlining and better fusion opportunities in subgraphs # When tracing this wrapper, we'll get a graph module with epilogue. @@ -594,6 +603,8 @@ def __hash__(self): return id(self.orig_callable) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def do_auto_functionalize_v2( mode: "torch._subclasses.functional_tensor.FunctionalTensorMode", op: Union[OpOverload, HopInstance], @@ -614,7 +625,23 @@ def do_auto_functionalize_v2( def _functionalize_callable(arg: Any): if callable(arg): +<<<<<<< HEAD return FunctionalCallableWithEpilogue(arg) +======= + + def functional_fn(*args, **kwargs): + # We call torch.func.functionalize. This allows us to inline the epilogue graph. + # Inlining has the benefit of allowing easiser fusion inside subgraph. + # Though the epilogue graph contains copy_, it is OK becuase inductor can handle it + # and this is also how we have been supporting top-level graph input mutation. + return tuple( + pytree.tree_leaves(torch.func.functionalize(arg)(*args, **kwargs)) + ) + + return torch._higher_order_ops.base_hop.FunctionWithNoFreeVars( + functional_fn + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return arg args, kwargs = pytree.tree_map(_functionalize_callable, (args, kwargs)) @@ -705,8 +732,12 @@ def set_result(base_index): with ctx.redispatch_to_next(): unwrapped_outs = auto_functionalized_v2( +<<<<<<< HEAD op, **auto_func_kwargs, # type: ignore[arg-type] +======= + op, **auto_func_kwargs # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) unwrapped_actual_out: Union[Any, tuple[Any]] = ( @@ -718,9 +749,15 @@ def set_result(base_index): ) if isinstance(op, HigherOrderOperator): +<<<<<<< HEAD assert len(schema.returns) > 0, ( f"hop is expected to return at least one output {schema}." ) +======= + assert ( + len(schema.returns) > 0 + ), f"hop is expected to return at least one output {schema}." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(unwrapped_actual_out) == len(schema.returns) else: if len(schema.returns) == 0: @@ -944,7 +981,11 @@ def auto_functionalized_v2_proxy( # Below code materializes the callable inputs to the hop as graph modules. # kwargs may contain general callables, that are not proxable e.g. FunctionWithNoFreeVars # this could happen when we auto_functionalize the backward of the hop, +<<<<<<< HEAD # where backward fn is a callablle that wraps forward graph module. +======= + # where backward fn is a callablle that wrapps forward graph module. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This function materialize the callable args according to the schema of the hop. # We cannot materialize the callables in kwargs directly because the inputs to callable diff --git a/torch/_higher_order_ops/base_hop.py b/torch/_higher_order_ops/base_hop.py index 11826c3f6369b..baa89d3cd4bf8 100644 --- a/torch/_higher_order_ops/base_hop.py +++ b/torch/_higher_order_ops/base_hop.py @@ -6,7 +6,10 @@ import torch.utils._pytree as pytree from torch._C import DispatchKey from torch._dispatch.python import suspend_functionalization +<<<<<<< HEAD from torch._higher_order_ops.auto_functionalize import FunctionalCallableWithEpilogue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._higher_order_ops.utils import ( check_input_alias_and_mutation_return_outputs, HopInstance, @@ -40,6 +43,7 @@ class InvokeQuant(BaseHOP): def __init__(self): return super().__init__("invoke_quant") +<<<<<<< HEAD invoke_quant = InvokeQuant() @@ -48,6 +52,13 @@ def g(x): return x.sin().cos() +======= + invoke_quant = InvokeQuant() + + def g(x): + return x.sin().cos() + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch.compile(backend="aot_eager") def f(x): return invoke_quant(g, x, scheme="nf4") @@ -71,6 +82,7 @@ def __init__(self, hop_name) -> None: ) def __call__(self, subgraph, *operands, **kwargs): +<<<<<<< HEAD if not isinstance( subgraph, ( @@ -79,6 +91,9 @@ def __call__(self, subgraph, *operands, **kwargs): FunctionalCallableWithEpilogue, ), ): +======= + if not isinstance(subgraph, (torch.fx.GraphModule, FunctionWithNoFreeVars)): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise RuntimeError( f"{self._name}: when calling this API without torch.compile, " f"we require that the subgraph be a torch.fx.GraphModule (or " @@ -116,10 +131,14 @@ def _call_ProxyTorchDispatchMode(self, proxy_mode, subgraph, *operands, **kwargs out = self(subgraph, *operands, **kwargs) return track_tensor_tree( +<<<<<<< HEAD out, out_proxy, constant=None, tracer=proxy_mode.tracer, # type: ignore[arg-type] +======= + out, out_proxy, constant=None, tracer=proxy_mode.tracer # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def _call_FakeTensorMode(self, mode, subgraph, *operands, **kwargs): @@ -198,7 +217,11 @@ def gen_schema(self, subgraph, *operands, **kwargs): import warnings warnings.warn( +<<<<<<< HEAD "Aliasing is not supported for HOP subgraph.\n" +======= + "Aliasing is not suppported for HOP subgraph.\n" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"{subgraph.print_readable(print_output=False)}\n" f"Alias info: inp-inp alias: {inp_inp_alias}, inp-out alias: {inp_out_alias}, out-out alias{out_out_alias}" f"This may lead to silent incorrectness." @@ -236,11 +259,15 @@ def backward(ctx, *grad_outputs): kwargs = ctx.kwargs # TODO: Something special needs to happen with min cut partitioner +<<<<<<< HEAD with ( suspend_functionalization(), disable_functional_mode(), torch.enable_grad(), ): +======= + with suspend_functionalization(), disable_functional_mode(), torch.enable_grad(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with disable_proxy_modes_tracing(): from .invoke_subgraph import create_fw_bw_graph from .utils import _from_fun diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 7c13b9a0fd147..4f479f3766610 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -1,12 +1,19 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs import contextlib +<<<<<<< HEAD import functools +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import logging import warnings from typing import Any, Callable, Optional, Union import torch +<<<<<<< HEAD +======= +import torch._subclasses.functional_tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch.utils._pytree as pytree from torch._C import DispatchKey from torch._C._functorch import ( @@ -19,10 +26,13 @@ from torch._higher_order_ops.utils import ( _maybe_run_with_interpreter, _set_compilation_env, +<<<<<<< HEAD check_input_alias_and_mutation_return_outputs, create_bw_fn, fill_none_with_masks, filter_with_masks, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) materialize_as_graph, reenter_make_fx, save_tensors_and_symints_for_backward, @@ -40,6 +50,11 @@ ) from torch.utils._python_dispatch import _get_current_dispatch_mode +<<<<<<< HEAD +======= +from .utils import clone_outputs_aliasing_inputs + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) @@ -57,6 +72,7 @@ def __call__(self, pred, true_fn, false_fn, operands): validate_subgraph_args_types(operands) return super().__call__(pred, true_fn, false_fn, operands) +<<<<<<< HEAD def gen_schema(self, pred, true_fn, false_fn, operands): from torch._higher_order_ops.schema import HopSchemaGenerator from torch._higher_order_ops.utils import materialize_as_graph @@ -103,6 +119,8 @@ def gen_schema(self, pred, true_fn, false_fn, operands): schema_gen.add_schema_tree_spec(pred, true_fn, false_fn, operands) return schema_gen.gen_schema() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cond_op = CondOp() @@ -155,12 +173,17 @@ def cond(pred, true_branch, false_branch, operands): def true_fn(x: torch.Tensor): return x.cos() +<<<<<<< HEAD def false_fn(x: torch.Tensor): return x.sin() +======= + def false_fn(x: torch.Tensor): + return x.sin() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return cond(x.shape[0] > 4, true_fn, false_fn, (x,)) Restrictions: @@ -234,6 +257,7 @@ def _validate_input(pred, true_fn, false_fn, operands): def _cond_op_wrapper(*args, **kwargs): return cond_op(*args, **kwargs) +<<<<<<< HEAD with ( _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), @@ -244,6 +268,12 @@ def _cond_op_wrapper(*args, **kwargs): backend: Union[str, Callable[..., Any]] = ( make_eager_backend_with_torch_function_mode(metadata_mode) ) +======= + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode(): + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: backend = "eager" return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)( @@ -251,10 +281,71 @@ def _cond_op_wrapper(*args, **kwargs): ) +<<<<<<< HEAD def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): assert isinstance(operands, (list, tuple)), ( f"Cond operands must be a list or tuple of tensors and SymInts {operands}" ) +======= +def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable: + """ + For a fn that accepts flat inputs and returns flat outputs: + fw_out = fn(*args), + this function returns: + grad_args = bw_fn(*args_and_grad_output) + with the following invariants: + 1. args + fw_out has an 1-1 correspondence to args_and_grad_output + 2. grad_args has an 1-1 corresponsence to args + 3. for tensor arg whose requires_grad is False, its corresponding grad in + grad_args will be a zero tensor with the same shape. + """ + + from torch._functorch.aot_autograd import AOTConfig, create_joint + from torch._higher_order_ops.utils import prepare_fw_with_masks_all_requires_grad + + dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + ) + n_primals = len(args) + + bw_fn = create_joint( + prepare_fw_with_masks_all_requires_grad(fn), aot_config=dummy_aot_config + ) + + def flat_fn(*args_and_grad_outs): + primals = args_and_grad_outs[:n_primals] + tangents = args_and_grad_outs[n_primals:] + grad_args = bw_fn(primals, tangents)[1] + assert len(args) == len(grad_args) + # In order to keep HOPs functional where the backward graph, + # would have outputs that are aliasing inputs. + # For example in cases where the backward of the function is simply + # passing the upstream gradients through. + maybe_clone = clone_outputs_aliasing_inputs(args_and_grad_outs) + + return [ + ( + torch.zeros_like(arg) + if isinstance(arg, torch.Tensor) and grad is None + else maybe_clone(grad) + ) + for grad, arg in zip(grad_args, primals) + ] + + return flat_fn + + +def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): + assert isinstance( + operands, (list, tuple) + ), f"Cond operands must be a list or tuple of tensors and SymInts {operands}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) true_graph = reenter_make_fx(true_fn)(*operands) false_graph = reenter_make_fx(false_fn)(*operands) @@ -301,9 +392,15 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): @cond_op.py_impl(DispatchKey.CompositeExplicitAutograd) def cond_op_dense(pred, true_fn, false_fn, operands): +<<<<<<< HEAD assert all(isinstance(o, (torch.Tensor, int)) for o in operands), ( f"Dense implementation operands must be a list of tensors and ints {operands}" ) +======= + assert all( + isinstance(o, (torch.Tensor, int)) for o in operands + ), f"Dense implementation operands must be a list of tensors and ints {operands}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mode = _get_current_dispatch_mode() assert mode is None, "Mode should never be enabled for CPU/CUDA key" if pred: @@ -344,6 +441,7 @@ def backward(ctx, *flat_grads): operands = saved_tensors_and_symints(ctx) args = operands + flat_grads # TODO: we need to materialize the bw graphs because dynamo is unable to +<<<<<<< HEAD # trace through the joint function when torch.compile torch.autograd.grad. grads_tensor_masks = [] @@ -364,13 +462,22 @@ def wrapped(*args): true_bw_gm = materialize_as_graph( create_fn_remove_none(ctx._true_bw_fn), +======= + # trace through the joint funcion when torch.compile torch.autograd.grad. + true_bw_gm = materialize_as_graph( + ctx._true_bw_fn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args, ctx._fw_include_key_set, ctx._fw_exclude_key_set, force_enable_grad=True, ) false_bw_gm = materialize_as_graph( +<<<<<<< HEAD create_fn_remove_none(ctx._false_bw_fn), +======= + ctx._false_bw_fn, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args, ctx._fw_include_key_set, ctx._fw_exclude_key_set, @@ -382,7 +489,11 @@ def wrapped(*args): false_bw_gm, args, ) +<<<<<<< HEAD return None, None, None, *fill_none_with_masks(grads, grads_tensor_masks) +======= + return None, None, None, *grads +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Note: @@ -544,11 +655,19 @@ def _has_unbacked_symbols(s: Union[int, torch.SymInt]) -> bool: """ This follows the logic in symbolic_shapes._compute_symbolic_stride +<<<<<<< HEAD Step 2: Since tensor stride is an accumulative multiplication of the sizes, which is a permutated (due to view ops) non-descending sequence. Case 1: No size is 1. In this case, strides have unique values. For example, suppose we have a tensor with: +======= + Step 2: Since tensor stride is an accumulative muliplication of the sizes, which is a permutated + (due to view ops) non-decending sequence. + + Case 1: No size is 1. In this case, strides have unique values. + For example, suppose we have a tenosr with: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) size [3, 4, 3, 5, 4, 5], stride (1200, 300, 1, 12, 3, 60), merged_size [u0, u1, u2, u3, u4, u5]. @@ -565,7 +684,11 @@ def _has_unbacked_symbols(s: Union[int, torch.SymInt]) -> bool: ... Case 2: At least one dimension has size 1, which can produce duplicates in strides. +<<<<<<< HEAD In this case, theoretically, we cannot uniquely determine the expr of strides because +======= + In this case, theorectically, we cannot uniquely determine the expr of strides because +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) the accessing stride_expr with same key in different order causes the final stride expression to be different. @@ -575,7 +698,11 @@ def _has_unbacked_symbols(s: Union[int, torch.SymInt]) -> bool: merged_size: (u0, u1) The stride expr could either be (u1, 1) or (1, u0) depending on whether we start with u1 or u0. +<<<<<<< HEAD For this reason, we try to break tie by sorting via descending index so we always get (u1, 1). +======= + For this reason, we try to break tie by sorting via decending index so we always get (u1, 1). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Note that backend might optimize the strides anyway so this is usually not a problem as long as two branches matches. See relevant discussions in https://github.com/pytorch/pytorch/issues/142024. @@ -648,9 +775,15 @@ def _maybe_expr(s: Union[int, torch.SymInt]): if _maybe_expr(a_val) in a_stride_expr: a_expr = a_stride_expr[_maybe_expr(a_val)] +<<<<<<< HEAD assert b_stride_expr[_maybe_expr(b_val)] == a_expr, ( f"a_stride_expr:{a_stride_expr}, b_stride_expr:{b_stride_expr}" ) +======= + assert ( + b_stride_expr[_maybe_expr(b_val)] == a_expr + ), f"a_stride_expr:{a_stride_expr}, b_stride_expr:{b_stride_expr}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) merged_strides[i] = a_expr else: if a_val == 1: @@ -707,12 +840,21 @@ def cond_func(ctx, pred, true_fn, false_fn, inputs): @cond_op.py_impl(torch._C._functorch.TransformType.Vmap) def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs): +<<<<<<< HEAD assert isinstance(inputs, (list, tuple)), ( "Cond inputs must be a list or tuple of tensors" ) assert all(isinstance(i, torch.Tensor) for i in inputs), ( "Cond inputs must be a list of tensors" ) +======= + assert isinstance( + inputs, (list, tuple) + ), "Cond inputs must be a list or tuple of tensors" + assert all( + isinstance(i, torch.Tensor) for i in inputs + ), "Cond inputs must be a list of tensors" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pred_is_batched = isinstance(pred, torch.Tensor) and is_batchedtensor(pred) pred_ = get_unwrapped(pred) if pred_is_batched else pred diff --git a/torch/_higher_order_ops/effects.py b/torch/_higher_order_ops/effects.py index 23f7a5e474bdf..b5daaa5c82e47 100644 --- a/torch/_higher_order_ops/effects.py +++ b/torch/_higher_order_ops/effects.py @@ -240,9 +240,15 @@ def handle_effects( key = get_effect_key(op, args, kwargs) assert key is not None if key not in tokens: +<<<<<<< HEAD assert allow_token_discovery, ( f"Could not find a token for effect {key} which came from the function {op}" ) +======= + assert ( + allow_token_discovery + ), f"Could not find a token for effect {key} which came from the function {op}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) proxy_tensor_mode = torch._C._get_dispatch_mode( torch._C._TorchDispatchModeKey.PROXY ) diff --git a/torch/_higher_order_ops/executorch_call_delegate.py b/torch/_higher_order_ops/executorch_call_delegate.py index 3274502b943cd..d0b69934e0653 100644 --- a/torch/_higher_order_ops/executorch_call_delegate.py +++ b/torch/_higher_order_ops/executorch_call_delegate.py @@ -49,10 +49,14 @@ def _unwrap_proxy(e): if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)): return e return get_proxy_slot( +<<<<<<< HEAD cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy, # type: ignore[attr-defined] +======= + cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if not is_lowered_module(lowered_module): diff --git a/torch/_higher_order_ops/flat_apply.py b/torch/_higher_order_ops/flat_apply.py index 654e2ea38384a..9f0b9a78d3e83 100644 --- a/torch/_higher_order_ops/flat_apply.py +++ b/torch/_higher_order_ops/flat_apply.py @@ -108,7 +108,11 @@ def impl(func, in_spec, *flat_args): # # TODO: The following can be updated to support non-graphable outputs and pytrees. # For non-graphable constant outputs: the assumption would be that they are constant +<<<<<<< HEAD # (every time the function runs those MUST be the same) +======= + # (everytime the function runs those MUST be the same) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # For pytree outputs: # I'm not sure if we need to return (flat_output, spec) or just (flat_output,): # in the latter case the tracers need to carry out the output specs diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index e622a0ebee036..d48c3f5ad67db 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -38,9 +38,15 @@ def _construct_strides( ) -> Sequence[int]: """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" # Initialize strides +<<<<<<< HEAD assert len(sizes) == len(fill_order), ( "Length of sizes must match the length of the fill order" ) +======= + assert len(sizes) == len( + fill_order + ), "Length of sizes must match the length of the fill order" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) strides = [0] * len(sizes) # Start with stride 1 for the innermost dimension @@ -92,7 +98,11 @@ def __call__( kernel_options: dict[str, Any], score_mod_other_buffers: tuple = (), mask_mod_other_buffers: tuple = (), +<<<<<<< HEAD ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +======= + ) -> tuple[torch.Tensor, torch.Tensor]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers) return super().__call__( query, @@ -134,7 +144,10 @@ def __call__( torch.Tensor, torch.Tensor, torch.Tensor, tuple[Optional[torch.Tensor], ...] ]: validate_subgraph_args_types(score_mod_other_buffers + mask_mod_other_buffers) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().__call__( query, key, @@ -209,7 +222,11 @@ def math_attention( kernel_options: dict[str, Any], score_mod_other_buffers: tuple = (), mask_mod_other_buffers: tuple = (), +<<<<<<< HEAD ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +======= +) -> tuple[torch.Tensor, torch.Tensor]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Eager implementation This implementation uses vmap to vectorize the score_mod function over the batch, head, m, and n dimensions. @@ -252,6 +269,7 @@ def math_attention( masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1) logsumexp = torch.where(masked_rows, -float("inf"), logsumexp) +<<<<<<< HEAD # working precision will be used so no need to cast to fp32 max_scores = torch.max(post_mod_scores, dim=-1)[0] @@ -265,6 +283,11 @@ def math_attention( logsumexp / math.log(2), max_scores / math.log(2), ) +======= + post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1) + + return post_mod_scores.to(query.dtype) @ value, logsumexp / math.log(2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @flex_attention.py_impl(DispatchKey.CompositeExplicitAutograd) @@ -278,8 +301,13 @@ def sdpa_dense( kernel_options: dict[str, Any], score_mod_other_buffers: tuple = (), mask_mod_other_buffers: tuple = (), +<<<<<<< HEAD ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: out, lse, max_scores = math_attention( +======= +) -> tuple[torch.Tensor, torch.Tensor]: + out, lse = math_attention( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) query, key, value, @@ -291,7 +319,11 @@ def sdpa_dense( mask_mod_other_buffers, ) out = _permute_strides(out, query.stride()) +<<<<<<< HEAD return out, lse, max_scores +======= + return out, lse +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def trace_flex_attention( @@ -305,7 +337,11 @@ def trace_flex_attention( kernel_options: dict[str, Any], score_mod_other_buffers: tuple = (), mask_mod_other_buffers: tuple = (), +<<<<<<< HEAD ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +======= +) -> tuple[torch.Tensor, torch.Tensor]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Traces the flex_attention operator with the given score_mod function and other_buffers. Trace SDPA will call make_fx with "fake" example vals and then trace the score_mod function @@ -375,7 +411,11 @@ def flex_attention_proxy_torch_dispatch_mode( kernel_options: dict[str, Any], score_mod_other_buffers: tuple = (), mask_mod_other_buffers: tuple = (), +<<<<<<< HEAD ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +======= +) -> tuple[torch.Tensor, torch.Tensor]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert mode is not None, "Mode should always be enabled for python fallback key" return trace_flex_attention( mode, @@ -403,7 +443,11 @@ def flex_attention_functionalize( kernel_options: dict[str, Any], score_mod_other_buffers: tuple = (), mask_mod_other_buffers: tuple = (), +<<<<<<< HEAD ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +======= +) -> tuple[torch.Tensor, torch.Tensor]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Defines the functionalization rules for the flex_attention operator. Write now we are unwrapping each tensor and then redispatching to the next, however we want to @@ -488,7 +532,11 @@ def flex_attention_fake_impl( kernel_options: dict[str, Any], score_mod_other_buffers: tuple = (), mask_mod_other_buffers: tuple = (), +<<<<<<< HEAD ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +======= +) -> tuple[torch.Tensor, torch.Tensor]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if has_user_subclass( ( query, @@ -509,17 +557,28 @@ def flex_attention_fake_impl( if query.is_nested: out = torch.empty_like(query, memory_format=torch.contiguous_format) logsumexp = query.sum(dim=-1) +<<<<<<< HEAD max_scores = query.max(dim=-1)[0] return out, logsumexp, max_scores +======= + return out, logsumexp +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) v_head_dim = value.size(-1) batch_size, num_heads, seq_len_q, _q_head_dim = query.shape logsumexp = query.new_empty(batch_size, num_heads, seq_len_q, dtype=torch.float32) +<<<<<<< HEAD max_scores = query.new_empty(batch_size, num_heads, seq_len_q, dtype=torch.float32) out_shape = (batch_size, num_heads, seq_len_q, v_head_dim) out = query.new_empty(out_shape) out = _permute_strides(out, query.stride()) return out, logsumexp, max_scores +======= + out_shape = (batch_size, num_heads, seq_len_q, v_head_dim) + out = query.new_empty(out_shape) + out = _permute_strides(out, query.stride()) + return out, logsumexp +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Registers dispatches for SAC @@ -607,7 +666,11 @@ def joint_f( *other_buffers: tuple[Tensor, ...], ) -> tuple[Tensor, ...]: def fw_with_masks( +<<<<<<< HEAD *args: tuple[Tensor, ...], +======= + *args: tuple[Tensor, ...] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> tuple[tuple[Tensor], tuple[bool]]: fw_out = score_mod(*args) out_requires_grad = fw_out.requires_grad @@ -640,15 +703,25 @@ def forward( kernel_options: dict[str, Any], mask_mod_other_buffers: tuple[Any, ...], *score_mod_other_buffers: tuple[Any, ...], +<<<<<<< HEAD ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +======= + ) -> tuple[torch.Tensor, torch.Tensor]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) any_buffer_requires_grad = any( buffer.requires_grad for buffer in mask_mod_other_buffers if isinstance(buffer, torch.Tensor) ) +<<<<<<< HEAD assert not any_buffer_requires_grad, ( "Captured buffers from mask mod that require grad are not supported." ) +======= + assert ( + not any_buffer_requires_grad + ), "Captured buffers from mask mod that require grad are not supported." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ctx._fw_graph = fw_graph ctx._joint_graph = joint_graph ctx._mask_graph = block_mask[-1] @@ -656,7 +729,11 @@ def forward( ctx.kernel_options = kernel_options ctx._score_mod_other_buffers_len = len(score_mod_other_buffers) with torch._C._AutoDispatchBelowAutograd(): +<<<<<<< HEAD out, logsumexp, max_scores = flex_attention( +======= + out, logsumexp = flex_attention( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) query, key, value, @@ -667,8 +744,12 @@ def forward( score_mod_other_buffers, mask_mod_other_buffers, ) +<<<<<<< HEAD # no grads for you sir ctx.mark_non_differentiable(max_scores) +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) save_tensors_and_symints_for_backward( ctx, ( @@ -677,12 +758,16 @@ def forward( value, out, logsumexp, +<<<<<<< HEAD max_scores, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *block_mask[:-1], *score_mod_other_buffers, *mask_mod_other_buffers, ), ) +<<<<<<< HEAD return out, logsumexp, max_scores @staticmethod @@ -692,6 +777,12 @@ def backward( # type: ignore[override] grad_logsumexp: Tensor, grad_max_scores: Tensor, ) -> tuple[Optional[Tensor], ...]: +======= + return out, logsumexp + + @staticmethod + def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> tuple[Optional[Tensor], ...]: # type: ignore[override] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fw_args = saved_tensors_and_symints(ctx) ( query, @@ -699,7 +790,10 @@ def backward( # type: ignore[override] value, out, logsumexp, +<<<<<<< HEAD max_scores, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) query_lengths, kv_lengths, kv_num_blocks, @@ -778,7 +872,11 @@ def flex_attention_autograd( kernel_options: dict[str, Any], score_mod_other_buffers: tuple[Tensor, ...] = (), mask_mod_other_buffers: tuple[Tensor, ...] = (), +<<<<<<< HEAD ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +======= +) -> tuple[torch.Tensor, torch.Tensor]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex with TransformGetItemToIndex(): @@ -787,11 +885,14 @@ def flex_attention_autograd( for t in (query, key, value, *score_mod_other_buffers) ) if torch.is_grad_enabled() and input_requires_grad: +<<<<<<< HEAD if block_mask[7] is None: raise RuntimeError( "BlockMask q_indices is None. Backward pass requires q_indices to be computed. " "Please create the BlockMask with compute_q_blocks=True" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) example_vals = ( query.new_zeros((), requires_grad=input_requires_grad), query.new_zeros((), dtype=torch.int), @@ -804,7 +905,11 @@ def flex_attention_autograd( ) else: fw_graph, bw_graph = score_mod, None +<<<<<<< HEAD out, logsumexp, max_scores = FlexAttentionAutogradOp.apply( +======= + out, logsumexp = FlexAttentionAutogradOp.apply( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) query, key, value, @@ -816,7 +921,11 @@ def flex_attention_autograd( mask_mod_other_buffers, *score_mod_other_buffers, ) +<<<<<<< HEAD return out, logsumexp, max_scores +======= + return out, logsumexp +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ---------------------------- Backward HOP Implementation ---------------------------- @@ -965,9 +1074,15 @@ def _maybe_new_buffer( actual_grad_value.copy_(grad_value) if Bq != Bkv: +<<<<<<< HEAD assert Bq > 1 and Bkv == 1, ( f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}" ) +======= + assert ( + Bq > 1 and Bkv == 1 + ), f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) actual_grad_key = torch.sum(actual_grad_key, 0, keepdim=True) actual_grad_value = torch.sum(actual_grad_value, 0, keepdim=True) @@ -1266,7 +1381,11 @@ def flex_attention_backward_fake_tensor_mode( [ ( torch.empty_like(buffer, memory_format=torch.contiguous_format) +<<<<<<< HEAD if isinstance(buffer, torch.Tensor) +======= + if isinstance(buffer, torch.Tensor) and buffer.requires_grad +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else None ) for buffer in score_mod_other_buffers diff --git a/torch/_higher_order_ops/hints_wrap.py b/torch/_higher_order_ops/hints_wrap.py index 3f21c518cbd74..1adcc8e78dbf3 100644 --- a/torch/_higher_order_ops/hints_wrap.py +++ b/torch/_higher_order_ops/hints_wrap.py @@ -38,7 +38,12 @@ def __call__(self, body_fn, args, kwargs, hints): if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in args): raise RuntimeError( +<<<<<<< HEAD f"args must be a tuple of tensors, ints, floats, or bools, got {args}" +======= + "args must be a tuple of tensors, ints, floats, or bools, got " + f"{args}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if not isinstance(kwargs, dict): diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index 11b663ea4f61a..754b862bb08ea 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -1,9 +1,17 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD import contextlib from contextlib import nullcontext from dataclasses import dataclass, field from typing import Any, Callable, Optional, Union +======= + +import contextlib +from contextlib import nullcontext +from dataclasses import dataclass, field +from typing import Optional, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.utils._pytree as pytree @@ -45,7 +53,11 @@ @dataclass class OutputMetadata: num_fw_outs: Optional[int] = None +<<<<<<< HEAD indexes_with_symint: set[int] = field(default_factory=set) +======= + indexes_with_none: set[int] = field(default_factory=set) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) indexes_with_no_grad: set[int] = field(default_factory=set) @@ -69,6 +81,7 @@ def __call__( identifier: Optional[str], *operands, ): +<<<<<<< HEAD assert identifier is None or isinstance(identifier, str), ( "identifier must be a None or a string" ) @@ -79,6 +92,15 @@ def __call__( ), ( f"invoke_subgraph operands must be a list of tensors/ints/SymInts/Generator {operands}" ) +======= + assert identifier is None or isinstance( + identifier, str + ), "identifier must be a None or a string" + + assert all( + isinstance(o, (torch.Tensor, int, torch.SymInt)) for o in operands + ), f"invoke_subgraph operands must be a list of tensors/ints/SymInts {operands}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().__call__(subgraph, identifier, *operands) @@ -130,6 +152,7 @@ def invoke_subgraph_placeholder(func, *args, **kwargs): def _invoke_subgraph_placeholder_wrapper(func, args): return invoke_subgraph_placeholder(func, *args) +<<<<<<< HEAD with ( _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), @@ -140,6 +163,12 @@ def _invoke_subgraph_placeholder_wrapper(func, args): backend: Union[str, Callable[..., Any]] = ( make_eager_backend_with_torch_function_mode(metadata_mode) ) +======= + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode(): + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: backend = "eager" @@ -258,8 +287,13 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None): output_metadata.num_fw_outs = num_fw_outs for idx, fw_out in enumerate(fw_outs): +<<<<<<< HEAD if isinstance(fw_out, torch.SymInt): output_metadata.indexes_with_symint.add(idx) +======= + if fw_out is None: + output_metadata.indexes_with_none.add(idx) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif not fw_out.requires_grad: output_metadata.indexes_with_no_grad.add(idx) @@ -331,8 +365,13 @@ def get_output_metadata(subgraph, *operands): output_metadata.num_fw_outs = num_fw_outs for idx, fw_out in enumerate(fw_outs): +<<<<<<< HEAD if isinstance(fw_out, torch.SymInt): output_metadata.indexes_with_symint.add(idx) +======= + if fw_out is None: + output_metadata.indexes_with_none.add(idx) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif not fw_out.requires_grad: output_metadata.indexes_with_no_grad.add(idx) return output_metadata @@ -428,10 +467,17 @@ def forward( *operands, ) +<<<<<<< HEAD # Check that int (coming from symint) is at expected indexes. for idx, o in enumerate(out): if isinstance(o, int): assert idx in output_metadata.indexes_with_symint +======= + # Check that None is at expected indexes. + for idx, o in enumerate(out): + if o is None: + assert idx in output_metadata.indexes_with_none +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return out @@ -452,7 +498,11 @@ def backward( filtered_grad_outs = [] for idx, o in enumerate(grad_outs): if o is None: +<<<<<<< HEAD assert idx in output_metadata.indexes_with_symint +======= + assert idx in output_metadata.indexes_with_none +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif idx in output_metadata.indexes_with_no_grad: # Deliberately skip over the grad_outs which we know should be # None because the corresponding fwd_out does not require_grad. @@ -470,7 +520,10 @@ def backward( from torch._subclasses.fake_tensor import extract_tensor_metadata fake_mode = detect_fake_mode(primals + filtered_grad_outs) +<<<<<<< HEAD assert fake_mode is not None, "fake_mode should be enabled for HOPs" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) state = _CacheKeyState(fake_mode.shape_env) tangent_metadata: list[object] = [] @@ -566,9 +619,15 @@ def _(ctx, subgraph, identifier, *operands): # We call auto_functionalized_v2 to support input mutation of invoke_subgraph. # See NOTE [Support input mutation of hops] for the overall design. # +<<<<<<< HEAD # invoke_subgraph is special because of its identifier based caching mechanism. # In invoke_subgraph's functionalization key implementation, we create a new # identifier because the subgraph is replaced by FunctionWithNoFreeVars in a +======= + # invoke_subgraph is special because of its identifier based caching machanism. + # In invoke_subgraph's functionalization key implementation, we create a new + # identifer because the subgraph is replaced by FunctionWithNoFreeVars in a +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # functional + epilogue form. assert isinstance(identifier, str), identifier return do_auto_functionalize_v2( @@ -613,7 +672,10 @@ def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, *operands): from torch._guards import detect_fake_mode fake_mode = detect_fake_mode(operands) +<<<<<<< HEAD assert fake_mode is not None and fake_mode.shape_env is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) insert_deferred_runtime_asserts( graph, fake_mode.shape_env, @@ -642,7 +704,11 @@ def _unwrap_proxy(arg): # with a previously cached identifier, the corresponding graph module might not # exist as a submodule in the new tracer's root. Therefore, we register it as a submodule below. # +<<<<<<< HEAD # The alternative is to give a new identifier when we re-trace the invoke_subgraph but this will increase +======= + # The alternative is to give a new identifer when we re-trace the invoke_subgraph but this will increase +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # the compilatoin time, which defeats the purpose of caching. registered_before = False for ( diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index 57d2cd3cb9001..e6fd31462b93c 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -13,6 +13,10 @@ from torch._subclasses.functional_tensor import disable_functional_mode from torch.fx.experimental.proxy_tensor import ( disable_proxy_modes_tracing, +<<<<<<< HEAD +======= + make_fx, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ProxyTorchDispatchMode, track_tensor_tree, ) @@ -21,6 +25,7 @@ _from_fun, _stack_pytree, _unstack_pytree, +<<<<<<< HEAD create_bw_fn, fill_none_with_masks, filter_with_masks, @@ -28,6 +33,12 @@ save_tensors_and_symints_for_backward, saved_tensors_and_symints, split_into_chunks, +======= + clone_outputs_aliasing_inputs, + prepare_fw_with_masks, + save_tensors_and_symints_for_backward, + saved_tensors_and_symints, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -42,13 +53,91 @@ def __call__(self, *args, **kwargs): map_impl = MapImpl() +<<<<<<< HEAD +======= +def create_fw_bw_graph(f, num_mapped_args, *args): + mapped_xs = args[:num_mapped_args] + pos_args = args[num_mapped_args:] + + # See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + unwrapped_mapped_xs = pytree.tree_map(_from_fun, mapped_xs) + example_xs = _unstack_pytree(unwrapped_mapped_xs)[0] + + example_pos_args = [ + _from_fun(arg) if isinstance(arg, torch.Tensor) else arg + for arg in pos_args + ] + example_flat_out = pytree.tree_map( + _from_fun, f(*example_xs, *example_pos_args) + ) + if any( + not isinstance(out, torch.Tensor) + for out in example_flat_out + if out is not None + ): + raise RuntimeError( + "Expect outputs of map only contains tensors or None. " + f"Got types {[type(out) for out in example_flat_out]}." + ) + example_grad = [_from_fun(out) for out in example_flat_out] + + fw_graph = make_fx(f)(*example_xs, *example_pos_args) + + from torch._functorch.aot_autograd import AOTConfig, create_joint + + dummy_aot_config = AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + ) + + def joint_f(*example_args): + joint_mapped_args = example_args[:joint_num_mapped] + args = example_args[joint_num_mapped:] + + mapped_input = joint_mapped_args[:num_mapped_args] + mapped_grads = joint_mapped_args[num_mapped_args:] + + joint = create_joint(prepare_fw_with_masks(f), aot_config=dummy_aot_config) + _, grads = joint( + list(mapped_input) + list(args), + [ + grad + for grad in mapped_grads + if grad is not None and grad.requires_grad + ], + ) + + # In order to keep map functional for backward graph, + # we clone outputs that are aliasing inputs + maybe_clone = clone_outputs_aliasing_inputs(example_args) + + return pytree.tree_map(maybe_clone, grads) + + joint_num_mapped = len(example_grad) + len(example_xs) + joint_graph = make_fx(joint_f)(*example_xs, *example_grad, *example_pos_args) + return fw_graph, joint_graph + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def map( f: Callable[[pytree.PyTree, tuple[pytree.PyTree, ...]], pytree.PyTree], xs: Union[pytree.PyTree, torch.Tensor], *args: TypeVarTuple, ): r""" +<<<<<<< HEAD Performs a map of f with xs. Intuitively, you can think of the semantic being: +======= + Perfoms a map of f with xs. Intuitively, you can think of the semantic being: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = [] for idx in len(xs.size(0)): @@ -124,6 +213,7 @@ def wrapped_fn(*flat_args, f, xs_tree_spec, args_tree_spec, num_xs): class MapAutogradOp(torch.autograd.Function): @staticmethod +<<<<<<< HEAD def forward(ctx, f, num_mapped_args, *flat_args): ctx._f = f ctx._num_mapped_args = num_mapped_args @@ -137,11 +227,23 @@ def forward(ctx, f, num_mapped_args, *flat_args): with torch._C._AutoDispatchBelowAutograd(): return ( *map_impl(f, flat_args[:num_mapped_args], flat_args[num_mapped_args:]), +======= + def forward(ctx, fw_graph, joint_graph, num_mapped_args, *flat_args): + save_tensors_and_symints_for_backward(ctx, flat_args) + ctx._joint_graph = joint_graph + ctx._num_mapped_args = num_mapped_args + with torch._C._AutoDispatchBelowAutograd(): + return ( + *map_impl( + fw_graph, flat_args[:num_mapped_args], flat_args[num_mapped_args:] + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @staticmethod def backward(ctx, *flat_grads): fw_args = saved_tensors_and_symints(ctx) +<<<<<<< HEAD num_mapped_args = ctx._num_mapped_args num_pos_args = ctx._num_pos_args num_grads = len(flat_grads) @@ -214,6 +316,24 @@ def trace_map(proxy_mode, func_overload, f, xs, pos_args): body_graph = f body_graph = reenter_make_fx(body_graph)(*example_input, *pos_args) +======= + fw_mapped_args = fw_args[: ctx._num_mapped_args] + pos_args = fw_args[ctx._num_mapped_args :] + + grads = map_impl( + ctx._joint_graph, + fw_mapped_args + flat_grads, + pos_args, + ) + return None, None, None, *grads + + +def trace_map(proxy_mode, func_overload, f, xs, pos_args): + example_input = _unstack_pytree(xs)[0] + body_graph = f + + body_graph = reenter_make_fx(body_graph)(*example_input, *pos_args) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) next_name = proxy_mode.tracer.get_fresh_qualname("body_graph_") @@ -240,7 +360,12 @@ def map_dense(f, xs, pos_args): @map_impl.py_autograd_impl def map_autograd(f, xs, pos_args): num_mapped_args = len(xs) +<<<<<<< HEAD flat_out = MapAutogradOp.apply(f, num_mapped_args, *xs, *pos_args) +======= + fw_graph, bw_graph = create_fw_bw_graph(f, num_mapped_args, *xs, *pos_args) + flat_out = MapAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return flat_out diff --git a/torch/_higher_order_ops/out_dtype.py b/torch/_higher_order_ops/out_dtype.py index 38c07e37bdb85..93107f70c683d 100644 --- a/torch/_higher_order_ops/out_dtype.py +++ b/torch/_higher_order_ops/out_dtype.py @@ -111,8 +111,13 @@ def is_int_mm(op, output_dtype, args): and len(args) == 2 and args[0].dtype == torch.int8 and args[1].dtype == torch.int8 +<<<<<<< HEAD and (args[0].is_cuda or args[0].is_xpu) and (args[1].is_cuda or args[1].is_xpu) +======= + and args[0].is_cuda + and args[1].is_cuda +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/torch/_higher_order_ops/run_const_graph.py b/torch/_higher_order_ops/run_const_graph.py index ed7c5278f5fe6..eebb0a296d524 100644 --- a/torch/_higher_order_ops/run_const_graph.py +++ b/torch/_higher_order_ops/run_const_graph.py @@ -1,24 +1,38 @@ +<<<<<<< HEAD from typing import Any, TYPE_CHECKING +======= +# mypy: allow-untyped-defs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch._C import DispatchKey from torch._higher_order_ops.utils import autograd_not_implemented from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode +<<<<<<< HEAD if TYPE_CHECKING: from torch._subclasses.functional_tensor import BaseFunctionalizeAPI +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree from torch.utils import _pytree as pytree class RunConstGraph(HigherOrderOperator): +<<<<<<< HEAD def __init__(self) -> None: super().__init__("run_const_graph") def __call__(self, graph: torch.fx.GraphModule, args: tuple[object, ...]) -> object: +======= + def __init__(self): + super().__init__("run_const_graph") + + def __call__(self, graph, args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().__call__(graph, args) @@ -26,6 +40,7 @@ def __call__(self, graph: torch.fx.GraphModule, args: tuple[object, ...]) -> obj @run_const_graph.py_impl(ProxyTorchDispatchMode) +<<<<<<< HEAD def run_const_graph_dispatch_mode( mode: ProxyTorchDispatchMode, graph: torch.fx.GraphModule, args: tuple[object, ...] ) -> object: @@ -34,6 +49,14 @@ def run_const_graph_dispatch_mode( assert isinstance(const_gm, torch.fx.GraphModule) assert not hasattr(mode.tracer.root, "_const_graph") # type: ignore[union-attr] mode.tracer.root.register_module("_const_graph", const_gm) # type: ignore[union-attr] +======= +def run_const_graph_dispatch_mode(mode, graph, args): + const_gm, weights = graph, args + p_args = pytree.tree_map(mode.tracer.unwrap_proxy, (graph, args)) + assert isinstance(const_gm, torch.fx.GraphModule) + assert not hasattr(mode.tracer.root, "_const_graph") + mode.tracer.root.register_module("_const_graph", const_gm) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) proxy = mode.tracer.create_proxy("call_function", run_const_graph, p_args, {}) @@ -42,6 +65,7 @@ def run_const_graph_dispatch_mode( @run_const_graph.py_functionalize_impl +<<<<<<< HEAD def run_const_graph_functional( ctx: "BaseFunctionalizeAPI", graph: torch.fx.GraphModule, args: tuple[Any, ...] ) -> Any: @@ -50,6 +74,14 @@ def run_const_graph_functional( with ctx.redispatch_to_next(): out = run_const_graph(graph, unwrapped_args) return ctx.wrap_tensors(out) # type: ignore[arg-type] +======= +def run_const_graph_functional(ctx, graph, args): + unwrapped_args = ctx.unwrap_tensors(args) + + with ctx.redispatch_to_next(): + out = run_const_graph(*unwrapped_args) + return ctx.wrap_tensors(out) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) run_const_graph.py_autograd_impl( @@ -58,17 +90,25 @@ def run_const_graph_functional( @run_const_graph.py_impl(FakeTensorMode) +<<<<<<< HEAD def run_const_graph_fake_tensor_mode( mode: FakeTensorMode, graph: torch.fx.GraphModule, args: tuple[object, ...] ) -> object: +======= +def run_const_graph_fake_tensor_mode(mode, graph, args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(graph, torch.fx.GraphModule) with mode: return graph(*args) @run_const_graph.py_impl(DispatchKey.CPU) +<<<<<<< HEAD def run_const_graph_cpu( graph: torch.fx.GraphModule, args: tuple[object, ...] ) -> object: +======= +def run_const_graph_cpu(graph, args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(graph, torch.fx.GraphModule) return graph(*args) diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index e4aa0161ad3c9..7180fabb0d690 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -1,12 +1,18 @@ # mypy: allow-untyped-defs import functools import itertools +<<<<<<< HEAD from typing import Any, Callable +======= +from collections.abc import Sequence +from typing import Any, Callable, Optional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._prims_common as utils import torch.utils._pytree as pytree from torch._C import DispatchKey +<<<<<<< HEAD from torch._higher_order_ops.utils import ( _maybe_compile_and_run_fn, check_input_alias_and_mutation_return_outputs, @@ -16,11 +22,21 @@ first_slice_copy_with_grad, get_tensor_mask, mask_list, +======= +from torch._higher_order_ops.cond import create_bw_fn +from torch._higher_order_ops.utils import ( + _maybe_compile_and_run_fn, + check_meta_consistency, + first_slice_copy, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) materialize_as_graph, reenter_make_fx, save_tensors_and_symints_for_backward, saved_tensors_and_symints, +<<<<<<< HEAD split_into_chunks, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unique_graph_id, validate_subgraph_args_types, ) @@ -40,9 +56,15 @@ def wrap_combine_fn_flat( *args, combine_fn, spec_init, spec_xs, num_init_leaves, num_inp_leaves ): +<<<<<<< HEAD assert len(args) == (num_init_leaves + num_inp_leaves), ( f"Combin_fn received wrong number of arguments, expected {num_init_leaves + num_inp_leaves}, but got {len(args)}" ) +======= + assert len(args) == ( + num_init_leaves + num_inp_leaves + ), f"Combin_fn received wrong number of arguments, expected {num_init_leaves + num_inp_leaves}, but got {len(args)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) carry = pytree.tree_unflatten(args[:num_init_leaves], spec_init) xs = pytree.tree_unflatten(args[num_init_leaves:], spec_xs) return combine_fn(carry, xs) @@ -63,6 +85,53 @@ def stack_y(y: torch.Tensor, scan_length: int) -> torch.Tensor: ) +<<<<<<< HEAD +======= +# NOTE: These functions can be reused in associative_scan and eventually moved to +# torch._higher_order_ops.utils +def get_tensor_mask(tensor_list: list[Any]) -> list[bool]: + # Returns a mask whether a list element is a tensor or not + return [True if isinstance(v, torch.Tensor) else False for v in tensor_list] + + +def mask_list( + mask: list[bool], inp: list[Any], other: Optional[list[Any]] = None +) -> list[Any]: + # Masks elements on an `inp` list. + # If other is None, then the elements of the `inp` list where the mask is False are removed + # If other is not None, then the elements of the `inp` list where the mask is False are + # replaced with the elements of the `other` list + assert len(mask) == len( + inp + ), "The length of the mask needs to be identical to the length of the input" + if other is not None: + assert len(inp) == len( + other + ), "If an input and an other list is provided, they need to have the same length" + return [i if m else o for m, i, o in zip(mask, inp, other)] + else: + return [i for m, i in zip(mask, inp) if m] + + +def first_slice_copy_with_grad(li: list[Any]) -> list[Any]: + # First_slice_copy does not keep the original requires_grad flag, + # but we need it for materialize_as_graph + # in order to compute the correct gradients + # The reason why first_slice_copy doesn't keep requires_grad flag is + # because it's called in torch.autograd.Function.backward/forward. + slc = [first_slice_copy(x).requires_grad_(x.requires_grad) for x in li] + return slc + + +def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[Any]: + it = iter(iterable) + assert sum(chunk_sizes) == len( + iterable + ), "the sum of all chunks needs to match the length of the iterable." + return [list(itertools.islice(it, size)) for size in chunk_sizes] + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def call_operator(operator, *args): return pytree.tree_leaves(operator(*args)) @@ -95,7 +164,11 @@ def scan( and the second output of ``combine_fn`` represents a slice of the output. This function must be pure, i.e., no lifted arguments are supported at the moment and may not have any side effects. +<<<<<<< HEAD init (torch.Tensor or pytree with tensor leaves): The initial scan carry, a tensor, or nested pytree of tensors. +======= + init (torch.Tensor or pytree with tensor leaves): The inital scan carry, a tensor, or nested pytree of tensors. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) The ``init`` is expected to have the same pytree structure as the first output element (i.e. carry) of ``combine_fn``. xs (torch.Tensor or pytree with tensor leaves): The input tensor, or nested pytree of tensors. @@ -114,7 +187,11 @@ def scan( - The combine_fn shouldn't have any aliasing between input-input, input-output, and output-output. E.g. return a view or the same tensor as input is not supported. As a workaround, can clone the output to avoid aliasing. +<<<<<<< HEAD - The combine_fn shouldn't mutate any inputs. We'll remove the mutation restriction for inference soon. Please file an issue +======= + - The combine_fn shoudn't mutate any inputs. We'll remove the mutation restriction for inference soon. Please file an issue +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if you input mutation support for training is needed. - The combine_fn's init carry should match the next_carry in pytree structure and in tensor metadata. @@ -126,7 +203,10 @@ def add(x: torch.Tensor, y: torch.Tensor): # clone the output to avoid output-output aliasing return next_carry, y.clone() +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) i0 = torch.zeros(1) xs = torch.arange(5) # returns torch.tensor([10.]), torch.tensor([[0], [1.], [3.], [6.], [10.]]) @@ -223,9 +303,15 @@ def __call__(self, combine_fn, init, xs, additional_inputs): # the additional_inputs being a list. See https://github.com/pytorch/pytorch/issues/145785 # Once this issue is resolved, the assertion should only allow tuples # and the tuple cast should be removed +<<<<<<< HEAD assert isinstance(additional_inputs, (tuple, list)), ( "additional_inputs must be a tuple." ) +======= + assert isinstance( + additional_inputs, (tuple, list) + ), "additional_inputs must be a tuple." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) additional_inputs = ( tuple(additional_inputs) if isinstance(additional_inputs, list) @@ -234,6 +320,7 @@ def __call__(self, combine_fn, init, xs, additional_inputs): validate_subgraph_args_types(additional_inputs) return super().__call__(combine_fn, init, xs, additional_inputs) +<<<<<<< HEAD def gen_schema(self, combine_fn, init, xs, additional_inputs): from torch._higher_order_ops.schema import HopSchemaGenerator from torch._higher_order_ops.utils import materialize_as_graph @@ -284,6 +371,8 @@ def gen_schema(self, combine_fn, init, xs, additional_inputs): schema_gen.add_schema_tree_spec(combine_fn, init, xs, additional_inputs) return schema_gen.gen_schema() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) scan_op = ScanOp() @@ -595,9 +684,15 @@ def combine_fn_with_carry_checkpoint(*args): carry, y = _extract_carry_and_out(combine_fn(*args), num_leaves_init) return [ *carry, +<<<<<<< HEAD # We additionally checkpoint all the intermediate carry outputs for backward. *[ n_c.detach().clone() if isinstance(n_c, torch.Tensor) else n_c +======= + # We additionally checkpoint all the intemediate carry outputs for backward. + *[ + n_c.clone().detach() if isinstance(n_c, torch.Tensor) else n_c +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for n_c in carry ], *y, @@ -803,7 +898,11 @@ def construct_args_single_step_bw(): # Prepare the bwd_init bwd_init = [*initial_g_additional_inputs, *g_c_T] +<<<<<<< HEAD # 5.) Perform the backward scan: +======= + # 5.) Perform the backwrad scan: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The ``combine_fn_bw_wrapped`` receives the # initial_g_additional_inputs and the last carry as the ``bwd_init`` and the # gradients of the outputs (g_ys), as well as the fw_carries and the fw_xs of the forward as the ``bwd_xs`` diff --git a/torch/_higher_order_ops/schema.py b/torch/_higher_order_ops/schema.py index b1cdacb323731..079aa795372d8 100644 --- a/torch/_higher_order_ops/schema.py +++ b/torch/_higher_order_ops/schema.py @@ -18,7 +18,11 @@ class HopArgumentInfo: example_value: Any # Provide an default_value default_value: Any +<<<<<<< HEAD # Whether this argument gets mutated in the hop subgraph. +======= + # Whether this arugment gets mutated in the hop subgraph. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # For output, this should always be False is_mutated: bool kw_only: bool @@ -35,9 +39,15 @@ def from_example( kw_only: bool = False, ) -> HopArgumentInfo: if default_value is not None: +<<<<<<< HEAD assert type(example_value) == type(default_value), ( f"example_value type {type(example_value)} doesn't match default_value type: {type(default_value)}" ) +======= + assert type(example_value) == type( + default_value + ), f"example_value type {type(example_value)} doesn't match default_value type: {type(default_value)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return HopArgumentInfo( name=name, @@ -65,8 +75,11 @@ def from_example(obj: Any) -> Any: return torch._C.AnyType.get() elif isinstance(obj, torch.SymInt): return torch._C.SymIntType.get() +<<<<<<< HEAD elif isinstance(obj, torch.SymBool): return torch._C.SymBoolType.get() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch._C._jit_try_infer_type(obj).type() @@ -209,12 +222,21 @@ def from_hop_argument_info( args.append(CArgumentGen.from_hop_argument_info(i, arg_info)) # NOTE: we want the output to always be a single argument with torch._C.TupleType. +<<<<<<< HEAD assert isinstance(out_argument_info.example_value, tuple), ( f"expect out_argument_info's example_value to be a tuple but got {out_argument_info.example_value}" ) assert not out_argument_info.is_mutated, ( "out_argument_info.is_mutated should always be set to False." ) +======= + assert isinstance( + out_argument_info.example_value, tuple + ), f"expect out_argument_info's example_value to be a tuple but got {out_argument_info.example_value}" + assert ( + not out_argument_info.is_mutated + ), "out_argument_info.is_mutated should always be set to False." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) rets = None if len(out_argument_info.example_value) == 1: rets = [CArgumentGen.from_hop_argument_info(0, out_argument_info, True)] diff --git a/torch/_higher_order_ops/strict_mode.py b/torch/_higher_order_ops/strict_mode.py index 1ed920c4a150c..3a4836a7028fb 100644 --- a/torch/_higher_order_ops/strict_mode.py +++ b/torch/_higher_order_ops/strict_mode.py @@ -1,6 +1,9 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD from typing import Any, Callable, Union +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._subclasses.functional_tensor import torch.utils._pytree as pytree @@ -35,9 +38,13 @@ def strict_mode(callable, operands): modes = [metadata_mode, predispatch_mode] modes = [mode for mode in modes if mode is not None] if modes: +<<<<<<< HEAD backend: Union[str, Callable[..., Any]] = ( make_eager_backend_with_torch_function_modes(modes) ) +======= + backend = make_eager_backend_with_torch_function_modes(modes) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: backend = "eager" with torch._dynamo.utils.disable_cache_limit(): diff --git a/torch/_higher_order_ops/torchbind.py b/torch/_higher_order_ops/torchbind.py index c10e674b7ac0c..7ccb92fbe5b1a 100644 --- a/torch/_higher_order_ops/torchbind.py +++ b/torch/_higher_order_ops/torchbind.py @@ -81,9 +81,15 @@ def enable_torchbind_tracing(): torch.ScriptMethod.__call__ = torchbind_method_redispatch # type: ignore[method-assign] yield finally: +<<<<<<< HEAD assert KNOWN_TYPES.pop() is torch.ScriptObject, ( "Someone else messed with KNOWN_TYPES during tracing, exploding." ) +======= + assert ( + KNOWN_TYPES.pop() is torch.ScriptObject + ), "Someone else messed with KNOWN_TYPES during tracing, exploding." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.ScriptMethod.__call__ = _orig_scriptmethod_call # type: ignore[method-assign] @@ -127,16 +133,26 @@ def inner(mode, *args, **kwargs): ret = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) if "val" not in out_proxy.node.meta: +<<<<<<< HEAD assert out is None or isinstance(out, (int, float, bool)), ( "Currently, only these constant dtypes are supported to be returned from torchbind methods." ) +======= + assert out is None or isinstance( + out, (int, float, bool) + ), "Currently, only these constant dtypes are supported to be returned from torchbind methods." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_proxy.node.meta["val"] = out return ret # When tracing with fake script object, the call_torchbind op will return a fake tensor # When tracing with real script object, the call_torchbind op may return a real tensor, +<<<<<<< HEAD # we need to convert it to fake tensor manually. Dynamic shape is supported. +======= +# we need to convert it to fake tensor mannually. Dynamic shape is surpported. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @call_torchbind.py_impl(FakeTensorMode) def call_torchbind_fake(mode, *args, **kwargs): with mode: diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index fa8ab598eb89c..e22567f6f514e 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -18,7 +18,10 @@ import torch.utils._pytree as pytree from torch import SymInt, Tensor from torch._C import DispatchKey +<<<<<<< HEAD from torch._higher_order_ops.utils import redirect_to_mode +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._ops import HigherOrderOperator from torch._prims_common import clone_preserve_strides from torch._subclasses.fake_tensor import FakeTensorMode @@ -29,7 +32,10 @@ ) from torch.fx.experimental.symbolic_shapes import guard_scalar from torch.types import IntLikeType +<<<<<<< HEAD from torch.utils.checkpoint import _CachedTorchDispatchMode, _CachingTorchDispatchMode +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING: @@ -95,7 +101,11 @@ def create_tma_experimental_metadata( def maybe_unpack_tma_experimental_metadata( +<<<<<<< HEAD tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata], +======= + tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Optional[tuple[list[IntLikeType], list[IntLikeType], IntLikeType]]: if not tma_meta or len(tma_meta) != 2: return None @@ -111,7 +121,11 @@ def create_tma_stable_metadata( def maybe_unpack_tma_stable_metadata( +<<<<<<< HEAD tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata], +======= + tma_meta: Union[TMAExperimentalMetadata, TMAStableMetadata] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Optional[tuple[list[IntLikeType]]]: if not tma_meta or len(tma_meta) != 2: return None @@ -463,16 +477,23 @@ def get_signature_value(idx: int, arg: Any) -> str: elif make_ir_sig_params == 3: codegen_fns = backend.get_codegen_implementation() ttir_module = src.make_ir(options, codegen_fns, context) +<<<<<<< HEAD elif make_ir_sig_params == 4: codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] codegen_fns = backend.get_codegen_implementation(*codegen_args) module_map = backend.get_module_map() ttir_module = src.make_ir(options, codegen_fns, module_map, context) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] codegen_fns = backend.get_codegen_implementation(*codegen_args) module_map = backend.get_module_map() +<<<<<<< HEAD ttir_module = src.make_ir(target, options, codegen_fns, module_map, context) +======= + ttir_module = src.make_ir(options, codegen_fns, module_map, context) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not ttir_module.verify(): raise RuntimeError("Verification for TTIR module has failed") @@ -805,9 +826,12 @@ def get_tma_stores( elif op.name == "tt.experimental_descriptor_store": assert len(op.args) >= 1 result.add(op.args[0]) +<<<<<<< HEAD elif op.name == "tt.descriptor_store": assert len(op.args) >= 1 result.add(op.args[0]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for val in list(result): if val in ops: @@ -857,6 +881,12 @@ def analyze_kernel_mutations( # (e.g. `tt.elementwise_inline_asm`), we assume it does not mutate any input parameters. if op.name in UNKNOWN_OPS: if op.name == "tt.elementwise_inline_asm" and op.is_pure: +<<<<<<< HEAD +======= + log.warning( + "TTIR mutation analysis: Skipping pure tt.elementwise_inline_asm op (is_pure=True)" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue raise RuntimeError( f"ttir analysis hit an op we do not know how to analyze: {op.name}" @@ -1044,7 +1074,11 @@ def triton_kernel_wrapper_mutation_dense( # as we need to launch the kernel here, we "unwrap" the # tma_descriptor_metadata, create the TMA descriptors # from it, and replace the tensors in the kwargs by the +<<<<<<< HEAD # corresponding TMA descriptors before launching +======= + # correspoinding TMA descriptors before launching +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs = kwargs.copy() for k, v in tma_descriptor_metadata.items(): tensor = kwargs[k] @@ -1129,8 +1163,12 @@ def trace_triton_kernel_wrapper( out = func_overload(**node_args) proxy_args = pytree.tree_map( +<<<<<<< HEAD proxy_mode.tracer.unwrap_proxy, # type: ignore[union-attr] node_args, +======= + proxy_mode.tracer.unwrap_proxy, node_args # type: ignore[union-attr] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) out_proxy = proxy_mode.tracer.create_proxy( "call_function", @@ -1344,9 +1382,12 @@ def triton_kernel_wrapper_functional_functionalize( triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA) triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCPU) +<<<<<<< HEAD # Adds SAC support for triton ops redirect_to_mode(triton_kernel_wrapper_mutation, _CachingTorchDispatchMode) redirect_to_mode(triton_kernel_wrapper_mutation, _CachedTorchDispatchMode) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ############################################################################### # The "TritonHOPifier": a class that transforms a call to a triton kernel into @@ -1671,9 +1712,15 @@ def call_triton_kernel( # Update the kwargs in each config # maybe_unpack_heuristic_result raises unsupported if the value is non-constant +<<<<<<< HEAD new_configs[config_idx].__dict__["kwargs"][kwarg_key] = ( self.maybe_unpack_heuristic_result(heuristic_result) ) +======= + new_configs[config_idx].__dict__["kwargs"][ + kwarg_key + ] = self.maybe_unpack_heuristic_result(heuristic_result) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) iter_kernel = iter_kernel.fn assert isinstance(iter_kernel, JITFunction) @@ -1753,9 +1800,15 @@ def call_triton_kernel( for config in new_configs: for name in special_param_names: if name not in config.__dict__["kwargs"]: +<<<<<<< HEAD assert name in config.__dict__, ( f"{name} must be in autotuning configs to be used as a kernel parameter" ) +======= + assert ( + name in config.__dict__ + ), f"{name} must be in autotuning configs to be used as a kernel parameter" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) config.__dict__["kwargs"][name] = config.__dict__[name] updated = True diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 7e5b235264fc5..586007e02d56d 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -1,8 +1,12 @@ # mypy: allow-untyped-defs import contextlib import functools +<<<<<<< HEAD from collections.abc import Iterable, Sequence from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext +======= +from contextlib import contextmanager, ExitStack, nullcontext +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from dataclasses import dataclass from typing import Any, Callable, Optional, overload, TypeVar, Union @@ -103,9 +107,13 @@ def _maybe_compile_and_run_fn(fn, *args): with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): with _temp_remove_metadata_torch_function_mode() as metadata_mode: if metadata_mode: +<<<<<<< HEAD backend: Union[str, Callable[..., Any]] = ( make_eager_backend_with_torch_function_mode(metadata_mode) ) +======= + backend = make_eager_backend_with_torch_function_mode(metadata_mode) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: backend = "eager" return torch.compile(fn, backend=backend, fullgraph=True)(*args) @@ -118,6 +126,7 @@ def reenter_make_fx(fn): @functools.wraps(fn) def wrapped(*args): +<<<<<<< HEAD assert _CURRENT_MAKE_FX_TRACER is not None, ( "Cannot reenter make_fx when we're not under a make_fx tracing session" ) @@ -125,6 +134,14 @@ def wrapped(*args): _maybe_run_with_interpreter(fn), *args ) return gm +======= + assert ( + _CURRENT_MAKE_FX_TRACER is not None + ), "Cannot reenter make_fx when we're not under a make_fx tracing session" + return _CURRENT_MAKE_FX_TRACER.trace_subgraph( + _maybe_run_with_interpreter(fn), *args + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return wrapped @@ -240,7 +257,10 @@ def diff_device( def _set_compilation_env(): _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag _old_allow_empty_graphs = torch._dynamo.config.allow_empty_graphs +<<<<<<< HEAD _old_capture_scalar_outputs = torch._dynamo.config.capture_scalar_outputs +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The issue is tracked in https://github.com/pytorch/pytorch/issues/144360: when dynamo finds # the top-level frame produces no graph, the default behavior is to fallback to eager. # Then when it encounters an inner function, it will try to trace that function again, which is unnecessary. @@ -254,22 +274,36 @@ def _set_compilation_env(): # once we are confident fx tracing works with dynamo. torch.fx._symbolic_trace._is_fx_tracing_flag = False torch._dynamo.config.allow_empty_graphs = True +<<<<<<< HEAD torch._dynamo.config.capture_scalar_outputs = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) yield finally: torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing torch._dynamo.config.allow_empty_graphs = _old_allow_empty_graphs +<<<<<<< HEAD torch._dynamo.config.capture_scalar_outputs = _old_capture_scalar_outputs +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The invariant here is that we always trace the branch with fake tensor def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch): +<<<<<<< HEAD fake_mode_det = detect_fake_mode(inputs) fake_mode: AbstractContextManager = nullcontext() tracing_mode = "fake" if fake_mode_det is not None: fake_mode = fake_mode_det tracing_mode = "real" +======= + fake_mode = detect_fake_mode(inputs) + tracing_mode = "real" + if fake_mode is None: + fake_mode = nullcontext() + tracing_mode = "fake" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Note: we need to turn off proxy tensor mode to avoid tracing infra # code that happens in make_fx e.g. we now call as_strided when wrapping tensor @@ -281,12 +315,18 @@ def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch): pre_dispatch=pre_dispatch, _error_on_data_dependent_ops=False, )(*inputs) +<<<<<<< HEAD if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None: # type: ignore[attr-defined] insert_deferred_runtime_asserts( gm, fake_mode.shape_env, # type: ignore[attr-defined] "hoo_maybe_fake_tracing", export=True, # type: ignore[attr-defined] +======= + if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None: + insert_deferred_runtime_asserts( + gm, fake_mode.shape_env, "hoo_maybe_fake_tracing", export=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return gm @@ -334,15 +374,23 @@ def analyze_potential_input_alias_or_mutation(name, aliases, input_mutations): def _has_potential_branch_input_mutation(gm, inputs, pre_dispatch=False): ( +<<<<<<< HEAD (_, _, _), inp_mutation, ) = potential_input_alias_or_mutation(gm, inputs, pre_dispatch) +======= + _, + _, + _, + ), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return len(inp_mutation) > 0 def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False): ( +<<<<<<< HEAD ( inp_inp_alias_map, inp_out_alias_map, @@ -350,6 +398,12 @@ def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False): ), inp_mutation, ) = potential_input_alias_or_mutation(gm, inputs, pre_dispatch) +======= + inp_inp_alias_map, + inp_out_alias_map, + out_out_alias_map, + ), inp_mutation = potential_input_alias_or_mutation(gm, inputs, pre_dispatch) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( any( ( @@ -405,7 +459,13 @@ def _check_alias_and_mutation(graph_module, inputs_fake, name, pre_dispatch): graph_module, inputs_fake, pre_dispatch=pre_dispatch ) if aliases: +<<<<<<< HEAD raise RuntimeError(f"{name} might be aliasing the input or the output!") # noqa: F541 +======= + raise RuntimeError( + f"{name} might be aliasing the input or the output!" + ) # noqa: F541 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if inp_mutation: raise RuntimeError(f"{name} might be modifying the input!") # noqa: F541 @@ -501,9 +561,13 @@ def fw_with_masks(*args): # require_gradness reasoning much easier. if pytree.tree_any_only(torch.Tensor, lambda t: t.requires_grad, args): fw_out = pytree.tree_map_only( +<<<<<<< HEAD torch.Tensor, lambda x: x.requires_grad_(True) if x.dtype.is_floating_point else x, fw_out, +======= + torch.Tensor, lambda x: x.requires_grad_(True), fw_out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return fw_out, pytree.tree_map_only( torch.Tensor, lambda x: x.requires_grad, fw_out @@ -517,9 +581,15 @@ def fw_with_masks(*args): # replaced with an all-zero tensor for better optimization def unmask_none_gradients(grads, operands): allowed_types = (torch.Tensor, int, torch.SymInt) +<<<<<<< HEAD assert all(isinstance(o, allowed_types) for o in operands), ( f"operands can only be of {allowed_types} but got {[type(o) for o in operands]}" ) +======= + assert all( + isinstance(o, allowed_types) for o in operands + ), f"operands can only be of {allowed_types} but got {[type(o) for o in operands]}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unmasked_grads = [] for g, o in zip(grads, operands): @@ -722,6 +792,7 @@ def saved_tensors_and_symints(ctx): return tuple(args) +<<<<<<< HEAD def split_into_chunks(iterable: Sequence[Any], chunk_sizes: list[int]) -> list[Any]: assert sum(chunk_sizes) == len(iterable), ( "the sum of all chunks needs to match the length of the iterable." @@ -785,6 +856,8 @@ def flat_fn(*args_and_grad_outs): return flat_fn +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_dummy_aot_autograd_config(): from torch._functorch.aot_autograd import AOTConfig @@ -804,6 +877,7 @@ def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor: return torch.select_copy(t, dim, 0) +<<<<<<< HEAD # Returns a mask whether a list element is a tensor or not def get_tensor_mask(tensor_list: Iterable[Any]) -> list[bool]: return [True if isinstance(v, torch.Tensor) else False for v in tensor_list] @@ -838,6 +912,8 @@ def first_slice_copy_with_grad(li: Iterable[Any]) -> list[Any]: return slc +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Reports the difference between meta of two tensors in a string def diff_tensor_meta( meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True @@ -872,9 +948,13 @@ def validate_subgraph_args_types(lifted_args: Union[tuple[Any, ...], list[Any]]) allowed_types = (torch.Tensor, int, torch.SymInt) assert all( isinstance(arg, (torch.Tensor, int, torch.SymInt)) for arg in lifted_args +<<<<<<< HEAD ), ( f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}" ) +======= + ), f"{lifted_args} can only be of {allowed_types} but got {tuple(type(arg) for arg in lifted_args)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: Return a more detailed information as to which node @@ -892,10 +972,13 @@ def check_input_alias_and_mutation( return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs +<<<<<<< HEAD def _tensor_storage(t) -> StorageWeakRef: return StorageWeakRef(t._typed_storage()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def check_input_alias_and_mutation_return_outputs( gm: torch.fx.GraphModule, fake_args: Union[list[FakeTensor], tuple[FakeTensor, ...]], @@ -909,11 +992,15 @@ def check_input_alias_and_mutation_return_outputs( # This function can be called under autograd, functional, proxy and fake tensor mode. # We need to return either a fake tensor or a real tensor depending on the mode. # to detect the input mutation/aliasing. +<<<<<<< HEAD with ( disable_proxy_modes_tracing(), disable_functional_mode(), suspend_functionalization(), ): +======= + with disable_proxy_modes_tracing(), disable_functional_mode(), suspend_functionalization(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _from_functional_tensor(t: torch.Tensor) -> torch.Tensor: if isinstance(t, FunctionalTensor) or torch._is_functional_tensor(t): @@ -945,14 +1032,24 @@ def _tensor_version(t) -> Optional[int]: return t._version return None +<<<<<<< HEAD def _get_shape_env( fake_args, ) -> torch.fx.experimental.symbolic_shapes.ShapeEnv: +======= + def _tensor_storage(t) -> StorageWeakRef: + return StorageWeakRef(t._typed_storage()) + + def _get_shape_env( + fake_args, + ) -> Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # detect_fake_mode requires there could be only one active fake mode. This # restricts the usage of this function because the global TracingContext # has a persistent fake mode but fake tensors can be created # outside of the tracing context (e.g. in testing). # Instead, we just look at fake_args fake tensor mode +<<<<<<< HEAD for arg in fake_args: if isinstance(arg, FakeTensor) and arg.fake_mode.shape_env is not None: return arg.fake_mode.shape_env @@ -969,6 +1066,23 @@ def _get_shape_env( # We allow non fake inputs for this purpose. This is fine for mutation detection purpose: # inputs are all fake and all mutations/aliasing are still detected. allow_non_fake_inputs=True, +======= + if len(fake_args) == 0: + return torch.fx.experimental.symbolic_shapes.ShapeEnv() + + for arg in fake_args: + if isinstance(arg, FakeTensor): + return arg.fake_mode.shape_env + return None + + # Clone the fake args to avoid mutating the original fake args + with ExitStack() as ctx_stack: + # We need to re-use prev_fake_mode's shape env to resolve + # the runtime assertions for unbacked symbols. + new_fake_mode = torch._subclasses.FakeTensorMode( + shape_env=_get_shape_env(fake_args), + allow_non_fake_inputs=False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # We need to temporarily turn inference_mode off because # under inference mode, tensor version counter is not tracked. @@ -1046,11 +1160,21 @@ def _get_shape_env( @overload +<<<<<<< HEAD def register_fake(hop, fn: None = None) -> Callable[[F], F]: ... @overload def register_fake(hop, fn: F) -> F: ... +======= +def register_fake(hop, fn: None = None) -> Callable[[F], F]: + ... + + +@overload +def register_fake(hop, fn: F) -> F: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def register_fake(hop, fn=None): @@ -1170,22 +1294,30 @@ def _materialize_as_graph_inner(): with suspend_functionalization(), disable_functional_mode(): with disable_proxy_modes_tracing(): unfunc_t = [_from_fun(arg) for arg in args] +<<<<<<< HEAD with contextlib.ExitStack() as stack: stack.enter_context( torch.utils._python_dispatch._disable_current_modes() ) stack.enter_context( +======= + with contextlib.ExitStack() as stack: + stack.enter_context( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set), ) if force_enable_grad: stack.enter_context(torch.enable_grad()) +<<<<<<< HEAD # fake_mode is needed because parent tracer's fake_mode might # be None but the associated inputs have fake mode or there # is a global tracing context with fake mode. We nneed to # make sure the fake mode when tracing subgraph is consistent. if fake_mode := detect_fake_mode(unfunc_t): stack.enter_context(fake_mode) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return _maybe_reenter_make_fx(fn)(*unfunc_t) gm = _materialize_as_graph_inner() @@ -1258,6 +1390,7 @@ def _has_gen_schema(op: HigherOrderOperator): return hasattr(type(op), method) and getattr(type(op), method) is not getattr( HigherOrderOperator, method ) +<<<<<<< HEAD def filter_with_masks(data: list[Optional[torch.Tensor]], masks: list[bool]): @@ -1268,3 +1401,5 @@ def filter_with_masks(data: list[Optional[torch.Tensor]], masks: list[bool]): def fill_none_with_masks(data: list[Optional[torch.Tensor]], masks: list[bool]): data_iter = iter(data) return [next(data_iter) if kept else None for kept in masks] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index 02aa6ac0215ec..58417429e8a7e 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -1,7 +1,11 @@ # mypy: allow-untyped-defs import contextlib +<<<<<<< HEAD import functools from typing import Any, Callable, Union +======= +from typing import Callable, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.utils._pytree as pytree @@ -10,11 +14,15 @@ _maybe_run_with_interpreter, _set_compilation_env, autograd_not_implemented, +<<<<<<< HEAD check_input_alias_and_mutation_return_outputs, check_meta_consistency, fill_none_with_masks, filter_with_masks, materialize_as_graph, +======= + check_meta_consistency, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reenter_make_fx, validate_subgraph_args_types, ) @@ -22,7 +30,10 @@ from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( _temp_remove_metadata_torch_function_mode, +<<<<<<< HEAD disable_proxy_modes_tracing, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ProxyTorchDispatchMode, track_tensor_tree, ) @@ -53,6 +64,7 @@ def __call__( validate_subgraph_args_types(additional_inputs) return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs) +<<<<<<< HEAD def gen_schema(self, cond_fn, body_fn, carried_inputs, additional_inputs): from torch._higher_order_ops.schema import HopSchemaGenerator from torch._higher_order_ops.utils import materialize_as_graph @@ -130,6 +142,8 @@ def _find_example_value(n, real_inp): ) return schema_gen.gen_schema() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) while_loop_op = WhileLoopOp() @@ -190,9 +204,15 @@ def body_fn(int_iter, x): - body_fn and cond_fn must not in-place mutate the carried_inputs. A clone before the mutation is required. +<<<<<<< HEAD - body_fn and cond_fn must not mutate python variables (e.g. list/dict) created outside of the body_fn. - body_fn and cond_fn's output cannot alias any of the inputs. A clone is required. +======= + - body_fn and cond_fn must not mutate python varialbles (e.g. list/dict) created outside of the body_fn. + + - body_fn and cond_fn's output cannot aliase any of the inputs. A clone is required. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. warning:: Temporal Limitations: @@ -253,9 +273,13 @@ def _while_loop_op_wrapper(*args, **kwargs): with _temp_remove_metadata_torch_function_mode() as metadata_mode: with _temp_remove_metadata_torch_function_mode() as metadata_mode: if metadata_mode: +<<<<<<< HEAD backend: Union[str, Callable[..., Any]] = ( make_eager_backend_with_torch_function_mode(metadata_mode) ) +======= + backend = make_eager_backend_with_torch_function_mode(metadata_mode) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: backend = "eager" return torch.compile( @@ -264,9 +288,13 @@ def _while_loop_op_wrapper(*args, **kwargs): @while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd) +<<<<<<< HEAD def while_loop_dense( cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False ): +======= +def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) carried_vals = carried_inputs def _validate_cond_output(pred): @@ -286,6 +314,7 @@ def _validate_cond_output(pred): f"carried_inputs must be a tuple or list but got {type(carried_inputs)}" ) +<<<<<<< HEAD # Check condition and set up flag should_loop = cond_fn(*carried_vals, *additional_inputs) _validate_cond_output(should_loop) @@ -339,6 +368,24 @@ def while_loop_autograd(cond_fn, body_fn, operands, additional_inputs): *operands, *additional_inputs, ) +======= + while pred := cond_fn(*carried_vals, *additional_inputs): + _validate_cond_output(pred) + out = body_fn(*carried_vals, *additional_inputs) + assert isinstance( + out, tuple + ), f"body_fn should return a tuple but got {type(out)}" + assert len(out) == len( + carried_inputs + ), "body_fn should return the same number of elements as carried_inputs" + carried_vals = out + return carried_vals + + +while_loop_op.py_autograd_impl( + autograd_not_implemented(while_loop_op, deferred_error=True) +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _find_or_create_fake_mode() -> FakeTensorMode: @@ -354,9 +401,15 @@ def _find_or_create_fake_mode() -> FakeTensorMode: def _create_unbacked_symint( fake_mode: FakeTensorMode, ignore_fresh_unbacked_symbols: bool ) -> torch.SymInt: +<<<<<<< HEAD assert fake_mode is not None and fake_mode.shape_env is not None, ( "Must provide a fake_mode with shape_env." ) +======= + assert ( + fake_mode is not None and fake_mode.shape_env is not None + ), "Must provide a fake_mode with shape_env." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ctx = ( contextlib.nullcontext() if not ignore_fresh_unbacked_symbols @@ -367,6 +420,7 @@ def _create_unbacked_symint( @while_loop_op.py_impl(ProxyTorchDispatchMode) +<<<<<<< HEAD def while_loop_tracing( mode, cond_fn, @@ -379,6 +433,11 @@ def while_loop_tracing( def _trace_while_loop( proxy_mode, op, cond_fn, body_fn, carried_inputs, additional_inputs +======= +def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs): + def _trace_while_loop( + proxy_mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): # NOTE [unspecialize int carry with unbacked symints] # When we support int carry, we'll also need to support int output of body_fn because. @@ -412,12 +471,18 @@ def _trace_while_loop( # For this reason, we treat int, symint outputs in the same way: # - they can match against any of int, symint carry # - we unspecialize them with new unbacked symints in fake while_loop +<<<<<<< HEAD # Similarly, we could do some analysis to refine the output ranges but it's easier to start with # fresh unbacked symints. One surprising case can be: an input unbacked symint is constrained by +======= + # Similarly, we could do some analysis to refine the output ranges but it's eaiser to start with + # fresh unbacked symints. One suprising case can be: an input unbacked symint is constrained by +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # users to be >= 0 (either before while_loop or inside body_fn) and it increments by 1 in each # iteration. Ideally, we should know that the final output is >= 0 but we didn't constrain the # unbacked symint output of subgraph as of today because this requires a smart range analysis. fake_mode: FakeTensorMode = _find_or_create_fake_mode() +<<<<<<< HEAD def _unspecialize_carried_inputs(x): if isinstance(x, (int, torch.SymInt)): @@ -456,6 +521,23 @@ def produce_graph(fn): cond_graph = produce_graph(cond_fn) body_graph = produce_graph(body_fn) +======= + unspecialized_carried_inputs = pytree.tree_map_only( + (int, torch.SymInt), + # For temporarily created unbacked symints, we don't need to bind them to any proxy + lambda _: _create_unbacked_symint( + fake_mode, ignore_fresh_unbacked_symbols=True + ), + carried_inputs, + ) + + cond_graph = reenter_make_fx(cond_fn)( + *unspecialized_carried_inputs, *additional_inputs + ) + body_graph = reenter_make_fx(body_fn)( + *unspecialized_carried_inputs, *additional_inputs + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) next_name = None i = 0 @@ -477,10 +559,17 @@ def produce_graph(fn): proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) out_proxy = proxy_mode.tracer.create_proxy( +<<<<<<< HEAD "call_function", op, proxy_args, {}, name=op._name ) out = op( +======= + "call_function", while_loop_op, proxy_args, {}, name="while_loop" + ) + + out = while_loop_op( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs ) return track_tensor_tree( @@ -488,18 +577,26 @@ def produce_graph(fn): ) return _trace_while_loop( +<<<<<<< HEAD mode, op, cond_fn, body_fn, carried_inputs, additional_inputs, +======= + mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @while_loop_op.py_impl(FakeTensorMode) def while_loop_fake_tensor_mode( +<<<<<<< HEAD mode, cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False +======= + mode, cond_fn, body_fn, carried_inputs, additional_inputs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): with mode: # NOTE: [Handling unback symints in subgraph of while_loop] @@ -544,6 +641,7 @@ def while_loop_fake_tensor_mode( "body_output", include_contiguity=False, ) +<<<<<<< HEAD if stack_output: n_iter = _create_unbacked_symint(mode, ignore_fresh_unbacked_symbols=False) @@ -564,6 +662,8 @@ def while_loop_fake_tensor_mode( fake_outputs, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # See NOTE [unspecialize int carry with unbacked symints] return pytree.tree_map_only( (int, torch.SymInt), @@ -577,6 +677,7 @@ def while_loop_fake_tensor_mode( @while_loop_op.py_functionalize_impl +<<<<<<< HEAD def while_loop_func( ctx, cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False ): @@ -584,6 +685,11 @@ def while_loop_func( op = while_loop_stack_output_op if stack_output else while_loop_op +======= +def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs): + from torch._higher_order_ops.utils import _check_alias_and_mutation + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs) unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs) unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs @@ -596,13 +702,18 @@ def while_loop_func( (body_fn, "body_fn"), ]: _check_alias_and_mutation(fn, unwrapped_inputs, fn_name, pre_dispatch) +<<<<<<< HEAD ret = op( +======= + ret = while_loop_op( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) functional_cond_fn, functional_body_fn, unwrapped_carried_inputs, unwrapped_additional_inputs, ) return ctx.wrap_tensors(ret) +<<<<<<< HEAD class WhileLoopStackOutputOp(HigherOrderOperator): @@ -927,3 +1038,5 @@ def body_fn(*flat_args): while_loop_stack_output_op.py_autograd_impl( autograd_not_implemented(while_loop_stack_output_op, deferred_error=True) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_higher_order_ops/wrap.py b/torch/_higher_order_ops/wrap.py index 8e9ca0503402c..e251193f9ecad 100644 --- a/torch/_higher_order_ops/wrap.py +++ b/torch/_higher_order_ops/wrap.py @@ -2,6 +2,7 @@ import inspect import itertools import logging +<<<<<<< HEAD from typing import Any, Optional import torch @@ -10,6 +11,12 @@ from torch._ops import HigherOrderOperator from torch.fx import GraphModule from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +======= +from typing import Optional + +from torch._logging import warning_once +from torch._ops import HigherOrderOperator +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.types import _dtype @@ -231,8 +238,12 @@ def divide_kwargs(kwargs): } return checkpoint_kwargs, gmod_kwargs +<<<<<<< HEAD @staticmethod def tag_nodes(gmod, is_sac): +======= + def tag_nodes(self, gmod, is_sac): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils.checkpoint import CheckpointPolicy unique_graph_id = next(uid) @@ -248,6 +259,7 @@ def tag_nodes(gmod, is_sac): return gmod def __call__(self, gmod, *args, **kwargs): +<<<<<<< HEAD dispatch_key_set = torch._ops._compute_keyset( args, kwargs, self.non_fallthrough_keys ) @@ -330,3 +342,46 @@ def proxy_mode_key( return track_tensor_tree( example_out, out_proxy, constant=None, tracer=proxy_mode.tracer ) +======= + import torch.fx.traceback as fx_traceback + from torch.fx import Interpreter + + if "_checkpoint_context_fn" in gmod.meta: + warning_once( + log, + """ +Detected that context_fn is passed to torch.utils.checkpoint under torch.compile. +Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_). +""", + ) + # use_reentrant is set to False because this op is going to be traced. + # And we ensure that AOT Autograd traces through the non reentrant + # version of checkpointing. + kwargs["use_reentrant"] = False + # preserve_rng_state is set to False because we want to prevent AOTAutograd from tracing through + # `torch.random.fork_rng` op (which is not supported yet under CUDA). + # This doesn't mean that we don't preserve RNG state. Instead, we will always preserve RNG state + # regardless of this flag (by doing RNG functionalization via `replace_random_passes` in Inductor + # instead of in AOTAutograd). + kwargs["preserve_rng_state"] = False + kwargs["context_fn"] = gmod.meta["_checkpoint_context_fn"] + # We first tag all nodes as "recompute" in this graph, and then we undo the "recompute" tag + # for specific nodes in _CachingTorchDispatchMode in torch/utils/checkpoint.py. + gmod = self.tag_nodes(gmod, is_sac=True) + # Using interpreter allows preservation of metadata through torch.compile stack. + with fx_traceback.preserve_node_meta(): + from torch.utils.checkpoint import checkpoint + + return checkpoint(Interpreter(gmod).run, *args, **kwargs) + else: + gmod = self.tag_nodes(gmod, is_sac=False) + # Using interpreter allows preservation of metadata through torch.compile stack. + # TODO: We want to use the same `checkpoint(Interpreter(gmod).run, *args, **kwargs)` here + # as the `context_fn != None` case, but that depends on in-place op support in TorchDispatchMode + torch.compile. + # (for details on in-place op issue, run `test_compile_selective_checkpoint_inplace_op` unit test) + with fx_traceback.preserve_node_meta(): + return Interpreter(gmod).run(*args) + + +tag_activation_checkpoint = TagActivationCheckpoint() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index d287337afaa69..ed2f415fd2b19 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -6,6 +6,10 @@ import os from typing import Any, IO, Literal, Optional, TYPE_CHECKING, Union +<<<<<<< HEAD +======= +import torch._inductor.config +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch.fx from .standalone_compile import CompiledArtifact # noqa: TC001 @@ -14,7 +18,10 @@ if TYPE_CHECKING: from torch._inductor.utils import InputType from torch.export import ExportedProgram +<<<<<<< HEAD from torch.export.pt2_archive._package import AOTICompiledModel +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.export.pt2_archive._package_weights import Weights from torch.types import FileLike @@ -223,7 +230,11 @@ def _aoti_compile_and_package_inner( not_strict_accuracy = check_accuracy == "accuracy" if not same_two_models( gm, +<<<<<<< HEAD compiled_model, # type: ignore[arg-type] +======= + compiled_model, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args, only_fwd=True, require_fp64=not_strict_accuracy, @@ -238,7 +249,11 @@ def _aoti_compile_and_package_inner( def aoti_load_package( path: FileLike, run_single_threaded: bool = False, device_index: int = -1 +<<<<<<< HEAD ) -> AOTICompiledModel: +======= +) -> Any: # type: ignore[type-arg] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Loads the model from the PT2 package. @@ -275,7 +290,11 @@ def aot_compile( kwargs: Optional[dict[str, Any]] = None, *, options: Optional[dict[str, Any]] = None, +<<<<<<< HEAD ) -> Union[str, list[Union[str, Weights]], torch.fx.GraphModule]: +======= +) -> Union[str, list[Union[str, Weights]]]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Ahead-of-time compile a given FX graph with TorchInductor into a shared library. @@ -292,6 +311,7 @@ def aot_compile( """ from .compile_fx import _aoti_flatten_inputs, compile_fx_aot +<<<<<<< HEAD if hasattr(gm, "_guards_fn"): # Do not compile the guards function, since it may contain checks # that are not currently supported by AOTI. In particular, non-Tensor @@ -301,6 +321,8 @@ def aot_compile( delattr(gm, "_guards_fn") gm.recompile() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) flat_example_inputs, options = _aoti_flatten_inputs( gm, args, kwargs, options=options ) diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 9f941c04e7b38..f713d3c7b4097 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -2,13 +2,19 @@ from __future__ import annotations import atexit +<<<<<<< HEAD import contextlib +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import functools import json import logging import multiprocessing import os +<<<<<<< HEAD import re +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sys from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures.process import BrokenProcessPool @@ -38,11 +44,15 @@ StaticAutotunerFuture, torch_key, ) +<<<<<<< HEAD from torch._inductor.compile_worker.subproc_pool import ( AnyPool, SubprocException, SubprocPool, ) +======= +from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.compile_worker.tracked_process_pool import ( TrackedProcessPoolExecutor, ) @@ -53,7 +63,10 @@ ) from torch._inductor.utils import clear_on_fresh_cache from torch._inductor.virtualized import V +<<<<<<< HEAD from torch._utils_internal import log_triton_builds +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.hub import _Faketqdm, tqdm from torch.utils._ordered_set import OrderedSet from torch.utils._triton import has_triton_package @@ -73,10 +86,13 @@ _triton_kernel_metrics: Optional[dict[str, dict[str, Any]]] = None +<<<<<<< HEAD size_hints_regex = re.compile( r"size_hints=(\{.*?\})", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def pre_fork_setup(): """ @@ -88,10 +104,20 @@ def pre_fork_setup(): # Computing the triton key can be slow. If we call it before fork, # it will be cached for the forked subprocesses. +<<<<<<< HEAD from torch._inductor.runtime.triton_compat import HAS_TRITON, triton_key if HAS_TRITON: triton_key() +======= + try: + from triton.compiler.compiler import triton_key + + triton_key() + except ImportError: + # Triton might not be installed or might be an old version. + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def caching_device_properties(): @@ -148,7 +174,10 @@ def shutdown_compile_workers() -> None: """Shut down all outstanding compile-worker pools.""" for pool in _pool_set: pool.shutdown() +<<<<<<< HEAD AsyncCompile._ready_future = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) after_fork() @@ -228,6 +257,7 @@ def remove_future(kernel_src: str) -> None: del CompiledTritonKernels._cache[key] +<<<<<<< HEAD @contextlib.contextmanager def async_compile_pool_manager(): """ @@ -247,6 +277,9 @@ class AsyncCompile: _ready_future: Optional[Future[Any]] = None +======= +class AsyncCompile: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self) -> None: pass @@ -265,7 +298,10 @@ def _get_ready(): @functools.lru_cache(1) def process_pool() -> AnyPool: assert get_compile_threads() > 1 +<<<<<<< HEAD AsyncCompile._ready_future = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log.info( "Creating '%s' pool with %d workers", config.worker_start_method, @@ -293,6 +329,11 @@ def process_pool() -> AnyPool: # kill the worker thread that sends the shutdown message to the workers... multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) +<<<<<<< HEAD +======= + # Set an attribute we can check to see if the pool is ready. + pool.ready_future = pool.submit(AsyncCompile._get_ready) # type: ignore[union-attr] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _pool_set.add(pool) return pool @@ -301,24 +342,32 @@ def warm_pool(cls) -> None: if get_compile_threads() <= 1: return _compile_start() +<<<<<<< HEAD # Pool is created on first access. Note for a SubprocPool, the sidecar process starts, # but its ProcessPoolExecutor does not initialize until a wakeup() call or the first # job is submitted. +======= + # Pool is initialized on first access +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cls.process_pool() _compile_end() @classmethod +<<<<<<< HEAD def wait_pool_ready(cls, timeout=120) -> None: cls.use_process_pool() if cls._ready_future is not None: cls._ready_future.result(timeout=timeout) @classmethod +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def submit(cls, task: Callable[..., Any]) -> Any: if get_compile_threads() <= 1: return task() return cls.pool().submit(task) +<<<<<<< HEAD @classmethod def use_process_pool(cls): if get_compile_threads() <= 1: @@ -356,6 +405,12 @@ def wakeup(cls) -> None: pool = cls.process_pool() if isinstance(pool, SubprocPool): pool.wakeup() +======= + def use_process_pool(self): + return ( + get_compile_threads() > 1 and self.process_pool().ready_future.done() # type: ignore[union-attr] + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): """ @@ -426,6 +481,7 @@ def reload_kernel_in_parent(): "use_static_cuda_launcher": torch._inductor.config.use_static_cuda_launcher } +<<<<<<< HEAD if len(torch._inductor.config.autotune_lookup_table) > 0: m = size_hints_regex.search(source_code) if m: @@ -447,6 +503,8 @@ def reload_kernel_in_parent(): fn_hash: torch._inductor.config.autotune_lookup_table[fn_hash] } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) task = self.process_pool().submit( _worker_compile_triton, load_kernel, @@ -455,18 +513,25 @@ def reload_kernel_in_parent(): ) def get_result() -> CachingAutotuner: +<<<<<<< HEAD try: kernel, elapsed_us = task.result() except SubprocException as e: raise e.with_name(kernel_name) from e +======= + kernel, elapsed_us = task.result() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Now that we've compiled, we should clear the future # so it can't be used again kernel.set_compile_info(compile_id, is_backward) CompiledTritonKernels.remove_future(source_code) +<<<<<<< HEAD kernel.restore_after_unpickle(old_values=None) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kernel.precompile( warm_cache_only=False, reload_kernel=reload_kernel_in_parent, @@ -491,6 +556,7 @@ def get_result() -> CachingAutotuner: log_waitcounter=True, waitcounter_name_override="compile_triton", ): +<<<<<<< HEAD fail = None try: start_ns = time_ns() @@ -514,6 +580,24 @@ def get_result() -> CachingAutotuner: raise finally: log_triton_builds(fail=fail) +======= + start_ns = time_ns() + _set_triton_ptxas_path() + kernel = load_kernel() + kernel.set_compile_info(compile_id, is_backward) + kernel.precompile( + warm_cache_only=False, + static_triton_bundle_key=CompiledTritonKernels.key(source_code), + ) + elapsed_us = (time_ns() - start_ns) // 1000 + get_metrics_context().add_top_n( + "triton_kernel_compile_times_us", kernel_name, elapsed_us + ) + info = kernel.autotune_cache_info or {} + info["compile_time_us"] = elapsed_us + _add_triton_kernel_info(kernel_name, info) + return kernel +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def multi_kernel(self, *args, **kwargs) -> Any: from torch._inductor.codegen.multi_kernel import MultiKernelCall @@ -580,6 +664,7 @@ def halide(self, meta: HalideMeta, source_code: str): ) return LambdaFuture(get_result) +<<<<<<< HEAD def cutedsl(self, kernel_name: str, source_code: str): """ Compile CuteDSL (CUTLASS Python DSL) kernels. @@ -619,6 +704,8 @@ def task(): future = self.submit(task) return LambdaFuture(lambda: future.result()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def wait(self, scope: dict[str, Any]) -> None: if get_compile_threads() > 1: with dynamo_timed( @@ -660,6 +747,7 @@ def _wait_futures(self, scope: dict[str, Any]) -> None: pbar.update(1) +<<<<<<< HEAD def maybe_warm_pool() -> None: if ( os.environ.get("TORCH_TNT_IN_USE", "0") == "1" @@ -678,6 +766,20 @@ def maybe_warm_pool() -> None: # could start them lazily if we're willing to lose a small amount of compile time. AsyncCompile.wakeup() +======= +if ( + os.environ.get("TORCH_TNT_IN_USE", "0") == "1" + or os.environ.get("TORCH_WARM_POOL", "1") != "1" + # The subprocess pool is only used for the Triton backend + or not has_triton_package() + # Skip for fbcode. We have internal reports of usages inside multiprocessing + # pools that lead a multiplicative number of compile subprocesses. + or config.is_fbcode() +): + pass +else: + AsyncCompile.warm_pool() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # On exit give the workers a chance to clean themselves up. Without this the # resource_tracker can complain about leaked semaphores coming from the diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index a504b54f132b7..9018d977e4778 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -31,12 +31,16 @@ get_hash, PyCodeCache, ) +<<<<<<< HEAD from torch._inductor.utils import ( get_gpu_type, get_ld_library_path, is_gpu, python_subprocess_env, ) +======= +from torch._inductor.utils import get_gpu_type, get_ld_library_path, is_gpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._logging import getArtifactLogger from torch.utils._ordered_set import OrderedSet @@ -44,7 +48,11 @@ if TYPE_CHECKING: from types import ModuleType +<<<<<<< HEAD from torch._inductor.select_algorithm import PartialRender, TritonTemplateCaller +======= + from torch._inductor.select_algorithm import TritonTemplateCaller +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from . import config from .runtime.benchmarking import benchmarker @@ -128,8 +136,16 @@ def start(self): f"--read-fd={str(subproc_read_fd)}", f"--write-fd={str(subproc_write_fd)}", ] +<<<<<<< HEAD env = { **python_subprocess_env(), +======= + extra_env = { + # We need to set the PYTHONPATH so the subprocess can find torch. + "PYTHONPATH": os.environ.get( + "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path) + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We shouldn't be using the Triton async compile subprocess pool, # but as a precaution set the env var that disables its creation. "TORCH_WARM_POOL": "0", @@ -141,10 +157,17 @@ def start(self): else "0", } if self.device is not None: +<<<<<<< HEAD env[CUDA_VISIBLE_DEVICES] = str(self.device) self.process = subprocess.Popen( cmd, env=env, +======= + extra_env[CUDA_VISIBLE_DEVICES] = str(self.device) + self.process = subprocess.Popen( + cmd, + env={**os.environ, **extra_env}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pass_fds=(subproc_read_fd, subproc_write_fd), ) os.close(subproc_read_fd) @@ -766,7 +789,11 @@ def update_workspace_size(self) -> None: return self.ensure_dll_loaded() unique_input_count = len( +<<<<<<< HEAD dict.fromkeys(meta.name for meta in self.input_tensor_meta) +======= + {meta.name for meta in self.input_tensor_meta} # noqa: set_linter +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) args = [c_void_p(None) for _ in range(unique_input_count + 1)] stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) @@ -876,6 +903,7 @@ def __str__(self) -> str: return f"{self.kernel_name=}" +<<<<<<< HEAD class CuteDSLBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): """Benchmark request for CuteDSL (CUTLASS Python DSL) kernels.""" @@ -925,6 +953,8 @@ def cleanup_run_fn(self) -> None: """Clean up any resources used by the kernel.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @functools.cache def get_tuning_process_pool() -> TuningProcessPool: pool = TuningProcessPool() diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index 2189e44f9e246..b1fd958f56b06 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -1,7 +1,11 @@ from __future__ import annotations import typing +<<<<<<< HEAD from typing import Any, Optional, TYPE_CHECKING, Union +======= +from typing import Any, Optional, TYPE_CHECKING +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sympy @@ -9,6 +13,7 @@ from . import config from .codecache import write_text +<<<<<<< HEAD from .kernel_inputs import KernelInputs # noqa: TC001 from .metrics import get_metric_table, is_metric_table_enabled from .runtime.hints import DeviceProperties, ReductionHint @@ -19,6 +24,15 @@ CPUConfigHeuristic, CUDAConfigHeuristic, MTIAConfigHeuristic, +======= +from .metrics import get_metric_table, is_metric_table_enabled +from .runtime.hints import DeviceProperties, ReductionHint +from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse +from .template_heuristics import ( + BaseConfigHeuristic, + CPUConfigHeuristic, + CUDAConfigHeuristic, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ROCmConfigHeuristic, XPUConfigHeuristic, ) @@ -33,11 +47,16 @@ from torch.utils._ordered_set import OrderedSet +<<<<<<< HEAD from .codegen.common import KernelTemplate from .codegen.simd_kernel_features import SIMDKernelFeatures from .codegen.triton import TritonKernel from .ir import ChoiceCaller from .select_algorithm import ExternKernelChoice +======= + from .codegen.simd_kernel_features import SIMDKernelFeatures + from .codegen.triton import TritonKernel +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Sortable(typing.Protocol): @@ -71,11 +90,69 @@ def get_config_heuristics( return XPUConfigHeuristic() elif device_type == "cpu": return CPUConfigHeuristic() +<<<<<<< HEAD elif device_type == "mtia": return MTIAConfigHeuristic() else: return BaseConfigHeuristic() +======= + else: + return BaseConfigHeuristic() + + # GEMM configs + def get_base_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + if config.max_autotune_gemm_search_space != "EXHAUSTIVE": + return mm_heuristics.get_mm_configs() + else: + return mm_heuristics.get_exhaustive_mm_configs() + + def get_extra_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_extra_mm_configs() + + def get_int8_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_int8_mm_configs() + + def get_mixed_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_mixed_mm_configs() + + def get_persistent_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_persistent_mm_configs() + + def get_scaled_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_scaled_mm_configs() + + def get_scaled_persistent_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_scaled_persistent_mm_configs() + + def get_mm_plus_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + mm_heuristics = self.get_config_heuristics(device_type) + return mm_heuristics.get_mm_plus_mm_configs() + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Conv configs def get_conv_configs( self, device_type: Optional[str] = "cuda" @@ -84,7 +161,10 @@ def get_conv_configs( return conv_heuristics.get_conv_configs() # Flex attention configs +<<<<<<< HEAD # TODO(coconutruben): break out flexattention/decode configs into the new retrieval mechanism +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_flex_attention_fwd_configs( self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda" ) -> list[Any]: @@ -103,6 +183,7 @@ def get_flex_decode_configs( flex_heuristics = self.get_config_heuristics(device_type) return flex_heuristics.get_flex_decode_configs(head_dim, dtype) +<<<<<<< HEAD def get_mm_configs( self, kernel_inputs: KernelInputs, @@ -167,6 +248,8 @@ def get_mm_configs( if choice is not None: yield choice +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def triton_kernel_kwargs( self, kernel_cls: type[TritonKernel], @@ -216,9 +299,13 @@ def should_use_persistent_reduction( if cooperative_reduction: # The RSPLIT of cooperative reductions means each thread block is operating on fewer elements try: +<<<<<<< HEAD threshold *= 32 // min( V.graph.sizevars.size_hint_or_throw(features.numel), 32 ) +======= + threshold *= 32 // min(V.graph.sizevars.size_hint(features.numel), 32) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except ValueError: pass # unbacked symint @@ -386,6 +473,7 @@ def can_fuse( WhyNoFuse(node1, node2)("Fusion will increase peak memory") return False +<<<<<<< HEAD if ( config.realize_acc_reads_size_threshold is not None and scheduler.fusion_accumulate_large_reads( @@ -397,6 +485,8 @@ def can_fuse( WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads") return False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return True @staticmethod diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 7b24208a2c512..2c03d1614fcaa 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -30,7 +30,10 @@ from datetime import timedelta from functools import lru_cache, partial from pathlib import Path +<<<<<<< HEAD from tempfile import _TemporaryFileWrapper +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from time import time, time_ns from types import ModuleType from typing import ( @@ -53,7 +56,10 @@ from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed from torch._inductor import config, exc, metrics from torch._inductor.codegen.common import ( +<<<<<<< HEAD custom_backend_codegen_configs, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) custom_backend_passes, init_backend_registration, ) @@ -76,15 +82,21 @@ get_ld_and_objcopy, get_name_and_dir_from_output_file_path, normalize_path_separator, +<<<<<<< HEAD run_asm_build_object, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from torch._inductor.cpu_vec_isa import pick_vec_isa from torch._inductor.custom_graph_pass import ( CustomGraphModulePass, CustomGraphPass, CustomGraphPassType, +<<<<<<< HEAD CustomPartitionerFn, CustomPartitionerFnType, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param from torch._inductor.runtime.compile_tasks import _reload_python_module @@ -124,6 +136,29 @@ if config.is_fbcode(): from triton.fb.build import build_paths +<<<<<<< HEAD +======= + from torch._inductor.fb.utils import ( + log_global_cache_errors, + log_global_cache_stats, + log_global_cache_vals, + use_global_cache, + ) +else: + + def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: + pass + + def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: + pass + + def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: + pass + + def use_global_cache() -> bool: + return False + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T = TypeVar("T") @@ -146,7 +181,10 @@ LOCK_TIMEOUT = 600 output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") +<<<<<<< HEAD autotuning_log = torch._logging.getArtifactLogger(__name__, "autotuning") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) @@ -174,10 +212,23 @@ def get_kernel_bin_format(device: str) -> str: return "" +<<<<<<< HEAD +======= +@functools.cache +def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]: + return ( + Path(os.path.join(global_cache_dir, CacheBase.get_system()["hash"])) + if global_cache_dir is not None + else None + ) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class CacheBase: @staticmethod @functools.cache def get_system() -> dict[str, Any]: +<<<<<<< HEAD from torch._inductor.runtime.triton_compat import HAS_TRITON, triton_key if HAS_TRITON: @@ -185,6 +236,15 @@ def get_system() -> dict[str, Any]: # is not updated with each code change triton_version = triton_key() else: +======= + try: + from triton.compiler.compiler import triton_key + + # Use triton_key instead of triton.__version__ as the version + # is not updated with each code change + triton_version = triton_key() + except ModuleNotFoundError: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) triton_version = None try: @@ -219,6 +279,13 @@ def get_system() -> dict[str, Any]: def get_local_cache_path() -> Path: return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"])) +<<<<<<< HEAD +======= + @staticmethod + def get_global_cache_path() -> Optional[Path]: + return get_global_cache_path_impl(config.global_cache_dir) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self) -> None: self.system = CacheBase.get_system() @@ -265,43 +332,83 @@ def set_value(self, *keys: str, value: Any) -> None: class PersistentCache(CacheBase): +<<<<<<< HEAD +======= + @functools.cache # noqa: B019 + def get_global_cache(self) -> dict[str, Any]: + global_cache_path = self.get_global_cache_path() + if global_cache_path is None or not global_cache_path.is_file(): + return {} + with open(global_cache_path) as global_cache_fp: + global_cache = json.load(global_cache_fp) + return global_cache["cache"] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def lookup( self, choices: list[ChoiceCaller], op: str, inputs: str, benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]], +<<<<<<< HEAD hint_override: Optional[int] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> dict[ChoiceCaller, float]: """ Check to see if we have benchmarked the given choice callers. For each choice caller: +<<<<<<< HEAD 1. Check local_cache[op][inputs][choice][precision], return benchmark if cached. 2. If benchmark is not None: +======= + 1. Check global_cache[op][inputs][choice][precision], return benchmark if cached. + 2. Check local_cache[op][inputs][choice][precision], return benchmark if cached. + 3. If benchmark is not None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a. `max_autotune_gemm=True`: benchmark the choice, update local_cache[op][inputs][choice], and return the benchmark. b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing. """ precision = torch.get_float32_matmul_precision() +<<<<<<< HEAD cache_key = f"{inputs}_{hint_override}" if hint_override is not None else inputs timings = {} def check_cache(cache: dict[str, Any]) -> bool: +======= + + log_stats = partial(log_global_cache_stats, self.system, op, inputs, precision) + log_vals = partial(log_global_cache_vals, self.system, op, inputs, precision) + log_errors = partial( + log_global_cache_errors, self.system, op, inputs, precision + ) + timings = {} + + def check_cache(cache: dict[str, Any], callback: Any = None) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Check if `cache` contains data for all the choices""" hit = True for choice in choices: choice_hash = choice.hash_key() +<<<<<<< HEAD if choice_hash in cache.get(op, {}).get(cache_key, {}).get( precision, {} ): # cache hit timings[choice] = cache[op][cache_key][precision][choice_hash] +======= + if choice_hash in cache.get(op, {}).get(inputs, {}).get(precision, {}): + # cache hit + timings[choice] = cache[op][inputs][precision][choice_hash] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: # cache miss hit = False break +<<<<<<< HEAD return hit local_cache = self.get_local_cache() if config.autotune_local_cache else {} @@ -315,6 +422,46 @@ def check_cache(cache: dict[str, Any]) -> bool: local_cache[op][cache_key][precision][choice.hash_key()] = timing self.update_local_cache(local_cache) +======= + if callback: + callback(cached=hit) + return hit + + if config.max_autotune or config.max_autotune_gemm: + local_cache = self.get_local_cache() if config.autotune_local_cache else {} + # check local cache first since it is data specific to the current machine + if ( + not check_cache(local_cache) + and not ( + use_global_cache() + and check_cache(self.get_global_cache(), callback=log_stats) + ) + and benchmark is not None + ): + try: + # re-benchmark everything to try to get consistent numbers from the same machine + timings = benchmark(choices) + assert all(choice in timings for choice in choices) + local_cache.setdefault(op, {}) + local_cache[op].setdefault(inputs, {}).setdefault(precision, {}) + for choice, timing in timings.items(): + local_cache[op][inputs][precision][choice.hash_key()] = timing + except RuntimeError as e: + # catch and log autotuning failures + log_errors(e) + raise e + + self.update_local_cache(local_cache) + + timings_to_log = { + choice.hash_key(): timings[choice] for choice in choices + } + log_vals(timings_to_log) + elif use_global_cache(): + # only check global cache, not local one + check_cache(self.get_global_cache(), callback=log_stats) + # may have a partial cache hit, where not everything is benchmarked +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return timings @@ -363,6 +510,7 @@ def get_hash( raise AssertionError(f"Unknown hash type {hash_type}") +<<<<<<< HEAD class WritableTempFile: """ Avoid "Permission denied error" on Windows: @@ -393,6 +541,8 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: os.unlink(self.temp_file.name) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def write( content: Union[str, bytes], extension: str, @@ -853,7 +1003,11 @@ def __init__( # Global settings affecting matmul codegen. self.cuda_matmul_settings = ( +<<<<<<< HEAD torch.backends.cuda.matmul.fp32_precision, +======= + torch.backends.cuda.matmul.allow_tf32, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction, torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction, ) @@ -866,6 +1020,7 @@ def __init__( self.post_grad_custom_pre_pass = self._get_custom_pass_detail( config.post_grad_custom_pre_pass ) +<<<<<<< HEAD # TODO: change to more holistic config rather than bundled_autograd_cache self.precompile_enabled = torch._functorch.config.bundled_autograd_cache self.post_grad_custom_post_pass = self._get_custom_pass_detail( @@ -877,6 +1032,11 @@ def __init__( self.joint_custom_post_pass = self._get_custom_pass_detail( config.joint_custom_post_pass ) +======= + self.post_grad_custom_post_pass = self._get_custom_pass_detail( + config.post_grad_custom_post_pass + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._pre_fusion_custom_pass = self._get_custom_pass_detail_unsafe( config._pre_fusion_custom_pass ) @@ -890,6 +1050,7 @@ def __init__( map(self._get_custom_pass_detail, custom_backend_passes.values()) ) +<<<<<<< HEAD # Save custom inductor codegen configs self.custom_backend_codegen_configs = { device: custom_config.save_config_portable(ignore_private_configs=False) @@ -902,6 +1063,8 @@ def __init__( config.custom_partitioner_fn ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This is mainly added to handle these two inductor configs, which are (unfortunately) # sometimes cache safe: # - _pre_fusion_custom_pass @@ -934,6 +1097,7 @@ def _get_custom_pass_detail( assert isinstance(custom_pass, (CustomGraphPass, CustomGraphModulePass)) return custom_pass.uuid() +<<<<<<< HEAD def _get_custom_partitioner_fn_detail( self, custom_partitioner_fn: CustomPartitionerFnType ) -> Optional[Any]: @@ -942,6 +1106,8 @@ def _get_custom_partitioner_fn_detail( assert isinstance(custom_partitioner_fn, CustomPartitionerFn) return custom_partitioner_fn.uuid() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def compiled_fx_graph_hash( gm: torch.fx.GraphModule, @@ -1113,7 +1279,11 @@ def _get_shape_env(cls: type[GuardedCache[T]]) -> Optional[ShapeEnv]: Helper to get the shape env from the tracing context. """ ctx = torch._guards.TracingContext.try_get() +<<<<<<< HEAD if not ctx or not ctx.fake_mode: +======= + if not ctx: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return None return ctx.fake_mode.shape_env @@ -1248,6 +1418,7 @@ def cache_hit_post_compile( lambda: {"filename": artifact_path}, payload_fn=lambda: code, ) +<<<<<<< HEAD trace_structured( "artifact", metadata_fn=lambda: { @@ -1264,6 +1435,8 @@ def cache_hit_post_compile( }, payload_fn=lambda: graph.inductor_provenance_stack_traces_str, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return graph, cache_info @staticmethod @@ -1426,10 +1599,13 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None: for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass): if p and (not isinstance(p, CustomGraphPass) or not p.uuid()): raise BypassFxGraphCache("Unsupported post grad custom pass") +<<<<<<< HEAD # Same with the joint custom passes for p in (config.joint_custom_pre_pass, config.joint_custom_post_pass): if p and (not isinstance(p, CustomGraphPass) or not p.uuid()): raise BypassFxGraphCache("Unsupported joint custom pass") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We should find any users of _pre_fusion_custom_pass and _fuse_ddp_communication_passes # and ensure they are not passing us raw callables if config._pre_fusion_custom_pass is not None: @@ -1602,6 +1778,7 @@ def clear() -> None: @functools.cache def split_aot_inductor_output_path(path: str) -> tuple[str, str]: +<<<<<<< HEAD def get_module_ext_type() -> str: if _IS_WINDOWS: return ".pyd" @@ -1610,6 +1787,10 @@ def get_module_ext_type() -> str: """Returns the path where the AOT Inductor compiled kernels are stored.""" if path.endswith(get_module_ext_type()): +======= + """Returns the path where the AOT Inductor compiled kernels are stored.""" + if path.endswith(".so"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return os.path.split(path) elif path.endswith(".pt2"): return os.path.split(path) @@ -1713,6 +1894,12 @@ def compile( """ generated_files: list[Union[str, Weights]] = additional_files # type: ignore[assignment] +<<<<<<< HEAD +======= + if sys.platform == "win32": + raise RuntimeError("AotCodeCompiler not yet supported for inductor") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _set_gpu_runtime_env() # cpp_extension consults the env picked_vec_isa = pick_vec_isa() @@ -1768,6 +1955,7 @@ def compile( key=config.aot_inductor.model_name_for_generated_files, ) +<<<<<<< HEAD header_code = "" header_path = "" if config.aot_inductor.compile_standalone: @@ -1812,6 +2000,10 @@ def compile( with WritableTempFile("w", suffix=".gv") as temp_file: tree.to_dotfile(temp_file.name) """ +======= + # Log the AOTInductor wrapper and kernel code, if needed. + with tempfile.NamedTemporaryFile("w+") as t: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) t.writelines((wrapper_code, "\n", kernel_code, "\n")) t.flush() V.debug.output_code(t.name, extension="cpp") @@ -1820,8 +2012,11 @@ def compile( generated_files.append(wrapper_path) if not config.aot_inductor.package_cpp_only: generated_files.append(kernel_path) +<<<<<<< HEAD if config.aot_inductor.compile_standalone: generated_files.append(header_path) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_code_log.info("Wrapper code written to: %s", wrapper_path) output_code_log.info("Kernel code written to: %s", kernel_path) @@ -1843,6 +2038,7 @@ def compile( }, payload_fn=lambda: kernel_code, ) +<<<<<<< HEAD if config.aot_inductor.compile_standalone: output_code_log.info("Header code written to: %s", header_path) trace_structured( @@ -1854,6 +2050,8 @@ def compile( }, payload_fn=lambda: header_code, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We use a file lock below to protect FS operations. The lock file # is scoped to the 'key', so make sure the consts_s is protected @@ -1866,9 +2064,12 @@ def compile( cmake_path = str(Path(specified_sub_dir) / "CMakeLists.txt") def _compile_consts(consts: bytes, platform: str) -> str: +<<<<<<< HEAD # Load from aot_inductor, and update the value on demand. use_asm_build: bool = config.aot_inductor.use_consts_asm_build +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if platform == "linux": if graph.mutated_buffers & OrderedSet(graph.constants.keys()): # .data section is between .text and .bss. When the size of .data is large, @@ -1885,6 +2086,7 @@ def _compile_consts(consts: bytes, platform: str) -> str: elif platform == "darwin": section_attr = "__DATA,__data" symbol_prefix = "_" +<<<<<<< HEAD elif platform == "win32": symbol_prefix = "" # ASM build is not supported on Windows, force use CPP build. @@ -2018,6 +2220,38 @@ def get_zero_consts_asm_code( consts_s = Path(consts_s) object_build_options = CppTorchDeviceOptions( device_type=device_type, +======= + else: + raise RuntimeError(f"Unsupported platform: {platform}") + + is_large_consts = len(consts) > 1024 + consts_asm = f"\t.section\t{section_attr}\n" + consts_asm += f"\t.balign {ALIGN_BYTES}\n" + consts_asm += f"\t.globl\t{symbol_prefix}_binary_constants_bin_start\n" + consts_asm += f"{symbol_prefix}_binary_constants_bin_start:\n" + if not is_large_consts: + for c in consts: + consts_asm += f"\t.byte {c}\n" + # Add one element even if constants are empty + # Otherwise assembler will not put them in data section + if not consts: + consts_asm += "\t.space 1\n" + else: + consts_asm += "\t.quad 0x1234567899abcdef\n" + consts_asm += f"\t.space {len(consts) - 8}\n" + consts_asm += f".globl\t{symbol_prefix}_binary_constants_bin_end\n" + consts_asm += f"{symbol_prefix}_binary_constants_bin_end:\n" + _, consts_s = write( + consts_asm, + "S", + specified_dir=str(specified_sub_dir), + ) + consts_s = Path(consts_s) + object_build_options = CppTorchDeviceOptions( + # Intel compiler failed to compile this manually constructed assembly file. + # it is ok to use gcc to compile the .S to a .o and linked with Intel compiler . + device_type=device_type if device_type != "xpu" else "cpu", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aot_mode=graph.aot_mode, compile_only=True, use_relative_path=use_relative_path, @@ -2029,21 +2263,31 @@ def get_zero_consts_asm_code( BuildOption=object_build_options, ) consts_o = object_builder.get_target_file_path() +<<<<<<< HEAD if use_asm_build is False and is_zero_size_consts: run_asm_build_object(str(consts_s), consts_o, str(consts_s.parent)) else: object_builder.build() if is_large_consts and use_asm_build: +======= + object_builder.build() + + if is_large_consts: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with open(consts_o, "r+b") as f: f.seek(0) hdr = f.read(1024) # Search for magic number and write the actual data over it +<<<<<<< HEAD start_idx = ( hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12") if sys.byteorder == "little" else hdr.find(b"\x12\x34\x56\x78\x99\xab\xcd\xef") ) +======= + start_idx = hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert start_idx != -1 f.seek(start_idx) pos = 0 @@ -2312,6 +2556,7 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: cubins_o = [] asm_files = [] +<<<<<<< HEAD if not _IS_WINDOWS: ld, objcopy = get_ld_and_objcopy(use_relative_path) kernels = getattr(V.graph.wrapper_code, "_kernel_name_to_body", {}) @@ -2356,6 +2601,32 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: cubins_o.append( convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy) ) +======= + ld, objcopy = get_ld_and_objcopy(use_relative_path) + for kernel_name, value in CudaKernelParamCache.cache.items(): + if asm_file := value["asm"]: + asm_files.append(asm_file) + + cubin_file = value[get_cpp_wrapper_cubin_path_name()] + if config.aot_inductor.emit_multi_arch_kernel and device_type == "cuda": + current_arch = _nvcc_arch_as_compile_option() + cmd = ( + f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} " + # Triton only allows generating PTX version as same as the current arch + f"-gencode arch=compute_{current_arch},code=compute_{current_arch} " + # Include SASS for the current specific arch + f"-gencode arch=compute_{current_arch},code=sm_{current_arch} " + ) + subprocess.run( + cmd.split(), capture_output=True, text=True, check=True + ) + + if config.aot_inductor.embed_kernel_binary: + # Embed cubin files into model.so using objcopy + cubins_o.append( + convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_name, output_dir = get_name_and_dir_from_output_file_path(output_so) so_build_options = CppTorchDeviceOptions( @@ -2435,6 +2706,7 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: os.remove(o_file) if use_mmap_weights: +<<<<<<< HEAD def get_page_size() -> int: # Don't use resource.getpagesize() on Windows, as it is a Unix specific package @@ -2473,6 +2745,11 @@ class SYSTEM_INFO(Structure): return sys_page_size page_size_ = get_page_size() +======= + import resource + + page_size_ = resource.getpagesize() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) page_size = max(16384, page_size_) with open(output_so, "a+b") as f_so: @@ -2486,6 +2763,7 @@ class SYSTEM_INFO(Structure): generated_files.append(output_so) if config.aot_inductor.package: +<<<<<<< HEAD if config.trace.provenance_tracking_level != 0: kernel_info = torch._inductor.debug.create_kernel_information_json() kernel_info_json = os.path.join( @@ -2495,6 +2773,8 @@ class SYSTEM_INFO(Structure): f.write(json.dumps(kernel_info, indent=4)) generated_files.append(kernel_info_json) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We want to return the directory that contains all the AOTI # generated files, not just the so # return os.path.split(output_so)[0] @@ -2634,7 +2914,11 @@ def _get_cpp_prefix_header(device: str) -> Optional[str]: def _get_cpp_wrapper_header(device: str, aot_mode: bool = False) -> str: """Given a device type (and optionally whether we're in AOT Inductor mode), returns the path to the cpp_wrapper header file to be precompiled.""" +<<<<<<< HEAD base_device = device.split(":", maxsplit=1)[0] +======= + base_device = device.split(":")[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_array_ref = config.aot_inductor.allow_stack_allocation and base_device == "cpu" return ( "torch/csrc/inductor/" @@ -3203,9 +3487,13 @@ def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[st return [ f"halide_buffer_t {name};", +<<<<<<< HEAD f"halide_dimension_t {name}_dims[] = {{{', '.join(dims)}}};" if len(dims) > 0 else f"halide_dimension_t * {name}_dims = nullptr;", +======= + f"halide_dimension_t {name}_dims[] = {{{', '.join(dims)}}};", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"{name}.device = {device};", f"{name}.device_interface = {device_interface};", f"{name}.host = {host};", @@ -3632,7 +3920,11 @@ def _cutlass_path() -> str: if config.is_fbcode(): from libfb.py import parutil +<<<<<<< HEAD return parutil.get_dir_path("cutlass-4-headers") +======= + return parutil.get_dir_path("cutlass-3-headers") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return config.cuda.cutlass_dir @@ -3673,9 +3965,13 @@ def cutlass_key() -> bytes: Note: OSS and fbcode will have different keys. """ if config.is_fbcode(): +<<<<<<< HEAD with importlib.resources.path( "cutlass_library", "src_hash.txt" ) as resource_path: +======= + with importlib.resources.path("cutlass", "src_hash.txt") as resource_path: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with open(resource_path) as resource_file: return resource_file.read().encode() @@ -3704,12 +4000,18 @@ def _cuda_lib_options() -> list[str]: if "torch/lib" in path: # don't want to depend on pytorch continue +<<<<<<< HEAD extra_ldflags.append(f"-L{path}") # -rpath ensures the DLL can find its dependencies when loaded, even # if the library path is non-standard. # But do not add the stubs folder to rpath as the driver is expected to be found at runtime if os.path.basename(path) != "stubs": extra_ldflags.extend(["-Xlinker", f"-rpath={path}"]) +======= + # -rpath ensures the DLL can find its dependencies when loaded, even + # if the library path is non-standard. + extra_ldflags.extend([f"-L{path}", "-Xlinker", f"-rpath={path}"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extra_ldflags.append("-lcuda") extra_ldflags.append("-lcudart") else: @@ -3818,10 +4120,14 @@ def cuda_compile_command( res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}" else: raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!") +<<<<<<< HEAD if log.isEnabledFor(logging.DEBUG): log.debug("CUDA command: %s", res) else: autotuning_log.debug("CUDA command: %s", res) +======= + log.debug("CUDA command: %s", res) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return res @@ -3966,7 +4272,10 @@ def get_kernel_binary_remote_cache( return None @classmethod +<<<<<<< HEAD @lru_cache(None) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: """ Writes source code into a file with dst_file_ext as the file extension. @@ -3991,6 +4300,12 @@ def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: cutlass_key(), # hack to deal with AOTI .o compilation ] +<<<<<<< HEAD +======= + + [dst_file_ext] + if dst_file_ext == "o" + else [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) key, input_path = write(source_code, cls._SOURCE_CODE_SUFFIX, extra=extra) return key, input_path @@ -4001,6 +4316,7 @@ def compile( ) -> tuple[str, str, str]: """ Compiles CUDA source_code into a file with dst_file_ext extension. +<<<<<<< HEAD If dst_file_ext is "so", first compiles to ".o" and then links to ".so". Returns a tuple of dst_file_path, hash_key, source_code_path """ @@ -4016,6 +4332,12 @@ def compile( key_with_ext = key + dst_file_ext if key_with_ext not in cls.cache: +======= + Returns a tuple of dst_file_path, hash_key, source_code_path + """ + key, input_path = cls.write(source_code, dst_file_ext) + if key not in cls.cache: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._filelock import FileLock lock_dir = get_lock_dir() @@ -4047,12 +4369,17 @@ def compile( binary_remote_cache.put( error_path, config.cuda.binary_remote_cache_force_write ) +<<<<<<< HEAD cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( +======= + cls.cache[key] = CUDACodeCache.CacheEntry( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_path, output_path, error_json ) raise exc.CUDACompileError(cmd_parts, error_output) if not os.path.exists(output_path): cmd = cuda_compile_command( +<<<<<<< HEAD src_files, output_path, dst_file_ext, extra_args ) with open(input_path, "a") as f: @@ -4060,6 +4387,15 @@ def compile( f.write(f"// CUDA {operation_name} cmd\n// {cmd}\n") start_time = time() log.debug("CUDA %s: %s", operation_name, cmd) +======= + [input_path], output_path, dst_file_ext, extra_args + ) + with open(input_path, "a") as f: + f.write("\n") + f.write(f"// CUDA Compile cmd\n// {cmd}\n") + start_time = time() + log.debug("CUDA Compilation: %s", cmd) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cmd_parts = cmd.split(" ") try: if use_re_build(): @@ -4077,7 +4413,11 @@ def compile( except subprocess.CalledProcessError as error: cls._record_cuda_compile_error( error.output.decode("utf-8"), +<<<<<<< HEAD key_with_ext, +======= + key, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cmd_parts, input_path, output_path, @@ -4088,7 +4428,11 @@ def compile( if "COMPILE FAILED WITH" in str(error): cls._record_cuda_compile_error( str(error), +<<<<<<< HEAD key_with_ext, +======= + key, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cmd_parts, input_path, output_path, @@ -4097,14 +4441,23 @@ def compile( raise exc.CUDACompileError(cmd_parts, str(error)) from error raise error end_time = time() +<<<<<<< HEAD log_duration_msg = f"CUDA {operation_name} took {end_time - start_time} seconds. Command: {cmd}" +======= + log_duration_msg = f"CUDA Compilation took {end_time - start_time} seconds. Compile command: {cmd}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log.info(log_duration_msg) else: log.debug( +<<<<<<< HEAD "CUDA %s skipped: %s since output already exists", operation_name, output_path, +======= + "CUDA Compilation skipped: %s since output already exists", + input_path, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Upload to remote cache if enabled if ( @@ -4115,16 +4468,25 @@ def compile( binary_remote_cache.put( output_path, config.cuda.binary_remote_cache_force_write ) +<<<<<<< HEAD cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( input_path, output_path, None ) cache_entry: CUDACodeCache.CacheEntry = cls.cache[key_with_ext] +======= + cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path, None) + cache_entry: CUDACodeCache.CacheEntry = cls.cache[key] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if cache_entry.error_json is not None: # Restore cached Exception and raise it as if we had compiled cmd_parts, error_output = json.loads(cache_entry.error_json) raise exc.CUDACompileError(cmd_parts, error_output.encode("utf-8")) +<<<<<<< HEAD return (cls.cache[key_with_ext].output_path, key, input_path) +======= + return (cls.cache[key].output_path, key, input_path) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str]: @@ -4147,7 +4509,11 @@ def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str def _record_cuda_compile_error( cls, error_str: str, +<<<<<<< HEAD key_with_ext: str, +======= + key: str, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cmd_parts: list[str], input_path: str, output_path: str, @@ -4156,9 +4522,13 @@ def _record_cuda_compile_error( binary_remote_cache: Any = None, ) -> None: error_json = json.dumps([cmd_parts, error_str]) +<<<<<<< HEAD cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( input_path, output_path, error_json ) +======= + cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path, error_json) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) error_path = binary_error_path(output_path) with open(error_path, "w", encoding="utf-8") as fh: fh.write(error_json) diff --git a/torch/_inductor/codegen/block_analysis.py b/torch/_inductor/codegen/block_analysis.py index b47c8325e2154..e6eb91ccd2c5e 100644 --- a/torch/_inductor/codegen/block_analysis.py +++ b/torch/_inductor/codegen/block_analysis.py @@ -17,6 +17,7 @@ class BlockPatternMatcher: Matches block indexing expressions. """ +<<<<<<< HEAD _indexing_wild_signed_int = functools.partial( sympy.Wild, properties=[lambda x: x.is_integer] ) @@ -24,6 +25,8 @@ class BlockPatternMatcher: sympy.Wild, properties=[lambda x: x.is_integer and x.is_nonnegative] ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr: """ @@ -70,6 +73,7 @@ def match_mod_div_block_expr( index = cls._preprocess(index) # Pattern match to find the strides and offset. +<<<<<<< HEAD wild_unsigned_int = functools.partial( cls._indexing_wild_unsigned_int, exclude=[index_var] ) @@ -82,6 +86,11 @@ def match_mod_div_block_expr( strides: list[Expr] = [ wild_signed_int(f"stride_mod{idx}") for idx in range(num_dims) ] +======= + wild = functools.partial(sympy.Wild, exclude=[index_var]) + dims: list[Expr] = [wild(f"dim_mod{idx}") for idx in range(num_dims)] + strides: list[Expr] = [wild(f"stride_mod{idx}") for idx in range(num_dims)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The first dimension's index is computed by division. # The remaining are computed by modulo. @@ -99,8 +108,12 @@ def match_mod_div_block_expr( # for more details. In short, here we check that each subexpression in sympy.Add contains # only FloorDiv or ModularIndexing expressions. if num_dims >= 5: +<<<<<<< HEAD stride = sympy.symbols("stride", cls=wild_signed_int) denom, other = sympy.symbols("denominator other", cls=wild_unsigned_int) +======= + stride, denom, other = sympy.symbols("stride denominator other", cls=wild) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mod_div_pattern = stride * ModularIndexing(index_var, denom, other) floor_div_pattern = stride * FloorDiv(index_var, denom) first_dim_floor_div_matched = False @@ -184,7 +197,11 @@ def match_affine_block_expr( stride. """ index = cls._preprocess(index) +<<<<<<< HEAD stride = cls._indexing_wild_signed_int(name="stride", exclude=[index_var]) +======= + stride = sympy.Wild("stride", exclude=[index_var]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m = index.match(index_var * stride) if m is None: return None diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 9802358b02eee..48afa0e5eb070 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -34,7 +34,10 @@ import torch.fx from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.utils import _pytree as pytree +<<<<<<< HEAD from torch.utils._config_module import ConfigModule +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter @@ -44,7 +47,10 @@ from .. import config, metrics from ..dtype_propagation import DtypePropagationOpsHandler from ..ops_handler import BasicMathOpsMixin, DefaultHandler +<<<<<<< HEAD from ..shape_propagation import ShapePropagationOpsHandler +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..utils import ( boolean_ops, DeferredLineBase, @@ -71,7 +77,10 @@ from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode from ..loop_body import LoopBody from ..scheduler import BaseScheduling, Scheduler, SchedulerNode +<<<<<<< HEAD from ..shape_propagation import BlockShapeType +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .wrapper import PythonWrapperCodegen _T = TypeVar("_T") @@ -255,9 +264,12 @@ def get_stride(self) -> list[sympy.Expr]: def get_name(self) -> str: return self.outer_name +<<<<<<< HEAD def get_is_pinned(self) -> bool: return False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_inputs_that_alias_output(self) -> list[str]: return [] @@ -308,7 +320,10 @@ class DeviceCodegen: scheduling: SchedulingConstructor wrapper_codegen: WrapperConstructor cpp_wrapper_codegen: Optional[WrapperConstructor] = None +<<<<<<< HEAD fx_wrapper_codegen: Optional[WrapperConstructor] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg, ConstexprArg] @@ -365,8 +380,13 @@ def cpp_device_ptr(self) -> str: def tma_descriptor_helpers(self) -> str: raise NotImplementedError +<<<<<<< HEAD def cpp_scratch( self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None +======= + def cpp_global_scratch( + self, idx: int, workspace: TritonScratchWorkspace +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Optional[tuple[list[str], str]]: # optionally return (scratch definition, arg name) raise NotImplementedError @@ -374,7 +394,10 @@ def cpp_scratch( device_op_overrides_dict: dict[str, DeviceOpOverrides] = {} custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {} +<<<<<<< HEAD custom_backend_codegen_configs: dict[str, Optional[ConfigModule]] = {} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The code generated by Inductor consists of two main parts: kernel code and wrapper code. @@ -403,6 +426,7 @@ def register_backend_for_device( device_scheduling: SchedulingConstructor, device_wrapper_codegen: WrapperConstructor, device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None, +<<<<<<< HEAD device_fx_wrapper_codegen: Optional[WrapperConstructor] = None, device_custom_pass: Optional[CustomGraphModulePass] = None, device_custom_config: Optional[ConfigModule] = None, @@ -422,6 +446,14 @@ def register_backend_for_device( f"{device_custom_config=} cannot be the same as the default inductor config {config=}" ) custom_backend_codegen_configs[device] = device_custom_config +======= + device_custom_pass: Optional[CustomGraphModulePass] = None, +) -> None: + device_codegens[device] = DeviceCodegen( + device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen + ) + custom_backend_passes[device] = device_custom_pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class BackendFeature(Enum): @@ -468,6 +500,7 @@ def get_scheduling_for_device(device: str) -> Optional[SchedulingConstructor]: def get_wrapper_codegen_for_device( +<<<<<<< HEAD device: str, cpp_wrapper: bool = False, fx_wrapper: bool = False ) -> Optional[WrapperConstructor]: if device in device_codegens: @@ -478,6 +511,17 @@ def get_wrapper_codegen_for_device( return wrapper_codegen_obj.cpp_wrapper_codegen else: return wrapper_codegen_obj.wrapper_codegen +======= + device: str, cpp_wrapper: bool = False +) -> Optional[WrapperConstructor]: + if device in device_codegens: + wrapper_codegen_obj: DeviceCodegen = device_codegens[device] + return ( + wrapper_codegen_obj.cpp_wrapper_codegen + if cpp_wrapper + else wrapper_codegen_obj.wrapper_codegen + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return None @@ -485,6 +529,7 @@ def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModul return custom_backend_passes[device] if device in custom_backend_passes else None +<<<<<<< HEAD def get_custom_backend_config_for_device(device: str) -> Optional[ConfigModule]: return ( custom_backend_codegen_configs[device] @@ -499,6 +544,10 @@ def init_backend_registration() -> None: Register the backend for different devices, including the scheduling for kernel code generation and the host side wrapper code generation. """ +======= +@functools.cache +def init_backend_registration() -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .cpp import CppScheduling from .cpp_wrapper_cpu import CppWrapperCpu from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef @@ -507,10 +556,15 @@ def init_backend_registration() -> None: from .cuda_combined_scheduling import CUDACombinedScheduling from .halide import HalideScheduling from .mps import MetalScheduling +<<<<<<< HEAD from .python_wrapper_mtia import PythonWrapperMtia from .triton import TritonScheduling from .wrapper import PythonWrapperCodegen from .wrapper_fxir import WrapperFxCodegen +======= + from .triton import TritonScheduling + from .wrapper import PythonWrapperCodegen +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if get_scheduling_for_device("cpu") is None: cpu_backends = { @@ -525,7 +579,10 @@ def init_backend_registration() -> None: CppWrapperCpuArrayRef if config.aot_inductor.allow_stack_allocation else CppWrapperCpu, +<<<<<<< HEAD WrapperFxCodegen, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if get_scheduling_for_device("cuda") is None: @@ -539,7 +596,10 @@ def init_backend_registration() -> None: lambda scheduling: cuda_backends[config.cuda_backend](scheduling), PythonWrapperCodegen, CppWrapperGpu, +<<<<<<< HEAD WrapperFxCodegen, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if get_scheduling_for_device("xpu") is None: @@ -548,7 +608,10 @@ def init_backend_registration() -> None: TritonScheduling, PythonWrapperCodegen, CppWrapperGpu, +<<<<<<< HEAD WrapperFxCodegen, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if get_scheduling_for_device("mps") is None: @@ -557,6 +620,7 @@ def init_backend_registration() -> None: MetalScheduling, PythonWrapperCodegen, CppWrapperMps, +<<<<<<< HEAD WrapperFxCodegen, ) @@ -567,6 +631,8 @@ def init_backend_registration() -> None: PythonWrapperMtia, CppWrapperGpu, WrapperFxCodegen, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) private_backend = torch._C._get_privateuse1_backend_name() @@ -580,14 +646,20 @@ def init_backend_registration() -> None: device_scheduling = _get_custom_mod_func("Scheduling") wrapper_codegen = _get_custom_mod_func("PythonWrapperCodegen") cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodegen") +<<<<<<< HEAD fx_wrapper_codegen = _get_custom_mod_func("WrapperFxCodegen") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if device_scheduling and wrapper_codegen and cpp_wrapper_codegen: register_backend_for_device( private_backend, device_scheduling, wrapper_codegen, cpp_wrapper_codegen, +<<<<<<< HEAD fx_wrapper_codegen, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) except RuntimeError: pass @@ -616,7 +688,10 @@ def get_device_op_overrides(device: str) -> DeviceOpOverrides: if not device_op_overrides_dict: from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401 from .cuda import device_op_overrides # noqa: F401 +<<<<<<< HEAD from .mtia import device_op_overrides as mtia_op_overrides # noqa: F401 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401 return device_op_overrides_dict[device] @@ -820,6 +895,7 @@ def doprint( expr = V.graph.sizevars.simplify(expr) return super().doprint(expr) +<<<<<<< HEAD def parenthesize(self, item: sympy.Expr, level: int, strict: bool = False) -> str: if isinstance(item, sympy.Mod): # use parenthesis to enforce precedence. @@ -828,6 +904,8 @@ def parenthesize(self, item: sympy.Expr, level: int, strict: bool = False) -> st else: return super().parenthesize(item, level, strict) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class OpDecompositions: """ @@ -1794,7 +1872,10 @@ def __init__( name: str, bounds: ValueRanges[Any], dtype: Optional[torch.dtype] = None, +<<<<<<< HEAD shape: BlockShapeType = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): super().__init__() assert isinstance(bounds, ValueRanges), type(bounds) @@ -1802,7 +1883,10 @@ def __init__( self.bounds = bounds self.use_count = 1 # track how many times this expression is used self.dtype = dtype +<<<<<<< HEAD self.shape = shape +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __str__(self) -> str: return self.name @@ -1912,7 +1996,10 @@ def generate( write: bool = True, assignment: bool = True, dtype: Optional[torch.dtype] = None, +<<<<<<< HEAD shape: BlockShapeType = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> CSEVariableType: if isinstance(expr, OpsValue): expr = expr.value @@ -1933,12 +2020,17 @@ def generate( assert isinstance(expr, str) cache_key = expr var = self.try_get(cache_key) +<<<<<<< HEAD if shape is None and not assignment: # since there's no assignment to a variable, use any shape here # other than None to avoid the unknown shape failures shape = () if not var: var = self.newvar(bounds, dtype, shape) +======= + if not var: + var = self.newvar(bounds, dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.put(cache_key, var) if write: if V.kernel.current_node: @@ -1984,10 +2076,16 @@ def newvar( self, bounds: ValueRanges[Any] = ValueRanges.unknown(), dtype: Optional[torch.dtype] = None, +<<<<<<< HEAD shape: BlockShapeType = None, ) -> CSEVariableType: var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" var = V.kernel.create_cse_var(var_name, bounds, dtype, shape) +======= + ) -> CSEVariableType: + var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" + var = V.kernel.create_cse_var(var_name, bounds, dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.varname_map[var_name] = var return var @@ -1996,12 +2094,19 @@ def namedvar( name: str, bounds: ValueRanges[Any] = ValueRanges.unknown(), dtype: Optional[torch.dtype] = None, +<<<<<<< HEAD shape: BlockShapeType = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> CSEVariableType: torch._check_value( name not in self.varname_map, lambda: f"duplicate name: {name}" ) +<<<<<<< HEAD var = V.kernel.create_cse_var(name, bounds, dtype, shape) +======= + var = V.kernel.create_cse_var(name, bounds, dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.varname_map[name] = var return var @@ -2410,6 +2515,7 @@ def get_dtype(name: str) -> torch.dtype: def __init__(self, name: str) -> None: self.name = name +<<<<<<< HEAD @property def uid(self) -> str: """ @@ -2433,6 +2539,8 @@ def choice_or_none(self, **kwargs: Any) -> Optional[ChoiceCaller]: return temp_choices[0] return None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def maybe_append_choice( self, choices: list[Any], **kwargs: Any ) -> Optional[NotImplementedError]: @@ -2480,6 +2588,7 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> value = getattr(self.parent_handler, name)(*args, **kwargs) dtype_handler = DtypePropagationOpsHandler() +<<<<<<< HEAD shape_handler = ShapePropagationOpsHandler() backend = get_current_backend() @@ -2491,16 +2600,30 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> if name == "masked" and backend == "triton": output_dtype = value.dtype output_shape = value.shape +======= + + backend = get_current_backend() + + output_dtype = None + if name == "masked" and backend == "triton": + output_dtype = value.dtype +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif name == "masked" and backend == "cpp": output_dtype = V.interpreter.current_node.meta.get( OptimizationContext.key, None ).dtype +<<<<<<< HEAD # TODO: fix me output_shape = None elif backend in ("triton", "cpp", "mps"): dtype_op = getattr(dtype_handler, name) output_dtype = dtype_op(*args, **kwargs) output_shape = shape_op(*args, **kwargs) +======= + elif backend in ("triton", "cpp", "mps"): + dtype_op = getattr(dtype_handler, name) + output_dtype = dtype_op(*args, **kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if backend in ("triton", "cpp"): # maybe there are some exceptions on mps? @@ -2508,7 +2631,11 @@ def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> output_idx = 0 +<<<<<<< HEAD def do_cse(v: Union[str, CSEVariable]) -> CSEVariable: +======= + def do_cse(v: str) -> CSEVariable: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # we tree_map over the output, so we need to fetch corresponding dtype nonlocal output_idx var_dtype: Optional[torch.dtype] = ( @@ -2516,6 +2643,7 @@ def do_cse(v: Union[str, CSEVariable]) -> CSEVariable: if isinstance(output_dtype, (list, tuple)) else output_dtype ) +<<<<<<< HEAD var_shape: BlockShapeType = ( output_shape[output_idx] # type: ignore[assignment] if isinstance(output_shape, (list, tuple)) @@ -2531,13 +2659,23 @@ def do_cse(v: Union[str, CSEVariable]) -> CSEVariable: v.dtype = var_dtype if v.shape is None: v.shape = var_shape +======= + output_idx += 1 + + # some cpp op implementations don't set the dtype + if backend == "cpp" and isinstance(v, CSEVariable) and v.dtype is None: + v.dtype = var_dtype +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) csevar = V.kernel.cse.generate( V.kernel.compute, v, bounds=bounds, dtype=output_dtype, +<<<<<<< HEAD shape=output_shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) csevar.update_on_args(name, args, kwargs) @@ -2634,6 +2772,7 @@ def indirect_indexing( pos = var.bounds & ValueRanges(0, int_oo) new_bounds = new_bounds | pos +<<<<<<< HEAD var = self.kernel.cse.generate( self.kernel.compute, stm, @@ -2641,6 +2780,9 @@ def indirect_indexing( dtype=var.dtype, shape=var.shape, ) +======= + var = self.kernel.cse.generate(self.kernel.compute, stm, bounds=new_bounds) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sympy_var = self.parent_handler.indirect_indexing(var, size, check) if generate_assert(check): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 9d36e24d5f9e5..56833afe656c5 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -24,7 +24,10 @@ from ..._dynamo.utils import counters from .. import config, cpp_builder, cpu_vec_isa, ir, metrics +<<<<<<< HEAD from ..debug import set_kernel_post_grad_provenance_tracing +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..loop_body import LoopBody from ..scheduler import ( BaseSchedulerNode, @@ -44,6 +47,10 @@ is_welford_reduction, parallel_num_threads, Placeholder, +<<<<<<< HEAD +======= + set_kernel_post_grad_provenance_tracing, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sympy_index_symbol, sympy_index_symbol_with_prefix, sympy_product, @@ -216,17 +223,25 @@ def reduction_combine( reduction_type, var, next_value, +<<<<<<< HEAD helper_val=None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) index: Optional[sympy.Symbol] = None, src_dtype=None, ): is_bool = src_dtype == torch.bool if reduction_type == "sum": +<<<<<<< HEAD if helper_val: return f"cascade_sum_combine({next_value}, &{helper_val})" else: conjunction = "|" if is_bool else "+" return f"{var} {conjunction} {next_value}" +======= + conjunction = "|" if is_bool else "+" + return f"{var} {conjunction} {next_value}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if reduction_type == "prod": return f"{var} * {next_value}" if reduction_type == "xor_sum": @@ -366,6 +381,7 @@ def replace_acc_name(buffer: IndentedBuffer, name: str, new_name: str): buffer._lines[i] = re.sub(r"\b" + f"{name}" + r"\b", f"{new_name}", line) +<<<<<<< HEAD def replace_cascade_sum_with_add(buffer: IndentedBuffer): """ Replaces `acc = cascade_sum_combine(value, ...)` with `acc = acc + value;` @@ -391,6 +407,8 @@ def replace_cascade_sum_with_add(buffer: IndentedBuffer): buffer._lines[i] = new_content +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @functools.lru_cache def stride_at(index: sympy.Expr, var: sympy.Symbol): if not index.has(var): @@ -933,8 +951,13 @@ def frexp(x): return tuple(V.kernel.cse.try_get(cache_key) for cache_key in cache_keys) code = BracesBuffer() +<<<<<<< HEAD exponent = V.kernel.cse.newvar(dtype=torch.int32, shape=x.shape) mantissa = V.kernel.cse.newvar(dtype=x.dtype, shape=x.shape) +======= + exponent = V.kernel.cse.newvar(dtype=torch.int32) + mantissa = V.kernel.cse.newvar(dtype=x.dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) code.writeline(f"int32_t {exponent};") code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});") V.kernel.compute.splice(code) @@ -1119,10 +1142,13 @@ def sign(x): code.writeline("()") return code +<<<<<<< HEAD @staticmethod def device_assert_async(cond, msg): return f'({cond} ? 0 : (throw std::runtime_error("{msg}"), 0))' +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CppOverrides._initialize_pointwise_overrides("cpp") @@ -1907,6 +1933,7 @@ def index_expr(expr, dtype): class CppKernel(Kernel): +<<<<<<< HEAD """ Base class for C++ kernel code generation in PyTorch Inductor. This class is responsible for generating C++ code from the intermediate representation. @@ -1916,6 +1943,8 @@ class CppKernel(Kernel): num_threads: Number of threads for parallel execution """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) overrides = CppOverrides # type: ignore[assignment] sexpr = cexpr newvar_prefix = "auto " @@ -1953,9 +1982,12 @@ def __init__(self, args, num_threads): self.welford_helper_cse = CSE( self.newvar_prefix, self.suffix, name_prefix="welford_helper" ) +<<<<<<< HEAD self.cascade_helper_cse = CSE( self.newvar_prefix, self.suffix, name_prefix="cascade_helper" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.preloads = IndentedBuffer() self.poststores = IndentedBuffer() self.num_threads = num_threads # num_threads the kernel specialized for @@ -2171,6 +2203,7 @@ def finalize_reduction_prefix(self, size: Optional[int] = None): for gen_fn in self.reduction_prefix_generators: self.reduction_prefix.splice(gen_fn(size)) +<<<<<<< HEAD def need_use_acc_helper(self, reduction_type, dtype, use_scalar): # Check if we need accumulate helper for the reduction operation. # using accumulate helper generates the necessary code to improve precision for @@ -2288,6 +2321,8 @@ def _use_acc_helper( f"{result}_local = cascade_sum_final(&{helper_val});" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def reduction(self, dtype, src_dtype, reduction_type, value): argmax_or_argmin = reduction_type in ("argmax", "argmin") reduction_key = src_dtype, reduction_type, value @@ -2306,6 +2341,7 @@ def reduction(self, dtype, src_dtype, reduction_type, value): acc, acc_type, reduction_type, init_dtype, reduction_init ) ) +<<<<<<< HEAD if self.need_use_acc_helper(reduction_type, dtype, True): # use cascade_helper for vec kernel @@ -2336,6 +2372,15 @@ def reduction(self, dtype, src_dtype, reduction_type, value): self.stores.writeline( f"{acc} = {reduction_combine(reduction_type, acc, value, index=index)};" ) +======= + assert self.reduction_depth is not None + index = self.itervars[self.reduction_depth] + for i in range(self.reduction_depth + 1, len(self.itervars)): + index = index * self.ranges[i] + self.itervars[i] + self.stores.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, value, index)};" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._gen_parallel_reduction_buffers(acc, acc_type, reduction_type, init_dtype) result = reduction_project(reduction_type, acc) @@ -2972,6 +3017,7 @@ def store(self, name, index, value, mode=None): raise NotImplementedError(f"store mode={mode}") def reduction(self, dtype, src_dtype, reduction_type, value): +<<<<<<< HEAD """ Perform vectorized reduction operation. @@ -2988,6 +3034,8 @@ def reduction(self, dtype, src_dtype, reduction_type, value): Returns: The result of the reduction operation """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Note: For argmax and argmin on bool type, we always convert bool to float. # Fix issue: https://github.com/pytorch/pytorch/issues/143568 assert reduction_type in VECTORIZABLE_RTYPES @@ -3013,7 +3061,10 @@ def reduction(self, dtype, src_dtype, reduction_type, value): ) assert isinstance(acc, CppCSEVariable) acc_vec = f"{acc}_vec" +<<<<<<< HEAD masked_acc = f"masked_{acc}" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) masked_acc_vec = f"masked_{acc_vec}" self.reduction_var_names += [f"{acc}", acc_vec, masked_acc_vec] self.is_reduction = True @@ -3031,9 +3082,13 @@ def reduction(self, dtype, src_dtype, reduction_type, value): self.reduction_init_vec, ) ) +<<<<<<< HEAD use_acc_helper = self.need_use_acc_helper(reduction_type, dtype, False) if use_acc_helper: +======= + if reduction_type == "welford_reduce": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # use masked acc_vec for tail vec kernel self.reduction_prefix_generators.append( self._gen_reduction_prefix( @@ -3045,11 +3100,16 @@ def reduction(self, dtype, src_dtype, reduction_type, value): ) ) +<<<<<<< HEAD # use welford_helper/cascade_helper for vec kernel +======= + # use welford_helper for vec kernel +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.reduction_depth is not None reduction_size = functools.reduce( operator.mul, self.ranges[self.reduction_depth :] ) +<<<<<<< HEAD if reduction_type == "welford_reduce": helper_val = self.welford_helper_cse.generate( self.compute, f"reduction {reduction_key}", write=False @@ -3060,6 +3120,13 @@ def reduction(self, dtype, src_dtype, reduction_type, value): ) masked_helper_val = f"masked_{helper_val}" helper_vec_range = ( +======= + welford_helper_val = self.welford_helper_cse.generate( + self.compute, f"reduction {reduction_key}", write=False + ) + masked_welford_helper_val = f"masked_{welford_helper_val}" + welford_helper_vec_range = ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ( FloorDiv(reduction_size, self.ranges[self.tiling_idx]) * FloorDiv(self.ranges[self.tiling_idx], self.tiling_factor) @@ -3069,7 +3136,11 @@ def reduction(self, dtype, src_dtype, reduction_type, value): if FloorDiv(self.ranges[self.tiling_idx], self.tiling_factor) else sympy.Integer(0) ) +<<<<<<< HEAD masked_helper_vec_range = ( +======= + masked_welford_helper_vec_range = ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ( FloorDiv(reduction_size, self.ranges[self.tiling_idx]) if self.tiling_idx >= self.reduction_depth @@ -3078,6 +3149,7 @@ def reduction(self, dtype, src_dtype, reduction_type, value): if self.ranges[self.tiling_idx] % self.tiling_factor else sympy.Integer(0) ) +<<<<<<< HEAD # scalar helper for scalar sum is also needed when vec kernel is included # Note: is it different from welford reduction as welford reduction of scalar version # does not need helper, and the helper needs the information of reduction size to initialize @@ -3099,11 +3171,21 @@ def reduction(self, dtype, src_dtype, reduction_type, value): masked_acc, masked_helper_val, masked_helper_vec_range, +======= + self._use_welford_helper( + acc_vec, welford_helper_val, welford_helper_vec_range, dtype + ) + self._use_welford_helper( + masked_acc_vec, + masked_welford_helper_val, + masked_welford_helper_vec_range, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype, ) # use masked acc_vec for tail vec kernel acc_vec_ = masked_acc_vec if self.tail_size else acc_vec +<<<<<<< HEAD helper_val_ = masked_helper_val if self.tail_size else helper_val if reduction_type == "sum": self.stores.writeline( @@ -3113,6 +3195,14 @@ def reduction(self, dtype, src_dtype, reduction_type, value): self.stores.writeline( f"{acc_vec_} = {self.reduction_combine_vec(reduction_type, acc_vec_, value, helper_val_)};" ) +======= + welford_helper_val_ = ( + masked_welford_helper_val if self.tail_size else welford_helper_val + ) + self.stores.writeline( + f"{acc_vec_} = {self.reduction_combine_vec(reduction_type, acc_vec_, value, welford_helper_val_)};" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: assert self.reduction_depth is not None index = self.itervars[self.reduction_depth] @@ -3143,7 +3233,11 @@ def reduction(self, dtype, src_dtype, reduction_type, value): reduction_combine_fn=reduction_combine, reduction_init_fn=reduction_init, ) +<<<<<<< HEAD if use_acc_helper: +======= + if reduction_type == "welford_reduce": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # use masked acc_vec for tail vec kernel self._gen_parallel_reduction_buffers( masked_acc_vec, @@ -3190,11 +3284,15 @@ def reduction(self, dtype, src_dtype, reduction_type, value): vec_dtype = torch.float if is_bool else dtype vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[vec_dtype]}>" vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[vec_dtype]}, {self._get_num_vectors(vec_dtype)}>" +<<<<<<< HEAD result_vec = f"{acc_vec}" if use_acc_helper: assert reduction_type == "sum" result_vec = f"{acc_vec} + {masked_acc_vec}" next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {result_vec})" +======= + next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {acc_vec})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.reduction_suffix.writeline( f"{acc} = {reduction_combine(reduction_type, acc, next_value, src_dtype=src_dtype)};" @@ -3207,12 +3305,15 @@ def reduction(self, dtype, src_dtype, reduction_type, value): self.reduction_suffix.writeline( f"{tmpvar} = {reduction_combine(reduction_type, tmpvar, masked_tmpvar)};" ) +<<<<<<< HEAD elif use_acc_helper: assert reduction_type == "sum" masked_tmpvar = f"masked_{tmpvar}" self.reduction_suffix.writeline( f"{tmpvar} = {tmpvar} + {masked_tmpvar};" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result = reduction_project(reduction_type, tmpvar) self.reduction_cse.reduction_cache[reduction_key] = result @@ -3222,10 +3323,18 @@ def store_reduction(self, name, index, value): index = self.rename_indexing(index) var = self.args.output(name) out_dtype = V.graph.get_dtype(name) +<<<<<<< HEAD if out_dtype.is_floating_point and out_dtype != torch.double: dtype = torch.float else: dtype = out_dtype +======= + dtype = ( + (out_dtype if out_dtype == torch.double else torch.float) + if out_dtype.is_floating_point + else torch.int64 + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_num_vectors = V.kernel._get_num_vectors(out_dtype) src_num_vectors = V.kernel._get_num_vectors(dtype) code = IndentedBuffer() @@ -3337,12 +3446,66 @@ def reduction_acc_type_vec(self, reduction_type, dtype): return f"{self._get_mask_type()}" return vec_type +<<<<<<< HEAD +======= + def _welford_helper_init( + self, welford_helper_val, welford_helper_vec_range, dtype, num_threads=None + ): + vec_num_range_thread = ( + CeilDiv(welford_helper_vec_range, num_threads) + if num_threads + else welford_helper_vec_range + ) + vec_num_range_thread_expr = cexpr_index(vec_num_range_thread) + chunk_size = 4096 + num_chunks = CeilDiv(vec_num_range_thread, chunk_size) + welford_helper_init_line = ( + f"WelfordHelper<{self._get_vec_type(dtype)}, {chunk_size}> {welford_helper_val}" + f"(" + f"{vec_num_range_thread_expr}" + f");" + ) + if isinstance(num_chunks, sympy.Integer) and num_chunks <= 1: + # When the number of chunks <= 1, there is no need to use cascade summation to improve + # reduction accuracy. We can initialize a static WelfordHelper to improve performance. + return f"static {welford_helper_init_line}" + else: + return welford_helper_init_line + + def _use_welford_helper( + self, acc_vec, welford_helper_val, welford_helper_vec_range, dtype + ): + num_threads = ( + "max_threads" if config.cpp.dynamic_threads else parallel_num_threads() + ) + self.non_parallel_reduction_prefix.writeline( + self._welford_helper_init( + welford_helper_val, welford_helper_vec_range, dtype + ) + ) + self.local_reduction_init.writeline( + self._welford_helper_init( + welford_helper_val, welford_helper_vec_range, dtype, num_threads + ) + ) + self.non_parallel_reduction_suffix.writeline( + f"{acc_vec} = welford_combine({acc_vec}, &{welford_helper_val});" + ) + self.local_reduction_stores.writeline( + f"{acc_vec}_local = welford_combine({acc_vec}_local, &{welford_helper_val});" + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def reduction_combine_vec( self, reduction_type, var, next_value, +<<<<<<< HEAD helper_val=None, +======= + welford_helper_val=None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) index: Optional[sympy.Symbol] = None, horizontal_reduction: Optional[bool] = None, src_dtype: Optional[torch.dtype] = torch.float32, @@ -3367,6 +3530,7 @@ def reduction_combine_vec( else f"at::vec::minimum({var}, {next_value})" ) elif reduction_type == "sum": +<<<<<<< HEAD if helper_val: if self.tail_size: return f"cascade_sum_combine({next_value}, {cexpr_index(self.tail_size)}, &{helper_val})" @@ -3378,6 +3542,13 @@ def reduction_combine_vec( else: conjunction = "|" if is_bool else "+" return f"{var} {conjunction} {next_value}" +======= + if self.tail_size: + return f"sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + conjunction = "|" if is_bool else "+" + return f"{var} {conjunction} {next_value}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif reduction_type == "prod": if self.tail_size: return f"prod_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" @@ -3389,11 +3560,21 @@ def reduction_combine_vec( else: return f"{var} ^ {next_value}" elif reduction_type == "welford_reduce": +<<<<<<< HEAD if helper_val: if self.tail_size: return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)}, &{helper_val})" else: return f"welford_combine({var}, {next_value}, &{helper_val})" +======= + if welford_helper_val: + if self.tail_size: + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)}, &{welford_helper_val})" + else: + return ( + f"welford_combine({var}, {next_value}, &{welford_helper_val})" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: if self.tail_size: return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)})" @@ -3927,7 +4108,11 @@ def _is_valid_indices( call_ranges[tiling_indice], fallback=0 ) if call_range < factor_lowp: +<<<<<<< HEAD V.graph.sizevars.check_lt(call_range, factor_lowp) # type: ignore[arg-type] +======= + V.graph.sizevars.guard_lt(call_range, factor_lowp) # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tiling_factor = factor_lowp // 2 break elif call_ranges[tiling_indice] < factor_lowp: @@ -4522,12 +4707,18 @@ def gen_body(self, code: Optional[BracesBuffer] = None): def aggregate_reduction_buffers( self, inner_loop_reduction_outer_not: bool, outer_loop: Optional["LoopLevel"] ): +<<<<<<< HEAD """ CppKernel/CppVecKernel/CppTile2dKernel have reduction buffers themselves. Here, we decide how to aggregate them together and place new reduction buffers under CppKernelProxy. """ +======= + # CppKernel/CppVecKernel/CppTile2dKernel have reduction buffers themselves. + # Here, we decide how to aggregate them together and place new reduction buffers + # under CppKernelProxy. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def aggregate_reduction_prefix_suffix(outer_loop: "LoopLevel"): assert len(self.kernels) >= 2 main_loop_kernel = self.kernels[0] @@ -4571,9 +4762,12 @@ def aggregate_reduction_prefix_suffix(outer_loop: "LoopLevel"): replace_acc_name( tail_loop_kernel.reduction_suffix, name, new_name ) +<<<<<<< HEAD # If tail loop kernel is a scalar kernel, use direct sum instead of cascade_sum_combine # as the reduction vars are extended: tmp_acc -> tmp_acc_arr[]. replace_cascade_sum_with_add(tail_loop_kernel.stores) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) suffix_buf.splice( move_code_under_inner_loop( tail_loop_kernel.reduction_suffix, @@ -4862,7 +5056,11 @@ def can_fuse_multi_outputs_template( isinstance(template_buf.layout, ir.MultiOutputLayout) and isinstance(node2.node, ir.MultiOutput) and len(node2.node.inputs) == 1 +<<<<<<< HEAD and node2.node.inputs[0].get_name() == template_buf.name # type: ignore[union-attr] +======= + and node2.node.inputs[0].get_name() == template_buf.name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return False @@ -5162,6 +5360,7 @@ def is_contiguous_index(x): ): continue # Local Buffer is a view of global buffer +<<<<<<< HEAD local_buffer_stride: list[int] = [] stride = global_buffer_layout.stride[-1] local_buffer_size = get_call_ranges(scheduler_node)[ @@ -5175,6 +5374,13 @@ def is_contiguous_index(x): global_buffer_layout.dtype, local_buffer_size, local_buffer_stride, +======= + local_buffer_layout = ir.FixedLayout( + global_buffer_layout.device, + global_buffer_layout.dtype, + global_buffer_layout.size[size_offset:], + global_buffer_layout.stride[size_offset:], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def try_share_local_buffer(local_buffer_layout, local_buffers): @@ -5344,7 +5550,11 @@ def template_buffer_has_other_users( flag_template_buffer_has_other_users = template_buffer_has_other_users( ctb, template_node.outputs_by_name, epilogue_ir_nodes ) +<<<<<<< HEAD kernel, render = ctb.make_kernel_render( # type: ignore[misc] +======= + kernel, render = ctb.make_kernel_render( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ctb, flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, epilogue_nodes=epilogue_ir_nodes, @@ -5396,6 +5606,13 @@ def define_kernel(self, src_code, nodes, kernel_args=None): else "" ) kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()]) +<<<<<<< HEAD +======= + # below add provenance tracing info for cpu CppKernel types + if config.trace.enabled: + set_kernel_post_grad_provenance_tracing(nodes, kernel_name) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name) src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) @@ -5434,6 +5651,7 @@ def flush(self): kernel_name = self.define_kernel( src_code, self.kernel_group.scheduled_nodes ) +<<<<<<< HEAD # below add provenance tracing info for cpu CppKernel types debug_handle: Optional[int] = None if config.trace.provenance_tracking_level != 0: @@ -5443,6 +5661,9 @@ def flush(self): self.kernel_group.call_kernel( V.graph.wrapper_code, kernel_name, debug_handle=debug_handle ) +======= + self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.reset_kernel_group() self._set_flush_status(False) @@ -5483,7 +5704,11 @@ def codegen_group(self, name=None) -> str: "win32", ] if enable_kernel_profile: +<<<<<<< HEAD code.writelines(["#include "]) +======= + code.writelines(["#include "]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) code.writeline("#include ") # 2. Function definition @@ -5492,11 +5717,16 @@ def codegen_group(self, name=None) -> str: arg_defs, _, _ = self.args.cpp_argdefs() arg_defs = ",\n".ljust(25).join(arg_defs) func_export_decl = get_export_declaration() +<<<<<<< HEAD inline_attr = ( "C10_ALWAYS_INLINE_ATTRIBUTE" if config.cpp.force_inline_kernel else "" ) code.writeline( f'extern "C" {func_export_decl} void {inline_attr} {kernel_decl_name}({arg_defs})' +======= + code.writeline( + f'extern "C" {func_export_decl} void {kernel_decl_name}({arg_defs})' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # 3. Function body @@ -5506,10 +5736,14 @@ def codegen_group(self, name=None) -> str: prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" code.writelines( [ +<<<<<<< HEAD ( "torch::aot_inductor::RAIIAtenRecordFunctionHandle " f'record_{prefix + kernel_name}_("{prefix + kernel_name}", nullptr);' ) +======= + f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef({{}}));' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] ) for old, new in self.args.aliases(): @@ -5517,6 +5751,7 @@ def codegen_group(self, name=None) -> str: code.splice(self.loops_code) return code.getvalue() +<<<<<<< HEAD def call_kernel(self, wrapper, kernel_name, debug_handle: Optional[int] = None): _, call_args, arg_types = self.args.cpp_argdefs() wrapper.generate_kernel_call( @@ -5525,6 +5760,12 @@ def call_kernel(self, wrapper, kernel_name, debug_handle: Optional[int] = None): triton=False, arg_types=arg_types, debug_handle=debug_handle, +======= + def call_kernel(self, wrapper, kernel_name): + _, call_args, arg_types = self.args.cpp_argdefs() + wrapper.generate_kernel_call( + kernel_name, call_args, triton=False, arg_types=arg_types +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -5723,6 +5964,7 @@ def get_simd_vec_depth(loops): simd_vec_depth = get_simd_vec_depth(self.loops) +<<<<<<< HEAD def has_scalar_kernel(loop_nest: LoopNest): assert isinstance(loop_nest.kernel, CppKernelProxy) return any( @@ -5730,6 +5972,8 @@ def has_scalar_kernel(loop_nest: LoopNest): for kernel in loop_nest.kernel.kernels ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # When the number of steps of the first inner loop is much larger than the number of steps of # all outer loops, change `start_depth` to the first inner loop and recalculate `max_depth`. if ( @@ -5743,7 +5987,10 @@ def has_scalar_kernel(loop_nest: LoopNest): simd_vec_depth is not None and max_depth > simd_vec_depth and self.loops[max_depth].is_reduction +<<<<<<< HEAD and has_scalar_kernel(self) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ): start_depth = max_depth diff --git a/torch/_inductor/codegen/cpp_flex_attention_template.py b/torch/_inductor/codegen/cpp_flex_attention_template.py index a1ceecf7f7c9e..ab48d89aff4ab 100644 --- a/torch/_inductor/codegen/cpp_flex_attention_template.py +++ b/torch/_inductor/codegen/cpp_flex_attention_template.py @@ -792,7 +792,11 @@ def get_arg_name(name): return "" if start_offset == -1: +<<<<<<< HEAD start_offset = self.len_score_other +======= + start_offset = getattr(self, len_attr) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) length = getattr(self, len_attr) for i in range(length): @@ -814,7 +818,11 @@ def modification(self, subgraph_buffer, output_name, output_idx): from ..loop_body import LoopBody from ..utils import sympy_index_symbol_with_prefix, SymT from ..virtualized import V +<<<<<<< HEAD from .cpp import CppKernelProxy, KernelGroup, ParallelDepth +======= + from .cpp import CppKernelProxy, KernelGroup +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kernel_group = KernelGroup() kernel_input_args = { @@ -883,6 +891,7 @@ def fn(*args): var_sizes_list.append((var_sizes, ())) cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) +<<<<<<< HEAD def max_parallel_depth(): return ParallelDepth(parallel_depth=0, start_depth=0) @@ -892,6 +901,9 @@ def max_parallel_depth(): cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth ): kernel_group.finalize_kernel(cpp_kernel_proxy, []) +======= + kernel_group.finalize_kernel(cpp_kernel_proxy, []) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_code = kernel_group.loops_code.getvalue() var_q_symbol, var_kv_symbol = self.block_vars @@ -985,8 +997,12 @@ def render( # type: ignore[override,return] self.input_dtype = query.layout.dtype num_threads = parallel_num_threads() +<<<<<<< HEAD assert isinstance(self.output_node, ir.IRNode) buf_out: ir.IRNode = TensorBox.create(self.output_node) +======= + buf_out = TensorBox.create(self.output_node) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if template_buffer_node is not None: buf_out = template_buffer_node options = dict( @@ -995,9 +1011,15 @@ def render( # type: ignore[override,return] value=value, kv_num_blocks=self.input_nodes[3], kv_indices=self.input_nodes[4], +<<<<<<< HEAD full_kv_num_blocks=( self.input_nodes[5] if not self.no_full_kv_block else None ), +======= + full_kv_num_blocks=self.input_nodes[5] + if not self.no_full_kv_block + else None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) full_kv_indices=self.input_nodes[6] if not self.no_full_kv_block else None, score_mod_other_buffers=self.score_mod_other_buffers, mask_mod_other_buffers=self.mask_mod_other_buffers, diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index bfcebbd6a3810..11ce5951e83b3 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -917,6 +917,12 @@ def add_choices( if input_indices is None: input_indices = list(range(len(input_nodes))) +<<<<<<< HEAD +======= + only_one_input = ( + input_nodes[0] == input_nodes[1] if len(input_nodes) > 1 else False + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def reorder_and_filter(inputs, layout_or_out): if has_bias: @@ -1016,9 +1022,12 @@ def normalize_shapes(inputs, layout_or_out): assert micro_gemm is not None pre_block_weights = cls.check_if_block_weight(new_inputs[1], micro_gemm) micro_gemm.use_local_vnni_blocking(not pre_block_weights) +<<<<<<< HEAD only_one_input = ( input_nodes[0] == input_nodes[1] if len(input_nodes) > 1 else False ) and not pre_block_weights # If weights are blocked, use the second input +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def preprocessor(inputs, layout): new_inputs, new_layout = normalize_shapes( @@ -1094,6 +1103,7 @@ def get_padded_size(n, block_n, k, should_block_weight): new_size = [padded_n // block_n, k, block_n] return new_size, padded_n +<<<<<<< HEAD @staticmethod def _maybe_remove_storage_offset(node: ir.IRNode): if node.get_layout().offset == 0: @@ -1106,6 +1116,8 @@ def _maybe_remove_storage_offset(node: ir.IRNode): # W.data_ptr[...] return ir.ExternKernel.copy_input(node) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def prep_weight( cls, @@ -1161,7 +1173,10 @@ def prep_weight( elif isinstance(W, ir.IRNode): # Require W layout to be fixed & contiguous, happens inplace. ir.ExternKernel.require_contiguous(W) +<<<<<<< HEAD new_inputs[1] = cls._maybe_remove_storage_offset(W) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not skip_int8_compensation and _is_int8_gemm(new_inputs): BCompensate = None @@ -1240,7 +1255,11 @@ def block_weight(cls, W, new_size, padding): permute_size[-2], permute_size[-3] = permute_size[-3], permute_size[-2] blocked_w = L.constant_pad_nd(W, (0, padding)) blocked_w = L.permute( +<<<<<<< HEAD L.view(blocked_w, permute_size), # type: ignore[arg-type] +======= + L.view(blocked_w, permute_size), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) permute_dims, ) else: diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index d6b8806bdd910..7c65f4b14c611 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -195,7 +195,11 @@ def get_b_layout(self) -> LayoutType: ALLOCATE_WEIGHT_BUFFER = r""" {%- if is_msvc_compiler %} // MSVC doesn't support stack-allocated dynamic-sized arrays, so using heap memory here. +<<<<<<< HEAD auto heap_deq_b_buf_ptr = std::make_unique<{{buffer_dtype}}[]>({{buffer_size}}); +======= + std::unique_ptr<{{buffer_dtype}}[]> heap_deq_b_buf_ptr(new {{buffer_dtype}}[{{buffer_size}}]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {{buffer_dtype}}* {{buffer_name}} = heap_deq_b_buf_ptr.get(); {%- else %} // It's safe to use a stack-allocated array since the blocking strategy would @@ -211,12 +215,21 @@ def codegen_allocate_weight_buffer( ) -> str: buffer_size = " * ".join(map(str, size_args)) return KernelTemplate._template_from_string(self.ALLOCATE_WEIGHT_BUFFER).render( +<<<<<<< HEAD { "buffer_name": buffer_name, "buffer_dtype": buffer_dtype, "buffer_size": buffer_size, "is_msvc_compiler": cpp_builder.is_msvc_cl(), } +======= + dict( + buffer_name=buffer_name, + buffer_dtype=buffer_dtype, + buffer_size=buffer_size, + is_msvc_compiler=cpp_builder.is_msvc_cl(), + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def is_woq_int4(self): @@ -963,6 +976,7 @@ def check_amx_extra(config, m, n, k, alpha, num_threads, **kwargs): return k % vnni_size == 0 and alpha == 1 +<<<<<<< HEAD def check_int8_bf16_amx_extra(config, m, n, k, alpha, num_threads, **kwargs): # We need avx512_bf16 to dequant int8 to bf16 vec_isa = kwargs.get("vec_isa", None) @@ -981,6 +995,8 @@ def check_amx_fp16_extra(config, m, n, k, alpha, num_threads, **kwargs): return vec_isa.is_amx_fp16_supported() and k % vnni_size == 0 and alpha == 1 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_micro_gemm( *generate_gemm_config( VecAMX, @@ -993,11 +1009,16 @@ def check_amx_fp16_extra(config, m, n, k, alpha, num_threads, **kwargs): ), *generate_gemm_config( VecAMX, +<<<<<<< HEAD [(32, 32, 32), (48, 16, 32)], +======= + [(32, 32, 32), (48, 16, 32), (16, 48, 32)], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_dtype=torch.bfloat16, input2_dtype=torch.int8, output_dtype=torch.float, compute_dtype=torch.float, +<<<<<<< HEAD extra_check=check_int8_bf16_amx_extra, ), *generate_gemm_config( @@ -1005,14 +1026,22 @@ def check_amx_fp16_extra(config, m, n, k, alpha, num_threads, **kwargs): [(32, 16, 32), (32, 32, 32), (48, 16, 32), (16, 48, 32)], input_dtype=torch.bfloat16, output_dtype=torch.float, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extra_check=check_amx_extra, ), *generate_gemm_config( VecAMX, [(32, 32, 32), (48, 16, 32), (16, 48, 32)], +<<<<<<< HEAD input_dtype=torch.float16, output_dtype=torch.float, extra_check=check_amx_fp16_extra, +======= + input_dtype=torch.bfloat16, + output_dtype=torch.float, + extra_check=check_amx_extra, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), *generate_gemm_config( VecAMX, @@ -1050,6 +1079,7 @@ class CppMicroGemmAMX(CppMicroGemm): for (int idx_dq = 0, idx_q = 0; idx_dq < buf_size; idx_q += ldb, idx_dq += {{block_n}}) { {%- for vec_idx in range(0, block_n, 32) %} {%- if (block_n - vec_idx) >= 32 %} +<<<<<<< HEAD // 1) Load 32 x int8 __m256i v8 = _mm256_loadu_si256((const __m256i*)(base_addr + idx_q + {{vec_idx}})); // 2) Widen: 32 x i8 -> 32 x i16 @@ -1082,6 +1112,14 @@ class CppMicroGemmAMX(CppMicroGemm): __m256i bf16 = (__m256i)_mm512_cvtneps_pbh(f32); // 6) Store 16 x bf16 (256 bits) _mm256_storeu_si256((__m256i*)(dequantized_B_buf + idx_dq + {{vec_idx}}), bf16); +======= + auto b_int8_idx_{{vec_idx}} = at::vec::Vectorized::loadu( + base_addr + idx_q + {{vec_idx}} , + static_cast(32) + ); + auto b_bf16_idx_{{vec_idx}} = at::vec::convert<{{input_t}}>(b_int8_idx_{{vec_idx}}); + b_bf16_idx_{{vec_idx}}.store(dequantized_B_buf + idx_dq + {{vec_idx}}); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {%- else %} auto b_int8_tail = at::vec::Vectorized::loadu( base_addr + idx_q + {{block_n - (block_n % 32)}}, @@ -1238,11 +1276,15 @@ class CppMicroGemmAMX(CppMicroGemm): _tile_dpbusd({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); {%- endif %} {%- else %} +<<<<<<< HEAD {%- if input_dtype == torch.float16 %} _tile_dpfp16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); {%- else %} _tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); {%- endif %} +======= + _tile_dpbf16ps({{tile_idx_c}}, {{tile_idx_a}}, {{tile_idx_b}}); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {%- endif %} {%- endfor %} {%- endfor %} @@ -1417,7 +1459,11 @@ def check_woq_int4_extra(config, m, n, k, alpha, num_threads, **kwargs): q_group_size = kwargs.get("q_group_size", None) assert q_group_size is not None if ( +<<<<<<< HEAD q_group_size not in [32, 64, 128] +======= + q_group_size < 32 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) or k % q_group_size != 0 or config.register_blocking.block_k > q_group_size ): @@ -1563,7 +1609,13 @@ class CppMicroGemmWoQInt4Avx512(CppMicroGemmFP32Vec): auto load_scale_and_zeros = [&](int i, int _kb) { // load 2x bfloat16 vector __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i)); +<<<<<<< HEAD _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0); +======= + if (_kb + PREFETCH_SIZE_KB < KB) { + _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // convert to 2x f32 vector __m512 a, b; @@ -1597,7 +1649,13 @@ class CppMicroGemmWoQInt4Avx512(CppMicroGemmFP32Vec): if constexpr (col == 0) { float aa = static_cast(A[row * lda + k]); +<<<<<<< HEAD _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0); +======= + if (k + PREFETCH_SIZE_K < K) { + _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) va = _mm512_set1_ps(aa); } @@ -1607,7 +1665,13 @@ class CppMicroGemmWoQInt4Avx512(CppMicroGemmFP32Vec): // to reduce de-quantize overhead. if constexpr (col == 0) { __m256i b4 = _mm256_loadu_si256((__m256i*)(B + k * ldb)); +<<<<<<< HEAD _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0); +======= + if (k + PREFETCH_SIZE_K < K) { + _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4)); vb[0] = _mm512_permutexvar_ps(b32, lut); @@ -1699,8 +1763,12 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX): TEMPLATE_ENTRY = r""" inline bool {{kernel_name}}_is_block_start(int index, int k_start, int group_size) { +<<<<<<< HEAD // check if (k_start + index) % group_size == 0, assuming group_size = 32/64/128 return ((k_start + index) & (group_size - 1)) == 0; +======= + return (k_start + index) % group_size == 0; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } {{declare_kernel}} { @@ -1784,7 +1852,13 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX): auto load_scale_and_zeros = [&](int i, int _kb) { // load 2x bfloat16 vector __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i)); +<<<<<<< HEAD _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0); +======= + if (_kb + PREFETCH_SIZE_KB < KB) { + _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // convert to 2x f32 vector __m512 a, b; @@ -1813,9 +1887,17 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX): c10::ForcedUnroll{}(load_scale_and_zeros, kb++); } +<<<<<<< HEAD _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb_int4, _MM_HINT_T0); // load 256 bits = 64 elements in int4 +======= + // load 256 bits = 64 elements in int4 + if (k + PREFETCH_SIZE_K < K) { + _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb_int4, _MM_HINT_T0); + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __m128i b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + k * ldb_int4)); b32[0] = _mm512_cvtepu8_epi32(b4); b32[1] = _mm512_srli_epi32(b32[0], 4); @@ -1824,8 +1906,13 @@ class CppMicroGemmWoQInt4Amx(CppMicroGemmAMX): vb[1] = _mm512_permutexvar_ps(b32[1], lut); vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]); +<<<<<<< HEAD __m128i b4_2 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + (k + 1) * ldb_int4)); b32[0 + COLS] = _mm512_cvtepu8_epi32(b4_2); +======= + b4 = _mm_loadu_si128((__m128i*)(B + n / 2 * K + (k + 1) * ldb_int4)); + b32[0 + COLS] = _mm512_cvtepu8_epi32(b4); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) b32[1 + COLS] = _mm512_srli_epi32(b32[0 + COLS], 4); vb[0 + COLS] = _mm512_permutexvar_ps(b32[0 + COLS] , lut); vb[0 + COLS] = _mm512_fmadd_ps(vb[0 + COLS], scale[0], zero[0]); @@ -1950,7 +2037,11 @@ def create_from_config(cls, config: CppMicroGemmConfig): alpha, ) +<<<<<<< HEAD def skip_amx_kernel_for_woq(dynamic_M): +======= + def skip_amx_kernel_for_woq(config, dynamic_M, micro_gemm_cls): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # For WoQ GEMM, AMX micro-kernel may not perform well if m is small. # Exception: for dynamic shapes, we consider using the AMX micro-kernel. if ( @@ -1959,7 +2050,15 @@ def skip_amx_kernel_for_woq(dynamic_M): or input2_dtype not in [torch.int8, torch.uint8] ): return False +<<<<<<< HEAD m_threshold = 5 +======= + # For WOQ INT8, use AMX for m >= block_m + # For WOQ INT4, use AMX for m >= 5 + block_m, *_ = config.register_blocking + is_woq_int4 = micro_gemm_cls == CppMicroGemmWoQInt4Amx + m_threshold = 5 if is_woq_int4 else block_m +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return m < m_threshold assert isinstance(n, int) or n.is_number, n @@ -2001,11 +2100,20 @@ def skip_amx_kernel_for_woq(dynamic_M): num_threads, dynamic_M=dynamic_M, q_group_size=q_group_size, +<<<<<<< HEAD vec_isa=vec_isa, ): continue block_m, block_n, block_k = config.register_blocking if config.vec_isa_cls == VecAMX and skip_amx_kernel_for_woq(dynamic_M): +======= + ): + continue + block_m, block_n, block_k = config.register_blocking + if config.vec_isa_cls == VecAMX and skip_amx_kernel_for_woq( + config, dynamic_M, cls + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue # Criteria on the ranking of configurations # 1. ISA: AMX > VEC @@ -2034,6 +2142,7 @@ def skip_amx_kernel_for_woq(dynamic_M): + (block_m * block_k + block_k * block_n) * config.input_dtype.itemsize ) +<<<<<<< HEAD size_score = register_bytes # if number of mxn blocks can not occupy all the threads, # we favor smaller register blocks. @@ -2042,6 +2151,11 @@ def skip_amx_kernel_for_woq(dynamic_M): matched_configs.append( ( (isa_score, dividable_score, occupancy_score, size_score), +======= + matched_configs.append( + ( + (isa_score, dividable_score, occupancy_score, register_bytes), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cls, config, ) diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index d72f13a3e3fac..2604395f593df 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -131,7 +131,11 @@ def header(self) -> IndentedBuffer: "win32", ] if enable_kernel_profile: +<<<<<<< HEAD res.writelines(["#include "]) +======= + res.writelines(["#include "]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return res def render(self, **kwargs) -> str: diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index b0dee69b012b7..9630ee17313f0 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -2,7 +2,10 @@ import itertools from collections.abc import Iterable from typing import Any, Callable, Optional, Union +<<<<<<< HEAD from unittest.mock import patch +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sympy from sympy.parsing.sympy_parser import parse_expr @@ -19,7 +22,11 @@ from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix from ..virtualized import V from .common import REMOVED +<<<<<<< HEAD from .cpp import CppKernel, CppKernelProxy, KernelGroup, ParallelDepth +======= +from .cpp import CppKernel, CppKernelProxy, KernelGroup +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext @@ -34,7 +41,11 @@ def parse_expr_with_index_symbols(expr): return expr.subs(int_symbols) +<<<<<<< HEAD def wrap_with_tensorbox(node) -> Union[ir.TensorBox, ir.ShapeAsConstantBuffer]: +======= +def wrap_with_tensorbox(node) -> ir.TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( ir.TensorBox.create(node) if isinstance(node, ir.Buffer) else ir.TensorBox(node) ) @@ -162,7 +173,10 @@ def slice_nd(self, node, ranges: list[tuple[Any, Any]]) -> ir.ReinterpretView: assert len(_range) == 2 start, end = parse_expr_with_index_symbols(_range) sliced = L.slice_(sliced, dim, start, end, clamp=False) +<<<<<<< HEAD assert isinstance(sliced, ir.TensorBox) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(sliced.data, ir.ReinterpretView), sliced.data return sliced.data @@ -175,10 +189,17 @@ def select(self, node, dim: int, idx: int) -> ir.ReinterpretView: assert isinstance(sliced.data, ir.ReinterpretView), sliced.data return sliced.data +<<<<<<< HEAD def view(self, node, sizes: list[Any]) -> ir.IRNode: node = wrap_with_tensorbox(node) sizes = parse_expr_with_index_symbols(sizes) return L.view(node, sizes).data # type: ignore[arg-type] +======= + def view(self, node, sizes: list[Any]) -> ir.View: + node = wrap_with_tensorbox(node) + sizes = parse_expr_with_index_symbols(sizes) + return L.view(node, sizes).data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def permute(self, node, dims): node = wrap_with_tensorbox(node) @@ -190,11 +211,15 @@ def maybe_codegen_profile(self) -> str: if config.cpp.enable_kernel_profile: graph_id = V.graph.graph_id prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" +<<<<<<< HEAD handle_str = ( "torch::aot_inductor::RAIIAtenRecordFunctionHandle " f'record_{prefix}{self.kernel_name}_("{prefix}{self.kernel_name}", nullptr);' ) return handle_str +======= + return f'RECORD_FUNCTION("{prefix}{self.kernel_name}", c10::ArrayRef({{}}));' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return "" @@ -293,6 +318,7 @@ def fn(*args): var_sizes_list.append(var_sizes) cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) +<<<<<<< HEAD def max_parallel_depth(): return ParallelDepth(parallel_depth=0, start_depth=0) @@ -302,6 +328,9 @@ def max_parallel_depth(): cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth ): kernel_group.finalize_kernel(cpp_kernel_proxy, []) +======= + kernel_group.finalize_kernel(cpp_kernel_proxy, []) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return kernel_group.loops_code.getvalue() def store_grouped_gemm_pointwise_nodes( @@ -355,6 +384,7 @@ def fn(*args): var_sizes_list.append(var_sizes) cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) +<<<<<<< HEAD def max_parallel_depth(): return ParallelDepth(parallel_depth=0, start_depth=0) @@ -364,6 +394,9 @@ def max_parallel_depth(): cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth ): kernel_group.finalize_kernel(cpp_kernel_proxy, []) +======= + kernel_group.finalize_kernel(cpp_kernel_proxy, []) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return kernel_group.loops_code.getvalue() def store_output( @@ -607,7 +640,11 @@ def info_dict( ) -> dict[str, Union[ir.PrimitiveInfoType, list[ir.PrimitiveInfoType]]]: return {"backend": "CPP", "op_type": "unknown"} +<<<<<<< HEAD def output_node(self) -> Union[ir.TensorBox, ir.ShapeAsConstantBuffer]: +======= + def output_node(self) -> ir.TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ir.TensorBox.create( ir.CppTemplateBuffer( layout=self.layout, diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 929c227039463..3d53420d701ba 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -22,7 +22,10 @@ from ..dependencies import Dep from ..loop_body import LoopBody from ..scheduler import BaseSchedulerNode, SchedulerBuffer +<<<<<<< HEAD from ..shape_propagation import BlockShapeType +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs from ..virtualized import ops, OpsValue, V from .common import CSEVariable, Kernel, KernelArgs, OptimizationContext @@ -146,9 +149,14 @@ def __init__( name, bounds: ValueRanges[Any], dtype: Optional[torch.dtype] = None, +<<<<<<< HEAD shape: BlockShapeType = None, ) -> None: super().__init__(name, bounds, dtype, shape=shape) +======= + ) -> None: + super().__init__(name, bounds, dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.is_vec = False self.dependent_itervars = OrderedSet[sympy.Symbol]() @@ -201,6 +209,7 @@ def doprint(self, expr, *, simplify: bool = True, p=True): expr = V.graph.sizevars.simplify(expr) return super().doprint(expr) +<<<<<<< HEAD def parenthesize(self, item: sympy.Expr, level: int, strict: bool = False) -> str: if isinstance(item, sympy.Mod): # use parenthesis to enforce precedence. @@ -209,6 +218,8 @@ def parenthesize(self, item: sympy.Expr, level: int, strict: bool = False) -> st else: return super().parenthesize(item, level, strict) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # A function to print, useful for printing sympy symbols. cexpr = CppPrinter().doprint diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 83d1d0614674b..7cc05151452d5 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -21,8 +21,12 @@ from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.symbol import symbol_is_type, SymT +<<<<<<< HEAD from .. import config, cpp_builder, ir from ..debug import set_kernel_post_grad_provenance_tracing +======= +from .. import config, ir +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..utils import _align, DeferredLineBase, LineContext, normalize_name from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper @@ -59,6 +63,7 @@ def __init__(self): self.device = "cpu" # must be initialized prior to calling super().__init__() self.included_devices: OrderedSet[str] = OrderedSet() +<<<<<<< HEAD self.model_class_name_suffix = ( config.aot_inductor.model_name_for_generated_files if config.aot_inductor.compile_standalone @@ -68,6 +73,9 @@ def __init__(self): super().__init__() +======= + super().__init__() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.declare = "auto " self.declare_maybe_reference = "decltype(auto) " self.ending = ";" @@ -118,12 +126,16 @@ def _generate_temporary_array_pointer( # e.g. const double** is possible, but not const double* const*. This means # that an array containing pointers must _already_ be properly const-qualified # by the c_type, and not add additional const-ness. +<<<<<<< HEAD # MSVC does not support implicitly converting a const iterator to a const pointer. ptr_call = ( "data()" if force_mutable or c_type.endswith("*") or cpp_builder.is_msvc_cl() else "cbegin()" ) +======= + ptr_call = "data()" if force_mutable or c_type.endswith("*") else "cbegin()" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( f"std::array<{c_type}, {len(elements)}>{{{', '.join(elements)}}}.{ptr_call}" ) @@ -222,6 +234,7 @@ def write_header(self): self.add_device_include(self.device) if V.graph.aot_mode: +<<<<<<< HEAD if not config.aot_inductor.compile_standalone: with open( os.path.join( @@ -234,6 +247,23 @@ def write_header(self): self.header.splice(f"""#include \"{self.model_class_name_suffix}.h\"""") self.header.splice("\n") +======= + with open( + os.path.join(os.path.dirname(__file__), "aoti_runtime", "interface.cpp") + ) as f: + self.header.splice(f.read()) + self.header.splice("\n") + + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + if config.profiler_mark_wrapper_call or enable_kernel_profile: + # No C shim for profiling APIs, assuming profiling is a debugging feature which + # does not provide any ABI compatibility promise. + self.header.splice("#include ") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _include_extra_header(self, header: str): # This is needed for cpp to python dtype conversion self.header.splice(f"#include <{header}>") @@ -508,8 +538,11 @@ def gen_check(handle_kind, idx, name, tensor): def write_wrapper_decl(self): inputs_len = len(V.graph.graph_inputs.keys()) if V.graph.aot_mode: +<<<<<<< HEAD self.codegen_additional_funcs() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if V.graph.const_module: self.header.splice(V.graph.const_module.wrapper_code.header) @@ -521,12 +554,21 @@ def write_wrapper_decl(self): if V.graph.is_const_graph: self.prefix.splice( +<<<<<<< HEAD f""" void {self.aoti_model_class_name}::_const_run_impl( std::vector& output_handles, DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor ) {{ +======= + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ ) else: @@ -534,18 +576,32 @@ def write_wrapper_decl(self): # If we do not split the constant graph, we'll just create # an empty implementation when wrapping the main module. self.prefix.splice( +<<<<<<< HEAD f""" void {self.aoti_model_class_name}::_const_run_impl( std::vector& output_handles, DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor ) {{}} +======= + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ ) +<<<<<<< HEAD run_impl_proto = f""" void {self.aoti_model_class_name}::run_impl( +======= + run_impl_proto = """ + void AOTInductorModel::run_impl( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles // are stolen; the array itself is borrowed @@ -555,7 +611,11 @@ def write_wrapper_decl(self): // borrowed DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor +<<<<<<< HEAD ) {{ +======= + ) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __check_inputs_outputs(input_handles, output_handles); """ @@ -585,7 +645,11 @@ def write_wrapper_decl(self): # Weights are promoted in the JIT mode num_args = len(V.graph.graph_inputs) + len(V.graph.constants) # release GIL to support multiple instances inference (in different threads of the same process) +<<<<<<< HEAD self.prefix.splice("py::gil_scoped_release_simple release;") +======= + self.prefix.splice("py::gil_scoped_release release;") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.prefix.splice( f""" @@ -627,10 +691,17 @@ def write_wrapper_decl(self): ), "Expect all constants to be Tensor" for idx, constants_key in enumerate(V.graph.constants.keys()): if V.graph.aot_mode: +<<<<<<< HEAD # Weights are stored in constants_ and owned by ConstantHandle there. # Don't call std::move here because it will cause constants_ to lose the ownership. self.prefix.writeline( f"""[[maybe_unused]] auto& {constants_key} = constants_->at({idx});""" +======= + # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there. + # Don't call std::move here because it will cause constants_ to lose the ownership. + self.prefix.writeline( + f"""[[maybe_unused]] auto {constants_key} = constants_->at({idx});""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: # Append constants as inputs to the graph @@ -666,9 +737,12 @@ def codegen_input_device_type_var_decl(self, code: IndentedBuffer, name): f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type({name}, &{name}_device_type));" ) +<<<<<<< HEAD def codegen_additional_funcs(self): pass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def codegen_model_kernels(self): self.prefix.writeline("namespace {") @@ -722,6 +796,7 @@ def codegen_model_kernels(self): ) self.prefix.writeline("}") +<<<<<<< HEAD # MSVC string was longer than the limit of 16380 single-byte characters. # https://learn.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2026 MSVC_C2026_MAX_STRING_LENGTH = 16000 @@ -744,6 +819,8 @@ def truncate_string(s: str, length: int) -> list[str]: else: self.prefix.writeline(f'{arg_name} = R"({arg_str_val})";') +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def codegen_model_constructor(self): """ // Generated code example @@ -772,7 +849,11 @@ def codegen_model_constructor(self): ) self.prefix.splice( f""" +<<<<<<< HEAD {self.aoti_model_class_name}::{self.aoti_model_class_name}(std::shared_ptr constants_map, +======= + AOTInductorModel::AOTInductorModel(std::shared_ptr constants_map, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::shared_ptr> constants_array, const std::string& device_str, std::optional cubin_dir) @@ -891,6 +972,7 @@ def escape_string(x): .replace("\t", "\\t") ) +<<<<<<< HEAD # Origin code: self.prefix.writeline(f'in_spec_ = R"({config.aot_inductor.serialized_in_spec})";') # Fix msvc C2026 error via codegen_write_arg_with_large_length_string self.codegen_write_arg_with_large_length_string( @@ -901,6 +983,13 @@ def escape_string(x): self.codegen_write_arg_with_large_length_string( arg_name="out_spec_", arg_str_val=config.aot_inductor.serialized_out_spec, +======= + self.prefix.writeline( + f'in_spec_ = R"({config.aot_inductor.serialized_in_spec})";' + ) + self.prefix.writeline( + f'out_spec_ = R"({config.aot_inductor.serialized_out_spec})";' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for idx, output in enumerate(V.graph.graph_outputs): @@ -934,12 +1023,21 @@ def codegen_const_run_driver(self): """ self.prefix.splice( +<<<<<<< HEAD f""" std::unordered_map {self.aoti_model_class_name}::const_run_impl( DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor, bool initialization ) {{ +======= + """ + std::unordered_map AOTInductorModel::const_run_impl( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization + ) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ ) if not config.aot_inductor.use_runtime_constant_folding: @@ -1088,7 +1186,10 @@ def generate_return(self, output_refs: list[str]): output_buffer = V.graph.graph_outputs[idx] if isinstance(output_buffer, ir.BaseView): output_storage = output_buffer.unwrap_view() +<<<<<<< HEAD assert isinstance(output_storage, (ir.BaseView, ir.MutableBox)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(output_storage.data, ir.ConstantBuffer): is_constant_buffer = True @@ -1122,7 +1223,11 @@ def generate_return(self, output_refs: list[str]): def generate_before_suffix(self, result): if not V.graph.is_const_graph: if V.graph.aot_mode: +<<<<<<< HEAD result.writeline(f"}} // {self.aoti_model_class_name}::run_impl") +======= + result.writeline("} // AOTInductorModel::run_impl") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: result.writeline("} // inductor_entry_impl") @@ -1130,7 +1235,11 @@ def generate_end(self, result): """Generates the end of the code block, and any code needed to call it.""" if V.graph.aot_mode: if V.graph.is_const_graph: +<<<<<<< HEAD result.writeline(f"}} // {self.aoti_model_class_name}::_const_run_impl") +======= + result.writeline("} // AOTInductorModel::_const_run_impl") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: result.writeline("} // namespace torch::aot_inductor\n\n\n") return @@ -1247,7 +1356,10 @@ def generate_c_shim_extern_kernel_call( device: str, *, debug_args: Optional[list[str]] = None, +<<<<<<< HEAD debug_handle: Optional[int] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: """debug_args kwarg allows CppWrapperCpuArrayRef to pass in wrapped arguments in place of args while preserving debug printer output.""" @@ -1264,16 +1376,26 @@ def generate_c_shim_extern_kernel_call( ] with debug_printer_manager: shim_fn = self.get_c_shim_func_name(kernel, device) +<<<<<<< HEAD self.write_provenance_debug_handle(shim_fn, debug_handle) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shim_fn_codes = ( f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));" ) if enable_kernel_profile: +<<<<<<< HEAD debug_handle_str = "" if debug_handle is None else f":{debug_handle}" shim_fn_codes = textwrap.dedent( f""" {{ RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}{debug_handle_str}", nullptr); +======= + shim_fn_codes = textwrap.dedent( + f""" + {{ + RECORD_FUNCTION("{shim_fn}", c10::ArrayRef()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {shim_fn_codes} }} """ @@ -1296,6 +1418,7 @@ def generate_c_shim_extern_kernel_alloc( args = [*args, f"&{output_handle_name}"] device = d.type if (d := extern_kernel.get_device()) else self.device +<<<<<<< HEAD debug_handle = None if config.trace.provenance_tracking_level != 0: @@ -1318,6 +1441,13 @@ def generate_c_shim_extern_kernel_alloc( f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object({output_handle_name}));" ) elif not is_inplace: +======= + self.generate_c_shim_extern_kernel_call( + extern_kernel.get_kernel_name(), args, device + ) + + if not is_inplace: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});") def _generate_extern_kernel_alloc_helper(self, extern_kernel, args): @@ -1361,6 +1491,7 @@ def generate_c_shim_fallback_kernel( raise NotImplementedError(f"unsupported type of {output=}") args = args + output_args device = d.type if (d := fallback_kernel.get_device()) else self.device +<<<<<<< HEAD debug_handle = None if config.trace.provenance_tracking_level != 0: @@ -1369,11 +1500,16 @@ def generate_c_shim_fallback_kernel( fallback_kernel.cpp_kernel_name, # type: ignore[arg-type] is_extern=True, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.generate_c_shim_extern_kernel_call( fallback_kernel.cpp_kernel_name, # type: ignore[arg-type] args, device, +<<<<<<< HEAD debug_handle=debug_handle, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for raii_handle in output_raii_handles: self.writeline(raii_handle) @@ -1385,7 +1521,10 @@ def _generate_extern_kernel_out_helper( out_view: Optional[str], args: list[str], device: str, +<<<<<<< HEAD debug_handle: Optional[int] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: if out_view: out_name = f"{out}_as_strided" @@ -1394,9 +1533,13 @@ def _generate_extern_kernel_out_helper( else: args.insert(0, out) +<<<<<<< HEAD self.generate_c_shim_extern_kernel_call( kernel, args, device, debug_handle=debug_handle ) +======= + self.generate_c_shim_extern_kernel_call(kernel, args, device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def generate_scatter_fallback( self, @@ -1506,6 +1649,7 @@ def codegen_dynamic_scalar(self, node): # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again self.unbacked_symbol_decls.add(str(node.sym)) +<<<<<<< HEAD def codegen_dynamic_select_index(self, node): index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int) @@ -1520,6 +1664,8 @@ def codegen_dynamic_select_index(self, node): # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def make_buffer_free(self, buffer): return ( "" @@ -1536,7 +1682,11 @@ def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str def generate_profiler_mark_wrapper_call(self, stack): self.wrapper_call.writeline( +<<<<<<< HEAD 'RAIIAtenRecordFunctionHandle record_inductor_wrapper_call_("inductor_wrapper_call", nullptr);' +======= + 'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef());' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def generate_start_graph(self): @@ -1593,6 +1743,7 @@ def codegen_int_array_var( # This is why writeline needs to explicitly passed in as a parameter. var = f"int_array_{next(self.int_array_id)}" ctype = "int64_t" +<<<<<<< HEAD if int_array == "{}": # An array of unknown bound cannot be initialized with {}. if known_statically: @@ -1612,6 +1763,14 @@ def codegen_int_array_var( writeline(f"static const {ctype} {var}[] = {int_array};") else: writeline(f"const {ctype} {var}[] = {int_array};") +======= + if var not in self.declared_int_array_vars: + self.declared_int_array_vars.add(var) + if known_statically: + writeline(f"static constexpr {ctype} {var}[] = {int_array};") + else: + writeline(f"const {ctype} {var}[] = {int_array};") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return var def make_buffer_allocation(self, buffer): @@ -1622,11 +1781,18 @@ def make_buffer_allocation(self, buffer): buffer.get_size(), buffer.get_stride(), V.graph.get_allocation_size(buffer), +<<<<<<< HEAD buffer.get_is_pinned(), ) def make_allocation( self, name, device, dtype, shape, stride, allocation_shape=None, is_pinned=False +======= + ) + + def make_allocation( + self, name, device, dtype, shape, stride, allocation_shape=None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): if allocation_shape is None: allocation_shape = shape @@ -1678,9 +1844,14 @@ def make_allocation( ] self.wrapper_call.writeline(f"AtenTensorHandle {handle_name};") +<<<<<<< HEAD pinned_str = "_pinned" if is_pinned else "" self.wrapper_call.writeline( f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided{pinned_str}({', '.join(args)}));" +======= + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if allocation_size != size: @@ -1695,6 +1866,7 @@ def make_allocation( self.wrapper_call.writeline( f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_as_strided({', '.join(args)}));" ) +<<<<<<< HEAD self.wrapper_call.writeline( f"wrap_with_raii_handle_if_needed({old_handle_name});" ) @@ -1704,6 +1876,12 @@ def make_allocation( def codegen_alloc_from_pool( self, name, offset, dtype, shape, stride ) -> tuple[str, list[str]]: +======= + + return f"RAIIAtenTensorHandle {name}({handle_name});" + + def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) size = self.codegen_shape_tuple(shape) stride = self.codegen_shape_tuple(stride) tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" @@ -1720,6 +1898,7 @@ def codegen_alloc_from_pool( ), f"&{tmp_name}", ] +<<<<<<< HEAD # We return the lines instead of writing here because writing here is bug prune. # If you write aoti_torch__alloc_from_pool lines, you must write the RAIIAtenTensorHandle # as well, otherwise you get memory leaks @@ -1728,6 +1907,13 @@ def codegen_alloc_from_pool( f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));", ] return f"RAIIAtenTensorHandle({tmp_name})", allocations_to_write +======= + self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));" + ) + return f"RAIIAtenTensorHandle({tmp_name})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def codegen_reinterpret_view( self, @@ -1842,7 +2028,11 @@ def create_new_tensor_handle() -> tuple[str, list[str]]: # ``` return final_tensor_str +<<<<<<< HEAD def codegen_device_copy(self, src, dst, non_blocking: Union[bool, str]): +======= + def codegen_device_copy(self, src, dst, non_blocking: bool): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """This function is overridden by cpp_wrapper_cpu_array_ref, so we don't need to handle cases where dst is not an AtenTensorHandle.""" self.writeline( @@ -1947,9 +2137,13 @@ def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs): finally: self.pop_codegened_graph() +<<<<<<< HEAD def codegen_while_loop(self, while_loop, stack_output=False): if stack_output: raise NotImplementedError("NYI cpp wrapper for while_loop_stack_output") +======= + def codegen_while_loop(self, while_loop): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_bool_pred = isinstance( while_loop.cond_subgraph.graph.graph_outputs[0], ir.ShapeAsConstantBuffer ) @@ -2310,7 +2504,11 @@ def generate_scoped_gil_acquire(self, declarations_before_scope, lines_in_scope) scoped_lines.writeline("{") with scoped_lines.indent(): +<<<<<<< HEAD scoped_lines.writeline("py::gil_scoped_acquire_simple acquire;") +======= + scoped_lines.writeline("py::gil_scoped_acquire acquire;") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) scoped_lines.writelines(lines_in_scope.split("\n")) scoped_lines.writelines("}") return scoped_lines._lines @@ -2665,7 +2863,11 @@ def generate_fallback_kernel_with_runtime_lookup_aot( "AtenTensorHandle", tensor_call_args, force_mutable=True ) +<<<<<<< HEAD extern_kernel_node_index = len(V.extern_kernel_nodes) - 1 +======= + extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.writeline( f"aoti_torch_proxy_executor_call_function(proxy_executor, " f"{extern_kernel_node_index}, " diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py index 63c5bc2debe8b..86e071ae6c3f5 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -297,7 +297,11 @@ def write_wrapper_decl(self): # Weights are promoted in the JIT mode num_args = len(V.graph.graph_inputs) + len(V.graph.constants) # release GIL to support multiple instances inference (in different threads of the same process) +<<<<<<< HEAD self.prefix.splice("py::gil_scoped_release_simple release;") +======= + self.prefix.splice("py::gil_scoped_release release;") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.prefix.splice( f""" @@ -409,7 +413,10 @@ def use_thread_local_cached_output_tensor(idx, output): output_buffer = V.graph.graph_outputs[idx] if isinstance(output_buffer, ir.BaseView): output_storage = output_buffer.unwrap_view() +<<<<<<< HEAD assert isinstance(output_storage, (ir.BaseView, ir.MutableBox)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(output_storage.data, ir.ConstantBuffer): is_constant_buffer = True @@ -565,6 +572,7 @@ def make_buffer_allocation(self, buffer): buffer.get_size(), buffer.get_stride(), buffer if self.can_stack_allocate_buffer(buffer) else None, +<<<<<<< HEAD buffer.get_is_pinned(), ) @@ -577,6 +585,12 @@ def make_allocation( stride, buffer_if_can_stack_allocate=None, is_pinned=False, +======= + ) + + def make_allocation( + self, name, device, dtype, shape, stride, buffer_if_can_stack_allocate=None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): orig_stride = stride device_str = self.codegen_device(device) @@ -623,9 +637,14 @@ def make_allocation( ] self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;") +<<<<<<< HEAD pinned_str = "_pinned" if is_pinned else "" self.wrapper_call.writeline( f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided{pinned_str}({', '.join(args)}));" +======= + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return f"RAIIAtenTensorHandle {name}({name}_handle);" @@ -771,7 +790,11 @@ def generate_fallback_kernel_with_runtime_lookup( buf_name, python_kernel_name, get_args, op_overload, raw_args, outputs ) +<<<<<<< HEAD def codegen_device_copy(self, src, dst, non_blocking: Union[bool, str]): +======= + def codegen_device_copy(self, src, dst, non_blocking: bool): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # aoti_torch_tensor_copy_ takes AtenTensorHandle as input, # while stack-allocation results in ArrayRefTensor # so disable stack allocation here diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 24b87fa8fa490..cbd8b2187cb13 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -211,17 +211,24 @@ def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params): ] arg_types = [arg_type_loookup[name] for name in call_args] arg_signatures = [triton_meta["signature"][name] for name in call_args] +<<<<<<< HEAD scratch_spaces = { name: params[name] for name in ["global_scratch", "profile_scratch"] if params.get(name, None) is not None } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) call_args_str = wrapper.generate_args_decl( prefix, call_args, arg_types, arg_signatures, +<<<<<<< HEAD scratch_spaces=scratch_spaces, +======= + workspace_size=params.get("global_scratch") or 0, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};") launch_kernel_args = [ @@ -234,8 +241,11 @@ def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params): "kernel_args_", "stream_", ] +<<<<<<< HEAD if wrapper.device == "xpu": launch_kernel_args.append(str(params["threads_per_warp"])) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prefix.writeline(f"launchKernel({', '.join(launch_kernel_args)});") @@ -461,7 +471,11 @@ def generate_args_decl( arg_types, arg_signatures, is_triton_kernel=True, +<<<<<<< HEAD scratch_spaces: Optional[dict[str, int]] = None, +======= + workspace_size=0, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Generates any declarations of args to pass into a kernel call, and then returns the arg names. @@ -579,6 +593,7 @@ def process_args(arg, arg_type, arg_signature=None): ): process_args(arg, arg_type, arg_signature) +<<<<<<< HEAD for scratch_name, workspace_size in (scratch_spaces or {}).items(): if ( is_triton_kernel @@ -599,6 +614,24 @@ def process_args(arg, arg_type, arg_signature=None): scratch_def, scratch_var = scratch code.writelines([maybe_hipify_code_wrapper(x) for x in scratch_def]) new_args.append(f"&{scratch_var}") +======= + if ( + is_triton_kernel + and ( + global_scratch := self.device_codegen.cpp_global_scratch( + next(self.arg_var_id), + workspace=TritonScratchWorkspace( + size=workspace_size, + generate_dtype_str=(lambda: self.codegen_dtype(torch.uint8)), + ), + ) + ) + is not None + ): + global_scratch_def, global_scratch_var = global_scratch + code.writelines([maybe_hipify_code_wrapper(x) for x in global_scratch_def]) + new_args.append(f"&{global_scratch_var}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ", ".join(new_args) diff --git a/torch/_inductor/codegen/cpp_wrapper_mps.py b/torch/_inductor/codegen/cpp_wrapper_mps.py index aea4470f1c964..2321ea9a50dba 100644 --- a/torch/_inductor/codegen/cpp_wrapper_mps.py +++ b/torch/_inductor/codegen/cpp_wrapper_mps.py @@ -3,6 +3,7 @@ import sympy import torch +<<<<<<< HEAD from torch.utils._ordered_set import OrderedSet from ..ir import GraphPartitionSignature @@ -21,6 +22,16 @@ def __init__(self) -> None: super().__init__() self._used_kernel_names: OrderedSet[str] = OrderedSet() +======= + +from ..ir import GraphPartitionSignature +from ..virtualized import V +from .cpp_wrapper_gpu import CppWrapperGpu +from .wrapper import PythonWrapperCodegen + + +class CppWrapperMps(CppWrapperGpu): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def create( is_subgraph: bool, @@ -34,6 +45,7 @@ def _generate_kernel_call_helper( self, kernel_name: str, call_args: list[str], +<<<<<<< HEAD *, device: Optional[torch.device] = None, triton: bool = True, @@ -43,10 +55,15 @@ def _generate_kernel_call_helper( triton_meta: Optional[dict[str, Any]] = None, graph_name: str = "", original_fxnode_name: Optional[str] = None, +======= + arg_types: Optional[list[type]] = None, + **kwargs: dict[str, Any], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: """ Generates MPS kernel call code. It should look something like: ``` +<<<<<<< HEAD get_mps_lib_0()->runCommandBlock([&] { get_mps_lib_0()->startEncoding(); aoti_torch_mps_set_arg(get_mps_lib_0_handle(), 0, buf0); @@ -73,17 +90,38 @@ def _generate_kernel_call_helper( assert device.type == "mps" +======= + auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel"); + auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get()); + mps_lib_0_func->runCommandBlock([&] { + mps_lib_0_func->startEncoding(); + aoti_torch_mps_set_arg(mps_lib_0_func_handle, 0, buf0); + aoti_torch_mps_set_arg(mps_lib_0_func_handle, 1, arg0_1); + ... + mps_lib_0_func->dispatch(9); + }); + ``` + """ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert arg_types is not None new_args = [] for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])): if isinstance(arg_type, torch.dtype): new_args.append( +<<<<<<< HEAD f"aoti_torch_mps_set_arg_tensor(get_{kernel_name}_handle(), {idx}, {arg});" ) elif arg_type in (int, sympy.core.symbol.Symbol): new_args.append( f"aoti_torch_mps_set_arg_int(get_{kernel_name}_handle(), {idx}, {arg});" +======= + f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});\n" + ) + elif arg_type in (int, sympy.core.symbol.Symbol): + new_args.append( + f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});\n" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: raise NotImplementedError( @@ -94,11 +132,17 @@ def _generate_kernel_call_helper( if threads is None: raise NotImplementedError("No threads or group_size provided") elif group_size is None: +<<<<<<< HEAD new_args.append(f"get_{kernel_name}()->dispatch({threads});\n") else: new_args.append( f"get_{kernel_name}()->dispatch({threads}, {group_size});\n" ) +======= + new_args.append(f"{kernel_name}->dispatch({threads});\n") + else: + new_args.append(f"{kernel_name}->dispatch({threads}, {group_size});\n") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # debug printer related logic for cpp kernel type. debug_printer_manager = V.graph.wrapper_code.debug_printer @@ -110,6 +154,7 @@ def _generate_kernel_call_helper( "cpp", ) with debug_printer_manager: +<<<<<<< HEAD self.write_mps_kernel_call(kernel_name, new_args) def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None: @@ -121,6 +166,21 @@ def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None: for call_arg in call_args: self.writeline(f" {call_arg}") self.writeline("});") +======= + self.writeline(self.wrap_kernel_call(kernel_name, new_args)) + + def wrap_kernel_call(self, name: str, call_args: list[str]) -> str: + lib_name = name[: -len("_func")] + calling_args = " ".join(call_args) + return f""" + auto {name} = {lib_name}.getKernelFunction("generated_kernel"); + auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get()); + {name}->runCommandBlock([&] {{ + {name}->startEncoding(); + {calling_args} + }}); + """ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def get_device_include_path(device: str) -> str: @@ -129,6 +189,7 @@ def get_device_include_path(device: str) -> str: "#include \n" "#include " ) +<<<<<<< HEAD def codegen_additional_funcs(self) -> None: """ @@ -178,3 +239,5 @@ def codegen_additional_funcs(self) -> None: ) self.prefix.writeline(" return handle;") self.prefix.writeline("}") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py index 67828622fde59..dddbb1b81a00f 100644 --- a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +++ b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -144,9 +144,14 @@ def codegen_template( assert all(isinstance(n, ComputedBuffer) for n in epilogue_ir_nodes), ( "Epilogue nodes must all be instances of ir.ComputedBuffer" ) +<<<<<<< HEAD kernel, render = ctb.make_kernel_render( # type: ignore[misc] ctb, epilogue_nodes=epilogue_nodes ) +======= + kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_nodes) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with kernel: for node in [template_node, *epilogue_nodes]: node.mark_run() diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index 0a9c6b0ca4e5f..8f6ea39ef9afe 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -29,7 +29,10 @@ IRNode, Layout, PrimitiveInfoType, +<<<<<<< HEAD ShapeAsConstantBuffer, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorBox, ) from ...utils import sympy_product @@ -177,9 +180,12 @@ def get_ld(node) -> Union[Expr, int]: def get_dynamic_shape_args(self) -> list[Union[Expr, int]]: return [*self.get_layout_args(), *self.size_args] +<<<<<<< HEAD def get_offset_args(self) -> list[Expr]: return [node.get_layout().offset for node in self.named_nodes.values()] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def find_ld_idx(node: IRNode) -> int: strides = node.get_stride() @@ -267,7 +273,10 @@ def def_kernel( In this case, the `input_reorder` would be [2, 0, 1]. additional_size_args: Additional size arguments for epilogue inputs """ +<<<<<<< HEAD # NB: name order matters here, it's used to match up offsets +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) names = [x.strip() for x in names_str.strip().split(",")] if len(inputs) + len(outputs) != len(names): raise RuntimeError( @@ -289,7 +298,10 @@ def def_kernel( free_symbols: OrderedSet[Expr] = OrderedSet() for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): if node is not None: +<<<<<<< HEAD # NB: named nodes must be populated in the order of names +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.named_nodes[name] = node self.args.output_buffers[node.get_name()] = name @@ -311,17 +323,25 @@ def def_kernel( size_vars.extend(str(s) for s in free_symbols) self.size_args.extend(free_symbols) size_args = [f"const int {s}" for s in size_vars] +<<<<<<< HEAD offset_args = [f"const int {name}_offset" for name in self.named_nodes.keys()] +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) runtime_arg_decls = ",".join( [f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info] ) if runtime_arg_decls: runtime_arg_decls += ", " +<<<<<<< HEAD signature = ( f"int {self.kernel_name}({', '.join(arg_defs + size_args + offset_args)},\ {runtime_arg_decls}{self._EXTRA_CPP_ARGS})" ) +======= + signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args)}, {runtime_arg_decls}{self._EXTRA_CPP_ARGS})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.signature = signature return signature @@ -354,6 +374,7 @@ def call_kernel( _, call_args, _, arg_types = self.args.python_argdefs() dynamic_shape_args = self.get_dynamic_shape_args() +<<<<<<< HEAD offset_args = self.get_offset_args() call_args.extend(dynamic_shape_args) # type: ignore[arg-type] call_args.extend(offset_args) # type: ignore[arg-type] @@ -361,6 +382,12 @@ def call_kernel( call_args.append(str(arg)) arg_types.extend("const int" for _ in dynamic_shape_args) arg_types.extend("const int" for _ in offset_args) +======= + call_args.extend(dynamic_shape_args) # type: ignore[arg-type] + for arg in self.runtime_arg_values: + call_args.append(arg) + arg_types.extend("int" for _ in dynamic_shape_args) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for arg in self.runtime_arg_info: arg_types.append(arg.ty) # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar @@ -436,6 +463,18 @@ def max_valid_index(self, node: IRNode, default=-1): max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i] return max_valid_offset +<<<<<<< HEAD +======= + def offset(self, node: IRNode) -> str: + """ + Generates code which represents offset of a given node. + """ + + if node is None: + return "0" + return str(node.get_layout().offset) # type: ignore[union-attr] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def ptr(self, node: IRNode) -> str: """ Generates code which represents pointer of a given node. @@ -446,7 +485,12 @@ def ptr(self, node: IRNode) -> str: arg_name = self.arg_name(node) if arg_name is None: return "nullptr" +<<<<<<< HEAD return f"{arg_name} + {arg_name}_offset" +======= + offset = self.offset(node) + return arg_name if offset == "0" else f"{arg_name} + {offset}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def size( self, @@ -632,26 +676,37 @@ def hash_key(self) -> str: """ Return kernel hash key that does not depend on swizzle. """ +<<<<<<< HEAD swizzle_str: str = ( str(self.info_kwargs.get("swizzle")) if isinstance(self.info_kwargs, dict) else "None" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return "-".join( [ self.category, self.bmreq.hash_key, +<<<<<<< HEAD swizzle_str, +======= + str(self.info_dict().get("swizzle")), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] ) def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: +<<<<<<< HEAD """ Information returned here is logged to the autotune log file when that is enabled. In general, we should avoid calling this function as it is expensive to compute, and can add up very fast. """ +======= + """Information returned here is logged to the autotune log file when that is enabled.""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.info_kwargs is not None and "op" in self.info_kwargs: op: Any = self.info_kwargs["op"] return { @@ -672,7 +727,11 @@ def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType else: return {"backend": "CUDA", "op_type": "unknown"} +<<<<<<< HEAD def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= + def output_node(self) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.bmreq.update_workspace_size() return TensorBox.create( CUDATemplateBuffer( diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 4aa0aeb46e077..59d52039085c5 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -3,15 +3,23 @@ import hashlib import itertools from dataclasses import dataclass +<<<<<<< HEAD from typing import Any, Optional, TYPE_CHECKING, Union +======= +from typing import Any, Optional, TYPE_CHECKING +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing_extensions import override from unittest.mock import patch import sympy import torch +<<<<<<< HEAD from torch._inductor import config from torch._inductor.utils import clear_on_fresh_cache, Placeholder +======= +from torch._inductor.utils import Placeholder +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._logging import getArtifactLogger from ...autotune_process import CUDABenchmarkRequest, TensorMeta @@ -39,12 +47,17 @@ class ArgInfo: ty: str +<<<<<<< HEAD @clear_on_fresh_cache class CUDATemplate(KernelTemplate): index_counter = itertools.count() # dict of cache key to (code, size_args) code_cache: dict[str, tuple[str, tuple[int, ...], tuple[int, ...]]] = {} cache_clear = staticmethod(code_cache.clear) +======= +class CUDATemplate(KernelTemplate): + index_counter = itertools.count() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__( self, @@ -54,15 +67,25 @@ def __init__( input_reorder: Optional[list[int]] = None, ) -> None: """ +<<<<<<< HEAD Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly. +======= + + Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: name (str): The name of the CUDATemplate object. input_nodes (List[IRNode]): A list of input IRNodes. layout (Layout): The layout of the output buffer / tensor. +<<<<<<< HEAD input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. +======= + input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ super().__init__(name) self.input_nodes = input_nodes @@ -70,15 +93,19 @@ def __init__( self.input_reorder = input_reorder self.layout = layout +<<<<<<< HEAD @classmethod @functools.lru_cache(None) def _template_from_string(cls, source: str) -> Any: return KernelTemplate._template_from_string(source) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def supports_epilogue_fusion(op: GemmOperation) -> bool: return False +<<<<<<< HEAD def make_key(self, name: str, input_key: str, layout_repr: str) -> str: """ Make a key for the code cache. The idea of the method is to cache @@ -128,6 +155,32 @@ def generate_code_and_args( runtime_arg_values=self.get_runtime_arg_values(**kwargs), ) with patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)): +======= + def generate( # type: ignore[override] + self, + description, + **kwargs, + ) -> CUDATemplateCaller: + """ + Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller + may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning. + + Args: + kwargs: Additional keyword arguments. + + Returns: + A CUDATemplateCaller object representing the generated CUDA template caller. + """ + kernel_name = str(Placeholder.KERNEL_NAME) + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), + CUDATemplateKernel( + kernel_name=kernel_name, + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) as kernel, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) code = self.render(kernel=kernel, **kwargs) _, call_args, _, _ = kernel.args.python_argdefs() autotuning_log.debug("Generated Code:\n%s", code) @@ -152,6 +205,7 @@ def generate_code_and_args( ) V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :])) size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args()) +<<<<<<< HEAD offset_args = V.graph.sizevars.size_hints(kernel.get_offset_args()) if key is not None: @@ -194,6 +248,10 @@ def generate( # type: ignore[override] ) # not caching since kernel name is needed below +======= + extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs)) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kernel_hash = hashlib.sha256(code.encode("utf-8")).hexdigest()[:8] kernel_name = f"cutlass_{kernel_hash}" code = code.replace(self.name, kernel_name) @@ -201,8 +259,13 @@ def generate( # type: ignore[override] # create the BenchmarkRequest bmreq = CUDABenchmarkRequest( kernel_name=kernel_name, +<<<<<<< HEAD input_tensor_meta=input_tensor_meta, output_tensor_meta=output_tensor_meta, +======= + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extra_args=extra_args, source_code=code, ) diff --git a/torch/_inductor/codegen/cuda/cutlass_cache.py b/torch/_inductor/codegen/cuda/cutlass_cache.py index 519125888c16c..56b3fd81df9ea 100644 --- a/torch/_inductor/codegen/cuda/cutlass_cache.py +++ b/torch/_inductor/codegen/cuda/cutlass_cache.py @@ -1,7 +1,10 @@ # mypy: allow-untyped-defs import functools import hashlib +<<<<<<< HEAD import inspect +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import json import logging import os @@ -10,7 +13,10 @@ import torch._inductor.config as config from torch._inductor.codecache import cutlass_key +<<<<<<< HEAD from torch._inductor.codegen.cuda import cutlass_utils, serialization +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer from torch._inductor.runtime.cache_dir_utils import cache_dir @@ -29,6 +35,7 @@ def get_config_request_key( instantiation_level: str, ) -> str: """ +<<<<<<< HEAD Return a key for the full ops, based on cutlass key, arch, cuda version, instantiation level, and serialization.py file hash. """ @@ -41,14 +48,21 @@ def get_file_hash(file_module): serialization_hash = get_file_hash(serialization) cutlass_utils_hash = get_file_hash(cutlass_utils) +======= + Return a key for the full ops, based on cutlass key, arch, cuda version, and instantiation level. + """ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) hash_target = "-".join( [ cutlass_key().hex(), arch, cuda_version, instantiation_level, +<<<<<<< HEAD serialization_hash, cutlass_utils_hash, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] ) return hashlib.sha256(hash_target.encode("utf-8")).hexdigest()[0:8] diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py index 605b93dff5926..db158bacf042d 100644 --- a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py +++ b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py @@ -2,7 +2,10 @@ from sympy import Expr +<<<<<<< HEAD import torch._inductor.config as config +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.ir import ( ComputedBuffer, InputBuffer, @@ -28,6 +31,30 @@ import textwrap from typing import Union +<<<<<<< HEAD +======= + from cutlass.backend.c_types import ( # type: ignore[import-untyped, import-not-found] + EmptyByte, + ) + from cutlass.backend.epilogue import ( # type: ignore[import-untyped, import-not-found] + dtype2ctype, + ) + from cutlass.backend.evt import ( # type: ignore[import-untyped, import-not-found] + EpilogueFunctorVisitor, + ) + from cutlass.backend.evt.backend.emitter_base import ( # type: ignore[import-untyped, import-not-found] + FusionCallbacks, + ) + from cutlass.backend.evt.backend.sm90_emitter import ( # type: ignore[import-untyped, import-not-found] + CollectiveEpilogue, + ) + from cutlass.backend.evt.frontend import ( # type: ignore[import-untyped, import-not-found] + PythonASTFrontend, + ) + from cutlass.backend.evt.ir.tensor import ( # type: ignore[import-untyped, import-not-found] + Tensor as CutlassTensor, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from cutlass_library import ( DataType, EpilogueScheduleType, @@ -35,6 +62,7 @@ TileDescription, ) +<<<<<<< HEAD if config.is_fbcode(): import python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401 else: @@ -61,15 +89,26 @@ def new_name(self, name: str) -> str: def get(self, name: str) -> str: return self.buf_renames.get(name, name) +======= + from torch._inductor.codegen.cuda import cuda_env + from torch._inductor.utils import IndentedBuffer + + _CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def create_example_tensors( var_name_to_buffer_name: dict[str, str], name_to_buffer: dict[str, Buffer], size_hint_fn: Callable[[Union[Expr, int]], int], +<<<<<<< HEAD ) -> dict[str, python_cutlass.backend.evt.ir.tensor.Tensor]: def cutlass_tensor_from_buffer( buffer: Buffer, ) -> python_cutlass.backend.evt.ir.tensor.Tensor: +======= + ) -> dict[str, CutlassTensor]: + def cutlass_tensor_from_buffer(buffer: Buffer) -> CutlassTensor: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shape = buffer.get_layout().size stride = buffer.get_layout().stride shape = tuple(size_hint_fn(x) for x in shape) @@ -84,11 +123,19 @@ def cutlass_tensor_from_buffer( non-contiguous layout, received stride: {stride} and shape: {shape}" ) +<<<<<<< HEAD return python_cutlass.backend.evt.ir.tensor.Tensor( shape=shape, layout_tag=( LayoutType.RowMajor if is_row_major else LayoutType.ColumnMajor ), +======= + return CutlassTensor( + shape=shape, + layout_tag=LayoutType.RowMajor + if is_row_major + else LayoutType.ColumnMajor, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) element=torch_dtype_to_cutlass_type(buffer.get_layout().dtype), ) @@ -99,7 +146,11 @@ def cutlass_tensor_from_buffer( def trace( fn_src: str, +<<<<<<< HEAD example_tensors: dict[str, python_cutlass.backend.evt.ir.tensor.Tensor], +======= + example_tensors: dict[str, CutlassTensor], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) accum_type: DataType, output_type: DataType, tile_description: TileDescription, @@ -107,6 +158,7 @@ def trace( name_to_buffer: dict[str, Buffer], size_hint_fn: Callable[[Union[Expr, int]], int], **kwargs: dict[str, Any], +<<<<<<< HEAD ) -> tuple[str, str, str, EVTArgRenames]: cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type] assert cuda_arch >= 90, "Only SM90+ is supported for EVT" @@ -133,26 +185,54 @@ def trace( epilogue_functor, name_to_buffer, size_hint_fn ) return evt_name, evt_args, evt_code, arg_renames +======= + ) -> tuple[str, str, str]: + cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type] + assert cuda_arch >= 90, "Only SM90+ is supported for EVT" + epilogue_functor = _trace(fn_src, example_tensors, cuda_arch, **kwargs) + visitor = EpilogueFunctorVisitor(cuda_arch, epilogue_functor) + fusion_callbacks = FusionCallbacks(visitor.graph, cuda_arch, emit_CD=False) + collective_epilogue = CollectiveEpilogue( + tile_description, + epilogue_schedule, + accum_type, + output_type, + fusion_callbacks, + ) + evt_name, evt_code = collective_epilogue.emit() + evt_args = _render_argument_type(epilogue_functor, name_to_buffer, size_hint_fn) + return evt_name, evt_args, evt_code +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Based off of # https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/python/cutlass/epilogue/epilogue.py#L117 # This is modified to enable directly passing the source code of the epilogue vs getting it from a bona-fide python function # The reason for this is that inspect.getsource does not work with functions defined at runtime via exec/eval def _trace( +<<<<<<< HEAD fn_src: str, example_tensors: dict[str, python_cutlass.backend.evt.ir.tensor.Tensor], cc: int, **kwargs: Any, ) -> EpilogueFunctor: class EpilogueFunctor(python_cutlass.backend.evt.frontend.PythonASTFrontend): +======= + fn_src: str, example_tensors: dict[str, CutlassTensor], cc: int, **kwargs: Any + ) -> EpilogueFunctor: + class EpilogueFunctor(PythonASTFrontend): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self, cc: int, **kwargs: Any): self.source = textwrap.dedent(fn_src) super().__init__(cc, **kwargs) +<<<<<<< HEAD def parse( self, example_inputs: dict[str, python_cutlass.backend.evt.ir.tensor.Tensor], ) -> None: +======= + def parse(self, example_inputs: dict[str, CutlassTensor]) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.example_inputs = example_inputs self.ast = ast.parse(self.source) self.visit(self.ast) @@ -166,6 +246,7 @@ def _render_argument_type( epilogue_functor: EpilogueFunctor, name_to_buffer: dict[str, Buffer], size_hint_fn: Callable[[Union[Expr, int]], int], +<<<<<<< HEAD ) -> tuple[str, EVTArgRenames]: epilogue_thread_type = epilogue_functor.epilogue_thread_type arg_renames = EVTArgRenames() @@ -176,6 +257,17 @@ def is_nested_visitor_type(t: type) -> bool: "python_cutlass.backend.c_types.visitor_factory..VisitorType", "cutlass.backend.c_types.visitor_factory..VisitorType", } +======= + ) -> str: + epilogue_thread_type = epilogue_functor.epilogue_thread_type + + # Fragile, but this is the only way to guarantee t is expected type because t is a local class + def is_nested_visitor_type(t: type) -> bool: + return ( + ".".join([t.__module__, t.__qualname__]) + == "cutlass.backend.c_types.visitor_factory..VisitorType" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) buffer = IndentedBuffer() with buffer.set_tabwidth(2): @@ -187,9 +279,13 @@ def render_argument_type(name: str, t: CutlassArgType) -> None: fields = [ ( fname, +<<<<<<< HEAD _get_arg_from_node( ty, name_to_buffer[name], size_hint_fn, arg_renames ), +======= + _get_arg_from_node(ty, name_to_buffer[name], size_hint_fn), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for fname, ty in t._fields_ ] @@ -220,6 +316,7 @@ def render_thread_type(name: str, t: CutlassArgType) -> None: render_argument_type("thread", epilogue_thread_type) buffer.writeline("}") +<<<<<<< HEAD return buffer.getvalue(), arg_renames def _get_arg_from_node( @@ -227,16 +324,29 @@ def _get_arg_from_node( node: Buffer, size_hint_fn: Callable[[Union[Expr, int]], int], arg_renames: EVTArgRenames, +======= + return buffer.getvalue() + + def _get_arg_from_node( + arg_ty: type, node: Buffer, size_hint_fn: Callable[[Union[Expr, int]], int] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> str: from ..cuda_template import CUTLASSTemplate # Today, arguments are either a pointer to the # node's memory, a stride tuple, the datatype # Once again, need to check for local class type for stride tuple +<<<<<<< HEAD if str(arg_ty) in { ".TupleType'>", ".TupleType'>", }: +======= + if ( + str(arg_ty) + == ".TupleType'>" + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DEFAULT_STRIDE_LEN = 3 assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN stride = [size_hint_fn(x) for x in node.get_layout().stride] @@ -255,13 +365,21 @@ def render_stride(x: int) -> str: return f"{{{', '.join([render_stride(x) for x in stride])}}}" elif issubclass(arg_ty, ctypes.c_void_p): +<<<<<<< HEAD name = arg_renames.new_name(node.get_name()) return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) ({name} + {name}_offset)" +======= + return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) {node.get_name()}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif ( arg_ty in _CUTLASS_C_DTYPES ): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently return f"{CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}(0)" +<<<<<<< HEAD elif issubclass(arg_ty, python_cutlass.backend.c_types.EmptyByte): +======= + elif issubclass(arg_ty, EmptyByte): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return "{}" raise NotImplementedError(f"Unsupported arg type: {arg_ty}") diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py index 95af1a968a97c..6f834156568e7 100644 --- a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py +++ b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py @@ -396,7 +396,11 @@ def emit(self, operation): "align_a": str(operation.A.alignment), "align_b": str(operation.B.alignment), "align_c": str(operation.C.alignment), +<<<<<<< HEAD "align_d": str(operation.D.alignment), +======= + "align_d": str(operation.C.alignment), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "transform_a": ComplexTransformTag[operation.A.complex_transform], "transform_b": ComplexTransformTag[operation.B.complex_transform], "math_operation": MathOperationTag[ diff --git a/torch/_inductor/codegen/cuda/cutlass_presets.py b/torch/_inductor/codegen/cuda/cutlass_presets.py index 346be534e82e6..a401ab8015579 100644 --- a/torch/_inductor/codegen/cuda/cutlass_presets.py +++ b/torch/_inductor/codegen/cuda/cutlass_presets.py @@ -20,6 +20,7 @@ def gen_cutlass_presets() -> dict[int, dict[str, list[str]]]: if arch == "90": preset = presets[0] preset["0"] = [ +<<<<<<< HEAD r"cutlass3x_sm90_tensorop_.*_128x128x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", r"cutlass3x_sm90_tensorop_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", r"cutlass3x_sm90_tensorop_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", @@ -82,6 +83,222 @@ def gen_cutlass_presets() -> dict[int, dict[str, list[str]]]: r"cutlass3x_sm90_tensorop_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", r"cutlass3x_sm90_tensorop_.*_256x192x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", r"cutlass3x_sm90_tensorop_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", +======= + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_64x256x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + ] + preset["1111"] = [ + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + ] + preset["2222"] = [ + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + ] + preset["3333"] = [ + r"cutlass3x_sm90_tensorop_s64x48x16gemm_.*_64x48x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_4x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_4x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_4x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_256x192x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_2x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_256x192x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + ] + preset["4444"] = [ + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x8x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_64x192x64_4x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_4x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_2x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_2x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + ] + preset["5555"] = [ + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_2x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x32x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_128x32x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x256_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_2x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x128_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x8x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_128x192x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_2x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_4x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x8x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x256_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_256x192x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] return presets diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index 7ca33ea779cc7..94d2bea2f0a89 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -13,9 +13,13 @@ import sympy import torch +<<<<<<< HEAD from torch._inductor.runtime.runtime_utils import dynamo_timed from torch._inductor.utils import clear_on_fresh_cache from torch.utils._ordered_set import OrderedSet +======= +from torch._inductor.utils import clear_on_fresh_cache +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ... import config from ...ir import Layout @@ -28,15 +32,19 @@ log = logging.getLogger(__name__) CUTLASS_OPERATION_KIND: str = "gemm" +<<<<<<< HEAD ACCUMULATOR_DTYPES: OrderedSet[torch.dtype] = OrderedSet([torch.float, torch.int32]) XW_DTYPES: OrderedSet[torch.dtype] = OrderedSet( [torch.half, torch.bfloat16, torch.float8_e4m3fn, torch.int8] ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @atexit.register def move_cutlass_compiled_cache() -> None: """Move CUTLASS compiled cache file to the cache directory if it exists.""" +<<<<<<< HEAD if not try_import_cutlass.cache_info().currsize > 0: return @@ -54,6 +62,19 @@ def move_cutlass_compiled_cache() -> None: try: filename = os.path.basename(python_cutlass.CACHE_FILE) shutil.move(python_cutlass.CACHE_FILE, os.path.join(cache_dir(), filename)) +======= + if "cutlass" not in sys.modules: + return + + import cutlass # type: ignore[import-not-found] + + if not os.path.exists(cutlass.CACHE_FILE): + return + + try: + filename = os.path.basename(cutlass.CACHE_FILE) + shutil.move(cutlass.CACHE_FILE, os.path.join(cache_dir(), filename)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log.debug("Moved CUTLASS compiled cache file to %s", cache_dir()) except OSError as e: log.warning("Failed to move CUTLASS compiled cache file: %s", str(e)) @@ -73,13 +94,24 @@ def try_import_cutlass() -> bool: """ We want to support three ways of passing in CUTLASS: 1. fbcode, handled by the internal build system. +<<<<<<< HEAD 2. User specifies cutlass_dir. The default is ../third_party/cutlass/, +======= + 2. pip install nvidia-cutlass, which provides the cutlass_library package + and the header files in the cutlass_library/source directory. + 3. User specifies cutlass_dir. The default is ../third_party/cutlass/, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) which is the directory when developers build from source. """ if config.is_fbcode(): try: +<<<<<<< HEAD import cutlass_library # type: ignore[import-not-found] import python_cutlass # type: ignore[import-not-found] # noqa: F401 +======= + import cutlass # type: ignore[import-not-found] + import cutlass_library # type: ignore[import-not-found] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except ImportError as e: log.warning( "Failed to import CUTLASS packages in fbcode: %s, ignoring the CUTLASS backend.", @@ -89,6 +121,37 @@ def try_import_cutlass() -> bool: return True +<<<<<<< HEAD +======= + try: + import cutlass # type: ignore[import-not-found] # noqa: F811 + import cutlass_library # type: ignore[import-not-found] # noqa: F811 + + cutlass_minor_vesion = int(cutlass.__version__.split(".")[1]) + if cutlass_minor_vesion < 7: + log.warning("CUTLASS version < 3.7 is not recommended.") + + log.debug( + "Found cutlass_library in python search path, overriding config.cuda.cutlass_dir" + ) + cutlass_library_dir = os.path.dirname(cutlass_library.__file__) + assert os.path.isdir(cutlass_library_dir), ( + f"{cutlass_library_dir} is not a directory" + ) + config.cuda.cutlass_dir = os.path.abspath( + os.path.join( + cutlass_library_dir, + "source", + ) + ) + + return True + except ModuleNotFoundError: + log.debug( + "cutlass_library not found in sys.path, trying to import from config.cuda.cutlass_dir" + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path. # This is a temporary hack to avoid CUTLASS module naming conflicts. # TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues. @@ -128,7 +191,11 @@ def path_join(path0, path1): if tmp_cutlass_full_path not in sys.path: def link_and_append(dst_link, src_path, parent_dir): +<<<<<<< HEAD if os.path.lexists(dst_link): +======= + if os.path.exists(dst_link): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert os.path.islink(dst_link), ( f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." ) @@ -156,7 +223,11 @@ def link_and_append(dst_link, src_path, parent_dir): ) try: +<<<<<<< HEAD import cutlass # noqa: F401, F811 +======= + import cutlass # noqa: F401 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import cutlass_library.generator # noqa: F401 import cutlass_library.library # noqa: F401 import cutlass_library.manifest # noqa: F401 @@ -287,10 +358,16 @@ def gen_ops() -> dict[Any, Any]: """ Generates all supported CUTLASS operations. """ +<<<<<<< HEAD with dynamo_timed("cutlass_utils.gen_ops"): arch = get_cuda_arch() version = get_cuda_version() return _gen_ops_cached(arch, version) +======= + arch = get_cuda_arch() + version = get_cuda_version() + return _gen_ops_cached(arch, version) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DTYPE_TO_CUTLASS_TYPE = { @@ -356,10 +433,13 @@ def get_accumulator_dtype( Given a pair of input torch dtypes, returns the inferred accumulator torch dtype. """ +<<<<<<< HEAD assert OrderedSet(input_torch_dtypes) <= XW_DTYPES, ( f"{input_torch_dtypes=} is not supported" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if len(input_torch_dtypes) != 2: return None @@ -380,6 +460,7 @@ def get_accumulator_dtype( torch_dtype = dtype0 if torch_dtype in (torch.float16, torch.bfloat16, torch.float, torch.float8_e4m3fn): +<<<<<<< HEAD accumulator_dtype = torch.float elif torch_dtype == torch.int8: accumulator_dtype = torch.int32 @@ -390,6 +471,12 @@ def get_accumulator_dtype( f"{accumulator_dtype=} is not supported" ) return accumulator_dtype +======= + return torch.float + if torch_dtype == torch.int8: + return torch.int32 + raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes=}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @functools.lru_cache(32) @@ -464,9 +551,13 @@ def __enter__(self, *args, **kwargs): _compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile +<<<<<<< HEAD def my_compile( source_code, dst_file_ext, extra_args: Optional[list[str]] = None ): +======= + def my_compile(source_code, dst_file_ext): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.sources.append(source_code) return _compile_method_orig(source_code, dst_file_ext) diff --git a/torch/_inductor/codegen/cuda/device_op_overrides.py b/torch/_inductor/codegen/cuda/device_op_overrides.py index 147515e0decfe..aeb0ec4b85fde 100644 --- a/torch/_inductor/codegen/cuda/device_op_overrides.py +++ b/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -4,6 +4,10 @@ import torch +<<<<<<< HEAD +======= +from ...utils import triton_version_uses_attrs_dict +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..common import ( DeviceOpOverrides, register_device_op_overrides, @@ -332,6 +336,7 @@ def cpp_kernel_type(self) -> str: def cpp_device_ptr(self) -> str: return "CUdeviceptr" +<<<<<<< HEAD def cpp_scratch( self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None ) -> Optional[tuple[list[str], str]]: @@ -359,6 +364,36 @@ def cpp_scratch( ) else: return [f"CUdeviceptr {var_name} = 0;"], var_name +======= + def cpp_global_scratch( + self, idx: int, workspace: TritonScratchWorkspace + ) -> Optional[tuple[list[str], str]]: + if triton_version_uses_attrs_dict(): + var_name = f"global_scratch_{idx}" + if workspace.size > 0: + size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};" + stride_array = f"int64_t {var_name}_stride[] = {{1}};" + device_type = "cached_torch_device_type_cuda" + device_idx = "device_idx_" + + return ( + [ + f"{size_array}", + f"{stride_array}", + f"AtenTensorHandle {var_name}_handle;", + ( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, " + f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));" + ), + f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);", + f"CUdeviceptr {var_name} = reinterpret_cast({var_name}_tensor.data_ptr());", + ], + var_name, + ) + else: + return [f"CUdeviceptr {var_name} = 0;"], var_name + return None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) register_device_op_overrides("cuda", CUDADeviceOpOverrides()) diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index d37e16768adb2..56016fdc511b5 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -10,10 +10,14 @@ import torch import torch.utils._pytree as pytree +<<<<<<< HEAD from torch._inductor.autotune_process import TensorMeta from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops from torch._inductor.codegen.wrapper import PythonWrapperCodegen from torch._inductor.runtime.runtime_utils import dynamo_timed +======= +from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.scheduler import BaseSchedulerNode from torch._inductor.select_algorithm import create_inputs_key from torch._inductor.utils import clear_on_fresh_cache @@ -37,6 +41,7 @@ from .cuda_template import CUTLASSTemplate from .cutlass_presets import gen_cutlass_presets from .cutlass_python_evt import CutlassEVTCodegen, scaled_mm_evt +<<<<<<< HEAD from .cutlass_utils import ( ACCUMULATOR_DTYPES, dtype_match, @@ -47,6 +52,12 @@ GemmOperation = Any EVTArgRenames = Any +======= +from .cutlass_utils import torch_dtype_to_cutlass_type + + +GemmOperation = Any +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) @@ -565,6 +576,7 @@ def _add_cutlass_gemm_choices( """ ops = self.gen_ops() +<<<<<<< HEAD # pre-computation layout_repr: str = str(layout) @@ -602,6 +614,21 @@ def _add_cutlass_gemm_choices( [node.get_layout() for node in input_nodes], [node.get_stride() for node in input_nodes], ) +======= + for name, op in ops: + for swizzle in inductor_cuda_config.cutlass_max_profiling_swizzle_options: + description = f"{name} swizzle={swizzle}" + self.maybe_append_choice( + choices, description=description, op=op, swizzle=swizzle + ) + + if len(ops) == 0: + input_layouts = [node.get_layout() for node in input_nodes] + input_strides = [node.get_stride() for node in input_nodes] + output_layout = layout + warning_msg = f"No suitable Cutlass GEMM configs found, fallbacks used ( {len(ops)=}, {output_layout=}, {input_layouts=}, {input_strides=} )" # noqa: B950 + log.warning(warning_msg) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log.debug( "Added %d Cutlass gemm configs.", len(ops), @@ -689,6 +716,7 @@ def layout_match( return CUTLASSGemmTemplate.cutlass_layout(torch_layout) == cutlass_layout @staticmethod +<<<<<<< HEAD def set_layout(tensor_desc: "TensorDescription", torch_layout: ir.Layout) -> None: # type: ignore[name-defined] # noqa: F821 """ Helper method: Sets the layout of a given tensor description to match the given torch layout @@ -698,6 +726,8 @@ def set_layout(tensor_desc: "TensorDescription", torch_layout: ir.Layout) -> Non tensor_desc.layout = CUTLASSGemmTemplate.cutlass_layout(torch_layout) @staticmethod +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def set_alignment(torch_layout, op_element) -> bool: """ Helper method to update the alignment of a given CUTLASS GEMM op operand's element. @@ -838,6 +868,7 @@ def _dtype_match( return True +<<<<<<< HEAD @classmethod def global_filter_ops( cls, @@ -885,6 +916,8 @@ def global_filter_ops( return ops +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def filter_op( self, op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 @@ -902,6 +935,19 @@ def filter_op( have been mutated. """ +<<<<<<< HEAD +======= + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib # type: ignore[import] + + # Skip simt kernels + if ( + op.tile_description.math_instruction.opcode_class + == cutlass_lib.OpcodeClass.Simt + ): + return None + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if op.gemm_kind not in self._get_supported_ops(): return None @@ -916,6 +962,16 @@ def filter_op( if not self._dtype_match(op): return None +<<<<<<< HEAD +======= + # Filter ops by input layouts. + if not ( + self.layout_match(X.get_layout(), op.A.layout) + and self.layout_match(W.get_layout(), op.B.layout) + ): + return None + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Filter ops by alignment. if not self._alignment_match(op): log.debug( @@ -923,6 +979,7 @@ def filter_op( ) return None +<<<<<<< HEAD # only use stream k for static shape if op.tile_scheduler.name == "StreamK": static_shape = PythonWrapperCodegen.statically_known_list_of_ints_or_none( @@ -938,6 +995,11 @@ def filter_op( self.set_layout(op.A, X.get_layout()) self.set_layout(op.B, W.get_layout()) +======= + # Update op. + op = copy.deepcopy(op) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Set output layout. op.D.layout = CUTLASSGemmTemplate.cutlass_layout(self.output_node.get_layout()) @@ -1024,8 +1086,12 @@ def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: log.debug("Using cached ops for %s", self.cache_key) return self.filtered_ops_cache[self.cache_key] +<<<<<<< HEAD with dynamo_timed("CUTLASSGemmTemplate.maybe_fetch_ops"): maybe_ops = maybe_fetch_ops() +======= + maybe_ops = maybe_fetch_ops() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if maybe_ops is None: log.debug("Cannot fetch ops from cache, generating ops from scratch") full_ops = cutlass_utils.gen_ops() @@ -1034,8 +1100,11 @@ def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: log.debug("Using cached ops from cache") ops = maybe_ops +<<<<<<< HEAD ops = self.global_filter_ops(ops) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) res: dict[str, cutlass_gemm_op.GemmOperation] = {} start_time = time.time() for op in ops: @@ -1168,10 +1237,13 @@ def render( # type: ignore[override] op = self.swap_XW(op) should_swap_xw = True +<<<<<<< HEAD name_to_buffer = {node.get_name(): node for node in self.input_nodes} # handle the fake output buffer during lowering name_to_buffer[Y.get_name()] = Y # type: ignore[assignment] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if epilogue_nodes or is_scaled_mm: if epilogue_nodes: ( @@ -1183,6 +1255,7 @@ def render( # type: ignore[override] Y.get_name(), epilogue_nodes, V.kernel.removed_buffers ) +<<<<<<< HEAD # TODO: mlazos remove this by returning buffer metadata from # ir_to_evt_python code for name, buf in ( @@ -1192,6 +1265,14 @@ def render( # type: ignore[override] name_to_buffer[name] = buf # type: ignore[assignment] D_output_name = var_name_to_buffer_name["D"] +======= + D_output_name = var_name_to_buffer_name["D"] + name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs + for name in V.graph.constants.keys(): + name_to_buffer[name] = V.graph.add_tensor_constant( + V.graph.constants[name], name + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) D_output_buffer = name_to_buffer[D_output_name] Y = D_output_buffer # type: ignore[assignment] # Interestingly, I don't think the rest of the layout matters here since we @@ -1232,11 +1313,18 @@ def render( # type: ignore[override] ) assert acc_dtype, "Could not determine accumulator dtype" +<<<<<<< HEAD evt_name, evt_args, evt_code, evt_arg_renames = self._render_evt( op, evt_py_code, var_name_to_buffer_name, name_to_buffer, +======= + evt_name, evt_args, evt_code = self._render_evt( + op, + evt_py_code, + var_name_to_buffer_name, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Y.get_dtype(), acc_dtype, ) @@ -1249,9 +1337,12 @@ def render( # type: ignore[override] Y, *extra_inputs, ] +<<<<<<< HEAD input_names = [evt_arg_renames.get(name) for name in input_names] output_names = [evt_arg_renames.get(name) for name in output_names] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) names_str = ",".join( ["X", "W", "Bias", *input_names, "Y", *output_names, *extra_names] ) @@ -1272,6 +1363,7 @@ def render( # type: ignore[override] instance_definition, instance_type = self._define_gemm_instance(op, evt_name) +<<<<<<< HEAD options = { "alpha": self.alpha, "beta": self.beta, @@ -1293,6 +1385,29 @@ def render( # type: ignore[override] "op_conf_name": op.configuration_name(), "epilogue_visitor_tree": evt_code, } +======= + options = dict( + alpha=self.alpha, + beta=self.beta, + X=X, + W=W, + Y=Y, + kernel_call_signature=kernel_call_signature, + Bias=Bias, + epilogue_template=epilogue_template, + argument_template=argument_template, + should_swap_xw=should_swap_xw, + template=self, + kernel=kernel, + instance_definition=instance_definition, + instance_type=instance_type, + input_reorder=self.input_reorder, + epilogue_args=evt_args, + test_call_statement=test_call_statement, + op_conf_name=op.configuration_name(), + epilogue_visitor_tree=evt_code, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) options.update(dict(zip(extra_names, extra_inputs))) res = self._template_from_string(self._get_template()).render(**options) if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias): @@ -1328,17 +1443,27 @@ def test_call_statement( f"(({arg_type}){arg_name}_data.get())" for arg_type, arg_name in zip(arg_types, arg_names) ] +<<<<<<< HEAD return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, 0, 0, 0, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950 +======= + return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _render_evt( self, op: GemmOperation, evt_py_code: str, buffer_renames: dict[str, str], +<<<<<<< HEAD name_to_buffer: dict[str, Buffer], output_dtype: torch.dtype, accumulator_dtype: torch.dtype, ) -> tuple[str, str, str, EVTArgRenames]: # type: ignore[name-defined] # noqa: F821 +======= + output_dtype: torch.dtype, + accumulator_dtype: torch.dtype, + ) -> tuple[str, str, str]: # type: ignore[name-defined] # noqa: F821 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError("_render_evt in CUTLASSGemmTemplate not implemented") @@ -1497,6 +1622,7 @@ def _render_evt( op: GemmOperation, evt_py_code: str, var_name_to_buffer_name: dict[str, str], +<<<<<<< HEAD name_to_buffer: dict[str, Buffer], output_dtype: torch.dtype, accumulator_dtype: torch.dtype, @@ -1506,12 +1632,35 @@ def _render_evt( acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype) output_dtype = torch_dtype_to_cutlass_type(output_dtype) +======= + output_dtype: torch.dtype, + accumulator_dtype: torch.dtype, + ) -> tuple[str, str, str]: # type: ignore[name-defined] # noqa: F821 + from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace + + name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs + + for name in V.graph.constants.keys(): + name_to_buffer[name] = V.graph.add_tensor_constant( + V.graph.constants[name], name + ) + + # handle the fake output buffer during lowering + name_to_buffer[self.output_node.get_name()] = self.output_node # type: ignore[assignment] + + acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype) + output_dtype = torch_dtype_to_cutlass_type(output_dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) examples = create_example_tensors( var_name_to_buffer_name, name_to_buffer, # type: ignore[arg-type] V.graph.sizevars.size_hint, ) +<<<<<<< HEAD evt_name, evt_args, evt_code, arg_renames = trace( +======= + evt_name, evt_args, evt_code = trace( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) evt_py_code, examples, acc_dtype, @@ -1526,7 +1675,10 @@ def _render_evt( evt_name, evt_args, evt_code, +<<<<<<< HEAD arg_renames, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def _shape_match( @@ -1674,6 +1826,7 @@ def render_gemm_arguments( tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped before the function call. """ +<<<<<<< HEAD options = { "alpha": alpha, "beta": beta, @@ -1687,6 +1840,21 @@ def render_gemm_arguments( "N": "N", "epilogue_args": epilogue_args, } +======= + options = dict( + alpha=alpha, + beta=beta, + X=X, + W=W, + Y=Y, + Bias=Bias, + template=self, + kernel=kernel, + M="M", + N="N", + epilogue_args=epilogue_args, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert epilogue_template is not None if should_swap_xw: @@ -1965,6 +2133,7 @@ def render_gemm_arguments( tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped before the function call. """ +<<<<<<< HEAD options = { "instance_type": instance_type, "alpha": alpha, @@ -1980,6 +2149,23 @@ def render_gemm_arguments( "N": "N", "epilogue_args": epilogue_args, } +======= + options = dict( + instance_type=instance_type, + alpha=alpha, + beta=beta, + X=X, + W=W, + Y=Y, + Bias=Bias, + Meta=Meta, + template=self, + kernel=kernel, + M="M", + N="N", + epilogue_args=epilogue_args, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if epilogue_template is None: arguments = self._template_from_string(argument_template).render( diff --git a/torch/_inductor/codegen/cuda/serialization.py b/torch/_inductor/codegen/cuda/serialization.py index a17f04b0a1b5a..d87c52341f6a2 100644 --- a/torch/_inductor/codegen/cuda/serialization.py +++ b/torch/_inductor/codegen/cuda/serialization.py @@ -1,8 +1,16 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD import functools import json from enum import Enum from typing import Any, Optional +======= +import enum +import functools +import json +from enum import Enum +from typing import Optional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass @@ -29,14 +37,25 @@ class CUTLASSOperationSerializer: ] @classmethod +<<<<<<< HEAD def serialize(cls, operation: "GemmOperation") -> str: # type: ignore[name-defined] # noqa: F821 +======= + def serialize(cls, operation: "GemmOperation"): # type: ignore[name-defined] # noqa: F821 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Serialize a GEMM operation to JSON string. Args: operation: GemmOperation object +<<<<<<< HEAD Returns: str: JSON string representation of the operation +======= + indent: JSON indentation spaces + + Returns: + str: JSON representation of the operation +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ assert operation.__class__.__qualname__ == "GemmOperation", ( "Only GemmOperation objects are supported via the main API" @@ -57,7 +76,11 @@ def deserialize(cls, json_str: str) -> "GemmOperation": # type: ignore[name-def return cls._json_to_gemm_operation(json_dict) @classmethod +<<<<<<< HEAD def _gemm_operation_to_json(cls, operation: "GemmOperation") -> dict[str, Any]: # type: ignore[name-defined] # noqa: F821 +======= + def _gemm_operation_to_json(cls, operation): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Convert GemmOperation to JSON-serializable dict. Args: @@ -119,7 +142,11 @@ def _gemm_operation_to_json(cls, operation: "GemmOperation") -> dict[str, Any]: return result @classmethod +<<<<<<< HEAD def _json_to_gemm_operation(cls, json_dict: dict[str, Any]) -> "GemmOperation": # type: ignore[name-defined] # noqa: F821 +======= + def _json_to_gemm_operation(cls, json_dict): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Convert JSON dict to GemmOperation object. Args: @@ -144,9 +171,15 @@ def _json_to_gemm_operation(cls, json_dict: dict[str, Any]) -> "GemmOperation": gemm_kind = cls._json_to_enum(json_dict["gemm_kind"], GemmKind) arch = json_dict["arch"] tile_description = cls._json_to_tile_description(json_dict["tile_description"]) +<<<<<<< HEAD A = cls._json_to_tensor_description(json_dict.get("A"), "A") B = cls._json_to_tensor_description(json_dict.get("B"), "B") C = cls._json_to_tensor_description(json_dict.get("C"), "C") +======= + A = cls._json_to_tensor_description(json_dict.get("A")) + B = cls._json_to_tensor_description(json_dict.get("B")) + C = cls._json_to_tensor_description(json_dict.get("C")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) element_epilogue = cls._json_to_enum(json_dict["element_epilogue"], DataType) # Get optional parameters with defaults @@ -157,7 +190,11 @@ def _json_to_gemm_operation(cls, json_dict: dict[str, Any]) -> "GemmOperation": swizzling_functor = cls._json_to_enum( json_dict.get("swizzling_functor"), SwizzlingFunctor ) +<<<<<<< HEAD D = cls._json_to_tensor_description(json_dict.get("D"), "D") +======= + D = cls._json_to_tensor_description(json_dict.get("D")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kernel_schedule = cls._json_to_enum( json_dict.get("kernel_schedule"), KernelScheduleType ) @@ -181,7 +218,11 @@ def _json_to_gemm_operation(cls, json_dict: dict[str, Any]) -> "GemmOperation": if "ScaleFactorD" in json_dict and "ScaleFactorVectorSize" in json_dict: ScaleFactorD = { "tensor": cls._json_to_tensor_description( +<<<<<<< HEAD json_dict.get("ScaleFactorD"), "ScaleFactorD" +======= + json_dict.get("ScaleFactorD") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), "vector_size": json_dict.get("ScaleFactorVectorSize"), } @@ -218,26 +259,70 @@ def _json_to_gemm_operation(cls, json_dict: dict[str, Any]) -> "GemmOperation": return operation @classmethod +<<<<<<< HEAD @functools.lru_cache(None) def _tile_description_to_json(cls, tile_desc: "TileDescription") -> str: # type: ignore[name-defined] # noqa: F821 """ Convert TileDescription to JSON string. +======= + def _tile_description_to_json(cls, tile_desc): + """ + Convert TileDescription to JSON dict. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: tile_desc: TileDescription object Returns: +<<<<<<< HEAD str: JSON string representation """ +======= + dict: Dictionary representation + """ + if tile_desc is None: + return None + + # Create a dictionary for math_instruction if it exists + math_instruction_dict = None + if ( + hasattr(tile_desc, "math_instruction") + and tile_desc.math_instruction is not None + ): + math_instruction = tile_desc.math_instruction + math_instruction_dict = { + "instruction_shape": math_instruction.instruction_shape, + "element_a": cls._enum_to_json(math_instruction.element_a), + "element_b": cls._enum_to_json(math_instruction.element_b), + "element_accumulator": cls._enum_to_json( + math_instruction.element_accumulator + ), + "opcode_class": cls._enum_to_json(math_instruction.opcode_class), + "math_operation": cls._enum_to_json(math_instruction.math_operation), + } + + # Add element_scale_factor if it exists + if ( + hasattr(math_instruction, "element_scale_factor") + and math_instruction.element_scale_factor is not None + ): + math_instruction_dict["element_scale_factor"] = cls._enum_to_json( + math_instruction.element_scale_factor + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Create the main dictionary with field names matching TileDescription constructor parameters result = { "threadblock_shape": tile_desc.threadblock_shape, "stages": tile_desc.stages, "warp_count": tile_desc.warp_count, +<<<<<<< HEAD "math_instruction": cls._math_instruction_to_json( tile_desc.math_instruction ), +======= + "math_instruction": math_instruction_dict, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "min_compute": tile_desc.minimum_compute_capability, # Store as min_compute for constructor "max_compute": tile_desc.maximum_compute_capability, # Store as max_compute for constructor "cluster_shape": tile_desc.cluster_shape, @@ -251,6 +336,7 @@ def _tile_description_to_json(cls, tile_desc: "TileDescription") -> str: # type ): result["tile_shape"] = tile_desc.tile_shape +<<<<<<< HEAD return json.dumps(result) @classmethod @@ -258,6 +344,12 @@ def _tile_description_to_json(cls, tile_desc: "TileDescription") -> str: # type def _json_to_tile_description( cls, json_dict: Optional[str] ) -> Optional["TileDescription"]: # type: ignore[name-defined] # noqa: F821 +======= + return result + + @classmethod + def _json_to_tile_description(cls, json_dict): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Convert JSON dict to TileDescription object. @@ -270,6 +362,7 @@ def _json_to_tile_description( if json_dict is None: return None +<<<<<<< HEAD tile_dict = json.loads(json_dict) from cutlass_library.library import TileDescription @@ -297,18 +390,95 @@ def _json_to_tile_description( max_compute=max_compute, cluster_shape=cluster_shape, explicit_vector_sizes=tile_dict.get("explicit_vector_sizes"), +======= + from cutlass_library import DataType + from cutlass_library.library import ( + MathInstruction, + MathOperation, + OpcodeClass, + TileDescription, + ) + + # First, reconstruct the math_instruction if it exists + math_instruction_obj = None + if ( + "math_instruction" in json_dict + and json_dict["math_instruction"] is not None + ): + mi_dict = json_dict["math_instruction"] + + # Convert string enum names back to enum values + element_a = cls._json_to_enum(mi_dict["element_a"], DataType) + element_b = cls._json_to_enum(mi_dict["element_b"], DataType) + element_acc = cls._json_to_enum(mi_dict["element_accumulator"], DataType) + + # Get the opcode_class enum + opcode_class = cls._json_to_enum(mi_dict["opcode_class"], OpcodeClass) + + # Get the math_operation enum + math_op = cls._json_to_enum(mi_dict["math_operation"], MathOperation) + + # Create the MathInstruction object + math_instruction_obj = MathInstruction( + instruction_shape=mi_dict["instruction_shape"], + element_a=element_a, + element_b=element_b, + element_accumulator=element_acc, + opcode_class=opcode_class, + math_operation=math_op, + ) + + # Add element_scale_factor if it exists + if ( + "element_scale_factor" in mi_dict + and mi_dict["element_scale_factor"] is not None + ): + math_instruction_obj.element_scale_factor = cls._json_to_enum( + mi_dict["element_scale_factor"], DataType + ) + + # Get compute capability values, checking both naming conventions + min_compute = json_dict.get( + "min_compute", json_dict.get("minimum_compute_capability") + ) + max_compute = json_dict.get( + "max_compute", json_dict.get("maximum_compute_capability") + ) + + # Get cluster shape with default value + cluster_shape = json_dict.get("cluster_shape", [1, 1, 1]) + + # Create the TileDescription object + tile_desc = TileDescription( + threadblock_shape=json_dict["threadblock_shape"], + stages=json_dict["stages"], + warp_count=json_dict["warp_count"], + math_instruction=math_instruction_obj, + min_compute=min_compute, + max_compute=max_compute, + cluster_shape=cluster_shape, + explicit_vector_sizes=json_dict.get("explicit_vector_sizes"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Set tile_shape if it exists and differs from threadblock_shape if ( +<<<<<<< HEAD "tile_shape" in tile_dict and tile_dict["tile_shape"] != tile_dict["threadblock_shape"] ): tile_desc.tile_shape = tile_dict["tile_shape"] +======= + "tile_shape" in json_dict + and json_dict["tile_shape"] != json_dict["threadblock_shape"] + ): + tile_desc.tile_shape = json_dict["tile_shape"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return tile_desc @classmethod +<<<<<<< HEAD @functools.lru_cache(None) def _math_instruction_to_json( cls, @@ -401,23 +571,36 @@ def _tensor_description_to_json( tensor_desc: Optional["TensorDescription"], # type: ignore[name-defined] # noqa: F821 ) -> Optional[str]: """Convert TensorDescription to JSON string. +======= + def _tensor_description_to_json(cls, tensor_desc): + """Convert TensorDescription to JSON dict. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: tensor_desc: TensorDescription object Returns: +<<<<<<< HEAD Optional[str]: JSON string representation or None +======= + dict: Dictionary representation +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ if tensor_desc is None: return None +<<<<<<< HEAD result = { +======= + return { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "element": cls._enum_to_json(tensor_desc.element), "layout": cls._enum_to_json(tensor_desc.layout), "alignment": tensor_desc.alignment, "complex_transform": cls._enum_to_json(tensor_desc.complex_transform), } +<<<<<<< HEAD return json.dumps(result) @classmethod @@ -441,6 +624,21 @@ def _json_to_tensor_description( tensor_dict = json.loads(json_dict) +======= + @classmethod + def _json_to_tensor_description(cls, tensor_json): + """Convert JSON dict to TensorDescription object. + + Args: + tensor_json: Dictionary representation + + Returns: + TensorDescription: Reconstructed object + """ + if tensor_json is None: + return None + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from cutlass_library import DataType from cutlass_library.library import ( ComplexTransform, @@ -448,44 +646,73 @@ def _json_to_tensor_description( TensorDescription, ) +<<<<<<< HEAD element = cls._json_to_enum(tensor_dict["element"], DataType) layout = cls._json_to_enum(tensor_dict["layout"], LayoutType) alignment = tensor_dict["alignment"] complex_transform = cls._json_to_enum( tensor_dict["complex_transform"], ComplexTransform +======= + element = cls._json_to_enum(tensor_json["element"], DataType) + layout = cls._json_to_enum(tensor_json["layout"], LayoutType) + alignment = tensor_json["alignment"] + complex_transform = cls._json_to_enum( + tensor_json["complex_transform"], ComplexTransform +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return TensorDescription(element, layout, alignment, complex_transform) @classmethod +<<<<<<< HEAD @functools.lru_cache(None) def _enum_to_json(cls, enum_value: Optional[Enum]) -> Optional[str]: """Convert enum value to JSON string. +======= + def _enum_to_json(cls, enum_value): + """Convert enum value to JSON dict. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: enum_value: Enum value Returns: +<<<<<<< HEAD Optional[str]: JSON string representation or None +======= + dict: Dictionary representation +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ if enum_value is None: return None +<<<<<<< HEAD result = { +======= + assert isinstance(enum_value, enum.Enum) + return { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "type": enum_value.__class__.__name__, "name": enum_value.name, } +<<<<<<< HEAD return json.dumps(result) @classmethod @functools.lru_cache(None) def _json_to_enum(cls, json_dict: Optional[str], enum_class: Any) -> Optional[Enum]: """Convert JSON string to enum value. +======= + @classmethod + def _json_to_enum(cls, json_dict, enum_class): + """Convert JSON dict to enum value. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Format: {name: "EnumName", value: 1} Args: +<<<<<<< HEAD json_dict: JSON string representation enum_class: Target enum class @@ -498,6 +725,17 @@ def _json_to_enum(cls, json_dict: Optional[str], enum_class: Any) -> Optional[En enum_dict = json.loads(json_dict) return enum_class[enum_dict["name"]] +======= + json_dict: Dictionary representation + enum_class: Target enum class + + Returns: + Reconstructed enum value + """ + if json_dict is None or json_dict.get("name", "None") == "None": + return None + return enum_class[json_dict["name"]] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @functools.lru_cache(1) diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py index cb497284d52f5..b5255f6707df1 100644 --- a/torch/_inductor/codegen/cuda_combined_scheduling.py +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -11,7 +11,10 @@ SchedulerNode, ) from .cuda.cuda_cpp_scheduling import CUDACPPScheduling +<<<<<<< HEAD from .cutedsl.cutedsl_scheduling import CuteDSLScheduling +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling from .triton import TritonScheduling @@ -45,7 +48,10 @@ def __init__(self, scheduler: Optional[Scheduler]) -> None: self._triton_scheduling = TritonScheduling(scheduler) self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler) +<<<<<<< HEAD self._cutedsl_scheduling = CuteDSLScheduling(scheduler) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]: return self._triton_scheduling.get_backend_features(device) @@ -55,8 +61,11 @@ def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: return self._cuda_cpp_scheduling if self._rocm_cpp_scheduling.is_rocm_cpp_template(node): return self._rocm_cpp_scheduling +<<<<<<< HEAD if self._cutedsl_scheduling.is_cutedsl_template(node): return self._cutedsl_scheduling +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self._triton_scheduling def can_fuse_vertical( @@ -68,11 +77,14 @@ def can_fuse_vertical( node1 ) or self._cuda_cpp_scheduling.is_cuda_cpp_template(node2): return False +<<<<<<< HEAD # CuteDSL doesn't support vertical fusion currently elif self._cutedsl_scheduling.is_cutedsl_template( node1 ) or self._cutedsl_scheduling.is_cutedsl_template(node2): return False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self._triton_scheduling.can_fuse_vertical(node1, node2) def can_fuse_horizontal( @@ -83,10 +95,13 @@ def can_fuse_horizontal( return self._cuda_cpp_scheduling.can_fuse_horizontal( node1, node2 ) # always False at the moment +<<<<<<< HEAD if self._cutedsl_scheduling.is_cutedsl_template(node): return self._cutedsl_scheduling.can_fuse_horizontal( node1, node2 ) # always False at the moment +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self._triton_scheduling.can_fuse_horizontal(node1, node2) def group_fn( @@ -111,6 +126,7 @@ def codegen_template( return self._rocm_cpp_scheduling.codegen_template( template_node, epilogue_nodes, prologue_nodes ) +<<<<<<< HEAD elif self._cutedsl_scheduling.is_cutedsl_template(template_node): # TODO remove this when we add epilogue support assert not epilogue_nodes @@ -118,6 +134,8 @@ def codegen_template( return self._cutedsl_scheduling.codegen_template( template_node, epilogue_nodes, prologue_nodes ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return self._triton_scheduling.codegen_template( template_node, epilogue_nodes, prologue_nodes @@ -144,6 +162,7 @@ def benchmark_codegened_module(self, module): return self._triton_scheduling.benchmark_codegened_module(module) def generate_kernel_code_from_nodes( +<<<<<<< HEAD self, nodes: Sequence[Any], benchmark_kernel: bool = False, @@ -151,6 +170,12 @@ def generate_kernel_code_from_nodes( ) -> str: return self._triton_scheduling.generate_kernel_code_from_nodes( nodes, benchmark_kernel, hint_override=hint_override +======= + self, nodes: Sequence[Any], benchmark_kernel: bool = False + ) -> str: + return self._triton_scheduling.generate_kernel_code_from_nodes( + nodes, benchmark_kernel +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def benchmark_combo_kernel( diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index f477d16cc7668..5897848e3d075 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -54,7 +54,10 @@ from collections.abc import Sequence from ..ops_handler import ReductionType, StoreMode +<<<<<<< HEAD from ..shape_propagation import BlockShapeType +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) @@ -557,7 +560,10 @@ def masked(mask, body, other): f"hl.cast({result.name}.type(), {halide_constant(other)})", [], bounds=ValueRanges.wrap(other), +<<<<<<< HEAD shape=result.shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # TODO(jansel): look into removing the where in the same places triton does return ops.where(new_mask, result, other) @@ -566,10 +572,13 @@ def masked(mask, body, other): def frexp(x): raise NotImplementedError("frexp") +<<<<<<< HEAD @staticmethod def device_assert_async(cond, msg): raise NotImplementedError("device_assert_async") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) HalideOverrides._initialize_pointwise_overrides("halide") @@ -582,9 +591,14 @@ def __init__( name, bounds: ValueRanges[Any], dtype: Optional[torch.dtype] = None, +<<<<<<< HEAD shape: BlockShapeType = None, ) -> None: super().__init__(name, bounds, dtype, shape=shape) +======= + ) -> None: + super().__init__(name, bounds, dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.used_dims: Optional[list[sympy.Symbol]] = None def update_on_args(self, name, args, kwargs): @@ -650,12 +664,21 @@ def eq(left, right): if V.graph.sizevars.statically_known_equals(left, right): return True try: +<<<<<<< HEAD a = V.graph.sizevars.size_hint_or_throw(left) b = V.graph.sizevars.size_hint_or_throw(right) except TypeError: # unbacked symints return False if a == b: V.graph.sizevars.check_equals(left, right) +======= + a = V.graph.sizevars.size_hint(left) + b = V.graph.sizevars.size_hint(right) + except TypeError: # unbacked symints + return False + if a == b: + V.graph.sizevars.guard_equals(left, right) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return a == b @@ -663,15 +686,24 @@ def lt(left, right): if V.graph.sizevars.statically_known_lt(left, right): return True try: +<<<<<<< HEAD a = V.graph.sizevars.size_hint_or_throw(left) b = V.graph.sizevars.size_hint_or_throw(right) +======= + a = V.graph.sizevars.size_hint(left) + b = V.graph.sizevars.size_hint(right) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except TypeError: # unbacked symints gcd = sympy.gcd(left, right) if gcd == left: return left != right return False if a < b: +<<<<<<< HEAD V.graph.sizevars.check_lt(left, right) +======= + V.graph.sizevars.guard_lt(left, right) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return a < b @@ -709,9 +741,15 @@ def __init__( def dtype_to_str(self, dtype: torch.dtype) -> str: return halide_type(dtype) +<<<<<<< HEAD def create_cse_var(self, name, bounds=None, dtype=None, shape=None): self.body.writeline(f"{name} = hl.Func({name!r})") return HalideCSEVariable(name, bounds, dtype, shape) +======= + def create_cse_var(self, name, bounds=None, dtype=None): + self.body.writeline(f"{name} = hl.Func({name!r})") + return HalideCSEVariable(name, bounds, dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def finalize_indexing(self, indices: Sequence[sympy.Expr]): """ @@ -1203,13 +1241,20 @@ def reduction( assert isinstance(value, HalideCSEVariable) and value.used_dims is not None reduction_vars = OrderedSet(self.reduction_renames) result_var = self.newfunc( +<<<<<<< HEAD [v for v in value.used_dims if v not in reduction_vars], +======= + [v for v in value.used_dims if v not in reduction_vars] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if reduction_vars - OrderedSet(value.used_dims): value = self.genfunc( f"{value}", self.sort_used_dims(OrderedSet((*value.used_dims, *reduction_vars))), +<<<<<<< HEAD shape=value.shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) value_str = value.subs_str(self.reduction_renames) default = ir.Reduction.default_accumulator(reduction_type, src_dtype) @@ -1299,9 +1344,13 @@ def scan( else: values.append( self.genfunc( +<<<<<<< HEAD f"{value}", [*value.used_dims, [*self.reduction_renames][:1]], shape=value.shape, +======= + f"{value}", [*value.used_dims, [*self.reduction_renames][:1]] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) all_used_dims.update(value.used_dims) @@ -1365,6 +1414,7 @@ def maybe_tuple(x): return tuple(unpack_vars) def genfunc( +<<<<<<< HEAD self, line, used_dims, @@ -1373,12 +1423,22 @@ def genfunc( shape: BlockShapeType = None, ) -> HalideCSEVariable: var = self.cse.generate(self.body, line, bounds=bounds, shape=shape) +======= + self, line, used_dims, *, bounds=ValueRanges.unknown() + ) -> HalideCSEVariable: + var = self.cse.generate(self.body, line, bounds=bounds) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(var, HalideCSEVariable) var.used_dims = used_dims return var +<<<<<<< HEAD def newfunc(self, used_dims, *, shape: BlockShapeType = None) -> HalideCSEVariable: var = self.cse.newvar(shape=shape) +======= + def newfunc(self, used_dims) -> HalideCSEVariable: + var = self.cse.newvar() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(var, HalideCSEVariable) var.used_dims = used_dims return var diff --git a/torch/_inductor/codegen/memory_planning.py b/torch/_inductor/codegen/memory_planning.py index 12d7500975e5b..e33de61e838ff 100644 --- a/torch/_inductor/codegen/memory_planning.py +++ b/torch/_inductor/codegen/memory_planning.py @@ -10,7 +10,10 @@ import sympy import torch +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._ordered_set import OrderedSet from .. import config @@ -143,6 +146,7 @@ class Allocation(AllocationTreeNode): allocated: bool = False pool: Optional[AllocationPool] = None offset: Optional[sympy.Expr] = None +<<<<<<< HEAD earliest_available: Optional[float] = None def __post_init__(self) -> None: @@ -154,6 +158,8 @@ def __post_init__(self) -> None: if has_unbacked_sym: self.earliest_available = self.get_live_ranges().begin +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def device(self): @@ -198,9 +204,12 @@ def __repr__(self): f"offset={self.offset})" ) +<<<<<<< HEAD def get_earliest_available(self): return self.earliest_available +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass class Empty(AllocationTreeNode): @@ -392,6 +401,7 @@ class AllocationPool: names_to_del: list[str] = dataclasses.field(default_factory=list) creation_cache: dict[str, str] = dataclasses.field(default_factory=dict) +<<<<<<< HEAD def __post_init__(self) -> None: for block in self.root.allocations: if isinstance(block, Allocation): @@ -412,6 +422,16 @@ def allocate(self, block: Allocation, is_last: bool): is_last = self.can_expand and is_last if self.root.allocate(block, is_last): self.update_restrict_live_range(block) +======= + def allocate(self, block: Allocation, is_last: bool): + if self.restrict_live_range and not self.restrict_live_range.contains( + block.live_range + ): + return False + + is_last = self.can_expand and is_last + if self.root.allocate(block, is_last): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return True if is_last: @@ -419,6 +439,7 @@ def allocate(self, block: Allocation, is_last: bool): return False +<<<<<<< HEAD def update_restrict_live_range(self, block: Allocation): if block_earliest_available := block.get_earliest_available(): if self.restrict_live_range is None: @@ -435,6 +456,11 @@ def allocate_at_end(self, block): block.mark_allocated() self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))]) self.update_restrict_live_range(block) +======= + def allocate_at_end(self, block): + block.mark_allocated() + self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return True def finalize(self, name): @@ -448,6 +474,10 @@ def codegen_create(self, wrapper, code: IndentedBuffer): nbytes = self.root.get_symbolic_size() for block in self.root.allocations: if isinstance(block, Allocation) and nbytes == block.get_symbolic_size(): +<<<<<<< HEAD +======= + # optimization: fuse first allocation and pool creation +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) node = block.node code.writeline( wrapper.make_allocation( @@ -458,6 +488,10 @@ def codegen_create(self, wrapper, code: IndentedBuffer): stride=tuple(node.get_stride()), ) ) +<<<<<<< HEAD +======= + self.creation_cache[block.codegen_alloc_from_pool(wrapper)] = self.name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return else: code.writeline( @@ -615,10 +649,14 @@ def codegen(self, code: IndentedBuffer): pool.codegen_create(self.wrapper, code) pool.names_to_del.extend(self.group.names) +<<<<<<< HEAD alloc_from_pool, allocation_lines_to_write = allocation.codegen_alloc_from_pool( self.wrapper ) code.writelines(allocation_lines_to_write) +======= + alloc_from_pool = allocation.codegen_alloc_from_pool(self.wrapper) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if alloc_from_pool in pool.creation_cache: code.writeline( self.wrapper.make_tensor_alias( diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 32e45bfde48d2..5d1f2474ce798 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -366,7 +366,11 @@ def randint64( @staticmethod def round(x: CSEVariable) -> str: +<<<<<<< HEAD return f"metal::rint({x})" +======= + return f"metal::round({x})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def pow(a: CSEVariable, b: CSEVariable) -> str: @@ -421,8 +425,11 @@ def _initialize_special_ops(cls) -> None: # Binary special ops for name in [ "polygamma", +<<<<<<< HEAD "igamma", "igammac", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "zeta", ]: setattr(cls, name, functools.partialmethod(cls._special_binary, name=name)) @@ -435,10 +442,13 @@ def _initialize_special_ops(cls) -> None: "chebyshev_polynomial_w", "hermite_polynomial_h", "hermite_polynomial_he", +<<<<<<< HEAD "shifted_chebyshev_polynomial_t", "shifted_chebyshev_polynomial_u", "shifted_chebyshev_polynomial_v", "shifted_chebyshev_polynomial_w", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ]: setattr( cls, @@ -537,7 +547,11 @@ def _new_idxvar( var_def = "threadgroup " if is_threadgroup else "" var_def += f"{dtype} {var_name}" if elem_count: +<<<<<<< HEAD var_def += f"[{self.sexpr(elem_count)}]" +======= + var_def += f"[{elem_count}]" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if default_value is not None: assert not is_threadgroup, "Thread group var can not have default value" var_def += f" = {default_value}" @@ -587,6 +601,7 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: if reduction_idx: reduction_idx += " + " reduction_idx += f"{rd.name} * {acc_buf_size}" +<<<<<<< HEAD if isinstance(rd.numel, sympy.Integer): acc_buf_size *= rd.numel @@ -602,6 +617,10 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: if isinstance(acc_buf_size, sympy.Integer) else self.simd_group_size ) +======= + acc_buf_size *= rd.numel + acc_buf_size = min(acc_buf_size, self.max_threadgroup_size) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if reduction_type == "any": acc = self._new_idxvar(dtype) @@ -625,7 +644,13 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: if reduction_type in ["prod", "sum"]: acc_dtype = DTYPE_TO_COMPUTATION_DTYPE[src_dtype] +<<<<<<< HEAD acc_buf = self._new_idxvar(acc_dtype, shmem_buf_size) +======= + acc_buf = self._new_idxvar( + acc_dtype, ceildiv(acc_buf_size, self.simd_group_size) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not self.multistage_reduction_entry: val = value else: @@ -636,6 +661,7 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: acc_dtype, default_value=default_val, is_threadgroup=False ) self.compute.splice(f"{val} {reduction_op}= {value};") +<<<<<<< HEAD return self.cse.generate( self.stores, @@ -696,6 +722,55 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: self.stores, f"c10::metal::threadgroup_{reduction_type}({data_acc_buf}, {idx_acc_buf}, " f"{val}, {idx_val}, {reduction_idx}, {acc_buf_size_str})", +======= + return self.cse.generate( + self.stores, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size})", + dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype], + ) + if reduction_type in ["max", "min", "argmin", "argmax"]: + acc_buf = self._new_idxvar(src_dtype, acc_buf_size) + acc_thread_var = f"{acc_buf}[{reduction_idx}]" + src_metal_type = DTYPE_TO_METAL[src_dtype] + if not self.multistage_reduction_entry: + self.compute.splice( + f"{acc_thread_var} = static_cast<{src_metal_type}>({value});" + ) + return self.cse.generate( + self.stores, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + dtype=dtype, + ) + lim_fn = "lowest" if reduction_type.endswith("max") else "max" + self.indexing_code.writeline( + f"{acc_thread_var} = ::metal::numeric_limits<{src_metal_type}>::{lim_fn}();" + ) + if reduction_type.startswith("arg"): + idx_var = next( + t for t in self.range_tree_nodes.values() if t.is_reduction + ) + idx_acc_buf = self._new_idxvar(torch.long, acc_buf_size) + cmp_op = ">" if reduction_type == "argmax" else "<" + idx_thread_var = f"{idx_acc_buf}[{reduction_idx}]" + self.indexing_code.splice(f"{idx_thread_var} = -1;") + self.compute.splice(f""" + if ({value} {cmp_op} {acc_thread_var}) {{ + {acc_thread_var} = {value}; + {idx_thread_var} = {idx_var.name}; + }} + """) + return self.cse.generate( + self.stores, + f"{idx_acc_buf}[c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})]", + dtype=dtype, + ) + self.compute.writeline( + f"{acc_thread_var} = ::c10::metal::{reduction_type}({acc_thread_var}, {value});" + ) + return self.cse.generate( + self.stores, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype=dtype, ) if reduction_type == "welford_reduce": @@ -704,7 +779,11 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: self.compute.splice(f"{acc_buf}[{reduction_idx}] = {value};") wf_res = self.cse.generate( self.compute, +<<<<<<< HEAD f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size_str})", +======= + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype=torch.float32, ) return _unwrap_helper(wf_res) @@ -735,7 +814,11 @@ def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: self.compute.writeline(f"{acc_thread_var} = {inp_value};") wf_res = self.cse.generate( self.stores if self.multistage_reduction_entry else self.compute, +<<<<<<< HEAD f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size_str})", +======= + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype=torch.float32, ) return _unwrap_helper(wf_res) @@ -745,14 +828,19 @@ def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None: index_expr = self.rename_indexing(entry.expr) index_str = self.sexpr(index_expr) # type: ignore[misc] +<<<<<<< HEAD if not entry.is_reduction or ( isinstance(entry.root.numel, sympy.Integer) and entry.root.numel <= self.max_threadgroup_size ): +======= + if not entry.is_reduction or entry.root.numel <= self.max_threadgroup_size: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.indexing_code.writeline( f"{self.index_dtype} {entry.name} = {index_str};" ) return +<<<<<<< HEAD acc_size = ( entry.root.numel @@ -760,10 +848,13 @@ def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None: else sympy.Symbol(f"{entry.root.prefix}numel", integer=True, positive=True) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.multistage_reduction_entry.append(entry) # When reducing the tensor whose size exceeds max threadgroup size # loop over extra indices per reduction thread and perform part of the operation # using values in the shared memory +<<<<<<< HEAD # Use floats so that it doesn't do integer division loop_size = (acc_size + float(self.max_threadgroup_size - 1)) // float( @@ -790,6 +881,21 @@ def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None: or loop_size * self.max_threadgroup_size != acc_size ): self.body.writeline(f"if ({entry.name} >= {acc_size}) break;") +======= + loop_size = ( + entry.root.numel + self.max_threadgroup_size - 1 + ) // self.max_threadgroup_size + self.body.writeline( + f"for(auto {entry.name}_cnt = 0; {entry.name}_cnt < {loop_size}; ++{entry.name}_cnt) {{" + ) + with self.body.indent(): + self.body.writeline( + f"{self.index_dtype} {entry.name} = {loop_size} * {index_str} + {entry.name}_cnt;" + ) + # Check that reduction is performed only within tensor boundary + if loop_size * self.max_threadgroup_size != entry.root.numel: + self.body.writeline(f"if ({entry.name} >= {entry.root.numel}) break;") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def codegen_body(self) -> None: """ @@ -858,6 +964,7 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: total_reduction_size = math.prod( t.numel for t in self.range_trees if t.is_reduction ) +<<<<<<< HEAD # If using dynamic shapes, set the threadgroup size to be the # max possible size threadgroup_size = ( @@ -865,6 +972,9 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: if isinstance(total_reduction_size, sympy.Integer) else self.max_threadgroup_size ) +======= + threadgroup_size = min(total_reduction_size, self.max_threadgroup_size) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) code.writeline( f"[[max_total_threads_per_threadgroup({threadgroup_size})]]" ) @@ -888,6 +998,7 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: code.writeline(f"constant {dtype_str}* {inner},") for outer, inner in self.args.sizevars.items(): code.writeline(f"constant long& {inner},") +<<<<<<< HEAD # Write dynamic values as inputs for idx_var in idx_vars: @@ -896,6 +1007,8 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: else: code.writeline(f"constant long& {idx_var.prefix}numel,") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(idx_vars) < 4, "Up to 3 index variables are supported" thread_pos_dtype = ( f"uint{len(idx_vars)}" if len(idx_vars) > 1 else "uint" @@ -930,9 +1043,13 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: return code.getvalue() def call_kernel(self, name: str, node: Any = None) -> None: +<<<<<<< HEAD """ Codegens a call to this kernel """ +======= + """Codegen a call to this kernel""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) wrapper = V.graph.wrapper_code # Make sure sizevars has been computed for v in self.args.sizevars.keys(): @@ -946,6 +1063,7 @@ def call_kernel(self, name: str, node: Any = None) -> None: args = [*self.args.output_buffers.keys(), *self.args.input_buffers.keys()] args = [arg for arg in args if arg not in self.removed_buffers] args += [str(v) for v in self.args.sizevars.keys()] +<<<<<<< HEAD arg_types = [arg_name_to_type[arg] for arg in args] # Add any dynamic ints as inputs @@ -962,6 +1080,10 @@ def call_kernel(self, name: str, node: Any = None) -> None: args.append(str(expr)) arg_types.append(int) +======= + + arg_types = [arg_name_to_type[arg] for arg in args] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expr_printer = self.cexpr if V.graph.cpp_wrapper else self.pexpr def format_threads(threads: list[str], kwarg: str) -> str: @@ -1008,7 +1130,11 @@ def format_threads(threads: list[str], kwarg: str) -> str: wrapper.generate_kernel_call( name, args, +<<<<<<< HEAD device=torch.device("mps"), +======= + device=torch.device("cpu"), # TODO: Fix me, MPS does not expose streams now +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) triton=False, arg_types=arg_types, ) @@ -1054,15 +1180,26 @@ def define_kernel( # Either using MultiKernel concept or overriding SIMDScheduling.codegen_node_scheduling mps_lib_name = f"mps_lib_{wrapper.next_kernel_suffix()}" +<<<<<<< HEAD kernel_name = f"{mps_lib_name}" wrapper.src_to_kernel[src_code] = kernel_name +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if V.graph.cpp_wrapper: src_code = ( f"at::native::mps::DynamicMetalShaderLibrary {mps_lib_name}" + src_code ) +<<<<<<< HEAD + +======= + kernel_name = f"{mps_lib_name}_func" + else: + kernel_name = f"{mps_lib_name}.generated_kernel" + wrapper.src_to_kernel[src_code] = kernel_name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) metadata_comment = f"{origins}\n{detailed_origins}" wrapper.define_kernel(mps_lib_name, src_code, metadata_comment, gpu=False) diff --git a/torch/_inductor/codegen/multi_kernel.py b/torch/_inductor/codegen/multi_kernel.py index c7ac48ba0231c..b916f6666ac47 100644 --- a/torch/_inductor/codegen/multi_kernel.py +++ b/torch/_inductor/codegen/multi_kernel.py @@ -4,7 +4,10 @@ import os import pathlib +<<<<<<< HEAD from torch._inductor.ir import MultiTemplateBuffer +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.metrics import get_metric_table, is_metric_table_enabled from torch.utils._ordered_set import OrderedSet @@ -46,9 +49,12 @@ def define_kernel(self, kernels): We should name the multi-kernel differently in these 2 cases. """ +<<<<<<< HEAD # Prevent circular import from ..select_algorithm import TritonTemplateKernel +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kernel_names = tuple(k.kernel_name for k in kernels) if kernel_names in self.subkernel_to_kernel_name: return self.subkernel_to_kernel_name[kernel_names] @@ -62,6 +68,7 @@ def define_kernel(self, kernels): # the second pass of cpp-wrapper. return multi_kernel_name +<<<<<<< HEAD arg_index: dict[int, list[slice]] = {} _, call_args, _, arg_types = kernels[0].args.python_argdefs() if isinstance(kernels[0], TritonTemplateKernel) and isinstance( @@ -93,13 +100,21 @@ def define_kernel(self, kernels): slice_reprs = ", ".join(repr(s) for s in slice_list) buf.writeline(f" {key}: [{slice_reprs}],") buf.writeline("}") +======= + buf = self.kernel_defs + buf.writeline("") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) buf.writeline( f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, [" ) with buf.indent(): for name in kernel_names: buf.writeline(f"{name},") +<<<<<<< HEAD buf.writeline(f"], arg_index=arg_index, shape_specialize={shape_specialize})") +======= + buf.writeline("])") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if config.triton.autotune_at_compile_time: V.graph.wrapper_code.src_to_kernel["\n".join(kernel_names)] = ( @@ -168,9 +183,12 @@ def call_kernel(self, kernel_name): Collect the union of arguments from all subkernels as the arguments for the multi-kernel. """ +<<<<<<< HEAD # Prevent circular import from ..select_algorithm import TritonTemplateKernel +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert kernel_name == self.kernel_name V.graph.wrapper_code.write_triton_header_once() _, call_args, _, arg_types = self.kernels[0].args.python_argdefs() @@ -184,6 +202,7 @@ def call_kernel(self, kernel_name): # the fast kernel directly kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) +<<<<<<< HEAD if isinstance(self.kernels[0], TritonTemplateKernel) and isinstance( self.kernels[0].output_node, MultiTemplateBuffer ): @@ -204,10 +223,15 @@ def call_kernel(self, kernel_name): self.kernels[0].add_numel_to_call_args(kernel_name, call_args, arg_types) multi_call_args = call_args multi_call_arg_types = arg_types +======= + # numels for all subkernels should be the same. Use kernels[0] here + self.kernels[0].add_numel_to_call_args(kernel_name, call_args, arg_types) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for ws in self.kernels[0].args.workspace_args: V.graph.wrapper_code.generate_workspace_allocation(ws) +<<<<<<< HEAD if V.graph.cpp_wrapper: # We have already selected the best kernel at compile time # so we only have one set of call args. NB: this currently @@ -220,6 +244,13 @@ def call_kernel(self, kernel_name): V.graph.wrapper_code.generate_kernel_call( kernel_name, multi_call_args, arg_types=multi_call_arg_types ) +======= + V.graph.wrapper_code.generate_kernel_call( + kernel_name, + call_args, + arg_types=arg_types, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for ws in reversed(self.kernels[0].args.workspace_args): V.graph.wrapper_code.generate_workspace_deallocation(ws) @@ -266,7 +297,11 @@ class MultiKernelCall: This class is called at run time to actually run the kernel """ +<<<<<<< HEAD def __init__(self, multi_kernel_name, kernels, arg_index, shape_specialize=False): +======= + def __init__(self, multi_kernel_name, kernels): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(kernels) >= 2 self._kernels = kernels self.multi_kernel_name = multi_kernel_name @@ -276,7 +311,10 @@ def __init__(self, multi_kernel_name, kernels, arg_index, shape_specialize=False ) == "1" or is_metric_table_enabled("persistent_red_perf") self.picked_kernel = None +<<<<<<< HEAD self.arg_index = arg_index +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if config.triton.multi_kernel > 1: # manually force a subkernel to ease perf testing picked_by_config = config.triton.multi_kernel - 2 @@ -287,6 +325,7 @@ def __init__(self, multi_kernel_name, kernels, arg_index, shape_specialize=False self._recorded = False +<<<<<<< HEAD # This means for each unique shape we will do a separate assessment # for which kernel is the best. This is particularly useful for matmul # kernels where the best kernel can vary based on very small differences @@ -294,6 +333,8 @@ def __init__(self, multi_kernel_name, kernels, arg_index, shape_specialize=False self._shape_specialize = shape_specialize self._shape_cache = {} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def cache_file_path(self): key = code_hash( ",".join( @@ -351,15 +392,22 @@ def benchmark_sub_kernels(self, *args, **kwargs): be picked. """ +<<<<<<< HEAD def wrap_fn(kernel, index): def inner(): filtered_args = self._get_filtered_args(args, index) args_clone, kwargs_clone = kernel.clone_args(*filtered_args, **kwargs) +======= + def wrap_fn(kernel): + def inner(): + args_clone, kwargs_clone = kernel.clone_args(*args, **kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return kernel.run(*args_clone, **kwargs_clone) return inner return [ +<<<<<<< HEAD benchmarker.benchmark_gpu(wrap_fn(kernel, index), rep=40) for index, kernel in enumerate(self.kernels) ] @@ -379,6 +427,12 @@ def _get_filtered_args(self, args, index): return args return [item for s in self.arg_index[index] for item in args[s]] +======= + benchmarker.benchmark_gpu(wrap_fn(kernel), rep=40) + for kernel in self.kernels + ] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # record_choice and lookup_choice are helper functions for cpp-wrapper # codegen. The first pass use record_choice to keep the choice and # the second pass do lookup by calling lookup_choice. @@ -415,6 +469,7 @@ def lookup_choice(multi_kernel_name: str) -> str: return V.graph.multi_kernel_to_choice[multi_kernel_name] def run(self, *args, **kwargs): +<<<<<<< HEAD if self._shape_specialize: cache_key = self._get_shape_cache_key(*args, **kwargs) cached_choice = self._get_cached_shape_choice(cache_key) @@ -429,6 +484,8 @@ def run(self, *args, **kwargs): else: self._select_kernel_by_shape(*args, **kwargs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.picked_kernel is None: timings = self.benchmark_sub_kernels(*args, **kwargs) self.picked_kernel = timings.index(min(timings)) @@ -444,7 +501,10 @@ def run(self, *args, **kwargs): get_metric_table("persistent_red_perf").add_row( functools.partial(self._metrics_table_row, timings) ) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not self.disable_cache: self.store_cache() @@ -455,6 +515,7 @@ def run(self, *args, **kwargs): ) assert picked_kernel_name is not None self.record_choice(self.multi_kernel_name, picked_kernel_name) +<<<<<<< HEAD run = self.kernels[self.picked_kernel].run # type: ignore[method-assign] filtered_args = self._get_filtered_args(args, self.picked_kernel) @@ -491,6 +552,10 @@ def _select_kernel_by_shape(self, *args, **kwargs): timings = self.benchmark_sub_kernels(*args, **kwargs) self.picked_kernel = timings.index(min(timings)) self._cache_shape_choice(shape_key, self.picked_kernel) +======= + self.run = self.kernels[self.picked_kernel].run # type: ignore[method-assign] + self.run(*args, **kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _metrics_table_row(self, timings): def get_kernel_path(k): diff --git a/torch/_inductor/codegen/rocm/ck_conv_template.py b/torch/_inductor/codegen/rocm/ck_conv_template.py index 032b0491a34fd..01906f83432b2 100644 --- a/torch/_inductor/codegen/rocm/ck_conv_template.py +++ b/torch/_inductor/codegen/rocm/ck_conv_template.py @@ -2,6 +2,7 @@ import copy import logging import random +<<<<<<< HEAD from typing import Any from typing_extensions import override @@ -9,6 +10,11 @@ from .rocm_template import ArgInfo +======= + +from torch._inductor.virtualized import V + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: import ck4inductor # type: ignore[import] @@ -287,8 +293,11 @@ def globals(self) -> IndentedBuffer: using ConvolutionForwardSpecialization = ck::tensor_operation::device::ConvolutionForwardSpecialization; +<<<<<<< HEAD using OutElementOp = PassThrough; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace ck { namespace utils { namespace conv { @@ -612,6 +621,7 @@ def size_args(self): right_pads_0, right_pads_1, ) +<<<<<<< HEAD @override def get_runtime_arg_info(self) -> list[ArgInfo]: @@ -623,3 +633,5 @@ def get_runtime_arg_values(self, **kwargs: Any) -> list[Any]: Helper method to retrieve runtime args from generate kwargs """ return [] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_inductor/codegen/rocm/ck_template.py b/torch/_inductor/codegen/rocm/ck_template.py index b1eaf5c228eed..1b3eaac41dfe7 100644 --- a/torch/_inductor/codegen/rocm/ck_template.py +++ b/torch/_inductor/codegen/rocm/ck_template.py @@ -21,10 +21,15 @@ class CKTemplate(ROCmTemplate): torch.bfloat16: "BF16", torch.int32: "I32", torch.int8: "I8", +<<<<<<< HEAD torch.float8_e4m3fnuz: "F8", # gfx94 torch.float8_e4m3fn: "F8", # gfx95 torch.float8_e5m2fnuz: "BF8", # gfx94 torch.float8_e5m2: "BF8", # gfx95 +======= + torch.float8_e4m3fnuz: "F8", + torch.float8_e5m2fnuz: "BF8", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } def header(self) -> IndentedBuffer: diff --git a/torch/_inductor/codegen/rocm/ck_tile_template.py b/torch/_inductor/codegen/rocm/ck_tile_template.py index 70d31d635cc36..391b0aa3ff22e 100644 --- a/torch/_inductor/codegen/rocm/ck_tile_template.py +++ b/torch/_inductor/codegen/rocm/ck_tile_template.py @@ -16,10 +16,15 @@ class CKTileTemplate(ROCmTemplate): torch.bfloat16: "BF16", torch.int32: "I32", torch.int8: "I8", +<<<<<<< HEAD torch.float8_e4m3fnuz: "F8", # gfx94 torch.float8_e4m3fn: "F8", # gfx95 torch.float8_e5m2fnuz: "BF8", # gfx94 torch.float8_e5m2: "BF8", # gfx95 +======= + torch.float8_e4m3fnuz: "F8", + torch.float8_e5m2fnuz: "BF8", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ck_dtype_to_size = { diff --git a/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py b/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py index b18010bda9086..89444bbee7520 100644 --- a/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py +++ b/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py @@ -242,6 +242,7 @@ class CKTileGemmTemplate(CKTileTemplate): constexpr auto TileK = {{instance_namespace}}::TileK; constexpr auto kPrefetchStages = BaseGemmPipeline::PrefetchStages; +<<<<<<< HEAD const auto BiasTerms = std::array (); const auto BiasStrides = std::array (); @@ -249,13 +250,23 @@ class CKTileGemmTemplate(CKTileTemplate): {X}, {W}, BiasTerms, +======= + auto kargs = ck_tile::GemmKernelArgs { + X, + W, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Y, M, N, K, +<<<<<<< HEAD {LDA}, {LDB}, BiasStrides, +======= + LDA, + LDB, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) LDC, kBatch }; @@ -693,6 +704,7 @@ def render_epilogue(epilogue_type): elif epilogue_type == "CShuffle": return r""" constexpr auto kMemoryOperation = ck_tile::memory_operation_enum::set; +<<<<<<< HEAD using DsDataType = ck_tile::tuple<>; // no bias terms for vanilla GEMM using DsLayout = ck_tile::tuple<>; constexpr auto ELayout = CLayout; @@ -705,6 +717,13 @@ def render_epilogue(epilogue_type): DsLayout, ELayout, CDEElementWise, +======= + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GemmPipelineProblem::kBlockSize, TileM, TileN, diff --git a/torch/_inductor/codegen/rocm/rocm_benchmark_request.py b/torch/_inductor/codegen/rocm/rocm_benchmark_request.py index df4982988aa15..79663c42f3448 100644 --- a/torch/_inductor/codegen/rocm/rocm_benchmark_request.py +++ b/torch/_inductor/codegen/rocm/rocm_benchmark_request.py @@ -96,7 +96,11 @@ def update_workspace_size(self) -> None: return self.ensure_dll_loaded() unique_input_count = len( +<<<<<<< HEAD dict.fromkeys(meta.name for meta in self.input_tensor_meta) +======= + {meta.name for meta in self.input_tensor_meta} # noqa: set_linter +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) args = [c_void_p(None) for _ in range(unique_input_count + 1)] stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) diff --git a/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py b/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py index 9288f73954ff3..cb82b294b9320 100644 --- a/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py +++ b/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py @@ -86,7 +86,11 @@ def codegen_template( _, (_numel, rnumel) = template_node.group assert rnumel == 1 ctb: ROCmTemplateBuffer = cast(ROCmTemplateBuffer, template_node.node) +<<<<<<< HEAD kernel, render = ctb.make_kernel_render(ctb) # type: ignore[misc] +======= + kernel, render = ctb.make_kernel_render(ctb) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with kernel: template_node.mark_run() src_code = render() diff --git a/torch/_inductor/codegen/rocm/rocm_kernel.py b/torch/_inductor/codegen/rocm/rocm_kernel.py index 5b90823b7f41c..eb5fdc27ebc67 100644 --- a/torch/_inductor/codegen/rocm/rocm_kernel.py +++ b/torch/_inductor/codegen/rocm/rocm_kernel.py @@ -7,6 +7,7 @@ from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch._inductor.utils import do_bench_using_profiling +<<<<<<< HEAD from ...ir import ( Buffer, ChoiceCaller, @@ -16,6 +17,9 @@ ShapeAsConstantBuffer, TensorBox, ) +======= +from ...ir import Buffer, ChoiceCaller, IRNode, Layout, PrimitiveInfoType, TensorBox +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ...virtualized import V from ..common import Kernel, OpOverrides, WorkspaceArg, WorkspaceZeroMode from ..cpp_utils import CppPrinter @@ -284,7 +288,11 @@ def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType **dict(self.info_kwargs["op"].dict_items()), # type: ignore[union-attr, index] } +<<<<<<< HEAD def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= + def output_node(self) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.bmreq.update_workspace_size() return TensorBox.create( ROCmTemplateBuffer( diff --git a/torch/_inductor/codegen/rocm/rocm_utils.py b/torch/_inductor/codegen/rocm/rocm_utils.py index 36871ac5c7f8f..0a762b5c8b1be 100644 --- a/torch/_inductor/codegen/rocm/rocm_utils.py +++ b/torch/_inductor/codegen/rocm/rocm_utils.py @@ -11,7 +11,10 @@ torch.float16: "uint16_t", torch.float8_e4m3fnuz: "uint8_t", torch.float8_e5m2fnuz: "uint8_t", +<<<<<<< HEAD torch.float8_e4m3fn: "uint8_t", torch.float8_e5m2: "uint8_t", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.bfloat16: "uint16_t", } diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index d73db7ed2a227..807ca6885aba7 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -18,7 +18,10 @@ import torch import torch._logging +<<<<<<< HEAD from torch._inductor.ir import MultiTemplateBuffer +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.tiling_utils import analyze_memory_coalescing from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.fx.immutable_collections import immutable_dict @@ -41,7 +44,10 @@ if TYPE_CHECKING: from ..ir import IRNode +<<<<<<< HEAD from ..debug import set_kernel_post_grad_provenance_tracing +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..optimize_indexing import indexing_dtype_strength_reduction from ..runtime.runtime_utils import green_text, yellow_text from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse @@ -52,6 +58,10 @@ IndentedBuffer, Placeholder, prefix_is_reduction, +<<<<<<< HEAD +======= + set_kernel_post_grad_provenance_tracing, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sympy_index_symbol, sympy_product, sympy_subs, @@ -408,7 +418,10 @@ def __init__( else self.should_use_cooperative_reduction() ) self.tiling_scores: Optional[dict[str, sympy.Expr]] = tiling_scores +<<<<<<< HEAD self.tiling: dict[str, sympy.Expr] = tiling +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.persistent_reduction: bool = ( override_persistent_reduction if override_persistent_reduction is not None @@ -757,6 +770,7 @@ def is_compatible( def split_and_set_ranges( self, lengths: Sequence[Sequence[sympy.Expr]] ) -> list[list[sympy.Expr]]: +<<<<<<< HEAD """ Split and set iteration ranges for the kernel based on the provided lengths. @@ -776,16 +790,23 @@ def split_and_set_ranges( # If we're not inside a reduction loop, set all reduction dimensions to 1 # This effectively disables reduction dimensions when not needed +======= + tiling = {rt.prefix: rt.numel for rt in self.range_trees} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not self.inside_reduction: for prefix in tiling: if prefix_is_reduction(prefix): tiling[prefix] = sympy.S.One +<<<<<<< HEAD # Extract the values from the tiling dictionary to create groups groups = [*tiling.values()] # Map the kernel's group structure to the node's sizes and set the ranges # using the set_ranges method, returning the resulting iteration variables +======= + groups = [*tiling.values()] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges) @classmethod @@ -972,6 +993,7 @@ def _map_tuple_or_scalar(fn, value): return tuple(map(fn, value)) return fn(value) +<<<<<<< HEAD def estimate_flops(self) -> Optional[int]: flops = [ node.estimate_flops() @@ -979,6 +1001,8 @@ def estimate_flops(self) -> Optional[int]: ] return sum(filter(None, flops)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def estimate_kernel_num_bytes(self): """ Try the best to estimate the total size (in bytes) of the @@ -1423,6 +1447,7 @@ def can_use_32bit_indexing( if buf.has_tensor_output() ] +<<<<<<< HEAD for buf in buffers: if not buf.has_tensor_output() and isinstance(buf, ir.MutationOutput): mutated_bufs = buf.get_mutation_buffers() @@ -1432,14 +1457,22 @@ def can_use_32bit_indexing( if buf.has_tensor_output() ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not all(expr_fits_within_32bit(size) for size in buf_sizes): return False # Only install guards for 32-bit indexing as there is no correctness # issue with using 64-bit for everything +<<<<<<< HEAD V.graph.sizevars.check_leq(numel, int_max) # type: ignore[arg-type] for size in buf_sizes: V.graph.sizevars.check_leq(size, int_max) # type: ignore[arg-type] +======= + V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type] + for size in buf_sizes: + V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return True def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): @@ -1459,17 +1492,28 @@ def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): for kernel in kernels: self.codegen_node_schedule_with_kernel(node_schedule, kernel) MultiKernel.merge_workspaces_inplace(kernels) +<<<<<<< HEAD debug_handles: list[tuple[str, Optional[int]]] = [] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for kernel in kernels: with V.set_kernel_handler(kernel): src_code = kernel.codegen_kernel() kernel_name = self.define_kernel(src_code, node_schedule, kernel) +<<<<<<< HEAD if config.trace.provenance_tracking_level != 0: debug_handle = set_kernel_post_grad_provenance_tracing( node_schedule, # type: ignore[arg-type] kernel_name, ) debug_handles.append((kernel_name, debug_handle)) +======= + if config.trace.enabled: + set_kernel_post_grad_provenance_tracing( + node_schedule, # type: ignore[arg-type] + kernel_name, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log.debug("Generating kernel code with kernel_name: %s", kernel_name) kernel.kernel_name = kernel_name kernel.code_hash = code_hash(src_code) @@ -1486,10 +1530,13 @@ def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): node.mark_run() self.codegen_comment(node_schedule) +<<<<<<< HEAD for kernel_name, debug_handle in debug_handles: V.graph.wrapper_code.write_provenance_debug_handle( kernel_name, debug_handle ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) final_kernel.call_kernel(final_kernel.kernel_name) if config.nan_asserts: @@ -1501,7 +1548,11 @@ def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove if ( +<<<<<<< HEAD V.graph.wrapper_code.supports_intermediate_hooks # type: ignore[has-type] +======= + V.graph.wrapper_code.supports_intermediate_hooks +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and config.generate_intermediate_hooks ): # Not every node in the schedule will actually be live on output; @@ -1565,6 +1616,7 @@ def codegen_node_schedule_with_kernel(self, node_schedule, kernel): index_vars = kernel.split_and_set_ranges(node.get_ranges()) node.codegen(index_vars) +<<<<<<< HEAD def _codegen_single_template( self, kernel, @@ -1578,6 +1630,20 @@ def _codegen_single_template( """ Helper method to codegen a single template kernel variant """ +======= + def codegen_template( + self, template_node, epilogue_nodes, prologue_nodes, *, only_gen_src_code=False + ) -> Optional[str]: + """ + Codegen a triton template + + If `only_gen_src_code` the src code will be returned instead of codegen'd into the wrapper + """ + _, (_numel, rnumel) = template_node.group + assert rnumel == 1 + kernel, render = template_node.node.make_kernel_render(template_node.node) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) buf_name_to_prologue_group = {} template_reads = template_node.used_buffer_names() prologue_group = [] @@ -1641,9 +1707,13 @@ def _codegen_single_template( kernel.cse.invalidate(OrderedSet()) if not isinstance(partial_code, str): +<<<<<<< HEAD # This is used to calculate flops in TritonTemplateKernels with ir.IRNode.current_origins(template_node.node.origins): partial_code.finalize_hook("") +======= + partial_code.finalize_hook("") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) partial_code.finalize_hook("", strict=False) # finalize must be called after adding epilogue above @@ -1655,6 +1725,7 @@ def _codegen_single_template( partial_code.finalize_hook(subgraph_name, strict=False) with kernel.set_subgraph_body(""): +<<<<<<< HEAD if not isinstance(partial_code, str): partial_code.finalize_hook("") @@ -1665,6 +1736,13 @@ def _codegen_single_template( # Note: some of these hooks may have been registered by a kernel subclass src_code = partial_code.finalize_remaining() +======= + if isinstance(partial_code, str): + src_code = partial_code + else: + partial_code.finalize_hook("") + src_code = partial_code.code +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) node_schedule = [*prologue_nodes, template_node, *epilogue_nodes] if config.benchmark_kernel: @@ -1678,6 +1756,7 @@ def _codegen_single_template( if only_gen_src_code: return src_code +<<<<<<< HEAD kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel) if config.trace.provenance_tracking_level != 0: @@ -1785,6 +1864,20 @@ def codegen_template( V.graph.inplaced_to_remove |= kernel.inplaced_to_remove self.free_buffers_in_scheduler() return None +======= + kernel_name = self.define_kernel(src_code, node_schedule, kernel) + + if config.trace.enabled: + set_kernel_post_grad_provenance_tracing(node_schedule, kernel_name) + + self.codegen_comment(node_schedule) + kernel.call_kernel(kernel_name, template_node.node) + + V.graph.removed_buffers |= kernel.removed_buffers + V.graph.inplaced_to_remove |= kernel.inplaced_to_remove + self.free_buffers_in_scheduler() + return None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def codegen_sync(self): V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize()) @@ -1865,7 +1958,11 @@ def codegen_combo_kernel(self, combo_kernel_node): for src_code, kernel, _ in kernel_code_list: kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel) # dump provenance node info for ComboKernelNode/ForeachKernel type +<<<<<<< HEAD if config.trace.provenance_tracking_level != 0: +======= + if config.trace.enabled: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) set_kernel_post_grad_provenance_tracing( combo_kernel_node.snodes, kernel_name ) @@ -2007,7 +2104,11 @@ def collapse_ranges(ranges: Sequence[sympy.Expr]) -> sympy.Expr: @classmethod def create_tiling( cls, pw_tiling: Sequence[sympy.Expr], reduction_tiling: Sequence[sympy.Expr] +<<<<<<< HEAD ) -> immutable_dict[str, sympy.Expr]: +======= + ) -> dict[str, sympy.Expr]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Create a tiling dict from pointwise and reduction splits. """ @@ -2022,7 +2123,11 @@ def create_partial_tiling( cls, tiling: Sequence[sympy.Expr], is_pointwise: bool, +<<<<<<< HEAD ) -> immutable_dict[str, sympy.Expr]: +======= + ) -> dict[str, sympy.Expr]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return cls.create_tiling( tiling if is_pointwise else [], tiling if not is_pointwise else [], @@ -2034,7 +2139,11 @@ def complete_partial_tiling( tiling: dict[str, sympy.Expr], numel: sympy.Expr, reduction_numel: sympy.Expr, +<<<<<<< HEAD ) -> immutable_dict[str, sympy.Expr]: +======= + ) -> dict[str, sympy.Expr]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Given a tiling for only pointwise or reduction dimensions, adds the missing one. """ @@ -2055,7 +2164,11 @@ def get_nd_tilings( node_schedule, pointwise_numel, reduction_numel, +<<<<<<< HEAD ) -> list[immutable_dict[str, sympy.Expr]]: +======= + ) -> list[dict[str, tuple[sympy.Expr]]]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Creates N-dimensional tiling candidates, attempting to simplify loads/stores by tiling the kernel into higher dimensions. @@ -2063,7 +2176,11 @@ def get_nd_tilings( Returns a list of tilings ranked by dimensionality. """ is_pointwise = reduction_numel == 1 +<<<<<<< HEAD tilings = OrderedSet[immutable_dict[str, sympy.Expr]]() +======= + tilings = OrderedSet[dict[str, sympy.Expr]]() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for node in EnableReduction.filter(node_schedule): if not isinstance(node, scheduler.SchedulerNode): continue @@ -2328,7 +2445,11 @@ def process_node_vars( ) ) +<<<<<<< HEAD tilings: list[tuple[CandidateTiling, immutable_dict[str, sympy.Expr]]] = [] +======= + tilings: list[tuple[CandidateTiling, dict[str, sympy.Expr]]] = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (pw_split, pw_score), (red_split, red_score) in score_split: candidate = CandidateTiling( cls.create_tiling(pw_split, red_split), @@ -2557,9 +2678,13 @@ def flush(self): def ready_to_flush(self) -> bool: return False +<<<<<<< HEAD def generate_kernel_code_from_nodes( self, nodes, benchmark_kernel=False, hint_override: Optional[int] = None ): +======= + def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not any(n.is_template() for n in nodes): _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group node_schedule = self.generate_node_schedule(nodes, numel, rnumel) @@ -2584,7 +2709,10 @@ def generate_kernel_code_from_nodes( epilogue, prologue, only_gen_src_code=True, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") diff --git a/torch/_inductor/codegen/subgraph.py b/torch/_inductor/codegen/subgraph.py index 374186c2e2426..bdb108d21a0cb 100644 --- a/torch/_inductor/codegen/subgraph.py +++ b/torch/_inductor/codegen/subgraph.py @@ -1,6 +1,10 @@ import itertools import logging +<<<<<<< HEAD from typing import Any, Callable, Union +======= +from typing import Any, Callable +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._inductor.config as config @@ -132,7 +136,11 @@ def hash_key(self) -> str: ] ) +<<<<<<< HEAD def output_node(self) -> Union[ir.TensorBox, ir.ShapeAsConstantBuffer]: +======= + def output_node(self) -> ir.TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ir.TensorBox.create( ir.SubgraphBuffer( layout=self.layout, @@ -168,6 +176,10 @@ class SubgraphTemplate(KernelTemplate): def __init__( self, name: str, +<<<<<<< HEAD +======= + make_fx_graph: Callable[..., Any], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Initialize a subgraph template. @@ -176,6 +188,7 @@ def __init__( name: The name of this template graph: The FX graph """ +<<<<<<< HEAD super().__init__(name=name) def generate( # type: ignore[override] @@ -185,6 +198,15 @@ def generate( # type: ignore[override] layout: Layout, make_fx_graph: Callable[..., Any], description: str = "", +======= + self.name = f"{name}_{next(SubgraphTemplate.index_counter)}" + self.make_fx_graph = make_fx_graph + + def generate( # type: ignore[override] + self, + input_nodes: list[Buffer], + layout: Layout, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) **kwargs: Any, ) -> SubgraphChoiceCaller: """ @@ -201,9 +223,17 @@ def generate( # type: ignore[override] """ return SubgraphChoiceCaller( +<<<<<<< HEAD name=f"{name}_{next(SubgraphTemplate.index_counter)}", input_nodes=input_nodes, layout=layout, description=description, make_fx_graph=make_fx_graph, +======= + name=self.name, + input_nodes=input_nodes, + layout=layout, + description="", + make_fx_graph=self.make_fx_graph, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 17a336cc3cf2e..5b15d68931283 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -26,7 +26,11 @@ from torch._prims_common import is_integer_dtype from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +<<<<<<< HEAD from torch.utils._triton import has_triton_package, has_triton_stable_tma_api +======= +from torch.utils._triton import has_triton_package +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT from ...utils._sympy.value_ranges import ValueRanges @@ -105,7 +109,10 @@ from torch._inductor.dtype_propagation import DtypePropagationOpsHandler from ..ir import IRNode +<<<<<<< HEAD from .common import BlockShapeType +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .simd_kernel_features import SIMDKernelFeatures _T = TypeVar("_T") @@ -210,7 +217,10 @@ class IndexingOptions: expand_str: Optional[str] _has_rindex: bool index: sympy.Expr +<<<<<<< HEAD expand_shape: Optional[Sequence[Union[int, str]]] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def has_mask(self) -> bool: return bool(self.mask_vars) @@ -238,6 +248,7 @@ def mask_str(self) -> str: @dataclasses.dataclass +<<<<<<< HEAD class BlockDescriptorOptions: """ This is a base class that describes a block descriptor used in Triton kernels. @@ -245,6 +256,9 @@ class BlockDescriptorOptions: or a block pointer (with BlockPtrOptions). """ +======= +class BlockPtrOptions: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) params: BlockParameters constant_offset: sympy.Expr order: list[int] @@ -270,17 +284,70 @@ def strides(self) -> list[sympy.Expr]: def offsets(self) -> list[sympy.Expr]: return self.params.offsets +<<<<<<< HEAD @classmethod def create( cls, +======= + def codegen_broadcast_and_reshape( + self, + value: str, + initial_shape: Sequence[sympy.Expr], + final_shape: Sequence[sympy.Expr], + allow_implicit: bool, + ) -> str: + """ + Generate a broadcast and a reshape for the block pointer. + This restores stride-0 dimensions which were removed from the block pointer. + """ + + # Reshape to add singletons. + pre_broadcast_shape = [ + sympy.S.One if is_broadcasting else dim + for dim, is_broadcasting in zip( + self.broadcast_shape, self.broadcasting_dims + ) + ] + value = triton_reshape(value, initial_shape, pre_broadcast_shape) + + # Broadcast singletons. + # For loads, we can often implicitly broadcast singleton dimensions. + # We need an explicit broadcast for stores, or if the final reshape does more + # than add singletons. + sizevars = V.graph.sizevars + supports_implicit_broadcast = allow_implicit and ( + len(pre_broadcast_shape) == len(final_shape) + and all( + sizevars.statically_known_equals(pre_dim, 1) + or sizevars.statically_known_equals(pre_dim, post_dim) + for pre_dim, post_dim in zip(pre_broadcast_shape, final_shape) + ) + ) + + if any(self.broadcasting_dims) and not supports_implicit_broadcast: + value = f"tl.broadcast_to({value}, {V.kernel.index_to_str(self.broadcast_shape)})" + + # Reshape to the final shape. + value = triton_reshape(value, self.broadcast_shape, final_shape) + + return value + + @staticmethod + def create( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *, params: BlockParameters, constant_offset: sympy.Expr, range_trees: list[IterationRangesRoot], mask_vars: OrderedSet[str], get_max_block: Callable[[str], int], +<<<<<<< HEAD ) -> BlockDescriptorOptions: """Helper to create a BlockDescriptorOptions instance""" +======= + ) -> BlockPtrOptions: + """Helper to create a BlockPtrOptions instance""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sizevars = V.graph.sizevars @@ -318,6 +385,7 @@ def lookup_size(exprs: Iterable[sympy.Expr]) -> list[sympy.Expr]: # Combine all removable dims. removable_dims = [any(dims) for dims in zip(singleton_dims, broadcasting_dims)] +<<<<<<< HEAD # Remove singleton_dims from broadcasting_dims so that # broadcast_shape and broadcasting_dims have the same length broadcasting_dims = [ @@ -326,6 +394,8 @@ def lookup_size(exprs: Iterable[sympy.Expr]) -> list[sympy.Expr]: if not is_singleton ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def remove_dims(it): """Removes any broadcasting or singleton dims from a given sequence""" return [ @@ -354,7 +424,11 @@ def remove_dims(it): # Need to expand rank to match the rank used inside the reduction loop final_shape += [sympy.S.One] * reduction_ndim +<<<<<<< HEAD result = cls( +======= + result = BlockPtrOptions( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) params=params, constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset), order=list(reversed(range(len(params.shape)))), @@ -375,10 +449,47 @@ def replace_offset( roffset = TritonSymbols.block_offsets[symt] return sympy_subs(expr, {roffset: replacement}) +<<<<<<< HEAD def remove_roffsets(self, expr: sympy.Expr) -> sympy.Expr: for symt in TritonSymbols.reduction_types: expr = self.replace_offset(expr, sympy.Integer(0), symt) return expr +======= + def format(self, name: str, roffset=True) -> str: + """ + Codegen a call to tl.make_block_ptr() + + Args: + name: variable name for pointer + roffset: should rn_offset be included in offsets=..., for use with tl.advance() + + Returns: + "tl.make_block_ptr(...)" + """ + + def remove_roffsets(expr: sympy.Expr) -> sympy.Expr: + for symt in TritonSymbols.reduction_types: + expr = self.replace_offset(expr, sympy.Integer(0), symt) + return expr + + f = V.kernel.index_to_str + offsets = [*self.offsets] + if not roffset: + offsets = [remove_roffsets(offset) for offset in offsets] + args = [ + ( + f"{name} + ({f(self.constant_offset)})" + if self.constant_offset != 0 + else name + ), + f"shape={f(self.shape)}", + f"strides={f(self.strides)}", + f"block_shape={f(self.block_shape)}", + f"order={f(self.order)}", + f"offsets={f(offsets)}", + ] + return f"tl.make_block_ptr({', '.join(args)})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def compute_boundary_check( self, @@ -436,6 +547,7 @@ def boundary_check(self) -> list[int]: assert self._boundary_check is not None return self._boundary_check +<<<<<<< HEAD def has_indirect(self) -> bool: return False # block_ptr can't do indirect indexing @@ -572,6 +684,8 @@ def format(self, name: str, roffset=True) -> str: ] return f"tl.make_block_ptr({', '.join(args)})" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def advance_roffset(self, symt: SymT) -> sympy.Expr: """ Codegen string to pass to tl.advance(name, ...). @@ -591,6 +705,27 @@ def advance_roffset(self, symt: SymT) -> sympy.Expr: ] return advance +<<<<<<< HEAD +======= + def has_indirect(self) -> bool: + return False # block_ptr can't do indirect indexing + + def has_rindex(self) -> bool: + return any( + free_symbol_is_type(expr, TritonSymbols.reduction_types) + for expr in self.block_shape + ) + + def has_rmask(self) -> bool: + return self.has_rindex() + + def has_tmpmask(self) -> bool: + return False # block_ptr can't do indirect indexing + + def has_mask(self) -> bool: + return bool(self.boundary_check()) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def triton_reshape( value: str, old_shape: Sequence[sympy.Expr], new_shape: Sequence[sympy.Expr] @@ -836,6 +971,7 @@ def low_precision_fp_var(var: Union[CSEVariable, Any]) -> bool: class TritonCSEVariable(CSEVariable): +<<<<<<< HEAD def __init__( self, name: str, @@ -849,6 +985,13 @@ def __init__( assert dtype is not None, "TritonCSEVariable must have dtype" # TODO: uncomment this and fix the few failures left # assert shape is not None, "TritonCSEVariable must have shape" +======= + def __init__(self, name, bounds: ValueRanges[Any], dtype: torch.dtype) -> None: + super().__init__(name, bounds, dtype) + # We'll use this to track which masks the variable needs when used for indirect indexing + self.mask_vars: OrderedSet[str] = OrderedSet() + assert dtype is not None, "TritonCSEVariable must have dtype" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def update_on_args(self, name, args, kwargs): for arg in args: @@ -969,6 +1112,7 @@ def _get_min_elements_per_thread( if dtype == torch.bool: return f"({x} != 0)" +<<<<<<< HEAD elif dtype == torch.uint8 and ( src_dtype is not None and src_dtype.is_floating_point or src_dtype is None ): @@ -977,6 +1121,12 @@ def _get_min_elements_per_thread( # optimization - if source type is known and it's not a floating type, then # do not apply conversion to the intermediate type. return f"{x}.to(tl.int16).to(tl.uint8)" +======= + elif dtype == torch.uint8: + # to work around llvm uint conversion semantics + # that produces 0's for negative values + return f"{x}.to(tl.int8).to(tl.uint8)" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if use_compute_types: out_dtype = triton_compute_type(dtype) @@ -1062,9 +1212,15 @@ def exp(x): more details. """ if config.use_fast_math: +<<<<<<< HEAD return f"tl_math.exp({x})" else: return f"libdevice.exp({x})" +======= + return f"libdevice.exp2({x} * {TritonOverrides._LOG_2_E})" + else: + return f"tl_math.exp({x})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod @maybe_upcast_float32() @@ -1101,6 +1257,7 @@ def relu(x): @staticmethod def minimum(a, b): +<<<<<<< HEAD if torch.version.hip: return f"tl.minimum({a}, {b}, tl.PropagateNan.ALL)" else: @@ -1112,6 +1269,13 @@ def maximum(a, b): return f"tl.maximum({a}, {b}, tl.PropagateNan.ALL)" else: return f"triton_helpers.maximum({a}, {b})" +======= + return f"triton_helpers.minimum({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"triton_helpers.maximum({a}, {b})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def where(a, b, c): @@ -1297,10 +1461,14 @@ def load_seed(name, offset): @staticmethod @maybe_upcast_float32() def rsqrt(x): +<<<<<<< HEAD if torch.version.hip: return f"tl.rsqrt({x})" else: return f"libdevice.rsqrt({x})" +======= + return f"libdevice.rsqrt({x})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod @maybe_upcast_float32() @@ -1465,9 +1633,13 @@ def constant(cls, value, dtype): @classmethod def index_expr(cls, expr, dtype): +<<<<<<< HEAD indexing = V.kernel.indexing( expr, block_ptr=False, tma_compatibility_checker=None ) +======= + indexing = V.kernel.indexing(expr, block_ptr=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(indexing, IndexingOptions) # Our sympy expr printing casts to the current kernel index dtype. @@ -1484,7 +1656,10 @@ def index_expr(cls, expr, dtype): indexing.index_str, bounds=get_bounds_index_expr(expr), dtype=dtype, +<<<<<<< HEAD shape=indexing.expand_shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) finally: config.test_configs.runtime_triton_dtype_assert = orig @@ -1494,7 +1669,10 @@ def index_expr(cls, expr, dtype): V.kernel.compute, cls.to_dtype(var, dtype), dtype=upcast_compute_type(dtype), +<<<<<<< HEAD shape=var.shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: # TODO: we are not always consistent in enforcing that the output of the index expr printing @@ -1513,7 +1691,10 @@ def index_expr(cls, expr, dtype): V.kernel.compute, cls.to_dtype(var, index_dtype), dtype=index_dtype, +<<<<<<< HEAD shape=var.shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) var.mask_vars = indexing.mask_vars @@ -1526,7 +1707,10 @@ def masked(mask, body, other): V.kernel.compute, f"{mask}.to(tl.int1)", dtype=torch.bool, +<<<<<<< HEAD shape=mask.shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) nodes = body.graph.find_nodes(op="output") @@ -1557,7 +1741,10 @@ def masked(mask, body, other): f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)", bounds=ValueRanges.wrap(other), dtype=result.dtype, +<<<<<<< HEAD shape=result.shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ret = ops.where(new_mask, result, other) else: @@ -1579,18 +1766,26 @@ def frexp(x): if cse_val := V.kernel.cse.try_get(cache_key): return cse_val +<<<<<<< HEAD mantissa = V.kernel.cse.newvar(dtype=x.dtype, shape=x.shape) exponent = V.kernel.cse.newvar(dtype=torch.int32, shape=x.shape) +======= + mantissa = V.kernel.cse.newvar(dtype=x.dtype) + exponent = V.kernel.cse.newvar(dtype=torch.int32) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) V.kernel.compute.writeline( f"{mantissa}, {exponent} = triton_helpers.frexp({x})" ) V.kernel.cse.put(cache_key, (mantissa, exponent)) return (mantissa, exponent) +<<<<<<< HEAD @staticmethod def device_assert_async(cond, msg): return f"tl.device_assert({cond}, {repr(msg)})" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class HelperFunctions: """An ordered set of helper functions.""" @@ -1711,6 +1906,7 @@ def augment_key(self, cache_key: str) -> Union[str, tuple[str, str]]: return cache_key +<<<<<<< HEAD @dataclasses.dataclass class TMACompatibilityChecker: """ @@ -1885,11 +2081,17 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): triton kernel programmatically """ +======= +class TritonKernel(SIMDKernel[TritonCSEVariable]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) overrides = TritonKernelOverrides # type: ignore[assignment] helper_functions: HelperFunctions kexpr: Callable[[sympy.Expr], str] = texpr allow_block_ptr = True +<<<<<<< HEAD tma_compatibility_checker_cls = TMACompatibilityChecker +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__( self, @@ -1897,7 +2099,10 @@ def __init__( min_elem_per_thread=0, optimize_mask=True, fixed_config: Optional[FixedTritonConfig] = None, +<<<<<<< HEAD hint_override: Optional[int] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) **kwargs, ) -> None: self.optimize_mask: bool = optimize_mask @@ -1914,8 +2119,11 @@ def __init__( self.pointer_advancements: dict[SymT, dict[str, list[sympy.Expr]]] = ( collections.defaultdict(dict) ) +<<<<<<< HEAD self.tma_min_block_sizes = dict[str, int]() self.hint_override = hint_override +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._load_counts: collections.Counter[str] = collections.Counter() # A set of autotuning hints to pass as part of triton_meta @@ -2057,7 +2265,10 @@ def indexing( dense_indexing=False, override_mask=None, block_ptr=False, +<<<<<<< HEAD tma_compatibility_checker: Optional[TMACompatibilityChecker] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Compute the index and mask to pass to tl.load() or tl.store() @@ -2118,6 +2329,7 @@ def indexing( dense_mask_vars.add(f"{tree.prefix}mask") if ( +<<<<<<< HEAD ( (block_ptr and self.allow_block_ptr and config.triton.use_block_ptr) or ( @@ -2125,6 +2337,11 @@ def indexing( and tma_compatibility_checker.can_use_tma() ) ) +======= + block_ptr + and self.allow_block_ptr + and config.triton.use_block_ptr +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and not override_mask and not self._load_mask and len(mask_vars - dense_mask_vars) == 0 @@ -2254,7 +2471,11 @@ def match_mod_div_block( offsets=block_offsets, ) +<<<<<<< HEAD def match_block_subexpr( +======= + def match_block_pointer_subexpr( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expr: sympy.Expr, range_tree: IterationRangesRoot ) -> Optional[BlockParameters]: """ @@ -2270,7 +2491,11 @@ def match_block_subexpr( return None +<<<<<<< HEAD def match_block_expr() -> Optional[BlockDescriptorOptions]: +======= + def match_block_pointer() -> Optional[BlockPtrOptions]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) index_relative_to_xyr_index = sympy_subs( index, {v: t.expr for v, t in self.range_tree_nodes.items()} ) @@ -2296,7 +2521,11 @@ def match_block_expr() -> Optional[BlockDescriptorOptions]: return None # Match the subexpression for this range tree. +<<<<<<< HEAD params = match_block_subexpr(subexpr, tree) +======= + params = match_block_pointer_subexpr(subexpr, tree) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if params is None: return None block_params += params @@ -2304,6 +2533,7 @@ def match_block_expr() -> Optional[BlockDescriptorOptions]: # Collect leftover terms as a constant offset. offset = index_relative_to_xyr_index - sum(index_subexprs) +<<<<<<< HEAD # Form the block pointer or TMA descriptor. self.filter_masks(mask_vars) @@ -2313,6 +2543,11 @@ def match_block_expr() -> Optional[BlockDescriptorOptions]: else TensorDescriptorOptions ) options = options_class.create( +======= + # Form the block pointer. + self.filter_masks(mask_vars) + return BlockPtrOptions.create( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) params=block_params, constant_offset=offset, range_trees=range_trees, @@ -2320,6 +2555,7 @@ def match_block_expr() -> Optional[BlockDescriptorOptions]: get_max_block=self.max_block, ) +<<<<<<< HEAD if options_class == TensorDescriptorOptions: nonlocal tma_compatibility_checker tma_compatibility_checker = cast( @@ -2334,15 +2570,25 @@ def match_block_expr() -> Optional[BlockDescriptorOptions]: # Return a block pointer, if indexing matches the pattern. options = match_block_expr() +======= + # Return a block pointer, if indexing matches the pattern. + options = match_block_pointer() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if options is not None: return options expand_str = None +<<<<<<< HEAD expand_shape: BlockShapeType = None index_str = self.index_to_str(index) if isinstance(index, sympy.Integer): expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() expand_shape = None if copy_shape else tuple(self.dense_size_list()) +======= + index_str = self.index_to_str(index) + if isinstance(index, sympy.Integer): + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" if self.fixed_config and not self._has_constant_xmask(): mask_vars = OrderedSet(["xmask"]) @@ -2350,6 +2596,7 @@ def match_block_expr() -> Optional[BlockDescriptorOptions]: mask_vars = OrderedSet() if self._load_mask: mask_vars.add(self._load_mask) +<<<<<<< HEAD return IndexingOptions( index_str, mask_vars, @@ -2362,18 +2609,27 @@ def match_block_expr() -> Optional[BlockDescriptorOptions]: if need_dense and not have_dense: expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() expand_shape = None if copy_shape else tuple(self.dense_size_list()) +======= + return IndexingOptions(index_str, mask_vars, expand_str, has_rindex, index) + + if need_dense and not have_dense: + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) index_str = f"tl.broadcast_to({index_str}, {expand_str})" mask_vars = dense_mask_vars elif not have_loop_vars and copy_shape: index_str = f"tl.broadcast_to({index_str}, {copy_shape}.shape)" mask_vars = dense_mask_vars +<<<<<<< HEAD if expand_shape is None: if need_dense or have_dense: expand_shape = None if copy_shape else tuple(self.dense_size_list()) else: expand_shape = () +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if override_mask: mask_vars = OrderedSet([override_mask]) @@ -2382,6 +2638,7 @@ def match_block_expr() -> Optional[BlockDescriptorOptions]: self.filter_masks(mask_vars) +<<<<<<< HEAD return IndexingOptions( index_str, mask_vars, @@ -2415,11 +2672,28 @@ def codegen_block_ptr( else: other = f", boundary_check={check!r}" +======= + return IndexingOptions(index_str, mask_vars, expand_str, has_rindex, index) + + def codegen_block_ptr( + self, name: str, var: str, indexing: BlockPtrOptions, other="" + ) -> tuple[str, str]: + check = indexing.boundary_check() + if not check: + # workaround https://github.com/triton-lang/triton/issues/2813 + other = "" + elif other: + assert other == ", other=0.0" + other = f", boundary_check={check!r}, padding_option='zero'" + else: + other = f", boundary_check={check!r}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( self.inside_reduction and self.range_trees[-1].is_loop and indexing.has_rindex() ): +<<<<<<< HEAD block_descriptor_id = next(self.block_ptr_id) if isinstance(indexing, BlockPtrOptions): block_descriptor = f"block_ptr{block_descriptor_id}" @@ -2457,12 +2731,47 @@ def codegen_block_ptr( else: block_descriptor = indexing.format(var) return block_descriptor, other +======= + block_ptr = f"block_ptr{next(self.block_ptr_id)}" + self.body.writeline( + DeferredLine( + name, f"{block_ptr} = {indexing.format(var, roffset=False)}" + ) + ) + # Store for later use. If the buffer is removed the below advancements + # are no longer necessary + self.block_ptr_to_buffer[block_ptr] = name + + # Generate block pointer advancements, for later use. + for symt in TritonSymbols.reduction_types: + advance_offsets = indexing.advance_roffset(symt) + + # Ignore identity advancements. + if all( + V.graph.sizevars.statically_known_equals(offset, sympy.Integer(0)) + for offset in advance_offsets + ): + continue + + advancements = self.pointer_advancements[symt] + assert block_ptr not in advancements, ( + "duplicate advancement for pointer '{block_ptr}' at type '{symt}'" + ) + advancements[block_ptr] = advance_offsets + else: + block_ptr = indexing.format(var) + return block_ptr, other +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""): # Stores require an explicit broadcast. We do this in two phases: # 1. Broadcast the operand to the final shape of the range trees, e.g. [ZBLOCK, # YBLOCK, XBLOCK]. This protects against implicit broadcasting from loads. +<<<<<<< HEAD # 2. In case the block pointer / tma descriptor has different dimensionality, broadcast/reshape the +======= + # 2. In case the block pointer has different dimensionality, broadcast/reshape the +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # result to the shape of the pointer. value = f"tl.broadcast_to({value}, {indexing.final_shape})" @@ -2479,9 +2788,13 @@ def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=" # workaround https://github.com/triton-lang/triton/issues/2814 value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})" +<<<<<<< HEAD if isinstance(indexing, BlockPtrOptions): return f"tl.store({block_ptr}, {value}{other})" return f"{block_ptr}.store({V.kernel.index_to_str(indexing.offsets)}, {value})" +======= + return f"tl.store({block_ptr}, {value}{other})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def check_bounds( self, @@ -2494,7 +2807,11 @@ def check_bounds( return assert isinstance(expr, sympy.Expr) +<<<<<<< HEAD indexing = self.indexing(expr, block_ptr=False, tma_compatibility_checker=None) +======= + indexing = self.indexing(expr, block_ptr=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(indexing, IndexingOptions) index_str = indexing.index_str @@ -2525,15 +2842,19 @@ def get_load_buffer(self, indexing): return self.loads def load(self, name: str, index: sympy.Expr): +<<<<<<< HEAD """ Load from the memory location 'name', offset by some indexing expression 'index'. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) var = self.args.input(name) load_counts = self._load_counts load_counts[name] += 1 make_line: Callable[[str], Union[str, DelayReplaceLine]] = identity indirect_indexing = self.is_indirect_indexing(index) original_index = index +<<<<<<< HEAD dtype = V.graph.get_dtype(name) indexing = self.indexing( index, @@ -2542,6 +2863,9 @@ def load(self, name: str, index: sympy.Expr): self, dtype, for_store=False ), ) +======= + indexing = self.indexing(index, block_ptr=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) has_rindex = indexing.has_rindex() has_tmpmask = indexing.has_tmpmask() @@ -2606,7 +2930,11 @@ def decide_later(): cachemod = ", cache_modifier='.cg'" append_broadcast = None +<<<<<<< HEAD shape: BlockShapeType = None +======= + dtype = V.graph.get_dtype(name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if should_unwrap_unspec_arg(name): line = var @@ -2614,6 +2942,7 @@ def decide_later(): # see triton_utils.py:signature_of if dtype in (torch.float16, torch.bfloat16): dtype = torch.float32 +<<<<<<< HEAD shape = () else: @@ -2636,6 +2965,21 @@ def decide_later(): else: line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other}{cachemod})" shape = indexing.expand_shape +======= + + else: + if isinstance(indexing, BlockPtrOptions): + block_ptr, other = self.codegen_block_ptr(name, var, indexing, other) + line = f"tl.load({block_ptr}{other}{ep}{cachemod})" + line = indexing.codegen_broadcast_and_reshape( + line, indexing.block_shape, indexing.final_shape, True + ) + elif isinstance(original_index, sympy.Integer): + line = f"tl.load({var} + ({original_index}))" + append_broadcast = indexing.expand_str + else: + line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other}{cachemod})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( dtype in (torch.float16, torch.bfloat16) @@ -2651,9 +2995,13 @@ def decide_later(): dtype = torch.bool load_buffer = self.get_load_buffer(indexing) +<<<<<<< HEAD result_var = self.cse.generate( load_buffer, make_line(line), dtype=dtype, shape=shape ) +======= + result_var = self.cse.generate(load_buffer, make_line(line), dtype=dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if result_var.use_count > 1: load_counts[name] -= 1 # don't double count cache hit assert isinstance(result_var, TritonCSEVariable) @@ -2661,9 +3009,13 @@ def decide_later(): if append_broadcast: line = f"tl.broadcast_to({result_var}, {append_broadcast})" +<<<<<<< HEAD result_var = self.cse.generate( load_buffer, line, dtype=dtype, shape=indexing.expand_shape ) +======= + result_var = self.cse.generate(load_buffer, line, dtype=dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if indexing.mask_vars: if dtype.is_floating_point: zero = "0.0" @@ -2675,9 +3027,13 @@ def decide_later(): constant_repr(self._load_other) if self._load_other else zero ) line = f"tl.where({indexing.mask_str}, {result_var}, {other_val})" +<<<<<<< HEAD result_var = self.cse.generate( load_buffer, line, dtype=dtype, shape=result_var.shape ) +======= + result_var = self.cse.generate(load_buffer, line, dtype=dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex): self.outside_loop_vars.add(result_var) @@ -2689,6 +3045,7 @@ def store( ) -> None: var = self.args.output(name) original_index = index +<<<<<<< HEAD dtype = V.graph.get_dtype(name) tma_compatibility_checker = None @@ -2702,6 +3059,9 @@ def store( block_ptr=mode is None, tma_compatibility_checker=tma_compatibility_checker, ) +======= + indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Guard against write-after-read corruption in triton. # See # https://github.com/triton-lang/triton/issues/1615 @@ -2714,11 +3074,19 @@ def store( if is_inplace and is_broadcasted: self.stores.writeline(DeferredLine(name, "tl.debug_barrier()")) +<<<<<<< HEAD if isinstance(indexing, (BlockPtrOptions, TensorDescriptorOptions)): block_descriptor, other = self.codegen_block_ptr(name, var, indexing) # block_ptr / tma descriptor stores don't do implicit casting line = self.codegen_block_ptr_store_line( name, indexing, block_descriptor, value, other +======= + if isinstance(indexing, BlockPtrOptions): + block_ptr, other = self.codegen_block_ptr(name, var, indexing) + # block_ptr stores don't do implicit casting + line = self.codegen_block_ptr_store_line( + name, indexing, block_ptr, value, other +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) elif mode is None: line = f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})" @@ -2747,6 +3115,7 @@ def guard_cooperative_store(self, name, buffer): buffer.writeline(DeferredLine(name, f"if rsplit_id == ({idx} % RSPLIT):")) return buffer.indent() +<<<<<<< HEAD def _combine_masks(self, *variables: Optional[CSEVariable]): masks = None for elem in variables: @@ -2759,6 +3128,8 @@ def _combine_masks(self, *variables: Optional[CSEVariable]): masks = masks | elem.mask_vars return masks +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def bucketize( self, values: CSEVariable, @@ -2806,12 +3177,17 @@ def bucketize( f"{sorter_indices}, " ")", dtype=indexing_dtype, # type: ignore[attr-defined] +<<<<<<< HEAD shape=values.shape, ) masks = self._combine_masks(values, boundary_indices, sorter_indices) result.mask_vars = masks # type: ignore[attr-defined] +======= + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return result def reduction_resize(self, value) -> str: @@ -2823,6 +3199,7 @@ def reduction_resize(self, value) -> str: sizes = [":"] * (ndims - nreduce) + ["None"] * nreduce return f"{value}[{', '.join(sizes)}]" +<<<<<<< HEAD def reduction_resize_and_shape(self, value, shape) -> tuple[str, BlockShapeType]: ndims = self.triton_tensor_ndim() if ndims == 1: @@ -2838,6 +3215,9 @@ def reduction_resize_and_shape(self, value, shape) -> tuple[str, BlockShapeType] def reduction_collapse_dims( self, buffer, value: CSEVariable, dtype: torch.dtype ) -> CSEVariable: +======= + def reduction_collapse_dims(self, buffer, value: str, dtype: torch.dtype) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Reshape to RBLOCK, collapsing all reduction dims. """ @@ -2848,11 +3228,18 @@ def reduction_collapse_dims( target_ndim = self.triton_tensor_ndim() - self.num_reduction_dims initial_shape = self.dense_size_list() target_shape = initial_shape[:target_ndim] + ["RBLOCK"] +<<<<<<< HEAD return self.cse.generate( buffer, triton_reshape(str(value), initial_shape, target_shape), dtype=dtype, shape=tuple(target_shape), +======= + return str( + self.cse.generate( + buffer, triton_reshape(value, initial_shape, target_shape), dtype=dtype + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def reduction( @@ -2903,7 +3290,10 @@ def maybe_upcast(value: CSEVariable) -> CSEVariable: self.compute, f"tl.broadcast_to({v}, {dense_size_str})", dtype=v.dtype, +<<<<<<< HEAD shape=tuple(self.dense_size_list()), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), value, ) @@ -2913,9 +3303,15 @@ def maybe_upcast(value: CSEVariable) -> CSEVariable: def final_reduction( buffer, +<<<<<<< HEAD value: CSEVariable, result_type: Optional[torch.dtype], ) -> tuple[str, Optional[torch.dtype], BlockShapeType]: +======= + value: str, + result_type: Optional[str], + ) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Helper to generate a reduction call, e.g. tl.sum. """ @@ -2924,6 +3320,7 @@ def final_reduction( value = self.reduction_collapse_dims(buffer, value, dtype) if reduction_type in ("max", "min"): +<<<<<<< HEAD result, shape = self.reduction_resize_and_shape( f"{module}.{reduction_type}2({value}, {dim})", value.shape ) @@ -2944,11 +3341,35 @@ def final_reduction_define( result_var: CSEVariable, value: CSEVariable, result_type: Optional[torch.dtype], +======= + value = self.reduction_resize( + f"{module}.{reduction_type}2({value}, {dim})" + ) + else: + value = self.reduction_resize( + f"{module}.{reduction_type}({value}, {dim})" + ) + + if result_type is not None: + value = f"{value}.to({result_type})" + + return value + + def final_reduction_define( + buffer, + result_var: str, + value: str, + result_type: Optional[str], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: """ Generate a reduction and assign it to an existing variable. """ +<<<<<<< HEAD value, _, _ = final_reduction(buffer, value, result_type) +======= + value = final_reduction(buffer, value, result_type) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) buffer.splice(f"{result_var} = {value}") def final_argreduce(buffer, result_var, value, index): @@ -2967,11 +3388,15 @@ def final_argreduce(buffer, result_var, value, index): acc_type = triton_acc_type(src_dtype) torch_acc_type = upcast_acc_dtype(src_dtype) +<<<<<<< HEAD result_shape = list(self.dense_size_list()) result_shape[dim] = "1" result_var: Any = self.cse.newvar( dtype=torch_acc_type, shape=tuple(result_shape) ) +======= + result_var: Any = self.cse.newvar(dtype=torch_acc_type) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result_var.mask_vars = OrderedSet( var for var in masks if not prefix_is_reduction(var[0]) ) @@ -2988,10 +3413,14 @@ def where_cond(tval, fval): def _mask_value(value, default) -> CSEVariable: return self.cse.generate( +<<<<<<< HEAD self.compute, where_cond(value, default), dtype=value.dtype, shape=value.shape if value.shape is not None else default.shape, +======= + self.compute, where_cond(value, default), dtype=value.dtype +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) masked_value: Union[CSEVariable, Sequence[CSEVariable]] @@ -3005,14 +3434,20 @@ def _mask_value(value, default) -> CSEVariable: masked_value = _mask_value(value, default) if reduction_type in ("argmax", "argmin"): +<<<<<<< HEAD assert isinstance(masked_value, CSEVariable) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) accumulator_dtype = V.kernel.get_index_dtype_as_torch_dtype() accumulator_index = str( self.cse.generate( self.compute, f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)", dtype=accumulator_dtype, +<<<<<<< HEAD shape=masked_value.shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) root_op = {"argmax": "max", "argmin": "min"}[reduction_type] @@ -3035,8 +3470,13 @@ def _mask_value(value, default) -> CSEVariable: assert isinstance(masked_value, Sequence) (mean, m2, weight) = masked_value result_var = tuple( +<<<<<<< HEAD self.cse.generate(self.compute, value, dtype=dtype, shape=shape) for value, shape in self._welford( +======= + self.cse.generate(self.compute, value, dtype=dtype) + for value in self._welford( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.compute, mean, m2, weight, dim, dtype ) ) @@ -3046,6 +3486,7 @@ def _mask_value(value, default) -> CSEVariable: result_var = self.prepare_softmax_twopass_fallback(dtype, value) else: assert isinstance(masked_value, CSEVariable) +<<<<<<< HEAD _result, _dtype, _shape = final_reduction( self.compute, masked_value, masked_value.dtype ) @@ -3058,6 +3499,15 @@ def _mask_value(value, default) -> CSEVariable: dtype=torch_acc_type, shape=tuple(self.dense_size_list()), ) +======= + result_var = self.cse.generate( + self.compute, + final_reduction(self.compute, str(masked_value), None), + dtype=masked_value.dtype, + ) + else: + accumulator = self.cse.namedvar(f"_{result_var}", dtype=torch_acc_type) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) default = ir.Reduction.default_accumulator(reduction_type, src_dtype) default = self._map_tuple_or_scalar(constant_repr, default) if not isinstance(default, tuple): @@ -3124,7 +3574,11 @@ def _mask_value(value, default) -> CSEVariable: # reduce. Similar to the final reduction for coopereative # reduction result_max = result_var +<<<<<<< HEAD result_sum = self.cse.newvar(dtype=dtype, shape=result_max.shape) +======= + result_sum = self.cse.newvar(dtype=dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result_var = self.online_softmax_reduce_final_reduction( self.post_loop_combine, @@ -3149,6 +3603,7 @@ def _mask_value(value, default) -> CSEVariable: # to # tmp5 = triton_helpers.max(_tmp5.to(tl.int8), 1)[:, None].to(tl.int1) # which is needed because tl.reduce doesn't support tl.int1 +<<<<<<< HEAD accumulator = self.cse.generate( self.post_loop_combine, f"{accumulator}.to(tl.int8)", @@ -3159,6 +3614,20 @@ def _mask_value(value, default) -> CSEVariable: final_reduction_define( self.post_loop_combine, result_var, accumulator, None ) +======= + accumulator_casted_str = f"{accumulator}.to(tl.int8)" + result_type = triton_compute_type(dtype) + final_reduction_define( + self.post_loop_combine, + str(result_var), + accumulator_casted_str, + result_type, + ) + else: + final_reduction_define( + self.post_loop_combine, str(result_var), str(accumulator), None + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.cooperative_reduction: default = ir.Reduction.default_accumulator(reduction_type, src_dtype) @@ -3211,7 +3680,10 @@ def _mask_value(value, default) -> CSEVariable: ) elif reduction_type == "online_softmax_reduce": result_max, result_sum = result_var +<<<<<<< HEAD assert isinstance(default, Sequence) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) peer_max = self.codegen_cooperative_reduction_peer_combine( result_max, upcast_acc_dtype(src_dtype), default[0] ) @@ -3231,7 +3703,13 @@ def _mask_value(value, default) -> CSEVariable: peers = self.codegen_cooperative_reduction_peer_combine( result_var, upcast_acc_dtype(src_dtype), default ) +<<<<<<< HEAD final_reduction_define(self.post_loop_store, result_var, peers, None) +======= + final_reduction_define( + self.post_loop_store, str(result_var), peers, None + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) exit_stack.close() self.cse.reduction_cache[cache_key] = result_var @@ -3291,6 +3769,7 @@ def _welford(self, buffer, mean, m2, weight, dim, dtype: torch.dtype): for value in (mean, m2, weight) ) welford = f"triton_helpers.welford({mean}, {m2}, {weight}, {dim})" +<<<<<<< HEAD def reduced_shape(shape): return tuple(shape[0:dim] + shape[dim + 1 :]) @@ -3305,6 +3784,13 @@ def reduced_shape(shape): self.reduction_resize_and_shape(value, value.shape) for value in welford_results ) +======= + welford_results = [str(self.cse.newvar(dtype=dtype)) for _ in range(3)] + buffer.writeline(f"{', '.join(welford_results)} = {welford}") + + result_values = tuple(self.reduction_resize(value) for value in welford_results) + return result_values +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def welford_reduce( self, result_var, reduction_type, value, where_cond, acc_type, dtype @@ -3312,6 +3798,7 @@ def welford_reduce( """Helper to codegen a welford reduction""" dim = self.triton_tensor_ndim() - self.num_reduction_dims +<<<<<<< HEAD accumulator = TritonCSEVariable( f"{result_var}_mean", shape=tuple(self.dense_size_list()), @@ -3330,6 +3817,11 @@ def welford_reduce( dtype=acc_type, bounds=ValueRanges.unknown(), ) +======= + accumulator = f"{result_var}_mean" + accumulator_m2 = f"{result_var}_m2" + accumulator_weight = f"{result_var}_weight" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.body.writeline( f"{accumulator} = tl.zeros({self.dense_size_str()}, {acc_type})" ) @@ -3366,11 +3858,21 @@ def welford_reduce( """ ) result_mean = result_var +<<<<<<< HEAD return self.welford_reduce_final_reduction( self.post_loop_combine, result_mean, None, None, +======= + result_m2 = self.cse.newvar(dtype=dtype) + result_weight = self.cse.newvar(dtype=dtype) + return self.welford_reduce_final_reduction( + self.post_loop_combine, + result_mean, + result_m2, + result_weight, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) accumulator, accumulator_m2, accumulator_weight, @@ -3391,6 +3893,7 @@ def welford_reduce_final_reduction( dtype, ): """Helper to codegen call to triton_helpers.welford""" +<<<<<<< HEAD values = list(self._welford(buffer, mean, m2, weight, dim, dtype)) result_exprs = [result_mean, result_m2, result_weight] @@ -3401,10 +3904,19 @@ def welford_reduce_final_reduction( buffer.splice(f"{result_expr} = {value}") return tuple(result_exprs) +======= + values = self._welford(buffer, mean, m2, weight, dim, dtype) + result_exprs = [result_mean, result_m2, result_weight] + for result_expr, value in zip(result_exprs, values): + buffer.splice(f"{result_expr} = {value}") + + return result_mean, result_m2, result_weight +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def online_softmax_reduce_final_reduction( self, buffer, result_max, result_sum, peer_max, peer_sum, dim, dtype ): +<<<<<<< HEAD accumulator_max = self.reduction_collapse_dims(buffer, peer_max, dtype) accumulator_sum = self.reduction_collapse_dims(buffer, peer_sum, dtype) buffer.splice( @@ -3415,6 +3927,13 @@ def online_softmax_reduce_final_reduction( {result_sum} = {self.reduction_resize(f"{result_sum}")} """ ) +======= + values = self._online_softmax_reduce(buffer, peer_max, peer_sum, dim, dtype) + result_exprs = [result_max, result_sum] + for result_expr, value in zip(result_exprs, values): + buffer.splice(f"{result_expr} = {value}") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return result_max, result_sum def max_rsplit(self): @@ -3424,7 +3943,11 @@ def max_rsplit(self): def codegen_cooperative_reduction_peer_combine( self, result_var, dtype, default_val +<<<<<<< HEAD ) -> CSEVariable: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Generate code to save a [XBLOCK, RSPLIT] temporary workspace, where each thread block writes a different column. After the barrier, every thread block loads the completed value so that it can compute the final @@ -3443,6 +3966,7 @@ def codegen_cooperative_reduction_peer_combine( """, strip=True, ) +<<<<<<< HEAD peers = self.create_cse_var( f"{result_var}_peers", shape=["XBLOCK", "RSPLIT"], @@ -3454,6 +3978,13 @@ def codegen_cooperative_reduction_peer_combine( f"rsplit_mask, eviction_policy='evict_first', other=triton_helpers.if_mask(rsplit_mask, {constant_repr(default_val)}))" ) return peers +======= + self.post_loop_store.writeline( + f"{result_var}_peers = tl.load({result_var}_ws + (xindex * RSPLIT + rsplit_arange), " + f"rsplit_mask, eviction_policy='evict_first', other=triton_helpers.if_mask(rsplit_mask, {constant_repr(default_val)}))" + ) + return f"{result_var}_peers" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def store_reduction( self, @@ -3463,6 +3994,7 @@ def store_reduction( ): assert self.inside_reduction self.inside_reduction = False +<<<<<<< HEAD dtype = V.graph.get_dtype(name) indexing = self.indexing( index, @@ -3471,6 +4003,9 @@ def store_reduction( kernel=self, dtype=dtype, for_store=True ), ) +======= + indexing = self.indexing(index, block_ptr=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.inside_reduction = True var = self.args.output(name) @@ -3480,7 +4015,11 @@ def store_reduction( self.guard_cooperative_store(name, self.post_loop_store) ) +<<<<<<< HEAD if isinstance(indexing, (BlockPtrOptions, TensorDescriptorOptions)): +======= + if isinstance(indexing, BlockPtrOptions): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.post_loop_store.writeline( DeferredLine( name, @@ -3504,9 +4043,13 @@ def store_reduction( exit_stack.close() +<<<<<<< HEAD def _lift_helper( self, fn, values: tuple[CSEVariable, ...], dtypes: tuple[torch.dtype, ...] ) -> str: +======= + def _lift_helper(self, fn, num_args, dtypes: tuple[torch.dtype, ...]) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Lift IR function for scan operations into a triton function # in the global namespace helper = IndentedBuffer() @@ -3514,10 +4057,14 @@ def _lift_helper( cse = CSE() args = [ +<<<<<<< HEAD tuple( cse.namedvar(f"arg{i}_{n}", dtype=dtype, shape=value.shape) for n, (value, dtype) in enumerate(zip(values, dtypes)) ) +======= + tuple(cse.namedvar(f"arg{i}_{n}", dtype=dtypes[n]) for n in range(num_args)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i in range(2) ] signature = ", ".join(str(x) for x in itertools.chain.from_iterable(args)) @@ -3532,9 +4079,13 @@ def _lift_helper( helper_name = "_triton_helper_fn" from torch._inductor.dtype_propagation import DtypePropagationOpsHandler +<<<<<<< HEAD from torch._inductor.shape_propagation import ShapePropagationOpsHandler shape_handler = ShapePropagationOpsHandler() +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype_handler = DtypePropagationOpsHandler() class CSEProxy(DefaultHandler): @@ -3549,16 +4100,22 @@ def _default( name, )(*args, **kwargs) +<<<<<<< HEAD output_shape = getattr( shape_handler, name, )(*args, **kwargs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return cse.generate( helper, getattr(overrides, name)(*args, **kwargs), dtype=output_dtype, +<<<<<<< HEAD shape=output_shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) with helper.indent(), V.set_ops_handler(CSEProxy()): @@ -3576,9 +4133,12 @@ def scan( ], values: tuple[CSEVariable, ...], ) -> tuple[CSEVariable, ...]: +<<<<<<< HEAD """ Perform an associative scan on 'values'. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.inside_reduction assert not self.cooperative_reduction, "TODO" masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) @@ -3591,7 +4151,11 @@ def scan( dtypes = tuple(upcast_compute_type(dtype) for dtype in dtypes) cse_compute = functools.partial(self.cse.generate, self.compute) +<<<<<<< HEAD combine_helper_fn = self._lift_helper(combine_fn, values, dtypes) +======= + combine_helper_fn = self._lift_helper(combine_fn, len(values), dtypes) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dim = self.triton_tensor_ndim() - self.num_reduction_dims for value, dtype in zip(values, dtypes): @@ -3599,19 +4163,26 @@ def scan( self.compute, f"{value}.to({triton_compute_type(dtype)})", dtype=dtype, +<<<<<<< HEAD shape=value.shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) value = self.cse.generate( self.compute, f"tl.broadcast_to({value_dtype}, {self.dense_size_str()})", dtype=dtype, +<<<<<<< HEAD shape=tuple(self.dense_size_list()), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) broadcasted_values.append(value) acc_type = triton_acc_type(dtype) if not self.persistent_reduction: +<<<<<<< HEAD reduced_size = self.dense_size_list() reduced_size[-1] = "1" accumulator = self.cse.newvar(dtype=dtype, shape=reduced_size) @@ -3620,6 +4191,16 @@ def scan( default = "float('nan')" if dtype.is_floating_point else "-1" self.body.writeline( f"{accumulator} = tl.full({reduced_size_str}, {default}, {acc_type})" +======= + accumulator = self.cse.newvar(dtype=dtype) + reduced_size = self.dense_size_list() + reduced_size[-1] = "1" + reduced_size = f"[{', '.join(reduced_size)}]" + + default = "float('nan')" if dtype.is_floating_point else "-1" + self.body.writeline( + f"{accumulator} = tl.full({reduced_size}, {default}, {acc_type})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) accumulators.append(accumulator) @@ -3632,10 +4213,14 @@ def cse_multiple(line, values, masks, dtypes): cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] if all(self.cse.contains(cache_key) for cache_key in cache_keys): return [self.cse.get(cache_key) for cache_key in cache_keys] +<<<<<<< HEAD result_vars = [ self.cse.newvar(dtype=dtype, shape=value.shape) for (dtype, value) in zip(dtypes, values) ] +======= + result_vars = [self.cse.newvar(dtype=_dtype) for _dtype in dtypes] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.compute.writeline( f"{csv(result_vars)} = {line}", ) @@ -3647,7 +4232,11 @@ def cse_multiple(line, values, masks, dtypes): partial_scan_vars = cse_multiple( f"tl.associative_scan(({csv(broadcasted_values)}), {dim}, {combine_helper_fn})", +<<<<<<< HEAD broadcasted_values, +======= + values, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) masks, dtypes, ) @@ -3656,6 +4245,7 @@ def cse_multiple(line, values, masks, dtypes): # tl.reduce doesn't work for non-commutative operators, so instead # of repeating the scan op as a reduction, we use sum to select the # last scan value +<<<<<<< HEAD def _partial_scan_shape(var): if var.shape is None: return None @@ -3664,11 +4254,16 @@ def _partial_scan_shape(var): shape[-1] = "1" return shape +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) partial_reduce_vars = [ cse_compute( f"triton_helpers.select_one(({partial_scan_var}), rbase == (RBLOCK - 1), dim=-1, keep_dims=True)", dtype=upcast_compute_type(partial_scan_var.dtype), +<<<<<<< HEAD shape=_partial_scan_shape(partial_scan_var), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for partial_scan_var in partial_scan_vars ] @@ -3678,7 +4273,10 @@ def _partial_scan_shape(var): cse_compute( f"tl.where(roffset > 0, {full_scan}, {partial_scan})", dtype=partial_scan.dtype, +<<<<<<< HEAD shape=partial_scan.shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for full_scan, partial_scan in zip(full_scan_vars, partial_scan_vars) ] @@ -3721,9 +4319,13 @@ def sort( assert len(dtypes) == len(values) broadcasted_values = [ cse_compute( +<<<<<<< HEAD f"tl.broadcast_to({value}, {self.dense_size_str()})", dtype=dtypes[i], shape=tuple(self.dense_size_list()), +======= + f"tl.broadcast_to({value}, {self.dense_size_str()})", dtype=dtypes[i] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for i, value in enumerate(values) ] @@ -3731,6 +4333,7 @@ def sort( def csv(values): return " ".join(f"{value}," for value in values) +<<<<<<< HEAD def cse_multiple(line, broadcasted_values, masks, dtypes): n = len(broadcasted_values) cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] @@ -3740,6 +4343,13 @@ def cse_multiple(line, broadcasted_values, masks, dtypes): self.cse.newvar(dtype=dtype, shape=value.shape) for dtype, value in zip(dtypes, broadcasted_values) ] # type: ignore[attr-defined] +======= + def cse_multiple(line, n, masks, dtypes): + cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] + if all(self.cse.contains(cache_key) for cache_key in cache_keys): + return [self.cse.get(cache_key) for cache_key in cache_keys] + result_vars = [self.cse.newvar(dtype=dtypes[i]) for i in range(n)] # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.compute.writeline( f"{csv(result_vars)} = {line}", ) @@ -3757,7 +4367,11 @@ def cse_multiple(line, broadcasted_values, masks, dtypes): f"triton_helpers.sort_with_index({broadcasted_values[0]}, {broadcasted_values[1]}," f" {rnumel}, {dim}, stable={stable}, descending={descending})" ) +<<<<<<< HEAD result_vars = cse_multiple(line, broadcasted_values, masks, dtypes) +======= + result_vars = cse_multiple(line, len(values), masks, dtypes) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: raise AssertionError("Unhandled sort") @@ -3797,9 +4411,14 @@ def codegen_body(self): loop_end = ( "rsplit_end" if self.cooperative_reduction else f"{prefix}numel" ) +<<<<<<< HEAD num_stages = ", num_stages = 2" if torch.version.hip else "" self.body.writeline( f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK{num_stages}):" +======= + self.body.writeline( + f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) with self.body.indent(offset=level + 1): self.iteration_ranges_codegen_header(tree, self.body) @@ -3897,13 +4516,21 @@ def codegen_kernel_benchmark(self, num_gb): buf = V.graph.try_get_buffer(arg_name) if buf: result.writeline( +<<<<<<< HEAD f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size(), hint_override=self.hint_override)}, {V.graph.sizevars.size_hints(buf.get_stride(), hint_override=self.hint_override)}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long +======= + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) elif arg_name in V.graph.constants: # note that random seed is put in V.graph.constants const_tensor = V.graph.constants[arg_name] result.writeline( +<<<<<<< HEAD f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size(), hint_override=self.hint_override)}, {V.graph.sizevars.size_hints(const_tensor.stride(), hint_override=self.hint_override)}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long +======= + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) elif isinstance(arg_sig, SizeArg): symval_hint = V.graph.sizevars.size_hint(arg_sig.expr) @@ -4035,12 +4662,16 @@ def inductor_meta_common(): ) return inductor_meta +<<<<<<< HEAD def codegen_kernel(self, name=None) -> str: """ Convert the TritonKernel from Inductor SIMD IR to triton code, including inductor triton heuristics, imports, metadata, and benchmarking infra. """ +======= + def codegen_kernel(self, name=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) code = IndentedBuffer() size_hints = {} @@ -4170,8 +4801,13 @@ def add_constexpr_arg(arg_name): optimize_mem = V.graph.is_inference or V.graph.is_backward inductor_meta = { +<<<<<<< HEAD "grid_type": self._get_grid_type().__name__, # Triton will not accept an OrderedSet for autotune_hints +======= + # Triton will not accept an OrderedSet for autotune_hints + "grid_type": self._get_grid_type().__name__, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "autotune_hints": set(self.autotune_hints), # noqa: set_linter "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), "mutated_arg_names": mutated_args, @@ -4181,6 +4817,7 @@ def add_constexpr_arg(arg_name): "num_reduction": self.num_reduction, **self.inductor_meta_common(), } +<<<<<<< HEAD # Bail on 3d tiling, which has more complicated coalesce patterns looped_red = V.kernel.features.is_reduction() and not self.persistent_reduction @@ -4234,18 +4871,27 @@ def add_constexpr_arg(arg_name): if self.tma_min_block_sizes: inductor_meta["tma_min_block_sizes"] = self.tma_min_block_sizes +======= + if self.tiling_scores: + inductor_meta["tiling_scores"] = self.tiling_scores + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.cooperative_reduction: inductor_meta["persistent_reduction"] = self.persistent_reduction num_gb = None if config.benchmark_kernel or config.profile_bandwidth: num_gb = self.estimate_kernel_num_bytes() / 1e9 +<<<<<<< HEAD if num_gb is not None: inductor_meta["kernel_num_gb"] = num_gb if config.benchmark_kernel: flops = self.estimate_flops() if flops is not None: inductor_meta["kernel_flop"] = flops +======= + inductor_meta["kernel_num_gb"] = num_gb +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) triton_meta["configs"] = [config_of(signature)] @@ -4730,11 +5376,14 @@ def define_kernel(self, src_code, node_schedule, kernel): kernel_name = "_".join( ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()] ) +<<<<<<< HEAD if config.aot_inductor.model_name_for_generated_files: # When AOTI compiles multiple submodules, we need to use the model name to # distinguish kernel related symbols. kernel_name = f"{config.aot_inductor.model_name_for_generated_files}_{kernel_name}" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # use the original src_code as the key wrapper.src_to_kernel[src_code] = kernel_name subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" diff --git a/torch/_inductor/codegen/triton_split_scan.py b/torch/_inductor/codegen/triton_split_scan.py index b36d26ec08bf6..c634d0f6959f6 100644 --- a/torch/_inductor/codegen/triton_split_scan.py +++ b/torch/_inductor/codegen/triton_split_scan.py @@ -86,9 +86,12 @@ def reduction(self, dtype, src_dtype, reduction_type, value): raise NotImplementedError("NYI TritonSplitDimKernel reductions") def scan(self, dtypes, combine_fn, values): +<<<<<<< HEAD """ Perform an associative scan on 'values'. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import triton.language as tl (dtype,) = dtypes @@ -126,6 +129,7 @@ def scan(self, dtypes, combine_fn, values): scratch_base: Union[str, TritonCSEVariable] scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True) if offset != 0: +<<<<<<< HEAD scratch_base = cse_load( f"{scratch_base} + {self.index_to_str(offset)}", shape=() ) @@ -136,6 +140,13 @@ def scan(self, dtypes, combine_fn, values): f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * " f"{scratch_elems_per_block} * {runtime_rblocks}", shape=(), +======= + scratch_base = cse_load(f"{scratch_base} + {self.index_to_str(offset)}") + runtime_rblocks = cse_load(f"tl.num_programs({self.range_trees[-1].index})") + scratch_base = cse_load( + f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * " + f"{scratch_elems_per_block} * {runtime_rblocks}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) @@ -145,11 +156,15 @@ def scan(self, dtypes, combine_fn, values): value = cse_compute( f"{value}.to({compute_type})", dtype=dtype, +<<<<<<< HEAD shape=value.shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) value = cse_compute( f"tl.broadcast_to({value}, {self.dense_size_str()})", dtype=dtype, +<<<<<<< HEAD shape=self.dense_size_list(), ) @@ -158,15 +173,28 @@ def scan(self, dtypes, combine_fn, values): assert dim == 0, "" shape = list(self.dense_size_list()) del shape[dim] +======= + ) + + combine_helper_fn = self._lift_helper(combine_fn, 1, (dtype,)) + dim = self.triton_tensor_ndim() - 1 + assert dim == 0, "" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) block_sum = cse_compute( f"tl.reduce({value}, {dim}, {combine_helper_fn})", dtype=dtype, +<<<<<<< HEAD shape=shape, ) exclusive_prefix = self.cse.newvar( dtype=dtype, shape=shape, +======= + ) + exclusive_prefix = self.cse.newvar( + dtype=dtype, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if element_nbits == 64: self.compute.splice( @@ -202,18 +230,27 @@ def scan(self, dtypes, combine_fn, values): block_scan = cse_compute( f"tl.associative_scan({value}, {dim}, {combine_helper_fn})", dtype=dtype, +<<<<<<< HEAD shape=shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) combined_result = cse_compute( f"{combine_helper_fn}({exclusive_prefix}, {block_scan})", dtype=dtype, +<<<<<<< HEAD shape=shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return ( cse_compute( f"tl.where(roffset == 0, {block_scan}, {combined_result})", dtype=dtype, +<<<<<<< HEAD shape=block_scan.shape, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ) diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index d97988f684c00..9befbdf696d97 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -71,8 +71,11 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str: return "constexpr" elif isinstance(arg.expr, (float, sympy.Float)): return "fp32" +<<<<<<< HEAD elif isinstance(arg.expr, bool): return "i1" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # if this is a integer if size_dtype == "tl.int32": @@ -83,7 +86,11 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str: # no hint: we'll see if we know that this is a 32-bit int, and guard if possible. int_max = torch.iinfo(torch.int32).max if expr_fits_within_32bit(arg.expr): +<<<<<<< HEAD V.graph.sizevars.check_leq(arg.expr, int_max) +======= + V.graph.sizevars.guard_leq(arg.expr, int_max) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return "i32" else: return "i64" diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 4aa7037618b99..4fd38cdcdb849 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -40,7 +40,10 @@ from .. import async_compile, config, ir from ..codecache import output_code_log +<<<<<<< HEAD from ..debug import set_kernel_post_grad_provenance_tracing +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..ir import IRNode, ReinterpretView from ..runtime import triton_heuristics from ..runtime.hints import DeviceProperties @@ -48,11 +51,18 @@ cache_on_self, DelayReplaceLine, get_benchmark_name, +<<<<<<< HEAD get_dtype_size, IndentedBuffer, is_codegen_graph_partition_subgraph, is_using_cudagraph_partition, LineContext, +======= + IndentedBuffer, + is_codegen_graph_partition_subgraph, + LineContext, + set_kernel_post_grad_provenance_tracing, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sympy_product, sympy_str, sympy_subs, @@ -228,6 +238,7 @@ def writeline(line: str, example_grid: Optional[str] = None): key=lambda x: len(x[1].kwargs), reverse=True, ): +<<<<<<< HEAD guardslist = [] if c.kwargs: # Remove AMD specific kwargs. @@ -240,6 +251,13 @@ def writeline(line: str, example_grid: Optional[str] = None): guardslist.append(f"meta['{kwarg}'] == {c.kwargs[kwarg]}") if guardslist: guards = " and ".join(guardslist) +======= + if c.kwargs: + guards = [ + f"meta['{name}'] == {val}" for name, val in c.kwargs.items() + ] + guards = " and ".join(guards) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: guards = "True" # for configs with empty kwargs grid, example_grid = determine_grid(grid, example_grid) @@ -261,7 +279,10 @@ def user_defined_triton_kernel_transitive_closure_source_code(kernel) -> str: compile_wrapper.splice(kernel.src, strip=True) # Also include any possible kernel being called indirectly +<<<<<<< HEAD import triton +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from triton import JITFunction # type: ignore[name-defined, attr-defined] from triton.language import constexpr # type: ignore[name-defined] @@ -290,6 +311,7 @@ def traverse(cur_kernel): compile_wrapper.splice(symbol.src, strip=True) symbols_included.add(symbol_name) traverse(symbol) +<<<<<<< HEAD elif hasattr(triton, "constexpr_function") and isinstance( symbol, triton.runtime.jit.ConstexprFunction ): @@ -298,6 +320,8 @@ def traverse(cur_kernel): compile_wrapper.splice(symbol.src, strip=True) symbols_included.add(symbol_name) traverse(symbol) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif isinstance(symbol, (int, str, bool, constexpr)): compile_wrapper.newline() if isinstance(symbol, constexpr): @@ -340,7 +364,11 @@ def traverse(cur_kernel): @dataclasses.dataclass class SymbolicCallArg: +<<<<<<< HEAD inner: sympy.Symbol +======= + inner: str +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # the original symbolic expression represented by inner inner_expr: sympy.Expr @@ -496,19 +524,28 @@ def codegen(self, code: IndentedBuffer) -> None: else: kernel_name = node.get_kernel_name() device = d.type if (d := node.get_device()) else V.graph.device_type +<<<<<<< HEAD provenance_debug_handle: Optional[int] = None # set provenance tracing kernel mapping for ExternKernel types if config.trace.provenance_tracking_level != 0: provenance_debug_handle = set_kernel_post_grad_provenance_tracing( node, kernel_name, is_extern=True ) +======= + # set provenance tracing kernel mapping for ExternKernel types + if config.trace.enabled: + set_kernel_post_grad_provenance_tracing(node, kernel_name, is_extern=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.wrapper._generate_extern_kernel_out_helper( kernel_name, node.codegen_reference(), node.output_view.codegen_reference() if node.output_view else None, args, device, +<<<<<<< HEAD provenance_debug_handle, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: @@ -608,6 +645,7 @@ def __str__(self) -> str: return f"{type(self).__name__}({', '.join(args)})" +<<<<<<< HEAD class EfficientPeakEstimate: def __init__(self): from ..memory import estimate_peak_memory, get_freeable_input_buf @@ -648,10 +686,13 @@ def update_peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass class AllocateLine(MemoryPlanningLine): node: BufferLike +<<<<<<< HEAD def __post_init__(self): assert V.graph.scheduler.current_node is not None self.scheduler_node_index = V.graph.scheduler.nodes.index( @@ -666,6 +707,8 @@ def should_reuse_buffer(self, free_line: FreeIfNotReusedLine, size: int) -> bool new_peak_memory = size + peak_memory_in_range return new_peak_memory <= overall_peak_memory +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: if self.node.get_name() in V.graph.removed_buffers: return NullLine(self.wrapper) @@ -674,6 +717,7 @@ def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: key = buffer_reuse_key(self.node) if config.allow_buffer_reuse and key in state: free_line = state.pop(key) +<<<<<<< HEAD size = V.graph.sizevars.size_hint( V.graph.get_allocation_storage_size(self.node), fallback=0 ) * get_dtype_size(self.node.get_dtype()) @@ -684,6 +728,10 @@ def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: else: state.push(key, free_line) return self +======= + free_line.is_reused = True + return ReuseLine(self.wrapper, free_line.node, self.node) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.node.get_device_or_error().type == "cpu": static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node) @@ -708,12 +756,15 @@ class FreeIfNotReusedLine(MemoryPlanningLine): node: BufferLike is_reused: bool = False +<<<<<<< HEAD def __post_init__(self): assert V.graph.scheduler.current_node is not None self.scheduler_node_index = V.graph.scheduler.nodes.index( V.graph.scheduler.current_node ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: if len(self.node.get_inputs_that_alias_output()) > 0: return self @@ -1053,12 +1104,18 @@ def write_header(self) -> None: aot_config_comment = "" if context is not None and context.aot_graph_name is not None: aot_config_comment = f"# AOT ID: {context.aot_graph_name}" +<<<<<<< HEAD inductor_debug_utils = "" if int(config.aot_inductor.debug_intermediate_value_printer) > 0: inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info" elif torch._inductor.config.test_configs.track_memory_lifecycle: inductor_debug_utils = "from torch._inductor.runtime.debug_utils import tracked_empty_strided\n" +======= + aot_inductor_debug_utils = "" + if int(config.aot_inductor.debug_intermediate_value_printer) > 0: + aot_inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.imports.splice( f""" {aot_config_comment} @@ -1076,7 +1133,11 @@ def write_header(self) -> None: from torch import device, empty_strided from {async_compile.__name__} import AsyncCompile from torch._inductor.select_algorithm import extern_kernels +<<<<<<< HEAD {inductor_debug_utils} +======= + {aot_inductor_debug_utils} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, strip=True, ) @@ -1088,10 +1149,15 @@ def write_header(self) -> None: assert_size_stride = torch._C._dynamo.guards.assert_size_stride assert_alignment = torch._C._dynamo.guards.assert_alignment empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu +<<<<<<< HEAD empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia +======= + empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda + empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor alloc_from_pool = torch.ops.inductor._alloc_from_pool async_compile = AsyncCompile() @@ -1152,6 +1218,7 @@ def write_triton_header_once(self) -> None: ) def write_get_raw_stream_header(self) -> None: +<<<<<<< HEAD import_get_raw_stream_str = V.graph.device_ops.import_get_raw_stream_as( "get_raw_stream" ) @@ -1161,6 +1228,16 @@ def write_get_raw_stream_header(self) -> None: if not V.graph.cpp_wrapper: if not self.imports.contains(import_get_raw_stream_str): self.imports.writeline(import_get_raw_stream_str) +======= + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + if not V.graph.cpp_wrapper: + self.imports.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @cache_on_self def write_get_raw_stream_header_once(self) -> None: @@ -1288,6 +1365,7 @@ def write_prefix(self) -> None: self.write_args(graph_input_names) self.codegen_inputs() +<<<<<<< HEAD # avoid duplicating asserts for both partition functions and # the call function when using cudagraph partition @@ -1296,6 +1374,9 @@ def write_prefix(self) -> None: and (not is_codegen_graph_partition_subgraph(self)) ): self.codegen_input_size_and_nan_asserts() +======= + self.codegen_input_size_and_nan_asserts() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def codegen_input_size_and_nan_asserts(self) -> None: if config.size_asserts: @@ -1307,7 +1388,11 @@ def codegen_input_size_and_nan_asserts(self) -> None: # that stream caching happens per graph instance. this # is important for nested subgraph codegening. def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str: +<<<<<<< HEAD self.write_get_raw_stream_header() +======= + self.write_get_raw_stream_header_once() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) name = f"stream{device_idx}" if config.triton.autotune_at_compile_time: self.kernel_autotune_calls.writeline( @@ -1350,6 +1435,12 @@ def codegen_device_guard_enter(self, device_idx: int) -> None: f"with {V.graph.device_ops.device_guard(device_idx)}:" ) self.kernel_autotune_calls.do_indent() +<<<<<<< HEAD +======= + self.kernel_autotune_calls.writeline( + V.graph.device_ops.set_device(device_idx) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_codegen_graph_partition_subgraph(self): # Need get_raw_stream for subgraph self.write_get_raw_stream_header() @@ -1453,13 +1544,19 @@ def _generate_extern_kernel_out_helper( out_view: Optional[str], args: list[str], device: str, +<<<<<<< HEAD debug_handle: Optional[int] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: # add debug printer code for triton kernel calls at (jit) inductor level debug_printer_manager = V.graph.wrapper_code.debug_printer debug_printer_manager.set_printer_args(args, kernel, None, None, "extern") args.append(f"out={out_view if out_view else out}") +<<<<<<< HEAD self.write_provenance_debug_handle(kernel, debug_handle) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with debug_printer_manager: self.writeline(f"{kernel}({', '.join(args)})") @@ -1737,8 +1834,11 @@ def run_wrapper_ir_passes(self, is_inference: bool): if is_inference and config.memory_planning: self.memory_plan() else: +<<<<<<< HEAD if config.allow_buffer_reuse: self.estimate_peak = EfficientPeakEstimate() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.memory_plan_reuse() def codegen_input_symbol_assignment( @@ -1830,8 +1930,12 @@ def ensure_size_computed(self, sym: sympy.Symbol): return self.computed_sizes.add(sym) expr = V.graph.sizevars.inv_precomputed_replacements[sym] +<<<<<<< HEAD arg = SymbolicCallArg(sym, expr) self.writeline(SymbolicCallArgLine(self, arg, V.graph)) +======= + self.writeline(f"{sym} = {pexpr(expr)}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def finalize_prefix(self): pass @@ -1859,9 +1963,13 @@ def codegen_python_shape_tuple(self, shape: Sequence[Expr]) -> str: def codegen_shape_tuple(self, shape: Sequence[Expr]) -> str: return self.codegen_python_shape_tuple(shape) +<<<<<<< HEAD def codegen_alloc_from_pool( self, name, offset, dtype, shape, stride ) -> tuple[str, list[str]]: +======= + def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return "alloc_from_pool({})".format( ", ".join( [ @@ -1872,7 +1980,11 @@ def codegen_alloc_from_pool( self.codegen_python_shape_tuple(stride), ] ) +<<<<<<< HEAD ), [] +======= + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def codegen_reinterpret_view( self, @@ -1903,11 +2015,16 @@ def codegen_reinterpret_view( f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})" ) +<<<<<<< HEAD def codegen_device_copy(self, src, dst, non_blocking: Union[bool, str]): +======= + def codegen_device_copy(self, src, dst, non_blocking: bool): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.writeline(f"{dst}.copy_({src}, {non_blocking})") def codegen_multi_output(self, node: ir.MultiOutput): result_name = node.get_name() +<<<<<<< HEAD arg_name = node.input_name(0) self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices)) @@ -1919,6 +2036,11 @@ def codegen_dynamic_select_index(self, node): # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol)) +======= + arg_name = node.inputs[0].get_name() + self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices)) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def codegen_dynamic_scalar(self, node): (data,) = (t.codegen_reference() for t in node.inputs) if len(node.keypath) == 0: @@ -2301,7 +2423,10 @@ def rename_sizes_for_launcher(expr: Union[int, sympy.Expr]) -> sympy.Expr: "config": config_to_dict(cfg), "python": [*map(pexpr, grid)], "cpp": [*map(cexpr, grid)], +<<<<<<< HEAD "python_slow": [*map(pexpr, grid)], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ) inductor_meta = { @@ -2373,10 +2498,16 @@ def rename_sizes_for_launcher(expr: Union[int, sympy.Expr]) -> sympy.Expr: return name, triton_meta, extra_launcher_call_args def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = None): +<<<<<<< HEAD sym_name = f"{kernel_name}_{tree.prefix}numel" if suffix is not None: sym_name += f"_{suffix}" sym = sympy.Symbol(sym_name, is_integer=True, is_positive=True) +======= + expr = f"{kernel_name}_{tree.prefix}numel" + if suffix is not None: + expr += f"_{suffix}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We can get symbolic expressions here, like s0*64 # It is fine to have them here, but we need to handle them correctly as their own type @@ -2385,7 +2516,11 @@ def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = No # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for # constant now, need type info. I agree, this needs type info, and while this is not true type info # it suffices as a type hint for the purposes of producing the correct code for this type. +<<<<<<< HEAD arg = SymbolicCallArg(sym, tree.numel) +======= + arg = SymbolicCallArg(expr, tree.numel) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.writeline(SymbolicCallArgLine(self, arg, V.graph)) return arg @@ -2609,7 +2744,10 @@ def generate_kernel_call( raw_args=None, triton_meta=None, original_fxnode_name=None, +<<<<<<< HEAD debug_handle: Optional[int] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Generates kernel call code. @@ -2629,7 +2767,10 @@ def generate_kernel_call( ) device = device or V.graph.get_current_device_or_throw() +<<<<<<< HEAD self.write_provenance_debug_handle(kernel_name, debug_handle) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.writeline( KernelCallLine( self, @@ -2661,6 +2802,7 @@ def _generate_kernel_call_helper( original_fxnode_name=None, ): device = device or V.graph.get_current_device_or_throw() +<<<<<<< HEAD if not triton and device.type != "cuda": if device.type == "cpu": self.writeline(self.wrap_kernel_call(kernel_name, call_args)) @@ -2671,6 +2813,10 @@ def _generate_kernel_call_helper( ) else: raise RuntimeError(f"device {device.type} nyi") +======= + if not (triton or device.type != "cpu"): + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return call_args_str = self.prepare_triton_kernel_call(call_args) @@ -2804,6 +2950,7 @@ def infer_arg_by_inputs(raw_keys, raw_args, idx, reused_args): arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg) all_args.append(arg_str if key is None else f"{key}={arg_str}") +<<<<<<< HEAD # Make sure kernel launch under a device guard because models don't always run on device 0 self.kernel_autotune_calls.writeline( f"with {V.graph.device_ops.device_guard(device.index)}:" @@ -2814,6 +2961,11 @@ def infer_arg_by_inputs(raw_keys, raw_args, idx, reused_args): ) self.kernel_autotune_calls.do_unindent() +======= + self.kernel_autotune_calls.writeline( + f"{kernel_name}.run({', '.join(all_args)}, stream={stream_name})" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.kernel_autotune_calls.writeline( DelayReplaceLine("", get_autotune_deletion_call, "") ) @@ -2880,6 +3032,7 @@ def make_buffer_allocation(self, buffer: BufferLike): shape = tuple(buffer.get_size()) allocation_shape = tuple(V.graph.get_allocation_size(buffer)) stride = tuple(buffer.get_stride()) +<<<<<<< HEAD is_pinned = buffer.get_is_pinned() return self.make_allocation( buffer.get_name(), device, dtype, shape, stride, allocation_shape, is_pinned @@ -2895,6 +3048,14 @@ def write_memory_track_allocation_once(self): def make_allocation( self, name, device, dtype, shape, stride, allocation_shape=None, is_pinned=False +======= + return self.make_allocation( + buffer.get_name(), device, dtype, shape, stride, allocation_shape + ) + + def make_allocation( + self, name, device, dtype, shape, stride, allocation_shape=None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): if allocation_shape is None: allocation_shape = shape @@ -2904,6 +3065,7 @@ def make_allocation( allocation_shape ) codegen_stride_tuple = self.codegen_python_shape_tuple(stride) +<<<<<<< HEAD if torch._inductor.config.test_configs.track_memory_lifecycle: out = ( f"{name} = tracked_empty_strided(" @@ -2921,6 +3083,9 @@ def make_allocation( f"{dtype})" ) elif device.type in ("cpu", "cuda", "xpu", "mtia"): +======= + if device.type in ("cpu", "cuda", "xpu"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # optimized path for faster allocations, saving ~2us versus the stuff below out = ( f"{name} = empty_strided_{device.type}(" @@ -2956,6 +3121,7 @@ def make_free_by_names(self, names_to_del: list[str]): def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): return f"{self.declare_maybe_reference}{new_name} = {old_name}{del_line}{self.ending} {self.comment} reuse" +<<<<<<< HEAD def write_provenance_debug_handle( self, kernel_name, @@ -2966,6 +3132,8 @@ def write_provenance_debug_handle( f"{self.comment} [Provenance debug handles] {kernel_name}:{debug_handle}" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool): assert old.get_dtype() == new.get_dtype() old_name = old.get_name() @@ -3371,6 +3539,7 @@ def codegen_conditional(self, conditional): self.codegen_subgraph(conditional.false_subgraph, outer_inputs, name) self.writeline(ExitSubgraphLine(self)) +<<<<<<< HEAD def codegen_while_loop(self, while_loop, stack_output): """while_loop is codegened as a host side while_loop""" @@ -3383,6 +3552,9 @@ def codegen_subgraph(subgraph, outer_inputs, outer_outputs): subgraph, outer_inputs, outer_outputs ) +======= + def codegen_while_loop(self, while_loop): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) name = while_loop.get_name() outer_carried_inputs = [ buf.codegen_reference() for buf in while_loop.carried_inputs @@ -3391,6 +3563,7 @@ def codegen_subgraph(subgraph, outer_inputs, outer_outputs): buf.codegen_reference() for buf in while_loop.additional_inputs ] +<<<<<<< HEAD ckp_offset = len(outer_carried_inputs) self.writeline(f"{name} = [None] * {len(outer_carried_inputs)}") if stack_output: @@ -3398,6 +3571,9 @@ def codegen_subgraph(subgraph, outer_inputs, outer_outputs): f"{name}.extend([[] for _ in range({len(outer_carried_inputs)})])" ) +======= + self.writeline(f"{name} = [None] * {len(outer_carried_inputs)}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i, inp in enumerate(outer_carried_inputs): # set the initial state before the loop self.writeline(f"{name}[{i}] = {inp}") @@ -3414,6 +3590,7 @@ def codegen_subgraph(subgraph, outer_inputs, outer_outputs): # the carried_inputs part of the inputs, the additional ones # are passed in as they're before. body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)] +<<<<<<< HEAD # Check condition at the beginning and set up flag codegen_subgraph( while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs @@ -3469,6 +3646,34 @@ def codegen_subgraph(subgraph, outer_inputs, outer_outputs): f"{name}[{i}] = torch.stack({name}[{i + ckp_offset}], dim=0)" ) self.writeline(ExitSubgraphLine(self)) +======= + + self.writeline("while True:") + self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph)) + + if V.graph.aot_mode: + self.codegen_subgraph_by_inlining( + while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs + ) + else: + self.codegen_subgraph_with_flattened_outputs( + while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs + ) + self.writeline( + f"if not {cond_outer_outputs[0]}: break" + ) # condition doesn't hold + self.writeline(ExitSubgraphLine(self)) + self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) + if V.graph.aot_mode: + self.codegen_subgraph_by_inlining( + while_loop.body_subgraph, body_outer_inputs, body_outer_outputs + ) + else: + self.codegen_subgraph_with_flattened_outputs( + while_loop.body_subgraph, body_outer_inputs, body_outer_outputs + ) + self.writeline(ExitSubgraphLine(self)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def statically_known_int_or_none(x): diff --git a/torch/_inductor/codegen/wrapper_fxir.py b/torch/_inductor/codegen/wrapper_fxir.py index 29905b11f3b97..2078a2fc33317 100644 --- a/torch/_inductor/codegen/wrapper_fxir.py +++ b/torch/_inductor/codegen/wrapper_fxir.py @@ -4,37 +4,56 @@ import operator import textwrap from collections import Counter +<<<<<<< HEAD from collections.abc import Sequence +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Any, Callable, Optional, Union import sympy import torch +<<<<<<< HEAD from torch._export.passes._node_metadata_hook import ( _node_metadata_hook, _set_node_metadata_hook, ) from torch._export.utils import _detect_fake_mode_from_gm +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._higher_order_ops.triton_kernel_wrap import ( TraceableTritonKernelWrapper, tracing_triton_hopifier_singleton, triton_kernel_wrapper_mutation, ) +<<<<<<< HEAD from torch._inductor.codecache import LambdaFuture, PyCodeCache from torch._inductor.runtime.triton_heuristics import CachingAutotuner from torch._inductor.select_algorithm import extern_kernels # noqa: F401 from torch._inductor.utils import convert_shape_to_symint, sympy_product +======= +from torch._inductor.codecache import PyCodeCache +from torch._inductor.runtime.triton_heuristics import CachingAutotuner +from torch._inductor.select_algorithm import extern_kernels # noqa: F401 +from torch._inductor.utils import sympy_product +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.virtualized import V from torch._library.triton import wrap_triton from torch.fx import GraphModule from torch.utils import _pytree as pytree from torch.utils._sympy.functions import FloorDiv +<<<<<<< HEAD from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp from torch.utils._sympy.reference import OptimizedPythonReferenceAnalysis from .. import config, ir from ..runtime.triton_compat import Config from ..utils import LineContext +======= + +from .. import config, ir +from ..utils import convert_shape_to_symint, convert_to_symint, LineContext +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .common import ( CodegenSymbol, FileBackedGraphModule, @@ -101,6 +120,7 @@ class TritonKernel: wrapped: TraceableTritonKernelWrapper +<<<<<<< HEAD def replace_floor_div(expr: sympy.Expr) -> sympy.Expr: """ Replace sympy.floor with FloorDiv. @@ -133,6 +153,8 @@ def replace_floor_div(expr: sympy.Expr) -> sympy.Expr: return sympy.floor(expr) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class WrapperFxCodegen(PythonWrapperCodegen): """ Backend to generate wrapper code as an FX IR graph. @@ -197,8 +219,11 @@ def __post_init__(self) -> None: ] = {} # Symbol table for codegen. self.kernels: dict[str, TritonKernel] = {} # Table to store Triton kernels. self._unique_symbol_ids: Counter[str] = Counter() +<<<<<<< HEAD self.tracer = torch.fx.proxy.GraphAppendingTracer(graph) self.expr_to_proxy: dict[sympy.Expr, torch.fx.Proxy] = {} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _import_kernel(self, code: str, kernel_name: str) -> CachingAutotuner: """ @@ -208,9 +233,12 @@ def _import_kernel(self, code: str, kernel_name: str) -> CachingAutotuner: mod = PyCodeCache.load(module_code) kernel = getattr(mod, kernel_name) +<<<<<<< HEAD if isinstance(kernel, LambdaFuture): kernel = kernel.result() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not isinstance(kernel, CachingAutotuner): raise NotImplementedError( textwrap.dedent(f""" @@ -236,6 +264,17 @@ def _fake_tensor( device=device, ) +<<<<<<< HEAD +======= + def _create_meta_from_buffer( + self, node: torch.fx.Node, buffer: CodegenBuffer + ) -> None: + name = buffer.get_name() + assert name + node.name = name + node.meta["val"] = buffer.get_example() + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _create_as_strided( self, input_node: torch.fx.Node, @@ -247,9 +286,15 @@ def _create_as_strided( torch.as_strided, args=( input_node, +<<<<<<< HEAD self._generate_sym_nodes(size), self._generate_sym_nodes(stride), self._generate_sym_node(offset), +======= + convert_shape_to_symint(size), + convert_shape_to_symint(stride), + convert_to_symint(offset), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ) @@ -298,6 +343,7 @@ def _generate_graph_inputs(self) -> None: """ Converts graph inputs to FX placeholders. """ +<<<<<<< HEAD for node in V.graph.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr] name = node.name @@ -368,6 +414,18 @@ def _generate_graph_constants(self) -> None: node.meta["val"] = value setattr(self.gm, name, value) self.buffer_to_node[name] = node +======= + for name, ir_node in V.graph.graph_inputs.items(): + # Introduce a new symbol for constant inputs. + buffer = ( + SymbolBuffer(sympy.Symbol(name, is_integer=True)) + if isinstance(ir_node, (int, float, sympy.Integer, sympy.Float)) + else self._get_buffer(ir_node) + ) + node = self.gm.graph.placeholder(buffer.get_name()) + self._create_meta_from_buffer(node, buffer) + self._record_allocation(buffer, node) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _generate_buffer(self, node: ir.IRNode) -> Optional[torch.fx.Node]: """ @@ -380,7 +438,11 @@ def generate_to_buffer(node: ir.IRNode) -> Optional[BufferLike]: return node elif isinstance(node, ir.NoneAsConstantBuffer): return None +<<<<<<< HEAD elif isinstance(node, ir.MutableBox): +======= + elif isinstance(node, ir.StorageBox): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return generate_to_buffer(node.data) elif isinstance(node, ir.ReinterpretView): # We need to introduce a new symbol if the output is a ReinterpretView. @@ -429,6 +491,7 @@ def generate(self) -> torch.fx.GraphModule: Main entrypoint for FX codegen. """ self._generate_graph_inputs() +<<<<<<< HEAD self._generate_graph_constants() fake_mode = _detect_fake_mode_from_gm(self.gm) @@ -457,11 +520,33 @@ def generate(self) -> torch.fx.GraphModule: """ ) ) +======= + + # Generate FX IR from Wrapper IR lines. + for line in self.lines: + if isinstance(line, WrapperLine): + line.codegen_fx(self)(line) + elif isinstance(line, LineContext): + # Ignore line context in FX IR. + pass + else: + raise NotImplementedError( + textwrap.dedent( + f""" + Found line of unrecognized type '{type(line)}': + '{line}' + + FX conversion only supports Wrapper IR lines. + """ + ) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._generate_output() self.gm.recompile() return self.gm +<<<<<<< HEAD def _sympy_interp(self, expr: sympy.Expr) -> torch.fx.Proxy: # hash cons if expr in self.expr_to_proxy: @@ -512,6 +597,8 @@ def _generate_sym_nodes( ) -> list[Union[int, torch.fx.Node]]: return [self._generate_sym_node(s) for s in shape] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _generate_allocate(self, line: WrapperLine) -> None: assert isinstance(line, AllocateLine) buffer = line.node @@ -520,8 +607,13 @@ def _generate_allocate(self, line: WrapperLine) -> None: device = buffer.get_device() dtype = buffer.get_dtype() +<<<<<<< HEAD shape = self._generate_sym_nodes(buffer.get_size()) stride = self._generate_sym_nodes(buffer.get_stride()) +======= + shape = convert_shape_to_symint(buffer.get_size()) + stride = convert_shape_to_symint(buffer.get_stride()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) node = self.gm.graph.call_function( torch.empty_strided, @@ -530,6 +622,10 @@ def _generate_allocate(self, line: WrapperLine) -> None: ) assert name node.name = name +<<<<<<< HEAD +======= + self._create_meta_from_buffer(node, buffer) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._record_allocation(buffer, node) def _generate_comment(self, line: WrapperLine) -> None: @@ -600,6 +696,10 @@ def _generate_reinterpret_helper( # Map ReinterpretView to as_strided. result_node = self._create_as_strided(input_node, size, stride, offset) result_node.name = name +<<<<<<< HEAD +======= + result_node.meta["val"] = layout.get_example() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._record_allocation(result_buffer, result_node) def _generate_reuse(self, line: WrapperLine) -> None: @@ -622,6 +722,10 @@ def _generate_reuse(self, line: WrapperLine) -> None: or old.get_offset() != offset ): result_node = self._create_as_strided(old_node, size, stride, offset) +<<<<<<< HEAD +======= + self._create_meta_from_buffer(result_node, new) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._record_allocation(new, result_node) @@ -636,6 +740,7 @@ def _generate_reuse(self, line: WrapperLine) -> None: def _generate_multi_output(self, line: WrapperLine) -> None: assert isinstance(line, MultiOutputLine) +<<<<<<< HEAD arg_node = self.buffer_to_node[line.arg_name] # For non-tuple / non-list outputs, map the @@ -644,12 +749,20 @@ def _generate_multi_output(self, line: WrapperLine) -> None: self.buffer_to_node[line.result_name] = arg_node return +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Extract the index for tuple access. inds = line.indices[0][1:] assert len(inds) == 1, f"Cannot convert {inds} to an index." idx = inds[0] +<<<<<<< HEAD + node = self.gm.graph.call_function(operator.getitem, args=(arg_node, idx)) +======= + arg_node = self.buffer_to_node[line.arg_name] node = self.gm.graph.call_function(operator.getitem, args=(arg_node, idx)) + node.meta["val"] = arg_node.meta["val"][idx] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) node.name = line.result_name self.buffer_to_node[line.result_name] = node @@ -672,10 +785,13 @@ def _generate_triton_call(self, line: WrapperLine) -> None: call_args = self._lookup_args(line.call_args) kernel = self.kernels[line.kernel_name] tuner = kernel.tuner +<<<<<<< HEAD # Use python_slow mode instead of python mode to avoid # the round to neginf behaviour, which is not the convention # in other languages. tuner.grid_mode = "python_slow" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Optionally autotune the kernels. # The FX backend currently only supports compile-time tuning. @@ -713,6 +829,7 @@ def node_to_tuning_arg(arg: Any) -> Any: kernel_name, ) +<<<<<<< HEAD triton_meta = tuner.triton_meta signature = triton_meta["signature"] @@ -761,6 +878,46 @@ def add_constants_to_call_args( call_kwargs = { name: self._generate_sym_node(val) for name, val in call_kwargs.items() } +======= + kernel_config = tuner.compile_results[0].config + call_args, grid = tuner._interpret_args_grid(call_args, kernel_config) + call_kwargs = dict(zip(tuner.triton_meta["signature"], call_args)) + call_kwargs.update(kernel_config.kwargs) + + def replace_floor_div(expr: sympy.Expr) -> sympy.Expr: + """ + Converts floor(x / c) to x // c. + """ + if isinstance(expr, sympy.core.mul.Mul) and isinstance( + expr.args[0], sympy.Rational + ): + # Only the first argument of a Mul can be a Rational. + frac = expr.args[0] + numerator = sympy_product(expr.args[1:]) * frac.numerator + denominator = frac.denominator + + # Sanity check the results. + new_expr = numerator / denominator + assert V.graph.sizevars.statically_known_equals(new_expr, expr), ( + f"Unsound replacement: '{new_expr}' != '{expr}'" + ) + + return FloorDiv(numerator, denominator) + else: + return sympy.floor(expr) + + def expr_to_symint(expr: Union[int, sympy.Expr]) -> Union[int, sympy.Expr]: + return ( + convert_to_symint(expr.replace(sympy.floor, replace_floor_div)) + if isinstance(expr, sympy.Expr) + else expr + ) + + # Convert sympy expressions to symints. + # Use FloorDiv over sympy.floor, so we can get nicer Python code from FX. + wrapper_grid = [tuple(expr_to_symint(dim) for dim in grid)] + call_kwargs = {name: expr_to_symint(val) for name, val in call_kwargs.items()} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Store non-graphable kwargs in the side table. ( @@ -801,7 +958,10 @@ def _generate_extern_kernel_common( """ # Get FX nodes corresponding to the call args. +<<<<<<< HEAD assert ir.is_node_sequence(kernel.inputs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor_nodes = tuple(self._generate_buffer(arg) for arg in kernel.inputs) args = tensor_nodes + tuple(kernel.constant_args) @@ -818,11 +978,22 @@ def _generate_extern_kernel_common( else: raise NotImplementedError(f"Unrecognized output layout: {kernel.layout}") +<<<<<<< HEAD fx_node = self.gm.graph.call_function( kernel.op_overload, # type: ignore[arg-type] args=args, kwargs=kwargs, ) +======= + # Look up the kernel function from its name. + kernel_name = kernel.get_kernel_name() + module_name, kernel_name = kernel_name.split(".", 1) + op = globals()[module_name] # E.g. extern_kernels, aten, etc. + for subname in kernel_name.split("."): + op = getattr(op, subname) # E.g. extern_kernels.addmm + + fx_node = self.gm.graph.call_function(op, args=args, kwargs=kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Assign the result to the given name. if result_buffer: @@ -832,6 +1003,17 @@ def _generate_extern_kernel_common( fx_node.name = result_buffer self.buffer_to_node[result_buffer] = fx_node +<<<<<<< HEAD +======= + arg_tensors = [ + arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg + for arg in args + ] + + # Run the operation to propagate metadata. + fx_node.meta["val"] = op(*arg_tensors, **kwargs) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _generate_kernel_call(self, line: WrapperLine) -> None: assert isinstance(line, KernelCallLine) if not line.triton: @@ -854,8 +1036,12 @@ def _generate_kernel_definition(self, line: WrapperLine) -> None: def _generate_symbolic_call_arg(self, line: WrapperLine) -> None: assert isinstance(line, SymbolicCallArgLine) +<<<<<<< HEAD # Store the arg: expr mapping for later use. arg = line.arg inner_expr_proxy = self._sympy_interp(arg.inner_expr) self.expr_to_proxy[arg.inner] = inner_expr_proxy +======= + # No need for an FX node, as we will pass the arg to kernels via a SymInt. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_inductor/codegen/xpu/device_op_overrides.py b/torch/_inductor/codegen/xpu/device_op_overrides.py index 5d538ec20ca21..2c5d9f77f6d00 100644 --- a/torch/_inductor/codegen/xpu/device_op_overrides.py +++ b/torch/_inductor/codegen/xpu/device_op_overrides.py @@ -58,10 +58,17 @@ def cpp_kernel_type(self) -> str: def cpp_device_ptr(self) -> str: return "void *" +<<<<<<< HEAD def cpp_scratch( self, idx: int, workspace: TritonScratchWorkspace, prefix: Optional[str] = None ) -> Optional[tuple[list[str], str]]: return [f"void *global_scratch_{idx} = 0;"], f"global_scratch_{idx}" +======= + def cpp_global_scratch( + self, idx: int, workspace: TritonScratchWorkspace + ) -> Optional[tuple[list[str], str]]: + return None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) register_device_op_overrides("xpu", XPUDeviceOpOverrides()) diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index c24cf336e66a3..96975398031d4 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -1,14 +1,20 @@ import functools +<<<<<<< HEAD import logging import math from enum import IntEnum from typing import Optional +======= +import math +from enum import IntEnum +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sympy import torch from . import ir +<<<<<<< HEAD from .utils import get_dtype_size, snode_args_kwargs, sympy_product from .virtualized import V @@ -16,11 +22,20 @@ log = logging.getLogger(__name__) +======= +from .utils import get_dtype_size, sympy_product +from .virtualized import V + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class NCCL_COLL(IntEnum): ALL_REDUCE = 0 ALL_GATHER = 1 REDUCE_SCATTER = 2 +<<<<<<< HEAD ALL_TO_ALL = 3 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class NVIDIA_GPU_TYPE(IntEnum): @@ -55,8 +70,11 @@ def get_collective_type(node: ir.IRNode) -> NCCL_COLL: return NCCL_COLL.ALL_GATHER elif "reduce_scatter" in kernel_name: return NCCL_COLL.REDUCE_SCATTER +<<<<<<< HEAD elif "torch.ops._dtensor.shard_dim_alltoall.default" in kernel_name: return NCCL_COLL.ALL_TO_ALL +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: raise ValueError(f"Unsupported collective kernel: {kernel_name}") @@ -75,7 +93,11 @@ def get_collective_input_size_bytes(node: ir.IRNode) -> int: def get_collective_group_size(node: ir.IRNode) -> int: +<<<<<<< HEAD if isinstance(node, ir._CollectiveKernel) and not isinstance(node, ir._WaitKernel): +======= + if type(node) == ir._CollectiveKernel: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.distributed_c10d import _get_group_size_by_name return _get_group_size_by_name(node.constant_args[-1]) @@ -166,6 +188,7 @@ class NCCL_PROTO(IntEnum): ] +<<<<<<< HEAD def estimate_nccl_collective_runtime_nccl_estimator(snode) -> Optional[float]: # type: ignore[no-untyped-def] kernel = snode.node assert kernel is not None @@ -213,6 +236,11 @@ def estimate_nccl_collective_runtime_nccl_estimator(snode) -> Optional[float]: def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: """ Returns estimated NCCL collective runtime in milliseconds (ms). +======= +def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: + """ + Returns estimated NCCL collective runtime in nanoseconds (ns). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc. We aim to estimate the runtime as accurately as possible. @@ -272,8 +300,11 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: if coll == NCCL_COLL.ALL_REDUCE: nsteps = 2 * (nRanks - 1) +<<<<<<< HEAD elif coll == NCCL_COLL.ALL_TO_ALL: nsteps = 2 * (nRanks - 1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): nsteps = nRanks - 1 @@ -291,7 +322,11 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: nInterSteps = 2 * nNodes else: nInterSteps = 0 +<<<<<<< HEAD elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER, NCCL_COLL.ALL_TO_ALL): +======= + elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nInterSteps = nNodes - 1 # First compute latency in us; then at the end, convert it to ns @@ -310,9 +345,13 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: # =============== final result =============== transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns +<<<<<<< HEAD ns = transport_ns + latency_ns ms = ns / 1e6 return ms +======= + return transport_ns + latency_ns +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ################################################################################################################ diff --git a/torch/_inductor/comm_lowering.py b/torch/_inductor/comm_lowering.py index e46909432f17e..3da0006249ac2 100644 --- a/torch/_inductor/comm_lowering.py +++ b/torch/_inductor/comm_lowering.py @@ -1,5 +1,9 @@ # mypy: allow-untyped-defs import logging +<<<<<<< HEAD +======= +from typing import cast +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.utils._pytree as pytree @@ -113,12 +117,19 @@ def realize_as_comm_buffer( def _get_data(x: ir.TensorBox) -> ir.IRNode: if isinstance(x.data, ir.BaseView): # TensorBox -> *View -> StorageBox -> IRNode +<<<<<<< HEAD node = x.data.unwrap_view() assert isinstance(node, (ir.BaseView, ir.MutableBox)) return node.data elif isinstance(x.data, ir.StorageBox): # TensorBox -> StorageBox -> IRNode return x.data.data +======= + return x.data.unwrap_view().data + elif isinstance(x.data, ir.StorageBox): + # TensorBox -> StorageBox -> IRNode + return cast(ir.Buffer, x.data.data) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: raise AssertionError( "Expect the data attr of a `TensorBox` to be either " @@ -209,6 +220,7 @@ def _all_reduce(inp: ir.TensorBox, reduce_op: str, group_name: str) -> ir.Tensor inp.realize() V.graph.no_fuse_buffer_names.add(inp.get_name()) inp = ir.ExternKernel.require_contiguous(inp) +<<<<<<< HEAD # Because we are lowering as inplace c10d.all_reduce_, we should generate # _AllReduce_Kernel instead of _AllReduceKernel. ir._AllReduce_Kernel.create_inplace( @@ -218,6 +230,12 @@ def _all_reduce(inp: ir.TensorBox, reduce_op: str, group_name: str) -> ir.Tensor group_name, # type: ignore[arg-type] ) return inp # type: ignore[return-value] +======= + ir._CollectiveKernel.create_inplace( + c10d.all_reduce_.default, inp, reduce_op, group_name + ) + return inp +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_comm_lowering(c10d.all_reduce_) # type: ignore[misc] def _all_reduce_( @@ -233,6 +251,7 @@ def _all_reduce_( # Lower as c10d.all_reduce_ inp = ir.ExternKernel.require_contiguous(inp) +<<<<<<< HEAD ir._AllReduce_Kernel.create_inplace( c10d.all_reduce_.default, inp, # type: ignore[arg-type] @@ -240,6 +259,12 @@ def _all_reduce_( group_name, # type: ignore[arg-type] ) return inp # type: ignore[return-value] +======= + ir._CollectiveKernel.create_inplace( + c10d.all_reduce_.default, inp, reduce_op, group_name + ) + return inp +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_comm_lowering(c10d.all_reduce_coalesced) def _all_reduce_coalesced(inputs, reduce_op, group_name): @@ -262,6 +287,7 @@ def _all_reduce_coalesced_(inputs, reduce_op, group_name): ) return inputs +<<<<<<< HEAD def _create_out_of_place(kernel, inputs, *args) -> ir.IRNode: node = ir._CollectiveKernel.create_out_of_place(kernel, inputs, *args) assert isinstance(node, ir.IRNode) @@ -274,6 +300,17 @@ def _all_gather_into_tensor(inp, group_size, group_name): inp, group_size, group_name, +======= + @register_comm_lowering(c10d.all_gather_into_tensor) + def _all_gather_into_tensor(inp, group_size, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + c10d.all_gather_into_tensor.default, + inp, + group_size, + group_name, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @register_comm_lowering(c10d.all_gather_into_tensor_coalesced) @@ -301,12 +338,23 @@ def _all_gather_into_tensor_out(inp, group_size, group_name, *, out): @register_comm_lowering(c10d.reduce_scatter_tensor) def _reduce_scatter_tensor(inp, reduce_op, group_size, group_name): +<<<<<<< HEAD return _create_out_of_place( c10d.reduce_scatter_tensor.default, inp, reduce_op, group_size, group_name, +======= + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + c10d.reduce_scatter_tensor.default, + inp, + reduce_op, + group_size, + group_name, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @register_comm_lowering(c10d.reduce_scatter_tensor_coalesced) @@ -324,12 +372,23 @@ def _reduce_scatter_tensor_coalesced(inputs, reduce_op, group_size, group_name): @register_comm_lowering(c10d.all_to_all_single) def _all_to_all_single(inp, output_split_sizes, input_split_sizes, group_name): +<<<<<<< HEAD return _create_out_of_place( c10d.all_to_all_single.default, inp, output_split_sizes, input_split_sizes, group_name, +======= + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + c10d.all_to_all_single.default, + inp, + output_split_sizes, + input_split_sizes, + group_name, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @register_comm_lowering(c10d.broadcast) @@ -349,12 +408,23 @@ def _broadcast_(inp, src, group_name): @register_comm_lowering(torch.ops._dtensor.shard_dim_alltoall) def _shard_dim_alltoall(inp, gather_dim, shard_dim, group_name): +<<<<<<< HEAD return _create_out_of_place( torch.ops._dtensor.shard_dim_alltoall.default, inp, gather_dim, shard_dim, group_name, +======= + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + torch.ops._dtensor.shard_dim_alltoall.default, + inp, + gather_dim, + shard_dim, + group_name, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @register_comm_lowering(c10d.wait_tensor) diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index fa8bb30f238cf..e07a6a3fdfac5 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -4,14 +4,21 @@ import heapq import importlib +<<<<<<< HEAD import itertools +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import logging import operator import sys import time from collections import defaultdict from dataclasses import dataclass +<<<<<<< HEAD from typing import Any, Optional, TYPE_CHECKING, Union +======= +from typing import Any, TYPE_CHECKING +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch._logging import trace_structured @@ -20,6 +27,7 @@ from . import config, ir from .dependencies import WeakDep +<<<<<<< HEAD if TYPE_CHECKING: @@ -33,6 +41,9 @@ get_freeable_input_buf, SNodeMemory, ) +======= +from .memory import estimate_peak_memory, FreeableInputBuffer, get_freeable_input_buf +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .utils import ( contains_collective, contains_wait, @@ -52,6 +63,7 @@ from torch._inductor.scheduler import BaseSchedulerNode +<<<<<<< HEAD def align_runtime_estimations_across_all_distributed_ranks( snodes: list[BaseSchedulerNode], ): @@ -74,6 +86,8 @@ def align_runtime_estimations_across_all_distributed_ranks( snodes[i].override_estimated_runtime = median_runtime_estimations[i] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: """ Greedily schedules waits as late as possible. @@ -145,6 +159,7 @@ def reorder_communication_preserving_peak_memory( reordered_snodes, node_stats = ( _reorder_communication_preserving_peak_memory_internal(snodes) ) +<<<<<<< HEAD return reordered_snodes @@ -624,6 +639,8 @@ def is_groupable( curr = _next[curr] # type: ignore[assignment] node_stats = stats +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) improvement = {snode: node_stats[snode].improvement for snode in node_stats} total_improvement = sum([improvement[snode] for snode in improvement]) total_moves = sum([node_stats[snode].moves for snode in node_stats]) @@ -639,12 +656,16 @@ def is_groupable( "improvement", "limiting factor", "moves", +<<<<<<< HEAD "grouped", "grouped_info", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] rows = [ [ node_summary(snode), +<<<<<<< HEAD node_info.initial_exposed, node_info.final_exposed, node_info.improvement, @@ -654,6 +675,15 @@ def is_groupable( node_info.grouped_info, ] for snode, node_info in node_stats.items() +======= + node_reorder_info.initial_exposed, + node_reorder_info.final_exposed, + node_reorder_info.improvement, + node_reorder_info.limiting_factor, + node_reorder_info.moves, + ] + for snode, node_reorder_info in node_stats.items() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] if importlib.util.find_spec("tabulate"): from tabulate import tabulate @@ -668,6 +698,7 @@ def is_groupable( ) reorder_log_str += str(headers) + "\n" reorder_log_str += "\n".join(map(str, rows)) +<<<<<<< HEAD new_snodes = _group_nodes(_head, None) assert len(new_snodes) == original_snodes_num @@ -677,6 +708,8 @@ def is_groupable( reorder_log_str += f"\n peak_memory_before:{peak_memory}" reorder_log_str += f"\n peak_memory_after:{new_peak_memory}" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) overlap_log.info(reorder_log_str) trace_structured( "artifact", @@ -687,7 +720,116 @@ def is_groupable( payload_fn=lambda: reorder_log_str, ) +<<<<<<< HEAD return new_snodes, stats +======= + return reordered_snodes + + +@dataclass +class ReorderInfo: + """ + Debug info describing how an individual snode was reordered + """ + + initial_exposed: float = -1 + final_exposed: float = -1 + limiting_factor: str = "None" + moves: int = 0 + + @property + def improvement(self): + return self.initial_exposed - self.final_exposed + + +def _reorder_communication_preserving_peak_memory_internal( + snodes: list[BaseSchedulerNode], +) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: + """ + Internal testing helper that also returns debug info. + Returns: + - reordered snodes list + - dict {snode: ReorderInfo} + """ + # heuristic to avoid degenerating to quadratic time + MOVE_LIMIT = len(snodes) * 100 + total_moves = 0 + # TODO - experiment with whether this limit is useful, setting `len(snodes)` disables it + PER_COLLECTIVE_PREFETCH_LIMIT = len(snodes) + if config.reorder_prefetch_limit is not None: + PER_COLLECTIVE_PREFETCH_LIMIT = config.reorder_prefetch_limit + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf( + snodes, graph_inputs + ) + peak_memory, curr_memory = estimate_peak_memory( + snodes, name_to_freeable_input_buf, graph_outputs + ) + runtimes = {snode: estimate_op_runtime(snode) for snode in snodes} + + # debug stats + stats: dict[BaseSchedulerNode, ReorderInfo] = {} + + def exposed_communication_time(collective_snode, remaining_snodes): + # assumes a linear schedule and computes the overlap of the collective with the remaining nodes + comm_time = estimate_op_runtime(collective_snode) + compute_time = 0.0 + for snode in remaining_snodes: + if contains_collective(snode): + continue + if contains_wait(snode): + # TODO - if the wait is for a collective that started before this collective or on another stream, + # we can ignore it. Otherwise, it's the end of the road for overlap opportunities + break + + compute_time += runtimes[snode] + return max(0, comm_time - compute_time) + + for i, snode in enumerate(snodes): + if contains_collective(snode): + reorder_info = stats[snode] = ReorderInfo() + reorder_info.initial_exposed = reorder_info.final_exposed = ( + exposed_communication_time(snode, snodes[i + 1 :]) + ) + if total_moves >= MOVE_LIMIT: + reorder_info.limiting_factor = "move limit" + continue + for j in range(i - 1, -1, -1): + prev_snode = snodes[j] + if j < max(0, i - PER_COLLECTIVE_PREFETCH_LIMIT): + reorder_info.limiting_factor = "prefetch limit" + break + if contains_collective(prev_snode): + reorder_info.limiting_factor = "collective ordering" + break + dep_names = OrderedSet([s.name for s in snode.unmet_dependencies]) + if any( + o.get_name() in dep_names for o in prev_snode.get_outputs() + ) and not contains_wait(prev_snode): + reorder_info.limiting_factor = "data dependency" + break + if peak_memory - curr_memory[j] < curr_memory[j - 1] - curr_memory[j]: + reorder_info.limiting_factor = "peak memory" + break + if reorder_info.final_exposed > runtimes[snode]: + reorder_info.limiting_factor = "sufficient overlapping" + break + reorder_info.moves += 1 + total_moves += 1 + tmp = snodes[j] + snodes[j] = snodes[j + 1] + snodes[j + 1] = tmp + # swapping nodes j and j+1 affects curr memory at j only + j_plus_one_alloc = curr_memory[j + 1] - curr_memory[j] + j_alloc = curr_memory[j] - curr_memory[j - 1] + curr_memory[j] = curr_memory[j] - j_alloc + j_plus_one_alloc + reorder_info.final_exposed = exposed_communication_time( + snode, snodes[j + 1 :] + ) + + return snodes, stats +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _schedule_for_comm( @@ -861,13 +1003,18 @@ def decide_global_ordering_of_comms( # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm mutating_buf = next(iter(comm_nodes[i].get_buffer_names())) for buf in comm_nodes[i - 1].get_buffer_names(): +<<<<<<< HEAD comm_nodes[i].add_fake_dep( WeakDep(buf, mutating_buf=mutating_buf, is_fake=True) ) +======= + comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return nodes +<<<<<<< HEAD @dataclass class SinkWaitInfo: grouped: int = 0 @@ -1246,6 +1393,8 @@ def sink_waits_iterative( return _sink_waits_iterative_internal(snodes)[0] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def estimate_op_runtime(snode: BaseSchedulerNode) -> float: """ Returns estimated op runtime in nanoseconds (ns) @@ -1263,9 +1412,13 @@ def node_summary(snode): if len(snodes) == 1: detail = "" if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)): +<<<<<<< HEAD outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}" ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}" detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}\n ({ins_str})" +======= + detail = f" ({snode.node.python_kernel_name})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) layouts = [child.node.get_output_spec() for child in snode.get_nodes()] out_tensor_info = ",".join( [ @@ -1622,10 +1775,21 @@ def remove_unused_getitem(g): CallFunction( torch.ops.fsdp.all_gather_copy_in.default, KeywordArg("all_gather_inputs"), +<<<<<<< HEAD KeywordArg("all_gather_output"), KeywordArg("inp_split_sizes"), KeywordArg("all_gather_input_numel"), KeywordArg("rank"), +======= + KeywordArg("inp_split_sizes"), + KeywordArg("all_gather_input_numel"), + KeywordArg("world_size"), + KeywordArg("rank"), + KeywordArg("dtype"), + KeywordArg("device"), + KeywordArg("group_name_inner"), + KeywordArg("allocate_memory_from_process_group"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), KeywordArg("item_idx"), ), @@ -1658,10 +1822,21 @@ def repl( repl, [ kwargs["all_gather_inputs"], +<<<<<<< HEAD kwargs["all_gather_output"], kwargs["inp_split_sizes"], kwargs["all_gather_input_numel"], kwargs["rank"], +======= + kwargs["inp_split_sizes"], + kwargs["all_gather_input_numel"], + kwargs["world_size"], + kwargs["rank"], + kwargs["dtype"], + kwargs["device"], + kwargs["group_name_inner"], + kwargs["allocate_memory_from_process_group"], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs["group_size"], kwargs["group_name"], ], @@ -1840,7 +2015,11 @@ def _create_group_node(snodes_to_group): mutating_buf = next(iter(ag_group_node.get_buffer_names())) for o in prev_ag_wait.get_outputs(): ag_group_node.add_fake_dep( +<<<<<<< HEAD WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True) +======= + WeakDep(o.get_name(), mutating_buf=mutating_buf) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) prev_ag_wait = wait_group_node @@ -1852,7 +2031,11 @@ def _create_group_node(snodes_to_group): mutating_buf = next(iter(rs_group_node.get_buffer_names())) for o in prev_rs_wait.get_outputs(): rs_group_node.add_fake_dep( +<<<<<<< HEAD WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True) +======= + WeakDep(o.get_name(), mutating_buf=mutating_buf) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) prev_rs_wait = wait_group_node diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 9e46613300456..fa86b19993aa5 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1,7 +1,10 @@ from __future__ import annotations import contextlib +<<<<<<< HEAD import copy +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import enum import functools import io @@ -15,7 +18,10 @@ from abc import ABC, abstractmethod from collections import defaultdict from contextlib import AbstractContextManager +<<<<<<< HEAD from dataclasses import dataclass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from inspect import currentframe from itertools import count from operator import attrgetter @@ -23,7 +29,11 @@ from typing_extensions import Never, override, ParamSpec, Protocol, TypedDict, Unpack from unittest import mock +<<<<<<< HEAD import torch._inductor.async_compile +======= +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch.fx import torch.utils._pytree as pytree from functorch.compile import min_cut_rematerialization_partition @@ -54,7 +64,10 @@ ) from torch._functorch.aot_autograd import ( aot_export_module, +<<<<<<< HEAD GraphOutputName, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) make_boxed_func, SerializableAOTDispatchCompiler, ) @@ -65,11 +78,15 @@ log_cudagraph_skip_and_bump_counter, PlaceholderInfo, ) +<<<<<<< HEAD from torch._inductor.custom_graph_pass import CustomPartitionerFn from torch._inductor.debug import ( create_mapping_pre_post_grad_nodes, save_args_for_compile_fx_inner, ) +======= +from torch._inductor.debug import save_args_for_compile_fx_inner +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.output_code import ( CompiledAOTI, CompiledFxGraph, @@ -155,8 +172,11 @@ def log_optimus_to_scuba(*args: object, **kwargs: object) -> None: from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log if TYPE_CHECKING: +<<<<<<< HEAD import types +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._functorch._aot_autograd.schemas import ( FQN, GraphInputName, @@ -173,6 +193,7 @@ class FxCompileMode(enum.Enum): SUBPROCESS = 2 +<<<<<<< HEAD @dataclass class FxCompileConfig: mode: FxCompileMode @@ -192,13 +213,27 @@ def _fx_compile_mode_default() -> FxCompileConfig: if value.lower().startswith("progressive+"): use_progressive = True value = value[12:] +======= +# Return compile mode and use_async flag +def _fx_compile_mode_default() -> tuple[FxCompileMode, bool]: + name = "TORCHINDUCTOR_FX_COMPILE_MODE" + value = os.environ.get(name) + if value is None: + return FxCompileMode.NORMAL, False + + use_async = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if value.lower().startswith("async+"): use_async = True value = value[6:] try: value = value.upper() +<<<<<<< HEAD return FxCompileConfig(FxCompileMode[value], use_async, use_progressive) +======= + return FxCompileMode[value], use_async +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except KeyError: import logging @@ -211,6 +246,7 @@ def _fx_compile_mode_default() -> FxCompileConfig: ) # Remove from the environment so subprocesses don't ALSO complain. os.environ.pop(name) +<<<<<<< HEAD return FxCompileConfig(FxCompileMode.NORMAL, False, False) @@ -225,6 +261,12 @@ def _get_progression_configs() -> list[dict[str, Any]]: fx_compile_mode = _fx_compile_config.mode fx_compile_async = _fx_compile_config.use_async fx_compile_progressive = _fx_compile_config.use_progressive +======= + return FxCompileMode.NORMAL, False + + +fx_compile_mode, fx_compile_async = _fx_compile_mode_default() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") @@ -363,6 +405,7 @@ def find_smallest_i(graph: fx.Graph, prefix: str) -> int: continue gm_target = attrgetter(target_name)(gm) model_target = attrgetter(target_name)(mod) +<<<<<<< HEAD if isinstance(gm_target, FakeScriptObject): if ( isinstance(model_target, FakeScriptObject) @@ -370,6 +413,9 @@ def find_smallest_i(graph: fx.Graph, prefix: str) -> int: ): continue elif ( +======= + if ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.equal(gm_target, model_target) and gm_target.dtype == model_target.dtype ): @@ -437,7 +483,11 @@ def _unlift_graph( from torch.export._unlift import _unlift +<<<<<<< HEAD outputs: tuple[torch.fx.Node, ...] = tuple(gm.graph.output_node().args[0]) # type: ignore[arg-type] +======= + outputs = list(gm.graph.nodes)[-1].args[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mutated_outputs = [] buffer_mutations = graph_signature.buffers_to_mutate user_input_mutations = graph_signature.user_inputs_to_mutate @@ -446,11 +496,18 @@ def _unlift_graph( value: Optional[Union[FQN, GraphInputName]] = None if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): +<<<<<<< HEAD name = GraphOutputName(out.name) if name in buffer_mutations: value = buffer_mutations[name] elif name in user_input_mutations: value = user_input_mutations[name] +======= + if out.name in buffer_mutations: + value = buffer_mutations[out.name] + elif out.name in user_input_mutations: + value = user_input_mutations[out.name] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mutated_outputs.append(value) @@ -460,6 +517,11 @@ def _unlift_graph( mutated_outputs, pytree.LeafSpec(), None, +<<<<<<< HEAD +======= + state_dict, + {}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return unlifted_gm @@ -728,7 +790,10 @@ class _CompileFxKwargs(TypedDict, total=False): layout_opt: Optional[bool] extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]] boxed_forward_device_index: Optional[BoxedDeviceIndex] +<<<<<<< HEAD fx_wrapper: bool +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _CompileFxCallable(Protocol): @@ -750,7 +815,10 @@ def compile_fx_inner( kwargs.setdefault("is_backward", False) kwargs.setdefault("graph_id", None) kwargs.setdefault("cpp_wrapper", False) +<<<<<<< HEAD kwargs.setdefault("fx_wrapper", False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs.setdefault("is_inference", False) kwargs.setdefault("boxed_forward_device_index", None) kwargs.setdefault("layout_opt", None) @@ -846,9 +914,13 @@ def _compile_fx_inner( backends_support_caching = all( backend.supports_caching for backend in ( +<<<<<<< HEAD get_wrapper_codegen_for_device( device.type, config.cpp_wrapper, config.fx_wrapper ) +======= + get_wrapper_codegen_for_device(device.type, config.cpp_wrapper) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for device in get_all_devices(gm) ) if backend is not None @@ -862,7 +934,10 @@ def _compile_fx_inner( and (config.fx_graph_cache or fx_graph_remote_cache) and not aot_mode and backends_support_caching +<<<<<<< HEAD and not torch._functorch.config.bundled_autograd_cache +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) local = config.fx_graph_cache remote = fx_graph_remote_cache @@ -920,6 +995,7 @@ def _compile_fx_inner( else: log.debug("Failed to generate FX cache key") +<<<<<<< HEAD if torch._functorch.config.bundled_autograd_cache: assert mb_compiled_graph is None assert cache_info is None @@ -951,6 +1027,12 @@ def _compile_fx_inner( # (this can happen either because cache was disabled, or we # determined the input is uncacheable) elif cache_info is None or cache_info["cache_state"] == "bypass": +======= + # CACHE BYPASS: Compile the graph, don't save it to the cache + # (this can happen either because cache was disabled, or we + # determined the input is uncacheable) + if cache_info is None or cache_info["cache_state"] == "bypass": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert mb_compiled_graph is None log.debug( "FX cache bypass reason: %s", @@ -1066,6 +1148,34 @@ def _compile_fx_inner( log.debug("FX codegen and compilation took %.3fs", time.time() - start) +<<<<<<< HEAD +======= + # Dump provenance artifacts for debugging trace + provenance_info = V.debug.log_inductor_triton_kernel_to_post_grad_node_info() + # provenance_info might be None if config.trace.enabled is not set + if provenance_info: + ( + debug_info, + node_mappings, + ) = provenance_info + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_generated_kernel_to_post_grad_nodes", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(debug_info), + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_provenance_tracking_node_mappings", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(node_mappings), + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This message is for printing overview information of inductor mm counts, shapes,etc after lowering if log.isEnabledFor(logging.INFO): mm_table_data = [] @@ -1172,7 +1282,10 @@ def codegen_and_compile( is_backward: bool = graph_kwargs.get("is_backward", False) graph_id: Optional[int] = graph_kwargs.get("graph_id", None) cpp_wrapper: bool = graph_kwargs.get("cpp_wrapper", False) +<<<<<<< HEAD fx_wrapper: bool = graph_kwargs.get("fx_wrapper", False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aot_mode: bool = V.aot_compilation is_inference: bool = graph_kwargs.get("is_inference", False) extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]] = ( @@ -1308,6 +1421,7 @@ def codegen_and_compile( }, payload_fn=lambda: inductor_post_grad_graph_str, ) +<<<<<<< HEAD if config.trace.provenance_tracking_level != 0: provenance_tracking_json = ( torch.fx.traceback.get_graph_provenance_json(gm.graph) @@ -1317,6 +1431,22 @@ def codegen_and_compile( torch._inductor.debug._pre_grad_graph_id, provenance_tracking_json, ) +======= + if config.trace.enabled: + provenance_tracking_json = ( + torch.fx.traceback.get_graph_provenance_json(gm.graph) + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_post_to_pre_grad_nodes", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(provenance_tracking_json), + ) + torch._inductor.debug._inductor_post_to_pre_grad_nodes = ( + provenance_tracking_json +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) metrics_context = get_metrics_context() @@ -1375,12 +1505,17 @@ def codegen_and_compile( is_inference=is_inference, is_backward=is_backward, is_const_graph=True, +<<<<<<< HEAD fx_wrapper=fx_wrapper, ) with ( V.set_graph_handler(const_graph), V.set_extern_kernel_nodes([]), ): +======= + ) + with V.set_graph_handler(const_graph): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert cpp_wrapper, "AOT mode only supports C++ wrapper" const_graph.run() const_wrapper_code, const_kernel_code = ( @@ -1409,14 +1544,21 @@ def codegen_and_compile( ), const_module=const_graph, inputs_to_check=inputs_to_check, +<<<<<<< HEAD fx_wrapper=fx_wrapper, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) metrics_helper = metrics.CachedMetricsHelper() # We are going to start code generating runtime asserts, so make sure # you don't start adding new ones in the lowering process graph.freeze_runtime_asserts() +<<<<<<< HEAD with V.set_graph_handler(graph), V.set_extern_kernel_nodes([]): +======= + with V.set_graph_handler(graph): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph.run(*example_inputs) output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = [] if graph.graph_outputs is not None: @@ -1447,6 +1589,7 @@ def codegen_and_compile( with dynamo_timed( "GraphLowering.compile_to_fn", log_pt2_compile_event=True ): +<<<<<<< HEAD if graph.aot_mode and graph.fx_wrapper: assert not graph.cpp_wrapper compiled_fn = graph.codegen()[0].gm # type: ignore[attr-defined] @@ -1456,6 +1599,9 @@ def codegen_and_compile( ) elif graph.aot_mode: +======= + if graph.aot_mode: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .codecache import AotCodeCompiler assert graph.cpp_wrapper, ( @@ -1471,9 +1617,17 @@ def codegen_and_compile( ) serialized_extern_kernel_nodes = None +<<<<<<< HEAD if V.extern_kernel_nodes: serialized_extern_kernel_nodes = ( graph.extern_node_serializer(V.extern_kernel_nodes) +======= + if graph.extern_kernel_nodes: + serialized_extern_kernel_nodes = ( + graph.extern_node_serializer( + graph.extern_kernel_nodes + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) output_code_log.debug( "Serialized Extern Kernel Nodes: \n%s", @@ -1508,6 +1662,7 @@ def codegen_and_compile( compiled_module, "runner", None ) +<<<<<<< HEAD # Dump provenance artifacts for debugging trace inductor_provenance_tracking_node_mappings = None inductor_kernel_stack_trace_str = None @@ -1536,6 +1691,8 @@ def codegen_and_compile( ) node_runtimes = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if inductor_metrics_log.isEnabledFor(logging.INFO): num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() metrics.num_bytes_accessed += num_bytes @@ -1550,6 +1707,7 @@ def codegen_and_compile( }, ) +<<<<<<< HEAD # Collect and dump op runtimes and tensor metadata for TLParse if config.log_tlparse: _, _, node_runtimes = graph.count_bytes() @@ -1558,6 +1716,8 @@ def codegen_and_compile( # Collect and dump collective-op schedule for external diagnostics torch._inductor.debug.log_collective_schedule(graph.scheduler.nodes) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( cudagraphs and config.triton.cudagraph_skip_dynamic_graphs @@ -1594,9 +1754,13 @@ def codegen_and_compile( V.graph.disable_cudagraphs_reason = disable if V.aot_compilation: +<<<<<<< HEAD assert isinstance( compiled_fn, (str, list, torch.fx.GraphModule) ), type(compiled_fn) +======= + assert isinstance(compiled_fn, (str, list)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return CompiledAOTI(compiled_fn) # TODO: Hoist this above V.aot_compilation @@ -1613,6 +1777,7 @@ def codegen_and_compile( self._compile_stats[type(self)].codegen_and_compile += 1 +<<<<<<< HEAD if ( torch._inductor.debug.RECORD_GRAPH_EXECUTION and torch._inductor.debug.GRAPH_COMPILE_IDS is not None @@ -1626,6 +1791,8 @@ def codegen_and_compile( compile_id ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return CompiledFxGraph( compiled_fn, graph, @@ -1642,8 +1809,11 @@ def codegen_and_compile( runnable_graph_str, inductor_post_grad_graph_str, compiled_fn_runner, +<<<<<<< HEAD inductor_provenance_tracking_node_mappings, inductor_kernel_stack_trace_str, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -1677,6 +1847,7 @@ def fx_codegen_and_compile( ) scheme = _AsyncFxCompile(scheme) +<<<<<<< HEAD if fx_compile_progressive: from .compile_fx_async import _ProgressiveFxCompile from .compile_fx_ext import _OutOfProcessFxCompile @@ -1692,6 +1863,8 @@ def fx_codegen_and_compile( scheme = _ProgressiveFxCompile(fast_scheme, scheme, progression_configs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) @@ -1891,18 +2064,31 @@ def compile_fx_aot( model_: GraphModule, example_inputs_: list[InputType], inner_compile: _CompileFxCallable = compile_fx_inner, +<<<<<<< HEAD config_patches: Optional[dict[str, Any]] = None, ) -> Union[list[Union[str, Weights]], str, GraphModule]: +======= + config_patches: Optional[dict[str, str]] = None, +) -> Union[list[Union[str, Weights]], str]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(model_, GraphModule), model_ # [See NOTE] Unwrapping subclasses AOT unwrap_tensor_subclass_parameters(model_) +<<<<<<< HEAD config_patches: dict[str, Any] = copy.deepcopy(config_patches or {}) if not (config_patches.get("fx_wrapper", False) or config.fx_wrapper): # If fx_wrapper is not set, then set cpp_wrapper config_patches["cpp_wrapper"] = True +======= + config_patches: dict[str, Any] = ( + {"cpp_wrapper": True} + if config_patches is None + else {**config_patches, "cpp_wrapper": True} + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_path = config_patches.get( "aot_inductor.output_path", config.aot_inductor.output_path @@ -1921,10 +2107,13 @@ def compile_fx_aot( "aot_inductor.output_path": code_hash(model_.code), } +<<<<<<< HEAD from .utils import maybe_aoti_standalone_config config_patches = maybe_aoti_standalone_config(config_patches) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extern_node_serializer = config_patches.pop("extern_node_serializer", None) saved_compile_id = model_.meta.get("dynamo_compile_id", None) saved_compile_context = torch._guards.CompileContext(saved_compile_id) @@ -1994,7 +2183,11 @@ def fw_compiler_freezing( idx for idx, n in enumerate(model_outputs) if isinstance(n, torch.fx.Node) ] +<<<<<<< HEAD static_input_idxs: list[Any] = [] +======= + static_input_idxs = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # constant params will be real tensors, not fake tracing_context = torch._guards.TracingContext.try_get() unwrapped_args_offsets = [0] @@ -2095,6 +2288,7 @@ def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[ ) +<<<<<<< HEAD def partition_fn( gm: GraphModule, joint_inputs: Sequence[object], @@ -2379,6 +2573,8 @@ def run_pre_grad_passes( return model_ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def compile_fx( model_: GraphModule, example_inputs_: Sequence[InputType], @@ -2398,12 +2594,15 @@ def compile_fx( NB: This function TAKES OWNERSHIP of the input ``model_`` and can potentially mutate it! Make a copy if you need to preserve the original GraphModule. """ +<<<<<<< HEAD # Wake up the AsyncCompile subproc pool as early as possible (if there's cuda). if any( isinstance(e, torch.Tensor) and e.device.type in ("cuda", "xpu") for e in example_inputs_ ): torch._inductor.async_compile.AsyncCompile.wakeup() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Some arguments trigger a recursive call to compile_fx. Handle these # short circuits first, before anything else @@ -2420,15 +2619,22 @@ def compile_fx( ) # TODO: This probably shouldn't be a recursive call +<<<<<<< HEAD if config.cpp_wrapper or config.fx_wrapper: cpp_wrapper_config = config.cpp_wrapper fx_wrapper_config = config.fx_wrapper +======= + if config.cpp_wrapper: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with ( config.patch( { "cpp_wrapper": False, # reset to break recursive call to compile_fx +<<<<<<< HEAD "fx_wrapper": False, # reset to break recursive call to compile_fx +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) **get_cpp_wrapper_config(), } ), @@ -2474,11 +2680,15 @@ def compile_fx( return compile_fx( patched_mod, fake_args, +<<<<<<< HEAD inner_compile=functools.partial( inner_compile, cpp_wrapper=cpp_wrapper_config, fx_wrapper=fx_wrapper_config, ), +======= + inner_compile=functools.partial(inner_compile, cpp_wrapper=True), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) decompositions=decompositions, ignore_shape_env=ignore_shape_env, ) @@ -2512,10 +2722,14 @@ def compile_fx( with ( _use_lazy_graph_module(dynamo_config.use_lazy_graph_module), enable_python_dispatcher(), +<<<<<<< HEAD torch.fx.traceback.preserve_node_meta( config.trace.provenance_tracking_level == 1 ), torch._inductor.debug.reset_provenance_globals(), +======= + torch.fx.traceback.preserve_node_meta(config.trace.enabled), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): # Pre-grad passes cannot be run if we weren't given a GraphModule. # Dynamo will always produce a GraphModule, but this handles cases @@ -2523,7 +2737,47 @@ def compile_fx( # having AOTAutograd trace it. # TODO: Get rid of this? if isinstance(model_, GraphModule): +<<<<<<< HEAD model_ = run_pre_grad_passes(model_, example_inputs_) +======= + # "before_pre_grad_graph" is used in inductor provenance + # tracking highlighter front-end. + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "before_pre_grad_graph", + "encoding": "string", + }, + payload_fn=lambda: model_.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + f"\n\n # graph id: {id(model_.graph)}", + ) + pre_grad_graphs_log.debug( + "%s", + lazy_format_graph_code( + "BEFORE PRE GRAD", + model_, + include_stride=True, + include_device=True, + colored=True, + ), + ) + torch._inductor.debug._pre_grad_graph_id = id(model_.graph) + + model_ = _recursive_pre_grad_passes(model_, example_inputs_) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "after_pre_grad_graph", + "encoding": "string", + }, + payload_fn=lambda: model_.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + f"\n\n # graph id: {id(model_.graph)}", + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: Move this before recursive pre-grad passes # NB: This short circuit never occurs for Dynamo produced graphs @@ -2539,7 +2793,24 @@ def compile_fx( num_example_inputs = len(example_inputs_) +<<<<<<< HEAD compiler_config_extra = create_compiler_config_extra(config) +======= + # Although cudagraphs may have been enabled via config, various + # conditions (which are tested within the bowels of Inductor) may + # force cudagraphs to be disabled. This mutable box lets us retrieve + # the final determination if cudagraphs actually can be used or not. + cudagraphs = BoxedBool(config.triton.cudagraphs) + + # See [Backward Generation Handling] + forward_device = BoxedDeviceIndex(None) + + # TODO: The modern style is to use CompileId from TracingContext to + # identify Inductor compilation. However, this CompileId cannot + # uniquely identify multiple Inductor compilations that arise from + # DDPOptimizer + graph_id = next(_graph_counter) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) decompositions = ( decompositions if decompositions is not None else select_decomp_table() @@ -2551,6 +2822,7 @@ def fw_compiler_base( is_inference: bool, ) -> OutputCode: with dynamo_utils.dynamo_timed("compile_fx..fw_compiler_base"): +<<<<<<< HEAD if isinstance(model_, GraphModule): num_orig_model_outputs = get_num_model_outputs(model_) else: @@ -2563,6 +2835,85 @@ def fw_compiler_base( compiler_config_extra=compiler_config_extra, inner_compile=inner_compile, is_inference=is_inference, +======= + if is_inference: + # partition_fn won't be called + _recursive_joint_graph_passes(gm) + + fixed = torch._inductor.utils.num_fw_fixed_arguments( + num_example_inputs, len(example_inputs) + ) + + model_outputs_node = output_node(gm) + if config.keep_output_stride: + model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) + num_model_outputs = len(model_outputs) + + context = torch._guards.TracingContext.try_get() + # See Note [User Outputs in the inductor graph] + if context is not None and context.fw_metadata and not is_inference: + original_output_start_index = ( + context.fw_metadata.num_mutated_inp_runtime_indices + ) + else: + original_output_start_index = 0 + + if isinstance(model_, GraphModule): + *_, orig_model_outputs_node = model_.graph.nodes + assert orig_model_outputs_node.op == "output" + orig_model_outputs, _ = pytree.tree_flatten( + orig_model_outputs_node.args + ) + num_orig_model_outputs = len(orig_model_outputs) + else: + num_orig_model_outputs = num_model_outputs + + assert num_orig_model_outputs <= num_model_outputs + + # Note [User Outputs in the inductor graph] + # We makes the following assumption + # For inference + # len(orig_model_outputs) == len(model_outputs) + # For training + # len(orig_model_outputs) <= len(model_outputs) + # During training, most of the time the model_outputs starts with + # original module's outputs followed by saved activations. + # But this can be not true if the model have inplace updated tensors. + # AOTAutograd will make those tensors being returned before the original + # module's output. + # To make things safe, we'll use original_output_start_index field + # set by AOTAutograd to decide where the original module outputs start. + orig_output_end_idx = ( + original_output_start_index + num_orig_model_outputs + ) + # Sanity check: we are about to splice out the "user" outputs from the full set + # of "graph" outputs. Make sure we're within bounds. + assert orig_output_end_idx <= num_model_outputs + + model_outputs_node.meta["user_visible_output_idxs"] = [ + idx + for idx in range( + original_output_start_index, orig_output_end_idx + ) + if isinstance(model_outputs[idx], torch.fx.Node) + ] + else: + model_outputs_node.meta["user_visible_output_idxs"] = [] + + # We also mark the invoke_subgraph outputs as user_visible to + # force the outputs of invoke_subgraph subgraph to follow the + # original strides + _recursive_record_user_visible_output_idxs(gm) + + return inner_compile( + gm, + example_inputs, + static_input_idxs=get_static_input_idxs(fixed), + cudagraphs=cudagraphs, + graph_id=graph_id, + is_inference=is_inference, + boxed_forward_device_index=forward_device, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = ( @@ -2576,9 +2927,15 @@ def fw_compiler_base( dynamo_model=model_, num_example_inputs=num_example_inputs, inner_compile=inner_compile, +<<<<<<< HEAD cudagraphs=compiler_config_extra.cudagraphs, graph_id=compiler_config_extra.graph_id, forward_device=compiler_config_extra.forward_device, +======= + cudagraphs=cudagraphs, + graph_id=graph_id, + forward_device=forward_device, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: inference_compiler = functools.partial(fw_compiler_base, is_inference=True) @@ -2586,10 +2943,41 @@ def fw_compiler_base( OutputCode, inference_compiler ) +<<<<<<< HEAD +======= + def partition_fn( + gm: GraphModule, + joint_inputs: Sequence[object], + **kwargs: object, + ) -> tuple[GraphModule, GraphModule]: + cuda_context = get_cuda_device_context(gm) + with cuda_context: + # We can skip the invoke_subgraph because the + # entire_partition_fn is called recursively for invoke_subgraph + # in partitioning. + _recursive_joint_graph_passes(gm, skip_invoke_subgraph=True) + + static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment] + "static_lifetime_input_indices", None + ) + + with dynamo_utils.dynamo_timed( + "min_cut_rematerialization_partition", log_pt2_compile_event=True + ): + return min_cut_rematerialization_partition( + gm, + joint_inputs, + compiler="inductor", + static_lifetime_input_indices=static_lifetime_input_indices, + **kwargs, + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @compile_time_strobelight_meta(phase_name="backward") def bw_compiler( gm: GraphModule, example_inputs: Sequence[InputType] ) -> OutputCode: +<<<<<<< HEAD with ( dynamo_utils.dynamo_timed("compile_fx..bw_compiler"), ): @@ -2599,6 +2987,40 @@ def bw_compiler( compiler_config_extra=compiler_config_extra, inner_compile=inner_compile, ) +======= + from torch._dynamo.convert_frame import compile_lock + + with ( + dynamo_utils.dynamo_timed("compile_fx..bw_compiler"), + compile_lock, + ): + model_outputs_node = output_node(gm) + if config.bw_outputs_user_visible: + model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) + model_outputs_node.meta["user_visible_output_idxs"] = [ + idx + for idx, n in enumerate(model_outputs) + if isinstance(n, torch.fx.Node) + ] + else: + model_outputs_node.meta["user_visible_output_idxs"] = [] + + fixed = count_tangents(gm) + with ( + config.patch(get_cpp_wrapper_config()) + if config.cpp_wrapper + else contextlib.nullcontext() + ): + return inner_compile( + gm, + example_inputs, + static_input_idxs=list(range(fixed)), + cudagraphs=cudagraphs, + is_backward=True, + graph_id=graph_id, + boxed_forward_device_index=forward_device, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bw_compiler = SerializableAOTDispatchCompiler(OutputCode, bw_compiler) @@ -2611,10 +3033,13 @@ def bw_compiler( ) if V.aot_compilation: +<<<<<<< HEAD from .utils import is_valid_aoti_model_name is_valid_aoti_model_name() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with functorch_config.patch(unlift_effect_tokens=True): gm, graph_signature = aot_export_module( model_, @@ -2636,7 +3061,10 @@ def bw_compiler( if node.op == "get_attr" and "val" not in node.meta: target = attrgetter(node.target)(gm) if isinstance(target, torch.Tensor): +<<<<<<< HEAD assert fake_mode is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) node.meta["val"] = fake_mode.from_tensor( target, static_shapes=True ) @@ -2685,8 +3113,13 @@ def bw_compiler( decompositions=decompositions, partition_fn=partition_fn, keep_inference_input_mutations=True, +<<<<<<< HEAD cudagraphs=compiler_config_extra.cudagraphs, boxed_forward_device_index=compiler_config_extra.forward_device, +======= + cudagraphs=cudagraphs, + boxed_forward_device_index=forward_device, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ignore_shape_env=ignore_shape_env, )(model_, example_inputs_) except ShortenTraceback as e: diff --git a/torch/_inductor/compile_fx_async.py b/torch/_inductor/compile_fx_async.py index 05c896ae86448..07da98379e52f 100644 --- a/torch/_inductor/compile_fx_async.py +++ b/torch/_inductor/compile_fx_async.py @@ -1,6 +1,9 @@ from __future__ import annotations +<<<<<<< HEAD from collections import deque +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from dataclasses import dataclass from typing import Any, Callable, Optional, TYPE_CHECKING from typing_extensions import final, override @@ -12,10 +15,13 @@ from .output_code import complex_memory_overlap as complex_memory_overlap # noqa: F401 +<<<<<<< HEAD # When async compile works with cache, remove the disabling below BUG_CACHES_DONT_WORK_WITH_ASYNC = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING: from collections.abc import Sequence from concurrent.futures import Future @@ -33,6 +39,7 @@ class _PostCompileData: graph_kwargs: _CompileFxKwargs +<<<<<<< HEAD @dataclass class ProgressiveCompilationState: progression_futures: deque[Future[_WireProtocolPickledOutput]] @@ -75,11 +82,17 @@ def switch_to_progression_stage(self, stage_index: int) -> tuple[OutputCode, boo return optimized_output_code, should_clear_state +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # _AsyncOutputCode handles the actual management of waiting for an # out-of-process compile to finish and then switching over to it. @final class _AsyncOutputCode(OutputCode): +<<<<<<< HEAD _eager_fn: Optional[Callable[..., Any]] +======= + _eager_forward: Optional[Callable[..., Any]] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _output_code: Optional[OutputCode] _future: Optional[Future[_WireProtocolPickledOutput]] _callback: Callable[[_WireProtocolPickledOutput], OutputCode] @@ -88,16 +101,26 @@ class _AsyncOutputCode(OutputCode): def __init__( self, +<<<<<<< HEAD # eager_fn is run until the future is finished. eager_fn: Callable[..., Any], +======= + # eager_forward is run until the future is finished. + eager_forward: Callable[..., Any], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # this responds with the result of the out-of-process compile when it's # ready. future: Future[_WireProtocolPickledOutput], # this callback gets called to turn the _WireProtocolPickledOutput into an OutputCode callback: Callable[[_WireProtocolPickledOutput], OutputCode], ) -> None: +<<<<<<< HEAD self._eager_fn = eager_fn self._boxed_call = getattr(eager_fn, "_boxed_call", False) +======= + self._eager_forward = eager_forward + self._boxed_call = getattr(eager_forward, "_boxed_call", False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._output_code = None self._future = future @@ -106,11 +129,19 @@ def __init__( @override def __call__(self, *args: Any) -> Any: if self._future is not None and self._future.done(): +<<<<<<< HEAD args = self._switch_to_compiled_fn(args) if eager_fn := self._eager_fn: _AsyncFxCompile._stat_eager_runs += 1 return eager_fn(*args) +======= + args = self._switch_to_compiled_forward(args) + + if eager_forward := self._eager_forward: + _AsyncFxCompile._stat_eager_runs += 1 + return eager_forward(*args) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: _AsyncFxCompile._stat_compiled_runs += 1 @@ -118,7 +149,11 @@ def __call__(self, *args: Any) -> Any: return self._output_code.__call__(*args) # Takes and returns the args (converted to the "right" boxed mode) +<<<<<<< HEAD def _switch_to_compiled_fn(self, args: tuple[Any, ...]) -> tuple[Any, ...]: +======= + def _switch_to_compiled_forward(self, args: tuple[Any, ...]) -> tuple[Any, ...]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self._future is not None # TODO: If the future ended in an exception do we want to continue @@ -134,7 +169,11 @@ def _switch_to_compiled_fn(self, args: tuple[Any, ...]) -> tuple[Any, ...]: ) self._output_code = output_code +<<<<<<< HEAD self._eager_fn = None +======= + self._eager_forward = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) boxed_call = getattr(output_code, "_boxed_call", False) if self._boxed_call != boxed_call: @@ -155,7 +194,11 @@ def post_compile( constants: CompiledFxGraphConstants, graph_kwargs: _CompileFxKwargs, ) -> None: +<<<<<<< HEAD if self._eager_fn is not None: +======= + if self._eager_forward is not None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._post_compile_data = _PostCompileData( example_inputs, constants, graph_kwargs ) @@ -218,7 +261,11 @@ def codegen_and_compile( _AsyncFxCompile._stat_bg_started += 1 f = self._compile._send_to_child_async(inputs) +<<<<<<< HEAD # This is called by _switch_to_compiled_fn() when f has a result... +======= + # This is called by _switch_to_compiled_forward() when f has a result... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def callback(pickled_output: _WireProtocolPickledOutput) -> OutputCode: _AsyncFxCompile._stat_bg_finished += 1 output = pickled_output.deserialize(constants) @@ -226,6 +273,7 @@ def callback(pickled_output: _WireProtocolPickledOutput) -> OutputCode: return output.graph return _AsyncOutputCode(eager_output_code, f, callback) +<<<<<<< HEAD # _ProgressiveOutputCode handles running a fast compile first, then hot-swapping @@ -396,3 +444,5 @@ def callback(pickled_output: _WireProtocolPickledOutput) -> OutputCode: return output.graph return _ProgressiveOutputCode(fast_output_code, progression_futures, callback) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 6342fc7e0fcd7..c5799ce9b42a2 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -13,7 +13,11 @@ import typing from concurrent.futures import Future, ProcessPoolExecutor from concurrent.futures.process import BrokenProcessPool +<<<<<<< HEAD from enum import Enum, IntEnum +======= +from enum import Enum +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Any, Callable, IO, Optional, TypeVar from typing_extensions import Never, ParamSpec @@ -27,8 +31,12 @@ TrackedProcessPoolExecutor, ) from torch._inductor.compile_worker.utils import _async_compile_initializer +<<<<<<< HEAD from torch._inductor.utils import get_ld_library_path, python_subprocess_env from torch._utils_internal import find_compile_subproc_binary +======= +from torch._inductor.utils import get_ld_library_path +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) @@ -37,6 +45,7 @@ _T = TypeVar("_T") +<<<<<<< HEAD class MsgHeader(IntEnum): ERROR = 0 SHUTDOWN = 1 @@ -73,6 +82,33 @@ def _recv_msg(read_pipe: IO[bytes]) -> tuple[MsgHeader, int, bytes]: msg_header, job_id, length = _unpack_msg(read_pipe.read(msg_bytes)) data = read_pipe.read(length) if length > 0 else b"" return msg_header, job_id, data +======= +def _pack_msg(job_id: int, length: int) -> bytes: + return struct.pack("nn", job_id, length) + + +def _unpack_msg(data: bytes) -> tuple[int, int]: + if not data: + return -1, -1 + return struct.unpack("nn", data) + + +msg_bytes = len(_pack_msg(0, 0)) + + +def _send_msg(write_pipe: IO[bytes], job_id: int, job_data: bytes = b"") -> None: + length = len(job_data) + write_pipe.write(_pack_msg(job_id, length)) + if length > 0: + write_pipe.write(job_data) + write_pipe.flush() + + +def _recv_msg(read_pipe: IO[bytes]) -> tuple[int, bytes]: + job_id, length = _unpack_msg(read_pipe.read(msg_bytes)) + data = read_pipe.read(length) if length > 0 else b"" + return job_id, data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _SubprocExceptionInfo: @@ -91,6 +127,7 @@ class SubprocException(Exception): Thrown when a job in a subprocess raises an Exception. """ +<<<<<<< HEAD def __init__(self, details: str, name: str = "") -> None: self.details = details super().__init__( @@ -99,6 +136,10 @@ def __init__(self, details: str, name: str = "") -> None: def with_name(self, name: str) -> "SubprocException": return SubprocException(self.details, name) +======= + def __init__(self, details: str) -> None: + super().__init__(f"An exception occurred in a subprocess:\n\n{details}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SubprocPickler: @@ -144,11 +185,14 @@ def __init__( cmd = [ sys.executable, entry, +<<<<<<< HEAD ] if (binary := find_compile_subproc_binary()) is not None: cmd = [binary] args = [ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"--pickler={self.pickler.__class__.__module__}.{self.pickler.__class__.__name__}", f"--kind={self.kind.value}", f"--workers={nprocs}", @@ -157,6 +201,7 @@ def __init__( f"--write-fd={str(subproc_write_fd)}", f"--torch-key={torch_key_str}", ] +<<<<<<< HEAD cmd.extend(args) log_path = None self.log_file = None @@ -171,17 +216,35 @@ def __init__( if log_path: self.log_file = open(log_path, "w") +======= + local = False + if config.worker_suppress_logging: + log.info("Suppressing compile worker output due to config") + local = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.process = subprocess.Popen( cmd, env={ +<<<<<<< HEAD **python_subprocess_env(), # Safeguard against creating a SubprocPool in the subprocess. +======= + **os.environ, + # We need to set the PYTHONPATH so the subprocess can find torch. + "PYTHONPATH": os.environ.get( + "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path) + ), + # We don't want to re-warm the pool when the subprocess imports + # torch._inductor.codecache since the warming process is what + # creates the SubprocPool in the first place. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "TORCH_WARM_POOL": "0", # Some internal usages need a modified LD_LIBRARY_PATH. "LD_LIBRARY_PATH": get_ld_library_path(), }, pass_fds=(subproc_read_fd, subproc_write_fd), +<<<<<<< HEAD stdout=self.log_file, stderr=self.log_file, ) @@ -189,6 +252,13 @@ def __init__( self.read_thread = threading.Thread( target=self._read_thread, name="InductorSubproc", daemon=True ) +======= + stdout=subprocess.DEVNULL if local else None, + stderr=subprocess.DEVNULL if local else None, + ) + self.write_lock = threading.Lock() + self.read_thread = threading.Thread(target=self._read_thread, daemon=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.futures_lock = threading.Lock() self.pending_futures: dict[int, Future[Any]] = {} @@ -211,6 +281,7 @@ def submit( job_id = next(self.job_id_count) self.pending_futures[job_id] = future = Future() future.set_running_or_notify_cancel() +<<<<<<< HEAD self._send(MsgHeader.JOB, job_id, job_data) return future @@ -219,10 +290,18 @@ def _send(self, msg_header: MsgHeader, job_id: int = -1, data: bytes = b"") -> N if not self.running: raise RuntimeError("Attempting to use a closed pool") _send_msg(self.write_pipe, msg_header, job_id, data) +======= + with self.write_lock: + if not self.running: + raise RuntimeError("submit() on closed pool") + _send_msg(self.write_pipe, job_id, job_data) + return future +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _read_thread(self) -> None: while True: data = b"" +<<<<<<< HEAD job_id = -1 try: msg_header, job_id, data = _recv_msg(self.read_pipe) @@ -233,6 +312,17 @@ def _read_thread(self) -> None: msg_header = MsgHeader.ERROR if msg_header != MsgHeader.JOB: +======= + try: + job_id, data = _recv_msg(self.read_pipe) + except Exception: + # Something went wrong during the read. There's no way we have a + # valid job_id. + log.exception("failure in subproc_pool._recv_msg") + job_id = -1 + + if job_id < 0: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # read_pipe returned None or got exception if self.running: log.warning("SubprocPool unclean exit") @@ -265,23 +355,32 @@ def _read_thread(self) -> None: self.pending_futures[job_id].set_result(result) del self.pending_futures[job_id] +<<<<<<< HEAD def quiesce(self) -> None: self._send(MsgHeader.QUIESCE) def wakeup(self) -> None: self._send(MsgHeader.WAKEUP) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def shutdown(self) -> None: try: with self.write_lock: if not self.running: return self.running = False +<<<<<<< HEAD _send_msg(self.write_pipe, MsgHeader.SHUTDOWN) self.write_pipe.close() self.process.wait(300) if self.log_file: self.log_file.close() +======= + _send_msg(self.write_pipe, -1) + self.write_pipe.close() + self.process.wait(300) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except OSError as e: log.warning("Ignored OSError in pool shutdown: %s", e) finally: @@ -309,6 +408,7 @@ def __init__( self.write_pipe = write_pipe self.write_lock = threading.Lock() self.nprocs = nprocs +<<<<<<< HEAD self.pool: Optional[ProcessPoolExecutor] = None self.running = True @@ -328,17 +428,47 @@ def _quiesce(self) -> None: if self.pool is not None: self.pool.shutdown(wait=False) self.pool = None +======= + self.pool = self._new_pool(nprocs, True) + self.running = True + + def _new_pool(self, nprocs: int, warm: bool) -> ProcessPoolExecutor: + pool = TrackedProcessPoolExecutor( + nprocs, + mp_context=multiprocessing.get_context(self.kind.value), + initializer=functools.partial(_async_compile_initializer, os.getpid()), + ) + multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + if warm: + _warm_process_pool(pool, nprocs) + return pool + + def main(self) -> None: + while True: + job_id, data = _recv_msg(self.read_pipe) + if job_id < 0: + return self._shutdown() + self.submit(job_id, data) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _shutdown(self) -> None: with self.write_lock: self.running = False try: +<<<<<<< HEAD _send_msg(self.write_pipe, MsgHeader.SHUTDOWN) +======= + _send_msg(self.write_pipe, -1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.write_pipe.close() except BrokenPipeError: pass # parent process already shutdown self.read_pipe.close() +<<<<<<< HEAD self._quiesce() +======= + self.pool.shutdown() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def submit(self, job_id: int, data: bytes) -> None: while self.running: @@ -349,7 +479,11 @@ def submit(self, job_id: int, data: bytes) -> None: # If any subprocess in the pool crashes, we get a BrokenProcessPool # exception and the whole pool becomes unusable. Handle crashes by # recreating the pool and resubmitting. +<<<<<<< HEAD self.pool = None +======= + self.pool = self._new_pool(self.nprocs, False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _submit_inner(self, job_id: int, data: bytes) -> None: def callback(fut: Future[Any]) -> None: @@ -363,17 +497,24 @@ def callback(fut: Future[Any]) -> None: assert isinstance(result, bytes) with self.write_lock: if self.running: +<<<<<<< HEAD _send_msg(self.write_pipe, MsgHeader.JOB, job_id, result) return self._start_pool() assert self.pool is not None +======= + _send_msg(self.write_pipe, job_id, result) + return + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) future = self.pool.submit( functools.partial(SubprocMain.do_job, self.pickler, data) ) future.add_done_callback(callback) +<<<<<<< HEAD def _start_pool(self) -> None: if self.pool is not None: return @@ -388,6 +529,8 @@ def _start_pool(self) -> None: ) _warm_process_pool(self.pool, self.nprocs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def do_job(pickler: SubprocPickler, data: bytes) -> bytes: # do the pickle/unpickle in the sub-subproc diff --git a/torch/_inductor/compile_worker/utils.py b/torch/_inductor/compile_worker/utils.py index a54fa308d3fd3..2719c52bb686a 100644 --- a/torch/_inductor/compile_worker/utils.py +++ b/torch/_inductor/compile_worker/utils.py @@ -23,8 +23,11 @@ def in_toplevel_process() -> bool: # This function cannot be an inner function since otherwise mp_context="spawn" would # not work for ProcessPoolExecutor since inner functions cannot be pickled. def _async_compile_initializer(orig_ppid: int) -> None: +<<<<<<< HEAD import torch._C +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def run() -> None: while True: sleep(1) @@ -38,9 +41,12 @@ def run() -> None: # Ignore Ctrl-C (i.e. SIGINT) sent to pool workers to avoid meaningless log spam. signal.signal(signal.SIGINT, signal.SIG_IGN) +<<<<<<< HEAD # Install a crash handler to print out the stacktrace for SEGV torch._C._initCrashHandler() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Set a bit to distinguish async_compile subprocesses from the toplevel process. global _IN_TOPLEVEL_PROCESS _IN_TOPLEVEL_PROCESS = False diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 857272df14c94..861b62a19454b 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -40,7 +40,11 @@ def bundle_triton_into_fx_graph_cache_default() -> Optional[bool]: def static_cuda_launcher_default() -> bool: +<<<<<<< HEAD STATIC_CUDA_LAUNCHER_VERSION = 2 +======= + STATIC_CUDA_LAUNCHER_VERSION = 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if "TORCHINDUCTOR_USE_STATIC_CUDA_LAUNCHER" in os.environ: return os.environ.get("TORCHINDUCTOR_USE_STATIC_CUDA_LAUNCHER") == "1" @@ -81,11 +85,14 @@ def prologue_fusion_enabled() -> bool: # Whether to enable printing the source code for each future verbose_progress = False +<<<<<<< HEAD # Configurable compile worker logging path for subproc_pool worker_log_path = ( "/logs/dedicated_log_torch_compile_worker_rank" if is_fbcode() else None ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # precompilation timeout precompilation_timeout_seconds: int = 60 * 60 @@ -96,8 +103,11 @@ def prologue_fusion_enabled() -> bool: default=True, ) +<<<<<<< HEAD remote_gemm_autotune_cache: bool = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # use remote fx aot graph codegen cache # False: Disables the cache # True: Enables the cache @@ -145,8 +155,17 @@ def prologue_fusion_enabled() -> bool: # None: Not set -- Off for OSS, JustKnobs based for internal bundled_autotune_remote_cache: Optional[bool] = bundled_autotune_remote_cache_default() +<<<<<<< HEAD # See torch.compiler.config.force_disable_caches force_disable_caches: bool = Config(alias="torch.compiler.config.force_disable_caches") +======= +# Force disabled all inductor level caching -- This will override any other caching flag +force_disable_caches: bool = Config( + justknob="pytorch/remote_cache:force_disable_caches", + env_name_force="TORCHINDUCTOR_FORCE_DISABLE_CACHES", + default=False, +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Unsafe way to skip dynamic shape guards to get faster cache load unsafe_skip_cache_dynamic_shape_guards: bool = False @@ -187,8 +206,11 @@ def prologue_fusion_enabled() -> bool: os.environ.get("TORCHINDUCTOR_CPP_WRAPPER_BUILD_SEPARATE", "0") == "1" ) +<<<<<<< HEAD fx_wrapper: bool = os.environ.get("TORCHINDUCTOR_FX_WRAPPER", "0") == "1" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Controls automatic precompiling of common include files for codecache.CppCodeCache # (i.e. for cpp_wrapper mode and for cpp kernels on CPU). AOTI header precompiling is # controlled by a separate flag. @@ -266,12 +288,18 @@ def prologue_fusion_enabled() -> bool: post_grad_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None post_grad_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None +<<<<<<< HEAD # Allow users to pass in custom partition function custom_partitioner_fn: torch._inductor.custom_graph_pass.CustomPartitionerFnType = None # Registers a custom joint graph pass. joint_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None joint_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None +======= +# Registers a custom joint graph pass. +joint_custom_pre_pass: Optional[Callable[[torch.fx.Graph], None]] = None +joint_custom_post_pass: Optional[Callable[[torch.fx.Graph], None]] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Registers a custom pregrad pass. Note that the pre-grad IR is 1. # non-functional, 2. non-normalized, and 3. prone to change. Ideally we should @@ -392,6 +420,7 @@ def prologue_fusion_enabled() -> bool: # enable operator reordering for peak memory optimization reorder_for_peak_memory = True +<<<<<<< HEAD reorder_iterative_debug_memory_recompute: bool = False reorder_iterative_debug_limit_to_reorder: Optional[int] = ( None @@ -412,12 +441,17 @@ def prologue_fusion_enabled() -> bool: None ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # runtime estimation function for ops # for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle estimate_op_runtime = "default" +<<<<<<< HEAD runtime_estimations_mms_benchmark: bool = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # unit: GB/s, uni-directional P2P bandwidth per card # default value is NVLink intra_node_bw = 300 @@ -445,6 +479,7 @@ def prologue_fusion_enabled() -> bool: # enable slow autotuning passes to select gemm algorithms max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1" +<<<<<<< HEAD # Modifies the number of autotuning choices displayed, set to None for all autotune_num_choices_displayed: Optional[int] = 10 @@ -465,11 +500,22 @@ def prologue_fusion_enabled() -> bool: == "1" ) +======= +# disable decomposek autotune choice for gemm +disable_decompose_k = os.environ.get("TORCHINDUCTOR_DISABLE_DECOMPOSE_K") == "1" + +# Modifies the number of autotuning choices displayed, set to None for all +autotune_num_choices_displayed: Optional[int] = 10 + +# enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph +graph_partition = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # force cublas and triton to use the same precision; cublas supports TF32 for matmul operations # when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations # for any combinations of m, n, k, regardless of their alignment. setting this flag will ensure # that triton does not use TF32 wherever cublas would not use TF32 +<<<<<<< HEAD # DEPRECATED. cuBLAS no longer has the above alignment requirements. will remove in the future. force_same_precision: bool = Config( justknob="pytorch/compiler:force_same_precision", @@ -485,11 +531,22 @@ def prologue_fusion_enabled() -> bool: # Specify candidate backends for gemm autotune. # Possible choices are combinations of: ATen, Triton, CUTLASS, CK, CKTILE, CPP. +======= +force_same_precision = ( + True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1" +) + +# Specify candidate backends for gemm autotune. +# Possible choices are combinations of: ATen, Triton, CUTLASS, CK, CPP. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ATen: default Pytorch ATen kernels. # Triton: Triton templates defined in torch inductor (AMD and NVidia GPUs). # CUTLASS: Cutlass templates and kernels (NVidia GPUs only). # CK: Composable Kernel templates and kernels (AMD Instinct GPUs only). +<<<<<<< HEAD # CKTILE: Composable Kernel templates and kernels, new API (AMD Instinct GPUs only). +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # CPP: CPP templates and kernels for CPU. max_autotune_gemm_backends = os.environ.get( "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP" @@ -525,8 +582,13 @@ def prologue_fusion_enabled() -> bool: # that can appear in the input shapes (e.g., in autotuning) unbacked_symint_fallback = 8192 +<<<<<<< HEAD # DEPRECATED. This setting is ignored. search_autotune_cache = False +======= +# enable searching global and local cache regardless of `max_autotune` +search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1" @@ -562,11 +624,14 @@ def prologue_fusion_enabled() -> bool: # Specify a list of comma separated optimizations to use learned heuristics for autoheuristic_use = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_USE", "mixed_mm") +<<<<<<< HEAD # If set to 1, will run a JIT post compile hook if one is set. run_jit_post_compile_hook = ( os.environ.get("TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK", "0") == "1" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def run_autoheuristic(name: str) -> bool: return collect_autoheuristic(name) or use_autoheuristic(name) @@ -612,9 +677,12 @@ def use_autoheuristic(name: str) -> bool: # Threshold to prevent excessive accumulation of ops in one buffer during lowering realize_acc_reads_threshold = 8 +<<<<<<< HEAD realize_acc_reads_size_threshold: Optional[int] = ( None # TODO(xuanzh): harden this to make it non optional ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # fallback to eager for random/dropout, this is slow but useful for debugging fallback_random = False @@ -718,7 +786,11 @@ def use_autoheuristic(name: str) -> bool: # for all except for foreach, 2 - enable for all combo_kernel_allow_mixed_sizes = 1 # Enable dynamic shapes for foreach kernels +<<<<<<< HEAD combo_kernel_foreach_dynamic_shapes = True +======= +combo_kernel_foreach_dynamic_shapes = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # constant folding on the joint graph joint_graph_constant_folding = True @@ -773,10 +845,13 @@ def decide_worker_start_method() -> str: worker_start_method: str = decide_worker_start_method() +<<<<<<< HEAD # Threshold to decide if a kernel has small memory access in bytes # Default value is 16 MB which is arbitrarily selected. small_memory_access_threshold: int = 16777216 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Whether to log from subprocess workers that are launched. worker_suppress_logging: bool = Config( justknob="pytorch/compiler:worker_suppress_logging", @@ -784,12 +859,15 @@ def decide_worker_start_method() -> str: default=True, ) +<<<<<<< HEAD # Log per-operation runtime estimates for TLParse analysis. log_tlparse: bool = Config( env_name_force="LOG_TLPARSE", default=False, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Flags to turn on all_reduce fusion. These 2 flags should be automatically turned # on by DDP and should not be set by the users. _fuse_ddp_communication = False @@ -874,6 +952,7 @@ def decide_compile_threads() -> int: # TODO: Set directly after internal rollout. compile_threads: Optional[int] = None if is_fbcode() else decide_compile_threads() +<<<<<<< HEAD # Whether to quiesce the Triton-compile subprocess pool at the end of each compilation. quiesce_async_compile_pool: bool = Config( justknob="pytorch/inductor:quiesce_async_compile_pool", @@ -881,6 +960,8 @@ def decide_compile_threads() -> int: default=False, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Whether or not to enable statically launching CUDA kernels # compiled by triton (instead of using triton's own launcher) use_static_cuda_launcher: bool = static_cuda_launcher_default() @@ -929,6 +1010,7 @@ def decide_compile_threads() -> int: ) pad_channels_last = False +<<<<<<< HEAD # Control if we will do padding on dynamic shapes pad_dynamic_shapes = False @@ -938,6 +1020,11 @@ def decide_compile_threads() -> int: # Control if we will expand the dimension of pointwise nodes to fuse expand_dimension_for_pointwise_nodes = False +======= +# Disable comprehensive padding on the CPU +disable_padding_cpu = True + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The width of comprehensive padding, in bytes. # CUDA max memory transaction size is 128 bytes for a warp. padding_alignment_bytes = 128 @@ -1052,6 +1139,7 @@ def decide_compile_threads() -> int: annotate_training: bool = os.environ.get("TORCHINDUCTOR_ANNOTATE_TRAINING", "0") == "1" # Enable caching codegen of triton templates. +<<<<<<< HEAD enable_caching_generated_triton_templates: bool = True # Lookup table for overriding autotune configs based on hash of Triton source code @@ -1074,15 +1162,21 @@ def get_worker_log_path() -> Optional[str]: env_name_force="TORCHINDUCTOR_WORKER_LOGPATH", default="", ) +======= +enable_caching_generated_triton_templates: bool = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # config specific to codegen/cpp.py class cpp: +<<<<<<< HEAD """ Settings for cpp backend. This class provides a centralized location for managing cpp backend settings. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # set to torch.get_num_threads() threads = -1 @@ -1099,7 +1193,11 @@ class cpp: dynamic_threads = os.environ.get("TORCHINDUCTOR_CPP_DYNAMIC_THREADS", "0") == "1" simdlen: Optional[int] = None +<<<<<<< HEAD min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "512")) +======= + min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cxx: tuple[Literal[None], str] = ( None, # download gcc12 from conda-forge if conda is installed @@ -1198,6 +1296,7 @@ class cpp: # Use a small dequant buffer for wgt of woq int4 size as: [q_group_size, Nr] use_small_dequant_buffer = False +<<<<<<< HEAD force_inline_kernel = ( os.environ.get("TORCHINDUCTOR_CPP_FORCE_INLINE_KERNEL", "0") == "1" ) @@ -1213,6 +1312,11 @@ class triton: Config specific to codegen/triton.py """ +======= + +# config specific to codegen/triton.py +class triton: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Use cudagraphs on output code cudagraphs = os.environ.get("TORCHINDUCTOR_CUDAGRAPHS") == "1" @@ -1254,6 +1358,7 @@ class triton: # instead of recording and executing cudagraphs force_cudagraphs_warmup = False +<<<<<<< HEAD # If False (default), torch.compile skips cudagraph for a graph if it # contains cudagraph-unsafe ops. If True, we require that all cuda ops # be captured into cudagraph. If this is not possible, this will raise @@ -1263,6 +1368,8 @@ class triton: default=False, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # assertions on the fast path fast_path_cudagraph_asserts = False @@ -1391,11 +1498,16 @@ class triton: # So far we see a fixed 8 spilled registers for kernels using sin/cos. # Raise the threshold to 16 to be safe. # We should revisit this once we understand more of the source of register spills. +<<<<<<< HEAD spill_threshold: int = 32 if torch.version.hip else 16 +======= + spill_threshold: int = 16 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Generate code containing the newer tl.make_block_ptr() API for loads/store use_block_ptr = False +<<<<<<< HEAD # (Experimental) # Generate code using the tl.make_tensor_descriptor() API for loads/store # [Note: TMA API Restrictions] Currently the TMA API requires the following: @@ -1409,6 +1521,8 @@ class triton: # can be satisfied, along with any existing requirements for index expressions use_tensor_descriptor = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Inject a bug into our relu implementation; useful for testing our repro # extraction and minification functionality. # Valid values: "compile_error", "runtime_error", "accuracy" @@ -1431,6 +1545,7 @@ class triton: # Note: it may also need to be used with config.compile_threads = 1 disallow_failing_autotune_kernels_TESTING_ONLY = False +<<<<<<< HEAD # specify number of splits to autotune on for decompose_k. 0 disables decompose_k num_decompose_k_splits = int( os.environ.get("TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS", "10") @@ -1442,6 +1557,8 @@ class triton: os.environ.get("TORCHINDUCTOR_DECOMPOSE_K_THRESHOLD", "32") ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class aot_inductor: """ @@ -1493,12 +1610,17 @@ class aot_inductor: # rather than embedded into the data section. Needed to support 1B+ parameter models force_mmap_weights: bool = False +<<<<<<< HEAD # Default value of use_consts_asm_build is True, it will build by assembly language. # When the value is False, it will build by c++ language. use_consts_asm_build = True package: bool = False package_cpp_only: Optional[bool] = None +======= + package: bool = False + package_cpp_only: bool = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Dictionary of metadata users might want to save to pass to the runtime. # TODO: Move this somewhere else, since it's no longer really a config @@ -1537,11 +1659,14 @@ class aot_inductor: # but performance for that interface may be degraded. use_minimal_arrayref_interface: bool = False +<<<<<<< HEAD # Set to True if we want to use Pytorch's CUDACachingAllocator for weight management weight_use_caching_allocator: bool = ( os.environ.get("AOT_INDUCTOR_WEIGHT_USE_CACHING_ALLOCATOR", "0") == "1" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Experimental. Flag to control whether to include weight in .so package_constants_in_so: bool = True @@ -1552,11 +1677,16 @@ class aot_inductor: precompile_headers: bool = not is_fbcode() # Embed generated kernel binary files into model.so +<<<<<<< HEAD embed_kernel_binary: Optional[bool] = None +======= + embed_kernel_binary: bool = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Generate kernel files that support multiple archs # For CUDA, this means generating fatbin files for kernels, and the fatbin files # contains PTX and SASS for the current architecture. +<<<<<<< HEAD emit_multi_arch_kernel: Optional[bool] = None # If not None, the generated files with use this name in file stem. @@ -1568,6 +1698,12 @@ class aot_inductor: # If compile_standalone, the aoti model class name is f"AOTInductorModel{name}" # # This name can only contain letters, numbers, and underscores. +======= + emit_multi_arch_kernel: bool = False + + # If not None, the generated files with use this name in file stem. + # If None, we will use a hash to name files. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model_name_for_generated_files: Optional[str] = None # Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict @@ -1575,11 +1711,14 @@ class aot_inductor: # custom op libs that have implemented C shim wrappers custom_op_libs: Optional[list[str]] = None +<<<<<<< HEAD compile_standalone: bool = False # Whether to enable link-time-optimization enable_lto = os.environ.get("AOT_INDUCTOR_ENABLE_LTO", "0") == "1" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class cuda: """Settings for cuda backend, today this consists of cutlass""" @@ -1611,11 +1750,19 @@ class cuda: # Path to the CUTLASS repo root directory. # The default path only works under PyTorch local development environment. +<<<<<<< HEAD cutlass_dir = os.path.realpath( os.environ.get( "TORCHINDUCTOR_CUTLASS_DIR", os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/"), ) +======= + cutlass_dir = os.environ.get( + "TORCHINDUCTOR_CUTLASS_DIR", + os.path.abspath( + os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/") + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Configures the maximum number of CUTLASS configs to profile in max_autotune. @@ -1712,9 +1859,12 @@ class cuda: # Use this to overwrite and handle cache pollution binary_remote_cache_force_write: bool = False +<<<<<<< HEAD # Enable caching codegen of cuda templates. enable_caching_codegen: bool = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class rocm: # Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"]. @@ -1723,11 +1873,15 @@ class rocm: # Enable the CK backend for CDNA2 and CDNA3 only (for now) # Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors +<<<<<<< HEAD ck_supported_arch: list[Literal["gfx90a", "gfx942", "gfx950"]] = [ "gfx90a", "gfx942", "gfx950", ] +======= + ck_supported_arch: list[str] = ["gfx90a", "gfx942"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Optimization level, use to balance compilation speed and runtime performance. # The type will not necessarily be comprehensive and won't be enforced at runtime. @@ -1785,9 +1939,12 @@ class rocm: # The threshold at which we trigger a splitK config - K // max(M,N) has to be greater than this split_k_threshold: int = 16 +<<<<<<< HEAD # The threshold at which we trigger a contiguous subgraph transformation contiguous_threshold: int = 16 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) cpu_backend: Literal["cpp", "triton", "halide"] = "cpp" @@ -1887,6 +2044,7 @@ class trace: log_autotuning_results = os.environ.get("LOG_AUTOTUNE_RESULTS", "0") == "1" +<<<<<<< HEAD # Save mapping info from inductor generated kernel to post_grad/pre_grad fx nodes # Levels: # 0 - disabled (default) @@ -1900,6 +2058,10 @@ class trace: "INDUCTOR_PROVENANCE", os.environ.get("TORCH_COMPILE_DEBUG", "0") ) ) +======= + # Save mapping info from inductor generated triton kernel to post_grad fx nodes + log_inductor_triton_kernel_to_post_grad_node_info: bool = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _save_config_ignore: list[str] = [ @@ -1927,8 +2089,11 @@ class trace: # see CustomGraphPass; these are handled specially "post_grad_custom_post_pass", "post_grad_custom_pre_pass", +<<<<<<< HEAD "joint_custom_pre_pass", "joint_custom_post_pass", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "_fuse_ddp_communication_passes", "_pre_fusion_custom_pass", # tests assume that changes here don't invalidate cache @@ -1959,12 +2124,15 @@ class test_configs: graphsafe_rng_func_ignores_fallback_random = False +<<<<<<< HEAD track_memory_lifecycle: Optional[Literal["assert", "log"]] = None # If set to True, AOTI-generated CMakelists.txt will still use libtorch # for unit testing use_libtorch = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index e2cb445ed1080..cf9968c8db96b 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -2,7 +2,10 @@ # The design document please check this RFC: https://github.com/pytorch/pytorch/issues/124245 import copy +<<<<<<< HEAD import ctypes +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import errno import functools import json @@ -19,7 +22,11 @@ import textwrap import warnings from collections.abc import Sequence +<<<<<<< HEAD from ctypes import cdll, wintypes +======= +from ctypes import cdll +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ctypes.util import find_library from pathlib import Path from typing import Any, Optional, Union @@ -142,6 +149,7 @@ def check_compiler_exist_windows(compiler: str) -> None: pass +<<<<<<< HEAD class WinPeFileVersionInfo: def __init__(self, file_path: str) -> None: self.file_path = file_path @@ -337,6 +345,12 @@ def get_cpp_compiler() -> str: compiler = normalize_path_separator(compiler) check_compiler_exist_windows(compiler) check_msvc_cl_language_id(compiler) +======= +def get_cpp_compiler() -> str: + if _IS_WINDOWS: + compiler = os.environ.get("CXX", "cl") + check_compiler_exist_windows(compiler) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: if config.is_fbcode(): return build_paths.cc @@ -754,11 +768,14 @@ def _get_os_related_cpp_cflags(cpp_compiler: str) -> list[str]: "wd4067", "wd4068", "EHsc", +<<<<<<< HEAD # For Intel oneAPI, ref: https://learn.microsoft.com/en-us/cpp/build/reference/zc-cplusplus?view=msvc-170 "Zc:__cplusplus", # Enable max compatible to msvc for oneAPI headers. # ref: https://github.com/pytorch/pytorch/blob/db38c44ad639e7ada3e9df2ba026a2cb5e40feb0/cmake/public/utils.cmake#L352-L358 # noqa: B950 "permissive-", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] else: cflags = ["Wno-unused-variable", "Wno-unknown-pragmas"] @@ -769,6 +786,7 @@ def _get_os_related_cpp_cflags(cpp_compiler: str) -> list[str]: else "Wno-ignored-optimization-argument" ) cflags.append(ignored_optimization_argument) +<<<<<<< HEAD if _is_gcc(cpp_compiler): # Issue all the warnings demanded by strict ISO C and ISO C++. # Ref: https://github.com/pytorch/pytorch/issues/153180#issuecomment-2986676878 @@ -812,10 +830,33 @@ def _get_ffast_math_flags() -> list[str]: if is_gcc(): flags.append("fexcess-precision=fast") +======= + return cflags + + +def _get_ffast_math_flags() -> list[str]: + # ffast-math is equivalent to these flags as in + # https://github.com/gcc-mirror/gcc/blob/4700ad1c78ccd7767f846802fca148b2ea9a1852/gcc/opts.cc#L3458-L3468 + # however gcc<13 sets the FTZ/DAZ flags for runtime on x86 even if we have + # -ffast-math -fno-unsafe-math-optimizations because the flags for runtime + # are added by linking in crtfastmath.o. This is done by the spec file which + # only does globbing for -ffast-math. + flags = [ + "fno-trapping-math", + "funsafe-math-optimizations", + "ffinite-math-only", + "fno-signed-zeros", + "fno-math-errno", + ] + + if is_gcc(): + flags.append("fexcess-precision=fast") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return flags +<<<<<<< HEAD def _get_inductor_debug_symbol_cflags() -> tuple[list[str], list[str]]: """ When we turn on generate debug symbol. @@ -863,6 +904,26 @@ def _get_optimization_cflags( if _IS_WINDOWS: pass else: +======= +def _get_optimization_cflags( + cpp_compiler: str, min_optimize: bool = False +) -> list[str]: + if _IS_WINDOWS: + return ["O1" if min_optimize else "O2"] + else: + wrapper_opt_level = config.aot_inductor.compile_wrapper_opt_level + cflags = ( + ["O0", "g"] + if config.aot_inductor.debug_compile + else [wrapper_opt_level if min_optimize else "O3", "DNDEBUG"] + ) + cflags += _get_ffast_math_flags() + cflags.append("fno-finite-math-only") + if not config.cpp.enable_unsafe_math_opt_flag: + cflags.append("fno-unsafe-math-optimizations") + cflags.append(f"ffp-contract={config.cpp.enable_floating_point_contract_flag}") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.platform != "darwin": # on macos, unknown argument: '-fno-tree-loop-vectorize' if _is_gcc(cpp_compiler): @@ -875,6 +936,7 @@ def _get_optimization_cflags( else: cflags.append("march=native") +<<<<<<< HEAD if config.aot_inductor.enable_lto and _is_clang(cpp_compiler): cflags.append("flto=thin") @@ -882,6 +944,12 @@ def _get_optimization_cflags( def _get_shared_cflags(do_link: bool) -> list[str]: +======= + return cflags + + +def _get_shared_cflag(do_link: bool) -> list[str]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if _IS_WINDOWS: """ MSVC `/MD` using python `ucrtbase.dll` lib as runtime. @@ -911,29 +979,42 @@ def get_cpp_options( libraries: list[str] = [] passthrough_args: list[str] = [] +<<<<<<< HEAD opt_cflags, opt_ldflags = _get_optimization_cflags(cpp_compiler, min_optimize) cflags = ( opt_cflags + _get_shared_cflags(do_link) +======= + cflags = ( + _get_shared_cflag(do_link) + + _get_optimization_cflags(cpp_compiler, min_optimize) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + _get_warning_all_cflag(warning_all) + _get_cpp_std_cflag() + _get_os_related_cpp_cflags(cpp_compiler) ) +<<<<<<< HEAD definitions += _get_os_related_cpp_definitions(cpp_compiler) if not _IS_WINDOWS and config.aot_inductor.enable_lto and _is_clang(cpp_compiler): ldflags.append("fuse-ld=lld") ldflags.append("flto=thin") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) passthrough_args.append(" ".join(extra_flags)) return ( definitions, include_dirs, cflags, +<<<<<<< HEAD ldflags + opt_ldflags, +======= + ldflags, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) libraries_dirs, libraries, passthrough_args, @@ -996,6 +1077,16 @@ def __init__( self._finalize_options() +<<<<<<< HEAD +======= +def _get_glibcxx_abi_build_flags() -> list[str]: + if not _IS_WINDOWS: + return ["-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))] + else: + return [] + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _get_torch_cpp_wrapper_definition() -> list[str]: return ["TORCH_INDUCTOR_CPP_WRAPPER", "STANDALONE_TORCH_HEADER"] @@ -1321,6 +1412,7 @@ def _get_openmp_args( return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthrough_args +<<<<<<< HEAD def _get_libstdcxx_args() -> tuple[list[str], list[str]]: """ For fbcode cpu case, we should link stdc++ instead assuming the binary where dlopen is executed is built with dynamic stdc++. @@ -1334,6 +1426,8 @@ def _get_libstdcxx_args() -> tuple[list[str], list[str]]: return lib_dir_paths, libs +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_mmap_self_macro(use_mmap_weights: bool) -> list[str]: macros = [] if use_mmap_weights: @@ -1349,6 +1443,7 @@ def get_cpp_torch_options( use_relative_path: bool, use_mmap_weights: bool, ) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]: +<<<<<<< HEAD """ This function is used to get the build args of torch related build options. 1. Torch include_directories, libraries, libraries_directories. @@ -1358,6 +1453,8 @@ def get_cpp_torch_options( 5. MISC 6. Return the build args """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) definitions: list[str] = [] include_dirs: list[str] = [] cflags: list[str] = [] @@ -1394,6 +1491,10 @@ def get_cpp_torch_options( omp_passthrough_args, ) = _get_openmp_args(cpp_compiler) +<<<<<<< HEAD +======= + cxx_abi_passthrough_args = _get_glibcxx_abi_build_flags() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fb_macro_passthrough_args = _use_fb_internal_macros() mmap_self_macros = get_mmap_self_macro(use_mmap_weights) @@ -1416,7 +1517,14 @@ def get_cpp_torch_options( libraries_dirs = python_libraries_dirs + torch_libraries_dirs + omp_lib_dir_paths libraries = torch_libraries + omp_lib passthrough_args = ( +<<<<<<< HEAD sys_libs_passthrough_args + isa_ps_args_build_flags + omp_passthrough_args +======= + sys_libs_passthrough_args + + isa_ps_args_build_flags + + cxx_abi_passthrough_args + + omp_passthrough_args +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return ( @@ -1538,6 +1646,7 @@ def get_cpp_torch_device_options( aot_mode: bool = False, compile_only: bool = False, ) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]: +<<<<<<< HEAD """ This function is used to get the build args of device related build options. 1. Device include_directories, libraries, libraries_directories. @@ -1545,6 +1654,8 @@ def get_cpp_torch_device_options( 3. MISC 4. Return the build args """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) definitions: list[str] = [] include_dirs: list[str] = [] cflags: list[str] = [] @@ -1564,8 +1675,11 @@ def get_cpp_torch_device_options( include_dirs = cpp_extension.include_paths(device_type) libraries_dirs = cpp_extension.library_paths(device_type) +<<<<<<< HEAD if not config.is_fbcode(): libraries += ["c10"] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if device_type == "cuda": definitions.append(" USE_ROCM" if torch.version.hip else " USE_CUDA") @@ -1584,6 +1698,7 @@ def get_cpp_torch_device_options( if device_type == "xpu": definitions.append(" USE_XPU") +<<<<<<< HEAD xpu_error_string = ( "Intel GPU driver is not properly installed, please follow the instruction " "in https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support." @@ -1602,6 +1717,16 @@ def get_cpp_torch_device_options( if not find_library("ze_loader"): raise OSError(xpu_error_string) +======= + # Suppress multi-line comment warnings in sycl headers + cflags += ["Wno-comment"] + libraries += ["c10_xpu", "sycl", "ze_loader", "torch_xpu"] + if not find_library("ze_loader"): + raise OSError( + "Intel GPU driver is not properly installed, please follow the instruction " + "in https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support." + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if device_type == "mps": definitions.append(" USE_MPS") @@ -1615,6 +1740,7 @@ def get_cpp_torch_device_options( # Only add link args, when compile_only is false. passthrough_args = ["-Wl,-Bstatic -lcudart_static -Wl,-Bdynamic"] +<<<<<<< HEAD if device_type == "cpu": ( stdcxx_lib_dir_paths, @@ -1623,6 +1749,8 @@ def get_cpp_torch_device_options( libraries_dirs += stdcxx_lib_dir_paths libraries += stdcxx_libs +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if config.aot_inductor.custom_op_libs: libraries += config.aot_inductor.custom_op_libs @@ -1802,9 +1930,12 @@ def __init__( self._aot_mode: bool = False self._name = name +<<<<<<< HEAD self._target_name = ( config.aot_inductor.model_name_for_generated_files or "aoti_model" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Code start here, initial self internal variables firstly. self._build_option = BuildOption @@ -1857,8 +1988,12 @@ def __init__( if isinstance(sources, str): sources = [sources] +<<<<<<< HEAD # Use relative paths only when requested (typically for remote builds) if config.is_fbcode() and self._use_relative_path: +======= + if config.is_fbcode() and (not self._aot_mode or self._use_relative_path): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Will create another temp directory for building, so do NOT use the # absolute path. self._orig_source_paths = list(sources) @@ -2031,6 +2166,7 @@ def save_compile_cmd_to_cmake( """ definitions = " ".join(self._build_option.get_definitions()) +<<<<<<< HEAD target_library_type = ( "STATIC" if config.aot_inductor.compile_standalone else "SHARED" ) @@ -2089,6 +2225,30 @@ def save_compile_cmd_to_cmake( """ ) +======= + contents = textwrap.dedent( + f""" + cmake_minimum_required(VERSION 3.27 FATAL_ERROR) + project(aoti_model LANGUAGES CXX) + set(CMAKE_CXX_STANDARD 17) + + # May need to point CMAKE_PREFIX_PATH to the right torch location + find_package(Torch REQUIRED) + + # Set a shared library target + add_library(aoti_model SHARED) + + # Add macro definitions + target_compile_definitions(aoti_model PRIVATE {definitions}) + + # Add compile flags + target_compile_options(aoti_model PRIVATE {self._cflags_args}) + # Backend specific flags + target_compile_options(aoti_model PRIVATE {self._passthrough_parameters_args} -c) + + """ + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if device_type == "cuda" and torch.version.hip is None: from torch._inductor.codecache import _nvcc_arch_as_compile_option @@ -2096,11 +2256,15 @@ def save_compile_cmd_to_cmake( contents += textwrap.dedent( f""" enable_language(CUDA) +<<<<<<< HEAD set(CMAKE_CUDA_STANDARD 17) find_package(CUDAToolkit REQUIRED) target_include_directories({self._target_name} PRIVATE ${{CUDAToolkit_INCLUDE_DIRS}}) target_compile_definitions({self._target_name} PRIVATE USE_CUDA) target_link_libraries({self._target_name} PRIVATE cuda CUDA::cudart_static) +======= + find_package(CUDAToolkit REQUIRED) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) find_program(OBJCOPY_EXECUTABLE objcopy) if(NOT OBJCOPY_EXECUTABLE) @@ -2129,7 +2293,11 @@ def save_compile_cmd_to_cmake( add_custom_command( OUTPUT ${{FATBIN_FILE}} COMMAND ${{CUDAToolkit_NVCC_EXECUTABLE}} --fatbin ${{PTX_FILE}} -o ${{FATBIN_FILE}} ${{NVCC_GENCODE_FLAGS}} +<<<<<<< HEAD -gencode arch=compute_{current_arch},code=compute_{current_arch} +======= + -gencode arch=compute_80,code=compute_80 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) -gencode arch=compute_{current_arch},code=sm_{current_arch} DEPENDS ${{PTX_FILE}} ) @@ -2164,7 +2332,11 @@ def save_src_to_cmake(self, cmake_path: str, src_path: str) -> None: # Remove the directory part of file_path src_path = "${CMAKE_CURRENT_SOURCE_DIR}/" + Path(src_path).name with open(cmake_path, "a") as f: +<<<<<<< HEAD f.write(f"target_sources({self._target_name} PRIVATE {src_path})\n") +======= + f.write(f"target_sources(aoti_model PRIVATE {src_path})\n") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def save_kernel_asm_to_cmake(self, cmake_path: str, asm_files: list[str]) -> None: # TODO: make this work beyond CUDA @@ -2178,6 +2350,7 @@ def save_kernel_asm_to_cmake(self, cmake_path: str, asm_files: list[str]) -> Non """ ) f.write(contents) +<<<<<<< HEAD if asm_files: f.write(f"add_dependencies({self._target_name} ${{KERNEL_TARGETS}})\n") f.write( @@ -2192,15 +2365,30 @@ def save_link_cmd_to_cmake(self, cmake_path: str) -> None: # When compile_standalone is True, do not link with libtorch return +======= + f.write("add_dependencies(aoti_model ${KERNEL_TARGETS})\n") + f.write( + "target_link_libraries(aoti_model PRIVATE ${KERNEL_OBJECT_FILES})\n" + ) + + def save_link_cmd_to_cmake(self, cmake_path: str) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lflags = " ".join(self._build_option.get_ldflags()) libs = " ".join(self._build_option.get_libraries()) contents = textwrap.dedent( f""" # Add linker flags +<<<<<<< HEAD target_link_options({self._target_name} PRIVATE {lflags}) # Add libraries target_link_libraries({self._target_name} PRIVATE {libs}) +======= + target_link_options(aoti_model PRIVATE {lflags}) + + # Add libraries + target_link_libraries(aoti_model PRIVATE {libs}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ ) @@ -2209,6 +2397,7 @@ def save_link_cmd_to_cmake(self, cmake_path: str) -> None: ) with open(cmake_path, "a") as f: f.write(contents) +<<<<<<< HEAD def run_asm_build_object(src: str, target: str, cwd: str) -> None: @@ -2239,3 +2428,5 @@ def get_command_line(asm_cc: str, src: str, target: str) -> str: target=normalize_path_separator(target), ) run_compile_cmd(cmd, cwd=normalize_path_separator(cwd)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index f2fd105e6a961..e2b297b4cabd1 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -11,7 +11,10 @@ import torch from torch._inductor import config +<<<<<<< HEAD from torch._inductor.utils import python_subprocess_env +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _IS_WINDOWS = sys.platform == "win32" @@ -132,7 +135,16 @@ def check_build(self, code: str) -> bool: ], cwd=output_dir, stderr=subprocess.DEVNULL, +<<<<<<< HEAD env=python_subprocess_env(), +======= + env={ + **os.environ, + "PYTHONPATH": os.environ.get( + "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path) + ), + }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) except Exception: return False @@ -200,13 +212,17 @@ class VecAVX512(VecISA): else "/arch:AVX512" ) # TODO: use cflags _dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32} +<<<<<<< HEAD _is_avx512_bf16_supported = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __str__(self) -> str: return "avx512" __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] +<<<<<<< HEAD _avx512_bf16_code = """ #include #include @@ -245,13 +261,18 @@ def build_arch_flags(self) -> str: else: return self._arch_flags +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass class VecAMX(VecAVX512): _arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8" +<<<<<<< HEAD # check amx_fp16 separately since it is not always supported when amx is supported # amx_fp16 intrinsic compilation need gcc >=13 on platforms which support amx_fp16 _is_amx_fp16_supported = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __str__(self) -> str: return super().__str__() + " amx_tile" @@ -279,14 +300,18 @@ def __str__(self) -> str: } """ +<<<<<<< HEAD _amx_fp16_code = _amx_code.replace("_tile_dpbf16ps", "_tile_dpfp16ps") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @functools.cache # noqa: B019 def __bool__(self) -> bool: if super().__bool__(): if config.is_fbcode(): return False if self.check_build(VecAMX._amx_code) and torch.cpu._init_amx(): +<<<<<<< HEAD # check amx-fp16 as well when check amx if torch.cpu._is_amx_fp16_supported(): # save _arch_flags @@ -315,6 +340,11 @@ def build_arch_flags(self) -> str: extra_flags += " -mamx-fp16" return self._arch_flags + extra_flags +======= + return True + return False + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass class VecAVX2(VecISA): diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 3b3dea909cd24..0ab4ef52cbaa8 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -90,7 +90,10 @@ from torch._guards import CompileId from torch._inductor.utils import InputType +<<<<<<< HEAD from torch.cuda import _POOL_HANDLE +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.types import _bool StorageWeakRefPointer = int @@ -818,7 +821,11 @@ def __init__( id: GraphID, parent: Optional[CUDAGraphNode], inputs: list[InputType], +<<<<<<< HEAD cuda_graphs_pool: _POOL_HANDLE, +======= + cuda_graphs_pool: tuple[int, int], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device_index: int, stack_traces: Optional[StackTraces], stream: torch.cuda.Stream, @@ -1229,7 +1236,10 @@ def all_outputs_are_dead(self) -> bool: def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType: "Record the model" +<<<<<<< HEAD assert self.graph is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def static_input_iter() -> Generator[torch.Tensor, None, None]: for i in self.wrapped_function.static_input_idxs: @@ -1312,11 +1322,21 @@ def _add_first_outputs( self.output_storage_alias.append(UnaliasedStorage) continue +<<<<<<< HEAD torch._check( o.is_cuda or o.untyped_storage().data_ptr() == 0, lambda: ( "Expected all cuda outputs in cuda graph recording. Non cuda output " f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" +======= + ( + torch._check( + o.is_cuda or o.untyped_storage().data_ptr() == 0, + lambda: ( + "Expected all cuda outputs in cuda graph recording. Non cuda output " + f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ) diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index effed470548cb..e02e45b32e3ca 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -10,11 +10,17 @@ from torch._inductor.utils import GraphPartitionMap, InputType from torch.utils._ordered_set import OrderedSet +<<<<<<< HEAD from .utils import is_using_cudagraph_partition if TYPE_CHECKING: from collections.abc import Sequence, Set as AbstractSet +======= + +if TYPE_CHECKING: + from collections.abc import Sequence +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") @@ -110,8 +116,12 @@ def format_default_skip_message(reason: str) -> str: def get_mutation_stack_trace( +<<<<<<< HEAD placeholders: Sequence[PlaceholderInfo], mutation_indices: Union[AbstractSet[int], Sequence[int]], +======= + placeholders: Sequence[PlaceholderInfo], mutation_indices: Sequence[int] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> str: stack_trace: Optional[str] = "" @@ -173,8 +183,12 @@ def check_multiple_devices_or_any_cpu_nodes( # meta tensors are supported since there is no compute device_node_mapping.pop(torch.device("meta"), None) +<<<<<<< HEAD # dynamo cudagraph does not support graph partition if is_using_cudagraph_partition(): +======= + if torch._inductor.config.graph_partition: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # graph partition supports splitting on cpu op. So we can ignore cpu nodes. device_node_mapping.pop(torch.device("cpu"), None) @@ -204,10 +218,13 @@ def check_lowering_disable_cudagraph( def log_cudagraph_skip_and_bump_counter(msg: str) -> None: perf_hint_log.warning(msg) counters["inductor"]["cudagraph_skips"] += 1 +<<<<<<< HEAD if torch._inductor.config.triton.cudagraph_or_error: raise RuntimeError(msg) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) metrics_context = get_metrics_context() if metrics_context.in_progress(): metrics_context.set("cudagraph_skip_reason", msg, overwrite=True) diff --git a/torch/_inductor/custom_graph_pass.py b/torch/_inductor/custom_graph_pass.py index 413a224724fd5..ed26a65840d6c 100644 --- a/torch/_inductor/custom_graph_pass.py +++ b/torch/_inductor/custom_graph_pass.py @@ -1,6 +1,9 @@ import hashlib from abc import ABC, abstractmethod +<<<<<<< HEAD from collections.abc import Sequence +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from functools import lru_cache from typing import Any, Callable, Optional, Union from typing_extensions import TypeAlias @@ -103,6 +106,7 @@ def get_hash_for_files(paths: tuple[str], extra: str = "") -> bytes: hasher.update(path.encode("utf-8")) hasher.update(f.read()) return hasher.digest() +<<<<<<< HEAD class CustomPartitionerFn(ABC): @@ -158,3 +162,5 @@ def uuid(self) -> Optional[Any]: CustomPartitionerFnType: TypeAlias = Optional[CustomPartitionerFn] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index e9df7119bb752..89d4ef1803d92 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -13,7 +13,11 @@ import pstats import shutil import traceback +<<<<<<< HEAD from collections.abc import Iterator, Sequence +======= +from collections.abc import Iterator +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Any, Callable, IO, Optional, Union from unittest.mock import patch @@ -22,10 +26,14 @@ from torch import fx as fx from torch._dynamo.repro.after_aot import save_graph_repro from torch._dynamo.utils import get_debug_dir +<<<<<<< HEAD from torch._inductor import utils from torch._logging import getArtifactLogger from torch._logging._internal import trace_structured from torch._utils_internal import signpost_event +======= +from torch._logging import getArtifactLogger +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.graph_module import GraphModule from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata from torch.fx.passes.tools_common import legalize_graph @@ -34,7 +42,10 @@ from torch.utils._pytree import tree_map from . import config, ir # noqa: F811, this is needed +<<<<<<< HEAD from .ir import ExternKernel +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .scheduler import ( BaseSchedulerNode, FusedSchedulerNode, @@ -47,11 +58,14 @@ log = logging.getLogger(__name__) +<<<<<<< HEAD # Graph execution tracking for debugging GRAPH_EXECUTION_ORDER: Optional[list[dict[str, object]]] = None RECORD_GRAPH_EXECUTION: bool = False GRAPH_COMPILE_IDS: Optional[dict[int, Optional[str]]] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ir_pre_fusion_log = getArtifactLogger(__name__, "ir_pre_fusion") ir_post_fusion_log = getArtifactLogger(__name__, "ir_post_fusion") SchedulerNodeList = list[Any] @@ -321,6 +335,7 @@ def enable_aot_logging() -> Iterator[None]: # Used for provenance tracking # They are not stored in DebugContext because they are not set in # _inductor_triton_kernel_to_post_grad_node_info's Debug Context +<<<<<<< HEAD _inductor_post_to_pre_grad_nodes: dict[str, dict[str, list[str]]] = {} _inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {} _pre_grad_graph_id: Optional[int] = None @@ -375,11 +390,21 @@ def reset_provenance_globals() -> Iterator[None]: _inductor_pre_grad_node_stack_trace = ( original_inductor_pre_grad_node_stack_trace ) +======= +_inductor_post_to_pre_grad_nodes: dict[str, Any] = {} +_pre_grad_graph_id: Optional[int] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class DebugContext: _counter = itertools.count() +<<<<<<< HEAD +======= + # Used for provenance tracking + _inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def create_debug_dir(folder_name: str) -> Optional[str]: debug_dir = config.trace.debug_dir or get_debug_dir() @@ -615,6 +640,28 @@ def draw_orig_fx_graph( def output_code(self, filename: str, extension: str = "py") -> None: shutil.copy(filename, self.filename(f"output_code.{extension}")) +<<<<<<< HEAD +======= + def log_inductor_triton_kernel_to_post_grad_node_info( + self, filename: str = "inductor_generated_kernel_to_post_grad_nodes.json" + ) -> tuple[dict[str, list[str]], dict[str, Any]]: + debug_info = {} + with self.fopen(filename, "w") as fd: + log.info("Writing provenance tracing debugging info to %s", fd.name) + debug_info = DebugContext._inductor_triton_kernel_to_post_grad_node_info + json.dump(debug_info, fd) + node_mapping = {} + if _pre_grad_graph_id: + with self.fopen( + "inductor_provenance_tracking_node_mappings.json", "w" + ) as fd: + node_mapping = create_node_mapping( + _pre_grad_graph_id, _inductor_post_to_pre_grad_nodes, debug_info + ) + json.dump(node_mapping, fd) + return debug_info, node_mapping + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def log_autotuning_results( self, name: str, @@ -720,6 +767,7 @@ def log_ir_post_fusion(nodes: SchedulerNodeList) -> None: V.debug.ir_post_fusion(nodes) +<<<<<<< HEAD def _dump_collective_schedule(schedule: list[Union[str, None]]) -> None: try: trace_structured( @@ -846,6 +894,8 @@ def record_and_log_graph_execution_order() -> Iterator[None]: GRAPH_COMPILE_IDS = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass class TensorMetadataHolder: tensor_metadata: TensorMetadata @@ -855,6 +905,7 @@ class TensorMetadataHolder: save_args_cnt = itertools.count() +<<<<<<< HEAD def create_mapping_pre_post_grad_nodes( pre_grad_graph_id: Optional[int], post_to_pre_grad_nodes_json: dict[str, Any], @@ -869,19 +920,67 @@ def create_mapping_pre_post_grad_nodes( "postToPre": {}, } +======= +def create_node_mapping( + pre_grad_graph_id: int, + post_to_pre_grad_nodes_json: dict[str, Any], + triton_kernel_to_post_grad_json: dict[str, Any], +) -> dict[str, dict[str, Any]]: + """Create bidirectional mappings between: + + - pre_grad graph nodes and post_grad graph code nodes, and vice versa + - triton kernel name and post_grad graph code nodes, and vice versa + """ + + # return a dummy dict if there's any error + empty_return: dict[str, dict[str, Any]] = { + "preToPost": {}, + "postToPre": {}, + "cppCodeToPost": {}, + "postToCppCode": {}, + } + + log.info("Creating node mappings for provenance tracking") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not isinstance(post_to_pre_grad_nodes_json, dict): log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict") return empty_return +<<<<<<< HEAD if not isinstance(pre_grad_graph_id, int): # pre_grad_graph_id may be empty if there's no pre_grad graph # and there's only a backward graph from backward pass engine +======= + if not isinstance(triton_kernel_to_post_grad_json, dict): + log.error( + "Provenance tacking error: triton_kernel_to_post_grad_json is not a dict" + ) + return empty_return + + if not isinstance(pre_grad_graph_id, int): + log.error("Provenance tacking error: pre_grad_graph_id is not an int") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return empty_return pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet) post_to_pre: dict[str, Any] = collections.defaultdict(OrderedSet) +<<<<<<< HEAD try: +======= + post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet) + + try: + for outer_key, node_array in triton_kernel_to_post_grad_json.items(): + if not isinstance(node_array, list): + log.error( + "Provenance tacking error: triton_kernel_to_post_grad_json value is not a list" + ) + return empty_return + for curr_node in node_array: + post_to_cpp_code[curr_node].add(outer_key) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def check_format(node: dict[str, Any]) -> bool: if not isinstance(node, dict): @@ -931,6 +1030,7 @@ def convert_sets_to_lists(d: dict[str, Any]) -> None: # convert to list because set is not JSON serializable convert_sets_to_lists(pre_to_post) convert_sets_to_lists(post_to_pre) +<<<<<<< HEAD return { "preToPost": pre_to_post, "postToPre": post_to_pre, @@ -991,12 +1091,19 @@ def convert_sets_to_lists(d: dict[str, Any]) -> None: # convert to list because set is not JSON serializable convert_sets_to_lists(post_to_cpp_code) return { +======= + convert_sets_to_lists(post_to_cpp_code) + return { + "preToPost": pre_to_post, + "postToPre": post_to_pre, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "cppCodeToPost": triton_kernel_to_post_grad_json, "postToCppCode": post_to_cpp_code, } except Exception as e: # Since this is just logging code, it should never interfere with regular # program execution, so we use this try-except to guard against any error +<<<<<<< HEAD signpost_event( "inductor", "provenance_tracking_error", @@ -1167,6 +1274,18 @@ def set_kernel_post_grad_provenance_tracing( return None +======= + log.error("Unexpected error in create_node_mapping: %s", e) + log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json) + log.error( + "triton_kernel_to_post_grad_json: %s", triton_kernel_to_post_grad_json + ) + log.error("pre_grad_graph_id: %s", pre_grad_graph_id) + log.error(traceback.format_exc()) + return empty_return + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None: """ This function is used to save arguments for a compile_fx_inner function call @@ -1250,7 +1369,11 @@ def aot_inductor_minifier_wrapper( use_minifier = config.aot_inductor.dump_aoti_minifier +<<<<<<< HEAD gm = exported_program.module(check_guards=False) +======= + gm = exported_program.module() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(gm, torch.fx.GraphModule) args, kwargs = exported_program.example_inputs @@ -1279,7 +1402,11 @@ def aot_inductor_minifier_wrapper( tuple_inputs = tuple(flat_example_inputs) flattened_ep = torch.export.export(gm_copy, tuple_inputs, strict=False) func( +<<<<<<< HEAD flattened_ep.module(check_guards=False), +======= + flattened_ep.module(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tuple_inputs, inductor_configs=config_copy, package_path=package_path, diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index eebe6c974e173..f6682cef2aa9b 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -34,7 +34,15 @@ ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype, ) +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import guard_or_false, statically_known_true +======= +from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_size_oblivious, + statically_known_true, +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from . import config, inductor_prims from .utils import ( @@ -158,6 +166,22 @@ def _embedding_dense_backward( ) +<<<<<<< HEAD +======= +# TODO: for now, inductor doesn't handle asserts +# because the condition is symbol -> tensor in the graph. +@register_decomposition([aten._assert_async.msg]) +def assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None: + return + + +# Following `assert_async_msg_decomp` and implement as non-op. +@register_decomposition([aten._functional_assert_async.msg]) +def functional_assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None: + return + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_decomposition([aten.sym_constrain_range_for_size.default]) def sym_constrain_range_for_size( symbol: torch.SymInt, @@ -350,7 +374,11 @@ def mm( and guard_or_false((torch.numel(self) + torch.numel(input2)) <= 32) ): counters["inductor"]["decompose_mm"] += 1 +<<<<<<< HEAD return self * input2 +======= + return torch.cat([self[i, :] * input2 for i in range(self.size(0))]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if statically_known_true(self.size(0) == 1) and statically_known_true( input2.size(-1) == 1 ): @@ -387,10 +415,17 @@ def non_empty_tensor(x: torch.Tensor) -> bool: # runtime assert forcing u0 to be zero. So if this hasn't happened, # we know that the unbacked SymInt has appropriate size and there are # no problems. +<<<<<<< HEAD if len(x.shape) == 1 and guard_or_false(x.shape[0] == 0): return False if dim < len(x.shape) and guard_or_false(x.shape[dim] == 0): +======= + if len(x.shape) == 1 and guard_size_oblivious(x.shape[0] == 0): + return False + + if dim < len(x.shape) and guard_size_oblivious(x.shape[dim] == 0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False return True @@ -459,6 +494,7 @@ def add( y_is_complex_tensor = torch.is_tensor(y) and y.is_complex() if not x_is_complex_tensor or not y_is_complex_tensor: return NotImplemented +<<<<<<< HEAD output_size_zero = False if x.ndim == 0 and y.ndim == 0: @@ -469,6 +505,8 @@ def add( if y.ndim == 0: y = y.reshape(1) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) z = y if alpha is not None: z = alpha * y @@ -500,9 +538,12 @@ def reshape_tensor_complex(tensor: torch.Tensor) -> torch.Tensor: x_reshaped = reshape_tensor_complex(x.view(x.real.dtype)) z_reshaped = reshape_tensor_complex(z.view(y.real.dtype)) result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type) +<<<<<<< HEAD if output_size_zero: return result[0] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return result @@ -576,6 +617,7 @@ def view_copy_dtype( return self.to(dtype).clone() +<<<<<<< HEAD def _get_shape_permutation_like( self: torch.Tensor, ) -> tuple[utils.ShapeType, utils.StrideType]: @@ -587,6 +629,51 @@ def _get_shape_permutation_like( permutation[l] = p return (shape, permutation) +======= +def get_like_layout( + tensor: torch.Tensor, + memory_format: Optional[torch.memory_format] = None, +) -> torch.memory_format: + # TODO: _to_copy tensor to stride permutation + if memory_format is torch.preserve_format or memory_format is None: + return utils.suggest_memory_format(tensor) + else: + return memory_format + + +@register_decomposition(aten.rand_like) +def rand_like( + self: torch.Tensor, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return torch.rand( + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randn_like) +def randn_like( + self: torch.Tensor, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return torch.randn( + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_decomposition(aten.full_like) @@ -601,6 +688,7 @@ def full_like( requires_grad: bool = False, memory_format: torch.memory_format = torch.preserve_format, ) -> torch.Tensor: +<<<<<<< HEAD dtype = self.dtype if dtype is None else dtype layout = self.layout if layout is None else layout device = self.device if device is None else device @@ -679,13 +767,63 @@ def randn_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor: @register_decomposition(aten.randint_like.default) def randint_like(self: torch.Tensor, high: int, **kwargs: Any) -> torch.Tensor: return _rand_like(functools.partial(aten.randint.low, 0, high), self, **kwargs) +======= + return torch.full( + [*self.size()], + fill_value, + dtype=dtype or self.dtype, + layout=layout or self.layout, + device=device or self.device, + requires_grad=requires_grad, + ).to(memory_format=get_like_layout(self, memory_format)) + + +@register_decomposition(aten.randint_like.default) +def randint_like( + self: torch.Tensor, + high: int, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return aten.randint.low( + 0, + high, + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_decomposition(aten.randint_like.low_dtype) def randint_like_low( +<<<<<<< HEAD self: torch.Tensor, low: int, high: int, **kwargs: Any ) -> torch.Tensor: return _rand_like(functools.partial(aten.randint.low, low, high), self, **kwargs) +======= + self: torch.Tensor, + low: int, + high: int, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: Optional[torch.memory_format] = None, + **kwargs: Any, +) -> torch.Tensor: + return aten.randint.low( + low, + high, + [*self.size()], + dtype=dtype or self.dtype, + device=device or self.device, + **kwargs, + ).to(memory_format=get_like_layout(self, memory_format)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_decomposition(aten.randint.default) @@ -701,7 +839,11 @@ def randint( def linear_dynamic_fp16_unpacked_weight( input: torch.Tensor, weight: torch.Tensor, +<<<<<<< HEAD bias: Optional[torch.Tensor] = None, +======= + bias: torch.Tensor, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> torch.Tensor: packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight) return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight( @@ -1150,6 +1292,7 @@ def rrelu_with_noise_functional( else: negative_slope = (lower + upper) / 2 return aten.leaky_relu(self, negative_slope), torch.Tensor() +<<<<<<< HEAD @register_decomposition(aten.repeat_interleave.Tensor) @@ -1172,3 +1315,5 @@ def repeat_interleave_Tensor( return torch.searchsorted( cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 835ea182f8e80..66808a0b7a66e 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -11,7 +11,10 @@ import sympy import torch +<<<<<<< HEAD from torch._inductor.utils import get_free_symbols +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.experimental.symbolic_shapes import free_symbols, free_unbacked_symbols from torch.utils._ordered_set import OrderedSet @@ -40,12 +43,15 @@ class Dep(abc.ABC): index: sympy.Expr @abc.abstractmethod +<<<<<<< HEAD def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: pass @abc.abstractmethod +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def rename(self, renames: dict[str, str]) -> Self: pass @@ -77,6 +83,7 @@ class MemoryDep(Dep): size: tuple[sympy.Expr, ...] mode: Optional[str] = None +<<<<<<< HEAD def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -86,6 +93,8 @@ def get_free_symbol_uses( | get_free_symbols(self.var_names, unbacked_only) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __repr__(self) -> str: maybe_mode = "" if self.mode is not None: @@ -323,11 +332,14 @@ def rename(self, renames: dict[str, str]) -> "StarDep": return StarDep(renames[self.name], self.mode) return self +<<<<<<< HEAD def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: return OrderedSet() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def numbytes_hint(self) -> int: try: return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( @@ -363,6 +375,7 @@ class WeakDep(Dep): name: str # Buffer that is doing the mutation mutating_buf: str +<<<<<<< HEAD # WeakDep's are also used to add dependencies to prevent some specific reordering, # E.g. collectives global ordering. # But if other pass guarantees proper ordering by its logic, @@ -374,6 +387,8 @@ def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: return OrderedSet() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def index(self) -> sympy.Expr: @@ -384,7 +399,11 @@ def get_numel(self) -> sympy.Expr: def rename(self, renames: dict[str, str]) -> "WeakDep": if self.name in renames: +<<<<<<< HEAD return WeakDep(renames[self.name], self.mutating_buf, self.is_fake) +======= + return WeakDep(renames[self.name], self.mutating_buf) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self def numbytes_hint(self) -> int: @@ -472,6 +491,7 @@ def buffer_names(self, ignore_integer_index: bool = True) -> OrderedSet[str]: names.add(dep.name) return names +<<<<<<< HEAD def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -481,6 +501,8 @@ def get_free_symbol_uses( result |= dep.get_free_symbol_uses(unbacked_only) return result +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: @@ -622,12 +644,21 @@ def index_vars_no_squeeze( def index_vars_squeeze( *argsizes: Sequence[sympy.Expr], prefix: str = "d" +<<<<<<< HEAD ) -> tuple[list[Sequence[sympy.Expr]], VarRanges]: from .ir import SqueezeView var_ranges, add_var = var_builder(prefix) args: list[Sequence[sympy.Expr]] = [] new_sizes: list[Sequence[sympy.Expr]] = [] +======= +) -> tuple[list[list[sympy.Expr]], VarRanges]: + from .ir import SqueezeView + + var_ranges, add_var = var_builder(prefix) + args: list[list[sympy.Expr]] = [] + new_sizes: list[list[sympy.Expr]] = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for size in argsizes: new_size, reindex = SqueezeView.squeezer(size) new_sizes.append(new_size) @@ -648,10 +679,14 @@ def extract_read_writes( if isinstance(fn, LoopBody): inner = extract_loop_body_with_args( +<<<<<<< HEAD fn, [*args, *hidden_args], # type: ignore[list-item] var_ranges, normalize, +======= + fn, [*args, *hidden_args], var_ranges, normalize +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: # Slow path tracing the function diff --git a/torch/_inductor/dtype_propagation.py b/torch/_inductor/dtype_propagation.py index d80caa1e2b72c..167fd21632edd 100644 --- a/torch/_inductor/dtype_propagation.py +++ b/torch/_inductor/dtype_propagation.py @@ -373,10 +373,13 @@ def placeholder(self, index: int) -> torch.dtype: f"{type(self).__name__}: ops.placeholder should not appear here" ) +<<<<<<< HEAD @staticmethod def device_assert_async(cond, msg: str) -> torch.dtype: return torch.bool +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING: diff --git a/torch/_inductor/exc.py b/torch/_inductor/exc.py index a46663ed8f8c0..b35eb7f6f4b01 100644 --- a/torch/_inductor/exc.py +++ b/torch/_inductor/exc.py @@ -92,9 +92,12 @@ def __init__(self, cmd: list[str], output: str) -> None: if isinstance(output, bytes): output = output.decode("utf-8") +<<<<<<< HEAD self.cmd = cmd self.output = output +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__( textwrap.dedent( """ @@ -111,9 +114,12 @@ def __init__(self, cmd: list[str], output: str) -> None: .format(cmd=" ".join(cmd), output=output) ) +<<<<<<< HEAD def __reduce__(self) -> tuple[type, tuple[list[str], str]]: return (self.__class__, (self.cmd, self.output)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class CUDACompileError(CppCompileError): pass diff --git a/torch/_inductor/fuzzer.py b/torch/_inductor/fuzzer.py index 8149bc7e98e79..5803a0a6f3e1e 100644 --- a/torch/_inductor/fuzzer.py +++ b/torch/_inductor/fuzzer.py @@ -23,8 +23,12 @@ ) import torch +<<<<<<< HEAD from functorch.compile import min_cut_rematerialization_partition from torch._inductor.custom_graph_pass import CustomGraphPass, CustomPartitionerFn +======= +from torch._inductor.custom_graph_pass import CustomGraphPass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.scheduler import BaseSchedulerNode from torch.utils._config_module import _ConfigEntry, ConfigModule from torch.utils._ordered_set import OrderedSet @@ -75,6 +79,7 @@ def uuid(self) -> Optional[Any]: return None +<<<<<<< HEAD class DummyPartitionerFn(CustomPartitionerFn): """ A Dummy partitioner function to be used by ConfigFuzzer @@ -89,6 +94,8 @@ def uuid(self) -> Optional[Any]: return None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T = TypeVar("T") @@ -99,7 +106,10 @@ class TypeExemplars: TYPE_EXEMPLARS: dict[str, Any] = { CustomGraphPass.__name__: DummyPass(), +<<<<<<< HEAD CustomPartitionerFn.__name__: DummyPartitionerFn(), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.fx.graph.Graph.__name__: torch.fx.graph.Graph(), BaseSchedulerNode.__name__: BaseSchedulerNode(None), # type: ignore[arg-type] } @@ -515,7 +525,10 @@ def keys(self) -> KeysView[ComboType]: "joint_custom_post_pass": DEFAULT, # Typing "joint_custom_pre_pass": DEFAULT, # Typing "pre_grad_custom_pass": DEFAULT, # Typing +<<<<<<< HEAD "custom_partitioner_fn": DEFAULT, # Typing +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, "torch._dynamo.config": { "traceable_tensor_subclasses": DEFAULT, # Typing @@ -523,7 +536,10 @@ def keys(self) -> KeysView[ComboType]: "compiled_autograd_kwargs_override": DEFAULT, # Typing "fail_on_recompile_limit_hit": DEFAULT, # fails in combo with suppress_errors "suppress_errors": DEFAULT, +<<<<<<< HEAD "caching_precompile": False, # Required +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, } diff --git a/torch/_inductor/fx_passes/b2b_gemm.py b/torch/_inductor/fx_passes/b2b_gemm.py index ff434ccba0952..e494614498bdc 100644 --- a/torch/_inductor/fx_passes/b2b_gemm.py +++ b/torch/_inductor/fx_passes/b2b_gemm.py @@ -1,7 +1,10 @@ # mypy: allow-untyped-defs import functools from collections import deque +<<<<<<< HEAD from typing import Union +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch.utils._ordered_set import OrderedSet @@ -13,7 +16,10 @@ FixedLayout, FlexibleLayout, InputBuffer, +<<<<<<< HEAD ShapeAsConstantBuffer, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) StorageBox, Subgraph, TensorBox, @@ -495,12 +501,19 @@ def convert_output_node_to_buffer(output): "The output node for B2B-GEMM's subgraph must be a StorageBox, but got: ", type(output_buffer), ) +<<<<<<< HEAD device = output_buffer.data.get_device() assert device is not None subgraph_buffer = ComputedBuffer( name=None, layout=FlexibleLayout( device=device, +======= + subgraph_buffer = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=output_buffer.data.get_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype=output_buffer.data.get_dtype(), size=output_buffer.data.get_size(), ), @@ -516,7 +529,11 @@ def convert_output_node_to_buffer(output): def create_placeholder( name: str, dtype: torch.dtype, device: torch.device +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= +) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Creates a placeholder input buffers for producing subgraph_output """ @@ -542,11 +559,16 @@ def tuned_b2b_gemm( A.get_dtype(), [A.shape[0], C.shape[1]], # type: ignore[index] ) +<<<<<<< HEAD placeholders = [ create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error()) ] subgraph_buffer = build_subgraph_buffer( placeholders, # type: ignore[arg-type, list-item] +======= + subgraph_buffer = build_subgraph_buffer( + [create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error())], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) subgraph, ) choices: list[TritonTemplateCaller] = [] diff --git a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py index 31c6dae82fdbe..b3ccf1cd8d11d 100644 --- a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py +++ b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py @@ -4,10 +4,14 @@ import torch from torch import Tensor from torch._dynamo.utils import counters, is_node_meta_valid +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import ( statically_known_false, statically_known_true, ) +======= +from torch.fx.experimental.symbolic_shapes import statically_known_true +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .. import config from ..pattern_matcher import Arg, CallFunction, Match, register_graph_pattern @@ -18,6 +22,7 @@ log = logging.getLogger(__name__) # TODO: need a better strategy for decomposing mm +<<<<<<< HEAD # The following two constants are for CUDA device only MIN_FIRST_DIMENSION_DECOMPOSITION = 10240 MAX_OTHER_DIMENSION_DECOMPOSITION = 32 @@ -29,6 +34,13 @@ max_other_dimension_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION cpu_max_first_dimension_decomposition = CPU_MAX_FIRST_DIMENSION_DECOMPOSITION cpu_max_other_dimension_decomposition = CPU_MAX_OTHER_DIMENSION_DECOMPOSITION +======= +MIN_FIRST_DIMENSION_DECOMPOSITION = 10240 +MAX_OTHER_DIMENSION_DECOMPOSITION = 32 + +min_first_dimension_decomposition = MIN_FIRST_DIMENSION_DECOMPOSITION +max_other_dimension_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if "decompose_mm_pass" in config.post_grad_fusion_options: min_first_dimension_decomposition = config.post_grad_fusion_options[ "decompose_mm_pass" @@ -36,6 +48,7 @@ max_other_dimension_decomposition = config.post_grad_fusion_options[ "decompose_mm_pass" ].get("max_other_dimension_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION) +<<<<<<< HEAD cpu_max_first_dimension_decomposition = config.post_grad_fusion_options[ "decompose_mm_pass" ].get( @@ -46,6 +59,8 @@ ].get( "cpu_max_other_dimension_decomposition", CPU_MAX_OTHER_DIMENSION_DECOMPOSITION ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def check_device(a: Tensor, b: Tensor, device="cuda") -> bool: @@ -70,6 +85,7 @@ def should_decompose_bmm(mat1, mat2) -> bool: if mat1.shape[0] < min_first_dimension_decomposition: return False # 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION +<<<<<<< HEAD # use bool() to deal with BooleanAtom type if ( bool(mat1.shape[1] < max_other_dimension_decomposition) @@ -84,11 +100,21 @@ def should_decompose_bmm(mat1, mat2) -> bool: mat1.shape[0] <= cpu_max_first_dimension_decomposition and mat2.shape[0] <= cpu_max_first_dimension_decomposition ): +======= + if (mat1.shape[1] < max_other_dimension_decomposition) + ( + mat1.shape[2] < max_other_dimension_decomposition + ) + (mat2.shape[2] < max_other_dimension_decomposition) < 2: + return False + return True + elif check_device(mat1, mat2, device="cpu"): + if mat1.shape[0] == 1 and mat2.shape[0] == 1: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return True return False def should_decompose_mm(mat1, mat2) -> bool: +<<<<<<< HEAD """ Determines whether matrix multiplication (mm) should be decomposed into pointwise operations based on the input matrices' metadata, shapes, device placement, and configuration options. @@ -118,6 +144,8 @@ def should_decompose_mm(mat1, mat2) -> bool: - Designed for use in graph optimization or fusion passes where decomposing large or dynamic matrix multiplications can improve performance or memory usage. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): mat1 = mat1.meta["val"] mat2 = mat2.meta["val"] @@ -125,6 +153,7 @@ def should_decompose_mm(mat1, mat2) -> bool: return False if len(mat1.shape) != 2 or len(mat2.shape) != 2: return False +<<<<<<< HEAD # case 1: we skip decompose mm if the input is dynamic shape if not config.post_grad_fusion_options["decompose_mm_pass"].get( "skip_dynamic_shape_dim_check", False @@ -199,6 +228,19 @@ def should_decompose_mm(mat1, mat2) -> bool: ) ) ) +======= + return ( + check_device(mat1, mat2, device="cuda") + and statically_known_true(mat1.shape[0] >= min_first_dimension_decomposition) + and statically_known_true(mat2.shape[0] < max_other_dimension_decomposition) + and statically_known_true(mat2.shape[1] < max_other_dimension_decomposition) + ) or ( + check_device(mat1, mat2, device="cpu") + and statically_known_true(mat1.shape[0] == 1) + and statically_known_true(mat2.shape[0] <= 128) + and statically_known_true(mat2.shape[1] <= 512) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def print_decompose_pattern(match: Match, inputs: list[torch.fx.Node]): diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index 5f449eb496642..dec986281fc7e 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -18,6 +18,10 @@ log = logging.getLogger(__name__) aten = torch.ops.aten +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _scaled_dot_product_attention = aten.scaled_dot_product_attention @@ -581,6 +585,7 @@ def _sfdp_replacement_20(query, key, value, attn_mask, dropout_p): ) +<<<<<<< HEAD def _sfdp_pattern_24(query, key, value, attention_mask): """ this pattern is for MBartForCausalLM/PLBartForCausalLM. @@ -617,6 +622,8 @@ def _sfdp_replacement_24(query, key, value, attention_mask): ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _sfdp_pattern_21(query, key, value, attn_mask): # for T5 with inplace add query = query.permute([0, 2, 1, 3]) @@ -1038,6 +1045,7 @@ def _get_sfdp_patterns(): {}, _sfdp_params_check, ), +<<<<<<< HEAD ( _sfdp_pattern_24, _sfdp_replacement_24, @@ -1045,6 +1053,8 @@ def _get_sfdp_patterns(): {}, _sfdp_extra_check, ), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] mask_fp32_patterns = ["pattern_16"] if dtype == torch.half: diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 3f8ebe0a7d57d..e590c63d1af89 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -1365,6 +1365,7 @@ def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusion print_output=False, include_stride=True, include_device=True ) +<<<<<<< HEAD name = f"optimus_{str(rule.__class__.__name__)}" if "MTIA" in name: name = f"cff_{str(rule.__class__.__name__)}" @@ -1372,6 +1373,12 @@ def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusion "artifact", metadata_fn=lambda: { "name": name, +======= + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"optimus_{str(rule.__class__.__name__)}", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "encoding": "string", }, payload_fn=lambda: graph_str, diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 868eb74824ddd..10e01b52ac955 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -109,15 +109,22 @@ def pack_linear_weight( # depends on the alignment of internally-stored metadata. # In aot mode, we need to firstly save the packed weight, when loading it, # it will be in a different address which doesn't work. +<<<<<<< HEAD # Disable MKL prepack linear in AOT mode. # Disable MKL prepack linear when batch_size has free symbols. +======= + # Disable MKL prepack linear in AOT mode +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) packed_weight_op = ( mkldnn._reorder_linear_weight if ( is_lp_weight or mkldnn._is_mkldnn_acl_supported() or V.aot_compilation +<<<<<<< HEAD or has_free_symbols(batch_size) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else torch.ops.mkl._mkl_reorder_linear_weight ) @@ -130,12 +137,16 @@ def pack_linear( ): packed_linear_inputs: tuple[Any, ...] = (input, packed_weight_node) transpose_weight_node = packed_weight_node.args[0] +<<<<<<< HEAD if ( is_lp_weight or mkldnn._is_mkldnn_acl_supported() or V.aot_compilation or has_free_symbols(batch_size) ): +======= + if is_lp_weight or mkldnn._is_mkldnn_acl_supported() or V.aot_compilation: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) packed_linear_inputs += (bias, "none", [], "") packed_linear_op: Callable[..., Any] = mkldnn._linear_pointwise.default else: @@ -1225,6 +1236,10 @@ def is_const_or_cat_by_const(weight): weight_meta_value = linear_node.args[weight_idx].meta.get("val") if input_meta_value is None or weight_meta_value is None: return False +<<<<<<< HEAD +======= + batch_size = input_meta_value.shape[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( input_meta_value.dtype == torch.float64 or weight_meta_value.dtype == torch.float64 @@ -1234,6 +1249,7 @@ def is_const_or_cat_by_const(weight): torch.bfloat16, torch.float16, ) +<<<<<<< HEAD reduced_f32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision in [ # type: ignore[attr-defined] "bf16", "tf32", @@ -1248,6 +1264,14 @@ def is_const_or_cat_by_const(weight): not compute_with_lp and not mkldnn._is_mkldnn_acl_supported() and not torch._C.has_mkl +======= + # on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol. + # on aarch64, use mkldnn op for fp32 as well if acl is enabled + if ( + not is_lp_weight + and not mkldnn._is_mkldnn_acl_supported() + and ((not torch._C.has_mkl) or has_free_symbols(batch_size)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): return False for meta_value in [input_meta_value, weight_meta_value]: @@ -1458,6 +1482,7 @@ def linear(match, *args, **kwargs): torch.bfloat16, torch.float16, ) +<<<<<<< HEAD reduced_f32_matmul_enabled = ( torch.backends.mkldnn.matmul.fp32_precision in ["bf16", "tf32"] # type: ignore[attr-defined] ) @@ -1471,6 +1496,18 @@ def linear(match, *args, **kwargs): ) packed_linear_node = mkldnn_device_op.pack_linear( graph, compute_with_lp, batch_size, input, packed_weight_node, bias +======= + batch_size = input.meta.get("val").shape[0] + if has_free_symbols(batch_size): + assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), ( + f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}" + ) + packed_weight_node = mkldnn_device_op.pack_linear_weight( + graph, is_lp_weight, transpose_weight_node, batch_size + ) + packed_linear_node = mkldnn_device_op.pack_linear( + graph, is_lp_weight, batch_size, input, packed_weight_node, bias +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) linear_node.replace_all_uses_with(packed_linear_node) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index ba6953c091183..34e2af4f37b76 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -197,6 +197,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): pass_name = "custom_backend_passes_" + device GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass) +<<<<<<< HEAD collectives_bucketing: bool = False if config.bucket_reduce_scatters_fx != "none": from torch._inductor.fx_passes.bucketing import bucket_reduce_scatter @@ -273,6 +274,12 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): # ./fx_passes/README.md for a discussion of mutation invariants. GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass( functools.partial(reinplace_inplaceable_ops, fake_tensor_updater), +======= + # Keep these last, since they introduces mutation. Look at + # ./fx_passes/README.md for a discussion of mutation invariants. + GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass( + reinplace_inplaceable_ops +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) GraphTransformObserver( gm, "decompose_triton_kernel_wrapper_functional" @@ -1303,6 +1310,7 @@ def decomp(*flat_args): graph_pass.apply(graph) +<<<<<<< HEAD # Remove unused get_attr nodes and their corresponding attributes from the graph module. # When auto_functionalizing a hop, we need to clean up get_attr nodes for _constant_schema # and the auto_functionalized graph module that are no longer referenced. @@ -1342,6 +1350,21 @@ def decomp(*flat_args): assert isinstance(attr_name, str) delattr(graph.owning_module, attr_name) +======= + # We need to remove the get_attr registered for _constant_schema and the + # auto_functioanlized's graph module (it's replaced with original ) when auto_functionalize a hop. + _to_remove = [] + for node in graph.nodes: + if node.op == "get_attr" and len(node.users) == 0: + _to_remove.append(node) + if hasattr(graph.owning_module, node.target) and isinstance( + getattr(graph.owning_module, node.target), torch.fx.GraphModule + ): + delattr(graph.owning_module, node.target) + for node in _to_remove: + graph.erase_node(node) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph.lint() for _ in graph.find_nodes( @@ -1488,12 +1511,15 @@ def is_valid_addmm_fusion(match): if not matched: return False # Shape mismatch +<<<<<<< HEAD inp_dtype = inp.meta["val"].dtype # aten cublas integration assumes equal dtypes if inp_dtype != mat1.meta["val"].dtype or inp_dtype != mat2.meta["val"].dtype: return False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return not should_prefer_unfused_addmm(match) @@ -1762,6 +1788,7 @@ def __call__(self, graph: fx.Graph) -> None: movable_constructors = self.find_movable_constructors(graph, constructors) target_device = next(iter(target_devices)) +<<<<<<< HEAD movable_cpu_placeholders = movable_constructors & cpu_placeholders if movable_cpu_placeholders: node = next(iter(reversed(movable_cpu_placeholders))) @@ -1800,6 +1827,19 @@ def __call__(self, graph: fx.Graph) -> None: and x.target != torch.ops.aten.copy_.default, ) last_node = gpu_node +======= + for node in movable_constructors: + if node in cpu_placeholders: + with graph.inserting_after(node): + gpu_node = graph.call_function( + torch.ops.prims.device_put.default, (node, target_device) + ) + node.replace_all_uses_with( + gpu_node, + lambda x: x != gpu_node + and x.target != torch.ops.aten.copy_.default, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # noop elimination if there are other device_put for gpu_node to # target device. Alternatively, we could just move the other device_put @@ -1813,12 +1853,19 @@ def __call__(self, graph: fx.Graph) -> None: for noop in noop_device_puts: noop.replace_all_uses_with(gpu_node) graph.erase_node(noop) +<<<<<<< HEAD movable_constructors -= movable_cpu_placeholders for node in movable_constructors: kwargs = node.kwargs.copy() kwargs["device"] = target_device node.kwargs = kwargs +======= + else: + kwargs = node.kwargs.copy() + kwargs["device"] = target_device + node.kwargs = kwargs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def find_movable_constructors( self, graph: fx.Graph, constructors: list[fx.Node] diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 01f62bdf608ce..365a3f8c257ef 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -72,6 +72,7 @@ def _get_pattern_output_dtype(match: Match): output_node = pattern_output_nodes[0] assert isinstance(output_node, torch.fx.Node) output_dtype = output_node.meta["val"].dtype +<<<<<<< HEAD assert output_dtype in [ torch.int8, torch.uint8, @@ -79,6 +80,9 @@ def _get_pattern_output_dtype(match: Match): torch.bfloat16, torch.float8_e4m3fn, ] +======= + assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return output_dtype @@ -1530,7 +1534,11 @@ def _find_first_node_in_dequant_pattern(_node): counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes) +<<<<<<< HEAD def _is_valid_dequant_conv_pattern(dtype, with_dtype_convert): +======= +def _is_valid_dequant_conv_pattern(dtype): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _inner(match): # Here we do some further check to ensure: # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now. @@ -1552,7 +1560,11 @@ def _inner(match): assert dtype in [torch.float32, torch.bfloat16] +<<<<<<< HEAD if not with_dtype_convert: +======= + if dtype == torch.float32: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dequant_node = conv_node.args[0] else: convert_to_bf16 = conv_node.args[0] @@ -1567,12 +1579,19 @@ def _inner(match): return _inner +<<<<<<< HEAD def _register_qconv_weight_prepack_pass( pattern, pass_number, dtype=torch.float32, with_dtype_convert=False ): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_conv_pattern(dtype, with_dtype_convert), +======= +def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_conv_pattern(dtype), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pass_number=pass_number, ) def qconv_weight_prepack(match: Match, *args, **kwargs): @@ -1592,7 +1611,11 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): assert dtype in [torch.float32, torch.bfloat16] conv_node = match.output_node() assert conv_node.target is aten.convolution.default +<<<<<<< HEAD if not with_dtype_convert: +======= + if dtype == torch.float32: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dequant_node = conv_node.args[0] else: convert_to_bf16 = conv_node.args[0] @@ -1697,7 +1720,11 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): # Erase the original conv node graph.erase_node(conv_node) # Erase the dequant pattern +<<<<<<< HEAD if with_dtype_convert: +======= + if dtype == torch.bfloat16: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type] graph.erase_node(dequant_node) # type: ignore[arg-type] # Erase the dequant per channel pattern @@ -1713,7 +1740,11 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): def _generate_dequant_convolution_node_pattern( +<<<<<<< HEAD _dequant_per_channel_pattern, dtype=torch.float32, with_dtype_convert=False +======= + _dequant_per_channel_pattern, dtype=torch.float32 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): assert dtype in [torch.float32, torch.bfloat16] dequant_convolution_node_pattern = CallFunction( @@ -1721,7 +1752,11 @@ def _generate_dequant_convolution_node_pattern( _may_generate_pattern_with_dtype_convert( get_dequantize_per_tensor_activation_pattern(), KeywordArg("autocast_act_dtype"), +<<<<<<< HEAD with_dtype_convert, +======= + dtype == torch.bfloat16, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), _dequant_per_channel_pattern, KeywordArg("b"), @@ -1735,9 +1770,13 @@ def _generate_dequant_convolution_node_pattern( return dequant_convolution_node_pattern +<<<<<<< HEAD def _generate_qconv_weight_prepack_patterns( dtype=torch.float32, with_dtype_convert=False ): +======= +def _generate_qconv_weight_prepack_patterns(dtype=torch.float32): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert dtype in [torch.float32, torch.bfloat16] return ( _generate_dequant_convolution_node_pattern( @@ -1745,7 +1784,10 @@ def _generate_qconv_weight_prepack_patterns( if dtype == torch.float32 else dequantize_per_channel_to_bf16_weight_pattern, dtype, +<<<<<<< HEAD with_dtype_convert, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), # There is another pattern due to the pass of convert_conv_weights_to_channels_last # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. @@ -1756,7 +1798,10 @@ def _generate_qconv_weight_prepack_patterns( if dtype == torch.float32 else dequantize_per_channel_to_bf16_clone_weight_pattern, dtype, +<<<<<<< HEAD with_dtype_convert, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ) @@ -1784,11 +1829,15 @@ def _get_linear_node(match, input_dim_exceeds_two, input_contiguous): def _get_linear_dq_node( +<<<<<<< HEAD linear_node, input_index, input_dim_exceeds_two, input_contiguous, with_dtype_convert, +======= + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): act_reshape_node = None activation_to_bf16_node = None @@ -1797,7 +1846,11 @@ def _get_linear_dq_node( if input_contiguous: act_reshape_node = linear_node.args[input_index] assert act_reshape_node.target is aten.reshape.default +<<<<<<< HEAD if not with_dtype_convert: +======= + if dtype == torch.float32: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # pattern: linear -> reshape -> dequant dequant_node = act_reshape_node.args[0] else: @@ -1808,13 +1861,21 @@ def _get_linear_dq_node( # bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous act_expand_node = linear_node.args[input_index] assert act_expand_node.target is aten.expand.default +<<<<<<< HEAD if not with_dtype_convert: +======= + if dtype == torch.float32: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dequant_node = act_expand_node.args[0] else: activation_to_bf16_node = act_expand_node.args[0] dequant_node = activation_to_bf16_node.args[0] else: +<<<<<<< HEAD if not with_dtype_convert: +======= + if dtype == torch.float32: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # pattern: linear -> dequant dequant_node = linear_node.args[input_index] else: @@ -1824,9 +1885,13 @@ def _get_linear_dq_node( return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node +<<<<<<< HEAD def _is_valid_dequant_linear_pattern( dtype, input_dim_exceeds_two, input_contiguous, with_dtype_convert ): +======= +def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _inner(match): # Check dequant pattern has only 1 user. ( @@ -1842,11 +1907,15 @@ def _inner(match): _, _, ) = _get_linear_dq_node( +<<<<<<< HEAD linear_node, input_index, input_dim_exceeds_two, input_contiguous, with_dtype_convert, +======= + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) assert dequant_node.target in [ @@ -1908,12 +1977,19 @@ def _register_qlinear_weight_prepack_pass( dtype=torch.float32, input_dim_exceeds_two=False, input_contiguous=True, +<<<<<<< HEAD with_dtype_convert=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_linear_pattern( +<<<<<<< HEAD dtype, input_dim_exceeds_two, input_contiguous, with_dtype_convert +======= + dtype, input_dim_exceeds_two, input_contiguous +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), pass_number=pass_number, ) @@ -1945,11 +2021,15 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs): activation_to_bf16_node, act_expand_node, ) = _get_linear_dq_node( +<<<<<<< HEAD linear_node, input_index, input_dim_exceeds_two, input_contiguous, with_dtype_convert, +======= + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if input_dim_exceeds_two and not input_contiguous: @@ -2056,7 +2136,11 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs): else: graph.erase_node(act_expand_node) graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined] +<<<<<<< HEAD if with_dtype_convert: +======= + if dtype == torch.bfloat16: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph.erase_node(activation_to_bf16_node) # Erase the dequant pattern graph.erase_node(dequant_node) @@ -2077,7 +2161,10 @@ def _generate_dequant_linear_node_pattern( dtype=torch.float32, input_dim_exceeds_two=False, is_tensor_overload=False, +<<<<<<< HEAD with_dtype_convert=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): assert dtype in [torch.float32, torch.bfloat16] t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) @@ -2089,7 +2176,11 @@ def _generate_dequant_linear_node_pattern( _may_generate_pattern_with_dtype_convert( get_dequantize_per_tensor_activation_pattern(is_tensor_overload), KeywordArg("autocast_act_dtype"), +<<<<<<< HEAD with_dtype_convert, +======= + dtype == torch.bfloat16, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), KeywordArg("act_reshape_size"), input_dim_exceeds_two, @@ -2106,7 +2197,11 @@ def _generate_dequant_linear_node_pattern( _may_generate_pattern_with_dtype_convert( get_dequantize_per_tensor_activation_pattern(is_tensor_overload), KeywordArg("autocast_act_dtype"), +<<<<<<< HEAD with_dtype_convert, +======= + dtype == torch.bfloat16, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), KeywordArg("act_reshape_size"), input_dim_exceeds_two, @@ -2124,7 +2219,10 @@ def _generate_dequant_bmm_node_pattern( dtype=torch.float32, with_bias=False, is_tensor_overload=False, +<<<<<<< HEAD with_dtype_convert=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): # When activation of linear dim exceed 2 and not contiguous t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) @@ -2137,7 +2235,11 @@ def _generate_dequant_bmm_node_pattern( _may_generate_pattern_with_dtype_convert( get_dequantize_per_tensor_activation_pattern(is_tensor_overload), KeywordArg("autocast_act_dtype"), +<<<<<<< HEAD with_dtype_convert, +======= + dtype == torch.bfloat16, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), KeywordArg("act_expand_size"), ), @@ -2167,7 +2269,10 @@ def _generate_qlinear_weight_prepack_patterns( input_contiguous=True, with_bias=False, is_tensor_overload=False, +<<<<<<< HEAD with_dtype_convert=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): if input_dim_exceeds_two and not input_contiguous: return _generate_dequant_bmm_node_pattern( @@ -2175,7 +2280,10 @@ def _generate_qlinear_weight_prepack_patterns( dtype, with_bias, is_tensor_overload, +<<<<<<< HEAD with_dtype_convert, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: return _generate_dequant_linear_node_pattern( @@ -2183,7 +2291,10 @@ def _generate_qlinear_weight_prepack_patterns( dtype, input_dim_exceeds_two, is_tensor_overload, +<<<<<<< HEAD with_dtype_convert, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -2299,6 +2410,7 @@ def _register_dequant_promotion(): def _register_qconv_weight_prepack(): +<<<<<<< HEAD for dtype, with_dtype_convert in itertools.product( [torch.float32, torch.bfloat16], [True, False] ): @@ -2314,6 +2426,14 @@ def _register_qconv_weight_prepack(): pass_number=1, dtype=dtype, with_dtype_convert=with_dtype_convert, +======= + for dtype in [torch.float32, torch.bfloat16]: + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype) + for weight_prepack_pattern in weight_prepack_patterns: + # Register to pass_number 1, so we can do dequant promotion in pass_number 0. + _register_qconv_weight_prepack_pass( + weight_prepack_pattern, pass_number=1, dtype=dtype +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -2353,6 +2473,7 @@ def _register_qlinear_weight_prepack(): # | OPT(add) | linear_weight_prepack_cases = itertools.product( +<<<<<<< HEAD [torch.float32, torch.bfloat16], [True, False], [True, False], [True, False] ) @@ -2365,11 +2486,21 @@ def _register_qlinear_weight_prepack(): ) in linear_weight_prepack_cases: if dtype == torch.float32 and with_dtype_convert: continue +======= + [torch.float32, torch.bfloat16], [True, False], [True, False] + ) + + # Step 1: register patterns from mm and addmm + for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns( dtype, input_dim_exceeds_two, is_tensor_overload=is_tensor_overload, +<<<<<<< HEAD with_dtype_convert=with_dtype_convert, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. @@ -2378,7 +2509,10 @@ def _register_qlinear_weight_prepack(): pass_number=1, dtype=dtype, input_dim_exceeds_two=input_dim_exceeds_two, +<<<<<<< HEAD with_dtype_convert=with_dtype_convert, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Step 2: register patterns from bmm @@ -2386,6 +2520,7 @@ def _register_qlinear_weight_prepack(): # refer to: # https://github.com/pytorch/pytorch/blob/80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968 # in this case, we can convert it back to qlinear +<<<<<<< HEAD for ( dtype, with_bias, @@ -2396,13 +2531,21 @@ def _register_qlinear_weight_prepack(): ): if dtype == torch.float32 and with_dtype_convert: continue +======= + for dtype, with_bias, is_tensor_overload in itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False] + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bmm_pattern = _generate_qlinear_weight_prepack_patterns( dtype=dtype, input_dim_exceeds_two=True, input_contiguous=False, with_bias=with_bias, is_tensor_overload=is_tensor_overload, +<<<<<<< HEAD with_dtype_convert=with_dtype_convert, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) _register_qlinear_weight_prepack_pass( bmm_pattern, @@ -2412,7 +2555,10 @@ def _register_qlinear_weight_prepack(): dtype=dtype, input_dim_exceeds_two=True, input_contiguous=False, +<<<<<<< HEAD with_dtype_convert=with_dtype_convert, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index b67c0dbb729ad..69271cb512180 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -759,6 +759,7 @@ def tensor_with_same_storage_already_reinplaced(arg): graph.erase_node(node) +<<<<<<< HEAD def reinplace_inplaceable_ops( fake_tensor_updater: torch._inductor.fx_utils.FakeTensorUpdater, graph: torch.fx.Graph, @@ -769,5 +770,10 @@ def reinplace_inplaceable_ops( # We run fake_tensor_updater to update the alias information. # Correct alias information is required for `reinplace_inplaceable_ops_core`. fake_tensor_updater.incremental_update() +======= +def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None: + with enable_python_dispatcher(): + canonicalize_view_scatter_ops(graph) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reinplace_inplaceable_ops_core(graph) decompose_generalized_scatter(graph) diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index af3631dc3288d..249623051a434 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -2,7 +2,10 @@ import itertools import logging import operator +<<<<<<< HEAD import os +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from collections import defaultdict from collections.abc import Sequence from typing import Any, Callable, Optional, Union @@ -10,7 +13,11 @@ import torch from torch._dynamo.utils import counters +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import free_symbols, guard_or_false +======= +from torch.fx.experimental.symbolic_shapes import free_symbols +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._ordered_set import OrderedSet from ..pattern_matcher import ( @@ -63,7 +70,10 @@ "split_stack_to_cats_pass", "unbind_stack_to_slices_pass", "move_reshape_out_of_split_stack_pass", +<<<<<<< HEAD "einsum_to_pointwise_pass", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] post_grad_pass_names = [ @@ -77,8 +87,11 @@ "move_view_after_cat_aten_pass", ] +<<<<<<< HEAD backend = os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_BACKEND", "inductor") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for pass_name in pre_grad_pass_names: # exclude all passes from the group batch fusion # they do not use pattern matcher @@ -210,7 +223,11 @@ def normalize_split_base( split_node.replace_all_uses_with(new_split_node) new_split_node.meta.update(split_node.meta) graph.erase_node(split_node) +<<<<<<< HEAD counters[backend]["normalization_pass"] += 1 +======= + counters["inductor"]["normalization_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_graph_pattern( @@ -262,7 +279,11 @@ def remove_split_with_size_one(match: Match, *args, **kwargs): # erase the split node and its child graph.erase_node(user) graph.erase_node(split_node) +<<<<<<< HEAD counters[backend]["remove_split_with_size_one_pass"] += 1 +======= + counters["inductor"]["remove_split_with_size_one_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_graph_pattern( @@ -302,6 +323,7 @@ def normalize_unbind_default(match: Match, *args, **kwargs): node.replace_all_uses_with(new_node) new_node.meta.update(node.meta) graph.erase_node(node) +<<<<<<< HEAD counters[backend]["normalization_pass"] += 1 @@ -310,6 +332,18 @@ def normalize_unbind_default(match: Match, *args, **kwargs): pass_dict=construct_pattern_matcher_pass("normalization_pass"), ) def normalize_cat_default(match: Match, *args, **kwargs): +======= + counters["inductor"]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.cat, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_cat_default(match: Match, *args, **kwargs): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cat_node = match.nodes[0] graph = match.graph tensors = get_arg_value(cat_node, 0, "tensors") @@ -334,7 +368,11 @@ def normalize_cat_default(match: Match, *args, **kwargs): def is_empty_tensor(x): # special case where torch.cat supports cat'ing with an empty tensor x_shape = x.meta["example_value"].shape +<<<<<<< HEAD return len(x_shape) == 1 and guard_or_false(x_shape[0] == 0) +======= + return len(x_shape) == 1 and guard_size_oblivious(x_shape[0] == 0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert all( ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors @@ -349,7 +387,10 @@ def is_empty_tensor(x): cat_node.args == new_args and cat_node.kwargs == new_kwargs and cat_node.op == "call_function" +<<<<<<< HEAD and cat_node.target == torch.cat +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): return @@ -362,7 +403,11 @@ def is_empty_tensor(x): cat_node.replace_all_uses_with(new_cat_node) new_cat_node.meta.update(cat_node.meta) graph.erase_node(cat_node) +<<<<<<< HEAD counters[backend]["normalization_pass"] += 1 +======= + counters["inductor"]["normalization_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_graph_pattern( @@ -398,7 +443,11 @@ def normalize_stack_default(match: Match, *args, **kwargs): node.replace_all_uses_with(new_node) new_node.meta.update(node.meta) graph.erase_node(node) +<<<<<<< HEAD counters[backend]["normalization_pass"] += 1 +======= + counters["inductor"]["normalization_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def find_next_users(split_node: torch.fx.Node) -> list[torch.fx.Node]: @@ -659,7 +708,11 @@ def merge_splits( for node in to_remove: graph.erase_node(node) +<<<<<<< HEAD counters[backend]["merge_splits_pass"] += 1 +======= + counters["inductor"]["merge_splits_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SplitCatSimplifier: @@ -719,7 +772,11 @@ def simplify( transform_params_list, # type: ignore[arg-type] ) self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type] +<<<<<<< HEAD counters[backend]["unbind_stack_pass"] += 1 +======= + counters["inductor"]["unbind_stack_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_user_input_list( self, split_node: torch.fx.Node, next_users: list[torch.fx.Node] @@ -818,9 +875,15 @@ def get_simplified_split_ranges( split_ranges = self.fill_gaps(split_ranges, 0, cumulative_sizes[-1]) if len(split_sections) == len(split_ranges): # Simplification not possible return None +<<<<<<< HEAD counters[backend]["scmerge_split_sections_removed"] = len(split_sections) - len( split_ranges ) +======= + counters["inductor"]["scmerge_split_sections_removed"] = len( + split_sections + ) - len(split_ranges) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return split_ranges def has_non_overlapping_ranges(self, ranges: list[_Range]) -> bool: @@ -928,7 +991,11 @@ def replace_split( [r[1] - r[0] for r in split_ranges], dim=split_dim, ) +<<<<<<< HEAD counters[backend]["scmerge_split_added"] += 1 +======= + counters["inductor"]["scmerge_split_added"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) split_items = [] with graph.inserting_after(new_split): for i in range(len(split_ranges)): @@ -1089,7 +1156,11 @@ def replace_cat( user_inputs_new_transformed_meta, dim=cat_dim, ) +<<<<<<< HEAD counters[backend]["scmerge_cat_added"] += 1 +======= + counters["inductor"]["scmerge_cat_added"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: new_cat_node = user_inputs_new_transformed[-1] new_cat_node.meta["example_value"] = ( @@ -1121,12 +1192,20 @@ def erase_old_nodes( next_users: list[torch.fx.Node], ): to_remove = [split_node] +<<<<<<< HEAD counters[backend]["scmerge_split_removed"] += 1 +======= + counters["inductor"]["scmerge_split_removed"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) to_remove.extend(split_node.users.keys()) for next_user in next_users: if next_user.target not in (torch.cat, torch.stack): continue +<<<<<<< HEAD counters[backend]["scmerge_cat_removed"] += 1 +======= + counters["inductor"]["scmerge_cat_removed"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) to_remove.append(next_user) for node in reversed(to_remove): if len(node.users.keys()) == 0: @@ -1319,7 +1398,11 @@ def merge_split_squeeze( graph.erase_node(squeeze) graph.erase_node(getitem_node) graph.erase_node(split) +<<<<<<< HEAD counters[backend]["split_cat_pass"] += 1 +======= + counters["inductor"]["split_cat_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) getitem_unbind = ListOf( @@ -1579,7 +1662,11 @@ def merge_getitem_cat(match: Match, split_sections: list[int], dim: int): split_node = new_split_node split_sections = new_split_sections +<<<<<<< HEAD counters[backend]["merge_getitem_cat_pass"] += 1 +======= + counters["inductor"]["merge_getitem_cat_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ############pattern to be optimized is######### @@ -1640,7 +1727,11 @@ def mutate_cat_node(match: Match, split_sections: list[int], dim: int): cat_user.replace_all_uses_with(split_node.args[0]) # type: ignore[arg-type] # remove the cat node graph.erase_node(cat_user) +<<<<<<< HEAD counters[backend]["mutate_cat_pass"] += 1 +======= + counters["inductor"]["mutate_cat_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # case 2: the cat uses some getitems from the split elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type] # check the split dim, and construct the slice tuple @@ -1668,7 +1759,11 @@ def mutate_cat_node(match: Match, split_sections: list[int], dim: int): # remove the cat node graph.erase_node(cat_user) +<<<<<<< HEAD counters[backend]["mutate_cat_pass"] += 1 +======= + counters["inductor"]["mutate_cat_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) getitem_split_aten = ListOf( @@ -1703,12 +1798,15 @@ def normalize_split_default_aten(match: Match, *args, **kwargs): return if split_dim < 0: # Normalize split dim split_dim += split_input.meta["val"].dim() +<<<<<<< HEAD # we also need to check the input of the split_node # primals =torch.randn(4096, 300) # split = torch.ops.aten.split.Tensor(primals, 320, 1) -> truncate to 300 automatically # split_2 = torch.ops.aten.split_with_sizes.default(primals, [320], dim = 1) -> runtime error split_input_size = split_input.meta["val"].shape[split_dim] split_size = min(split_size, split_input_size) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) split_section_list = [split_size] * (len(split_node.meta["val"])) new_args = (split_input, split_section_list) new_kwargs = {"dim": split_dim} @@ -1728,7 +1826,11 @@ def normalize_split_default_aten(match: Match, *args, **kwargs): split_node.replace_all_uses_with(new_split_node) new_split_node.meta.update(split_node.meta) graph.erase_node(split_node) +<<<<<<< HEAD counters[backend]["normalization_aten_pass"] += 1 +======= + counters["inductor"]["normalization_aten_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_graph_pattern( @@ -1769,7 +1871,11 @@ def normalize_split_with_size_default_aten(match: Match, *args, **kwargs): split_node.replace_all_uses_with(new_split_node) new_split_node.meta.update(split_node.meta) graph.erase_node(split_node) +<<<<<<< HEAD counters[backend]["normalization_aten_pass"] += 1 +======= + counters["inductor"]["normalization_aten_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_graph_pattern( @@ -1792,11 +1898,15 @@ def merge_split_cat_aten(match: Match, *args, **kwargs): for cat_node in list(getitem_nodes[0].users.keys()): cat_dim = get_arg_value(cat_node, 1, "dim") cat_inputs = get_arg_value(cat_node, 0, "tensors") +<<<<<<< HEAD try: cat_input_len = len(cat_inputs) except TypeError: continue if cat_input_len < threshold_to_cat: +======= + if len(cat_inputs) < threshold_to_cat: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue # check split node and cat node has same dim, and all getitem nodes have same parent node parent_to_indices = defaultdict(list) # type: ignore[var-annotated] @@ -1870,7 +1980,11 @@ def merge_split_cat_aten(match: Match, *args, **kwargs): graph.erase_node(getitem_node) if len(split_node.users) == 0: graph.erase_node(split_node) +<<<<<<< HEAD counters[backend]["split_cat_aten_pass"] += 1 +======= + counters["inductor"]["split_cat_aten_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_graph_pattern( @@ -1930,7 +2044,11 @@ def merge_select_cat_aten(match: Match, *args, **kwargs): for select_node in select_nodes: if len(select_node.users) == 0: graph.erase_node(select_node) +<<<<<<< HEAD counters[backend]["select_cat_aten_pass"] += 1 +======= + counters["inductor"]["select_cat_aten_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_graph_pattern( @@ -1978,7 +2096,11 @@ def is_empty_tensor(x: torch.fx.Node) -> bool: cat_node.replace_all_uses_with(new_cat_node) new_cat_node.meta.update(cat_node.meta) graph.erase_node(cat_node) +<<<<<<< HEAD counters[backend]["normalization_aten_pass"] += 1 +======= + counters["inductor"]["normalization_aten_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_graph_pattern( @@ -2039,7 +2161,11 @@ def merge_unbind_stack_aten(match: Match, *args, **kwargs): for select_node in select_nodes: if len(select_node.users) == 0: graph.erase_node(select_node) +<<<<<<< HEAD counters[backend]["unbind_stack_aten_pass"] += 1 +======= + counters["inductor"]["unbind_stack_aten_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def divide_into_consecutive_sublists(indices: list[int]) -> list[list[int]]: @@ -2377,7 +2503,11 @@ def split_cat_to_slices(match: Match, split_sections: list[int], dim: int): cat_inputs = cat_node.args[0] # type: ignore[union-attr] graph.erase_node(cat_node) remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] +<<<<<<< HEAD counters[backend]["split_cat_to_slices_pass"] += 1 +======= + counters["inductor"]["split_cat_to_slices_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue if len(new_cat_args) > 1 and len(new_cat_args) < len(cat_inputs): new_args = (new_cat_args,) @@ -2393,7 +2523,11 @@ def split_cat_to_slices(match: Match, split_sections: list[int], dim: int): # remove the cat node graph.erase_node(cat_node) remove_split_unbind_children(graph, cat_inputs) +<<<<<<< HEAD counters[backend]["split_cat_to_slices_pass"] += 1 +======= + counters["inductor"]["split_cat_to_slices_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ############pattern to be optimized is######### @@ -2454,7 +2588,11 @@ def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int): cat_inputs = cat_node.args[0] # type: ignore[union-attr] graph.erase_node(cat_node) remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] +<<<<<<< HEAD counters[backend]["unbind_cat_to_view_pass"] += 1 +======= + counters["inductor"]["unbind_cat_to_view_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): # get the view shape @@ -2474,7 +2612,11 @@ def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int): cat_inputs = cat_node.args[0] # type: ignore[union-attr] graph.erase_node(cat_node) remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] +<<<<<<< HEAD counters[backend]["unbind_cat_to_view_pass"] += 1 +======= + counters["inductor"]["unbind_cat_to_view_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def reshape_cat_node_to_stack( @@ -2624,7 +2766,11 @@ def split_stack_to_cats(match: Match, split_sections: list[int], dim: int): # case 1: only one node in the new cat args, don't need to cat if len(new_cat_args) == 1: reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, split_dim) +<<<<<<< HEAD counters[backend]["split_stack_to_cats_pass"] += 1 +======= + counters["inductor"]["split_stack_to_cats_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): with graph.inserting_after(stack_node): @@ -2637,7 +2783,11 @@ def split_stack_to_cats(match: Match, split_sections: list[int], dim: int): new_cat_args_meta, dim=split_dim ) reshape_cat_node_to_stack(graph, cat_node, stack_node, split_dim) +<<<<<<< HEAD counters[backend]["split_stack_to_cats_pass"] += 1 +======= + counters["inductor"]["split_stack_to_cats_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ############pattern to be optimized is######### @@ -2696,7 +2846,11 @@ def unbind_stack_to_slices(match: Match, unbind_input: torch.fx.Node, dim: int): # case 1: only one node in the new cat args, don't need to cat if len(new_cat_args) == 1: reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, unbind_dim) +<<<<<<< HEAD counters[backend]["unbind_stack_to_slices_pass"] += 1 +======= + counters["inductor"]["unbind_stack_to_slices_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): # get the view shape @@ -2711,7 +2865,11 @@ def unbind_stack_to_slices(match: Match, unbind_input: torch.fx.Node, dim: int): new_cat_args_meta, dim=cat_dim ) reshape_cat_node_to_stack(graph, new_cat_node, stack_node, unbind_dim) +<<<<<<< HEAD counters[backend]["unbind_stack_to_slices_pass"] += 1 +======= + counters["inductor"]["unbind_stack_to_slices_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ############pattern to be optimized is######### @@ -2816,7 +2974,11 @@ def move_reshape_out_of_split_stack(match: Match, *args, **kwargs): # check the input of stack node, and remove nodes that have no users remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] remove_split_unbind_children(graph, split_users) # type: ignore[arg-type] +<<<<<<< HEAD counters[backend]["move_reshape_out_of_split_stack_pass"] += 1 +======= + counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): # decompose the cat args into multiple stack nodes, i.e., we stack @@ -2878,7 +3040,11 @@ def move_reshape_out_of_split_stack(match: Match, *args, **kwargs): graph.erase_node(stack_node) remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] remove_split_unbind_children(graph, split_users) # type: ignore[arg-type] +<<<<<<< HEAD counters[backend]["move_reshape_out_of_split_stack_pass"] += 1 +======= + counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) view_getitem_split_aten = ListOf( @@ -2970,6 +3136,7 @@ def move_view_after_cat(match: Match, *args, **kwargs): cat_node.replace_all_uses_with(view_node) view_node.meta.update(cat_node.meta) graph.erase_node(cat_node) +<<<<<<< HEAD counters[backend]["move_view_after_cat_aten_pass"] += 1 @@ -3033,3 +3200,6 @@ def should_replace_einsum(einsum_node) -> bool: if should_replace_einsum(einsum_node): match.replace_by_example(repl, [input, weights]) counters[backend]["einsum_to_pointwise_pass"] += 1 +======= + counters["inductor"]["move_view_after_cat_aten_pass"] += 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_inductor/fx_utils.py b/torch/_inductor/fx_utils.py index c754c0324868e..7abc2551770e1 100644 --- a/torch/_inductor/fx_utils.py +++ b/torch/_inductor/fx_utils.py @@ -1,5 +1,8 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD import contextlib +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import operator from collections import defaultdict from typing import Any, Callable, Optional @@ -89,7 +92,10 @@ def hash_node(self, node: torch.fx.Node): return (node, node.target, id(node.args), id(node.kwargs)) def incremental_update(self): +<<<<<<< HEAD """Update FakeTensors on self.graph. We will try to do the minimum amount of work.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) existing_storages: defaultdict[Optional[int], int] = defaultdict(int) for node in self.graph.nodes: existing_storages[get_node_storage(node)] += 1 @@ -97,15 +103,23 @@ def incremental_update(self): def is_intlist_same(new, old): return statically_known_true(sym_eq(new, old)) +<<<<<<< HEAD def is_fake_tensor_same(new, old, *, node): +======= + def is_fake_tensor_same(new, old): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if type(new) != type(old): return False if isinstance(new, (list, tuple)): if len(new) != len(old): return False return all( +<<<<<<< HEAD is_fake_tensor_same(new_i, old_i, node=node) for new_i, old_i in zip(new, old) +======= + is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if new is None: return old is None @@ -135,6 +149,7 @@ def is_fake_tensor_same(new, old, *, node): if get_storage(new) == get_storage(old): return True +<<<<<<< HEAD def any_user_may_alias(node): if not isinstance(node.meta["val"], torch.Tensor): # analysis too complicated on lists, can support in the future @@ -190,6 +205,14 @@ def any_user_may_alias(node): ): return True +======= + # This is the case where it returns a completely fresh storage that's used nowhere else. + if ( + existing_storages[get_storage(old)] == 1 + and get_storage(new) not in existing_storages + ): + return True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False def should_process_node(node): @@ -201,16 +224,22 @@ def should_process_node(node): return node.op == "call_function" and ( isinstance(node.target, torch._ops.OpOverload) or node.target == operator.getitem +<<<<<<< HEAD or node.target == torch._inductor.fx_passes.reinplace._generalized_scatter +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) to_process = OrderedSet[int]() for node in self.graph.nodes: +<<<<<<< HEAD # NB: Be very careful about skipping nodes (via continues) here # and ask for a careful review when changing this code. The # consequence for incorrect FakeTensor metadata is difficult-to-debug # silent incorrectness. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( self.hash_node(node) in self.processed_hashes and id(node) not in to_process @@ -225,9 +254,14 @@ def should_process_node(node): continue with V.fake_mode, enable_python_dispatcher(): new_fake_tensor = node.target(*args, **kwargs) +<<<<<<< HEAD if "val" in node.meta and is_fake_tensor_same( new_fake_tensor, node.meta["val"], node=node +======= + if "val" in node.meta and is_fake_tensor_same( + new_fake_tensor, node.meta["val"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): continue @@ -314,7 +348,11 @@ def realizes_inputs(node: torch.fx.Node) -> bool: def count_flops_fx(node: torch.fx.Node) -> Optional[int]: +<<<<<<< HEAD if not countable_fx(node) or isinstance(node.target, str): +======= + if isinstance(node.target, str): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return None with FakeTensorMode(allow_non_fake_inputs=True): success, args, kwargs = get_fake_args_kwargs(node) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index d10dc7a464261..fc58ce93597fb 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -65,7 +65,10 @@ MissingOperatorWithDecomp, MissingOperatorWithoutDecomp, ) +<<<<<<< HEAD from .fx_utils import count_flops_fx +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .ir import ( Constant, DonatedBuffer, @@ -75,7 +78,10 @@ InputBuffer, Pointwise, Reduction, +<<<<<<< HEAD ShapeAsConstantBuffer, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) StorageBox, TensorBox, TorchBindObject, @@ -123,7 +129,10 @@ from torch.fx.graph import Graph from .codegen.wrapper import PythonWrapperCodegen +<<<<<<< HEAD from .dependencies import Dep +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .scheduler import BaseSchedulerNode CompiledModule = Union[ModuleType, FileBackedGraphModule] @@ -312,7 +321,10 @@ def __init__( const_module: Optional[GraphLowering] = None, name: Optional[str] = None, inputs_to_check: Optional[Sequence[int]] = None, +<<<<<<< HEAD fx_wrapper: bool = False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: super().__init__(gm) self.example_inputs = example_inputs @@ -342,7 +354,10 @@ def __init__( shape_env.deferred_runtime_asserts.copy() ) self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]() +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.sizevars = SizeVarAllocator(shape_env) self.graph_input_names: list[str] = [] self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {} @@ -393,6 +408,11 @@ def __init__( self.inplaced_to_remove: OrderedSet[str] = OrderedSet() self.device_ops: DeviceOpOverrides = None # type: ignore[assignment] self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment] +<<<<<<< HEAD +======= + # See `ProxyExecutor Design Note` in ir.py for more details + self.extern_kernel_nodes: list[ir.ExternKernelNode] = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.extern_node_serializer import extern_node_json_serializer @@ -412,7 +432,10 @@ def __init__( self.creation_time = time.time() self.name = name # type: ignore[assignment] self.cpp_wrapper = cpp_wrapper +<<<<<<< HEAD self.fx_wrapper = fx_wrapper +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # record multi_kernel choice for cpp_wrapper so the second pass knows # which sub-kernel is picked. Copy cpp_wrapper to another variable @@ -487,9 +510,12 @@ def __init__( self.bw_donated_idxs = get_donated_idxs() +<<<<<<< HEAD # Cache for dep size hints to avoid expensive recomputation self.dep_size_hint_cache: dict[Dep, int] = {} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def freeze_runtime_asserts(self) -> None: self._shape_env.freeze_runtime_asserts() @@ -575,6 +601,7 @@ def has_feature( assert isinstance(feature, BackendFeature), feature return feature in self.get_backend_features(get_device_type(device)) +<<<<<<< HEAD def get_dep_size_hint(self, dep: Dep) -> int: """ Get the size hint for a dependency with caching to avoid expensive recomputation. @@ -592,6 +619,8 @@ def get_dep_size_hint(self, dep: Dep) -> int: self.dep_size_hint_cache[dep] = res return self.dep_size_hint_cache[dep] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_current_device_or_throw(self) -> torch.device: if device := self.current_device: return device @@ -682,6 +711,7 @@ def is_small_channel(n: torch.fx.Node) -> bool: # only grouped convolutions benchmarked as slower in conv samples for inference only if is_inference: +<<<<<<< HEAD flop_counts: dict[str, float] = defaultdict(float) for node in conv_nodes: counted_flops = count_flops_fx(node) @@ -700,6 +730,34 @@ def is_small_channel(n: torch.fx.Node) -> bool: flop_counts[node_type] += counted_flops else: log.debug("Conv inputs meta not found") +======= + from torch.utils.flop_counter import FlopCounterMode + + flop_counts: dict[str, float] = defaultdict(float) + for node in conv_nodes: + success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs( + node + ) + + if success: + with FlopCounterMode(display=False) as flop_counter_mode: + with V.fake_mode: + node.target(*args, **kwargs) + + counted_flops = flop_counter_mode.get_total_flops() + if is_grouped(node): + node_type = "grouped" + elif is_small_channel(node): + node_type = "small" + elif is_in_out_channel(node): + node_type = "in_out" + else: + node_type = "default" + + flop_counts[node_type] += counted_flops + else: + log.debug("Conv inputs meta not found") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # average benchmarked channels last speedup / slowdown, < 1 is speedup. # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/ @@ -1046,7 +1104,11 @@ def allocate_non_dup_const_name( def add_tensor_constant( self, data: Tensor, name: Optional[str] = None +<<<<<<< HEAD ) -> Union[TensorBox, ir.ShapeAsConstantBuffer]: +======= + ) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_name = self.allocate_non_dup_const_name(name, data) return TensorBox.create( ir.ConstantBuffer( @@ -1112,11 +1174,18 @@ def placeholder( return None # See note: Note: [Generator arguments in AOTDispatcher] elif isinstance(example, torch.Generator): +<<<<<<< HEAD assert len(V.graph.current_node.users) == 1 and next( iter(V.graph.current_node.users) ).target in ( torch._prims.rng_prims.graphsafe_run_with_rng_state, torch.ops.higher_order.invoke_subgraph, +======= + assert ( + len(V.graph.current_node.users) == 1 + and next(iter(V.graph.current_node.users)).target + is torch._prims.rng_prims.graphsafe_run_with_rng_state +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) gen = ir.GeneratorState(name=target, device=example.device) self.graph_inputs[target] = gen # type: ignore[assignment] @@ -1156,7 +1225,11 @@ def placeholder( self.graph_inputs[target] = tensor self.graph_input_names.append(target) +<<<<<<< HEAD self.graph_inputs_original[target] = tensor.data.data # type: ignore[union-attr] +======= + self.graph_inputs_original[target] = tensor.data.data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.current_node.users: # cudagraphs should work with an unused CPU input self.add_device_info(example.device) @@ -1206,9 +1279,13 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> error.operator_str(target, args, kwargs), ) +<<<<<<< HEAD tag: Optional[torch._C.Tag] = get_layout_constraint_tag( target, with_default=False ) +======= + tag = get_layout_constraint_tag(target, with_default=False) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( tag is None and torch._library.utils.is_builtin(target) @@ -1225,10 +1302,15 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> # and identify them one by one. decided_constraint = require_contiguous # type: ignore[assignment] else: +<<<<<<< HEAD default_tag: torch._C.Tag = get_layout_constraint_tag( target, with_default=True ) decided_constraint = tag_to_layout_constraint(default_tag) +======= + tag = get_layout_constraint_tag(target, with_default=True) + decided_constraint = tag_to_layout_constraint(tag) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) make_fallback(target, layout_constraint=decided_constraint) @@ -1302,9 +1384,13 @@ def get_attr( target: str, # type: ignore[override] args: tuple[()], # type: ignore[override] kwargs: dict[str, object], +<<<<<<< HEAD ) -> Union[ Constant, TensorBox, ShapeAsConstantBuffer, ir.Subgraph, TorchBindObject ]: +======= + ) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # this is a constant value = getattr_recursive(self.module, target) # type: ignore[arg-type] @@ -1565,10 +1651,15 @@ def debug(msg: str) -> None: ): if ( n.op == "call_function" +<<<<<<< HEAD # this path only for built-in operators and n.target and isinstance(n.target, torch._ops.OpOverload) and torch._library.utils.is_builtin(n.target) +======= + and n.target + not in (operator.getitem, torch._higher_order_ops.invoke_subgraph) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and ( fallback_node_due_to_unsupported_type(n) or CompilerBisector.disable_subsystem( @@ -1829,7 +1920,11 @@ def debug(msg: str) -> None: shape_env = V.graph.sizevars.shape_env +<<<<<<< HEAD # An input can be unbacked symint i.e.: when mark_unabcked is used. +======= + # An input can an unbacked symint i.e.: when mark_unabcked is used. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # in that case add it to new_unbacked_defs. if ( n.op == "placeholder" @@ -1896,7 +1991,10 @@ def format_new_defs() -> str: V.fake_mode.shape_env.unbacked_renamings.get(s, s) for s in unbacked_bindings.keys() ) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert new_unbacked_defs >= renamed_unbacked_bindings, ( f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n" f"fx node is: {n.format_node()}\n" @@ -2019,7 +2117,11 @@ def init_wrapper_code( self.device_ops = get_device_op_overrides(self.device_type) wrapper_code_gen_cls = get_wrapper_codegen_for_device( +<<<<<<< HEAD self.device_type, self.cpp_wrapper, self.fx_wrapper +======= + self.device_type, self.cpp_wrapper +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) assert wrapper_code_gen_cls is not None, ( f"Device {self.device_type} not supported" diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index 0dc0a00412a83..fcee6e6a442d4 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -65,6 +65,7 @@ def is_constant(self): def __post_init__(self): if _is_constant(self.expr): +<<<<<<< HEAD expr = self.expr if isinstance(expr, sympy.Expr): expr = expr.expand(identity=True) @@ -77,6 +78,9 @@ def __post_init__(self): if self.dtype.is_signed: expr = expr - 2 ** (bits - 1) self.expr = expr +======= + self.expr = dtype_to_type(self.dtype)(self.expr) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SymPyOps: diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index a454e4f5f77be..9dc68c88a19bf 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6,27 +6,41 @@ import itertools import logging import operator +<<<<<<< HEAD import os import textwrap import traceback from collections.abc import Container, Generator, Iterable, Iterator, Sequence +======= +import textwrap +import traceback +import typing +from collections.abc import Generator, Iterable, Sequence +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from contextlib import AbstractContextManager, nullcontext from enum import Enum from functools import partial from typing import ( Any, Callable, +<<<<<<< HEAD cast, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ClassVar, Literal, Optional, overload, +<<<<<<< HEAD SupportsFloat, SupportsInt, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TYPE_CHECKING, TypeVar, Union, ) +<<<<<<< HEAD from typing_extensions import ( assert_never, Never, @@ -36,6 +50,9 @@ TypeAlias, TypeIs, ) +======= +from typing_extensions import assert_never, Never, TypeAlias +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from unittest.mock import patch import sympy @@ -50,7 +67,10 @@ from torch._export.serde.serialize import GraphModuleSerializer from torch._higher_order_ops.auto_functionalize import can_auto_functionalize from torch._inductor import metrics +<<<<<<< HEAD from torch._inductor.utils import get_free_symbols +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._prims_common import ( compute_required_storage_length, is_boolean_dtype, @@ -64,12 +84,22 @@ compute_unbacked_bindings, free_symbols, free_unbacked_symbols, +<<<<<<< HEAD rebind_unbacked, resolve_unbacked_bindings, ShapeEnv, SymTypes, ) from torch.fx.node import Node +======= + IterateExprs, + rebind_unbacked, + resolve_unbacked_bindings, + ShapeEnv, + statically_known_true, + SymTypes, +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing from torch.utils._sympy.symbol import SymT @@ -80,7 +110,10 @@ CodegenSymbol, get_scheduling_for_device, index_prevent_reordering, +<<<<<<< HEAD Kernel, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from .dependencies import ( Dep, @@ -121,11 +154,17 @@ if TYPE_CHECKING: from torch._library.fake_class_registry import FakeScriptObject +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import SympyBoolean from torch.fx.node import Argument from .codegen.cuda.cuda_template import CUDATemplate from .codegen.wrapper import PythonWrapperCodegen +======= + from torch.fx.node import Node + + from .codegen.cuda.cuda_template import CUDATemplate +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .graph import GraphLowering from .utils import IndentedBuffer @@ -143,7 +182,10 @@ has_triton = False +<<<<<<< HEAD _P = ParamSpec("_P") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _T = TypeVar("_T") _U = TypeVar("_U") _V = TypeVar("_V") @@ -151,15 +193,21 @@ _IntLike: TypeAlias = Union[int, Expr] _NumLike: TypeAlias = Union[int, float, Expr] +<<<<<<< HEAD _OpOverloads: TypeAlias = Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) indent = functools.partial(textwrap.indent, prefix=" ") aten = torch.ops.aten +<<<<<<< HEAD autotune_warmup = int(os.getenv("TORCH_AUTOTUNE_WARMUP", 25)) autotune_rep = int(os.getenv("TORCH_AUTOTUNE_REP", 100)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ [Note: Inductor IR] Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each @@ -272,7 +320,11 @@ def _check_tensorbox(nodes: Optional[_NodeOrNodes]) -> None: def ops_wrapper(name: str) -> Callable[..., OpsValue]: +<<<<<<< HEAD assert isinstance(name, str), type(name) +======= + assert isinstance(name, str) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def fn(*args: object, **kwargs: object) -> OpsValue: return getattr(ops, name)(*args, **kwargs) @@ -308,6 +360,16 @@ def reindex(index: Sequence[_T]) -> Sequence[_V]: return reindex +<<<<<<< HEAD +======= +def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]: + if unbacked_only: + return free_unbacked_symbols(x) + else: + return free_symbols(x) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NHWC_STRIDE_ORDER = [3, 0, 2, 1] NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1] @@ -318,7 +380,11 @@ def get_fill_order( """ Convert strides to fill order (argsort) """ +<<<<<<< HEAD if shape_env is None or all(isinstance(s, (int, sympy.Integer)) for s in seq): +======= + if shape_env is None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sorted_idx: Sequence[int] = argsort(seq) else: # argsort_sym handles unbacked symints (with the help of the shape_env) @@ -425,7 +491,11 @@ def is_triton(x: Union[IRNode, torch.device, None, str]) -> bool: return False from .codegen.triton import TritonScheduling +<<<<<<< HEAD assert isinstance(device_scheduling, type), type(device_scheduling) +======= + assert isinstance(device_scheduling, type) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return issubclass(device_scheduling, TritonScheduling) @@ -433,6 +503,7 @@ def is_cpu(x: Union[IRNode, torch.device, None, str]) -> bool: return get_device_type(x) == "cpu" +<<<<<<< HEAD def is_aligned_realized_tensor_hint( x: Union[Buffer, TensorBox], alignment: int ) -> bool: @@ -443,6 +514,10 @@ def is_aligned_realized_tensor_hint( or free_unbacked_symbols(x.get_stride()) or free_unbacked_symbols(x.get_size()) ): +======= +def is_aligned_realized_tensor(x: Union[Buffer, TensorBox], alignment: int) -> bool: + if not isinstance(x, IRNode) or x.maybe_get_stride() is None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False aligned_strides = all( @@ -467,7 +542,11 @@ def significant_strides_equal( """ assert len(shape) == len(strides1) and len(strides1) == len(strides2) for dim, s1, s2 in zip(shape, strides1, strides2): +<<<<<<< HEAD if V.graph.sizevars.statically_known_leq(dim, 1): +======= + if V.graph.sizevars.statically_known_leq(dim, 1): # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue if not V.graph.sizevars.statically_known_equals( @@ -481,9 +560,15 @@ def significant_strides_equal( def try_match_insignificant_strides( +<<<<<<< HEAD tensor: IRNode, strides: Sequence[Union[int, torch.SymInt]], ) -> IRNode: +======= + tensor: Union[TensorBox, BaseView], + strides: Sequence[Union[int, torch.SymInt]], +) -> Union[TensorBox, BaseView]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant dimensions - size 0 or 1 - will be updated. @@ -497,7 +582,11 @@ def try_match_insignificant_strides( V.graph.sizevars.statically_known_equals(s1, s2) for s1, s2 in zip(strides, tensor.get_stride()) ): +<<<<<<< HEAD return tensor +======= + return tensor # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not significant_strides_equal(strides, tensor.get_stride(), tensor.get_size()): return tensor @@ -505,7 +594,11 @@ def try_match_insignificant_strides( storage, old_layout = as_storage_and_layout(tensor) new_stride = [*old_layout.stride] for i, s in enumerate(tensor.get_size()): +<<<<<<< HEAD if V.graph.sizevars.statically_known_leq(s, 1): +======= + if V.graph.sizevars.statically_known_leq(s, 1): # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_stride[i] = strides[i] new_layout = FixedLayout( @@ -514,7 +607,10 @@ def try_match_insignificant_strides( old_layout.size, new_stride, old_layout.offset, +<<<<<<< HEAD old_layout.is_pinned, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return TensorBox(ReinterpretView(data=storage, layout=new_layout)) @@ -529,7 +625,11 @@ def gm_original_output_strides(gm: torch.fx.GraphModule) -> None: record_original_output_strides(gm) +<<<<<<< HEAD def get_symbolic_inputs(inputs: Sequence[IRNode]) -> list[Expr]: +======= +def get_symbolic_inputs(inputs: list[Buffer]) -> list[Expr]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sym_vars: OrderedSet[Expr] = OrderedSet() for inp in inputs: sym_vars |= get_free_symbols(inp.get_size(), unbacked_only=False) @@ -539,6 +639,7 @@ def get_symbolic_inputs(inputs: Sequence[IRNode]) -> list[Expr]: class IRNode: +<<<<<<< HEAD """Base class for all intermediate representation (IR) nodes in TorchInductor. Note: @@ -546,11 +647,16 @@ class IRNode: and must be overridden by concrete subclasses. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _current_origins: ClassVar[OrderedSet[Any]] = OrderedSet() # NB: These are kinda weird, origins: OrderedSet[Any] = dataclasses.field(init=False) +<<<<<<< HEAD # traces back to where the IRNode is created in Inductor +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) traceback: Optional[list[str]] = dataclasses.field(init=False) origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False) @@ -564,6 +670,7 @@ def current_origins(origins: OrderedSet[Node]) -> Generator[None, None, None]: finally: IRNode._current_origins = old +<<<<<<< HEAD @staticmethod def is_realized_node(node: IRNode) -> bool: return isinstance( @@ -577,6 +684,8 @@ def is_realized_node(node: IRNode) -> bool: ), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _post_init_setattr(self, attr: str, value: Any) -> None: # Intended for use in __post_init__ for enforcing an invariant on a dataclass # If you must, can also be used for setting provenance info @@ -584,8 +693,12 @@ def _post_init_setattr(self, attr: str, value: Any) -> None: object.__setattr__(self, attr, value) def __post_init__(self) -> None: +<<<<<<< HEAD origins = OrderedSet(self._current_origins) self._post_init_setattr("origins", origins) +======= + self._post_init_setattr("origins", OrderedSet(self._current_origins)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._post_init_setattr( "traceback", traceback.format_stack() if config.debug_ir_traceback else None ) @@ -603,6 +716,7 @@ def get_origin_node(self) -> Optional[torch.fx.Node]: def get_defining_op(self) -> Optional[Operation]: return None +<<<<<<< HEAD def get_stack_traces(self) -> OrderedSet[str]: # Return stack traces to user model code # A single IRNode could correspond to multiple lines of code @@ -634,11 +748,14 @@ def get_stack_traces(self) -> OrderedSet[str]: stack_traces.add(stack_trace) return stack_traces +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def common_repr(self, shorten: bool = True) -> Sequence[str]: origins = f"origins={getattr(self, 'origins', '')}" if shorten and len(origins) > 64: # this can get *very* long origins = f"{origins[:61]}..." +<<<<<<< HEAD if not self.get_stack_traces(): return [origins] @@ -648,6 +765,9 @@ def common_repr(self, shorten: bool = True) -> Sequence[str]: stack_trace_str += stack_trace.split("\n") stack_trace_str.append("}") return [origins] + stack_trace_str +======= + return [origins] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def str_helper( self, lines: Sequence[object], shorten: bool = True, multiline: bool = True @@ -788,6 +908,7 @@ def freeze_layout(self) -> None: raise NotImplementedError(type(self).__name__) def freeze_layout_with_stride_order( +<<<<<<< HEAD self, order: Sequence[int], allow_padding: bool = False ) -> None: raise NotImplementedError(type(self).__name__) @@ -800,6 +921,20 @@ def freeze_layout_with_same_order(self, stride: Sequence[_IntLike]) -> None: def freeze_layout_with_exact_strides( self, exact_strides: Sequence[_IntLike], allow_padding: bool = False +======= + self, order: list[int], allow_padding: bool = False + ) -> None: + raise NotImplementedError(type(self).__name__) + + def freeze_layout_with_fill_order(self, order: list[int]) -> None: + raise NotImplementedError(type(self).__name__) + + def freeze_layout_with_same_order(self, stride: list[_IntLike]) -> None: + raise NotImplementedError(type(self).__name__) + + def freeze_layout_with_exact_strides( + self, exact_strides: list[_IntLike], allow_padding: bool = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: raise NotImplementedError(type(self).__name__) @@ -823,7 +958,11 @@ def get_free_symbol_uses( def get_reduction_type(self) -> Optional[str]: raise NotImplementedError(type(self).__name__) +<<<<<<< HEAD def get_reduction_size(self) -> Sequence[Expr]: +======= + def get_reduction_size(self) -> Sequence[sympy.Expr]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError(type(self).__name__) def is_extern(self) -> bool: @@ -973,9 +1112,13 @@ def get_pointwise_size(self) -> Sequence[Expr]: return self.ranges @classmethod +<<<<<<< HEAD def create( cls, *args: Any, **kwargs: Any ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= + def create(cls, *args: Any, **kwargs: Any) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) origin_node = kwargs.pop("origin_node", None) tb = kwargs.pop("traceback", None) r = cls(*args, **kwargs) @@ -1042,7 +1185,11 @@ def get_read_names(self) -> OrderedSet[str]: def num_reads(self) -> int: return len(self.inner_fn_opcount().read_buffers) +<<<<<<< HEAD def get_reduction_size(self) -> Sequence[Expr]: +======= + def get_reduction_size(self) -> Sequence[sympy.Expr]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError( f"get_reduction_size() is not implemented by {type(self)}!" ) @@ -1094,10 +1241,14 @@ def constant_to_device(self, device: torch.device) -> IRNode: loader = self.make_loader() loader = patch.object(ConstantBuffer, "override_device", device)(loader) return Pointwise( +<<<<<<< HEAD device=device, dtype=self.dtype, inner_fn=loader, ranges=self.ranges, +======= + device=device, dtype=self.dtype, inner_fn=loader, ranges=self.ranges +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -1124,7 +1275,11 @@ def store_output( output_name: Optional[str], indexer: Callable[[Sequence[Expr]], Never], vars: Sequence[Expr], +<<<<<<< HEAD ) -> Any: +======= + ) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) loader = self.make_loader() if output_name is None: output_name = "unnamed" @@ -1227,7 +1382,11 @@ def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol *(get_free_symbols(e, unbacked_only) for e in self.reduction_ranges) ) +<<<<<<< HEAD def get_reduction_size(self) -> Sequence[Expr]: +======= + def get_reduction_size(self) -> Sequence[sympy.Expr]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.reduction_ranges def get_reduction_type(self) -> Optional[str]: @@ -1246,7 +1405,11 @@ def store_reduction( self.reduction_type, self.inner_fn(vars, reduction_vars), ) +<<<<<<< HEAD ops.store_reduction(output_name or "unnamed", indexer(vars), value) +======= + return ops.store_reduction(output_name or "unnamed", indexer(vars), value) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def index_length(self) -> int: return len(self.ranges) + len(self.reduction_ranges) @@ -1283,7 +1446,11 @@ def num_splits( device: torch.device, dst_dtype: torch.dtype, src_dtype: torch.dtype, +<<<<<<< HEAD inner_fn: Callable[_P, OpsValue], +======= + inner_fn: Callable[..., OpsValue], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ranges: Sequence[_IntLike], reduction_ranges: Sequence[_IntLike], reduction_type: Union[ReductionType, Literal["scan"]], @@ -1375,12 +1542,19 @@ def inner_reduction_splits( ) def get_read_indices(r: Reduction) -> tuple[Sequence[Expr], bool]: +<<<<<<< HEAD device = r.get_device() assert device is not None cb = ComputedBuffer( name=None, layout=FlexibleLayout( device=device, +======= + cb = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=r.get_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype=r.get_dtype(), size=r.get_size(), ), @@ -1449,7 +1623,13 @@ def _unroll_reduction_fn( src_dtype: torch.dtype, ) -> Callable[[Sequence[_IntLike]], OpsValue]: """Convert inner_fn from a reduction to an pointwise""" +<<<<<<< HEAD reduction_ranges = V.graph.sizevars.guard_int_seq(reduction_ranges) +======= + reduction_ranges = [ + V.graph.sizevars.evaluate_static_shape(x) for x in reduction_ranges + ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) combine_fn = get_reduction_combine_fn(reduction_type, src_dtype) @@ -1466,10 +1646,19 @@ def fn(index: Sequence[_IntLike]) -> Any: value_fn: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Any] if reduction_type in ("argmin", "argmax"): +<<<<<<< HEAD flatten_index = _fixed_indexer( reduction_ranges, FlexibleLayout.contiguous_strides(reduction_ranges), ) +======= + flatten_index = FixedLayout( + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + reduction_ranges, + FlexibleLayout.contiguous_strides(reduction_ranges), + ).make_indexer() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def value_fn( index: Sequence[_IntLike], rindex: Sequence[_IntLike] @@ -1497,7 +1686,11 @@ def create( reduction_type: ReductionType, reduction_hint: ReductionHint = ReductionHint.DEFAULT, input_node: Optional[IRNode] = None, +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= + ) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) if reduction_numel == 0: @@ -1508,10 +1701,17 @@ def py_cnst(val: object) -> Union[bool, float, int]: if dst_dtype == torch.bool: return bool(val) elif dst_dtype.is_floating_point: +<<<<<<< HEAD assert isinstance(val, SupportsFloat), type(val) return float(val) else: assert isinstance(val, SupportsInt), type(val) +======= + assert isinstance(val, typing.SupportsFloat) + return float(val) + else: + assert isinstance(val, typing.SupportsInt) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return int(val) rtypes_to_inits = { @@ -1601,10 +1801,16 @@ def _maybe_increase_split(split: int) -> int: reduction_hint = hint if split == -1: assert input_node is not None +<<<<<<< HEAD with patch.object(FlexibleLayout, "allow_indexing", True): new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( input_node ) +======= + new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( + input_node + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert new_ranges is not None assert new_reduction_ranges is not None return cls.create_multilayer_existing_ranges( @@ -1781,7 +1987,11 @@ def body() -> OpsValue: @classmethod def _multilayer_wrap_loader_existing_ranges( cls, +<<<<<<< HEAD loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue], +======= + loader: Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) original_ranges: Sequence[Expr], original_reduction_ranges: Sequence[Expr], new_ranges: Sequence[Integer], @@ -1795,8 +2005,13 @@ def _multilayer_wrap_loader_existing_ranges( ) def wrapper_fn( +<<<<<<< HEAD merged_index: Sequence[Expr], new_reduction_index: Sequence[Expr], +======= + merged_index: Sequence[sympy.Expr], + new_reduction_index: Sequence[sympy.Expr], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> OpsValue: original_idx = merged_index[: len(original_ranges)] new_index = merged_index[len(original_ranges) :] @@ -1821,7 +2036,11 @@ def create_multilayer_helper( reduction_type: ReductionType, split: _IntLike, reduction_hint: ReductionHint, +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= + ) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Break a large reduction up into multiple smaller reductions recursively @@ -1884,7 +2103,11 @@ def create_multilayer( split: _IntLike, reduction_hint: ReductionHint, input_node: Optional[IRNode] = None, +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= + ) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Break a large reduction up into multiple smaller reductions recursively @@ -1930,7 +2153,11 @@ def create_multilayer_existing_ranges( new_reduction_ranges: list[Integer], reduction_type: ReductionType, reduction_hint: ReductionHint, +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= + ) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Break a large reduction up into multiple smaller reductions recursively @@ -1957,6 +2184,7 @@ def create_multilayer_existing_ranges( ) +<<<<<<< HEAD def _fixed_indexer( size: Sequence[int], stride: Optional[Sequence[int]] = None, @@ -1977,6 +2205,9 @@ def indexer(index: Sequence[int]) -> int: INNER_FN_TY: TypeAlias = Callable[[Sequence[Expr], Sequence[Expr]], OpsValue] +======= +INNER_FN_TY = Callable[[Sequence[Expr], Sequence[Expr]], OpsValue] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class MultiOutputReduction(Reduction): @@ -2025,14 +2256,22 @@ def store_reduction( indexer: Callable[[Sequence[Expr]], Never], vars: Sequence[Expr], reduction_vars: Sequence[Symbol], +<<<<<<< HEAD ) -> Any: +======= + ) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) values = ops.reduction( self.dtype, self.src_dtype, self.reduction_type, self.inner_fn(vars, reduction_vars), ) +<<<<<<< HEAD assert isinstance(values, (tuple, list)), type(values) +======= + assert isinstance(values, (tuple, list)), f"{type(values)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) value = values[self.output_index] return ops.store_reduction(output_name or "unnamed", indexer(vars), value) @@ -2050,7 +2289,11 @@ def create( # type: ignore[override] num_output: int, reduction_hint: ReductionHint = ReductionHint.DEFAULT, input_node: Optional[IRNode] = None, +<<<<<<< HEAD ) -> Sequence[Union[TensorBox, ShapeAsConstantBuffer]]: +======= + ) -> Sequence[TensorBox]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Create the reduction disregarding splitting. """ @@ -2062,7 +2305,11 @@ def create( # type: ignore[override] inner_fn, ranges, reduction_ranges, +<<<<<<< HEAD "online_softmax_reduce", +======= + "online_softmax_reduce", # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) src_dtype, reduction_hint, output_idx, @@ -2086,12 +2333,20 @@ def create( # type: ignore[override] reduction_ranges: list[Integer], reduction_type: ReductionType, reduction_hint: ReductionHint = ReductionHint.DEFAULT, +<<<<<<< HEAD ) -> Sequence[Union[TensorBox, ShapeAsConstantBuffer]]: +======= + ) -> Sequence[TensorBox]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert reduction_type in ("welford_reduce", "welford_combine") reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) +<<<<<<< HEAD def const(val: int) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= + def const(val: int) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def inner_fn(idx: Sequence[Expr]) -> OpsValue: return ops.constant( val, @@ -2115,7 +2370,11 @@ def inner_fn(idx: Sequence[Expr]) -> OpsValue: def copy( loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue], +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= + ) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def inner_fn(idx: Sequence[Expr]) -> OpsValue: reduction_index = [sympy.S.Zero for _ in reduction_ranges] return loader(idx, reduction_index) @@ -2214,7 +2473,11 @@ def create_multilayer( # type: ignore[override] reduction_type: ReductionType, split: _IntLike, reduction_hint: ReductionHint, +<<<<<<< HEAD ) -> Sequence[Union[TensorBox, ShapeAsConstantBuffer]]: +======= + ) -> Sequence[TensorBox]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Break a large reduction up into multiple smaller reductions recursively @@ -2335,7 +2598,11 @@ def store_reduction( indexer: Callable[[Sequence[_IntLike]], Never], vars: Sequence[Expr], scan_vars: Sequence[Symbol], +<<<<<<< HEAD ) -> Any: +======= + ) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) idx = self.reindex(vars, scan_vars) values = tuple(inner_fn(idx) for inner_fn in self.inner_fns) result = ops.scan(self.dtypes, self.combine_fn, values) @@ -2347,7 +2614,11 @@ def get_reduction_type(self) -> Optional[str]: # return self.scan_op return "custom" +<<<<<<< HEAD def get_reduction_size(self) -> Sequence[Expr]: +======= + def get_reduction_size(self) -> Sequence[sympy.Expr]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.scan_ranges def get_size(self) -> Sequence[Expr]: @@ -2385,7 +2656,11 @@ def create( # type: ignore[override] # Whether we have the option to fallback to aten can_fallback_to_aten: bool = True, **kwargs: Any, +<<<<<<< HEAD ) -> Sequence[Optional[Union[TensorBox, ShapeAsConstantBuffer]]]: +======= + ) -> Sequence[Optional[TensorBox]]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pointwise_ranges = [*size[:axis], *size[axis + 1 :]] scan_ranges = [size[axis]] @@ -2541,7 +2816,11 @@ def store_reduction( indexer: Callable[[Sequence[Expr]], Expr], vars: Sequence[Expr], reduction_vars: Sequence[Expr], +<<<<<<< HEAD ) -> Any: +======= + ) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) idx = self.reindex(vars, reduction_vars) values = tuple(inner_fn(idx) for inner_fn in self.inner_fns) result = ops.sort(self.dtypes, values, self.stable, self.descending) @@ -2588,7 +2867,11 @@ def create( # type: ignore[override] descending: bool, reduction_hint: ReductionHint = ReductionHint.DEFAULT, **kwargs: Any, +<<<<<<< HEAD ) -> Sequence[Optional[Union[TensorBox, ShapeAsConstantBuffer]]]: +======= + ) -> Sequence[Optional[TensorBox]]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pointwise_ranges = [*size[:axis], *size[axis + 1 :]] sort_ranges = [size[axis]] @@ -2752,8 +3035,13 @@ def is_unaligned(node: IRNode) -> bool: if isinstance(node, ReinterpretView): layout = node.layout +<<<<<<< HEAD has_unaligned_layout = not V.graph.sizevars.statically_known_multiple_of( layout.offset * get_dtype_size(layout.dtype), GPU_ALIGN_BYTES +======= + has_unaligned_layout = not statically_known_true( + layout.offset * get_dtype_size(layout.dtype) % GPU_ALIGN_BYTES == 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return is_unaligned(node.data) or has_unaligned_layout @@ -2820,6 +3108,7 @@ def has_exceeded_max_reads(self) -> bool: def realize(self) -> Optional[str]: return self.data.realize() +<<<<<<< HEAD def realize_hint(self) -> None: self.data.realize_hint() @@ -2832,6 +3121,19 @@ def is_extern(self) -> bool: def is_module_buffer(self) -> bool: assert isinstance(self.data, BaseView), type(self.data) return self.data.is_module_buffer() +======= + def realize_hint(self): # type: ignore[no-untyped-def] + return self.data.realize_hint() + + def get_storage_numel(self): # type: ignore[no-untyped-def] + return self.data.get_storage_numel() + + def is_extern(self) -> bool: + return self.data.is_extern() # type: ignore[attr-defined] + + def is_module_buffer(self) -> bool: + return self.data.is_module_buffer() # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_read_names(self) -> OrderedSet[str]: return self.data.get_read_names() @@ -2840,10 +3142,17 @@ def get_reads(self) -> OrderedSet[Dep]: with patch.object(FlexibleLayout, "allow_indexing", True): return extract_read_writes( self.make_loader(), +<<<<<<< HEAD self.get_size(), ).reads def unwrap_view(self) -> IRNode: +======= + self.get_size(), # type: ignore[arg-type] + ).reads + + def unwrap_view(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x: IRNode = self while isinstance(x, BaseView): x = x.data @@ -2863,6 +3172,7 @@ def constant_to_device(self, device: torch.device) -> IRNode: @ir_dataclass class ExpandView(BaseView): +<<<<<<< HEAD size: Sequence[Expr] @staticmethod @@ -2870,6 +3180,15 @@ def _normalize_size(x: IRNode, new_size: Sequence[_IntLike]) -> Sequence[_IntLik """Replace `-1` with correct sizes""" sizevars = V.graph.sizevars new_size = [sympy.expand(s) for s in new_size] +======= + size: list[Expr] + + @staticmethod + def _normalize_size(x, new_size): # type: ignore[no-untyped-def] + """Replace `-1` with correct sizes""" + sizevars = V.graph.sizevars + new_size = list(map(sympy.expand, new_size)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) old_size = x.get_size() old_size = [None] * (len(new_size) - len(old_size)) + list(old_size) assert len(new_size) == len(old_size) @@ -2877,8 +3196,13 @@ def _normalize_size(x: IRNode, new_size: Sequence[_IntLike]) -> Sequence[_IntLik if new_size[i] == -1: assert old_size[i] is not None new_size[i] = old_size[i] +<<<<<<< HEAD elif old_size[i] is None or V.graph.sizevars.is_size_one_or_false( old_size[i] +======= + elif old_size[i] is None or V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(old_size[i], 1), size_oblivious=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): pass else: @@ -2893,7 +3217,11 @@ def _normalize_size(x: IRNode, new_size: Sequence[_IntLike]) -> Sequence[_IntLik return new_size @classmethod +<<<<<<< HEAD def create(cls, x: IRNode, new_size: Sequence[_IntLike]) -> BaseView: +======= + def create(cls, x, new_size): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_size = cls._normalize_size(x, new_size) if is_storage_and_layout(x): @@ -2904,7 +3232,13 @@ def create(cls, x: IRNode, new_size: Sequence[_IntLike]) -> BaseView: for stride, size in zip(old_layout.stride, old_layout.size): new_stride.append( stride +<<<<<<< HEAD if not V.graph.sizevars.is_size_one_or_false(size) +======= + if not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(size, 1), size_oblivious=True + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else sympy.S.Zero ) new_layout = FixedLayout( @@ -2913,7 +3247,10 @@ def create(cls, x: IRNode, new_size: Sequence[_IntLike]) -> BaseView: list(new_size), new_stride, old_layout.offset, +<<<<<<< HEAD old_layout.is_pinned, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return ReinterpretView(data=storage, layout=new_layout) @@ -2922,16 +3259,24 @@ def create(cls, x: IRNode, new_size: Sequence[_IntLike]) -> BaseView: def get_size(self) -> Sequence[Expr]: return self.size +<<<<<<< HEAD def make_reindexer( self, ) -> Callable[[Sequence[Expr]], Sequence[Expr]]: +======= + def make_reindexer(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) target = self.get_size() actual = self.data.get_size() skip = len(target) - len(actual) +<<<<<<< HEAD def reindex( index: Sequence[Expr], ) -> Sequence[Expr]: +======= + def reindex(index): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) index = list(index[skip:]) assert len(index) == len(actual) for i in range(len(actual)): @@ -2948,7 +3293,11 @@ class PermuteView(BaseView): dims: list[Expr] @classmethod +<<<<<<< HEAD def create(cls, x: IRNode, dims: Sequence[int]) -> BaseView: +======= + def create(cls, x, dims): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dims = cls._map_neg_dims(dims) assert OrderedSet(dims) == OrderedSet(range(len(dims))) @@ -2960,14 +3309,21 @@ def create(cls, x: IRNode, dims: Sequence[int]) -> BaseView: [old_layout.size[i] for i in dims], [old_layout.stride[i] for i in dims], old_layout.offset, +<<<<<<< HEAD old_layout.is_pinned, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return ReinterpretView(data=storage, layout=new_layout) return PermuteView(data=x, dims=dims) @classmethod +<<<<<<< HEAD def _map_neg_dims(cls, dims: Sequence[int]) -> list[int]: +======= + def _map_neg_dims(cls, dims): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return [dim if dim >= 0 else len(dims) + dim for dim in dims] def get_size(self) -> Sequence[Expr]: @@ -2977,16 +3333,24 @@ def get_size(self) -> Sequence[Expr]: size = self.data.get_size() return [size[i] for i in self.dims] +<<<<<<< HEAD def make_reindexer( self, ) -> Callable[[Sequence[Expr]], Sequence[Expr]]: +======= + def make_reindexer(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inv = {j: i for i, j in enumerate(self.dims)} inv = [inv[i] for i in range(len(self.dims))] assert OrderedSet(inv) == OrderedSet(range(len(self.dims))) +<<<<<<< HEAD def reindex( index: Sequence[Expr], ) -> Sequence[Expr]: +======= + def reindex(index): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return [index[i] for i in inv] return reindex @@ -2995,13 +3359,21 @@ def reindex( @ir_dataclass class SqueezeView(BaseView): @classmethod +<<<<<<< HEAD def create(cls, x: IRNode, *, dim: Optional[int] = None) -> IRNode: +======= + def create(cls, x, *, dim=None): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_storage_and_layout(x): storage, old_layout = as_storage_and_layout(x) new_size = [] new_stride = [] if dim is not None: +<<<<<<< HEAD assert isinstance(dim, int), type(dim) +======= + assert isinstance(dim, int), "expected integer dim argument" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert 0 <= dim and dim < len(old_layout.size) for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)): @@ -3022,7 +3394,10 @@ def create(cls, x: IRNode, *, dim: Optional[int] = None) -> IRNode: new_size, new_stride, old_layout.offset, +<<<<<<< HEAD old_layout.is_pinned, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return ReinterpretView(data=storage, layout=new_layout) @@ -3034,14 +3409,22 @@ def create(cls, x: IRNode, *, dim: Optional[int] = None) -> IRNode: return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim]) @staticmethod +<<<<<<< HEAD def squeezer( size: Sequence[Expr], ) -> tuple[list[int], Callable[[Sequence[Expr]], tuple[Expr]]]: +======= + def squeezer(size: Sequence[sympy.Expr]): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_size = [s for s in size if s != 1] not_one = [i for i, s in enumerate(size) if s != 1] length = len(size) +<<<<<<< HEAD def reindex(index: Sequence[Expr]) -> tuple[Expr]: +======= + def reindex(index: list[sympy.Expr]) -> tuple[sympy.Expr, ...]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(index) == len(not_one), f"{index} {not_one}" new_index = [sympy.S.Zero] * length for idx, s in zip(not_one, index): @@ -3050,18 +3433,29 @@ def reindex(index: Sequence[Expr]) -> tuple[Expr]: return new_size, reindex +<<<<<<< HEAD def __init__(self, data: Any) -> None: +======= + def __init__(self, data) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise AssertionError("use SqueezeView.create()") @ir_dataclass class GenericView(BaseView): +<<<<<<< HEAD size: Sequence[Expr] reindex: Callable[[Sequence[Expr]], Sequence[Expr]] def make_reindexer( self, ) -> Callable[[Sequence[Expr]], Sequence[Expr]]: +======= + size: list[Expr] + reindex: Callable[..., Any] + + def make_reindexer(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.reindex def reindex_str(self) -> str: @@ -3079,12 +3473,16 @@ def __str__(self) -> str: __repr__ = __str__ @classmethod +<<<<<<< HEAD def create( cls, x: IRNode, new_size: Sequence[Expr], reindex: Callable[[Sequence[Expr]], Sequence[Expr]], ) -> BaseView: +======= + def create(cls, x, new_size, reindex): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return cls(data=x, size=list(new_size), reindex=reindex) def get_size(self) -> Sequence[Expr]: @@ -3094,7 +3492,11 @@ def get_size(self) -> Sequence[Expr]: @ir_dataclass class View(GenericView): @staticmethod +<<<<<<< HEAD def handle_negative_index(idx: Expr, size: Expr) -> Expr: +======= + def handle_negative_index(idx, size): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) idx = sympy.expand(idx) size = sympy.expand(size) evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr @@ -3103,8 +3505,13 @@ def handle_negative_index(idx: Expr, size: Expr) -> Expr: return idx @classmethod +<<<<<<< HEAD def create(cls, x: IRNode, new_size: Sequence[Expr]) -> IRNode: # type: ignore[override] assert isinstance(new_size, Sequence), type(new_size) +======= + def create(cls, x, new_size): # type: ignore[no-untyped-def, override] + assert isinstance(new_size, (tuple, list)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size) # Skip pointless views @@ -3120,7 +3527,11 @@ def create(cls, x: IRNode, new_size: Sequence[Expr]) -> IRNode: # type: ignore[ if 0 in new_size: +<<<<<<< HEAD def fake_reindex(index: Any) -> tuple[int, ...]: +======= + def fake_reindex(index): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return tuple([0] * len(old_size)) return cls(data=x, size=list(new_size), reindex=fake_reindex) @@ -3141,7 +3552,10 @@ def fake_reindex(index: Any) -> tuple[int, ...]: new_size, FlexibleLayout.contiguous_strides(new_size), old_layout.offset, +<<<<<<< HEAD old_layout.is_pinned, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return ReinterpretView(data=storage, layout=new_layout) @@ -3149,9 +3563,13 @@ def fake_reindex(index: Any) -> tuple[int, ...]: return cls(data=x, size=list(new_size), reindex=reindex) @staticmethod +<<<<<<< HEAD def resolve_negative_size( old_size: Sequence[Expr], new_size: Sequence[Expr] ) -> tuple[list[Expr], list[Expr]]: +======= + def resolve_negative_size(old_size, new_size): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_size = [V.graph.sizevars.simplify(x) for x in new_size] old_size = [V.graph.sizevars.simplify(x) for x in old_size] @@ -3162,7 +3580,11 @@ def resolve_negative_size( new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size)) break +<<<<<<< HEAD V.graph.sizevars.check_equals(sympy_product(old_size), sympy_product(new_size)) +======= + V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return old_size, new_size @classmethod @@ -3170,7 +3592,11 @@ def dynamic_reshape_indexer( cls, old_size: Sequence[_IntLike], new_size: Sequence[_IntLike], +<<<<<<< HEAD dense_dim: Optional[int] = None, +======= + dense_dim: Optional[int] = None, # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Callable[[Sequence[_T]], Sequence[_V]]: try: reindex = cls._dynamic_reshape_indexer(old_size, new_size, dense_dim) @@ -3183,11 +3609,15 @@ def dynamic_reshape_indexer( return reindex @staticmethod +<<<<<<< HEAD def _dynamic_reshape_indexer( old_size: Sequence[Expr], new_size: Sequence[Expr], dense_dim: Optional[int] = None, ) -> Callable[[Sequence[Expr]], Sequence[Expr]]: +======= + def _dynamic_reshape_indexer(old_size, new_size, dense_dim: Optional[int] = None): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Perform a reshape entirely by modifying indexing math """ @@ -3223,14 +3653,22 @@ def _dynamic_reshape_indexer( stack_old.append(size_old) # re-add elif size_hint(size_new) == size_hint(size_old): view_expr.append(var) +<<<<<<< HEAD V.graph.sizevars.check_equals(size_new, size_old) +======= + V.graph.sizevars.guard_equals(size_new, size_old) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif size_hint(size_new) < size_hint(size_old): while size_hint(size_new) < size_hint(size_old): var2, size_new2 = stack_new.pop() var = var2 * size_new + var size_new = size_new * size_new2 view_expr.append(var) +<<<<<<< HEAD V.graph.sizevars.check_equals(size_new, size_old) +======= + V.graph.sizevars.guard_equals(size_new, size_old) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif size_hint(size_new) > size_hint(size_old): divisor = sympy.S.One modulus = size_old @@ -3241,18 +3679,30 @@ def _dynamic_reshape_indexer( view_expr.append(ModularIndexing(var, divisor, modulus)) divisor = divisor * modulus size_old = size_old * modulus +<<<<<<< HEAD V.graph.sizevars.check_equals(size_new, size_old) +======= + V.graph.sizevars.guard_equals(size_new, size_old) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: raise AssertionError while stack_old: size_old = stack_old.pop() +<<<<<<< HEAD V.graph.sizevars.check_equals(size_old, 1) +======= + V.graph.sizevars.guard_equals(size_old, 1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) view_expr.append(sympy.S.Zero) while stack_new: var, size_new = stack_new.pop() +<<<<<<< HEAD V.graph.sizevars.check_equals(size_new, 1) +======= + V.graph.sizevars.guard_equals(size_new, 1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if dense_dim is not None and len(new_size) == 1: view_expr.reverse() @@ -3264,9 +3714,13 @@ def _dynamic_reshape_indexer( assert len(view_expr) == len(old_size) +<<<<<<< HEAD def reindex( index: Sequence[Expr], ) -> Sequence[Expr]: +======= + def reindex(index): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(index) == len(vars), (len(index), len(vars)) replacements = dict(zip(vars, index)) return tuple(sympy_subs(x, replacements) for x in view_expr) @@ -3305,13 +3759,21 @@ def get_origin_node(self) -> Optional[torch.fx.Node]: return None @property +<<<<<<< HEAD def dtype(self) -> torch.dtype: +======= + def dtype(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.layout.dtype def get_size(self) -> Sequence[Expr]: return list(self.layout.size) +<<<<<<< HEAD def get_stride(self) -> Sequence[Expr]: +======= + def get_stride(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return list(self.layout.stride) def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: @@ -3331,7 +3793,11 @@ def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: def get_layout(self) -> Layout: return self.layout +<<<<<<< HEAD def freeze_layout(self) -> None: +======= + def freeze_layout(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pass def get_free_symbol_uses( @@ -3367,7 +3833,11 @@ class DtypeView(BaseView): target_dtype: torch.dtype @classmethod +<<<<<<< HEAD def create(cls, x: IRNode, new_dtype: torch.dtype) -> BaseView: +======= + def create(cls, x, new_dtype): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_storage_and_layout(x): storage, old_layout = as_storage_and_layout(x) new_layout = FixedLayout( @@ -3376,7 +3846,10 @@ def create(cls, x: IRNode, new_dtype: torch.dtype) -> BaseView: old_layout.size, old_layout.stride, old_layout.offset, +<<<<<<< HEAD old_layout.is_pinned, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return ReinterpretView(data=storage, layout=new_layout) return DtypeView(data=x, target_dtype=new_dtype) @@ -3387,7 +3860,11 @@ def __str__(self) -> str: __repr__ = __str__ @property +<<<<<<< HEAD def dtype(self) -> torch.dtype: +======= + def dtype(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.target_dtype def get_size(self) -> Sequence[Expr]: @@ -3396,7 +3873,11 @@ def get_size(self) -> Sequence[Expr]: def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: inner = self.data.make_loader() +<<<<<<< HEAD def loader(idx: Sequence[Expr]) -> OpsValue: +======= + def loader(idx): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ops.to_dtype_bitcast(inner(idx), self.target_dtype, self.data.dtype) return loader @@ -3404,9 +3885,13 @@ def loader(idx: Sequence[Expr]) -> OpsValue: class SliceView(View): @classmethod +<<<<<<< HEAD def normalize_start_end( cls, x: IRNode, dim: int, start: int, end: int ) -> tuple[int, int]: +======= + def normalize_start_end(cls, x, dim, start, end): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Normalize start and end such that both are in the range [0, x.get_size()[dim]] and start <= end. @@ -3421,7 +3906,11 @@ def normalize_start_end( min_func = sizevars.evaluate_min max_func = sizevars.evaluate_max +<<<<<<< HEAD def clamp(x: Expr, lower: int, upper: int) -> Expr: +======= + def clamp(x, lower, upper): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) clamped_lower = ( x if sizevars.statically_known_geq(x, lower) else max_func(x, lower) ) @@ -3432,11 +3921,16 @@ def clamp(x: Expr, lower: int, upper: int) -> Expr: ) return clamped_full +<<<<<<< HEAD def clamp_wrap( val: Union[int, None], lower: int, upper: int, default: Union[Expr, int] ) -> Union[Expr, int]: if val is None: # TODO(rec): can this really happen? +======= + def clamp_wrap(val, lower, upper, default): # type: ignore[no-untyped-def] + if val is None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return default val = cls.handle_negative_index(val, dim_size) return clamp(val, lower, upper) @@ -3446,6 +3940,7 @@ def clamp_wrap( return start, end @classmethod +<<<<<<< HEAD def create( # type: ignore[override] cls, x: IRNode, @@ -3457,6 +3952,11 @@ def create( # type: ignore[override] ) -> IRNode: step = sympy.expand(step) assert isinstance(step, Expr) or step > 0, step +======= + def create(cls, x, dim, start, end, step=1, clamp=True): # type: ignore[no-untyped-def, override] + step = sympy.expand(step) + assert isinstance(step, sympy.Expr) or step > 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: if start == 0 and end >= 2**63 - 1 and step == 1: return x @@ -3484,6 +3984,7 @@ def create( # type: ignore[override] new_size, new_stride, old_layout.offset + old_layout.stride[dim] * start, +<<<<<<< HEAD old_layout.is_pinned, ) return ReinterpretView(data=storage, layout=new_layout) @@ -3491,6 +3992,12 @@ def create( # type: ignore[override] def reindex( index: Sequence[Expr], ) -> Sequence[Expr]: +======= + ) + return ReinterpretView(data=storage, layout=new_layout) + + def reindex(index): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(index) == len(new_size), f"wrong ndim {index} {new_size}" index = list(index) index[dim] = index[dim] * step + start @@ -3556,6 +4063,7 @@ def constant_to_device(self, device: torch.device) -> IRNode: def is_contiguous_strides_for_shape( stride: Sequence[_IntLike], shape: Sequence[_IntLike] ) -> bool: +<<<<<<< HEAD expected_stride = 1 expected_stride_max = 1 for x, y in reversed(tuple(zip(shape, stride))): @@ -3571,6 +4079,14 @@ def is_contiguous_strides_for_shape( expected_stride *= x return True +======= + return all( + size == 1 or left == right + for left, right, size in zip( + stride, FlexibleLayout.contiguous_strides(shape), shape + ) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_align_for_dtype(dtype: torch.dtype) -> int: @@ -3587,6 +4103,7 @@ def get_device(self) -> Optional[torch.device]: def storage_size(self) -> int: raise NotImplementedError(type(self).__name__) +<<<<<<< HEAD def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -3602,14 +4119,25 @@ class Layout(OutputSpec): whether it is pinned. """ +======= + +@ir_dataclass +class Layout(OutputSpec): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__( self, device: torch.device, dtype: torch.dtype, +<<<<<<< HEAD size: Sequence[Expr], stride: Optional[Sequence[Expr]] = None, offset: Expr = Integer(0), is_pinned: bool = False, +======= + size: list[Expr], + stride: Optional[list[Expr]] = None, + offset: Expr = Integer(0), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: if stride is None: stride = FlexibleLayout.contiguous_strides(size) @@ -3617,12 +4145,18 @@ def __init__( self.dtype = dtype assert len(size) == len(stride), f"size={size}, stride={stride}" assert all(isinstance(s, (Expr, int)) for s in size) +<<<<<<< HEAD self.size = size self.stride = stride self.offset = offset self.is_pinned = is_pinned # is_pinned implies cpu assert (not self.is_pinned) or (self.device.type == "cpu") +======= + self.size: list[Expr] = size + self.stride: list[Expr] = stride + self.offset: Expr = offset +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __str__(self) -> str: offset = "" @@ -3630,12 +4164,18 @@ def __str__(self) -> str: offset = f", offset={self.offset}" device_index_str = "" if self.device.index is None else f":{self.device.index}" +<<<<<<< HEAD is_pinned_str = "" if self.is_pinned: is_pinned_str = f", is_pinned={self.is_pinned}" return ( f"{type(self).__name__}('{self.device.type}{device_index_str}', {self.dtype}, " f"size={self.size}, stride={self.stride}{offset}{is_pinned_str})" +======= + return ( + f"{type(self).__name__}('{self.device.type}{device_index_str}', {self.dtype}, " + f"size={self.size}, stride={self.stride}{offset})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) __repr__ = __str__ @@ -3650,7 +4190,10 @@ def get_example(self) -> torch.Tensor: convert_shape_to_symint(self.stride), dtype=self.dtype, device=self.device, +<<<<<<< HEAD pin_memory=self.is_pinned, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def is_contiguous(self) -> bool: @@ -3680,7 +4223,11 @@ def is_transposed(self) -> bool: return False return True +<<<<<<< HEAD def is_stride_ordered(self, order: Sequence[int]) -> bool: +======= + def is_stride_ordered(self, order) -> bool: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(self.stride) == len(order) # ignore dimensions of size 1, they dont affect layout @@ -3691,9 +4238,15 @@ def is_stride_ordered(self, order: Sequence[int]) -> bool: ] stride = [self.stride[i] for i in non_1_indices] +<<<<<<< HEAD order: Sequence[int] = [order[i] for i in non_1_indices] def sorted_indices(arr: Sequence[int]) -> Sequence[int]: +======= + order = [order[i] for i in non_1_indices] + + def sorted_indices(arr): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sorted_arr = sorted(arr) return [sorted_arr.index(element) for element in arr] @@ -3715,16 +4268,24 @@ def sorted_indices(arr: Sequence[int]) -> Sequence[int]: return False return True +<<<<<<< HEAD def is_channels_last_stride_ordered(self) -> bool: +======= + def is_channels_last_stride_ordered(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # create channels_last order(NCHW, NCDHW, the C is the first order). order = [0] + list(reversed(range(1, len(self.stride) - 1))) order = [len(order)] + order return self.is_stride_ordered(order) @staticmethod +<<<<<<< HEAD def _pad_strides( in_strides: Sequence[int], size: Sequence[Expr], dtype: torch.dtype ) -> Sequence[int]: +======= + def _pad_strides(in_strides, size, dtype): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ The padding does not change stride order but makes sure all strides larger than the threshold are multiple of align. @@ -3744,6 +4305,7 @@ def _pad_strides( ): return in_strides +<<<<<<< HEAD shape_env = V.graph._shape_env if hasattr(V.graph, "_shape_env") else None def contains_unbacked_symints(expr: sympy.Expr | int) -> bool: @@ -3758,6 +4320,20 @@ def contains_unbacked_symints(expr: sympy.Expr | int) -> bool: return in_strides stride_order = get_stride_order(in_strides, shape_env) +======= + # get_stride_order does not work with dynamic shape. Also we can not + # statically decide if a padding is needed or how much padding we should + # do for dynamic shape. + # + # Skip padding the strides for dynamic shape for now. + if not all( + isinstance(s, (int, sympy.Integer)) + for s in itertools.chain(in_strides, size) + ): + return in_strides + + stride_order = get_stride_order(in_strides) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fill_order = stride_order2fill_order(stride_order) new_strides = [0 for _ in range(len(in_strides))] @@ -3769,6 +4345,7 @@ def contains_unbacked_symints(expr: sympy.Expr | int) -> bool: for rank, idx in enumerate(fill_order[1:], start=1): prev_idx = fill_order[rank - 1] stride = new_strides[prev_idx] * size[prev_idx] +<<<<<<< HEAD # Static stride and meets padding conditions OR # Dynamic stride and config.pad_dynamic_shape=True require_padding = ( @@ -3780,6 +4357,13 @@ def contains_unbacked_symints(expr: sympy.Expr | int) -> bool: if require_padding: new_strides[idx] = ceildiv(stride, align) * align padded = True +======= + + if stride > config.padding_stride_threshold and stride % align != 0: + stride = ceildiv(stride, align) * align + padded = True + new_strides[idx] = stride +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not padded: # Consider a tensor with shape [256, 1, 5, 5] @@ -3790,6 +4374,7 @@ def contains_unbacked_symints(expr: sympy.Expr | int) -> bool: metrics.num_comprehensive_padding += 1 return new_strides +<<<<<<< HEAD def pad_strides(self) -> None: assert isinstance(self, FlexibleLayout), type(self) assert self.stride is not None @@ -3799,6 +4384,17 @@ def should_pad_strides(self) -> bool: return config.comprehensive_padding and isinstance(self, FlexibleLayout) def as_fixed(self) -> FixedLayout: +======= + def pad_strides(self): # type: ignore[no-untyped-def] + assert isinstance(self, FlexibleLayout) + assert self.stride is not None + self.stride = self._pad_strides(self.stride, self.size, self.dtype) + + def should_pad_strides(self): # type: ignore[no-untyped-def] + return config.comprehensive_padding and isinstance(self, FlexibleLayout) + + def as_fixed(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(self, FixedLayout): return self @@ -3810,7 +4406,10 @@ def as_fixed(self) -> FixedLayout: self.size, self.stride, self.offset, +<<<<<<< HEAD self.is_pinned, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: @@ -3819,14 +4418,21 @@ def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: ) return self.as_fixed().make_indexer() +<<<<<<< HEAD def __eq__(self, other: object) -> bool: return ( isinstance(other, Layout) and self.device == other.device +======= + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + return ( + self.device == other.device +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and self.dtype == other.dtype and self.size == other.size and self.stride == other.stride and self.offset == other.offset +<<<<<<< HEAD and self.is_pinned == other.is_pinned ) @@ -3841,6 +4447,12 @@ def get_free_symbol_uses( | get_free_symbols(self.stride, unbacked_only) | get_free_symbols(self.offset, unbacked_only) ) +======= + ) + + def storage_size(self) -> sympy.Expr: + return compute_required_storage_length(self.size, self.stride, self.offset) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class FixedLayout(Layout): @@ -3848,17 +4460,39 @@ class FixedLayout(Layout): def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: """A closure containing math to read a given element""" +<<<<<<< HEAD return _fixed_indexer(self.size, self.stride, self.offset) class FlexibleLayout(Layout): """A Tensor layout that we are allowed to change""" +======= + + def indexer(index): # type: ignore[no-untyped-def] + assert len(index) == len(self.stride) + assert len(index) == len(self.size) + result = self.offset + for idx, stride, sz in zip(index, self.stride, self.size): + if sz != 1: + result = result + idx * stride + return result + + return indexer + + +class FlexibleLayout(Layout): + """A Tensor layout we are allowed to change""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) allow_indexing = False # WARNING! This doesn't handle zero size tensors correctly @staticmethod +<<<<<<< HEAD def contiguous_strides(sizes: Sequence[int]) -> list[Expr]: +======= + def contiguous_strides(sizes): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if len(sizes) == 0: return [] reversed_strides = [sympy.S.One] @@ -3867,7 +4501,11 @@ def contiguous_strides(sizes: Sequence[int]) -> list[Expr]: return list(reversed(reversed_strides)) @staticmethod +<<<<<<< HEAD def fill_ordered(sizes: Sequence[int], order: Sequence[int]) -> list[Expr]: +======= + def fill_ordered(sizes, order): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Create a stride based on the order the dimensions should be filled in. @@ -3884,7 +4522,11 @@ def fill_ordered(sizes: Sequence[int], order: Sequence[int]) -> list[Expr]: return strides @staticmethod +<<<<<<< HEAD def stride_ordered(sizes: Sequence[int], order: Sequence[int]) -> Sequence[Expr]: +======= + def stride_ordered(sizes, order): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Create a stride based on the sorted order of a permuted range. @@ -3896,9 +4538,13 @@ def stride_ordered(sizes: Sequence[int], order: Sequence[int]) -> Sequence[Expr] return FlexibleLayout.fill_ordered(sizes, fill_order) @staticmethod +<<<<<<< HEAD def stride_ordered_for_memory_format( sizes: Sequence[int], memory_format: torch.memory_format ) -> Sequence[Expr]: +======= + def stride_ordered_for_memory_format(sizes, memory_format): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Create a stride based on a memory format. @@ -3923,9 +4569,13 @@ def stride_ordered_for_memory_format( raise NotImplementedError @staticmethod +<<<<<<< HEAD def same_ordered( sizes: Sequence[int], stride: Sequence[_IntLike] ) -> Sequence[Expr]: +======= + def same_ordered(sizes, stride): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Create a stride that has the same stride order as given stride @@ -3937,9 +4587,13 @@ def same_ordered( fill_order = sorted(range(len(stride)), key=stride.__getitem__) return FlexibleLayout.fill_ordered(sizes, fill_order) +<<<<<<< HEAD def as_stride_order( self, order: Sequence[int], allow_padding: bool = False ) -> FixedLayout: +======= + def as_stride_order(self, order, allow_padding=False): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_stride = self.stride_ordered(self.size, order) if self.should_pad_strides() and allow_padding: new_stride = self._pad_strides(new_stride, self.size, self.dtype) @@ -3950,12 +4604,18 @@ def as_stride_order( self.size, new_stride, self.offset, +<<<<<<< HEAD self.is_pinned, ) def as_exact_strides( self, exact_strides: Sequence[_IntLike], allow_padding: bool = False ) -> FixedLayout: +======= + ) + + def as_exact_strides(self, exact_strides, allow_padding=False): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_stride = exact_strides if self.should_pad_strides() and allow_padding: new_stride = self._pad_strides(new_stride, self.size, self.dtype) @@ -3966,11 +4626,18 @@ def as_exact_strides( self.size, new_stride, self.offset, +<<<<<<< HEAD self.is_pinned, ) def as_fill_order(self, order: Sequence[int]) -> FixedLayout: new_stride: Sequence[int] = self.fill_ordered(self.size, order) +======= + ) + + def as_fill_order(self, order): # type: ignore[no-untyped-def] + new_stride = self.fill_ordered(self.size, order) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.should_pad_strides(): new_stride = self._pad_strides(new_stride, self.size, self.dtype) return FixedLayout( @@ -3979,10 +4646,16 @@ def as_fill_order(self, order: Sequence[int]) -> FixedLayout: self.size, new_stride, self.offset, +<<<<<<< HEAD self.is_pinned, ) def as_same_order(self, stride: Sequence[_IntLike]) -> FixedLayout: +======= + ) + + def as_same_order(self, stride): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_stride = self.same_ordered(self.size, stride) if self.should_pad_strides(): new_stride = self._pad_strides(new_stride, self.size, self.dtype) @@ -3992,6 +4665,7 @@ def as_same_order(self, stride: Sequence[_IntLike]) -> FixedLayout: self.size, new_stride, self.offset, +<<<<<<< HEAD self.is_pinned, ) @@ -4003,11 +4677,20 @@ def __init__( stride_order: Optional[Sequence[Union[int, Integer]]] = None, is_pinned: bool = False, ) -> None: +======= + ) + + def __init__(self, device, dtype, size, stride_order=None) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if stride_order: strides = FlexibleLayout.fill_ordered(size, stride_order) else: strides = FlexibleLayout.contiguous_strides(size) +<<<<<<< HEAD super().__init__(device, dtype, size, strides, is_pinned=is_pinned) +======= + super().__init__(device, dtype, size, strides) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class NonOwningLayout(Layout): @@ -4026,7 +4709,11 @@ def __init__(self, view: Union[BaseView, TensorBox]) -> None: def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: return self.as_fixed().make_indexer() +<<<<<<< HEAD def maybe_guard_aligned(self) -> bool: +======= + def maybe_guard_aligned(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) offset = self.view.get_layout().offset if offset == 0: return True @@ -4034,6 +4721,7 @@ def maybe_guard_aligned(self) -> bool: return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) +<<<<<<< HEAD def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -4044,6 +4732,8 @@ def get_free_symbol_uses( assert isinstance(input_buffer, Buffer), type(box) return input_buffer.layout.get_free_symbol_uses(unbacked_only) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class CommBufferType(Enum): SYMM_MEM = "symm_mem" @@ -4083,7 +4773,10 @@ def __init__( size=fixed.size, stride=fixed.stride, offset=fixed.offset, +<<<<<<< HEAD is_pinned=fixed.is_pinned, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.comm_buffer_type = comm_buffer_type self.group_name = group_name @@ -4106,7 +4799,11 @@ class NoneLayout(OutputSpec): def storage_size(self) -> int: return 0 +<<<<<<< HEAD def as_fixed(self) -> OutputSpec: +======= + def as_fixed(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self def get_device(self) -> Optional[torch.device]: @@ -4118,7 +4815,11 @@ def __init__(self, target: IRNode) -> None: super().__init__( target.get_device_or_error(), target.get_dtype(), +<<<<<<< HEAD target.get_size(), +======= + target.get_size(), # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) None, ) self.target = target @@ -4126,18 +4827,30 @@ def __init__(self, target: IRNode) -> None: V.graph.mark_buffer_mutated(name) @property +<<<<<<< HEAD def stride(self) -> Sequence[Expr]: # type: ignore[override] +======= + def stride(self) -> list[Expr]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.real_layout().stride @stride.setter # type: ignore[override] def stride(self, value: Never) -> None: pass # ignore setting of stride +<<<<<<< HEAD def storage_size(self) -> Expr: return self.real_layout().storage_size() def get_buffer(self) -> Buffer: def unwrap_views(target: Any) -> Any: +======= + def storage_size(self) -> sympy.Expr: + return self.real_layout().storage_size() + + def get_buffer(self) -> Buffer: + def unwrap_views(target): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(target, MutationLayoutSHOULDREMOVE): return unwrap_views(target.target) if isinstance(target, BaseView): @@ -4147,6 +4860,7 @@ def unwrap_views(target: Any) -> Any: return target result = unwrap_views(self.target) +<<<<<<< HEAD assert isinstance(result, Buffer), type(result) return result @@ -4159,6 +4873,18 @@ def real_layout(self) -> Layout: def realize_into( cls, src: IRNode, dst: IRNode, unsafe_alias: bool = False ) -> IRNode: +======= + assert isinstance(result, Buffer), ( + "MutationLayoutSHOULDREMOVE must refer to a buffer" + ) + return result + + def real_layout(self): # type: ignore[no-untyped-def] + return self.get_buffer().layout + + @classmethod + def realize_into(cls, src, dst, unsafe_alias=False): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dst.realize() # NOTE: We must realize users of `dst` before we realize `src`, since # realization order determines scheduling order. Otherwise, src's @@ -4177,11 +4903,16 @@ def realize_into( src.realize_hint() if not unsafe_alias: +<<<<<<< HEAD node = Pointwise.create( +======= + src = Pointwise.create( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=src.get_device(), dtype=src.get_dtype(), inner_fn=src.make_loader(), ranges=[ +<<<<<<< HEAD V.graph.sizevars.check_equals_and_simplify(a, b) for a, b in zip(src.get_size(), dst.get_size()) ], @@ -4196,6 +4927,19 @@ def realize_into( return src.data def as_fixed(self) -> Self: # type: ignore[override] +======= + V.graph.sizevars.guard_equals(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + ).data + + src.realize() + assert isinstance(src.data.layout, FlexibleLayout) + src.data.layout = MutationLayoutSHOULDREMOVE(dst) + return src.data + + def as_fixed(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: @@ -4255,6 +4999,7 @@ def get_layout(self) -> Layout: def get_output_spec(self) -> OutputSpec: return self.layout +<<<<<<< HEAD def get_storage_numel(self) -> int: return self.get_numel() @@ -4262,11 +5007,18 @@ def get_is_pinned(self) -> bool: return self.get_layout().is_pinned def freeze_layout(self) -> None: +======= + def get_storage_numel(self): # type: ignore[no-untyped-def] + return self.get_numel() + + def freeze_layout(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(self.layout, Layout) and not isinstance( self.layout, NonOwningLayout ): self.layout = self.layout.as_fixed() +<<<<<<< HEAD def freeze_layout_with_stride_order( self, order: Sequence[int], allow_padding: bool = False ) -> None: @@ -4285,11 +5037,33 @@ def freeze_layout_with_exact_strides( self, exact_strides: Sequence[int], allow_padding: bool = False ) -> None: assert isinstance(self.layout, FlexibleLayout), type(self.layout) +======= + def freeze_layout_with_stride_order(self, order, allow_padding=False) -> None: # type: ignore[no-untyped-def] + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_stride_order(order, allow_padding=allow_padding) + + def freeze_layout_with_fill_order(self, order) -> None: # type: ignore[no-untyped-def] + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_fill_order(order) + + def freeze_layout_with_same_order(self, stride) -> None: # type: ignore[no-untyped-def] + assert isinstance(self.layout, FlexibleLayout) + self.layout = self.layout.as_same_order(stride) + + def freeze_layout_with_exact_strides( # type: ignore[no-untyped-def] + self, exact_strides, allow_padding=False + ) -> None: + assert isinstance(self.layout, FlexibleLayout) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.layout = self.layout.as_exact_strides( exact_strides, allow_padding=allow_padding ) +<<<<<<< HEAD def is_zero_elements(self) -> bool: +======= + def is_zero_elements(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0)) def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: @@ -4297,7 +5071,11 @@ def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: if self.is_zero_elements(): return partial(nop_loader_fn, dtype=self.get_dtype()) +<<<<<<< HEAD def loader(index: Sequence[Expr]) -> OpsValue: +======= + def loader(index): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) indexer = self.make_indexer() return ops.load(self.name or "unnamed", indexer(index)) @@ -4306,7 +5084,11 @@ def loader(index: Sequence[Expr]) -> OpsValue: def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: return self.get_name() +<<<<<<< HEAD def decide_layout(self) -> None: +======= + def decide_layout(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pass def get_inputs_that_alias_output(self) -> Sequence[str]: @@ -4427,6 +5209,7 @@ def has_tensor_output(self) -> bool: @ir_dataclass(frozen=False) class ComputedBuffer(OperationBuffer): +<<<<<<< HEAD """ Represents a buffer that is computed during kernel execution rather than being an input. """ @@ -4443,6 +5226,9 @@ def force_realize() -> Iterator[None]: yield finally: ComputedBuffer._force_realize = old_value +======= + data: Loops +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_computed_buffer_name(self) -> Optional[str]: """ @@ -4465,6 +5251,7 @@ def get_read_names(self) -> OrderedSet[str]: return self.data.get_read_names() def get_read_writes(self) -> dependencies.ReadWrites: +<<<<<<< HEAD if not isinstance(self.data, (Reduction, Scan, Sort, Pointwise)): return dependencies.ReadWrites( reads=OrderedSet(), @@ -4472,17 +5259,28 @@ def get_read_writes(self) -> dependencies.ReadWrites: index_exprs=OrderedSet(), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with patch.object(FlexibleLayout, "allow_indexing", True): if self.data.get_reduction_type(): return extract_read_writes( self.get_store_function(), +<<<<<<< HEAD self.data.get_pointwise_size(), self.data.get_reduction_size(), +======= + self.data.get_pointwise_size(), # type: ignore[arg-type] + self.data.get_reduction_size(), # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: return extract_read_writes( self.get_store_function(), +<<<<<<< HEAD self.data.get_size(), +======= + self.data.get_size(), # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def get_free_symbol_uses( @@ -4497,11 +5295,18 @@ def get_free_symbol_uses( # those symbols that establishes a dependency). However, we haven't # started codegen yet so we can't directly reuse that logic. # +<<<<<<< HEAD +======= + # For now, I'm just yoloing with the size of the buffer. Not sure if + # it is enough. + # +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # One thing you might wonder is if this is enough for a ComputedBuffer # denoting a reduction over i0. Empirically, it is enough, but for an # unusual reason: we only need accurate dependencies for item() call, # but it's impossible to end up with a reduction over i0 from an # item() call without a regular non-reduction buffer first. +<<<<<<< HEAD result = self.layout.get_free_symbol_uses( unbacked_only ) | self.data.get_free_symbol_uses(unbacked_only) @@ -4511,27 +5316,45 @@ def get_free_symbol_uses( ): result |= self.get_read_writes().get_free_symbol_uses(unbacked_only) return result +======= + return ( + get_free_symbols(self.get_size(), unbacked_only) + | get_free_symbols(self.get_stride(), unbacked_only) + | get_free_symbols(self.get_offset(), unbacked_only) + | self.data.get_free_symbol_uses(unbacked_only) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: if ( not self.get_reduction_type() and self.name not in V.graph.mutated_buffers and self.num_reads() == 0 +<<<<<<< HEAD and not self._force_realize +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): # inline this op rather than generating ops.load() return self.data.make_loader() return super().make_loader() +<<<<<<< HEAD def has_store_function(self) -> bool: return isinstance(self.data, (Reduction, Scan, Sort, Pointwise)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_store_function(self) -> Callable[..., None]: indexer = self.get_layout().as_fixed().make_indexer() if isinstance(self.data, (Reduction, Scan, Sort)): return partial(self.data.store_reduction, self.name, indexer) else: +<<<<<<< HEAD assert isinstance(self.data, Pointwise), type(self.data) +======= + assert isinstance(self.data, Pointwise) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return partial(self.data.store_output, self.name, indexer) def get_fill_order(self) -> Optional[list[int]]: @@ -4585,9 +5408,15 @@ def decide_layout(self) -> None: def get_default_sizes_body( self, ) -> tuple[ +<<<<<<< HEAD tuple[list[Expr], list[Expr]], LoopBody, tuple[list[Expr], list[Expr]], +======= + tuple[list[sympy.Expr], list[sympy.Expr]], + LoopBody, + tuple[list[sympy.Expr], list[sympy.Expr]], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ]: args, var_ranges = dependencies.index_vars_squeeze( self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q" @@ -4618,7 +5447,11 @@ def simplify_and_reorder( self, extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, recompute_sizes_body_func: Optional[Callable[..., Any]] = None, +<<<<<<< HEAD ) -> tuple[tuple[list[Expr], list[Expr]], Optional[LoopBody]]: +======= + ) -> tuple[tuple[list[sympy.Expr], list[sympy.Expr]], LoopBody]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This is a main place where we do loop transformations in a backend-agnostic way. @@ -4658,8 +5491,13 @@ def simplify_and_reorder( and len(extra_indexing_constraints) == 2 ) extra_indexing_ranges, extra_indexing_expr = extra_indexing_constraints +<<<<<<< HEAD assert isinstance(extra_indexing_ranges, dict), type(extra_indexing_ranges) assert isinstance(extra_indexing_expr, list), type(extra_indexing_expr) +======= + assert isinstance(extra_indexing_ranges, dict) + assert isinstance(extra_indexing_expr, list) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert all(isinstance(f, Expr) for f in extra_indexing_expr) expected_var_ranges = body.var_ranges @@ -4677,6 +5515,7 @@ def simplify_and_reorder( if not V.graph.has_feature(self, BackendFeature.PREFER_STORE_LOOP_ORDER): memory_addrs.extend(body.get_read_exprs()) +<<<<<<< HEAD def simplify_and_reorder( x_vars: Sequence[sympy.Symbol], support_vars: Sequence[sympy.Symbol], @@ -4687,6 +5526,9 @@ def simplify_and_reorder( Callable[[Sequence[int]], Sequence[int]], Callable[[Sequence[int]], Sequence[int]], ]: +======= + def simplify_and_reorder(x_vars, support_vars, sizes, simplify_loops): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sizes, reindex0, reindex1 = self._apply_loop_reordering( x_vars, support_vars, sizes, memory_addrs ) @@ -4739,6 +5581,7 @@ def simplify_and_reorder( return (iter_ranges, reduce_ranges), body @staticmethod +<<<<<<< HEAD def _apply_loop_reordering( index_vars: Sequence[sympy.Symbol], support_vars: Sequence[sympy.Symbol], @@ -4750,6 +5593,15 @@ def _apply_loop_reordering( Callable[[Sequence[int]], Sequence[int]], Callable[[Sequence[int]], Sequence[int]], ]: +======= + def _apply_loop_reordering( # type: ignore[no-untyped-def] + index_vars, + support_vars, + sizes, + memory_addrs, + priority_idx=None, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Shuffle the order of loops around to hopefully improve performance. """ @@ -4778,7 +5630,11 @@ def _apply_loop_reordering( sizes = [sizes[i] for i in order] return sizes, same_reorder(order), inverse_reorder(order) +<<<<<<< HEAD def get_reduction_size(self) -> Sequence[Expr]: +======= + def get_reduction_size(self) -> Sequence[sympy.Expr]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.data.get_reduction_size() def get_reduction_type(self) -> Optional[str]: @@ -4803,9 +5659,15 @@ class TemplateBuffer(OperationBuffer): def __init__( self, +<<<<<<< HEAD layout: OutputSpec, inputs: Sequence[IRNode], make_kernel_render: Optional[Callable[..., Any]], +======= + layout: Layout, + inputs: Sequence[IRNode], + make_kernel_render: Callable[..., Any], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: super().__init__(name=None, layout=layout) self.inputs = InputsKernel.unwrap_storage(inputs) @@ -4816,11 +5678,19 @@ def __init__( def get_read_writes(self) -> dependencies.ReadWrites: return self.extract_read_writes(normalize=True) +<<<<<<< HEAD def extract_read_writes(self, normalize: bool = False) -> dependencies.ReadWrites: name = self.get_name() indexer = self.get_layout().make_indexer() def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any: +======= + def extract_read_writes(self, normalize): # type: ignore[no-untyped-def] + name = self.get_name() + indexer = self.get_layout().make_indexer() + + def dummy(index, rindex): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(rindex) == 0 return ops.store(name, indexer(index), "fake") @@ -4829,6 +5699,7 @@ def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any: ) for inp in self.inputs: +<<<<<<< HEAD assert isinstance(inp, (ReinterpretView, Buffer)), type(inp) assert isinstance(inp.layout, Layout), type(inp.layout) @@ -4840,11 +5711,25 @@ def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any: deps.reads |= dependencies.extract_read_writes( dummy, inp.get_size(), (), normalize=normalize +======= + indexer = inp.layout.make_indexer() + + def dummy(index, rindex): # type: ignore[no-untyped-def] + assert len(rindex) == 0 + ops.load(inp.get_name(), indexer(index)) + + deps.reads |= dependencies.extract_read_writes( + dummy, inp.get_size(), (), normalize=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ).reads return deps +<<<<<<< HEAD def get_reduction_size(self) -> Sequence[Expr]: +======= + def get_reduction_size(self) -> Sequence[sympy.Expr]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return sympy.S.One def get_reduction_type(self) -> Optional[str]: @@ -4853,6 +5738,7 @@ def get_reduction_type(self) -> Optional[str]: def should_allocate(self) -> bool: return True +<<<<<<< HEAD def simplify_and_reorder( self, extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, @@ -4862,17 +5748,36 @@ def simplify_and_reorder( ( self.get_size(), [], +======= + def simplify_and_reorder( # type: ignore[no-untyped-def] + self, + extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ): + return ( + ( + self.get_size(), + (), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), None, ) class TritonTemplateBuffer(TemplateBuffer): +<<<<<<< HEAD def __init__( self, layout: Layout, inputs: Sequence[IRNode], make_kernel_render: Optional[Callable[_P, _T]], +======= + def __init__( # type: ignore[no-untyped-def] + self, + layout, + inputs, + make_kernel_render, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mutated_inputs: Optional[Iterable[IRNode]] = None, allowed_prologue_inps: Optional[OrderedSet[str]] = None, ) -> None: @@ -4898,7 +5803,10 @@ def __init__( assert current_node in allowed_set, ( f"Mutated inputs are only allowed for {allowed_set} but got {current_node}" ) +<<<<<<< HEAD assert isinstance(self.inputs[0], IRNode), type(self.inputs[0]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device = self.inputs[0].get_device() self.outputs += [ MutationOutput(NoneLayout(device=device), buf, self) @@ -4973,6 +5881,7 @@ def __init__( # knowing what autotuning is choosing) self.description = description +<<<<<<< HEAD def benchmark(self, *args: Any, out: torch.Tensor) -> float: algo = self.to_callable() benchmark_configs = { @@ -4982,11 +5891,22 @@ def benchmark(self, *args: Any, out: torch.Tensor) -> float: if config.profile_bandwidth_with_do_bench_using_profiling: return do_bench_using_profiling(lambda: algo(*args), **benchmark_configs) return benchmarker.benchmark(algo, args, {"out": out}, **benchmark_configs) +======= + def benchmark(self, *args, out) -> float: # type: ignore[no-untyped-def] + algo = self.to_callable() + if config.profile_bandwidth_with_do_bench_using_profiling: + return do_bench_using_profiling(lambda: algo(*args)) + return benchmarker.benchmark(algo, args, {"out": out}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def call_name(self) -> str: raise NotImplementedError +<<<<<<< HEAD def to_callable(self) -> Callable[..., Any]: +======= + def to_callable(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError def kernel_hash_key(self) -> str: @@ -4999,7 +5919,11 @@ def kernel_hash_key(self) -> str: def hash_key(self) -> str: raise NotImplementedError +<<<<<<< HEAD def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= + def output_node(self) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: @@ -5027,8 +5951,13 @@ class MultiTemplateBuffer(TritonTemplateBuffer): def __init__( self, layout: Layout, +<<<<<<< HEAD inputs: Sequence[IRNode], choice_timings_fn: Callable[[Optional[int]], dict[ChoiceCaller, float]], +======= + inputs: list[IRNode], + choice_timings_fn: Callable[[], dict[ChoiceCaller, float]], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unfiltered_choices: list[ChoiceCaller], allowed_prologue_inps: OrderedSet[str], ) -> None: @@ -5039,7 +5968,11 @@ def __init__( allowed_prologue_inps=allowed_prologue_inps, ) self._choice_timings_fn = choice_timings_fn +<<<<<<< HEAD self._choice_timings: dict[Optional[int], dict[ChoiceCaller, float]] = {} +======= + self._choice_timings: Optional[dict[ChoiceCaller, float]] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.original_inputs = inputs self._output_plannable = all( isinstance(choice, TritonTemplateCallerBase) @@ -5049,7 +5982,10 @@ def __init__( ) for choice in unfiltered_choices ) +<<<<<<< HEAD self._make_kernel_renders: dict[Optional[int], Any] = {} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def output_plannable(self) -> bool: @@ -5058,6 +5994,7 @@ def output_plannable(self) -> bool: """ return self._output_plannable +<<<<<<< HEAD def choice_timings( self, hint_override: Optional[int] = None ) -> dict[ChoiceCaller, float]: @@ -5070,6 +6007,17 @@ def swap_as_triton_caller(self, caller: TritonTemplateCallerBase) -> Iterator[No assert isinstance( caller, torch._inductor.select_algorithm.TritonTemplateCaller ), type(caller) +======= + @property + def choice_timings(self) -> dict[ChoiceCaller, float]: + if self._choice_timings is None: + self._choice_timings = self._choice_timings_fn() + return self._choice_timings + + @contextlib.contextmanager + def swap_as_triton_caller(self, caller: TritonTemplateCallerBase): # type: ignore[no-untyped-def] + assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.layout == caller.layout render = self.make_kernel_render @@ -5080,13 +6028,18 @@ def swap_as_triton_caller(self, caller: TritonTemplateCallerBase) -> Iterator[No self.make_kernel_render = render def finalize_as_triton_caller(self, caller: TritonTemplateCallerBase) -> None: +<<<<<<< HEAD assert isinstance( caller, torch._inductor.select_algorithm.TritonTemplateCaller ), type(caller) +======= + assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.get_size() == caller.layout.size assert self.get_stride() == caller.layout.stride self.make_kernel_render = caller.get_make_kernel_render() +<<<<<<< HEAD def get_min_choice( self, hint_override: Optional[int] = None ) -> tuple[ChoiceCaller, float]: @@ -5111,6 +6064,19 @@ def __init__( layout: Layout, inputs: Sequence[IRNode], make_kernel_render: Callable[_P, _T], +======= + def get_min_choice(self) -> tuple[ChoiceCaller, float]: + min_choice = min(self.choice_timings, key=self.choice_timings.get) # type: ignore[arg-type] + return (min_choice, self.choice_timings[min_choice]) + + +class CUDATemplateBuffer(TemplateBuffer): + def __init__( # type: ignore[no-untyped-def] + self, + layout, + inputs, + make_kernel_render, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) workspace_size: int, template: CUDATemplate, supports_epilogue_fusion: bool, @@ -5121,7 +6087,11 @@ def __init__( self.template = template self.supports_epilogue_fusion = supports_epilogue_fusion +<<<<<<< HEAD def get_workspace_size(self) -> int: +======= + def get_workspace_size(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.workspace_size if self.workspace_size is not None else 0 def emulate_store_fn(self) -> None: @@ -5130,6 +6100,7 @@ def emulate_store_fn(self) -> None: class CppTemplateBuffer(TemplateBuffer): +<<<<<<< HEAD def __init__( self, layout: Layout, @@ -5138,6 +6109,9 @@ def __init__( template: CUDATemplate, choice: Any, ) -> None: +======= + def __init__(self, layout, inputs, make_kernel_render, template, choice) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(layout, inputs, make_kernel_render) self.template = template self.choice = choice @@ -5145,16 +6119,25 @@ def __init__( def get_layout(self) -> Layout: if isinstance(self.layout, MultiOutputLayout): +<<<<<<< HEAD assert isinstance(self.outputs, Iterable), type(self.outputs) first_output = self.outputs[0] assert isinstance(first_output, Buffer), type(first_output) layout = first_output.layout assert isinstance(layout, Layout), type(layout) +======= + assert isinstance(self.outputs, Iterable) + first_output = self.outputs[0] + assert isinstance(first_output, Buffer) + layout = first_output.layout + assert isinstance(layout, Layout) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return layout else: return super().get_layout() +<<<<<<< HEAD class CuteDSLTemplateBuffer(TemplateBuffer): """ Buffer for CuteDSL (CUTLASS Python DSL) template kernels. @@ -5200,12 +6183,21 @@ def input_name(self, i: int) -> str: input = self.inputs[i] assert isinstance(input, IRNode) return input.get_name() +======= +@ir_dataclass(frozen=False) +class InputsKernel(OperationBuffer): + inputs: list[Buffer] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_read_writes(self) -> dependencies.ReadWrites: reads = OrderedSet[dependencies.Dep]() StarDep = dependencies.StarDep for input in self.inputs: +<<<<<<< HEAD if isinstance(input, Sequence): +======= + if isinstance(input, list): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reads.update(StarDep(x.get_name()) for x in input) elif isinstance(input, ShapeAsConstantBuffer): # Skip creating dependency for symbolics as they're visible globally @@ -5242,6 +6234,7 @@ def unwrap_storage_for_input(cls, x: IRNode) -> IRNode: return cls.unwrap_storage_for_input(x) if isinstance(x, TorchBindObject): return x +<<<<<<< HEAD assert isinstance(x, (Buffer, ReinterpretView)), type(x) return x @@ -5252,6 +6245,16 @@ def unwrap_storage( inputs_new: list[Union[IRNode, Sequence[IRNode]]] = [] for x in inputs: if isinstance(x, Sequence): +======= + assert isinstance(x, (Buffer, ReinterpretView)), x + return x + + @staticmethod + def unwrap_storage(inputs): # type: ignore[no-untyped-def] + inputs_new = [] + for x in inputs: + if isinstance(x, list): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = [InputsKernel.unwrap_storage_for_input(i) for i in x] else: x = InputsKernel.unwrap_storage_for_input(x) @@ -5264,6 +6267,7 @@ def is_extern(self) -> bool: def num_reads(self) -> int: return 1 +<<<<<<< HEAD def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -5276,6 +6280,8 @@ def get_free_symbol_uses( r |= inner_inp.get_free_symbol_uses(unbacked_only) return r +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class NopKernel(InputsKernel): def is_no_op(self) -> bool: @@ -5292,10 +6298,14 @@ class ConcatKernel(NopKernel): """ @classmethod +<<<<<<< HEAD def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: """ Create the concat kernel from inputs """ +======= + def create(cls, inputs, dim): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device = inputs[0].get_device() dtype = inputs[0].get_dtype() new_size = list(inputs[0].get_size()) @@ -5312,12 +6322,20 @@ def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: if j == dim: new_size[j] = new_size[j] + input_size[j] else: +<<<<<<< HEAD new_size[j] = V.graph.sizevars.check_equals_and_simplify( +======= + new_size[j] = V.graph.sizevars.guard_equals( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_size[j], input_size[j] ) offsets_end.append(new_size[dim]) +<<<<<<< HEAD output_stride: Sequence[int] = FlexibleLayout.contiguous_strides(new_size) +======= + output_stride = FlexibleLayout.contiguous_strides(new_size) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if config.comprehensive_padding: # Ensure the output stride matches the alignment requirements output_stride = Layout._pad_strides( @@ -5337,7 +6355,11 @@ def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: break any_input_is_storage_and_layout = any(is_storage_and_layout(x) for x in inputs) fx_node_args = V.graph.current_node.args[0] +<<<<<<< HEAD assert isinstance(fx_node_args, list), type(fx_node_args) +======= + assert isinstance(fx_node_args, list) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # If any of the inputs has meta tensor and the meta tensor is in CL format, use CL format for the output if any_input_is_storage_and_layout is False and any( "val" in arg.meta @@ -5349,11 +6371,14 @@ def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: ): output_stride = make_channels_last_strides_for(new_size) +<<<<<<< HEAD is_pinned = all( is_storage_and_layout(x) and x.get_layout().is_pinned for x in inputs ) assert device is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) concat_kernel = ConcatKernel( name=None, layout=FixedLayout( @@ -5361,20 +6386,30 @@ def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: dtype=dtype, size=new_size, stride=output_stride, +<<<<<<< HEAD is_pinned=is_pinned, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), inputs=[], ) kernel = StorageBox(concat_kernel) op_names = [] +<<<<<<< HEAD for i, inp in enumerate(inputs): assert isinstance(inp, (BaseView, MutableBox)), type(inp) input_buffer = cls.realize_into( inp, +======= + for i in range(len(inputs)): + input_buffer = cls.realize_into( + inputs[i], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SliceView.create( kernel, dim, offsets_start[i], offsets_end[i], clamp=False ), ) +<<<<<<< HEAD assert isinstance(input_buffer, Buffer), type(input_buffer) assert isinstance(concat_kernel.inputs, list), type(concat_kernel.inputs) concat_kernel.inputs.append(input_buffer) @@ -5389,6 +6424,18 @@ def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: and input_unwrapped.is_input_buffer() and (dev := inp.get_device()) is not None and is_gpu(dev.type) +======= + concat_kernel.inputs.append(input_buffer) + + if isinstance(inputs[i].data, BaseView): + input_unwrapped = inputs[i].data.unwrap_view() + else: + input_unwrapped = inputs[i].data + + if ( + input_unwrapped.is_input_buffer() + and is_gpu(inputs[i].get_device().type) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and not is_dynamic(input_buffer) ): op_names.append(input_buffer.get_operation_name()) @@ -5403,14 +6450,21 @@ def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: return kernel @classmethod +<<<<<<< HEAD def can_realize_into_without_copy( cls, src: IRNode, dst: Optional[IRNode] = None ) -> bool: +======= + def can_realize_into_without_copy(cls, src, dst=None): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(src, TensorBox): # unwrap a TensorBox return cls.can_realize_into_without_copy(src.data, dst) +<<<<<<< HEAD assert isinstance(src, (BaseView, StorageBox)), type(src) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(src.data, MultiTemplateBuffer): if ( not isinstance(src.data.layout, FixedLayout) @@ -5432,6 +6486,7 @@ def can_realize_into_without_copy( for s1, s2 in zip(src.get_stride(), dst.get_stride()) ) +<<<<<<< HEAD return ( hasattr(src.data, "layout") and isinstance(src.data.layout, FlexibleLayout) @@ -5445,6 +6500,14 @@ def get_free_symbol_uses( @classmethod def realize_into(cls, src: IRNode, dst: IRNode) -> IRNode: +======= + return isinstance(src.data.layout, FlexibleLayout) and not isinstance( + src.data, ExternKernelAlloc + ) + + @classmethod + def realize_into(cls, src, dst): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Attempt to turn this into a ReinterpretView rather than assert. # This has concessions around layout, as as_storage_and_layout # can cause us to go from flexible to fixed layout. @@ -5452,7 +6515,11 @@ def realize_into(cls, src: IRNode, dst: IRNode) -> IRNode: if is_storage_and_layout(dst): storage, layout = as_storage_and_layout(dst) dst = ReinterpretView(data=storage, layout=layout) +<<<<<<< HEAD assert isinstance(dst, ReinterpretView), type(dst) +======= + assert isinstance(dst, ReinterpretView), dst +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(src, TensorBox): # unwrap a TensorBox return cls.realize_into(src.data, dst) @@ -5470,7 +6537,11 @@ def realize_into(cls, src: IRNode, dst: IRNode) -> IRNode: dtype=src.get_dtype(), inner_fn=src.make_loader(), ranges=[ +<<<<<<< HEAD V.graph.sizevars.check_equals_and_simplify(a, b) +======= + V.graph.sizevars.guard_equals(a, b) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for a, b in zip(src.get_size(), dst.get_size()) ], ) @@ -5482,12 +6553,16 @@ def should_allocate(self) -> bool: @ir_dataclass(frozen=False) class ExternKernel(InputsKernel): +<<<<<<< HEAD """ A class that represents Kernels which are not directly lowered to Inductor Loop Level IR, such as custom operators, or aten operators which we fallback to. """ constant_args: Sequence[Any] = () +======= + constant_args: tuple[Any, ...] = () +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) output_view: Optional[ReinterpretView] = None python_kernel_name: Optional[str] = None @@ -5497,17 +6572,25 @@ class ExternKernel(InputsKernel): ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field( default_factory=list ) +<<<<<<< HEAD op_overload: Optional[_OpOverloads] = None arg_properties: Optional[list[dict[str, Any]]] = None allarg_properties: dict[str, dict[str, Any]] = dataclasses.field( default_factory=dict ) +======= + op_overload: Optional[ + Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator] + ] = None + arg_properties: Optional[list[dict[str, Any]]] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwarg_properties: Optional[dict[str, dict[str, Any]]] = None unbacked_bindings: dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field( default_factory=dict ) mutation_outputs: list[MutationOutput] = dataclasses.field(default_factory=list) +<<<<<<< HEAD def __init__( self, name: Optional[str], @@ -5520,6 +6603,20 @@ def __init__( cpp_kernel_name: Optional[str] = None, ordered_kwargs_for_cpp_kernel: Iterable[str] = (), op_overload: Optional[_OpOverloads] = None, +======= + def __init__( # type: ignore[no-untyped-def] + self, + name, + layout, + inputs, + constant_args=(), + kwargs=None, + output_view=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: super().__init__( name=name, @@ -5544,7 +6641,11 @@ def get_outputs(self) -> list[Buffer]: def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() +<<<<<<< HEAD def collect_arg_kwarg_properties(self) -> None: +======= + def collect_arg_kwarg_properties(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen self.arg_properties = ( @@ -5581,17 +6682,29 @@ def collect_arg_kwarg_properties(self) -> None: else: self.schema_kwargs = [] +<<<<<<< HEAD def decide_layout(self) -> None: +======= + def decide_layout(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(self.layout, FlexibleLayout): self.apply_constraint() self.freeze_layout() +<<<<<<< HEAD def codegen_comment(self, wrapper: PythonWrapperCodegen) -> None: +======= + def codegen_comment(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) origin_str, _detailed_origin_str = get_kernel_metadata(self, wrapper) if origin_str: wrapper.make_comment(origin_str) +<<<<<<< HEAD def codegen(self, wrapper: PythonWrapperCodegen) -> None: +======= + def codegen(self, wrapper): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None: @@ -5633,6 +6746,7 @@ def set_python_kernel_name(self, python_kernel_name: Optional[str]) -> None: f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}" ) +<<<<<<< HEAD def get_kernel_name(self) -> str: from .codegen.cpp_wrapper_cpu import CppWrapperCpu @@ -5654,6 +6768,18 @@ def get_kernel_name(self) -> str: @staticmethod def copy_input(x: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= + def get_kernel_name(self): # type: ignore[no-untyped-def] + device = d.type if (d := self.get_device()) else V.graph.device_type + return ( + V.graph.wrapper_code.get_c_shim_func_name(self.cpp_kernel_name, device) # type: ignore[attr-defined] + if V.graph.cpp_wrapper + else self.python_kernel_name + ) + + @staticmethod + def copy_input(x): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pw = Pointwise.create( device=x.get_device(), dtype=x.get_dtype(), @@ -5666,8 +6792,13 @@ def copy_input(x: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]: return pw @classmethod +<<<<<<< HEAD def process_kernel( cls, kernel: _OpOverloads, *args: Any, **kwargs: Any +======= + def process_kernel( # type: ignore[no-untyped-def] + cls, kernel, *args, **kwargs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> tuple[ Any, list[Any], @@ -5690,6 +6821,7 @@ def process_kernel( if is_arg_tensor[-1]: tensor_args.append(arg) else: +<<<<<<< HEAD if isinstance(arg, Expr): arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None) non_tensor_args.append(arg) @@ -5697,6 +6829,13 @@ def process_kernel( def unflatten_args( new_tensor_args: Sequence[_T], new_non_tensor_args: Sequence[_T] ) -> tuple[list[_T], dict[str, _T]]: +======= + if isinstance(arg, sympy.Expr): + arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None) + non_tensor_args.append(arg) + + def unflatten_args(new_tensor_args, new_non_tensor_args): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result = [] it_tensors = iter(new_tensor_args) it_non_tensors = iter(new_non_tensor_args) @@ -5755,11 +6894,19 @@ def unflatten_args( unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None if shape_env := V.fake_mode.shape_env: node_meta_val = V.current_node.meta.get("val") +<<<<<<< HEAD ctx: AbstractContextManager[None] = nullcontext() if V.current_node.target == torch._higher_order_ops.effects.with_effects: # remove the first effect token in meta["val"] and meta["unbacked_bindings"] node_meta_val = node_meta_val[1] ctx = _remove_effect_token_unbacked_bindings(V.current_node) +======= + ctx = nullcontext() + if V.current_node.target == torch._higher_order_ops.effects.with_effects: + # remove the first effect token in meta["val"] and meta["unbacked_bindings"] + node_meta_val = node_meta_val[1] + ctx = _remove_effect_token_unbacked_bindings(V.current_node) # type: ignore[assignment] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with ctx: rebind_unbacked(shape_env, V.current_node, example_output) @@ -5788,13 +6935,21 @@ def unflatten_args( ) @classmethod +<<<<<<< HEAD def convert_to_reinterpret_view(cls, x: IRNode) -> ReinterpretView: +======= + def convert_to_reinterpret_view(cls, x): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ In order to pass this to an extern kernel we need a ReinterpretView not a View. This allows us to avoid some unneeded copies. """ +<<<<<<< HEAD assert isinstance(x, BaseView), type(x) +======= + assert isinstance(x, BaseView) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(x, ReinterpretView): return x @@ -5808,7 +6963,10 @@ def convert_to_reinterpret_view(cls, x: IRNode) -> ReinterpretView: if ( x_unwrap_view_fx_node is not None and "val" in x_unwrap_view_fx_node.meta +<<<<<<< HEAD and isinstance(x_unwrap_view, (ReinterpretView, Buffer, MutableBox)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and isinstance(x_unwrap_view.layout, FlexibleLayout) and ( x_unwrap_view_fx_node.meta["val"].is_contiguous( @@ -5826,7 +6984,12 @@ def convert_to_reinterpret_view(cls, x: IRNode) -> ReinterpretView: x_unwrap_view.freeze_layout() index_args, var_ranges = dependencies.index_vars_squeeze( +<<<<<<< HEAD x.get_size(), prefix="r" +======= + x.get_size(), + prefix="r", # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) range_vars = index_args[0] index = x.make_indexer()(range_vars) @@ -5850,18 +7013,31 @@ def convert_to_reinterpret_view(cls, x: IRNode) -> ReinterpretView: layout=FixedLayout( device=x.get_device_or_error(), dtype=x.get_dtype(), +<<<<<<< HEAD size=x.get_size(), stride=strides, offset=offset, is_pinned=False, +======= + size=x.get_size(), # type: ignore[arg-type] + stride=strides, + offset=offset, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ) @classmethod +<<<<<<< HEAD def realize_input(cls, x: IRNode) -> IRNode: if x is None: return NoneAsConstantBuffer() if isinstance(x, (Expr, sympy.logic.boolalg.Boolean, int)): +======= + def realize_input(cls, x): # type: ignore[no-untyped-def] + if x is None: + return NoneAsConstantBuffer() + if isinstance(x, (sympy.Expr, sympy.logic.boolalg.Boolean, int)): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ShapeAsConstantBuffer(expr=x) if isinstance(x, Constant): return V.graph.add_tensor_constant( @@ -5891,7 +7067,11 @@ def realize_input(cls, x: IRNode) -> IRNode: return cls.copy_input(x) @classmethod +<<<<<<< HEAD def require_stride1(cls, x: IRNode) -> IRNode: +======= + def require_stride1(cls, x): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_storage_and_layout(x): if len(x.get_stride()) == 0: return x @@ -5901,6 +7081,7 @@ def require_stride1(cls, x: IRNode) -> IRNode: return cls.copy_input(x) @classmethod +<<<<<<< HEAD def require_strides( cls, x: IRNode, @@ -5908,6 +7089,15 @@ def require_strides( exact_strides: Optional[Sequence[_IntLike]] = None, allow_padding: bool = False, ) -> IRNode: +======= + def require_strides( # type: ignore[no-untyped-def] + cls, + x, + order: Optional[Sequence[int]] = None, + exact_strides: Optional[Sequence[_IntLike]] = None, + allow_padding=False, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert order is not None or exact_strides is not None # Layout generally doesn't matter, but some consuming external ops might have requirements if x.get_numel() in (0, 1) and not exact_strides: @@ -5925,9 +7115,13 @@ def require_strides( # the current size and stride already satisfies this order. # However by freezing it to the required order, the layout will be changed to: # size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary. +<<<<<<< HEAD use_current_stride_order = is_stride_order_storage_and_layout( x, order ) and not free_unbacked_symbols(x.get_layout().stride) +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # fix flexiblelayout to be FixedLayout with stride_order as_storage_and_layout( x, @@ -5935,11 +7129,17 @@ def require_strides( want_contiguous=False, stride_order=( get_stride_order( +<<<<<<< HEAD V.graph.sizevars.size_hints_or_throw( x.get_layout().stride ) ) if use_current_stride_order +======= + V.graph.sizevars.size_hints(x.get_layout().stride) + ) + if is_stride_order_storage_and_layout(x, order) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else order ), allow_padding=allow_padding, @@ -5970,6 +7170,7 @@ def require_strides( if exact_strides is not None else x ) +<<<<<<< HEAD elif isinstance( (mutation_layout := x.get_layout()), MutationLayoutSHOULDREMOVE ): @@ -5985,6 +7186,21 @@ def require_strides( exact_strides and significant_strides_equal( exact_strides, real_layout.stride, x.get_size() +======= + elif isinstance(x.get_layout(), MutationLayoutSHOULDREMOVE): + if isinstance(x.get_layout().real_layout(), FlexibleLayout): + raise AssertionError( + "the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout" + ) + elif isinstance(x.get_layout().real_layout(), FixedLayout) and ( + (order and x.get_layout().real_layout().is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, + x.get_layout().real_layout().stride, + x.get_size(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) ): @@ -6005,9 +7221,14 @@ def require_strides( isinstance(x, TensorBox) and isinstance(x.data, BaseView) and not isinstance(x.data, ReinterpretView) +<<<<<<< HEAD and is_storage_and_layout(unwrap_view := x.unwrap_view()) and hasattr(unwrap_view, "data") and not isinstance(unwrap_view.data, ExternKernelAlloc) +======= + and is_storage_and_layout(x.unwrap_view()) + and not isinstance(x.unwrap_view().data, ExternKernelAlloc) # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): try: x.data = cls.convert_to_reinterpret_view(x.data) @@ -6065,14 +7286,19 @@ def require_strides( return x @classmethod +<<<<<<< HEAD def require_exact_strides( cls, x: IRNode, exact_strides: Sequence[_IntLike], allow_padding: bool = False ) -> IRNode: +======= + def require_exact_strides(cls, x, exact_strides, allow_padding=False): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return cls.require_strides( x, exact_strides=exact_strides, allow_padding=allow_padding ) @classmethod +<<<<<<< HEAD def require_stride_order( cls, x: IRNode, order: Sequence[int], allow_padding: bool = False ) -> IRNode: @@ -6095,6 +7321,32 @@ def is_mkldnn_tensor(x: IRNode) -> bool: return False return name in V.graph.constants and V.graph.constants[name].is_mkldnn +======= + def require_stride_order(cls, x, order, allow_padding=False): # type: ignore[no-untyped-def] + return cls.require_strides(x, order=order, allow_padding=allow_padding) + + @classmethod + def require_channels_last(cls, x): # type: ignore[no-untyped-def] + return cls.require_stride_order(x, NHWC_STRIDE_ORDER) + + @classmethod + def require_channels_last_3d(cls, x): # type: ignore[no-untyped-def] + return cls.require_stride_order(x, NHWDC_STRIDE_ORDER) + + @classmethod + def require_contiguous(cls, x): # type: ignore[no-untyped-def] + def is_mkldnn_tensor(x): # type: ignore[no-untyped-def] + def safe_get_name(x): # type: ignore[no-untyped-def] + try: + return x.get_name() + except (AttributeError, NotImplementedError): + return None + + return ( + safe_get_name(x) in V.graph.constants + and V.graph.constants[safe_get_name(x)].is_mkldnn + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO move this to the more proper places if is_mkldnn_tensor(x): @@ -6105,7 +7357,11 @@ def is_mkldnn_tensor(x: IRNode) -> bool: ) @classmethod +<<<<<<< HEAD def require_contiguous_strides(cls, x: IRNode) -> IRNode: +======= + def require_contiguous_strides(cls, x): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: combine this with require_contiguous after # https://github.com/pytorch/pytorch/pull/148235 lands. return cls.require_exact_strides( @@ -6115,9 +7371,13 @@ def require_contiguous_strides(cls, x: IRNode) -> IRNode: def apply_constraint(self) -> None: pass +<<<<<<< HEAD def fill_non_provided_args( self, args: Sequence[Any], kwargs: dict[str, Any] ) -> Sequence[Any]: +======= + def fill_non_provided_args(self, args, kwargs): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Previously, we want to maintain forward-compatibility by skipping # default args in the serialized artifacts in fbcode. However, # some of our shim interfaces require default values being OrderedSet. @@ -6126,8 +7386,13 @@ def fill_non_provided_args( # part if we see real FC requirement. More details related to FC # can be found at: # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing +<<<<<<< HEAD assert isinstance(args, Sequence), type(args) if not isinstance(args, list): +======= + assert isinstance(args, (list, tuple)) + if isinstance(args, tuple): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = list(args) assert self.arg_properties, "ExternKernel.arg_properties should not be empty" @@ -6151,7 +7416,11 @@ def fill_non_provided_args( ) return args +<<<<<<< HEAD def codegen_const_args(self, names: Optional[list[str]] = None) -> list[str]: +======= + def codegen_const_args(self, names: Optional[list[str]] = None): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if V.graph.cpp_wrapper: result = [] # Aten ops follow the convention that tensor args are before non-tensor args, @@ -6169,8 +7438,12 @@ def codegen_const_args(self, names: Optional[list[str]] = None) -> list[str]: for i, x in enumerate(self.constant_args): if name_to_arg_properties is not None: +<<<<<<< HEAD assert names is not None prop = name_to_arg_properties.get(names[i]) +======= + prop = name_to_arg_properties.get(names[i]) # type: ignore[index] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) type_ = prop.get("type") if prop else None else: idx = len(self.inputs) + i @@ -6182,9 +7455,15 @@ def codegen_const_args(self, names: Optional[list[str]] = None) -> list[str]: result.append(V.graph.wrapper_code.val_to_arg_str(x, type_)) return result else: +<<<<<<< HEAD return [V.graph.wrapper_code.val_to_arg_str(a) for a in self.constant_args] def codegen_args(self) -> list[str]: +======= + return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args) + + def codegen_args(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if V.graph.cpp_wrapper and self.op_overload is not None: # cpp wrapper needs special logic to fill in missing args with default values inputs = self.fill_non_provided_args( @@ -6210,7 +7489,11 @@ def codegen_args(self) -> list[str]: args.extend(self.codegen_const_args()) return args +<<<<<<< HEAD def get_kwargs_value(self, arg_name: str, **kwargs: Any) -> Any: +======= + def get_kwargs_value(self, arg_name, **kwargs): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Given an argument name, queries for values in (in order): 1. any provided kwargs for this function. 2. the class self.kwargs member. @@ -6219,11 +7502,19 @@ def get_kwargs_value(self, arg_name: str, **kwargs: Any) -> Any: return kwargs.get(arg_name) if arg_name in self.kwargs: return self.kwargs.get(arg_name) +<<<<<<< HEAD if (arg := self.allarg_properties.get(arg_name)) is not None: return arg.get("default_value") raise AssertionError(f"{arg_name} not in self.allarg_properties") def codegen_kwargs(self, skip_out: bool = False) -> list[str]: +======= + if self.allarg_properties and arg_name in self.allarg_properties: + return self.allarg_properties.get(arg_name).get("default_value") # type: ignore[union-attr] + raise AssertionError(f"{arg_name} not in self.allarg_properties") + + def codegen_kwargs(self, skip_out=False): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if V.graph.cpp_wrapper: if self.op_overload is not None and len(self.schema_kwargs) == 0: # All the args should have been generated by fill_non_provided_args in codegen_args @@ -6236,11 +7527,22 @@ def codegen_kwargs(self, skip_out: bool = False) -> list[str]: continue v = self.get_kwargs_value(arg_name) +<<<<<<< HEAD if isinstance(v, Expr): kwargs.append(v) else: assert self.allarg_properties is not None type_ = self.allarg_properties.get(arg_name, {}).get("type") +======= + if isinstance(v, sympy.Expr): + kwargs.append(v) + else: + type_ = ( + self.allarg_properties.get(arg_name).get("type") # type: ignore[union-attr] + if self.allarg_properties and arg_name in self.allarg_properties + else None + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs.append(V.graph.wrapper_code.val_to_arg_str(v, type_)) else: kwargs = [ @@ -6260,7 +7562,11 @@ def get_op_name(self) -> str: op_name = "unknown_op" return op_name +<<<<<<< HEAD def codegen_size_asserts(self, wrapper: PythonWrapperCodegen) -> None: +======= + def codegen_size_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if config.size_asserts and not V.graph.cpp_wrapper: # comparing strides for 0 size tensor is tricky. Ignore them for now. if sympy_product(self.get_size()) == 0: @@ -6272,7 +7578,11 @@ def codegen_size_asserts(self, wrapper: PythonWrapperCodegen) -> None: f"assert_size_stride({self.get_name()}, {size}, {stride}, {op_name!r})" ) +<<<<<<< HEAD def codegen_alignment_asserts(self, wrapper: PythonWrapperCodegen) -> None: +======= + def codegen_alignment_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if config.alignment_asserts and not V.graph.cpp_wrapper: name = self.get_name() aligned = name not in V.graph.unaligned_buffers @@ -6286,6 +7596,7 @@ def codegen_alignment_asserts(self, wrapper: PythonWrapperCodegen) -> None: f"# buffer {name} (op: {op_name}) is assumed to be not aligned" ) +<<<<<<< HEAD def codegen_memory_tracking(self, wrapper: PythonWrapperCodegen) -> None: """ Track outputs of fallback operators if config.test_configs.track_memory_lifecycle @@ -6298,6 +7609,9 @@ def codegen_memory_tracking(self, wrapper: PythonWrapperCodegen) -> None: wrapper.writeline(f"track_tensor({name}, '{name}')") def get_group_stride(self) -> tuple[list[Sequence[Expr]], list[Expr]]: +======= + def get_group_stride(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ get output sizes and strides, for template_codegen """ @@ -6306,7 +7620,11 @@ def get_group_stride(self) -> tuple[list[Sequence[Expr]], list[Expr]]: # iter_ranges = _size of output tensor, reduce_range = [] because no reduction return [_size, []], _stride +<<<<<<< HEAD def canonicalize(self) -> tuple[Expr, Sequence[Expr]]: +======= + def canonicalize(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Manually get canonicalization of the output index """ @@ -6345,7 +7663,11 @@ def get_free_symbol_uses( maybe_get_symbols = ( maybe_free_unbacked_symbols if unbacked_only else maybe_free_symbols ) +<<<<<<< HEAD r = InputsKernel.get_free_symbol_uses(self, unbacked_only) +======= + r = OrderedSet[sympy.Symbol]() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for arg in self.constant_args: r |= maybe_get_symbols(arg) for arg in self.kwargs.values(): @@ -6369,6 +7691,7 @@ def __str__(self) -> str: @ir_dataclass(frozen=False) class ExternKernelOut(ExternKernel): +<<<<<<< HEAD def codegen(self, wrapper: PythonWrapperCodegen) -> None: wrapper.generate_extern_kernel_out(self) @@ -6390,6 +7713,27 @@ def __init__( None, layout, unwrapped_inputs, +======= + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + wrapper.generate_extern_kernel_out(self) + + def __init__( # type: ignore[no-untyped-def] + self, + layout, + inputs, + constant_args=(), + kwargs=None, + output_view=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ) -> None: + super().__init__( + None, + layout, + self.unwrap_storage(inputs), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) constant_args, kwargs or {}, None, @@ -6426,6 +7770,7 @@ def __init__(self, count: int, device: torch.device) -> None: class ExternKernelAlloc(ExternKernel): +<<<<<<< HEAD def codegen(self, wrapper: PythonWrapperCodegen) -> None: wrapper.generate_extern_kernel_alloc(self) @@ -6446,6 +7791,26 @@ def __init__( None, layout, cast(Sequence[IRNode], unwrapped_inputs), +======= + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + wrapper.generate_extern_kernel_alloc(self) + + def __init__( # type: ignore[no-untyped-def] + self, + layout, + inputs, + constant_args=(), + kwargs=None, + python_kernel_name=None, + cpp_kernel_name=None, + ordered_kwargs_for_cpp_kernel=(), + op_overload=None, + ) -> None: + super().__init__( + None, + layout, + self.unwrap_storage(inputs), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) constant_args, kwargs or {}, None, @@ -6464,7 +7829,11 @@ def __init__( def should_allocate(self) -> bool: return False +<<<<<<< HEAD def apply_constraint(self) -> None: +======= + def apply_constraint(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError @@ -6473,9 +7842,13 @@ class MutationOutput(Buffer): An output buffer that represents the mutation of a pre-existing buffer """ +<<<<<<< HEAD def __init__( self, layout: OutputSpec, mutated_node: IRNode, mutating_node: Operation ) -> None: +======= + def __init__(self, layout, mutated_node, mutating_node: Operation) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(name=None, layout=layout) mutated_node_name = mutated_node.get_name() V.graph.mark_buffer_mutated(mutated_node_name) @@ -6492,6 +7865,7 @@ def get_mutation_names(self) -> Sequence[str]: def should_allocate(self) -> bool: return False +<<<<<<< HEAD def get_mutation_buffers(self) -> Sequence[IRNode]: mutation_names = self.get_mutation_names() return [ @@ -6500,6 +7874,8 @@ def get_mutation_buffers(self) -> Sequence[IRNode]: if buf is not None ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TMADescriptor(ExternKernel): """ @@ -6535,9 +7911,13 @@ def create( cls._CACHE[key] = cls._create_impl(tensor, tma_meta) return cls._CACHE[key] +<<<<<<< HEAD def __init__( self, tensor: IRNode, inputs: Sequence[Any], constant_args: Sequence[Any] ) -> None: +======= + def __init__(self, tensor: IRNode, inputs, constant_args): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__( None, # link back to the underlying tensor in terms of ownership @@ -6549,7 +7929,11 @@ def __init__( layout=tensor.get_layout(), ) ), +<<<<<<< HEAD cast(Sequence[Buffer], inputs), +======= + inputs, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tuple(constant_args), None, ) @@ -6558,7 +7942,11 @@ def __init__( self.name = V.graph.register_buffer(self) V.graph.register_operation(self) +<<<<<<< HEAD def codegen(self, wrapper: PythonWrapperCodegen) -> None: +======= + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) wrapper.generate_tma_descriptor(self) def get_tensor(self) -> IRNode: @@ -6640,7 +8028,10 @@ def __init__( self.subgraph = V.graph.make_subgraph(self.gm, example_inputs, subgraph_name) +<<<<<<< HEAD assert is_node_sequence(self.inputs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sym_inputs = get_symbolic_inputs(self.inputs) for sym_inp in sym_inputs: @@ -6653,20 +8044,31 @@ def __init__( with V.set_graph_handler(self.subgraph): # Don't bother autotuning on Triton here +<<<<<<< HEAD with inductor_config.patch( +======= + with inductor_config.patch( # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) max_autotune=False, max_autotune_gemm=False, max_autotune_gemm_backends="ATEN", ): self.subgraph.run(*self.example_inputs) +<<<<<<< HEAD def codegen(self, wrapper: PythonWrapperCodegen) -> None: +======= + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class CodegenGraph: def __init__(self, graph: GraphLowering): self.graph = graph self.name = graph.name +<<<<<<< HEAD assert is_node_sequence(self.inputs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) outer_inputs = [t.codegen_reference() for t in self.inputs] wrapper.codegen_subgraph_with_flattened_outputs( CodegenGraph(self.subgraph), @@ -6676,7 +8078,11 @@ def __init__(self, graph: GraphLowering): class UserDefinedTritonKernel(ExternKernel): +<<<<<<< HEAD def get_kernel_and_metadata(self) -> tuple[Kernel, Any, list[str], list[str]]: +======= + def get_kernel_and_metadata(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from triton.runtime.autotuner import Autotuner from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table @@ -6707,11 +8113,15 @@ def get_kernel_and_metadata(self) -> tuple[Kernel, Any, list[str], list[str]]: kernel = kernel.fn return kernel, configs, restore_value_args, reset_to_zero_args +<<<<<<< HEAD @override def codegen(self, wrapper: PythonWrapperCodegen) -> None: """Overrides the parent member. See https://github.com/pytorch/pytorch/issues/151692""" +======= + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.utils import triton_version_uses_attrs_dict ( @@ -6737,10 +8147,14 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None: named_args = { k: self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel } +<<<<<<< HEAD assert hasattr(kernel, "arg_names") and hasattr(kernel, "constexprs"), type( kernel ) constexpr_names = OrderedSet(kernel.arg_names[i] for i in kernel.constexprs) +======= + constexpr_names = OrderedSet([kernel.arg_names[i] for i in kernel.constexprs]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args: list[Any] = [] arg_types: list[Any] = [] @@ -6749,9 +8163,12 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None: for name, arg in itertools.chain( named_args.items(), zip(itertools.repeat(""), extra_launch_args) ): +<<<<<<< HEAD if name in constexpr_names and triton_version_uses_attrs_dict(): # see #160000 - we don't pass in constexpr args to speed up runtime. continue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raw_keys_filtered.append(name) raw_args_filtered.append(arg) if isinstance(arg, IRNode): @@ -6809,6 +8226,7 @@ def get_free_symbol_uses( def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() +<<<<<<< HEAD def __init__( self, *, @@ -6821,6 +8239,14 @@ def __init__( kwargs: dict[str, IRNode] = {} constant_args: list[IRNode] = [] +======= + def __init__( # type: ignore[no-untyped-def] + self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args + ) -> None: + inputs = [] + kwargs = {} + constant_args = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for k, v in kernel_args.items(): if isinstance(v, TensorBox): t = InputsKernel.unwrap_storage_for_input(self.realize_input(v)) @@ -6835,7 +8261,10 @@ def __init__( assert len(inputs) != 0 self.device = inputs[0].get_device() +<<<<<<< HEAD assert isinstance(inputs, Sequence), type(inputs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__( None, NoneLayout(device=self.device), @@ -6849,7 +8278,10 @@ def __init__( kernel, configs, _, _ = self.get_kernel_and_metadata() # If we are autotuning, not all arguments will be passed +<<<<<<< HEAD assert hasattr(kernel, "arg_names") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.ordered_kwargs_for_cpp_kernel = [ arg for arg in kernel.arg_names if arg in kernel_args ] @@ -6882,9 +8314,14 @@ class InplaceBernoulliFallback(ExternKernel): This needs to be a custom class to handle mutation properly """ +<<<<<<< HEAD def codegen(self, wrapper: PythonWrapperCodegen) -> None: assert all(isinstance(t, IRNode) for t in self.inputs) (x,) = (cast(IRNode, t).codegen_reference() for t in self.inputs) +======= + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + (x,) = (t.codegen_reference() for t in self.inputs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if V.graph.cpp_wrapper: # Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here, @@ -6901,14 +8338,22 @@ def should_allocate(self) -> bool: return False def get_mutation_names(self) -> Sequence[str]: +<<<<<<< HEAD return [self.input_name(0)] +======= + return [self.inputs[0].get_name()] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() +<<<<<<< HEAD def __init__( self, op_overload: _OpOverloads, x: IRNode, *constant_args: Any ) -> None: +======= + def __init__(self, op_overload, x, *constant_args) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__( None, NoneLayout(device=x.get_device()), @@ -6927,7 +8372,11 @@ class InplaceCopyFallback(ExternKernel): This needs to be a custom class to handle mutation properly """ +<<<<<<< HEAD def codegen(self, wrapper: PythonWrapperCodegen) -> None: +======= + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (dst, src, non_blocking) = self.codegen_args() wrapper.codegen_device_copy(src, dst, non_blocking) @@ -6935,16 +8384,28 @@ def should_allocate(self) -> bool: return False def get_mutation_names(self) -> Sequence[str]: +<<<<<<< HEAD return [self.input_name(0)] +======= + return [self.inputs[0].get_name()] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() +<<<<<<< HEAD def __init__( self, layout: OutputSpec, inputs: Sequence[IRNode], constant_args: Sequence[Any], +======= + def __init__( # type: ignore[no-untyped-def] + self, + layout, + inputs, + constant_args, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: super().__init__( None, @@ -6959,9 +8420,13 @@ def __init__( V.graph.register_operation(self) @classmethod +<<<<<<< HEAD def create( cls, dst: IRNode, src: IRNode, non_blocking: bool = False ) -> InplaceCopyFallback: +======= + def create(cls, dst, src, non_blocking: bool = False): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inputs = [cls.realize_input(t) for t in [dst, src]] constant_args = (non_blocking,) result = InplaceCopyFallback( @@ -6977,8 +8442,12 @@ class MutatingFirstArgExternKernel(ExternKernel): This needs to be a custom class to handle mutation properly """ +<<<<<<< HEAD def codegen(self, wrapper: PythonWrapperCodegen) -> None: assert is_node_sequence(self.inputs) +======= + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) argrefs = [ *(t.codegen_reference() for t in self.inputs), *map(repr, self.constant_args), @@ -6991,7 +8460,11 @@ def should_allocate(self) -> bool: return False def get_mutation_names(self) -> Sequence[str]: +<<<<<<< HEAD return [self.input_name(0)] +======= + return [self.inputs[0].get_name()] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() @@ -7001,7 +8474,11 @@ def has_side_effects(self) -> bool: class ResizeStorageBytes(MutatingFirstArgExternKernel): +<<<<<<< HEAD def __init__(self, variable: IRNode, new_size: int) -> None: +======= + def __init__(self, variable, new_size) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(new_size, int), "TODO: dynamic shapes" super().__init__( None, @@ -7014,12 +8491,19 @@ def __init__(self, variable: IRNode, new_size: int) -> None: V.graph.register_operation(self) self.python_kernel_name = "inductor_ops.resize_storage_bytes_" self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_" +<<<<<<< HEAD assert isinstance(variable, (BaseView, StorageBox, TensorBox)), type(variable) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) V.graph.never_reuse_buffers.add(variable.data.get_name()) class SetSourceTensorKernel(ExternKernelAlloc): +<<<<<<< HEAD def __init__(self, self_tensor: IRNode, storage_tensor: IRNode) -> None: +======= + def __init__(self, self_tensor, storage_tensor) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) storage_tensor.freeze_layout() super().__init__( storage_tensor.get_layout(), @@ -7027,9 +8511,12 @@ def __init__(self, self_tensor: IRNode, storage_tensor: IRNode) -> None: python_kernel_name="torch.ops.aten.set_.source_Tensor", op_overload=torch.ops.aten.set_.source_Tensor, ) +<<<<<<< HEAD assert isinstance(self_tensor, (BaseView, StorageBox, TensorBox)), type( self_tensor ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) V.graph.never_reuse_buffers.add(self_tensor.data.get_name()) V.graph.never_reuse_buffers.add(storage_tensor.get_name()) V.graph.never_reuse_buffers.add(self.get_name()) @@ -7040,7 +8527,11 @@ def __init__(self, self_tensor: IRNode, storage_tensor: IRNode) -> None: ] def get_inputs_that_alias_output(self) -> Sequence[str]: +<<<<<<< HEAD return [self.input_name(0), self.input_name(1)] +======= + return [self.inputs[0].get_name(), self.inputs[1].get_name()] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ScatterFallback(ExternKernel): @@ -7050,7 +8541,11 @@ class ScatterFallback(ExternKernel): It also handle the case `src` being a scalar properly. """ +<<<<<<< HEAD def codegen(self, wrapper: PythonWrapperCodegen) -> None: +======= + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reduce = self.kwargs["reduce"] if V.graph.cpp_wrapper: # Follow aten/src/ATen/native/ReductionType.h:get_operator_enum @@ -7058,7 +8553,10 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None: if reduce in get_operator_enum: reduce = get_operator_enum[reduce] +<<<<<<< HEAD assert is_node_sequence(self.inputs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.src_is_tensor: (x, index, src) = (t.codegen_reference() for t in self.inputs) else: @@ -7077,14 +8575,20 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None: def should_allocate(self) -> bool: return False +<<<<<<< HEAD def get_mutation_names(self) -> list[str]: inp = self.inputs[0] assert isinstance(inp, IRNode) return [inp.get_name()] +======= + def get_mutation_names(self) -> Sequence[str]: + return [self.inputs[0].get_name()] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() +<<<<<<< HEAD def __init__( self, op_overload: _OpOverloads, @@ -7092,6 +8596,15 @@ def __init__( dim: int, index: IRNode, src: IRNode, +======= + def __init__( # type: ignore[no-untyped-def] + self, + op_overload, + x, + dim: int, + index, + src, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *, reduce: Optional[str] = None, include_self: bool = True, @@ -7126,8 +8639,12 @@ class IndexPutFallback(ExternKernel): This needs to be a custom class to handle mutation and indices properly """ +<<<<<<< HEAD def codegen(self, wrapper: PythonWrapperCodegen) -> None: assert is_node_sequence(self.inputs) +======= + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs) indices = [] iter_valid_indices = iter(valid_indices) @@ -7145,11 +8662,16 @@ def should_allocate(self) -> bool: return False def get_mutation_names(self) -> Sequence[str]: +<<<<<<< HEAD return [self.input_name(0)] +======= + return [self.inputs[0].get_name()] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() +<<<<<<< HEAD def __init__( self, op_overload: torch._ops.OpOverload, @@ -7158,6 +8680,9 @@ def __init__( values: Sequence[Any], accumulate: Any, ) -> None: +======= + def __init__(self, op_overload, x, indices, values, accumulate) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.indices = indices valid_indices = [i for i in indices if i is not None] tensors = [self.realize_input(x) for x in [x, values, *valid_indices]] @@ -7171,14 +8696,22 @@ def __init__( cpp_kernel_name=cpp_kernel_name, op_overload=op_overload, ) +<<<<<<< HEAD V.graph.mark_buffer_mutated(self.input_name(0)) +======= + V.graph.mark_buffer_mutated(self.inputs[0].get_name()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.name = V.graph.register_buffer(self) V.graph.register_operation(self) class DeviceCopy(ExternKernelOut): @classmethod +<<<<<<< HEAD def create(cls, x: IRNode, device: torch.device, non_blocking: bool) -> IRNode: +======= + def create(cls, x, device, non_blocking): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( not x.is_extern() and all(r in V.graph.constants for r in x.get_read_names()) @@ -7187,6 +8720,7 @@ def create(cls, x: IRNode, device: torch.device, non_blocking: bool) -> IRNode: return x.constant_to_device(device) V.graph.add_device_info(device) +<<<<<<< HEAD x_device = x.get_device() assert x_device is not None V.graph.add_device_info(x_device) @@ -7214,12 +8748,27 @@ def create(cls, x: IRNode, device: torch.device, non_blocking: bool) -> IRNode: x.get_size(), stride, is_pinned=is_destination_pinned, +======= + V.graph.add_device_info(x.get_device()) + + developer_warning("DeviceCopy in input program") + constant_args = (non_blocking,) + return DeviceCopy( + FlexibleLayout( + device=device, + dtype=x.get_dtype(), + size=x.get_size(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), [cls.realize_input(x)], constant_args, ) +<<<<<<< HEAD def codegen(self, wrapper: PythonWrapperCodegen) -> None: +======= + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = self.codegen_args() assert len(args) == 2 if self.output_view: @@ -7230,6 +8779,7 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None: wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1]) +<<<<<<< HEAD class DynamicSelectStorageOffset(ExternKernel): """ The result of computing a dynamic selection index is determined as follows: when the index in the @@ -7274,6 +8824,8 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None: wrapper.codegen_dynamic_select_index(self) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class DynamicScalar(ExternKernel): """ The result of a call to aten._local_scalar_dense. @@ -7285,9 +8837,13 @@ def get_reads(self) -> OrderedSet[Dep]: def should_allocate(self) -> bool: return False +<<<<<<< HEAD def __init__( self, sym: sympy.Symbol, keypath: pytree.KeyPath, data: IRNode ) -> None: +======= + def __init__(self, sym, keypath, data) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) data.realize() super().__init__( None, NoneLayout(device=torch.device("cpu")), self.unwrap_storage([data]) @@ -7298,7 +8854,11 @@ def __init__( def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet([self.sym]) +<<<<<<< HEAD def codegen(self, wrapper: PythonWrapperCodegen) -> None: +======= + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) wrapper.codegen_dynamic_scalar(self) @@ -7313,7 +8873,11 @@ def get_reads(self) -> OrderedSet[Dep]: def should_allocate(self) -> bool: return False +<<<<<<< HEAD def __init__(self, scalar: SympyBoolean, msg: str) -> None: +======= + def __init__(self, scalar, msg) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__( # Buffer(name, layotu) None, @@ -7327,12 +8891,19 @@ def __init__(self, scalar: SympyBoolean, msg: str) -> None: def has_side_effects(self) -> bool: return True +<<<<<<< HEAD def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: return get_free_symbols(self.scalar, unbacked_only) def codegen(self, wrapper: PythonWrapperCodegen) -> None: +======= + def get_free_symbol_uses(self, unbacked_only: bool = False): # type: ignore[no-untyped-def] + return get_free_symbols(self.scalar, unbacked_only) + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not config.scalar_asserts: return # NB: It is EXTREMELY important not to simplify the scalar under assertion here, @@ -7341,10 +8912,14 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None: # simplify(u0 == 0), you will get True (because we've already runtime assert'ed # that it's true). But we're code generating the actual runtime assert here!! symbol = next(iter(self.get_free_symbol_uses(unbacked_only=False))) +<<<<<<< HEAD if V.graph.fx_wrapper: # TODO fix pass elif V.graph.cpp_wrapper: +======= + if V.graph.cpp_wrapper: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) symbol_str = f"std::to_string({symbol})" sizevar = V.graph.wrapper_code.codegen_cpp_sizevar( self.scalar, simplify=False @@ -7377,6 +8952,7 @@ class FallbackKernel(ExternKernelAlloc): inplace aten ops, and mutating ops that are auto-functionalizable. """ +<<<<<<< HEAD def __init__( self, layout: OutputSpec, @@ -7387,6 +8963,18 @@ def __init__( kwargs: Optional[dict[str, Any]] = None, *, unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None, +======= + def __init__( # type: ignore[no-untyped-def] + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + *, + unbacked_bindings=None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: super().__init__( layout, @@ -7396,16 +8984,31 @@ def __init__( ) self.use_runtime_dispatch = False +<<<<<<< HEAD self.unbacked_bindings = unbacked_bindings or {} assert isinstance( kernel, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) +======= + self.unbacked_bindings = unbacked_bindings + + assert isinstance( + kernel, + ( + torch._ops.OpOverload, + torch._ops.HigherOrderOperator, + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported" self.op_overload = kernel self.unflatten_args = unflatten_args self.kwargs = {} if kwargs is None else kwargs +<<<<<<< HEAD assert self.python_kernel_name is not None V.graph.warn_fallback(self.python_kernel_name) +======= + V.graph.warn_fallback(self.python_kernel_name) # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # args that are aliased self.alias_names: list[str] = [] @@ -7450,10 +9053,17 @@ def __init__( args, kwargs = self.unflatten_args(self.inputs, self.constant_args) +<<<<<<< HEAD def handle_aliasing_and_mutation(info: torch._C.Argument, arg: Any) -> None: # Assertions to make sure we didn't mismatch args if isinstance(info.type, torch.ListType): assert isinstance(arg, (list, tuple)), type(arg) +======= + def handle_aliasing_and_mutation(info, arg) -> None: # type: ignore[no-untyped-def] + # Assertions to make sure we didn't mismatch args + if isinstance(info.type, torch.ListType): + assert isinstance(arg, (list, tuple)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if library_utils.is_tensor_like_type(info.type): # PyTorch also accepts None and scalar types for args marked as "Tensor". # We're not going to check all of them here. @@ -7464,9 +9074,14 @@ def handle_aliasing_and_mutation(info: torch._C.Argument, arg: Any) -> None: if info.alias_info is None: return +<<<<<<< HEAD def add_alias(t: IRNode) -> None: self.alias_names.append(t.get_name()) assert info.alias_info is not None +======= + def add_alias(t) -> None: # type: ignore[no-untyped-def] + self.alias_names.append(t.get_name()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if info.alias_info.is_write: self.mutation_outputs.append( MutationOutput(NoneLayout(device=t.get_device()), t, self) @@ -7495,22 +9110,38 @@ def get_read_writes(self) -> dependencies.ReadWrites: return read_writes +<<<<<<< HEAD def codegen_unbacked_symbol_defs(self, wrapper: PythonWrapperCodegen) -> None: +======= + def codegen_unbacked_symbol_defs(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return wrapper.codegen_unbacked_symbol_defs_for_outputs( self.get_name(), self.outputs, getattr(self, "unbacked_bindings", None) ) +<<<<<<< HEAD def get_unbacked_symbol_defs(self) -> Container[sympy.Symbol]: # type: ignore[override] +======= + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if unbacked_bindings := getattr(self, "unbacked_bindings", None): resolved = resolve_unbacked_bindings( V.graph.sizevars.shape_env, unbacked_bindings ) assert resolved is not None +<<<<<<< HEAD return resolved.keys() else: return OrderedSet() def codegen_args(self) -> list[str]: +======= + return resolved.keys() # type: ignore[return-value] + else: + return OrderedSet() + + def codegen_args(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass class Shim: ref: Any @@ -7518,7 +9149,10 @@ class Shim: def __repr__(self) -> str: return self.ref +<<<<<<< HEAD assert is_node_sequence(self.inputs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] args, kwargs = self.unflatten_args(tensor_args, self.constant_args) if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): @@ -7535,16 +9169,23 @@ def __repr__(self) -> str: return args @staticmethod +<<<<<<< HEAD def find_device( tensor_args: Optional[Sequence[torch.Tensor]], example_output: Sequence[Any] ) -> Any: +======= + def find_device(tensor_args, example_output): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) non_torch_bind_tensor_args = ( [t for t in tensor_args if not isinstance(t, TorchBindObject)] if tensor_args else None ) if non_torch_bind_tensor_args: +<<<<<<< HEAD assert tensor_args +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) devices = [arg.get_device() for arg in tensor_args if arg.get_device()] return devices[0] if isinstance(example_output, torch.Tensor): @@ -7558,17 +9199,25 @@ def find_device( if len(devices) == 1: return devices[0] for device in devices: +<<<<<<< HEAD assert isinstance(device, torch.device) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_gpu(device.type): return device return devices[0] return None +<<<<<<< HEAD def has_side_effects(self) -> bool: +======= + def has_side_effects(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(self.op_overload, torch._ops.HigherOrderOperator): return False return get_schema_info(self.op_overload).is_mutable() +<<<<<<< HEAD def get_inputs_that_alias_output(self) -> Sequence[str]: assert isinstance( self.op_overload, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) @@ -7589,6 +9238,10 @@ def get_inputs_that_alias_output(self) -> Sequence[str]: return [] else: return self.alias_names +======= + def get_inputs_that_alias_output(self): # type: ignore[no-untyped-def] + return self.alias_names +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_mutation_names(self) -> Sequence[str]: assert len(self.mutation_names) <= 1 @@ -7609,7 +9262,11 @@ def export_extern_kernel_node(self): # type: ignore[no-untyped-def] self.op_overload, ) +<<<<<<< HEAD assert isinstance(self, FallbackKernel), type(self) +======= + assert isinstance(self, FallbackKernel) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args, kwargs = self.unflatten_args(self.inputs, self.constant_args) args = self.fill_non_provided_args(args, kwargs) ordered_kwargs = [ @@ -7622,6 +9279,7 @@ def export_extern_kernel_node(self): # type: ignore[no-untyped-def] # No need to serialize in the cpp wrapper JIT mode return [*args, *ordered_kwargs] +<<<<<<< HEAD serializer = GraphModuleSerializer(None, []) # type: ignore[arg-type] named_arguments = serializer.serialize_inputs(target, args, kwargs) @@ -7630,6 +9288,13 @@ def handle_single_output( return_type: Union[torch.TensorType, torch.ListType, torch.JitType], output: Union[IRNode, Sequence[IRNode]], ) -> export_schema.Argument: +======= + serializer = GraphModuleSerializer(None, None) # type: ignore[arg-type] + named_arguments = serializer.serialize_inputs(target, args, kwargs) + + # serialize_outputs + def handle_single_output(return_type, output): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(return_type, (torch.TensorType, torch.NoneType)): # For single Tensor or None out = output @@ -7637,7 +9302,10 @@ def handle_single_output( assert len(output) == 1 out = output[0] if isinstance(return_type, torch.TensorType): +<<<<<<< HEAD assert isinstance(out, IRNode) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return export_schema.Argument.create( as_tensor=export_schema.TensorArgument(name=out.get_name()) ) @@ -7647,7 +9315,10 @@ def handle_single_output( elif isinstance(return_type, torch.ListType) and isinstance( return_type.getElementType(), torch.TensorType ): +<<<<<<< HEAD assert isinstance(output, Sequence), type(output) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # For single TensorList return export_schema.Argument.create( as_tensors=[ @@ -7666,7 +9337,10 @@ def handle_single_output( ) ) else: +<<<<<<< HEAD assert isinstance(output, IRNode) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return export_schema.Argument.create( as_optional_tensor=export_schema.OptionalTensorArgument.create( as_tensor=export_schema.TensorArgument( @@ -7680,7 +9354,11 @@ def handle_single_output( raise RuntimeError(f"Unsupported return type {type(return_type)}") if isinstance(target, torch._higher_order_ops.torchbind.CallTorchBind): +<<<<<<< HEAD returns = target.schema(args[0], args[1]).returns +======= + returns = target.schema(args[0], args[1]).returns # type: ignore[union-attr] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: returns = target._schema.returns # type: ignore[union-attr] if len(returns) == 1: @@ -7693,6 +9371,7 @@ def handle_single_output( # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])" # Not generating output args for self.mutation_outputs output_arguments = [ +<<<<<<< HEAD handle_single_output( return_schema.real_type, # type: ignore[attr-defined] output, @@ -7705,12 +9384,23 @@ def handle_single_output( name=self.get_name(), node=export_schema.Node( target=self.op_overload.name(), +======= + handle_single_output(return_schema.real_type, output) + for return_schema, output in zip(returns, self.outputs) + ] + + node = ExternKernelNode( + name=self.get_name(), + node=export_schema.Node( + target=self.op_overload.name(), # type: ignore[union-attr] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inputs=named_arguments, outputs=output_arguments, metadata={}, ), ) +<<<<<<< HEAD V.extern_kernel_nodes.append(node) return [*args, *ordered_kwargs] @@ -7724,6 +9414,17 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None: if kernel.namespace == "aten": # Aten Fallback Ops assert isinstance(kernel, torch._ops.OpOverload), type(kernel) +======= + V.graph.extern_kernel_nodes.append(node) + + return [*args, *ordered_kwargs] + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + kernel = self.op_overload + if kernel.namespace == "aten": # type: ignore[union-attr] + # Aten Fallback Ops + assert isinstance(kernel, torch._ops.OpOverload) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if V.graph.cpp_wrapper: from torchgen.aoti.fallback_ops import inductor_fallback_ops @@ -7735,9 +9436,15 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None: kernel, ) self.use_runtime_dispatch = True +<<<<<<< HEAD elif kernel.namespace == "_quantized": # Internal Quantized Fallback Ops assert isinstance(kernel, torch._ops.OpOverload), type(kernel) +======= + elif kernel.namespace == "_quantized": # type: ignore[union-attr] + # Internal Quantized Fallback Ops + assert isinstance(kernel, torch._ops.OpOverload) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif V.graph.cpp_wrapper: # For non-aten OpOverload, i.e. custom ops # If the op is in custom_ops_to_c_shims, generate direct function call @@ -7780,9 +9487,12 @@ def is_number(t: torch.JitType) -> bool: self.codegen_comment(wrapper) if self.use_runtime_dispatch: exported_args = self.export_extern_kernel_node() +<<<<<<< HEAD assert self.python_kernel_name is not None assert self.op_overload is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) wrapper.generate_fallback_kernel_with_runtime_lookup( self.get_name(), self.python_kernel_name, @@ -7797,11 +9507,15 @@ def is_number(t: torch.JitType) -> bool: if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) self.codegen_alignment_asserts(wrapper) +<<<<<<< HEAD self.codegen_memory_tracking(wrapper) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.codegen_unbacked_symbol_defs(wrapper) @staticmethod +<<<<<<< HEAD def tensor_to_layout(output: torch.Tensor) -> FixedLayout: is_pinned = False try: @@ -7809,11 +9523,15 @@ def tensor_to_layout(output: torch.Tensor) -> FixedLayout: except RuntimeError: # dispatch not implemented pass +======= + def tensor_to_layout(output: torch.Tensor): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return FixedLayout( output.device, output.dtype, convert_shape_to_inductor(output.size()), convert_shape_to_inductor(output.stride()), +<<<<<<< HEAD is_pinned=is_pinned, ) @@ -7826,6 +9544,16 @@ def create(cls, kernel: _OpOverloads, *args: Any, **kwargs: Any) -> FallbackKern else: context = nullcontext() +======= + ) + + @classmethod + def create(cls, kernel, *args, **kwargs): # type: ignore[no-untyped-def] + fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,) + context: AbstractContextManager[None] = ( + V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext() # type: ignore[assignment] + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with context: ( example_output, @@ -7868,7 +9596,11 @@ def create(cls, kernel: _OpOverloads, *args: Any, **kwargs: Any) -> FallbackKern unbacked_bindings=unbacked_bindings, ) +<<<<<<< HEAD def generate_output(output: Any, indices: list[tuple[Any, int]]) -> Any: +======= + def generate_output(output, indices): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(output, (list, tuple)): return type(output)( generate_output(output[i], indices + [(type(output), i)]) @@ -7903,15 +9635,24 @@ def generate_output(output: Any, indices: list[tuple[Any, int]]) -> Any: return None outputs = generate_output(example_output, []) +<<<<<<< HEAD if isinstance(outputs, (list, tuple)): packed.outputs = outputs elif isinstance(outputs, dict): packed.outputs = tuple(outputs) +======= + if isinstance(outputs, (list, tuple, dict)): + packed.outputs = outputs # type: ignore[assignment] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: packed.outputs = [outputs] return outputs +<<<<<<< HEAD def apply_constraint(self) -> None: +======= + def apply_constraint(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().apply_constraint() @@ -7924,6 +9665,7 @@ def should_allocate(self) -> bool: def get_inputs_that_alias_output(self) -> Sequence[str]: # Signal to codegen that our output buffer isn't safe to reuse +<<<<<<< HEAD return [self.input_name(0)] def __init__( @@ -7935,6 +9677,19 @@ def __init__( unflatten_args: Callable[..., Any], *, unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None, +======= + return [self.inputs[0].get_name()] + + def __init__( # type: ignore[no-untyped-def] + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + *, + unbacked_bindings=None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: super().__init__( layout, @@ -7946,6 +9701,7 @@ def __init__( ) +<<<<<<< HEAD class MemoryCheckKernel(FallbackKernel): """ Custom kernel for memory checking that generates direct function calls @@ -7971,6 +9727,8 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None: wrapper.writeline(call) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @ir_dataclass class MultiOutputLayout(OutputSpec): device: torch.device @@ -7980,18 +9738,31 @@ def get_device(self) -> Optional[torch.device]: class MultiOutput(ExternKernel): +<<<<<<< HEAD def codegen(self, wrapper: PythonWrapperCodegen) -> None: +======= + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) wrapper.codegen_multi_output(self) if not self.skip_size_stride_alignment_checks: self.codegen_size_asserts(wrapper) self.codegen_alignment_asserts(wrapper) +<<<<<<< HEAD def __init__( self, layout: OutputSpec, input: IRNode, indices: list[tuple[Any, ...]], skip_size_stride_alignment_checks: bool = False, +======= + def __init__( # type: ignore[no-untyped-def] + self, + layout: OutputSpec, + input, + indices: list[tuple[Any, ...]], + skip_size_stride_alignment_checks=False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: super().__init__(None, layout, [input], ()) self.name = V.graph.register_buffer(self) @@ -8002,6 +9773,7 @@ def __init__( def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: +<<<<<<< HEAD input_node = self.inputs[0] assert isinstance(input_node, IRNode), input_node return input_node.get_free_symbol_uses(unbacked_only) @@ -8010,6 +9782,16 @@ def should_allocate(self) -> bool: return len(self.inputs) == 1 and ( isinstance(self.inputs[0], CppTemplateBuffer) # Grouped GEMM ) +======= + return self.inputs[0].get_free_symbol_uses(unbacked_only) + + def should_allocate(self) -> bool: + if len(self.inputs) == 1 and ( + isinstance(self.inputs[0], CppTemplateBuffer) # Grouped GEMM + ): + return True + return False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_inputs_that_alias_output(self) -> Sequence[str]: return [ @@ -8067,6 +9849,7 @@ def freeze_layout(self) -> None: return self.data.freeze_layout() def freeze_layout_with_stride_order( +<<<<<<< HEAD self, order: Sequence[int], allow_padding: bool = False ) -> None: return self.data.freeze_layout_with_stride_order(order, allow_padding) @@ -8079,6 +9862,20 @@ def freeze_layout_with_same_order(self, stride: Sequence[_IntLike]) -> None: def freeze_layout_with_exact_strides( self, exact_strides: Sequence[_IntLike], allow_padding: bool = False +======= + self, order: list[int], allow_padding: bool = False + ) -> None: + return self.data.freeze_layout_with_stride_order(order, allow_padding) + + def freeze_layout_with_fill_order(self, order: list[int]) -> None: + return self.data.freeze_layout_with_fill_order(order) + + def freeze_layout_with_same_order(self, stride: list[_IntLike]) -> None: + return self.data.freeze_layout_with_same_order(stride) + + def freeze_layout_with_exact_strides( + self, exact_strides: list[_IntLike], allow_padding: bool = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: return self.data.freeze_layout_with_exact_strides(exact_strides, allow_padding) @@ -8097,7 +9894,11 @@ def get_storage_numel(self) -> _IntLike: def get_reduction_type(self) -> Optional[str]: return self.data.get_reduction_type() +<<<<<<< HEAD def get_reduction_size(self) -> Sequence[Expr]: +======= + def get_reduction_size(self) -> Sequence[sympy.Expr]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.data.get_reduction_size() def is_extern(self) -> bool: @@ -8150,7 +9951,11 @@ def get_size(self) -> Sequence[Expr]: return self.data.get_size() @property +<<<<<<< HEAD def dtype(self) -> torch.dtype: +======= + def dtype(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.data.dtype def __str__(self) -> str: @@ -8175,37 +9980,63 @@ def __str__(self) -> str: class TensorBox(MutableBox): @staticmethod +<<<<<<< HEAD def create(data: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= + def create(data): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(data, ShapeAsConstantBuffer): return data return TensorBox(StorageBox(data)) class StorageBox(MutableBox): +<<<<<<< HEAD """ StorageBox allow in-place mutation of Tensors """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def is_input_buffer(self) -> bool: if isinstance(self.data, (InputBuffer, ReinterpretView)): return self.data.get_name() in V.graph.graph_inputs return False +<<<<<<< HEAD def is_module_buffer(self) -> bool: +======= + def is_module_buffer(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( isinstance(self.data, (ConstantBuffer)) and self.data.get_name() in V.graph.constants ) def realize(self) -> Optional[str]: +<<<<<<< HEAD if IRNode.is_realized_node(self.data): return self.data.get_name() +======= + if isinstance( + self.data, + ( + ComputedBuffer, + InputsKernel, + InputBuffer, + ReinterpretView, + TemplateBuffer, + ), + ): + return self.data.get_name() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(self.data, (Pointwise, Reduction, Scan, Sort)), type( self.data ) origin_node = self.data.get_origin_node() traceback = self.data.get_traceback() +<<<<<<< HEAD device = self.data.get_device() assert device is not None @@ -8216,6 +10047,14 @@ def realize(self) -> Optional[str]: dtype=self.data.get_dtype(), size=self.data.get_size(), is_pinned=False, +======= + self.data = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=self.data.get_device(), + dtype=self.data.get_dtype(), + size=self.data.get_size(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), data=self.data, ) @@ -8236,15 +10075,19 @@ def realize_hint(self) -> None: ): self.realize() +<<<<<<< HEAD def has_accumulated_enough_reads_by_size(self, threshold: int) -> bool: return ( sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads()) > threshold ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def has_exceeded_max_reads(self) -> bool: return isinstance(self.data, Pointwise) and ( self.num_reads() > config.realize_acc_reads_threshold or self.has_large_inner_fn() +<<<<<<< HEAD or ( config.realize_acc_reads_size_threshold is not None and self.has_accumulated_enough_reads_by_size( @@ -8254,6 +10097,11 @@ def has_exceeded_max_reads(self) -> bool: ) def should_realize_on_reuse(self, users: int) -> bool: +======= + ) + + def should_realize_on_reuse(self, users): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ A heuristic to decide if we should realize a tensor that is used multiple times. @@ -8275,7 +10123,11 @@ def mark_reuse(self, users: int) -> None: if self.should_realize_on_reuse(users): self.realize() +<<<<<<< HEAD def num_reads(self) -> int: +======= + def num_reads(self): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.data.num_reads() @@ -8302,11 +10154,19 @@ class InvokeSubgraph(ExternKernel): """ subgraph: Optional[Subgraph] = None +<<<<<<< HEAD operands: Optional[Sequence[IRNode]] = None outputs: Optional[Sequence[IRNode]] = None def __init__( self, subgraph: Subgraph, operands: Sequence[IRNode], layout: MultiOutputLayout +======= + operands: Optional[list[TensorBox]] = None + outputs: Optional[list[MultiOutput]] = None + + def __init__( + self, subgraph: Subgraph, operands: list[TensorBox], layout: MultiOutputLayout +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: super().__init__( name=None, @@ -8318,11 +10178,15 @@ def __init__( V.graph.register_operation(self) @classmethod +<<<<<<< HEAD def create( cls, subgraph: Subgraph, *operands: IRNode ) -> list[Union[ShapeAsConstantBuffer, NoneAsConstantBuffer, MultiOutput]]: """For each operand, get a realized input, force it to have the same strides as the subgraph inputs, then use an InvokeSubgraph""" +======= + def create(cls, subgraph: Subgraph, *operands): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .lowering import constrain_to_fake_tensor # TODO(anijain2305) - Support sym expr as operands in future. @@ -8341,11 +10205,19 @@ def create( # Realize the inputs. Also intermediates can have different strides than # the inputs of the subgraph. So, force the intermediates to have same # strides as that of subgraph inputs. +<<<<<<< HEAD operands: list[IRNode] = [cls.realize_input(x) for x in operands] new_operands: list[IRNode] = [] for idx, operand in enumerate(operands): if isinstance(operand, (ShapeAsConstantBuffer, GeneratorState)): +======= + operands = [cls.realize_input(x) for x in operands] + + new_operands = [] + for idx, operand in enumerate(operands): + if isinstance(operand, ShapeAsConstantBuffer): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_operands.append(operand) else: new_operands.append( @@ -8374,12 +10246,17 @@ def create( device = operand.get_device() break assert device is not None +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) invoke_subgraph = InvokeSubgraph( subgraph=subgraph, operands=operands, layout=MultiOutputLayout(device=device), ) +<<<<<<< HEAD def create_output( output: IRNode, ind: int ) -> Union[ShapeAsConstantBuffer, NoneAsConstantBuffer, MultiOutput]: @@ -8397,32 +10274,64 @@ def create_output( stride=output.get_stride(), offset=output.get_layout().offset, is_pinned=output.get_layout().is_pinned, +======= + def create_output(output: IRNode, ind: int): # type: ignore[no-untyped-def] + if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)): + return output + else: + return MultiOutput( + FixedLayout( + device=output.get_device(), # type: ignore[arg-type] + dtype=output.get_dtype(), + size=output.get_size(), # type: ignore[arg-type] + stride=output.get_stride(), # type: ignore[arg-type] + offset=output.get_layout().offset, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), invoke_subgraph, # type: ignore[has-type] [(list, ind)], skip_size_stride_alignment_checks=True, ) +<<<<<<< HEAD outs = [create_output(output, i) for i, output in enumerate(outputs)] invoke_subgraph.outputs = outs # type: ignore[assignment] return outs def codegen(self, wrapper: PythonWrapperCodegen) -> None: +======= + outputs = [create_output(output, i) for i, output in enumerate(outputs)] + invoke_subgraph.outputs = outputs + return outputs + + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) wrapper.codegen_invoke_subgraph(self) @ir_dataclass(frozen=False) class Conditional(ExternKernel): predicate: Optional[IRNode] = None +<<<<<<< HEAD operands: Optional[Sequence[IRNode]] = None true_subgraph: Optional[Subgraph] = None false_subgraph: Optional[Subgraph] = None outputs: Optional[Sequence[MultiOutput]] = None +======= + operands: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None + true_subgraph: Optional[Subgraph] = None + false_subgraph: Optional[Subgraph] = None + outputs: Optional[list[MultiOutput]] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__( self, predicate: IRNode, +<<<<<<< HEAD operands: Sequence[IRNode], +======= + operands: list[Union[TensorBox, ShapeAsConstantBuffer]], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) true_subgraph: Subgraph, false_subgraph: Subgraph, layout: MultiOutputLayout, @@ -8433,7 +10342,11 @@ def __init__( self.true_subgraph = true_subgraph self.false_subgraph = false_subgraph +<<<<<<< HEAD sym_args, tensor_args = _split_by_sym_type([predicate, *operands]) +======= + sym_args, tensor_args = _split_by_sym_type([predicate] + operands) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__( name=None, @@ -8447,6 +10360,7 @@ def __init__( self.name = V.graph.register_buffer(self) V.graph.register_operation(self) +<<<<<<< HEAD @staticmethod def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.Expr]: if isinstance(s, int): @@ -8455,11 +10369,16 @@ def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.Expr]: @classmethod def create( +======= + @classmethod + def create( # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cls, predicate: TensorBox, true_fn: Subgraph, false_fn: Subgraph, operands: list[Union[TensorBox, ShapeAsConstantBuffer]], +<<<<<<< HEAD ) -> Sequence[IRNode]: """Create a Sequence of IRNodes from a conditional statement (see .lowering.cond)""" predicate = cls.realize_input(predicate) @@ -8469,6 +10388,13 @@ def create( assert isinstance(fx_operands, Sequence), type(fx_operands) assert all(isinstance(n, Node) for n in fx_operands) fake_operands = [cast(Node, x).meta["val"] for x in fx_operands] +======= + ): + predicate = cls.realize_input(predicate) + operands = [cls.realize_input(x) for x in operands] + fx_operands = V.graph.current_node.args[-1] + fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for subgraph in (true_fn, false_fn): if subgraph.graph is None: @@ -8481,10 +10407,15 @@ def create( with V.set_graph_handler(subgraph.graph): subgraph.graph.run(*fake_operands) +<<<<<<< HEAD assert true_fn.graph is not None assert false_fn.graph is not None true_outputs = true_fn.graph.graph_outputs false_outputs = false_fn.graph.graph_outputs +======= + true_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] + false_outputs = false_fn.graph.graph_outputs # type: ignore[union-attr] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)): if _has_aliased_buffers(true_outputs): @@ -8519,6 +10450,7 @@ def create( unbacked_bindings=unbacked_bindings, ) +<<<<<<< HEAD outputs = [ MultiOutput( FixedLayout( @@ -8530,6 +10462,21 @@ def create( ], offset=output.get_layout().offset, is_pinned=output.get_layout().is_pinned, +======= + def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.expr]: + if isinstance(s, int): + return s + return s.node.expr + + outputs = [ + MultiOutput( + FixedLayout( + device=output.get_device(), + dtype=output.get_dtype(), + size=[_maybe_expr(sz) for sz in merged_output.size()], + stride=[_maybe_expr(sz) for sz in merged_output.stride()], + offset=output.get_layout().offset, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), conditional, [(list, i)], @@ -8544,7 +10491,11 @@ def create( conditional.outputs = outputs # type: ignore[assignment] return outputs +<<<<<<< HEAD def codegen(self, wrapper: PythonWrapperCodegen) -> None: +======= + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) wrapper.codegen_conditional(self) wrapper.codegen_unbacked_symbol_defs_for_outputs( self.get_name(), self.outputs, getattr(self, "unbacked_bindings", {}) @@ -8556,7 +10507,11 @@ def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: V.graph.sizevars.shape_env, unbacked_bindings ) assert resolved is not None +<<<<<<< HEAD return OrderedSet(resolved.keys()) +======= + return resolved.keys() # type: ignore[return-value] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return OrderedSet() @@ -8577,6 +10532,7 @@ def _split_by_sym_type( @ir_dataclass(frozen=False) class WhileLoop(ExternKernel): +<<<<<<< HEAD """The IR node for while_loop and while_loop_stack_output. It supports input mutation.""" carried_inputs: Optional[Sequence[IRNode]] = None @@ -8594,28 +10550,51 @@ def __init__( layout: MultiOutputLayout, unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]], stack_output: bool, +======= + carried_inputs: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None + additional_inputs: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None + cond_subgraph: Optional[Subgraph] = None + body_subgraph: Optional[Subgraph] = None + outputs: Optional[list[MultiOutput]] = None + + def __init__( + self, + carried_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], + additional_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], + cond_subgraph: Subgraph, + body_subgraph: Subgraph, + layout: MultiOutputLayout, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: self.carried_inputs = carried_inputs self.additional_inputs = additional_inputs self.cond_subgraph = cond_subgraph self.body_subgraph = body_subgraph +<<<<<<< HEAD sym_args, tensor_args = _split_by_sym_type( [*carried_inputs, *additional_inputs] ) +======= + sym_args, tensor_args = _split_by_sym_type(carried_inputs + additional_inputs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__( name=None, layout=layout, inputs=tensor_args, constant_args=sym_args, ) +<<<<<<< HEAD if unbacked_bindings is not None: self.unbacked_bindings = unbacked_bindings self.stack_output = stack_output +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.name = V.graph.register_buffer(self) V.graph.register_operation(self) +<<<<<<< HEAD # Accidental aliasing can be created due to cse, where the empty buffers we # allocated for backward to use gets csed into the same buffer in function fx_graph_cse. # See test_scan_multiple_layers_gradient for a concrete example. @@ -8666,6 +10645,22 @@ def _require_exact_strides( tensor_boxes: Sequence[IRNode], fake_tensors: list[Union[int, torch.SymInt, torch.Tensor]], ) -> list[IRNode]: +======= + @classmethod + def create( # type: ignore[no-untyped-def] + cls, + cond_fn: Subgraph, + body_fn: Subgraph, + carried_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], + additional_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], + ): + from torch._higher_order_ops.utils import check_input_alias_and_mutation + + def _require_exact_strides( + tensor_boxes: list[TensorBox | ShapeAsConstantBuffer], + fake_tensors: list[Union[int, torch.SymInt, torch.Tensor]], + ) -> list[TensorBox | ShapeAsConstantBuffer]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(tensor_boxes) == len(fake_tensors) ret = [] for tb, fk in zip(tensor_boxes, fake_tensors): @@ -8686,6 +10681,7 @@ def _require_exact_strides( fake_carried_inputs = [x.meta["val"] for x in fx_carried_inputs] # type: ignore[union-attr] fake_additional_inputs = [x.meta["val"] for x in fx_additional_inputs] # type: ignore[union-attr] +<<<<<<< HEAD carried_inputs_ = [cls.realize_input(x) for x in carried_inputs] carried_inputs_ = WhileLoop._clone_aliased_inputs(carried_inputs_) carried_inputs_ = _require_exact_strides(carried_inputs_, fake_carried_inputs) @@ -8694,11 +10690,23 @@ def _require_exact_strides( additional_inputs_, fake_additional_inputs ) all_inputs = carried_inputs_ + additional_inputs_ +======= + carried_inputs = [cls.realize_input(x) for x in carried_inputs] + carried_inputs = _require_exact_strides(carried_inputs, fake_carried_inputs) + additional_inputs = [cls.realize_input(x) for x in additional_inputs] + additional_inputs = _require_exact_strides( + additional_inputs, fake_additional_inputs + ) + all_inputs = carried_inputs + additional_inputs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for subgraph in (cond_fn, body_fn): if subgraph.graph is None: # create and lower subgraphs +<<<<<<< HEAD assert isinstance(fx_all_inputs, Sequence), type(fx_all_inputs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) subgraph.graph = V.graph.make_subgraph( gm=subgraph.graph_module, example_inputs=fx_all_inputs, # type: ignore[arg-type] @@ -8717,6 +10725,7 @@ def _require_exact_strides( fake_carried_inputs ) subgraph.graph.graph_outputs = _require_exact_strides( # type: ignore[assignment] +<<<<<<< HEAD subgraph.graph.graph_outputs, fake_carried_inputs, ) @@ -8724,6 +10733,14 @@ def _require_exact_strides( assert cond_fn.graph and body_fn.graph cond_outputs = cond_fn.graph.graph_outputs body_outputs = body_fn.graph.graph_outputs +======= + subgraph.graph.graph_outputs, # type: ignore[arg-type] + fake_carried_inputs, + ) + + cond_outputs = cond_fn.graph.graph_outputs # type: ignore[union-attr] + body_outputs = body_fn.graph.graph_outputs # type: ignore[union-attr] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if _has_aliased_buffers(body_outputs): raise AssertionError( @@ -8745,6 +10762,7 @@ def _require_exact_strides( device = all_inputs[0].get_device() assert device is not None # to make linter happy +<<<<<<< HEAD # make sure carried_inputs_ and body outputs are structurally equivalent assert len(carried_inputs_) == len(body_outputs), ( carried_inputs_, @@ -8777,12 +10795,39 @@ def _guard_list_equals( while_loop = WhileLoop( carried_inputs=carried_inputs_, additional_inputs=additional_inputs_, +======= + # make sure carried_inputs and body outputs are structurally equivalent + assert len(carried_inputs) == len(body_outputs), (carried_inputs, body_outputs) + for i, (op, bo) in enumerate(zip(carried_inputs, body_outputs)): + + def _guard_list_equals( + lhs_exprs: Sequence[Union[int, Any]], + rhs_exprs: Sequence[Union[int, Any]], + ) -> None: + for lhs, rhs in zip(lhs_exprs, rhs_exprs): + V.graph.sizevars.guard_equals(lhs, rhs) + + _guard_list_equals(op.get_size(), bo.get_size()) + _guard_list_equals(op.get_stride(), bo.get_stride()) + # assume all carried_inputs and outputs are on the same device + # as the MultiOutputLayout below requires single device + assert op.get_device() == bo.get_device(), (i, op, bo, device) + assert op.get_dtype() == bo.get_dtype(), (i, op, bo) + assert op.get_layout().offset == bo.get_layout().offset, (i, op, bo) + + while_loop = WhileLoop( + carried_inputs=carried_inputs, + additional_inputs=additional_inputs, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cond_subgraph=cond_fn, body_subgraph=body_fn, # asserted above that there is at least one operand layout=MultiOutputLayout(device=device), +<<<<<<< HEAD unbacked_bindings=unbacked_bindings, stack_output=stack_output, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) assert body_fn.graph is not None and isinstance( @@ -8795,6 +10840,7 @@ def _guard_list_equals( )[3] mutated_idx_set = OrderedSet(mutated_idxs) mutated_inputs = [all_inputs[idx] for idx in mutated_idx_set] +<<<<<<< HEAD # Create all outputs first mutated_inputs_iter = iter(mutated_inputs) @@ -8844,6 +10890,39 @@ def _guard_list_equals( while_loop.outputs.append(multi_out) all_outputs.append(multi_out) +======= + real_outputs = { + idx: out + for idx, out in enumerate(body_outputs) + if idx not in mutated_idx_set + } + real_outputs = [ + MultiOutput( + FixedLayout( + device=output.get_device(), + dtype=output.get_dtype(), + size=output.get_size(), + stride=output.get_stride(), + offset=output.get_layout().offset, + ), + while_loop, + [(list, idx)], + ) + for idx, output in real_outputs.items() + ] + while_loop.outputs = real_outputs + while_loop.mutation_outputs = [ + MutationOutput(inp.layout, inp, while_loop) # type: ignore[union-attr] + for inp in mutated_inputs + ] + + outputs_iter = iter(real_outputs) + mutated_inputs_iter = iter(mutated_inputs) + all_outputs = [ + next(mutated_inputs_iter) if idx in mutated_idx_set else next(outputs_iter) + for idx in range(len(body_outputs)) + ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for inp, out in zip(carried_inputs, all_outputs): if inp.get_name() in V.graph.graph_inputs: # if a carried input of the while_loop is a graph input, @@ -8854,6 +10933,7 @@ def _guard_list_equals( V.graph.never_reuse_buffers.add(out.get_name()) return all_outputs +<<<<<<< HEAD def codegen(self, wrapper: PythonWrapperCodegen) -> None: wrapper.codegen_while_loop(self, self.stack_output) wrapper.codegen_unbacked_symbol_defs_for_outputs( @@ -8882,6 +10962,23 @@ def __init__( kwargs: Optional[dict[str, Any]] = None, *, unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None, +======= + def codegen(self, wrapper) -> None: # type: ignore[no-untyped-def] + wrapper.codegen_while_loop(self) + + +class EffectfulKernel(FallbackKernel): + def __init__( # type: ignore[no-untyped-def] + self, + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + *, + unbacked_bindings=None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: super().__init__( layout, @@ -8919,10 +11016,14 @@ def has_side_effects(self) -> bool: class NonTensorObj(IRNode): +<<<<<<< HEAD def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: return OrderedSet() +======= + pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @ir_dataclass @@ -8948,8 +11049,12 @@ def get_real_obj(self) -> torch.ScriptObject: def get_buf_bytes(self) -> int: # Returns the sum of all tensors in the flattened object real_script_obj = self.get_real_obj() +<<<<<<< HEAD assert hasattr(real_script_obj, "__obj_flatten__") flat_dict = dict(real_script_obj.__obj_flatten__()) +======= + flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) flat_elems = pytree.tree_flatten(flat_dict)[0] flat_sizes = [ x.element_size() * x.numel() @@ -8985,10 +11090,14 @@ def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None: "Setting cpp kernel needs a valid op_overload" ) kernel = self.op_overload +<<<<<<< HEAD if cpp_kernel_name is not None: self.cpp_kernel_name = cpp_kernel_name else: self.cpp_kernel_name = kernel._schema.name +======= + self.cpp_kernel_name = kernel._schema.name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.ordered_kwargs_for_cpp_kernel = [ x.name for x in kernel._schema.arguments if x.kwarg_only @@ -9001,12 +11110,17 @@ def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None: # the constraints, we model collective -> wait_tensor as as two-step # mutation of the input buffers. @classmethod +<<<<<<< HEAD def create_inplace( cls, kernel: _OpOverloads, inputs: Union[IRNode, list[IRNode]], *args: Any, **kwargs: Any, +======= + def create_inplace( # type: ignore[no-untyped-def] + cls, kernel, inputs: Union[TensorBox, list[TensorBox]], *args, **kwargs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: with V.graph.fake_mode: ( @@ -9066,6 +11180,7 @@ def create_inplace( # TODO(yifu): add a pre-grad pass to validate the correctness of collective # usage in the user program. @classmethod +<<<<<<< HEAD def create_out_of_place( cls, kernel: _OpOverloads, @@ -9073,6 +11188,11 @@ def create_out_of_place( *args: Any, **kwargs: Any, ) -> Union[list[MultiOutput], _CollectiveKernel]: +======= + def create_out_of_place( # type: ignore[no-untyped-def] + cls, kernel, inputs: Union[TensorBox, list[TensorBox]], *args, **kwargs + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with V.graph.fake_mode: ( example_output, @@ -9087,7 +11207,10 @@ def create_out_of_place( if isinstance(example_output, list): device = cls.find_device(tensor_args, example_output) +<<<<<<< HEAD assert device is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) packed = cls( MultiOutputLayout(device=device), kernel, @@ -9125,6 +11248,7 @@ def create_out_of_place( return packed +<<<<<<< HEAD class _AllReduce_Kernel(_CollectiveKernel): def __init__( self, @@ -9225,6 +11349,14 @@ def get_volatile_reads(self) -> Sequence[IRNode]: i = inp.inputs[0] assert isinstance(i, IRNode), type(i) return [i] +======= +class _WaitKernel(_CollectiveKernel): + def get_volatile_reads(self): # type: ignore[no-untyped-def] + inp = self.inputs[0] + if isinstance(inp, _CollectiveKernel): + # Out-of-place single-output + return [inp.inputs[0]] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif isinstance(inp, MultiOutput): # This can be two things: # 1. Out-of-place multi-output coll @@ -9242,7 +11374,11 @@ def get_volatile_reads(self) -> Sequence[IRNode]: return [] @classmethod +<<<<<<< HEAD def create_wait(cls, kernel: _OpOverloads, inp: TensorBox) -> None: +======= + def create_wait(cls, kernel, inp: TensorBox) -> None: # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with V.graph.fake_mode: ( _example_output, diff --git a/torch/_inductor/jagged_lowerings.py b/torch/_inductor/jagged_lowerings.py index 83848c5a9612c..9048c75eb334d 100644 --- a/torch/_inductor/jagged_lowerings.py +++ b/torch/_inductor/jagged_lowerings.py @@ -5,7 +5,12 @@ import torch +<<<<<<< HEAD from .ir import Pointwise, ShapeAsConstantBuffer, TensorBox +======= +from .ir import Pointwise, TensorBox +from .lowering import fallback_handler, is_integer_type, register_lowering +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .virtualized import ops @@ -26,7 +31,11 @@ def get_inverse_offsets( offsets: TensorBox, jagged_len: Union[int, sympy.Expr], realize: bool = True, +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= +) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Returns "inverse_offsets" - the inverse of the offsets array. offsets maps batch index (dense) to jagged index (i.e. offset into jagged tensor). @@ -108,9 +117,12 @@ def jagged_idx_to_dense_idx( def register_jagged_ops(): +<<<<<<< HEAD # Avoid circular import by importing here from .lowering import fallback_handler, is_integer_type, register_lowering +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # pyre-ignore[56] @register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default) def _jagged_to_padded_dense_forward( @@ -118,7 +130,11 @@ def _jagged_to_padded_dense_forward( jagged_offsets: list[TensorBox], max_lengths: list[int], # list of ints/SymInts padding_value: float = 0.0, +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= + ) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device = jagged_values.get_device_or_error() dtype = jagged_values.get_dtype() @@ -188,7 +204,11 @@ def _dense_to_jagged_forward_impl( dense: TensorBox, jagged_offsets: list[TensorBox], jagged_len: Optional[int] = None, +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= + ) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device = dense.get_device_or_error() dtype = dense.get_dtype() @@ -261,7 +281,11 @@ def _dense_to_jagged_forward( dense: TensorBox, jagged_offsets: list[TensorBox], jagged_len: Optional[int] = None, +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= + ) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return _dense_to_jagged_forward_impl( fallback_op=torch.ops.aten._padded_dense_to_jagged_forward.default, dense=dense, diff --git a/torch/_inductor/kernel/__init__.py b/torch/_inductor/kernel/__init__.py index 9668f1b6c6e1d..d410e701aa805 100644 --- a/torch/_inductor/kernel/__init__.py +++ b/torch/_inductor/kernel/__init__.py @@ -1 +1,5 @@ +<<<<<<< HEAD from . import flex, mm, mm_common, mm_plus_mm +======= +from . import mm, mm_common, mm_plus_mm +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index a843c7369fb53..2bb4f4c53ab58 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -1,13 +1,19 @@ # mypy: allow-untyped-defs import logging +<<<<<<< HEAD from typing import TYPE_CHECKING +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch._dynamo.utils import counters from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate from .. import ir, lowering as L +<<<<<<< HEAD from ..kernel_inputs import MMKernelInputs +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..select_algorithm import ( autotune_select_algorithm, ExternKernelChoice, @@ -23,12 +29,25 @@ use_triton_template, ) from ..virtualized import V +<<<<<<< HEAD from .mm_common import _is_static_problem, is_batch_stride_largest_or_zero, mm_args if TYPE_CHECKING: from ..ir import ChoiceCaller +======= +from .mm_common import ( + _is_static_problem, + addmm_epilogue, + is_batch_stride_largest, + mm_args, + mm_config_kwargs, + mm_options, +) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) aten = torch.ops.aten @@ -38,6 +57,16 @@ def bmm_grid(b, m, n, meta, *, cdiv): return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1) +<<<<<<< HEAD +======= +def _is_large_block_for_cpu(m, n, k): + # Thresholds are experimentally determined to reduce Triton CPU compile times + if m > 128 or n > 128 or k > 128: + return True + return m * n > 2**12 + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bmm_template = TritonTemplate( name="bmm", grid=bmm_grid, @@ -166,6 +195,7 @@ def may_require_contiguous(t, meta_t): meta_mat2 = V.graph.current_node.args[1] mat2 = may_require_contiguous(mat2, meta_mat2) +<<<<<<< HEAD # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that m, n, k, layout, mat1, mat2 = mm_args( mat1, mat2, layout=layout, out_dtype=out_dtype @@ -174,6 +204,11 @@ def may_require_contiguous(t, meta_t): # Create MMKernelInputs for BMM at the top kernel_inputs = MMKernelInputs([mat1, mat2]) +======= + m, n, k, layout, mat1, mat2 = mm_args( + mat1, mat2, layout=layout, out_dtype=out_dtype + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # below is for getting an overview logging info of inductor mms batch_size = mat1.get_size()[0] # Extract batch dimension @@ -189,6 +224,7 @@ def may_require_contiguous(t, meta_t): layout, ) +<<<<<<< HEAD aten_handler: ExternKernelChoice = aten_bmm aten_extra_kwargs = {} if out_dtype: @@ -228,6 +264,47 @@ def may_require_contiguous(t, meta_t): CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( choices, layout, kernel_inputs.nodes() ) # type: ignore[arg-type] +======= + if out_dtype: + assert mat1.get_device().type == "cuda", "out_dtype is only supported for CUDA" + aten_func = aten_bmm_dtype.bind((mat1, mat2), layout, out_dtype=out_dtype) + else: + aten_func = aten_bmm.bind((mat1, mat2), layout) + + # options to tune from + choices = [aten_func] if use_aten_gemm_kernels() else [] + + device_type = ir.get_device_type(mat1) + bmm_configs = V.choices.get_base_mm_configs(device_type) + + dtype = mat1.get_dtype() + if use_triton_template(layout): + # TODO: add out_dtype support for Triton Template + assert out_dtype is None, "out_dtype is not supported for Triton" + for config in bmm_configs( + m, + n, + k, + **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize), + ): + bmm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + _, is_nonzero = _is_static_problem(layout) + batch_stride_largest = is_batch_stride_largest(mat1, mat2, layout) + if ( + batch_stride_largest + and is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("bmm") + ): + from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate + + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if use_cpp_bmm_template(layout, mat1, mat2): from ..codegen.cpp_bmm_template import CppBmmTemplate @@ -235,6 +312,7 @@ def may_require_contiguous(t, meta_t): CppBmmTemplate.add_choices( choices, layout, +<<<<<<< HEAD kernel_inputs.nodes(), ) @@ -242,10 +320,20 @@ def may_require_contiguous(t, meta_t): CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes()) return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) +======= + [mat1, mat2], + ) + + if use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2]) + + return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @L.register_lowering(aten.baddbmm) def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): +<<<<<<< HEAD """ Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.) """ @@ -257,6 +345,10 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): [inp, mat1, mat2], scalars=dict(alpha=alpha, beta=beta) ) +======= + m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # below is for getting an overview logging info of inductor mms batch_size = mat1.get_size()[0] counters["aten_mm_info"][f"aten.baddbmm_{batch_size}_{m}_{n}_{k}"] += 1 @@ -271,6 +363,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): inp.get_dtype(), layout, ) +<<<<<<< HEAD name = "baddbmm" # options to tune from choices: list[ChoiceCaller] = [] @@ -290,3 +383,31 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): ) return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) +======= + + # options to tune from + choices = ( + [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)] + if use_aten_gemm_kernels() + else [] + ) + + device_type = ir.get_device_type(mat1) + bmm_configs = V.choices.get_base_mm_configs(device_type) + + if use_triton_template(layout): + for config in bmm_configs( + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + ): + bmm_template.maybe_append_choice( + choices, + input_nodes=(inp, mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]), + ) + + return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index 6b9e9a1a32e7f..ef6e0a58eaa88 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -29,6 +29,10 @@ use_triton_template, ) from ..virtualized import V +<<<<<<< HEAD +======= +from .mm_common import mm_config_kwargs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING: @@ -60,6 +64,16 @@ def conv3d_grid(n, c, d, h, w, meta, *, cdiv): ) +<<<<<<< HEAD +======= +def _is_large_block_for_cpu(m, n, k): + # Thresholds are experimentally determined to reduce Triton CPU compile times + if m > 256 or n > 256 or k > 256: + return True + return m * n * k > 2**17 + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) LOOP_BODY_2D = """ idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W @@ -430,17 +444,29 @@ def convolution( dilation = tuple(dilation) output_padding = tuple(output_padding) if not isinstance(groups, int): +<<<<<<< HEAD groups = V.graph.sizevars.guard_int(groups) +======= + groups = V.graph.sizevars.evaluate_static_shape(groups) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(groups, int) # Need use hint for triton template since the template does not # work with a dynamic shape. # +<<<<<<< HEAD # No need to guard_int for dilation and output_padding # since the template is only used when dilation is 1 and output_padding # is 0. stride = tuple(V.graph.sizevars.guard_int_seq(stride)) padding = tuple(V.graph.sizevars.guard_int_seq(padding)) +======= + # No need to evaluate_static_shape for dilation and output_padding + # since the template is only used when dilation is 1 and output_padding + # is 0. + stride = tuple(V.graph.sizevars.evaluate_static_shapes(stride)) + padding = tuple(V.graph.sizevars.evaluate_static_shapes(padding)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs: ConvLayoutParams = { "stride": stride, @@ -460,7 +486,13 @@ def convolution( dim=0, ) +<<<<<<< HEAD out_chan, in_chan, *kernel_shape = V.graph.sizevars.guard_int_seq(weight.get_size()) +======= + out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes( + weight.get_size() + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Always convert conv1D to 2D for Intel GPU. # Only conv2D can be converted to channel last layout, @@ -529,18 +561,30 @@ def channels_last_conv(): # apply channels last. if V.graph.layout_opt and ndim == 2: V.graph.num_channels_last_conv += 1 +<<<<<<< HEAD x = ir.ExternKernel.require_channels_last(x) # type: ignore[assignment] # TODO maybe we can convert weights to channels last just once before # running the model. weight = ir.ExternKernel.require_channels_last(weight) # type: ignore[assignment] +======= + x = ir.ExternKernel.require_channels_last(x) + # TODO maybe we can convert weights to channels last just once before + # running the model. + weight = ir.ExternKernel.require_channels_last(weight) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) layout = conv_layout(x, weight, None, **kwargs) else: layout = conv_layout(x, weight, None, **kwargs) req_stride_order = ir.get_stride_order( V.graph.sizevars.size_hints(layout.stride) ) +<<<<<<< HEAD x = ir.ExternKernel.require_stride_order(x, req_stride_order) # type: ignore[assignment] weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) # type: ignore[assignment] +======= + x = ir.ExternKernel.require_stride_order(x, req_stride_order) + weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ordered_kwargs_for_cpp_kernel = [ "stride", @@ -558,7 +602,11 @@ def channels_last_conv(): args = [x, weight, bias] bias.realize() bias.freeze_layout() +<<<<<<< HEAD V.graph.sizevars.guard_int_seq(bias.get_size()) +======= + V.graph.sizevars.evaluate_static_shapes(bias.get_size()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) choices = [] if torch._inductor.utils._use_conv_autotune_backend("ATEN"): @@ -595,6 +643,10 @@ def channels_last_conv(): sympy_product([x.get_size()[0], *x.get_size()[2:]]), out_chan, in_chan, +<<<<<<< HEAD +======= + **mm_config_kwargs(device_type, _is_large_block_for_cpu), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): if ndim == 2: conv2d_template.maybe_append_choice( diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py new file mode 100644 index 0000000000000..22751fe100479 --- /dev/null +++ b/torch/_inductor/kernel/flex_attention.py @@ -0,0 +1,2763 @@ +# mypy: allow-untyped-defs +"""Triton Implementation of the flex_attention Kernel""" + +import copy +import logging +import math +from collections.abc import Sequence +from dataclasses import dataclass +from enum import auto, Enum +from typing import Any, Optional, Union + +import sympy + +import torch +from torch._inductor.virtualized import V +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_map +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.value_ranges import ValueRanges + +from ..ir import ( + Buffer, + ComputedBuffer, + ExternKernel, + FixedLayout, + FlexibleLayout, + get_fill_order, + InputBuffer, + IRNode, + MutationLayoutSHOULDREMOVE, + Scatter, + StorageBox, + Subgraph, + TensorBox, +) +from ..lowering import ( + _full, + check_and_broadcast_indices, + empty, + empty_strided, + expand, + index_output_size_and_inner_fn, + lowerings, + register_lowering, + to_dtype, +) +from ..select_algorithm import ( + autotune_select_algorithm, + realize_inputs, + SymbolicGridFn, + TritonTemplate, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +Expr = sympy.Expr + + +def construct_strides( + sizes: Sequence[int], + fill_order: Sequence[int], +) -> Sequence[int]: + """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" + # Initialize strides + assert len(sizes) == len(fill_order), ( + "Length of sizes must match the length of the fill order" + ) + strides = [0] * len(sizes) + + # Start with stride 1 for the innermost dimension + current_stride = 1 + + # Iterate through the fill order populating strides + for dim in fill_order: + strides[dim] = current_stride + current_stride *= sizes[dim] + + return strides + + +def infer_dense_strides(size: Sequence[int], orig_strides: Sequence[int]): + """This is a mirror of the same function in aten/src/ATen/ExpandUtils.cpp + + Args: + size: The size of the output tensor + orig_strides: The strides of the input tensor + Returns: + List[int]: Dense non-overlapping strides that preserve the input tensor's layout permutation. + The returned strides follow the same stride propagation rules as TensorIterator. This matches + The behavior of empty_like() + """ + fill_order = get_fill_order(orig_strides, V.graph.sizevars.shape_env) + return construct_strides(size, fill_order) + + +@SymbolicGridFn +def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv): + """How is this kernel parallelized? + We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1) + Each block is responsible for iterating over blocks of keys and values calculating + the final attention output. + """ + return (cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1) + + +def create_placeholder( + name: str, + dtype: torch.dtype, + device: torch.device, + size: Optional[list[int]] = None, +) -> TensorBox: + """Creates a placeholder input buffers for producing subgraph_output.""" + input_buffer = InputBuffer( + name=name, + layout=FixedLayout( + device, + dtype, + size if size else [], + FlexibleLayout.contiguous_strides(size) if size else [], + ), + ) + return TensorBox.create(input_buffer) + + +def maybe_realize(args: list[Optional[IRNode]]): + """Accepts a list of optional IRNodes and returns a list of realized IRNodes""" + return tree_map( + lambda x: ( + realize_inputs(x) + if x is not None and not isinstance(x, sympy.Symbol) + else x + ), + args, + ) + + +def get_float32_precision(): + if ( + torch.get_float32_matmul_precision() == "highest" + or torch.version.hip + or torch.mtia.is_available() + ): + return "'ieee'" + else: + return "'tf32'" + + +def zeros_and_scatter_lowering(shape: list[int], indices, values): + # Always accumulate into fp32 then cast + grad = _full(0, values.get_device(), torch.float32, shape) + assert isinstance(grad, TensorBox) + grad.realize() + x_size = grad.get_size() + values = to_dtype(values, grad.get_dtype()) + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + indices, tensor_indices = check_and_broadcast_indices(indices, grad.get_device()) + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + indexed_size = [x_size[i] for i in range(len(indices))] + + expected_vals_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=True, + ) + + values = expand(values, expected_vals_size) + device = grad.get_device() + assert device is not None + scatter = Scatter( + device=device, + dtype=grad.get_dtype(), + inner_fn=values.make_loader(), + ranges=expected_vals_size, # iter_ranges, + output_indexer=inner_fn, + scatter_mode="atomic_add", + ) + + buffer = ComputedBuffer( + name=grad.data.data.name, # type: ignore[attr-defined] + layout=MutationLayoutSHOULDREMOVE(grad), + data=scatter, + ) + return buffer + + +SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]] + + +def build_subgraph_module_buffer( + args: list[TensorBox], graph_module: torch.fx.GraphModule +) -> SubgraphResults: + """This function's goal is to take in the required args and produce the subgraph buffer + The subgraph buffer is a ComputedBuffer that will be inlined into the triton template + + Args: + args: The args that are passed into the subgraph. Contains both fixed and lifted inputs. + subgraph: The Subgraph ir for which to produce the output node + """ + from ..subgraph_lowering import PointwiseSubgraphLowering + + pw_subgraph = PointwiseSubgraphLowering( + graph_module, + root_graph_lowering=V.graph, + allowed_mutations=OrderedSet([torch.ops.flex_lib.zeros_and_scatter.default]), + additional_lowerings={ + torch.ops.flex_lib.zeros_and_scatter.default: zeros_and_scatter_lowering + }, + ) + with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type] + pw_subgraph.run(*args) + + # Since we are allowing mutations/buffer creation, we need to register any fresh buffers + # creating during the pointwise subgraph lowering + if len(pw_subgraph.buffers) > 0: + for buffer in pw_subgraph.buffers: + V.graph.register_buffer(buffer) + + def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]: + if output_buffer is None: + return None + if isinstance(output_buffer, ComputedBuffer): + # These nodes are coming from the output of zeros_and_scatter + return output_buffer + assert isinstance(output_buffer, TensorBox), ( + "The output node for flex attention's subgraph must be a TensorBox, but got: ", + type(output_buffer), + ) + assert isinstance(output_buffer.data, StorageBox), ( + "The output node for the flex attention subgraph must be a StorageBox, but got: ", + type(output_buffer), + ) + subgraph_buffer = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=output_buffer.data.get_device(), + dtype=output_buffer.data.get_dtype(), + size=output_buffer.data.get_size(), + ), + data=output_buffer.data.data, # type: ignore[arg-type] + ) + return subgraph_buffer + + return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs) + + +def build_subgraph_buffer(args: list[TensorBox], subgraph: Subgraph) -> SubgraphResults: + return build_subgraph_module_buffer(args, subgraph.graph_module) + + +def get_fwd_subgraph_outputs( + subgraph_buffer: SubgraphResults, mask_graph_buffer: SubgraphResults +) -> list[Optional[ComputedBuffer]]: + subgraph_buffer = ( + subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer] + ) + mask_graph_buffer = ( + mask_graph_buffer + if isinstance(mask_graph_buffer, Sequence) + else [mask_graph_buffer] + ) + return [*subgraph_buffer, *mask_graph_buffer] + + +# Inner Triton functions shared by flex_attention & split-k decoding kernels. +compute_next_offset_func = r""" +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset +""" + +get_bounded_indices_func = r""" +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices +""" + + +load_checked_block = r""" +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") +""" + +load_checked_2d = r""" +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_DIM: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +""" + +compute_flex_attention = r""" +{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0) + off_zq = tl.program_id(1) // HQ + off_hq = tl.program_id(1) % HQ + + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + {%- if USE_TMA %} + desc_q = tl.make_tensor_descriptor( + base=Q, + shape=[Q_LEN*HQ*ZQ, QK_HEAD_DIM], + strides=[QK_HEAD_DIM, 1], + block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED], + ) + desc_v = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM], + strides=[V_HEAD_DIM, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + desc_k = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM], + strides=[V_HEAD_DIM, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + {%- endif %} + + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + K_block_ptr = None + V_block_ptr = None + Q_block_ptr = None + + if not USE_TMA: + Q_block_ptr = tl.make_block_ptr( + base=Q , + shape=(Q_LEN, QK_HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(q_start * BLOCK_M, 0), + block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + + {%- if USE_TMA %} + q = tl.load_tensor_descriptor( + desc_q, + [(q_start * BLOCK_M).to(tl.int32), 0], + ) + {%- else %} + q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + if not USE_TMA: + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(QK_HEAD_DIM, KV_LEN), + strides=(stride_kk, stride_kn), + offsets=(0, kv_start), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(kv_start, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + if not USE_TMA: + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(QK_HEAD_DIM, KV_LEN), + strides=(stride_kk, stride_kn), + offsets=(0, kv_start), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(kv_start, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1) // HQ + idx_hq = tl.program_id(1) % HQ + idx_m = offs_m[:, None] + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} + + if OUTPUT_LOGSUMEXP: + off_hz = tl.program_id(1) + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + """ + + +compute_forward_inner = r""" +@triton.jit +def forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + start_n, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + start_n, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + if not USE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, offset)) + V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) + + + return acc, l_i, m_i + +""" + + +compute_forward_block_mn = r""" +@triton.jit +def forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + start_n, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + # -- load k -- + # NB reversed order to since K is transposed + {%- if USE_TMA %} + k = tl.load_tensor_descriptor( # load in row major + desc_k, + [start_n.to(tl.int32) , kv_start], + ) + {%- else %} + k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE) + {%- endif %} + + if USE_TMA: + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=1, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + {%- if USE_TMA %} + v = tl.load_tensor_descriptor( + desc_v, + [kv_start.to(tl.int32) + start_n.to(tl.int32),0], + ) + {%- else %} + v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +""" + + +flex_attention_template = TritonTemplate( + name="flex_attention", + grid=flex_attention_grid, + source=compute_flex_attention + + compute_forward_inner + + compute_next_offset_func + + compute_forward_block_mn + + load_checked_block + + get_bounded_indices_func, +) + + +def _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa): + """Decide which kernel to use, return true if use flex decoding kernel. + Note: + Since the number of splits is calculated based of the the number of batch and head dims + we need to ensure that the batch and head dims are statically known. Otherwise we just + use the main flex_attention kernel. + """ + force_flex = kernel_options.get("FORCE_USE_FLEX_ATTENTION", False) + short_query_length = V.graph.sizevars.evaluate_expr( + sympy.Lt(query.get_size()[-2], 128) + ) + non_zero_length = V.graph.sizevars.evaluate_expr(sympy.Gt(query.get_size()[-2], 0)) + static_batch = isinstance(query.get_size()[0], (int, sympy.Integer)) + static_num_heads = isinstance(query.get_size()[1], (int, sympy.Integer)) + if enable_gqa: + # in the current flex decoding triton kernel, grouped query heads for the + # same kv head are handled by the same block. So it's hard to support different + # kv num blocks for grouped query heads. We just fall back to main flex_attention + # kernel where each query head is handled by a separate block. + valid_block_mask_num_heads = V.graph.sizevars.evaluate_expr( + sympy.Eq(kv_indices.get_size()[1], 1) + ) + else: + valid_block_mask_num_heads = V.graph.sizevars.evaluate_expr( + sympy.Or( + sympy.Eq(kv_indices.get_size()[1], 1), + sympy.Eq(kv_indices.get_size()[1], query.get_size()[1]), + ) + ) + return ( + not force_flex + and short_query_length + and static_batch + and static_num_heads + and non_zero_length + and valid_block_mask_num_heads + ) + + +_h100_default_config = { + (torch.float32, 64): (128, 32, 4, 3), + (torch.float32, 128): (32, 64, 4, 3), + (torch.float32, 256): (32, 32, 4, 3), + (torch.bfloat16, 64): (128, 128, 4, 3), + (torch.bfloat16, 128): (128, 64, 8, 3), + (torch.bfloat16, 256): (64, 32, 4, 3), + (torch.float16, 64): (128, 128, 4, 3), + (torch.float16, 128): (128, 128, 8, 3), + (torch.float16, 256): (64, 32, 4, 3), +} + +_a100_default_config = { + (torch.float32, 64): (128, 32, 4, 3), + (torch.float32, 128): (128, 32, 4, 3), + (torch.float32, 256): (64, 16, 4, 3), + (torch.bfloat16, 64): (128, 64, 4, 3), + (torch.bfloat16, 128): (128, 64, 8, 3), + (torch.bfloat16, 256): (32, 64, 4, 3), + (torch.float16, 64): (128, 64, 4, 3), + (torch.float16, 128): (128, 64, 8, 3), + (torch.float16, 256): (32, 64, 4, 3), +} + +_rocm_default_config = { + (torch.float32, 64): (128, 32, 4, 1), + (torch.float32, 128): (128, 32, 4, 1), + (torch.float32, 256): (64, 16, 4, 1), + (torch.bfloat16, 64): (128, 64, 8, 1), + (torch.bfloat16, 128): (128, 64, 8, 1), + (torch.bfloat16, 256): (32, 64, 8, 1), + (torch.float16, 64): (128, 64, 8, 1), + (torch.float16, 128): (128, 64, 8, 1), + (torch.float16, 256): (32, 64, 4, 1), +} + + +class Mode(Enum): + fwd = auto() + bwd = auto() + + +def create_num_blocks_fake_generator(sparse_indices): + # The idea here is that we need to create a real tensor with real data + # that's representative for benchmarking. + # For example, returning all zeros for the `kv_num_blocks` input would mean + # that we are computing 0 blocks for each row, which would provide bogus + # autotuning results. + # + # In this case, we choose to use min(16, max_block) blocks, because I + # (Horace) think it'll probably result in pretty representative performance. + # If it's too short then prefetching won't help. If it's too long then + # autotuning will take longer for no good reason. + def create_num_blocks_fake(x) -> torch.Tensor: + num_blocks_for_autotuning = V.graph.sizevars.size_hint(sparse_indices.shape[-1]) + size = [V.graph.sizevars.size_hint(i) for i in x.get_size()] + return torch.full( + size, + num_blocks_for_autotuning, + dtype=x.get_dtype(), + device=x.get_device(), + ) + + return create_num_blocks_fake + + +def create_indices_fake(x) -> torch.Tensor: + size = [V.graph.sizevars.size_hint(i) for i in x.get_size()] + indices = torch.arange(0, size[-1], dtype=x.get_dtype(), device=x.get_device()) + indices = indices.expand(size).contiguous() + return indices + + +from torch._inductor.kernel.flex_decoding import create_flex_decoding_kernel + +from ..codegen.cpp_flex_attention_template import CppFlexAttentionTemplate + + +def check_cpu_supported(): + import os + import sys + + requires_avx2_on_cpu = ( + torch.cpu._is_avx2_supported() and os.getenv("ATEN_CPU_CAPABILITY") != "default" + ) + supported = ( + requires_avx2_on_cpu + and not torch.xpu.is_available() + and not sys.platform == "darwin" + ) + return supported + + +def contiguous_last_dim(x): + """Ensure that realized IR node has a contiguous stride in the last dimension.""" + strides = x.maybe_get_stride() + if strides and strides[-1] != 1: + contiguous_stride_order = list(reversed(range(len(x.get_size())))) + return ExternKernel.require_stride_order(x, contiguous_stride_order) + return x + + +def lower_cpu( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, +): + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + mask_graph, + ) = block_mask + + if kernel_options["OUTPUT_LOGSUMEXP"]: + raise NotImplementedError( + "torch.compile on CPU only supports inference and `return_lse` is not supported yet." + ) + if not check_cpu_supported(): + raise NotImplementedError( + "torch.compile on current platform is not supported for CPU." + ) + + fake_buffers: list[Buffer] = [] # noqa: F821 + + # [Note] Handle the case where the split sizes are not statically known. + # The value of cur_qSplitSize and cur_kvSplitSize are decided during runtime. + # We use symbols to represent them during the compilation here. + # They'll be replaced by the string "cur_qSplitSize" and "cur_kvSplitSize" in + # the modification function of the CppFlexAttentionTemplate class. + cur_qSplitSize = V.graph.sizevars.shape_env.create_unbacked_symint().node.expr + cur_kvSplitSize = V.graph.sizevars.shape_env.create_unbacked_symint().node.expr + shape_env = V.graph.sizevars.shape_env + + # We don't know the concrete value of cur_qSplitSize and cur_kvSplitSize during the compilation. + # Mark symbols > 1 to ensure broadcasting is always applied. + # This avoids treating them as equal when `eq(var, 1)` is evaluated in `broadcast_symbolic_shapes`. + shape_env.var_to_range[cur_qSplitSize] = ValueRanges(2, int_oo) + shape_env.var_to_range[cur_kvSplitSize] = ValueRanges(2, int_oo) + + score_dtype = torch.float + placeholder_inps = [ + create_placeholder(name, dtype, query.get_device(), size) + for name, dtype, size in [ + ("score", score_dtype, [cur_qSplitSize, cur_kvSplitSize]), + ("b", torch.int64, []), + ("h", torch.int64, []), + ("q_idx", torch.int64, [cur_qSplitSize, 1]), + ("kv_idx", torch.int64, [1, cur_kvSplitSize]), + ] + ] + subgraph_buffer = build_subgraph_buffer( + placeholder_inps + list(score_mod_other_buffers), subgraph + ) + if subgraph_buffer is not None: + if isinstance(subgraph_buffer, list): + for _buf in subgraph_buffer: + if _buf is not None: + _buf.freeze_layout() + else: + subgraph_buffer.freeze_layout() + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device(), size) + for name, dtype, size in [ + ("score", score_dtype, [cur_qSplitSize, cur_kvSplitSize]), + ("b", torch.int64, []), + ("h", torch.int64, []), + ("q_idx", torch.int64, [cur_qSplitSize, 1]), + ("kv_idx", torch.int64, [1, cur_kvSplitSize]), + ] + ] + + # The original mask_graph works on a scalar and only includes + # the logic of calculating the mask value. + # We need to add the logic of applying the mark to the qk_data tensor + # into the graph for the later codegen of this part. + # Example: + # mask_graph: + # def mask_fn(b, h, q_idx, kv_idx): + # mask = q_idx >= kv_idx + # return mask + # The converted_mask_graph should be: + # def converted_mask_fn(qk_data, b, h, q_idx, kv_idx): + # mask = q_idx >= kv_idx + # qk_data = torch.where(mask, qk_data, torch.full_like(qk_data, -float("inf"))) + # return qk_data + def convert_mask_graph_module(mask_graph): + gm = copy.deepcopy(mask_graph.graph_module) + graph = gm.graph + # Add qk_data as the first input + with graph.inserting_before(next(iter(graph.nodes))): + qk_data_node = graph.placeholder("qk_data") + + # Find the node that returns the mask + output_node = None + for node in graph.nodes: + if node.op == "output": + output_node = node + break + + # Get the mask node + assert output_node is not None + mask_node = output_node.args[0] + + size_node = [cur_qSplitSize, cur_kvSplitSize] + # Create a new node for torch.full + with graph.inserting_after(mask_node): + full_node = graph.call_function( + torch.full, + args=(size_node, -float("inf")), + kwargs={"dtype": score_dtype}, + ) + + # Create a new node for torch.where + with graph.inserting_after(full_node): + where_node = graph.call_function( + torch.ops.aten.where, args=(mask_node, qk_data_node, full_node) + ) + + # Update the output node to return the result of torch.where + output_node.args = (where_node,) + + graph.lint() + converted = torch.fx.GraphModule(gm, graph) + return converted + + converted_mask_graph_module = convert_mask_graph_module(mask_graph) + + mask_graph_buffer = build_subgraph_module_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), + converted_mask_graph_module, + ) + + # Clear the pending fresh unbacked symbols that are created for cur_qSplitSize and cur_kvSplitSize in the current kernel. + pending = V.graph.sizevars.shape_env.pending_fresh_unbacked_symbols + V.graph.sizevars.shape_env.pending_fresh_unbacked_symbols = [ + x for x in pending if x not in (cur_qSplitSize, cur_kvSplitSize) + ] + + buffer_list = ( + placeholder_inps + + list(score_mod_other_buffers) + + mask_graph_placeholder_inps + + list(mask_mod_other_buffers) + ) + for item in buffer_list: + if isinstance(item, TensorBox): + fake_buffers.append(item.data.data) # type: ignore[attr-defined] + + # CPU kernel requires last dim to be contiguous + query, key, value = map(contiguous_last_dim, [query, key, value]) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + if len(OrderedSet([query.get_name(), key.get_name(), value.get_name()])) != 3: + raise NotImplementedError( + "Unsupported for now if query, key, value are the same buffer." + ) + if query.get_dtype() not in [torch.float, torch.bfloat16, torch.float16]: + raise NotImplementedError( + "`torch.float` , `torch.float16` and `torch.bfloat16` are supported in FlexAttention for CPU device. " + f"Found input tensors are `{query.get_dtype()}`." + ) + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) + mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + B = Bq + + # Construct output layout with strides matching the query. + out_size = [B, Hq, seq_len_q, v_head_dim] + out_strides = infer_dense_strides(out_size, query.get_stride()) + + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + [B, Hq, seq_len_q, v_head_dim], + stride=[sympy.sympify(s) for s in out_strides], + ) + _choices: list[Any] = [] + input_nodes = [query, key, value, kv_num_blocks, kv_indices] + if not full_kv_num_blocks: + no_full_kv_block = True + else: + no_full_kv_block = False + input_nodes += [full_kv_num_blocks] + input_nodes += [full_kv_indices] + has_other_buffer = False + kernel_input_name_to_buffer = {} + if score_mod_other_buffers or mask_mod_other_buffers: + has_other_buffer = True + + for prefix, buffers in [ + ("score_others", score_mod_other_buffers), + ("mask_others", mask_mod_other_buffers), + ]: + kernel_input_name_to_buffer.update( + {f"{prefix}_{i}": buf for i, buf in enumerate(buffers)} + ) + input_nodes += [ + value + for value in kernel_input_name_to_buffer.values() + if not isinstance(value, sympy.Symbol) + ] + + skip_mask_score = kernel_options.get("SKIP_MASK_SCORE", False) + # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + assert V.graph.sizevars.evaluate_expr( + sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE)) + ), ( + "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask." + ) + assert V.graph.sizevars.evaluate_expr( + sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE)) + ), ( + "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask." + ) + CppFlexAttentionTemplate.add_choices( + choices=_choices, + input_nodes=input_nodes, + layout=layout, + scale=scale, + score_mod=None if skip_mask_score else subgraph_buffer, + mask_mod=None if skip_mask_score else mask_graph_buffer, + kv_block_size=SPARSE_KV_BLOCK_SIZE, + q_block_size=SPARSE_Q_BLOCK_SIZE, + has_other_buffer=has_other_buffer, + no_full_kv_block=no_full_kv_block, + fake_buffers=fake_buffers, + len_score_other=len(score_mod_other_buffers), + len_mask_other=len(mask_mod_other_buffers), + kernel_input_name_to_buffer=kernel_input_name_to_buffer, + block_vars=(cur_qSplitSize, cur_kvSplitSize), + ) + inputs_for_autotuning = [ + query, + key, + value, + ] + res = autotune_select_algorithm( + "flex_attention", + _choices, + inputs_for_autotuning, + layout, + ) + + # need subgraph inputs and outputs to analyze all symints used in flex attention + res.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + res.data.data.subgraph_outs = get_fwd_subgraph_outputs( + subgraph_buffer, mask_graph_buffer + ) + + return (res,) + + +def is_power_of_2(n): + return n != 0 and ((n & (n - 1)) == 0) + + +def next_power_of_two(n): + if n <= 0: + return 1 + return 2 ** math.ceil(math.log2(n)) + + +def set_head_dim_values( + kernel_options: dict[str, Any], qk_head_dim, v_head_dim, graph_sizevars +): + """ + Mutates kernel options, adding head dimension calculations. + + Args: + kernel_options: Dictionary to populate with options + qk_head_dim: Query/Key head dimension + v_head_dim: Value head dimension + graph_sizevars: Graph size variables object with evaluate_static_shape method + + """ + # QK dimensions + qk_head_dim_static = graph_sizevars.evaluate_static_shape(qk_head_dim) + kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim_static) + kernel_options.setdefault( + "QK_HEAD_DIM_ROUNDED", next_power_of_two(qk_head_dim_static) + ) + + # V dimensions + v_head_dim_static = graph_sizevars.evaluate_static_shape(v_head_dim) + kernel_options.setdefault("V_HEAD_DIM", v_head_dim_static) + kernel_options.setdefault( + "V_HEAD_DIM_ROUNDED", next_power_of_two(v_head_dim_static) + ) + + # Safety flag + kernel_options.setdefault( + "SAFE_HEAD_DIM", + is_power_of_2(qk_head_dim_static) and is_power_of_2(v_head_dim_static), + ) + + +@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None) +def flex_attention( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, +): + if query.get_device().type == "cpu": + return lower_cpu( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + # below is cuda path if device is not cpu + # tl.dot does not support embedding size less than 16 + small_dqk = V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-1], 16)) + small_dv = V.graph.sizevars.evaluate_expr(sympy.Lt(value.get_size()[-1], 16)) + if small_dqk or small_dv: + raise NotImplementedError( + f"NYI: embedding dimension of the query, key, and value must be " + f"at least 16 but got E={query.get_size()[-1]} and Ev={value.get_size()[-1]}" + ) + + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + mask_graph, + ) = block_mask + + placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("score", query.get_dtype()), + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + subgraph_buffer = build_subgraph_buffer( + placeholder_inps + list(score_mod_other_buffers), subgraph + ) + + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + mask_graph_buffer = build_subgraph_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph + ) + + kernel_options = dict(kernel_options) + # Mark symbols in custom kernel options as static shapes and add guards. + kernel_options = { + k: V.graph.sizevars.evaluate_static_shape(v) + if isinstance(v, sympy.Symbol) + else v + for k, v in kernel_options.items() + } + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) + enable_gqa = V.graph.sizevars.evaluate_expr( + sympy.Ne(query.get_size()[1], key.get_size()[1]), + ) + if _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa): + return create_flex_decoding_kernel( + query, + key, + value, + block_mask, + scale, + kernel_options, + subgraph_buffer, + mask_graph_buffer, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) + mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) + + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_q, 0)), ( + "Query length must be greater than 0" + ) + assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_kv, 0)), ( + "Key length must be greater than 0" + ) + + B = Bq + + if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: + kernel_options.setdefault("IS_DIVISIBLE", False) + else: + kernel_options.setdefault("IS_DIVISIBLE", True) + + # NB it is okay that the v_head_dim is different + # We are using these to match fill order of the output. + q_strides = query.get_stride() + # Construct output layout with strides matching the query. + out_size = [B, Hq, seq_len_q, v_head_dim] + out_strides = infer_dense_strides(out_size, q_strides) + + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + [B, Hq, seq_len_q, v_head_dim], + stride=[sympy.sympify(s) for s in out_strides], + ) + # see NOTE:[TritonTemplates with multiple outputs] + logsumexp_shape = [B, Hq, seq_len_q] + logsumexp = empty_strided( + logsumexp_shape, + None, + dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + kernel_options.setdefault("SM_SCALE", scale) + + # Determine GQA broadcast factor. + gqa_shared_heads = Hq // Hkv + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Inside of Triton kernel, only apply partial masking if partial blocks are computed. + # full_kv_num_blocks is None if partial blocks are not computed + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + full_kv_num_blocks, full_kv_indices = ( + empty(0, device=query.get_device()) for _ in range(2) + ) + + set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) + + choices: list[Any] = [] + + dtype = query.get_dtype() + head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) + configs = V.choices.get_flex_attention_fwd_configs(head_dim, dtype) + + # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + + # Note, we don't need to pass in the captured buffers explicitly + # because they're implicitly added by the score_mod function + # We do need to explicitly pass it in for autotuning though. + original_kernel_options = kernel_options.copy() + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + for conf in configs: + if ( + SPARSE_KV_BLOCK_SIZE % conf.block_n != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_m != 0 + ): + if len(configs) == 1: + raise ValueError( + f"Q and KV block size must be divisible by BLOCK_M and BLOCK_N. We " + f"got Q_BLOCK_SIZE={SPARSE_Q_BLOCK_SIZE} and KV_BLOCK_SIZE={SPARSE_KV_BLOCK_SIZE}." + ) + continue + + cur_kernel_options = original_kernel_options.copy() + # Performance tuning + # Triton parameters + # Remove prefix for forward kernels options and delete backward kernel options. + for k in list(cur_kernel_options.keys()): + if k.startswith("fwd_"): + v = cur_kernel_options.pop(k) + cur_kernel_options[k[4:]] = v + if k.startswith("bwd_"): + cur_kernel_options.pop(k) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + # Disabling TMA by default, only explicit kernel_options supported for now + cur_kernel_options.setdefault("USE_TMA", False) + + cur_kernel_options.setdefault("BLOCK_M", conf.block_m) + cur_kernel_options.setdefault("BLOCK_N", conf.block_n) + # Blocksparse options + cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + + # ROCm specific kernargs + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + + error = flex_attention_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout, + subgraphs=[ + subgraph_buffer, + mask_graph_buffer, + ], + mutated_inputs=[ + logsumexp, + ], + call_sizes=query.get_size(), + **cur_kernel_options, + ) + if error is not None and len(configs) == 1: + raise error + inputs_for_autotuning = ( + [ + query, + key, + value, + logsumexp, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + ) + input_gen_fns = { + 4: create_num_blocks_fake_generator(kv_indices), + 5: create_indices_fake, + 6: create_num_blocks_fake_generator(full_kv_indices), + 7: create_indices_fake, + } + + out = autotune_select_algorithm( + "flex_attention", + choices, + # Need to filter out symbols since there is an invariant + # that all input_nodes are of type IRNode + [x for x in inputs_for_autotuning if isinstance(x, torch._inductor.ir.IRNode)], + layout, + input_gen_fns=input_gen_fns, + ) + + # need subgraph inputs and outputs to analyze all symints used in flex attention + out.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + out.data.data.subgraph_outs = get_fwd_subgraph_outputs( + subgraph_buffer, mask_graph_buffer + ) + + return (out, logsumexp) + + +# ---------------------------- Backward HOP Implementation ---------------------------- + + +def flex_attention_backward_grid( + batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta +): + """How is this kernel parallelized? + Currently this is only parallelizing over batch* kv_heads, but we can, and want to + parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size). + To do this will either require atomic updates to some grad values or to have a two pass kernel design. + """ + import triton + + return ( + triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads) + + triton.cdiv(num_key_value, meta["BLOCK_N1"]), + 1, + batch_size * kv_heads, + ) + + +flex_attention_backward_template = TritonTemplate( + name="flex_attention_backward", + grid=flex_attention_backward_grid, + source=r""" +{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}} + stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}} + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}} + stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + HKV = {{size("K", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_hz = tl.program_id(2) + off_zq = off_hz // HKV # q batch idx + off_hkv = off_hz % HKV # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}} + stride_q_idx_h = {{stride("Q_IDX", 1)}} + stride_q_idx_n = {{stride("Q_IDX", 2)}} + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + {{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}} + +@triton.jit +def bwd_dq_inner( + {{gen_argdefs()}}, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + if not IS_DIVISIBLE: + if hi >= 1: + for start_n in range(0, hi - 1): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + else: + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +): + {{gen_defines() | indent_except_first(1)}} + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds prior to the last loop + m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(1) }} + if CHECK_BLOCK_BOUNDARY: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = offs_m2[:, None] < Q_LEN and offs_n2[None, :] < KV_LEN + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + {{gen_argdefs()}}, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + if not IS_DIVISIBLE: + if hi >= 1: + for start_m in range(0, hi - 1): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + + offs_m1 += offset + + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + else: + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +): + {{gen_defines() | indent_except_first(1) }} + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds prior to the last loop + n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) + + pre_mod_scores = qkT + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qkT" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="dsT" + ) | indent_except_first(1) }} + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="idx_b", + h="idx_h", + m="idx_m", + n="idx_n", + grad_score_mod="dsT" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + if CHECK_BLOCK_BOUNDARY: + grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0) + + dsT = grad_scores + if not IS_FULL_BLOCKS: + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + """ + + compute_next_offset_func + + get_bounded_indices_func + + load_checked_2d, +) + + +def validate_joint_graph(joint_graph: torch.fx.Graph): + """We do some pre lowering graph checks in order to raise nicer error messages""" + for node in joint_graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.flex_lib.zeros_and_scatter.default + ): + for user in node.users: + if user.op != "output": + raise NotImplementedError( + "Using multiple indexing operations on the same tensor that requires gradients " + "in a score_mod function is not currently supported. " + "This typically happens when indexing the same tensor multiple times, like:\n\n" + " def score_mod(score, b, h, q_idx, kv_idx):\n" + " return score + bias[q_idx] + bias[kv_idx] # bias used twice!\n\n" + "A valid workaround is to clone() the tensors that will be indexed multiple times. For example:\n\n" + " bias1 = bias.clone()\n" + " def score_mod(score, b, h, q_idx, kv_idx):\n" + " return score + bias[q_idx] + bias1[kv_idx]\n\n" + "Note that this solution will use additional memory." + ) + return + + +@dataclass(frozen=True) +class JointOutputResult: + """Results from processing joint outputs.""" + + grad_input: ComputedBuffer + captured_grads_compute: list[ComputedBuffer] + captured_grads: list[Optional[TensorBox]] + mutated_grads: list[TensorBox] + + +def process_joint_outputs( + all_joint_outputs: SubgraphResults, num_placeholders: int +) -> JointOutputResult: + """Process joint outputs and extract various buffers needed for lowering + + Args: + all_joint_outputs: List of all the outputs from build_subgraphs + num_placeholders: The number of placeholder inputs, used to skip over unused backward compute buffers + + Returns: + JointOutputResult containing processed buffers and gradients + """ + assert isinstance(all_joint_outputs, list) + assert all_joint_outputs[0] is not None, ( + "joint_subgraph_buffer is None - this is a bug!" + ) + + joint_buffer = all_joint_outputs[0] + other_grads = all_joint_outputs[num_placeholders - 1 :] + + # outer_grads has the structure: Len(other_buffer_grads) if buffer doesn't require grad than it will be None + # We only grab the buffers that require grad for inlining into kernel + grads_compute = [buf for buf in other_grads if buf is not None] + + def get_out(buf): + if buf is None: + return None + assert isinstance(buf, ComputedBuffer) + assert buf.name is not None + return TensorBox.create(V.graph.get_buffer(buf.name)) + + grads_out = [get_out(x) for x in other_grads] + mutated_grads = [buf for buf in grads_out if buf is not None] + + return JointOutputResult( + grad_input=joint_buffer, + captured_grads_compute=grads_compute, + captured_grads=grads_out, + mutated_grads=mutated_grads, + ) + + +# TODO: We probably also need a layout constraint? +@register_lowering( + torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None +) +def flex_attention_backward(*args, **kwargs): + ( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) = args + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + mask_graph, + ) = block_mask + + ( + query, + key, + value, + grad_out, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + grad_out, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + device = query.get_device() + dtype = query.get_dtype() + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + + kernel_options = dict(kernel_options) + # Mark symbols in custom kernel options as static shapes and add guards. + kernel_options = { + k: V.graph.sizevars.evaluate_static_shape(v) + if isinstance(v, sympy.Symbol) + else v + for k, v in kernel_options.items() + } + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) + if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: + kernel_options.setdefault("IS_DIVISIBLE", False) + else: + kernel_options.setdefault("IS_DIVISIBLE", True) + + fwd_placeholder_inps = [ + create_placeholder(name, dtype, device) + for name, dtype in [ + ("score", dtype), + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + fw_subgraph_buffer = build_subgraph_buffer( + fwd_placeholder_inps + list(score_mod_other_buffers), fw_graph + ) + + joint_placeholder_inps = fwd_placeholder_inps + [ + create_placeholder("grad_score_mod", dtype, device) + ] + # Sometimes we have weird unused nodes here + joint_graph.graph_module.graph.eliminate_dead_code() + + # It is hard to raise nice errors for some joint graphs during subgraph lowering + # This lets us do some checks before attempting to lower + validate_joint_graph(joint_graph.graph_module.graph) + + all_joint_outputs = build_subgraph_buffer( + joint_placeholder_inps + list(score_mod_other_buffers), + joint_graph, + ) + + joint_outputs = process_joint_outputs( + all_joint_outputs, len(joint_placeholder_inps) + ) + + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + mask_graph_buffer = build_subgraph_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph + ) + + mask_graph_buffer = mask_graph_buffer + + # Construct layout with stride order matching K + key_size = [Bq, Hkv, seq_len_kv, qk_head_dim] + key_strides = infer_dense_strides(key_size, key.get_stride()) + + layout_broadcasted_k = FixedLayout( + key.get_device(), + key.get_dtype(), + key_size, + stride=[sympy.sympify(s) for s in key_strides], + ) + + # Create delta which will is needed for the bwd's kernel + grad_lse_exp2 = lowerings[aten.mul](grad_logsumexp, 1 / math.log(2)) + mul_delta = lowerings[aten.mul](out, grad_out) + delta = lowerings[aten.sum](mul_delta, axis=-1) + delta = lowerings[aten.sub](delta, grad_lse_exp2) + delta = ExternKernel.require_contiguous(delta) + + grad_lse_exp2, delta = maybe_realize([grad_lse_exp2, delta]) + + # # see NOTE:[TritonTemplates with multiple outputs] + query_size = [Bq, Hq, seq_len_q, qk_head_dim] + grad_query_strides = infer_dense_strides(query_size, query.get_stride()) + grad_query = empty_strided( + query_size, + stride=[sympy.sympify(s) for s in grad_query_strides], + dtype=query.get_dtype(), + device=query.get_device(), + ) + + # Construct output layout with stride order matching value + value_size = [Bq, Hkv, seq_len_kv, v_head_dim] + value_strides = infer_dense_strides(value_size, value.get_stride()) + + broadcasted_grad_value = empty_strided( + value_size, + stride=[sympy.sympify(s) for s in value_strides], + dtype=value.get_dtype(), + device=value.get_device(), + ) + + kernel_options.setdefault("SM_SCALE", scale) + + # Determine GQA factor + gqa_shared_heads = Hq // Hkv + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Inside of Triton kernel, only apply partial masking if partial blocks are computed. + # full_kv_num_blocks is torch.zeros([1, 1, 1]) if partial blocks are not computed. + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + full_kv_num_blocks, full_kv_indices, full_q_num_blocks, full_q_indices = ( + empty(0, device=query.get_device()) for _ in range(4) + ) + + set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) + + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + + choices: list[Any] = [] + + dtype = query.get_dtype() + head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) + configs = V.choices.get_flex_attention_bwd_configs(head_dim, dtype) + + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + original_kernel_options = kernel_options.copy() + for conf in configs: + if ( + SPARSE_KV_BLOCK_SIZE % conf.block_m != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_m != 0 + or SPARSE_KV_BLOCK_SIZE % conf.block_n != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_n != 0 + ): + continue + + # Performance tuning + # Triton heuristics + cur_kernel_options = original_kernel_options.copy() + # Remove prefix for backward kernels options and delete forward kernel options. + for k in list(cur_kernel_options.keys()): + if k.startswith("bwd_"): + v = cur_kernel_options.pop(k) + cur_kernel_options[k[4:]] = v + if k.startswith("fwd_"): + cur_kernel_options.pop(k) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + cur_kernel_options.setdefault("BLOCK_M1", conf.block_m) + cur_kernel_options.setdefault("BLOCK_N1", conf.block_n) + cur_kernel_options.setdefault("BLOCK_M2", conf.block_n) + cur_kernel_options.setdefault("BLOCK_N2", conf.block_m) + + # Blocksparse options + cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + + # ROCm specific kernargs + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + + flex_attention_backward_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + broadcasted_grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ], + layout=layout_broadcasted_k, # We use store_output only for grad_key + subgraphs=[ + fw_subgraph_buffer, + joint_outputs.grad_input, + mask_graph_buffer, + joint_outputs.captured_grads_compute, + ], + mutated_inputs=[ + grad_query, + broadcasted_grad_value, + *joint_outputs.mutated_grads, + ], + call_sizes=query.get_size() + key.get_size()[1:3], + **cur_kernel_options, + ) + inputs_for_autotuning = ( + [ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + broadcasted_grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + + joint_outputs.mutated_grads + ) + input_gen_fns = { + 8: create_num_blocks_fake_generator(kv_indices), # kv_num_blocks + 9: create_indices_fake, + 10: create_num_blocks_fake_generator(q_indices), # q_num_blocks + 11: create_indices_fake, + 12: create_num_blocks_fake_generator(full_kv_indices), # full_kv_num_blocks + 13: create_indices_fake, + 14: create_num_blocks_fake_generator(full_q_indices), # full_q_num_blocks + 15: create_indices_fake, + } + + broadcasted_grad_key = autotune_select_algorithm( + "flex_attention_backward", + choices, + [x for x in inputs_for_autotuning if isinstance(x, torch._inductor.ir.IRNode)], + layout_broadcasted_k, + input_gen_fns=input_gen_fns, + ) # [Bq, Hkv, seq_len_kv, k_head_dim] + + # need subgraph inputs and outputs to analyze all symints used in flex attention + broadcasted_grad_key.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + broadcasted_grad_key.data.data.subgraph_outs = get_bwd_subgraph_outputs( + fw_subgraph_buffer, mask_graph_buffer, joint_outputs + ) + + if V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv)): + grad_key = broadcasted_grad_key + grad_value = broadcasted_grad_value + else: + assert V.graph.sizevars.evaluate_expr(sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. " + f"Got Bq={V.graph.sizevars.evaluate_expr(Bq)} " + f"and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}" + ) + grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True) + grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True) + + return (grad_query, grad_key, grad_value, tuple(joint_outputs.captured_grads)) + + +def get_bwd_subgraph_outputs( + subgraph_buffer: SubgraphResults, + mask_graph_buffer: SubgraphResults, + joint_outputs: JointOutputResult, +) -> list[Optional[Union[ComputedBuffer, TensorBox]]]: + subgraph_buffer = ( + subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer] + ) + mask_graph_buffer = ( + mask_graph_buffer + if isinstance(mask_graph_buffer, Sequence) + else [mask_graph_buffer] + ) + joint_output_buffers = [ + joint_outputs.grad_input, + *joint_outputs.captured_grads_compute, + *joint_outputs.captured_grads, + *joint_outputs.mutated_grads, + ] + + return [*subgraph_buffer, *mask_graph_buffer, *joint_output_buffers] diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py new file mode 100644 index 0000000000000..0c663421a036e --- /dev/null +++ b/torch/_inductor/kernel/flex_decoding.py @@ -0,0 +1,628 @@ +# mypy: allow-untyped-defs +"""Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)""" + +from typing import Any + +import sympy + +import torch +from torch._inductor.virtualized import V + +from .. import ir +from ..ir import FixedLayout, FlexibleLayout +from ..lowering import empty, empty_strided, lowerings +from ..runtime.runtime_utils import is_power_of_2, next_power_of_2 +from ..select_algorithm import autotune_select_algorithm, SymbolicGridFn, TritonTemplate +from .flex_attention import ( + compute_forward_block_mn, + compute_forward_inner, + compute_next_offset_func, + create_indices_fake, + create_num_blocks_fake_generator, + get_bounded_indices_func, + get_fwd_subgraph_outputs, + load_checked_2d, + load_checked_block, + maybe_realize, +) + + +aten = torch.ops.aten +prims = torch.ops.prims + + +@SymbolicGridFn +def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta): + """How is this kernel parallelized? + We create a grid of (batch_size * kv_heads, SPLIT_KV, 1) + Each block is responsible for iterating over blocks of keys and values calculating + the local output for their tile of keys and values over all full length of query. + groups of SPLIT_KV blocks then combine their output to produce the final result. + """ + + return (batch_size * kv_heads, meta["SPLIT_KV"], 1) + + +flex_decoding_template = TritonTemplate( + name="flex_decoding", + grid=flex_decoding_grid, + source=r""" + {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}} + stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}} + + + Z = {{size("Q", 0)}} + ZKV = {{size("K", 0)}} + HKV = {{size("Q", 1)}} + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = {{size("Q", 3)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0) % HKV + off_t = tl.program_id(1) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}} + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}} + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Apply both score_mod and mask_mod + + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(off_n, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + None, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(off_n, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + None, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} + """ + + compute_forward_inner + + get_bounded_indices_func + + load_checked_block + + load_checked_2d + + compute_next_offset_func + + compute_forward_block_mn, +) + + +def get_split_k(B: int, H: int, Mk: int) -> int: + num_SM = torch.cuda.get_device_properties("cuda").multi_processor_count + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + assert isinstance(bh, (int, sympy.Integer)), "B and H must be concrete integers" + split_k = num_SM // bh * 2 # Each SM should at least get one block. + # TODO: workload evening at runtime for splits fully masked out. + # Before we have runtime workload evening, assign 2 splits per SM. + split_k = max(split_k, 1) + + return split_k + + +def create_flex_decoding_kernel(*args, **kwargs): + from .flex_attention import set_head_dim_values + + ( + query, + key, + value, + block_mask, + scale, + kernel_options, + score_mod_subgraph, + mask_mod_subgraph, + score_mod_other_buffers, + mask_mod_other_buffers, + ) = args + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, # full_kv_num_blocks, + full_kv_indices, # full_kv_indices, + _, # q_num_blocks + _, # q_indices + _, # full_q_num_blocks, + _, # full_q_indices, + _, # SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + _, + ) = block_mask + + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + + B = Bq + kernel_options = dict(kernel_options) + # Mark symbols in custom kernel options as static shapes and add guards. + kernel_options = { + k: V.graph.sizevars.evaluate_static_shape(v) + if isinstance(v, sympy.Symbol) + else v + for k, v in kernel_options.items() + } + + # TODO: Fix flex decoding non-divisible case! + if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: + kernel_options.setdefault("IS_DIVISIBLE", False) + else: + kernel_options.setdefault("IS_DIVISIBLE", True) + + # Calculate GQA head sharing + gqa_shared_heads = Hq // Hkv + if not is_power_of_2(gqa_shared_heads): + raise ValueError( + "Number of shared query heads sharing the same KV head must be power of 2. " + ) + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + # Create a plackeholder full block list in case it is empty + full_kv_num_blocks, full_kv_indices = ( + empty(0, device=query.get_device()) for _ in range(2) + ) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + ) + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) + mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) + + choices: list[Any] = [] + dtype = key.get_dtype() + head_dim = V.graph.sizevars.evaluate_static_shape(key.get_size()[-1]) + configs = V.choices.get_flex_decode_configs(head_dim, dtype) + + # TODO: fix autotuning. + + kernel_options.setdefault("SM_SCALE", scale) + kernel_options.setdefault("SPLIT_KV", get_split_k(B, Hkv, seq_len_kv)) + MAX_SPLIT_KV = kernel_options["SPLIT_KV"] + + # create config dependent intermediate buffers + buf_ACC_shape = [B, MAX_SPLIT_KV, Hq, seq_len_q, v_head_dim] + buf_ML_shape = buf_ACC_shape[:-1] + buf_M = empty_strided( + buf_ML_shape, + None, + dtype=torch.float32, # The rowmax is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + buf_L = empty_strided( + buf_ML_shape, + None, + dtype=torch.float32, # The intermediate sumexp is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + + layout_acc = FixedLayout( + query.get_device(), + torch.float32, + buf_ACC_shape, + FlexibleLayout.contiguous_strides(buf_ACC_shape), + ) + + set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) + + kernel_options.setdefault( + "BLOCK_M", + ( + # m + # if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0)) + # else # Always use a BLOCK_M > 16 before Triton fix https://github.com/triton-lang/triton/pull/4061 is in pin + max( + next_power_of_2( + V.graph.sizevars.size_hint( + seq_len_q, + fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type] + ) + * gqa_shared_heads + ), + 16, + ) + ), + ) + + query = ir.ExternKernel.realize_input(query) + stride_b, stride_hq, stride_seq_len_q, stride_qk_head_dim = query.get_stride() + + # Reshape query for GQA: [B, Hq, Mq, D] -> [B, Hkv, G, Mq, D] + gqa_query_shape = (B, Hkv, gqa_shared_heads, seq_len_q, qk_head_dim) + gqa_query_stride = ( + stride_b, + stride_hq * gqa_shared_heads, + stride_hq, + stride_seq_len_q, + stride_qk_head_dim, + ) + query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride) + + V.graph.sizevars.guard_leq( + seq_len_q * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"]) + ) + + kernel_options.setdefault( + "SAFE_M_BOUNDARY", + ((seq_len_q * gqa_shared_heads) % kernel_options["BLOCK_M"]) == 0, + ) + # TODO: This feels sketchy + kernel_options.setdefault("SAFE_N_BOUNDARY", True) + # Mark SPARSE_KV_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + + original_kernel_options = kernel_options.copy() + # Note, we don't need to pass in the captured buffers explicitly + # because they're implicitly added by the score_mod function + # We do need to explicitly pass it in for autotuning though. + + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + for conf in configs: + if SPARSE_KV_BLOCK_SIZE % conf.block_n != 0: + continue + + cur_kernel_options = original_kernel_options.copy() + # Remove prefix for forward kernels options and delete backward kernel options. + for k in list(cur_kernel_options.keys()): + if k.startswith("fwd_"): + v = cur_kernel_options.pop(k) + cur_kernel_options[k[4:]] = v + if k.startswith("bwd_"): + cur_kernel_options.pop(k) + # Performance tuning + cur_kernel_options.setdefault("BLOCK_N", conf.block_n) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + # Set default to False + cur_kernel_options.setdefault("USE_TMA", False) + + # Add ROCm-specific parameters if they exist in the config + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + + flex_decoding_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + buf_M, + buf_L, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout_acc, + subgraphs=[ + score_mod_subgraph, + mask_mod_subgraph, + ], + mutated_inputs=[buf_M, buf_L], + call_sizes=query.get_size(), + **cur_kernel_options, + ) + + inputs_for_flex_decoding = ( + [ + query, + key, + value, + buf_M, + buf_L, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + ) + + input_gen_fns = { + 5: create_num_blocks_fake_generator(kv_indices), + 6: create_indices_fake, + 7: create_num_blocks_fake_generator(full_kv_indices), + 8: create_indices_fake, + } + + buf_ACC = autotune_select_algorithm( + "flex_decoding", + choices, + inputs_for_flex_decoding, + layout_acc, + input_gen_fns=input_gen_fns, + ) + + # need subgraph inputs and outputs to analyze all symints used in flex attention + buf_ACC.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + buf_ACC.data.data.subgraph_outs = get_fwd_subgraph_outputs( + score_mod_subgraph, mask_mod_subgraph + ) + + # Reduction + + g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0] + # See [Note] Handle fully masked out rows: + # g_M Is the global max among split kv blocks. + masked_rows = lowerings[aten.eq](g_M, -float("inf")) + adj_M = lowerings[aten.sub](buf_M, g_M) + adj_M = lowerings[aten.where](masked_rows, 0, adj_M) + alpha = lowerings[aten.exp2](adj_M) + + buf_L = lowerings[aten.mul](buf_L, alpha) + g_L = lowerings[aten.sum](buf_L, axis=1) + masked_rows_squeezed = lowerings[aten.squeeze](masked_rows, dim=1) + g_L = lowerings[aten.where](masked_rows_squeezed, 1.0, g_L) + logsumexp = lowerings[aten.log2](g_L) + logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1)) + + alpha_unseq = lowerings[aten.unsqueeze](alpha, 4) + buf_ACC = lowerings[aten.mul](buf_ACC, alpha_unseq) + output = lowerings[aten.sum](buf_ACC, axis=1) + L_unseq = lowerings[aten.unsqueeze](g_L, 3) + output = lowerings[aten.div](output, L_unseq) + output = lowerings[prims.convert_element_type](output, query.get_dtype()) + + return ( + output, + logsumexp, + ) diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index a597107510e78..88ad7e5c5ccd4 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -3,6 +3,11 @@ import logging from typing import Any, Optional +<<<<<<< HEAD +======= +import sympy + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch._dynamo.utils import counters from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm @@ -13,11 +18,15 @@ mm_operations, ) from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate +<<<<<<< HEAD from torch._inductor.remote_gemm_autotune_cache import gen_best_config +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.virtualized import V from torch.fx.experimental.proxy_tensor import make_fx from torch.torch_version import TorchVersion +<<<<<<< HEAD from .. import config as inductor_config from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate @@ -26,6 +35,20 @@ from ..ir import Buffer, ChoiceCaller, FlexibleLayout, is_triton, Layout from ..kernel_inputs import MMKernelInputs from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering +======= +from .. import config as inductor_config, ir +from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate +from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate +from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate +from ..codegen.subgraph import SubgraphTemplate +from ..ir import FlexibleLayout, is_triton +from ..lowering import ( + add_layout_constraint, + constrain_to_fx_strides, + lowerings as L, + register_lowering, +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..select_algorithm import ( autotune_select_algorithm, ExternKernelChoice, @@ -34,6 +57,11 @@ ) from ..utils import ( _use_cutlass_for_op, +<<<<<<< HEAD +======= + get_k_splits, + get_tma_workspace_arg, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) use_aten_gemm_kernels, use_ck_gemm_template, use_ck_tile_gemm_template, @@ -43,7 +71,22 @@ use_triton_template, use_triton_tma_template, ) +<<<<<<< HEAD from .mm_common import _is_static_problem, mm_args, mm_grid, persistent_mm_grid +======= +from .mm_common import ( + _is_static_problem, + addmm_epilogue, + mm_args, + mm_config_kwargs, + mm_grid, + mm_options, + persistent_mm_grid, + persistent_mm_options, + scale_mm_epilogue, + scaled_mm_options, +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: @@ -92,11 +135,19 @@ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) +<<<<<<< HEAD if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1): offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) else: offs_a_m = rm % M if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1): +======= + if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and M >= BLOCK_M: + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and N >= BLOCK_N: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) else: offs_b_n = rn % N @@ -243,11 +294,18 @@ rk_for_mask = tl.arange(0, BLOCK_K) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) +<<<<<<< HEAD {%- if TMA_EXPERIMENTAL_API %} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE a_desc_ptr = workspace_base b_desc_ptr = workspace_base + TMA_SIZE +<<<<<<< HEAD +======= + {%- if TMA_EXPERIMENTAL_API %} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) triton.language.extra.cuda.experimental_device_tensormap_create2d( desc_ptr=a_desc_ptr, global_address=A, @@ -266,6 +324,7 @@ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) +<<<<<<< HEAD {%- else %} stride_am = {{stride("A", 0)}} stride_ak = {{stride("A", 1)}} @@ -275,12 +334,25 @@ base=A, shape=[M, K] if A_ROW_MAJOR else [K, M], strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1], +======= + a_desc = a_desc_ptr + b_desc = b_desc_ptr + {%- else %} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K] if A_ROW_MAJOR else [K, M], + strides=[K, 1] if A_ROW_MAJOR else [M, 1], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], ) b_desc = triton.language.make_tensor_descriptor( base=B, shape=[K, N] if B_ROW_MAJOR else [N, K], +<<<<<<< HEAD strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1], +======= + strides=[N, 1] if B_ROW_MAJOR else [K, 1], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], ) {%- endif %} @@ -307,13 +379,21 @@ {%- if TMA_EXPERIMENTAL_API %} a = tl._experimental_descriptor_load( +<<<<<<< HEAD a_desc_ptr, +======= + a_desc, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [rm, rk] if A_ROW_MAJOR else [rk, rm], [BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], A.dtype.element_ty, ) b = tl._experimental_descriptor_load( +<<<<<<< HEAD b_desc_ptr, +======= + b_desc, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [rk, rn] if B_ROW_MAJOR else [rn, rk], [BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], B.dtype.element_ty, @@ -425,11 +505,18 @@ def apply_scaling( k_tiles = tl.cdiv(K, BLOCK_K) num_tiles = num_pid_m * num_pid_n +<<<<<<< HEAD {%- if TMA_EXPERIMENTAL_API %} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE a_desc_ptr = workspace_base b_desc_ptr = workspace_base + TMA_SIZE +<<<<<<< HEAD +======= + {%- if TMA_EXPERIMENTAL_API %} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) triton.language.extra.cuda.experimental_device_tensormap_create2d( desc_ptr=a_desc_ptr, global_address=A, @@ -448,6 +535,7 @@ def apply_scaling( tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) +<<<<<<< HEAD {%- else %} stride_am = {{stride("A", 0)}} stride_bn = {{stride("B", 1)}} @@ -455,12 +543,25 @@ def apply_scaling( base=A, shape=[M, K], strides=[stride_am, 1], +======= + a_desc = a_desc_ptr + b_desc = a_desc_ptr + {%- else %} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K], + strides=[K, 1], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) block_shape=[BLOCK_M, BLOCK_K], ) b_desc = triton.language.make_tensor_descriptor( base=B, shape=[N, K], +<<<<<<< HEAD strides=[stride_bn, 1], +======= + strides=[K, 1], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) block_shape=[BLOCK_N, BLOCK_K], ) {%- endif %} @@ -552,6 +653,7 @@ def lazy_register_extern_choice(fn): return ExternKernelChoice(fn) +<<<<<<< HEAD aten_mm = ExternKernelChoice(torch.mm, "at::mm_out", op_overload=aten.mm.out) aten_addmm = ExternKernelChoice( @@ -561,12 +663,24 @@ def lazy_register_extern_choice(fn): aten__int_mm = ExternKernelChoice( torch._int_mm, "at::_int_mm_out", op_overload=aten._int_mm.out ) +======= +aten_mm = ExternKernelChoice(torch.mm, "at::mm_out") + +aten_addmm = ExternKernelChoice( + torch.addmm, "at::addmm_out", op_overload=aten.addmm.default +) + +aten__int_mm = ExternKernelChoice(torch._int_mm, "at::_int_mm_out") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten__sparse_semi_structured_mm = ExternKernelChoice( torch._sparse_semi_structured_mm, "at::_sparse_semi_structured_mm", has_out_variant=False, +<<<<<<< HEAD op_overload=aten._sparse_semi_structured_mm.default, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) aten__fp8_mm = ExternKernelChoice( @@ -578,13 +692,35 @@ def _is_int8_mat(mat): return mat.get_dtype() in (torch.int8, torch.uint8) +<<<<<<< HEAD +======= +def _is_large_block_for_cpu(m, n, k): + # Thresholds are experimentally determined to reduce Triton CPU compile times + return m * n > 2**13 + + +@functools.lru_cache +def using_b200() -> bool: + """Returns true if the device is a NVIDIA B200, otherwise returns false.""" + if not torch.cuda.is_available(): + return False + # compute capability 10.0 or 10.0a is NVIDIA B200 + device_properties = torch.cuda.get_device_properties(torch.cuda.current_device()) + return device_properties.major == 10 + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): """ Giving torch.addmm a 1D tensor calls a different (faster) cublasLt kernel under the hood. There are a few shapes where this is slower, but they are rare. """ +<<<<<<< HEAD if (inp.stride(0) == 0 and inp.size(0) != 0) or inp.size(0) == 1: +======= + if inp.stride(0) == 0 or inp.size(0) == 1: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta) return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta) @@ -632,6 +768,7 @@ def decomposeK(a, b, k_splits): return reduced_buf.to(a.dtype) +<<<<<<< HEAD class DecomposeKSugraphTemplate(SubgraphTemplate): def __init__(self): super().__init__( @@ -720,11 +857,14 @@ def contiguous_addmm(inp, a, b): ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_lowering(aten.mm, type_promotion_kind=None) def tuned_mm(mat1, mat2, *, layout=None): """ Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.) """ +<<<<<<< HEAD # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) static_shape, is_nonzero = _is_static_problem(layout) @@ -733,6 +873,12 @@ def tuned_mm(mat1, mat2, *, layout=None): # Create MMKernelInputs for standard MM at the top kernel_inputs = MMKernelInputs([mat1, mat2]) +======= + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + device_type = ir.get_device_type(mat1) + name = "mm" + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # below is for getting an overview logging info of inductor mms counters["aten_mm_info"][f"aten.mm_{m}_{n}_{k}"] += 1 log.info( @@ -750,6 +896,7 @@ def tuned_mm(mat1, mat2, *, layout=None): aten_layout = FlexibleLayout( device=layout.device, dtype=layout.dtype, size=layout.size ) +<<<<<<< HEAD choices: list[ChoiceCaller] = [] if use_aten_gemm_kernels(): choices.extend( @@ -781,12 +928,103 @@ def tuned_mm(mat1, mat2, *, layout=None): kernel_inputs, layout, [mm_contiguous_subgraph_template], "mm" ) ) +======= + + # options to tune from + choices = ( + [aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else [] + ) + static_shape, is_nonzero = _is_static_problem(layout) + + mm_configs = V.choices.get_base_mm_configs(device_type) + persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type) + extra_mm_configs = V.choices.get_extra_mm_configs(device_type) + + dtype = mat1.get_dtype() + if is_nonzero and use_triton_template(layout): + for config in mm_configs( + m, + n, + k, + **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize), + ): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + + if use_triton_tma_template(mat1, mat2): + for config in persistent_mm_configs( + m, + n, + k, + **mm_config_kwargs( + device_type, _is_large_block_for_cpu, dtype.itemsize + ), + ): + persistent_tma_mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat1.get_device(), + ), + **mm_options(config, m, n, k, layout), + **persistent_mm_options(mat1, mat2), + ) + + from torch._inductor.ir import get_free_symbols + + # Only do split-k optimization if K is much larger than m, n and m, n are small + # and if there aren't any unbacked symbols + unbacked_symbols = any( + len(get_free_symbols(itr, unbacked_only=True)) > 0 + for itr in ( + mat1.get_size(), + mat1.get_stride(), + mat2.get_size(), + mat2.get_stride(), + ) + ) + if use_decompose_k_choice(m, n, k) and not unbacked_symbols: + from torch._dispatch.python import enable_python_dispatcher + + from ..decomposition import select_decomp_table + + k_splits = get_k_splits(m, n, k) + for k_split in k_splits: + if not V.graph.sizevars.statically_known_true( + sympy.Eq(sympy.Mod(k, k_split), 0) + ): + continue + + with enable_python_dispatcher(): + decompositions = select_decomp_table() + + decompose_k_subgraph_template = SubgraphTemplate( + name=f"decompose_k_mm_{k_split}_split", + make_fx_graph=make_fx( + functools.partial(decomposeK, k_splits=k_split), + decompositions, + ), + ) + + decompose_k_subgraph_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( is_nonzero and use_cutlass_template(layout, m, n, k) and _use_cutlass_for_op("mm") ): +<<<<<<< HEAD CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( choices, layout, kernel_inputs.nodes() ) @@ -795,12 +1033,24 @@ def tuned_mm(mat1, mat2, *, layout=None): CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes()) if is_nonzero and use_ck_tile_gemm_template(layout, m, n, k): CKTileGemmTemplate.add_choices(choices, layout, kernel_inputs.nodes()) +======= + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) + + if is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2]) + if is_nonzero and use_ck_tile_gemm_template(layout, m, n, k): + CKTileGemmTemplate.add_choices(choices, layout, [mat1, mat2]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if use_cpp_gemm_template(layout, mat1, mat2): CppGemmTemplate.add_choices( choices, layout, +<<<<<<< HEAD kernel_inputs.nodes(), +======= + [mat1, mat2], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) input_nodes = [mat1, mat2] @@ -814,6 +1064,7 @@ def tuned_mm(mat1, mat2, *, layout=None): if use_aten_gemm_kernels(): always_included.append("extern_mm") num_choices_before_extra_configs = len(choices) +<<<<<<< HEAD choices.extend( V.choices.get_mm_configs( # TODO(coconutruben): remove once we deprecate ah @@ -825,6 +1076,17 @@ def tuned_mm(mat1, mat2, *, layout=None): "mm-ah", ) ) +======= + for config in extra_mm_configs( + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + ): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # using AutoHeuristic for ranking ah_choices = mm_autoheuristic( @@ -853,6 +1115,7 @@ def tuned_mm(mat1, mat2, *, layout=None): choices = choices[:num_choices_before_extra_configs] for k in inductor_config.external_matmul: +<<<<<<< HEAD choices.append( lazy_register_extern_choice(k).bind(kernel_inputs.nodes(), layout) ) @@ -870,15 +1133,27 @@ def tuned_mm(mat1, mat2, *, layout=None): layout, best_config_future=best_config_future, ) +======= + choices.append(lazy_register_extern_choice(k).bind((mat1, mat2), layout)) + + return autotune_select_algorithm(name, choices, [mat1, mat2], layout) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_lowering(aten._int_mm, type_promotion_kind=None) def tuned_int_mm(mat1, mat2, *, layout=None): +<<<<<<< HEAD # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that m, n, k, layout, mat1, mat2 = mm_args( mat1, mat2, layout=layout, out_dtype=torch.int32 ) name = "int_mm" +======= + m, n, k, layout, mat1, mat2 = mm_args( + mat1, mat2, layout=layout, out_dtype=torch.int32 + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # below is for getting an overview logging info of inductor mms counters["aten_mm_info"][f"aten._int_mm_{m}_{n}_{k}"] += 1 log.info( @@ -891,6 +1166,7 @@ def tuned_int_mm(mat1, mat2, *, layout=None): layout, ) +<<<<<<< HEAD static_shape, is_nonzero = _is_static_problem(layout) use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k) choices: list[ChoiceCaller] = [] @@ -920,10 +1196,41 @@ def tuned_int_mm(mat1, mat2, *, layout=None): ) return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) +======= + device_type = ir.get_device_type(mat1) + + static_shape, is_nonzero = _is_static_problem(layout) + use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k) + + choices = ( + [aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] + ) + + if use_cutlass and _use_cutlass_for_op("int_mm"): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True + ) + + int8_mm_configs = V.choices.get_int8_mm_configs(device_type) + + if is_nonzero and use_triton_template(layout, enable_int32=True): + for config in int8_mm_configs( + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + ): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + + return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_lowering(aten.addmm, type_promotion_kind=None) def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): +<<<<<<< HEAD """ Lowering for autotuning aten.addmm with different backends (Aten, Triton, CUTLASS, etc.) """ @@ -936,6 +1243,11 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): [inp_expanded, mat1, mat2], scalars=dict(alpha=alpha, beta=beta) ) choices: list[ChoiceCaller] = [] +======= + device_type = ir.get_device_type(mat1) + m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout) + static_shape, is_nonzero = _is_static_problem(layout) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # below is for getting an overview logging info of inductor mms counters["aten_mm_info"][f"aten.addmm_{m}_{n}_{k}"] += 1 @@ -948,7 +1260,11 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): mat2.get_dtype(), layout, ) +<<<<<<< HEAD aten_layout = layout +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (not is_nonzero) or ( not (inductor_config.max_autotune or inductor_config.max_autotune_gemm) ): @@ -957,6 +1273,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): from torch._inductor.ir import FixedLayout, FlexibleLayout if isinstance(layout, FixedLayout): +<<<<<<< HEAD aten_layout = FlexibleLayout( device=layout.device, dtype=layout.dtype, size=layout.size ) @@ -1025,29 +1342,133 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): "addmm", ) ) +======= + layout = FlexibleLayout( + device=layout.device, dtype=layout.dtype, size=layout.size + ) + choices = ( + [ + aten_addmm.bind( + (inp, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ] + if use_aten_gemm_kernels() + else [] + ) + return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout) + + choices = ( + [ + aten_addmm.bind( + (inp_expanded, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ] + if use_aten_gemm_kernels() + else [] + ) + + if ( + use_aten_gemm_kernels() + and inp_expanded.get_stride()[0] == 0 + and inp_expanded.get_device().type == "cuda" + and inductor_config.triton.autotune_cublasLt + ): + # unexpand inp to make sure fused addmm from cublasLt is used + choices.insert( + 0, + aten_bias_addmm.bind( + (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta + ), + ) + + mm_configs = V.choices.get_base_mm_configs(device_type) + persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type) + + dtype = mat1.get_dtype() + if is_nonzero and use_triton_template(layout): + for config in mm_configs( + m, + n, + k, + **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize), + ): + mm_template.maybe_append_choice( + choices, + input_nodes=(inp_expanded, mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]), + ) + + if use_triton_tma_template(mat1, mat2): + for config in persistent_mm_configs( + m, + n, + k, + **mm_config_kwargs( + device_type, _is_large_block_for_cpu, dtype.itemsize + ), + ): + persistent_tma_mm_template.maybe_append_choice( + choices, + input_nodes=(inp_expanded, mat1, mat2), + layout=layout, + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat1.get_device(), + ), + **mm_options(config, m, n, k, layout), + **persistent_mm_options(mat1, mat2), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( is_nonzero and use_cutlass_template(layout, m, n, k) +<<<<<<< HEAD and _use_cutlass_for_op(name) +======= + and _use_cutlass_for_op("addmm") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( choices, layout, +<<<<<<< HEAD # reorder here because CUTLASS expects (x, w, bias) but torch # is bias, x, w kernel_inputs.nodes(reorder=[1, 2, 0]), alpha=alpha, beta=beta, +======= + [mat1, mat2, inp_expanded], + alpha=alpha, + beta=beta, + input_reorder=[2, 0, 1], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if is_nonzero and use_ck_gemm_template(layout, m, n, k): CKGemmTemplate.add_ck_gemm_choices( choices, layout, +<<<<<<< HEAD # reorder here because CK expects (x, w, bias) but torch # is bias, x, w kernel_inputs.nodes(reorder=[1, 2, 0]), +======= + [mat1, mat2, inp_expanded], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) alpha=alpha, beta=beta, input_reorder=[2, 0, 1], @@ -1057,13 +1478,23 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): CppGemmTemplate.add_choices( choices, layout, +<<<<<<< HEAD kernel_inputs.nodes(), +======= + [inp_expanded, mat1, mat2], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) alpha=alpha, beta=beta, has_bias=True, ) +<<<<<<< HEAD return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) +======= + return autotune_select_algorithm( + "addmm", choices, [inp_expanded, mat1, mat2], layout + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_lowering(aten._sparse_semi_structured_mm, type_promotion_kind=None) @@ -1072,13 +1503,22 @@ def tuned_sparse_semi_structured_mm( ): from torch._inductor.select_algorithm import realize_inputs +<<<<<<< HEAD # TODO(coconturuben): support V.choices.get_mm_configs for sparse_semi_structured_mm +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mat1, mat1_meta, mat2 = realize_inputs(mat1, mat1_meta, mat2) m1, k1 = mat1.get_size() m2, _ = mat1_meta.get_size() k2, n = mat2.get_size() +<<<<<<< HEAD m = V.graph.sizevars.check_equals_and_simplify(m1, m2) k = V.graph.sizevars.check_equals_and_simplify(2 * k1, k2) +======= + m = V.graph.sizevars.guard_equals(m1, m2) + k = V.graph.sizevars.guard_equals(2 * k1, k2) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if layout is None: from torch._inductor.ir import FixedLayout @@ -1111,7 +1551,11 @@ def tuned_sparse_semi_structured_mm( ) return autotune_select_algorithm( +<<<<<<< HEAD "sparse_semi_structured_mm", choices, (mat1, mat1_meta, mat2), layout +======= + "sparse_semi_structured_mm", choices, [mat1, mat1_meta, mat2], layout +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -1145,7 +1589,10 @@ def tuned_scaled_mm( Returns: Tensor: The result of the scaled matrix multiplication """ +<<<<<<< HEAD # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m, n, k, layout, mat_a, mat_b = mm_args( mat_a, mat_b, layout=layout, out_dtype=out_dtype ) @@ -1160,11 +1607,17 @@ def tuned_scaled_mm( mat_b.get_dtype(), layout, ) +<<<<<<< HEAD name = "scaled_mm" +======= + + device_type = ir.get_device_type(mat_a) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) check_supported_striding(mat_a, mat_b) scale_a_real, scale_b_real = realize_inputs(scale_a, scale_b) +<<<<<<< HEAD input_nodes: list[Any] if not bias: @@ -1226,23 +1679,145 @@ def tuned_scaled_mm( kwarg_overrides={mm_template.uid: overriders}, ) ) +======= + input_nodes: tuple[Any, ...] + + if not bias: + input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real) + else: + bias_real = realize_inputs(bias) + input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real, bias_real) + + aten_choice = aten__fp8_mm.bind( + input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum + ) + + choices = [] + if use_aten_gemm_kernels(): + choices.append(aten_choice) + + # We dont have triton lowerings for the MX variants yet + if scale_a.dtype != torch.float32: + return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) + + _, is_nonzero = _is_static_problem(layout) + + scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type) + scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs( + device_type + ) + + if is_nonzero and use_triton_template(layout, enable_float8=True): + triton_input_nodes: tuple[Any, ...] + if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1: + # Need to unsqueeze bias from [N] -> [1, N] + triton_bias = L[aten.unsqueeze](bias, 0) + else: + triton_bias = bias + + if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0: + assert len(scale_a.get_size()) == len(scale_b.get_size()) + # Need to unsqueeze scale from [] -> [1, 1] + triton_scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1) + triton_scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1) + else: + triton_scale_a = scale_a + triton_scale_b = scale_b + + if bias: + triton_input_nodes = ( + mat_a, + mat_b, + triton_scale_a, + triton_scale_b, + triton_bias, + ) + suffix_args = 3 + else: + triton_input_nodes = (mat_a, mat_b, triton_scale_a, triton_scale_b) + suffix_args = 2 + + # TODO (paulzhan): There is no template that exists for bias and TMA + # Don't run tma template currently if bias exists + if use_triton_tma_template(mat_a, mat_b) and not bias: + for config in scaled_persistent_mm_configs(m, n, k): + kwargs = scaled_mm_options( + config, + m, + n, + k, + layout, + scale_a, + scale_b, + use_fast_accum, + device_tma=True, + ) + scaled_mm_device_tma_template.maybe_append_choice( + choices, + input_nodes=triton_input_nodes, + layout=layout, + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat_a.get_device(), + ), + **kwargs, + ) + + for config in scaled_mm_configs(m, n, k): + if V.graph.sizevars.guard_or_false(sympy.Le(k, 16)): + # Triton crashes however uncommon for real workloads + continue + + # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid + # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape + if using_b200() and V.graph.sizevars.guard_or_false(sympy.Lt(k, 32)): + continue + + kwargs = scaled_mm_options( + config, m, n, k, layout, scale_a, scale_b, use_fast_accum + ) + # possibly appends a TritonTemplateCaller to choices + mm_template.maybe_append_choice( + choices, + input_nodes=triton_input_nodes, + layout=layout, + **kwargs, + suffix_args=suffix_args, + epilogue_fn=scale_mm_epilogue(), + epilogue_fn_hash="scale_mm_epilogue", + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( is_nonzero and use_cutlass_template(layout, m, n, k) +<<<<<<< HEAD and _use_cutlass_for_op(name) +======= + and _use_cutlass_for_op("scaled_mm") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( choices, layout, +<<<<<<< HEAD kernel_inputs.nodes(), # type: ignore[arg-type] +======= + input_nodes, # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) use_fast_accum=use_fast_accum, # type: ignore[arg-type] ) if is_nonzero and use_ck_gemm_template(layout, m, n, k): +<<<<<<< HEAD CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes()) return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout) +======= + CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes) + + return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @functools.cache diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 228492fd9a1e5..76ed38d45958c 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -3,13 +3,25 @@ from collections.abc import Sequence from typing import Any +<<<<<<< HEAD +======= +import sympy + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn from torch._inductor.utils import sympy_product from torch._inductor.virtualized import V +<<<<<<< HEAD +from ..codegen.wrapper import PythonWrapperCodegen +from ..ir import _IntLike, Layout, TensorBox +======= +from .. import config as inductor_config from ..codegen.wrapper import PythonWrapperCodegen from ..ir import _IntLike, Layout, TensorBox +from ..utils import get_num_sms, TMA_DESCRIPTOR_SIZE +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log = logging.getLogger(__name__) @@ -45,6 +57,99 @@ def acc_type(dtype): return f"tl.{dtype}".replace("torch.", "") +<<<<<<< HEAD +======= +def mm_options(config, sym_m, sym_n, sym_k, layout): + """ + Common options to matmul triton templates. + """ + even_k_symbolic = ( + # it isn't worth guarding on this + sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] + ) + allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and ( + not inductor_config.force_same_precision + or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0) + ) + options_dict = dict( + EVEN_K=even_k_symbolic, + ALLOW_TF32=allow_tf32, + USE_FAST_ACCUM=False, # Option for _scaled_mm + ACC_TYPE=acc_type(layout.dtype), + num_stages=config.num_stages, + num_warps=config.num_warps, + **config.kwargs, + ) + + # If GROUP_M not specified then default to 8 + if "GROUP_M" not in config.kwargs: + group_m = config.kwargs.get("GROUP_M", 8) + options_dict["GROUP_M"] = group_m + + return options_dict + + +def tma_options() -> dict[str, Any]: + from torch.utils._triton import has_triton_stable_tma_api + + return {"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api()} + + +def persistent_mm_options(mat1, mat2): + res = dict( + A_ROW_MAJOR=not mat1.layout.is_transposed(), + B_ROW_MAJOR=not mat2.layout.is_transposed(), + NUM_SMS=get_num_sms(), + TMA_SIZE=TMA_DESCRIPTOR_SIZE, + ) + res.update(tma_options()) + return res + + +def scaled_mm_options( # type: ignore[no-untyped-def] + config, # triton.Config + sym_m: sympy.core.numbers.Integer, + sym_n: sympy.core.numbers.Integer, + sym_k: sympy.core.numbers.Integer, + layout: Layout, + scale_a, + scale_b, + use_fast_accum: bool, + device_tma: bool = False, +) -> dict[str, Any]: + def are_compatible_scales(size_a, size_b) -> bool: + # Same sized scales are compatible + if len(size_a) == len(size_b): + return True + + # Both need to be scalars or len(1) tensors + if len(size_a) <= 1 and len(size_b) <= 1: + return True + + return False + + size_a, size_b = scale_a.get_size(), scale_b.get_size() + assert are_compatible_scales(size_a, size_b), ( + "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " + f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." + ) + + mm_template_options = mm_options(config, sym_m, sym_n, sym_k, layout) + + mm_template_options["ACC_TYPE"] = "tl.float32" + mm_template_options["USE_FAST_ACCUM"] = use_fast_accum + mm_template_options["SCALING_ROWWISE"] = len(size_a) == 2 + + if device_tma: + mm_template_options["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE + mm_template_options["NUM_SMS"] = get_num_sms() + + mm_template_options.update(tma_options()) + + return mm_template_options + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def mm_args( mat1, mat2, @@ -63,10 +168,17 @@ def mm_args( *b2, n, k2 = mat2.get_size() else: *b2, k2, n = mat2.get_size() +<<<<<<< HEAD b = [V.graph.sizevars.check_equals_and_simplify(a, b) for a, b in zip(b1, b2)] if use_4x2_dim: k2 = k2 * 2 k = V.graph.sizevars.check_equals_and_simplify(k1, k2) +======= + b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)] + if use_4x2_dim: + k2 = k2 * 2 + k = V.graph.sizevars.guard_equals(k1, k2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if layout is None: from torch._inductor.ir import FixedLayout @@ -87,6 +199,23 @@ def mm_args( return [m, n, k, layout, mat1, mat2, *others] +<<<<<<< HEAD +======= +def mm_config_kwargs(device, exclude_condition, dtype_size=None): + if device == "cpu": + return { + "scale": 0.5, + "exclude": exclude_condition, + } + + if dtype_size and inductor_config.max_autotune_gemm_search_space == "EXHAUSTIVE": + return { + "dtype_size": dtype_size, + } + return {} + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def addmm_epilogue(dtype, alpha, beta): def epilogue(acc, bias): if alpha != 1: @@ -180,7 +309,11 @@ def has_zero_dim(size: Sequence[_IntLike]) -> bool: ) +<<<<<<< HEAD def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool: +======= +def is_batch_stride_largest(mat1, mat2, layout) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Checking if the batch stride is the largest in the stride. """ @@ -188,7 +321,11 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool: strides = [mat1.get_stride(), mat2.get_stride(), layout.stride] for size, stride in zip(sizes, strides): assert len(size) == len(stride) == 3, "Expect 3D tensors" +<<<<<<< HEAD if stride[0] != 0 and stride[0] != sympy_product(size[1:]): +======= + if stride[0] != sympy_product(size[1:]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False return True diff --git a/torch/_inductor/kernel/mm_plus_mm.py b/torch/_inductor/kernel/mm_plus_mm.py index 60e1b01a5b032..f633d9317cdd3 100644 --- a/torch/_inductor/kernel/mm_plus_mm.py +++ b/torch/_inductor/kernel/mm_plus_mm.py @@ -1,11 +1,17 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD import logging from typing import TYPE_CHECKING import torch from ..kernel_inputs import MMKernelInputs +======= +import torch + +from .. import ir +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ..lowering import lowerings from ..select_algorithm import ( autotune_select_algorithm, @@ -14,6 +20,7 @@ ) from ..utils import use_aten_gemm_kernels, use_triton_template from ..virtualized import V +<<<<<<< HEAD from .mm_common import mm_args, mm_grid @@ -22,6 +29,11 @@ log = logging.getLogger(__name__) +======= +from .mm_common import mm_args, mm_grid, mm_options + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten = torch.ops.aten aten_mm_plus_mm = ExternKernelChoice( @@ -127,9 +139,15 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): """ Computes mm(mat1, mat2) + mm(mat3, mat4) """ +<<<<<<< HEAD # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout) m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout) +======= + m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout) + device_type = ir.get_device_type(mat1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Optimization is optional, because we can always just not do the fusion if ( @@ -148,6 +166,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4) ) +<<<<<<< HEAD # Create MMKernelInputs for MM Plus MM (matrices are at indices 0, 1 for first pair) # Note: This is a special case with 4 matrices, but we use the first pair for M, N, K extraction kernel_inputs = MMKernelInputs([mat1, mat2, mat3, mat4], mat1_idx=0, mat2_idx=1) @@ -172,4 +191,29 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): return autotune_select_algorithm( "mm_plus_mm", choices, kernel_inputs.nodes(), layout1 +======= + assert layout1 == layout2 + # options to tune from + choices = ( + [aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout1)] + if use_aten_gemm_kernels() + else [] + ) + + mm_configs = V.choices.get_mm_plus_mm_configs(device_type) + if use_triton_template(layout1): + for config in mm_configs(): + # see https://github.com/triton-lang/triton/issues/1298 + # BLOCK_K = K causes llvm error + if V.graph.sizevars.statically_known_lt(config.kwargs["BLOCK_K"], k1): + mm_plus_mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2, mat3, mat4), + layout=layout1, + **mm_options(config, m1, n1, k1, layout1), + ) + + return autotune_select_algorithm( + "mm_plus_mm", choices, [mat1, mat2, mat3, mat4], layout1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/torch/_inductor/kernel/mm_scaled_grouped.py b/torch/_inductor/kernel/mm_scaled_grouped.py new file mode 100644 index 0000000000000..ad34ea0210b51 --- /dev/null +++ b/torch/_inductor/kernel/mm_scaled_grouped.py @@ -0,0 +1,741 @@ +# mypy: allow-untyped-defs +import logging +from dataclasses import dataclass +from typing import Any, Optional + +import torch +from torch._dynamo.utils import counters +from torch._inductor.runtime.triton_compat import tl +from torch._inductor.virtualized import V +from torch.utils._triton import has_triton + +from ..ir import ChoiceCaller, Layout, TensorBox +from ..lowering import register_lowering +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + realize_inputs, + TritonTemplate, +) +from ..utils import ( + get_gpu_shared_memory, + get_num_sms, + has_free_symbols, + use_aten_gemm_kernels, +) +from .mm_common import ( + _is_static_problem, + check_supported_striding, + persistent_grouped_mm_grid, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +@dataclass +class Config: + kwargs: dict[str, int] + num_stages: int + num_warps: int + + +_NV_CONFIGS = [ + Config( + { + "BLOCK_M": block_size_m, + "BLOCK_N": block_size_n, + "BLOCK_K": block_size_k, + "NUM_CONSUMER_GROUPS": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m in [16, 32, 64, 128] + for block_size_n in [64, 128, 256] + for block_size_k in [64, 128, 256] + for num_stages in [3, 4] + for num_warps in [4, 8] +] + + +def grouped_mm_configs(): + return _NV_CONFIGS + + +def early_config_prune(g, m, configs, named_args): + dtsize = 1 + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, num_consumer_groups = ( + kw["BLOCK_M"], + kw["BLOCK_N"], + kw["BLOCK_K"], + config.num_stages, + config.num_warps, + getattr(config, "num_consumer_groups", 0), + ) + + # 1. Prune NV configs depending on g and m. + if not has_free_symbols((g, m)): + a_is_2d, b_is_2d = named_args["A_IS_2D"], named_args["B_IS_2D"] + m_avg = m // g if a_is_2d and not b_is_2d else m + if m_avg <= 16: + if BLOCK_M > 32: + continue + elif m_avg <= 32: + if BLOCK_M > 64: + continue + elif m_avg <= 64: + if BLOCK_M <= 16: + continue + else: + if BLOCK_M <= 32: + continue + + # 2. make sure we have enough smem + max_shared_memory = get_gpu_shared_memory() + + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory > max_shared_memory: + continue + + use_warp_specialization = num_consumer_groups >= 1 + + # 3. make sure we can partition for ws + if use_warp_specialization: + if num_warps != 4: + continue + + # "tritongpu-warp-spec-data-partition" + m_slice = BLOCK_M // num_consumer_groups + n_slice = BLOCK_N // num_consumer_groups + if m_slice < 64 and n_slice < 256: + continue + + pruned_configs.append(config) + + return pruned_configs + + +triton_grouped_mm_source = r""" +{%- if SCALED %} +{%- if A_IS_2D or B_IS_2D %} +{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr", "offsets_ptr")}} +{%- else %} +{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr")}} +{%- endif %} +{%- else %} +{%- if A_IS_2D or B_IS_2D %} +{{def_kernel("a_ptr", "b_ptr", "offsets_ptr")}} +{%- else %} +{{def_kernel("a_ptr", "b_ptr")}} +{%- endif %} +{%- endif %} + tidx = tl.program_id(0) + +{%- set M_IS_VARYING = A_IS_2D and not B_IS_2D %} +{%- set N_IS_VARYING = not A_IS_2D and B_IS_2D %} +{%- set K_IS_VARYING = A_IS_2D and B_IS_2D %} + +{%- if A_IS_2D %} +{%- if B_IS_2D %} + G = {{size("offsets_ptr", 0)}} +{%- else %} + G = {{size("b_ptr", 0)}} +{%- endif %} +{%- else %} +{%- if B_IS_2D %} + G = {{size("a_ptr", 0)}} +{%- else %} + G = {{size("a_ptr", 0)}} +{%- endif %} +{%- endif %} + + # the b_ptr tensor is given with its last two dims transposed, revert here + + M = {{size("a_ptr", -2)}} + N = {{size("b_ptr", -1)}} + K = {{size("a_ptr", -1)}} + + A_STRIDE_M = {{stride("a_ptr", -2)}} + A_STRIDE_K = {{stride("a_ptr", -1)}} +{%- if not A_IS_2D %} + A_STRIDE_G = {{stride("a_ptr", 0)}} +{%- if SCALED %} + SCALE_A_STRIDE_G = {{stride("scale_a_ptr", 0)}} +{%- endif %} +{%- endif %} + B_STRIDE_N = {{stride("b_ptr", -1)}} + B_STRIDE_K = {{stride("b_ptr", -2)}} +{%- if not B_IS_2D %} + B_STRIDE_G = {{stride("b_ptr", 0)}} +{%- if SCALED %} + SCALE_B_STRIDE_G = {{stride("scale_b_ptr", 0)}} +{%- endif %} +{%- endif %} + +{%- if USE_TMA_LOAD %} +{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %} + a_desc = tl._experimental_make_tensor_descriptor( +{%- else %} + a_desc = tl.make_tensor_descriptor( +{%- endif %} + a_ptr, +{%- if A_IS_2D %} + shape=[M, K], + # fixme: strides=[A_STRIDE_M, A_STRIDE_K], + strides=[{{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}], + block_shape=[BLOCK_M, BLOCK_K], +{%- else %} + shape=[G, M, K], + # fixme: strides=[A_STRIDE_G, A_STRIDE_M, A_STRIDE_K], + strides=[{{stride("a_ptr", 0)}}, {{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}], + block_shape=[1, BLOCK_M, BLOCK_K], +{%- endif %} + ) + +{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %} + b_desc = tl._experimental_make_tensor_descriptor( +{%- else %} + b_desc = tl.make_tensor_descriptor( +{%- endif %} + b_ptr, +{%- if B_IS_2D %} + shape=[N, K], + # fixme: strides=[B_STRIDE_N, B_STRIDE_K], + strides=[{{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}], + block_shape=[BLOCK_N, BLOCK_K], +{%- else %} + shape=[G, N, K], + # fixme: strides=[B_STRIDE_G, B_STRIDE_N, B_STRIDE_K], + strides=[{{stride("b_ptr", 0)}}, {{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}], + block_shape=[1, BLOCK_N, BLOCK_K], +{%- endif %} + ) +{%- endif %} + +{%- if M_IS_VARYING %} + m_end_offset = 0 +{%- endif %} +{%- if N_IS_VARYING %} + n_end_offset = 0 +{%- endif %} +{%- if K_IS_VARYING %} + k_end_offset = 0 +{%- endif %} + iterated_tiles = 0 + for g in tl.range(G): +{%- if M_IS_VARYING %} + # Move across groups + m_start_offset = m_end_offset + m_end_offset = tl.load(offsets_ptr + g) + m_size = m_end_offset - m_start_offset +{%- if SCALED %} + m_scale_start_offset = m_start_offset +{%- endif %} +{%- else %} + m_start_offset = 0 + m_size = M +{%- if SCALED %} + m_scale_start_offset = g * M +{%- endif %} +{%- endif %} + +{%- if N_IS_VARYING %} + # Move across groups + n_start_offset = n_end_offset + n_end_offset = tl.load(offsets_ptr + g) + n_size = n_end_offset - n_start_offset +{%- if SCALED %} + n_scale_start_offset = n_start_offset +{%- endif %} +{%- else %} + n_start_offset = 0 + n_size = N +{%- if SCALED %} + n_scale_start_offset = g * N +{%- endif %} +{%- endif %} + + if m_size > 0 and n_size > 0: +{%- if K_IS_VARYING %} + # Move across groups + k_start_offset = k_end_offset + k_end_offset = tl.load(offsets_ptr + g) + k_size = k_end_offset - k_start_offset +{%- else %} + k_start_offset = 0 + k_size = K +{%- endif %} + + num_m_tiles = tl.cdiv(m_size, BLOCK_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_N) + num_tiles = num_m_tiles * num_n_tiles + + # Move across tiles + while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: + gidx = tidx - iterated_tiles + # Split M first and N second. + tile_m_idx = gidx % num_m_tiles + tile_n_idx = gidx // num_m_tiles + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + +{%- if USE_TMA_LOAD %} + m_offset = (m_start_offset + tile_m_idx * BLOCK_M).to(tl.int32) + n_offset = (n_start_offset + tile_n_idx * BLOCK_N).to(tl.int32) + + for k_offset in range(0, k_size, BLOCK_K): +{%- if A_IS_2D %} + a = a_desc.load([m_offset, k_start_offset + k_offset]) +{%- else %} + a = a_desc.load([g, m_offset, k_start_offset + k_offset]).reshape(BLOCK_M, BLOCK_K) +{%- endif %} +{%- if B_IS_2D %} + b = b_desc.load([n_offset, k_start_offset + k_offset]) +{%- else %} + b = b_desc.load([g, n_offset, k_start_offset + k_offset]).reshape(BLOCK_N, BLOCK_K) +{%- endif %} + +{%- if K_IS_VARYING %} + if k_offset + BLOCK_K > k_size: + group_offs_k = k_offset + tl.arange(0, BLOCK_K) + a = tl.where(group_offs_k < k_size, a, 0) + b = tl.where(group_offs_k < k_size, b, 0) +{%- endif %} + +{%- if USE_FAST_ACCUM %} + accumulator = tl.dot(a, b.T, accumulator) +{%- else %} + accumulator += tl.dot(a, b.T) +{%- endif %} +{%- else %} + offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = k_start_offset + tl.arange(0, BLOCK_K) + a_ptrs = ( + a_ptr +{%- if not A_IS_2D %} + + g * A_STRIDE_G +{%- endif %} + + (m_start_offset + offs_am[:, None]) * A_STRIDE_M + + offs_k[None, :] * A_STRIDE_K + ) + b_ptrs = ( + b_ptr +{%- if not B_IS_2D %} + + g * B_STRIDE_G +{%- endif %} + + (n_start_offset + offs_bn[:, None]) * B_STRIDE_N + + offs_k[None, :] * B_STRIDE_K + ) + for k_offset in range(0, k_size, BLOCK_K): + a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) + b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) + if k_offset + BLOCK_K > k_size: + group_offs_k = k_offset + tl.arange(0, BLOCK_K) + a = tl.where(group_offs_k < k_size, a, 0) + b = tl.where(group_offs_k < k_size, b, 0) +{%- if USE_FAST_ACCUM %} + accumulator = tl.dot(a, b.T, accumulator) +{%- else %} + accumulator += tl.dot(a, b.T) +{%- endif %} + a_ptrs += BLOCK_K + b_ptrs += BLOCK_K +{%- endif %} + + offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) +{%- if SCALED %} + scale_a = tl.load( + scale_a_ptr +{%- if A_IS_2D %} + + m_scale_start_offset +{%- else %} + + g * SCALE_A_STRIDE_G +{%- endif %} + + offs_am[:, None], + mask=offs_am[:, None] < m_size, + ) + scale_b = tl.load( + scale_b_ptr +{%- if B_IS_2D %} + + n_scale_start_offset +{%- else %} + + g * SCALE_B_STRIDE_G +{%- endif %} + + offs_bn[None, :], + mask=offs_bn[None, :] < n_size, + ) + c = accumulator.to(tl.float32) * scale_a * scale_b +{%- else %} + c = accumulator.to(tl.float32) +{%- endif %} + +{%- if M_IS_VARYING %} + idx_m = (m_start_offset + offs_am[:, None]) +{%- else %} + idx_m = offs_am[:, None] +{%- endif %} +{%- if N_IS_VARYING %} + idx_n = (n_start_offset + offs_bn[None, :]) +{%- else %} + idx_n = offs_bn[None, :] +{%- endif %} + mask = offs_am[:, None] < m_size and offs_bn[None, :] < n_size +{%- if M_IS_VARYING or N_IS_VARYING %} + {{store_output(("idx_m", "idx_n"), "c", "mask", indent_width=16)}} +{%- else %} + {{store_output(("g", "idx_m", "idx_n"), "c", "mask", indent_width=16)}} +{%- endif %} + tidx += NUM_SMS + + iterated_tiles += num_tiles +""" + + +triton_grouped_mm_template = TritonTemplate( + name="grouped_mm", + grid=persistent_grouped_mm_grid, + source=triton_grouped_mm_source, +) + +triton_scaled_grouped_mm_template = TritonTemplate( + name="scaled_grouped_mm", + grid=persistent_grouped_mm_grid, + source=triton_grouped_mm_source, +) + + +def grouped_mm_args( + mat1: TensorBox, + mat2: TensorBox, + offs: Optional[TensorBox], + layout=None, + out_dtype=None, +): + mat1, mat2 = realize_inputs(mat1, mat2) + if offs is not None: + realize_inputs(offs) + mat1_size = mat1.get_size() + mat2_size = mat2.get_size() + + m1dim, m2dim = len(mat1_size), len(mat2_size) + + assert m1dim == 2 or m1dim == 3 + assert m2dim == 2 or m2dim == 3 + + if layout is None: + from torch._inductor.ir import FixedLayout + + if out_dtype is None: + out_dtype = mat1.get_dtype() + + dims = [] + if m1dim == 2: + if m2dim == 2: + assert offs is not None + dims = [offs.get_size()[0], mat1_size[0], mat2_size[1]] + else: + dims = [mat1_size[0], mat2_size[-1]] + else: + if m2dim == 2: + dims = [mat1_size[1], mat2_size[1]] + else: + dims = [mat1_size[0], mat1_size[1], mat2_size[-1]] + layout = FixedLayout( + mat1.get_device(), + out_dtype, + dims, + ) + else: + assert out_dtype is None, "out_dtype is ignored if layout is specified." + + return (mat1_size, mat2_size, layout, mat1, mat2, offs) + + +aten__grouped_mm = ExternKernelChoice( + torch._grouped_mm, + "at::_grouped_mm", + op_overload=aten._grouped_mm, + has_out_variant=False, +) + + +aten__scaled_grouped_mm = ExternKernelChoice( + torch._scaled_grouped_mm, + "at::_scaled_grouped_mm", + op_overload=aten._scaled_grouped_mm, + has_out_variant=False, +) + + +def can_use_triton_kernel( + mat_a: TensorBox, + mat_b: TensorBox, + offs: Optional[TensorBox], + bias: Optional[TensorBox], + scale_result: Optional[TensorBox], +) -> bool: + if not ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + and not torch.version.hip + ): + return False + if not has_triton(): + return False + + # The _grouped_mm()/_scaled_grouped_mm() operator do not support + # bias nor scale_result yet. + if bias is not None: + return False + if scale_result is not None: + return False + + if len(mat_a.get_size()) == 2 or len(mat_b.get_size()) == 2: + return offs is not None + else: + return offs is None + + +def create_offsets(x, m1_size, m2_size, offs_size): + m1_is_2d = len(m1_size) == 2 + m2_is_2d = len(m2_size) == 2 + if m1_is_2d: + if m2_is_2d: + k = V.graph.sizevars.size_hint(m1_size[1]) + noffs = V.graph.sizevars.size_hint(offs_size[0]) + step = k / noffs + return torch.linspace( + step, k, noffs, dtype=x.get_dtype(), device=x.get_device() + ) + + else: + m = V.graph.sizevars.size_hint(m1_size[0]) + noffs = V.graph.sizevars.size_hint(offs_size[0]) + step = m / noffs + return torch.linspace( + step, m, noffs, dtype=x.get_dtype(), device=x.get_device() + ) + else: + if m2_is_2d: + n = V.graph.sizevars.size_hint(m2_size[0]) + noffs = V.graph.sizevars.size_hint(offs_size[0]) + step = n / noffs + return torch.linspace( + step, n, noffs, dtype=x.get_dtype(), device=x.get_device() + ) + else: + return None + + +def _tuned_grouped_mm_common( + operator_name: str, + algorithm_name: str, + extern_kernel_choice: ExternKernelChoice, + kernel_template: TritonTemplate, + mat_a: TensorBox, + mat_b: TensorBox, + scale_a: Optional[TensorBox] = None, + scale_b: Optional[TensorBox] = None, + offs: Optional[TensorBox] = None, + bias: Optional[TensorBox] = None, + scale_result: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: Optional[bool] = None, + layout: Optional[Layout] = None, +) -> TensorBox: + assert (scale_a is None) == (scale_b is None) + assert scale_result is None or scale_a is not None + + m1_size, m2_size, layout, mat_a, mat_b, offs = grouped_mm_args( + mat_a, mat_b, offs, layout=layout, out_dtype=out_dtype + ) + counters["aten_mm_info"][operator_name] += 1 + log_message = f"Tuned {operator_name}: mat1_shape=%s, mat2_shape=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s" + log.info( + log_message, + m1_size, + m2_size, + mat_a.get_dtype(), + mat_b.get_dtype(), + layout, + ) + + if scale_a is not None and scale_b is not None: + check_supported_striding(mat_a, mat_b) + + # workaround for Inductor not supporting optional tensor input arguments + input_nodes: list[Any] = [mat_a, mat_b] + if scale_a is not None: + input_nodes.append(realize_inputs(scale_a)) + if scale_b is not None: + input_nodes.append(realize_inputs(scale_b)) + if offs is not None: + input_nodes.append(realize_inputs(offs)) + + if use_fast_accum is None: + aten_choice = extern_kernel_choice.bind( + input_nodes, + layout, + out_dtype=out_dtype, + ) + else: + aten_choice = extern_kernel_choice.bind( + input_nodes, + layout, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, + ) + if use_fast_accum is None: + use_fast_accum = False + + choices: list[ChoiceCaller] = [] + if use_aten_gemm_kernels(): + choices.append(aten_choice) + + _, is_nonzero = _is_static_problem(layout) + + # Checking only for the equality of corresponding dims of + # multiplicands here, relying on meta function checks for + # everything else. + if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result): + scaled = scale_a is not None + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + g = offs.get_size()[0] + V.graph.sizevars.guard_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.guard_equals(g1, g2) + V.graph.sizevars.guard_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.guard_equals(g1, g2) + V.graph.sizevars.guard_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.guard_equals(g1, g2) + V.graph.sizevars.guard_equals(k1, k2) + a_is_2d, b_is_2d = False, False + + triton_has_make_tensor_descriptor = hasattr(tl, "make_tensor_descriptor") + triton_has_experimental_make_tensor_descriptor = hasattr( + tl, "_experimental_make_tensor_descriptor" + ) + use_tma_load = ( + triton_has_make_tensor_descriptor + or triton_has_experimental_make_tensor_descriptor + ) + # The make_tensor_descriptor imposes this additional limitation. + use_tma_load = use_tma_load and ( + mat_a.get_stride()[-1] == 1 and mat_b.get_stride()[-2] == 1 + ) + + kwargs = { + "SCALED": scaled, + "A_IS_2D": a_is_2d, + "B_IS_2D": b_is_2d, + "USE_FAST_ACCUM": use_fast_accum, + "NUM_SMS": get_num_sms(), + "USE_TMA_LOAD": use_tma_load, + "USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR": triton_has_experimental_make_tensor_descriptor, + } + + for config in early_config_prune(g, m, grouped_mm_configs(), kwargs): + kernel_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=layout, + num_stages=config.num_stages, + num_warps=config.num_warps, + **kwargs, + **config.kwargs, + ) + + input_gen_fns = { + 4: lambda x: create_offsets( + x, m1_size, m2_size, offs.get_size() if offs is not None else None + ), + } + return autotune_select_algorithm( + algorithm_name, choices, input_nodes, layout, input_gen_fns=input_gen_fns + ) + + +@register_lowering(aten._grouped_mm.default, type_promotion_kind=None) +def tuned_grouped_mm( + mat_a: TensorBox, + mat_b: TensorBox, + offs: Optional[TensorBox] = None, + bias: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + layout: Optional[Layout] = None, +) -> TensorBox: + """Auto-tuning for _grouped_mm() operator.""" + + return _tuned_grouped_mm_common( + "aten._grouped_mm.default", + "grouped_mm", + aten__grouped_mm, + triton_grouped_mm_template, + mat_a, + mat_b, + None, + None, + offs, + bias, + None, + out_dtype, + None, + layout, + ) + + +@register_lowering(aten._scaled_grouped_mm.default, type_promotion_kind=None) +def tuned_scaled_grouped_mm( + mat_a: TensorBox, + mat_b: TensorBox, + scale_a: TensorBox, + scale_b: TensorBox, + offs: Optional[TensorBox] = None, + bias: Optional[TensorBox] = None, + scale_result: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, + layout: Optional[Layout] = None, +) -> TensorBox: + """Auto-tuning for _scaled_grouped_mm() operator.""" + + return _tuned_grouped_mm_common( + "aten._scaled_grouped_mm.default", + "scaled_grouped_mm", + aten__scaled_grouped_mm, + triton_scaled_grouped_mm_template, + mat_a, + mat_b, + scale_a, + scale_b, + offs, + bias, + scale_result, + out_dtype, + use_fast_accum, + layout, + ) diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 5ae38810fa134..882a0cc8bb636 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -223,6 +223,7 @@ def merge_loops(self) -> LoopBody: ) return new_body2 +<<<<<<< HEAD def expand_dimension_for_pointwise_node( self, dimension: int, new_range: int ) -> LoopBody: @@ -270,6 +271,8 @@ def new_body(*indices: Sequence[sympy.Expr]) -> Any: ) return new_body +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def reorder_iter_loops(self, new_order) -> LoopBody: """ Reorder iteration loops and return a new LoopBody. diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index d05bdd1354694..0c8dac02a43c6 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -26,7 +26,10 @@ from torch._dynamo.utils import counters from torch._higher_order_ops.associative_scan import associative_scan_op from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation +<<<<<<< HEAD from torch._library.utils import get_layout_constraint_tag +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._prims_common import ( canonicalize_dim, canonicalize_dims, @@ -41,11 +44,15 @@ Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import ( free_unbacked_symbols, has_free_unbacked_symbols, resolve_unbacked_bindings, ) +======= +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing @@ -53,19 +60,28 @@ from . import config, inductor_prims, ir, test_operators # NOQA: F401 from .decomposition import decompositions, get_decompositions from .ir import ( +<<<<<<< HEAD BaseView, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DtypeView, ExpandView, IndexingConstant, IRNode, is_triton, +<<<<<<< HEAD MutableBox, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) OnlineSoftmaxReduction, ops_wrapper, PermuteView, Pointwise, Reduction, +<<<<<<< HEAD ShapeAsConstantBuffer, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SqueezeView, TensorBox, validate_ir, @@ -164,10 +180,13 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A if not isinstance(fn, torch._ops.OpOverload): # Only OpOverloads have layout constraints. return None +<<<<<<< HEAD if maybe_layout_tag := get_layout_constraint_tag(fn, with_default=False): return tag_to_layout_constraint(maybe_layout_tag) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if fn in _maybe_layout_constraints: return _maybe_layout_constraints[fn] return None @@ -176,7 +195,11 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A def tag_to_layout_constraint(tag): if tag == torch._C.Tag.needs_exact_strides: return constrain_to_fake_tensors +<<<<<<< HEAD if tag == torch._C.Tag.needs_contiguous_strides: # type: ignore[attr-defined] +======= + if tag == torch._C.Tag.needs_contiguous_strides: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return require_contiguous_strides if tag == torch._C.Tag.needs_fixed_stride_order: return constrain_to_fx_strides @@ -314,6 +337,7 @@ def in_namespace(op, namespace): return False +<<<<<<< HEAD def maybe_copy_cpu_scalar(x: TensorBox, device: torch.device) -> TensorBox: """ Copy cpu scalar if doesn't not match with given `device` @@ -334,6 +358,8 @@ def maybe_copy_cpu_scalar(x: TensorBox, device: torch.device) -> TensorBox: return x +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def transform_args( args: list[Any], kwargs: dict[str, Any], @@ -341,10 +367,13 @@ def transform_args( type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND], convert_input_to_bool: bool, ) -> tuple[list[Any], dict[str, Any]]: +<<<<<<< HEAD """ Transforms arguments for broadcasting and type promotion """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args_indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] kwargs_indices = [k for k, v in kwargs.items() if isinstance(v, TensorBox)] # check that there's something to transform @@ -372,12 +401,15 @@ def transform_args( args[args_indices[0]] if args_indices else kwargs[kwargs_indices[0]] ).get_device() +<<<<<<< HEAD for i in args_indices: args[i] = maybe_copy_cpu_scalar(args[i], device) for k in kwargs_indices: kwargs[k] = maybe_copy_cpu_scalar(kwargs[k], device) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # sometimes args are an immutable list so we can't mutate them def promote(arg): if isinstance(arg, TensorBox): @@ -527,12 +559,25 @@ def broadcast_symbolic_shapes(a, b): """ output = [] for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One): +<<<<<<< HEAD if V.graph.sizevars.is_size_one_or_false(y): output.append(x) elif V.graph.sizevars.is_size_one_or_false(x): output.append(y) else: V.graph.sizevars.check_equals(x, y) +======= + if V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(y, 1), size_oblivious=True + ): + output.append(x) + elif V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(x, 1), size_oblivious=True + ): + output.append(y) + else: + V.graph.sizevars.guard_equals(x, y) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols): output.append(y) # prefer shorter formula else: @@ -735,9 +780,13 @@ def inner(*inputs: list[list[TensorBox]], alpha=1): return inner +<<<<<<< HEAD def to_dtype( x: Union[TensorBox, ShapeAsConstantBuffer], dtype: torch.dtype, copy: bool = False ): +======= +def to_dtype(x: TensorBox, dtype: torch.dtype, copy=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) src_dtype = x.get_dtype() if src_dtype == dtype: return clone(x) if copy else x @@ -974,10 +1023,32 @@ def broadcast_tensors(*inputs): outputs = [] for x in inputs: sizes = x.get_size() +<<<<<<< HEAD if len(sizes) != len(target) or any( V.graph.sizevars.is_size_one_or_false(a) != V.graph.sizevars.is_size_one_or_false(b) +======= + if len(sizes) != len(target) or any( + ( + ( + V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(a, 1), size_oblivious=True + ) + and not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(b, 1), size_oblivious=True + ) + ) + or ( + not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(a, 1), size_oblivious=True + ) + and V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(b, 1), size_oblivious=True + ) + ) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for a, b in zip(sizes, target) ): x = expand(x, target) @@ -1001,16 +1072,29 @@ def squeeze(x, dim=None): return TensorBox(SqueezeView.create(x.data)) dim = ( +<<<<<<< HEAD V.graph.sizevars.guard_int(dim) if isinstance(dim, (int, sympy.Expr)) else tuple(V.graph.sizevars.guard_int(d) for d in dim) +======= + V.graph.sizevars.evaluate_static_shape(dim) + if isinstance(dim, (int, sympy.Expr)) + else tuple(V.graph.sizevars.evaluate_static_shape(d) for d in dim) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload] dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim) new_shape = [] for d, s in enumerate(x.get_size()): +<<<<<<< HEAD if not (d in dims and V.graph.sizevars.guard_or_false(sympy.Eq(s, 1))): +======= + if not ( + d in dims + and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True) + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_shape.append(s) # squeeze does nothing if the size isn't 1 @@ -1182,7 +1266,13 @@ def inner_fn(index): @register_lowering(aten._unsafe_view, type_promotion_kind=None) @register_lowering(aten.view, type_promotion_kind=None) @register_lowering(aten.reshape, type_promotion_kind=None) +<<<<<<< HEAD def view(x: TensorBox, sizes: Sequence[sympy.Expr]) -> TensorBox: +======= +def view(x, sizes): + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return TensorBox(View.create(x.data, sizes)) @@ -1310,7 +1400,11 @@ def quantized_decomposed_quantize_per_channel( quant_min: int, quant_max: int, dtype: torch.dtype, +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= +) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(scales.get_size()) == 1, "expect scales 1 dim" assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" @@ -1352,6 +1446,7 @@ def inner_fn(idx): ) +<<<<<<< HEAD def _assert_async(cond, msg): cond.realize() cond = to_dtype(cond, torch.bool) @@ -1380,6 +1475,8 @@ def lower_assert_functional_async(cond, msg): return _assert_async(cond, msg) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_lowering( quantized_decomposed.dequantize_per_channel, type_promotion_kind=None ) @@ -1393,7 +1490,11 @@ def quantized_decomposed_dequantize_per_channel( dtype: torch.dtype, *, out_dtype: Optional[torch.dtype] = None, +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= +) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(scales.get_size()) == 1, "expect scales 1 dim" assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" assert input.get_dtype() == dtype, ( @@ -1443,7 +1544,11 @@ def quantized_decomposed_quantize_per_tensor_default( quant_min: int, quant_max: int, dtype: torch.dtype, +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= +) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.get_dtype() == torch.bfloat16: input = to_dtype(input, torch.float32) assert input.get_dtype() == torch.float32, ( @@ -1484,7 +1589,11 @@ def quantized_decomposed_dequantize_per_tensor_default( dtype: torch.dtype, *, out_dtype: Optional[torch.dtype] = None, +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= +) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert input.get_dtype() == dtype, ( f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" ) @@ -1521,7 +1630,11 @@ def quantized_decomposed_quantize_per_tensor_tensor( quant_min: int, quant_max: int, dtype: torch.dtype, +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= +) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.get_dtype() == torch.bfloat16: input = to_dtype(input, torch.float32) assert input.get_dtype() == torch.float32, ( @@ -1571,7 +1684,11 @@ def quantized_decomposed_dequantize_per_tensor_tensor( dtype: torch.dtype, *, out_dtype: Optional[torch.dtype] = None, +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= +) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(scale.get_size()) == 0 or ( len(scale.get_size()) == 1 and scale.get_size()[0] == 1 ), "expect scale as scalar tensor" @@ -1804,6 +1921,7 @@ def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1): @register_lowering(aten.select, type_promotion_kind=None) def select(x, dim, idx): +<<<<<<< HEAD idx = sympy.expand(idx) size = sympy.expand(x.get_size()[dim]) actual_index = None @@ -1858,6 +1976,10 @@ def select(x, dim, idx): del new_size[dim] del new_stride[dim] return as_strided(x, new_size, new_stride, new_storage_offset) +======= + idx = View.handle_negative_index(idx, x.get_size()[dim]) + return squeeze(slice_(x, dim, idx, idx + 1), dim) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_lowering(aten.split, type_promotion_kind=None) @@ -1869,7 +1991,13 @@ def split(x, sizes, dim=0): # by computing what the actual size of each chunk should be. if not isinstance(sizes, (list, tuple)): x_size = x.get_size()[dim] +<<<<<<< HEAD chunks = V.graph.sizevars.guard_int(FloorDiv(x_size + sizes - 1, sizes)) +======= + chunks = V.graph.sizevars.evaluate_static_shape( + FloorDiv(x_size + sizes - 1, sizes) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sizes_ = [sizes] * chunks # The last chunk might have a smaller size than the rest. sizes_[-1] = x_size - (chunks - 1) * sizes @@ -1895,7 +2023,11 @@ def split_with_sizes(x, sizes, dim=0): @register_lowering(aten.unbind, type_promotion_kind=None) def unbind(x, dim=0): dim = _validate_dim(x, dim, 0) +<<<<<<< HEAD x_size = V.graph.sizevars.guard_int(x.get_size()[dim]) +======= + x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result = [select(x, dim, i) for i in range(x_size)] return result @@ -1911,8 +2043,13 @@ def unfold(x, dimension, size, step): dim_size = sizes[dim] sizevars = V.graph.sizevars +<<<<<<< HEAD sizevars.check_leq(size, dim_size) sizevars.check_lt(0, step) # type: ignore[arg-type] +======= + sizevars.guard_leq(size, dim_size) + sizevars.guard_lt(0, step) # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_dim_size = FloorDiv(dim_size - size, step) + 1 if sizevars.size_hint_or_throw(dim_size) > 0: @@ -1959,7 +2096,11 @@ def _validate_dim(x, dim, offset=0): def glu(x, dim=-1): dim = _validate_dim(x, dim, 0) # TODO: don't guard on static shape here +<<<<<<< HEAD new_len = V.graph.sizevars.guard_int(x.get_size()[dim]) // 2 +======= + new_len = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim]) // 2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a = slice_(x, dim, 0, new_len) b = slice_(x, dim, new_len, new_len * 2) return mul(a, sigmoid(b)) @@ -2363,7 +2504,11 @@ def searchsorted( right: bool = False, side: Optional[str] = None, sorter: Optional[TensorBox] = None, +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= +) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) validate_bucketize = lambda tb: V.graph.has_feature( # noqa: E731 tb, BackendFeature.BUCKETIZE ) @@ -2635,8 +2780,12 @@ def apply_constraint(idx, arg, fx_arg): if len(arg.get_size()) not in (3, 4): return arg +<<<<<<< HEAD is_aligned_tensor = ir.is_aligned_realized_tensor_hint(arg, ALIGNMENT) if is_aligned_tensor: +======= + if ir.is_aligned_realized_tensor(arg, ALIGNMENT): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ir.try_match_insignificant_strides( ir.ExternKernel.realize_input(arg), meta_stride_expr ) @@ -2644,7 +2793,11 @@ def apply_constraint(idx, arg, fx_arg): if ( isinstance(arg, IRNode) and arg.maybe_get_stride() is not None +<<<<<<< HEAD and is_aligned_tensor +======= + and ir.is_aligned_realized_tensor(arg, ALIGNMENT) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): return ir.try_match_insignificant_strides( ir.ExternKernel.realize_input(arg), meta_stride_expr @@ -2688,7 +2841,11 @@ def apply_constraint(idx, arg, fx_arg): return ir.ExternKernel.require_exact_strides(arg, out_strides) +<<<<<<< HEAD if is_aligned_tensor: +======= + if ir.is_aligned_realized_tensor(arg, ALIGNMENT): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ir.try_match_insignificant_strides( ir.ExternKernel.realize_input(arg), meta_stride_expr ) @@ -2696,7 +2853,11 @@ def apply_constraint(idx, arg, fx_arg): if ( isinstance(arg, IRNode) and arg.maybe_get_stride() is not None +<<<<<<< HEAD and is_aligned_tensor +======= + and ir.is_aligned_realized_tensor(arg, ALIGNMENT) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): return ir.try_match_insignificant_strides( ir.ExternKernel.realize_input(arg), meta_stride_expr @@ -2930,7 +3091,10 @@ def is_aligned(x): # index_reduce requires fallback when use_scatter_fallback(...) returns True make_fallback(aten.index_reduce) +<<<<<<< HEAD make_fallback(aten.repeat_interleave.Tensor, override_decomp=True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Register with type_promotion_kind None. @@ -3012,8 +3176,13 @@ def select_scatter(x, src, dim: int, index: int): dim = _validate_dim(x, dim, 0) if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)): index = index + x.get_size()[dim] +<<<<<<< HEAD V.graph.sizevars.check_leq(0, index) # type: ignore[arg-type] V.graph.sizevars.check_lt(index, x.get_size()[dim]) # type: ignore[arg-type] +======= + V.graph.sizevars.guard_leq(0, index) # type: ignore[arg-type] + V.graph.sizevars.guard_lt(index, x.get_size()[dim]) # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) src = expand(unsqueeze(src, dim), x.get_size()) src_loader = src.make_loader() @@ -3037,7 +3206,11 @@ def inner_fn(idx): @register_lowering(aten.slice_scatter, type_promotion_kind=None) def slice_scatter(x, src, dim=0, start=None, end=None, step=1): +<<<<<<< HEAD src = to_dtype(src, x.get_dtype()) +======= + assert x.get_dtype() == src.get_dtype() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x_loader = x.make_loader() dim = _validate_dim(x, dim, 0) dim_size = x.get_size()[dim] @@ -3184,6 +3357,11 @@ def long_tensor(data): @register_lowering(aten._local_scalar_dense) def _local_scalar_dense(data): +<<<<<<< HEAD +======= + from torch.fx.experimental.symbolic_shapes import resolve_unbacked_bindings + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This is interesting! Most lowerings return tensors, so you can just # return the buffer you allocated and it will get used (or not used, if # it's dead.) But _local_scalar_dense (aka item) returns an int, @@ -3274,6 +3452,10 @@ def inner_fn(index): ) +<<<<<<< HEAD +======= +@register_lowering(aten.full_like, type_promotion_kind=None) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def full_like(x, fill_value, **kwargs): return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs) @@ -3717,7 +3899,10 @@ def index_put_as_masked_fill(self, indices, value, accumulate): def index_put_fallback(self, indices, values, accumulate): +<<<<<<< HEAD assert isinstance(V.graph.current_node.target, torch._ops.OpOverload) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ir.IndexPutFallback(V.graph.current_node.target, self, indices, values, accumulate) return self @@ -3841,10 +4026,15 @@ def indice_slice_from_randperm(indice): values = expand(values, expected_vals_size) # all guards are set above during broadcast_tensors and expand +<<<<<<< HEAD device = self.get_device() assert device is not None scatter = ir.Scatter( device=device, +======= + scatter = ir.Scatter( + device=self.get_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype=self.get_dtype(), inner_fn=values.make_loader(), ranges=expected_vals_size, # iter_ranges, @@ -4064,6 +4254,7 @@ def backend_reduce_str(reduce): assert reduce is None return None +<<<<<<< HEAD device = self.get_device() assert device is not None @@ -4071,6 +4262,12 @@ def backend_reduce_str(reduce): # zero out the corresponding elements first zero_out = ir.Scatter( device=device, +======= + if not include_self: + # zero out the corresponding elements first + zero_out = ir.Scatter( + device=self.get_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype=self.get_dtype(), inner_fn=lambda index: ops.constant(0, self.get_dtype()), ranges=index.get_size(), @@ -4089,7 +4286,11 @@ def backend_reduce_str(reduce): # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 scatter = ir.Scatter( +<<<<<<< HEAD device=device, +======= + device=self.get_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype=self.get_dtype(), inner_fn=fn, ranges=index.get_size(), @@ -4120,7 +4321,11 @@ def upsample_nearestnd( x_loader = x.make_loader() i_sizes = x.get_size()[-n:] batch = x.get_size()[:-n] +<<<<<<< HEAD i_sizes = [V.graph.sizevars.guard_int(i) for i in i_sizes] +======= + i_sizes = [V.graph.sizevars.evaluate_static_shape(i) for i in i_sizes] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(scales_x) == n o_sizes = output_size @@ -4452,10 +4657,17 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0: # Sliding windows must start within the input or left padding x_alt -= 1 # type: ignore[assignment] +<<<<<<< HEAD V.graph.sizevars.check_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type] if V.graph.sizevars.size_hint(x_out - x_alt) == 0: # ceil mode is actually a no-op, lets guard on that V.graph.sizevars.check_equals(x_out, x_alt) +======= + V.graph.sizevars.guard_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type] + if V.graph.sizevars.size_hint(x_out - x_alt) == 0: + # ceil mode is actually a no-op, lets guard on that + V.graph.sizevars.guard_equals(x_out, x_alt) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ceil_mode = False else: x_out = x_alt @@ -4562,10 +4774,17 @@ def fn_inner(idx, reduction_idx): ranges=new_size, reduction_ranges=kernel_size, ) +<<<<<<< HEAD if isinstance(result.data.data, Reduction): # type: ignore[attr-defined, union-attr] # Only realize if reduction isn't unrolled result.realize() if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined, union-attr] +======= + if isinstance(result.data.data, Reduction): # type: ignore[attr-defined] + # Only realize if reduction isn't unrolled + result.realize() + if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Only realize if reduction isn't unrolled offsets.realize() @@ -4615,7 +4834,11 @@ def _pool_offsets_to_indices( [Sequence[Union[int, torch.SymInt]], Sequence[Union[int, torch.SymInt]]], torch._inductor.virtualized.OpsValue, ], +<<<<<<< HEAD ) -> Union[TensorBox, ShapeAsConstantBuffer]: +======= +) -> TensorBox: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) n_dim = len(kernel_size) offsets_loader = offsets.make_loader() window_size = sympy.sympify(functools.reduce(operator.mul, kernel_size)) @@ -4748,12 +4971,19 @@ def max_pool2d_with_indices_backward( x_stride: Optional[Sequence[Any]] if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise): # type: ignore[attr-defined] data = x.data.data # type: ignore[attr-defined] +<<<<<<< HEAD device = data.get_device() assert device is not None x_buffer = ir.ComputedBuffer( name=None, layout=ir.FlexibleLayout( device=device, +======= + x_buffer = ir.ComputedBuffer( + name=None, + layout=ir.FlexibleLayout( + device=data.get_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype=data.get_dtype(), size=data.get_size(), ), @@ -5018,8 +5248,13 @@ def _adaptive_avg_pool2d(x, output_size): *batch, h_in, w_in = x.get_size() +<<<<<<< HEAD h_in = V.graph.sizevars.guard_int(h_in) w_in = V.graph.sizevars.guard_int(w_in) +======= + h_in = V.graph.sizevars.evaluate_static_shape(h_in) + w_in = V.graph.sizevars.evaluate_static_shape(w_in) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) h_out, w_out = output_size @@ -5093,8 +5328,13 @@ def adaptive_max_pool2d(x, output_size): *batch, h_in, w_in = x.get_size() +<<<<<<< HEAD h_in = V.graph.sizevars.guard_int(h_in) w_in = V.graph.sizevars.guard_int(w_in) +======= + h_in = V.graph.sizevars.evaluate_static_shape(h_in) + w_in = V.graph.sizevars.evaluate_static_shape(w_in) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) h_out, w_out = output_size @@ -5171,6 +5411,7 @@ def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim, ndims): samples_loader = samples.make_loader() def load(prefix, i): +<<<<<<< HEAD # Handle indexing for samples tensor correctly for different input dimensions # samples tensor always has shape (N, C, 2) for fractional_max_pool2d where: # - N=1 for 3D inputs (C,H,W), N=batch_size for 4D inputs (N,C,H,W) @@ -5193,6 +5434,9 @@ def load(prefix, i): else: # Fallback for unexpected tensor shapes sample = samples_loader([*prefix, ndims - 1 - dim]) +======= + sample = samples_loader([*prefix, ndims - 1 - dim]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) i_expr = ops.index_expr(i, samples.get_dtype()) diff = ops.index_expr(in_sz - kernel_sz, torch.int64) out_sz_expr = ops.index_expr(out_sz - 1, torch.int64) @@ -5271,11 +5515,17 @@ def increments_to_index(idx, reduction_idx): ranges=new_size, reduction_ranges=kernel_size, ) +<<<<<<< HEAD assert isinstance(result, TensorBox), result if isinstance(result.data.data, Reduction): # type: ignore[attr-defined] # Only realize if reduction isn't unrolled result.realize() assert isinstance(offsets, TensorBox), offsets +======= + if isinstance(result.data.data, Reduction): # type: ignore[attr-defined] + # Only realize if reduction isn't unrolled + result.realize() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined] # Only realize if reduction isn't unrolled offsets.realize() @@ -5293,8 +5543,13 @@ def upsample_nearest2d_backward( x.realize_hint() *_batch, inp_h, inp_w = x.get_size() +<<<<<<< HEAD inp_h = V.graph.sizevars.guard_int(inp_h) inp_w = V.graph.sizevars.guard_int(inp_w) +======= + inp_h = V.graph.sizevars.evaluate_static_shape(inp_h) + inp_w = V.graph.sizevars.evaluate_static_shape(inp_w) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *_batch, out_h, out_w = input_size @@ -5960,7 +6215,11 @@ def inner(x, axis=None, keepdims=False, *, dtype=None): ) result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs) if isinstance( +<<<<<<< HEAD result.data.data, # type: ignore[attr-defined, attr-type, union-attr] +======= + result.data.data, # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Reduction, ): # Only realize if reduction isn't unrolled result.realize() @@ -6206,14 +6465,22 @@ def mutate_to(changed, val, unsafe_alias=False): if not isinstance(val, ir.StorageBox): # introduce a copy to handle views +<<<<<<< HEAD node = Pointwise.create( +======= + val = Pointwise.create( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=changed.get_device(), dtype=changed.get_dtype(), inner_fn=val.make_loader(), ranges=changed.get_size(), +<<<<<<< HEAD ) assert isinstance(node, (BaseView, MutableBox)) val = node.data +======= + ).data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(val, ir.StorageBox) if isinstance(changed_data, ir.StorageBox) and not ( @@ -6324,9 +6591,13 @@ def div_prim(a, b): if is_integral: return truncdiv(a, b) +<<<<<<< HEAD # Disable CPU optimization to avoid precision issues. # see https://github.com/pytorch/pytorch/issues/157959 if (divisor := get_constant_value(b)) is not None and a.get_device().type != "cpu": +======= + if (divisor := get_constant_value(b)) is not None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Replace divide by constant with multiply by reciprocal if divisor.value == 0: reciprocal = math.copysign(float("inf"), divisor.value) @@ -6788,7 +7059,11 @@ def make_triton_fallback(op): register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum) register_foreach_pointwise(aten._foreach_reciprocal, reciprocal) register_foreach_pointwise(aten._foreach_sign, sign) +<<<<<<< HEAD foreach_copy = register_foreach_pointwise(aten._foreach_copy, copy) +======= +register_foreach_pointwise(aten._foreach_copy, copy) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # these are only encountered as outputs of the graph @@ -6827,9 +7102,12 @@ def fn(*args, **kwargs): register_foreach_inplace( aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar ) +<<<<<<< HEAD register_foreach_inplace( aten._foreach_copy_.default, aten._foreach_copy.default, foreach_copy ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def register_inplace(aten_op, outplace_op): @@ -6892,9 +7170,13 @@ def sym_size(a, dim): # int, but you KNOW that int must always be a constant, # then you do not need trace that call at all (and just # constant propagate the integer as is.) +<<<<<<< HEAD assert isinstance(val, torch.SymInt), ( f"Expect val to be torch.SymInt but got val={val}" ) +======= + assert isinstance(val, torch.SymInt) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return val.node.expr @@ -6902,9 +7184,13 @@ def sym_size(a, dim): def sym_stride(a, dim): val = V.graph.current_node.meta["val"] # See Note [Can val be an int?] +<<<<<<< HEAD assert isinstance(val, torch.SymInt), ( f"Expect val to be torch.SymInt but got val={val}" ) +======= + assert isinstance(val, torch.SymInt) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return val.node.expr @@ -6987,6 +7273,7 @@ def resize(x, size, *, memory_format=None): and torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined] ): if is_float_dtype(dtype): +<<<<<<< HEAD uninitialized_val = float("nan") elif is_integer_dtype(dtype): uninitialized_val = torch.iinfo(dtype).max @@ -6998,6 +7285,19 @@ def resize(x, size, *, memory_format=None): if V.graph.sizevars.statically_known_equals(old_numel, 0): # type: ignore[arg-type] return full(size, uninitialized_val, dtype=dtype, device=device) +======= + uninitalized_val = float("nan") + elif is_integer_dtype(dtype): + uninitalized_val = torch.iinfo(dtype).max + else: + uninitalized_val = True + else: + # using zero as that is what empty does + uninitalized_val = 0.0 + + if V.graph.sizevars.statically_known_equals(old_numel, 0): # type: ignore[arg-type] + return full(size, uninitalized_val, dtype=dtype, device=device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x_flat = as_strided( x, @@ -7017,7 +7317,11 @@ def inner_fn(idx): flat_index_expr = ops.index_expr(flat_index, torch.int64) limit = ops.index_expr(old_numel, torch.int64) mask = ops.lt(flat_index_expr, limit) +<<<<<<< HEAD return ops.masked(mask, lambda: flat_loader([flat_index]), uninitialized_val) +======= + return ops.masked(mask, lambda: flat_loader([flat_index]), uninitalized_val) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = Pointwise.create( device=device, dtype=dtype, inner_fn=inner_fn, ranges=list(size) @@ -7065,7 +7369,11 @@ def cond(pred, true_fn, false_fn, operands): @register_lowering(torch.ops.higher_order.while_loop, type_promotion_kind=None) +<<<<<<< HEAD def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False): +======= +def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if any( isinstance(x, IRNode) and is_triton(x) for x in carried_inputs + additional_inputs @@ -7085,6 +7393,7 @@ def _map_output(out: Any): else: raise RuntimeError(f"NYI unsupported output type: {type(out)}") +<<<<<<< HEAD result = ir.WhileLoop.create( cond_fn, body_fn, carried_inputs, additional_inputs, stack_output ) @@ -7101,6 +7410,16 @@ def _map_output(out: Any): def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands): result = ir.InvokeSubgraph.create(subgraph_fn, *operands) return list(map(TensorBox.create, result)) # type: ignore[call-overload] +======= + result = ir.WhileLoop.create(cond_fn, body_fn, carried_inputs, additional_inputs) + return list(map(_map_output, result)) + + +@register_lowering(torch.ops.higher_order.invoke_subgraph, type_promotion_kind=None) +def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands): + result = ir.InvokeSubgraph.create(subgraph_fn, *operands) + return list(map(TensorBox.create, result)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_lowering(torch._higher_order_ops.invoke_quant, type_promotion_kind=None) diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index 27ca4415c8f0e..505069d80fc71 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -4,14 +4,23 @@ import dataclasses import heapq import logging +<<<<<<< HEAD from typing import Callable, Optional, TYPE_CHECKING, TypedDict, Union from torch._environment import is_fbcode +======= +from typing import Callable, TYPE_CHECKING, TypedDict, Union + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._utils_internal import signpost_event from torch.utils._ordered_set import OrderedSet from .ir import MultiOutputLayout, NoneLayout +<<<<<<< HEAD from .utils import get_dtype_size +======= +from .utils import get_dtype_size, is_wait +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .virtualized import V @@ -76,11 +85,30 @@ def get_freeable_input_buf( Create and keep track of all input buffers that can be freed during the program Returns: +<<<<<<< HEAD A dictionary containing all freeable input buffers, keyed by their names. """ def _dep_size_hint(dep: Dep) -> int: return V.graph.get_dep_size_hint(dep) +======= + A dictionary containing all freeble input buffers, keyed by their names. + """ + + # this function is copied from torch/_inductor/scheduler.py + # TODO: would be nice to remove the try/except block for both places + def _dep_size_hint(dep: Dep) -> int: + res = 0 + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + return res +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # get freeable input buffers' successor nodes and their sizes # note that different deps can have the same name, so we use name as keys @@ -88,6 +116,7 @@ def _dep_size_hint(dep: Dep) -> int: collections.defaultdict(OrderedSet) ) dep_name_to_size: dict[str, int] = dict() +<<<<<<< HEAD for node in nodes: for dep in node.read_writes.reads: @@ -102,6 +131,15 @@ def _dep_size_hint(dep: Dep) -> int: ): dep_name_to_succ_nodes[dep.name].add(node) dep_name_to_size[dep.name] = _dep_size_hint(dep) +======= + for node in nodes: + for dep in node.read_writes.reads: + if dep.name in graph_inputs and not dep.name.startswith( + ("primals_", "arg", "fwd_rng_state", "bwd_rng_state") + ): + dep_name_to_succ_nodes[dep.name].add(node) + dep_name_to_size[dep.name] = _dep_size_hint(dep) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # create FreeableInputBuffer objects and add them to the returned dictionary name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = dict() @@ -132,6 +170,7 @@ def compute_size_for_scheduler_buffer( buf1: at creation, 0 bytes allocated, when deleted, 10 bytes freed buf2: at creation, 0 bytes allocated, when deleted, 20 bytes freed +<<<<<<< HEAD When an operation mutates a buffer in-place, the scheduler creates a new buffer name to track the "before" and "after" states, even though they share the same memory. @@ -154,6 +193,8 @@ def compute_size_for_scheduler_buffer( The only memory events are the creation prior to op0, and the deletion following buf1. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns: A dictionary mapping a scheduler buffer to a tuple of (size_alloc, size_free). """ @@ -165,12 +206,33 @@ def compute_size_for_scheduler_buffer( def _compute_and_update_buf_size( sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False ) -> int: +<<<<<<< HEAD if sched_buf.get_name() in V.graph.scheduler.mutation_real_name: sched_buf_to_size[sched_buf.get_name()] = (0, 0) return 0 elif isinstance(sched_buf.node.layout, NoneLayout): sched_buf_to_size[sched_buf.get_name()] = (0, 0) return 0 +======= + if isinstance(sched_buf.node.layout, NoneLayout): + _size = 0 + # for a wait tensor op, its schedulerBuffer NoneLayout layout. However, + # the schedulerBuffer is treated as a mutation of the collective output + # so it needs to inherit the size of the collectives + if ( + sched_buf.defining_op + and is_wait(sched_buf.defining_op.node) + and sched_buf.get_mutations() + ): + mutated_buf_name = sched_buf.get_mutations()[0] + _size = ( + sched_buf_to_size[mutated_buf_name][1] + if mutated_buf_name in sched_buf_to_size + else 0 + ) + sched_buf_to_size[sched_buf.get_name()] = (_size, _size) + return _size +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif isinstance(sched_buf.node.layout, MultiOutputLayout): size_alloc = 0 for user in sched_buf.users: @@ -223,6 +285,7 @@ def assign_memory_planning_info_for_scheduler_buffers( for dep in node.unmet_dependencies: dep_name_to_succ_nodes[dep.name].add(node) +<<<<<<< HEAD # iterate in reverse, so dependencies are picked up transitively. for mutating_buf_name, real_buf_name in reversed( V.graph.scheduler.mutation_real_name.items() @@ -231,6 +294,8 @@ def assign_memory_planning_info_for_scheduler_buffers( mutating_buf_name ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # populate the MemoryPlanningInfoForBuffer attribute to each scheduler buffer # note: there are scheduler buffers not in dep_name_to_succ_nodes (e.g., graph outputs) for buf_name in name_to_buf.keys(): @@ -250,6 +315,7 @@ def assign_memory_planning_info_for_scheduler_nodes( """ Assign to each scheduler node its predecessor and successor nodes. """ +<<<<<<< HEAD node_to_pred_nodes: dict[BaseSchedulerNode, OrderedSet[BaseSchedulerNode]] = ( collections.defaultdict(OrderedSet) @@ -261,11 +327,29 @@ def assign_memory_planning_info_for_scheduler_nodes( # collect all predecessors using existing successor mappings for node in nodes: +======= + from .scheduler import SchedulerBuffer + + for index, node in enumerate(nodes): + size_alloc = sum(buffer.mpi_buffer.size_alloc for buffer in node.get_outputs()) + pred_buffers = OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]]() + for dep in node.read_writes.reads: + if dep.name in name_to_buf and dep in node.unmet_dependencies: + pred_buffers.add(name_to_buf[dep.name]) + elif dep.name in name_to_freeable_input_buf: + pred_buffers.add(name_to_freeable_input_buf[dep.name]) + pred_nodes = OrderedSet( + name_to_fused_node[pred_buffer.defining_op_name()] + for pred_buffer in pred_buffers + if (isinstance(pred_buffer, SchedulerBuffer)) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) succ_nodes = OrderedSet( succ_node for buffer in node.get_outputs() for succ_node in buffer.mpi_buffer.succ_nodes ) +<<<<<<< HEAD node_to_succ_nodes[node] = succ_nodes # For each successor, add current node as its predecessor @@ -297,10 +381,18 @@ def assign_memory_planning_info_for_scheduler_nodes( size=size_alloc, pred_buffers=node_to_pred_buffers[node], pred_nodes=node_to_pred_nodes[node], +======= + node.mpi_node = MemoryPlanningInfoForNode( + index=index, + size=size_alloc, + pred_buffers=pred_buffers, + pred_nodes=pred_nodes, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) succ_nodes=succ_nodes, ) +<<<<<<< HEAD # map each scheduler buffer to its size, start step, and end step @dataclasses.dataclass class BufferInfo: @@ -400,6 +492,8 @@ def _get_end_step_and_snode( return buf_info_list, node_to_step, buf_to_snode_last_use +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def estimate_peak_memory( nodes: list[BaseSchedulerNode], name_to_freeable_input_buf: dict[str, FreeableInputBuffer], @@ -414,9 +508,74 @@ def estimate_peak_memory( List[int]: memory usage at each node (or each step). """ +<<<<<<< HEAD buf_info_list, _, _ = compute_memory_timeline( nodes, name_to_freeable_input_buf, graph_outputs ) +======= + # map each scheduler buffer to its size, start step, and end step + @dataclasses.dataclass + class BufferInfo: + buffer: Union[SchedulerBuffer, FreeableInputBuffer] + size_alloc: int + size_free: int + start_step: int + end_step: int + + # get the execution step of each node, this will be used to determine + # the end_step of buffers + node_to_step: dict[BaseSchedulerNode, int] = { + node: step for step, node in enumerate(nodes) + } + + # get buffers' size and liveliness information + buf_info_list: list[BufferInfo] = [] + # 1. for freeable input buffers + for buf_name, input_buf in name_to_freeable_input_buf.items(): + end_step = ( + len(nodes) - 1 + if buf_name in graph_outputs + else max( + node_to_step[succ_node] for succ_node in input_buf.mpi_buffer.succ_nodes + ) + ) + buf_info_list.append( + BufferInfo( + input_buf, + input_buf.mpi_buffer.size_free, + input_buf.mpi_buffer.size_free, + 0, + end_step, + ) + ) + + # 2. for scheduler buffers + for step, node in enumerate(nodes): + for sched_buf in node.get_outputs(): + # note: it is possible for a non-graph-output sched_buf to have no succ_nodes and + # to be only used by its defining op (e.g., due to fusion when all consumers of + # the buffer are fused with its defining op). In such cases, end_step is step. + end_step = ( + len(nodes) - 1 + if sched_buf.get_name() in graph_outputs + else max( + [ + node_to_step[succ_node] + for succ_node in sched_buf.mpi_buffer.succ_nodes + ], + default=step, + ) + ) + buf_info_list.append( + BufferInfo( + sched_buf, + sched_buf.mpi_buffer.size_alloc, + sched_buf.mpi_buffer.size_free, + step, + end_step, + ) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # incremental memory changes at each step memory = [0 for _ in range(len(nodes) + 1)] @@ -438,6 +597,7 @@ def estimate_peak_memory( return (max_memory, memories_at_nodes) +<<<<<<< HEAD @dataclasses.dataclass class SNodeMemory: size_alloc: int @@ -505,6 +665,8 @@ def estimate_peak_memory_allocfree( ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def topological_sort_lpmf( nodes: list[BaseSchedulerNode], name_to_freeable_input_buf: dict[str, FreeableInputBuffer], @@ -518,7 +680,11 @@ def topological_sort_lpmf( Buffer memory optimization for video codec application modeled in Simulink https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF +<<<<<<< HEAD The algorithm maintains the max memory so far. +======= + The algorithm maintain the max memory so far. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) At every iteration, for each scheduleable node, it computes: - how much memory needs to be allocated for the output buffers of this node; - how much memory can be freed as a result of executing this node. @@ -744,6 +910,7 @@ def visit(n: BaseSchedulerNode) -> None: return result +<<<<<<< HEAD def validate_graph_acyclic(nodes: list[BaseSchedulerNode]) -> None: """ Validate that the graph is acyclic by checking predecessor relationships. @@ -831,6 +998,8 @@ def validate_unique_buffer_names( ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def prepare_planning_info( nodes: list[BaseSchedulerNode], name_to_buf: dict[str, SchedulerBuffer], @@ -887,6 +1056,7 @@ def reorder_for_peak_memory( graph_outputs, ) +<<<<<<< HEAD # Validate planning info before proceeding with reordering try: validate_graph_acyclic(nodes) @@ -896,6 +1066,8 @@ def reorder_for_peak_memory( if not is_fbcode(): # TODO: remove after ensuring OSS side is safe raise +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # keep track of the peak memory estimates of different methods peak_memory_diff_methods: list[PeakMemoryResult] = [] peak_memory_diff_methods.append( @@ -922,8 +1094,11 @@ def reorder_for_peak_memory( torch_log.info("%s peak memory: %d", method.__name__, peak_memory) except Exception as e: torch_log.error("Failed to reorder for %s: %s", method.__name__, e) +<<<<<<< HEAD if not is_fbcode(): # TODO: remove after ensuring OSS side is safe raise +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) signpost_event( category="inductor", diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index 866c22abd0699..4a861d04137bb 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -1,11 +1,19 @@ # mypy: allow-untyped-defs from collections.abc import Sequence +<<<<<<< HEAD from typing import Any, Optional, Union +======= +from typing import Any, Optional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sympy import torch +<<<<<<< HEAD from torch._prims_common import make_channels_last_strides_for, StrideType +======= +from torch._prims_common import make_channels_last_strides_for +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._ordered_set import OrderedSet from .ir import ( @@ -14,7 +22,10 @@ FlexibleLayout, get_device_type, ir_node_to_tensor, +<<<<<<< HEAD IRNode, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_contiguous_storage_and_layout, Layout, may_convert_to_optional, @@ -22,7 +33,10 @@ MultiOutputLayout, MutationOutput, NoneLayout, +<<<<<<< HEAD ShapeAsConstantBuffer, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorBox, ) from .utils import convert_shape_to_inductor, pad_listlike, SUPPORTED_MKLDNN_DEVICES @@ -177,7 +191,11 @@ def _original_deconv_weight_size( if ( dynamic_shapes or get_device_type(x) == "xpu" ) and is_contiguous_storage_and_layout(x): +<<<<<<< HEAD output_stride: StrideType = FlexibleLayout.contiguous_strides(output_size) +======= + output_stride = FlexibleLayout.contiguous_strides(output_size) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Currently we don't support channel last for the situation that stride of input's batch dim is 0, # eg. input_size = (1, 1280, 64, 64), but input_stride=(0, 1, 81920, 1280). # So we use NCHW hear instead. @@ -513,13 +531,17 @@ def __init__( inputs, constant_args=(), ) -> None: +<<<<<<< HEAD self.device_type = get_device_type(inputs[0]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__( layout, inputs, constant_args, None, op_overload=torch.ops.mkldnn._convolution_transpose_pointwise.default, +<<<<<<< HEAD cpp_kernel_name=f"aoti_torch_{self.device_type}_mkldnn__convolution_transpose_pointwise", ) @@ -527,6 +549,13 @@ def codegen(self, wrapper): wrapper.include_extra_header( f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" ) +======= + cpp_kernel_name="aoti_torch_cpu_mkldnn__convolution_transpose_pointwise", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().codegen(wrapper) @classmethod @@ -593,7 +622,10 @@ def __init__( - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp, fp32_output, unary_attr, unary_scalars, unary_algorithm] """ +<<<<<<< HEAD self.device_type = get_device_type(inputs[0]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.has_bias = len(inputs) == 5 super().__init__( layout, @@ -601,6 +633,7 @@ def __init__( constant_args, None, op_overload=torch.ops.onednn.qconv_pointwise.default, +<<<<<<< HEAD cpp_kernel_name=f"aoti_torch_{self.device_type}__qconv_pointwise_tensor", ) @@ -608,6 +641,13 @@ def codegen(self, wrapper): wrapper.include_extra_header( f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" ) +======= + cpp_kernel_name="aoti_torch_cpu__qconv_pointwise_tensor", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().codegen(wrapper) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @@ -616,8 +656,13 @@ def codegen(self, wrapper): def create( cls, qx: "TensorBox", +<<<<<<< HEAD x_scale: Union["ShapeAsConstantBuffer", "TensorBox"], x_zero_point: Union["ShapeAsConstantBuffer", "TensorBox"], +======= + x_scale: "TensorBox", + x_zero_point: "TensorBox", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) qw: "TensorBox", # qw w_scale: "TensorBox", w_zero_point: "TensorBox", @@ -652,7 +697,11 @@ def create( groups, transposed, output_padding, +<<<<<<< HEAD [x_scale, x_zero_point, w_scale, w_zero_point], # type: ignore[list-item] +======= + [x_scale, x_zero_point, w_scale, w_zero_point], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # swap padding and stride to align with functional conv arg order if bias is None: @@ -700,7 +749,10 @@ def __init__( - const_args [b, stride, padding, dilation, groups, o_scale, o_zp, output_dtype, accum_scale, accum_zp, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm] """ +<<<<<<< HEAD self.device_type = get_device_type(inputs[0]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.has_bias = len(inputs) == 8 self.idx_for_inplace_sum = 6 super().__init__( @@ -709,6 +761,7 @@ def __init__( constant_args, None, op_overload=torch.ops.onednn.qconv2d_pointwise.binary, +<<<<<<< HEAD cpp_kernel_name=( f"aoti_torch_{self.device_type}__qconv2d_pointwise_binary_tensor" ), @@ -718,12 +771,23 @@ def codegen(self, wrapper): wrapper.include_extra_header( f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" ) +======= + cpp_kernel_name=("aoti_torch_cpu__qconv2d_pointwise_binary_tensor"), + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().codegen(wrapper) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) def get_mutation_names(self) -> Sequence[str]: +<<<<<<< HEAD return [self.input_name(self.idx_for_inplace_sum)] +======= + return [self.inputs[self.idx_for_inplace_sum].get_name()] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() @@ -845,10 +909,17 @@ def create(cls, x, packed_w, orig_w, B, batch_size): else: constant_args.insert(0, None) +<<<<<<< HEAD device = x.get_device() assert device is not None return MKLPackedLinear( layout=FixedLayout(device, x.get_dtype(), output_size, output_stride), +======= + return MKLPackedLinear( + layout=FixedLayout( + x.get_device(), x.get_dtype(), output_size, output_stride + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inputs=inputs, constant_args=constant_args, ) @@ -861,13 +932,17 @@ def __init__( inputs, constant_args=(), ) -> None: +<<<<<<< HEAD self.device_type = get_device_type(inputs[0]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__( layout, inputs, constant_args, None, op_overload=torch.ops.mkldnn._linear_pointwise.default, +<<<<<<< HEAD cpp_kernel_name=f"aoti_torch_{self.device_type}__linear_pointwise", ) @@ -875,6 +950,13 @@ def codegen(self, wrapper): wrapper.include_extra_header( f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" ) +======= + cpp_kernel_name="aoti_torch_cpu__linear_pointwise", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().codegen(wrapper) @classmethod @@ -893,12 +975,18 @@ def create(cls, x, w, B, attr, scalars, algorithm): else: constant_args.insert(0, None) +<<<<<<< HEAD device = x.get_device() assert device is not None packed = LinearUnary( layout=FixedLayout( device=device, +======= + packed = LinearUnary( + layout=FixedLayout( + device=x.get_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype=x.get_dtype(), size=output_size, ), @@ -920,13 +1008,17 @@ def __init__( inputs, constant_args=(), ) -> None: +<<<<<<< HEAD self.device_type = get_device_type(inputs[0]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__( layout, inputs, constant_args, None, op_overload=torch.ops.mkldnn._linear_pointwise.binary, +<<<<<<< HEAD cpp_kernel_name=f"aoti_torch_{self.device_type}__linear_pointwise_binary", ) @@ -934,6 +1026,13 @@ def codegen(self, wrapper): wrapper.include_extra_header( f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" ) +======= + cpp_kernel_name="aoti_torch_cpu__linear_pointwise_binary", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().codegen(wrapper) @classmethod @@ -953,11 +1052,17 @@ def create(cls, x, y, w, B, attr): else: constant_args.insert(0, B) +<<<<<<< HEAD device = x.get_device() assert device is not None packed = LinearBinary( layout=FixedLayout( device=device, +======= + packed = LinearBinary( + layout=FixedLayout( + device=x.get_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype=x.get_dtype(), size=output_size, ), @@ -988,7 +1093,10 @@ def __init__( - const_args is: [bias, x_scale, x_zp, o_scale, o_zp, fp32_output, unary_attr, unary_scalars, unary_algorithm] """ +<<<<<<< HEAD self.device_type = get_device_type(inputs[0]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.has_bias = has_bias super().__init__( layout, @@ -996,6 +1104,7 @@ def __init__( constant_args, None, op_overload=(torch.ops.onednn.qlinear_pointwise.tensor), +<<<<<<< HEAD cpp_kernel_name=( f"aoti_torch_{self.device_type}__qlinear_pointwise_tensor" ), @@ -1005,6 +1114,13 @@ def codegen(self, wrapper): wrapper.include_extra_header( f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" ) +======= + cpp_kernel_name=("aoti_torch_cpu__qlinear_pointwise_tensor"), + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().codegen(wrapper) if isinstance(self.layout, Layout): @@ -1076,7 +1192,10 @@ def __init__( - const_args is: [bias, o_scale, o_zp, fp32_output, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm] """ +<<<<<<< HEAD self.device_type = get_device_type(inputs[0]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.has_bias = has_bias self.idx_for_inplace_sum = 6 super().__init__( @@ -1085,6 +1204,7 @@ def __init__( constant_args, None, op_overload=(torch.ops.onednn.qlinear_pointwise.binary_tensor), +<<<<<<< HEAD cpp_kernel_name=f"aoti_torch_{self.device_type}__qlinear_pointwise_binary_tensor", ) @@ -1092,6 +1212,13 @@ def codegen(self, wrapper): wrapper.include_extra_header( f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" ) +======= + cpp_kernel_name="aoti_torch_cpu__qlinear_pointwise_binary_tensor", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().codegen(wrapper) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @@ -1099,9 +1226,13 @@ def codegen(self, wrapper): def get_mutation_names(self) -> Sequence[str]: binary_post_op = self.constant_args[-5] if binary_post_op == "sum": +<<<<<<< HEAD input = self.inputs[self.idx_for_inplace_sum] assert isinstance(input, IRNode) return [input.get_name()] +======= + return [self.inputs[self.idx_for_inplace_sum].get_name()] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return [] @@ -1252,10 +1383,15 @@ def create( train, ] +<<<<<<< HEAD device = x.get_device() assert device is not None packed = MkldnnRnnLayer( MultiOutputLayout(device=device), +======= + packed = MkldnnRnnLayer( + MultiOutputLayout(device=x.get_device()), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inputs=inputs, constant_args=constant_args, ) @@ -1276,7 +1412,11 @@ def get_strides_of_lstm_output(output_shape, batch_first): output_ir = [ MultiOutput( FixedLayout( +<<<<<<< HEAD x.get_device(), # type: ignore[arg-type] +======= + x.get_device(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x.get_dtype(), output_size, output_stride, diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 3b3a7b072534a..442a52e1f9e87 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -1,6 +1,10 @@ # mypy: allow-untyped-defs import functools +<<<<<<< HEAD from typing import Optional, Union +======= +from typing import Optional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.utils._pytree as pytree @@ -35,6 +39,7 @@ def create_int8_compensation( x_scale: ir.TensorBox, x_zp: ir.TensorBox, w_scale: ir.TensorBox, +<<<<<<< HEAD ) -> tuple[ bool, Union[ir.TensorBox, ir.ShapeAsConstantBuffer], @@ -42,13 +47,25 @@ def create_int8_compensation( ]: x_w_scale: Optional[Union[ir.TensorBox, ir.ShapeAsConstantBuffer]] = None use_int8_fast_compensation_path = all( +======= +) -> tuple[bool, ir.TensorBox, Optional[ir.TensorBox]]: + use_int8_fast_compensation_path = False + weight_compens = None + x_w_scale = None + if all( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) isinstance(item, ir.TensorBox) and item.get_name() in V.graph.constants and hasattr(item.data, "data") and isinstance(item.data.data, ir.ConstantBuffer) for item in [x_scale, x_zp, w_scale] +<<<<<<< HEAD ) if use_int8_fast_compensation_path: +======= + ): + use_int8_fast_compensation_path = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x_w_scale_tensor = ( V.graph.constants[x_scale.get_name()] * V.graph.constants[w_scale.get_name()] @@ -70,7 +87,11 @@ def create_int8_compensation( weight_compens_tensor, name=packed_weight.get_name() + "_BMatrixCompens", ) +<<<<<<< HEAD return ( # type: ignore[return-type] +======= + return ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) use_int8_fast_compensation_path, weight_compens, x_w_scale, @@ -147,12 +168,21 @@ def grouped_gemm_lowering( choices: list[ChoiceCaller] = [] *_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout) +<<<<<<< HEAD kwargs = { "has_bias": [bias is not None for bias in b], "trans_w": True, "epilogue_creator": None, "act_mapping": dict.fromkeys(range(num_gemm), x), } +======= + kwargs = dict( + has_bias=[bias is not None for bias in b], + trans_w=True, + epilogue_creator=None, + act_mapping=dict.fromkeys(range(num_gemm), x), + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_nodes = [x, *w] input_nodes.extend([bias for bias in b if bias is not None]) @@ -184,7 +214,11 @@ def grouped_gemm_lowering( if len(x_size) > 2: for gemm_idx in range(num_gemm): return_tensors[gemm_idx] = view( +<<<<<<< HEAD return_tensors[gemm_idx], # type: ignore[arg-type] +======= + return_tensors[gemm_idx], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (*x_size[:-1], return_tensors[gemm_idx].get_size()[-1]), ) return return_tensors @@ -341,7 +375,11 @@ def linear_unary( # GEMM template needs 2D input, normalize input shape here x = view(x, [-1, x_size[-1]]) if b is not None: +<<<<<<< HEAD b = ir.ExternKernel.realize_input(b) # type: ignore[assignment] +======= + b = ir.ExternKernel.realize_input(b) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) choices: list[ChoiceCaller] = [] if config.max_autotune or config.max_autotune_gemm: transposed_w = permute(w, [1, 0]) @@ -353,6 +391,7 @@ def epilogue_creator(buf): buf, attr, scalars=scalars, algorithm=algorithm ) +<<<<<<< HEAD kwargs = { "has_bias": b is not None, "trans_w": True, @@ -360,6 +399,13 @@ def epilogue_creator(buf): None if attr == "none" else epilogue_creator ), } +======= + kwargs = dict( + has_bias=b is not None, + trans_w=True, + epilogue_creator=None if attr == "none" else epilogue_creator, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if b is not None: kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment] CppGemmTemplate.add_choices( @@ -406,7 +452,11 @@ def linear_binary( if len(y_size) > 2: y = view(y, [-1, y_size[-1]]) if b is not None: +<<<<<<< HEAD b = ir.ExternKernel.realize_input(b) # type: ignore[assignment] +======= + b = ir.ExternKernel.realize_input(b) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) choices: list[ChoiceCaller] = [] if config.max_autotune or config.max_autotune_gemm: transposed_w = permute(w, [1, 0]) @@ -418,12 +468,20 @@ def linear_binary( def epilogue_creator(buf): return create_epilogue_with_attr(buf, attr, other=y) +<<<<<<< HEAD kwargs = { "has_bias": b is not None, "trans_w": True, "epilogue_creator": epilogue_creator, } +======= + kwargs = dict( + has_bias=b is not None, + trans_w=True, + epilogue_creator=epilogue_creator, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1] CppGemmTemplate.add_choices( choices, @@ -634,8 +692,13 @@ def qconvolution_binary( return TensorBox.create( mkldnn_ir.QConvPointWiseBinaryPT2E.create( x, +<<<<<<< HEAD x_scale, # type: ignore[arg-type] x_zp, # type: ignore[arg-type] +======= + x_scale, + x_zp, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) packed_weight, w_scale, w_zp, @@ -675,8 +738,13 @@ def qlinear_unary( algorithm, layout=None, ): +<<<<<<< HEAD assert packed_weight.get_dtype() in [torch.int8, torch.float8_e4m3fn], ( "Only int8 and e4m3fn weights are supported by oneDNN qlinear." +======= + assert packed_weight.get_dtype() is torch.int8, ( + "Only int8 weights are supported by oneDNN qlinear." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) x_size = x.get_size() if len(x_size) > 2: @@ -732,7 +800,11 @@ def qlinear_unary( ): # W_zp might be a ConstantBuffer with int64, convert it to int32 w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32) +<<<<<<< HEAD w_zp = V.graph.add_tensor_constant( # type: ignore[assignment] +======= + w_zp = V.graph.add_tensor_constant( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name() ) @@ -1035,7 +1107,11 @@ def qlinear_binary( ir.ConstantBuffer, ): w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32) +<<<<<<< HEAD w_zp = V.graph.add_tensor_constant( # type: ignore[assignment] +======= + w_zp = V.graph.add_tensor_constant( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name() ) if binary_attr == "sum": diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index a52257c61480c..3642cedd39e1a 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -31,7 +31,10 @@ "prod", "sum", "xor_sum", +<<<<<<< HEAD "online_softmax_reduce", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] @@ -706,9 +709,12 @@ def placeholder(self, index: int) -> T: """This is a fake op used in analysis but not codegen""" raise NotImplementedError +<<<<<<< HEAD def device_assert_async(self, cond: T, msg: str) -> T: raise NotImplementedError +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _ignore_op_re = re.compile(r"_.*|paren").fullmatch @@ -791,9 +797,12 @@ def {target}(self, {", ".join(args)}): if target in OP_NAMES: setattr(cls, target, impl) +<<<<<<< HEAD def device_assert_async(self, cond, msg): return None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DefaultHandler._init_cls() @@ -939,9 +948,12 @@ def sort(dtypes, values, stable, descending): def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: return sympy_index_symbol(str(index_var)) +<<<<<<< HEAD def device_assert_async(self, cond, msg): return None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class KernelFormatterHandler(DefaultHandler): def __init__(self, parent_handler: OpsHandler[Any]): @@ -1008,9 +1020,12 @@ def getvalue(self, result): self._output.writeline(f"return {result}") return self._output.getvalue() +<<<<<<< HEAD def device_assert_async(self, cond, msg: str): return f"ops.device_assert_async({cond}, {msg})" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class WrapperHandler(DefaultHandler): def __init__(self, inner: OpsHandler[Any]): diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index 955c00c51d0b9..66c2d27ca4335 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -41,16 +41,24 @@ ) from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param from torch._inductor.utils import ( +<<<<<<< HEAD _unstable_customized_partition_wrapper, align_inputs_from_check_idxs, BoxedBool, CUDAGraphWrapperMetadata, +======= + align_inputs_from_check_idxs, + BoxedBool, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GraphPartitionMap, InputType, output_node, set_tracing_context_output_strides, ) +<<<<<<< HEAD from torch.autograd.profiler import record_function +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._ordered_set import OrderedSet from . import config @@ -424,8 +432,11 @@ class CompiledFxGraph(OutputCode): # fx graph. The expression must be generated by: # ShapeEnv.produce_guards_expression() guards_expr: Optional[str] +<<<<<<< HEAD inductor_provenance_mapping_str: Optional[str] inductor_provenance_stack_traces_str: Optional[str] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cudagraph_info: Optional[CudagraphCachedInfo] partition_maps: Optional[list[GraphPartitionMap]] @@ -452,8 +463,11 @@ def __init__( runnable_graph_str: str, inductor_post_grad_graph_str: str, compiled_fn_runner: Optional[Any] = None, +<<<<<<< HEAD inductor_provenance_mapping_str: Optional[str] = None, inductor_provenance_stack_traces_str: Optional[str] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: self.current_callable = current_callable self.compiled_fn_runner = compiled_fn_runner @@ -468,8 +482,11 @@ def __init__( self.source_code = f.read() self.runnable_graph_str = runnable_graph_str self.inductor_post_grad_graph_str = inductor_post_grad_graph_str +<<<<<<< HEAD self.inductor_provenance_mapping_str = inductor_provenance_mapping_str self.inductor_provenance_stack_traces_str = inductor_provenance_stack_traces_str +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.cache_linemap = graph.cache_linemap # TODO - ordered set self.device_types = OrderedSet(graph.device_types) @@ -589,6 +606,7 @@ def __del__(self) -> None: def __call__(self, inputs: Sequence[Any]) -> Any: assert self.current_callable is not None +<<<<<<< HEAD if ( torch._inductor.debug.RECORD_GRAPH_EXECUTION @@ -611,6 +629,10 @@ def __call__(self, inputs: Sequence[Any]) -> Any: f"## Call CompiledFxGraph {self._fx_graph_cache_key} ##" ): return self.current_callable(inputs) +======= + try: + return self.current_callable(inputs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) finally: get_runtime_metrics_context().finish() AutotuneCacheBundler.end_compile() @@ -630,6 +652,7 @@ def post_compile( This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph. The results of this function are *not* saved in the cache itself. """ +<<<<<<< HEAD if config.graph_partition and _unstable_customized_partition_wrapper.wrapper: # Mechanically apply user-specified cudagraph wrappers without modification assert self.recursively_apply_fns is not None @@ -647,6 +670,8 @@ def post_compile( self.recursively_apply_fns(customized_wrappers_with_metadata) return +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) set_tracing_context_output_strides(example_inputs, self) assert graph_kwargs["cudagraphs"] is not None assert graph_kwargs["is_backward"] is not None @@ -765,7 +790,11 @@ class CompiledAOTI(OutputCode): Class holding an AOTInductor compiled so. """ +<<<<<<< HEAD filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule] +======= + filename: Union[str, list[Union[str, Weights]]] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __call__(self, inputs: Sequence[Any]) -> Any: raise NotImplementedError("NYI") diff --git a/torch/_inductor/package/package.py b/torch/_inductor/package/package.py index bd11d033cadb3..7a996f59e47bb 100644 --- a/torch/_inductor/package/package.py +++ b/torch/_inductor/package/package.py @@ -105,7 +105,11 @@ def load_package( run_single_threaded: bool = False, num_runners: int = 1, device_index: int = -1, +<<<<<<< HEAD ) -> AOTICompiledModel: +======= +) -> AOTICompiledModel: # type: ignore[type-arg] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: pt2_contents = load_pt2( path, diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index e8210f1e80f81..28f0250dc6eaa 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -63,11 +63,18 @@ from torch._prims_common import is_integer_dtype from torch._subclasses.fake_tensor import unset_fake_temporarily from torch.fx.experimental.proxy_tensor import make_fx +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import guard_or_false, statically_known_true from torch.fx.graph_module import _get_attr from torch.fx.immutable_collections import immutable_dict, immutable_list from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.traceback import preserve_node_meta +======= +from torch.fx.experimental.symbolic_shapes import statically_known_true +from torch.fx.graph_module import _get_attr +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._ordered_set import OrderedSet from .._functorch import config as functorch_config @@ -87,8 +94,11 @@ Constant = Any NodeOrConstant = Union[Constant, torch.fx.Node] +<<<<<<< HEAD backend = os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_BACKEND", "inductor") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class SearchFn(Protocol): __name__: str @@ -130,7 +140,11 @@ def _transfer_meta( # transfer metadata after pattern matching occurs. # skip "val" and "tensor_meta" because this info is too specific; it's unlikely # to remain accurate after pattern matching has occurred. +<<<<<<< HEAD if config.trace.provenance_tracking_level == 1: +======= + if config.trace.enabled: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We handle "from_node" field of the node meta specially to record that the new node comes from the old_node. new_from_node = new_meta.get("from_node", []).copy() new_from_node.append(NodeSource(old_node, pass_name, NodeSourceAction.REPLACE)) @@ -146,8 +160,11 @@ def _transfer_meta( for k, v in old_node.meta.items() if k in torch.fx.proxy._COPY_META_FIELDS ) +<<<<<<< HEAD if "stack_trace" in old_node.meta: new_meta["stack_trace"] = old_node.meta["stack_trace"] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Match: @@ -323,12 +340,16 @@ def record(node: torch.fx.Node, val: Any) -> None: ] else: +<<<<<<< HEAD example_vals = torch.fx.map_arg( args, lambda arg: arg.meta["val"] if "val" in arg.meta else arg.meta["example_value"], ) +======= + example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) replacement = trace_fn(replacement_fn, example_vals) if len(self.nodes) == 1: for n in replacement.graph.nodes: @@ -548,7 +569,11 @@ def __init__( fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns) for fn in fns: if isinstance(fn, torch._ops.OpOverloadPacket): +<<<<<<< HEAD fns.extend(getattr(fn, overload) for overload in fn.overloads()) # noqa: B909 +======= + fns.extend(getattr(fn, overload) for overload in fn.overloads()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.fns = fns self.fns_set = OrderedSet(fns) @@ -1177,6 +1202,7 @@ def run_node(self, node: torch.fx.Node) -> Any: raise NotImplementedError( f"NYI: replacement_graph.{target} is not a graph module. Got {sub_gm}." ) +<<<<<<< HEAD assert graph.owning_module is not None graph_name = None for n, mod in graph.owning_module.named_modules(): @@ -1189,6 +1215,14 @@ def run_node(self, node: torch.fx.Node) -> Any: graph.owning_module, target ) graph.owning_module.register_module(graph_name, sub_gm) +======= + + assert graph.owning_module is not None + _, graph_name = unique_graph_name_with_root( + graph.owning_module, str(target) + ) + graph.owning_module.register_module(graph_name, sub_gm) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return graph.get_attr(graph_name) raise NotImplementedError(f"unhandled {node}") @@ -1306,9 +1340,13 @@ def replace( for user in old_uses: idx = maybe_getitem(user) if idx is None: +<<<<<<< HEAD raise AssertionError( "Deleted index from getitem, did you erase the index and not properly replace it?" ) +======= + raise AssertionError("can't handle") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) replace(user, new[idx]) graph.erase_node(old) @@ -1434,9 +1472,13 @@ def check_fn(match: Match) -> bool: ) sym_args: list[torch.SymInt] = [] +<<<<<<< HEAD fake_mode = torch._dynamo.utils.detect_fake_mode(args) assert fake_mode is not None with fake_mode: +======= + with torch._dynamo.utils.detect_fake_mode(args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i, grad in enumerate(requires_grad): if isinstance(args[i], torch.Tensor): if grad and is_integer_dtype(args[i].dtype): @@ -1978,12 +2020,20 @@ def apply(self, gm: Union[torch.fx.GraphModule, torch.fx.Graph]) -> int: continue if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: log.warning("%s%s %s %s", node, node.args, m, entry.pattern) +<<<<<<< HEAD if is_match(m) and guard_or_false(entry.extra_check(m)): count += 1 entry.apply(m, graph, node) counters[backend]["pattern_matcher_count"] += 1 counters[backend]["pattern_matcher_nodes"] += len(m.nodes) +======= + if is_match(m) and entry.extra_check(m): + count += 1 + entry.apply(m, graph, node) + counters["inductor"]["pattern_matcher_count"] += 1 + counters["inductor"]["pattern_matcher_nodes"] += len(m.nodes) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return count def clear(self) -> None: @@ -2108,7 +2158,11 @@ def fwd_only( ) -> torch.fx.GraphModule: """Build a normalized inference graph, for use with fx_to_pattern""" # TODO - look into using aot autograd, asserting no mutating ops here +<<<<<<< HEAD with enable_python_dispatcher(), preserve_node_meta(): +======= + with enable_python_dispatcher(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) decompositions = ( get_decomp_fn() if get_decomp_fn is not None else select_decomp_table() ) @@ -2219,13 +2273,21 @@ def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]: @functools.cache @functools.wraps(fn) def lazy_init() -> Any: +<<<<<<< HEAD counters_ref = counters[backend].copy() +======= + counters_ref = counters["inductor"].copy() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode(): result = fn() # clear view matches encountered during tracing +<<<<<<< HEAD counters[backend] = counters_ref +======= + counters["inductor"] = counters_ref +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return result diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index 1304ce79b86ed..ff55b19d756c6 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -170,6 +170,7 @@ def get(self, key: str) -> Optional[_T]: try: result = self._get(key, sample) cache_stats.get(type(self).__name__, result) +<<<<<<< HEAD except Exception as e: cache_stats.exception(type(self).__name__) if sample: @@ -177,6 +178,12 @@ def get(self, key: str) -> Optional[_T]: raise finally: self._log_sample(sample) +======= + except Exception: + cache_stats.exception(type(self).__name__) + raise + self._log_sample(sample) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return result # Add `value` to the cache with the key `key`. Note that `None` is not a @@ -189,6 +196,7 @@ def put(self, key: str, value: _T) -> None: try: self._put(key, value, sample) cache_stats.put(type(self).__name__) +<<<<<<< HEAD except Exception as e: cache_stats.exception(type(self).__name__) if sample: @@ -196,6 +204,12 @@ def put(self, key: str, value: _T) -> None: raise finally: self._log_sample(sample) +======= + except Exception: + cache_stats.exception(type(self).__name__) + raise + self._log_sample(sample) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Used to convert data from the cache into structured data. def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override] diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index 88b9c80c77146..7a580ced1ae48 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -35,7 +35,10 @@ from typing_extensions import override import torch +<<<<<<< HEAD from torch._dynamo.precompile_context import PrecompileContext +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.runtime.runtime_utils import cache_dir from torch.compiler._cache import ( CacheArtifact, @@ -126,7 +129,10 @@ def create( ) -> Optional[AutotuneCache]: cache = AutotuneCache(configs_hash) key = AutotuneCache._prepare_key(filename) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key) cache._setup_remote_autotune_cache(inductor_meta, key) if cache.local_cache or cache.remote_cache: @@ -302,10 +308,13 @@ def save( CacheArtifactManager.record_artifact( AutotuneCacheArtifact.type(), autotune_artifact_key, data ) +<<<<<<< HEAD if torch._dynamo.config.caching_precompile: PrecompileContext.record_artifact( AutotuneCacheArtifact.type(), autotune_artifact_key, data ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if log.isEnabledFor(logging.DEBUG): type_str = "coordesc" if found_by_coordesc else "heuristic" @@ -631,10 +640,13 @@ def _get(self, key: str, sample: Optional[Sample]) -> Optional[JsonDataTy]: CacheArtifactManager.record_artifact( AutotuneCacheArtifact.type(), autotune_artifact_key, result ) +<<<<<<< HEAD if torch._dynamo.config.caching_precompile: PrecompileContext.record_artifact( AutotuneCacheArtifact.type(), autotune_artifact_key, result ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return result @override diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index 95b1ba64d1580..534112eda5513 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -3,7 +3,11 @@ from functools import cached_property, wraps from itertools import chain from statistics import median +<<<<<<< HEAD from typing import Any, Callable, Optional, Union +======= +from typing import Any, Callable +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing_extensions import Concatenate, ParamSpec, Self, TypeVar import torch @@ -173,7 +177,11 @@ def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> fl return self.triton_do_bench(_callable, **kwargs, return_mode="median") +<<<<<<< HEAD class InductorBenchmarker(TritonBenchmarker): # noqa: docstring_linter +======= +class InductorBenchmarker(TritonBenchmarker): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @cached_property def L2_cache_size(self: Self) -> int: """Get the L2 cache size, in bytes, of the current device.""" @@ -205,17 +213,26 @@ def get_event_pairs_min_timing( ) @time_and_count +<<<<<<< HEAD def benchmark_gpu( # type: ignore[override] +======= + def benchmark_gpu( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self: Self, _callable: Callable[[], Any], estimation_iters: int = 5, memory_warmup_iters: int = 100, benchmark_iters: int = 100, max_benchmark_duration: int = 25, +<<<<<<< HEAD return_mode: str = "min", grad_to_none: Optional[list[torch.Tensor]] = None, **kwargs: Any, ) -> Union[float, list[float]]: +======= + **kwargs: Any, + ) -> float: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Benchmark a GPU callable using a custom benchmarking implementation. Arguments: @@ -233,6 +250,7 @@ def benchmark_gpu( # type: ignore[override] of `memory_warmup_iters` and `benchmark_iters`, along with the estimated runtime of `_callable` and various other factors, and we then shrink `benchmark_iters` to fit in the allotted maximum duration. +<<<<<<< HEAD - return_mode: Return mode for benchmark results. Options are "min" (default), "all" (returns all measurements). - grad_to_none: Optionally, a list of tensors whose gradients should be cleared @@ -242,6 +260,12 @@ def benchmark_gpu( # type: ignore[override] Returns: - If return_mode="min": The minimum runtime of `_callable`, in milliseconds. - If return_mode="all": List of all runtime measurements, in milliseconds. +======= + - **kwargs: Additional kwargs that may be passed to the fallback. + + Returns: + - The minimum runtime of `_callable`, in milliseconds. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ # we don't want any outside errors propagating into benchmarking torch.cuda.synchronize() @@ -257,10 +281,13 @@ def benchmark_gpu( # type: ignore[override] # estimate the runtime of `_callable` event_pairs = self.get_event_pairs(estimation_iters) for start_event, end_event in event_pairs: +<<<<<<< HEAD # Clear gradients before timing (matches triton.testing.do_bench) if grad_to_none is not None: for x in grad_to_none: x.grad = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) buffer.zero_() start_event.record() _callable() @@ -280,20 +307,28 @@ def benchmark_gpu( # type: ignore[override] # benchmark `_callable` event_pairs = self.get_event_pairs(benchmark_iters) for start_event, end_event in event_pairs: +<<<<<<< HEAD # Clear gradients before timing (matches triton.testing.do_bench) if grad_to_none is not None: for x in grad_to_none: x.grad = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) buffer.zero_() start_event.record() _callable() end_event.record() torch.cuda.synchronize() +<<<<<<< HEAD +======= + benchmarked_timing = self.get_event_pairs_min_timing(event_pairs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # explicitly delete the buffer, sometimes helps memory # footprint metrics in OSS Inductor performance benchmarks del buffer +<<<<<<< HEAD # Return based on the requested mode if return_mode == "all": # Get all timings from event pairs @@ -311,6 +346,11 @@ def benchmark_gpu( # type: ignore[override] raise ValueError( f"Unsupported return_mode: {return_mode}. Use 'min' or 'all'." ) +======= + # return the minimum of `estimated_timing` and `benchmarked_timing`, + # we just want the minimum timing overall so we might as well check both + return min(estimated_timing, benchmarked_timing) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) benchmarker = ( diff --git a/torch/_inductor/runtime/compile_tasks.py b/torch/_inductor/runtime/compile_tasks.py index 1851e447e1950..2667c99414d1a 100644 --- a/torch/_inductor/runtime/compile_tasks.py +++ b/torch/_inductor/runtime/compile_tasks.py @@ -10,8 +10,11 @@ from types import ModuleType from typing import Any, Callable, TYPE_CHECKING +<<<<<<< HEAD from torch._utils_internal import log_triton_builds +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING: from torch._inductor.runtime.triton_heuristics import CachingAutotuner @@ -40,7 +43,11 @@ def _reload_python_module( def _set_triton_ptxas_path() -> None: if os.environ.get("TRITON_PTXAS_PATH") is not None: return +<<<<<<< HEAD ptxas = Path(__file__).absolute().parents[2] / "bin" / "ptxas" +======= + ptxas = Path(__file__).absolute().parents[1] / "bin" / "ptxas" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not ptxas.exists(): return if ptxas.is_file() and os.access(ptxas, os.X_OK): @@ -59,6 +66,7 @@ def _worker_compile_triton( from torch._inductor import config with config.patch(extra_config): +<<<<<<< HEAD fail = None try: start_ns = time.time_ns() @@ -74,3 +82,13 @@ def _worker_compile_triton( raise finally: log_triton_builds(fail=fail) +======= + start_ns = time.time_ns() + kernel = load_kernel() + kernel.precompile(warm_cache_only=True) + elapsed_ns = time.time_ns() - start_ns + kernel.prepare_for_pickle() + # We can release this memory in the compile subprocesses: + linecache.clearcache() + return kernel, elapsed_ns // 1000 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 26b3bcf5cc5cf..342d11fc2c82e 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -248,10 +248,14 @@ def autotune( log.debug("= Do coordinate descent tuning for %s =", self.name) log.debug( +<<<<<<< HEAD "%s: Baseline Config %s, baseline timing %f", self.name, baseline_config, baseline_timing, +======= + "Baseline Config %s, baseline timing %f", baseline_config, baseline_timing +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) improved = True best_config = baseline_config @@ -293,17 +297,28 @@ def autotune( if improved: msg = red_text( +<<<<<<< HEAD "%s: Coordinate descend tuning found improvement of %.3fx by looking in all directions." ) log.debug( msg, self.name, +======= + "Coordinate descend tuning found improvement of %.3fx by looking in all directions." + ) + log.debug( + msg, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) old_best_timing / best_timing, ) log.debug( +<<<<<<< HEAD "%s: Improve from %s %f -> %s %f, %.3fx", self.name, +======= + "Improve from %s %f -> %s %f, %.3fx", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) baseline_config, baseline_timing, best_config, diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index a1a0a792c9b84..fe3371a7bc894 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -13,7 +13,11 @@ # The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values # NOTE: if these fail asserts submit a PR to increase them TRITON_MAX_BLOCK = { +<<<<<<< HEAD "X": 8192, +======= + "X": 4096, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "Y": 1024, "Z": 1024, "R0_": 4096 * 16, # * 16 is multi-kernel only @@ -153,8 +157,14 @@ def create(cls, device) -> DeviceProperties: except AttributeError: if device_type == "xpu": multi_processor_count = props.gpu_subslice_count +<<<<<<< HEAD elif device_type == "mtia": multi_processor_count = 64 +======= + elif device_type == "mps": + # TODO: Fetch the actual value from ioreg + multi_processor_count = 8 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: raise return cls( diff --git a/torch/_inductor/runtime/static_cuda_launcher.py b/torch/_inductor/runtime/static_cuda_launcher.py index bfea6fc119d96..a795158957672 100644 --- a/torch/_inductor/runtime/static_cuda_launcher.py +++ b/torch/_inductor/runtime/static_cuda_launcher.py @@ -54,6 +54,7 @@ def __init__(self, kernel: CompiledKernel) -> None: launch_enter = triton_knobs.runtime.launch_enter_hook launch_exit = triton_knobs.runtime.launch_exit_hook +<<<<<<< HEAD def hook_is_empty(hook: Any) -> bool: if hook is None: return True @@ -67,6 +68,9 @@ def hook_is_empty(hook: Any) -> bool: return False if not hook_is_empty(launch_enter) or not hook_is_empty(launch_exit): +======= + if launch_enter is not None or launch_exit is not None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise NotImplementedError( "We don't support launch enter or launch exit hooks" ) @@ -75,6 +79,7 @@ def hook_is_empty(hook: Any) -> bool: kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared ) +<<<<<<< HEAD def needs_scratch_arg(scratch_name: str, param_name: str) -> bool: if hasattr(kernel.metadata, param_name): if getattr(kernel.metadata, param_name) > 0: @@ -90,6 +95,18 @@ def needs_scratch_arg(scratch_name: str, param_name: str) -> bool: self.has_global_scratch = needs_scratch_arg("Global", "global_scratch_size") # same situation for profile scratch - triton-lang/triton#7258 self.has_profile_scratch = needs_scratch_arg("Profile", "profile_scratch_size") +======= + # Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel. + # Inductor never uses this field or enables it, but we still have to pass + # an extra None into the set of params if its enabled + if hasattr(kernel.metadata, "global_scratch_size"): + if kernel.metadata.global_scratch_size > 0: + raise NotImplementedError("Global scratch not yet supported") + else: + self.has_global_scratch = True + else: + self.has_global_scratch = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.arg_tys = self.arg_ty_from_signature(kernel.src) self.function: Optional[int] = ( @@ -231,12 +248,21 @@ def run( # thing, it should always match. # Get rid of constants before passing to cubin launcher +<<<<<<< HEAD # Add a None if triton wants extra parameters for scratch spaces arg_tys = self.arg_tys for has_scratch in [self.has_global_scratch, self.has_profile_scratch]: if has_scratch: arg_tys = arg_tys + "O" args = (*args, None) +======= + # Add a None if triton wants an extra parameter to the cubin + if self.has_global_scratch: + arg_tys = self.arg_tys + "O" + args = (*args, None) + else: + arg_tys = self.arg_tys +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(args) == len(arg_tys) # TODO: can handle grid functions here or in C++, so diff --git a/torch/_inductor/runtime/triton_compat.py b/torch/_inductor/runtime/triton_compat.py index 645e0f4c8903d..b0e02e880f0c3 100644 --- a/torch/_inductor/runtime/triton_compat.py +++ b/torch/_inductor/runtime/triton_compat.py @@ -87,6 +87,7 @@ def _triton_config_has(param_name: str) -> bool: except ImportError: knobs = None +<<<<<<< HEAD try: from triton.runtime.cache import triton_key # type: ignore[attr-defined] except ImportError: @@ -98,6 +99,11 @@ def _triton_config_has(param_name: str) -> bool: "_semantic" in inspect.signature(triton.language.core.view).parameters ) HAS_TRITON = True +======= + builtins_use_semantic_kwarg = ( + "_semantic" in inspect.signature(triton.language.core.view).parameters + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: def _raise_error(*args: Any, **kwargs: Any) -> Any: @@ -134,8 +140,11 @@ def constexpr(val: Any) -> Any: dtype = Any HAS_WARP_SPEC = False +<<<<<<< HEAD triton_key = _raise_error HAS_TRITON = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def cc_warp_size(cc: Union[str, int]) -> int: @@ -172,5 +181,8 @@ class autograd_profiler: # type: ignore[no-redef] "triton", "cc_warp_size", "knobs", +<<<<<<< HEAD "triton_key", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] diff --git a/torch/_inductor/runtime/triton_helpers.py b/torch/_inductor/runtime/triton_helpers.py index e003615b218fd..29ae303fb8cbb 100644 --- a/torch/_inductor/runtime/triton_helpers.py +++ b/torch/_inductor/runtime/triton_helpers.py @@ -2,6 +2,10 @@ # mypy: allow-untyped-defs import math as pymath import warnings +<<<<<<< HEAD +======= +from functools import wraps +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Any, Callable, TypeVar from .triton_compat import ( # noqa: F401 @@ -79,7 +83,11 @@ def div_floor_integer(a, b): def remainder_integer(a, b): # NOTE: a % b matches C division, not floor division remainder = a % b +<<<<<<< HEAD return tl.where((remainder != 0) & ((a < 0) != (b < 0)), remainder + b, remainder) +======= + return tl.where(remainder != 0 and ((a < 0) != (b < 0)), remainder + b, remainder) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @triton.jit @@ -130,9 +138,15 @@ def minimum_with_index(a_value, a_index, b_value, b_index): if is_floating(a_value): a_isnan = a_value != a_value b_isnan = b_value != b_value +<<<<<<< HEAD mask |= a_isnan & (not b_isnan) # Consider NaNs as equal equal |= a_isnan & b_isnan +======= + mask |= a_isnan and not b_isnan + # Consider NaNs as equal + equal |= a_isnan and b_isnan +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Prefer lowest index if values are equal mask |= equal & (a_index < b_index) @@ -146,9 +160,15 @@ def maximum_with_index(a_value, a_index, b_value, b_index): if is_floating(a_value): a_isnan = a_value != a_value b_isnan = b_value != b_value +<<<<<<< HEAD mask |= a_isnan & (not b_isnan) # Consider NaNs as equal equal |= a_isnan & b_isnan +======= + mask |= a_isnan and not b_isnan + # Consider NaNs as equal + equal |= a_isnan and b_isnan +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Prefer lowest index if values are equal mask |= equal & (a_index < b_index) @@ -168,15 +188,25 @@ def max_with_index(value, index, dim): @triton.jit def exp(x, use_fast_math: tl.constexpr): if use_fast_math: +<<<<<<< HEAD return math.exp(x) else: return libdevice.exp(x) +======= + return libdevice.exp2(x * _LOG_2_E) + else: + return math.exp(x) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @triton.jit def online_softmax_reduce(lhs_max, lhs_sum, dim, use_fast_math: tl.constexpr): out_max = max2(lhs_max, dim) +<<<<<<< HEAD out_max_keepdim = tl.expand_dims(out_max, dim) +======= + out_max_keepdim = out_max[:, None] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) delta = tl.where(out_max_keepdim == float("-inf"), 0, lhs_max - out_max_keepdim) out_sum = tl.sum(lhs_sum * exp(delta, use_fast_math), dim) return out_max, out_sum @@ -314,8 +344,13 @@ def bucketize_binary_search( while full_range > 1: mid = (high + low) // 2 mask = ( +<<<<<<< HEAD (mid * BOUNDARIES_STRIDE + boundary_indices) < BOUNDARIES_UNDERLYING_NUMEL ).logical_and(mid < BOUNDARIES_SIZE) +======= + mid * BOUNDARIES_STRIDE + boundary_indices + ) < BOUNDARIES_UNDERLYING_NUMEL and mid < BOUNDARIES_SIZE +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mid_indices = ( mid if sorter_ptr is None or SORTER_STRIDE is None @@ -564,6 +599,7 @@ def _compare_and_swap_with_index( # actual compare-and-swap ix = x.to(idtype, bitcast=True) +<<<<<<< HEAD # sort treats nan as having the higher value. comparisons with nan always return False. # to align with sort semantics, we need to update descending to check if right_isnan, # and ascending to check if left_isnan. @@ -592,6 +628,16 @@ def _compare_and_swap_with_index( if is_floating(left): eq = eq | (left_isnan & right_isnan) cond = cond | (eq & (left_idx > right_idx)) +======= + if descending: + cond = left < right + else: + cond = left > right + + if stable: + # When stable sorting, tie break by index + cond = cond | ((left == right) & (left_idx > right_idx)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cond = (right_valid_mask > left_valid_mask) | ( (right_valid_mask == left_valid_mask) & cond @@ -722,9 +768,16 @@ def triton_builtin(f: Callable[..., _T]) -> Callable[..., _T]: """ if builtins_use_semantic_kwarg: # support Triton before and after https://github.com/triton-lang/triton/pull/7054 +<<<<<<< HEAD # and after https://github.com/triton-lang/triton/pull/7239 def wrapper(*args, _semantic, **kwargs): kwargs["_builder"] = _semantic +======= + @wraps(f) + def wrapper(*args, **kwargs): + kwargs["_builder"] = kwargs["_semantic"] + del kwargs["_semantic"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f(*args, **kwargs) else: wrapper = f # type: ignore[assignment] diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 1de1f9a595c9e..7512e94a7211e 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -31,7 +31,10 @@ import torch from torch._dynamo.utils import set_feature_use +<<<<<<< HEAD from torch._environment import is_fbcode +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._prims_common import compute_required_storage_length from torch.utils._ordered_set import OrderedSet @@ -53,6 +56,10 @@ ) from .runtime_utils import ( ceildiv, +<<<<<<< HEAD +======= + compilation_callback, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) conditional_product, create_bandwidth_info_str, dynamo_timed, @@ -82,6 +89,7 @@ ) +<<<<<<< HEAD class InductorConfig(Config): """Inductor-specific Triton config with additional control flags""" @@ -90,6 +98,8 @@ def __init__(self, *args, dynamic_scale_rblock=True, **kwargs): self.dynamic_scale_rblock = dynamic_scale_rblock +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class NoTritonConfigsError(RuntimeError): pass @@ -106,6 +116,7 @@ class NoTritonConfigsError(RuntimeError): log = logging.getLogger(__name__) +<<<<<<< HEAD triton_name_sub = re.compile(r"^def [^(]+\(") @@ -134,6 +145,8 @@ def lookup_autotune_config(size_hints, fn) -> Optional[Config]: return cached_config +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_total_reduction_numel(numels: dict[str, int]) -> int: return conditional_product( @@ -205,7 +218,12 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid): call_kwargs[k] = v else: call_kwargs[k] = v +<<<<<<< HEAD call_kwargs.update(launcher.config.kwargs) +======= + if not triton_version_uses_attrs_dict(): + call_kwargs.update(launcher.config.kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) call_kwargs["num_warps"] = launcher.config.num_warps call_kwargs["num_stages"] = launcher.config.num_stages if HAS_WARP_SPEC: @@ -317,9 +335,13 @@ def __init__( [] if reset_to_zero_arg_names is None else reset_to_zero_arg_names ) self.optimize_mem = optimize_mem +<<<<<<< HEAD cached_config = lookup_autotune_config(size_hints, fn) self.configs = [cached_config] if cached_config else configs +======= + self.configs = configs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.heuristic_type = heuristic_type self.custom_kernel = custom_kernel self.cuda_kernel_saved = False @@ -374,9 +396,12 @@ def __init__( self.compile_id: Optional[CompileId] = None self.is_backward = False +<<<<<<< HEAD # Mode for launch grid calculation self.grid_mode: Literal["python", "python_slow", "cpp"] = "python" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def is_statically_launchable(self): """ Checks if every compiled kernel is statically launchable, which @@ -585,7 +610,11 @@ def _dynamic_scale_rblock(self): assert hasattr(self, "_reload_kernel") assert callable(self._reload_kernel) self.fn = self._reload_kernel().fn +<<<<<<< HEAD self.compile_results.append(self._precompile_config(new_config)) # noqa: B909 +======= + self.compile_results.append(self._precompile_config(new_config)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._make_launchers() @@ -600,6 +629,7 @@ def _make_launchers(self): # load binary to the correct device with DeviceGuard(device_interface, self.triton_meta["device"]): # need to initialize context +<<<<<<< HEAD with dynamo_timed( "CachingAutotuner.synchronize", # Deliberately avoid overloading pt2_compile_events: @@ -607,6 +637,9 @@ def _make_launchers(self): ): device_interface.synchronize(device_interface.current_device()) +======= + device_interface.synchronize(device_interface.current_device()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) launchers = [] exc = None for result in self.compile_results: @@ -619,7 +652,11 @@ def _make_launchers(self): raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}") self.launchers = launchers +<<<<<<< HEAD def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any, Any]: +======= + def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Drop stuff from triton.JITFunction that does not pickle. This must be called after precompile so that these things are no longer needed. Returns a tuple of old values @@ -630,13 +667,17 @@ def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any, Any]: self.fn.used_global_vals, self.fn.repr, self.launchers, +<<<<<<< HEAD getattr(self.fn, "_hash_lock", None), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.fn.fn = None self.fn.__globals__ = None self.fn.used_global_vals = None self.fn.repr = _ConstRepr(self.fn.repr(self.fn)) self.launchers = [] +<<<<<<< HEAD self.fn._hash_lock = None return old_values @@ -657,6 +698,10 @@ def restore_after_unpickle( # _hash_lock to be a valid RLock self.fn._hash_lock = threading.RLock() +======= + return old_values + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def prepare_for_caching(self) -> None: """ Statically Launched CUDA Kernels have a raw cubin on them @@ -784,6 +829,7 @@ def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]: compile_meta, ) raise +<<<<<<< HEAD # Simulate JIT Hook call if ( @@ -814,6 +860,8 @@ def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]: except Exception: log.exception("jit_post_compile_hook failed") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TritonBundler.put( triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0) ) @@ -830,6 +878,31 @@ def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]: return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta) +<<<<<<< HEAD +======= + def _get_args_with_constexprs(self, args, launcher): + """ + `args` is passed in with only the non-constexpr args (because the constexpr arg values + depend on the config). However, in later triton versions, the constexpr args need to be + added into the args list. + """ + if triton_version_uses_attrs_dict(): + # first: aggregate the constexpr args in (index, val) pairs + # so we can sort them by index. + constexpr_args: list[tuple[int, Any]] = [] + for arg_name, arg_val in launcher.config.kwargs.items(): + if arg_name in self.fn.arg_names: + constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val)) + + constexpr_args.sort() + new_args = [*args] + for arg_idx, arg_val in constexpr_args: + new_args.insert(arg_idx, arg_val) + + return new_args + return args + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def bench(self, launcher, *args, with_profiler=False, **kwargs): """Measure the performance of a given launcher""" # we don't skip configs with spilled registers when auto-tuning custom @@ -838,7 +911,11 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs): # for some (complicated) custom Triton kernels, a register-spilling # config may yield the best latency. if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( +<<<<<<< HEAD "spill_threshold", 32 if torch.version.hip else 16 +======= + "spill_threshold", 16 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): log.debug( "Skip config %s because of register spilling: %d", @@ -858,6 +935,7 @@ def kernel_call(): ) # reset to zero before evaluating any config self.reset_to_zero_args(*args, **kwargs) +<<<<<<< HEAD if autograd_profiler._is_profiler_enabled: profiler_kwargs = self.get_profiler_kwargs(stream, launcher) with torch._C._profiler._RecordFunctionFast( @@ -881,6 +959,17 @@ def kernel_call(): # only use profiler when not already in a profiler instance if with_profiler and not autograd_profiler._is_profiler_enabled: +======= + args_with_constexprs = self._get_args_with_constexprs(cloned_args, launcher) + launcher( + *args_with_constexprs, + **cloned_kwargs, + stream=stream, + ) + self.restore_args_from_cpu(cpu_copies) + + if with_profiler: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.utils import do_bench_using_profiling return do_bench_using_profiling(kernel_call, warmup=10, rep=40) @@ -1013,11 +1102,18 @@ def benchmark_all_configs(self, *args, **kwargs): log_waitcounter=True, waitcounter_name_override="triton_autotuner", ), +<<<<<<< HEAD # Temporarily disable due to spam # compilation_callback.callback_handler.install_callbacks( # compilation_callback.CallbackTrigger.TRITON_AUTOTUNING, # str(self.compile_id), # ), +======= + compilation_callback.callback_handler.install_callbacks( + compilation_callback.CallbackTrigger.TRITON_AUTOTUNING, + str(self.compile_id), + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): timings = { launcher: self.bench(launcher, *args, **kwargs) @@ -1098,6 +1194,7 @@ def save_gpu_kernel(self, stream, launcher): "def_args": launcher.def_args, "call_args": launcher.call_args, "global_scratch": launcher.global_scratch, +<<<<<<< HEAD "profile_scratch": launcher.profile_scratch, } if self.device_props.type == "xpu": @@ -1109,6 +1206,9 @@ def save_gpu_kernel(self, stream, launcher): launcher.bin.metadata, "threads_per_warp", 32 ) +======= + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.codecache import CudaKernelParamCache bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin") @@ -1140,6 +1240,7 @@ def coordinate_descent_tuning(self, launcher, *args, **kwargs): # skip triton template return launcher +<<<<<<< HEAD with dynamo_timed( "CachingAutotuner.coordinate_descent_tuning", # These generate too many pt2_compile_event logs: @@ -1154,6 +1255,8 @@ def coordinate_descent_tuning(self, launcher, *args, **kwargs): return self._coordinate_descent_tuning(launcher, *args, **kwargs) def _coordinate_descent_tuning(self, launcher, *args, **kwargs): +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) config2launcher = {launcher.config: launcher} # TODO: should we just load the kernels ahead of time if we know we're going to call this? @@ -1210,6 +1313,7 @@ def benchmark_one_config(config): config2launcher[best_config] = self._precompile_config( best_config ).make_launcher() +<<<<<<< HEAD fn_hash = generate_lookup_hash_from_source_code( str(self.size_hints), self.fn.src @@ -1239,6 +1343,10 @@ def get_profiler_kwargs(self, stream, launcher): ret["kernel_num_gb"] = self.inductor_meta["kernel_num_gb"] return ret +======= + return config2launcher[best_config] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def run( self, *args, @@ -1282,10 +1390,14 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved): self.save_gpu_kernel(stream, launcher) +<<<<<<< HEAD # PyTorch execution trace replay calls CachingAutotuner::run() instead of calls launcher # so _RecordFunctionFast need to capture the args into CachingAutotuner::run() # make a copy here to avoid mutating the original args args_without_constexprs = tuple(args) +======= + args = self._get_args_with_constexprs(args, launcher) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.dump_launch_params: new_args, grid = self._interpret_args_grid(args, launcher.config) @@ -1294,11 +1406,31 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): # it is faster than entering and exiting a context manager, even if the context # manager is a nullcontext. if autograd_profiler._is_profiler_enabled: +<<<<<<< HEAD profiler_kwargs = self.get_profiler_kwargs(stream, launcher) with torch._C._profiler._RecordFunctionFast( self.inductor_meta.get("kernel_name", "triton kernel"), args_without_constexprs, +======= + kernel_kwargs_str = ",".join( + f"{k}={v}" for (k, v) in launcher.config.kwargs.items() + ) + + profiler_kwargs = { + "kernel_file": (self.filename or ""), + "kernel_hash": self.kernel_hash, + "kernel_backend": "triton", + "stream": stream, + "num_warps": launcher.config.num_warps, + "num_stages": launcher.config.num_stages, + "kernel_kwargs": kernel_kwargs_str, + } + + with torch._C._profiler._RecordFunctionFast( + self.inductor_meta.get("kernel_name", "triton kernel"), + args, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) profiler_kwargs, ): return launcher( @@ -1316,6 +1448,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): def _interpret_args_grid( self, args: tuple[Any, ...], cfg: Config ) -> tuple[tuple[Any, ...], tuple[int, int, int]]: +<<<<<<< HEAD if triton_version_uses_attrs_dict(): def filtered_signature() -> list[str]: @@ -1337,6 +1470,13 @@ def filtered_signature() -> list[str]: zip( [ *filtered_signature(), +======= + grid = GridExpr.from_meta(self.inductor_meta, cfg).eval_slow( + dict( + zip( + [ + *self.triton_meta["signature"].keys(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *self.inductor_meta.get("extra_launcher_args", ()), ], args, @@ -1357,10 +1497,13 @@ def __call__(self, _=None) -> str: class CompileResult(Generic[_T]): +<<<<<<< HEAD """ Base class representing compiled result. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__( self, kernel: _T, @@ -1424,6 +1567,7 @@ def _get_arg_lists( ) none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys())) +<<<<<<< HEAD def _convert_constant(constant): if isinstance(constant, str): return "r'" + constant + "'" @@ -1448,6 +1592,23 @@ def _convert_constant(constant): repl = { k: _convert_constant(compile_meta["constants"].get(k)) for k in implicit_constants +======= + if triton_version_uses_attrs_dict(): + call_args = arg_names + def_args = arg_names + if ( + "num_warps" in compile_meta["constants"] + or "num_stages" in compile_meta["constants"] + ): + # num_warps/num_stages are special implicit args that are not in the signature + # see test_triton_kernel_special_params + def_args = [ + arg for arg in def_args if arg not in ("num_warps", "num_stages") + ] + repl = { + k: str(compile_meta["constants"].get(k)) + for k in ("num_warps", "num_stages") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } call_args = [repl.get(arg, arg) for arg in call_args] else: @@ -1727,8 +1888,11 @@ def make_launcher(self) -> LauncherType: import math as math_lib +<<<<<<< HEAD import triton as triton_lib +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch as torch_lib scope = { @@ -1763,7 +1927,10 @@ def make_launcher(self) -> LauncherType: "runner": get_first_attr(binary, "run", "c_wrapper"), "math": math_lib, "torch": torch_lib, +<<<<<<< HEAD "triton": triton_lib, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if not hasattr(binary, "launch_metadata"): @@ -1832,6 +1999,7 @@ def make_launcher(self) -> LauncherType: launcher.def_args = def_args launcher.call_args = call_args kernel_metadata = getattr(self.kernel, "metadata", None) +<<<<<<< HEAD # for the scratch arguments: None indicates that the kernel doesn't # take any scratch argument; otherwise a number indicates the number @@ -1849,6 +2017,11 @@ def make_launcher(self) -> LauncherType: ) launcher.global_scratch = global_scratch launcher.profile_scratch = profile_scratch +======= + launcher.global_scratch = getattr( + kernel_metadata, "global_scratch_size", None + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return launcher @@ -2163,9 +2336,12 @@ def triton_config( num_stages=1, num_elements_per_warp=256, min_elem_per_thread=0, +<<<<<<< HEAD num_warps=None, matrix_instr=None, waves_per_eu=None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Config: """ Construct a pointwise triton config with some adjustment heuristics @@ -2222,11 +2398,17 @@ def triton_config( ): z *= 2 +<<<<<<< HEAD # Calculate num_warps if they are not hard passed to config if num_warps is None: num_warps = _num_warps( conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 ) +======= + num_warps = _num_warps( + conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # we are going to arrive at 2 warps only if bs was too small due to # numel being too small. However to workaround some ptx bugs we still # want at least 4 warps if there's enough elements per thread @@ -2256,6 +2438,7 @@ def triton_config( cfg["ZBLOCK"] = z check_max_block(cfg) check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) +<<<<<<< HEAD config = Config(cfg, num_warps=num_warps, num_stages=num_stages) if torch.version.hip: @@ -2265,6 +2448,9 @@ def triton_config( config.kwargs["waves_per_eu"] = waves_per_eu return config +======= + return Config(cfg, num_warps=num_warps, num_stages=num_stages) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]: @@ -2312,8 +2498,11 @@ def triton_config_reduction( num_stages=1, num_warps=None, register_intensive=False, +<<<<<<< HEAD waves_per_eu=None, dynamic_scale_rblock=True, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Config: """ Construct a reduction triton config with some adjustment heuristics @@ -2357,6 +2546,7 @@ def total_numel() -> int: cfg = _get_config({"x": x, **rnumels}) check_max_block(cfg) check_config(cfg, xnumel=size_hints["x"]) +<<<<<<< HEAD config = InductorConfig( cfg, num_warps=num_warps, @@ -2369,6 +2559,9 @@ def total_numel() -> int: config.kwargs["waves_per_eu"] = waves_per_eu return config +======= + return Config(cfg, num_warps=num_warps, num_stages=num_stages) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _get_config(numels: dict[str, int]) -> dict[str, int]: @@ -2380,7 +2573,11 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]: def triton_config_tiled_reduction( +<<<<<<< HEAD size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None +======= + size_hints, x, y, r, num_stages=1, register_intensive=False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Construct a tile reduction triton config with some adjustment @@ -2417,6 +2614,7 @@ def total_numel() -> int: ) check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"]) check_max_block(cfg) +<<<<<<< HEAD config = Config(cfg, num_warps=num_warps, num_stages=num_stages) if torch.version.hip: if waves_per_eu is not None: @@ -2467,6 +2665,9 @@ def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Conf ) return new_configs return configs +======= + return Config(cfg, num_warps=num_warps, num_stages=num_stages) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def pointwise( @@ -2510,6 +2711,7 @@ def pointwise( triton_config_with_settings( size_hints, bs // 2, num_elements_per_warp=64 ), +<<<<<<< HEAD triton_config_with_settings( size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 ), @@ -2542,6 +2744,13 @@ def pointwise( if ( disable_pointwise_autotuning(inductor_meta) or (torch.version.hip is None and tile_hint == TileHint.SQUARE) +======= + *hinted_configs, + ] + if len(size_hints) == 2: + if ( + disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) and not ( inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") @@ -2550,6 +2759,7 @@ def pointwise( else: configs = [ triton_config_with_settings(size_hints, 32, 32), +<<<<<<< HEAD triton_config_with_settings( size_hints, 64, 32 ), # better for some kernels @@ -2563,10 +2773,16 @@ def pointwise( triton_config_with_settings( size_hints, 32, 512 ), # +30% for some kernels +======= + triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 + triton_config_with_settings(size_hints, 256, 16), + triton_config_with_settings(size_hints, 16, 256), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) triton_config_with_settings(size_hints, bs, 1), triton_config_with_settings(size_hints, 1, bs), *hinted_configs, ] +<<<<<<< HEAD if torch.version.hip: configs += [ # add here ] @@ -2580,6 +2796,8 @@ def pointwise( Config({"XBLOCK":512, "YBLOCK": 64}, num_warps=8), # wri0: 58us: triton_poi_fused_clone_53 ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if len(size_hints) == 3: if disable_pointwise_autotuning(inductor_meta): configs = [triton_config_with_settings(size_hints, 16, 16, 16)] @@ -2597,9 +2815,12 @@ def pointwise( if not configs: raise NotImplementedError(f"size_hints: {size_hints}") +<<<<<<< HEAD configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return cached_autotune( size_hints, configs, @@ -2611,13 +2832,18 @@ def pointwise( def _reduction_configs( +<<<<<<< HEAD *, size_hints: dict[str, int], inductor_meta: dict[str, Any], num_dynamic=0 +======= + *, size_hints: dict[str, int], inductor_meta: dict[str, Any] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> list[Config]: reduction_hint = inductor_meta.get("reduction_hint", None) # Convert reductions to 1D, to simplify heuristics. rnumel = get_total_reduction_numel(size_hints) +<<<<<<< HEAD # Is max autotune enabled max_autotune_enabled = inductor_meta.get("max_autotune") or inductor_meta.get( "max_autotune_pointwise" @@ -2629,6 +2855,15 @@ def _reduction_configs( "num_reduction", 0 ) if size_hints["x"] >= 1024 and loads_and_red >= 10: +======= + register_intensive = False + MAX_R0_BLOCK = 2048 + if ( + size_hints["x"] >= 1024 + and inductor_meta.get("num_load", 0) + inductor_meta.get("num_reduction", 0) + >= 10 + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # A heuristics to reduce R0_BLOCK if a kernel potentially need many registers. # Consider load and reduction since load need move data into registers and # reduction needs an accumulator. @@ -2644,6 +2879,7 @@ def _reduction_configs( MAX_R0_BLOCK = 1024 register_intensive = True +<<<<<<< HEAD def make_config( x, r, @@ -2653,6 +2889,9 @@ def make_config( dynamic_scale_rblock=True, waves_per_eu=None, ): +======= + def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # For 3D case with tiling scores, create an adapted version if "y" in size_hints: assert "tiling_scores" in inductor_meta @@ -2664,7 +2903,10 @@ def make_config( num_warps=num_warps, num_stages=num_stages, register_intensive=register_intensive, +<<<<<<< HEAD waves_per_eu=waves_per_eu, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: # For other cases, use the original function @@ -2675,6 +2917,7 @@ def make_config( num_warps=num_warps, num_stages=num_stages, register_intensive=register_intensive, +<<<<<<< HEAD waves_per_eu=waves_per_eu, dynamic_scale_rblock=dynamic_scale_rblock, ) @@ -2727,16 +2970,25 @@ def outer_config_opt(): register_intensive=register_intensive, ) +======= + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) contiguous_config = make_config( 1, min(rnumel, MAX_R0_BLOCK), register_intensive=register_intensive, ) +<<<<<<< HEAD +======= + outer_config = make_config(64, 8, register_intensive=register_intensive) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tiny_config = make_config( 2 * (256 // rnumel) if rnumel <= 256 else 1, min(rnumel, MAX_R0_BLOCK), register_intensive=register_intensive, ) +<<<<<<< HEAD outer_config = make_config(64, 8, register_intensive=register_intensive) # TODO (paulzhan): Test heuristic on AMD and internal testing @@ -2795,6 +3047,34 @@ def outer_config_opt(): ) return result_configs +======= + # For 3d tiling, default to more autotuning initially + if "y" in size_hints: + pass + elif inductor_meta.get("max_autotune") or inductor_meta.get( + "max_autotune_pointwise" + ): + pass # skip all these cases + elif reduction_hint == ReductionHint.INNER: + return [contiguous_config] + elif reduction_hint == ReductionHint.OUTER: + return [outer_config] + elif reduction_hint == ReductionHint.OUTER_TINY: + return [tiny_config] + if disable_pointwise_autotuning(inductor_meta): + return [make_config(32, 128)] + return [ + contiguous_config, + outer_config, + tiny_config, + make_config(64, 64), + make_config(8, 512), + # halve the XBLOCK/Rn_BLOCK compared to outer_config + # TODO: this may only be beneficial when each iteration of the reduction + # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 + make_config(64, 4, num_warps=8), + ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def match_target_block_product( @@ -2852,7 +3132,10 @@ def adapt_config_for_tiling( num_stages=1, register_intensive=False, persistent_reduction=False, +<<<<<<< HEAD waves_per_eu=None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Config: """ Create an adapted configuration based on tiling scores, @@ -2871,7 +3154,10 @@ def adapt_config_for_tiling( block_sizes["r0_"], num_stages=num_stages, register_intensive=register_intensive, +<<<<<<< HEAD waves_per_eu=waves_per_eu, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -2890,6 +3176,7 @@ def reduction( assert triton_meta is not None +<<<<<<< HEAD num_dynamic = 0 for k in triton_meta["signature"].keys(): if "ks" in k: @@ -2900,6 +3187,9 @@ def reduction( ) configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) +======= + configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return cached_autotune( size_hints, configs=configs, @@ -2946,7 +3236,10 @@ def cooperative_reduction( config.kwargs["RSPLIT"] = split # TODO(jansel): add more configs in max_autotune +<<<<<<< HEAD configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return cached_autotune( size_hints, configs=configs, @@ -2964,6 +3257,7 @@ def _persistent_reduction_configs( ): xnumel = size_hints["x"] rnumel = get_total_reduction_numel(size_hints) +<<<<<<< HEAD loads_and_stores = inductor_meta.get("num_load", 0) + inductor_meta.get( "num_store", 0 ) @@ -2983,6 +3277,19 @@ def _persistent_reduction_configs( configs = [ triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) for xblock in xblock_vals +======= + + MAX_PERSISTENT_BLOCK_NUMEL = 4096 + max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or ( + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + ) + + if "y" not in size_hints: + configs = [ + triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) + for xblock in (1, 8, 32, 128) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if xblock == 1 or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel) ] @@ -2990,7 +3297,11 @@ def _persistent_reduction_configs( configs = [] assert "tiling_scores" in inductor_meta x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")} +<<<<<<< HEAD for target_block_size in xblock_vals: +======= + for target_block_size in (1, 8, 32, 64, 128): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL: continue @@ -3003,6 +3314,10 @@ def _persistent_reduction_configs( ) ) +<<<<<<< HEAD +======= + # defer to more autotuning, initially +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tiny_configs = [ triton_config_reduction( size_hints, @@ -3011,6 +3326,7 @@ def _persistent_reduction_configs( ) ] +<<<<<<< HEAD # defer to more autotuning, initially if "y" in size_hints: pass @@ -3036,16 +3352,32 @@ def _persistent_reduction_configs( ) ] +======= + if "y" in size_hints: + pass + # TODO(jansel): we should be able to improve these heuristics + elif not max_autotune_enabled: # Don't filter if tuning enabled + if reduction_hint == ReductionHint.INNER and rnumel >= 256: + configs = configs[:1] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif reduction_hint == ReductionHint.OUTER: configs = configs[-1:] elif reduction_hint == ReductionHint.OUTER_TINY: configs = tiny_configs else: +<<<<<<< HEAD # If autotune is enabled append tiny configs for conf in tiny_configs: if conf not in configs: configs.append(conf) +======= + if max_autotune_enabled: + for conf in tiny_configs: + if conf not in configs: + configs.append(conf) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for c in configs: # we don't need Rn_BLOCK for persistent reduction for prefix in size_hints: @@ -3072,6 +3404,7 @@ def persistent_reduction( configs = _persistent_reduction_configs(size_hints, reduction_hint, inductor_meta) +<<<<<<< HEAD # This key is not added to the inductor meta as its clear from the heuristic # choice that it is persistent. Add it and remove it below so that persistent # configs can be filtered appropriately by _maybe_filter_configs_for_tma_restrictions @@ -3080,6 +3413,8 @@ def persistent_reduction( configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) inductor_meta.pop(persistent_reduction_key) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return cached_autotune( size_hints, configs, @@ -3116,7 +3451,10 @@ def split_scan( if var.startswith("R") and cfg.kwargs[var] < min_rblock: cfg.kwargs[var] = min_rblock +<<<<<<< HEAD configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return cached_autotune( size_hints, configs=configs, @@ -3242,10 +3580,16 @@ def foreach(triton_meta, filename=None, inductor_meta=None): Compile a triton foreach kernel """ configs = [] +<<<<<<< HEAD # Naive autotuning path for num_warps if disable_pointwise_autotuning(inductor_meta) and not ( inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") +======= + if disable_pointwise_autotuning(inductor_meta) and not ( + inductor_meta.get("max_autotune") or + inductor_meta.get("max_autotune_pointwise") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): configs.append(triton.Config({}, num_stages=1, num_warps=8)) else: @@ -3261,20 +3605,31 @@ def foreach(triton_meta, filename=None, inductor_meta=None): filename=filename, ) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass class GridExpr: """Generate code for grid size expressions in launcher""" inductor_meta: dict[str, Any] +<<<<<<< HEAD mode: Literal["python", "cpp", "python_slow"] = "python" +======= + mode: Literal["python", "cpp"] = "python" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prefix: list[str] = dataclasses.field(default_factory=list) x_grid: Union[str, int] = 1 y_grid: Union[str, int] = 1 z_grid: Union[str, int] = 1 def __post_init__(self) -> None: +<<<<<<< HEAD assert self.mode in ("python", "cpp", "python_slow") +======= + assert self.mode in ("python", "cpp") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def generate(self, meta: dict[str, int]) -> None: raise NotImplementedError @@ -3286,6 +3641,7 @@ def ceildiv( return numel if isinstance(numel, int) and isinstance(block, int): return ceildiv(numel, block) # constant fold +<<<<<<< HEAD # This trick only works in python, where # negative integer division is floored if self.mode == "python": @@ -3295,6 +3651,11 @@ def ceildiv( elif self.mode == "python_slow": return f"(({numel} + {block} - 1) // ({block}))" # For cpp code gen +======= + if self.mode == "python": + return f"-(({numel}) // -({block}))" + # trick above doesn't work in C++ due to rounding differences +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"(({numel} + ({block} - 1)) / ({block}))" def maximum(self, seq: list[Union[int, str]]) -> Union[int, str]: @@ -3302,7 +3663,11 @@ def maximum(self, seq: list[Union[int, str]]) -> Union[int, str]: items = self._constant_fold(max, seq) if len(items) <= 1: return items[0] +<<<<<<< HEAD if self.mode in ("python", "python_slow"): +======= + if self.mode == "python": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"max({', '.join(map(str, items))})" return functools.reduce(lambda x, y: f"std::max({x}, {y})", items) @@ -3325,7 +3690,11 @@ def _constant_fold( def assign_tmp(self, name: str, expr: Union[str, int]) -> str: # Grid functions are one per kernel, so name collisions are fine +<<<<<<< HEAD if self.mode in ("python", "python_slow"): +======= + if self.mode == "python": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"{name} = {expr}" if self.mode == "cpp": return f"uint32_t {name} = {expr};" @@ -3335,7 +3704,11 @@ def assign_tmp(self, name: str, expr: Union[str, int]) -> str: def from_meta( inductor_meta: dict[str, Any], cfg: Union[Config, dict[str, int]], +<<<<<<< HEAD mode: Literal["python", "cpp", "python_slow"] = "python", +======= + mode: Literal["python", "cpp"] = "python", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> GridExpr: grid_cls = globals()[inductor_meta["grid_type"]] assert issubclass(grid_cls, GridExpr) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 41dbd9e14ad9b..a545e82896306 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -1,7 +1,10 @@ from __future__ import annotations import collections +<<<<<<< HEAD import contextlib +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import dataclasses import functools import inspect @@ -16,6 +19,7 @@ import typing from collections import Counter, defaultdict from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Union +<<<<<<< HEAD from typing_extensions import ParamSpec, TypeAlias @@ -25,20 +29,34 @@ import weakref +======= + + +if TYPE_CHECKING: + from collections.abc import Sequence + from types import ModuleType + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import sympy import torch import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +<<<<<<< HEAD import torch.utils._pytree as pytree from torch._dynamo.utils import counters, dynamo_timed from torch._inductor.codecache import LambdaFuture, PyCodeCache from torch._inductor.ir import TritonTemplateCallerBase +======= +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor.codecache import LambdaFuture, PyCodeCache +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.metrics import get_metric_table, is_metric_table_enabled from torch.fx.experimental.symbolic_shapes import free_symbols from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT from torch.utils._triton import has_triton +<<<<<<< HEAD from . import comms, config, config_comms, dependencies, ir, metrics from .analyze_preserves_zero_mask import can_codegen_without_upcasts from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel @@ -49,6 +67,15 @@ from .dependencies import Dep, MemoryDep, StarDep, WeakDep from .exc import GPUTooOldForTriton, TritonMissing from .fx_utils import count_flops_fx +======= +from . import comms, config, dependencies, ir, metrics +from .analyze_preserves_zero_mask import can_codegen_without_upcasts +from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel +from .comm_analysis import estimate_nccl_collective_runtime +from .dependencies import Dep, MemoryDep, StarDep, WeakDep +from .exc import GPUTooOldForTriton, TritonMissing +from .fx_utils import count_flops_fx, countable_fx +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .ir import ( get_device_type, GraphPartitionSignature, @@ -61,7 +88,10 @@ from .runtime.runtime_utils import green_text, red_text from .sizevars import SimplifyIndexing from .utils import ( +<<<<<<< HEAD _unstable_customized_partition_wrapper, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cache_on_self, cmp, device_need_guard, @@ -76,7 +106,10 @@ is_multi_outputs_template, is_output_of_multi_outputs_template, is_wait, +<<<<<<< HEAD maybe_log_cudagraph_partition, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sympy_product, ) from .virtualized import V @@ -89,6 +122,7 @@ __name__, "compute_dependencies" ) +<<<<<<< HEAD PartitionType: TypeAlias = list["BaseSchedulerNode"] _T = TypeVar("_T") _P = ParamSpec("_P") @@ -114,6 +148,9 @@ def register_should_partition_rule( """ assert isinstance(op, torch._ops.OpOverload) _custom_should_partition_fns[op] = func +======= +PartitionType = list["BaseSchedulerNode"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass @@ -241,7 +278,10 @@ class BaseSchedulerNode: min_order: int max_order: int mpi_node: MemoryPlanningInfoForNode +<<<<<<< HEAD override_estimated_runtime: Optional[float] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self, scheduler: Scheduler) -> None: self.scheduler: Scheduler = scheduler @@ -268,6 +308,7 @@ def _init_from_node(self, node: ir.Operation) -> None: buf.get_name(): buf for buf in self.outputs } +<<<<<<< HEAD # mutation_renames for the current node. Due to potential # more mutations happening later, this can be different # to Scheduler.mutation_renames. Also this dict should be small @@ -275,6 +316,8 @@ def _init_from_node(self, node: ir.Operation) -> None: # node is stored here. self.mutation_renames: dict[str, str] = {} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __repr__(self) -> str: return f"{type(self).__name__}(name={self.get_name()!r})" @@ -334,6 +377,7 @@ def log_details(self) -> None: def reorder_loops_by_dep_pair( self, self_dep: MemoryDep, other_dep: MemoryDep +<<<<<<< HEAD ) -> bool: return False @@ -344,6 +388,13 @@ def update_mutated_names(self, renames: dict[str, str]) -> None: if name in renames } self.set_read_writes(self.read_writes.rename(self.mutation_renames)) +======= + ) -> None: + return + + def update_mutated_names(self, renames: dict[str, str]) -> None: + self.set_read_writes(self.read_writes.rename(renames)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def add_fake_dep(self, dep: Dep) -> None: self.set_read_writes(self.read_writes.with_read(dep)) @@ -648,15 +699,22 @@ def codegen_originating_info( out_lines.append(op_info_str) if "stack_trace" in o.meta: stack_trace = f"{o.meta['stack_trace']}" +<<<<<<< HEAD stack_trace_last_line = stack_trace.rsplit("|", maxsplit=1)[-1] +======= + stack_trace_last_line = stack_trace.split("|")[-1] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_lines.append( "#pragma CMT " + stack_trace_last_line.replace("{", "{{") .replace("}", "}}") .replace("\n", "\\") +<<<<<<< HEAD .replace( "\\", "\\\\" ) # For windows safe path, avoid for example \x, \U. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) out_lines.append("#pragma CMT END ORIGIN") out_lines.append("") @@ -844,6 +902,7 @@ def estimate_flops(self) -> int | None: fx_node = self.node.get_origin_node() if fx_node is None: return None +<<<<<<< HEAD flops = count_flops_fx(fx_node) if flops is None: @@ -863,6 +922,21 @@ def get_estimated_runtime(self) -> float: def _get_estimated_runtime(self) -> float: """ Returns estimated op runtime in milliseconds (ms) +======= + if not countable_fx(fx_node): + return None + + flops = count_flops_fx(fx_node) + + resolved_flops = V.graph.sizevars.size_hints((flops,), fallback=0)[0] + counters["inductor"]["flop_count"] += resolved_flops + return resolved_flops + + @cache_on_self + def get_estimated_runtime(self) -> float: + """ + Returns estimated op runtime in nanoseconds (ns) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ buf = self.get_nodes()[0].get_outputs()[0] layout = buf.node.get_output_spec() @@ -874,6 +948,7 @@ def _get_estimated_runtime(self) -> float: if is_collective(self.node): assert isinstance(self.node, ir.IRNode) try: +<<<<<<< HEAD if config_comms.runtime_estimations_use_nccl_lib_estimations: cache_key = get_estimate_runtime_cache_key_from_snode(self) cache = get_estimate_runtime_cache() @@ -889,6 +964,8 @@ def _get_estimated_runtime(self) -> float: cache.set_value(cache_key, value=ms) return ms +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return estimate_nccl_collective_runtime(self.node) except ValueError as e: # We don't know how to estimate runtime for this collective, @@ -907,10 +984,13 @@ def _get_estimated_runtime(self) -> float: # since it doesn't take extra time to get the result after the collective is completed. return 0 +<<<<<<< HEAD ret = maybe_estimate_runtime_benchmark(self) if ret is not None: return ret +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype = buf.node.maybe_get_dtype() try: gpu_memory_bandwidth = get_gpu_dram_gbps() @@ -931,9 +1011,13 @@ def _get_estimated_runtime(self) -> float: if flops_est == 0 or flops_est is None: # no flops estimate, so fall back to memory estimate +<<<<<<< HEAD ns = self.get_read_write_buffers_sizes() / gpu_memory_bandwidth ms = ns / 1e6 return ms +======= + return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO(xmfan): find a better heuristic to model FLOPS/latency relationship factor = 1.0 @@ -942,10 +1026,15 @@ def _get_estimated_runtime(self) -> float: compute_time = (factor * flops_est / gpu_flops) * 1e9 transfer_time = counted_bytes / gpu_memory_bandwidth +<<<<<<< HEAD # Return estimated runtime in milliseconds ns = max(compute_time, transfer_time) ms = ns / 1e6 return ms +======= + # Return estimated runtime in nanoseconds + return max(compute_time, transfer_time) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_template_node(self) -> Optional[ir.TemplateBuffer]: return None @@ -970,6 +1059,7 @@ def get_prologue_template_epilogue( return prologue, template_node, epilogue +<<<<<<< HEAD @functools.cache def get_estimate_runtime_cache() -> torch._inductor.codecache.LocalCache: return torch._inductor.codecache.LocalCache() @@ -1041,6 +1131,8 @@ def maybe_estimate_runtime_benchmark(snode: BaseSchedulerNode) -> Optional[float return ms +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class WhyNoFuse: # TODO when we drop support for Python < 3.10, we can use # @dataclass(slots=True) instead of manually specifying __slots__. @@ -1155,11 +1247,14 @@ def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: class SchedulerNode(BaseSchedulerNode): +<<<<<<< HEAD """ A SchedulerNode is a node for scheduling that encapsulates either a ComputedBuffer or a TemplateBuffer. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _sizes: tuple[Sequence[sympy.Expr], ...] _body: LoopBody @@ -1175,6 +1270,7 @@ def __init__( def _compute_attrs( self, extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, +<<<<<<< HEAD recompute_sizes_body_func: Optional[Callable[_P, _T]] = None, ) -> None: assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) @@ -1183,6 +1279,15 @@ def _compute_attrs( recompute_sizes_body_func=recompute_sizes_body_func, ) self._body = body # type: ignore[assignment] +======= + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ) -> None: + assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) + self._sizes, self._body = self.node.simplify_and_reorder( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=recompute_sizes_body_func, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device = self.node.get_device_or_error() group_fn = self.scheduler.get_backend(device).group_fn @@ -1229,9 +1334,13 @@ def refresh_dependencies( self.set_read_writes( dependencies.extract_read_writes( self._body, *self._sizes, normalize=normalize +<<<<<<< HEAD ) .with_read(fake_deps) .rename(self.mutation_renames) +======= + ).with_read(fake_deps) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.pointwise_read_writes.clear_cache(self) @@ -1253,6 +1362,7 @@ def apply_new_loop_order(self, new_order: Sequence[int]) -> None: self.refresh_dependencies(normalize=False, need_clear_tiling_cache=True) +<<<<<<< HEAD def expand_dimension_for_pointwise_node( self, dimension: int, new_range: int ) -> None: @@ -1270,6 +1380,8 @@ def expand_dimension_for_pointwise_node( # Need normalize the prefix name to facilitate finding common dependencies self.refresh_dependencies(normalize=True, need_clear_tiling_cache=True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def merge_loops(self) -> None: self._body = self._body.merge_loops() self._sizes = self._body.sizes @@ -1284,7 +1396,11 @@ def merge_loops(self) -> None: def reorder_loops_by_dep_pair( self, self_dep: MemoryDep, other_dep: MemoryDep +<<<<<<< HEAD ) -> bool: +======= + ) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_order = None self_sizes = self._sizes[0] if len(self_sizes) == self_dep.num_vars == other_dep.num_vars: @@ -1296,13 +1412,19 @@ def reorder_loops_by_dep_pair( "Reorder loops for %s with order %s", self.get_name(), new_order ) self.apply_new_loop_order(new_order) +<<<<<<< HEAD return True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: loop_ordering_log.debug( "Don't reordering %s because we can not decide the suitable loop order", self.get_name(), ) +<<<<<<< HEAD return False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def debug_str_extra(self) -> str: name = self.get_name() @@ -1368,6 +1490,7 @@ def ranges_from_index_vars( return var_ranges def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None: +<<<<<<< HEAD """ Generate code for this node using the provided index variables. @@ -1379,6 +1502,8 @@ def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None: index_vars: A sequence of sequences of sympy expressions representing the index variables for each dimension of the computation. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) var_ranges = self.ranges_from_index_vars(index_vars) try: with ( @@ -1448,6 +1573,7 @@ def _get_atomic_add_buffers(self) -> OrderedSet[str]: ) return buffers_store_as_atomic_add +<<<<<<< HEAD @cache_on_self def has_side_effects(self) -> bool: # self._body is None sometimes that's why this check was added @@ -1455,6 +1581,8 @@ def has_side_effects(self) -> bool: return True return super().has_side_effects() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def refresh_group_node_dependencies( group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode], @@ -1559,6 +1687,7 @@ def estimate_flops(self) -> int | None: def reorder_loops_by_dep_pair( self, self_dep: MemoryDep, other_dep: MemoryDep +<<<<<<< HEAD ) -> bool: """ Return true if a loop reordering is performed. @@ -1566,6 +1695,12 @@ def reorder_loops_by_dep_pair( if self.is_template(): # We can not really reorder loops for a triton template return False +======= + ) -> None: + if self.is_template(): + # We can not really reorder loops for a triton template + return +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self_sizes = None for snode in self.snodes: assert isinstance(snode, SchedulerNode) @@ -1573,7 +1708,11 @@ def reorder_loops_by_dep_pair( loop_ordering_log.debug( "Can not reorder fused node due to different sizes" ) +<<<<<<< HEAD return False +======= + return +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self_sizes = snode._sizes[0] new_order = None @@ -1586,7 +1725,11 @@ def reorder_loops_by_dep_pair( "Dont reordering fused node %s because we can not decide the suitable loop order", self.get_name(), ) +<<<<<<< HEAD return False +======= + return +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) metrics.num_loop_reordering += 1 loop_ordering_log.debug( "Reorder loops for fused node %s with order %s", self.get_name(), new_order @@ -1596,7 +1739,10 @@ def reorder_loops_by_dep_pair( snode.apply_new_loop_order(new_order) refresh_group_node_dependencies(self) +<<<<<<< HEAD return True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None: super().__init__(scheduler) @@ -1728,12 +1874,15 @@ def debug_str(self) -> str: return buf.getrawvalue().rstrip() +<<<<<<< HEAD @cache_on_self def has_side_effects(self) -> bool: if self.snodes is not None: return any(node.has_side_effects() for node in self.snodes) return super().has_side_effects() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ForeachKernelSchedulerNode(FusedSchedulerNode): """ @@ -2078,6 +2227,7 @@ def create(cls, snodes: list[BaseSchedulerNode]) -> GroupedSchedulerNode: scheduler.name_to_fused_node[grouped_snode.get_name()] = grouped_snode return grouped_snode +<<<<<<< HEAD def __init__( self, scheduler: Scheduler, @@ -2092,15 +2242,23 @@ def __init__( # Reusing calculation of grouped unmed_dependencies etc. # No fusion logic in this case. self.temp_grouping = temp_grouping +======= + def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None: + super().__init__(scheduler) + init_group_node(self, scheduler, snodes) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def unpack(self) -> list[BaseSchedulerNode]: """ Do fusion among nodes within this GroupedSchedulerNode, and then unpack this GroupedSchedulerNode into regular nodes. """ +<<<<<<< HEAD if self.temp_grouping: return self.snodes +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for snode in self.snodes: self.scheduler.name_to_fused_node[snode.get_name()] = snode del self.scheduler.name_to_fused_node[self.get_name()] @@ -2157,7 +2315,11 @@ def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> b def pick_loop_order( stride_lengths: list[list[int]], sizes: Sequence[sympy.Expr], +<<<<<<< HEAD priority_idx: Sequence[int] = (), +======= + priority_idx: tuple[int, ...] = (), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> list[int]: """ A heuristic to decide loop iteration orders. This has not been well @@ -2235,22 +2397,34 @@ def merge(self, other: NodeUser) -> NodeUser: _post_grad_graph_counter = itertools.count() +<<<<<<< HEAD def used_non_deterministic_runtime_estimations() -> bool: return config.runtime_estimations_mms_benchmark +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Scheduler: """ A Scheduler is a graph of BaseSchedulerNodes. It is responsible for optimizations such as fusion, reorder, and graph partition. """ +<<<<<<< HEAD +======= + __dep_size_hint_cache: dict[Dep, int] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self, nodes: list[ir.Operation]) -> None: with dynamo_timed("Scheduler.__init__"): self._init(nodes) def _init(self, nodes: list[ir.Operation]) -> None: super().__init__() +<<<<<<< HEAD +======= + self.__dep_size_hint_cache = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) V.graph.scheduler = self self.backends: dict[torch.device, BaseScheduling] = {} self.post_grad_graph_id = next(_post_grad_graph_counter) @@ -2266,16 +2440,22 @@ def _init(self, nodes: list[ir.Operation]) -> None: ) self.nodes = [self.create_scheduler_node(n) for n in nodes] +<<<<<<< HEAD self.current_node: Optional[BaseSchedulerNode] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.update_zero_dim_cpu_tensor() # some new constants could have been created above self.available_buffer_names.update(V.graph.constants.keys()) for node in self.nodes: node.prune_deps() +<<<<<<< HEAD # See [Note: Graph Partition Device Contexts] self.default_device_context: Optional[torch.device] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.name_to_donated_buffer: dict[str, SchedulerDonatedBuffer] = ( self.get_donated_buffers() ) @@ -2328,6 +2508,7 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.logged_slow_fusion = OrderedSet[tuple[str, str]]() if config._pre_fusion_custom_pass is not None: self.nodes = config._pre_fusion_custom_pass(self.nodes) +<<<<<<< HEAD self.nodes = self.fuse_nodes(self.nodes) if config._post_fusion_custom_pass is not None: @@ -2342,6 +2523,15 @@ def _init(self, nodes: list[ir.Operation]) -> None: log_waitcounter=True, ): self.create_combo_kernel_nodes(num_ck_nodes=None) +======= + self.nodes = self.fuse_nodes(self.nodes) + if config._post_fusion_custom_pass is not None: + self.nodes = config._post_fusion_custom_pass(self.nodes) + self.merge_loops() + self.finalize_multi_template_buffers() + if config.combo_kernels: + self.create_combo_kernel_nodes(num_ck_nodes=None) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Peak memory pass and overlap pass must run last, otherwise # other reordering passes could undo their effects. @@ -2356,6 +2546,7 @@ def _init(self, nodes: list[ir.Operation]) -> None: OrderedSet(V.graph.get_output_names()), ) if config.reorder_for_compute_comm_overlap: +<<<<<<< HEAD if not config.reorder_for_peak_memory: from .memory import assign_memory_planning_info_for_scheduler_buffers @@ -2397,14 +2588,23 @@ def _init(self, nodes: list[ir.Operation]) -> None: torch._inductor.config.graph_partition and torch._inductor.config.triton.cudagraphs ): +======= + self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) + self.process_grouped_nodes() + + if torch._inductor.config.graph_partition: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes) self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes) self.compute_last_usage() +<<<<<<< HEAD if torch._inductor.config.test_configs.track_memory_lifecycle: self.insert_memory_check_nodes() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log_ir_post_fusion(self.nodes) V.debug.graph_diagram(self.nodes) self.debug_draw_graph() @@ -2511,7 +2711,13 @@ def compute_dependencies(self) -> None: mutation properly. """ +<<<<<<< HEAD class DedupList(Generic[_T]): +======= + T = TypeVar("T") + + class DedupList(Generic[T]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This data structure behaves like a list except it makes sure the elements remain unique. @@ -2523,19 +2729,32 @@ class DedupList(Generic[_T]): def __init__( self, +<<<<<<< HEAD items: Optional[list[_T]] = None, membership: Optional[OrderedSet[_T]] = None, +======= + items: Optional[list[T]] = None, + membership: Optional[OrderedSet[T]] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: self.items = items or [] self.membership = membership or OrderedSet() +<<<<<<< HEAD def append(self, node_user: _T) -> None: +======= + def append(self, node_user: T) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if node_user in self.membership: return self.items.append(node_user) self.membership.add(node_user) +<<<<<<< HEAD def __add__(self, other: DedupList[_T]) -> DedupList[_T]: +======= + def __add__(self, other: DedupList[T]) -> DedupList[T]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_membership = OrderedSet.union(self.membership, other.membership) new_items = self.items + [ x for x in other.items if x not in self.membership @@ -2610,11 +2829,20 @@ def add_user( for fs in s.free_symbols: unbacked_symbol_to_origin_node[fs] = None +<<<<<<< HEAD has_non_input_unbacked_defs = False for node in self.nodes: assert node.node is not None # unbacked symbols don't follow ordinary buffer dependencies, so # we track their def/uses separately +======= + for node in self.nodes: + log.debug("scheduling %s", node.node) + + # unbacked symbols don't follow ordinary buffer dependencies, so + # we track their def/uses separately + assert node.node is not None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unbacked_symbol_defs = sorted( node.node.get_unbacked_symbol_defs(), key=lambda x: x.name ) @@ -2623,6 +2851,7 @@ def add_user( # Pick the first definer as canonical. There may be multiple # because if a MultiOutputLayout buffer propagates an unbacked # symint to multiple outputs, they will all claim to def it. +<<<<<<< HEAD has_non_input_unbacked_defs = True if s not in unbacked_symbol_to_origin_node: unbacked_symbol_to_origin_node[s] = node.get_name() @@ -2645,6 +2874,22 @@ def add_user( if (r := unbacked_symbol_to_origin_node[s]) is not None: for buf in self.name_to_node[r].get_outputs(): node.add_fake_dep(StarDep(buf.get_name())) +======= + if s not in unbacked_symbol_to_origin_node: + unbacked_symbol_to_origin_node[s] = node.get_name() + + unbacked_symbol_uses = sorted( + node.node.get_free_symbol_uses(unbacked_only=True), key=lambda x: x.name + ) + # if a kernel takes unbacked symints, register dependencies + for s in unbacked_symbol_uses: + assert s in unbacked_symbol_to_origin_node, ( + f"{s} not in {unbacked_symbol_to_origin_node}" + ) + if (r := unbacked_symbol_to_origin_node[s]) is not None: + for buf in self.name_to_node[r].get_outputs(): + node.add_fake_dep(StarDep(buf.get_name())) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( len(node.read_writes.writes) == 1 @@ -2699,6 +2944,7 @@ def add_user( add_user(buf_name, OutputNode(StarDep(buf_name))) # make sure unbacked symints aren't dead-code-eliminated +<<<<<<< HEAD if has_non_input_unbacked_defs: for out in V.graph.graph_outputs: for s in out.get_free_symbol_uses(unbacked_only=True): @@ -2713,6 +2959,19 @@ def add_user( s, ) add_user(buf_name, OutputNode(StarDep(buf_name))) +======= + for out in V.graph.graph_outputs: + for s in out.get_free_symbol_uses(unbacked_only=True): + assert s in unbacked_symbol_to_origin_node, ( + f"{s} not in {unbacked_symbol_to_origin_node.keys()}" + ) + if r := unbacked_symbol_to_origin_node[s]: + for buf_name in self.name_to_node[r].get_buffer_names(): + log.debug( + "scheduling output %s for unbacked symint %s", buf_name, s + ) + add_user(buf_name, OutputNode(StarDep(buf_name))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # make sure input mutation isn't dead-code-eliminated for name in self.mutation_renames: @@ -2750,6 +3009,7 @@ def add_user( compute_dependencies_log.debug("BUFFER USER LIST\n") compute_dependencies_log.debug("===== AFTER SCHEDULING =====\n%s", str) +<<<<<<< HEAD def insert_memory_check_nodes(self) -> None: from .memory import ( assign_memory_planning_info_for_scheduler_buffers, @@ -2827,6 +3087,8 @@ def construct_mem_check_node( self.nodes = new_nodes +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def dead_node_elimination(self) -> None: """ Remove any nodes without users @@ -2962,10 +3224,17 @@ def compute_ancestors(self) -> None: node.max_order = order def merge_loops(self) -> None: +<<<<<<< HEAD if not config.loop_ordering_after_fusion: return for node in self.nodes: +======= + for node in self.nodes: + if not config.loop_ordering_after_fusion: + continue + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Even for CPU, if we are using the halide backend, we still need # the merge loops steps below if not isinstance(node, (SchedulerNode, FusedSchedulerNode)) or ( @@ -3047,10 +3316,14 @@ def benchmark_fused_nodes( return backend.benchmark_fused_nodes(nodes) def generate_kernel_code_from_nodes( +<<<<<<< HEAD self, nodes: Sequence[BaseSchedulerNode], benchmark_kernel: bool, hint_override: Optional[int] = None, +======= + self, nodes: Sequence[BaseSchedulerNode], benchmark_kernel: bool +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> str: """ Benchmark fused list of nodes and return the execution time @@ -3061,9 +3334,13 @@ def generate_kernel_code_from_nodes( self.current_device = device backend = self.get_backend(device) with dynamo_timed("benchmark_fused_nodes"): +<<<<<<< HEAD return backend.generate_kernel_code_from_nodes( nodes, benchmark_kernel, hint_override=hint_override ) +======= + return backend.generate_kernel_code_from_nodes(nodes, benchmark_kernel) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def benchmark_codegened_module( self, module: ModuleType, device: torch.device @@ -3125,7 +3402,11 @@ def replace_operation_buffer( min_node_unfused = next( ( timing +<<<<<<< HEAD for timing in multi_node.choice_timings() +======= + for timing in multi_node.choice_timings +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance( timing, torch._inductor.select_algorithm.ExternKernelCaller, @@ -3137,6 +3418,7 @@ def replace_operation_buffer( min_node_unfused, torch._inductor.ir.TritonTemplateCallerBase, ): +<<<<<<< HEAD if config.multi_kernel_hints: callers: dict[Optional[int], TritonTemplateCallerBase] = {} callers[None] = min_node_unfused @@ -3158,6 +3440,13 @@ def replace_operation_buffer( out_tensorbox = min_node_unfused.output_node() out_storage = out_tensorbox.data # type: ignore[union-attr] +======= + node.node.finalize_as_triton_caller(min_node_unfused) + continue + + out_tensorbox = min_node_unfused.output_node() + out_storage = out_tensorbox.data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(out_storage, ir.StorageBox) out_buffer = out_storage.data assert isinstance(out_buffer, ir.OperationBuffer) @@ -3276,10 +3565,17 @@ def log_fusion(ms_fused: float, ms1: float, ms2: float) -> None: async_compile = torch._inductor.async_compile.AsyncCompile() def compile_kernel( +<<<<<<< HEAD nodes: Sequence[BaseSchedulerNode], hint_override: Optional[int] = None ) -> tuple[Optional[LambdaFuture], ModuleType]: src_code = self.generate_kernel_code_from_nodes( nodes, benchmark_kernel=True, hint_override=hint_override +======= + nodes: Sequence[BaseSchedulerNode], + ) -> tuple[Optional[LambdaFuture], ModuleType]: + src_code = self.generate_kernel_code_from_nodes( + nodes, benchmark_kernel=True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) mod = PyCodeCache.load(src_code) if not async_compile.use_process_pool(): @@ -3301,6 +3597,7 @@ def compile_kernel( ) assert isinstance(multi_node, ir.MultiTemplateBuffer) +<<<<<<< HEAD hint_override_best_fusion_choice: dict[ Optional[int], TritonTemplateCallerBase ] = {} @@ -3353,6 +3650,10 @@ def compile_kernel( # Eagerly compile and benchmark non-template nodes choice_timings = multi_node.choice_timings() +======= + # Eagerly compile and benchmark non-template nodes + choice_timings = multi_node.choice_timings +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _, ms1 = multi_node.get_min_choice() ms2, path2 = ( self.benchmark_fused_nodes(node_list_2) @@ -3427,6 +3728,7 @@ def benchmark_when_ready() -> bool: log_fusion(min_ms_fused, ms1, ms2) if min_ms_fused < (ms1 + ms2) and ms_fused_choice is not None: +<<<<<<< HEAD if config.multi_kernel_hints: hint_override_best_fusion_choice[None] = ms_fused_choice multi_node.finalize_as_triton_callers( @@ -3436,6 +3738,10 @@ def benchmark_when_ready() -> bool: multi_node.finalize_as_triton_caller(ms_fused_choice) multi_node._choice_timings[None] = new_timings +======= + multi_node.finalize_as_triton_caller(ms_fused_choice) + multi_node._choice_timings = new_timings +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return True else: return False @@ -3835,6 +4141,7 @@ def _find_single_user_inputs( return True return False +<<<<<<< HEAD def fusion_accumulate_large_reads( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode, threshold: int ) -> bool: @@ -3843,6 +4150,8 @@ def fusion_accumulate_large_reads( ) return sum(self.dep_size_hint(dep) for dep in all_reads) > threshold +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def are_long_distant_nodes( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> bool: @@ -3939,11 +4248,14 @@ def shared_data_after_reordering_loop( Right now just greedily reorder the loop of node1 to be compatible with node2, but ideally we should have some heuristics to reorder the loop for node2 to be compatible with node1 if that's more efficient. +<<<<<<< HEAD Return the amount of shared data re-computed in this method. If no such recomputation happens, return -1 (not return 0 since 0 is a valid amount of shared data). +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ # TODO Don't do loop reordering for CPU for now. @@ -3951,14 +4263,22 @@ def shared_data_after_reordering_loop( if not config.loop_ordering_after_fusion or any( n.is_cpu() for n in [node1, node2] ): +<<<<<<< HEAD return -1 +======= + return 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) node1_buffer_names = node1.read_writes.buffer_names() node2_buffer_names = node2.read_writes.buffer_names() # Fast path: no common buffers. common_buffer_names = node1_buffer_names & node2_buffer_names if not common_buffer_names: +<<<<<<< HEAD return -1 +======= + return 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} @@ -3981,13 +4301,21 @@ def shared_data_after_reordering_loop( ) if len(candidates) == 0: +<<<<<<< HEAD return -1 +======= + return 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Pick the largest buffer to guide the loop reordering _numel, lhs_dep, rhs_dep = max(candidates, key=operator.itemgetter(0)) if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep): +<<<<<<< HEAD return -1 +======= + return 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if lhs_dep.num_vars != rhs_dep.num_vars: # this can happen due to we don't merge loops. @@ -3996,6 +4324,7 @@ def shared_data_after_reordering_loop( # normalization (merging loops) if lhs_dep.normalize() == rhs_dep.normalize(): return self.dep_size_hint(lhs_dep) +<<<<<<< HEAD return -1 reordered = False @@ -4004,6 +4333,15 @@ def shared_data_after_reordering_loop( reordered = node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep) elif not node2.is_reduction(): reordered = node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep) +======= + return 0 + + # Only reorder loops for pointwise for now + if not node1.is_reduction(): + node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep) + elif not node2.is_reduction(): + node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: loop_ordering_log.debug( "Don't reorder loops since both nodes are reductions: %s v.s. %s", @@ -4011,7 +4349,11 @@ def shared_data_after_reordering_loop( node2.get_name(), ) +<<<<<<< HEAD return self.score_fusion_memory(node1, node2) if reordered else -1 +======= + return self.score_fusion_memory(node1, node2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def unfusable_node(self, node: BaseSchedulerNode) -> bool: """ @@ -4082,6 +4424,7 @@ def low_prec_fp(dtype: torch.dtype) -> bool: return True +<<<<<<< HEAD def get_expand_dim_for_pointwise_nodes( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> Optional[tuple[int, SchedulerNode, sympy.Expr]]: @@ -4180,11 +4523,17 @@ def has_reusable_buffer(node: BaseSchedulerNode) -> bool: else: return None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: """ Determine if it is possible to combine node1 and node2 into a single fused node. """ +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if node1 is node2: return False @@ -4234,7 +4583,11 @@ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: allowed_prologue_inps = template.get_allowed_prologue_inps() unsupported_prologue_args = ( +<<<<<<< HEAD OrderedSet(inp.get_name() for inp in template.inputs) # type: ignore[union-attr] +======= + OrderedSet(inp.get_name() for inp in template.inputs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - allowed_prologue_inps ) @@ -4288,6 +4641,10 @@ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: ): why("fusion for buffer explicit disabled") return False +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device = node1.get_device() device2 = node2.get_device() if device != device2: @@ -4300,6 +4657,7 @@ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: shared_data_score < config.score_fusion_memory_threshold and config.loop_ordering_after_fusion ): +<<<<<<< HEAD new_shared_data_score = self.shared_data_after_reordering_loop(node1, node2) if new_shared_data_score >= 0: shared_data_score = new_shared_data_score @@ -4310,6 +4668,9 @@ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: (expand_dim, smaller_node, expand_size) = expand_analysis smaller_node.expand_dimension_for_pointwise_node(expand_dim, expand_size) shared_data_score = self.score_fusion_memory(node1, node2) +======= + shared_data_score = self.shared_data_after_reordering_loop(node1, node2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if loop_ordering_log.isEnabledFor(logging.DEBUG): loop_ordering_log.debug( @@ -4363,7 +4724,11 @@ def can_fuse_vertical( if remaining: for rd in remaining: if self.fusable_read_and_write(rd, cd): +<<<<<<< HEAD remaining.remove(rd) # noqa: B909 +======= + remaining.remove(rd) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) remaining_deps = OrderedSet( dep.name @@ -4459,7 +4824,24 @@ def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: return False def dep_size_hint(self, dep: Dep) -> int: +<<<<<<< HEAD return V.graph.get_dep_size_hint(dep) +======= + res = 0 + if dep not in self.__dep_size_hint_cache: + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + self.__dep_size_hint_cache[dep] = res + else: + res = self.__dep_size_hint_cache[dep] + return res +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def score_fusion_memory( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode @@ -4646,6 +5028,7 @@ def can_buffer_be_removed_through_fusion( and name not in self.mutation_real_name ) +<<<<<<< HEAD def should_partition( self, node: BaseSchedulerNode, should_log: bool = False ) -> bool: @@ -4710,6 +5093,29 @@ def noop_log(msg: str, node: Optional[BaseSchedulerNode]) -> None: if is_cudagraph_unsafe_op(node.node): log_partition_reason("CUDAGraph-unsafe custom ops", node=node) +======= + def should_partition(self, node: BaseSchedulerNode) -> bool: + """Return True if we should partition the inductor graph on this node""" + if isinstance(node, FusedSchedulerNode): + return any(self.should_partition(snode) for snode in node.snodes) + + if not node.is_gpu(): + return True + + if node.node is None: + return True + + if isinstance(node.node, ir.DeviceCopy): + return True + + if isinstance(node.node, ir.Conditional): + return True + + if getattr(node.node, "unbacked_bindings", None): + return True + + if is_cudagraph_unsafe_op(node.node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return True return False @@ -5179,7 +5585,11 @@ def graph_partition( cur_partition: PartitionType = [] skip_cudagraphs = [] for node in self.nodes: +<<<<<<< HEAD should_partition = self.should_partition(node, should_log=True) +======= + should_partition = self.should_partition(node) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if cur_partition and skip_cudagraph != should_partition: partitions.append(cur_partition) skip_cudagraphs.append(skip_cudagraph) @@ -5249,6 +5659,7 @@ def _codegen_partition_wrapper( [node.get_name() for node in signature.output_nodes] ) +<<<<<<< HEAD def use_default_device_context( self, partitions: list[PartitionType], signatures: list[GraphPartitionSignature] ) -> contextlib.AbstractContextManager[None]: @@ -5323,6 +5734,8 @@ def all_on_target_device( self.default_device_context = cudagraph_partition_device +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _codegen_partitions(self) -> None: """ Split nodes into partitions and codegen each partition into separate functions. @@ -5331,6 +5744,7 @@ def _codegen_partitions(self) -> None: """ partitions, signatures = self.graph_partition() +<<<<<<< HEAD if len(partitions) > 1: msg = f"cudagraph partition into {len(partitions)} partitions" maybe_log_cudagraph_partition(msg=msg, prefix="") @@ -5345,6 +5759,17 @@ def _codegen_partitions(self) -> None: self._codegen(partition) else: self._codegen_partition_wrapper(partition, signature) +======= + for partition, signature in zip(partitions, signatures): + assert len(partition) >= 1, ( + f"Each partition must have at least one node but found {len(partition)}" + ) + + if signature.skip_cudagraph: + self._codegen(partition) + else: + self._codegen_partition_wrapper(partition, signature) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) num_partitions = next(self._graph_partition_counter) V.graph.wrapper_code.set_all_partition_names(num_partitions) @@ -5377,11 +5802,15 @@ def _codegen(self, nodes: list[BaseSchedulerNode]) -> None: ) seen.add(key) +<<<<<<< HEAD self.current_device = self.default_device_context if self.default_device_context and config.triton.autotune_at_compile_time: V.graph.wrapper_code.write_get_raw_stream_header() +======= + self.current_device = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for node in nodes: if log.isEnabledFor(logging.DEBUG): try: @@ -5415,7 +5844,10 @@ def _codegen(self, nodes: list[BaseSchedulerNode]) -> None: assert device.index is not None, "device should have an index" V.graph.wrapper_code.codegen_device_guard_enter(device.index) +<<<<<<< HEAD self.current_node = node +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.buffer_names_to_free.update(node.last_usage) if node.is_template(): @@ -5460,6 +5892,7 @@ def _codegen(self, nodes: list[BaseSchedulerNode]) -> None: ): self.flush() +<<<<<<< HEAD if self.current_device != self.default_device_context: # when default_device_context is not None, we are codegen # for graph partitions and all nodes must be on @@ -5469,6 +5902,12 @@ def _codegen(self, nodes: list[BaseSchedulerNode]) -> None: # exit the outermost CUDA device guard. this is # important for nested indentation codegen-ing. V.graph.wrapper_code.codegen_device_guard_exit() +======= + if self.current_device and device_need_guard(self.current_device.type): + # exit the outermost CUDA device guard. this is + # important for nested indentation codegen-ing. + V.graph.wrapper_code.codegen_device_guard_exit() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.flush() @@ -5657,10 +6096,14 @@ def codegen_template( raise NotImplementedError def generate_kernel_code_from_nodes( +<<<<<<< HEAD self, nodes: Sequence[BaseSchedulerNode], benchmark_kernel: bool, hint_override: Optional[int] = None, +======= + self, nodes: Sequence[BaseSchedulerNode], benchmark_kernel: bool +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> str: """ Generate a kernel given a list of pre-fused nodes. diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index ac8daee16417a..d481ca73b27bf 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -27,6 +27,7 @@ import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.testing import rand_strided +<<<<<<< HEAD from torch._dynamo.utils import ( counters, dynamo_timed, @@ -35,6 +36,9 @@ preserve_rng_state, ) from torch._inductor.await_utils import await_sync +======= +from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.utils import clear_on_fresh_cache from torch.utils._filelock import FileLock from torch.utils._ordered_set import OrderedSet @@ -61,14 +65,20 @@ from .codegen.triton import ( gen_common_triton_imports, texpr, +<<<<<<< HEAD TMACompatibilityChecker, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TritonKernel, TritonScheduling, ) from .codegen.triton_utils import config_of, equal_1_arg_indices, signature_to_meta from .codegen.wrapper import pexpr from .exc import CUDACompileError +<<<<<<< HEAD from .fx_utils import count_flops_fx +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .ir import ChoiceCaller, PrimitiveInfoType from .ops_handler import StoreMode from .runtime.benchmarking import benchmarker @@ -174,6 +184,7 @@ class PartialRender: of replacements after the initial render. """ +<<<<<<< HEAD HookFn = Callable[[], str] def __init__( @@ -207,6 +218,14 @@ def finalize_hook(self, hook_key: str, strict: bool = True) -> None: NOTE: Will **error** if the hook has already been finalized. """ +======= + def __init__(self, code, replacement_hooks) -> None: + super().__init__() + self.code = code + self.replacement_hooks = replacement_hooks + + def finalize_hook(self, hook_key: str, strict=True) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if hook_key not in self.replacement_hooks: if strict: raise RuntimeError( @@ -214,6 +233,7 @@ def finalize_hook(self, hook_key: str, strict: bool = True) -> None: ) else: return +<<<<<<< HEAD hook = self.replacement_hooks[hook_key] assert hook is not None, f"Hook key {hook_key} can only be called once" @@ -243,6 +263,17 @@ def finalize_all(self) -> str: """ for key in self.replacement_hooks: self.finalize_hook(key) +======= + assert self.replacement_hooks[hook_key] is not None, ( + "hook_key can only be called once" + ) + self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]()) + self.replacement_hooks[hook_key] = None + + def finalize_all(self) -> str: + for key, fn in self.replacement_hooks.items(): + self.code = self.code.replace(key, fn()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.code @@ -263,7 +294,11 @@ class SubgraphInfo: # only copied over if not None range_trees: Optional[list["IterationRangesRoot"]] = None +<<<<<<< HEAD numels: Optional[dict[str, sympy.Expr]] = None +======= + numels = None # type: ignore[var-annotated] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __post_init__(self): self.only_copy_if_non_none_fields = ("range_trees", "numels") @@ -295,8 +330,12 @@ def load(self, name: str, index: sympy.Expr): if name not in self.fixed_inputs: index_str = self._process_indexing(index) var = self._add_kernel_input(name) +<<<<<<< HEAD buffer = V.graph.get_buffer(name) var_dtype = buffer.dtype +======= + var_dtype = V.graph.get_buffer(name).dtype +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) line = f"tl.load({var} + {index_str})" if ( @@ -306,6 +345,7 @@ def load(self, name: str, index: sympy.Expr): line += ".to(tl.float32)" var_dtype = torch.float32 +<<<<<<< HEAD out = self.kernel.cse.generate( self.kernel.compute, line, dtype=var_dtype, shape=() ) @@ -316,6 +356,13 @@ def load(self, name: str, index: sympy.Expr): f"({self.fixed_inputs[name]})", dtype=torch.float32, shape=(), +======= + out = self.kernel.cse.generate(self.kernel.compute, line, dtype=var_dtype) + return out + + return self.kernel.cse.generate( + self.kernel.compute, f"({self.fixed_inputs[name]})", dtype=torch.float32 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def indirect_indexing(self, index_var: str, size, check, wrap_neg=True): @@ -354,6 +401,7 @@ def _process_indexing(self, index): class TritonTemplateKernel(TritonKernel): +<<<<<<< HEAD """ A specialized kernel class for Triton templates that handles code generation for templated Triton kernels. @@ -363,6 +411,8 @@ class TritonTemplateKernel(TritonKernel): arguments, and prologue/epilogue fusion. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__( self, kernel_name, @@ -383,7 +433,10 @@ def __init__( subgraphs: Optional[list[ir.ComputedBuffer]] = None, workspace_arg: Optional[WorkspaceArg] = None, prologue_loads_all_inputs=False, +<<<<<<< HEAD hint_override: Optional[int] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: numel = sympy_product(output_node.get_size()) super().__init__( @@ -392,7 +445,10 @@ def __init__( "r0_": sympy.S.One, }, features=SIMDKernelFeatures([], numel), +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.input_nodes = input_nodes self.output_node = output_node @@ -465,9 +521,12 @@ def __init__( # by adding all inputs. self.prologue_loads_all_inputs = prologue_loads_all_inputs +<<<<<<< HEAD # Extra functions to be exposed during partial template rendering. self.extra_template_env_fns: list[Callable[..., Any]] = [] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def input_dependent_preserved_state(self) -> str: # Not adding self.args.output_buffers on purpose. But we do not need to reproduce it on a cache hit. # (never accessed). @@ -557,12 +616,17 @@ def estimate_kernel_num_bytes(self): ninplace_args = len(unique(self.args.inplace_buffers.values())) num_bytes = [] for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))): +<<<<<<< HEAD size = V.graph.sizevars.size_hints(inp.get_size(), fallback=0) +======= + size = V.graph.sizevars.size_hints(inp.get_size()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) numel = functools.reduce(operator.mul, size, 1) dtype_size = get_dtype_size(inp.get_dtype()) num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) return sum(num_bytes) +<<<<<<< HEAD def estimate_flops(self) -> int: for node in self.input_nodes: for fx_node in node._current_origins: @@ -571,6 +635,8 @@ def estimate_flops(self) -> int: return V.graph.sizevars.size_hint(f, fallback=0) return 0 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def jit_lines(self): if self.use_jit: return "@triton.jit" @@ -603,17 +669,24 @@ def jit_lines(self): inductor_meta = { "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), +<<<<<<< HEAD **self.inductor_meta_common(), +======= + **TritonKernel.inductor_meta_common(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) **FixedGrid.setup_grid_as_args(), } if config.profile_bandwidth or config.benchmark_kernel: num_gb = self.estimate_kernel_num_bytes() / 1e9 inductor_meta["kernel_num_gb"] = num_gb +<<<<<<< HEAD if config.benchmark_kernel: flops = self.estimate_flops() inductor_meta["kernel_flop"] = flops inductor_meta["config_args"] = self.meta +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template_args = f""" num_stages={self.num_stages}, @@ -641,7 +714,12 @@ def hook(): arg_defs, *_ = self.args.python_argdefs() return f"{', '.join(x.full_name() for x in arg_defs)}" +<<<<<<< HEAD return self._register_hook("", hook, allow_overwriting=True) +======= + self.render_hooks[""] = hook + return "" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def gen_defines(self): return self.defines @@ -719,7 +797,13 @@ def hook(): code.splice(renames.getvalue()) return code.getvalue() +<<<<<<< HEAD return self._register_hook("", hook) +======= + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def size(self, name: str, index: int): """ @@ -1012,7 +1096,13 @@ def hook(): return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() +<<<<<<< HEAD return self._register_hook(hook_key, hook) +======= + assert hook_key not in self.render_hooks + self.render_hooks[hook_key] = hook + return hook_key +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def store_output( self, @@ -1020,7 +1110,10 @@ def store_output( val: str, mask: Optional[str] = None, indent_width: int = 4, +<<<<<<< HEAD val_shape: Optional[list[str]] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away. @@ -1071,9 +1164,13 @@ def store_output( if "ACC_TYPE" in self.meta else torch.float32 ) +<<<<<<< HEAD epilogue_args = [ V.kernel.cse.namedvar(val, dtype=acc_dtype, shape=val_shape) ] +======= + epilogue_args = [V.kernel.cse.namedvar(val, dtype=acc_dtype)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for input_node in itertools.chain( self.input_nodes[: self.prefix_args], self.input_nodes[len(self.input_nodes) - self.suffix_args :], @@ -1097,6 +1194,7 @@ def hook(): return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() +<<<<<<< HEAD return self._register_hook("", hook) def _register_hook( @@ -1141,6 +1239,11 @@ def _register_extra_template_env_fns(self, *fns: Callable[..., Any]): function. """ self.extra_template_env_fns.extend(fns) +======= + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def render(self, template, kwargs, record_input_dependent_tracked_event=False): if record_input_dependent_tracked_event: @@ -1160,7 +1263,10 @@ def render(self, template, kwargs, record_input_dependent_tracked_event=False): self.modification, self.gen_argdefs, self.gen_defines, +<<<<<<< HEAD *self.extra_template_env_fns, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] } return PartialRender( @@ -1192,7 +1298,10 @@ def indexing( copy_shape=None, override_mask=None, block_ptr=False, +<<<<<<< HEAD tma_compatibility_checker: Optional[TMACompatibilityChecker] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Override the default indexing to use our custom mask and force @@ -1206,12 +1315,16 @@ def indexing( copy_shape=self.template_out, override_mask=self.template_mask, block_ptr=block_ptr, +<<<<<<< HEAD tma_compatibility_checker=tma_compatibility_checker, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def codegen_range_tree(self): pass # ignore default codegen +<<<<<<< HEAD def additional_call_args_and_types(self): if isinstance(self.grid_fn, SymbolicGridFn): grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta) @@ -1223,15 +1336,26 @@ def additional_call_args_and_types(self): return (grid_args, map(type, grid_args)) return ((), ()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): wrapper = V.graph.wrapper_code _, call_args, _, arg_types = self.args.python_argdefs() +<<<<<<< HEAD additional_call_args, additional_arg_types = ( self.additional_call_args_and_types() ) if not additional_call_args: +======= + grid_args = () + if isinstance(self.grid_fn, SymbolicGridFn): + grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta) + elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes): + grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta) + else: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert not V.graph.cpp_wrapper, "cpp_wrapper requires SymbolicGridFn" wrapper.add_import_once(f"import {self.grid_fn.__module__}") meta = wrapper.add_meta_once(self.meta) @@ -1240,9 +1364,15 @@ def call_kernel(self, name: str, node: Optional[ir.IRNode] = None): f"*{fn_name}({', '.join(map(pexpr, self.call_sizes))}, {meta})" ) arg_types.append(None) +<<<<<<< HEAD call_args.extend(additional_call_args) arg_types.extend(additional_arg_types) +======= + assert len(grid_args) in (0, 3), "grid_fn should return 3 values" + call_args.extend(grid_args) + arg_types.extend(map(type, grid_args)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.workspace_arg is not None: wrapper.generate_workspace_allocation(self.workspace_arg) @@ -1317,7 +1447,11 @@ def make_key( input_nodes: tuple[ir.IRNode], num_stages: int, num_warps: int, +<<<<<<< HEAD call_sizes: Sequence[sympy.core.symbol.Symbol], +======= + call_sizes: list[sympy.core.symbol.Symbol], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prefix_args: int, suffix_args: int, epilogue_fn: Optional[Callable[..., Any]], @@ -1328,7 +1462,10 @@ def make_key( num_consumer_groups: int, num_buffers_warp_spec: int, kwargs: dict[str, Any], +<<<<<<< HEAD hint_override: Optional[int] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Optional[str]: def layout_key(layout: ir.Layout) -> str: assert not isinstance(layout, ir.FlexibleLayout) @@ -1379,7 +1516,10 @@ def has_flexible_layout() -> bool: "num_buffers_warp_spec": num_buffers_warp_spec, "epilogue_fn_hash": epilogue_fn_hash, "kwargs": kwargs, +<<<<<<< HEAD "hint_override": hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ) @@ -1443,11 +1583,14 @@ def __init__( # was not used are the same. test_cache = False +<<<<<<< HEAD @property def uid(self) -> str: # unique by prefixing with triton return f"triton::{self.name}" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def maybe_append_choice( self, choices: list[Any], **kwargs: Any ) -> Optional[NotImplementedError]: @@ -1460,9 +1603,13 @@ def maybe_append_choice( """ try: +<<<<<<< HEAD choice = self.generate(generate_with_caching=True, **kwargs) if choice is not None: choices.append(choice) +======= + choices.append(self.generate(generate_with_caching=True, **kwargs)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return None except NotImplementedError as e: log.info( @@ -1479,7 +1626,11 @@ def generate_and_load( input_nodes: tuple[ir.IRNode], num_stages: int, num_warps: int, +<<<<<<< HEAD call_sizes: Sequence[sympy.core.symbol.Symbol], +======= + call_sizes: list[sympy.core.symbol.Symbol], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prefix_args: int, suffix_args: int, epilogue_fn: Optional[Callable[..., Any]], @@ -1491,7 +1642,10 @@ def generate_and_load( layout: ir.Layout, kwargs: dict[str, Any], generate_with_caching, +<<<<<<< HEAD hint_override: Optional[int] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Optional[GenerateAndLoadResult]: """Generate the python code and load it into the current process""" caching_enabled = ( @@ -1523,12 +1677,17 @@ def generate_and_load( for name, val in kwargs.items(): defines.write(f"{name} : tl.constexpr = {val}\n") +<<<<<<< HEAD +======= + defines = defines.getvalue() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fake_out = ir.Buffer(name="buf_out", layout=layout) kernel_name = f"triton_{self.name}" numel = sympy_product(layout.size) buffers = itertools.chain(input_nodes, (fake_out,)) +<<<<<<< HEAD if TritonScheduling.can_use_32bit_indexing(numel, buffers): index_dtype = "tl.int32" @@ -1538,6 +1697,12 @@ def generate_and_load( # Add index dtype to defines so it's available in the template defines.write(f"INDEX_DTYPE : tl.constexpr = {index_dtype}\n") defines = defines.getvalue() +======= + if not TritonScheduling.can_use_32bit_indexing(numel, buffers): + raise NotImplementedError( + "64-bit indexing is not yet implemented for triton templates" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kernel_options = { "input_nodes": input_nodes, @@ -1568,7 +1733,10 @@ def make_kernel(): output_node=fake_out, workspace_arg=workspace_arg, use_jit=False, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) **kernel_options, ) @@ -1685,10 +1853,16 @@ def generate( # type: ignore[override] epilogue_fn_hash: Optional[str] = None, subgraphs: Optional[list[ir.Buffer]] = None, mutated_inputs: Optional[list[ir.IRNode]] = None, +<<<<<<< HEAD call_sizes: Optional[Sequence[sympy.core.symbol.Symbol]] = None, workspace_arg: Optional[WorkspaceArg] = None, generate_with_caching=False, hint_override: Optional[int] = None, +======= + call_sizes: Optional[list[sympy.core.symbol.Symbol]] = None, + workspace_arg: Optional[WorkspaceArg] = None, + generate_with_caching=False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) **kwargs, ): """This function generates a TritonTemplateCaller @@ -1733,7 +1907,10 @@ def generate( # type: ignore[override] layout, kwargs, generate_with_caching and self._cache_codegen_enabled_for_template, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # May happen as result of dev by 0. @@ -1755,7 +1932,10 @@ def generate( # type: ignore[override] extra_args = V.graph.sizevars.size_hints( map(sympy.expand, result.kernel_args_sizevars_keys), fallback=config.unbacked_symint_fallback, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}" @@ -1779,14 +1959,21 @@ def generate( # type: ignore[override] options = result.kernel_options +<<<<<<< HEAD def make_kernel_render(out_node, hint_override: Optional[int] = None): +======= + def make_kernel_render(out_node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert result is not None kernel = self.kernel_type( kernel_name=str(Placeholder.KERNEL_NAME), output_node=out_node, workspace_arg=workspace_arg, use_jit=False, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) **options, ) render = functools.partial( @@ -1802,7 +1989,10 @@ def make_kernel_render(out_node, hint_override: Optional[int] = None): *V.graph.sizevars.size_hints( call_sizes, fallback=config.unbacked_symint_fallback, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), kwargs, ) @@ -1854,7 +2044,10 @@ def make_kernel_render(out_node, hint_override: Optional[int] = None): mutated_inputs=mutated_inputs, workspace_arg=workspace_arg, allowed_prologue_inps=result.prologue_supported_inputs, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -1914,6 +2107,7 @@ def bind( self, input_nodes, layout, kwargs, has_out_variant=self.has_out_variant ) +<<<<<<< HEAD @property def uid(self) -> str: # unique by prefixing with aten @@ -1944,6 +2138,8 @@ def maybe_append_choice( choices.append(self.bind(input_nodes=input_nodes, layout=layout, **kwargs)) return None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TritonTemplateCaller(ir.TritonTemplateCallerBase): def __init__( @@ -1960,7 +2156,10 @@ def __init__( mutated_inputs=None, workspace_arg: Optional[WorkspaceArg] = None, allowed_prologue_inps: Optional[OrderedSet[str]] = None, +<<<<<<< HEAD hint_override: Optional[int] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: super().__init__(name, input_nodes, layout, description) self.make_kernel_render = make_kernel_render @@ -1980,7 +2179,10 @@ def __init__( self.allowed_prologue_inps = ( allowed_prologue_inps if allowed_prologue_inps is not None else OrderedSet() ) +<<<<<<< HEAD self.hint_override = hint_override +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def benchmark(self, *args, out): assert self.bmreq is not None @@ -2097,7 +2299,11 @@ def output_node(self): assert self.choice.op_overload is not None, ( "Please provide an op_overload to use ir.FallbackKernel" ) +<<<<<<< HEAD inner: ir.IRNode = ir.FallbackKernel.create( +======= + inner = ir.FallbackKernel.create( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.choice.op_overload, *self.input_nodes, **self.kwargs ) elif self.choice.kernel_creator is not None: @@ -2303,6 +2509,7 @@ def create_precompile_key( None, ] +<<<<<<< HEAD # Args to PreprocessingFunctions # choices: list of ChoiceCaller objects to preprocess # Returns: modified list of ChoiceCaller objects @@ -2336,6 +2543,8 @@ def filter_choices_by_desc_regex(choices: list[ChoiceCaller]) -> list[ChoiceCall ] return choices +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class AlgorithmSelectorCache(PersistentCache): """ @@ -2356,6 +2565,7 @@ def __init__(self, *args, **kwargs) -> None: # first to benchmark it. share a single precompilation function for all lowerings # of a particular key self.precompile_cache: dict[str, Callable[[], None]] = {} +<<<<<<< HEAD # cache for prescreening results to ensure deterministic candidate selection self.prescreening_cache: dict[str, OrderedSet[str]] = {} # list of callbacks that are called after benchmarking @@ -2375,6 +2585,15 @@ def _register_default_preprocessing_fns(self): self.add_preprocessing_fn(filter_choices_by_name_regex) self.add_preprocessing_fn(filter_choices_by_desc_regex) +======= + # list of callbacks that are called after benchmarking + self.feedback_saver_fns: list[FeedbackFunction] = [] + # cache for prescreening results to ensure deterministic candidate selection + self.prescreening_cache: dict[str, OrderedSet[str]] = {} + + clear_on_fresh_cache(self) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def cache_clear(self) -> None: self.precompile_cache.clear() self.prescreening_cache.clear() @@ -2393,6 +2612,7 @@ def __call__( input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None, precompilation_timeout_seconds: int = 60 * 60, return_multi_template=False, +<<<<<<< HEAD best_config_future=None, ): from .codegen.cuda.cuda_kernel import CUDATemplateCaller @@ -2401,6 +2621,11 @@ def __call__( for preprocessing_fn in self.preprocessing_fns: choices = preprocessing_fn(choices) +======= + ): + from .codegen.cuda.cuda_kernel import CUDATemplateCaller + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Templates selected with input_gen_fns require specific input data to avoid IMA # Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection # TODO(jgong5): support multi-template on CPU @@ -2409,6 +2634,28 @@ def __call__( # TODO - assert that we have not mutating kernels here +<<<<<<< HEAD +======= + if config.test_configs.autotune_choice_name_regex is not None: + choices = [ + c + for c in choices + if re.search( + config.test_configs.autotune_choice_name_regex, + c.name, + ) + ] + if config.test_configs.autotune_choice_desc_regex is not None: + choices = [ + c + for c in choices + if re.search( + config.test_configs.autotune_choice_desc_regex, + c.description, + ) + ] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if mm_file_name := get_mm_log_filename(): M, K = input_nodes[-2].get_size()[:2] N = input_nodes[-1].get_size()[-1] @@ -2431,6 +2678,7 @@ def __call__( # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size. return choices[0].output_node() +<<<<<<< HEAD inputs_key = create_inputs_key(input_nodes) # TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking @@ -2451,12 +2699,22 @@ def benchmark(choices, hint_override: Optional[int] = None): return benchmark_fn(choices) def autotune(choices, hint_override: Optional[int] = None): +======= + @functools.cache + def make_benchmark_fn(): + return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns) + + inputs_key = create_inputs_key(input_nodes) + + def autotune(choices): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) log.debug("Starting autotuning") with dynamo_timed( f"{name}_template_autotuning", log_pt2_compile_event=True, dynamo_compile_column_us="compile_time_autotune_time_us", +<<<<<<< HEAD metadata=_autotune_metadata(input_nodes), ): benchmark_results = benchmark(choices, hint_override=hint_override) @@ -2465,12 +2723,34 @@ def autotune(choices, hint_override: Optional[int] = None): f"{name}_template_autotuning", benchmark_results ) return benchmark_results +======= + metadata={ + "autotune_strides": ", ".join( + [str(n.get_stride()) for n in input_nodes] + ), + "autotune_dtypes": ", ".join( + [str(n.get_dtype()) for n in input_nodes] + ), + "autotune_shape": ", ".join( + ["x".join(map(str, n.get_size())) for n in input_nodes] + ), + "autotune_offset": ", ".join( + [str(n.get_layout().offset) for n in input_nodes] + ), + }, + ): + return make_benchmark_fn()(choices) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if config.autotune_in_subproc: # Initialize the suprocess pool so it will warmup early. torch._inductor.autotune_process.get_tuning_process_pool() +<<<<<<< HEAD def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None): +======= + def do_autotuning(choices, precompile_fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) precompile_start_ts = time.time() with dynamo_timed( f"{name}_template_precompiling", @@ -2491,8 +2771,12 @@ def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None): candidates, name, inputs_key, +<<<<<<< HEAD lambda choices: autotune(choices, hint_override=hint_override), hint_override=hint_override, +======= + autotune, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) choices = self.prune_choices_postscreen( choices, timings, name, inputs_key, self.prescreening_cache @@ -2501,6 +2785,7 @@ def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None): log.debug("Prescreening elapsed time: %.02fs", prescreening_elapse) autotune_start_ts = time.time() +<<<<<<< HEAD if best_config_future is not None: best_config = await_sync(best_config_future) @@ -2530,12 +2815,18 @@ def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None): ] log.info("Filtered to %d choices based on best_config", len(choices)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) timings = self.lookup( choices, name, inputs_key, +<<<<<<< HEAD lambda choices: autotune(choices, hint_override=hint_override), hint_override=hint_override, +======= + autotune, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) autotune_elapse = time.time() - autotune_start_ts @@ -2546,8 +2837,16 @@ def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None): ): raise NoValidChoicesError +<<<<<<< HEAD if ( has_autotuned +======= + if make_benchmark_fn.cache_info().currsize: + counters["inductor"]["select_algorithm_autotune"] += 1 + + if ( + make_benchmark_fn.cache_info().currsize +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) or log.getEffectiveLevel() == logging.DEBUG or config.trace.log_autotuning_results ): @@ -2558,7 +2857,10 @@ def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None): autotune_elapse, precompile_elapse, prescreening_elapse, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def profiler_bench_function(): @@ -2570,7 +2872,13 @@ def profiler_bench_function(): profile_bandwidth_with_do_bench_using_profiling=True, autotune_in_subproc=False, ): +<<<<<<< HEAD return benchmark(choices) +======= + return self.make_benchmark_fn( + choices, input_nodes, layout, input_gen_fns + )(choices) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for feedback_fn in self.feedback_saver_fns: # re-benchmarking the same choices with profiler is a bit expensive, so pass it in as a thunk. @@ -2593,6 +2901,7 @@ def profiler_bench_function(): if return_multi_template and (config.max_autotune or config.max_autotune_gemm): +<<<<<<< HEAD def get_timings(hint_override: Optional[int] = None): filtered_choices = [ c @@ -2603,6 +2912,10 @@ def get_timings(hint_override: Optional[int] = None): timings = do_autotuning( filtered_choices, precompile_fn, hint_override=hint_override ) +======= + def get_timings(): + timings = do_autotuning(choices, precompile_fn) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) min_extern_choice = float("inf") for choice, timing in timings.items(): if isinstance(choice, ExternKernelCaller): @@ -2621,7 +2934,11 @@ def get_timings(hint_override: Optional[int] = None): # We take the union of allowed prologue inputs from all choices, # and, within benchmark fusion, don't allow prologue fusion for +<<<<<<< HEAD # choices which don't support the whole union. +======= + # choices which dont support the whole union. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) allowed_prologue_inps: OrderedSet[str] = OrderedSet() for c in choices: if isinstance(c, TritonTemplateCaller): @@ -2716,6 +3033,14 @@ def no_op(*args, **kwargs): log.debug("Found all %d timings in cache, returning no_op", len(timings)) return no_op +<<<<<<< HEAD +======= + if config.search_autotune_cache and not ( + config.max_autotune or config.max_autotune_gemm + ): + return no_op + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) precompile_key = create_precompile_key(name, inputs_key, choices) if precompile_func := self.precompile_cache.get(precompile_key): log.debug("Precompile function found in cache, returning it") @@ -2793,16 +3118,22 @@ def on_complete(future): def wait_on_futures(): log.debug("Waiting on futures") counters["inductor"]["select_algorithm_precompile"] += 1 +<<<<<<< HEAD exceptions: list[tuple[ChoiceCaller, BaseException]] = [] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for future in as_completed( futures, timeout=precompilation_timeout_seconds, ): if e := future.exception(): +<<<<<<< HEAD counters["inductor"][ "select_algorithm_num_precompilation_exceptions" ] += 1 exceptions.append((futures[future], e)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.codegen.cuda.cuda_kernel import ( CUDATemplateCaller, ) @@ -2814,6 +3145,7 @@ def wait_on_futures(): "Exception %s for benchmark choice %s", e, futures[future], +<<<<<<< HEAD exc_info=e, ) else: @@ -2822,6 +3154,13 @@ def wait_on_futures(): e, futures[future], exc_info=e, +======= + exc_info=True, + ) + else: + log.error( + "Exception %s for benchmark choice %s", e, futures[future] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: counters["inductor"]["select_algorithm_num_precompiles"] += 1 @@ -2830,8 +3169,11 @@ def wait_on_futures(): futures.get(future), elapsed_times.get(future), ) +<<<<<<< HEAD if exceptions: _log_autotune_exceptions(exceptions) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) executor.shutdown(wait=True) @@ -2846,7 +3188,10 @@ def get_inputs( input_nodes: list[ir.IRNode], layout: ir.Layout, input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], +<<<<<<< HEAD hint_override: Optional[int] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> AutotuneArgs: """ Factory method to create AutotuneArgs from a list of ChoiceCallers. @@ -2856,9 +3201,13 @@ def get_inputs( # de-duplicate args unique_example_inputs = { +<<<<<<< HEAD x.get_name(): input_gen_fns.get( i, lambda x: cls.benchmark_example_value(x, hint_override=hint_override) )(x) +======= + x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i, x in enumerate(input_nodes) } example_inputs = list(unique_example_inputs.values()) @@ -2871,23 +3220,36 @@ def get_inputs( V.graph.sizevars.size_hints( input_node.get_size(), fallback=config.unbacked_symint_fallback, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), V.graph.sizevars.size_hints( input_node.get_stride(), fallback=config.unbacked_symint_fallback, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), V.graph.sizevars.size_hint( input_node.get_layout().offset, fallback=config.unbacked_symint_fallback, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ) ) for input_node in input_nodes ] +<<<<<<< HEAD out = cls.benchmark_example_value(layout, hint_override=hint_override) +======= + out = cls.benchmark_example_value(layout) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_extern = torch.as_strided( out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset) ) @@ -2910,11 +3272,19 @@ def benchmark_choice( ) -> float: is_extern = isinstance(choice, (ExternKernelCaller, SubgraphChoiceCaller)) benchmark_tensors = autotune_args.get_benchmark_tensors(is_extern) +<<<<<<< HEAD inputs, output = benchmark_tensors.unpack() output.zero_() result = choice.benchmark(*inputs, out=output) device_type = next( (tensor.device.type for tensor in inputs if is_gpu(tensor.device.type)), +======= + inpts, output = benchmark_tensors.unpack() + output.zero_() + result = choice.benchmark(*inpts, out=output) + device_type = next( + (tensor.device.type for tensor in inpts if is_gpu(tensor.device.type)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "cuda", ) device_interface = get_interface_for_device(device_type) @@ -2996,11 +3366,16 @@ def benchmark_in_current_process( input_nodes: list[ir.IRNode], layout: ir.Layout, input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], +<<<<<<< HEAD hint_override: Optional[int] = None, ) -> dict[ChoiceCaller, float]: inputs = cls.get_inputs( choices, input_nodes, layout, input_gen_fns, hint_override=hint_override ) +======= + ) -> dict[ChoiceCaller, float]: + inputs = cls.get_inputs(choices, input_nodes, layout, input_gen_fns) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return cls.benchmark_choices(choices, inputs) @classmethod @@ -3010,7 +3385,10 @@ def benchmark_in_sub_process( input_nodes: list[ir.IRNode], layout: ir.Layout, input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], +<<<<<<< HEAD hint_override: Optional[int] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): from . import autotune_process @@ -3020,7 +3398,11 @@ def benchmark_in_sub_process( triton = [c for c in choices if not isinstance(c, ExternKernelCaller)] timings = cls.benchmark_in_current_process( +<<<<<<< HEAD extern, input_nodes, layout, input_gen_fns, hint_override=hint_override +======= + extern, input_nodes, layout, input_gen_fns +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) timings.update(autotune_process.benchmark_in_sub_process(triton)) # type: ignore[arg-type] return timings @@ -3032,7 +3414,10 @@ def make_benchmark_fn( input_nodes: list[ir.IRNode], layout: ir.Layout, input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], +<<<<<<< HEAD hint_override: Optional[int] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): if DEBUG: print(f"{len(choices)} tuning requests:") @@ -3043,7 +3428,10 @@ def make_benchmark_fn( input_nodes=input_nodes, layout=layout, input_gen_fns=input_gen_fns, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: return functools.partial( @@ -3051,7 +3439,10 @@ def make_benchmark_fn( input_nodes=input_nodes, layout=layout, input_gen_fns=input_gen_fns, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @staticmethod @@ -3210,7 +3601,10 @@ def log_results( elapse: float, precompile_elapse: float, prescreening_elapse: Optional[float] = None, +<<<<<<< HEAD hint_override: Optional[int] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): V.debug.log_autotuning_results( name, input_nodes, timings, elapse, precompile_elapse @@ -3225,7 +3619,10 @@ def log_results( V.graph.sizevars.size_hints( n.get_size(), fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ) ) @@ -3313,7 +3710,11 @@ def get_choice_info(choice): ) @staticmethod +<<<<<<< HEAD def benchmark_example_value(node, hint_override: Optional[int] = None): +======= + def benchmark_example_value(node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Convert an ir.Buffer into a concrete torch.Tensor we can use for benchmarking. @@ -3332,12 +3733,18 @@ def benchmark_example_value(node, hint_override: Optional[int] = None): V.graph.sizevars.size_hints( node.get_size(), fallback=config.unbacked_symint_fallback, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), V.graph.sizevars.size_hints( node.get_stride(), fallback=config.unbacked_symint_fallback, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), node.get_device(), node.get_dtype(), @@ -3345,7 +3752,10 @@ def benchmark_example_value(node, hint_override: Optional[int] = None): V.graph.sizevars.size_hints( V.graph.get_allocation_size(node), fallback=config.unbacked_symint_fallback, +<<<<<<< HEAD hint_override=hint_override, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ) @@ -3400,6 +3810,7 @@ def key_of(node): def add_feedback_saver(self, fn: FeedbackFunction): self.feedback_saver_fns.append(fn) +<<<<<<< HEAD def clear_feedback_savers(self): self.feedback_saver_fns = [] @@ -3417,10 +3828,13 @@ def clear_preprocessing_fns(self, clear_defaults: bool = False): if not clear_defaults: self._register_default_preprocessing_fns() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None +<<<<<<< HEAD def get_algorithm_selector_cache() -> AlgorithmSelectorCache: """Get the global algorithm selector cache, creating it if it doesn't exist.""" global _ALGORITHM_SELECTOR_CACHE @@ -3431,6 +3845,12 @@ def get_algorithm_selector_cache() -> AlgorithmSelectorCache: def autotune_select_algorithm(*args, **kwargs): cache = get_algorithm_selector_cache() +======= +def autotune_select_algorithm(*args, **kwargs): + global _ALGORITHM_SELECTOR_CACHE + if _ALGORITHM_SELECTOR_CACHE is None: + _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if "return_multi_template" not in kwargs: kwargs["return_multi_template"] = ( @@ -3440,12 +3860,17 @@ def autotune_select_algorithm(*args, **kwargs): if "precompilation_timeout_seconds" not in kwargs: kwargs["precompilation_timeout_seconds"] = config.precompilation_timeout_seconds +<<<<<<< HEAD return cache(*args, **kwargs) +======= + return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def add_feedback_saver( fn: FeedbackFunction, ): +<<<<<<< HEAD cache = get_algorithm_selector_cache() cache.add_feedback_saver(fn) @@ -3489,6 +3914,12 @@ def clear_preprocessing_fns(clear_defaults: bool = False): """ cache = get_algorithm_selector_cache() cache.clear_preprocessing_fns(clear_defaults) +======= + global _ALGORITHM_SELECTOR_CACHE + if _ALGORITHM_SELECTOR_CACHE is None: + _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() + _ALGORITHM_SELECTOR_CACHE.add_feedback_saver(fn) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def realize_inputs(*args): @@ -3527,6 +3958,7 @@ def sympy_call(self, *args, **kwargs): return self.fn(*args, **kwargs, **self.kwargs_sym) +<<<<<<< HEAD def _autotune_metadata(input_nodes): """Helper function to extract autotune metadata from input nodes.""" return { @@ -3667,5 +4099,7 @@ def _log_autotune_exceptions( pass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ensure lowering is imported so that `extern_kernels.*` is populated from . import lowering # noqa: F401 diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 8727777b562b2..06c81a9689742 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -8,11 +8,15 @@ import sympy from sympy import Expr +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import ( free_symbols, has_free_unbacked_symbols, ShapeEnv, ) +======= +from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols, ShapeEnv +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import FloorDiv, ModularIndexing from torch.utils._sympy.symbol import symbol_is_type, SymT @@ -59,6 +63,7 @@ def statically_known_true( # lifting and in some cases we should be directly passing through to ShapeEnv, # but there is some extra inductor logic that needs to be handled here class SizeVarAllocator: +<<<<<<< HEAD """ A class that manages symbolic size variables and their relationships. @@ -68,6 +73,8 @@ class SizeVarAllocator: calculations for tensor operations. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self, shape_env=None) -> None: super().__init__() if shape_env is None: @@ -319,7 +326,11 @@ def prune(index): # Note - [On Statically Known] # The statically_known_* family of functions below NEVER guard, they could return True if the # asked questions can be answered without guarding otherwise they return False. +<<<<<<< HEAD # Those are similar to statically_known_true in symbolic_shapes.py but operate on sympy +======= + # Those are similar to statically_known_true in symbolic_shapes but operate on sympy +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # expressions instead of symnodes. def statically_known_true(self, expr: Union[sympy.Basic, bool]) -> bool: """ @@ -336,9 +347,13 @@ def statically_known_equals( """ return self.statically_known_true(sympy.Eq(left, right)) # type: ignore[arg-type] +<<<<<<< HEAD def statically_known_list_equals( self, left: Sequence[Expr], right: Sequence[Expr] ) -> bool: +======= + def statically_known_list_equals(self, left: list[Expr], right: list[Expr]) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Returns a bool indicating if it is sound to optimize as if left and right lists are equal. """ @@ -380,16 +395,23 @@ def statically_known_multiple_of( """ Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator. """ +<<<<<<< HEAD # The reason we skip compute here is to avoid the cost of trying to eval this symbolically. # see https://github.com/sympy/sympy/issues/28200 +======= + # The reason we skip unbacked here is that we want to avoid the cost of trying to eval this symbolically. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if has_free_unbacked_symbols(numerator) or has_free_unbacked_symbols( denominator ): return False +<<<<<<< HEAD if len(free_symbols(numerator)) > 20: return False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expr = sympy.Eq(numerator % denominator, 0) return self.statically_known_true(expr) # type: ignore[arg-type] @@ -399,6 +421,7 @@ def statically_known_power_of_2(self, expr: Expr) -> bool: """ return isinstance(expr, sympy.Integer) and is_power_of_2(int(expr)) +<<<<<<< HEAD # The expect/check functions require you to ALREADY KNOW that a particular # condition holds. They are similar to expect_true in symbolic_shapes.py and # torch.check but operates on sympy expressions instead of symnodes. @@ -449,6 +472,64 @@ def check_lt(self, left: Expr, right: Expr) -> None: # Similar to the functions guard_or_false/guard_or_true in symbolic_shapes.py # but operates on sympy expressions instead of symnodes. see Note [guard_or_]. +======= + # The guard functions require you to ALREADY KNOW that a particular + # condition holds. If you don't know (you want to guard on an expression + # being a particular value, and then get access to that value), use + # the evaluate functions. + + def guard_equals(self, left: Expr, right: Expr) -> Expr: + if isinstance(left, Expr): + left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type] + if isinstance(right, Expr): + right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] + + expr = sympy.Eq(left, right) + static_expr = self.shape_env._maybe_evaluate_static(expr) + + if static_expr is not None: + assert bool(static_expr) + return left + + assert self.shape_env.guard_or_defer_runtime_assert(expr, "guard_equals") + return left + + def guard_leq(self, left: Expr, right: Expr) -> None: + return self.guard_lt(left, right + 1) + + def guard_lt(self, left: Expr, right: Expr) -> None: + expr = sympy.Lt(left, right) + static_expr = self.shape_env._maybe_evaluate_static(expr) + + if static_expr is not None: + assert bool(static_expr) + return + + assert self.shape_env.guard_or_defer_runtime_assert(expr, "guard_lt") + + def guarded_order(self, seq): + """ + Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing. + """ + seq = [*map(self.remove_precomputed_replacements, seq)] + seq = [ + (self.size_hint_or_throw(var), orig_idx, var) + for orig_idx, var in enumerate(seq) + ] + seq.sort() + order = [-1] * len(seq) + last_var = None + for new_index, (_, orig_index, var) in enumerate(seq): + order[orig_index] = new_index + if last_var is not None: + self.guard_leq(last_var, var) + last_var = var + return order + + # Similar to the functions guard_or_false/guard_or_true in symbolic_shapes but operates on sympy + # expressions instead of symnodes. see Note [guard_or_]. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def guard_or_false(self, left): return self.evaluate_expr(left, fallback_value=False) @@ -476,6 +557,7 @@ def evaluate_expr( fallback_value=fallback_value, ) +<<<<<<< HEAD def is_size_one_or_false(self, size: Expr) -> bool: """Return True if size equals 1. @@ -483,6 +565,8 @@ def is_size_one_or_false(self, size: Expr) -> bool: """ return self.guard_or_false(sympy.Eq(size, 1)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def evaluate_min(self, left: Expr, right: Expr) -> Expr: """return the smaller of left and right, and guard on that choice""" if isinstance(left, Expr): @@ -506,10 +590,17 @@ def evaluate_min(self, left: Expr, right: Expr) -> Expr: f"evaluate_min({left}, {right}) with unbacked symints" ) from None if lv <= rv: +<<<<<<< HEAD self.check_leq(left, right) return left else: self.check_leq(right, left) +======= + self.guard_leq(left, right) + return left + else: + self.guard_leq(right, left) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return right def evaluate_max(self, left: Expr, right: Expr) -> Expr: @@ -519,6 +610,7 @@ def evaluate_max(self, left: Expr, right: Expr) -> Expr: min_val = self.evaluate_min(left, right) return right if min_val is left else left +<<<<<<< HEAD def guard_int(self, expr: Union[Expr, int]) -> int: """ Similar to guard_int in symbolic_shapes.py, except this function works with SymPy @@ -537,15 +629,30 @@ def guard_int_seq(self, left: Sequence[Union[Expr, int]]) -> list[int]: Apply guard_int on a sequence of inputs. """ return [self.guard_int(x) for x in left] +======= + def evaluate_static_shape(self, left: Union[Expr, int]) -> int: + if isinstance(left, int): + return left + right = self.size_hint_or_throw(left) + self.guard_equals(left, sympy.Integer(right)) + return int(right) + + def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> list[int]: + return [self.evaluate_static_shape(x) for x in left] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def remove_precomputed_replacements(self, expr: Expr) -> Expr: if any(symbol_is_type(s, SymT.PRECOMPUTED_SIZE) for s in expr.free_symbols): # type: ignore[attr-defined] return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type] return expr +<<<<<<< HEAD def symbolic_hint( self, expr: Union[Expr, int], hint_override: Optional[int] = None ) -> Union[Expr, int]: +======= + def symbolic_hint(self, expr: Union[Expr, int]) -> Union[Expr, int]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(expr, int): return expr # Substitute all hints into expr, but leave unbacked symints alone @@ -559,14 +666,18 @@ def symbolic_hint( return int(expr) # type: ignore[return-value] except TypeError: return expr # inf/nan/I +<<<<<<< HEAD if hint_override: return hint_override +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expr = self.remove_precomputed_replacements(expr) return sympy_subs(expr, self.var_to_val) def size_hint( +<<<<<<< HEAD self, expr: Union[Expr, int], *, @@ -574,6 +685,11 @@ def size_hint( hint_override: Optional[int] = None, ) -> int: out = self.symbolic_hint(expr, hint_override=hint_override) +======= + self, expr: Union[Expr, int], *, fallback: Optional[int] = None + ) -> int: + out = self.symbolic_hint(expr) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not isinstance(out, (int, sympy.Integer)) and fallback is not None: # Use the provided heuristic fallback hint unbacked_sym_vrs = { @@ -594,7 +710,10 @@ def size_hint( raise def size_hint_or_throw(self, expr: Union[Expr, int]) -> int: +<<<<<<< HEAD # Like size_hint but there's no fallback for unbacked symints, so it throws. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = self.symbolic_hint(expr) try: return int(out) @@ -607,6 +726,7 @@ def size_hints( exprs: Iterable[Union[Expr, int]], *, fallback: Optional[int] = None, +<<<<<<< HEAD hint_override: Optional[int] = None, ) -> tuple[int, ...]: return tuple( @@ -620,6 +740,10 @@ def size_hints_or_throw( ) -> tuple[int, ...]: # Like size_hints but there's no fallback for unbacked symints, so it throws. return tuple(self.size_hint_or_throw(x) for x in exprs) +======= + ) -> tuple[int, ...]: + return tuple(self.size_hint(x, fallback=fallback) for x in exprs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _lru_cache(self, fn, maxsize=None): """ @@ -759,7 +883,11 @@ def atomically_apply_size_hint( } return expr.subs(size_dict) +<<<<<<< HEAD def offset_var(self, index: Expr, vars: Sequence[sympy.Symbol]) -> Expr: +======= + def offset_var(self, index: Expr, vars: list[sympy.Symbol]) -> Expr: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Extract offset part of an indexing expression""" index = self.simplify(index) return sympy_subs(index, {v: sympy.S.Zero for v in vars if v != 0}) diff --git a/torch/_inductor/standalone_compile.py b/torch/_inductor/standalone_compile.py index 88f635426bfd9..8ddf1fa76424c 100644 --- a/torch/_inductor/standalone_compile.py +++ b/torch/_inductor/standalone_compile.py @@ -10,7 +10,10 @@ import torch.fx from torch._dynamo.utils import dynamo_timed +<<<<<<< HEAD from torch._inductor.cpp_builder import normalize_path_separator +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.cudagraph_utils import BoxedDeviceIndex from torch._inductor.runtime.cache_dir_utils import temporary_cache_dir from torch._inductor.utils import BoxedBool, InputType @@ -117,7 +120,10 @@ def save( def load( *, path: str, format: Literal["binary", "unpacked"] = "binary" ) -> CompiledArtifact: +<<<<<<< HEAD path = normalize_path_separator(path) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with dynamo_timed("CompiledArtifact.load"): if format == "binary": # can't assert that it is a file since it might not exist yet @@ -205,7 +211,10 @@ def standalone_compile( # Reuse fake_mode from the TracingContext. # NB: The TracingContext only exists if we're currently in a torch.compile backend. context = torch._guards.TracingContext.get() +<<<<<<< HEAD assert context.fake_mode is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fake_mode = context.fake_mode elif dynamic_shapes == "from_graph": fake_mode = FakeTensorMode(shape_env=ShapeEnv()) @@ -216,13 +225,18 @@ def standalone_compile( last_node = next(iter(reversed(gm.graph.nodes))) assert last_node.op == "output" assert len(last_node.args) == 1 +<<<<<<< HEAD def handle_node(node: torch.fx.Node) -> None: nonlocal fake_mode +======= + for node in last_node.args[0]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if "example_value" in node.meta: maybe_tensor = node.meta["example_value"] if isinstance(maybe_tensor, torch._subclasses.fake_tensor.FakeTensor): fake_mode = maybe_tensor.fake_mode +<<<<<<< HEAD # If gm came from Dynamo, then last_node.args[0] is always a list, # even in single-Tensor returns. @@ -236,6 +250,8 @@ def handle_node(node: torch.fx.Node) -> None: for node in last_node.args[0]: handle_node(node) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: raise ValueError( f"standalone_compile got unsupported `dynamic_shapes` value: dynamic_shapes={dynamic_shapes}." diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index 180a9d0eba801..cb2a0667ef89b 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -13,7 +13,10 @@ from . import ir from .exc import SubgraphLoweringException +<<<<<<< HEAD from .graph import GraphLowering +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .ops_handler import SimpleCSEHandler from .virtualized import ops, V, WrapperHandler @@ -33,7 +36,11 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter): """ graph_outputs: Optional[list[ir.IRNode]] +<<<<<<< HEAD root_graph: GraphLowering +======= + root_graph: torch._inductor.graph.GraphLowering +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _current_op: Optional[TargetType] # For backwards of buffer_grads with scatters we allow mutations allowed_mutations: Optional[OrderedSet[OpOverload]] @@ -44,7 +51,11 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter): def __init__( self, gm: torch.fx.GraphModule, +<<<<<<< HEAD root_graph_lowering: GraphLowering, +======= + root_graph_lowering: torch._inductor.graph.GraphLowering, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) allowed_mutations: Optional[OrderedSet[OpOverload]] = None, additional_lowerings: Optional[LoweringDict] = None, ) -> None: @@ -87,7 +98,12 @@ def mark_buffer_mutated(self, name: str) -> None: def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str: if self._approved_mutator(): +<<<<<<< HEAD name = self.root_graph.register_buffer(buffer, set_name=set_name) +======= + name = self.qualify_name(f"buf{len(self.buffers)}") + self.buffers.append(buffer) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return name else: raise SubgraphLoweringException( diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py new file mode 100644 index 0000000000000..dfd37523a3702 --- /dev/null +++ b/torch/_inductor/template_heuristics.py @@ -0,0 +1,1180 @@ +from __future__ import annotations + +import dataclasses +import itertools +import math +from functools import partial +from threading import Lock +from typing import Any, Callable, TYPE_CHECKING + +import torch +from torch.utils._ordered_set import OrderedSet + +from . import config +from .utils import get_backend_num_stages +from .virtualized import V + + +if TYPE_CHECKING: + from collections.abc import Generator + + from triton import Config as TritonConfig + + +# Gemm Configs +@dataclasses.dataclass +class BaseConfig: + """ + Base Gemm configuration used for most backends (CPU, CUDA) + """ + + block_m: int + block_n: int + block_k: int + num_stages: int + num_warps: int + + +@dataclasses.dataclass +class GemmConfig(BaseConfig): + """ + Gemm configuration used for most backends (CPU, CUDA) + """ + + group_m: int = 8 + + +ConvConfig = BaseConfig + + +# FlexAttention Configs +@dataclasses.dataclass +class FlexConfig: + """ + Base Config class for flex attention + - FlexAttn forward, backward and flex decode will use this + + NOTE: + For flex_attn bwd block_m and block_n are reused for block_m1, block_m2, block_n1, block_n2 + + """ + + block_m: int + block_n: int + num_stages: int + num_warps: int + + +@dataclasses.dataclass +class FlexDecodeConfig: + """ + Config class for flex decoding + """ + + block_n: int + num_stages: int + num_warps: int + + +# ROCm classes +@dataclasses.dataclass +class ROCmGemmConfig(GemmConfig): + """ + ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 16 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmConvConfig(ConvConfig): + """ + ROCm subclass for Conv, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 16 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmFlexConfig(FlexConfig): + """ + ROCm subclass for FlexAttn, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 0 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmFlexDecodeConfig(FlexDecodeConfig): + """ + ROCm subclass for FlexDecode, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 0 + waves_per_eu: int = 0 + kpack: int = 2 + + +class BaseHeuristicSingleton(type): + """ + Thread-safe implementation of single to be used in the config heuristic subclasses + to ensure heavy __init__ calls are not repeatedly run + """ + + _instances: dict[type[Any], Any] = {} + _lock: Lock = Lock() + + def __call__( + cls: BaseHeuristicSingleton, *args: Any, **kwargs: Any + ) -> BaseConfigHeuristic: + with cls._lock: + if cls not in cls._instances: + instance = super().__call__() + cls._instances[cls] = instance + return cls._instances[cls] + + +class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton): + """ + Base class for mm_configs, device specific triton kernels config inherit from here + """ + + def __init__(self) -> None: + # List of dictionaries to store the kernel configs. Configs that evaluate to true + # will be utilised on the target platform. The configs are as follows: + # (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) + self.mm_configs: list[BaseConfig] = [ + GemmConfig(32, 32, 16, 1, 2), + GemmConfig(32, 32, 128, 2, 4), + GemmConfig(32, 64, 32, 5, 8), + GemmConfig(64, 32, 32, 5, 8), + GemmConfig(64, 32, 128, 5, 4), + GemmConfig(64, 64, 16, 2, 4), + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 64, 64, 3, 8), + GemmConfig(64, 64, 128, 5, 4), + GemmConfig(64, 128, 32, 3, 4), + GemmConfig(64, 128, 32, 4, 8), + GemmConfig(64, 128, 64, 3, 4), + GemmConfig(64, 128, 128, 4, 4), + GemmConfig(128, 64, 32, 3, 4), + GemmConfig(128, 64, 32, 4, 8), + GemmConfig(128, 128, 32, 2, 8), + GemmConfig(128, 128, 32, 3, 4), + GemmConfig(128, 128, 64, 3, 4), + GemmConfig(128, 128, 64, 5, 8), + ] + + # Exhaustive search for mm configs + self.exhaustive_configs: list[BaseConfig] = [ + GemmConfig(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, group_m) + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, 2, 3, 4, 5] + for num_warps in [2, 4, 8] + for group_m in [8] + ] + + # these are only used in tuned_mm when AutoHeuristic is enabled + # the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned + # when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10 + # which saves compilation time (since less configs are autotuned) and potentially increase performance + # because the learned heuristic might predict a config that is not part mm_configs + self.extra_mm_configs: list[BaseConfig] = [ + GemmConfig(16, 32, 16, 3, 2), + GemmConfig(16, 32, 32, 4, 2), + GemmConfig(16, 32, 32, 5, 2), + GemmConfig(64, 64, 128, 3, 4), + GemmConfig(128, 64, 32, 2, 2), + GemmConfig(128, 64, 64, 3, 8), + GemmConfig(128, 64, 128, 4, 8), + GemmConfig(128, 128, 32, 4, 4), + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 64, 5, 4), + ] + + self.int8_mm_configs: list[BaseConfig] = [ + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 128, 32, 3, 4), + GemmConfig(128, 64, 32, 3, 4), + GemmConfig(64, 128, 32, 4, 8), + GemmConfig(128, 64, 32, 4, 8), + GemmConfig(64, 32, 32, 5, 8), + GemmConfig(32, 64, 32, 5, 8), + GemmConfig(128, 128, 32, 2, 8), + GemmConfig(64, 64, 64, 3, 8), + GemmConfig(128, 256, 128, 3, 8), + GemmConfig(256, 128, 128, 3, 8), + ] + + self.mixed_mm_configs: list[BaseConfig] = [ + GemmConfig(16, 128, 256, 3, 4), + GemmConfig(16, 128, 256, 5, 8), + ] + + self.persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 64, 3, 8), + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(128, 128, 64, 4, 8), + GemmConfig(128, 128, 64, 5, 8), + GemmConfig(256, 128, 64, 4, 8), + GemmConfig(128, 128, 64, 5, 4), + ] + + self.scaled_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 32, 3, 8), + GemmConfig(256, 128, 32, 3, 8), + GemmConfig(256, 64, 32, 4, 4), + GemmConfig(64, 256, 32, 4, 4), + GemmConfig(128, 128, 32, 4, 4), + GemmConfig(128, 64, 32, 4, 4), + GemmConfig(64, 128, 32, 4, 4), + GemmConfig(128, 32, 32, 4, 4), + GemmConfig(64, 32, 32, 5, 2), + GemmConfig(256, 128, 128, 3, 8), + GemmConfig(256, 64, 128, 4, 4), + GemmConfig(64, 256, 128, 4, 4), + GemmConfig(128, 128, 128, 4, 4), + GemmConfig(128, 64, 64, 4, 4), + GemmConfig(64, 128, 64, 4, 4), + GemmConfig(128, 32, 64, 4, 4), + GemmConfig(64, 32, 64, 5, 2), + GemmConfig(16, 32, 32, 2, 2), + GemmConfig(16, 64, 32, 2, 2), + GemmConfig(16, 128, 32, 2, 4), + GemmConfig(16, 256, 32, 2, 4), + GemmConfig(16, 32, 64, 2, 2), + GemmConfig(16, 64, 64, 2, 2), + GemmConfig(16, 128, 64, 2, 4), + GemmConfig(16, 256, 64, 2, 4), + GemmConfig(32, 32, 32, 2, 2), + GemmConfig(32, 64, 32, 2, 2), + GemmConfig(32, 128, 32, 2, 4), + GemmConfig(32, 256, 32, 2, 4), + GemmConfig(32, 32, 64, 2, 2), + GemmConfig(32, 64, 64, 2, 2), + GemmConfig(32, 128, 64, 2, 4), + GemmConfig(32, 256, 64, 2, 4), + GemmConfig(16, 32, 32, 3, 2), + GemmConfig(16, 64, 32, 3, 2), + GemmConfig(16, 128, 32, 3, 4), + GemmConfig(16, 256, 32, 3, 4), + GemmConfig(16, 32, 64, 3, 2), + GemmConfig(16, 64, 64, 3, 2), + GemmConfig(16, 128, 64, 3, 4), + GemmConfig(16, 256, 64, 3, 4), + GemmConfig(32, 32, 32, 3, 2), + GemmConfig(32, 64, 32, 3, 2), + GemmConfig(32, 128, 32, 3, 4), + GemmConfig(32, 256, 32, 3, 4), + GemmConfig(32, 32, 64, 3, 2), + GemmConfig(32, 64, 64, 3, 2), + GemmConfig(32, 128, 64, 3, 4), + GemmConfig(32, 256, 64, 3, 4), + GemmConfig(16, 32, 32, 4, 2), + GemmConfig(16, 64, 32, 4, 2), + GemmConfig(16, 128, 32, 4, 4), + GemmConfig(16, 256, 32, 4, 4), + GemmConfig(16, 32, 64, 4, 2), + GemmConfig(16, 64, 64, 4, 2), + GemmConfig(16, 128, 64, 4, 4), + GemmConfig(16, 256, 64, 4, 4), + GemmConfig(32, 32, 32, 4, 2), + GemmConfig(32, 64, 32, 4, 2), + GemmConfig(32, 128, 32, 4, 4), + GemmConfig(32, 256, 32, 4, 4), + GemmConfig(32, 32, 64, 4, 2), + GemmConfig(32, 64, 64, 4, 2), + GemmConfig(32, 128, 64, 4, 4), + GemmConfig(32, 256, 64, 4, 4), + GemmConfig(16, 32, 32, 5, 2), + GemmConfig(16, 64, 32, 5, 2), + GemmConfig(16, 128, 32, 5, 4), + GemmConfig(16, 256, 32, 5, 4), + GemmConfig(16, 32, 64, 5, 2), + GemmConfig(16, 64, 64, 5, 2), + GemmConfig(16, 128, 64, 5, 4), + GemmConfig(16, 256, 64, 5, 4), + GemmConfig(32, 32, 32, 5, 2), + GemmConfig(32, 64, 32, 5, 2), + GemmConfig(32, 128, 32, 5, 4), + GemmConfig(32, 256, 32, 5, 4), + GemmConfig(32, 32, 64, 5, 2), + GemmConfig(32, 64, 64, 5, 2), + GemmConfig(32, 128, 64, 5, 4), + GemmConfig(32, 256, 64, 5, 4), + GemmConfig(16, 32, 32, 6, 2), + GemmConfig(16, 64, 32, 6, 2), + GemmConfig(16, 128, 32, 6, 4), + GemmConfig(16, 256, 32, 6, 4), + GemmConfig(16, 32, 64, 6, 2), + GemmConfig(16, 64, 64, 6, 2), + GemmConfig(16, 128, 64, 6, 4), + GemmConfig(16, 256, 64, 6, 4), + GemmConfig(32, 32, 32, 6, 2), + GemmConfig(32, 64, 32, 6, 2), + GemmConfig(32, 128, 32, 6, 4), + GemmConfig(32, 256, 32, 6, 4), + GemmConfig(32, 32, 64, 6, 2), + GemmConfig(32, 64, 64, 6, 2), + GemmConfig(32, 128, 64, 6, 4), + GemmConfig(32, 256, 64, 6, 4), + ] + + self.scaled_persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + GemmConfig(128, 128, 128, 4, 8), + GemmConfig(128, 128, 128, 4, 4), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(128, 128, 128, 5, 4), + GemmConfig(128, 128, 128, 5, 8), + GemmConfig(128, 128, 128, 6, 8), + GemmConfig(128, 128, 64, 4, 8), + ] + + # TODO: Unify with other gemm patterns, mm_plus_mm currently follows + # slightly different pattern than rest + self.mm_plus_mm_configs: list[BaseConfig] = [ + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 64, 32, 3, 8), + GemmConfig(64, 64, 32, 4, 16), + GemmConfig(64, 32, 32, 4, 8), + GemmConfig(32, 64, 32, 4, 8), + GemmConfig(128, 128, 32, 1, 8), + GemmConfig(64, 64, 64, 1, 8), + GemmConfig(32, 32, 128, 1, 8), + GemmConfig(64, 64, 16, 2, 4), + GemmConfig(32, 32, 16, 1, 2), + ] + + self.conv_configs: list[BaseConfig] = [ + ConvConfig(64, 256, 16, 2, 4), + ConvConfig(256, 64, 16, 2, 4), + ConvConfig(1024, 16, 16, 1, 8), + ConvConfig(128, 128, 32, 2, 8), + ConvConfig(64, 64, 32, 2, 4), + ConvConfig(64, 256, 32, 2, 8), + ConvConfig(256, 64, 32, 2, 8), + ] + + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [ + FlexConfig(128, 64, 3, 4), + FlexConfig(128, 128, 3, 4), + FlexConfig(128, 128, 2, 8), + FlexConfig(64, 128, 3, 4), + FlexConfig(64, 64, 3, 4), + ] + + self.flex_attn_bwd_autotune_configs: list[FlexConfig] = [ + FlexConfig(BLOCK1, BLOCK2, s, w) + for BLOCK1 in [32, 64] + for BLOCK2 in [32, 64, 128] + for s in [1, 3, 4, 5] # num_stages + for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) + if BLOCK2 % BLOCK1 == 0 + ] + + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [ + FlexDecodeConfig(64, 3, 2), + FlexDecodeConfig(32, 3, 2), + FlexDecodeConfig(128, 3, 2), + ] + + self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [ + FlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps) + for BLOCK_M in [16, 32, 64, 128] + for BLOCK_N in [32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + ] + + self.exhaustive_flex_attn_bwd_configs: list[FlexConfig] = [ + FlexConfig(BLOCK1, BLOCK2, num_stages, num_warps) + for BLOCK1 in [16, 32, 64, 128] + for BLOCK2 in [16, 32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + if BLOCK2 % BLOCK1 == 0 + ] + + self.exhaustive_flex_decode_configs: list[FlexDecodeConfig] = [ + FlexDecodeConfig(block_n, num_stages, num_warps) + for block_n in [16, 32, 64, 128] + for num_stages in [1, 3, 4, 5] + for num_warps in [2, 4, 8] + ] + + def _finalize_mm_configs( + self, + configs: list[BaseConfig], + ) -> Generator[TritonConfig, None, None]: + """ + Finalizes configs after scaling, applying additional constraints. + """ + used: OrderedSet[tuple[int, ...]] = OrderedSet() + + max_mm_configs = config.test_configs.max_mm_configs + + for conf in configs: + # Each warp computes a 16x16 tile = 256 elements + num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256) + + # Construct key for finding duplicate configs + key: tuple[int, ...] = ( + conf.block_m, + conf.block_n, + conf.block_k, + conf.num_stages, + num_warps, + ) + + # Check if gemm specific arg exists - add to key if does + group_m = getattr(conf, "group_m", None) + if group_m is not None: + key += (group_m,) + + if key not in used and ( + max_mm_configs is None or len(used) < max_mm_configs + ): + used.add(key) + kwargs = { + "BLOCK_M": conf.block_m, + "BLOCK_N": conf.block_n, + "BLOCK_K": conf.block_k, + "num_stages": conf.num_stages, + "num_warps": num_warps, + } + if group_m is not None: + kwargs["GROUP_M"] = group_m + yield self.triton_config(**kwargs) + + def _scale_mm_configs( + self, + m: int, + n: int, + k: int, + configs: list[BaseConfig], + scale: float, + has_int8_tensor: bool, + exclude: Callable[[int, int, int], bool], + ) -> list[BaseConfig]: + """ + Scales and filters matrix multiplication configs based on input size. + """ + from .runtime.runtime_utils import next_power_of_2 + + min_block_size = 16 + min_block_size_k = 32 if has_int8_tensor else 16 + + m = max( + next_power_of_2( + V.graph.sizevars.size_hint( + m, + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + ) + ), + min_block_size, + ) + n = max( + next_power_of_2( + V.graph.sizevars.size_hint( + n, + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + ) + ), + min_block_size, + ) + k = max( + next_power_of_2( + V.graph.sizevars.size_hint( + k, + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + ) + ), + min_block_size_k, + ) + + scaled_configs = [] + for c in configs: + scaled_config = dataclasses.replace( + c, + block_m=max(min(int(c.block_m * scale), m), min_block_size), + block_n=max(min(int(c.block_n * scale), n), min_block_size), + block_k=max(min(int(c.block_k * scale), k), min_block_size_k), + ) + + if not exclude( + scaled_config.block_m, scaled_config.block_n, scaled_config.block_k + ): + scaled_configs.append(scaled_config) + + return scaled_configs + + def _prune_exhaustive_configs( + self, + configs: list[BaseConfig], + dtype_size: int, + ) -> list[BaseConfig]: + import torch + + pruned_configs = [] + for gemm_config in configs: + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + sm_available = props.shared_memory_per_block_optin # type: ignore[attr-defined] + NUM_REG = 255 + + acc_regs = math.ceil( + gemm_config.block_m * gemm_config.block_n / (gemm_config.num_warps * 32) + ) + + shared_mem_accum = dtype_size * ( + gemm_config.block_m * gemm_config.block_k + + gemm_config.block_n * gemm_config.block_k + ) + + # Will use more shared memory than available + if shared_mem_accum * gemm_config.num_stages > sm_available: + continue + # Lower bound for register spillage, if exceeds the kernel will certainly spill + elif acc_regs > NUM_REG: + continue + + pruned_configs.append(gemm_config) + + return pruned_configs + + def preprocess_mm_configs( + self, + m: int, + n: int, + k: int, + configs: list[BaseConfig], + has_int8_tensor: bool = False, + scale: int = 1, + exclude: Callable[[int, int, int], bool] = lambda m, n, k: False, + dtype_size: int = 0, + ) -> Generator[TritonConfig, None, None]: + scaled_configs = self._scale_mm_configs( + m, n, k, configs, scale, has_int8_tensor, exclude + ) + + if config.max_autotune_gemm_search_space == "EXHAUSTIVE": + assert dtype_size > 0, "dtype_size must be provided for exhaustive search" + scaled_configs = self._prune_exhaustive_configs(scaled_configs, dtype_size) + return self._finalize_mm_configs(scaled_configs) + + def triton_config( + self, num_stages: int, num_warps: int, **kwargs: Any + ) -> TritonConfig: + from triton import Config as TritonConfig # type: ignore[attr-defined] + + return TritonConfig(kwargs, num_stages=num_stages, num_warps=num_warps) + + def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.mm_configs) + + def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.exhaustive_configs) + + def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.extra_mm_configs) + + def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.int8_mm_configs) + + def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + mm_configs = ( + self.mm_configs + self.mixed_mm_configs + if config.max_autotune_gemm_search_space == "EXHAUSTIVE" + else self.mm_configs + ) + return partial(self.preprocess_mm_configs, configs=mm_configs) + + def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + persistent_mm_configs = ( + self.exhaustive_configs + if config.max_autotune_gemm_search_space == "EXHAUSTIVE" + else self.persistent_mm_configs + ) + + # num_warps=2 not safe for TMA + persistent_mm_configs = [ + config for config in persistent_mm_configs if config.num_warps != 2 + ] + return partial(self.preprocess_mm_configs, configs=persistent_mm_configs) + + def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.scaled_mm_configs) + + def get_scaled_persistent_mm_configs( + self, + ) -> partial[Generator[TritonConfig, None, None]]: + return partial( + self.preprocess_mm_configs, configs=self.scaled_persistent_mm_configs + ) + + def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self._finalize_mm_configs, configs=self.mm_plus_mm_configs) + + def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]: + return partial(self.preprocess_mm_configs, configs=self.conv_configs) + + # Flex attn helpers + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = FlexConfig(64, 64, 3, 4) + else: + default_config = FlexConfig(128, 64, 3, 4) + else: + if dtype == torch.float32: + default_config = FlexConfig(32, 16, 3, 4) + else: + default_config = FlexConfig(64, 32, 3, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_bwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + default_config = FlexConfig(16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + default_config = FlexDecodeConfig(block_n=64, num_stages=1, num_warps=2) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + +class CPUConfigHeuristic(BaseConfigHeuristic): + pass + + +class CUDAConfigHeuristic(BaseConfigHeuristic): + """ + Child class for CUDA device specific gemm/flex attention/conv/ configs. + """ + + def __init__(self) -> None: + super().__init__() + + self.h100_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(32, 64, 3, 4), + (torch.float32, 256): FlexConfig(32, 32, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 128, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 3, 8), + (torch.bfloat16, 256): FlexConfig(64, 32, 3, 4), + (torch.float16, 64): FlexConfig(128, 128, 3, 4), + (torch.float16, 128): FlexConfig(128, 128, 3, 8), + (torch.float16, 256): FlexConfig(64, 32, 3, 4), + } + + self.a100_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(128, 32, 3, 4), + (torch.float32, 256): FlexConfig(64, 16, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 64, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 3, 8), + (torch.bfloat16, 256): FlexConfig(32, 64, 3, 4), + (torch.float16, 64): FlexConfig(128, 64, 3, 4), + (torch.float16, 128): FlexConfig(128, 64, 3, 8), + (torch.float16, 256): FlexConfig(32, 64, 3, 4), + } + + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + capability = torch.cuda.get_device_capability() + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = FlexConfig(64, 64, 3, 4) + else: + default_config = FlexConfig(128, 64, 3, 4) + if capability >= (9, 0): + default_config = self.h100_default_flex_config.get( + (dtype, head_dim), default_config + ) + elif capability >= (8, 0): + default_config = self.a100_default_flex_config.get( + (dtype, head_dim), default_config + ) + else: + if dtype == torch.float32: + default_config = FlexConfig(32, 16, 3, 4) + else: + default_config = FlexConfig(64, 32, 3, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + capability = torch.cuda.get_device_capability() + + flex_attn_bwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + if dtype == torch.float32: + default_config = FlexConfig(16, 16, 1, 4) + elif head_dim <= 256 and capability >= (9, 0): # H100 + if head_dim == 64: + default_config = FlexConfig(64, 64, 3, 4) + elif head_dim == 128: + default_config = FlexConfig(64, 128, 3, 8) + else: + default_config = FlexConfig(64, 64, 2, 4) + elif capability >= (8, 0): # A100 + if head_dim == 64: + default_config = FlexConfig(32, 128, 3, 4) + elif head_dim == 128: + # SM86/89 have smaller shared memory sizes + num_stages = 3 if capability[1] == 0 else 2 + default_config = FlexConfig(64, 64, num_stages, 4) + else: + default_config = FlexConfig(64, 64, 2, 4) + else: # modest hardware or extremely large head_dim + default_config = FlexConfig(16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + capability = torch.cuda.get_device_capability() + + default_config = FlexDecodeConfig(64, 1, 2) + + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + if capability >= (9, 0): # sm_90+ + if head_dim > 128 and dtype == torch.float32: + default_config = FlexDecodeConfig(64, 1, 2) + else: + default_config = FlexDecodeConfig(64, 3, 2) + else: + default_config = FlexDecodeConfig(64, 1, 2) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + +class ROCmConfigHeuristic(BaseConfigHeuristic): + """ + Child class for ROCm specific gemm/flex attention/conv/ configs. + """ + + def __init__(self) -> None: + super().__init__() + + self.default_num_stages = get_backend_num_stages() + + self.mm_configs: list[BaseConfig] = [ + ROCmGemmConfig( + 16, 16, 256, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(32, 16, 256, self.default_num_stages, 4, group_m=4), + ROCmGemmConfig( + 32, 32, 16, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(32, 32, 128, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(32, 64, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig( + 64, 16, 128, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(64, 32, 32, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 32, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 32, 64, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(64, 32, 128, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 64, 16, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 64, 64, self.default_num_stages, 4, group_m=4), + ROCmGemmConfig(64, 64, 128, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig(64, 64, 256, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 64, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(64, 128, 32, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(64, 128, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(64, 128, 128, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(128, 32, 32, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(128, 32, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig( + 128, 64, 32, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(128, 64, 64, self.default_num_stages, 4, group_m=16), + ROCmGemmConfig(128, 64, 128, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 128, 128, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 128, 32, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig( + 128, 128, 32, self.default_num_stages, 8, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 128, 64, self.default_num_stages, 4, group_m=16), + ROCmGemmConfig(128, 128, 64, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(128, 128, 128, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig( + 128, 256, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 256, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(256, 64, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 256, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(256, 128, 32, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig(256, 128, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(256, 256, 64, self.default_num_stages, 8, group_m=4), + ] + + # Exhaustive search for mm configs + self.exhaustive_configs: list[BaseConfig] = [ + ROCmGemmConfig( + BLOCK_M, + BLOCK_N, + BLOCK_K, + num_stages, + num_warps, + group_m, + matrix_instr_nonkdim, + waves_per_eu, + kpack, + ) + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, self.default_num_stages] + for num_warps in [4, 8] + for group_m in [4, 8, 16] + for matrix_instr_nonkdim in [0, 16] + for waves_per_eu in [0, 2] + for kpack in [2] + ] + + self.default_flex_config = { + (torch.float32, 64): ROCmFlexConfig(128, 32, 1, 4), + (torch.float32, 128): ROCmFlexConfig(128, 32, 1, 4), + (torch.float32, 256): ROCmFlexConfig(64, 16, 1, 4), + (torch.bfloat16, 64): ROCmFlexConfig(128, 64, 1, 8), + (torch.bfloat16, 128): ROCmFlexConfig(128, 64, 1, 8), + (torch.bfloat16, 256): ROCmFlexConfig(32, 64, 1, 8), + (torch.float16, 64): ROCmFlexConfig(128, 64, 1, 8), + (torch.float16, 128): ROCmFlexConfig(128, 64, 1, 8), + (torch.float16, 256): ROCmFlexConfig(32, 64, 1, 4), + } + + self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK1, BLOCK2, 1, w) + for BLOCK1 in [16, 64, 128] + for BLOCK2 in [16, 32, 64, 128] + for w in [4, 8] + ] + + self.flex_attn_bwd_autotune_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK1, BLOCK2, 1, w, mfma) + for BLOCK1 in [16, 32, 64] + for BLOCK2 in [32, 64, 128] + for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) + for mfma in [0, 16] + if BLOCK2 % BLOCK1 == 0 + ] + + self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [ + ROCmFlexDecodeConfig(32, 1, 4), + ROCmFlexDecodeConfig(64, 1, 4), + ROCmFlexDecodeConfig(128, 1, 4), + ROCmFlexDecodeConfig(32, 1, 8), + ROCmFlexDecodeConfig(64, 1, 8), + ROCmFlexDecodeConfig(128, 1, 8), + ] + + self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps, mfma, wpeu) + for BLOCK_M in [16, 32, 64, 128] + for BLOCK_N in [32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + ] + + self.exhaustive_flex_attn_bwd_configs: list[FlexConfig] = [ + ROCmFlexConfig(BLOCK1, BLOCK2, num_stages, num_warps, mfma, wpeu) + for BLOCK1 in [16, 32, 64, 128] + for BLOCK2 in [16, 32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + if BLOCK2 % BLOCK1 == 0 + ] + + self.exhaustive_flex_decode_configs: list[FlexDecodeConfig] = [ + ROCmFlexDecodeConfig(block_n, num_stages, num_warps, mfma, wpeu, kpack=2) + for block_n in [16, 32, 64, 128] + for num_stages in [1, 2] + for num_warps in [2, 4, 8] + for mfma in [0, 16] + for wpeu in [0, int(8 // num_warps)] + ] + + def _filter_configs( + self, configs: list[BaseConfig], new_num_stages: int + ) -> list[BaseConfig]: + # TODO: _filter_configs can be removed once backend specific configs are added + # for all methods + for c in configs: + c.num_stages = self.default_num_stages + return configs + + def _finalize_mm_configs( + self, + configs: list[BaseConfig], + ) -> Generator[TritonConfig, None, None]: + """ + Finalizes configs after scaling, applying additional constraints. + """ + used: OrderedSet[tuple[int, ...]] = OrderedSet() + + max_mm_configs = config.test_configs.max_mm_configs + + for conf in configs: + # Each warp computes a 16x16 tile = 256 elements + conf.num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256) + + # Defaults for AMD triton backend kern args if not set + matrix_instr_nonkdim = getattr(conf, "matrix_instr_nonkdim", 16) + waves_per_eu = getattr(conf, "waves_per_eu", 0) + kpack = getattr(conf, "kpack", 2) + + if matrix_instr_nonkdim != 0 and ( + conf.block_m % matrix_instr_nonkdim != 0 + or conf.block_n % matrix_instr_nonkdim != 0 + ): + # block_m and block_n must be a multiple of matrix_instr_nonkdim + continue + + # Construct key for finding duplicate configs + key: tuple[int, ...] = ( + conf.block_m, + conf.block_n, + conf.block_k, + conf.num_stages, + conf.num_warps, + waves_per_eu, + matrix_instr_nonkdim, + kpack, + ) + + # Check if gemm specific arg exists - add to key if does + group_m = getattr(conf, "group_m", None) + if group_m is not None: + key += (group_m,) + + if waves_per_eu != 0: + waves_per_eu = int(8 // conf.num_warps) + + if key not in used and ( + max_mm_configs is None or len(used) < max_mm_configs + ): + used.add(key) + kwargs = { + "BLOCK_M": conf.block_m, + "BLOCK_N": conf.block_n, + "BLOCK_K": conf.block_k, + "num_stages": conf.num_stages, + "num_warps": conf.num_warps, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "waves_per_eu": waves_per_eu, + "kpack": kpack, + } + if group_m is not None: + kwargs["GROUP_M"] = group_m + yield self.triton_config(**kwargs) + + def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.extra_mm_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_int8_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.int8_mm_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_mixed_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + mm_configs = ( + self.mm_configs + self.mixed_mm_configs + if config.max_autotune_gemm_search_space == "EXHAUSTIVE" + else self.mm_configs + ) + filtered_configs = self._filter_configs(mm_configs, self.default_num_stages) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_persistent_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.persistent_mm_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_scaled_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.scaled_mm_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_scaled_persistent_mm_configs( + self, + ) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.scaled_persistent_mm_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_mm_plus_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs(self.mm_plus_mm_configs, 1) + return partial(self._finalize_mm_configs, configs=filtered_configs) + + def get_conv_configs(self) -> partial[Generator[TritonConfig, None, None]]: + filtered_configs = self._filter_configs( + self.conv_configs, self.default_num_stages + ) + return partial(self.preprocess_mm_configs, configs=filtered_configs) + + def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_fwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_fwd_configs + flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + + if head_dim <= 256: + if dtype == torch.float32: + default_config = ROCmFlexConfig(64, 64, 1, 4) + else: + default_config = ROCmFlexConfig(128, 64, 1, 8) + default_config = self.default_flex_config.get( + (dtype, head_dim), default_config + ) + else: + if dtype == torch.float32: + default_config = ROCmFlexConfig(32, 16, 1, 4) + else: + default_config = ROCmFlexConfig(64, 32, 1, 4) + + if default_config not in flex_attn_fwd_configs: + flex_attn_fwd_configs.append(default_config) + + return flex_attn_fwd_configs + + def get_flex_attn_bwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: + flex_attn_bwd_configs: list[FlexConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_attn_bwd_configs + flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + + if dtype == torch.float32: + default_config = ROCmFlexConfig(16, 16, 1, 4) + elif head_dim <= 256: + if head_dim == 64: + default_config = ROCmFlexConfig(64, 64, 1, 4) + elif head_dim == 128: + default_config = ROCmFlexConfig(64, 128, 1, 8) + else: + default_config = ROCmFlexConfig(64, 64, 1, 4) + else: + default_config = ROCmFlexConfig(16, 16, 1, 4) + + if default_config not in flex_attn_bwd_configs: + flex_attn_bwd_configs.append(default_config) + + return flex_attn_bwd_configs + + def get_flex_decode_configs( + self, head_dim: int, dtype: Any + ) -> list[FlexDecodeConfig]: + flex_decode_configs: list[FlexDecodeConfig] = [] + + if config.max_autotune: + if config.max_autotune_flex_search_space == "EXHAUSTIVE": + return self.exhaustive_flex_decode_configs + flex_decode_configs += self.flex_decode_autotune_configs + + default_config = ROCmFlexDecodeConfig(64, 1, 4) + + if default_config not in flex_decode_configs: + flex_decode_configs.append(default_config) + + return flex_decode_configs + + +class XPUConfigHeuristic(BaseConfigHeuristic): + """ + Placeholder child class for XPU specific overrides. + """ diff --git a/torch/_inductor/test_operators.py b/torch/_inductor/test_operators.py index d3d2705f8c788..a66d056e8e443 100644 --- a/torch/_inductor/test_operators.py +++ b/torch/_inductor/test_operators.py @@ -5,6 +5,7 @@ from torch.autograd import Function +<<<<<<< HEAD _test_lib_def = torch.library.Library("_inductor_test", "DEF") _test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag) @@ -26,3 +27,27 @@ def backward(ctx: Any, *grad_output: Any) -> Any: def realize(x: Tensor) -> Tensor: return Realize.apply(x) +======= +if not torch._running_with_deploy(): + _test_lib_def = torch.library.Library("_inductor_test", "DEF") + _test_lib_def.define( + "realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag + ) + + _test_lib_impl = torch.library.Library("_inductor_test", "IMPL") + for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"): + _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) + + class Realize(Function): + @staticmethod + def forward(ctx: object, x: Tensor) -> Tensor: + return torch.ops._inductor_test.realize(x) + + @staticmethod + # types need to stay consistent with _SingleLevelFunction + def backward(ctx: Any, *grad_output: Any) -> Any: + return grad_output[0] + + def realize(x: Tensor) -> Tensor: + return Realize.apply(x) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_inductor/triton_bundler.py b/torch/_inductor/triton_bundler.py index b210dbff5c849..d72514298b395 100644 --- a/torch/_inductor/triton_bundler.py +++ b/torch/_inductor/triton_bundler.py @@ -183,9 +183,20 @@ def put_static_autotuner(cls, key: str, kernel: "CachingAutotuner") -> None: # new_kernel, ) ) +<<<<<<< HEAD # Put the values back since we need it to use now kernel.restore_after_unpickle(old_values) +======= + # Put the values back since we need it to use now + ( + kernel.fn.fn, + kernel.fn.__globals__, + kernel.fn.used_global_vals, + kernel.fn.repr, + kernel.launchers, + ) = old_values +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def collect_static_autotuners( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 29a690aa1080b..da3135af6918a 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -18,11 +18,15 @@ import shutil import statistics import sys +<<<<<<< HEAD import sysconfig +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import tempfile import textwrap import time import unittest +<<<<<<< HEAD from collections.abc import ( Collection, Generator, @@ -31,6 +35,9 @@ MutableMapping, MutableSet, ) +======= +from collections.abc import Collection, Iterator, Mapping, MutableMapping, MutableSet +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from datetime import datetime from io import StringIO from typing import ( @@ -59,16 +66,23 @@ import sympy import torch +<<<<<<< HEAD import torch.utils._pytree as pytree from torch._inductor.analysis.device_info import datasheet_tops from torch._inductor.runtime.hints import DeviceProperties from torch.utils._dtype_abbrs import dtype_abbrs from torch.utils._ordered_set import OrderedSet from torch.utils._pytree import tree_flatten, tree_map_only +======= +from torch._inductor.runtime.hints import DeviceProperties +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_map_only +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) OPTIMUS_EXCLUDE_POST_GRAD = [ "activation_quantization_aten_pass", +<<<<<<< HEAD "inductor_autotune_lookup_table", ] @@ -80,23 +94,47 @@ ) +======= +] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING: from collections.abc import Iterable, Sequence, ValuesView from torch import SymBool, SymFloat, SymInt from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.fx import GraphModule +<<<<<<< HEAD +======= + from torch.fx.experimental.symbolic_shapes import ShapeEnv +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.node import Node from .codegen.common import WorkspaceArg from .codegen.wrapper import PythonWrapperCodegen from .graph import GraphLowering +<<<<<<< HEAD from .ir import Buffer, ExternKernel, IRNode, Layout, Operation, ReinterpretView +======= + from .ir import ( + Buffer, + ExternKernel, + ExternKernelOut, + IRNode, + Layout, + Operation, + ReinterpretView, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .output_code import CompiledFxGraph from .scheduler import BaseSchedulerNode, SchedulerBuffer +<<<<<<< HEAD GPU_TYPES = ["cuda", "mps", "xpu", "mtia"] +======= +GPU_TYPES = ["cuda", "mps", "xpu"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T = TypeVar("T") @@ -133,8 +171,11 @@ def get_gpu_type() -> str: _IS_WINDOWS = sys.platform == "win32" log = logging.getLogger(__name__) +<<<<<<< HEAD perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _T = TypeVar("_T") VarRanges = dict[sympy.Expr, sympy.Expr] @@ -254,10 +295,14 @@ def fp8_bench(fn: Callable[[], Any], warmup: int = 25, rep: int = 100) -> float: [ event for event in p.events() +<<<<<<< HEAD if ( event.device_type == DeviceType.CUDA and re.match(r"fused_abs_max_\d", event.name) is not None ) +======= + if event.device_type == DeviceType.CUDA and "fused_abs_max_0" in event.name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] ) if filtered_events: @@ -726,6 +771,7 @@ def get_kernel_metadata( node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], wrapper: PythonWrapperCodegen, ) -> tuple[str, str]: +<<<<<<< HEAD """ Retrieves metadata information for a kernel. Args: @@ -740,6 +786,8 @@ def get_kernel_metadata( - The second string represent the kernel's detailed metadata. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) all_origins = aggregate_origins(node_schedule) inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"] @@ -784,6 +832,7 @@ def get_kernel_metadata( # print the aot_autograd graph fragment if single_graph is not None: +<<<<<<< HEAD from . import ir detailed_metadata.append(f"{wrapper.comment} Graph fragment:") @@ -863,6 +912,13 @@ def stringfy_layout(layout: ir.Layout | None) -> str: ) detailed_metadata.append(f"{wrapper.comment} return {','.join(all_writes)}") +======= + detailed_metadata.append(f"{wrapper.comment} Graph fragment:") + for n in inductor_nodes: + # TODO(future): maybe refactor torch/fx/graph.py to make it easy to + # generate python code for graph fragments + detailed_metadata.append(f"{wrapper.comment} {n.format_node()}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return metadata, "\n".join(detailed_metadata) @@ -889,7 +945,13 @@ def dominated_nodes( def gather_origins( args: Sequence[IRNode], kwargs: dict[str, IRNode] +<<<<<<< HEAD ) -> OrderedSet[torch.fx.Node]: +======= +) -> OrderedSet[IRNode]: + import itertools + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from . import ir def is_unrealized_node(n: IRNode) -> bool: @@ -897,6 +959,7 @@ def is_unrealized_node(n: IRNode) -> bool: return is_unrealized_node(n.data) if isinstance(n, ir.StorageBox): return is_unrealized_node(n.data) +<<<<<<< HEAD return isinstance(n, ir.IRNode) and not isinstance( n, ( @@ -914,6 +977,13 @@ def is_unrealized_node(n: IRNode) -> bool: args_flatten, _ = tree_flatten(args) args_origins = [val.origins for val in args_flatten if is_unrealized_node(val)] return OrderedSet(itertools.chain(*args_origins, *kwargs_origins)) +======= + return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise) + + kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)] + arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)] + return OrderedSet(itertools.chain(*arg_origins, *kwarg_origins)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def sympy_str(expr: sympy.Expr) -> str: @@ -1091,7 +1161,11 @@ def get_first_incompatible_cudagraph_node( if ( not torch._inductor.config.graph_partition and isinstance(node.target, torch._ops.OpOverload) +<<<<<<< HEAD and torch._C.Tag.cudagraph_unsafe in node.target.tags # type: ignore[attr-defined] +======= + and torch._C.Tag.cudagraph_unsafe in node.target.tags +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): # skip cudagraph if a cudagraph_unsafe op is detected. # graph_partition helps by splitting on this cudagraph_unsafe @@ -1199,17 +1273,25 @@ def fresh_cache( """ clear_caches() +<<<<<<< HEAD from torch._inductor.cpp_builder import normalize_path_separator inductor_cache_dir = normalize_path_separator(tempfile.mkdtemp(dir=dir)) +======= + inductor_cache_dir = tempfile.mkdtemp(dir=dir) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: with mock.patch.dict( os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir} ): log.debug("Using inductor cache dir %s", inductor_cache_dir) +<<<<<<< HEAD triton_cache_dir = normalize_path_separator( os.path.join(inductor_cache_dir, "triton") ) +======= + triton_cache_dir = os.path.join(inductor_cache_dir, "triton") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}): yield if isinstance(cache_entries, dict): @@ -1232,7 +1314,10 @@ def fresh_cache( # Let's not fail if we can't clean up the temp dir. Also note that for # Windows, we can't delete the loaded modules because the module binaries # are open. +<<<<<<< HEAD ignore_errors=is_windows(), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) onerror=lambda func, path, exc_info: log.warning( "Failed to remove temporary cache dir at %s", inductor_cache_dir, @@ -1459,9 +1544,12 @@ def __add__(self, other: Self) -> IndentedBuffer: res.writelines(other._lines) return res +<<<<<<< HEAD def contains(self, new_line: Union[DeferredLineBase, LineContext, str]) -> bool: return new_line in self._lines +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class FakeIndentedBuffer(IndentedBuffer): def __init__(self) -> None: @@ -1565,6 +1653,7 @@ def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool: @functools.lru_cache def get_max_num_sms() -> int: +<<<<<<< HEAD if torch.xpu.is_available(): return torch.xpu.get_device_properties().gpu_subslice_count return torch.cuda.get_device_properties("cuda").multi_processor_count @@ -1585,6 +1674,14 @@ def get_num_sms() -> int: # TODO we need to properly guard on this global if torch.xpu.is_available(): return get_max_num_sms() +======= + return torch.cuda.get_device_properties("cuda").multi_processor_count + + +def get_num_sms() -> int: + """Handle experimental carveout if set otherwise return hardware SM count""" + # TODO we need to properly guard on this global +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) carveout = torch._C._get_sm_carveout_experimental() return get_max_num_sms() - (carveout if carveout is not None else 0) @@ -1638,11 +1735,15 @@ def _use_conv_autotune_backend(backend: str) -> bool: def use_triton_template( +<<<<<<< HEAD layout: Layout, *, enable_int32: bool = False, enable_float8: bool = False, check_max_autotune: bool = True, +======= + layout: Layout, *, enable_int32: bool = False, enable_float8: bool = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> bool: from .codegen.common import BackendFeature, has_backend_feature @@ -1659,13 +1760,18 @@ def use_triton_template( ) or (layout.device.type == "cpu" and layout.dtype in layout_dtypes) ) +<<<<<<< HEAD # some callers handle max-autotune checking externally and (config.max_autotune or config.max_autotune_gemm or not check_max_autotune) +======= + and (config.max_autotune or config.max_autotune_gemm) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and _use_autotune_backend("TRITON") and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES) ) +<<<<<<< HEAD def can_use_tma(*matrices: IRNode, add_guards: bool = False) -> bool: """ Return True iff *all* supplied tensors satisfy the CUDA-12.9 TMA constraints @@ -1742,10 +1848,36 @@ def _is_tma_compatible_default(x: IRNode) -> bool: # FP8 special case: inner ≄ 32 if dtype == torch.float8_e4m3fn and not V.graph.sizevars.statically_known_geq( +======= +def use_triton_tma_template(*matrices: IRNode) -> bool: + from torch.utils._triton import has_triton_stable_tma_api, has_triton_tma_device + + from .virtualized import V + + def _is_tma_compatible(x: IRNode) -> bool: + if len(x.get_size()) != 2: + return False + + dtype = x.get_dtype() + if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn): + return False + + layout = x.get_layout() + transposed = layout.is_transposed() + if not (layout.is_contiguous() or transposed): + return False + + inner_dim = layout.size[1] + if transposed: + inner_dim = layout.size[0] + + if dtype == torch.float8_e4m3fn and V.graph.sizevars.statically_known_lt( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inner_dim, 32 ): return False +<<<<<<< HEAD return True def _is_tma_compatible_xpu(x: IRNode) -> bool: @@ -1774,6 +1906,19 @@ def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool all(len(m.get_size()) == 2 for m in matrices) and can_use_tma(*matrices, add_guards=add_guards) and config.triton.enable_persistent_tma_matmul +======= + inner_bytes = inner_dim * dtype.itemsize + return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT) + + if has_triton_stable_tma_api() and config.cpp_wrapper: + # TODO(dberard) remove this when we get AOTI support for new TMA APIs (#155047) + return False + + return ( + config.triton.enable_persistent_tma_matmul + and has_triton_tma_device() + and all(_is_tma_compatible(m) for m in matrices) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -1802,9 +1947,14 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: if not try_import_cutlass(): log.warning( "Failed to import CUTLASS lib. Please check whether " +<<<<<<< HEAD "_inductor.config.cuda.cutlass_dir %s is set correctly. " "Skipping CUTLASS backend for now.", config.cuda.cutlass_dir, +======= + "_inductor.config.cuda.cutlass_dir is set correctly. " + "Skipping CUTLASS backend for now." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return False return res @@ -1818,6 +1968,7 @@ def _use_cutlass_for_op(op_name: str) -> bool: return op_name.upper() in [x.strip() for x in enabled_ops.split(",")] +<<<<<<< HEAD _IntLike: TypeAlias = Union[int, sympy.Expr] @@ -1830,6 +1981,24 @@ def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: return ( not torch.version.hip and V.graph.sizevars.statically_known_true( +======= +decompose_k_threshold = 32 + +# To limit compile time +k_splits_limit = 5 + +# Hand-tuned +default_k_splits = [16, 32, 64, 128, 256] + +_IntLike: TypeAlias = Union[int, sympy.Expr] + + +def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: + from torch._inductor.virtualized import V + + return ( + V.graph.sizevars.statically_known_true( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sympy.And( sympy.Ge(k, decompose_k_threshold * m), sympy.Ge(k, decompose_k_threshold * n), @@ -1837,6 +2006,7 @@ def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: ) and not V.graph.aot_mode # TODO: Support AOTI for decomposeK and not V.graph.cpp_wrapper +<<<<<<< HEAD ) @@ -1861,11 +2031,15 @@ def use_contiguous(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: ) and not V.graph.aot_mode and not V.graph.cpp_wrapper +======= + and not config.disable_decompose_k +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @functools.cache def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: +<<<<<<< HEAD # To limit compile time k_splits_limit = config.triton.num_decompose_k_splits @@ -1876,6 +2050,11 @@ def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: return default_k_splits elif k_splits_limit == 0: return [] +======= + # If k is a sympy expression, we can't do any splitting + if isinstance(k, sympy.Expr) and not k.is_number: + return default_k_splits +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (isinstance(m, sympy.Expr) and not m.is_number) or ( isinstance(n, sympy.Expr) and not n.is_number @@ -1915,10 +2094,22 @@ def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: if config.max_autotune_gemm_search_space == "EXHAUSTIVE": return pow_of_2_divisors + mul_of_32_divisors + rest_of_splits +<<<<<<< HEAD best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits # Otherwise, conform results to k_splits_limit return best_splits[:k_splits_limit] +======= + # If the # of power of 2 divisors are greater than k_splits_limit, return all + # This should be ok for compile time, all perfect squares between 128 and min(k / m, k / n) + # should never be a massive amount + if len(pow_of_2_divisors) >= k_splits_limit: + return pow_of_2_divisors + else: + best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits + # Otherwise, conform results to k_splits_limit + return best_splits[:k_splits_limit] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @functools.cache @@ -2193,6 +2384,10 @@ def call(self, *args: Any, **kwargs: Any) -> None: self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() ) # Skip all the actual compiling. +<<<<<<< HEAD +======= + nonlocal save_output_code +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) save_output_code(wrapper_code.value) if kernel_code: save_output_code(kernel_code.value) @@ -2388,6 +2583,7 @@ def get_backend_num_stages() -> int: @functools.cache +<<<<<<< HEAD def get_device_tflops(dtype: torch.dtype) -> float: """ We don't want to throw errors in this function. First check to see if the device is in device_info.py, @@ -2404,6 +2600,11 @@ def get_device_tflops(dtype: torch.dtype) -> float: 0, ) +======= +def get_device_tflops(dtype: torch.dtype) -> int: + from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert dtype in (torch.float16, torch.bfloat16, torch.float32) if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"): @@ -2411,7 +2612,11 @@ def get_device_tflops(dtype: torch.dtype) -> float: from torch._utils_internal import max_clock_rate sm_clock = max_clock_rate() +<<<<<<< HEAD if dtype in (torch.float16, torch.bfloat16) and SM80OrLater: +======= + if dtype in (torch.float16, torch.bfloat16): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return get_max_tensorcore_tflops(dtype, sm_clock) if torch.backends.cuda.matmul.allow_tf32: @@ -2419,7 +2624,11 @@ def get_device_tflops(dtype: torch.dtype) -> float: else: return get_max_simd_tflops(torch.float32, sm_clock) else: +<<<<<<< HEAD if dtype in (torch.float16, torch.bfloat16) and SM80OrLater: +======= + if dtype in (torch.float16, torch.bfloat16): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return get_max_tensorcore_tflops(dtype) if torch.backends.cuda.matmul.allow_tf32: @@ -2554,7 +2763,11 @@ def is_output_of_multi_outputs_template( return ( isinstance(input_buf, ir.MultiOutput) and len(input_buf.inputs) == 1 +<<<<<<< HEAD and is_multi_outputs_template(input_buf.inputs[0]) # type: ignore[arg-type] +======= + and is_multi_outputs_template(input_buf.inputs[0]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -2568,9 +2781,13 @@ def is_collective( from . import ir return ( +<<<<<<< HEAD isinstance(node, ir._CollectiveKernel) and not isinstance(node, ir._WaitKernel) and (op is None or node.op_overload is op) +======= + type(node) == ir._CollectiveKernel and (op is None or node.op_overload is op) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) or ( # TODO: this is a temporary solution to ensure that we can identify torchrec's # communication ops. But in order to allow better communication and computation @@ -2896,9 +3113,16 @@ def maybe_get_suppress_shape_guards_ctx() -> contextlib.AbstractContextManager[N return contextlib.nullcontext() # In standalone inductor compile mode, we might not have a shape_env attached to the fake mode +<<<<<<< HEAD if not tracing_context.fake_mode or not tracing_context.fake_mode.shape_env: return contextlib.nullcontext() shape_env = tracing_context.fake_mode.shape_env +======= + shape_env = tracing_context.fake_mode.shape_env + if not shape_env: + return contextlib.nullcontext() + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return shape_env.suppress_guards() @@ -3042,6 +3266,7 @@ def expr_fits_within_32bit(e: sympy.Expr) -> bool: # (e.g., via ValueRanges) that it is still in bounds if V.graph.sizevars.statically_known_true(e <= int_max): return True +<<<<<<< HEAD # AOTI doesn't guard on < 2**32, so checking hints isn't a viable option, # in case the hinted value is < 2**32, but the allowed range is larger. @@ -3062,6 +3287,8 @@ def expr_fits_within_32bit(e: sympy.Expr) -> bool: # so this could potentially have int64 values return False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Otherwise, the hint MUST exist and be in range return has_hint(e) and size_hint(e) <= int_max @@ -3315,6 +3542,45 @@ def get_donated_idxs() -> Optional[list[int]]: return None +<<<<<<< HEAD +======= +def set_kernel_post_grad_provenance_tracing( + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernelOut], + kernel_name: str, + is_extern: bool = False, +) -> None: + from .codegen.simd_kernel_features import DisableReduction, EnableReduction + from .ir import ExternKernelOut + from .virtualized import V + + if is_extern: + assert isinstance(node_schedule, ExternKernelOut) + curr_node_info = ( + V.debug._inductor_triton_kernel_to_post_grad_node_info.setdefault( + kernel_name, [] + ) + ) + curr_node_info.extend( + origin.name + for origin in node_schedule.origins + if origin.name not in curr_node_info + ) + else: + assert isinstance(node_schedule, list) + for snode in node_schedule: + if snode not in (EnableReduction, DisableReduction): + if snode.node is not None: + curr_node_info = V.debug._inductor_triton_kernel_to_post_grad_node_info.setdefault( + kernel_name, [] + ) + curr_node_info.extend( + origin.name + for origin in snode.node.origins + if origin.name not in curr_node_info + ) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TritonAttrsDescriptorVersion(enum.Enum): V0_NO_TRITON = 0 V1_COMPILER = 1 # triton.compiler.compiler.AttrsDescriptor @@ -3367,7 +3633,11 @@ def is_cudagraph_unsafe_op(node: Operation) -> bool: if ( isinstance(node.op_overload, torch._ops.OpOverload) +<<<<<<< HEAD and torch._C.Tag.cudagraph_unsafe in node.op_overload.tags # type: ignore[attr-defined] +======= + and torch._C.Tag.cudagraph_unsafe in node.op_overload.tags +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): return True @@ -3396,6 +3666,7 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool: ) +<<<<<<< HEAD def is_using_cudagraph_partition() -> bool: return ( torch._inductor.config.triton.cudagraphs @@ -3403,6 +3674,8 @@ def is_using_cudagraph_partition() -> bool: ) and torch._inductor.config.graph_partition +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def dtype_from_size(size: int) -> torch.dtype: from .virtualized import V @@ -3439,6 +3712,7 @@ def is_mkldnn_fp16_supported(device_type: str) -> bool: # match "xpu", "xpu:0", "xpu:1", etc. return True return False +<<<<<<< HEAD def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str: @@ -3705,3 +3979,5 @@ def to_real_tensor(e: Any) -> Any: flat_args = [to_real_tensor(a) for a in flat_args] args, kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec) return args, kwargs +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index ea1073f88b714..3347b38a0ec00 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -80,7 +80,10 @@ from torch._inductor.codegen.cpp_utils import LocalBufferContext from torch._inductor.debug import DebugContext from torch._inductor.graph import GraphLowering +<<<<<<< HEAD from torch._inductor.ir import ExternKernelNode +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._inductor.loop_body import InterpreterShim from torch._subclasses import FakeTensorMode @@ -184,9 +187,12 @@ def get_index_dtype_as_torch_dtype(self): "ops", cast(type[OpsHandler[Any]], MockHandler) ) _graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler) +<<<<<<< HEAD _extern_kernel_nodes: Virtualized[list[ExternKernelNode]] = Virtualized( "extern_kernel_nodes", NullHandler ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler) _fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler) _kernel: Virtualized[NullKernelHandler] = Virtualized( @@ -347,9 +353,12 @@ class _V: ) get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler +<<<<<<< HEAD set_extern_kernel_nodes: Callable[[list[ExternKernelNode]], Any] = ( _extern_kernel_nodes._set_handler ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler get_real_inputs: Callable[[], Any] = _real_inputs._get_handler set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler @@ -376,6 +385,7 @@ def graph(self) -> GraphLowering: return _graph._get_handler() @property +<<<<<<< HEAD def extern_kernel_nodes(self) -> list[ExternKernelNode]: """ The extern_kernel_nodes needed for the entire graph, including the @@ -385,6 +395,8 @@ def extern_kernel_nodes(self) -> list[ExternKernelNode]: return _extern_kernel_nodes._get_handler() @property +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def real_inputs(self): """non-fake example inputs""" return _real_inputs._get_handler() diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index 9a527471c8cc0..18abfbb085323 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -1,8 +1,15 @@ import argparse +<<<<<<< HEAD import datetime import tempfile from collections import defaultdict from dataclasses import dataclass +======= +import dataclasses +import datetime +import tempfile +from collections import defaultdict +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from types import ModuleType from typing import Any, Optional, Protocol @@ -159,7 +166,11 @@ def get_info_str( ) +<<<<<<< HEAD @dataclass +======= +@dataclasses.dataclass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ProfileEvent: category: str key: str @@ -176,10 +187,13 @@ def parse_profile_event_list( nruns: int, device_name: str, ) -> None: +<<<<<<< HEAD """ Parse and generate a report for an event_list. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_self_device_time( ev: torch.autograd.profiler_util.EventList, ) -> float: @@ -299,10 +313,13 @@ def report() -> None: report() +<<<<<<< HEAD PROFILE_DIR = tempfile.gettempdir() PROFILE_PATH = f"{PROFILE_DIR}/compiled_module_profile.json" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def perf_profile( wall_time_ms: float, times: int, @@ -313,14 +330,22 @@ def perf_profile( with torch.profiler.profile(record_shapes=True) as p: benchmark_compiled_module_fn(times=times, repeat=repeat) +<<<<<<< HEAD path = PROFILE_PATH +======= + path = f"{tempfile.gettempdir()}/compiled_module_profile.json" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) p.export_chrome_trace(path) print(f"Profiling result for a compiled module of benchmark {benchmark_name}:") print(f"Chrome trace for the profile is written to {path}") event_list = p.key_averages(group_by_input_shape=True) print(event_list.table(sort_by="self_device_time_total", row_limit=10)) parse_profile_event_list( +<<<<<<< HEAD benchmark_name, event_list, wall_time_ms, times * repeat, p.use_device or "" +======= + benchmark_name, event_list, wall_time_ms, times * repeat, p.use_device +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -468,6 +493,7 @@ def compiled_module_main( "If None, NCU will use '--set full'." ), ) +<<<<<<< HEAD parser.add_argument( "--times", type=int, @@ -481,13 +507,20 @@ def compiled_module_main( help="Number of repetitions of each benchmark run", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = parser.parse_args() if args.benchmark_kernels: benchmark_all_kernels(benchmark_name, args.benchmark_all_configs) else: +<<<<<<< HEAD times = args.times repeat = args.repeat +======= + times = 10 + repeat = 10 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index be6d23bbbc53c..0ebe63c70e6df 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -171,7 +171,11 @@ def _qualified_name(obj, mangle_name=True) -> str: # torch.package and TorchScript have separate mangling schemes to avoid # name collisions from multiple packages. To avoid them interfering with +<<<<<<< HEAD # each other, normalize the package managing here. +======= + # each other, normalize the package manging here. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if package_mangling.is_mangled(module_name): module_name = module_name.replace("<", "_") module_name = module_name.replace(">", "_") @@ -382,7 +386,11 @@ def get_closure(fn): # values global in the function. # In Python 3.9 declaring class as global will make it invisible to # `inspect.getsource`, see https://bugs.python.org/issue42666 . +<<<<<<< HEAD # This could be worked around by manually adding it to `global()` dictionary. +======= +# This could be worked around by manualy adding it to `global()` dictionary. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def createResolutionCallbackFromClosure(fn): @@ -469,7 +477,11 @@ def get_annotation_str(annotation): elif isinstance(annotation, ast.Attribute): return ".".join([get_annotation_str(annotation.value), annotation.attr]) elif isinstance(annotation, ast.Subscript): +<<<<<<< HEAD # In Python3.9+ subscript indices are not wrapped in ast.Index +======= + # In Python3.9+ subscript indicies are not wrapped in ast.Index +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) subscript_slice = annotation.slice return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]" elif isinstance(annotation, ast.Tuple): @@ -1121,7 +1133,11 @@ def _get_overloaded_methods(method, mod_class): mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0]) if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno): raise AssertionError( +<<<<<<< HEAD "Overloads are not usable when a module is redeclared within the same file: " +======= + "Overloads are not useable when a module is redeclared within the same file: " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + str(method) ) return overloads @@ -1271,7 +1287,11 @@ def _get_named_tuple_properties( # [Note: ForwardRef annotations in NamedTuple attributes] # NamedTuple types are slightly different from normal types. # +<<<<<<< HEAD # Normally, annotations are evaluated like this (during jit.script): +======= + # Normally, annotations are evaluted like this (during jit.script): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # 1. Load strings of python code into c++ and parse. # 2. Get annotations as strings # 3. Use the PythonResolver's resolution callback (rcb) to convert @@ -1503,7 +1523,11 @@ def persistent_id(self, obj): # unpicklable if it doesn't contain tensors, as we can just ignore/skip # it. To play it safe, we only do so for common objects that we're sure # don't contain tensors. Feel free to add new types here. Note also that +<<<<<<< HEAD # even if a type isn't listed here this won't block users, since they +======= + # even if a type isn't listed here this won't block users, since thet +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # can just add a __getstate__ or __reduce__ method to their class. if isinstance(obj, LockType): return "" diff --git a/torch/_lazy/extract_compiled_graph.py b/torch/_lazy/extract_compiled_graph.py index 38219a54b30b6..995a28ecb873c 100644 --- a/torch/_lazy/extract_compiled_graph.py +++ b/torch/_lazy/extract_compiled_graph.py @@ -56,9 +56,15 @@ class ReturnValueHandler: r""" When ltc_sync_multi is called on multi tensors, the compiled graph will contain output only for unique tensors - if a tensor appears multiple +<<<<<<< HEAD times in the input to _ltc_sync_multi, only the first occurrence matters. However from python level, we still expect multi tensors returned with duplication +======= + times in the input to _ltc_sync_multi, only the first occurance matters. + + However from python level, we still expect multi tensors returned with duplciation +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) even if the TS graph dedup the output. e.g. for method: def forward(self, a): @@ -123,7 +129,11 @@ def hasDeviceArg(args, kwargs): # To force those tensors on the lazy device, we can not simply override # the device argument since there is no explicit device argument. # What we are doing here is, for the list of covered tensor factory methods +<<<<<<< HEAD # we add a lazy device argument explicitly. +======= + # we add a lazy device argument explicity. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # # TODO: This solution is no ideal since we may miss some factory methods. In future # when we support lazy mode, this method can be replaced by that. @@ -170,7 +180,11 @@ def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable: if len(fallback_ops) > 0: raise RuntimeError( +<<<<<<< HEAD f"Fail to extract the compiled graph because of fallback: {','.join(fallback_ops)}" +======= + f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if not isinstance(lazy_out, (tuple, list)): diff --git a/torch/_lazy/metrics.py b/torch/_lazy/metrics.py index 3f676ec1f8ae0..62d557ab84a2d 100644 --- a/torch/_lazy/metrics.py +++ b/torch/_lazy/metrics.py @@ -13,7 +13,11 @@ def counter_names(): def counter_value(name: str): +<<<<<<< HEAD """Return the value of the counter with the specified name""" +======= + """Return the value of the counter with the speficied name""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch._C._lazy._counter_value(name) diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 251cdefe0f05d..a74dd6e526348 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -28,7 +28,12 @@ def custom_op( mutates_args: Union[str, Iterable[str]], device_types: device_types_t = None, schema: Optional[str] = None, +<<<<<<< HEAD ) -> Callable[[Callable[..., object]], "CustomOpDef"]: ... +======= +) -> Callable[[Callable[..., object]], "CustomOpDef"]: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -40,7 +45,12 @@ def custom_op( mutates_args: Union[str, Iterable[str]], device_types: device_types_t = None, schema: Optional[str] = None, +<<<<<<< HEAD ) -> "CustomOpDef": ... +======= +) -> "CustomOpDef": + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @exposed_in("torch.library") @@ -210,7 +220,10 @@ def __init__( self._lib = get_library_allowing_overwrite(self._namespace, self._name) self._register_to_dispatcher(self._tags) self._disabled_kernel: set = set() +<<<<<<< HEAD self._used_triton_kernels: list[Any] = list() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) OPDEFS[self._qualname] = self @property @@ -404,7 +417,11 @@ def register_fake(self, fn: Callable, /) -> Callable: (sizes/strides/storage_offset/device), it specifies what the properties of the output Tensors are. +<<<<<<< HEAD Please see :func:`torch.library.register_fake` for more details. +======= + Please see :func:`torch.library.impl_abstract` for more details. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: fn (Callable): The function to register as the FakeTensor @@ -447,10 +464,17 @@ def register_fake(self, fn: Callable, /) -> Callable: >>> >>> @nonzero.register_fake >>> def _(x): +<<<<<<< HEAD >>> # Number of nonzero-elements is data-dependent. >>> # Since we cannot peek at the data in an abstract impl, >>> # we use the ctx object to construct a new symint that >>> # represents the data-dependent size. +======= + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an abstract impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> ctx = torch.library.get_ctx() >>> nnz = ctx.new_dynamic_size() >>> shape = [nnz, x.dim()] @@ -560,7 +584,11 @@ def register_autograd( >>> >>> x = torch.randn(3, requires_grad=True) >>> y = numpy_sin(x) +<<<<<<< HEAD >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) +======= + >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> assert torch.allclose(grad_x, x.cos()) >>> >>> # Example with a keyword-only arg @@ -580,7 +608,11 @@ def register_autograd( >>> >>> x = torch.randn(3, requires_grad=True) >>> y = numpy_mul(x, val=3.14) +<<<<<<< HEAD >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) +======= + >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14)) """ @@ -596,6 +628,13 @@ def register_autograd( self._setup_context_fn = setup_context def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None: +<<<<<<< HEAD +======= + if torch._running_with_deploy(): + utils.warn_deploy(stacklevel=5) + return + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lib = self._lib schema_str = self._name + self._schema cpp_schema = _C.parse_schema(schema_str) @@ -914,7 +953,11 @@ def get_library_allowing_overwrite( def _maybe_get_opdef( +<<<<<<< HEAD op: Union[CustomOpDef, _ops.OpOverload, str], +======= + op: Union[CustomOpDef, _ops.OpOverload, str] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Optional[CustomOpDef]: if isinstance(op, CustomOpDef): return op diff --git a/torch/_library/fake_class_registry.py b/torch/_library/fake_class_registry.py index 68208d0be4a86..80ef78d75152c 100644 --- a/torch/_library/fake_class_registry.py +++ b/torch/_library/fake_class_registry.py @@ -137,12 +137,17 @@ def maybe_to_fake_obj( # x.__obj_flatten__() could be calling some tensor operations inside but we don't # want to call these ops in surrounding dispatch modes when executing it. # Otherwise, for example, the fake tensor modes will error out when the tensors inside +<<<<<<< HEAD # script object execute some operations like clone if allow_non_fake_input flag is set. +======= + # script obeject execute some operations like clone if allow_non_fake_input flag is set. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with _disable_current_modes(): flat_x = x.__obj_flatten__() # type: ignore[attr-defined] _check_valid_flat_script_obj(flat_x) +<<<<<<< HEAD with fake_mode: from torch._higher_order_ops.utils import _tensor_storage @@ -181,6 +186,13 @@ def maybe_to_fake_obj( ), flat_x, ) +======= + fake_flattened = pytree.tree_map_only( + torch.Tensor, + lambda t: fake_mode.from_tensor(t), + flat_x, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened) @@ -271,8 +283,13 @@ def pop(self): def size(self): return len(self.queue) +<<<<<<< HEAD In this example, the original TensorQeue need to add a __obj_flatten__ method to the class TensorQueue and the flattened result is passed into FakeTensorQueue's +======= + In this example, the original TensorQeue need to addd a __obj_flatten__ method + to the class TensorQueue and the flattend result is passed into FakeTensorQueue's +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __obj_unflatten__ as inputs to create a fake class. This protocol allows pytorch to look at the contents of the script object and properly handle them in the subsystems like dynamo, aot_aotugrad or more. @@ -281,7 +298,11 @@ def size(self): def inner(fake_class: HasStaticMethodFromReal): ns, name = parse_namespace(qualname) +<<<<<<< HEAD # This also checks whether the referred torch::class_ exists. +======= + # This also checks whether the refered torch::class_ exists. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._C._get_custom_class_python_wrapper(ns, name) from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None) diff --git a/torch/_library/fake_profile.py b/torch/_library/fake_profile.py index d480f66626806..b4b938a27dd73 100644 --- a/torch/_library/fake_profile.py +++ b/torch/_library/fake_profile.py @@ -102,7 +102,11 @@ def unsafe_generate_fake_kernels(op_profiles: dict[str, set[OpProfile]]) -> Gene an output with the same metadata as in the recorded profile. If a profile doesn't exist then an exception will be thrown. +<<<<<<< HEAD The fake kernel generation is considered unsafe because it relies on the +======= + The fake kernel generation is considerd unsafe because it relies on the +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) rigid, pre-defined operator profiles that do not account for potential variations in output behavior. Specifically, the generated kernels assume a fixed relationship between input and output ranks. However, in reality, it's diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index 512bd5835bd95..c53a21c93edb5 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -150,13 +150,21 @@ def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]: "the arguments that are mutated or the string 'unknown'. " ) if schema_type.startswith("Tensor"): +<<<<<<< HEAD schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}" +======= + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif name in mutates_args: if not schema_type.startswith("Tensor"): error_fn( f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated" ) +<<<<<<< HEAD schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}" +======= + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) seen_args.add(name) if param.default is inspect.Parameter.empty: params.append(f"{schema_type} {name}") diff --git a/torch/_library/simple_registry.py b/torch/_library/simple_registry.py index bf25cde9cb531..7f90b6480949f 100644 --- a/torch/_library/simple_registry.py +++ b/torch/_library/simple_registry.py @@ -28,10 +28,16 @@ def __init__(self): self._data = {} def find(self, qualname: str) -> "SimpleOperatorEntry": +<<<<<<< HEAD res = self._data.get(qualname, None) if res is None: self._data[qualname] = res = SimpleOperatorEntry(qualname) return res +======= + if qualname not in self._data: + self._data[qualname] = SimpleOperatorEntry(qualname) + return self._data[qualname] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) singleton: SimpleLibraryRegistry = SimpleLibraryRegistry() diff --git a/torch/_library/triton.py b/torch/_library/triton.py index 741b341f7e210..d482dd30bb194 100644 --- a/torch/_library/triton.py +++ b/torch/_library/triton.py @@ -1,6 +1,10 @@ +<<<<<<< HEAD import ast import contextlib import inspect +======= +import contextlib +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import threading from collections.abc import Generator, Iterable from typing import Any, Callable, Optional, Union @@ -11,6 +15,7 @@ from .infer_schema import infer_schema +<<<<<<< HEAD triton_ops_to_kernels: dict[str, list[object]] = {} @@ -84,6 +89,8 @@ def visit_Call(self, node: ast.Call) -> None: return find_triton_kernels(fn) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @exposed_in("torch.library") def triton_op( name: str, @@ -230,6 +237,7 @@ def functional_decomp( # type: ignore[no-untyped-def] if custom_triton_ops_decomposition_disabled(): return mode.__torch_dispatch__(op, types, args, kwargs) else: +<<<<<<< HEAD # TODO: https://github.com/pytorch/pytorch/issues/160333 # We should deduplicate the unrecognized_types logic. import torch._subclasses @@ -252,6 +260,11 @@ def functional_decomp( # type: ignore[no-untyped-def] triton_kernels = get_inner_triton_kernels(fn) triton_ops_to_kernels[name] = triton_kernels +======= + with mode: + return fn(*args, **kwargs) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) return result diff --git a/torch/_library/utils.py b/torch/_library/utils.py index 59a316acc69af..c4cff5801d284 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -2,8 +2,14 @@ import dataclasses import inspect import sys +<<<<<<< HEAD from collections.abc import Iterable, Iterator from typing import Any, Callable, Literal, Optional, overload, Union +======= +import warnings +from collections.abc import Iterable, Iterator +from typing import Any, Callable, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.utils._pytree as pytree @@ -11,6 +17,18 @@ from torch._ops import OpOverload +<<<<<<< HEAD +======= +def warn_deploy(stacklevel=3): + warnings.warn( + "Python torch.library APIs do nothing under torch::deploy (multipy). " + "Please instead use C++ custom operator registration APIs.", + RuntimeWarning, + stacklevel=stacklevel, + ) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass class Kernel: """Models a (function, source location)""" @@ -432,7 +450,11 @@ def check_one(info, was_mutated): f"{self.op._name}: for argument '{info.name}': the operator's schema " f"{self.op._schema} specified that " f"the operator {'mutates' if info.is_write else 'does not mutate'} " +<<<<<<< HEAD f"the argument, but this seems to be empirically wrong. " +======= + f"the argument, but this seems to be emperically wrong. " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"Please make the schema and operator behavior consistent. " f"You can specify that an operator mutates a Tensor by " f"e.g. changing its schema type from 'Tensor name' to 'Tensor(a!) name'" @@ -501,6 +523,7 @@ def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str] ] +<<<<<<< HEAD # Case 1: with_default=True (or omitted). Return type is guaranteed to be a Tag. @overload def get_layout_constraint_tag( @@ -515,6 +538,8 @@ def get_layout_constraint_tag( ) -> Optional[_C.Tag]: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_layout_constraint_tag(fn, *, with_default=True): for tag in tags_by_priority: if tag in fn.tags: diff --git a/torch/_linalg_utils.py b/torch/_linalg_utils.py index 43c8b65767e00..e4f07ff67c695 100644 --- a/torch/_linalg_utils.py +++ b/torch/_linalg_utils.py @@ -8,7 +8,11 @@ def is_sparse(A): +<<<<<<< HEAD """Check if tensor A is a sparse COO tensor. All other sparse storage formats (CSR, CSC, etc...) will return False.""" +======= + """Check if tensor A is a sparse tensor""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(A, torch.Tensor): return A.layout == torch.sparse_coo diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index a3f57411b8f54..2d0e0f95422f3 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -57,7 +57,11 @@ def _polynomial_coefficients_given_roots(roots): # So the code below tries to circumvent the explicit root finding by series # of operations on memory copies imitating the Horner's method. # The memory copies are required to construct nodes in the computational graph +<<<<<<< HEAD # by exploiting the explicit (not in-place, separate node for each step) +======= + # by exploting the explicit (not in-place, separate node for each step) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # recursion of the Horner's method. # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity. poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs @@ -80,7 +84,11 @@ def _polynomial_value(poly, x, zero_power, transition): poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n +<<<<<<< HEAD x (Tensor): the value (possible batched) to evaluate the polynomial `poly` at. +======= + x (Tensor): the value (possible batched) to evalate the polynomial `poly` at. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) zero_power (Tensor): the representation of `x^0`. It is application-specific. @@ -168,7 +176,11 @@ def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest): # of the characteristic polynomial. chr_poly_D = _polynomial_coefficients_given_roots(D) +<<<<<<< HEAD # the code below finds the explicit solution to the Sylvester equation +======= + # the code belows finds the explicit solution to the Sylvester equation +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U # and incorporates it into the whole gradient stored in the `res` variable. # @@ -391,17 +403,23 @@ def lobpcg( we do the following symmetrization map: `A -> (A + A.t()) / 2`. The map is performed only when the `A` requires gradients. +<<<<<<< HEAD .. warning:: LOBPCG algorithm is not applicable when the number of `A`'s rows is smaller than 3x the number of requested eigenpairs `n`. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: A (Tensor): the input tensor of size :math:`(*, m, m)` +<<<<<<< HEAD k (integer, optional): the number of requested eigenpairs. Default is the number of :math:`X` columns (when specified) or `1`. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) B (Tensor, optional): the input tensor of size :math:`(*, m, m)`. When not specified, `B` is interpreted as identity matrix. @@ -411,6 +429,7 @@ def lobpcg( initial approximation of eigenvectors. X must be a dense tensor. +<<<<<<< HEAD n (integer, optional): if :math:`X` is not specified then `n` specifies the size of the generated random approximation of eigenvectors. Default value for `n` @@ -426,6 +445,21 @@ def lobpcg( the current approximation of eigenpairs is returned. For infinite iteration but until convergence criteria is met, use `-1`. +======= + iK (tensor, optional): the input tensor of size :math:`(*, m, + m)`. When specified, it will be used as preconditioner. + + k (integer, optional): the number of requested + eigenpairs. Default is the number of :math:`X` + columns (when specified) or `1`. + + n (integer, optional): if :math:`X` is not specified then `n` + specifies the size of the generated random + approximation of eigenvectors. Default value for `n` + is `k`. If :math:`X` is specified, the value of `n` + (when specified) must be the number of :math:`X` + columns. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tol (float, optional): residual tolerance for stopping criterion. Default is `feps ** 0.5` where `feps` is @@ -441,6 +475,15 @@ def lobpcg( description of the function above. Default is "ortho". +<<<<<<< HEAD +======= + niter (int, optional): maximum number of iterations. When + reached, the iteration process is hard-stopped and + the current approximation of eigenpairs is returned. + For infinite iteration but until convergence criteria + is met, use `-1`. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tracker (callable, optional) : a function for tracing the iteration process. When specified, it is called at each iteration step with LOBPCG instance as an diff --git a/torch/_logging/__init__.py b/torch/_logging/__init__.py index d0fdebb23bde9..0972ce187d5a3 100644 --- a/torch/_logging/__init__.py +++ b/torch/_logging/__init__.py @@ -12,7 +12,10 @@ dtrace_structured, get_structured_logging_overhead, getArtifactLogger, +<<<<<<< HEAD hide_warnings, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) LazyString, set_logs, trace_structured, diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index a418fe3b60970..d5e72e15fad3f 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -1,5 +1,8 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD import contextlib +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import functools import hashlib import importlib.util @@ -13,7 +16,10 @@ import sys import tempfile import time +<<<<<<< HEAD import warnings +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from collections import defaultdict from dataclasses import dataclass, field from typing import Any, Callable, Generic, Optional, Union @@ -726,6 +732,7 @@ def _invalid_settings_err_msg(settings, verbose=False): return msg +<<<<<<< HEAD def process_env_var_string_for_windows(env_var_str: str) -> str: """ When we setup logging config as guide: https://docs.pytorch.org/docs/stable/logging.html @@ -769,6 +776,10 @@ def remove_outer_quotes(s: str) -> str: def _parse_log_settings(settings): settings = process_env_var_string_for_windows(settings) +======= +@functools.lru_cache +def _parse_log_settings(settings): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if settings == "": return {} @@ -1199,6 +1210,7 @@ def warning_once(logger_obj, *args, **kwargs) -> None: logger_obj.warning(*args, **kwargs) +<<<<<<< HEAD def safe_grad_filter(message, category, filename, lineno, file=None, line=None) -> bool: return "The .grad attribute of a Tensor" not in str(message) @@ -1238,6 +1250,8 @@ def _showwarning(*args, **kwargs): warnings.showwarning = prior +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class LazyString(Generic[_P]): def __init__( self, func: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs @@ -1289,7 +1303,10 @@ def trace_structured_artifact( name: str, # this will go in metadata encoding: str, payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None, +<<<<<<< HEAD compile_id: Optional[CompileId] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: trace_structured( "artifact", @@ -1298,7 +1315,10 @@ def trace_structured_artifact( "encoding": encoding, }, payload_fn=payload_fn, +<<<<<<< HEAD compile_id=compile_id, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -1321,12 +1341,17 @@ def trace_structured( payload is an arbitrary string, which can be arbitrarily long (but expected to have newlines so no lines are too long) """ +<<<<<<< HEAD assert name not in [ +======= + assert "name" not in [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "rank", "compiled_autograd_id", "frame_id", "frame_compile_id", "attempt", +<<<<<<< HEAD "severity", "timestamp", "pathname", @@ -1338,6 +1363,15 @@ def trace_structured( assert callable(payload_fn), ( f"payload_fn should be callable, but got {type(payload_fn)}" ) +======= + ] + assert callable( + metadata_fn + ), f"metadata_fn should be callable, but got {type(metadata_fn)}" + assert callable( + payload_fn + ), f"payload_fn should be callable, but got {type(payload_fn)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # trace_log never propagates and is ALWAYS DEBUG, so also check that there # are handlers instead of checking the log level if trace_log.handlers: diff --git a/torch/_logging/structured.py b/torch/_logging/structured.py index 4eae33227e618..a4f7e8c458ca9 100644 --- a/torch/_logging/structured.py +++ b/torch/_logging/structured.py @@ -1,7 +1,10 @@ """ Utilities for converting data types into structured JSON for dumping. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import inspect import os import traceback diff --git a/torch/_lowrank.py b/torch/_lowrank.py index 182883cfc5e59..b319439b3044d 100644 --- a/torch/_lowrank.py +++ b/torch/_lowrank.py @@ -27,7 +27,11 @@ def get_approximate_basis( .. note:: For an adequate approximation of a k-rank matrix :math:`A`, where k is not known in advance but could be estimated, the number of :math:`Q` columns, q, can be +<<<<<<< HEAD chosen according to the following criteria: in general, +======= + choosen according to the following criteria: in general, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) :math:`k <= q <= min(2*k, m, n)`. For large low-rank matrices, take :math:`q = k + 5..10`. If k is relatively small compared to :math:`min(m, n)`, choosing @@ -100,7 +104,11 @@ def svd_lowrank( .. note:: For an adequate approximation of a k-rank matrix :math:`A`, where k is not known in advance but could be estimated, the number of :math:`Q` columns, q, can be +<<<<<<< HEAD chosen according to the following criteria: in general, +======= + choosen according to the following criteria: in general, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) :math:`k <= q <= min(2*k, m, n)`. For large low-rank matrices, take :math:`q = k + 5..10`. If k is relatively small compared to :math:`min(m, n)`, choosing diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 9202b4da41d28..ca7b00849f4e4 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1,8 +1,15 @@ # mypy: allow-untyped-defs import math +<<<<<<< HEAD from collections.abc import Sequence from enum import Enum from functools import wraps +======= +import operator +from collections.abc import Sequence +from enum import Enum +from functools import reduce, wraps +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Callable, Optional, TypeVar, Union from typing_extensions import ParamSpec @@ -16,15 +23,31 @@ meta_table, ) from torch._ops import OpOverload +<<<<<<< HEAD from torch._prims import _prim_elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND +======= +from torch._prims import ( + _prim_elementwise_meta, + ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, + view_of, +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._prims_common import ( BoolLike, corresponding_complex_dtype, corresponding_real_dtype, +<<<<<<< HEAD +======= + definitely_contiguous, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, FloatLike, IntLike, +<<<<<<< HEAD +======= + is_contiguous, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) make_contiguous_strides_for, Number, suggest_memory_format, @@ -195,6 +218,171 @@ def linalg_cross(self, other, *, dim=-1): return self.new_empty(out_shape) +<<<<<<< HEAD +======= +# This function is python match of computeStride_impl in TensorUtils.cpp +def _compute_stride(old_shape, old_stride, new_shape, size_oblivious=False): + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + sym_eq, + ) + + def maybe_guard_or_false(x): + if size_oblivious: + return guard_or_false(x) + + return x + + def maybe_guard_or_true(x): + if size_oblivious: + return guard_or_true(x) + + return x + + if len(old_shape) == 0: + return [1] * len(new_shape) + + numel = reduce(operator.mul, old_shape, 1) + zero_numel = maybe_guard_or_false(numel == 0) + if zero_numel and maybe_guard_or_false(sym_eq(old_shape, new_shape)): + return old_stride + + new_stride = [0] * len(new_shape) + + if zero_numel: + for view_d in range(len(new_shape) - 1, -1, -1): + if view_d == len(new_shape) - 1: + new_stride[view_d] = 1 + else: + new_stride[view_d] = ( + max(new_shape[view_d + 1], 1) * new_stride[view_d + 1] + ) + return new_stride + + view_d = len(new_shape) - 1 + chunk_base_stride = old_stride[-1] + tensor_numel = 1 + view_numel = 1 + + for tensor_d in range(len(old_shape) - 1, -1, -1): + tensor_numel *= old_shape[tensor_d] + + if tensor_d == 0 or ( + maybe_guard_or_true(old_shape[tensor_d - 1] != 1) + and maybe_guard_or_true( + old_stride[tensor_d - 1] != tensor_numel * chunk_base_stride + ) + ): + while view_d >= 0 and ( + maybe_guard_or_true(view_numel < tensor_numel) + or maybe_guard_or_false(new_shape[view_d] == 1) + ): + new_stride[view_d] = view_numel * chunk_base_stride + view_numel *= new_shape[view_d] + view_d -= 1 + + if maybe_guard_or_true(view_numel != tensor_numel): + return None + + if tensor_d > 0: + chunk_base_stride = old_stride[tensor_d - 1] + tensor_numel = 1 + view_numel = 1 + if view_d != -1: + return None + return new_stride + + +def _view_has_unbacked_input(a, shape): + from torch.fx.experimental.symbolic_shapes import has_hint + + return ( + any(not has_hint(s) for s in a.size()) + or any(not has_hint(s) for s in a.stride()) + or any(not has_hint(s) for s in shape) + ) + + +def _view_unbacked_meta(a, shape, size_oblivious_enabled=True): + from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq + + # Creates a valid shape + shape = utils.extract_shape_from_varargs(shape, validate=False) + + # Reshape may be given a shape with a -1 length + # This indicates that the dimension's length should be inferred + shape = utils.infer_size(shape, a.numel()) + + # Special-cases reshaping zero dim tensors + if a.ndim == 0: + _a = a + for length in shape: + torch._check(length == 1) + _a = torch._refs.unsqueeze(_a, -1) + if _a is a: + return view_of(a) + else: + return _a + + # Special-cases reshaping to zero dim tensors + if len(shape) == 0: + _a = a + for length in a.shape: + torch._check(length == 1) + _a = torch._refs.squeeze(_a, -1) + if _a is a: + return view_of(a) + else: + return _a + + shape_numel = reduce(operator.mul, shape, 1) + + torch._check( + a.numel() == shape_numel, + lambda: f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!", + ) + + if len(shape) == len(a.shape) and guard_or_false(sym_eq(shape, a.shape)): + return view_of(a) + + if definitely_contiguous(a) if size_oblivious_enabled else is_contiguous(a): + strides = utils.make_contiguous_strides_for(shape) + return a.as_strided(shape, strides) + + new_strides = _compute_stride( + a.size(), a.stride(), shape, size_oblivious=size_oblivious_enabled + ) + + if new_strides is not None: + return a.as_strided(shape, new_strides) + + # If we fail to do size oblivious view, and backed_size_oblivious was on, + # then we redo everything by looking at hints and guarding instead of failing. + # Also if the expression has unbacked symbols, then we run again with size_oblivious_enabled=False + # to throw a data dependent error. + + if size_oblivious_enabled and ( + torch.fx.experimental._config.backed_size_oblivious + or _view_has_unbacked_input(a, shape) + ): + return _view_unbacked_meta(a, shape, size_oblivious_enabled=False) + + msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!" + raise ValueError(msg) + + +@register_meta(aten.view.default) +def _view_meta(a, *shape): + if torch.fx.experimental._config.backed_size_oblivious or _view_has_unbacked_input( + a, shape + ): + return _view_unbacked_meta(a, shape) + else: + return torch._refs._reshape_view_helper(a, *shape, allow_copy=False) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_meta(aten.linalg_matrix_exp) @out_wrapper() def linalg_matrix_exp(self): @@ -2373,10 +2561,16 @@ def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int ret_shape.append( _formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]) ) +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import sym_or torch._check( sym_or(*[x > 0 for x in ret_shape[2:]]), +======= + + torch._check( + any(x > 0 for x in ret_shape[2:]), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lambda: f"Given input size per channel: {list(dims)}. " f"Calculated output size per channel: {ret_shape[2:]}. " f"Output size is too small", @@ -2404,7 +2598,11 @@ def meta_miopen_batch_norm( out_shape = input_tensor.shape # If tensor is provided for running_mean and running_var then use this. If these are not +<<<<<<< HEAD # provided then we return the shape of weight tensor. Similar to how this is handled in the decomposition +======= + # provded then we return the shape of weight tensor. Similar to how this is handled in the decomposition +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) save_mean_shape = running_mean.shape if running_mean is not None else weight.shape save_var_shape = running_var.shape if running_var is not None else weight.shape @@ -2552,6 +2750,7 @@ def meta_qconv_pointwise( groups, None, ) +<<<<<<< HEAD if output_dtype is None: output_dtype = x.dtype assert output_dtype in [ @@ -2570,6 +2769,12 @@ def meta_qconv_pointwise( 4: torch.channels_last, 5: torch.channels_last_3d, }[len(shape_out)] +======= + assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8, torch.int8] + out = x.new_empty(shape_out, dtype=output_dtype) + assert len(shape_out) in [3, 4], "only conv1d/2d are supported" + format = torch.channels_last if len(shape_out) == 4 else torch.contiguous_format +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = out.to(memory_format=format) return out @@ -2621,6 +2826,7 @@ def meta_qlinear_pointwise( output_shape = list(x.shape) # The weight has been transposed during the qlinear weight prepack process. output_shape[-1] = w.shape[1] +<<<<<<< HEAD assert output_dtype in [ torch.float32, torch.bfloat16, @@ -2628,6 +2834,9 @@ def meta_qlinear_pointwise( torch.uint8, torch.float8_e4m3fn, ] +======= + assert output_dtype in [torch.float32, torch.bfloat16, torch.int8, torch.uint8] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = x.new_empty(output_shape, dtype=output_dtype) return out @@ -2658,6 +2867,7 @@ def meta_qlinear_pointwise_binary( output_shape = list(x.shape) # The weight has been transposed during the qlinear weight prepack process. output_shape[-1] = w.shape[1] +<<<<<<< HEAD assert output_dtype in [ torch.float32, torch.bfloat16, @@ -2665,6 +2875,9 @@ def meta_qlinear_pointwise_binary( torch.int8, torch.float8_e4m3fn, ] +======= + assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8, torch.int8] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = x.new_empty(output_shape, dtype=output_dtype) return out @@ -3312,12 +3525,17 @@ def meta_repeat_interleave_Tensor(repeats, output_size=None): def meta_complex(real, imag): assert real.dtype.is_floating_point assert imag.dtype.is_floating_point +<<<<<<< HEAD result = elementwise_meta( real.to(corresponding_complex_dtype(real.dtype)), imag.to(corresponding_complex_dtype(imag.dtype)), type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) return result +======= + out_shape = _broadcast_shapes(real.shape, imag.shape) + return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_meta([aten.nonzero_static.default, aten.nonzero_static.out]) @@ -3455,9 +3673,15 @@ def _restride_src(self): return self.as_strided(shape, strides) out = self.new_empty(before_shape + replacement_shape + after_shape) +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import guard_or_false if guard_or_false(self.numel() == 0): +======= + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(self.numel() == 0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # No need to worry about the output strides if self is empty. return out @@ -3798,7 +4022,11 @@ def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + kai_num_bytes_bias ) +<<<<<<< HEAD # This function returns size of these datatypes stored as enum. We modify it to just return bf16 datatype +======= + # This funtion retuns size of these datatypes stored as enum. We modify it to just return bf16 datatype +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/kai_common.h?ref_type=heads#L55 def kai_get_bf16_datatype_size_in_bytes(): return 2 # 2 bytes @@ -4194,6 +4422,7 @@ def is_booleanic(arg): return self +<<<<<<< HEAD @register_meta( [ aten.add.Scalar, @@ -4206,6 +4435,8 @@ def meta_binop_alpha(self, other, alpha=1): ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_meta([aten.round.default, aten.round.decimals]) def meta_round(self, **kwargs): return elementwise_meta( @@ -4346,6 +4577,14 @@ def meta_index_put_(self, indices, values, accumulate=False): return self +<<<<<<< HEAD +======= +@register_meta(aten.alias.default) +def meta_alias(self): + return self.view(self.shape) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None, out_dtype=None): torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") @@ -4768,7 +5007,11 @@ def unpack(name, val): else: torch._check( False, +<<<<<<< HEAD lambda: "Unsupported memory format. Supports only ChannelsLast, Contiguous", +======= + lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) @@ -4890,7 +5133,11 @@ def meta_fractional_max_pool2d(self, kernel_size, output_size, random_samples): torch._check( self.size(d) > 0, f"fractional_max_pool2d: Expected input to have non-zero " +<<<<<<< HEAD f" size for non-batch dimensions, but got {self.size()} with dimension {d} empty", +======= + f" size for non-batch dimenions, but got {self.size()} with dimension {d} empty", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # the check and message are out of sync, but this matches the structured meta @@ -5396,6 +5643,42 @@ def meta_zeros( ) +<<<<<<< HEAD +======= +@register_meta(aten.select.int) +def meta_select(self, dim, index): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + ndim = self.dim() + torch._check_index( + ndim != 0, + lambda: "select() cannot be applied to a 0-dim tensor.", + ) + + dim = dim if dim >= 0 else dim + ndim + size = self.size(dim) + + torch._check_index( + not ( + guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size) + ), + lambda: f"select(): index {index} out of range for tensor of size " + f"{self.size()} at dimension {dim}", + ) + + index = index if index >= 0 else index + size + + new_size = list(self.size()) + new_stride = list(self.stride()) + + new_storage_offset = self.storage_offset() + index * new_stride[dim] + del new_size[dim] + del new_stride[dim] + + return self.as_strided(new_size, new_stride, new_storage_offset) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_meta(aten.select_scatter.default) def meta_select_scatter(self, src, dim, index): return utils.clone_preserve_strides(self) @@ -5442,10 +5725,17 @@ def gather_shape_check(self, dim, index): @register_meta(aten.gather.default) def meta_gather(self, dim, index, sparse_grad=False): +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import guard_or_false wrapped_dim = maybe_wrap_dim(dim, self.dim()) is_index_empty = guard_or_false(index.numel() == 0) +======= + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + wrapped_dim = maybe_wrap_dim(dim, self.dim()) + is_index_empty = guard_size_oblivious(index.numel() == 0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not is_index_empty: torch._check( index.dtype == torch.long or index.dtype == torch.int, @@ -5484,9 +5774,15 @@ def get_operator_enum(reduce_, use_new_options=False): # From aten/src/ATen/native/ScatterGatherChecks.h def scatter_gather_dtype_check(method_name, self, index, src_opt=None): +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import guard_or_true if guard_or_true(index.numel() != 0): +======= + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(index.numel() != 0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._check( index.dtype == torch.long or index.dtype == torch.int, lambda: f"{method_name}(): Expected dtype int32/int64 for index", @@ -5505,9 +5801,15 @@ def ensure_nonempty_dim(dim): # From aten/src/ATen/native/ScatterGatherChecks.h def scatter_shape_check(self, dim, index, src_opt=None): +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import guard_or_false if guard_or_false(index.numel() == 0): +======= + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(index.numel() == 0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch._check( ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()), @@ -5648,7 +5950,11 @@ def meta__scaled_dot_product_flash_attention( # it's possible we'll need to have some special handling in inductor for sdpa # See [Note] BC breaking change to flash seed/offset if torch.version.hip and torch.cuda.is_available(): +<<<<<<< HEAD # Maintain old path on AMD +======= + # Maintian old path on AMD +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) seed = torch.empty((), dtype=torch.long, device="meta") offset = torch.empty((), dtype=torch.long, device="meta") else: @@ -5710,7 +6016,11 @@ def meta__scaled_dot_product_cudnn_attention( res = alloc_with_matching_layout(query, res_shape) logsum_exp = torch.empty( +<<<<<<< HEAD (B, H, S_Q, 1), +======= + (B, H, S_Q), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtype=torch.float, device=query.device, ) @@ -5855,21 +6165,41 @@ def meta__scaled_dot_product_flash_attention_for_cpu_backward( scale: Optional[float] = None, ): # cpus's grad layout is different from cuda's, +<<<<<<< HEAD # i.e. (batch_size, seq_len, num_heads, head_dim) grad_q = torch.empty_permuted( query.size(), +======= + # i.e. (batch_size, seq_len,num_heads, head_dim) + batch_size = query.size(0) + num_heads = query.size(1) + head_dim = query.size(3) + len_q = query.size(2) + len_k = key.size(2) + + grad_q = torch.empty_permuted( + (batch_size, num_heads, len_q, head_dim), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (0, 2, 1, 3), dtype=query.dtype, device=query.device, ) grad_k = torch.empty_permuted( +<<<<<<< HEAD key.size(), +======= + (batch_size, num_heads, len_k, head_dim), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (0, 2, 1, 3), dtype=key.dtype, device=key.device, ) grad_v = torch.empty_permuted( +<<<<<<< HEAD value.size(), +======= + (batch_size, num_heads, len_k, head_dim), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (0, 2, 1, 3), dtype=value.dtype, device=value.device, @@ -5878,6 +6208,7 @@ def meta__scaled_dot_product_flash_attention_for_cpu_backward( return grad_q, grad_k, grad_v +<<<<<<< HEAD @register_meta([aten._scaled_dot_product_attention_math_for_mps]) def meta__scaled_dot_product_attention_math_for_mps( query: Tensor, @@ -5933,6 +6264,8 @@ def sdpa_vector_2pass_mps(): return sdpa_vector_fast_mps() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_meta([aten._scaled_dot_product_efficient_attention]) def meta__scaled_dot_product_efficient_attention( query: Tensor, @@ -6132,7 +6465,11 @@ def meta__flash_attention_forward( # See [Note] BC breaking change to flash seed/offset seed, offset = None, None if torch.version.hip and torch.cuda.is_available(): +<<<<<<< HEAD # Maintain old path on AMD +======= + # Maintian old path on AMD +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) seed = torch.empty((), dtype=torch.long, device="meta") offset = torch.empty((), dtype=torch.long, device="meta") else: @@ -6344,7 +6681,11 @@ def has_zero_dim(tensor_2d): ) torch._check( mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0, +<<<<<<< HEAD lambda: f"Expected both dimensions of mat2 to be divisible by 16 but got {mat2.shape}", +======= + lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # determine scaling type and check input dimensions (refer to Blas.cpp op) @@ -6471,7 +6812,11 @@ def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True): def meta_multinomial(input, num_samples, replacement=False, *, generator=None): torch._check( 0 < input.dim() <= 2, +<<<<<<< HEAD lambda: f"The probability distributions dimensions must be 1 or 2, but got {input.dim()}", +======= + lambda: f"The probabilty distributions dimensions must be 1 or 2, but got {input.dim()}", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if input.dim() == 1: return torch.empty(num_samples, dtype=torch.long, device=input.device) @@ -7097,7 +7442,12 @@ def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale): @register_meta([aten.nan_to_num.default, aten.nan_to_num.out]) @out_wrapper() def nan_to_num(self, nan=None, posinf=None, neginf=None): +<<<<<<< HEAD return torch.empty_like(self) +======= + result_size = list(self.size()) + return self.new_empty(result_size) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_meta(torch.ops.aten.transpose_) @@ -7366,24 +7716,37 @@ def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype): out_size = [offs.size(0), mat1.size(0), mat2.size(1)] else: torch._check( +<<<<<<< HEAD offs.size(0) == mat2.size(0), lambda: "matrix batch sizes have to match" +======= + offs.size(0) == mat2.size(0), "matrix batch sizes have to match" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) out_size = [mat1.size(0), mat2.size(-1)] else: if mat2_is_2d: torch._check( +<<<<<<< HEAD offs.size(0) == mat1.size(0), lambda: "matrix batch sizes have to match" +======= + offs.size(0) == mat1.size(0), "matrix batch sizes have to match" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) out_size = [mat1.size(1), mat2.size(1)] else: # regular bmm +<<<<<<< HEAD torch._check( mat1.size(0) == mat2.size(0), lambda: "batched dimension has to match" ) +======= + torch._check(mat1.size(0) == mat2.size(0), "batched dimension has to match") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_size = [mat1.size(0), mat1.size(1), mat2.size(-1)] out_dtype = out_dtype or mat1.dtype +<<<<<<< HEAD if torch.version.cuda: alignment = 16 // out_dtype.itemsize size_padded = (out_size[-1] + alignment - 1) // alignment * alignment @@ -7396,6 +7759,15 @@ def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype): ) else: out = torch.empty(out_size, dtype=out_dtype, device=mat1.device) +======= + alignment = 16 // out_dtype.itemsize + size_padded = (out_size[-1] + alignment - 1) // alignment * alignment + if mat1_is_2d == mat2_is_2d: + out_stride = [out_size[1] * size_padded, size_padded, 1] + else: + out_stride = [size_padded, 1] + out = torch.empty_strided(out_size, out_stride, dtype=out_dtype, device=mat1.device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return out @@ -7421,31 +7793,48 @@ def _meta_grouped_mm_common( # aten/src/ATen/native/cuda/Blas.cpp. if scaled: +<<<<<<< HEAD fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn torch._check( mat_a.dtype == fp8_dtype and mat_b.dtype == fp8_dtype, lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950 +======= + torch._check( + mat_a.dtype == torch.float8_e4m3fn and mat_b.dtype == torch.float8_e4m3fn, + lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: torch._check( mat_a.dtype == torch.bfloat16 and mat_b.dtype == torch.bfloat16, +<<<<<<< HEAD lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950 +======= + lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) torch._check( mat_a.dim() in [2, 3] and mat_b.dim() in [2, 3], +<<<<<<< HEAD lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}", # noqa: B950 +======= + lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) mat_a_is_2d = mat_a.dim() == 2 mat_b_is_2d = mat_b.dim() == 2 +<<<<<<< HEAD if not mat_a_is_2d or not mat_b_is_2d: torch._check( mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if scaled: def is_row_major(mat): @@ -7458,11 +7847,19 @@ def is_col_major(mat): torch._check( is_row_major(mat_a), +<<<<<<< HEAD lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}", # noqa: B950 ) torch._check( is_col_major(mat_b), lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}", # noqa: B950 +======= + lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}", + ) + torch._check( + is_col_major(mat_b), + lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def check_valid_strides(mat_name, mat): @@ -7474,7 +7871,11 @@ def check_valid_strides(mat_name, mat): ): torch._check( mat_stride[end_dim] % alignment == 0, +<<<<<<< HEAD lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.", # noqa: B950 +======= + lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) elif mat_stride[end_dim] == 1 and mat_stride[end_dim - 1] >= max( 1, mat.shape[end_dim] @@ -7494,6 +7895,7 @@ def check_valid_strides(mat_name, mat): if scale_a is not None and scale_b is not None: torch._check( +<<<<<<< HEAD (scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32) or ( scale_a.dtype == torch.float8_e8m0fnu @@ -7509,10 +7911,16 @@ def check_valid_strides(mat_name, mat): def round_up(x, y): """Rounds up x to nearest multiple of y""" return ((x + y - 1) // y) * y +======= + scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32, + lambda: "Both scale_a and scale_b must be float (fp32) tensors, but got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950 + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1): if mat.dim() == 2: torch._check( +<<<<<<< HEAD scale.is_contiguous(), lambda: f"Expected {scale_name} to be contiguous.", ) @@ -7537,12 +7945,33 @@ def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1): else: torch._check( scale.stride(-1) == 1, +======= + scale.dim() == 1, + lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.", + ) + torch._check( + scale.is_contiguous(), + lambda: f"Expected {scale_name} to be contiguous.", + ) + torch._check( + scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier, + lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950 + ) + else: + torch._check( + scale.dim() == 2, + lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.", + ) + torch._check( + scale.stride(1) == 1, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lambda: f"Expected {scale_name} to be contiguous in the last dimension.", ) torch._check( scale.shape[0] == mat.shape[0], lambda: f"Expected {scale_name} batch dimension to be {mat.shape[0]}, got {scale.shape[0]}.", ) +<<<<<<< HEAD # For MXFP8, 3d tensors have static 'groups' (stack of 2d tensors) so we can know the expected blocked # scale sizes at compile time. if is_mxfp8: @@ -7569,6 +7998,12 @@ def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1): scale.shape[1] == mat.shape[1 + scaled_dim], lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.", # noqa: B950 ) +======= + torch._check( + scale.shape[1] == mat.shape[1 + scaled_dim], + lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.", + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) scale_multiplier = ( offs.shape[0] if offs is not None and mat_a_is_2d and mat_b_is_2d else 1 @@ -7616,7 +8051,11 @@ def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1): @register_meta(aten._grouped_mm) @out_wrapper() +<<<<<<< HEAD def meta_grouped_mm( +======= +def grouped_mm( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mat_a: Tensor, mat_b: Tensor, offs: Optional[Tensor] = None, @@ -7635,7 +8074,11 @@ def meta_grouped_mm( ) +<<<<<<< HEAD @register_meta([aten._scaled_grouped_mm]) +======= +@register_meta([aten._scaled_grouped_mm.default]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def meta_scaled_grouped_mm( mat_a: torch.Tensor, mat_b: torch.Tensor, @@ -7695,6 +8138,7 @@ def _constant_pad_nd_meta(input, pad, value=0): f"{l_inp} dimensions.", ) +<<<<<<< HEAD if all(isinstance(p, utils.IntWithoutSymInt) and p <= 0 for p in pad): c_input = input for i in range(l_diff, l_inp): @@ -7709,6 +8153,8 @@ def _constant_pad_nd_meta(input, pad, value=0): return c_input.clone() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_shape = list(input_sizes[:l_diff]) for i in range(l_pad): pad_idx = len(pad) - ((i + 1) * 2) diff --git a/torch/_numpy/_dtypes.py b/torch/_numpy/_dtypes.py index e955a47060fff..e61c55b091011 100644 --- a/torch/_numpy/_dtypes.py +++ b/torch/_numpy/_dtypes.py @@ -1,9 +1,15 @@ # mypy: ignore-errors +<<<<<<< HEAD """Define analogs of numpy dtypes supported by pytorch. Define the scalar types and supported dtypes and numpy <--> torch dtype mappings. """ +======= +""" Define analogs of numpy dtypes supported by pytorch. +Define the scalar types and supported dtypes and numpy <--> torch dtype mappings. +""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import builtins import torch diff --git a/torch/_numpy/_dtypes_impl.py b/torch/_numpy/_dtypes_impl.py index feed9c4600501..fd081c14b179b 100644 --- a/torch/_numpy/_dtypes_impl.py +++ b/torch/_numpy/_dtypes_impl.py @@ -1,11 +1,18 @@ # mypy: ignore-errors +<<<<<<< HEAD """Dtypes/scalar type implementations with torch dtypes. +======= +"""Dtypes/scalar type implementaions with torch dtypes. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Here `dtype` is always a torch.dtype, this module knows nothing about scalar types, wrapper dtypes or anything like that. PyTorch only. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from collections import namedtuple import torch diff --git a/torch/_numpy/_funcs_impl.py b/torch/_numpy/_funcs_impl.py index 19748a08b9dec..43d3b6a18ef6f 100644 --- a/torch/_numpy/_funcs_impl.py +++ b/torch/_numpy/_funcs_impl.py @@ -5,7 +5,10 @@ Things imported from here have numpy-compatible signatures but operate on pytorch tensors. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Contents of this module ends up in the main namespace via _funcs.py # where type annotations are used in conjunction with the @normalizer decorator. from __future__ import annotations @@ -96,7 +99,11 @@ def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"): else: out_dtype = _dtypes_impl.result_type_impl(*tensors) +<<<<<<< HEAD # cast input arrays if necessary; do not broadcast them against `out` +======= + # cast input arrays if necessary; do not broadcast them agains `out` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensors = _util.typecast_tensors(tensors, out_dtype, casting) return tensors @@ -1290,7 +1297,11 @@ def cross(a: ArrayLike, b: ArrayLike, axisa=-1, axisb=-1, axisc=-1, axis=None): def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=False): # Have to manually normalize *operands and **kwargs, following the NumPy signature +<<<<<<< HEAD # We have a local import to avoid polluting the global space, as it will be then +======= + # We have a local import to avoid poluting the global space, as it will be then +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # exported in funcs.py from ._ndarray import ndarray from ._normalizations import ( diff --git a/torch/_numpy/_ndarray.py b/torch/_numpy/_ndarray.py index f192a39dd0296..93adf1bb6970c 100644 --- a/torch/_numpy/_ndarray.py +++ b/torch/_numpy/_ndarray.py @@ -169,6 +169,7 @@ def _upcast_int_indices(index): return index +<<<<<<< HEAD def _has_advanced_indexing(index): """Check if there's any advanced indexing""" return any( @@ -289,6 +290,8 @@ def unsqueeze_fn(x): return index, lambda x: x, lambda x: x +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Used to indicate that a parameter is unspecified (as opposed to explicitly # `None`) class _Unspecified: @@ -588,25 +591,36 @@ def neg_step(i, s): index = neg_step(0, index) index = _util.ndarrays_to_tensors(index) index = _upcast_int_indices(index) +<<<<<<< HEAD # Apply NumPy-compatible indexing conversion index = _numpy_compatible_indexing(index) # Apply NumPy-compatible empty ellipsis behavior index, maybe_squeeze, _ = _numpy_empty_ellipsis_patch(index, tensor.ndim) return maybe_squeeze(ndarray(tensor.__getitem__(index))) +======= + return ndarray(tensor.__getitem__(index)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __setitem__(self, index, value): index = _util.ndarrays_to_tensors(index) index = _upcast_int_indices(index) +<<<<<<< HEAD # Apply NumPy-compatible indexing conversion index = _numpy_compatible_indexing(index) # Apply NumPy-compatible empty ellipsis behavior index, _, maybe_unsqueeze = _numpy_empty_ellipsis_patch(index, self.tensor.ndim) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not _dtypes_impl.is_scalar(value): value = normalize_array_like(value) value = _util.cast_if_needed(value, self.tensor.dtype) +<<<<<<< HEAD return self.tensor.__setitem__(index, maybe_unsqueeze(value)) +======= + return self.tensor.__setitem__(index, value) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) take = _funcs.take put = _funcs.put diff --git a/torch/_numpy/_normalizations.py b/torch/_numpy/_normalizations.py index 82cdb2b0b11b3..0ca9e3b651642 100644 --- a/torch/_numpy/_normalizations.py +++ b/torch/_numpy/_normalizations.py @@ -1,7 +1,12 @@ # mypy: ignore-errors +<<<<<<< HEAD """ "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on.""" +======= +""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on. +""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from __future__ import annotations import functools @@ -43,12 +48,17 @@ NotImplementedType = typing.TypeVar("NotImplementedType") +<<<<<<< HEAD def normalize_array_like(x, parm=None): # codespell:ignore +======= +def normalize_array_like(x, parm=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ._ndarray import asarray return asarray(x).tensor +<<<<<<< HEAD def normalize_array_like_or_scalar(x, parm=None): # codespell:ignore if _dtypes_impl.is_scalar_or_symbolic(x): return x @@ -72,6 +82,31 @@ def normalize_seq_array_like(x, parm=None): # codespell:ignore def normalize_dtype(dtype, parm=None): # codespell:ignore +======= +def normalize_array_like_or_scalar(x, parm=None): + if _dtypes_impl.is_scalar_or_symbolic(x): + return x + return normalize_array_like(x, parm) + + +def normalize_optional_array_like_or_scalar(x, parm=None): + if x is None: + return None + return normalize_array_like_or_scalar(x, parm) + + +def normalize_optional_array_like(x, parm=None): + # This explicit normalizer is needed because otherwise normalize_array_like + # does not run for a parameter annotated as Optional[ArrayLike] + return None if x is None else normalize_array_like(x, parm) + + +def normalize_seq_array_like(x, parm=None): + return tuple(normalize_array_like(value) for value in x) + + +def normalize_dtype(dtype, parm=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # cf _decorators.dtype_to_torch torch_dtype = None if dtype is not None: @@ -80,6 +115,7 @@ def normalize_dtype(dtype, parm=None): # codespell:ignore return torch_dtype +<<<<<<< HEAD def normalize_not_implemented(arg, parm): # codespell:ignore if arg != parm.default: # codespell:ignore raise NotImplementedError( @@ -88,6 +124,14 @@ def normalize_not_implemented(arg, parm): # codespell:ignore def normalize_axis_like(arg, parm=None): # codespell:ignore +======= +def normalize_not_implemented(arg, parm): + if arg != parm.default: + raise NotImplementedError(f"'{parm.name}' parameter is not supported.") + + +def normalize_axis_like(arg, parm=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ._ndarray import ndarray if isinstance(arg, ndarray): @@ -95,7 +139,11 @@ def normalize_axis_like(arg, parm=None): # codespell:ignore return arg +<<<<<<< HEAD def normalize_ndarray(arg, parm=None): # codespell:ignore +======= +def normalize_ndarray(arg, parm=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # check the arg is an ndarray, extract its tensor attribute if arg is None: return arg @@ -103,11 +151,19 @@ def normalize_ndarray(arg, parm=None): # codespell:ignore from ._ndarray import ndarray if not isinstance(arg, ndarray): +<<<<<<< HEAD raise TypeError(f"'{parm.name}' must be an array") # codespell:ignore return arg.tensor def normalize_outarray(arg, parm=None): # codespell:ignore +======= + raise TypeError(f"'{parm.name}' must be an array") + return arg.tensor + + +def normalize_outarray(arg, parm=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # almost normalize_ndarray, only return the array, not its tensor if arg is None: return arg @@ -119,11 +175,19 @@ def normalize_outarray(arg, parm=None): # codespell:ignore arg = ndarray(arg) if not isinstance(arg, ndarray): +<<<<<<< HEAD raise TypeError(f"'{parm.name}' must be an array") # codespell:ignore return arg def normalize_casting(arg, parm=None): # codespell:ignore +======= + raise TypeError(f"'{parm.name}' must be an array") + return arg + + +def normalize_casting(arg, parm=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if arg not in ["no", "equiv", "safe", "same_kind", "unsafe"]: raise ValueError( f"casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe' (got '{arg}')" @@ -147,10 +211,17 @@ def normalize_casting(arg, parm=None): # codespell:ignore } +<<<<<<< HEAD def maybe_normalize(arg, parm): # codespell:ignore """Normalize arg if a normalizer is registered.""" normalizer = normalizers.get(parm.annotation, None) # codespell:ignore return normalizer(arg, parm) if normalizer else arg # codespell:ignore +======= +def maybe_normalize(arg, parm): + """Normalize arg if a normalizer is registered.""" + normalizer = normalizers.get(parm.annotation, None) + return normalizer(arg, parm) if normalizer else arg +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ### Return value helpers ### @@ -217,8 +288,13 @@ def wrapped(*args, **kwds): # NB: extra unknown arguments: pass through, will raise in func(*args) below args = ( tuple( +<<<<<<< HEAD maybe_normalize(arg, parm) # codespell:ignore for arg, parm in zip(args, params.values()) # codespell:ignore +======= + maybe_normalize(arg, parm) + for arg, parm in zip(args, params.values()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) + args[len(params.values()) :] ) diff --git a/torch/_numpy/_reductions_impl.py b/torch/_numpy/_reductions_impl.py index 4afc217ebd4b7..8ac3001447c78 100644 --- a/torch/_numpy/_reductions_impl.py +++ b/torch/_numpy/_reductions_impl.py @@ -1,11 +1,18 @@ # mypy: ignore-errors +<<<<<<< HEAD """Implementation of reduction operations, to be wrapped into arrays, dtypes etc +======= +""" Implementation of reduction operations, to be wrapped into arrays, dtypes etc +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) in the 'public' layer. Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from __future__ import annotations import functools diff --git a/torch/_numpy/_util.py b/torch/_numpy/_util.py index fdb1736a1d0f7..ab59172bb996a 100644 --- a/torch/_numpy/_util.py +++ b/torch/_numpy/_util.py @@ -1,6 +1,11 @@ # mypy: ignore-errors +<<<<<<< HEAD """Assorted utilities, which do not need anything other then torch and stdlib.""" +======= +"""Assorted utilities, which do not need anything other then torch and stdlib. +""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import operator @@ -204,7 +209,11 @@ def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0): Notes ----- +<<<<<<< HEAD This is almost a "tensor_like" coercive function. Does not handle wrapper +======= + This is almost a "tensor_like" coersion function. Does not handle wrapper +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ndarrays (those should be handled in the ndarray-aware layer prior to invoking this function). """ diff --git a/torch/_numpy/random.py b/torch/_numpy/random.py index a3d4a1c73241f..00f322aa24a5a 100644 --- a/torch/_numpy/random.py +++ b/torch/_numpy/random.py @@ -7,7 +7,10 @@ Q: default dtype is float64 in numpy """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from __future__ import annotations import functools diff --git a/torch/_numpy/testing/utils.py b/torch/_numpy/testing/utils.py index cd0d33893ac28..c49186f306c7c 100644 --- a/torch/_numpy/testing/utils.py +++ b/torch/_numpy/testing/utils.py @@ -4,7 +4,10 @@ Utility function to facilitate testing. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import contextlib import gc import operator @@ -168,7 +171,11 @@ def assert_equal(actual, desired, err_msg="", verbose=True): Examples -------- +<<<<<<< HEAD >>> np.testing.assert_equal([4, 5], [4, 6]) +======= + >>> np.testing.assert_equal([4,5], [4,6]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Traceback (most recent call last): ... AssertionError: @@ -299,12 +306,17 @@ def print_assert_equal(test_string, actual, desired): Examples -------- +<<<<<<< HEAD >>> np.testing.print_assert_equal( ... "Test XYZ of func xyz", [0, 1], [0, 1] ... ) # doctest: +SKIP >>> np.testing.print_assert_equal( ... "Test XYZ of func xyz", [0, 1], [0, 2] ... ) # doctest: +SKIP +======= + >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1]) # doctest: +SKIP + >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2]) # doctest: +SKIP +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Traceback (most recent call last): ... AssertionError: Test XYZ of func xyz failed @@ -382,9 +394,14 @@ def assert_almost_equal(actual, desired, decimal=7, err_msg="", verbose=True): ACTUAL: 2.3333333333333 DESIRED: 2.33333334 +<<<<<<< HEAD >>> assert_almost_equal( ... np.array([1.0, 2.3333333333333]), np.array([1.0, 2.33333334]), decimal=9 ... ) +======= + >>> assert_almost_equal(np.array([1.0,2.3333333333333]), + ... np.array([1.0,2.33333334]), decimal=9) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Traceback (most recent call last): ... AssertionError: @@ -493,6 +510,7 @@ def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True Examples -------- +<<<<<<< HEAD >>> np.testing.assert_approx_equal( ... 0.12345677777777e-20, 0.1234567e-20 ... ) # doctest: +SKIP @@ -506,6 +524,13 @@ def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True ... 0.12345672e-20, # doctest: +SKIP ... significant=8, ... ) +======= + >>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20) # doctest: +SKIP + >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20, # doctest: +SKIP + ... significant=8) + >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20, # doctest: +SKIP + ... significant=8) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Traceback (most recent call last): ... AssertionError: @@ -515,7 +540,11 @@ def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True the evaluated condition that raises the exception is +<<<<<<< HEAD >>> abs(0.12345670e-20 / 1e-21 - 0.12345672e-20 / 1e-21) >= 10 ** -(8 - 1) +======= + >>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) True """ @@ -790,16 +819,27 @@ def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False): -------- The first assert does not raise an exception: +<<<<<<< HEAD >>> np.testing.assert_array_equal( ... [1.0, 2.33333, np.nan], [np.exp(0), 2.33333, np.nan] ... ) +======= + >>> np.testing.assert_array_equal([1.0,2.33333,np.nan], + ... [np.exp(0),2.33333, np.nan]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Use `assert_allclose` or one of the nulp (number of floating point values) functions for these cases instead: +<<<<<<< HEAD >>> np.testing.assert_allclose( ... [1.0, np.pi, np.nan], [1, np.sqrt(np.pi) ** 2, np.nan], rtol=1e-10, atol=0 ... ) +======= + >>> np.testing.assert_allclose([1.0,np.pi,np.nan], + ... [1, np.sqrt(np.pi)**2, np.nan], + ... rtol=1e-10, atol=0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) As mentioned in the Notes section, `assert_array_equal` has special handling for scalars. Here the test checks that each value in `x` is 3: @@ -824,7 +864,11 @@ def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False): The `strict` parameter also ensures that the array data types match: >>> x = np.array([2, 2, 2]) +<<<<<<< HEAD >>> y = np.array([2.0, 2.0, 2.0], dtype=np.float32) +======= + >>> y = np.array([2., 2., 2.], dtype=np.float32) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> np.testing.assert_array_equal(x, y, strict=True) Traceback (most recent call last): ... @@ -896,11 +940,19 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg="", verbose=True): -------- the first assert does not raise an exception +<<<<<<< HEAD >>> np.testing.assert_array_almost_equal([1.0, 2.333, np.nan], [1.0, 2.333, np.nan]) >>> np.testing.assert_array_almost_equal( ... [1.0, 2.33333, np.nan], [1.0, 2.33339, np.nan], decimal=5 ... ) +======= + >>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan], + ... [1.0,2.333,np.nan]) + + >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], + ... [1.0,2.33339,np.nan], decimal=5) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Traceback (most recent call last): ... AssertionError: @@ -912,9 +964,14 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg="", verbose=True): x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64) y: torch.ndarray([1.0000, 2.3334, nan], dtype=float64) +<<<<<<< HEAD >>> np.testing.assert_array_almost_equal( ... [1.0, 2.33333, np.nan], [1.0, 2.33333, 5], decimal=5 ... ) +======= + >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], + ... [1.0,2.33333, 5], decimal=5) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Traceback (most recent call last): ... AssertionError: @@ -1070,8 +1127,13 @@ def assert_string_equal(actual, desired): Examples -------- +<<<<<<< HEAD >>> np.testing.assert_string_equal("abc", "abc") # doctest: +SKIP >>> np.testing.assert_string_equal("abc", "abcd") # doctest: +SKIP +======= + >>> np.testing.assert_string_equal('abc', 'abc') # doctest: +SKIP + >>> np.testing.assert_string_equal('abc', 'abcd') # doctest: +SKIP +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Traceback (most recent call last): File "", line 1, in ... @@ -1357,11 +1419,19 @@ def assert_array_almost_equal_nulp(x, y, nulp=1): Examples -------- +<<<<<<< HEAD >>> x = np.array([1.0, 1e-10, 1e-20]) >>> eps = np.finfo(x.dtype).eps >>> np.testing.assert_array_almost_equal_nulp(x, x * eps / 2 + x) # doctest: +SKIP >>> np.testing.assert_array_almost_equal_nulp(x, x * eps + x) # doctest: +SKIP +======= + >>> x = np.array([1., 1e-10, 1e-20]) + >>> eps = np.finfo(x.dtype).eps + >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x) # doctest: +SKIP + + >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x) # doctest: +SKIP +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Traceback (most recent call last): ... AssertionError: X and Y are not equal to 1 ULP (max is 2) @@ -1420,7 +1490,11 @@ def assert_array_max_ulp(a, b, maxulp=1, dtype=None): Examples -------- +<<<<<<< HEAD >>> a = np.linspace(0.0, 1.0, 100) +======= + >>> a = np.linspace(0., 1., 100) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a))) # doctest: +SKIP """ @@ -1578,7 +1652,11 @@ def assert_warns(warning_class, *args, **kwargs): >>> import warnings >>> def deprecated_func(num): ... warnings.warn("Please upgrade", DeprecationWarning) +<<<<<<< HEAD ... return num * num +======= + ... return num*num +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> with np.testing.assert_warns(DeprecationWarning): ... assert deprecated_func(4) == 16 >>> # or passing a func @@ -1679,6 +1757,7 @@ def inp(): yield out, inp(), ufmt % (o, o, s, dtype, "out of place") d = inp() yield d, d, ufmt % (o, o, s, dtype, "in place") +<<<<<<< HEAD yield ( out[1:], inp()[:-1], @@ -1702,6 +1781,21 @@ def inp(): dtype, "out of place", ), +======= + yield out[1:], inp()[:-1], ufmt % ( + o + 1, + o, + s - 1, + dtype, + "out of place", + ) + yield out[:-1], inp()[1:], ufmt % ( + o, + o + 1, + s - 1, + dtype, + "out of place", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) yield inp()[:-1], inp()[1:], ufmt % (o, o + 1, s - 1, dtype, "aliased") yield inp()[1:], inp()[:-1], ufmt % (o + 1, o, s - 1, dtype, "aliased") @@ -1717,6 +1811,7 @@ def inp1(): yield d, d, inp2(), bfmt % (o, o, o, s, dtype, "in place1") d = inp2() yield d, inp1(), d, bfmt % (o, o, o, s, dtype, "in place2") +<<<<<<< HEAD yield ( out[1:], inp1()[:-1], @@ -1800,6 +1895,55 @@ def inp1(): dtype, "aliased", ), +======= + yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % ( + o + 1, + o, + o, + s - 1, + dtype, + "out of place", + ) + yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % ( + o, + o + 1, + o, + s - 1, + dtype, + "out of place", + ) + yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % ( + o, + o, + o + 1, + s - 1, + dtype, + "out of place", + ) + yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % ( + o + 1, + o, + o, + s - 1, + dtype, + "aliased", + ) + yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % ( + o, + o + 1, + o, + s - 1, + dtype, + "aliased", + ) + yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % ( + o, + o, + o + 1, + s - 1, + dtype, + "aliased", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -1880,10 +2024,16 @@ class clear_and_catch_warnings(warnings.catch_warnings): -------- >>> import warnings >>> with np.testing.clear_and_catch_warnings( # doctest: +SKIP +<<<<<<< HEAD ... modules=[np.core.fromnumeric] ... ): ... warnings.simplefilter("always") ... warnings.filterwarnings("ignore", module="np.core.fromnumeric") +======= + ... modules=[np.core.fromnumeric]): + ... warnings.simplefilter('always') + ... warnings.filterwarnings('ignore', module='np.core.fromnumeric') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ... # do something that raises a warning but ignore those in ... # np.core.fromnumeric """ @@ -1981,8 +2131,11 @@ class suppress_warnings: sup = np.testing.suppress_warnings() sup.filter(module=np.ma.core) # module must match exactly +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @sup def some_function(): # do something which causes a warning in np.ma.core diff --git a/torch/_ops.py b/torch/_ops.py index b351aa17dfa76..9d6ac281a1181 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -85,7 +85,11 @@ def __init__(self): # This table allows you to override the behavior of a particular # dispatch key to call a custom Python function, rather than the +<<<<<<< HEAD # ordinary C++ configured behavior. This is the raison d'etre of # codespell:ignore +======= + # ordinary C++ configured behavior. This is the raison d'etre of +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Python dispatcher: to let you program the dispatcher from Python # in case you need something unusual, and don't want to clobber # the existing registrations using the Python operator registration @@ -267,7 +271,10 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type] DispatchKey.BackendSelect, DispatchKey.AutocastCPU, # type: ignore[attr-defined] DispatchKey.AutocastCUDA, # type: ignore[attr-defined] +<<<<<<< HEAD DispatchKey.AutocastXPU, # type: ignore[attr-defined] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] @@ -298,7 +305,11 @@ def __init__(self, name, *, cacheable=False): self.fallthrough(dispatch_key) # [NOTE] We have to register pre-dispatch key implementation +<<<<<<< HEAD # because sometimes HOP use aot-dispatch tracing to detect certain +======= + # because sometimes HOP use aot-dispatch tracing to detect certaion +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # mutations. This is problematic when we are functionalizing HOP # during pre-dispatch because when the inner tracer starts, it will see # that PreDispatch key is still active. In that case, we just redispatch @@ -416,6 +427,7 @@ def check_overloaded(arg): # TODO(rzou): we should support torch_dispatch calling convention too. result = handler(mode, *args, **kwargs) else: +<<<<<<< HEAD if curr_mode.supports_higher_order_operators: with _pop_mode_temporarily() as mode: return curr_mode.__torch_dispatch__(self, [], args, kwargs) @@ -429,6 +441,12 @@ def check_overloaded(arg): f" {curr_mode}.__torch_dispatch__ or" f" returning NotImplemented when not supported." ) +======= + raise NotImplementedError( + f"There was no rule registered for HOP {self._name} and mode {curr_mode}. " + f"We recommend filing an issue." + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if result is not NotImplemented: return result @@ -467,12 +485,19 @@ def check_overloaded(arg): # All handlers returned NotImplemented raise TypeError( +<<<<<<< HEAD f"HigherOrderOperator '{self._name}' is not supported for the given input types. " f"This typically happens when using custom tensor types or dispatch modes that don't " f"have implementations for this operation.\n\n" f"Current mode: {curr_mode}\n" f"Input types: {[type(a).__name__ for a in overloaded_args]}\n\n" f"To fix this, can add support for '{self._name}' in {curr_mode}'s __torch_dispatch__\n" +======= + f"Multiple dispatch failed for {self._name}. There was no registered that " + f"did not return NotImplemented. Use HOP.py_impl to register some. " + f"Tried mode: {curr_mode}) and subclasses: " + f"{[type(a) for a in overloaded_args]}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) functionality_key = torch._C._to_functionality_key(dispatch_key) # type: ignore[attr-defined] @@ -1116,7 +1141,11 @@ def _dispatch_in_python( f" but no python implementation is found." f" Please file an issue on this when you encounter this error." f" This error can happen when you export or compile the model." +<<<<<<< HEAD f" It can still happen even if a C++ implementation for {dispatch_key}. " +======= + f" It can still happpen even if a C++ implementation for {dispatch_key}. " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f" has been registered. That's because FakeScriptObject purely lives in python and cannot work " f" with a C++ implementation." ) @@ -1264,7 +1293,11 @@ def overloads(self): def _call_overload_packet_from_python( op: OpOverloadPacket[_P, _T], *args: _P.args, **kwargs: _P.kwargs ) -> _T: +<<<<<<< HEAD # Reuse the torch function handling logic in cpp +======= + # Re-use the torch function handling logic in cpp +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet( op, *args, **kwargs ) @@ -1479,15 +1512,25 @@ def load_library(self, path): Args: path (str): A path to a shared library to load. """ +<<<<<<< HEAD +======= + if torch._running_with_deploy(): + return + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) path = _utils_internal.resolve_library_path(path) with dl_open_guard(): # Import the shared library into the process, thus running its # static (global) initialization code in order to register custom # operators with the JIT. +<<<<<<< HEAD try: ctypes.CDLL(path) except Exception as e: raise OSError(f"Could not load this library: {path}") from e +======= + ctypes.CDLL(path) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.loaded_libraries.add(path) diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index bb26bbb508bd6..203b788db12d8 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -302,7 +302,11 @@ def _backend_select_impl(*args, **kwargs): else: return _prim_impl(*args, **kwargs) +<<<<<<< HEAD name = schema.split("(", maxsplit=1)[0] +======= + name = schema.split("(")[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) schema = schema[len(name) :] # register non-functional ops with old custom ops API @@ -2174,7 +2178,11 @@ def _resize_aten(a: Tensor, shape: ShapeType) -> Tensor: _resize_doc = """ Gives a tensor with no elements a new shape, returning the modified tensor. +<<<<<<< HEAD The tensor's strides are contiguous and its values are uninitialized. +======= + The tensor's strides are contiguous and its values are unitialized. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ # TODO: review support arbitrary resizes @@ -2513,11 +2521,15 @@ def _full_aten( ) -> Tensor: # Note that Mypy thinks torch.full can't accept a complex fill_value return torch.full( +<<<<<<< HEAD shape, fill_value, dtype=dtype, device=device, requires_grad=requires_grad, # type: ignore[arg-type] +======= + shape, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -2560,11 +2572,15 @@ def _full_like_aten( ) -> Tensor: # Note that Mypy thinks torch.full can't accept a complex fill_value return torch.full_like( +<<<<<<< HEAD a, fill_value, dtype=dtype, device=device, requires_grad=requires_grad, # type: ignore[arg-type] +======= + a, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/torch/_prims/debug_prims.py b/torch/_prims/debug_prims.py index d52462815229b..a9ff1dc3db85d 100644 --- a/torch/_prims/debug_prims.py +++ b/torch/_prims/debug_prims.py @@ -1,5 +1,10 @@ +<<<<<<< HEAD import contextlib from collections.abc import Generator, Sequence +======= +# mypy: allow-untyped-defs +import contextlib +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Optional import torch @@ -10,7 +15,11 @@ @contextlib.contextmanager +<<<<<<< HEAD def load_tensor_reader(loc: str) -> Generator[None, None, None]: +======= +def load_tensor_reader(loc): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) global LOAD_TENSOR_READER assert LOAD_TENSOR_READER is None # load_tensor is an "op", and we will play merry hell on @@ -26,13 +35,18 @@ def load_tensor_reader(loc: str) -> Generator[None, None, None]: LOAD_TENSOR_READER = None +<<<<<<< HEAD def register_debug_prims() -> None: +======= +def register_debug_prims(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.library.define( "debugprims::load_tensor", "(str name, int[] size, int[] stride, *, ScalarType dtype, Device device) -> Tensor", ) @torch.library.impl("debugprims::load_tensor", "BackendSelect") +<<<<<<< HEAD def load_tensor_factory( name: str, size: Sequence[int], @@ -40,6 +54,9 @@ def load_tensor_factory( dtype: torch.dtype, device: torch.device, ) -> torch.Tensor: +======= + def load_tensor_factory(name, size, stride, dtype, device): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if LOAD_TENSOR_READER is None: from torch._dynamo.testing import rand_strided @@ -56,5 +73,9 @@ def load_tensor_factory( # Unlike the other properties, we will do coercions for dtype # mismatch if r.dtype != dtype: +<<<<<<< HEAD r = clone_input(r, dtype=dtype) # type: ignore[no-untyped-call] +======= + r = clone_input(r, dtype=dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return r diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py index e6ed4a4e3ea6b..10baa175cf810 100644 --- a/torch/_prims/rng_prims.py +++ b/torch/_prims/rng_prims.py @@ -1,5 +1,9 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD from typing import cast, Optional +======= +from typing import Optional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.utils._pytree as pytree @@ -69,10 +73,19 @@ def philox_rand_offset( curand4_engine_calls = 4 device_property = torch.cuda.get_device_properties(torch.cuda.current_device()) blocks_per_sm = device_property.max_threads_per_multi_processor // block_size +<<<<<<< HEAD num = cast(int, numel) grid_size = (num + block_size - 1) // block_size grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm) return ((num - 1) // (block_size * grid_size * unroll) + 1) * curand4_engine_calls +======= + grid_size = (numel + block_size - 1) // block_size + grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm) + offset = ( + (numel - 1) // (block_size * grid_size * unroll) + 1 + ) * curand4_engine_calls + return offset +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def register_philox_rand(): @@ -340,9 +353,15 @@ def impl_cuda(op, *args, rng_state=None, **kwargs): @graphsafe_run_with_rng_state.py_impl(DispatchKey.BackendSelect) def impl_backend_select(op, *args, rng_state=None, **kwargs): device = get_device(args, kwargs) +<<<<<<< HEAD assert device == "cuda", ( f"GraphSafe RNG operations only supported for CUDA, got {device}" ) +======= + assert ( + device == "cuda" + ), f"GraphSafe RNG operations only supported for CUDA, got {device}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return impl_cuda(op, *args, rng_state=rng_state, **kwargs) @graphsafe_run_with_rng_state.py_impl(FakeTensorMode) diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 91b0cc1f68d47..63a1230af3de0 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -5,7 +5,11 @@ import typing import warnings from collections.abc import Sequence +<<<<<<< HEAD from contextlib import AbstractContextManager, nullcontext +======= +from contextlib import nullcontext +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from enum import Enum from functools import reduce from typing import ( @@ -33,6 +37,7 @@ import sympy class _WorksWithInt(typing.Protocol): +<<<<<<< HEAD def __add__(self, other: Any) -> typing.Self: ... def __radd__(self, other: Any) -> typing.Self: ... @@ -40,6 +45,19 @@ def __radd__(self, other: Any) -> typing.Self: ... def __mul__(self, other: Any) -> typing.Self: ... def __rmul__(self, other: Any) -> typing.Self: ... +======= + def __add__(self, other: Any) -> typing.Self: + ... + + def __radd__(self, other: Any) -> typing.Self: + ... + + def __mul__(self, other: Any) -> typing.Self: + ... + + def __rmul__(self, other: Any) -> typing.Self: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _IntLikeT = TypeVar("_IntLikeT", bound=_WorksWithInt) @@ -107,18 +125,36 @@ def __rmul__(self, other: Any) -> typing.Self: ... def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool: +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import guard_or_true +======= + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if len(a) != len(b): return False for x, y in zip(a, b): if allow_rhs_unbacked: +<<<<<<< HEAD if isinstance(y, torch.SymInt): continue # if we do not know, then they are not the same. if guard_or_true(x != y): +======= + # TODO: We should check that the symbols are consistent + # with each other + if isinstance(y, torch.SymInt): + continue + # NB: Naively, you would not expect to have to do an oblivious guard + # here because there is seemingly no broadcasting here, but in fact we + # use this in some situations to determine if we need to do an expand + # on the tensor because they don't line up, so you can definitely end + # up trying to prove u0 != 1 in this situation. See + # python test/test_proxy_tensor.py -k test_cumsum_unbacked + if guard_size_oblivious(x != y): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False return True @@ -248,6 +284,7 @@ def check_all_strides( return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False) +<<<<<<< HEAD def check_contiguous_sizes_strides(sizes, strides, false_if_dde=False): """ Performs an equality check between actual stride & expected stride (based on composed sizes), @@ -292,6 +329,8 @@ def eval_eager(x): return True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This function is equivalent to compute_contiguous() from TensorImpl.cpp def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool: """ @@ -302,17 +341,48 @@ def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool: """ from torch.fx.experimental.symbolic_shapes import ( guard_or_false, +<<<<<<< HEAD + guard_size_oblivious, + ) + + maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious +======= + guard_or_true, guard_size_oblivious, + is_nested_int, ) maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious + maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if maybe_guard_or_false(a.numel() < 2): return True +<<<<<<< HEAD return check_contiguous_sizes_strides( a.shape, a.stride(), false_if_dde=false_if_dde ) +======= + expected_stride = 1 + for x, y in reversed(tuple(zip(a.shape, a.stride()))): + # Skips checking strides when a dimension has length 1. + if maybe_guard_or_false(x == 1): + continue + + if maybe_guard_or_true(y != expected_stride): + return False + + # if x is 0 then a is contiguous anyway. So in the check above for non-contiguity condition we can + # can assume x is not 0 in expected_stride equation. This make the check consistent with + # make_contiguous_strides_for. If we make a tensor and used strides from make_contiguous_strides_for + # and then called definitely_contiguous we should get True. + expected_stride *= ( + x if is_nested_int(x) else sym_max(x, 1) + ) # type:ignore[assignment] + + return True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp @@ -321,6 +391,7 @@ def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool: if a.ndim != 4: return False +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true def eval_eager(x): @@ -328,6 +399,16 @@ def eval_eager(x): maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager +======= + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + guard_size_oblivious, + ) + + maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious + maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expected_stride = 1 for idx in (1, 3, 2, 0): @@ -349,6 +430,7 @@ def is_channels_last_contiguous_3d(a: Tensor, false_if_dde=False) -> bool: if a.ndim != 5: return False +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true def eval_eager(x): @@ -356,6 +438,16 @@ def eval_eager(x): maybe_guard_or_false = guard_or_false if false_if_dde else eval_eager maybe_guard_or_true = guard_or_true if false_if_dde else eval_eager +======= + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + guard_size_oblivious, + ) + + maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious + maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expected_stride = 1 for idx in (1, 4, 3, 2, 0): @@ -405,22 +497,38 @@ def is_contiguous_for_memory_format( # type: ignore[return] ) +<<<<<<< HEAD def is_contiguous_or_false(a: TensorLikeType) -> bool: +======= +def definitely_contiguous(a: TensorLikeType) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return is_contiguous(a, false_if_dde=True) # similar to is_channels_last_contiguous_2d but return false on data dependency. +<<<<<<< HEAD def is_channels_last_contiguous_or_false_2d(a: Tensor) -> bool: +======= +def definitely_channels_last_contiguous_2d(a: Tensor) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return is_channels_last_contiguous_2d(a, false_if_dde=True) # similar to is_channels_last_contiguous_3d but return false on data dependency. +<<<<<<< HEAD def is_channels_last_contiguous_or_false_3d(a: Tensor) -> bool: +======= +def definitely_channels_last_contiguous_3d(a: Tensor) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return is_channels_last_contiguous_3d(a, false_if_dde=True) # similar to is_contiguous_for_memory_format but return false on data dependency. +<<<<<<< HEAD def is_contiguous_for_memory_format_or_false( # type: ignore[return] +======= +def definitely_contiguous_for_memory_format( # type: ignore[return] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) a: Tensor, *, memory_format: torch.memory_format ) -> bool: return is_contiguous_for_memory_format( @@ -446,6 +554,7 @@ def is_channels_last_contiguous(a: Tensor) -> bool: # similar to is_channels_last_contiguous but return false on data dependency. +<<<<<<< HEAD def is_channels_last_contiguous_or_false(a: Tensor) -> bool: return is_channels_last_contiguous_or_false_2d( a @@ -473,6 +582,37 @@ def _is_non_overlapping_and_dense_or_false(sizes, strides) -> bool: # non-overlapping and "dense" if their stride is one if len(sizes) == 1: return guard_or_false(strides[0] == 1) +======= +def definitely_channels_last_contiguous(a: Tensor) -> bool: + return definitely_channels_last_contiguous_2d( + a + ) or definitely_channels_last_contiguous_3d(a) + + +def is_non_overlapping_and_dense(a: Tensor) -> bool: + """ + True when a tensor is non-overlapping and dense. + + A tensor is non-overlapping and dense when there exists a permutation of + its dimensions that is contiguous. + """ + + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if a.is_sparse: + return False + + # Short-circuits if the tensor is already contiguous or channels-last contiguous + if definitely_contiguous(a) or definitely_channels_last_contiguous(a): + return True + + # The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp + + # Short-circuits for tensors of rank one, which are + # non-overlapping and "dense" if their stride is one + if a.ndim == 1: + return a.stride()[0] == 1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous # Sorts (length, stride) pairs by stride @@ -485,6 +625,7 @@ class K(NamedTuple): stride: int def __lt__(self, other): +<<<<<<< HEAD # for backed symbols, this is practically a < operation # for unbacked, we return True if < is statically known, # then try to answer this symbolically, with stride ordering semantics @@ -523,6 +664,35 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool: return False return _is_non_overlapping_and_dense_or_false(a.shape, a.stride()) +======= + return guard_size_oblivious(self.stride < other.stride) + + def __gt__(self, other): + return guard_size_oblivious(self.stride > other.stride) + + def __le__(self, other): + return guard_size_oblivious(self.stride <= other.stride) + + def __ge__(self, other): + return guard_size_oblivious(self.stride >= other.stride) + + def __eq__(self, other): + return guard_size_oblivious(self.stride == other.stride) + + lengths_and_strides = sorted(map(K, a.shape, a.stride())) + + expected_stride = 1 + for length, stride in lengths_and_strides: + if guard_size_oblivious(length == 1): + continue + + if guard_size_oblivious(stride != expected_stride): + return False + + expected_stride *= length + + return True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NOTE: Based on the implementation in TensorIterator.cpp, but note that @@ -536,10 +706,14 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool: def compute_elementwise_output_logical_to_physical_perm( *tensors, _skip_checks=False ) -> list[int]: +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import ( guard_or_false, guard_size_oblivious, ) +======= + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not _skip_checks and len(tensors) == 0: msg = "Can't compute elementwise output strides for zero tensors!" @@ -573,6 +747,7 @@ def compute_elementwise_output_logical_to_physical_perm( is_contiguous = True is_channels_last = True for t in tensors: +<<<<<<< HEAD is_contiguous = is_contiguous and is_contiguous_for_memory_format_or_false( t, memory_format=torch.contiguous_format ) @@ -581,6 +756,13 @@ def compute_elementwise_output_logical_to_physical_perm( and is_contiguous_for_memory_format_or_false( t, memory_format=torch.channels_last ) +======= + is_contiguous = is_contiguous and definitely_contiguous_for_memory_format( + t, memory_format=torch.contiguous_format + ) + is_channels_last = is_channels_last and definitely_contiguous_for_memory_format( + t, memory_format=torch.channels_last +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if is_contiguous and not is_channels_last: @@ -595,11 +777,16 @@ def should_swap(idx_a, idx_b): for tensor in tensors: stride_a = tensor.stride()[idx_a] stride_b = tensor.stride()[idx_b] +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if guard_size_oblivious(stride_a == 0) or guard_size_oblivious( stride_b == 0 ): continue +<<<<<<< HEAD if guard_or_false(stride_a == stride_b): if guard_size_oblivious(shape[idx_a] > shape[idx_b]): return 1 @@ -612,6 +799,8 @@ def should_swap(idx_a, idx_b): elif guard_or_false(stride_b == 1): return 1 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if guard_size_oblivious(stride_a < stride_b): return -1 @@ -952,7 +1141,11 @@ def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]: # Extracts dimensions that might be passed either as a list/tuple or as varargs. # A typical case is Tensor.permute . def extract_dims_from_varargs( +<<<<<<< HEAD dims: Union[DimsSequenceType, tuple[DimsSequenceType, ...]], +======= + dims: Union[DimsSequenceType, tuple[DimsSequenceType, ...]] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> DimsSequenceType: if dims and isinstance(dims[0], Sequence): assert len(dims) == 1 @@ -1274,7 +1467,11 @@ def get_higher_dtype( assert b is None or isinstance(b, (torch.dtype, TensorLike, Number)) def _extract_dtype( +<<<<<<< HEAD x: Optional[Union[torch.dtype, TensorLikeType, NumberType]], +======= + x: Optional[Union[torch.dtype, TensorLikeType, NumberType]] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Optional[torch.dtype]: if x is None: return None @@ -1492,7 +1689,11 @@ class RETURN_TYPE(Enum): # TODO: when NumberType contains the sym types, can simplify this def number_type( +<<<<<<< HEAD x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool], +======= + x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> type: if isinstance(x, torch.SymInt): return int @@ -1748,7 +1949,13 @@ def make_contiguous_strides_for( strides = [] for l in reversed(shape): strides.append(multiplier) +<<<<<<< HEAD multiplier *= l if is_nested_int(l) else sym_max(l, 1) # type:ignore[assignment] +======= + multiplier *= ( + l if is_nested_int(l) else sym_max(l, 1) + ) # type:ignore[assignment] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result = tuple(reversed(strides)) @@ -1898,9 +2105,13 @@ def compute_required_storage_length( >>> # xdoctest: +SKIP(failing) >>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11)) +<<<<<<< HEAD >>> size = compute_required_storage_length( ... t2.shape, t2.stride(), t2.storage_offset() ... ) +======= + >>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> size == t.storage().size() True @@ -1910,6 +2121,7 @@ def compute_required_storage_length( >>> slice.storage().size() 100 +<<<<<<< HEAD >>> compute_required_storage_length( ... slice.shape, slice.stride(), slice.storage_offset() ... ) @@ -1922,6 +2134,16 @@ def compute_required_storage_length( # Note: we are unsafely assuming tensor is not empty here, without # runtime assertions. if guard_or_false(reduce(operator.mul, shape, 1) == 0): +======= + >>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset()) + 40 + + """ + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + # Short-circuits if the shape has no elements + if guard_size_oblivious(reduce(operator.mul, shape, 1) == 0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return 0 max_offset = sum((x - 1) * y for x, y in zip(shape, strides)) @@ -2129,9 +2351,13 @@ def alert_not_deterministic(caller: str): class CUDARngStateHelper: @staticmethod +<<<<<<< HEAD def get_torch_state_as_tuple( fake_mode: AbstractContextManager[Any] = nullcontext(), ): +======= + def get_torch_state_as_tuple(fake_mode=nullcontext()): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not torch.cuda.is_available(): raise RuntimeError("CUDA not available") diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index e5e5b13f62c7d..37dacdcf3b9f1 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -94,7 +94,11 @@ class elementwise_type_promotion_wrapper: Takes two kwargs, type_promoting_args and type_promotion_kind. +<<<<<<< HEAD type_promoting_args must be a string Sequence specifying the argument names of all +======= + type_promoting_args must be a string Sequence specifiying the argument names of all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) arguments that participate in type promotion (and should be type promoted). If the arg specifies a Sequence-type then every element of the Sequence will participate in type promotion. @@ -316,7 +320,12 @@ def maybe_check_copy_devices(out): and len(result) == len(out_names) # type: ignore[arg-type] ) or ( +<<<<<<< HEAD fn.__name__ == "unbind" and isinstance(result, (list, tuple)) # type: ignore[arg-type] +======= + fn.__name__ == "unbind" + and isinstance(result, (list, tuple)) # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) # unbind_copy is a special case: see https://github.com/pytorch/pytorch/issues/130829 @@ -341,6 +350,7 @@ def maybe_check_copy_devices(out): assert isinstance(out, TensorLike) # These two operations are done in-place _maybe_resize_out( +<<<<<<< HEAD out, result.shape, # type: ignore[union-attr] maybe_compute_memory_format(result), @@ -350,6 +360,11 @@ def maybe_check_copy_devices(out): copy_to=out, exact_dtype=exact_dtype, ) +======= + out, result.shape, maybe_compute_memory_format(result) # type: ignore[union-attr] + ) + _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: if fn.__name__ != "unbind": assert isinstance(out, tuple) # type: ignore[arg-type] @@ -390,8 +405,12 @@ def maybe_check_copy_devices(out): params = sorted(params, key=lambda p: p.kind) _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] +<<<<<<< HEAD parameters=params, return_annotation=return_type, # type: ignore[arg-type] +======= + parameters=params, return_annotation=return_type # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) _fn.__annotations__ = dict(getattr(fn, "__annotations__", {})) @@ -406,9 +425,13 @@ def maybe_check_copy_devices(out): # Add an indicator attribute that can be used in special cases # where having a function wrapped by `out_wrapper` is not desirable e.g. # jit +<<<<<<< HEAD _fn._torch_decompositions_out_wrapper = ( # type: ignore[attr-defined] f"This function is wrapped by {out_wrapper.__module__}.out_wrapper" ) +======= + _fn._torch_decompositions_out_wrapper = f"This function is wrapped by {out_wrapper.__module__}.out_wrapper" # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return _fn diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 783e440223796..442f573e250df 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -19,6 +19,11 @@ from torch import sym_float, sym_int from torch._prims_common import ( BoolLike, +<<<<<<< HEAD +======= + definitely_contiguous, + definitely_contiguous_for_memory_format, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DeviceLikeType, Dim, DimsSequenceType, @@ -28,8 +33,11 @@ FloatLike, FloatWithoutSymFloat, IntLike, +<<<<<<< HEAD is_contiguous_for_memory_format_or_false, is_contiguous_or_false, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_weakly_lesser_type, Number, NumberType, @@ -385,7 +393,11 @@ def handle_noncontiguous_outputs(input_tlist, output): def _broadcast_shapes(*_shapes): +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import guard_or_false, is_nested_int +======= + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shapes = tuple( (x,) if isinstance(x, IntLike) else x @@ -396,12 +408,19 @@ def _broadcast_shapes(*_shapes): if len(shapes) == 0: return None +<<<<<<< HEAD for shape in shapes: if not isinstance(shape, Sequence): raise RuntimeError( "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape, ) +======= + # Type checking + # TODO: make common validations available as utils + for shape in shapes: + assert isinstance(shape, Sequence) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Computes common shape common_shape: list[Union[int, torch.SymInt]] = [ @@ -409,6 +428,7 @@ def _broadcast_shapes(*_shapes): ] * reduce(max, (len(shape) for shape in shapes)) for arg_idx, shape in enumerate(shapes): for idx in range(-1, -1 - len(shape), -1): +<<<<<<< HEAD # NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1). if is_nested_int(shape[idx]): # Broadcasting is allowed for (j0, 1) or (j0, j0); @@ -422,17 +442,24 @@ def _broadcast_shapes(*_shapes): continue if guard_or_false(common_shape[idx] == 1): +======= + if guard_size_oblivious(common_shape[idx] == 1): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if shape[idx] < 0: raise ValueError( "Attempting to broadcast a dimension with negative length!" ) common_shape[idx] = shape[idx] +<<<<<<< HEAD if not is_nested_int(shape[idx]) and guard_or_false(shape[idx] == 1): # broadcast case . continue else: # If broadcasting is undecided we pick non-broadcast path and add runtime assertion. +======= + elif guard_size_oblivious(shape[idx] != 1): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._check( common_shape[idx] == shape[idx], lambda: f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! " @@ -449,6 +476,7 @@ def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True): *(t.shape if isinstance(t, TensorLike) else None for t in args) ) +<<<<<<< HEAD def should_expand(a: ShapeType, b: ShapeType) -> bool: from torch.fx.experimental.symbolic_shapes import ( guard_or_false, @@ -481,6 +509,8 @@ def should_expand(a: ShapeType, b: ShapeType) -> bool: return False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __maybe_broadcast(x, shape): if x is None: return None @@ -490,7 +520,11 @@ def __maybe_broadcast(x, shape): if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x): return x +<<<<<<< HEAD if should_expand(x.shape, common_shape): +======= + if not utils.same_shape(x.shape, common_shape): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return x.expand(common_shape) return x @@ -2289,21 +2323,32 @@ def _reduction( return result +<<<<<<< HEAD def _make_copy_from_view(fn, return_none_on_out_variant=False): +======= +def _make_copy_from_view(fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy) """ aten_fn = getattr(aten, fn.__name__) annotations = getattr(fn, "__annotations__", {}) +<<<<<<< HEAD # view ops should not change dtypes, this ensures that the decomp path has # the same error checks as eager. fn = out_wrapper(exact_dtype=True)(aten_fn) +======= + fn = out_wrapper()(aten_fn) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @wraps(fn) def _fn(*args, out=None, **kwargs): result = fn(*args, out=out, **kwargs) +<<<<<<< HEAD if return_none_on_out_variant and out is not None: return None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if out is not None: return result @@ -2824,7 +2869,14 @@ def cat_compute_output_memory_format(inputs): utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False) +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import guard_or_false +======= + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_size_oblivious, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This is a bit tricky. Naively, you would expect to just pick one # arbitrary tensor and check that all tensors match this tensor. However, @@ -2878,7 +2930,11 @@ def cat_compute_output_memory_format(inputs): # through), and is load bearing for our Inductor lowerings # (which assume that size oblivious tests are OK to determine # if a shape is permissibly zero.) +<<<<<<< HEAD guard_or_false(tensor.shape[0] == 0), +======= + guard_size_oblivious(tensor.shape[0] == 0), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lambda: f"Number of dimensions of tensors must match. " f"Expected {example.ndim}-D tensors, but got 1-D for " f"tensor number {tensor_idx} in the list", @@ -3032,7 +3088,11 @@ def contiguous( ) # TODO: make logic consistent with aten contiguous +<<<<<<< HEAD if is_contiguous_for_memory_format_or_false(a, memory_format=memory_format): +======= + if definitely_contiguous_for_memory_format(a, memory_format=memory_format): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return a return torch.clone(a, memory_format=memory_format) @@ -3046,7 +3106,11 @@ def dstack(tensors: TensorSequenceType) -> TensorLikeType: @register_decomposition(aten.expand) +<<<<<<< HEAD def expand(a: Tensor, *shape, implicit: bool = False) -> Tensor: +======= +def expand(a: Tensor, *shape) -> Tensor: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_or # NOTE: cannot use utils.extract_shape_from_varargs here @@ -3318,8 +3382,11 @@ def native_layer_norm( bias: Optional[Tensor], eps: float, ) -> tuple[Tensor, Tensor, Tensor]: +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import sym_eq +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) normalized_ndim = len(normalized_shape) torch._check( normalized_ndim >= 1, @@ -3331,7 +3398,11 @@ def native_layer_norm( # while torch.Size([1, 2, 3]) == (1, 2, 3) is True # therefore we use tuple(normalized_shape) torch._check( +<<<<<<< HEAD weight is None or sym_eq(weight.shape, tuple(normalized_shape)), +======= + weight is None or weight.shape == tuple(normalized_shape), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lambda: "Expected weight to be of same shape as normalized_shape, but got " + "weight of shape " + str(weight.shape) # type: ignore[union-attr] @@ -3339,7 +3410,11 @@ def native_layer_norm( + str(normalized_shape), ) torch._check( +<<<<<<< HEAD bias is None or sym_eq(bias.shape, tuple(normalized_shape)), +======= + bias is None or bias.shape == tuple(normalized_shape), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lambda: "Expected bias to be of same shape as normalized_shape, but got " + "bias of shape " + str(bias.shape) # type: ignore[union-attr] @@ -3348,9 +3423,13 @@ def native_layer_norm( ) torch._check( input.ndim >= normalized_ndim +<<<<<<< HEAD and sym_eq( input.shape[(input.ndim - normalized_ndim) :], tuple(normalized_shape) ), +======= + and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lambda: "Given normalized_shape=" + str(normalized_shape) + ", expected input with shape " @@ -3904,7 +3983,11 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL else: return _a +<<<<<<< HEAD if is_contiguous_or_false(a): +======= + if definitely_contiguous(a): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Special-cases for nd_to_1d if len(shape) == 1 and a.ndim > 1: return torch.as_strided(a, [a.numel()], [1]) @@ -4091,15 +4174,24 @@ def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType: @register_decomposition(aten.unbind) def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType: +<<<<<<< HEAD +======= + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dim = utils.canonicalize_dim(t.ndim, dim) torch._check_index( len(t.shape) > 0, lambda: "Dimension specified as 0 but tensor has no dimensions", ) +<<<<<<< HEAD # Note: t.shape[dim] can't be dynamic or unbacked, even if we use guard_or_false here we will fail # later in the split since t.shape[dim] control the number of output tensors. if t.shape[dim] == 0: +======= + if guard_size_oblivious(t.shape[dim] == 0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return () else: return tuple( @@ -4230,7 +4322,11 @@ def index_select(x: TensorLike, dim: int, index: TensorLike): @register_decomposition(aten.squeeze.dims) def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: +<<<<<<< HEAD from torch.fx.experimental.symbolic_shapes import guard_or_false +======= + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if dim is None: dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1) @@ -4245,8 +4341,12 @@ def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType return prims.view_of(a) # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1 +<<<<<<< HEAD # would it be better if we just not allow 1 for unbacked at runtiume? dims = tuple(d for d in dims if guard_or_false(a.shape[d] == 1)) +======= + dims = tuple(d for d in dims if guard_size_oblivious(a.shape[d] == 1)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if len(dims) == 0: return prims.view_of(a) if len(dims) == 1: @@ -6123,7 +6223,11 @@ def bucketize( if n_boundaries == 0: return torch.zeros_like(a) # We are trying to find the bucket (defined by pairs of consecutive elements of `boundaries`) +<<<<<<< HEAD # each element of `a` belongs to. We use binary search to achieve logarithmic complexity, +======= + # each element of `a` belongs to. We use binary search to achieve logarithimic complexity, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # but each step of the search is done "in parallel" over all elements of `a` # can't use int32 as indexes, so we have to do all computations with int64 and convert at the end start = torch.zeros(a.shape, device=a.device, dtype=torch.int64) @@ -6548,7 +6652,11 @@ def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int) permute_copy = _make_copy_from_view(aten.permute) t_copy = _make_copy_from_view(aten.t) transpose_copy = _make_copy_from_view(aten.transpose) +<<<<<<< HEAD unbind_copy = _make_copy_from_view(aten.unbind, return_none_on_out_variant=True) +======= +unbind_copy = _make_copy_from_view(aten.unbind) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unsqueeze_copy = _make_copy_from_view(aten.unsqueeze) view_copy = _make_copy_from_view(aten.view) diff --git a/torch/_refs/fft.py b/torch/_refs/fft.py index e12e4c8e603ba..d08e92c9de089 100644 --- a/torch/_refs/fft.py +++ b/torch/_refs/fft.py @@ -313,8 +313,12 @@ def _canonicalize_fft_shape_and_dim_args( # Translate any -1 values in shape to the default length ret_shape = tuple( +<<<<<<< HEAD s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined] +======= + s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) elif dim is None: # No shape, no dim diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index 28711c2c5485f..47f103b0029d7 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -180,7 +180,11 @@ def vector_norm( if keepdim or x.ndim == 0: return to_result_dtype(x).contiguous() elif dim is None: +<<<<<<< HEAD return to_result_dtype(x).flatten()[0] +======= + return x.flatten()[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: new_shape = [s for d, s in enumerate(x.shape) if d not in dim] return to_result_dtype(x.view(new_shape)).contiguous() diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py index 89ead281d9478..589491312a854 100644 --- a/torch/_refs/nn/functional/__init__.py +++ b/torch/_refs/nn/functional/__init__.py @@ -760,7 +760,11 @@ def _nll_loss_nd( batch_size = input.shape[0] loss = -input[torch.arange(batch_size), target] * current_weight else: +<<<<<<< HEAD # 3D case (N batch size, C classes, K dimensions) +======= + # 3D case (N batch size, C classe, K dimensions) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # input (N batch size, C classes, K) batch_size = input.shape[0] extent = input.shape[2] diff --git a/torch/_strobelight/cli_function_profiler.py b/torch/_strobelight/cli_function_profiler.py index 80108dc99186b..5e56a07bb7681 100644 --- a/torch/_strobelight/cli_function_profiler.py +++ b/torch/_strobelight/cli_function_profiler.py @@ -59,7 +59,11 @@ class StrobelightCLIFunctionProfiler: StrobelightCLIFunctionProfiler can be used to profile a python function and generate a strobelight link with the results. It works on meta servers but +<<<<<<< HEAD does not requires an fbcode target. +======= + does not requries an fbcode target. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) When stop_at_error is false(default), error during profiling does not prevent the work function from running. @@ -310,7 +314,11 @@ def strobelight( profiler = StrobelightCLIFunctionProfiler(**kwargs) def strobelight_inner( +<<<<<<< HEAD work_function: Callable[_P, _R], +======= + work_function: Callable[_P, _R] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Callable[_P, Optional[_R]]: @functools.wraps(work_function) def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: diff --git a/torch/_strobelight/compile_time_profiler.py b/torch/_strobelight/compile_time_profiler.py index 436f9a2c8b594..e62e182ece01a 100644 --- a/torch/_strobelight/compile_time_profiler.py +++ b/torch/_strobelight/compile_time_profiler.py @@ -127,7 +127,11 @@ def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None: if not shutil.which("strobeclient"): logger.info( +<<<<<<< HEAD "strobeclient not found, can't enable compile time strobelight profiling, seems" +======= + "strobeclient not found, cant enable compile time strobelight profiling, seems" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "like you are not on a FB machine." ) return diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index cefff832c5fdd..6193020c5fb1b 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -3,14 +3,19 @@ import functools import itertools import math +<<<<<<< HEAD import operator import sys from functools import reduce +======= +import sys +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Callable, Union import torch import torch._custom_op import torch._logging +<<<<<<< HEAD import torch._prims_common as utils from torch._dispatch.python import no_python_dispatcher from torch._ops import OpOverload @@ -24,6 +29,17 @@ is_float_dtype, is_integer_dtype, make_contiguous_strides_for, +======= +from torch._dispatch.python import no_python_dispatcher +from torch._ops import OpOverload +from torch._prims_common import ( + definitely_contiguous_for_memory_format, + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + is_boolean_dtype, + is_float_dtype, + is_integer_dtype, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from torch._subclasses.fake_tensor import ( DataDependentOutputException, @@ -135,9 +151,15 @@ def _is_tensor_constructor(func: OpOverload): def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]): def impl_decorator(op_impl): if isinstance(run_impl_check, OpOverload): +<<<<<<< HEAD assert run_impl_check not in op_implementations_dict, ( f"duplicate registration: {run_impl_check}" ) +======= + assert ( + run_impl_check not in op_implementations_dict + ), f"duplicate registration: {run_impl_check}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) op_implementations_dict[run_impl_check] = op_impl elif isinstance(run_impl_check, (list, tuple)): for op in run_impl_check: @@ -237,7 +259,11 @@ def stride_incorrect_op(op): # These operators have meta implementations with incorrect strides @register_op_impl(stride_incorrect_op) def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs): +<<<<<<< HEAD # This is a workaround for meta implementations with incorrect strides +======= + # This is a workaround for meta implmentations with incorrect strides +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def is_symbolic(x): if isinstance(x, FakeTensor): @@ -365,6 +391,7 @@ def unique2( return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts) +<<<<<<< HEAD @register_op_impl(aten.select.int) def meta_select(fake_mode, func, self, dim, index): from torch.fx.experimental.symbolic_shapes import guard_or_false @@ -407,6 +434,8 @@ def meta_select(fake_mode, func, self, dim, index): return self.as_strided(new_size, new_stride, new_storage_offset) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_op_impl(aten.unique_dim.default) def unique_dim( fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False @@ -437,6 +466,7 @@ def _(fake_mode, func, arg, return_inverse=False, return_counts=False, dim=None) ) +<<<<<<< HEAD # This function is python match of computeStride_impl in TensorUtils.cpp def _compute_stride(old_shape, old_stride, new_shape, size_oblivious=False): from torch.fx.experimental.symbolic_shapes import ( @@ -615,6 +645,8 @@ def _view_meta_copy(fake_mode, func, a, *shape, out=None): ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_op_impl(aten.repeat_interleave.Tensor) def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): if output_size is None: @@ -801,6 +833,7 @@ def assert_tensor_metadata( layout=None, ) -> None: if sizes is not None: +<<<<<<< HEAD assert t.size() == sizes, ( f"Tensor sizes mismatch! Expected: {sizes}, Got: {t.size()}" ) @@ -820,6 +853,27 @@ def assert_tensor_metadata( assert t.device == device, ( f"Tensor device mismatch! Expected: {device}, Got: {t.device}" ) +======= + assert ( + t.size() == sizes + ), f"Tensor sizes mismatch! Expected: {sizes}, Got: {t.size()}" + if strides is not None: + assert ( + t.stride() == strides + ), f"Tensor strides mismatch! Expected: {strides}, Got: {t.stride()}" + if dtype is not None: + assert ( + t.dtype == dtype + ), f"Tensor dtype mismatch! Expected: {dtype}, Got: {t.dtype}" + if layout is not None: + assert ( + t.layout == layout + ), f"Tensor layout mismatch! Expected: {layout}, Got: {t.layout()}" + if device is not None: + assert ( + t.device == device + ), f"Tensor device mismatch! Expected: {device}, Got: {t.device}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NB: this must be ordered after local_scalar_dense @@ -1246,8 +1300,12 @@ def slow(msg): # compute_fast_setup_type definitely_contiguous = True definitely_channels_last = True +<<<<<<< HEAD # TODO: is_non-overlapping_and_dense not bound from Python +======= + # TODO: is_non-overlapping_and_dense (not bound from Python +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # no inplace, no out, everything defined if is_noncontiguous_supported(common_device): @@ -1256,13 +1314,21 @@ def slow(msg): continue definitely_contiguous = ( definitely_contiguous +<<<<<<< HEAD and is_contiguous_for_memory_format_or_false( +======= + and definitely_contiguous_for_memory_format( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) op, memory_format=torch.contiguous_format ) ) definitely_channels_last = ( definitely_channels_last +<<<<<<< HEAD and is_contiguous_for_memory_format_or_false( +======= + and definitely_contiguous_for_memory_format( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) op, memory_format=torch.channels_last ) ) @@ -1318,9 +1384,13 @@ def get_fast_op_impls(): register_fast_op_impl(torch.ops.aten.sub.Tensor)( make_fast_binary_impl(torch._refs.sub) ) +<<<<<<< HEAD register_fast_op_impl(torch.ops.aten.mul.Tensor)( make_fast_binary_impl(torch._refs.mul) ) # type: ignore[has-type] +======= + register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) register_fast_op_impl(torch.ops.aten.div.Tensor)( make_fast_binary_impl( torch._refs.div, diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 5767f6a1d0c1e..44312ed8c9eca 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -496,9 +496,15 @@ def from_meta_and_device( pytype: Optional[type[torch.Tensor]] = None, dispatch_keys: Optional[torch.DispatchKeySet] = None, ) -> FakeTensor: +<<<<<<< HEAD assert t.device.type == "meta", ( f"tensor's device must be `meta`, got {t.device.type} instead" ) +======= + assert ( + t.device.type == "meta" + ), f"tensor's device must be `meta`, got {t.device.type} instead" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This is a bit abusive (this is not the "real" tensor) but whatever, # the meta tensor should be fresh so there's no way to get it wrong maybe_memo = self._get_memo(t) @@ -889,11 +895,14 @@ def _find_common_device( aten._foreach_copy.default, ) +<<<<<<< HEAD # list of ops not using zero dim cpu tensor logic to align with the eager mode. bypass_zero_dim_cpu_tensor_check_ops = ordered_set( aten.nextafter.default, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def check_cpu_device(device: torch.device) -> bool: return device.type == "cpu" @@ -917,6 +926,7 @@ def merge_devices(t: object) -> None: is_cpu_zero_dim = t_is_cpu_zero_dim return +<<<<<<< HEAD is_bypass_zero_dim_cpu_tensor_check_op = ( func in bypass_zero_dim_cpu_tensor_check_ops ) @@ -928,6 +938,15 @@ def merge_devices(t: object) -> None: # current device is from cpu 0 dim tensor, overwrite if is_cpu_zero_dim and not is_bypass_zero_dim_cpu_tensor_check_op: +======= + # mismatching devices ! + # if current tensor is cpu 0 dim, defer to existing device + if t_is_cpu_zero_dim: + return + + # current device is from cpu 0 dim tensor, overwrite + if is_cpu_zero_dim: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) common_device = t.device is_cpu_zero_dim = t_is_cpu_zero_dim return @@ -940,6 +959,7 @@ def merge_devices(t: object) -> None: if any(map(check_cpu_device, (common_device, t.device))): return +<<<<<<< HEAD # if prefer_device_type is set, prefer that device type over others prefer_device_type = torch._functorch.config.fake_tensor_prefer_device_type if prefer_device_type is not None: @@ -955,6 +975,8 @@ def merge_devices(t: object) -> None: # Keep the existing preferred device type return +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # mismatching devices of non-zero dim tensors, throw # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as raise RuntimeError( @@ -1618,10 +1640,14 @@ def _validate_cache_key( if torch.Tag.dynamic_output_shape in func.tags: if func is aten.index.Tensor: _, new_kwargs = normalize_function( # type: ignore[misc] +<<<<<<< HEAD func, args=args, # type: ignore[arg-type] kwargs=kwargs, # type: ignore[arg-type] normalize_to_only_use_kwargs=True, +======= + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for index in new_kwargs["indices"]: # index calls nonzero for bool or int8 tensors, and @@ -1672,6 +1698,7 @@ def _prep_args_for_hash( convert FakeTensors into metadata. Raises _BypassDispatchCache to signal unsupported cases that should bypass caching. """ +<<<<<<< HEAD from torch._higher_order_ops.auto_functionalize import ( FunctionalCallableWithEpilogue, ) @@ -1681,6 +1708,10 @@ def _prep_args_for_hash( result.append(type(args)) result.append(f"length_{len(args)}") +======= + from torch._higher_order_ops.utils import FunctionalizeCtxWrapper + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(args, dict): self._prep_args_for_hash(result, args.keys(), state, id_hashed_objects) self._prep_args_for_hash(result, args.values(), state, id_hashed_objects) @@ -1719,10 +1750,13 @@ def _prep_args_for_hash( # functional wrapper is destroyed after fake tensor prop. We # need to put the finalizer on the subgraph. id_hashed_objects.append(arg.subgraph) +<<<<<<< HEAD elif isinstance(arg, FunctionalCallableWithEpilogue): result.append(type(arg)) result.append(hash(arg)) id_hashed_objects.append(arg.orig_callable) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: # It's important to capture the type of the arg since, e.g., 1 and 1.0 # hash to the same value, but can produce different dtypes for the @@ -2167,7 +2201,13 @@ def _check_fake_real_vals(fake: Any, real: Any) -> None: try: _check_fake_real_vals(s_fake, s_real) except MetadataMismatchError as exc: +<<<<<<< HEAD if torch._functorch.config.generate_fake_kernels_from_real_mismatches: +======= + if ( + torch._functorch.config.generate_fake_kernels_from_real_mismatches + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtrace_structured( "mismatched_fake_kernel", metadata_fn=lambda: { @@ -2340,9 +2380,15 @@ def _dispatch_impl( and not flat_arg_fake_tensors and not device_conversion_skip_const_prop ): +<<<<<<< HEAD assert all(t.constant is not None for t in flat_arg_fake_tensors), ( f"{func} should not have fake inputs without constants" ) +======= + assert all( + t.constant is not None for t in flat_arg_fake_tensors + ), f"{func} should not have fake inputs without constants" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const_flat_args = [ a.constant if self.is_our_fake(a) else a for a in flat_args ] @@ -2394,7 +2440,11 @@ def _dispatch_impl( # (aot autograd, torchdynamo) where each operation is run consecutively. # Because each operation is run in order, we can trace out and support # sequences like: x = torch.tensor(0.); y = x.add_(1) +<<<<<<< HEAD # Whenever a constant is written to but with inputs that cannot be evaluated +======= + # Whenver a constant is written to but with inputs that cannot be evaluated +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # statically, such as random_(), we invalidate all constants that alias the input # We will rely on functionalization for use of fake tensors constants as persistent # objects on an FX Graph. @@ -2567,7 +2617,13 @@ def go(t: object, real_t: Tensor) -> None: if real_out is not nil: # cross check fake/real outputs, and optionally override fake kernel mismatches +<<<<<<< HEAD if not torch._functorch.config.generate_fake_kernels_from_real_mismatches: +======= + if ( + not torch._functorch.config.generate_fake_kernels_from_real_mismatches + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._maybe_infer_fake_kernel_from_pytree_out( func, (args, kwargs), @@ -2617,11 +2673,15 @@ def go(t: object, real_t: Tensor) -> None: # If there's a Python meta, prefer that over the decomposition from torch._decomp import meta_table as meta_table +<<<<<<< HEAD if ( func not in meta_table and not self.cpp_meta_supports_symint(func) and not (has_symbolic_sizes and func in self._view_fake_tensor_impl_ops) ): +======= + if func not in meta_table and not self.cpp_meta_supports_symint(func): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._decomp import decomposition_table # Prefer Python decompositions over C++ ones @@ -2929,10 +2989,13 @@ def create_symbolic_nested_int( aten._sparse_coo_tensor_with_dims_and_tensors.default, ) +<<<<<<< HEAD _view_fake_tensor_impl_ops = ordered_set( aten.view.default, aten._unsafe_view.default ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def cpp_meta_supports_symint(self, func: OpOverload) -> bool: if torch.Tag.view_copy in func.tags: return True @@ -2959,10 +3022,14 @@ def invalidate_written_to_constants( schema_info = get_schema_info(func) if any_constant and schema_info.is_mutable(): _, new_kwargs = normalize_function( # type: ignore[misc] +<<<<<<< HEAD func, args=args, # type: ignore[arg-type] kwargs=kwargs, # type: ignore[arg-type] normalize_to_only_use_kwargs=True, +======= + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for k, v in new_kwargs.items(): k = k if (k != "input" or schema_info.has_argument(k)) else "self" @@ -2986,9 +3053,15 @@ def from_tensor( if static_shapes is None: static_shapes = self.static_shapes if static_shapes: +<<<<<<< HEAD assert symbolic_context is None, ( "cannot set both static_shapes and symbolic_context" ) +======= + assert ( + symbolic_context is None + ), "cannot set both static_shapes and symbolic_context" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shape_env = None return self.fake_tensor_converter.from_real_tensor( self, diff --git a/torch/_subclasses/fake_utils.py b/torch/_subclasses/fake_utils.py index bd481c87cf6f3..9a5fd8df36720 100644 --- a/torch/_subclasses/fake_utils.py +++ b/torch/_subclasses/fake_utils.py @@ -102,7 +102,11 @@ def is_sdpa_error(func, idx, e): def try_convert_fake_to_real( +<<<<<<< HEAD ten_list: list[Union[FakeTensor, Any]], +======= + ten_list: list[Union[FakeTensor, Any]] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> list[Union[FakeTensor, torch.Tensor, Any]]: """ Attempt to convert fake tensors to a corresponding real tensor with the correct underlying storage by looking up @@ -266,9 +270,15 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): if fake_r is not None: r_flat = pytree.tree_leaves(r) f_flat = pytree.tree_leaves(fake_r) +<<<<<<< HEAD assert len(f_flat) == len(r_flat), ( f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}" ) +======= + assert len(f_flat) == len( + r_flat + ), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.check_aliasing: _check_alias_info( @@ -279,9 +289,15 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r)) ): r_is_ten = isinstance(r_out, torch.Tensor) +<<<<<<< HEAD assert r_is_ten == isinstance(f_out, torch.Tensor), ( f"{context} mismatched number of tensor outputs" ) +======= + assert r_is_ten == isinstance( + f_out, torch.Tensor + ), f"{context} mismatched number of tensor outputs" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if r_is_ten: try: _check_fake_real_tensors( diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index 28cc3070affc3..ef9ff185d56ba 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -67,7 +67,11 @@ class FunctionalTensor(torch.Tensor): # later, as long as it doesn't break anything). # FunctionalTensorWrapper copies **all** dispatch keys from the inner tensor # to the wrapper, excluding functorch and python dispatch keys. +<<<<<<< HEAD # Here I'm trying to reuse the keyset the functorch wrapper subclasses copy, +======= + # Here I'm trying to re-use the keyset the functorch wrapper subclasses copy, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # except that they don't include ZeroTensor so I'm manually adding it in. _extra_dispatch_keys = torch._C._additional_keys_to_prop_for_wrapper_tensors.add( torch._C.DispatchKey.ZeroTensor @@ -488,7 +492,11 @@ def unwrap(x): - FunctionalTensor._extra_dispatch_keys ) +<<<<<<< HEAD # All we want to do here is reuse the existing C++ functionalization logic. +======= + # All we want to do here is re-use the existing C++ functionalization logic. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This requires swizzling our TLS dispatch keys so that the Functionalize key is active. with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): try: diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index b73ee9abfc33a..99ff75851ee89 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -5,6 +5,10 @@ import functools import threading import typing +<<<<<<< HEAD +======= +import warnings +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import weakref from abc import abstractmethod from contextlib import AbstractContextManager, contextmanager @@ -80,7 +84,12 @@ def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool: def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]: +<<<<<<< HEAD with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter): +======= + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return t.grad @@ -355,9 +364,13 @@ def describe_tensor( maybe_functorch_stack = None if is_functorch_wrapped: +<<<<<<< HEAD with ( torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() ) as maybe_functorch_stack: +======= + with torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() as maybe_functorch_stack: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pass attrs = None @@ -417,7 +430,10 @@ def describe_tensor( stride=stride, storage_offset=storage_offset, dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())), +<<<<<<< HEAD dynamo_hint_overrides=getattr(t, "_dynamo_hint_overrides", {}), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sparse_dim=( t.sparse_dim() if t.is_sparse or is_sparse_compressed(t) else None ), @@ -518,7 +534,12 @@ def apply( new_base: _TensorT, symint_visitor_fn: Optional[Callable[[int], int]] = None, tensor_visitor_fn: Optional[Callable[[torch.Tensor], _TensorT]] = None, +<<<<<<< HEAD ) -> _TensorT: ... +======= + ) -> _TensorT: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def from_tensor(t: torch.Tensor) -> ViewFunc: @@ -574,7 +595,12 @@ def apply( class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]): def __call__( self, arg: Callable[[], torch.Tensor], /, *, device: Union[torch.device, str] +<<<<<<< HEAD ) -> _TensorT_cov: ... +======= + ) -> _TensorT_cov: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _MetaTensorCallbackKwargs(TypedDict, total=False): @@ -591,7 +617,12 @@ def __call__( arg: Callable[[], torch.Tensor], /, **kwargs: Unpack[_MetaTensorCallbackKwargs], +<<<<<<< HEAD ) -> _TensorT_cov: ... +======= + ) -> _TensorT_cov: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass(frozen=True) @@ -615,7 +646,10 @@ class MetaTensorDesc(Generic[_TensorT]): # defined on NJT size: tuple[int, ...] dynamo_dynamic_indices: list[int] +<<<<<<< HEAD dynamo_hint_overrides: dict[int, int] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) layout: torch.layout = torch.strided is_inference: bool = False @@ -784,9 +818,15 @@ def __init__(self, *, copy_data: bool = False) -> None: ] = weakref.WeakValueDictionary() # Maps MetaTensorId to torch.Tensor (typically a meta tensor or # FakeTensor) +<<<<<<< HEAD self.tensor_memo: weakref.WeakValueDictionary[MetaTensorId, _TensorT] = ( weakref.WeakValueDictionary() ) +======= + self.tensor_memo: weakref.WeakValueDictionary[ + MetaTensorId, _TensorT + ] = weakref.WeakValueDictionary() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.hit = 0 self.miss = 0 self.del_hook = None @@ -958,7 +998,10 @@ def sym_sizes_strides_storage_offset( [d in t.dynamo_dynamic_indices for d in range(t.ndim)], src, symbolic_context=symbolic_context, +<<<<<<< HEAD hint_overrides=t.dynamo_hint_overrides, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: return (t.size, t.stride, t.storage_offset) @@ -1644,7 +1687,11 @@ def is_c_of_r( with torch.enable_grad(): r = view_from_base(base, t) +<<<<<<< HEAD # NB: We don't actually faithfully replicate +======= + # NB: We don't actaully faithfully replicate +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # autograd connectivity, but that doesn't matter # today. See following for more info: # https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913 @@ -1670,8 +1717,11 @@ def is_c_of_r( torch._C.DispatchKey.ADInplaceOrView, old_exclude ) +<<<<<<< HEAD r.fake_device = t.device # type: ignore[attr-defined] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: is_leaf = t.is_leaf @@ -1774,9 +1824,15 @@ def is_c_of_r( # subclasses. Relevant test is # DynamicShapesFunctionTests::test_add_dynamic_shapes in # test/dynamo/test_dynamic_shapes.py +<<<<<<< HEAD maybe_fake_mgr: AbstractContextManager[None] = ( contextlib.nullcontext() ) +======= + maybe_fake_mgr: AbstractContextManager[ + None + ] = contextlib.nullcontext() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._subclasses.fake_tensor import ( in_kernel_invocation_manager, maybe_get_fake_mode, diff --git a/torch/_tensor.py b/torch/_tensor.py index 6cebed28b8b0d..4168acbc3e1e3 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -6,8 +6,12 @@ from collections import OrderedDict from copy import deepcopy from numbers import Number +<<<<<<< HEAD from typing import Any, Callable, cast, Optional, TypeVar, Union from typing_extensions import Concatenate, ParamSpec +======= +from typing import Any, Callable, cast, Optional, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._C as _C @@ -28,6 +32,7 @@ ) +<<<<<<< HEAD _P = ParamSpec("_P") _TensorLike = TypeVar("_TensorLike", bound=_C.TensorBase) @@ -43,6 +48,18 @@ def wrapped(self: _TensorLike, *args: _P.args, **kwargs: _P.kwargs) -> "Tensor": if has_torch_function(sargs): return handle_torch_function(wrapped, sargs, *sargs, **kwargs) return f(self, *args, **kwargs) +======= +def _handle_torch_function_and_wrap_type_error_to_not_implemented(f): + assigned = functools.WRAPPER_ASSIGNMENTS + + @functools.wraps(f, assigned=assigned) + def wrapped(*args, **kwargs): + try: + # See https://github.com/pytorch/pytorch/issues/75462 + if has_torch_function(args): + return handle_torch_function(wrapped, args, *args, **kwargs) + return f(*args, **kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except TypeError: return NotImplemented @@ -330,7 +347,11 @@ def _reduce_ex_internal(self, proto): torch.serialization._serialization_tls.materialize_fake_tensors ) +<<<<<<< HEAD if self.device.type in ["xla", "maia", "mtia"] or ( +======= + if self.device.type in ["xla", "maia"] or ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) not torch._C._has_storage(self) and self.device.type == torch._C._get_privateuse1_backend_name() ): @@ -343,6 +364,37 @@ def _reduce_ex_internal(self, proto): torch._utils._rebuild_device_tensor_from_cpu_tensor, (cpu_tensor, self.dtype, str(self.device), self.requires_grad), ) +<<<<<<< HEAD +======= + # Legacy comment that does not hold anymore. + # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors. + # We considered a few options: + # 1. CPU tensor can't be used here. + # Otherwise in torch.load CPU storage is reconstructed with randomly + # initialized data, moved onto backend device, and then storage is updated + # to the serialized content. This works perfectly for CPU/CUDA but not these backends; + # their tensors are disconnected with storage so they don't get the update. + # 2. Python list is not a good fit due to performance reason. + # `tolist()` converts every single element in the tensor into python objects + # and serialize them one by one. + if self.device.type in ["mtia"]: + # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't + # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype, + # this would reconstruct the BFloat16 tensor from numpy. + if skip_data: + raise RuntimeError( + "Cannot serialize tensors on backends with no storage under skip_data context manager" + ) + numpy_tensor = ( + self.cpu().numpy() + if self.dtype != torch.bfloat16 + else self.cpu().to(torch.float32).numpy() + ) + return ( + torch._utils._rebuild_device_tensor_from_numpy, + (numpy_tensor, self.dtype, str(self.device), self.requires_grad), + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.device.type == "meta": # NB: This implementation BREAKS storage sharing. Current # hypothesis is that no one cares for meta tensors. @@ -1071,11 +1123,19 @@ def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None ) @_handle_torch_function_and_wrap_type_error_to_not_implemented +<<<<<<< HEAD def __rsub__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor": return _C._VariableFunctions.rsub(self, other) @_handle_torch_function_and_wrap_type_error_to_not_implemented def __rdiv__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor": +======= + def __rsub__(self, other): + return _C._VariableFunctions.rsub(self, other) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rdiv__(self, other): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.reciprocal() * other __rtruediv__ = __rdiv__ @@ -1090,13 +1150,20 @@ def __rdiv__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor _C.TensorBase.pow ), ) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented( _C.TensorBase.pow_ ) @_handle_torch_function_and_wrap_type_error_to_not_implemented +<<<<<<< HEAD def __rmod__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor": +======= + def __rmod__(self, other): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch.remainder(other, self) def __format__(self, format_spec): @@ -1109,6 +1176,7 @@ def __format__(self, format_spec): return object.__format__(self, format_spec) @_handle_torch_function_and_wrap_type_error_to_not_implemented +<<<<<<< HEAD def __rpow__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor": return torch.pow(other, self) @@ -1136,6 +1204,29 @@ def __rrshift__( @_handle_torch_function_and_wrap_type_error_to_not_implemented def __rmatmul__(self, other: "Tensor") -> "Tensor": +======= + def __rpow__(self, other): + return torch.pow(other, self) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __floordiv__(self, other): + return torch.floor_divide(self, other) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rfloordiv__(self, other): + return torch.floor_divide(other, self) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rlshift__(self, other): + return torch.bitwise_left_shift(other, self) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rrshift__(self, other): + return torch.bitwise_right_shift(other, self) + + @_handle_torch_function_and_wrap_type_error_to_not_implemented + def __rmatmul__(self, other): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch.matmul(other, self) __pos__ = _C.TensorBase.positive @@ -1659,6 +1750,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): __torch_dispatch__ = _C._disabled_torch_dispatch_impl +<<<<<<< HEAD def __dlpack__( self, *, @@ -1667,6 +1759,9 @@ def __dlpack__( dl_device: Optional[tuple[enum.IntEnum, int]] = None, copy: Optional[bool] = None, ): +======= + def __dlpack__(self, stream=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_ of the current tensor to be exported to other libraries. @@ -1677,6 +1772,7 @@ def __dlpack__( Args: stream (integer or None): An optional Python integer representing a +<<<<<<< HEAD pointer to a CUDA stream. The current stream is synchronized with this stream before the capsule is created, and since the capsule shares its storage with the tensor this make it safe to access from @@ -1704,11 +1800,23 @@ def __dlpack__( "copy": copy, } return handle_torch_function(Tensor.__dlpack__, (self,), *args, **kwargs) +======= + pointer to a CUDA stream. The current stream is synchronized with + this stream before the capsule is created, and since the capsule + shares its storage with the tensor this make it safe to access from + both streams. If None or -1 is passed then no synchronization is performed. + If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for + synchronization. + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__dlpack__, (self,), self, stream) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # DLPack capsules can't capture all of PyTorch's semantics, # so we prohibit exporting tensors that would lose their properties like # requires_grad and having the conjugate bit set. if self.requires_grad: +<<<<<<< HEAD raise BufferError( "Can't export tensors that require gradient, use tensor.detach()" ) @@ -1729,10 +1837,23 @@ def __dlpack__( f"Current device: {torch.cuda.current_device()}." ) +======= + raise RuntimeError( + "Can't export tensors that require gradient, use tensor.detach()" + ) + if self.is_conj(): + raise RuntimeError("Can't export tensors with the conjugate bit set") + if self.layout != torch.strided: + raise RuntimeError( + "Can't export tensors with layout other than torch.strided" + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if stream is not None and type(stream) is not int: # Stream pointers in CUDA/ROCm are uniquely numbered and can # be retrieved from their integer value. raise TypeError("stream must be ``int`` or ``none``") +<<<<<<< HEAD elif self.device.type == "cuda" and stream != -1: # NB: This logic handles the special case values for default # streams and must be kept in sync with from_dlpack in @@ -1762,6 +1883,25 @@ def __dlpack__( elif self.device.type == "cpu": assert stream is None, "stream should be None on cpu." +======= + elif stream is not None and stream != -1: + if self.device.type == "cuda": + # NB: This logic handles the special case values for default + # streams and must be kept in sync with from_dlpack in + # torch/utils/dlpack.py + if stream == 1 and torch.version.hip is None: + stream = torch.cuda.default_stream() + elif stream == 0 and torch.version.hip is not None: + stream = torch.cuda.default_stream() + else: + stream = torch.cuda.ExternalStream(stream) + # Only synchronize on different streams + sync_stream = torch.cuda.current_stream() + if stream != sync_stream: + event = torch.cuda.Event() + event.record(sync_stream) + stream.wait_event(event) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.device.type == "xla": import torch_xla import torch_xla.utils.dlpack as xla_dlpack @@ -1773,6 +1913,7 @@ def __dlpack__( raise RuntimeError( "Can't export to dlpack an XLA tensor that is not on CUDA." ) +<<<<<<< HEAD # Does not support DLPack 1.0, yet. return xla_dlpack.to_dlpack(self) @@ -1782,6 +1923,10 @@ def __dlpack__( return _C._to_dlpack(self, dl_device=dl_device, copy=copy) return _C._to_dlpack_versioned(self, dl_device=dl_device, copy=copy) +======= + return xla_dlpack.to_dlpack(self) + return torch.to_dlpack(self) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __dlpack_device__(self) -> tuple[enum.IntEnum, int]: if has_torch_function_unary(self): @@ -1795,9 +1940,15 @@ def __dlpack_device__(self) -> tuple[enum.IntEnum, int]: if torch_device_type == "cuda" and torch.version.hip is not None: device_type = DLDeviceType.kDLROCM elif torch_device_type == "cpu" and self.is_pinned(): +<<<<<<< HEAD device_type = DLDeviceType.kDLCUDAHost elif torch_device_type == "cuda": device_type = DLDeviceType.kDLCUDA +======= + device_type = DLDeviceType.kDLCPUPinned + elif torch_device_type == "cuda": + device_type = DLDeviceType.kDLGPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif torch_device_type == "cpu": device_type = DLDeviceType.kDLCPU elif torch_device_type == "xpu": @@ -1813,9 +1964,13 @@ def __dlpack_device__(self) -> tuple[enum.IntEnum, int]: ): raise ValueError(f"Unknown device type {torch_device_type} for Dlpack") +<<<<<<< HEAD device_type = DLDeviceType.kDLCUDA elif torch_device_type == "mps": device_type = DLDeviceType.kDLMetal +======= + device_type = DLDeviceType.kDLGPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: raise ValueError(f"Unknown device type {torch_device_type} for Dlpack") return (device_type, idx) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 33a2184b71f5e..c74f3d57da7d0 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1215,9 +1215,15 @@ def add_docstr_all(method: str, docstr: str) -> None: Args: src (Tensor): the source tensor to copy from +<<<<<<< HEAD non_blocking (bool, optional): if ``True`` and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. Default: ``False`` +======= + non_blocking (bool): if ``True`` and this copy is between CPU and GPU, + the copy may occur asynchronously with respect to the host. For other + cases, this argument has no effect. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """, ) @@ -1383,9 +1389,15 @@ def add_docstr_all(method: str, docstr: str) -> None: then no copy is performed and the original object is returned. Args: +<<<<<<< HEAD device (:class:`torch.device`, optional): The destination GPU device. Defaults to the current CUDA device. non_blocking (bool, optional): If ``True`` and the source is in pinned memory, +======= + device (:class:`torch.device`): The destination GPU device. + Defaults to the current CUDA device. + non_blocking (bool): If ``True`` and the source is in pinned memory, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) the copy will be asynchronous with respect to the host. Otherwise, the argument has no effect. Default: ``False``. {memory_format} @@ -1403,9 +1415,15 @@ def add_docstr_all(method: str, docstr: str) -> None: then no copy is performed and the original object is returned. Args: +<<<<<<< HEAD device (:class:`torch.device`, optional): The destination MTIA device. Defaults to the current MTIA device. non_blocking (bool, optional): If ``True`` and the source is in pinned memory, +======= + device (:class:`torch.device`): The destination MTIA device. + Defaults to the current MTIA device. + non_blocking (bool): If ``True`` and the source is in pinned memory, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) the copy will be asynchronous with respect to the host. Otherwise, the argument has no effect. Default: ``False``. {memory_format} @@ -1423,9 +1441,15 @@ def add_docstr_all(method: str, docstr: str) -> None: then no copy is performed and the original object is returned. Args: +<<<<<<< HEAD device (:class:`torch.device`, optional): The destination IPU device. Defaults to the current IPU device. non_blocking (bool, optional): If ``True`` and the source is in pinned memory, +======= + device (:class:`torch.device`): The destination IPU device. + Defaults to the current IPU device. + non_blocking (bool): If ``True`` and the source is in pinned memory, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) the copy will be asynchronous with respect to the host. Otherwise, the argument has no effect. Default: ``False``. {memory_format} @@ -1443,9 +1467,15 @@ def add_docstr_all(method: str, docstr: str) -> None: then no copy is performed and the original object is returned. Args: +<<<<<<< HEAD device (:class:`torch.device`, optional): The destination XPU device. Defaults to the current XPU device. non_blocking (bool, optional): If ``True`` and the source is in pinned memory, +======= + device (:class:`torch.device`): The destination XPU device. + Defaults to the current XPU device. + non_blocking (bool): If ``True`` and the source is in pinned memory, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) the copy will be asynchronous with respect to the host. Otherwise, the argument has no effect. Default: ``False``. {memory_format} @@ -1612,7 +1642,11 @@ def add_docstr_all(method: str, docstr: str) -> None: Arguments: fill_value (Scalar): the fill value +<<<<<<< HEAD wrap (bool, optional): the diagonal 'wrapped' after N columns for tall matrices. Default: ``False`` +======= + wrap (bool): the diagonal 'wrapped' after N columns for tall matrices. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Example:: @@ -2526,7 +2560,11 @@ def add_docstr_all(method: str, docstr: str) -> None: row of ``source`` is multiplied by the ``j``\ th row of :attr:`self`. If :obj:`include_self="True"`, the values in the :attr:`self` tensor are included in the reduction, otherwise, rows in the :attr:`self` tensor that are accumulated +<<<<<<< HEAD to are treated as if they were filled with the reduction identities. +======= +to are treated as if they were filled with the reduction identites. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) The :attr:`dim`\ th dimension of ``source`` must have the same size as the length of :attr:`index` (which must be a vector), and all other dimensions must @@ -3640,7 +3678,11 @@ def callable(a, b) -> number tensor. Pad the out tensor with `fill_value` if the `size` is larger than total number of non-zero elements, truncate out tensor if `size` is smaller. The size must be a non-negative integer. +<<<<<<< HEAD fill_value (int, optional): the value to fill the output tensor with when `size` is larger +======= + fill_value (int): the value to fill the output tensor with when `size` is larger +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) than the total number of non-zero elements. Default is `-1` to represent invalid index. @@ -3848,7 +3890,11 @@ def callable(a, b) -> number Args: index (LongTensor): the indices into self source (Tensor): the tensor containing values to copy from +<<<<<<< HEAD accumulate (bool, optional): whether to accumulate into self. Default: ``False`` +======= + accumulate (bool): whether to accumulate into self +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Example:: @@ -4394,12 +4440,20 @@ def callable(a, b) -> number This is the reverse operation of the manner described in :meth:`~Tensor.gather`. +<<<<<<< HEAD It is also required that ``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that ``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. Note that ``input`` and ``index`` do not broadcast against each other for NPUs, so when running on NPUs, :attr:`input` and :attr:`index` must have the same number of dimensions. Standard broadcasting occurs in all other cases. +======= +:attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should all have +the same number of dimensions. It is also required that +``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that +``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. +Note that ``index`` and ``src`` do not broadcast. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be between ``0`` and ``self.size(dim) - 1`` inclusive. @@ -4526,8 +4580,11 @@ def callable(a, b) -> number dimensions. It is also required that ``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that ``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. Note that ``index`` and ``src`` do not broadcast. +<<<<<<< HEAD When :attr:`index` is empty, we always return the original tensor without further error checking. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Note: {forward_reproducibility_note} diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index c9262e1b2ee0a..77bac803ee416 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -178,6 +178,7 @@ def __init__(self, tensor): self.int_mode = False break +<<<<<<< HEAD self.sci_mode = ( nonzero_finite_max / nonzero_finite_min > 1000.0 or nonzero_finite_max > 1.0e8 @@ -190,6 +191,16 @@ def __init__(self, tensor): # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites # to indicate that the tensor is of floating type. add 1 to the len to account for this. if self.sci_mode: +======= + if self.int_mode: + # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites + # to indicate that the tensor is of floating type. add 1 to the len to account for this. + if ( + nonzero_finite_max / nonzero_finite_min > 1000.0 + or nonzero_finite_max > 1.0e8 + ): + self.sci_mode = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for value in nonzero_finite_vals: value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value) self.max_width = max(self.max_width, len(value_str)) @@ -199,7 +210,16 @@ def __init__(self, tensor): self.max_width = max(self.max_width, len(value_str) + 1) else: # Check if scientific representation should be used. +<<<<<<< HEAD if self.sci_mode: +======= + if ( + nonzero_finite_max / nonzero_finite_min > 1000.0 + or nonzero_finite_max > 1.0e8 + or nonzero_finite_min < 1.0e-4 + ): + self.sci_mode = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for value in nonzero_finite_vals: value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value) self.max_width = max(self.max_width, len(value_str)) @@ -208,6 +228,12 @@ def __init__(self, tensor): value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value) self.max_width = max(self.max_width, len(value_str)) +<<<<<<< HEAD +======= + if PRINT_OPTS.sci_mode is not None: + self.sci_mode = PRINT_OPTS.sci_mode + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def width(self): return self.max_width diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 68c3fe31c5bf0..2730deb25d85b 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -104,11 +104,14 @@ def merge_dicts(*dicts): """ }, { +<<<<<<< HEAD "opt_dim_without_none": """ dim (int, optional): the dimension to reduce. If omitted, all dimensions are reduced. Explicit ``None`` is not supported. """ }, { +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "keepdim_details": """If :attr:`keepdim` is ``True``, the output tensor is of the same size as :attr:`input` except in the dimension :attr:`dim` where it is of size 1. Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in @@ -1006,8 +1009,12 @@ def merge_dicts(*dicts): tensor is constructed using :func:`torch.from_numpy`. If :attr:`data` is a CuPy array, the returned tensor will be located on the same device as the CuPy array unless +<<<<<<< HEAD specifically overwritten by :attr:`device` or a default device. The device of the CuPy array is inferred from the pointer of the array using `cudaPointerGetAttributes` unless :attr:`device` is provided with an explicit device index. +======= +specifically overwritten by :attr:`device` or a default device. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. seealso:: @@ -2783,12 +2790,15 @@ def merge_dicts(*dicts): result of this operation to :attr:`input`. To create a tensor without an autograd relationship to :attr:`input` see :meth:`~Tensor.detach`. +<<<<<<< HEAD In addition, when ``torch.preserve_format`` is used: If the input tensor is dense (i.e., non-overlapping strided), its memory format (including strides) is retained. Otherwise (e.g., a non-dense view like a stepped slice), the output is converted to the dense (contiguous) format. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: {input} @@ -4587,8 +4597,11 @@ def merge_dicts(*dicts): It is also required that ``index.size(d) <= input.size(d)`` for all dimensions ``d != dim``. :attr:`out` will have the same shape as :attr:`index`. Note that ``input`` and ``index`` do not broadcast against each other. +<<<<<<< HEAD When :attr:`index` is empty, we always return an empty output with the same shape without further error checking. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: input (Tensor): the source tensor @@ -4742,8 +4755,12 @@ def merge_dicts(*dicts): edge_order (``int``, optional): 1 or 2, for `first-order `_ or `second-order `_ +<<<<<<< HEAD estimation of the boundary ("edge") values, respectively. Note that when :attr:`edge_order` is specified, each dimension size of :attr:`input` should be at least edge_order+1 +======= + estimation of the boundary ("edge") values, respectively. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Examples:: @@ -5021,6 +5038,7 @@ def merge_dicts(*dicts): ) add_docstr( +<<<<<<< HEAD torch.hash_tensor, r""" hash_tensor(input, *, mode=0) -> Tensor @@ -5075,6 +5093,8 @@ def merge_dicts(*dicts): ) add_docstr( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.histc, r""" histc(input, bins=100, min=0, max=0, *, out=None) -> Tensor @@ -5184,7 +5204,11 @@ def merge_dicts(*dicts): If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences of bin edges. Each 1D tensor should contain a strictly increasing sequence with at least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying +<<<<<<< HEAD the left and right edges of all bins. Every bin is inclusive of its left edge. Only +======= +the left and right edges of all bins. Every bin is exclusive of its left edge. Only +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) the rightmost bin is inclusive of its right edge. If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins @@ -5555,13 +5579,18 @@ def merge_dicts(*dicts): add_docstr( torch.is_floating_point, r""" +<<<<<<< HEAD is_floating_point(input: Tensor) -> bool +======= +is_floating_point(input) -> (bool) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns True if the data type of :attr:`input` is a floating point data type i.e., one of ``torch.float64``, ``torch.float32``, ``torch.float16``, and ``torch.bfloat16``. Args: {input} +<<<<<<< HEAD Example:: @@ -5573,19 +5602,26 @@ def merge_dicts(*dicts): True >>> torch.is_floating_point(torch.tensor([1, 2, 3], dtype=torch.complex64)) False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """.format(**common_args), ) add_docstr( torch.is_complex, r""" +<<<<<<< HEAD is_complex(input: Tensor) -> bool +======= +is_complex(input) -> (bool) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns True if the data type of :attr:`input` is a complex data type i.e., one of ``torch.complex64``, and ``torch.complex128``. Args: {input} +<<<<<<< HEAD Example:: @@ -5597,6 +5633,8 @@ def merge_dicts(*dicts): False >>> torch.is_complex(torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16)) False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """.format(**common_args), ) @@ -5677,11 +5715,19 @@ def merge_dicts(*dicts): >>> torch.is_nonzero(torch.tensor([1, 3, 5])) Traceback (most recent call last): ... +<<<<<<< HEAD RuntimeError: Boolean value of Tensor with more than one value is ambiguous >>> torch.is_nonzero(torch.tensor([])) Traceback (most recent call last): ... RuntimeError: Boolean value of Tensor with no values is ambiguous +======= + RuntimeError: bool value of Tensor with more than one value is ambiguous + >>> torch.is_nonzero(torch.tensor([])) + Traceback (most recent call last): + ... + RuntimeError: bool value of Tensor with no values is ambiguous +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """.format(**common_args), ) @@ -6602,6 +6648,7 @@ def merge_dicts(*dicts): Returns the maximum value of all elements in the ``input`` tensor. +<<<<<<< HEAD .. note:: The difference between ``max``/``min`` and ``amax``/``amin`` is: - ``amax``/``amin`` supports reducing on multiple dimensions, @@ -6614,6 +6661,8 @@ def merge_dicts(*dicts): - If reduce over all dimensions(no dim specified), gradients evenly distribute between equally ``max``/``min`` values. - If reduce over one specified axis, only propagate to the indexed element. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: {input} @@ -6646,7 +6695,11 @@ def merge_dicts(*dicts): Args: {input} +<<<<<<< HEAD {opt_dim_without_none} +======= + {opt_dim_all_reduce} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {opt_keepdim} Keyword args: @@ -6677,7 +6730,11 @@ def merge_dicts(*dicts): See :func:`torch.maximum`. +<<<<<<< HEAD """.format(**single_dim_common), +======= +""".format(**multi_dim_common), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) add_docstr( @@ -6752,6 +6809,7 @@ def merge_dicts(*dicts): - ``amax``/``amin`` supports reducing on multiple dimensions, - ``amax``/``amin`` does not return indices. +<<<<<<< HEAD Both ``amax``/``amin`` evenly distribute gradients between equal values when there are multiple input elements with the same minimum or maximum value. @@ -6759,6 +6817,11 @@ def merge_dicts(*dicts): - If reduce over all dimensions(no dim specified), gradients evenly distribute between equally ``max``/``min`` values. - If reduce over one specified axis, only propagate to the indexed element. +======= + Both ``max``/``min`` and ``amax``/``amin`` evenly distribute gradients between equal values + when there are multiple input elements with the same minimum or maximum value. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {keepdim_details} Args: @@ -7148,7 +7211,11 @@ def merge_dicts(*dicts): {opt_keepdim} Keyword arguments: +<<<<<<< HEAD interpolation (str, optional): interpolation method to use when the desired quantile lies between two data points. +======= + interpolation (str): interpolation method to use when the desired quantile lies between two data points. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Can be ``linear``, ``lower``, ``higher``, ``midpoint`` and ``nearest``. Default is ``linear``. {out} @@ -7236,6 +7303,7 @@ def merge_dicts(*dicts): Returns the minimum value of all elements in the :attr:`input` tensor. +<<<<<<< HEAD .. note:: The difference between ``max``/``min`` and ``amax``/``amin`` is: - ``amax``/``amin`` supports reducing on multiple dimensions, @@ -7248,6 +7316,8 @@ def merge_dicts(*dicts): - If reduce over all dimensions(no dim specified), gradients evenly distribute between equally ``max``/``min`` values. - If reduce over one specified axis, only propagate to the indexed element. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: {input} @@ -7280,7 +7350,11 @@ def merge_dicts(*dicts): Args: {input} +<<<<<<< HEAD {opt_dim_without_none} +======= + {opt_dim_all_reduce} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {opt_keepdim} Keyword args: @@ -7376,6 +7450,7 @@ def merge_dicts(*dicts): - ``amax``/``amin`` supports reducing on multiple dimensions, - ``amax``/``amin`` does not return indices. +<<<<<<< HEAD Both ``amax``/``amin`` evenly distribute gradients between equal values when there are multiple input elements with the same minimum or maximum value. @@ -7383,6 +7458,11 @@ def merge_dicts(*dicts): - If reduce over all dimensions(no dim specified), gradients evenly distribute between equally ``max``/``min`` values. - If reduce over one specified axis, only propagate to the indexed element. +======= + Both ``max``/``min`` and ``amax``/``amin`` evenly distribute gradients between equal values + when there are multiple input elements with the same minimum or maximum value. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {keepdim_details} Args: @@ -7592,6 +7672,7 @@ def merge_dicts(*dicts): N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after. If the second argument is 1-dimensional, a +<<<<<<< HEAD 1 is appended to its dimension for the purpose of the batched matrix multiply and removed after. The first N-2 dimensions of each argument, the batch dimensions, are @@ -7603,6 +7684,19 @@ def merge_dicts(*dicts): tensor, the batch dimensions are :math:`(j \times 1)` and :math:`(k)`, and the matrix dimensions are :math:`(n \times m)` and :math:`(m \times p)`. :attr:`out` will be a :math:`(j \times k \times n \times p)` tensor. +======= + 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. + The non-matrix (i.e. batch) dimensions are :ref:`broadcasted ` (and thus + must be broadcastable). For example, if :attr:`input` is a + :math:`(j \times 1 \times n \times n)` tensor and :attr:`other` is a :math:`(k \times n \times n)` + tensor, :attr:`out` will be a :math:`(j \times k \times n \times n)` tensor. + + Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs + are broadcastable, and not the matrix dimensions. For example, if :attr:`input` is a + :math:`(j \times 1 \times n \times m)` tensor and :attr:`other` is a :math:`(k \times m \times p)` + tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the + matrix dimensions) are different. :attr:`out` will be a :math:`(j \times k \times n \times p)` tensor. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) This operation has support for arguments with :ref:`sparse layouts`. In particular the matrix-matrix (both arguments 2-dimensional) supports sparse arguments with the same restrictions @@ -7715,7 +7809,11 @@ def merge_dicts(*dicts): Args: {input} +<<<<<<< HEAD other (Tensor or Number): the tensor or number to multiply input by. +======= + other (Tensor or Number) - the tensor or number to multiply input by. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Keyword args: {out} @@ -9058,7 +9156,11 @@ def merge_dicts(*dicts): Keyword args: {generator} {out} +<<<<<<< HEAD dtype (torch.dtype, optional): the desired data type of returned tensor. Default: if ``None``, +======= + dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) this function returns a tensor with dtype ``torch.int64``. {layout} {device} @@ -9647,6 +9749,7 @@ def merge_dicts(*dicts): ) add_docstr( +<<<<<<< HEAD torch.segment_reduce, r""" segment_reduce(data: Tensor, reduce: str, *, lengths: Tensor | None = None, indices: Tensor | None = None, offsets: Tensor | None = None, axis: _int = 0, unsafe: _bool = False, initial: Number | _complex | None = None) -> Tensor # noqa: B950 @@ -9675,6 +9778,8 @@ def merge_dicts(*dicts): ) add_docstr( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.select, r""" select(input, dim, index) -> Tensor @@ -10009,7 +10114,11 @@ def merge_dicts(*dicts): add_docstr( torch.sort, r""" +<<<<<<< HEAD sort(input, dim=-1, descending=False, *, stable=False, out=None) -> (Tensor, LongTensor) +======= +sort(input, dim=-1, descending=False, stable=False, *, out=None) -> (Tensor, LongTensor) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Sorts the elements of the :attr:`input` tensor along a given dimension in ascending order by value. @@ -10030,10 +10139,17 @@ def merge_dicts(*dicts): {input} dim (int, optional): the dimension to sort along descending (bool, optional): controls the sorting order (ascending or descending) +<<<<<<< HEAD Keyword args: stable (bool, optional): makes the sorting routine stable, which guarantees that the order of equivalent elements is preserved. +======= + stable (bool, optional): makes the sorting routine stable, which guarantees that the order + of equivalent elements is preserved. + +Keyword args: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out (tuple, optional): the output tuple of (`Tensor`, `LongTensor`) that can be optionally given to be used as output buffers @@ -10074,7 +10190,11 @@ def merge_dicts(*dicts): add_docstr( torch.argsort, r""" +<<<<<<< HEAD argsort(input, dim=-1, descending=False, *, stable=False) -> Tensor +======= +argsort(input, dim=-1, descending=False, stable=False) -> Tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns the indices that sort a tensor along a given dimension in ascending order by value. @@ -10090,8 +10210,11 @@ def merge_dicts(*dicts): {input} dim (int, optional): the dimension to sort along descending (bool, optional): controls the sorting order (ascending or descending) +<<<<<<< HEAD Keyword args: +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) stable (bool, optional): controls the relative order of equivalent elements Example:: @@ -10168,7 +10291,11 @@ def merge_dicts(*dicts): subtracted by the number before it denotes the number of elements or blocks in a given compressed dimension. plain_indices (array_like): Plain dimension (column or row) +<<<<<<< HEAD coordinates of each element or block in values. (B+1)-dimensional +======= + co-ordinates of each element or block in values. (B+1)-dimensional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor with the same length as values. values (array_list): Initial values for the tensor. Can be a list, @@ -10236,7 +10363,11 @@ def merge_dicts(*dicts): starts. Each successive number in the tensor subtracted by the number before it denotes the number of elements in a given row. +<<<<<<< HEAD col_indices (array_like): Column coordinates of each element in +======= + col_indices (array_like): Column co-ordinates of each element in +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) values. (B+1)-dimensional tensor with the same length as values. values (array_list): Initial values for the tensor. Can be a list, @@ -10299,7 +10430,11 @@ def merge_dicts(*dicts): starts. Each successive number in the tensor subtracted by the number before it denotes the number of elements in a given column. +<<<<<<< HEAD row_indices (array_like): Row coordinates of each element in +======= + row_indices (array_like): Row co-ordinates of each element in +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) values. (B+1)-dimensional tensor with the same length as values. values (array_list): Initial values for the tensor. Can be a list, @@ -10362,7 +10497,11 @@ def merge_dicts(*dicts): given row block starts. Each successive number in the tensor subtracted by the number before it denotes the number of blocks in a given row. +<<<<<<< HEAD col_indices (array_like): Column block coordinates of each block +======= + col_indices (array_like): Column block co-ordinates of each block +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) in values. (B+1)-dimensional tensor with the same length as values. values (array_list): Initial values for the tensor. Can be a list, @@ -10430,7 +10569,11 @@ def merge_dicts(*dicts): column starts. Each successive number in the tensor subtracted by the number before it denotes the number of elements in a given column. +<<<<<<< HEAD row_indices (array_like): Row block coordinates of each block in +======= + row_indices (array_like): Row block co-ordinates of each block in +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) values. (B+1)-dimensional tensor with the same length as values. values (array_list): Initial blocks for the tensor. Can be a list, @@ -12250,12 +12393,15 @@ def merge_dicts(*dicts): Floating point and complex tensors are filled with NaN, and integer tensors are filled with the maximum value. +<<<<<<< HEAD When ``torch.preserve_format`` is used: If the input tensor is dense (i.e., non-overlapping strided), its memory format (including strides) is retained. Otherwise (e.g., a non-dense view like a stepped slice), the output is converted to the dense format. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: {input} @@ -12419,6 +12565,7 @@ def merge_dicts(*dicts): {device} {requires_grad} {memory_format} +<<<<<<< HEAD Example:: @@ -12437,6 +12584,8 @@ def merge_dicts(*dicts): tensor([[-1., -1., -1., -1.], [-1., -1., -1., -1.], [-1., -1., -1., -1.]], dtype=torch.float64) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """.format(**factory_like_common_args), ) @@ -12632,8 +12781,13 @@ def merge_dicts(*dicts): add_docstr( torch.hamming_window, """ +<<<<<<< HEAD hamming_window(window_length, *, dtype=None, layout=None, device=None, pin_memory=False, \ requires_grad=False) -> Tensor +======= +hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None, \ +layout=torch.strided, device=None, requires_grad=False) -> Tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ + r""" Hamming window function. @@ -12661,6 +12815,7 @@ def merge_dicts(*dicts): + r""" Arguments: window_length (int): the size of returned window +<<<<<<< HEAD Keyword args: {dtype} Only floating point types are supported. @@ -12683,12 +12838,19 @@ def merge_dicts(*dicts): window_length (int): the size of returned window periodic (bool): If True, returns a window to be used as periodic function. If False, return a symmetric window. +======= + periodic (bool, optional): If True, returns a window to be used as periodic + function. If False, return a symmetric window. + alpha (float, optional): The coefficient :math:`\alpha` in the equation above + beta (float, optional): The coefficient :math:`\beta` in the equation above +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Keyword args: {dtype} Only floating point types are supported. layout (:class:`torch.layout`, optional): the desired layout of returned window tensor. Only ``torch.strided`` (dense layout) is supported. {device} +<<<<<<< HEAD {pin_memory} {requires_grad} @@ -12737,6 +12899,8 @@ def merge_dicts(*dicts): ``torch.strided`` (dense layout) is supported. {device} {pin_memory} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {requires_grad} Returns: diff --git a/torch/_utils.py b/torch/_utils.py index 9bd062cb5cec6..1f7a5816c3f37 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -159,7 +159,11 @@ def _get_restore_location(device): # serialization), and the state dict saves "data" only, thus # stripping the backward hooks. In some cases, hooks are # essential to the well-functioning of a model (e.g., DDP), +<<<<<<< HEAD # but DDP already manages re-adding the hooks! +======= +# but DDP already manages readding the hooks! +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # # - We didn't serialize them in many cases. Prior to #10220, we # were dropping backward hooks in ForkingPickler. We "fixed" this @@ -190,7 +194,11 @@ def _rebuild_tensor(storage, storage_offset, size, stride): def get_tensor_metadata(tensor): # Tensor's Metadata for serializing. +<<<<<<< HEAD # Currently, this only returns a dict[string, bool] specifying whether +======= + # Currently, this only returns a dict[string, bool] specifing whether +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # `conj` or `neg` bit is set. assert isinstance(tensor, torch.Tensor) return torch._C._get_tensor_metadata(tensor) # type: ignore[attr-defined] @@ -499,7 +507,11 @@ def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state): def _get_obj_state(obj): # Get the state of the python subclass +<<<<<<< HEAD # This loosely mimics the function on the object class but since Tensor do not inherit +======= + # This loosely mimicks the function on the object class but since Tensor do not inherit +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # from it, we cannot call that function directly # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891 # Note that starting with Python 3.11, this `__getstate__` is always defined and thus diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index f20a88ce85402..8da2ebcc27015 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -4,7 +4,10 @@ import os import sys import tempfile +<<<<<<< HEAD import typing_extensions +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Any, Callable, Optional, TypeVar from typing_extensions import ParamSpec @@ -29,6 +32,7 @@ StrobelightCompileTimeProfiler.enable() # this arbitrary-looking assortment of functionality is provided here +<<<<<<< HEAD # to have a central place for overridable behavior. The motivating # use is the FB build environment, where this source file is replaced # by an equivalent. @@ -37,6 +41,22 @@ torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) else: torch_parent = os.path.dirname(os.path.dirname(__file__)) +======= +# to have a central place for overrideable behavior. The motivating +# use is the FB build environment, where this source file is replaced +# by an equivalent. + +if torch._running_with_deploy(): + # __file__ is meaningless in the context of frozen torch used in torch deploy. + # setting empty torch_parent should allow below functions to operate without crashing, + # but it's unclear if there is a valid use case for them in the context of deploy. + torch_parent = "" +else: + if os.path.basename(os.path.dirname(__file__)) == "shared": + torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + else: + torch_parent = os.path.dirname(os.path.dirname(__file__)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_file_path(*path_components: str) -> str: @@ -117,10 +137,13 @@ def signpost_event(category: str, name: str, parameters: dict[str, Any]): log.info("%s %s: %r", category, name, parameters) +<<<<<<< HEAD def add_mlhub_insight(category: str, insight: str, insight_description: str): pass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def log_compilation_event(metrics): log.info("%s", metrics) @@ -137,10 +160,13 @@ def log_export_usage(**kwargs): pass +<<<<<<< HEAD def log_draft_export_usage(**kwargs): pass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def log_trace_structured_event(*args, **kwargs) -> None: pass @@ -215,9 +241,12 @@ def is_fb_unit_test() -> bool: @functools.cache def max_clock_rate(): +<<<<<<< HEAD """ unit: MHz """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not torch.version.hip: from triton.testing import nvsmi @@ -288,6 +317,7 @@ def record_chromium_event_internal( def profiler_allow_cudagraph_cupti_lazy_reinit_cuda12(): return True +<<<<<<< HEAD def deprecated(): @@ -365,3 +395,5 @@ def find_compile_subproc_binary() -> Optional[str]: Allows overriding the binary used for subprocesses """ return None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 9382a5500e0ee..7697f67f2f3f8 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -403,12 +403,18 @@ def load(self): func not in _get_allowed_globals().values() and func not in _get_user_allowed_globals().values() ): +<<<<<<< HEAD error_msg = ( f"Trying to call reduce for unrecognized function {func}" ) if hasattr(func, "__self__"): error_msg += f" which belongs to {func.__self__}" raise UnpicklingError(error_msg) +======= + raise UnpicklingError( + f"Trying to call reduce for unrecognized function {func}" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result = func(*args) if func in torch._tensor_classes and "sparse" in func.__module__: _sparse_tensors_to_validate.append(result) @@ -520,7 +526,11 @@ def load(self): elif key[0] == BINPERSID[0]: pid = self.stack.pop() # Only allow persistent load of storage +<<<<<<< HEAD if type(pid) is not tuple and type(pid) is not int: +======= + if type(pid) is not tuple and not type(pid) is not int: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise UnpicklingError( f"persistent_load id must be tuple or int, but got {type(pid)}" ) diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index 4d1a78df1f74c..da92fb567e755 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -8,6 +8,7 @@ import torch from ._utils import _device_t, _get_device_index +<<<<<<< HEAD from .memory import ( empty_cache, max_memory_allocated, @@ -18,6 +19,8 @@ reset_accumulated_memory_stats, reset_peak_memory_stats, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __all__ = [ @@ -25,6 +28,7 @@ "current_device_idx", # deprecated "current_device_index", "current_stream", +<<<<<<< HEAD "empty_cache", "device_count", "device_index", @@ -36,6 +40,11 @@ "memory_stats", "reset_accumulated_memory_stats", "reset_peak_memory_stats", +======= + "device_count", + "device_index", + "is_available", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "set_device_idx", # deprecated "set_device_index", "set_stream", @@ -137,6 +146,7 @@ def current_device_index() -> int: category=FutureWarning, )(current_device_index) +<<<<<<< HEAD current_device_idx.__doc__ = r""" (Deprecated) Return the index of a currently selected device for the current :ref:`accelerator`. @@ -149,6 +159,8 @@ def current_device_index() -> int: and will be removed in a future PyTorch release. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def set_device_index(device: _device_t, /) -> None: r"""Set the current device index to a given device. @@ -168,6 +180,7 @@ def set_device_index(device: _device_t, /) -> None: category=FutureWarning, )(set_device_index) +<<<<<<< HEAD set_device_idx.__doc__ = r""" (Deprecated) Set the current device index to a given device. @@ -181,6 +194,8 @@ def set_device_index(device: _device_t, /) -> None: and will be removed in a future PyTorch release. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def current_stream(device: _device_t = None, /) -> torch.Stream: r"""Return the currently selected stream for a given device. diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index c758d47fc8150..b0540d051285c 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -43,9 +43,13 @@ def decorate_autocast(*args, **kwargs): with autocast_instance: return func(*args, **kwargs) +<<<<<<< HEAD decorate_autocast.__script_unsupported = ( # type: ignore[attr-defined] "@autocast() decorator is not supported in script mode" ) +======= + decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode" # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return decorate_autocast @@ -90,9 +94,15 @@ class autocast: class AutocastModel(nn.Module): ... +<<<<<<< HEAD @torch.autocast(device_type="cuda") def forward(self, input): ... +======= + @torch.autocast(device_type="cuda") + def forward(self, input): + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Floating-point Tensors produced in an autocast-enabled region may be ``float16``. After returning to an autocast-disabled region, using them with floating-point @@ -154,11 +164,17 @@ class TestModel(nn.Module): def __init__(self, input_size, num_classes): super().__init__() self.fc1 = nn.Linear(input_size, num_classes) +<<<<<<< HEAD def forward(self, x): return self.fc1(x) +======= + def forward(self, x): + return self.fc1(x) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_size = 2 num_classes = 2 model = TestModel(input_size, num_classes).eval() @@ -324,7 +340,11 @@ def __init__( elif self.device == self.custom_backend_name: supported_dtype = self.custom_device_mod.get_amp_supported_dtype() if self.fast_dtype not in supported_dtype: +<<<<<<< HEAD error_message = f"In {self.custom_backend_name} autocast, but the target dtype {self.fast_dtype} is not supported. " +======= + error_message = f"In {self.custom_backend_name} autocast, but the target dtype is not supported. " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of " error_message += ( ", ".join(str(dtype) for dtype in supported_dtype) + " currently." @@ -397,10 +417,14 @@ def __enter__(self): self._enabled, self._cache_enabled, ) +<<<<<<< HEAD mode.__torch_function__(torch.amp._enter_autocast, (), args) return self return self +======= + return mode.__torch_function__(torch.amp._enter_autocast, (), args) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] if torch._jit_internal.is_scripting(): @@ -423,10 +447,14 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[ov mode, torch.fx.experimental.proxy_tensor.PreDispatchTorchFunctionMode, ): +<<<<<<< HEAD mode.__torch_function__(torch.amp._exit_autocast, (), ()) # This is very important because the above line actually doesn't # run exit code so it end up swallowing exceptions. return False +======= + return mode.__torch_function__(torch.amp._exit_autocast, (), ()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False def __call__(self, func): diff --git a/torch/amp/grad_scaler.py b/torch/amp/grad_scaler.py index 54314b034d15e..e515e5474e3f0 100644 --- a/torch/amp/grad_scaler.py +++ b/torch/amp/grad_scaler.py @@ -134,8 +134,12 @@ def __init__( if self._device == "cuda": if enabled and torch.cuda.amp.common.amp_definitely_not_available(): warnings.warn( +<<<<<<< HEAD "torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.", stacklevel=2, +======= + "torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self._enabled = False @@ -176,6 +180,7 @@ def _lazy_init_scale_growth_tracker(self, dev: torch.device) -> None: ) @overload +<<<<<<< HEAD def scale(self, outputs: torch.Tensor) -> torch.Tensor: ... @overload @@ -186,6 +191,22 @@ def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: @overload def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: ... +======= + def scale(self, outputs: torch.Tensor) -> torch.Tensor: + ... + + @overload + def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: + ... + + @overload + def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: + ... + + @overload + def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def scale( self, @@ -455,9 +476,15 @@ def step( if optimizer_state["stage"] is OptState.READY: self.unscale_(optimizer) +<<<<<<< HEAD assert len(optimizer_state["found_inf_per_device"]) > 0, ( "No inf checks were recorded for this optimizer." ) +======= + assert ( + len(optimizer_state["found_inf_per_device"]) > 0 + ), "No inf checks were recorded for this optimizer." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) @@ -501,10 +528,15 @@ def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None if isinstance(new_scale, float): self._scale.fill_(new_scale) else: +<<<<<<< HEAD reason = ( "new_scale should be a float or a 1-element torch.cuda.FloatTensor or " "torch.FloatTensor with requires_grad=False." ) +======= + reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \ + torch.FloatTensor with requires_grad=False." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert new_scale.device.type == self._device, reason assert new_scale.numel() == 1, reason assert new_scale.requires_grad is False, reason @@ -682,9 +714,15 @@ def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> dict[str, A dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device) found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device) +<<<<<<< HEAD self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = ( self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) ) +======= + self._per_optimizer_states[id(optimizer)][ + "found_inf_per_device" + ] = self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_relu.py b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py index 8446468dddcff..572ec5bd85792 100644 --- a/torch/ao/nn/intrinsic/qat/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/qat/modules/linear_relu.py @@ -1,11 +1,16 @@ +<<<<<<< HEAD from __future__ import annotations from typing import Optional +======= +# mypy: allow-untyped-defs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.ao.nn.intrinsic as nni import torch.ao.nn.qat as nnqat import torch.nn.functional as F +<<<<<<< HEAD from torch.ao.nn.intrinsic.modules.fused import _FusedModule @@ -13,6 +18,11 @@ class LinearReLU(nnqat.Linear, _FusedModule): +======= + + +class LinearReLU(nnqat.Linear, nni._FusedModule): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r""" A LinearReLU module fused from Linear and ReLU modules, attached with FakeQuantize modules for weight, used in @@ -36,6 +46,7 @@ class LinearReLU(nnqat.Linear, _FusedModule): torch.Size([128, 30]) """ +<<<<<<< HEAD _FLOAT_MODULE = nni.LinearReLU def __init__( @@ -59,6 +70,21 @@ def from_float( return super().from_float(mod, use_precomputed_fake_quant) # type: ignore[no-untyped-call,no-any-return] def to_float(self) -> nni.LinearReLU: +======= + _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, qconfig=None): + super().__init__(in_features, out_features, bias, qconfig) + + def forward(self, input): + return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias)) + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): + return super().from_float(mod, use_precomputed_fake_quant) + + def to_float(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) linear = torch.nn.Linear( self.in_features, self.out_features, self.bias is not None ) @@ -66,4 +92,8 @@ def to_float(self) -> nni.LinearReLU: if self.bias is not None: linear.bias = torch.nn.Parameter(self.bias.detach()) relu = torch.nn.ReLU() +<<<<<<< HEAD return torch.ao.nn.intrinsic.LinearReLU(linear, relu) # type: ignore[no-untyped-call] +======= + return torch.ao.nn.intrinsic.LinearReLU(linear, relu) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py b/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py index a9566b268f088..3a0df9122a20c 100644 --- a/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py @@ -1,6 +1,10 @@ +<<<<<<< HEAD from typing import Any from typing_extensions import Self +======= +# mypy: allow-untyped-defs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.ao.nn.intrinsic as nni import torch.ao.nn.quantized.dynamic as nnqd @@ -30,6 +34,7 @@ class LinearReLU(nnqd.Linear): torch.Size([128, 30]) """ +<<<<<<< HEAD _FLOAT_MODULE = nni.LinearReLU def __init__( @@ -39,6 +44,11 @@ def __init__( bias: bool = True, dtype: torch.dtype = torch.qint8, ) -> None: +======= + _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(in_features, out_features, bias, dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -55,6 +65,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: raise RuntimeError("Unsupported dtype on dynamic quantized linear relu!") return Y.to(x.dtype) +<<<<<<< HEAD def _get_name(self) -> str: return "DynamicQuantizedLinearReLU" @@ -62,10 +73,21 @@ def _get_name(self) -> str: def from_float( cls, mod: torch.nn.Module, use_precomputed_fake_quant: bool = False ) -> Self: +======= + def _get_name(self): + return "DynamicQuantizedLinearReLU" + + @classmethod + def from_float(cls, mod, use_precomputed_fake_quant=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().from_float( mod, use_precomputed_fake_quant=use_precomputed_fake_quant ) @classmethod +<<<<<<< HEAD def from_reference(cls, ref_qlinear_relu: Any) -> Self: # type: ignore[override] +======= + def from_reference(cls, ref_qlinear_relu): # type: ignore[override] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().from_reference(ref_qlinear_relu[0]) diff --git a/torch/ao/nn/qat/modules/conv.py b/torch/ao/nn/qat/modules/conv.py index 4a193fa6763cd..4155475fc248f 100644 --- a/torch/ao/nn/qat/modules/conv.py +++ b/torch/ao/nn/qat/modules/conv.py @@ -1,5 +1,9 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD from typing import ClassVar, Literal, Union +======= +from typing import ClassVar, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.nn as nn @@ -26,7 +30,11 @@ def __init__( output_padding: tuple[int, ...], groups: int, bias: bool, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"], +======= + padding_mode: str, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) qconfig=None, device=None, dtype=None, @@ -148,7 +156,11 @@ def __init__( dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) qconfig=None, device=None, dtype=None, @@ -210,7 +222,11 @@ def __init__( dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) qconfig=None, device=None, dtype=None, @@ -275,7 +291,11 @@ def __init__( dilation: _size_3_t = 1, groups: int = 1, bias: bool = True, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) qconfig=None, device=None, dtype=None, diff --git a/torch/ao/nn/quantizable/modules/activation.py b/torch/ao/nn/quantizable/modules/activation.py index d9f5e4ff4c86c..a0cdfafc4b150 100644 --- a/torch/ao/nn/quantizable/modules/activation.py +++ b/torch/ao/nn/quantizable/modules/activation.py @@ -214,7 +214,11 @@ def dequantize(self): fp.bias_v = nn.Parameter(self.bias_v.dequantize()) # Set the linear weights +<<<<<<< HEAD # Note: Because the linear layers are quantized, mypy does not know how +======= + # Note: Because the linear layers are quantized, mypy does not nkow how +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # to deal with them -- might need to ignore the typing checks. # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969 w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type] diff --git a/torch/ao/nn/quantized/dynamic/modules/conv.py b/torch/ao/nn/quantized/dynamic/modules/conv.py index a079f31f62e45..ec1b99108a29f 100644 --- a/torch/ao/nn/quantized/dynamic/modules/conv.py +++ b/torch/ao/nn/quantized/dynamic/modules/conv.py @@ -2,7 +2,11 @@ r"""Dynamically quantized convolution modules.""" import warnings +<<<<<<< HEAD from typing import ClassVar, Literal, Optional +======= +from typing import ClassVar, Optional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.ao.nn.quantized as nnq @@ -62,7 +66,11 @@ def __init__( dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, reduce_range=True, diff --git a/torch/ao/nn/quantized/modules/conv.py b/torch/ao/nn/quantized/modules/conv.py index 592c5893d113a..dcf45484b9297 100644 --- a/torch/ao/nn/quantized/modules/conv.py +++ b/torch/ao/nn/quantized/modules/conv.py @@ -1,7 +1,11 @@ # mypy: allow-untyped-defs r"""Quantized convolution modules.""" +<<<<<<< HEAD from typing import ClassVar, Literal, Optional +======= +from typing import ClassVar, Optional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.ao.nn.intrinsic as nni @@ -401,7 +405,11 @@ def __init__( dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, ): diff --git a/torch/ao/nn/quantized/modules/rnn.py b/torch/ao/nn/quantized/modules/rnn.py index 5040b8c97d050..7ea1837b7c0f3 100644 --- a/torch/ao/nn/quantized/modules/rnn.py +++ b/torch/ao/nn/quantized/modules/rnn.py @@ -1,5 +1,9 @@ +<<<<<<< HEAD from typing import Any +======= +# mypy: allow-untyped-defs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch @@ -36,11 +40,19 @@ class LSTM(torch.ao.nn.quantizable.LSTM): _FLOAT_MODULE = torch.ao.nn.quantizable.LSTM # type: ignore[assignment] +<<<<<<< HEAD def _get_name(self) -> str: return "QuantizedLSTM" @classmethod def from_float(cls, *args: Any, **kwargs: Any) -> None: +======= + def _get_name(self): + return "QuantizedLSTM" + + @classmethod + def from_float(cls, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The whole flow is float -> observed -> quantized # This class does observed -> quantized only raise NotImplementedError( @@ -50,7 +62,11 @@ def from_float(cls, *args: Any, **kwargs: Any) -> None: ) @classmethod +<<<<<<< HEAD def from_observed(cls: type["LSTM"], other: torch.ao.nn.quantizable.LSTM) -> "LSTM": +======= + def from_observed(cls, other): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(other, cls._FLOAT_MODULE) # type: ignore[has-type] converted = torch.ao.quantization.convert( other, inplace=False, remove_qconfig=True diff --git a/torch/ao/nn/quantized/reference/modules/conv.py b/torch/ao/nn/quantized/reference/modules/conv.py index de2ea9c6da8d0..bd7450b7ee610 100644 --- a/torch/ao/nn/quantized/reference/modules/conv.py +++ b/torch/ao/nn/quantized/reference/modules/conv.py @@ -1,5 +1,9 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD from typing import Any, Literal, Optional +======= +from typing import Any, Optional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.nn as nn @@ -62,7 +66,11 @@ def __init__( dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, weight_qparams: Optional[dict[str, Any]] = None, @@ -282,7 +290,11 @@ def __init__( groups: int = 1, bias: bool = True, dilation: _size_1_t = 1, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, weight_qparams: Optional[dict[str, Any]] = None, diff --git a/torch/ao/nn/sparse/quantized/utils.py b/torch/ao/nn/sparse/quantized/utils.py index ccf85e68d84ff..1399b60f87b1d 100644 --- a/torch/ao/nn/sparse/quantized/utils.py +++ b/torch/ao/nn/sparse/quantized/utils.py @@ -15,7 +15,11 @@ def _is_valid_linear_block_sparse_pattern( # This is a stop-gap measure as current flow does not allow module # specific block sparse pattern. +<<<<<<< HEAD # In fact there is no way to convey sparse pattern via module config +======= +# Infact there is no way to convey sparse pattern via module config +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # of quantization flow. Thus using the global context to convey # sparsity pattern. # Once the flow supports it, this should be removed. diff --git a/torch/ao/ns/fx/graph_passes.py b/torch/ao/ns/fx/graph_passes.py index bc30a014c195a..08a9dadee627f 100644 --- a/torch/ao/ns/fx/graph_passes.py +++ b/torch/ao/ns/fx/graph_passes.py @@ -1124,7 +1124,11 @@ def load_arg(a): # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c # # Note: node_start_c may be the same node as node_end_c, or they +<<<<<<< HEAD # may have nodes in between. +======= + # may have nodes inbetween. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) diff --git a/torch/ao/ns/fx/qconfig_multi_mapping.py b/torch/ao/ns/fx/qconfig_multi_mapping.py index 530d5ce52d998..47ce1c0dee937 100644 --- a/torch/ao/ns/fx/qconfig_multi_mapping.py +++ b/torch/ao/ns/fx/qconfig_multi_mapping.py @@ -109,7 +109,11 @@ def _handle_list_size_mismatch( target_qconfigs_dict[key] = None break +<<<<<<< HEAD # insert copies of this new QConfigMapping until all entries +======= + # insert copies of this new QConfigMapping until all entires +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # in qconfig_list can fit among the QConfigMappings while len(qconfig_list) > len(self.qconfig_mappings_list): self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping)) diff --git a/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py b/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py index ef6a35686c7d6..ba4eb2fc7ad40 100644 --- a/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py +++ b/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py @@ -159,7 +159,11 @@ def hook(module, input) -> None: if data is None: out_data = [ 0 for _ in range(0, len(features)) +<<<<<<< HEAD ] # create one in case of 1st forward +======= + ] # create one incase of 1st forward +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.state[name]["mask"] = [0 for _ in range(0, len(features))] else: out_data = data # a list diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/README.md b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/README.md index 234a573029f80..b27087a2fa07c 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/README.md +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/README.md @@ -14,7 +14,11 @@ The [DataNormSparsifier](https://github.com/pytorch/pytorch/blob/main/torch/ao/p 3. Norm: L1 and L2 ## Dataset +<<<<<<< HEAD The benchmarks are created for the dlrm model on the Kaggle CriteoDataset which can be downloaded from [here](https://ailab.criteo.com/ressources/) or [here](https://figshare.com/articles/dataset/Kaggle_Display_Advertising_Challenge_dataset/5732310/1). +======= +The benchmarks are created for the dlrm model on the Kaggle CriteoDataset which can be downloaded from [here](https://ailab.criteo.com/ressources/) or [here](https://figshare.com/articles/dataset/Kaggle_Display_Advertising_Challenge_dataset/5732310/1). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ## Results 1. **Disk Usage**: Introducing sparsity in the embeddings reduces file size after compression. The compressed model size goes down from 1.9 GB to 150 MB after 100% sparsity. @@ -34,7 +38,11 @@ The takeaway is that the dlrm model with sparse coo tensor is slower (roughly 2x ## Setup The benchmark codes depend on the [DLRM codebase](https://github.com/facebookresearch/dlrm). 1. Clone the dlrm git repository +<<<<<<< HEAD 2. Download the dataset from [here](https://ailab.criteo.com/ressources/) or [here](https://figshare.com/articles/dataset/Kaggle_Display_Advertising_Challenge_dataset/5732310/1) +======= +2. Download the dataset from [here](https://ailab.criteo.com/ressources/) or [here](https://figshare.com/articles/dataset/Kaggle_Display_Advertising_Challenge_dataset/5732310/1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 3. The DLRM model can be trained using the following script ``` # Make sure you go into the file and make sure that the path to dataset is correct. diff --git a/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py b/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py index 442639be9b214..b278f3c3d165a 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/lightning/tests/test_callbacks.py @@ -199,7 +199,11 @@ def _check_on_train_epoch_start(self, pl_module, callback): do not want as the config of each layer changes after .step() +<<<<<<< HEAD Hence, we need to dump and restore the state_dict() every time because we're +======= + Hence, we need to dump and restore the state_dict() everytime because we're +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) copying the model after each epoch. Hence, it is essential to make sure that the sparsifier's state_dict() is being correctly dumped and restored. diff --git a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py index 680ecd9f139e3..d6f1761292580 100644 --- a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py @@ -11,7 +11,11 @@ class FPGMPruner(BaseStructuredSparsifier): r"""Filter Pruning via Geometric Median (FPGM) Structured Pruner +<<<<<<< HEAD This sparsifier prune filter (row) in a tensor according to distances among filters according to +======= + This sparsifier prune fliter (row) in a tensor according to distances among filters according to +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration `_. This sparsifier is controlled by three variables: diff --git a/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py b/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py index f904cc3ab8c4c..1cdb3f550d516 100644 --- a/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD from typing import Any, cast import torch @@ -5,6 +6,14 @@ from .base_structured_sparsifier import BaseStructuredSparsifier from .parametrization import FakeStructuredSparsity +======= +# mypy: allow-untyped-defs +from typing import cast + +import torch + +from .base_structured_sparsifier import BaseStructuredSparsifier, FakeStructuredSparsity +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class LSTMSaliencyPruner(BaseStructuredSparsifier): @@ -26,7 +35,11 @@ class LSTMSaliencyPruner(BaseStructuredSparsifier): This applies to both weight_ih_l{k} and weight_hh_l{k}. """ +<<<<<<< HEAD def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: Any) -> None: +======= + def update_mask(self, module, tensor_name, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) weights = getattr(module, tensor_name) for p in getattr(module.parametrizations, tensor_name): diff --git a/torch/ao/pruning/_experimental/pruner/saliency_pruner.py b/torch/ao/pruning/_experimental/pruner/saliency_pruner.py index 1a97cff7ab231..ea02f67623eb8 100644 --- a/torch/ao/pruning/_experimental/pruner/saliency_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/saliency_pruner.py @@ -7,7 +7,11 @@ class SaliencyPruner(BaseStructuredSparsifier): Prune rows based on the saliency (L1 norm) of each row. This pruner works on N-Dimensional weight tensors. +<<<<<<< HEAD For each row, we will calculate the saliency, which is the sum the L1 norm of all weights in that row. +======= + For each row, we will calculate the saliency, whic is the sum the L1 norm of all weights in that row. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) We expect that the resulting saliency vector has the same shape as our mask. We then pick elements to remove until we reach the target sparsity_level. """ diff --git a/torch/ao/pruning/scheduler/lambda_scheduler.py b/torch/ao/pruning/scheduler/lambda_scheduler.py index 5588c157161a0..cdfa7031e1868 100644 --- a/torch/ao/pruning/scheduler/lambda_scheduler.py +++ b/torch/ao/pruning/scheduler/lambda_scheduler.py @@ -1,7 +1,12 @@ +<<<<<<< HEAD import warnings from typing import Callable, Union from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier +======= +# mypy: allow-untyped-defs +import warnings +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .base_scheduler import BaseScheduler @@ -32,6 +37,7 @@ class LambdaSL(BaseScheduler): >>> scheduler.step() """ +<<<<<<< HEAD def __init__( self, sparsifier: BaseSparsifier, @@ -39,6 +45,9 @@ def __init__( last_epoch: int = -1, verbose: bool = False, ) -> None: +======= + def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.sparsifier = sparsifier if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple): @@ -49,9 +58,15 @@ def __init__( f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}" ) self.sl_lambdas = list(sl_lambda) +<<<<<<< HEAD super().__init__(sparsifier, last_epoch, verbose) # type: ignore[no-untyped-call] def get_sl(self) -> list[float]: +======= + super().__init__(sparsifier, last_epoch, verbose) + + def get_sl(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not self._get_sl_called_within_step: warnings.warn( "To get the last sparsity level computed by the scheduler, " diff --git a/torch/ao/pruning/sparsifier/utils.py b/torch/ao/pruning/sparsifier/utils.py index 47185aeea5274..71ee52d358614 100644 --- a/torch/ao/pruning/sparsifier/utils.py +++ b/torch/ao/pruning/sparsifier/utils.py @@ -98,7 +98,11 @@ def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> dict[str, # string manip to split tensor_fqn into module_fqn and tensor_name # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight' # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight' +<<<<<<< HEAD tensor_name = tensor_fqn.rsplit(".", maxsplit=1)[-1] +======= + tensor_name = tensor_fqn.split(".")[-1] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)] module = fqn_to_module(model, module_fqn) diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index f50b9d6cd137e..90b300fdb650f 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -1,6 +1,9 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD import sys +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Callable, Optional, Union import torch @@ -33,6 +36,7 @@ # ensure __module__ is set correctly for public APIs +<<<<<<< HEAD if sys.version_info < (3, 12): ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase] ObserverOrFakeQuantize.__module__ = "torch.ao.quantization" @@ -43,6 +47,10 @@ "ObserverOrFakeQuantize", Union[ObserverBase, FakeQuantizeBase] ) +======= +ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase] +ObserverOrFakeQuantize.__module__ = "torch.ao.quantization" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for _f in [ compare_results, extract_results_from_loggers, diff --git a/torch/ao/quantization/backend_config/observation_type.py b/torch/ao/quantization/backend_config/observation_type.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/ao/quantization/experimental/adaround_loss.py b/torch/ao/quantization/experimental/adaround_loss.py index 9b0ce6a32f14d..cb6fc2d7a6011 100644 --- a/torch/ao/quantization/experimental/adaround_loss.py +++ b/torch/ao/quantization/experimental/adaround_loss.py @@ -54,7 +54,11 @@ def rounding_regularization( 1 + np.cos(rel_iter * np.pi) ) +<<<<<<< HEAD # A rectified sigmoid for soft-quantization as formulated [23] in https://arxiv.org/pdf/2004.10568.pdf +======= + # A rectified sigmoid for soft-quantization as formualted [23] in https://arxiv.org/pdf/2004.10568.pdf +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) h_alpha = torch.clamp( torch.sigmoid(V) * (ADAROUND_ZETA - ADAROUND_GAMMA) + ADAROUND_GAMMA, min=0, diff --git a/torch/ao/quantization/experimental/adaround_optimization.py b/torch/ao/quantization/experimental/adaround_optimization.py index fd2d8124bb701..767d98610d766 100644 --- a/torch/ao/quantization/experimental/adaround_optimization.py +++ b/torch/ao/quantization/experimental/adaround_optimization.py @@ -107,7 +107,11 @@ def get_data_inp_out( ) if torch.cuda.is_available(): # Somehow, we need to move the model continuously +<<<<<<< HEAD # Otherwise, the model will be lowered to CPU mysteriously +======= + # Otherwise, the model will be lowered to CPU misteriously +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.model = self.model.cuda() self.q_model = self.q_model.cuda() for data_ in data: diff --git a/torch/ao/quantization/fx/README.md b/torch/ao/quantization/fx/README.md index cd380977b2aa5..7a82e107168c2 100644 --- a/torch/ao/quantization/fx/README.md +++ b/torch/ao/quantization/fx/README.md @@ -296,7 +296,11 @@ BackendConfig(nniqat.LinearReLU) Pattern in this case is the same as before, it defines the pattern for the subgraph we are dealing with +<<<<<<< HEAD `set_observation_type`: sets the observation type for the pattern, currently only two types: +======= +`set_observation_type`: sets the observation type for the patter, currently only two types: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) `OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` means the output observer instance will be different from the input, which is the most common type of observer placement. diff --git a/torch/ao/quantization/fx/_model_report/README.md b/torch/ao/quantization/fx/_model_report/README.md index fa4f142aa23cf..cffc827a36ec3 100644 --- a/torch/ao/quantization/fx/_model_report/README.md +++ b/torch/ao/quantization/fx/_model_report/README.md @@ -8,10 +8,17 @@ ModelReport Most detectors require a **traceable GraphModule**, but some (ex. `PerChannelDetector`) require just an `nn.Module`. #### Typical Fx Workflow +<<<<<<< HEAD - Initialize model → Prepare model → Calibrate model → Convert model → ... #### Fx Workflow with ModelReport - Initialize model → Prepare model → **Add detector observers** → Calibrate model → **Generate report** → **Remove detector observers** → Convert model → ... +======= +- Initialize model → Prepare model → Callibrate model → Convert model → ... + +#### Fx Workflow with ModelReport +- Initialize model → Prepare model → **Add detector observers** → Callibrate model → **Generate report** → **Remove detector observers** → Convert model → ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) > āš ļø **You can only prepare and remove observers once with a given ModelReport Instance**: Be very careful here! @@ -23,7 +30,11 @@ This snippet should be ready to copy, paste, and use with the exception of a few # prep model qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping() model = Model() # TODO define model +<<<<<<< HEAD example_input = torch.randn((*args)) # TODO get example data for calibration +======= +example_input = torch.randn((*args)) # TODO get example data for callibration +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prepared_model = quantize_fx.prepare_fx(model, qconfig_mapping, example_input) # create ModelReport instance and insert observers @@ -31,8 +42,13 @@ detector_set = set([DynamicStaticDetector()]) # TODO add all desired detectors model_report = ModelReport(model, detector_set) ready_for_callibrate = model_report.prepare_detailed_callibration() +<<<<<<< HEAD # calibrate model and generate report ready_for_callibrate(example_input) # TODO run calibration of model with relevant data +======= +# callibrate model and generate report +ready_for_callibrate(example_input) # TODO run callibration of model with relevant data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reports = model_report.generate_model_report(remove_inserted_observers=True) for report_name in report.keys(): text_report, report_dict = reports[report_name] @@ -46,7 +62,11 @@ mod_rep_visualizer.generate_table_visualization() # shows collected data as a ta ``` There is a tutorial in the works that will walk through a full usage of the ModelReport API. +<<<<<<< HEAD This tutorial will show the ModelReport API being used on toy model in both an Fx Graph Mode workflow and an alternative workflow with just a traceable model. +======= +This tutorial will show the ModelReport API being used on toy model in both an Fx Graph Mode workflow and an alterative workflow with just a traceable model. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) This README will be updated with a link to the tutorial upon completion of the tutorial. # Key Modules Overview @@ -60,7 +80,11 @@ There are three primary methods to be familiar with when using the ModelReport c This is so that we can keep track of where we want to insert observers on a detector by detector basis and also keep track of which detectors to generate reports for. - `prepare_detailed_calibration(self)` → `GraphModule` inserts observers into the locations specified by each detector in the model. It then returns the GraphModule with the detectors inserted into both the regular module structure as well as the node structure. +<<<<<<< HEAD - `generate_model_report(self, remove_inserted_observers: bool)` → `Dict[str, Tuple[str, Dict]]` uses calibrated GraphModule to optionally removes inserted observers, and generate, for each detector the ModelReport instance was initialized with: +======= +- `generate_model_report(self, remove_inserted_observers: bool)` → `Dict[str, Tuple[str, Dict]]` uses callibrated GraphModule to optionally removes inserted observers, and generate, for each detector the ModelReport instance was initialized with: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - A string-based report that is easily digestable and actionable explaining the data collected by relevant observers for that detector - A dictionary containing statistics collected by the relevant observers and values calculated by the detector for further analysis or plotting @@ -107,7 +131,11 @@ For both of the two things listed above, you can filter the data by either `modu To get a list of all the modules or features, you can call `mod_rep_visualizer.get_all_unique_module_fqns()` and `mod_rep_visualizer.get_all_unique_feature_names()` respectively. For the features, because some features are not plottable, you can set the flag to only get plottable features +<<<<<<< HEAD in the aforementioned `get_all_unique_feature_names` method. +======= +in the aformentioned `get_all_unique_feature_names` method. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ## Detector Overview @@ -152,7 +180,11 @@ The statistics collected by the `ModelReportObserver` include: - Ratio of 100th percentile to some *n*th percentile - Number of constant value batches to pass through each channel +<<<<<<< HEAD After the `ModelReportObserver` collects the statistics above during the calibration process, the detectors then extract the information they need to generate their reports from the relevant observers. +======= +After the `ModelReportObserver` collects the statistics above during the callibration process, the detectors then extract the information they need to generate their reports from the relevant observers. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ### Using Your Own Observer diff --git a/torch/ao/quantization/fx/_model_report/model_report.py b/torch/ao/quantization/fx/_model_report/model_report.py index 04035b41bf68e..6a46cf443876b 100644 --- a/torch/ao/quantization/fx/_model_report/model_report.py +++ b/torch/ao/quantization/fx/_model_report/model_report.py @@ -36,7 +36,11 @@ class ModelReport: - Suggestions for outlier detection for all layers (Graph Modules) The ModelReport class has the primary functionality of inserting observers (primarily the ModelReportObserver) +<<<<<<< HEAD where needed for each detector to gather the information it needs, and then after calibration, the ModelReport +======= + where needed for each detector to gather the information it needs, and then after callibration, the ModelReport +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class compiles the report generated by each Detector class into a single report to return to the user. It also has the capability to remove all the observers it inserted as well. @@ -70,7 +74,11 @@ class compiles the report generated by each Detector class into a single report 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects and model 2.) Prepare your model with prepare_fx 3.) Call model_report.prepare_detailed_calibration to add relevant observers +<<<<<<< HEAD 4.) Calibrate your model with data +======= + 4.) Callibrate your model with data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers Optional 6.) Call model_report.generate_visualizer to get a ModelReportVisualizer instance @@ -102,7 +110,11 @@ class compiles the report generated by each Detector class into a single report ... ) >>> tracer_reporter = ModelReport(graph_module, tracer_detector_set) +<<<<<<< HEAD >>> # now we insert the observers and calibrate the model +======= + >>> # now we insert the observers and callibrate the model +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> tracer_model_with_observers = tracer_reporter.prepare_detailed_calibration() >>> for i in range(num_callibration_batches): >>> example_input = get_callibration_input() @@ -179,7 +191,11 @@ def prepare_detailed_calibration(self) -> GraphModule: # if already prepared once, cannot prepare again if self._prepared_flag: raise ValueError( +<<<<<<< HEAD "Already ran preparing detailed calibration. Run the report generation next after calibration." +======= + "Already ran preparing detailed callibration. Run the report generation next after callibration." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # loop through each detector, find where placements should be, and keep track @@ -271,7 +287,11 @@ def generate_model_report( Generates all the requested reports. Note: +<<<<<<< HEAD You should have calibrated the model with relevant data before calling this +======= + You should have callibrated the model with relevant data before calling this +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) The reports generated are specified by the desired_reports specified in desired_reports @@ -286,12 +306,21 @@ def generate_model_report( Note: Throws exception if we try to generate report on model we already removed observers from +<<<<<<< HEAD Throws exception if we try to generate report without preparing for calibration """ # if we haven't prepped model for calibration, then we shouldn't generate report yet if not self._prepared_flag: raise Exception( # noqa: TRY002 "Cannot generate report without preparing model for calibration" +======= + Throws exception if we try to generate report without preparing for callibration + """ + # if we haven't prepped model for callibration, then we shouldn't generate report yet + if not self._prepared_flag: + raise Exception( # noqa: TRY002 + "Cannot generate report without preparing model for callibration" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # if we already removed the observers, we cannot generate report @@ -546,12 +575,21 @@ def _generate_module_fqn_to_detector_info_mapping( Note: Throws exception if we try to generate mapping on model we already removed observers from +<<<<<<< HEAD Throws exception if we try to generate mapping without preparing for calibration """ # if we haven't prepped model for calibration, then we shouldn't generate mapping yet if not self._prepared_flag: raise Exception( # noqa: TRY002 "Cannot generate report without preparing model for calibration" +======= + Throws exception if we try to generate mapping without preparing for callibration + """ + # if we haven't prepped model for callibration, then we shouldn't generate mapping yet + if not self._prepared_flag: + raise Exception( # noqa: TRY002 + "Cannot generate report without preparing model for callibration" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # if we already removed the observers, we cannot mapping @@ -600,7 +638,11 @@ def generate_qconfig_mapping(self) -> QConfigMapping: Note: Throws exception if we try to generate mapping on model we already removed observers from +<<<<<<< HEAD Throws exception if we try to generate mapping without preparing for calibration +======= + Throws exception if we try to generate mapping without preparing for callibration +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ # get the mapping info detector_qconfig_info_combined = ( diff --git a/torch/ao/quantization/fx/_model_report/model_report_visualizer.py b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py index 63d31171bbe76..e2822ff3e92ec 100644 --- a/torch/ao/quantization/fx/_model_report/model_report_visualizer.py +++ b/torch/ao/quantization/fx/_model_report/model_report_visualizer.py @@ -63,7 +63,11 @@ class ModelReportVisualizer: 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects 2.) Prepare your model with prepare_fx 3.) Call model_report.prepare_detailed_calibration on your model to add relevant observers +<<<<<<< HEAD 4.) Calibrate your model with data +======= + 4.) Callibrate your model with data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers 6.) Use output of model_report.generate_report to initialize ModelReportVisualizer instance 7.) Use instance to view different views of data as desired, applying filters as needed diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index dc51ab943bc5b..02d6d54c970f6 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -94,7 +94,10 @@ def _replace_observer_with_quantize_dequantize_node_decomposed( modules: dict[str, torch.nn.Module], node_name_to_scope: dict[str, tuple[str, type]], node_name_to_qconfig: dict[str, QConfigAny], +<<<<<<< HEAD model_device: Optional[torch.device] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: """Replace activation_post_process module call node with quantize and dequantize node working with decomposed Tensor @@ -211,11 +214,15 @@ def add_dequantize_op_kwargs(dequantize_op, input_node): # sure that the default overload can be used. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( +<<<<<<< HEAD model, graph, module_path + prefix + key, value_or_node, model_device, +======= + model, graph, module_path + prefix + key, value_or_node +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) quantize_op_inputs.append(qparam_node) else: @@ -367,7 +374,10 @@ def _replace_observer_with_quantize_dequantize_node( modules: dict[str, torch.nn.Module], node_name_to_scope: dict[str, tuple[str, type]], node_name_to_qconfig: dict[str, QConfigAny], +<<<<<<< HEAD model_device: Optional[torch.device] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: """Replace activation_post_process module call node with quantize and dequantize node @@ -448,11 +458,15 @@ def _replace_observer_with_quantize_dequantize_node( # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( +<<<<<<< HEAD model, graph, module_path + prefix + key, value_or_node, model_device, +======= + model, graph, module_path + prefix + key, value_or_node +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) quantize_op_inputs.append(qparam_node) else: @@ -750,7 +764,10 @@ def convert_weighted_module( backend_config: BackendConfig, is_decomposed: bool = False, is_reference: bool = False, +<<<<<<< HEAD model_device: Optional[torch.device] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: """Convert a weighted module to reference quantized module in the model If the QConfig of a QAT module is not set, the module will still be converted to @@ -839,10 +856,14 @@ def convert_weighted_module( is_ptq = weight_post_process is None if is_ptq: weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] +<<<<<<< HEAD if model_device is not None: device = model_device else: device = assert_and_get_unique_device(float_module) +======= + device = assert_and_get_unique_device(float_module) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if device: weight_post_process.to(device) @@ -1158,7 +1179,10 @@ def convert( qat_module_classes = get_qat_module_classes(backend_config) fused_module_classes = get_fused_module_classes(backend_config) statically_quantized_custom_module_nodes: set[Node] = set() +<<<<<<< HEAD model_device = assert_and_get_unique_device(model) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for node in list(model.graph.nodes): if node.op == "placeholder": @@ -1212,7 +1236,10 @@ def convert( modules, node_name_to_scope, node_name_to_qconfig, +<<<<<<< HEAD model_device, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: _replace_observer_with_quantize_dequantize_node( @@ -1221,7 +1248,10 @@ def convert( modules, node_name_to_scope, node_name_to_qconfig, +<<<<<<< HEAD model_device, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) elif isinstance(mod, DeQuantStub): _replace_observer_or_dequant_stub_with_dequantize_node( @@ -1251,7 +1281,10 @@ def convert( backend_config, is_decomposed, is_reference, +<<<<<<< HEAD model_device, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) elif type_before_parametrizations(mod) in custom_module_classes: convert_custom_module( diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index e70a078630d9d..ba57b639aa5df 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -478,7 +478,10 @@ def _insert_obs_or_fq( model: torch.nn.Module, named_modules: dict[str, torch.nn.Module], graph: Graph, +<<<<<<< HEAD model_device: Optional[torch.device] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Node: """ Attaches `obs_or_fq` to `model`, and creates a node which calls @@ -486,8 +489,12 @@ def _insert_obs_or_fq( obs_or_fq: an instance of Observer or FakeQuantize module """ +<<<<<<< HEAD if model_device is None: model_device = assert_and_get_unique_device(model) +======= + model_device = assert_and_get_unique_device(model) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if model_device: obs_or_fq.to(model_device) # add obs_or_fq module as attribute @@ -807,7 +814,10 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, backend_config: Optional[BackendConfig] = None, +<<<<<<< HEAD model_device: Optional[torch.device] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Argument: """ Given a `node` and an `arg`, inserts an input observer between @@ -830,7 +840,10 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( obs_or_fq_map, is_qat, backend_config, +<<<<<<< HEAD model_device, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) new_arg_to_return.append(new_inner_arg) return type(arg)(new_arg_to_return) @@ -949,12 +962,16 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq if existing_obs_node is None: new_obs_node = _insert_obs_or_fq( +<<<<<<< HEAD arg, arg_as_input_act_obs_or_fq, model, named_modules, graph, model_device, +======= + arg, arg_as_input_act_obs_or_fq, model, named_modules, graph +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # override this arg to be the observed arg new_arg = new_obs_node @@ -975,7 +992,10 @@ def _maybe_insert_input_observers_for_node( obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, backend_config: Optional[BackendConfig] = None, +<<<<<<< HEAD model_device: Optional[torch.device] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: """ If needed, inserts observers to the input args and kwargs of `node`. @@ -1007,7 +1027,10 @@ def _maybe_insert_input_observers_for_node( obs_or_fq_map, is_qat, backend_config, +<<<<<<< HEAD model_device, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) new_args.append(new_arg) @@ -1025,7 +1048,10 @@ def _maybe_insert_input_observers_for_node( obs_or_fq_map, is_qat, backend_config, +<<<<<<< HEAD model_device, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) new_kwargs[k] = new_kwarg @@ -1119,7 +1145,11 @@ def _maybe_insert_output_observer_for_node( ) target_dtype, target_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq) # uncomment after we support reuse_input_obs_or_fq properly by having separate +<<<<<<< HEAD # implementations for this key instead of reusing the input_output_share_observers +======= + # implemntations for this key instead of reusing the input_output_share_observers +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # code # reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False) # for now we set this to False since reuse_input_obs_or_fq for @@ -1129,7 +1159,11 @@ def _maybe_insert_output_observer_for_node( reuse_input_obs_or_fq = False # Note: prev_output_dtype = torch.float and prev_output_is_dynamic=False +<<<<<<< HEAD # because the prev_output is the output of an fp32 op, although technically +======= + # because the prev_output is the output of an fp32 op, althought technically +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # we should get the dtype of the output from node.meta["val"] in the future # if we deprecate fx graph mode quantization needs_obs_or_fq = _needs_obs_or_fq( @@ -1675,7 +1709,10 @@ def insert_observers_for_model( outputs_seen_counter = 0 results_node = None obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize] = {} +<<<<<<< HEAD model_device = assert_and_get_unique_device(model) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: change this to insert obs/fq by pattern instead of by node for node in nodes_before_observation: @@ -1779,7 +1816,10 @@ def insert_observers_for_model( obs_or_fq_map, is_qat, backend_config, +<<<<<<< HEAD model_device, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # insert equalization input observers if needed @@ -2016,7 +2056,11 @@ def prepare( same as input_quantized_idxs configuration provided for the standalone module standalone_module_output_quantized_idxs(List[Int]): a list of +<<<<<<< HEAD indices for the graph output that is quantized +======= + indexs for the graph output that is quantized +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) same as input_quantized_idxs configuration provided for the standalone module """ diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index f8445da5fea19..3b10a52e00f99 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -190,7 +190,11 @@ def get_attr_name(i: int): def collect_producer_nodes(node: Node) -> Optional[list[Node]]: +<<<<<<< HEAD r"""Starting from a target node, trace back until we hit input or +======= + r"""Starting from a target node, trace back until we hit inpu or +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) getattr node. This is used to extract the chain of operators starting from getattr to the target node, for example def forward(self, x): @@ -254,11 +258,15 @@ def assert_and_get_unique_device(module: torch.nn.Module) -> Any: def create_getattr_from_value( +<<<<<<< HEAD module: torch.nn.Module, graph: Graph, prefix: str, value: Any, device: Optional[torch.device] = None, +======= + module: torch.nn.Module, graph: Graph, prefix: str, value: Any +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Node: """ Given a value of any type, creates a getattr node corresponding to the value and @@ -266,8 +274,12 @@ def create_getattr_from_value( """ get_new_attr_name = get_new_attr_name_with_prefix(prefix) attr_name = get_new_attr_name(module) +<<<<<<< HEAD if device is None: device = assert_and_get_unique_device(module) +======= + device = assert_and_get_unique_device(module) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_value = ( value.detach().clone() if isinstance(value, torch.Tensor) diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index 7b56fbe7232cb..699863bb2ae36 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -358,7 +358,11 @@ def _calculate_qparams( # Functionally equivalent to 'determine_qparams' in utils.py. Observers must be torchscriptable however and qscheme # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer # to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code +<<<<<<< HEAD # seems unlikely to change (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. +======= + # seems unlikey to change (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO(jakeszwe, jerryzh168) if not check_min_max_valid(min_val, max_val): return torch.tensor([1.0], device=min_val.device.type), torch.tensor( @@ -1241,7 +1245,11 @@ def _combine_histograms( # If the orig hist only has one value (i.e., the min and max are the same) # we can just add it into new histogram if orig_min == orig_max: +<<<<<<< HEAD bin_value = torch.sum(orig_hist) +======= + bin_value = torch.sum(update_hist) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) transformed_orig_hist = ( torch.histc(orig_min, bins=self.bins, min=update_min, max=update_max) # type: ignore[arg-type] * bin_value @@ -1866,7 +1874,11 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): Converts the observer node in the graph into its quantized representation Args: +<<<<<<< HEAD model: graph module to convert the observer node in +======= + model: graph module to conver the observer node in +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) observer_node: the observer node to convert """ from torch.ao.quantization.fx.utils import create_getattr_from_value @@ -1902,6 +1914,7 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): else: scale, zero_point = self.calculate_qparams() scale_node = create_getattr_from_value( +<<<<<<< HEAD model, model.graph, "_scale", @@ -1914,6 +1927,12 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): "_zero_point", zero_point, zero_point.device if isinstance(zero_point, torch.Tensor) else None, +======= + model, model.graph, "_scale", scale + ) + zero_point_node = create_getattr_from_value( + model, model.graph, "_zero_point", zero_point +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) q_node = model.graph.call_function( diff --git a/torch/ao/quantization/pt2e/_affine_quantization.py b/torch/ao/quantization/pt2e/_affine_quantization.py index e4eac6f6cc776..e6ce89746eeb0 100644 --- a/torch/ao/quantization/pt2e/_affine_quantization.py +++ b/torch/ao/quantization/pt2e/_affine_quantization.py @@ -1,6 +1,10 @@ # copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py # and https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py +<<<<<<< HEAD # PLEASE DON'T MODIFY THIS FILE SO THAT WE DON'T GET OUT OF SYNC +======= +# PLESE DON'T MODIFY THIS FILE SO THAT WE DON'T GET OUT OF SYNC +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import logging from abc import ABCMeta from typing import Any, Optional, Union @@ -469,7 +473,11 @@ def _quantize_affine_no_dtype_cast( 1. figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction 2. quantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain +<<<<<<< HEAD 3. reshape the quantized result to original shape +======= + 3. reshape the quantized result to origianl shape +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size @@ -619,7 +627,11 @@ def _dequantize_affine_no_dtype_check( 1. figure out the dimension for reduction based on block_size, also reshape the input to align with the shape after reduction 2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain +<<<<<<< HEAD 3. reshape the quantized result to original shape and change dtype to the output_dtype +======= + 3. reshape the quantized result to origianl shape and change dtype to the output_dtype +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ assert len(block_size) == input.dim(), ( f"Got input dim:{input.dim()}, block_size: {block_size}" diff --git a/torch/ao/quantization/pt2e/port_metadata_pass.py b/torch/ao/quantization/pt2e/port_metadata_pass.py index aab4c435c872f..5563a5fbf7a7c 100644 --- a/torch/ao/quantization/pt2e/port_metadata_pass.py +++ b/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -177,19 +177,32 @@ class PortNodeMetaForQDQ(PassBase): - Example 1: - Original: [Conv -> AvgPool -> Linear] - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] +<<<<<<< HEAD - Inner brackets specify which nodes Q/DQ inherit metadata from +======= + - Inner brackets specify which nodes Q/DQ inherit metdata from +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ] - Note first Q and last DQ do not inherit metadata from any nodes - Example 2: - Original: [Conv -> AvgPool -> Linear] - AvgPool is not quantized - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] +<<<<<<< HEAD - Inner brackets specify which nodes Q/DQ inherit metadata from - [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ] - Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation on the nodes (in this case AvgPool node) to conclude if the node or pattern was supposed to be quantized. And subsequently decide if the preceding Q, if any, should +======= + - Inner brackets specify which nodes Q/DQ inherit metdata from + - [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ] + - Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because + AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation + on the nodes (in this case AvgPool node) to conclude if the node or patter was + supposed to be quantized. And subsequntly decide if the preceding Q, if any, should +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inherit metadata from AvgPool. - Dynamically quantized patterns: - Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index 57ff311521015..314f97a01708c 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -22,7 +22,10 @@ QuantizationSpecBase, SharedQuantizationSpec, ) +<<<<<<< HEAD from torch.ao.quantization.utils import _assert_and_get_unique_device +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx import Graph, GraphModule, Node from torch.fx.node import Argument @@ -276,7 +279,11 @@ def _get_edge_or_node_to_group_id( _update_shared_with(input_edge, qspec, shared_with_map) +<<<<<<< HEAD # now that we get the sharing relations between all edges and nodes, we can assign group ids +======= + # now that we get the sharing relations between all edges and nodes, we can assingn group ids +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cur_group_id = 0 edge_or_node_to_group_id: dict[EdgeOrNode, int] = {} for edge_or_node in shared_with_map.keys(): @@ -320,7 +327,10 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( named_modules: dict[str, torch.nn.Module], obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, +<<<<<<< HEAD model_device: Optional[torch.device] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Argument: """ Given a `node` and an `arg`, inserts an input observer between @@ -339,7 +349,10 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( named_modules, obs_or_fq_map, is_qat, +<<<<<<< HEAD model_device, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) new_arg_to_return.append(new_inner_arg) return type(arg)(new_arg_to_return) @@ -393,12 +406,16 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( assert isinstance(model.graph, Graph) new_arg = _insert_obs_or_fq( +<<<<<<< HEAD arg, input_edge_obs_or_fq, model, named_modules, model.graph, model_device, +======= + arg, input_edge_obs_or_fq, model, named_modules, model.graph +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return new_arg @@ -410,7 +427,10 @@ def _maybe_insert_input_observers_for_node( named_modules: dict[str, torch.nn.Module], obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, +<<<<<<< HEAD model_device: Optional[torch.device] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: """ If needed, inserts observers to the input args and kwargs of `node`. @@ -437,7 +457,10 @@ def _maybe_insert_input_observers_for_node( named_modules, obs_or_fq_map, is_qat, +<<<<<<< HEAD model_device, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) new_args.append(new_arg) @@ -462,17 +485,24 @@ def _maybe_insert_output_observer_for_node( graph: Graph, obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, +<<<<<<< HEAD model_device: Optional[torch.device] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Optional[Node]: if node in obs_or_fq_map: output_act_obs_or_fq = obs_or_fq_map[node] new_output = _insert_obs_or_fq( +<<<<<<< HEAD node, output_act_obs_or_fq, model, named_modules, graph, model_device, +======= + node, output_act_obs_or_fq, model, named_modules, graph +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # propagate numeric debug handle from original node to observer/fake_quant node if ( @@ -495,7 +525,10 @@ def _maybe_insert_input_and_output_observers_for_node( model: torch.fx.GraphModule, obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, +<<<<<<< HEAD model_device: Optional[torch.device] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): this_node_quantization_annotation = ( node.meta["quantization_annotation"] @@ -513,7 +546,10 @@ def _maybe_insert_input_and_output_observers_for_node( named_modules, obs_or_fq_map, is_qat, +<<<<<<< HEAD model_device, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor) @@ -522,6 +558,7 @@ def _maybe_insert_input_and_output_observers_for_node( # this returns the new observer node if it was needed maybe_output_obs_node = _maybe_insert_output_observer_for_node( +<<<<<<< HEAD node, model, named_modules, @@ -529,6 +566,9 @@ def _maybe_insert_input_and_output_observers_for_node( obs_or_fq_map, is_qat, model_device, +======= + node, model, named_modules, model.graph, obs_or_fq_map, is_qat +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if maybe_output_obs_node is None: @@ -576,16 +616,23 @@ def prepare( ) if obs_or_fq_callback: obs_or_fq_callback(model, obs_or_fq_map) +<<<<<<< HEAD model_device = _assert_and_get_unique_device(model) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for node in nodes_before_observation: # TODO: simplify logic for inserting observers _maybe_insert_input_and_output_observers_for_node( +<<<<<<< HEAD node, model, obs_or_fq_map, is_qat, model_device, +======= + node, model, obs_or_fq_map, is_qat +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) model = GraphModule(model, model.graph) diff --git a/torch/ao/quantization/pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py index b9ce762896f1f..898b63a018749 100644 --- a/torch/ao/quantization/pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -876,7 +876,11 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: m, F.conv_transpose2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda ) +<<<<<<< HEAD # remove in place add from batchnorm tracking training stats +======= + # remove in place add from batchnorm tracking traning stats +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for node in m.graph.nodes: if ( node.target == torch.ops.aten.add_.Tensor diff --git a/torch/ao/quantization/pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py index 5a757a700498d..34bbde42351da 100644 --- a/torch/ao/quantization/pt2e/representation/rewrite.py +++ b/torch/ao/quantization/pt2e/representation/rewrite.py @@ -300,7 +300,11 @@ def _reference_quantized_conv2d( # Out_(i, j)_fp32 = ((X_scale * W_scale) * Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)]) + bias_(i)_fp32 # In order to addition of bias_(i)_fp32 inside, we must do # Out_(i, j)_fp32 = (X_scale * W_scale) * (Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)] + (1 / (X_scale * W_scale)) * bias_(i)_fp32)W_scale # noqa: B950 +<<<<<<< HEAD # Note we had to multiply bias_fp32 with X_scale * W_scale = bias_scale +======= + # Note we had to multiply bias_fp32 qith X_scale * W_scale = bias_scale +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Thus bias quantization to int32 must be with X_scale * W_scale bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) @@ -436,7 +440,11 @@ def _reference_quantized_add( x_fp32 = (x_i8 - x_zero_point) * x_scale (3) y_fp32 = (y_i8 - y_zero_point) * y_scale (4) +<<<<<<< HEAD # applying the above formula to the out_i8 equation we can get the following: +======= + # applying the above fomula to the out_i8 equation we can get the following: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_i8 = out_fp32 / out_scale + out_zero_point # (1) = (x_f32 + y_f32) / out_scale + out_zero_point # applying (2) to substitute out_fp32 with x_fp32 + y_fp32 = ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point # apply (3) and (4) diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 699a4c384837d..13d98e729411c 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -361,7 +361,11 @@ def _get_aten_graph_module_for_pattern( example_inputs, kwargs, strict=True, +<<<<<<< HEAD ).module(check_guards=False) +======= + ).module() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] aten_pattern.recompile() # type: ignore[operator] diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index 94dfdb8c7626a..ed46e5ea0e65f 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -1,6 +1,9 @@ # mypy: allow-untyped-defs import copy +<<<<<<< HEAD import sys +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import warnings from collections import namedtuple from typing import Any, Optional, Union @@ -568,6 +571,7 @@ def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> N ) +<<<<<<< HEAD if sys.version_info < (3, 12): QConfigAny = Optional[QConfig] QConfigAny.__module__ = "torch.ao.quantization.qconfig" @@ -575,6 +579,10 @@ def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> N from typing import TypeAliasType QConfigAny = TypeAliasType("QConfigAny", Optional[QConfig]) +======= +QConfigAny = Optional[QConfig] +QConfigAny.__module__ = "torch.ao.quantization.qconfig" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _add_module_to_qconfig_obs_ctr( diff --git a/torch/ao/quantization/quantize_fx.py b/torch/ao/quantization/quantize_fx.py index c59d35c573505..1cf91922353e2 100644 --- a/torch/ao/quantization/quantize_fx.py +++ b/torch/ao/quantization/quantize_fx.py @@ -185,7 +185,11 @@ def _prepare_standalone_module_fx( same as input_quantized_idxs configuration provided for the standalone module * `standalone_module_output_quantized_idxs(List[Int])`: a list of +<<<<<<< HEAD indices for the graph output that is quantized +======= + indexs for the graph output that is quantized +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) same as input_quantized_idxs configuration provided for the standalone module diff --git a/torch/ao/quantization/quantize_pt2e.py b/torch/ao/quantization/quantize_pt2e.py index 169e2905ddbdc..136b49cd5853b 100644 --- a/torch/ao/quantization/quantize_pt2e.py +++ b/torch/ao/quantization/quantize_pt2e.py @@ -76,7 +76,11 @@ def calibrate(model, data_loader): # Step 1. program capture # NOTE: this API will be updated to torch.export API in the future, but the captured +<<<<<<< HEAD # result should mostly stay the same +======= + # result shoud mostly stay the same +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m = torch.export.export_for_training(m, *example_inputs).module() # we get a model with aten ops @@ -153,7 +157,11 @@ def train_loop(model, train_data): # Step 1. program capture # NOTE: this API will be updated to torch.export API in the future, but the captured +<<<<<<< HEAD # result should mostly stay the same +======= + # result shoud mostly stay the same +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m = torch.export.export_for_training(m, *example_inputs).module() # we get a model with aten ops @@ -218,7 +226,11 @@ def convert_pt2e( Args: * `model` (torch.fx.GraphModule): calibrated/trained model +<<<<<<< HEAD * `use_reference_representation` (bool): boolean flag to indicate whether to produce reference representation or not +======= + * `use_reference_representation` (bool): boolean flag to indicate whether to produce referece representation or not +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * `fold_quantize` (bool): boolean flag for whether fold the quantize op or not Returns: diff --git a/torch/ao/quantization/quantizer/quantizer.py b/torch/ao/quantization/quantizer/quantizer.py index 9884cb1990f07..545e6b44199a9 100644 --- a/torch/ao/quantization/quantizer/quantizer.py +++ b/torch/ao/quantization/quantizer/quantizer.py @@ -111,7 +111,11 @@ class DerivedQuantizationSpec(QuantizationSpecBase): @dataclass class QuantizationAnnotation: +<<<<<<< HEAD """How are input argument or output should be quantized, +======= + """How are input arguemnt or output should be quantized, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expressed as QuantizationSpec, this corresponds to how a Tensor in the operator Graph is observed (PTQ) or fake quantized (QAT) """ diff --git a/torch/ao/quantization/quantizer/utils.py b/torch/ao/quantization/quantizer/utils.py index 04fefb7e463bc..e4013b0080623 100644 --- a/torch/ao/quantization/quantizer/utils.py +++ b/torch/ao/quantization/quantizer/utils.py @@ -28,7 +28,11 @@ def _node_only_used_for_sym_size(node: Node, partition_nodes: list[Node]): This utility is used to handle cases when dynami_shape=True tracing leads to symint nodes in the pattern of linear module. In those cases, we need to distinguish between the nodes that are in input for just extracting value of +<<<<<<< HEAD some dimensions (and symint nodes) vs. the one that is activation. +======= + some dimentions (and symint nodes) vs. the one that is activation. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) For example: graph(x, y, weight): size_0 = torch.ops.aten.sym_size([x], [0]) diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index 6005152a4d73f..edc674ccb8e94 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -245,7 +245,11 @@ def not_module_type_or_name_filter(n: Node) -> bool: class XNNPACKQuantizer(Quantizer): """ !!! DEPRECATED !!! +<<<<<<< HEAD XNNPACKQuantizer is a marked as deprecated. It will be removed in the future. +======= + XNNPACKQuantizer is a marked as deprected. It will be removed in the future. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) It has been moved to executorch.backends.xnnpack.quantizer.xnnpack_quantizer.XNNPACKQuantizer. Please use the new quantizer instead. """ diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index f8ac0a7727de3..16999eb678b9d 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -422,7 +422,11 @@ def _annotate_conv_bn( filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[list[list[Node]]]: """ +<<<<<<< HEAD Find conv + batchnorm partitions +======= + Find conv + batchnorm parititions +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. """ return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False) @@ -435,7 +439,11 @@ def _annotate_conv_bn_relu( filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[list[list[Node]]]: """ +<<<<<<< HEAD Find conv + batchnorm + relu partitions +======= + Find conv + batchnorm + relu parititions +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. """ return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True) @@ -448,7 +456,11 @@ def _annotate_conv_transpose_bn( filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[list[list[Node]]]: """ +<<<<<<< HEAD Find conv_transpose + batchnorm partitions +======= + Find conv_transpose + batchnorm parititions +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. """ return _do_annotate_conv_bn( @@ -463,7 +475,11 @@ def _annotate_conv_transpose_bn_relu( filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[list[list[Node]]]: """ +<<<<<<< HEAD Find conv_transpose + batchnorm + relu partitions +======= + Find conv_transpose + batchnorm + relu parititions +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. """ return _do_annotate_conv_bn( diff --git a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py index eff97dbcf27da..f707a83e9bd29 100644 --- a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py @@ -85,7 +85,11 @@ def __init__(self) -> None: overrides. We keep the annotate methods but make the function body empty, aiming to let `_generate_qdq_quantized_model` generate qdq around op and graph execute on fp32 dtype for +<<<<<<< HEAD unsupported operators. +======= + unspported operators. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ def _annotate_qat_conv2d_fusion_pattern( diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index e93cd3fdb7cbd..3533f57557714 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -4,7 +4,10 @@ """ import functools +<<<<<<< HEAD import sys +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import warnings from collections import OrderedDict from inspect import getfullargspec, signature @@ -16,6 +19,7 @@ from torch.nn.utils.parametrize import is_parametrized +<<<<<<< HEAD if sys.version_info < (3, 12): NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any] NodePattern.__module__ = "torch.ao.quantization.utils" @@ -26,6 +30,10 @@ "NodePattern", Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any] ) +======= +NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any] +NodePattern.__module__ = "torch.ao.quantization.utils" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This is the Quantizer class instance from torch/quantization/fx/quantize.py. # Define separately to prevent circular imports. @@ -37,6 +45,7 @@ # Type for fusion patterns, it can be more complicated than the following actually, # see pattern.md for docs # TODO: not sure if typing supports recursive data types +<<<<<<< HEAD if sys.version_info < (3, 12): Pattern = Union[ @@ -58,6 +67,12 @@ Any, ], ) +======= +Pattern = Union[ + Callable, tuple[Callable, Callable], tuple[Callable, tuple[Callable, Callable]], Any +] +Pattern.__module__ = "torch.ao.quantization.utils" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: maybe rename this to MatchInputNode @@ -642,7 +657,11 @@ def validate_qmin_qmax(quant_min: int, quant_max: int) -> None: # Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer +<<<<<<< HEAD # to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikely to change +======= +# to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikey to change +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. TODO(jakeszwe, jerryzh168) def determine_qparams( min_val: torch.Tensor, diff --git a/torch/autograd/_functions/utils.py b/torch/autograd/_functions/utils.py index 1e74e21d3cef2..0b8ce3d55120d 100644 --- a/torch/autograd/_functions/utils.py +++ b/torch/autograd/_functions/utils.py @@ -1,4 +1,9 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD +======= +import operator +from functools import reduce +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def maybe_view(tensor, size, check_same_size=True): @@ -24,3 +29,41 @@ def maybe_unexpand(tensor, old_size, check_same_size=True): for dim in expanded_dims: tensor = tensor.sum(dim, keepdim=True) return tensor +<<<<<<< HEAD +======= + + +# Check whether the op enable broadcasting, and whether it is supported by ONNX. +# If dims1 and dims2 are different, then broadcast is True. +# We always assume the combination of dims1 and dims2 is broadcastable. +# The following types of broadcasting are supported in ONNX: +# 1) Only one element in dims2, such as dims2 = [1, 1] +# 2) dims2 is suffix of dims1, such as dims1 = [2, 3, 4], and dims2 = [3, 4] +# Details can be found here: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm +def check_onnx_broadcast(dims1, dims2): + broadcast = False + supported = True + len1 = len(dims1) + len2 = len(dims2) + + numel2 = reduce(operator.mul, dims2) + if len1 < len2: + broadcast = True + if numel2 != 1: + supported = False + elif len1 > len2: + broadcast = True + if numel2 != 1 and dims1[len1 - len2 :] != dims2: + supported = False + else: + if dims1 != dims2: + broadcast = True + if numel2 != 1: + supported = False + + if not supported: + raise ValueError( + f"Numpy style broadcasting is not supported in ONNX. Input dims are: {dims1}, {dims2}" + ) + return broadcast +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/autograd/anomaly_mode.py b/torch/autograd/anomaly_mode.py index 0277f1b75541f..d7884eb7b6af8 100644 --- a/torch/autograd/anomaly_mode.py +++ b/torch/autograd/anomaly_mode.py @@ -1,6 +1,9 @@ # mypy: allow-untyped-defs r"""Autograd anomaly mode.""" +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import warnings import torch @@ -32,7 +35,10 @@ class detect_anomaly: ... @staticmethod ... def forward(ctx, inp): ... return inp.clone() +<<<<<<< HEAD ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ... @staticmethod ... def backward(ctx, gO): ... # Error during the backward pass diff --git a/torch/autograd/function.py b/torch/autograd/function.py index ac3aad9f93b59..d235c253c6bf4 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -4,8 +4,13 @@ import itertools import warnings from collections import OrderedDict +<<<<<<< HEAD from typing import Any, Callable, Optional, TypeVar from typing_extensions import Concatenate, deprecated, ParamSpec +======= +from typing import Any, Optional +from typing_extensions import deprecated +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._C as _C @@ -29,10 +34,13 @@ # This is incremented in FunctionMeta during class definition AUTOGRAD_FUNCTION_COUNTER = itertools.count() +<<<<<<< HEAD _T = TypeVar("_T") _R = TypeVar("_R") _P = ParamSpec("_P") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Formerly known as: _ContextMethodMixin class FunctionCtx: @@ -370,7 +378,10 @@ def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: def forward(*args: Any, **kwargs: Any) -> Any: pass +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass @@ -599,6 +610,7 @@ def _is_setup_context_defined(fn): return fn != _SingleLevelFunction.setup_context +<<<<<<< HEAD def once_differentiable( fn: Callable[Concatenate[_T, _P], _R], ) -> Callable[Concatenate[_T, _P], _R]: @@ -606,6 +618,13 @@ def once_differentiable( def wrapper(ctx: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R: with torch.no_grad(): outputs = fn(ctx, *args, **kwargs) +======= +def once_differentiable(fn): + @functools.wraps(fn) + def wrapper(ctx, *args): + with torch.no_grad(): + outputs = fn(ctx, *args) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not torch.is_grad_enabled(): return outputs @@ -626,14 +645,22 @@ def wrapper(ctx: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R: return outputs if not isinstance(outputs, tuple): +<<<<<<< HEAD outputs_ = (outputs,) else: outputs_ = outputs +======= + outputs = (outputs,) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) err_fn = _functions.DelayedError( b"trying to differentiate twice a function that was marked " b"with @once_differentiable", +<<<<<<< HEAD len(outputs_), +======= + len(outputs), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Create aliases of each output that has requires_grad=True. We need @@ -645,7 +672,11 @@ def fake_requires_grad(var): var.requires_grad = True return var +<<<<<<< HEAD return err_fn(*[fake_requires_grad(v) for v in outputs_]) # type: ignore[return-value] +======= + return err_fn(*[fake_requires_grad(v) for v in outputs]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return wrapper @@ -775,7 +806,10 @@ class NestedIOFunction(Function): This class is here only for backward compatibility reasons. Use :class:`Function` instead of this for any new use case. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the # superclass (Function) but are instance methods here, which mypy reports as incompatible. diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index e92f38b3af38b..47e92205dc3eb 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -210,6 +210,7 @@ def clone(self) -> "set_grad_enabled": class inference_mode(_DecoratorContextManager): +<<<<<<< HEAD r"""Context manager that enables or disables inference mode. InferenceMode is analogous to :class:`~no_grad` and should be used @@ -221,11 +222,24 @@ class inference_mode(_DecoratorContextManager): recorded by autograd. This context manager is thread-local; it does not affect computation +======= + r"""Context-manager that enables or disables inference mode. + + InferenceMode is a context manager analogous to :class:`~no_grad` + to be used when you are certain your operations will have no interactions + with autograd (e.g., model training). Code run under this mode gets better + performance by disabling view tracking and version counter bumps. Note that + unlike some other mechanisms that locally enable or disable grad, + entering inference_mode also disables to :ref:`forward-mode AD `. + + This context manager is thread local; it will not affect computation +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) in other threads. Also functions as a decorator. .. note:: +<<<<<<< HEAD Inference mode is one of several mechanisms that can locally enable or disable gradients. See :ref:`locally-disable-grad-doc` for a comparison. If avoiding the use of tensors created in inference mode @@ -241,6 +255,16 @@ class inference_mode(_DecoratorContextManager): mode (bool or function): Either a boolean flag to enable or disable inference mode, or a Python function to decorate with inference mode enabled. +======= + Inference mode is one of several mechanisms that can enable or + disable gradients locally see :ref:`locally-disable-grad-doc` for + more information on how they compare. + + Args: + mode (bool or function): Either a boolean flag whether to enable or + disable inference mode or a Python function to decorate with + inference mode enabled +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Example:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) @@ -250,7 +274,11 @@ class inference_mode(_DecoratorContextManager): ... y = x * x >>> y.requires_grad False +<<<<<<< HEAD >>> # xdoctest: +SKIP("want string isn't quite right") +======= + >>> # xdoctest: +SKIP("want string isnt quite right") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> y._version Traceback (most recent call last): File "", line 1, in diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 6dacdfe8b9462..6a17de72b3352 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -1947,7 +1947,11 @@ def _fast_gradcheck( # Note [VarArg of Tensors] # ~~~~~~~~~~~~~~~~~~~~~~~~ +<<<<<<< HEAD # 'func' accepts a vararg of tensors, which isn't expressible in the type system at the moment. +======= +# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted, # the '...' first argument of Callable can be replaced with VarArg(Tensor). # For now, we permit any input. @@ -2036,6 +2040,7 @@ def gradcheck( ``True`` if all differences satisfy allclose condition """ +<<<<<<< HEAD assert check_forward_ad or check_backward_ad, ( "Expected at least one of check_forward_ad or check_backward_ad to be True" ) @@ -2045,6 +2050,17 @@ def gradcheck( assert not (check_batched_forward_grad and not check_forward_ad), ( "Setting check_batched_forward_grad=True requires check_forward_ad to be True" ) +======= + assert ( + check_forward_ad or check_backward_ad + ), "Expected at least one of check_forward_ad or check_backward_ad to be True" + assert not ( + check_batched_grad and not check_backward_ad + ), "Setting check_batched_grad=True requires check_backward_ad to be True" + assert not ( + check_batched_forward_grad and not check_forward_ad + ), "Setting check_batched_forward_grad=True requires check_forward_ad to be True" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = locals().copy() args.pop("raise_exception") if not raise_exception: @@ -2189,6 +2205,7 @@ def gradgradcheck( Returns: True if all differences satisfy allclose condition """ +<<<<<<< HEAD assert check_fwd_over_rev or check_rev_over_rev, ( "Expected at least one of check_fwd_over_rev or check_rev_over_rev to be True" ) @@ -2198,6 +2215,17 @@ def gradgradcheck( assert not (check_batched_grad and not check_rev_over_rev), ( "Setting check_batched_grad=True requires check_rev_over_rev to be True" ) +======= + assert ( + check_fwd_over_rev or check_rev_over_rev + ), "Expected at least one of check_fwd_over_rev or check_rev_over_rev to be True" + assert not ( + check_undefined_grad and not check_rev_over_rev + ), "Setting check_undefined_grad=True requires check_rev_over_rev to be True" + assert not ( + check_batched_grad and not check_rev_over_rev + ), "Setting check_batched_grad=True requires check_rev_over_rev to be True" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: do we want to test this too? # assert not (check_batched_forward_grad and not check_fwd_over_rev), ( # "Setting check_batched_forward_grad=True requires check_fwd_over_rev to be True") diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 4b2707b65d0f1..a48ee1a794e9e 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -194,9 +194,12 @@ class GradientEdge(NamedTuple): node: Node output_nr: int +<<<<<<< HEAD # This token can be used to ensure the graph stays alive when it cannot be # done via the node field ownership_token: Optional[Node] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge: @@ -212,6 +215,7 @@ def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge: ) grad_fn = _get_grad_fn_or_grad_acc(tensor) +<<<<<<< HEAD # Python-based Node are owned by the C++ side meaning the python grad_fn # object we hold here does NOT keep the C++ graph alive. # Create an ownership token by creating a new C++ node that own the graph @@ -224,6 +228,11 @@ def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge: # Note that output_nr default to 0 which is the right value # for the AccumulateGrad node. return GradientEdge(grad_fn, tensor.output_nr, ownership_token=token) +======= + # Note that output_nr default to 0 which is the right value + # for the AccumulateGrad node. + return GradientEdge(grad_fn, tensor.output_nr) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def increment_version(tensor: Union[torch.Tensor, Iterable[torch.Tensor]]) -> None: @@ -253,7 +262,11 @@ class saved_tensors_hooks: Use this context-manager to define how intermediary results of an operation should be packed before saving, and unpacked on retrieval. +<<<<<<< HEAD In that context, the ``pack_hook`` function will be called every time an +======= + In that context, the ``pack_hook`` function will be called everytime an +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) operation saves a tensor for backward (this includes intermediary results saved using :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but @@ -521,9 +534,15 @@ def get_inner_hook(idx: int) -> Callable[[torch.Tensor], None]: def inner_hook(grad: torch.Tensor) -> None: nonlocal count, nb_calls, buffer, fn id = torch._C._current_graph_task_id() +<<<<<<< HEAD assert id != -1, ( "expected this hook to be called inside a backward call" ) +======= + assert ( + id != -1 + ), "expected this hook to be called inside a backward call" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) count[id] = count.get(id, 0) buffer[id] = buffer.get(id, [None] * len_tensors) @@ -732,9 +751,15 @@ def clear(self) -> None: @contextlib.contextmanager +<<<<<<< HEAD def allow_mutation_on_saved_tensors() -> Generator[ _AllowMutationOnSavedContext, None, None ]: +======= +def allow_mutation_on_saved_tensors() -> ( + Generator[_AllowMutationOnSavedContext, None, None] +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Context manager under which mutating tensors saved for backward is allowed. Under this context manager, tensors saved for backward are cloned on mutation, diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index c1ae4d8561fdb..eca885c900fb9 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -95,7 +95,10 @@ def _run_on_profiler_stop(): @dataclass class _ProfilerStats: "Profiler timing and stats used by developers to catch issues/regressions" +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) profiling_window_duration_sec: float = 0 number_of_events: int = 0 profiler_prepare_call_duration_us: int = 0 @@ -108,9 +111,12 @@ class _ProfilerStats: class profile: """Context manager that manages autograd profiler state and holds a summary of results. +<<<<<<< HEAD .. note:: This is the backend, most people should use :mod:`torch.profiler` instead. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Under the hood it just records events of functions being executed in C++ and exposes those events to Python. You can wrap any code into it and it will only report runtime of PyTorch functions. @@ -255,9 +261,15 @@ def __init__( self.custom_trace_id_callback = custom_trace_id_callback self.trace_id = "" if not self.use_cpu: +<<<<<<< HEAD assert use_kineto, ( "Device-only events supported only with Kineto (use_kineto=True)" ) +======= + assert ( + use_kineto + ), "Device-only events supported only with Kineto (use_kineto=True)" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.use_device is not None: VALID_DEVICE_OPTIONS = ["cuda", "xpu", "mtia", "hpu"] @@ -294,6 +306,7 @@ def __init__( else: self.kineto_activities.add(ProfilerActivity.CUDA) elif self.use_device == "xpu": +<<<<<<< HEAD assert use_kineto and ProfilerActivity.XPU in _supported_activities(), ( "Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices." ) @@ -307,22 +320,49 @@ def __init__( assert use_kineto and ProfilerActivity.HPU in _supported_activities(), ( "Legacy HPU profiling is not supported. Requires use_kineto=True on HPU devices." ) +======= + assert ( + use_kineto and ProfilerActivity.XPU in _supported_activities() + ), "Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices." + self.kineto_activities.add(ProfilerActivity.XPU) + elif self.use_device == "mtia": + assert ( + use_kineto and ProfilerActivity.MTIA in _supported_activities() + ), "Legacy MTIA profiling is not supported. Requires use_kineto=True on MTIA devices." + self.kineto_activities.add(ProfilerActivity.MTIA) + elif self.use_device == "hpu": + assert ( + use_kineto and ProfilerActivity.HPU in _supported_activities() + ), "Legacy HPU profiling is not supported. Requires use_kineto=True on HPU devices." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.kineto_activities.add(ProfilerActivity.HPU) elif self.use_device is not None and self.use_device != "privateuseone": if ( not use_kineto or ProfilerActivity.PrivateUse1 not in _supported_activities() ): +<<<<<<< HEAD assert self.use_cpu, ( "Legacy custombackend profiling requires use_cpu=True" ) +======= + assert ( + self.use_cpu + ), "Legacy custombackend profiling requires use_cpu=True" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1_FALLBACK else: self.kineto_activities.add(ProfilerActivity.PrivateUse1) +<<<<<<< HEAD assert len(self.kineto_activities) > 0, ( "No activities specified for the profiler" ) +======= + assert ( + len(self.kineto_activities) > 0 + ), "No activities specified for the profiler" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def default_trace_id(self): # Generate a UUID @@ -403,7 +443,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): ) # If we plan to accumulate events we should post process the function events +<<<<<<< HEAD # right away to retain the state across multiple start/stop calls +======= + # right away to retain the state across mulitple start/stop calls +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.acc_events: self._ensure_function_events() return False @@ -586,10 +630,14 @@ def _device_memory_usage(mem_record): device_corr_map: dict[int, list[FunctionEvent]] = {} max_evt_id = 0 for kineto_event in result.events(): +<<<<<<< HEAD if ( _filter_name(kineto_event.name()) or getattr(kineto_event, "is_hidden_event", lambda: False)() ): +======= + if _filter_name(kineto_event.name()): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue rel_start_ns = kineto_event.start_ns() - trace_start_ns rel_end_ns = kineto_event.end_ns() - trace_start_ns @@ -745,12 +793,20 @@ class record_function(_ContextDecorator): >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER) >>> x = torch.randn((1, 1), requires_grad=True) >>> with torch.autograd.profiler.profile() as prof: +<<<<<<< HEAD ... y = x**2 ... with torch.autograd.profiler.record_function( ... "label-z" ... ): # label the block ... z = y**3 ... y.backward() +======= + ... y = x ** 2 + ... with torch.autograd.profiler.record_function("label-z"): # label the block + ... z = y ** 3 + ... y.backward() + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> # xdoctest: +IGNORE_WANT >>> # NOTE: some columns were removed for brevity >>> print(prof.key_averages().table(sort_by="self_cpu_time_total")) @@ -857,7 +913,11 @@ class emit_itt: The Instrumentation and Tracing Technology (ITT) API enables your application to generate and control the collection of trace data during its execution across different Intel tools. This context manager is to annotate Intel(R) VTune Profiling trace. With help of this context manager, +<<<<<<< HEAD you will be able to see labeled ranges in Intel(R) VTune Profiler GUI. +======= + you will be able to see labled ranges in Intel(R) VTune Profiler GUI. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. warning: This context manager should not be called recursively, i.e. at most one diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index b789aab11c663..7182f8e56747f 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -126,9 +126,15 @@ def _populate_cpu_children(self): current_events.pop() else: parent.append_cpu_child(event) +<<<<<<< HEAD assert event.cpu_parent is None, ( f"There is already a CPU parent event for {event.key}" ) +======= + assert ( + event.cpu_parent is None + ), f"There is already a CPU parent event for {event.key}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) event.set_cpu_parent(parent) break @@ -173,7 +179,10 @@ def table( max_shapes_column_width=80, header=None, top_level_events_only=False, +<<<<<<< HEAD time_unit=None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """Print an EventList as a nicely formatted table. @@ -190,8 +199,11 @@ def table( display events at top level like top-level invocation of python `lstm`, python `add` or other functions, nested events like low-level cpu/cuda/xpu ops events are omitted for profiler result readability. +<<<<<<< HEAD time_unit(str, optional): A time unit to be used for all values in the table. Valid options are: ``s``, ``ms`` and ``us``. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns: A string containing the table. @@ -207,7 +219,10 @@ def table( profile_memory=self._profile_memory, with_flops=self._with_flops, top_level_events_only=top_level_events_only, +<<<<<<< HEAD time_unit=time_unit, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def export_chrome_trace(self, path): @@ -836,7 +851,10 @@ def _build_table( with_flops=False, profile_memory=False, top_level_events_only=False, +<<<<<<< HEAD time_unit=None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """Print a summary of events (which can be a list of FunctionEvent or FunctionEventAvg).""" if len(events) == 0: @@ -1044,6 +1062,7 @@ def trim_path(path, src_column_width): path = "..." + path[3:] return path +<<<<<<< HEAD def override_time_unit(time_us, default_str, time_unit): US_IN_SECOND = 1000.0 * 1000.0 US_IN_MS = 1000.0 @@ -1056,6 +1075,8 @@ def override_time_unit(time_us, default_str, time_unit): else: return default_str +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) event_limit = 0 for evt in events: if event_limit == row_limit: @@ -1089,6 +1110,7 @@ def override_time_unit(time_us, default_str, time_unit): row_values += [ # Self CPU total %, 0 for async events. evt.self_cpu_percent, +<<<<<<< HEAD override_time_unit( evt.self_cpu_time_total, evt.self_cpu_time_total_str, time_unit ), # Self CPU total @@ -1100,6 +1122,13 @@ def override_time_unit(time_us, default_str, time_unit): override_time_unit( evt.cpu_time, evt.cpu_time_str, time_unit ), # CPU time avg +======= + evt.self_cpu_time_total_str, # Self CPU total + # CPU total %, 0 for async events. + evt.total_cpu_percent, + evt.cpu_time_total_str, # CPU total + evt.cpu_time_str, # CPU time avg +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] if has_device_time: evt.total_device_percent = _format_time_share( @@ -1107,6 +1136,7 @@ def override_time_unit(time_us, default_str, time_unit): ) row_values.extend( [ +<<<<<<< HEAD override_time_unit( evt.self_device_time_total, evt.self_device_time_total_str, @@ -1120,6 +1150,13 @@ def override_time_unit(time_us, default_str, time_unit): override_time_unit( evt.device_time, evt.device_time_str, time_unit ), # device time avg +======= + evt.self_device_time_total_str, + # device time total % + evt.total_device_percent, + evt.device_time_total_str, + evt.device_time_str, # device time avg +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] ) if profile_memory: @@ -1172,6 +1209,7 @@ def override_time_unit(time_us, default_str, time_unit): append(row_format.format(*empty_headers)) append(header_sep) +<<<<<<< HEAD append( f"Self CPU time total: {override_time_unit(sum_self_cpu_time_total, _format_time(sum_self_cpu_time_total), time_unit)}" ) @@ -1179,5 +1217,12 @@ def override_time_unit(time_us, default_str, time_unit): append( f"Self {use_device.upper() if use_device is not None else 'None'} " f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}" +======= + append(f"Self CPU time total: {_format_time(sum_self_cpu_time_total)}") + if has_device_time: + append( + f"Self {use_device.upper() if use_device is not None else 'None'} " + f"time total: {_format_time(sum_self_device_time_total)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return "".join(result) diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index c02a8c36fd08b..ddfa2af21e95e 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -1,10 +1,16 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD import sys import types from contextlib import contextmanager import torch +======= +import types +from contextlib import contextmanager + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The idea for this parameter is that we forbid bare assignment # to torch.backends..enabled and friends when running our @@ -60,6 +66,7 @@ def __getattr__(self, attr): return self.m.__getattribute__(attr) +<<<<<<< HEAD class _FP32Precision: def __init__(self, backend, op): self.backend = backend @@ -124,6 +131,8 @@ def __init__(self, m, name): sys.modules[__name__] = GenericModule(sys.modules[__name__], __name__) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.backends import ( cpu as cpu, cuda as cuda, @@ -131,12 +140,18 @@ def __init__(self, m, name): cusparselt as cusparselt, kleidiai as kleidiai, mha as mha, +<<<<<<< HEAD miopen as miopen, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mkl as mkl, mkldnn as mkldnn, mps as mps, nnpack as nnpack, openmp as openmp, +<<<<<<< HEAD opt_einsum as opt_einsum, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) quantized as quantized, ) diff --git a/torch/backends/_nnapi/serializer.py b/torch/backends/_nnapi/serializer.py index 5c5d1a1885f31..9506c1b12a7e7 100644 --- a/torch/backends/_nnapi/serializer.py +++ b/torch/backends/_nnapi/serializer.py @@ -201,7 +201,11 @@ class DimOrder(enum.Enum): class Operand(NamedTuple): +<<<<<<< HEAD """Representation of an NNAPI operand.""" +======= + """Represenation of an NNAPI operand.""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # NNAPI operand type. One of NNAPI_OperandCode. # TODO: Make this an enum. diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index 87327428461a2..124e92a7df174 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -135,8 +135,11 @@ def __getattr__(self, name): return torch._C._get_cublas_allow_bf16_reduced_precision_reduction() elif name == "allow_fp16_accumulation": return torch._C._get_cublas_allow_fp16_accumulation() +<<<<<<< HEAD elif name == "fp32_precision": return torch._C._get_fp32_precision_getter("cuda", "matmul") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise AttributeError("Unknown attribute " + name) def __setattr__(self, name, value): @@ -148,8 +151,11 @@ def __setattr__(self, name, value): return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value) elif name == "allow_fp16_accumulation": return torch._C._set_cublas_allow_fp16_accumulation(value) +<<<<<<< HEAD elif name == "fp32_precision": return torch._C._set_fp32_precision_setter("cuda", "matmul", value) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise AttributeError("Unknown attribute " + name) @@ -162,7 +168,11 @@ def __setattr__(self, name, value): def preferred_linalg_library( +<<<<<<< HEAD backend: Union[None, str, torch._C._LinalgBackend] = None, +======= + backend: Union[None, str, torch._C._LinalgBackend] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> torch._C._LinalgBackend: r""" Override the heuristic PyTorch uses to choose between cuSOLVER and MAGMA for CUDA linear algebra operations. @@ -210,7 +220,11 @@ def preferred_linalg_library( elif isinstance(backend, str): if backend not in _LinalgBackends: raise RuntimeError( +<<<<<<< HEAD f"Unknown input value. Choose from: {_LinalgBackends_str}." +======= + "Unknown input value. " f"Choose from: {_LinalgBackends_str}." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) torch._C._set_linalg_preferred_backend(_LinalgBackends[backend]) elif isinstance(backend, torch._C._LinalgBackend): @@ -233,7 +247,11 @@ def preferred_linalg_library( def preferred_blas_library( +<<<<<<< HEAD backend: Union[None, str, torch._C._BlasBackend] = None, +======= + backend: Union[None, str, torch._C._BlasBackend] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> torch._C._BlasBackend: r""" Override the library PyTorch uses for BLAS operations. Choose between cuBLAS, cuBLASLt, and CK [ROCm-only]. @@ -265,7 +283,11 @@ def preferred_blas_library( elif isinstance(backend, str): if backend not in _BlasBackends: raise RuntimeError( +<<<<<<< HEAD f"Unknown input value. Choose from: {_BlasBackends_str}." +======= + "Unknown input value. " f"Choose from: {_BlasBackends_str}." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) torch._C._set_blas_preferred_backend(_BlasBackends[backend]) elif isinstance(backend, torch._C._BlasBackend): @@ -288,13 +310,21 @@ def preferred_blas_library( def preferred_rocm_fa_library( +<<<<<<< HEAD backend: Union[None, str, torch._C._ROCmFABackend] = None, +======= + backend: Union[None, str, torch._C._ROCmFABackend] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> torch._C._ROCmFABackend: r""" [ROCm-only] Override the backend PyTorch uses in ROCm environments for Flash Attention. Choose between AOTriton and CK +<<<<<<< HEAD .. warning:: This flag is experimental and subject to change. +======= + .. warning:: This flag is experimeental and subject to change. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) When Flash Attention is enabled and desired, PyTorch defaults to using AOTriton as the backend. This flag (a :class:`str`) allows users to override this backend to use composable_kernel @@ -316,13 +346,21 @@ def preferred_rocm_fa_library( elif isinstance(backend, str): if backend not in _ROCmFABackends: raise RuntimeError( +<<<<<<< HEAD f"Unknown input value. Choose from: {_ROCmFABackends_str}." +======= + "Unknown input value. " f"Choose from: {_ROCmFABackends_str}." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) torch._C._set_rocm_fa_preferred_backend(_ROCmFABackends[backend]) elif isinstance(backend, torch._C._ROCmFABackend): torch._C._set_rocm_fa_preferred_backend(backend) else: +<<<<<<< HEAD raise ValueError(f"Unknown input value. Choose from: {_ROCmFABackends_str}.") +======= + raise ValueError("Unknown input value. " f"Choose from: {_ROCmFABackends_str}.") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch._C._get_rocm_fa_preferred_backend() diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index 9c155de7c04b0..df2af4603424d 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -6,6 +6,7 @@ from typing import Optional import torch +<<<<<<< HEAD from torch.backends import ( __allow_nonbracketed_mutation, _FP32Precision, @@ -14,6 +15,9 @@ ContextProp, PropModule, ) +======= +from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: @@ -135,7 +139,10 @@ def set_flags( _benchmark_limit=None, _deterministic=None, _allow_tf32=None, +<<<<<<< HEAD _fp32_precision="none", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): orig_flags = ( torch._C._get_cudnn_enabled(), @@ -143,7 +150,10 @@ def set_flags( None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(), torch._C._get_cudnn_deterministic(), torch._C._get_cudnn_allow_tf32(), +<<<<<<< HEAD torch._C._get_fp32_precision_getter("cuda", "all"), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if _enabled is not None: torch._C._set_cudnn_enabled(_enabled) @@ -155,8 +165,11 @@ def set_flags( torch._C._set_cudnn_deterministic(_deterministic) if _allow_tf32 is not None: torch._C._set_cudnn_allow_tf32(_allow_tf32) +<<<<<<< HEAD if _fp32_precision is not None: torch._C._set_fp32_precision_setter("cuda", "all", _fp32_precision) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return orig_flags @@ -167,6 +180,7 @@ def flags( benchmark_limit=10, deterministic=False, allow_tf32=True, +<<<<<<< HEAD fp32_precision="none", ): with __allow_nonbracketed_mutation(): @@ -177,6 +191,12 @@ def flags( deterministic, allow_tf32, fp32_precision, +======= +): + with __allow_nonbracketed_mutation(): + orig_flags = set_flags( + enabled, benchmark, benchmark_limit, deterministic, allow_tf32 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) try: yield @@ -211,12 +231,15 @@ def __init__(self, m, name): allow_tf32 = ContextProp( torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32 ) +<<<<<<< HEAD conv = _FP32Precision("cuda", "conv") rnn = _FP32Precision("cuda", "rnn") fp32_precision = ContextProp( _get_fp32_precision_getter("cuda", "all"), _set_fp32_precision_setter("cuda", "all"), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This is the sys.modules replacement trick, see diff --git a/torch/backends/cusparselt/__init__.py b/torch/backends/cusparselt/__init__.py index 9d3d9a8a01636..92e39d670e5ec 100644 --- a/torch/backends/cusparselt/__init__.py +++ b/torch/backends/cusparselt/__init__.py @@ -1,3 +1,7 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Optional import torch @@ -19,7 +23,11 @@ if _cusparselt is not None: +<<<<<<< HEAD def _init() -> bool: +======= + def _init(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) global __cusparselt_version global __MAX_ALG_ID if __cusparselt_version is None: @@ -34,7 +42,11 @@ def _init() -> bool: else: +<<<<<<< HEAD def _init() -> bool: +======= + def _init(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False diff --git a/torch/backends/mkl/__init__.py b/torch/backends/mkl/__init__.py index ae16922761afe..49db4bd485954 100644 --- a/torch/backends/mkl/__init__.py +++ b/torch/backends/mkl/__init__.py @@ -30,7 +30,10 @@ class verbose: .. code-block:: python import torch +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model(data) with torch.backends.mkl.verbose(torch.backends.mkl.VERBOSE_ON): model(data) @@ -48,9 +51,15 @@ def __enter__(self): if self.enable == VERBOSE_OFF: return st = torch._C._verbose.mkl_set_verbose(self.enable) +<<<<<<< HEAD assert st, ( "Failed to set MKL into verbose mode. Please consider to disable this verbose scope." ) +======= + assert ( + st + ), "Failed to set MKL into verbose mode. Please consider to disable this verbose scope." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/torch/backends/mkldnn/__init__.py b/torch/backends/mkldnn/__init__.py index ae76a9f20c46f..e46dd55ac49d4 100644 --- a/torch/backends/mkldnn/__init__.py +++ b/torch/backends/mkldnn/__init__.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING import torch +<<<<<<< HEAD from torch.backends import ( __allow_nonbracketed_mutation, _FP32Precision, @@ -12,6 +13,9 @@ ContextProp, PropModule, ) +======= +from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def is_available(): @@ -43,7 +47,10 @@ class verbose: .. code-block:: python import torch +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model(data) with torch.backends.mkldnn.verbose(torch.backends.mkldnn.VERBOSE_ON): model(data) @@ -62,9 +69,15 @@ def __enter__(self): if self.level == VERBOSE_OFF: return st = torch._C._verbose.mkldnn_set_verbose(self.level) +<<<<<<< HEAD assert st, ( "Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope." ) +======= + assert ( + st + ), "Failed to set MKLDNN into verbose mode. Please consider to disable this verbose scope." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -72,14 +85,21 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False +<<<<<<< HEAD def set_flags( _enabled=None, _deterministic=None, _allow_tf32=None, _fp32_precision="none" ): +======= +def set_flags(_enabled=None, _deterministic=None, _allow_tf32=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig_flags = ( torch._C._get_mkldnn_enabled(), torch._C._get_mkldnn_deterministic(), torch._C._get_onednn_allow_tf32(), +<<<<<<< HEAD torch._C._get_fp32_precision_getter("mkldnn", "all"), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if _enabled is not None: torch._C._set_mkldnn_enabled(_enabled) @@ -87,15 +107,24 @@ def set_flags( torch._C._set_mkldnn_deterministic(_deterministic) if _allow_tf32 is not None: torch._C._set_onednn_allow_tf32(_allow_tf32) +<<<<<<< HEAD if _fp32_precision is not None: torch._C._set_fp32_precision_setter("mkldnn", "all", _fp32_precision) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return orig_flags @contextmanager +<<<<<<< HEAD def flags(enabled=False, deterministic=False, allow_tf32=True, fp32_precision="none"): with __allow_nonbracketed_mutation(): orig_flags = set_flags(enabled, deterministic, allow_tf32, fp32_precision) +======= +def flags(enabled=False, deterministic=False, allow_tf32=True): + with __allow_nonbracketed_mutation(): + orig_flags = set_flags(enabled, deterministic, allow_tf32) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: yield finally: @@ -117,6 +146,7 @@ def is_available(self): allow_tf32 = ContextProp( torch._C._get_onednn_allow_tf32, torch._C._set_onednn_allow_tf32 ) +<<<<<<< HEAD matmul = _FP32Precision("mkldnn", "matmul") conv = _FP32Precision("mkldnn", "conv") rnn = _FP32Precision("mkldnn", "rnn") @@ -124,6 +154,8 @@ def is_available(self): _get_fp32_precision_getter("mkldnn", "all"), _set_fp32_precision_setter("generic", "all"), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING: diff --git a/torch/backends/mps/__init__.py b/torch/backends/mps/__init__.py index 5c3c507428cff..c267eaf388d0d 100644 --- a/torch/backends/mps/__init__.py +++ b/torch/backends/mps/__init__.py @@ -1,3 +1,7 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from functools import lru_cache as _lru_cache from typing import Optional, TYPE_CHECKING @@ -5,6 +9,7 @@ from torch.library import Library as _Library +<<<<<<< HEAD __all__ = [ "get_core_count", "get_name", @@ -13,6 +18,9 @@ "is_macos13_or_newer", "is_macos_or_newer", ] +======= +__all__ = ["is_built", "is_available", "is_macos13_or_newer", "is_macos_or_newer"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def is_built() -> bool: @@ -43,6 +51,7 @@ def is_macos13_or_newer(minor: int = 0) -> bool: return torch._C._mps_is_on_macos_or_newer(13, minor) +<<<<<<< HEAD @_lru_cache def get_name() -> str: r"""Return Metal device name""" @@ -64,6 +73,12 @@ def get_core_count() -> int: def _init() -> None: +======= +_lib: Optional[_Library] = None + + +def _init(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Register prims as implementation of var_mean and group_norm.""" global _lib diff --git a/torch/backends/xeon/run_cpu.py b/torch/backends/xeon/run_cpu.py index fe263858abb74..9d1721a1057bf 100644 --- a/torch/backends/xeon/run_cpu.py +++ b/torch/backends/xeon/run_cpu.py @@ -119,7 +119,11 @@ Memory allocator ---------------- +<<<<<<< HEAD "--enable-tcmalloc" and "--enable-jemalloc" can be used to enable different memory allocator. +======= +"--enable-tcmalloc" and "--enable-jemalloc" can be used to enable different memory allcator. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ @@ -262,11 +266,17 @@ def numa_aware_check(self, core_list): class _Launcher: r"""Class for launcher.""" +<<<<<<< HEAD msg_lib_notfound = ( f"Unable to find the {{0}} library file lib{{1}}.so in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib \ or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or \ {expanduser('~')}/.local/lib/ so the LD_PRELOAD environment variable will not be set." ) +======= + msg_lib_notfound = f"Unable to find the {{0}} library file lib{{1}}.so in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib \ +or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or \ +{expanduser('~')}/.local/lib/ so the LD_PRELOAD environment variable will not be set." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self) -> None: self.cpuinfo = _CPUinfo() @@ -613,12 +623,22 @@ def launch(self, args): args.rank == -1 ): # sequentially assign ncores_per_instance to ninstances core_list = cores[ +<<<<<<< HEAD i * args.ncores_per_instance : (i + 1) +======= + i + * args.ncores_per_instance : (i + 1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * args.ncores_per_instance ] else: # assign ncores_per_instance from rank core_list = cores[ +<<<<<<< HEAD args.rank * args.ncores_per_instance : (args.rank + 1) +======= + args.rank + * args.ncores_per_instance : (args.rank + 1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * args.ncores_per_instance ] @@ -626,9 +646,15 @@ def launch(self, args): if local_size > 1: total_num_cores = len(core_list) cores_per_rank = total_num_cores // local_size +<<<<<<< HEAD assert cores_per_rank >= 1, ( "At least one core needs to be assigned to each rank" ) +======= + assert ( + cores_per_rank >= 1 + ), "At least one core needs to be assigned to each rank" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) core_list = core_list[ cores_per_rank * local_rank : cores_per_rank * (local_rank + 1) ] diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 08ec23b748eb5..9379010c85cb8 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -1,6 +1,10 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD import io from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union +======= +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing_extensions import ParamSpec import torch @@ -24,7 +28,10 @@ "set_stance", "set_enable_guard_collectives", "cudagraph_mark_step_begin", +<<<<<<< HEAD "load_compiled_function", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "wrap_numpy", "is_compiling", "is_dynamo_compiling", @@ -41,8 +48,11 @@ _P = ParamSpec("_P") _R = TypeVar("_R") +<<<<<<< HEAD FuncType = Callable[..., Any] F = TypeVar("F", bound=FuncType) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def compile(*args, **kwargs): @@ -127,7 +137,10 @@ def allow_in_graph(fn): torch.compiler.allow_in_graph(my_custom_function) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch.compile(...) def fn(x): x = torch.add(x, 1) @@ -135,7 +148,10 @@ def fn(x): x = torch.add(x, 1) return x +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fn(...) Will capture a single graph containing ``my_custom_function()``. @@ -256,10 +272,14 @@ def disable(fn=None, recursive=True, *, reason=None): def set_stance( +<<<<<<< HEAD stance: str = "default", *, skip_guard_eval_unsafe: bool = False, force_backend: Union[str, Callable[..., Any], None] = None, +======= + stance: str = "default", *, skip_guard_eval_unsafe=False, force_backend=None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Set the current stance of the compiler. @@ -269,15 +289,23 @@ def set_stance( .. code-block:: python @torch.compile +<<<<<<< HEAD def foo(x): ... +======= + def foo(x): + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch.compiler.set_stance("force_eager") def bar(): # will not be compiled foo(...) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bar() with torch.compiler.set_stance("force_eager"): @@ -362,7 +390,11 @@ def set_enable_guard_collectives(enabled: bool): from torch._dynamo.eval_frame import guard_collectives_hook if enabled: +<<<<<<< HEAD return set_guard_complete_hook(guard_collectives_hook) is not None # type: ignore[arg-type] +======= + return set_guard_complete_hook(guard_collectives_hook) is not None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return set_guard_complete_hook(None) is not None @@ -385,7 +417,10 @@ def cudagraph_mark_step_begin(): def rand_foo(): return torch.rand([4], device="cuda") +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for _ in range(5): torch.compiler.cudagraph_mark_step_begin() rand_foo() + rand_foo() @@ -641,6 +676,7 @@ def nested_compile_region(fn=None): ) return _mark_compile_region(fn) +<<<<<<< HEAD def load_compiled_function(file: io.IOBase) -> Callable[..., Any]: @@ -661,3 +697,5 @@ def load_compiled_function(file: io.IOBase) -> Callable[..., Any]: data = file.read() return AOTCompiledFunction.deserialize(data) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/compiler/_cache.py b/torch/compiler/_cache.py index 054ab1bb9fb2c..6c6e6596cc5a4 100644 --- a/torch/compiler/_cache.py +++ b/torch/compiler/_cache.py @@ -72,9 +72,15 @@ class CacheArtifactFactory: @classmethod def register(cls, artifact_cls: type[CacheArtifact]) -> type[CacheArtifact]: artifact_type_key = artifact_cls.type() +<<<<<<< HEAD assert artifact_cls.type() not in cls._artifact_types, ( f"Artifact of type={artifact_type_key} already registered in mega-cache artifact factory" ) +======= + assert ( + artifact_cls.type() not in cls._artifact_types + ), f"Artifact of type={artifact_type_key} already registered in mega-cache artifact factory" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cls._artifact_types[artifact_type_key] = artifact_cls setattr( CacheInfo, @@ -85,9 +91,15 @@ def register(cls, artifact_cls: type[CacheArtifact]) -> type[CacheArtifact]: @classmethod def _get_artifact_type(cls, artifact_type_key: str) -> type[CacheArtifact]: +<<<<<<< HEAD assert artifact_type_key in cls._artifact_types, ( f"Artifact of type={artifact_type_key} not registered in mega-cache artifact factory" ) +======= + assert ( + artifact_type_key in cls._artifact_types + ), f"Artifact of type={artifact_type_key} not registered in mega-cache artifact factory" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return cls._artifact_types[artifact_type_key] @classmethod @@ -135,10 +147,13 @@ def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body] def precompile_aot_autograd_artifacts(self) -> list[str]: # type: ignore[empty-body] ... +<<<<<<< HEAD @property def precompile_dynamo_artifacts(self) -> list[str]: # type: ignore[empty-body] ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def add(self, artifact: CacheArtifact) -> None: self.artifacts[artifact.type()].append(artifact.key) @@ -186,21 +201,35 @@ class CacheArtifactManager: - Call CacheArtifactManager.deserialize to hot load the cache artifacts on a potentially different process +<<<<<<< HEAD NOTE: There's no FB/FC guarantees, results of cache artifacts will not be +======= + NOTE: There's no FB/FC guarentees, results of cache artifacts will not be +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) used unless code version matches. """ # Protected by the compile_lock _new_cache_artifacts: CacheArtifactsResult = defaultdict(list) +<<<<<<< HEAD # Keep a separate seen artifacts list to make avoid unnecessary duplicates +======= + # Keep a seperate seen artifacts list to make avoid unnecessary duplicates +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This list will not be cleared between serialize() calls _seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet() # When serialize() is called, artifacts are transferred from _cache_artifacts to # internal data structure of the _serializer # This allows us to only pay the cost of serialization if serialize() is called +<<<<<<< HEAD _serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = ( AppendingByteSerializer(serialize_fn=_serialize_single_cache) ) +======= + _serializer: AppendingByteSerializer[ + tuple[str, list[CacheArtifact]] + ] = AppendingByteSerializer(serialize_fn=_serialize_single_cache) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _cache_info: CacheInfo = CacheInfo() @classmethod diff --git a/torch/compiler/config.py b/torch/compiler/config.py index d30f6c66f29e9..86da330e3a30b 100644 --- a/torch/compiler/config.py +++ b/torch/compiler/config.py @@ -59,6 +59,7 @@ consistent profiles across all ranks. """ +<<<<<<< HEAD pgo_extra_read_key: Optional[str] = Config( env_name_default="TORCH_COMPILE_STICKY_PGO_READ", default=None ) @@ -72,6 +73,8 @@ and merges it with the default state. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cache_key_tag: str = Config(env_name_default="TORCH_COMPILE_CACHE_KEY_TAG", default="") """ @@ -79,6 +82,7 @@ A common use case for such a tag is to break caches. """ +<<<<<<< HEAD force_disable_caches: bool = Config( justknob="pytorch/remote_cache:force_disable_caches", env_name_force=[ @@ -91,6 +95,8 @@ Force disables all caching -- This will take precedence over and override any other caching flag """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dynamic_sources: str = Config( env_name_default="TORCH_COMPILE_DYNAMIC_SOURCES", default="" ) @@ -113,6 +119,7 @@ and force_parameter_static_shapes. """ +<<<<<<< HEAD # force a python GC before recording cudagraphs force_cudagraph_gc: bool = Config(env_name_default="TORCH_CUDAGRAPH_GC", default=False) """ @@ -121,4 +128,6 @@ """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) install_config_module(sys.modules[__name__]) diff --git a/torch/cpu/__init__.py b/torch/cpu/__init__.py index b42b7f0ff54bd..b5713ec207808 100644 --- a/torch/cpu/__init__.py +++ b/torch/cpu/__init__.py @@ -27,6 +27,11 @@ "Event", ] +<<<<<<< HEAD +======= +_device_t = Union[_device, str, int, None] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _is_avx2_supported() -> bool: r"""Returns a bool indicating if CPU supports AVX2.""" @@ -73,7 +78,11 @@ def is_available() -> bool: return True +<<<<<<< HEAD def synchronize(device: torch.types.Device = None) -> None: +======= +def synchronize(device: _device_t = None) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Waits for all kernels in all streams on the CPU device to complete. Args: @@ -119,7 +128,11 @@ def wait(self, stream=None) -> None: _current_stream = _default_cpu_stream +<<<<<<< HEAD def current_stream(device: torch.types.Device = None) -> Stream: +======= +def current_stream(device: _device_t = None) -> Stream: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Returns the currently selected :class:`Stream` for a given device. Args: @@ -179,7 +192,11 @@ def device_count() -> int: return 1 +<<<<<<< HEAD def set_device(device: torch.types.Device) -> None: +======= +def set_device(device: _device_t) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Sets the current device, in CPU we do nothing. N.B. This function only exists to facilitate device-agnostic code diff --git a/torch/csrc/Device.cpp b/torch/csrc/Device.cpp index 53aca5ae8e31b..0bfd683ffdbff 100644 --- a/torch/csrc/Device.cpp +++ b/torch/csrc/Device.cpp @@ -141,9 +141,15 @@ static PyObject* THPDevice_rc(PyObject* a, PyObject* b, int op) { case Py_LE: case Py_GT: case Py_GE: +<<<<<<< HEAD TORCH_CHECK_TYPE(false, "comparison not implemented"); default: TORCH_CHECK_TYPE(false, "unexpected comparison op"); +======= + throw torch::TypeError("comparison not implemented"); + default: + throw torch::TypeError("unexpected comparison op"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index dc3da8881a715..d4b849501d56a 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -72,6 +72,7 @@ void initModule(PyObject* module) { torch::utils::maybe_initialize_device(device_type); return at::accelerator::maybeExchangeDevice(device_index); }); +<<<<<<< HEAD m.def("_accelerator_isAllocatorInitialized", []() { const auto device_type = at::accelerator::getAccelerator(true).value(); @@ -136,6 +137,8 @@ void initModule(PyObject* module) { m.def("_accelerator_resetPeakStats", [](c10::DeviceIndex device_index) { at::accelerator::resetPeakStats(device_index); }); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace torch::accelerator diff --git a/torch/csrc/Exceptions.cpp b/torch/csrc/Exceptions.cpp index 77085a9463999..676504e5c0af1 100644 --- a/torch/csrc/Exceptions.cpp +++ b/torch/csrc/Exceptions.cpp @@ -228,6 +228,20 @@ std::string processErrorMsg(std::string str) { return str; } +<<<<<<< HEAD +======= +static std::string formatMessage(const char* format, va_list fmt_args) { + constexpr size_t ERROR_BUF_SIZE = 1024; + std::string error_buf(ERROR_BUF_SIZE, '\0'); + auto res = vsnprintf(error_buf.data(), ERROR_BUF_SIZE, format, fmt_args); + if (res < 0) { + res = 0; + } + error_buf.resize(res); + return error_buf; +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void translate_exception_to_python(const std::exception_ptr& e_ptr) { try { TORCH_INTERNAL_ASSERT( @@ -239,6 +253,16 @@ void translate_exception_to_python(const std::exception_ptr& e_ptr) { CATCH_ALL_ERRORS(return) } +<<<<<<< HEAD +======= +TypeError::TypeError(const char* format, ...) { + va_list fmt_args{}; + va_start(fmt_args, format); + msg = formatMessage(format, fmt_args); + va_end(fmt_args); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void PyWarningHandler::InternalHandler::process(const c10::Warning& warning) { warning_buffer_.push_back(warning); } diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index 60a7bb644df01..84bdaa5954510 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -74,7 +74,10 @@ inline void PyErr_SetString(PyObject* type, const std::string& message) { _CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt) \ _CATCH_GENERIC_ERROR( \ NotImplementedError, PyExc_NotImplementedError, retstmnt) \ +<<<<<<< HEAD _CATCH_GENERIC_ERROR(BufferError, PyExc_BufferError, retstmnt) \ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _CATCH_GENERIC_ERROR(SyntaxError, PyExc_SyntaxError, retstmnt) \ _CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \ _CATCH_GENERIC_ERROR( \ @@ -284,12 +287,28 @@ struct PyTorchError : public std::exception { std::string msg; }; +<<<<<<< HEAD // Translates to Python TypeError struct TypeError : public PyTorchError { TORCH_PYTHON_API TypeError() = default; TORCH_PYTHON_API TypeError(std::string msg_) : PyTorchError(std::move(msg_)) {} using PyTorchError::PyTorchError; +======= +// Declare a printf-like function on gcc & clang +// The compiler can then warn on invalid format specifiers +#ifdef __GNUC__ +#define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) \ + __attribute__((format(printf, FORMAT_INDEX, VA_ARGS_INDEX))) +#else +#define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) +#endif + +// Translates to Python TypeError +struct TypeError : public PyTorchError { + using PyTorchError::PyTorchError; + TORCH_PYTHON_API TypeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PyObject* python_type() override { return PyExc_TypeError; } diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp index d99d41ae3d351..d93ea69927843 100644 --- a/torch/csrc/Generator.cpp +++ b/torch/csrc/Generator.cpp @@ -82,11 +82,17 @@ static PyObject* THPGenerator_setState(PyObject* _self, PyObject* _new_state) { HANDLE_TH_ERRORS if (!THPVariable_Check(_new_state)) { +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format( "expected a torch.ByteTensor, but got {}", Py_TYPE(_new_state)->tp_name)); +======= + throw torch::TypeError( + "expected a torch.ByteTensor, but got %s", + Py_TYPE(_new_state)->tp_name); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } auto self = (THPGenerator*)_self; auto& gen = self->cdata; @@ -382,10 +388,15 @@ PyObject* THPGenerator_Wrap(const Generator& gen) { at::Generator THPGenerator_Unwrap(PyObject* state) { if (!Py_IS_TYPE(state, &THPGeneratorType)) { +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format( "expected a Generator, but got {}", Py_TYPE(state)->tp_name)); +======= + throw torch::TypeError( + "expected a Generator, but got %s", Py_TYPE(state)->tp_name); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return reinterpret_cast(state)->cdata; } diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index bf615360b657d..4d1adb606725b 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -2,7 +2,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #ifndef _MSC_VER @@ -20,6 +23,10 @@ #include #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -71,7 +78,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -138,8 +148,11 @@ #include #endif +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace py = pybind11; static PyObject* module; @@ -283,8 +296,12 @@ static PyObject* THPModule_crashIfvptrUBSAN(PyObject* module, PyObject* noarg) { virtual ~Baz() = default; }; Baz x{}; +<<<<<<< HEAD // Purposely cast through `void*` so there's no fixups applied. // NOLINTNEXTLINE(bugprone-casting-through-void,-warnings-as-errors) +======= + // NOLINTNEXTLINE(bugprone-casting*) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto y = static_cast(static_cast(&x)); auto rc = y->bar(); return THPUtils_packInt32(rc); @@ -409,10 +426,17 @@ static PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { // associated with the TensorImpl. Swap this field as well. std::optional mb_obj_a = a->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( +<<<<<<< HEAD /*ignore_hermetic_tls=*/false); std::optional mb_obj_b = b->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( /*ignore_hermetic_tls=*/false); +======= + getPyInterpreter(), /*ignore_hermetic_tls=*/false); + std::optional mb_obj_b = + b->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( + getPyInterpreter(), /*ignore_hermetic_tls=*/false); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_INTERNAL_ASSERT( mb_obj_a.has_value() && mb_obj_b.has_value(), "Both tensors should have PyObjects tagged by the current python interpreter"); @@ -422,8 +446,15 @@ static PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { a->cdata = b->cdata; b->cdata = tmp; +<<<<<<< HEAD a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(a_); b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(b_); +======= + a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj( + getPyInterpreter(), a_, c10::impl::PyInterpreterStatus::TAGGED_BY_US); + b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj( + getPyInterpreter(), b_, c10::impl::PyInterpreterStatus::TAGGED_BY_US); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -586,11 +617,16 @@ static PyObject* THPModule_getCpuCapability( END_HANDLE_TH_ERRORS } +<<<<<<< HEAD namespace { template void DLPack_Capsule_Destructor(PyObject* data) { if (C10_LIKELY(!PyCapsule_IsValid(data, at::DLPackTraits::capsule))) { +======= +static void DLPack_Capsule_Destructor(PyObject* data) { + if (C10_LIKELY(!PyCapsule_IsValid(data, "dltensor"))) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // early out, see DLPack spec: if a consuming library sets the capsule // name to something else, they own it and we don't need to do anything return; @@ -600,6 +636,7 @@ void DLPack_Capsule_Destructor(PyObject* data) { // since consuming libraries should rename the capsule according to spec. // Note that this cannot set a python error (we checked validity above), // so we don't need to handle python error state here. +<<<<<<< HEAD T* tensor = (T*)PyCapsule_GetPointer(data, at::DLPackTraits::capsule); // the dlMTensor has not been consumed, call deleter ourselves. // DLPack spec mentions that deleter may be NULL, but deleter from @@ -661,6 +698,25 @@ static PyObject* THPModule_toDLPackVersioned( return THPModule_toDLPackImpl(self, args, kwargs); } +======= + DLManagedTensor* dlMTensor = + (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + // the dlMTensor has not been consumed, call deleter ourselves. + // DLPack spec mentions that deleter may be NULL, but deleter from + // `at::toDLPack` is never NULL, so no need for an additional check here. + dlMTensor->deleter(dlMTensor); + END_HANDLE_TH_ERRORS_RET() +} + +static PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPVariable_Check(data), "data must be a Tensor"); + DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_Unpack(data)); + return PyCapsule_New(dlMTensor, "dltensor", DLPack_Capsule_Destructor); + END_HANDLE_TH_ERRORS +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) { using namespace torch::autograd; HANDLE_TH_ERRORS @@ -669,6 +725,7 @@ static PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) { END_HANDLE_TH_ERRORS } +<<<<<<< HEAD static PyObject* THPModule_torchDeviceToDLDevice( PyObject* _unused, PyObject* data) { @@ -691,6 +748,8 @@ static PyObject* THPModule_torchDeviceToDLDevice( END_HANDLE_TH_ERRORS } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static PyObject* THModule_getCppBacktrace(PyObject* _unused, PyObject* args) { HANDLE_TH_ERRORS size_t frames_to_skip = 0; @@ -739,12 +798,18 @@ static PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) { } static PyObject* THPModule_allowTF32CuDNN(PyObject* _unused, PyObject* noargs) { +<<<<<<< HEAD HANDLE_TH_ERRORS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (at::globalContext().allowTF32CuDNN()) Py_RETURN_TRUE; else Py_RETURN_FALSE; +<<<<<<< HEAD END_HANDLE_TH_ERRORS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } static PyObject* THPModule_setFloat32MatmulPrecision( @@ -765,7 +830,10 @@ static PyObject* THPModule_setFloat32MatmulPrecision( static PyObject* THPModule_float32MatmulPrecision( PyObject* _unused, PyObject* noargs) { +<<<<<<< HEAD HANDLE_TH_ERRORS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::string s = "highest"; auto p = at::globalContext().float32MatmulPrecision(); if (p == at::Float32MatmulPrecision::HIGH) { @@ -774,7 +842,10 @@ static PyObject* THPModule_float32MatmulPrecision( s = "medium"; } return THPUtils_packString(s); +<<<<<<< HEAD END_HANDLE_TH_ERRORS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } static PyObject* THPModule_setSDPPriorityOrder( PyObject* _unused, @@ -1172,6 +1243,7 @@ static PyObject* THPModule_benchmarkCuDNN(PyObject* _unused, PyObject* noargs) { Py_RETURN_FALSE; } +<<<<<<< HEAD static PyObject* THPModule_setImmediateMiopen( PyObject* _unused, PyObject* arg) { @@ -1195,6 +1267,8 @@ static PyObject* THPModule_immediateMiopen( Py_RETURN_FALSE; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static PyObject* THPModule_setAllowTF32CuBLAS( PyObject* _unused, PyObject* arg) { @@ -1212,12 +1286,18 @@ static PyObject* THPModule_setAllowTF32CuBLAS( static PyObject* THPModule_allowTF32CuBLAS( PyObject* _unused, PyObject* noargs) { +<<<<<<< HEAD HANDLE_TH_ERRORS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (at::globalContext().allowTF32CuBLAS()) { Py_RETURN_TRUE; } Py_RETURN_FALSE; +<<<<<<< HEAD END_HANDLE_TH_ERRORS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } static PyObject* THPModule_setAllowFP16ReductionCuBLAS( @@ -1349,7 +1429,10 @@ static PyObject* THPModule_setQEngine(PyObject* /* unused */, PyObject* arg) { "but got ", THPUtils_typename(arg)); auto qengine = THPUtils_unpackLong(arg); +<<<<<<< HEAD // NOLINTNEXTLINE(clang-analyzer-optin.core.EnumCastOutOfRange) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::globalContext().setQEngine(static_cast(qengine)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -1363,7 +1446,11 @@ static PyObject* THPModule_qEngine(PyObject* _unused, PyObject* noargs) { static PyObject* THPModule_supportedQEngines( PyObject* _unused, PyObject* noargs) { +<<<<<<< HEAD const auto& qengines = at::globalContext().supportedQEngines(); +======= + auto qengines = at::globalContext().supportedQEngines(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto list = THPObjectPtr(PyList_New(static_cast(qengines.size()))); if (!list) @@ -1665,8 +1752,11 @@ static std::initializer_list TorchMethods = { {"_set_onednn_allow_tf32", THPModule_setAllowTF32OneDNN, METH_O, nullptr}, {"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr}, {"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr}, +<<<<<<< HEAD {"_get_miopen_immediate", THPModule_immediateMiopen, METH_NOARGS, nullptr}, {"_set_miopen_immediate", THPModule_setImmediateMiopen, METH_O, nullptr}, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"_get_cudnn_deterministic", THPModule_deterministicCuDNN, METH_NOARGS, @@ -1767,6 +1857,7 @@ static std::initializer_list TorchMethods = { THPModule_are_vmap_fallback_warnings_enabled, METH_NOARGS, nullptr}, +<<<<<<< HEAD {"_to_dlpack", castPyCFunctionWithKeywords(THPModule_toDLPack), METH_VARARGS | METH_KEYWORDS, @@ -1780,6 +1871,10 @@ static std::initializer_list TorchMethods = { THPModule_torchDeviceToDLDevice, METH_O, nullptr}, +======= + {"_to_dlpack", THPModule_toDLPack, METH_O, nullptr}, + {"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"_get_cpp_backtrace", THModule_getCppBacktrace, METH_VARARGS, nullptr}, {"_rename_privateuse1_backend", THModule_rename_privateuse1_backend, @@ -1857,7 +1952,10 @@ static std::initializer_list TorchMethods = { {nullptr, nullptr, 0, nullptr}}; #ifdef USE_CUDA +<<<<<<< HEAD // NOLINTBEGIN(misc-use-internal-linkage) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void THCPStream_init(PyObject* module); void THCPEvent_init(PyObject* module); void THCPGraph_init(PyObject* module); @@ -1866,7 +1964,10 @@ PyMethodDef* THCPModule_methods(); namespace torch::cuda { void initModule(PyObject* module); } // namespace torch::cuda +<<<<<<< HEAD // NOLINTEND(misc-use-internal-linkage) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif #ifdef USE_XPU @@ -1908,6 +2009,7 @@ class WeakTensorRef { } }; +<<<<<<< HEAD namespace { using SigHandler = void (*)(int); @@ -1968,6 +2070,8 @@ void _initCrashHandler() { } // anonymous namespace +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extern "C" TORCH_PYTHON_API PyObject* initModule(); // separate decl and defn for msvc error C2491 PyObject* initModule() { @@ -2077,7 +2181,10 @@ PyObject* initModule() { torch::instruction_counter::initModule(module); torch::initVerboseBindings(module); ASSERT_TRUE(THPStorage_init(module)); +<<<<<<< HEAD torch::functionalization::initModule(module); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef USE_CUDA // This will only initialise base classes and attach them to library namespace @@ -2147,7 +2254,10 @@ PyObject* initModule() { }); auto py_module = py::reinterpret_borrow(module); +<<<<<<< HEAD py_module.def("_initCrashHandler", &_initCrashHandler); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py_module.def("_demangle", &c10::demangle); py_module.def("_log_api_usage_once", &LogAPIUsageOnceFromPython); py_module.def("_log_api_usage_metadata", &LogAPIUsageMetadataFromPython); @@ -2194,7 +2304,11 @@ Call this whenever a new thread is created in order to propagate values from py_module.def("_storage_Use_Count", [](size_t storage_impl_ptr) { // NOLINTNEXTLINE(performance-no-int-to-ptr) c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr; +<<<<<<< HEAD return c10::raw::intrusive_ptr::use_count(storage_impl); +======= + return c10::raw::weak_intrusive_ptr::use_count(storage_impl); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }); ASSERT_TRUE( @@ -2204,8 +2318,11 @@ Call this whenever a new thread is created in order to propagate values from set_module_attr("_has_kleidiai", at::hasKleidiAI() ? Py_True : Py_False)); ASSERT_TRUE( set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False)); +<<<<<<< HEAD ASSERT_TRUE(set_module_attr( "_has_eigen_sparse", at::hasEigenSparse() ? Py_True : Py_False)); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py_module.def("_valgrind_supported_platform", []() { #if defined(USE_VALGRIND) @@ -2475,6 +2592,7 @@ Call this whenever a new thread is created in order to propagate values from }); py_module.def( +<<<<<<< HEAD "_get_fp32_precision_getter", [](const std::string& backend, const std::string& op) { return at::globalContext().float32Precision(backend, op); @@ -2490,6 +2608,8 @@ Call this whenever a new thread is created in order to propagate values from }); py_module.def( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "_stash_obj_in_tls", [](const std::string& key, py::handle arg) { at::impl::ThreadLocalPythonObjects::get_state().set( key, @@ -2606,6 +2726,33 @@ Call this whenever a new thread is created in order to propagate values from ASSERT_TRUE(set_module_attr("_GLIBCXX_USE_CXX11_ABI", Py_True)); +<<<<<<< HEAD +======= +// See note [Pybind11 ABI constants] +#define SET_STR_DEFINE(name) \ + ASSERT_TRUE(set_module_attr("_" #name, THPUtils_packString(name))) + +#ifdef PYBIND11_COMPILER_TYPE + SET_STR_DEFINE(PYBIND11_COMPILER_TYPE); +#else + ASSERT_TRUE( + set_module_attr("_" C10_STRINGIZE(PYBIND11_COMPILER_TYPE), Py_None)); +#endif + +#ifdef PYBIND11_STDLIB + SET_STR_DEFINE(PYBIND11_STDLIB); +#else + ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_STDLIB), Py_None)); +#endif + +#ifdef PYBIND11_BUILD_ABI + SET_STR_DEFINE(PYBIND11_BUILD_ABI); +#else + ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_BUILD_ABI), Py_None)); +#endif +#undef SET_STR_DEFINE + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py_module.def( "_set_conj", [](const at::Tensor& x, bool conj) { x._set_conj(conj); }); py_module.def( @@ -2762,8 +2909,11 @@ Call this whenever a new thread is created in order to propagate values from #ifdef USE_KINETO torch::global_kineto_init(); #endif +<<<<<<< HEAD auto nativert_module = py_module.def_submodule("_nativert"); torch::nativert::initModelRunnerPybind(nativert_module); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return module; END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp index e6016a7721e8b..3be2b5b0c81de 100644 --- a/torch/csrc/PyInterpreter.cpp +++ b/torch/csrc/PyInterpreter.cpp @@ -82,8 +82,11 @@ struct ConcretePyInterpreterVTable final bool is_contiguous(const c10::TensorImpl* self, at::MemoryFormat) const override; +<<<<<<< HEAD c10::SymBool sym_is_contiguous(const c10::TensorImpl* self, at::MemoryFormat) const override; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool is_strides_like(const c10::TensorImpl* self, at::MemoryFormat) const override; bool is_non_overlapping_and_dense(const c10::TensorImpl* self) const override; @@ -478,6 +481,7 @@ bool ConcretePyInterpreterVTable::is_contiguous( return PyObject_IsTrue(out.ptr()); } +<<<<<<< HEAD c10::SymBool ConcretePyInterpreterVTable::sym_is_contiguous( const c10::TensorImpl* self, at::MemoryFormat memory_format) const { @@ -505,6 +509,8 @@ c10::SymBool ConcretePyInterpreterVTable::sym_is_contiguous( : c10::SymBool{py::cast(out)}; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool ConcretePyInterpreterVTable::is_strides_like( const c10::TensorImpl* self, at::MemoryFormat memory_format) const { @@ -615,7 +621,11 @@ static void set_tensor_attr_with_capsule( py::capsule& capsule, const char* attr_name) { std::optional mb_obj = tensor->pyobj_slot()->check_pyobj( +<<<<<<< HEAD /*ignore_hermetic_tls=*/false); +======= + getPyInterpreter(), /*ignore_hermetic_tls=*/false); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK( mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); auto obj = mb_obj.value(); @@ -1016,3 +1026,10 @@ py::handle getTorchApiFunction(const c10::OperatorHandle& op) { c10::impl::PyInterpreter* getPyInterpreter() { return torch::detail::self_interpreter.get(); } +<<<<<<< HEAD +======= + +bool isMainPyInterpreter() { + return torch::detail::self_interpreter.is_main_interpreter(); +} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/csrc/PyInterpreter.h b/torch/csrc/PyInterpreter.h index 0ff9f79d02c27..f5fd74cfe5bcc 100644 --- a/torch/csrc/PyInterpreter.h +++ b/torch/csrc/PyInterpreter.h @@ -10,4 +10,8 @@ TORCH_PYTHON_API py::handle getTorchApiFunction(const c10::OperatorHandle& op); // TODO: Move these to a proper namespace TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter(); +<<<<<<< HEAD TORCH_PYTHON_API void initializeGlobalPyInterpreter(); +======= +TORCH_PYTHON_API bool isMainPyInterpreter(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 08112b41aaaed..176d78ba1816a 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -35,6 +35,10 @@ PyTypeObject* THPStorageClass = nullptr; PyObject* THPStorage_NewWithStorage( PyTypeObject* type, c10::Storage _storage, +<<<<<<< HEAD +======= + c10::impl::PyInterpreterStatus status, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool allow_preexisting_pyobj) { TORCH_CHECK( PyType_IsSubtype(type, &THPStorageType), @@ -42,7 +46,11 @@ PyObject* THPStorage_NewWithStorage( "Storage is not possible. Make sure your class inherits from Storage."); auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj( +<<<<<<< HEAD /*ignore_hermetic_tls=*/false); +======= + getPyInterpreter(), /*ignore_hermetic_tls=*/false); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (maybe_pyobj.has_value() && maybe_pyobj.value()) { TORCH_CHECK( allow_preexisting_pyobj, @@ -77,7 +85,12 @@ PyObject* THPStorage_NewWithStorage( if (!c10::impl::HermeticPyObjectTLS::get_state()) { s->is_hermetic = false; const auto& storage = THPStorage_Unpack(s); +<<<<<<< HEAD storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(obj); +======= + storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj( + getPyInterpreter(), obj, status); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { s->is_hermetic = true; } @@ -89,12 +102,37 @@ PyObject* THPStorage_NewWithStorage( PyObject* THPStorage_Wrap(c10::Storage storage) { c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); if (c10::impl::HermeticPyObjectTLS::get_state()) { +<<<<<<< HEAD return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); } c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot(); std::optional maybe_pyobj = pyobj_slot->check_pyobj( /*ignore_hermetic_tls=*/false); +======= + return THPStorage_NewWithStorage( + THPStorageClass, + std::move(storage), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + } + c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot(); + + // If the StorageImpl has a PyObject that is managed by a different + // interpreter than the current one, create a new StorageImpl that points to + // the same data and then create the Python storage from that. + // NOTE: This is only supposed to happen in MultiPy // codespell:ignore + if (pyobj_slot->has_pyobj_nonhermetic() && + !pyobj_slot->check_interpreter(getPyInterpreter())) { + return THPStorage_NewWithStorage( + THPStorageClass, + c10::newStorageImplFromRefcountedDataPtr(storage), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + } + std::optional maybe_pyobj = pyobj_slot->check_pyobj( + getPyInterpreter(), /*ignore_hermetic_tls=*/false); + c10::impl::PyInterpreterStatus status = + c10::impl::PyInterpreterStatus::TAGGED_BY_US; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (maybe_pyobj.has_value()) { auto obj = *maybe_pyobj; if (obj) { @@ -113,8 +151,20 @@ PyObject* THPStorage_Wrap(c10::Storage storage) { return obj; } } +<<<<<<< HEAD } return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); +======= + status = c10::impl::PyInterpreterStatus::TAGGED_BY_US; + } else { + if (storage.use_count() <= 1) { + status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED; + } else { + status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED; + } + } + return THPStorage_NewWithStorage(THPStorageClass, std::move(storage), status); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } static bool THPStorage_isPreservable(THPStorage* self) { @@ -128,7 +178,12 @@ static bool THPStorage_isPreservable(THPStorage* self) { } if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj( +<<<<<<< HEAD /*ignore_hermetic_tls=*/true) != (PyObject*)self) { +======= + getPyInterpreter(), /*ignore_hermetic_tls=*/true) != + (PyObject*)self) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return false; } if (storage.use_count() <= 1) { @@ -146,10 +201,18 @@ static bool THPStorage_tryPreserve(THPStorage* self) { c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); auto maybe_pyobj = storage_impl->pyobj_slot()->check_pyobj( +<<<<<<< HEAD /*ignore_hermetic_tls=*/true); // NOTE: It is possible to just set the PyObjectSlot here, but the point is // that we should have already set PyObjectSlot when the storage PyObject // was created. +======= + getPyInterpreter(), + /*ignore_hermetic_tls=*/true); + // NOTE: It is possible to just set the PyObjectSlot here, but the point is + // that we should have already set PyObjectSlot when the storage PyObject was + // created. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_INTERNAL_ASSERT( maybe_pyobj.has_value(), "Trying to preserve a Python storage whose PyObjectSlot does not have a PyObject"); @@ -357,7 +420,12 @@ static PyObject* THPStorage_pynew( at::DataPtr(), allocator, /*resizable=*/true, +<<<<<<< HEAD device_opt)); +======= + device_opt), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // torch.Storage(size, *, ...) } else if (r.idx == 1) { @@ -370,7 +438,12 @@ static PyObject* THPStorage_pynew( at::DataPtr(), allocator, /*resizable=*/true, +<<<<<<< HEAD device_opt)); +======= + device_opt), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // torch.Storage(sequence, *, ...) } else if (r.idx == 2) { @@ -394,7 +467,12 @@ static PyObject* THPStorage_pynew( at::DataPtr(), allocator, /*resizable=*/true, +<<<<<<< HEAD device_opt)); +======= + device_opt), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) THPObjectPtr item; try { const auto& storage = THPStorage_Unpack(self); @@ -490,8 +568,15 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) { /* resizable */ false, device_opt); +<<<<<<< HEAD PyObject* _ret = THPStorage_NewWithStorage(Py_TYPE(self), std::move(new_storage_impl)); +======= + PyObject* _ret = THPStorage_NewWithStorage( + Py_TYPE(self), + std::move(new_storage_impl), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return _ret; } diff --git a/torch/csrc/Storage.h b/torch/csrc/Storage.h index 698cd80548efa..1818d6926df78 100644 --- a/torch/csrc/Storage.h +++ b/torch/csrc/Storage.h @@ -19,6 +19,10 @@ TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage); TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage( PyTypeObject* type, c10::Storage _storage, +<<<<<<< HEAD +======= + c10::impl::PyInterpreterStatus status, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool allow_preexisting_pyobj = false); TORCH_PYTHON_API extern PyTypeObject* THPStorageClass; diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp index da64bcfbd5008..0fc08c208ff6a 100644 --- a/torch/csrc/StorageMethods.cpp +++ b/torch/csrc/StorageMethods.cpp @@ -390,7 +390,14 @@ static PyObject* THPStorage_fromFile( storage->set_nbytes(actual_nbytes); } +<<<<<<< HEAD return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); +======= + return THPStorage_NewWithStorage( + THPStorageClass, + std::move(storage), + c10::impl::PyInterpreterStatus::TAGGED_BY_US); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/StorageSharing.cpp b/torch/csrc/StorageSharing.cpp index e58865bb60a8a..e22125cd08176 100644 --- a/torch/csrc/StorageSharing.cpp +++ b/torch/csrc/StorageSharing.cpp @@ -86,7 +86,12 @@ static PyObject* THPStorage_pyNewFilenameStorage( THManagedMapAllocator::makeDataPtr( "", handle.c_str(), flags, static_cast(size)), /*allocator=*/nullptr, +<<<<<<< HEAD /*resizable=*/false)); +======= + /*resizable=*/false), + c10::impl::PyInterpreterStatus::TAGGED_BY_US); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) END_HANDLE_TH_ERRORS } @@ -181,7 +186,12 @@ static PyObject* THPStorage_newSharedFilename( THManagedMapAllocator::makeDataPtr( manager_handle, object_handle, flags, size), /*allocator=*/nullptr, +<<<<<<< HEAD /*resizable=*/false)); +======= + /*resizable=*/false), + c10::impl::PyInterpreterStatus::TAGGED_BY_US); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) END_HANDLE_TH_ERRORS } @@ -195,7 +205,13 @@ static PyObject* THPStorage_pyNewFdStorage(PyObject* _unused, PyObject* args) { return nullptr; } return THPStorage_NewWithStorage( +<<<<<<< HEAD THPStorageClass, at::new_shm_fd_storage(size)); +======= + THPStorageClass, + at::new_shm_fd_storage(size), + c10::impl::PyInterpreterStatus::TAGGED_BY_US); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) END_HANDLE_TH_ERRORS } @@ -274,7 +290,12 @@ static PyObject* THPStorage_newSharedFd(PyObject* _unused, PyObject* args) { at::MapAllocator::makeDataPtr( at::WITH_FD, "", fd, flags, size, nullptr), /*allocator=*/nullptr, +<<<<<<< HEAD /*resizable=*/false)); +======= + /*resizable=*/false), + c10::impl::PyInterpreterStatus::TAGGED_BY_US); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) END_HANDLE_TH_ERRORS } @@ -555,7 +576,14 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) { base->set_resizable(false); base->set_received_cuda(true); +<<<<<<< HEAD return THPStorage_NewWithStorage(THPStorageClass, std::move(base)); +======= + return THPStorage_NewWithStorage( + THPStorageClass, + std::move(base), + c10::impl::PyInterpreterStatus::TAGGED_BY_US); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #else TORCH_CHECK(false, "CUDA is not available"); #endif diff --git a/torch/csrc/api/include/torch/python.h b/torch/csrc/api/include/torch/python.h index 4878b1cc851a7..9d22544c6797f 100644 --- a/torch/csrc/api/include/torch/python.h +++ b/torch/csrc/api/include/torch/python.h @@ -26,7 +26,11 @@ inline Device py_object_to_device(py::object object) { if (THPDevice_Check(obj)) { return reinterpret_cast(obj)->device; } +<<<<<<< HEAD TORCH_CHECK_TYPE(false, "Expected device"); +======= + throw TypeError("Expected device"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } inline Dtype py_object_to_dtype(py::object object) { @@ -34,7 +38,11 @@ inline Dtype py_object_to_dtype(py::object object) { if (THPDtype_Check(obj)) { return reinterpret_cast(obj)->scalar_type; } +<<<<<<< HEAD TORCH_CHECK_TYPE(false, "Expected dtype"); +======= + throw TypeError("Expected dtype"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } template diff --git a/torch/csrc/api/src/optim/sgd.cpp b/torch/csrc/api/src/optim/sgd.cpp index 821587e439375..2438af3f0330f 100644 --- a/torch/csrc/api/src/optim/sgd.cpp +++ b/torch/csrc/api/src/optim/sgd.cpp @@ -84,7 +84,11 @@ Tensor SGD::step(LossClosure closure) { Tensor buf; auto param_state = state_.find(p.unsafeGetTensorImpl()); if (param_state == state_.end()) { +<<<<<<< HEAD buf = d_p.detach().clone(); +======= + buf = torch::clone(d_p).detach(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto state = std::make_unique(); state->momentum_buffer(buf); state_[p.unsafeGetTensorImpl()] = std::move(state); diff --git a/torch/csrc/api/src/serialize.cpp b/torch/csrc/api/src/serialize.cpp index fae54d1248476..01a981e3cd8e0 100644 --- a/torch/csrc/api/src/serialize.cpp +++ b/torch/csrc/api/src/serialize.cpp @@ -1,4 +1,8 @@ #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 2b7e7760754d4..107986731cf16 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -894,8 +894,12 @@ Tensor logcumsumexp_backward( // Reference: // https://github.com/tensorflow/tensorflow/blob/2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863 +<<<<<<< HEAD auto scalar_min = AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( at::ScalarType::Half, +======= + auto scalar_min = AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::ScalarType::BFloat16, at::typeMetaToScalarType(grad.dtype()), "logcumsumexp_backward", @@ -1078,7 +1082,11 @@ std::vector cat_tensors_backward( auto& shape = sizes[i]; // If input was empty tensor, gradInput should be empty tensor. if (shape.size() == 1) { +<<<<<<< HEAD if (TORCH_GUARD_OR_FALSE(shape[0].sym_eq(0))) { +======= + if (TORCH_GUARD_SIZE_OBLIVIOUS(shape[0].sym_eq(0))) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) grad_inputs[i] = at::zeros({0}, grad_val.options()); continue; } @@ -3452,11 +3460,16 @@ std::tuple linalg_svd_jvp( const auto V = Vh.mH(); // dP = U^H dA V +<<<<<<< HEAD // U^H (dA V) is O(km(n + k)) // (U^H dA) V is O(kn(m + k)) // So prefer U^H (dA V) if m < n auto dP = m < n ? at::matmul(U.mH(), at::matmul(dA, V)) : at::matmul(at::matmul(U.mH(), dA), V); +======= + auto dP = m >= n ? at::matmul(U.mH(), at::matmul(dA, V)) + : at::matmul(at::matmul(U.mH(), dA), V); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto dS = is_complex ? at::real(dP.diagonal(0, -2, -1)) : dP.diagonal(0, -2, -1); @@ -5026,6 +5039,7 @@ std::tuple layer_norm_double_backward( return std::tuple{gI, gG, ggO}; } +<<<<<<< HEAD std::tuple infinitely_differentiable_native_rms_norm_backward( const Tensor& dY, const Tensor& drstd, @@ -5123,6 +5137,8 @@ std::tuple infinitely_differentiable_native_rms_norm_backward( return std::make_tuple(dX, dgamma); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::tuple infinitely_differentiable_native_group_norm_backward( const Tensor& dY, @@ -6477,6 +6493,7 @@ Tensor layer_norm_jvp( bias_t.defined() ? bias_t.view(view_size_affine) : bias_t); } +<<<<<<< HEAD Tensor rms_norm_jvp( const Tensor& input_p, const Tensor& input_t, @@ -6569,6 +6586,8 @@ Tensor rms_norm_rstd_jvp( return rstd_t; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor group_norm_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 96864e165a95a..1dcb77fd3e6ea 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -35,7 +35,13 @@ TORCH_API Tensor toNonOptFwGrad(const std::optional& t); TORCH_API Tensor toNonOptPrimal(const std::optional& t); TORCH_API Tensor toNonOptTensor(const std::optional& t); +<<<<<<< HEAD inline std::optional wrap_opt_if(const Tensor& t, const bool cond) { +======= +TORCH_API inline std::optional wrap_opt_if( + const Tensor& t, + const bool cond) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using OptTensor = std::optional; return cond ? OptTensor(t) : static_cast(std::nullopt); } @@ -826,6 +832,7 @@ std::tuple layer_norm_double_backward( c10::SymIntArrayRef normalized_shape, std::array output_mask); +<<<<<<< HEAD std::tuple infinitely_differentiable_native_rms_norm_backward( const Tensor& dY, const Tensor& drstd, @@ -835,6 +842,8 @@ std::tuple infinitely_differentiable_native_rms_norm_backward( const std::optional& weight_opt, std::array grad_input_mask); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::tuple householder_product_backward( const Tensor& grad, const Tensor& result, @@ -974,6 +983,7 @@ Tensor layer_norm_jvp( const Tensor& saved_invstd, c10::SymIntArrayRef normalized_shape); +<<<<<<< HEAD Tensor rms_norm_jvp( const Tensor& input_p, const Tensor& input_t, @@ -988,6 +998,8 @@ Tensor rms_norm_rstd_jvp( const Tensor& saved_rstd, IntArrayRef normalized_shape); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor group_norm_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index f0024f8f0b070..d063b9f526d0d 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -979,6 +979,7 @@ static void validate_outputs_impl( } if (grad.device() != metadata.device()) { +<<<<<<< HEAD if (grad.dim() == 0) { grad = grad.to(metadata.device()); } else { @@ -986,6 +987,15 @@ static void validate_outputs_impl( // should be eventually removed if (!(metadata.is_tensor_subclass() || grad.unsafeGetTensorImpl()->is_python_dispatch())) { +======= + // quick hack for: https://github.com/pytorch/pytorch/issues/65016 but + // should be eventually removed + if (!(metadata.is_tensor_subclass() || + grad.unsafeGetTensorImpl()->is_python_dispatch())) { + if (grad.dim() == 0) { + grad = grad.to(metadata.device()); + } else { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::stringstream ss; ss << "invalid gradient at index " << i << " - expected device "; ss << metadata.device() << " but got " << grad.device(); diff --git a/torch/csrc/autograd/functions/basic_ops.cpp b/torch/csrc/autograd/functions/basic_ops.cpp index d461c638df12a..548977e276e0b 100644 --- a/torch/csrc/autograd/functions/basic_ops.cpp +++ b/torch/csrc/autograd/functions/basic_ops.cpp @@ -57,7 +57,17 @@ auto UndefinedGrad::apply(variable_list&& inputs) -> variable_list { auto UndefinedGradBackward::apply(variable_list&& output_grads) -> variable_list { +<<<<<<< HEAD return tensor_list(output_grads.size()); +======= + tensor_list input_grads; + output_grads.reserve(input_grads.size()); + for (auto& grad : output_grads) { + (void)grad; // Suppress unused variable warning + input_grads.emplace_back(); + } + return input_grads; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } auto Identity::apply(variable_list&& grads) -> variable_list { diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 380060501882f..0b22914c2d997 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -307,11 +307,15 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { e.activityType() == (uint8_t)libkineto::ActivityType::GPU_USER_ANNOTATION; }) +<<<<<<< HEAD .def("nbytes", [](const KinetoEvent& e) { return e.nBytes(); }) // whether the event is hidden .def("is_hidden_event", [](const KinetoEvent& e) { return e.isHiddenEvent(); }); +======= + .def("nbytes", [](const KinetoEvent& e) { return e.nBytes(); }); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m.def("_soft_assert_raises", &setSoftAssertRaises); m.def("_get_sequence_nr", &at::sequence_number::peek); diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index 6880caddc8d25..997145edcde87 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -936,10 +936,13 @@ bool KinetoEvent::hasKwinputs() const { return !kwinputs_.empty(); } +<<<<<<< HEAD bool KinetoEvent::isHiddenEvent() const { return result_ && result_->hidden_; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const std::unordered_map KinetoEvent::kwinputs() const { return kwinputs_; diff --git a/torch/csrc/autograd/profiler_kineto.h b/torch/csrc/autograd/profiler_kineto.h index 34d65a0b8dd6b..6210ed06a5ce6 100644 --- a/torch/csrc/autograd/profiler_kineto.h +++ b/torch/csrc/autograd/profiler_kineto.h @@ -37,7 +37,10 @@ struct TORCH_API KinetoEvent { bool hasConcreteInputs() const; const c10::ArrayRef concreteInputs() const; bool hasKwinputs() const; +<<<<<<< HEAD bool isHiddenEvent() const; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const std::unordered_map kwinputs() const; uint64_t flops() const; int64_t sequenceNr() const; diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index 78a0c6eeec7ac..7fa892b391283 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -674,14 +674,18 @@ struct ThreadLocalResults { CallTypeHelper::tuple_type trace_keys_; AppendOnlyList exit_times_; AppendOnlyList c_exit_times_; +<<<<<<< HEAD int active_frames_{0}; int remaining_start_frames_{0}; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; // ============================================================================ // == Tracing implementation ================================================== // ============================================================================ +<<<<<<< HEAD #define IS_PYTHON_3_12 (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION == 12) #if IS_PYTHON_3_12 // forward declarations @@ -693,6 +697,8 @@ static PyObject* c_call_callback( PyObject* kwnames); #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class PythonTracer final : public python_tracer::PythonTracerBase { public: PythonTracer(torch::profiler::impl::RecordQueue* queue); @@ -704,7 +710,11 @@ class PythonTracer final : public python_tracer::PythonTracerBase { PyFrameObject* frame, int what, PyObject* arg); +<<<<<<< HEAD void register_gc_callback() override; +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void stop() override; void restart() override; std::vector> getEvents( @@ -723,8 +733,11 @@ class PythonTracer final : public python_tracer::PythonTracerBase { PyFrameObject* frame, bool is_startup_frame); +<<<<<<< HEAD static PyObject* gc_event_callback(PyObject* self, PyObject* args); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void recordCCall( ThreadLocalResults& tls, PyFrameObject* frame, @@ -733,9 +746,16 @@ class PythonTracer final : public python_tracer::PythonTracerBase { const std::vector interpreterThreads() const; +<<<<<<< HEAD std::atomic active_lock_{false}; bool active_{false}; bool gc_callback_registered_{false}; +======= + PyObject* get_callable_from_frame(PyFrameObject* frame); + + std::atomic active_lock_{false}; + bool active_{false}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch::profiler::impl::RecordQueue* queue_; PyInterpreterState* interpreter_{nullptr}; @@ -745,6 +765,7 @@ class PythonTracer final : public python_tracer::PythonTracerBase { std::vector start_frames_; std::deque thread_local_results_; ValueCache value_cache_; +<<<<<<< HEAD #if IS_PYTHON_3_12 friend PyObject* c_call_callback( @@ -963,6 +984,10 @@ static void unregisterMonitoringCallback() { } #endif +======= +}; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const std::vector PythonTracer::interpreterThreads() const { pybind11::gil_scoped_acquire gil; std::vector out; @@ -976,6 +1001,7 @@ const std::vector PythonTracer::interpreterThreads() const { return out; } +<<<<<<< HEAD // we are only registering on main thread while holding GIL so this should be // safe static PyObject* py_gc_callback = nullptr; @@ -997,6 +1023,8 @@ PyObject* PythonTracer::gc_event_callback(PyObject* self, PyObject* args) { Py_RETURN_NONE; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) : queue_(queue), @@ -1026,8 +1054,12 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) PyThreadState_Swap(thread_state); thread_local_results_.emplace_back(thread_state, &value_cache_, this); +<<<<<<< HEAD auto& tls = thread_local_results_.back(); auto* ctx = tls.ctx_; +======= + auto* ctx = thread_local_results_.back().ctx_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // When we begin profiling there are already frames on the Python // interpreter stack. To ensure a complete trace, we must push calls @@ -1049,7 +1081,21 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) } for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) { +<<<<<<< HEAD recordPyCall(tls, it->get(), true); +======= + recordPyCall(thread_local_results_.back(), it->get(), true); + PyFrameObject* frame = it->get(); + PyObject* callable = get_callable_from_frame(frame); + if (callable) { + // If the frame has a callable, record it as a C call since + // PyEval_GetFrame only gets the python frame. We need to record this C + // call so that when exiting the profiler we don't have a mismatched C + // call. + recordCCall(thread_local_results_.back(), it->get(), callable, true); + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto frame_refcount = Py_REFCNT(it->get()); // We hold one reference in `current_stack`, and the interpreter holds @@ -1057,13 +1103,17 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) TORCH_INTERNAL_ASSERT(frame_refcount >= 2, frame_refcount); } +<<<<<<< HEAD tls.remaining_start_frames_ = tls.active_frames_; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Note: // This profile will not compose with other CPython profilers, and // cannot be round tripped via `sys.settrace(sys.gettrace())` PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx); } +<<<<<<< HEAD #if IS_PYTHON_3_12 registerMonitoringCallback(); #endif @@ -1129,14 +1179,19 @@ void PythonTracer::register_gc_callback() { Py_DECREF(callbacks); Py_DECREF(gc_module); PyGILState_Release(gstate); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void PythonTracer::stop() { gil_and_restore_thread gil; +<<<<<<< HEAD if (gc_callback_registered_) { unregister_gc_callback(); gc_callback_registered_ = false; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (active_) { for (const auto thread_state : interpreterThreads()) { if (thread_state->c_profilefunc == &PythonTracer::pyProfileFn) { @@ -1145,10 +1200,13 @@ void PythonTracer::stop() { } } +<<<<<<< HEAD #if IS_PYTHON_3_12 unregisterMonitoringCallback(); #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto lock_returned = active_lock_.compare_exchange_strong(active_, false); active_ = false; SOFT_ASSERT(lock_returned, "Failed to return python tracer lock."); @@ -1172,9 +1230,12 @@ void PythonTracer::restart() { PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx); } } +<<<<<<< HEAD #if IS_PYTHON_3_12 registerMonitoringCallback(); #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // NOLINTNEXTLINE(bugprone-exception-escape) @@ -1237,7 +1298,10 @@ void PythonTracer::recordPyCall( const auto time = c10::getApproximateTime(); is_startup_frame ? start_frames_.push_back({key, time}) : queue_->getSubqueue()->emplace_py_call(key, time); +<<<<<<< HEAD ++tls.active_frames_; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void PythonTracer::recordCCall( @@ -1257,7 +1321,30 @@ void PythonTracer::recordCCall( auto key = tls.intern( arg, (void*)(fn->m_ml), frame); queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime()); +<<<<<<< HEAD ++tls.active_frames_; +======= +} + +PyObject* PythonTracer::get_callable_from_frame(PyFrameObject* frame) { + if (frame == nullptr) { + return nullptr; + } + // Get the code object associated with the frame + auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); + if (code == nullptr) { + return nullptr; + } + // Get the function name (if needed) + auto name = THPUtils_unpackStringView(code->co_name).data(); + // To get the function object, you will need to look in the globals or the + // frame's f_globals + PyObject* func = PyDict_GetItemString(PyFrame_GetGlobals(frame), name); + if (func) { + Py_INCREF(func); // Make sure the returned function has a reference + } + return func; // Returns a PyObject* (the function) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // ============================================================================ @@ -1361,7 +1448,13 @@ class PostProcess { state.exits_.top().t_ < enter.enter_t_) { auto& exit = state.exits_.top(); auto& tstack = stacks[exit.python_tid_]; +<<<<<<< HEAD pop(tstack, exit.t_); +======= + if (!tstack.empty()) { + pop(tstack, exit.t_); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) state.exits_.pop(); } out.push_back(Result::create( @@ -1553,6 +1646,7 @@ int PythonTracer::pyProfileFn( local_results.active_tracer_->recordCCall(local_results, frame, arg); break; +<<<<<<< HEAD case PyTrace_RETURN: local_results.exit_times_.emplace_back(c10::getApproximateTime()); local_results.active_frames_--; @@ -1560,15 +1654,24 @@ int PythonTracer::pyProfileFn( local_results.remaining_start_frames_) { local_results.remaining_start_frames_ = local_results.active_frames_; } +======= + case PyTrace_EXCEPTION: + case PyTrace_RETURN: + local_results.exit_times_.emplace_back(c10::getApproximateTime()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) break; case PyTrace_C_EXCEPTION: case PyTrace_C_RETURN: +<<<<<<< HEAD if (local_results.active_frames_ > local_results.remaining_start_frames_) { local_results.c_exit_times_.emplace_back(c10::getApproximateTime()); local_results.active_frames_--; } +======= + local_results.c_exit_times_.emplace_back(c10::getApproximateTime()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) break; } return 0; diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 14591bc1fb4a1..85e227db2b6e1 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -793,6 +793,7 @@ static void _get_tensors_to_save( if (is_executable) { // TODO: We should really just ALWAYS throw an error here, but // doing so will break some internal tests. We should fix those. +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format( @@ -804,6 +805,16 @@ static void _get_tensors_to_save( } } Py_CLEAR(self->to_save); +======= + throw torch::TypeError( + "save_for_backward can only save variables, but argument %ld is of " + "type %s", + i, + Py_TYPE(obj)->tp_name); + } + } + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // Save any variables that requested by to_save @@ -811,7 +822,11 @@ static void _save_variables( const std::vector>& tensors_to_save, const std::shared_ptr& cdata_ptr, THPFunction* self) { +<<<<<<< HEAD if (tensors_to_save.size() == 0) +======= + if (!self->to_save) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return; size_t num_saved = tensors_to_save.size(); self->saved_variables.clear(); @@ -824,6 +839,11 @@ static void _save_variables( self->saved_variables.emplace_back(opt_tensor.value(), is_output); } } +<<<<<<< HEAD +======= + // Free .to_save + Py_CLEAR(self->to_save); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // Mark requires_grad = 0 on non-differentiable variables (as per @@ -1053,8 +1073,12 @@ void _trace_post_record( } } } +<<<<<<< HEAD py::object onnx_globals = py::module::import("torch.onnx._internal.torchscript_exporter._globals"); +======= + py::object onnx_globals = py::module::import("torch.onnx._globals"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::bool_ is_in_onnx_export = py::module::import("torch.onnx.__init__").attr("is_in_onnx_export"); py::bool_ is_autograd_inlining_enabled = diff --git a/torch/csrc/autograd/python_legacy_variable.cpp b/torch/csrc/autograd/python_legacy_variable.cpp index ee00008c94bb9..1b60ccf7155fb 100644 --- a/torch/csrc/autograd/python_legacy_variable.cpp +++ b/torch/csrc/autograd/python_legacy_variable.cpp @@ -1,7 +1,10 @@ #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -58,9 +61,14 @@ static PyObject* THPVariable_pynew( !is_volatile || !requires_grad, "Variable can't be volatile and require_grad at the same time!"); if (grad_fn && !THPFunction_Check(grad_fn)) { +<<<<<<< HEAD TORCH_CHECK_TYPE( false, "_grad_fn has to be a Function object or None, but got ", +======= + throw TypeError( + "_grad_fn has to be a Function object or None, but got %s", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Py_TYPE(grad_fn)->tp_name); } Variable var; @@ -76,10 +84,15 @@ static PyObject* THPVariable_pynew( } else if (THPVariable_Check(data)) { var = THPVariable_Unpack(data).detach(); } else { +<<<<<<< HEAD TORCH_CHECK_TYPE( false, "Variable data has to be a tensor, but got ", Py_TYPE(data)->tp_name); +======= + throw torch::TypeError( + "Variable data has to be a tensor, but got %s", Py_TYPE(data)->tp_name); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // We set `tensor`'s `allow_tensor_metadata_change` to true here, because we // want to allow the following use case for backward compatibility: diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index 79739b6e459d2..df41a9905948d 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -622,6 +622,7 @@ void initTorchFunctions(PyObject* module) { return impl->was_inductor_storage_resized(); }); py_module.def( +<<<<<<< HEAD "_functionalize_inductor_storage_resized_counter", [](const at::Tensor& t) { TORCH_INTERNAL_ASSERT( @@ -630,6 +631,8 @@ void initTorchFunctions(PyObject* module) { return impl->inductor_storage_resized_counter(); }); py_module.def( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "_functionalize_are_all_mutations_hidden_from_autograd", [](const at::Tensor& t) { TORCH_INTERNAL_ASSERT( @@ -644,6 +647,18 @@ void initTorchFunctions(PyObject* module) { at::functionalization::impl::isFunctionalTensor(t)); at::functionalization::impl::mark_mutation_hidden_from_autograd(t); }); +<<<<<<< HEAD +======= + py_module.def( + "_functionalize_apply_view_metas", + [](const at::Tensor& tensor, const at::Tensor& base) { + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(tensor)); + auto impl = + at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); + return impl->apply_view_metas(base); + }); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py_module.def("_functionalize_is_symbolic", [](const at::Tensor& t) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); @@ -697,11 +712,14 @@ void initTorchFunctions(PyObject* module) { auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); return t_impl->has_data_mutation(); }); +<<<<<<< HEAD py_module.def("_functionalize_mutation_counter", [](const at::Tensor& t) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); return t_impl->mutation_counter(); }); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py_module.def( "_functionalize_get_storage_size", [](const at::Tensor& t, bool before) { TORCH_INTERNAL_ASSERT( @@ -711,10 +729,17 @@ void initTorchFunctions(PyObject* module) { auto size = wrapper->get_storage_size(/*before=*/before); return size; }); +<<<<<<< HEAD py_module.def("_functionalize_mark_storage_changed", [](const at::Tensor& t) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(t); wrapper->mark_storage_changed(); +======= + py_module.def("_functionalize_set_storage_changed", [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); + auto wrapper = at::functionalization::impl::unsafeGetFunctionalWrapper(t); + wrapper->set_storage_changed(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }); py_module.def("_functionalize_was_storage_changed", [](const at::Tensor& t) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); @@ -722,6 +747,7 @@ void initTorchFunctions(PyObject* module) { return wrapper->was_storage_changed(); }); py_module.def( +<<<<<<< HEAD "_functionalize_storage_changed_counter", [](const at::Tensor& t) { TORCH_INTERNAL_ASSERT( at::functionalization::impl::isFunctionalTensor(t)); @@ -730,6 +756,8 @@ void initTorchFunctions(PyObject* module) { return t_impl->storage_changed_counter(); }); py_module.def( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "_functionalize_unsafe_set", [](at::Tensor& dst, const at::Tensor& src) { // Forcefully/unsafely dumps src.storage into dst. // This API is technically and not specific to functionalization diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 712719304ad63..0aaa71e942ae2 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -157,7 +157,11 @@ void pushPyOutToStack( const char* msg) { TORCH_CHECK( PyGILState_Check(), "GIL must be held before you call pushPyOutToStack"); +<<<<<<< HEAD const auto& schema_returns = op.schema().returns(); +======= + auto schema_returns = op.schema().returns(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto num_returns = schema_returns.size(); if (num_returns == 0) { // Check that we got a None return from Python. Anything else is an error. @@ -209,8 +213,13 @@ PyObject* ParameterClass = nullptr; static PyObject* THPVariable_NewWithVar( PyTypeObject* type, const at::TensorBase& _var, +<<<<<<< HEAD bool allow_preexisting_pyobj = false, std::optional has_torch_dispatch_if_known = std::nullopt); +======= + c10::impl::PyInterpreterStatus status, + bool allow_preexisting_pyobj = false); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // clang-tidy gets confused by static const static const char* VOLATILE_WARNING = @@ -261,12 +270,24 @@ PyObject* THPVariable_Wrap(const at::TensorBase& var) { } if (c10::impl::HermeticPyObjectTLS::get_state()) { +<<<<<<< HEAD return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); +======= + return THPVariable_NewWithVar( + (PyTypeObject*)THPVariableClass, + var, + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } std::optional mb_obj = var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( +<<<<<<< HEAD /*ignore_hermetic_tls=*/false); +======= + getPyInterpreter(), /*ignore_hermetic_tls=*/false); + c10::impl::PyInterpreterStatus status{}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (mb_obj.has_value()) { auto obj = *mb_obj; if (obj) { @@ -291,6 +312,7 @@ PyObject* THPVariable_Wrap(const at::TensorBase& var) { // (https://github.com/pytorch/pytorch/pull/56017). Prior to this PR // being a thing, the PyObject field will get cleared when all references // to the Python object are removed. +<<<<<<< HEAD } if (C10_LIKELY(var.device().type() != c10::kXLA)) { @@ -302,6 +324,29 @@ PyObject* THPVariable_Wrap(const at::TensorBase& var) { } return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); +======= + status = c10::impl::PyInterpreterStatus::TAGGED_BY_US; + } else { + // Assumption: if a Tensor has been shared across threads, this induces + // a refcount bump. Therefore, if the use count 1, we are the sole thread + // with access to this tensor and no race is possible. + if (var.use_count() <= 1) { + status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED; + } else { + status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED; + } + } + + if (C10_LIKELY(var.device().type() != c10::kXLA)) { + return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status); + } + + if (auto clazz = getPythonTensorClass(var.device())) { + return THPVariable_NewWithVar((PyTypeObject*)clazz, var, status); + } + + return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } static bool isResurrectable(THPVariable* self) { @@ -330,7 +375,12 @@ static bool isResurrectable(THPVariable* self) { } // Check if this is hermetic. If it is, no resurrection. if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( +<<<<<<< HEAD /*ignore_hermetic_tls=*/false) != (PyObject*)self) { +======= + getPyInterpreter(), /*ignore_hermetic_tls=*/false) != + (PyObject*)self) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return false; } return true; @@ -356,6 +406,10 @@ static bool THPVariable_tryResurrect(THPVariable* self) { c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl(); auto maybe_pyobj = tensor_impl->pyobj_slot()->check_pyobj( +<<<<<<< HEAD +======= + getPyInterpreter(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /*ignore_hermetic_tls=*/false); TORCH_INTERNAL_ASSERT( @@ -571,7 +625,14 @@ static PyObject* THPVariable_as_subclass( // stack torch_dispatch_mode::StashTorchDispatchStackGuard td_g; c10::impl::DisablePythonDispatcher dpd_g; +<<<<<<< HEAD return THPVariable_NewWithVar((PyTypeObject*)cls, self.alias()); +======= + return THPVariable_NewWithVar( + (PyTypeObject*)cls, + self.alias(), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) END_HANDLE_TH_ERRORS } @@ -623,6 +684,7 @@ static PyObject* THPVariable_make_subclass( data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6)); } +<<<<<<< HEAD return THPVariable_NewWithVar((PyTypeObject*)cls, data); END_HANDLE_TH_ERRORS } @@ -686,6 +748,15 @@ static Tensor make_tensor_for_subclass_helper( return tensor; } +======= + return THPVariable_NewWithVar( + (PyTypeObject*)cls, + data, + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + END_HANDLE_TH_ERRORS +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static PyObject* THPVariable_make_wrapper_subclass( PyObject*, PyObject* args, @@ -753,6 +824,7 @@ static PyObject* THPVariable_make_wrapper_subclass( // don't bother releasing GIL here, as we are not allocating any nontrivial // data +<<<<<<< HEAD auto sym_sizes = r.symintlist(1); auto sym_strides_own = r.symintlistOptional(2); Tensor tensor = make_tensor_for_subclass_helper( @@ -767,6 +839,71 @@ static PyObject* THPVariable_make_wrapper_subclass( if (sizes_strides_policy.has_value()) { tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides( parseSizesStridesPolicyArgument(*sizes_strides_policy)); +======= + Tensor tensor; + + { + AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove. + tracer::impl::NoTracerDispatchMode tracer_guard{}; + + auto sym_sizes = r.symintlist(1); + auto sym_strides_own = r.symintlistOptional(2); + auto sym_strides = + static_cast>(sym_strides_own); + auto sym_storage_offset = r.toSymIntOptional(3); + + c10::SymInt size_bytes; + auto dtype_itemsize = static_cast(options.dtype().itemsize()); + auto storage_size = r.toSymIntOptional(14); + + if (storage_size.has_value()) { + size_bytes = storage_size.value(); + } else if (sym_strides.has_value()) { + size_bytes = at::detail::computeStorageNbytes( + sym_sizes, + sym_strides.value(), + dtype_itemsize, + sym_storage_offset.value_or(0)); + } else { + size_bytes = at::detail::computeStorageNbytesContiguous( + sym_sizes, dtype_itemsize, sym_storage_offset.value_or(0)); + } + + // We use storages **only** to track aliasing of subclasses during tracing. + // The actual data pointers are not valid. + Storage storage{ + Storage::use_byte_size_t{}, + size_bytes, + /*allocator=*/c10::GetAllocator(c10::kMeta), + /*resizable=*/true}; + // TODO: constructor should probably accept data pointer + storage.set_data_ptr_noswap(at::DataPtr{nullptr, r.device(7)}); + + auto keys = c10::DispatchKeySet({options.computeDispatchKey()}); + if (auto mb_extra_keys = r.toDispatchKeySetOptional(13)) { + keys = keys | *mb_extra_keys; + } + tensor = at::detail::make_tensor( + std::move(storage), keys, options.dtype()); + + TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl(); + + if (sym_strides.has_value()) { + tensor_impl->set_sizes_and_strides( + sym_sizes, sym_strides.value(), sym_storage_offset); + } else { + TORCH_CHECK( + !sym_storage_offset.has_value(), + "setting storage offset without stride not supported"); + tensor_impl->generic_set_sizes_contiguous(sym_sizes); + } + + const auto sizes_strides_policy = r.stringViewOptional(10); + if (sizes_strides_policy.has_value()) { + tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides( + parseSizesStridesPolicyArgument(*sizes_strides_policy)); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } tensor.set_requires_grad(r.toBool(9)); @@ -781,6 +918,7 @@ static PyObject* THPVariable_make_wrapper_subclass( return THPVariable_NewWithVar( (PyTypeObject*)cls, tensor, +<<<<<<< HEAD // false is the default /*allow_preexisting_pyobj=*/false, // we checked __torch_dispatch__ above; avoid checking again. @@ -848,6 +986,9 @@ static PyObject* THPVariable_make_dtensor( // we know DTensor has __torch_dispatch__ and we double-checked // above; avoid checking again. /*has_torch_dispatch_if_known=*/true); +======= + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) END_HANDLE_TH_ERRORS } @@ -1741,10 +1882,13 @@ static PyMethodDef extra_methods[] = { castPyCFunctionWithKeywords(THPVariable_make_wrapper_subclass), METH_STATIC | METH_VARARGS | METH_KEYWORDS, nullptr}, +<<<<<<< HEAD {"_make_dtensor", castPyCFunctionWithKeywords(THPVariable_make_dtensor), METH_STATIC | METH_VARARGS | METH_KEYWORDS, nullptr}, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr}, {"_view_func", castPyCFunctionWithKeywords(THPVariable_view_func), @@ -1879,6 +2023,10 @@ PyObject* THPVariable_pynew( return THPVariable_NewWithVar( type, tensor, +<<<<<<< HEAD +======= + c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /*allow_preexisting_pyobj=*/true); END_HANDLE_TH_ERRORS } @@ -1931,7 +2079,12 @@ static int THPVariable_subclass_clear(THPVariable* self) { if (!self->cdata.unsafeIsBorrowed() && tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( +<<<<<<< HEAD /*ignore_hermetic_tls=*/false) == (PyObject*)self) { +======= + getPyInterpreter(), /*ignore_hermetic_tls=*/false) == + (PyObject*)self) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // TODO: empirically, on OS X this assert appears to be untrue // In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn // distributed/rpc/test_process_group_agent.py @@ -2103,6 +2256,7 @@ static void THPVariable_subclass_dealloc(PyObject* self) { Py_DECREF(type); } +<<<<<<< HEAD // Creates a new Python object for a Variable. static PyObject* THPVariable_NewWithVar( PyTypeObject* type, @@ -2112,13 +2266,34 @@ static PyObject* THPVariable_NewWithVar( // Make sure that the reinterpret into a THPVariable* will be valid TORCH_CHECK( type == &THPVariableType || PyType_IsSubtype(type, &THPVariableType), +======= +// Creates a new Python object for a Variable. The status parameter +// specifies what the interpreter tag status on the object is; for +// example, if you ran check_pyobj, the return optional of this object +// tells you if the tensor was already tagged or not so you can pass +// TAGGED_BY_US or MAYBE_UNINITIALIZED; in other cases, you know where +// var came from and can directly assert that it's DEFINITELY_UNINITIALIZED. +// It's ALWAYS safe (albeit slower) to call this with MAYBE_UNINITIALIZED. +static PyObject* THPVariable_NewWithVar( + PyTypeObject* type, + const at::TensorBase& _var, + c10::impl::PyInterpreterStatus status, + bool allow_preexisting_pyobj) { + // Make sure that the reinterpret into a THPVariable* will be valid + TORCH_CHECK( + PyType_IsSubtype(type, &THPVariableType), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "Creating a Tensor subclass from a class ", "that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor."); // This function overwrite the Tensor's pyobj field without extra checks // Make sure it is not set otherwise we would leak memory auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( +<<<<<<< HEAD /*ignore_hermetic_tls=*/false); +======= + getPyInterpreter(), /*ignore_hermetic_tls=*/false); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Under some circumstances, we may attempt to create a new Python // object for a variable that already has a Python object. The most common @@ -2200,10 +2375,16 @@ static PyObject* THPVariable_NewWithVar( // Normal codepath v->cdata = MaybeOwned::owned(Variable(_var)); const auto& var = THPVariable_Unpack(v); +<<<<<<< HEAD var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(obj); if (has_torch_dispatch_if_known.has_value() ? *has_torch_dispatch_if_known : check_has_torch_dispatch(obj)) { +======= + var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj( + getPyInterpreter(), obj, status); + if (check_has_torch_dispatch(obj)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) var.unsafeGetTensorImpl()->set_python_dispatch(true); } } diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index e618ee703378f..e568cacdb99a5 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -29,7 +29,10 @@ #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using namespace at; using namespace torch::autograd::utils; @@ -61,6 +64,7 @@ Py_ssize_t THPVariable_length(PyObject* self) { // and tuples of those types. We also handle bools as if they were a // Variable[ByteTensor]. +<<<<<<< HEAD // We only go one deep, because that's all torchdim needs (it supports // a tuple/list of FCDs which triggers a split behavior, but you can // only do it at the top level) and it's all the dispatcher will do @@ -88,6 +92,8 @@ static bool sequence_has_torch_function(PyObject* seq) { return false; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static int64_t count_specified_dimensions(PyObject* index) { // Count the number of indexed dimensions (everything but ellipsis and None) // -1 is a sentinel for __torch_function__ @@ -95,10 +101,15 @@ static int64_t count_specified_dimensions(PyObject* index) { auto size = PyTuple_GET_SIZE(index); for (Py_ssize_t i = 0; i < size; i++) { PyObject* obj = PyTuple_GET_ITEM(index, i); +<<<<<<< HEAD if (check_has_torch_function(obj)) { return -1; } +======= + if (check_has_torch_function(obj)) + return -1; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (THPVariable_Check(obj)) { const auto& var = THPVariable_Unpack(obj); const auto& var_scalar_type = var.scalar_type(); @@ -107,6 +118,7 @@ static int64_t count_specified_dimensions(PyObject* index) { } else { count++; } +<<<<<<< HEAD } else { // Check sequences for __torch_function__ (top-level only) if (PySequence_Check(obj)) { @@ -118,6 +130,12 @@ static int64_t count_specified_dimensions(PyObject* index) { obj != Py_False) { count++; } +======= + } else if ( + obj != Py_None && obj != Py_Ellipsis && obj != Py_True && + obj != Py_False) { + count++; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } return count; @@ -160,12 +178,19 @@ inline Variable valueToTensor( } else if (torch::is_symbool(value)) { scalar = Scalar(py::cast(py::handle(value))); } else { +<<<<<<< HEAD TORCH_CHECK_TYPE( false, "can't assign a ", Py_TYPE(value)->tp_name, " to a ", torch::utils::options_to_string(options)); +======= + throw TypeError( + "can't assign a %s to a %s", + Py_TYPE(value)->tp_name, + torch::utils::options_to_string(options).c_str()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // lift_fresh is supposed to be used in situations where you are guaranteed to // get a plain Tensor which is not true for cpu device but not for non cpu @@ -434,7 +459,11 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { variable_list variableIndices; int64_t specified_dims = count_specified_dimensions(holder.get()); if (specified_dims == -1) { +<<<<<<< HEAD return handle_torch_function_indexing(self, index); +======= + return handle_torch_function_indexing(self, holder.get()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } Variable sliced = applySlicing( self_, @@ -482,7 +511,11 @@ static void dispatch_set_item( int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { HANDLE_TH_ERRORS if (py_value == nullptr) { +<<<<<<< HEAD TORCH_CHECK_TYPE(false, "Tensor does not support deleting items"); +======= + throw TypeError("Tensor does not support deleting items"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if ((check_has_torch_function(self)) || (check_has_torch_function(py_value))) { @@ -495,7 +528,11 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { if (self_.layout() == kSparse || self_.layout() == kSparseCsr || self_.layout() == kSparseCsc || self_.layout() == kSparseBsr || self_.layout() == kSparseBsc) { +<<<<<<< HEAD TORCH_CHECK_TYPE(false, "Cannot assign to a sparse tensor"); +======= + throw TypeError("Cannot assign to a sparse tensor"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } OptionalDeviceGuard device_guard(device_of(self_)); at::Device self_device = self_.device(); diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp index 0124a0212bc61..d5b1ca35ba247 100644 --- a/torch/csrc/autograd/saved_variable.cpp +++ b/torch/csrc/autograd/saved_variable.cpp @@ -39,11 +39,16 @@ SavedVariable::SavedVariable( // follow. TORCH_CHECK( !variable.is_inference(), +<<<<<<< HEAD "Inference tensors cannot be saved for backward. Please do not use " "Tensors created in inference mode in computation tracked by autograd. " "To work around this, you can make a clone to get a normal tensor and " "use it in autograd, or use `torch.no_grad()` instead of " "`torch.inference_mode()`."); +======= + "Inference tensors cannot be saved for backward. To work around " + "you can make a clone to get a normal tensor and use it in autograd.") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) was_default_constructed_ = false; saved_version_ = variable._version(); diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp index 3fbe6f906db4e..01de9f8acb97e 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp +++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp @@ -7,6 +7,26 @@ namespace torch::cuda::CUDAPluggableAllocator { +<<<<<<< HEAD +======= +CUDAPluggableAllocatorDeleterContext::CUDAPluggableAllocatorDeleterContext( + std::function free_fn, + void* data, + size_t size, + int device, + cudaStream_t stream) + : free_fn_(std::move(free_fn)), + data_(data), + size_(size), + device_(device), + stream_(stream) {} + +void CUDAPluggableAllocatorDeleterContext::free() { + free_fn_(data_, size_, device_, stream_); + delete this; +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int device_count = 0; void custom_raw_deleter(void* ptr); @@ -24,8 +44,13 @@ _AllocationMetadata::_AllocationMetadata( // This avoids having to link against libtorch for C++ based custom allocators // And also use this from python CUDAPluggableAllocator::CUDAPluggableAllocator( +<<<<<<< HEAD std::function alloc_fn, std::function free_fn) +======= + std::function alloc_fn, + std::function free_fn) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : alloc_fn_(std::move(alloc_fn)), free_fn_(std::move(free_fn)) {} CUDAPluggableAllocator::CUDAPluggableAllocator(CUDAPluggableAllocator& other) @@ -97,8 +122,15 @@ c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) { C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device); void* r = this->malloc(size, device, stream); +<<<<<<< HEAD c10::DataPtr data_ptr = { r, r, raw_deleter(), c10::Device(c10::DeviceType::CUDA, device)}; +======= + auto* ctx = new CUDAPluggableAllocatorDeleterContext( + free_fn_, r, size, device, stream); + c10::DataPtr data_ptr = { + r, ctx, raw_deleter(), c10::Device(c10::DeviceType::CUDA, device)}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return data_ptr; } @@ -363,8 +395,13 @@ getCurrentAllocator() { // TODO: add more functions in the argument std::shared_ptr createCustomAllocator( +<<<<<<< HEAD std::function alloc_fn, std::function free_fn) { +======= + std::function alloc_fn, + std::function free_fn) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::shared_ptr allocator( new CUDAPluggableAllocator(std::move(alloc_fn), std::move(free_fn))); allocator->init(device_count); @@ -381,8 +418,13 @@ void changeCurrentAllocator( current_custom_allocator = allocator; } +<<<<<<< HEAD void custom_raw_deleter(void* ptr) { current_custom_allocator->raw_delete(ptr); +======= +void custom_raw_deleter(void* ctx) { + reinterpret_cast(ctx)->free(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace torch::cuda::CUDAPluggableAllocator diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h index d4f73117eca61..912ad43a61e02 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.h +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -11,6 +11,35 @@ namespace torch::cuda::CUDAPluggableAllocator { +<<<<<<< HEAD +======= +using MallocFuncType = void*(size_t, int, cudaStream_t); +using FreeFuncType = void(void*, size_t, int, cudaStream_t); + +// A CUDAPluggableAllocatorDeleterContext object is used as the `ctx` +// argument for DataPtr. We need context because a user can use +// multiple allocators in the same PyTorch program, and +// the allocators can have different free functions, such as: +// free, cudaFree, cudaFreeAsync, ncclMemFree etc. +struct TORCH_CUDA_CPP_API CUDAPluggableAllocatorDeleterContext { + explicit CUDAPluggableAllocatorDeleterContext( + std::function free_fn, + void* data, + size_t size, + int device, + cudaStream_t stream); + + void free(); + + private: + std::function free_fn_; + void* data_; + size_t size_; + int device_; + cudaStream_t stream_{}; +}; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if defined(USE_ROCM) using streamType = c10::hip::HIPStream; #else @@ -23,8 +52,13 @@ getCurrentAllocator(); TORCH_CUDA_CPP_API std::shared_ptr< c10::cuda::CUDACachingAllocator::CUDAAllocator> createCustomAllocator( +<<<<<<< HEAD std::function alloc_fn, std::function free_fn); +======= + std::function alloc_fn, + std::function free_fn); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CUDA_CPP_API void changeCurrentAllocator( const std::shared_ptr& allocator); @@ -43,8 +77,13 @@ struct _AllocationMetadata { struct TORCH_CUDA_CPP_API CUDAPluggableAllocator : public c10::cuda::CUDACachingAllocator::CUDAAllocator { CUDAPluggableAllocator( +<<<<<<< HEAD std::function alloc_fn, std::function free_fn); +======= + std::function alloc_fn, + std::function free_fn); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CUDAPluggableAllocator(CUDAPluggableAllocator& other); CUDAPluggableAllocator(CUDAPluggableAllocator&& other) = delete; @@ -147,8 +186,13 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator void copy_data(void* dest, const void* src, std::size_t count) const final; protected: +<<<<<<< HEAD std::function alloc_fn_; std::function free_fn_; +======= + std::function alloc_fn_; + std::function free_fn_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::function init_fn_; std::function reset_fn_; std::function memory_fraction_fn_; diff --git a/torch/csrc/cuda/Graph.cpp b/torch/csrc/cuda/Graph.cpp index 2a551ae28e96d..aa67948eb5c1f 100644 --- a/torch/csrc/cuda/Graph.cpp +++ b/torch/csrc/cuda/Graph.cpp @@ -101,6 +101,7 @@ void THCPGraph_init(PyObject* module) { // compile error. return reinterpret_cast(graph); }, +<<<<<<< HEAD py::call_guard()) .def( "raw_cuda_graph_exec", @@ -112,5 +113,7 @@ void THCPGraph_init(PyObject* module) { // compile error. return reinterpret_cast(graph_exec); }, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard()); } diff --git a/torch/csrc/cuda/MemPool.cpp b/torch/csrc/cuda/MemPool.cpp index b651a4b5e68aa..350ed4854e48f 100644 --- a/torch/csrc/cuda/MemPool.cpp +++ b/torch/csrc/cuda/MemPool.cpp @@ -16,12 +16,24 @@ void THCPMemPool_init(PyObject* module) { .def( py::init([](c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator, bool is_user_created, +<<<<<<< HEAD bool use_on_oom) { torch::utils::device_lazy_init(at::kCUDA); return std::make_shared<::c10::cuda::MemPool>( allocator, is_user_created, use_on_oom); })) .def_property_readonly("id", &::c10::cuda::MemPool::id) +======= + bool use_on_oom, + bool symmetric) { + torch::utils::device_lazy_init(at::kCUDA); + return std::make_shared<::c10::cuda::MemPool>( + allocator, is_user_created, use_on_oom, symmetric); + })) + .def_property_readonly("id", &::c10::cuda::MemPool::id) + .def_property_readonly( + "is_symmetric", &::c10::cuda::MemPool::is_symmetric) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def_property_readonly("allocator", &::c10::cuda::MemPool::allocator) .def("use_count", &::c10::cuda::MemPool::use_count); } diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 3a8929110e8b4..58e45382c6fab 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -5,6 +5,10 @@ #include #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -907,8 +911,11 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) { py::str release_lock_on_malloc_s = "release_lock_on_cudamalloc"; py::str pinned_use_host_register_s = "pinned_use_cuda_host_register"; py::str roundup_power2_divisions_s = "roundup_power2_divisions"; +<<<<<<< HEAD py::str graph_capture_record_stream_reuse_s = "graph_capture_record_stream_reuse"; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) allocator_settings[last_allocator_settings_s] = snapshot.config_metadata.last_allocator_settings; @@ -924,8 +931,11 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) { snapshot.config_metadata.release_lock_on_malloc; allocator_settings[pinned_use_host_register_s] = snapshot.config_metadata.pinned_use_host_register; +<<<<<<< HEAD allocator_settings[graph_capture_record_stream_reuse_s] = snapshot.config_metadata.graph_capture_record_stream_reuse; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unsigned int roundup_key = 1; py::dict roundup_settings; for (const auto& v : snapshot.config_metadata.roundup_power2_divisions) { @@ -1020,6 +1030,37 @@ PyObject* THCPModule_cudaGetSyncDebugMode(PyObject* self, PyObject* noargs) { END_HANDLE_TH_ERRORS } +<<<<<<< HEAD +======= +std::string uuid_to_string(const char* uuid_bytes) { + // UUIDs are a 128-bit label. CUDA and HIP store this as char[16]. + // For string representation, the code here expands this to + // 8-4-4-4-12 hex format, so each byte becomes 2 hex characters. + return fmt::format( + "{:02x}{:02x}{:02x}{:02x}-" + "{:02x}{:02x}-" + "{:02x}{:02x}-" + "{:02x}{:02x}-" + "{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}", + (uint8_t)uuid_bytes[0], + (uint8_t)uuid_bytes[1], + (uint8_t)uuid_bytes[2], + (uint8_t)uuid_bytes[3], + (uint8_t)uuid_bytes[4], + (uint8_t)uuid_bytes[5], + (uint8_t)uuid_bytes[6], + (uint8_t)uuid_bytes[7], + (uint8_t)uuid_bytes[8], + (uint8_t)uuid_bytes[9], + (uint8_t)uuid_bytes[10], + (uint8_t)uuid_bytes[11], + (uint8_t)uuid_bytes[12], + (uint8_t)uuid_bytes[13], + (uint8_t)uuid_bytes[14], + (uint8_t)uuid_bytes[15]); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) //////////////////////////////////////////////////////////////////////////////// // Cuda module initialization //////////////////////////////////////////////////////////////////////////////// @@ -1274,16 +1315,24 @@ static void registerCudaPluggableAllocator(PyObject* module) { self.set_release_pool(func); }); m.def("_cuda_customAllocator", [](uint64_t malloc_ptr, uint64_t free_ptr) { +<<<<<<< HEAD using MallocFuncType = void*(size_t, int, cudaStream_t); using FreeFuncType = void(void*, size_t, int, cudaStream_t); +======= + using namespace torch::cuda::CUDAPluggableAllocator; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::function malloc_fn = // NOLINTNEXTLINE(performance-no-int-to-ptr) reinterpret_cast(malloc_ptr); std::function free_fn = // NOLINTNEXTLINE(performance-no-int-to-ptr) reinterpret_cast(free_ptr); +<<<<<<< HEAD return torch::cuda::CUDAPluggableAllocator::createCustomAllocator( malloc_fn, free_fn); +======= + return createCustomAllocator(malloc_fn, free_fn); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }); // NOLINTNEXTLINE(bugprone-unused-raii) diff --git a/torch/csrc/cuda/memory_snapshot.cpp b/torch/csrc/cuda/memory_snapshot.cpp index 3c96d5c5908dd..2d36713f6e304 100644 --- a/torch/csrc/cuda/memory_snapshot.cpp +++ b/torch/csrc/cuda/memory_snapshot.cpp @@ -458,8 +458,11 @@ std::string _memory_snapshot_pickled() { IValue release_lock_on_malloc_s = "release_lock_on_cudamalloc"; IValue pinned_use_host_register_s = "pinned_use_cuda_host_register"; IValue roundup_power2_divisions_s = "roundup_power2_divisions"; +<<<<<<< HEAD IValue graph_capture_record_stream_reuse_s = "graph_capture_record_stream_reuse"; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) allocator_settings.insert( last_allocator_settings_s, @@ -480,9 +483,12 @@ std::string _memory_snapshot_pickled() { allocator_settings.insert( pinned_use_host_register_s, snapshot.config_metadata.pinned_use_host_register); +<<<<<<< HEAD allocator_settings.insert( graph_capture_record_stream_reuse_s, snapshot.config_metadata.graph_capture_record_stream_reuse); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unsigned int roundup_key = 1; auto roundup_settings = new_dict(); for (const auto& v : snapshot.config_metadata.roundup_power2_divisions) { diff --git a/torch/csrc/cuda/shared/nvtx.cpp b/torch/csrc/cuda/shared/nvtx.cpp index d28e8ae222eaa..bca230c04d74a 100644 --- a/torch/csrc/cuda/shared/nvtx.cpp +++ b/torch/csrc/cuda/shared/nvtx.cpp @@ -3,11 +3,15 @@ #endif #ifndef ROCM_ON_WINDOWS +<<<<<<< HEAD #ifdef TORCH_CUDA_USE_NVTX3 #include #else // TORCH_CUDA_USE_NVTX3 #include #endif // TORCH_CUDA_USE_NVTX3 +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #else // ROCM_ON_WINDOWS #include #endif // ROCM_ON_WINDOWS @@ -54,11 +58,15 @@ static void* device_nvtxRangeStart(const char* msg, std::intptr_t stream) { void initNvtxBindings(PyObject* module) { auto m = py::handle(module).cast(); +<<<<<<< HEAD #ifdef TORCH_CUDA_USE_NVTX3 auto nvtx = m.def_submodule("_nvtx", "nvtx3 bindings"); #else auto nvtx = m.def_submodule("_nvtx", "libNvToolsExt.so bindings"); #endif +======= + auto nvtx = m.def_submodule("_nvtx", "nvtx3 bindings"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nvtx.def("rangePushA", nvtxRangePushA); nvtx.def("rangePop", nvtxRangePop); nvtx.def("rangeStartA", nvtxRangeStartA); diff --git a/torch/csrc/deploy/README.md b/torch/csrc/deploy/README.md new file mode 100644 index 0000000000000..2d40ca8361ff4 --- /dev/null +++ b/torch/csrc/deploy/README.md @@ -0,0 +1,2 @@ +# torch::deploy has been moved to pytorch/multipy +Please check out [https://github.com/pytorch/multipy](https://github.com/pytorch/multipy) to find the new home for torch::deploy. diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index 655e0a5578c29..504d6060dd330 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -46,8 +46,11 @@ class TORCH_API Backend : public torch::CustomClassHolder { // backend name // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::string backend; +<<<<<<< HEAD std::string group_name; std::vector global_ranks_in_group; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; explicit Backend(int rank, int size); @@ -79,6 +82,7 @@ class TORCH_API Backend : public torch::CustomClassHolder { return false; } +<<<<<<< HEAD virtual void setTimeout(std::chrono::milliseconds timeout) { TORCH_CHECK( false, @@ -86,6 +90,8 @@ class TORCH_API Backend : public torch::CustomClassHolder { "Backend ", getBackendName(), " does not support setting timeout")); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) virtual void startCoalescing() { TORCH_CHECK( false, @@ -107,6 +113,7 @@ class TORCH_API Backend : public torch::CustomClassHolder { TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented."); } +<<<<<<< HEAD // Subclasses must override this method to return the backend name virtual c10::intrusive_ptr getBackendOptions() { TORCH_CHECK( @@ -117,6 +124,8 @@ class TORCH_API Backend : public torch::CustomClassHolder { " does not implement getBackendOptions.")); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) virtual c10::intrusive_ptr broadcast( std::vector& /* tensors */, const BroadcastOptions& /* opts */ = BroadcastOptions()) { @@ -391,6 +400,7 @@ class TORCH_API Backend : public torch::CustomClassHolder { " is missing implementation of enableCollectivesTiming."); } +<<<<<<< HEAD virtual c10::intrusive_ptr split( const c10::intrusive_ptr& store, const std::vector& ranks, @@ -414,6 +424,8 @@ class TORCH_API Backend : public torch::CustomClassHolder { " is missing implementation of merge."); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool hasHooks() const { return onCompletionHook_ != nullptr; } diff --git a/torch/csrc/distributed/c10d/FakeProcessGroup.hpp b/torch/csrc/distributed/c10d/FakeProcessGroup.hpp index dc3c4889057c8..1718f0afc44de 100644 --- a/torch/csrc/distributed/c10d/FakeProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/FakeProcessGroup.hpp @@ -20,6 +20,7 @@ class FakeWork : public Work { class FakeProcessGroup : public Backend { public: +<<<<<<< HEAD struct Options : Backend::Options { explicit Options() : Backend::Options("fake") {} @@ -39,6 +40,9 @@ class FakeProcessGroup : public Backend { c10::intrusive_ptr getBackendOptions() override { return c10::static_intrusive_pointer_cast(options_); } +======= + FakeProcessGroup(int rank, int size) : Backend(rank, size) {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::intrusive_ptr broadcast( std::vector& /* tensors */, @@ -212,9 +216,12 @@ class FakeProcessGroup : public Backend { const BarrierOptions& /* opts */ = BarrierOptions()) override { return c10::make_intrusive(); } +<<<<<<< HEAD private: c10::intrusive_ptr options_; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; } // namespace c10d diff --git a/torch/csrc/distributed/c10d/FileStore.cpp b/torch/csrc/distributed/c10d/FileStore.cpp index 7b0fc862e680d..a7a4a50069fc5 100644 --- a/torch/csrc/distributed/c10d/FileStore.cpp +++ b/torch/csrc/distributed/c10d/FileStore.cpp @@ -33,11 +33,15 @@ #define LOCK_SH 0x00000010 #define LOCK_UN 0x00000100 +<<<<<<< HEAD #if defined(_WIN32) && defined(USE_ROCM) static #endif int flock_(int fd, int op) { +======= +int flock_(int fd, int op) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) HANDLE hdl = (HANDLE)_get_osfhandle(fd); DWORD low = 1, high = 0; OVERLAPPED offset = {0, 0, 0, 0, NULL}; diff --git a/torch/csrc/distributed/c10d/FlightRecorder.cpp b/torch/csrc/distributed/c10d/FlightRecorder.cpp index 2384448a06e75..102db126c6a5c 100644 --- a/torch/csrc/distributed/c10d/FlightRecorder.cpp +++ b/torch/csrc/distributed/c10d/FlightRecorder.cpp @@ -1,5 +1,8 @@ #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace c10d { @@ -40,7 +43,11 @@ DebugInfoWriter& DebugInfoWriter::getWriter(int rank) { auto cacheDirPath = std::filesystem::path(homeDir + "/.cache/torch"); // Create the .cache directory if it doesn't exist std::filesystem::create_directories(cacheDirPath); +<<<<<<< HEAD auto defaultLocation = cacheDirPath / "comm_lib_trace_rank_"; +======= + auto defaultLocation = cacheDirPath / "nccl_trace_rank_"; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // For internal bc compatibility, we keep the old the ENV check. std::string fileNamePrefix = getCvarString( diff --git a/torch/csrc/distributed/c10d/FlightRecorder.hpp b/torch/csrc/distributed/c10d/FlightRecorder.hpp index b0974495a87a9..9295d1c8bbe6a 100644 --- a/torch/csrc/distributed/c10d/FlightRecorder.hpp +++ b/torch/csrc/distributed/c10d/FlightRecorder.hpp @@ -20,10 +20,17 @@ namespace c10d { // (minor when adding fields, major when changing existing fields) // Also update both JSON and Pickle dumps to make use of the newly defined // field(s). +<<<<<<< HEAD DEFINE_CONSTANT(version_val, "2.10") DEFINE_CONSTANT(entries_key, "entries") DEFINE_CONSTANT(nccl_comm_key, "nccl_comm_state") DEFINE_CONSTANT(comm_lib_version_key, "comm_lib_version") +======= +DEFINE_CONSTANT(version_val, "2.9") +DEFINE_CONSTANT(entries_key, "entries") +DEFINE_CONSTANT(nccl_comm_key, "nccl_comm_state") +DEFINE_CONSTANT(nccl_version_key, "nccl_version") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DEFINE_CONSTANT(version_key, "version") DEFINE_CONSTANT(pg_config_key, "pg_config") DEFINE_CONSTANT(pg_status_key, "pg_status") @@ -179,7 +186,11 @@ struct FlightRecorder { std::map> all_pg_status_ = {}; std::map, std::vector> pg_name_to_ranks_ = {}; +<<<<<<< HEAD std::string comm_lib_version_; +======= + std::string nccl_version_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::optional record( size_t pg_id, @@ -200,7 +211,11 @@ struct FlightRecorder { const std::tuple& pg_name, std::vector ranks); +<<<<<<< HEAD void record_accelerator_version(const std::string comm_lib_version); +======= + void record_accelerator_version(const std::string nccl_version); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void update_state(Entry& r); diff --git a/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp b/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp index 473372fd44b4c..18ad7d52bcab7 100644 --- a/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp +++ b/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp @@ -128,12 +128,20 @@ void FlightRecorder::record_pg_ranks( template void FlightRecorder::record_accelerator_version( +<<<<<<< HEAD const std::string comm_lib_version) { +======= + const std::string nccl_version) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!enabled_) { return; } std::lock_guard guard(mutex_); +<<<<<<< HEAD comm_lib_version_ = std::move(comm_lib_version); +======= + nccl_version_ = std::move(nccl_version); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } template @@ -425,7 +433,11 @@ std::string FlightRecorder::dump_json( bool onlyActive) { json result; result[version_key_str] = version_val_str; +<<<<<<< HEAD result[comm_lib_version_key_str] = comm_lib_version_; +======= + result[nccl_version_key_str] = nccl_version_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result[pg_config_key_str] = getPgConfigJson(); result[pg_status_key_str] = getPgStatusJson(); @@ -522,7 +534,11 @@ std::string FlightRecorder::dump( // common values result.insert(version_key, version_val); result.insert(pg_config_key, getPgConfig()); +<<<<<<< HEAD result.insert(comm_lib_version_key_str, comm_lib_version_); +======= + result.insert(nccl_version_key_str, nccl_version_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result.insert(pg_status_key, getPgStatus()); // collective trace diff --git a/torch/csrc/distributed/c10d/Functional.cpp b/torch/csrc/distributed/c10d/Functional.cpp index 99b0c7d17bf7e..5997d6b6aef7f 100644 --- a/torch/csrc/distributed/c10d/Functional.cpp +++ b/torch/csrc/distributed/c10d/Functional.cpp @@ -30,6 +30,7 @@ c10d::ReduceOp to_reduce_op(const std::string& reduce_op) { return it->second; } +<<<<<<< HEAD at::Tensor allocate_all_gather_output( const at::Tensor& input, int64_t group_size) { @@ -65,6 +66,8 @@ at::Tensor allocate_reduce_scatter_output( namespace c10d { +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::Tensor& all_reduce_( at::Tensor& input, // NOLINTNEXTLINE(performance-unnecessary-value-param) @@ -85,6 +88,7 @@ at::Tensor all_reduce( const at::Tensor& input, std::string reduce_op, std::string group_name) { +<<<<<<< HEAD if (input.is_complex()) { TORCH_CHECK( // TODO - ideally use 'to_reduce_op' helper but it currently errors on @@ -100,6 +104,10 @@ at::Tensor all_reduce( auto output_ret = all_reduce_(output, std::move(reduce_op), std::move(group_name)); return input.is_complex() ? at::view_as_complex(output_ret) : output_ret; +======= + auto output = input.clone(at::MemoryFormat::Contiguous); + return all_reduce_(output, std::move(reduce_op), std::move(group_name)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } std::vector all_reduce_coalesced_( @@ -133,6 +141,20 @@ std::vector all_reduce_coalesced( outputs, std::move(reduce_op), std::move(group_name)); } +<<<<<<< HEAD +======= +at::Tensor allocate_all_gather_output( + const at::Tensor& input, + int64_t group_size) { + TORCH_CHECK(input.is_contiguous()); + auto output_size = input.sizes().vec(); + output_size[0] *= group_size; + return at::empty( + output_size, + at::TensorOptions().dtype(input.dtype()).device(input.device())); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector all_gather_into_tensor_coalesced( std::vector inputs, int64_t group_size, @@ -158,11 +180,17 @@ at::Tensor all_gather_into_tensor( int64_t group_size, std::string group_name) { TORCH_CHECK(input.is_contiguous()); +<<<<<<< HEAD auto real_input = input.is_complex() ? at::view_as_real(input) : input; std::vector inputs{real_input}; auto output = all_gather_into_tensor_coalesced( inputs, group_size, std::move(group_name))[0]; return input.is_complex() ? at::view_as_complex(output) : output; +======= + std::vector inputs{input}; + return all_gather_into_tensor_coalesced( + inputs, group_size, std::move(group_name))[0]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } at::Tensor& all_gather_into_tensor_out( @@ -179,6 +207,25 @@ at::Tensor& all_gather_into_tensor_out( return output; } +<<<<<<< HEAD +======= +at::Tensor allocate_reduce_scatter_output( + const at::Tensor& input, + const int64_t group_size) { + TORCH_CHECK(input.is_contiguous()); + auto output_size = input.sizes().vec(); + if (output_size[0] % group_size != 0) { + LOG(WARNING) << "The first dimension of the reduce_scatter input (" + << output_size[0] << ") is not divisible by the group size (" + << group_size << ")."; + } + output_size[0] /= group_size; + return at::empty( + output_size, + at::TensorOptions().dtype(input.dtype()).device(input.device())); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector reduce_scatter_tensor_coalesced( std::vector inputs, // NOLINTNEXTLINE(performance-unnecessary-value-param) @@ -209,12 +256,15 @@ at::Tensor reduce_scatter_tensor( int64_t group_size, std::string group_name) { TORCH_CHECK(input.is_contiguous()); +<<<<<<< HEAD if (input.is_complex()) { auto real_input = at::view_as_real(input); std::vector inputs{real_input}; return at::view_as_complex(reduce_scatter_tensor_coalesced( inputs, std::move(reduce_op), group_size, std::move(group_name))[0]); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector inputs{input}; return reduce_scatter_tensor_coalesced( inputs, std::move(reduce_op), group_size, std::move(group_name))[0]; @@ -222,6 +272,7 @@ at::Tensor reduce_scatter_tensor( at::Tensor all_to_all_single( const at::Tensor& input, +<<<<<<< HEAD c10::SymIntArrayRef _output_split_sizes, c10::SymIntArrayRef _input_split_sizes, // NOLINTNEXTLINE(performance-unnecessary-value-param) @@ -237,6 +288,12 @@ at::Tensor all_to_all_single( input_split_sizes.emplace_back(size.expect_int()); } +======= + std::vector output_split_sizes, + std::vector input_split_sizes, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::string group_name) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(input.is_contiguous()); std::vector output_sizes = input.sizes().vec(); output_sizes[0] = std::accumulate( @@ -258,8 +315,12 @@ at::Tensor all_to_all_single( at::Tensor& broadcast_(at::Tensor& input, int64_t src, std::string group_name) { c10d::BroadcastOptions opts; opts.rootRank = src; +<<<<<<< HEAD auto input_real = input.is_complex() ? at::view_as_real(input) : input; std::vector inputs{input_real}; +======= + std::vector inputs{input}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto group = c10d::resolve_process_group(group_name); auto work = group->broadcast(inputs, opts); @@ -275,68 +336,108 @@ at::Tensor broadcast( return broadcast_(output, src, std::move(group_name)); } +<<<<<<< HEAD } // namespace c10d +======= +} // namespace +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_LIBRARY(_c10d_functional, m) { m.def( "all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor", torch::dispatch( +<<<<<<< HEAD c10::DispatchKey::CompositeExplicitAutograd, c10d::all_reduce), +======= + c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {at::Tag::pt2_compliant_tag}); m.def( "all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)", torch::dispatch( +<<<<<<< HEAD c10::DispatchKey::CompositeExplicitAutograd, c10d::all_reduce_), +======= + c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {at::Tag::pt2_compliant_tag}); m.def( "all_reduce_coalesced(Tensor[] inputs, str reduce_op, str group_name) -> Tensor[]", torch::dispatch( +<<<<<<< HEAD c10::DispatchKey::CompositeExplicitAutograd, c10d::all_reduce_coalesced), +======= + c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {at::Tag::pt2_compliant_tag}); m.def( "all_reduce_coalesced_(Tensor[](a!) inputs, str reduce_op, str group_name) -> Tensor[](a!)", torch::dispatch( +<<<<<<< HEAD c10::DispatchKey::CompositeExplicitAutograd, c10d::all_reduce_coalesced_), +======= + c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced_), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {at::Tag::pt2_compliant_tag}); m.def( "all_gather_into_tensor_out(Tensor input, int group_size, str group_name, *, Tensor(a!) out) -> Tensor(a!)", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, +<<<<<<< HEAD c10d::all_gather_into_tensor_out), +======= + ::all_gather_into_tensor_out), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( "all_gather_into_tensor(Tensor input, int group_size, str group_name) -> Tensor", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, +<<<<<<< HEAD c10d::all_gather_into_tensor), +======= + ::all_gather_into_tensor), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( "all_gather_into_tensor_coalesced(Tensor[] inputs, int group_size, str group_name) -> Tensor[]", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, +<<<<<<< HEAD c10d::all_gather_into_tensor_coalesced), +======= + ::all_gather_into_tensor_coalesced), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( "reduce_scatter_tensor(Tensor input, str reduce_op, int group_size, str group_name) -> Tensor", torch::dispatch( +<<<<<<< HEAD c10::DispatchKey::CompositeExplicitAutograd, c10d::reduce_scatter_tensor), +======= + c10::DispatchKey::CompositeExplicitAutograd, ::reduce_scatter_tensor), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduce_op, int group_size, str group_name) -> Tensor[]", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, +<<<<<<< HEAD c10d::reduce_scatter_tensor_coalesced), +======= + ::reduce_scatter_tensor_coalesced), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( @@ -346,19 +447,31 @@ TORCH_LIBRARY(_c10d_functional, m) { "SymInt[] input_split_sizes, " "str group_name) -> Tensor", torch::dispatch( +<<<<<<< HEAD c10::DispatchKey::CompositeExplicitAutograd, c10d::all_to_all_single), +======= + c10::DispatchKey::CompositeExplicitAutograd, ::all_to_all_single), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( "broadcast(Tensor input, int src, str group_name) -> Tensor", +<<<<<<< HEAD torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, c10d::broadcast), +======= + torch::dispatch(c10::DispatchKey::CompositeExplicitAutograd, ::broadcast), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {at::Tag::pt2_compliant_tag}); m.def( "broadcast_(Tensor(a!) input, int src, str group_name) -> Tensor(a!)", torch::dispatch( +<<<<<<< HEAD c10::DispatchKey::CompositeExplicitAutograd, c10d::broadcast_), +======= + c10::DispatchKey::CompositeExplicitAutograd, ::broadcast_), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {at::Tag::pt2_compliant_tag}); m.def( @@ -375,6 +488,7 @@ class AllToAllSingle : public torch::autograd::Function { torch::autograd::AutogradContext* ctx, const at::Tensor& input, // NOLINTNEXTLINE(performance-unnecessary-value-param) +<<<<<<< HEAD at::SymIntArrayRef output_split_sizes, // NOLINTNEXTLINE(performance-unnecessary-value-param) at::SymIntArrayRef input_split_sizes, @@ -383,21 +497,42 @@ class AllToAllSingle : public torch::autograd::Function { // swap sizes for backwards pass ctx->saved_data["output_split_sizes"] = input_split_sizes.vec(); ctx->saved_data["input_split_sizes"] = output_split_sizes.vec(); +======= + std::vector output_split_sizes, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::vector input_split_sizes, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::string group_name) { + // swap sizes for backwards pass + ctx->saved_data["output_split_sizes"] = input_split_sizes; + ctx->saved_data["input_split_sizes"] = output_split_sizes; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ctx->saved_data["group_name"] = group_name; return c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::all_to_all_single", "") +<<<<<<< HEAD .typed() +======= + .typed() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .call(input, output_split_sizes, input_split_sizes, group_name); } static torch::autograd::variable_list backward( torch::autograd::AutogradContext* ctx, const torch::autograd::variable_list& grad_out_list) { +<<<<<<< HEAD std::vector output_split_sizes = ctx->saved_data["output_split_sizes"].toSymIntVector(); std::vector input_split_sizes = ctx->saved_data["input_split_sizes"].toSymIntVector(); +======= + const std::vector& output_split_sizes = + ctx->saved_data["output_split_sizes"].toIntVector(); + const std::vector& input_split_sizes = + ctx->saved_data["input_split_sizes"].toIntVector(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const std::string& group_name = ctx->saved_data["group_name"].toStringRef(); DCHECK(grad_out_list.size() == 1); @@ -406,7 +541,11 @@ class AllToAllSingle : public torch::autograd::Function { auto out = c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::all_to_all_single", "") +<<<<<<< HEAD .typed() +======= + .typed() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .call(grad_out, output_split_sizes, input_split_sizes, group_name); // do an explicit wait to avoid cuda stream issues @@ -422,8 +561,13 @@ class AllToAllSingle : public torch::autograd::Function { at::Tensor all_to_all_single_autograd( const at::Tensor& input, +<<<<<<< HEAD at::SymIntArrayRef output_split_sizes, at::SymIntArrayRef input_split_sizes, +======= + const std::vector& output_split_sizes, + const std::vector& input_split_sizes, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const std::string& group_name) { return AllToAllSingle::apply( input, output_split_sizes, input_split_sizes, group_name); @@ -445,7 +589,11 @@ class ReduceScatterTensor return c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "") +<<<<<<< HEAD .typed() +======= + .typed() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .call(input, reduce_op, group_size, group_name); } @@ -461,7 +609,11 @@ class ReduceScatterTensor auto out = c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "") +<<<<<<< HEAD .typed() +======= + .typed() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .call(grad_out, group_size, group_name); // do an explicit wait to avoid cuda stream issues @@ -501,7 +653,11 @@ class AllGatherIntoTensor return c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "") +<<<<<<< HEAD .typed() +======= + .typed() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .call(input, group_size, group_name); } @@ -517,7 +673,11 @@ class AllGatherIntoTensor auto out = c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "") +<<<<<<< HEAD .typed() +======= + .typed() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .call(grad_out, "sum", group_size, group_name); // do an explicit wait to avoid cuda stream issues @@ -594,10 +754,14 @@ at::Tensor shard_dim_alltoall( input_sizes.insert(input_sizes.begin() + shard_dim, group_size); auto tensor_reshaped = input.view(input_sizes); +<<<<<<< HEAD auto tensor_shard_contig = tensor_reshaped.movedim(shard_dim, 0).contiguous(); auto tensor_for_comm = input.is_complex() ? at::view_as_real(tensor_shard_contig) : tensor_shard_contig; +======= + auto tensor_for_comm = tensor_reshaped.movedim(shard_dim, 0).contiguous(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto recv_tensor = at::empty_like(tensor_for_comm); std::vector out_split_sizes; @@ -619,8 +783,12 @@ at::Tensor shard_dim_alltoall( // view/reshape it back to the expected output shape output_sizes[shard_dim] /= group_size; output_sizes[gather_dim] *= group_size; +<<<<<<< HEAD return input.is_complex() ? at::view_as_complex(output).view(output_sizes) : output.view(output_sizes); +======= + return output.view(output_sizes); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace diff --git a/torch/csrc/distributed/c10d/Functional.hpp b/torch/csrc/distributed/c10d/Functional.hpp index 553ba296cc52c..ed3172e4bffee 100644 --- a/torch/csrc/distributed/c10d/Functional.hpp +++ b/torch/csrc/distributed/c10d/Functional.hpp @@ -1,6 +1,7 @@ #pragma once #include +<<<<<<< HEAD namespace c10d { @@ -76,3 +77,5 @@ C10_EXPORT at::Tensor broadcast( std::string group_name); } // namespace c10d +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 8074cc98a04f1..e995009dd2849 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -1,5 +1,8 @@ #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include @@ -40,6 +43,7 @@ NCCLComm::NCCLComm(NCCLComm&& other) { std::swap(deviceIndex_, other.deviceIndex_); } +<<<<<<< HEAD void NCCLComm::setUniqueHash(ncclUniqueId ncclId) { const uint8_t* bytes = reinterpret_cast(&ncclId); @@ -58,6 +62,10 @@ void NCCLComm::setUniqueHash(std::string hash) { std::string NCCLComm::getUniqueHash() { return uniqueHash_; +======= +ncclUniqueId NCCLComm::getNcclId() { + return ncclId_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } std::shared_ptr NCCLComm::create( @@ -70,7 +78,11 @@ std::shared_ptr NCCLComm::create( C10D_NCCL_CHECK( ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), std::nullopt); +<<<<<<< HEAD comm->setUniqueHash(commId); +======= + comm->ncclId_ = commId; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comm->rank_ = rank; comm->deviceIndex_ = deviceIndex; comm->initialized_ = true; @@ -95,7 +107,11 @@ std::shared_ptr NCCLComm::create( ncclCommInitRankConfig( &(comm->ncclComm_), numRanks, commId, rank, &config), std::nullopt); +<<<<<<< HEAD comm->setUniqueHash(commId); +======= + comm->ncclId_ = commId; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comm->rank_ = rank; comm->deviceIndex_ = deviceIndex; // Under blocking mode, comm is initialized immediately after NCCL init @@ -129,7 +145,11 @@ std::shared_ptr NCCLComm::create_scalable( // Only the first ncclUniqueId will be used to create the // communicator hash id, which is used to identify the communicator // in the log file and in the replay tool. +<<<<<<< HEAD comm->setUniqueHash(commIds[0]); +======= + comm->ncclId_ = commIds[0]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comm->rank_ = rank; comm->deviceIndex_ = deviceIndex; comm->initialized_ = !comm->nonBlocking_; @@ -195,12 +215,24 @@ std::optional NCCLComm::getNcclCommFailureReason() const { return commFailureReason_; } +<<<<<<< HEAD #if defined(NCCL_HAS_COMM_SPLIT) +======= +// TODO: why do we have `!defined(FBCODE_CAFFE2)` here? +#if defined(NCCL_HAS_COMM_SPLIT) && !defined(FBCODE_CAFFE2) +// last argument to split() API is not used to support +// multiple implementations +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::shared_ptr NCCLComm::split( NCCLComm* source, int color_id, int rank, +<<<<<<< HEAD ncclConfig_t& config) { +======= + ncclConfig_t& config, + std::vector& ranks_ull) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK( color_id >= NCCL_SPLIT_NOCOLOR, "Color must be a non-negative value or NCCL_SPLIT_NOCOLOR (-1)" @@ -250,9 +282,12 @@ std::shared_ptr NCCLComm::split( // Child comm should be on the same device as parent comm comm->deviceIndex_ = source->deviceIndex_; comm->nonBlocking_ = config.blocking == 0; +<<<<<<< HEAD comm->setUniqueHash( source->getUniqueHash() + ":" + std::to_string(source->ncclCommSplitCounter_)); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) LOG(INFO) << "Rank " << source->rank_ << ": created child comm " << comm->repr() << " with color_id " << color_id; return comm; @@ -573,6 +608,7 @@ size_t hashTensors(const std::vector& tensors) { return hash; } +<<<<<<< HEAD // NCCL uses Non-negative int to represent in-group according to API // requirement. We take a list of ranks and generate a hash value based on the // list and ensure its range of 32-bit int. @@ -594,6 +630,8 @@ int genNcclSplitColor(const std::vector& ranks) { return color; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Default value: 30 minutes int nccl_nonblocking_timeout() { static int timeout = -2; // -2 means not initialized diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index fcd55b6a655ef..e3485c056d1ba 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -14,6 +14,10 @@ #include #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include constexpr int64_t kCommInitBusyWaitMillis = 2; @@ -231,7 +235,10 @@ static std::map ncclDataType = { }; TORCH_API size_t hashTensors(const std::vector& tensors); +<<<<<<< HEAD TORCH_API int genNcclSplitColor(const std::vector& ranks); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_API std::string getNcclVersion(); TORCH_API std::tuple getNcclVersionTuple(); TORCH_API int getNcclVersionNumber(); @@ -259,10 +266,13 @@ class NCCLComm { ~NCCLComm() noexcept; +<<<<<<< HEAD void setUniqueHash(ncclUniqueId ncclId); void setUniqueHash(std::string hash); std::string getUniqueHash(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static std::shared_ptr create( int numRanks, int rank, @@ -291,13 +301,22 @@ class NCCLComm { NCCLComm* source, int color_id, int rank, +<<<<<<< HEAD ncclConfig_t& config); +======= + ncclConfig_t& config, + std::vector& ranks_ull); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif // NCCL_HAS_COMM_SPLIT #if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP) std::unordered_map ncclCommDump(); #endif +<<<<<<< HEAD +======= + ncclUniqueId getNcclId(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::DeviceIndex getDeviceIndex(); // Must not be copyable @@ -357,8 +376,13 @@ class NCCLComm { friend class ProcessGroupNCCL; protected: +<<<<<<< HEAD // Unique hash for this communicator. std::string uniqueHash_; +======= + // Unique nccl_id for this communicator. + ncclUniqueId ncclId_{}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool aborted_{false}; uint64_t ncclCommSplitCounter_{0}; ncclResult_t ncclAsyncErr_{ncclSuccess}; diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index e57e2c2a8d417..bc2e5559a2684 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -4,7 +4,11 @@ #include #include +<<<<<<< HEAD #include +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -158,6 +162,7 @@ void ProcessGroup::release_resources() { backendTypeToBackend_.clear(); } +<<<<<<< HEAD c10::intrusive_ptr ProcessGroup::splitGroup( const std::vector& ranks, const std::optional& timeout, @@ -259,6 +264,8 @@ c10::intrusive_ptr ProcessGroup::mergeRemoteGroup( return newGroup; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace c10d namespace { @@ -428,6 +435,7 @@ bool allow_inflight_collective_as_graph_input() { .allow_inflight_collective_as_graph_input(); } +<<<<<<< HEAD c10::intrusive_ptr& currentProcessGroup() { thread_local static c10::intrusive_ptr pg = nullptr; return pg; @@ -437,4 +445,6 @@ void setProcessGroup(c10::intrusive_ptr pg) { currentProcessGroup() = std::move(pg); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 4fb2d566e9a76..d9a3d17da0564 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -71,6 +71,7 @@ C10_EXPORT bool allow_inflight_collective_as_graph_input(); // class TORCH_API ProcessGroup : public torch::CustomClassHolder { public: +<<<<<<< HEAD struct TORCH_API MergeOptions : torch::CustomClassHolder { explicit MergeOptions( const std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout, @@ -86,6 +87,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { std::optional group_desc; }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) enum BackendType : uint8_t { UNDEFINED = 0, GLOO = 1, @@ -179,6 +182,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return false; } +<<<<<<< HEAD virtual void setTimeout(std::chrono::milliseconds timeout) { for (auto& backend : backendTypeToBackend_) { backend.second->setTimeout(timeout); @@ -189,6 +193,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return splitCounter_++; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) virtual void startCoalescing(c10::DeviceType deviceType) { // only nccl has implemented startCoalescing so only execute for nccl // backends @@ -967,10 +973,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return bound_device_id_; } +<<<<<<< HEAD c10::intrusive_ptr getStore() const { return store_; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void setBoundDeviceId(std::optional device) { if (device) { TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index"); @@ -978,6 +987,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { bound_device_id_ = device; } +<<<<<<< HEAD // This creates a new subgroup using the specified ranks. // The current rank must be included in the list of new_ranks. virtual c10::intrusive_ptr splitGroup( @@ -994,6 +1004,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const MergeOptions& opts, const int& size); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) protected: // Implementations of this interface need to call this to setup // appropriate logging etc. @@ -1007,7 +1019,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) BackendType backendType_; std::string pg_desc_; +<<<<<<< HEAD int64_t splitCounter_; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Debug level setting. It is parsed once when ProcessGroup is constructed and // remains the same across use of this process group. @@ -1024,8 +1039,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { std::optional bound_device_id_; }; +<<<<<<< HEAD // Thread local functions for managing the currently active process group. TORCH_API c10::intrusive_ptr& currentProcessGroup(); TORCH_API void setProcessGroup(c10::intrusive_ptr processGroup); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index c7f21d62e24e6..9cc53cbfd878e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -299,7 +299,10 @@ ProcessGroupGloo::AsyncWork::AsyncWork( std::vector> outputTensors, OpType opType, uint64_t seq, +<<<<<<< HEAD std::chrono::milliseconds timeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const char* profilingTitle, const std::optional>& inputTensors) // Profiler: Pass nullptr as profilingTitle to parent constructor to @@ -307,7 +310,10 @@ ProcessGroupGloo::AsyncWork::AsyncWork( // correct timestamps for work that is asynchronously executed. : Work(-1, opType, nullptr, inputTensors), context_(std::move(context)), +<<<<<<< HEAD timeout_(timeout == kUnsetTimeout ? context_->getTimeout() : timeout), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) outputTensors_(std::move(outputTensors)), future_(createFutureAsOutput(outputTensors_)), seq_(seq) { @@ -528,16 +534,27 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: // use. Note: if the hostname does not resolve to an address (e.g. // because of misconfigured /etc/hosts file), this will not work. const auto hostNameMax = sysconf(_SC_HOST_NAME_MAX); +<<<<<<< HEAD std::string hostname(hostNameMax, '\0'); auto rv = gethostname(hostname.data(), hostNameMax); +======= + auto hostname = std::unique_ptr(new char[hostNameMax]); + auto rv = gethostname(hostname.get(), hostNameMax); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (rv != 0) { C10_THROW_ERROR(DistBackendError, c10::utils::str_error(errno)); } // Use this machine's hostname if it resolves to an address. +<<<<<<< HEAD if (doesHostnameResolveToUsableAddress(hostname.data())) { return ::c10d::GlooDeviceFactory::makeDeviceForHostname( hostname.data(), lazyInit); +======= + if (doesHostnameResolveToUsableAddress(hostname.get())) { + return ::c10d::GlooDeviceFactory::makeDeviceForHostname( + hostname.get(), lazyInit); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // Otherwise, use the loopback address. @@ -697,6 +714,7 @@ const std::vector& ProcessGroupGloo::groupRanks() const { return options_->global_ranks_in_group; } +<<<<<<< HEAD c10::intrusive_ptr ProcessGroupGloo::split( const c10::intrusive_ptr& store, const std::vector& ranks, @@ -735,6 +753,8 @@ c10::intrusive_ptr ProcessGroupGloo::merge( return c10::static_intrusive_pointer_cast(pg); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { std::unique_lock lock(workMutex_); pgStatus_->lastEnqueuedSeq = static_cast(work->seq_); @@ -777,14 +797,21 @@ class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork { int rootRank, int rootTensor, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) +======= + uint64_t seq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : ProcessGroupGloo::AsyncWork( std::move(context), {inputs}, OpType::BROADCAST, seq, +<<<<<<< HEAD timeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "gloo:broadcast", inputs), inputs(inputs), @@ -797,15 +824,22 @@ class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork { const int rootTensor; const uint32_t tag; +<<<<<<< HEAD void broadcast(at::Tensor tensor) { if (tensor.is_complex()) { tensor = at::view_as_real(tensor); } +======= + void broadcast(at::Tensor& tensor) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto& scalarType = tensor.scalar_type(); gloo::BroadcastOptions opts(context_); opts.setRoot(rootRank); opts.setTag(tag); +<<<<<<< HEAD opts.setTimeout(timeout_); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensor); gloo::broadcast(opts); } @@ -839,6 +873,7 @@ class AsyncBroadcastCUDAWork : public AsyncBroadcastWork { int rootRank, int rootTensor, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) : AsyncBroadcastWork( @@ -849,6 +884,10 @@ class AsyncBroadcastCUDAWork : public AsyncBroadcastWork { tag, seq, timeout) { +======= + uint64_t seq) + : AsyncBroadcastWork(context, inputs, rootRank, rootTensor, tag, seq) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) initializeStreamsEvents(inputs, streams, events); // Create pinned host side tensors. @@ -925,6 +964,7 @@ c10::intrusive_ptr ProcessGroupGloo::broadcast( ++seq_; if (device.type() == at::kCPU) { work = c10::make_intrusive( +<<<<<<< HEAD std::move(context), inputs, opts.rootRank, @@ -941,6 +981,12 @@ c10::intrusive_ptr ProcessGroupGloo::broadcast( tag, seq_, opts.timeout); +======= + std::move(context), inputs, opts.rootRank, opts.rootTensor, tag, seq_); + } else if (device.type() == at::kCUDA) { + work = c10::make_intrusive( + std::move(context), inputs, opts.rootRank, opts.rootTensor, tag, seq_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { TORCH_CHECK(false, "Invalid backend"); } @@ -985,7 +1031,11 @@ c10::intrusive_ptr ProcessGroupGloo::allreduce( ++seq_; work = GlooAllreduceRegistry()->Create( +<<<<<<< HEAD device.type(), context, inputs, opts.reduceOp, tag, seq_, opts.timeout); +======= + device.type(), context, inputs, opts.reduceOp, tag, seq_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) enqueue(work); return work; @@ -996,16 +1046,27 @@ static c10::intrusive_ptr makeAllreduceCPUWork( std::vector& inputs, ReduceOp reduceOp, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) { +======= + uint64_t seq) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto layout = inputs[0].layout(); if (layout == c10::kStrided) { return c10::make_intrusive( +<<<<<<< HEAD std::move(context), inputs, reduceOp, tag, seq, timeout); } else if (layout == c10::kSparse) { return c10::make_intrusive( std::move(context), inputs, tag, seq, timeout); +======= + std::move(context), inputs, reduceOp, tag, seq); + } else if (layout == c10::kSparse) { + return c10::make_intrusive( + std::move(context), inputs, tag, seq); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { TORCH_CHECK(false, "ProcessGroupGloo::allreduce: unsupported layout"); } @@ -1020,8 +1081,12 @@ C10_DEFINE_TYPED_REGISTRY( std::vector&, ReduceOp, uint32_t, +<<<<<<< HEAD uint64_t, std::chrono::milliseconds) +======= + uint64_t) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) C10_REGISTER_TYPED_CREATOR( GlooAllreduceRegistry, @@ -1086,7 +1151,11 @@ c10::intrusive_ptr ProcessGroupGloo::allreduce_coalesced( if (device.type() == c10::kCPU) { if (layout == c10::kStrided) { work = c10::make_intrusive( +<<<<<<< HEAD std::move(context), tensors, opts.reduceOp, tag, seq_, opts.timeout); +======= + std::move(context), tensors, opts.reduceOp, tag, seq_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { invalidArgument("unsupported layout"); } @@ -1108,14 +1177,21 @@ class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { int rootTensor, ReduceOp reduceOp, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) +======= + uint64_t seq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : ProcessGroupGloo::AsyncWork( std::move(context), {inputs}, OpType::REDUCE, seq, +<<<<<<< HEAD timeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "gloo:reduce", inputs), inputs(inputs), @@ -1131,6 +1207,7 @@ class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { const uint32_t tag; void reduce(std::vector& tensors) { +<<<<<<< HEAD auto tensor = tensors[0]; if (tensor.is_complex()) { TORCH_CHECK( @@ -1147,6 +1224,14 @@ class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { opts.setReduceFunction(getFunction(scalarType, reduceOp)); opts.setTimeout(timeout_); GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensor); +======= + const auto& scalarType = tensors[0].scalar_type(); + gloo::ReduceOptions opts(context_); + opts.setRoot(rootRank); + opts.setTag(tag); + opts.setReduceFunction(getFunction(scalarType, reduceOp)); + GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensors[0]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gloo::reduce(opts); // Gloo doesn't support AVG so we use SUM + division. @@ -1191,8 +1276,12 @@ class AsyncReduceCUDAWork : public AsyncReduceWork { int rootTensor, ReduceOp reduceOp, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) +======= + uint64_t seq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : AsyncReduceWork( context, inputs, @@ -1200,8 +1289,12 @@ class AsyncReduceCUDAWork : public AsyncReduceWork { rootTensor, std::move(reduceOp), tag, +<<<<<<< HEAD seq, timeout) { +======= + seq) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) initializeStreamsEvents(inputs, streams, events); // Kick off copy from CUDA tensors to pinned CPU tensors. @@ -1284,8 +1377,12 @@ c10::intrusive_ptr ProcessGroupGloo::reduce( opts.rootTensor, opts.reduceOp, tag, +<<<<<<< HEAD seq_, opts.timeout); +======= + seq_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else if (device.type() == at::kCUDA) { work = c10::make_intrusive( std::move(context), @@ -1294,8 +1391,12 @@ c10::intrusive_ptr ProcessGroupGloo::reduce( opts.rootTensor, opts.reduceOp, tag, +<<<<<<< HEAD seq_, opts.timeout); +======= + seq_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { TORCH_CHECK(false, "Invalid backend"); } @@ -1312,14 +1413,21 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork { std::vector>& outputs, std::vector& inputs, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) +======= + uint64_t seq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : ProcessGroupGloo::AsyncWork( std::move(context), outputs, OpType::ALLGATHER, seq, +<<<<<<< HEAD timeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "gloo:all_gather", inputs), outputs(outputs), @@ -1336,7 +1444,10 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork { const auto& scalarType = inputs[0].scalar_type(); gloo::AllgatherOptions opts(context_); opts.setTag(tag); +<<<<<<< HEAD opts.setTimeout(timeout_); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Use single flattened input tensor. at::Tensor flatInputTensor = flattenDenseTensors(inputs); @@ -1345,8 +1456,12 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork { // Use single flat output tensor. // The first dimension corresponds to the index into outputs[N], // so copying into the actual output later is easy. +<<<<<<< HEAD at::Tensor flatOutputTensor = newLikeFlat(outputs[0], /*preserve_strides*/ false); +======= + at::Tensor flatOutputTensor = newLikeFlat(outputs[0]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor); gloo::allgather(opts); @@ -1363,7 +1478,11 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork { } const std::vector getOutputTensors() override { +<<<<<<< HEAD return {newLikeFlat(outputs[0], /*preserve_strides*/ false)}; +======= + return {newLikeFlat(outputs[0])}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void run() override { @@ -1380,9 +1499,14 @@ class AsyncAllgatherCUDAWork : public AsyncAllgatherWork { std::vector>& outputs, std::vector& inputs, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) : AsyncAllgatherWork(context, outputs, inputs, tag, seq, timeout) { +======= + uint64_t seq) + : AsyncAllgatherWork(context, outputs, inputs, tag, seq) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) initializeStreamsEvents(inputs, inputStreams, inputEvents); initializeStreamsEvents(outputs, outputStreams, outputEvents); @@ -1498,8 +1622,12 @@ c10::intrusive_ptr ProcessGroupGloo::reduce_scatter_tensor_coalesced( std::vector inp = {buffers[i]}; AllreduceOptions arOpts; arOpts.reduceOp = opts.reduceOp; +<<<<<<< HEAD arOpts.timeout = opts.timeout; works.push_back(allreduce(inp, arOpts)); +======= + works.push_back(allreduce(inp)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return c10::make_intrusive( [rank, worldSize, buffers, outputTensors, works = std::move(works)]() { @@ -1577,10 +1705,17 @@ c10::intrusive_ptr ProcessGroupGloo::allgather( ++seq_; if (device.type() == at::kCPU) { work = c10::make_intrusive( +<<<<<<< HEAD std::move(context), outputs, inputs, tag, seq_, opts.timeout); } else if (device.type() == at::kCUDA) { work = c10::make_intrusive( std::move(context), outputs, inputs, tag, seq_, opts.timeout); +======= + std::move(context), outputs, inputs, tag, seq_); + } else if (device.type() == at::kCUDA) { + work = c10::make_intrusive( + std::move(context), outputs, inputs, tag, seq_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { TORCH_CHECK(false, "Invalid backend"); } @@ -1597,14 +1732,21 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { std::vector>& output_lists, std::vector& input_list, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) +======= + uint64_t seq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : ProcessGroupGloo::AsyncWork( std::move(context), output_lists, OpType::ALLGATHER_COALESCED, seq, +<<<<<<< HEAD timeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "gloo:all_gather", input_list), output_lists(output_lists), @@ -1623,7 +1765,10 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { const auto& scalarType = input_list[0].scalar_type(); gloo::AllgatherOptions opts(context_); opts.setTag(tag); +<<<<<<< HEAD opts.setTimeout(timeout_); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Use single flattened input tensor. at::Tensor flatInputTensor = flattenDenseTensors(input_list); @@ -1659,7 +1804,11 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { } const std::vector getOutputTensors() override { +<<<<<<< HEAD return {newLikeFlat(output_lists[0], /*preserve_strides*/ false)}; +======= + return {newLikeFlat(output_lists[0])}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void run() override { @@ -1672,7 +1821,11 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { c10::intrusive_ptr ProcessGroupGloo::allgather_coalesced( std::vector>& output_lists, std::vector& input_list, +<<<<<<< HEAD const AllgatherOptions& opts) { +======= + const AllgatherOptions& /* unused */) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static auto invalidArgument = [](const std::string& msg) { TORCH_CHECK(false, "ProcessGroupGloo::allgather_coalesced: " + msg); }; @@ -1720,7 +1873,11 @@ c10::intrusive_ptr ProcessGroupGloo::allgather_coalesced( auto context = getContext(tag); ++seq_; auto work = c10::make_intrusive( +<<<<<<< HEAD std::move(context), output_lists, input_list, tag, seq_, opts.timeout); +======= + std::move(context), output_lists, input_list, tag, seq_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) enqueue(work); return work; } @@ -1750,14 +1907,21 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { std::vector& inputs, int root, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) +======= + uint64_t seq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : ProcessGroupGloo::AsyncWork( std::move(context), outputs, OpType::GATHER, seq, +<<<<<<< HEAD timeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "gloo:gather", inputs), outputs(outputs), @@ -1777,19 +1941,30 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { gloo::GatherOptions opts(context_); opts.setRoot(root); opts.setTag(tag); +<<<<<<< HEAD opts.setTimeout(timeout_); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Set single temporary tensor on root process. // This is later scattered to the separate output tensors. at::Tensor flatOutputTensor; if (context_->rank == root) { +<<<<<<< HEAD flatOutputTensor = newLikeFlat(outputs[0], /*preserve_strides*/ false); +======= + flatOutputTensor = newLikeFlat(outputs[0]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor); } // Set single input tensor on all processes. +<<<<<<< HEAD at::Tensor flatInputTensor = flattenDenseTensors(inputs[0]); GENERATE_ALL_TYPES(scalarType, setInput, opts, flatInputTensor); +======= + GENERATE_ALL_TYPES(scalarType, setInput, opts, inputs[0]); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gloo::gather(opts); // Unflatten into output tensors on root process. @@ -1806,8 +1981,12 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { const std::vector getOutputTensors() override { return outputs.empty() ? std::vector{} +<<<<<<< HEAD : std::vector{newLikeFlat( outputs[0], /*preserve_strides*/ false)}; +======= + : std::vector{newLikeFlat(outputs[0])}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void run() override { @@ -1828,9 +2007,14 @@ class AsyncGatherCUDAWork : public AsyncGatherWork { std::vector& inputs, int root, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) : AsyncGatherWork(context, outputs, inputs, root, tag, seq, timeout) { +======= + uint64_t seq) + : AsyncGatherWork(context, outputs, inputs, root, tag, seq) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) initializeStreamsEvents(inputs, inputStreams, inputEvents); initializeStreamsEvents(outputs, outputStreams, outputEvents); @@ -1948,6 +2132,7 @@ c10::intrusive_ptr ProcessGroupGloo::gather( ++seq_; if (device.type() == at::kCPU) { work = c10::make_intrusive( +<<<<<<< HEAD std::move(context), outputs, inputs, @@ -1964,6 +2149,12 @@ c10::intrusive_ptr ProcessGroupGloo::gather( tag, seq_, opts.timeout); +======= + std::move(context), outputs, inputs, opts.rootRank, tag, seq_); + } else if (device.type() == at::kCUDA) { + work = c10::make_intrusive( + std::move(context), outputs, inputs, opts.rootRank, tag, seq_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { TORCH_CHECK(false, "Invalid backend"); } @@ -1981,14 +2172,21 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { std::vector>& inputs, int root, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) +======= + uint64_t seq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : ProcessGroupGloo::AsyncWork( std::move(context), {outputs}, OpType::SCATTER, seq, +<<<<<<< HEAD timeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "gloo:scatter", !inputs.empty() ? std::optional>(inputs[0]) : std::nullopt), @@ -2009,7 +2207,10 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { gloo::ScatterOptions opts(context_); opts.setRoot(root); opts.setTag(tag); +<<<<<<< HEAD opts.setTimeout(timeout_); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Set list of input tensors on root process if (context_->rank == root) { @@ -2023,8 +2224,12 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { const std::vector getInputTensors() override { return inputs.empty() ? std::vector{} +<<<<<<< HEAD : std::vector{newLikeFlat( inputs[0], /*preserve_strides*/ false)}; +======= + : std::vector{newLikeFlat(inputs[0])}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } const std::vector getOutputTensors() override { @@ -2044,9 +2249,14 @@ class AsyncScatterCUDAWork : public AsyncScatterWork { std::vector>& inputs, int root, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) : AsyncScatterWork(context, outputs, inputs, root, tag, seq, timeout) { +======= + uint64_t seq) + : AsyncScatterWork(context, outputs, inputs, root, tag, seq) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) initializeStreamsEvents(inputs, inputStreams, inputEvents); initializeStreamsEvents(outputs, outputStreams, outputEvents); @@ -2161,6 +2371,7 @@ c10::intrusive_ptr ProcessGroupGloo::scatter( ++seq_; if (device.type() == at::kCPU) { work = c10::make_intrusive( +<<<<<<< HEAD std::move(context), outputs, inputs, @@ -2177,6 +2388,12 @@ c10::intrusive_ptr ProcessGroupGloo::scatter( tag, seq_, opts.timeout); +======= + std::move(context), outputs, inputs, opts.rootRank, tag, seq_); + } else if (device.type() == at::kCUDA) { + work = c10::make_intrusive( + std::move(context), outputs, inputs, opts.rootRank, tag, seq_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { TORCH_CHECK(false, "Invalid backend"); } @@ -2218,8 +2435,12 @@ c10::intrusive_ptr ProcessGroupGloo::reduce_scatter( std::vector inp = {buffers[i]}; AllreduceOptions arOpts; arOpts.reduceOp = opts.reduceOp; +<<<<<<< HEAD arOpts.timeout = opts.timeout; works.push_back(allreduce(inp, arOpts)); +======= + works.push_back(allreduce(inp)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return c10::make_intrusive( [worldSize, works = std::move(works)]() { @@ -2240,14 +2461,21 @@ class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork { std::vector& outputCounts, std::vector& inputCounts, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) +======= + uint64_t seq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : ProcessGroupGloo::AsyncWork( std::move(context), {{outputTensor}}, OpType::ALLTOALL, seq, +<<<<<<< HEAD timeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "gloo:all_to_all", std::optional>({inputTensor})), outputTensor(outputTensor), @@ -2268,7 +2496,10 @@ class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork { // Gloo alltoall gloo::AlltoallOptions opts(context_); opts.setTag(tag); +<<<<<<< HEAD opts.setTimeout(timeout_); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GENERATE_ALL_TYPES(scalarType, setInput, opts, inputTensor); GENERATE_ALL_TYPES(scalarType, setOutput, opts, outputTensor); gloo::alltoall(opts); @@ -2286,7 +2517,10 @@ class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork { outputCounts, outputTensor, &recvCounts, &recvOffsets); gloo::AlltoallvOptions opts(context_); opts.setTag(tag); +<<<<<<< HEAD opts.setTimeout(timeout_); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GENERATE_ALL_TYPES(scalarType, setInput, opts, inputTensor, sendCounts); GENERATE_ALL_TYPES(scalarType, setOutput, opts, outputTensor, recvCounts); gloo::alltoallv(opts); @@ -2315,8 +2549,12 @@ class AsyncAlltoallCUDAWork : public AsyncAlltoallWork { std::vector& outputCounts, std::vector& inputCounts, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) +======= + uint64_t seq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : AsyncAlltoallWork( context, outputTensor, @@ -2324,8 +2562,12 @@ class AsyncAlltoallCUDAWork : public AsyncAlltoallWork { outputCounts, inputCounts, tag, +<<<<<<< HEAD seq, timeout) { +======= + seq) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) initializeStreamsEvents({inputTensor}, inputStreams, inputEvents); initializeStreamsEvents({outputTensor}, outputStreams, outputEvents); @@ -2376,7 +2618,11 @@ c10::intrusive_ptr ProcessGroupGloo::alltoall_base( at::Tensor& inputTensor, std::vector& outputCounts, std::vector& inputCounts, +<<<<<<< HEAD const AllToAllOptions& opts) { +======= + const AllToAllOptions& /* unused */) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static auto invalidArgument = [](const std::string& msg) { TORCH_CHECK(false, "ProcessGroupGloo::alltoall_base: " + msg); }; @@ -2405,8 +2651,12 @@ c10::intrusive_ptr ProcessGroupGloo::alltoall_base( outputCounts, inputCounts, tag, +<<<<<<< HEAD seq_, opts.timeout); +======= + seq_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else if (device.type() == at::kCUDA) { work = c10::make_intrusive( std::move(context), @@ -2415,8 +2665,12 @@ c10::intrusive_ptr ProcessGroupGloo::alltoall_base( outputCounts, inputCounts, tag, +<<<<<<< HEAD seq_, opts.timeout); +======= + seq_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { invalidArgument(c10::str("unsupported device type ", device.type())); } @@ -2527,14 +2781,21 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { std::shared_ptr context, std::vector> priorWork, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) +======= + uint64_t seq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : ProcessGroupGloo::AsyncWork( std::move(context), {}, OpType::BARRIER, seq, +<<<<<<< HEAD timeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "gloo:barrier", std::nullopt), priorWork(std::move(priorWork)), @@ -2563,7 +2824,10 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { gloo::BarrierOptions opts(context_); opts.setTag(tag); +<<<<<<< HEAD opts.setTimeout(timeout_); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gloo::barrier(opts); } }; @@ -2587,7 +2851,11 @@ c10::intrusive_ptr ProcessGroupGloo::barrier(const BarrierOptions& opts) { auto context = getContext(tag); ++seq_; auto work = c10::make_intrusive( +<<<<<<< HEAD std::move(context), std::move(priorWork), tag, seq_, opts.timeout); +======= + std::move(context), std::move(priorWork), tag, seq_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) enqueue(work); return work; } diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index 4297807f2e8b9..0197ba9f32fa8 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -69,7 +69,10 @@ class TORCH_API ProcessGroupGloo : public Backend { std::vector> outputTensors, OpType opType, uint64_t seq, +<<<<<<< HEAD std::chrono::milliseconds timeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const char* profilingTitle = nullptr, const std::optional>& inputTensors = std::nullopt); @@ -100,7 +103,10 @@ class TORCH_API ProcessGroupGloo : public Backend { // work has completed std::optional trace_id_; std::shared_ptr context_; +<<<<<<< HEAD const std::chrono::milliseconds timeout_; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) private: void finishWorkGloo(); @@ -188,10 +194,13 @@ class TORCH_API ProcessGroupGloo : public Backend { } #endif +<<<<<<< HEAD const c10::intrusive_ptr<::c10d::Store>& _getStore() const { return store_; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) protected: c10::intrusive_ptr<::c10d::Store> store_; }; @@ -255,6 +264,11 @@ class TORCH_API ProcessGroupGloo : public Backend { return c10::make_intrusive(timeout); } +<<<<<<< HEAD +======= + std::vector global_ranks_in_group; + std::string group_name; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector> devices; int threads; }; @@ -263,10 +277,13 @@ class TORCH_API ProcessGroupGloo : public Backend { return std::string(GLOO_BACKEND_NAME); } +<<<<<<< HEAD bool supportsSplitting() const override { return true; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Helper functions to create a new device object. // They are static functions on this class to keep them logically // separate from the rest of the code base (e.g. torch/csrc/distributed). @@ -300,6 +317,7 @@ class TORCH_API ProcessGroupGloo : public Backend { return options_; } +<<<<<<< HEAD void setTimeout(std::chrono::milliseconds timeout) override { options_->timeout = timeout; for (auto& context : contexts_) { @@ -322,6 +340,8 @@ class TORCH_API ProcessGroupGloo : public Backend { const int& rank, const int& size) override; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const std::vector& groupRanks() const; c10::intrusive_ptr broadcast( diff --git a/torch/csrc/distributed/c10d/ProcessGroupGlooCuda.cpp b/torch/csrc/distributed/c10d/ProcessGroupGlooCuda.cpp index 6e680b41fe8de..46b6b2c4881e1 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGlooCuda.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGlooCuda.cpp @@ -9,18 +9,29 @@ namespace c10d { class AsyncAllreduceCUDADeviceWork : public ProcessGroupGloo::AsyncWork { public: AsyncAllreduceCUDADeviceWork( +<<<<<<< HEAD std::shared_ptr context, std::vector& inputs, ReduceOp reduceOp, uint32_t tag, uint64_t seq, std::chrono::milliseconds timeout) +======= + const std::shared_ptr& context, + std::vector& inputs, + ReduceOp reduceOp, + uint32_t tag, + uint64_t seq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : ProcessGroupGloo::AsyncWork( std::move(context), {inputs}, OpType::ALLREDUCE, seq, +<<<<<<< HEAD timeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "gloo:all_reduce", inputs), inputs_(inputs), @@ -78,6 +89,7 @@ class AsyncAllreduceCUDAHostWork : public AsyncAllreduceWork { std::vector& inputs, ReduceOp reduceOp, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) : AsyncAllreduceWork( @@ -87,6 +99,10 @@ class AsyncAllreduceCUDAHostWork : public AsyncAllreduceWork { tag, seq, timeout) { +======= + uint64_t seq) + : AsyncAllreduceWork(context, inputs, std::move(reduceOp), tag, seq) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) initializeStreamsEvents(inputs, streams, events); // Kick off copy from CUDA tensors to pinned CPU tensors. @@ -135,9 +151,14 @@ class AsyncSparseAllreduceCUDAWork : public AsyncSparseAllreduceWork { const std::shared_ptr& context, std::vector& inputs, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) : AsyncSparseAllreduceWork(context, inputs, tag, seq, timeout) { +======= + uint64_t seq) + : AsyncSparseAllreduceWork(context, inputs, tag, seq) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) initializeStreamsEvents(inputs, streams, events); // Kick off copy from CUDA tensors to CPU tensors. @@ -189,13 +210,18 @@ static c10::intrusive_ptr makeAllreduceCUDAWork( std::vector& inputs, ReduceOp reduceOp, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) { +======= + uint64_t seq) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto layout = inputs[0].layout(); if (layout == c10::kStrided) { if (context->getDevice()->hasGPUDirect()) { return c10::make_intrusive( +<<<<<<< HEAD std::move(context), inputs, reduceOp, tag, seq, timeout); } else { return c10::make_intrusive( @@ -204,6 +230,16 @@ static c10::intrusive_ptr makeAllreduceCUDAWork( } else if (layout == c10::kSparse) { return c10::make_intrusive( std::move(context), inputs, tag, seq, timeout); +======= + std::move(context), inputs, reduceOp, tag, seq); + } else { + return c10::make_intrusive( + std::move(context), inputs, reduceOp, tag, seq); + } + } else if (layout == c10::kSparse) { + return c10::make_intrusive( + std::move(context), inputs, tag, seq); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { TORCH_CHECK(false, "ProcessGroupGloo::allreduce: unsupported layout"); } diff --git a/torch/csrc/distributed/c10d/ProcessGroupGlooDetail.hpp b/torch/csrc/distributed/c10d/ProcessGroupGlooDetail.hpp index 442cb490743b2..6002c3c13cd04 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGlooDetail.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGlooDetail.hpp @@ -93,8 +93,12 @@ TORCH_DECLARE_TYPED_REGISTRY( std::vector&, ReduceOp, uint32_t, +<<<<<<< HEAD uint64_t, std::chrono::milliseconds); +======= + uint64_t); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // This function initializes a vector of CUDA streams, one for every // tensor in the input tensor vector, and ensures that these streams are @@ -232,8 +236,13 @@ void setInput(O& opts, at::Tensor& tensor, std::vector& counts) { } template +<<<<<<< HEAD void setOutputs(O& opts, std::vector& tensors, int64_t count) { opts.setOutputs(getDataPointers(tensors), count); +======= +void setOutputs(O& opts, std::vector& tensors) { + opts.setOutputs(getDataPointers(tensors), tensors[0].numel()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } template @@ -270,14 +279,21 @@ class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork { std::vector& inputs, ReduceOp reduceOp, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) +======= + uint64_t seq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : ProcessGroupGloo::AsyncWork( std::move(context), {inputs}, OpType::ALLREDUCE, seq, +<<<<<<< HEAD timeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "gloo:all_reduce", inputs), inputs(inputs), @@ -289,6 +305,7 @@ class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork { const uint32_t tag; void allreduce(std::vector& tensors) { +<<<<<<< HEAD auto tensor = tensors[0]; if (tensor.is_complex()) { TORCH_CHECK( @@ -306,6 +323,13 @@ class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork { // Use tensor.numel() instead of tensors[0].numel() to // get the right number of elements when tensors[0] is complex GENERATE_ALL_TYPES(scalarType, setOutputs, opts, tensors, tensor.numel()); +======= + const auto& scalarType = tensors[0].scalar_type(); + gloo::AllreduceOptions opts(context_); + opts.setReduceFunction(getFunction(scalarType, reduceOp)); + opts.setTag(tag); + GENERATE_ALL_TYPES(scalarType, setOutputs, opts, tensors); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gloo::allreduce(opts); // Gloo doesn't support AVG so we use SUM + division. @@ -347,6 +371,7 @@ class AsyncAllreduceCoalescedWork : public AsyncAllreduceWork { std::vector& inputs, ReduceOp reduceOp, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) : AsyncAllreduceWork( @@ -356,6 +381,10 @@ class AsyncAllreduceCoalescedWork : public AsyncAllreduceWork { tag, seq, timeout) {} +======= + uint64_t seq) + : AsyncAllreduceWork(context, inputs, std::move(reduceOp), tag, seq) {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void run() override { allreduceCoalesced(inputs); @@ -386,14 +415,21 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { std::shared_ptr context, std::vector& inputs, uint32_t tag, +<<<<<<< HEAD uint64_t seq, std::chrono::milliseconds timeout) +======= + uint64_t seq) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : ProcessGroupGloo::AsyncWork( std::move(context), {inputs}, OpType::_ALLREDUCE_SPARSE, seq, +<<<<<<< HEAD timeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "gloo:sparse_all_reduce", inputs), inputs(inputs), @@ -568,7 +604,10 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { gloo::AllgatherOptions opts(context_); opts.setOutput(buffer.mutable_data_ptr(), buffer.numel()); opts.setTag(tag); +<<<<<<< HEAD opts.setTimeout(timeout_); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gloo::allgather(opts); return metadata; @@ -600,7 +639,10 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { input.numel()); opts.setOutput(output.mutable_data_ptr(), counts); opts.setTag(tag); +<<<<<<< HEAD opts.setTimeout(timeout_); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gloo::allgatherv(opts); // Compile indices tensor per rank. @@ -646,7 +688,10 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { GENERATE_ALL_TYPES( valueTensor.scalar_type(), setOutput, opts, output, counts); opts.setTag(tag); +<<<<<<< HEAD opts.setTimeout(timeout_); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gloo::allgatherv(opts); // Compile values tensor per rank. diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 768cb3b14fab6..98314c4344162 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -68,6 +68,26 @@ inline bool isUnsupportedFloat8(at::ScalarType t) { ); } +<<<<<<< HEAD +======= +bool complexViewAsRealAllowed(const ReduceOp& reduceOp) { + switch (reduceOp) { + // NOLINTNEXTLINE(bugprone-branch-clone) + case ReduceOp::SUM: + return true; + case ReduceOp::AVG: + return true; + case ReduceOp::PREMUL_SUM: + return true; + case ReduceOp::UNUSED: + return true; + default: + return false; + } + return false; +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT template ncclRedOpRAII unpackPreMulSum( @@ -122,14 +142,21 @@ ncclRedOpRAII getNcclReduceOp( return unpackPreMulSum(reduceOp, comm); case ncclFloat: return unpackPreMulSum(reduceOp, comm); +<<<<<<< HEAD case ncclBfloat16: return unpackPreMulSum(reduceOp, comm); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case ncclDouble: return unpackPreMulSum(reduceOp, comm); default: C10_THROW_ERROR( +<<<<<<< HEAD TypeError, "PreMulSum Data type must be half, float, bfloat16 or double"); +======= + TypeError, "PreMulSum Data type must be half, float, or double"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ncclRedOp_t{}; } #else @@ -203,6 +230,20 @@ void syncStream( ncclEvent.block(ncclStream); } +<<<<<<< HEAD +======= +// Given a ncclUniqueId, convert it to a string representation that can be put +// in the store. +std::string buildNcclUniqueIdStr(const ncclUniqueId& ncclID) { + const uint8_t* bytes = reinterpret_cast(&ncclID); + std::ostringstream oss; + for (const auto i : c10::irange(NCCL_UNIQUE_ID_BYTES)) { + oss << std::hex << static_cast(bytes[i]); + } + return oss.str(); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::string getNcclAbortedCommStoreKey(const std::string& ncclIdStr) { return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr; } @@ -274,12 +315,17 @@ bool shouldAllCommunicatorsRegisterAllTensors() { // - This map has also to be maintained as global variable since the register // hooks are called outside the scope of any PG, thus we need traverse // communicators in all PGs. +<<<<<<< HEAD // MemPoolSet has ids of mempools used with this communicator, and whether they // were registered with window APIs or not using MemPoolSet = std::unordered_set< std::tuple, c10::hash>>; +======= +using MemPoolSet = std:: + unordered_set>; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static std::unordered_map, MemPoolSet> ncclCommMemPoolMap; static std::mutex ncclCommMemPoolMapMutex; @@ -297,6 +343,7 @@ static void cacheAllocatorRegisterHook( std::lock_guard lock(ncclCommMemPoolMapMutex); for (auto& [ncclComm, memPools] : ncclCommMemPoolMap) { if (te.device_ == ncclComm->getDeviceIndex()) { +<<<<<<< HEAD bool symm = false; bool should_register = shouldAllCommunicatorsRegisterAllTensors(); auto it = @@ -314,6 +361,12 @@ static void cacheAllocatorRegisterHook( te.size_, /*errorOnRereg*/ false, /*window*/ symm); +======= + if (shouldAllCommunicatorsRegisterAllTensors() || + memPools.find(te.mempool_) != memPools.end()) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + ncclComm->registerSegment(reinterpret_cast(te.addr_), te.size_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } } @@ -330,6 +383,7 @@ static void cacheAllocatorDeregisterHook( std::lock_guard lock(ncclCommMemPoolMapMutex); for (auto& [ncclComm, memPools] : ncclCommMemPoolMap) { if (te.device_ == ncclComm->getDeviceIndex()) { +<<<<<<< HEAD bool symm = false; bool should_register = shouldAllCommunicatorsRegisterAllTensors(); auto it = @@ -343,6 +397,12 @@ static void cacheAllocatorDeregisterHook( if (should_register) { // NOLINTNEXTLINE(performance-no-int-to-ptr) ncclComm->deregisterSegment(reinterpret_cast(te.addr_), symm); +======= + if (shouldAllCommunicatorsRegisterAllTensors() || + memPools.find(te.mempool_) != memPools.end()) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + ncclComm->deregisterSegment(reinterpret_cast(te.addr_)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } } @@ -383,7 +443,12 @@ static std:: } } for (auto& ncclComm : allNCCLComms) { +<<<<<<< HEAD ncclDumpMap[ncclComm->getUniqueHash()] = ncclComm->ncclCommDump(); +======= + std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId()); + ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return ncclDumpMap; #else @@ -528,9 +593,17 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL( // DEFAULT_FLAGS = cudaEventDisableTiming. if (cudaEventCacheEnabled) { ncclStartEvent_ = enableTiming +<<<<<<< HEAD ? CUDAEventCache::get(device.index())->create(enableTiming) : nullptr; ncclEndEvent_ = CUDAEventCache::get(device.index())->create(enableTiming); +======= + ? ProcessGroupNCCL::CUDAEventCache::get(device.index()) + ->create(enableTiming) + : nullptr; + ncclEndEvent_ = ProcessGroupNCCL::CUDAEventCache::get(device.index()) + ->create(enableTiming); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { ncclStartEvent_ = enableTiming ? std::make_shared(cudaEventDefault) @@ -867,6 +940,64 @@ void ProcessGroupNCCL::WorkNCCL::abort() { } } +<<<<<<< HEAD +======= +ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() = default; + +// CUDA event is used to record the start/end of one Work. +// Instead of let the CUDA event gets destroyed, we now reuse it after the Work +// has been erased from workMetaList_. +// This is to avoid the potential deadlock caused by CudaEventDestroy. +std::shared_ptr ProcessGroupNCCL::CUDAEventCache::create( + bool timing) { + // Register the deleter as a callback when the WorkNCCL object is destroyed. + // Each deleter keeps a ref count to the cache object, so that even when + // the thread that creates the cache is gone, the cache object won't be + // destroyed until all the events in the cache are destroyed (ref number drops + // to zero). + auto deleter = [cache = shared_from_this(), + timing](at::cuda::CUDAEvent* event) { + std::lock_guard lock(cache->cacheMutex_); + // We put the event back to the cache deque once the WorkNCCL object is + // destroyed. + cache->eventsArray_[timing ? 1 : 0].push_back(event); + }; + at::cuda::CUDAEvent* event = nullptr; + { + std::lock_guard lock(cacheMutex_); + auto& events = eventsArray_[timing ? 1 : 0]; + // If we still have events in the cache, we reuse it. Otherwise, we create a + // new one. + if (!events.empty()) { + event = events.front(); + events.pop_front(); + } else { + event = new at::cuda::CUDAEvent( + timing ? cudaEventDefault : cudaEventDisableTiming); + } + } + return std::shared_ptr(event, std::move(deleter)); +} + +std::shared_ptr ProcessGroupNCCL:: + CUDAEventCache::get(at::DeviceIndex device) { + // A per-thread singleton of device-to-CUDAEventCache map. + // Map is needed because events cannot be reused across devices. + // Per-thread ownership is needed to support multi-threaded case (instead of + // multi-process case). + static thread_local std:: + map> + cacheDeviceMap; + // Check if device has already been in the map, if not, add a new entry + auto it = cacheDeviceMap.find(device); + if (it == cacheDeviceMap.end()) { + cacheDeviceMap.emplace( + device, std::make_shared()); + } + return cacheDeviceMap[device]; +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static std::atomic process_group_id = 0; constexpr const char* MULTI_DEVICE_ERROR_MSG = @@ -924,8 +1055,11 @@ ProcessGroupNCCL::ProcessGroupNCCL( TORCH_WARN_ONCE( "TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated."); } +<<<<<<< HEAD showSerializationWarning_ = getCvarBool(TORCH_NCCL_SHOW_EAGER_INIT_P2P_SERIALIZATION_WARNING, true); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (blockingWait_) { LOG(INFO) @@ -977,9 +1111,14 @@ ProcessGroupNCCL::ProcessGroupNCCL( const std::string OFF = "OFF"; std::string torch_distributed_debug = getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str()); +<<<<<<< HEAD LOG(INFO) << logPrefix() << "ProcessGroupNCCL initialization options: " << "size: " << size << ", global rank: " << globalRank() +======= + LOG(INFO) << logPrefix() << "ProcessGroupNCCL initialization options: " + << "size: " << size << ", global rank: " << globalRank() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) << ", TIMEOUT(ms): " << options_->timeout.count() << ", USE_HIGH_PRIORITY_STREAM: " << options_->is_high_priority_stream @@ -1021,7 +1160,10 @@ void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) { LOG(INFO) << logPrefix() << "Eagerly connecting nccl backend with device " << device; initNCCLComm(key, device, OpType::ALLREDUCE); +<<<<<<< HEAD eagerInit_ = true; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } bool ProcessGroupNCCL::useNonblocking() { @@ -1075,7 +1217,16 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) { LOG(ERROR) << logPrefix() << "No parent communicator exists for nocolor split"; } +<<<<<<< HEAD NCCLComm::split(comm.get(), NCCL_SPLIT_NOCOLOR, rank_, options_->config); +======= + NCCLComm::split( + comm.get(), + NCCL_SPLIT_NOCOLOR, + rank_, + options_->config, + options_->global_ranks_in_group); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif // NCCL_HAS_COMM_SPLIT } @@ -1099,14 +1250,21 @@ ErrorType ProcessGroupNCCL::getError() { return error_; } +<<<<<<< HEAD void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool, bool symm) { const auto key = std::to_string(pool->device()); +======= +void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) { + const auto key = std::to_string(pool->device()); + auto device = at::Device(at::DeviceType::CUDA, pool->device()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) LOG(INFO) << logPrefix() << "Performing NCCL user buffer registration for all buffers in " << "MemPool: " << pool->id() << ", device index: " << key << ", i am " << this; auto ncclComm = getNCCLComm(key); if (ncclComm == nullptr) { +<<<<<<< HEAD C10_THROW_ERROR( DistBackendError, "NCCL communicator has not been initialized before mem pool creation. You can pass `device_id` to init_process_group -- one way of eager initialization -- to work around this issue"); @@ -1115,11 +1273,35 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool, bool symm) { std::lock_guard lock(ncclCommMemPoolMapMutex); auto iter = ncclCommMemPoolMap.find(ncclComm); iter->second.insert(std::make_tuple(pool->id(), symm)); +======= + // HACK: currently we are using this function for NVLS + // reductions, and that's why using OpType::ALLREDUCE. + // If we end up using this API for zero-copy P2P, we might + // need to refactor and account for different OpType. + ncclComm = initNCCLComm(key, device, OpType::ALLREDUCE); + } + TORCH_INTERNAL_ASSERT(ncclComm != nullptr); + { + std::lock_guard lock(ncclCommMemPoolMapMutex); + auto iter = ncclCommMemPoolMap.find(ncclComm); + iter->second.insert(pool->id()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // We must ensure we're listening for allocator trace events in order to // register future segments allocated in this pool (this call is idempotent). attachAllocatorHooks(); auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id()); +<<<<<<< HEAD +======= + // TODO: + // if(pool->is_symmetric()) { + // Allgather to verify len(mempool.snapshot.segments) matches across GPUs + // Allgather to verify mempool.alloc_request_counter matches across GPUs + // add alloc_request_counter per mempool (How many allocations a mempool has + // served during its lifetime) this should guarantee pool is used in a + // symmetric/SPMD manner + // } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto& segmentInfo : snapshot.segments) { TORCH_INTERNAL_ASSERT( segmentInfo.device == pool->device(), @@ -1129,18 +1311,28 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool, bool symm) { reinterpret_cast(segmentInfo.address), segmentInfo.total_size, /*errorOnRereg=*/false, // ignores reregistration error +<<<<<<< HEAD /*window*/ symm); // whether to use NCCL symmetric memory +======= + /*window=*/pool->is_symmetric()); // whether to use NCCL symmetric + // memory +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) { const auto key = std::to_string(pool->device()); +<<<<<<< HEAD +======= + auto device = at::Device(at::DeviceType::CUDA, pool->device()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) LOG(INFO) << logPrefix() << "Performing NCCL user buffer deregistration for all buffers in " << "MemPool: " << pool->id() << ", device index: " << key << ", i am " << this; auto ncclComm = getNCCLComm(key); if (ncclComm == nullptr) { +<<<<<<< HEAD C10_THROW_ERROR( DistBackendError, "NCCL communicator has not been initialized before mem pool creation. You can pass `device_id` to init_process_group -- one way of eager initialization -- to work around this issue"); @@ -1158,6 +1350,19 @@ void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) { "Trying to unregister not previously registered pool"); symm = std::get<1>(*mempool_it); iter->second.erase(mempool_it); +======= + // HACK: currently we are using this function for NVLS + // reductions, and that's why using OpType::ALLREDUCE. + // If we end up using this API for zero-copy P2P, we might + // need to refactor and account for different OpType. + ncclComm = initNCCLComm(key, device, OpType::ALLREDUCE); + } + TORCH_INTERNAL_ASSERT(ncclComm != nullptr); + { + std::lock_guard lock(ncclCommMemPoolMapMutex); + auto iter = ncclCommMemPoolMap.find(ncclComm); + iter->second.erase(pool->id()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id()); for (const auto& segmentInfo : snapshot.segments) { @@ -1166,7 +1371,11 @@ void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) { "Mismatch between CUDA memory segment device and pool's device"); // NOLINTNEXTLINE(performance-no-int-to-ptr) ncclComm->deregisterSegment( +<<<<<<< HEAD reinterpret_cast(segmentInfo.address), symm); +======= + reinterpret_cast(segmentInfo.address), pool->is_symmetric()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -1256,6 +1465,7 @@ void ProcessGroupNCCL::enableCollectivesTiming() { enableTiming_.store(true); } +<<<<<<< HEAD c10::intrusive_ptr ProcessGroupNCCL::split( const c10::intrusive_ptr& store, const std::vector& ranks, @@ -1308,6 +1518,8 @@ c10::intrusive_ptr ProcessGroupNCCL::merge( return c10::static_intrusive_pointer_cast(pg); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool ProcessGroupNCCL::waitForFutureOrTimeout( std::future& fut, const std::chrono::milliseconds& timeOutMilSec, @@ -1703,8 +1915,11 @@ void ProcessGroupNCCL::HeartbeatMonitor::join() { void ProcessGroupNCCL::HeartbeatMonitor::runLoop() { c10::setThreadName("pt_nccl_heartbt"); +<<<<<<< HEAD STATIC_SCOPED_WAIT_COUNTER( pytorch.ProcessGroupNCCL__HeartbeatMonitor__runLoop); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint64_t heartBeatCounter = 0ULL; std::string errorMsg; @@ -2032,7 +2247,10 @@ void ProcessGroupNCCL::Watchdog::join() { void ProcessGroupNCCL::Watchdog::run() { c10::setThreadName("pt_nccl_watchdg"); +<<<<<<< HEAD STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__Watchdog__run); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try { VLOG(2) << pg_->logPrefix() << "Process group watchdog thread started!"; @@ -2286,10 +2504,13 @@ void ProcessGroupNCCL::Watchdog::runLoop() { // Work status logging for desync debug desyncDebugger_.logWorkStart(work); +<<<<<<< HEAD // allow watchdog to do an event query on a side thread at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index()); at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal}; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // a work could be started but not completed, so we should not update // lastStartedSeq and lastStartedOpName if the work state is checked // multiple times after the start @@ -2301,6 +2522,13 @@ void ProcessGroupNCCL::Watchdog::runLoop() { pg_->pgStatus_->lastStartedNumelOut = work.numelOut_; } +<<<<<<< HEAD +======= + // allow watchdog to do an event query on a side thread + at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index()); + at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal}; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Clean up completed work if (work.isCompleted()) { // In case user didn't call `work.wait()` with async collectives, @@ -3011,7 +3239,15 @@ std::shared_ptr ProcessGroupNCCL::initNCCLComm( LOG(INFO) << logPrefix() << "Splitting NCCL communicator from " << parentComm->repr(); ncclComm = NCCLComm::split( +<<<<<<< HEAD parentComm.get(), options_->split_color, rank, options_->config); +======= + parentComm.get(), + options_->split_color, + rank, + options_->config, + options_->global_ranks_in_group); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } } @@ -3217,6 +3453,7 @@ void check_gpu_single_tensor( if (!tensor.is_cuda() || tensor.is_sparse()) { C10_THROW_ERROR(ValueError, "Tensors must be CUDA and dense"); } +<<<<<<< HEAD // Check memory format if (!tensor.is_contiguous(tensor.suggest_memory_format())) { // P2P is a bit relaxed, supporting transfer of a transposed tensor @@ -3226,6 +3463,11 @@ void check_gpu_single_tensor( C10_THROW_ERROR( ValueError, "Tensors for P2P must be non-overlapping and dense"); } +======= + // Skip the following requirements for P2P operations + if (!tensor.is_contiguous(tensor.suggest_memory_format())) { + if (p2p) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_WARN_ONCE( "Detected non-contiguous tensor in P2P operations. It is user " "responsibility to guarantee that source and destination tensors have " @@ -3948,6 +4190,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( at::cuda::OptionalCUDAGuard gpuGuard(device); std::string key; +<<<<<<< HEAD int p2pRank = -1, p2pTargetRank = -1; bool isSendRecvSelf = rank_ == peer; // For batch_isend_irecv, ncclGroupStart() would be called upfront @@ -4014,6 +4257,25 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( p2pRank = rank_ <= peer ? 0 : 1; p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; ncclComm = getNCCLComm(key); +======= + int p2pRank = 0, p2pTargetRank = 0; + bool isSendRecvSelf = false; + // For batch_isend_irecv, ncclGroupStart() would be called upfront + bool batchP2P = ncclActiveGroupCounter_ > 0; + if (batchP2P) { + // For batch P2P, we need to treat it like a collective when selecting + // communicator, because other ranks can call into this batch other than my + // rank and my peer + key = getKeyFromDevice(device); + p2pRank = rank_; + p2pTargetRank = peer; + } else { + // For single P2P, preserve the old two-rank behavior (to avoid perf diff) + key = getKeySendRecv(rank_, peer); + p2pRank = rank_ <= peer ? 0 : 1; + isSendRecvSelf = rank_ == peer; + p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!coalescing_state_) { // Bump P2P sequence number. @@ -4025,6 +4287,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // coalesced or individual op_id_++; +<<<<<<< HEAD if (ncclComm == nullptr) { // ncclComm should never be a nullptr in eager init mode. // For lazy init mode, isSendRecvSelf is only valid for non-batch @@ -4032,6 +4295,11 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // argument to be false. ncclComm = initNCCLComm(key, device, opType, p2pRank, isSendRecvSelf && !batchP2P); +======= + std::shared_ptr ncclComm = getNCCLComm(key); + if (ncclComm == nullptr) { + ncclComm = initNCCLComm(key, device, opType, p2pRank, isSendRecvSelf); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if (coalescing_state_ & CoalActive) { @@ -4392,7 +4660,11 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce( auto tensor = tensors.back(); if (tensor.is_complex()) { TORCH_CHECK( +<<<<<<< HEAD c10d::isComplexViewAsRealAllowed(opts.reduceOp), +======= + complexViewAsRealAllowed(opts.reduceOp), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "all_reduce does not support", opts.reduceOp, "on complex tensors"); @@ -4586,7 +4858,11 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce( auto tensor = tensors.back(); if (tensor.is_complex()) { TORCH_CHECK( +<<<<<<< HEAD c10d::isComplexViewAsRealAllowed(opts.reduceOp), +======= + complexViewAsRealAllowed(opts.reduceOp), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "reduce does not support", opts.reduceOp, "on complex tensors"); @@ -5064,12 +5340,23 @@ c10::DeviceIndex ProcessGroupNCCL::guessDeviceId() const { // offset wrt the device id if intra-node GPUs are sharded into multiple // dimensions. int devIdx = globalRank() % localDeviceCount_; +<<<<<<< HEAD if (devIdx == 0) { // only log on first rank of each node LOG(WARNING) << c10::str( "Guessing device ID based on global rank. ", "This can cause a hang if rank to GPU mapping is heterogeneous. ", "You can specify device_id in init_process_group()"); } +======= + LOG(WARNING) + << logPrefix() + << c10::str( + " using GPU ", + devIdx, + " as device used by this process is currently unknown. ", + "This can potentially cause a hang if this rank to GPU mapping is incorrect. ", + "You can specify device_id in init_process_group() to force use of a particular device."); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return static_cast(devIdx); } @@ -5436,7 +5723,10 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); auto inputTensor = inputTensors.back(); +<<<<<<< HEAD check_gpu_single_tensor(inputTensor); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector outputs; @@ -5758,7 +6048,11 @@ at::Tensor ProcessGroupNCCL::allocateTensor( // Pool is created memPool_ = std::make_unique(allocator); // Register so that we call ncclCommRegister on all new allocations +<<<<<<< HEAD registerMemPool(memPool_.get(), /*symmetric*/ false); +======= + registerMemPool(memPool_.get()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) LOG(INFO) << logPrefix() << "Created memory pool"; } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index f7a3a28caceb3..95424d55f805c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -23,7 +23,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -44,11 +47,14 @@ namespace c10d { static std::vector TORCH_NCCL_BCAST_UNIQUEID = { "TORCH_NCCL_BCAST_UNIQUEID"}; +<<<<<<< HEAD // Control EagerInit P2P serialization warning static std::vector TORCH_NCCL_SHOW_EAGER_INIT_P2P_SERIALIZATION_WARNING = { "TORCH_NCCL_SHOW_EAGER_INIT_P2P_SERIALIZATION_WARNING"}; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Control whether to always use high priority streams static std::vector TORCH_NCCL_HIGH_PRIORITY = { "TORCH_NCCL_HIGH_PRIORITY"}; @@ -350,10 +356,13 @@ class TORCH_API ProcessGroupNCCL : public Backend { // or timed out. If timeout, exception will be thrown. bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; +<<<<<<< HEAD void blockCurrentStream() override { synchronize(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void abort() override; // Let current stream wait on the completion of the NCCL work @@ -448,8 +457,13 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Record collective sizes for debug. We only record the size on the first // device as multi-device per process is deprecated +<<<<<<< HEAD size_t numelIn_ = 0; size_t numelOut_ = 0; +======= + size_t numelIn_ = -1; + size_t numelOut_ = -1; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Wrapper method for the static checkForNCCLErrors which can be overridden // for tests. @@ -504,6 +518,26 @@ class TORCH_API ProcessGroupNCCL : public Backend { friend class ProcessGroupNCCL; }; +<<<<<<< HEAD +======= + class CUDAEventCache + : public std::enable_shared_from_this { + public: + CUDAEventCache(); + std::shared_ptr create(bool timing); + static std::shared_ptr get( + at::DeviceIndex device); + + private: + std::mutex cacheMutex_; + // NOTE: We intentionally store raw pointers so that + // we do not attempt to destroy the event objects on process exit, + // because cuda may be gone. + std::array, 2> + eventsArray_; // 0 for timing=false, 1 for timing=true + }; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct Options : Backend::Options { // NOTE: timeout in ProcessGroupNCCL::Options denote the timeout for // operations. This is only used when blockingWait_ is enabled. @@ -525,7 +559,11 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Optional "parent" backend and color to create communicators from // via `ncclCommSplit` +<<<<<<< HEAD c10::intrusive_ptr split_from; +======= + std::shared_ptr split_from; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Color to use for `ncclCommSplit`, values: // * Non-negative value: in group; // * NCCL_SPLIT_NOCOLOR (-1): not in group; @@ -545,6 +583,11 @@ class TORCH_API ProcessGroupNCCL : public Backend { // the int value of `NCCL_SPLIT_NOCOLOR` (-1) instead. int split_color{-2}; #endif +<<<<<<< HEAD +======= + std::vector global_ranks_in_group; + std::string group_name; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; // Helper class related to TORCH_NCCL_DESYNC_DEBUG @@ -786,10 +829,13 @@ class TORCH_API ProcessGroupNCCL : public Backend { return options_; } +<<<<<<< HEAD c10::intrusive_ptr getBackendOptions() override { return c10::static_intrusive_pointer_cast(options_); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const std::string getBackendName() const override { return std::string(NCCL_BACKEND_NAME); } @@ -810,10 +856,13 @@ class TORCH_API ProcessGroupNCCL : public Backend { #endif } +<<<<<<< HEAD void setTimeout(std::chrono::milliseconds timeout) override { options_->timeout = timeout; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void startCoalescing() override; c10::intrusive_ptr endCoalescing() override; @@ -958,6 +1007,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { void enableCollectivesTiming() override; +<<<<<<< HEAD c10::intrusive_ptr split( const c10::intrusive_ptr& store, const std::vector& ranks, @@ -969,6 +1019,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { const int& rank, const int& size) override; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Helper function for iteratively aborting communicators in the provided map void abortCommsFromMap( std::unordered_map>& ncclCommsMap, @@ -1002,7 +1054,11 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Performs NCCL user buffer registration for all buffers in // the given MemPool +<<<<<<< HEAD void registerMemPool(c10::cuda::MemPool* pool, bool symm = false); +======= + void registerMemPool(c10::cuda::MemPool* pool); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Performs NCCL user buffer de-registration for all buffers in // the given MemPool @@ -1092,10 +1148,13 @@ class TORCH_API ProcessGroupNCCL : public Backend { int globalRankStart_; int globalRankStride_; +<<<<<<< HEAD private: bool eagerInit_{false}; bool showSerializationWarning_{true}; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Helper that encapsulates work shared across all collective communication // primitives. The callbacks have the following signatures: // diff --git a/torch/csrc/distributed/c10d/PyProcessGroup.hpp b/torch/csrc/distributed/c10d/PyProcessGroup.hpp index afec6bbe11a9a..b8b4da299aa5a 100644 --- a/torch/csrc/distributed/c10d/PyProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/PyProcessGroup.hpp @@ -151,6 +151,7 @@ class PyProcessGroup : public ProcessGroup { group_desc); } +<<<<<<< HEAD c10::intrusive_ptr splitGroup( const std::vector& ranks, const std::optional& timeout, @@ -181,6 +182,8 @@ class PyProcessGroup : public ProcessGroup { size); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index f944815b99fa3..224b608f14721 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -423,6 +423,7 @@ void TCPStore::ping() { buffer.flush(); uint32_t returnedNonce = client_->receiveValue(); +<<<<<<< HEAD if (nonce != returnedNonce) { C10_THROW_ERROR( DistNetworkError, @@ -431,6 +432,10 @@ void TCPStore::ping() { nonce, returnedNonce)); } +======= + TORCH_INTERNAL_ASSERT( + nonce == returnedNonce, "Ping failed, invalid nonce returned"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void TCPStore::_splitSet( diff --git a/torch/csrc/distributed/c10d/Types.hpp b/torch/csrc/distributed/c10d/Types.hpp index 18db14f5cef04..201b30f27d364 100644 --- a/torch/csrc/distributed/c10d/Types.hpp +++ b/torch/csrc/distributed/c10d/Types.hpp @@ -110,8 +110,11 @@ ReduceOp makeNCCLPreMulSum(const T& factor) { return rop; } +<<<<<<< HEAD TORCH_API bool isComplexViewAsRealAllowed(const ReduceOp& reduceOp); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) constexpr auto kUnsetTimeout = std::chrono::milliseconds(-1); struct BroadcastOptions { diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index c7a2e3523ae4d..654b7329f9cb8 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -444,9 +444,13 @@ inline at::Tensor newLikeFlat( sizes, strides, t.options().memory_format(std::nullopt)); } +<<<<<<< HEAD inline at::Tensor newLikeFlat( std::vector& tensors, bool preserve_strides = true) { +======= +inline at::Tensor newLikeFlat(std::vector& tensors) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (tensors.empty()) { TORCH_CHECK(false, "Received an empty list"); } @@ -454,6 +458,7 @@ inline at::Tensor newLikeFlat( at::DeviceGuard gpuGuard(t.device()); std::vector sizes{static_cast(tensors.size())}; sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end()); +<<<<<<< HEAD if (t.is_contiguous() || !preserve_strides) { // we are checking for memory format, so tensor might // not be contiguous @@ -468,6 +473,9 @@ inline at::Tensor newLikeFlat( strides.insert(strides.end(), t.strides().begin(), t.strides().end()); return at::empty_strided(sizes, strides, t.options()); } +======= + return at::empty(sizes, t.options()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } inline std::vector> getSizes( diff --git a/torch/csrc/distributed/c10d/Work.cpp b/torch/csrc/distributed/c10d/Work.cpp index cdec9185ce537..e10dafac697f4 100644 --- a/torch/csrc/distributed/c10d/Work.cpp +++ b/torch/csrc/distributed/c10d/Work.cpp @@ -1,6 +1,9 @@ #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -101,6 +104,7 @@ bool Work::wait(std::chrono::milliseconds timeout) { return true; } +<<<<<<< HEAD void Work::blockCurrentStream() { // block cuda stream indefinitely until work is completed. std::shared_ptr handle = @@ -110,6 +114,8 @@ void Work::blockCurrentStream() { [handle](c10::ivalue::Future& future) { handle->abort(); }); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void Work::abort() { TORCH_CHECK(false, "Work::abort not implemented."); } diff --git a/torch/csrc/distributed/c10d/Work.hpp b/torch/csrc/distributed/c10d/Work.hpp index 3b743e36d2a05..2e0e85201e297 100644 --- a/torch/csrc/distributed/c10d/Work.hpp +++ b/torch/csrc/distributed/c10d/Work.hpp @@ -110,6 +110,7 @@ class TORCH_API Work : public torch::CustomClassHolder { // virtual bool wait(std::chrono::milliseconds timeout = kNoTimeout); +<<<<<<< HEAD // Blocks the current stream until the work is completed. // This is equivalent to synchronize for CUDA tensors but works for both CPU // tensors and CUDA tensors by using a spinlock CUDA kernel. @@ -117,6 +118,8 @@ class TORCH_API Work : public torch::CustomClassHolder { // If no stream is active it will throw an error. virtual void blockCurrentStream(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) virtual void abort(); // Returns a Future object that will be associated with the completion of diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp index 973197ded14fc..e011d4ecf9c56 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp @@ -4,10 +4,14 @@ #include #include #include +<<<<<<< HEAD #include #include #include #include +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace c10d::control_plane { diff --git a/torch/csrc/distributed/c10d/cuda/AsyncMM.cu b/torch/csrc/distributed/c10d/cuda/AsyncMM.cu index 76f58b8338615..99cdc2b7f2dba 100644 --- a/torch/csrc/distributed/c10d/cuda/AsyncMM.cu +++ b/torch/csrc/distributed/c10d/cuda/AsyncMM.cu @@ -8,7 +8,10 @@ // Two warnings in Cutlass included header files C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") +<<<<<<< HEAD C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && \ CUDA_VERSION >= 12000 @@ -40,7 +43,10 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable") C10_DIAGNOSTIC_POP() C10_DIAGNOSTIC_POP() +<<<<<<< HEAD C10_DIAGNOSTIC_POP() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace { @@ -151,7 +157,11 @@ at::Tensor async_input_mm_impl( reinterpret_cast(b.data_ptr()), stride_B, }, +<<<<<<< HEAD {{}, +======= + {{1, 1}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reinterpret_cast(out.data_ptr()), stride_C, reinterpret_cast(out.data_ptr()), diff --git a/torch/csrc/distributed/c10d/cuda/utils.cpp b/torch/csrc/distributed/c10d/cuda/utils.cpp index 44d5242e1401d..c9c0387fec86a 100644 --- a/torch/csrc/distributed/c10d/cuda/utils.cpp +++ b/torch/csrc/distributed/c10d/cuda/utils.cpp @@ -22,7 +22,11 @@ bool deviceSupportsMulticast(int device_idx) { // - Device support: Determined by querying // CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED at runtime. auto driver_api = c10::cuda::DriverAPI::get(); +<<<<<<< HEAD int multicast_supported = 0; +======= + int multicast_supported; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) C10_CUDA_DRIVER_CHECK(driver_api->cuDeviceGetAttribute_( &multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 0189326683585..f6723764cb61a 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1126,11 +1126,14 @@ This class does not support ``__members__`` property.)"); .def_static( "has_multicast_support", &::c10d::symmetric_memory::has_multicast_support) +<<<<<<< HEAD .def_static("set_backend", &::c10d::symmetric_memory::set_backend) .def_static("get_backend", &::c10d::symmetric_memory::get_backend) .def_static( "get_mempool_allocator", &::c10d::symmetric_memory::get_mempool_allocator) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def_property_readonly("rank", &SymmetricMemory::get_rank) .def_property_readonly("world_size", &SymmetricMemory::get_world_size) .def_property_readonly( @@ -1170,7 +1173,10 @@ This class does not support ``__members__`` property.)"); .def_property_readonly("buffer_size", &SymmetricMemory::get_buffer_size) .def_property_readonly( "signal_pad_size", &SymmetricMemory::get_signal_pad_size) +<<<<<<< HEAD .def_property_readonly("offset", &SymmetricMemory::get_offset) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def( "get_buffer", &SymmetricMemory::get_buffer, @@ -1202,12 +1208,15 @@ This class does not support ``__members__`` property.)"); py::arg("src_rank"), py::arg("channel") = 0, py::arg("timeout_ms") = 0) +<<<<<<< HEAD .def( "get_remote_tensor", &SymmetricMemory::get_remote_tensor, py::arg("peer"), py::arg("sizes"), py::arg("dtype")) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Util functions that are often used together with symmetric memory but // not necessarily directly on symmetric memory. .def_static( @@ -2073,6 +2082,7 @@ communication mechanism. .def("rank", &::c10d::ProcessGroup::getRank, R"(Get the rank of this process group.)") .def("size", &::c10d::ProcessGroup::getSize, R"(Get the size of this process group.)") .def("name", &::c10d::ProcessGroup::getBackendName, R"(Get the name of this process group.)") +<<<<<<< HEAD .def("get_group_store", &::c10d::ProcessGroup::getStore, R"(Get the store of this process group.)") .def( "split_group", @@ -2103,6 +2113,8 @@ communication mechanism. py::arg("group_name") = std::nullopt, py::arg("group_desc") = std::nullopt, py::call_guard()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def( "abort", &::c10d::ProcessGroup::abort, @@ -2126,22 +2138,35 @@ communication mechanism. py::call_guard(), R"(Broadcasts the tensor to all processes in the process group. +<<<<<<< HEAD See :func:`torch.distributed.broadcast` for more details.)") +======= + See :func:`torch.distributed.broadcast for more details.)") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def( "broadcast", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, at::Tensor& x, +<<<<<<< HEAD int rootRank, std::optional timeout) { ::c10d::BroadcastOptions opts; opts.rootRank = rootRank; opts.timeout = timeout.value_or(::c10d::kUnsetTimeout); +======= + int rootRank) { + ::c10d::BroadcastOptions opts; + opts.rootRank = rootRank; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector tensors = {x}; return self->broadcast(tensors, opts); }, py::arg("tensor"), py::arg("root"), +<<<<<<< HEAD py::arg("timeout") = std::nullopt, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard(), R"(Broadcasts the tensor to all processes in the process group. @@ -2159,16 +2184,25 @@ communication mechanism. "allreduce", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, std::vector& xs, +<<<<<<< HEAD const ::c10d::ReduceOp& op, std::optional timeout) { ::c10d::AllreduceOptions opts; opts.reduceOp = op; opts.timeout = timeout.value_or(::c10d::kUnsetTimeout); +======= + const ::c10d::ReduceOp& op) { + ::c10d::AllreduceOptions opts; + opts.reduceOp = op; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self->allreduce(xs, opts); }, py::arg("tensors"), py::arg("op") = ::c10d::ReduceOp::SUM, +<<<<<<< HEAD py::arg("timeout") = std::nullopt, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard(), R"(Allreduces the provided tensors across all processes in the process group. @@ -2178,17 +2212,26 @@ communication mechanism. "allreduce", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, at::Tensor& x, +<<<<<<< HEAD const ::c10d::ReduceOp& op, std::optional timeout) { ::c10d::AllreduceOptions opts; opts.reduceOp = op; opts.timeout = timeout.value_or(::c10d::kUnsetTimeout); +======= + const ::c10d::ReduceOp& op) { + ::c10d::AllreduceOptions opts; + opts.reduceOp = op; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector xs = {x}; return self->allreduce(xs, opts); }, py::arg("tensor"), py::arg("op") = ::c10d::ReduceOp::SUM, +<<<<<<< HEAD py::arg("timeout") = std::nullopt, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard(), R"(Allreduces the provided tensors across all processes in the process group. @@ -2218,19 +2261,29 @@ communication mechanism. [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, at::Tensor& x, int rootRank, +<<<<<<< HEAD const ::c10d::ReduceOp& op, std::optional timeout) { ::c10d::ReduceOptions opts; opts.reduceOp = op; opts.rootRank = rootRank; opts.timeout = timeout.value_or(::c10d::kUnsetTimeout); +======= + const ::c10d::ReduceOp& op) { + ::c10d::ReduceOptions opts; + opts.reduceOp = op; + opts.rootRank = rootRank; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector xs = {x}; return self->reduce(xs, opts); }, py::arg("tensor"), py::arg("root"), py::arg("op") = ::c10d::ReduceOp::SUM, +<<<<<<< HEAD py::arg("timeout") = std::nullopt, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard(), R"(Reduces the provided tensors across all processes in the process group. @@ -2249,6 +2302,7 @@ communication mechanism. "allgather", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, std::vector& output, +<<<<<<< HEAD at::Tensor& input, std::optional timeout) { std::vector> outputs = {output}; @@ -2264,6 +2318,20 @@ communication mechanism. R"(Allgathers the input tensors from all processes across the process group. See :func:`torch.distributed.all_gather` for more details.)") +======= + at::Tensor& input) { + std::vector> outputs = {output}; + std::vector inputs = {input}; + return self->allgather( + outputs, inputs, ::c10d::AllgatherOptions()); + }, + py::arg("output_tensors"), + py::arg("input_tensor"), + py::call_guard(), + R"(Allgathers the input tensors from all processes across the process group. + + See :func:`torch.distributed.all_gather: for more details.)") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def( "_allgather_base", &::c10d::ProcessGroup::_allgather_base, @@ -2307,6 +2375,7 @@ communication mechanism. [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, std::vector& output, at::Tensor& input, +<<<<<<< HEAD int rootRank, std::optional timeout) { ::c10d::GatherOptions opts; @@ -2316,13 +2385,22 @@ communication mechanism. if (!output.empty()) { outputs.push_back(output); } +======= + int rootRank) { + ::c10d::GatherOptions opts; + opts.rootRank = rootRank; + std::vector> outputs = {output}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector inputs = {input}; return self->gather(outputs, inputs, opts); }, py::arg("output_tensors"), py::arg("input_tensor"), py::arg("root"), +<<<<<<< HEAD py::arg("timeout") = std::nullopt, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard(), R"(Gathers the input tensors from all processes across the process group. @@ -2342,6 +2420,7 @@ communication mechanism. [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, at::Tensor& output, std::vector& input, +<<<<<<< HEAD int rootRank, std::optional timeout) { ::c10d::ScatterOptions opts; @@ -2351,13 +2430,22 @@ communication mechanism. if (!input.empty()) { inputs.push_back(input); } +======= + int rootRank) { + ::c10d::ScatterOptions opts; + opts.rootRank = rootRank; + std::vector> inputs = {input}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector outputs = {output}; return self->scatter(outputs, inputs, opts); }, py::arg("output_tensor"), py::arg("input_tensors"), py::arg("root"), +<<<<<<< HEAD py::arg("timeout") = std::nullopt, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard(), R"(Scatters the input tensors from all processes across the process group. @@ -2377,19 +2465,29 @@ communication mechanism. [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, at::Tensor& output, std::vector& input, +<<<<<<< HEAD const ::c10d::ReduceOp& op, std::optional timeout) { +======= + const ::c10d::ReduceOp& op) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector outputs = {output}; std::vector> inputs = {input}; ::c10d::ReduceScatterOptions opts; opts.reduceOp = op; +<<<<<<< HEAD opts.timeout = timeout.value_or(::c10d::kUnsetTimeout); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self->reduce_scatter(outputs, inputs, opts); }, py::arg("output"), py::arg("input"), py::arg("op") = ::c10d::ReduceOp::SUM, +<<<<<<< HEAD py::arg("timeout") = std::nullopt, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard(), R"(Reduces and scatters the input tensors from all processes across the process group. @@ -2422,6 +2520,7 @@ communication mechanism. py::call_guard(), R"(Alltoalls the input tensors from all processes across the process group. +<<<<<<< HEAD See :func:`torch.distributed.all_to_all` for more details.)") .def( "alltoall_base", @@ -2444,6 +2543,9 @@ communication mechanism. R"(Alltoalls the input tensors from all processes across the process group. See :func:`torch.distributed.all_to_all` for more details.)") +======= + See :func:`torch.distributed.all_to_all for more details.)") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def( "alltoall", &::c10d::ProcessGroup::alltoall, @@ -2491,6 +2593,7 @@ communication mechanism. See :func:`torch.distributed.barrier` for more details.)") .def( +<<<<<<< HEAD "barrier", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, std::optional timeout) { @@ -2505,6 +2608,8 @@ communication mechanism. See :func:`torch.distributed.barrier` for more details.)") .def( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "_set_sequence_number_for_group", &::c10d::ProcessGroup::setSequenceNumberForGroup, py::call_guard()) @@ -2515,6 +2620,7 @@ communication mechanism. .def( "monitored_barrier", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, +<<<<<<< HEAD const std::optional& timeout, bool waitAllRanks) { ::c10d::BarrierOptions opts; @@ -2522,18 +2628,30 @@ communication mechanism. return self->monitoredBarrier(opts, waitAllRanks); }, py::arg("timeout") = std::nullopt, +======= + const std::chrono::milliseconds& timeout, + bool waitAllRanks) { + ::c10d::BarrierOptions opts; + opts.timeout = timeout; + return self->monitoredBarrier(opts, waitAllRanks); + }, + py::arg("timeout") = ::c10d::kUnsetTimeout, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::arg("wait_all_ranks") = false, py::call_guard(), R"(Blocks until all processes in the group enter the call, and then all leave the call together. See :func:`torch.distributed.monitored_barrier` for more details.)") +<<<<<<< HEAD .def( "set_timeout", &::c10d::ProcessGroup::setTimeout, py::arg("timeout"), py::call_guard(), R"(Sets the default timeout for all future operations.)") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def_property_readonly( "_device_types", &::c10d::ProcessGroup::getDeviceTypes) .def( @@ -2679,10 +2797,13 @@ The hook must have the following signature: return ivalue.toCustomClass<::c10d::ProcessGroup>(); }); +<<<<<<< HEAD // Thread local process group manipulation module.def("_set_process_group", &::c10d::setProcessGroup); module.def("_current_process_group", &::c10d::currentProcessGroup); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::enum_<::c10d::ProcessGroup::BackendType>( processGroup, "BackendType", @@ -2728,12 +2849,15 @@ The hook must have the following signature: &::c10d::Backend::supportsTimeEstimation, "(test whether the backend supports collective time estimation)") .def( +<<<<<<< HEAD "set_timeout", &::c10d::Backend::setTimeout, py::arg("timeout"), py::call_guard(), R"(Sets the default timeout for all future operations.)") .def( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "broadcast", &::c10d::Backend::broadcast, py::arg("tensors"), @@ -2743,17 +2867,26 @@ The hook must have the following signature: "broadcast", [](const c10::intrusive_ptr<::c10d::Backend>& self, at::Tensor& x, +<<<<<<< HEAD int rootRank, std::optional timeout) { ::c10d::BroadcastOptions opts; opts.rootRank = rootRank; opts.timeout = timeout.value_or(::c10d::kUnsetTimeout); +======= + int rootRank) { + ::c10d::BroadcastOptions opts; + opts.rootRank = rootRank; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector xs = {x}; return self->broadcast(xs, opts); }, py::arg("tensor"), py::arg("root"), +<<<<<<< HEAD py::arg("timeout") = std::nullopt, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard()) .def( "allreduce", @@ -2765,32 +2898,50 @@ The hook must have the following signature: "allreduce", [](const c10::intrusive_ptr<::c10d::Backend>& self, std::vector& xs, +<<<<<<< HEAD const ::c10d::ReduceOp& op, std::optional timeout) { ::c10d::AllreduceOptions opts; opts.reduceOp = op; opts.timeout = timeout.value_or(::c10d::kUnsetTimeout); +======= + const ::c10d::ReduceOp& op) { + ::c10d::AllreduceOptions opts; + opts.reduceOp = op; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self->allreduce(xs, opts); }, py::arg("tensors"), py::arg("op") = ::c10d::ReduceOp::SUM, +<<<<<<< HEAD py::arg("timeout") = std::nullopt, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard()) .def( "allreduce", [](const c10::intrusive_ptr<::c10d::Backend>& self, at::Tensor& x, +<<<<<<< HEAD const ::c10d::ReduceOp& op, std::optional timeout) { ::c10d::AllreduceOptions opts; opts.reduceOp = op; opts.timeout = timeout.value_or(::c10d::kUnsetTimeout); +======= + const ::c10d::ReduceOp& op) { + ::c10d::AllreduceOptions opts; + opts.reduceOp = op; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector xs = {x}; return self->allreduce(xs, opts); }, py::arg("tensor"), py::arg("op") = ::c10d::ReduceOp::SUM, +<<<<<<< HEAD py::arg("timeout") = std::nullopt, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard()) .def( "allreduce_coalesced", @@ -2809,19 +2960,29 @@ The hook must have the following signature: [](const c10::intrusive_ptr<::c10d::Backend>& self, at::Tensor& x, int rootRank, +<<<<<<< HEAD const ::c10d::ReduceOp& op, std::chrono::milliseconds timeout) { ::c10d::ReduceOptions opts; opts.reduceOp = op; opts.rootRank = rootRank; opts.timeout = timeout; +======= + const ::c10d::ReduceOp& op) { + ::c10d::ReduceOptions opts; + opts.reduceOp = op; + opts.rootRank = rootRank; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector xs = {x}; return self->reduce(xs, opts); }, py::arg("tensor"), py::arg("root"), py::arg("op") = ::c10d::ReduceOp::SUM, +<<<<<<< HEAD py::arg("timeout") = ::c10d::kUnsetTimeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard()) .def( "allgather", @@ -2841,6 +3002,7 @@ The hook must have the following signature: "allgather", [](const c10::intrusive_ptr<::c10d::Backend>& self, std::vector& output, +<<<<<<< HEAD at::Tensor& input, std::chrono::milliseconds timeout) { std::vector> outputs = {output}; @@ -2852,6 +3014,16 @@ The hook must have the following signature: py::arg("output_tensors"), py::arg("input_tensor"), py::arg("timeout") = ::c10d::kUnsetTimeout, +======= + at::Tensor& input) { + std::vector> outputs = {output}; + std::vector inputs = {input}; + return self->allgather( + outputs, inputs, ::c10d::AllgatherOptions()); + }, + py::arg("output_tensors"), + py::arg("input_tensor"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard()) .def( "allgather_coalesced", @@ -2872,6 +3044,7 @@ The hook must have the following signature: [](const c10::intrusive_ptr<::c10d::Backend>& self, std::vector& output, at::Tensor& input, +<<<<<<< HEAD int rootRank, std::chrono::milliseconds timeout) { ::c10d::GatherOptions opts; @@ -2881,13 +3054,22 @@ The hook must have the following signature: if (!output.empty()) { outputs.push_back(output); } +======= + int rootRank) { + ::c10d::GatherOptions opts; + opts.rootRank = rootRank; + std::vector> outputs = {output}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector inputs = {input}; return self->gather(outputs, inputs, opts); }, py::arg("output_tensors"), py::arg("input_tensor"), py::arg("root"), +<<<<<<< HEAD py::arg("timeout") = ::c10d::kUnsetTimeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard()) .def( "scatter", @@ -2901,6 +3083,7 @@ The hook must have the following signature: [](const c10::intrusive_ptr<::c10d::Backend>& self, at::Tensor& output, std::vector& input, +<<<<<<< HEAD int rootRank, std::chrono::milliseconds timeout) { ::c10d::ScatterOptions opts; @@ -2910,13 +3093,22 @@ The hook must have the following signature: if (!input.empty()) { inputs.push_back(input); } +======= + int rootRank) { + ::c10d::ScatterOptions opts; + opts.rootRank = rootRank; + std::vector> inputs = {input}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector outputs = {output}; return self->scatter(outputs, inputs, opts); }, py::arg("output_tensor"), py::arg("input_tensors"), py::arg("root"), +<<<<<<< HEAD py::arg("timeout") = ::c10d::kUnsetTimeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard()) .def( "reduce_scatter", @@ -2930,19 +3122,29 @@ The hook must have the following signature: [](const c10::intrusive_ptr<::c10d::Backend>& self, at::Tensor& output, std::vector& input, +<<<<<<< HEAD const ::c10d::ReduceOp& op, std::chrono::milliseconds timeout) { +======= + const ::c10d::ReduceOp& op) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector outputs = {output}; std::vector> inputs = {input}; ::c10d::ReduceScatterOptions opts; opts.reduceOp = op; +<<<<<<< HEAD opts.timeout = timeout; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self->reduce_scatter(outputs, inputs, opts); }, py::arg("output_tensors"), py::arg("input_tensor"), py::arg("op") = ::c10d::ReduceOp::SUM, +<<<<<<< HEAD py::arg("timeout") = ::c10d::kUnsetTimeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard()) .def( "_reduce_scatter_base", @@ -2965,6 +3167,7 @@ The hook must have the following signature: [](::c10d::Backend& self, at::Tensor& output, at::Tensor& input, +<<<<<<< HEAD std::vector& outputSplitSizes, std::vector& inputSplitSizes, std::chrono::milliseconds timeout) { @@ -2972,12 +3175,25 @@ The hook must have the following signature: opts.timeout = timeout; return self.alltoall_base( output, input, outputSplitSizes, inputSplitSizes, opts); +======= + std::vector outputSplitSizes, + std::vector inputSplitSizes) { + return self.alltoall_base( + output, + input, + outputSplitSizes, + inputSplitSizes, + ::c10d::AllToAllOptions()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, py::arg("output"), py::arg("input"), py::arg("output_split_sizes"), py::arg("input_split_sizes"), +<<<<<<< HEAD py::arg("timeout") = ::c10d::kUnsetTimeout, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::call_guard()) .def( "alltoall", @@ -3096,11 +3312,15 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). py::arg("backend"), py::arg("timeout") = kProcessGroupDefaultTimeout) .def_readonly("backend", &::c10d::Backend::Options::backend) +<<<<<<< HEAD .def_readwrite("_timeout", &::c10d::Backend::Options::timeout) .def_readwrite( "global_ranks_in_group", &::c10d::Backend::Options::global_ranks_in_group) .def_readwrite("group_name", &::c10d::Backend::Options::group_name); +======= + .def_readwrite("_timeout", &::c10d::Backend::Options::timeout); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef USE_C10D_GLOO static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; @@ -3116,7 +3336,16 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). processGroupGloo, "_Options", backendOptions) .def(py::init<>()) .def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices) +<<<<<<< HEAD .def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads); +======= + .def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads) + .def_readwrite( + "global_ranks_in_group", + &::c10d::ProcessGroupGloo::Options::global_ranks_in_group) + .def_readwrite( + "group_name", &::c10d::ProcessGroupGloo::Options::group_name); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) processGroupGloo .def_static( @@ -3214,7 +3443,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). R"(Create a new ProcessGroupGloo instance.)") .def( "_set_default_timeout", +<<<<<<< HEAD &::c10d::ProcessGroupGloo::setTimeout, +======= + [](const c10::intrusive_ptr<::c10d::ProcessGroupGloo>& self, + std::chrono::milliseconds timeout) { + self->getOptions()->timeout = timeout; + }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::arg("timeout"), py::call_guard()) .def_property_readonly( @@ -3311,7 +3547,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). &::c10d::ProcessGroupNCCL::getCommSplitCounter) .def( "_set_default_timeout", +<<<<<<< HEAD &::c10d::ProcessGroupNCCL::setTimeout, +======= + [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self, + std::chrono::milliseconds timeout) { + self->getOptions()->timeout = timeout; + }, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::arg("timeout"), py::call_guard()) .def( @@ -3344,11 +3587,15 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). .def( "perform_nocolor_split", &::c10d::ProcessGroupNCCL::performNocolorSplit) +<<<<<<< HEAD .def( "register_mem_pool", &::c10d::ProcessGroupNCCL::registerMemPool, py::arg("pool"), py::arg("symm") = false) +======= + .def("register_mem_pool", &::c10d::ProcessGroupNCCL::registerMemPool) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def( "deregister_mem_pool", &::c10d::ProcessGroupNCCL::deregisterMemPool) @@ -3413,11 +3660,14 @@ for details. #ifdef NCCL_HAS_NVLS_CTAS .def_readwrite("nvls_ctas", &ncclConfig_t::nvlsCTAs) #endif +<<<<<<< HEAD .def( "unsafe_get_ptr", [](const ncclConfig_t& self) { return reinterpret_cast(&self); }) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def_property( "net_name", [](const ncclConfig_t& self) { return self.netName; }, @@ -3482,6 +3732,14 @@ Example:: "split_from", &::c10d::ProcessGroupNCCL::Options::split_from) .def_readwrite( "split_color", &::c10d::ProcessGroupNCCL::Options::split_color) +<<<<<<< HEAD +======= + .def_readwrite( + "global_ranks_in_group", + &::c10d::ProcessGroupNCCL::Options::global_ranks_in_group) + .def_readwrite( + "group_name", &::c10d::ProcessGroupNCCL::Options::group_name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def( "__copy__", [](const ::c10d::ProcessGroupNCCL::Options& self) { @@ -3520,6 +3778,7 @@ Example:: .def( py::init([](const c10::intrusive_ptr<::c10d::Store>& store, int rank, +<<<<<<< HEAD int size, c10::intrusive_ptr<::c10d::ProcessGroupXCCL::Options> options) { @@ -3563,6 +3822,19 @@ Example:: )") .def("get_xccl_version", [] { return ::c10d::getXcclVersion(); }); +======= + int size) { + // gil_scoped_release is not safe as a call_guard in init. + // https://github.com/pybind/pybind11/issues/5473 + py::gil_scoped_release nogil{}; + + return c10::make_intrusive<::c10d::ProcessGroupXCCL>( + store, rank, size); + }), + py::arg("store"), + py::arg("rank"), + py::arg("size")); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif #ifdef USE_C10D_UCC @@ -3708,6 +3980,7 @@ such as `dist.all_reduce(tensor, async_op=True)`. or timed out. If timeout, exception will be thrown. )") .def( +<<<<<<< HEAD "block_current_stream", &::c10d::Work::blockCurrentStream, py::call_guard(), @@ -3723,6 +3996,8 @@ such as `dist.all_reduce(tensor, async_op=True)`. Work object result asynchronously. )") .def( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "get_future_result", [](::c10d::Work& work) -> std::shared_ptr { @@ -3821,6 +4096,7 @@ such as `dist.all_reduce(tensor, async_op=True)`. auto fakeProcessGroup = intrusive_ptr_no_gil_destructor_class_<::c10d::FakeProcessGroup>( +<<<<<<< HEAD module, "FakeProcessGroup", backend); intrusive_ptr_class_<::c10d::FakeProcessGroup::Options>( fakeProcessGroup, "Options", backendOptions) @@ -3842,6 +4118,16 @@ such as `dist.all_reduce(tensor, async_op=True)`. c10::make_intrusive<::c10d::FakeProcessGroup::Options>()) .def_property_readonly( "options", &::c10d::FakeProcessGroup::getBackendOptions); +======= + module, "FakeProcessGroup", backend) + .def( + py::init([](int rank, int size) { + return c10::make_intrusive<::c10d::FakeProcessGroup>( + rank, size); + }), + py::arg("rank"), + py::arg("world_size")); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto fakeWork = intrusive_ptr_no_gil_destructor_class_<::c10d::FakeWork>( module, "FakeWork", work) diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 752e18c8dbf7d..beff15c9f80bb 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -964,6 +964,27 @@ void Reducer::all_reduce_bucket(Bucket& bucket) { // do any extra synchronization here. const auto& tensor = bucket.gradients; +<<<<<<< HEAD +======= + // TODO(@egienvalue): remove special case after view ops are fully + // supported on MTIA. + // If the bucket.gradients is on MTIA, bucket.bucket_views_in might not + // point to the same storage as bucket.gradients due to the special + // memory layout. It has to explicitly copy the data back to 1-D gradients. + if (tensor.is_mtia()) { + for (const auto i : c10::irange(bucket.variables.size())) { + const auto offset = bucket.offsets[i]; + const auto length = bucket.lengths[i]; + if (!bucket.bucket_views_in[i].is_alias_of(tensor)) { + tensor + .narrow( + 0, static_cast(offset), static_cast(length)) + .copy_(bucket.bucket_views_in[i].flatten()); + } + } + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GradBucket grad_bucket( next_bucket_, buckets_.size(), @@ -1267,8 +1288,17 @@ void Reducer::initialize_bucket_views(Reducer::Bucket& bucket) { auto& v = bucket.variables[i]; const auto offset = bucket.offsets[i]; const auto length = bucket.lengths[i]; +<<<<<<< HEAD if (v.is_non_overlapping_and_dense()) { +======= + // TODO(@egienvalue): remove special case after view ops are fully + // supported on MTIA. + // In general, on MTIA, due to the special memory layout, it doesn't + // support as_strided which creates a view tensor and aten::view will + // create a new tensor on MTIA for now. + if (v.is_non_overlapping_and_dense() && !v.is_mtia()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // If the param's memory is dense, match its layout, anticipating // the autograd engine (AccumulateGrad) will also create gradients // matching its layout. @@ -1322,8 +1352,17 @@ void Reducer::populate_bucket_views_out( const auto& v = bucket.variables[i]; const auto offset = bucket.offsets[i]; const auto length = bucket.lengths[i]; +<<<<<<< HEAD if (v.is_non_overlapping_and_dense()) { +======= + // TODO(@egienvalue): remove special case after view ops are fully + // supported on MTIA. + // In general, on MTIA, due to the special memory layout, it doesn't + // support as_strided which creates a view tensor and aten::view will + // create a new tensor on MTIA for now. + if (v.is_non_overlapping_and_dense() && !v.is_mtia()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // If the param's memory is dense, match its layout, anticipating // the autograd engine (AccumulateGrad) will also create gradients // matching its layout. diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp index b23722ec384ab..a557ea0d5bada 100644 --- a/torch/csrc/distributed/c10d/socket.cpp +++ b/torch/csrc/distributed/c10d/socket.cpp @@ -193,6 +193,7 @@ class SocketImpl { }; std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len) { +<<<<<<< HEAD // It can be be very slow to repeatedly hit DNS resolution failure, but its // very helpful to have DNS names in logs by default. So we try to use DNS but // if we hit a transient failure we just disable it for the remainder of the @@ -237,6 +238,40 @@ std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len) { } } return "?UNKNOWN?"; +======= + char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT + + if (int err = ::getnameinfo( + addr, len, host, NI_MAXHOST, port, NI_MAXSERV, NI_NUMERICSERV)) { + C10D_WARNING( + "The hostname of the client socket cannot be retrieved. err={}", err); + + // if we can't resolve the hostname, display the IP address + if (addr->sa_family == AF_INET) { + struct sockaddr_in* psai = (struct sockaddr_in*)&addr; + // NOLINTNEXTLINE(*array*) + char ip[INET_ADDRSTRLEN]; + if (inet_ntop(addr->sa_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN) != + nullptr) { + return fmt::format("{}:{}", ip, psai->sin_port); + } + } else if (addr->sa_family == AF_INET6) { + struct sockaddr_in6* psai = (struct sockaddr_in6*)&addr; + // NOLINTNEXTLINE(*array*) + char ip[INET6_ADDRSTRLEN]; + if (inet_ntop( + addr->sa_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN) != + nullptr) { + return fmt::format("[{}]:{}", ip, psai->sin6_port); + } + } + return "?UNKNOWN?"; + } + if (addr->sa_family == AF_INET) { + return fmt::format("{}:{}", host, port); + } + return fmt::format("[{}]:{}", host, port); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace c10d::detail diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h index 0abbc84ebe52a..bfcf374cb6671 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h @@ -115,20 +115,29 @@ __device__ __forceinline__ void wait_signal(uint32_t* addr) { // Pattern 0: Ensures that all writes to symm_mem buffers from previous // kernels across all devices are visible to the current kernel: // +<<<<<<< HEAD // sync_remote_blocks(...); +======= +// sync_remote_blocks(...); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // __syncthreads(); // // Pattern 1: Ensures that all writes to symm_mem buffers from the current // block are visible to all remote blocks with matching blockIdx: // // __syncthreads(); +<<<<<<< HEAD // sync_remote_blocks(...); +======= +// sync_remote_blocks(...); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // __syncthreads(); // // Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe // for writing by subsequent kernels across all devices. // // __syncthreads(); +<<<<<<< HEAD // sync_remote_blocks(...); template __device__ __forceinline__ void sync_remote_blocks( @@ -153,6 +162,42 @@ __device__ __forceinline__ void sync_remote_blocks( } } }; +======= +// sync_remote_blocks(...); +template +__device__ __forceinline__ void sync_remote_blocks( + uint32_t** signal_pads, + size_t rank, + size_t world_size); + +template <> +__device__ __forceinline__ void sync_remote_blocks( + uint32_t** signal_pads, + size_t rank, + size_t world_size) { + if (threadIdx.x < world_size) { + auto target_rank = threadIdx.x; + put_signal( + signal_pads[target_rank] + blockIdx.x * world_size + rank); + wait_signal( + signal_pads[rank] + blockIdx.x * world_size + target_rank); + } +} + +template <> +__device__ __forceinline__ void sync_remote_blocks( + uint32_t** signal_pads, + size_t rank, + size_t world_size) { + if (threadIdx.x < world_size) { + auto target_rank = threadIdx.x; + put_signal( + signal_pads[target_rank] + blockIdx.x * world_size + rank); + wait_signal( + signal_pads[rank] + blockIdx.x * world_size + target_rank); + } +} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template struct MultimemLdReduce { @@ -256,7 +301,11 @@ __device__ __inline__ T add_bf16x2(T a, T b) { __hip_bfloat16 bf[2]; } _bf2f_a = {.f = 0}, _bf2f_b = {.f = 0}; +<<<<<<< HEAD //__hip_bfloat162 is a struct with two __hip_bfloat16 elements called x and y +======= + //__hip_bfloat162 is a struct wtih two __hip_bfloat16 elements called x and y +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // This typecasts input a and b as bfloat16 and maps to low bits of a float // and does the addition in float _bf2f_a.bf[1] = reinterpret_cast<__hip_bfloat162*>(&a)->x; diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu index bd1446c579411..f2030d87b5dd8 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu @@ -1,3 +1,4 @@ +<<<<<<< HEAD #include #include #include @@ -6,6 +7,15 @@ #include #include #include +======= +#include +#include +#include +#include + +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -23,6 +33,7 @@ #define CUDART_SUPPORTS_MULTICAST #endif +<<<<<<< HEAD // add these definitions so that we can compile with CUDA < 12.3 // borrowed from // https://github.com/NVIDIA/nccl/blob/3ea7eedf3b9b94f1d9f99f4e55536dfcbd23c1ca/src/include/p2p.h#L20 @@ -35,6 +46,8 @@ typedef struct CUmemFabricHandle_st { typedef CUmemFabricHandle_v1 CUmemFabricHandle; #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace c10d { namespace symmetric_memory { @@ -47,6 +60,7 @@ AllocationRef::AllocationRef( void* ptr, HandleType handle, size_t block_size, +<<<<<<< HEAD int device_idx, bool is_multicast) : ptr(ptr), @@ -54,6 +68,13 @@ AllocationRef::AllocationRef( block_size(block_size), device_idx(device_idx), is_multicast(is_multicast) {} +======= + int device_idx) + : ptr(ptr), + handle(handle), + block_size(block_size), + device_idx(device_idx) {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AllocationRef::~AllocationRef() { if (is_finalizing()) { @@ -66,12 +87,15 @@ AllocationRef::~AllocationRef() { auto driver_api = c10::cuda::DriverAPI::get(); C10_CUDA_DRIVER_CHECK( driver_api->cuMemUnmap_(reinterpret_cast(ptr), block_size)); +<<<<<<< HEAD #if defined(CUDART_SUPPORTS_MULTICAST) if (is_multicast) { C10_CUDA_DRIVER_CHECK( driver_api->cuMulticastUnbind_(handle, device_idx, 0, block_size)); } #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handle)); #elif defined(USE_ROCM) C10_HIP_CHECK(hipMemUnmap(reinterpret_cast(ptr), block_size)); @@ -146,6 +170,81 @@ void* CUDASymmetricMemory::get_multicast_ptr() { return mc_addr_; } +<<<<<<< HEAD +======= +at::Tensor CUDASymmetricMemory::get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) { + const size_t numel = std::accumulate( + sizes.begin(), + sizes.end(), + static_cast(1), + std::multiplies()); + const auto element_size = c10::elementSize(dtype); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= buffer_size_, + "CUDASymmetricMemory::get_buffer: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + buffer_size_, + " bytes)"); + auto data_ptr = reinterpret_cast(buffers_[rank]) + + storage_offset * element_size; + auto device = c10::Device(c10::DeviceType::CUDA, local_device_idx_); + auto options = at::TensorOptions().dtype(dtype).device(device); + return at::for_blob(data_ptr, sizes) + .options(options) + .target_device(device) + .make_tensor(); +} + +at::Tensor CUDASymmetricMemory::get_signal_pad( + int rank, + c10::IntArrayRef sizes, + std::optional dtype, + int64_t storage_offset) { + // If the dtype is unspecified, default it to UInt32, as it + // is the most common type for signaling purposes. + if (!dtype.has_value()) { + dtype = c10::ScalarType::UInt32; + } + + // If the shape is unspecified, treat the signal pad as a 1d tensor. + const auto element_size = c10::elementSize(*dtype); + std::vector shape; + if (!sizes.empty()) { + shape = sizes.vec(); + } else { + shape.push_back(signal_pad_size / element_size); + } + + const size_t numel = std::accumulate( + shape.begin(), + shape.end(), + static_cast(1), + std::multiplies()); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= signal_pad_size, + "CUDASymmetricMemory::get_signal_pad: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + signal_pad_size, + " bytes)"); + auto data_ptr = reinterpret_cast(signal_pads_[rank]) + + storage_offset * element_size; + auto device = c10::Device(c10::DeviceType::CUDA, local_device_idx_); + auto options = at::TensorOptions().dtype(*dtype).device(device); + return at::for_blob(data_ptr, shape) + .options(options) + .target_device(device) + .make_tensor(); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void check_channel(int channel, int world_size) { TORCH_CHECK( channel >= 0, @@ -204,11 +303,15 @@ static __global__ void barrier_kernel( void CUDASymmetricMemory::barrier(int channel, size_t timeout_ms) { check_channel(channel, world_size_); c10::cuda::CUDAGuard guard(local_device_idx_); +<<<<<<< HEAD barrier_kernel<<< 1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>( +======= + barrier_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reinterpret_cast(signal_pads_dev_), channel, rank_, @@ -246,11 +349,15 @@ void CUDASymmetricMemory::put_signal( size_t timeout_ms) { check_channel(channel, world_size_); c10::cuda::CUDAGuard guard(local_device_idx_); +<<<<<<< HEAD put_signal_kernel<<< 1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>( +======= + put_signal_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reinterpret_cast(signal_pads_dev_), dst_rank, channel, @@ -294,11 +401,15 @@ void CUDASymmetricMemory::wait_signal( size_t timeout_ms) { check_channel(channel, world_size_); c10::cuda::CUDAGuard guard(local_device_idx_); +<<<<<<< HEAD wait_signal_kernel<<< 1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>( +======= + wait_signal_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reinterpret_cast(signal_pads_dev_), src_rank, channel, @@ -316,6 +427,7 @@ int CUDASymmetricMemory::get_world_size() { return world_size_; } +<<<<<<< HEAD c10::Device CUDASymmetricMemory::get_device() { return c10::Device(c10::DeviceType::CUDA, local_device_idx_); } @@ -324,6 +436,8 @@ bool CUDASymmetricMemory::world_within_direct_access() { return true; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Block::Block( c10::intrusive_ptr alloc_ref, int device_idx, @@ -338,15 +452,22 @@ Block::Block( signal_pad_offset(signal_pad_offset), default_group_name(std::move(group_name)) {} +<<<<<<< HEAD namespace { using Expandable_Segments_Handle_Type = c10::cuda::CUDACachingAllocator::Expandable_Segments_Handle_Type; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void* CUDASymmetricMemoryAllocator::alloc( size_t size, int device_idx, const std::optional& group_name) { +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) size_t signal_pad_offset = at::round_up(size, 16UL); size_t block_size = signal_pad_offset + signal_pad_size; c10::cuda::CUDAGuard guard(device_idx); @@ -357,6 +478,7 @@ void* CUDASymmetricMemoryAllocator::alloc( prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; // NOLINTNEXTLINE(bugprone-signed-char-misuse) prop.location.id = device_idx; +<<<<<<< HEAD bool has_fabric_support = at::cuda::get_fabric_access(device_idx); LOG(INFO) << "CUDASymmetricMemoryAllocator::alloc: has_fabric_support " << has_fabric_support; if (handle_type_ == Expandable_Segments_Handle_Type::UNSPECIFIED) { @@ -367,6 +489,10 @@ void* CUDASymmetricMemoryAllocator::alloc( } else { prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC; } +======= + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) size_t granularity; auto driver_api = c10::cuda::DriverAPI::get(); @@ -375,10 +501,17 @@ void* CUDASymmetricMemoryAllocator::alloc( block_size = at::round_up(block_size, granularity); HandleType handle; +<<<<<<< HEAD C10_CUDA_DRIVER_CHECK(driver_api->cuMemCreate_(&handle, block_size, &prop, 0)); #elif defined(USE_ROCM) handle_type_ = Expandable_Segments_Handle_Type::POSIX_FD; +======= + C10_CUDA_DRIVER_CHECK( + driver_api->cuMemCreate_(&handle, block_size, &prop, 0)); + +#elif defined(USE_ROCM) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) hipMemAllocationProp prop = {}; prop.type = hipMemAllocationTypePinned; prop.location.type = hipMemLocationTypeDevice; @@ -386,17 +519,25 @@ void* CUDASymmetricMemoryAllocator::alloc( prop.location.id = device_idx; prop.requestedHandleType = hipMemHandleTypePosixFileDescriptor; +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) size_t granularity; C10_HIP_CHECK(hipMemGetAllocationGranularity( &granularity, &prop, hipMemAllocationGranularityRecommended)); block_size = at::round_up(block_size, granularity); HandleType handle; +<<<<<<< HEAD C10_HIP_CHECK(hipMemCreate( reinterpret_cast(&handle), block_size, &prop, 0)); +======= + C10_HIP_CHECK(hipMemCreate(reinterpret_cast(&handle), block_size, &prop, 0)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #else TORCH_CHECK( @@ -444,7 +585,10 @@ struct RendezvousRequest { size_t buffer_size; size_t signal_pad_offset; bool has_multicast_support; +<<<<<<< HEAD char hostname[HOST_NAME_MAX + 1]; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; void validate_rendezvous_requests( @@ -452,6 +596,7 @@ void validate_rendezvous_requests( int world_size) { TORCH_CHECK(reqs.size() == (size_t)world_size); +<<<<<<< HEAD // For NVL72 systems, multiple hosts can be within a single nvlink domain. // Multiple blocks will have same device_idx but they are on different hosts. // Use (hostname, device_idx) pair to uniquely identify each allocation. @@ -461,6 +606,15 @@ void validate_rendezvous_requests( } if (!allow_overlapping_devices() && device_host_pairs.size() < (size_t)world_size) { +======= + std::unordered_set device_indices; + device_indices.reserve(world_size); + for (auto req : reqs) { + device_indices.insert(req.device_idx); + } + if (!allow_overlapping_devices() && + device_indices.size() < (size_t)world_size) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK( false, "CUDASymmetricMemoryAllocator::rendezvous: ", @@ -498,12 +652,19 @@ static bool check_group_multicast_support( } } +<<<<<<< HEAD template +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static void init_multicast_for_block( HandleType& mc_handle, void*& mc_addr, const c10::intrusive_ptr& block, +<<<<<<< HEAD std::conditional_t ipc_channel, +======= + IpcChannel& ipc_channel, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const std::vector& pids, const c10::intrusive_ptr& store, int rank, @@ -511,6 +672,7 @@ static void init_multicast_for_block( #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) && \ defined(CUDART_SUPPORTS_MULTICAST) auto driver_api = c10::cuda::DriverAPI::get(); +<<<<<<< HEAD auto handleType = use_fabric_handle ? CU_MEM_HANDLE_TYPE_FABRIC : CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; @@ -521,6 +683,12 @@ static void init_multicast_for_block( CUmulticastObjectProp mc_prop{}; mc_prop.numDevices = world_size; mc_prop.handleTypes = handleType; +======= + if (rank == 0) { + CUmulticastObjectProp mc_prop{}; + mc_prop.numDevices = world_size; + mc_prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mc_prop.size = block->block_size; // create a multicast object, which acts as a handle that allows multiple @@ -537,6 +705,7 @@ static void init_multicast_for_block( << "\". Gracefully skipping multicast initialization. " << "However, this is unexpected. Please report the issue on GitHub."; // Allow peers gracefully skip multicast initialization by sending -1 +<<<<<<< HEAD // TODO: allow graceful skip for fabric if constexpr (!use_fabric_handle) { ipc_channel.broadcast_fds(rank, 0, pids, -1); @@ -577,6 +746,30 @@ static void init_multicast_for_block( C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( &mc_handle, (void*)&(mc_handles[0]), CU_MEM_HANDLE_TYPE_FABRIC)); } +======= + ipc_channel.broadcast_fds(rank, 0, pids, -1); + return; + } + + int mc_fd; + // using the CUDA Driver API to export a multicast object into a POSIX file descriptor. + C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( + &mc_fd, mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0)); + ipc_channel.broadcast_fds(rank, 0, pids, mc_fd); + // Ref count is incremented as soon as SCM_RIGHTS send happens + close(mc_fd); + } else { + int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1); + if (mc_fd == -1) { + return; + } + // Convert back to a handle from the broadcasted POSIX file descriptor. + C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( + &mc_handle, + (void*)(uintptr_t)mc_fd, + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + close(mc_fd); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // All rank adds their physical allocation to the multicast object @@ -590,6 +783,7 @@ static void init_multicast_for_block( #endif } +<<<<<<< HEAD namespace { template c10::intrusive_ptr make_symm_mem( @@ -757,6 +951,12 @@ c10::intrusive_ptr make_symm_mem( c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( void* ptr, const std::optional& group_name) { +======= +c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( + void* ptr, + const std::optional& group_name) { + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto block = find_block(ptr); if (block == nullptr) { return nullptr; @@ -784,6 +984,7 @@ c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( return it->second; } +<<<<<<< HEAD auto group_info = get_group_info(group_name_); TORCH_INTERNAL_ASSERT( @@ -792,6 +993,113 @@ c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( handle_type_ == Expandable_Segments_Handle_Type::FABRIC_HANDLE; auto symm_mem = use_fabric ? make_symm_mem(ptr, block, group_info) : make_symm_mem(ptr, block, group_info); +======= + c10::cuda::CUDAGuard guard(block->device_idx); + + // Currently, IpcChannel is using a file based socket for inter-process communication + IpcChannel ipc_channel; + auto group_info = get_group_info(group_name_); + auto store = group_info.store; + int rank = group_info.rank; + int world_size = group_info.world_size; + int block_fd; + +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + auto driver_api = c10::cuda::DriverAPI::get(); + // using the CUDA Driver API to export a GPU memory block as a + // POSIX file descriptor (FD), so it can be shared across processes via IPC. + C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( + &block_fd, + block->alloc_ref->handle, + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, + 0)); +#elif defined (USE_ROCM) + C10_HIP_CHECK(hipMemExportToShareableHandle( + &block_fd, block->alloc_ref->handle, hipMemHandleTypePosixFileDescriptor, 0)); +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif + + auto local_req = RendezvousRequest{ + .device_idx = block->device_idx, + .pid = getpid(), + .block_size = block->block_size, + .buffer_size = block->buffer_size, + .signal_pad_offset = block->signal_pad_offset, + .has_multicast_support = device_has_multicast_support(block->device_idx)}; + auto reqs = storeExchange.all_gather(store, rank, world_size, local_req); + validate_rendezvous_requests(reqs, world_size); + + std::vector pids(world_size); + for (int r = 0; r < world_size; ++r) { + pids[r] = reqs[r].pid; + } + auto imported_fds = ipc_channel.all_gather_fds(rank, pids, block_fd); + + std::vector handles(world_size); + std::vector buffers(world_size, nullptr); + std::vector signal_pads(world_size, nullptr); + + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + handles[r] = block->alloc_ref->handle; + buffers[r] = ptr; + signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset); + continue; + } + // This api imports a GPU memory allocation that was previously exported as a file + // descriptor and it returns a memory handle. +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( + &handles[r], + (void*)(uintptr_t)imported_fds[r], + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); +#elif defined (USE_ROCM) + C10_HIP_CHECK(hipMemImportFromShareableHandle( + &handles[r], + (void*)(uintptr_t)&(imported_fds[r]), + hipMemHandleTypePosixFileDescriptor)); +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif + map_block(&buffers[r], handles[r], block->block_size, block->device_idx); + signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset); + close(imported_fds[r]); + } + storeExchange.barrier(store, rank, world_size); + close(block_fd); + + HandleType mc_handle{}; + void* mc_addr = nullptr; + bool group_has_multicast_support = check_group_multicast_support(reqs); + if (!allow_overlapping_devices() && group_has_multicast_support) { + init_multicast_for_block( + mc_handle, mc_addr, block, ipc_channel, pids, store, rank, world_size); + } + + std::vector> alloc_refs; + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + alloc_refs.emplace_back(block->alloc_ref); + continue; + } + alloc_refs.push_back(c10::make_intrusive( + buffers[r], handles[r], block->block_size, block->device_idx)); + } + + auto symm_mem = c10::make_intrusive( + std::move(alloc_refs), + std::move(buffers), + std::move(signal_pads), + mc_handle, + mc_addr, + block->buffer_size, + block->device_idx, + group_info.rank, + group_info.world_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) block->symm_mems[group_name_] = symm_mem; return symm_mem; } @@ -800,6 +1108,7 @@ bool CUDASymmetricMemoryAllocator::has_multicast_support(int device_idx) { return device_has_multicast_support(device_idx); } +<<<<<<< HEAD c10::DeviceType CUDASymmetricMemoryAllocator::supported_device_type() { return c10::DeviceType::CUDA; } @@ -808,6 +1117,8 @@ std::string CUDASymmetricMemoryAllocator::name() { return "CUDA"; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::intrusive_ptr CUDASymmetricMemoryAllocator::find_block(void* ptr) { std::shared_lock lock(mutex_); auto it = ptr_to_block_.find(ptr); @@ -819,6 +1130,7 @@ c10::intrusive_ptr CUDASymmetricMemoryAllocator::find_block(void* ptr) { struct RegisterCUDASymmetricMemoryAllocator { RegisterCUDASymmetricMemoryAllocator() { +<<<<<<< HEAD auto allocator = c10::make_intrusive(); // Query backend used for CUDA tensor // "CUDA" backend stands for this implementation @@ -828,6 +1140,14 @@ struct RegisterCUDASymmetricMemoryAllocator { } else { // Register availability in case `set_backend` is called dynamically register_availability("CUDA", allocator); +======= + // Query backend used for CUDA tensor + // "CUDA" backend stands for this implementation + if (getSymmMemBackendCUDA() == "CUDA") { + register_allocator( + c10::DeviceType::CUDA, + c10::make_intrusive()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } }; diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp index 39a6122bcdb27..cf47341e6ac4f 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp @@ -1,7 +1,10 @@ #pragma once #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -15,14 +18,21 @@ struct AllocationRef : public c10::intrusive_ptr_target { HandleType handle; size_t block_size; int device_idx; +<<<<<<< HEAD bool is_multicast; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AllocationRef( void* ptr, HandleType handle, size_t block_size, +<<<<<<< HEAD int device_idx, bool is_multicast = false); +======= + int device_idx); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ~AllocationRef(); }; @@ -52,14 +62,32 @@ class CUDASymmetricMemory : public SymmetricMemory { bool has_multicast_support() override; void* get_multicast_ptr() override; +<<<<<<< HEAD +======= + at::Tensor get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) override; + + at::Tensor get_signal_pad( + int rank, + c10::IntArrayRef sizes, + std::optional dtype, + int64_t storage_offset) override; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void barrier(int channel, size_t timeout_ms) override; void put_signal(int dst_rank, int channel, size_t timeout_ms) override; void wait_signal(int src_rank, int channel, size_t timeout_ms) override; int get_rank() override; int get_world_size() override; +<<<<<<< HEAD c10::Device get_device() override; bool world_within_direct_access() override; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) private: std::vector> alloc_refs_; @@ -108,17 +136,23 @@ class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator { void* ptr, const std::optional& group_name) override; bool has_multicast_support(int device_idx) override; +<<<<<<< HEAD c10::DeviceType supported_device_type() override; std::string name() override; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) private: c10::intrusive_ptr find_block(void* ptr); std::shared_mutex mutex_; std::unordered_map> ptr_to_block_; +<<<<<<< HEAD c10::cuda::CUDACachingAllocator::Expandable_Segments_Handle_Type handle_type_ = c10::cuda::CUDACachingAllocator:: Expandable_Segments_Handle_Type::UNSPECIFIED; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; } // namespace c10d::symmetric_memory diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu index 572c5a8fd369d..c0f376f68c324 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu @@ -104,8 +104,12 @@ void init_elementwise_launch_config( size_t max_num_blocks, size_t max_num_threads, int& num_blocks, +<<<<<<< HEAD int& num_threads, int world_size) { +======= + int& num_threads) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Align to preserve alignment in each split const size_t aligned_numel = at::round_up(numel, alignment * splits); const size_t numel_per_split = aligned_numel / splits; @@ -113,11 +117,17 @@ void init_elementwise_launch_config( if (numel_per_split <= max_num_threads * numel_per_thread) { num_blocks = 1; +<<<<<<< HEAD num_threads = at::ceil_div(numel_per_split, numel_per_thread); // `sync_remote_blocks` maps threads to peers, so we need to make sure there // are enough threads num_threads = max(num_threads, world_size); num_threads = at::round_up(num_threads, at::cuda::warp_size()); +======= + num_threads = at::round_up( + at::ceil_div(numel_per_split, numel_per_thread), + static_cast(at::cuda::warp_size())); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { num_blocks = std::min( at::ceil_div(numel_per_split, max_num_threads * numel_per_thread), @@ -137,7 +147,11 @@ static __global__ void multimem_all_reduce_kernel( static_assert(alignment % sizeof(T) == 0); constexpr size_t numel_per_thread = alignment / sizeof(T); +<<<<<<< HEAD sync_remote_blocks(signal_pads, rank, world_size); +======= + sync_remote_blocks(signal_pads, rank, world_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __syncthreads(); const size_t numel_per_rank = @@ -155,7 +169,11 @@ static __global__ void multimem_all_reduce_kernel( } __syncthreads(); +<<<<<<< HEAD sync_remote_blocks(signal_pads, rank, world_size); +======= + sync_remote_blocks(signal_pads, rank, world_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } at::Tensor multimem_all_reduce_( @@ -188,8 +206,12 @@ at::Tensor multimem_all_reduce_( 8, 1024, num_blocks, +<<<<<<< HEAD num_threads, symm_mem->get_world_size()); +======= + num_threads); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AT_DISPATCH_FLOAT_AND_BFLOAT16( input.scalar_type(), "multimem_all_reduce_", [&]() { @@ -223,7 +245,11 @@ static __global__ void multimem_one_shot_all_reduce_kernel( static_assert(alignment % sizeof(T) == 0); constexpr size_t numel_per_thread = alignment / sizeof(T); +<<<<<<< HEAD sync_remote_blocks(signal_pads, rank, world_size); +======= + sync_remote_blocks(signal_pads, rank, world_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __syncthreads(); auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread; @@ -234,7 +260,11 @@ static __global__ void multimem_one_shot_all_reduce_kernel( } __syncthreads(); +<<<<<<< HEAD sync_remote_blocks(signal_pads, rank, world_size); +======= + sync_remote_blocks(signal_pads, rank, world_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } at::Tensor multimem_one_shot_all_reduce_out( @@ -275,8 +305,12 @@ at::Tensor multimem_one_shot_all_reduce_out( 8, 1024, num_blocks, +<<<<<<< HEAD num_threads, symm_mem->get_world_size()); +======= + num_threads); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AT_DISPATCH_FLOAT_AND_BFLOAT16( input.scalar_type(), "multimem_one_shot_all_reduce", [&]() { @@ -316,7 +350,11 @@ static __global__ void multimem_all_gather_kernel( uint32_t** signal_pads, size_t rank, size_t world_size) { +<<<<<<< HEAD sync_remote_blocks(signal_pads, rank, world_size); +======= + sync_remote_blocks(signal_pads, rank, world_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __syncthreads(); const size_t start = bytes_per_rank * rank; @@ -329,7 +367,11 @@ static __global__ void multimem_all_gather_kernel( } __syncthreads(); +<<<<<<< HEAD sync_remote_blocks(signal_pads, rank, world_size); +======= + sync_remote_blocks(signal_pads, rank, world_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } at::Tensor multimem_all_gather_out( @@ -383,8 +425,12 @@ at::Tensor multimem_all_gather_out( 8, 1024, num_blocks, +<<<<<<< HEAD num_threads, symm_mem->get_world_size()); +======= + num_threads); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() { multimem_all_gather_kernel @@ -408,6 +454,10 @@ at::Tensor multimem_all_gather_out( // count to 512 to prevent/alleviate register spill. constexpr size_t one_shot_all_reduce_max_num_blocks = 24; constexpr size_t one_shot_all_reduce_max_num_threads = 512; +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ void one_shot_all_reduce_kernel( @@ -431,7 +481,11 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ } } // TODO make it sync with one block for no-copy case +<<<<<<< HEAD sync_remote_blocks(signal_pads, rank, world_size); +======= + sync_remote_blocks(signal_pads, rank, world_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __syncthreads(); for (size_t i = offset; i < numel; i += stride) { @@ -441,7 +495,11 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ } __syncthreads(); +<<<<<<< HEAD sync_remote_blocks(signal_pads, rank, world_size); +======= + sync_remote_blocks(signal_pads, rank, world_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } at::Tensor one_shot_all_reduce_out_impl( @@ -499,8 +557,12 @@ at::Tensor one_shot_all_reduce_out_impl( one_shot_all_reduce_max_num_blocks, one_shot_all_reduce_max_num_threads, num_blocks, +<<<<<<< HEAD num_threads, symm_mem->get_world_size()); +======= + num_threads); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AT_DISPATCH_FLOAT_AND_BFLOAT16( input.scalar_type(), "one_shot_all_reduce", [&]() { @@ -567,6 +629,7 @@ at::Tensor one_shot_all_reduce_copy( input, local_input, reduce_op, group_name, out); } +<<<<<<< HEAD #if defined(USE_ROCM) constexpr size_t two_shot_all_reduce_max_num_blocks = 64; constexpr size_t two_shot_all_reduce_max_num_threads = 128; @@ -574,6 +637,11 @@ constexpr size_t two_shot_all_reduce_max_num_threads = 128; constexpr size_t two_shot_all_reduce_max_num_blocks = 24; constexpr size_t two_shot_all_reduce_max_num_threads = 1024; #endif +======= +constexpr size_t two_shot_all_reduce_max_num_blocks = 24; +constexpr size_t two_shot_all_reduce_max_num_threads = 1024; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template < typename T, int alignment, @@ -594,7 +662,11 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ constexpr size_t numel_per_thread = alignment / sizeof(T); int32_t N_last_dim = last_dim_size / world_size; // used only for split_last_dim reduce_scatter +<<<<<<< HEAD sync_remote_blocks(signal_pads, rank, world_size); +======= + sync_remote_blocks(signal_pads, rank, world_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __syncthreads(); const size_t numel_per_rank = @@ -626,7 +698,11 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ } __syncthreads(); +<<<<<<< HEAD sync_remote_blocks(signal_pads, rank, world_size); +======= + sync_remote_blocks(signal_pads, rank, world_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if constexpr (reduce_scatter) { return; } @@ -637,16 +713,22 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ for (size_t step = 0; step < k_world_size; ++step) { size_t remote_rank = (rank + step) % k_world_size; size_t remote_start = numel_per_rank * remote_rank; +<<<<<<< HEAD #if defined (USE_ROCM) tmp[step] = at::native::memory::ld_vec( input_ptrs[remote_rank] + input_offset + min(remote_start + i, numel-1)); #else +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (remote_start + i >= numel) { continue; } tmp[step] = at::native::memory::ld_vec( input_ptrs[remote_rank] + input_offset + remote_start + i); +<<<<<<< HEAD #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } #pragma unroll k_world_size for (size_t step = 0; step < k_world_size; ++step) { @@ -661,7 +743,11 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ // need to make sure all blocks exit simultaneously so that the data // is not corrupted by the subsequent kernels __syncthreads(); +<<<<<<< HEAD sync_remote_blocks(signal_pads, rank, world_size); +======= + sync_remote_blocks(signal_pads, rank, world_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } template @@ -676,7 +762,11 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ static_assert(alignment % sizeof(T) == 0); constexpr size_t numel_per_thread = alignment / sizeof(T); +<<<<<<< HEAD sync_remote_blocks(signal_pads, rank, world_size); +======= + sync_remote_blocks(signal_pads, rank, world_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __syncthreads(); const size_t numel_per_rank = @@ -699,7 +789,11 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ } __syncthreads(); +<<<<<<< HEAD sync_remote_blocks(signal_pads, rank, world_size); +======= + sync_remote_blocks(signal_pads, rank, world_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } at::Tensor two_shot_all_reduce_impl( @@ -755,8 +849,12 @@ at::Tensor two_shot_all_reduce_impl( two_shot_all_reduce_max_num_blocks, two_shot_all_reduce_max_num_threads, num_blocks, +<<<<<<< HEAD num_threads, symm_mem->get_world_size()); +======= + num_threads); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!output.has_value()) { AT_DISPATCH_FLOAT_AND_BFLOAT16( @@ -903,8 +1001,12 @@ at::Tensor reduce_scatter_out( two_shot_all_reduce_max_num_blocks, two_shot_all_reduce_max_num_threads, num_blocks, +<<<<<<< HEAD num_threads, symm_mem->get_world_size()); +======= + num_threads); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (split_last_dim) { AT_DISPATCH_FLOAT_AND_BFLOAT16( input.scalar_type(), "two_shot_all_reduce", [&]() { diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryTypes.hpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryTypes.hpp index daf273446ef3a..4f849f7e51e1a 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryTypes.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryTypes.hpp @@ -1,5 +1,6 @@ #pragma once +<<<<<<< HEAD #include namespace c10d::symmetric_memory { @@ -13,6 +14,11 @@ constexpr int symm_max_nblocks = 32; // channels. Each signal is 32 bits, which is the minimum unit for atomic cas. constexpr size_t signal_pad_size = symm_max_nblocks * max_cuda_p2p_domain_size * sizeof(uint32_t); +======= +namespace c10d::symmetric_memory { + +constexpr size_t signal_pad_size = 2048; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) using HandleType = CUmemGenericAllocationHandle; diff --git a/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu index 0eda605fad6fb..6efff60301900 100644 --- a/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu @@ -93,6 +93,85 @@ class NCCLSymmetricMemory : public SymmetricMemory { return nullptr; } +<<<<<<< HEAD +======= + // TODO: This is up for change. + at::Tensor get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) { + // TODO: deduplicate + const size_t numel = std::accumulate( + sizes.begin(), + sizes.end(), + static_cast(1), + std::multiplies()); + const auto element_size = c10::elementSize(dtype); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= buffer_size_, + "NCCLSymmetricMemory::get_buffer: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + buffer_size_, + " bytes)"); + auto data_ptr = reinterpret_cast(buffers_[rank]) + + storage_offset * element_size; + auto device = c10::Device(c10::DeviceType::CUDA, device_idx_); + auto options = at::TensorOptions().dtype(dtype).device(device); + return at::for_blob(data_ptr, sizes) + .options(options) + .target_device(device) + .make_tensor(); + } + + // TODO: This is up for change. + at::Tensor get_signal_pad( + int rank, + c10::IntArrayRef sizes, + std::optional dtype, + int64_t storage_offset) override { + // TODO: deduplicate + // If the dtype is unspecified, default it to UInt32, as it + // is the most common type for signaling purposes. + if (!dtype.has_value()) { + dtype = c10::ScalarType::UInt32; + } + + // If the shape is unspecified, treat the signal pad as a 1d tensor. + const auto element_size = c10::elementSize(*dtype); + std::vector shape; + if (!sizes.empty()) { + shape = sizes.vec(); + } else { + shape.push_back(signal_pad_size / element_size); + } + + const size_t numel = std::accumulate( + shape.begin(), + shape.end(), + static_cast(1), + std::multiplies()); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= signal_pad_size, + "NCCLSymmetricMemory::get_signal_pad: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + signal_pad_size, + " bytes)"); + auto data_ptr = reinterpret_cast(signal_pads_[rank]) + + storage_offset * element_size; + auto device = c10::Device(c10::DeviceType::CUDA, device_idx_); + auto options = at::TensorOptions().dtype(*dtype).device(device); + return at::for_blob(data_ptr, shape) + .options(options) + .target_device(device) + .make_tensor(); + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void barrier(int channel, size_t timeout_ms) override { // TODO } @@ -113,10 +192,13 @@ class NCCLSymmetricMemory : public SymmetricMemory { return world_size_; } +<<<<<<< HEAD c10::Device get_device() override { return c10::Device(c10::DeviceType::CUDA, device_idx_); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) virtual std::vector& get_rank_to_global_rank() override { return rank_to_global_rank_; }; @@ -157,6 +239,11 @@ class NCCLSymmetricMemoryAllocator : public SymmetricMemoryAllocator { auto group_info = get_group_info("0"); auto store = group_info.store; +<<<<<<< HEAD +======= + int rank = group_info.rank; + int world_size = group_info.world_size; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::cuda::CUDAGuard guard(device_idx); // TODO: we might need to use a roundup or mempool for mem allocation. void* ptr; @@ -204,6 +291,10 @@ class NCCLSymmetricMemoryAllocator : public SymmetricMemoryAllocator { ncclWindow_t signal_handle; auto group_info = get_group_info(group_name.value()); +<<<<<<< HEAD +======= + auto global_rank = get_group_info("0").rank; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto buffer_size_map = storeExchange.all_gather(group_info.store, group_info.rank, group_info.world_size, it->second->buffer_size); @@ -229,7 +320,11 @@ class NCCLSymmetricMemoryAllocator : public SymmetricMemoryAllocator { comm)); void* signal_pad_ptr; +<<<<<<< HEAD C10D_NCCL_CHECK(ncclMemAlloc(&signal_pad_ptr, signal_pad_size), "ncclMemAlloc failed"); +======= + TORCH_CHECK(ncclMemAlloc(&signal_pad_ptr, signal_pad_size) == ncclSuccess, "ncclMemAlloc failed"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) C10D_NCCL_CHECK( ncclCommWindowRegister(comm, signal_pad_ptr, signal_pad_size, (ncclWindow_t*)&signal_handle, NCCL_WIN_COLL_SYMMETRIC), c10::str( @@ -252,6 +347,7 @@ class NCCLSymmetricMemoryAllocator : public SymmetricMemoryAllocator { return false; }; +<<<<<<< HEAD c10::DeviceType supported_device_type() override { return c10::DeviceType::CUDA; } @@ -260,6 +356,8 @@ class NCCLSymmetricMemoryAllocator : public SymmetricMemoryAllocator { return "NCCL"; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) private: std::unordered_map> ptr_to_symm_mem_; @@ -271,6 +369,7 @@ class NCCLSymmetricMemoryAllocator : public SymmetricMemoryAllocator { struct RegisterNCCLSymmetricMemoryAllocator { RegisterNCCLSymmetricMemoryAllocator() { +<<<<<<< HEAD auto allocator = c10::make_intrusive(); // Query backend used for CUDA tensor if (getSymmMemBackendCUDA() == "NCCL") { @@ -281,6 +380,13 @@ struct RegisterNCCLSymmetricMemoryAllocator { } else { // Register availability in case `set_backend` is called dynamically register_availability("NCCL", allocator); +======= + // Query backend used for CUDA tensor + if (getSymmMemBackendCUDA() == "NCCL") { + register_allocator( + c10::DeviceType::CUDA, + c10::make_intrusive()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } }; diff --git a/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu index a470c7e2e54f4..608a4bee1783f 100644 --- a/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu @@ -9,6 +9,7 @@ #include #include #include +<<<<<<< HEAD // Starting from NVSHMEM 3.3.9, nvshmem_host.h exists so that we can cleanly // include only the nvshmem host library headers: @@ -17,11 +18,20 @@ #include #include // For maximum compatibility, we use the "host/" style for now. +======= +#include + +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace c10d { namespace symmetric_memory { +<<<<<<< HEAD /* Start of NVSHMEMSymmetricMemory implementation */ +======= +/* Start of CUDASymmetricMemory implementation */ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static StoreExchange storeExchange = StoreExchange("NVSHMEMSymmetricMemory"); @@ -43,6 +53,7 @@ struct NVSHMEMAllocation { } }; +<<<<<<< HEAD // A class to hold the base pointers and signal pad pointers for a group of // peers. One `NVSHMEMPeerAllocInfo` object can be shared by multiple // `NVSHMEMSymmetricMemory` objects when latter reside on the same allocation @@ -61,6 +72,23 @@ class NVSHMEMPeerAllocInfo : public c10::intrusive_ptr_target { auto global_rank = get_group_info("0").rank; GroupInfo& group_info = get_group_info(group_name); +======= +class NVSHMEMSymmetricMemory : public SymmetricMemory { + public: + NVSHMEMSymmetricMemory( + std::shared_ptr allocation, + const std::string& group_name) + : allocation_(allocation), + buffer_size_(allocation->buffer_size), + device_idx_(allocation->device_idx), + group_name_(group_name) { + // For logging only + static int exchanged_n_times = 0; + c10::cuda::CUDAGuard guard(device_idx_); + + auto global_rank = get_group_info("0").rank; + GroupInfo& group_info = get_group_info(group_name_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto store = group_info.store; rank_ = group_info.rank; world_size_ = group_info.world_size; @@ -73,12 +101,17 @@ class NVSHMEMPeerAllocInfo : public c10::intrusive_ptr_target { if (rank_ == 0) { LOG(INFO) << "[rank " << rank_ << "]" << " rank_to_global_rank: " << group_info.rank_to_global_rank +<<<<<<< HEAD << ", group_name: " << group_name +======= + << ", group_name: " << group_name_ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) << ", exchanged_n_times: " << exchanged_n_times; } } TORCH_INTERNAL_ASSERT(!group_info.rank_to_global_rank.empty()); rank_to_global_rank_ = group_info.rank_to_global_rank; +<<<<<<< HEAD world_within_cuda_p2p_ = true; for (int r = 0; r < world_size_; ++r) { @@ -89,11 +122,19 @@ class NVSHMEMPeerAllocInfo : public c10::intrusive_ptr_target { if (peer_ptr == nullptr) { world_within_cuda_p2p_ = false; } +======= + for (int r = 0; r < world_size_; ++r) { + buffers_.push_back(nvshmem_ptr( + allocation->ptr, rank_to_global_rank_[r])); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // TODO: use the same allocation for signal pad void* signal_pad_ptr = nvshmem_malloc(signal_pad_size); +<<<<<<< HEAD TORCH_CHECK(signal_pad_ptr != nullptr, "nvshmem_malloc failed"); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AT_CUDA_CHECK(cudaMemset(signal_pad_ptr, 0, signal_pad_size)); for (int r = 0; r < world_size_; ++r) { @@ -124,6 +165,7 @@ class NVSHMEMPeerAllocInfo : public c10::intrusive_ptr_target { cudaMemcpyHostToDevice)); } +<<<<<<< HEAD private: void* base_ptr_; size_t buffer_size_; @@ -165,11 +207,14 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory { offset_ = offset; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ~NVSHMEMSymmetricMemory() override{ // TODO }; std::vector get_buffer_ptrs() override { +<<<<<<< HEAD return pai_->buffers_; } @@ -187,6 +232,25 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory { size_t get_buffer_size() override { return pai_->buffer_size_; +======= + return buffers_; + } + + std::vector get_signal_pad_ptrs() override { + return signal_pads_; + } + + void** get_buffer_ptrs_dev() override { + return buffers_dev_; + } + + void** get_signal_pad_ptrs_dev() override { + return signal_pads_dev_; + } + + size_t get_buffer_size() override { + return buffer_size_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } size_t get_signal_pad_size() override { @@ -203,8 +267,83 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory { return nullptr; } +<<<<<<< HEAD size_t get_offset() override { return offset_; +======= + at::Tensor get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) { + // TODO: deduplicate + const size_t numel = std::accumulate( + sizes.begin(), + sizes.end(), + static_cast(1), + std::multiplies()); + const auto element_size = c10::elementSize(dtype); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= buffer_size_, + "NVSHMEMSymmetricMemory::get_buffer: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + buffer_size_, + " bytes)"); + auto data_ptr = reinterpret_cast(buffers_[rank]) + + storage_offset * element_size; + auto device = c10::Device(c10::DeviceType::CUDA, device_idx_); + auto options = at::TensorOptions().dtype(dtype).device(device); + return at::for_blob(data_ptr, sizes) + .options(options) + .target_device(device) + .make_tensor(); + } + + at::Tensor get_signal_pad( + int rank, + c10::IntArrayRef sizes, + std::optional dtype, + int64_t storage_offset) override { + // TODO: deduplicate + // If the dtype is unspecified, default it to UInt32, as it + // is the most common type for signaling purposes. + if (!dtype.has_value()) { + dtype = c10::ScalarType::UInt32; + } + + // If the shape is unspecified, treat the signal pad as a 1d tensor. + const auto element_size = c10::elementSize(*dtype); + std::vector shape; + if (!sizes.empty()) { + shape = sizes.vec(); + } else { + shape.push_back(signal_pad_size / element_size); + } + + const size_t numel = std::accumulate( + shape.begin(), + shape.end(), + static_cast(1), + std::multiplies()); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= signal_pad_size, + "NVSHMEMSymmetricMemory::get_signal_pad: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + signal_pad_size, + " bytes)"); + auto data_ptr = reinterpret_cast(signal_pads_[rank]) + + storage_offset * element_size; + auto device = c10::Device(c10::DeviceType::CUDA, device_idx_); + auto options = at::TensorOptions().dtype(*dtype).device(device); + return at::for_blob(data_ptr, shape) + .options(options) + .target_device(device) + .make_tensor(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void barrier(int channel, size_t timeout_ms) override { @@ -220,6 +359,7 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory { } int get_rank() override { +<<<<<<< HEAD return pai_->rank_; } @@ -312,6 +452,39 @@ static void initialize_nvshmem_with_store( LOG(INFO) << "NVSHMEM is available, version: " << major << '.' << minor; } +======= + return rank_; + } + + int get_world_size() override { + return world_size_; + } + + virtual const std::vector& get_rank_to_global_rank() override { + return rank_to_global_rank_; + }; + + int* get_rank_to_global_rank_dev() override { + return rank_to_global_rank_dev_; + }; + + private: + std::shared_ptr allocation_; + size_t buffer_size_; + std::vector buffers_; + std::vector signal_pads_; + int device_idx_; + int rank_; + int world_size_; + void** buffers_dev_; + void** signal_pads_dev_; + std::string group_name_; + + std::vector rank_to_global_rank_; + int* rank_to_global_rank_dev_; +}; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { public: void* alloc( @@ -322,13 +495,17 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { group_name == std::nullopt, "NVSHMEMSymmetricMemoryAllocator::alloc " "must not be called with a group_name"); +<<<<<<< HEAD c10::cuda::CUDAGuard guard(device_idx); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto group_info = get_group_info("0"); auto store = group_info.store; int rank = group_info.rank; int world_size = group_info.world_size; +<<<<<<< HEAD initialize_nvshmem_with_store(store, rank, world_size, device_idx); auto ptr = nvshmem_malloc(size); // If size is 0 (which is legal allocation request) we shouldn't error out @@ -337,6 +514,14 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { allocations_.try_emplace( ptr, std::make_unique(ptr, size, device_idx)); +======= + nvshmem_extension::initialize_nvshmem_with_store(store, rank, world_size); + auto ptr = nvshmem_malloc(size); + auto allocation = + std::make_shared(ptr, size, device_idx); + // TODO: thread safety + allocations_.try_emplace(ptr, std::move(allocation)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ptr; } @@ -364,6 +549,7 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { return it->second; } } +<<<<<<< HEAD // In case of MemPool, tensor.storage().data_ptr() may not match // exactly an allocation's base address. Thus we perform the search by // testing if the former is within an allocation's range. @@ -406,6 +592,15 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { // "shallow" copy adjusting the offset field in the handle. return c10::make_intrusive(*symm_mem, (uintptr_t)ptr - (uintptr_t)allocation->ptr); } +======= + auto it = allocations_.find(ptr); + TORCH_CHECK(it != allocations_.end()); + auto symm_mem = + c10::make_intrusive(it->second, *group_name); + + symm_mems_[std::make_tuple(ptr, *group_name)] = symm_mem; + return symm_mem; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; bool has_multicast_support(int device_idx) override { @@ -413,6 +608,7 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { return false; }; +<<<<<<< HEAD c10::DeviceType supported_device_type() override { return c10::DeviceType::CUDA; } @@ -424,11 +620,17 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { private: std::unordered_map> allocations_; std::map, c10::intrusive_ptr> +======= + private: + std::unordered_map> allocations_; + std::map, c10::intrusive_ptr> +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) symm_mems_; }; struct RegisterNVSHMEMSymmetricMemoryAllocator { RegisterNVSHMEMSymmetricMemoryAllocator() { +<<<<<<< HEAD auto allocator = c10::make_intrusive(); // Query backend used for CUDA tensor if (getSymmMemBackendCUDA() == "NVSHMEM") { @@ -439,6 +641,13 @@ struct RegisterNVSHMEMSymmetricMemoryAllocator { } else { // Register availability in case `set_backend` is called dynamically register_availability("NVSHMEM", allocator); +======= + // Query backend used for CUDA tensor + if (getSymmMemBackendCUDA() == "NVSHMEM") { + register_allocator( + c10::DeviceType::CUDA, + c10::make_intrusive()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } }; diff --git a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp index 81853c4c07d20..cc5b9e830db47 100644 --- a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp @@ -22,6 +22,7 @@ class AllocatorMap { map_[device_type] = std::move(allocator); } +<<<<<<< HEAD void register_availability( const std::string& name, c10::intrusive_ptr allocator) { @@ -55,6 +56,8 @@ class AllocatorMap { return it->second->name(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::intrusive_ptr get_allocator( c10::DeviceType device_type) { auto it = map_.find(device_type); @@ -62,7 +65,10 @@ class AllocatorMap { it != map_.end(), "SymmetricMemory does not support device type ", device_type); +<<<<<<< HEAD in_use_ = true; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return it->second; } @@ -72,6 +78,10 @@ class AllocatorMap { } ~AllocatorMap() { +<<<<<<< HEAD +======= + LOG(INFO) << "Destroying Symmetric Memory Allocators"; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_finalizing_ = true; } @@ -82,6 +92,7 @@ class AllocatorMap { c10::DeviceType, c10::intrusive_ptr> map_; +<<<<<<< HEAD // For backends to register availability. // This registration is at static time. Therefore, it is expected that the @@ -93,6 +104,8 @@ class AllocatorMap { avail_map_; bool in_use_ = false; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; static std::unordered_map group_info_map{}; @@ -173,6 +186,7 @@ void register_allocator( device_type, std::move(allocator)); } +<<<<<<< HEAD void register_availability( const std::string& name, c10::intrusive_ptr allocator) { @@ -187,6 +201,8 @@ std::optional get_backend(c10::Device device) { return AllocatorMap::get().get_backend(device.type()); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool has_allocator(c10::DeviceType device_type) { return AllocatorMap::get().has_allocator(device_type); } @@ -266,6 +282,7 @@ TORCH_API bool has_multicast_support( return allocator->has_multicast_support(device_idx); } } +<<<<<<< HEAD // MemPool Support @@ -427,6 +444,8 @@ at::Tensor SymmetricMemory::get_signal_pad( .make_tensor(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace c10d::symmetric_memory namespace { @@ -496,6 +515,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) { "memset32_(Tensor(a!) input, int offset, int val, int count) -> Tensor(a!)"); m.def("nvshmem_put(Tensor(a!) tensor, int peer) -> ()"); +<<<<<<< HEAD m.def("nvshmem_get(Tensor(a!) tensor, int peer) -> ()"); m.def( "nvshmem_broadcast(Tensor(a!) input, int root, str group_name) -> Tensor(a!)"); @@ -507,6 +527,15 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) { "all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()"); m.def( "all_to_all_vdev_2d_offset(Tensor input, Tensor(a!) out, Tensor in_splits_offsets, Tensor(a!) out_splits_offsets, str group_name) -> ()"); +======= + m.def("nvshmem_broadcast(Tensor(a!) input, str group_name) -> Tensor(a!)"); + m.def( + "nvshmem_all_to_all(Tensor input, Tensor(a!) out, str group_name) -> Tensor(a!)"); + m.def( + "all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)"); + m.def( + "all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name, int? major_align=None) -> Tensor(a!)"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } TORCH_LIBRARY_IMPL(symm_mem, Meta, m) { diff --git a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp index d2cb70e1b1ae9..0ba8fd06eb1c7 100644 --- a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp @@ -50,6 +50,7 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { virtual size_t get_buffer_size() = 0; virtual size_t get_signal_pad_size() = 0; +<<<<<<< HEAD virtual size_t get_offset() { TORCH_CHECK(false, "NYI"); } @@ -73,6 +74,22 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { int peer, c10::IntArrayRef sizes, c10::ScalarType dtype); +======= + virtual bool has_multicast_support() = 0; + virtual void* get_multicast_ptr() = 0; + + virtual at::Tensor get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) = 0; + + virtual at::Tensor get_signal_pad( + int rank, + c10::IntArrayRef sizes, + std::optional dtype = std::nullopt, + int64_t storage_offset = 0) = 0; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) virtual void barrier(int channel, size_t timeout_ms) = 0; virtual void put_signal(int dst_rank, int channel, size_t timeout_ms) = 0; @@ -80,7 +97,10 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { virtual int get_rank() = 0; virtual int get_world_size() = 0; +<<<<<<< HEAD virtual c10::Device get_device() = 0; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) virtual const std::vector& get_rank_to_global_rank() { TORCH_CHECK(false, "NYI"); @@ -89,12 +109,15 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { virtual int* get_rank_to_global_rank_dev() { TORCH_CHECK(false, "NYI"); } +<<<<<<< HEAD // Returns true if *all* peers within the group are accessible via direct // memory load and store. virtual bool world_within_direct_access() { TORCH_CHECK(false, "NYI"); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; class SymmetricMemoryAllocator : public c10::intrusive_ptr_target { @@ -112,8 +135,11 @@ class SymmetricMemoryAllocator : public c10::intrusive_ptr_target { void* ptr, const std::optional& group_name) = 0; virtual bool has_multicast_support(int device_idx) = 0; +<<<<<<< HEAD virtual c10::DeviceType supported_device_type() = 0; virtual std::string name() = 0; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; C10_EXPORT bool is_finalizing(); @@ -122,10 +148,13 @@ C10_EXPORT void register_allocator( c10::DeviceType device_type, c10::intrusive_ptr allocator); +<<<<<<< HEAD C10_EXPORT void register_availability( const std::string& name, c10::intrusive_ptr allocator); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) C10_EXPORT bool has_allocator(c10::DeviceType device_type); C10_EXPORT c10::intrusive_ptr get_allocator( @@ -195,6 +224,7 @@ TORCH_API c10::intrusive_ptr rendezvous( TORCH_API bool has_multicast_support( c10::DeviceType device_type, int device_idx); +<<<<<<< HEAD TORCH_API void set_backend(const std::string& name); @@ -207,4 +237,6 @@ C10_EXPORT void register_mempool_allocator( TORCH_API std::shared_ptr get_mempool_allocator( c10::Device device); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace c10d::symmetric_memory diff --git a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu index 182eaeb90f1a0..b755e34143527 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu +++ b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu @@ -1,14 +1,21 @@ #include +<<<<<<< HEAD #include #include #include #include #include +======= +#include + +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include +<<<<<<< HEAD // Use torch's cub wrapper instead of CUDA's , see #55292 #include @@ -38,6 +45,22 @@ namespace c10d::nvshmem_extension { #define WARP_SIZE 32 extern "C" void nvshmem_init() __attribute__((weak)); +======= +#include +// Use torch's cub wrapper instead of CUDA's , see #55292 +#include +#include + +namespace c10d::nvshmem_extension { + +using c10d::symmetric_memory::StoreExchange; +static StoreExchange storeExchange = StoreExchange("nvshmem_ext"); + +#define THREADS_PER_BLOCK 512 +#define WARP_SIZE 32 + +constexpr int MiB = 1024 * 1024; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Check if NVSHMEM is available bool is_nvshmem_available() { @@ -45,12 +68,15 @@ bool is_nvshmem_available() { static std::mutex mutex; static int is_available = -2; std::lock_guard lock(mutex); +<<<<<<< HEAD // Checked if the symbol is statically linked if(is_available == -2 && nvshmem_init) { is_available = 1; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (is_available == -2) { void* handle{}; // Open the shared library, RTLD_LAZY defers symbol resolution until needed @@ -67,10 +93,72 @@ bool is_nvshmem_available() { return is_available == 1; } +<<<<<<< HEAD +======= +// Bootstrap based on user's setting for NCCL +// Long term, this may be a bit unclean; short term, it improves UX +void maybe_initialize_env_vars() { + auto nccl_socket_if_name = c10::utils::get_env("NCCL_SOCKET_IFNAME"); + auto nccl_hca_list = c10::utils::get_env("NCCL_IB_HCA"); + auto nccl_ib_gid_index = c10::utils::get_env("NCCL_IB_GID_INDEX"); + auto nvshmem_socket_if_name = + c10::utils::get_env("NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME"); + auto nvshmem_hca_list = c10::utils::get_env("NCCL_IB_HCA"); + auto nvshmem_ib_gid_index = c10::utils::get_env("NVSHMEM_IB_GID_INDEX"); + + if (!nvshmem_socket_if_name.has_value() && nccl_socket_if_name.has_value()) { + c10::utils::set_env( + "NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME", nccl_socket_if_name->c_str()); + } + if (!nvshmem_hca_list.has_value() && nccl_hca_list.has_value()) { + c10::utils::set_env("NVSHMEM_ENABLE_NIC_PE_MAPPING", "1"); + c10::utils::set_env("NVSHMEM_HCA_LIST", nccl_hca_list->c_str()); + } + if (!nvshmem_ib_gid_index.has_value() && nccl_ib_gid_index.has_value()) { + c10::utils::set_env("NVSHMEM_IB_GID_INDEX", nccl_ib_gid_index->c_str()); + } +} + +void initialize_nvshmem_with_store( + c10::intrusive_ptr store, + int rank, + int world_size) { + static bool is_initialized = false; + if (is_initialized) { + return; + } + + maybe_initialize_env_vars(); + + nvshmemx_uniqueid_t unique_id; + TORCH_CHECK( + nvshmemx_get_uniqueid(&unique_id) == 0, "nvshmemx_get_uniqueid failed"); + + // Using an existing store_all_gather due to laziness. + // TODO(yifu): should use broadcast + auto unique_ids = storeExchange.all_gather(store, rank, world_size, unique_id); + + nvshmemx_init_attr_t attr; + nvshmemx_set_attr_uniqueid_args(rank, world_size, &unique_ids[0], &attr); + + TORCH_CHECK( + nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr) == 0, + "nvshmemx_init_attr failed"); + + is_initialized = true; + + // Print version + int major, minor; + ::nvshmem_info_get_version(&major, &minor); + LOG(INFO) << "NVSHMEM is available, version: " << major << "." << minor; +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Initializes the device state in CUmodule so that it’s able to perform NVSHMEM // operations. void nvshmemx_cumodule_init(uintptr_t module) { auto cumodule = reinterpret_cast(module); +<<<<<<< HEAD NVSHMEM_CHECK( ::nvshmemx_cumodule_init(cumodule), "nvshmemx_cumodule_init failed"); @@ -92,6 +180,56 @@ at::Tensor nvshmem_broadcast(at::Tensor& input, const int64_t root, const std::s } void nvshmem_put(at::Tensor& tensor, const int64_t peer) { +======= + TORCH_CHECK( + ::nvshmemx_cumodule_init(cumodule) == 0, + "nvshmemx_cumodule_init failed"); +} + +std::unordered_map group_name_to_team_; + +nvshmem_team_t group_to_team( + const std::string& group_name, + const std::vector& global_ranks) { + auto it = group_name_to_team_.find(group_name); + if (it != group_name_to_team_.end()) { + return it->second; + } + TORCH_CHECK(global_ranks.size() > 1); + int stride = global_ranks[1] - global_ranks[0]; + for (size_t r = 1; r < global_ranks.size(); ++r) { + TORCH_CHECK(global_ranks[r] - global_ranks[r - 1] == stride); + } + + nvshmem_team_t team; + TORCH_CHECK( + nvshmem_team_split_strided( + NVSHMEM_TEAM_WORLD, + global_ranks[0], + stride, + global_ranks.size(), + nullptr, + 0, + &team) == 0); + group_name_to_team_[group_name] = team; + TORCH_CHECK(team != NVSHMEM_TEAM_INVALID); + return team; +} + +at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name) { + auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); + int rank = input_hdl->get_rank(); + int world_size = input_hdl->get_world_size(); + auto team = group_to_team(group_name, input_hdl->get_rank_to_global_rank()); + void* buffer_ptr = input_hdl->get_buffer_ptrs()[rank]; + + auto stream = at::cuda::getCurrentCUDAStream(); + nvshmemx_broadcastmem_on_stream(team, buffer_ptr, buffer_ptr, input_hdl->get_buffer_size(), 0, stream); + return input; +} + +void nvshmem_put(at::Tensor& tensor, int64_t peer) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // TODO: support non-contiguous tensors TORCH_CHECK(tensor.is_contiguous(), "put op currently supports contiguous tensors only"); @@ -100,13 +238,17 @@ void nvshmem_put(at::Tensor& tensor, const int64_t peer) { auto rank = hdl->get_rank(); void* buffer_ptr = hdl->get_buffer_ptrs()[rank]; auto buffer_size = tensor.numel() * tensor.element_size(); +<<<<<<< HEAD TORCH_CHECK(peer < hdl->get_world_size(), "peer must be smaller than world size"); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::cuda::CUDAGuard guard(tensor.device()); auto stream = at::cuda::getCurrentCUDAStream(); nvshmemx_putmem_on_stream(buffer_ptr, tensor.data_ptr(), buffer_size, peer, stream); } +<<<<<<< HEAD void nvshmem_get(at::Tensor& tensor, const int64_t peer) { // TODO: support non-contiguous tensors TORCH_CHECK(tensor.is_contiguous(), @@ -123,6 +265,8 @@ void nvshmem_get(at::Tensor& tensor, const int64_t peer) { nvshmemx_getmem_on_stream(tensor.mutable_data_ptr(), buffer_ptr, buffer_size, peer, stream); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::Tensor nvshmem_all_to_all( at::Tensor& input, at::Tensor& out, @@ -131,6 +275,7 @@ at::Tensor nvshmem_all_to_all( auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); int rank = input_hdl->get_rank(); int world_size = input_hdl->get_world_size(); +<<<<<<< HEAD auto& team_manager = TeamManager::get(input.device()); auto team = team_manager.get_team(group_name, input_hdl->get_rank_to_global_rank()); @@ -142,6 +287,13 @@ at::Tensor nvshmem_all_to_all( TORCH_CHECK_EQ(input.numel() % world_size, 0); auto buffer_size = input.numel() * input.element_size(); size_t bytes_per_rank = buffer_size / world_size; +======= + auto team = group_to_team(group_name, input_hdl->get_rank_to_global_rank()); + + void* input_ptr = input_hdl->get_buffer_ptrs()[rank]; + void* output_ptr = out_hdl->get_buffer_ptrs()[rank]; + size_t bytes_per_rank = input_hdl->get_buffer_size() / world_size; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto stream = at::cuda::getCurrentCUDAStream(input.device().index()); nvshmemx_alltoallmem_on_stream(team, output_ptr, input_ptr, bytes_per_rank, stream); @@ -180,6 +332,7 @@ __device__ int64_t prefixSum(int64_t *odata, int64_t *idata, int n) { // - input splits (IN) // - output splits (OUT) and // - source offsets (OUT). +<<<<<<< HEAD __global__ void exchangeSplitAndOffset(int64_t* input_splits, int64_t* out_splits_offsets, nvshmem_team_t team) { #ifndef _NVSHMEM_DEVICELIB_SUPPORTED CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM"); @@ -192,6 +345,14 @@ __global__ void exchangeSplitAndOffset(int64_t* input_splits, int64_t* out_split int tid = threadIdx.x; CUDA_KERNEL_ASSERT(npes <= THREADS_PER_BLOCK); +======= +__global__ void exchangeSplitAndOffset(int64_t* in_out_splits, int mype, int npes) { + auto input_splits = in_out_splits; + auto output_splits = in_out_splits + npes; + auto source_offsets = in_out_splits + npes * 2; + int tid = threadIdx.x; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __shared__ int64_t peer_offsets[THREADS_PER_BLOCK]; // Scan input splits to get the source offsets @@ -200,6 +361,7 @@ __global__ void exchangeSplitAndOffset(int64_t* input_splits, int64_t* out_split // Use 1 block to do the exchange if (tid < npes) { +<<<<<<< HEAD // tid is peer index within team, but put calls require global rank int peer_global = nvshmem_team_translate_pe(team, tid, NVSHMEM_TEAM_WORLD); nvshmem_int64_p(source_offsets + mype, peer_offsets[tid], peer_global); @@ -208,11 +370,20 @@ __global__ void exchangeSplitAndOffset(int64_t* input_splits, int64_t* out_split // This barrier ensures that all remote PEs see the updated values nvshmemx_barrier_block(team); #endif +======= + int peer = tid; + nvshmem_int64_p(source_offsets + mype, peer_offsets[peer], peer); + nvshmem_int64_p(output_splits + mype, input_splits[peer], peer); + } + // This barrier ensures that all remote PEs see the updated values + nvshmemx_barrier_all_block(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // This kernel is used to do the actual data exchange. // `in_out_splits` has the same definition as in `exchangeSplitAndOffset`. // `stride` is the stride at dim 0, unit in byte. +<<<<<<< HEAD __global__ void allToAllV(void *send_data, void *recv_data, int64_t* out_splits_offsets, size_t stride, nvshmem_team_t team) { #ifndef _NVSHMEM_DEVICELIB_SUPPORTED CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM"); @@ -222,12 +393,20 @@ __global__ void allToAllV(void *send_data, void *recv_data, int64_t* out_splits_ int npes = nvshmem_team_n_pes(team); auto output_splits = out_splits_offsets; auto source_offsets = out_splits_offsets + npes; +======= +__global__ void allToAllV(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes) { + auto output_splits = in_out_splits + npes; + auto source_offsets = in_out_splits + npes * 2; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int bid = blockIdx.x; int tid = threadIdx.x; int blocks_per_peer = max(gridDim.x / npes, 1); // Calculate the output offsets +<<<<<<< HEAD CUDA_KERNEL_ASSERT(npes <= THREADS_PER_BLOCK); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __shared__ int64_t peer_offsets[THREADS_PER_BLOCK]; prefixSum(peer_offsets, output_splits, npes); __syncthreads(); @@ -235,7 +414,10 @@ __global__ void allToAllV(void *send_data, void *recv_data, int64_t* out_splits_ // Target a different peer based on bid for (int i = bid / blocks_per_peer; i < npes; i += gridDim.x / blocks_per_peer) { int peer = (mype + i) % npes; +<<<<<<< HEAD auto peer_global = nvshmem_team_translate_pe(team, peer, NVSHMEM_TEAM_WORLD); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Total amount from `peer` auto peer_size = output_splits[peer] * stride; // Amount to get from `peer` in this block @@ -246,16 +428,25 @@ __global__ void allToAllV(void *send_data, void *recv_data, int64_t* out_splits_ auto block_offset = block_size * (bid % blocks_per_peer); auto source_offset = source_offsets[peer] * stride + block_offset; auto write_offset = peer_offsets[peer] * stride + block_offset; +<<<<<<< HEAD nvshmemx_getmem_nbi_block( (char*)recv_data + write_offset, (char*)send_data + source_offset, block_size, peer_global); +======= + nvshmemx_getmem_block( + (char*)recv_data + write_offset, + (char*)send_data + source_offset, + block_size, + peer); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // Write out the output offsets (to the scratchpad line) if (bid == 0 && tid < npes) { source_offsets[tid] = peer_offsets[tid]; } +<<<<<<< HEAD // Make sure getmem_nbi calls finish nvshmem_quiet(); #endif @@ -281,11 +472,20 @@ void all_to_all_vdev( at::Tensor& out, at::Tensor& in_splits, at::Tensor& out_splits_offsets, +======= +} + +at::Tensor all_to_all_vdev( + at::Tensor& input, + at::Tensor& out, + at::Tensor& in_out_splits, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::string group_name) { /* Perform AllToAllv operation using NVSHMEM, with split information provided on device. * Arguments: * - `input` is the input tensor * - `out` is the output tensor +<<<<<<< HEAD * - `in_splits` is a 1D tensor of size (npes), containing the input splits * - `out_splits_offsets` is a 2D tensor of size (2, npes). The rows are (in order): output splits and output offsets. @@ -308,13 +508,37 @@ void all_to_all_vdev( auto& team_manager = TeamManager::get(device); auto team = team_manager.get_team(group_name, input_hdl->get_rank_to_global_rank()); auto stream = at::cuda::getCurrentCUDAStream(device.index()); +======= + * - `in_out_splits` is a 2D tensor of size (3, npes). The rows are (in order): + input splits (IN) + output splits (OUT) and + output offsets (OUT). + */ + auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); + auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); + auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name); + int rank = input_hdl->get_rank(); + int world_size = input_hdl->get_world_size(); + + void* input_ptr = input_hdl->get_buffer_ptrs()[rank]; + void* output_ptr = out_hdl->get_buffer_ptrs()[rank]; + int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]); + + auto stream = at::cuda::getCurrentCUDAStream(input.device().index()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Exchange output splits and source offsets // Use collective launch because kernel involves nvshmem barrier void* args0[] = { +<<<<<<< HEAD &in_splits_ptr, &out_splits_offsets_ptr, &team}; +======= + &splits_ptr, + &rank, + &world_size}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nvshmemx_collective_launch( (const void*)exchangeSplitAndOffset, dim3(1), @@ -324,11 +548,36 @@ void all_to_all_vdev( stream); // CTA Tuning +<<<<<<< HEAD auto input_size = input.numel() * input.element_size(); int num_blocks = get_a2a_nblocks( input_size, input_hdl->get_world_size(), input_hdl->world_within_direct_access()); +======= + // Intra-node: use multiple blocks per peer to increase data parallelism, up to 8. + // Up to 1 MB -> 1 block + // Up to 2 MB -> 2 blocks + // Up to 4 MB -> 4 blocks + // More -> 8 blocks + // The tuning for `num_blocks` below multiplies these numbers by world_size + // (e.g. 8 -> 8 * 8). If world_size is smaller, we simply shift the blocks + // towards data parallelism. (There may be room for improvement here) + auto input_size = input.numel() * input.element_size(); + int num_blocks = input_size < MiB ? 8 : + (input_size < 2 * MiB ? 16 : + (input_size < 4 * MiB ? 32 : 64)); + + // Inter-node: limit the total the number of blocks: + // = 16 for 16GPUs which is enough to max out 90 GB/s bandwidth perf + // = 8 for more than 16 GPUs which is enough to max out approx 50 GB/s bandwidth perf + // Above assumes 400Gb/s NIC for inter-node and 400GB/s NVLinks for intra-node comms. + // TODO: better intra vs inter detection, currently it is based on world_size. + int max_inter_node_blocks = world_size <= 16 ? 16 : 8; + if (world_size > 8) { + num_blocks = std::min(num_blocks, max_inter_node_blocks); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Stride at dim 0 (assuming input is contiguous, TODO) size_t stride_bytes = input.stride(0) * input.element_size(); @@ -337,9 +586,16 @@ void all_to_all_vdev( void* args1[] = { &input_ptr, &output_ptr, +<<<<<<< HEAD &out_splits_offsets_ptr, &stride_bytes, &team}; +======= + &splits_ptr, + &stride_bytes, + &rank, + &world_size}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nvshmemx_collective_launch( (const void*)allToAllV, dim3(num_blocks), @@ -347,6 +603,7 @@ void all_to_all_vdev( args1, 0, stream); +<<<<<<< HEAD } // Start of `all_to_all_vdev_2d` @@ -421,6 +678,45 @@ __global__ void exchangeSplitAndOffset_2d(int64_t* in_splits_offsets, int64_t* o // This barrier ensures that all remote PEs see the updated values nvshmemx_barrier_block(team); #endif +======= + return out; +} + +// Start of `all_to_all_vdev_2d` +// This kernel is used to exchange output splits and source offsets between peers. +// For meaning of `mype` and `npes`, see the docstring of `all_to_all_vdev_2d`. +// `in_out_splits` is of size (3, npes * ne) and contains: +// - input splits (IN) +// - output splits (OUT) and +// - source offsets (OUT). +__global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int npes, int ne, size_t input_dim0) { + int nsplits = npes * ne; + auto input_splits = in_out_splits; + auto output_splits = in_out_splits + nsplits; + auto source_offsets = in_out_splits + nsplits * 2; + int tid = threadIdx.x; + + __shared__ int64_t peer_offsets[THREADS_PER_BLOCK]; + + // Scan input splits to get the source offsets + auto sum_of_splits = prefixSum(peer_offsets, input_splits, nsplits); + __syncthreads();; + CUDA_KERNEL_ASSERT(sum_of_splits <= input_dim0); + + // Use 1 block to do the exchange + if (tid < nsplits) { + int peer = tid / ne; + int e = tid % ne; + // This does a transpose from rank-major order to expert-major order + int dst_offset = e * npes + mype; + auto split_val = input_splits[tid]; + CUDA_KERNEL_ASSERT(split_val >= 0); + nvshmem_int64_p(source_offsets + dst_offset, peer_offsets[tid], peer); + nvshmem_int64_p(output_splits + dst_offset, split_val, peer); + } + // This barrier ensures that all remote PEs see the updated values + nvshmemx_barrier_all_block(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // This is an warp-scope, exclusive prefix sum. When called by a block of @@ -463,6 +759,7 @@ __device__ int64_t prefixSum_warp(int64_t *odata, int64_t *idata, int n) { // `in_out_splits` has the same definition as in `exchangeSplitAndOffset`. // `stride` is the stride at dim 0, unit in byte. // For meaning of `mype` and `npes`, see the docstring of `all_to_all_vdev_2d`. +<<<<<<< HEAD // `major_align` is the alignment at dim 0, unit in element. If 0, no alignment is needed. // `rank_is_row_out` is a boolean flag indicating whether the output has ranks as rows or experts as rows. @@ -476,6 +773,12 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_split int nsplits = minor_size * major_size; auto output_splits = out_splits_offsets; auto source_offsets = out_splits_offsets + nsplits; +======= +__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne, int64_t major_align) { + int nsplits = npes * ne; + auto output_splits = in_out_splits + nsplits; + auto source_offsets = in_out_splits + nsplits * 2; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int bid = blockIdx.x; int tid = threadIdx.x; @@ -485,6 +788,7 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_split int laneId = tid % A2AV_TILE_SIZE; // Each tile calculates its own prefix sum __shared__ int64_t tile_prefix_sums[NUM_TILES][A2AV_TILE_SIZE]; +<<<<<<< HEAD // A tile takes care of minor_size worth of splits int nsplits_per_tile = min(minor_size, nsplits - tileId * minor_size); // TODO: currently it is assumed that the number of PE's is smaller than @@ -494,6 +798,17 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_split // Similarly, the number of experts per rank is also assumed to be smaller // than `NUM_TILES` CUDA_KERNEL_ASSERT(major_size <= NUM_TILES && "major_size is too large\n"); +======= + // A tile takes care of npes worth of splits + int nsplits_per_tile = min(npes, nsplits - tileId * npes); + // TODO: currently it is assumed that the number of PE's is smaller than + // `A2AV_TILE_SIZE` bc the warp-scope prefix sum can only handle up to + // WARP_SIZE elements + CUDA_KERNEL_ASSERT(npes <= A2AV_TILE_SIZE); + // Similarly, the number of experts per rank is also assumed to be smaller + // than `NUM_TILES` + CUDA_KERNEL_ASSERT(ne <= NUM_TILES); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Total length of each tile __shared__ int64_t len_per_tile[NUM_TILES]; @@ -501,6 +816,7 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_split // this local prefix sum. if (nsplits_per_tile > 0) { // Each tile calculates its own prefix sum, return value is the sum of all elements in the tile. +<<<<<<< HEAD int64_t my_tile_len = prefixSum_warp(tile_prefix_sums[tileId], output_splits + tileId * minor_size, nsplits_per_tile); // Last thread in each tile does the up aligning. if (laneId == A2AV_TILE_SIZE - 1) { @@ -513,6 +829,16 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_split } else { // 0 means alignment not needed len_per_tile[tileId] = my_tile_len; } +======= + int64_t my_tile_len = prefixSum_warp(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile); + // Last thread in each tile does the up aligning. + if (laneId == A2AV_TILE_SIZE - 1) { + auto aligned_len = (my_tile_len + major_align - 1) / major_align * major_align; + // In case `aligned_len` is 0, we set it to `major_align` to avoid an + // empty bin, bc cutlass currently does not support it. See + // https://github.com/pytorch/pytorch/issues/152668. + len_per_tile[tileId] = max(aligned_len, major_align); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } __syncthreads(); @@ -534,6 +860,7 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_split // Target a different e based on bid for (int eid = bid; eid < nsplits; eid += gridDim.x) { +<<<<<<< HEAD int row = eid / minor_size; int col = eid % minor_size; // Amount from `peer` for `e` @@ -562,6 +889,30 @@ void all_to_all_vdev_2d( at::Tensor& out, at::Tensor& in_splits, at::Tensor& out_splits_offsets, +======= + int peer = eid % npes; + // Amount from `peer` for `e` + auto peer_size = output_splits[eid] * stride; + auto source_offset = source_offsets[eid] * stride; + auto e_offset = tile_prefix_sums[eid / npes][peer]; + auto write_offset = e_offset * stride; + nvshmemx_getmem_block( + (char*)recv_data + write_offset, + (char*)send_data + source_offset, + peer_size, + peer); + } + // Write out the output offsets (to the scratchpad line) + if (bid == 0 && tid < nsplits) { + source_offsets[tid] = tile_prefix_sums[tid / npes][tid % npes]; + } +} + +at::Tensor all_to_all_vdev_2d( + at::Tensor& input, + at::Tensor& out, + at::Tensor& in_out_splits, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::string group_name, std::optional major_align) { /* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device. @@ -602,8 +953,12 @@ void all_to_all_vdev_2d( */ auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); +<<<<<<< HEAD auto in_splits_hdl = c10d::symmetric_memory::rendezvous(in_splits, group_name); auto out_splits_offsets_hdl = c10d::symmetric_memory::rendezvous(out_splits_offsets, group_name); +======= + auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int rank = input_hdl->get_rank(); int world_size = input_hdl->get_world_size(); // TODO: world_size is currently limited by the number of elements in a WarpScan. @@ -613,6 +968,7 @@ void all_to_all_vdev_2d( int64_t major_align_val = major_align.value_or(1); TORCH_CHECK(major_align_val > 0, "major_align must be positive"); +<<<<<<< HEAD void* input_ptr = input.data_ptr(); void* output_ptr = out.mutable_data_ptr(); int64_t* in_splits_ptr = (int64_t*)(in_splits.data_ptr()); @@ -631,22 +987,46 @@ void all_to_all_vdev_2d( && out_split_shape[1] == in_split_shape[0] && in_split_shape[0] % world_size == 0, "out_splits_offsets must be 2D with 2 rows, " +======= + void* input_ptr = input_hdl->get_buffer_ptrs()[rank]; + void* output_ptr = out_hdl->get_buffer_ptrs()[rank]; + int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]); + + // Shape checks + auto split_shape = in_out_splits.sizes(); + TORCH_CHECK(in_out_splits.is_contiguous() + && input.is_contiguous() + && out.is_contiguous(), + "input, out and in_out_splits must be contiguous"); + TORCH_CHECK(split_shape.size() == 2 + && split_shape[0] == 3 + && split_shape[1] % world_size == 0, + "in_out_splits must be 2D with 3 rows, " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "each row must be a multiple of world_size"); // Consistency checks TORCH_CHECK(input.dtype() == out.dtype() && input.stride(0) == out.stride(0), "input and out must have the same dtype and same stride at dim 0"); +<<<<<<< HEAD TORCH_CHECK(in_splits.scalar_type() == at::kLong && out_splits_offsets.scalar_type() == at::kLong, "splits and offsets must be int64"); // Number of experts per rank int ne = in_split_shape[0] / world_size; +======= + TORCH_CHECK(in_out_splits.scalar_type() == at::kLong, "in_out_splits must be int64"); + + // Number of experts per rank + int ne = split_shape[1] / world_size; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE; TORCH_CHECK(ne <= NUM_TILES, "Number of experts must be smaller than NUM_TILES", NUM_TILES); // Set device context for getting the stream and launching kernels below +<<<<<<< HEAD auto device = input.device(); TORCH_CHECK(device.type() == at::DeviceType::CUDA && out.device() == device && @@ -671,6 +1051,22 @@ void all_to_all_vdev_2d( &rank_is_row_in}; nvshmemx_collective_launch( (const void*)exchangeSplitAndOffset_2d, // false: input offsets not provided +======= + c10::cuda::CUDAGuard guard(input.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + + // Exchange output splits and source offsets + auto input_dim0 = input.size(0); + // Use collective launch because kernel involves nvshmem barrier + void* args0[] = { + &splits_ptr, + &rank, + &world_size, + &ne, + &input_dim0}; + nvshmemx_collective_launch( + (const void*)exchangeSplitAndOffset_2d, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dim3(1), dim3(THREADS_PER_BLOCK), args0, @@ -684,12 +1080,16 @@ void all_to_all_vdev_2d( // Stride at dim 0 size_t stride_bytes = input.stride(0) * input.element_size(); +<<<<<<< HEAD bool rank_is_row_out = !rank_is_row_in; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // All to all data exchange void* args1[] = { &input_ptr, &output_ptr, +<<<<<<< HEAD &in_splits_ptr, &out_splits_offsets_ptr, &stride_bytes, @@ -698,6 +1098,14 @@ void all_to_all_vdev_2d( &major_align_val, &rank_is_row_out, &team}; +======= + &splits_ptr, + &stride_bytes, + &rank, + &world_size, + &ne, + &major_align_val}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nvshmemx_collective_launch( (const void*)allToAllV_2d, dim3(num_blocks), @@ -705,6 +1113,7 @@ void all_to_all_vdev_2d( args1, 0, stream); +<<<<<<< HEAD } void all_to_all_vdev_2d_offset( @@ -841,15 +1250,26 @@ void all_to_all_vdev_2d_offset( 0, stream); } +======= + return out; +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace c10d::nvshmem_extension TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) { m.impl("nvshmem_broadcast", c10d::nvshmem_extension::nvshmem_broadcast); m.impl("nvshmem_put", c10d::nvshmem_extension::nvshmem_put); +<<<<<<< HEAD m.impl("nvshmem_get", c10d::nvshmem_extension::nvshmem_get); m.impl("nvshmem_all_to_all", c10d::nvshmem_extension::nvshmem_all_to_all); m.impl("all_to_all_vdev", c10d::nvshmem_extension::all_to_all_vdev); m.impl("all_to_all_vdev_2d", c10d::nvshmem_extension::all_to_all_vdev_2d); m.impl("all_to_all_vdev_2d_offset", c10d::nvshmem_extension::all_to_all_vdev_2d_offset); +======= + m.impl("nvshmem_all_to_all", c10d::nvshmem_extension::nvshmem_all_to_all); + m.impl("all_to_all_vdev", c10d::nvshmem_extension::all_to_all_vdev); + m.impl("all_to_all_vdev_2d", c10d::nvshmem_extension::all_to_all_vdev_2d); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } diff --git a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh index ae008921bcd83..8b1ef67ab39c5 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh +++ b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cuh @@ -1,5 +1,6 @@ #pragma once +<<<<<<< HEAD #include #include @@ -14,6 +15,19 @@ namespace c10d::nvshmem_extension { +======= +#include + +#include + +namespace c10d::nvshmem_extension { + +void initialize_nvshmem_with_store( + c10::intrusive_ptr store, + int rank, + int world_size); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Check if NVSHMEM is available TORCH_API bool is_nvshmem_available(); @@ -21,17 +35,24 @@ TORCH_API bool is_nvshmem_available(); // operations. TORCH_API void nvshmemx_cumodule_init(uintptr_t module); +<<<<<<< HEAD TORCH_API void nvshmem_put(at::Tensor& tensor, const int64_t peer); TORCH_API void nvshmem_get(at::Tensor& tensor, const int64_t peer); at::Tensor nvshmem_broadcast(at::Tensor& input, const int64_t root, const std::string& group_name); +======= +TORCH_API void nvshmem_put(at::Tensor& tensor, int64_t peer); + +at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::Tensor nvshmem_all_to_all( at::Tensor& input, at::Tensor& out, std::string group_name); +<<<<<<< HEAD void all_to_all_vdev( at::Tensor& input, at::Tensor& out, @@ -54,4 +75,19 @@ void all_to_all_vdev_2d_offset( at::Tensor& out_splits_offsets, std::string group_name); +======= +at::Tensor all_to_all_vdev( + at::Tensor& input, + at::Tensor& out, + at::Tensor& in_out_splits, + std::string group_name); + +at::Tensor all_to_all_vdev_2d( + at::Tensor& input, + at::Tensor& out, + at::Tensor& in_out_splits, + std::string group_name, + std::optional major_align = std::nullopt); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace c10d::nvshmem_extension diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp index fb812f8522f5a..5a092bdc4ee9a 100644 --- a/torch/csrc/distributed/rpc/init.cpp +++ b/torch/csrc/distributed/rpc/init.cpp @@ -79,8 +79,11 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { module.attr("_DEFAULT_RPC_TIMEOUT_SEC") = py::cast(kDefaultRpcTimeoutSeconds); module.attr("_UNSET_RPC_TIMEOUT") = py::cast(kUnsetRpcTimeout); module.attr("_DEFAULT_INIT_METHOD") = py::cast(kDefaultInitMethod); +<<<<<<< HEAD module.attr("_DEFAULT_NUM_WORKER_THREADS") = py::cast(kDefaultNumWorkerThreads); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto workerInfo = shared_ptr_class_( @@ -570,6 +573,12 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { R"(All devices used by the local agent.)") .def("_set_device_map", &TensorPipeRpcBackendOptions::setDeviceMap); +<<<<<<< HEAD +======= + module.attr("_DEFAULT_NUM_WORKER_THREADS") = + py::cast(kDefaultNumWorkerThreads); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shared_ptr_class_(module, "TensorPipeAgent", rpcAgent) .def( py::init( diff --git a/torch/csrc/distributed/rpc/python_remote_call.h b/torch/csrc/distributed/rpc/python_remote_call.h index 09d4ba36dc62b..f4b842a5c8667 100644 --- a/torch/csrc/distributed/rpc/python_remote_call.h +++ b/torch/csrc/distributed/rpc/python_remote_call.h @@ -3,6 +3,10 @@ #include #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace torch::distributed::rpc { class TORCH_API PythonRemoteCall : public RpcCommandBase { diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h index e353c54805415..b07f50c8a1c4c 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.h +++ b/torch/csrc/distributed/rpc/rpc_agent.h @@ -24,7 +24,10 @@ constexpr auto kDefaultInitMethod = "env://"; constexpr float kSecToMsConversion = 1000; constexpr auto kRpcTimeoutErrorStr = "RPC ran for more than set timeout ({} ms) and will now be marked with an error"; +<<<<<<< HEAD constexpr auto kDefaultNumWorkerThreads = 16; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using steady_clock_time_point = std::chrono::time_point; diff --git a/torch/csrc/distributed/rpc/rref_context.cpp b/torch/csrc/distributed/rpc/rref_context.cpp index c36c6386b861e..257a0e916f01e 100644 --- a/torch/csrc/distributed/rpc/rref_context.cpp +++ b/torch/csrc/distributed/rpc/rref_context.cpp @@ -348,7 +348,11 @@ c10::intrusive_ptr RRefContext::getOrCreateOwnerRRef( // here is a plain TensorType, they are not equal relationship: // specialized TensorType <: plain TensorType // +<<<<<<< HEAD // In RPC we don't care the difference as we Ser/De with just the +======= + // In RPC we don't care the difference as we ser'de with just the +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // plain TensorType. This is not a issue for UserRRef creation either, // since Tensor can only get specialized with a previous run of local // JIT function, and we shouldn't preserve the specialized SubTensorType diff --git a/torch/csrc/distributed/rpc/rref_proto.h b/torch/csrc/distributed/rpc/rref_proto.h index a1482b46939b1..9c005f32019fc 100644 --- a/torch/csrc/distributed/rpc/rref_proto.h +++ b/torch/csrc/distributed/rpc/rref_proto.h @@ -4,6 +4,10 @@ #include #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include namespace torch::distributed::rpc { diff --git a/torch/csrc/distributed/rpc/script_call.h b/torch/csrc/distributed/rpc/script_call.h index 476ee118fe7fa..93962da46c4d1 100644 --- a/torch/csrc/distributed/rpc/script_call.h +++ b/torch/csrc/distributed/rpc/script_call.h @@ -3,6 +3,10 @@ #include #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include diff --git a/torch/csrc/distributed/rpc/script_remote_call.h b/torch/csrc/distributed/rpc/script_remote_call.h index e18edab648210..28144c695c8a2 100644 --- a/torch/csrc/distributed/rpc/script_remote_call.h +++ b/torch/csrc/distributed/rpc/script_remote_call.h @@ -3,6 +3,10 @@ #include #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include namespace torch::distributed::rpc { diff --git a/torch/csrc/distributed/rpc/script_resp.h b/torch/csrc/distributed/rpc/script_resp.h index 53841e3d705c2..f60f2b4a8b1dd 100644 --- a/torch/csrc/distributed/rpc/script_resp.h +++ b/torch/csrc/distributed/rpc/script_resp.h @@ -2,6 +2,10 @@ #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace torch::distributed::rpc { diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.h b/torch/csrc/distributed/rpc/tensorpipe_agent.h index e6f4d66af1388..291df5693d52f 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.h +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.h @@ -74,6 +74,11 @@ struct TORCH_API ChannelRegistration { TORCH_DECLARE_REGISTRY(TensorPipeChannelRegistry, ChannelRegistration); +<<<<<<< HEAD +======= +constexpr auto kDefaultNumWorkerThreads = 16; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct TORCH_API TensorPipeRpcBackendOptions : public RpcBackendOptions { TensorPipeRpcBackendOptions( int numWorkerThreads, diff --git a/torch/csrc/dynamo/cache_entry.cpp b/torch/csrc/dynamo/cache_entry.cpp index beb8064ba6c24..6451c822dc67c 100644 --- a/torch/csrc/dynamo/cache_entry.cpp +++ b/torch/csrc/dynamo/cache_entry.cpp @@ -5,7 +5,11 @@ #include CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) +<<<<<<< HEAD : backend{py::cast(get_backend(backend))} { +======= + : backend{backend} { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) this->guard_manager = guarded_code.attr("guard_manager"); this->code = guarded_code.attr("code"); this->compile_id = guarded_code.attr("compile_id"); @@ -52,7 +56,10 @@ void CacheEntry::invalidate(py::object deleted_guard_manager) { this->guard_manager = std::move(deleted_guard_manager); this->root_mgr = nullptr; this->trace_annotation = "Invalidated"; +<<<<<<< HEAD this->backend = py::none(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void CacheEntry::update_diff_guard_root_manager() { @@ -77,8 +84,13 @@ PyObject* CacheEntry_to_obj(CacheEntry* e) { PyObject* get_backend(PyObject* callback) { py::handle handle = py::handle(callback); +<<<<<<< HEAD while (py::hasattr(handle, "_torchdynamo_orig_backend")) { handle = handle.attr("_torchdynamo_orig_backend"); +======= + while (py::hasattr(handle, "_torchdynamo_orig_callable")) { + handle = handle.attr("_torchdynamo_orig_callable"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return handle.ptr(); } diff --git a/torch/csrc/dynamo/cache_entry.h b/torch/csrc/dynamo/cache_entry.h index e7c58f31a090d..4d30270b5630d 100644 --- a/torch/csrc/dynamo/cache_entry.h +++ b/torch/csrc/dynamo/cache_entry.h @@ -53,7 +53,11 @@ typedef struct VISIBILITY_HIDDEN CacheEntry { // diff guard root guard manager if exists void* diff_guard_root_mgr{nullptr}; // backend used to create this cache entry +<<<<<<< HEAD py::object backend; +======= + PyObject* backend{nullptr}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Reference to owning ExtraState ExtraState* _owner{nullptr}; // Reference to this CacheEntry's location in owner's linked list diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index c5f5fd8d2f188..99fa238bfa24d 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -1106,8 +1106,12 @@ struct IValuePacker { // That's what the TypePtr is for: it contains the information to do the // parsing. See torch::jit::toIValue for more information. static at::TypePtr packed_type() { +<<<<<<< HEAD // On windows CPU is support compiled autograd. #if defined(_WIN32) && (defined(USE_CUDA) || defined(USE_ROCM)) +======= +#ifdef _WIN32 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // NB: the if-constexpr usage triggers compilation errors on Windows // with certain compiler settings // (see https://github.com/pytorch/pytorch/pull/144707 for examples). @@ -1385,8 +1389,12 @@ struct IValuePacker> { } std::vector result; auto lst = t.toList(); +<<<<<<< HEAD for (size_t i = 0; i < lst.size(); ++i) { const at::IValue& elt = lst.get(i); +======= + for (const at::IValue& elt : lst) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result.emplace_back(IValuePacker::unpack(elt)); } return result; diff --git a/torch/csrc/dynamo/cpython_defs.c b/torch/csrc/dynamo/cpython_defs.c index 244d4165d5e87..07e8781a10481 100644 --- a/torch/csrc/dynamo/cpython_defs.c +++ b/torch/csrc/dynamo/cpython_defs.c @@ -2,6 +2,7 @@ #include #include +<<<<<<< HEAD #if IS_PYTHON_3_14_PLUS const uint8_t* THP_PyOpcode_Caches = NULL; @@ -16,6 +17,8 @@ THP_PyFrame_Clear(_PyInterpreterFrame *frame) #else +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if IS_PYTHON_3_11_PLUS #define Py_BUILD_CORE @@ -374,5 +377,8 @@ const uint8_t* THP_PyOpcode_Caches = NULL; const int THP_PyOpcode_Caches_size = 0; #endif +<<<<<<< HEAD -#endif // IS_PYTHON_3_14_PLUS \ No newline at end of file +#endif // IS_PYTHON_3_14_PLUS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/csrc/dynamo/cpython_includes.h b/torch/csrc/dynamo/cpython_includes.h index 8c88addf5e42b..745ee074a71fe 100644 --- a/torch/csrc/dynamo/cpython_includes.h +++ b/torch/csrc/dynamo/cpython_includes.h @@ -21,6 +21,7 @@ #if IS_PYTHON_3_11_PLUS #include +<<<<<<< HEAD #if IS_PYTHON_3_14_PLUS #include #endif @@ -31,6 +32,8 @@ #if IS_PYTHON_3_14_PLUS && !defined(_WIN32) #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif #undef Py_BUILD_CORE @@ -40,6 +43,7 @@ extern "C" { #endif +<<<<<<< HEAD #if IS_PYTHON_3_14_PLUS && !defined(_WIN32) #define F_CODE(x) (PyCodeObject*)PyStackRef_AsPyObjectBorrow(x->f_executable) @@ -52,6 +56,8 @@ extern "C" { #else +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if IS_PYTHON_3_13_PLUS #define F_CODE(x) ((PyCodeObject*)(x)->f_executable) #define PREV_INSTR(x) (x)->instr_ptr @@ -60,8 +66,11 @@ extern "C" { #define PREV_INSTR(x) (x)->prev_instr #endif +<<<<<<< HEAD #endif // IS_PYTHON_3_14_PLUS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if IS_PYTHON_3_12_PLUS #define FUNC(x) ((x)->f_funcobj) #else diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index 07d28e7c77cfb..06f48c4048ed9 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -224,6 +224,20 @@ const char* get_frame_name(THP_EVAL_API_FRAME_OBJECT* frame) { return PyUnicode_AsUTF8(F_CODE(frame)->co_name); } +<<<<<<< HEAD +======= +void clear_old_frame_if_python_312_plus( + PyThreadState* tstate, + THP_EVAL_API_FRAME_OBJECT* frame) { +#if IS_PYTHON_3_12_PLUS + + THP_PyFrame_Clear(frame); + THP_PyThreadState_PopFrame(tstate, frame); + +#endif +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static PyObject* dynamo_eval_custom_code_impl( PyThreadState* tstate, THP_EVAL_API_FRAME_OBJECT* frame, @@ -474,6 +488,7 @@ static PyObject* dynamo__custom_eval_frame_shim( static void enable_eval_frame_shim(PyThreadState* tstate) {} static void enable_eval_frame_default(PyThreadState* tstate) {} +<<<<<<< HEAD PyObject* dynamo_eval_custom_code( PyThreadState* tstate, THP_EVAL_API_FRAME_OBJECT* frame, @@ -488,6 +503,10 @@ PyObject* dynamo_eval_frame_default( int throw_flag) { return NULL; } static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {{NULL}}; +======= + +static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {NULL}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static PyTypeObject THPPyInterpreterFrameType = { PyVarObject_HEAD_INIT(NULL, 0) @@ -499,6 +518,7 @@ static PyTypeObject THPPyInterpreterFrameType = { #endif // !(IS_PYTHON_3_14_PLUS) +<<<<<<< HEAD void clear_old_frame_if_python_312_plus( PyThreadState* tstate, THP_EVAL_API_FRAME_OBJECT* frame) { @@ -510,6 +530,8 @@ void clear_old_frame_if_python_312_plus( #endif } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static PyObject* increment_working_threads( PyThreadState* tstate, PyObject* module) { @@ -544,6 +566,10 @@ static PyObject* decrement_working_threads( static PyObject* set_eval_frame( PyObject* new_callback, +<<<<<<< HEAD +======= + PyThreadState* tstate, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PyObject* module) { // Change the eval frame callback and return the old one // - None: disables TorchDynamo @@ -551,6 +577,7 @@ static PyObject* set_eval_frame( // - Python callable(): enables TorchDynamo PyObject* old_callback = eval_frame_callback_get(); +<<<<<<< HEAD // Common case: if Dynamo is actually off, we might see a lot of // traffic setting the callback to None when it was already // None. Skip messing with threading, thread-local storage, and @@ -577,6 +604,24 @@ static PyObject* set_eval_frame( Py_INCREF(old_callback); } +======= + // owned by caller + Py_INCREF(old_callback); + + if (old_callback != Py_None && new_callback == Py_None) { + decrement_working_threads(tstate, module); + } else if (old_callback == Py_None && new_callback != Py_None) { + increment_working_threads(tstate, module); + } + + Py_INCREF(new_callback); + Py_DECREF(old_callback); + + // Set thread local callback. This will drive behavior of our shim, if/when it + // is installed. + eval_frame_callback_set(new_callback); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return old_callback; } @@ -591,7 +636,11 @@ static PyObject* set_eval_frame_py(PyObject* module, PyObject* callback) { "python enabled=%d and is run_only=%d", callback != Py_None, callback == Py_False); +<<<<<<< HEAD return set_eval_frame(callback, module); +======= + return set_eval_frame(callback, PyThreadState_GET(), module); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } static PyObject* set_skip_guard_eval_unsafe( diff --git a/torch/csrc/dynamo/eval_frame_cpp.cpp b/torch/csrc/dynamo/eval_frame_cpp.cpp index 77927f43b9058..0d45b2e77457f 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.cpp +++ b/torch/csrc/dynamo/eval_frame_cpp.cpp @@ -139,6 +139,7 @@ PyObject* dynamo__custom_eval_frame( auto fail = [&]() { clear_old_frame_if_python_312_plus(tstate, frame); }; +<<<<<<< HEAD #if IS_PYTHON_3_12_PLUS // skip tracing the frame if CPython is in a tracing state (e.g. // sys.monitoring call) @@ -148,6 +149,8 @@ PyObject* dynamo__custom_eval_frame( } #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ExtraState* extra = get_extra_state(F_CODE(frame)); if (callback.is(py::bool_(false)) && extra == nullptr) { @@ -344,6 +347,7 @@ PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* args) { extra_state_set_exec_strategy(extra, strategy); Py_RETURN_NONE; } +<<<<<<< HEAD void skip_code_recursive(PyCodeObject* code) { ExtraState* extra = get_extra_state(code); @@ -355,3 +359,5 @@ void skip_code_recursive(PyCodeObject* code) { FrameExecStrategy{FrameAction::SKIP, FrameAction::SKIP}; extra_state_set_exec_strategy(extra, strategy); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/csrc/dynamo/eval_frame_cpp.h b/torch/csrc/dynamo/eval_frame_cpp.h index 2f3587094f763..f3e5e93d023b5 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.h +++ b/torch/csrc/dynamo/eval_frame_cpp.h @@ -17,7 +17,10 @@ PyObject* dynamo__custom_eval_frame( PyObject* callback); PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* obj); +<<<<<<< HEAD void skip_code_recursive(PyCodeObject* code); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef __cplusplus diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp index ad617a8de5b09..2cccf967cf02c 100644 --- a/torch/csrc/dynamo/extra_state.cpp +++ b/torch/csrc/dynamo/extra_state.cpp @@ -152,8 +152,13 @@ void lookup( for (CacheEntry& cache_entry : extra_state->cache_entry_list) { // Check backend. Py_False means run only mode. +<<<<<<< HEAD bool valid = backend == Py_False || backend_match(cache_entry.backend.ptr(), backend); +======= + bool valid = + backend == Py_False || backend_match(cache_entry.backend, backend); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (valid) { try { @@ -268,6 +273,7 @@ void _load_precompile_entry( PrecompileEntry(std::move(guard_manager), std::move(dynamo_code)); extra->precompile_entries.push_back(std::move(entry)); } +<<<<<<< HEAD py::list _debug_get_precompile_entries(const py::handle& code_obj) { if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { @@ -283,3 +289,5 @@ py::list _debug_get_precompile_entries(const py::handle& code_obj) { } return result; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/csrc/dynamo/extra_state.h b/torch/csrc/dynamo/extra_state.h index 1630ac90b21dd..caf4bc5b7529f 100644 --- a/torch/csrc/dynamo/extra_state.h +++ b/torch/csrc/dynamo/extra_state.h @@ -202,6 +202,9 @@ void _load_precompile_entry( const py::handle& code_obj, py::object guard_manager, py::object dynamo_code); +<<<<<<< HEAD py::list _debug_get_precompile_entries(const py::handle& code_obj); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif diff --git a/torch/csrc/dynamo/framelocals_mapping.cpp b/torch/csrc/dynamo/framelocals_mapping.cpp index 16420ddc90e60..dcf6d99012815 100644 --- a/torch/csrc/dynamo/framelocals_mapping.cpp +++ b/torch/csrc/dynamo/framelocals_mapping.cpp @@ -4,9 +4,13 @@ #include #include +<<<<<<< HEAD #define Py_BUILD_CORE #include #undef Py_BUILD_CORE +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if IS_PYTHON_3_11_PLUS @@ -28,6 +32,7 @@ FrameLocalsMapping::FrameLocalsMapping(FrameLocalsFrameType* frame) PyCodeObject* co = F_CODE(frame); _framelocals.resize(co->co_nlocalsplus, nullptr); +<<<<<<< HEAD #if IS_PYTHON_3_14_PLUS TORCH_CHECK(false, "Python 3.14+ not supported"); #else @@ -35,6 +40,11 @@ FrameLocalsMapping::FrameLocalsMapping(FrameLocalsFrameType* frame) return; } #endif +======= + if (!frame->stacktop) { + return; + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto update_framelocals = [&](int i, PyObject* value) { _PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i); @@ -59,6 +69,7 @@ FrameLocalsMapping::FrameLocalsMapping(FrameLocalsFrameType* frame) }; auto offset = co->co_nlocalsplus - co->co_nfreevars; +<<<<<<< HEAD #if IS_PYTHON_3_14_PLUS TORCH_CHECK(false, "Python 3.14+ not supported"); #else @@ -74,6 +85,13 @@ FrameLocalsMapping::FrameLocalsMapping(FrameLocalsFrameType* frame) #else PyObject* closure = ((PyFunctionObject*)FUNC(frame))->func_closure; #endif +======= + for (int i = 0; i < offset; i++) { + update_framelocals(i, frame->localsplus[i]); + } + // Get references to closure variables + PyObject* closure = ((PyFunctionObject*)FUNC(frame))->func_closure; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (int i = 0; i < co->co_nfreevars; i++) { update_framelocals(offset + i, PyTuple_GET_ITEM(closure, i)); } diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index c8e0ae9c27360..436e772798b0e 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -6,7 +6,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -31,15 +34,19 @@ #include #endif +<<<<<<< HEAD #ifdef USE_MTIA #include #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include #include +<<<<<<< HEAD // Uncomment next line to count instructions for guard eval. // #define GUARD_INSTRUCTION_COUNT #ifdef GUARD_INSTRUCTION_COUNT @@ -79,6 +86,8 @@ uint64_t count_instructions(const std::function& fn) { } #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Certain CPython data structures are defined in `.c` files in earlier Python // versions, e.g., for TupleIteratorGetItemAccessor, we need a fast way to // retrieve the underlying tuple and access the item. Before Python 3.12 @@ -634,7 +643,11 @@ struct GlobalStateGuard { _torch_function_all_disabled = at::impl::torch_function_all_disabled(); _deterministic_algorithms = ctx.deterministicAlgorithms(); _deterministic_algorithms_warn_only = ctx.deterministicAlgorithmsWarnOnly(); +<<<<<<< HEAD _allow_tf32 = ctx.float32Precision("cuda", "matmul") == "tf32"; +======= + _allow_tf32 = ctx.allowTF32CuBLAS(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _allow_fp16_reduce = ctx.allowFP16ReductionCuBLAS(); _allow_bf16_reduce = ctx.allowBF16ReductionCuBLAS(); _num_threads = at::get_num_threads(); @@ -651,7 +664,11 @@ struct GlobalStateGuard { _deterministic_algorithms == ctx.deterministicAlgorithms() && _deterministic_algorithms_warn_only == ctx.deterministicAlgorithmsWarnOnly() && +<<<<<<< HEAD _allow_tf32 == (ctx.float32Precision("cuda", "matmul") == "tf32") && +======= + _allow_tf32 == ctx.allowTF32CuBLAS() && +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _allow_fp16_reduce == ctx.allowFP16ReductionCuBLAS() && _allow_bf16_reduce == ctx.allowBF16ReductionCuBLAS() && _num_threads == at::get_num_threads()) && @@ -672,7 +689,11 @@ struct GlobalStateGuard { if (_deterministic_algorithms_warn_only != ctx.deterministicAlgorithmsWarnOnly()) os << "deterministic_algorithms_warn_only "; +<<<<<<< HEAD if (_allow_tf32 != (ctx.float32Precision("cuda", "matmul") == "tf32")) +======= + if (_allow_tf32 != ctx.allowTF32CuBLAS()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) os << "allow_tf32 "; if (_allow_fp16_reduce != ctx.allowFP16ReductionCuBLAS()) os << "allow_fp16_reduce "; @@ -788,11 +809,19 @@ static PyMethodDef GlobalStateGuard_methods[] = { (PyCFunction)(void*)GlobalStateGuard_reason, METH_NOARGS, "Return string reason for guard check failing"}, +<<<<<<< HEAD {"__getstate__", (PyCFunction)(void*)GlobalStateGuard_dump, METH_NOARGS, "Return serialized json format"}, {"__setstate__", +======= + {"dump", + (PyCFunction)(void*)GlobalStateGuard_dump, + METH_NOARGS, + "Return serialized json format"}, + {"load", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (PyCFunction)(void*)GlobalStateGuard_load, METH_VARARGS, "Parse serialized json format"}, @@ -834,7 +863,10 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) { static std::unordered_map dict_version_map; static int dict_version_watcher_id; +<<<<<<< HEAD static int dict_recursive_tag_watcher_id; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static uint64_t global_dict_version_id = 1; static int dict_version_watch_callback( PyDict_WatchEvent event, @@ -1043,8 +1075,12 @@ static void _parse_empty_strided_args( static PyObject* _empty_strided_device( PyObject* dummy, PyObject* args, +<<<<<<< HEAD c10::DeviceType device_type, bool is_pinned = false) { +======= + c10::DeviceType device_type) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) HANDLE_TH_ERRORS; at::SmallVector sizes; at::SmallVector strides; @@ -1052,7 +1088,11 @@ static PyObject* _empty_strided_device( _parse_empty_strided_args(args, sizes, strides, dtype); if (device_type == c10::DeviceType::CPU) { return THPVariable_Wrap( +<<<<<<< HEAD at::detail::empty_strided_cpu(sizes, strides, dtype, is_pinned)); +======= + at::detail::empty_strided_cpu(sizes, strides, dtype)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } #ifdef USE_CUDA else if (device_type == c10::DeviceType::CUDA) { @@ -1066,12 +1106,15 @@ static PyObject* _empty_strided_device( sizes, strides, dtype, c10::DeviceType::XPU)); } #endif +<<<<<<< HEAD #ifdef USE_MTIA else if (device_type == c10::DeviceType::MTIA) { return THPVariable_Wrap(at::detail::empty_strided_mtia( sizes, strides, dtype, c10::DeviceType::MTIA)); } #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else { TORCH_CHECK( false, "PyTorch compiled without support for the specified device."); @@ -1086,6 +1129,7 @@ static PyObject* _empty_strided_cpu(PyObject* dummy, PyObject* args) { return _empty_strided_device(dummy, args, c10::DeviceType::CPU); } +<<<<<<< HEAD static PyObject* _empty_strided_cpu_pinned(PyObject* dummy, PyObject* args) { // at::empty_strided is surprising slow. This is a lower-overhead // version that saves ~2us on every allocation. @@ -1093,6 +1137,8 @@ static PyObject* _empty_strided_cpu_pinned(PyObject* dummy, PyObject* args) { dummy, args, c10::DeviceType::CPU, /*is_pinned=*/true); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static PyObject* _empty_strided_cuda(PyObject* dummy, PyObject* args) { // at::empty_strided is surprising slow. This is lower-overhead. return _empty_strided_device(dummy, args, c10::DeviceType::CUDA); @@ -1103,10 +1149,13 @@ static PyObject* _empty_strided_xpu(PyObject* dummy, PyObject* args) { return _empty_strided_device(dummy, args, c10::DeviceType::XPU); } +<<<<<<< HEAD static PyObject* _empty_strided_mtia(PyObject* dummy, PyObject* args) { return _empty_strided_device(dummy, args, c10::DeviceType::MTIA); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static PyObject* _reinterpret_tensor(PyObject* dummy, PyObject* args) { HANDLE_TH_ERRORS; static PythonArgParser parser( @@ -1136,6 +1185,7 @@ static PyMethodDef _methods[] = { {"assert_alignment", assert_alignment, METH_VARARGS, nullptr}, {"dict_version", dict_version, METH_VARARGS, nullptr}, {"_empty_strided_cpu", _empty_strided_cpu, METH_VARARGS, nullptr}, +<<<<<<< HEAD {"_empty_strided_cpu_pinned", _empty_strided_cpu_pinned, METH_VARARGS, @@ -1143,6 +1193,10 @@ static PyMethodDef _methods[] = { {"_empty_strided_cuda", _empty_strided_cuda, METH_VARARGS, nullptr}, {"_empty_strided_xpu", _empty_strided_xpu, METH_VARARGS, nullptr}, {"_empty_strided_mtia", _empty_strided_mtia, METH_VARARGS, nullptr}, +======= + {"_empty_strided_cuda", _empty_strided_cuda, METH_VARARGS, nullptr}, + {"_empty_strided_xpu", _empty_strided_xpu, METH_VARARGS, nullptr}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"_reinterpret_tensor", _reinterpret_tensor, METH_VARARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; @@ -1185,12 +1239,18 @@ bool is_immutable_object(py::handle example_value) { return true; } +<<<<<<< HEAD return (example_value.ptr() == Py_None) || PyLong_Check(example_value.ptr()) || PyFloat_Check(example_value.ptr()) || PyBool_Check(example_value.ptr()) || PyUnicode_Check(example_value.ptr()) || PyCode_Check(example_value.ptr()) || (Py_TYPE(example_value.ptr()) == &PyCFunction_Type) || +======= + return PyLong_Check(example_value.ptr()) || + PyFloat_Check(example_value.ptr()) || PyBool_Check(example_value.ptr()) || + PyUnicode_Check(example_value.ptr()) || +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (is_tensor_immutable && THPVariable_Check(example_value.ptr())); } @@ -1558,6 +1618,7 @@ class GuardManager; class RootGuardManager; class DictGuardManager; +<<<<<<< HEAD // Global registry used by the *recursive-dict-tag* optimisation. // // Key : `PyObject*` pointing to a watched `dict` @@ -1589,11 +1650,22 @@ class DictGuardManager; // stores only lightweight pointers. std::unordered_map> dict_to_guard_managers; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /** * Base class for the leaf guard in the GuardManager hierarchy. */ class LeafGuard { public: +<<<<<<< HEAD +======= + // Most guards do not need root guard manager. + LeafGuard(py::object verbose_code_parts) + : _verbose_code_parts(std::move(verbose_code_parts)) {} + + // Guards like TENSOR_MATCH require root_guard_manager to access local_state + // shared across all leaf guards. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) LeafGuard(RootGuardManager* root_guard_manager, py::object verbose_code_parts) : _root_guard_manager(root_guard_manager), _verbose_code_parts(std::move(verbose_code_parts)) {} @@ -1655,11 +1727,16 @@ class LeafGuard { */ class LAMBDA_GUARD : public LeafGuard { public: +<<<<<<< HEAD LAMBDA_GUARD( RootGuardManager* root_guard_manager, py::object guard_check_fn, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) { +======= + LAMBDA_GUARD(py::object guard_check_fn, py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (py::isinstance(guard_check_fn)) { _guard_check_fn = py::cast(std::move(guard_check_fn)); } else { @@ -1704,11 +1781,16 @@ class LAMBDA_GUARD : public LeafGuard { class TYPE_MATCH : public LeafGuard { public: // type_id = id(type(obj)) +<<<<<<< HEAD TYPE_MATCH( RootGuardManager* root_guard_manager, py::object type_id, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), +======= + TYPE_MATCH(py::object type_id, py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _expected(py::cast(std::move(type_id))) {} bool check_nopybind(PyObject* value) override { // borrowed ref @@ -1724,11 +1806,16 @@ class TYPE_MATCH : public LeafGuard { class ID_MATCH : public LeafGuard { public: // obj_id = id(obj) +<<<<<<< HEAD ID_MATCH( RootGuardManager* root_guard_manager, py::object obj_id, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), +======= + ID_MATCH(py::object obj_id, py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _expected(py::cast(std::move(obj_id))) {} bool check_nopybind(PyObject* value) override { // borrowed ref @@ -1743,10 +1830,15 @@ class ID_MATCH : public LeafGuard { class NONE_MATCH : public LeafGuard { public: +<<<<<<< HEAD NONE_MATCH( RootGuardManager* root_guard_manager, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {} +======= + NONE_MATCH(py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)) {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool check_nopybind(PyObject* value) override { // borrowed ref return value == Py_None; @@ -1755,10 +1847,15 @@ class NONE_MATCH : public LeafGuard { class TRUE_MATCH : public LeafGuard { public: +<<<<<<< HEAD TRUE_MATCH( RootGuardManager* root_guard_manager, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {} +======= + TRUE_MATCH(py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)) {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool check_nopybind(PyObject* value) override { // borrowed ref return value == Py_True; @@ -1767,10 +1864,15 @@ class TRUE_MATCH : public LeafGuard { class FALSE_MATCH : public LeafGuard { public: +<<<<<<< HEAD FALSE_MATCH( RootGuardManager* root_guard_manager, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {} +======= + FALSE_MATCH(py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)) {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool check_nopybind(PyObject* value) override { // borrowed ref return value == Py_False; @@ -1779,11 +1881,16 @@ class FALSE_MATCH : public LeafGuard { class EQUALS_MATCH : public LeafGuard { public: +<<<<<<< HEAD EQUALS_MATCH( RootGuardManager* root_guard_manager, py::object value, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), +======= + EQUALS_MATCH(py::object value, py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _value(value), _value_type(Py_TYPE(value.ptr())) {} @@ -1820,13 +1927,20 @@ class EQUALS_MATCH : public LeafGuard { class RANGE_ITERATOR_MATCH : public LeafGuard { public: RANGE_ITERATOR_MATCH( +<<<<<<< HEAD RootGuardManager* root_guard_manager, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::object start, py::object stop, py::object step, py::object type_id, py::object verbose_code_parts) +<<<<<<< HEAD : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), +======= + : LeafGuard(std::move(verbose_code_parts)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _type_id(py::cast(std::move(type_id))) { PyObject* start_obj = start.ptr(); PyObject* stop_obj = stop.ptr(); @@ -1867,11 +1981,18 @@ class RANGE_ITERATOR_MATCH : public LeafGuard { class TUPLE_ITERATOR_LEN : public LeafGuard { public: TUPLE_ITERATOR_LEN( +<<<<<<< HEAD RootGuardManager* root_guard_manager, py::object length, py::object type_id, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), +======= + py::object length, + py::object type_id, + py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _length(py::cast(std::move(length))), _type_id(py::cast(std::move(type_id))) {} @@ -1896,11 +2017,16 @@ class TUPLE_ITERATOR_LEN : public LeafGuard { class LENGTH_CHECK : public LeafGuard { public: +<<<<<<< HEAD LENGTH_CHECK( RootGuardManager* root_guard_manager, py::object value, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), +======= + LENGTH_CHECK(py::object value, py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _length(py::cast(std::move(value))) {} bool check_nopybind(PyObject* value) override { // borrowed ref @@ -1916,11 +2042,16 @@ class LENGTH_CHECK : public LeafGuard { class DICT_LENGTH : public LeafGuard { public: +<<<<<<< HEAD DICT_LENGTH( RootGuardManager* root_guard_manager, py::object value, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), +======= + DICT_LENGTH(py::object value, py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _length(py::cast(std::move(value))) {} bool check_nopybind(PyObject* value) override { // borrowed ref @@ -1934,8 +2065,13 @@ class DICT_LENGTH : public LeafGuard { class NOT_NONE : public LeafGuard { public: +<<<<<<< HEAD NOT_NONE(RootGuardManager* root_guard_manager, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {} +======= + NOT_NONE(py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)) {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool check_nopybind(PyObject* value) override { // borrowed ref return value != Py_None; @@ -1944,11 +2080,16 @@ class NOT_NONE : public LeafGuard { class MAPPING_KEYS_MATCH : public LeafGuard { public: +<<<<<<< HEAD MAPPING_KEYS_MATCH( RootGuardManager* root_guard_manager, py::object value, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) { +======= + MAPPING_KEYS_MATCH(py::object value, py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // This is ok to stash in the state because we only support // MappingProxyType objects with constant keys. So, the mem overhead is // negligible. @@ -1968,10 +2109,15 @@ class MAPPING_KEYS_MATCH : public LeafGuard { class DEFAULT_DEVICE : public LeafGuard { public: +<<<<<<< HEAD DEFAULT_DEVICE( RootGuardManager* root_guard_manager, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) { +======= + DEFAULT_DEVICE(py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::handle device_module = py::module::import("torch.utils._device"); // Save the dict using py::object _utils_device_dict = device_module.attr("__dict__"); @@ -2015,6 +2161,7 @@ class DEFAULT_DEVICE : public LeafGuard { class GLOBAL_STATE : public LeafGuard { public: +<<<<<<< HEAD GLOBAL_STATE( RootGuardManager* root_guard_manager, py::object verbose_code_parts) @@ -2034,6 +2181,12 @@ class GLOBAL_STATE : public LeafGuard { if (!PyObject_TypeCheck(owner_.ptr(), &GlobalStateGuardType)) { throw py::type_error("GLOBAL_STATE expects a GlobalStateGuard"); } +======= + GLOBAL_STATE(py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)) { + _guard = std::make_unique(); + _guard->init(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } bool check_nopybind(PyObject* value) override { // borrowed ref @@ -2055,8 +2208,12 @@ class GLOBAL_STATE : public LeafGuard { } private: +<<<<<<< HEAD py::object owner_; GlobalStateGuard* _guard; +======= + std::unique_ptr _guard; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; // Checks that an attr is absent in the object. We don't need the opposite @@ -2064,11 +2221,16 @@ class GLOBAL_STATE : public LeafGuard { // HASATTR guard. class NO_HASATTR : public LeafGuard { public: +<<<<<<< HEAD NO_HASATTR( RootGuardManager* root_guard_manager, py::object attr_name, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), +======= + NO_HASATTR(py::object attr_name, py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _attr_name(std::move(attr_name)) {} bool check_nopybind(PyObject* value) override { // borrowed ref @@ -2086,12 +2248,17 @@ class NO_HASATTR : public LeafGuard { // being faster. class DICT_CONTAINS : public LeafGuard { public: +<<<<<<< HEAD DICT_CONTAINS( RootGuardManager* root_guard_manager, bool contains, py::object key, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), +======= + DICT_CONTAINS(bool contains, py::object key, py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _contains(contains ? 1 : 0), _key(std::move(key)) {} @@ -2109,6 +2276,7 @@ class DICT_CONTAINS : public LeafGuard { py::object _key; }; +<<<<<<< HEAD // Check that set contains an item. class SET_CONTAINS : public LeafGuard { public: @@ -2136,6 +2304,8 @@ class SET_CONTAINS : public LeafGuard { py::object _item; }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /** * Relational guards compare more than one value. We implement Relational * guards by capturing some state in the guard object. For example for tensor @@ -2153,10 +2323,15 @@ class SET_CONTAINS : public LeafGuard { */ class RelationalGuard : public LeafGuard { public: +<<<<<<< HEAD RelationalGuard( RootGuardManager* root_guard_manager, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) {} +======= + RelationalGuard(py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)) {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // reset the relational guard state on guard failure. This is called by the // guard manager. @@ -2168,10 +2343,15 @@ class RelationalGuard : public LeafGuard { */ class OBJECT_ALIASING : public RelationalGuard { public: +<<<<<<< HEAD OBJECT_ALIASING( RootGuardManager* root_guard_manager, py::object verbose_code_parts) : RelationalGuard(root_guard_manager, std::move(verbose_code_parts)) {} +======= + OBJECT_ALIASING(py::object verbose_code_parts) + : RelationalGuard(std::move(verbose_code_parts)) {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool check_nopybind(PyObject* value) override { // borrowed ref if (_is_first_call) { @@ -2197,10 +2377,16 @@ class OBJECT_ALIASING : public RelationalGuard { class NO_TENSOR_ALIASING : public RelationalGuard { public: NO_TENSOR_ALIASING( +<<<<<<< HEAD RootGuardManager* root_guard_manager, const py::list& tensor_names, py::object verbose_code_parts) : RelationalGuard(root_guard_manager, std::move(verbose_code_parts)), +======= + const py::list& tensor_names, + py::object verbose_code_parts) + : RelationalGuard(std::move(verbose_code_parts)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _tensor_names(tensor_names) { _unique_tensors.reserve(tensor_names.size()); } @@ -2248,11 +2434,18 @@ class NO_TENSOR_ALIASING : public RelationalGuard { class STORAGE_OVERLAPPING : public RelationalGuard { public: STORAGE_OVERLAPPING( +<<<<<<< HEAD RootGuardManager* root_guard_manager, bool overlapping, std::shared_ptr checker, py::object verbose_code_parts) : RelationalGuard(root_guard_manager, std::move(verbose_code_parts)), +======= + bool overlapping, + std::shared_ptr checker, + py::object verbose_code_parts) + : RelationalGuard(std::move(verbose_code_parts)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _overlapping(overlapping), _checker(std::move(checker)) {} @@ -2280,13 +2473,20 @@ class STORAGE_OVERLAPPING : public RelationalGuard { class SYMBOLIC_SHAPE_GUARD : public RelationalGuard { public: SYMBOLIC_SHAPE_GUARD( +<<<<<<< HEAD RootGuardManager* root_guard_manager, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::int_ nargs_int, py::int_ nargs_float, py::int_ py_addr, py::object py_addr_keep_alive, py::object verbose_code_parts) +<<<<<<< HEAD : RelationalGuard(root_guard_manager, std::move(verbose_code_parts)), +======= + : RelationalGuard(std::move(verbose_code_parts)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _py_addr_keep_alive(std::move(py_addr_keep_alive)) { _nargs_int = PyLong_AsSize_t(nargs_int.ptr()); _nargs_float = PyLong_AsSize_t(nargs_float.ptr()); @@ -2394,12 +2594,19 @@ class DYNAMIC_INDICES : public LeafGuard { // f"(({tensor_name}._dynamo_dynamic_indices.issubset({value._dynamo_dynamic_indices})) // if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)" # // noqa: B950 +<<<<<<< HEAD public: DYNAMIC_INDICES( RootGuardManager* root_guard_manager, py::set dynamic_indices, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)), +======= + // ) + public: + DYNAMIC_INDICES(py::set dynamic_indices, py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _dynamic_indices(std::move(dynamic_indices)) {} bool check_nopybind(PyObject* value) override { // borrowed ref @@ -2429,11 +2636,16 @@ class DYNAMIC_INDICES : public LeafGuard { class DICT_VERSION : public LeafGuard { public: +<<<<<<< HEAD DICT_VERSION( RootGuardManager* root_guard_manager, py::object value, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) { +======= + DICT_VERSION(py::object value, py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!PyDict_Check(value.ptr())) { throw py::type_error("DICT_VERSION expects a dict"); } @@ -2458,6 +2670,7 @@ std::unique_ptr make_guard_manager( py::handle example_value, py::handle guard_manager_enum); +<<<<<<< HEAD // Forward declarations for tag safe related helpers. All of these require some // interaction between RootGuardManager and GuardManager. Since both of the // classes are forward declared, we have to forward declare these helpers as @@ -2473,6 +2686,8 @@ bool is_recording_dict_pointers(RootGuardManager* root); void record_dict_pointer(RootGuardManager* root, PyObject* dict_pointer); void record_tensor_pointer(RootGuardManager* root, PyObject* tensor_pointer); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GuardManager* clone_guard_manager( GuardManager* from, RootGuardManager* root, @@ -2480,6 +2695,7 @@ GuardManager* clone_guard_manager( void add_relational_guard_resetter_to_cloned_root( RootGuardManager* root, std::shared_ptr guard); +<<<<<<< HEAD std::shared_ptr get_no_tensor_aliasing_guard( RootGuardManager* _root); // std::string get_compile_id(RootGuardManager* root); @@ -2488,6 +2704,9 @@ struct WeakEntry { PyObject* wr; // weakref PyObject* cap; // capsule whose m_self is used by the callback }; +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /** * Base class representing a pair of accessor and the associated guard * manager. The accessor defines how to access the child value from the @@ -2638,6 +2857,7 @@ class GuardManager { py::handle example_value) : _root(root), _source(std::move(source)), +<<<<<<< HEAD _is_dict(py::isinstance(example_value)), _is_immutable(is_immutable_object(example_value)) { if (_is_dict) { @@ -2650,10 +2870,17 @@ class GuardManager { _max_saved_pointers_for_recursive_dict_tags_check = config_module.attr("max_saved_pointers_for_recursive_dict_tags_check") .cast(); +======= + _is_dict(py::isinstance(example_value)) { + if (_is_dict) { + _dict_tag = get_dict_version_unchecked(example_value.ptr()); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } GuardManager(const GuardManager& m) = delete; GuardManager& operator=(const GuardManager&) = delete; +<<<<<<< HEAD virtual ~GuardManager() { cleanup_tag_safe_entries(); @@ -2671,6 +2898,9 @@ class GuardManager { } _tag_safe_entries.clear(); } +======= + virtual ~GuardManager() = default; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) RootGuardManager* get_root() { return _root; @@ -2685,6 +2915,7 @@ class GuardManager { } public: +<<<<<<< HEAD // relational guard helpers void set_has_object_aliasing_guard() { _has_object_aliasing_guard = true; @@ -2778,6 +3009,11 @@ class GuardManager { _is_dict(is_dict), _is_immutable(is_immutable), _weak_type(weak_type) {} +======= + // For cloning + GuardManager(RootGuardManager* root, std::string source, bool is_dict) + : _root(root), _source(std::move(source)), _is_dict(is_dict) {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void clone_common( RootGuardManager* cloned_root, @@ -2808,6 +3044,7 @@ class GuardManager { if (!py::cast(clone_filter_fn(this))) { return nullptr; } +<<<<<<< HEAD GuardManager* cloned_mgr = new GuardManager( cloned_root, _source, _is_dict, _is_immutable, _weak_type); if (is_tag_safe()) { @@ -2816,6 +3053,9 @@ class GuardManager { cloned_mgr->mark_tag_safe_root(); } } +======= + GuardManager* cloned_mgr = new GuardManager(cloned_root, _source, _is_dict); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) clone_common(cloned_root, cloned_mgr, clone_filter_fn); return cloned_mgr; } @@ -2870,6 +3110,7 @@ class GuardManager { return this->check_accessors_nopybind(value); } +<<<<<<< HEAD bool check_dict_pointer_tags(PyObject* value) { if (_dict_callback_installed) { // This means that for 3.12+, there are callbacks watching dict pointers. @@ -3152,6 +3393,10 @@ class GuardManager { } } #endif +======= + virtual bool check_nopybind(PyObject* value) { + return check_nopybind_template(value); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } virtual bool check_nopybind(FrameLocalsMapping* value) { @@ -3375,6 +3620,7 @@ class GuardManager { // to enable fail fast for the next check. std::vector> _accessors; +<<<<<<< HEAD // relational guard helpers bool _has_object_aliasing_guard = false; bool _has_no_tensor_aliasing_guard = false; @@ -3400,6 +3646,10 @@ class GuardManager { // weakref to the type of guarded value // protected because it is used for cloning by DictGuardManager py::object _weak_type; +======= + bool _is_dict; + uint64_t _dict_tag{0}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; GuardAccessor::GuardAccessor( @@ -3446,6 +3696,7 @@ class RootGuardManager : public GuardManager { // This is the root node, set its _root member to nullptr RootGuardManager() : GuardManager(this, "L") {} +<<<<<<< HEAD void add_no_tensor_aliasing_guard( std::shared_ptr no_tensor_aliasing_guard) { // stash a pointer to the _no_tensor_alising_guard @@ -3457,6 +3708,8 @@ class RootGuardManager : public GuardManager { return _no_tensor_aliasing_guard; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Adds the relational guard resetter void add_relational_guard_resetter( std::shared_ptr relational_guard) { @@ -3483,9 +3736,12 @@ class RootGuardManager : public GuardManager { std::lock_guard lock_guard(_lock); Py_BLOCK_THREADS; // ; is added to avoid clang-formatting +<<<<<<< HEAD // Clean up dict pointer recording for tag safe roots reset_dict_tag_recording_variables(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Get the local state. This will be used for TENSOR_MATCH guards. if (_init_local_state) { LocalState state; @@ -3628,6 +3884,7 @@ class RootGuardManager : public GuardManager { return ret; } +<<<<<<< HEAD void attach_compile_id(std::string compile_id) { _compile_id = compile_id; } @@ -3636,6 +3893,8 @@ class RootGuardManager : public GuardManager { // return _compile_id; // } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) private: // Reset the state of all the relational guards on failure. void _reset_relational_guard_state() { @@ -3645,6 +3904,7 @@ class RootGuardManager : public GuardManager { } public: +<<<<<<< HEAD // tag safe optimizations void start_recording_dict_pointers(GuardManager* tag_safe_root) { _current_tag_safe_root = tag_safe_root; @@ -3683,6 +3943,8 @@ class RootGuardManager : public GuardManager { } public: +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Local state for TENSOR_MATCH guards. LocalState _local_state; @@ -3726,6 +3988,7 @@ class RootGuardManager : public GuardManager { // We init LocalState only when this flag it set. This flag is set during // TENSOR_MATCH guard init. bool _init_local_state = false; +<<<<<<< HEAD // debug info std::string _compile_id; @@ -3738,6 +4001,8 @@ class RootGuardManager : public GuardManager { GuardManager* _current_tag_safe_root{nullptr}; std::vector> _recorded_dict_pointers; std::vector _recorded_tensor_pointers; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; /* @@ -3756,7 +4021,11 @@ class DictGuardManager : public GuardManager { RootGuardManager* root, std::string source, py::handle example_value) +<<<<<<< HEAD : GuardManager(root, std::move(source), example_value), +======= + : GuardManager(root, std::move(source)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _size(PyDict_Size(example_value.ptr())), _expected_type(Py_TYPE(example_value.ptr())), _is_exact_dict_type(PyDict_CheckExact(example_value.ptr())) {} @@ -3971,6 +4240,7 @@ class DictGuardManager : public GuardManager { Py_ssize_t size, PyTypeObject* expected_type, bool is_exact_dict_type, +<<<<<<< HEAD std::vector indices, py::object weak_type) : GuardManager( @@ -3979,6 +4249,10 @@ class DictGuardManager : public GuardManager { true, // _is_dict false, // _is_immutable weak_type), +======= + std::vector indices) + : GuardManager(cloned_root, std::move(source), true), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _size(size), _expected_type(expected_type), _is_exact_dict_type(is_exact_dict_type), @@ -3997,6 +4271,7 @@ class DictGuardManager : public GuardManager { _size, _expected_type, _is_exact_dict_type, +<<<<<<< HEAD _indices, _weak_type); if (is_tag_safe()) { @@ -4005,6 +4280,10 @@ class DictGuardManager : public GuardManager { cloned_mgr->mark_tag_safe_root(); } } +======= + _indices); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) clone_common(cloned_root, cloned_mgr, clone_filter_fn); for (auto index : _indices) { KeyValueManager& key_value_manager = _key_value_managers[index]; @@ -4083,6 +4362,7 @@ void add_relational_guard_resetter_to_cloned_root( root->add_relational_guard_resetter(std::move(guard)); } +<<<<<<< HEAD #if IS_PYTHON_3_12_PLUS static int dict_recursive_tag_watch_callback( PyDict_WatchEvent event, @@ -4104,6 +4384,8 @@ static int dict_recursive_tag_watch_callback( } #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::unique_ptr make_guard_manager( RootGuardManager* root, std::string source, @@ -4147,6 +4429,7 @@ std::unique_ptr make_guard_manager( throw py::type_error("Invalid guard manager enum"); } } +<<<<<<< HEAD return std::make_unique(root, std::move(source), example_value); } @@ -4191,6 +4474,17 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { const py::list& initial_stack, py::object verbose_code_parts) : LeafGuard(root_guard_manager, std::move(verbose_code_parts)) { +======= + return std::make_unique(root, std::move(source)); +} + +class TORCH_FUNCTION_MODE_STACK : public LeafGuard { + public: + TORCH_FUNCTION_MODE_STACK( + const py::list& initial_stack, + py::object verbose_code_parts) + : LeafGuard(std::move(verbose_code_parts)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref @@ -4403,10 +4697,13 @@ class GetAttrGuardAccessor : public GuardAccessor { ")"; } +<<<<<<< HEAD std::string get_attr_name() { return py::str(_attr_name).cast(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) public: // cloning functions GetAttrGuardAccessor(GuardManager* guard_manager, GetAttrGuardAccessor* from) : GuardAccessor(guard_manager, from) { @@ -4530,6 +4827,7 @@ class GetGenericDictGuardAccessor : public GuardAccessor { // check_verbose_nopybind. bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) override { // borrowed ref +<<<<<<< HEAD // NOTE for future guard optimization developers - We tried saving the dict // pointer and weakref of the original object to avoid calling // PyObject_GenericGetDict on a fast path, but this did not lead any @@ -4537,6 +4835,8 @@ class GetGenericDictGuardAccessor : public GuardAccessor { // 1) Once __dict__ is generated, accessing it the second time is fast. // 2) Getting the object from weakref, from 3.13 onwards, requires // Py_DECREF, which further eats into the benefit. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PyObject* x = PyObject_GenericGetDict(obj, nullptr); // new ref if (x == nullptr) { // Attribute absent, clear the exception and return false. @@ -4806,7 +5106,10 @@ class DictGetItemGuardAccessor : public GuardAccessor { // check_verbose_nopybind. bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) override { if (matches_dict_tag && _is_immutable_object && +<<<<<<< HEAD !is_recording_dict_pointers(get_guard_manager()->get_root()) && +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _guard_manager->has_no_accessors()) { // immutable object and dict tag matches, we can skip the guard subtree. // NB: We only skip the subtree if there are no accessors in the subtree. @@ -4940,6 +5243,7 @@ class ListGetItemGuardAccessor : public GuardAccessor { }; /** +<<<<<<< HEAD * Represents set[index] accessor by converting the set into a list. */ class SetGetItemGuardAccessor : public GuardAccessor { @@ -5016,6 +5320,8 @@ class SetGetItemGuardAccessor : public GuardAccessor { }; /** +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * Represents tuple[index] accessor. It is faster than generic * GetItemGuardAccessor. */ @@ -5631,6 +5937,7 @@ class TypeGuardAccessor : public GuardAccessor { }; /** +<<<<<<< HEAD * Represent x.__dict__ accessor, where x is type object. */ class TypeDictGuardAccessor : public GuardAccessor { @@ -5743,6 +6050,8 @@ class TypeMROGuardAccessor : public GuardAccessor { }; /** +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * Getitem tuple_iterator accessor. */ class TupleIteratorGetItemAccessor : public GuardAccessor { @@ -6018,6 +6327,7 @@ class WeakRefCallGuardAccessor : public GuardAccessor { }; /** +<<<<<<< HEAD * Represent x.__code__ */ class CodeGuardAccessor : public GuardAccessor { @@ -6170,6 +6480,8 @@ class ClosureGuardAccessor : public GuardAccessor { }; /** +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * Implements function call no args - e.g, torch.cuda.current_device() */ class CallFunctionNoArgsGuardAccessor : public GuardAccessor { @@ -6331,16 +6643,24 @@ void install_object_aliasing_guard( py::object verbose_code_parts) { // Adds tensor X is tensor Y guard. This is a an example of relational guard. // There is one guard object that is shared between two guard managers. +<<<<<<< HEAD std::shared_ptr guard = std::make_shared( x->get_root(), std::move(verbose_code_parts)); +======= + std::shared_ptr guard = + std::make_shared(std::move(verbose_code_parts)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Register the resetter on the root guard manager, so that it can reset // the newly added relational guard when the guard eval fails. x->get_root()->add_relational_guard_resetter(guard); +<<<<<<< HEAD x->set_has_object_aliasing_guard(); y->set_has_object_aliasing_guard(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // In case the guard is a DictGuardManager, OBJECT_ALIASING guard is a // permitted guard. x->add_permitted_leaf_guard(guard); @@ -6355,19 +6675,29 @@ void install_no_tensor_aliasing_guard( // relational guard. There is one guard object that is shared between multiple // guard managers. std::shared_ptr guard = std::make_shared( +<<<<<<< HEAD py::cast(guard_managers[0])->get_root(), tensor_names, std::move(verbose_code_parts)); +======= + tensor_names, std::move(verbose_code_parts)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Register the resetter on the root guard manager, so that it can reset // the newly added relational guard when the guard eval fails. py::cast(guard_managers[0]) ->get_root() +<<<<<<< HEAD ->add_no_tensor_aliasing_guard(guard); for (const auto& guard_manager : guard_managers) { py::cast(guard_manager)->add_leaf_guard(guard); py::cast(guard_manager)->set_has_no_tensor_aliasing_guard(); +======= + ->add_relational_guard_resetter(guard); + for (const auto& guard_manager : guard_managers) { + py::cast(guard_manager)->add_leaf_guard(guard); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -6383,7 +6713,10 @@ void install_symbolic_shape_guard( // multiple guard managers. std::shared_ptr guard = std::make_shared( +<<<<<<< HEAD py::cast(guard_managers[0])->get_root(), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::move(nargs_int), std::move(nargs_float), std::move(py_addr), @@ -6413,10 +6746,14 @@ void install_storage_overlapping_guard_with_checker( std::shared_ptr guard = std::make_shared( +<<<<<<< HEAD py::cast(guard_managers[0])->get_root(), overlapping, checker, verbose_code_parts); +======= + overlapping, checker, verbose_code_parts); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::cast(guard_managers[0]) ->get_root() ->add_relational_guard_resetter(guard); @@ -6449,7 +6786,10 @@ void install_storage_overlapping_guard( /* overlapping= */ false); } +<<<<<<< HEAD C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-volatile") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) char flush_cache_by_eviction() { constexpr size_t evict_size = 32 * 1024 * 1024; std::vector buffer(evict_size, 1); @@ -6460,7 +6800,10 @@ char flush_cache_by_eviction() { } return sink; } +<<<<<<< HEAD C10_DIAGNOSTIC_POP() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) double profile_guard_manager( RootGuardManager* root, @@ -6516,6 +6859,7 @@ bool run_root_guard_manager(void* root, FrameLocalsMapping* f_locals) { if (root == nullptr) { return false; } +<<<<<<< HEAD #ifdef GUARD_INSTRUCTION_COUNT auto n = count_instructions( @@ -6523,6 +6867,8 @@ bool run_root_guard_manager(void* root, FrameLocalsMapping* f_locals) { std::cout << "#instructions in guard eval = " << n << std::endl << std::flush; #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ((RootGuardManager*)root)->check_nopybind(f_locals); } @@ -6603,6 +6949,7 @@ PyObject* torch_c_dynamo_guards_init() { .def("verbose_code_parts", &LeafGuard::verbose_code_parts); py::class_>( py_m, "LAMBDA_GUARD") +<<<<<<< HEAD .def(py::init()) .def("__call__", &LAMBDA_GUARD::check); py::class_>( @@ -6642,23 +6989,73 @@ PyObject* torch_c_dynamo_guards_init() { .def("__call__", &DEFAULT_DEVICE::check); py::class_>(py_m, "NOT_NONE") .def(py::init()) +======= + .def(py::init()) + .def("__call__", &LAMBDA_GUARD::check); + py::class_>( + py_m, "TYPE_MATCH") + .def(py::init()) + .def("__call__", &TYPE_MATCH::check); + py::class_>(py_m, "ID_MATCH") + .def(py::init()) + .def("__call__", &ID_MATCH::check); + py::class_>( + py_m, "NONE_MATCH") + .def(py::init()) + .def("__call__", &NONE_MATCH::check); + py::class_>( + py_m, "TRUE_MATCH") + .def(py::init()) + .def("__call__", &TRUE_MATCH::check); + py::class_>( + py_m, "FALSE_MATCH") + .def(py::init()) + .def("__call__", &FALSE_MATCH::check); + py::class_>( + py_m, "EQUALS_MATCH") + .def(py::init()) + .def("__call__", &EQUALS_MATCH::check); + py::class_>( + py_m, "LENGTH_CHECK") + .def(py::init()) + .def("__call__", &LENGTH_CHECK::check); + py::class_>( + py_m, "DICT_LENGTH") + .def(py::init()) + .def("__call__", &DICT_LENGTH::check); + py::class_>( + py_m, "DEFAULT_DEVICE") + .def(py::init()) + .def("__call__", &DEFAULT_DEVICE::check); + py::class_>(py_m, "NOT_NONE") + .def(py::init()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def("__call__", &NOT_NONE::check); py::class_< MAPPING_KEYS_MATCH, LeafGuard, std::shared_ptr>(py_m, "MAPPING_KEYS_MATCH") +<<<<<<< HEAD .def(py::init()) +======= + .def(py::init()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def("__call__", &MAPPING_KEYS_MATCH::check); py::class_< TUPLE_ITERATOR_LEN, LeafGuard, std::shared_ptr>(py_m, "TUPLE_ITERATOR_LEN") +<<<<<<< HEAD .def(py::init()) +======= + .def(py::init()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def("__call__", &TUPLE_ITERATOR_LEN::check); py::class_< RANGE_ITERATOR_MATCH, LeafGuard, std::shared_ptr>(py_m, "RANGE_ITERATOR_MATCH") +<<<<<<< HEAD .def(py::init< RootGuardManager*, py::object, @@ -6670,6 +7067,13 @@ PyObject* torch_c_dynamo_guards_init() { py::class_>( py_m, "GLOBAL_STATE") .def(py::init()) +======= + .def(py::init()) + .def("__call__", &RANGE_ITERATOR_MATCH::check); + py::class_>( + py_m, "GLOBAL_STATE") + .def(py::init()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def("check_verbose", &GLOBAL_STATE::check_verbose) .def("__call__", &GLOBAL_STATE::check); py::class_< @@ -6677,6 +7081,7 @@ PyObject* torch_c_dynamo_guards_init() { LeafGuard, std::shared_ptr>( py_m, "TORCH_FUNCTION_MODE_STACK") +<<<<<<< HEAD .def(py::init()) .def("__call__", &TORCH_FUNCTION_MODE_STACK::check); py::class_>( @@ -6698,6 +7103,25 @@ PyObject* torch_c_dynamo_guards_init() { py::class_>( py_m, "DICT_VERSION") .def(py::init()) +======= + .def(py::init()) + .def("__call__", &TORCH_FUNCTION_MODE_STACK::check); + py::class_>( + py_m, "NO_HASATTR") + .def(py::init()) + .def("__call__", &NO_HASATTR::check); + py::class_>( + py_m, "DICT_CONTAINS") + .def(py::init()) + .def("__call__", &DICT_CONTAINS::check); + py::class_>( + py_m, "DYNAMIC_INDICES") + .def(py::init()) + .def("__call__", &DYNAMIC_INDICES::check); + py::class_>( + py_m, "DICT_VERSION") + .def(py::init()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def("__call__", &DICT_VERSION::check); py::class_< DISPATCH_KEY_SET_MATCH, @@ -6747,11 +7171,19 @@ PyObject* torch_c_dynamo_guards_init() { py::class_>( py_m, "GuardAccessor") .def("repr", &GuardAccessor::repr); +<<<<<<< HEAD py::class_< GetAttrGuardAccessor, GuardAccessor, std::unique_ptr>(py_m, "GetAttrGuardAccessor") .def("get_attr_name", &GetAttrGuardAccessor::get_attr_name); +======= + // NOLINTNEXTLINE(bugprone-unused-raii) + py::class_< + GetAttrGuardAccessor, + GuardAccessor, + std::unique_ptr>(py_m, "GetAttrGuardAccessor"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // NOLINTNEXTLINE(bugprone-unused-raii) py::class_< GenericGetAttrGuardAccessor, @@ -6817,6 +7249,7 @@ PyObject* torch_c_dynamo_guards_init() { std::unique_ptr>(py_m, "TypeGuardAccessor"); // NOLINTNEXTLINE(bugprone-unused-raii) py::class_< +<<<<<<< HEAD TypeDictGuardAccessor, GuardAccessor, std::unique_ptr>(py_m, "TypeDictGuardAccessor"); @@ -6827,6 +7260,8 @@ PyObject* torch_c_dynamo_guards_init() { std::unique_ptr>(py_m, "TypeMROGuardAccessor"); // NOLINTNEXTLINE(bugprone-unused-raii) py::class_< +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) WeakRefCallGuardAccessor, GuardAccessor, std::unique_ptr>( @@ -6845,6 +7280,7 @@ PyObject* torch_c_dynamo_guards_init() { py_m, "TupleIteratorGetItemAccessor"); // NOLINTNEXTLINE(bugprone-unused-raii) py::class_< +<<<<<<< HEAD CodeGuardAccessor, GuardAccessor, std::unique_ptr>(py_m, "CodeGuardAccessor"); @@ -6855,6 +7291,8 @@ PyObject* torch_c_dynamo_guards_init() { std::unique_ptr>(py_m, "ClosureGuardAccessor"); // NOLINTNEXTLINE(bugprone-unused-raii) py::class_< +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GlobalWeakRefGuardAccessor, GuardAccessor, std::unique_ptr>( @@ -6867,6 +7305,7 @@ PyObject* torch_c_dynamo_guards_init() { .def("get_source", &GuardManager::get_source) .def("fail_count", &GuardManager::fail_count) .def( +<<<<<<< HEAD "has_object_aliasing_guard", &GuardManager::has_object_aliasing_guard) .def( "is_guarded_value_immutable", @@ -6882,6 +7321,8 @@ PyObject* torch_c_dynamo_guards_init() { .def( "get_type_of_guarded_value", &GuardManager::get_type_of_guarded_value) .def( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "get_accessors", &GuardManager::get_accessors, py::return_value_policy::reference) @@ -6903,9 +7344,13 @@ PyObject* torch_c_dynamo_guards_init() { py::object lambda, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), std::move(lambda), std::move(verbose_code_parts))); +======= + std::move(lambda), std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_type_match_guard", @@ -6914,9 +7359,13 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("TYPE_MATCH"); self.add_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), std::move(value), std::move(verbose_code_parts))); +======= + std::move(value), std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_id_match_guard", @@ -6925,30 +7374,49 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("ID_MATCH"); self.add_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), std::move(value), std::move(verbose_code_parts))); +======= + std::move(value), std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_none_match_guard", [](GuardManager& self, py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("NONE_MATCH"); +<<<<<<< HEAD self.add_leaf_guard(std::make_shared( self.get_root(), std::move(verbose_code_parts))); +======= + self.add_leaf_guard( + std::make_shared(std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_true_match_guard", [](GuardManager& self, py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("TRUE_MATCH"); +<<<<<<< HEAD self.add_leaf_guard(std::make_shared( self.get_root(), std::move(verbose_code_parts))); +======= + self.add_leaf_guard( + std::make_shared(std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_false_match_guard", [](GuardManager& self, py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("FALSE_MATCH"); +<<<<<<< HEAD self.add_leaf_guard(std::make_shared( self.get_root(), std::move(verbose_code_parts))); +======= + self.add_leaf_guard( + std::make_shared(std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_equals_match_guard", @@ -6957,9 +7425,13 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("EQUALS_MATCH"); self.add_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), std::move(value), std::move(verbose_code_parts))); +======= + std::move(value), std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_length_check_guard", @@ -6968,9 +7440,13 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("LENGTH_CHECK"); self.add_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), std::move(value), std::move(verbose_code_parts))); +======= + std::move(value), std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_dict_length_check_guard", @@ -6979,9 +7455,13 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("DICT_LENGTH"); self.add_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), std::move(value), std::move(verbose_code_parts))); +======= + std::move(value), std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_tuple_iterator_length_guard", @@ -6991,7 +7471,10 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("TUPLE_ITERATOR_LEN"); self.add_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::move(length), std::move(type_id), std::move(verbose_code_parts))); @@ -7006,7 +7489,10 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("RANGE_ITERATOR_MATCH"); self.add_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::move(start), std::move(stop), std::move(step), @@ -7017,14 +7503,23 @@ PyObject* torch_c_dynamo_guards_init() { "add_default_device_guard", [](GuardManager& self, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), std::move(verbose_code_parts))); +======= + std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_not_none_guard", [](GuardManager& self, py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("NOT_NONE"); +<<<<<<< HEAD self.add_leaf_guard(std::make_shared( self.get_root(), std::move(verbose_code_parts))); +======= + self.add_leaf_guard( + std::make_shared(std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_mapping_keys_guard", @@ -7033,9 +7528,13 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("MAPPING_KEYS_MATCH"); self.add_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), std::move(value), std::move(verbose_code_parts))); +======= + std::move(value), std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_dispatch_key_set_guard", @@ -7050,6 +7549,7 @@ PyObject* torch_c_dynamo_guards_init() { }) .def( "add_global_state_guard", +<<<<<<< HEAD [](GuardManager& self, py::object initial_state, py::object verbose_code_parts) -> void { @@ -7057,6 +7557,11 @@ PyObject* torch_c_dynamo_guards_init() { self.get_root(), std::move(initial_state), std::move(verbose_code_parts))); +======= + [](GuardManager& self, py::object verbose_code_parts) -> void { + self.add_leaf_guard( + std::make_shared(std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_torch_function_mode_stack_guard", @@ -7064,7 +7569,11 @@ PyObject* torch_c_dynamo_guards_init() { const py::list& initial_stack, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), initial_stack, std::move(verbose_code_parts))); +======= + initial_stack, std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_no_hasattr_guard", @@ -7072,9 +7581,13 @@ PyObject* torch_c_dynamo_guards_init() { py::object attr_name, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), std::move(attr_name), std::move(verbose_code_parts))); +======= + std::move(attr_name), std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_dict_contains_guard", @@ -7083,6 +7596,7 @@ PyObject* torch_c_dynamo_guards_init() { py::object key, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), contains, std::move(key), @@ -7099,6 +7613,9 @@ PyObject* torch_c_dynamo_guards_init() { contains, std::move(item), std::move(verbose_code_parts))); +======= + contains, std::move(key), std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_dynamic_indices_guard", @@ -7106,9 +7623,13 @@ PyObject* torch_c_dynamo_guards_init() { py::set value, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), std::move(value), std::move(verbose_code_parts))); +======= + std::move(value), std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_dict_version_guard", @@ -7117,9 +7638,13 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("DICT_VERSION"); self.add_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), std::move(value), std::move(verbose_code_parts))); +======= + std::move(value), std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_tensor_match_guard", @@ -7310,6 +7835,7 @@ PyObject* torch_c_dynamo_guards_init() { // return by reference because GuardManager has the ownership of accessors // and guard managers .def( +<<<<<<< HEAD "type_dict_manager", [](GuardManager& self, std::string source, @@ -7350,6 +7876,8 @@ PyObject* torch_c_dynamo_guards_init() { // return by reference because GuardManager has the ownership of accessors // and guard managers .def( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "weakref_call_manager", [](GuardManager& self, std::string source, @@ -7397,6 +7925,7 @@ PyObject* torch_c_dynamo_guards_init() { py::arg("example_value"), py::arg("guard_manager_enum"), py::return_value_policy::reference) +<<<<<<< HEAD .def( "set_getitem_manager", &GuardManager::get_child_manager, @@ -7445,6 +7974,8 @@ PyObject* torch_c_dynamo_guards_init() { py::arg("example_value"), py::arg("guard_manager_enum"), py::return_value_policy::reference) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // return by reference because GuardManager has the ownership of accessors // and guard managers .def( @@ -7532,7 +8063,10 @@ PyObject* torch_c_dynamo_guards_init() { .def(py::init<>()) .def("check", &RootGuardManager::check) .def("check_verbose", &RootGuardManager::check_verbose) +<<<<<<< HEAD .def("attach_compile_id", &RootGuardManager::attach_compile_id) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def( "clone_manager", &RootGuardManager::clone_manager, @@ -7549,7 +8083,11 @@ PyObject* torch_c_dynamo_guards_init() { py::object lambda, py::object verbose_code_parts) -> void { self.add_epilogue_lambda_guard(std::make_unique( +<<<<<<< HEAD &self, std::move(lambda), std::move(verbose_code_parts))); +======= + std::move(lambda), std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }); // Dict Guard Manager @@ -7612,10 +8150,14 @@ PyObject* torch_c_dynamo_guards_init() { py::object key, py::object verbose_code_parts) -> void { self.add_permitted_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), contains, std::move(key), std::move(verbose_code_parts))); +======= + contains, std::move(key), std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_dict_version_guard", @@ -7624,9 +8166,13 @@ PyObject* torch_c_dynamo_guards_init() { py::object verbose_code_parts) -> void { SKIP_IF_GUARD_ALREADY_PRESENT("DICT_VERSION"); self.add_permitted_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), std::move(value), std::move(verbose_code_parts))); +======= + std::move(value), std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) .def( "add_no_hasattr_guard", @@ -7634,9 +8180,13 @@ PyObject* torch_c_dynamo_guards_init() { py::object attr_name, py::object verbose_code_parts) -> void { self.add_permitted_leaf_guard(std::make_shared( +<<<<<<< HEAD self.get_root(), std::move(attr_name), std::move(verbose_code_parts))); +======= + std::move(attr_name), std::move(verbose_code_parts))); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }) // Not permitted accessors .def("lambda_manager", &DictGuardManager::fail_on_get_child_manager) @@ -7705,6 +8255,7 @@ PyObject* torch_c_dynamo_guards_init() { throw std::runtime_error("Failed to install dict_version_watch_callback"); } +<<<<<<< HEAD dict_recursive_tag_watcher_id = PyDict_AddWatcher(dict_recursive_tag_watch_callback); if (dict_recursive_tag_watcher_id == -1) { @@ -7712,6 +8263,8 @@ PyObject* torch_c_dynamo_guards_init() { "Failed to install dict_recursive_tag_watch_callback"); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif return m; diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index 2b642ce0bfe80..027ec4eb1209e 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -238,9 +238,12 @@ void initDynamoBindings(PyObject* torch) { "update_diff_guard_root_manager", &CacheEntry::update_diff_guard_root_manager); +<<<<<<< HEAD py::class_(m, "_PrecompileEntry") .def_readonly("guard_manager", &PrecompileEntry::guard_manager); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::class_(m, "_ExtraState") .def("invalidate", &ExtraState::invalidate); @@ -262,7 +265,10 @@ void initDynamoBindings(PyObject* torch) { m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list); m.def("_reset_precompile_entries", &_reset_precompile_entries); m.def("_load_precompile_entry", &_load_precompile_entry); +<<<<<<< HEAD m.def("_debug_get_precompile_entries", &_debug_get_precompile_entries); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::bind_vector>(m, "VectorUInt8"); m.attr("py_opcode_caches") = _PyOpcode_Caches_vec; m.def("code_framelocals_names", &code_framelocals_names); diff --git a/torch/csrc/export/pt2_archive_constants.h b/torch/csrc/export/pt2_archive_constants.h index 1583f759acb65..98e0fad8c21c8 100644 --- a/torch/csrc/export/pt2_archive_constants.h +++ b/torch/csrc/export/pt2_archive_constants.h @@ -33,14 +33,20 @@ namespace torch::_export::archive_spec { DO(WEIGHTS_DIR, "data/weights/") \ DO(WEIGHT_FILENAME_PREFIX, "weight_") \ DO(WEIGHTS_PARAM_CONFIG_FORMAT, "data/weights/{}_model_param_config.json") \ +<<<<<<< HEAD DO(WEIGHTS_CONFIG_FILENAME_FORMAT, "data/weights/{}_weights_config.json") \ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /* constants, including tensor_constants, non-persistent buffers and script \ * objects */ \ DO(CONSTANTS_DIR, "data/constants/") \ DO(CONSTANTS_PARAM_CONFIG_FORMAT, \ "data/constants/{}_model_constants_config.json") \ +<<<<<<< HEAD DO(CONSTANTS_CONFIG_FILENAME_FORMAT, \ "data/constants/{}_constants_config.json") \ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DO(TENSOR_CONSTANT_FILENAME_PREFIX, "tensor_") \ DO(CUSTOM_OBJ_FILENAME_PREFIX, "custom_obj_") \ /* example inputs */ \ diff --git a/torch/csrc/export/pybind.cpp b/torch/csrc/export/pybind.cpp index eedd8666ea168..8d544f9760be3 100644 --- a/torch/csrc/export/pybind.cpp +++ b/torch/csrc/export/pybind.cpp @@ -1,7 +1,12 @@ +<<<<<<< HEAD #include #include #include #include +======= +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -17,6 +22,7 @@ void initExportBindings(PyObject* module) { exportModule.def( "deserialize_exported_program", [](const std::string& serialized) { +<<<<<<< HEAD auto parsed = nlohmann::json::parse(serialized); // Query the current Python schema version as target @@ -29,12 +35,16 @@ void initExportBindings(PyObject* module) { auto upgraded = upgrade(parsed, target_version); return upgraded.get(); +======= + return nlohmann::json::parse(serialized).get(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }); exportModule.def("serialize_exported_program", [](const ExportedProgram& ep) { return nlohmann::json(ep).dump(); }); +<<<<<<< HEAD exportModule.def( "upgrade", [](const std::string& serialized_json, int target_version) { auto parsed = nlohmann::json::parse(serialized_json); @@ -48,6 +58,8 @@ void initExportBindings(PyObject* module) { exportModule.def( "deregister_example_upgraders", []() { deregisterExampleUpgraders(); }); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto& entry : torch::_export::archive_spec::kAllConstants) { pt2ArchiveModule.attr(entry.first) = entry.second; } diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index aa8ef905d57aa..0e919fe050639 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -9,10 +9,15 @@ #include #include #include +<<<<<<< HEAD #include #include #include #include +======= +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifndef _WIN32 #include @@ -37,6 +42,7 @@ namespace fs = std::filesystem; #endif namespace { +<<<<<<< HEAD const std::string k_separator = "/"; @@ -60,6 +66,8 @@ std::string normalize_path_separator(const std::string& orig_path) { return normalized_path; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool file_exists(const std::string& path) { #ifdef _WIN32 return fs::exists(path); @@ -71,6 +79,7 @@ bool file_exists(const std::string& path) { std::string create_temp_dir() { #ifdef _WIN32 +<<<<<<< HEAD try { fs::path temp_dir = fs::temp_directory_path(); return temp_dir.string(); @@ -81,6 +90,9 @@ std::string create_temp_dir() { throw std::runtime_error( "Unknown error occurred while getting temporary directory"); } +======= + throw std::runtime_error("Not implemented"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #else std::string temp_dir = "/tmp/XXXXXX"; if (mkdtemp(temp_dir.data()) == nullptr) { @@ -92,6 +104,7 @@ std::string create_temp_dir() { #endif } +<<<<<<< HEAD const char* object_file_ext() { #ifdef _WIN32 return ".obj"; @@ -115,6 +128,13 @@ bool _is_windows_os() { return false; #endif } +======= +#ifdef _WIN32 +const std::string k_separator = "\\"; +#else +const std::string k_separator = "/"; +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace namespace torch::inductor { @@ -134,12 +154,19 @@ const nlohmann::json& load_json_file(const std::string& json_path) { } std::tuple get_cpp_compile_command( +<<<<<<< HEAD const std::string& arg_filename, +======= + const std::string& filename, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const std::vector& sources, const nlohmann::json& compile_options, const std::string& output_dir = "") { // Construct the cpp command +<<<<<<< HEAD auto filename = normalize_path_separator(arg_filename); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::string compiler = compile_options["compiler"].get(); bool compile_only = compile_options["compile_only"].get(); @@ -149,8 +176,12 @@ std::tuple get_cpp_compile_command( source_args += source + " "; } +<<<<<<< HEAD std::string file_ext = compile_only ? object_file_ext() : extension_file_ext(); +======= + std::string file_ext = compile_only ? ".o" : ".so"; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::string target_file = output_dir + filename + file_ext; std::string target_dir = output_dir; if (target_dir.empty()) { @@ -160,38 +191,59 @@ std::tuple get_cpp_compile_command( std::string cflags_args; for (auto& arg : compile_options["cflags"]) { +<<<<<<< HEAD cflags_args += _is_windows_os() ? "/" : "-" + arg.get() + " "; +======= + cflags_args += "-" + arg.get() + " "; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } std::string definitions_args; for (auto& arg : compile_options["definitions"]) { +<<<<<<< HEAD definitions_args += _is_windows_os() ? "/D" : "-D " + arg.get() + " "; +======= + definitions_args += "-D " + arg.get() + " "; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } std::string include_dirs_args; for (auto& arg : compile_options["include_dirs"]) { +<<<<<<< HEAD include_dirs_args += _is_windows_os() ? "/I" : "-I" + arg.get() + " "; +======= + include_dirs_args += "-I" + arg.get() + " "; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } std::string ldflags_args; for (auto& arg : compile_options["ldflags"]) { +<<<<<<< HEAD ldflags_args += _is_windows_os() ? "/" : "-" + arg.get() + " "; +======= + ldflags_args += "-" + arg.get() + " "; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } std::string libraries_dirs_args; for (auto& arg : compile_options["libraries_dirs"]) { +<<<<<<< HEAD if (_is_windows_os()) { libraries_dirs_args += fmt::format("/LIBPATH:\"{}\"", arg.get()) + " "; } else { libraries_dirs_args += "-L" + arg.get() + " "; } +======= + libraries_dirs_args += "-L" + arg.get() + " "; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } std::string libraries_args; for (auto& arg : compile_options["libraries"]) { +<<<<<<< HEAD if (_is_windows_os()) { libraries_args += fmt::format("{}.lib", arg.get()) + " "; } else { @@ -242,6 +294,39 @@ std::tuple get_cpp_compile_command( compile_only_arg, target_file)); } +======= + libraries_args += "-l" + arg.get() + " "; + } + + std::string passthrough_parameters_args; + for (auto& arg : compile_options["passthrough_args"]) { + std::string arg_str = arg.get(); + std::string target = "script.ld"; + std::string replacement = target_dir; + replacement.append(k_separator).append(target); + size_t pos = arg_str.find(target); + if (pos != std::string::npos) { + arg_str.replace(pos, target.length(), replacement); + } + passthrough_parameters_args += arg_str + " "; + } + + std::string compile_only_arg = compile_only ? "-c" : ""; + + std::string cmd = fmt::format( + "{} {} {} {} {} {} {} {} {} {} -o {}", + compiler, + source_args, + definitions_args, + cflags_args, + include_dirs_args, + passthrough_parameters_args, + ldflags_args, + libraries_args, + libraries_dirs_args, + compile_only_arg, + target_file); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return std::make_tuple(cmd, target_file); } @@ -404,6 +489,7 @@ std::string compile_so( return output_so; } +<<<<<<< HEAD std::unordered_set find_model_names( const std::vector& paths) { @@ -426,6 +512,8 @@ std::unordered_set find_model_names( return model_names; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) { @@ -441,6 +529,7 @@ void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) { } } +<<<<<<< HEAD class RAIIMinizArchive { public: RAIIMinizArchive(const std::string& zip_path) { @@ -505,6 +594,8 @@ class RAIIMinizArchive { mz_zip_archive _zip_archive{}; }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTIModelPackageLoader::AOTIModelPackageLoader( const std::string& model_package_path, const std::string& model_name, @@ -524,8 +615,37 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } // Extract all files within the zipfile to a temporary directory +<<<<<<< HEAD RAIIMinizArchive zip_archive{model_package_path}; auto found_filenames{zip_archive.get_filenames()}; +======= + mz_zip_archive zip_archive; + memset(&zip_archive, 0, sizeof(zip_archive)); + + if (!mz_zip_reader_init_file(&zip_archive, model_package_path.c_str(), 0)) { + throw std::runtime_error( + std::string("Failed to initialize zip archive: ") + + mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); + } + + std::vector found_filenames; + for (uint32_t i = 0; i < zip_archive.m_total_files; i++) { + uint32_t filename_len = + mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0); + if (filename_len == 0) { + throw std::runtime_error("Failed to read filename"); + } + // filename_len returned by mz_zip_reader_get_filename includes the null + // terminator, so we need to subtract 1 here + std::string filename_str(filename_len - 1, '\0'); + if (!mz_zip_reader_get_filename( + &zip_archive, i, filename_str.data(), filename_len)) { + throw std::runtime_error("Failed to read filename"); + } + found_filenames.push_back(filename_str); + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (found_filenames.empty()) { throw std::runtime_error("No files found in zip archive."); } @@ -547,11 +667,16 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( << found_filenames[1]; } +<<<<<<< HEAD temp_dir_ = normalize_path_separator(create_temp_dir()); +======= + temp_dir_ = create_temp_dir(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::string so_filename; std::string cpp_filename; std::vector obj_filenames; +<<<<<<< HEAD std::string model_directory = normalize_path_separator( file_prefix + "data" + k_separator + "aotinductor" + k_separator + model_name); @@ -577,6 +702,29 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( std::string filename = cur_filename; if (lastSlash != std::string::npos) { filename = cur_filename.substr(lastSlash + 1); +======= + std::string model_directory = file_prefix + "data" + k_separator + + "aotinductor" + k_separator + model_name; + std::string const_directory = + file_prefix + "data" + k_separator + "constants"; + + for (const std::string& filename_str : found_filenames) { + // Only compile files in the specified model directory + if (c10::starts_with(filename_str, model_directory) || + c10::starts_with(filename_str, const_directory)) { + std::string output_path_str = temp_dir_; + + if (c10::starts_with(filename_str, model_directory)) { + output_path_str += k_separator; + output_path_str += filename_str; + } else { // startsWith(filename_str, const_directory) + // Extract constants to the same directory as the rest of the files + // to be consistent with internal implementation + size_t lastSlash = filename_str.find_last_of(k_separator); + std::string filename = filename_str; + if (lastSlash != std::string::npos) { + filename = filename_str.substr(lastSlash + 1); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } output_path_str.append(k_separator) .append(model_directory) @@ -584,6 +732,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( .append(filename); } +<<<<<<< HEAD std::string output_file_path = normalize_path_separator(output_path_str); LOG(INFO) << "Extract file: " << zip_filename_str << " to " << output_file_path; @@ -595,6 +744,18 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( "Failed to find parent path in " + output_file_path); } std::string parent_path = output_file_path.substr(0, parent_path_idx); +======= + LOG(INFO) << "Extract file: " << filename_str << " to " + << output_path_str; + + // Create the parent directory if it doesn't exist + size_t parent_path_idx = output_path_str.find_last_of(k_separator); + if (parent_path_idx == std::string::npos) { + throw std::runtime_error( + "Failed to find parent path in " + output_path_str); + } + std::string parent_path = output_path_str.substr(0, parent_path_idx); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!recursive_mkdir(parent_path)) { throw std::runtime_error(fmt::format( "Failed to create directory {}: {}", @@ -603,6 +764,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } // Extracts file to the temp directory +<<<<<<< HEAD zip_archive.extract_file(zip_filename_str, output_path_str); // Save the file for bookkeeping @@ -615,16 +777,43 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( obj_filenames.push_back(output_file_path); } else if (filename_extension == extension_file_ext()) { so_filename = output_file_path; +======= + mz_zip_reader_extract_file_to_file( + &zip_archive, filename_str.c_str(), output_path_str.c_str(), 0); + + // Save the file for bookkeeping + size_t extension_idx = output_path_str.find_last_of('.'); + if (extension_idx != std::string::npos) { + std::string filename_extension = output_path_str.substr(extension_idx); + if (filename_extension == ".cpp") { + cpp_filename = output_path_str; + } else if (filename_extension == ".o") { + obj_filenames.push_back(output_path_str); + } else if (filename_extension == ".so") { + so_filename = output_path_str; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } } } +<<<<<<< HEAD +======= + // Close the zip archive as we have extracted all files to the temp + // directory + if (!mz_zip_reader_end(&zip_archive)) { + throw std::runtime_error( + std::string("Failed to close zip archive: {}") + + mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (cpp_filename.empty() && so_filename.empty()) { std::string found_filenames_str; for (const std::string& filename : found_filenames) { found_filenames_str += filename + "\n"; } +<<<<<<< HEAD std::string model_names_str; for (const std::string& model_name_tmp : find_model_names(found_filenames)) { @@ -641,6 +830,11 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( "To load a specific model, please provide its name using the `model_name` parameter when calling AOTIModelPackageLoader() or torch._inductor.package.load_package.\n\n" "The following files were loaded from the archive:\n" + found_filenames_str); +======= + throw std::runtime_error( + "No AOTInductor generate cpp file or so file found in zip archive with the prefix " + + model_directory + "Loaded the following:\n" + found_filenames_str); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // Compile the .so diff --git a/torch/csrc/inductor/aoti_runtime/device_utils.h b/torch/csrc/inductor/aoti_runtime/device_utils.h index 8c75560f8d29b..8402660e40d54 100644 --- a/torch/csrc/inductor/aoti_runtime/device_utils.h +++ b/torch/csrc/inductor/aoti_runtime/device_utils.h @@ -14,7 +14,11 @@ #include #include +<<<<<<< HEAD #define AOTI_RUNTIME_CUDA_CHECK(EXPR) \ +======= +#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) do { \ const cudaError_t code = EXPR; \ const char* msg = cudaGetErrorString(code); \ @@ -34,7 +38,11 @@ using DeviceStreamType = cudaStream_t; #include #include #include +<<<<<<< HEAD #define AOTI_RUNTIME_XPU_CHECK(EXPR) \ +======= +#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) do { \ const ze_result_t status = EXPR; \ if (status != ZE_RESULT_SUCCESS) { \ @@ -52,7 +60,11 @@ using DeviceStreamType = sycl::queue*; #else +<<<<<<< HEAD #define AOTI_RUNTIME_CPU_CHECK(EXPR) \ +======= +#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool ok = EXPR; \ if (!ok) { \ throw std::runtime_error("CPU runtime error"); \ diff --git a/torch/csrc/inductor/aoti_runtime/interface.h b/torch/csrc/inductor/aoti_runtime/interface.h index fab9a87a725e8..c2f5b95d73cf3 100644 --- a/torch/csrc/inductor/aoti_runtime/interface.h +++ b/torch/csrc/inductor/aoti_runtime/interface.h @@ -6,6 +6,7 @@ // applies to other files under torch/csrc/inductor/aoti_runtime/. #include +<<<<<<< HEAD #ifdef _WIN32 /* On Windows, we need to explicit declaration for export APIs. And because the @@ -17,6 +18,8 @@ the import case. #define AOTI_API __attribute__((__visibility__("default"))) #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extern "C" { struct AOTInductorModelOpaque; using AOTInductorModelHandle = AOTInductorModelOpaque*; @@ -32,7 +35,11 @@ using AOTInductorConstantMapHandle = AOTInductorConstantMap*; // TODO: Deprecate this API. This was kept for BC compatibility. // Please use AOTInductorModelContainerCreateWithDevice instead. +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerCreate( +======= +AOTIRuntimeError AOTInductorModelContainerCreate( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle* container_handle, size_t num_models, bool is_cpu, @@ -45,18 +52,30 @@ AOTI_API AOTIRuntimeError AOTInductorModelContainerCreate( // "cpu", "cuda", "cuda:0", etc. If the device index is not specified for CUDA // device, runtime will use the device index returned by // "cudaGetDevice(&device_idx)" +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( +======= +AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle* container_handle, size_t num_models, const char* device_str, const char* cubin_dir); // Deletes the AOTInductor model container. +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerDelete( AOTInductorModelContainerHandle container_handle); // Runs the inference. AOTI_API AOTIRuntimeError AOTInductorModelContainerRun( +======= +AOTIRuntimeError AOTInductorModelContainerDelete( + AOTInductorModelContainerHandle container_handle); + +// Runs the inference. +AOTIRuntimeError AOTInductorModelContainerRun( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles // are stolen; the array itself is borrowed @@ -70,7 +89,11 @@ AOTI_API AOTIRuntimeError AOTInductorModelContainerRun( AOTIProxyExecutorHandle proxy_executor_handle); // Single-threaded variant of previous. +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( +======= +AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles // are stolen; the array itself is borrowed @@ -84,14 +107,22 @@ AOTI_API AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( AOTIProxyExecutorHandle proxy_executor_handle); // Retrieves the number of constants for the model. +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerGetNumConstants( +======= +AOTIRuntimeError AOTInductorModelContainerGetNumConstants( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, size_t* num_constants); // Retrieves a constant's name. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantName( +======= +AOTIRuntimeError AOTInductorModelContainerGetConstantName( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, size_t idx, const char** name); @@ -99,7 +130,11 @@ AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantName( // Retrieves a constant's original FQN. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( +======= +AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, size_t idx, const char** original_fqn); @@ -107,7 +142,11 @@ AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( // Retrieves whether a constant is from folded. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( +======= +AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, size_t idx, bool* from_folded); @@ -115,7 +154,11 @@ AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( // Retrieves the inductor constant type. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantType( +======= +AOTIRuntimeError AOTInductorModelContainerGetConstantType( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, size_t idx, int32_t* type); @@ -123,7 +166,11 @@ AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantType( // Retrieves a constant's dtype. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( +======= +AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, size_t idx, int32_t* dtype); @@ -131,21 +178,33 @@ AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( // Retrieves a constant's data size. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantDataSize( +======= +AOTIRuntimeError AOTInductorModelContainerGetConstantDataSize( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, size_t idx, size_t* data_size); // Extract the constants that is being used in the container. +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerExtractConstantsMap( +======= +AOTIRuntimeError AOTInductorModelContainerExtractConstantsMap( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, AOTInductorConstantMapHandle constant_map_handle, bool use_inactive); // Setup the constant buffer in model container with provided ConstantMap. // The ConstantMap is user managed, and the user would retain ownership. +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBuffer( +======= +AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBuffer( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, AOTInductorConstantMapHandle constant_map_handle, bool use_inactive, @@ -154,7 +213,11 @@ AOTInductorModelContainerUpdateUserManagedConstantBuffer( // Setup the constant buffer in model container with provided ConstantMap // use_inactive should be set as true if the inactive buffer is to be updated. // validate_full_update checks if all constants are included in the ConstantMap +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( +======= +AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, AOTInductorConstantMapHandle constant_map_handle, bool use_inactive, @@ -162,43 +225,75 @@ AOTI_API AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( // Setup the inactive constant buffer in model container with provided // ConstantMap +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( +======= +AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, AOTInductorConstantMapHandle constant_map_handle); // Free the inactive constant buffer in model container. +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerFreeInactiveConstantBuffer( AOTInductorModelContainerHandle container_handle); // Run constant folding on constant buffer. AOTI_API AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( +======= +AOTIRuntimeError AOTInductorModelContainerFreeInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle); + +// Run constant folding on constant buffer. +AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, bool use_inactive, AOTInductorStreamHandle stream_handle, AOTIProxyExecutorHandle proxy_executor_handle); // Swap the constant buffer being used to the inactive one. +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( AOTInductorModelContainerHandle container_handle); // Retrieves the number of inputs for the model. AOTI_API AOTIRuntimeError AOTInductorModelContainerGetNumInputs( +======= +AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( + AOTInductorModelContainerHandle container_handle); + +// Retrieves the number of inputs for the model. +AOTIRuntimeError AOTInductorModelContainerGetNumInputs( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, size_t* ret_num_inputs); // Retrieves the input name at the given index. +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerGetInputName( +======= +AOTIRuntimeError AOTInductorModelContainerGetInputName( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, size_t input_idx, const char** ret_input_names); // Retrieves the number of outputs for the model. +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( +======= +AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, size_t* ret_num_outputs); // Retrieves the output name at the given index. +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelContainerGetOutputName( +======= +AOTIRuntimeError AOTInductorModelContainerGetOutputName( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, size_t output_idx, const char** ret_output_names); @@ -210,24 +305,37 @@ AOTI_API AOTIRuntimeError AOTInductorModelContainerGetOutputName( // // constant_map_handle is an opaque type to satisfy the C ABI. It should be a // std::unordered_map*. +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelCreate( +======= +AOTIRuntimeError AOTInductorModelCreate( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelHandle* model_handle, AOTInductorConstantMapHandle constant_map_handle); // Run an AOTInductorModel (see AOTInductorModelCreate for when one should use // this function versus AOTInductorModelContainerRun). +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelRun( +======= +AOTIRuntimeError AOTInductorModelRun( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelHandle model_handle, AtenTensorHandle* input_handles, AtenTensorHandle* output_handles); // Replace AOTInductorModel's constant map. Note it doesn't handle concurrency // so be sure to handle ordering if AOTInductorModelRun is ran concurrently. +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelUpdateConstantsMap( +======= +AOTIRuntimeError AOTInductorModelUpdateConstantsMap( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelHandle model_handle, AOTInductorConstantMapHandle constant_map_handle); // Delete an AOTInductorModel created by AOTInductorModelCreate. +<<<<<<< HEAD AOTI_API AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle); @@ -236,6 +344,15 @@ AOTI_API AOTIRuntimeError AOTInductorModelGetNumOutputs( size_t* ret_num_outputs); AOTI_API AOTIRuntimeError AOTInductorModelContainerGetCallSpec( +======= +AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle); + +AOTIRuntimeError AOTInductorModelGetNumOutputs( + AOTInductorModelHandle model_handle, + size_t* ret_num_outputs); + +AOTIRuntimeError AOTInductorModelContainerGetCallSpec( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTInductorModelContainerHandle container_handle, const char** in_spec, const char** out_spec); diff --git a/torch/csrc/inductor/aoti_runtime/model.h b/torch/csrc/inductor/aoti_runtime/model.h index 253c5e917e76b..44f66b705db4b 100644 --- a/torch/csrc/inductor/aoti_runtime/model.h +++ b/torch/csrc/inductor/aoti_runtime/model.h @@ -1,13 +1,742 @@ #pragma once +<<<<<<< HEAD +======= +#include +#include +#include +#include +#include +#include +#include +#include +#include + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // WARNING: Be careful when adding new includes here. This header will be used // in model.so, and should not refer to any aten/c10 headers except the stable // C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule // applies to other files under torch/csrc/inductor/aoti_runtime/. +<<<<<<< HEAD #include namespace torch::aot_inductor { +======= +#include +#ifdef USE_MPS +#include +#endif // USE_MPS +#ifdef USE_XPU +#include +#else +#include +#endif // USE_XPU +#include + +#define AOTI_RUNTIME_CHECK(EXPR, MSG) \ + do { \ + bool ok = EXPR; \ + if (!ok) { \ + throw std::runtime_error(MSG); \ + } \ + } while (0) + +// At codegen time, we write out a binary file called constants.bin. +// We then turn the raw binary to an object file that exposes this +// symbol and link it into the final .so. +// For information on the binary format, see `man objcopy`, under +// the "binary-architecture" flag: +// https://man7.org/linux/man-pages/man1/objcopy.1.html +// todo: use #embed in C++ 23 once available +// The constants are NOT readonly because they may be mutated. +// NOLINTNEXTLINE(*array*) +extern uint8_t _binary_constants_bin_start[]; +// NOLINTNEXTLINE(*array*) +extern uint8_t _binary_constants_bin_end[]; + +#if defined(USE_CUDA) || defined(USE_XPU) +// Compute required blob size with 64-alignment if on GPU. +#define AOTI_CONST_ALIGNMENT 64 +#else +// Use 64-alignment (use something >=64)for better performance on CPU. +#define AOTI_CONST_ALIGNMENT 64 +#endif + +namespace { + +using RAIIDataPtr = std::unique_ptr>; + +#ifdef USE_CUDA + +RAIIDataPtr RAII_gpuMalloc(size_t num_bytes) { + void* data_ptr; + AOTI_RUNTIME_DEVICE_CHECK(cudaMalloc((void**)&data_ptr, num_bytes)); + auto deleter = [](void* ptr) { AOTI_RUNTIME_DEVICE_CHECK(cudaFree(ptr)); }; + return RAIIDataPtr(data_ptr, deleter); +} + +#elif defined(USE_XPU) + +RAIIDataPtr RAII_gpuMalloc(size_t num_bytes) { + sycl::queue* queue_ptr = nullptr; + aoti_torch_get_current_sycl_queue((void**)&queue_ptr); + void* data_ptr = sycl::malloc_device(num_bytes, *queue_ptr); + auto deleter = [queue_ptr](void* ptr) { sycl::free(ptr, *queue_ptr); }; + return RAIIDataPtr(data_ptr, deleter); +} + +#elif defined(USE_MPS) + +RAIIDataPtr RAII_gpuMalloc(size_t num_bytes) { + void* data_ptr = nullptr; + aoti_torch_mps_malloc(&data_ptr, num_bytes); + auto deleter = [](void* ptr) { aoti_torch_mps_free(ptr); }; + return RAIIDataPtr(data_ptr, deleter); +} + +#else + +RAIIDataPtr RAII_cpuMalloc(size_t num_bytes) { + void* data_ptr = std::malloc(num_bytes); + if (!data_ptr) { + throw std::bad_alloc(); + } + auto deleter = [](void* ptr) { std::free(ptr); }; + return RAIIDataPtr(data_ptr, deleter); +} + +#endif // USE_CUDA + +} // anonymous namespace + +namespace torch::aot_inductor { + +using ConstantMap = + std::unordered_map; + +// valid device strs are: cpu, cuda, cuda:0, cuda:1, ... +// Update the list here if more devices are supported in the future +inline void parse_device_str( + const std::string& device_str, + int32_t& device_type, + int32_t& device_idx) { + std::regex re("(cpu|cuda|xpu|mps)(:([0-9]+))?"); + std::smatch sm; + bool matched = std::regex_match(device_str, sm, re); + AOTI_RUNTIME_CHECK(matched, "Invalid device: " + device_str); + + if (sm[1].str() == "cpu") { + device_type = aoti_torch_device_type_cpu(); + } else if (sm[1].str() == "cuda") { + device_type = aoti_torch_device_type_cuda(); +#ifdef USE_XPU + } else if (sm[1].str() == "xpu") { + device_type = aoti_torch_device_type_xpu(); +#endif +#ifdef USE_MPS + } else if (sm[1].str() == "mps") { + device_type = aoti_torch_device_type_mps(); +#endif + } else { + AOTI_RUNTIME_CHECK(false, "Invalid device: " + device_str); + } + + if (sm[3].matched) { + device_idx = stoi(sm[3].str()); + } else { + device_idx = -1; + } +} + +// Defines the base class for AOTInductorModel, which is generated by the +// AOTInductor cpp codegen. Since we do not need dynamic dispatch, we rely +// on curiously recurring template pattern (CRTP) to save some runtime +// v-table overhead. The generated AOTInductorModel is specialized with +// methods such as run_impl. +template +class AOTInductorModelBase { + public: + AOTInductorModelBase( + size_t num_inputs, + size_t num_outputs, + size_t num_constants, + const std::string& device_str, + std::optional cubin_dir, + bool include_weights = true) + : inputs_info_(num_inputs), + outputs_info_(num_outputs), + constants_info_(num_constants), + cubin_dir_(std::move(cubin_dir)), + include_weights(include_weights) { + parse_device_str(device_str, device_type_, device_idx_); + +#ifdef USE_CUDA + if (device_idx_ == -1) { + AOTI_RUNTIME_DEVICE_CHECK(cudaGetDevice(&device_idx_)); + } else { + // If device_idx_ is passed in, we need to set the current device to it + AOTI_RUNTIME_DEVICE_CHECK(cudaSetDevice(device_idx_)); + } +#endif // USE_CUDA +#ifdef USE_XPU + if (device_idx_ == -1) { + aoti_torch_get_current_xpu_device(&device_idx_); + } else { + aoti_torch_set_current_xpu_device(device_idx_); + } +#endif // USE_XPU +#ifdef USE_MPS + if (device_idx_ == -1) { + device_idx_ = 0; + } +#endif // USE_MPS + } + + // NOLINTNEXTLINE(modernize-use-equals-default) + ~AOTInductorModelBase() { +#ifdef USE_CUDA + if (run_finished_) { + auto code = cudaEventDestroy(*run_finished_); + if (code != cudaSuccess) { + std::cerr << "Failed to destroy CUDA event in AOTInductor model: " + << cudaGetErrorString(code) << std::endl; + } + } +#endif // USE_CUDA +#ifdef USE_XPU + if (run_finished_) { + (*run_finished_)->wait_and_throw(); + delete *run_finished_; + } +#endif // USE_XPU + } + + AOTInductorModelBase(AOTInductorModelBase&&) = delete; + AOTInductorModelBase& operator=(AOTInductorModelBase&&) = delete; + AOTInductorModelBase(const AOTInductorModelBase&) = delete; + AOTInductorModelBase& operator=(const AOTInductorModelBase&) = delete; + + void run( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor) { +#ifdef USE_CUDA + if (!run_finished_) { + cudaEvent_t run_finished; + AOTI_RUNTIME_DEVICE_CHECK(cudaEventCreate(&run_finished)); + run_finished_.emplace(run_finished); + } +#elif defined(USE_XPU) + if (run_finished_) { + (*run_finished_)->wait_and_throw(); + delete *run_finished_; + run_finished_.reset(); + } +#else // !USE_CUDA && !USE_XPU + run_finished_ = false; +#endif + + auto* model = static_cast(this); + model->run_impl(input_handles, output_handles, stream, proxy_executor); + +#ifdef USE_CUDA + AOTI_RUNTIME_DEVICE_CHECK(cudaEventRecord(*run_finished_, stream)); +#elif defined(USE_XPU) + run_finished_ = std::make_optional(new sycl::event( + static_cast(stream)->ext_oneapi_submit_barrier())); +#else // !USE_CUDA && !USE_XPU + run_finished_ = true; +#endif // USE_CUDA + } + + // Non-thread-aware variant of run(). Obviously unsafe to use in a threaded + // environment :) + void run_single_threaded( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor) { + // don't bother with any of the run_finished stuff; this is unsafe to call + // in a threaded context + auto* model = static_cast(this); + model->run_impl(input_handles, output_handles, stream, proxy_executor); + } + + std::unordered_map run_const_fold( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization = false) { +#ifdef USE_CUDA + if (!run_finished_) { + cudaEvent_t run_finished; + AOTI_RUNTIME_DEVICE_CHECK(cudaEventCreate(&run_finished)); + run_finished_.emplace(run_finished); + } +#elif defined(USE_XPU) + if (run_finished_) { + (*run_finished_)->wait_and_throw(); + delete *run_finished_; + run_finished_.reset(); + } +#else // !USE_CUDA && !USE_XPU + run_finished_ = false; +#endif + + auto* model = static_cast(this); + auto folded_constants = + model->const_run_impl(stream, proxy_executor, initialization); + +#ifdef USE_CUDA + AOTI_RUNTIME_DEVICE_CHECK(cudaEventRecord(*run_finished_, stream)); +#elif defined(USE_XPU) + // sycl::queue* queue_ptr = nullptr; + // aoti_torch_get_current_sycl_queue((void**)&queue_ptr); + run_finished_ = std::make_optional(new sycl::event( + static_cast(stream)->ext_oneapi_submit_barrier())); + +#else // !USE_CUDA && !USE_XPU + run_finished_ = true; +#endif // USE_CUDA + + return folded_constants; + } + + void load_constants() { + size_t num_constants = this->num_constants(); + size_t num_folded_constants = this->num_folded_constants(); + constants_map_->reserve(num_constants); + + std::vector constants_internal_offset( + num_constants - num_folded_constants); + size_t blob_size = 0; + compute_constant_blob(blob_size, constants_internal_offset); + if (!include_weights) { + return; + } +#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_MPS) + constant_blob_ = RAII_gpuMalloc(blob_size); +#else + constant_blob_ = RAII_cpuMalloc(blob_size); +#endif + + size_t bytes_read = 0; + for (size_t i = 0; i < num_constants; i++) { + bool from_folded = this->constant_from_folded(i); + if (from_folded) { + continue; + } + std::string name = this->constant_name(i); + size_t data_size = this->constant_data_size(i); + uint8_t* internal_ptr = (data_size != 0) + ? constant_ptr( + constants_internal_offset[i], + bytes_read, + data_size, + /* skip_copy = */ false) + : nullptr; + bytes_read += data_size; + + // Create at::Tensor from copied memory. + auto dtype = this->constant_dtype(i); + auto ndim = this->constant_ndim(i); + auto size = this->constant_shape(i); + auto stride = this->constant_stride(i); +#ifdef USE_MPS + auto offset = this->constant_offset(i) + + (constants_internal_offset[i] / aoti_torch_dtype_element_size(dtype)); +#else + auto offset = this->constant_offset(i); +#endif + auto layout = this->constant_layout(i); + auto opaque_metadata_ptr = this->opaque_metadata(i); + auto opaque_metadata_size = this->opaque_metadata_size(i); + + AtenTensorHandle tensor_handle = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_v2( + internal_ptr, + ndim, + size, + stride, + offset, + dtype, + device_type_, + device_idx_, + &tensor_handle, + layout, + opaque_metadata_ptr, + opaque_metadata_size)); + constants_map_->emplace(std::move(name), tensor_handle); + } + if (constants_map_) { + this->update_constants_array_from_map(); + } + } + + RAIIDataPtr&& release_constant_blob() { + return std::move(constant_blob_); + } + + std::shared_ptr> get_constants_array() { + return constants_; + } + + int32_t get_device_type() const { + return device_type_; + } + + int32_t get_device_idx() const { + return device_idx_; + } + + uint8_t* constant_ptr( + size_t constant_offset, + size_t bytes_read, + size_t data_size, + bool skip_copy) { + auto* constants_ptr = static_cast(constant_blob_.get()); + uint8_t* internal_ptr = constants_ptr + constant_offset; + // TODO: Handle shared storage case. + if (!skip_copy) { +#ifdef USE_XPU + sycl::queue* queue_ptr = nullptr; + aoti_torch_get_current_sycl_queue((void**)&queue_ptr); + queue_ptr + ->memcpy(internal_ptr, _get_constants_start() + bytes_read, data_size) + .wait(); +#elif USE_CUDA + AOTI_RUNTIME_DEVICE_CHECK(cudaMemcpy( + internal_ptr, + _get_constants_start() + bytes_read, + data_size, + cudaMemcpyHostToDevice)); +#elif USE_MPS + aoti_torch_mps_memcpy( + constants_ptr, + constant_offset, + bytes_read, + data_size, + _get_constants_start()); + return constants_ptr; +#else + memcpy(internal_ptr, _get_constants_start() + bytes_read, data_size); +#endif + } + return internal_ptr; + } + + void compute_constant_blob( + size_t& blob_size, + std::vector& constants_internal_offset) { + size_t num_constants = this->num_constants(); + blob_size = 0; + size_t curr_idx = 0; + for (size_t i = 0; i < num_constants; i++) { + if (this->constant_from_folded(i)) { + continue; + } + size_t data_size = this->constant_data_size(i); + if (data_size % AOTI_CONST_ALIGNMENT) { + data_size = AOTI_CONST_ALIGNMENT + + (data_size / AOTI_CONST_ALIGNMENT) * AOTI_CONST_ALIGNMENT; + } + constants_internal_offset[curr_idx++] = blob_size; + blob_size += data_size; + } + } + + size_t num_inputs() const { + return inputs_info_.size(); + } + + size_t num_outputs() const { + return outputs_info_.size(); + } + + size_t num_constants() const { + return constants_info_.size(); + } + + size_t num_folded_constants() const { + size_t total_consts = this->num_constants(); + size_t folded_consts = 0; + for (size_t i = 0; i < total_consts; i++) { + if (this->constant_from_folded(i)) { + folded_consts++; + } + } + return folded_consts; + } + + const char* input_name(int64_t idx) const { + return inputs_info_.at(idx).name; + } + + const char* output_name(int64_t idx) const { + return outputs_info_.at(idx).name; + } + + const char* constant_name(int64_t idx) const { + return constants_info_.at(idx).name; + } + + size_t constant_ndim(int64_t idx) { + return constants_info_.at(idx).shape.size(); + } + + const int64_t* constant_shape(int64_t idx) const { + return constants_info_.at(idx).shape.data(); + } + + const int64_t* constant_stride(int64_t idx) const { + return constants_info_.at(idx).stride.data(); + } + + int32_t constant_dtype(int64_t idx) const { + return constants_info_.at(idx).dtype; + } + + int32_t constant_layout(int64_t idx) const { + return constants_info_.at(idx).layout; + } + + size_t constant_offset(int64_t idx) const { + return constants_info_.at(idx).offset; + } + + size_t constant_data_size(int64_t idx) const { + return constants_info_.at(idx).data_size; + } + + const char* constant_original_fqn(int64_t idx) const { + return constants_info_.at(idx).original_fqn; + } + + const uint8_t* opaque_metadata(int64_t idx) const { + return constants_info_.at(idx).opaque_metadata.data(); + } + + size_t opaque_metadata_size(int64_t idx) { + return constants_info_.at(idx).opaque_metadata.size(); + } + + bool constant_from_folded(int64_t idx) const { + return constants_info_.at(idx).from_folded; + } + + int32_t constant_type(int64_t idx) const { + return constants_info_.at(idx).type; + } + + const char* get_in_spec() const { + return in_spec_.c_str(); + } + + const char* get_out_spec() const { + return out_spec_.c_str(); + } + + void update_constants_array_from_map() { + if (!constants_map_) { + throw std::runtime_error{ + "constants_map_ was not ready when constants_ is trying to be constructed from it!"}; + } + if (!constants_) { + constants_ = + std::make_shared>(constants_info_.size()); + } else { + constants_->resize(constants_info_.size()); + } + int idx = 0; + for (const auto& info : constants_info_) { + const auto it = constants_map_->find(info.name); + if (it != constants_map_->end()) { + constants_->at(idx) = ConstantHandle(it->second); + } + idx++; + } + } + + void update_constants_map( + std::shared_ptr constants_map, + bool remap_constants_array = true) { + constants_map_ = std::move(constants_map); + if (remap_constants_array) { + update_constants_array_from_map(); + } + } + + // This function allows us to update the constants_ that is used to look up + // the corresponding constant tensor during runtime. + void update_constants_array( + std::shared_ptr> constants_array) { + constants_ = std::move(constants_array); + } + + /// Returns true if the model is complete. + bool is_finished() { +#ifdef USE_CUDA + if (!run_finished_) { + throw std::runtime_error{"Model CUDA event was not initialized"}; + } + + auto event_status = cudaEventQuery(*run_finished_); + if (event_status == cudaSuccess) { + return true; + } else if (event_status == cudaErrorNotReady) { + return false; + } + + throw std::runtime_error( + std::string("The model did not finish successfully. Error: ") + + cudaGetErrorString(cudaGetLastError())); +#elif defined(USE_XPU) + if (!run_finished_) { + throw std::runtime_error{"Model XPU event was not initialized"}; + } + using namespace sycl::info; + return (*run_finished_)->get_info() == + event_command_status::complete; + +#else // !USE_CUDA && !USE_XPU + return run_finished_; +#endif // USE_CUDA + } + + /// Synchronizes completion event. + void wait_for_completion() { +#ifdef USE_CUDA + if (!run_finished_) { + throw std::runtime_error{"Model event was not initialized"}; + } + + AOTI_RUNTIME_DEVICE_CHECK(cudaEventSynchronize(*run_finished_)); +#endif // USE_CUDA +#ifdef USE_XPU + if (!run_finished_) { + throw std::runtime_error{"Model event was not initialized"}; + } + (*run_finished_)->wait_and_throw(); +#endif + } + + protected: + uint8_t* _get_constants_start() { +#ifndef USE_MMAP_SELF + // NOLINTNEXTLINE(*const-cast*) + return const_cast(_binary_constants_bin_start); +#else + if (self_mmap) { + return self_mmap; + } + Dl_info dl_info; + // get pointer to constant which are appended to the binary + AOTI_RUNTIME_CHECK( + dladdr(__func__, &dl_info), "Can't find shared library name"); + int fd = open(dl_info.dli_fname, O_RDONLY); + AOTI_RUNTIME_CHECK(fd >= 0, "Shared library file cannot be opened"); + auto fsize = lseek(fd, 0, SEEK_END); + auto weights_size = + reinterpret_cast(_binary_constants_bin_start)[0]; + auto magic_number = + reinterpret_cast(_binary_constants_bin_start)[1]; + auto weights_offset = fsize - weights_size; + AOTI_RUNTIME_CHECK( + (weights_offset & 0x3fff) == 0, + "weights_offset must be aligned to 16K boundary"); + auto ptr = mmap( + NULL, + weights_size, + PROT_READ | PROT_WRITE, + MAP_PRIVATE, + fd, + weights_offset); + close(fd); + AOTI_RUNTIME_CHECK(ptr != MAP_FAILED, "mmap() failed"); + self_mmap = static_cast(ptr); + AOTI_RUNTIME_CHECK( + reinterpret_cast( + self_mmap + weights_size - sizeof(uint64_t))[0] == magic_number, + "Weights data seems corrupt"); + return self_mmap; +#endif + } + struct ParamInfo { + const char* name = nullptr; + }; + + struct ConstInfo { + const char* name = nullptr; + std::vector shape; + std::vector stride; + int32_t dtype{}; + int64_t offset{}; + size_t data_size{}; + int32_t layout{}; + std::vector opaque_metadata; + int64_t opaque_metadata_size{}; + const char* original_fqn = nullptr; + bool from_folded{}; + int32_t type{}; + }; + + std::vector inputs_info_; + std::vector outputs_info_; + std::vector constants_info_; + std::string in_spec_; + std::string out_spec_; + + std::shared_ptr constants_map_; + std::shared_ptr> constants_; + + // Holds the blob storage for constants' at::Tensor. + RAIIDataPtr constant_blob_; + +#ifdef USE_MMAP_SELF + uint8_t* self_mmap = NULL; +#endif + + // A directory with CUDA binary files, e.g. compiled kernels, etc. + const std::optional cubin_dir_; + + // This is the flag that implies whether the weight is included in the model. + // If True, we would prepare the weight when loading the model, otherwise the + // model will be loaded without weights, and need to be provided by the user. + bool include_weights; + + // Record if the model finishes an inference run so that its owning + // AOTModelContainer can reuse this instance. +#ifdef USE_CUDA + std::optional run_finished_; +#elif defined(USE_XPU) + std::optional run_finished_; +#else // !USE_CUDA + bool run_finished_{}; +#endif + + // Generated model uses this device index to create CUDA guards. + int32_t device_type_{}; + int32_t device_idx_{}; +}; + +// Codegen-ed classes can derive from this to keep pointers to loaded kernels. +class AOTInductorModelKernelsBase { + public: + virtual ~AOTInductorModelKernelsBase() = default; +}; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class AOTInductorModel : public AOTInductorModelBase { public: AOTInductorModel( diff --git a/torch/csrc/inductor/aoti_runtime/model_container.h b/torch/csrc/inductor/aoti_runtime/model_container.h index 0bd12e841e39f..ba685416471c3 100644 --- a/torch/csrc/inductor/aoti_runtime/model_container.h +++ b/torch/csrc/inductor/aoti_runtime/model_container.h @@ -467,6 +467,7 @@ class AOTInductorModelContainer { constants_blob_ptr + constants_internal_offset_[idx]; void* user_constant_ptr; int64_t constant_size; +<<<<<<< HEAD int64_t* stride; int64_t offset; aoti_torch_get_data_ptr(tensor, &user_constant_ptr); @@ -476,12 +477,17 @@ class AOTInductorModelContainer { aoti_torch_get_storage_offset(tensor, &offset)); auto dtype = models_[0]->constant_dtype(idx); +======= + aoti_torch_get_data_ptr(tensor, &user_constant_ptr); + aoti_torch_get_storage_size(tensor, &constant_size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef USE_XPU sycl::queue* queue_ptr = nullptr; aoti_torch_get_current_sycl_queue((void**)&queue_ptr); queue_ptr ->memcpy(internal_constants_ptr, user_constant_ptr, constant_size) .wait(); +<<<<<<< HEAD #elif USE_MPS internal_constants_ptr = constants_blob_ptr; aoti_torch_mps_copy_buffer( @@ -497,6 +503,10 @@ class AOTInductorModelContainer { aoti_torch_dtype_element_size(dtype); #elif USE_CUDA AOTI_RUNTIME_CUDA_CHECK(cudaMemcpy( +======= +#elif USE_CUDA + AOTI_RUNTIME_DEVICE_CHECK(cudaMemcpy( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) internal_constants_ptr, user_constant_ptr, constant_size, @@ -508,15 +518,29 @@ class AOTInductorModelContainer { // We extract stride and offset from provided Tensor since we do not // guarantee that the tensor is contiguous. AtenTensorHandle tensor_handle; +<<<<<<< HEAD + int device_type = models_[0]->get_device_type(); + int device_idx = models_[0]->get_device_idx(); +======= + int64_t* stride; + int64_t offset; int device_type = models_[0]->get_device_type(); int device_idx = models_[0]->get_device_idx(); + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(tensor, &stride)); + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_storage_offset(tensor, &offset)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob( internal_constants_ptr, models_[0]->constant_ndim(idx), models_[0]->constant_shape(idx), stride, offset, +<<<<<<< HEAD dtype, +======= + models_[0]->constant_dtype(idx), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device_type, device_idx, &tensor_handle)); diff --git a/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h b/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h index 3a2e91c37c916..276472840807a 100644 --- a/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h +++ b/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h @@ -128,10 +128,19 @@ static std::unique_ptr _createKernel( uint32_t numWarps, uint32_t sharedMemory, void** params, +<<<<<<< HEAD sycl::queue* queuePtr, uint32_t threadsPerWarp) { std::string kernelName = kernelPtr->get_info(); +======= + sycl::queue* queuePtr) { + std::string kernelName = + kernelPtr->get_info(); + // Currently threadsPerWarp is hard code to 32 from torch.compile to triton + // stack. + int threadsPerWarp = 32; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint32_t numParams = kernelPtr->get_info(); size_t globalRangeX = gridX * threadsPerWarp * numWarps; size_t globalRangeY = gridY; diff --git a/torch/csrc/inductor/aoti_runtime/utils.h b/torch/csrc/inductor/aoti_runtime/utils.h index b813b3f6f745c..fdf3392e798c7 100644 --- a/torch/csrc/inductor/aoti_runtime/utils.h +++ b/torch/csrc/inductor/aoti_runtime/utils.h @@ -12,7 +12,10 @@ // C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule // applies to other files under torch/csrc/inductor/aoti_runtime/. #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #if defined(__GNUC__) || defined(__clang__) #define AOTI_NOINLINE __attribute__((noinline)) @@ -22,18 +25,39 @@ #define AOTI_NOINLINE #endif +<<<<<<< HEAD #define AOTI_TORCH_ERROR_CODE_CHECK(call) \ if ((call) != AOTI_TORCH_SUCCESS) { \ torch::headeronly::detail::throw_exception(#call, __FILE__, __LINE__); \ +======= +AOTI_NOINLINE static void throw_exception( + const char* call, + const char* file, + int64_t line) { + std::stringstream ss; + ss << call << " API call failed at " << file << ", line " << line; + throw std::runtime_error(ss.str()); +} + +#define AOTI_TORCH_ERROR_CODE_CHECK(call) \ + if ((call) != AOTI_TORCH_SUCCESS) { \ + throw_exception(#call, __FILE__, __LINE__); \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } using AOTIRuntimeError = int32_t; #define AOTI_RUNTIME_SUCCESS 0 #define AOTI_RUNTIME_FAILURE 1 +<<<<<<< HEAD #define AOTI_RUNTIME_ERROR_CODE_CHECK(call) \ if ((call) != AOTI_RUNTIME_SUCCESS) { \ torch::headeronly::detail::throw_exception(#call, __FILE__, __LINE__); \ +======= +#define AOTI_RUNTIME_ERROR_CODE_CHECK(call) \ + if ((call) != AOTI_RUNTIME_SUCCESS) { \ + throw_exception(#call, __FILE__, __LINE__); \ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } namespace torch::aot_inductor { @@ -42,16 +66,20 @@ using DeleterFnPtr = void (*)(void*); inline void noop_deleter(void*) {} +<<<<<<< HEAD inline void delete_record_function_object(void* ptr) { AOTI_TORCH_ERROR_CODE_CHECK(aoti_record_function_end( reinterpret_cast(ptr))); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inline void delete_tensor_object(void* ptr) { AOTI_TORCH_ERROR_CODE_CHECK( aoti_torch_delete_tensor_object(reinterpret_cast(ptr))); } +<<<<<<< HEAD class RAIIAtenRecordFunctionHandle { public: RAIIAtenRecordFunctionHandle() : handle_(nullptr, noop_deleter) {} @@ -116,6 +144,8 @@ class RAIIAtenRecordFunctionHandle { std::unique_ptr handle_; }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // RAIIAtenTensorHandle steals the tensor objects created by the libtorch C ABI class RAIIAtenTensorHandle { public: diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 3ce4dd82cfdab..0b1e0c44a1370 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -1,8 +1,13 @@ #ifndef AOTI_TORCH_SHIM #define AOTI_TORCH_SHIM +<<<<<<< HEAD #include #include +======= +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // This header defines a stable C API for certain ATen functionality in // libtorch. The AOTInductor compiled model.so will only refer to this header @@ -36,6 +41,32 @@ // maintain the old and new versions of the APIs until all old model.so // go out of use. +<<<<<<< HEAD +======= +#ifdef __GNUC__ +#define AOTI_TORCH_EXPORT __attribute__((__visibility__("default"))) +#else // !__GNUC__ +#ifdef _WIN32 +// PyTorch2 doesn't currently work on Windows. Exporting these APIs can lead +// to symbol clashes at link time if libtorch is included in a DLL and binary +// that depends on the DLL. As a short term fix, we don't export the symbols. +// In the long term, this will need to be addressed when Windows is supported. +#ifdef OVRSOURCE +// Do not export AOTI on Windows for internal builds +#define AOTI_TORCH_EXPORT +#else /* OVRSOURCE */ +#ifdef EXPORT_AOTI_FUNCTIONS +#define AOTI_TORCH_EXPORT __declspec(dllexport) +#else +#define AOTI_TORCH_EXPORT __declspec(dllimport) +#endif +#endif /* OVRSOURCE */ +#else // !_WIN32 +#define AOTI_TORCH_EXPORT +#endif // _WIN32 +#endif // __GNUC__ + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // The following files are implemented in a header-only way and are guarded by // test/cpp/aoti_abi_check #include @@ -46,6 +77,36 @@ extern "C" { #endif +<<<<<<< HEAD +======= +// AtenTensorHandle represents an abstract notion of Tensor that can be passed +// between model.so and libtorch.so. The contents of the structure itself +// are private; model.so is not allowed to access any fields directly, it must +// go through functions defined in this ABI. Under the hood, this is +// represented as at::Tensor*, but we reserve the right to change this (and in +// fact, we probably should change it to at::TensorImpl* at least). +// +// An AtenTensorHandle can be owning (please check the API reference for exact +// ownership/borrow semantics). If you have an owning AtenTensorHandle +// in model.so, you are obligated to aoti_torch_delete_tensor_object when you +// are done. You can use the helper C++ class RAIIAtenTensorHandle +// (see aot_runtime/model.h) to ensure the deallocator is called in RAII style +// (note that RAIIAtenTensorHandle is private to model.so, and never crosses +// the ABI boundary.) +struct AtenTensorOpaque; +using AtenTensorHandle = AtenTensorOpaque*; + +struct AtenGeneratorOpaque; +using AtenGeneratorHandle = AtenGeneratorOpaque*; + +struct AOTIProxyExecutorOpaque; +using AOTIProxyExecutorHandle = AOTIProxyExecutorOpaque*; + +using AOTITorchError = int32_t; +#define AOTI_TORCH_SUCCESS 0 +#define AOTI_TORCH_FAILURE 1 + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Getter functions for retrieving various constants from the runtime, that // can subsequently be passed to other aoti_* functions. By hiding these // behind functions, the precise value of device/dtype is NOT part of the @@ -220,9 +281,12 @@ aoti_torch_get_device_type(AtenTensorHandle tensor, int32_t* ret_device_type); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_device_index(AtenTensorHandle tensor, int32_t* ret_device_index); +<<<<<<< HEAD AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_layout(AtenTensorHandle tensor, int32_t* ret_layout); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset( AtenTensorHandle tensor, int64_t* ret_storage_offset); @@ -230,9 +294,12 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset( AOTI_TORCH_EXPORT AOTITorchError aoti_torch_is_contiguous(AtenTensorHandle tensor, bool* ret_is_contiguous); +<<<<<<< HEAD AOTI_TORCH_EXPORT AOTITorchError aoti_torch_is_defined(AtenTensorHandle tensor, bool* ret_is_defined); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_EXPORT AOTITorchError aoti_torch_new_tensor_handle( AtenTensorHandle orig_handle, AtenTensorHandle* new_handle); @@ -273,6 +340,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_empty_strided( AtenTensorHandle* ret_new_tensor // returns new reference ); +<<<<<<< HEAD AOTI_TORCH_EXPORT AOTITorchError aoti_torch_empty_strided_pinned( int64_t ndim, const int64_t* sizes_ptr, @@ -283,6 +351,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_empty_strided_pinned( AtenTensorHandle* ret_new_tensor // returns new reference ); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_EXPORT AOTITorchError aoti_torch_as_strided( AtenTensorHandle self, const int64_t* sizes_ptr, @@ -315,6 +385,130 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2( const uint8_t* opaque_metadata, int64_t opaque_metadata_size); +<<<<<<< HEAD +======= +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag( + AtenTensorHandle weight, + AtenTensorHandle indices, + AtenTensorHandle offsets, + int32_t scale_grad_by_freq, + int32_t mode, + int32_t sparse, + AtenTensorHandle per_sample_weights, // optional argument + int32_t include_last_offset, + int32_t padding_idx, + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__fft_c2c( + AtenTensorHandle self, + const int64_t* dim_ptr, + int64_t dim_size, + int64_t normalization, + int32_t forward, + AtenTensorHandle* ret // returns new reference +); + +// This version is deprecated. We will remove it later +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + double dropout_p, + bool is_causal, + bool return_debug_mask, + double scale, + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3, // returns new reference + int64_t* ret4, + int64_t* ret5, + AtenTensorHandle* ret6, // returns new reference + AtenTensorHandle* ret7, // returns new reference + AtenTensorHandle* ret8 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch__scaled_dot_product_flash_attention_v2( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + double dropout_p, + int is_causal, + int return_debug_mask, + double* scale, // optional argument + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3, // returns new reference + int64_t* ret4, + int64_t* ret5, + AtenTensorHandle* ret6, // returns new reference + AtenTensorHandle* ret7, // returns new reference + AtenTensorHandle* ret8 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch__scaled_dot_product_efficient_attention( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + AtenTensorHandle attn_bias, // optional argument + int compute_log_sumexp, + double dropout_p, + int is_causal, + double* scale, // optional argument + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm( + AtenTensorHandle self, + AtenTensorHandle mat2, + AtenTensorHandle bias, + int32_t* out_dtype, + AtenTensorHandle scale_a, + AtenTensorHandle scale_b, + AtenTensorHandle scale_result, + int8_t use_fast_accum, + AtenTensorHandle* ret0, + AtenTensorHandle* ret1); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm_v2( + AtenTensorHandle self, + AtenTensorHandle mat2, + AtenTensorHandle scale_a, + AtenTensorHandle scale_b, + AtenTensorHandle bias, + AtenTensorHandle scale_result, + int32_t* out_dtype, + int8_t use_fast_accum, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_convolution( + AtenTensorHandle input, + AtenTensorHandle weight, + AtenTensorHandle bias, // optional argument + const int64_t* stride_ptr, + int64_t stride_size, + const int64_t* padding_ptr, + int64_t padding_size, + const int64_t* dilation_ptr, + int64_t dilation_size, + int transposed, + const int64_t* output_padding_ptr, + int64_t output_padding_size, + int64_t groups, + AtenTensorHandle* ret // returns new reference +); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // This function will create a new uninitialized tensor object // and its pointer is returned through *ret. AOTI_TORCH_EXPORT AOTITorchError @@ -347,11 +541,35 @@ aoti_torch_clone(AtenTensorHandle self, AtenTensorHandle* ret); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_clone_preserve_strides(AtenTensorHandle self, AtenTensorHandle* ret); +<<<<<<< HEAD +======= +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_addmm_out( + AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat1, + AtenTensorHandle mat2, + float beta, + float alpha); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_bmm_out( + AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat2); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_EXPORT AOTITorchError aoti_torch_copy_( AtenTensorHandle self, AtenTensorHandle src, int32_t non_blocking); +<<<<<<< HEAD +======= +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mm_out( + AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat2); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_EXPORT AOTITorchError aoti_torch__mm_plus_mm_out( AtenTensorHandle out, AtenTensorHandle a, @@ -381,7 +599,11 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight( AtenTensorHandle input, AtenTensorHandle weight, +<<<<<<< HEAD AtenTensorHandle bias, // optional argument +======= + AtenTensorHandle bias, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t out_channel, AtenTensorHandle* out); @@ -398,6 +620,7 @@ aoti_torch_cpu__wrapped_quantized_linear_prepacked( int64_t out_channel, AtenTensorHandle* out); +<<<<<<< HEAD AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self); AOTI_TORCH_EXPORT AOTITorchError @@ -419,6 +642,21 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_record_function_start( AOTI_TORCH_EXPORT AOTITorchError aoti_record_function_end(AtenRecordFunctionHandle guard); +======= +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( + AtenTensorHandle repeats, + int64_t* output_size, + AtenTensorHandle* out); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_check_inf_and_nan(const char* tensor_name, AtenTensorHandle tensor); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_out( AtenTensorHandle out, AtenTensorHandle self, @@ -443,6 +681,20 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_index_put_out( const AtenTensorHandle values, bool accumulate); +<<<<<<< HEAD +======= +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_as_real( + AtenTensorHandle self, + AtenTensorHandle* ret // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_dtype( + AtenTensorHandle self, + int32_t dtype, + AtenTensorHandle* ret // returns new reference +); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle( AtenTensorHandle self, const char* msg); @@ -515,6 +767,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_call_dispatcher( const char* overloadName, StableIValue* stack); +<<<<<<< HEAD // Device-generic guard for managing device context struct DeviceGuardOpaque; using DeviceGuardHandle = DeviceGuardOpaque*; @@ -548,6 +801,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_stream( AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_device_index(int32_t* ret_device_index); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef USE_CUDA struct CUDAGuardOpaque; diff --git a/torch/csrc/inductor/aoti_torch/c/shim_cpu.h b/torch/csrc/inductor/aoti_torch/c/shim_cpu.h index 5a10290decd1d..45fa8cc04368b 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/c/shim_cpu.h @@ -245,6 +245,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor( AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0); +<<<<<<< HEAD AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce_( AtenTensorHandle inp, const char* reduce_op, @@ -261,6 +262,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__c10d_functional_wait_tensor( AtenTensorHandle inp, AtenTensorHandle* ret0); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef __cplusplus } // extern "C" #endif diff --git a/torch/csrc/inductor/aoti_torch/c/shim_mps.h b/torch/csrc/inductor/aoti_torch/c/shim_mps.h index 08f1569927f00..791f5a91fa1f6 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim_mps.h +++ b/torch/csrc/inductor/aoti_torch/c/shim_mps.h @@ -32,6 +32,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_memcpy( size_t data_size, uint8_t* constants_start); +<<<<<<< HEAD AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_copy_buffer( void* src_buffer, void* dst_buffer, @@ -39,6 +40,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_copy_buffer( size_t src_offset, size_t dst_offset); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef __cplusplus } // extern "C" #endif diff --git a/torch/csrc/inductor/aoti_torch/c/shim_xpu.h b/torch/csrc/inductor/aoti_torch/c/shim_xpu.h index c25fe6443c948..a8b997336a322 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/c/shim_xpu.h @@ -107,6 +107,7 @@ aoti_torch_xpu_mkldnn__convolution_pointwise_binary_( const char** unary_algorithm, AtenTensorHandle* ret0); +<<<<<<< HEAD AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__qlinear_pointwise_tensor( AtenTensorHandle X, AtenTensorHandle act_scale, @@ -201,6 +202,8 @@ aoti_torch_xpu__qconv2d_pointwise_binary_tensor( const char** unary_algorithm, AtenTensorHandle* ret0); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif // AT_MKLDNN_ENABLED() #ifdef __cplusplus } // extern "C" diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index aced2b2f539de..02f9792b79976 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -29,7 +29,10 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_c2c(AtenTensorHandle self, AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +<<<<<<< HEAD AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index 470919cf389c3..fd37218fcbd40 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -32,7 +32,10 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_backward(AtenT AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* cum_seq_q, AtenTensorHandle* cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* seqused_k, AtenTensorHandle* alibi_slopes, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +<<<<<<< HEAD AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0); @@ -51,7 +54,10 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__thnn_fused_lstm_cell(AtenTenso AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0); +<<<<<<< HEAD AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_abs(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h index 179c0074b3cdf..b5400668b5093 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h @@ -18,12 +18,18 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +<<<<<<< HEAD AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); +======= +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0); @@ -38,8 +44,11 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_addmv(AtenTensorHandle self, Ate AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_angle(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +<<<<<<< HEAD AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_baddbmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator); @@ -68,17 +77,23 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_masked_scatter_backward(AtenTens AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_masked_select(AtenTensorHandle self, AtenTensorHandle mask, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool2d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool2d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0); +<<<<<<< HEAD AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool3d_with_indices(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_pool3d_with_indices_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, const int64_t* dilation, int64_t dilation_len_, int32_t ceil_mode, AtenTensorHandle indices, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_unpool2d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_max_unpool3d(AtenTensorHandle self, AtenTensorHandle indices, const int64_t* output_size, int64_t output_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_median(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Scalar(AtenTensorHandle self, double other, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_mul_Tensor(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nanmedian(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0); +<<<<<<< HEAD AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_native_dropout(AtenTensorHandle input, double p, int32_t* train, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h index 09ebbb76d0b21..d9d4be4780c62 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h @@ -13,8 +13,11 @@ extern "C" { AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +<<<<<<< HEAD AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 2cdeab071cd82..b1c2aa8962401 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -1,5 +1,8 @@ #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -25,10 +28,13 @@ #include #include +<<<<<<< HEAD #include #include #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifndef AT_PER_OPERATOR_HEADERS #include #else @@ -389,6 +395,7 @@ AOTITorchError aoti_torch_get_device_index( }); } +<<<<<<< HEAD AOTITorchError aoti_torch_get_layout( AtenTensorHandle tensor, int32_t* ret_layout) { @@ -398,6 +405,8 @@ AOTITorchError aoti_torch_get_layout( }); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTITorchError aoti_torch_get_storage_offset( AtenTensorHandle tensor, int64_t* ret_storage_offset) { @@ -416,6 +425,7 @@ AOTITorchError aoti_torch_is_contiguous( }); } +<<<<<<< HEAD AOTITorchError aoti_torch_is_defined( AtenTensorHandle tensor, bool* ret_is_defined) { @@ -425,6 +435,8 @@ AOTITorchError aoti_torch_is_defined( }); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTITorchError aoti_torch_new_tensor_handle( AtenTensorHandle orig_handle, AtenTensorHandle* new_handle) { @@ -475,6 +487,7 @@ AOTITorchError aoti_torch_empty_strided( }); } +<<<<<<< HEAD AOTITorchError aoti_torch_empty_strided_pinned( int64_t ndim, const int64_t* sizes_ptr, @@ -497,6 +510,8 @@ AOTITorchError aoti_torch_empty_strided_pinned( }); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTITorchError aoti_torch_create_tensor_from_blob( void* data, int64_t ndim, @@ -1026,17 +1041,28 @@ AOTITorchError aoti_torch_cpu__wrapped_linear_prepack( AOTITorchError aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight( AtenTensorHandle input, AtenTensorHandle weight, +<<<<<<< HEAD AtenTensorHandle bias, // optional argument +======= + AtenTensorHandle bias, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t out_channel, AtenTensorHandle* out) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ at::Tensor* input_tensor = tensor_handle_to_tensor_pointer(input); at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight); +<<<<<<< HEAD auto optional_bias_tensor = pointer_to_optional(tensor_handle_to_tensor_pointer(bias)); *out = new_tensor_handle(at::fbgemm_linear_fp16_weight_fp32_activation( *input_tensor, *weight_tensor, optional_bias_tensor)); +======= + at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias); + + *out = new_tensor_handle(at::fbgemm_linear_fp16_weight_fp32_activation( + *input_tensor, *weight_tensor, *bias_tensor)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }); } @@ -1101,6 +1127,7 @@ AOTITorchError aoti_torch_check_inf_and_nan( }); } +<<<<<<< HEAD AOTITorchError aoti_record_function_start( const char* name, IValueMapHandle kwargs, @@ -1140,6 +1167,8 @@ AOTITorchError aoti_record_function_end(AtenRecordFunctionHandle guard) { }); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTITorchError aoti_torch_scatter_out( AtenTensorHandle out, AtenTensorHandle self, @@ -1266,7 +1295,12 @@ void aoti_torch_print_tensor_handle(AtenTensorHandle self, const char* msg) { if (msg) { std::cout << " " << msg; } +<<<<<<< HEAD std::cout << " " << "]:" << '\n'; +======= + std::cout << " " + << "]:" << '\n'; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Print exact tensor values for small size tensors const int64_t numel = t->numel(); @@ -1673,6 +1707,7 @@ AOTITorchError aoti_torch_call_dispatcher( } }); } +<<<<<<< HEAD AOTITorchError aoti_torch_create_device_guard( int32_t device_index, @@ -1730,3 +1765,5 @@ AOTITorchError aoti_torch_get_current_device_index(int32_t* ret_device_index) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( { *ret_device_index = at::accelerator::getDeviceIndex(); }); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/csrc/inductor/aoti_torch/shim_cpu.cpp b/torch/csrc/inductor/aoti_torch/shim_cpu.cpp index b1c864bf3fbba..835f0b05d6910 100644 --- a/torch/csrc/inductor/aoti_torch/shim_cpu.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_cpu.cpp @@ -1,7 +1,10 @@ +<<<<<<< HEAD #ifdef USE_DISTRIBUTED #include #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -19,6 +22,19 @@ using namespace torch::aot_inductor; #if AT_MKLDNN_ENABLED() +<<<<<<< HEAD +======= +template +static c10::List convert_to_c10_List(const T* scalars, const int64_t len) { + c10::List scalars_list; + scalars_list.reserve(len); + for (int64_t i = 0; i < len; i++) { + scalars_list.emplace_back(scalars[i]); + } + return scalars_list; +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTITorchError aoti_torch_cpu_mkldnn__convolution_pointwise_binary( AtenTensorHandle X, AtenTensorHandle other, @@ -532,6 +548,7 @@ AOTITorchError aoti_torch_cpu__weight_int4pack_mm_cpu_tensor( *ret0 = new_tensor_handle(std::move(tmp_result)); }); } +<<<<<<< HEAD #ifdef USE_DISTRIBUTED AOTITorchError aoti_torch_cpu__c10d_functional_all_reduce_( @@ -567,3 +584,5 @@ AOTITorchError aoti_torch_cpu__c10d_functional_wait_tensor( }); } #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/csrc/inductor/aoti_torch/shim_mps.mm b/torch/csrc/inductor/aoti_torch/shim_mps.mm index 1bf88839ecfe0..cddae1e1c08ff 100644 --- a/torch/csrc/inductor/aoti_torch/shim_mps.mm +++ b/torch/csrc/inductor/aoti_torch/shim_mps.mm @@ -3,8 +3,11 @@ #include #include #include +<<<<<<< HEAD #include #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using namespace torch::aot_inductor; @@ -42,6 +45,7 @@ AOTITorchError aoti_torch_mps_free( memcpy(buffer_pointer + constant_offset, constants_start + bytes_read, data_size); }); } +<<<<<<< HEAD AOTITorchError aoti_torch_mps_copy_buffer(void* src_buffer, void* dst_buffer, size_t data_size, size_t src_offset, size_t dst_offset) { @@ -55,3 +59,5 @@ AOTITorchError aoti_torch_mps_free( stream->copy_and_sync(src_mtl_buffer, dst_mtl_buffer, data_size, src_offset, dst_offset, true, profile_id); }); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/csrc/inductor/aoti_torch/shim_xpu.cpp b/torch/csrc/inductor/aoti_torch/shim_xpu.cpp index c05872ae04239..1021b99ad6dda 100644 --- a/torch/csrc/inductor/aoti_torch/shim_xpu.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_xpu.cpp @@ -80,8 +80,11 @@ AOTITorchError aoti_torch_get_current_sycl_queue(void** ret) { #if AT_MKLDNN_ENABLED() #include +<<<<<<< HEAD #include #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AOTITorchError aoti_torch_xpu_mkldnn__convolution_pointwise_binary( AtenTensorHandle X, @@ -206,6 +209,7 @@ AOTITorchError aoti_torch_xpu_mkldnn__convolution_pointwise( }); } +<<<<<<< HEAD AOTITorchError aoti_torch_xpu__qlinear_pointwise_tensor( AtenTensorHandle X, AtenTensorHandle act_scale, @@ -429,4 +433,6 @@ AOTITorchError aoti_torch_xpu__qconv2d_pointwise_binary_tensor( }); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif // AT_MKLDNN_ENABLED() diff --git a/torch/csrc/inductor/aoti_torch/utils.h b/torch/csrc/inductor/aoti_torch/utils.h index 22018cd70c829..e8038478daf3f 100644 --- a/torch/csrc/inductor/aoti_torch/utils.h +++ b/torch/csrc/inductor/aoti_torch/utils.h @@ -222,6 +222,7 @@ inline std::optional> pointer_to_optional_list( : std::nullopt; } +<<<<<<< HEAD template static c10::List convert_to_c10_List(const T* scalars, const int64_t len) { c10::List scalars_list; @@ -232,4 +233,6 @@ static c10::List convert_to_c10_List(const T* scalars, const int64_t len) { return scalars_list; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace torch::aot_inductor diff --git a/torch/csrc/inductor/cpp_prefix.h b/torch/csrc/inductor/cpp_prefix.h index f98da60a10496..2a1fcdaa36227 100644 --- a/torch/csrc/inductor/cpp_prefix.h +++ b/torch/csrc/inductor/cpp_prefix.h @@ -75,6 +75,7 @@ struct IsVecMaskType> : std::true_type {}; #endif template +<<<<<<< HEAD struct CascadeSumHelper { // A data struct to help cascade summation: std::vector sum_stk{}; @@ -134,6 +135,8 @@ inline T cascade_sum_final(CascadeSumHelper* c) { } template +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct WelfordHelper { // A data struct to help welford reduction: // 1. Save the reciprocal of weights to avoid redundant divisions. @@ -270,6 +273,7 @@ Welford welford_combine( out.index}; } +<<<<<<< HEAD template inline T cascade_sum_combine( T& data, @@ -295,6 +299,8 @@ inline T cascade_sum_combine( return c->sum_stk[0]; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template T max_masked_reduce(const T& a, const T& b, const int64_t tail_size) { auto out = at::vec::maximum(a, b); diff --git a/torch/csrc/inductor/cpp_wrapper/common.h b/torch/csrc/inductor/cpp_wrapper/common.h index a2eebfcc86032..604e1965f766f 100644 --- a/torch/csrc/inductor/cpp_wrapper/common.h +++ b/torch/csrc/inductor/cpp_wrapper/common.h @@ -6,7 +6,12 @@ #include #include +<<<<<<< HEAD #include +======= +#define PYBIND11_SIMPLE_GIL_MANAGEMENT +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Include some often-used cpp_wrapper headers, for precompiling. #include diff --git a/torch/csrc/jit/OVERVIEW.md b/torch/csrc/jit/OVERVIEW.md index 1ef0522d2175a..47a57e37f03c1 100644 --- a/torch/csrc/jit/OVERVIEW.md +++ b/torch/csrc/jit/OVERVIEW.md @@ -958,7 +958,11 @@ torch._C._jit_set_fusion_strategy([ ]) ``` +<<<<<<< HEAD This will make two attempts to generate static-shape graphs, and after that fall back to generating dynamic-shape graphs. If for some reason compilation keeps occurring (even with dynamic-shape graphs - e.g. this could happen if ranks or dtypes vary), after 20 compilation attempts the graph executor will fall back to running the graph without any attempts to compile it. +======= +This will make two attempts to generate static-shape graphs, and after that fall back to generating dynamic-shape graphs. If for some reason compilation keeps occuring (even with dynamic-shape graphs - e.g. this could happen if ranks or dtypes vary), after 20 compilation attempts the graph executor will fall back to running the graph without any attempts to compile it. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ### Pre-derivative Optimization ### diff --git a/torch/csrc/jit/README.md b/torch/csrc/jit/README.md index 4d9c2d07f3d1d..1fe6404cc7c83 100644 --- a/torch/csrc/jit/README.md +++ b/torch/csrc/jit/README.md @@ -26,5 +26,9 @@ A brief summary of the source tree: **Refer** to each folder for more in-depth documentation. Other relevant parts of the codebase not contained here: +<<<<<<< HEAD - [aten/src/ATen/core](../../../aten/src/ATen/core): contains JIT code reused by other elements of the +======= +- [aten/src/ATen/core](../../../aten/src/ATen/core): contains JIT code re-used by other elements of the +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) runtime system (eager, mobile, etc.) diff --git a/torch/csrc/jit/api/function_impl.h b/torch/csrc/jit/api/function_impl.h index f508f3e5d522b..2f67da6d2f68e 100644 --- a/torch/csrc/jit/api/function_impl.h +++ b/torch/csrc/jit/api/function_impl.h @@ -147,7 +147,11 @@ struct TORCH_API GraphFunction : public Function { mutable std::array, SpecializationKey::TotalCount> optimized_graphs_; +<<<<<<< HEAD // GraphFunctions are invocable from multiple threads, so this lock needs to +======= + // GraphFunctions are invokable from multiple threads, so this lock needs to +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // be held when we're initializing graph executor for the first time or // computing the optimized graph. We're using reentrant mutex so that we don't // need to worry about causing a deadlock by calling one method from another diff --git a/torch/csrc/jit/api/method.h b/torch/csrc/jit/api/method.h index d7ef14ddb193d..41a79c140e20b 100644 --- a/torch/csrc/jit/api/method.h +++ b/torch/csrc/jit/api/method.h @@ -67,7 +67,11 @@ struct TORCH_API Method : public torch::IMethod { private: void setArgumentNames(std::vector&) const override; +<<<<<<< HEAD // Methods are uniqued owned by a single module. This raw pointer allows +======= + // Methods are uniqued onwed by a single module. This raw pointer allows +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // looking up the module. ObjectPtr owner_; diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h index 52cec12fb8598..23bcf743aebca 100644 --- a/torch/csrc/jit/api/module.h +++ b/torch/csrc/jit/api/module.h @@ -327,7 +327,11 @@ struct TORCH_API Module : public Object { // Map of function names to the traced inputs that they have been traced with c10::Dict traced_inputs_; +<<<<<<< HEAD // Mutex to keep registering buffer or parameter thread safe. +======= + // Mutex to keep registring buffer or parameter thread safe. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::shared_ptr register_mutex_ = std::make_shared(); }; diff --git a/torch/csrc/jit/backends/backend_debug_handler.cpp b/torch/csrc/jit/backends/backend_debug_handler.cpp index 0d41034130395..350ead90b032e 100644 --- a/torch/csrc/jit/backends/backend_debug_handler.cpp +++ b/torch/csrc/jit/backends/backend_debug_handler.cpp @@ -26,7 +26,11 @@ int64_t BackendDebugInfoRecorder::getNextDebugHandle(const Node* node) { BackendDebugInfoMapType BackendDebugInfoRecorder::stopRecording() { // Note that this is return by copy and since // InlinedCallStackPtrs are intrusive ptr it will result in +<<<<<<< HEAD // bump of refcount. Not performant, but this is not intended +======= + // bump of refcount. Not performant, but this is not intented +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // to be used in perf critical path. // Alternate might be do move but that will be destructive return handles_to_inlined_callstack_ptrs_; diff --git a/torch/csrc/jit/backends/backend_debug_handler.h b/torch/csrc/jit/backends/backend_debug_handler.h index 2e0145b56c294..c29ed3bd5176e 100644 --- a/torch/csrc/jit/backends/backend_debug_handler.h +++ b/torch/csrc/jit/backends/backend_debug_handler.h @@ -18,7 +18,11 @@ namespace torch::jit { * Effectively debug handles are something that is given to backend and later * when an exception occurs in the backend, backend can tell, using debug * handle, that an exception occurred here. Then the runtime can generate +<<<<<<< HEAD * callstack corresponding to the exception. +======= + * callstack correspoding to the exception. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * There are two parts to BackendDebugHandleManager: * 1. static std::atomic debug_handle * 2. Map of [debug-handle, DebugInfoTuple] diff --git a/torch/csrc/jit/backends/backend_exception.h b/torch/csrc/jit/backends/backend_exception.h index 807ef38e28305..5362ee5468ef1 100644 --- a/torch/csrc/jit/backends/backend_exception.h +++ b/torch/csrc/jit/backends/backend_exception.h @@ -16,13 +16,21 @@ class TORCH_API BackendRuntimeException : public c10::Error { } // If rethrowing, can push another debug_handle // This is useful in couple of scenarios. +<<<<<<< HEAD // 1. A submodule is lowered and lite interpreter has CallMethod +======= + // 1. A submodule is lowered and lite interperter has CallMethod +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // to lowered module's method. In this case lowered module will throw with // a handle, plus there will be another debug handle corresponding // to the CallMethod node in lite interpreter. Both together give complete // trace. This function allows lite interpreter to rethrow with debug // handle it has for CallMethod. +<<<<<<< HEAD // 2. Another scenarios is when lite interpreter can make function calls or +======= + // 2. Another scenarios is when lite interperter can make function calls or +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // the lowered backend also has function call ability. Thus we have // multiple function frames. Now we need a stack of handles to symbolicate // entire stack trace. diff --git a/torch/csrc/jit/backends/xnnpack/serialization/serializer.h b/torch/csrc/jit/backends/xnnpack/serialization/serializer.h index 4d8fe049134fe..2f9b6ed305376 100644 --- a/torch/csrc/jit/backends/xnnpack/serialization/serializer.h +++ b/torch/csrc/jit/backends/xnnpack/serialization/serializer.h @@ -37,7 +37,11 @@ class XNNSerializer { // Serialize add node, we are serializing the argument needed to call // xnn_define_add2. Serializing these values, and at run time we build +<<<<<<< HEAD // the graph by re running xnn_define_add2 +======= + // teh graph by re running xnn_define_add2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void serializeAddNode( uint32_t input1_id, uint32_t input2_id, diff --git a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp index 0428ac370b728..96013d67f6c0f 100644 --- a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp +++ b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp @@ -34,7 +34,11 @@ class XNNPackBackend : public PyTorchBackendInterface { c10::impl::GenericDict method_compile_spec) override { auto dict = processed.toGenericDict(); +<<<<<<< HEAD // Compiling and wrapping execution object +======= + // Compiling and wrapping exeuction object +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const std::string& ser_model = dict.at("ser_model").toStringRef(); XNNExecutor executor; XNNCompiler::compileModel(ser_model.data(), ser_model.length(), &executor); diff --git a/torch/csrc/jit/codegen/cuda/README.md b/torch/csrc/jit/codegen/cuda/README.md index a68bc0491919b..ad822a9aa7b9d 100644 --- a/torch/csrc/jit/codegen/cuda/README.md +++ b/torch/csrc/jit/codegen/cuda/README.md @@ -78,7 +78,11 @@ Graph print out is straight forward and you should look for `prim::CudaFusionGro return (%o.5) ``` +<<<<<<< HEAD Note that one thing that could prevents fusion when you are running training is autodiff. Fusion pass only runs within `prim::DifferentiableGraph`, so the first thing you should check is to that targeted ops are within differentiable graph subgraphs. +======= +Note that one thing that could prevents fusion when you are running training is autodiff. Fusion pass only runs within `prim::DifferentiableGraph`, so the first thing you should check is to that targetted ops are within differentiable graph subgraphs. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Graph dump could be quite confusing to look at, since it naively dumps all graphs executed by profiling executor and differentiable graphs are executed via a nested graph executor. So for each graph, you might see a few segmented `Optimized Graph` where each corresponds to a differentiable node in the original graph. #### 2. Cuda Fusion Graphs diff --git a/torch/csrc/jit/codegen/fuser/codegen.cpp b/torch/csrc/jit/codegen/fuser/codegen.cpp index 2f1e7e8e95059..597aec8c961cd 100644 --- a/torch/csrc/jit/codegen/fuser/codegen.cpp +++ b/torch/csrc/jit/codegen/fuser/codegen.cpp @@ -635,7 +635,11 @@ std::string generateKernel( } // Includes headers +<<<<<<< HEAD // Note: CUDA kernels support Halfs and random generation, CPU kernels do not +======= + // Note: CUDA kernels support halfs and random generation, CPU kernels do not +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (has_half_tensor) { env.s("HalfHeader", cuda::half_support_literal); } else { diff --git a/torch/csrc/jit/codegen/fuser/executor.cpp b/torch/csrc/jit/codegen/fuser/executor.cpp index 67c4501dc2758..5757f18e81273 100644 --- a/torch/csrc/jit/codegen/fuser/executor.cpp +++ b/torch/csrc/jit/codegen/fuser/executor.cpp @@ -28,7 +28,11 @@ static std::optional> getMapSize( // exactly how much storage do we need, so this could be fixed in-place at // every step. We're just missing a few functions for ATen, but the fix // should be straightforward. +<<<<<<< HEAD // Note: left uninitialized since empty shape is broadcastable to any shape +======= + // Note: left unitialized since empty shape is broadcastable to any shape +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector map_size; map_size.reserve(8); for (const auto arg_idx : arg_subset) { @@ -201,7 +205,11 @@ static void launchFusion( for (const auto& c : fusion.concatDesc()) flat_outputs_size += c.nSubTensors(); +<<<<<<< HEAD // Fails if the elements of the first (any) tensor are not expressible as +======= + // Fails if the elements of the first (any) tensor are not expressable as +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // a 32-bit integer. // Note: this code assumes that inputs are 32-bit addressable // Note: this code assumes that all inputs are of the same size diff --git a/torch/csrc/jit/codegen/fuser/fused_kernel.h b/torch/csrc/jit/codegen/fuser/fused_kernel.h index 0f785c4506609..674edd0d6be5c 100644 --- a/torch/csrc/jit/codegen/fuser/fused_kernel.h +++ b/torch/csrc/jit/codegen/fuser/fused_kernel.h @@ -40,7 +40,11 @@ struct FusedKernel { // CUDA code), and the remainder are pointers to the TensorInfo structs // that compiled code uses to load Tensor data. // launch_with_tensors handles packing at::Tensors into this arguments array. +<<<<<<< HEAD // CPU code uses the same convention so that launch_with_tensors can be +======= + // CPU code uses the same convension so that launch_with_tensors can be +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // shared. virtual void launch_raw(const uint32_t numel, std::vector& arguments) const = 0; diff --git a/torch/csrc/jit/codegen/onednn/graph_helper.cpp b/torch/csrc/jit/codegen/onednn/graph_helper.cpp index 2ef9f3cfa955c..44883a81212ee 100644 --- a/torch/csrc/jit/codegen/onednn/graph_helper.cpp +++ b/torch/csrc/jit/codegen/onednn/graph_helper.cpp @@ -70,7 +70,11 @@ Operator LlgaGraphHelper::makeBinaryOp(Node* node, opkind kind) { // third_party/ideep/mkl-dnn/src/interface/op_def.hpp. Operator LlgaGraphHelper::createOperator(Node* node) { auto nodeKind = node->kind(); +<<<<<<< HEAD // we're using an if-else clause instead of a switch statement +======= + // we're using an if-else clause instead of a switch staement +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // because we would soon be adding custom ops with function schemas. // We would have to use Symbol::fromQualString at that time anyway, // but we are okay with this choice, since this code is not in the hot-path. diff --git a/torch/csrc/jit/codegen/onednn/kernel.cpp b/torch/csrc/jit/codegen/onednn/kernel.cpp index c5421643e8c43..375a47b3742cc 100644 --- a/torch/csrc/jit/codegen/onednn/kernel.cpp +++ b/torch/csrc/jit/codegen/onednn/kernel.cpp @@ -84,9 +84,15 @@ ArgSpecs LlgaKernel::initializeInputSpecs(const TensorArgs& inputs) { for (const auto i : c10::irange(nGraphInputs_)) { auto spec = ArgSpec(graph_->inputs()[i]).supplementTensorInfo(inputs[i]); initializedInputIds_.insert(spec.tid()); +<<<<<<< HEAD int64_t occurrence = tensorIdToOccurence[spec.tid()]; inputSpecs.insert(inputSpecs.end(), occurrence, spec); runArgsIdx_.insert(runArgsIdx_.end(), occurrence, i); +======= + int64_t occurence = tensorIdToOccurence[spec.tid()]; + inputSpecs.insert(inputSpecs.end(), occurence, spec); + runArgsIdx_.insert(runArgsIdx_.end(), occurence, i); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } GRAPH_DEBUG("Initializing constant input tensors"); initializeConstantInputs(); diff --git a/torch/csrc/jit/docs/serialization.md b/torch/csrc/jit/docs/serialization.md index 43f7e261f0207..3514f2d8e3310 100644 --- a/torch/csrc/jit/docs/serialization.md +++ b/torch/csrc/jit/docs/serialization.md @@ -371,7 +371,11 @@ TorchScript class, or a `ScriptModule`. Owns other its attribute types **`Object`**: An instance of a particular class. Own the `CompilationUnit` that owns its `ClassType`. This is to ensure that if the user passes the object around in C++, all its code will stay around and methods will be +<<<<<<< HEAD invocable. +======= +invokable. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) **`Module`**: A view over a `ClassType` and the `Object` that holds its state. Also responsible for turning unqualified names (e.g. `forward()`) into diff --git a/torch/csrc/jit/frontend/builtin_functions.cpp b/torch/csrc/jit/frontend/builtin_functions.cpp index 2225f58e54e75..6c8cf61b86e96 100644 --- a/torch/csrc/jit/frontend/builtin_functions.cpp +++ b/torch/csrc/jit/frontend/builtin_functions.cpp @@ -103,10 +103,17 @@ struct BuiltinFunctionRegistry { // re-lock, the mutex without waiting), and report no loaded builtins during // init. std::lock_guard guard(mutex); +<<<<<<< HEAD if (state == INITIALIZING) { return empty; } else if (state == UNINITIALIZED) { state = INITIALIZING; +======= + if (state == INTIIALIZING) { + return empty; + } else if (state == UNINITIALIZED) { + state = INTIIALIZING; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) loadBuiltinFunctions(); state = INITIALIZED; } @@ -168,6 +175,7 @@ struct BuiltinFunctionRegistry { loadSource(aten_ops_additional, "aten"); // These are under `prim` instead of `aten` since they exist to bind certain +<<<<<<< HEAD // tensor property getters to corresponding methods loadSource(tensor_properties, "prim"); } @@ -178,6 +186,12 @@ struct BuiltinFunctionRegistry { INTIIALIZING = 1, // codespell:ignore INITIALIZED = 2 } state = UNINITIALIZED; +======= + // tensor property getters to correpsonding methods + loadSource(tensor_properties, "prim"); + } + enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::recursive_mutex mutex; std::vector> modules; std::unordered_map> builtins_by_name_; diff --git a/torch/csrc/jit/frontend/error_report.cpp b/torch/csrc/jit/frontend/error_report.cpp index d5a8408e971c0..6db7f0e25d4b0 100644 --- a/torch/csrc/jit/frontend/error_report.cpp +++ b/torch/csrc/jit/frontend/error_report.cpp @@ -6,6 +6,7 @@ namespace torch::jit { // Avoid storing objects with destructor in thread_local for mobile build. #ifndef C10_MOBILE +<<<<<<< HEAD // [NOTE: Thread-safe CallStack] // `calls` maintains a stack of Python calls that resulted in the // currently compiled TorchScript code. RAII ErrorReport::CallStack @@ -34,6 +35,9 @@ namespace torch::jit { // (since now multiple threads access a given thread_local calls object) static thread_local std::shared_ptr calls = std::make_shared(); +======= +static thread_local std::vector calls; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif // C10_MOBILE ErrorReport::ErrorReport(const ErrorReport& e) @@ -44,15 +48,23 @@ ErrorReport::ErrorReport(const ErrorReport& e) #ifndef C10_MOBILE ErrorReport::ErrorReport(const SourceRange& r) +<<<<<<< HEAD : context(r), error_stack(calls->get_stack()) {} void ErrorReport::CallStack::update_pending_range(const SourceRange& range) { calls->update_pending_range(range); +======= + : context(r), error_stack(calls.begin(), calls.end()) {} + +void ErrorReport::CallStack::update_pending_range(const SourceRange& range) { + calls.back().caller_range = range; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } ErrorReport::CallStack::CallStack( const std::string& name, const SourceRange& range) { +<<<<<<< HEAD source_callstack_ = calls; source_callstack_->push_back({name, range}); } @@ -61,6 +73,13 @@ ErrorReport::CallStack::~CallStack() { if (source_callstack_) { source_callstack_->pop_back(); } +======= + calls.push_back({name, range}); +} + +ErrorReport::CallStack::~CallStack() { + calls.pop_back(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } #else // defined C10_MOBILE ErrorReport::ErrorReport(const SourceRange& r) : context(r) {} @@ -91,7 +110,11 @@ static std::string get_stacked_errors(const std::vector& error_stack) { std::string ErrorReport::current_call_stack() { #ifndef C10_MOBILE +<<<<<<< HEAD return get_stacked_errors(calls->get_stack()); +======= + return get_stacked_errors(calls); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #else TORCH_CHECK(false, "Call stack not supported on mobile"); #endif // C10_MOBILE diff --git a/torch/csrc/jit/frontend/error_report.h b/torch/csrc/jit/frontend/error_report.h index 9f5ad9bf3bb68..0efedd0d2bf9b 100644 --- a/torch/csrc/jit/frontend/error_report.h +++ b/torch/csrc/jit/frontend/error_report.h @@ -1,7 +1,10 @@ #pragma once #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace torch::jit { @@ -19,6 +22,7 @@ struct TORCH_API ErrorReport : public std::exception { const char* what() const noexcept override; +<<<<<<< HEAD class TORCH_API Calls { private: std::vector calls_; @@ -51,6 +55,8 @@ struct TORCH_API ErrorReport : public std::exception { } }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct TORCH_API CallStack { // These functions are used to report why a function was being compiled // (i.e. what was the call stack of user functions at compilation time that @@ -61,9 +67,12 @@ struct TORCH_API ErrorReport : public std::exception { // Change the range that is relevant for the current function (i.e. after // each successful expression compilation, change it to the next expression) static void update_pending_range(const SourceRange& range); +<<<<<<< HEAD private: std::shared_ptr source_callstack_; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; static std::string current_call_stack(); diff --git a/torch/csrc/jit/frontend/exit_transforms.cpp b/torch/csrc/jit/frontend/exit_transforms.cpp index 48fc133fe3d04..5cfd799453910 100644 --- a/torch/csrc/jit/frontend/exit_transforms.cpp +++ b/torch/csrc/jit/frontend/exit_transforms.cpp @@ -333,7 +333,11 @@ struct ExitTransformer { std::vector exit_block_vals; // after an exit, the only values that will get used // are the hasExited() and exitValues(), so we match the existing +<<<<<<< HEAD // block outputs with uninitialized +======= + // block outputs with unitialized +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) exit_block_vals = matchValuesWithUnitialized(block->outputs()); // Set the new if to have the same outputs of the original block, @@ -362,7 +366,11 @@ struct ExitTransformer { // break // j = j + 1 // where the j + 1 value will be a block output, but since they will +<<<<<<< HEAD // never be used, it is safe to replace them with uninitialized value +======= + // never be used, it is safe to replace them with unitialized value +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void destroyNodeAfterExit(Node* n) { for (auto output : n->outputs()) { if (!output->uses().empty()) { diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 3004562e9ff56..c36b141d61410 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -959,7 +959,11 @@ struct to_ir { emitDef( def, nullptr, +<<<<<<< HEAD closure_block); // ignore schema return, we just won't use it for now +======= + closure_block); // ignore schema return, we just wont use it for now +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // since we never create a Method for the closure }; auto closure_value = emitClosure(emit_body); @@ -1578,7 +1582,11 @@ struct to_ir { /*default_to_union=*/true, elem_type_hint); +<<<<<<< HEAD // Case: The list comprehension generated heterogeneous values, +======= + // Case: The list comprehension generated heterogenous values, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // and we don't have a type hint to suggest that this is what the // user expected if (!type_hint && (*unified_elem_type)->isUnionType()) { @@ -1701,7 +1709,11 @@ struct to_ir { << "the first generated key was " << k->type()->repr_str()); } else if ( first_generated_key_type && first_generated_key_type != k->type()) { +<<<<<<< HEAD // Values can be heterogeneous, so we only need to check that the +======= + // Values can be heterogenous, so we only need to check that the +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // key types are all the same throw( ErrorReport(dc) @@ -2118,7 +2130,11 @@ struct to_ir { // Try to unify the types. If we found a type annotation earlier // in the environment, and if that type annotation is some form // of union, then we need to tell `unifyTypes` not to throw an +<<<<<<< HEAD // error if the branched return types we found are heterogeneous +======= + // error if the branched return types we found are heterogenous +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool default_to_union = full_type && (full_type->kind() == UnionType::Kind || full_type->kind() == OptionalType::Kind || @@ -2440,7 +2456,11 @@ struct to_ir { SugaredValuePtr iterable = sv->iter(loc, method); // We unroll the loop for iterables that contain ModuleLists so that we can +<<<<<<< HEAD // compile Heterogeneous module lists. +======= + // compile Heterogenous module lists. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!iterable->shouldEmitUnrolled()) { emitLoopCommon(loc, emit_body, iterable, targets, {}); } else { @@ -4260,7 +4280,11 @@ struct to_ir { } std::shared_ptr emitRpcExpr(const Apply& apply, Symbol rpc_op) { +<<<<<<< HEAD // TODO: This is a temporary apporoach to enable calling user function +======= + // TODO: This is a temporary apporoach to enable calling user fucntion +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // through RPC in TorchScript, // Ideally, function value in JIT IR is first-class citizen and // The RPC C++ entry API can take c10::Function directly. @@ -5399,7 +5423,11 @@ struct FunctionResolver : public Resolver { CompilationUnit::CompilationUnit(const std::string& source) : CompilationUnit() { +<<<<<<< HEAD // calls the define with native resolver to generate the graph for functions +======= + // calles the define with native resolver to generate the graph for functions +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) define(std::nullopt, source, nativeResolver(), nullptr); } diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index cbc22fab84e23..a83851bcc09b6 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -333,12 +333,20 @@ bool isBlockListedSchema(const FunctionSchema& schema) { // Currently JIT does not distinguish ScalarType vs int, so there is really // no way to distinguish x.view(1) vs x.view(torch.int8). So we have to // hardcode the aten::view.dtype here to block this overload. This blocklist +<<<<<<< HEAD // should be removed when JIT fully supports ScalarType as its own type. +======= + // should be removed when JIT fully suports ScalarType as its own type. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (schema.name() == "aten::view" && schema.overload_name() == "dtype") { return true; } // Note (@tugsbayasgalan) +<<<<<<< HEAD // TorchScript doesn't support kwargs so this op collides with aten.max.others +======= + // TorchScript doesn't suport kwargs so this op collides with aten.max.others +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // since both of them have 2 Tensor inputs. Since we don't expect users to // use this op in TS, we just skip it if (schema.name() == "aten::max" && schema.overload_name() == "unary_out") { diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index 4df9fb6639842..eba4e399a1d02 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -33,7 +33,10 @@ using c10::StorageType; using c10::StreamObjType; using c10::StringType; using c10::Symbol; +<<<<<<< HEAD using c10::SymBoolType; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using c10::SymIntType; using c10::TensorType; using c10::TupleType; @@ -67,7 +70,10 @@ TypePtr SchemaTypeParser::parseBaseType() { {"int", c10::TypeFactory::get()}, {"SymInt", c10::TypeFactory::get()}, {"bool", c10::TypeFactory::get()}, +<<<<<<< HEAD {"SymBool", c10::TypeFactory::get()}, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {"None", c10::TypeFactory::get()}, {"NoneType", c10::TypeFactory::get()}, {"Capsule", c10::TypeFactory::get()}, diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index 31fc483812ab0..441de2607607c 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -448,7 +448,11 @@ std::vector ScriptTypeParser::parseArgsFromDecl( } std::vector ScriptTypeParser::parseReturnFromDecl(const Decl& decl) { +<<<<<<< HEAD // we represent no annotation on a return type as having no values in the +======= + // we represent no annoation on a return type as having no values in the +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // schema's return() list // in emitReturn we take the actual return value to be the value of the // return statement if no one was provided here diff --git a/torch/csrc/jit/frontend/source_range.cpp b/torch/csrc/jit/frontend/source_range.cpp index 89815d386ac05..a862e9a28a3a0 100644 --- a/torch/csrc/jit/frontend/source_range.cpp +++ b/torch/csrc/jit/frontend/source_range.cpp @@ -42,12 +42,21 @@ size_t StringCordView::find(const std::string& tok, size_t start) const { size_t offset = start; for (; begin != end_iter; ++begin, ++offset) { if (*begin == tok[0]) { +<<<<<<< HEAD auto mismatch = std::mismatch(begin, end_iter, tok.begin(), tok.end()); if (mismatch.second == tok.end()) { // no mismatch, and second string (tok) is exhausted. return offset; } if (mismatch.first == end_iter) { +======= + auto mis = std::mismatch(begin, end_iter, tok.begin(), tok.end()); + if (mis.second == tok.end()) { + // no mismatch, and second string (tok) is exhausted. + return offset; + } + if (mis.first == end_iter) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // this str is exhausted but tok is not return std::string::npos; } @@ -312,7 +321,11 @@ void SourceRange::print_with_context( } out << "\n"; } +<<<<<<< HEAD // print out initial context +======= + // print out inital context +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out << str.substr(begin_context, start() - begin_context); size_t line_start = start(); size_t line_end = range_end; diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp index 0e9f0c9c2178c..f1cd119f30c92 100644 --- a/torch/csrc/jit/frontend/sugared_value.cpp +++ b/torch/csrc/jit/frontend/sugared_value.cpp @@ -359,8 +359,13 @@ void SimpleValue::setAttr( throw( ErrorReport(loc) << "Assignment to attribute '" << field +<<<<<<< HEAD << "' cannot be of a type that contains class " << "'" << classType->repr_str() << "'.\n" +======= + << "' cannot be of a type that contains class " + << "'" << classType->repr_str() << "'.\n" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) << "Classes that recursively contain instances of themselves" << " are not yet supported"); } @@ -826,6 +831,7 @@ SugaredValuePtr SugaredEnumClass::iter( return enum_values_list_constant; } +<<<<<<< HEAD std::shared_ptr TorchCheckValue::call( const SourceRange& loc, GraphFunction& m, @@ -904,4 +910,6 @@ std::shared_ptr TorchCheckValue::call( return std::make_shared(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace torch::jit diff --git a/torch/csrc/jit/frontend/sugared_value.h b/torch/csrc/jit/frontend/sugared_value.h index 59ddea774d5d1..8ea63629b3fb0 100644 --- a/torch/csrc/jit/frontend/sugared_value.h +++ b/torch/csrc/jit/frontend/sugared_value.h @@ -118,7 +118,11 @@ struct TORCH_API SugaredValue // If we are iterating over a Sugared Value and it returns a value from this // function, then we emit an unrolled loop over the variable. This allows us +<<<<<<< HEAD // to support containers of Heterogeneous types, like Module Containers & +======= + // to support containers of Heterogenous types, like Module Containers & +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Tuples virtual std::optional staticLen() { return std::nullopt; @@ -136,10 +140,18 @@ struct TORCH_API SugaredValue // Value * virtual Value* len(const SourceRange& loc, GraphFunction& m) { throw( +<<<<<<< HEAD ErrorReport(loc) << "'" << kind() << "'" << " object is not iterable"); } // expression for ith element for iterable value +======= + ErrorReport(loc) << "'" << kind() << "'" + << " object is not iterable"); + } + + // expression for ith elemement for iterable value +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) virtual std::shared_ptr getitem( const SourceRange& loc, GraphFunction& m, @@ -296,7 +308,11 @@ struct TORCH_API SugaredTupleValue : public SugaredValue { return shared_from_this(); } +<<<<<<< HEAD // Because this is used to contain SugaredValues of Heterogeneous types, +======= + // Because this is used to contain SugaredValues of Heterogenous types, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // we define staticLen() so that when this is iterated over it is emitted // as an unrolled loop. std::optional staticLen() override { @@ -318,7 +334,11 @@ struct TORCH_API BuiltinModule : public SugaredValue { GraphFunction& m, const std::string& field) override { if (field == "autograd") { +<<<<<<< HEAD // When referring torch.autograd, it is also considered to be a +======= + // When refering torch.autograd, it is also considered to be a +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // BuiltinModule and we will dispatch to the aten operators for the // methods under its module. return std::make_shared("aten", version); @@ -330,12 +350,20 @@ struct TORCH_API BuiltinModule : public SugaredValue { private: std::string name; +<<<<<<< HEAD // when we add operator versioning, emit this op as it existing at 'version' +======= + // when we add operator versioning, emit this op as it exising at 'version' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // if not set, use the latest version std::optional version; }; +<<<<<<< HEAD // Represents a class, analogous to `int` or `dict`. Instances of classes, +======= +// Represents a class, analagous to `int` or `dict`. Instances of classes, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // like `1` or `{"foo": 5}`, are represented as SimpleValues struct TORCH_API ClassValue : public SugaredValue { explicit ClassValue(ClassTypePtr type) : type_(std::move(type)) {} @@ -857,6 +885,7 @@ struct TORCH_API SliceValue : public SugaredValue { Value* step_; }; +<<<<<<< HEAD struct TORCH_API TorchCheckValue : public SugaredValue { explicit TorchCheckValue() = default; @@ -872,4 +901,6 @@ struct TORCH_API TorchCheckValue : public SugaredValue { size_t n_binders) override; }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace torch::jit diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index 3cfa77ef05cca..bc186162ee09f 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -557,7 +557,11 @@ void TracingState::setValue(const IValue& v, Value* value) { // If the value comes from a CallFunction or CallMethod, it may not have // shape information attached. For debuggability, we enhance the type +<<<<<<< HEAD // information by assigning the concrete value's type to the jit::Value. +======= + // information by assigning the concrete value's tupe to the jit::Value. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (auto tensor_type = value->type()->cast()) { if (!tensor_type->isComplete()) { value->inferTypeFrom(var); diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 16edf669da9be..ac7635314098a 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -53,7 +53,11 @@ class MutableTypePtrHelper { // Tensor with shape information removed. For example, a Tensor // of dimension 4 would map to the same type as a Tensor of // dimension 1. This allows us to treat all subclasses of Tensor +<<<<<<< HEAD // as a single, homogeneous "Tensor" type. +======= + // as a single, homogenous "Tensor" type. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::optional mapTypeToAliasTypeSet(const TypePtr& type) { if (mutable_type_cache_) { const AliasTypeSet* result = mapTypeToBorrowedAliasTypeSet(type); diff --git a/torch/csrc/jit/ir/alias_analysis.h b/torch/csrc/jit/ir/alias_analysis.h index 497412c6476e5..21edb041b44b4 100644 --- a/torch/csrc/jit/ir/alias_analysis.h +++ b/torch/csrc/jit/ir/alias_analysis.h @@ -48,7 +48,11 @@ class ValueAndMemoryLocationSet; * * `descendFunctionCalls` - recursively analyze function and method calls * instead of conservative analysis. Generally analysis should be done after +<<<<<<< HEAD * inlining so the implementation for recursive analysis is unoptimized. +======= + * inlining so the implmentation for recursive analysis is unoptimized. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) */ class AliasDb { public: @@ -102,7 +106,11 @@ class AliasDb { // Do any nodes write to an alias set output by `n`? TORCH_API bool hasOutputWriters(const Node* n) const; +<<<<<<< HEAD // Do any nodes write to an alias set inputted/outputted by `n`? +======= + // Do any nodes write to an alias set inputed/outputed by `n`? +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_API bool hasWriters(const Node* n) const; // Do any nodes write to `v`s memory location? @@ -338,7 +346,11 @@ TORCH_API void Lint(const AliasDb* db); * * The AliasDb must not be mutated after construction of a * ValueAndMemoryLocationsSet, or else the MemoryLocations stored in the * ValueAndMemoryLocationSet will no longer be accurate. +<<<<<<< HEAD * * A ValueAndMemoryLocationsSet is tied to an instance of AliasDb but +======= + * * A ValueAndMemoryLocationsSet is tied to an instsance of AliasDb but +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * does not own the AliasDb. It is the user's responsibility to ensure * that the AliasDb outlives the ValuesAndMemoryLocationsSet. * diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index fea29767d2653..23b270b837aeb 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1143,7 +1143,11 @@ bool Node::isNondeterministic() const { if (!kind().is_aten()) { return false; } +<<<<<<< HEAD // All aten ops are expected to have a schema. However this is left as a +======= + // All aten ops are expecte to have a schema. However this is left as a +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // warning instead of an assert to ensure that previous use cases do not // break. if (!schema) { @@ -1648,7 +1652,11 @@ Block* Node::findCommonAncestorBlockWith(Node* n) { n2 = n2->owningBlock()->owningNode(); } +<<<<<<< HEAD // Now they are the same number of blocks from the graph block, +======= + // Now they are the same numer of blocks from the graph block, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // recurse upwards, checking if they are on the same block while (true) { if (n1->owningBlock() == n2->owningBlock()) { diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index c3b4f455d576b..b48babaf36f6d 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -616,7 +616,11 @@ struct TORCH_API Node { // as the equivalents phi-nodes in standard SSA form, // defining a new Value to represent any term that has multiple // definitions depending on how control flowed. Outputs of the node containing +<<<<<<< HEAD // control flow serve a similar purpose defining new values for variables +======= + // control flow serve a similiar purpose defining new values for variables +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // that would have different definitions depending on which way control // flowed. @@ -1374,7 +1378,11 @@ struct Graph : std::enable_shared_from_this { // kwargs using Python argument matching rules, and checks that the op matches // a known schema. // +<<<<<<< HEAD // If this node successfully completes, it guarantees the node +======= + // If this node successfully completes, it guarentees the node +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // is a correctly-formed invocation of opname TORCH_API Value* insert( Symbol opname, diff --git a/torch/csrc/jit/ir/ir_views.h b/torch/csrc/jit/ir/ir_views.h index 94aec3bde85ae..e60a59a4f72cf 100644 --- a/torch/csrc/jit/ir/ir_views.h +++ b/torch/csrc/jit/ir/ir_views.h @@ -143,7 +143,11 @@ struct LoopView { private: Node* node_; +<<<<<<< HEAD // adjust index_ordering by adding indices 0 - thorough adjust, and +======= + // adjust index_ordering by adding indices 0 - thorugh adjust, and +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // incrementing all existing inputs by adjust static std::vector adjustIndices( size_t adjust, diff --git a/torch/csrc/jit/ir/irparser.h b/torch/csrc/jit/ir/irparser.h index 9b256b71487f6..c2d613014ac52 100644 --- a/torch/csrc/jit/ir/irparser.h +++ b/torch/csrc/jit/ir/irparser.h @@ -13,7 +13,11 @@ struct Value; // \brief Parse IR from \p STR constructing the corresponding IR in\ GRAPH. // if parse_tensor_constants is true will construct empty tensors +<<<<<<< HEAD // for Tensor constants with random or uninitialized contents, otherwise will +======= +// for Tensor constants with random or unitialized contents, otherwise will +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // throw TORCH_API void parseIR( const std::string& str, @@ -25,7 +29,11 @@ TORCH_API void parseIR( * \p VMAP is filled with String to Value pairs allowing to index Values in the * newly created graph by their name in the original IR string. * if parse_tensor_constants is true will construct empty tensors +<<<<<<< HEAD * for Tensor constants with random or uninitialized contents, otherwise will +======= + * for Tensor constants with random or unitialized contents, otherwise will +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * throw */ TORCH_API void parseIR( diff --git a/torch/csrc/jit/ir/node_hashing.cpp b/torch/csrc/jit/ir/node_hashing.cpp index 1551e610c3d10..54307fc160b4c 100644 --- a/torch/csrc/jit/ir/node_hashing.cpp +++ b/torch/csrc/jit/ir/node_hashing.cpp @@ -16,7 +16,11 @@ namespace torch::jit { namespace { bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) { +<<<<<<< HEAD // type_equal doesn't distinguish between mkldnn/pytorch cpu tensors, +======= + // type_equal doesnt distinguish between mkldnn/pytorch cpu tensors, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // and we dont want to coalesce mkldnn tensors bc they do layout // transformations based on usage if (lhs.is_mkldnn() || rhs.is_mkldnn()) { diff --git a/torch/csrc/jit/ir/scope.h b/torch/csrc/jit/ir/scope.h index 51baee8e277c1..950c5796a5471 100644 --- a/torch/csrc/jit/ir/scope.h +++ b/torch/csrc/jit/ir/scope.h @@ -208,7 +208,11 @@ struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target { }; // {source range, node name, InlinedCallStack} +<<<<<<< HEAD // We store node name because same debug info will be used for +======= +// We store node name because same debug infor will be used for +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // profiling as well, so we need to know op names as well. using DebugInfoTuple = std::tuple; diff --git a/torch/csrc/jit/ir/subgraph_matcher.h b/torch/csrc/jit/ir/subgraph_matcher.h index 91e170c052750..d8a038359e76b 100644 --- a/torch/csrc/jit/ir/subgraph_matcher.h +++ b/torch/csrc/jit/ir/subgraph_matcher.h @@ -11,7 +11,11 @@ namespace torch::jit { * \brief A structure describing a match of a pattern in a graph. * * The structure contains an anchor node, from which the match was found, and +<<<<<<< HEAD * match-maps for nodes and values. A match-map specifies the correspondence +======= + * match-maps for nodes and values. A match-map specifies the correspondance +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * between nodes in the pattern graph (match-map keys) with nodes in the actual * graph (match-map values). We keep such maps for both nodes and values. */ @@ -38,7 +42,11 @@ struct Match { * graph are ignored during matching (IOW, we're essentially performing DCE on * the pattern). * - Pattern graph nodes cannot alias. TODO: the check not implemented yet. +<<<<<<< HEAD * - Aliasing nodes in the graph cannot constitute a match (i.e. through all +======= + * - Aliasing nodes in the graph cannot consitute a match (i.e. through all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * found matches, no nodes in the subgraph alias with each other). TODO: check * not implemented yet. * - The matcher will not mutate either the pattern graph or the matched graph. diff --git a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp index 4422608423ee7..4cfce2cb06211 100644 --- a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp +++ b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp @@ -125,7 +125,11 @@ void write_archive_current( std::string fname = tensor_dir + tensor_names[i++]; if (use_storage_context && pre_serialized_files.find(fname) != pre_serialized_files.end()) { +<<<<<<< HEAD // storage has been serialized already, skip +======= + // storage has been serialzed already, skip +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue; } writer.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes()); @@ -230,7 +234,11 @@ std::stringstream update_bytecode_version( How to add backport_v{i}_to_v{i-1} ? There are two options: +<<<<<<< HEAD 1) [Format change only, recommended] Construct a reader with the +======= + 1) [Format change only, recommended] Constrcut a reader with the +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_model_stream, modify the file, and use PyTorchWriter to write it to output_model_stream. See backport_v5_to_v4. @@ -322,7 +330,11 @@ std::stringstream backport_v5_to_v4(std::stringstream& input_model_stream) { // The export function to generate bytecode.pkl for version 4. After bytecode // version bump, the old export function doesn't exist anymore, so keep a copy +<<<<<<< HEAD // here for backport purpose. +======= + // here for backport pupose. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto writeArchiveV4 = [](PyTorchStreamWriter& writer, const std::string& archive_name, const c10::IValue& value) { @@ -502,7 +514,11 @@ std::stringstream backport_v9_to_v8(std::stringstream& input_model_stream) { torch::jit::load(input_model_stream, std::nullopt, extra_files); std::stringstream intermediate_model_stream; // TODO(@pavithran) : Check if debug info is available and use load/save while +<<<<<<< HEAD // backporting hardcode debaug info to be false until supported. +======= + // backporting hardcode debaug info to be false untill supported. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool hasBytecodeDebug = false; { BytecodeEmitModeGuard argNumGuard( diff --git a/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp b/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp index 8d847ddeb533f..6014b50dda368 100644 --- a/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp +++ b/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp @@ -393,7 +393,11 @@ ModelCompatCheckResult is_compatible( OperatorInfo runtime_op_info = runtime_info.operator_info.at(op_name); // If the runtime op has no schema information its a false alarm and isn't +<<<<<<< HEAD // actually usable +======= + // actually useable +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!runtime_op_info.num_schema_args.has_value()) { result.status = ModelCompatibilityStatus::ERROR; std::ostringstream s; diff --git a/torch/csrc/jit/mobile/debug_info.cpp b/torch/csrc/jit/mobile/debug_info.cpp index 0a410a42fef04..3339850050766 100644 --- a/torch/csrc/jit/mobile/debug_info.cpp +++ b/torch/csrc/jit/mobile/debug_info.cpp @@ -76,7 +76,11 @@ std::pair, std::string> getStackTraceWithModuleHierarchy // This function construct stacktrace with module hierarchy // Module hierarchy will contain information about where in the // module hierarchy this source is. For example if conv2d op +<<<<<<< HEAD // exist in hierarchy A->B->C->Conv2d with type annotations of +======= +// exist in hierarcy A->B->C->Conv2d with type annotations of +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // A -> TopM, B->MyModule, C->SomeModule, then module hierarchy // will be TopM(A).MyModule(B).SomeModule(C).Conv2d(conv) // Source level stack information will be from model source code. diff --git a/torch/csrc/jit/mobile/debug_info.h b/torch/csrc/jit/mobile/debug_info.h index 14e1b1e4e7cd1..2f3d468ec580d 100644 --- a/torch/csrc/jit/mobile/debug_info.h +++ b/torch/csrc/jit/mobile/debug_info.h @@ -14,7 +14,11 @@ namespace torch::jit { * exception of BackendRuntimeException should raised using debug handles. * getSourceDebugString method is responsible for translating debug * handles to correspond debug information. +<<<<<<< HEAD * This debug information includes stack trace of model level source code and +======= + * This debug informatin includes stack trace of model level source code and +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * module hierarchy where the exception occurred. */ class MobileDebugTable { diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.h b/torch/csrc/jit/mobile/flatbuffer_loader.h index 24c670e01f79b..1a0230da1f914 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.h +++ b/torch/csrc/jit/mobile/flatbuffer_loader.h @@ -48,7 +48,11 @@ using ExtraFilesMap = std::unordered_map; // shared_ptr overload of this function. // // If should_copy_tensor_memory is true, then the returned module will NOT have +<<<<<<< HEAD // references to `data`, so `data` can be freed immediately. +======= +// refences to `data`, so `data` can be freed immediately. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // // If should_copy_tensor_memory is false, then returned module will have tensors // that points inside of `data`; the caller will need to make sure that `data` @@ -93,7 +97,11 @@ TORCH_API mobile::Module parse_and_initialize_mobile_module_for_jit( // // This function does steps 1+2+3 described above. // +<<<<<<< HEAD // We need to have this as a convenience because Python API will need to wrap +======= +// We need to have this as a convienience because Python API will need to wrap +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // this. C++ clients should use one of the versions of // parse_and_initialize_mobile_module() so they can manage the raw data more // directly. @@ -110,7 +118,11 @@ TORCH_API mobile::ModuleInfo get_module_info_from_flatbuffer( char* flatbuffer_content); // The methods below are less efficient because it need to read the stream in +<<<<<<< HEAD // its entirety to a buffer +======= +// its entirity to a buffer +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_API mobile::Module load_mobile_module_from_stream_with_copy( std::istream& in, std::optional device = std::nullopt, diff --git a/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.cpp b/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.cpp index b02c7ef74096a..3c3a861d78358 100644 --- a/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.cpp +++ b/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.cpp @@ -105,7 +105,11 @@ std::unordered_map MobileModelRunner:: function_and_info_dict[key.toStringRef()] = data_list; } +<<<<<<< HEAD // Could store the full mapping of std types, but the 'info' section isn't +======= + // Could store the full mapping of std types, but the 'info' section isnt +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // needed here std::string input_function = function_and_info_dict["get_inputs_function_name"][0]; diff --git a/torch/csrc/jit/mobile/profiler_edge.h b/torch/csrc/jit/mobile/profiler_edge.h index 4acfb041fc41f..d308765c3072c 100644 --- a/torch/csrc/jit/mobile/profiler_edge.h +++ b/torch/csrc/jit/mobile/profiler_edge.h @@ -38,7 +38,11 @@ class TORCH_API KinetoEdgeCPUProfiler { * * Thus, when KinetoEdgeCPUProfiler is used as RAII to do profiling * within certain scope. In that scope, the captured reference to +<<<<<<< HEAD * Module will outlive KinetoEdgeCPUProfiler. This is guaranteed because +======= + * Module will outlive KinetoEdgeCPUProfiler. This is gauranteed because +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * KinetoEdgeCPUProfiler must be constructed later than Module, on stack. * * An example of the anti-pattern and wrong usage is: diff --git a/torch/csrc/jit/mobile/train/optim/sgd.cpp b/torch/csrc/jit/mobile/train/optim/sgd.cpp index ae1a40e106215..e4ef6697a0d52 100644 --- a/torch/csrc/jit/mobile/train/optim/sgd.cpp +++ b/torch/csrc/jit/mobile/train/optim/sgd.cpp @@ -102,7 +102,11 @@ Tensor SGD::step(const LossClosure& closure) { Tensor buf; auto param_state = state_.find(p.unsafeGetTensorImpl()); if (param_state == state_.end()) { +<<<<<<< HEAD buf = d_p.detach().clone(); +======= + buf = torch::clone(d_p).detach(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto state = std::make_unique(); state->momentum_buffer(buf); state_[p.unsafeGetTensorImpl()] = std::move(state); diff --git a/torch/csrc/jit/mobile/type_parser.cpp b/torch/csrc/jit/mobile/type_parser.cpp index f9287a5eb7040..1f9ce9ed7d74d 100644 --- a/torch/csrc/jit/mobile/type_parser.cpp +++ b/torch/csrc/jit/mobile/type_parser.cpp @@ -36,7 +36,11 @@ TypeParser::TypeParser(std::vector& pythonStrs) // instruction. In nested type, the lowest level type will be at the beginning // of the type list. It is possible to parse it without worrying about // ordering, but it also introduces 1) extra cost to process nested type to +<<<<<<< HEAD // the correct order 2) lost the benefit that the instruction order is likely +======= +// the correct order 2) lost the benifit that the instruction order is likely +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // problematic if type list parsing fails. std::vector TypeParser::parseList() { std::vector typePtrs; diff --git a/torch/csrc/jit/operator_upgraders/README.md b/torch/csrc/jit/operator_upgraders/README.md index 60558a308110b..3c76a80ce69d1 100644 --- a/torch/csrc/jit/operator_upgraders/README.md +++ b/torch/csrc/jit/operator_upgraders/README.md @@ -11,7 +11,11 @@ You can determine if your change in the operator is BC breaking, if it fails `te ### Some examples BC breaking changes +<<<<<<< HEAD When making changes to the operators, the first thing to identify is if it's BC/FC breaking. Again, we only targeting for BC breaking changes on this guidance. Here are some examples to help understanding what a BC changes may look like: +======= +When making changes to the operators, the first thing to identify is if it's BC/FC breaking. Again, we only targetting for BC breaking changes on this guidance. Here are some examples to help understanding what a BC changes may look like: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #### Backward Compatibility Breakage: @@ -32,7 +36,11 @@ When making changes to the operators, the first thing to identify is if it's BC/ ### 1.Preparation +<<<<<<< HEAD [Build PyTorch from source](https://github.com/pytorch/pytorch#from-source) and prepare a test model before making changes to the operator, following the process below. A test model before making the operator changes is needed to test the upgrader. Otherwise, after the change to operator, the new runtime will no longer be able to produce a model with the historic operator and can't test it anymore. +======= +[Build PyTorch from souce](https://github.com/pytorch/pytorch#from-source) and prepare a test model before making changes to the operator, following the process below. A test model before making the operator changes is needed to test the upgrader. Otherwise, after the change to operator, the new runtime will no longer be able to produce a model with the historic operator and can't test it anymore. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 1. Add a test module in `test/jit/fixtures_srcs/fixtures_src.py`. In `test/jit/fixtures_srcs/generate_models.py`, ``` diff --git a/torch/csrc/jit/passes/batch_mm.cpp b/torch/csrc/jit/passes/batch_mm.cpp index 38e4b5068e2ff..cab4014dc009d 100644 --- a/torch/csrc/jit/passes/batch_mm.cpp +++ b/torch/csrc/jit/passes/batch_mm.cpp @@ -319,7 +319,11 @@ static void BatchMMTreeReduce(Block* block, AliasDb& alias_db) { } static bool shape_is_fast_for_side(const at::Tensor& other_side_input) { +<<<<<<< HEAD // Cutoff chose by benchmarking on a TITAN V +======= + // Cutoff chosed by benchmarking on a TITAN V +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return other_side_input.numel() <= 1024 * 2048; } diff --git a/torch/csrc/jit/passes/canonicalize.cpp b/torch/csrc/jit/passes/canonicalize.cpp index 1cc849d4a3cd7..430c6d6cf5a60 100644 --- a/torch/csrc/jit/passes/canonicalize.cpp +++ b/torch/csrc/jit/passes/canonicalize.cpp @@ -96,7 +96,11 @@ static bool isBefore(Node* n1, Node* n2) { } } +<<<<<<< HEAD // Now they are the same number of blocks from the graph block, +======= + // Now they are the same numer of blocks from the graph block, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // recurse upwards, checking if they are on the same block while (true) { if (n1->owningBlock() == n2->owningBlock()) { diff --git a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp index 680f7683009c8..f05c136cff52a 100644 --- a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp +++ b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp @@ -98,7 +98,11 @@ void InplaceMKLDNNSubgraph(const std::shared_ptr& graph) { // This function first calculates aliasing sets, // then calculates the last node each aliasing set is alive for. // Then we go through each node, if it's a node which has an equivalent +<<<<<<< HEAD // inplace node and the aliasing set for its input is dead after this node, we +======= + // inplace node and the aliasing set for its input is dead afer this node, we +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // inplace it. Then we merge the aliasing sets for the input and output of the // node and extend the liveness of the set. To inplace a node you need to // prove device and dtype of the input and output are the same, which we've @@ -812,7 +816,11 @@ void ComputeSubgraphInMKLDNN(Node* subgraph_node) { if (body_node->kind() == aten::conv2d || body_node->kind() == aten::conv3d) { +<<<<<<< HEAD // this node doesn't handle string padding yet... +======= + // this node doesnt handle string padding yet... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!body_node->namedInput("padding")->type()->cast()) { body_node->replaceWithNewSymbol(Symbol::prim("mkldnn_convolution")); body_node->destroy(); diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index cddae77768228..53a56e9294e01 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -167,7 +167,11 @@ std::shared_ptr ToONNX( ConstantValueMap::ClearMaps(); auto new_graph = std::make_shared(graph->current_scope()); py::dict env; +<<<<<<< HEAD // Kept identical to values in env. Used for constant-time existence check. +======= + // Kept identical to values in env. Used for constant-time existance check. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::set values_in_env; try { BlockToONNX( @@ -260,12 +264,19 @@ void NodeToONNX( ::torch::onnx::OperatorExportTypes operator_export_type, py::dict& env, py::set& values_in_env) { +<<<<<<< HEAD py::object onnx_utils = py::module::import("torch.onnx._internal.torchscript_exporter.utils"); py::object onnx_globals = py::module::import("torch.onnx._internal.torchscript_exporter._globals"); py::object onnx_registration = py::module::import( "torch.onnx._internal.torchscript_exporter.registration"); +======= + py::object onnx = py::module::import("torch.onnx"); + py::object onnx_globals = py::module::import("torch.onnx._globals"); + py::object onnx_registration = + py::module::import("torch.onnx._internal.registration"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Setup all the lambda helper functions. @@ -476,7 +487,11 @@ void NodeToONNX( // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to // Python. Check #87343 for details. py::list new_nodes = py::list(); +<<<<<<< HEAD py::object raw_output = onnx_utils.attr("_run_symbolic_function")( +======= + py::object raw_output = onnx.attr("_run_symbolic_function")( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) g->shared_from_this(), new_block, n, @@ -592,7 +607,11 @@ void NodeToONNX( // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to // Python. Check #87343 for details. +<<<<<<< HEAD py::object raw_output = onnx_utils.attr("_run_symbolic_method")( +======= + py::object raw_output = onnx.attr("_run_symbolic_method")( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_block->owningGraph()->shared_from_this(), op->name(), pyobj.attr("symbolic"), @@ -607,7 +626,11 @@ void NodeToONNX( // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to // Python. Check #87343 for details. py::list new_nodes = py::list(); +<<<<<<< HEAD py::object raw_output = onnx_utils.attr("_run_symbolic_function")( +======= + py::object raw_output = onnx.attr("_run_symbolic_function")( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_block->owningGraph()->shared_from_this(), new_block, n, diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp index 0ac07adf0d45c..c6ba5f0dae905 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.cpp +++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp @@ -76,8 +76,13 @@ static std::optional runTorchSlice_opset9( if (!(node->hasAttributeS("starts") && node->hasAttributeS("ends"))) { return std::nullopt; } +<<<<<<< HEAD auto const& startsAttr = node->is(attr::starts); auto const& endsAttr = node->is(attr::ends); +======= + auto startsAttr = node->is(attr::starts); + auto endsAttr = node->is(attr::ends); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (startsAttr.size() != endsAttr.size()) { return std::nullopt; } diff --git a/torch/csrc/jit/passes/onnx/function_extraction.cpp b/torch/csrc/jit/passes/onnx/function_extraction.cpp index 32c0e1b77c2cb..50a1904dafb81 100644 --- a/torch/csrc/jit/passes/onnx/function_extraction.cpp +++ b/torch/csrc/jit/passes/onnx/function_extraction.cpp @@ -216,7 +216,11 @@ void FunctionExtractor::FunctionContext::SetAttrName( TORCH_INTERNAL_ASSERT( v_it != scope_ctxs_[scope_key_]->env_to_subgraph_.end()); auto* n_in_def = v_it->second->node(); +<<<<<<< HEAD node_attr_to_name_[n_in_def][attr.toUnqualString()] = name; +======= + auto n_attr_it = node_attr_to_name_[n_in_def][attr.toUnqualString()] = name; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } std::optional FunctionExtractor::FunctionContext::FindAttrName( @@ -405,7 +409,11 @@ std::optional FunctionExtractor::InferScope(Node* n) { auto common_ancestor = FindCommonAncestor(scopes); if (common_ancestor.has_value() && IsValidScope(common_ancestor.value())) { +<<<<<<< HEAD return common_ancestor; +======= + return common_ancestor.value(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } } diff --git a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h index 8d05fbe942651..85d9b25181a58 100644 --- a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h +++ b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h @@ -17,7 +17,11 @@ namespace torch::jit { // information. Shape and type information is only available after // _jit_pass_onnx, which converts aten nodes to onnx nodes. So there is a // interdependent issue. _jit_pass_onnx depends on preprocess passes to convert +<<<<<<< HEAD // aten nodes into convertible condition, and preprocess passes depend on +======= +// aten nodes into convertable condition, and preprocess passes depend on +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // _jit_pass_onnx to convert upstream nodes and apply onnx shape inference. // Separating the pass into two parts breaks the interdependency. // diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 71595b769ac1c..de5a747e4369e 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -35,8 +35,13 @@ static bool isRNN(const Node* node) { } static bool isNopTranspose(const std::vector& perm) { +<<<<<<< HEAD for (size_t i = 0, perm_size = perm.size(); i < perm_size; i++) { if (perm[i] != static_cast(i)) { +======= + for (int64_t i = 0, perm_size = perm.size(); i < perm_size; i++) { + if (perm[i] != i) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return false; } } diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp index a188eb0abd6b8..677f5d4472f51 100644 --- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp @@ -10,6 +10,11 @@ #include +<<<<<<< HEAD +======= +#include + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace torch::jit { namespace { @@ -191,7 +196,12 @@ std::pair PrepareCopyForONNX(Node* node) { expanded_value->node()->copyMetadata(node); auto index_put = graph->insert( +<<<<<<< HEAD aten::index_put_, {node->input(0), dummy_list, expanded_value}); +======= + aten::index_put_, + {node->input(0), dummy_list, expanded_value, node->input(2)}); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) index_put->node()->copyMetadata(node); index_put->copyMetadata(node->output()); node->output()->replaceAllUsesWith(index_put); @@ -341,7 +351,11 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) { auto it = std::find(node->inputs().begin(), node->inputs().end(), input); if (it != node->inputs().end()) { +<<<<<<< HEAD auto index = std::distance(node->inputs().begin(), it); +======= + int index = std::distance(node->inputs().begin(), it); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_WARN( "ONNX Preprocess - Removing mutation from node ", node->kind().toQualString(), diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index 452b18f3efc31..2b6924c6472e0 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -282,7 +282,11 @@ Value* CloneValueFromListConstruct( auto input = n_graph->addInput(); if (scalar_type) { auto v_type = TensorType::create( +<<<<<<< HEAD scalar_type, +======= + scalar_type.value(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::kCPU, c10::SymbolicShape(), c10::VaryingShape{}, @@ -411,9 +415,13 @@ void ConvertGraphToONNXProto( } } +<<<<<<< HEAD std::optional ComputeConstantFolding( const Node* n, int opset_version) { +======= +std::optional ComputeConstantFolding(Node* n, int opset_version) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (n->inputs().empty()) { return std::nullopt; } @@ -465,7 +473,11 @@ std::optional<::c10::SymbolicShape> ComputeShapeFromReshape( auto it_0 = std::find_if(shape_vector.begin(), shape_vector.end(), is_zero); bool shape_has_zero = it_0 != shape_vector.end(); +<<<<<<< HEAD int64_t minus_one_pos = -1; +======= + int minus_one_pos = -1; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (auto i : c10::irange(shape_vector.size())) { if (shape_vector[i].value() == -1) { minus_one_pos = i; @@ -775,7 +787,11 @@ void ProcessBroadcastNode(Node* n) { } void ProcessShapeForConcatNode(Node* n) { +<<<<<<< HEAD auto axis = n->i(attr::axis); +======= + int axis = n->i(attr::axis); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (ConstantValueMap::HasRank(n->input(0)->debugName())) { auto rank = ConstantValueMap::GetRank(n->input(0)->debugName()).value(); size_t axis_adjust = 0; @@ -1246,7 +1262,11 @@ void ProcessUnsqueezeNode(Node* n) { void ComputeConstant(Node* n, int opset_version) { if (n->kind() == ::c10::onnx::Constant) { if (n->kindOf(attr::value) == AttributeKind::t) { +<<<<<<< HEAD const at::Tensor& const_val = n->t(attr::value); +======= + at::Tensor const_val = n->t(attr::value); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) at::Tensor const_val_copy = at::empty(const_val.sizes(), const_val.options()); const_val_copy.copy_(const_val); @@ -1383,7 +1403,11 @@ void ComputeConstant(Node* n, int opset_version) { .value() .sizes(); if (input0_shape_size.has_value()) { +<<<<<<< HEAD const auto& input0_shape_value = input0_shape_size.value(); +======= + auto input0_shape_value = input0_shape_size.value(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (ConstantValueMap::HasValue(n->input(1)->debugName())) { // When value of `shape` is statically known, // output shape can be computed. @@ -1476,7 +1500,11 @@ void ComputeConstant(Node* n, int opset_version) { .value() .sizes(); if (input0_shape_size.has_value()) { +<<<<<<< HEAD const auto& input0_shape_value = input0_shape_size.value(); +======= + auto input0_shape_value = input0_shape_size.value(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) int64_t total_size = 1; auto is_full_static = true; for (const auto i : c10::irange(input0_shape_value.size())) { @@ -1512,7 +1540,11 @@ void ComputeConstant(Node* n, int opset_version) { .value() .sizes(); if (input0_shape_size.has_value()) { +<<<<<<< HEAD const auto& input0_shape_value = input0_shape_size.value(); +======= + auto input0_shape_value = input0_shape_size.value(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (ConstantValueMap::HasValue(n->input(1)->debugName())) { auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector( n->input(1)->debugName()); @@ -1661,10 +1693,17 @@ void SpecialPostProcess(Node* n) { }; auto find_sequence_empty = [](Value* input, +<<<<<<< HEAD const TensorTypePtr& t_type) -> Node* { auto find_sequence_empty_impl = [](Value* input, const TensorTypePtr& t_type, +======= + TensorTypePtr t_type) -> Node* { + auto find_sequence_empty_impl = + [](Value* input, + TensorTypePtr t_type, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto& find_sequence_empty_ref) -> Node* { auto input_node = input->node(); TORCH_INTERNAL_ASSERT(input_node); @@ -1710,7 +1749,11 @@ void SpecialPostProcess(Node* n) { return nullptr; }; return find_sequence_empty_impl( +<<<<<<< HEAD input, t_type, find_sequence_empty_impl); +======= + input, std::move(t_type), find_sequence_empty_impl); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; if (seq_node && t_type && t_type->scalarType()) { @@ -2257,7 +2300,11 @@ void ONNXSetDynamicInputShape( } } +<<<<<<< HEAD static bool HasSequenceTypeOutput(const Node* node) { +======= +static bool HasSequenceTypeOutput(Node* node) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (node->kind() == ::c10::onnx::SplitToSequence || node->kind() == ::c10::onnx::SequenceInsert || node->kind() == ::c10::onnx::SequenceEmpty || diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index 63e6804c97eb3..aecf8533abe46 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -21,6 +21,86 @@ using namespace ::c10::onnx; } +<<<<<<< HEAD +======= +// Get the scale of the input to quantized op. There are two cases here +// 1. For ops with output_scale specified in op signature, we get the output +// scale +// 2. For ops with no output scale in op signature (like quantized::relu) +// we traverse up the graph to get the scale from its input until we hit a node +// where scale is explicitly specified. +double getScaleFromInput(Node* input_node) { + std::optional scale; + std::string input_name = input_node->kind().toQualString(); + std::unordered_set noscale_ops = { + "quantized::max_pool2d", + "aten::max_pool2d", + "aten::relu", + "prim::ListUnpack", + "aten::split_with_sizes", + "quantized::nchw2nhwc", + "quantized::nhwc2nchw", + "aten::slice", + "aten::avg_pool2d", + "quantized::cat", + "prim::ListConstruct", + "aten::upsample_nearest2d", + "aten::sigmoid", + "aten::reshape"}; + if (input_name == "aten::quantize_per_tensor") { + TORCH_CHECK( + input_node->inputs().size() > 1, + "aten::quantize_per_tensor expected scale to be 2nd input"); + scale = toIValue(input_node->inputs()[1]); + return scale.value().toDouble(); + } else if (input_name == "quantized::linear") { + // %r = quantized::linear(%input, %packed_weight, %w_scale, %w_zero_point) + TORCH_CHECK( + input_node->inputs().size() > 2, + "quantized::linear expected scale to be 3rd input"); + scale = toIValue(input_node->inputs()[2]); + return scale.value().toDouble(); + } else if (input_name == "quantized::conv2d") { + // %r = quantized::conv2d(%input, %packed_weight, %w_scale, %w_zero_point) + TORCH_CHECK( + input_node->inputs().size() > 2, + "quantized::conv2d expected scale to be 3rd input"); + auto num_inputs = input_node->inputs().size(); + scale = toIValue(input_node->inputs()[num_inputs - 2]); + return scale.value().toDouble(); + } else if (input_name == "quantized::conv2d_relu") { + // %r = quantized::conv2d_relu(%input, %packed_weight, %w_scale, + // %w_zero_point) + TORCH_CHECK( + input_node->inputs().size() > 2, + "quantized::conv2d_relu expected scale to be 3rd input"); + auto num_inputs = input_node->inputs().size(); + scale = toIValue(input_node->inputs()[num_inputs - 2]); + return scale.value().toDouble(); + } else if (input_name == "quantized::add") { + // %r = quantized::add(%input_a, %input_b, %w_scale, %w_zero_point) + TORCH_CHECK( + input_node->inputs().size() > 2, + "quantized::add expected scale to be 3rd input"); + scale = toIValue(input_node->inputs()[2]); + return scale.value().toDouble(); + } else if (input_name == "aten::sigmoid") { + // For the _caffe2::Int8Sigmoid op output scale is 1.0/256 + // And output zero_point is set to 0 (quint8 type). + return 1.0L / 256; + } + // For the ops below the scale is not part of the op signature, so we traverse + // up the graph to get the scale from its input when defined in the graph. + else if (noscale_ops.find(input_name) != noscale_ops.end()) { + return getScaleFromInput(input_node->inputs()[0]->node()); + } + TORCH_INTERNAL_ASSERT( + false, + "Unrecognized quantized operator while trying to compute q_scale for operator ", + input_name); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static std::vector CreateQuantizedWeights( std::shared_ptr& graph, const at::Tensor& weight, @@ -238,7 +318,11 @@ static void unpackQuantizedWeightsHelper( auto config_vals = elements[1].to>(); auto tensors = elements[2].to>>(); +<<<<<<< HEAD const std::optional& weight = tensors[1]; +======= + std::optional weight = tensors[1]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_INTERNAL_ASSERT( weight, "Weight should always be present in serialized qconv."); unpacked_weight = *weight; @@ -296,7 +380,11 @@ static void unpackQuantizedWeightsHelper( TORCH_INTERNAL_ASSERT(version == "2", "Unknown serialization version"); std::vector non_optional = elements[1].toTensorVector(); +<<<<<<< HEAD const at::Tensor& conv_params_packed = non_optional[0]; +======= + at::Tensor conv_params_packed = non_optional[0]; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unpacked_weight = non_optional[1]; const int64_t kSpatialDim = conv_params_packed[0].item(); diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp index 36d6884637d2a..3db8d9bd3546b 100644 --- a/torch/csrc/jit/passes/quantization/helper.cpp +++ b/torch/csrc/jit/passes/quantization/helper.cpp @@ -116,7 +116,11 @@ static std::vector _single_input_general_shape_aten_funcs = { "__getitem__", }; +<<<<<<< HEAD // These are prim::CallFunctions for ops that doesn't require observation and +======= +// Theses are prim::CallFunctions for ops that doesn't require observation and +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // have a single input Tensor // Also these ops do computation on the value of Tensor // TODO: [Need verify] looks like we can quantize simple functionals that just @@ -136,7 +140,11 @@ static std::vector _single_input_general_value_call_funcs = { "leaky_relu", }; +<<<<<<< HEAD // These are aten functions for ops that doesn't require observation and +======= +// Theses are aten functions for ops that doesn't require observation and +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // have a single input Tensor // Also these ops do computation on the value of Tensor // e.g. `aten::avg_pool2d(%input_tensor, ...)` diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index 5fab235044453..c153e74a7dc69 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -1702,7 +1702,11 @@ Module InsertObserversForOnDevicePTQ( // you will have multiple getattrs for the same attribute and thus potentially // multiple observers observing the same value. This will also lead to // increased size of the packed param struct. I dont expect this to be a +<<<<<<< HEAD // common pattern but something to be aware of Note that current quant +======= + // common pattern but something to be aware fo Note that current quant +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // workflow does not prevent this anyway since during inset quant dequant // things are inlined anyway helper.fillBoundaryValueMap(cloned_module, observer_method_name); diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 2e39bf67bf5f3..0de6ddfdc8333 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -1622,7 +1622,11 @@ void InsertQuantDeQuantHelper::insertCalculateQParamsAndQuantizationOps( void InsertQuantDeQuantHelper::runForOnDevicePTQ( Module& module, const std::string& method_name) { +<<<<<<< HEAD // In all likelihood this really won't do anything because we expect that +======= + // In all likelihood this really wont do anything because we expect that +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // the input method for quantization's prepare step will be inlined. Thus // only call methods we will see will belong to observer's forward calls. for (auto& invoked_methods : getInvokedMethods(module, method_name)) { @@ -1834,8 +1838,13 @@ Module InsertQuantDeQuantOnDevicePTQ( // ReplicateChooseQParamsQuantDequant: This is propagating dynamic quant's // quant dequant RemoveRedundantQuantizationOps: THis is removing activation // observers for dynamic quant when the op related to it is not dynamically +<<<<<<< HEAD // quantizable. Doesn't really make sense. In our case we won't have those // anyway since for dynamic quant activations won't be observed We can still +======= + // quantizable. Doesnt really make sense. In our case we wont have those + // anyway since for dynamic quant activations wont be observed We can still +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // use this function because the above two methods should really be a noop h.propagateQuantizationOps(module); return module; diff --git a/torch/csrc/jit/passes/quantization/quantization_patterns.h b/torch/csrc/jit/passes/quantization/quantization_patterns.h index 86d7b5857c49c..547ffc5bd2ae9 100644 --- a/torch/csrc/jit/passes/quantization/quantization_patterns.h +++ b/torch/csrc/jit/passes/quantization/quantization_patterns.h @@ -206,7 +206,11 @@ QuantFusionInfo getFixedQParamOpFusionInfo( %r = )"; op_pattern += op_name + "(" + "%a_dequant" + extra_op_arg_list + ")"; // IR pattern common to all ops with fixed quantization parameters for +<<<<<<< HEAD // asymmetric quantization +======= + // asymetric quantization +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::string asym_fixed_qparam_op_suffix = R"( %r_scale : float = prim::Constant[value=0.00390625]() %r_zero_point : int = prim::Constant[value=0]() diff --git a/torch/csrc/jit/passes/symbolic_shape_cache.h b/torch/csrc/jit/passes/symbolic_shape_cache.h index 4d0f1bdcd6298..17c04603dbfd2 100644 --- a/torch/csrc/jit/passes/symbolic_shape_cache.h +++ b/torch/csrc/jit/passes/symbolic_shape_cache.h @@ -8,7 +8,11 @@ namespace torch::jit { struct TORCH_API CanonicalizedSymbolicShape { // TODO: Consider in the future if it is reasonable to // merge code with SymbolicShape or VaryingShape while keeping +<<<<<<< HEAD // the two not implicitly convertible (and cause bugs). +======= + // the two not implicitly convertable (and cause bugs). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CanonicalizedSymbolicShape( const c10::SymbolicShape& orig_shape, std::unordered_map& ss_map) { diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index bb052fc8421ff..0e7de799a6e5a 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -196,6 +196,7 @@ static void removeProfileNodesAndSpecializeTypes(Block* b) { if (it->input()->type()->kind() == c10::TypeKind::TensorType) { input_tensor_type = it->input()->type()->expect(); } else { +<<<<<<< HEAD auto element_type = it->input() ->type(); if (element_type->cast()) { @@ -210,6 +211,13 @@ static void removeProfileNodesAndSpecializeTypes(Block* b) { element_type->expect(); } +======= + input_tensor_type = it->input() + ->type() + ->expectRef() + .getElementType() + ->expect(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_is_optional = true; } @@ -405,7 +413,11 @@ void insertTypeGuard( namespace { bool has_unsupported_pin_memory(const Node* node) { +<<<<<<< HEAD // can't support non-constant pin_memory or pin_memory = True +======= + // cant support non-constant pin_memory or pin_memory = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (auto maybe_index = node->schema().argumentIndexWithName("pin_memory")) { int index = *maybe_index; auto inp = node->input(index); diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.h b/torch/csrc/jit/passes/tensorexpr_fuser.h index c9007c82b95e5..16ee0e7bf00ed 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.h +++ b/torch/csrc/jit/passes/tensorexpr_fuser.h @@ -66,7 +66,11 @@ TORCH_API bool isSupported(Node* node); /// work with dynamic shapes unless explicitly register the shape function via /// `torch::jit::RegisterShapeComputeGraphForSchema` for the custom operator. /// +<<<<<<< HEAD /// @return Reference of the custom operator set +======= +/// @return Reference of the custome operator set +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /// TORCH_API OperatorSet& getCustomOperatorSet(); diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.cpp b/torch/csrc/jit/passes/utils/subgraph_utils.cpp index f9fd65f9ce541..3e7fa1e2283bf 100644 --- a/torch/csrc/jit/passes/utils/subgraph_utils.cpp +++ b/torch/csrc/jit/passes/utils/subgraph_utils.cpp @@ -62,7 +62,11 @@ struct ValueMapper { auto new_outputs = merged_node->outputs(); for (Value* v : new_outputs) { auto maybe_last_use = firstOrLastUse(v, /*find_first*/ false); +<<<<<<< HEAD // if it doesn't have a use it shouldn't have been added as output +======= + // if it doesnt have a use it shouldnt have been added as output +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_INTERNAL_ASSERT(maybe_last_use); const Use last_use = *maybe_last_use; diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 1cc439aa65b20..023013559a465 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1726,7 +1726,11 @@ void initJITBindings(PyObject* module) { const py::args& args, const py::kwargs& kwargs) { ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors); return _get_operation_for_overload_or_packet( +<<<<<<< HEAD op, symbol, args, kwargs, /*is_overload*/ true); +======= + {op}, symbol, args, kwargs, /*is_overload*/ true); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }); auto func_dk = py::cpp_function([op, symbol, allow_numbers_as_tensors]( @@ -1735,7 +1739,11 @@ void initJITBindings(PyObject* module) { const py::kwargs& kwargs) { ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors); return _get_operation_for_overload_or_packet( +<<<<<<< HEAD op, symbol, args, kwargs, /*is_overload*/ true, dk_); +======= + {op}, symbol, args, kwargs, /*is_overload*/ true, dk_); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }); return py::make_tuple( func, func_dk, py::cast(op->getTags().vec())); @@ -1958,6 +1966,7 @@ void initJITBindings(PyObject* module) { std::vector, bool, bool>()) +<<<<<<< HEAD .def_property_readonly("name", &FunctionSchema::name) .def_property_readonly("overload_name", &FunctionSchema::overload_name) .def_property_readonly("arguments", &FunctionSchema::arguments) @@ -1977,6 +1986,19 @@ void initJITBindings(PyObject* module) { // FunctionSchema::isBackwardCompatibleWith has an extra // defaulted argument, so we can't just use a // pointer-to-member here. +======= + .def_property_readonly( + "name", [](FunctionSchema& self) { return self.name(); }) + .def_property_readonly( + "overload_name", + [](FunctionSchema& self) { return self.overload_name(); }) + .def_property_readonly( + "arguments", [](FunctionSchema& self) { return self.arguments(); }) + .def_property_readonly( + "returns", [](FunctionSchema& self) { return self.returns(); }) + .def( + "is_backward_compatible_with", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [](const FunctionSchema& self, const FunctionSchema& old_schema) { return self.isBackwardCompatibleWith(old_schema); }) @@ -1999,14 +2021,22 @@ void initJITBindings(PyObject* module) { }) .def( "__str__", +<<<<<<< HEAD [](const FunctionSchema& self) { +======= + [](FunctionSchema& self) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::stringstream ss; ss << self; return ss.str(); }) .def( "__repr__", +<<<<<<< HEAD [](const FunctionSchema& self) { +======= + [](FunctionSchema& self) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::stringstream ss; ss << self; return ss.str(); @@ -2020,9 +2050,14 @@ void initJITBindings(PyObject* module) { [](const py::str& schema) { // __setstate__, note: no `self` argument return parseSchema(schema); })) +<<<<<<< HEAD .def_property_readonly("is_mutable", [](const FunctionSchema& self) { return self.is_mutable(); }); +======= + .def_property_readonly( + "is_mutable", [](FunctionSchema& self) { return self.is_mutable(); }); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::class_(m, "Argument") .def(py::init< std::string, @@ -2031,17 +2066,31 @@ void initJITBindings(PyObject* module) { std::optional, bool, std::optional>()) +<<<<<<< HEAD .def_property_readonly("name", &Argument::name) .def_property_readonly("type", &Argument::type) .def_property_readonly("real_type", &Argument::real_type) .def_property_readonly( "N", [](const Argument& self) -> py::object { +======= + .def_property_readonly("name", [](Argument& self) { return self.name(); }) + .def_property_readonly("type", [](Argument& self) { return self.type(); }) + .def_property_readonly( + "real_type", [](Argument& self) { return self.real_type(); }) + .def_property_readonly( + "N", + [](Argument& self) -> py::object { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (self.N()) ? py::cast(*self.N()) : py::none(); }) .def_property_readonly( "default_value", +<<<<<<< HEAD [](const Argument& self) -> py::object { +======= + [](Argument& self) -> py::object { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!self.default_value()) { return py::none(); } @@ -2050,6 +2099,7 @@ void initJITBindings(PyObject* module) { }) .def( "has_default_value", +<<<<<<< HEAD [](const Argument& self) -> py::bool_ { return self.default_value().has_value(); }) @@ -2058,30 +2108,56 @@ void initJITBindings(PyObject* module) { .def_property_readonly( "is_write", [](const Argument& self) { +======= + [](Argument& self) -> py::bool_ { + return self.default_value().has_value(); + }) + .def_property_readonly( + "alias_info", [](Argument& self) { return self.alias_info(); }) + .def_property_readonly( + "is_write", + [](Argument& self) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (self.alias_info() == nullptr) { return false; } return self.alias_info()->isWrite(); }) .def_property_readonly( +<<<<<<< HEAD "is_out", [](const Argument& self) { return self.is_out(); }) .def_property_readonly("kwarg_only", [](const Argument& self) -> bool { +======= + "is_out", [](Argument& self) { return self.is_out(); }) + .def_property_readonly("kwarg_only", [](Argument& self) -> bool { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.kwarg_only(); }); py::class_(m, "_AliasInfo") .def(py::init, std::set>()) .def_property_readonly( +<<<<<<< HEAD "is_write", [](const AliasInfo& self) { return self.isWrite(); }) .def_property_readonly( "before_set", [](const AliasInfo& self) { +======= + "is_write", [](AliasInfo& self) { return self.isWrite(); }) + .def_property_readonly( + "before_set", + [](AliasInfo& self) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::set before_set_python; for (const auto& set : self.beforeSets()) { before_set_python.insert(py::str(set.toUnqualString())); } return before_set_python; }) +<<<<<<< HEAD .def_property_readonly("after_set", [](const AliasInfo& self) { +======= + .def_property_readonly("after_set", [](AliasInfo& self) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::set after_set_python; for (const auto& set : self.afterSets()) { after_set_python.insert(py::str(set.toUnqualString())); @@ -2324,7 +2400,11 @@ void initJITBindings(PyObject* module) { // Throw errors when calling wait() on the returned Future if // any of the original futures would throw. // NB: PythonFutureWrapper takes an unwrap_func which serves as a +<<<<<<< HEAD // callback to evaluate the value in the Future. RPC uses this +======= + // callback to evalute the value in the Future. RPC uses this +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // unwrap_func to check whether the returned py::object is a // RemoteException object, and re-throw the exception if it is. // By extracting the c10::ivalue::Future from PythonFutureWrapper diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index a366aa58f822d..826a3dc741815 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -90,7 +90,11 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional N) { if (PyBool_Check(obj.ptr())) { scalar = at::Scalar(THPUtils_unpackBool(obj.ptr())); } else if (THPUtils_checkLong(obj.ptr())) { +<<<<<<< HEAD scalar = THPUtils_unpackInteger(obj.ptr()); +======= + scalar = at::Scalar(THPUtils_unpackLong(obj.ptr())); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else if (PyComplex_Check(obj.ptr())) { scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr())); } else if (THPUtils_checkDouble(obj.ptr())) { @@ -313,7 +317,11 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional N) { bool is_symbolic = false; for (auto it = obj.begin(); it != obj.end(); it++) { auto elm = *it; +<<<<<<< HEAD if (torch::is_symint(elm) || THPVariable_Check(elm.ptr())) { +======= + if (torch::is_symint(elm)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_symbolic = true; break; } @@ -512,7 +520,11 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional N) { if (py::isinstance(obj)) { return py::cast(obj); } else if (py::isinstance(obj)) { +<<<<<<< HEAD return THPUtils_unpackInteger(obj.ptr()); +======= + return py::cast(obj); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else if (py::isinstance(obj)) { return py::cast(obj); } else if (PyComplex_CheckExact(obj.ptr())) { @@ -598,8 +610,11 @@ py::object toPyObject(IValue ivalue) { return py::cast(*tensor.const_data_ptr()); case at::ScalarType::Long: return py::cast(*tensor.const_data_ptr()); +<<<<<<< HEAD case at::ScalarType::UInt64: return py::cast(*tensor.const_data_ptr()); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case at::ScalarType::Double: return py::cast(*tensor.const_data_ptr()); case at::ScalarType::ComplexDouble: @@ -765,8 +780,11 @@ py::object toPyObject(IValue ivalue) { return py::cast(std::move(ivalue).toSymFloat()); } else if (ivalue.isSymBool()) { return py::cast(std::move(ivalue).toSymBool()); +<<<<<<< HEAD } else if (ivalue.isUnsigned()) { return py::cast(std::move(ivalue).toUInt()); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { TORCH_CHECK( false, @@ -780,6 +798,7 @@ std::pair, Stack> getOpWithStack( const std::vector>& operations, const py::args& args, const py::kwargs& kwargs) { +<<<<<<< HEAD return getOpWithStack( c10::ArrayRef>(operations), args, kwargs); } @@ -791,6 +810,11 @@ std::pair, Stack> getOpWithStack( Stack stack; if (operations.size() == 1) { std::shared_ptr op = operations[0]; +======= + Stack stack; + if (operations.size() == 1) { + std::shared_ptr op = operations.at(0); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Create a stack full of the arguments and keyword arguments. stack = createStackForSchema(op->schema(), args, kwargs, std::nullopt); @@ -821,7 +845,11 @@ std::pair, Stack> getOpWithStack( } // This function is used to check if the schema is valid for the given args and +<<<<<<< HEAD // kwargs. It checks script object by checking whether the FakeScriptObject is +======= +// kwargs. It checks script object by checking wether the FakeScriptObject is +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // an instance of the corresponding fake class for the actual class used in // schema. bool checkSchemaAllowFakeScriptObject( @@ -842,6 +870,7 @@ py::object invokeOperatorFromPython( const py::args& args, const py::kwargs& kwargs, std::optional dk) { +<<<<<<< HEAD return invokeOperatorFromPython( c10::ArrayRef>(operations), args, kwargs, dk); } @@ -851,6 +880,8 @@ py::object invokeOperatorFromPython( const py::args& args, const py::kwargs& kwargs, std::optional dk) { +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto [found_op, stack] = getOpWithStack(operations, args, kwargs); { pybind11::gil_scoped_release no_gil_guard; @@ -872,9 +903,14 @@ std::optional _maybe_handle_torch_function( const py::args& args, const py::kwargs& kwargs) { std::vector overloaded_args; +<<<<<<< HEAD const auto args_size = args.size(); size_t total_arg_num = args_size + kwargs.size(); for (const auto i : c10::irange(args_size)) { +======= + size_t total_arg_num = args.size() + kwargs.size(); + for (const auto i : c10::irange(args.size())) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_tensor_and_append_overloaded(args[i].ptr(), &overloaded_args); is_tensor_list_and_append_overloaded( args[i].ptr(), @@ -929,6 +965,7 @@ py::object _get_operation_for_overload_or_packet( const py::kwargs& kwargs, bool is_overload, std::optional dk) { +<<<<<<< HEAD return _get_operation_for_overload_or_packet( c10::ArrayRef(operations), symbol, args, kwargs, is_overload, dk); } @@ -940,6 +977,8 @@ py::object _get_operation_for_overload_or_packet( const py::kwargs& kwargs, bool is_overload, std::optional dk) { +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::string ns = symbol.ns().toUnqualString(); std::string method_name = symbol.toUnqualString(); std::string overload_name = operations[0]->schema().overload_name(); diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 5ae84e3e0c68b..d106a070d0228 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -649,7 +649,11 @@ inline InferredType tryToInferContainerType( ".")); } else { // TODO: this message is not correct anymore, since this InferredType is +<<<<<<< HEAD // used from a bunch of circumstances unrelated to tracing. We can reuse +======= + // used from a bunch of circumstances unrelated to tracing. We can re-use +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // this instead of the attribute_failure stuff in concreteType return InferredType(c10::str( "Only tensors and (possibly nested) tuples of tensors, lists, or dicts ", @@ -1277,6 +1281,7 @@ TORCH_PYTHON_API std::pair, Stack> getOpWithStack( const py::args& args, const py::kwargs& kwargs); +<<<<<<< HEAD // Efficient overload (does not require vector allocation) of the // above for use from C++ code. std::pair, Stack> getOpWithStack( @@ -1284,12 +1289,15 @@ std::pair, Stack> getOpWithStack( const py::args& args, const py::kwargs& kwargs); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_PYTHON_API py::object invokeOperatorFromPython( const std::vector>& operations, const py::args& args, const py::kwargs& kwargs, std::optional dk = std::nullopt); +<<<<<<< HEAD // Efficient overload (does not require vector allocation) of the // above for use from C++ code. py::object invokeOperatorFromPython( @@ -1298,6 +1306,8 @@ py::object invokeOperatorFromPython( const py::kwargs& kwargs, std::optional dk = std::nullopt); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_PYTHON_API std::optional _maybe_handle_torch_function( const std::string& ns, const std::string& method_name, @@ -1319,6 +1329,7 @@ TORCH_PYTHON_API py::object _get_operation_for_overload_or_packet( bool is_overload, std::optional dk = std::nullopt); +<<<<<<< HEAD // Efficient overload (does not require vector allocation) of the // above for use from C++ code. py::object _get_operation_for_overload_or_packet( @@ -1329,4 +1340,6 @@ py::object _get_operation_for_overload_or_packet( bool is_overload, std::optional dk = std::nullopt); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace torch::jit diff --git a/torch/csrc/jit/python/python_ivalue.h b/torch/csrc/jit/python/python_ivalue.h index 73297c3ac0794..96b235f9307a6 100644 --- a/torch/csrc/jit/python/python_ivalue.h +++ b/torch/csrc/jit/python/python_ivalue.h @@ -99,7 +99,11 @@ struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder { py_obj_.ptr() = nullptr; } +<<<<<<< HEAD // explicit construction to avoid erroneous implicit conversion and +======= + // explicit construction to avoid errornous implicit conversion and +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // copy-initialization explicit ConcretePyObjectHolder(py::object py_obj) : py_obj_(std::move(py_obj)) {} diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 8b16e089aa50e..390354f8a192b 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -1222,10 +1222,15 @@ std::shared_ptr toSugaredValue( } else if ( obj.ptr() == py::module::import("torch.jit").attr("isinstance").ptr()) { return SpecialFormValue::create(prim::isinstance); +<<<<<<< HEAD } else if (obj.ptr() == py::module::import("torch").attr("_check").ptr()) { return std::make_shared(); #ifdef USE_RPC // RPC module is only available when build flag "USE_DISTRIBUTED" is on. +======= +#ifdef USE_RPC + // RPC module is only avaialble when build flag "USE_DISTRIBUTED" is on. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else if ( isRpcAvailable && obj.ptr() == @@ -1238,7 +1243,11 @@ std::shared_ptr toSugaredValue( return SpecialFormValue::create(prim::rpc_sync); } else if ( isRpcAvailable && +<<<<<<< HEAD // RPC module is only available when build flag "USE_DISTRIBUTED" is on. +======= + // RPC module is only avaialble when build flag "USE_DISTRIBUTED" is on. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj.ptr() == py::module::import("torch.distributed.rpc").attr("remote").ptr()) { return SpecialFormValue::create(prim::rpc_remote); diff --git a/torch/csrc/jit/python/python_sugared_value.h b/torch/csrc/jit/python/python_sugared_value.h index c00eefa20df03..227af98a50c6b 100644 --- a/torch/csrc/jit/python/python_sugared_value.h +++ b/torch/csrc/jit/python/python_sugared_value.h @@ -68,7 +68,11 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { ErrorReport(loc) << kind() << " cannot be used as a value. " << "Perhaps it is a closed over global variable? If so, please " +<<<<<<< HEAD << "consider passing it in as an argument or use a local variable " +======= + << "consider passing it in as an argument or use a local varible " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) << "instead."); } diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 81da1605fcbe2..80746e4ad18e4 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -89,7 +89,11 @@ std::pair, Stack> createGraphByTracingWithDict( }; // The argument_names parameter is parsed in python and its order +<<<<<<< HEAD // is the same as the arguments' declaration order in forward() method. +======= + // is the same as the arguments' decalaration order in forward() method. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // These name shall be added to the graph as debug name and the order // should align with the traceable stack we generated by the python dict. std::vector compact_argument_names; diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 9d4d681f8b32f..4dedcefd0b170 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -55,7 +55,11 @@ C10_DEFINE_bool( C10_DEFINE_bool( torch_jit_enable_expanded_stacks, false, +<<<<<<< HEAD "When true we will attempts to pre-expand node stacks and cache expanded stacks.") +======= + "When true we will attemps to pre-expand node stacks and cache expanded stacks.") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) C10_DEFINE_bool( torch_jit_expanded_stacks_mangled, diff --git a/torch/csrc/jit/runtime/jit_exception.h b/torch/csrc/jit/runtime/jit_exception.h index 580febe465ff2..27c957e092c2a 100644 --- a/torch/csrc/jit/runtime/jit_exception.h +++ b/torch/csrc/jit/runtime/jit_exception.h @@ -18,7 +18,11 @@ struct TORCH_API JITException : public std::runtime_error { return python_class_name_; } +<<<<<<< HEAD // the original msg if this is from a python exception. The interpreter has +======= + // the original msg if this is from a python exception. The interpretor has +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // changed the original message by adding "The following operation failed in // the TorchScript interpreter." in front of it in the handleError function. std::optional getOriginalMsg() const { diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index d59b93190e36a..b6f1f8489eb50 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -115,8 +115,13 @@ bool isSortableListOfObjectsOrTuples( } auto type = ivalues.get(0).type(); +<<<<<<< HEAD // We assume lists have homogeneous types, use first element to determine // best sorting methods. If in the future we need to support heterogeneous +======= + // We assume lists have homogenous types, use first element to determine + // best sorting methods. If in the future we need to support heterogenous +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // types inside list, then sorting needs to have runtime sortable checks. const size_t n = ivalues.size(); for (const auto i : c10::irange(n)) { @@ -1141,7 +1146,11 @@ static const std::vector opGenArgs{ // // create a clone of these declarations with a _hacked_twin overload name // and nullability scrubbed from TensorList arg types +<<<<<<< HEAD // TODO find out why this exists and how to do it without the hack +======= + // TOOD find out why this exists and how to do it without the hack +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( @@ -2839,7 +2848,11 @@ void hashValue(Stack& stack) { } static const std::vector opGenArgs2{ +<<<<<<< HEAD // registered as Any[] so that heterogeneous tuples can be called with len() +======= + // registered as Any[] so that heterogenous tuples can be called with len() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::len.any(Any[] a) -> int"), listLen, diff --git a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp index d77e0b3a10d64..f87f2411054f7 100644 --- a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp +++ b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp @@ -3204,7 +3204,11 @@ def _batch_norm_with_update(input: List[int], )=====") + std::string(R"=====(def broadcast_inplace(a: List[int], b: List[int]) -> List[int]: +<<<<<<< HEAD _0 = "The dims of tensor b ({}) must be less than or equal to the dims of tensor a ({}) " +======= + _0 = "The dims of tensor b ({}) must be less than or equal tothe dims of tensor a ({}) " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _1 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}" dimsA = torch.len(a) dimsB = torch.len(b) diff --git a/torch/csrc/jit/runtime/static/README.md b/torch/csrc/jit/runtime/static/README.md index ba5e057ca1ec8..3b44a9ea0b24d 100644 --- a/torch/csrc/jit/runtime/static/README.md +++ b/torch/csrc/jit/runtime/static/README.md @@ -71,7 +71,11 @@ Runtime instances in your code. Static runtime's memory planner does two things: 1) Coalesces internal allocations for tensor storage +<<<<<<< HEAD 2) Does static analysis to figure out how to efficiently reuse memory. +======= +2) Does static analysis to figure out how to efficiently re-use memory. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ### Standard Resizing Static runtime will record the space required for each intermediate managed tensor it sees diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index b25f63c939b04..ce3d50c04fbfd 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -29,7 +29,11 @@ TORCH_API std::string dumpValueSet( const c10::FastSet& value_set, const char* set_name = ""); +<<<<<<< HEAD inline bool doesNotHeapAllocateWhenStoredInIValue(const Type& type) { +======= +TORCH_API inline bool doesNotHeapAllocateWhenStoredInIValue(const Type& type) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) switch (type.kind()) { // NOTE: NumberType may allocate because it includes complex. case TypeKind::NoneType: @@ -44,11 +48,19 @@ inline bool doesNotHeapAllocateWhenStoredInIValue(const Type& type) { } } +<<<<<<< HEAD inline c10::Symbol getStaticRuntimeMetadataSymbol() { return Symbol::attr("static_runtime::metadata"); } inline bool borrowsOutputs(c10::Symbol kind) { +======= +TORCH_API inline c10::Symbol getStaticRuntimeMetadataSymbol() { + return Symbol::attr("static_runtime::metadata"); +} + +TORCH_API inline bool borrowsOutputs(c10::Symbol kind) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static const std::array symbols_with_borrowed_outputs = { c10::Symbol::fromQualString("static_runtime::select_tensor"), c10::Symbol::fromQualString("static_runtime::dict_unpack"), @@ -70,7 +82,11 @@ inline bool borrowsOutputs(c10::Symbol kind) { // The output aliases that end up here are as a result of aliasDb failing to // recognize them as outputs due to collection object (e.g., Tuple) aliasing // inputs. +<<<<<<< HEAD // Values that don't show up in output_aliases or external_aliases are created +======= +// Values that dont't show up in output_aliases or external_aliases are created +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // and consumed within the graph. class ValueGroup { public: @@ -111,7 +127,11 @@ class TORCH_API ManagedTensorRanges { // If true, then this node is the last use of at least one // managed tensor. availableTensorValuesAfterNode(node) will return a vector +<<<<<<< HEAD // of the managed tensors that are available for reuse +======= + // of the managed tensors that are available for re-use +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // in the nodes following this one. bool nodeFreesManagedTensors(Node* node) const; const std::vector& availableTensorValuesAfterNode( @@ -141,7 +161,11 @@ class TORCH_API ManagedTensorRanges { void extendInputLifetime(Node* node, size_t new_end); // Maps Node* to the set of managed tensors that are now available +<<<<<<< HEAD // for reuse after this node. +======= + // for re-use after this node. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::FastMap> node_to_newly_free_tensors_{}; // Maps each Value* to its lifetime (start node index, end node index) c10::FastMap value_lifetimes_{}; diff --git a/torch/csrc/jit/runtime/static/memory_planner.cpp b/torch/csrc/jit/runtime/static/memory_planner.cpp index 8660183867e08..5b642b12786f8 100644 --- a/torch/csrc/jit/runtime/static/memory_planner.cpp +++ b/torch/csrc/jit/runtime/static/memory_planner.cpp @@ -76,7 +76,11 @@ std::vector assignStorageToManagedTensors( // This set maps each Value* to its assigned storage group. c10::FastMap storage_group_mapping; // On each iteration, this vector stores the set of storage groups that +<<<<<<< HEAD // are available for reuse. +======= + // are available for re-use. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector free_storage_groups; auto makeNewStorageGroup = [&](const Value* value) { diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp index 716202f45687a..d8cf025897999 100644 --- a/torch/csrc/jit/runtime/static/native_ops.cpp +++ b/torch/csrc/jit/runtime/static/native_ops.cpp @@ -529,7 +529,11 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator { const auto in1_i = p_node->Input(1).toOptional(); const auto in2_i = p_node->Input(2).toBool(); const auto in3_i = p_node->Input(3).toBool(); +<<<<<<< HEAD // To mimic the behavior of the JIT interpreter, if both dtype +======= + // To mimick the behavior of the JIT interpreter, if both dtype +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // and copy are not set, we return self. Otherwise, we assume // that dtype is set. if (!in1_i && !in3_i) { diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 9e408682ca6c3..ffc36f8287267 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1910,7 +1910,11 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator { } auto& out_t = p_node->Output(0).toTensor(); +<<<<<<< HEAD if (te && te->checkInput(in0_t) && in0_t.sizes() == in1_t.sizes() && +======= + if (in0_t.sizes() == in1_t.sizes() && +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) in0_t.scalar_type() == in1_t.scalar_type() && in0_t.strides() == in1_t.strides() && in0_t.is_contiguous() && in0_t.scalar_type() == at::kFloat) { diff --git a/torch/csrc/jit/serialization/export.cpp b/torch/csrc/jit/serialization/export.cpp index 6184889e5f10e..94385d871dece 100644 --- a/torch/csrc/jit/serialization/export.cpp +++ b/torch/csrc/jit/serialization/export.cpp @@ -17,7 +17,10 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include diff --git a/torch/csrc/jit/serialization/export.h b/torch/csrc/jit/serialization/export.h index b8746d0722412..c46103dfd0373 100644 --- a/torch/csrc/jit/serialization/export.h +++ b/torch/csrc/jit/serialization/export.h @@ -5,6 +5,10 @@ #include #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -214,7 +218,11 @@ struct TORCH_API BytecodeEmitMode { // true: instruction of default argument values (like LOADC) is emitted. // false: instruction of default argument values are not emitted. Instead // they are fetched from operator schema. +<<<<<<< HEAD // default_args_before_out_args (to forward compatible support +======= +// default_args_before_out_args (to forward compatibile support +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // operators allowing out arguments and default arguments): // true: the number of specified arguments will deserialized to (#all_args - // #default_args). false: the number of specified arguments will deserialized to diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index e0ded27d375b1..bc722d1ea1c56 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -131,7 +131,11 @@ std::string get_named_tuple_str_or_default( // str() return "Tensor" and repr_str() return "Tensor (inferred)". If // it's not inferred type, str() return "Tensor[]" and repr_str() // return "Tensor". In cpp, repr_str() will always return "Tensor" +<<<<<<< HEAD // regardless inferred type. When exporting custom type in bytecode, +======= + // regardless inferred type. When exporing custom type in bytecode, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // "Tensor" is the preferred way to deserialize Tensor type std::string named_tuple_type_str = it->is_inferred_type() ? named_tuple_type->str() @@ -554,7 +558,11 @@ void ScriptModuleSerializer::writeArchive( } WriteableTensorData writable_td = getWriteableTensorData(td); if (use_storage_context && serialized_tensors.count(tensor_name)) { +<<<<<<< HEAD // storage has been serialized already, skip +======= + // storage has been serialzed already, skip +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue; } writer_.writeRecord( @@ -698,10 +706,17 @@ void ScriptModuleSerializer::writeByteCode( // debug handles. // The reason we save debug handles conditionally is so that // we dont end up with a model that has debug handles but has not +<<<<<<< HEAD // debug map to correlate debug handles with. // Once we have a model with both handles and debug map, we can // strip off debug map and have a lean model served to production. // If exception occurs we have a model with debug map that can be +======= + // debug map to correlate debug handels with. + // Once we have a model with both handles and debug map, we can + // strip off debug map and have a lean model served to production. + // If exception ocurrs we have a model with debug map that can be +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // used to symbolicate debug handles writeArchive( debug_info_telements, diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 2dc3f138ff76d..59c0d8c2651a4 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -1,3 +1,4 @@ +<<<<<<< HEAD #include #include @@ -5,16 +6,34 @@ #include #include +======= +#include +#include +#ifdef USE_RPC +#include +#endif +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include #include +<<<<<<< HEAD #ifdef USE_RPC #include #endif namespace torch::jit { +======= +#include +#include + +namespace torch::jit { + +using ::c10::IValue; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Protocol 2 is the highest that can be decoded by Python 2 // See https://docs.python.org/3/library/pickle.html#data-stream-format constexpr static uint8_t PROTOCOL_VERSION = 2; @@ -719,4 +738,95 @@ void Pickler::pushTuple(const IValue& ivalue) { } } +<<<<<<< HEAD +======= +WriteableTensorData getWriteableTensorData( + const at::Tensor& tensor, + bool to_cpu) { + WriteableTensorData result; + result.tensor_ = tensor; + result.size_ = tensor.storage().nbytes(); + // TODO HIP support + if (tensor.storage().device_type() != DeviceType::CPU && to_cpu) { + // NB: This new tensor is created to support cuda tensors. + // Storages can be mutated when converting tensors from cuda to cpu, + // and we need a cpu tensor to copy data from. + result.tensor_ = + at::empty({0}, tensor.options()) + .set_( + tensor.storage(), + /* storage_offset = */ 0, + /* size = */ + {static_cast( + tensor.storage().nbytes() / tensor.element_size())}, + /* stride = */ {1}) + .cpu(); + TORCH_CHECK( + result.tensor_.storage().nbytes() == result.size_, + "Storage tensor size did not match record size"); + } + return result; +} + +bool checkHasValidSetGetState(const std::shared_ptr& cls) { + // Check that the schemas for __getstate__ and __setstate__ are correct + auto getstate = cls->findMethod("__getstate__"); + if (getstate == nullptr) { + return false; + } + auto get_schema = getstate->getSchema(); + + // Check __getstate__ + // __getstate__ is expected to be (self) -> T + TORCH_CHECK( + get_schema.arguments().size() == 1, + "'__getstate__' must have 'self' as its only argument, but found ", + get_schema.arguments().size(), + " arguments"); + TORCH_CHECK( + get_schema.returns().size() == 1, + "'__getstate__' must return 1 value, but found ", + get_schema.returns().size()); + + // Check __setstate__ if the method exists + // __setstate__ is expected to be (self, T) -> None + auto setstate = cls->findMethod("__setstate__"); + if (!setstate) { + return false; + } + auto set_schema = setstate->getSchema(); + + TORCH_CHECK( + set_schema.arguments().size() == 2, + "'__setstate__' must have 'self' and the state as its " + "only arguments, but found ", + set_schema.arguments().size(), + " arguments"); + TORCH_CHECK( + set_schema.returns().size() == 1, + "'__setstate__' must return None, but found ", + set_schema.returns().size(), + " return values"); + TORCH_CHECK( + set_schema.returns().at(0).type()->isSubtypeOf(*NoneType::get()), + "'__setstate__' must return None, but found value of type", + set_schema.returns().at(0).type()->annotation_str()); + + // Check that the return type of __getstate__ matches the input to + // __setstate__ + auto get_type = get_schema.returns().at(0).type(); + auto set_type = set_schema.arguments().at(1).type(); + + TORCH_CHECK( + get_type->isSubtypeOf(*set_type), + "'__getstate__'s return type (", + get_type->annotation_str(), + ") does not match '__setstate__'s argument type (", + set_type->annotation_str(), + ")"); + + return true; +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace torch::jit diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index 526c840bc10e8..b60c350ca0edf 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -1,5 +1,9 @@ #pragma once +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -8,17 +12,125 @@ #include #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include #include +<<<<<<< HEAD #include namespace torch::jit { using ::c10::IValue; +======= + +namespace torch::jit { + +// See Python's pickletools.py for a detailed description of each of these codes +enum class PickleOpCode : char { + MARK = '(', + STOP = '.', + POP = '0', + POP_MARK = '1', + DUP = '2', + FLOAT = 'F', + INT = 'I', + BININT = 'J', + BININT1 = 'K', + LONG = 'L', + BININT2 = 'M', + NONE = 'N', + PERSID = 'P', + BINPERSID = 'Q', + REDUCE = 'R', + STRING = 'S', + BINSTRING = 'T', + SHORT_BINSTRING = 'U', + // NB: Avoid using UNICODE as it is a macro in the Windows API + UNICODE_ = 'V', + BINUNICODE = 'X', + APPEND = 'a', + BUILD = 'b', + GLOBAL = 'c', + DICT = 'd', + EMPTY_DICT = '}', + APPENDS = 'e', + GET = 'g', + BINGET = 'h', + INST = 'i', + LONG_BINGET = 'j', + LIST = 'l', + EMPTY_LIST = ']', + OBJ = 'o', + PUT = 'p', + BINPUT = 'q', + LONG_BINPUT = 'r', + SETITEM = 's', + TUPLE = 't', + EMPTY_TUPLE = ')', + SETITEMS = 'u', + BINFLOAT = 'G', + + // Protocol 2 + PROTO = char('\x80'), + NEWOBJ = '\x81', + EXT1 = '\x82', + EXT2 = '\x83', + EXT4 = '\x84', + TUPLE1 = '\x85', + TUPLE2 = '\x86', + TUPLE3 = '\x87', + NEWTRUE = '\x88', + NEWFALSE = '\x89', + LONG1 = '\x8a', + LONG4 = '\x8b', + + // Protocol 3 (Python 3.x) + BINBYTES = 'B', + SHORT_BINBYTES = 'C', + + // Protocol 4 + SHORT_BINUNICODE = char('\x8c'), + BINUNICODE8 = '\x8d', + BINBYTES8 = '\x8e', + EMPTY_SET = '\x8f', + ADDITEMS = '\x90', + FROZENSET = '\x91', + NEWOBJ_EX = '\x92', + STACK_GLOBAL = '\x93', + MEMOIZE = '\x94', + FRAME = '\x95' +}; + +using ::c10::IValue; + +struct WriteableTensorData { + const char* data() const { + return static_cast(tensor_.storage().data()); + } + size_t sizeInBytes() const { + return size_; + } + size_t nbytes() const { + return tensor_.storage().nbytes(); + } + bool storageHasDeleter() const { + return tensor_.storage().data_ptr().get_context() != nullptr; + } + + private: + friend TORCH_API WriteableTensorData + getWriteableTensorData(const at::Tensor& tensor, bool to_cpu); + at::Tensor tensor_; + uint64_t size_; +}; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TORCH_API Pickler { AT_DISALLOW_COPY_AND_ASSIGN(Pickler); @@ -182,4 +294,145 @@ class TORCH_API Pickler { bool tag_aggregates_; }; +<<<<<<< HEAD +======= +// returns a (tensor, record_size) for a tensor, converting it to a CPU tensor +// if it was CUDA and to_cpu is True. +TORCH_API WriteableTensorData +getWriteableTensorData(const at::Tensor& tensor, bool to_cpu = true); + +// if the cls has __getstate__/__setstate__ +// assert they have the right schema and return true, +// otherwise return false +bool checkHasValidSetGetState(const std::shared_ptr& cls); + +// Declare BackendMeta serialization and deserialization function pointer types. +using BackendMetaPtr = std::function< + void(const at::Tensor&, std::unordered_map&)>; + +// A allowlist of device type, currently available is PrivateUse1 +inline std::unordered_set& GetBackendMetaAllowlist() { + static std::unordered_set DeviceTypeAllowlist{ + c10::DeviceType::PrivateUse1}; + return DeviceTypeAllowlist; +} + +// Dynamically obtain serialization function pairs +// that require the corresponding backend. +inline std::array< + std::optional>, + at::COMPILE_TIME_MAX_DEVICE_TYPES>& +GetBackendMetaSerialization() { + // The array to save function pointer for BackendMeta serialization. + // key is the DeviceType, value is std::pair obj. + // value.first represent get function and value.seconde represent set function + static std::array< + std::optional>, + at::COMPILE_TIME_MAX_DEVICE_TYPES> + BackendMetaSerialization; + return BackendMetaSerialization; +} + +// Register function pointer of Tensor BackendMetadata for serialization. +TORCH_API inline void TensorBackendMetaRegistry( + c10::DeviceType t, + const BackendMetaPtr& get_fptr, + const BackendMetaPtr& set_fptr) { + // allowlist verification + // Only if the devicetype is in the allowlist, + // we allow the serialization extension to be registered for backendmeta data. + const auto& DeviceTypeAllowlist = GetBackendMetaAllowlist(); + TORCH_CHECK( + DeviceTypeAllowlist.find(t) != DeviceTypeAllowlist.end(), + "It is not allowed to register the serialization method ", + "of backendMeta data for PrivateUse1. ", + "If you have related serialization requirements, ", + "please expand the allowlist"); + // Register function pointer + int device_type = static_cast(t); + auto& BackendMetaSerialization = GetBackendMetaSerialization(); + TORCH_CHECK( + !BackendMetaSerialization[device_type].has_value(), + "The tensor BackendMeta serialization function pointer for ", + t, + " has been registered."); + BackendMetaSerialization[device_type] = + std::optional>( + std::make_pair(get_fptr, set_fptr)); +} + +// Return a map of Tensor Metadata which including BackendMetaData for +// serialization. For now, it only takes care of `conj` and `neg` bit. +inline std::unordered_map getTensorMetadata( + const at::Tensor& t) { + // We don't support serializing `ZeroTensor` as it is not public + // facing yet. + TORCH_CHECK( + !t._is_zerotensor(), + "ZeroTensor is not serializable,", + " please file an issue if required."); + std::unordered_map metadata{}; + + // Only add meta-data if the value is not default. + if (t.is_conj()) { + metadata["conj"] = true; + } + if (t.is_neg()) { + metadata["neg"] = true; + } + // Only add BackendMetaData for custom backend if the function pointer is + // registered. + int device_type = static_cast(t.device().type()); + const auto& BackendMetaSerialization = GetBackendMetaSerialization(); + if (BackendMetaSerialization[device_type].has_value()) { + // Pass the tensor and metadata map references as parameters to the custom + // serialization function. + BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().first; + fptr(t, metadata); + } + return metadata; +} + +// set Tensor Metadata based on the map. +// Refer: getTensorMetadata +inline void setTensorMetadata( + const at::Tensor& t, + std::unordered_map metadata) { + auto iter_end = metadata.end(); + auto iter_temp = metadata.find("conj"); + if (iter_temp != iter_end) { + t._set_conj(true); + metadata.erase(iter_temp); + } + iter_temp = metadata.find("neg"); + if (iter_temp != iter_end) { + t._set_neg(true); + metadata.erase(iter_temp); + } + // Only set BackendMetaData for custom backend if the function pointer is + // registered. + int device_type = static_cast(t.device().type()); + const auto& BackendMetaSerialization = GetBackendMetaSerialization(); + if (BackendMetaSerialization[device_type].has_value()) { + // Pass the tensor and metadata map references as parameters to the custom + // deserialization function. + BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().second; + fptr(t, metadata); + } +} + +// set Tensor metadata based on the map. +// NOTE: This overload is required by unpickler.cpp +inline void setTensorMetadata( + const at::Tensor& t, + const c10::Dict& metadata_idict) { + std::unordered_map metadata; + for (auto& pair : metadata_idict) { + auto key = *pair.key().toString(); + metadata[key] = pair.value().toBool(); + } + setTensorMetadata(t, std::move(metadata)); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace torch::jit diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index 70e188816fb4c..54b27743e5b9d 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -212,7 +212,11 @@ struct PythonPrintImpl { // and would appear in the same order when the expression tree is // reparsed. // The last case can be checked +<<<<<<< HEAD // because when we emit a expression tree in the parser, +======= + // because when we emit a expresion tree in the parser, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // we do a left-to-right postorder traversal of the expression tree (emit // children, then emit op). The reverse of this is a right-to-left preorder // traversal of the tree. By doing a right-to-left preorder traversal of the @@ -222,12 +226,20 @@ struct PythonPrintImpl { // expression. // The inductive step is that the right-most input should be produced by the +<<<<<<< HEAD // node immediately before the current node if it is in tree order. +======= + // node immediatly before the current node if it is in tree order. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool canInline(Value* v) { Node* n = v->node(); // there must be only 1 values, otherwise we need an assignment to handle +<<<<<<< HEAD // the multiple output values +======= + // the multiple outout values +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (n->outputs().size() != 1) return false; // if it is used more than once, then we need a variable @@ -651,7 +663,11 @@ struct PythonPrintImpl { // [reordering of inlines] // We inline anything that is semantically legal to inline, but sometimes // we find that these lines get too long. In that case we break the lines +<<<<<<< HEAD /// and it is important that we un-inline all the inputs preceding the long +======= + /// and it is important that we un-inline all the inputs preceeding the long +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /// input: // r = foo(x.add_(b), some_long + expression) // wrong! @@ -1410,7 +1426,11 @@ struct PythonPrintImpl { enforce_importable_(enforce_importable) {} void printClass(const ClassTypePtr& classType) { +<<<<<<< HEAD // If any of the methods are not Graph functions, this indicates that +======= + // If any of the methods are not Graph funtions, this indicates that +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // this class is a custom-bound C++ class. Skip serialization // of this class, we will depend on the ClassType being defined // in the target process. diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 0253a5588030c..e82cfd6390c9e 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -5,6 +5,10 @@ #endif #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -44,7 +48,11 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { to_process.pop_back(); // ensure we only scan each pointer value once, otherwise this // can become exponential (and if we allow recursive data in the future, +<<<<<<< HEAD // it would not terminate). +======= + // it would not terminiate). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (w.value.isPtrType()) { const void* key = w.value.internalToPointer(); auto it = scanned.find(key); @@ -490,7 +498,11 @@ PickleOpCode Unpickler::readInstruction() { stack_.size(), " and start index is ", start, +<<<<<<< HEAD ", but stack_ is iterated by two elements at a time"); +======= + ", but stack_ is iterated by two elemenst at a time"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (size_t i = start; i < stack_.size(); i += 2) { dict.insert_or_assign(stack_[i], stack_[i + 1]); } @@ -669,6 +681,7 @@ void Unpickler::readGlobal( // See [NOTE] skip_next_read_global this->skip_next_read_global--; if (this->skip_next_read_global == 1) { +<<<<<<< HEAD if (module_name == "torch" && class_name == "Tensor") { // This is a special case when we are unpickling a subclassed tensor // with type torch.nn.Buffer. We didn't frequently run into this because @@ -679,6 +692,8 @@ void Unpickler::readGlobal( this->skip_next_read_global = 0; return; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Pass through to the correct handler } else if (this->skip_next_read_global == 0) { // Corresponds to the type of `Tensor` being unpickled @@ -784,10 +799,13 @@ void Unpickler::readGlobal( // a Subclassed Tensor. rebuildTensorFromTypeV2(); } else if ( +<<<<<<< HEAD module_name == "torch._utils" && (class_name == "_rebuild_parameter")) { // Unpickle a Parameter rebuildParameter(); } else if ( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") { rebuildSparseTensor(); } else if (module_name == "builtins" && class_name == "complex") { @@ -1038,6 +1056,7 @@ void Unpickler::rebuildTensorFromTypeV2() { }); } +<<<<<<< HEAD void Unpickler::rebuildParameter() { globals_.emplace_back([this] { auto args = pop(stack_).toTuple(); @@ -1050,6 +1069,8 @@ void Unpickler::rebuildParameter() { }); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #ifdef USE_RPC void Unpickler::rebuildRRef() { globals_.emplace_back([this] { diff --git a/torch/csrc/jit/serialization/unpickler.h b/torch/csrc/jit/serialization/unpickler.h index 702a1d8816e7f..b896b69212a65 100644 --- a/torch/csrc/jit/serialization/unpickler.h +++ b/torch/csrc/jit/serialization/unpickler.h @@ -3,10 +3,16 @@ #include #include #include +<<<<<<< HEAD #include #include #include +======= +#include +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace torch::jit { @@ -137,7 +143,10 @@ class TORCH_API Unpickler { const std::string& module_name, const std::string& class_name); void rebuildTensor(bool quantized); +<<<<<<< HEAD void rebuildParameter(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void rebuildTensorFromTypeV2(); void rebuildSparseTensor(); #ifdef USE_DISTRIBUTED diff --git a/torch/csrc/jit/tensorexpr/ConditionalsInTE.md b/torch/csrc/jit/tensorexpr/ConditionalsInTE.md index c7bcea4976483..596f6363160bd 100644 --- a/torch/csrc/jit/tensorexpr/ConditionalsInTE.md +++ b/torch/csrc/jit/tensorexpr/ConditionalsInTE.md @@ -14,7 +14,11 @@ So far the recommendation was to standardize on fused conditionals. ## Expression Conditionals vs Statement Conditionals +<<<<<<< HEAD Tensor IR contains both expression conditionals (`CompareSelect` and `IfThenElse`), as well as statement conditionals (`Cond`). Expression conditionals are defined by being functional in nature: there is no side effect from duplicating the conditional, evaluating it twice, etc. They are an important ingredient in expressing important operators like ReLU: +======= +Tensor IR contains both expression conditionals (`CompareSelect` and `IfThenElse`), as well as statement conditionals (`Cond`). Expression conditionals are defined by being functional in nature: there is no side effect from duplicating the conditional, evaluating it twice, etc. They are an important ingredient in expression important operators like ReLU: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` store (((load A) >= 0.0) ? (load A) : 0.0), B diff --git a/torch/csrc/jit/tensorexpr/codegen_external.py b/torch/csrc/jit/tensorexpr/codegen_external.py index 6c8316cc9a420..0dbb3d045cdd2 100644 --- a/torch/csrc/jit/tensorexpr/codegen_external.py +++ b/torch/csrc/jit/tensorexpr/codegen_external.py @@ -77,7 +77,11 @@ def gen_external(native_functions_path, tags_path, external_path): at::Tensor& r = tensors[0]; {nl.join(tensor_decls)} try {{ +<<<<<<< HEAD at::{name}_out({", ".join(["r"] + arg_names)}); +======= + at::{name}_out({', '.join(['r'] + arg_names)}); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }} catch (...) {{ }} }}""" diff --git a/torch/csrc/jit/tensorexpr/external_functions.cpp b/torch/csrc/jit/tensorexpr/external_functions.cpp index c9aedb115a98f..c6c12414e427a 100644 --- a/torch/csrc/jit/tensorexpr/external_functions.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions.cpp @@ -1437,7 +1437,11 @@ void nnc_aten_embedding( r = at::embedding(weight, indices); } catch (...) { } +<<<<<<< HEAD // TODO: have to copy output because at::embedding doesn't have an out +======= + // TODO: have to copy output because at::embedding doesnt have an out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // variant and NNC's external calls don't support allocations memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel()); } diff --git a/torch/csrc/jit/tensorexpr/ir.cpp b/torch/csrc/jit/tensorexpr/ir.cpp index e75af13df9327..85c7d451d2d8a 100644 --- a/torch/csrc/jit/tensorexpr/ir.cpp +++ b/torch/csrc/jit/tensorexpr/ir.cpp @@ -125,7 +125,11 @@ Dtype Intrinsics::IntrinsicsDtype( IntrinsicsOp op_type, const std::vector& params) { // TODO: check the op_type and make a real decision +<<<<<<< HEAD // Doesn't this fail with kRand? +======= + // Doesnt this fail with kRand? +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (params.empty()) { throw malformed_input("invalid params in Intrinsics"); } else if (params.size() == 1) { diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp index 88d86d639c686..07131ba821f7d 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp @@ -930,7 +930,11 @@ ExprPtr PolynomialTransformer::mutate(const MulPtr& v) { variable = lhs_new; } +<<<<<<< HEAD // Handle special case mul by 1 since that's safe for floating point, even if +======= + // Handle special case mul by 1 since thats safe for floating point, even if +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // it's Nan/Inf. if (scalar && immediateEquals(scalar, 1)) { auto c = alloc(v->dtype(), variable); @@ -1105,8 +1109,13 @@ ExprPtr PolynomialTransformer::mutate(const DivPtr& v) { return lhs_new; } +<<<<<<< HEAD // If numerator and denominator are equal the result is 1. // Unless the denominator could be zero. +======= + // If numberator and denominator are equal the result is 1. + // Unless the demoninator could be zero. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // if (hasher_.hash(lhs_new) == hasher_.hash(rhs_new)) { // return getImmediateByType(v->dtype(), 1); // } @@ -1745,7 +1754,11 @@ ExprPtr TermExpander::mutate(const TermPtr& v) { std::vector vars; std::vector multilaneVars; +<<<<<<< HEAD // Assume we can reorder here because we won't merge floating terms. +======= + // Assume we can reorder here because we wont merge floating terms. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ExprPtr lastNode{nullptr}; for (const auto& var : v->variables()) { ExprPtr node = var->accept_mutator(this); @@ -1830,7 +1843,11 @@ static ExprPtr polyGCD(const PolynomialPtr& poly) { ExprPtr scalar = poly->scalar(); const std::vector& variables = poly->variables(); +<<<<<<< HEAD // We only want to factorize if we're saving complete operations, i.e. no +======= + // We ony want to factorize if we're saving complete operations, i.e. no +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // value in factorizing 6x + 4y into 2 * (3x + 2y) since we don't save work. int opsSaved = 1; // default to saving the scalar. long GCD = std::abs(immediateAs(scalar)); @@ -2088,7 +2105,11 @@ static ExprPtr simplifyRoundModPattern(const PolynomialPtr& poly) { // TODO: for now don't attempt partial factorization of this // optimization. E.g. it's possible to do: 2 * (x/y) * y + (x%y) => x + +<<<<<<< HEAD // (x/y) * y but unsure that's actually much better, particularly with +======= + // (x/y) * y but unsure thats actually much better, particularly with +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // CSE. if (!immediateEquals( evaluateOp(alloc(r->scalar(), m->scalar())), 0)) { diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index a8ffa40f58dba..86b618190a908 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -1263,11 +1263,19 @@ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides( const std::vector& sorted_stride_indices_descending, const std::vector& strides, BufPtr& buf) { +<<<<<<< HEAD // We need to convert the output tensor so that its values are laid // so that when viewed from the output strides the values are correct. // A contiguous Tensor of size(2, 3) with values 0-5 is laid out as: // [0] [1] [2] [3] [4] [5] // The same valued tensor with strides (1, 2) would be laid out like +======= + // We need to convert the output tensor so that its values are layed + // so that when viewed from the output strides the values are correct. + // A contiguous Tensor of size(2, 3) with values 0-5 is layed out as: + // [0] [1] [2] [3] [4] [5] + // The same valued tensor with strides (1, 2) would be layed out like +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // [0] [3] [1] [4] [2] [5] // When we are doing the re-ordering of values into the output tensor, // we are iterating per-element of the input, and we are fixed @@ -1378,7 +1386,11 @@ Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides( tt->strides().concrete_sizes(), buildErrorMessage("Output strides are unknown.")); const std::vector strides = *tt->strides().concrete_sizes(); +<<<<<<< HEAD // All Tensors in NNC are laid out in default, contiguous layout. +======= + // All Tensors in NNC are layed out in default, contiguous layout. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // If the output is also default contiguous we don't need to do anything if (strides == default_strides) { return Tensor(buf, nullptr); diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 918d82579444f..0ce79b7a71df5 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -83,7 +83,11 @@ using namespace torch::jit::tensorexpr; C10_DEFINE_bool( torch_jit_llvm_use_fast_intrinsics, false, +<<<<<<< HEAD "Use fast (but slightly less accurate) implementations of tanh and sigmoid") +======= + "Use fast (but slightly less accurate) implementations of tanh and sigmoid"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace torch::jit::tensorexpr { @@ -246,7 +250,11 @@ class LLVMCodeGenImpl : public IRVisitor { std::string kernel_func_name_; #define LLVM_TYPE_DECLARE(_1, Name) llvm::Type* Name##Ty_; +<<<<<<< HEAD AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, LLVM_TYPE_DECLARE) +======= + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, LLVM_TYPE_DECLARE); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #undef LLVM_TYPE_DECLARE #if LLVM_VERSION_MAJOR >= 15 @@ -780,7 +788,11 @@ void LLVMCodeGenImpl::emitKernel( GRAPH_DEBUG("\nLLVM generated assembly code\n\n", asmCode_, "\n"); } +<<<<<<< HEAD // TODO: The binary ops are copypaste. +======= +// TODO: The binary ops are copypasta. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void LLVMCodeGenImpl::visit(const AddPtr& v) { v->lhs()->accept(this); @@ -878,7 +890,11 @@ void LLVMCodeGenImpl::visit(const OrPtr& v) { bool rfp = rhs->getType()->isFPOrFPVectorTy(); if (!lfp && !rfp) { +<<<<<<< HEAD value_ = irb_.CreateOr(lhs, rhs); // codespell:ignore +======= + value_ = irb_.CreateOr(lhs, rhs); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { throw malformed_input("llvm_codegen: bad type in Or", v); } @@ -1101,7 +1117,11 @@ std::enable_if_t, llvm::Value*> getFromType( void LLVMCodeGenImpl::visit(const Name##ImmPtr& v) { \ value_ = getFromType(Name##Ty_, v->value()); \ } +<<<<<<< HEAD AT_FORALL_SCALAR_TYPES(IMM_VISIT_DECLARE) +======= +AT_FORALL_SCALAR_TYPES(IMM_VISIT_DECLARE); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #undef IMM_VISIT_DECLARE void LLVMCodeGenImpl::visit(const HalfImmPtr& v) { @@ -1225,7 +1245,11 @@ void LLVMCodeGenImpl::visit(const CastPtr& v) { } value_ = irb_.CreateFPCast(value_, dstType); } else if (dstType->isIntOrIntVectorTy()) { +<<<<<<< HEAD // Strictly casting from Float -> i8 doesn't give correct results +======= + // Strictly casting from Float -> i8 doesnt give correct results +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // set one bit true if the input float is not 0 if (v->dtype().scalar_type() == ScalarType::Bool) { llvm::Value* zero = diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.cpp b/torch/csrc/jit/tensorexpr/llvm_jit.cpp index 80d919a5674e6..f70000fd521c5 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_jit.cpp @@ -11,7 +11,10 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override") #include C10_DIAGNOSTIC_POP() +<<<<<<< HEAD C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -36,7 +39,10 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #endif #include #include +<<<<<<< HEAD C10_DIAGNOSTIC_POP() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.h b/torch/csrc/jit/tensorexpr/llvm_jit.h index 19a21329b64a7..433a23f916000 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.h +++ b/torch/csrc/jit/tensorexpr/llvm_jit.h @@ -9,11 +9,17 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override") #include C10_DIAGNOSTIC_POP() +<<<<<<< HEAD C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include #include C10_DIAGNOSTIC_POP() +======= +#include +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 7f0888666d3af..39b7b5c7ecc1c 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -987,7 +987,11 @@ void LoopNest::inlineIntermediateBufs(bool allow_duplicated_work) { } } +<<<<<<< HEAD // all bufs will have at least one store (if they have > 1 they can't be +======= + // all bufs will have at least one store (if they have > 1 they cant be +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // inlined anyway) size_t reads = uses.size() - 1; // if only one read, we can inline it without duplicating work diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h index 222ac5713d36b..4ef204b19fa06 100644 --- a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h +++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h @@ -240,7 +240,11 @@ class TORCH_API MemDependencyChecker : public IRVisitor { std::unordered_set> accessesWithin( const StmtPtr& A) const; // TODO: this will return only the AccessInfo for A. It's included for +<<<<<<< HEAD // completeness but be aware it won't return accesses used in the computation +======= + // completeness but be aware it wont return accesses used in the computation +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // of A. std::unordered_set> accessesWithin( const ExprPtr& A) const; diff --git a/torch/csrc/jit/tensorexpr/registerizer.cpp b/torch/csrc/jit/tensorexpr/registerizer.cpp index 37f79d529238d..a2ae6af1ef5f0 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.cpp +++ b/torch/csrc/jit/tensorexpr/registerizer.cpp @@ -225,7 +225,11 @@ void RegisterizerAnalysis::visit(const ForPtr& v) { // possible that an access at a higher scope could "unhide" the // conditional access, in which case we need to hoist. If there is no // access to this element at a higher scope then we cannot safely hoist. +<<<<<<< HEAD // We cannot know at this level whether that will or won't occur. +======= + // We cannot know at this level whether that will or wont occur. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // // The solution we take here is to split the space-time continuum, and // keep both versions of the access handy. If the hoisted access is not @@ -542,7 +546,11 @@ void RegisterizerAnalysis::mergeCurrentScopeIntoParent() { closeAccessIntoScope(pCandidate, parent); parentAccesses.erase(parentIt); +<<<<<<< HEAD // the children access inserted into the parent scope. +======= + // the childs access inserted into the parent scope. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) closeAccessIntoScope(candidate, parent); continue; } @@ -567,7 +575,11 @@ void RegisterizerAnalysis::mergeCurrentScopeIntoParent() { ++it; } +<<<<<<< HEAD // Insert the children closed access into the parent scope. +======= + // Insert the childs closed access into the parent scope. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) closeAccessIntoScope(candidate, parent); } diff --git a/torch/csrc/jit/tensorexpr/registerizer.h b/torch/csrc/jit/tensorexpr/registerizer.h index c507d3b13a95e..f27638c5a2999 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.h +++ b/torch/csrc/jit/tensorexpr/registerizer.h @@ -186,7 +186,11 @@ class AccessInfo { bool firstUsageOverlapped_{false}; // The cost in real ops that this access represents, to enable +<<<<<<< HEAD // filtering accesses that won't save any loads or stores. +======= + // filtering accesses that wont save any loads or stores. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ExprPtr store_cost_; ExprPtr load_cost_; diff --git a/torch/csrc/lazy/core/tensor_impl.cpp b/torch/csrc/lazy/core/tensor_impl.cpp index ce49338936e39..334e2e57daf93 100644 --- a/torch/csrc/lazy/core/tensor_impl.cpp +++ b/torch/csrc/lazy/core/tensor_impl.cpp @@ -195,14 +195,22 @@ bool LTCTensorImpl::is_strides_like_custom( return false; } +<<<<<<< HEAD c10::SymBool LTCTensorImpl::sym_is_non_overlapping_and_dense_custom() const { +======= +bool LTCTensorImpl::is_non_overlapping_and_dense_custom() const { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // This should be true, but false as a temporary fix for a PyTorch core issue, // according to https://github.com/pytorch/xla/pull/2682. return false; } +<<<<<<< HEAD c10::SymBool LTCTensorImpl::sym_is_contiguous_custom( c10::MemoryFormat _unused) const { +======= +bool LTCTensorImpl::is_contiguous_custom(c10::MemoryFormat _unused) const { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // TODO(ezyang): I don't think this branch is actually necessary // TODO(ezyang): I don't think this logic is right, shouldn't we pass on // the memory format? diff --git a/torch/csrc/lazy/core/tensor_impl.h b/torch/csrc/lazy/core/tensor_impl.h index 02f68c01c6f44..4b54ef23b473b 100644 --- a/torch/csrc/lazy/core/tensor_impl.h +++ b/torch/csrc/lazy/core/tensor_impl.h @@ -41,11 +41,18 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl { int64_t numel_custom() const override; int64_t storage_offset_custom() const override; int64_t dim_custom() const override; +<<<<<<< HEAD bool is_strides_like_custom(at::MemoryFormat memory_format) const override; c10::SymBool sym_is_non_overlapping_and_dense_custom() const override; c10::SymBool sym_is_contiguous_custom( at::MemoryFormat memory_format) const override; +======= + bool is_contiguous_custom(at::MemoryFormat memory_format) const override; + bool is_strides_like_custom(at::MemoryFormat memory_format) const override; + bool is_non_overlapping_and_dense_custom() const override; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::SymIntArrayRef sym_sizes_custom() const override; c10::SymIntArrayRef sym_strides_custom() const override; c10::SymInt sym_numel_custom() const override; diff --git a/torch/csrc/lazy/python/init.cpp b/torch/csrc/lazy/python/init.cpp index 4807aa6a4c7d1..0d7f8a81c9d25 100644 --- a/torch/csrc/lazy/python/init.cpp +++ b/torch/csrc/lazy/python/init.cpp @@ -331,9 +331,19 @@ void initLazyBindings(PyObject* module) { // So far this problem has only been observed internally, so we will just // block it off there. +<<<<<<< HEAD // When libtorch_python is loaded, we register the python frame getter // otherwise, debug util simply omits python frames GetPythonFramesFunction() = GetPythonFrames; +======= +#if !(defined(USE_DEPLOY)) + + // When libtorch_python is loaded, we register the python frame getter + // otherwise, debug util simply omits python frames + GetPythonFramesFunction() = GetPythonFrames; + +#endif // USE_DEPLOY +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // namespace torch::lazy diff --git a/torch/csrc/lazy/ts_backend/ts_lowering_context.h b/torch/csrc/lazy/ts_backend/ts_lowering_context.h index 356ea3d8e9231..ba910be618b31 100644 --- a/torch/csrc/lazy/ts_backend/ts_lowering_context.h +++ b/torch/csrc/lazy/ts_backend/ts_lowering_context.h @@ -91,7 +91,11 @@ class TORCH_API TSLoweringContext : public LoweringContext { for (torch::jit::Value* output : root_tuple_) { graph_->block()->registerOutput(output); } +<<<<<<< HEAD return std::make_shared(graph_); +======= + return std::shared_ptr(new TSComputation(graph_)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // Retrieves the lowered operation for an output. If the requested output is diff --git a/torch/csrc/monitor/events.h b/torch/csrc/monitor/events.h index 2ec89251c62e4..072ba47fa72fe 100644 --- a/torch/csrc/monitor/events.h +++ b/torch/csrc/monitor/events.h @@ -35,7 +35,11 @@ struct TORCH_API Event { std::unordered_map data; }; +<<<<<<< HEAD inline bool operator==(const Event& lhs, const Event& rhs) { +======= +TORCH_API inline bool operator==(const Event& lhs, const Event& rhs) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return lhs.name == rhs.name && lhs.timestamp == rhs.timestamp && lhs.data == rhs.data; } diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index 51c77aba6d765..b5f5ce73e90ac 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -501,12 +501,15 @@ void initModule(PyObject* module) { at::mps::getMPSProfiler().startCapture(fileName); }); m.def("_mps_stopCapture", []() { at::mps::getMPSProfiler().stopCapture(); }); +<<<<<<< HEAD m.def("_mps_get_name", []() { return at::mps::MPSDevice::getInstance()->getName(); }); m.def("_mps_get_core_count", []() { return at::mps::MPSDevice::getInstance()->getCoreCount(); }); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } #endif /* USE_MPS */ diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp index 0b8171a372653..ee56b11fcb928 100644 --- a/torch/csrc/mtia/Module.cpp +++ b/torch/csrc/mtia/Module.cpp @@ -63,6 +63,7 @@ void initModule(PyObject* module) { return at::detail::getMTIAHooks().getDefaultStream(device_index); }); +<<<<<<< HEAD m.def( "_mtia_setStream", [](int64_t stream_id, @@ -75,6 +76,8 @@ void initModule(PyObject* module) { static_cast(device_type))); }); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m.def("_mtia_setCurrentStream", [](const c10::Stream& stream) { torch::utils::device_lazy_init(at::kMTIA); auto device = at::detail::getMTIAHooks().getCurrentDevice(); diff --git a/torch/csrc/profiler/README.md b/torch/csrc/profiler/README.md index dc27337349ddc..ec1a19757506c 100644 --- a/torch/csrc/profiler/README.md +++ b/torch/csrc/profiler/README.md @@ -13,6 +13,7 @@ The profiler instruments PyTorch to collect information about the model's execut - [Codebase Structure](#codebase-structure) - [`RecordFunction`](#recordfunction) - [Autograd Integration](#autograd-integration) +<<<<<<< HEAD - [Torch Operation Collection](#torch-operation-collection) - [Allocation Event Collection](#allocation-event-collection) - [Kineto Integration](#kineto-integration) @@ -56,6 +57,16 @@ torch/ │ └── record_function.h # RecordFunction definitions └── LICENSE # License information ``` +======= +- [Collection and Post-Processing](#collection-and-post-processing) +- [Kineto Integration](#kineto-integration) +- [Python Tracing](#python-tracing) + +## Codebase Structure ## + +TODO + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ## `RecordFunction` ## [aten/src/ATen/record_function.h](../../../aten/src/ATen/record_function.h) @@ -78,6 +89,7 @@ The profiler records two pieces of information from the autograd engine: (\*) Note that only op invocations whose inputs require gradients are assigned a sequence number +<<<<<<< HEAD ## Torch Operation Collection ## This section describes the general flow for collecting torch operations during auto-trace (in-process, synchronous tracing). For details on on-demand tracing (out-of-process, asynchronous), please refer to the Libkineto README. @@ -114,3 +126,16 @@ This setup allows for detailed and accurate data collection on both Python and C ## Clock Alignment ## Depending on the system environment, the profiler will use the most efficient clock when creating a timestamp. The default for most Linux systems is TSC, which records time in the form of CPU cycles. To convert from this time to the unix time in nanoseconds, we create a clock converter. If Kineto is included in the profiler, this converter will also be passed into Kineto as well to ensure alignment. +======= +## Collection and Post-Processing ## + +TODO + +## Kineto Integration ## + +TODO + +## Python Tracing ## + +TODO +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp index c7f759cd077c9..b631e1a3d01e6 100644 --- a/torch/csrc/profiler/collection.cpp +++ b/torch/csrc/profiler/collection.cpp @@ -396,8 +396,12 @@ std::unique_ptr ThreadLocalSubqueue::begin_op( } event->start_time_ = c10::getApproximateTime(); +<<<<<<< HEAD event->allow_tf32_cublas_ = at::globalContext().float32Precision("cuda", "matmul") == "tf32"; +======= + event->allow_tf32_cublas_ = at::globalContext().allowTF32CuBLAS(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!config_.experimental_config.performance_events.empty()) { const size_t n = config_.experimental_config.performance_events.size(); event->counters_ = std::make_unique(n, 0); @@ -613,7 +617,10 @@ std::string Result::name() const { ATTRIBUTE(OutOfMemory, std::string("[OutOfMemory]")), ATTRIBUTE(PyCall, toString(e)), ATTRIBUTE(PyCCall, std::string(e.function_name_.str())), +<<<<<<< HEAD ATTRIBUTE(PythonGC, std::string("Python GC")), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [](const auto& e) -> std::string { return e.name_; })); } @@ -632,7 +639,10 @@ libkineto::ActivityType Result::kinetoType() const { ATTRIBUTE(OutOfMemory, libkineto::ActivityType::CPU_INSTANT_EVENT), ATTRIBUTE(PyCall, libkineto::ActivityType::PYTHON_FUNCTION), ATTRIBUTE(PyCCall, libkineto::ActivityType::PYTHON_FUNCTION), +<<<<<<< HEAD ATTRIBUTE(PythonGC, libkineto::ActivityType::PYTHON_FUNCTION), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ATTRIBUTE(Kineto, e.activity_type_))); } @@ -652,7 +662,10 @@ int64_t Result::endTimeNS() const { ATTRIBUTE(Allocation, start_time_ns_), ATTRIBUTE(OutOfMemory, start_time_ns_), ATTRIBUTE(Kineto, start_time_ns_ + e.duration_ns_), +<<<<<<< HEAD ATTRIBUTE(PythonGC, start_time_ns_ + e.duration_ns_), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [&](const auto& e) -> int64_t { return e.end_time_ns_; })); // In rare cases we're willing to tolerate ops which are missing an end time @@ -703,9 +716,12 @@ RecordQueue::RecordQueue( activities_{std::move(activities)} { if (tracePython()) { python_tracer_ = python_tracer::PythonTracerBase::make(this); +<<<<<<< HEAD if (getPythonGcEvents()) { python_tracer_->register_gc_callback(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -713,10 +729,13 @@ bool RecordQueue::tracePython() const { return config_.with_stack && activities_.count(ActivityType::CPU); } +<<<<<<< HEAD bool RecordQueue::getPythonGcEvents() const { return config_.experimental_config.record_python_gc_info; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ThreadLocalSubqueue* RecordQueue::getSubqueue() { // In the most common case, a thread will want to write to the same sub-queue // that it wrote to last call. The only time that isn't true is if: @@ -1026,12 +1045,15 @@ class TransferEvents { } } +<<<<<<< HEAD bool isHiddenEvent(const itrace_t* activity) const { TORCH_INTERNAL_ASSERT(activity != nullptr); // Kineto uses "hidden" metadata to mark events that should be hidden. return activity->getMetadataValue("hidden") == "1"; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::shared_ptr resultFromActivity(const itrace_t* activity) { TORCH_INTERNAL_ASSERT(activity != nullptr); @@ -1052,7 +1074,11 @@ class TransferEvents { {/*id=*/static_cast(activity->flowId()), /*type=*/static_cast(activity->flowType()), /*start=*/activity->flowStart()}}); +<<<<<<< HEAD event->hidden_ = isHiddenEvent(activity); +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // NB: It's tempting to set `event->kineto_activity_`; however we can only // guarantee that the events we passed to Kineto are of type // `GenericTraceActivity`. Others may derive from ITraceActivity and thus @@ -1498,6 +1524,7 @@ RecordQueue::getRecords( queue.allocations_.clear(); materialize(queue.ooms_); +<<<<<<< HEAD std::optional pending_start; for (auto& e : queue.pythongc_) { if (e.first.find("start") != std::string::npos) { @@ -1523,6 +1550,8 @@ RecordQueue::getRecords( } } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (auto& i : queue.py_calls_) { python_enters.push_back( {i.first, queue.tid(), queue.kineto_info(), converter(i.second)}); diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h index 847819f971957..d367ddb0bb1dd 100644 --- a/torch/csrc/profiler/collection.h +++ b/torch/csrc/profiler/collection.h @@ -34,8 +34,12 @@ enum class EventType : uint8_t { OutOfMemory, PyCall, PyCCall, +<<<<<<< HEAD Kineto, PythonGC +======= + Kineto +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; // ============================================================================ @@ -193,12 +197,15 @@ struct ExtraFields { }; template <> +<<<<<<< HEAD struct ExtraFields { std::string phase; int64_t duration_ns_; }; template <> +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) struct ExtraFields { using raw_event_t = std::pair; std::string name_; @@ -422,14 +429,22 @@ struct TORCH_API Result : public std::enable_shared_from_this { ExtraFields, ExtraFields, ExtraFields, +<<<<<<< HEAD ExtraFields, ExtraFields> +======= + ExtraFields> +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extra_fields_; std::weak_ptr parent_; std::vector> children_; bool finished_{false}; +<<<<<<< HEAD bool hidden_{false}; +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const torch::profiler::impl::kineto::activity_t* kineto_activity_{nullptr}; private: @@ -557,11 +572,14 @@ class TORCH_API ThreadLocalSubqueue { py_calls_.emplace_back(std::forward(args)...); } +<<<<<<< HEAD template void emplace_gc_call(Args&&... args) { pythongc_.emplace_back(std::forward(args)...); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint64_t tid() const { return tid_; } @@ -652,9 +670,12 @@ class TORCH_API ThreadLocalSubqueue { std::pair, BlockSize> py_calls_; +<<<<<<< HEAD // gc with_stack (Python) AppendOnlyList, BlockSize> pythongc_; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; class TORCH_API RecordQueue { @@ -662,7 +683,10 @@ class TORCH_API RecordQueue { RecordQueue(ProfilerConfig config, std::set activities); bool tracePython() const; +<<<<<<< HEAD bool getPythonGcEvents() const; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ThreadLocalSubqueue* getSubqueue(); void stop(); void restart(); diff --git a/torch/csrc/profiler/kineto_shim.cpp b/torch/csrc/profiler/kineto_shim.cpp index ec9994e15ec9c..288e1cf2cdc18 100644 --- a/torch/csrc/profiler/kineto_shim.cpp +++ b/torch/csrc/profiler/kineto_shim.cpp @@ -50,7 +50,10 @@ const std::set kXpuTypes = { const std::set kMtiaTypes = { libkineto::ActivityType::MTIA_CCP_EVENTS, libkineto::ActivityType::MTIA_RUNTIME, +<<<<<<< HEAD libkineto::ActivityType::MTIA_INSIGHT, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; const std::set hpuTypes = { libkineto::ActivityType::HPU_OP, @@ -207,7 +210,10 @@ class ExperimentalConfigWrapper { configss << "\nCUPTI_PROFILER_ENABLE_PER_KERNEL=" << (config_.profiler_measure_per_kernel ? "true" : "false") << "\n"; +<<<<<<< HEAD configss << "CUSTOM_CONFIG=" << config_.custom_profiler_config << "\n"; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) LOG(INFO) << "Generated config = " << configss.str(); libkineto::api().activityProfiler().prepareTrace( @@ -240,6 +246,7 @@ static const std::string setTraceID(const std::string& trace_id) { configss << "REQUEST_GROUP_TRACE_ID=" << trace_id << "\n"; return configss.str(); } +<<<<<<< HEAD static const std::string appendCustomConfig( const std::string& config, @@ -252,6 +259,8 @@ static const std::string appendCustomConfig( configss << "CUSTOM_CONFIG=" << custom_profiler_config << "\n"; return configss.str(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #endif void prepareTrace( @@ -308,9 +317,13 @@ void prepareTrace( return; } +<<<<<<< HEAD const std::string traceIdStr = setTraceID(trace_id); const std::string configStr = appendCustomConfig(traceIdStr, config.custom_profiler_config); +======= + const std::string configStr = setTraceID(trace_id); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) libkineto::api().activityProfiler().prepareTrace(k_activities, configStr); #endif // USE_KINETO diff --git a/torch/csrc/profiler/orchestration/observer.cpp b/torch/csrc/profiler/orchestration/observer.cpp index 5ef0690d18115..edba44f4289c3 100644 --- a/torch/csrc/profiler/orchestration/observer.cpp +++ b/torch/csrc/profiler/orchestration/observer.cpp @@ -21,8 +21,11 @@ ExperimentalConfig::ExperimentalConfig( bool disable_external_correlation, bool profile_all_threads, bool capture_overload_names, +<<<<<<< HEAD bool record_python_gc_info, std::string custom_profiler_config, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool adjust_timestamps) : profiler_metrics{std::move(profiler_metrics)}, profiler_measure_per_kernel{profiler_measure_per_kernel}, @@ -33,8 +36,11 @@ ExperimentalConfig::ExperimentalConfig( disable_external_correlation{disable_external_correlation}, profile_all_threads{profile_all_threads}, capture_overload_names{capture_overload_names}, +<<<<<<< HEAD record_python_gc_info{record_python_gc_info}, custom_profiler_config(std::move(custom_profiler_config)), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) adjust_timestamps{adjust_timestamps} {} /*explicit*/ ExperimentalConfig::operator bool() const { diff --git a/torch/csrc/profiler/orchestration/observer.h b/torch/csrc/profiler/orchestration/observer.h index ba62e9b56b5c6..5a913ec5ffb30 100644 --- a/torch/csrc/profiler/orchestration/observer.h +++ b/torch/csrc/profiler/orchestration/observer.h @@ -62,8 +62,11 @@ struct TORCH_API ExperimentalConfig { bool disable_external_correlation = false, bool profile_all_threads = false, bool capture_overload_names = false, +<<<<<<< HEAD bool record_python_gc_info = false, std::string custom_profiler_config = "", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool adjust_timestamps = false); explicit operator bool() const; @@ -104,6 +107,7 @@ struct TORCH_API ExperimentalConfig { bool capture_overload_names; /* +<<<<<<< HEAD * Controls whether or not python gc info is recorded. This is used to * determine if gc collect is slowing down your profile. */ @@ -116,6 +120,8 @@ struct TORCH_API ExperimentalConfig { std::string custom_profiler_config; /* +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * Controls whether or not timestamp adjustment occurs after profiling. * The purpose of this is to adjust Vulkan event timelines to align with those * of their parent CPU events. diff --git a/torch/csrc/profiler/orchestration/python_tracer.cpp b/torch/csrc/profiler/orchestration/python_tracer.cpp index 0d1ad389f8896..1d5c9ca117910 100644 --- a/torch/csrc/profiler/orchestration/python_tracer.cpp +++ b/torch/csrc/profiler/orchestration/python_tracer.cpp @@ -11,7 +11,10 @@ struct NoOpPythonTracer : public PythonTracerBase { void stop() override {} void restart() override {} +<<<<<<< HEAD void register_gc_callback() override {} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector> getEvents( std::function, std::vector&, diff --git a/torch/csrc/profiler/orchestration/python_tracer.h b/torch/csrc/profiler/orchestration/python_tracer.h index 1011f75b82308..4e02662ae25a3 100644 --- a/torch/csrc/profiler/orchestration/python_tracer.h +++ b/torch/csrc/profiler/orchestration/python_tracer.h @@ -48,7 +48,10 @@ struct TORCH_API PythonTracerBase { virtual void stop() = 0; virtual void restart() = 0; +<<<<<<< HEAD virtual void register_gc_callback() = 0; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) virtual std::vector> getEvents( std::function time_converter, std::vector& enters, diff --git a/torch/csrc/profiler/python/init.cpp b/torch/csrc/profiler/python/init.cpp index aa7abe9433fe1..039a826e414ef 100644 --- a/torch/csrc/profiler/python/init.cpp +++ b/torch/csrc/profiler/python/init.cpp @@ -340,9 +340,13 @@ void initPythonBindings(PyObject* module) { bool /* adjust_profiler_step */, bool /* disable_external_correlation*/, bool /* profile_all_threads */, +<<<<<<< HEAD bool /* capture_overload_names */, bool /* record_python_gc_info */, std::string /* custom_profiler_config*/ +======= + bool /* capture_overload_names */ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >(), "An experimental config for Kineto features. Please note that" "backward compatibility is not guaranteed.\n" @@ -357,12 +361,19 @@ void initPythonBindings(PyObject* module) { " that expose CUDA device, stream and event synchronization activities. This feature is new\n" " and currently disabled by default.\n" " adjust_profiler_step (bool) : whether to adjust the profiler step to\n" +<<<<<<< HEAD " match the parent python event duration. This feature is new and currently disabled by default.\n" " disable_external_correlation (bool) : whether to disable external correlation\n" " profile_all_threads (bool) : whether to profile all threads\n" " capture_overload_names (bool) : whether to include ATen overload names in the profile\n" " record_python_gc_info (bool) : adds python gc events to profile\n" " custom_profiler_config (string) : Used to pass some configurations to the custom profiler backend.\n", +======= + " match the parent python event duration. This feature is new and currently disabled by default.\n", + " disable_external_correlation (bool) : whether to disable external correlation\n", + " profile_all_threads (bool) : whether to profile all threads\n", + " capture_overload_names (bool) : whether to include ATen overload names in the profile\n", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) py::arg("profiler_metrics") = std::vector(), py::arg("profiler_measure_per_kernel") = false, py::arg("verbose") = false, @@ -371,9 +382,13 @@ void initPythonBindings(PyObject* module) { py::arg("adjust_profiler_step") = false, py::arg("disable_external_correlation") = false, py::arg("profile_all_threads") = false, +<<<<<<< HEAD py::arg("capture_overload_names") = false, py::arg("record_python_gc_info") = false, py::arg("custom_profiler_config") = "") +======= + py::arg("capture_overload_names") = false) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def(py::pickle( [](const ExperimentalConfig& p) { // __getstate__ py::list py_metrics; @@ -396,8 +411,11 @@ void initPythonBindings(PyObject* module) { p.disable_external_correlation, p.profile_all_threads, p.capture_overload_names, +<<<<<<< HEAD p.record_python_gc_info, p.custom_profiler_config, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) p.performance_events); }, [](const py::tuple& t) { // __setstate__ diff --git a/torch/csrc/profiler/stubs/cuda.cpp b/torch/csrc/profiler/stubs/cuda.cpp index e08b2a3efd0f2..9d425aa8f67d1 100644 --- a/torch/csrc/profiler/stubs/cuda.cpp +++ b/torch/csrc/profiler/stubs/cuda.cpp @@ -1,11 +1,15 @@ #include #ifndef ROCM_ON_WINDOWS +<<<<<<< HEAD #ifdef TORCH_CUDA_USE_NVTX3 #include #else #include #endif +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #else // ROCM_ON_WINDOWS #include #endif // ROCM_ON_WINDOWS diff --git a/torch/csrc/profiler/util.h b/torch/csrc/profiler/util.h index f2ae57fa0e591..4e24581a95531 100644 --- a/torch/csrc/profiler/util.h +++ b/torch/csrc/profiler/util.h @@ -43,7 +43,11 @@ TORCH_API void logSoftAssert( uint32_t line, const char* cond, const char* args); +<<<<<<< HEAD inline void logSoftAssert( +======= +TORCH_API inline void logSoftAssert( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const char* func, const char* file, uint32_t line, diff --git a/torch/csrc/serialization.cpp b/torch/csrc/serialization.cpp index dd3027d372dcf..774c7e50d2494 100644 --- a/torch/csrc/serialization.cpp +++ b/torch/csrc/serialization.cpp @@ -351,14 +351,25 @@ c10::intrusive_ptr THPStorage_readFileRaw( _storage_nbytes); } +<<<<<<< HEAD std::string cpu_data; +======= + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + std::unique_ptr cpu_data; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) uint8_t* data{}; if (storage->device_type() == at::kCPU) { data = static_cast(storage->mutable_data()); } else { +<<<<<<< HEAD cpu_data.resize(nbytes); data = (uint8_t*)cpu_data.data(); +======= + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + cpu_data = std::unique_ptr(new char[nbytes]); + data = (uint8_t*)cpu_data.get(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // fast track for bytes and little endian @@ -368,16 +379,27 @@ c10::intrusive_ptr THPStorage_readFileRaw( doRead(file, data, storage->nbytes()); } else { int64_t buffer_size = std::min(size, (int64_t)5000); +<<<<<<< HEAD std::vector le_buffer; le_buffer.resize(buffer_size * element_size); for (int64_t i = 0; i < size; i += buffer_size) { size_t to_convert = std::min(size - i, buffer_size); doRead(file, le_buffer.data(), element_size * to_convert); +======= + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) + std::unique_ptr le_buffer( + new uint8_t[buffer_size * element_size]); + + for (int64_t i = 0; i < size; i += buffer_size) { + size_t to_convert = std::min(size - i, buffer_size); + doRead(file, le_buffer.get(), element_size * to_convert); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // NOLINTNEXTLINE(bugprone-branch-clone) if (element_size == 2) { torch::utils::THP_decodeBuffer( +<<<<<<< HEAD (int16_t*)data + i, le_buffer.data(), true, to_convert); } else if (element_size == 4) { torch::utils::THP_decodeBuffer( @@ -385,6 +407,15 @@ c10::intrusive_ptr THPStorage_readFileRaw( } else if (element_size == 8) { torch::utils::THP_decodeBuffer( (int64_t*)data + i, le_buffer.data(), true, to_convert); +======= + (int16_t*)data + i, le_buffer.get(), true, to_convert); + } else if (element_size == 4) { + torch::utils::THP_decodeBuffer( + (int32_t*)data + i, le_buffer.get(), true, to_convert); + } else if (element_size == 8) { + torch::utils::THP_decodeBuffer( + (int64_t*)data + i, le_buffer.get(), true, to_convert); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } } diff --git a/torch/csrc/stable/library.h b/torch/csrc/stable/library.h index 741b6229042a4..060f58786ed56 100644 --- a/torch/csrc/stable/library.h +++ b/torch/csrc/stable/library.h @@ -4,16 +4,213 @@ // code for better UX. #include +<<<<<<< HEAD // Technically, this file doesn't use anything from stableivalue_conversions.h, // but we need to include it here as the contents of stableivalue_conversions.h // used to live here and so we need to expose them for backwards compatibility. #include +======= +#include + +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // use anonymous namespace to avoid collisions between differing // versions of this file that may be included by different sources namespace { +<<<<<<< HEAD +======= +// ============================================================================= +// helpers for converting between StableIValue and T +// ============================================================================= + +// forward declare so that from/to() calls in detail work +template +StableIValue from(T val); +template +T to(StableIValue val); + +namespace detail { + +// ============================================================================= +// FROM CONVERSIONS (T -> StableIValue) +// ============================================================================= + +// Specialization for general copyable types (catch-all) => StableIValue +template +struct FromImpl { + static StableIValue call(T val) { + static_assert( + sizeof(T) <= sizeof(StableIValue), + "StableLibrary stack does not support parameter types larger than 64 bits."); + static_assert(std::is_trivially_copyable_v); + // Initialization should be cheap enough; let's give people well-specified + // reproducible behavior. + StableIValue result = 0; + // NOTE [ -Wclass-memaccess ]: reinterpret_cast to suppress + // overzealous -Wclass-memaccess. (see + // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=107361) We have a + // static_assert above that T is trivially copyable, which should be + // enough. + std::memcpy(&result, reinterpret_cast(&val), sizeof(val)); + return result; + } +}; + +// Specialization for std::nullopt_t => StableIValue +template <> +struct FromImpl { + static StableIValue call(std::nullopt_t val) { + return from(nullptr); + } +}; + +// Specialization for std::optional => StableIValue +// [Handling std::optional] +// When the schema is represented by an optional type, say int?, then we +// expect the custom extension representation to be a std::optional +// (critically NOT int!). In order for all parameters to be stably parsed and +// handled by our dispatcher, we liaison custom extension parameters through +// boxed kernels, meaning that every value will make its way to be an IValue: +// +// custom extension value --(from)-> StableIValue --(to_ivalue)-> IValue +// +// When the custom extension value is a literal that can be trivially +// casted to StableIValue, e.g., an int, a float, a pointer, this route is +// ...trivial. The below specialization is for a case when the custom +// extension value would NOT fit within a StableIValue: a std::optional. +// +// If the std::optional has no value, it is treated as std::nullopt, +// whose StableIValue representation is from(nullptr). Otherwise, we: +// 1. unwrap the std::optional +// 2. recursively convert its value of type T to a StableIValue +// 3. allocate heap space for said StableIValue +// 4. convert the resulting StableIValue* into a StableIValue +// +// note that this allocates heap memory! which we expect to be cleaned +// up in the to_ivalue() function defined in shim_common.cpp. We +// purposefully hide this implementation detail from the user so that +// all the user needs to know is: +// +// The schema requests an optional (T?) so I must call `from` on a +// std::optional or a std::nullopt. +template +struct FromImpl> { + static StableIValue call(const std::optional& val) { + if (!val.has_value()) { + return from(std::nullopt); + } + StableIValue* heap_val = new StableIValue(from(val.value())); + return from(heap_val); + } +}; + +// Specialization for torch::stable::Tensor => StableIValue +// Returns a new owning reference of the underlying Tensor. +template <> +struct FromImpl { + static StableIValue call(const torch::stable::Tensor& val) { + AtenTensorHandle new_ath; + aoti_torch_new_tensor_handle(val.get(), &new_ath); + return from(new_ath); + } +}; + +// ============================================================================= +// TO CONVERSIONS (StableIValue -> T) +// ============================================================================= + +// Specialization for StableIValue => general copyable types (catch-all) +template +struct ToImpl { + static T call(StableIValue val) { + static_assert(std::is_trivially_copyable_v); + // T may not have a default constructor. (For example, it might be + // c10::Device.) However, std::memcpy implicitly creates a T at the + // destination. So, we can use a union to work around this lack of + // default constructor. + union Result { + Result() {} + T t; + }; + Result result; + // See NOTE[ -Wclass-memaccess ] above. + std::memcpy(reinterpret_cast(&result.t), &val, sizeof(result)); + return result.t; + } +}; + +// Specialization for StableIValue => std::nullopt_t +template <> +struct ToImpl { + static std::nullopt_t call(StableIValue val) { + // val should be equivalent to from(nullptr) + return std::nullopt; + } +}; + +// Specialization for StableIValue => std::optional, see [Handling +// std::optional] as the semantic is the same but in reverse direction as we go +// from IValue --(from_ivalue)-> StableIValue --(to)-> T in custom extension +template +struct ToImpl> { + static std::optional call(StableIValue val) { + auto sivp = to(val); + + // sivp is either nullptr or a pointer to a StableIValue + if (sivp == nullptr) { + return {}; + } + auto inner_val = to(*sivp); + + // free the memory associated with StableIValue* sivp + delete sivp; + + return std::make_optional(inner_val); + } +}; + +// Specialization for StableIValue => torch::stable::Tensor +// The resulting stable::Tensor steals ownership of the input's +// underlying AtenTensorHandle. +template <> +struct ToImpl { + static torch::stable::Tensor call(StableIValue val) { + return torch::stable::Tensor(to(val)); + } +}; + +} // namespace detail + +// Expose the partially templated class functions through single functions +template +StableIValue from(T val) { + return detail::FromImpl::call(val); +} + +template +StableIValue from(const std::optional& val) { + return detail::FromImpl>::call(val); +} + +// The below overload is used! See https://godbolt.org/z/859cshxrW +// We are suppressing the warning for versions clang12- and gcc11- +[[maybe_unused]] StableIValue from(const torch::stable::Tensor& val) { + return detail::FromImpl::call(val); +} + +template +T to(StableIValue val) { + return detail::ToImpl::call(val); +} + +// ============================================================================= +// end to helpers for converting between StableIValue and T +// ============================================================================= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class StableLibrary final { private: TorchLibraryHandle lib_; diff --git a/torch/csrc/stable/tensor.h b/torch/csrc/stable/tensor.h index 8762372a415cf..989776d426b31 100644 --- a/torch/csrc/stable/tensor.h +++ b/torch/csrc/stable/tensor.h @@ -1,4 +1,131 @@ #pragma once +<<<<<<< HEAD #include #include +======= +// TODO ASAP: THIS FILE SHOULD BE HEADER ONLY BUT ISN'T ENFORCED: +// I only need it for AOTI_TORCH_ERROR_CODE_CHECK, see #154908 +#include + +#include + +namespace torch::stable { + +using DeviceIndex = + int8_t; // this is from c10/core/Device.h and can be header only + +// The torch::stable::Tensor class is a highlevel C++ wrapper around +// the C shim Tensor APIs. We've modeled this class after TensorBase, as custom +// op kernels only really need to interact with Tensor metadata (think sizes, +// strides, device, dtype). Other functions on Tensor (like empty_like) should +// live like the ATen op that they are and exist outside of this struct. +// +// There are several goals of this class over AtenTensorHandle and +// RAIIAtenTensorHandle: +// 1. torch::stable::Tensor is a nicer UX much closer to torch::Tensor than the +// C APIs with AtenTensorHandle. Under the hood we still call to these C shim +// APIs to preserve stability. +// 2. RAIIAtenTensorHandle boils down to a uniq_ptr that forces the user to pass +// around ownership. This makes it difficult to pass one input into 2 +// different functions, e.g., doing something like c = a(t) + b(t) for +// stable::Tensor t. Thus, we use a shared_ptr here. +class Tensor { + private: + std::shared_ptr ath_; + + public: + Tensor() = delete; + + // Construct a stable::Tensor from an AtenTensorHandle (ATH) + // Steals ownership from the ATH + explicit Tensor(AtenTensorHandle ath) + : ath_(ath, [](AtenTensorHandle ath) { + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_tensor_object(ath)); + }) {} + + // Copy and move constructors can be default cuz the underlying handle is a + // shared_ptr + Tensor(const Tensor& other) = default; + Tensor(Tensor&& other) noexcept = default; + + // Copy and move assignment operators can be default cuz the underlying handle + // is a shared_ptr + Tensor& operator=(const Tensor& other) = default; + Tensor& operator=(Tensor&& other) noexcept = default; + + // Destructor can be default: shared ptr has custom deletion logic + ~Tensor() = default; + + // Returns a borrowed reference to the AtenTensorHandle + AtenTensorHandle get() const { + return ath_.get(); + } + + // ============================================================================= + // C-shimified TensorBase APIs: the below APIs have the same signatures and + // semantics as their counterparts in TensorBase.h. + // ============================================================================= + + void* data_ptr() const { + void* data_ptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr)); + return data_ptr; + } + + int64_t dim() const { + int64_t dim; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim)); + return dim; + } + + int64_t numel() const { + int64_t numel; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel)); + return numel; + } + + // note: this is a subset of the original TensorBase API. It takes no + // arguments whereas the original API takes in a kwarg of memory format. + // Here, we assume the default contiguous memory format. + bool is_contiguous() const { + bool is_contiguous; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_is_contiguous(ath_.get(), &is_contiguous)); + return is_contiguous; + } + + int64_t stride(int64_t dim) const { + int64_t stride; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_stride(ath_.get(), dim, &stride)); + return stride; + } + + DeviceIndex get_device() const { + int32_t device_index; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_index(ath_.get(), &device_index)); + return static_cast(device_index); + } + + bool is_cuda() const { + int32_t device_type; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_type(ath_.get(), &device_type)); + return device_type == aoti_torch_device_type_cuda(); + } + + int64_t size(int64_t dim) const { + int64_t size; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(ath_.get(), dim, &size)); + return size; + } + + // ============================================================================= + // END of C-shimified TensorBase APIs + // ============================================================================= +}; + +} // namespace torch::stable +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/csrc/utils.cpp b/torch/csrc/utils.cpp index c23a41e8e64ef..a989477ed7f21 100644 --- a/torch/csrc/utils.cpp +++ b/torch/csrc/utils.cpp @@ -240,6 +240,7 @@ uint8_t storage_get(const at::Storage& self, ptrdiff_t idx) { return self_t[idx].item(); } +<<<<<<< HEAD std::string uuid_to_string(const char* uuid_bytes) { // UUIDs are a 128-bit label. CUDA/HIP and XPU store this as char[16]. // For string representation, the code here expands this to @@ -268,6 +269,8 @@ std::string uuid_to_string(const char* uuid_bytes) { (uint8_t)uuid_bytes[15]); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template class THPPointer; // NOLINTBEGIN(misc-use-internal-linkage) namespace torch::gdb { diff --git a/torch/csrc/utils.h b/torch/csrc/utils.h index 71a2b10e59046..87f81c3796b2c 100644 --- a/torch/csrc/utils.h +++ b/torch/csrc/utils.h @@ -201,5 +201,8 @@ bool maybeThrowBackCompatKeepdimWarn(char* func); void storage_fill(const at::Storage& self, uint8_t value); void storage_set(const at::Storage& self, ptrdiff_t idx, uint8_t value); uint8_t storage_get(const at::Storage& self, ptrdiff_t idx); +<<<<<<< HEAD std::string uuid_to_string(const char* uuid_bytes); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp index 9dc6e9777a36e..432c4df9d2d23 100644 --- a/torch/csrc/utils/disable_torch_function.cpp +++ b/torch/csrc/utils/disable_torch_function.cpp @@ -5,7 +5,10 @@ #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace torch { static PyObject* disabled_torch_function = nullptr; @@ -220,9 +223,14 @@ PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) { } else if (PyTuple_Check(args)) { py_args = py::reinterpret_borrow(args); } else { +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format("expected List or Tuple (got {})", Py_TYPE(args)->tp_name)); +======= + throw torch::TypeError( + "expected List or Tuple (got %s)", Py_TYPE(args)->tp_name); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // These are all C-API calls so no exceptions will be raised @@ -255,9 +263,14 @@ PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* a) { } else if (PyTuple_Check(args)) { py_args = py::reinterpret_borrow(args); } else { +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format("expected List or Tuple (got {})", Py_TYPE(args)->tp_name)); +======= + throw torch::TypeError( + "expected List or Tuple (got %s)", Py_TYPE(args)->tp_name); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // This implementation is not completely correct. The moral diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h index bec4e283dcac8..3394cb63e0fdb 100644 --- a/torch/csrc/utils/generated_serialization_types.h +++ b/torch/csrc/utils/generated_serialization_types.h @@ -1,5 +1,9 @@ // @generated by update_schema.py +<<<<<<< HEAD // checksum<<74d07b92c36d5854263145c231553dcda15215f0460e7ace43554248c05378ec>> +======= +// checksum<<110c364974d3b0f7dcbdf6862781212bdcc7178925c43c894c336fc2b6ca6628>> +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // clang-format off #pragma once @@ -61,7 +65,10 @@ class ForwardRef { ptr_ = std::make_unique(*other.ptr_); return *this; } +<<<<<<< HEAD ~ForwardRef(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const T& operator*() const { return *ptr_; } @@ -129,7 +136,10 @@ inline void from_json(const nlohmann::json& j, F64& f) { class AOTInductorModelPickleData; class Argument; class BufferMutationSpec; +<<<<<<< HEAD class ComplexValue; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ConstantValue; class CustomObjArgument; class Device; @@ -150,6 +160,10 @@ class InputToParameterSpec; class InputToTensorConstantSpec; class InputTokenSpec; class LossOutputSpec; +<<<<<<< HEAD +======= +class Model; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ModuleCallEntry; class ModuleCallSignature; class NamedArgument; @@ -158,9 +172,13 @@ class Node; class OptionalTensorArgument; class OutputSpec; class OutputTokenSpec; +<<<<<<< HEAD class ParameterMutationSpec; class PayloadConfig; class PayloadMeta; +======= +class Program; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class RangeConstraint; class SchemaVersion; class SymBool; @@ -286,8 +304,11 @@ enum class ScalarType { UINT16 = 28, FLOAT8E4M3FN = 29, FLOAT8E5M2 = 30, +<<<<<<< HEAD FLOAT8E4M3FNUZ = 31, FLOAT8E5M2FNUZ = 32, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; inline std::string_view printEnum(const ScalarType& e) { @@ -309,8 +330,11 @@ inline std::string_view printEnum(const ScalarType& e) { case ScalarType::UINT16: return "UINT16"; case ScalarType::FLOAT8E4M3FN: return "FLOAT8E4M3FN"; case ScalarType::FLOAT8E5M2: return "FLOAT8E5M2"; +<<<<<<< HEAD case ScalarType::FLOAT8E4M3FNUZ: return "FLOAT8E4M3FNUZ"; case ScalarType::FLOAT8E5M2FNUZ: return "FLOAT8E5M2FNUZ"; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) default: throw std::runtime_error("Unknown enum value"); } @@ -334,8 +358,11 @@ inline void parseEnum(std::string_view s, ScalarType& t) { if (s == "UINT16") { t = ScalarType::UINT16; return; } if (s == "FLOAT8E4M3FN") { t = ScalarType::FLOAT8E4M3FN; return; } if (s == "FLOAT8E5M2") { t = ScalarType::FLOAT8E5M2; return; } +<<<<<<< HEAD if (s == "FLOAT8E4M3FNUZ") { t = ScalarType::FLOAT8E4M3FNUZ; return; } if (s == "FLOAT8E5M2FNUZ") { t = ScalarType::FLOAT8E5M2FNUZ; return; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -1200,6 +1227,7 @@ class CustomObjArgument { friend void from_json(const nlohmann::json& nlohmann_json_j, CustomObjArgument& nlohmann_json_t); }; +<<<<<<< HEAD class ComplexValue { private: F64 real; @@ -1227,16 +1255,26 @@ class ComplexValue { friend void from_json(const nlohmann::json& nlohmann_json_j, ComplexValue& nlohmann_json_t); }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Argument { struct Void {}; public: enum class Tag { +<<<<<<< HEAD AS_NONE, AS_TENSOR, AS_TENSORS, AS_INT, AS_INTS, AS_FLOAT, AS_FLOATS, AS_STRING, AS_STRINGS, AS_SYM_INT, AS_SYM_INTS, AS_SCALAR_TYPE, AS_MEMORY_FORMAT, AS_LAYOUT, AS_DEVICE, AS_BOOL, AS_BOOLS, AS_SYM_BOOL, AS_SYM_BOOLS, AS_GRAPH, AS_OPTIONAL_TENSORS, AS_CUSTOM_OBJ, AS_OPERATOR, AS_SYM_FLOAT, AS_SYM_FLOATS, AS_OPTIONAL_TENSOR, AS_COMPLEX }; private: std::variant, int64_t, std::vector, F64, std::vector, std::string, std::vector, SymIntArgument, std::vector, ScalarType, MemoryFormat, Layout, Device, bool, std::vector, SymBoolArgument, std::vector, GraphArgument, std::vector, CustomObjArgument, std::string, SymFloatArgument, std::vector, OptionalTensorArgument, ComplexValue> variant_; +======= + AS_NONE, AS_TENSOR, AS_TENSORS, AS_INT, AS_INTS, AS_FLOAT, AS_FLOATS, AS_STRING, AS_STRINGS, AS_SYM_INT, AS_SYM_INTS, AS_SCALAR_TYPE, AS_MEMORY_FORMAT, AS_LAYOUT, AS_DEVICE, AS_BOOL, AS_BOOLS, AS_SYM_BOOL, AS_SYM_BOOLS, AS_GRAPH, AS_OPTIONAL_TENSORS, AS_CUSTOM_OBJ, AS_OPERATOR, AS_SYM_FLOAT, AS_SYM_FLOATS, AS_OPTIONAL_TENSOR + }; + + private: + std::variant, int64_t, std::vector, F64, std::vector, std::string, std::vector, SymIntArgument, std::vector, ScalarType, MemoryFormat, Layout, Device, bool, std::vector, SymBoolArgument, std::vector, GraphArgument, std::vector, CustomObjArgument, std::string, SymFloatArgument, std::vector, OptionalTensorArgument> variant_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tag tag_; public: @@ -1478,6 +1516,7 @@ class Argument { tag_ = Tag::AS_OPTIONAL_TENSOR; } +<<<<<<< HEAD const ComplexValue& get_as_complex() const { return std::get<27>(variant_); } @@ -1487,6 +1526,8 @@ class Argument { tag_ = Tag::AS_COMPLEX; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) friend void to_json(nlohmann::json& nlohmann_json_j, const Argument& nlohmann_json_t) { if (nlohmann_json_t.tag_ == Tag::AS_NONE) { @@ -1593,10 +1634,13 @@ class Argument { nlohmann_json_j["as_optional_tensor"] = nlohmann_json_t.get_as_optional_tensor(); return; } +<<<<<<< HEAD if (nlohmann_json_t.tag_ == Tag::AS_COMPLEX) { nlohmann_json_j["as_complex"] = nlohmann_json_t.get_as_complex(); return; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } friend void from_json(const nlohmann::json& nlohmann_json_j, Argument& nlohmann_json_t) { @@ -1731,11 +1775,14 @@ class Argument { nlohmann_json_t.tag_ = Tag::AS_OPTIONAL_TENSOR; return; } +<<<<<<< HEAD if (nlohmann_json_j.contains("as_complex")) { nlohmann_json_t.variant_.emplace<27>(nlohmann_json_j.at("as_complex").template get()); nlohmann_json_t.tag_ = Tag::AS_COMPLEX; return; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } }; @@ -1767,7 +1814,10 @@ inline std::string_view printEnum(const Argument::Tag& e) { case Argument::Tag::AS_SYM_FLOAT: return "AS_SYM_FLOAT"; case Argument::Tag::AS_SYM_FLOATS: return "AS_SYM_FLOATS"; case Argument::Tag::AS_OPTIONAL_TENSOR: return "AS_OPTIONAL_TENSOR"; +<<<<<<< HEAD case Argument::Tag::AS_COMPLEX: return "AS_COMPLEX"; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) default: throw std::runtime_error("Unknown enum value"); } @@ -1800,7 +1850,10 @@ inline void parseEnum(std::string_view s, Argument::Tag& t) { if (s == "AS_SYM_FLOAT") { t = Argument::Tag::AS_SYM_FLOAT; return; } if (s == "AS_SYM_FLOATS") { t = Argument::Tag::AS_SYM_FLOATS; return; } if (s == "AS_OPTIONAL_TENSOR") { t = Argument::Tag::AS_OPTIONAL_TENSOR; return; } +<<<<<<< HEAD if (s == "AS_COMPLEX") { t = Argument::Tag::AS_COMPLEX; return; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -2544,6 +2597,7 @@ class BufferMutationSpec { friend void from_json(const nlohmann::json& nlohmann_json_j, BufferMutationSpec& nlohmann_json_t); }; +<<<<<<< HEAD class ParameterMutationSpec { private: TensorArgument arg; @@ -2571,6 +2625,8 @@ class ParameterMutationSpec { friend void from_json(const nlohmann::json& nlohmann_json_j, ParameterMutationSpec& nlohmann_json_t); }; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class GradientToParameterSpec { private: TensorArgument arg; @@ -2675,11 +2731,19 @@ class OutputSpec { public: enum class Tag { +<<<<<<< HEAD USER_OUTPUT, LOSS_OUTPUT, BUFFER_MUTATION, GRADIENT_TO_PARAMETER, GRADIENT_TO_USER_INPUT, USER_INPUT_MUTATION, TOKEN, PARAMETER_MUTATION }; private: std::variant variant_; +======= + USER_OUTPUT, LOSS_OUTPUT, BUFFER_MUTATION, GRADIENT_TO_PARAMETER, GRADIENT_TO_USER_INPUT, USER_INPUT_MUTATION, TOKEN + }; + + private: + std::variant variant_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tag tag_; public: @@ -2750,6 +2814,7 @@ class OutputSpec { tag_ = Tag::TOKEN; } +<<<<<<< HEAD const ParameterMutationSpec& get_parameter_mutation() const { return std::get<8>(variant_); } @@ -2759,6 +2824,8 @@ class OutputSpec { tag_ = Tag::PARAMETER_MUTATION; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) friend void to_json(nlohmann::json& nlohmann_json_j, const OutputSpec& nlohmann_json_t) { if (nlohmann_json_t.tag_ == Tag::USER_OUTPUT) { @@ -2789,10 +2856,13 @@ class OutputSpec { nlohmann_json_j["token"] = nlohmann_json_t.get_token(); return; } +<<<<<<< HEAD if (nlohmann_json_t.tag_ == Tag::PARAMETER_MUTATION) { nlohmann_json_j["parameter_mutation"] = nlohmann_json_t.get_parameter_mutation(); return; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } friend void from_json(const nlohmann::json& nlohmann_json_j, OutputSpec& nlohmann_json_t) { @@ -2832,11 +2902,14 @@ class OutputSpec { nlohmann_json_t.tag_ = Tag::TOKEN; return; } +<<<<<<< HEAD if (nlohmann_json_j.contains("parameter_mutation")) { nlohmann_json_t.variant_.emplace<8>(nlohmann_json_j.at("parameter_mutation").template get()); nlohmann_json_t.tag_ = Tag::PARAMETER_MUTATION; return; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } }; @@ -2849,7 +2922,10 @@ inline std::string_view printEnum(const OutputSpec::Tag& e) { case OutputSpec::Tag::GRADIENT_TO_USER_INPUT: return "GRADIENT_TO_USER_INPUT"; case OutputSpec::Tag::USER_INPUT_MUTATION: return "USER_INPUT_MUTATION"; case OutputSpec::Tag::TOKEN: return "TOKEN"; +<<<<<<< HEAD case OutputSpec::Tag::PARAMETER_MUTATION: return "PARAMETER_MUTATION"; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) default: throw std::runtime_error("Unknown enum value"); } @@ -2863,7 +2939,10 @@ inline void parseEnum(std::string_view s, OutputSpec::Tag& t) { if (s == "GRADIENT_TO_USER_INPUT") { t = OutputSpec::Tag::GRADIENT_TO_USER_INPUT; return; } if (s == "USER_INPUT_MUTATION") { t = OutputSpec::Tag::USER_INPUT_MUTATION; return; } if (s == "TOKEN") { t = OutputSpec::Tag::TOKEN; return; } +<<<<<<< HEAD if (s == "PARAMETER_MUTATION") { t = OutputSpec::Tag::PARAMETER_MUTATION; return; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -3110,7 +3189,10 @@ class ExportedProgram { SchemaVersion schema_version; std::vector verifiers = {}; std::string torch_version = "<=2.4"; +<<<<<<< HEAD std::vector guards_code = {}; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) public: @@ -3162,6 +3244,7 @@ class ExportedProgram { torch_version = std::move(def); } +<<<<<<< HEAD const std::vector& get_guards_code() const { return guards_code; } @@ -3170,10 +3253,13 @@ class ExportedProgram { guards_code = std::move(def); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) friend void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nlohmann_json_t); friend void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t); }; +<<<<<<< HEAD class PayloadMeta { private: std::string path_name; @@ -3235,6 +3321,87 @@ class PayloadConfig { friend void to_json(nlohmann::json& nlohmann_json_j, const PayloadConfig& nlohmann_json_t); friend void from_json(const nlohmann::json& nlohmann_json_j, PayloadConfig& nlohmann_json_t); +======= +class Program { + private: + std::unordered_map methods; + + public: + + const std::unordered_map& get_methods() const { + return methods; + } + + void set_methods(std::unordered_map def) { + methods = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Program& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, Program& nlohmann_json_t); +}; + +class Model { + private: + std::string name; + std::unordered_map tensorPaths; + Program program; + std::unordered_map delegates; + std::unordered_map deviceAllocationMap; + std::unordered_map constantPaths; + + public: + + const std::string& get_name() const { + return name; + } + + void set_name(std::string def) { + name = std::move(def); + } + + const std::unordered_map& get_tensorPaths() const { + return tensorPaths; + } + + void set_tensorPaths(std::unordered_map def) { + tensorPaths = std::move(def); + } + + const Program& get_program() const { + return program; + } + + void set_program(Program def) { + program = std::move(def); + } + + const std::unordered_map& get_delegates() const { + return delegates; + } + + void set_delegates(std::unordered_map def) { + delegates = std::move(def); + } + + const std::unordered_map& get_deviceAllocationMap() const { + return deviceAllocationMap; + } + + void set_deviceAllocationMap(std::unordered_map def) { + deviceAllocationMap = std::move(def); + } + + const std::unordered_map& get_constantPaths() const { + return constantPaths; + } + + void set_constantPaths(std::unordered_map def) { + constantPaths = std::move(def); + } + + friend void to_json(nlohmann::json& nlohmann_json_j, const Model& nlohmann_json_t); + friend void from_json(const nlohmann::json& nlohmann_json_j, Model& nlohmann_json_t); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; class AOTInductorModelPickleData { @@ -3375,6 +3542,7 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, BufferMutationSpec& nlohmann_json_t.buffer_name = nlohmann_json_j.value("buffer_name", nlohmann_json_default_obj.buffer_name); } +<<<<<<< HEAD inline void to_json(nlohmann::json& nlohmann_json_j, const ComplexValue& nlohmann_json_t) { nlohmann_json_j["real"] = nlohmann_json_t.real; nlohmann_json_j["imag"] = nlohmann_json_t.imag; @@ -3386,6 +3554,8 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, ComplexValue& nlohm nlohmann_json_t.imag = nlohmann_json_j.value("imag", nlohmann_json_default_obj.imag); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inline void to_json(nlohmann::json& nlohmann_json_j, const CustomObjArgument& nlohmann_json_t) { nlohmann_json_j["name"] = nlohmann_json_t.name; nlohmann_json_j["class_fqn"] = nlohmann_json_t.class_fqn; @@ -3415,7 +3585,10 @@ inline void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nloh nlohmann_json_j["schema_version"] = nlohmann_json_t.schema_version; nlohmann_json_j["verifiers"] = nlohmann_json_t.verifiers; nlohmann_json_j["torch_version"] = nlohmann_json_t.torch_version; +<<<<<<< HEAD nlohmann_json_j["guards_code"] = nlohmann_json_t.guards_code; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } inline void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t) { @@ -3426,7 +3599,10 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nl nlohmann_json_t.schema_version = nlohmann_json_j.value("schema_version", nlohmann_json_default_obj.schema_version); nlohmann_json_t.verifiers = nlohmann_json_j.value("verifiers", nlohmann_json_default_obj.verifiers); nlohmann_json_t.torch_version = nlohmann_json_j.value("torch_version", nlohmann_json_default_obj.torch_version); +<<<<<<< HEAD nlohmann_json_t.guards_code = nlohmann_json_j.value("guards_code", nlohmann_json_default_obj.guards_code); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } inline void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNode& nlohmann_json_t) { @@ -3610,6 +3786,28 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, LossOutputSpec& nlo nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); } +<<<<<<< HEAD +======= +inline void to_json(nlohmann::json& nlohmann_json_j, const Model& nlohmann_json_t) { + nlohmann_json_j["name"] = nlohmann_json_t.name; + nlohmann_json_j["tensorPaths"] = nlohmann_json_t.tensorPaths; + nlohmann_json_j["program"] = nlohmann_json_t.program; + nlohmann_json_j["delegates"] = nlohmann_json_t.delegates; + nlohmann_json_j["deviceAllocationMap"] = nlohmann_json_t.deviceAllocationMap; + nlohmann_json_j["constantPaths"] = nlohmann_json_t.constantPaths; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, Model& nlohmann_json_t) { + Model nlohmann_json_default_obj; + nlohmann_json_t.name = nlohmann_json_j.value("name", nlohmann_json_default_obj.name); + nlohmann_json_t.tensorPaths = nlohmann_json_j.value("tensorPaths", nlohmann_json_default_obj.tensorPaths); + nlohmann_json_t.program = nlohmann_json_j.value("program", nlohmann_json_default_obj.program); + nlohmann_json_t.delegates = nlohmann_json_j.value("delegates", nlohmann_json_default_obj.delegates); + nlohmann_json_t.deviceAllocationMap = nlohmann_json_j.value("deviceAllocationMap", nlohmann_json_default_obj.deviceAllocationMap); + nlohmann_json_t.constantPaths = nlohmann_json_j.value("constantPaths", nlohmann_json_default_obj.constantPaths); +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inline void to_json(nlohmann::json& nlohmann_json_j, const ModuleCallEntry& nlohmann_json_t) { nlohmann_json_j["fqn"] = nlohmann_json_t.fqn; nlohmann_json_j["signature"] = nlohmann_json_t.signature; @@ -3686,6 +3884,7 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, OutputTokenSpec& nl nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); } +<<<<<<< HEAD inline void to_json(nlohmann::json& nlohmann_json_j, const ParameterMutationSpec& nlohmann_json_t) { nlohmann_json_j["arg"] = nlohmann_json_t.arg; nlohmann_json_j["parameter_name"] = nlohmann_json_t.parameter_name; @@ -3719,6 +3918,15 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, PayloadMeta& nlohma nlohmann_json_t.is_param = nlohmann_json_j.value("is_param", nlohmann_json_default_obj.is_param); nlohmann_json_t.use_pickle = nlohmann_json_j.value("use_pickle", nlohmann_json_default_obj.use_pickle); nlohmann_json_t.tensor_meta = nlohmann_json_j.value("tensor_meta", nlohmann_json_default_obj.tensor_meta); +======= +inline void to_json(nlohmann::json& nlohmann_json_j, const Program& nlohmann_json_t) { + nlohmann_json_j["methods"] = nlohmann_json_t.methods; +} + +inline void from_json(const nlohmann::json& nlohmann_json_j, Program& nlohmann_json_t) { + Program nlohmann_json_default_obj; + nlohmann_json_t.methods = nlohmann_json_j.value("methods", nlohmann_json_default_obj.methods); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } inline void to_json(nlohmann::json& nlohmann_json_j, const RangeConstraint& nlohmann_json_t) { @@ -3825,7 +4033,10 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, UserOutputSpec& nlo template ForwardRef::ForwardRef(ForwardRef&&) = default; template ForwardRef& ForwardRef::operator=(ForwardRef&&) = default; +<<<<<<< HEAD template ForwardRef::~ForwardRef() = default; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace _export } // namespace torch diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 613657e03b926..23d9d52f38687 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -303,10 +303,13 @@ static py::object maybe_get_registered_torch_dispatch_rule( return result; } +<<<<<<< HEAD // NB: Invariant: if you run this function, you MUST test if the returned // py::object is nullptr, as this will occur WITHOUT error condition being set. // And if an error happens, this function is responsible for throwing a C++ // error. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static py::object dispatch_on_subclass( PyObject* args, PyObject* kwargs, @@ -386,7 +389,10 @@ static py::object dispatch_on_subclass( break; } } +<<<<<<< HEAD // NB: PyErr_Occurred is NOT set here, this means NO dispatch happened +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ret; } @@ -588,6 +594,7 @@ auto handle_torch_function_no_python_arg_parser( } if (ret.ptr() == nullptr) { +<<<<<<< HEAD // We didn't successfully dispatch anything, this should be impossible TORCH_INTERNAL_ASSERT( 0, @@ -597,6 +604,11 @@ auto handle_torch_function_no_python_arg_parser( overloaded_args, ", is_mode_active = ", is_mode_active()); +======= + // if an exception occurred in a user's implementation of + // __torch_function__, throw it + throw python_error(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else if (ret.ptr() == Py_NotImplemented) { // all __torch_function__ implementations in overloaded_args // returned NotImplemented, so we raise a TypeError. @@ -677,6 +689,7 @@ auto handle_torch_function_indexing( auto size = PyTuple_GET_SIZE(index_tup.ptr()); for (auto i : c10::irange(size)) { auto* obj = PyTuple_GetItem(index_tup.ptr(), i); +<<<<<<< HEAD auto r = is_tensor_and_append_overloaded(obj, &overridable_args); if (!r && PySequence_Check(obj)) { auto inner_size = PySequence_Length(obj); @@ -693,6 +706,9 @@ auto handle_torch_function_indexing( } } } +======= + is_tensor_and_append_overloaded(obj, &overridable_args); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if (val != nullptr) { is_tensor_and_append_overloaded(val, &overridable_args); @@ -819,15 +835,20 @@ bool is_tensor_and_append_overloaded( return false; } +<<<<<<< HEAD static bool is_scalar_list( PyObject* obj, std::vector* overloaded_args = nullptr) { +======= +static bool is_scalar_list(PyObject* obj) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto tuple = six::isTuple(obj); if (!(tuple || PyList_Check(obj))) { return false; } // NOLINTNEXTLINE(bugprone-branch-clone) const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); +<<<<<<< HEAD bool has_torch_func = false; for (const auto idx : c10::irange(size)) { @@ -842,6 +863,12 @@ static bool is_scalar_list( } if (!THPUtils_checkScalar(iobj) && !has_torch_func) { +======= + for (const auto idx : c10::irange(size)) { + PyObject* iobj = + tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx); + if (!THPUtils_checkScalar(iobj)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return false; } } @@ -891,9 +918,13 @@ static bool is_float_or_symfloat(PyObject* obj) { return false; } +<<<<<<< HEAD static bool is_float_or_complex_list( PyObject* obj, std::vector* overloaded_args = nullptr) { +======= +static bool is_float_or_complex_list(PyObject* obj) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto tuple = six::isTuple(obj); if (!(tuple || PyList_Check(obj))) { return false; @@ -901,6 +932,7 @@ static bool is_float_or_complex_list( // NOLINTNEXTLINE(bugprone-branch-clone) const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj); +<<<<<<< HEAD bool has_torch_func = false; for (long idx = 0; idx < size; idx++) { @@ -920,6 +952,12 @@ static bool is_float_or_complex_list( !has_torch_func) { return false; } +======= + if (size > 0) { + PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0); + if (!is_float_or_symfloat(iobj) && !PyComplex_Check(iobj)) { + return false; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -927,6 +965,7 @@ static bool is_float_or_complex_list( } static bool is_int_or_symint(PyObject* obj) { +<<<<<<< HEAD // Call checkLong first so that actual ints go fast. if (THPUtils_checkLong(obj)) { return true; @@ -935,6 +974,12 @@ static bool is_int_or_symint(PyObject* obj) { // THPUtils_checkIndex may call __index__ or __int__ // which may have side effects if obj is a symint node // so we do `is_symint` check first +======= + // THPUtils_checkIndex may call __index__ or __int__ + // which may have side effects if obj is a symint node + // so we do `is_symint` check first + // TODO: maybe we should be using checkLong here? +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (torch::is_symint(py::handle(obj))) { return true; } @@ -964,6 +1009,7 @@ static bool is_int_or_symint(PyObject* obj) { static bool is_int_or_symint_list( PyObject* obj, int broadcast_size, +<<<<<<< HEAD int64_t* failed_idx = nullptr, std::vector* overloaded_args = nullptr) { const bool is_tuple = PyTuple_Check(obj); @@ -1009,6 +1055,28 @@ static bool is_int_or_symint_list( } return true; +======= + int64_t* failed_idx = nullptr) { + if (PyTuple_Check(obj) || PyList_Check(obj)) { + if (PySequence_Size(obj) == 0) { + return true; + } + auto item = py::reinterpret_steal(PySequence_GetItem(obj, 0)); + + if (is_int_or_symint(item.ptr())) { + return true; + } + + // NOTE: JIT tracer allows arbitrary scalar tensors to act as ints + // in an intlist argument. Even float or complex scalar tensors. + bool r = + (jit::tracer::isTracing() && THPVariable_Check(item.ptr()) && + THPVariable_Unpack(item.ptr()).sizes().empty()); + if (!r && failed_idx != nullptr) { + *failed_idx = 0; + } + return r; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single @@ -1022,6 +1090,7 @@ auto FunctionParameter::check( std::vector& overloaded_args, int argnum, int64_t* failed_idx) -> bool { +<<<<<<< HEAD if (_check(obj, overloaded_args, argnum, failed_idx)) { return true; } @@ -1043,6 +1112,8 @@ auto FunctionParameter::_check( std::vector& overloaded_args, int argnum, int64_t* failed_idx) -> bool { +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) switch (type_) { case ParameterType::TENSOR: { if (is_tensor_and_append_overloaded(obj, &overloaded_args)) { @@ -1108,7 +1179,11 @@ auto FunctionParameter::_check( obj, &overloaded_args, argnum, true /* throw_error */); } case ParameterType::FLOAT_LIST: +<<<<<<< HEAD return is_float_or_complex_list(obj, &overloaded_args); +======= + return is_float_or_complex_list(obj); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case ParameterType::GENERATOR: return THPGenerator_Check(obj); case ParameterType::BOOL: @@ -1118,7 +1193,19 @@ auto FunctionParameter::_check( case ParameterType::PYOBJECT: return true; case ParameterType::SCALARTYPE: +<<<<<<< HEAD return THPDtype_Check(obj) || THPPythonScalarType_Check(obj); +======= + if (THPDtype_Check(obj) || THPPythonScalarType_Check(obj)) { + return true; + } + if (check_has_torch_function(obj, /*ignore_mode*/ true)) { + // tensor subclasses and unrelated objects with __torch_function__ + append_overloaded_arg(&overloaded_args, obj, /*obj_is_type*/ false); + return true; + } + return false; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case ParameterType::LAYOUT: return THPLayout_Check(obj); case ParameterType::MEMORY_FORMAT: @@ -1135,13 +1222,21 @@ auto FunctionParameter::_check( case ParameterType::STRING: return THPUtils_checkString(obj); case ParameterType::SCALAR_LIST: +<<<<<<< HEAD return is_scalar_list(obj, &overloaded_args); +======= + return is_scalar_list(obj); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case ParameterType::SYM_INT: return is_int_or_symint(obj); // Allow SymInt where int is expected; we'll guard in this case case ParameterType::INT_LIST: case ParameterType::SYM_INT_LIST: +<<<<<<< HEAD return is_int_or_symint_list(obj, size, failed_idx, &overloaded_args); +======= + return is_int_or_symint_list(obj, size, failed_idx); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case ParameterType::DISPATCH_KEY_SET: return py::isinstance(py::handle(obj)); default: @@ -1514,6 +1609,7 @@ std::string FunctionSignature::toString() const { const auto min_args = signature.min_args; const long nargs_ = nargs; if (min_args != max_pos_args) { +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format( @@ -1532,6 +1628,22 @@ std::string FunctionSignature::toString() const { max_pos_args == 1 ? "" : "s", nargs_, nargs == 1 ? "was" : "were")); +======= + throw TypeError( + "%s() takes from %zu to %zu positional arguments but %ld were given", + signature.name.c_str(), + min_args, + max_pos_args, + nargs_); + } + throw TypeError( + "%s() takes %zu positional argument%s but %ld %s given", + signature.name.c_str(), + max_pos_args, + max_pos_args == 1 ? "" : "s", + nargs_, + nargs == 1 ? "was" : "were"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } [[noreturn]] static void missing_args( @@ -1551,6 +1663,7 @@ std::string FunctionSignature::toString() const { } } +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format( @@ -1559,6 +1672,14 @@ std::string FunctionSignature::toString() const { num_missing, num_missing == 1 ? "s" : "", ss.str())); +======= + throw TypeError( + "%s() missing %d required positional argument%s: %s", + signature.name.c_str(), + num_missing, + num_missing == 1 ? "s" : "", + ss.str().c_str()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } static Py_ssize_t find_param(FunctionSignature& signature, PyObject* name) { @@ -1587,11 +1708,16 @@ static Py_ssize_t find_param(FunctionSignature& signature, PyObject* name) { // accessible within this thread. while (PyDict_Next(kwargs, &pos, &key, &value)) { if (!THPUtils_checkString(key)) { +<<<<<<< HEAD TORCH_CHECK_TYPE(false, "keywords must be strings"); +======= + throw TypeError("keywords must be strings"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } auto param_idx = find_param(signature, key); if (param_idx < 0) { +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format( @@ -1607,11 +1733,28 @@ static Py_ssize_t find_param(FunctionSignature& signature, PyObject* name) { "{}() got multiple values for argument '{}'", signature.name, THPUtils_unpackString(key))); +======= + throw TypeError( + "%s() got an unexpected keyword argument '%s'", + signature.name.c_str(), + THPUtils_unpackString(key).c_str()); + } + + if (param_idx < num_pos_args) { + throw TypeError( + "%s() got multiple values for argument '%s'", + signature.name.c_str(), + THPUtils_unpackString(key).c_str()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } // this should never be hit +<<<<<<< HEAD TORCH_CHECK_TYPE(false, "invalid keyword arguments"); +======= + throw TypeError("invalid keyword arguments"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } bool FunctionSignature::parse( @@ -1689,8 +1832,12 @@ bool FunctionSignature::parse( // should avoid having complex signatures that make use of it... } else if ( varargs_eligible && +<<<<<<< HEAD (is_int_or_symint_list( args, param.size, &failed_idx, &overloaded_args))) { +======= + (is_int_or_symint_list(args, param.size, &failed_idx))) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // take all positional arguments as this parameter // e.g. permute(1, 2, 3) -> permute((1, 2, 3)) dst[i++] = args; @@ -1699,6 +1846,7 @@ bool FunctionSignature::parse( } else if (raise_exception) { if (is_kwd) { // foo(): argument 'other' must be str, not int +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format( @@ -1707,6 +1855,14 @@ bool FunctionSignature::parse( param.name, param.type_name(), Py_TYPE(obj)->tp_name)); +======= + throw TypeError( + "%s(): argument '%s' must be %s, not %s", + name.c_str(), + param.name.c_str(), + param.type_name().c_str(), + Py_TYPE(obj)->tp_name); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { // foo(): argument 'other' (position 2) must be str, not int if (failed_idx != -1) { @@ -1715,6 +1871,7 @@ bool FunctionSignature::parse( obj = args; } TORCH_INTERNAL_ASSERT(failed_idx < PySequence_Size(obj)); +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format( @@ -1738,6 +1895,27 @@ bool FunctionSignature::parse( arg_pos + 1, param.type_name(), Py_TYPE(obj)->tp_name)); +======= + throw TypeError( + "%s(): argument '%s' (position %ld) must be %s, but found element of type %s at pos %ld", + name.c_str(), + param.name.c_str(), + static_cast(arg_pos + 1), + param.type_name().c_str(), + Py_TYPE(py::reinterpret_steal( + PySequence_GetItem(obj, failed_idx)) + .ptr()) + ->tp_name, + static_cast(failed_idx)); + } + throw TypeError( + "%s(): argument '%s' (position %ld) must be %s, not %s", + name.c_str(), + param.name.c_str(), + static_cast(arg_pos + 1), + param.type_name().c_str(), + Py_TYPE(obj)->tp_name); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } else { return false; @@ -1859,7 +2037,11 @@ void PythonArgParser::print_error( auto options = get_signatures(); auto msg = torch::format_invalid_args(args, kwargs, function_name + "()", options); +<<<<<<< HEAD TORCH_CHECK_TYPE(false, msg); +======= + throw TypeError("%s", msg.c_str()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } std::vector PythonArgParser::get_signatures() const { @@ -1886,7 +2068,25 @@ at::Tensor PythonArgs::tensor_slow(int i) { if (PyBool_Check(obj)) { scalar = at::Scalar(THPUtils_unpackBool(obj)); } else if (THPUtils_checkLong(obj)) { +<<<<<<< HEAD scalar = THPUtils_unpackInteger(obj); +======= + int overflow = -1; + long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); + if (value == -1 && PyErr_Occurred()) { + throw python_error(); + } + if (overflow != 0) { + // try unsigned + unsigned long long value = PyLong_AsUnsignedLongLong(obj); + if (value == static_cast(-1) && PyErr_Occurred()) { + throw python_error(); + } + scalar = at::Scalar(static_cast(value)); + } else { + scalar = at::Scalar(static_cast(value)); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else if (PyComplex_Check(obj)) { scalar = at::Scalar(THPUtils_unpackComplexDouble(obj)); } else if (THPUtils_checkDouble(obj)) { @@ -1912,12 +2112,17 @@ at::Tensor PythonArgs::tensor_slow(int i) { // a test for Py_None here; instead, you need to mark the argument // as *allowing none*; you can do this by writing 'Tensor?' instead // of 'Tensor' in the ATen metadata. +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format( "expected Tensor as argument {}, but got {}", i, Py_TYPE(obj)->tp_name)); +======= + throw TypeError( + "expected Tensor as argument %d, but got %s", i, Py_TYPE(obj)->tp_name); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove at::tracer::impl::NoTracerDispatchMode tracer_guard; diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index a81f861ae9030..161cc083cb27d 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -39,7 +39,10 @@ // Scalar and Tensor, UNLESS they require grad (in which case // they only bind to Tensor). +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -322,12 +325,15 @@ struct FunctionParameter { int argnum, int64_t* failed_idx = nullptr); +<<<<<<< HEAD bool _check( PyObject* obj, std::vector& overloaded_args, int argnum, int64_t* failed_idx = nullptr); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void set_default_str(const std::string& str); TORCH_PYTHON_API std::string type_name() const; @@ -497,9 +503,13 @@ inline std::array PythonArgs::tensorlist_n(int i) { // NOLINTNEXTLINE(bugprone-branch-clone) auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); if (size != N) { +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format("expected tuple of {} elements but got {}", N, size)); +======= + throw TypeError("expected tuple of %d elements but got %d", N, (int)size); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } for (const auto idx : c10::irange(size)) { PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) @@ -537,6 +547,7 @@ inline void throw_intlist_exception( ? e.what() : std::string("type must be ") + args->signature.params[i].type_name() + ",but got " + Py_TYPE(obj)->tp_name; +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format( @@ -545,6 +556,14 @@ inline void throw_intlist_exception( args->signature.params[i].name, idx + 1, error)); +======= + throw TypeError( + "%s(): argument '%s' failed to unpack the object at pos %zu with error \"%s\"", + args->signature.name.c_str(), + args->signature.params[i].name.c_str(), + idx + 1, + error.c_str()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } inline std::vector PythonArgs::symintlist(int i) { @@ -723,6 +742,7 @@ inline std::vector PythonArgs::getDoublelist(int i) { res[idx] = THPUtils_unpackDouble(obj); } } catch (const std::exception&) { +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format( @@ -732,6 +752,15 @@ inline std::vector PythonArgs::getDoublelist(int i) { signature.params[i].type_name(), Py_TYPE(obj)->tp_name, idx + 1)); +======= + throw TypeError( + "%s(): argument '%s' must be %s, but found element of type %s at pos %zu", + signature.name.c_str(), + signature.params[i].name.c_str(), + signature.params[i].type_name().c_str(), + Py_TYPE(obj)->tp_name, + idx + 1); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } return res; @@ -1059,6 +1088,7 @@ inline double PythonArgs::toDouble(int i) { } inline bool PythonArgs::toBool(int i) { +<<<<<<< HEAD if (!args[i]) { return signature.params[i].default_bool; } @@ -1068,11 +1098,19 @@ inline bool PythonArgs::toBool(int i) { if (args[i] == Py_False) { return false; } +======= + if (!args[i]) + return signature.params[i].default_bool; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (torch::is_symbool(py::handle(args[i]))) { return py::cast(py::handle(args[i])) .guard_bool(__FILE__, __LINE__); } +<<<<<<< HEAD return false; +======= + return args[i] == Py_True; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } inline double PythonArgs::toDoubleWithDefault(int i, double default_double) { @@ -1139,10 +1177,15 @@ inline c10::Stream PythonArgs::stream(int i) { return c10::Stream( c10::Stream::Default::DEFAULT, c10::Device(c10::DeviceType::CPU, -1)); if (!THPStream_Check(args[i])) { +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format( "expected Stream object. Got '{}'", Py_TYPE(args[i])->tp_name)); +======= + throw TypeError( + "expected Stream object. Got '%s'", Py_TYPE(args[i])->tp_name); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return c10::Stream::unpack3( ((THPStream*)args[i])->stream_id, diff --git a/torch/csrc/utils/python_compat.h b/torch/csrc/utils/python_compat.h index 16292e4fd0308..52ff18f2bea36 100644 --- a/torch/csrc/utils/python_compat.h +++ b/torch/csrc/utils/python_compat.h @@ -13,7 +13,10 @@ extern "C" { #define IS_PYTHON_3_12_PLUS PY_VERSION_HEX >= 0x030C0000 #define IS_PYTHON_3_13_PLUS PY_VERSION_HEX >= 0x030D0000 #define IS_PYTHON_3_14_PLUS PY_VERSION_HEX >= 0x030E0000 +<<<<<<< HEAD #define IS_PYTHON_3_15_PLUS PY_VERSION_HEX >= 0x030F0000 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) static inline int PyCode_GetNCellvars(PyCodeObject* code) { // gh-26364 added co_ncellvars to Python 3.11.0rc1 diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 07fa4ea5e1dd7..18e6707046864 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -2,11 +2,17 @@ #include #include +<<<<<<< HEAD #include #include #include #include #include +======= +#include +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -27,8 +33,11 @@ #include #include +<<<<<<< HEAD #include #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -36,10 +45,13 @@ namespace py = pybind11; namespace torch::impl::dispatch { +<<<<<<< HEAD // Global storage for leaked Python filenames to ensure they remain valid // for the lifetime of Library objects static std::vector leaked_python_filenames_; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // NB: I'd like to index this on OperatorHandle, but I can't, as I can't // guarantee that the main interpreter has finish doing all registrations before // the other interpreters start banging on it @@ -194,6 +206,18 @@ class PythonKernelHolder : public c10::OperatorKernel { auto arguments = torch::jit::pop(*stack, op.schema().arguments().size()); py::gil_scoped_acquire g; +<<<<<<< HEAD +======= + // Jan 2024: We're slated to get rid of multipy, // codespell:ignore multipy + // so stop forcing hermetic mode unconditionally in all situations when + // you're using multipy. // codespell:ignore multipy + // Eventually just delete this entirely. (Note that you may break + // multipy anyway this way with dispatcher // codespell:ignore multipy + // registered functions that require hermetic to be off.) +#if defined(USE_DEPLOY) + EnableHermeticPyObject g2; +#endif +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); auto func = py::reinterpret_borrow(func_.ptr(getPyInterpreter())); @@ -216,10 +240,19 @@ class PythonKernelHolder : public c10::OperatorKernel { } }; +<<<<<<< HEAD // @todo sahanp: Afait only register is used in the codebase. This can be // removed / simplified static torch::_RegisterOrVerify register_or_verify() { return torch::_RegisterOrVerify::REGISTER; +======= +static torch::_RegisterOrVerify register_or_verify() { + if (isMainPyInterpreter()) { + return torch::_RegisterOrVerify::REGISTER; + } else { + return torch::_RegisterOrVerify::VERIFY; + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } static py::object ophandle_call_boxed( @@ -292,6 +325,10 @@ void initDispatchBindings(PyObject* module) { .def( "reset", [](const py::object& self) { +<<<<<<< HEAD +======= + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.cast().reset(); return; }, @@ -301,6 +338,10 @@ void initDispatchBindings(PyObject* module) { .def( "def_", [](py::object self, const char* schema, const char* alias) { +<<<<<<< HEAD +======= + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.cast().def( torch::schema(schema, parseAliasAnalysisKind(alias))); return self; @@ -314,6 +355,10 @@ void initDispatchBindings(PyObject* module) { .def( "def_legacy", [](py::object self, const char* schema) { +<<<<<<< HEAD +======= + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.cast().def(torch::jit::parseSchema(schema)); return self; }, @@ -333,6 +378,10 @@ void initDispatchBindings(PyObject* module) { const char* name, const char* dispatch, const char* debug) { +<<<<<<< HEAD +======= + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.cast().def( name, dispatch_str(dispatch, [](const at::Tensor& a) { return a; @@ -350,6 +399,10 @@ void initDispatchBindings(PyObject* module) { const char* dispatch, const char* alias, const char* debug) { +<<<<<<< HEAD +======= + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.cast().def( torch::schema(schema, parseAliasAnalysisKind(alias)), dispatch_str(dispatch, [](const at::Tensor& a) { @@ -370,6 +423,10 @@ void initDispatchBindings(PyObject* module) { const char* name, const char* dispatch, const char* debug) { +<<<<<<< HEAD +======= + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.cast().impl( name, dispatch_str(dispatch, [](const at::Tensor& a) { return a; @@ -464,6 +521,10 @@ void initDispatchBindings(PyObject* module) { .def( "fallback_fallthrough", [](py::object self, const char* dispatch) { +<<<<<<< HEAD +======= + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.cast().fallback( dispatch_str(dispatch, CppFunction::makeFallthrough())); return self; @@ -478,6 +539,10 @@ void initDispatchBindings(PyObject* module) { bool with_keyset) { HANDLE_TH_ERRORS auto& lib = self.cast(); +<<<<<<< HEAD +======= + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (func.is(py::module::import("torch.library") .attr("fallthrough_kernel"))) { lib.fallback( @@ -504,18 +569,25 @@ void initDispatchBindings(PyObject* module) { const char* file, uint32_t linenum) { HANDLE_TH_ERRORS +<<<<<<< HEAD // Store the file string in global storage to ensure it remains valid // for the lifetime of the Library object leaked_python_filenames_.emplace_back(file); const char* leaked_file = leaked_python_filenames_.back().c_str(); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return std::make_unique( parseKind(kind), std::move(name), std::string(dispatch).empty() ? std::nullopt : std::make_optional(c10::parseDispatchKey(dispatch)), +<<<<<<< HEAD leaked_file, +======= + "/dev/null", // temporary workaround +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) linenum); END_HANDLE_TH_ERRORS_PYBIND }, @@ -527,12 +599,15 @@ void initDispatchBindings(PyObject* module) { py::arg("linenum") = 0); m.def( +<<<<<<< HEAD "_dispatch_clear_leaked_python_filenames", []() { leaked_python_filenames_.clear(); }, "Clear the global storage of leaked Python filenames. " "WARNING: Only call this if you're sure no Library objects are still using the filenames."); m.def( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "_dispatch_find_schema_or_throw", [](const char* name, const char* overload_name) -> c10::OperatorHandle { return c10::Dispatcher::singleton().findSchemaOrThrow( @@ -921,6 +996,11 @@ void initDispatchBindings(PyObject* module) { handle.setReportErrorCallback_(std::move(callback_obj)); }); +<<<<<<< HEAD +======= + m.def( + "_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); }); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m.def("_dispatch_pystub", [](const char* name, const char* overload) { return c10::Dispatcher::singleton().getPyStub( c10::OperatorName(name, overload)); @@ -956,6 +1036,7 @@ void initDispatchBindings(PyObject* module) { include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode)); }); +<<<<<<< HEAD m.def("_autocast_supported_devices", []() { std::vector result; for (const auto device_type : at::autocast::_AUTOCAST_SUPPORTED_DEVICES) { @@ -965,6 +1046,8 @@ void initDispatchBindings(PyObject* module) { return result; }); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) m.def("_get_nested_int", [](int64_t data, int64_t coeff) { return c10::SymInt(c10::SymNode( c10::make_intrusive(data, coeff))); @@ -1009,6 +1092,7 @@ void initDispatchBindings(PyObject* module) { m.def("_only_lift_cpu_tensors", &torch::utils::only_lift_cpu_tensors); m.def("_set_only_lift_cpu_tensors", &torch::utils::set_only_lift_cpu_tensors); +<<<<<<< HEAD m.def( "_get_dtensor_allow_implicit_replication", &at::get_dtensor_allow_implicit_replication); @@ -1016,6 +1100,8 @@ void initDispatchBindings(PyObject* module) { "_set_dtensor_allow_implicit_replication", &at::set_dtensor_allow_implicit_replication); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using c10::impl::TorchDispatchModeKey; py::enum_(m, "_TorchDispatchModeKey") .value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL) diff --git a/torch/csrc/utils/python_numbers.h b/torch/csrc/utils/python_numbers.h index a8b9b8632a00b..26626d2865d73 100644 --- a/torch/csrc/utils/python_numbers.h +++ b/torch/csrc/utils/python_numbers.h @@ -62,11 +62,21 @@ inline int32_t THPUtils_unpackInt(PyObject* obj) { if (value == -1 && PyErr_Occurred()) { throw python_error(); } +<<<<<<< HEAD TORCH_CHECK_VALUE(overflow == 0, "Overflow when unpacking long long"); TORCH_CHECK_VALUE( value <= std::numeric_limits::max() && value >= std::numeric_limits::min(), "Overflow when unpacking long"); +======= + if (overflow != 0) { + throw std::runtime_error("Overflow when unpacking long"); + } + if (value > std::numeric_limits::max() || + value < std::numeric_limits::min()) { + throw std::runtime_error("Overflow when unpacking long"); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (int32_t)value; } @@ -76,7 +86,13 @@ inline int64_t THPUtils_unpackLong(PyObject* obj) { if (value == -1 && PyErr_Occurred()) { throw python_error(); } +<<<<<<< HEAD TORCH_CHECK_VALUE(overflow == 0, "Overflow when unpacking long long"); +======= + if (overflow != 0) { + throw std::runtime_error("Overflow when unpacking long"); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (int64_t)value; } @@ -85,9 +101,15 @@ inline uint32_t THPUtils_unpackUInt32(PyObject* obj) { if (PyErr_Occurred()) { throw python_error(); } +<<<<<<< HEAD TORCH_CHECK_VALUE( value <= std::numeric_limits::max(), "Overflow when unpacking long long"); +======= + if (value > std::numeric_limits::max()) { + throw std::runtime_error("Overflow when unpacking unsigned long"); + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return (uint32_t)value; } @@ -163,6 +185,7 @@ inline c10::complex THPUtils_unpackComplexDouble(PyObject* obj) { } inline bool THPUtils_unpackNumberAsBool(PyObject* obj) { +<<<<<<< HEAD #ifdef USE_NUMPY // Handle NumPy boolean scalars (np.bool_) if (torch::utils::is_numpy_bool(obj)) { @@ -173,6 +196,8 @@ inline bool THPUtils_unpackNumberAsBool(PyObject* obj) { return truth != 0; } #endif +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (PyFloat_Check(obj)) { return (bool)PyFloat_AS_DOUBLE(obj); } @@ -208,6 +233,7 @@ inline c10::DeviceIndex THPUtils_unpackDeviceIndex(PyObject* obj) { } return (c10::DeviceIndex)value; } +<<<<<<< HEAD template inline T THPUtils_unpackInteger(PyObject* obj) { @@ -227,3 +253,5 @@ inline T THPUtils_unpackInteger(PyObject* obj) { } return static_cast(uvalue); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/csrc/utils/python_strings.h b/torch/csrc/utils/python_strings.h index 1d26c4333bc2b..92a63af649c94 100644 --- a/torch/csrc/utils/python_strings.h +++ b/torch/csrc/utils/python_strings.h @@ -116,7 +116,11 @@ inline py::object PyObject_FastGetAttrString(PyObject* obj, const char* name) { } /* Attribute referenced by (PyObject *)name */ else if (tp->tp_getattro != nullptr) { +<<<<<<< HEAD auto w = py::reinterpret_steal(PyUnicode_FromString(name)); +======= + auto w = py::reinterpret_steal(THPUtils_internString(name)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (w.ptr() == nullptr) { return py::object(); } diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 35511300f703e..421ff01366cef 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -304,7 +304,11 @@ Tensor internal_new_from_data( TORCH_CHECK( !pin_memory, "Can't pin tensor constructed from __cuda_array_interface__"); +<<<<<<< HEAD auto tensor = tensor_from_cuda_array_interface(data, device_opt); +======= + auto tensor = tensor_from_cuda_array_interface(data); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto& inferred_scalar_type = type_inference ? tensor.scalar_type() : scalar_type; @@ -556,7 +560,10 @@ void check_base_legacy_new( c10::DispatchKey::SparseCUDA, c10::DispatchKey::SparseHIP, c10::DispatchKey::SparseXPU, +<<<<<<< HEAD c10::DispatchKey::SparseMPS, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::DispatchKey::SparsePrivateUse1, }); TORCH_CHECK( @@ -670,6 +677,7 @@ Tensor legacy_sparse_tensor_generic_ctor_new( // new(sequence) binds to this signature but should be treated differently // unless the sequences is a torch.Size if (ctor_or_new == CtorOrNew::CTOR) { +<<<<<<< HEAD TORCH_CHECK_TYPE( false, "torch.sparse.SparseTensor(sequence) only accepts sizes. Please use torch.sparse_coo_tensor() " @@ -677,6 +685,13 @@ Tensor legacy_sparse_tensor_generic_ctor_new( } else { TORCH_CHECK_TYPE( false, +======= + throw TypeError( + "torch.sparse.SparseTensor(sequence) only accepts sizes. Please use torch.sparse_coo_tensor() " + "or construct a strided tensor and convert it to sparse via to_sparse."); + } else { + throw TypeError( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "SparseTensor.new(sequence) only accepts sizes. Please use torch.sparse_coo_tensor() " "or construct a strided tensor and convert it to sparse via to_sparse."); } @@ -1656,6 +1671,7 @@ Tensor tensor_frombuffer( return tensor; } +<<<<<<< HEAD namespace { template @@ -1673,6 +1689,21 @@ at::Tensor tensor_fromDLPackImpl(PyObject* data, T* tensor) { } else { tensor->deleter(tensor); } +======= +Tensor tensor_fromDLPack(PyObject* data) { + DLManagedTensor* dlMTensor = + (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + TORCH_CHECK( + dlMTensor, + "from_dlpack received an invalid capsule. " + "Note that DLTensor capsules can be consumed only once, " + "so you might have already constructed a tensor from it once."); + + auto deleter_with_gil = [dlMTensor](void*) { + if (dlMTensor->deleter) { + pybind11::gil_scoped_acquire gil; + dlMTensor->deleter(dlMTensor); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } }; @@ -1680,11 +1711,22 @@ at::Tensor tensor_fromDLPackImpl(PyObject* data, T* tensor) { // destructor function that will be called when the underlying storage goes // out of scope. When the destructor is called, the dlMTensor is destructed // too. +<<<<<<< HEAD auto atensor = at::DLPackTraits::fromDLPack(tensor, std::move(deleter_maybe_gil)); // Make sure this capsule will never be used again. PyCapsule_SetName(data, at::DLPackTraits::used); +======= + // HACK: Ensure that we hold the GIL here just in case the + // managed tensor originating from a buggy NumPy build. + auto atensor = torch::utils::is_numpy_dlpack_deleter_bugged() + ? at::fromDLPack(dlMTensor, std::move(deleter_with_gil)) + : at::fromDLPack(dlMTensor); + + // Make sure this capsule will never be used again. + PyCapsule_SetName(data, "used_dltensor"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // It is possible that the call to at::fromDLPack is the very first // call to create a Tensor in PyTorch. If so, then _lazy_init has @@ -1696,6 +1738,7 @@ at::Tensor tensor_fromDLPackImpl(PyObject* data, T* tensor) { return atensor; } +<<<<<<< HEAD // Check whether `data` is a valid DLPack capsule. // This function checks for the versioned and unversioned forms. bool isValidDLPackCapsule(PyObject* data) { @@ -1734,6 +1777,8 @@ Tensor tensor_fromDLPack(PyObject* data) { } } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor asarray( PyObject* obj, std::optional dtype, @@ -1798,7 +1843,11 @@ Tensor asarray( #endif // Check whether 'obj' is a 'DLPack' capsule +<<<<<<< HEAD if (!tensor.defined() && isValidDLPackCapsule(obj)) { +======= + if (!tensor.defined() && PyCapsule_IsValid(obj, "dltensor") != 0) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor = tensor_fromDLPack(obj); } diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index b9839a79f6110..087ca5e71c2ac 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -1,4 +1,7 @@ +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #define WITH_NUMPY_IMPORT_ARRAY @@ -27,9 +30,13 @@ bool is_numpy_int(PyObject* obj) { bool is_numpy_scalar(PyObject* obj) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } +<<<<<<< HEAD at::Tensor tensor_from_cuda_array_interface( PyObject* obj, std::optional device_opt) { +======= +at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) throw std::runtime_error("PyTorch was compiled without NumPy support"); } @@ -108,7 +115,11 @@ static std::vector to_aten_shape(int ndim, npy_intp* values) { static std::vector seq_to_aten_shape(PyObject* py_seq) { int ndim = PySequence_Length(py_seq); if (ndim == -1) { +<<<<<<< HEAD TORCH_CHECK_TYPE(false, "shape and strides must be sequences"); +======= + throw TypeError("shape and strides must be sequences"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } auto result = std::vector(ndim); for (const auto i : c10::irange(ndim)) { @@ -306,8 +317,12 @@ int aten_to_numpy_dtype(const ScalarType scalar_type) { case kBool: return NPY_BOOL; default: +<<<<<<< HEAD TORCH_CHECK_TYPE( false, "Got unsupported ScalarType ", toString(scalar_type)); +======= + throw TypeError("Got unsupported ScalarType %s", toString(scalar_type)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -359,12 +374,19 @@ ScalarType numpy_dtype_to_aten(int dtype) { auto pytype = THPObjectPtr(PyArray_TypeObjectFromType(dtype)); if (!pytype) throw python_error(); +<<<<<<< HEAD TORCH_CHECK_TYPE( false, fmt::format( "can't convert np.ndarray of type {}. The only supported types are: " "float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint64, uint32, uint16, uint8, and bool.", ((PyTypeObject*)pytype.get())->tp_name)); +======= + throw TypeError( + "can't convert np.ndarray of type %s. The only supported types are: " + "float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint64, uint32, uint16, uint8, and bool.", + ((PyTypeObject*)pytype.get())->tp_name); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } bool is_numpy_int(PyObject* obj) { @@ -382,9 +404,13 @@ bool is_numpy_scalar(PyObject* obj) { PyArray_IsScalar(obj, ComplexFloating)); } +<<<<<<< HEAD at::Tensor tensor_from_cuda_array_interface( PyObject* obj, std::optional device_opt) { +======= +at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!is_numpy_available()) { throw std::runtime_error("Numpy is not available"); } @@ -393,7 +419,11 @@ at::Tensor tensor_from_cuda_array_interface( TORCH_INTERNAL_ASSERT(cuda_dict); if (!PyDict_Check(cuda_dict.get())) { +<<<<<<< HEAD TORCH_CHECK_TYPE(false, "`__cuda_array_interface__` must be a dict"); +======= + throw TypeError("`__cuda_array_interface__` must be a dict"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // Extract the `obj.__cuda_array_interface__['shape']` attribute @@ -404,7 +434,11 @@ at::Tensor tensor_from_cuda_array_interface( throw python_error(); } if (py_shape == nullptr) { +<<<<<<< HEAD TORCH_CHECK_TYPE(false, "attribute `shape` must exist"); +======= + throw TypeError("attribute `shape` must exist"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } sizes = seq_to_aten_shape(py_shape); } @@ -418,7 +452,11 @@ at::Tensor tensor_from_cuda_array_interface( throw python_error(); } if (py_typestr == nullptr) { +<<<<<<< HEAD TORCH_CHECK_TYPE(false, "attribute `typestr` must exist"); +======= + throw TypeError("attribute `typestr` must exist"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } PyArray_Descr* descr = nullptr; TORCH_CHECK_VALUE( @@ -440,10 +478,17 @@ at::Tensor tensor_from_cuda_array_interface( throw python_error(); } if (py_data == nullptr) { +<<<<<<< HEAD TORCH_CHECK_TYPE(false, "attribute `shape` data exist"); } if (!PyTuple_Check(py_data) || PyTuple_GET_SIZE(py_data) != 2) { TORCH_CHECK_TYPE(false, "`data` must be a 2-tuple of (int, bool)"); +======= + throw TypeError("attribute `shape` data exist"); + } + if (!PyTuple_Check(py_data) || PyTuple_GET_SIZE(py_data) != 2) { + throw TypeError("`data` must be a 2-tuple of (int, bool)"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } data_ptr = PyLong_AsVoidPtr(PyTuple_GET_ITEM(py_data, 0)); if (data_ptr == nullptr && PyErr_Occurred()) { @@ -454,8 +499,13 @@ at::Tensor tensor_from_cuda_array_interface( throw python_error(); } if (read_only) { +<<<<<<< HEAD TORCH_CHECK_TYPE( false, "the read only flag is not supported, should always be False"); +======= + throw TypeError( + "the read only flag is not supported, should always be False"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } @@ -469,8 +519,13 @@ at::Tensor tensor_from_cuda_array_interface( if (py_strides != nullptr && py_strides != Py_None) { if (PySequence_Length(py_strides) == -1 || static_cast(PySequence_Length(py_strides)) != sizes.size()) { +<<<<<<< HEAD TORCH_CHECK_TYPE( false, "strides must be a sequence of the same length as shape"); +======= + throw TypeError( + "strides must be a sequence of the same length as shape"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } strides = seq_to_aten_shape(py_strides); @@ -493,6 +548,7 @@ at::Tensor tensor_from_cuda_array_interface( // ref: // https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html#cuda-array-interface-version-3 if (data_ptr != nullptr) { +<<<<<<< HEAD if (device_opt.has_value() && device_opt->has_index()) { // if device_opt is provided with explicit device index, use it return device_opt; @@ -500,6 +556,9 @@ at::Tensor tensor_from_cuda_array_interface( // otherwise infer from cudaPointerGetAttributes later in from_blob return std::nullopt; } +======= + return {}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { const auto current_device = at::detail::getCUDAHooks().getCurrentDevice(); return Device( diff --git a/torch/csrc/utils/tensor_numpy.h b/torch/csrc/utils/tensor_numpy.h index 5f93cbb089c21..b8990f69b10b3 100644 --- a/torch/csrc/utils/tensor_numpy.h +++ b/torch/csrc/utils/tensor_numpy.h @@ -22,9 +22,13 @@ TORCH_API bool is_numpy_bool(PyObject* obj); TORCH_API bool is_numpy_scalar(PyObject* obj); void warn_numpy_not_writeable(); +<<<<<<< HEAD at::Tensor tensor_from_cuda_array_interface( PyObject* obj, std::optional device_opt = std::nullopt); +======= +at::Tensor tensor_from_cuda_array_interface(PyObject* obj); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void validate_numpy_for_dlpack_deleter_bug(); bool is_numpy_dlpack_deleter_bugged(); diff --git a/torch/csrc/utils/tensor_types.cpp b/torch/csrc/utils/tensor_types.cpp index d696a0cdf4ddd..ebd7e7bfa5425 100644 --- a/torch/csrc/utils/tensor_types.cpp +++ b/torch/csrc/utils/tensor_types.cpp @@ -39,8 +39,11 @@ const char* backend_to_string(const at::Backend& backend) { return "torch.cuda.sparse"; case at::Backend::SparseXPU: return "torch.xpu.sparse"; +<<<<<<< HEAD case at::Backend::SparseMPS: return "torch.mps.sparse"; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case at::Backend::QuantizedCPU: return "torch.quantized"; case at::Backend::HPU: diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index d49fc0539a087..d79415041af1b 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -295,6 +295,7 @@ static void registerXpuDeviceProperties(PyObject* module) { return static_cast(prop.architecture); }; #endif +<<<<<<< HEAD // Wrapper class for XPU UUID struct XPUuuid { XPUuuid(const std::array& uuid) : bytes(uuid) {} @@ -312,6 +313,10 @@ static void registerXpuDeviceProperties(PyObject* module) { return uuid_to_string(reinterpret_cast(uuid.bytes.data())); }); +======= + auto m = py::handle(module).cast(); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #define DEFINE_READONLY_MEMBER(member) \ def_readonly(#member, &DeviceProp::member) @@ -320,7 +325,10 @@ static void registerXpuDeviceProperties(PyObject* module) { ._(name) \ ._(platform_name) \ ._(vendor) \ +<<<<<<< HEAD ._(device_id) \ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ._(driver_version) \ ._(version) \ ._(max_compute_units) \ @@ -343,21 +351,29 @@ static void registerXpuDeviceProperties(PyObject* module) { .def_property_readonly("architecture", get_device_architecture) #endif .def_property_readonly("type", get_device_type) +<<<<<<< HEAD .def_property_readonly( "uuid", [](const DeviceProp& prop) -> XPUuuid { return XPUuuid(prop.uuid); }) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .def( "__repr__", [&get_device_type, &gpu_subslice_count](const DeviceProp& prop) { std::ostringstream stream; stream << "_XpuDeviceProperties(name='" << prop.name << "', platform_name='" << prop.platform_name << "', type='" +<<<<<<< HEAD << get_device_type(prop) << "', device_id=0x" << std::hex << std::uppercase << prop.device_id << std::dec << ", uuid=" << uuid_to_string( reinterpret_cast(prop.uuid.data())) << ", driver_version='" << prop.driver_version << "', total_memory=" +======= + << get_device_type(prop) << "', driver_version='" + << prop.driver_version << "', total_memory=" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) << prop.global_mem_size / (1024ull * 1024) << "MB" << ", max_compute_units=" << prop.max_compute_units << ", gpu_eu_count=" << prop.gpu_eu_count diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index e9049f036e1e2..364c5bc91c413 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -18,7 +18,11 @@ import traceback import warnings from functools import lru_cache +<<<<<<< HEAD from typing import Any, Callable, cast, NewType, Optional, TYPE_CHECKING, Union +======= +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._C @@ -236,10 +240,17 @@ def _sleep(cycles): torch._C._cuda_sleep(cycles) +<<<<<<< HEAD def _extract_arch_version(arch_string: str) -> int: """Extracts the architecture string from a CUDA version""" base = arch_string.split("_", maxsplit=2)[1] base = base.removesuffix("a").removesuffix("f") +======= +def _extract_arch_version(arch_string: str): + """Extracts the architecture string from a CUDA version""" + base = arch_string.split("_")[1] + base = base.removesuffix("a") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return int(base) @@ -259,7 +270,11 @@ def _check_capability(): CUDA_ARCHES_SUPPORTED = { "12.6": {"min": 50, "max": 90}, "12.8": {"min": 70, "max": 120}, +<<<<<<< HEAD "13.0": {"min": 75, "max": 120}, +======= + "12.9": {"min": 70, "max": 120}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } if ( @@ -407,6 +422,11 @@ def _lazy_init(): ) # This function throws if there's a driver initialization error, no GPUs # are found or any other error occurs +<<<<<<< HEAD +======= + if "CUDA_MODULE_LOADING" not in os.environ: + os.environ["CUDA_MODULE_LOADING"] = "LAZY" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._C._cuda_init() # Some of the queued calls may reentrantly call _lazy_init(); # we need to just return without initializing in that case. @@ -453,7 +473,11 @@ def cudart(): >>> from torch.cuda import cudart, check_error >>> import os >>> +<<<<<<< HEAD >>> os.environ["CUDA_PROFILE"] = "1" +======= + >>> os.environ['CUDA_PROFILE'] = '1' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> >>> def perform_cuda_operations_with_streams(): >>> stream = torch.cuda.Stream() @@ -1021,7 +1045,11 @@ def device_count() -> int: r""" Return the number of GPUs available. +<<<<<<< HEAD .. note:: This API will NOT poison fork if NVML discovery succeeds. +======= + .. note:: This API will NOT posion fork if NVML discovery succeeds. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) See :ref:`multiprocessing-poison-fork-note` for more details. """ global _cached_device_count @@ -1693,6 +1721,12 @@ def __call__(self, *args, **kwargs): def _register_triton_kernels(): +<<<<<<< HEAD +======= + if torch._running_with_deploy(): + return + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @_WrappedTritonKernel def kernel_impl(*args, **kwargs): from torch.sparse._triton_ops import bsr_dense_mm @@ -1770,7 +1804,11 @@ def _compile_kernel( >>> a = torch.randn(1024, device="cuda") >>> b = torch.randn(1024, device="cuda") >>> c = torch.empty_like(a) +<<<<<<< HEAD >>> add_kernel(grid=(4, 1, 1), block=(256, 1, 1), args=[a, b, c, a.numel()]) +======= + >>> add_kernel(grid=(4,1,1), block=(256,1,1), args=[a, b, c, a.numel()]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ import ctypes @@ -1800,9 +1838,12 @@ def _compile_kernel( from . import amp, jiterator, nvtx, profiler, sparse, tunable +<<<<<<< HEAD _POOL_HANDLE = NewType("_POOL_HANDLE", tuple[int, int]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __all__ = [ # Typed storage and tensors "BFloat16Storage", diff --git a/torch/cuda/_memory_viz.py b/torch/cuda/_memory_viz.py index 7f0f4fc3559ff..cd0a2648755e2 100644 --- a/torch/cuda/_memory_viz.py +++ b/torch/cuda/_memory_viz.py @@ -89,9 +89,13 @@ def _block_extra(b): def format_flamegraph(flamegraph_lines, flamegraph_script=None): if flamegraph_script is None: +<<<<<<< HEAD cache_dir = os.path.expanduser("~/.cache/") os.makedirs(cache_dir, exist_ok=True) flamegraph_script = f"{cache_dir}/flamegraph.pl" +======= + flamegraph_script = f"/tmp/{os.getuid()}_flamegraph.pl" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not os.path.exists(flamegraph_script): import tempfile import urllib.request @@ -102,8 +106,13 @@ def format_flamegraph(flamegraph_lines, flamegraph_script=None): "https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl", f.name, ) +<<<<<<< HEAD try: os.chmod(f.name, 0o755) +======= + subprocess.check_call(["chmod", "+x", f.name]) + try: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) os.rename(f.name, flamegraph_script) except OSError: # noqa: B001,E722 # Ok to skip, the file will be removed by tempfile @@ -133,7 +142,11 @@ def frames_fragment(frames): if "history" not in b: frames, accounted_for_size = _block_extra(b) f.write( +<<<<<<< HEAD f"{prefix};{b['state']};{frames_fragment(frames)} {accounted_for_size}\n" +======= + f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: accounted_for_size = 0 @@ -142,18 +155,31 @@ def frames_fragment(frames): accounted_for_size += sz if "frames" in h: frames = h["frames"] +<<<<<<< HEAD f.write(f"{prefix};{b['state']};{frames_fragment(frames)} {sz}\n") else: f.write(f"{prefix};{b['state']}; {sz}\n") gaps = b["size"] - accounted_for_size if gaps: f.write(f"{prefix};{b['state']}; {gaps}\n") +======= + f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n') + else: + f.write(f'{prefix};{b["state"]}; {sz}\n') + gaps = b["size"] - accounted_for_size + if gaps: + f.write(f'{prefix};{b["state"]}; {gaps}\n') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def segments(snapshot, format_flamegraph=format_flamegraph): f = io.StringIO() for seg in snapshot["segments"]: +<<<<<<< HEAD prefix = f"stream_{seg['stream']};seg_{seg['address']}" +======= + prefix = f'stream_{seg["stream"]};seg_{seg["address"]}' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _write_blocks(f, prefix, seg["blocks"]) return format_flamegraph(f.getvalue()) @@ -161,7 +187,11 @@ def segments(snapshot, format_flamegraph=format_flamegraph): def memory(snapshot, format_flamegraph=format_flamegraph): f = io.StringIO() for seg in snapshot["segments"]: +<<<<<<< HEAD prefix = f"stream_{seg['stream']}" +======= + prefix = f'stream_{seg["stream"]}' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _write_blocks(f, prefix, seg["blocks"]) return format_flamegraph(f.getvalue()) @@ -171,7 +201,11 @@ def _seg_key(seg): return (seg["address"], seg["total_size"]) def _seg_info(seg): +<<<<<<< HEAD return f"stream_{seg['stream']};seg_{seg['address']}" +======= + return f'stream_{seg["stream"]};seg_{seg["address"]}' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f = io.StringIO() @@ -301,11 +335,16 @@ def segsum(data): occupied[j] = "0123456789*"[int(frac[j] * 10)] else: occupied[j] = m +<<<<<<< HEAD stream = "" if seg["stream"] == 0 else f", stream_{seg['stream']}" +======= + stream = "" if seg["stream"] == 0 else f', stream_{seg["stream"]}' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) body = "".join(occupied) assert ( seg_free_external + seg_free_internal + seg_allocated == seg["total_size"] ) +<<<<<<< HEAD stream = f" stream_{seg['stream']}" if seg["stream"] != 0 else "" if seg["total_size"] >= PAGE_SIZE: out.write( @@ -313,6 +352,15 @@ def segsum(data): f"{_report_free(seg_free_external, seg_free_internal)} free{stream}\n" ) out.write(f"segments: {len(data['segments'])}\n") +======= + stream = f' stream_{seg["stream"]}' if seg["stream"] != 0 else "" + if seg["total_size"] >= PAGE_SIZE: + out.write( + f'[{body}] {Bytes(seg["total_size"])} allocated, ' + f"{_report_free(seg_free_external, seg_free_internal)} free{stream}\n" + ) + out.write(f'segments: {len(data["segments"])}\n') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out.write(f"total_reserved: {Bytes(total_reserved)}\n") out.write(f"total_allocated: {Bytes(total_allocated)}\n") out.write(f"total_free: {_report_free(free_external, free_internal)}\n") @@ -338,7 +386,11 @@ def _name(): return free_names.pop() r, m = next_name // 26, next_name % 26 next_name += 1 +<<<<<<< HEAD return f"{chr(ord('a') + m)}{'' if r == 0 else r}" +======= + return f'{chr(ord("a") + m)}{"" if r == 0 else r}' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def find_segment(addr): for name, saddr, size in segment_intervals: diff --git a/torch/cuda/_utils.py b/torch/cuda/_utils.py index d5e3a6d180132..c6bead46d8997 100644 --- a/torch/cuda/_utils.py +++ b/torch/cuda/_utils.py @@ -30,6 +30,7 @@ def _check_cuda(result: int) -> None: def _get_nvrtc_library() -> ctypes.CDLL: +<<<<<<< HEAD major_version = int(torch.version.cuda.split(".")[0]) # type: ignore[union-attr] if sys.platform == "win32": nvrtc_libs = [ @@ -46,6 +47,14 @@ def _get_nvrtc_library() -> ctypes.CDLL: except OSError: continue raise OSError("Could not find any NVRTC library") +======= + # Since PyTorch already loads NVRTC, we can use the system library + # which should be compatible with PyTorch's version + if sys.platform == "win32": + return ctypes.CDLL("nvrtc64_120_0.dll") + else: + return ctypes.CDLL("libnvrtc.so") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _nvrtc_compile( diff --git a/torch/cuda/error.py b/torch/cuda/error.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/cuda/gds.py b/torch/cuda/gds.py index d3922499682e4..89fb05a2221b8 100644 --- a/torch/cuda/gds.py +++ b/torch/cuda/gds.py @@ -119,9 +119,15 @@ def register_handle(self) -> None: This is a wrapper around ``cuFileHandleRegister``. """ +<<<<<<< HEAD assert self.handle is None, ( "Cannot register a handle that is already registered." ) +======= + assert ( + self.handle is None + ), "Cannot register a handle that is already registered." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.handle = torch._C._gds_register_handle(self.fd) def deregister_handle(self) -> None: @@ -129,9 +135,15 @@ def deregister_handle(self) -> None: This is a wrapper around ``cuFileHandleDeregister``. """ +<<<<<<< HEAD assert self.handle is not None, ( "Cannot deregister a handle that is not registered." ) +======= + assert ( + self.handle is not None + ), "Cannot deregister a handle that is not registered." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._C._gds_deregister_handle(self.handle) self.handle = None @@ -145,9 +157,15 @@ def load_storage(self, storage: Storage, offset: int = 0) -> None: storage (Storage): Storage to load data into. offset (int, optional): Offset into the file to start loading from. (Default: 0) """ +<<<<<<< HEAD assert self.handle is not None, ( "Cannot load data from a file that is not registered." ) +======= + assert ( + self.handle is not None + ), "Cannot load data from a file that is not registered." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._C._gds_load_storage(self.handle, storage, offset) def save_storage(self, storage: Storage, offset: int = 0) -> None: @@ -160,7 +178,13 @@ def save_storage(self, storage: Storage, offset: int = 0) -> None: storage (Storage): Storage to save data from. offset (int, optional): Offset into the file to start saving to. (Default: 0) """ +<<<<<<< HEAD assert self.handle is not None, ( "Cannot save data to a file that is not registered." ) +======= + assert ( + self.handle is not None + ), "Cannot save data to a file that is not registered." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._C._gds_save_storage(self.handle, storage, offset) diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py index 3946b7b3360ad..58bd8621f2367 100644 --- a/torch/cuda/graphs.py +++ b/torch/cuda/graphs.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD from __future__ import annotations import gc @@ -12,10 +13,18 @@ if TYPE_CHECKING: # importing _POOL_HANDLE at runtime toplevel causes an import cycle from torch.cuda import _POOL_HANDLE +======= +# mypy: allow-untyped-defs +import gc +import typing + +import torch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .._utils import _dummy_type +<<<<<<< HEAD __all__ = [ "is_current_stream_capturing", "graph_pool_handle", @@ -29,6 +38,8 @@ _P = ParamSpec("_P") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not hasattr(torch._C, "_CudaStreamBase"): # Define dummy base classes torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph") @@ -44,7 +55,11 @@ ) +<<<<<<< HEAD def is_current_stream_capturing() -> bool: +======= +def is_current_stream_capturing(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise. If a CUDA context does not exist on the current device, returns False without initializing the context. @@ -53,7 +68,11 @@ def is_current_stream_capturing() -> bool: # Python shim helps Sphinx process docstrings more reliably. +<<<<<<< HEAD def graph_pool_handle() -> _POOL_HANDLE: +======= +def graph_pool_handle(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Return an opaque token representing the id of a graph memory pool. See :ref:`Graph memory management`. @@ -61,7 +80,11 @@ def graph_pool_handle() -> _POOL_HANDLE: .. warning:: This API is in beta and may change in future releases. """ +<<<<<<< HEAD return torch.cuda._POOL_HANDLE(_graph_pool_handle()) +======= + return _graph_pool_handle() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Python shim helps Sphinx process docstrings more reliably. @@ -73,11 +96,19 @@ class CUDAGraph(torch._C._CUDAGraph): cudaGraphExec_t will be instantiated on GPU at the end of ``capture_end`` and the underlying cudaGraph_t will be destroyed. Users who want to query or otherwise modify the +<<<<<<< HEAD underlying cudaGraph_t before instantiation can set ``keep_graph=True`` and access it via ``raw_cuda_graph`` after ``capture_end``. Note that the cudaGraphExec_t will not be instantiated at the end of ``capture_end`` in this case. Instead, it will be instantiated via an explicit called +======= + underlying cudaGraph_t before instantiatiation can set + ``keep_graph=True`` and access it via ``raw_cuda_graph`` after + ``capture_end``. Note that the cudaGraphExec_t will not be + instantiated at the end of ``capture_end`` in this + case. Instead, it wil be instantiated via an explicit called +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) to ``instantiate`` or automatically on the first call to ``replay`` if ``instantiate`` was not already called. Calling ``instantiate`` manually before ``replay`` is recommended to @@ -92,12 +123,19 @@ class CUDAGraph(torch._C._CUDAGraph): """ +<<<<<<< HEAD def __new__(cls, keep_graph: bool = False) -> Self: return super().__new__(cls, keep_graph) def capture_begin( self, pool: Optional[_POOL_HANDLE] = None, capture_error_mode: str = "global" ) -> None: +======= + def __new__(cls, keep_graph=False): + return super().__new__(cls, keep_graph) + + def capture_begin(self, pool=None, capture_error_mode="global"): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Begin capturing CUDA work on the current stream. Typically, you shouldn't call ``capture_begin`` yourself. @@ -116,7 +154,11 @@ def capture_begin( """ # noqa: B950 super().capture_begin(pool=pool, capture_error_mode=capture_error_mode) +<<<<<<< HEAD def capture_end(self) -> None: +======= + def capture_end(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""End CUDA graph capture on the current stream. After ``capture_end``, ``replay`` may be called on this instance. @@ -127,7 +169,11 @@ def capture_end(self) -> None: """ super().capture_end() +<<<<<<< HEAD def instantiate(self) -> None: +======= + def instantiate(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Instantiate the CUDA graph. Will be called by ``capture_end`` if ``keep_graph=False``, or by ``replay`` if ``keep_graph=True`` and ``instantiate`` has not already been @@ -136,6 +182,7 @@ def instantiate(self) -> None: """ super().instantiate() +<<<<<<< HEAD def replay(self) -> None: r"""Replay the CUDA work captured by this graph.""" super().replay() @@ -145,6 +192,17 @@ def reset(self) -> None: super().reset() def pool(self) -> _POOL_HANDLE: +======= + def replay(self): + r"""Replay the CUDA work captured by this graph.""" + super().replay() + + def reset(self): + r"""Delete the graph currently held by this instance.""" + super().reset() + + def pool(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Return an opaque token representing the id of this graph's memory pool. This id can optionally be passed to another graph's ``capture_begin``, @@ -152,11 +210,19 @@ def pool(self) -> _POOL_HANDLE: """ return super().pool() +<<<<<<< HEAD def enable_debug_mode(self) -> None: r"""Enable debugging mode for CUDAGraph.debug_dump.""" return super().enable_debug_mode() def debug_dump(self, debug_path: str) -> None: +======= + def enable_debug_mode(self): + r"""Enable debugging mode for CUDAGraph.debug_dump.""" + return super().enable_debug_mode() + + def debug_dump(self, debug_path): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r""" Arguments: debug_path (required): Path to dump the graph to. @@ -166,13 +232,18 @@ def debug_dump(self, debug_path: str) -> None: """ return super().debug_dump(debug_path) +<<<<<<< HEAD def raw_cuda_graph(self) -> int: +======= + def raw_cuda_graph(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Returns the underlying cudaGraph_t. ``keep_graph`` must be True. See the following for APIs for how to manipulate this object: `Graph Managmement `_ and `cuda-python Graph Management bindings `_ """ # noqa: B950 return super().raw_cuda_graph() +<<<<<<< HEAD def raw_cuda_graph_exec(self) -> int: r"""Returns the underlying cudaGraphExec_t. ``instantiate`` must have been called if ``keep_graph`` is True, or ``capture_end`` must have been called if ``keep_graph`` is False. If you call ``instantiate()`` after ``raw_cuda_graph_exec()``, the previously returned cudaGraphExec_t will be destroyed. It is your responsibility not to use this object after destruction. @@ -180,6 +251,8 @@ def raw_cuda_graph_exec(self) -> int: """ # noqa: B950 return super().raw_cuda_graph_exec() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class graph: r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay. @@ -211,6 +284,7 @@ class graph: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 """ # noqa: B950 +<<<<<<< HEAD default_capture_stream: Optional[torch.cuda.Stream] = None def __init__( @@ -218,6 +292,15 @@ def __init__( cuda_graph: CUDAGraph, pool: Optional[_POOL_HANDLE] = None, stream: Optional[torch.cuda.Stream] = None, +======= + default_capture_stream: typing.Optional["torch.cuda.Stream"] = None + + def __init__( + self, + cuda_graph, + pool=None, + stream=None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) capture_error_mode: str = "global", ): # Lazy-init of default_capture_stream helps avoid circular-import errors. @@ -226,9 +309,13 @@ def __init__( if self.__class__.default_capture_stream is None: self.__class__.default_capture_stream = torch.cuda.Stream() +<<<<<<< HEAD self.pool: Union[tuple[()], tuple[_POOL_HANDLE]] = ( () if pool is None else (pool,) ) +======= + self.pool = () if pool is None else (pool,) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.capture_stream = ( stream if stream is not None else self.__class__.default_capture_stream ) @@ -237,6 +324,7 @@ def __init__( self.cuda_graph = cuda_graph self.capture_error_mode = capture_error_mode +<<<<<<< HEAD def __enter__(self) -> None: # Free as much memory as we can for the graph torch.cuda.synchronize() @@ -249,6 +337,12 @@ def __enter__(self) -> None: # when a dead python cycle is holding onto CUDA memory. gc.collect() +======= + def __enter__(self): + # Free as much memory as we can for the graph + torch.cuda.synchronize() + gc.collect() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.cuda.empty_cache() # Stackoverflow seems comfortable with this pattern @@ -256,6 +350,7 @@ def __enter__(self) -> None: self.stream_ctx.__enter__() self.cuda_graph.capture_begin( +<<<<<<< HEAD # type: ignore[misc] *self.pool, capture_error_mode=self.capture_error_mode, @@ -297,6 +392,20 @@ def make_graphed_callables( allow_unused_input: bool = False, pool: Optional[_POOL_HANDLE] = None, ) -> Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]]: +======= + *self.pool, capture_error_mode=self.capture_error_mode + ) + + def __exit__(self, exc_type, exc_value, traceback): + self.cuda_graph.capture_end() + self.stream_ctx.__exit__(exc_type, exc_value, traceback) + # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__() + + +def make_graphed_callables( + callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Accept callables (functions or :class:`nn.Module`\ s) and returns graphed versions. Each graphed callable's forward pass runs its source callable's @@ -370,6 +479,7 @@ def make_graphed_callables( just_one_callable = False +<<<<<<< HEAD _sample_args: tuple[tuple[Tensor, ...], ...] if not isinstance(callables, tuple): just_one_callable = True @@ -381,6 +491,16 @@ def make_graphed_callables( flatten_sample_args = [] for c, args in zip(callables, _sample_args): +======= + if not isinstance(callables, tuple): + just_one_callable = True + callables = (callables,) + sample_args = (sample_args,) + + flatten_sample_args = [] + + for c, args in zip(callables, sample_args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(c, torch.nn.Module): assert ( len(c._backward_hooks) == 0 @@ -425,7 +545,11 @@ def make_graphed_callables( torch.cuda.synchronize() with torch.cuda.stream(torch.cuda.Stream()): for func, args, static_input_surface in zip( +<<<<<<< HEAD callables, _sample_args, per_callable_static_input_surfaces +======= + callables, sample_args, per_callable_static_input_surfaces +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): grad_inputs, outputs, outputs_grad = None, None, None for _ in range(num_warmup_iters): @@ -455,11 +579,19 @@ def make_graphed_callables( # Capture forward graphs per_callable_static_outputs = [] per_callable_output_unflatten_spec = [] +<<<<<<< HEAD for func, args, fwd_graph in zip(callables, _sample_args, fwd_graphs): with torch.cuda.graph(fwd_graph, pool=mempool): func_outputs = func(*args) flatten_outputs, spec = torch.utils._pytree.tree_flatten(func_outputs) +======= + for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): + with torch.cuda.graph(fwd_graph, pool=mempool): + outputs = func(*args) + + flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) per_callable_static_outputs.append(tuple(flatten_outputs)) per_callable_output_unflatten_spec.append(spec) @@ -511,6 +643,7 @@ def make_graphed_callables( # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. def make_graphed_autograd_function( +<<<<<<< HEAD fwd_graph: CUDAGraph, bwd_graph: CUDAGraph, module_params: tuple[torch.nn.Parameter, ...], @@ -524,6 +657,21 @@ def make_graphed_autograd_function( class Graphed(torch.autograd.Function): @staticmethod def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]: +======= + fwd_graph, + bwd_graph, + module_params, + len_user_args, + output_unflatten_spec, + static_input_surface, + static_outputs, + static_grad_outputs, + static_grad_inputs, + ): + class Graphed(torch.autograd.Function): + @staticmethod + def forward(ctx, *inputs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # At this stage, only the user args may (potentially) be new tensors. for i in range(len_user_args): if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): @@ -534,7 +682,11 @@ def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]: @staticmethod @torch.autograd.function.once_differentiable +<<<<<<< HEAD def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]: +======= + def backward(ctx, *grads): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(grads) == len(static_grad_outputs) for g, grad in zip(static_grad_outputs, grads): if g is not None: @@ -550,7 +702,11 @@ def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]: b.detach() if b is not None else b for b in static_grad_inputs ) +<<<<<<< HEAD def functionalized(*user_args: object) -> object: +======= + def functionalized(*user_args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Runs the autograd function with inputs == all inputs to the graph that might require grad # (explicit user args + module parameters) # Assumes module params didn't change since capture. @@ -561,7 +717,11 @@ def functionalized(*user_args: object) -> object: return functionalized # Put together the final graphed callables +<<<<<<< HEAD ret: list[_ModuleOrCallable] = [] +======= + ret = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i, func in enumerate(callables): graphed = make_graphed_autograd_function( fwd_graphs[i], @@ -577,6 +737,7 @@ def functionalized(*user_args: object) -> object: if isinstance(func, torch.nn.Module): +<<<<<<< HEAD def make_graphed_forward( func: torch.nn.Module, graph_training_state: bool, @@ -596,6 +757,20 @@ def new_fwd(*user_args: _P.args, **user_kwargs: _P.kwargs) -> _R: func.forward = make_graphed_forward( func, func.training, graphed, func.forward ) +======= + def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): + def new_fwd(*user_args): + # If the module's training-or-eval state matches what we graphed, + # run the graph, otherwise run the original forward method + if func.training == graph_training_state: + return graphed(*user_args) + else: + return orig_fwd(*user_args) + + return new_fwd + + func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ret.append(func) else: ret.append(graphed) diff --git a/torch/cuda/jiterator.py b/torch/cuda/jiterator.py index 8bcb14d9fcfbd..96fe487d266ab 100644 --- a/torch/cuda/jiterator.py +++ b/torch/cuda/jiterator.py @@ -57,9 +57,15 @@ def __init__( ): self.code_string = code_string +<<<<<<< HEAD assert return_by_ref or num_outputs == 1, ( "Return by value only works for single output. " ) +======= + assert ( + return_by_ref or num_outputs == 1 + ), "Return by value only works for single output. " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.return_by_ref = return_by_ref self.num_outputs = num_outputs @@ -72,9 +78,15 @@ def __init__( def __call__(self, *tensors: Tensor, **kwargs): # Jiterator follow torch.cuda's lazy initialization behavior # Defer checking cuda's availability at the function invocation time +<<<<<<< HEAD assert self.is_cuda_available, ( "Jiterator is only supported on CUDA and ROCm GPUs, none are available." ) +======= + assert ( + self.is_cuda_available + ), "Jiterator is only supported on CUDA and ROCm GPUs, none are available." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs." @@ -114,8 +126,13 @@ def _create_jit_fn(code_string: str, **kwargs) -> Callable: code_string = "template T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }" jitted_fn = create_jit_fn(code_string, alpha=1.0) +<<<<<<< HEAD a = torch.rand(3, device="cuda") b = torch.rand(3, device="cuda") +======= + a = torch.rand(3, device='cuda') + b = torch.rand(3, device='cuda') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # invoke jitted function like a regular python function result = jitted_fn(a, b, alpha=3.14) @@ -123,6 +140,7 @@ def _create_jit_fn(code_string: str, **kwargs) -> Callable: Example:: +<<<<<<< HEAD code_string = ( "template T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }" ) @@ -130,6 +148,13 @@ def _create_jit_fn(code_string: str, **kwargs) -> Callable: jitted_fn = create_jit_fn(code_string, val=0.0) a = torch.rand(3, device="cuda") b = torch.rand(3, device="cuda") +======= + code_string = "template T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }" + code_string += "template T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }" + jitted_fn = create_jit_fn(code_string, val=0.0) + a = torch.rand(3, device='cuda') + b = torch.rand(3, device='cuda') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # invoke jitted function like a regular python function result = jitted_fn(a, b) # using default val=0.0 @@ -141,9 +166,15 @@ def _create_jit_fn(code_string: str, **kwargs) -> Callable: code_string = "template T my_gelu(T a) { return a > 0 ? a : 0; }" my_gelu = create_jit_fn(code_string) my_lib = torch.library.Library("aten", "IMPL") +<<<<<<< HEAD my_lib.impl("aten::gelu", my_gelu, "CUDA") # torch.nn.GELU and torch.nn.function.gelu are now overridden a = torch.rand(3, device="cuda") +======= + my_lib.impl('aten::gelu', my_gelu, "CUDA") + # torch.nn.GELU and torch.nn.function.gelu are now overridden + a = torch.rand(3, device='cuda') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a)) .. warning:: @@ -173,8 +204,13 @@ def _create_multi_output_jit_fn( code_string = "template void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }" jitted_fn = create_jit_fn(code_string, alpha=1.0) +<<<<<<< HEAD a = torch.rand(3, device="cuda") b = torch.rand(3, device="cuda") +======= + a = torch.rand(3, device='cuda') + b = torch.rand(3, device='cuda') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # invoke jitted function like a regular python function result = jitted_fn(a, b, alpha=3.14) diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 5a1a0adc02afc..74cc9998ce486 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -255,9 +255,15 @@ def memory_stats(device: "Device" = None) -> dict[str, Any]: - ``all``: combined statistics across all memory pools. - ``large_pool``: statistics for the large allocation pool +<<<<<<< HEAD (as of June 2025, for size >= 1MB allocations). - ``small_pool``: statistics for the small allocation pool (as of June 2025, for size < 1MB allocations). +======= + (as of October 2019, for size >= 1MB allocations). + - ``small_pool``: statistics for the small allocation pool + (as of October 2019, for size < 1MB allocations). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Metric type: @@ -886,7 +892,11 @@ def _record_memory_history( store the last accumulated `max_entries` entries, meaning new entries overwrite older entries. +<<<<<<< HEAD C++ implementation for reference to ring buffer implementation: +======= + C++ implementation for reference to ring buffer implemenation: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. code-block:: cpp @@ -914,7 +924,11 @@ def _record_memory_history( Args: enabled (Literal[None, "state", "all"], optional): `None`, disable recording memory history. +<<<<<<< HEAD `"state"`, keep information for currently allocated memory. +======= + `"state"`, keep information for currenly allocated memory. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) `"all"`, additionally keep a history of all alloc/free calls. Defaults to "all". context (Literal[None, "state", "alloc", "all"], optional): @@ -968,10 +982,16 @@ def _snapshot(device: "Device" = None): .. code-block:: python class Snapshot(TypedDict): +<<<<<<< HEAD segments: List[Segment] device_traces: List[List[TraceEntry]] +======= + segments : List[Segment] + device_traces: List[List[TraceEntry]] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Segment(TypedDict): # Segments are memory returned from a cudaMalloc call. # The size of reserved memory is the sum of all Segments. @@ -980,6 +1000,7 @@ class Segment(TypedDict): # is split into more then one Block. # empty_cache() frees Segments that are entirely inactive. address: int +<<<<<<< HEAD total_size: int # cudaMalloc'd size of segment stream: int segment_type: Literal["small", "large"] # 'large' (>1MB) @@ -987,11 +1008,20 @@ class Segment(TypedDict): active_size: int # size of memory in use or in active_awaiting_free state blocks: List[Block] +======= + total_size: int # cudaMalloc'd size of segment + stream: int + segment_type: Literal['small', 'large'] # 'large' (>1MB) + allocated_size: int # size of memory in use + active_size: int # size of memory in use or in active_awaiting_free state + blocks : List[Block] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class Block(TypedDict): # A piece of memory returned from the allocator, or # current cached but inactive. size: int +<<<<<<< HEAD requested_size: int # size requested during malloc, may be smaller than # size due to rounding address: int @@ -1009,12 +1039,28 @@ class Frame(TypedDict): line: int name: str +======= + requested_size: int # size requested during malloc, may be smaller than + # size due to rounding + address: int + state: Literal['active_allocated', # used by a tensor + 'active_awaiting_free', # waiting for another stream to finish using + # this, then it will become free + 'inactive',] # free for reuse + frames: List[Frame] # stack trace from where the allocation occurred + + class Frame(TypedDict): + filename: str + line: int + name: str +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TraceEntry(TypedDict): # When `torch.cuda.memory._record_memory_history()` is enabled, # the snapshot will contain TraceEntry objects that record each # action the allocator took. action: Literal[ +<<<<<<< HEAD "alloc" # memory allocated "free_requested", # the allocated received a call to free memory "free_completed", # the memory that was requested to be freed is now @@ -1036,6 +1082,29 @@ class TraceEntry(TypedDict): stream: int device_free: int # only present for OOM, the amount of # memory cuda still reports to be free +======= + 'alloc' # memory allocated + 'free_requested', # the allocated received a call to free memory + 'free_completed', # the memory that was requested to be freed is now + # able to be used in future allocation calls + 'segment_alloc', # the caching allocator ask cudaMalloc for more memory + # and added it as a segment in its cache + 'segment_free', # the caching allocator called cudaFree to return memory + # to cuda possibly trying free up memory to + # allocate more segments or because empty_caches was called + 'oom', # the allocator threw an OOM exception. 'size' is + # the requested number of bytes that did not succeed + 'snapshot' # the allocator generated a memory snapshot + # useful to coorelate a previously taken + # snapshot with this trace + ] + addr: int # not present for OOM + frames: List[Frame] + size: int + stream: int + device_free: int # only present for OOM, the amount of + # memory cuda still reports to be free +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns: The Snapshot dictionary object @@ -1169,15 +1238,26 @@ class MemPool(_MemPool): use_on_oom(bool): a bool that indicates if this pool can be used as a last resort if a memory allocation outside of the pool fails due to Out Of Memory. This is False by default. +<<<<<<< HEAD +======= + symmetric(bool): a bool that indicates if this pool is symmetrical + across ranks. This is False by default. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ def __init__( self, allocator: Optional[_cuda_CUDAAllocator] = None, use_on_oom: bool = False, +<<<<<<< HEAD ): super().__init__(allocator, True, use_on_oom) +======= + symmetric: bool = False, + ): + super().__init__(allocator, True, use_on_oom, symmetric) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def id(self) -> tuple[int, int]: @@ -1185,6 +1265,14 @@ def id(self) -> tuple[int, int]: return super().id @property +<<<<<<< HEAD +======= + def is_symmetric(self) -> bool: + r"""Returns whether this pool is used for NCCL's symmetric memory.""" + return super().is_symmetric + + @property +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def allocator(self) -> Optional[_cuda_CUDAAllocator]: r"""Returns the allocator this MemPool routes allocations to.""" return super().allocator diff --git a/torch/cuda/tunable.py b/torch/cuda/tunable.py index d1ac7fad7480b..59468629b80a8 100644 --- a/torch/cuda/tunable.py +++ b/torch/cuda/tunable.py @@ -124,11 +124,19 @@ There are basically two steps: 1) Set the environment variables to collect the untuned GEMM and this will generate ``tunableop_untuned0.csv``: +<<<<<<< HEAD .. code-block:: bash export PYTORCH_TUNABLEOP_ENABLED=1 export PYTORCH_TUNABLEOP_TUNING=0 export PYTORCH_TUNABLEOP_RECORD_UNTUNED=1 +======= +.. code-block:: python + + PYTORCH_TUNABLEOP_ENABLED=1 + PYTORCH_TUNABLEOP_TUNING=0 + PYTORCH_TUNABLEOP_RECORD_UNTUNED=1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ... 2) Run a Python script that reads the ``tunableop_untuned0.csv`` and generates the ``tunableop_results0.csv``, like this: @@ -138,9 +146,15 @@ import torch.cuda.tunable as tunable import os +<<<<<<< HEAD os.putenv("PYTORCH_TUNABLEOP_ENABLED", "1") os.putenv("PYTORCH_TUNABLEOP_TUNING", "1") os.putenv("PYTORCH_TUNABLEOP_RECORD_UNTUNED", "0") +======= + os.putenv('PYTORCH_TUNABLEOP_ENABLED', '1') + os.putenv('PYTORCH_TUNABLEOP_TUNING', '1') + os.putenv('PYTORCH_TUNABLEOP_RECORD_UNTUNED', '0') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tunable.tune_gemm_in_file("tunableop_untuned0.csv") @@ -155,7 +169,11 @@ .. code-block:: python if __name__ == "__main__": +<<<<<<< HEAD num_gpus = 8 # number of GPUs that will be used during the tuning process +======= + num_gpus = 8 # number of GPUs that will be used during the tuning process +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tunable.mgpu_tune_gemm_in_file("tunableop_untuned?.csv", num_gpus) Note that the usage of the ``mgpu_tune_gemm_in_file`` API is different from its single GPU counterpart @@ -179,7 +197,10 @@ Use the C++ or Python APIs instead. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import concurrent.futures import glob import multiprocessing as mp @@ -286,7 +307,11 @@ def set_filename(filename: str, insert_device_ordinal: bool = False) -> None: If :attr:`insert_device_ordinal` is ``True`` then the current device ordinal will be added to the given filename automatically. This can be used in a +<<<<<<< HEAD 1-process-per-gpu scenario to ensure all processes write to a separate file. +======= + 1-process-per-gpu cenario to ensure all processes write to a separate file. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ torch._C._cuda_tunableop_set_filename(filename, insert_device_ordinal) # type: ignore[attr-defined] @@ -591,6 +616,10 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None: transA = layout[1] == "T" dtype = dtype_dict.get(data_type) if data_type == "tf32": +<<<<<<< HEAD +======= + # User must still set HIPBLASLT_ALLOW_TF32=1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.backends.cuda.matmul.allow_tf32 = True else: torch.backends.cuda.matmul.allow_tf32 = False diff --git a/torch/custom_class.h b/torch/custom_class.h index 01a4f2e92b28b..9a048c68f607e 100644 --- a/torch/custom_class.h +++ b/torch/custom_class.h @@ -287,7 +287,11 @@ class class_ : public ::torch::detail::class_base { /// __getstate__(intrusive_ptr) -> T1 /// __setstate__(T2) -> intrusive_ptr /// +<<<<<<< HEAD /// `T1` must be an object that is convertible to IValue by the same rules +======= + /// `T1` must be an object that is convertable to IValue by the same rules +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /// for custom op/method registration. /// /// For the common case, T1 == T2. T1 can also be a subtype of T2. An @@ -444,7 +448,11 @@ c10::IValue make_custom_class(CtorArgs&&... args) { } // Alternative api for creating a torchbind class over torch::class_ this api is +<<<<<<< HEAD // preferred to prevent size regressions on Edge usecases. Must be used in +======= +// preffered to prevent size regressions on Edge usecases. Must be used in +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // conjunction with TORCH_SELECTIVE_CLASS macro aka // selective_class("foo_namespace", TORCH_SELECTIVE_CLASS("foo")) template diff --git a/torch/distributed/CONTRIBUTING.md b/torch/distributed/CONTRIBUTING.md index 7fd1288772e2a..5939163c984df 100644 --- a/torch/distributed/CONTRIBUTING.md +++ b/torch/distributed/CONTRIBUTING.md @@ -2,13 +2,21 @@ Please go through PyTorch's top level [Contributing Guide](../../CONTRIBUTING.md) before proceeding with this guide. +<<<<<<< HEAD [PyTorch Distributed Overview](https://pytorch.org/tutorials//beginner/dist_overview.html) is a great starting point with a lot of tutorials, documentation and design docs covering PyTorch Distributed. We highly recommend going through some of that material before you start working on PyTorch Distributed. +======= +[PyTorch Distributed Overview](https://pytorch.org/tutorials//beginner/dist_overview.html) is a great starting point with a lot of tutorials, documentation and design docs covering PyTorch Distributed. We would highly recommend going through some of that material before you start working on PyTorch Distributed. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) In this document, we mostly focus on some of the code structure for PyTorch distributed and implementation details. ### Onboarding Tasks +<<<<<<< HEAD A list of onboarding tasks can be found [here](https://github.com/pytorch/pytorch/issues?q=is%3Aopen%20is%3Aissue%20label%3A%22pt_distributed_rampup%22). +======= +A list of onboarding tasks can be found [here](https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+distributed%22+label%3A%22topic%3A+bootcamp%22) and [here](https://github.com/pytorch/pytorch/issues?q=is%3Aopen+is%3Aissue+label%3A%22module%3A+distributed%22+label%3Apt_distributed_rampup). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ## Code Pointers diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 38e2fdbee803a..1871e31a90d65 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -4,7 +4,10 @@ import sys import traceback import typing +<<<<<<< HEAD from datetime import timedelta +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch @@ -83,7 +86,11 @@ def interaction(self, *args, **kwargs): _breakpoint_cache: dict[int, typing.Any] = {} +<<<<<<< HEAD def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600): +======= + def breakpoint(rank: int = 0, skip: int = 0): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Set a breakpoint, but only on a single rank. All other ranks will wait for you to be done with the breakpoint before continuing. @@ -100,6 +107,7 @@ def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600): log.warning("Skip the breakpoint, counter=%d", counter) return +<<<<<<< HEAD # avoid having the default timeout (if short) interrupt your debug session if timeout_s is not None: for group in torch.distributed.distributed_c10d._pg_map: @@ -107,6 +115,8 @@ def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600): timedelta(seconds=timeout_s), group ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if get_rank() == rank: pdb = _DistributedPdb() pdb.message( diff --git a/torch/distributed/_composable/checkpoint_activation.py b/torch/distributed/_composable/checkpoint_activation.py index 2d109ad56835b..ff70573157c18 100644 --- a/torch/distributed/_composable/checkpoint_activation.py +++ b/torch/distributed/_composable/checkpoint_activation.py @@ -79,7 +79,10 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module: user_context_fns = kwargs.pop("context_fn", None) determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE) debug = kwargs.pop("debug", False) +<<<<<<< HEAD early_stop = kwargs.pop("early_stop", True) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if kwargs: raise ValueError( @@ -104,7 +107,10 @@ def context_fns(): context_fns, determinism_check, debug, +<<<<<<< HEAD early_stop, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *args, **kwargs, ) diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index c893794fc3011..a4e4d5814d81e 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -19,6 +19,7 @@ from torch.utils._pytree import tree_map_only # type: ignore[no-redef] +<<<<<<< HEAD try: from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling except Exception: @@ -30,6 +31,25 @@ def is_torchdynamo_compiling(): # type: ignore[misc] return False return False +======= +if torch._running_with_deploy(): + + def is_torchdynamo_compiling(): + """Can't import torchdynamo in torchdeploy builds currently.""" + return False + +else: + try: + from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling + except Exception: + warnings.warn( + "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly" + ) + + def is_torchdynamo_compiling(): + return False + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ New traceable, functional collectives. @@ -815,11 +835,14 @@ def _are_we_tracing() -> bool: # If fake mode is turned on, we are almost definitely compiling/tracing. if torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is not None: return True +<<<<<<< HEAD # See Note [enable_python_dispatcher in dynamo] if torch._C._dispatch_tls_is_dispatch_key_included( torch._C.DispatchKey.PythonDispatcher ): return True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return get_proxy_mode() is not None @@ -869,6 +892,7 @@ def all_reduce_wait_compiled(y): ) +<<<<<<< HEAD def _make_all_gather_out_tensor(input, group_size): out_size = list(input.size()) if len(out_size) == 0: @@ -881,6 +905,16 @@ def _make_all_gather_out_tensor(input, group_size): def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size): return [_make_all_gather_out_tensor(t, group_size) for t in self] +======= +def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size): + def mk_out_tensor(shard): + out_size = list(shard.size()) + out_size[0] *= group_size + out_tensor = shard.new_empty(out_size) + return out_tensor + + return [mk_out_tensor(t) for t in self] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We now register meta kernels to deal with tracing @@ -897,7 +931,13 @@ def _wait_tensor_meta(self, *args): def _all_gather_into_tensor_meta(shard, tag, rankset, group_size): +<<<<<<< HEAD return _make_all_gather_out_tensor(shard, group_size) +======= + out_size = list(shard.size()) + out_size[0] *= group_size + return shard.new_empty(out_size) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size): @@ -951,11 +991,23 @@ def _all_to_all_single_meta( def _all_gather_into_tensor_out_native_meta(input, group_size, group_name, *, out): +<<<<<<< HEAD return _make_all_gather_out_tensor(input, group_size) def _all_gather_into_tensor_native_meta(input, group_size, group_name): return _make_all_gather_out_tensor(input, group_size) +======= + shape = list(input.size()) + shape[0] *= group_size + return input.new_empty(shape) + + +def _all_gather_into_tensor_native_meta(input, group_size, group_name): + shape = list(input.size()) + shape[0] *= group_size + return input.new_empty(shape) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name): @@ -980,6 +1032,7 @@ def _reduce_scatter_tensor_coalesced_native_meta( ] +<<<<<<< HEAD # Library MUST be defined at module scope or it doesn't work lib_impl = torch.library.Library("_c10d_functional", "IMPL") lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") @@ -1032,6 +1085,68 @@ def _reduce_scatter_tensor_coalesced_native_meta( backend_impl = getattr(fun_col_impl, f"_{op_name}") legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag) legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd") +======= +if not torch._running_with_deploy(): + # Library MUST be defined at module scope or it doesn't work + # Creating a "DEF" Library always crashes torch::deploy so we create our + # Library instances here guarded against running inside it + lib_impl = torch.library.Library("_c10d_functional", "IMPL") + lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") + lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta") + lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") + lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") + lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") + lib_impl.impl( + "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta" + ) + lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") + lib_impl.impl( + "all_gather_into_tensor_coalesced", + _all_gather_into_tensor_coalesced_native_meta, + "Meta", + ) + lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta") + lib_impl.impl( + "reduce_scatter_tensor_coalesced", + _reduce_scatter_tensor_coalesced_native_meta, + "Meta", + ) + lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta") + lib_impl.impl("broadcast", _broadcast_meta, "Meta") + lib_impl.impl("broadcast_", _broadcast__meta, "Meta") + + # mark these ops has side effect so that they won't be removed by DCE + torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) + torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) + + # Register legacy ops for backward compatibility + # TODO(yifu): remove these in functional collective beta release + legacy_lib = torch.library.Library("c10d_functional", "DEF") + legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL") + ops_defs = [ + "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", + "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", + "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", + "wait_tensor(Tensor self) -> Tensor", + "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor", + "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]", + "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", + "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", + "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950 + ] + + my_module = sys.modules[__name__] + for op_def in ops_defs: + op_name = op_def[0 : op_def.index("(")] + backend_impl = getattr(fun_col_impl, f"_{op_name}") + legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag) + legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd") + +else: + warnings.warn( + "PyTorch Distributed functional collectives do not work with torch::deploy." + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ @@ -1148,7 +1263,11 @@ def all_gather_inplace( assert not async_op, ( "Can't remap async version of inplace op to functional collective" ) +<<<<<<< HEAD assert tensor.dim() == 0 or all(t.size(0) == tensor.size(0) for t in tensor_list), ( +======= + assert all(t.size(0) == tensor.size(0) for t in tensor_list), ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "Remapping variable size all_gather is not yet supported" ) @@ -1162,11 +1281,16 @@ def all_gather_inplace( output_splits = [] offset = 0 for t in tensor_list: +<<<<<<< HEAD is_scalar = t.dim() == 0 t_offset = 1 if is_scalar else t.size(0) out = output[offset] if is_scalar else output[offset : offset + t_offset] output_splits.append(out) offset += t_offset +======= + output_splits.append(output[offset : offset + t.size(0)]) + offset += t.size(0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for dst, src in zip(tensor_list, output_splits): dst.copy_(src) return tensor_list diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index 772483322cc56..fa8e726d69fa0 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -1146,12 +1146,17 @@ def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor: resharding_spec, shard_spec.ChunkShardingSpec ) or not isinstance(self._sharding_spec, shard_spec.ChunkShardingSpec): raise NotImplementedError("Only ChunkShardingSpec supported for reshard.") +<<<<<<< HEAD num_local_shards = len(self.local_shards()) if num_local_shards != 1: raise NotImplementedError( f"Only single local shard supported for reshard. Number of shards: {num_local_shards}" ) +======= + if len(self.local_shards()) != 1: + raise NotImplementedError("Only single local shard supported for reshard.") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self._sharding_spec.dim == resharding_spec.dim: # type: ignore[attr-defined] if self._sharding_spec.placements == resharding_spec.placements: # type: ignore[attr-defined] @@ -1184,11 +1189,16 @@ def local_tensor(self) -> torch.Tensor: Returns: A :class:`torch.Tensor` of the local shard. """ +<<<<<<< HEAD num_local_shards = len(self.local_shards()) if num_local_shards != 1: raise NotImplementedError( f"Only single local shard is supported. Number of shards: {num_local_shards}" ) +======= + if len(self.local_shards()) != 1: + raise NotImplementedError("Only single local shard is supported.") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.local_shards()[0].tensor @classmethod diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index 8c527e7efe5d4..f0be3a6773e79 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -423,7 +423,11 @@ def tensor_func( t = t.share_memory_() if pin_memory: pin_memory_utils.pin_memory(t.data_ptr(), t.numel() * t.element_size()) +<<<<<<< HEAD weakref.finalize(t, pin_memory_utils.unpin_memory, t.data_ptr()) +======= + weakref.finalize(t, pin_memory_utils.unpin_memory, t) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return t elif pin_memory: diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index 43c2959fdd8d1..d208e274fa3fa 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -1,5 +1,8 @@ +<<<<<<< HEAD from __future__ import annotations +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import math import os import socket @@ -9,7 +12,11 @@ from datetime import timedelta from enum import Enum from functools import partial +<<<<<<< HEAD from typing import Any, Callable, Literal +======= +from typing import Any, Callable, Optional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.distributed._functional_collectives as funcol @@ -49,11 +56,19 @@ def enable_symm_mem_for_group(group_name: str) -> None: _is_test_mode: bool = False +<<<<<<< HEAD _mocked_group_names: set[str] | None = None @contextmanager def _test_mode(group_names: set[str] | None = None) -> Generator[None, None, None]: +======= +_mocked_group_names: Optional[set[str]] = None + + +@contextmanager +def _test_mode(group_names: Optional[set[str]] = None) -> Generator[None, None, None]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Forces ``is_symm_mem_enabled_for_group()`` to return ``True`` and the ops defined in the ``symm_mem`` namespace to use fallback implementations. @@ -85,7 +100,11 @@ def is_symm_mem_enabled_for_group(group_name: str) -> bool: return group_name in _group_name_to_store +<<<<<<< HEAD _group_name_to_workspace_tensor: dict[str, torch.Tensor | None] = {} +======= +_group_name_to_workspace_tensor: dict[str, Optional[torch.Tensor]] = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory: @@ -471,7 +490,11 @@ class _ScaleMode(Enum): def _check_and_verify_fp8_all_gather_scale_mode( +<<<<<<< HEAD shard: torch.Tensor, scale: torch.Tensor | None, gather_dim: int, group_size: int +======= + shard: torch.Tensor, scale: Optional[torch.Tensor], gather_dim: int, group_size: int +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> _ScaleMode: full_shape = list(shard.shape) full_shape[gather_dim] *= group_size @@ -500,6 +523,7 @@ def _fused_all_gather_matmul_impl( mm_out_op: torch._ops.OpOverload, A_shard: torch.Tensor, Bs: list[torch.Tensor], +<<<<<<< HEAD A_scale: torch.Tensor | None, kwargs_list: list[dict[str, Any]], out_dtypes: list[torch.dtype | None], @@ -507,6 +531,15 @@ def _fused_all_gather_matmul_impl( group_name: str, return_A: bool, ) -> tuple[torch.Tensor | None, list[torch.Tensor]]: +======= + A_scale: Optional[torch.Tensor], + kwargs_list: list[dict[str, Any]], + out_dtypes: list[Optional[torch.dtype]], + gather_dim: int, + group_name: str, + return_A: bool, +) -> tuple[Optional[torch.Tensor], list[torch.Tensor]]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if A_shard.dim() < 2: raise ValueError("A_shard must be a matrix") for B in Bs: @@ -629,7 +662,11 @@ def _fused_all_gather_matmul_fallback( group_name: str, *, return_A: bool = True, +<<<<<<< HEAD ) -> tuple[torch.Tensor | None, list[torch.Tensor]]: +======= +) -> tuple[Optional[torch.Tensor], list[torch.Tensor]]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) group_size = c10d._get_group_size_by_name(group_name) A = torch.ops._c10d_functional.all_gather_into_tensor( A_shard.contiguous(), group_size, group_name @@ -651,7 +688,11 @@ def _fused_all_gather_matmul( group_name: str, *, return_A: bool = True, +<<<<<<< HEAD ) -> tuple[torch.Tensor | None, list[torch.Tensor]]: +======= +) -> tuple[Optional[torch.Tensor], list[torch.Tensor]]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Perform the following logic with micro-pipelined computation and communication: @@ -821,9 +862,15 @@ def _fused_all_gather_scaled_matmul_fallback( B_scales: list[torch.Tensor], gather_dim: int, group_name: str, +<<<<<<< HEAD biases: list[torch.Tensor | None], result_scales: list[torch.Tensor | None], out_dtypes: list[torch.dtype | None], +======= + biases: list[Optional[torch.Tensor]], + result_scales: list[Optional[torch.Tensor]], + out_dtypes: list[Optional[torch.dtype]], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) use_fast_accum: list[bool], ) -> tuple[torch.Tensor, list[torch.Tensor]]: out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes) @@ -859,9 +906,15 @@ def scaled_matmul( B: torch.Tensor, A_scale: torch.Tensor, B_scale: torch.Tensor, +<<<<<<< HEAD bias: torch.Tensor | None, result_scale: torch.Tensor | None, out_dtype: torch.dtype | None, +======= + bias: Optional[torch.Tensor], + result_scale: Optional[torch.Tensor], + out_dtype: Optional[torch.dtype], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) use_fast_accum: bool, ) -> torch.Tensor: leading_dims = A.shape[:-1] @@ -895,9 +948,15 @@ def _fused_all_gather_scaled_matmul( B_scales: list[torch.Tensor], gather_dim: int, group_name: str, +<<<<<<< HEAD biases: list[torch.Tensor | None], result_scales: list[torch.Tensor | None], out_dtypes: list[torch.dtype | None], +======= + biases: list[Optional[torch.Tensor]], + result_scales: list[Optional[torch.Tensor]], + out_dtypes: list[Optional[torch.dtype]], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) use_fast_accum: list[bool], ) -> tuple[torch.Tensor, list[torch.Tensor]]: """ @@ -1048,7 +1107,11 @@ def _fused_matmul_reduce_scatter_impl( A: torch.Tensor, B: torch.Tensor, kwargs: dict[str, Any], +<<<<<<< HEAD out_dtype: torch.dtype | None, +======= + out_dtype: Optional[torch.dtype], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reduce_op: str, scatter_dim: int, group_name: str, @@ -1110,9 +1173,15 @@ def _fused_scaled_matmul_reduce_scatter( scatter_dim_after_maybe_reshape: int, group_name: str, output_shape: list[int], +<<<<<<< HEAD bias: torch.Tensor | None = None, result_scale: torch.Tensor | None = None, out_dtype: torch.dtype | None = None, +======= + bias: Optional[torch.Tensor] = None, + result_scale: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) use_fast_accum: bool = False, ) -> torch.Tensor: if _is_test_mode: @@ -1164,9 +1233,15 @@ def _fused_scaled_matmul_reduce_scatter_fallback( scatter_dim_after_maybe_reshape: int, group_name: str, output_shape: list[int], +<<<<<<< HEAD bias: torch.Tensor | None = None, result_scale: torch.Tensor | None = None, out_dtype: torch.dtype | None = None, +======= + bias: Optional[torch.Tensor] = None, + result_scale: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) use_fast_accum: bool = False, ) -> torch.Tensor: if A_scale.numel() > 1: @@ -1210,7 +1285,11 @@ def _fused_scaled_matmul_reduce_scatter_impl( B: torch.Tensor, A_scale: torch.Tensor, kwargs: dict[str, Any], +<<<<<<< HEAD out_dtype: torch.dtype | None, +======= + out_dtype: Optional[torch.dtype], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) reduce_op: str, orig_scatter_dim: int, scatter_dim_after_maybe_reshape: int, @@ -1272,11 +1351,14 @@ def _fused_scaled_matmul_reduce_scatter_impl( .flatten(0, -2) ) A_scale_shards = list(A_scale.chunk(group.size())) +<<<<<<< HEAD # cuBLAS's row-wise kernel requires scales to be aligned to 16 bytes. # When we slice them we might break this and need to reallocate them. A_scale_shards = [ t if t.data_ptr() % 16 == 0 else t.clone() for t in A_scale_shards ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: raise ValueError("A_scale cannot be none for scaled_mm") @@ -1352,7 +1434,11 @@ def restride_A_for_fused_matmul_reduce_scatter( def _maybe_convert_scalar_types_to_dtypes( scalar_types: list[Any], +<<<<<<< HEAD ) -> list[torch.dtype | None]: +======= +) -> list[Optional[torch.dtype]]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ When a list of `torch.dtype`s is passed through the dispatcher as `ScalarType[]`, it is converted to a list of scalar type enum values. This @@ -1384,7 +1470,11 @@ def _maybe_convert_scalar_types_to_dtypes( if any(not isinstance(x, (type(None), int)) for x in scalar_types): return scalar_types +<<<<<<< HEAD dtypes: list[torch.dtype | None] = [] +======= + dtypes: list[Optional[torch.dtype]] = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for scalar_type in scalar_types: if scalar_type is None: dtypes.append(scalar_type) @@ -1616,6 +1706,7 @@ def _low_contention_reduce_scatter( ) +<<<<<<< HEAD @torch.library.impl(lib, "all_to_all_vdev_2d", "Meta") def _all_to_all_vdev_2d_meta( input: torch.Tensor, @@ -1639,23 +1730,38 @@ def _all_to_all_vdev_2d_offset_meta( return None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ============================================================================= # User-facing APIs # ============================================================================= from collections.abc import Sequence +<<<<<<< HEAD from typing import overload, TYPE_CHECKING, Union +======= +from typing import Any, overload, TYPE_CHECKING, Union + +from torch.types import _device, _dtype, _int +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING: from torch._C._distributed_c10d import ProcessGroup +<<<<<<< HEAD from torch.types import _device, _dtype, _int +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload def empty( +<<<<<<< HEAD *size: _int, dtype: _dtype | None = None, device: _device | None = None +======= + *size: _int, dtype: Optional[_dtype] = None, device: Optional[_device] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> torch.Tensor: ... @@ -1663,15 +1769,25 @@ def empty( def empty( size: Sequence[_int], *, +<<<<<<< HEAD dtype: _dtype | None = None, device: _device | None = None, +======= + dtype: Optional[_dtype] = None, + device: Optional[_device] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> torch.Tensor: ... def empty( # type: ignore[misc] *size: Any, +<<<<<<< HEAD dtype: _dtype | None = None, device: _device | None = None, +======= + dtype: Optional[_dtype] = None, + device: Optional[_device] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> torch.Tensor: r""" empty(*size, *, dtype=None, device=None) -> Tensor @@ -1712,7 +1828,11 @@ def empty( # type: ignore[misc] def rendezvous( +<<<<<<< HEAD tensor: torch.Tensor, group: Union[str, ProcessGroup] +======= + tensor: torch.Tensor, group: Union[str, "ProcessGroup"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> _SymmetricMemory: r""" rendezvous(tensor, group) -> _SymmetricMemory @@ -1756,6 +1876,7 @@ def is_nvshmem_available() -> bool: return _is_nvshmem_available() +<<<<<<< HEAD def set_backend(name: Literal["NVSHMEM", "CUDA", "NCCL"]) -> None: r""" Set the backend for symmetric memory allocation. This is a global setting @@ -1793,3 +1914,6 @@ def get_mempool_allocator(device: _device): # type: ignore[no-untyped-def] __all__ = ["empty", "rendezvous", "is_nvshmem_available", "set_backend", "get_backend"] +======= +__all__ = ["empty", "rendezvous", "is_nvshmem_available"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index 0d5e88e91805a..ac77215281db0 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD import logging import os import subprocess @@ -426,17 +427,126 @@ def putmem_signal_block_extern_wrapper( # type: ignore[no-untyped-def] sig_op, pe, _semantic=None, +======= +import os +import sysconfig +from typing import Optional + +from torch.utils._triton import has_triton + + +def enable_triton(lib_dir: Optional[str] = None) -> dict[str, str]: + """ + Enable NVSHMEM device functions for Triton. It performs a NVSHMEM + device-side initialization on the kernel module created by Triton. + + Args: + lib_dir (Optional[str]): The directory where the NVSHMEM device library + is located. If not provided, it will use the default path where NVSHMEM + wheel is installed. + + Returns: + dict[str, str]: A dictionary containing the NVSHMEM device library name + and path. + """ + from triton.runtime.jit import JITFunction + + from torch._C._distributed_c10d import _nvshmemx_cumodule_init + + # Detect NVSHMEM device library path from python library path + if lib_dir is None: + py_lib_path = sysconfig.get_path("purelib") + lib_dir = py_lib_path + "/nvidia/nvshmem/lib" + + lib_path = os.path.join(lib_dir, "libnvshmem_device.bc") + if not os.path.exists(lib_path): + raise RuntimeError("NVSHMEM device library not found") + + extern_libs = {"libnvshmem_device": lib_path} + + # A hook function to initialize NVSHMEM in Triton + def nvshmem_init_hook(*args, **kwargs) -> None: # type: ignore[no-untyped-def] + key = kwargs["key"] + device = kwargs["compile"]["device"] + jit_function = kwargs["fn"].jit_function + kernel_cache, _, _, _ = jit_function.device_caches[device] + kernel = kernel_cache.get(key, None) + kernel.run + _nvshmemx_cumodule_init(kernel.module) + + # Register the function as a post-compile hook + JITFunction.compiled_hook = nvshmem_init_hook + + # Return to user so that they can use it in Triton kernel invocation + return extern_libs + + +if has_triton(): + from triton.language import core + + @core.extern + def putmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [dst, src, nelems, pe], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmemx_putmem_block", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def getmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [dst, src, nelems, pe], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmemx_getmem_block", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def putmem_signal_block( # type: ignore[no-untyped-def] + dst, + src, + nelems, + sig_addr, + signal, + sig_op, + pe, + _builder=None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): # type: ignore[no-untyped-def] return core.extern_elementwise( "", "", +<<<<<<< HEAD [dst, src, size_bytes, signal, sig_val, sig_op, pe], +======= + [dst, src, nelems, sig_addr, signal, sig_op, pe], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) { ( core.dtype("int64"), core.dtype("int64"), core.dtype("int64"), core.dtype("int64"), +<<<<<<< HEAD core.dtype("uint64"), core.dtype("int32"), core.dtype("int32"), @@ -488,6 +598,19 @@ def wait_until(ivar, cmp_op, cmp_val): # type: ignore[no-untyped-def] @core.extern def wait_until_extern_wrapper(ivar, cmp, cmp_val, _semantic=None): # type: ignore[no-untyped-def] +======= + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmemx_putmem_signal_block", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def wait_until(ivar, cmp, cmp_val, _builder=None): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return core.extern_elementwise( "", "", @@ -495,6 +618,7 @@ def wait_until_extern_wrapper(ivar, cmp, cmp_val, _semantic=None): # type: igno { ( core.dtype("int64"), +<<<<<<< HEAD core.dtype("int32"), core.dtype("int32"), ): ("nvshmem_int_wait_until", core.dtype("int32")) @@ -654,6 +778,35 @@ def fence(_semantic=None): # type: ignore[no-untyped-def] nvshmem.put(dst2, src2, nelems, target_pe) ``` """ +======= + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmem_longlong_wait_until", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def signal_wait_until(sig_addr, cmp, cmp_val, _builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [sig_addr, cmp, cmp_val], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmem_signal_wait_until", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def fence(_builder=None): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return core.extern_elementwise( "", "", @@ -662,6 +815,7 @@ def fence(_semantic=None): # type: ignore[no-untyped-def] (): ("nvshmem_fence", core.dtype("int32")), }, is_pure=False, +<<<<<<< HEAD _semantic=_semantic, ) @@ -702,6 +856,13 @@ def quiet(_semantic=None): # type: ignore[no-untyped-def] ) # Signal completion ``` """ +======= + _builder=_builder, + ) + + @core.extern + def quiet(_builder=None): # type: ignore[no-untyped-def] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return core.extern_elementwise( "", "", @@ -710,6 +871,7 @@ def quiet(_semantic=None): # type: ignore[no-untyped-def] (): ("nvshmem_quiet", core.dtype("int32")), }, is_pure=False, +<<<<<<< HEAD _semantic=_semantic, ) @@ -1164,3 +1326,7 @@ def on_exit() -> None: if kernel not in triton_kernels: triton_kernels[kernel] = None +======= + _builder=_builder, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/distributed/_tools/fake_collectives.py b/torch/distributed/_tools/fake_collectives.py index 3b201b395334b..9ca664e8b489a 100644 --- a/torch/distributed/_tools/fake_collectives.py +++ b/torch/distributed/_tools/fake_collectives.py @@ -63,9 +63,16 @@ def create_fakework(args, return_first_arg=True): # type: ignore[no-untyped-def "recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False), } +<<<<<<< HEAD lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901 for op, meta_func in _META_FUNCTIONS.items(): lib_impl.impl(op, meta_func, "Meta") +======= +if not torch._running_with_deploy(): + lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901 + for op, meta_func in _META_FUNCTIONS.items(): + lib_impl.impl(op, meta_func, "Meta") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # List of collective operation functions including functional collectives # Note: The following collectives might be deprecated soon hence not adding them diff --git a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py index 6153d8e03fdff..c80caf655c00b 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py @@ -232,11 +232,21 @@ def hook_with_zero_step( ) ddp_ref = weakref.ref(ddp) +<<<<<<< HEAD # NOTE: Gloo may hang with this overlapping approach; see https://github.com/pytorch/pytorch/issues/62300 pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] if pg == dist.Backend.GLOO: raise RuntimeError( "Gloo backend using Overlapping DDP with ZeRO may meet hangs" +======= + # NOTE: Gloo may hang with this overlapping approach, so we require + # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 + pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] + if (pg != dist.Backend.NCCL) and (pg != "hccl"): + raise RuntimeError( + "Overlapping DDP with ZeRO using this approach currently requires " + "NCCL/HCCL backend to avoid hangs" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if shard_buckets: @@ -392,11 +402,21 @@ def hook_with_zero_step_interleaved( ) ddp_ref = weakref.ref(ddp) +<<<<<<< HEAD # NOTE: Gloo may hang with this overlapping approach; see https://github.com/pytorch/pytorch/issues/62300 pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] if pg == dist.Backend.GLOO: raise RuntimeError( "Gloo backend using Overlapping DDP with ZeRO may meet hangs" +======= + # NOTE: Gloo may hang with this overlapping approach, so we require + # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 + pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] + if (pg != dist.Backend.NCCL) and (pg != "hccl"): + raise RuntimeError( + "Overlapping DDP with ZeRO using this approach currently requires " + "NCCL/HCCL backend to avoid hangs" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if shard_buckets: diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index f1e95d12514ed..42c284fd9138c 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -434,7 +434,11 @@ def powerSGD_hook( # Keep a copy of the input tensor, # so that we can compute the local error caused by compression later, # by comparing this copy and the input tensor updated after decompression. +<<<<<<< HEAD input_tensor_cp = input_tensor.detach().clone() +======= + input_tensor_cp = torch.clone(input_tensor).detach() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Unflatten the input tensor into per-parameter tensors, for layer-wise compression. tensors = bucket.gradients() @@ -631,7 +635,10 @@ def decompress(fut): if state.use_error_feedback: # Memorize the local errors. +<<<<<<< HEAD assert input_tensor_cp is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) state.error_dict[bucket_index] = input_tensor_cp - input_tensor if not state.warm_start: state.p_memory_dict.clear() @@ -757,7 +764,11 @@ def batched_powerSGD_hook( # Keep a copy of the input tensor, # so that we can compute the local error caused by compression later, # by comparing this copy and the input tensor updated after decompression. +<<<<<<< HEAD input_tensor_cp = input_tensor.detach().clone() +======= + input_tensor_cp = torch.clone(input_tensor).detach() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) matrix = input_tensor.view(square_side_length, square_side_length) # Reuse P and Q from the previous iteration if possible. @@ -844,7 +855,10 @@ def decompress(fut): if state.use_error_feedback: # Memorize the local errors. +<<<<<<< HEAD assert input_tensor_cp is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) state.error_dict[bucket_index] = input_tensor_cp - input_tensor # Removing this seemingly unnecessary sync somehow may cause failures. # See: https://github.com/pytorch/pytorch/pull/54838 diff --git a/torch/distributed/autograd/__init__.py b/torch/distributed/autograd/__init__.py index 6a52c36942e48..1ce06fd4eff78 100644 --- a/torch/distributed/autograd/__init__.py +++ b/torch/distributed/autograd/__init__.py @@ -1,15 +1,23 @@ +<<<<<<< HEAD from __future__ import annotations from typing import Any, TYPE_CHECKING +======= +# mypy: allow-untyped-defs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch +<<<<<<< HEAD if TYPE_CHECKING: from types import TracebackType def is_available() -> bool: +======= +def is_available(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return hasattr(torch._C, "_dist_autograd_init") @@ -31,8 +39,11 @@ def is_available() -> bool: get_gradients, ) +<<<<<<< HEAD __all__ = ["context", "is_available"] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class context: """ @@ -53,6 +64,7 @@ class context: >>> dist_autograd.backward(context_id, [loss]) """ +<<<<<<< HEAD def __enter__(self) -> int: self.autograd_context = _new_context() return self.autograd_context._context_id() @@ -63,4 +75,11 @@ def __exit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: +======= + def __enter__(self): + self.autograd_context = _new_context() + return self.autograd_context._context_id() + + def __exit__(self, type, value, traceback): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _release_context(self.autograd_context._context_id()) diff --git a/torch/distributed/checkpoint/__init__.py b/torch/distributed/checkpoint/__init__.py index c9eb7de5b25a8..713b4ce7bac07 100644 --- a/torch/distributed/checkpoint/__init__.py +++ b/torch/distributed/checkpoint/__init__.py @@ -11,7 +11,10 @@ ) from .optimizer import load_sharded_optimizer_state_dict from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem +<<<<<<< HEAD from .quantized_hf_storage import QuantizedHuggingFaceStorageReader +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .state_dict_loader import load, load_state_dict from .state_dict_saver import async_save, save, save_state_dict from .storage import StorageReader, StorageWriter diff --git a/torch/distributed/checkpoint/_async_executor.py b/torch/distributed/checkpoint/_async_executor.py index 428c697b91e9b..c80dc141a8df4 100644 --- a/torch/distributed/checkpoint/_async_executor.py +++ b/torch/distributed/checkpoint/_async_executor.py @@ -15,14 +15,21 @@ class _AsyncCheckpointExecutor(abc.ABC): @abc.abstractmethod def execute_save( self, +<<<<<<< HEAD staging_future_or_state_dict: Union[STATE_DICT_TYPE, Future[STATE_DICT_TYPE]], +======= + staged_state_dict: STATE_DICT_TYPE, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *, checkpoint_id: Union[str, os.PathLike, None] = None, storage_writer: Optional[StorageWriter] = None, planner: Optional[SavePlanner] = None, process_group: Optional[dist.ProcessGroup] = None, +<<<<<<< HEAD no_dist: bool = False, use_collectives: bool = True, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Future: """ Execute the checkpoint save request asynchronously. diff --git a/torch/distributed/checkpoint/_async_process_executor.py b/torch/distributed/checkpoint/_async_process_executor.py index e708433058440..8fa3fdd8dc987 100644 --- a/torch/distributed/checkpoint/_async_process_executor.py +++ b/torch/distributed/checkpoint/_async_process_executor.py @@ -44,8 +44,11 @@ class _AsyncCheckpointRequest: checkpoint_request_id: _CheckpointRequestIdentifier storage_writer: Optional[StorageWriter] = None planner: Optional[SavePlanner] = None +<<<<<<< HEAD no_dist: bool = False use_collectives: bool = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass(init=False) @@ -152,8 +155,11 @@ def save( checkpoint_id: Union[str, os.PathLike, None] = None, storage_writer: Optional[StorageWriter] = None, planner: Optional[SavePlanner] = None, +<<<<<<< HEAD no_dist: bool = False, use_collectives: bool = True, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Metadata: # Create a unique identifier to locate requests/responses # from the checkpoint daemon process. @@ -163,8 +169,11 @@ def save( checkpoint_request_id=checkpoint_request_id, storage_writer=storage_writer, planner=planner, +<<<<<<< HEAD no_dist=no_dist, use_collectives=use_collectives, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self._send(async_cp_request) result = self._wait_for_response() @@ -178,8 +187,11 @@ def _execute_save( checkpoint_request_id: _CheckpointRequestIdentifier, storage_writer: Optional[StorageWriter] = None, planner: Optional[SavePlanner] = None, +<<<<<<< HEAD no_dist: bool = False, use_collectives: bool = True, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Metadata: from torch.distributed.checkpoint.state_dict_saver import save @@ -188,8 +200,11 @@ def _execute_save( checkpoint_id=checkpoint_request_id.checkpoint_id, storage_writer=storage_writer, planner=planner, +<<<<<<< HEAD no_dist=no_dist, use_collectives=use_collectives, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return metadata @@ -198,8 +213,11 @@ def _checkpointing_subprocess( pg_init_info: _ProcessGroupInitInfo, parent_conn, ) -> None: +<<<<<<< HEAD # Phase 1: Process Group Initialization # Only needs to execute once during the lifetime of the checkpoint background process. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: _init_logger(pg_init_info.global_rank) @@ -220,6 +238,7 @@ def _checkpointing_subprocess( logger.info("Checkpoint background process is running...") parent_conn.send(_CheckpointSaveProcessControlOpts.INIT_COMPLETE) +<<<<<<< HEAD except BaseException as e: # noqa: B036 logger.error( f"Checkpoint background process failed during initialization: {e}" # noqa: G004 @@ -229,6 +248,10 @@ def _checkpointing_subprocess( # Phase 2: Serving Loop try: +======= + + # Serving loop. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) while True: logger.info("Waiting for checkpoint save request...") obj = parent_conn.recv() @@ -243,6 +266,7 @@ def _checkpointing_subprocess( f"Received async checkpoint request with id={obj.checkpoint_request_id.checkpoint_id}" # noqa: G004 ) +<<<<<<< HEAD try: response = _AsyncCheckpointProcess._execute_save( obj.staged_state_dict, @@ -262,6 +286,24 @@ def _checkpointing_subprocess( ) parent_conn.send(e) # Continue serving loop - don't exit process +======= + response = _AsyncCheckpointProcess._execute_save( + obj.staged_state_dict, + checkpoint_request_id=obj.checkpoint_request_id, + storage_writer=obj.storage_writer, + planner=obj.planner, + ) + parent_conn.send(response) + logger.info( + f"Submitted checkpoint save request for checkpoint_id={obj.checkpoint_request_id}" # noqa: G004 + ) + except BaseException as e: + logger.error( + f"Checkpoint background process encountered an exception: {e}" # noqa: G004 + ) + parent_conn.send(e) + raise +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) finally: logger.info("Checkpoint background process is shutting down...") dist.destroy_process_group() @@ -279,13 +321,20 @@ def __init__(self) -> None: def _execute_save_impl( *, pg_init_info: Optional[_ProcessGroupInitInfo], +<<<<<<< HEAD staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], +======= + staged_state_dict: STATE_DICT_TYPE, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) checkpoint_id: Union[str, os.PathLike, None] = None, storage_writer: Optional[StorageWriter] = None, planner: Optional[SavePlanner] = None, process_group: Optional[dist.ProcessGroup] = None, +<<<<<<< HEAD no_dist: bool = False, use_collectives: bool = True, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Metadata: global _CHECKPOINT_PROCESS if _CHECKPOINT_PROCESS is None: @@ -303,30 +352,43 @@ def create_checkpoint_daemon_process() -> None: create_checkpoint_daemon_process() assert _CHECKPOINT_PROCESS is not None +<<<<<<< HEAD staged_state_dict = ( staging_future_or_state_dict.result() if isinstance(staging_future_or_state_dict, Future) else staging_future_or_state_dict ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return _CHECKPOINT_PROCESS.save( staged_state_dict=staged_state_dict, checkpoint_id=checkpoint_id, storage_writer=storage_writer, planner=planner, +<<<<<<< HEAD no_dist=no_dist, use_collectives=use_collectives, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def execute_save( self, +<<<<<<< HEAD staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], +======= + staged_state_dict: STATE_DICT_TYPE, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *, checkpoint_id: Union[str, os.PathLike, None] = None, storage_writer: Optional[StorageWriter] = None, planner: Optional[SavePlanner] = None, process_group: Optional[dist.ProcessGroup] = None, +<<<<<<< HEAD no_dist: bool = False, use_collectives: bool = True, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Future: """ NOTE: @@ -353,12 +415,19 @@ def execute_save( f: Future = self._executor.submit( self._execute_save_impl, pg_init_info=pg_init_info, +<<<<<<< HEAD staging_future_or_state_dict=staging_future_or_state_dict, checkpoint_id=checkpoint_id, storage_writer=storage_writer, planner=planner, no_dist=no_dist, use_collectives=use_collectives, +======= + staged_state_dict=staged_state_dict, + checkpoint_id=checkpoint_id, + storage_writer=storage_writer, + planner=planner, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) f.add_done_callback(lambda f: self._executor.shutdown(wait=False)) diff --git a/torch/distributed/checkpoint/_async_thread_executor.py b/torch/distributed/checkpoint/_async_thread_executor.py index 8dfe63413d433..8385b9697c734 100644 --- a/torch/distributed/checkpoint/_async_thread_executor.py +++ b/torch/distributed/checkpoint/_async_thread_executor.py @@ -11,6 +11,7 @@ from torch.distributed.checkpoint.storage import StorageWriter +<<<<<<< HEAD def save_wrapper( staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], *, @@ -48,23 +49,44 @@ def __init__(self) -> None: def execute_save( self, staging_future_or_state_dict: Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE], +======= +class _ThreadBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor): + def __init__(self) -> None: + self._executor = ThreadPoolExecutor(max_workers=1) + + def execute_save( + self, + staged_state_dict: STATE_DICT_TYPE, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *, checkpoint_id: Union[str, os.PathLike, None] = None, storage_writer: Optional[StorageWriter] = None, planner: Optional[SavePlanner] = None, process_group: Optional[dist.ProcessGroup] = None, +<<<<<<< HEAD no_dist: bool = False, use_collectives: bool = True, ) -> Future: f: Future = self._executor.submit( save_wrapper, staging_future_or_state_dict=staging_future_or_state_dict, +======= + ) -> Future: + from torch.distributed.checkpoint.state_dict_saver import save + + f: Future = self._executor.submit( + save, + staged_state_dict, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) checkpoint_id=checkpoint_id, storage_writer=storage_writer, planner=planner, process_group=process_group, +<<<<<<< HEAD no_dist=no_dist, use_collectives=use_collectives, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) f.add_done_callback(lambda f: self._executor.shutdown(wait=False)) diff --git a/torch/distributed/checkpoint/_checkpointer.py b/torch/distributed/checkpoint/_checkpointer.py index d21d8248d2047..7e35d2988f7d8 100644 --- a/torch/distributed/checkpoint/_checkpointer.py +++ b/torch/distributed/checkpoint/_checkpointer.py @@ -83,14 +83,21 @@ def async_save( Returns: Future: A future holding the resultant Metadata object from `save`. """ +<<<<<<< HEAD response = saver.async_save( +======= + return saver.async_save( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) state_dict, storage_writer=self.storage_writer, process_group=self.process_group, planner=self.save_planner, ) +<<<<<<< HEAD assert isinstance(response, Future) return response +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def load(self, state_dict: dict[str, Any]) -> None: """Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization.""" diff --git a/torch/distributed/checkpoint/_hf_utils.py b/torch/distributed/checkpoint/_hf_utils.py index 0d14229b7f8cc..98263086b8c8c 100644 --- a/torch/distributed/checkpoint/_hf_utils.py +++ b/torch/distributed/checkpoint/_hf_utils.py @@ -41,16 +41,24 @@ FORMAT_KEY = "format" FORMAT_VALUE = "pt" +<<<<<<< HEAD NUM_BYTES_FOR_HEADER_LEN = 8 SHARDED_DIR_NAME = "sharded" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass class _HFStorageInfo: """This is the per entry storage info.""" relative_path: str +<<<<<<< HEAD +======= + offset: int + length: int +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shape: torch.Size dtype: torch.dtype @@ -82,11 +90,20 @@ def _get_safetensors_file_metadata(file_bytes: io.IOBase) -> tuple[Any, int]: # and follows their documentation on how their files are serialized # https://huggingface.co/docs/safetensors/index#format +<<<<<<< HEAD header_len_bytes = file_bytes.read(NUM_BYTES_FOR_HEADER_LEN) header_len = struct.unpack(">>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _get_dtype(dtype_str: str) -> torch.dtype: diff --git a/torch/distributed/checkpoint/_state_dict_stager.py b/torch/distributed/checkpoint/_state_dict_stager.py index 45fbd7686d896..b82e28c05b3a7 100644 --- a/torch/distributed/checkpoint/_state_dict_stager.py +++ b/torch/distributed/checkpoint/_state_dict_stager.py @@ -1,8 +1,16 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD import types import warnings import weakref from copyreg import dispatch_table +======= +import logging +import types +import weakref +from copyreg import dispatch_table +from logging import getLogger +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Any import torch @@ -11,6 +19,13 @@ from torch.utils.weak import WeakIdKeyDictionary +<<<<<<< HEAD +======= +logger = getLogger() +logger.setLevel(logging.INFO) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class StateDictStager: """ A class for optimizing storage objects during staging for async checkpointing. @@ -28,10 +43,16 @@ class StateDictStager: def __init__(self, pin_memory: bool = False, share_memory: bool = False): if pin_memory and not torch.cuda.is_available(): +<<<<<<< HEAD warnings.warn( "Ignoring pin_memory flag for checkpoint staging as pinning memory" "requires CUDA, but CUDA is not available. ", stacklevel=2, +======= + logger.warning( + "Ignoring pin_memory flag for checkpoint staging as pinning memory" + "requires CUDA, but CUDA is not available. " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self.pin_memory = False else: @@ -178,9 +199,14 @@ def _offload_tensor(self, x, memo, non_blocking=False): Returns: A CPU copy of the tensor with optimized storage """ +<<<<<<< HEAD # if data_ptr is not 0, we allocate a new storage below. so we can skip # memory allocation by using [] for size. y = x.new_empty([] if x.data_ptr() != 0 else x.size(), device="cpu") +======= + # Create a new empty tensor on CPU + y = x.new_empty([], device="cpu") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Store in memo dict early to handle recursive references d = id(x) @@ -219,6 +245,7 @@ def _offload_tensor(self, x, memo, non_blocking=False): return y +<<<<<<< HEAD def close(self): """ Clean up all cached storages and release associated resources. @@ -229,6 +256,8 @@ def close(self): """ self._cached_storage_mapping.clear() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch.no_grad() def deepcopy_with_tensor_offload(self, x, memo=None, _nil=[], non_blocking=False): # noqa: B006 """Deep copy operation on arbitrary Python objects with special handling for PyTorch tensors. diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py index 3c9f5831b7e81..bbdcf73306a8b 100644 --- a/torch/distributed/checkpoint/default_planner.py +++ b/torch/distributed/checkpoint/default_planner.py @@ -408,7 +408,11 @@ def _should_include_key(self, key: str, metadata: Metadata) -> bool: return True if key in self.keys: +<<<<<<< HEAD return True +======= + True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unflattened_keys: list[str] = [] planner_data = metadata.planner_data.get(key) @@ -654,7 +658,11 @@ def _validate_global_plan(global_plan: list[SavePlan], metadata: Metadata) -> bo # Check whether combined chunk cover the whole tensor tensor_volume = reduce(operator.mul, value.size, 1) +<<<<<<< HEAD if len(global_plan) > 1 and chunks_volume != tensor_volume: +======= + if chunks_volume != tensor_volume: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) logger.warning( """ key:%s invalid fill tensor-volume: diff --git a/torch/distributed/checkpoint/examples/async_checkpointing_example.py b/torch/distributed/checkpoint/examples/async_checkpointing_example.py index eb0562ec3dada..1371cafd75e71 100644 --- a/torch/distributed/checkpoint/examples/async_checkpointing_example.py +++ b/torch/distributed/checkpoint/examples/async_checkpointing_example.py @@ -4,7 +4,10 @@ import os import shutil import traceback +<<<<<<< HEAD from concurrent.futures import Future +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.distributed as dist @@ -107,7 +110,10 @@ def run(rank, world_size): if epoch % SAVE_PERIOD == 0: if f is not None: +<<<<<<< HEAD assert isinstance(f, Future) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f.result() f = dcp.state_dict_saver.async_save( state_dict, checkpoint_id=CHECKPOINT_DIR @@ -124,7 +130,10 @@ def run(rank, world_size): _print("Reloading model from last checkpoint!") if f is not None: +<<<<<<< HEAD assert isinstance(f, Future) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f.result() dcp.load(state_dict) diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index cc4115cb7de0e..d23b2dac2f46f 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -17,7 +17,11 @@ from enum import Enum from io import UnsupportedOperation from pathlib import Path +<<<<<<< HEAD from typing import Any, Callable, cast, Final, IO, Optional, Union +======= +from typing import Any, Callable, cast, IO, Optional, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # introduced as collections.abc.Buffer in Python 3.12 from typing_extensions import Buffer @@ -68,8 +72,11 @@ _metadata_fn: str = ".metadata" +<<<<<<< HEAD CURRENT_DCP_VERSION: Final[str] = "1.0.0" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass class _StorageInfo: @@ -620,14 +627,18 @@ def __init__( self.overwrite = overwrite self.transforms = _StorageWriterTransforms(_extensions) self.serialization_format = serialization_format +<<<<<<< HEAD self.rank: Optional[int] = None self.use_collectives: bool = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: if checkpoint_id: self.path = self.fs.init_path(checkpoint_id) self.save_id = _generate_uuid() +<<<<<<< HEAD def set_up_storage_writer( self, is_coordinator: bool, *args: Any, **kwargs: Any ) -> None: @@ -651,24 +662,41 @@ def prepare_local_plan(self, plan: SavePlan) -> SavePlan: if self.overwrite: warnings.warn( f"Detected an existing checkpoint in {self.path}, overwriting since {self.overwrite=}." +======= + def set_up_storage_writer(self, is_coordinator: bool) -> None: + pass + + def prepare_local_plan(self, plan: SavePlan) -> SavePlan: + self.fs.mkdir(self.path) + if self.fs.exists(self.metadata_path): + if self.overwrite: + warnings.warn( + f"Detected an existing checkpoint in {self.metadata_path}, overwriting since {self.overwrite=}." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) " Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to" " maintain this functionality or False to raise when an existing checkpoint is found." ) else: raise RuntimeError(f"Checkpoint already exists and {self.overwrite=}.") +<<<<<<< HEAD if self.rank is not None and not self.use_collectives: plan = dataclasses.replace( plan, storage_data=_StoragePrefix(f"__{self.rank}_") ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return plan def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: new_plans = [ dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) +<<<<<<< HEAD if plan.storage_data is None else plan +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i, plan in enumerate(plans) ] return new_plans @@ -752,20 +780,28 @@ def _write_data( return fut def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: +<<<<<<< HEAD metadata = dataclasses.replace(metadata, version=CURRENT_DCP_VERSION) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) storage_md = {} for wr_list in results: storage_md.update({wr.index: wr.storage_data for wr in wr_list}) metadata.storage_data = storage_md metadata.storage_meta = self.storage_meta() +<<<<<<< HEAD tmp_filename = ( f"__{self.rank}{_metadata_fn}.tmp" if not self.use_collectives and self.rank is not None else f"{_metadata_fn}.tmp" ) tmp_path = cast(Path, self.fs.concat_path(self.path, tmp_filename)) +======= + + tmp_path = cast(Path, self.fs.concat_path(self.path, f"{_metadata_fn}.tmp")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.fs.create_stream(tmp_path, "wb") as metadata_file: pickle.dump(metadata, metadata_file) if self.sync_files: @@ -775,6 +811,7 @@ def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: os.sync() # delete in-case other checkpoints were present. +<<<<<<< HEAD if not self.use_collectives and self.rank is not None: metadata_path = self._get_metadata_path(self.rank) else: @@ -784,13 +821,25 @@ def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: self.fs.rm_file(metadata_path) self.fs.rename(tmp_path, metadata_path) +======= + if self.fs.exists(self.metadata_path): + self.fs.rm_file(self.metadata_path) + + self.fs.rename(tmp_path, self.metadata_path) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def storage_meta(self) -> Optional[StorageMeta]: return StorageMeta(checkpoint_id=self.checkpoint_id, save_id=self.save_id) +<<<<<<< HEAD def _get_metadata_path(self, rank: Optional[int] = None) -> os.PathLike: filename = f"{_metadata_fn}" if rank is None else f"__{rank}{_metadata_fn}" return cast(Path, self.fs.concat_path(self.path, filename)) +======= + @property + def metadata_path(self) -> Union[str, os.PathLike]: + return cast(Path, self.fs.concat_path(self.path, _metadata_fn)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def checkpoint_id(self) -> Union[str, os.PathLike]: @@ -842,8 +891,11 @@ def __init__( self.storage_data: dict[Any, Any] = {} self.load_id = _generate_uuid() self.transforms = _StorageReaderTransforms(_extension_registry) +<<<<<<< HEAD self.rank = None self.use_collectives = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _slice_file(self, file, sinfo: _StorageInfo) -> IO[bytes]: return cast(IO[bytes], _create_file_view(file, sinfo.offset, sinfo.length)) @@ -913,6 +965,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: fut.set_result(None) return fut +<<<<<<< HEAD def _get_metadata_path(self, rank: Optional[int] = None) -> os.PathLike: filename = f"{_metadata_fn}" if rank is None else f"__{rank}{_metadata_fn}" return cast(Path, self.fs.concat_path(self.path, filename)) @@ -921,6 +974,11 @@ def _get_metadata_path(self, rank: Optional[int] = None) -> os.PathLike: def read_metadata(self, *args: Any, **kwargs: Any) -> Metadata: rank = kwargs.get("rank", None) path = self._get_metadata_path(rank) +======= + # Implementing the abstract function in StorageReader + def read_metadata(self) -> Metadata: + path = self.fs.concat_path(self.path, ".metadata") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with self.fs.create_stream(path, "rb") as metadata_file: metadata = pickle.load(metadata_file) @@ -930,12 +988,17 @@ def read_metadata(self, *args: Any, **kwargs: Any) -> Metadata: return metadata +<<<<<<< HEAD def set_up_storage_reader( self, metadata: Metadata, is_coordinator: bool, *args: Any, **kwargs: Any ) -> None: self.storage_data = metadata.storage_data self.rank = kwargs.get("rank", None) self.use_collectives = kwargs.get("use_collectives", True) +======= + def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: + self.storage_data = metadata.storage_data +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.storage_data is not None def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: @@ -966,8 +1029,12 @@ class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager): * File creation is atomic The checkpoint consist of one file per write request plus +<<<<<<< HEAD a global `.metadata` file with the serialized metadata if rank coordination is enabled. a rank local `__{rank}.metadata` file with the serialized metadata if rank coordination is NOT enabled. +======= + a `.metadata` file with the serialized metadata. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ diff --git a/torch/distributed/checkpoint/hf_storage.py b/torch/distributed/checkpoint/hf_storage.py index 17db989727d4a..af0ec996056e0 100644 --- a/torch/distributed/checkpoint/hf_storage.py +++ b/torch/distributed/checkpoint/hf_storage.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import dataclasses import json +<<<<<<< HEAD import logging import queue import threading @@ -18,6 +19,27 @@ CUSTOM_METADATA_KEY, SAVED_OFFSETS_KEY, SHARDED_DIR_NAME, +======= +import queue +from typing import Any, Optional + +import torch +from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter +from torch.distributed.checkpoint._hf_utils import ( + _gen_file_name, + _get_dtype, + _get_safetensors_file_metadata, + _HFStorageInfo, + _metadata_fn, + CUSTOM_METADATA_KEY, + DATA_KEY, + DATA_OFFSETS_KEY, + DEFAULT_EXTRA_METADATA_KEY, + DTYPE_KEY, + SAVED_OFFSETS_KEY, + SHAPE_KEY, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SUFFIX, ) from torch.distributed.checkpoint.filesystem import SerializationFormat @@ -41,6 +63,7 @@ from torch.futures import Future +<<<<<<< HEAD logger: logging.Logger = logging.getLogger(__name__) __all__ = ["HuggingFaceStorageWriter", "HuggingFaceStorageReader"] @@ -49,26 +72,48 @@ class HuggingFaceStorageWriter(FileSystemWriter): """ A writer that writes to storage in the huggingface safetensors format. +======= +__all__ = ["HuggingFaceStorageWriter", "HuggingFaceStorageReader"] + + +class HuggingFaceStorageWriter(FsspecWriter): + """ + A writer that writes to a huggingface repository in the huggingface format. + Uses Fsspec back-end to communicate with back-end storage. + Fsspec registration of the storage solution is required. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ def __init__( self, path: str, fqn_to_index_mapping: Optional[dict[str, int]] = None, +<<<<<<< HEAD thread_count: int = 1, save_distributed: bool = False, enable_consolidation: bool = False, thread_count_consolidation: int = 1, +======= + token: Optional[str] = None, + save_sharded: bool = False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: """ Initialize the huggingface writer pointing to path. Args: +<<<<<<< HEAD path: directory where the checkpoint will be read from. +======= + path: hf directory where the checkpoint will be read from. + Needs to have .safetensors files, but can be from any fsspec supported storage, + including localFS and hf://. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fqn_to_index_mapping: A mapping from tensor FQN to the index of the file that the tensor should be written to. Indices are from 1 to N, where N is the number of files. If not provided, the tensors will be written to a single file. If none, then all the tensors on the same rank will be written to the same file. +<<<<<<< HEAD thread_count: Number of threads to use to write distributed checkpoint. Default to 1. save_distributed: If True, save the checkpoint using distributed APIs where every rank saves its own shard. Default is False which assumes rank-0 checkpointing of the full state_dict. @@ -91,14 +136,41 @@ def __init__( self.consolidated_output_path = str(self.path) self.path = self.fs.concat_path(self.path, SHARDED_DIR_NAME) self.thread_count_consolidation = thread_count_consolidation +======= + token: The token to use to authenticate with huggingface hub. + save_sharded: If True, save the checkpoint as a sharded checkpoint where every rank saves its own shard. + Default is False which assumes full tensors are being saved. + + """ + + if token is not None: + super().__init__( + path=path, + token=token, + serialization_format=SerializationFormat.SAFETENSORS, + ) + else: + super().__init__( + path=path, + serialization_format=SerializationFormat.SAFETENSORS, + ) + self._fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping + self._save_sharded = save_sharded +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: new_plans = [] for i, plan in enumerate(plans, start=1): storage_data: dict[str, Any] = {} +<<<<<<< HEAD if self.fqn_to_index_mapping is not None: storage_data["fqn_to_index_mapping"] = self.fqn_to_index_mapping if self.save_distributed: +======= + if self._fqn_to_index_mapping is not None: + storage_data["fqn_to_index_mapping"] = self._fqn_to_index_mapping + if self._save_sharded: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) storage_data["shard_index"] = i new_plans.append(dataclasses.replace(plan, storage_data=storage_data)) @@ -137,6 +209,7 @@ def write_data( return super()._write_data(planner, file_queue) def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: +<<<<<<< HEAD if self.save_distributed and not self.enable_consolidation: # if we are saving distributed, without consolidating, # then we have no metadata to write because a metadata @@ -160,6 +233,11 @@ def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: # writing a model.index.safetensors.json file with fqn to file mapping # for the rank-0 checkpointing case +======= + if self._save_sharded: + return + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) metadata_to_write = {} storage_md = {} total_size = 0 @@ -199,16 +277,28 @@ def metadata_path(self) -> str: return _metadata_fn +<<<<<<< HEAD class HuggingFaceStorageReader(FileSystemReader): """ A reader that reads a checkpoint in the huggingface safetensors format. """ def __init__(self, path: str, thread_count: int = 1) -> None: +======= +class HuggingFaceStorageReader(FsspecReader): + """ + A reader that reads from a huggingface repository in the huggingface format. + Uses in Fsspec back-end to communicate with storage. + Fsspec registration of the storage solution is required. + """ + + def __init__(self, path: str, token: Optional[str] = None) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Initialize the huggingface reader pointing to path. Args: +<<<<<<< HEAD path: directory where the checkpoint will be read from. thread_count: Number of threads to use to read distributed checkpoint. Default to 1. """ @@ -253,6 +343,21 @@ def _read_files_from_queue( def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: from safetensors import safe_open # type: ignore[import] +======= + path: hf directory where the checkpoint will be read from. + Needs to have .safetensors file, but can be from any fsspec supported storage, + including localFS and hf://. + token: The token to use to authenticate with huggingface hub. + """ + + if token is not None: + super().__init__(path=path, token=token) + else: + super().__init__(path=path) + + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + from safetensors import deserialize # type: ignore[import-not-found] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) per_file: dict[str, list[ReadItem]] = {} @@ -261,6 +366,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: file_name = item_md.relative_path per_file.setdefault(file_name, []).append(read_item) +<<<<<<< HEAD if self.thread_count <= 1 or len(per_file) <= 1: for file_name, reqs in per_file.items(): with safe_open(filename=file_name, framework="pt") as f: @@ -302,15 +408,50 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: assert processed_count == len(per_file), ( f"Not all files were processed: {processed_count} out of {len(per_file)}" ) +======= + for file_name, reqs in per_file.items(): + with self.fs.create_stream(file_name, "rb") as stream: + # TODO: make this more efficient by doing offset reads instead of a + # full deserialization of the file + deserialized = deserialize(stream.read()) + deserialized_dict: dict[str, dict[str, Any]] = { + tensor_info[0]: tensor_info[1] for tensor_info in deserialized + } + + for req in reqs: + item_md = self.storage_data[req.storage_index] + + tensor_bytes = deserialized_dict[req.dest_index.fqn][DATA_KEY] + + tensor = torch.frombuffer( + tensor_bytes, + dtype=item_md.dtype, + ) + tensor = tensor.reshape(item_md.shape) + tensor = narrow_tensor_by_index( + tensor, req.storage_offsets, req.lengths + ) + target_tensor = planner.resolve_tensor(req).detach() + + assert target_tensor.size() == tensor.size(), ( + f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + ) + + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fut: Future = Future() fut.set_result(None) return fut def read_metadata(self) -> Metadata: +<<<<<<< HEAD from safetensors import safe_open # type: ignore[import] from safetensors.torch import _getdtype # type: ignore[import] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) state_dict_metadata: dict[str, TensorStorageMetadata] = {} storage_data: dict[MetadataIndex, _HFStorageInfo] = {} @@ -320,6 +461,7 @@ def read_metadata(self) -> Metadata: safetensors_files.append(file) for safetensor_file in safetensors_files: +<<<<<<< HEAD with safe_open(safetensor_file, framework="pt") as f: keys = f.keys() extra_metadata = f.metadata() @@ -333,10 +475,27 @@ def read_metadata(self) -> Metadata: for key in keys: shape = f.get_slice(key).get_shape() dtype = f.get_slice(key).get_dtype() +======= + with self.fs.create_stream(safetensor_file, "rb") as f: + safetensors_metadata, _ = _get_safetensors_file_metadata(f) + custom_metadata = safetensors_metadata.get(DEFAULT_EXTRA_METADATA_KEY) + + dcp_sharding_info = None + if custom_metadata and custom_metadata.get(CUSTOM_METADATA_KEY): + dcp_sharding_info = json.loads( + custom_metadata.get(CUSTOM_METADATA_KEY) + ) + + for key, val in safetensors_metadata.items(): + if key == DEFAULT_EXTRA_METADATA_KEY: + continue + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # construct state_dict_metadata if dcp_sharding_info is not None: offset = dcp_sharding_info[key][SAVED_OFFSETS_KEY] else: +<<<<<<< HEAD offset = [0] * len(shape) if key not in state_dict_metadata: @@ -344,23 +503,49 @@ def read_metadata(self) -> Metadata: properties=TensorProperties(dtype=_getdtype(dtype)), size=torch.Size( [saved + offset for saved, offset in zip(shape, offset)] +======= + offset = [0] * len(val[SHAPE_KEY]) + + if key not in state_dict_metadata: + state_dict_metadata[key] = TensorStorageMetadata( + properties=TensorProperties( + dtype=_get_dtype(val[DTYPE_KEY]) + ), + size=torch.Size( + [ + saved + offset + for saved, offset in zip(val[SHAPE_KEY], offset) + ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), chunks=[ ChunkStorageMetadata( offsets=torch.Size(offset), +<<<<<<< HEAD sizes=torch.Size(shape), +======= + sizes=torch.Size(val[SHAPE_KEY]), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ], ) else: state_dict_metadata[key].chunks.append( ChunkStorageMetadata( +<<<<<<< HEAD torch.Size(offset), sizes=torch.Size(shape) +======= + torch.Size(offset), sizes=torch.Size(val[SHAPE_KEY]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) size = list(state_dict_metadata[key].size) for i in range(len(size)): +<<<<<<< HEAD size[i] = max(size[i], shape[i] + offset[i]) +======= + size[i] = max(size[i], val[SHAPE_KEY][i] + offset[i]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) state_dict_metadata[key].size = torch.Size(size) # construct storage data @@ -369,11 +554,23 @@ def read_metadata(self) -> Metadata: fqn=key, offset=dcp_sharding_info[key][SAVED_OFFSETS_KEY] ) else: +<<<<<<< HEAD metadata_index = MetadataIndex(fqn=key, offset=[0] * len(shape)) storage_data[metadata_index] = _HFStorageInfo( relative_path=safetensor_file, shape=torch.Size(shape), dtype=_getdtype(dtype), +======= + metadata_index = MetadataIndex( + fqn=key, offset=[0] * len(val[SHAPE_KEY]) + ) + storage_data[metadata_index] = _HFStorageInfo( + relative_path=safetensor_file, + offset=val[DATA_OFFSETS_KEY][0], + length=val[DATA_OFFSETS_KEY][1] - val[DATA_OFFSETS_KEY][0], + shape=torch.Size(val[SHAPE_KEY]), + dtype=_get_dtype(val[DTYPE_KEY]), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) metadata = Metadata( diff --git a/torch/distributed/checkpoint/metadata.py b/torch/distributed/checkpoint/metadata.py index 36864b6bf3ad6..528c6d87b176f 100644 --- a/torch/distributed/checkpoint/metadata.py +++ b/torch/distributed/checkpoint/metadata.py @@ -147,7 +147,10 @@ class Metadata: planner_data: Any = None storage_data: Any = None storage_meta: Optional[StorageMeta] = None +<<<<<<< HEAD version: Optional[str] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass(frozen=True) diff --git a/torch/distributed/checkpoint/staging.py b/torch/distributed/checkpoint/staging.py index e7acf4975173c..020063d9a3c90 100644 --- a/torch/distributed/checkpoint/staging.py +++ b/torch/distributed/checkpoint/staging.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD import os import tempfile from concurrent.futures import Future, ThreadPoolExecutor @@ -34,6 +35,16 @@ BlockingAsyncStager: Implementation of AsyncStager which stages the state_dict on CPU RAM and blocks until the copy is complete. Please use DefaultStager instead. """ +======= +from typing import Optional +from typing_extensions import Protocol, runtime_checkable + +from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + + +__all__ = ["AsyncStager", "BlockingAsyncStager"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @runtime_checkable @@ -72,11 +83,18 @@ def should_synchronize_after_execute(self) -> bool: """ Whether to synchronize after executing the stage. """ +<<<<<<< HEAD return self._synchronize_after_execute def stage( self, state_dict: STATE_DICT_TYPE ) -> Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE]: +======= + + return self._synchronize_after_execute + + def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Returns a "staged" copy of `state_dict`. The expectation of the staged copy is that it is inoculated from any updates incurred after the stage call is complete. @@ -85,17 +103,21 @@ def stage( f"{self.__class__.__name__} must implement stage method" ) +<<<<<<< HEAD @deprecated( "`synchronize_staging` is deprecated and will be removed in future versions." "Please use staging_future from AsyncSaveResponse instead.", category=FutureWarning, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def synchronize_staging(self) -> None: """ In the case `stage` is async in some way, this method should be called to ensure staging is complete and it is safe to begin modifying the original `state_dict` """ +<<<<<<< HEAD def close(self) -> None: """ Clean up all resources used by the stager. @@ -267,6 +289,8 @@ def synchronize_staging(self) -> None: if self._staging_future is not None: self._staging_future.result() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class BlockingAsyncStager(AsyncStager): """ @@ -318,6 +342,7 @@ def synchronize_staging(self) -> None: """ No-op function, since staging is blocking. """ +<<<<<<< HEAD def close(self) -> None: pass @@ -464,3 +489,5 @@ def close(self) -> None: """ Clean up resources. Persisted files are intentionally left for future discovery. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index ae3c4df775abd..ea54c9efad6e9 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -1,10 +1,16 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs +<<<<<<< HEAD import inspect import logging import os import warnings from typing import Any, cast, Optional, TYPE_CHECKING, Union +======= +import os +import warnings +from typing import Any, cast, Optional, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing_extensions import deprecated import torch @@ -20,6 +26,7 @@ from .utils import _api_bc_check, _DistWrapper, _profile +<<<<<<< HEAD if TYPE_CHECKING: from torch.distributed.checkpoint.metadata import Metadata @@ -27,6 +34,10 @@ logger = logging.getLogger() +======= +__all__ = ["load_state_dict", "load"] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @deprecated( "`load_state_dict` is deprecated and will be removed in future versions. " @@ -220,6 +231,7 @@ def _load_state_dict( ckpt_kwargs["checkpoint_id"] = ckpt_id ckpt_kwargs["process_group"] = distW.group +<<<<<<< HEAD use_collectives = True metadata: Optional[Metadata] = None @@ -262,6 +274,14 @@ def local_step(): ) else: storage_reader.set_up_storage_reader(metadata, distW.is_coordinator) +======= + @_dcp_method_logger(**ckpt_kwargs) + def local_step(): + assert planner is not None + metadata = storage_reader.read_metadata() + planner.set_up_planner(state_dict, metadata, distW.is_coordinator) + storage_reader.set_up_storage_reader(metadata, distW.is_coordinator) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) local_plan = planner.create_local_plan() local_plan = storage_reader.prepare_local_plan(local_plan) @@ -274,6 +294,7 @@ def global_step(all_local_plans): all_local_plans = storage_reader.prepare_global_plan(all_local_plans) return all_local_plans +<<<<<<< HEAD central_plan: Optional[LoadPlan] = None if use_collectives: central_plan = distW.reduce_scatter("plan", local_step, global_step) @@ -281,22 +302,32 @@ def global_step(all_local_plans): local_plan: LoadPlan = local_step() global_plan: list[LoadPlan] = global_step([local_plan]) central_plan = global_plan[0] +======= + central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @_dcp_method_logger(**ckpt_kwargs) def read_data(): assert planner is not None +<<<<<<< HEAD assert central_plan is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) final_local_plan = planner.finish_plan(central_plan) all_reads = storage_reader.read_data(final_local_plan, planner) all_reads.wait() return None +<<<<<<< HEAD if use_collectives: _ = distW.all_gather("read", read_data) else: read_data() distW.barrier() +======= + _ = distW.all_gather("read", read_data) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _load_state_dict_from_keys( diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index 9971f19db8174..a2673d3537b48 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -4,14 +4,21 @@ import os import warnings from concurrent.futures import Future +<<<<<<< HEAD from dataclasses import dataclass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from enum import Enum from typing import cast, Optional, Union from typing_extensions import deprecated import torch import torch.distributed as dist +<<<<<<< HEAD from torch.distributed._state_dict_utils import STATE_DICT_TYPE +======= +from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.checkpoint._async_executor import ( # noqa: TC001 _AsyncCheckpointExecutor, ) @@ -24,6 +31,7 @@ from torch.distributed.checkpoint._storage_utils import _storage_setup from torch.distributed.checkpoint.default_planner import DefaultSavePlanner from torch.distributed.checkpoint.logger import _dcp_method_logger +<<<<<<< HEAD from torch.distributed.checkpoint.metadata import Metadata from torch.distributed.checkpoint.planner import SavePlan, SavePlanner from torch.distributed.checkpoint.staging import ( @@ -33,11 +41,19 @@ ) from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.storage import StorageWriter, WriteResult +======= +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE +from torch.distributed.checkpoint.planner import SavePlan, SavePlanner +from torch.distributed.checkpoint.staging import AsyncStager +from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.checkpoint.storage import StorageWriter +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.distributed_c10d import _get_default_group from .utils import _api_bc_check, _DistWrapper, _profile +<<<<<<< HEAD __all__ = [ "save_state_dict", "save", @@ -45,6 +61,9 @@ "AsyncCheckpointerType", "AsyncSaveResponse", ] +======= +__all__ = ["save_state_dict", "save", "async_save", "AsyncCheckpointerType"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class AsyncCheckpointerType(Enum): @@ -92,7 +111,10 @@ def save( planner: Optional[SavePlanner] = None, process_group: Optional[dist.ProcessGroup] = None, no_dist: bool = False, +<<<<<<< HEAD use_collectives: bool = True, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Metadata: """ Save a distributed model in SPMD style. @@ -144,6 +166,7 @@ def save( (Default: ``None``) no_dist (bool): If ``True``, this function will assume the intent is to load +<<<<<<< HEAD a checkpoint on a single rank/process. (Default: ``False``) use_collectives (bool): If ``False``, this function will assume the intent is to save @@ -151,6 +174,10 @@ def save( (Default: ``True``) This configuration is experimental and should be used with caution. It will change the format of the saved checkpoint and may not be backward compatible. +======= + a checkpoint without using cross-rank synchronization. + (Default: ``False``) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns: Metadata: Metadata object for the saved checkpoint. @@ -196,6 +223,7 @@ def save( process_group=process_group, no_dist=no_dist, planner=planner, +<<<<<<< HEAD use_collectives=use_collectives, ) @@ -214,6 +242,11 @@ class AsyncSaveResponse: upload_completion: Future[None] +======= + ) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @_dcp_method_logger(log_exceptions=True) def async_save( state_dict: STATE_DICT_TYPE, @@ -223,16 +256,23 @@ def async_save( planner: Optional[SavePlanner] = None, process_group: Optional[dist.ProcessGroup] = None, async_checkpointer_type: AsyncCheckpointerType = AsyncCheckpointerType.THREAD, +<<<<<<< HEAD async_stager: Optional[AsyncStager] = None, no_dist: bool = False, use_collectives: bool = True, ) -> Union[Future, AsyncSaveResponse]: +======= +) -> Future: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Asynchronous version of ``save``. This code first de-stages the state_dict on to the staging storage (defaults to CPU memory), and then calls the `save` in a separate thread. .. warning:: This feature is experimental and subject to change. +<<<<<<< HEAD MUST CALL CLOSE AFTER LAST CHECKPOINT IS SAVED +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: state_dict (Dict[str, Any]): The state_dict to save. @@ -252,6 +292,7 @@ def async_save( process_group (Optional[ProcessGroup]): ProcessGroup to be used for cross-rank synchronization. (Default: ``None``) +<<<<<<< HEAD async_checkpointer_type (AsyncCheckpointerType): whether to do checkpoint in separate thread or process (Default: ``AsyncCheckpointerType.THREAD``) @@ -265,6 +306,8 @@ def async_save( use_collectives: If False, Save the checkpoint without rank coordination. (Default: ``True``) This configuration is experimental and should be used with caution. It will change the format of the saved checkpoint and may not be backward compatible. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns: Future: A future holding the resultant Metadata object from `save`. @@ -298,6 +341,7 @@ def async_save( "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'" ) +<<<<<<< HEAD if async_stager is None: if storage_writer is not None and isinstance(storage_writer, AsyncStager): # bwc with old storage_writers @@ -312,6 +356,8 @@ def async_save( ) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) storage_writer = cast( StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False) ) @@ -319,23 +365,44 @@ def async_save( state_dict = _stateful_to_state_dict(state_dict) @_dcp_method_logger(log_exceptions=True) +<<<<<<< HEAD def stage_state_dict() -> Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE]: return async_stager.stage(state_dict) staging_future_or_state_dict = stage_state_dict() upload_executor: _AsyncCheckpointExecutor = ( +======= + def stage_state_dict(): + if isinstance(storage_writer, AsyncStager): + staged_state_dict = storage_writer.stage(state_dict) + else: # provides bwc for storage_writers not implementing AsyncStager + staged_state_dict = _create_cpu_state_dict(state_dict) + _copy_state_dict(state_dict, staged_state_dict, type_check=False) + + return staged_state_dict + + staged_state_dict = stage_state_dict() + + executor: _AsyncCheckpointExecutor = ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _ProcessBasedAsyncCheckpointExecutor() if async_checkpointer_type == AsyncCheckpointerType.PROCESS else _ThreadBasedAsyncCheckpointExecutor() ) +<<<<<<< HEAD upload_future: Future = upload_executor.execute_save( staging_future_or_state_dict, +======= + f: Future = executor.execute_save( + staged_state_dict, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) checkpoint_id=checkpoint_id, storage_writer=storage_writer, planner=planner, process_group=process_group, +<<<<<<< HEAD no_dist=no_dist, use_collectives=use_collectives, ) @@ -372,6 +439,21 @@ def maybe_synchronize_staging(): maybe_synchronize_staging() return upload_future +======= + ) + + @_dcp_method_logger(log_exceptions=True) + def maybe_synchronize_staging(): + if ( + isinstance(storage_writer, AsyncStager) + and storage_writer.should_synchronize_after_execute + ): + storage_writer.synchronize_staging() + + maybe_synchronize_staging() + + return f +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @_dcp_method_logger(log_exceptions=True) @@ -392,7 +474,10 @@ def _save_state_dict( coordinator_rank: int = 0, no_dist: bool = False, planner: Optional[SavePlanner] = None, +<<<<<<< HEAD use_collectives: bool = True, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Metadata: torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict") @@ -425,6 +510,7 @@ def local_step(): storage_meta=storage_meta, is_coordinator=distW.is_coordinator, ) +<<<<<<< HEAD if ( "kwargs" @@ -437,6 +523,9 @@ def local_step(): ) else: storage_writer.set_up_storage_writer(distW.is_coordinator) +======= + storage_writer.set_up_storage_writer(distW.is_coordinator) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) local_plan = planner.create_local_plan() local_plan = storage_writer.prepare_local_plan(local_plan) @@ -451,6 +540,7 @@ def global_step(all_local_plans): all_local_plans = storage_writer.prepare_global_plan(all_local_plans) return all_local_plans +<<<<<<< HEAD central_plan: Optional[SavePlan] = None if use_collectives: central_plan = distW.reduce_scatter("plan", local_step, global_step) @@ -458,11 +548,17 @@ def global_step(all_local_plans): local_plan: SavePlan = local_step() global_plan: list[SavePlan] = global_step([local_plan]) central_plan = global_plan[0] +======= + central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @_dcp_method_logger(**ckpt_kwargs) def write_data(): assert planner is not None +<<<<<<< HEAD assert central_plan is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) final_local_plan = planner.finish_plan(central_plan) all_writes = storage_writer.write_data(final_local_plan, planner) @@ -475,6 +571,7 @@ def finish_checkpoint(all_results): storage_writer.finish(metadata=global_metadata, results=all_results) return global_metadata +<<<<<<< HEAD if use_collectives: metadata = distW.all_reduce("write", write_data, finish_checkpoint) else: @@ -483,3 +580,6 @@ def finish_checkpoint(all_results): distW.barrier() return metadata +======= + return distW.all_reduce("write", write_data, finish_checkpoint) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/distributed/checkpoint/storage.py b/torch/distributed/checkpoint/storage.py index b184d7b170052..656976a2e47b1 100644 --- a/torch/distributed/checkpoint/storage.py +++ b/torch/distributed/checkpoint/storage.py @@ -61,9 +61,13 @@ def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: ... @abc.abstractmethod +<<<<<<< HEAD def set_up_storage_writer( self, is_coordinator: bool, *args: Any, **kwargs: Any ) -> None: +======= + def set_up_storage_writer(self, is_coordinator: bool) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Initialize this instance. @@ -202,7 +206,11 @@ def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: ... @abc.abstractmethod +<<<<<<< HEAD def read_metadata(self, *args: Any, **kwargs: Any) -> Metadata: +======= + def read_metadata(self) -> Metadata: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Read the checkpoint metadata. @@ -212,9 +220,13 @@ def read_metadata(self, *args: Any, **kwargs: Any) -> Metadata: """ @abc.abstractmethod +<<<<<<< HEAD def set_up_storage_reader( self, metadata: Metadata, is_coordinator: bool, *args: Any, **kwargs: Any ) -> None: +======= + def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Initialize this instance. diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index 6d00026d99349..02192a2a17992 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -190,7 +190,11 @@ def reduce_scatter( local_data: Union[WRAPPED_EXCEPTION, T] try: local_data = map_fun() +<<<<<<< HEAD except BaseException as e: # noqa: B036 +======= + except BaseException as e: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) local_data = _wrap_exception(e) all_data = self.gather_object(local_data) @@ -206,7 +210,11 @@ def reduce_scatter( list[Union[R, CheckpointException]], reduce_fun(cast(list[T], all_data)), ) +<<<<<<< HEAD except BaseException as e: # noqa: B036 +======= + except BaseException as e: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) node_failures[self.rank] = _wrap_exception(e) if len(node_failures) > 0: @@ -237,7 +245,11 @@ def all_reduce( local_data: Union[T, WRAPPED_EXCEPTION] try: local_data = map_fun() +<<<<<<< HEAD except BaseException as e: # noqa: B036 +======= + except BaseException as e: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) local_data = _wrap_exception(e) all_data = self.gather_object(local_data) @@ -248,7 +260,11 @@ def all_reduce( if len(node_failures) == 0: try: result = reduce_fun(cast(list[T], all_data)) +<<<<<<< HEAD except BaseException as e: # noqa: B036 +======= + except BaseException as e: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) node_failures[self.rank] = _wrap_exception(e) if len(node_failures) > 0: @@ -274,7 +290,11 @@ def all_gather( result: Union[T, WRAPPED_EXCEPTION] try: result = map_fun() +<<<<<<< HEAD except BaseException as e: # noqa: B036 +======= + except BaseException as e: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result = _wrap_exception(e) all_results = self.all_gather_object(result) @@ -300,7 +320,11 @@ def broadcast( if self.is_coordinator: try: result = map_fun() +<<<<<<< HEAD except BaseException as e: # noqa: B036 +======= + except BaseException as e: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result = CheckpointException(step, {self.rank: _wrap_exception(e)}) final_result = self.broadcast_object(result) if isinstance(final_result, CheckpointException): diff --git a/torch/distributed/collective_utils.py b/torch/distributed/collective_utils.py index 715cd251ea4d7..b14d6d3090e82 100644 --- a/torch/distributed/collective_utils.py +++ b/torch/distributed/collective_utils.py @@ -9,6 +9,7 @@ from __future__ import annotations +<<<<<<< HEAD import importlib import logging from collections import defaultdict @@ -32,6 +33,14 @@ logger = logging.getLogger(__name__) +======= +from dataclasses import dataclass +from typing import Any, Callable, cast, Generic, Optional, TypeVar, Union + +import torch.distributed as dist + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T = TypeVar("T") @@ -232,6 +241,7 @@ def all_gather_object_enforce_type( f"Object type at index {i} is {type(object_list[i])}, " f"while first object type is {type(first_obj)}" ) +<<<<<<< HEAD def _summarize_ranks(ranks: Iterable[int]) -> str: @@ -341,3 +351,5 @@ def _check_rng_sync( log_str = f"Generator desync detected:\n{_desync_table_str(value_header, value_ranks)}" logger.error(log_str) return log_str +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 13bb084299c69..9af9e818de901 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -5,9 +5,14 @@ import os import threading import warnings +<<<<<<< HEAD from collections.abc import Iterator from functools import reduce from itertools import chain, zip_longest +======= +from functools import reduce +from itertools import chain +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Optional, TYPE_CHECKING, Union import torch @@ -70,7 +75,11 @@ def __init__(self) -> None: self.mesh_stack: list[DeviceMesh] = [] self.child_to_root_mapping: dict[DeviceMesh, DeviceMesh] = {} self.mesh_dim_group_options: dict[ +<<<<<<< HEAD int, tuple[Optional[str], Optional[C10dBackend.Options]] +======= + int, tuple[str, Optional[C10dBackend.Options]] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] = {} self.root_to_flatten_mapping: dict[DeviceMesh, dict[str, DeviceMesh]] = {} # Record flatten mesh name to its mesh dim index in root mesh. @@ -126,21 +135,32 @@ def create_sub_mesh( slice_dim_group_name.append( self.root_to_flatten_mapping[device_mesh][ mesh_dim_name +<<<<<<< HEAD ]._dim_group_names[0] # type: ignore[has-type] +======= + ]._dim_group_names[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten) slice_dim_group_name.append( +<<<<<<< HEAD device_mesh._dim_group_names[mesh_dim_indices[0]] # type: ignore[has-type] +======= + device_mesh._dim_group_names[mesh_dim_indices[0]] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now. mesh_dims_remained_idx = list(range(mesh_tensor.ndim)) for idx in slice_dim_idx: +<<<<<<< HEAD if idx not in mesh_dims_remained_idx: raise NotImplementedError( "Currently, this only allows slicing out a contiguous flattened dim." ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mesh_dims_remained_idx.remove(idx) # pg_ranks_by_dim is the size of [number of local ranks of the outermost submesh dimension, *slice_dim_idx] @@ -161,12 +181,17 @@ def create_sub_mesh( if cur_rank in mesh_nd: res_submesh = submesh +<<<<<<< HEAD res_submesh._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined, has-type] +======= + res_submesh._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.child_to_root_mapping[res_submesh] = device_mesh return res_submesh def create_flatten_mesh( +<<<<<<< HEAD self, device_mesh: "DeviceMesh", mesh_dim_name: Optional[str] = None, @@ -174,21 +199,42 @@ def create_flatten_mesh( None, None, ), +======= + self, device_mesh: "DeviceMesh", mesh_dim_name: Optional[str] = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> "DeviceMesh": root_mesh = _mesh_resources.get_root_mesh(device_mesh) flatten_dims_in_root = [ +<<<<<<< HEAD not_none(root_mesh.mesh_dim_names).index(flatten_mesh_dim_name) for flatten_mesh_dim_name in not_none(device_mesh.mesh_dim_names) ] if not mesh_dim_name: mesh_dim_name = "_".join(not_none(device_mesh.mesh_dim_names)) +======= + not_none(root_mesh.mesh_dim_names).index(flattened_mesh_dim_name) + for flattened_mesh_dim_name in not_none(device_mesh.mesh_dim_names) + ] + + if not mesh_dim_name: + mesh_dim_name = "_".join( + [ + not_none(root_mesh.mesh_dim_names)[dim] + for dim in flatten_dims_in_root + ] + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Check whether the mesh_dim_name for flattened mesh is valid. self.flatten_name_to_root_dims.setdefault(root_mesh, {}) invalid_dim_names = chain( +<<<<<<< HEAD list(not_none(root_mesh.mesh_dim_names)), +======= + *list(not_none(root_mesh.mesh_dim_names)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *self.flatten_name_to_root_dims[root_mesh].keys(), ) if mesh_dim_name in invalid_dim_names: @@ -224,7 +270,10 @@ def create_flatten_mesh( root_mesh.device_type, mesh_nd, mesh_dim_names=(mesh_dim_name,), +<<<<<<< HEAD backend_override=(backend_override,), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if cur_rank in mesh_nd: res_flattened_mesh = flattened_mesh @@ -291,7 +340,11 @@ def get_mesh_dim_by_name( def _set_mesh_dim_group_options( self, dim: int, +<<<<<<< HEAD backend: Optional[str], +======= + backend: str, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pg_options: Optional[C10dBackend.Options] = None, ) -> None: self.mesh_dim_group_options[dim] = (backend, pg_options) @@ -304,11 +357,15 @@ def _get_slice_mesh_dims( If valid, return dim indexes of the slice mesh in the device mesh. """ if device_mesh != self.get_root_mesh(device_mesh): +<<<<<<< HEAD warnings.warn( "You are attempting to slice a submesh from another submesh. While we support this operation, " "it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. " "If not, this may result in some ranks receiving the submesh while others encounter errors." ) +======= + raise RuntimeError("Cannot create a submesh from a submesh.") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The slice mesh_dim_names should consist either the device_mesh's mesh_dim_names # or its flattened mesh's mesh_dim_names. @@ -344,9 +401,15 @@ def _get_slice_mesh_dims( slice_mesh_dims.append((next_idx,)) if next_idx <= curr_idx: raise KeyError( +<<<<<<< HEAD f"Invalid mesh_dim_names {mesh_dim_names} specified. " f"Found mesh dim indices to slice: {slice_mesh_dims}. " "Mesh dim indices should be in ascending order." +======= + f"Invalid mesh_dim_names {mesh_dim_names} specified. ", + f"Found mesh dim indices to slice: {slice_mesh_dims}. ", + "Mesh dim indices should be in ascending order.", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) curr_idx = next_idx @@ -373,7 +436,11 @@ def _get_all_submeshes( _init_backend=False, ) submesh._dim_group_names = ( +<<<<<<< HEAD [device_mesh._dim_group_names[mesh_dim]] # type: ignore[has-type] +======= + [device_mesh._dim_group_names[mesh_dim]] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if cur_rank in mesh_1d else [] ) @@ -447,9 +514,12 @@ def __init__( mesh: Union[torch.Tensor, "ArrayLike"], *, mesh_dim_names: Optional[tuple[str, ...]] = None, +<<<<<<< HEAD backend_override: Optional[ tuple[tuple[Optional[str], Optional[C10dBackend.Options]], ...] ] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _init_backend: bool = True, ) -> None: self.device_type = device_type @@ -461,8 +531,11 @@ def __init__( else torch.tensor(mesh, device="cpu", dtype=torch.int) ) self.mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None +<<<<<<< HEAD if backend_override is None: backend_override = ((None, None),) * self.mesh.ndim +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # private field to pre-generate DeviceMesh's hash self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) @@ -476,7 +549,11 @@ def __init__( # process (we need to know if the current global rank is in the mesh or not). if _init_backend: self._setup_world_group_and_device() +<<<<<<< HEAD self._init_process_groups(backend_override) +======= + self._init_process_groups() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_initialized() and get_backend() == "threaded": self._thread_id = threading.get_ident() @@ -538,18 +615,23 @@ def _setup_world_group_and_device(self): return _get_default_group() +<<<<<<< HEAD def _init_process_groups( self, backend_override: tuple[ tuple[Optional[str], Optional[C10dBackend.Options]], ... ], ): +======= + def _init_process_groups(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # group_name associated with each mesh dimension, each # mesh dimension should have one sub-group per rank # dim_group_names: list[str] = [] default_group = _get_default_group() +<<<<<<< HEAD if ( self.mesh.ndim == 1 and self.mesh.numel() == get_world_size() @@ -557,6 +639,9 @@ def _init_process_groups( == (None, None) and backend_override[0] == (None, None) ): +======= + if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Append the default pg to the first dim groups only if the default pg is compatible with `self.device_type`. # Otherwise, create new pg. ranks = list(range(get_world_size())) @@ -583,17 +668,24 @@ def _init_process_groups( # Respect dim group options specified via _MeshEnv.set_dim_group_options(). # Inherit from the parent group if no options are specified for the group. if dim in _mesh_resources.mesh_dim_group_options: +<<<<<<< HEAD if backend_override[dim] != (None, None): raise RuntimeError( f"Dimension {dim} present both in the backend_override argument " "and via _mesh_resources._set_mesh_dim_group_options" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ( backend, pg_options, ) = _mesh_resources.mesh_dim_group_options[dim] else: +<<<<<<< HEAD backend, pg_options = backend_override[dim] +======= + backend, pg_options = None, None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description # of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`. @@ -616,6 +708,7 @@ def _init_process_groups( dim_group = None has_split_group = False if ( +<<<<<<< HEAD ( bound_device_id := getattr( default_group, "bound_device_id", None @@ -629,6 +722,12 @@ def _init_process_groups( == backend ) ): +======= + bound_device_id := getattr( + default_group, "bound_device_id", None + ) + ) is not None and torch.cuda.is_available(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dim_group = split_group( parent_pg=default_group, pg_options=pg_options, @@ -677,6 +776,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback) -> None: def __repr__(self) -> str: device_mesh_repr = ( +<<<<<<< HEAD f"({', '.join(f'{k}={v}' for k, v in zip(self.mesh_dim_names, self.mesh.shape))})" if self.mesh_dim_names else f"{tuple(self.mesh.shape)}" @@ -686,6 +786,13 @@ def __repr__(self) -> str: if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL": device_mesh_repr += f", Mesh: {self.mesh.tolist()}" return f"{device_mesh_repr})" +======= + f"DeviceMesh('{self.device_type}', {self.mesh.tolist()})" + if not self.mesh_dim_names + else f"DeviceMesh('{self.device_type}', {self.mesh.tolist()}, mesh_dim_names={self.mesh_dim_names})" + ) + return device_mesh_repr +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __hash__(self): # lazily compute hash @@ -703,6 +810,7 @@ def __hash__(self): return self._hash def __eq__(self, other: object) -> bool: +<<<<<<< HEAD if self is other: return True if not isinstance(other, DeviceMesh): @@ -714,6 +822,20 @@ def __eq__(self, other: object) -> bool: and self.mesh_dim_names == other.mesh_dim_names and self._thread_id == other._thread_id ) +======= + if not isinstance(other, DeviceMesh): + return False + if id(self) == id(other): + return True + else: + return ( + self._flatten_mesh_list == other._flatten_mesh_list + and self.mesh.shape == other.mesh.shape + and self.device_type == other.device_type + and self.mesh_dim_names == other.mesh_dim_names + and self._thread_id == other._thread_id + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __getitem__( self, mesh_dim_names: Union[str, tuple[str, ...]] @@ -1001,6 +1123,7 @@ def get_coordinate(self) -> Optional[list[int]]: """ return self._coordinate_on_dim if self._coordinate_on_dim else None +<<<<<<< HEAD def _flatten( self, mesh_dim_name: Optional[str] = None, @@ -1008,14 +1131,22 @@ def _flatten( None, str, C10dBackend.Options, tuple[str, C10dBackend.Options] ] = None, ) -> "DeviceMesh": +======= + def _flatten(self, mesh_dim_name: Optional[str] = None) -> "DeviceMesh": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Returns a 1D DeviceMesh by flattening the current DeviceMesh. If no mesh_dim_name is provided, the default is a string concatenating the mesh_dim_names of the given submesh with each mesh_dim_name separated by "_". For example, if we have a 3D mesh DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")), calling +<<<<<<< HEAD mesh_3d["dp", "cp"]._flatten() will create a 1D submesh DeviceMesh([0, 2, 4, 6], mesh_dim_names=("dp_cp",)) on rank 0, 2, 4, 6 and a 1D submesh DeviceMesh([1, 3, 5, 7], mesh_dim_names=("dp_cp",)) on rank 1, 3, 5, 7. +======= + mesh_3d["dp", "cp"]._flatten() will create a 1D submesh DeviceMesh([0, 1, 2, 3], mesh_dim_names=("dp_cp",)) + on rank 0, 1, 2, 3 and a 1D submesh DeviceMesh([4, 5, 6, 7], mesh_dim_names=("dp_cp",)) on rank 4, 5, 6, 7. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) After the flattened dimension is created, to access the flattened dimension in mesh_3d, one can use the existing slicing method to obtain the flattened mesh through calling mesh_3d["dp_cp"]. @@ -1025,6 +1156,7 @@ def _flatten( "Cannot flatten a DeviceMesh without mesh_dim_names!" ) +<<<<<<< HEAD if backend_override is not None: (backend_override_tuple,) = _normalize_backend_override( {0: backend_override}, 1 @@ -1072,18 +1204,24 @@ def _normalize_backend_override( f"Found invalid keys in backend_override: got {list(backend_override.keys())}, " f"expected integers in range [0, {ndim}) or one of {mesh_dim_names}" ) +======= + return _mesh_resources.create_flatten_mesh(self, mesh_dim_name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def init_device_mesh( device_type: str, mesh_shape: tuple[int, ...], *, mesh_dim_names: Optional[tuple[str, ...]] = None, +<<<<<<< HEAD backend_override: Optional[ dict[ Union[int, str], Union[str, C10dBackend.Options, tuple[str, C10dBackend.Options]], ] ] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> DeviceMesh: """ Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. @@ -1101,18 +1239,25 @@ def init_device_mesh( required for distributed communications behind the scene. Args: +<<<<<<< HEAD device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like", "xpu". +======= + device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like". +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Passing in a device type with a GPU index, such as "cuda:0", is not allowed. mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array describing the layout of devices. mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension of the multi-dimensional array describing the layout of devices. Its length must match the length of `mesh_shape`. Each string in `mesh_dim_names` must be unique. +<<<<<<< HEAD backend_override (Dict[int | str, tuple[str, Options] | str | Options], optional): Overrides for some or all of the ProcessGroups that will be created for each mesh dimension. Each key can be either the index of a dimension or its name (if mesh_dim_names is provided). Each value can be a tuple containing the name of the backend and its options, or just one of these two components (in which case the other will be set to its default value). +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns: DeviceMesh: A :class:`DeviceMesh` object representing the device layout. @@ -1139,6 +1284,7 @@ def init_device_mesh( f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.", ) +<<<<<<< HEAD if backend_override is not None: backend_override_tuple = tuple( _normalize_backend_override( @@ -1148,6 +1294,8 @@ def init_device_mesh( else: backend_override_tuple = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # assume valid device types are all letters if device_type and not device_type.isalpha(): raise RuntimeError( @@ -1163,7 +1311,10 @@ def init_device_mesh( device_type=device_type, mesh=mesh, mesh_dim_names=mesh_dim_names, +<<<<<<< HEAD backend_override=backend_override_tuple, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return device_mesh diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 14790e5dba8af..d63b99c79f470 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1568,12 +1568,20 @@ def init_process_group( Args: backend (str or Backend, optional): The backend to use. Depending on build-time configurations, valid values include ``mpi``, ``gloo``, +<<<<<<< HEAD ``nccl``, ``ucc``, ``xccl`` or one that is registered by a third-party +======= + ``nccl``, ``ucc``, or one that is registered by a third-party +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) plugin. Since 2.6, if ``backend`` is not provided, c10d will use a backend registered for the device type indicated by the `device_id` kwarg (if provided). The known default registrations today are: ``nccl`` +<<<<<<< HEAD for ``cuda``, ``gloo`` for ``cpu``, ``xccl`` for ``xpu``. +======= + for ``cuda``, ``gloo`` for ``cpu``. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) If neither ``backend`` nor ``device_id`` is provided, c10d will detect the accelerator on the run-time machine and use a backend registered for that detected accelerator (or ``cpu``). @@ -1751,6 +1759,7 @@ def init_process_group( else: # backward compatible API if store is None: +<<<<<<< HEAD if backend == "fake": from torch.testing._internal.distributed.fake_pg import FakeStore @@ -1761,6 +1770,13 @@ def init_process_group( ) store, rank, world_size = next(rendezvous_iterator) store.set_timeout(timeout) +======= + rendezvous_iterator = rendezvous( + not_none(init_method), rank, world_size, timeout=timeout + ) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Use a PrefixStore to avoid accidental overrides of keys used by # different systems (e.g. RPC) in case the store is multi-tenant. @@ -1939,9 +1955,15 @@ def _new_process_group_helper( if "," not in str(backend) and ":" not in str(backend): assert backend in Backend.backend_type_map, f"Unknown backend type {backend}" if backend == Backend.UNDEFINED: +<<<<<<< HEAD # Currently when backend is UNDEFINED, only one backend will be initialized # we use nccl (if cuda is available) or gloo as default backend # so we can correctly call getDefaultBackend which in ProcessGroup. +======= + # Currently when backend is UNDEFINED, both ``gloo`` and ``nccl`` backends + # will be created, we use nccl(if cuda is available) or gloo as default + # backend so we can correctly call getDefaultBackend which in ProcessGroup. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if Backend.NCCL in backend_config.get_device_backend_map().values(): pg._set_default_backend(ProcessGroup.BackendType.NCCL) else: @@ -2040,12 +2062,17 @@ def _new_process_group_helper( elif backend_str == Backend.XCCL: if not is_xccl_available(): raise RuntimeError("Distributed package doesn't have XCCL built in") +<<<<<<< HEAD backend_options = ProcessGroupXCCL.Options() backend_options.global_ranks_in_group = global_ranks_in_group backend_options.group_name = group_name backend_options._timeout = timeout backend_class = ProcessGroupXCCL( backend_prefix_store, group_rank, group_size, backend_options +======= + backend_class = ProcessGroupXCCL( + backend_prefix_store, group_rank, group_size +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) backend_type = ProcessGroup.BackendType.XCCL else: @@ -2830,8 +2857,11 @@ def broadcast( opts.rootRank = group_src opts.rootTensor = 0 opts.asyncOp = async_op +<<<<<<< HEAD if tensor.is_complex(): tensor = torch.view_as_real(tensor) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) work = group.broadcast([tensor], opts) if async_op: return work @@ -3332,7 +3362,10 @@ def send_object_list( group: Optional[ProcessGroup] = None, device: Optional[torch.device] = None, group_dst: Optional[int] = None, +<<<<<<< HEAD use_batch: bool = False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Sends picklable objects in ``object_list`` synchronously. @@ -3353,10 +3386,13 @@ def send_object_list( ``device`` before sending. Default is ``None``. group_dst (int, optional): Destination rank on ``group``. Must specify one of ``dst`` and ``group_dst`` but not both +<<<<<<< HEAD use_batch (bool, optional): If True, use batch p2p operations instead of regular send operations. This avoids initializing 2-rank communicators and uses existing entire group communicators. See batch_isend_irecv for usage and assumptions. Default is ``False``. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns: ``None``. @@ -3420,12 +3456,16 @@ def send_object_list( object_sizes_tensor = torch.cat(size_list) # Send object sizes +<<<<<<< HEAD if use_batch: batch_isend_irecv( [P2POp(isend, object_sizes_tensor, group_peer=group_dst, group=group)] ).pop().wait() else: send(object_sizes_tensor, group_dst=group_dst, group=group) +======= + send(object_sizes_tensor, group_dst=group_dst, group=group) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Concatenate and send serialized object tensors # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list @@ -3435,12 +3475,16 @@ def send_object_list( else: object_tensor = torch.cat(tensor_list) +<<<<<<< HEAD if use_batch: batch_isend_irecv( [P2POp(isend, object_tensor, group_peer=group_dst, group=group)] ).pop().wait() else: send(object_tensor, group_dst=group_dst, group=group) +======= + send(object_tensor, group_dst=group_dst, group=group) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @_exception_logger @@ -3450,7 +3494,10 @@ def recv_object_list( group: Optional[ProcessGroup] = None, device: Optional[torch.device] = None, group_src: Optional[int] = None, +<<<<<<< HEAD use_batch: bool = False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Receives picklable objects in ``object_list`` synchronously. @@ -3468,10 +3515,13 @@ def recv_object_list( device (``torch.device``, optional): If not None, receives on this device. Default is ``None``. group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``. +<<<<<<< HEAD use_batch (bool, optional): If True, use batch p2p operations instead of regular send operations. This avoids initializing 2-rank communicators and uses existing entire group communicators. See batch_isend_irecv for usage and assumptions. Default is ``False``. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns: Sender rank. -1 if rank is not part of the group. If rank is part of the group, @@ -3515,10 +3565,13 @@ def recv_object_list( >>> objects ['foo', 12, {1: 2}] """ +<<<<<<< HEAD group = _group_or_default_group(group) group_src = _canonicalize_group_rank(group, src, group_src) _check_not_self_rank(group, group_src, "source") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if _rank_not_in_group(group): _warn_not_in_group("recv_object_list") return -1 @@ -3535,6 +3588,7 @@ def recv_object_list( ) # Receive object sizes +<<<<<<< HEAD if use_batch: work = batch_isend_irecv( [ @@ -3550,6 +3604,9 @@ def recv_object_list( rank_sizes = get_global_rank(group, group_src) else: rank_sizes = recv(object_sizes_tensor, group=group, group_src=group_src) +======= + rank_sizes = recv(object_sizes_tensor, src=src, group=group, group_src=group_src) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Tensor to receive serialized objects into. object_tensor = torch.empty( # type: ignore[call-overload] @@ -3558,6 +3615,7 @@ def recv_object_list( device=current_device, ) +<<<<<<< HEAD if use_batch: work = batch_isend_irecv( [ @@ -3573,6 +3631,9 @@ def recv_object_list( rank_objects = get_global_rank(group, group_src) else: rank_objects = recv(object_tensor, group=group, group_src=group_src) +======= + rank_objects = recv(object_tensor, src=src, group=group, group_src=group_src) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert rank_sizes == rank_objects, ( "Mismatch in return ranks for object sizes and objects." ) @@ -4841,11 +4902,14 @@ def barrier( None, if not async_op or if not part of the group .. note:: `ProcessGroupNCCL` now blocks the cpu thread till the completion of the barrier collective. +<<<<<<< HEAD .. note:: `ProcessGroupNCCL` implements barrier as an all_reduce of a 1-element tensor. A device must be chosen for allocating this tensor. The device choice is made by checking in this order (1) the first device passed to `device_ids` arg of barrier if not None, (2) the device passed to init_process_group if not None, (3) the device that was first used with this process group, if another collective with tensor inputs has been performed, (4) the device index indicated by the global rank mod local device count. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ group = group or _get_default_group() @@ -4872,11 +4936,17 @@ def barrier( # may use default device 0, causing issues like hang or all processes # creating context on device 0. opts.device = device +<<<<<<< HEAD if group.rank() == 0: warnings.warn( # warn only once "barrier(): using the device under current context. " "You can specify `device_id` in `init_process_group` to mute this warning." ) +======= + warnings.warn( # warn only once + "No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. " + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) work = group.barrier(opts=opts) @@ -5081,8 +5151,13 @@ def split_group( """ # check inputs +<<<<<<< HEAD if split_ranks is None or len(split_ranks) == 0: raise ValueError("split_ranks cannot be None or empty") +======= + if split_ranks is None: + raise ValueError("split_ranks cannot be None") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) global _world default_pg = _get_default_group() @@ -5091,6 +5166,10 @@ def split_group( raise RuntimeError( "No device associated with the default pg, not safe to split any process groups" ) +<<<<<<< HEAD +======= + _default_backend, default_store = _world.pg_map[default_pg] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) global_rank = default_pg.rank() global_world_size = default_pg.size() @@ -5121,8 +5200,16 @@ def split_group( ) # set the group_desc before the color or no_cloor split +<<<<<<< HEAD if hasattr(parent_backend, "comm_split_count") and group_desc is None: group_desc = f"{parent_pg.group_desc}:split:{parent_backend.comm_split_count()}" # type: ignore[attr-defined] +======= + group_desc = ( + f"{parent_pg.group_desc}:split:{parent_backend.comm_split_count()}" # type: ignore[attr-defined] + if group_desc is None + else group_desc + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) parent_backend_str, _ = _world.pg_map[parent_pg] # same type of backend as the parent process group @@ -5140,9 +5227,14 @@ def split_group( _check_valid_timeout(timeout) # find my group of ranks and my group local rank in split_ranks +<<<<<<< HEAD # for ranks which are not in any split PGs, we just pass in this the first split group # and None will be returned. my_group = split_ranks[0] +======= + my_group = None + group_rank = -1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for split_group in split_ranks: if len(split_group) == 0: @@ -5156,6 +5248,7 @@ def split_group( split_group = sorted(split_group) if parent_group_rank in split_group: my_group = split_group +<<<<<<< HEAD break group_name = _process_group_name(my_group, use_hashed_name=False) @@ -5188,11 +5281,94 @@ def split_group( # Create the global rank to group rank mapping _world.pg_group_ranks[split_pg] = { +======= + group_rank = split_group.index(parent_group_rank) + break + # if my rank does not belong to any sub group, + # no_color split should be called + if my_group is None or group_rank == -1: + parent_backend.perform_nocolor_split(device_id) # type: ignore[attr-defined] + return None + + group_name = _process_group_name(my_group, use_hashed_name=False) + global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group] + + prefix_store = PrefixStore(f"{group_name}/", default_store) + # We register the backend after initializing and timeout is set in pg_options. + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + len(my_group), + ) + pg.bound_device_id = device_id # type: ignore[union-attr] + pg_options._timeout = timeout # type: ignore[union-attr] + pg_options.split_from = parent_backend # type: ignore[union-attr] + pg_options.split_color = _process_group_color(my_group) # type: ignore[union-attr] + pg_options.global_ranks_in_group = global_ranks_in_my_group # type: ignore[union-attr] + pg_options.group_name = group_name # type: ignore[union-attr] + + if parent_backend_str == Backend.NCCL: + backend_type = ProcessGroup.BackendType.NCCL + if not isinstance(pg_options, ProcessGroupNCCL.Options): + raise RuntimeError( + "Expected pg_options argument to be of type ProcessGroupNCCL.Options" + ) + backend_class = ProcessGroupNCCL( + prefix_store, group_rank, len(my_group), pg_options + ) + else: + assert parent_backend_str.upper() in Backend._plugins, ( + f"Unknown c10d backend type {parent_backend_str.upper()}" + ) + backend_plugin = Backend._plugins[parent_backend_str.upper()] + creator_fn = backend_plugin.creator_fn + extended_api = backend_plugin.extended_api + backend_type = ProcessGroup.BackendType.CUSTOM + if not extended_api: + backend_class = creator_fn(prefix_store, group_rank, len(my_group), timeout) + else: + dist_backend_opts = _DistributedBackendOptions() + dist_backend_opts.store = prefix_store + dist_backend_opts.group_rank = group_rank + dist_backend_opts.group_size = len(my_group) + backend_class = creator_fn(dist_backend_opts, pg_options) + + pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(torch.device("cuda"), backend_type, backend_class) + + # set group_name and group_desc to backend + assert group_name is not None + assert group_desc is not None + pg._set_group_name(group_name) + pg._set_group_desc(group_desc) + + # always eagerly initialize the backend in split_group + eager_backend = pg._get_backend(device_id) + eager_backend.eager_connect_single_device(device_id) + + # update global state + _world.pg_map[pg] = (backend, prefix_store) + _world.pg_names[pg] = group_name + _register_process_group(group_name, pg) + _world.pg_backend_config[pg] = str(backend_config) + pg_tag = f"ptd:{group_name}" + _world.tags_to_pg.setdefault(pg_tag, []).append(pg) + _world.pg_to_tag[pg] = pg_tag + + # Create the global rank to group rank mapping + _world.pg_group_ranks[pg] = { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) global_rank: group_rank for group_rank, global_rank in enumerate(global_ranks_in_my_group) } +<<<<<<< HEAD return split_pg +======= + return pg +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @_time_logger diff --git a/torch/distributed/elastic/agent/server/api.py b/torch/distributed/elastic/agent/server/api.py index 1175da3b91b7c..c2eea82a44c8d 100644 --- a/torch/distributed/elastic/agent/server/api.py +++ b/torch/distributed/elastic/agent/server/api.py @@ -27,7 +27,10 @@ from torch.distributed.elastic.multiprocessing import ProcessFailure, SignalException from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError from torch.distributed.elastic.utils.logging import get_logger +<<<<<<< HEAD from torch.numa.binding import NumaOptions +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __all__ = [ @@ -89,7 +92,10 @@ class WorkerSpec: master_addr: Optional[str] = None local_addr: Optional[str] = None event_log_handler: str = "null" +<<<<<<< HEAD numa_options: Optional[NumaOptions] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __post_init__(self): assert self.local_world_size > 0 diff --git a/torch/distributed/elastic/agent/server/local_elastic_agent.py b/torch/distributed/elastic/agent/server/local_elastic_agent.py index dd9f0647a153d..20cda7865dcce 100644 --- a/torch/distributed/elastic/agent/server/local_elastic_agent.py +++ b/torch/distributed/elastic/agent/server/local_elastic_agent.py @@ -353,7 +353,10 @@ def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]: logs_specs=self._logs_specs, log_line_prefixes=log_line_prefixes, start_method=self._start_method, +<<<<<<< HEAD numa_options=spec.numa_options, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return self._pcontext.pids() diff --git a/torch/distributed/elastic/multiprocessing/__init__.py b/torch/distributed/elastic/multiprocessing/__init__.py index 3f9fabd720bdd..714feec733a63 100644 --- a/torch/distributed/elastic/multiprocessing/__init__.py +++ b/torch/distributed/elastic/multiprocessing/__init__.py @@ -80,7 +80,10 @@ def trainer(a, b, c): to_map, ) from torch.distributed.elastic.utils.logging import get_logger +<<<<<<< HEAD from torch.numa.binding import NumaOptions +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __all__ = [ @@ -107,7 +110,10 @@ def start_processes( logs_specs: LogsSpecs, log_line_prefixes: Optional[dict[int, str]] = None, start_method: str = "spawn", +<<<<<<< HEAD numa_options: Optional[NumaOptions] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> PContext: """ Start ``n`` copies of ``entrypoint`` processes with the provided options. @@ -141,8 +147,13 @@ def start_processes( For each process, the ``log_dir`` will contain: #. ``{local_rank}/error.json``: if the process failed, a file with the error info +<<<<<<< HEAD #. ``{local_rank}/stdout.log``: if ``redirect & STDOUT == STDOUT`` #. ``{local_rank}/stderr.log``: if ``redirect & STDERR == STDERR`` +======= + #. ``{local_rank}/stdout.json``: if ``redirect & STDOUT == STDOUT`` + #. ``{local_rank}/stderr.json``: if ``redirect & STDERR == STDERR`` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. note:: It is expected that the ``log_dir`` exists, is empty, and is a directory. @@ -216,7 +227,10 @@ def start_processes( envs=envs, logs_specs=logs_specs, log_line_prefixes=log_line_prefixes, +<<<<<<< HEAD numa_options=numa_options, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: context = MultiprocessContext( @@ -227,7 +241,10 @@ def start_processes( log_line_prefixes=log_line_prefixes, start_method=start_method, logs_specs=logs_specs, +<<<<<<< HEAD numa_options=numa_options, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) try: diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index ed3ea86b0f2aa..ba74c76d745d2 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -37,7 +37,10 @@ SubprocessHandler, ) from torch.distributed.elastic.multiprocessing.tail_log import TailLog +<<<<<<< HEAD from torch.numa.binding import NumaOptions +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) IS_WINDOWS = sys.platform == "win32" @@ -631,7 +634,10 @@ def __init__( start_method: str, logs_specs: LogsSpecs, log_line_prefixes: Optional[dict[int, str]] = None, +<<<<<<< HEAD numa_options: Optional[NumaOptions] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): super().__init__( name, @@ -656,8 +662,11 @@ def __init__( # successfully. If any process died on event.wait() calling set() method will deadlock. self._worker_finished_event = mp.get_context(self.start_method).Event() +<<<<<<< HEAD self._numa_options: Optional[NumaOptions] = numa_options +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _start(self): if self._pc: raise ValueError( @@ -679,7 +688,10 @@ def _start(self): join=False, daemon=False, start_method=self.start_method, +<<<<<<< HEAD numa_options=self._numa_options, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def _is_done(self) -> bool: @@ -816,7 +828,10 @@ def __init__( envs: dict[int, dict[str, str]], logs_specs: LogsSpecs, log_line_prefixes: Optional[dict[int, str]] = None, +<<<<<<< HEAD numa_options: Optional[NumaOptions] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): super().__init__( name, @@ -831,7 +846,10 @@ def __init__( self._running_local_ranks: set[int] = set(range(self.nprocs)) self._failures: dict[int, ProcessFailure] = {} self.subprocess_handlers: dict[int, SubprocessHandler] = {} +<<<<<<< HEAD self._numa_options: Optional[NumaOptions] = numa_options +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _start(self): if self.subprocess_handlers: @@ -846,7 +864,10 @@ def _start(self): stdout=self.stdouts[local_rank], stderr=self.stderrs[local_rank], local_rank_id=local_rank, +<<<<<<< HEAD numa_options=self._numa_options, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for local_rank in range(self.nprocs) } diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py index 947ce7b001ef7..9819becda4a7b 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py @@ -3,12 +3,18 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +<<<<<<< HEAD from typing import Optional +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( SubprocessHandler, ) +<<<<<<< HEAD from torch.numa.binding import NumaOptions +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __all__ = ["get_subprocess_handler"] @@ -21,7 +27,10 @@ def get_subprocess_handler( stdout: str, stderr: str, local_rank_id: int, +<<<<<<< HEAD numa_options: Optional[NumaOptions] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> SubprocessHandler: return SubprocessHandler( entrypoint=entrypoint, @@ -30,5 +39,8 @@ def get_subprocess_handler( stdout=stdout, stderr=stderr, local_rank_id=local_rank_id, +<<<<<<< HEAD numa_options=numa_options, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py index 6a2e7ae35c4b7..cf50eff9c3378 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py @@ -7,6 +7,7 @@ # LICENSE file in the root directory of this source tree. import os import signal +<<<<<<< HEAD import sys from subprocess import Popen from typing import Any, Optional @@ -16,6 +17,12 @@ NumaOptions, ) +======= +import subprocess +import sys +from typing import Any, Optional + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __all__ = ["SubprocessHandler"] @@ -44,7 +51,10 @@ def __init__( stdout: Optional[str], stderr: Optional[str], local_rank_id: int, +<<<<<<< HEAD numa_options: Optional[NumaOptions], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): self._stdout = open(stdout, "w") if stdout else None self._stderr = open(stderr, "w") if stderr else None @@ -53,6 +63,7 @@ def __init__( env_vars.update(env) args_str = (entrypoint, *[str(e) for e in args]) +<<<<<<< HEAD self.local_rank_id = local_rank_id @@ -68,6 +79,16 @@ def _popen(self, args: tuple, env: dict[str, str]) -> Popen: kwargs["start_new_session"] = True return Popen( +======= + self.local_rank_id = local_rank_id + self.proc: subprocess.Popen = self._popen(args_str, env_vars) + + def _popen(self, args: tuple, env: dict[str, str]) -> subprocess.Popen: + kwargs: dict[str, Any] = {} + if not IS_WINDOWS: + kwargs["start_new_session"] = True + return subprocess.Popen( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes], # _PathLike[str], bytes, str]], bytes, str]` for 1st param but got # `Tuple[str, *Tuple[Any, ...]]`. diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py index 0e4da86d4621d..ad138820d077a 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -208,7 +208,11 @@ def shutdown(self) -> bool: try: self.set_closed() return True +<<<<<<< HEAD except BaseException as e: # noqa: B036 +======= + except BaseException as e: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) logger.warning("Shutdown failed. Error occurred: %s", str(e)) return False diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index 4fe05da4c844c..561a755f3950a 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -2294,9 +2294,14 @@ def _writeback_orig_params(self) -> bool: flat_param._params[i] = param if needs_param_writeback: expected_shape = torch.Size([numel_in_shard]) +<<<<<<< HEAD src = param if self.uses_sharded_strategy else param.view(-1) self._writeback_tensor( src, flat_param, i, expected_shape, offset_in_shard, True +======= + self._writeback_tensor( + param, flat_param, i, expected_shape, offset_in_shard, True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) wroteback = True @@ -2328,6 +2333,7 @@ def _writeback_orig_params(self) -> bool: if flat_param_grad is None: flat_param_grad = torch.zeros_like(flat_param) expected_shape = torch.Size([numel_in_shard]) +<<<<<<< HEAD src = ( param.grad if self.uses_sharded_strategy @@ -2335,6 +2341,10 @@ def _writeback_orig_params(self) -> bool: ) self._writeback_tensor( src, +======= + self._writeback_tensor( + param.grad, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) flat_param_grad, i, expected_shape, diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_api.py b/torch/distributed/fsdp/_fully_shard/_fsdp_api.py index 38650323f5e99..49ad0431dd0e8 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_api.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_api.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass @@ -9,6 +10,12 @@ _ReduceOp = Union[dist.ReduceOp, dist.ReduceOp.RedOpType] +======= +from dataclasses import dataclass +from typing import Optional + +import torch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass(frozen=True) @@ -53,6 +60,7 @@ class MixedPrecisionPolicy: cast_forward_inputs: bool = True +<<<<<<< HEAD class Comm(ABC): """ Interface for communication primitives. @@ -127,6 +135,8 @@ def __call__( ) -> Optional[dist.Work]: ... +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass class OffloadPolicy: """ diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py index 90b4b91a5cc7a..a143a401aa938 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py @@ -1,16 +1,27 @@ +<<<<<<< HEAD import math from collections.abc import Sequence from itertools import chain from typing import Any, Callable, cast, NamedTuple, Optional, Union +======= +from itertools import chain +from typing import Callable, cast, NamedTuple, Optional, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.distributed as dist from torch.distributed.device_mesh import _get_device_handle +<<<<<<< HEAD from torch.distributed.distributed_c10d import ReduceOp from torch.distributed.fsdp._fully_shard._fsdp_api import AllGather, ReduceScatter from torch.distributed.tensor import DTensor from ._fsdp_api import _ReduceOp +======= +from torch.distributed.distributed_c10d import _resolve_process_group, ReduceOp +from torch.distributed.tensor import DTensor + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ._fsdp_common import ( _get_dim0_padded_size, _raise_assert_with_print, @@ -33,21 +44,50 @@ class AllGatherResult(NamedTuple): all_gather_input_split_sizes: list[int] +<<<<<<< HEAD +======= +def allocate_memory( + size: int, + dtype: torch.dtype, + device: torch.device, + group: dist.ProcessGroup, + from_process_group: bool, +) -> torch.Tensor: + if from_process_group: + backend = group._get_backend(device) + if backend.supports_tensor_alloc(device): + return backend.allocate_tensor(size, dtype=dtype, device=device) + return torch.empty((size,), dtype=dtype, device=device) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 lib.define( """ all_gather_copy_in( Tensor[] all_gather_inputs, +<<<<<<< HEAD Tensor all_gather_output, SymInt[] inp_split_sizes, SymInt all_gather_input_numel, SymInt rank +======= + SymInt[] inp_split_sizes, + SymInt all_gather_input_numel, + SymInt world_size, + SymInt rank, + ScalarType dtype, + Device device, + str group_name, + bool allocate_memory_from_process_group +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> (Tensor, Tensor) """ ) +<<<<<<< HEAD class DefaultAllocMixin: def allocate( self, @@ -160,6 +200,23 @@ def all_gather_copy_in_meta( all_gather_input_numel: int, rank: int, ) -> tuple[torch.Tensor, torch.Tensor]: +======= +@torch.library.impl(lib, "all_gather_copy_in", "Meta") +def all_gather_copy_in_meta( + all_gather_inputs: list[torch.Tensor], + inp_split_sizes: list[int], + all_gather_input_numel: int, + world_size: int, + rank: int, + dtype: torch.dtype, + device: torch.device, + group_name: str, + allocate_memory_from_process_group: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + all_gather_output = torch.empty( + (all_gather_input_numel * world_size,), dtype=dtype, device="meta" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) all_gather_input = all_gather_output.narrow( 0, all_gather_input_numel * rank, all_gather_input_numel ) @@ -174,11 +231,30 @@ def all_gather_copy_in_meta( @torch.library.impl(lib, "all_gather_copy_in", "PrivateUse1") def all_gather_copy_in_cuda( all_gather_inputs: list[torch.Tensor], +<<<<<<< HEAD all_gather_output: torch.Tensor, inp_split_sizes: list[int], all_gather_input_numel: int, rank: int, ) -> tuple[torch.Tensor, torch.Tensor]: +======= + inp_split_sizes: list[int], + all_gather_input_numel: int, + world_size: int, + rank: int, + dtype: torch.dtype, + device: torch.device, + group_name: str, + allocate_memory_from_process_group: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + all_gather_output = allocate_memory( + all_gather_input_numel * world_size, + dtype=dtype, + device=device, + group=_resolve_process_group(group_name), + from_process_group=allocate_memory_from_process_group, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) all_gather_input = all_gather_output.narrow( 0, all_gather_input_numel * rank, all_gather_input_numel ) @@ -240,7 +316,11 @@ def foreach_all_gather( all_gather_copy_in_stream: torch.Stream, all_gather_stream: torch.Stream, device: torch.device, +<<<<<<< HEAD all_gather_comm: AllGather, +======= + allocate_memory_from_process_group: bool = False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Optional[AllGatherResult]: world_size, rank = group.size(), group.rank() device_handle = _get_device_handle(device.type) @@ -259,6 +339,7 @@ def foreach_all_gather( all_gather_inputs = [*chain.from_iterable(param_all_gather_inputs)] inp_split_sizes = [t.numel() for t in all_gather_inputs] all_gather_input_numel = sum(inp_split_sizes) +<<<<<<< HEAD all_gather_output = all_gather_comm.allocate( (all_gather_input_numel * world_size,), dtype=dtype, device=device ) @@ -268,11 +349,27 @@ def foreach_all_gather( inp_split_sizes, all_gather_input_numel, rank, +======= + all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in( + all_gather_inputs, + inp_split_sizes, + all_gather_input_numel, + world_size, + rank, + dtype, + device, + group.group_name, + allocate_memory_from_process_group, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) del param_all_gather_inputs all_gather_stream.wait_stream(all_gather_copy_in_stream) with device_handle.stream(all_gather_stream): +<<<<<<< HEAD all_gather_work = all_gather_comm( +======= + all_gather_work = dist.all_gather_into_tensor( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_tensor=all_gather_output, input_tensor=all_gather_input, group=group, @@ -449,7 +546,10 @@ def foreach_reduce( unsharded_grads: list[torch.Tensor], reduce_scatter_group: dist.ProcessGroup, reduce_scatter_stream: torch.Stream, +<<<<<<< HEAD reduce_scatter_comm: ReduceScatter, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) orig_dtype: Optional[torch.dtype], reduce_dtype: Optional[torch.dtype], device: torch.device, @@ -459,6 +559,10 @@ def foreach_reduce( all_reduce_grads: bool, partial_reduce_output: Optional[torch.Tensor], # only used for HSDP all_reduce_hook: Optional[Callable[[torch.Tensor], None]], +<<<<<<< HEAD +======= + allocate_memory_from_process_group: bool = False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) force_sum_reduction_for_comms: bool = False, ) -> tuple[ torch.Tensor, @@ -505,10 +609,19 @@ def foreach_reduce( ) reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes) reduce_scatter_output_numel = reduce_scatter_input_numel // world_size +<<<<<<< HEAD reduce_scatter_input = reduce_scatter_comm.allocate( (reduce_scatter_input_numel,), dtype=reduce_dtype, device=device, +======= + reduce_scatter_input = allocate_memory( + reduce_scatter_input_numel, + dtype=reduce_dtype, + device=device, + group=reduce_scatter_group, + from_process_group=allocate_memory_from_process_group, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) device_handle = _get_device_handle(device.type) foreach_reduce_scatter_copy_in(unsharded_grads, reduce_scatter_input, world_size) @@ -519,6 +632,7 @@ def foreach_reduce( all_reduce_input = None all_reduce_event = None with device_handle.stream(reduce_scatter_stream): +<<<<<<< HEAD reduce_output = reduce_scatter_comm.allocate( (reduce_scatter_output_numel,), dtype=reduce_dtype, @@ -528,6 +642,19 @@ def foreach_reduce( reduce_scatter_comm( output_tensor=reduce_output, input_tensor=reduce_scatter_input, +======= + reduce_output = allocate_memory( + reduce_scatter_output_numel, + dtype=reduce_dtype, + device=device, + group=reduce_scatter_group, + from_process_group=allocate_memory_from_process_group, + ) + _div_if_needed(reduce_scatter_input, predivide_factor) + dist.reduce_scatter_tensor( + output=reduce_output, + input=reduce_scatter_input, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) group=reduce_scatter_group, op=reduce_scatter_op, ) @@ -603,7 +730,11 @@ def foreach_reduce( if non_blocking: # Record an event on which to block the CPU thread to # ensure that the D2H copy finishes before the optimizer +<<<<<<< HEAD fsdp_param.grad_offload_event = post_reduce_stream.record_event() +======= + fsdp_param.grad_offload_event = reduce_scatter_stream.record_event() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if to_accumulate_grad: assert isinstance(fsdp_param.sharded_param.grad, DTensor) fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad @@ -701,8 +832,11 @@ def _get_gradient_divide_factors( if factor == data_parallel_size: # Warning: NCCL ReduceOp.AVG may produce incorrect results with # world size 1. +<<<<<<< HEAD if data_parallel_size == 1: return None, None, ReduceOp.SUM, ReduceOp.SUM +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return None, None, ReduceOp.AVG, ReduceOp.AVG else: reduce_scatter_op = torch.distributed._make_nccl_premul_sum(1 / factor) diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py index b599f48d77d1d..dff4da68dd360 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py @@ -15,6 +15,7 @@ _compiled_autograd_enabled: bool = False +<<<<<<< HEAD def detect_compiled_autograd(): assert not torch.compiler.is_compiling(), ( @@ -33,6 +34,34 @@ def detect_compiled_autograd(): def compiled_autograd_enabled(): global _compiled_autograd_enabled return _compiled_autograd_enabled +======= +if torch._running_with_deploy(): + + def detect_compiled_autograd(): + pass + + def compiled_autograd_enabled(): + return False + +else: + + def detect_compiled_autograd(): + assert not torch.compiler.is_compiling(), ( + "`detect_compiled_autograd()` is designed to be called in eager mode" + ) + global _compiled_autograd_enabled + import torch._dynamo.compiled_autograd as ca + + _compiled_autograd_enabled = ( + ca.compiled_autograd_enabled + or ca.compiled_autograd_enabled_force_eager + or ca.in_compiled_autograd_region + ) + + def compiled_autograd_enabled(): + global _compiled_autograd_enabled + return _compiled_autograd_enabled +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py index db8f2bf722f01..d3c2ea71d9317 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py @@ -140,7 +140,12 @@ def copy__functionalize(tensor, data): torch.ops.fsdp.copy_.default(tensor_inner, data_inner) +<<<<<<< HEAD torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default) +======= +if not torch._running_with_deploy(): + torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ShardedState(Enum): @@ -291,14 +296,20 @@ def _init_sharded_param( dp_global_mesh is None or tp_global_mesh is None ): raise AssertionError( +<<<<<<< HEAD "FSDP requires the DP and model parallel TP/EP mesh to have the same parent mesh but got: \n" f"DP's global mesh: {dp_global_mesh}\nTP/EP's global mesh: {tp_global_mesh}" +======= + "FSDP requires the DP and TP mesh to have the same parent mesh but got: \n" + f"DP's global mesh: {dp_global_mesh}\nTP's global mesh: {tp_global_mesh}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism" assert dp_mesh.mesh_dim_names is not None, name_dims_error assert tp_mesh.mesh_dim_names is not None, name_dims_error submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names self._spmd_mesh = dp_global_mesh[submesh_names] +<<<<<<< HEAD if len(self._tp_spec.placements) > 2: raise NotImplementedError( f"FSDP only supports 1D TP/EP or 2D EP+TP, not {self._tp_spec.placements}" @@ -307,6 +318,15 @@ def _init_sharded_param( assert 2 <= self._spmd_mesh.ndim <= 4, ( "_spmd_mesh.ndim can only be 2 (FSDP+TP/EP), 3 (FSDP+EP+TP, HSDP+TP/EP), " f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}." +======= + if len(self._tp_spec.placements) != 1: + raise NotImplementedError( + f"FSDP only supports 1D TP, not {self._tp_spec.placements}" + ) + split_factor = self._tp_spec.num_shards_map[shard_dim] + assert 2 <= self._spmd_mesh.ndim <= 3, ( + f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) self._spmd_placements: tuple[Placement, ...] dp_shard_tp_placement = ( @@ -315,11 +335,19 @@ def _init_sharded_param( if split_factor > 1 else fsdp_placement ), +<<<<<<< HEAD *self._tp_spec.placements, ) if dp_mesh.ndim == 1: # FSDP self._spmd_placements = dp_shard_tp_placement else: # HSDP +======= + self._tp_spec.placements[0], + ) + if self._spmd_mesh.ndim == 2: + self._spmd_placements = dp_shard_tp_placement + else: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.mesh_info.replicate_mesh_dim == 0 self._spmd_placements = (Replicate(),) + dp_shard_tp_placement self._sharding_spec = DTensorSpec( @@ -376,7 +404,13 @@ def _init_sharded_param( if self.offload_to_cpu and not padded_sharded_param.is_meta: padded_sharded_param = padded_sharded_param.cpu() if self.pin_memory: +<<<<<<< HEAD padded_sharded_param = padded_sharded_param.pin_memory() +======= + padded_sharded_param = padded_sharded_param.pin_memory( + device=self.device + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._sharded_param_data = padded_sharded_param.view(-1) length = sharded_param.size(shard_dim) if sharded_param.numel() > 0 else 0 sharded_param = padded_sharded_param.narrow( @@ -694,7 +728,10 @@ def all_gather_inputs(self) -> list[torch.Tensor]: # 1D ), ( f"Invalid fsdp_pre_all_gather: {pre_all_gather_signature}\n" "Expects fsdp_pre_all_gather(self, mesh: DeviceMesh, " +<<<<<<< HEAD "outer_size: torch.Size, outer_stride: tuple[int, ...], " +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "module: nn.Module, mp_policy: MixedPrecisionPolicy)" ) if num_fn_params == 1: @@ -847,7 +884,11 @@ def reset_sharded_param(self): local_tensor = padded_local_tensor updated_local_tensor = True if self.pin_memory and not local_tensor.is_pinned(): +<<<<<<< HEAD local_tensor = local_tensor.cpu().pin_memory() +======= + local_tensor = local_tensor.cpu().pin_memory(device=self.device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) updated_local_tensor = True self._sharded_param_data = local_tensor.view(-1) assert isinstance(self.sharded_param, DTensor) # mypy diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py index 554367e8705c8..8b33dd0e0cfdb 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py @@ -15,6 +15,7 @@ from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy from ._fsdp_collectives import ( +<<<<<<< HEAD AllGather, AllGatherResult, DefaultAllGather, @@ -25,6 +26,12 @@ ProcessGroupAllocAllGather, ProcessGroupAllocReduceScatter, ReduceScatter, +======= + AllGatherResult, + foreach_all_gather, + foreach_all_gather_copy_out, + foreach_reduce, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from ._fsdp_common import ( compiled_autograd_enabled, @@ -32,7 +39,11 @@ HSDPMeshInfo, TrainingState, ) +<<<<<<< HEAD from ._fsdp_param import alloc_storage, FSDPParam, ParamModuleInfo, ShardedState +======= +from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) logger = logging.getLogger("torch.distributed.fsdp.fully_shard") @@ -165,9 +176,12 @@ def __init__( self._module_to_pre_save_state_dict_hook_handle: _ModuleToHandleDict = {} self._module_to_pre_load_state_dict_hook_handle: _ModuleToHandleDict = {} self._all_reduce_hook: Optional[Callable[[torch.Tensor], None]] = None +<<<<<<< HEAD self._all_gather_comm: AllGather = DefaultAllGather() self._all_gather_output = torch.empty(0, device=self.device) self._reduce_scatter_comm: ReduceScatter = DefaultReduceScatter() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Optional stream to run the user-defined all-reduce hook in # Saved here and not in the comm. context because we allow the user to # specify it, possibly at construction time before lazy init @@ -199,6 +213,12 @@ def __init__( # Whether to unshard in backward: can be overridden by the user if the # parameters in this group are not needed for backward (e.g. embedding) self.unshard_in_backward: bool = True +<<<<<<< HEAD +======= + # Whether to (try to) use the ProcessGroup's allocate_tensor method for + # the staging buffers for collective comms. + self.allocate_memory_from_process_group = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # - CUDA events for stream synchronization # Holds the all-gather output buffer, sync objects, and metadata @@ -265,6 +285,7 @@ def lazy_init(self): self._init_mp_dtypes() self._register_state_dict_hooks() +<<<<<<< HEAD def set_allocate_memory_from_process_group(self, enable: bool) -> None: """ Whether to (try to) use the ProcessGroup's allocate_tensor method for @@ -295,6 +316,8 @@ def set_allocate_memory_from_process_group(self, enable: bool) -> None: else DefaultReduceScatter() ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Runtime # def unshard(self, async_op: bool = False): if self._all_gather_result is not None: # already called, pending wait @@ -311,6 +334,7 @@ def unshard(self, async_op: bool = False): # used in the all-gather streams self._wait_all_gather_streams_on_event(self._reshard_after_forward_event) self._reshard_after_forward_event = None +<<<<<<< HEAD world_size = self._all_gather_process_group.size() if world_size == 1: @@ -327,6 +351,8 @@ def unshard(self, async_op: bool = False): return +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with record_function(self._with_fqn("FSDP::all_gather")): self._all_gather_result = foreach_all_gather( self.fsdp_params, @@ -334,7 +360,11 @@ def unshard(self, async_op: bool = False): async_op, *self.comm_ctx.get_all_gather_streams(async_op, self._training_state), self.device, +<<<<<<< HEAD self._all_gather_comm, +======= + self.allocate_memory_from_process_group, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def wait_for_unshard(self): @@ -353,6 +383,7 @@ def wait_for_unshard(self): if prev_all_gather_state := self.comm_ctx.all_gather_state: self._wait_all_gather_streams_on_event(prev_all_gather_state.event) self.comm_ctx.all_gather_state = None # free the all-gather result +<<<<<<< HEAD world_size = self._all_gather_process_group.size() if world_size == 1: # directly initialize unsharded parameters from sharded parameters @@ -399,6 +430,20 @@ def wait_for_unshard(self): and self._training_state == TrainingState.FORWARD and world_size > 1 ): +======= + with record_function(self._with_fqn("FSDP::all_gather_copy_out")): + foreach_all_gather_copy_out( + self._all_gather_result, + self.fsdp_params, + self._all_gather_process_group, + ) + for fsdp_param in self.fsdp_params: + fsdp_param.init_unsharded_param() + self._to_unsharded() + all_gather_copy_out_event = self.device_handle.Event() + all_gather_copy_out_event.record() + if not async_op and self._training_state == TrainingState.FORWARD: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Defer free to allow for overlap of this copy-out with next # all-gather collective self.comm_ctx.all_gather_state = AllGatherState( @@ -406,7 +451,10 @@ def wait_for_unshard(self): ) else: self._wait_all_gather_streams_on_event(all_gather_copy_out_event) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._all_gather_result = None # free unless saved in `all_gather_state` def _wait_all_gather_streams_on_event(self, event: Optional[torch.Event]): @@ -546,7 +594,10 @@ def post_backward(self, *unused: Any): unsharded_grads, self._reduce_scatter_process_group, self.comm_ctx.reduce_scatter_stream, +<<<<<<< HEAD self._reduce_scatter_comm, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._orig_dtype, self._reduce_dtype, self.device, @@ -556,6 +607,10 @@ def post_backward(self, *unused: Any): self.all_reduce_grads, self._partial_reduce_output, self._all_reduce_hook, +<<<<<<< HEAD +======= + self.allocate_memory_from_process_group, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.force_sum_reduction_for_comms, ) self.comm_ctx.reduce_scatter_state = ReduceScatterState( diff --git a/torch/distributed/fsdp/_fully_shard/_fully_shard.py b/torch/distributed/fsdp/_fully_shard/_fully_shard.py index eb348a00f5f98..e3faa61416636 100644 --- a/torch/distributed/fsdp/_fully_shard/_fully_shard.py +++ b/torch/distributed/fsdp/_fully_shard/_fully_shard.py @@ -21,7 +21,11 @@ from torch.distributed._composable import contract from torch.distributed.utils import _get_root_modules +<<<<<<< HEAD from ._fsdp_api import AllGather, MixedPrecisionPolicy, OffloadPolicy, ReduceScatter +======= +from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo from ._fsdp_init import ( _get_device_from_mesh, @@ -455,6 +459,7 @@ def set_modules_to_backward_prefetch(self, modules: list[FSDPModule]) -> None: module._get_fsdp_state() for module in modules ] +<<<<<<< HEAD def set_custom_all_gather(self, comm: AllGather) -> None: """ Overrides the default ``all_gather`` communication behavior, @@ -481,6 +486,8 @@ def set_custom_reduce_scatter(self, comm: ReduceScatter) -> None: if (fsdp_param_group := state._fsdp_param_group) is not None: fsdp_param_group._reduce_scatter_comm = comm +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def set_all_reduce_hook( self, hook: Callable[[torch.Tensor], None], @@ -584,17 +591,24 @@ def set_allocate_memory_from_process_group_for_comm(self, enable: bool) -> None: using NCCL, this enables it to leverage zero-copy transfers over SHARP (for NVLink and/or InfiniBand). +<<<<<<< HEAD This cannot be used together with :meth:`set_custom_all_gather` or :meth:`set_custom_reduce_scatter` as those APIs allow for finer-grained control over each communication, and this method cannot determine their staging buffer allocation strategy. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: enable (bool): Whether to turn on ProcessGroup allocation. """ state = self._get_fsdp_state() if (fsdp_param_group := state._fsdp_param_group) is not None: +<<<<<<< HEAD fsdp_param_group.set_allocate_memory_from_process_group(enable) +======= + fsdp_param_group.allocate_memory_from_process_group = enable +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _set_unshard_async_op(self, async_op: bool): """ diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index a81d48ebdba86..f9498244c6717 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -330,7 +330,11 @@ def param_hook( try: state_dict[fqn] = state_dict[fqn].detach().clone() state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined] +<<<<<<< HEAD except BaseException as e: # noqa: B036 +======= + except BaseException as e: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) warnings.warn( f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. " "This may mean that this state_dict entry could point to invalid " diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index f8d0033eb59bd..99757494ceb26 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -121,6 +121,12 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): well as the ZeRO Stage 3 from `DeepSpeed `_. FullyShardedDataParallel is commonly shortened to FSDP. +<<<<<<< HEAD +======= + To understand FSDP internals, refer to the + :ref:`fsdp_notes`. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Example:: >>> # xdoctest: +SKIP("undefined variables") diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index 4a8d41c9358a1..5e49075912e73 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -320,8 +320,13 @@ def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None self._scale.fill_(new_scale) # type: ignore[union-attr] else: reason = ( +<<<<<<< HEAD "new_scale should be a float or a 1-element torch.cuda.FloatTensor or " "torch.FloatTensor with requires_grad=False." +======= + "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \ + torch.FloatTensor with requires_grad=False." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) assert new_scale.device.type == self._device, reason assert new_scale.numel() == 1, reason diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index acf23b27ca2a6..a04a8e3f4705e 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -11,9 +11,13 @@ from dataclasses import dataclass, field from typing import Any, Callable, Optional, Union +<<<<<<< HEAD import torch import torch.distributed.elastic.rendezvous.registry as rdzv_registry from torch._utils_internal import get_default_numa_options +======= +import torch.distributed.elastic.rendezvous.registry as rdzv_registry +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.elastic import events, metrics from torch.distributed.elastic.agent.server.api import WorkerSpec from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent @@ -26,7 +30,10 @@ from torch.distributed.elastic.rendezvous import RendezvousParameters from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint from torch.distributed.elastic.utils.logging import get_logger +<<<<<<< HEAD from torch.numa.binding import NumaOptions +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __all__ = ["LaunchConfig", "elastic_launch", "launch_agent"] @@ -94,7 +101,10 @@ class LaunchConfig: metrics_cfg: dict[str, str] = field(default_factory=dict) local_addr: Optional[str] = None event_log_handler: str = "null" +<<<<<<< HEAD numa_options: Optional[NumaOptions] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __post_init__(self): default_timeout = 900 @@ -107,6 +117,7 @@ def __post_init__(self): if self.logs_specs is None: self.logs_specs = DefaultLogsSpecs() +<<<<<<< HEAD if ( self.numa_options is None and torch.cuda.is_available() @@ -116,6 +127,8 @@ def __post_init__(self): self.numa_options = get_default_numa_options() logger.info("Using default numa options = %r", self.numa_options) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class elastic_launch: """ @@ -223,8 +236,12 @@ def launch_agent( " monitor_interval : %(monitor_interval)s\n" " log_dir : %(log_dir)s\n" " metrics_cfg : %(metrics_cfg)s\n" +<<<<<<< HEAD " event_log_handler : %(event_log_handler)s\n" " numa_options : %(numa_options)s\n", +======= + " event_log_handler : %(event_log_handler)s\n", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) { "entrypoint": entrypoint_name, "min_nodes": config.min_nodes, @@ -239,7 +256,10 @@ def launch_agent( "log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr] "metrics_cfg": config.metrics_cfg, "event_log_handler": config.event_log_handler, +<<<<<<< HEAD "numa_options": config.numa_options, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }, ) @@ -267,7 +287,10 @@ def launch_agent( master_port=master_port, local_addr=config.local_addr, event_log_handler=config.event_log_handler, +<<<<<<< HEAD numa_options=config.numa_options, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) agent = LocalElasticAgent( diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 3dfb0fe25c4cd..e64a25785f289 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -681,7 +681,11 @@ def _from_traced( ``output_loss_value_spec={'loss': True, 'model_out': False}`` """ +<<<<<<< HEAD traced = exported_program.module(check_guards=False) +======= + traced = exported_program.module() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if split_policy is not None: logger.info("Auto-splitting model") diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py index aacaf0b7f5e4a..5478e604243c3 100644 --- a/torch/distributed/pipelining/__init__.py +++ b/torch/distributed/pipelining/__init__.py @@ -3,7 +3,10 @@ from .schedules import ( _ScheduleForwardOnly, Schedule1F1B, +<<<<<<< HEAD ScheduleDualPipeV, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ScheduleGPipe, ScheduleInterleaved1F1B, ScheduleInterleavedZeroBubble, @@ -26,5 +29,8 @@ "ScheduleLoopedBFS", "ScheduleInterleavedZeroBubble", "ScheduleZBVZeroBubble", +<<<<<<< HEAD "ScheduleDualPipeV", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] diff --git a/torch/distributed/pipelining/_backward.py b/torch/distributed/pipelining/_backward.py index a3529067db793..3d2d4e8a0f060 100644 --- a/torch/distributed/pipelining/_backward.py +++ b/torch/distributed/pipelining/_backward.py @@ -235,6 +235,7 @@ def stage_backward_weight( weight_grads.append(weight.grad) for param_group in param_groups: +<<<<<<< HEAD valid_edges = [] valid_grad_outputs: list[torch.Tensor] = [] @@ -246,6 +247,13 @@ def stage_backward_weight( summed_grad = sum(non_none_grads) valid_edges.append(GradientEdge(intermediate, 0)) valid_grad_outputs.append(summed_grad) +======= + # TODO: Handle case where intermediate can have multiple outputs + intermediate_edges = tuple( + GradientEdge(i, 0) for i in param_group["intermediates"] + ) + weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Break a reference cycle caused inside stage_backward_input->get_hook->hook # The summarized cycle is: @@ -254,6 +262,7 @@ def stage_backward_weight( # We need to keep intermediates alive up until backward_weight, but we can free it now. del param_group["intermediates"] +<<<<<<< HEAD if valid_edges: # Only call autograd.grad if we have valid gradients # [NEW!] Able to pass a GradientEdge to autograd.grad as output weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"]) @@ -273,6 +282,27 @@ def stage_backward_weight( weight.grad = dw else: weight.grad += dw +======= + assert all(len(g) == 1 for g in param_group["grads"]) + # [NEW!] Able to pass a GradientEdge to autograd.grad as output + # We do not need to retain_graph because... guarantee no overlap? + # print("trying to execute: ", intermediate_edges, weights_edges) + dweights = torch.autograd.grad( + intermediate_edges, + weights_edges, + grad_outputs=sum(param_group["grads"], tuple()), + retain_graph=retain_graph, + ) + # release grad memory early after use + del param_group["grads"] + + for grad_acc, dw in zip(param_group["params"], dweights): + weight, index = grad_acc_to_weight[grad_acc] + if weight.grad is None: + weight.grad = dw + else: + weight.grad += dw +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # return grads in the original order weights were provided in return tuple(weight_grads) diff --git a/torch/distributed/pipelining/_schedule_visualizer.py b/torch/distributed/pipelining/_schedule_visualizer.py index 38ba1241c4e5a..098b5faa6db5c 100644 --- a/torch/distributed/pipelining/_schedule_visualizer.py +++ b/torch/distributed/pipelining/_schedule_visualizer.py @@ -24,7 +24,11 @@ def get_schedule_ops( +<<<<<<< HEAD schedule: Union[str, type[_PipelineSchedule]], +======= + schedule: Union[str, _PipelineSchedule], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pp_degree: int, num_microbatches: int, num_stages_per_rank: Optional[int] = None, @@ -38,7 +42,11 @@ def get_schedule_ops( if isinstance(schedule, str): schedule_class = get_schedule_class(schedule) +<<<<<<< HEAD elif issubclass(schedule, _PipelineSchedule): +======= + elif type(schedule) == _PipelineSchedule: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) schedule_class = schedule else: raise ValueError(f"Invalid schedule: {schedule}") @@ -98,7 +106,10 @@ def __init__( _ComputationType.BACKWARD_INPUT: _ComputationTypeColor("teal", "Backward Input"), _ComputationType.BACKWARD_WEIGHT: _ComputationTypeColor("green", "Backward Weight"), _ComputationType.FULL_BACKWARD: _ComputationTypeColor("orange", "Full Backward", 2), +<<<<<<< HEAD _ComputationType.OVERLAP_F_B: _ComputationTypeColor("purple", "Overlap F+B", 3), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } @@ -137,6 +148,7 @@ def visualize_schedule( used_computation.add(action.computation_type) color = comp_type_color.color width = comp_type_color.width +<<<<<<< HEAD # Check if action has sub_actions to determine styling if action.sub_actions is not None: @@ -146,6 +158,8 @@ def visualize_schedule( linewidth = 1 # Default linewidth for regular actions text_weight = "normal" # Default text weight +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Draw the rectangle to represent the action duration rect = Rectangle( (draw_position, num_ranks - rank_idx - 1), @@ -153,10 +167,15 @@ def visualize_schedule( 1, facecolor=color, edgecolor="black", +<<<<<<< HEAD linewidth=linewidth, ) ax.add_patch(rect) +======= + ) + ax.add_patch(rect) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Draw the text centered within the rectangle ax.text( draw_position + width / 2, @@ -166,9 +185,14 @@ def visualize_schedule( va="center", fontsize=font_size, color="white", +<<<<<<< HEAD weight=text_weight, ) +======= + ) + # Increment the drawing position by the width of the current action +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) draw_position += width else: draw_position += 1 # Move to the next diff --git a/torch/distributed/pipelining/_utils.py b/torch/distributed/pipelining/_utils.py index 2f0472211b8c8..2532fed78f546 100644 --- a/torch/distributed/pipelining/_utils.py +++ b/torch/distributed/pipelining/_utils.py @@ -1,6 +1,9 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import logging from dataclasses import dataclass from typing import Union @@ -123,6 +126,7 @@ def generate_stage_to_rank_mapping( return mapping +<<<<<<< HEAD def generate_rank_to_stage_mapping( pp_size: int, num_stages: int, style: str = "loop" ) -> dict[int, list[int]]: @@ -149,6 +153,8 @@ def generate_rank_to_stage_mapping( return rank_to_stages +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass class PipeInfo: """ diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index ffc23a654ec45..5e023d96fe74c 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -9,7 +9,10 @@ from abc import ABC, abstractmethod from collections import Counter, defaultdict from enum import Enum +<<<<<<< HEAD from functools import lru_cache +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Any, Callable, NamedTuple, Optional, Union import torch @@ -19,7 +22,11 @@ from torch.nn.modules.loss import _Loss from torch.profiler import record_function +<<<<<<< HEAD from ._utils import generate_rank_to_stage_mapping, generate_stage_to_rank_mapping +======= +from ._utils import generate_stage_to_rank_mapping +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec from .stage import _PipelineStageBase @@ -34,7 +41,10 @@ "ScheduleLoopedBFS", "ScheduleInterleavedZeroBubble", "ScheduleZBVZeroBubble", +<<<<<<< HEAD "ScheduleDualPipeV", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] logger = logging.getLogger(__name__) @@ -52,7 +62,10 @@ class _ComputationType(Enum): SEND_B = 8 RECV_B = 9 FULL_BACKWARD = 10 +<<<<<<< HEAD OVERLAP_F_B = 11 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __str__(self): str_map = { @@ -66,7 +79,10 @@ def __str__(self): _ComputationType.SEND_B: "SEND_B", _ComputationType.RECV_B: "RECV_B", _ComputationType.FULL_BACKWARD: "B", +<<<<<<< HEAD _ComputationType.OVERLAP_F_B: "OVERLAP_F_B", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return str_map[self] @@ -92,8 +108,11 @@ def from_str(action): return _ComputationType.RECV_B elif action == "B": return _ComputationType.FULL_BACKWARD +<<<<<<< HEAD elif action == "OVERLAP_F_B": return _ComputationType.OVERLAP_F_B +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: raise RuntimeError(f"Invalid computation type {action}") @@ -108,7 +127,10 @@ def from_str(action): SEND_B = _ComputationType.SEND_B RECV_B = _ComputationType.RECV_B FULL_BACKWARD = _ComputationType.FULL_BACKWARD +<<<<<<< HEAD OVERLAP_F_B = _ComputationType.OVERLAP_F_B +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Convenience shorthand for compute actions only since they are used in 'simple schedule format' F = FORWARD @@ -126,6 +148,7 @@ class _Action(NamedTuple): stage_index: int computation_type: _ComputationType microbatch_index: Optional[int] = None +<<<<<<< HEAD sub_actions: Optional[tuple["_Action", ...]] = None def __str__(self): @@ -152,6 +175,15 @@ def is_compute_op(self) -> bool: BACKWARD_WEIGHT, OVERLAP_F_B, ) +======= + + def __repr__(self): + repr = str(self.stage_index) + repr += str(self.computation_type) + if self.microbatch_index is not None: + repr += str(self.microbatch_index) + return repr +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def from_str(action_string: str): @@ -162,6 +194,7 @@ def from_str(action_string: str): e.g. `2F0`, `1UNSHARD`, `3SEND_F1` """ action_string = action_string.strip() +<<<<<<< HEAD if action_string == "": return None @@ -194,6 +227,8 @@ def from_str(action_string: str): ) # Handle regular single action format +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if match := _action_regex.match(action_string): stage_index, computation_type, microbatch_index = match.groups() return _Action( @@ -208,11 +243,14 @@ def from_str(action_string: str): ) +<<<<<<< HEAD @lru_cache def _get_profiler_function_name(action: _Action) -> str: return f"PP:{str(action)}" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _format_pipeline_order( pipeline_order: dict[int, list[Optional[_Action]]], error_step_number: Optional[int] = None, @@ -311,13 +349,21 @@ def __init__( logger.info("Using %s", self.__class__.__name__) def _maybe_compute_loss(self, stage, output, target_mbs, mb_index): +<<<<<<< HEAD if stage.is_last and self._loss_fn is not None: +======= + if stage.is_last and self._has_backward: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index] self._internal_losses.append(loss) def _maybe_get_loss(self, stage, mb_index): valid_index = 0 <= mb_index < len(self._internal_losses) +<<<<<<< HEAD if stage.is_last and self._loss_fn is not None and valid_index: +======= + if stage.is_last and self._has_backward and valid_index: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self._internal_losses[mb_index] elif len(self._internal_losses) != 0 and not valid_index: raise RuntimeError( @@ -382,6 +428,7 @@ def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): """ raise NotImplementedError +<<<<<<< HEAD def eval(self, *args, target=None, losses: Optional[list] = None, **kwargs): """ Run one iteration of the pipeline schedule with *whole-batch* input. @@ -402,6 +449,8 @@ def eval(self, *args, target=None, losses: Optional[list] = None, **kwargs): # Restore the original state self._has_backward = original_has_backward +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _check_inputs( self, arg_mbs: Optional[list] = None, @@ -558,6 +607,11 @@ def __init__( # Self attributes self._stage = stage self._num_stages = stage.num_stages +<<<<<<< HEAD +======= + # Set the same has_backward flag for stage object + self._stage.has_backward = self._has_backward +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._stage_initialized = False if n_microbatches < self._num_stages: @@ -571,6 +625,7 @@ def __init__( ) def _initialize_stage(self, args, kwargs): +<<<<<<< HEAD # Prepare the communication needed for the pipeline schedule execution # This is needed because during execution we always perform a series of batch P2P ops # The first call of the batched P2P needs to involve the global group @@ -578,6 +633,8 @@ def _initialize_stage(self, args, kwargs): all_ops.extend(self._stage._get_init_p2p_neighbors_ops()) _wait_batch_p2p(_batch_p2p(all_ops)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs) if self._has_backward: self._stage._prepare_backward_infra(self._n_microbatches) @@ -594,6 +651,7 @@ def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): target: target for the loss function. losses: a list to store the losses for each microbatch. """ +<<<<<<< HEAD if self._has_backward and not torch.is_grad_enabled(): raise RuntimeError( "step() requires gradients to be enabled for backward computation; " @@ -603,6 +661,8 @@ def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): # Set the same has_backward flag for stage object self._stage.has_backward = self._has_backward +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Clean per iteration self._stage.clear_runtime_states() @@ -747,6 +807,13 @@ def _step_microbatches( for work in fwd_sends_to_wait: _wait_batch_p2p(work) +<<<<<<< HEAD +======= + # No loss function, no need to run backward + if not self._has_backward: + return + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Run backward # Delay send waits bwd_sends_to_wait: list[list[dist.Work]] = [] @@ -774,13 +841,22 @@ def _step_microbatches( grad_scale_factor=self._n_microbatches if self.scale_grads else 1 ) +<<<<<<< HEAD +======= + # Return losses if there is a container passed in + self._update_losses(self._stage, losses) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Wait for all backward sends to finish for work in bwd_sends_to_wait: _wait_batch_p2p(work) +<<<<<<< HEAD # Update losses if there is a container passed in self._update_losses(self._stage, losses) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: """ Returns the pipeline order for GPipe schedule. @@ -1033,7 +1109,11 @@ def _add_unshard_reshard( compute_actions: list[Optional[_Action]], max_active_stages: int = 3, ) -> list[_Action]: +<<<<<<< HEAD """Given a basic schedule involving only compute actions (F,B,W,OVERLAP_F_B), add UNSHARD/RESHARD actions for FSDP. +======= + """Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation. RESHARD does the opposite, releasing memory (but doing no communication) @@ -1053,6 +1133,7 @@ def next_stage_indices( ret: list[int] = [] for a in next_actions: +<<<<<<< HEAD if a is not None: # Handle OVERLAP_F_B actions by checking their sub_actions if a.computation_type == OVERLAP_F_B and a.sub_actions is not None: @@ -1071,6 +1152,13 @@ def next_stage_indices( ret.append(a.stage_index) if len(ret) == count: break +======= + if a is not None and a.stage_index not in seen: + seen.add(a.stage_index) + ret.append(a.stage_index) + if len(ret) == count: + break +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ret active_stages: set[int] = set() @@ -1127,6 +1215,7 @@ def _merge_bw( if action is None: continue +<<<<<<< HEAD # Remove any None actions and find the next non-None action while len(compute_actions) and compute_actions[0] is None: compute_actions.pop(0) @@ -1134,6 +1223,12 @@ def _merge_bw( # Get the next action if it exists next_action = compute_actions[0] if len(compute_actions) > 0 else None +======= + while len(compute_actions) and (next_action := compute_actions[0]) is None: + # remove any None actions between 'action' and 'next_action' + compute_actions.pop(0) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( action.computation_type == BACKWARD_INPUT and next_action is not None @@ -1155,9 +1250,12 @@ def _add_send_recv( stage_to_rank: Callable[[int], int], num_stages: int, ) -> dict[int, list[_Action]]: +<<<<<<< HEAD """ Transforms a compute-only schedule into a complete schedule with communication actions. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) comm_actions: dict[int, list[_Action]] = {rank: [] for rank in compute_actions} prev_actions: dict[int, set[_Action]] = {rank: set() for rank in compute_actions} @@ -1226,6 +1324,7 @@ def _ready_to_schedule( else: return True +<<<<<<< HEAD # TODO: For now we are splitting OVERLAP_F_B into replacing it to # its forward and backward components # We need to figure out how to do the communication @@ -1239,6 +1338,8 @@ def _ready_to_schedule( new_actions.append(action) compute_actions[rank] = new_actions +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) while compute_actions: progress = False # go in order of ranks even if dict keys aren't ordered @@ -1295,6 +1396,7 @@ def _validate_schedule( for stage_id in range(num_stages) } stage_index_to_rank_mapping = {} +<<<<<<< HEAD def _process_action(action: _Action, rank: int, step: int): """Process a single action and update stage_actions and stage_index_to_rank_mapping""" @@ -1371,6 +1473,42 @@ def _process_action(action: _Action, rank: int, step: int): else: # Process the main action normally _process_action(action, rank, step) +======= + for rank in actions: + for action in actions[rank]: + if action is None: + continue + assert isinstance(action, _Action), ( + f"Got an invalid action: {action}, expected instance of _Action" + ) + s_id = action.stage_index + ctype = action.computation_type + mb_id = action.microbatch_index + if ctype == F: + stage_actions[s_id][F].add(mb_id) + elif ctype == B: + assert mb_id in stage_actions[s_id][F], ( + f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward" + ) + stage_actions[s_id][B].add(mb_id) + elif ctype == I: + assert mb_id in stage_actions[s_id][F], ( + f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward" + ) + stage_actions[s_id][I].add(mb_id) + elif ctype == W: + assert mb_id in stage_actions[s_id][I], ( + f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward Input" + ) + stage_actions[s_id][W].add(mb_id) + if s_id not in stage_index_to_rank_mapping: + stage_index_to_rank_mapping[s_id] = rank + else: + existing_rank = stage_index_to_rank_mapping[s_id] + assert rank == existing_rank, ( + f"Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for s_id in stage_actions: f_mb = len(stage_actions[s_id][F]) @@ -1382,11 +1520,14 @@ def _process_action(action: _Action, rank: int, step: int): f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}" ) +<<<<<<< HEAD assert i_mb == w_mb, ( f"Invalid backward microbatches for stage {s_id}: I and W must have equal counts, \ but got I={i_mb}, W={w_mb}" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert b_mb + (i_mb + w_mb) // 2 == num_microbatches, ( f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \ but got B={b_mb}, I={i_mb}, W={w_mb}" @@ -1436,6 +1577,12 @@ def __init__( for stage in self._stages: stage.stage_index_to_group_rank = self.stage_index_to_group_rank +<<<<<<< HEAD +======= + # Set the same has_backward flag for stage object + for stage in self._stages: + stage.has_backward = self._has_backward +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._stages_initialized = False # avoid putting a reference to 'self' inside the lambda, it creates a ref cycle @@ -1452,6 +1599,7 @@ def __init__( ) def _initialize_stages(self, args: tuple[Any, ...], kwargs): +<<<<<<< HEAD # Prepare the communication needed for the pipeline schedule execution # This is needed because during execution we always perform a series of batch P2P ops # The first call of the batched P2P needs to involve the global group @@ -1460,6 +1608,8 @@ def _initialize_stages(self, args: tuple[Any, ...], kwargs): all_ops.extend(stage._get_init_p2p_neighbors_ops()) _wait_batch_p2p(_batch_p2p(all_ops)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # may be 'none' value (if this stage sends its output shapes to the next stage via P2P) # or real value (if this stage and next stage are on the same device) next_stage_args: tuple[Any, ...] = tuple() @@ -1526,6 +1676,7 @@ def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): target: target for the loss function. losses: a list to store the losses for each microbatch. """ +<<<<<<< HEAD if self._has_backward and not torch.is_grad_enabled(): raise RuntimeError( "step() requires gradients to be enabled for backward computation; " @@ -1537,6 +1688,8 @@ def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): for stage in self._stages: stage.has_backward = self._has_backward +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Clean per iteration for stage in self._stages: stage.clear_runtime_states() @@ -1730,11 +1883,18 @@ def _step_microbatches( _wait_batch_p2p(_batch_p2p(ops)) except Exception as e: logger.error( +<<<<<<< HEAD "[Rank %s] pipeline schedule %s caught the following exception '%s' \ at time_step %s when running action %s", self.rank, self.__class__.__name__, str(e), +======= + "[Rank %s] pipeline schedule %s caught the following exception \ + at time_step %s when running action %s", + self.rank, + self.__class__.__name__, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) time_step, action, ) @@ -1757,7 +1917,11 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti): subclassed and the subclass can be responsible for creating a schedule IR. """ +<<<<<<< HEAD def _prepare_schedule_with_comms( +======= + def _load_actions( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, actions: dict[int, list[Optional[_Action]]], format: str = "compute_only", @@ -1778,6 +1942,7 @@ def _prepare_schedule_with_comms( self.pipeline_order_with_comms[rank].append(action) # TODO what level of validation should we offer for compute+comms schedule? elif format == "compute_only": +<<<<<<< HEAD # Validate that the schedule does not have comms already added to it for rank, action_list in actions.items(): for i, action in enumerate(action_list): @@ -1789,6 +1954,8 @@ def _prepare_schedule_with_comms( f"should not be present when format='compute_only'." ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Perform schedule lowering for rank in actions: self.pipeline_order_with_comms[rank] = _add_unshard_reshard( @@ -1813,13 +1980,18 @@ def _load_csv(self, filename: str, format: str = "compute_only"): # this will populate self.pipeline_order super()._load_csv(filename) # this will populate self.pipeline_order_with_comms +<<<<<<< HEAD self._prepare_schedule_with_comms(self.pipeline_order) +======= + self._load_actions(self.pipeline_order) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif format == "compute_comms": actions = {} with open(filename, newline="") as csvfile: reader = csv.reader(csvfile) for rank, row in enumerate(reader): actions[rank] = [_Action.from_str(s) for s in row] +<<<<<<< HEAD self._prepare_schedule_with_comms(actions, format=format) else: raise NotImplementedError(f"{format=} is not implemented") @@ -1842,6 +2014,23 @@ def _dump_csv(self, filename: str, format: str = "compute_comms"): writer = csv.writer(csvfile) for rank in self.pipeline_order_with_comms: writer.writerow(self.pipeline_order_with_comms[rank]) +======= + self._load_actions(actions, format=format) + else: + raise NotImplementedError(f"{format=} is not implemented") + + def _dump_csv(self, filename: str): + """Dump a CSV representation of the compute + comms schedule into a file with the provided filename.""" + # TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible + # that it does not exist if it was created from a compute_comms schedule. + assert self.pipeline_order_with_comms is not None, ( + "Must initialize compute_comms schedule before dump_csv" + ) + with open(filename, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + for rank in self.pipeline_order_with_comms: + writer.writerow(self.pipeline_order_with_comms[rank]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _simulate(self): return _simulate_comms_compute( @@ -1874,7 +2063,11 @@ def _step_microbatches( } assert self.pipeline_order_with_comms is not None, ( +<<<<<<< HEAD "Must call _prepare_schedule_with_comms() before calling _step_microbatches()" +======= + "Must call _load_actions() before calling _step_microbatches()" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use @@ -1925,6 +2118,7 @@ def _assert_unsharded(stage_idx: int): action, ) +<<<<<<< HEAD with record_function(_get_profiler_function_name(action)): # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections, # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be @@ -2074,6 +2268,150 @@ def _assert_unsharded(stage_idx: int): ) else: raise ValueError(f"{action=} is unknown or unsupported") +======= + # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections, + # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be + # safe to use instead. + # However, I was wondering if I should avoid calling batched operators at all in the case that there is + # only one operator per batch. I could iterate through the 'fwd_send_ops' one by one and run them. + if comp_type == SEND_F: + send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index))) + elif comp_type == SEND_B: + send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index))) + elif comp_type == RECV_F: + assert ( + stage_idx, + mb_index, + ) not in fwd_recv_ops, ( + "Recv twice for {stage_idx=} {mb_index=} without executing forward" + ) + fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( + stage.get_fwd_recv_ops(mb_index) + ) + elif comp_type == RECV_B: + assert ( + stage_idx, + mb_index, + ) not in bwd_recv_ops, ( + "Recv twice for {stage_idx=} {mb_index=} without executing backward" + ) + bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( + stage.get_bwd_recv_ops(mb_index) + ) + elif comp_type == UNSHARD: + if stage_uses_fsdp: + assert ( + stage_idx not in unsharded_stages + and stage_idx not in unshard_ops + ), f"Unsharding the same {stage_idx=} twice" + unshard_ops[stage_idx] = stage.submod.unshard(async_op=True) # type: ignore[operator] + elif comp_type == RESHARD: + if stage_uses_fsdp: + assert stage_idx in unsharded_stages, ( + f"Resharding {stage_idx=} without unsharding" + ) + assert stage_idx not in unshard_ops, ( + f"Resharding {stage_idx=} before finishing unshard" + ) + stage.submod.reshard() # type: ignore[operator] + elif comp_type == FORWARD: + if stage_uses_fsdp: + _assert_unsharded(stage_idx) + + if ( + not stage.is_first + # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) + and not is_prev_stage_on_this_rank + ): + assert ( + stage_idx, + mb_index, + ) in fwd_recv_ops, f"Computing {action=} before receiving input" + _wait_batch_p2p(fwd_recv_ops.pop((stage_idx, mb_index))) + + output = stage.forward_one_chunk( + mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index] + ) + self._maybe_compute_loss(stage, output, target_mbs, mb_index) + + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_next_stage_on_this_rank: + stage_index_to_stage[stage_idx + 1].set_local_fwd_input( + output, mb_index + ) + + elif comp_type == FULL_BACKWARD: + if stage_uses_fsdp: + _assert_unsharded(stage_idx) + + if ( + not stage.is_last + # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) + and not is_next_stage_on_this_rank + ): + assert ( + stage_idx, + mb_index, + ) in bwd_recv_ops, ( + f"Attempted to run compute {action=} before receiving input" + ) + _wait_batch_p2p(bwd_recv_ops.pop((stage_idx, mb_index))) + loss = self._maybe_get_loss(stage, mb_index) + backward_counter[stage_idx] += 1 + last_backward = backward_counter[stage_idx] == self._n_microbatches + grad_scale_factor = self._n_microbatches if self.scale_grads else 1 + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=True, + last_backward=last_backward, + ) + if last_backward: + stage.scale_grads(grad_scale_factor) + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_prev_stage_on_this_rank: + stage_index_to_stage[stage_idx - 1].set_local_bwd_input( + stage.get_local_bwd_output(mb_index), mb_index + ) + elif comp_type == BACKWARD_INPUT: + if stage_uses_fsdp: + _assert_unsharded(stage_idx) + + if not stage.is_last and not is_next_stage_on_this_rank: + assert ( + stage_idx, + mb_index, + ) in bwd_recv_ops, ( + f"Attempted to run compute {action=} before receiving input" + ) + _wait_batch_p2p(bwd_recv_ops.pop((stage_idx, mb_index))) + loss = self._maybe_get_loss(stage, mb_index) + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=False, + last_backward=False, + ) + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_prev_stage_on_this_rank: + stage_index_to_stage[stage_idx - 1].set_local_bwd_input( + stage.get_local_bwd_output(mb_index), mb_index + ) + elif comp_type == BACKWARD_WEIGHT: + if stage_uses_fsdp: + _assert_unsharded(stage_idx) + backward_counter[stage_idx] += 1 + stage.backward_weight_one_chunk( + mb_index, + last_backward=backward_counter[stage_idx] + == self._n_microbatches, + ) + else: + raise ValueError(f"{action=} is unknown or unsupported") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except Exception as e: logger.error( "_PipelineScheduleRuntime caught exception at step %s when running action %s. Full Schedule:", @@ -2585,7 +2923,11 @@ def need_bubble(stage, op, microbatch, num_stages_global, seen_ops): if actions[rank][timestamp] is not None: temp_action = actions[rank][timestamp] assert temp_action is not None +<<<<<<< HEAD stage_index, op, microbatch, _ = temp_action +======= + stage_index, op, microbatch = temp_action +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not need_bubble( stage_index, op, microbatch, num_stages_global, seen_ops ): @@ -2792,6 +3134,7 @@ def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]: return rank_ops +<<<<<<< HEAD class ScheduleDualPipeV(_PipelineScheduleRuntime): """ The DualPipeV schedule. A more efficient schedule variant based on the @@ -3013,6 +3356,8 @@ def add_weight_action_if_pending(actions: list): return actions +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_schedule_class(schedule_name: str): """ Maps a schedule name (case insensitive) to its corresponding class object. @@ -3029,7 +3374,10 @@ def get_schedule_class(schedule_name: str): "PipelineScheduleSingle": PipelineScheduleSingle, "PipelineScheduleMulti": PipelineScheduleMulti, "ZBVZeroBubble": ScheduleZBVZeroBubble, +<<<<<<< HEAD "DualPipeV": ScheduleDualPipeV, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } lowercase_keys = {k.lower(): k for k in schedule_map.keys()} lowercase_schedule_name = schedule_name.lower() diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index 6615ced0398e5..4e352e69630dc 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -462,10 +462,18 @@ def get_bwd_send_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: """ Get the gradient send ops for current stage's backward. """ +<<<<<<< HEAD if not self.has_backward or self.is_first: return [] self._check_chunk_id(bwd_chunk_id) +======= + self._check_chunk_id(bwd_chunk_id) + + if not self.has_backward or self.is_first: + return [] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Create bwd send infra lazily if self.grad_send_info is None: # Send info for input grads during backward: @@ -760,10 +768,13 @@ def backward_one_chunk( last_backward is controlled by the schedule and signals synchronization of gradients across DP groups after the last backward. """ +<<<<<<< HEAD # skip backward computation if backward is not enabled if not self.has_backward: return +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._check_chunk_id(bwd_chunk_id) ( @@ -851,10 +862,13 @@ def backward_one_chunk( logger.debug("%s Backwarded chunk %s", self.log_prefix, bwd_chunk_id) def backward_weight_one_chunk(self, bwd_chunk_id: int, last_backward=False): +<<<<<<< HEAD # skip backward computation if backward is not enabled if not self.has_backward: return +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert bwd_chunk_id in self.dw_runner, ( f"{self.log_prefix} Attempted to run backward_weight_one_chunk for chunk {bwd_chunk_id}" " without first calling `backward_one_chunk(full_backward=False)`" @@ -935,6 +949,7 @@ def _validate_fwd_outputs(self, outputs: tuple[torch.Tensor, ...]): f"Stage {self.stage_index} forward outputs", expected_tensors_meta, outputs ) +<<<<<<< HEAD def _get_init_p2p_neighbors_ops(self) -> list[dist.P2POp]: """ Get the operations to initialize the p2p communicators between previous and next stages. @@ -989,6 +1004,8 @@ def _get_init_p2p_neighbors_ops(self) -> list[dist.P2POp]: return ops +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _PipelineStage(_PipelineStageBase): def __init__( @@ -1424,7 +1441,10 @@ def _shape_inference( ), group=self.group, device=self.device, +<<<<<<< HEAD use_batch=True, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) recv_args = objects[0] assert isinstance(recv_args, tuple), type(recv_args) @@ -1490,7 +1510,10 @@ def _shape_inference( ), group=self.group, device=self.device, +<<<<<<< HEAD use_batch=True, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) outputs_meta = tuple() diff --git a/torch/distributed/rpc/internal.py b/torch/distributed/rpc/internal.py index c830fc11d8edd..dec73c80b66a1 100644 --- a/torch/distributed/rpc/internal.py +++ b/torch/distributed/rpc/internal.py @@ -226,7 +226,11 @@ def _handle_exception(result): exc = None try: exc = result.exception_type(exception_msg) +<<<<<<< HEAD except BaseException as e: # noqa: B036 +======= + except BaseException as e: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise RuntimeError( # noqa: B904 f"Failed to create original exception type. Error msg was {str(e)}" f" Original exception on remote side was {exception_msg}" diff --git a/torch/distributed/rpc/rref_proxy.py b/torch/distributed/rpc/rref_proxy.py index 71c111b2f2e65..d06502d692375 100644 --- a/torch/distributed/rpc/rref_proxy.py +++ b/torch/distributed/rpc/rref_proxy.py @@ -53,13 +53,21 @@ def _rref_type_cont(rref_fut): def _wrap_rref_type_cont(fut): try: _rref_type_cont(fut).then(_complete_op) +<<<<<<< HEAD except BaseException as ex: # noqa: B036 +======= + except BaseException as ex: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result.set_exception(ex) def _complete_op(fut): try: result.set_result(fut.value()) +<<<<<<< HEAD except BaseException as ex: # noqa: B036 +======= + except BaseException as ex: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result.set_exception(ex) rref_fut.then(_wrap_rref_type_cont) diff --git a/torch/distributed/run.py b/torch/distributed/run.py index 2738191f0e379..121d54ba141eb 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -382,10 +382,13 @@ def main(): from torch.distributed.elastic.utils import macros from torch.distributed.elastic.utils.logging import get_logger from torch.distributed.launcher.api import elastic_launch, LaunchConfig +<<<<<<< HEAD from torch.numa.binding import ( AffinityMode as _AffinityMode, # Signify as private with _ NumaOptions as _NumaOptions, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils.backend_registration import _get_custom_mod_func @@ -620,6 +623,7 @@ def get_args_parser() -> ArgumentParser: "Can be used to override custom logging behavior.", ) +<<<<<<< HEAD parser.add_argument( "--numa-binding", "--numa_binding", @@ -645,6 +649,8 @@ def get_args_parser() -> ArgumentParser: featuring a single L3 cache per socket.""", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # # Positional arguments. # @@ -838,11 +844,14 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str tee=Std.from_str(args.tee), local_ranks_filter=ranks, ) +<<<<<<< HEAD numa_options = ( None if args.numa_binding is None else _NumaOptions(affinity_mode=_AffinityMode(args.numa_binding)) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) config = LaunchConfig( min_nodes=min_nodes, @@ -860,7 +869,10 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str local_addr=args.local_addr, logs_specs=logs_specs, event_log_handler=args.event_log_handler, +<<<<<<< HEAD numa_options=numa_options, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) with_python = not args.no_python diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index 7eeafaa8eaf9d..643c84a4da689 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -269,6 +269,7 @@ def __new__( # new method instruct wrapper tensor from local_tensor and add # placement spec, it does not do actual distribution assert spec.tensor_meta is not None, "TensorMeta should not be None!" +<<<<<<< HEAD r = torch.Tensor._make_dtensor( cls, @@ -276,6 +277,16 @@ def __new__( spec.tensor_meta.stride, local_tensor, requires_grad, +======= + r = torch.Tensor._make_wrapper_subclass( + cls, + spec.tensor_meta.shape, + strides=spec.tensor_meta.stride, + dtype=local_tensor.dtype, + device=local_tensor.device, + layout=local_tensor.layout, + requires_grad=requires_grad, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) r._spec = spec @@ -562,7 +573,11 @@ def full_tensor( """ Return the full tensor of this DTensor. It will perform necessary collectives to gather the local tensors from other ranks in its DeviceMesh and concatenate +<<<<<<< HEAD them together. It's a syntactic sugar of the following code: +======= + them together. It's a syntatic sugar of the following code: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()`` @@ -1002,7 +1017,11 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def] # set default placements to replicated if not specified placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) +<<<<<<< HEAD # check device_mesh against placements +======= + # check device_mesh againts placements +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert device_mesh.ndim == len(placements), ( "mesh dimension does not match the length of placements" ) diff --git a/torch/distributed/tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py index 4fce6fea538a6..28508a2469fe3 100644 --- a/torch/distributed/tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -25,6 +25,7 @@ logger = logging.getLogger(__name__) +<<<<<<< HEAD @torch.library.register_fake("_dtensor::shard_dim_alltoall") def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): group_size = _get_group_size_by_name(group_name) @@ -36,6 +37,28 @@ def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): torch.cat(stacked_list, dim=gather_dim) .chunk(group_size, dim=shard_dim)[group_rank] .contiguous() +======= +if not torch._running_with_deploy(): + + @torch.library.register_fake("_dtensor::shard_dim_alltoall") + def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): + group_size = _get_group_size_by_name(group_name) + stacked_list = [torch.empty_like(input) for _ in range(group_size)] + group = _resolve_process_group(group_name) + group_rank = get_group_rank(group, get_rank()) + + return ( + torch.cat(stacked_list, dim=gather_dim) + .chunk(group_size, dim=shard_dim)[group_rank] + .contiguous() + ) + +else: + import warnings + + warnings.warn( + "PyTorch Distributed functional collectives do not work with torch::deploy." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -316,7 +339,11 @@ def redistribute_cost( NOTE: 1. Only consider communication cost here, since computation costs for redistribute +<<<<<<< HEAD are quite trivial (i.e. we only need to narrow or simple division) +======= + are quite trival (i.e. we only need to narrow or simple division) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 2. Only consider redistribute cost on same mesh, cross mesh communication cost is not quite needed for operator strategy estimation/selection. """ diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index 7ac7801b50bca..8987c1be9dfb3 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -23,7 +23,10 @@ ) from torch.distributed.tensor._utils import try_find_mesh_from_args from torch.distributed.tensor.placement_types import Partial, Placement, Replicate +<<<<<<< HEAD from torch.utils._python_dispatch import return_and_correct_aliasing +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: @@ -121,6 +124,7 @@ def __init__(self) -> None: aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler, } +<<<<<<< HEAD # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) # as implicitly replicated or we throw error to user. # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave @@ -132,6 +136,13 @@ def _allow_implicit_replication(self) -> bool: @_allow_implicit_replication.setter def _allow_implicit_replication(self, value: bool) -> None: return torch._C._set_dtensor_allow_implicit_replication(value) +======= + # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) + # as implicitly replicated or we throw error to user. + # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave + # it as False by default. + self._allow_implicit_replication = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def dispatch( self, @@ -140,11 +151,25 @@ def dispatch( kwargs: dict[str, object], ) -> object: """ +<<<<<<< HEAD Main dispatching logic. Follows precedence order: (1) custom_op_handler (2) registered sharding strategy, then rule (3) composite implicit autograd decomposition """ +======= + Main dispatching logic + """ + # operators that does not need to go through sharding propagation + if torch._C._dispatch_has_kernel_for_dispatch_key( + op_call.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ): + # When running under inference mode, CompositeImplicitAutograd ops show up in __torch_dispatch__, + # so we manually decompose them, here + out = op_call.decompose(*args, **kwargs) + assert out is not NotImplemented + return out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if op_call in self._custom_op_handlers: return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] @@ -152,6 +177,7 @@ def dispatch( op_info = self.unwrap_to_op_info(op_call, args, kwargs) logger.debug("Dispatching op_call: %s", op_info.schema) +<<<<<<< HEAD try: self.sharding_propagator.propagate(op_info) except NotImplementedError: @@ -170,22 +196,33 @@ def dispatch( f"Sharding propagation failed for {op_info.schema}" ) from e +======= + self.sharding_propagator.propagate(op_info) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_sharding = op_info.output_sharding logger.debug("output_sharding for %s: %s", op_call, output_sharding) assert output_sharding is not None, "output sharding should not be None" mesh = op_info.compute_mesh +<<<<<<< HEAD participating = mesh.get_coordinate() is not None if participating: +======= + if mesh.get_coordinate() is not None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # computation that happens in the current rank of the mesh, normal case if output_sharding.needs_redistribute: # If sharding propagation decision needs redistribute, perform redistribute # on args first, which could potentially modify args (i.e. allgather certain arg) assert output_sharding.redistribute_schema is not None self.redistribute_local_args( +<<<<<<< HEAD op_info, output_sharding.redistribute_schema, output_sharding.use_val_from_redistribute_schema, +======= + op_info, output_sharding.redistribute_schema +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) local_tensor_args = ( @@ -208,6 +245,7 @@ def dispatch( cast(dtensor.DTensor, args[0]), cast(torch.Tensor, local_tensor_args[0]), ) +<<<<<<< HEAD # If the user provided a generator, we hook it up to our RNG manager, but we also pop it from kwargs # so the op_call does not directly use it (we want op_call to fall back to the 'default' which is @@ -221,6 +259,10 @@ def dispatch( random._rng_tracker._distribute_region( first_arg._spec, generator=maybe_user_generator ) +======= + rng_context = ( + random._rng_tracker._distribute_region(first_arg._spec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if random._rng_tracker and not first_local_arg.is_meta else contextlib.nullcontext() ) @@ -289,6 +331,7 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: if op_info.schema.is_inplace_op(): # inplace op should return self instead of re-wrapping if output_sharding.output_spec is not None: +<<<<<<< HEAD # NOTE: aten.squeeze_.dim is an inplace op but it also may change # the inplace argument's tensor meta. Here we choose to special case # this op because as far as I know this is the only inplace op that @@ -303,6 +346,9 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: return return_and_correct_aliasing(op_call, args, kwargs, args[0]) else: return args[0] +======= + return args[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return None elif op_info.schema.is_out_variant_op(): @@ -324,17 +370,24 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: assert len(out_dts) >= 1, "out variant should have at least one out arg" return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] else: +<<<<<<< HEAD ret = self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] if participating and op_info.schema.is_view_op(): return return_and_correct_aliasing(op_call, args, kwargs, ret) else: return ret +======= + return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def redistribute_local_args( op_info: OpInfo, suggested_input_schema: OpSchema, +<<<<<<< HEAD use_val_from_redistribute_schema: bool, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: # NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it if op_info.args_tree_spec is not None: @@ -357,12 +410,16 @@ def redistribute_local_args( else: new_local_args.append(local_tensor) else: +<<<<<<< HEAD if use_val_from_redistribute_schema: # args can be updated for view related ops, we refer to the # update in redistribute_schema. new_local_args.append(reshard_arg_spec) else: new_local_args.append(arg_spec) +======= + new_local_args.append(reshard_arg_spec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) op_info.local_args = tuple(new_local_args) @@ -489,7 +546,11 @@ def _try_replicate_spec_for_scalar_tensor( "Found a non-scalar tensor with numel=1 and ndim!=0, " "we are implicitly creating a replicated DTensor for it. " "However, please consider changing it to a scalar tensor " +<<<<<<< HEAD "or explicitly create a DTensor under distributed environment." +======= + "or explicitly create a DTensor under distributed enviroment." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if tensor_arg.numel() == 1 or self._allow_implicit_replication: diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py index bffb399b2bca8..108a00b4e3c9b 100644 --- a/torch/distributed/tensor/_dtensor_spec.py +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -40,6 +40,7 @@ def __setattr__(self, attr: str, value: Any) -> None: # change (though we do not expect `mesh` or `placements` to change) if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"): self._hash = None +<<<<<<< HEAD # This assert was triggered by buggy handling for dict outputs in some # FX passes, where you accidentally iterate over a dict and try to put # keys into TensorMeta. See https://github.com/pytorch/pytorch/issues/157919 @@ -50,6 +51,8 @@ def __setattr__(self, attr: str, value: Any) -> None: # test/distributed/tensor/experimental/test_tp_transform.py::TensorParallelTest::test_tp_transform_e2e # but I actually can't reproduce it, maybe it is also a bug! assert isinstance(value, (TensorMeta, TensorMetadata)), value +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _hash_impl(self) -> int: # hashing and equality check for DTensorSpec are used to cache the sharding @@ -250,7 +253,11 @@ def from_dim_map( if placement.is_shard(): placement = cast(Shard, placement) raise RuntimeError( +<<<<<<< HEAD f"DeviceMesh dimension can't be mapped to two dimension of the same tensor: {i} and {placement.dim}" +======= + f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) elif placement.is_partial(): raise RuntimeError( diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index 6f8c644095eec..4c71a78b5b4d4 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD """ DTensor operator schema definitions and utilities. @@ -23,11 +24,16 @@ 4. Cache sharding decisions for performance optimization """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from collections.abc import Sequence from dataclasses import dataclass from functools import cached_property from typing import Any, Optional, Union +<<<<<<< HEAD from typing_extensions import deprecated +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch._ops import OpOverload @@ -37,6 +43,7 @@ try: +<<<<<<< HEAD from torch.utils._cxx_pytree import ( register_pytree_node, tree_leaves, @@ -46,6 +53,11 @@ except ImportError: from torch.utils._pytree import ( # type: ignore[no-redef, assignment] register_pytree_node, +======= + from torch.utils._cxx_pytree import tree_leaves, tree_map_only, TreeSpec +except ImportError: + from torch.utils._pytree import ( # type: ignore[no-redef, assignment] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tree_leaves, tree_map_only, TreeSpec, @@ -58,7 +70,11 @@ PlacementList = list[Optional[Placement]] +<<<<<<< HEAD # ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type should +======= +# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # be the same set of possibilities. OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]] @@ -95,6 +111,7 @@ class OpSpec: note: when the op return value is a single DTensor object, output_specs is DTensorSpec; when the return value is a tuple of Optional[DTensor], output_specs is a tuple of Optional[DTensorSpec]. +<<<<<<< HEAD note: we MUST produce an DTensorSpec for every output that is a Tensor. None entries only occur for non-Tensor outputs (e.g., operators that return Optional[Tensor], @@ -134,6 +151,16 @@ class OpSpec: K, # cost of redistributing tensor_a from 'Shard(0)' ], """ +======= + """ + + output_specs: Union[DTensorSpec, tuple[Optional[DTensorSpec], ...]] + input_specs: Optional[Sequence[DTensorSpec]] = None + + # redistribute costs to redistribute the operator input shardings to this OpSpec. + # Note that We need a nested list to record the cost for each operand of this + # operator, and for each operand of this operator it might have multiple OpSpecs. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) redistribute_cost: Optional[list[list[float]]] = None @cached_property @@ -190,8 +217,11 @@ class OpStrategy(StrategyType): """ OpStrategy that consists of a list of sharding strategies associated with the op, where each strategy is an OpSpec that describes the acceptable input/output sharding. +<<<<<<< HEAD invariant: the DeviceMesh on all OpSpec must be the same +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ def __init__(self, strategies: list[OpSpec]) -> None: @@ -228,6 +258,7 @@ def shape(self): class TupleStrategy(StrategyType): """ +<<<<<<< HEAD TupleStrategy is a special case for operators that are fundamentally compound or batched such that some subset of the inputs and outputs are completely unrelated to some other subset. @@ -262,16 +293,39 @@ def childs(self) -> Sequence[StrategyType]: # codespell:ignore childs def child_mesh(self, index: int) -> DeviceMesh: op_strategy = self.children[index] +======= + TupleStrategy represents the output strategy of this op is a tuple of OpStrategies, + i.e. If the output of this op is a tuple of tensors or list of tensors with possibly + different OpStrategies, we should return a TupleStrategy that contains a tuple of + OpStrategy, where each child represents the sharding strategy of "each element" of + the tuple/list of tensors the op returns. + + NOTE: if the output of the op is a List[Tensor] and they share the same OpStrategy, + then we should return a single OpStrategy instead of a TupleStrategy + """ + + def __init__(self, childs: Sequence[StrategyType]) -> None: + super().__init__() + self.childs: Sequence[StrategyType] = childs + + def child_mesh(self, index: int) -> DeviceMesh: + op_strategy = self.childs[index] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(op_strategy, OpStrategy) return op_strategy.mesh def __str__(self) -> str: child_strategies_str = ", ".join( +<<<<<<< HEAD [f"{str(strat)}" for idx, strat in enumerate(self.children)] +======= + [f"{str(strat)}" for idx, strat in enumerate(self.childs)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return f"TupleStrategy({child_strategies_str})" +<<<<<<< HEAD try: register_pytree_node( TupleStrategy, @@ -283,6 +337,8 @@ def __str__(self) -> str: pass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass class RuntimeSchemaInfo: """ @@ -312,7 +368,11 @@ class OpSchema: order preserved). It is mainly used by the DTensor's dispatching logic to perform various actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.) +<<<<<<< HEAD NOTE: this must be used as a read only data class +======= + NOTE: this should be used as a read only data class +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TODO: make this a frozen dataclass Args: @@ -329,8 +389,11 @@ class OpSchema: schema_info: Optional[RuntimeSchemaInfo] = None +<<<<<<< HEAD _comparison_key: Optional[tuple[object, ...]] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def args_spec(self) -> tuple[DTensorSpec, ...]: """ @@ -377,7 +440,11 @@ def __str__(self) -> str: args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs)) mesh_shape = arg.mesh_shape elif isinstance(arg, TupleStrategy): +<<<<<<< HEAD first_op_strategy = arg.children[0] +======= + first_op_strategy = arg.childs[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(first_op_strategy, OpStrategy) mesh_shape = first_op_strategy.mesh_shape args_schema.append(str(arg)) @@ -393,9 +460,15 @@ def __post_init__(self) -> None: has_symints = True break self.has_symints = has_symints +<<<<<<< HEAD self._recompute_comparison_key() def arg_type_tensor_or_tensor_list_like(self, arg: object) -> bool: +======= + + def arg_type_tensor_or_tensor_list_like(self, arg_idx: int) -> bool: + arg = self.args_schema[arg_idx] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) is_tensor = isinstance(arg, DTensorSpec) if is_tensor: return True @@ -413,6 +486,7 @@ def return_type_tuple_tensor_like(self) -> bool: return_types[0].type, torch.TensorType ) +<<<<<<< HEAD def return_type_list_tensor_like(self) -> bool: # returns True if the return type is a List return_types = self.op._schema.returns @@ -420,6 +494,8 @@ def return_type_list_tensor_like(self) -> bool: return_types[0].type, torch.ListType ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def return_type_tensor(self) -> bool: return_types = self.op._schema.returns # all dispatch ops only return Tensor or Tuple[Tensor] for tensor like @@ -444,7 +520,11 @@ def get_mesh_from_args(self, validate: bool = True) -> DeviceMesh: mesh = first_arg.mesh elif isinstance(first_arg, (list, tuple, TupleStrategy)): first_elem = ( +<<<<<<< HEAD first_arg.children[0] +======= + first_arg.childs[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(first_arg, TupleStrategy) else first_arg[0] ) @@ -476,10 +556,15 @@ def is_out_variant_op(self) -> bool: # be entirely correct, but it's good enough for now. return "out" in self.op._schema.overload_name +<<<<<<< HEAD def is_view_op(self) -> bool: return self.op._schema._is_view_op() def _recompute_comparison_key(self): +======= + def __hash__(self) -> int: + # Only hash args and kwargs that op indicates to hash +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not self.schema_info: static_argnum = len(self.args_schema) static_kwargkey = None @@ -490,18 +575,28 @@ def _recompute_comparison_key(self): args_to_hash = tuple( tuple(e) if isinstance(e, list) else e for i, e in enumerate(self.args_schema) +<<<<<<< HEAD if self.arg_type_tensor_or_tensor_list_like(e) or i >= static_argnum +======= + if self.arg_type_tensor_or_tensor_list_like(i) or i >= static_argnum +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if static_kwargkey is not None: kwargs_to_hash = tuple( self.kwargs_schema.get(k, None) for k in static_kwargkey ) +<<<<<<< HEAD self._comparison_key = (self.op, args_to_hash, kwargs_to_hash) else: self._comparison_key = (self.op, args_to_hash) def __hash__(self) -> int: return hash(self._comparison_key) +======= + return hash((self.op, args_to_hash, kwargs_to_hash)) + else: + return hash((self.op, args_to_hash)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __eq__(self, other: object) -> bool: # early return checks @@ -514,7 +609,35 @@ def __eq__(self, other: object) -> bool: if len(self.args_schema) != len(other.args_schema): return False +<<<<<<< HEAD return self._comparison_key == other._comparison_key +======= + # compare each element and early return if any of them is different + if not self.schema_info: + static_argnum = len(self.args_schema) + static_kwargkey = None + else: + static_argnum = self.schema_info.static_argnum + static_kwargkey = self.schema_info.static_kwargkey + + for i, (self_arg, other_arg) in enumerate( + zip(self.args_schema, other.args_schema) + ): + if isinstance(self_arg, DTensorSpec) and self_arg != other_arg: + return False + elif i >= static_argnum and self_arg != other_arg: + return False + + # check kwarg equality when there's a static kwarg key + if static_kwargkey: + for key in static_kwargkey: + if self.kwargs_schema.get(key, None) != other.kwargs_schema.get( + key, None + ): + return False + + return True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def gen_fake_args(self) -> ArgsType: """ @@ -561,7 +684,10 @@ def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None: new_arg_schema.append(arg) self.args_schema = tuple(new_arg_schema) self.kwargs_schema = origin_schema.kwargs_schema +<<<<<<< HEAD self._recompute_comparison_key() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass @@ -576,6 +702,7 @@ class OutputSharding: exactly the same as the operator OpSchema, except the DTensorSpecs """ +<<<<<<< HEAD # specifies the output sharding pattern output_spec: OutputSpecType # schema for redistribution if needed @@ -584,6 +711,11 @@ class OutputSharding: needs_redistribute: bool = False # flag to use values from `redistribute_schema` use_val_from_redistribute_schema: bool = False +======= + output_spec: OutputSpecType + redistribute_schema: Optional[OpSchema] = None + needs_redistribute: bool = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @cached_property def mesh(self): diff --git a/torch/distributed/tensor/_ops/_einsum_strategy.py b/torch/distributed/tensor/_ops/_einsum_strategy.py index 506103d70a599..bb34032e39a04 100644 --- a/torch/distributed/tensor/_ops/_einsum_strategy.py +++ b/torch/distributed/tensor/_ops/_einsum_strategy.py @@ -90,6 +90,7 @@ def gen_einsum_strategies( ) -> OpStrategy: """ Generate a strategy list for the ops that follow einsum style notation. +<<<<<<< HEAD In principle, each mesh dim is independent of other device mesh dim when we generate strategies. So we generate strategy over each device mesh dim and @@ -107,10 +108,13 @@ def gen_einsum_strategies( 3. Linearity (Partial): If enabled, set Partial on output and inputs over the same device mesh dim. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ # parse einop equation and extract dims input_dims, output_dim = EinsumDims.parse_equation(equation) edims = EinsumDims.parse_dims(input_dims, output_dim) +<<<<<<< HEAD all_mesh_dim_strategies = [] # generate strategies for each mesh dim and do cartesian product for final strategy. E.g., for a 2D mesh, we can have [P(),R,R] @@ -177,6 +181,78 @@ def gen_einsum_strategies( # generate strategies for entire mesh all_mesh_dim_strategies = [strategies_over_one_mesh_dim] * mesh.ndim strategy_combs = itertools.product(*all_mesh_dim_strategies) +======= + + all_mesh_dim_strategies = [] + + # generate strategies for each mesh dim + for mesh_dim in range(mesh.ndim): + mesh_dim_strategies = [] + + # placement list stores placements of [output, input1, input2, ...] + # first we always have replicate all for inputs and output + placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1) + mesh_dim_strategies.append(placement_list) + + # split batch dim + for batch_dim in edims.batch_dims: + output_batch_dim = output_dim.index(batch_dim) + placement_list = [Shard(output_batch_dim)] + for input_dim in input_dims: + input_batch_dim = input_dim.index(batch_dim) + placement_list.append(Shard(input_batch_dim)) + + mesh_dim_strategies.append(placement_list) + + # split contracting dim + for contracting_dim in edims.contracting_dims: + placement_list = [Partial()] + for input_dim in input_dims: + input_contracting_dim = input_dim.index(contracting_dim) + placement_list.append(Shard(input_contracting_dim)) + + mesh_dim_strategies.append(placement_list) + + # split lhs free dim + for lhs_dim in edims.lhs_out_only_dims: + lhs_free_dim = output_dim.index(lhs_dim) + # this means split the lhs input and output + # i.e. S(0), R -> S(0) + lhs_placement_list: list[Placement] = [ + Shard(lhs_free_dim), + Shard(lhs_free_dim), + Replicate(), + ] + mesh_dim_strategies.append(lhs_placement_list) + + # split rhs free dim + for rhs_dim in edims.rhs_out_only_dims: + rhs_free_dim = output_dim.index(rhs_dim) + rhs_placement_list: list[Placement] = [ + Shard(rhs_free_dim), + Replicate(), + Shard(rhs_free_dim), + ] + mesh_dim_strategies.append(rhs_placement_list) + + # linearity strategy + if linearity: + linearity_placement_list: list[Placement] = [Partial()] + for input_dim in input_dims: + linearity_placement_list.append(Partial()) + mesh_dim_strategies.append(linearity_placement_list) + + all_mesh_dim_strategies.append(mesh_dim_strategies) + + # generate strategies for entire mesh + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + # TODO: filter out invalid strategies, at this point we generate + # all possible strategies without considering the whether the tensor + # dim could be sharded or not, we would need to filter out invalid + # strategies base on the actual tensor shape + # (i.e. for Shard, tensor dim size must > mesh size) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) all_strategies = [] for strategy_comb in strategy_combs: spec_list = [DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)] diff --git a/torch/distributed/tensor/_ops/_embedding_ops.py b/torch/distributed/tensor/_ops/_embedding_ops.py index 1b8e47895ce59..d1196cad0df59 100644 --- a/torch/distributed/tensor/_ops/_embedding_ops.py +++ b/torch/distributed/tensor/_ops/_embedding_ops.py @@ -113,7 +113,11 @@ def _partition_value( def _reduce_value( self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int ) -> torch.Tensor: +<<<<<<< HEAD # by the time we need reduction, we should have already saved the mask +======= + # by the time we ned reduction, we should have already saved the mask +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.mask_buffer.data is not None # apply the mask to the tensor that pending reduction @@ -134,7 +138,11 @@ def _reduce_shard_value( mesh_dim: int, shard_spec: Placement, ) -> torch.Tensor: +<<<<<<< HEAD # by the time we need reduction, we should have already saved the mask +======= + # by the time we ned reduction, we should have already saved the mask +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.mask_buffer.data is not None # apply the mask to the tensor that pending reduction diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 1e6eb40939e4a..710b56e12a91a 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -419,8 +419,13 @@ def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy: assert isinstance(input_tuple_strategy, TupleStrategy) norm_type = args_schema[1] if len(args_schema) > 1 else 2 assert isinstance(norm_type, (int, float, str)), f"{norm_type}" +<<<<<<< HEAD output_tuple_strategy_children: list[OpStrategy] = [] for op_strategy in input_tuple_strategy.children: +======= + output_tuple_strategy_childs: list[OpStrategy] = [] + for op_strategy in input_tuple_strategy.childs: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(op_strategy, OpStrategy), f"{op_strategy}" reduce_dims = list(range(op_strategy.ndim)) output_strategy = common_reduction_strategy( @@ -429,8 +434,13 @@ def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy: reduction_linear=True, reduction_op=NormReduction(norm_type), ) +<<<<<<< HEAD output_tuple_strategy_children.append(output_strategy) return TupleStrategy(output_tuple_strategy_children) +======= + output_tuple_strategy_childs.append(output_strategy) + return TupleStrategy(output_tuple_strategy_childs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_op_strategy( @@ -556,6 +566,7 @@ def softmax_backward_strategy(op_schema: OpSchema) -> OpStrategy: mesh=grad_out_strategy.mesh, placements=replicate_reduction_dims(src_spec.placements, [softmax_dim]), ) +<<<<<<< HEAD new_grad_out_spec = DTensorSpec( mesh=tgt_spec.mesh, placements=tgt_spec.placements, @@ -566,12 +577,17 @@ def softmax_backward_strategy(op_schema: OpSchema) -> OpStrategy: placements=tgt_spec.placements, tensor_meta=out_src_spec.tensor_meta, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) redist_grad_out_cost = generate_redistribute_costs(grad_out_strategy, tgt_spec) redist_out_cost = generate_redistribute_costs(out_strategy, tgt_spec) grad_in_strategy.strategies.append( OpSpec( output_specs=tgt_spec, +<<<<<<< HEAD input_specs=(new_grad_out_spec, new_out_spec), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) redistribute_cost=[redist_grad_out_cost, redist_out_cost], ) ) @@ -818,6 +834,7 @@ def nll_loss_backward_strategy(op_schema: OpSchema) -> OpStrategy: return grad_in_strategy +<<<<<<< HEAD def _common_norm_forward_strategy( op_schema: OpSchema, rms_norm: bool = False, @@ -850,6 +867,29 @@ def _common_norm_forward_strategy( bias_strategy = None # the current norm implementation requires that all +======= +@register_op_strategy( + [aten.native_layer_norm.default], + schema_info=RuntimeSchemaInfo(1), +) +def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + # args must be: input, normalized_shape, weight, bias, eps + # for None weight and bias, their corresponding objects will + # be None as well. layer_norm_strategy returns one OpStrategy + # for the triple return values (out, mean, rstd). + assert len(op_schema.args_schema) == 5 + ( + input_strategy, + normalized_shape, + weight_strategy, + bias_strategy, + _, + ) = op_schema.args_schema + + # the current layer norm implementation requires that all +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # input DTensor's sharding must be in form of OpStrategy assert isinstance(input_strategy, OpStrategy) assert isinstance(normalized_shape, (int, Sequence, torch.Size)) @@ -858,7 +898,11 @@ def _common_norm_forward_strategy( input_ndim = input_strategy.ndim axis = input_ndim - len(normalized_size) +<<<<<<< HEAD # we use OpStrategy because the output values (out, mean, rstd) +======= + # we use OpStrategy because the output (out, mean, rstd) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # should have the same placements output_strategy = OpStrategy([]) for idx, input_placement_strategy in enumerate(input_strategy.strategies): @@ -927,6 +971,7 @@ def _common_norm_forward_strategy( @register_op_strategy( +<<<<<<< HEAD [aten.native_layer_norm.default], schema_info=RuntimeSchemaInfo(1), ) @@ -984,6 +1029,35 @@ def _common_norm_backward_strategy( assert isinstance(rstd_strategy, OpStrategy) if mean_strategy is not None: assert isinstance(mean_strategy, OpStrategy) +======= + [aten.native_layer_norm_backward.default], + schema_info=RuntimeSchemaInfo(2), +) +def layer_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: + # backward op does not need to validate the mesh since forward op has already done it + mesh = op_schema.get_mesh_from_args(validate=False) + + # args must be: grad_out, input, normalized_shape, mean, rstd, + # weight, bias, output_mask. For None weight and bias, their + # corresponding objects will be None as well. + + assert len(op_schema.args_schema) == 8 + ( + grad_out_strategy, + input_strategy, + normalized_shape, + mean_strategy, + rstd_strategy, + weight_strategy, + bias_strategy, + output_mask, + ) = op_schema.args_schema + + assert isinstance(grad_out_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(mean_strategy, OpStrategy) + assert isinstance(rstd_strategy, OpStrategy) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(normalized_shape, (int, Sequence, torch.Size)) normalized_size = normalize_to_torch_size(normalized_shape) @@ -991,12 +1065,18 @@ def _common_norm_backward_strategy( axis = input_ndim - len(normalized_size) outer_dims = list(range(axis)) +<<<<<<< HEAD if not rms_norm: assert isinstance(output_mask, list) and len(output_mask) == 3 else: assert isinstance(output_mask, list) and len(output_mask) == 2 # output tuple: (d_input, d_weight[, d_bias]) +======= + assert isinstance(output_mask, list) and len(output_mask) == 3 + + # output triple: (d_input, d_weight, d_bias) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_tuple_strategy = OpStrategy([]) for idx, input_placement_strategy in enumerate(input_strategy.strategies): # args for OpSpec @@ -1037,6 +1117,7 @@ def _common_norm_backward_strategy( generate_redistribute_costs(input_strategy, input_target_spec) ) +<<<<<<< HEAD # arg: mean if not rms_norm: assert mean_strategy is not None # mypy fix @@ -1045,6 +1126,12 @@ def _common_norm_backward_strategy( redistribute_costs.append([0.0 for _ in mean_strategy.strategies]) # arg: rstd +======= + # arg: mean, rstd + mean_src_spec = mean_strategy.strategies[idx].output_spec + input_specs_list.append(mean_src_spec) + redistribute_costs.append([0.0 for _ in mean_strategy.strategies]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) rstd_src_spec = rstd_strategy.strategies[idx].output_spec input_specs_list.append(rstd_src_spec) redistribute_costs.append([0.0 for _ in rstd_strategy.strategies]) @@ -1060,7 +1147,10 @@ def _add_target_input_spec(strategy) -> DTensorSpec: # arg: weight # d_weight = sum(grad_out * (input - mean) / rstd, outer_dim, keepdim=False) +<<<<<<< HEAD # For RMS norm, mean is 0, so it's just: sum(grad_out * input / rstd, outer_dim, keepdim=False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if weight_strategy is not None: weight_src_spec = _add_target_input_spec(weight_strategy) # TODO: now d_weight spec follows input spec w/ a reduction. @@ -1080,15 +1170,22 @@ def _add_target_input_spec(strategy) -> DTensorSpec: ) output_specs_list.append(weight_out_spec if output_mask[1] else None) else: +<<<<<<< HEAD if not rms_norm: error_msg = "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward." else: error_msg = "output_mask[1] should not be `True` while weight argument is `None` in _fused_rms_norm_backward." assert output_mask[1] is False, error_msg +======= + assert output_mask[1] is False, ( + "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward." + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_specs_list.append(None) # arg: bias # d_bias = sum(grad_out, outer_dim, keepdim=False) +<<<<<<< HEAD if not rms_norm: if bias_strategy is not None: bias_src_spec = _add_target_input_spec(bias_strategy) @@ -1113,6 +1210,31 @@ def _add_target_input_spec(strategy) -> DTensorSpec: "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward." ) output_specs_list.append(None) +======= + if bias_strategy is not None: + bias_src_spec = _add_target_input_spec(bias_strategy) + # d_bias spec follows a reduction over grad_out + inp_placements = _replicate_dims_start_at( + grad_out_target_spec.placements, axis + ) + reduce_dims_map = _infer_reduce_dims_map( + outer_dims, grad_out_target_spec.ndim, False + ) + out_placements = map_placements_after_reduction( + inp_placements, outer_dims, reduce_dims_map, "sum" + ) + bias_out_spec = DTensorSpec( + mesh=mesh, + placements=out_placements, + tensor_meta=bias_src_spec.tensor_meta, + ) + output_specs_list.append(bias_out_spec if output_mask[2] else None) + else: + assert output_mask[2] is False, ( + "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward." + ) + output_specs_list.append(None) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_tuple_strategy.strategies.append( OpSpec( @@ -1126,6 +1248,7 @@ def _add_target_input_spec(strategy) -> DTensorSpec: @register_op_strategy( +<<<<<<< HEAD [aten.native_layer_norm_backward.default], schema_info=RuntimeSchemaInfo(2), ) @@ -1157,10 +1280,13 @@ def sort_strategy(op_schema: OpSchema, sort_dim: int) -> OpStrategy: @register_op_strategy( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [aten.topk.default], schema_info=RuntimeSchemaInfo(2), ) def topk_strategy(op_schema: OpSchema) -> OpStrategy: +<<<<<<< HEAD topk_dim = ( cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1 ) @@ -1221,4 +1347,28 @@ def histc_strategy(op_schema: OpSchema) -> OpStrategy: return expand_to_full_mesh_op_strategy( input_strategy.mesh, op_schema, single_mesh_dim_strategies +======= + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + topk_dim = ( + cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1 + ) + topk_dim = normalize_dim(topk_dim, input_strategy.ndim) + + single_mesh_dim_strategies = [] + + # two outputs (values, indices), 1 input + # replicate always works + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # every dim except topk dim should work + for dim in range(input_strategy.ndim): + if dim != topk_dim: + dim_shardings: PlacementList = [Shard(dim)] * 3 + single_mesh_dim_strategies.append(dim_shardings) + # TODO: topk on sharded dim requries non-trival reduction, address it later + + return expand_to_full_mesh_op_strategy( + input_strategy.mesh, op_schema, single_mesh_dim_strategies, input_index=2 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index b0dc49dde358c..9ed6b515b94eb 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -6,7 +6,11 @@ import torch from torch.distributed.device_mesh import DeviceMesh +<<<<<<< HEAD from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +======= +from torch.distributed.tensor._dtensor_spec import DTensorSpec +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.tensor._op_schema import ( OpSchema, OpSpec, @@ -24,10 +28,13 @@ prod, register_op_strategy, ) +<<<<<<< HEAD from torch.distributed.tensor._utils import ( compute_local_shape_and_global_offset, compute_local_stride, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.tensor.placement_types import ( Partial, Placement, @@ -312,7 +319,10 @@ def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrate single_mesh_dim_strategies.append(num_heads_dim_sharding) # Shard on the batch dimension +<<<<<<< HEAD debug_attn_mask_sharding = Shard(0) if return_debug_mask else Replicate() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) single_mesh_dim_strategies.append( [ Shard(0), # output @@ -323,7 +333,11 @@ def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrate None, # max_k Replicate(), # rng_state None, # unused +<<<<<<< HEAD debug_attn_mask_sharding, # debugattn +======= + Shard(0), # debugattn +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Shard(0), # q Shard(0), # k Shard(0), # v @@ -331,7 +345,10 @@ def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrate ) # Context Parallelism: shards on the sequence dim +<<<<<<< HEAD debug_attn_mask_sharding = Shard(2) if return_debug_mask else Replicate() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) single_mesh_dim_strategies.append( [ Shard(2), # output @@ -342,7 +359,11 @@ def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrate None, # max_k Replicate(), # rng_state None, # unused +<<<<<<< HEAD debug_attn_mask_sharding, # debugattn +======= + Shard(2), # debugattn +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Shard(2), # q Shard(2), # k Shard(2), # v @@ -706,7 +727,11 @@ def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrate None, # max_k None, # philox_seed None, # philox_offset +<<<<<<< HEAD # NOTE: debug_attn_mask is not supported by pytorch and is always an empty tensor +======= + # NOTE: debug_attn_mask is not supproted by pytorch and is always an empty tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # https://github.com/pytorch/pytorch/blob/60205b0eb2602317856312a66d955c88334ade0b/aten/src/ATen/native/transformers/cuda/attention.cu#L839-L840 debug_attn_mask_sharding, # debug_attn_mask Replicate(), # q @@ -1041,6 +1066,7 @@ def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy: ] ) +<<<<<<< HEAD def valid_grouped_mm_strides( input_specs: list[DTensorSpec], output_specs: tuple[Optional[DTensorSpec], ...] ) -> bool: @@ -1088,4 +1114,8 @@ def check_valid_strides(meta: TensorMeta) -> bool: single_mesh_dim_strategies, input_index=1, is_valid_strategy_cb=valid_grouped_mm_strides, +======= + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/torch/distributed/tensor/_ops/_pointwise_ops.py b/torch/distributed/tensor/_ops/_pointwise_ops.py index 46fc8fbc0d990..de23a28b7563d 100644 --- a/torch/distributed/tensor/_ops/_pointwise_ops.py +++ b/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -1,6 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from collections.abc import Sequence +<<<<<<< HEAD from typing import cast, Optional +======= +from typing import cast +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch.distributed.tensor._dtensor_spec import DTensorSpec @@ -25,7 +29,10 @@ Replicate, Shard, ) +<<<<<<< HEAD from torch.utils._typing_utils import not_none +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten = torch.ops.aten @@ -46,6 +53,18 @@ # ] +<<<<<<< HEAD +======= +linear_pointwise_ops = [ + aten.div.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. + aten.div_.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. + aten.to.dtype, + aten.add.Tensor, + aten.add_.Tensor, +] + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pointwise_ops = [ # please keep the entries below alphabetically sorted aten.__ilshift__.Scalar, @@ -134,6 +153,7 @@ aten.ceil.out, aten.ceil_.default, aten.clamp.default, +<<<<<<< HEAD aten.clamp.Tensor, aten.clamp.out, aten.clamp_.default, @@ -142,6 +162,10 @@ aten.clamp_min.Tensor, aten.clamp_max.default, aten.clamp_max.Tensor, +======= + aten.clamp.out, + aten.clamp_.default, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten.clip.default, aten.clip.out, aten.clip_.default, @@ -295,7 +319,15 @@ aten.maximum.out, aten.minimum.default, aten.minimum.out, +<<<<<<< HEAD + aten.mul.out, +======= + aten.mul.Scalar, + aten.mul.Tensor, aten.mul.out, + aten.mul_.Scalar, + aten.mul_.Tensor, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten.mvlgamma.default, aten.mvlgamma.out, aten.mvlgamma_.default, @@ -410,6 +442,7 @@ aten.threshold_backward.default, ] +<<<<<<< HEAD # the linear pointwise ops map, key is op, value is the type of linearity linear_pointwise_ops = { aten.to.dtype: 0, @@ -426,12 +459,18 @@ def pointwise_strategy(op_schema: OpSchema, linearity: int = -1) -> OpStrategy: followed_strategy_index = -1 +======= + +def pointwise_strategy(op_schema: OpSchema, linearity: bool = False) -> OpStrategy: + max_shards_strategy_index = -1 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) max_shards = -1 max_ndim = -1 if op_schema.is_inplace_op(): # inplace op should follow the first arg strategy followed_strategy = op_schema.args_schema[0] +<<<<<<< HEAD followed_strategy_index = 0 elif op_schema.is_out_variant_op(): # out variant op should follow the out kwarg strategy @@ -440,6 +479,11 @@ def pointwise_strategy(op_schema: OpSchema, linearity: int = -1) -> OpStrategy: # have an "index", we set it to a reasonably large number just to indicate it's # not a valid index followed_strategy_index = 100 +======= + elif op_schema.is_out_variant_op(): + # out variant op should follow the out kwarg strategy + followed_strategy = op_schema.kwargs_schema["out"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: # normal pointwise op, we choose to follow the arg with # the max shards in case operands needs reshard @@ -454,16 +498,25 @@ def pointwise_strategy(op_schema: OpSchema, linearity: int = -1) -> OpStrategy: if (arg_max_shards > max_shards) or ( arg_max_shards == max_shards and arg_max_ndim > max_ndim ): +<<<<<<< HEAD followed_strategy_index = idx max_shards = arg_max_shards max_ndim = arg_max_ndim followed_strategy = op_schema.args_schema[followed_strategy_index] +======= + max_shards_strategy_index = idx + max_shards = arg_max_shards + max_ndim = arg_max_ndim + + followed_strategy = op_schema.args_schema[max_shards_strategy_index] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(followed_strategy, OpStrategy), ( f"no strategy to follow for {op_schema}!" ) return common_pointwise_strategy( +<<<<<<< HEAD op_schema.args_schema, followed_strategy, followed_strategy_index, @@ -512,15 +565,31 @@ def common_pointwise_strategy( scalar_tensor_idx: Index of the Replicate scalar tensor for which we allow the mesh to be different from the mesh of followed_strategy """ +======= + op_schema.args_schema, followed_strategy, linearity + ) + + +def common_pointwise_strategy( + args_schema: Sequence[object], + followed_strategy: OpStrategy, + linearity: bool, +) -> OpStrategy: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # handle broadcasting common_shape = torch.broadcast_shapes( *[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)] ) pointwise_strategy = OpStrategy([]) +<<<<<<< HEAD for op_spec in followed_strategy.strategies: spec_to_follow = op_spec.output_spec +======= + for placement_strategy in followed_strategy.strategies: + spec_to_follow = placement_strategy.output_spec +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out_placements: list[Placement] = [] for placement in spec_to_follow.placements: if isinstance(placement, Shard): @@ -528,6 +597,7 @@ def common_pointwise_strategy( common_ndim = len(common_shape) new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim out_placements.append(Shard(new_shard_dim)) +<<<<<<< HEAD elif isinstance(placement, Partial): # note that only partial-sum and partial-avg are supported for linearity partial_supports_linearity = placement.is_partial( @@ -541,11 +611,19 @@ def common_pointwise_strategy( # by default we just replicate the partial, need to see if this # is optimal for all cases out_placements.append(Replicate()) +======= + elif isinstance(placement, Partial) and not linearity: + # clear the partial placemnet if op does not support linearity + # by default we just replicate the partial, need to see if this + # is optimal for all cases + out_placements.append(Replicate()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: out_placements.append(placement) input_specs: list[DTensorSpec] = [] redistribute_costs: list[list[float]] = [] +<<<<<<< HEAD for input_idx, input_arg in enumerate(args_schema): if isinstance(input_arg, OpStrategy): input_arg_spec = input_arg.strategies[0].output_spec @@ -587,13 +665,34 @@ def common_pointwise_strategy( != followed_strategy_index # Don't convert the "followed" strategy ) +======= + for arg_idx, input_arg in enumerate(args_schema): + if isinstance(input_arg, OpStrategy): + # sanity check that all args that follow the same strategy + # are on the same DeviceMesh + if input_arg.mesh != followed_strategy.mesh: + raise ValueError( + f"Could not run pointwise computation across different mesh: " + f"Found {input_arg.mesh} and {followed_strategy.mesh}!" + ) + + # every arg follow the out_placements, but need to handle broadcasting + input_arg_spec = input_arg.strategies[0].output_spec + input_arg_dims_map = infer_broadcast_dims_map( + common_shape, input_arg_spec.shape + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_target_placements = map_placements_after_broadcast( tuple(out_placements), common_shape, input_arg_dims_map, +<<<<<<< HEAD partial_to_replicate=should_convert_partial, ) +======= + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_arg_target_spec = DTensorSpec( mesh=followed_strategy.mesh, placements=input_target_placements, @@ -617,7 +716,20 @@ def common_pointwise_strategy( return pointwise_strategy +<<<<<<< HEAD for op in linear_pointwise_ops.keys(): +======= +def linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType: + """ + Linear pointwise operators can propagate pending reductions. + For example, c = add(a, b); if a is pending sum, then c will be + pending sum as well without any communication overhead. + """ + return pointwise_strategy(op_schema, linearity=True) + + +for op in linear_pointwise_ops: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( linear_pointwise_strategy ) @@ -708,6 +820,7 @@ def list_pointwise_strategy( OpStrategy: generated strategy """ +<<<<<<< HEAD def args_tuple_strategies( args_schema: tuple[object, ...], ) -> list[Optional[TupleStrategy]]: @@ -719,6 +832,17 @@ def args_tuple_strategies( if isinstance(arg, TupleStrategy): # every tuple strategy should have the same length assert len(arg.children) == strategy_len +======= + def args_tuple_strategies(args_schema: tuple[object, ...]) -> list[TupleStrategy]: + first_arg = args_schema[0] + assert isinstance(first_arg, TupleStrategy) + strategy_len = len(first_arg.childs) + tuple_strategies: list[TupleStrategy] = [] + for arg_idx, arg in enumerate(args_schema): + if isinstance(arg, TupleStrategy): + # every tuple strategy should have the same length + assert len(arg.childs) == strategy_len +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tuple_strategies.append(arg) elif isinstance(arg, OpStrategy): if arg_idx > 0: # implicitly broadcast @@ -729,6 +853,7 @@ def args_tuple_strategies( raise RuntimeError( f"list op only supports tuple strategy! {op_schema}" ) +<<<<<<< HEAD else: # insert None as placeholder so that the idx of arg is kept tuple_strategies.append(None) @@ -751,6 +876,21 @@ def args_tuple_strategies( scalar_tensor_idx=_FUSED_OP_SCALAR_IDX if op_schema.op in fused_ops else None, +======= + return tuple_strategies + + args_strategies = args_tuple_strategies(op_schema.args_schema) + follow_strategy: TupleStrategy = args_strategies[0] + list_strategy: list[OpStrategy] = [] + for child_idx, child_strtgy in enumerate(follow_strategy.childs): + assert isinstance(child_strtgy, OpStrategy) + args_schema: list[OpStrategy] = [ + cast(OpStrategy, arg_strategy.childs[child_idx]) + for arg_strategy in args_strategies + ] + pointwise_strategy: OpStrategy = common_pointwise_strategy( + args_schema, child_strtgy, linearity +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) list_strategy.append(pointwise_strategy) return TupleStrategy(list_strategy) @@ -784,6 +924,7 @@ def list_linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType: aten._fused_adamw_.tensor_lr, ] +<<<<<<< HEAD # The state_steps arg of fused adam / adamw is a Replicate scalar tensor, which will be put on # the compute_mesh of an op across all parameter groups, even when not all parameter groups @@ -791,6 +932,8 @@ def list_linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType: # redistribute during sharding propagation. _FUSED_OP_SCALAR_IDX = 5 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for op in fused_ops: register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( list_pointwise_strategy diff --git a/torch/distributed/tensor/_ops/_random_ops.py b/torch/distributed/tensor/_ops/_random_ops.py index 9db9b85e58d2d..2136d4164c055 100644 --- a/torch/distributed/tensor/_ops/_random_ops.py +++ b/torch/distributed/tensor/_ops/_random_ops.py @@ -31,6 +31,7 @@ def random_op_strategy(op_schema: OpSchema) -> StrategyType: if is_tensor_partial(arg_spec): # TODO: figure out how inplace random op should behave when it's partial raise RuntimeError(f"{op_schema.op} with Partial is not supported yet!") +<<<<<<< HEAD random_strategy.strategies.append( OpSpec( output_specs=arg_spec, @@ -38,5 +39,8 @@ def random_op_strategy(op_schema: OpSchema) -> StrategyType: redistribute_cost=[[0.0] * len(self_strategy.strategies)], ) ) +======= + random_strategy.strategies.append(OpSpec(output_specs=arg_spec)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return random_strategy diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index a5a037a3c73e6..39fc0438818e6 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -4,7 +4,10 @@ from typing import cast, Optional import torch +<<<<<<< HEAD from torch._prims_common import IntLike +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import ( OpSchema, @@ -35,12 +38,16 @@ Shard, ) +<<<<<<< HEAD from ._pointwise_ops import pointwise_strategy +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten = torch.ops.aten +<<<<<<< HEAD def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType: # For ops with a single tensor input, we perform a 1:1 mapping such that # for each strategy that the input supports, we create a corresponding strategy. @@ -75,17 +82,40 @@ def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType: for strategy in first_input_strategy.strategies ] ) +======= +def default_strategy(op_schema: OpSchema) -> StrategyType: + # Default strategy by default just propagate the first input strategy + select_strategy = op_schema.args_schema[0] + assert isinstance(select_strategy, OpStrategy) + # we create new DTensorSpecs even for default strategy to assure that + # the tensor metas are distinct between the arguments and outputs + default_strategy = [ + OpSpec( + output_specs=DTensorSpec( + mesh=select_strategy.mesh, + placements=strategy.output_spec.placements, + ) + ) + for strategy in select_strategy.strategies + ] + return OpStrategy(default_strategy) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) register_op_strategy( [ aten.clone.default, aten.contiguous.default, +<<<<<<< HEAD +======= + aten.copy_.default, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten.detach.default, aten.fill_.Scalar, aten.view.dtype, aten.zero_.default, ] +<<<<<<< HEAD )(propagate_single_input_strategy) @@ -107,6 +137,13 @@ def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType: # - our 'src' input may be redistributed to match up with the 'self' input, with the caveat of adjusting for # broadcasting dim register_op_strategy(aten.copy_.default)(pointwise_strategy) +======= +)(default_strategy) + +register_op_strategy( + aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) +)(default_strategy) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_op_strategy( @@ -226,7 +263,11 @@ def new_factory_strategy(op_schema: OpSchema) -> StrategyType: OpSpec( output_specs=replica_spec, input_specs=(input_spec,), +<<<<<<< HEAD redistribute_cost=[[0.0] * len(input_strategy.strategies)], +======= + redistribute_cost=[[0.0] * mesh.ndim], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) @@ -244,7 +285,11 @@ def new_factory_strategy(op_schema: OpSchema) -> StrategyType: output_specs=input_spec, input_specs=(input_spec,), # encouraging new tensor placement to be the same as input +<<<<<<< HEAD redistribute_cost=[[-0.1] * len(input_strategy.strategies)], +======= + redistribute_cost=[[-0.1] * mesh.ndim], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ) @@ -255,6 +300,7 @@ def new_factory_strategy(op_schema: OpSchema) -> StrategyType: def gen_bucketize_strategy(op_schema: OpSchema) -> StrategyType: """Just propagate input sharding, but expect replicated for boundaries input.""" mesh = op_schema.get_mesh_from_args() +<<<<<<< HEAD input_strategy, boundaries_strategy = op_schema.args_schema bucketize_strategy = OpStrategy([]) assert isinstance(input_strategy, OpStrategy) @@ -279,6 +325,16 @@ def gen_bucketize_strategy(op_schema: OpSchema) -> StrategyType: generate_redistribute_costs(boundaries_strategy, replica_spec), ], ) +======= + input_strategy = op_schema.args_schema[0] + bucketize_strategy = OpStrategy([]) + assert isinstance(input_strategy, OpStrategy) + for arg_strategy in input_strategy.strategies: + arg_spec = DTensorSpec(mesh, arg_strategy.output_spec.placements) + replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) + bucketize_strategy.strategies.append( + OpSpec(output_specs=arg_spec, input_specs=(arg_spec, replica_spec)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return bucketize_strategy @@ -396,6 +452,7 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType: start = 0 if end is None or end > input_shape[dim]: end = input_shape[dim] +<<<<<<< HEAD assert isinstance(start, IntLike) assert isinstance(end, IntLike) assert isinstance(step, IntLike) @@ -404,6 +461,16 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType: slice_dim = normalize_dim(dim, input_ndim) # type: ignore[arg-type] start = normalize_dim(start, input_shape[dim]) # type: ignore[arg-type] end = normalize_dim(end, input_shape[dim]) # type: ignore[arg-type] +======= + assert isinstance(start, int) + assert isinstance(end, int) + assert isinstance(step, int) + + # normalize args + slice_dim = normalize_dim(dim, input_ndim) + start = normalize_dim(start, input_shape[dim]) + end = normalize_dim(end, input_shape[dim]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) redundant_slice = start == 0 and end == input_shape[dim] and step == 1 @@ -414,6 +481,7 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType: if not is_tensor_dim_sharded(arg_spec, dim=slice_dim) or redundant_slice: # only add the strategy if the slice dim is not sharded out_spec = DTensorSpec(mesh, arg_spec.placements) +<<<<<<< HEAD slice_strategy.strategies.append( OpSpec( output_specs=out_spec, @@ -421,6 +489,9 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType: redistribute_cost=[[0.0] * len(input_strategy.strategies)], ) ) +======= + slice_strategy.strategies.append(OpSpec(output_specs=out_spec)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not slice_strategy.strategies: # if all strategies are filtered out, unsharding all specs on slice dim # of the input strategy, and use that as the op strategy @@ -429,6 +500,7 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType: unshard_spec = DTensorSpec( mesh, unshard_tensor_dim(arg_spec.placements, dim=slice_dim) ) +<<<<<<< HEAD slice_strategy.strategies.append( OpSpec( output_specs=unshard_spec, @@ -437,6 +509,9 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType: ], ) ) +======= + slice_strategy.strategies.append(OpSpec(output_specs=unshard_spec)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return slice_strategy @@ -461,9 +536,14 @@ def slice_backward_rules(op_schema: OpSchema) -> OpStrategy: new_placements.append(placement) new_spec = DTensorSpec(output_spec.mesh, tuple(new_placements)) redistribute_cost = [generate_redistribute_costs(input_strategy, new_spec)] +<<<<<<< HEAD new_strategy = OpSpec( output_specs=new_spec, redistribute_cost=redistribute_cost ) +======= + placement_strategy.redistribute_cost = redistribute_cost + new_strategy = OpSpec(output_specs=new_spec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_strategies.append(new_strategy) return OpStrategy(output_strategies) @@ -501,9 +581,13 @@ def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType: # TODO: Ideally we'd like to make sure the output is re-sharded afterwards to keep input sharding. mesh = op_schema.get_mesh_from_args() input_strategy = op_schema.args_schema[0] +<<<<<<< HEAD src_strategy = op_schema.args_schema[1] assert isinstance(input_strategy, OpStrategy) assert isinstance(src_strategy, OpStrategy) +======= + assert isinstance(input_strategy, OpStrategy) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_ndim = input_strategy.ndim slice_dim = ( cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0 @@ -518,6 +602,7 @@ def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType: is_tensor_dim_sharded(arg_spec, dim=slice_dim) or is_tensor_partial(arg_spec) ): +<<<<<<< HEAD input_spec = DTensorSpec(mesh, arg_spec.placements, arg_spec.tensor_meta) # TODO: need to relax the constraint to src src_spec = DTensorSpec(mesh, arg_spec.placements) @@ -532,12 +617,17 @@ def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType: ], ) ) +======= + # only add the strategy if the slice_scatter dim is not sharded or partial + slice_scatter_strategy.strategies.append(OpSpec(output_specs=arg_spec)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not slice_scatter_strategy.strategies: # if all strategies are filtered out, replicating all specs on slice_scatter dim # of the input strategy, and use that as the op strategy for arg_strategy in input_strategy.strategies: arg_spec = arg_strategy.output_spec +<<<<<<< HEAD new_placement = replicate_tensor_dim(arg_spec.placements, dim=slice_dim) input_spec = DTensorSpec(mesh, new_placement) src_spec = DTensorSpec(mesh, new_placement) @@ -550,6 +640,13 @@ def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType: generate_redistribute_costs(src_strategy, src_spec), ], ) +======= + replicate_spec = DTensorSpec( + mesh, replicate_tensor_dim(arg_spec.placements, dim=slice_dim) + ) + slice_scatter_strategy.strategies.append( + OpSpec(output_specs=replicate_spec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return slice_scatter_strategy @@ -565,12 +662,16 @@ def replica_only_strategy(op_schema: OpSchema) -> StrategyType: @register_op_strategy( +<<<<<<< HEAD [ aten.scatter_.value, aten.scatter.value, aten.scatter_.src, aten.scatter.src, ], +======= + [aten.scatter_.value, aten.scatter.value, aten.scatter_.src, aten.scatter.src], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) schema_info=RuntimeSchemaInfo(1), ) def scatter_strategy(op_schema: OpSchema) -> StrategyType: @@ -596,6 +697,7 @@ def scatter_strategy(op_schema: OpSchema) -> StrategyType: return op_strategy +<<<<<<< HEAD @register_op_strategy(aten.scatter_add.default, schema_info=RuntimeSchemaInfo(1)) def scatter_add_strategy(op_schema: OpSchema) -> StrategyType: input_strategy = op_schema.args_schema[0] @@ -629,11 +731,17 @@ def scatter_add_strategy(op_schema: OpSchema) -> StrategyType: @register_op_strategy(aten.gather.default, schema_info=RuntimeSchemaInfo(1)) +======= +@register_op_strategy(aten.gather.default) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def gather_strategy(op_schema: OpSchema) -> StrategyType: mesh = op_schema.get_mesh_from_args() input_strategy = cast(OpStrategy, op_schema.args_schema[0]) dim = cast(int, op_schema.args_schema[1]) +<<<<<<< HEAD dim = normalize_dim(dim, input_strategy.ndim) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) index_strategy = cast(OpStrategy, op_schema.args_schema[2]) input_shape = input_strategy.shape @@ -649,7 +757,11 @@ def gather_strategy(op_schema: OpSchema) -> StrategyType: # input sharding, input sharded, index accepts mask partial, output follows index # this only works when the input is sharded on the gather dimension, and # index has size 1 on the gather dimension +<<<<<<< HEAD if dim < len(index_shape) and index_shape[dim] == 1: +======= + if index_shape[dim] == 1: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) input_sharding: PlacementList = [ index_partial_placement, @@ -663,12 +775,15 @@ def gather_strategy(op_schema: OpSchema) -> StrategyType: index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)] single_mesh_dim_strategies.append(index_sharding) +<<<<<<< HEAD if len(input_shape) == len(index_shape): for d in range(len(input_shape)): if d != dim: sharding: PlacementList = [Shard(d), Shard(d), Shard(d)] single_mesh_dim_strategies.append(sharding) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return expand_to_full_mesh_op_strategy( mesh, op_schema, single_mesh_dim_strategies, input_index=1 ) @@ -718,7 +833,11 @@ def merge_placement( follow_placements: Optional[list[Placement]] = None mesh = tuple_strategy.child_mesh(0) +<<<<<<< HEAD for arg_strategy in tuple_strategy.children: +======= + for arg_strategy in tuple_strategy.childs: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(arg_strategy, OpStrategy) if arg_strategy.mesh != mesh: raise ValueError( @@ -760,7 +879,11 @@ def stack_strategy(op_schema: OpSchema) -> StrategyType: args_schema = op_schema.args_schema input_tuple_strategy = args_schema[0] assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}" +<<<<<<< HEAD first_input_strategy = input_tuple_strategy.children[0] +======= + first_input_strategy = input_tuple_strategy.childs[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}" common_input_ndim = first_input_strategy.ndim dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 @@ -778,11 +901,16 @@ def stack_strategy(op_schema: OpSchema) -> StrategyType: input_specs = tuple( DTensorSpec(mesh, tuple(follow_placements)) +<<<<<<< HEAD for _ in range(len(input_tuple_strategy.children)) +======= + for _ in range(len(input_tuple_strategy.childs)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) follow_placements = normalize_shard_for_stack(follow_placements, dim) +<<<<<<< HEAD for strategy in input_tuple_strategy.children: assert isinstance(strategy, OpStrategy) output_spec = DTensorSpec(mesh, tuple(follow_placements)) @@ -797,6 +925,14 @@ def stack_strategy(op_schema: OpSchema) -> StrategyType: redistribute_cost=redistribute_cost, ) ) +======= + op_strategy.strategies.append( + OpSpec( + output_specs=DTensorSpec(mesh, tuple(follow_placements)), + input_specs=input_specs, + ) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return op_strategy @@ -805,8 +941,12 @@ def cat_strategy(op_schema: OpSchema) -> StrategyType: args_schema = op_schema.args_schema input_tuple_strategy = args_schema[0] assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}" +<<<<<<< HEAD num_input_tensor = len(input_tuple_strategy.children) first_input_strategy = input_tuple_strategy.children[0] +======= + first_input_strategy = input_tuple_strategy.childs[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}" common_input_ndim = first_input_strategy.ndim dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 @@ -815,6 +955,7 @@ def cat_strategy(op_schema: OpSchema) -> StrategyType: mesh = first_input_strategy.mesh +<<<<<<< HEAD op_strategy = OpStrategy([]) # use a set to deduplicate strategies with the same placement strategies_placement_pool = set() @@ -866,6 +1007,27 @@ def cat_strategy(op_schema: OpSchema) -> StrategyType: redistribute_cost=redistribute_costs, ) ) +======= + follow_placements = _derive_follow_placements_from_tuple_strategy( + op_schema.op, input_tuple_strategy + ) + # for cat we unshard the cat dim if it is sharded + follow_placements = unshard_tensor_dim(follow_placements, dim) + + # create op strategy base on the follow placements + op_strategy = OpStrategy([]) + + input_specs = tuple( + DTensorSpec(mesh, tuple(follow_placements)) + for _ in range(len(input_tuple_strategy.childs)) + ) + op_strategy.strategies.append( + OpSpec( + output_specs=DTensorSpec(mesh, tuple(follow_placements)), + input_specs=input_specs, + ) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return op_strategy @@ -902,6 +1064,7 @@ def prop_index_select(op_schema: OpSchema) -> OutputSharding: return result +<<<<<<< HEAD @register_op_strategy( [ aten.index_put.default, @@ -1007,6 +1170,8 @@ def prop_index_put(op_schema: OpSchema) -> StrategyType: return op_strategy +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_prop_rule(aten.index.Tensor, schema_info=RuntimeSchemaInfo(needs_pytree=True)) def prop_index(op_schema: OpSchema) -> OutputSharding: """ @@ -1137,7 +1302,11 @@ def place(vp: Placement, ip: Placement) -> Placement: ], RuntimeSchemaInfo(1), ) +<<<<<<< HEAD def split_strategy(op_schema: OpSchema) -> OpStrategy: +======= +def split_strategy(op_schema: OpSchema) -> TupleStrategy: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_strategy = op_schema.args_schema[0] split_size_or_sections = op_schema.args_schema[1] assert isinstance(input_strategy, OpStrategy) @@ -1147,6 +1316,19 @@ def split_strategy(op_schema: OpSchema) -> OpStrategy: ) dim = normalize_dim(split_dim, input_ndim) +<<<<<<< HEAD +======= + # tensor to split cannot have Partial for now + for arg_strategy in input_strategy.strategies: + arg_spec = arg_strategy.output_spec + if is_tensor_partial(arg_spec): + raise NotImplementedError( + f"splitting distributed tensor with " + f"Partial placement is not implemented!\n" + f"DTensorSpec={arg_strategy}" + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def size_split(N, i) -> list: # Last chunk will be smaller if the tensor size N # along the given dimension dim is not divisible by i. @@ -1160,6 +1342,7 @@ def size_split(N, i) -> list: ) assert isinstance(output_size_list, Sized) +<<<<<<< HEAD all_strategies = [] for strategy in input_strategy.strategies: spec = strategy.output_spec @@ -1184,3 +1367,25 @@ def size_split(N, i) -> list: ) return OpStrategy(all_strategies) +======= + split_strategies = [] + + for _ in range(len(output_size_list)): + op_strategy = OpStrategy([]) + + for strategy in input_strategy.strategies: + spec = strategy.output_spec + placements = spec.placements + if is_tensor_dim_sharded(spec, dim=dim): + # if the input is sharded on the split dim, we need to unshard it + placements = unshard_tensor_dim(spec.placements, dim=dim) + + spec = DTensorSpec(spec.mesh, placements) + + op_strategy.strategies.append( + OpSpec(output_specs=spec, input_specs=([spec])) + ) + split_strategies.append(op_strategy) + + return TupleStrategy(split_strategies) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/distributed/tensor/_ops/_view_ops.py b/torch/distributed/tensor/_ops/_view_ops.py index 62e8c68e9be9d..3ae26d720ca6c 100644 --- a/torch/distributed/tensor/_ops/_view_ops.py +++ b/torch/distributed/tensor/_ops/_view_ops.py @@ -22,12 +22,16 @@ prod, register_op_strategy, ) +<<<<<<< HEAD from torch.distributed.tensor.placement_types import ( _StridedShard, Placement, Replicate, Shard, ) +======= +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten = torch.ops.aten @@ -305,7 +309,11 @@ def view_groups(from_size: Shape, to_size: Shape) -> DimMap: Flatten((InputDim(1), InputDim(2))) ) +<<<<<<< HEAD - output dimension 0 maps to input dimension 0 +======= + - ouptut dimension 0 maps to input dimension 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - output dimension 1 maps to a flattened input dimensions 1 and 2 @@ -519,6 +527,7 @@ def maybe_get_shard_mesh_dim_and_placement( return i, placement return None, None +<<<<<<< HEAD # NOTE: This function has three responsibilities: # 1. determine "theoretically" if an output dimension can be sharded, i.e. fill the shardable_dims map # 2. determine "theoretically" the corresponding input dimension to shard on, via return value @@ -527,10 +536,15 @@ def maybe_get_shard_mesh_dim_and_placement( # 3 requires that info, to decide whether we can error out. Maybe we can refactor # to make this function purely "theoretical". def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: +======= + def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: + # TODO(whc) this helper is pretty hard to understand, at least it should be better documented if not refactored +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(cmd, InputDim): return cmd elif isinstance(cmd, Flatten): for i, dim in enumerate(cmd.input_dims): +<<<<<<< HEAD # so far all Flatten is always composed of InputDims; revisit this if needed assert isinstance(dim, InputDim) can_shard_dim = True @@ -561,6 +575,41 @@ def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: assert isinstance(cmd.input_dims[0], InputDim) return cmd.input_dims[0] +======= + if isinstance(dim, InputDim): + can_shard_dim = True + shard_mesh_dim, shard_placement = ( + maybe_get_shard_mesh_dim_and_placement(dim) + ) + input_sharded = shard_mesh_dim is not None + if i > 0: + can_shard_dim = False + if strict_view and input_sharded: + raise RuntimeError( + f"Attempted to flatten sharded dimension {i}, ", + "but only the leftmost dim of a Flatten can be sharded.", + ) + elif input_sharded: + assert ( + shard_placement is not None and shard_mesh_dim is not None + ) + tensor_dim_size = global_input_shape[shard_placement.dim] + mesh_dim_size = mesh_sizes[shard_mesh_dim] + if tensor_dim_size % mesh_dim_size != 0: + can_shard_dim = False + if strict_view: + raise RuntimeError( + f"Attempted to flatten unevenly sharded dimension {i}, " + "which would require resharding the input. " + "Please explicitly redistribute the tensor instead." + ) + + shardable_dims[dim.input_dim] = [can_shard_dim] * mesh_ndim + dim0 = cmd.input_dims[0] + # TODO(whc) dim0 can be sharded or not sharded, can't it? + # should we only return it if its sharded in the placement? + return dim0 if isinstance(dim0, InputDim) else None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif isinstance(cmd, Split): in_dim = get_in_dim_to_shard(cmd.input_dim) out_size = cmd.group_shape[cmd.split_id] @@ -579,6 +628,7 @@ def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: out_size % mesh_dim_size == 0 for mesh_dim_size in mesh_sizes ] +<<<<<<< HEAD shard_mesh_dim, _ = maybe_get_shard_mesh_dim_and_placement(in_dim) if strict_view and shard_mesh_dim is not None: if not shardable_dims[in_dim.input_dim][shard_mesh_dim]: @@ -587,6 +637,8 @@ def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: "It cannot be performed without redistribution, which is disallowed by the current operator.", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # 2. here we special case things like [Shard(0), Shard(0)] submesh_size = 1 for size, shard in zip(mesh_sizes, input_src_placements): @@ -621,6 +673,7 @@ def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: ) for mesh_dim, p in enumerate(input_src_placements) ] +<<<<<<< HEAD def _rewrite_shard_dim(p: Shard): """ @@ -645,6 +698,10 @@ def _rewrite_shard_dim(p: Shard): output_placements = [ _rewrite_shard_dim(p) if isinstance(p, Shard) else p +======= + output_placements = [ + Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for p in input_tgt_placements ] @@ -716,9 +773,12 @@ def reshape_strategy(op_schema: OpSchema) -> StrategyType: register_op_strategy_map(aten.squeeze.default, torch.squeeze) register_op_strategy_map( +<<<<<<< HEAD aten.squeeze_.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1) ) register_op_strategy_map( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten.squeeze.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1) ) register_op_strategy_map( diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index fb6f8a8ba8108..9ccea394c6c44 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -19,7 +19,10 @@ OutputSharding, PlacementList, RuntimeSchemaInfo, +<<<<<<< HEAD StrategyType, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from torch.distributed.tensor.device_mesh import DeviceMesh from torch.distributed.tensor.placement_types import ( @@ -35,12 +38,23 @@ # convenient wrapper to register sharding propagation rules +<<<<<<< HEAD +======= +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def register_prop_rule( op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]], schema_info: Optional[RuntimeSchemaInfo] = None, ) -> Callable[ [Callable[[OpSchema], OutputSharding]], Callable[[OpSchema], OutputSharding] ]: +<<<<<<< HEAD +======= + # pyre-fixme[53]: Captured variable `func` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def wrapper( impl: Callable[[OpSchema], OutputSharding], ) -> Callable[[OpSchema], OutputSharding]: @@ -57,6 +71,11 @@ def wrapper( def register_op_strategy( op, schema_info=None ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: +<<<<<<< HEAD +======= + # pyre-fixme[53]: Captured variable `func` is not annotated. + # pyre-fixme[3]: Return type must be annotated. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # pyre-fixme[2]: Parameter must be annotated. # For every ATen op that accepts any args in this list, @@ -96,6 +115,7 @@ def wrapper(impl): return wrapper +<<<<<<< HEAD def replicate_op_strategy(op_schema: OpSchema) -> StrategyType: """ Fallback strategy all use Replication() @@ -126,6 +146,8 @@ def replicate_op_strategy(op_schema: OpSchema) -> StrategyType: ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def as_list( x: Union[list[object], object], # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type. @@ -165,8 +187,11 @@ def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: for i, placement in enumerate(spec.placements): if placement.is_shard(): shard_dim = cast(Shard, placement).dim +<<<<<<< HEAD if shard_dim >= len(shape): return False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shards_map[shard_dim] *= spec.mesh.size(i) for i, dim_size in enumerate(shape): @@ -222,11 +247,15 @@ def map_placements_after_broadcast( placements: tuple[Placement, ...], shape: torch.Size, broadcast_dims_map: list[int], +<<<<<<< HEAD partial_to_replicate: bool = False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> tuple[Placement, ...]: """Map each placement based on the output shape after broadcast.""" new_placements: list[Placement] = [] for placement in placements: +<<<<<<< HEAD if isinstance(placement, Partial): if partial_to_replicate: # map the partial placement to replicate @@ -234,6 +263,9 @@ def map_placements_after_broadcast( else: new_placements.append(placement) elif isinstance(placement, Replicate): +======= + if isinstance(placement, (Replicate, Partial)): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_placements.append(placement) else: assert isinstance(placement, Shard) @@ -249,7 +281,11 @@ def map_placements_after_broadcast( # the input shape shard dim before broadcasting, # in this case it means implicit broadcasting happen # in this dim, so we can just mark it as replicate +<<<<<<< HEAD # and implicit broadcast will broadcast automatically +======= + # and implict broadcast will broadcast automatically +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # to the sharded shape new_placements.append(Replicate()) @@ -259,11 +295,14 @@ def map_placements_after_broadcast( def generate_redistribute_costs( src_strategy: OpStrategy, dst_spec: DTensorSpec ) -> list[float]: +<<<<<<< HEAD """Generates one row in the 'redistribute_costs' matrix in an OpSpec The length of the returned list will match the number of strategies in 'src_strategy'. Each value in the row is the cost of redistributing from a particular src_strategy to dst_spec. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) redistribute_costs: list[float] = [ redistribute_cost(strat.output_spec, dst_spec) for strat in src_strategy.strategies @@ -279,6 +318,7 @@ def expand_to_full_mesh_op_strategy( *, input_index: int = 1, inplace_op: bool = False, +<<<<<<< HEAD is_valid_strategy_cb: Optional[ Callable[[list[DTensorSpec], tuple[Optional[DTensorSpec], ...]], bool] ] = None, @@ -309,6 +349,9 @@ def expand_to_full_mesh_op_strategy( [Replicate(), Replicate(), Replicate()] ] """ +======= +) -> OpStrategy: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Expand the single_mesh_dim_strategies to full mesh dim strategies. all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim @@ -319,7 +362,10 @@ def expand_to_full_mesh_op_strategy( spec_list: list[Optional[DTensorSpec]] = [] for specs in zip(*strategy_comb): if specs[0] is not None: +<<<<<<< HEAD # TODO: we should fill in tensor_meta here. If nothing else, it helps the filter strategy callback +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) spec_list.append(DTensorSpec(mesh, specs)) else: spec_list.append(None) @@ -337,6 +383,7 @@ def expand_to_full_mesh_op_strategy( # input_spec matches the first argument's runtime sharding, otherwise we skip continue +<<<<<<< HEAD output_specs: tuple[Optional[DTensorSpec], ...] if input_index > 1: output_specs = tuple(spec_list[:input_index]) @@ -369,4 +416,32 @@ def expand_to_full_mesh_op_strategy( redistribute_cost=redistribute_cost, ) all_strategies.append(strategy) +======= + # check inputs shardable + inputs_shardable = all( + is_tensor_shardable(inp.shape, s) + for inp, s in zip(input_args_strategy, input_specs) + ) + + # only add to the all_strategies list when all inputs are shardable + if inputs_shardable: + redistribute_cost = [ + generate_redistribute_costs(input_strategy, input_spec) + for input_strategy, input_spec in zip(input_args_strategy, input_specs) + ] + if input_index > 1: + output_specs = tuple(spec_list[:input_index]) + else: + if spec_list[0] is not None: + output_specs = spec_list[0] # type: ignore[assignment] + else: + raise RuntimeError("output spec is None") + strategy = OpSpec( + output_specs=output_specs, + input_specs=input_specs, + redistribute_cost=redistribute_cost, + ) + all_strategies.append(strategy) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return OpStrategy(all_strategies) diff --git a/torch/distributed/tensor/_random.py b/torch/distributed/tensor/_random.py index dc3a1fb10e4b3..805eb79a9f38a 100644 --- a/torch/distributed/tensor/_random.py +++ b/torch/distributed/tensor/_random.py @@ -2,17 +2,28 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import contextlib import warnings +<<<<<<< HEAD from logging import getLogger from typing import Optional, Union import torch +======= +from typing import Optional, Union + +import torch +import torch.distributed as dist +from torch import Tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.device_mesh import _get_device_handle, DeviceMesh from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor.placement_types import Shard +<<<<<<< HEAD logger = getLogger(__name__) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __all__ = [ "is_rng_supported_mesh", "manual_seed", @@ -76,6 +87,7 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None: ) return +<<<<<<< HEAD # TODO: deprecate this API, but also need to ensure we disable broadcast for PP case, and that's currently # bundled together with this API. See torchtitan/distributed/utils.py:set_determinism # warnings.warn( @@ -84,19 +96,29 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None: # ) # Note: we still need to ensure setting `run_state_sync=False` to support the the pp case +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # instantiate a RNG tracker if haven't. By default DTensor uses an # OffsetBasedRNGTracker to perform random operators. global _rng_tracker if not _rng_tracker: _rng_tracker = OffsetBasedRNGTracker(device_mesh, run_state_sync=False) +<<<<<<< HEAD if device_mesh.get_coordinate() is None: +======= + # the current rank is in mesh + if device_mesh.get_coordinate() is not None: + _rng_tracker._manual_seed(seed) + else: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise RuntimeError( "manual_seed requires the current rank to be a part of the device mesh " "otherwise DTensor RNG state on the rank will not be initialized and " "the behavior of DTensor random ops is undefined." ) +<<<<<<< HEAD # DTensor no longer maintains a copy of rng state. manual seed on dtensor is the same thing # as manual seed on torch. torch.manual_seed(seed) @@ -139,6 +161,8 @@ def seed(self, seed: int) -> None: ) self._state[:8] = seed_tensor +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _RNGStateTracker: """ @@ -157,9 +181,21 @@ def __init__(self, device: torch.device): f"{self.__class__.__name__} instantiation requires the presence of " f"{device.type} device but couldn't find." ) +<<<<<<< HEAD + self._use_distribute_region = True + + @property +======= + + self._states: dict[str, Tensor] = {} self._use_distribute_region = True @property + def rng_states(self) -> dict[str, Tensor]: + return self._states + + @property +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def distribute_region_enabled(self) -> bool: return self._use_distribute_region @@ -167,9 +203,34 @@ def distribute_region_enabled(self) -> bool: def distribute_region_enabled(self, value) -> None: self._use_distribute_region = value +<<<<<<< HEAD def _distribute_region( self, spec: DTensorSpec, generator: Optional[torch.Generator] = None ): +======= + def rng_state_is_sync(self, name) -> bool: + return name in self.rng_states + + def get_seed(self, name: str) -> int: + if name not in self.rng_states: + raise RuntimeError( + f"{self.__class__.__name__} does not have random state for {name}" + ) + + seed_tensor = (self.rng_states[name])[0:8].view(dtype=torch.int64) + return int(seed_tensor.item()) + + def set_seed(self, name: str, seed: int) -> None: + seed_tensor = torch.tensor([seed], dtype=torch.uint64, device="cpu").view( + torch.uint8 + ) + offset_tensor = torch.tensor([0], dtype=torch.uint64, device="cpu").view( + torch.uint8 + ) + self.rng_states[name] = torch.cat([seed_tensor, offset_tensor]) + + def _distribute_region(self, spec: DTensorSpec): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pass def _manual_seed(self, parallel_seed: int) -> None: @@ -199,6 +260,7 @@ def __init__( f"CUDA/CUDA-like/XPU device. Got {self._device.type} instead." ) +<<<<<<< HEAD rng_state = self._get_device_state() if run_state_sync: # synchronize RNG state using rank 0's current one @@ -249,15 +311,44 @@ def _distribute_region( self._device_handle.set_rng_ctx("philox") old_offset = state.offset self._set_pre_op_offset(state, spec) +======= + rng_state = self._device_handle.get_rng_state().to(self._device) + if run_state_sync: + # synchronize RNG state using rank 0's current one + dist.broadcast(rng_state, 0) + + self.rng_states["parallel-rng"] = rng_state.to("cpu") + + def _manual_seed(self, parallel_seed: int) -> None: + self.set_seed("parallel-rng", parallel_seed) + + @contextlib.contextmanager + def _distribute_region(self, spec: DTensorSpec): + # check if the parallel rng state has been synchronized or not + if not self.rng_state_is_sync("parallel-rng"): + raise RuntimeError( + "OffsetBasedRNGTracker requires the random state to be synchronized " + "before entering into a distribute region!" + ) + + if self.distribute_region_enabled: + old_offset = self.get_offset("parallel-rng") + self._set_pre_op_offset(spec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch.random.fork_rng( devices=[self._device], device_type=self._device.type ): assert self._device_handle is not None +<<<<<<< HEAD self._device_handle.set_rng_state(state.state) +======= + self._device_handle.set_rng_state(self.rng_states["parallel-rng"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: yield # execute the region code finally: # update offset to synchronize among ranks +<<<<<<< HEAD self._set_post_op_offset(state, spec, old_offset) if self._device.type == "hpu": self._device_handle.unset_rng_ctx("philox") @@ -273,6 +364,34 @@ def _distribute_region( self._set_device_state(state.state) def _set_pre_op_offset(self, state: _PhiloxState, spec: DTensorSpec) -> None: +======= + self._set_post_op_offset(spec, old_offset) + else: + yield + + def get_offset(self, name: str) -> int: + if name not in self.rng_states: + raise RuntimeError( + f"{self.__class__.__name__} does not have random state for {name}" + ) + + offset_tensor = (self.rng_states[name])[8:].view(dtype=torch.int64) + return int(offset_tensor.item()) + + def set_offset(self, name: str, offset: int) -> None: + if name not in self.rng_states: + raise RuntimeError( + f"{self.__class__.__name__} does not have random state for {name}" + ) + + seed_tensor = (self.rng_states[name])[0:8] + offset_tensor = torch.tensor([offset], dtype=torch.uint64, device="cpu").view( + torch.uint8 + ) + self.rng_states[name] = torch.cat([seed_tensor, offset_tensor]) + + def _set_pre_op_offset(self, spec: DTensorSpec) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Set the starting RNG offset for current device's local shard before actual op execution. The pre_op_offset value should start from the current RNG offset and increment by the size of local shard until it reaches the size of the whole @@ -280,7 +399,10 @@ def _set_pre_op_offset(self, state: _PhiloxState, spec: DTensorSpec) -> None: will be the same. Args: +<<<<<<< HEAD state (:class:`Tensor`): The generator state to modify +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) spec (:class:`DTensorSpec`): the spec of the DTensor object on which we prepare the offset for running random ops. @@ -383,23 +505,36 @@ def _set_pre_op_offset(self, state: _PhiloxState, spec: DTensorSpec) -> None: local_size = prod(local_size_on_rank_0) # get current RNG offset +<<<<<<< HEAD current_offset = state.offset +======= + current_offset = self.get_offset("parallel-rng") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # pytorch: offset must be multiple of 4 # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4 +<<<<<<< HEAD state.offset = current_offset + offset_incr def _set_post_op_offset( self, state: _PhiloxState, spec: DTensorSpec, old_offset: int ) -> None: +======= + self.set_offset("parallel-rng", current_offset + offset_incr) + + def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Sets the RNG to a synchronized state after running the local random op. Every rank should set its RNG offset to `old_offset + DTensor.numel()` where old_offset is the offset before calling `set_pre_op_offset` i.e. the offset before running DTensor random ops. Args: +<<<<<<< HEAD state (:class:`Tensor`): The generator state to modify. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) spec (:class:`DTensorSpec`): the spec of the DTensor object on which we post-process the offset for running random ops. @@ -414,7 +549,11 @@ def _set_post_op_offset( # pytorch: offset must be multiple of 4 # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp numel = (numel + 3) // 4 * 4 +<<<<<<< HEAD state.offset = old_offset + numel +======= + self.set_offset("parallel-rng", old_offset + numel) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _calc_shard_linear_idx( self, shard_coord: list[int], shard_size: list[int] diff --git a/torch/distributed/tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py index 54d8723b92f89..9f1d2c70b8fb3 100644 --- a/torch/distributed/tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -7,7 +7,10 @@ import torch import torch.distributed._functional_collectives as funcol import torch.distributed.tensor._api as dtensor +<<<<<<< HEAD from torch.distributed._functional_collectives import _are_we_tracing +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor.device_mesh import DeviceMesh from torch.distributed.tensor.placement_types import ( @@ -182,7 +185,14 @@ def redistribute_local_tensor( # which should be an empty tensor return local_tensor +<<<<<<< HEAD if _are_we_tracing(): +======= + has_symints = any(isinstance(s, torch.SymInt) for s in current_spec.shape) or any( + isinstance(s, torch.SymInt) for s in target_spec.shape + ) + if has_symints: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec) else: transform_infos = _gen_transform_infos(current_spec, target_spec) diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index cd5452a1e9c01..a1868cd09a7e1 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -8,7 +8,10 @@ import torch from torch._ops import OpOverload from torch._subclasses import FakeTensorMode +<<<<<<< HEAD from torch.distributed._functional_collectives import _are_we_tracing +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._op_schema import ( OpInfo, @@ -102,6 +105,7 @@ def register_op_strategy( schema_info: Optional[RuntimeSchemaInfo] = None, ): """ +<<<<<<< HEAD Register a :class:`OpStrategy` generator for an operator. During the sharding propagation, DTensor wants to enumerate all @@ -145,6 +149,9 @@ def register_op_strategy( last two would affect sharding propagation along with the :class:`DTensor` argument ``self``. Since the argument index of ``min`` is 2, the `schema_info` should be `RuntimeSchemaInfo(static_argnum=2)`. +======= + Register a sharding strategy generator for an operator. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ self.op_strategy_funcs[op_overload] = strategy_func if schema_info is not None: @@ -199,6 +206,7 @@ def _propagate_tensor_meta_non_cached( def _propagate_tensor_meta( self, op_schema: OpSchema ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: +<<<<<<< HEAD """ Cached version of _propagate_tensor_meta_non_cached This is a private API. Use propagate_tensor_meta instead. @@ -218,6 +226,10 @@ def propagate_tensor_meta( else: return self._propagate_tensor_meta(op_schema) +======= + return self._propagate_tensor_meta_non_cached(op_schema) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _wrap_output_spec_tensor_meta( self, op: OpOverload, @@ -312,7 +324,10 @@ def spec_to_strategy(spec: object) -> object: op=op_schema.op, args_schema=tuple(args_op_strategy), kwargs_schema=kwargs_op_strategy, +<<<<<<< HEAD schema_info=op_schema.schema_info, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def propagate(self, op_info: OpInfo) -> None: @@ -320,7 +335,11 @@ def propagate(self, op_info: OpInfo) -> None: # because SymInts are not hashable. # This is generally ok because this only happens during tracing in torch.compile, # and tracing does not need to be as fast as eagermode DTensor usages. +<<<<<<< HEAD if _are_we_tracing(): +======= + if op_info.schema.has_symints: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) else: output_sharding = cast( @@ -338,6 +357,10 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin return OutputSharding(None, op_schema) out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema) +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if op_schema.op in self.op_strategy_funcs: # wrap the op_schema with op strategy for sharding strategy propagation strategy_schema = self._wrap_with_op_strategy(op_schema) @@ -347,12 +370,19 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin if isinstance(op_strategy, OpStrategy): # single Op strategy +<<<<<<< HEAD output_strategy = self._select_strategy(op_strategy, op_schema) # check if we need to redistribute the input needs_redistribute = False # check if we want to use args value from redistribute_schema use_val_from_redistribute_schema = False +======= + output_strategy = self._select_strategy(op_strategy) + + # check if we need to redistribute the input + needs_redistribute = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expected_input_specs: list[DTensorSpec] = [] # in case where the op does not specify input_specs and output_specs @@ -395,7 +425,10 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin out_tensor_meta, schema, output_strategy.output_spec ) needs_redistribute = True +<<<<<<< HEAD use_val_from_redistribute_schema = True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # construct output spec for the op if op_schema.return_type_tuple_tensor_like(): @@ -416,10 +449,14 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin for _ in range(len(op_schema.op._schema.returns)) ] ) +<<<<<<< HEAD elif ( op_schema.return_type_tensor() or op_schema.return_type_list_tensor_like() ): +======= + elif op_schema.return_type_tensor(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_specs = output_strategy.output_specs else: output_specs = None @@ -428,14 +465,21 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin output_specs, suggestion_schema, needs_redistribute=needs_redistribute, +<<<<<<< HEAD use_val_from_redistribute_schema=use_val_from_redistribute_schema, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) elif isinstance(op_strategy, TupleStrategy): # tuple strategy output sharding processing # runtime select OpSpec for each TupleStrategy input arg selected_strategies: list[OpSpec] = [] out_spec_list: list[DTensorSpec] = [] +<<<<<<< HEAD for strategy in op_strategy.children: +======= + for strategy in op_strategy.childs: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(strategy, OpStrategy) selected_strategy = self._select_strategy(strategy) selected_strategies.append(selected_strategy) @@ -497,7 +541,10 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin tuple(out_spec_list) if out_tensor_meta is not None else None, suggestion_schema, needs_redistribute=needs_redistribute, +<<<<<<< HEAD use_val_from_redistribute_schema=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: raise ValueError("Unsupported op strategy type") @@ -555,22 +602,31 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin f"Operator {op_schema.op} does not have a sharding strategy registered." ) +<<<<<<< HEAD def _select_strategy( self, strategy: OpStrategy, op_schema: Optional[OpSchema] = None ) -> OpSpec: +======= + def _select_strategy(self, strategy: OpStrategy) -> OpSpec: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if len(strategy.strategies) == 1: # short cut with only one possible OpSpec return strategy.strategies[0] op_spec_costs: list[float] = [] +<<<<<<< HEAD no_redistribute_strategy_index: int = -1 for strategy_idx, op_spec in enumerate(strategy.strategies): +======= + for op_spec in strategy.strategies: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert op_spec.redistribute_cost is not None, ( "must set redistribute cost each OpSpec!" ) redistribute_cost = sum(chain.from_iterable(op_spec.redistribute_cost)) op_spec_costs.append(redistribute_cost) +<<<<<<< HEAD # If there's no redistribute cost, we record the index of the strategy # which doesn't need redistribute. # TODO: Currently this only applies to OpStrategy selection. Requires extra @@ -605,6 +661,10 @@ def _select_strategy( selected_strategy_index = op_spec_costs.index(min_cost) return strategy.strategies[selected_strategy_index] +======= + # for eager execution, we just select the one with the minimal redistribute cost + return strategy.strategies[op_spec_costs.index(min(op_spec_costs))] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _adjust_shape_and_stride_args( self, diff --git a/torch/distributed/tensor/_shards_wrapper.py b/torch/distributed/tensor/_shards_wrapper.py index a3798eac4ae0d..1344df779cd2c 100644 --- a/torch/distributed/tensor/_shards_wrapper.py +++ b/torch/distributed/tensor/_shards_wrapper.py @@ -27,7 +27,11 @@ class LocalShardsWrapper(torch.Tensor): """ A wrapper class to hold local shards of a DTensor. +<<<<<<< HEAD This class is used largely for checkpointing purposes and implicitly subtypes +======= + This class is used largely for checkpointing purposes and implicity subtypes +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) the _Checkpointable protocol. """ @@ -159,7 +163,11 @@ def handle_view(args, kwargs) -> "LocalShardsWrapper": ] elif args[0].local_shards()[0].ndim == 1: assert args[0].storage_metadata().size[0] == view_shape[0] +<<<<<<< HEAD # This case is for optimizer sharding as regardless of sharding type, optimizer state is row wise sharded +======= + # This case is for optimizer sharding as regardles of sharding type, optimizer state is row wise sharded +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) res_shards_list = [ aten.view.default(shard, shard.shape, **kwargs) for shard in args[0].local_shards() diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index a39c49f5230a4..31baf604e7261 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -284,19 +284,31 @@ def compute_global_tensor_shape( if isinstance(placements[0], Replicate): return shape elif isinstance(placements[0], Shard): +<<<<<<< HEAD local_shape = torch.tensor(list(shape), device=mesh.device_type) +======= + local_shape = torch.tensor(list(shape)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gathered_shaped_tensors = [ torch.empty_like(local_shape, device=local_shape.device) for _ in range(mesh.size()) ] +<<<<<<< HEAD funcol.all_gather_inplace(gathered_shaped_tensors, local_shape, mesh) +======= + funcol.all_gather_inplace(gathered_shaped_tensors, local_shape) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sharded_dim_sum = 0 shard_dim = placements[0].dim other_dims = [d for d in range(mesh.ndim) if d != shard_dim] for shape_tensor in gathered_shaped_tensors: if not torch.equal(local_shape[other_dims], shape_tensor[other_dims]): raise RuntimeError( +<<<<<<< HEAD "Non-sharded dimensions should have identical size across ranks." +======= + "Non-sharded dimentions should have identical size across ranks." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) shape_tensor_list = shape_tensor.tolist() sharded_dim_sum += shape_tensor_list[shard_dim] diff --git a/torch/distributed/tensor/debug/_comm_mode.py b/torch/distributed/tensor/debug/_comm_mode.py index 99978f9cc6b5e..121d1c332da6d 100644 --- a/torch/distributed/tensor/debug/_comm_mode.py +++ b/torch/distributed/tensor/debug/_comm_mode.py @@ -395,7 +395,11 @@ def add_json_information(json_dict, fqn): json_dict: dict[str, Any] = {} add_json_information(json_dict, "Global") +<<<<<<< HEAD # converts dictionary into json file +======= + # converts dictonary into json file +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with open(file_name, "w") as json_file: json.dump(json_dict, json_file, indent=4) diff --git a/torch/distributed/tensor/examples/comm_mode_features_example.py b/torch/distributed/tensor/examples/comm_mode_features_example.py index 8625a3f7dd1d7..bb821a8974d41 100644 --- a/torch/distributed/tensor/examples/comm_mode_features_example.py +++ b/torch/distributed/tensor/examples/comm_mode_features_example.py @@ -27,10 +27,18 @@ def get_device_type() -> str: +<<<<<<< HEAD device_type = "cpu" if torch.accelerator.device_count() >= 4: device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") return device_type +======= + return ( + "cuda" + if torch.cuda.is_available() and torch.cuda.device_count() >= 4 + else "cpu" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10d_functional = torch.ops.c10d_functional @@ -710,7 +718,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def run_example(world_size: int, rank: int, example_name: str) -> None: # set manual seed +<<<<<<< HEAD # initializing class with all of the functions +======= + # intializing class with all of the functions +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instantiated_example = CommDebugModeExample(world_size, rank) # dict that stores example code function names name_to_example_code: dict[str, Callable[[], None]] = { diff --git a/torch/distributed/tensor/examples/convnext_example.py b/torch/distributed/tensor/examples/convnext_example.py index 994f2ee10f69b..823f4d72f70b6 100644 --- a/torch/distributed/tensor/examples/convnext_example.py +++ b/torch/distributed/tensor/examples/convnext_example.py @@ -1,7 +1,11 @@ # mypy: allow-untyped-defs """ The following example demonstrates how to train a ConvNeXt model +<<<<<<< HEAD with intermediate activations sharded across multiple GPUs via DTensor +======= +with intermediate activations sharded across mutliple GPUs via DTensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) To run the example, use the following command: torchrun --standalone --nnodes=1 --nproc-per-node=4 convnext_example.py diff --git a/torch/distributed/tensor/examples/torchrec_sharding_example.py b/torch/distributed/tensor/examples/torchrec_sharding_example.py index 2c5d104136102..4c0ff4dd487b1 100644 --- a/torch/distributed/tensor/examples/torchrec_sharding_example.py +++ b/torch/distributed/tensor/examples/torchrec_sharding_example.py @@ -231,7 +231,11 @@ def run_torchrec_row_wise_uneven_sharding_example(rank, world_size): # note: for uneven sharding, we need to specify the shape and stride because # DTensor would assume even sharding and compute shape/stride based on the +<<<<<<< HEAD # assumption. Torchrec needs to pass in this information explicitly. +======= + # assumption. Torchrec needs to pass in this information explicitely. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # shape/stride are global tensor's shape and stride dtensor = DTensor.from_local( local_shards_wrapper, # a torch.Tensor subclass @@ -324,7 +328,11 @@ def run_torchrec_table_wise_sharding_example(rank, world_size): # create a DTensor from the local shard for the current table # note: for uneven sharding, we need to specify the shape and stride because # DTensor would assume even sharding and compute shape/stride based on the +<<<<<<< HEAD # assumption. Torchrec needs to pass in this information explicitly. +======= + # assumption. Torchrec needs to pass in this information explicitely. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtensor = DTensor.from_local( local_shards, device_submesh, diff --git a/torch/distributed/tensor/examples/visualize_sharding_example.py b/torch/distributed/tensor/examples/visualize_sharding_example.py index 7c0ab3adfffae..f8598a40d0e5a 100644 --- a/torch/distributed/tensor/examples/visualize_sharding_example.py +++ b/torch/distributed/tensor/examples/visualize_sharding_example.py @@ -18,9 +18,12 @@ rank = int(os.environ["RANK"]) +<<<<<<< HEAD device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def section(msg: str) -> None: if rank == 0: rich.print(rich.rule.Rule(msg)) @@ -34,7 +37,11 @@ def visualize(t: dt.DTensor, msg: str = "") -> None: section("[bold]1D Tensor; 1D Mesh[/bold]") +<<<<<<< HEAD m = dist.init_device_mesh(device_type, (4,)) +======= +m = dist.init_device_mesh("cuda", (4,)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) t = torch.ones(4) visualize( dt.distribute_tensor(t, m, [dt.Replicate()]), @@ -46,7 +53,11 @@ def visualize(t: dt.DTensor, msg: str = "") -> None: ) section("[bold]2D Tensor; 1D Mesh[/bold]") +<<<<<<< HEAD m = dist.init_device_mesh(device_type, (4,)) +======= +m = dist.init_device_mesh("cuda", (4,)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) t = torch.ones(4, 4) visualize( dt.distribute_tensor(t, m, [dt.Replicate()]), @@ -62,7 +73,11 @@ def visualize(t: dt.DTensor, msg: str = "") -> None: ) section("[bold]1D Tensor; 2D Mesh[/bold]") +<<<<<<< HEAD m = dist.init_device_mesh(device_type, (2, 2)) +======= +m = dist.init_device_mesh("cuda", (2, 2)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) t = torch.ones(4) visualize( dt.distribute_tensor(t, m, [dt.Replicate(), dt.Replicate()]), @@ -82,7 +97,11 @@ def visualize(t: dt.DTensor, msg: str = "") -> None: ) section("[bold]2D Tensor; 2D Mesh[/bold]") +<<<<<<< HEAD m = dist.init_device_mesh(device_type, (2, 2)) +======= +m = dist.init_device_mesh("cuda", (2, 2)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) t = torch.ones(4, 4) visualize( dt.distribute_tensor(t, m, [dt.Replicate(), dt.Replicate()]), diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index 6cd06727cd2b2..055b923cc2351 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -1,3 +1,8 @@ +<<<<<<< HEAD +======= +# Copyright (c) Meta Platforms, Inc. and affiliates + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import contextlib import itertools import logging @@ -15,6 +20,7 @@ import torch.nn.functional as F from torch import nn from torch.distributed.device_mesh import DeviceMesh +<<<<<<< HEAD from torch.distributed.tensor import ( distribute_module, distribute_tensor, @@ -28,6 +34,10 @@ BlockMask, create_block_mask, ) +======= +from torch.distributed.tensor import distribute_module, DTensor, Replicate, Shard +from torch.distributed.tensor.parallel.style import ParallelStyle +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.overrides import TorchFunctionMode @@ -71,6 +81,7 @@ class _ContextParallelOptions: _cp_options = _ContextParallelOptions() +<<<<<<< HEAD @dataclass class _ContextParallelGlobalVars: # The current context parallel impl requires a record of some info @@ -90,6 +101,8 @@ def _set_cp_global_var(name: str, value: Any) -> None: setattr(_cp_global_vars, name, value) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _is_causal_behavior( rank: int, world_size: int, i: int, is_causal: bool ) -> _CausalBehavior: @@ -221,6 +234,109 @@ def results(self) -> tuple[torch.Tensor, torch.Tensor]: return out.to(self._out_dtype), lse.to(self._lse_dtype) +<<<<<<< HEAD +======= +def _scaled_dot_product_ring_flash_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + if return_debug_mask: + raise NotImplementedError("return_debug_mask is not supported yet") + + seq_dim = 2 + return _templated_ring_attention( + mesh, + seq_dim, + aten._scaled_dot_product_flash_attention, + query=query, + key=key, + value=value, + is_causal=is_causal, + dropout_p=dropout_p, + scale=scale, + ) + + +def _scaled_dot_product_ring_efficient_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + compute_log_sumexp: bool = True, + dropout_p: float = 0.0, + is_causal: bool = False, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + if attn_bias is not None: + raise NotImplementedError("attn_bias is not supported yet") + + if not compute_log_sumexp: + # CP requires compute_log_sumexp to be True because it always merges LSE + compute_log_sumexp = True + + seq_dim = 2 + return _templated_ring_attention( + mesh, + seq_dim, + aten._scaled_dot_product_efficient_attention, + query=query, + key=key, + value=value, + is_causal=is_causal, + attn_bias=attn_bias, + dropout_p=dropout_p, + scale=scale, + compute_log_sumexp=compute_log_sumexp, + ) + + +def _scaled_dot_product_ring_cudnn_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + compute_log_sumexp: bool = True, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + if attn_bias is not None: + raise NotImplementedError("attn_bias is not supported yet") + + if not compute_log_sumexp: + # CP requires compute_log_sumexp to be True because it always merges LSE + compute_log_sumexp = True + + seq_dim = 2 + return _templated_ring_attention( + mesh, + seq_dim, + aten._scaled_dot_product_cudnn_attention, + query=query, + key=key, + value=value, + attn_bias=attn_bias, + compute_log_sumexp=compute_log_sumexp, + dropout_p=dropout_p, + is_causal=is_causal, + return_debug_mask=return_debug_mask, + scale=scale, + ) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _AttentionOp(Protocol): def __call__( self, @@ -263,7 +379,11 @@ def next_buffer(self) -> torch.Tensor: class _AllGatherRotater(_RingRotater): """ +<<<<<<< HEAD Allgather the kv and return the only the required kv. +======= + Allgather the kv and return the only the requried kv. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Only one communication will be done. """ @@ -301,11 +421,32 @@ def _create_rotater( elif method == _RotateMethod.ALL_GATHER: return _AllGatherRotater(pg, seq_dim) else: +<<<<<<< HEAD raise NotImplementedError(f"Unknown method {method}") def _templated_ring_attention( group: dist.ProcessGroup, +======= + raise NotImplementedError(f"Unkonwn method {method}") + + +def _ring_rotate( + block: torch.Tensor, pg: dist.ProcessGroup, send_to_next: bool +) -> torch.Tensor: + block = block.contiguous() + size = dist.get_world_size(pg) + dsts = ( + list(range(1, size)) + [0] + if send_to_next + else [size - 1] + list(range(0, size - 1)) + ) + return ft_c.permute_tensor(block, dsts, pg) + + +def _templated_ring_attention( + mesh: DeviceMesh, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) seq_dim: int, op: _AttentionOp, query: torch.Tensor, @@ -363,12 +504,20 @@ def _templated_ring_attention( First Iteration: Both ranks perform SDPA with their local qkv pairs, similar to the no-load-balance case. This iteration corresponds to the `if` of the +<<<<<<< HEAD (`if, `elif`, `else`) in the implementation. +======= + (`if, `elif`, `else`) in the implemementation. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Second Iteration: Rank0 now has (q0, q3) and (k1, k2); rank1 has (q1, q2) and (k0, k3). For rank0, no computation is needed for q0. However, computations for q3k1 and q3k2 are required, so only q3 is used for SDPA. This corresponds to the +<<<<<<< HEAD `else` of the (`if`, `elif`, `else`) in the implementation. +======= + `else` of the (`if`, `elif`, `else`) in the implemementation. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) For rank1, k0 is not needed for q1 and q2, so only k3 is used for SDPA. This corresponds to the `elif` of (`if`, `elif`, `else`) in the implementation. @@ -395,11 +544,21 @@ def _templated_ring_attention( if not is_causal and _cp_options.enable_load_balance: raise RuntimeError("Load balancing requires `is_causal=True`.") +<<<<<<< HEAD assert isinstance(group, dist.ProcessGroup), ( "process group must be single dimension" ) rank = dist.get_rank(group) size = dist.get_world_size(group) +======= + if isinstance(mesh, dist.ProcessGroup): + pg: Union[dist.ProcessGroup, list[dist.ProcessGroup]] = mesh + else: + pg = mesh.get_group() + assert isinstance(pg, dist.ProcessGroup), "process group must be single dimension" + rank = dist.get_rank(pg) + size = dist.get_world_size(pg) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) next_kv = None @@ -415,7 +574,11 @@ def _templated_ring_attention( out: torch.Tensor logsumexp: torch.Tensor +<<<<<<< HEAD rotater = _create_rotater(group, 2) +======= + rotater = _create_rotater(pg, 2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i in range(size): if i > 0: @@ -475,8 +638,100 @@ def _templated_ring_attention( return *sdpa_merger.results(), *rest +<<<<<<< HEAD def _templated_ring_attention_backward( group: dist.ProcessGroup, +======= +def _sdpa_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + # extract local tensor and sharding infos to a OpInfo + op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + logger.debug("Dispatching op_call: %s", op_info.schema) + + # sharding propagation + # TODO: remove the context parallel strategy from the default propagation + # rule. Either figure out how to dynamically enable it or just don't call + # propagate. + DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + assert not output_sharding.needs_redistribute, "inputs need to be redistributed" + + if op_call == aten._scaled_dot_product_flash_attention.default: + local_results = _scaled_dot_product_ring_flash_attention( + op_info.compute_mesh, + *op_info.local_args, # type: ignore[arg-type] + **op_info.local_kwargs, # type: ignore[arg-type] + ) + elif op_call == aten._scaled_dot_product_efficient_attention.default: + local_results = _scaled_dot_product_ring_efficient_attention( + op_info.compute_mesh, + *op_info.local_args, # type: ignore[arg-type] + **op_info.local_kwargs, # type: ignore[arg-type] + ) + elif op_call == aten._scaled_dot_product_cudnn_attention.default: + local_results = _scaled_dot_product_ring_cudnn_attention( + op_info.compute_mesh, + *op_info.local_args, # type: ignore[arg-type] + **op_info.local_kwargs, # type: ignore[arg-type] + ) + else: + raise NotImplementedError( + "CP only supports flash attention and memory efficient attention now." + ) + + return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec) + + +def _sdpa_backward_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + # Redistribute grad_output tensor to the same placement as output tensor + args = list(args) + args = tuple(args) + + # extract local tensor and sharding infos to a OpInfo + op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + logger.debug("Dispatching op_call: %s", op_info.schema) + + # sharding propagation + DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + assert not output_sharding.needs_redistribute, "inputs need to be redistributed" + + if op_call == aten._scaled_dot_product_flash_attention_backward.default: + local_results = _scaled_dot_product_ring_flash_attention_backward( + op_info.compute_mesh, + *op_info.local_args, # type: ignore[arg-type] + **op_info.local_kwargs, # type: ignore[arg-type] + ) + elif op_call == aten._scaled_dot_product_efficient_attention_backward.default: + local_results = _scaled_dot_product_ring_efficient_attention_backward( + op_info.compute_mesh, + *op_info.local_args, # type: ignore[arg-type] + **op_info.local_kwargs, # type: ignore[arg-type] + ) + elif op_call == aten._scaled_dot_product_cudnn_attention_backward.default: + local_results = _scaled_dot_product_ring_cudnn_attention_backward( + op_info.compute_mesh, + *op_info.local_args, # type: ignore[arg-type] + **op_info.local_kwargs, # type: ignore[arg-type] + ) + else: + raise NotImplementedError(f"{op_call=}") + + return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec) + + +def _templated_ring_attention_backward( + mesh: DeviceMesh, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) seq_dim: int, op: _AttentionOp, grad_out: torch.Tensor, @@ -492,8 +747,15 @@ def _templated_ring_attention_backward( """This API implements the backward of the ring attention.""" if not is_causal and _cp_options.enable_load_balance: raise RuntimeError("Load balancing requires `is_causal=True`.") +<<<<<<< HEAD rank = dist.get_rank(group) size = dist.get_world_size(group) +======= + pg = mesh.get_group() + assert isinstance(pg, dist.ProcessGroup), "must be single dimension" + rank = dist.get_rank(pg) + size = dist.get_world_size(pg) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) next_kv = None next_grad_kv = None rest: list[Any] @@ -506,8 +768,13 @@ def _templated_ring_attention_backward( key = key.contiguous() value = value.contiguous() +<<<<<<< HEAD kv_rotater = _create_rotater(group, 2) dkv_rotater = _create_rotater(group, 2, method=_RotateMethod.ALL_TO_ALL) +======= + kv_rotater = _create_rotater(pg, 2) + dkv_rotater = _create_rotater(pg, 2, method=_RotateMethod.ALL_TO_ALL) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i in range(size): if i > 0: # Wait for the kv from the (cp_rank - 1) rank. @@ -642,6 +909,7 @@ def _templated_ring_attention_backward( ) +<<<<<<< HEAD def _scaled_dot_product_ring_flash_attention( mesh: DeviceMesh, query: torch.Tensor, @@ -748,6 +1016,8 @@ def _scaled_dot_product_ring_cudnn_attention( ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _scaled_dot_product_ring_flash_attention_backward( mesh: DeviceMesh, grad_out: torch.Tensor, @@ -767,11 +1037,17 @@ def _scaled_dot_product_ring_flash_attention_backward( *, scale: Optional[float] = None, ) -> tuple[torch.Tensor, ...]: +<<<<<<< HEAD # TODO: remove this hardcoding seq_dim = 2 group = mesh.get_group() return _templated_ring_attention_backward( group, +======= + seq_dim = 2 + return _templated_ring_attention_backward( + mesh, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) seq_dim, aten._scaled_dot_product_flash_attention_backward.default, grad_out=grad_out, @@ -810,11 +1086,17 @@ def _scaled_dot_product_ring_efficient_attention_backward( *, scale: Optional[float] = None, ) -> tuple[torch.Tensor, ...]: +<<<<<<< HEAD # TODO: remove this hardcoding seq_dim = 2 group = mesh.get_group() return _templated_ring_attention_backward( group, +======= + seq_dim = 2 + return _templated_ring_attention_backward( + mesh, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) seq_dim, aten._scaled_dot_product_efficient_attention_backward.default, grad_out=grad_out, @@ -854,11 +1136,17 @@ def _scaled_dot_product_ring_cudnn_attention_backward( *, scale: Optional[float] = None, ) -> tuple[torch.Tensor, ...]: +<<<<<<< HEAD # TODO: remove this hardcoding seq_dim = 2 group = mesh.get_group() return _templated_ring_attention_backward( group, +======= + seq_dim = 2 + return _templated_ring_attention_backward( + mesh, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) seq_dim, aten._scaled_dot_product_cudnn_attention_backward.default, grad_out=grad_out, @@ -881,6 +1169,7 @@ def _scaled_dot_product_ring_cudnn_attention_backward( ) +<<<<<<< HEAD def _sdpa_handler( op_call: torch._ops.OpOverload, args: tuple[object, ...], @@ -928,6 +1217,15 @@ def _sdpa_handler( aten._scaled_dot_product_efficient_attention_backward.default: _sdpa_handler, aten._scaled_dot_product_cudnn_attention.default: _sdpa_handler, aten._scaled_dot_product_cudnn_attention_backward.default: _sdpa_handler, +======= +customized_ops = { + aten._scaled_dot_product_flash_attention.default: _sdpa_handler, + aten._scaled_dot_product_flash_attention_backward.default: _sdpa_backward_handler, + aten._scaled_dot_product_efficient_attention.default: _sdpa_handler, + aten._scaled_dot_product_efficient_attention_backward.default: _sdpa_backward_handler, + aten._scaled_dot_product_cudnn_attention.default: _sdpa_handler, + aten._scaled_dot_product_cudnn_attention_backward.default: _sdpa_backward_handler, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } @@ -946,7 +1244,11 @@ def _distribute_function( the inputs and outputs of a function. Similar to ``distribute_module``, this API installs hooks to the ``fn`` to convert the inputs and outputs. There are two major differences between ``distribute_function`` and ``distribute_module``. +<<<<<<< HEAD First, a function does not have parameters and buffers, as a result, +======= + First, a function does not have parammeters and buffers, as a result, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``distribute_function`` itself won't convert any parameters/buffers but simply install the input and output hooks. The tensor conversion will happen in the hooks. Another difference is an nn.Module subclass can have several instances and each @@ -962,9 +1264,15 @@ def _distribute_function( ``fn_module`` is ``torch.nn.functional``. device_mesh (:class:`DeviceMesh`): the device mesh that will be used by the input and output hooks to distribute the tensors. +<<<<<<< HEAD input_fn (Optional[Callable]): the hook to distribute or convert the input arguments of ``fn``. output_fn (Optional[Callable]): the hook to distribute or convert the output +======= + input_fn (Optioinal[Callable]): the hook to distribute or convert the input + arguments of ``fn``. + output_fn (Optioinal[Callable]): the hook to distribute or convert the output +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) arguments of ``fn``. """ @@ -1019,7 +1327,11 @@ class _AttentionContextParallel(ParallelStyle): Applies context parallel optimizations to the attention layer. This will work for nn.MultiHeadedAttention and custom attention layers that +<<<<<<< HEAD call F.scaled_dotproduct_attention with a similar signature. +======= + call F.scaled_dotproduct_attention with a simliar signature. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) This expects the `forward` method consumes either: @@ -1038,8 +1350,18 @@ class _AttentionContextParallel(ParallelStyle): ) def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: +<<<<<<< HEAD if not device_mesh.ndim == 1: raise ValueError("CP only supports single dimension device mesh") +======= + if not isinstance(device_mesh, DeviceMesh): + raise ValueError( + f"{type(device_mesh)} is not supported by {type(self)} yet." + ) + + if not device_mesh.ndim == 1: + raise ValueError +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return distribute_module( module, @@ -1116,6 +1438,7 @@ def backward_hook(grad: torch.Tensor) -> None: return tuple(out) +<<<<<<< HEAD def create_cp_block_mask( mask_mod: _mask_mod_signature, B: int, @@ -1202,6 +1525,8 @@ def local_q_idx_to_q_idx(local_q_idx: torch.Tensor) -> torch.Tensor: return block_mask +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @contextlib.contextmanager def _context_parallel(seq_dim: int, mesh: DeviceMesh) -> Generator[None, None, None]: """Replace SDPA with the CP-wrapped version and enable DTensor CP dispatcher.""" @@ -1233,12 +1558,15 @@ def attention_output_fn(mesh: DeviceMesh, outputs: Any) -> Any: return tuple(new_outputs) +<<<<<<< HEAD def unshard(x: torch.Tensor, mesh: DeviceMesh, shard_dim: int) -> torch.Tensor: x = x.contiguous() all_xs = [torch.empty_like(x) for _ in range(mesh.size())] ft_c.all_gather_inplace(all_xs, x, mesh) return torch.cat(all_xs, dim=shard_dim) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class DistributeFunction(TorchFunctionMode): def __init__( self, @@ -1261,6 +1589,7 @@ def __torch_function__( ) -> Any: kwargs = kwargs or {} +<<<<<<< HEAD # special handler for flex_attention if func == torch._higher_order_ops.flex_attention: query, key, value, score_mod, block_mask = args[:5] @@ -1293,6 +1622,8 @@ def __torch_function__( **kwargs, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if func != self._fn: return func(*args, **kwargs) @@ -1315,6 +1646,7 @@ def __torch_function__( yield _restore_function(F.scaled_dot_product_attention, F) elif _dispatch_mode == _DispatchMode.TORCH_FUNCTION: +<<<<<<< HEAD tf_mode = _cp_global_vars.torch_function_mode if tf_mode is None: tf_mode = DistributeFunction( @@ -1326,12 +1658,21 @@ def __torch_function__( _set_cp_global_var("torch_function_mode", tf_mode) with tf_mode: +======= + with DistributeFunction( + F.scaled_dot_product_attention, + mesh, + attention_input_fn, + attention_output_fn, + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with _enable_cp_dispatcher(): yield else: raise NotImplementedError("torch dispatch mode is not supported yet.") +<<<<<<< HEAD def _generate_round_robin_indices( seq_length: int, cp_world_size: int, @@ -1373,12 +1714,100 @@ def _generate_round_robin_indices( if restore: all_indices_tensor = torch.argsort(all_indices_tensor) return all_indices_tensor +======= +class _LoadBalancer(ABC): + @classmethod + @abstractmethod + def shard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: ... + + @classmethod + @abstractmethod + def unshard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: ... + + +class _SequentialSharder(_LoadBalancer): + """ + This load balancer chunks the buffer into cp_world_size and rank0 gets + 0th shard, rank1 gets 1st shard, ... + So this doesn't have any load balancing effect when using the causal masking. + """ + + @classmethod + def shard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + assert buffer.size()[seq_dim] % mesh.size() == 0 + return buffer.chunk(mesh.size(), dim=seq_dim)[mesh.get_local_rank()] + + @classmethod + def unshard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + buffer = buffer.contiguous() + all_buffers = [torch.empty_like(buffer) for _ in range(mesh.size())] + ft_c.all_gather_inplace(all_buffers, buffer, mesh) + return torch.cat(all_buffers, dim=seq_dim) + + +class _RoundRobinLoadBalancer(_LoadBalancer): + """ + This load balancer chunk the buffer into cp_world_size * ROUND_ROBIN_CYCLE + shards, and uses a round robin approach to achieve load balancing. + Since ROUND_ROBIN_CYCLE being 2 will achieve perfect load balancing for + causal masking, we assume ROUND_ROBIN_CYCLE is always 2 to simplify the + implementation. + """ + + ROUND_ROBIN_CYCLE = 2 + + @classmethod + def shard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + assert cls.ROUND_ROBIN_CYCLE == 2, ( + "The current implementation only works if ROUND_ROBIN_CYCLE is 2." + ) + cp_world_size = mesh.size() + cp_rank = mesh.get_local_rank() + assert buffer.size()[seq_dim] % (cp_world_size * 2) == 0 + chunks = buffer.chunk(cp_world_size * 2, dim=seq_dim) + return torch.cat( + (chunks[cp_rank], chunks[cp_world_size * 2 - cp_rank - 1]), + dim=seq_dim, + ) + + @classmethod + def unshard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + assert cls.ROUND_ROBIN_CYCLE == 2, ( + "The current implementation only works if ROUND_ROBIN_CYCLE is 2." + ) + buffer = buffer.contiguous() + cp_world_size = mesh.size() + + all_buffers = [torch.empty_like(buffer) for _ in range(cp_world_size)] + ft_c.all_gather_inplace(all_buffers, buffer, mesh) + sliced_buffers = [sb for b in all_buffers for sb in b.chunk(2, dim=seq_dim)] + ordered_buffers = list(sliced_buffers) + for i, b in enumerate(sliced_buffers): + if i % 2 == 0: + ordered_buffers[i // 2] = b + else: + ordered_buffers[cp_world_size * 2 - (i // 2) - 1] = b + return torch.cat(ordered_buffers, dim=seq_dim) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _context_parallel_buffers( mesh: DeviceMesh, buffers: list[torch.Tensor], buffer_seq_dims: list[int], +<<<<<<< HEAD load_balance_indices: Optional[torch.Tensor] = None, ) -> list[torch.Tensor]: """Shard the buffers along the sequence dimensions according to CP rules.""" @@ -1392,6 +1821,18 @@ def _context_parallel_buffers( buffer, mesh, [Shard(seq_dim)], src_data_rank=None ).to_local() new_buffers.append(sharded_buffer) +======= +) -> list[torch.Tensor]: + """Shard the buffers along the sequence dimensions according to CP rules.""" + new_buffers = [] + sharder = ( + _RoundRobinLoadBalancer + if _cp_options.enable_load_balance + else _SequentialSharder + ) + for buffer, seq_dim in zip(buffers, buffer_seq_dims): + new_buffers.append(sharder.shard(buffer, mesh, seq_dim)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return new_buffers @@ -1449,6 +1890,7 @@ def context_parallel( raise ValueError("`no_restore_buffers` must be a subset of `buffers`.") original_buffers = [None if b in no_restore_buffers else b.clone() for b in buffers] +<<<<<<< HEAD device = buffers[0].device seq_length = buffers[0].shape[buffer_seq_dims[0]] @@ -1468,6 +1910,13 @@ def context_parallel( shard = shard.clone() buffer.resize_(shard.shape) buffer.copy_(shard) +======= + chunks = _context_parallel_buffers(mesh, buffers, buffer_seq_dims) + for buffer, chunk in zip(buffers, chunks): + chunk = chunk.clone() + buffer.resize_(chunk.shape) + buffer.copy_(chunk) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with _context_parallel(seq_dim=2, mesh=mesh): yield @@ -1496,6 +1945,7 @@ def context_parallel_unshard( Returns: List[torch.Tensor]: the unsharded buffers. """ +<<<<<<< HEAD if _cp_options.enable_load_balance: device = buffers[0].device cp_world_size = mesh.size() @@ -1520,6 +1970,14 @@ def context_parallel_unshard( unsharded_buffers.append(unsharded_b) return unsharded_buffers +======= + sharder = ( + _RoundRobinLoadBalancer + if _cp_options.enable_load_balance + else _SequentialSharder + ) + return [sharder.unshard(b, mesh, dim) for b, dim in zip(buffers, seq_dims)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def set_rotate_method(rotate_method: str) -> None: diff --git a/torch/distributed/tensor/experimental/_func_map.py b/torch/distributed/tensor/experimental/_func_map.py index 31cdd0f9a06fc..1f6a988a855e1 100644 --- a/torch/distributed/tensor/experimental/_func_map.py +++ b/torch/distributed/tensor/experimental/_func_map.py @@ -24,10 +24,17 @@ def local_map( +<<<<<<< HEAD func: Optional[Callable] = None, out_placements: OutputPlacements = None, in_placements: InputPlacements = None, in_grad_placements: InputPlacements = None, +======= + func: Callable, + out_placements: OutputPlacements, + in_placements: Optional[InputPlacements] = None, + in_grad_placements: Optional[InputPlacements] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device_mesh: Optional[DeviceMesh] = None, *, redistribute_inputs: bool = False, @@ -76,11 +83,18 @@ def local_map( tensor input remains the same as the original :class:`DTensor` input and use that for gradient computation. Default: None. device_mesh (:class:`DeviceMesh`, optional): +<<<<<<< HEAD the device mesh that the output :class:`DTensor` s are placed on. If not specified, this will be inferred from the first input :class:`DTensor`'s device mesh. Default: None. Keyword Args: +======= + the device mesh that all the :class:`DTensor` s are placed on. If not + specified, this will be inferred from the input :class:`DTensor` s' device + mesh. `local_map` requires every :class:`DTensor` s to be placed on the same + device mesh. Default: None. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) redistribute_inputs (bool, optional): the bool value indicating whether to reshard the input :class:`DTensor` s when their placements are different from the required input placements. If this @@ -92,6 +106,13 @@ def local_map( and returns a :class:`DTensor` constructed from the return value of ``func``. Raises: +<<<<<<< HEAD +======= + AssertionError: If the input :class:`DTensor` is not placed on the same device + mesh, or if they are placed on a different device mesh than the ``device_mesh`` + argument passed in. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) AssertionError: For any non-DTensor output, we require its corresponding output placement in ``out_placements`` be None. An AssertionError will be raised if this is not the case. @@ -112,7 +133,11 @@ def local_map( >>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh >>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh >>> +<<<<<<< HEAD >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor conversion +======= + >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> local_mm_allreduce_forward = local_map( >>> mm_allreduce_forward, >>> out_placements=[Replicate()], @@ -133,6 +158,7 @@ def local_map( .. note:: This API is currently experimental and subject to change """ +<<<<<<< HEAD if func is None: # decorator mode def decorated(func): @@ -274,3 +300,121 @@ def _local_map_wrapped( return pytree.tree_unflatten(flat_dist_out, out_spec) else: return out +======= + def wrapped(device_mesh: Optional[DeviceMesh], *args, **kwargs): + # process input args + flat_args, args_spec = pytree.tree_flatten(args) + if in_placements is not None: + assert len(in_placements) == len(flat_args), ( + f"in_placements length {len(in_placements)} does not match the number " + f"of input args {len(flat_args)}!" + ) + + # we assume every DTensor object is placed on the same device mesh + flat_local_args = [] + seen_dtensor_arg = False + for idx, arg in enumerate(flat_args): + if isinstance(arg, DTensor): + # TODO: the current code doesn't consider the uneven sharding case + # Need to think about what the consequence is when the input DTensor + # is uneven sharded. + if device_mesh is None: # infer device mesh from the DTensor arg + device_mesh = arg.device_mesh + + # this function is applied to at least one DTensor argument + seen_dtensor_arg = True + + assert arg.device_mesh == device_mesh, ( + f"arg {arg} in local_map has a mismatched device mesh: " + f"{arg} has device mesh {arg.device_mesh} while " + f"the expected device mesh is {device_mesh}!" + ) + if in_placements is not None: + spec = in_placements[idx] + assert spec is not None, ( + f"DTensor input {arg} expects placements but received {spec}!" + ) + + if not isinstance(spec, tuple): + spec = tuple(spec) + + if arg.placements != spec: + if redistribute_inputs: + # redistribute to input placements + arg = arg.redistribute(device_mesh, spec) + else: + raise ValueError( + f"arg {arg} in local_map has a mismatched placements: " + f"arg placements is {arg.placements} but the input " + f"placements is {spec}! " + "If redistribute_inputs is wanted, set " + "redistribute_inputs=True to local_map." + ) + + if in_grad_placements is not None: + spec = in_grad_placements[idx] + assert spec is not None, ( + f"DTensor input {arg} expects in grad placements but received {spec}!" + ) + if not isinstance(spec, tuple): + spec = tuple(spec) + local_arg = arg.to_local(grad_placements=spec) + else: + local_arg = arg.to_local() + + if isinstance(local_arg, AsyncCollectiveTensor): + local_arg = local_arg.wait() + + flat_local_args.append(local_arg) + else: + # Non-Tensor input must have None in `in_placements` + if in_placements is not None and not isinstance(arg, torch.Tensor): + spec = in_placements[idx] + assert spec is None, ( + f"Non-Tensor input {arg} expects None placements " + f"but received {spec}!" + ) + + flat_local_args.append(arg) + + local_args = pytree.tree_unflatten(flat_local_args, args_spec) + + out = func(*local_args, **kwargs) + + if seen_dtensor_arg: + # process output + flat_out, out_spec = pytree.tree_flatten(out) + + flat_dist_out = [] + out_placements_tuple = ( + out_placements + if isinstance(out_placements, tuple) + else (out_placements,) + ) + assert len(flat_out) == len(out_placements_tuple), ( + "local_map requires one PlacementType be provided for each output value," + f" received {len(out_placements_tuple)} out_placements but" + f" {len(flat_out)} is expected!" + ) + for out, spec in zip(flat_out, out_placements_tuple): + if isinstance(out, torch.Tensor): + assert not isinstance(out, DTensor), ( + f"torch.Tensor output expected but received {type(out)}: {out}" + ) + + flat_dist_out.append( + DTensor.from_local(out, device_mesh, spec, run_check=False) + ) + else: + assert spec is None, ( + f"Non-tensor output {out} expects None placements but received {spec}!" + ) + + flat_dist_out.append(out) + + return pytree.tree_unflatten(flat_dist_out, out_spec) + else: + return out + + return functools.partial(wrapped, device_mesh) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/distributed/tensor/experimental/_register_sharding.py b/torch/distributed/tensor/experimental/_register_sharding.py index b286b151efed5..73fed5f6cbea8 100644 --- a/torch/distributed/tensor/experimental/_register_sharding.py +++ b/torch/distributed/tensor/experimental/_register_sharding.py @@ -41,7 +41,11 @@ def register_sharding(op: Union[OpOverload, list[OpOverload]]): as the original op (except that if an arg is a :class:`torch.Tensor`, it will be replaced by a tensor-like object that DTensor uses internally). The function should return a sequence of 2-tuples, each specifying acceptable output placements and its +<<<<<<< HEAD corresponding input placements. +======= + corresponding intput placements. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Example: >>> # xdoctest: +SKIP("distributed") @@ -77,7 +81,11 @@ def strategy_to_spec(strategy: object) -> object: # take the output spec from the first strategy return strategy.strategies[0].output_spec elif isinstance(strategy, TupleStrategy): +<<<<<<< HEAD return tuple(strategy_to_spec(s) for s in strategy.children) +======= + return tuple(strategy_to_spec(s) for s in strategy.childs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return strategy diff --git a/torch/distributed/tensor/parallel/_data_parallel_utils.py b/torch/distributed/tensor/parallel/_data_parallel_utils.py index c41da260a02f9..3c0520a5d796a 100644 --- a/torch/distributed/tensor/parallel/_data_parallel_utils.py +++ b/torch/distributed/tensor/parallel/_data_parallel_utils.py @@ -30,7 +30,11 @@ def _flatten_tensor( @no_type_check def _unflatten_tensor(tensor, spec, *, device_handle=None, compute_stream=None): +<<<<<<< HEAD # unflatten would mainly be called every time FSDP allgather parameters. +======= + # unflatten would mainly be called everytime FSDP allgather parameters. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result = DTensor.from_local( tensor, spec.mesh, diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py new file mode 100644 index 0000000000000..0a78872f57d8b --- /dev/null +++ b/torch/distributed/tensor/parallel/_utils.py @@ -0,0 +1,67 @@ +# mypy: allow-untyped-defs +import warnings +from typing import Union + +from torch.distributed.device_mesh import _mesh_resources +from torch.distributed.tensor import DeviceMesh +from torch.distributed.tensor.placement_types import Placement + + +try: + from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling +except Exception: + + def is_torchdynamo_compiling(): # type: ignore[misc] + return False + + +LayoutsType = Union[Placement, tuple[Placement, ...]] + + +def _deprecate_warnings(func_name: str, extra_msg: str) -> None: + """ + Inject common validation logics for `_prepare_input` funcs via this decorator. + + Include verifying that input needs to be either a :class:`Tensor` or :class:`DTensor` + and only 1D :class:`DeviceMesh` is passed in. + """ + # TODO: Will follow up with dynamo POC to make warnings.warn working with dynamo. + if not is_torchdynamo_compiling(): + warnings.warn( + f"{func_name} is deprecated and will be removed soon. {extra_msg}", + FutureWarning, + stacklevel=3, + ) + + +def _validate_tp_mesh_dim( + device_mesh: DeviceMesh, +) -> None: + """ + Check whether TP mesh dimension is valid or not. + + Args: + device_mesh (:class:`DeviceMesh`): + The `device_mesh` where we perform + Tensor Parallelism on. + + Return: + `True` if the mesh dimension + is valid, `False` otherwise. + """ + if device_mesh.ndim > 1: + raise ValueError( + f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!" + 'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]' + ) + + root_mesh = _mesh_resources.get_root_mesh(device_mesh) + # if a root mesh is not the same as device_mesh, + # meaning the device_mesh is sliced out from the root mesh. + if root_mesh and root_mesh != device_mesh: + tp_mesh_dim_in_root = _mesh_resources.get_root_mesh_dim(device_mesh) + if tp_mesh_dim_in_root != root_mesh.ndim - 1: + raise RuntimeError( + f"Found TP device_mesh on the {tp_mesh_dim_in_root} dimension of its parent mesh.", + "Currently we only support intranode TP and TP needs to be the innermost dimension on its parent mesh.", + ) diff --git a/torch/distributed/tensor/parallel/api.py b/torch/distributed/tensor/parallel/api.py index 2a3369a8edda0..e6d62475e2cc3 100644 --- a/torch/distributed/tensor/parallel/api.py +++ b/torch/distributed/tensor/parallel/api.py @@ -6,6 +6,10 @@ import torch import torch.nn as nn from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +<<<<<<< HEAD +======= +from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.distributed.tensor.parallel.style import ParallelStyle @@ -70,6 +74,10 @@ def parallelize_module( # type: ignore[return] torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module") device_mesh = device_mesh or _mesh_resources.get_current_mesh() +<<<<<<< HEAD +======= + _validate_tp_mesh_dim(device_mesh) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if parallelize_plan is None: warnings.warn( @@ -86,6 +94,7 @@ def parallelize_module( # type: ignore[return] return parallelize_plan._apply(module, device_mesh) elif isinstance(parallelize_plan, dict): for module_path, parallelize_style in parallelize_plan.items(): +<<<<<<< HEAD if module_path == "": # shortcut: empty string means to apply the plan to the current module parallelize_module(module, device_mesh, parallelize_style) @@ -133,6 +142,41 @@ def parallelize_module( # type: ignore[return] parallelize_style, src_data_rank=src_data_rank, ) +======= + path_splits = module_path.split(".") + if len(path_splits) == 0: + raise ValueError( + "Expect module path to be non-empty, but got empty string!" + ) + while path_splits: + atom = path_splits.pop(0) + matched_children = filter( + # `t[0]` is child name + lambda t: fnmatch(t[0], atom), + module.named_children(), + ) + # apply the plan to all matched submodules + for _, submodule in matched_children: + if path_splits: + # we haven't reached the leaf, apply in dict style + leaf_path = ".".join( + path_splits + ) # rest of the path after `atom` + parallelize_module( + submodule, + device_mesh, + {leaf_path: parallelize_style}, + src_data_rank=src_data_rank, + ) + else: + # otherwise, directly apply style to this submodule + parallelize_module( + submodule, + device_mesh, + parallelize_style, + src_data_rank=src_data_rank, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return module else: raise TypeError( # pyre-ignore[7] diff --git a/torch/distributed/tensor/parallel/ddp.py b/torch/distributed/tensor/parallel/ddp.py index 7b19f97675197..881ced4561dfd 100644 --- a/torch/distributed/tensor/parallel/ddp.py +++ b/torch/distributed/tensor/parallel/ddp.py @@ -36,7 +36,11 @@ def _update_module_param(param_list: list[tuple[nn.Module, str, nn.Parameter]]): def _reconstruct_dtensor(module: nn.Module, _input: Any): """ +<<<<<<< HEAD Reconstruct DTensor parameters from local tensors +======= + Recontruct DTensor parameters from local tensors +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ param_list = [] # TODO: To add perf optimizations to this iterations diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index 1b0b8cac7c760..773859d1af6bf 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -326,7 +326,11 @@ def __init__(self, device_handle) -> None: super().__init__() self.compute_stream = None self.device_handle = device_handle +<<<<<<< HEAD # we have to use the dynamo disable this way to disable dynamo as the decorator way would +======= + # we have to use the dynamo disable this way to disable dynamo as the decorater way would +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # trigger build failure with torch deploy... self.post_unflatten_transform = torch._dynamo.disable( # type: ignore[method-assign] self.post_unflatten_transform diff --git a/torch/distributed/tensor/parallel/loss.py b/torch/distributed/tensor/parallel/loss.py index 32a90bc8f1fb3..4842ce1e66b22 100644 --- a/torch/distributed/tensor/parallel/loss.py +++ b/torch/distributed/tensor/parallel/loss.py @@ -112,7 +112,11 @@ def _propagate_tensor_meta( kwargs: dict[str, object], ) -> TensorMeta: op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) +<<<<<<< HEAD tensor_meta = DTensor._op_dispatcher.sharding_propagator.propagate_tensor_meta( +======= + tensor_meta = DTensor._op_dispatcher.sharding_propagator._propagate_tensor_meta( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) op_info.schema ) if isinstance(tensor_meta, TensorMeta): diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index b37d49bd30744..604e07ee62b08 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -41,10 +41,15 @@ def is_shard(self, dim: Optional[int] = None) -> bool: def is_replicate(self) -> bool: return isinstance(self, Replicate) +<<<<<<< HEAD def is_partial(self, reduce_op: Optional[str] = None) -> bool: if reduce_op is None: return isinstance(self, Partial) return isinstance(self, Partial) and self.reduce_op == reduce_op +======= + def is_partial(self) -> bool: + return isinstance(self, Partial) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclass(frozen=True) @@ -701,7 +706,11 @@ def _partition_value( # _partition_value: partition the value of a replicated tensor on the mesh dimension # _partition_value is the conjugate operation of _reduce_value +<<<<<<< HEAD # - i.e. _partition_value on a sum reduce op is just a division operation +======= + # - i.e. _partition_value on a sum reduce op is just a divison operation +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation # TODO: if the reduce_op is min/max, etc. the _partition_value should be a # different operation diff --git a/torch/distributions/exp_family.py b/torch/distributions/exp_family.py index ab8d340bd7931..e72d3157c553d 100644 --- a/torch/distributions/exp_family.py +++ b/torch/distributions/exp_family.py @@ -1,6 +1,9 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD from typing import Union +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch import Tensor from torch.distributions.distribution import Distribution @@ -57,7 +60,11 @@ def entropy(self): """ Method to compute the entropy using Bregman divergence of the log normalizer. """ +<<<<<<< HEAD result: Union[Tensor, float] = -self._mean_carrier_measure +======= + result = -self._mean_carrier_measure +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) nparams = [p.detach().requires_grad_() for p in self._natural_params] lg_normal = self._log_normalizer(*nparams) gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True) diff --git a/torch/distributions/lkj_cholesky.py b/torch/distributions/lkj_cholesky.py index f3fc4b20751e4..de3a020c8fc62 100644 --- a/torch/distributions/lkj_cholesky.py +++ b/torch/distributions/lkj_cholesky.py @@ -130,7 +130,11 @@ def log_prob(self, value): # Additionally, the Jacobian of the transformation from Cholesky factor to # correlation matrix is: # prod(L_ii ^ (D - i)) +<<<<<<< HEAD # So the probability of a Cholesky factor is proportional to +======= + # So the probability of a Cholesky factor is propotional to +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i) # with order_i = 2 * concentration - 2 + D - i if self._validate_args: diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index 1724b586b5a76..6f33482e960b9 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -170,7 +170,11 @@ def log_prob(self, value): if self._validate_args: self._validate_sample(value) event_dim = len(self.event_shape) +<<<<<<< HEAD log_prob: Union[Tensor, float] = 0.0 +======= + log_prob = 0.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) y = value for transform in reversed(self.transforms): x = transform.inv(y) diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 9584bb0b342d1..0d96e472715b6 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -840,7 +840,11 @@ def inverse_shape(self, shape): class CorrCholeskyTransform(Transform): r""" +<<<<<<< HEAD Transforms an unconstrained real vector :math:`x` with length :math:`D*(D-1)/2` into the +======= + Transforms an uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows: @@ -907,7 +911,11 @@ def forward_shape(self, shape): N = shape[-1] D = round((0.25 + 2 * N) ** 0.5 + 0.5) if D * (D - 1) // 2 != N: +<<<<<<< HEAD raise ValueError("Input is not a flattened lower-diagonal number") +======= + raise ValueError("Input is not a flattend lower-diagonal number") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return shape[:-1] + (D, D) def inverse_shape(self, shape): diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 621cabf15a3b8..73dba4e239db9 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD import logging import os import warnings @@ -34,6 +35,63 @@ "ShapesCollection", "unflatten", "UnflattenedModule", +======= +import builtins +import copy +import dataclasses +import inspect +import os +import sys +import typing +import warnings +import zipfile +from collections.abc import Iterator +from enum import auto, Enum +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import torch +import torch.utils._pytree as pytree +from torch.fx._compatibility import compatibility +from torch.fx.passes.infra.pass_base import PassResult +from torch.fx.passes.infra.pass_manager import PassManager +from torch.types import FileLike +from torch.utils._pytree import ( + FlattenFunc, + FromDumpableContextFn, + ToDumpableContextFn, + UnflattenFunc, +) + + +if TYPE_CHECKING: + # Import the following modules during type checking to enable code intelligence features, + # Do not import unconditionally, as they import sympy and importing sympy is very slow + from torch._ops import OpOverload + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + + +__all__ = [ + "Constraint", + "Dim", + "ExportBackwardSignature", + "ExportGraphSignature", + "ExportedProgram", + "CustomDecompTable", + "ModuleCallEntry", + "ModuleCallSignature", + "default_decompositions", + "dims", + "export", + "export_for_training", + "load", + "register_dataclass", + "save", + "unflatten", + "FlatArgsAdapter", + "UnflattenedModule", + "AdditionalInputs", + "draft_export", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] # To make sure export specific custom ops are loaded @@ -53,6 +111,7 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] +<<<<<<< HEAD log: logging.Logger = logging.getLogger(__name__) @@ -70,6 +129,17 @@ def export_for_training( strict: bool = False, preserve_module_call_signature: tuple[str, ...] = (), prefer_deferred_runtime_asserts_over_guards: bool = False, +======= + +def export_for_training( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + strict: bool = False, + preserve_module_call_signature: tuple[str, ...] = (), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> ExportedProgram: """ :func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing @@ -158,19 +228,30 @@ def export_for_training( dynamic_shapes, strict=strict, preserve_module_call_signature=preserve_module_call_signature, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def export( mod: torch.nn.Module, args: tuple[Any, ...], +<<<<<<< HEAD kwargs: Optional[Mapping[str, Any]] = None, *, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None, strict: bool = False, preserve_module_call_signature: tuple[str, ...] = (), prefer_deferred_runtime_asserts_over_guards: bool = False, +======= + kwargs: Optional[dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + strict: bool = False, + preserve_module_call_signature: tuple[str, ...] = (), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> ExportedProgram: """ :func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing @@ -282,7 +363,10 @@ def export( strict=strict, preserve_module_call_signature=preserve_module_call_signature, pre_dispatch=True, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) except Exception as e: draft_export_msg = ( @@ -447,8 +531,12 @@ def load( f, expected_opset_version=expected_opset_version, ) +<<<<<<< HEAD except RuntimeError as e: log.warning("Ran into the following error when deserializing: %s", e) +======= + except RuntimeError: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pt2_contents = PT2ArchiveContents({}, {}, {}) if len(pt2_contents.exported_programs) > 0 or len(pt2_contents.extra_files) > 0: @@ -458,6 +546,7 @@ def load( return pt2_contents.exported_programs["model"] # TODO: For backward compatibility, we support loading a zip file from 2.7. Delete this path in 2.9(?) +<<<<<<< HEAD with zipfile.ZipFile(f, "r") as zipf: if "version" not in zipf.namelist(): raise RuntimeError( @@ -470,6 +559,12 @@ def load( "deprecated. Please generate a new pt2 saved file." ) +======= + warnings.warn( + "This version of file is deprecated. Please generate a new pt2 saved file." + ) + with zipfile.ZipFile(f, "r") as zipf: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Check the version version = zipf.read("version").decode().split(".") from torch._export.serde.schema import ( @@ -535,12 +630,20 @@ def load( def draft_export( mod: torch.nn.Module, args: tuple[Any, ...], +<<<<<<< HEAD kwargs: Optional[Mapping[str, Any]] = None, *, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None, preserve_module_call_signature: tuple[str, ...] = (), strict: bool = False, prefer_deferred_runtime_asserts_over_guards: bool = False, +======= + kwargs: Optional[dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + preserve_module_call_signature: tuple[str, ...] = (), + strict: bool = False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> ExportedProgram: """ A version of torch.export.export which is designed to consistently produce @@ -556,7 +659,10 @@ def draft_export( dynamic_shapes=dynamic_shapes, preserve_module_call_signature=preserve_module_call_signature, strict=strict, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/torch/export/_draft_export.py b/torch/export/_draft_export.py index 2b14327b24512..b1849d1437bfe 100644 --- a/torch/export/_draft_export.py +++ b/torch/export/_draft_export.py @@ -4,24 +4,37 @@ import os import re import tempfile +<<<<<<< HEAD import time from collections.abc import Mapping +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from dataclasses import dataclass from enum import IntEnum from typing import Any, Callable, Optional, Union import torch import torch._logging._internal +<<<<<<< HEAD import torch.utils._pytree as pytree from torch._dynamo.exc import UserError, UserErrorType +======= +import torch._logging.structured +import torch.utils._pytree as pytree +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._export.passes.insert_custom_op_guards import ( get_op_profiles, insert_custom_op_guards, OpProfile, ) +<<<<<<< HEAD from torch._utils_internal import log_draft_export_usage from ._trace import _export, get_ep_stats +======= + +from ._trace import _export +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .dynamic_shapes import _DimHint, _DimHintType, Dim from .exported_program import ExportedProgram @@ -365,12 +378,17 @@ def _log_expression_created( def draft_export( mod: torch.nn.Module, args: tuple[Any, ...], +<<<<<<< HEAD kwargs: Optional[Mapping[str, Any]] = None, +======= + kwargs: Optional[dict[str, Any]] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, preserve_module_call_signature: tuple[str, ...] = (), strict: bool = False, pre_dispatch: bool = True, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards: bool = False, ) -> ExportedProgram: start_time = time.time() @@ -378,6 +396,12 @@ def draft_export( dynamic_shapes = dynamic_shapes or {} constraint_violation_msg = None +======= +) -> ExportedProgram: + kwargs = kwargs or {} + dynamic_shapes = dynamic_shapes or {} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) capture_structured_log = CaptureStructuredTrace() with ( @@ -397,6 +421,7 @@ def draft_export( strict=strict, pre_dispatch=pre_dispatch, preserve_module_call_signature=preserve_module_call_signature, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, ) except Exception as exc: @@ -433,6 +458,28 @@ def convert_dim_to_auto(dim: Any) -> Any: type=f"{type(exc).__name__}.{type(exc).__qualname__}", ) raise exc +======= + ) + except torch._dynamo.exc.UserError: + + def convert_dim_to_auto(dim: Any) -> Any: + if isinstance(dim, Dim): + return Dim.AUTO(min=dim.min, max=dim.max) + elif isinstance(dim, _DimHint) and dim.type == _DimHintType.DYNAMIC: + return Dim.AUTO(min=dim.min, max=dim.max) + return dim + + new_shapes = pytree.tree_map(convert_dim_to_auto, dynamic_shapes) + ep = _export( + mod, + args, + kwargs, + dynamic_shapes=new_shapes, + strict=strict, + pre_dispatch=pre_dispatch, + preserve_module_call_signature=preserve_module_call_signature, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._logging.dtrace_structured("exported_program", payload_fn=lambda: str(ep)) @@ -531,6 +578,7 @@ def convert_dim_to_auto(dim: Any) -> Any: """ ) +<<<<<<< HEAD log_draft_export_usage( error=False, export_time=time.time() - start_time, @@ -539,4 +587,6 @@ def convert_dim_to_auto(dim: Any) -> Any: report=ep._report, **get_ep_stats(ep), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ep diff --git a/torch/export/_remove_effect_tokens_pass.py b/torch/export/_remove_effect_tokens_pass.py index bde7eb6042245..2bc737abe877c 100644 --- a/torch/export/_remove_effect_tokens_pass.py +++ b/torch/export/_remove_effect_tokens_pass.py @@ -23,7 +23,11 @@ def _remove_effect_tokens_from_graph_helper( output_node = None with_effect_nodes: list[torch.fx.Node] = [] +<<<<<<< HEAD # Output node need to check its args against output_token_names (collected from output_spec) +======= + # Output node need to check its args agianst output_token_names (collected from output_spec) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Therefore, we only need to find the top-levele output node output_node = next(reversed(ep.graph_module.graph.find_nodes(op="output"))) for module in ep.graph_module.modules(): @@ -126,7 +130,11 @@ def _remove_effect_tokens_from_graph_helper( def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: """ +<<<<<<< HEAD Removes the existence of tokens from the exported program, including: +======= + Removes the existance of tokens from the exported program, including: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - Removes the input and output tokens - Replaces with_effects(token, func, args) with just func(args) diff --git a/torch/export/_swap.py b/torch/export/_swap.py index 4c93956e32b49..3dd66c4fff28e 100644 --- a/torch/export/_swap.py +++ b/torch/export/_swap.py @@ -163,7 +163,11 @@ def _remove_extraneous_pytrees(gm: torch.fx.GraphModule) -> None: """ for node in gm.graph.nodes: +<<<<<<< HEAD if node.op == "call_module" and node.target != "_guards_fn": +======= + if node.op == "call_module": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _try_remove_connecting_pytrees(node) gm.graph.eliminate_dead_code() diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 76d80ff6eeec8..abd1caca9587c 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -2,18 +2,28 @@ # mypy: allow-untyped-defs import dataclasses import functools +<<<<<<< HEAD import gc import inspect import logging import os +======= +import inspect +import logging +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import re import sys import time import warnings +<<<<<<< HEAD import weakref from contextlib import contextmanager, nullcontext from typing import Any, Callable, Optional, Union from typing_extensions import TypeAlias +======= +from contextlib import contextmanager, nullcontext +from typing import Any, Callable, Optional, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch._dynamo @@ -52,13 +62,22 @@ ) from torch._export.verifier import SpecViolationError from torch._export.wrappers import _wrap_submodules +<<<<<<< HEAD from torch._functorch._aot_autograd.graph_capture_wrappers import create_functional_call +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._functorch._aot_autograd.input_output_analysis import ( _graph_input_names, _graph_output_names, ) from torch._functorch._aot_autograd.schemas import GraphSignature from torch._functorch._aot_autograd.subclass_utils import get_subclass_typing_container +<<<<<<< HEAD +======= +from torch._functorch._aot_autograd.traced_function_transforms import ( + create_functional_call, +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._functorch._aot_autograd.utils import ( create_tree_flattened_fn, register_buffer_assignment_hook, @@ -72,7 +91,10 @@ from torch._logging import dtrace_structured from torch._subclasses.fake_tensor import FakeTensorMode from torch._utils_internal import log_export_usage +<<<<<<< HEAD from torch.export._leakage_detection_utils import find_legit_leaks_from_referrers +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.export._unlift import _check_input_constraints_pre_hook from torch.export.dynamic_shapes import ( _check_dynamic_shapes, @@ -99,6 +121,11 @@ from torch.utils._pytree import TreeSpec from torch.utils._sympy.value_ranges import ValueRangeError +<<<<<<< HEAD +======= +from ._safeguard import AutogradStateOpsFailSafeguard +from ._wrapper_utils import _WrapperModule +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .exported_program import ( _disable_prexisiting_fake_mode, ExportedProgram, @@ -111,12 +138,15 @@ log = logging.getLogger(__name__) +<<<<<<< HEAD NONSTRICT_EXPORT_SANITIZE_TRACE = "NONSTRICT_EXPORT_SANITIZE_TRACE" # Type alias for dynamic shapes specification _DynamicShapesSpec: TypeAlias = Union[dict[str, Any], tuple[Any, ...], list[Any]] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @dataclasses.dataclass class ExportDynamoConfig: @@ -210,7 +240,11 @@ def _strip_root(x): def _rewrite_tracepoint_node(gm: torch.fx.GraphModule): """ +<<<<<<< HEAD In-place modify input graph module by replacing the export tracepoint with a new node +======= + In-place modifiy input graph module by replacing the export tracepoint with a new node +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) that has the same target and args, but with the _export_root stripped from path. """ for node in gm.graph.nodes: @@ -268,7 +302,11 @@ def _extract_fake_inputs(gm, args, kwargs): # We get both because now we might have a combination of symint and tensor # inputs, and we want to check that the shape env is consistent between +<<<<<<< HEAD # both. Unfortunately we can't see what fake mode is attached to the shape +======= + # both. Unforunately we can't see what fake mode is attached to the shape +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # env, then we can just compare fake modes. detected_fake_mode = detect_fake_mode(fake_inps + fake_vals) detected_shape_env = detect_shape_env(fake_inps + fake_vals) @@ -390,9 +428,13 @@ def _get_param_buffer_mapping( param_lookup: dict[int, str] = {} buffer_lookup: dict[int, str] = {} for name, param in original_module.named_parameters(remove_duplicate=False): +<<<<<<< HEAD if param_lookup.get(id(param)) is None: # we only want to keep the first occurrence of a parameter to guarantee parity of original and traced module. param_lookup[id(param)] = name +======= + param_lookup[id(param)] = name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for name, buffer in original_module.named_buffers(remove_duplicate=False): buffer_lookup[id(buffer)] = name @@ -683,8 +725,13 @@ def _restore_state_dict( Restores the state dict of the traced module to that of the original module. """ param_buffer_table = _get_param_buffer_mapping(original_module, traced_module) +<<<<<<< HEAD # Since the graph module is flattened (no module hierarchy), we # need to normalize the module by replacing "." with "_". If we +======= + # Since the graph module is flattened (no module heirarchy), we + # need to noramlize the module by replacing "." with "_". If we +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # don't, it will try to save the weight to a submodule which no # longer exists. for name, fqn in param_buffer_table.items(): @@ -744,10 +791,13 @@ def _make_module_call_graph( return [*original, *additional] +<<<<<<< HEAD class _ExportModuleSpecTrackerDict(dict): pass +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _export_to_torch_ir( f: Callable, args: tuple[Any, ...], @@ -756,7 +806,11 @@ def _export_to_torch_ir( *, preserve_module_call_signature: tuple[str, ...] = (), disable_constraint_solver: bool = False, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards: bool = False, +======= + allow_complex_guards_as_runtime_asserts: bool = False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) restore_fqn: bool = True, _log_export_usage: bool = True, same_signature: bool = True, @@ -800,9 +854,13 @@ def _export_to_torch_ir( with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)): try: +<<<<<<< HEAD module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = ( _ExportModuleSpecTrackerDict() ) +======= + module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ctx = nullcontext() if not isinstance(f, torch.fx.GraphModule): ctx = _wrap_submodules( # type: ignore[assignment] @@ -816,7 +874,14 @@ def _export_to_torch_ir( assume_static_by_default=True, tracing_mode="symbolic", disable_constraint_solver=disable_constraint_solver, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, +======= + # currently the following 2 flags are tied together for export purposes, + # but untangle for sake of dynamo export api + prefer_deferred_runtime_asserts_over_guards=True, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _log_export_usage=_log_export_usage, same_signature=same_signature, )( @@ -851,9 +916,29 @@ def _export_to_aten_ir( transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later. pre_dispatch=False, decomp_table=None, +<<<<<<< HEAD _prettify_placeholder_names: bool = True, decompose_custom_triton_ops: bool = False, ) -> ATenExportArtifact: +======= + _check_autograd_state: bool = True, + _is_torch_jit_trace: bool = False, + _prettify_placeholder_names: bool = True, + decompose_custom_triton_ops: bool = False, +) -> ATenExportArtifact: + # [NOTE] If the user is exporting under training mode, we want to detect if there is any + # state change in the autograd global state and error. If the user is exporting under inference + # mode, we don't care. At predispatch level, we don't care about the state change. + is_grad_enabled = torch._C.is_grad_enabled() + grad_safe_guard = nullcontext() + # export_to_aten_ir is called when we decompose the ep into inference IR + # In that setting, we actually shouldn't check the state change as at this point, + # because the intention is specalizing to inference. + if _check_autograd_state: + if not pre_dispatch and is_grad_enabled: + grad_safe_guard = AutogradStateOpsFailSafeguard() # type: ignore[assignment] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) custom_triton_ops_decomposition_ctx = ( nullcontext if decompose_custom_triton_ops @@ -870,6 +955,10 @@ def _export_to_aten_ir( strict=True, stack_weights=True, ), +<<<<<<< HEAD +======= + grad_safe_guard, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _ignore_backend_decomps(), _compiling_state_context(), custom_triton_ops_decomposition_ctx(), @@ -891,8 +980,11 @@ def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm): new_output_node = list(new_gm.graph.nodes)[-1] assert old_output_node.op == "output" and new_output_node.op == "output" # make sure we don't override any meta +<<<<<<< HEAD if "desc" in new_output_node.meta: del new_output_node.meta["desc"] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(new_output_node.meta) == 0 new_output_node.meta.update(old_output_node.meta) @@ -1202,6 +1294,7 @@ def _get_original_state_dict(mod: torch.nn.Module) -> dict[str, Any]: return original_state_dict +<<<<<<< HEAD def _process_export_inputs( mod: torch.nn.Module, args: tuple[object, ...], @@ -1247,6 +1340,9 @@ def _process_export_inputs( Raises: UserError: If args is not a tuple. """ +======= +def _process_export_inputs(mod, args, kwargs, dynamic_shapes): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not isinstance(args, tuple): raise UserError( UserErrorType.INVALID_INPUT, @@ -1255,6 +1351,7 @@ def _process_export_inputs( kwargs = kwargs if kwargs is not None else {} _, original_in_spec = pytree.tree_flatten((args, kwargs)) +<<<<<<< HEAD verify_additional_inputs: Callable[[ExportedProgram], None] out_dynamic_shapes: Optional[_DynamicShapesSpec] if isinstance(dynamic_shapes, torch.export.AdditionalInputs): @@ -1268,6 +1365,17 @@ def _process_export_inputs( out_dynamic_shapes = dynamic_shapes return args, kwargs, original_in_spec, out_dynamic_shapes, verify_additional_inputs +======= + if isinstance(dynamic_shapes, torch.export.AdditionalInputs): + verify_additional_inputs = dynamic_shapes.verify + dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) + else: + verify_additional_inputs = lambda ep: None # noqa: E731 + if isinstance(dynamic_shapes, torch.export.ShapesCollection): + dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) + + return args, kwargs, original_in_spec, dynamic_shapes, verify_additional_inputs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _get_module_call_graph( @@ -1297,7 +1405,11 @@ def _get_module_call_graph( outputs=[], in_spec=specs["in_spec"], out_spec=specs["out_spec"], +<<<<<<< HEAD forward_arg_names=None, # we only propagate forward_arg_names for the top level module +======= + forward_arg_names=None, # we only propage forward_arg_names for the top level module +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if len(preserve_module_call_signature) > 0: @@ -1323,6 +1435,10 @@ def _get_range_constraints( args, kwargs, dynamic_shapes, +<<<<<<< HEAD +======= + _is_torch_jit_trace=False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): gm: torch.fx.GraphModule = export_artifact.aten.gm export_graph_signature: ExportGraphSignature = export_artifact.aten.sig @@ -1335,6 +1451,7 @@ def _get_range_constraints( ), len(export_graph_signature.input_specs), ) +<<<<<<< HEAD combined_args = _combine_args(mod, args, kwargs) # This is because we trace based on the kwargs passed in from user @@ -1350,6 +1467,26 @@ def _get_range_constraints( combined_args_traced_order[key] = kwargs[key] combined_args = combined_args_traced_order +======= + combined_args = _combine_args( + mod, args, kwargs, _is_torch_jit_trace=_is_torch_jit_trace + ) + + # This is because we trace based on the kewargs passed in from user + # not based on the signature. I feel it would be better to just enforce + # one ordering at the start of tracing to avoid confusions, but that is + # bigger refactor, so do this to unblock for now. + if not _is_torch_jit_trace: + combined_args_traced_order = {} + for arg in combined_args: + if arg not in kwargs: + combined_args_traced_order[arg] = combined_args[arg] + + for key in kwargs: + combined_args_traced_order[key] = kwargs[key] + + combined_args = combined_args_traced_order +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) range_constraints = make_constraints( fake_mode, @@ -1398,6 +1535,47 @@ def _temp_disable_texpr_fuser(): torch._C._jit_set_texpr_fuser_enabled(original_state) +<<<<<<< HEAD +======= +def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None): + with _temp_disable_texpr_fuser(): + from torch.jit._trace import TopLevelTracedModule + + export_args, export_kwargs = _process_jit_trace_inputs_for_export(args, kwargs) + + if isinstance(traced_callable, (TopLevelTracedModule, torch._C.ScriptModule)): # type: ignore[operator] + return _export( + traced_callable, + export_args, + export_kwargs, + strict=False, + _is_torch_jit_trace=True, + ).module() + + elif isinstance(traced_callable, torch.ScriptMethod) and isinstance( + traced_callable.owner(), # type: ignore[operator] + (torch._C.ScriptModule, torch.nn.Module), + ): + with patch_forward(traced_callable.owner(), traced_callable): # type: ignore[operator] + return _export( + traced_callable.owner(), # type: ignore[operator] + export_args, + export_kwargs, + strict=False, + _is_torch_jit_trace=True, + ).module() + + else: + return _export( + _WrapperModule(traced_callable), + export_args, + export_kwargs, + strict=False, + _is_torch_jit_trace=True, + ).module() + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _strict_export( mod: torch.nn.Module, args: tuple[Any, ...], @@ -1405,7 +1583,12 @@ def _strict_export( dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]], preserve_module_call_signature: tuple[str, ...], orig_in_spec: TreeSpec, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards: bool, +======= + allow_complex_guards_as_runtime_asserts: bool, + _is_torch_jit_trace: bool, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _to_aten_func: Callable, ) -> ExportArtifact: """ @@ -1419,7 +1602,11 @@ def _strict_export( dynamic_shapes, preserve_module_call_signature=preserve_module_call_signature, restore_fqn=False, # don't need to restore because we will do it later +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, +======= + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _log_export_usage=False, ) @@ -1685,6 +1872,7 @@ def override_getattribute_for_subclasses(args): for k, (old_getattr, _) in tensor_type_to_old_getattribute.items(): k.__getattribute__ = old_getattr # type: ignore[method-assign, attr-defined] +<<<<<<< HEAD @contextmanager def _maybe_restore_grad_state(): """ @@ -1703,6 +1891,9 @@ def _maybe_restore_grad_state(): override_getattribute_for_subclasses(flat_args), _maybe_restore_grad_state(), ): +======= + with ctx, override_getattribute_for_subclasses(flat_args): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gm = make_fx( wrapped_fn, record_module_stack=True, @@ -1753,7 +1944,10 @@ def _is_impure(node): gm.graph.eliminate_dead_code(_is_impure) # create graph signature +<<<<<<< HEAD assert out_spec.spec is not None, "out_spec.spec is None!" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) input_names = _graph_input_names(gm) output_names = _graph_output_names(gm) sig = GraphSignature( @@ -1766,10 +1960,16 @@ def _is_impure(node): zip(input_names[param_len : param_len + buffer_len], named_buffers) ), buffers_to_mutate={}, +<<<<<<< HEAD parameters_to_mutate={}, user_inputs_to_mutate={}, in_spec=in_spec, out_spec=out_spec.spec, +======= + user_inputs_to_mutate={}, + in_spec=in_spec, + out_spec=out_spec, # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) backward_signature=None, input_tokens=[], output_tokens=[], @@ -1862,7 +2062,12 @@ def _non_strict_export( dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]], preserve_module_call_signature: tuple[str, ...], orig_in_spec: TreeSpec, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards: bool, +======= + allow_complex_guards_as_runtime_asserts: bool, + _is_torch_jit_trace: bool, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _to_aten_func: Callable, ) -> ExportArtifact: """ @@ -1929,9 +2134,12 @@ def forward(self, *args, **kwargs): _strip_root, sig.inputs_to_parameters ) sig.buffers_to_mutate = pytree.tree_map(_strip_root, sig.buffers_to_mutate) +<<<<<<< HEAD sig.parameters_to_mutate = pytree.tree_map( _strip_root, sig.parameters_to_mutate ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for node in gm.graph.nodes: if "nn_module_stack" in node.meta: @@ -1959,7 +2167,12 @@ def forward(self, *args, **kwargs): args, kwargs, dynamic_shapes, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, # for shape env initialization +======= + _is_torch_jit_trace=_is_torch_jit_trace, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, # for shape env initialization +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) fake_params_buffers = _fakify_params_buffers(fake_mode, mod) @@ -1971,6 +2184,10 @@ def _produce_guards_callback(gm): dynamic_shapes=dynamic_shapes, equalities_inputs=equalities_inputs, original_signature=original_signature, +<<<<<<< HEAD +======= + _is_torch_jit_trace=_is_torch_jit_trace, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) tx = TracingContext(fake_mode) @@ -2040,7 +2257,10 @@ def _export_for_training( *, strict: bool = True, preserve_module_call_signature: tuple[str, ...] = (), +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards: bool = False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> ExportedProgram: global _EXPORT_MODULE_HIERARCHY _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod) @@ -2058,6 +2278,7 @@ def _export_for_training( # Call the appropriate export function based on the strictness of tracing. export_func = _strict_export if strict else _non_strict_export +<<<<<<< HEAD alive_fake_input_ids_before_export: list[int] = [] if not strict and os.environ.get(NONSTRICT_EXPORT_SANITIZE_TRACE, "0") == "1": @@ -2068,6 +2289,8 @@ def _export_for_training( if isinstance(i, torch._subclasses.fake_tensor.FakeTensor) ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) export_artifact = export_func( mod=mod, args=args, @@ -2075,7 +2298,12 @@ def _export_for_training( dynamic_shapes=dynamic_shapes, preserve_module_call_signature=preserve_module_call_signature, orig_in_spec=orig_in_spec, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, +======= + allow_complex_guards_as_runtime_asserts=False, + _is_torch_jit_trace=False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _to_aten_func=_export_to_aten_ir_make_fx, ) @@ -2123,6 +2351,7 @@ def _export_for_training( ) verify_additional_inputs(exported_program) +<<<<<<< HEAD if not strict and os.environ.get(NONSTRICT_EXPORT_SANITIZE_TRACE, "0") == "1": # See NOTE [export non-strict fake tensor leak detection] @@ -2170,6 +2399,8 @@ def _export_for_training( del legit_leak +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return exported_program @@ -2184,7 +2415,12 @@ def _export( strict: bool = True, preserve_module_call_signature: tuple[str, ...] = (), pre_dispatch: bool = False, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards: bool = False, +======= + allow_complex_guards_as_runtime_asserts: bool = False, + _is_torch_jit_trace: bool = False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> ExportedProgram: """ Traces either an nn.Module's forward function or just a callable with PyTorch @@ -2215,7 +2451,11 @@ def _export( preserve_module_call_signature: A list of submodule paths for which the original calling conventions are preserved as metadata. +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards: +======= + allow_complex_guards_as_runtime_asserts: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) With the current dynamic shapes language for dims and derived dims, we can run into constraints that are not expressible with the language. For example, flattening a matrix and adding to a vector, both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible. @@ -2259,7 +2499,10 @@ def _export( dynamic_shapes, strict=strict, preserve_module_call_signature=preserve_module_call_signature, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) dtrace_structured("exported_program", payload_fn=lambda: str(ep)) return ep @@ -2284,15 +2527,30 @@ def _export( dynamic_shapes=dynamic_shapes, preserve_module_call_signature=preserve_module_call_signature, orig_in_spec=original_in_spec, +<<<<<<< HEAD prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, _to_aten_func=functools.partial( _export_to_aten_ir, pre_dispatch=pre_dispatch, +======= + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + _is_torch_jit_trace=_is_torch_jit_trace, + _to_aten_func=functools.partial( + _export_to_aten_ir, + pre_dispatch=pre_dispatch, + _is_torch_jit_trace=_is_torch_jit_trace, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ) export_graph_signature: ExportGraphSignature = export_artifact.aten.sig +<<<<<<< HEAD forward_arg_names = _get_forward_arg_names(mod, args, kwargs) +======= + forward_arg_names = ( + _get_forward_arg_names(mod, args, kwargs) if not _is_torch_jit_trace else None + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inline_constraints = _get_inline_constraints(export_artifact.fake_mode) # The unbacked symint symbols are updated in aot_export # so we serialize them here instead of inside dynamo. @@ -2304,6 +2562,10 @@ def _export( args, kwargs, dynamic_shapes, +<<<<<<< HEAD +======= + _is_torch_jit_trace=_is_torch_jit_trace, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) gm, module_call_graph = _get_module_call_graph( export_artifact, @@ -2314,7 +2576,12 @@ def _export( _verify_nn_module_stack(gm) _verify_stack_trace(gm) +<<<<<<< HEAD _verify_placeholder_names(gm, export_graph_signature) +======= + if not _is_torch_jit_trace: + _verify_placeholder_names(gm, export_graph_signature) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Remove Proxy because they cannot be deepcopied or pickled. torch._export.utils.remove_proxy_from_state_dict(original_state_dict, in_place=True) diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index f876e462214ca..adb094375b865 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -1,14 +1,20 @@ # mypy: allow-untyped-defs import copy +<<<<<<< HEAD import inspect import math +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import warnings from collections.abc import Sequence from itertools import chain from typing import Any, Optional +<<<<<<< HEAD import sympy +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import torch.utils._pytree as pytree from torch._export.non_strict_utils import ( @@ -16,16 +22,22 @@ _exit_enable_graph_inputs_of_type_nn_module, _get_graph_inputs_of_type_nn_module, ) +<<<<<<< HEAD from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( _convert_range_to_int, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._export.utils import _check_input_constraints_for_graph from torch.export.unflatten import _assign_attr, _AttrKind from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo from torch.fx.traceback import NodeSource, NodeSourceAction +<<<<<<< HEAD from torch.utils._sympy.solve import try_solve from torch.utils._sympy.value_ranges import ValueRanges +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ._remove_effect_tokens_pass import _remove_effect_tokens from ._tree_utils import reorder_kwargs @@ -82,6 +94,7 @@ def _check_inputs_match(args, kwargs, in_spec: pytree.TreeSpec) -> list: return flat_args_with_path +<<<<<<< HEAD def _convert_guards_code_to_fn( guards_code: list[str], paths_of_placeholders: list[pytree.KeyPath], @@ -181,6 +194,20 @@ def _check_input_constraints_pre_hook(self, args, kwargs): # NOTE: this call is Dynamo disabled, as it used to be _check_input_constraints_for_module(self, args, kwargs) +======= +@torch._dynamo.disable +def _check_input_constraints_pre_hook(self, args, kwargs): + if not self.validate_inputs: + return + + flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec) + + _check_input_constraints_for_graph( + [node for node in self.graph.nodes if node.op == "placeholder"], + flat_args_with_path, + self.range_constraints, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _unlift_inputs_as_getattr( @@ -234,7 +261,16 @@ def _insert_copy_for_mutations( Find the all the buffers and inputs that were mutated and insert copy_ operators to reflect mutations. """ +<<<<<<< HEAD output_node = gm.graph.output_node() +======= + output_node = None + for node in gm.graph.nodes: + if node.op == "output": + output_node = node + break + assert output_node is not None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) outputs = pytree.tree_flatten(output_node.args)[0] assert len(outputs) == len(mutated_outputs) @@ -260,6 +296,7 @@ def _insert_copy_for_mutations( ) return_nodes_to_copy[return_node] = copy_node +<<<<<<< HEAD output_args = tuple( return_nodes_to_copy[node] if node in return_nodes_to_copy else node for node in user_output_nodes @@ -267,6 +304,15 @@ def _insert_copy_for_mutations( with gm.graph.inserting_before(output_node): # Only return user outputs new_output = gm.graph.output(output_args) +======= + output_args = [ + return_nodes_to_copy[node] if node in return_nodes_to_copy else node + for node in user_output_nodes + ] + with gm.graph.inserting_before(output_node): + # Only return user outputs + new_output = gm.graph.output(tuple(output_args)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_node.replace_all_uses_with(new_output) gm.graph.erase_node(output_node) new_output.name = output_node.name @@ -290,6 +336,7 @@ def _get_codegen( """ if forward_arg_names: names = forward_arg_names +<<<<<<< HEAD elif ( in_spec.type == tuple and in_spec.num_children == 2 @@ -302,6 +349,21 @@ def _get_codegen( names.extend(in_spec.children_specs[1].context) else: names = [f"arg_{i}" for i in range(in_spec.num_children)] +======= + else: + if ( + in_spec.type == tuple + and in_spec.num_children == 2 + and in_spec.children_specs[0].type == tuple + and in_spec.children_specs[1].type == dict + ): + # if in_spec contains the args (tuple) and kwargs (dict) + names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)] + # add kwarg names + names.extend(in_spec.children_specs[1].context) + else: + names = [f"arg_{i}" for i in range(in_spec.num_children)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return _PyTreeCodeGen( _PyTreeInfo( @@ -318,6 +380,11 @@ def _unlift( mutated_outputs: Sequence[Optional[str]], in_spec: pytree.TreeSpec, out_spec: Optional[pytree.TreeSpec], +<<<<<<< HEAD +======= + state_dict: dict[str, Any], + constants: dict[str, Any], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) forward_arg_names: Optional[list[str]] = None, ): """ @@ -457,7 +524,11 @@ def _create_stateful_graph_module( for constant_fqn in ep.graph_signature.lifted_tensor_constants: # Sometimes, the constant can require gradient, this is probably a bug in user code, # e.g. `self.const = torch.randn(2, 2, requires_grad=True)`. +<<<<<<< HEAD # We call detach on the constant_val since they're tensor constants and we don't need to +======= + # We call detach on the constant_val since they're tensor contants and we don't need to +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # compute their gradients anyway. # Users should properly register it as parameter if they want it to require gradient. buffer = stateful_gm.get_buffer(constant_fqn) @@ -515,6 +586,7 @@ def _create_stateful_graph_module( return stateful_gm +<<<<<<< HEAD def _get_input_paths(example_inputs, signature): """ Generate paths of placeholders, needed for generating the guards function. @@ -658,6 +730,12 @@ def _unlift_exported_program_lifted_states( if ep.verifiers[0].dialect != "TRAINING": ep = _remove_effect_tokens(ep) +======= +def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Module: + # TODO T206340015 + if ep.verifiers[0].dialect != "TRAINING": + ep = _remove_effect_tokens(ep) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants) forward_arg_names = ( @@ -682,16 +760,21 @@ def _unlift_exported_program_lifted_states( ( out_spec.target if out_spec.kind +<<<<<<< HEAD in ( OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION, OutputKind.PARAMETER_MUTATION, ) +======= + in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else None ) for out_spec in ep.graph_signature.output_specs ] +<<<<<<< HEAD source_node_dict = { node.name: node for node in ep.graph.nodes if node.op != "placeholder" } @@ -714,16 +797,29 @@ def _unlift_exported_program_lifted_states( ] assert ep.call_spec.in_spec is not None +======= + for node in new_gm.graph.nodes: + node.meta["from_node"] = [ + NodeSource(node, "ExportedProgram.module()", NodeSourceAction.CREATE) + ] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_gm = _unlift( new_gm, lifted_inputs, mutated_outputs, ep.call_spec.in_spec, ep.call_spec.out_spec, +<<<<<<< HEAD +======= + ep.state_dict, + ep.constants, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) forward_arg_names=forward_arg_names, ) unlift_gm = _create_stateful_graph_module(new_gm, ep.range_constraints, ep) unlift_gm.meta.update(ep.graph_module.meta) +<<<<<<< HEAD # create a _guards_fn submodule and insert a call to it after placeholders graph = unlift_gm.graph @@ -758,3 +854,6 @@ class GuardsFn(torch.nn.Module): def forward(self, *args): pass +======= + return unlift_gm +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/export/custom_ops.py b/torch/export/custom_ops.py index 9df7988da9314..ddedc03a11489 100644 --- a/torch/export/custom_ops.py +++ b/torch/export/custom_ops.py @@ -1,6 +1,9 @@ +<<<<<<< HEAD # mypy: allow-untyped-defs import importlib +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch @@ -27,6 +30,7 @@ def _access_subclass_inner_tensor( f"Attribute {attr} is not a tensor or doesn't exist in {src_subclass_tensor}" ) return val +<<<<<<< HEAD def _call_custom_autograd_function_in_pre_dispatch(function_cls_name, *args, **kwargs): @@ -47,3 +51,5 @@ def _call_custom_autograd_function_in_pre_dispatch(function_cls_name, *args, **k function_cls = getattr(module, class_name) assert hasattr(function_cls, "apply") return function_cls.apply(*args, **kwargs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index de41fdfdb3467..3356e81f0a4c3 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -85,6 +85,7 @@ def __call__(self, min=None, max=None) -> "_DimHint": class Dim: """ +<<<<<<< HEAD The ``Dim`` class allows users to specify dynamism in their exported programs. By marking a dimension with a ``Dim``, the compiler associates the dimension with a symbolic integer containing a dynamic range. @@ -98,6 +99,17 @@ class Dim: compiler to decide (``Dim.AUTO``). The export process will automatically infer the remaining constraints on min/max ranges and relationships between dimensions. +======= + The `Dim` class allows users to specify dynamism in their exported programs. By marking a dimension with a `Dim`, + the compiler associates the dimension with a symbolic integer containing a dynamic range. + + The API can be used in 2 ways: Dim hints (i.e. automatic dynamic shapes: `Dim.AUTO`, `Dim.DYNAMIC`, `Dim.STATIC`), + or named Dims (i.e. `Dim("name", min=1, max=2)`). + + Dim hints provide the lowest barrier to exportability, with the user only needing to specify if a dimension + if dynamic, static, or left for the compiler to decide (`Dim.AUTO`). The export process will automatically + infer the remaining constraints on min/max ranges and relationships between dimensions. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Example:: @@ -116,6 +128,7 @@ def forward(self, x, y): } ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes) +<<<<<<< HEAD Here, export would raise an exception if we replaced all uses of ``Dim.AUTO`` with ``Dim.DYNAMIC``, as ``x.shape[0]`` is constrained to be static by the model. @@ -123,12 +136,25 @@ def forward(self, x, y): e.g. ``(x.shape[0] + y.shape[1]) % 4 == 0``, to be raised if runtime inputs do not satisfy such constraints. You may also specify min-max bounds for Dim hints, e.g. ``Dim.AUTO(min=16, max=32)``, ``Dim.DYNAMIC(max=64)``, +======= + Here, export would raise an exception if we replaced all uses of `Dim.AUTO` with `Dim.DYNAMIC`, + as x.shape[0] is constrained to be static by the model. + + More complex relations between dimensions may also be codegened as runtime assertion nodes by the compiler, + e.g. (x.shape[0] + y.shape[1]) % 4 == 0, to be raised if runtime inputs do not satisfy such constraints. + + You may also specify min-max bounds for Dim hints, e.g. `Dim.AUTO(min=16, max=32)`, `Dim.DYNAMIC(max=64)`, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with the compiler inferring the remaining constraints within the ranges. An exception will be raised if the valid range is entirely outside the user-specified range. Named Dims provide a stricter way of specifying dynamism, where exceptions are raised if the compiler infers constraints that do not match the user specification. For example, exporting the previous +<<<<<<< HEAD model, the user would need the following ``dynamic_shapes`` argument:: +======= + model, the user would need the following `dynamic_shapes` argument:: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) s0 = Dim("s0") s1 = Dim("s1", min=16) @@ -138,9 +164,14 @@ def forward(self, x, y): } ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes) +<<<<<<< HEAD Named Dims also allow specification of relationships between dimensions, up to univariate linear relations. For example, the following indicates one dimension is a multiple of another plus 4:: +======= + Named Dims also allow specification of relationships between dimensions, up to univariate linear relations. + For example, the following indicates one dimension is a multiple of another plus 4:: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) s0 = Dim("s0") s1 = 3 * s0 + 4 @@ -687,11 +718,16 @@ def _compare(tree, dynamic_shapes, path): raise +<<<<<<< HEAD def _combine_args(f, args, kwargs) -> dict[str, Any]: +======= +def _combine_args(f, args, kwargs, _is_torch_jit_trace=False) -> dict[str, Any]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # combine args and kwargs following the signature of f, as it happens # in the body of f when called with *args, **kwargs if isinstance(f, ExportedProgram): f = f.module() +<<<<<<< HEAD signature = ( inspect.signature(f.forward) @@ -700,6 +736,17 @@ def _combine_args(f, args, kwargs) -> dict[str, Any]: ) kwargs = kwargs if kwargs is not None else {} return signature.bind(*args, **kwargs).arguments +======= + if not _is_torch_jit_trace: + signature = ( + inspect.signature(f.forward) + if isinstance(f, torch.nn.Module) + else inspect.signature(f) + ) + kwargs = kwargs if kwargs is not None else {} + return signature.bind(*args, **kwargs).arguments + return args +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ShapesCollection: @@ -887,7 +934,11 @@ def verify(self, ep): epm = ep.module() for args, kwargs in self._examples: +<<<<<<< HEAD torch.export._unlift._check_input_constraints_for_module( +======= + torch.export._unlift._check_input_constraints_pre_hook( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) epm, args, kwargs or {} ) @@ -945,7 +996,11 @@ def check_symbols(path, tensor, shape): f"Unexpected dimension mapped to index {i} in input tensor shape {shape} " f"specified at `dynamic_shapes{keystr(path)}` " f"(expected None, an int, a Dim, Dim.AUTO, Dim.STATIC, or Dim.DYNAMIC, " +<<<<<<< HEAD f" but got {dim!r} instead)", +======= + f" but got {dim} instead)", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case_name="dynamic_shapes_validation", ) elif isinstance(shape, (tuple, list)): @@ -968,7 +1023,11 @@ def check_symbols(path, tensor, shape): f"Unexpected dimension #{i} in input tensor shape {shape} " f"specified at `dynamic_shapes{keystr(path)}` " f"(expected None, an int, a Dim, Dim.AUTO, Dim.STATIC, or Dim.DYNAMIC, " +<<<<<<< HEAD f"but got {dim!r} instead)", +======= + f"but got {dim} instead)", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case_name="dynamic_shapes_validation", ) elif shape is not None: diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py index 1c87bb29bfe96..9e76fe08de67a 100644 --- a/torch/export/experimental/__init__.py +++ b/torch/export/experimental/__init__.py @@ -1,6 +1,7 @@ import copy import dataclasses import functools +<<<<<<< HEAD import os import types import typing @@ -22,6 +23,18 @@ def _copy_graph_module_and_signature( ep: torch.export.ExportedProgram, +======= +import types +import typing +import typing_extensions + +import torch +from torch.export.exported_program import _decompose_exported_program + + +def _copy_graph_module_and_signature( + ep: torch.fx.GraphModule, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]: # copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(), # and this can break placeholder names in some particular cases. @@ -39,7 +52,11 @@ def _copy_graph_module_and_signature( for old_node, new_node in zip(old_phs, new_phs): new_node.name = old_node.name +<<<<<<< HEAD return gm, new_graph_signature +======= + return gm, new_graph_signature # type: ignore[return-value] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _remove_detach_pass( @@ -84,6 +101,7 @@ def _export_forward_backward( return ep._update(gm, new_graph_signature) +<<<<<<< HEAD def _sticky_export( forward_func: typing.Callable[_InputT, _RetT], dynamic_shapes_callback: typing.Optional[ @@ -95,16 +113,28 @@ def _sticky_export( ] ] = None, ) -> typing.Callable[_InputT, _RetT]: +======= +@typing.no_type_check +def _sticky_export(forward_func, dynamic_shapes_callback=None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Lazily export the model on first forward call. Usage: model.forward = _sticky_export(model.forward, dynamic_shapes_callback=callback) """ +<<<<<<< HEAD model = forward_func.__self__ # type: ignore[attr-defined] original_forward = forward_func.__func__ # type: ignore[attr-defined] @functools.wraps(forward_func) def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: +======= + model = forward_func.__self__ + original_forward = forward_func.__func__ + + @functools.wraps(forward_func) + def wrapper(*args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Unpatch forward to avoid recursion during export model.forward = types.MethodType(original_forward, model) @@ -119,7 +149,11 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: kwargs, dynamic_shapes=dynamic_shapes_spec, ).module() +<<<<<<< HEAD wrapper._exported_artifact = exported # type: ignore[attr-defined] +======= + wrapper._exported_artifact = exported +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) finally: # Restore the wrapper after export model.forward = wrapper @@ -135,6 +169,13 @@ class _ExportMethod: fallbacks: list[torch.export.ExportedProgram] +<<<<<<< HEAD +======= +_InputT = typing_extensions.ParamSpec("_InputT") +_RetT = typing.TypeVar("_RetT") + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _ExportPackage: """ An export package is a collection of torch.export()-ed PyTorch models consisting of @@ -217,7 +258,11 @@ def _exporter( - Returns an optional dynamic shape spec. Exporter will only export an overload when the spec callable successfully returns +<<<<<<< HEAD a result without raising AssertionError. +======= + a result without rasing AssertionError. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) For example: ``` @@ -324,8 +369,12 @@ def _exporter_context(*args, **kwargs): # type: ignore[no-untyped-def] if isinstance(fn, torch.nn.Module): _exporter_context = torch._dynamo.eval_frame.OptimizedModule( # type: ignore[assignment] # noqa: F811 +<<<<<<< HEAD fn, lambda _: _exporter_context, # type: ignore[arg-type] +======= + fn, lambda _: _exporter_context +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def _define_overload( @@ -341,6 +390,7 @@ def _define_overload( _exporter_context._define_overload = _define_overload # type: ignore[attr-defined] return _exporter_context +<<<<<<< HEAD @property def _method_overloads( @@ -427,3 +477,5 @@ def _compiled_and_package( ) with open(Path(base_directory) / "main.cpp", "w") as file: file.write(main_file_str) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 1aa2e59d1752b..e2f953d204716 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -7,10 +7,17 @@ import operator import types import warnings +<<<<<<< HEAD from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager from typing import Any, Callable, final, NamedTuple, Optional, TYPE_CHECKING, Union +======= +from collections import defaultdict, namedtuple +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._guards import tracing, TracingContext from torch._higher_order_ops.utils import autograd_not_implemented @@ -40,7 +47,10 @@ import torch import torch.utils._pytree as pytree from torch._export.utils import ( +<<<<<<< HEAD _build_cache, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _collect_all_valid_cia_ops, _collect_and_set_constant_attrs, _collect_param_buffer_metadata, @@ -289,7 +299,11 @@ def _split_decomp_table_to_cia_and_python_decomp( # decomp_table = decomp_table_to_core_aten() # del decomp_table[aten.linear] # In this case, user says decompose everything except for aten.linear +<<<<<<< HEAD # 2. Has been marked with custom decomp behaviour. Example: +======= + # 2. Has been marked with custom decomp behavour. Example: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # decomp_table = {aten.linear: some_op} # For (1), we want to remove all the CIA ops that weren't handled by user as # it suggests they are safe to decompose, so we should remove from preservable_list. @@ -326,7 +340,11 @@ def default_decompositions() -> "CustomDecompTable": def _decompose_and_get_gm_with_new_signature_constants( +<<<<<<< HEAD ep: "ExportedProgram", +======= + ep, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *, cia_to_decomp: dict[torch._ops.OperatorBase, Callable], python_decomp_table: dict[torch._ops.OperatorBase, Callable], @@ -365,7 +383,11 @@ def _is_joint_ir_decomp(ep, joint_loss_index): # [NOTE] Unwrapping subclasses AOT # In torch.compile, the subclass unwrapping/wrapping happen at runtime +<<<<<<< HEAD # but at export, this is impossible as it is intended to be run on +======= + # but at export, this is impossible as it is intented to be run on +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # C++ environment. As a result, we unwrap subclass parameters AOT. After this, # ExportedProgram state_dict won't be same as eager model because eager model # could have subclass weights while ExportedProgram will have desugared versions. @@ -385,11 +407,17 @@ def _is_joint_ir_decomp(ep, joint_loss_index): # Fix the graph output signature to be tuple if scalar out_spec = mod._out_spec +<<<<<<< HEAD assert isinstance(mod.graph._codegen, _PyTreeCodeGen) orig_arg_names = mod.graph._codegen.pytree_info.orig_args # aot_export expect the return type to always be a tuple. assert out_spec is not None +======= + orig_arg_names = mod.graph._codegen.pytree_info.orig_args + + # aot_export expect the return type to always be a tuple. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if out_spec.type not in (list, tuple): out_spec = pytree.TreeSpec(tuple, None, [out_spec]) @@ -425,7 +453,11 @@ def _is_joint_ir_decomp(ep, joint_loss_index): # TODO (tmanlaibaatar) Ideally run_decomp should just call _non_strict_export # but due to special handling of constants as non-persistent buffers make it little +<<<<<<< HEAD # difficult. But we should unify this code path together. T206837815 +======= + # diffucult. But we should unify this code path together. T206837815 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._export.non_strict_utils import ( _enable_graph_inputs_of_type_nn_module, _fakify_script_objects, @@ -480,6 +512,10 @@ def _is_joint_ir_decomp(ep, joint_loss_index): fake_params_buffers, new_fake_constant_attrs, decomp_table=python_decomp_table, +<<<<<<< HEAD +======= + _check_autograd_state=False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _prettify_placeholder_names=False, decompose_custom_triton_ops=decompose_custom_triton_ops, ) @@ -573,8 +609,13 @@ def _is_joint_ir_decomp(ep, joint_loss_index): delattr(ep.graph_module, name) # TODO(zhxhchen17) Return the new graph_signature directly. +<<<<<<< HEAD fake_mode_det = detect_fake_mode(fake_args) fake_mode_ctx = contextlib.nullcontext() if fake_mode_det is None else fake_mode_det # type: ignore[assignment] +======= + fake_mode = detect_fake_mode(fake_args) + fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode # type: ignore[assignment] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) custom_triton_ops_decomposition_ctx = ( contextlib.nullcontext if decompose_custom_triton_ops @@ -582,7 +623,11 @@ def _is_joint_ir_decomp(ep, joint_loss_index): ) with ( _ignore_backend_decomps(), +<<<<<<< HEAD fake_mode_ctx, +======= + fake_mode, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _override_composite_implicit_decomp(cia_to_decomp), custom_triton_ops_decomposition_ctx(), ): @@ -613,7 +658,11 @@ def update_arg(old_arg, new_ph): raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}") new_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] +<<<<<<< HEAD new_outputs: tuple[torch.fx.Node, ...] = tuple(gm.graph.output_node().args[0]) # type: ignore[arg-type] +======= + new_outputs = list(gm.graph.nodes)[-1].args[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # rename the placeholders assert len(new_placeholders) == len(old_placeholders) @@ -621,6 +670,7 @@ def update_arg(old_arg, new_ph): new_ph.name = new_ph.target = old_ph.name # handle name collisions with newly decomposed graph nodes +<<<<<<< HEAD name_map = {} find_available: dict[str, int] = defaultdict(int) used_names: set[str] = set() @@ -633,6 +683,13 @@ def update_arg(old_arg, new_ph): node.name = _rename_without_collisions( name_map, find_available, used_names, node.name, node.name ) +======= + name_map = {ph.name: ph.name for ph in new_placeholders} + for node in gm.graph.nodes: + if node.op == "placeholder": + continue + node.name = _rename_without_collisions(name_map, node.name, node.name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # propagate names to higher order op subgraphs _name_hoo_subgraph_placeholders(gm) @@ -653,10 +710,14 @@ def update_arg(old_arg, new_ph): shape_env = _get_shape_env(gm) if shape_env is not None: with _set_node_metadata_hook( +<<<<<<< HEAD gm, functools.partial( _node_metadata_hook, metadata={"stack_trace": stack_trace} ), +======= + gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): insert_deferred_runtime_asserts( gm, @@ -667,9 +728,15 @@ def update_arg(old_arg, new_ph): # update output specs gm.recompile() +<<<<<<< HEAD for output, name in zip(new_outputs, _graph_output_names(gm)): if name is not None: output.name = name +======= + for i, name in enumerate(_graph_output_names(gm)): + if isinstance(new_outputs[i], torch.fx.Node): + new_outputs[i].name = name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # To match the output target with correct input for input mutations # need to find the old to new placeholder map @@ -740,7 +807,11 @@ def update_arg(old_arg, new_ph): for i, spec in enumerate(ep.graph_signature.input_specs) if isinstance(spec.arg, TensorArgument) } +<<<<<<< HEAD for node in new_outputs[len(output_specs) :]: +======= + for i, node in enumerate(new_outputs[len(output_specs) :]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) source = gradients[node.name] spec = specs[source] # type: ignore[index] if spec.kind == InputKind.PARAMETER: @@ -792,9 +863,15 @@ def _remove_unneccessary_copy_op_pass( if node.op == "output": args, _ = pytree.tree_flatten(node.args) for out in args: +<<<<<<< HEAD if isinstance(out, torch.fx.Node) and ( out.name in new_graph_signature.buffers_to_mutate or out.name in new_graph_signature.parameters_to_mutate +======= + if ( + isinstance(out, torch.fx.Node) + and out.name in new_graph_signature.buffers_to_mutate +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): if ( out.op == "call_function" @@ -1023,6 +1100,7 @@ class ExportedProgram: again to construct a correct ExportedProgram. """ +<<<<<<< HEAD _graph_module: torch.fx.GraphModule """The underlying GraphModule containing the exported computation graph.""" @@ -1049,6 +1127,8 @@ class ExportedProgram: _guards_code: list[str] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__( self, root: Union[torch.nn.Module, dict[str, Any]], @@ -1086,8 +1166,11 @@ def __init__( # Validate should be always the last step of the constructor. self.validate() +<<<<<<< HEAD self._guards_code = _convert_guards_to_code(_get_shape_env(self._graph_module)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property @compatibility(is_backward_compatible=False) def graph_module(self): @@ -1225,9 +1308,13 @@ def example_inputs(self, value): @property @compatibility(is_backward_compatible=False) def call_spec(self): +<<<<<<< HEAD class CallSpec(NamedTuple): in_spec: Optional[pytree.TreeSpec] out_spec: Optional[pytree.TreeSpec] +======= + CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if len(self.module_call_graph) == 0: return CallSpec(in_spec=None, out_spec=None) @@ -1304,7 +1391,11 @@ def _get_flat_args_with_check(self, args, kwargs): Returns: A tuple of (flat_args, received_spec) +<<<<<<< HEAD flat_args is flattened args / kwargs +======= + flat_args is flattend args / kwargs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) received_spec is the pytree spec produced while flattening the tuple (args, kwargs) """ @@ -1370,6 +1461,64 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: "You should use `exported_program.module()` instead." ) +<<<<<<< HEAD +======= + def _postprocess_graph_module_outputs(self, res, orig_args, orig_kwargs): + """Process potential mutations to the input. + + Because self.graph_module is functional, so mutations has to be written + back after execution of graph_module. + """ + import torch._export.error as error + + flat_args, _ = self._get_flat_args_with_check(orig_args, orig_kwargs) + if self.call_spec.out_spec is not None: + buffer_mutation = self.graph_signature.buffers_to_mutate + user_input_mutation = self.graph_signature.user_inputs_to_mutate + num_mutated = len(buffer_mutation) + len(user_input_mutation) + mutated_values = res[:num_mutated] + + # Exclude dependency token from final result. + assertion_dep_token = self.graph_signature.assertion_dep_token + if assertion_dep_token is not None: + assertion_dep_token_index = next(iter(assertion_dep_token.keys())) + res = res[:assertion_dep_token_index] + + res = res[num_mutated:] + try: + res = pytree.tree_unflatten(res, self.call_spec.out_spec) + except Exception: + _, received_spec = pytree.tree_flatten(res) + raise error.InternalError( # noqa: B904 + "Trying to flatten user outputs with exported output tree spec: \n" + f"{self.call_spec.out_spec}\n" + "but actually got outputs with tree spec of: \n" + f"{received_spec}" + ) + finally: + user_inputs = [ + spec + for spec in self.graph_signature.input_specs + if spec.kind == InputKind.USER_INPUT + ] + for i, value in enumerate(mutated_values): + output_spec = self.graph_signature.output_specs[i] + if output_spec.kind == OutputKind.BUFFER_MUTATION: + assert output_spec.target is not None + self.state_dict[output_spec.target] = value + elif output_spec.kind == OutputKind.USER_INPUT_MUTATION: + assert output_spec.target is not None + index = next( + i + for i, spec in enumerate(user_inputs) + if spec.arg.name == output_spec.target + ) + flat_args[index].copy_(value) + else: + raise AssertionError(f"Unexpected kind: {output_spec.kind}") + return res + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __str__(self) -> str: graph_module = self.graph_module.print_readable( print_output=False, colored=False @@ -1383,6 +1532,7 @@ def __str__(self) -> str: ) return string +<<<<<<< HEAD def module(self, check_guards=True) -> torch.fx.GraphModule: """ Returns a self contained GraphModule with all the parameters/buffers inlined. @@ -1397,6 +1547,15 @@ def module(self, check_guards=True) -> torch.fx.GraphModule: from ._unlift import _unlift_exported_program_lifted_states module = _unlift_exported_program_lifted_states(self, check_guards=check_guards) +======= + def module(self) -> torch.nn.Module: + """ + Returns a self contained GraphModule with all the parameters/buffers inlined. + """ + from ._unlift import _unlift_exported_program_lifted_states + + module = _unlift_exported_program_lifted_states(self) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _train(self, mode: bool = True): raise NotImplementedError("Calling train() is not supported yet.") @@ -1464,7 +1623,11 @@ def run_decompositions( if isinstance(_decomp_table, CustomDecompTable): _decomp_table = _decomp_table.materialize() +<<<<<<< HEAD # Note [Separating decomp_table into CIA decomps and non-CIA decomps] +======= + # Note [Seperating decomp_table into CIA decomps and non-CIA decomps] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # At this point, we have a decomp_table that contains decomp behaviour for # both CIA and post-autograd ops. # We need to separate the op into two categories: @@ -1688,6 +1851,7 @@ def _create_graph_module_for_export(root, graph): gm._graph = graph return gm +<<<<<<< HEAD def _convert_guards_to_code(shape_env): @@ -1710,3 +1874,5 @@ def _convert_guards_to_code(shape_env): for guard in shape_env.guards if guard.expr.free_symbols.issubset(local_vars) ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/export/graph_signature.py b/torch/export/graph_signature.py index e8935e359b0ee..f234867cd9ea2 100644 --- a/torch/export/graph_signature.py +++ b/torch/export/graph_signature.py @@ -121,7 +121,10 @@ class OutputKind(Enum): USER_OUTPUT = auto() LOSS_OUTPUT = auto() BUFFER_MUTATION = auto() +<<<<<<< HEAD PARAMETER_MUTATION = auto() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GRADIENT_TO_PARAMETER = auto() GRADIENT_TO_USER_INPUT = auto() USER_INPUT_MUTATION = auto() @@ -164,11 +167,19 @@ class ExportBackwardSignature: class ExportGraphSignature: """ :class:`ExportGraphSignature` models the input/output signature of Export Graph, +<<<<<<< HEAD which is a fx.Graph with stronger invariants guarantees. Export Graph is functional and does not access "states" like parameters or buffers within the graph via ``getattr`` nodes. Instead, :func:`export` guarantees that parameters, buffers, and constant tensors are lifted out of +======= + which is a fx.Graph with stronger invariants gurantees. + + Export Graph is functional and does not access "states" like parameters + or buffers within the graph via ``getattr`` nodes. Instead, :func:`export` + gurantees that parameters, buffers, and constant tensors are lifted out of +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) the graph as inputs. Similarly, any mutations to buffers are not included in the graph either, instead the updated values of mutated buffers are modeled as additional outputs of Export Graph. @@ -372,7 +383,11 @@ def user_outputs(self) -> Collection[Union[int, float, bool, None, str]]: return tuple(user_outputs) # A dictionary mapping graph input node names to parameters. If a graph input +<<<<<<< HEAD # name is found in this dictionary, it is guaranteed to be a lifted parameter. +======= + # name is found in this dictionary, it is guranteed to be a lifted parameter. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def inputs_to_parameters(self) -> Mapping[str, str]: return _immutable_dict( @@ -384,7 +399,11 @@ def inputs_to_parameters(self) -> Mapping[str, str]: ) # A dictionary mapping graph input node names to buffers. If a graph input +<<<<<<< HEAD # name is found in this dictionary, it is guaranteed to be a lifted buffer. +======= + # name is found in this dictionary, it is guranteed to be a lifted buffer. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def inputs_to_buffers(self) -> Mapping[str, str]: return _immutable_dict( @@ -408,6 +427,7 @@ def buffers_to_mutate(self) -> Mapping[str, str]: ) @property +<<<<<<< HEAD def parameters_to_mutate(self) -> Mapping[str, str]: return _immutable_dict( (s.arg.name, s.target) @@ -418,6 +438,8 @@ def parameters_to_mutate(self) -> Mapping[str, str]: ) @property +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def user_inputs_to_mutate(self) -> Mapping[str, str]: return _immutable_dict( (s.arg.name, s.target) @@ -612,7 +634,10 @@ def _convert_to_export_graph_signature( inputs_to_buffers = graph_signature.inputs_to_buffers user_outputs = set(graph_signature.user_outputs) buffer_mutations = graph_signature.buffers_to_mutate +<<<<<<< HEAD parameter_mutations = graph_signature.parameters_to_mutate +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) user_input_mutations = graph_signature.user_inputs_to_mutate grad_params = ( graph_signature.backward_signature.gradients_to_parameter # type: ignore[union-attr] @@ -674,21 +699,28 @@ def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec: if not isinstance(o, TensorArgument): return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) name = o.name +<<<<<<< HEAD if idx < len(buffer_mutations) + len(parameter_mutations) + len( user_input_mutations ) + len(output_tokens): +======= + if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if name in buffer_mutations: return OutputSpec( kind=OutputKind.BUFFER_MUTATION, arg=o, target=buffer_mutations[name], # type: ignore[index] ) +<<<<<<< HEAD elif name in parameter_mutations: return OutputSpec( kind=OutputKind.PARAMETER_MUTATION, arg=o, target=parameter_mutations[name], # type: ignore[index] ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif name in user_input_mutations: return OutputSpec( kind=OutputKind.USER_INPUT_MUTATION, diff --git a/torch/export/passes/__init__.py b/torch/export/passes/__init__.py index 5e9c5a66008b9..0473f6584995d 100644 --- a/torch/export/passes/__init__.py +++ b/torch/export/passes/__init__.py @@ -52,6 +52,7 @@ def _get_new_device( if isinstance(v, torch.Tensor): ep._constants[k] = v.to(_get_new_device(v.device, location)) +<<<<<<< HEAD # move example_inputs if they exist if ep.example_inputs is not None: args, kwargs = ep.example_inputs @@ -91,6 +92,21 @@ def _get_new_device( else v, node.meta.get("val"), ) +======= + for node in ep.graph.nodes: + # move all the nodes kwargs with burnt-in device + if "device" in node.kwargs: + kwargs = node.kwargs.copy() + kwargs["device"] = _get_new_device(kwargs["device"], location) + node.kwargs = kwargs + # move all the tensor metadata + node.meta["val"] = pytree.tree_map( + lambda v: v.to(_get_new_device(v.device, location)) + if isinstance(v, torch.Tensor) + else v, + node.meta.get("val"), + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ep.validate() return ep diff --git a/torch/export/pt2_archive/_package.py b/torch/export/pt2_archive/_package.py index 19edd03d44e38..a68f411e3a009 100644 --- a/torch/export/pt2_archive/_package.py +++ b/torch/export/pt2_archive/_package.py @@ -11,6 +11,7 @@ import torch import torch.utils._pytree as pytree +<<<<<<< HEAD from torch._export.serde import schema from torch._export.serde.serialize import ( _dataclass_to_dict, @@ -29,6 +30,11 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.export import ExportedProgram from torch.export._tree_utils import reorder_kwargs +======= +from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact +from torch.export._tree_utils import reorder_kwargs +from torch.export.exported_program import ExportedProgram +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.export.pt2_archive._package_weights import ( get_complete, group_weights, @@ -40,16 +46,23 @@ ARCHIVE_FORMAT_VALUE, ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE, +<<<<<<< HEAD CONSTANTS_CONFIG_FILENAME_FORMAT, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CONSTANTS_DIR, CUSTOM_OBJ_FILENAME_PREFIX, EXTRA_DIR, MODELS_DIR, MODELS_FILENAME_FORMAT, SAMPLE_INPUTS_FILENAME_FORMAT, +<<<<<<< HEAD TENSOR_CONSTANT_FILENAME_PREFIX, WEIGHT_FILENAME_PREFIX, WEIGHTS_CONFIG_FILENAME_FORMAT, +======= + WEIGHT_FILENAME_PREFIX, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) WEIGHTS_DIR, ) from torch.types import FileLike @@ -93,8 +106,11 @@ class PT2ArchiveWriter: """ def __init__(self, archive_path_or_buffer: FileLike): +<<<<<<< HEAD if isinstance(archive_path_or_buffer, str): archive_path_or_buffer = normalize_path_separator(archive_path_or_buffer) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer) # type: ignore[arg-type] # NOTICE: version here is different from the archive_version # this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version @@ -189,8 +205,11 @@ class PT2ArchiveReader: """ def __init__(self, archive_path_or_buffer: FileLike): +<<<<<<< HEAD if isinstance(archive_path_or_buffer, str): archive_path_or_buffer = normalize_path_separator(archive_path_or_buffer) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type] assert self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE, ( "Invalid archive format" @@ -238,11 +257,14 @@ def get_file_names(self) -> list[str]: return self.archive_file.get_all_records() +<<<<<<< HEAD is_pt2_package.__module__ = "torch.export.pt2_archive" PT2ArchiveWriter.__module__ = "torch.export.pt2_archive" PT2ArchiveReader.__module__ = "torch.export.pt2_archive" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _package_aoti_files( archive_writer: PT2ArchiveWriter, aoti_files: Optional[AOTI_FILES], @@ -325,6 +347,7 @@ def _package_aoti_files( logger.debug(weights_config) +<<<<<<< HEAD def _is_fake_tensor(t: torch.Tensor) -> bool: return isinstance(t, FakeTensor) @@ -481,6 +504,8 @@ def _package_payload_config( ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _package_exported_programs( archive_writer: PT2ArchiveWriter, exported_programs: Optional[Union[ExportedProgram, dict[str, ExportedProgram]]], @@ -491,11 +516,16 @@ def _package_exported_programs( return if isinstance(exported_programs, ExportedProgram): +<<<<<<< HEAD exported_programs = {"model": exported_programs} +======= + exported_programs = {"model", exported_programs} # type: ignore[assignment] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert isinstance(exported_programs, dict) for model_name, ep in exported_programs.items(): +<<<<<<< HEAD weights_config = _package_state_dict(ep, archive_writer, pickle_protocol) weights_config_file = WEIGHTS_CONFIG_FILENAME_FORMAT.format(model_name) _package_payload_config(archive_writer, weights_config, weights_config_file) @@ -509,10 +539,21 @@ def _package_exported_programs( opset_version, pickle_protocol, ) +======= + artifact: SerializedArtifact = serialize(ep, opset_version, pickle_protocol) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) archive_writer.write_bytes( MODELS_FILENAME_FORMAT.format(model_name), artifact.exported_program ) +<<<<<<< HEAD +======= + # TODO:Consider dedup this with the weights saved in package_aoti_files + archive_writer.write_bytes(f"{WEIGHTS_DIR}{model_name}.pt", artifact.state_dict) + archive_writer.write_bytes( + f"{CONSTANTS_DIR}{model_name}.pt", artifact.constants + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) archive_writer.write_bytes( SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name), artifact.example_inputs, @@ -540,21 +581,38 @@ def package_pt2( opset_version: Optional[dict[str, int]] = None, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, ) -> FileLike: +<<<<<<< HEAD r""" Saves the artifacts to a PT2Archive format. The artifact can then be loaded using ``load_pt2``. Args: f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to +======= + """ + Saves the artifacts to a PT2Archive format + (https://docs.google.com/document/d/1RQ4cmywilnFUT1VE-4oTGxwXdc8vowCSZsrRgo3wFA8/edit?tab=t.0#heading=h.v2y2jgnwc56a). + The artifact can then be loaded using ``load_pt2``. + + Args: + f (str | os.PathLike[str] | IO[bytes]) A file-like object (has to +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) implement write and flush) or a string containing a file name. exported_programs (Union[ExportedProgram, dict[str, ExportedProgram]]): The exported program to save, or a dictionary mapping model name to an exported program to save. The exported program will be saved under +<<<<<<< HEAD models/\*.json. If only one ExportedProgram is specified, this will automatically be named "model". aoti_files (Union[list[str], dict[str, list[str]]]): A list of files +======= + models/*.json. If only one ExportedProgram is specified, this will + automatically be named "model". + + aoti_files (Union[list[str], dict[str, list[str]]): A list of files +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) generated by AOTInductor via ``torch._inductor.aot_compile(..., {"aot_inductor.package": True})``, or a dictionary mapping model name to its AOTInductor generated files. @@ -580,7 +638,10 @@ def package_pt2( if not ( (isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable()) or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) +<<<<<<< HEAD or (isinstance(f, tempfile._TemporaryFileWrapper) and f.name.endswith(".pt2")) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): # TODO: turn this into an error logger.warning( @@ -666,6 +727,7 @@ class PT2ArchiveContents: extra_files: dict[str, Any] +<<<<<<< HEAD def _create_flat_tensor_from_bytes( tensor_bytes: bytes, tensor_meta: schema.TensorMeta, @@ -845,6 +907,8 @@ def _load_constants( return constants +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _load_exported_programs( archive_reader: PT2ArchiveReader, file_names: list[str], @@ -862,6 +926,7 @@ def _load_exported_programs( len(prefix) : -len(suffix) ] # given "models/foo.json" we can now get "foo" +<<<<<<< HEAD sample_inputs_file = SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name) serialized_sample_inputs = archive_reader.read_bytes(sample_inputs_file) @@ -881,6 +946,26 @@ def _load_exported_programs( serialized_sample_inputs, ) +======= + weights_file = f"{WEIGHTS_DIR}{model_name}.pt" + constants_file = f"{CONSTANTS_DIR}{model_name}.pt" + sample_inputs_file = SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name) + + serialized_exported_program = archive_reader.read_bytes(file) + serialized_weights = archive_reader.read_bytes(weights_file) + serialized_constants = archive_reader.read_bytes(constants_file) + serialized_sample_inputs = archive_reader.read_bytes(sample_inputs_file) + + artifact: SerializedArtifact = SerializedArtifact( + serialized_exported_program, + serialized_weights, + serialized_constants, + serialized_sample_inputs, + ) + + # Deserialize ExportedProgram + ep = deserialize(artifact, expected_opset_version) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) exported_programs[model_name] = ep return exported_programs @@ -933,8 +1018,11 @@ def load_pt2( A ``PT2ArchiveContents`` object which contains all the objects in the PT2. """ +<<<<<<< HEAD from torch._inductor.cpp_builder import normalize_path_separator +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not ( (isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable()) or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) @@ -973,9 +1061,12 @@ def load_pt2( file_end = file[ len(AOTINDUCTOR_DIR) : ] # remove data/aotinductor/ prefix +<<<<<<< HEAD file_end = normalize_path_separator( file_end ) # Win32 need normalize path before split. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model_name = file_end.split("/")[ 0 ] # split "model_name/...cpp" into "model_name" diff --git a/torch/export/pt2_archive/_package_weights.py b/torch/export/pt2_archive/_package_weights.py index 5e2a360b3dc6a..abc8257deca8f 100644 --- a/torch/export/pt2_archive/_package_weights.py +++ b/torch/export/pt2_archive/_package_weights.py @@ -28,7 +28,11 @@ def __init__(self, tensor: torch.Tensor): def is_complete(self) -> bool: """ +<<<<<<< HEAD Whether the tensor completely overlaps with its underlying storage +======= + Whehter the tensor completely overlaps with its underlying storage +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ return ( self.start == self.storage_ptr @@ -39,7 +43,11 @@ def is_complete(self) -> bool: class Weights(dict): """ A dictionary mapping from weight name to a tuple of (tensor, TensorProperties). +<<<<<<< HEAD tensor represents the actual initial value of the weight. +======= + tensor represents the actual intial value of the weight. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorProperties represents the properties of the weight that are needed to recover the weight. We use two separate entries because `tensor` could be a clone of the original weight tensor, diff --git a/torch/export/pt2_archive/constants.py b/torch/export/pt2_archive/constants.py index 772c3c0708412..e280d04645dce 100644 --- a/torch/export/pt2_archive/constants.py +++ b/torch/export/pt2_archive/constants.py @@ -9,9 +9,12 @@ ARCHIVE_VERSION_PATH: str = pt2_archive_constants.ARCHIVE_VERSION_PATH ARCHIVE_VERSION_VALUE: str = pt2_archive_constants.ARCHIVE_VERSION_VALUE CONSTANTS_DIR: str = pt2_archive_constants.CONSTANTS_DIR +<<<<<<< HEAD CONSTANTS_CONFIG_FILENAME_FORMAT: str = ( pt2_archive_constants.CONSTANTS_CONFIG_FILENAME_FORMAT ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CUSTOM_OBJ_FILENAME_PREFIX: str = pt2_archive_constants.CUSTOM_OBJ_FILENAME_PREFIX EXTRA_DIR: str = pt2_archive_constants.EXTRA_DIR MODELS_DIR: str = pt2_archive_constants.MODELS_DIR @@ -23,9 +26,12 @@ TENSOR_CONSTANT_FILENAME_PREFIX: str = ( pt2_archive_constants.TENSOR_CONSTANT_FILENAME_PREFIX ) +<<<<<<< HEAD WEIGHTS_CONFIG_FILENAME_FORMAT: str = ( pt2_archive_constants.WEIGHTS_CONFIG_FILENAME_FORMAT ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) WEIGHT_FILENAME_PREFIX: str = pt2_archive_constants.WEIGHT_FILENAME_PREFIX WEIGHTS_DIR: str = pt2_archive_constants.WEIGHTS_DIR XL_MODEL_WEIGHTS_DIR: str = pt2_archive_constants.XL_MODEL_WEIGHTS_DIR diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index d09307f66d6b8..e82435c3e295f 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -15,10 +15,17 @@ import torch.fx._pytree as fx_pytree import torch.utils._pytree as pytree from torch._library.fake_class_registry import FakeScriptObject +<<<<<<< HEAD from torch.export import ExportedProgram from torch.export._tree_utils import reorder_kwargs from torch.export.exported_program import ( ConstantArgument, +======= +from torch.export._tree_utils import reorder_kwargs +from torch.export.exported_program import ( + ConstantArgument, + ExportedProgram, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ExportGraphSignature, InputKind, ModuleCallSignature, @@ -27,7 +34,11 @@ SymIntArgument, TensorArgument, ) +<<<<<<< HEAD from torch.fx._symbolic_trace import is_fx_symbolic_tracing +======= +from torch.fx._symbolic_trace import is_fx_tracing +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.graph_module import _get_attr, _get_attr_via_attr_list, _print_readable from torch.utils._pytree import GetAttrKey, SequenceKey @@ -53,6 +64,7 @@ class _AttrKind(Enum): MODULE = "module" +<<<<<<< HEAD @dataclass(frozen=True) class _TensorID: """Custom tensor identifier containing storage, stride, and size information.""" @@ -63,6 +75,8 @@ class _TensorID: storage_offset: int +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) RUN_WITH_INTERPRETER = True @@ -134,11 +148,14 @@ class _SubmoduleBase: _ty: Optional[str] def type_name(self) -> Optional[str]: +<<<<<<< HEAD """ Subclass of this class - InterpreterModule, InterpreterModuleDispatcher, represents corresponding model in eager model. To get this type information for those modules in eager model we need to use this method. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self._ty @@ -163,7 +180,11 @@ def __init__( def forward(self, *args, **kwargs): assert self.graph_module is not None, "Didn't finalize this InterpreterModule" +<<<<<<< HEAD if not is_fx_symbolic_tracing() and ( +======= + if not is_fx_tracing() and ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter ): # Dynamo cannot trace through torch.fx.Interpreter, so fall back to @@ -295,10 +316,13 @@ def adapt( """NOTE: This adapter may mutate given ``input_args_with_path``.""" ... +<<<<<<< HEAD def get_flat_arg_paths(self) -> list[str]: """Returns a list of paths that are used to access the flat args.""" return [] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class UnflattenedModule(torch.nn.Module): def __init__( @@ -310,6 +334,7 @@ def __init__( if export_module.graph_signature.backward_signature is not None: raise ValueError("Unflattening on JointExportModule NYI") +<<<<<<< HEAD def _id(obj): """Returns _TensorID dataclass for tensors, otherwise id().""" if isinstance(obj, torch.Tensor): @@ -321,6 +346,8 @@ def _id(obj): ) return id(obj) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fqn_list = [entry.fqn for entry in export_module.module_call_graph] assert fqn_list[0] == "" export_graph = deepcopy(export_module.graph) @@ -338,7 +365,10 @@ def _id(obj): self._run_with_interpreter = RUN_WITH_INTERPRETER _inplace_buffer_and_input_mutations(export_graph, self.graph_signature) +<<<<<<< HEAD _fix_nn_module_stacks(export_graph) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.ivals = _IVals() # for any intermediate value of a mutation that is read, track the mutation @@ -365,6 +395,7 @@ def _id(obj): # graph's forward pass (_sink_params). state_dict = export_module.state_dict assigned_params: set[str] = set() # tracking unused params +<<<<<<< HEAD id_to_param: dict[ Union[int, _TensorID], torch.nn.Parameter ] = {} # handling weight-sharing @@ -372,11 +403,22 @@ def _id(obj): param = state_dict[name] if _id(param) not in id_to_param: id_to_param[_id(param)] = torch.nn.Parameter( +======= + id_to_param: dict[int, torch.nn.Parameter] = {} # handling weight-sharing + for name in self.graph_signature.parameters: # this loop adds used params + param = state_dict[name] + if id(param) not in id_to_param: + id_to_param[id(param)] = torch.nn.Parameter( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) param.clone(), requires_grad=param.requires_grad ) _assign_attr( +<<<<<<< HEAD id_to_param[_id(param)], +======= + id_to_param[id(param)], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, name, attr_kind=_AttrKind.PARAMETER, @@ -385,7 +427,11 @@ def _id(obj): non_persistent_buffers = set(self.graph_signature.non_persistent_buffers) assigned_buffers: set[str] = set() # tracking unused buffers +<<<<<<< HEAD id_to_buffer: dict[Union[int, _TensorID], tuple[torch.nn.Parameter, bool]] = {} +======= + id_to_buffer: dict[int, tuple[torch.nn.Parameter, bool]] = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for name in self.graph_signature.buffers: # this loop adds used buffers if name in non_persistent_buffers: persistent = False @@ -394,11 +440,19 @@ def _id(obj): persistent = True buffer = state_dict[name] +<<<<<<< HEAD if _id(buffer) not in id_to_buffer: id_to_buffer[_id(buffer)] = (buffer.clone(), persistent) _assign_attr( id_to_buffer[_id(buffer)][0], +======= + if id(buffer) not in id_to_buffer: + id_to_buffer[id(buffer)] = (buffer.clone(), persistent) + + _assign_attr( + id_to_buffer[id(buffer)][0], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, name, attr_kind=_AttrKind.BUFFER, @@ -413,37 +467,59 @@ def _id(obj): continue is_buffer = False +<<<<<<< HEAD if _id(tensor) in id_to_buffer or not isinstance( +======= + if id(tensor) in id_to_buffer or not isinstance( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor, torch.nn.Parameter ): # aliased buffer is_buffer = True if is_buffer: if ( +<<<<<<< HEAD _id(tensor) not in id_to_buffer ): # this is completely unused (not weight-sharing) id_to_buffer[_id(tensor)] = ( +======= + id(tensor) not in id_to_buffer + ): # this is completely unused (not weight-sharing) + id_to_buffer[id(tensor)] = ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor, True, ) # assign to respect original model _assign_attr( +<<<<<<< HEAD id_to_buffer[_id(tensor)][0], +======= + id_to_buffer[id(tensor)][0], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, name, attr_kind=_AttrKind.BUFFER, persistent=True, ) else: +<<<<<<< HEAD if _id(tensor) not in id_to_param: # this is unused id_to_param[_id(tensor)] = tensor _assign_attr( id_to_param[_id(tensor)], +======= + if id(tensor) not in id_to_param: # this is unused + id_to_param[id(tensor)] = tensor + _assign_attr( + id_to_param[id(tensor)], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self, name, attr_kind=_AttrKind.PARAMETER, ) # use id map so we don't double-clone aliased constants +<<<<<<< HEAD id_to_const: dict[ Union[int, _TensorID], Union[torch.Tensor, torch._C.ScriptObject] ] = {} @@ -453,6 +529,15 @@ def _id(obj): constant = constant.clone() id_to_const[_id(constant)] = constant _constant = id_to_const[_id(constant)] +======= + id_to_const: dict[int, Union[torch.Tensor, torch._C.ScriptObject]] = {} + for fqn, constant in export_module.constants.items(): + if id(constant) not in id_to_const: + if isinstance(constant, torch.Tensor): + constant = constant.clone() + id_to_const[id(constant)] = constant + _constant = id_to_const[id(constant)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _assign_attr( _constant, self, @@ -462,18 +547,26 @@ def _id(obj): # This is to handle parameters/buffers that point to the same tensor # object id -> list of (node_name, target_name) +<<<<<<< HEAD consts_map: dict[Union[int, _TensorID], list[tuple[str, str]]] = defaultdict( list ) +======= + consts_map: dict[int, list[tuple[str, str]]] = defaultdict(list) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) consts_targets: set[str] = set() def add_to_consts_map(obj_id, node_name, target_name): name_list = consts_map[obj_id] name_list.append((node_name, target_name)) +<<<<<<< HEAD # track aliased/unused params, buffers # prefer using untyped_storage() over id() when it's available added_params_buffers: set[str] = set() +======= + added_params_buffers: set[str] = set() # track aliased/unused params, buffers +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for s in self.graph_signature.input_specs: if s.kind == InputKind.PARAMETER or ( s.kind == InputKind.BUFFER and s.persistent @@ -481,47 +574,76 @@ def add_to_consts_map(obj_id, node_name, target_name): assert hasattr(s.arg, "name") assert isinstance(s.target, str) add_to_consts_map( +<<<<<<< HEAD _id(export_module.state_dict[s.target]), s.arg.name, s.target, +======= + id(export_module.state_dict[s.target]), s.arg.name, s.target +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) consts_targets.add(s.target) added_params_buffers.add(s.target) elif ( +<<<<<<< HEAD s.kind == InputKind.BUFFER and not s.persistent +======= + (s.kind == InputKind.BUFFER and not s.persistent) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) or s.kind == InputKind.CONSTANT_TENSOR or s.kind == InputKind.CUSTOM_OBJ ): assert hasattr(s.arg, "name") assert isinstance(s.target, str) add_to_consts_map( +<<<<<<< HEAD _id(export_module.constants[s.target]), s.arg.name, s.target, +======= + id(export_module.constants[s.target]), s.arg.name, s.target +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) consts_targets.add(s.target) # add constants that are aliased and don't appear in graph signature for const_name, const in export_module.constants.items(): if const_name not in consts_targets: +<<<<<<< HEAD const_id = _id(const) assert const_id in consts_map ph_name, _ = consts_map[const_id][0] add_to_consts_map(const_id, ph_name, const_name) +======= + assert id(const) in consts_map, ( + "Constants should be either aliased or appear in graph signature" + ) + ph_name, _ = consts_map[id(const)][0] + add_to_consts_map(id(const), ph_name, const_name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) added_params_buffers.add(s.target) # add aliased/unused params and buffers that don't appear in graph signature for fqn, tensor in export_module.state_dict.items(): if fqn not in added_params_buffers: +<<<<<<< HEAD tensor_id = _id(tensor) if tensor_id not in consts_map: +======= + if id(tensor) not in consts_map: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # completely unused (no weight-sharing), ignore. # this weight doesn't appear in graph module, # so won't cause FQN assignment issues continue +<<<<<<< HEAD ph_name, _ = consts_map[tensor_id][0] add_to_consts_map(tensor_id, ph_name, fqn) +======= + ph_name, _ = consts_map[id(tensor)][0] + add_to_consts_map(id(tensor), ph_name, fqn) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # node name -> list of possible targets inputs_to_state: dict[str, list[str]] = {} @@ -600,7 +722,11 @@ def process_forward_inputs(self, *args, **kwargs): ) flat_args = [x[1] for x in flat_args_with_path] +<<<<<<< HEAD if is_fx_symbolic_tracing(): +======= + if is_fx_tracing(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return flat_args if in_spec != signature.in_spec: @@ -620,6 +746,7 @@ def process_forward_inputs(self, *args, **kwargs): from torch._export.utils import _check_input_constraints_for_graph if self.adapted is True: +<<<<<<< HEAD flat_arg_paths = ( self.flat_args_adapter.get_flat_arg_paths() if self.flat_args_adapter @@ -639,6 +766,14 @@ def process_forward_inputs(self, *args, **kwargs): arg, ) for idx, arg in enumerate(flat_args) +======= + # TODO(suo): The FlatArgsAdapter returns a list of flat args, + # which we don't have keypaths for. For now, just create a dummy + # keypath to associate with the arg. + new_flat_args_with_path = [ # type: ignore[var-annotated] + ((SequenceKey(idx=0), GetAttrKey(name="")), arg) + for arg in flat_args +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] else: new_flat_args_with_path = flat_args_with_path # type: ignore[assignment] @@ -650,10 +785,20 @@ def process_forward_inputs(self, *args, **kwargs): return flat_args def forward(self, *args, **kwargs): +<<<<<<< HEAD flat_args = self.process_forward_inputs(*args, **kwargs) signature = self.module_call_graph[0].signature if is_fx_symbolic_tracing(): +======= + flat_args = torch._dynamo.disable( + self.process_forward_inputs, + reason="do not trace into preprocessing the inputs", + )(*args, **kwargs) + signature = self.module_call_graph[0].signature + + if is_fx_tracing(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return_val = torch.fx.Interpreter(self, graph=self.graph).run( *flat_args, enable_io_processing=False ) @@ -755,7 +900,11 @@ def unflatten( """Unflatten an ExportedProgram, producing a module with the same module hierarchy as the original eager module. This can be useful if you are trying to use :mod:`torch.export` with another system that expects a module +<<<<<<< HEAD hierarchy instead of the flat graph that :mod:`torch.export` usually produces. +======= + hierachy instead of the flat graph that :mod:`torch.export` usually produces. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. note:: The args/kwargs of unflattened modules will not necessarily match the eager module, so doing a module swap (e.g. :code:`self.submod = @@ -772,6 +921,7 @@ def unflatten( hierarchy as the original eager module pre-export. """ module = _remove_effect_tokens(module) +<<<<<<< HEAD m = UnflattenedModule(module, flat_args_adapter) # Disable process_forward_inputs as the adapter has many @@ -783,6 +933,9 @@ def unflatten( ) return m +======= + return UnflattenedModule(module, flat_args_adapter) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _inplace_buffer_and_input_mutations( @@ -857,6 +1010,7 @@ def forward(self, buffer, x): output_node.args = ((user_outputs),) +<<<<<<< HEAD def _fix_nn_module_stacks(graph): # For each nn module stack in the graph, check if the fqns in it represent a stack: # 1. Each fqn must be a prefix of the next fqn. @@ -896,6 +1050,8 @@ def _fix_nn_module_stacks(graph): ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _is_prefix(candidate, target): """Check whether `candidate` is a prefix of `target`.""" return len(candidate) < len(target) and target[: len(candidate)] == candidate diff --git a/torch/functional.py b/torch/functional.py index b5fcf8240c83f..59aa82b6e6fb9 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -105,9 +105,64 @@ def broadcast_shapes(*shapes): # This wrapper exists to support variadic args. # TODO Move this to C++ once the jit has better support for torch.Size. if not torch.jit.is_tracing(): +<<<<<<< HEAD result = torch._refs._broadcast_shapes(*shapes) if result is None: return torch.Size([]) +======= + max_len = 0 + for shape in shapes: + if isinstance(shape, (int, torch.SymInt)): + if max_len < 1: + max_len = 1 + elif isinstance(shape, (tuple, list)): + s = len(shape) + if max_len < s: + max_len = s + result = [1] * max_len + + from torch.fx.experimental.symbolic_shapes import ( + guard_size_oblivious, + is_nested_int, + ) + + for shape in shapes: + if isinstance(shape, (int, torch.SymInt)): + shape = (shape,) + if isinstance(shape, (tuple, list)): + for i in range(-1, -1 - len(shape), -1): + if shape[i] < 0: + raise RuntimeError( + f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})" + ) + + # NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1). + if is_nested_int(shape[i]): + # Broadcasting is allowed for (j0, 1) or (j0, j0); + # not (j0, j1), (j0, 5), etc. + if is_nested_int(result[i]) and guard_size_oblivious( + shape[i] == result[i] + ): + continue + else: + # NB: result is initialized to 1 so this is effectively an + # equals one test + if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious( + shape[i] == result[i] + ): + continue + + if result[i] != 1: + raise RuntimeError( + "Shape mismatch: objects cannot be broadcast to a single shape" + ) + result[i] = shape[i] + else: + raise RuntimeError( + "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", + shape, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch.Size(result) else: # with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail @@ -363,7 +418,11 @@ def parse_subscript(n: int) -> str: if len(operands) == 1 and isinstance(operands[0], (list, tuple)): # the old interface of passing the operands as one list argument _operands = operands[0] +<<<<<<< HEAD # recurse in case operands contains value that has torch function +======= + # recurse incase operands contains value that has torch function +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # in the original implementation this line is omitted return einsum(equation, *_operands) @@ -1473,7 +1532,11 @@ def atleast_1d(*tensors): Input tensors with one or more dimensions are returned as-is. Args: +<<<<<<< HEAD input (Tensor or sequence of Tensors): tensor(s) to be converted to at least 1-dimensional. +======= + input (Tensor or list of Tensors) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns: output (Tensor or tuple of Tensors) @@ -1494,8 +1557,11 @@ def atleast_1d(*tensors): >>> y = torch.tensor(1.) >>> torch.atleast_1d((x, y)) (tensor([0.5000]), tensor([1.])) +<<<<<<< HEAD >>> torch.atleast_1d() () +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ # This wrapper exists to support variadic args. if has_torch_function(tensors): @@ -1511,7 +1577,11 @@ def atleast_2d(*tensors): Input tensors with two or more dimensions are returned as-is. Args: +<<<<<<< HEAD input (Tensor or sequence of Tensors): tensor(s) to be converted to at least 2-dimensional. +======= + input (Tensor or list of Tensors) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns: output (Tensor or tuple of Tensors) @@ -1534,8 +1604,11 @@ def atleast_2d(*tensors): >>> y = torch.tensor(1.) >>> torch.atleast_2d((x, y)) (tensor([[0.5000]]), tensor([[1.]])) +<<<<<<< HEAD >>> torch.atleast_2d() () +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ # This wrapper exists to support variadic args. if has_torch_function(tensors): @@ -1551,7 +1624,11 @@ def atleast_3d(*tensors): Input tensors with three or more dimensions are returned as-is. Args: +<<<<<<< HEAD input (Tensor or sequence of Tensors): tensor(s) to be converted to at least 3-dimensional. +======= + input (Tensor or list of Tensors) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns: output (Tensor or tuple of Tensors) @@ -1582,8 +1659,11 @@ def atleast_3d(*tensors): >>> y = torch.tensor(1.0) >>> torch.atleast_3d((x, y)) (tensor([[[0.5000]]]), tensor([[[1.]]])) +<<<<<<< HEAD >>> torch.atleast_3d() () +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ # This wrapper exists to support variadic args. if has_torch_function(tensors): @@ -2075,8 +2155,12 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None): Args: A (Tensor): the tensor to factor of size :math:`(*, m, n)` +<<<<<<< HEAD pivot (bool, optional): Whether to compute the LU decomposition with partial pivoting, or the regular LU decomposition. :attr:`pivot`\ `= False` not supported on CPU. Default: `True`. +======= + pivot (bool, optional): controls whether pivoting is done. Default: ``True`` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) get_infos (bool, optional): if set to ``True``, returns an info IntTensor. Default: ``False`` out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``, diff --git a/torch/futures/__init__.py b/torch/futures/__init__.py index 76a479d965a3c..6fa6d47aea6ec 100644 --- a/torch/futures/__init__.py +++ b/torch/futures/__init__.py @@ -13,7 +13,15 @@ S = TypeVar("S") +<<<<<<< HEAD class Future(torch._C.Future, Generic[T]): +======= +class _PyFutureMeta(type(torch._C.Future), type(Generic)): # type: ignore[misc, no-redef] + pass + + +class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r""" Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It diff --git a/torch/fx/README.md b/torch/fx/README.md index 3d42cb9375d43..ea241276e82b6 100644 --- a/torch/fx/README.md +++ b/torch/fx/README.md @@ -70,7 +70,11 @@ Here, we set up a simple Module that exercises different language features: fetc The `fx.Graph` is a core data structure in FX that represents the operations and their dependencies in a structured format. It consists of a List of `fx.Node` representing individual operations and their inputs and outputs. The Graph enables simple manipulation and analysis of the model structure, which is essential for implementing various transformations and optimizations. ## Node +<<<<<<< HEAD An `fx.Node` is a data structure that represents individual operations within an `fx.Graph`, it maps to callsites such as operators, methods and modules. Each `fx.Node` keeps track of its inputs, the previous and next nodes, the stacktrace so you can map back the node to a line of code in your python file and some optional metadata stored in a `meta` dict. +======= +An `fx.Node` is a datastructure that represent individual operations within an `fx.Graph`, it maps to callsites such as operators, methods and modules. Each `fx.Node` keeps track of its inputs, the previous and next nodes, the stacktrace so you can map back the node to a line of code in your python file and some optional metadata stored in a `meta` dict. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ## [GraphModule](https://pytorch.org/docs/main/fx.html#torch.fx.GraphModule) ## The `fx.GraphModule` is a subclass of `nn.Module` that holds the transformed Graph, the original module's parameter attributes and its source code. It serves as the primary output of FX transformations and can be used like any other `nn.Module`. `fx.GraphModule` allows for the execution of the transformed model, as it generates a valid forward method based on the Graph's structure. @@ -115,11 +119,19 @@ Tracing captures an intermediate representation (IR), which is represented as a Node is the data structure that represents individual operations within a Graph. For the most part, Nodes represent callsites to various entities, such as operators, methods, and Modules (some exceptions include Nodes that specify function inputs and outputs). Each Node has a function specified by its `op` property. The Node semantics for each value of `op` are as follows: +<<<<<<< HEAD - `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. `kwargs` is ignored. Placeholders correspond to the function parameters (e.g. `x`) in the graph printout. - `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are ignored - `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention - `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument*. - `call_method` calls a method on a value. `name` is similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument* +======= +- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. `kwargs` is don't-care. Placeholders correspond to the function parameters (e.g. `x`) in the graph printout. +- `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care +- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention +- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument*. +- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument* +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - `output` contains the output of the traced function in its `args[0]` attribute. This corresponds to the "return" statement in the Graph printout. To facilitate easier analysis of data dependencies, Nodes have read-only properties `input_nodes` and `users`, which specify which Nodes in the Graph are used by this Node and which Nodes use this Node, respectively. Although Nodes are represented as a doubly-linked list, the use-def relationships form an acyclic graph and can be traversed as such. diff --git a/torch/fx/__init__.py b/torch/fx/__init__.py index c048b4fdd8f89..eb0cc9ba3a0c1 100644 --- a/torch/fx/__init__.py +++ b/torch/fx/__init__.py @@ -52,7 +52,11 @@ def forward(self, x): The **symbolic tracer** performs "symbolic execution" of the Python code. It feeds fake values, called Proxies, through the code. Operations +<<<<<<< HEAD on these Proxies are recorded. More information about symbolic tracing +======= +on theses Proxies are recorded. More information about symbolic tracing +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) can be found in the :func:`symbolic_trace` and :class:`Tracer` documentation. diff --git a/torch/fx/_graph_pickler.py b/torch/fx/_graph_pickler.py index a53cefb2c0189..7e9d142511686 100644 --- a/torch/fx/_graph_pickler.py +++ b/torch/fx/_graph_pickler.py @@ -212,6 +212,11 @@ def __init__(self, node: SymNode) -> None: self.hint = node._hint def _to_sym_node(self) -> SymNode: +<<<<<<< HEAD +======= + from torch.fx.experimental.sym_node import SymNode + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert self.shape_env is not None return SymNode(self.expr, self.shape_env, self.pytype, self.hint) diff --git a/torch/fx/_lazy_graph_module.py b/torch/fx/_lazy_graph_module.py index 83ce51fddd040..50d64d0dc1779 100644 --- a/torch/fx/_lazy_graph_module.py +++ b/torch/fx/_lazy_graph_module.py @@ -127,6 +127,14 @@ def _lazy_forward(self, *args, **kwargs): forward = _lazy_forward +<<<<<<< HEAD +======= + # TODO: we shold handle __reduce_deploy__ the same way as __reduce_package__, + # or __reduce__ by calling _real_recompile. But I don't find a good way + # to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule + # will be used in torch::deploy. So it's skipped for now. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __reduce_package__(self, exporter: PackageExporter): """ Follow GraphModule.__reduce__ but call 'self._real_recompile' rather diff --git a/torch/fx/_pytree.py b/torch/fx/_pytree.py index 7a31e4ef3cfa6..f681dbc9826e6 100644 --- a/torch/fx/_pytree.py +++ b/torch/fx/_pytree.py @@ -42,7 +42,11 @@ def tree_flatten_spec( # I guess these exist for BC, FC reasons. # In general, we should be able to directly # use pytree tree flattener to flatten them, +<<<<<<< HEAD # as export serializes the pytree separately. +======= + # as export serializes the pytree seperately. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Will remove it in follow up PR. if spec.type in SUPPORTED_NODES: flatten_fn_spec = SUPPORTED_NODES[spec.type] diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 4775bef4ba318..fda97f7f0b3ca 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -5,7 +5,10 @@ import copy import functools import inspect +<<<<<<< HEAD import logging +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import math import os import warnings @@ -27,8 +30,11 @@ from .proxy import ParameterProxy, Proxy, Scope, ScopeContextManager, TracerBase +<<<<<<< HEAD log = logging.getLogger(__name__) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS # These need to run in global scope to handle nested calls correctly @@ -46,6 +52,7 @@ _constant_attribute_types = get_args(_ConstantAttributeType) +<<<<<<< HEAD # We only want to print this once to avoid flooding logs @functools.lru_cache def is_fx_tracing_warning(): @@ -66,6 +73,12 @@ def is_fx_symbolic_tracing(): return _is_fx_tracing_flag and not torch.compiler.is_compiling() +======= +def is_fx_tracing(): + return _is_fx_tracing_flag + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @compatibility(is_backward_compatible=True) class ProxyableClassMeta(type): """ @@ -164,7 +177,11 @@ def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: co.co_name, co.co_qualname, # type: ignore[attr-defined] co.co_firstlineno, +<<<<<<< HEAD co.co_linetable, +======= + co.co_lnotab, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) co.co_exceptiontable, # type: ignore[attr-defined] co.co_freevars, co.co_cellvars, @@ -454,6 +471,10 @@ def create_arg(self, a: Any) -> "Argument": setattr(self.root, qualname, a) return self.create_node("get_attr", qualname, (), {}) +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return super().create_arg(a) @compatibility(is_backward_compatible=True) @@ -712,7 +733,11 @@ def proxy_placeholder(name): # In the case that we have pytree-flattened inputs in # `concrete_args`, generate a flattening wrapper around the # original root function and return that. +<<<<<<< HEAD self.graph._codegen = _PyTreeCodeGen( # type: ignore[has-type] +======= + self.graph._codegen = _PyTreeCodeGen( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _PyTreeInfo(orig_args[:total_args], in_spec, None) ) @@ -720,7 +745,11 @@ def flatten_fn(*args): tree_args = pytree.tree_unflatten(list(args), in_spec) tree_out = root_fn(*tree_args) out_args, out_spec = pytree.tree_flatten(tree_out) +<<<<<<< HEAD assert isinstance(self.graph._codegen, _PyTreeCodeGen) # type: ignore[has-type] +======= + assert isinstance(self.graph._codegen, _PyTreeCodeGen) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.graph._codegen.pytree_info = ( self.graph._codegen.pytree_info._replace(out_spec=out_spec) ) diff --git a/torch/fx/config.py b/torch/fx/config.py index db06176c43e13..fcd6d1ef6ebf9 100644 --- a/torch/fx/config.py +++ b/torch/fx/config.py @@ -1,5 +1,9 @@ # Whether to disable showing progress on compilation passes +<<<<<<< HEAD # Need to add a new config otherwise will get a circular import if dynamo config is imported here +======= +# Need to add a new config otherwise wil get a circular import if dynamo config is imported here +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) disable_progress = True # If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index 3e53cb908fbfc..7b80fa1336d7f 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -164,9 +164,12 @@ def split_const_subgraphs( attributes on the module prior to running the non-constant portion of the graph. """ +<<<<<<< HEAD import sympy +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not isinstance(module, torch.fx.GraphModule): mod_traced = torch.fx.symbolic_trace(module) else: @@ -197,10 +200,13 @@ def split_const_subgraphs( if node.is_impure(): continue +<<<<<<< HEAD # Skip folding nodes that have symbolic fill_value if isinstance(node.kwargs.get("fill_value", None), sympy.Expr): continue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Must be a constant foldable node at this point. const_nodes.add(node) if node.op != "get_attr": diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py index a8798a6a0726a..23d3e47466848 100644 --- a/torch/fx/experimental/graph_gradual_typechecker.py +++ b/torch/fx/experimental/graph_gradual_typechecker.py @@ -82,7 +82,11 @@ def expand_to_tensor_dim(t, n): def broadcast_types(t1, t2): """ Applies broadcasting to both given types such that they +<<<<<<< HEAD become consistent with each other and returns two new +======= + become consistent with eachother and returns two new +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) resulting types """ @@ -846,7 +850,11 @@ def flatten_refinement_rule(n: Node): @register_algebraic_expressions_inference_rule(Conv2d) def conv_rule(n: Node, module_instance): """ +<<<<<<< HEAD Represents the output in terms of an algrbraic expression w.r.t +======= + Represents the outout in terms of an algrbraic expression w.r.t +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) the input when possible """ assert isinstance(n.args[0], Node) diff --git a/torch/fx/experimental/migrate_gradual_types/constraint.py b/torch/fx/experimental/migrate_gradual_types/constraint.py index 388d716245d4f..40506479a9530 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint.py @@ -164,7 +164,11 @@ class TGreatestUpperBound(Constraint): def __init__(self, res, rhs1, rhs2): """ +<<<<<<< HEAD :param res: tensor variable that stores the result of the output +======= + :param res: tensor variable that stores the result of the outout +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) :param rhs1: tensor or tensor variable :param rhs2: tensor or tensor variabke """ @@ -407,7 +411,11 @@ def __init__( """ :param conv_result: the convolution result :param input_var: input to convolution +<<<<<<< HEAD :param c_out: output channel type +======= + :param c_out: output chanel type +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) :param kernel: kernel tuple """ self.conv_result = conv_result diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py index e4951aab15cbf..826bbb4594c6d 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -532,7 +532,11 @@ def view_inference_rule(n: Node, symbols, constraints, counter): else: num_constraints.append(BinConstraintD(t, Dyn, op_neq)) +<<<<<<< HEAD t2_type.append(t) # type: ignore[arg-type] +======= + t2_type.append(t) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) t2_type = TensorType(t2_type) # type: ignore[assignment] @@ -681,7 +685,11 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter): # tensor output case elif isinstance(n.args[1], tuple): # create and store the new tensor variable +<<<<<<< HEAD get_item_output, counter = gen_tvar(counter) # type: ignore[arg-type,assignment] +======= + get_item_output, counter = gen_tvar(counter) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) symbols[n] = get_item_output # retrieve arg variables @@ -1073,7 +1081,11 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter): e1 = symbols[n.args[0]] return [BinConstraintT(my_output, e1, op_eq)], counter elif isinstance(symbols[n.args[0]], DVar): +<<<<<<< HEAD my_output, counter = gen_dvar(counter) # type: ignore[arg-type,assignment] +======= + my_output, counter = gen_dvar(counter) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) symbols[n] = my_output e1 = symbols[n.args[0]] @@ -1095,7 +1107,11 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter): e2 = symbols[n.args[1]] return [BinConstraintT(my_output, e2, op_eq)], counter elif isinstance(symbols[n.args[1]], DVar): +<<<<<<< HEAD my_output, counter = gen_dvar(counter) # type: ignore[arg-type,assignment] +======= + my_output, counter = gen_dvar(counter) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) symbols[n] = my_output e2 = symbols[n.args[1]] diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py index 9b84c12127f0f..021551e980115 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py @@ -823,7 +823,11 @@ def calc_last_two_dims(constraint, d: list[DVar]): [BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)] ) +<<<<<<< HEAD # transform parameters into tuples in case they are not already +======= + # transform parameters into tuples incase they are not already +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) padding = ( (constraint.padding, constraint.padding) if isinstance(constraint.padding, int) diff --git a/torch/fx/experimental/migrate_gradual_types/util.py b/torch/fx/experimental/migrate_gradual_types/util.py index b160ec8de70f9..5e1a85a19a3e9 100644 --- a/torch/fx/experimental/migrate_gradual_types/util.py +++ b/torch/fx/experimental/migrate_gradual_types/util.py @@ -1,3 +1,7 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx.experimental.migrate_gradual_types.constraint import ( BinConstraintD, BVar, @@ -7,7 +11,11 @@ from torch.fx.experimental.migrate_gradual_types.operation import op_leq +<<<<<<< HEAD def gen_tvar(curr: int) -> tuple[TVar, int]: +======= +def gen_tvar(curr): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Generate a tensor variable :param curr: The current counter @@ -17,7 +25,11 @@ def gen_tvar(curr: int) -> tuple[TVar, int]: return TVar(curr), curr +<<<<<<< HEAD def gen_dvar(curr: int) -> tuple[DVar, int]: +======= +def gen_dvar(curr): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Generate a dimension variable :param curr: the current counter @@ -27,7 +39,11 @@ def gen_dvar(curr: int) -> tuple[DVar, int]: return DVar(curr), curr +<<<<<<< HEAD def gen_bvar(curr: int) -> tuple[BVar, int]: +======= +def gen_bvar(curr): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Generate a boolean variable :param curr: the current counter @@ -37,7 +53,11 @@ def gen_bvar(curr: int) -> tuple[BVar, int]: return BVar(curr), curr +<<<<<<< HEAD def gen_tensor_dims(n: int, curr: int) -> tuple[list[DVar], int]: +======= +def gen_tensor_dims(n, curr): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Generate a list of tensor dimensions :param n: the number of dimensions @@ -51,7 +71,11 @@ def gen_tensor_dims(n: int, curr: int) -> tuple[list[DVar], int]: return dims, curr +<<<<<<< HEAD def gen_nat_constraints(list_of_dims: list[DVar]) -> list[BinConstraintD]: +======= +def gen_nat_constraints(list_of_dims): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Generate natural number constraints for dimensions """ diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index ae4d1c59823a2..69911a9a7d5f6 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -11,7 +11,11 @@ import inspect import logging import operator +<<<<<<< HEAD import threading +======= +import traceback +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import typing import typing_extensions import weakref @@ -67,6 +71,10 @@ ) from torch.utils._stats import count from torch.utils._thunk import Thunk +<<<<<<< HEAD +======= +from torch.utils._traceback import CapturedTraceback +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary, WeakTensorKeyDictionary from ._backward_state import BackwardState @@ -180,7 +188,11 @@ def is_sym_node(node: _HasMeta) -> bool: return "val" in node.meta and isinstance(node.meta["val"], py_sym_types) +<<<<<<< HEAD @overload # type: ignore[no-overload-impl] +======= +@overload +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ... @@ -196,6 +208,7 @@ def set_proxy_slot( ) -> None: ... +<<<<<<< HEAD class _DisableUpdateTensorTracker(threading.local): value: bool = False @@ -259,6 +272,9 @@ def f_graph_2(x, y, z): def set_proxy_slot( # type: ignore[no-redef] +======= +def set_proxy_slot( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj: Union[PySymType, _AnyScriptObjectType, Tensor], tracer: _ProxyTracer, proxy: object, @@ -268,9 +284,13 @@ def set_proxy_slot( # type: ignore[no-redef] # We DO want to clobber proxies whenever we run an inplace operation # on a tensor, and it affects the metadata on the proxy. assert isinstance(proxy, _ProxyTensor) +<<<<<<< HEAD # see NOTE [Do not clobber inplace ops] if not _is_proxy_tensor_update_tensor_tracker_disabled(): tracer.tensor_tracker[obj] = proxy +======= + tracer.tensor_tracker[obj] = proxy +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif isinstance(obj, (_AnyScriptObject)): # We DO want to clobber proxies, with a similar rationale as for tensors. assert isinstance(proxy, Proxy) @@ -820,6 +840,7 @@ def _maybe_record_pointwise_barrier( last_node.meta["low_precision_pointwise_barrier"] = True +<<<<<<< HEAD def _fetch_proxies_and_all_constant_flag( flat_args_kwargs: Union[list[object], tuple[object, ...]], tracer: _ProxyTracer ) -> tuple[list[object], tuple[object, ...], bool]: @@ -863,6 +884,8 @@ def _fetch_proxies_and_all_constant_flag( return f_flat_args_kwargs, tuple(proxy_flat_args_kwargs), all_constant +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def proxy_call( proxy_mode: ProxyTorchDispatchMode, func: OpOverload, @@ -915,8 +938,32 @@ def can_handle_tensor(x: Tensor) -> bool: return (args[0] != 0).item() # type: ignore[attr-defined] tracer = proxy_mode.tracer +<<<<<<< HEAD f_flat_args_kwargs, proxy_flat_args_kwargs, all_constant = ( _fetch_proxies_and_all_constant_flag(flat_args_kwargs, tracer) +======= + f_flat_args_kwargs = [ + ( + fetch_object_proxy(tracer, x) + if isinstance(x, (Tensor, _AnyScriptObject)) + else x + ) + for x in flat_args_kwargs + ] + + # If there are SymInts, we also should not consider this constant. + # However, fake tensor handling of SymInts is sufficiently broken that + # I couldn't write a test for this case + all_constant = ( + not any( + t.constant is None + for t in f_flat_args_kwargs + if isinstance(t, _ProxyTensor) + ) + # TODO: maybe constant SymInts should also be allowed? Not sure if + # this can happen + and not any(isinstance(x, py_sym_types) for x in flat_args_kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if torch.Tag.data_dependent_output in func.tags: @@ -944,6 +991,16 @@ def can_handle_tensor(x: Tensor) -> bool: "in your make_fx call." ) +<<<<<<< HEAD +======= + proxy_flat_args_kwargs = [ + e.proxy if isinstance(e, _ProxyTensor) else e for e in f_flat_args_kwargs + ] + proxy_flat_args_kwargs = [ + (fetch_sym_proxy(proxy_mode.tracer)(e) if isinstance(e, py_sym_types) else e) + for e in proxy_flat_args_kwargs + ] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) proxy_args, proxy_kwargs = pytree.tree_unflatten(proxy_flat_args_kwargs, spec) # When we trace through a torch.tensor invocation, you never actually @@ -1097,6 +1154,10 @@ class PythonKeyTracer(Tracer): tensor_tracker: MutableMapping[Tensor, _ProxyTensor] torch_fn_counts: dict[OpOverload, int] enable_thunkify: bool = False +<<<<<<< HEAD +======= + stack_trace: bool = False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self) -> None: super().__init__(autowrap_modules=()) # type: ignore[arg-type] @@ -1160,7 +1221,11 @@ def unwrap_proxy( def unwrap_proxy(self, e: T) -> object: if isinstance(e, Tensor): +<<<<<<< HEAD return get_proxy_slot(e, self, e, lambda x: x.proxy) # type: ignore[attr-defined] +======= + return get_proxy_slot(e, self, e, lambda x: x.proxy) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif isinstance(e, py_sym_types): return get_proxy_slot(e, self, e, lambda e: e.force()) elif isinstance(e, _AnyScriptObject): @@ -1179,8 +1244,43 @@ def create_node( ) -> torch.fx.Node: node = super().create_node(kind, target, args, kwargs, name, type_expr) # type: ignore[arg-type] +<<<<<<< HEAD if node.op in ["placeholder", "output"] and "stack_trace" in node.meta: del node.meta["stack_trace"] +======= + # stack_trace + if ( + self.stack_trace + and "stack_trace" not in node.meta + and node.op not in ["placeholder", "output"] + ): + user_frame_summary = CapturedTraceback.extract().summary() + if user_frame_summary: + # we retain frames from forward() calls, or ops + # located in torch/__init__.py (e.g. sym_int, sym_constrain_range, vmap) + stack_trace = [ + frame + for frame in user_frame_summary + if ( + frame.name == "forward" + or frame.filename.endswith("torch/__init__.py") + ) + ] + # filter out forward() frames from fx/_symbolic_trace.py, export/_trace.py + # this is hardcoded, but leads to a much cleaner stack trace + stack_trace = [ + frame + for frame in stack_trace + if not frame.filename.endswith( + ("fx/_symbolic_trace.py", "export/_trace.py") + ) + ] + if ( + stack_trace + ): # empty list for strict mode, dynamo should handle stack_trace + stack_trace = traceback.StackSummary.from_list(stack_trace) + node.meta["stack_trace"] = "".join(stack_trace.format()).strip() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if kind == "get_attr": assert isinstance(target, str) @@ -1435,7 +1535,11 @@ def __torch_function__( kwargs = kwargs or {} if func in _side_effectful_need_to_be_preserved_pre_dispatch: # It's for passing the export verifier which needs to verify the meta['val'] +<<<<<<< HEAD # TODO(tmanlaibaatar): we should systematically couple it with export verifier, +======= + # TODO(tmanlaibaatar): we should systematically couple it with expoert verifier, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # instead of hardcoding it here. # T203648563 if func == torch.amp.autocast_mode._exit_autocast: @@ -1450,6 +1554,7 @@ def __torch_function__( torch.amp.autocast_mode._exit_autocast, ]: node.meta["val"] = None +<<<<<<< HEAD # For autocast, the python APIs run so we don't have to run them again # here. if func is torch._C._set_grad_enabled: @@ -1476,6 +1581,11 @@ def __torch_function__( res = func(*args, **kwargs) track_tensor_tree(res, out_proxy, constant=None, tracer=self.tracer) return res +======= + return node + # Don't actually run the function! We just want to trace the calls + # into a graph. We don't actualy want to change global autograd state. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return func(*args, **kwargs) @@ -1770,8 +1880,12 @@ class _ModuleStackTracer(PythonKeyTracer): def __init__(self, scope_root: GraphModule) -> None: super().__init__() +<<<<<<< HEAD self.record_stack_traces = True self._record_forward_stack_traces_only = True +======= + self.stack_trace = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.scope_root = scope_root self.enable_attr_proxy = False self.submodule_paths = {} @@ -1907,6 +2021,7 @@ def trace( # type: ignore[override] ) -> fx.Graph: res = super().trace(root, concrete_args) +<<<<<<< HEAD # NOTE [export non-strict fake tensor leak detection] # In non-strict export, we don't have dynamo's side effect # tracking logic which makes some cases hard to detect. @@ -1926,6 +2041,8 @@ def trace( # type: ignore[override] for key, val in self.tensor_tracker.items(): _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT[id(key)] = val.proxy.node +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Since we are making _AttrProxy mimic the original # submodule, when someone registers a module directly # to the tracer while tracing, the proxy object gets registered @@ -2021,7 +2138,11 @@ def create_node(self, *args: object, **kwargs: object) -> fx.node.Node: # nn_module_stack if node.op not in ["placeholder", "output"]: if "nn_module_stack" not in node.meta: +<<<<<<< HEAD node.meta["nn_module_stack"] = self.module_stack.copy() +======= + node.meta["nn_module_stack"] = self.module_stack +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # convert nn_module_stack from Dict[key, (FQN, class)] -> Dict[str, Tuple[str, str]] for key, (fqn, mod_cls) in node.meta["nn_module_stack"].items(): if isinstance(mod_cls, type): @@ -2054,8 +2175,12 @@ def __init__( record_module_stack: bool, _allow_fake_constant: bool, _error_on_data_dependent_ops: bool, +<<<<<<< HEAD record_stack_traces: bool = False, parent_tracer: Optional[_MakefxTracer] = None, +======= + stack_trace: bool = False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: # Configurations that are used to initialize the context managers and their states. # Should not modify them during tracing. @@ -2086,8 +2211,12 @@ def __init__( self.torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode] = ( nullcontext() ) +<<<<<<< HEAD self.record_stack_traces = record_stack_traces self.parent_tracer: Optional[_MakefxTracer] = parent_tracer +======= + self.stack_trace = stack_trace +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _checkpoint_modes(self) -> list[Any]: return [ @@ -2127,6 +2256,7 @@ def _init_modes_from_inputs( if hasattr(f, "_orig_mod") and self.record_module_stack: scope_root = f._orig_mod # _ModuleStackTracer always try to preserve stack trace +<<<<<<< HEAD # in forward functions self.fx_tracer = _ModuleStackTracer(scope_root) else: @@ -2134,6 +2264,12 @@ def _init_modes_from_inputs( self.fx_tracer.record_stack_traces = self.record_stack_traces if self.record_stack_traces: self.fx_tracer._record_forward_stack_traces_only = True +======= + self.fx_tracer = _ModuleStackTracer(scope_root) + else: + self.fx_tracer = PythonKeyTracer() + self.fx_tracer.stack_trace = self.stack_trace +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.tracing_mode == "fake": import torch._dynamo @@ -2336,6 +2472,7 @@ def _wrap_func(f: Callable[_P, R], phs: Sequence[PHBase]) -> Callable[_P, R]: ) raise +<<<<<<< HEAD if ( self.is_hop_subgraph_tracer() and (fake_mode := torch._guards.detect_fake_mode(args)) @@ -2345,6 +2482,8 @@ def _wrap_func(f: Callable[_P, R], phs: Sequence[PHBase]) -> Callable[_P, R]: insert_deferred_runtime_asserts(t, fake_mode.shape_env, "reenter_make_fx") t.recompile() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: kind of a bad way to do it, should maybe figure out a better way if self.tracing_mode == "symbolic": assert self.fake_tensor_mode is not None @@ -2355,9 +2494,12 @@ def trace(self, f: Callable, *args: object) -> fx.GraphModule: with self._init_modes_from_inputs(f, args): return self._trace_inner(f, *args) +<<<<<<< HEAD def is_hop_subgraph_tracer(self) -> bool: return self.parent_tracer is not None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def trace_subgraph(self, f: Callable, *args: object) -> GraphModule: # Create a new tracer based on parent's config sub_tracer = _MakefxTracer( @@ -2368,7 +2510,10 @@ def trace_subgraph(self, f: Callable, *args: object) -> GraphModule: self.record_module_stack, self._allow_fake_constant, self._error_on_data_dependent_ops, +<<<<<<< HEAD parent_tracer=self, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) with sub_tracer._init_modes_from_parent(self): return sub_tracer._trace_inner(f, *args) @@ -2398,14 +2543,22 @@ def make_fx( record_module_stack: bool = False, _allow_fake_constant: bool = False, _error_on_data_dependent_ops: bool = True, +<<<<<<< HEAD record_stack_traces: bool = False, +======= + stack_trace: bool = False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Callable[..., GraphModule]: """ Given a function f, return a new function which when executed with valid arguments to f, returns an FX GraphModule representing the set of operations that were executed during the course of execution. +<<<<<<< HEAD If record_stack_traces is True, the stack trace will be preserved on node.meta["stack_trace"] +======= + If stack_trace is True, the stack_trace will be preserved on node.meta["stack_trace"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ assert tracing_mode in ["real", "fake", "symbolic"] @@ -2420,8 +2573,12 @@ def make_fx( record_module_stack, _allow_fake_constant, _error_on_data_dependent_ops, +<<<<<<< HEAD record_stack_traces=record_stack_traces or config.trace.provenance_tracking_level == 1, +======= + stack_trace=stack_trace or config.trace.enabled, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @functools.wraps(f) diff --git a/torch/fx/experimental/shape_inference/infer_symbol_values.py b/torch/fx/experimental/shape_inference/infer_symbol_values.py index 1ee6d2a939ae2..4413808cee8f1 100644 --- a/torch/fx/experimental/shape_inference/infer_symbol_values.py +++ b/torch/fx/experimental/shape_inference/infer_symbol_values.py @@ -64,7 +64,11 @@ def infer_symbol_values( for right_var in right_vars: if sp.sympify(right_var) == sp.sympify("s0"): right_equation = sp.cancel(right_equation / right_var) +<<<<<<< HEAD right_vars.remove(right_var) # noqa: B909 +======= + right_vars.remove(right_var) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) var = right_vars[0] idx = symbol_idx_dict[str(var)] diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 00f1fed899acf..2c280192934b9 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -56,7 +56,10 @@ # NB: The sym_* functions are used via getattr() and must be imported here. from torch import SymBool, SymFloat, SymInt +<<<<<<< HEAD from torch._C._functorch import get_unwrapped, is_batchedtensor +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._guards import ShapeGuard, SLoc, Source, TracingContext from torch._logging import dtrace_structured, LazyString, structured, trace_structured from torch._subclasses.meta_utils import is_sparse_any @@ -213,7 +216,11 @@ def log_lru_cache_stats(wrapped_f: functools._lru_cache_wrapper[object]) -> None class SymIntEqByExpr: """ This is a wrapper around SymInt which has alternative semantics for +<<<<<<< HEAD equality and pickling. Specifically, instead of erroring or guarding, we +======= + equality. Specifically, instead of erroring or guarding, we +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) instead will hash/compare equality based on the underlying sympy expression; e.g., s0 and s1 will always compare as False. @@ -222,6 +229,7 @@ class SymIntEqByExpr: canonicalize to the same expression via regular simplification. """ +<<<<<<< HEAD @staticmethod def _extract(val: Union[torch.SymInt, int]) -> sympy.Expr: if isinstance(val, torch.SymInt): @@ -231,16 +239,42 @@ def _extract(val: Union[torch.SymInt, int]) -> sympy.Expr: def __init__(self, val: Union[torch.SymInt, int]) -> None: self.val: sympy.Expr = SymIntEqByExpr._extract(val) +======= + val: Union[torch.SymInt, int] + + def __init__(self, val: Union[torch.SymInt, int]) -> None: + self.val = val +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __repr__(self) -> str: return repr(self.val) +<<<<<<< HEAD def __eq__(self, other: object) -> bool: assert isinstance(other, SymIntEqByExpr) return self.val == other.val def __hash__(self) -> int: return hash(self.val) +======= + def _extract(self) -> sympy.Expr: + if isinstance(self.val, torch.SymInt): + return self.val.node.expr + else: + return sympy.Integer(self.val) + + def __eq__(self, other: object) -> bool: + assert isinstance(other, SymIntEqByExpr) + + # int equality fastpath + if type(self.val) is int and type(other.val) is int: + return self.val == other.val + + return self._extract() == other._extract() + + def __hash__(self) -> int: + return hash(self._extract()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _nested_int_aware_sort( @@ -306,6 +340,7 @@ def uninteresting_files() -> set[str]: import torch._logging import torch._subclasses.fake_tensor import torch._subclasses.meta_utils +<<<<<<< HEAD import torch.export._trace mods = [ @@ -315,6 +350,14 @@ def uninteresting_files() -> set[str]: torch.fx.experimental.sym_node, torch.fx.interpreter, torch.fx._symbolic_trace, +======= + + mods = [ + sys.modules[__name__], + torch.fx.experimental.recording, + torch.fx.experimental.sym_node, + torch.fx.interpreter, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch, torch._compile, torch._dynamo.eval_frame, @@ -1141,10 +1184,14 @@ def expr(s: Union[SymInt, SymFloat, SymBool]) -> sympy.Expr: for attr in attrs: sub = getattr(a, attr) r.update(go(sub, path + (InnerTensorKey(attr),))) +<<<<<<< HEAD elif isinstance(a, torch.Tensor) and is_batchedtensor(a): unwrapped_tensor = get_unwrapped(a) r.update(go(unwrapped_tensor, path)) elif isinstance(a, torch.Tensor) and not is_batchedtensor(a): +======= + elif isinstance(a, torch.Tensor): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._subclasses.fake_tensor import FakeTensor assert isinstance(a, FakeTensor) @@ -1280,7 +1327,10 @@ def compute_unbacked_bindings( return None fs = shape_env.pending_fresh_unbacked_symbols +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pending = set(fs) if not pending: return None @@ -1329,6 +1379,7 @@ def compute_unbacked_bindings( isinstance(old_sym, SymTypes) and (old_s := old_sym.node.expr) != new_s ): +<<<<<<< HEAD # If old_s is not an unbacked_symbol, # we assume that the original unbacked symbol is replaced # by a backed symbol (old_s). This can happen @@ -1337,6 +1388,9 @@ def compute_unbacked_bindings( # When this happens we just replace new_s by the old_s # because we know the value is the same. if isinstance(old_s, sympy.Symbol) and free_unbacked_symbols(old_s): +======= + if isinstance(old_s, sympy.Symbol): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) shape_env._rename_unbacked_to(new_s, old_s) else: shape_env._eliminate_unbacked(new_s, old_s) @@ -1462,6 +1516,10 @@ def statically_known_true(x: BoolLikeType) -> bool: if not isinstance(x, SymBool): assert isinstance(x, bool) return x +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result = _static_eval_sym_bool(x) if result is None: return False @@ -2373,7 +2431,11 @@ def _maybe_evaluate_static_worker( # Note: # Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers. +<<<<<<< HEAD # Sympy might give unexpected results when comparing an integer with a non-integer +======= + # Sympy might give unexepected results when comparing an integer with a non-integer +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Therefore, we cast offset to int here. # For example: # shape_0 = sympy.Symbol("shape_0", positive=True, integer=True) @@ -3381,7 +3443,11 @@ def prettify_results( constraint_violation_error: object, forced_specializations: dict[str, str], ) -> str: +<<<<<<< HEAD """Format a message for constraint violation errors""" +======= + """Format a message for constraint violation erros""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.export.dynamic_shapes import _get_dim_name_mapping if not self._dcp.source_name_to_debug_name: @@ -3530,6 +3596,10 @@ class ShapeEnvSettings: specialize_zero_one: bool duck_shape: bool prefer_deferred_runtime_asserts_over_guards: bool +<<<<<<< HEAD +======= + allow_complex_guards_as_runtime_asserts: bool +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) trace_asserts: bool @@ -3667,6 +3737,13 @@ def _init( # in guards is helpful, since these guards in some sense are overly # pedantic. See also https://github.com/pytorch/pytorch/issues/121749 prefer_deferred_runtime_asserts_over_guards: bool = False, +<<<<<<< HEAD +======= + # When True, does not emit or raise constraint violation errors on + # implicit guards generated by ops, and defers to runtime assertions + # in the graph instead. For export. + allow_complex_guards_as_runtime_asserts: bool = False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # XXX Add any new settings that could affect FakeTensor evaluation # to: torch._subclasses.fake_tensor._ShapeEnvSettings trace_asserts: bool = False, @@ -3683,6 +3760,10 @@ def _init( specialize_zero_one=specialize_zero_one, duck_shape=duck_shape, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, +<<<<<<< HEAD +======= + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) trace_asserts=trace_asserts, ) @@ -3841,7 +3922,12 @@ def _init( # with something like effect token tracking. self.unbacked_alloc_order: dict[sympy.Symbol, int] = {} +<<<<<<< HEAD self.specialization_stacks: dict[Source, traceback.StackSummary] = {} +======= + self.user_specialization_stacks: dict[Source, traceback.StackSummary] = {} + self.framework_specialization_stacks: dict[Source, traceback.StackSummary] = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.trace_asserts = trace_asserts @@ -3894,6 +3980,13 @@ def duck_shape(self) -> bool: def prefer_deferred_runtime_asserts_over_guards(self) -> bool: return self.settings.prefer_deferred_runtime_asserts_over_guards +<<<<<<< HEAD +======= + @property + def allow_complex_guards_as_runtime_asserts(self) -> bool: + return self.settings.allow_complex_guards_as_runtime_asserts + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @contextmanager def patch_source_specialization( self, source: Source, check_fn: Callable[[sympy.Symbol], sympy.Expr] @@ -3928,7 +4021,11 @@ def patch_source_specialization( added_replacements[axiom.lhs] = axiom.rhs self.axioms.update(new_axioms) +<<<<<<< HEAD # We need to freeze the ShapeEnv because any additional modification of +======= + # We need to freeze the ShapeEnv becuase any additional modification of +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # the ShapeEnv will cause unsoundness for subsequent specialization calls. self.frozen = True try: @@ -3968,7 +4065,12 @@ def check_equal(self, other: ShapeEnv) -> None: "replacements_slocs", "_resimplify_floor_div_axioms", "_expr_sym_node_id", +<<<<<<< HEAD "specialization_stacks", +======= + "user_specialization_stacks", + "framework_specialization_stacks", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Mapping of the value of each to-be-compared field into the values that @@ -4354,23 +4456,33 @@ def _produce_dyn_sizes_from_int_tuple( tensor_size: Sequence[IntLikeType], source: Source, symbolic_context: SymbolicContext, +<<<<<<< HEAD hint_overrides: Optional[dict[int, int]] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> list[sympy.Expr]: assert all(not is_symbolic(val) for val in tensor_size), ( f"Expect size to be a plain tuple of ints but got {tensor_size}" ) from torch._dynamo.source import TensorProperty, TensorPropertySource +<<<<<<< HEAD if not hint_overrides: hint_overrides = {} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _assert_symbol_context(symbolic_context) dynamic_dims = symbolic_context.dynamic_sizes # type: ignore[attr-defined] constraint_dims = symbolic_context.constraint_sizes # type: ignore[attr-defined] size = [] for i, val in enumerate(tensor_size): sym = self.create_symbol( +<<<<<<< HEAD val if i not in hint_overrides else hint_overrides[i], +======= + val, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TensorPropertySource(source, TensorProperty.SIZE, i), dynamic_dims[i], constraint_dims[i], @@ -4466,7 +4578,11 @@ def create_symbolic_sizes_strides_storage_offset( # The order of checking the guards matters. In this specific example: # If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True, +<<<<<<< HEAD # we may have an unnecessary shape speciliazation for y. +======= + # we may have an unnessary shape speciliazation for y. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _maybe_specialize_sym_int_with_hint( self, maybe_sym: IntLikeType ) -> IntLikeType: @@ -4490,7 +4606,10 @@ def _create_symbolic_sizes_strides_storage_offset( source: Source, *, symbolic_context: Optional[SymbolicContext] = None, +<<<<<<< HEAD hint_overrides: Optional[dict[int, int]] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> tuple[ tuple[IntLikeType, ...], tuple[IntLikeType, ...], @@ -4498,9 +4617,12 @@ def _create_symbolic_sizes_strides_storage_offset( ]: dim = len(ex_size) +<<<<<<< HEAD if not hint_overrides: hint_overrides = {} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Reimplement the legacy behavior if symbolic_context is None: constraint_sizes: list[DimConstraint] = [None] * dim @@ -4555,7 +4677,11 @@ def _create_symbolic_sizes_strides_storage_offset( from torch._dynamo.source import TensorProperty, TensorPropertySource size: list[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple( +<<<<<<< HEAD ex_size, source, symbolic_context, hint_overrides=hint_overrides +======= + ex_size, source, symbolic_context +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) stride = self._compute_symbolic_stride( source, @@ -4571,7 +4697,11 @@ def _create_symbolic_sizes_strides_storage_offset( sym_sizes = [ self.create_symintnode( sym, +<<<<<<< HEAD hint=hint if i not in hint_overrides else hint_overrides[i], +======= + hint=hint, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) source=TensorPropertySource(source, TensorProperty.SIZE, i), ) for i, (sym, hint) in enumerate(zip(size, ex_size)) @@ -4708,7 +4838,11 @@ def create_symfloatnode( self, sym: sympy.Expr, *, +<<<<<<< HEAD hint: Optional[int | float | bool], +======= + hint: Optional[int], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) source: Optional[Source] = None, ) -> FloatLikeType: """Create a SymFloat value from a symbolic expression""" @@ -5246,7 +5380,11 @@ def produce_guards_verbose( # calls on this new instance. Finally, it will check whether this new instance # has equal state. # +<<<<<<< HEAD # It's important that we do it in the beginning of this function, since it modifies +======= + # It's important that we do it in the begining of this function, since it modifies +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # self.dim_constraints through its execution. Changes that happen in this method # aren't interesting, since this is the function call we wish to reproduce at the # end. If we wish to simply reproduce ShapeEnv instances even after this call, @@ -5531,13 +5669,28 @@ def hint(s: sympy.Expr) -> str: var_with_range = self._render_range_for_constraint_violation( source, constraint ) +<<<<<<< HEAD user_stack = self.specialization_stacks.get(source, None) +======= + user_stack = self.user_specialization_stacks.get(source, None) + framework_stack = self.framework_specialization_stacks.get( + source, None + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) msg = ( f"You marked {self._debug_name(source)} as dynamic but your code " f"specialized it to be a constant ({val}). If you're using mark_dynamic, " f"either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, " f"replace it with either Dim.STATIC or Dim.AUTO." + ( +<<<<<<< HEAD +======= + "\n\nFramework stack:\n" + "".join(framework_stack.format()) + if framework_stack + else "" + ) + + ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "\n\nUser stack:\n" + "".join(user_stack.format()) if user_stack else "" @@ -6243,7 +6396,11 @@ def _maybe_evaluate_static( Use compute_hint == True if you are trying to compute a non-binding hint for the particular hint values of backed and unbacked SymInts, +<<<<<<< HEAD e.g., if s0 happens to be 3 this run, compute_hint will substitute s0 with 3. +======= + e.g., if s0 happens to be 3 this run, compute_hint will subsitute s0 with 3. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ # axioms with compute hint NYE @@ -6260,11 +6417,19 @@ def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None: return self._resimplify_floor_div_axioms = False new_items = {} +<<<<<<< HEAD for k, v in list(axioms.items()): # A FloorDiv in implications could have became CleanDiv at this point, due to new facts # to the shapeEnv. This handles such issue but its not ideal. This is the only expression # simplification that depends on the global state of shape env. # TODO try to get rid of CleanDiv since it breaks the invariant that's simplifications of sympy +======= + for k, v in axioms.items(): + # A FloorDiv in implications could have became CleanDiv at this point, due to new facts + # to the shapeEnv. This handles such issue but its not ideal. This is the only expression + # simplification that depends on the global state of shape env. + # TODO try to get rid of CleanDiv since it breaks the invariant thats simplifications of sympy +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # expressions only depend on the expression itself. if k.has(FloorDiv): new_items.update({self.simplify(k): v}) @@ -6346,6 +6511,7 @@ def simplify(self, expr: _SympyT, size_oblivious: bool = False) -> _SympyT: expr = safe_expand(expr) expr = self.replace(expr) +<<<<<<< HEAD # Simplify max(0/1, x) to x when x >= 0/1. max(1, x) is a commonly introduced # expression when creating contiguous strides. if not size_oblivious: @@ -6364,6 +6530,8 @@ def simplify(self, expr: _SympyT, size_oblivious: bool = False) -> _SympyT: if min_max_replacements: expr = expr.xreplace(min_max_replacements) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if size_oblivious and (expr.has(Max) or expr.has(Min)): # type: ignore[has-type] min_max_replacements = {} for atom in (*expr.atoms(Max), *expr.atoms(Min)): # type: ignore[has-type] @@ -6522,6 +6690,10 @@ def _make_data_dependent_error( expr: sympy.Basic, unhinted_expr: sympy.Basic, *, +<<<<<<< HEAD +======= + size_oblivious_result: Optional[sympy.Basic] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expr_sym_node_id: Optional[int] = None, ) -> GuardOnDataDependentSymNode: # TODO: in a Dynamo context, having user code, and having the @@ -6535,6 +6707,14 @@ def _make_data_dependent_error( if s in self.size_like: size_like_symbols.append(s) size_oblivious_result_msg = "" +<<<<<<< HEAD +======= + if size_oblivious_result is not None: + size_oblivious_result_msg = ( + f"ATTENTION: guard_size_oblivious would fix the error, evaluating expression to {size_oblivious_result}.\n" + "Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.\n\n" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sloc, maybe_extra_debug = self._get_stack_summary(True) if expr.is_integer: # type: ignore[attr-defined] desc = ( @@ -6542,11 +6722,14 @@ def _make_data_dependent_error( ) else: desc = "Could not guard on data-dependent expression" +<<<<<<< HEAD size_oblivious_result_msg = ( "consider using data-dependent friendly APIs such as " "guard_or_false, guard_or_true and statically_known_true" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) msg = ( f"{desc} {expr} (unhinted: {unhinted_expr}). " f"(Size-like symbols: {', '.join(map(str, size_like_symbols)) or 'none'})\n\n" @@ -6642,7 +6825,11 @@ def _set_replacement(self, a: sympy.Symbol, tgt: sympy.Expr, msg: str) -> None: assert isinstance(a, sympy.Symbol) if ( +<<<<<<< HEAD self.prefer_deferred_runtime_asserts_over_guards +======= + self.allow_complex_guards_as_runtime_asserts +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and not _is_supported_equivalence(tgt) ): return # continuing leads to placeholder shapes having complex expressions that we can't resolve @@ -6776,7 +6963,14 @@ def _set_replacement(self, a: sympy.Symbol, tgt: sympy.Expr, msg: str) -> None: for source in self.var_to_sources.get(a, []): if user_tb: +<<<<<<< HEAD self.specialization_stacks[source] = user_tb +======= + self.user_specialization_stacks[source] = user_tb + self.framework_specialization_stacks[source] = ( + CapturedTraceback.extract(cpp=True) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if config.print_specializations: self.log.warning( @@ -7571,9 +7765,22 @@ def compute_concrete_val() -> sympy.Basic: ok = True if not ok: +<<<<<<< HEAD + raise self._make_data_dependent_error( + expr.xreplace(self.var_to_val), + expr, +======= + size_oblivious_result = None + # compute size_oblivious_result to suggest it as a fix for the user if it works. + if not size_oblivious: + size_oblivious_result = self._maybe_evaluate_static( + expr, size_oblivious=True + ) raise self._make_data_dependent_error( expr.xreplace(self.var_to_val), expr, + size_oblivious_result=size_oblivious_result, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expr_sym_node_id=self._expr_sym_node_id, ) else: @@ -7617,6 +7824,7 @@ def compute_concrete_val() -> sympy.Basic: # is no longer necessary) self._maybe_guard_rel(g) +<<<<<<< HEAD if ( torch.compiler.is_exporting() and self.prefer_deferred_runtime_asserts_over_guards @@ -7626,6 +7834,9 @@ def compute_concrete_val() -> sympy.Basic: # and so the result here will be statically known self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}") else: +======= + if not self.allow_complex_guards_as_runtime_asserts: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # at this point, we've evaluated the concrete expr value, and have # flipped/negated the guard if necessary. Now we know what to guard # or defer to runtime assert on. @@ -7634,6 +7845,14 @@ def compute_concrete_val() -> sympy.Basic: ) self.guards.append(guard) self.axioms.update(dict(self.get_implications(self.simplify(g)))) +<<<<<<< HEAD +======= + else: + # it's fine to defer simple guards here without checking, + # the _maybe_guard_rel() call above will set replacements if possible, + # and so the result here will be statically known + self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec) @@ -7862,9 +8081,13 @@ def run_node(self, n: torch.fx.Node) -> Result: from torch._guards import detect_fake_mode result = super().run_node(n) +<<<<<<< HEAD fake_mode = detect_fake_mode() assert fake_mode is not None rebind_unbacked(fake_mode.shape_env, n, result) +======= + rebind_unbacked(detect_fake_mode().shape_env, n, result) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return result diff --git a/torch/fx/experimental/unification/multipledispatch/utils.py b/torch/fx/experimental/unification/multipledispatch/utils.py index 0b21183c40b97..a1b30a63818de 100644 --- a/torch/fx/experimental/unification/multipledispatch/utils.py +++ b/torch/fx/experimental/unification/multipledispatch/utils.py @@ -5,9 +5,15 @@ __all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"] +<<<<<<< HEAD def raises(err, lamda): # codespell:ignore lamda try: lamda() # codespell:ignore lamda +======= +def raises(err, lamda): + try: + lamda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False except err: return True diff --git a/torch/fx/experimental/unification/utils.py b/torch/fx/experimental/unification/utils.py index a8035f75d3027..d7d3ea2d0ba06 100644 --- a/torch/fx/experimental/unification/utils.py +++ b/torch/fx/experimental/unification/utils.py @@ -23,9 +23,15 @@ def transitive_get(key, d): return key +<<<<<<< HEAD def raises(err, lamda): # codespell:ignore lamda try: lamda() # codespell:ignore lamda +======= +def raises(err, lamda): + try: + lamda() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False except err: return True diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index db00952512067..85bfc94e81940 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -651,7 +651,11 @@ def _validate(self) -> None: def translation_validation_enabled() -> bool: +<<<<<<< HEAD # Checks every time this function is called, in case the Dynamo +======= + # Checks everytime this function is called, in case the Dynamo +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # option is set, but Z3 is not installed. _assert_z3_installed_if_tv_set() return _HAS_Z3 and config.translation_validation diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 514490513cbf5..6bd067b168f2c 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -324,6 +324,7 @@ def __init__(self): self._body_transformer: Optional[TransformCodeFunc] = None self._func_name: str = "forward" +<<<<<<< HEAD def _format_multiline_args(self, args: list[str]) -> str: """Helper to format function arguments in expanded multiline format.""" return "".join(self._format_single_arg(arg) for arg in args) @@ -373,6 +374,9 @@ def gen_fn_def( *, expanded_def: bool = False, ) -> str: +======= + def gen_fn_def(self, free_vars: list[str], maybe_return_annotation: str) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Given the free variables and a return annotation, generates the beginning of the FX function. By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'` @@ -381,6 +385,7 @@ def gen_fn_def( # would have added it. if len(free_vars) == 0 or free_vars[0] != "self": free_vars.insert(0, "self") +<<<<<<< HEAD if expanded_def: args_formatted = self._format_multiline_args(free_vars) @@ -393,14 +398,25 @@ def gen_fn_def( def generate_output( self, output_args: Argument, *, descs: Optional[Any] = None ) -> str: +======= + return ( + f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:" + ) + + def generate_output(self, output_args: Argument) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Given the output arguments, generates the return statement of the FX function. Note: The returned statement should not be indented. """ +<<<<<<< HEAD if descs is not None and isinstance(output_args, (list, tuple)): return self._format_multiline_container(output_args, descs, "return ") else: return f"return {repr(output_args)}" +======= + return f"return {repr(output_args)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def process_inputs(self, *args: Any) -> Any: """ @@ -438,8 +454,11 @@ def _gen_python_code( include_stride: bool = False, include_device: bool = False, colored: bool = False, +<<<<<<< HEAD # Render each argument on its own line expanded_def: bool = False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> PythonCode: free_vars: list[str] = [] body: list[str] = [] @@ -646,7 +665,10 @@ def emit_node(node: Node): maybe_type_annotation = ( "" if node.type is None else f" : {type_repr(node.type)}" ) +<<<<<<< HEAD maybe_comment = "" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if verbose: # override annotation with more detailed information @@ -678,6 +700,7 @@ def emit_node(node: Node): elif isinstance(meta_val, TensorMetadata): maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"' +<<<<<<< HEAD desc = None if expanded_def: desc = node.meta.get("desc", None) @@ -685,13 +708,19 @@ def emit_node(node: Node): maybe_comment += f" # {desc}" # output is handled specially +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if node.op == "placeholder": assert isinstance(node.target, str) maybe_default_arg = ( "" if not node.args else f" = {_get_repr(node.args[0])}" ) free_vars.append( +<<<<<<< HEAD f"{node.target}{maybe_type_annotation}{maybe_default_arg}{maybe_comment}" +======= + f"{node.target}{maybe_type_annotation}{maybe_default_arg}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) raw_name = node.target.replace("*", "") if raw_name != repr(node): @@ -767,6 +796,7 @@ def emit_node(node: Node): elif node.op == "output": if node.type is not None: maybe_return_annotation[0] = f" -> {type_repr(node.type)}" +<<<<<<< HEAD body.append( self._call_method_with_signature_check( self.generate_output, @@ -774,6 +804,9 @@ def emit_node(node: Node): descs=desc if expanded_def else None, ) ) +======= + body.append(self.generate_output(node.args[0])) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return raise NotImplementedError(f"node: {node.op} {node.target}") @@ -807,12 +840,16 @@ def emit_node(node: Node): for name, value in self.additional_globals(): add_global(name, value) +<<<<<<< HEAD prologue = self._call_method_with_signature_check( self.gen_fn_def, free_vars, maybe_return_annotation[0], expanded_def=expanded_def, ) +======= + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # remove counter and generate lineno to node index mapping lineno_map: dict[int, Optional[int]] = {} @@ -861,6 +898,7 @@ def process_outputs(self, out: Any) -> Any: assert self.pytree_info.out_spec is not None return pytree.tree_unflatten(out, self.pytree_info.out_spec) +<<<<<<< HEAD def _format_annotations(self, free_vars: list[str], expanded_def: bool) -> str: """Helper to format annotations for variables in pytree codegen.""" if not free_vars: @@ -878,6 +916,9 @@ def _format_annotations(self, free_vars: list[str], expanded_def: bool) -> str: def gen_fn_def( self, free_vars, maybe_return_annotation, *, expanded_def: bool = False ): +======= + def gen_fn_def(self, free_vars, maybe_return_annotation): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Given a user function/model: # myargs = [myargs0, myargs1] # mykwargs = {'mykwargs0': ..., 'mykwargs1': ...} @@ -894,17 +935,25 @@ def gen_fn_def( # If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec # e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec) if self.pytree_info is None: +<<<<<<< HEAD return super().gen_fn_def( free_vars, maybe_return_annotation, expanded_def=expanded_def ) +======= + return super().gen_fn_def(free_vars, maybe_return_annotation) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fn_args = self.pytree_info.orig_args has_orig_self = (fn_args[0] == "self") if len(fn_args) > 0 else False if has_orig_self: free_vars.insert(0, "self") +<<<<<<< HEAD fn_definition = super().gen_fn_def( fn_args[:], maybe_return_annotation, expanded_def=expanded_def ) +======= + fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if len(free_vars) > 0: # pytree has placeholders in it # when kwargs is present, in_spec is tuple(args, kwargs) @@ -936,12 +985,20 @@ def gen_fn_def( # we need to split it to two lines: # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon) # one for code: `var1, var2, = function_call()` +<<<<<<< HEAD without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars] fn_definition += self._format_annotations(free_vars, expanded_def) +======= + without_annotation = [x.split(":")[0] for x in free_vars] + has_annotation = [x + "; " for x in free_vars if ":" in x] + if len(has_annotation) > 0: + fn_definition += "\n " + "".join(has_annotation) + "\n" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fn_definition += f""" {", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" return fn_definition +<<<<<<< HEAD def generate_output(self, output_args, *, descs: Optional[Any] = None): if self.pytree_info and self.pytree_info.out_spec: if descs is not None and isinstance(output_args, (list, tuple)): @@ -957,6 +1014,13 @@ def generate_output(self, output_args, *, descs: Optional[Any] = None): ) else: return super().generate_output(output_args, descs=descs) +======= + def generate_output(self, output_args): + if self.pytree_info and self.pytree_info.out_spec: + return f"return pytree.tree_unflatten({repr(output_args)}, self._out_spec)" + else: + return super().generate_output(output_args) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _FindNodesLookupTable: @@ -1112,7 +1176,11 @@ def find_nodes( Returns: +<<<<<<< HEAD Iterable of nodes with the requested op and target. +======= + Iteratable of nodes with the requested op and target. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ node_list = self._find_nodes_lookup_table.find_nodes(op=op, target=target) if sort: @@ -1621,7 +1689,11 @@ def output(self, result: "Argument", type_expr: Optional[Any] = None): op="output", target="output", args=(result,), type_expr=type_expr ) +<<<<<<< HEAD def _target_to_str(self, target: Optional[Target]) -> str: +======= + def _target_to_str(self, target: Target) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if callable(target): op = target.__name__ else: @@ -1641,7 +1713,10 @@ def python_code( include_stride: bool = False, include_device: bool = False, colored: bool = False, +<<<<<<< HEAD expanded_def: bool = False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> PythonCode: """ Turn this ``Graph`` into valid Python code. @@ -1673,7 +1748,11 @@ def python_code( # To do this, we create a new namespace just for this source. All names # that get printed must come from this namespace. # +<<<<<<< HEAD # Why can't we reuse node.name? Because it was generated within the +======= + # Why can't we re-use node.name? Because it was generated within the +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # namespace `self._graph_namespace`. In order to provide uniqueness # over both locals (node.name) *and* globals, we create a completely # new namespace to put all identifiers in. @@ -1681,7 +1760,11 @@ def python_code( # Override Node's repr to generate a valid name within our namespace. # Since repr() is designed to produce a valid Python expression, it +<<<<<<< HEAD # makes sense to reuse it. This way, it's easy to print something like +======= + # makes sense to re-use it. This way, it's easy to print something like +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is # implemented cooperatively to allow this. def node_repr(n: Node): @@ -1708,7 +1791,10 @@ def override_node_repr(graph: Graph): include_stride=include_stride, include_device=include_device, colored=colored, +<<<<<<< HEAD expanded_def=expanded_def, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def _python_code( @@ -1720,7 +1806,10 @@ def _python_code( include_stride: bool = False, include_device: bool = False, colored: bool = False, +<<<<<<< HEAD expanded_def: bool = False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> PythonCode: return self._codegen._gen_python_code( self.nodes, @@ -1730,7 +1819,10 @@ def _python_code( include_stride=include_stride, include_device=include_device, colored=colored, +<<<<<<< HEAD expanded_def=expanded_def, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def __str__(self) -> str: diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 4c067c0e76e4c..8456f22e291f6 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -30,6 +30,10 @@ __all__ = [ "reduce_graph_module", "reduce_package_graph_module", +<<<<<<< HEAD +======= + "reduce_deploy_graph_module", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "GraphModule", ] @@ -146,6 +150,21 @@ def reduce_package_graph_module( return _deserialize_graph_module(forward, body) +<<<<<<< HEAD +======= +@compatibility(is_backward_compatible=True) +def reduce_deploy_graph_module( + importer: PackageImporter, body: dict[Any, Any], import_block: str +) -> torch.nn.Module: + ns = {} + ns["__builtins__"] = importer.patched_builtins + fn_src = body.get("_code") + assert fn_src is not None + forward = _forward_from_src(import_block + fn_src, ns) + return _deserialize_graph_module(forward, body) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We create a dummy class here because symbolic_trace pulls the forward() # function off of the class, rather than the instance. This class is used # in _deserialize_graph_module() below. @@ -309,7 +328,10 @@ def _print_readable( include_stride=False, include_device=False, colored=False, +<<<<<<< HEAD expanded_def=False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): graph = module.graph assert graph is not None and isinstance(graph, torch.fx.Graph), ( @@ -322,7 +344,10 @@ def _print_readable( include_stride=include_stride, include_device=include_device, colored=colored, +<<<<<<< HEAD expanded_def=expanded_def, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) module_code = verbose_python_code.src module_code = module_code.lstrip("\n") @@ -842,6 +867,17 @@ def call_wrapped(self, *args, **kwargs): # Passing Tracer as argument allows subclasses extending fx.GraphModule # define their own Tracer (extending fx.Tracer). +<<<<<<< HEAD +======= + def __reduce_deploy__(self, importer: Importer): + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, importer) + return (reduce_deploy_graph_module, (dict_without_graph, import_block)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __reduce_package__(self, exporter: PackageExporter): dict_without_graph = self.__dict__.copy() @@ -937,7 +973,10 @@ def print_readable( # If `fast_sympy_print` is True then we use a sympy printer which is faster # but may result in less-readable output. fast_sympy_print: bool = False, +<<<<<<< HEAD expanded_def: bool = False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Return the Python code generated for current GraphModule and its children GraphModules @@ -959,7 +998,10 @@ def fast_repr(expr: torch.types.PySymType) -> str: include_stride, include_device, colored, +<<<<<<< HEAD expanded_def, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return r @@ -978,7 +1020,11 @@ def _replicate_for_data_parallel(self): @contextlib.contextmanager def _set_replace_hook(self, f): """ +<<<<<<< HEAD Takes a callable which will be called every time when we replace a node +======= + Takes a callable which will be called everytime when we replace a node +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) to a new node, or change the node's name. Callable takes three arguments: the old node we're changing, and NAME of the new node, followed by the user node which consumes the old node to be replaced. @@ -992,7 +1038,11 @@ def _set_replace_hook(self, f): def _register_replace_node_hook(self, f): """ +<<<<<<< HEAD Takes a callable which will be called every time when we replace a node +======= + Takes a callable which will be called everytime when we replace a node +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) to a new node, or change the node's name. Callable takes three arguments: the old node we're changing, and NAME of the new node, followed by the user node which consumes the old node to be replaced. @@ -1002,7 +1052,11 @@ def _register_replace_node_hook(self, f): def _unregister_replace_node_hook(self, f): """ +<<<<<<< HEAD Takes a callable which was previously registered to be called every time when we replace a node. +======= + Takes a callable which was previously registered to be called everytime when we replace a node. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) This function will unregister that callable so it is no longer invoked on node replacement. """ assert callable(f), "create_node hook must be a callable." diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index a6cbe1cfe2c82..5b3e72ed8be73 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -5,7 +5,10 @@ import torch import torch.fx.traceback as fx_traceback +<<<<<<< HEAD from torch._logging import trace_structured +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.hub import tqdm from . import config @@ -33,7 +36,11 @@ class Interpreter: transformations as well as analysis passes. Methods in the Interpreter class can be overridden to customize +<<<<<<< HEAD the behavior of execution. The map of overridable methods +======= + the behavior of execution. The map of overrideable methods +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) in terms of call hierarchy:: run() @@ -176,12 +183,16 @@ def run( if self.extra_traceback: msg = f"While executing {node.format_node()}" msg = f"{e.args[0]}\n\n{msg}" if e.args else str(msg) +<<<<<<< HEAD msg += f"\nOriginal traceback:\n{node.stack_trace}" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ( isinstance(self.module, GraphModule) and self.module.graph is not None and isinstance(self.module.graph, torch.fx.Graph) ): +<<<<<<< HEAD trace_structured( "artifact", metadata_fn=lambda: { @@ -196,6 +207,10 @@ def run( msg += "\nUse tlparse to see full graph. " msg += "(https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)" +======= + msg += f"\nGraphModule: {self.module.print_readable(print_output=False, include_stride=True)}\n" + msg += f"\nOriginal traceback:\n{node.stack_trace}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) e.args = (msg,) + e.args[1:] if isinstance(e, KeyError): raise RuntimeError(*e.args) from e diff --git a/torch/fx/node.py b/torch/fx/node.py index dbd6ed93ef26c..605655428335c 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -4,9 +4,15 @@ import logging import operator import types +<<<<<<< HEAD from collections.abc import Iterable, Mapping, Sequence from typing import Any, Callable, Optional, TYPE_CHECKING, Union from typing_extensions import ParamSpec, TypeAlias, TypeVar +======= +from collections.abc import Mapping, Sequence +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import ParamSpec +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch._C import _fx_map_aggregate, _fx_map_arg, _NodeBase @@ -15,7 +21,10 @@ normalize_function, normalize_module, ) +<<<<<<< HEAD from torch.utils._dtype_abbrs import dtype_abbrs +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .._ops import ops as _ops from ._compatibility import compatibility @@ -46,7 +55,11 @@ ] base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined] +<<<<<<< HEAD Target: TypeAlias = Union[Callable[..., Any], str] +======= +Target = Union[Callable[..., Any], str] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Argument = Optional[ Union[ @@ -151,6 +164,7 @@ def _get_qualified_name(func: Callable[..., Any]) -> str: if getattr(builtins, func.__name__, None) is func: return func.__name__ # torch.Tensor.{fn} +<<<<<<< HEAD if ( isinstance(func, (types.MethodDescriptorType, types.WrapperDescriptorType)) and func is getattr(torch.Tensor, func.__name__, None) @@ -158,6 +172,11 @@ def _get_qualified_name(func: Callable[..., Any]) -> str: func.__module__ == torch._tensor.__name__ and func.__qualname__ == f"Tensor.{func.__name__}" ): +======= + if isinstance( + func, (types.MethodDescriptorType, types.WrapperDescriptorType) + ) and func is getattr(torch.Tensor, func.__name__, None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"torch.Tensor.{func.__name__}" name = func.__name__ if name == "": @@ -250,7 +269,11 @@ class Node(_NodeBase): # should not be accessed directly. _input_nodes: dict["Node", None] # All of the nodes that use the value produced by this Node +<<<<<<< HEAD # Note one user may correspond to several uses, e.g. the node for ``x + x`` +======= + # Note one user may correspond to several uses, e.g. the node fo ``x + x`` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # would appear once here, but represents two uses. # Is a dict to act as an "ordered set". Keys are significant, value dont-care users: dict["Node", None] @@ -602,8 +625,11 @@ def format_node( self, placeholder_names: Optional[list[str]] = None, maybe_return_typename: Optional[list[str]] = None, +<<<<<<< HEAD *, include_tensor_metadata: bool = False, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Optional[str]: """ Return a descriptive string representation of ``self``. @@ -625,7 +651,10 @@ def format_node( maybe_return_typename: A single-element list that will store a formatted string representing the output of the generated ``forward`` function. Internal use only. +<<<<<<< HEAD include_tensor_metadata: Whether to include tensor metadata +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns: str: If 1) we're using ``format_node`` as an internal helper @@ -657,6 +686,7 @@ def format_node( maybe_return_typename[0] = f" -> {_type_repr(self.type)}" return f"return {self.args[0]}" else: +<<<<<<< HEAD def stringify_shape(shape: Iterable) -> str: return f"[{', '.join([str(x) for x in shape])}]" @@ -687,6 +717,13 @@ def stringify_shape(shape: Iterable) -> str: ) return ( f"%{self.name} : {type_annotation}[num_users={len(self.users)}] = " +======= + maybe_typename = ( + f"{_type_repr(self.type)} " if self.type is not None else "" + ) + return ( + f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"{self.op}[target={self._pretty_print_target(self.target)}](" f"args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})" ) @@ -777,6 +814,7 @@ def is_impure(self, impure_random: bool = True) -> bool: # impure since it mutates RNG state return True +<<<<<<< HEAD # Handle Python random functions that don't have _nondeterministic_seeded # but still affect global RNG state (issue #151524) # These should be impure regardless of impure_random setting to maintain @@ -800,6 +838,8 @@ def is_impure(self, impure_random: bool = True) -> bool: # between eager and compiled execution, regardless of generator usage return True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.target in _side_effectful_functions # Check if an impure module. diff --git a/torch/fx/passes/_tensorify_python_scalars.py b/torch/fx/passes/_tensorify_python_scalars.py index dd8edb50e1612..890186e71a23a 100644 --- a/torch/fx/passes/_tensorify_python_scalars.py +++ b/torch/fx/passes/_tensorify_python_scalars.py @@ -64,8 +64,13 @@ # manage to eliminate all float compute, this ends up being equivalent, but # there is a critical difference when some floats cannot be eliminated: when # we call item() on them, what should it's SymFloat be? Ideally, it would +<<<<<<< HEAD # be the same backed SymFloat we had before. But without symbolic expression # propagation on tensor quantities, repropagating would instead give you an +======= +# be the same backed SymFloat we had before. But without symbolic expresssion +# propogation on tensor quantities, repropagating would instead give you an +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # unbacked SymFloat. Maybe it is a good idea to implement symbolic propagation # on 0d scalar tensors, but I decided to go for something simpler to start. # @@ -203,7 +208,11 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: and node.target is torch.ops.aten._local_scalar_dense.default ): dtype = node.args[0].meta["val"].dtype +<<<<<<< HEAD if not dtype.is_floating_point: +======= + if dtype != torch.float64: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) continue assert isinstance(node.args[0], fx.Node), node.args[0] @@ -212,10 +221,13 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: expr_to_tensor_proxy[s] = MetaProxy( node.args[0], tracer=tracer, fake_mode=fake_mode ) +<<<<<<< HEAD # Upcast the float tensor to torch.float64 to avoid precision problem expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default( expr_to_tensor_proxy[s], torch.float64 ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expr_to_sym_proxy[s] = MetaProxy( node, tracer=tracer, fake_mode=fake_mode ) @@ -350,7 +362,11 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: # Sometimes by the time we get to tensorify, there have already been # specializations, eg. in python_arg_parser.h. In these cases, # placeholder nodes no longer have a reference to their original +<<<<<<< HEAD # symfloat and thus we need to deduce specializations have happened +======= + # symfloat and thus we need to deduce specializations have happend +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # via shape_env.replacements. NB: there's an important invariant here # that symfloats keep consistent names across restarts. for k, v in shape_env.var_to_val.items(): diff --git a/torch/fx/passes/graph_manipulation.py b/torch/fx/passes/graph_manipulation.py index 6026e9ca25c05..5c6f54553ccb7 100644 --- a/torch/fx/passes/graph_manipulation.py +++ b/torch/fx/passes/graph_manipulation.py @@ -88,7 +88,11 @@ def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes: """ # Total num of elements total_num_of_elems = 0 +<<<<<<< HEAD # For a module, consider all parameters +======= + # For a module, conside all parameters +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if node.op == "call_module": submodule_dict = dict(fx_module.named_modules()) submodule = submodule_dict[node.target] diff --git a/torch/fx/passes/graph_transform_observer.py b/torch/fx/passes/graph_transform_observer.py index 6479af665895c..19175a0c00021 100644 --- a/torch/fx/passes/graph_transform_observer.py +++ b/torch/fx/passes/graph_transform_observer.py @@ -31,13 +31,18 @@ def __init__( """ log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified """ +<<<<<<< HEAD from torch._inductor import config as inductor_config +======= + from torch._inductor.config import trace +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.gm = gm self.passname = passname self.subsystem = subsystem if log_url is None: +<<<<<<< HEAD log_url = inductor_config.trace.log_url_for_graph_xform self.log_url = log_url @@ -46,12 +51,23 @@ def __init__( self.log_url is not None or inductor_config.trace.provenance_tracking_level == 1 ) +======= + log_url = trace.log_url_for_graph_xform + + self.log_url = log_url + + self.active = trace.enabled or self.log_url is not None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.active: self.erased_nodes: set[str] = set() self.created_nodes: set[str] = set() self.name_to_node: dict[str, Node] = {} +<<<<<<< HEAD # record graph modules deepcopied from self.gm, so we can remove hooks on them when exiting the context +======= + # record graph modules deepcopied from self.gm, so we can remove hoooks on them when exiting the context +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.copied_gms: list[GraphModule] = [] self._node_creation_hook = self.get_node_creation_hook() @@ -193,12 +209,15 @@ def on_node_replace(old: Node, new: str, user: Node): assert isinstance(new_node, Node) +<<<<<<< HEAD # replace hook is called once for each user of old # this avoids adding duplicated source nodes added_nodes = {s.name for s in new_node.meta.get("from_node", [])} if old.name in added_nodes: return +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) action = [NodeSourceAction.REPLACE] if new_node.name in self.created_nodes: action.append(NodeSourceAction.CREATE) diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 6fc17b959424d..c92aeec4019e5 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -18,6 +18,7 @@ class Partition: def __init__( +<<<<<<< HEAD self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None, @@ -35,12 +36,23 @@ def __init__( "nodes and node_orders must have the same length" ) self.nodes = dict(zip(nodes_list, node_orders_list)) +======= + self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None + ): + self.id = id + self.nodes = dict.fromkeys(nodes) if nodes is not None else {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __repr__(self) -> str: return str(self.nodes) +<<<<<<< HEAD def add_node(self, node: Node, node_order: Optional[int] = None): self.nodes.update({node: node_order}) +======= + def add_node(self, node: Node): + self.nodes.update({node: None}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def remove_node(self, node: Node): del self.nodes[node] @@ -185,7 +197,11 @@ def dfs_iter_find_cycle(all_user_nodes: set[Node]): return merge_id, True +<<<<<<< HEAD def merge_single_node(node: Node, node_order: Optional[int], id: Optional[int]): +======= + def merge_single_node(node: Node, id: Optional[int]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _update_partition_map(node: Node, id: int): # Iterate through all the users of this node and update the partition map to indicate # that there is a path from the partition id of this node to the target partition id. @@ -202,19 +218,31 @@ def _update_partition_map(node: Node, id: int): assignment.pop(node) elif id not in partitions_by_id: assignment[node] = id +<<<<<<< HEAD assert node_order is not None partitions_by_id[id] = Partition( id=id, nodes=[node], node_orders=[node_order] ) +======= + partitions_by_id[id] = Partition(id=id, nodes=[node]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) partition_users[id] = set(node.users) _update_partition_map(node, id) else: assignment[node] = id +<<<<<<< HEAD partitions_by_id[id].add_node(node, node_order) logger.debug("Proposing partitions...") for node_order, node in enumerate(reversed(self.graph_module.graph.nodes)): +======= + partitions_by_id[id].add_node(node) + + logger.debug("Proposing partitions...") + + for node in reversed(self.graph_module.graph.nodes): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # use Dict as an ordered set to ensure deterministic partitioning result, don't care value merge_candidates: dict[int, None] = {} @@ -227,7 +255,11 @@ def _update_partition_map(node: Node, id: int): partition_id = next(new_partition_id) nodes_order[node] = partition_id partitions_order[partition_id] = partition_id +<<<<<<< HEAD merge_single_node(node, node_order, partition_id) +======= + merge_single_node(node, partition_id) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) merge_candidates[partition_id] = None # merge all possible partitions @@ -244,6 +276,7 @@ def _update_partition_map(node: Node, id: int): # in the graph, otherwise, this is a no-op self_id, _ = maybe_merge_partition(self_id, other_id) +<<<<<<< HEAD # sort partition nodes based on descending node order for partition in partitions_by_id.values(): partition.nodes = dict( @@ -252,6 +285,8 @@ def _update_partition_map(node: Node, id: int): ) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # post processing to re-assign "getitem" nodes into upstream partition logger.debug("Reassigning getitem nodes to its producer node's partition...") nodes_reassignment: dict[Node, int] = {} @@ -272,7 +307,11 @@ def _update_partition_map(node: Node, id: int): if assignment.get(user, None) != id: # type: ignore[arg-type] nodes_reassignment[user] = id # type: ignore[assignment] for node, id in nodes_reassignment.items(): +<<<<<<< HEAD merge_single_node(node, None, id) +======= + merge_single_node(node, id) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # filter out single node partitions if not self.allows_single_node_partition: diff --git a/torch/fx/passes/infra/pass_manager.py b/torch/fx/passes/infra/pass_manager.py index 4077e74360f56..dbfdf93449888 100644 --- a/torch/fx/passes/infra/pass_manager.py +++ b/torch/fx/passes/infra/pass_manager.py @@ -78,7 +78,11 @@ def _topological_sort_passes( if len(constraints) == 0: return passes +<<<<<<< HEAD # Construct a graph mapping nodes to a list of their users +======= + # Contruct a graph mapping nodes to a list of their users +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph: dict[Callable, list[Callable]] = {p: [] for p in passes} indegree_map: dict[Callable, int] = dict.fromkeys(passes, 0) candidates: Queue = Queue() diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index 8c15b9097397b..31a381ca4bd48 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -95,7 +95,11 @@ class _MinimizerBase: Currently we provides two ways to traverse the graph and generate submodules. 1. Sequential traversal: this will traverse the graph node by node and generate +<<<<<<< HEAD one submodule with one single node. +======= + one submodule with one sigle node. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 2. Binary searching: this will do a binary search style traversal on the graph. For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP. @@ -648,7 +652,11 @@ def _block_traverse( ) -> NodeSet: """ Traverse topologically sorted node list +<<<<<<< HEAD Find minimum block (start_idx, end_idx) which contains the culprit +======= + Find minimium block (start_idx, end_idx) which contains the culprit +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) 1st pass: search for end_idx by finding the last node in culprit block where Numerical accuracy (0, end_idx) > threshold 2nd pass: search for start_idx by finding the first node in culprit block diff --git a/torch/fx/passes/reinplace.py b/torch/fx/passes/reinplace.py index 6027c603ec1fe..debdd81be4972 100644 --- a/torch/fx/passes/reinplace.py +++ b/torch/fx/passes/reinplace.py @@ -266,7 +266,11 @@ def matching_view_metadata(a, b): continue self_alias_base = self_alias.meta["view_of"] try: +<<<<<<< HEAD # The we're trying to reuse the args from the view_scatter call inside of the corresponding +======= + # The we're trying to re-use the args from the view_scatter call inside of the corresponding +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse # of the current alias we're looking at. view_replay_metadata = original_view( @@ -291,7 +295,11 @@ def reinplace(gm, *sample_args): mutating the nodes of the graph. We look for out-of-place op call sites like `b = a.add(...)`, and convert them to be inplace (`b = a.add_(...)`), +<<<<<<< HEAD as long as the input to the current operator ("a") isn't reused +======= + as long as the input to the current operator ("a") isn't re-used +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) anywhere later in the graph. This pass currently expects to operate on a **functional, ATen** graph. @@ -342,7 +350,11 @@ def reinplace(gm, *sample_args): NOTE: there's a future optimization that we should make: if "a" is a (alias of a) program input, but later in the program there is a node that looks like "a.copy_(...)", +<<<<<<< HEAD Then re-inplacing is ok to do - we are temporarily reusing a's buffer, +======= + Then re-inplacing is ok to do - we are temporarily re-using a's buffer, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) which will later be overwritten by the copy_() call. This will be an important optimization to have for programs that mutate @@ -599,7 +611,11 @@ def _add_to_map(x): later_node_usages, self_aliases ) +<<<<<<< HEAD # Step 2: Check to see if the input to the op is reused later in the graph. +======= + # Step 2: Check to see if the input to the op is re-used later in the graph. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # If not (same goes for its aliases), then this op is safe to re-in place. # This is a slightly roundabout way to check that there are no later usages of the current self argument. # (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete) diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 19e101a5c120a..2adad42a02e57 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -337,6 +337,7 @@ def match_symbol(symint, cb): torch._check, torch.ops.aten._assert_scalar.default, ): +<<<<<<< HEAD cond = node.args[0] if node.args else node.kwargs.get("cond") if ( cond == True # noqa: E712 @@ -344,6 +345,14 @@ def match_symbol(symint, cb): and assert_expr in added_asserts ): arg = cond +======= + if ( + node.args[0] == True # noqa: E712 + or (assert_expr := _get_sym_val(node.args[0])) in expr_to_proxy + and assert_expr in added_asserts + ): + arg = node.args[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) gm.graph.erase_node(node) if isinstance(arg, fx.Node) and not arg.users: gm.graph.erase_node(arg) @@ -462,7 +471,10 @@ def go(node, keypath): ), keypath[2:], ) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return go( graph.call_method( keypath[0].name, (node, keypath[1].idx) @@ -470,6 +482,7 @@ def go(node, keypath): keypath[2:], ) elif isinstance(keypath[0], CallMethodKey): +<<<<<<< HEAD if keypath[0].name == "storage_offset": return go( graph.call_function( @@ -479,6 +492,8 @@ def go(node, keypath): keypath[1:], ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return go( graph.call_method(keypath[0].name, (node,)), keypath[1:] ) diff --git a/torch/fx/passes/shape_prop.py b/torch/fx/passes/shape_prop.py index d734242abd82a..42a28bca4b3cc 100644 --- a/torch/fx/passes/shape_prop.py +++ b/torch/fx/passes/shape_prop.py @@ -7,7 +7,11 @@ import torch.fx from torch._dispatch.python import enable_python_dispatcher from torch._guards import detect_fake_mode +<<<<<<< HEAD from torch._prims_common import is_contiguous_for_memory_format_or_false +======= +from torch._prims_common import definitely_contiguous_for_memory_format +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._subclasses.meta_utils import is_sparse_any from torch.fx._compatibility import compatibility from torch.fx.node import map_aggregate, Node @@ -35,8 +39,13 @@ class TensorMetadata(NamedTuple): # When include_contiguity is True, we will set contiguity when its always true for the tensor. # Some tensors can represent both contiguous and non-contiguous tensors. e.g: (u0, u1) with (u2, u3). +<<<<<<< HEAD # In such situation contiguity is not set. We could also make it a tri-state i.e: (def_contiguous, # def_not_contiguous and unknown). +======= +# In such situation contiguity is not set. We could also make it a tri-state i.e: (definitely_contiguous, +# contiguous, and unknown). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _extract_tensor_metadata( result: torch.Tensor, include_contiguity=True ) -> TensorMetadata: @@ -57,7 +66,11 @@ def _extract_tensor_metadata( torch.channels_last_3d, } for query_format in memory_formats: +<<<<<<< HEAD if is_contiguous_for_memory_format_or_false( +======= + if definitely_contiguous_for_memory_format( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) result, memory_format=query_format ): memory_format = query_format diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index 4d9526c63f83d..c5b490ad8edfa 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -248,7 +248,10 @@ def record_cross_partition_use(def_node: Node, use_node: Optional[Node]): s_def_partition = partitions[s_defined] s_def_partition.outputs.setdefault(s_node.name) s_def_partition.dependents.setdefault(used) +<<<<<<< HEAD use_partition.dependencies.setdefault(s_defined) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if defined is not None: use_partition.dependencies.setdefault(defined) @@ -395,7 +398,11 @@ def instantiate_node_partition_mapping(node): root_partition = root_partitions.pop() sorted_partitions.append(root_partition) for dependent in partitions[root_partition].dependents: +<<<<<<< HEAD partitions[dependent].dependencies.pop(root_partition) # noqa: B909 +======= + partitions[dependent].dependencies.pop(root_partition) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not partitions[dependent].dependencies: root_partitions.append(dependent) if len(sorted_partitions) != len(partitions): diff --git a/torch/fx/passes/split_utils.py b/torch/fx/passes/split_utils.py index 88da7ac7c4f55..808cc2145cf7e 100644 --- a/torch/fx/passes/split_utils.py +++ b/torch/fx/passes/split_utils.py @@ -17,12 +17,16 @@ @compatibility(is_backward_compatible=False) def getattr_recursive(obj, name): for layer in name.split("."): +<<<<<<< HEAD if isinstance(obj, torch.nn.ModuleList): if hasattr(obj, "_modules") and layer in obj._modules: obj = obj._modules[layer] else: return None elif hasattr(obj, layer): +======= + if hasattr(obj, layer): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj = getattr(obj, layer) else: return None diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index 8a23c73785e8c..63650c792cb51 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import argparse import copy +<<<<<<< HEAD import json import logging import os @@ -11,6 +12,15 @@ import torch from torch._logging import trace_structured +======= +import logging +from collections import defaultdict +from collections.abc import Iterable, Sequence +from dataclasses import dataclass +from typing import Any, NamedTuple, Optional + +import torch +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.fx._compatibility import compatibility from torch.fx.node import map_arg from torch.fx.passes.graph_manipulation import get_size_of_node @@ -35,8 +45,11 @@ "Subgraph", "SplitResult", "generate_inputs_for_submodules", +<<<<<<< HEAD "NodeEvent", "NodeEventTracker", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] _LOGGER = logging.getLogger(__name__) @@ -44,6 +57,7 @@ DEFAULT_SKIP_FUSION = False DEFAULT_ALLOW_NON_TENSOR = False +<<<<<<< HEAD # ENV var and constants for node tracker TRACKER_DUMP_PATH = "_fx_net_tracker" @@ -73,6 +87,8 @@ ENV_FX_NET_ACC_SPLITTER_TRACKER_MODE, "0" ) # type: ignore[assignment] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _SplitterSettingBase: def __init__( @@ -134,6 +150,7 @@ def __init__( @compatibility(is_backward_compatible=False) +<<<<<<< HEAD class NodeEvent: """ An event in graph split that happened on a node. @@ -273,6 +290,8 @@ def dump_selected_nodes(nodes): @compatibility(is_backward_compatible=False) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class FxNetAccNodesFinder: """ Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor @@ -298,8 +317,11 @@ def __init__( self.allow_non_tensor = allow_non_tensor self.acc_nodes: NodeSet = set() +<<<<<<< HEAD self.tracker = NodeEventTracker(int(TRACKER_MODE), DUMP_PREFIX) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList): """ Transitively excludes nodes from ACC supported set. @@ -314,9 +336,13 @@ def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList): for user in node.users: if user in self.acc_nodes: self.acc_nodes.remove(user) +<<<<<<< HEAD self.tracker.add(user, "acc_del|user_of_new_cpu_node", node) if not is_node_output_tensor(user): self.tracker.add(user, "new_cpu_node|non_tensor_output") +======= + if not is_node_output_tensor(user): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cpu_worklist.append(user) def reduce_acc_nodes_non_tensor_input(self): @@ -333,7 +359,10 @@ def reduce_acc_nodes_non_tensor_input(self): continue if is_node_output_tensor(node): continue +<<<<<<< HEAD self.tracker.add(node, "new_cpu_node|callable_non_tensor_input") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) non_tensor_cpu_nodes.append(node) self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes) @@ -352,9 +381,12 @@ def reduce_acc_nodes_non_tensor_output(self): for user in acc_node.users: if user not in self.acc_nodes: new_cpu_nodes.append(acc_node) +<<<<<<< HEAD self.tracker.add( acc_node, "acc_del|non_tensor_output_with_cpu_user", user ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) break if not new_cpu_nodes: @@ -367,6 +399,7 @@ def reduce_acc_nodes_non_tensor_output(self): def __call__(self) -> NodeSet: submodules = dict(self.module.named_modules()) +<<<<<<< HEAD self.acc_nodes = set() for n in self.module.graph.nodes: if n.op not in CALLABLE_NODE_OPS: @@ -378,11 +411,23 @@ def __call__(self) -> NodeSet: self.tracker.add(n, "init_acc|callable_and_operator_supported") self.acc_nodes.add(n) +======= + self.acc_nodes = { + n + for n in self.module.graph.nodes + if n.op in CALLABLE_NODE_OPS + and self.operator_support.is_node_supported(submodules, n) + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not self.allow_non_tensor: self.reduce_acc_nodes_non_tensor_input() self.reduce_acc_nodes_non_tensor_output() +<<<<<<< HEAD self.tracker.dump() +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self.acc_nodes @@ -408,7 +453,11 @@ class SplitResult(NamedTuple): split_module: root module after splitting. submodule_inputs: a dict that maps submodule name to its inputs. non_acc_submodule_prefix: the prefix for non acc submodules. For +<<<<<<< HEAD acc submodule the prefix is always "_run_on_acc_". +======= + acc submodule the prefix is alwasy "_run_on_acc_". +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ split_module: torch.fx.GraphModule @@ -447,8 +496,12 @@ def pre_forward(module, module_inputs): for name, mod in model.named_modules(): if name in target_submodules: +<<<<<<< HEAD if not isinstance(mod, torch.jit.ScriptModule): handles.append(mod.register_forward_pre_hook(pre_forward)) +======= + handles.append(mod.register_forward_pre_hook(pre_forward)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def clean_up_handles(): for h in handles: @@ -905,7 +958,11 @@ def extend_acc_subgraph(self, tag: str): """ # Dict that maps node to its users and ignore users that # are in the subgraph that has greater tag +<<<<<<< HEAD deps = self.find_reverse_deps(tag_id=int(tag.rsplit("_", maxsplit=1)[-1])) +======= + deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1])) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.update_reverse_deps_for_fusions(deps) # Parent nodes of the subgraph diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 33db9fd03d790..c46cb14a134e1 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -44,7 +44,11 @@ def topo_sort(nodes: NodeList) -> NodeList: @compatibility(is_backward_compatible=False) def validate_partition(partition: NodeList) -> bool: +<<<<<<< HEAD # verify the partition doesn't form a dependency cycle in the original graph +======= + # verify the partition does't form a dependency cycle in the original graph +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # returns True for valid partition, False for invalid partition_set = set(partition) @@ -96,7 +100,11 @@ def fuse_as_graphmodule( gm: GraphModule, nodes: NodeList, module_name: str, +<<<<<<< HEAD partition_lookup_table: _Optional[dict[Node, _Optional[int]]] = None, +======= + partition_lookup_table: _Optional[dict[Node, None]] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *, always_return_tuple: bool = False, ) -> tuple[GraphModule, tuple[Node, ...], tuple[Node, ...]]: @@ -157,13 +165,21 @@ def remap_inputs(x: Node) -> Node: if x in partition_lookup_table: # x is inside subgraph, return the copied node +<<<<<<< HEAD # the node should have been copied already, as we are copying graph in the topological order +======= + # the node should have been copied aleady, as we are copying graph in the topological order +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return node_map[x] if x not in node_to_placeholder: # x is not in subgraph, create a new placeholder for subgraph placeholder_node = subgraph.placeholder(x.name, type_expr=x.type) +<<<<<<< HEAD # copy all meta fields, even if some fields might be irrelevant for the placeholder node +======= + # copy all meta fields, even if some fields might be irrelvant for the placeholder node +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) placeholder_node.meta = copy.copy(x.meta) node_to_placeholder[x] = placeholder_node @@ -249,7 +265,11 @@ def erase_nodes(gm: GraphModule, nodes: NodeList) -> None: @compatibility(is_backward_compatible=False) def fuse_by_partitions( gm: GraphModule, +<<<<<<< HEAD partitions: list[dict[Node, _Optional[int]]], +======= + partitions: list[dict[Node, None]], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prefix: str = "fused_", always_return_tuple: bool = False, ) -> GraphModule: diff --git a/torch/fx/passes/utils/matcher_utils.py b/torch/fx/passes/utils/matcher_utils.py index aa58b52933f94..21a024fb060bc 100644 --- a/torch/fx/passes/utils/matcher_utils.py +++ b/torch/fx/passes/utils/matcher_utils.py @@ -95,7 +95,11 @@ def __init__( ) for node in pattern.nodes: +<<<<<<< HEAD if node.op != "output" and not node.is_impure(): +======= + if node.op != "output": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(node.users) > 0, ( "SubgraphMatcher cannot be initialized with an pattern with dead code" ) @@ -137,14 +141,21 @@ def _match_attributes(self, pn: Node, gn: Node) -> bool: raise RuntimeError(f"Unsupported type {pn_value} when matching attributes") return False +<<<<<<< HEAD def _nodes_are_equal(self, pn: Node, gn: Node, node_name_match: str = "") -> bool: +======= + def _nodes_are_equal(self, pn: Node, gn: Node) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # if exact match for placeholder is not required, then use placeholder as a wildcard if not self.match_placeholder and pn.op == "placeholder": return True +<<<<<<< HEAD if node_name_match and node_name_match in gn.name: return True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if pn.op == gn.op: if pn.op == "placeholder" or pn.op == "output": return True @@ -215,9 +226,13 @@ def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool: else: return type(gn) == type(pn) and gn == pn +<<<<<<< HEAD def _match_nodes( self, pn: Node, gn: Node, match: InternalMatch, node_name_match: str = "" ) -> bool: +======= + def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) logger.info(" matching %s to %s", pn, gn) assert isinstance(pn, Node) and isinstance(gn, Node), str( @@ -233,7 +248,11 @@ def _match_nodes( if gn in match.nodes_map.values(): return False +<<<<<<< HEAD if not self._nodes_are_equal(pn, gn, node_name_match): +======= + if not self._nodes_are_equal(pn, gn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False # Optimistically mark `pn` as a match for `gn`, and save a local copy of match @@ -318,11 +337,19 @@ def get_all_arguments(orig_args, orig_kwargs): return True +<<<<<<< HEAD def match(self, graph: Graph, node_name_match: str = "") -> list[InternalMatch]: """ Returns: The matched subgraphs. The returned subgraph would be fully self-contained, meaning the nodes (except placeholder +======= + def match(self, graph: Graph) -> list[InternalMatch]: + """ + Returns: + The matched subgraphs. + Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and nodes returned by output) can only be consumed by nodes within the matched subgraph. Subgraph pattern matcher is implemented with the backtracking style in the following steps: @@ -360,7 +387,11 @@ def match(self, graph: Graph, node_name_match: str = "") -> list[InternalMatch]: match_candidates: dict[Node, list[Node]] = defaultdict(list) for pattern_anchor in self.pattern_anchors: for node in graph.nodes: +<<<<<<< HEAD if self._nodes_are_equal(pattern_anchor, node, node_name_match): +======= + if self._nodes_are_equal(pattern_anchor, node): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) match_candidates[pattern_anchor].append(node) match_candidates_list = list(match_candidates.items()) @@ -387,9 +418,13 @@ def backtracking(anchor_index, match): for node in candidate_nodes: logger.info("Trying to match anchor %s to %s", pattern_anchor, node) +<<<<<<< HEAD match_found = self._match_nodes( pattern_anchor, node, match, node_name_match ) +======= + match_found = self._match_nodes(pattern_anchor, node, match) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if match_found: # match next anchor backtracking(anchor_index + 1, match) diff --git a/torch/fx/passes/utils/matcher_with_name_node_map_utils.py b/torch/fx/passes/utils/matcher_with_name_node_map_utils.py index 3114d55b635fc..f3548b98a1309 100644 --- a/torch/fx/passes/utils/matcher_with_name_node_map_utils.py +++ b/torch/fx/passes/utils/matcher_with_name_node_map_utils.py @@ -88,7 +88,11 @@ def __init__( ignore_literals, ) +<<<<<<< HEAD def match(self, graph: Graph, node_name_match: str = "") -> list[InternalMatch]: +======= + def match(self, graph: Graph) -> list[InternalMatch]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """The returned InternalMatch will have name_node_map populated with a map from node name (str) to the target node, e.g. {"conv": target_conv_ndoe, "relu": target_relu_node} @@ -107,7 +111,11 @@ def pattern(...): return relu, {"conv": conv, "relu": relu} ``` instead """ +<<<<<<< HEAD internal_matches = super().match(graph, node_name_match) +======= + internal_matches = super().match(graph) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for internal_match in internal_matches: for k, n in self.name_node_map.items(): internal_match.name_node_map[k] = internal_match.nodes_map[n] diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index aa057085b3760..f138ca7beebf4 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -124,10 +124,13 @@ def __exit__(self, *args): class TracerBase: graph: Graph record_stack_traces: bool = False +<<<<<<< HEAD # When record_stack_traces is True, only reocrd stack traces # with forward function names. # This helps when we want stack trace back to model code _record_forward_stack_traces_only: bool = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Feature flag for mutable schema checking # Enableby default in 1.12 check_mutable_operations: bool = False @@ -208,6 +211,7 @@ def create_node( elif self.module_stack: node.meta["nn_module_stack"] = copy.copy(self.module_stack) +<<<<<<< HEAD if self.record_stack_traces and not node.stack_trace: user_stack_summary = CapturedTraceback.extract().summary() if user_stack_summary: @@ -256,6 +260,11 @@ def _filter_traceback_frames( return traceback.StackSummary.from_list(user_frames) +======= + log.debug("create_node %s", node) + return node + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @compatibility(is_backward_compatible=True) def proxy(self, node: Node) -> "Proxy": return Proxy(node, self) @@ -294,6 +303,34 @@ def create_proxy( else: proxy = proxy_factory_fn(node) +<<<<<<< HEAD +======= + if self.record_stack_traces and not proxy.node.stack_trace: + from torch.fx.experimental.symbolic_shapes import uninteresting_files + + user_frame_summary = CapturedTraceback.extract().summary() + if user_frame_summary: + first_forward = -1 + for i, frame in enumerate(user_frame_summary): + if frame.name == "forward": + user_frame_summary = user_frame_summary[i:] + first_forward = i + break + + # Not having a "forward" call in the stacktrace implies the + # stacktrace will probably be irrelevant + if first_forward == -1: + user_frame_summary = [] + + stack_trace = [ + frame + for frame in user_frame_summary + if frame.filename not in uninteresting_files() + ] + stack_trace = traceback.StackSummary.from_list(stack_trace) + proxy.node.stack_trace = "".join(stack_trace.format()).strip() + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return proxy def _find_user_frame(self): diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index eebdfad096322..9f6d06901d321 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -234,7 +234,10 @@ def replace_pattern_with_filters( replacement_callback: Optional[ Callable[["InternalMatch", Graph, Graph], Graph] ] = None, +<<<<<<< HEAD node_name_match: str = "", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> list[ReplacedPatterns]: """ See replace_pattern for documentation. This function is an overload with an additional match_filter argument. @@ -247,6 +250,7 @@ def replace_pattern_with_filters( ``replacement_callback``: A function that takes in a match and returns a Graph to be used as the replacement. This allows you to construct a replacement graph based on the match. +<<<<<<< HEAD ``replacement_callback``: Node name to match. If not empty, it will try to match the node name. """ @@ -258,6 +262,12 @@ def replace_pattern_with_filters( ignore_literals, replacement_callback, node_name_match, +======= + """ + + return _replace_pattern( + gm, pattern, replacement, match_filters, ignore_literals, replacement_callback +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -273,7 +283,10 @@ def _replace_pattern( replacement_callback: Optional[ Callable[["InternalMatch", Graph, Graph], Graph] ] = None, +<<<<<<< HEAD node_name_match: str = "", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> list[ReplacedPatterns]: from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher @@ -297,9 +310,13 @@ def _replace_pattern( remove_overlapping_matches=True, ignore_literals=ignore_literals, ) +<<<<<<< HEAD _matches: list[InternalMatch] = matcher.match( original_graph, node_name_match=node_name_match ) +======= + _matches: list[InternalMatch] = matcher.match(original_graph) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Filter out matches that don't match the filter _matches = [ diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index ed111b5f5b54b..87ba61d156432 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -1,20 +1,29 @@ # mypy: allow-untyped-defs import copy +<<<<<<< HEAD import logging +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import traceback from contextlib import contextmanager from enum import Enum from typing import Any, Optional, Union +<<<<<<< HEAD from torch._utils_internal import signpost_event +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from ._compatibility import compatibility from .graph import Graph from .node import Node +<<<<<<< HEAD log = logging.getLogger(__name__) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __all__ = [ "preserve_node_meta", "has_preserved_node_meta", @@ -56,8 +65,11 @@ def __init__(self, name: str, target: str, graph_id: int): action: list["NodeSourceAction"] from_node: list["NodeSource"] node_info: Optional["NodeInfo"] +<<<<<<< HEAD _dict: Optional[dict[str, Any]] _action_string: Optional[str] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__( self, @@ -87,10 +99,13 @@ def __init__( self.node_info = None self.from_node = [] +<<<<<<< HEAD # cache the action string and dict representation for performance. self._action_string: Optional[str] = None self._dict: Optional[dict[str, Any]] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @property def name(self) -> str: return self.node_info.name if self.node_info else "" @@ -107,9 +122,13 @@ def __repr__(self): return self.print_readable() def _get_action_string(self): +<<<<<<< HEAD if self._action_string is None: self._action_string = "+".join([a.name.lower() for a in self.action]) return self._action_string +======= + return "+".join([a.name.lower() for a in self.action]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def print_readable(self, indent=0): if indent > 9: @@ -125,6 +144,7 @@ def print_readable(self, indent=0): return result def to_dict(self) -> dict: +<<<<<<< HEAD if self._dict is None: # Convert the object to a dictionary action_string = self._get_action_string() @@ -211,6 +231,18 @@ def _from_dict(cls, d: Optional[dict]) -> Optional["NodeSource"]: else: node_source.from_node = [] return node_source +======= + # Convert the object to a dictionary + action_string = self._get_action_string() + return { + "name": self.name, + "target": self.target, + "graph_id": self.graph_id, + "pass_name": self.pass_name, + "action": action_string, + "from_node": [node.to_dict() for node in self.from_node], + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @compatibility(is_backward_compatible=False) @@ -316,6 +348,7 @@ def get_graph_provenance_json(graph: Graph) -> dict[str, Any]: """ Given an fx.Graph, return a json that contains the provenance information of each node. """ +<<<<<<< HEAD try: provenance_tracking_json = {} for node in graph.nodes: @@ -339,3 +372,14 @@ def get_graph_provenance_json(graph: Graph) -> dict[str, Any]: }, ) return {} +======= + provenance_tracking_json = {} + for node in graph.nodes: + if node.op == "call_function": + provenance_tracking_json[node.name] = ( + [source.to_dict() for source in node.meta["from_node"]] + if "from_node" in node.meta + else [] + ) + return provenance_tracking_json +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index 4cfeeb6238ad5..223d4c4c57428 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -3,6 +3,7 @@ # to guarantee that compiling these symbols do not require linking libtorch # to ensure header-only-ness. +<<<<<<< HEAD # torch/headeronly/util/shim_utils.h TORCH_ERROR_CODE_CHECK @@ -43,6 +44,33 @@ fp16_ieee_to_fp32_value # fp32_to_bits called from fp16_ieee_from_fp32_value # c10/util/complex.h, torch/headeronly/util/complex.h +======= +# c10/util/TypeCast.h +convert + +# c10/util/bit_cast.h +bit_cast + +# c10/util/BFloat16-math.h, c10/util/BFloat16.h +BFloat16 + +# c10/util/Float8_e4m3fn.h +Float8_e4m3fn + +# c10/util/Float8_e4m3fnuz.h +Float8_e4m3fnuz + +# c10/util/Float8_e5m2.h +Float8_e5m2 + +# c10/util/Float8_e5m2fnuz.h +Float8_e5m2fnuz + +# c10/util/Half.h +Half + +# c10/util/complex.h +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) complex # ATen/NumericUtils.h, c10/util/generic_math.h @@ -63,6 +91,7 @@ maximum minimum size +<<<<<<< HEAD # torch/headeronly/cpu/vec/vec_half.h float2half_scalar half2float_scalar @@ -99,3 +128,7 @@ bits16 NumScalarTypes ScalarType # dummy_int1_7_t, dummy_uint1_7_t tested through ScalarType +======= +# torch/headeronly/macros/Export.h +C10_API +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/headeronly/BUILD.bazel b/torch/headeronly/BUILD.bazel index 030651b120436..1bcfa077b716b 100644 --- a/torch/headeronly/BUILD.bazel +++ b/torch/headeronly/BUILD.bazel @@ -1,5 +1,16 @@ load("@rules_cc//cc:defs.bzl", "cc_library") +<<<<<<< HEAD load("//:tools/bazel.bzl", "rules") load(":build.bzl", "define_targets") define_targets(rules = rules) +======= + +cc_library( + name = "torch_headeronly", + hdrs = glob([ + "**/*.h" + ]), + visibility = ["//visibility:public"], +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/headeronly/macros/Export.h b/torch/headeronly/macros/Export.h index 8dd25419efb4e..8a80f77b8e634 100644 --- a/torch/headeronly/macros/Export.h +++ b/torch/headeronly/macros/Export.h @@ -1,5 +1,6 @@ #pragma once +<<<<<<< HEAD #ifndef C10_MACROS_EXPORT_H_ #define C10_MACROS_EXPORT_H_ @@ -7,6 +8,8 @@ #include #endif // C10_USING_CUSTOM_GENERATED_MACROS +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /* Header file to define the common scaffolding for exported symbols. * * Export is by itself a quite tricky situation to deal with, and if you are @@ -92,6 +95,7 @@ #else #define C10_API C10_IMPORT #endif +<<<<<<< HEAD // This one is being used by libtorch.so #ifdef CAFFE2_BUILD_MAIN_LIB @@ -151,3 +155,5 @@ #define C10_API_ENUM #endif #endif // C10_MACROS_EXPORT_H_ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/hub.py b/torch/hub.py index fc943a4dd004d..89c5aced31cd2 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -200,12 +200,16 @@ def _validate_not_a_forked_repo(repo_owner, repo_name, ref): while True: page += 1 url = f"{url_prefix}?per_page=100&page={page}" +<<<<<<< HEAD try: response = json.loads(_read_url(Request(url, headers=headers))) except HTTPError: # Retry without token in case it had insufficient permissions. del headers["Authorization"] response = json.loads(_read_url(Request(url, headers=headers))) +======= + response = json.loads(_read_url(Request(url, headers=headers))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Empty response means no more data to process if not response: break @@ -272,7 +276,11 @@ def _get_cache_or_reload( except HTTPError as err: if err.code == 300: # Getting a 300 Multiple Choices error likely means that the ref is both a tag and a branch +<<<<<<< HEAD # in the repo. This can be disambiguated by explicitly using refs/heads/ or refs/tags +======= + # in the repo. This can be disambiguated by explicitely using refs/heads/ or refs/tags +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # See https://git-scm.com/book/en/v2/Git-Internals-Git-References # Here, we do the same as git: we throw a warning, and assume the user wanted the branch warnings.warn( diff --git a/torch/jit/_decompositions.py b/torch/jit/_decompositions.py index 000ec7d0ec796..cc59fb2b67cd2 100644 --- a/torch/jit/_decompositions.py +++ b/torch/jit/_decompositions.py @@ -40,7 +40,11 @@ def signatures_match(decomposition_sig, torch_op_sig): return False for decomp_param, op_param in zip(decomp_params.values(), op_params.values()): +<<<<<<< HEAD # can't check full equality yet because not all fields are correctly deduced +======= + # can't check full equality yet because not all fields are correcly deduced +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # in the torch_op_sig - like default value # can't check 'kind' bc # kwarg-only values with defaults not yet supported in TS diff --git a/torch/jit/_freeze.py b/torch/jit/_freeze.py index b61a2dd6207d1..2c698a16383b0 100644 --- a/torch/jit/_freeze.py +++ b/torch/jit/_freeze.py @@ -150,7 +150,11 @@ def run_frozen_optimizations( None Note: +<<<<<<< HEAD In rare occasions, this can result in slower execution. +======= + In rare occassions, this can result in slower execution. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Example (Freezing a module with Conv->Batchnorm) .. code-block:: python diff --git a/torch/jit/_monkeytype_config.py b/torch/jit/_monkeytype_config.py index 84ea4d5c3f6b0..507b2086e1de6 100644 --- a/torch/jit/_monkeytype_config.py +++ b/torch/jit/_monkeytype_config.py @@ -26,7 +26,11 @@ _IS_MONKEYTYPE_INSTALLED = False +<<<<<<< HEAD # Checks whether a class is defined in `torch.*` modules +======= +# Checks whether a class is defind in `torch.*` modules +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def is_torch_native_class(cls): if not hasattr(cls, "__module__"): return False @@ -130,7 +134,11 @@ def consolidate_types(self, qualified_name: str) -> dict: types = list(types) type_length = len(types) if type_length == 2 and type(None) in types: +<<<<<<< HEAD # TODO: To remove this check once Union support in TorchScript lands. +======= + # TODO: To remove this check once Union suppport in TorchScript lands. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) all_args[arg] = get_optional_of_element_type(types) elif type_length > 1: all_args[arg] = "Any" diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index e89bcc47dff6b..fc4485c4a420a 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -431,7 +431,11 @@ def __init__(self) -> None: self.methods_compiled = set() def get_or_create_concrete_type(self, nn_module): +<<<<<<< HEAD """Infer a ConcreteType from this `nn.Module` instance. Underlying JIT types are reused if possible.""" +======= + """Infer a ConcreteType from this `nn.Module` instance. Underlying JIT types are re-used if possible.""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) concrete_type_builder = infer_concrete_type_builder(nn_module) nn_module_type = type(nn_module) @@ -502,7 +506,11 @@ def get_module_concrete_type(nn_module, share_types=True): # Look into the store of cached JIT types concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module) else: +<<<<<<< HEAD # Get a concrete type directly, without trying to reuse an existing JIT +======= + # Get a concrete type directly, without trying to re-use an existing JIT +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # type from the type store. concrete_type_builder = infer_concrete_type_builder(nn_module, share_types) concrete_type_builder.set_poisoned() diff --git a/torch/jit/_script.py b/torch/jit/_script.py index ccd967d69f4e7..c81ff35a5a51b 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -704,7 +704,14 @@ def _reconstruct(self, cpp_module): @property def graph(self): +<<<<<<< HEAD r"""Return a string representation of the internal graph for the ``forward`` method.""" +======= + r"""Return a string representation of the internal graph for the ``forward`` method. + + See :ref:`interpreting-graphs` for details. + """ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return self._c._get_method("forward").graph @property @@ -713,6 +720,10 @@ def inlined_graph(self): Return a string representation of the internal graph for the ``forward`` method. This graph will be preprocessed to inline all function and method calls. +<<<<<<< HEAD +======= + See :ref:`interpreting-graphs` for details. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ return self.forward.inlined_graph # type: ignore[attr-defined] @@ -721,6 +732,10 @@ def code(self): r""" Return a pretty-printed representation (as valid Python syntax) of the internal graph for the ``forward`` method. +<<<<<<< HEAD +======= + See :ref:`inspecting-code` for details. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ return self.forward.code # type: ignore[attr-defined] @@ -735,6 +750,10 @@ def code_with_constants(self): [1] a ConstMap following the CONSTANT.cN format of the output in [0]. The indices in the [0] output are keys to the underlying constant's values. +<<<<<<< HEAD +======= + See :ref:`inspecting-code` for details. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ r = self.forward.code_with_constants # type: ignore[attr-defined] return (r[0], ConstMap(r[1])) @@ -1240,7 +1259,11 @@ def script( subsequently passed by reference between Python and TorchScript with zero copy overhead. ``torch.jit.script`` can be used as a function for modules, functions, dictionaries and lists +<<<<<<< HEAD and as a decorator ``@torch.jit.script`` for torchscript-classes and functions. +======= + and as a decorator ``@torch.jit.script`` for :ref:`torchscript-classes` and functions. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: obj (Callable, class, or nn.Module): The ``nn.Module``, function, class type, diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py index f2a6f4a841763..7f90e0dcaf3d0 100644 --- a/torch/jit/_shape_functions.py +++ b/torch/jit/_shape_functions.py @@ -86,7 +86,11 @@ def broadcast_inplace(a: list[int], b: list[int]): dimsB = len(b) if dimsB > dimsA: raise AssertionError( +<<<<<<< HEAD f"The dims of tensor b ({dimsB}) must be less than or equal to the dims of tensor a ({dimsA}) " +======= + f"The dims of tensor b ({dimsB}) must be less than or equal tothe dims of tensor a ({dimsA}) " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) for dimA in range(dimsA): dimB = dimsB - dimsA + dimA @@ -1172,7 +1176,11 @@ def cross_entropy_loss( adding ops. There are currently cases in the test case where this is being called in the SSA opinfo tests with with unexpected values (eg list of two ints, see the first +<<<<<<< HEAD opinfo test). The behavior of index is significantly dependent on the inputs. +======= +opinfo test). The behavoir of index is significantly dependent on the inputs. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) This could be an error with how we are matching up shape functions, or that this function needs to just implement everything. @@ -1452,7 +1460,11 @@ def add_bounded_compute_mapping( # add_shape_compute_mapping("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", index_Tensor) # TODO: migrate over all of symbolic_shape_registry_util.cpp +<<<<<<< HEAD # These are duplicated here so that the functions will be serialized +======= +# These are duplicated here so that the functions will be serialiazed +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) add_shape_compute_mapping( "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", broadcast_three, diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 5084d7c922837..8b67cbe1f3bb2 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -993,7 +993,15 @@ def forward(self, x): stacklevel=2, ) +<<<<<<< HEAD from torch._utils_internal import log_torchscript_usage +======= + from torch._utils_internal import ( + check_if_torch_exportable, + log_torch_jit_trace_exportability, + log_torchscript_usage, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) traced_func = _trace_impl( func, @@ -1010,6 +1018,106 @@ def forward(self, x): _store_inputs, ) log_torchscript_usage("trace", model_id=_get_model_id(traced_func)) +<<<<<<< HEAD +======= + + if check_if_torch_exportable(): + from torch._export.converter import TS2EPConverter + from torch.export._trace import ( + _convert_ts_to_export_experimental, + _process_jit_trace_inputs_for_export, + ) + + traced_func_for_export = _trace_impl( + func, + example_inputs=example_inputs, + optimize=optimize, + check_trace=False, + check_inputs=check_inputs, + check_tolerance=check_tolerance, + strict=strict, + _force_outplace=_force_outplace, + _module_class=_module_class, + _compilation_unit=_compilation_unit, + example_kwarg_inputs=example_kwarg_inputs, + _store_inputs=_store_inputs, + ) + + export_args, _ = _process_jit_trace_inputs_for_export( + example_inputs, example_kwarg_inputs + ) + + def _log_exportability(func_to_export, export_func, export_args, export_type): + try: + traced_result = func_to_export(*export_args) + except Exception as e: + _ = e + log_torch_jit_trace_exportability( + "trace", str(export_type), str(_ExportOutcome.SUCCESS), "succeeded" + ) + return + + try: + ep_module = export_func(func_to_export, export_args) + except Exception as e: + log_torch_jit_trace_exportability( + "trace", + str(export_type), + str(_ExportOutcome.FAILED_TO_EXPORT), + str(e), + ) + return + + try: + export = ep_module(*export_args) + except Exception as e: + log_torch_jit_trace_exportability( + "trace", str(export_type), str(_ExportOutcome.FAILED_TO_RUN), str(e) + ) + return + + if not analyze_ts_result_with_export_result(export, traced_result): + log_torch_jit_trace_exportability( + "trace", + str(export_type), + str(_ExportOutcome.ACCURACY_ERROR), + "accuracy error", + ) + return + + log_torch_jit_trace_exportability( + "trace", str(export_type), str(_ExportOutcome.SUCCESS), "succeeded" + ) + + def _direct_export_and_lower(func, export_args): + return torch.export.export(func, export_args, strict=False).module() + + def _convert_ts_to_export_source_to_source(func, export_args): + return TS2EPConverter(func, export_args).convert().module() + + # torch.jit.trace is noop when the original module is torch.jit.ScriptModule + if not isinstance(traced_func_for_export, torch.jit.ScriptModule): + _log_exportability( + traced_func_for_export, + _direct_export_and_lower, + export_args, + _ExportType.DIRECT_EXPORT, + ) + + _log_exportability( + traced_func_for_export, + _convert_ts_to_export_experimental, + export_args, + _ExportType.TRACE_AND_EXPORT, + ) + _log_exportability( + traced_func_for_export, + _convert_ts_to_export_source_to_source, + export_args, + _ExportType.SOURCE_TO_SOURCE, + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return traced_func diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 518eb7df2749c..712e9d9258adc 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -73,6 +73,17 @@ from torch.jit._monkeytype_config import get_qualified_name, monkeytype_trace +<<<<<<< HEAD +======= +_IS_ASTUNPARSE_INSTALLED = False +try: + import astunparse # type: ignore[import] + + _IS_ASTUNPARSE_INSTALLED = True +except ImportError: + pass + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Borrowed from cPython implementation # https://github.com/python/cpython/blob/561612d8456cfab5672c9b445521113b847bd6b3/Lib/textwrap.py#L411# @@ -430,11 +441,15 @@ def build_def(ctx, py_def, type_line, def_name, self_name=None, pdt_arg_types=No is_method = self_name is not None if type_line is not None: type_comment_decl = torch._C.parse_type_comment(type_line) +<<<<<<< HEAD decl = torch._C.merge_type_from_type_comment( decl, # type: ignore[arg-type] type_comment_decl, is_method, # type: ignore[assignment] ) +======= + decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return Def(Ident(r, def_name), decl, build_stmts(ctx, body)) @@ -583,7 +598,11 @@ def build_args(args): from typing import List, Dict, Tuple @torch.jit.ignore +<<<<<<< HEAD {ast.unparse(ignore_function)} +======= +{astunparse.unparse(ignore_function)} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ g = copy.copy(globals()) exec(ignore_func_str, g) # noqa: P204 @@ -838,6 +857,14 @@ def build_With(ctx, stmt): r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("with")) # Handle ignore context manager if is_torch_jit_ignore_context_manager(stmt): +<<<<<<< HEAD +======= + if not _IS_ASTUNPARSE_INSTALLED: + raise RuntimeError( + "torch.jit._IgnoreContextManager requires installing Python library `astunparse`, \ + please install it in your Python environment" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assign_ast = build_ignore_context_manager(ctx, stmt) return build_stmt(ctx, assign_ast) return With(r, build_withitems(ctx, stmt.items), build_stmts(ctx, stmt.body)) @@ -1046,12 +1073,20 @@ def build_Compare(ctx, expr): in_expr = BinOp("in", lhs, rhs) cmp_expr = UnaryOp(r, "not", in_expr) else: +<<<<<<< HEAD cmp_expr = BinOp(op_token, lhs, rhs) # type: ignore[assignment] +======= + cmp_expr = BinOp(op_token, lhs, rhs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if result is None: result = cmp_expr else: +<<<<<<< HEAD result = BinOp("and", result, cmp_expr) # type: ignore[assignment] +======= + result = BinOp("and", result, cmp_expr) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return result @staticmethod @@ -1126,7 +1161,11 @@ def build_ExtSlice(ctx, base, extslice): return Subscript(base, [build_SliceExpr(ctx, base, expr.slice)]) elif sub_type is ast.ExtSlice: return Subscript(base, build_ExtSlice(ctx, base, expr.slice)) +<<<<<<< HEAD else: # In Python3.9 array indices are not wrapped in ast.Index +======= + else: # In Python3.9 array indicies are not wrapped in ast.Index +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sub_type is ast.Tuple: # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k] indices = [] diff --git a/torch/jit/supported_ops.py b/torch/jit/supported_ops.py index 98229edff6ee8..c7220d8115005 100644 --- a/torch/jit/supported_ops.py +++ b/torch/jit/supported_ops.py @@ -243,8 +243,13 @@ def _get_global_builtins(): "getattr": "Attribute name must be a literal string", "hasattr": "Attribute name must be a literal string", "isinstance": "Result is static", +<<<<<<< HEAD "zip": "Arguments must be iterable.", "enumerate": "Arguments must be iterable.", +======= + "zip": "Arguments must be iterable. See :ref:`Iterables ` for details.", + "enumerate": "Arguments must be iterable. See :ref:`Iterables ` for details.", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "range": "Can only be used as an iterator in a for loop", } @@ -295,7 +300,11 @@ def _get_global_builtins(): {schemaless_ops_str} +<<<<<<< HEAD The following functions will use the corresponding magic method on TorchScript classes +======= +The following functions will use the corresponding magic method on :any:`TorchScript classes` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. csv-table:: :header: "Function", "Magic Method" diff --git a/torch/library.h b/torch/library.h index f906e04ddecff..14f5678830385 100644 --- a/torch/library.h +++ b/torch/library.h @@ -89,7 +89,11 @@ struct NoInferSchemaTag {}; #define HAS_PT2_COMPLIANT_TAG +<<<<<<< HEAD // For multipy/torchdeploy use case // codespell:ignore multipy +======= +// For multipy/torchdeploy use case +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) enum class _RegisterOrVerify { REGISTER, VERIFY }; template @@ -926,7 +930,11 @@ class TorchLibraryInit final { } void initialize() { +<<<<<<< HEAD lib = std::make_unique(kind, ns, key, file, line); +======= + lib = std::unique_ptr(new Library(kind, ns, key, file, line)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) init_function(*lib); } }; @@ -1022,7 +1030,11 @@ class TorchLibraryInit final { /// Macro for defining a function that will be run at static /// initialization time to define operator overrides for dispatch key /// `k` (must be an unqualified enum member of c10::DispatchKey) in +<<<<<<< HEAD /// namespace `ns` (must be a valid C++ identifier, no quotes). Use this +======= +/// namespace `ns` (must be a valid C++ identifer, no quotes). Use this +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) /// macro when you want to implement a preexisting set of custom /// operators on a new dispatch key (e.g., you want to provide CUDA /// implementations of already existing operators). One common usage diff --git a/torch/library.py b/torch/library.py index 372037f09dbe5..d4abb4e72701d 100644 --- a/torch/library.py +++ b/torch/library.py @@ -102,8 +102,16 @@ def __init__(self, ns, kind, dispatch_key=""): ns, " is a reserved namespace. Please try creating a library with another name.", ) +<<<<<<< HEAD frame = traceback.extract_stack(limit=2)[0] +======= + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return + + frame = traceback.extract_stack(limit=3)[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) filename, lineno = frame.filename, frame.lineno self.m: Optional[Any] = torch._C._dispatch_library( kind, ns, dispatch_key, filename, lineno @@ -153,6 +161,12 @@ def define(self, schema, alias_analysis="", *, tags=()): >>> my_lib = Library("mylib", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor") """ +<<<<<<< HEAD +======= + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid # AliasAnalysis type in C++ @@ -185,6 +199,12 @@ def define(self, schema, alias_analysis="", *, tags=()): def _register_fake(self, op_name, fn, _stacklevel=1, *, allow_override=False): r"""Registers the fake impl for an operator defined in the library.""" +<<<<<<< HEAD +======= + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) source = torch._library.utils.get_source(_stacklevel + 1) frame = sys._getframe(_stacklevel) @@ -228,6 +248,12 @@ def _register_torch_dispatch_rule(self, op_name, torch_dispatch_class, fn): If it is a TorchDispatchMode, we expect fn to have the following signature: (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any """ +<<<<<<< HEAD +======= + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) qualname = f"{self.ns}::{op_name}" entry = torch._library.simple_registry.singleton.find(qualname) @@ -247,6 +273,12 @@ def _impl_with_aoti_compile(self, op_name, dispatch_key=""): >>> my_lib = Library("aten", "IMPL") >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU") """ +<<<<<<< HEAD +======= + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if dispatch_key == "": dispatch_key = self.dispatch_key @@ -309,6 +341,12 @@ def impl( >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", div_cpu, "CPU") """ +<<<<<<< HEAD +======= + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not callable(fn): raise TypeError( @@ -391,13 +429,23 @@ def fallback(self, fn, dispatch_key="", *, with_keyset=False): >>> # ... >>> my_lib.fallback(fallback_kernel, "Autocast") """ +<<<<<<< HEAD +======= + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if dispatch_key == "": dispatch_key = self.dispatch_key if self.ns != "_": raise RuntimeError( +<<<<<<< HEAD f"""Fallback can only be registered using library fragment on the global namespace "_" but it is {self.ns}""" +======= + f"""Fallback can only be registered using libary fragment on the global namespace "_" but it is {self.ns}""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) assert dispatch_key != "" @@ -906,6 +954,7 @@ def register_autocast( lib = Library(namespace, "FRAGMENT") _keep_alive.append(lib) +<<<<<<< HEAD def _maybe_override_py_impl(op: torch._ops.OpOverload, dispatch_key): def inner(kernel): if op.has_kernel_for_dispatch_key(dispatch_key): @@ -917,6 +966,9 @@ def inner(kernel): @_maybe_override_py_impl(_op, torch._C.DispatchKey.AutocastCPU) @_maybe_override_py_impl(_op, torch._C.DispatchKey.AutocastCUDA) def _autocast_py_impl(*args, **kwargs): +======= + def kernel(_, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert len(kwargs) == 0, "Custom ops do not support kwargs yet." autocast_keyset = torch._C.DispatchKeySet( torch._C.DispatchKey.AutocastCPU @@ -924,10 +976,13 @@ def _autocast_py_impl(*args, **kwargs): with torch._C._ExcludeDispatchKeyGuard(autocast_keyset): return _op(*_cast(args, device_type, cast_inputs)) +<<<<<<< HEAD def kernel(_, *args, **kwargs): assert len(kwargs) == 0, "Custom ops do not support kwargs yet." return _autocast_py_impl(*args, **kwargs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if device_type == "cuda": return lib.impl(opname, kernel, "AutocastCUDA", with_keyset=True) else: diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index f0e23fed81f5d..622bbbe56d4d7 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -166,6 +166,7 @@ def _generate_docstring(func): """, ) +<<<<<<< HEAD args_and_kwargs = { # argument name sufficies separated by double underscore will # be removed in the final documentation string. @@ -180,12 +181,29 @@ def _generate_docstring(func): "mean": (("dim",), ("keepdim=False", "dtype=None", "mask=None")), "median": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), "norm": ( +======= + args_and_kwargs = dict( + # argument name sufficies separated by double underscore will + # be removed in the final documentation string. + sum=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + prod=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + cumsum=(("dim__as_int",), ("dtype=None", "mask=None")), + cumprod=(("dim__as_int",), ("dtype=None", "mask=None")), + amin=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + amax=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + argmin=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), + argmax=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), + mean=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + median=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), + norm=( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ( "ord", "dim", ), ("keepdim=False", "dtype=None", "mask=None"), ), +<<<<<<< HEAD "var": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), "std": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), "logsumexp": (("dim",), ("keepdim=False", "dtype=None", "mask=None")), @@ -193,12 +211,22 @@ def _generate_docstring(func): "log_softmax": (("dim__as_int",), ("dtype=None", "mask=None")), "softmin": (("dim__as_int",), ("dtype=None", "mask=None")), "normalize": ( +======= + var=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), + std=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), + logsumexp=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), + softmax=(("dim__as_int",), ("dtype=None", "mask=None")), + log_softmax=(("dim__as_int",), ("dtype=None", "mask=None")), + softmin=(("dim__as_int",), ("dtype=None", "mask=None")), + normalize=( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ( "ord__required", "dim__as_int", ), ("eps=1e-12", "dtype=None", "mask=None"), ), +<<<<<<< HEAD } argument_declarations = { @@ -283,6 +311,92 @@ def _generate_docstring(func): "cumsum": "cumulative_sum", "cumprod": "cumulative_prod", } +======= + ) + + argument_declarations = dict( + dim="""\ +dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + Default: None that is equivalent to ``tuple(range(input.ndim))``.""", + dim__as_int="""\ +dim (int): the dimension along which {operation name} is computed.""", + ord="""\ +ord (int, float, optional): the order of vector norm. Default: 2. + See :func:`torch.linalg.vector_norm` for a list of supported norms.""", + ord__required="""\ +ord (int, float): the order of vector norm. Default: 2. + See :func:`torch.linalg.vector_norm` for a list of supported norms.""", + unbiased="""\ +unbiased (bool): when True, use Bessel's correction, otherwise, compute + the uncorrected sample variance.""", + eps="""\ +eps (float, optional): small value to avoid division by zero. Default: {default}.""", + keepdim="""\ +keepdim (bool, optional): whether the output tensor has + :attr:`dim` retained or not. Default: {default}.""", + dtype="""\ +dtype (:class:`torch.dtype`, optional): the desired data type + of returned tensor. If specified, the input tensor is + casted to :attr:`dtype` before the operation is + performed. Default: {default}.""", + mask="""\ +mask (:class:`torch.Tensor`, optional): the boolean tensor + containing the binary mask of validity of input tensor + elements. + Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""", + ) + + definitions = dict( + softmax="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Softmax of i-th element in ``x`` is +defined as ``exp(x[i])/sum(exp(x))``.""", + log_softmax="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is +defined as ``log(exp(x[i])/sum(exp(x)))``.""", + softmin="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Softmin of i-th element in ``x`` is +defined as ``exp(-x[i])/sum(exp(-x))``.""", + normalize="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Normalize of i-th element in ``x`` is +defined as ``x[i]/max(norm(x, p), eps)``.""", + cumsum="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is +defined as ``sum(x[:i])``.""", + cumprod="""\ +Let ``x`` be a sequence of unmasked elements of one-dimensional slice +of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is +defined as ``prod(x[:i])``.""", + ) + + reduction_names = dict( + sum="sum", + prod="product", + amax="maximum", + amin="minimum", + argmax="argmax", + argmin="argmin", + mean="mean", + median="median", + norm="norm", + var="variance", + std="standard_deviation", + logsumexp="logsumexp", + ) + + normalization_names = dict( + softmax="softmax", + log_softmax="log_softmax", + softmin="softmin", + normalize="normalize", + cumsum="cumulative_sum", + cumprod="cumulative_prod", + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) operation_names = {} operation_names.update(reduction_names) diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py index 9a4df21429ad6..ee809e2decc65 100644 --- a/torch/masked/maskedtensor/_ops_refs.py +++ b/torch/masked/maskedtensor/_ops_refs.py @@ -285,9 +285,13 @@ def layout(func, *args, **kwargs): return _get_data(args[0]).layout +<<<<<<< HEAD @register_dispatch_func( [torch.ops.aten.is_contiguous, torch.ops.aten.sym_is_contiguous] ) +======= +@register_dispatch_func([torch.ops.aten.is_contiguous]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def is_contiguous(func, *args, **kwargs): data = _get_data(args[0]) if data.is_sparse: @@ -353,10 +357,14 @@ def _apply_fn_on_data(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._to_copy]) def _to_copy(func, *args, **kwargs): new_data = func(_get_data(args[0]), *args[1:], **kwargs) +<<<<<<< HEAD cloned_kwargs = kwargs.copy() cloned_kwargs["dtype"] = torch.bool new_mask = func(_maybe_get_mask(args[0]), *args[1:], **cloned_kwargs) return MaskedTensor(new_data, new_mask) +======= + return MaskedTensor(new_data, _maybe_get_mask(args[0])) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @register_dispatch_func([torch.ops.aten._softmax]) diff --git a/torch/mps/profiler.py b/torch/mps/profiler.py index eebeea9a02a49..6a47243911183 100644 --- a/torch/mps/profiler.py +++ b/torch/mps/profiler.py @@ -76,13 +76,21 @@ def is_metal_capture_enabled() -> bool: def is_capturing_metal() -> bool: +<<<<<<< HEAD """Checks if metal capture is in progress""" +======= + """Cheks if metal capture is in progress""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch._C._mps_isCapturing() # type: ignore[attr-defined] @contextlib.contextmanager def metal_capture(fname: str): +<<<<<<< HEAD """Context manager that enables capturing of Metal calls into gputrace""" +======= + """Conext manager that enables capturing of Metal calls into gputrace""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: torch._C._mps_startCapture(fname) # type: ignore[attr-defined] yield diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index 4c4ee32024732..2431680c918e4 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -204,10 +204,13 @@ def attach_out_of_memory_observer( torch._C._mtia_attachOutOfMemoryObserver(observer) +<<<<<<< HEAD def is_bf16_supported(including_emulation: bool = True): return True +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_device_capability(device: Optional[_device_t] = None) -> tuple[int, int]: r"""Return capability of a given device as a tuple of (major version, minor version). @@ -339,6 +342,7 @@ def __exit__(self, type: Any, value: Any, traceback: Any): torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type] +<<<<<<< HEAD def _set_stream_by_id(stream_id, device_index, device_type): r"""set stream specified by the stream id, device index and device type @@ -350,6 +354,8 @@ def _set_stream_by_id(stream_id, device_index, device_type): torch._C._mtia_setStream(stream_id, device_index, device_type) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext: r"""Wrap around the Context-manager StreamContext that selects a given stream. @@ -407,7 +413,10 @@ def set_rng_state( "default_stream", "memory_stats", "max_memory_allocated", +<<<<<<< HEAD "memory_allocated", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "reset_peak_memory_stats", "get_device_capability", "get_device_properties", @@ -421,5 +430,8 @@ def set_rng_state( "device", "set_rng_state", "get_rng_state", +<<<<<<< HEAD "is_bf16_supported", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] diff --git a/torch/mtia/memory.py b/torch/mtia/memory.py index ff350f42b7020..7965ddea1e148 100644 --- a/torch/mtia/memory.py +++ b/torch/mtia/memory.py @@ -36,6 +36,7 @@ def max_memory_allocated(device: Optional[_device_t] = None) -> int: return memory_stats(device).get("dram", 0).get("peak_bytes", 0) +<<<<<<< HEAD def memory_allocated(device: Optional[_device_t] = None) -> int: r"""Return the current MTIA memory occupied by tensors in bytes for a given device. @@ -49,6 +50,8 @@ def memory_allocated(device: Optional[_device_t] = None) -> int: return memory_stats(device).get("dram", 0).get("allocated_bytes", 0) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def reset_peak_memory_stats(device: Optional[_device_t] = None) -> None: r"""Reset the peak memory stats for a given device. @@ -66,6 +69,9 @@ def reset_peak_memory_stats(device: Optional[_device_t] = None) -> None: __all__ = [ "memory_stats", "max_memory_allocated", +<<<<<<< HEAD "memory_allocated", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "reset_peak_memory_stats", ] diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index cbd6eee571f13..c51743ec8181a 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -290,7 +290,11 @@ def reduce_tensor(tensor): # 0xE000 -> --------CUDA allocation----- # # To send tensor1, the following info are required from sender to receiver for +<<<<<<< HEAD # storage reconstruction. +======= + # storage recontruction. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # 1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process). # basePtr may not be exactly 0xA000 since it's a different process. # 2. offset(0xA100) of storage1 in the CUDA allocation. diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index b11e5714fc2e8..54df6a2a0d047 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -12,11 +12,14 @@ from concurrent.futures import as_completed, ThreadPoolExecutor from typing import Optional +<<<<<<< HEAD from torch.numa.binding import ( maybe_temporarily_apply_numa_binding_to_current_thread, NumaOptions, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] @@ -241,7 +244,10 @@ def start_processes( join=True, daemon=False, start_method="spawn", +<<<<<<< HEAD numa_options: Optional[NumaOptions] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): # To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010), # this func will start processes in parallel if start_method is 'forkserver'. @@ -273,12 +279,16 @@ def start_process(i): ) tf.close() os.unlink(tf.name) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) process = mp.Process( target=_wrap, args=(fn, i, args, tf.name), daemon=daemon, ) +<<<<<<< HEAD # HACK [NUMA inheritance]: Subprocesses inherit the parent thread's CPU # affinity. So, we temporarily apply the bindings to the current thread, @@ -297,6 +307,9 @@ def start_process(i): gpu_index=i, numa_options=numa_options ): process.start() +======= + process.start() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return i, process, tf.name if not start_parallel: diff --git a/torch/nativert/OVERVIEW.md b/torch/nativert/OVERVIEW.md index bfe97c9aefc75..24d0efcf25536 100644 --- a/torch/nativert/OVERVIEW.md +++ b/torch/nativert/OVERVIEW.md @@ -244,7 +244,11 @@ For CPU kernels, it is extremely inefficient to go through the dispatcher. For one, the dispatcher doesn't deal with kernel out-variants. > **_NOTE:_** an out-variant of a kernel is one that takes the outputs as +<<<<<<< HEAD > mutable references. this has a few benefits... namely, it allows us to reuse +======= +> mutable references. this has a few benefits... namely, it allows us to re-use +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) > the storage/manage from the previous execution. In addition, the dispatcher acts as a stack machine. You push the inputs to the @@ -281,7 +285,11 @@ RuntimeConfigs { ### Constant Folding Constant folding is the process of finding all of the constant-evaluable +<<<<<<< HEAD subgraphs, evaluating them at startup, and then storing their results as +======= +subgraphs, evaluating them at startup, and then storing thier results as +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) constants as opposed to re-evaluting them every time. To enable constant folding, you can set the following configurations. @@ -311,7 +319,11 @@ torch.ops.quantized.linear_dynamic_fp16.default which should give a ~2x speedup over the fp32 variant with minimal effect on correctness. +<<<<<<< HEAD The linear_prepack_fp16 op will be constant-folded, so it's imperative that +======= +The linear_prepack_fp16 op will be constant-folded, so it's imperitive that +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) these two features are used together. To enable this feature, use the following configurations. @@ -327,7 +339,11 @@ RuntimeConfigs { > :warning: **This is an experimental feature** +<<<<<<< HEAD The main upside of memory planning comes from the efficient reuse of tensor +======= +The main upside of memory planning comes from the efficient re-use of tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) buffers, which is extremely important in memory-bound services. That is, if two tensors don’t have an overlapping lifetime during execution, and the first tensor is larger than the second, then the second tensor can share the same diff --git a/torch/nativert/common/FileUtil.cpp b/torch/nativert/common/FileUtil.cpp index c0887b5277922..4042ca2d018a5 100644 --- a/torch/nativert/common/FileUtil.cpp +++ b/torch/nativert/common/FileUtil.cpp @@ -76,7 +76,11 @@ int filterCloseReturn(int r) { return r; } +<<<<<<< HEAD // The following wrapX() functions are private functions for wrapping file-io +======= +// The following wrapX() funcions are private functions for wrapping file-io +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // against interrupt and partial op completions. // Wrap call to f(args) in loop to retry on EINTR diff --git a/torch/nativert/detail/ITree.cpp b/torch/nativert/detail/ITree.cpp index cd24ca78320fb..16ed3ead5a6ee 100644 --- a/torch/nativert/detail/ITree.cpp +++ b/torch/nativert/detail/ITree.cpp @@ -46,7 +46,11 @@ class PytreeNodeRegistry { const ITreeSpec& spec, std::vector& ivalues) { const auto& tuple = nested.toTupleRef().elements(); +<<<<<<< HEAD TORCH_CHECK(tuple.size() == spec.children().size()); +======= + TORCH_CHECK_EQ(tuple.size(), spec.children().size()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (size_t i = 0; i < tuple.size(); i++) { itreeFlatten(tuple[i], spec.children(i), ivalues); } @@ -60,7 +64,11 @@ class PytreeNodeRegistry { const c10::IValue& nested, const ITreeSpec& spec) { const auto& tuple = nested.toTupleRef().elements(); +<<<<<<< HEAD TORCH_CHECK(tuple.size() == spec.children().size()); +======= + TORCH_CHECK_EQ(tuple.size(), spec.children().size()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (size_t i = 0; i < tuple.size(); i++) { ivalueApply(fn, tuple[i], spec.children(i)); } @@ -119,7 +127,11 @@ class PytreeNodeRegistry { const auto& contextKeys = spec.contextKeys(); // allow the dict size less than the spec, missing key will be // filled with empty tensor +<<<<<<< HEAD TORCH_CHECK(dict.size() <= contextKeys.size()); +======= + TORCH_CHECK_LE(dict.size(), contextKeys.size()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) size_t i = 0; for (const auto& key : contextKeys) { auto it = dict.find(key); @@ -143,7 +155,11 @@ class PytreeNodeRegistry { c10::Dict dict( c10::AnyType::get(), c10::AnyType::get()); TORCH_CHECK(obj.is_array()); +<<<<<<< HEAD TORCH_CHECK(obj.size() == flats.size()); +======= + TORCH_CHECK_EQ(obj.size(), flats.size()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dict.reserve(flats.size()); for (size_t i = 0; i < flats.size(); i++) { dict.insert(dynamicToIValue(obj[i]), std::move(flats[i])); @@ -200,7 +216,11 @@ ITreeSpec makeITreeSpec( TORCH_CHECK(obj.is_object()); TORCH_CHECK(obj.find("type") != obj.end()); if (obj["type"].is_null()) { +<<<<<<< HEAD TORCH_CHECK(obj["children_spec"].empty()); +======= + TORCH_CHECK_EQ(obj["children_spec"].size(), 0); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_CHECK(obj["context"].is_null()); const Value* value = values[start]; @@ -244,11 +264,19 @@ ITreeSpec itreeSpecLoads( const std::vector& values) { const auto obj = nlohmann::json::parse(json); TORCH_CHECK(obj.is_array()); +<<<<<<< HEAD TORCH_CHECK(obj.size() == 2); TORCH_CHECK(obj[0].get() == kDefaultTreeSpecSerializationProtocol); auto result = makeITreeSpec(obj[1], values, 0); TORCH_CHECK(result.numIValues() == values.size()); +======= + TORCH_CHECK_EQ(obj.size(), 2); + TORCH_CHECK_EQ(obj[0].get(), kDefaultTreeSpecSerializationProtocol); + auto result = makeITreeSpec(obj[1], values, 0); + + TORCH_CHECK_EQ(result.numIValues(), values.size()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return result; } @@ -256,7 +284,11 @@ c10::IValue itreeUnflatten( std::vector ivalues, const ITreeSpec& spec) { RECORD_USER_SCOPE("nativert::itreeUnflatten"); +<<<<<<< HEAD TORCH_CHECK(ivalues.size() == spec.numIValues()); +======= + TORCH_CHECK_EQ(ivalues.size(), spec.numIValues()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (spec.isIValue()) { return std::move(ivalues[0]); } @@ -299,20 +331,32 @@ std::vector itreeFlattenFromArgs( const ITreeSpec& spec) { RECORD_USER_SCOPE("nativert::itreeFlattenFromArgs"); TORCH_CHECK(!spec.isIValue()); +<<<<<<< HEAD TORCH_CHECK(spec.children().size() == 2); +======= + TORCH_CHECK_EQ(spec.children().size(), 2); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector ivalues; ivalues.reserve(spec.numIValues()); const auto& specArgs = spec.children(0); TORCH_CHECK(!specArgs.isIValue()); +<<<<<<< HEAD TORCH_CHECK(specArgs.children().size() == args.size()); +======= + TORCH_CHECK_EQ(specArgs.children().size(), args.size()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (size_t i = 0; i < args.size(); i++) { itreeFlatten(args[i], specArgs.children(i), ivalues); } const auto& specKwargs = spec.children(1); TORCH_CHECK(!specKwargs.isIValue()); +<<<<<<< HEAD TORCH_CHECK(specKwargs.context().size() == kwargs.size()); +======= + TORCH_CHECK_EQ(specKwargs.context().size(), kwargs.size()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (size_t i = 0; i < specKwargs.context().size(); i++) { itreeFlatten( kwargs.at(specKwargs.context()[i].get_ref()), @@ -329,11 +373,19 @@ void ivalueApplyFromArgs( const ITreeSpec& spec) { RECORD_USER_SCOPE("nativert::ivalueApplyFromArgs"); TORCH_CHECK(!spec.isIValue()); +<<<<<<< HEAD TORCH_CHECK(spec.children().size() == 2); const auto& specArgs = spec.children(0); TORCH_CHECK(!specArgs.isIValue()); TORCH_CHECK(specArgs.children().size() == args.size()); +======= + TORCH_CHECK_EQ(spec.children().size(), 2); + + const auto& specArgs = spec.children(0); + TORCH_CHECK(!specArgs.isIValue()); + TORCH_CHECK_EQ(specArgs.children().size(), args.size()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (size_t i = 0; i < args.size(); i++) { ivalueApply(fn, args[i], specArgs.children(i)); } @@ -342,7 +394,11 @@ void ivalueApplyFromArgs( TORCH_CHECK(!specKwargs.isIValue()); const auto& ctx = specKwargs.context(); +<<<<<<< HEAD TORCH_CHECK(ctx.size() == kwargs.size()); +======= + TORCH_CHECK_EQ(ctx.size(), kwargs.size()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (size_t i = 0; i < ctx.size(); i++) { ivalueApply( diff --git a/torch/nativert/detail/ITree.h b/torch/nativert/detail/ITree.h index 19359920720ac..7bec677c6758c 100644 --- a/torch/nativert/detail/ITree.h +++ b/torch/nativert/detail/ITree.h @@ -16,6 +16,11 @@ namespace torch::nativert::detail { +<<<<<<< HEAD +======= +using torch::nativert::Value; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ITreeSpec; using ITreeFlattenFn = diff --git a/torch/nativert/detail/MPMCQueue.h b/torch/nativert/detail/MPMCQueue.h index 8301ce3fdb4c5..3be917ac5e476 100644 --- a/torch/nativert/detail/MPMCQueue.h +++ b/torch/nativert/detail/MPMCQueue.h @@ -55,6 +55,7 @@ class MPMCQueue { return true; } +<<<<<<< HEAD /** * Get the current size of the queue. * @return The number of elements in the queue. @@ -64,6 +65,8 @@ class MPMCQueue { return storage_.size(); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) private: std::mutex mutex_; std::deque storage_; diff --git a/torch/nativert/executor/DelegateExecutor.cpp b/torch/nativert/executor/DelegateExecutor.cpp index 6585ac34ddd6c..b5a447f226128 100644 --- a/torch/nativert/executor/DelegateExecutor.cpp +++ b/torch/nativert/executor/DelegateExecutor.cpp @@ -8,8 +8,13 @@ #include +<<<<<<< HEAD #include #include +======= +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace torch::nativert { @@ -28,7 +33,10 @@ char* _mkdtemp(char* outputDir) { std::string extractToTemporaryFolder( caffe2::serialize::PyTorchStreamReader& packageReader, const std::string& targetPath) { +<<<<<<< HEAD // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) char outputDir[] = "/tmp/delegate_model_XXXXXX"; char* tempdir = _mkdtemp(outputDir); TORCH_CHECK( @@ -54,7 +62,11 @@ std::string extractToTemporaryFolder( << " from archive path: " << path << " size: " << dataSize; File extracted(extractedFilename, O_CREAT | O_WRONLY, 0640); +<<<<<<< HEAD const auto bytesWritten = writeFull( +======= + const auto bytesWritten = torch::nativert::writeFull( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extracted.fd(), const_cast(dataPointer.get()), dataSize); TORCH_CHECK( bytesWritten != -1, diff --git a/torch/nativert/executor/DelegateExecutor.h b/torch/nativert/executor/DelegateExecutor.h index 7d88f98987764..b4a9dd26038ae 100644 --- a/torch/nativert/executor/DelegateExecutor.h +++ b/torch/nativert/executor/DelegateExecutor.h @@ -26,7 +26,11 @@ class DelegateExecutor { // Runtime calls processWeights() to pass the weights to the delegate backend. // Typically, a backend would perform some form of validation and processing, +<<<<<<< HEAD // such as constant folding. The processed weights stays in the deactivate +======= + // such as constant folding. The processed weights stays in the inactivate +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // state until commitWeights() is called. // // Weights tensors are co-owned by the runtime and the delegate backend. @@ -38,7 +42,11 @@ class DelegateExecutor { // affect the weight tensors in the delegate backend. // When a weight tensor is no longer used by the delegate backend, the backend // must release it by decreasing a refcount. Runtime would +<<<<<<< HEAD // also release the refcount for weight tensor if it's no longer activate. The +======= + // also release the refcount for weight tensor if it's no longer activte. The +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // underlying storage for weight tensors will be freed when the refcount // reaches 0. virtual void processWeights(std::shared_ptr weights) = 0; @@ -46,8 +54,11 @@ class DelegateExecutor { // This call activate the processed weights. virtual void commitWeights() = 0; +<<<<<<< HEAD virtual void initWeights(std::shared_ptr weights) = 0; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) virtual std::vector run(std::vector& inputs) = 0; }; diff --git a/torch/nativert/executor/ExecutionFrame.cpp b/torch/nativert/executor/ExecutionFrame.cpp index 133a6125a26a2..d2d7302c497f8 100644 --- a/torch/nativert/executor/ExecutionFrame.cpp +++ b/torch/nativert/executor/ExecutionFrame.cpp @@ -2,6 +2,10 @@ #include #include +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace torch::nativert { @@ -10,6 +14,7 @@ ExecutionFrame::ExecutionFrame(const Graph& graph) allValues_(graph.numValues()), persistent_(graph.numValues()), moveable_output_mask_(graph.userOutputs().size()) { +<<<<<<< HEAD updatePersistentValues(/* weights = nullptr */); updateMovableOutputs(); } @@ -68,15 +73,29 @@ void ExecutionFrame::setWeights(const Weights& weights) { std::unordered_map foldedConstIds; for (const Node& node : graph.nodes()) { +======= + // load constant SymInts into execution frame + for (const auto& [valueId, constSymintValue] : + graph_.getConstantSymIntValues()) { + setPersistentIValue(valueId, constSymintValue); + } + + for (const Node& node : graph_.nodes()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (node.target() == "torch.ops.higher_order.run_const_graph") { const auto& const_graph = std::get>(node.attributes().at(0).value); for (size_t i = 0; i < node.outputs().size(); ++i) { +<<<<<<< HEAD foldedConstIds[std::string{const_graph->outputs().at(i)->name()}] = +======= + foldedConstIds_[std::string{const_graph->outputs().at(i)->name()}] = +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) node.outputs()[i]->id(); } } } +<<<<<<< HEAD for (const auto& [name, tensor] : weights->getFoldedConsts()) { persistentValues.emplace_back(foldedConstIds.at(name), tensor); } @@ -97,6 +116,40 @@ void ExecutionFrame::updatePersistentValues(const Weights* weights) { auto&& [value, iv] = *it; setPersistentIValue(value, std::move(iv)); } +======= +} + +ExecutionFrame::ExecutionFrame(const Graph& graph, const Weights& weights) + : ExecutionFrame(graph) { + setWeights(weights); +} + +void ExecutionFrame::setWeights(const Weights& weights) { + weightVersion_ = weights.version(); + + const auto& inputsToWeights = graph_.signature().inputsToWeights(); + for (const auto& [inputName, weightName] : inputsToWeights) { + const Value* value = graph_.getValue(inputName); + setPersistentIValue(value->id(), weights.at(weightName)); + } + + const auto& inputsToCustomObjs = graph_.signature().inputsToCustomObjs(); + for (const auto& [inputName, customObjName] : inputsToCustomObjs) { + const Value* value = graph_.getValue(inputName); + setPersistentIValue(value->id(), weights.getCustomObj(customObjName)); + } + + for (const auto& [value, tensor] : weights.getFoldedConsts()) { + setPersistentIValue(foldedConstIds_.at(value), tensor); + } + + for (const auto& [n, iv] : weights.getConstFoldedValues()) { + const Value* v = graph_.getValue(n); + setPersistentIValue(v->id(), iv); + } + + updateMovableOutputs(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void ExecutionFrame::updateMovableOutputs() { diff --git a/torch/nativert/executor/ExecutionFrame.h b/torch/nativert/executor/ExecutionFrame.h index 1f256bb6b534b..7faa6bb25887f 100644 --- a/torch/nativert/executor/ExecutionFrame.h +++ b/torch/nativert/executor/ExecutionFrame.h @@ -2,9 +2,14 @@ #include +<<<<<<< HEAD #include #include #include +======= +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -22,11 +27,15 @@ class ExecutionFrame { // torch.cond explicit ExecutionFrame(const Graph& graph); +<<<<<<< HEAD explicit ExecutionFrame( const Graph& graph, const Weights& weights, const torch::nativert::ExecutorConfig& executorConfig = {}, LayoutPlanner* layoutPlanner = nullptr); +======= + explicit ExecutionFrame(const Graph& graph, const Weights& weights); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Constructor for testing purpose explicit ExecutionFrame( @@ -35,15 +44,19 @@ class ExecutionFrame { const std::vector& graphInputIds, const std::vector& graphOutputIds); +<<<<<<< HEAD ExecutionFrame(const ExecutionFrame&) = delete; ExecutionFrame& operator=(const ExecutionFrame&) = delete; ExecutionFrame(ExecutionFrame&&) = delete; ExecutionFrame& operator=(ExecutionFrame&&) = delete; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ~ExecutionFrame() { destroyBorrowedIValues(); } +<<<<<<< HEAD template auto withManagedMemory(CB&& cb) { if (!layoutManager_) { @@ -55,6 +68,8 @@ class ExecutionFrame { const_cast(layoutManager_.get())); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::vector tryMoveUserOutputs(); c10::IValue moveIValue(ValueId id) { @@ -96,19 +111,28 @@ class ExecutionFrame { return getIValue(id).toDouble(); } +<<<<<<< HEAD C10_ALWAYS_INLINE bool isManagedValue(const ValueId id) const { return layoutPlanner_ != nullptr && layoutPlanner_->is_managed(id); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void setPersistentIValue(ValueId id, c10::IValue ivalue) { setIValue(id, std::move(ivalue)); persistent_[id] = true; } +<<<<<<< HEAD void releaseValueIfNeeded(ValueId id) { if (!isManagedValue(id) && !persistent_[id]) { allValues_[id] = c10::IValue(); } +======= + void releaseValue(ValueId id) { + CHECK(!persistent_[id]) << "Cannot release persistent value"; + allValues_[id] = c10::IValue(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void destroyBorrowedIValues() { @@ -118,12 +142,26 @@ class ExecutionFrame { borrowedValueIds_.clear(); } +<<<<<<< HEAD +======= + void setWork(int64_t workId, const c10::intrusive_ptr& work) { + work_[workId] = work; + } + + c10::intrusive_ptr getWork(int64_t workId) const { + CHECK(work_.find(workId) != work_.end()) + << "Couldn't find work with Id: " << workId; + return work_.at(workId); + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) WeightVersion weightVersion() const { return weightVersion_; } void setWeights(const Weights& weights); +<<<<<<< HEAD static std::vector> getPersistentValues( const Graph& graph, const Weights* weights = nullptr); @@ -145,11 +183,19 @@ class ExecutionFrame { } void updatePersistentValues(const Weights* weights = nullptr); +======= + private: + bool isOutputMovable(size_t idx) const { + TORCH_CHECK_LT(idx, moveable_output_mask_.size()); + return moveable_output_mask_[idx]; + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void updateMovableOutputs(); const Graph& graph_; WeightVersion weightVersion_ = -1; +<<<<<<< HEAD std::unique_ptr layoutManager_; LayoutPlanner* layoutPlanner_{nullptr}; @@ -161,6 +207,19 @@ class ExecutionFrame { std::vector borrowedValueIds_; +======= + // All the intermediate values for the entire graph, including graph inputs + // and outputs This table is fixed once constructed + std::vector allValues_; + std::vector persistent_; + + std::unordered_map> work_; + + std::vector borrowedValueIds_; + + std::unordered_map foldedConstIds_; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // moveable_output_mask_[i] corresponds to user_outputs_[i] // // if moveable_output_mask_[i] is true, then user_outputs_[i] diff --git a/torch/nativert/executor/ExecutorConfig.h b/torch/nativert/executor/ExecutorConfig.h index fb57f2b6f2ef6..3e602ab12ca27 100644 --- a/torch/nativert/executor/ExecutorConfig.h +++ b/torch/nativert/executor/ExecutorConfig.h @@ -1,6 +1,9 @@ #pragma once +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include @@ -9,17 +12,27 @@ namespace torch::nativert { struct ExecutorConfig { bool validateInputs = false; bool debugNan = false; +<<<<<<< HEAD bool enableStaticCPUKernels = true; bool runConstFolding = false; bool doExecutionFrameCleanup = true; bool tryFreeUnmanagedValuesAfterUse = true; +======= + bool enableStaticCPUKernels = false; + bool enableStaticMemoryPlanning = false; + bool runConstFolding = false; + bool doExecutionFrameCleanup = true; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // allows up to max number of concurrent threads. int64_t maxNumConcurrentThreads = 8; // allows up to max number of parallel ops. int64_t maxParallelOps = 1; int64_t minNumExecutionFrames = 1; int64_t executionFramePoolCleanupIntervalSec = 600; +<<<<<<< HEAD LayoutPlannerSettings layoutPlannerSettings; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::string modelName = "unknown"; }; diff --git a/torch/nativert/executor/GraphExecutorBase.cpp b/torch/nativert/executor/GraphExecutorBase.cpp index a623d5873ea56..c2c6696ce4425 100644 --- a/torch/nativert/executor/GraphExecutorBase.cpp +++ b/torch/nativert/executor/GraphExecutorBase.cpp @@ -13,14 +13,22 @@ GraphExecutorBase::GraphExecutorBase( : graph_(graph), nodeKernels_(std::move(nodeKernels)), executorConfig_(executorConfig), +<<<<<<< HEAD execPlan_(ExecutionPlanner{graph_}.createPlan()) {} +======= + execPlan_(ExecutionPlanner{graph_}.createPlan()) {}; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) void GraphExecutorBase::fillUserInputs( ExecutionFrame& frame, std::vector inputs) { RECORD_USER_SCOPE("Executor::fillUserInputs"); const auto& inputValues = graph_.userInputs(); +<<<<<<< HEAD TORCH_CHECK(inputValues.size() == inputs.size()); +======= + TORCH_CHECK_EQ(inputValues.size(), inputs.size()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // load user input tensor into execution frame for (size_t i = 0; i < inputValues.size(); i++) { @@ -32,7 +40,11 @@ void GraphExecutorBase::fillUserInputs( ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( ExecutionFrame& executionFrame, +<<<<<<< HEAD const std::vector>& inputsList, +======= + std::vector> inputsList, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const uint32_t warmupRuns, const uint32_t mainRuns) { // TODO: add support for memory profiling @@ -40,6 +52,7 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( ProfileMetrics results; const auto numNodes = static_cast(nodeKernels_.size()); +<<<<<<< HEAD results.percentPerNode.resize(numNodes, 0.0f); results.nodeTypes.reserve(numNodes); @@ -47,6 +60,8 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( results.nodeTypes.emplace_back(nodeKernel->node()->target()); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) results.timePerNode.resize(numNodes, 0); if (inputsList.empty()) { auto i = 0; @@ -81,6 +96,7 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( // Execute kernels caffe2::Timer timer; +<<<<<<< HEAD executionFrame.withManagedMemory([&](auto) { for (uint32_t i = 0; i < mainRuns; i++) { for (auto inputs : inputsList) { @@ -99,6 +115,24 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( } } }); +======= + for (uint32_t i = 0; i < mainRuns; i++) { + for (auto inputs : inputsList) { + const auto& inputValues = graph_.userInputs(); + + TORCH_CHECK_EQ(inputValues.size(), inputs.size()); + for (size_t j = 0; j < inputValues.size(); j++) { + executionFrame.setIValue(inputValues[j]->id(), std::move(inputs[j])); + } + for (NodeIndex nodeIdx = 0; nodeIdx < nodeKernels_.size(); ++nodeIdx) { + timer.Start(); + nodeKernels_[nodeIdx]->compute(executionFrame); + float millis = timer.MilliSeconds(); + results.timePerNode[nodeIdx] += millis; + } + } + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Summarize results const float numTotalIters = @@ -121,11 +155,15 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( results.totalNodesCount = numNodes; for (const auto& r : results.timePerNodeType) { const std::string& target = r.first; +<<<<<<< HEAD results.percentPerNodeType[target] = r.second * 100.0f / results.totalTime; } for (const auto i : c10::irange(numNodes)) { results.percentPerNode[i] = results.timePerNode[i] * 100.0f / results.totalTime; +======= + results.percentPerNodeType[target] = r.second * 100.0 / results.totalTime; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } return results; } diff --git a/torch/nativert/executor/GraphExecutorBase.h b/torch/nativert/executor/GraphExecutorBase.h index dfe020ebae29e..5da7005a5e39c 100644 --- a/torch/nativert/executor/GraphExecutorBase.h +++ b/torch/nativert/executor/GraphExecutorBase.h @@ -14,15 +14,23 @@ struct ProfileMetrics { size_t staticDispatchNodesCount{0}; size_t totalNodesCount{0}; std::vector timePerNode; +<<<<<<< HEAD std::vector nodeTypes; std::unordered_map timePerNodeType; std::unordered_map percentPerNodeType; std::vector percentPerNode; +======= + std::unordered_map timePerNodeType; + std::unordered_map percentPerNodeType; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::unordered_map instancesPerNodeType; std::unordered_set staticDispatchNodes; std::unordered_set primNodes; float totalTime{0}; +<<<<<<< HEAD std::string name; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; /** @@ -54,7 +62,11 @@ class GraphExecutorBase { ProfileMetrics benchmarkIndividualNodes( ExecutionFrame& executionFrame, +<<<<<<< HEAD const std::vector>& inputs, +======= + std::vector> inputs, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const uint32_t warmup_runs, const uint32_t main_runs); diff --git a/torch/nativert/executor/OpKernel.h b/torch/nativert/executor/OpKernel.h index 4e3fd2dd9a8e8..a970f43be8ef3 100644 --- a/torch/nativert/executor/OpKernel.h +++ b/torch/nativert/executor/OpKernel.h @@ -98,8 +98,15 @@ class OpKernel { public: explicit OpKernel( const Node* node, +<<<<<<< HEAD OpKernelKind kind = OpKernelKind::kInterpreterFallbackKernel) : node_(node), kind_(kind) { +======= + std::optional device = std::nullopt, + torch::nativert::OpKernelKind kind = + torch::nativert::OpKernelKind::kInterpreterFallbackKernel) + : node_(node), device_(device), kind_(kind) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) VLOG(1) << "Initializing kernel for node: " << *node_; } @@ -108,17 +115,30 @@ class OpKernel { } void compute(ExecutionFrame& executionFrame) const; +<<<<<<< HEAD OpKernelKind kind() const { +======= + torch::nativert::OpKernelKind kind() const { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return kind_; } bool hasPrimKernel() const { +<<<<<<< HEAD return kind() == OpKernelKind::kPrimKernel; } bool hasStaticDispatch() const { return kind() == OpKernelKind::kStaticDispatchKernel || kind() == OpKernelKind::kNativeStaticDispatchKernel; +======= + return kind() == torch::nativert::OpKernelKind::kPrimKernel; + } + + bool hasStaticDispatch() const { + return kind() == torch::nativert::OpKernelKind::kStaticDispatchKernel || + kind() == torch::nativert::OpKernelKind::kNativeStaticDispatchKernel; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } size_t numInputs() const { @@ -149,9 +169,16 @@ class OpKernel { virtual void computeInternal(ExecutionFrame& executionFrame) const = 0; const Node* node_; +<<<<<<< HEAD const static bool blockingEnabled_; // this should be set in the ctor! const OpKernelKind kind_; +======= + std::optional device_; + const static bool blockingEnabled_; + // this should be set in the ctor! + const torch::nativert::OpKernelKind kind_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; } // namespace torch::nativert diff --git a/torch/nativert/executor/OpKernelKind.h b/torch/nativert/executor/OpKernelKind.h index 5a8ba38316f67..6822ae5401946 100644 --- a/torch/nativert/executor/OpKernelKind.h +++ b/torch/nativert/executor/OpKernelKind.h @@ -8,10 +8,16 @@ enum class OpKernelKind : uint8_t { kPrimKernel, kStaticDispatchKernel, kInterpreterFallbackKernel, +<<<<<<< HEAD // static dispatch kernels that don't reuse // out TensorImpl kNativeStaticDispatchKernel, kTritonKernel, +======= + // static dispatch kernels that don't re-use + // out TensorImpl + kNativeStaticDispatchKernel, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; } // namespace torch::nativert diff --git a/torch/nativert/executor/Placement.cpp b/torch/nativert/executor/Placement.cpp index 0432ecdc2a7c3..5c58b3abb246c 100644 --- a/torch/nativert/executor/Placement.cpp +++ b/torch/nativert/executor/Placement.cpp @@ -32,6 +32,7 @@ std::ostream& operator<<(std::ostream& os, const Placement& placement) { return os; } +<<<<<<< HEAD namespace { void assertCudaDeviceHasIndex(const c10::Device& device) { if (device.is_cuda()) { @@ -41,6 +42,8 @@ void assertCudaDeviceHasIndex(const c10::Device& device) { } } // namespace +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Placement::Placement(std::optional defaultDevice) : Placement({}, defaultDevice) {} @@ -48,6 +51,7 @@ Placement::Placement( const std::unordered_map& deviceMap, std::optional defaultDevice) { for (const auto& [srcDevice, dstDevice] : deviceMap) { +<<<<<<< HEAD assertCudaDeviceHasIndex(srcDevice); assertCudaDeviceHasIndex(dstDevice); @@ -57,11 +61,22 @@ Placement::Placement( if (defaultDevice.has_value()) { assertCudaDeviceHasIndex(defaultDevice.value()); defaultDevice_ = defaultDevice.value(); +======= + deviceMap_.try_emplace( + normalizeDevice(srcDevice), normalizeDevice(dstDevice)); + } + if (defaultDevice.has_value()) { + defaultDevice_ = normalizeDevice(defaultDevice.value()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } } c10::Device Placement::getMappedDevice(const c10::Device& srcDevice) const { +<<<<<<< HEAD auto it = deviceMap_.find(srcDevice); +======= + auto it = deviceMap_.find(normalizeDevice(srcDevice)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (it != deviceMap_.end()) { return it->second; } diff --git a/torch/nativert/executor/Placement.h b/torch/nativert/executor/Placement.h index 6ea86348973ee..1b7237c9f9076 100644 --- a/torch/nativert/executor/Placement.h +++ b/torch/nativert/executor/Placement.h @@ -9,6 +9,24 @@ namespace torch::nativert { /** +<<<<<<< HEAD +======= + * This function returns a normalized version of the input device: + * - For CPU devices, the returned device will have no index (i.e., the default + * CPU device). + * - For CUDA devices, if no index is specified, index 0 is assumed. + * - For other device types, the function will raise an error. + * + * @param device The input c10::Device to normalize. + * @return A normalized c10::Device with standardized indexing. + * + * @throws c10::Error If the device type is not CPU or CUDA. + */ + +c10::Device normalizeDevice(const c10::Device& device); + +/** +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * Returns true if the two devices are the same and has the same device index * (if cuda). */ diff --git a/torch/nativert/executor/PlacementUtils.cpp b/torch/nativert/executor/PlacementUtils.cpp index e73224b4f4f52..173a364bc330c 100644 --- a/torch/nativert/executor/PlacementUtils.cpp +++ b/torch/nativert/executor/PlacementUtils.cpp @@ -4,6 +4,23 @@ namespace torch::nativert { +<<<<<<< HEAD +======= +c10::Device normalizeDevice(const c10::Device& device) { + // cpu device doesn't have index + // cuda device index must have a index + if (device.is_cpu()) { + return c10::Device(c10::DeviceType::CPU); + } else if (device.is_cuda()) { + return c10::Device( + c10::DeviceType::CUDA, + device.has_index() ? device.index() : static_cast(0)); + } else { + TORCH_CHECK(false, "Unsupported device type", device); + } +} + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bool isSameDevice(const c10::Device& a, const c10::Device& b) { if (a.is_cpu()) { return b.is_cpu(); diff --git a/torch/nativert/executor/SerialGraphExecutor.cpp b/torch/nativert/executor/SerialGraphExecutor.cpp index 58a7cd1c4307c..7a91ad640d1be 100644 --- a/torch/nativert/executor/SerialGraphExecutor.cpp +++ b/torch/nativert/executor/SerialGraphExecutor.cpp @@ -14,6 +14,7 @@ std::vector SerialGraphExecutor::execute( std::vector SerialGraphExecutor::executeWithPrefilledFrame( ExecutionFrame& executionFrame) { +<<<<<<< HEAD executionFrame.withManagedMemory([&](const LayoutManager* layout_manager) { // Execute kernels for all nodes except prim.Input and prim.Output for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) { @@ -34,6 +35,21 @@ std::vector SerialGraphExecutor::executeWithPrefilledFrame( } } }); +======= + // Execute kernels for all nodes except prim.Input and prim.Output + for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) { + nodeKernels_[nodeIdx]->compute(executionFrame); + + // don't free intermediate values when static memory planning is enabled + if (!executorConfig_.enableStaticMemoryPlanning) { + // Free the intermediate values that are no used anymore + for (const auto& valueKey : execPlan_->valuesToFree[nodeIdx]) { + executionFrame.releaseValue(valueKey); + } + } + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return executionFrame.tryMoveUserOutputs(); } diff --git a/torch/nativert/executor/SessionState.h b/torch/nativert/executor/SessionState.h index 37cdf32b3fd3e..f715f71a2facb 100644 --- a/torch/nativert/executor/SessionState.h +++ b/torch/nativert/executor/SessionState.h @@ -9,6 +9,12 @@ namespace torch::nativert { +<<<<<<< HEAD +======= +using torch::nativert::ExecutionFrame; +using torch::nativert::Node; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) template > struct copyable_atomic : public __atomic_base { public: diff --git a/torch/nativert/executor/Weights.cpp b/torch/nativert/executor/Weights.cpp index d685cc1a78163..c6ca1c8df090c 100644 --- a/torch/nativert/executor/Weights.cpp +++ b/torch/nativert/executor/Weights.cpp @@ -16,6 +16,10 @@ #include #endif +<<<<<<< HEAD +======= +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include namespace torch::nativert { @@ -26,14 +30,22 @@ Weights::Weights( const Graph* graph, const std::optional>& stateDict, +<<<<<<< HEAD const std::optional>& constants) : graph_(graph), weightsMeta_(graph->weightsMeta()), +======= + Placement placement) + : graph_(graph), + weightsMeta_(graph->weightsMeta()), + placement_(std::move(placement)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) version_(globalVersion_++) { if (stateDict.has_value()) { loadStateDict(stateDict.value()); } +<<<<<<< HEAD if (constants.has_value()) { for (const auto& [name, value] : constants.value()) { if (value.isTensor()) { @@ -45,6 +57,8 @@ Weights::Weights( } } } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } Weights::Weights( @@ -54,10 +68,18 @@ Weights::Weights( std::string_view stateDictPathPrefix, const std::unordered_map& constantPaths, std::string_view constantPathPrefix, +<<<<<<< HEAD +======= + Placement placement, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::function skipSizeCheck, std::function skipDtypeCheck) : graph_(graph), weightsMeta_(graph->weightsMeta()), +<<<<<<< HEAD +======= + placement_(std::move(placement)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) version_(globalVersion_++), skipSizeCheck_(std::move(skipSizeCheck)), skipDtypeCheck_(std::move(skipDtypeCheck)) { @@ -106,7 +128,11 @@ Weights::Weights( if (!isUsed) { VLOG(1) << "Tensor " << tensorName << " is not used during inference"; +<<<<<<< HEAD auto targetDevice = tensorMeta->device(); +======= + auto targetDevice = placement_.getMappedDevice(tensorMeta->device()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) allValues_[tensorName] = at::scalar_tensor(0, at::TensorOptions().device(targetDevice)); return; @@ -129,7 +155,11 @@ Weights::Weights( at::empty({0}, tensorOptions) .set_(storage, 0, tensorMeta->sizes(), tensorMeta->strides()); +<<<<<<< HEAD auto targetDevice = tensorMeta->device(); +======= + auto targetDevice = placement_.getMappedDevice(tensorMeta->device()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) VLOG(1) << "Loading weight " << tensorName << " on " << targetDevice; if (!isSameDevice(targetDevice, tensor.device())) { tensor = tensor.to(targetDevice); @@ -317,7 +347,11 @@ void Weights::loadStateDict( TORCH_CHECK( it != weightsMeta_.end(), "Couldn't find ", name, " in weightsMeta"); +<<<<<<< HEAD auto targetDevice = it->second.device(); +======= + auto targetDevice = placement_.getMappedDevice(it->second.device()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto tensor = stateDictIt->second.toTensor().to(targetDevice); TORCH_CHECK(tensor.sizes() == it->second.sizes()); @@ -360,7 +394,11 @@ void Weights::validateValue(const std::string& name, const at::Tensor& newValue) " vs ", newValue.dtype()); +<<<<<<< HEAD auto targetDevice = weightMeta.device(); +======= + auto targetDevice = placement_.getMappedDevice(weightMeta.device()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (targetDevice.is_cpu() && targetDevice.has_index()) { LOG(WARNING) << "Target device is cpu but has index: " << targetDevice; } diff --git a/torch/nativert/executor/Weights.h b/torch/nativert/executor/Weights.h index e3c1469c0d5c5..91d26fb1c8186 100644 --- a/torch/nativert/executor/Weights.h +++ b/torch/nativert/executor/Weights.h @@ -1,8 +1,17 @@ #pragma once +<<<<<<< HEAD #include #include #include +======= +#include + +#include +#include +#include +#include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include @@ -20,12 +29,20 @@ using WeightVersion = int; */ class Weights { public: +<<<<<<< HEAD Weights( const Graph* graph, const std::optional>& stateDict = std::nullopt, const std::optional>& constants = std::nullopt); +======= + explicit Weights( + const Graph* graph, + const std::optional>& + stateDict = std::nullopt, + Placement placement = Placement()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // Arguments // - pytorchStreamReader: the reader for the model archive @@ -36,6 +53,11 @@ class Weights { // - constantPaths: a map from constant name to file path in the archive // - constantPathPrefix: a prefix that will be prepended to paths in // constantPathPrefix +<<<<<<< HEAD +======= + // - placement: the device placement of the weights, default to follow the + // original device in the weight's metadata +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) explicit Weights( const Graph* graph, std::shared_ptr @@ -44,6 +66,10 @@ class Weights { std::string_view stateDictPathPrefix, const std::unordered_map& constantPaths, std::string_view constantPathPrefix, +<<<<<<< HEAD +======= + Placement placement = Placement(), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::function skipSizeCheck = {}, std::function skipDtypeCheck = {}); @@ -104,13 +130,21 @@ class Weights { private: const Graph* graph_; const std::unordered_map& weightsMeta_; +<<<<<<< HEAD +======= + Placement placement_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // keys are parameter/buffer/constant names, not graph input names! std::unordered_map allValues_; std::unordered_map customObjs_; +<<<<<<< HEAD // contains CustomClassHolder map from a file name to an arbitrary +======= + // contains CustomClassHolder map from a file name to an arbitray +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // key in customObjs_ that hold the loaded content of the file. // This is used in AOTIDelegateExecutor. std::unordered_map customObjsPaths_; diff --git a/torch/nativert/executor/memory/FunctionSchema.cpp b/torch/nativert/executor/memory/FunctionSchema.cpp index 264ed702cbc0d..b3cea31beede0 100644 --- a/torch/nativert/executor/memory/FunctionSchema.cpp +++ b/torch/nativert/executor/memory/FunctionSchema.cpp @@ -11,16 +11,20 @@ bool FunctionSchema::alias(size_t input_idx, size_t output_idx) const { } } +<<<<<<< HEAD VLOG(1) << "checking aliasing spec for " << c10_fn_schema_.name() << " " << (c10_fn_schema_.is_varret() ? "varret" : "non-varret") << " " << (c10_fn_schema_.is_vararg() ? "vararg" : "non-vararg"); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (!aliasing_spec_.empty()) { VLOG(1) << "aliasing spec is not empty but no entry found for (" << input_idx << "-->" << output_idx << ") -- falling back to schema->may_contain_alias()"; } +<<<<<<< HEAD /* varret and vararg will contribute to the input/output idx's but because we don't know how many inputs/outputs there are, @@ -49,6 +53,8 @@ bool FunctionSchema::alias(size_t input_idx, size_t output_idx) const { return true; } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return c10_fn_schema_.may_contain_alias( {c10::SchemaArgType::output, output_idx}, {c10::SchemaArgType::input, input_idx}, diff --git a/torch/nativert/executor/memory/FunctionSchema.h b/torch/nativert/executor/memory/FunctionSchema.h index 713df50805847..adf461f3863be 100644 --- a/torch/nativert/executor/memory/FunctionSchema.h +++ b/torch/nativert/executor/memory/FunctionSchema.h @@ -17,7 +17,12 @@ class FunctionSchema { explicit FunctionSchema( const c10::FunctionSchema& schema, AliasingSpec&& aliasing_spec = {}, +<<<<<<< HEAD OpKernelKind kernel_kind = OpKernelKind::kInterpreterFallbackKernel) +======= + torch::nativert::OpKernelKind kernel_kind = + torch::nativert::OpKernelKind::kInterpreterFallbackKernel) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) : aliasing_spec_(std::move(aliasing_spec)), kernel_kind_(kernel_kind), c10_fn_schema_(schema) {} @@ -32,13 +37,21 @@ class FunctionSchema { bool alias(size_t input_idx, size_t output_idx) const; +<<<<<<< HEAD C10_ALWAYS_INLINE OpKernelKind kernel_kind() const { +======= + C10_ALWAYS_INLINE torch::nativert::OpKernelKind kernel_kind() const { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return kernel_kind_; } private: AliasingSpec aliasing_spec_; +<<<<<<< HEAD OpKernelKind kernel_kind_; +======= + torch::nativert::OpKernelKind kernel_kind_; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) c10::FunctionSchema c10_fn_schema_; }; diff --git a/torch/nativert/executor/memory/LayoutPlannerSettings.h b/torch/nativert/executor/memory/LayoutPlannerSettings.h index 8ade27997bdfc..5755e591d7cf0 100644 --- a/torch/nativert/executor/memory/LayoutPlannerSettings.h +++ b/torch/nativert/executor/memory/LayoutPlannerSettings.h @@ -7,7 +7,10 @@ namespace torch::nativert { enum class LayoutPlannerAlgorithmType { Bump, GreedyBySize, +<<<<<<< HEAD DisjointStorageGroups, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; class LayoutManagerSettings { diff --git a/torch/nativert/graph/Graph.cpp b/torch/nativert/graph/Graph.cpp index ee5fbaca11b91..cb318562ff804 100644 --- a/torch/nativert/graph/Graph.cpp +++ b/torch/nativert/graph/Graph.cpp @@ -8,8 +8,14 @@ #include #include #include +<<<<<<< HEAD #include #include +======= +#include +#include // @manual +#include // @manual +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace torch::nativert { @@ -39,7 +45,11 @@ size_t expectImpl( TORCH_CHECK( expected == actual, fmt::format( +<<<<<<< HEAD "Parser error: expected '{}' at position {}, but found '{}'.", +======= + "Parser error: expected '{}' at postition {}, but found '{}'.", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expected, curPos, actual)); @@ -54,7 +64,11 @@ size_t expectImpl(std::string_view source, char expected, size_t curPos) { } TORCH_CHECK( expected == source[curPos], +<<<<<<< HEAD "Parser error: expected '{}' at position {}, but found '{}'.", +======= + "Parser error: expected '{}' at postition {}, but found '{}'.", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expected, curPos, source[curPos]); @@ -281,7 +295,11 @@ void Node::applyDevicePlacement(const Placement& placement) { auto device = std::get(attribute.value); auto targetDevice = placement.getMappedDevice(std::get(attribute.value)); +<<<<<<< HEAD if (!isSameDevice(targetDevice, device)) { +======= + if (!torch::nativert::isSameDevice(targetDevice, device)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) LOG(INFO) << "Overriding " << device.str() << " to " << targetDevice.str() << " for node " << *this; attribute.value = targetDevice; @@ -568,7 +586,11 @@ void Graph::lint() const { } } for (const auto& node : nodes()) { +<<<<<<< HEAD TORCH_CHECK(node.owningGraph() == this); +======= + TORCH_CHECK_EQ(node.owningGraph(), this); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // Check that every list type is either produced by a prim.ListPack or // immediately consumed by a prim.ListUnpack. We make use of this invariant @@ -661,6 +683,7 @@ void Graph::replaceAllUsesAfterNode( } void Graph::applyDevicePlacement(const Placement& placement) { +<<<<<<< HEAD TORCH_CHECK( !placementApplied_, "placement has been applied to the graph! placement must be applied once and once only."); @@ -685,6 +708,16 @@ void Graph::applyDevicePlacement(const Placement& placement) { Node* Graph::nodeAfter(Node* n) { TORCH_CHECK(n->owningGraph() == this); +======= + // TODO: consolidate device info in weight loading here as well. + for (auto& node : nodes_) { + node.applyDevicePlacement(placement); + } +} + +Node* Graph::nodeAfter(Node* n) { + TORCH_CHECK_EQ(n->owningGraph(), this); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (n == outputNode_) { return nullptr; } @@ -693,7 +726,11 @@ Node* Graph::nodeAfter(Node* n) { } const Node* Graph::nodeAfter(const Node* n) const { +<<<<<<< HEAD TORCH_CHECK(n->owningGraph() == this); +======= + TORCH_CHECK_EQ(n->owningGraph(), this); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (n == outputNode_) { return nullptr; } @@ -702,7 +739,11 @@ const Node* Graph::nodeAfter(const Node* n) const { } Node* Graph::nodeBefore(Node* n) { +<<<<<<< HEAD TORCH_CHECK(n->owningGraph() == this); +======= + TORCH_CHECK_EQ(n->owningGraph(), this); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (n == inputNode_) { return nullptr; } @@ -711,7 +752,11 @@ Node* Graph::nodeBefore(Node* n) { } const Node* Graph::nodeBefore(const Node* n) const { +<<<<<<< HEAD TORCH_CHECK(n->owningGraph() == this); +======= + TORCH_CHECK_EQ(n->owningGraph(), this); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if (n == inputNode_) { return nullptr; } @@ -720,7 +765,12 @@ const Node* Graph::nodeBefore(const Node* n) const { } void Graph::removeNode(Node* n) { +<<<<<<< HEAD TORCH_CHECK(n->owningGraph() == this, "Node does not belong to this graph!"); +======= + TORCH_CHECK_EQ(n->owningGraph(), this) + << "Node does not belong to this graph!"; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (auto* outputVal : n->outputs()) { TORCH_CHECK( @@ -762,7 +812,12 @@ std::vector Graph::insertGraph( const Graph& subgraph, std::vector inputs, std::unordered_map& valueMap) { +<<<<<<< HEAD TORCH_CHECK(subgraph.inputs().size() == inputs.size(), "Input size mismatch"); +======= + TORCH_CHECK_EQ(subgraph.inputs().size(), inputs.size()) + << "Input size mismatch"; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (auto i : c10::irange(subgraph.inputs().size())) { valueMap[subgraph.inputs()[i]] = inputs[i]; } @@ -868,7 +923,11 @@ void Node::addOutput() { } Value* Node::addOutput(const Type& type) { +<<<<<<< HEAD TORCH_CHECK(type == Type::Kind::None); +======= + TORCH_CHECK_EQ(type, Type::Kind::None); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Value* v = owningGraph_->addValue(std::nullopt, type, this); outputs_.push_back(v); return v; @@ -907,9 +966,15 @@ std::vector Value::getListElements() const { ret.push_back(tv.value); } } else { +<<<<<<< HEAD TORCH_CHECK(users().size() == 1); const auto listUnpack = users()[0]; TORCH_CHECK(listUnpack->target() == "prim.ListUnpack"); +======= + TORCH_CHECK_EQ(users().size(), 1); + const auto listUnpack = users()[0]; + TORCH_CHECK_EQ(listUnpack->target(), "prim.ListUnpack"); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto v : listUnpack->outputs()) { ret.push_back(v); } @@ -1084,6 +1149,7 @@ std::ostream& operator<<(std::ostream& out, const Graph& graph) { c10::Device convertDevice(std::string_view symbol) { // Symbol looks like `Device{cuda:1}` const auto typeStart = symbol.find('{') + 1; +<<<<<<< HEAD TORCH_CHECK(typeStart < symbol.size()); const auto typeEnd = symbol.find(':'); @@ -1095,6 +1161,19 @@ c10::Device convertDevice(std::string_view symbol) { const auto indexEnd = symbol.find('}'); TORCH_CHECK(indexEnd != std::string_view::npos); +======= + TORCH_CHECK_LT(typeStart, symbol.size()); + + const auto typeEnd = symbol.find(':'); + TORCH_CHECK_NE(typeEnd, std::string_view::npos); + + const auto type = symbol.substr(typeStart, typeEnd - typeStart); + const auto indexStart = typeEnd + 1; + TORCH_CHECK_LT(indexStart, symbol.size()); + + const auto indexEnd = symbol.find('}'); + TORCH_CHECK_NE(indexEnd, std::string_view::npos); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) const auto index = symbol.substr(indexStart, indexEnd - indexStart); @@ -1113,7 +1192,11 @@ c10::Device convertDevice(std::string_view symbol) { Constant convertAtomicConstant(std::string_view symbol) { if (c10::starts_with(symbol, "\"")) { // chop off the outer quotes and return the string +<<<<<<< HEAD TORCH_CHECK(symbol.size() >= 2); +======= + TORCH_CHECK_GE(symbol.size(), 2); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) symbol.remove_prefix(1); symbol.remove_suffix(1); return std::string(symbol); @@ -1192,8 +1275,13 @@ Constant convertListConstant(std::string_view source) { TORCH_CHECK(false, "constant lists only support int, float, bool"); } } else { +<<<<<<< HEAD TORCH_CHECK( type.index() == val.index(), "lists must have all the same type"); +======= + TORCH_CHECK_EQ(type.index(), val.index()) + << "lists must have all the same type"; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } values.push_back(std::move(val)); if (source.at(curPos) == ']') { @@ -1296,7 +1384,11 @@ std::unique_ptr Parser::parse() { } // For graph textual format, it should be safe to assume all // inputs/outputs are from users. +<<<<<<< HEAD graph_->setSignature(GraphSignature{signature_}); +======= + graph_->setSignature(torch::nativert::GraphSignature{signature_}); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) graph_->finalize(); graph_->lint(); // TODO: Might have some source left over, should check it if so. @@ -1320,7 +1412,11 @@ bool Parser::nextIf(char expected) { } void Parser::parseGraphInputs() { +<<<<<<< HEAD TORCH_CHECK(curPos_ == 0); +======= + TORCH_CHECK_EQ(curPos_, 0); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) expect("graph"); const auto inputs = parseList( '(', ')', [&]() { return parseAtomicSymbol(); }); @@ -1383,7 +1479,11 @@ std::string_view Parser::parseUntil( return source_.substr(start, curPos_ - start); } +<<<<<<< HEAD // Parse a string, including the outer quotes +======= +// Parse a strng, including the outer quotes +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::string_view Parser::parseString() { size_t start = curPos_; expect('"'); diff --git a/torch/nativert/graph/Graph.h b/torch/nativert/graph/Graph.h index a86e973621994..80b2933a84a40 100644 --- a/torch/nativert/graph/Graph.h +++ b/torch/nativert/graph/Graph.h @@ -584,8 +584,11 @@ class Graph { void setWeightsMeta( const std::unordered_map& tensorsMeta) { +<<<<<<< HEAD TORCH_CHECK(!placementApplied_); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (auto [name, tensorMeta] : tensorsMeta) { weightsMeta_.emplace(name, TensorMeta{tensorMeta}); } @@ -607,8 +610,11 @@ class Graph { void setTensorValuesMeta( const std::unordered_map& tensorsMeta) { +<<<<<<< HEAD TORCH_CHECK(!placementApplied_); +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (auto [name, tensorMeta] : tensorsMeta) { tensorValuesMeta_.emplace(name, TensorMeta{tensorMeta}); } @@ -634,8 +640,11 @@ class Graph { friend std::ostream& operator<<(std::ostream& out, const Graph& g); GraphSignature signature_; +<<<<<<< HEAD bool placementApplied_ = false; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // keys are parameters, buffers, tensor_constants' names std::unordered_map weightsMeta_; diff --git a/torch/nativert/graph/GraphPasses.cpp b/torch/nativert/graph/GraphPasses.cpp index 6cb378af80dbd..b569376322047 100644 --- a/torch/nativert/graph/GraphPasses.cpp +++ b/torch/nativert/graph/GraphPasses.cpp @@ -101,10 +101,14 @@ std::string selectScalarOverloadName(const Node& node) { "floor_divide_out", "_conj"}; std::vector atoms = c10::split(node.target(), '.'); +<<<<<<< HEAD if (atoms.size() < 3) { return ""; } +======= + TORCH_CHECK_GE(atoms.size(), 3); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) std::string ns = std::string{atoms[atoms.size() - 3]}; std::string opName = std::string{atoms[atoms.size() - 2]}; @@ -113,7 +117,11 @@ std::string selectScalarOverloadName(const Node& node) { overloadName != "Tensor_mode") { return overloadName; } +<<<<<<< HEAD if (allowed.find(opName) == allowed.end()) { +======= + if (allowed.find(std::string{opName}) == allowed.end()) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return overloadName; } auto op = c10::Dispatcher::singleton().findSchemaOrThrow( diff --git a/torch/nativert/graph/GraphPasses.h b/torch/nativert/graph/GraphPasses.h index 7971aeb6b2242..45d08993c57ce 100644 --- a/torch/nativert/graph/GraphPasses.h +++ b/torch/nativert/graph/GraphPasses.h @@ -4,8 +4,14 @@ namespace torch::nativert { +<<<<<<< HEAD void selectScalarOverload(Graph* graph); std::string selectScalarOverloadName(const Node& node); +======= +void selectScalarOverload(torch::nativert::Graph* graph); + +std::string selectScalarOverloadName(const torch::nativert::Node& node); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // namespace torch::nativert diff --git a/torch/nativert/graph/GraphSignature.h b/torch/nativert/graph/GraphSignature.h index a9e2a95bbaa6b..1ac4f25273b9e 100644 --- a/torch/nativert/graph/GraphSignature.h +++ b/torch/nativert/graph/GraphSignature.h @@ -14,7 +14,11 @@ namespace torch::nativert { * * The GraphSignature class models the input and output specs of an exported * graph produced by torch.export, which is a fx.Graph with stronger invariants +<<<<<<< HEAD * guarantees. It holds the graph information deserialized from the pt2 archive +======= + * gurantees. It holds the graph information deserialized from the pt2 archive +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * package. Runtime relies on the GraphSignature for weight name lookup and * weight loading. The serialization schema is defined in * torch/_export/serde/schema.py See more at: diff --git a/torch/nativert/graph/Serialization.cpp b/torch/nativert/graph/Serialization.cpp index 4c45edd1f5751..40318f0b67b32 100644 --- a/torch/nativert/graph/Serialization.cpp +++ b/torch/nativert/graph/Serialization.cpp @@ -184,7 +184,11 @@ std::unique_ptr jsonToSubgraph( graphInputs = std::move(reorderedGraphInputs); auto reorderedSignature = *signature; reorderedSignature.set_input_specs(reorderedInputSpecs); +<<<<<<< HEAD graph->setSignature(GraphSignature{reorderedSignature}); +======= + graph->setSignature(torch::nativert::GraphSignature{reorderedSignature}); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } for (const auto& input : graphInputs) { @@ -408,7 +412,11 @@ std::unique_ptr jsonToSubgraph( } sig.set_output_specs(std::move(outputSpecs)); +<<<<<<< HEAD graph->setSignature(GraphSignature{sig}); +======= + graph->setSignature(torch::nativert::GraphSignature{sig}); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } // weightsTensorMeta are indexed by weight's name, not graph input's name @@ -422,11 +430,17 @@ std::unique_ptr jsonToSubgraph( } auto it = jsonTensorValue.find(inputName); +<<<<<<< HEAD TORCH_CHECK( it != jsonTensorValue.end(), "Missing tensor metadata for ", inputName, "in thriftGraph.tensorValue"); +======= + CHECK(it != jsonTensorValue.end()) + << "Missing tensor metadata for " << inputName + << "in thriftGraph.tensorValue"; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) weightsTensorMeta[weightName] = it->second; } graph->setWeightsMeta(weightsTensorMeta); @@ -464,7 +478,11 @@ Constant constantToValue( bool loadNodeMetadata) { switch (jsonArg.tag()) { case torch::_export::Argument::Tag::AS_NONE: +<<<<<<< HEAD return None(); +======= + return torch::nativert::None(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case torch::_export::Argument::Tag::AS_INT: return jsonArg.get_as_int(); case torch::_export::Argument::Tag::AS_INTS: { @@ -493,6 +511,7 @@ Constant constantToValue( return ret; } case torch::_export::Argument::Tag::AS_SCALAR_TYPE: +<<<<<<< HEAD return convertJsonScalarType(jsonArg.get_as_scalar_type()); case torch::_export::Argument::Tag::AS_MEMORY_FORMAT: return convertJsonMemoryFormat(jsonArg.get_as_memory_format()); @@ -500,6 +519,17 @@ Constant constantToValue( return convertJsonLayout(jsonArg.get_as_layout()); case torch::_export::Argument::Tag::AS_DEVICE: return convertJsonDevice(jsonArg.get_as_device()); +======= + return torch::nativert::convertJsonScalarType( + jsonArg.get_as_scalar_type()); + case torch::_export::Argument::Tag::AS_MEMORY_FORMAT: + return torch::nativert::convertJsonMemoryFormat( + jsonArg.get_as_memory_format()); + case torch::_export::Argument::Tag::AS_LAYOUT: + return torch::nativert::convertJsonLayout(jsonArg.get_as_layout()); + case torch::_export::Argument::Tag::AS_DEVICE: + return torch::nativert::convertJsonDevice(jsonArg.get_as_device()); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) case torch::_export::Argument::Tag::AS_BOOL: return jsonArg.get_as_bool(); case torch::_export::Argument::Tag::AS_BOOLS: { diff --git a/torch/nativert/graph/TensorMeta.cpp b/torch/nativert/graph/TensorMeta.cpp index 68d47a58fb68a..9b42d5cabb7ce 100644 --- a/torch/nativert/graph/TensorMeta.cpp +++ b/torch/nativert/graph/TensorMeta.cpp @@ -41,10 +41,13 @@ c10::ScalarType convertJsonScalarType( return c10::ScalarType::Float8_e4m3fn; case torch::_export::ScalarType::FLOAT8E5M2: return c10::ScalarType::Float8_e5m2; +<<<<<<< HEAD case torch::_export::ScalarType::FLOAT8E4M3FNUZ: return c10::ScalarType::Float8_e4m3fnuz; case torch::_export::ScalarType::FLOAT8E5M2FNUZ: return c10::ScalarType::Float8_e5m2fnuz; +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) default: TORCH_CHECK(false, "unknown scalar type", static_cast(scalarType)); } @@ -106,6 +109,7 @@ TensorMeta::TensorMeta(const torch::_export::TensorMeta& tensorMeta) layout_(convertJsonLayout(tensorMeta.get_layout())), requiresGrad_(tensorMeta.get_requires_grad()), device_(convertJsonDevice(tensorMeta.get_device())) { +<<<<<<< HEAD const auto& storageOffset = tensorMeta.get_storage_offset(); if (storageOffset.tag() == torch::_export::SymInt::Tag::AS_INT) { storage_offset_ = tensorMeta.get_storage_offset().get_as_int(); @@ -114,6 +118,13 @@ TensorMeta::TensorMeta(const torch::_export::TensorMeta& tensorMeta) // setting the storage offset to 0 for now hasSymbolicShape_ = true; storage_offset_ = 0; +======= + if (tensorMeta.get_storage_offset().tag() == + torch::_export::SymInt::Tag::AS_INT) { + storage_offset_ = tensorMeta.get_storage_offset().get_as_int(); + } else { + CHECK(false) << "SymInt not supported yet"; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } for (const auto& size : tensorMeta.get_sizes()) { @@ -123,7 +134,11 @@ TensorMeta::TensorMeta(const torch::_export::TensorMeta& tensorMeta) numel_ *= val; } else if (size.tag() == torch::_export::SymInt::Tag::AS_EXPR) { // TODO: it's still unclear how SymInt shape should be used in runtime +<<<<<<< HEAD // One potential use cases is for verifying inputs shape matches constrain +======= + // One potential use cases is for verifing inputs shape matches constrain +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // This would require unpacking the serialized constrain, which is NYI // // For the time being, we just set the symbolic dim to -1 diff --git a/torch/nativert/graph/TensorMeta.h b/torch/nativert/graph/TensorMeta.h index 7fe9a88c731af..143e2d94e285c 100644 --- a/torch/nativert/graph/TensorMeta.h +++ b/torch/nativert/graph/TensorMeta.h @@ -10,7 +10,10 @@ #include #include +<<<<<<< HEAD #include +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace torch::nativert { @@ -26,12 +29,20 @@ class TensorMeta { explicit TensorMeta(const torch::_export::TensorMeta& tensorMeta); c10::IntArrayRef sizes() const { +<<<<<<< HEAD TORCH_CHECK(!hasSymbolicShape_, "TensorMeta has symbolic shape"); +======= + CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape"; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return sizes_; } c10::IntArrayRef strides() const { +<<<<<<< HEAD TORCH_CHECK(!hasSymbolicShape_, "TensorMeta has symbolic shape"); +======= + CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape"; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return strides_; } @@ -56,7 +67,11 @@ class TensorMeta { } int64_t numel() const { +<<<<<<< HEAD TORCH_CHECK(!hasSymbolicShape_, "TensorMeta has symbolic shape"); +======= + CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape"; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return numel_; } @@ -69,11 +84,14 @@ class TensorMeta { requiresGrad_); } +<<<<<<< HEAD // override device according to placement void applyDevicePlacement(const Placement& placement) { device_ = placement.getMappedDevice(device_); } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) // NYI // c10::SymIntArrayRef sym_sizes() const {} // c10::SymIntArrayRef sym_strides() const {} diff --git a/torch/nativert/kernels/AutoFunctionalizeKernel.cpp b/torch/nativert/kernels/AutoFunctionalizeKernel.cpp index 76589b52c56ee..7bc57ff2cd180 100644 --- a/torch/nativert/kernels/AutoFunctionalizeKernel.cpp +++ b/torch/nativert/kernels/AutoFunctionalizeKernel.cpp @@ -11,14 +11,23 @@ UnsafeAutoFunctionalizeKernel::UnsafeAutoFunctionalizeKernel(const Node* node) op_(getOperatorForTarget( std::get(node->attributes()[0].value))), schema_(op_.schema()), +<<<<<<< HEAD arguments_(prefillStackWithStaticArgs(node, schema_)), numOutputs_(static_cast(schema_.returns().size())) { +======= + arguments_(prefillStackWithStaticArgs(node, schema_)) { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (const auto& [idx, schemaArg] : c10::enumerate(schema_.arguments())) { if (schemaArg.alias_info() != nullptr && schemaArg.alias_info()->isWrite()) { mutatingInputArgs_.push_back(node->getInput(schemaArg.name()).value); } } +<<<<<<< HEAD +======= + + numOutputs_ = schema_.returns().size(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } void UnsafeAutoFunctionalizeKernel::computeInternal( diff --git a/torch/nativert/kernels/AutoFunctionalizeKernel.h b/torch/nativert/kernels/AutoFunctionalizeKernel.h index f9d6e6e58c6c9..9d85c256fe410 100644 --- a/torch/nativert/kernels/AutoFunctionalizeKernel.h +++ b/torch/nativert/kernels/AutoFunctionalizeKernel.h @@ -4,8 +4,13 @@ #include #include +<<<<<<< HEAD #include #include +======= +#include // @manual +#include // @manual +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) namespace torch::nativert { diff --git a/torch/nativert/kernels/C10Kernel.cpp b/torch/nativert/kernels/C10Kernel.cpp index 98f028ff3da2d..7bce544a10fc2 100644 --- a/torch/nativert/kernels/C10Kernel.cpp +++ b/torch/nativert/kernels/C10Kernel.cpp @@ -13,9 +13,16 @@ namespace torch::nativert { C10Kernel::C10Kernel( const Node* node, +<<<<<<< HEAD OpKernelKind kind, AliasingSpec&& aliasingSpec) : OpKernel(node, kind), +======= + c10::Device device, + OpKernelKind kind, + AliasingSpec&& aliasingSpec) + : OpKernel(node, device, kind), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) op_(getOperatorForTarget(node->target(), node)), schema_(op_.schema(), std::move(aliasingSpec), kind_), arguments_(prefillStackWithStaticArgs(node, op_.schema())) {} @@ -48,10 +55,15 @@ void C10Kernel::computeInternal(ExecutionFrame& executionFrame) const { // these are named I don't think it will ever happen in practice. We need to // enforce it though. const auto& outputValues = node_->outputs(); +<<<<<<< HEAD TORCH_CHECK( outputValues.size() == stack.size(), "Output size mismatch for ", node_->toString()); +======= + TORCH_CHECK_EQ(outputValues.size(), stack.size()) + << "Output size mismatch for " << node_->toString(); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for (auto&& [i, actualOutput] : c10::enumerate(stack)) { executionFrame.setIValue(outputValues[i]->id(), std::move(actualOutput)); } diff --git a/torch/nativert/kernels/C10Kernel.h b/torch/nativert/kernels/C10Kernel.h index d477524103819..114351689b92f 100644 --- a/torch/nativert/kernels/C10Kernel.h +++ b/torch/nativert/kernels/C10Kernel.h @@ -24,9 +24,16 @@ class C10Kernel : public OpKernel { C10Kernel() = delete; // deleted default constructor C10Kernel( const Node* node, +<<<<<<< HEAD OpKernelKind kind = OpKernelKind::kInterpreterFallbackKernel, AliasingSpec&& aliasingSpec = {}); ~C10Kernel() override = default; +======= + c10::Device device, + OpKernelKind kind = OpKernelKind::kInterpreterFallbackKernel, + AliasingSpec&& aliasingSpec = {}); + virtual ~C10Kernel() override = default; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [[nodiscard]] const c10::IValue& input( uint32_t i, @@ -57,19 +64,31 @@ class C10Kernel : public OpKernel { class SymIntOpKernel : public OpKernel { public: explicit SymIntOpKernel(const Node* node) : OpKernel(node) {} +<<<<<<< HEAD void computeInternal(ExecutionFrame& executionFrame) const final; +======= + void computeInternal(ExecutionFrame& executionFrame) const override final; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; class SymBoolOpKernel : public OpKernel { public: explicit SymBoolOpKernel(const Node* node) : OpKernel(node) {} +<<<<<<< HEAD void computeInternal(ExecutionFrame& executionFrame) const final; +======= + void computeInternal(ExecutionFrame& executionFrame) const override final; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; class SymFloatOpKernel : public OpKernel { public: explicit SymFloatOpKernel(const Node* node) : OpKernel(node) {} +<<<<<<< HEAD void computeInternal(ExecutionFrame& executionFrame) const final; +======= + void computeInternal(ExecutionFrame& executionFrame) const override final; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; // ScalarOpKernel does binary arithmetic operations on scalar values. @@ -78,7 +97,11 @@ class SymFloatOpKernel : public OpKernel { class ScalarBinaryOpKernel : public OpKernel { public: explicit ScalarBinaryOpKernel(const Node* node) : OpKernel(node) {} +<<<<<<< HEAD void computeInternal(ExecutionFrame& executionFrame) const final; +======= + void computeInternal(ExecutionFrame& executionFrame) const override final; +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) }; } // namespace torch::nativert diff --git a/torch/nativert/kernels/HigherOrderKernel.cpp b/torch/nativert/kernels/HigherOrderKernel.cpp index 370339c82f820..2c8ba9d385d12 100644 --- a/torch/nativert/kernels/HigherOrderKernel.cpp +++ b/torch/nativert/kernels/HigherOrderKernel.cpp @@ -6,33 +6,58 @@ namespace torch::nativert { +<<<<<<< HEAD +======= +using torch::nativert::Graph; + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) HigherOrderKernel::HigherOrderKernel( const Node* node, std::vector> graphExecutors) : OpKernel(node), graphExecutors_(std::move(graphExecutors)) { static constexpr std::string_view prefix = "torch.ops.higher_order."; +<<<<<<< HEAD TORCH_CHECK(c10::starts_with(node->target(), prefix)); +======= + CHECK(c10::starts_with(node->target(), prefix)); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto opName = node->target().substr(prefix.size()); if (opName == "cond") { opType_ = OpType::COND; // Checking torch.cond schema is as expected: // torch.cond(Tensor predicate, Graph graph1, Graph graph2, Tensor[] args) // -> Tensor[] +<<<<<<< HEAD TORCH_CHECK(node_->attributes().size() == 2); TORCH_CHECK(node_->inputs().size() == 2); +======= + TORCH_CHECK_EQ(node_->attributes().size(), 2); + TORCH_CHECK_EQ(node_->inputs().size(), 2); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else if (opName == "while_loop") { opType_ = OpType::WHILE_LOOP; // Checking torch.while_loop schema is as expected: // torch.while_loop(Graph cond, Graph body, Tensor[] args, Tensor[] +<<<<<<< HEAD // additional) -> Tensor[] TORCH_CHECK(node_->attributes().size() == 2); TORCH_CHECK(node_->inputs().size() == 2); +======= + // additonal) -> Tensor[] + TORCH_CHECK_EQ(node_->attributes().size(), 2); + TORCH_CHECK_EQ(node_->inputs().size(), 2); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else if (opName == "run_const_graph") { opType_ = OpType::RUN_CONST_GRAPH; // Checking torch.run_const_graph schema is as expected: // torch.run_const_graph(Graph graph, Tensor[] args) -> Tensor[] +<<<<<<< HEAD TORCH_CHECK(!node_->attributes().empty()); TORCH_CHECK(node_->inputs().size() == 1); +======= + TORCH_CHECK_GE(node_->attributes().size(), 1); + TORCH_CHECK_EQ(node_->inputs().size(), 1); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } else { throw std::runtime_error( fmt::format("Unknown higher order op: {}", opName)); diff --git a/torch/nested/__init__.py b/torch/nested/__init__.py index 5aa739efd2edb..8e002b8faaeb4 100644 --- a/torch/nested/__init__.py +++ b/torch/nested/__init__.py @@ -392,8 +392,13 @@ def nested_tensor_from_jagged( offsets (optional :class:`torch.Tensor`): Offsets into the jagged dimension of shape B + 1. lengths (optional :class:`torch.Tensor`): Lengths of the batch elements of shape B. jagged_dim (optional int): Indicates which dimension in values is the packed jagged +<<<<<<< HEAD dimension. Must be >= 1 as the batch dimension (dim=0) cannot be ragged. If None, this is set to dim=1 (i.e. the dimension immediately following the batch dimension). Default: None +======= + dimension. If None, this is set to dim=1 (i.e. the dimension immediately following + the batch dimension). Default: None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) min_seqlen (optional int): If set, uses the specified value as the cached minimum sequence length for the returned nested tensor. This can be a useful alternative to computing this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None @@ -450,8 +455,11 @@ def nested_tensor_from_jagged( if jagged_dim is None: jagged_dim = 1 +<<<<<<< HEAD elif jagged_dim < 1: raise ValueError(f"Expected jagged_dim >=1, but got {jagged_dim}.") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.nested._internal.nested_tensor import ( nested_view_from_values_offsets_lengths, diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index d3c4ba8c91661..3c76435aafbf2 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -234,6 +234,7 @@ def _maybe_min_seqlen(self) -> Optional[int]: mt = self._min_seqlen_tensor return None if mt is None else _load_val_from_tensor(mt) +<<<<<<< HEAD def _is_contiguous_or_false(self): if self.lengths() is not None: return False @@ -243,16 +244,24 @@ def _is_contiguous_or_false(self): self._values, memory_format=torch.contiguous_format ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __repr__(self): # type: ignore[override] # We should implement this in torch/_tensor_str.py instead grad_fn_str = ( f", requires_grad={self.requires_grad}" if self.requires_grad else "" ) +<<<<<<< HEAD if self.grad_fn: grad_fn_str = f", grad_fn={self.grad_fn}" return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._is_contiguous_or_false()})" +======= + if self.grad_fn: + grad_fn_str = f", grad_fn={self.grad_fn}" + return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self.is_contiguous()})" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO: Remove this in favor of the default tensor subclass serialization logic. # We don't do this today because of https://github.com/pytorch/pytorch/issues/125622. diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 19b1fe670835f..8d9ae1bb61b94 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -517,6 +517,7 @@ def is_contiguous_general(func, *args, **kwargs): @register_jagged_func( +<<<<<<< HEAD torch.ops.aten.sym_is_contiguous.default, "self: jt_all, memory_format: any?" ) def sym_is_contiguous_general(func, *args, **kwargs): @@ -540,6 +541,8 @@ def sym_is_contiguous_general(func, *args, **kwargs): @register_jagged_func( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.ops.aten.clone.default, "input: jt_all, memory_format: any?" ) def clone_default(func, *args, **kwargs): @@ -865,6 +868,7 @@ def _softmax_default(func, *args, **kwargs): @register_jagged_func( +<<<<<<< HEAD torch.ops.aten._log_softmax.default, "self: jt_all, dim: any, half_to_float: any" ) def _log_softmax_default(func, *args, **kwargs): @@ -905,6 +909,8 @@ def _log_softmax_default(func, *args, **kwargs): @register_jagged_func( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.ops.aten._softmax_backward_data.default, "grad_output: jt, output: jt, dim: any, input_dtype: any", ) @@ -2683,7 +2689,11 @@ def flex_njt( kernel_options: Dict[str, Any], score_mod_other_buffers: Tuple = (), mask_mod_other_buffers: Tuple = (), +<<<<<<< HEAD ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +======= +) -> Tuple[torch.Tensor, torch.Tensor]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert query.dim() == 4 and key.dim() == 4 and value.dim() == 4 # TODO: Support this if needed; determine if NJT buffers need be unwrapped as dense. @@ -2696,9 +2706,12 @@ def flex_njt( "currently supported. Please file an issue if this is important to you." ) +<<<<<<< HEAD # Always set them since 0 sized elements are not handled gracefully kernel_options = {**kernel_options, "OUTPUT_MAX": True, "OUTPUT_LOGSUMEXP": True} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # need to pass dense tensor of shape (B, n_heads, sum(seq_len), D) output = flex_attention_hop( query.values().unsqueeze(0), @@ -2729,6 +2742,7 @@ def flex_njt( max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined] ).transpose(1, 2) +<<<<<<< HEAD max_scores_njt = torch.nested.nested_tensor_from_jagged( output[2].transpose(1, 2).squeeze(0), query._offsets, # type: ignore[attr-defined] @@ -2738,6 +2752,9 @@ def flex_njt( ).transpose(1, 2) return (output_njt, logsumexp_njt, max_scores_njt) +======= + return (output_njt, logsumexp_njt) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @flex_attention_backward_hop.py_impl(NestedTensor) # type: ignore[misc] diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index 997e1805d08c3..1e400f46d3225 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -369,7 +369,11 @@ def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor): # use with the flash-attention and efficient_attention kernels without # needing to call contiguous on the nested tensor input. # It checks that the storage offsets' adjacent_differences are a constant +<<<<<<< HEAD # multiple of the previous tensor in the nested tensor and that the strides +======= + # mutiple of the previous tensor in the nested tensor and that the strides +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # are monitonically decreasing. This check is done after calling transpose on # the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim] diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index efdd7daa0d2a6..6cf8d68159779 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -39,7 +39,10 @@ - FLASH_ATTENTION: The flash attention backend for scaled dot product attention. - EFFICIENT_ATTENTION: The efficient attention backend for scaled dot product attention. - CUDNN_ATTENTION: The cuDNN backend for scaled dot product attention. +<<<<<<< HEAD - OVERRIDEABLE: The overridable backend for extension. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) See :func:`torch.nn.attention.sdpa_kernel` for more details. @@ -68,7 +71,10 @@ def _raise_kernel_warnings(params: SDPAParams) -> None: "flash": "FLASH_ATTENTION", "mem_efficient": "EFFICIENT_ATTENTION", "math": "MATH", +<<<<<<< HEAD "overrideable": "OVERRIDEABLE", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } @@ -79,7 +85,11 @@ def _backend_from_string(name: str): def _cur_sdpa_kernel_backends(with_priority: bool = False): backends = [] for name, val in _backend_names.items(): +<<<<<<< HEAD if getattr(torch._C, f"_get_{name}_sdp_enabled")(): +======= + if getattr(torch.backends.cuda, f"{name}_sdp_enabled")(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) backends.append(getattr(SDPBackend, val)) if with_priority: curr_priority = torch._C._get_sdp_priority_order() @@ -92,7 +102,11 @@ def _cur_sdpa_kernel_backends(with_priority: bool = False): def _sdpa_kernel(backends: Iterable, set_priority: bool = False): for name, val in _backend_names.items(): enabled = getattr(SDPBackend, val) in backends +<<<<<<< HEAD getattr(torch._C, f"_set_sdp_use_{name}")(enabled) +======= + getattr(torch.backends.cuda, f"enable_{name}_sdp")(enabled) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if set_priority: # backends should be a unique list user_priority = [int(backend) for backend in backends] diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index 3d002b7b23656..d3a0f10b85258 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -269,7 +269,11 @@ def _dispatch( )[0].transpose(1, 2) else: _raise_kernel_warnings(sdpa_params) +<<<<<<< HEAD # We can't use efficient attention the only support for lower right is via materialization +======= + # We cant use efficient attention the only support for lower right is via materialization +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.scaled_dot_product_attention( query, key, diff --git a/torch/nn/attention/experimental/__init__.py b/torch/nn/attention/experimental/__init__.py index 4a6694bbe3990..ede944bdfc6bd 100644 --- a/torch/nn/attention/experimental/__init__.py +++ b/torch/nn/attention/experimental/__init__.py @@ -1,2 +1,6 @@ # Experimental features are not mature yet and are subject to change. +<<<<<<< HEAD # We do not provide any BC/FC guarantees +======= +# We do not provide any BC/FC guarntees +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/nn/attention/experimental/_paged_attention.py b/torch/nn/attention/experimental/_paged_attention.py index 70eadcdadfaa0..d71734d4b821e 100644 --- a/torch/nn/attention/experimental/_paged_attention.py +++ b/torch/nn/attention/experimental/_paged_attention.py @@ -29,7 +29,11 @@ class PagedAttention: """ PagedAttention supports flex attention inference with a large batch size. With PagedAttention, a batch of key/value tensors with varying kv length +<<<<<<< HEAD is split into tensor blocks of fixed length and cached in a compact way. +======= + is splitted into tensor blocks of fixed length and cached in a compact way. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Thus we can avoid redundant memory consumption due to varying kv length and support a larger batch size. """ @@ -198,7 +202,10 @@ def convert_logical_block_mask( self, block_mask: BlockMask, batch_idx: Optional[torch.Tensor] = None, +<<<<<<< HEAD kv_len: Optional[torch.Tensor] = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> BlockMask: """ Converts a logical block mask by mapping its logical kv indices to the corresponding @@ -211,8 +218,11 @@ def convert_logical_block_mask( batch dimension. This provides flexibility to convert a block mask with smaller batch size than the page table; shape :math:`(B)`. +<<<<<<< HEAD kv_len (Optional[Tensor]): actual KV sequence length for upper bound check; shape :math:`(B,)` to handle multiple batches. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ B, H, ROWS, MAX_BLOCKS_IN_COL = block_mask.kv_indices.shape @@ -264,7 +274,11 @@ def convert_logical_block_mask( .to(torch.int32) ) +<<<<<<< HEAD new_mask_mod = self.get_mask_mod(block_mask.mask_mod, kv_len) +======= + new_mask_mod = self.get_mask_mod(block_mask.mask_mod) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size) return BlockMask.from_kv_blocks( @@ -278,9 +292,13 @@ def convert_logical_block_mask( ) def get_mask_mod( +<<<<<<< HEAD self, mask_mod: Optional[_mask_mod_signature], kv_len: Optional[torch.Tensor] = None, +======= + self, mask_mod: Optional[_mask_mod_signature] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> _mask_mod_signature: """ Converts a mask_mod based on mapping from the physical block index to the logical @@ -288,7 +306,10 @@ def get_mask_mod( Args: mask_mod (_mask_mod_signature): mask_mod based on the logical block index. +<<<<<<< HEAD kv_len (Optional[torch.Tensor]): actual KV sequence length for upper bound check. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ if mask_mod is None: mask_mod = noop_mask @@ -303,6 +324,7 @@ def new_mask_mod( physical_kv_offset = physical_kv_idx % self.page_size logical_block_idx = self.physical_to_logical[b, physical_kv_block] logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset +<<<<<<< HEAD live_block = logical_block_idx >= 0 within_upper_bound = ( logical_kv_idx < kv_len[b] if kv_len is not None else True @@ -311,13 +333,22 @@ def new_mask_mod( is_valid = live_block & within_upper_bound & within_lower_bound return torch.where(is_valid, mask_mod(b, h, q_idx, logical_kv_idx), False) +======= + return torch.where( + logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return new_mask_mod def get_score_mod( +<<<<<<< HEAD self, score_mod: Optional[_score_mod_signature], kv_len: Optional[torch.Tensor] = None, +======= + self, score_mod: Optional[_score_mod_signature] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> _score_mod_signature: """ Converts a score_mod based on mapping from the physical block index to the logical @@ -325,8 +356,11 @@ def get_score_mod( Args: score_mod (_score_mod_signature): score_mod based on the logical block index. +<<<<<<< HEAD `kv_len (Optional[torch.Tensor]): actual KV sequence length for upper bound check. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ if score_mod is None: score_mod = _identity @@ -342,6 +376,7 @@ def new_score_mod( physical_kv_offset = physical_kv_idx % self.page_size logical_block_idx = self.physical_to_logical[b, physical_kv_block] logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset +<<<<<<< HEAD live_block = logical_block_idx >= 0 within_upper_bound = ( logical_kv_idx < kv_len[b] if kv_len is not None else True @@ -351,6 +386,10 @@ def new_score_mod( return torch.where( is_valid, +======= + return torch.where( + logical_block_idx >= 0, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) score_mod(score, b, h, q_idx, logical_kv_idx), float("-inf"), ) diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index f99acaf50a126..e20f37e4f5ff4 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -9,6 +9,7 @@ import operator import warnings from enum import Enum +<<<<<<< HEAD from typing import Any, Callable, NamedTuple, Optional, Union import torch @@ -25,6 +26,12 @@ except ImportError: from typing_extensions import NotRequired +======= +from typing import Any, Callable, Optional, Union + +import torch +from torch import Tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop from torch._higher_order_ops.utils import _set_compilation_env from torch._prims_common import DeviceLikeType @@ -36,6 +43,7 @@ from torch.utils._pytree import tree_map_only +<<<<<<< HEAD # Private debug flag to disable internal compilation wrapping for debugging purposes. # WARNING: This is intended ONLY for debugging score_mod and mask_mod functions. # When enabled, this bypasses the required internal compilation that ensures correctness @@ -72,6 +80,11 @@ def _warn_once( "AuxOutput", "AuxRequest", "FlexKernelOptions", +======= +__all__ = [ + "BlockMask", + "flex_attention", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "create_block_mask", "create_mask", "create_nested_block_mask", @@ -84,6 +97,7 @@ def _warn_once( _mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] +<<<<<<< HEAD class FlexKernelOptions(TypedDict, total=False): """Options for controlling the behavior of FlexAttention kernels. @@ -221,6 +235,8 @@ class AuxOutput(NamedTuple): max_scores: Optional[Tensor] = None +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class _ModificationType(Enum): """Enum for the type of modification function. - SCORE_MOD: score_mod function which accepts a score as the first argument @@ -474,6 +490,11 @@ def __init__( raise RuntimeError("BlockMask must have at least 2 dimensions") assert kv_num_blocks is not None, "kv_num_blocks must be provided" assert kv_indices is not None, "kv_indices must be provided" +<<<<<<< HEAD +======= + assert q_num_blocks is not None, "q_num_blocks must be provided" + assert q_indices is not None, "q_indices must be provided" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert (full_kv_num_blocks is None) == (full_kv_indices is None), ( "full_kv_num_blocks and full_kv_indices must be both provided or omitted" ) @@ -503,7 +524,10 @@ def from_kv_blocks( BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE, mask_mod: Optional[_mask_mod_signature] = None, seq_lengths: Optional[tuple[int, int]] = None, +<<<<<<< HEAD compute_q_blocks: bool = True, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Creates a BlockMask instance from key-value block information. @@ -531,6 +555,7 @@ def from_kv_blocks( ) # Generate q_num_blocks and q_indices +<<<<<<< HEAD if compute_q_blocks: q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices) if full_kv_num_blocks is not None: @@ -542,6 +567,15 @@ def from_kv_blocks( full_q_num_blocks, full_q_indices = None, None else: q_num_blocks, q_indices = None, None +======= + q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices) + if full_kv_num_blocks is not None: + assert full_kv_indices is not None + full_q_num_blocks, full_q_indices = _transpose_ordered( + full_kv_num_blocks, full_kv_indices + ) + else: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) full_q_num_blocks, full_q_indices = None, None if isinstance(BLOCK_SIZE, int): @@ -550,7 +584,11 @@ def from_kv_blocks( mask_mod = mask_mod if mask_mod is not None else noop_mask if seq_lengths is None: q_length = kv_indices.shape[-2] * BLOCK_SIZE[0] +<<<<<<< HEAD kv_length = kv_indices.shape[-1] * BLOCK_SIZE[1] +======= + kv_length = q_indices.shape[-2] * BLOCK_SIZE[1] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) seq_lengths = (q_length, kv_length) return cls( @@ -649,6 +687,7 @@ def causal_mask(b, h, q_idx, kv_idx): assert new_block_mask.kv_num_blocks.shape == (2, 1, 1) assert new_block_mask.kv_indices.shape == (2, 1, 1, 4) """ +<<<<<<< HEAD index = (index,) if not isinstance(index, tuple) else index padded = (*index, slice(None), slice(None), slice(None))[:3] sizes = self.kv_num_blocks.shape[:3] @@ -658,6 +697,8 @@ def causal_mask(b, h, q_idx, kv_idx): else i for i, n in zip(padded, sizes) ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) new_kv_num_blocks = self.kv_num_blocks[index] new_kv_indices = self.kv_indices[index] if self.full_kv_num_blocks is not None: @@ -675,7 +716,10 @@ def causal_mask(b, h, q_idx, kv_idx): BLOCK_SIZE=self.BLOCK_SIZE, mask_mod=None, seq_lengths=self.seq_lengths, +<<<<<<< HEAD compute_q_blocks=self.q_indices is not None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def __repr__(self): @@ -837,7 +881,11 @@ def to(self, device: Union[torch.device, str]) -> "BlockMask": Note: This method does not modify the original BlockMask in-place. +<<<<<<< HEAD Instead, it returns a new BlockMask instance where individual tensor attributes +======= + Instead, it returns a new BlockMask instance where invidual tensor attributes +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) may or may not be moved to the specified device, depending on their current device placement. """ @@ -1294,12 +1342,16 @@ def causal_mask(b, h, q_idx, kv_idx): def _apply_kernel_options( +<<<<<<< HEAD query: Tensor, key: Tensor, value: Tensor, return_lse: bool, kernel_options, return_aux: Optional[AuxRequest] = None, +======= + query: Tensor, key: Tensor, value: Tensor, return_lse: bool, kernel_options +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): kernel_options = {} if kernel_options is None else dict(kernel_options) @@ -1309,6 +1361,7 @@ def _apply_kernel_options( # This forces all biases grad scatters to be done in the DQ iteration loop of the backwards kernel_options.setdefault("WRITE_DQ", True) +<<<<<<< HEAD any_inputs_on_cpu_device = ( query.device.type == "cpu" or key.device.type == "cpu" @@ -1345,6 +1398,26 @@ def _apply_kernel_options( raise NotImplementedError("Returning max scores is not supported on CPU.") kernel_options["OUTPUT_MAX"] = False +======= + # If forward kernel needs to return logsumexp is decided by this rule internally. + assert "OUTPUT_LOGSUMEXP" not in kernel_options + kernel_options["OUTPUT_LOGSUMEXP"] = True + if not return_lse: + # We used to check if q,k,v required grads but since captured buffers can require grad + # we always write unless in no_grad + output_logsumexp = torch.is_grad_enabled() + kernel_options["OUTPUT_LOGSUMEXP"] = output_logsumexp + any_inputs_on_cpu_device = ( + query.device.type == "cpu" + or key.device.type == "cpu" + or value.device.type == "cpu" + ) + if any_inputs_on_cpu_device: + # CPU with torch.compile now supports infernece, and will not return lse + # TODO: support CPU for training and return lse + kernel_options["OUTPUT_LOGSUMEXP"] = False + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return kernel_options @@ -1360,8 +1433,16 @@ def _validate_device(query: Tensor, key: Tensor, value: Tensor): """TODO: Remove once non cuda/cpu devices support is added We only need to check query since we have already that q,k,v are on the same device """ +<<<<<<< HEAD supported_devices = {"cuda", "cpu", "xpu", "hpu"} if query.device.type not in supported_devices: +======= + if ( + query.device.type != "cuda" + and query.device.type != "cpu" + and query.device.type != "hpu" + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise ValueError( "FlexAttention is only supported on CUDA, CPU or HPU devices. " f"Found input tensors on {query.device.type} device." @@ -1455,10 +1536,15 @@ def flex_attention( scale: Optional[float] = None, enable_gqa: bool = False, return_lse: bool = False, +<<<<<<< HEAD kernel_options: Optional[FlexKernelOptions] = None, *, return_aux: Optional[AuxRequest] = None, ) -> Union[Tensor, tuple[Tensor, Tensor], tuple[Tensor, AuxOutput]]: +======= + kernel_options: Optional[dict[str, Any]] = None, +) -> Union[Tensor, tuple[Tensor, Tensor]]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""This function implements scaled dot product attention with an arbitrary attention score modification function. This function computes the scaled dot product attention between query, key, and value tensors with a user-defined @@ -1492,6 +1578,7 @@ def score_mod( block_mask (Optional[BlockMask]): BlockMask object that controls the blocksparsity pattern of the attention. scale (Optional[float]): Scaling factor applied prior to softmax. If none, the default value is set to :math:`\frac{1}{\sqrt{E}}`. enable_gqa (bool): If set to True, enables Grouped Query Attention (GQA) and broadcasts key/value heads to query heads. +<<<<<<< HEAD return_lse (bool): Whether to return the logsumexp of the attention scores. Default is False. **Deprecated**: Use ``return_aux=AuxRequest(lse=True)`` instead. kernel_options (Optional[FlexKernelOptions]): Options to control the behavior of the underlying Triton kernels. @@ -1499,16 +1586,23 @@ def score_mod( return_aux (Optional[AuxRequest]): Specifies which auxiliary outputs to compute and return. If None, only the attention output is returned. Use ``AuxRequest(lse=True, max_scores=True)`` to request both auxiliary outputs. +======= + return_lse (bool): Whether to return the logsumexp of the attention scores. Default is False. + kernel_options (Optional[Dict[str, Any]]): Options to pass into the Triton kernels. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns: output (Tensor): Attention output; shape :math:`(B, Hq, L, Ev)`. +<<<<<<< HEAD When ``return_aux`` is not None: aux (AuxOutput): Auxiliary outputs with requested fields populated. When ``return_aux`` is None (deprecated paths): lse (Tensor): Log-sum-exp of attention scores; shape :math:`(B, Hq, L)`. Only returned if ``return_lse=True``. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Shape legend: - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` - :math:`S: \text{Source sequence length}` @@ -1612,6 +1706,7 @@ def score_mod( f"but got {query.device} and {block_mask.kv_num_blocks.device}." # type: ignore[union-attr] ) +<<<<<<< HEAD # Handle deprecation warnings for old parameters if return_lse and return_aux is not None: raise ValueError( @@ -1626,12 +1721,15 @@ def score_mod( category=FutureWarning, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kernel_options = _apply_kernel_options( query, key, value, return_lse, kernel_options, +<<<<<<< HEAD return_aux, ) @@ -1664,13 +1762,21 @@ def _finalize_outputs( return out +======= + ) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if torch.compiler.is_dynamo_compiling(): # mark head_dim and number of heads to be static for x in [query, key, value]: torch._dynamo.mark_static(x, -3) torch._dynamo.mark_static(x, -1) +<<<<<<< HEAD out, lse, max_scores = flex_attention_hop( +======= + out, lse = flex_attention_hop( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) query, key, value, @@ -1679,6 +1785,7 @@ def _finalize_outputs( scale, kernel_options, # type: ignore[union-attr] ) +<<<<<<< HEAD return _finalize_outputs( out, lse, max_scores, return_aux=return_aux, return_lse=return_lse ) @@ -1694,6 +1801,12 @@ def _finalize_outputs( "This will allow you to use print statements or breakpoints. Note: This doesn't work with the backwards pass and may produce incorrect results." ), ) +======= + if return_lse: + return out, lse * math.log(2) + else: + return out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not torch._dynamo.is_dynamo_supported(): raise RuntimeError("flex_attention requires dynamo support") @@ -1712,6 +1825,7 @@ def _flex_attention_hop_wrapper(*args, **kwargs): with _temp_remove_pre_dispatch_torch_function_mode(): with _temp_remove_metadata_torch_function_mode() as metadata_mode: if metadata_mode: +<<<<<<< HEAD backend: Union[str, Callable[..., Any]] = ( make_eager_backend_with_torch_function_mode(metadata_mode) ) @@ -1726,6 +1840,16 @@ def _flex_attention_hop_wrapper(*args, **kwargs): ) out, lse, max_scores = flex_fn( +======= + backend = make_eager_backend_with_torch_function_mode( + metadata_mode + ) + else: + backend = "eager" + out, lse = torch.compile( + _flex_attention_hop_wrapper, backend=backend, fullgraph=True + )( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) query, key, value, @@ -1734,6 +1858,13 @@ def _flex_attention_hop_wrapper(*args, **kwargs): scale, kernel_options, ) +<<<<<<< HEAD return _finalize_outputs( out, lse, max_scores, return_aux=return_aux, return_lse=return_lse ) +======= + if return_lse: + return out, lse * math.log(2) + else: + return out +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 92142fd44df88..e7afab4c9c910 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -338,6 +338,12 @@ Applies a 1D average pooling over an input signal composed of several input planes. +<<<<<<< HEAD +======= +.. note:: + pad should be at most half of effective kernel size. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) See :class:`~torch.nn.AvgPool1d` for details and output shape. Args: @@ -346,9 +352,14 @@ tuple `(kW,)` stride: the stride of the window. Can be a single number or a tuple `(sW,)`. Default: :attr:`kernel_size` +<<<<<<< HEAD padding: implicit zero paddings on both sides of the input. Can be a single number or a tuple `(padW,)`. Should be at most half of effective kernel size, that is :math:`((kernelSize - 1) * dilation + 1) / 2`. Default: 0 +======= + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padW,)`. Default: 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape. Default: ``False`` count_include_pad: when True, will include the zero-padding in the @@ -374,6 +385,12 @@ :math:`sH \times sW` steps. The number of output features is equal to the number of input planes. +<<<<<<< HEAD +======= +.. note:: + pad should be at most half of effective kernel size. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) See :class:`~torch.nn.AvgPool2d` for details and output shape. Args: @@ -383,9 +400,13 @@ stride: stride of the pooling operation. Can be a single number, a single-element tuple or a tuple `(sH, sW)`. Default: :attr:`kernel_size` padding: implicit zero paddings on both sides of the input. Can be a +<<<<<<< HEAD single number, a single-element tuple or a tuple `(padH, padW)`. Should be at most half of effective kernel size, that is :math:`((kernelSize - 1) * dilation + 1) / 2`. Default: 0 +======= + single number, a single-element tuple or a tuple `(padH, padW)`. Default: 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ceil_mode: when True, will use `ceil` instead of `floor` in the formula to compute the output shape. Default: ``False`` count_include_pad: when True, will include the zero-padding in the @@ -404,6 +425,12 @@ size :math:`sT \times sH \times sW` steps. The number of output features is equal to :math:`\lfloor\frac{\text{input planes}}{sT}\rfloor`. +<<<<<<< HEAD +======= +.. note:: + pad should be at most half of effective kernel size. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) See :class:`~torch.nn.AvgPool3d` for details and output shape. Args: @@ -413,9 +440,13 @@ stride: stride of the pooling operation. Can be a single number or a tuple `(sT, sH, sW)`. Default: :attr:`kernel_size` padding: implicit zero paddings on both sides of the input. Can be a +<<<<<<< HEAD single number or a tuple `(padT, padH, padW)`. Should be at most half of effective kernel size, that is :math:`((kernelSize - 1) * dilation + 1) / 2`. Default: 0 +======= + single number or a tuple `(padT, padH, padW)`, Default: 0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ceil_mode: when True, will use `ceil` instead of `floor` in the formula to compute the output shape count_include_pad: when True, will include the zero-padding in the @@ -4683,7 +4714,11 @@ def interpolate( # noqa: F811 ) # "area" mode always requires an explicit size rather than scale factor. +<<<<<<< HEAD # Reuse the recompute_scale_factor code path. +======= + # Re-use the recompute_scale_factor code path. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if mode == "area" and output_size is None: recompute_scale_factor = True @@ -4758,7 +4793,13 @@ def interpolate( # noqa: F811 # Two levels are necessary to prevent TorchScript from touching # are_deterministic_algorithms_enabled. if not torch.jit.is_scripting(): +<<<<<<< HEAD if not input.is_cpu and torch.are_deterministic_algorithms_enabled(): +======= + if torch.are_deterministic_algorithms_enabled() and ( + input.is_cuda or input.is_xpu + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Use slow decomp whose backward will be in terms of index_put # importlib is required because the import cannot be top level # (cycle) and cannot be nested (TS doesn't support) @@ -4770,6 +4811,7 @@ def interpolate( # noqa: F811 ) if input.dim() == 5 and mode == "trilinear": assert align_corners is not None +<<<<<<< HEAD # Two levels are necessary to prevent TorchScript from touching # are_deterministic_algorithms_enabled. if not torch.jit.is_scripting(): @@ -4780,6 +4822,8 @@ def interpolate( # noqa: F811 return importlib.import_module( "torch._decomp.decompositions" )._upsample_linear_vec(input, output_size, align_corners, scale_factors) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch._C._nn.upsample_trilinear3d( input, output_size, align_corners, scale_factors ) @@ -4911,7 +4955,11 @@ def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 This is equivalent with ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``. +<<<<<<< HEAD Expected inputs are spatial (4 dimensional). Use `upsample_trilinear` for +======= + Expected inputs are spatial (4 dimensional). Use `upsample_trilinear` fo +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) volumetric (5 dimensional) inputs. Args: @@ -5823,6 +5871,10 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) +<<<<<<< HEAD +======= + attn_bias.to(query.dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if attn_mask is not None: if attn_mask.dtype == torch.bool: diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index d0b64447e900b..c5db59c8b57d7 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -396,6 +396,7 @@ def instance_norm( __all__ += ["instance_norm"] def interpolate( +<<<<<<< HEAD input: Tensor, size: int | Sequence[int] | None = ..., scale_factor: float | Sequence[float] | None = ..., @@ -404,6 +405,16 @@ def interpolate( recompute_scale_factor: bool | None = ..., antialias: bool = ..., ) -> Tensor: ... +======= + input: Any, + size: Any | None = ..., + scale_factor: Any | None = ..., + mode: str = ..., + align_corners: Any | None = ..., + recompute_scale_factor: Any | None = ..., + antialias: bool = ..., +): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __all__ += ["interpolate"] diff --git a/torch/nn/init.py b/torch/nn/init.py index af31b6fa22824..bcbd1bb8c1d78 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -64,7 +64,11 @@ # These no_grad_* functions are necessary as wrappers around the parts of these # functions that use `with torch.no_grad()`. The JIT doesn't support context # managers, so these need to be implemented as builtins. Using these wrappers +<<<<<<< HEAD # lets us keep those builtins small and reusable. +======= +# lets us keep those builtins small and re-usable. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _no_grad_uniform_( tensor: Tensor, a: float, b: float, generator: _Optional[torch.Generator] = None ) -> Tensor: @@ -459,6 +463,17 @@ def xavier_uniform_( Examples: >>> w = torch.empty(3, 5) >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain("relu")) +<<<<<<< HEAD +======= + + Note: + Be aware that ``fan_in`` and ``fan_out`` are calculated assuming + that the weight matrix is used in a transposed manner, + (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). + This is important for correct initialization. + If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, + pass in a transposed weight matrix, i.e. ``nn.init.xavier_uniform_(w.T, ...)``. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) @@ -491,6 +506,17 @@ def xavier_normal_( Examples: >>> w = torch.empty(3, 5) >>> nn.init.xavier_normal_(w) +<<<<<<< HEAD +======= + + Note: + Be aware that ``fan_in`` and ``fan_out`` are calculated assuming + that the weight matrix is used in a transposed manner, + (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). + This is important for correct initialization. + If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, + pass in a transposed weight matrix, i.e. ``nn.init.xavier_normal_(w.T, ...)``. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index ed7efdb2b4261..997573ca5bbcb 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -89,6 +89,7 @@ def __init__(self, threshold: float, value: float, inplace: bool = False) -> Non # TODO: check in THNN (if inplace == True, then assert value <= threshold) def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -98,6 +99,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.threshold(input, self.threshold, self.value, self.inplace) + + def extra_repr(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inplace_str = ", inplace=True" if self.inplace else "" return f"threshold={self.threshold}, value={self.value}{inplace_str}" @@ -133,11 +139,16 @@ class ReLU(Module): __constants__ = ["inplace"] inplace: bool +<<<<<<< HEAD def __init__(self, inplace: bool = False) -> None: +======= + def __init__(self, inplace: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__() self.inplace = inplace def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -147,6 +158,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.relu(input, inplace=self.inplace) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inplace_str = "inplace=True" if self.inplace else "" return inplace_str @@ -197,13 +213,18 @@ class RReLU(Module): def __init__( self, lower: float = 1.0 / 8, upper: float = 1.0 / 3, inplace: bool = False +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__() self.lower = lower self.upper = upper self.inplace = inplace def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -213,6 +234,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.rrelu(input, self.lower, self.upper, self.training, self.inplace) + + def extra_repr(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inplace_str = ", inplace=True" if self.inplace else "" return f"lower={self.lower}, upper={self.upper}{inplace_str}" @@ -286,6 +312,7 @@ def __init__( assert self.max_val > self.min_val def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -295,6 +322,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.hardtanh(input, self.min_val, self.max_val, self.inplace) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inplace_str = ", inplace=True" if self.inplace else "" return f"min_val={self.min_val}, max_val={self.max_val}{inplace_str}" @@ -321,6 +353,7 @@ class ReLU6(Hardtanh): >>> output = m(input) """ +<<<<<<< HEAD def __init__(self, inplace: bool = False) -> None: super().__init__(0.0, 6.0, inplace) @@ -328,6 +361,12 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + def __init__(self, inplace: bool = False): + super().__init__(0.0, 6.0, inplace) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inplace_str = "inplace=True" if self.inplace else "" return inplace_str @@ -353,9 +392,12 @@ class Sigmoid(Module): """ def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch.sigmoid(input) @@ -396,9 +438,12 @@ def __init__(self, inplace: bool = False) -> None: self.inplace = inplace def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.hardsigmoid(input, self.inplace) @@ -424,9 +469,12 @@ class Tanh(Module): """ def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return torch.tanh(input) @@ -462,11 +510,16 @@ class SiLU(Module): __constants__ = ["inplace"] inplace: bool +<<<<<<< HEAD def __init__(self, inplace: bool = False) -> None: +======= + def __init__(self, inplace: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__() self.inplace = inplace def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -476,6 +529,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.silu(input, inplace=self.inplace) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inplace_str = "inplace=True" if self.inplace else "" return inplace_str @@ -507,11 +565,16 @@ class Mish(Module): __constants__ = ["inplace"] inplace: bool +<<<<<<< HEAD def __init__(self, inplace: bool = False) -> None: +======= + def __init__(self, inplace: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__() self.inplace = inplace def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -521,6 +584,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.mish(input, inplace=self.inplace) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inplace_str = "inplace=True" if self.inplace else "" return inplace_str @@ -564,9 +632,12 @@ def __init__(self, inplace: bool = False) -> None: self.inplace = inplace def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.hardswish(input, self.inplace) @@ -611,6 +682,7 @@ def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None: self.inplace = inplace def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -620,6 +692,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.elu(input, self.alpha, self.inplace) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inplace_str = ", inplace=True" if self.inplace else "" return f"alpha={self.alpha}{inplace_str}" @@ -662,6 +739,7 @@ def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None: self.inplace = inplace def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -671,6 +749,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.celu(input, self.alpha, self.inplace) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inplace_str = ", inplace=True" if self.inplace else "" return f"alpha={self.alpha}{inplace_str}" @@ -718,6 +801,7 @@ def __init__(self, inplace: bool = False) -> None: self.inplace = inplace def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -727,6 +811,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.selu(input, self.inplace) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inplace_str = "inplace=True" if self.inplace else "" return inplace_str @@ -762,6 +851,7 @@ def __init__(self, dim: int = -1) -> None: self.dim = dim def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -771,6 +861,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.glu(input, self.dim) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"dim={self.dim}" @@ -810,6 +905,7 @@ def __init__(self, approximate: str = "none") -> None: self.approximate = approximate def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -819,6 +915,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.gelu(input, approximate=self.approximate) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"approximate={repr(self.approximate)}" @@ -859,6 +960,7 @@ def __init__(self, lambd: float = 0.5) -> None: self.lambd = lambd def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Run forward pass. """ @@ -868,6 +970,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.hardshrink(input, self.lambd) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"{self.lambd}" @@ -916,6 +1023,7 @@ def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None: self.inplace = inplace def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Run forward pass. """ @@ -925,6 +1033,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.leaky_relu(input, self.negative_slope, self.inplace) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) inplace_str = ", inplace=True" if self.inplace else "" return f"negative_slope={self.negative_slope}{inplace_str}" @@ -949,9 +1062,12 @@ class LogSigmoid(Module): """ def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Run forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.logsigmoid(input) @@ -994,6 +1110,7 @@ def __init__(self, beta: float = 1.0, threshold: float = 20.0) -> None: self.threshold = threshold def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Run forward pass. """ @@ -1003,6 +1120,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.softplus(input, self.beta, self.threshold) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"beta={self.beta}, threshold={self.threshold}" @@ -1041,6 +1163,7 @@ def __init__(self, lambd: float = 0.5) -> None: self.lambd = lambd def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Run forward pass. """ @@ -1050,6 +1173,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.softshrink(input, self.lambd) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return str(self.lambd) @@ -1226,7 +1354,11 @@ def __init__( self._reset_parameters() +<<<<<<< HEAD def _reset_parameters(self) -> None: +======= + def _reset_parameters(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self._qkv_same_embed_dim: xavier_uniform_(self.in_proj_weight) else: @@ -1625,6 +1757,7 @@ def __init__( self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs)) self.reset_parameters() +<<<<<<< HEAD def reset_parameters(self) -> None: """ Resets parameters based on their initialization used in ``__init__``. @@ -1641,6 +1774,15 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + def reset_parameters(self): + torch.nn.init.constant_(self.weight, self.init) + + def forward(self, input: Tensor) -> Tensor: + return F.prelu(input, self.weight) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"num_parameters={self.num_parameters}" @@ -1664,9 +1806,12 @@ class Softsign(Module): """ def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.softsign(input) @@ -1690,9 +1835,12 @@ class Tanhshrink(Module): """ def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.tanhshrink(input) @@ -1740,6 +1888,7 @@ def __setstate__(self, state): self.dim = None def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -1749,6 +1898,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.softmin(input, self.dim, _stacklevel=5) + + def extra_repr(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"dim={self.dim}" @@ -1805,6 +1959,7 @@ def __setstate__(self, state): self.dim = None def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -1814,6 +1969,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.softmax(input, self.dim, _stacklevel=5) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"dim={self.dim}" @@ -1840,9 +2000,12 @@ class Softmax2d(Module): """ def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() not in (3, 4): raise ValueError( f"Softmax2d: expected input to be 3D or 4D, got {input.dim()}D instead" @@ -1890,6 +2053,7 @@ def __setstate__(self, state): self.dim = None def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -1899,4 +2063,9 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.log_softmax(input, self.dim, _stacklevel=5) + + def extra_repr(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"dim={self.dim}" diff --git a/torch/nn/modules/adaptive.py b/torch/nn/modules/adaptive.py index 26103c4f2a7b9..42809ba2ac89c 100644 --- a/torch/nn/modules/adaptive.py +++ b/torch/nn/modules/adaptive.py @@ -174,18 +174,24 @@ def __init__( self.tail.append(projection) def reset_parameters(self) -> None: +<<<<<<< HEAD """ Resets parameters based on their initialization used in ``__init__``. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.head.reset_parameters() for i2h, h2o in self.tail: # type: ignore[misc] i2h.reset_parameters() # type: ignore[has-type] h2o.reset_parameters() # type: ignore[has-type] def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) targ_dim = target_.dim() if targ_dim == 1: diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index d975641e80e84..dcaa8b3f5e8f2 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -114,7 +114,11 @@ def _load_from_state_dict( missing_keys, unexpected_keys, error_msgs, +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) version = local_metadata.get("version", None) if (version is None or version < 2) and self.track_running_stats: @@ -193,11 +197,17 @@ def forward(self, input: Tensor) -> Tensor: return F.batch_norm( input, # If buffers are not to be tracked, ensure that they won't be updated +<<<<<<< HEAD ( self.running_mean if not self.training or self.track_running_stats else None ), +======= + self.running_mean + if not self.training or self.track_running_stats + else None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.running_var if not self.training or self.track_running_stats else None, self.weight, self.bias, @@ -338,7 +348,11 @@ class BatchNorm1d(_BatchNorm): >>> output = m(input) """ +<<<<<<< HEAD def _check_input_dim(self, input) -> None: +======= + def _check_input_dim(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() != 2 and input.dim() != 3: raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") @@ -372,7 +386,11 @@ class LazyBatchNorm1d(_LazyNormBase, _BatchNorm): cls_to_become = BatchNorm1d # type: ignore[assignment] +<<<<<<< HEAD def _check_input_dim(self, input) -> None: +======= + def _check_input_dim(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() != 2 and input.dim() != 3: raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") @@ -449,7 +467,11 @@ class BatchNorm2d(_BatchNorm): >>> output = m(input) """ +<<<<<<< HEAD def _check_input_dim(self, input) -> None: +======= + def _check_input_dim(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() != 4: raise ValueError(f"expected 4D input (got {input.dim()}D input)") @@ -483,7 +505,11 @@ class LazyBatchNorm2d(_LazyNormBase, _BatchNorm): cls_to_become = BatchNorm2d # type: ignore[assignment] +<<<<<<< HEAD def _check_input_dim(self, input) -> None: +======= + def _check_input_dim(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() != 4: raise ValueError(f"expected 4D input (got {input.dim()}D input)") @@ -560,7 +586,11 @@ class BatchNorm3d(_BatchNorm): >>> output = m(input) """ +<<<<<<< HEAD def _check_input_dim(self, input) -> None: +======= + def _check_input_dim(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() != 5: raise ValueError(f"expected 5D input (got {input.dim()}D input)") @@ -594,7 +624,11 @@ class LazyBatchNorm3d(_LazyNormBase, _BatchNorm): cls_to_become = BatchNorm3d # type: ignore[assignment] +<<<<<<< HEAD def _check_input_dim(self, input) -> None: +======= + def _check_input_dim(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() != 5: raise ValueError(f"expected 5D input (got {input.dim()}D input)") @@ -719,20 +753,31 @@ def __init__( ) self.process_group = process_group +<<<<<<< HEAD def _check_input_dim(self, input) -> None: if input.dim() < 2: raise ValueError(f"expected at least 2D input (got {input.dim()}D input)") def _check_non_zero_input_channels(self, input) -> None: +======= + def _check_input_dim(self, input): + if input.dim() < 2: + raise ValueError(f"expected at least 2D input (got {input.dim()}D input)") + + def _check_non_zero_input_channels(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.size(1) == 0: raise ValueError( "SyncBatchNorm number of input channels should be non-zero" ) def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._check_input_dim(input) self._check_non_zero_input_channels(input) diff --git a/torch/nn/modules/channelshuffle.py b/torch/nn/modules/channelshuffle.py index 34a48f04f853d..65af56a7d39c7 100644 --- a/torch/nn/modules/channelshuffle.py +++ b/torch/nn/modules/channelshuffle.py @@ -50,6 +50,7 @@ def __init__(self, groups: int) -> None: self.groups = groups def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -59,4 +60,9 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.channel_shuffle(input, self.groups) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"groups={self.groups}" diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index a03f57ea58a80..f0051f7dc7009 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -170,9 +170,12 @@ def __add__(self, other) -> Sequential: ) def pop(self, key: Union[int, slice]) -> Module: +<<<<<<< HEAD """ Pop ``key`` from self. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) v = self[key] del self[key] return v @@ -243,9 +246,12 @@ def __iter__(self) -> Iterator[Module]: # TestScript.test_sequential_intermediary_types). Cannot annotate # with Any as TorchScript expects a more precise type def forward(self, input): +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for module in self: input = module(input) return input @@ -492,7 +498,11 @@ def extend(self, modules: Iterable[Module]) -> Self: self.add_module(str(offset + i), module) return self +<<<<<<< HEAD # remove forward altogether to fallback on Module's _forward_unimplemented +======= + # remove forward alltogether to fallback on Module's _forward_unimplemented +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ModuleDict(Module): @@ -632,7 +642,11 @@ def update(self, modules: Mapping[str, Module]) -> None: # that's too cumbersome to type correctly with overloads, so we add an ignore here self[m[0]] = m[1] # type: ignore[assignment] +<<<<<<< HEAD # remove forward altogether to fallback on Module's _forward_unimplemented +======= + # remove forward alltogether to fallback on Module's _forward_unimplemented +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ParameterList(Module): @@ -752,9 +766,12 @@ def extend(self, values: Iterable[Any]) -> Self: return self def extra_repr(self) -> str: +<<<<<<< HEAD """ Return the extra representation of the module. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) child_lines = [] for k, p in enumerate(self): if isinstance(p, torch.Tensor): diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 2f15c3d488f72..28a54c6d85f1d 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -1,6 +1,10 @@ # mypy: allow-untyped-defs import math +<<<<<<< HEAD from typing import Literal, Optional, Union +======= +from typing import Optional, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing_extensions import deprecated import torch @@ -80,7 +84,11 @@ def _conv_forward( # type: ignore[empty-body] transposed: bool output_padding: tuple[int, ...] groups: int +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] +======= + padding_mode: str +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) weight: Tensor bias: Optional[Tensor] @@ -96,7 +104,11 @@ def __init__( output_padding: tuple[int, ...], groups: int, bias: bool, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"], +======= + padding_mode: str, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, ) -> None: @@ -324,7 +336,11 @@ def __init__( dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", # TODO: refine this type +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, ) -> None: @@ -503,7 +519,11 @@ def __init__( dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", # TODO: refine this type +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, ) -> None: @@ -672,7 +692,11 @@ def __init__( dilation: _size_3_t = 1, groups: int = 1, bias: bool = True, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, ) -> None: @@ -917,7 +941,11 @@ def __init__( groups: int = 1, bias: bool = True, dilation: _size_1_t = 1, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, ) -> None: @@ -1105,7 +1133,11 @@ def __init__( groups: int = 1, bias: bool = True, dilation: _size_2_t = 1, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, ) -> None: @@ -1296,7 +1328,11 @@ def __init__( groups: int = 1, bias: bool = True, dilation: _size_3_t = 1, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, ) -> None: @@ -1374,7 +1410,11 @@ class _ConvTransposeMixin(_ConvTransposeNd): "Please consider using public APIs.", category=FutureWarning, ) +<<<<<<< HEAD def __init__(self, *args, **kwargs) -> None: +======= + def __init__(self, *args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(*args, **kwargs) @@ -1489,7 +1529,11 @@ def __init__( dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, ) -> None: @@ -1558,7 +1602,11 @@ def __init__( dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", # TODO: refine this type +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, ) -> None: @@ -1628,7 +1676,11 @@ def __init__( dilation: _size_3_t = 1, groups: int = 1, bias: bool = True, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, ) -> None: @@ -1696,7 +1748,11 @@ def __init__( groups: int = 1, bias: bool = True, dilation: _size_1_t = 1, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, ) -> None: @@ -1765,7 +1821,11 @@ def __init__( groups: int = 1, bias: bool = True, dilation: int = 1, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, ) -> None: @@ -1834,7 +1894,11 @@ def __init__( groups: int = 1, bias: bool = True, dilation: _size_3_t = 1, +<<<<<<< HEAD padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", +======= + padding_mode: str = "zeros", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) device=None, dtype=None, ) -> None: diff --git a/torch/nn/modules/distance.py b/torch/nn/modules/distance.py index 27ab92fef5eb4..023184d945919 100644 --- a/torch/nn/modules/distance.py +++ b/torch/nn/modules/distance.py @@ -55,9 +55,12 @@ def __init__( self.keepdim = keepdim def forward(self, x1: Tensor, x2: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim) @@ -94,7 +97,10 @@ def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: self.eps = eps def forward(self, x1: Tensor, x2: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.cosine_similarity(x1, x2, self.dim, self.eps) diff --git a/torch/nn/modules/dropout.py b/torch/nn/modules/dropout.py index ee3de5d61dc0b..e4237c33ee8c0 100644 --- a/torch/nn/modules/dropout.py +++ b/torch/nn/modules/dropout.py @@ -67,9 +67,12 @@ class Dropout(_DropoutNd): """ def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.dropout(input, self.p, self.training, self.inplace) @@ -115,9 +118,12 @@ class Dropout1d(_DropoutNd): """ def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.dropout1d(input, self.p, self.training, self.inplace) @@ -170,9 +176,12 @@ class Dropout2d(_DropoutNd): """ def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.dropout2d(input, self.p, self.training, self.inplace) @@ -218,9 +227,12 @@ class Dropout3d(_DropoutNd): """ def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.dropout3d(input, self.p, self.training, self.inplace) @@ -263,9 +275,12 @@ class AlphaDropout(_DropoutNd): """ def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.alpha_dropout(input, self.p, self.training) @@ -317,7 +332,10 @@ class FeatureAlphaDropout(_DropoutNd): """ def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.feature_alpha_dropout(input, self.p, self.training) diff --git a/torch/nn/modules/flatten.py b/torch/nn/modules/flatten.py index c4920ccd65b00..64582c2bf1614 100644 --- a/torch/nn/modules/flatten.py +++ b/torch/nn/modules/flatten.py @@ -50,6 +50,7 @@ def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: self.end_dim = end_dim def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -59,6 +60,11 @@ def extra_repr(self) -> str: """ Returns the extra representation of the module. """ +======= + return input.flatten(self.start_dim, self.end_dim) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"start_dim={self.start_dim}, end_dim={self.end_dim}" @@ -130,7 +136,11 @@ def __init__( self.dim = dim self.unflattened_size = unflattened_size +<<<<<<< HEAD def _require_tuple_tuple(self, input) -> None: +======= + def _require_tuple_tuple(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(input, tuple): for idx, elem in enumerate(input): if not isinstance(elem, tuple): @@ -144,7 +154,11 @@ def _require_tuple_tuple(self, input) -> None: + f"but found type {type(input).__name__}" ) +<<<<<<< HEAD def _require_tuple_int(self, input) -> None: +======= + def _require_tuple_int(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(input, (tuple, list)): for idx, elem in enumerate(input): if not isinstance(elem, int): @@ -158,6 +172,7 @@ def _require_tuple_int(self, input) -> None: ) def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -167,4 +182,9 @@ def extra_repr(self) -> str: """ Returns the extra representation of the module. """ +======= + return input.unflatten(self.dim, self.unflattened_size) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"dim={self.dim}, unflattened_size={self.unflattened_size}" diff --git a/torch/nn/modules/fold.py b/torch/nn/modules/fold.py index ab1a58882c852..a2525ca8b474b 100644 --- a/torch/nn/modules/fold.py +++ b/torch/nn/modules/fold.py @@ -147,9 +147,12 @@ def __init__( self.stride = stride def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.fold( input, self.output_size, @@ -160,9 +163,12 @@ def forward(self, input: Tensor) -> Tensor: ) def extra_repr(self) -> str: +<<<<<<< HEAD """ Return the extra representation of the module. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( "output_size={output_size}, kernel_size={kernel_size}, " "dilation={dilation}, padding={padding}, stride={stride}".format( @@ -318,17 +324,23 @@ def __init__( self.stride = stride def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.unfold( input, self.kernel_size, self.dilation, self.padding, self.stride ) def extra_repr(self) -> str: +<<<<<<< HEAD """ Return the extra representation of the module. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( "kernel_size={kernel_size}, dilation={dilation}, padding={padding}," " stride={stride}".format(**self.__dict__) diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index 25f0c45d5c1a8..c1be0ff6b7836 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -64,7 +64,11 @@ def _load_from_state_dict( missing_keys, unexpected_keys, error_msgs, +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) version = local_metadata.get("version", None) # at version 1: removed running_mean and running_var when # track_running_stats=False (default) @@ -193,10 +197,17 @@ class InstanceNorm1d(_InstanceNorm): >>> output = m(input) """ +<<<<<<< HEAD def _get_no_batch_dim(self) -> int: return 2 def _check_input_dim(self, input) -> None: +======= + def _get_no_batch_dim(self): + return 2 + + def _check_input_dim(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() not in (2, 3): raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") @@ -230,10 +241,17 @@ class LazyInstanceNorm1d(_LazyNormBase, _InstanceNorm): cls_to_become = InstanceNorm1d # type: ignore[assignment] +<<<<<<< HEAD def _get_no_batch_dim(self) -> int: return 2 def _check_input_dim(self, input) -> None: +======= + def _get_no_batch_dim(self): + return 2 + + def _check_input_dim(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() not in (2, 3): raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") @@ -309,10 +327,17 @@ class InstanceNorm2d(_InstanceNorm): >>> output = m(input) """ +<<<<<<< HEAD def _get_no_batch_dim(self) -> int: return 3 def _check_input_dim(self, input) -> None: +======= + def _get_no_batch_dim(self): + return 3 + + def _check_input_dim(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() not in (3, 4): raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") @@ -347,10 +372,17 @@ class LazyInstanceNorm2d(_LazyNormBase, _InstanceNorm): cls_to_become = InstanceNorm2d # type: ignore[assignment] +<<<<<<< HEAD def _get_no_batch_dim(self) -> int: return 3 def _check_input_dim(self, input) -> None: +======= + def _get_no_batch_dim(self): + return 3 + + def _check_input_dim(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() not in (3, 4): raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") @@ -425,10 +457,17 @@ class InstanceNorm3d(_InstanceNorm): >>> output = m(input) """ +<<<<<<< HEAD def _get_no_batch_dim(self) -> int: return 4 def _check_input_dim(self, input) -> None: +======= + def _get_no_batch_dim(self): + return 4 + + def _check_input_dim(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() not in (4, 5): raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)") @@ -463,9 +502,16 @@ class LazyInstanceNorm3d(_LazyNormBase, _InstanceNorm): cls_to_become = InstanceNorm3d # type: ignore[assignment] +<<<<<<< HEAD def _get_no_batch_dim(self) -> int: return 4 def _check_input_dim(self, input) -> None: +======= + def _get_no_batch_dim(self): + return 4 + + def _check_input_dim(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() not in (4, 5): raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)") diff --git a/torch/nn/modules/lazy.py b/torch/nn/modules/lazy.py index 46e7c7be63dbc..538c3cae5663a 100644 --- a/torch/nn/modules/lazy.py +++ b/torch/nn/modules/lazy.py @@ -170,7 +170,11 @@ class LazyModuleMixin: cls_to_become: Optional[type[Any]] = None def __init__(self: _LazyProtocol, *args, **kwargs): +<<<<<<< HEAD # Mypy doesn't like this super call in a mixin +======= + # Mypy doesnt like this super call in a mixin +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(*args, **kwargs) # type: ignore[misc] self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook) self._initialize_hook = self.register_forward_pre_hook( @@ -250,7 +254,11 @@ def has_uninitialized_params(self: _LazyProtocol): def _infer_parameters(self: _LazyProtocol, module, args, kwargs=None): r"""Infers the size and initializes the parameters according to the provided input batch. +<<<<<<< HEAD Given a module that contains parameters that were declared inferable +======= + Given a module that contains parameters that were declared inferrable +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass in the complete module using the provided input to initialize all the parameters as needed. diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index a3c867d533d6b..010635404d1ce 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -44,9 +44,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__() def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return input @@ -115,9 +118,12 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: +<<<<<<< HEAD """ Resets parameters based on their initialization used in ``__init__``. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see # https://github.com/pytorch/pytorch/issues/57109 @@ -128,6 +134,7 @@ def reset_parameters(self) -> None: init.uniform_(self.bias, -bound, bound) def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -137,6 +144,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.linear(input, self.weight, self.bias) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" @@ -230,15 +242,19 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: +<<<<<<< HEAD """ Resets parameters based on their initialization used in ``__init__``. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) bound = 1 / math.sqrt(self.weight.size(1)) init.uniform_(self.weight, -bound, bound) if self.bias is not None: init.uniform_(self.bias, -bound, bound) def forward(self, input1: Tensor, input2: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -248,6 +264,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.bilinear(input1, input2, self.weight, self.bias) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( f"in1_features={self.in1_features}, in2_features={self.in2_features}, " f"out_features={self.out_features}, bias={self.bias is not None}" @@ -300,16 +321,22 @@ def __init__( self.bias = UninitializedParameter(**factory_kwargs) def reset_parameters(self) -> None: +<<<<<<< HEAD """ Resets parameters based on their initialization used in ``__init__``. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not self.has_uninitialized_params() and self.in_features != 0: super().reset_parameters() def initialize_parameters(self, input) -> None: # type: ignore[override] +<<<<<<< HEAD """ Infers ``in_features`` based on ``input`` and initializes parameters. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.has_uninitialized_params(): with torch.no_grad(): self.in_features = input.shape[-1] diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 949c9f46d0085..c3f8518567549 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -126,9 +126,12 @@ def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> N super().__init__(size_average, reduce, reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.l1_loss(input, target, reduction=self.reduction) @@ -253,9 +256,12 @@ def __init__( self.ignore_index = ignore_index def forward(self, input: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.nll_loss( input, target, @@ -360,9 +366,12 @@ def __init__( self.eps = eps def forward(self, log_input: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.poisson_nll_loss( log_input, target, @@ -454,9 +463,12 @@ def __init__( def forward( self, input: Tensor, target: Tensor, var: Union[Tensor, float] ) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.gaussian_nll_loss( input, target, var, full=self.full, eps=self.eps, reduction=self.reduction ) @@ -557,9 +569,12 @@ def __init__( self.log_target = log_target def forward(self, input: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.kl_div( input, target, reduction=self.reduction, log_target=self.log_target ) @@ -628,9 +643,12 @@ def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> N super().__init__(size_average, reduce, reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.mse_loss(input, target, reduction=self.reduction) @@ -721,9 +739,12 @@ def __init__( super().__init__(weight, size_average, reduce, reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.binary_cross_entropy( input, target, weight=self.weight, reduction=self.reduction ) @@ -813,7 +834,11 @@ class BCEWithLogitsLoss(_Loss): operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of size [B, C, H, W] will apply different pos_weights to each element of the batch or [C, H, W] the same pos_weights across the batch. To apply the same positive weight +<<<<<<< HEAD along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. +======= + along all spacial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Default: ``None`` Shape: @@ -846,7 +871,10 @@ def __init__( self.pos_weight: Optional[Tensor] def forward(self, input: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.binary_cross_entropy_with_logits( input, target, @@ -920,7 +948,10 @@ def __init__( self.margin = margin def forward(self, input: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.hinge_embedding_loss( input, target, margin=self.margin, reduction=self.reduction ) @@ -988,7 +1019,10 @@ def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> N super().__init__(size_average, reduce, reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.multilabel_margin_loss(input, target, reduction=self.reduction) @@ -1073,7 +1107,10 @@ def __init__( self.beta = beta def forward(self, input: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta) @@ -1135,7 +1172,10 @@ def __init__(self, reduction: str = "mean", delta: float = 1.0) -> None: self.delta = delta def forward(self, input: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.huber_loss(input, target, reduction=self.reduction, delta=self.delta) @@ -1178,7 +1218,10 @@ def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> N super().__init__(size_average, reduce, reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.soft_margin_loss(input, target, reduction=self.reduction) @@ -1287,9 +1330,13 @@ class probabilities only when a single class label per minibatch item is too res :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. The target data type is required to be long when using class indices. If containing class probabilities, the target must be the same shape input, and each value should be between :math:`[0, 1]`. This means the target +<<<<<<< HEAD data type is required to be float when using class probabilities. Note that PyTorch does not strictly enforce probability constraints on the class probabilities and that it is the user's responsibility to ensure ``target`` contains valid probability distributions (see below examples section for more details). +======= + data type is required to be float when using class probabilities. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - Output: If reduction is 'none', shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss, depending on the shape of the input. Otherwise, scalar. @@ -1316,6 +1363,7 @@ class probabilities only when a single class label per minibatch item is too res >>> target = torch.randn(3, 5).softmax(dim=1) >>> output = loss(input, target) >>> output.backward() +<<<<<<< HEAD .. note:: When ``target`` contains class probabilities, it should consist of soft labels—that is, @@ -1361,6 +1409,8 @@ class probabilities only when a single class label per minibatch item is too res tensor([1.0000, 1.0000, 1.0000]) >>> loss(input, target_new).item() 2.55349063873291 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ __constants__ = ["ignore_index", "reduction", "label_smoothing"] @@ -1381,7 +1431,10 @@ def __init__( self.label_smoothing = label_smoothing def forward(self, input: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.cross_entropy( input, target, @@ -1443,7 +1496,10 @@ def __init__( super().__init__(weight, size_average, reduce, reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.multilabel_soft_margin_loss( input, target, weight=self.weight, reduction=self.reduction ) @@ -1515,7 +1571,10 @@ def __init__( self.margin = margin def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.cosine_embedding_loss( input1, input2, target, margin=self.margin, reduction=self.reduction ) @@ -1582,7 +1641,10 @@ def __init__( self.margin = margin def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.margin_ranking_loss( input1, input2, target, margin=self.margin, reduction=self.reduction ) @@ -1673,7 +1735,10 @@ def __init__( self.margin = margin def forward(self, input: Tensor, target: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.multi_margin_loss( input, target, @@ -1769,7 +1834,11 @@ def __init__( size_average=None, reduce=None, reduction: str = "mean", +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(size_average, reduce, reduction) if margin <= 0: raise ValueError( @@ -1781,7 +1850,10 @@ def __init__( self.swap = swap def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.triplet_margin_loss( anchor, positive, @@ -1904,7 +1976,11 @@ def __init__( margin: float = 1.0, swap: bool = False, reduction: str = "mean", +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(size_average=None, reduce=None, reduction=reduction) if margin <= 0: raise ValueError( @@ -1917,7 +1993,10 @@ def __init__( self.swap = swap def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.triplet_margin_with_distance_loss( anchor, positive, @@ -2085,7 +2164,11 @@ class CTCLoss(_Loss): def __init__( self, blank: int = 0, reduction: str = "mean", zero_infinity: bool = False +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(reduction=reduction) self.blank = blank self.zero_infinity = zero_infinity @@ -2097,7 +2180,10 @@ def forward( input_lengths: Tensor, target_lengths: Tensor, ) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.ctc_loss( log_probs, targets, diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index f0c4914782f39..deb6bb6d7da7b 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -42,7 +42,11 @@ class _IncompatibleKeys( ): __slots__ = () +<<<<<<< HEAD def __repr__(self) -> str: +======= + def __repr__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not self.missing_keys and not self.unexpected_keys: return "" return super().__repr__() @@ -70,7 +74,11 @@ def _addindent(s_, numSpaces): class _WrappedHook: +<<<<<<< HEAD def __init__(self, hook: Callable, module: Optional["Module"] = None) -> None: +======= + def __init__(self, hook: Callable, module: Optional["Module"] = None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.hook: Callable = hook functools.update_wrapper(self, hook) @@ -568,9 +576,13 @@ def register_buffer( raise KeyError('buffer name can\'t be empty string ""') elif hasattr(self, name) and name not in self._buffers: raise KeyError(f"attribute '{name}' already exists") +<<<<<<< HEAD elif tensor is not None and not ( isinstance(tensor, torch.Tensor) or hasattr(tensor, "__torch_function__") ): +======= + elif tensor is not None and not isinstance(tensor, torch.Tensor): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise TypeError( f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' " "(torch Tensor or None required)" @@ -929,7 +941,11 @@ def _apply(self, fn, recurse=True): for module in self.children(): module._apply(fn) +<<<<<<< HEAD def compute_should_use_set_data(tensor, tensor_applied) -> bool: +======= + def compute_should_use_set_data(tensor, tensor_applied): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): # If the new tensor has compatible tensor type as the existing tensor, # the current behavior is to change the tensor in-place using `.data =`, @@ -1538,7 +1554,11 @@ def _get_backward_pre_hooks(self): return backward_pre_hooks +<<<<<<< HEAD def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn) -> None: +======= + def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not isinstance(result, torch.Tensor): if not ( isinstance(result, tuple) @@ -1966,7 +1986,11 @@ def __getattr__(self, name: str) -> Union[Tensor, "Module"]: ) def __setattr__(self, name: str, value: Union[Tensor, "Module"]) -> None: +<<<<<<< HEAD def remove_from(*dicts_or_sets) -> None: +======= + def remove_from(*dicts_or_sets): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for d in dicts_or_sets: if name in d: if isinstance(d, dict): @@ -2026,10 +2050,14 @@ def remove_from(*dicts_or_sets) -> None: else: buffers = self.__dict__.get("_buffers") if isinstance(value, Buffer) or buffers is not None and name in buffers: +<<<<<<< HEAD if value is not None and not ( isinstance(value, torch.Tensor) or hasattr(value, "__torch_function__") ): +======= + if value is not None and not isinstance(value, torch.Tensor): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise TypeError( f"cannot assign '{torch.typename(value)}' as buffer '{name}' " "(torch.nn.Buffer, torch.Tensor or None expected)" @@ -2070,7 +2098,11 @@ def remove_from(*dicts_or_sets) -> None: else: super().__setattr__(name, value) +<<<<<<< HEAD def __delattr__(self, name) -> None: +======= + def __delattr__(self, name): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if name in self._parameters: del self._parameters[name] elif name in self._buffers: @@ -2137,7 +2169,11 @@ def register_state_dict_pre_hook(self, hook): self._state_dict_pre_hooks[handle.id] = hook return handle +<<<<<<< HEAD def _save_to_state_dict(self, destination, prefix, keep_vars) -> None: +======= + def _save_to_state_dict(self, destination, prefix, keep_vars): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Save module state to the `destination` dictionary. The `destination` dictionary will contain the state @@ -2347,7 +2383,11 @@ def _load_from_state_dict( missing_keys, unexpected_keys, error_msgs, +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. This is called on every submodule @@ -2573,7 +2613,11 @@ def load_state_dict( # mypy isn't aware that "_metadata" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] +<<<<<<< HEAD def load(module, local_state_dict, prefix="") -> None: +======= + def load(module, local_state_dict, prefix=""): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: local_metadata["assign_to_params_buffers"] = assign @@ -2976,7 +3020,11 @@ def extra_repr(self) -> str: """ return "" +<<<<<<< HEAD def __repr__(self) -> str: +======= + def __repr__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We treat the extra repr like the sub-module, one item per line extra_lines = [] extra_repr = self.extra_repr() diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index 1474de008c185..e741af67f8b69 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -60,6 +60,7 @@ def __init__( self.k = k def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -69,6 +70,11 @@ def extra_repr(self): """ Return the extra representation of the module. """ +======= + return F.local_response_norm(input, self.size, self.alpha, self.beta, self.k) + + def extra_repr(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__) @@ -88,6 +94,7 @@ def __init__( self.k = k def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -97,6 +104,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return _cross_map_lrn2d.apply(input, self.size, self.alpha, self.beta, self.k) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__) @@ -249,7 +261,11 @@ class GroupNorm(Module): The input channels are separated into :attr:`num_groups` groups, each containing ``num_channels / num_groups`` channels. :attr:`num_channels` must be divisible by :attr:`num_groups`. The mean and standard-deviation are calculated +<<<<<<< HEAD separately over each group. :math:`\gamma` and :math:`\beta` are learnable +======= + separately over the each group. :math:`\gamma` and :math:`\beta` are learnable +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) per-channel affine transform parameter vectors of size :attr:`num_channels` if :attr:`affine` is ``True``. The variance is calculated via the biased estimator, equivalent to @@ -409,13 +425,21 @@ def reset_parameters(self) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: """ +<<<<<<< HEAD Runs the forward pass. +======= + Runs forward pass. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) def extra_repr(self) -> str: """ +<<<<<<< HEAD Return the extra representation of the module. +======= + Extra information about the module. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ return ( "{normalized_shape}, eps={eps}, " diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py index 6c4c117d1a7d3..72c785015a1a1 100644 --- a/torch/nn/modules/padding.py +++ b/torch/nn/modules/padding.py @@ -90,7 +90,11 @@ def __init__(self, padding: _size_2_t) -> None: super().__init__() self.padding = _pair(padding) +<<<<<<< HEAD def _check_input_dim(self, input) -> None: +======= + def _check_input_dim(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() != 2 and input.dim() != 3: raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") @@ -150,7 +154,11 @@ def __init__(self, padding: _size_4_t) -> None: super().__init__() self.padding = _quadruple(padding) +<<<<<<< HEAD def _check_input_dim(self, input) -> None: +======= + def _check_input_dim(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() != 3 and input.dim() != 4: raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") @@ -200,7 +208,11 @@ def __init__(self, padding: _size_6_t) -> None: super().__init__() self.padding = _ntuple(6)(padding) +<<<<<<< HEAD def _check_input_dim(self, input) -> None: +======= + def _check_input_dim(self, input): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if input.dim() != 4 and input.dim() != 5: raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)") @@ -267,7 +279,11 @@ class ConstantPad1d(_ConstantPadNd): padding: tuple[int, int] +<<<<<<< HEAD def __init__(self, padding: _size_2_t, value: float) -> None: +======= + def __init__(self, padding: _size_2_t, value: float): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(value) self.padding = _pair(padding) @@ -722,9 +738,12 @@ def __init__(self, padding: _size_2_t) -> None: super().__init__(padding, 0.0) def extra_repr(self) -> str: +<<<<<<< HEAD """ Return the extra representation of the module. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"{self.padding}" @@ -779,9 +798,12 @@ def __init__(self, padding: _size_4_t) -> None: super().__init__(padding, 0.0) def extra_repr(self) -> str: +<<<<<<< HEAD """ Return the extra representation of the module. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"{self.padding}" @@ -824,7 +846,10 @@ def __init__(self, padding: _size_6_t) -> None: super().__init__(padding, 0.0) def extra_repr(self) -> str: +<<<<<<< HEAD """ Return the extra representation of the module. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"{self.padding}" diff --git a/torch/nn/modules/pixelshuffle.py b/torch/nn/modules/pixelshuffle.py index 74c9e0878f0b5..57e0e782c8290 100644 --- a/torch/nn/modules/pixelshuffle.py +++ b/torch/nn/modules/pixelshuffle.py @@ -56,6 +56,7 @@ def __init__(self, upscale_factor: int) -> None: self.upscale_factor = upscale_factor def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -65,6 +66,11 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.pixel_shuffle(input, self.upscale_factor) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"upscale_factor={self.upscale_factor}" @@ -115,6 +121,7 @@ def __init__(self, downscale_factor: int) -> None: self.downscale_factor = downscale_factor def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ @@ -124,4 +131,9 @@ def extra_repr(self) -> str: """ Return the extra representation of the module. """ +======= + return F.pixel_unshuffle(input, self.downscale_factor) + + def extra_repr(self) -> str: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"downscale_factor={self.downscale_factor}" diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index 3a5cd9805bfb9..4d3b0c469e56b 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -142,7 +142,10 @@ class MaxPool1d(_MaxPoolNd): dilation: _size_1_t def forward(self, input: Tensor): +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.max_pool1d( input, self.kernel_size, @@ -222,7 +225,10 @@ class MaxPool2d(_MaxPoolNd): dilation: _size_2_t def forward(self, input: Tensor): +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.max_pool2d( input, self.kernel_size, @@ -306,7 +312,10 @@ class MaxPool3d(_MaxPoolNd): dilation: _size_3_t def forward(self, input: Tensor): +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.max_pool3d( input, self.kernel_size, @@ -400,7 +409,10 @@ def __init__( def forward( self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None ) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.max_unpool1d( input, indices, self.kernel_size, self.stride, self.padding, output_size ) @@ -496,7 +508,10 @@ def __init__( def forward( self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None ) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.max_unpool2d( input, indices, self.kernel_size, self.stride, self.padding, output_size ) @@ -575,7 +590,10 @@ def __init__( def forward( self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None ) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.max_unpool3d( input, indices, self.kernel_size, self.stride, self.padding, output_size ) @@ -668,7 +686,10 @@ def __init__( self.count_include_pad = count_include_pad def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.avg_pool1d( input, self.kernel_size, @@ -777,7 +798,10 @@ def __init__( self.divisor_override = divisor_override def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.avg_pool2d( input, self.kernel_size, @@ -894,7 +918,10 @@ def __init__( self.divisor_override = divisor_override def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.avg_pool3d( input, self.kernel_size, @@ -1014,8 +1041,12 @@ class FractionalMaxPool3d(Module): Args: kernel_size: the size of the window to take a max over. +<<<<<<< HEAD Can be a single number `k` (for a square kernel of `k x k x k`) or a tuple `(kt x kh x kw)`, `k` must greater than 0. +======= + Can be a single number k (for a square kernel of k x k x k) or a tuple `(kt x kh x kw)` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_size: the target output size of the image of the form `oT x oH x oW`. Can be a tuple `(oT, oH, oW)` or a single number oH for a square image `oH x oH x oH` output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. @@ -1056,11 +1087,14 @@ def __init__( _random_samples=None, ) -> None: super().__init__() +<<<<<<< HEAD if (isinstance(kernel_size, int) and kernel_size <= 0) or ( isinstance(kernel_size, (tuple, list)) and not all(k > 0 for k in kernel_size) ): raise ValueError(f"kernel_size must greater than 0, but got {kernel_size}") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.kernel_size = _triple(kernel_size) self.return_indices = return_indices self.register_buffer("_random_samples", _random_samples) @@ -1159,7 +1193,10 @@ class LPPool1d(_LPPoolNd): stride: _size_1_t def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.lp_pool1d( input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode ) @@ -1215,7 +1252,10 @@ class LPPool2d(_LPPoolNd): stride: _size_2_t def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.lp_pool2d( input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode ) @@ -1275,7 +1315,10 @@ class LPPool3d(_LPPoolNd): stride: _size_3_t def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.lp_pool3d( input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode ) @@ -1327,7 +1370,10 @@ class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd): output_size: _size_1_t def forward(self, input: Tensor): +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.adaptive_max_pool1d(input, self.output_size, self.return_indices) @@ -1370,7 +1416,10 @@ class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd): output_size: _size_2_opt_t def forward(self, input: Tensor): +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.adaptive_max_pool2d(input, self.output_size, self.return_indices) @@ -1414,7 +1463,10 @@ class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd): output_size: _size_3_opt_t def forward(self, input: Tensor): +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.adaptive_max_pool3d(input, self.output_size, self.return_indices) @@ -1454,9 +1506,12 @@ class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd): output_size: _size_1_t def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.adaptive_avg_pool1d(input, self.output_size) @@ -1496,7 +1551,10 @@ class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd): output_size: _size_2_opt_t def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.adaptive_avg_pool2d(input, self.output_size) @@ -1536,5 +1594,8 @@ class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd): output_size: _size_3_opt_t def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """Runs the forward pass.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.adaptive_avg_pool3d(input, self.output_size) diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index df4f2f882628b..e72a008a62f3e 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -204,7 +204,11 @@ def __init__( self.reset_parameters() +<<<<<<< HEAD def _init_flat_weights(self) -> None: +======= + def _init_flat_weights(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._flat_weights = [ getattr(self, wn) if hasattr(self, wn) else None for wn in self._flat_weights_names @@ -214,7 +218,11 @@ def _init_flat_weights(self) -> None: ] self.flatten_parameters() +<<<<<<< HEAD def __setattr__(self, attr, value) -> None: +======= + def __setattr__(self, attr, value): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names: # keep self._flat_weights up to date if you do self.weight = ... idx = self._flat_weights_names.index(attr) @@ -360,7 +368,11 @@ def _weights_have_changed(self): def check_forward_args( self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.check_input(input, batch_sizes) expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) @@ -387,7 +399,11 @@ def extra_repr(self) -> str: s += ", bidirectional={bidirectional}" return s.format(**self.__dict__) +<<<<<<< HEAD def _update_flat_weights(self) -> None: +======= + def _update_flat_weights(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not torch.jit.is_scripting(): if self._weights_have_changed(): self._init_flat_weights() @@ -615,7 +631,11 @@ def __init__( ) -> None: ... @overload +<<<<<<< HEAD def __init__(self, *args, **kwargs) -> None: ... +======= + def __init__(self, *args, **kwargs): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self, *args, **kwargs): if "proj_size" in kwargs: @@ -652,9 +672,12 @@ def forward( pass def forward(self, input, hx=None): # noqa: F811 +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._update_flat_weights() num_directions = 2 if self.bidirectional else 1 @@ -974,7 +997,11 @@ def __init__( ) -> None: ... @overload +<<<<<<< HEAD def __init__(self, *args, **kwargs) -> None: ... +======= + def __init__(self, *args, **kwargs): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self, *args, **kwargs): super().__init__("LSTM", *args, **kwargs) @@ -1001,7 +1028,11 @@ def check_forward_args( input: Tensor, hidden: tuple[Tensor, Tensor], # type: ignore[override] batch_sizes: Optional[Tensor], +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.check_input(input, batch_sizes) self.check_hidden_size( hidden[0], @@ -1307,7 +1338,11 @@ def __init__( ) -> None: ... @overload +<<<<<<< HEAD def __init__(self, *args, **kwargs) -> None: ... +======= + def __init__(self, *args, **kwargs): ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self, *args, **kwargs): if "proj_size" in kwargs: diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index dd699537e08c6..b3d17b14cee6b 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -301,7 +301,11 @@ def generate_square_subsequent_mask( """ return _generate_square_subsequent_mask(sz, dtype=dtype, device=device) +<<<<<<< HEAD def _reset_parameters(self) -> None: +======= + def _reset_parameters(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Initiate parameters in the transformer model.""" for p in self.parameters(): if p.dim() > 1: diff --git a/torch/nn/modules/upsampling.py b/torch/nn/modules/upsampling.py index 7fd102a768225..18497b625d1da 100644 --- a/torch/nn/modules/upsampling.py +++ b/torch/nn/modules/upsampling.py @@ -169,9 +169,12 @@ def __init__( self.recompute_scale_factor = recompute_scale_factor def forward(self, input: Tensor) -> Tensor: +<<<<<<< HEAD """ Runs the forward pass. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return F.interpolate( input, self.size, @@ -188,9 +191,12 @@ def __setstate__(self, state): super().__setstate__(state) def extra_repr(self) -> str: +<<<<<<< HEAD """ Return the extra representation of the module. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.scale_factor is not None: info = "scale_factor=" + repr(self.scale_factor) else: diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index c5db538f52bb2..c6063658f9622 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -347,20 +347,28 @@ class DistributedDataParallel(Module, Joinable): To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn up ``N`` processes, ensuring that each process exclusively works on a single GPU from 0 to N-1. This can be done by either setting +<<<<<<< HEAD ``CUDA_VISIBLE_DEVICES`` for every process or by calling the following API for GPUs, +======= + ``CUDA_VISIBLE_DEVICES`` for every process or by calling: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> # xdoctest: +SKIP("undefined variables") >>> torch.cuda.set_device(i) +<<<<<<< HEAD or calling the unified API for :ref:`accelerator`, >>> # xdoctest: +SKIP("undefined variables") >>> torch.accelerator.set_device_index(i) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) where i is from 0 to N-1. In each process, you should refer the following to construct this module: >>> # xdoctest: +SKIP("undefined variables") +<<<<<<< HEAD >>> if torch.accelerator.is_available(): >>> device_type = torch.accelerator.current_accelerator().type >>> vendor_backend = torch.distributed.get_default_backend_for_device(device_type) @@ -374,6 +382,13 @@ class DistributedDataParallel(Module, Joinable): >>> torch.distributed.init_process_group(device_id=i) +======= + >>> torch.distributed.init_process_group( + >>> backend='nccl', world_size=N, init_method='...' + >>> ) + >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) In order to spawn up multiple processes per node, you can use either ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``. @@ -2184,7 +2199,11 @@ def _sync_buffers(self): else: # The process with rank 0 is considered the authoritative copy. authoritative_rank = 0 +<<<<<<< HEAD # Update self.modules_buffers in case any buffers were +======= + # Update self.modules_buffers incase any buffers were +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # reassigned. self._assign_modules_buffers() self._sync_module_buffers(authoritative_rank) diff --git a/torch/nn/parallel/replicate.py b/torch/nn/parallel/replicate.py index 6c6e4567efa11..49dd35294a62f 100644 --- a/torch/nn/parallel/replicate.py +++ b/torch/nn/parallel/replicate.py @@ -111,7 +111,12 @@ def replicate( ) -> list[T]: if not _replicatable_module(network): raise RuntimeError( +<<<<<<< HEAD "Cannot replicate network where python modules are children of ScriptModule" +======= + "Cannot replicate network where python modules are " + "childrens of ScriptModule" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if not devices: diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index c41a102fc946f..fc41911b29e79 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -1,10 +1,14 @@ from collections import OrderedDict +<<<<<<< HEAD from typing import Any +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch._C import _disabled_torch_function_impl +<<<<<<< HEAD __all__ = [ "Parameter", "UninitializedParameter", @@ -15,6 +19,8 @@ ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Metaclass to combine _TensorMeta and the instance check override for Parameter. class _ParameterMeta(torch._C._TensorMeta): # Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag. @@ -185,6 +191,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): ) +<<<<<<< HEAD def is_lazy(param: Any) -> bool: """ Returns whether ``param`` is an ``UninitializedParameter`` or ``UninitializedBuffer``. @@ -192,6 +199,9 @@ def is_lazy(param: Any) -> bool: Args: param (Any): the input to check. """ +======= +def is_lazy(param): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return isinstance(param, UninitializedTensorMixin) diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py index fd6e908e1ff09..3fc5340a258ab 100644 --- a/torch/nn/utils/clip_grad.py +++ b/torch/nn/utils/clip_grad.py @@ -4,8 +4,13 @@ import types import typing import warnings +<<<<<<< HEAD from typing import Callable, cast, Optional, TypeVar, Union from typing_extensions import deprecated, ParamSpec, TypeAlias +======= +from typing import cast, Optional, Union +from typing_extensions import deprecated +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch import Tensor @@ -16,6 +21,7 @@ ) +<<<<<<< HEAD __all__: list[str] = [ "clip_grad_norm", "clip_grad_norm_", @@ -24,15 +30,26 @@ _tensor_or_tensors: TypeAlias = Union[ # noqa: PYI042 +======= +__all__: list[str] = [] + + +_tensor_or_tensors = Union[ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.Tensor, typing.Iterable[torch.Tensor], # noqa: UP006 - needed until XLA's patch is updated ] +<<<<<<< HEAD _P = ParamSpec("_P") _R = TypeVar("_R") def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]: +======= + +def _no_grad(func): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions clip_grad_norm_ and clip_grad_value_ themselves. @@ -129,6 +146,7 @@ def _clip_grads_with_norm_( The gradients will be scaled by the following calculation .. math:: +<<<<<<< HEAD grad = grad * \min(\frac{max\_norm}{total\_norm + 1e-6}, 1) Gradients are modified in-place. @@ -136,6 +154,12 @@ def _clip_grads_with_norm_( Note: The scale coefficient is clamped to a maximum of 1.0 to prevent gradient amplification. This ensures that gradients are only scaled down when the total norm exceeds max_norm. +======= + grad = grad * \frac{max\_norm}{total\_norm + 1e-6} + + Gradients are modified in-place. + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) This function is equivalent to :func:`torch.nn.utils.clip_grad_norm_` with a pre-calculated total norm. @@ -296,3 +320,11 @@ def clip_grad_value_( else: for grad in grads: cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value) +<<<<<<< HEAD +======= + + +clip_grad_norm.__module__ = "torch.nn.utils" +clip_grad_norm_.__module__ = "torch.nn.utils" +clip_grad_value_.__module__ = "torch.nn.utils" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py index aa7c0f35a183e..ef2e5f71c9b4a 100644 --- a/torch/nn/utils/rnn.py +++ b/torch/nn/utils/rnn.py @@ -255,6 +255,7 @@ def _packed_sequence_init( def invert_permutation(permutation: Optional[Tensor]) -> Optional[Tensor]: +<<<<<<< HEAD """Returns the inverse of ``permutation``. This is useful for converting between sorted and unsorted indices in @@ -263,6 +264,8 @@ def invert_permutation(permutation: Optional[Tensor]) -> Optional[Tensor]: Args: permutation (Tensor, optional): a 1-D tensor of indices to invert """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if permutation is None: return None output = torch.empty_like(permutation, memory_format=torch.legacy_contiguous_format) diff --git a/torch/onnx/README.md b/torch/onnx/README.md index 3878f48d70be0..69ab033435b13 100644 --- a/torch/onnx/README.md +++ b/torch/onnx/README.md @@ -4,3 +4,95 @@ Torch->ONNX converter / exporter. - User-facing docs: https://pytorch.org/docs/main/onnx.html - Developer docs: https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter +<<<<<<< HEAD +======= + +> Read the following if you are contributing to `torch.onnx` + +## Symbolic functions Opsets + +Opset 9 is the base version. It is selected as the base version because + +1. It is the first opset version supported by PyTorch export. +2. Opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations + that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations, + we chose to handle them as special cases separately. + +Backward support for opset versions beyond opset 7 is not in our roadmap. + +For opset versions other than 9, by default they will inherit the symbolic functions defined in +symbolic_opset9.py. + +To extend support for updated operators in different opset versions on top of opset 9, +simply add the updated symbolic functions in the respective symbolic_opset{version}.py file. +Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example. + +## Editing Symbolic Files + +- Use the internal `registration.onnx_symbolic` decorator to register a new symbolic function. Search for `def reshape(g, self, shape):` to see an example. +- Parameter names must *exactly* match the names in + aten/src/ATen/native/native_functions.yaml, because + dispatch is done with keyword arguments. +- Looking for inplace ops? They're detected by + `_jit_pass_onnx_remove_inplace_ops_for_onnx`, and + transparently dispatched to their non inplace versions in + "run_symbolic_function". See Note [Export inplace](#export-inplace) + +### A note on Tensor types + +In general, we should avoid depending on the type of Tensor Values contained +within the trace graph. However, this is sometimes unavoidable (due to ONNX +spec requirements, etc). The TensorType object has accessors for these properties that return the property if it is statically known and return nullopt otherwise. + +In general, we should prefer to rely on the least specific information possible. +For example, not relying on tensor properties at all is better than relying +on the number of dimensions which is better than relying on +concrete shapes. Doing so will make the export symbolics +more robust to different graphs. + +### Extra context for symbolic functions + +The first argument of a symbolic function is always a `GraphContext` object. + +`GraphContext` contains all methods defined in a `torch.Graph` object and context +for the symbolic function. + +In general, symbolic functions only require inputs and attributes to +the original node. An example of a symbolic function needing context is +`prim::Loop`. It needs access to the sub-block of the original node. + +### Export inplace + +It would be better for us to export inplace annotations, +than to not export them, since it is useful information that can +help the target of an ONNX export export more efficiently. However, +ONNX doesn't currently formalize inplace. Fortunately, it's sound to drop +inplace annotations, but we are losing information this way. + +### Pointwise by scalar + +What happens if you add a tensor with a constant (e.g., x + 2)? There are +some moving parts to implementing the ONNX translation in this case: + +- By the time we get the scalar in a symbolic function here, it is no longer a + Python long/float, but a PyTorch tensor with `numel == 1` (eventually, we want + it to be a zero dim tensor but this change has not happened yet.) However, the + type of this scalar is *exactly* what the user wrote in Python, which may not + match the tensor it is being added to. PyTorch will do implicit conversions on + scalars; however, ONNX will not, so we must do the conversion ourselves. This + is what `symbolic_helper._if_scalar_type_as()` and + `_jit_pass_onnx_scalar_type_analysis` does. + +- Dispatch to these functions takes advantage an outrageous coincidence + between the tensor and scalar name. When we add two tensors together, + you get the dispatch: + + add(*[self, other], **{"alpha": alpha}) + + When you add a tensor and a scalar, you get the dispatch: + + add(*[self], **{"other": other, "alpha": alpha}) + + By having the argument name line up with the name of the scalar attribute + if it exists, we can write a single function for both overloads. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 668f47c15bc82..681821d748452 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -6,6 +6,7 @@ # Modules "errors", "ops", +<<<<<<< HEAD # Public functions "export", "is_in_onnx_export", @@ -36,11 +37,98 @@ JitScalarType, # Deprecated members that are excluded from __all__ ) from ._internal.torchscript_exporter.utils import ( # Deprecated members that are excluded from __all__ +======= + "symbolic_helper", + "utils", + # All opsets + "symbolic_caffe2", + "symbolic_opset7", + "symbolic_opset8", + "symbolic_opset9", + "symbolic_opset10", + "symbolic_opset11", + "symbolic_opset12", + "symbolic_opset13", + "symbolic_opset14", + "symbolic_opset15", + "symbolic_opset16", + "symbolic_opset17", + "symbolic_opset18", + "symbolic_opset19", + "symbolic_opset20", + # Enums + "OperatorExportTypes", + "TrainingMode", + "TensorProtoDataType", + "JitScalarType", + # Public functions + "export", + "is_in_onnx_export", + "select_model_mode_for_export", + "register_custom_op_symbolic", + "unregister_custom_op_symbolic", + # Base error + "OnnxExporterError", + "ExportOptions", + "ONNXProgram", + "dynamo_export", + "enable_fake_mode", + # DORT / torch.compile + "is_onnxrt_backend_supported", +] + +from typing import Any, Callable, TYPE_CHECKING +from typing_extensions import deprecated + +import torch +from torch._C import _onnx as _C_onnx +from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode + +from ._internal._exporter_legacy import enable_fake_mode +from ._internal.exporter._onnx_program import ONNXProgram +from ._internal.onnxruntime import ( + is_onnxrt_backend_supported, + OrtBackend as _OrtBackend, + OrtBackendOptions as _OrtBackendOptions, + OrtExecutionProvider as _OrtExecutionProvider, +) +from ._type_utils import JitScalarType +from .errors import OnnxExporterError +from .utils import ( + _run_symbolic_function, + _run_symbolic_method, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) register_custom_op_symbolic, select_model_mode_for_export, unregister_custom_op_symbolic, ) +<<<<<<< HEAD from .errors import OnnxExporterError +======= + + +from . import ( # usort: skip. Keep the order instead of sorting lexicographically + errors, + ops, + symbolic_caffe2, + symbolic_helper, + symbolic_opset7, + symbolic_opset8, + symbolic_opset9, + symbolic_opset10, + symbolic_opset11, + symbolic_opset12, + symbolic_opset13, + symbolic_opset14, + symbolic_opset15, + symbolic_opset16, + symbolic_opset17, + symbolic_opset18, + symbolic_opset19, + symbolic_opset20, + utils, +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING: @@ -48,10 +136,22 @@ from collections.abc import Collection, Mapping, Sequence # Set namespace for exposed private names +<<<<<<< HEAD ONNXProgram.__module__ = "torch.onnx" OnnxExporterError.__module__ = "torch.onnx" # TODO(justinchuby): Remove these two properties +======= +JitScalarType.__module__ = "torch.onnx" +ONNXProgram.__module__ = "torch.onnx" +OnnxExporterError.__module__ = "torch.onnx" +_OrtBackend.__module__ = "torch.onnx" +_OrtBackendOptions.__module__ = "torch.onnx" +_OrtExecutionProvider.__module__ = "torch.onnx" +enable_fake_mode.__module__ = "torch.onnx" +is_onnxrt_backend_supported.__module__ = "torch.onnx" + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) producer_name = "pytorch" producer_version = _C_onnx.PRODUCER_VERSION @@ -65,11 +165,23 @@ def export( f: str | os.PathLike | None = None, *, kwargs: dict[str, Any] | None = None, +<<<<<<< HEAD +======= + export_params: bool = True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) verbose: bool | None = None, input_names: Sequence[str] | None = None, output_names: Sequence[str] | None = None, opset_version: int | None = None, +<<<<<<< HEAD dynamo: bool = True, +======= + dynamic_axes: Mapping[str, Mapping[int, str]] + | Mapping[str, Sequence[int]] + | None = None, + keep_initializers_as_inputs: bool = False, + dynamo: bool = False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Dynamo only options external_data: bool = True, dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, @@ -82,12 +194,15 @@ def export( dump_exported_program: bool = False, artifacts_dir: str | os.PathLike = ".", fallback: bool = False, +<<<<<<< HEAD # BC options export_params: bool = True, keep_initializers_as_inputs: bool = False, dynamic_axes: Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None = None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Deprecated options training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX, @@ -100,7 +215,11 @@ def export( Setting ``dynamo=True`` enables the new ONNX export logic which is based on :class:`torch.export.ExportedProgram` and a more modern +<<<<<<< HEAD set of translation logic. This is the recommended and default way to export models +======= + set of translation logic. This is the recommended way to export models +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) to ONNX. When ``dynamo=True``: @@ -110,22 +229,38 @@ def export( #. If the model is already an ExportedProgram, it will be used as-is. #. Use :func:`torch.export.export` and set ``strict=False``. #. Use :func:`torch.export.export` and set ``strict=True``. +<<<<<<< HEAD +======= + #. Use ``draft_export`` which removes some soundness guarantees in data-dependent + operations to allow export to proceed. You will get a warning if the exporter + encounters any unsound data-dependent operation. + #. Use :func:`torch.jit.trace` to trace the model then convert to ExportedProgram. + This is the most unsound strategy but may be useful for converting TorchScript + models to ONNX. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: model: The model to be exported. args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the exported model; any Tensor arguments will become inputs of the exported model, in the order they occur in the tuple. +<<<<<<< HEAD f: Path to the output ONNX model file. E.g. "model.onnx". This argument is kept for backward compatibility. It is recommended to leave unspecified (None) and use the returned :class:`torch.onnx.ONNXProgram` to serialize the model to a file instead. kwargs: Optional example keyword inputs. +======= + f: Path to the output ONNX model file. E.g. "model.onnx". + kwargs: Optional example keyword inputs. + export_params: If false, parameters (weights) will not be exported. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) verbose: Whether to enable verbose logging. input_names: names to assign to the input nodes of the graph, in order. output_names: names to assign to the output nodes of the graph, in order. opset_version: The version of the `default (ai.onnx) opset `_ +<<<<<<< HEAD to target. You should set ``opset_version`` according to the supported opset versions of the runtime backend or compiler you want to run the exported model with. Leave as default (``None``) to use the recommended version, or refer to @@ -176,6 +311,10 @@ def export( dynamic_axes: Prefer specifying ``dynamic_shapes`` when ``dynamo=True`` and when ``fallback`` is not enabled. +======= + to target. Must be >= 7. + dynamic_axes: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) By default the exported model will have the shapes of all input and output tensors set to exactly match those given in ``args``. To specify axes of tensors as @@ -257,12 +396,93 @@ def forward(self, x): dim_param: "sum_dynamic_axes_1" # axis 0 ... +<<<<<<< HEAD training: Deprecated option. Instead, set the training mode of the model before exporting. operator_export_type: Deprecated option. Only ONNX is supported. do_constant_folding: Deprecated option. custom_opsets: Deprecated option. export_modules_as_functions: Deprecated option. autograd_inlining: Deprecated option. +======= + keep_initializers_as_inputs: If True, all the + initializers (typically corresponding to model weights) in the + exported graph will also be added as inputs to the graph. If False, + then initializers are not added as inputs to the graph, and only + the user inputs are added as inputs. + + Set this to True if you intend to supply model weights at runtime. + Set it to False if the weights are static to allow for better optimizations + (e.g. constant folding) by backends/runtimes. + + dynamo: Whether to export the model with ``torch.export`` ExportedProgram instead of TorchScript. + external_data: Whether to save the model weights as an external data file. + This is required for models with large weights that exceed the ONNX file size limit (2GB). + When False, the weights are saved in the ONNX file with the model architecture. + dynamic_shapes: A dictionary or a tuple of dynamic shapes for the model inputs. Refer to + :func:`torch.export.export` for more details. This is only used (and preferred) when dynamo is True. + Note that dynamic_shapes is designed to be used when the model is exported with dynamo=True, while + dynamic_axes is used when dynamo=False. + custom_translation_table: A dictionary of custom decompositions for operators in the model. + The dictionary should have the callable target in the fx Node as the key (e.g. ``torch.ops.aten.stft.default``), + and the value should be a function that builds that graph using ONNX Script. This option + is only valid when dynamo is True. + report: Whether to generate a markdown report for the export process. This option + is only valid when dynamo is True. + optimize: Whether to optimize the exported model. This option + is only valid when dynamo is True. Default is True. + verify: Whether to verify the exported model using ONNX Runtime. This option + is only valid when dynamo is True. + profile: Whether to profile the export process. This option + is only valid when dynamo is True. + dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file. + This is useful for debugging the exporter. This option is only valid when dynamo is True. + artifacts_dir: The directory to save the debugging artifacts like the report and the serialized + exported program. This option is only valid when dynamo is True. + fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails. + This option is only valid when dynamo is True. When fallback is enabled, It is + recommended to set dynamic_axes even when dynamic_shapes is provided. + + training: Deprecated option. Instead, set the training mode of the model before exporting. + operator_export_type: Deprecated option. Only ONNX is supported. + do_constant_folding: Deprecated option. + custom_opsets: Deprecated. + A dictionary: + + * KEY (str): opset domain name + * VALUE (int): opset version + + If a custom opset is referenced by ``model`` but not mentioned in this dictionary, + the opset version is set to 1. Only custom opset domain name and version should be + indicated through this argument. + export_modules_as_functions: Deprecated option. + + Flag to enable + exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the + particular types of modules to export as local functions in ONNX. + This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because + ``opset_version`` < 15 implies IR version < 8, which means no local function support. + Module variables will be exported as function attributes. There are two categories of function + attributes. + + 1. Annotated attributes: class variables that have type annotations via + `PEP 526-style `_ + will be exported as attributes. + Annotated attributes are not used inside the subgraph of ONNX local function because + they are not created by PyTorch JIT tracing, but they may be used by consumers + to determine whether or not to replace the function with a particular fused kernel. + + 2. Inferred attributes: variables that are used by operators inside the module. Attribute names + will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from + python module annotations. Inferred attributes are used inside the subgraph of ONNX local function. + + * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes. + * ``True``: export all ``nn.Module`` forward calls as local function nodes. + * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes, + only if the type of the ``nn.Module`` is found in the set. + autograd_inlining: Deprecated. + Flag used to control whether to inline autograd functions. + Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns: :class:`torch.onnx.ONNXProgram` if dynamo is True, otherwise None. @@ -275,8 +495,11 @@ def forward(self, x): *autograd_inlining* is now deprecated. .. versionchanged:: 2.7 *optimize* is now True by default. +<<<<<<< HEAD .. versionchanged:: 2.9 *dynamo* is now True by default. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ if dynamo is True or isinstance(model, torch.export.ExportedProgram): from torch.onnx._internal.exporter import _compat @@ -320,7 +543,11 @@ def forward(self, x): else: import warnings +<<<<<<< HEAD from ._internal.torchscript_exporter.utils import export +======= + from torch.onnx.utils import export +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) warnings.warn( "You are using the legacy TorchScript-based ONNX export. Starting in PyTorch 2.9, " @@ -362,9 +589,118 @@ def forward(self, x): return None +<<<<<<< HEAD def is_in_onnx_export() -> bool: """Returns whether it is in the middle of ONNX export.""" from torch.onnx._internal.exporter import _flags from torch.onnx._internal.torchscript_exporter._globals import GLOBALS +======= +@deprecated( + "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead." +) +class ExportOptions: + """Options for dynamo_export. + + .. deprecated:: 2.7 + Please use ``torch.onnx.export(..., dynamo=True)`` instead. + + Attributes: + dynamic_shapes: Shape information hint for input/output tensors. + When ``None``, the exporter determines the most compatible setting. + When ``True``, all input shapes are considered dynamic. + When ``False``, all input shapes are considered static. + """ + + def __init__(self, *, dynamic_shapes: bool | None = None): + self.dynamic_shapes: bool | None = dynamic_shapes + + +@deprecated( + "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead." +) +def dynamo_export( + model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined] + /, + *model_args, + export_options: ExportOptions | None = None, + **model_kwargs, +) -> ONNXProgram: + """Export a torch.nn.Module to an ONNX graph. + + .. deprecated:: 2.7 + Please use ``torch.onnx.export(..., dynamo=True)`` instead. + + Args: + model: The PyTorch model to be exported to ONNX. + model_args: Positional inputs to ``model``. + model_kwargs: Keyword inputs to ``model``. + export_options: Options to influence the export to ONNX. + + Returns: + An in-memory representation of the exported ONNX model. + """ + + import warnings + + from torch.onnx._internal.exporter import _compat + from torch.utils import _pytree + + if isinstance(model, torch.export.ExportedProgram): + return _compat.export_compat( + model, # type: ignore[arg-type] + model_args, + f=None, + kwargs=model_kwargs, + opset_version=18, + external_data=True, + export_params=True, + fallback=True, + ) + if export_options is not None: + warnings.warn( + "You are using an experimental ONNX export logic, which currently only supports dynamic shapes. " + "For a more comprehensive set of export options, including advanced features, please consider using " + "`torch.onnx.export(..., dynamo=True)`. ", + category=DeprecationWarning, + ) + + if export_options is not None and export_options.dynamic_shapes: + # Make all shapes dynamic if it's possible + def _to_dynamic_shape(x): + if isinstance(x, torch.Tensor): + rank = len(x.shape) + dynamic_shape = {} + for i in range(rank): + dynamic_shape[i] = torch.export.Dim.AUTO + return dynamic_shape + else: + return None + + # model_args could be nested + dynamic_shapes = _pytree.tree_map( + _to_dynamic_shape, + model_args, + ) + else: + dynamic_shapes = None + + return _compat.export_compat( + model, # type: ignore[arg-type] + model_args, + f=None, + kwargs=model_kwargs, + dynamic_shapes=dynamic_shapes, + opset_version=18, + external_data=True, + export_params=True, + fallback=True, + ) + + +def is_in_onnx_export() -> bool: + """Returns whether it is in the middle of ONNX export.""" + from torch.onnx._globals import GLOBALS + from torch.onnx._internal.exporter import _flags +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return GLOBALS.in_onnx_export or _flags._is_onnx_exporting diff --git a/torch/onnx/_constants.py b/torch/onnx/_constants.py index 87ff04da8cd1e..a1cdbbdc708c4 100644 --- a/torch/onnx/_constants.py +++ b/torch/onnx/_constants.py @@ -6,7 +6,11 @@ ONNX_MIN_OPSET = 7 ONNX_MAX_OPSET = 23 ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 20 +<<<<<<< HEAD ONNX_DEFAULT_OPSET = 20 +======= +ONNX_DEFAULT_OPSET = 18 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ONNX_CONSTANT_FOLDING_MIN_OPSET = 9 PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues" diff --git a/torch/onnx/_experimental.py b/torch/onnx/_experimental.py new file mode 100644 index 0000000000000..0fac4450a71c8 --- /dev/null +++ b/torch/onnx/_experimental.py @@ -0,0 +1,28 @@ +"""Experimental classes and functions used by ONNX export.""" + +import dataclasses +from collections.abc import Mapping, Sequence +from typing import Optional, Union + +import torch +import torch._C._onnx as _C_onnx + + +@dataclasses.dataclass +class ExportOptions: + """Arguments used by :func:`torch.onnx.export`.""" + + # TODO(justinchuby): Deprecate and remove this class. + + export_params: bool = True + verbose: bool = False + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL + input_names: Optional[Sequence[str]] = None + output_names: Optional[Sequence[str]] = None + operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX + opset_version: Optional[int] = None + do_constant_folding: bool = True + dynamic_axes: Optional[Mapping[str, Union[Mapping[int, str], Sequence[int]]]] = None + keep_initializers_as_inputs: Optional[bool] = None + custom_opsets: Optional[Mapping[str, int]] = None + export_modules_as_functions: Union[bool, set[type[torch.nn.Module]]] = False diff --git a/torch/onnx/_flags.py b/torch/onnx/_flags.py index b88e3b3363f1d..b5f7d44e01961 100644 --- a/torch/onnx/_flags.py +++ b/torch/onnx/_flags.py @@ -43,8 +43,15 @@ def _load_boolean_flag( return state +<<<<<<< HEAD ENABLE_DRAFT_EXPORT: bool = _load_boolean_flag( "TORCH_ONNX_ENABLE_DRAFT_EXPORT", this_will="enable torch.export.draft_export as a strategy for capturing models", default=False, +======= +PLACEHOLDER: bool = _load_boolean_flag( + "TORCH_ONNX_PLACEHOLDER", + this_will="do nothing", + default=True, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/torch/onnx/_globals.py b/torch/onnx/_globals.py new file mode 100644 index 0000000000000..55d0550324e73 --- /dev/null +++ b/torch/onnx/_globals.py @@ -0,0 +1,82 @@ +# mypy: allow-untyped-defs +"""Globals used internally by the ONNX exporter. + +Do not use this module outside of `torch.onnx` and its tests. + +Be very judicious when adding any new global variables. Do not create new global +variables unless they are absolutely necessary. +""" + +import torch._C._onnx as _C_onnx + +# This module should only depend on _constants and nothing else in torch.onnx to keep +# dependency direction clean. +from torch.onnx import _constants + + +class _InternalGlobals: + """Globals used internally by ONNX exporter. + + NOTE: Be very judicious when adding any new variables. Do not create new + global variables unless they are absolutely necessary. + """ + + def __init__(self) -> None: + self._export_onnx_opset_version = _constants.ONNX_DEFAULT_OPSET + self._training_mode: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL + self._in_onnx_export: bool = False + # Whether the user's model is training during export + self.export_training: bool = False + self.operator_export_type: _C_onnx.OperatorExportTypes = ( + _C_onnx.OperatorExportTypes.ONNX + ) + self.onnx_shape_inference: bool = True + self._autograd_inlining: bool = True + + @property + def training_mode(self): + """The training mode for the exporter.""" + return self._training_mode + + @training_mode.setter + def training_mode(self, training_mode: _C_onnx.TrainingMode): + if not isinstance(training_mode, _C_onnx.TrainingMode): + raise TypeError( + "training_mode must be of type 'torch.onnx.TrainingMode'. This is " + "likely a bug in torch.onnx." + ) + self._training_mode = training_mode + + @property + def export_onnx_opset_version(self) -> int: + """Opset version used during export.""" + return self._export_onnx_opset_version + + @export_onnx_opset_version.setter + def export_onnx_opset_version(self, value: int): + self._export_onnx_opset_version = value + + @property + def in_onnx_export(self) -> bool: + """Whether it is in the middle of ONNX export.""" + return self._in_onnx_export + + @in_onnx_export.setter + def in_onnx_export(self, value: bool): + if type(value) is not bool: + raise TypeError("in_onnx_export must be a boolean") + self._in_onnx_export = value + + @property + def autograd_inlining(self) -> bool: + """Whether Autograd must be inlined.""" + return self._autograd_inlining + + @autograd_inlining.setter + def autograd_inlining(self, value: bool): + if type(value) is not bool: + raise TypeError("autograd_inlining must be a boolean") + self._autograd_inlining = value + + +GLOBALS = _InternalGlobals() diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py new file mode 100644 index 0000000000000..b3150ef9cdeb3 --- /dev/null +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -0,0 +1,496 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + + +__all__ = [ + "ExportOptions", + "ONNXRuntimeOptions", + "OnnxRegistry", + "enable_fake_mode", +] + + +import abc +import contextlib +import dataclasses +import logging +import warnings +from collections import defaultdict +from typing import Any, Callable, TYPE_CHECKING +from typing_extensions import deprecated + +import torch +import torch._ops +from torch.onnx._internal import io_adapter +from torch.onnx._internal._lazy_import import onnxscript_apis +from torch.onnx._internal.exporter import _constants +from torch.onnx._internal.fx import ( + decomposition_table, + patcher as patcher, + registration, +) + + +# We can only import onnx from this module in a type-checking context to ensure that +# 'import torch.onnx' continues to work without having 'onnx' installed. We fully +# 'import onnx' inside of dynamo_export (by way of _assert_dependencies). +if TYPE_CHECKING: + import io + from collections.abc import Mapping, Sequence + + import onnxruntime + import onnxscript + + from torch._subclasses import fake_tensor + +log = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ONNXFakeContext: + """A dataclass used to store context for model export using FakeTensor. + + This dataclass stores the FakeTensorMode instance used to convert + real tensors and model parameters into fake tensors. This :attr:`ONNXFakeContext.fake_mode` is + reused internally during tracing of a :class:`torch.nn.Module` into a FX :class:`GraphModule`. + """ + + fake_mode: fake_tensor.FakeTensorMode + """The fake tensor mode used for tracing model using fake tensors and parameters.""" + + state_dict_paths: tuple[str | io.BytesIO | dict[str, Any]] | None = None + """List of paths of files that contain the model :meth:`state_dict`""" + + +@deprecated( + "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", +) +class OnnxRegistry: + """Registry for ONNX functions. + + .. deprecated:: 2.7 + Please use ``torch.onnx.export(..., dynamo=True)`` instead. + + The registry maintains a mapping from qualified names to symbolic functions under a + fixed opset version. It supports registering custom onnx-script functions and for + dispatcher to dispatch calls to the appropriate function. + + """ + + def __init__(self) -> None: + """Initializes the registry""" + + # NOTE: _registry is the registry maps OpNameto a list of ONNXFunctions. It is important + # not to directly modify this variable. Instead, access to it should be done through + # the public methods: register_custom_op, get_ops, and is_registered_op. + self._registry: dict[registration.OpName, list[registration.ONNXFunction]] = ( + defaultdict(list) + ) + + self._opset_version = _constants.TORCHLIB_OPSET + warnings.warn( + f"torch.onnx.dynamo_export only implements opset version {self._opset_version} for now. If you need to use a " + "different opset version, please register them with register_custom_op." + ) + + self._initiate_registry_from_torchlib() + + @property + def opset_version(self) -> int: + """The ONNX opset version the exporter should target.""" + + return self._opset_version + + def _initiate_registry_from_torchlib(self) -> None: + """Populates the registry with ATen functions from torchlib. + + Args: + torchlib_registry: The torchlib registry to use for populating the registry. + """ + for meta in onnxscript_apis.get_torchlib_ops(): + internal_name_instance = registration.OpName.from_qualified_name( + meta.qualified_name + ) + symbolic_function = registration.ONNXFunction( + onnx_function=meta.function, # type: ignore[arg-type] + op_full_name=internal_name_instance.qualified_name(), + is_custom=False, + is_complex=meta.is_complex, + ) + self._register(internal_name_instance, symbolic_function) + + def _register( + self, + internal_qualified_name: registration.OpName, + symbolic_function: registration.ONNXFunction, + ) -> None: + """Registers a ONNXFunction to an operator. + + Args: + internal_qualified_name: The qualified name of the operator to register: OpName. + symbolic_function: The ONNXFunction to register. + """ + self._registry[internal_qualified_name].append(symbolic_function) + + def register_op( + self, + function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, + namespace: str, + op_name: str, + overload: str | None = None, + is_complex: bool = False, + ) -> None: + """Registers a custom operator: torch.ops.... + + Args: + function: The onnx-sctip function to register. + namespace: The namespace of the operator to register. + op_name: The name of the operator to register. + overload: The overload of the operator to register. If it's default overload, + leave it to None. + is_complex: Whether the function is a function that handles complex valued inputs. + + Raises: + ValueError: If the name is not in the form of 'namespace::op'. + """ + internal_name_instance = registration.OpName.from_name_parts( + namespace=namespace, op_name=op_name, overload=overload + ) + symbolic_function = registration.ONNXFunction( + onnx_function=function, + op_full_name=internal_name_instance.qualified_name(), + is_custom=True, + is_complex=is_complex, + ) + self._register(internal_name_instance, symbolic_function) + + def get_op_functions( + self, namespace: str, op_name: str, overload: str | None = None + ) -> list[registration.ONNXFunction] | None: + """Returns a list of ONNXFunctions for the given op: torch.ops.... + + The list is ordered by the time of registration. The custom operators should be + in the second half of the list. + + Args: + namespace: The namespace of the operator to get. + op_name: The name of the operator to get. + overload: The overload of the operator to get. If it's default overload, + leave it to None. + Returns: + A list of ONNXFunctions corresponding to the given name, or None if + the name is not in the registry. + """ + internal_name_instance = registration.OpName.from_name_parts( + namespace=namespace, op_name=op_name, overload=overload + ) + return self._registry.get(internal_name_instance) + + def is_registered_op( + self, namespace: str, op_name: str, overload: str | None = None + ) -> bool: + """Returns whether the given op is registered: torch.ops.... + + Args: + namespace: The namespace of the operator to check. + op_name: The name of the operator to check. + overload: The overload of the operator to check. If it's default overload, + leave it to None. + + Returns: + True if the given op is registered, otherwise False. + """ + functions = self.get_op_functions( + namespace=namespace, op_name=op_name, overload=overload + ) + return functions is not None + + def _all_registered_ops(self) -> set[str]: + """Returns the set of all registered function names.""" + return { + op_name_class.qualified_name() for op_name_class in self._registry.keys() + } + + +@deprecated( + "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", + category=None, +) +class ExportOptions: + """Options to influence the TorchDynamo ONNX exporter. + + .. deprecated:: 2.7 + Please use ``torch.onnx.export(..., dynamo=True)`` instead. + + Attributes: + dynamic_shapes: Shape information hint for input/output tensors. + When ``None``, the exporter determines the most compatible setting. + When ``True``, all input shapes are considered dynamic. + When ``False``, all input shapes are considered static. + fake_context: The fake context used for symbolic tracing. + onnx_registry: The ONNX registry used to register ATen operators to ONNX functions. + """ + + def __init__( + self, + *, + dynamic_shapes: bool | None = True, + fake_context: ONNXFakeContext | None = None, + onnx_registry: OnnxRegistry | None = None, + ): + self.dynamic_shapes = dynamic_shapes + self.fake_context = fake_context + self.onnx_registry = onnx_registry + + +@deprecated( + "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", + category=None, +) +class ResolvedExportOptions(ExportOptions): + """Consolidates :class:`ExportOptions` with default values. + All unspecified options from :class:`ExportOptions` are assigned a default value. + This is an internal class and its API may be changed at any time without notice. + """ + + def __init__(self): + from torch.onnx._internal.fx import ( + dynamo_graph_extractor, + onnxfunction_dispatcher, + ) + + self.dynamic_shapes: bool = True + self.fx_tracer: dynamo_graph_extractor.DynamoExport = ( + dynamo_graph_extractor.DynamoExport() + ) + self.fake_context = None + self.onnx_registry: OnnxRegistry = OnnxRegistry() + self.decomposition_table = ( + decomposition_table.create_onnx_friendly_decomposition_table( # type: ignore[assignment] + self.onnx_registry + ) + ) + self.onnxfunction_dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher( + self.onnx_registry, + ) + + +@contextlib.contextmanager +def enable_fake_mode(): + """Enable fake mode for the duration of the context. + + Internally it instantiates a :class:`torch._subclasses.fake_tensor.FakeTensorMode` context manager + that converts user input and model parameters into :class:`torch._subclasses.fake_tensor.FakeTensor`. + + A :class:`torch._subclasses.fake_tensor.FakeTensor` + is a :class:`torch.Tensor` with the ability to run PyTorch code without having to + actually do computation through tensors allocated on a ``meta`` device. Because + there is no actual data being allocated on the device, this API allows for + initializing and exporting large models without the actual memory footprint needed for executing it. + + It is highly recommended to initialize the model in fake mode when exporting models that + are too large to fit into memory. + + .. note:: + This function does not support torch.onnx.export(..., dynamo=True, optimize=True). + Please call ONNXProgram.optimize() outside of the function after the model is exported. + + Example:: + + # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) + >>> import torch + >>> class MyModel(torch.nn.Module): # Model with a parameter + ... def __init__(self) -> None: + ... super().__init__() + ... self.weight = torch.nn.Parameter(torch.tensor(42.0)) + ... def forward(self, x): + ... return self.weight + x + >>> with torch.onnx.enable_fake_mode(): + ... # When initialized in fake mode, the model's parameters are fake tensors + ... # They do not take up memory so we can initialize large models + ... my_nn_module = MyModel() + ... arg1 = torch.randn(2, 2, 2) + >>> onnx_program = torch.onnx.export(my_nn_module, (arg1,), dynamo=True, optimize=False) + >>> # Saving model WITHOUT initializers (only the architecture) + >>> onnx_program.save( + ... "my_model_without_initializers.onnx", + ... include_initializers=False, + ... keep_initializers_as_inputs=True, + ... ) + >>> # Saving model WITH initializers after applying concrete weights + >>> onnx_program.apply_weights({"weight": torch.tensor(42.0)}) + >>> onnx_program.save("my_model_with_initializers.onnx") + + .. warning:: + This API is experimental and is *NOT* backward-compatible. + + """ + from torch._subclasses import fake_tensor + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + # This overrides the internal `FakeTensorMode` instance created by `torch._dynamo.export`[1]. + # It is a good idea to keep them in sync (constructor args) to maintain the same default behavior + # [1] `torch/_dynamo/output_graph.py::InstructionTranslator::OutputGraph.__init__` + # Mixed fake/real tensors are only allowed when `torch.onnx.dynamo_export` is not called within `FakeTensorMode` + # This is needed because models can create new parameters during `forward(self, *args, **kwargs)` run + fake_mode = fake_tensor.FakeTensorMode( + allow_non_fake_inputs=not torch._guards.detect_fake_mode(), + shape_env=ShapeEnv( + allow_scalar_outputs=False, allow_dynamic_output_shape_ops=False + ), + ) + # The patcher is needed for when user calls `fake_model.load_state_dict(...)` within fake mode + patcher_context = patcher.ONNXTorchPatcher() + fake_context = ONNXFakeContext(fake_mode=fake_mode) + with fake_mode, patcher_context: + yield fake_context + fake_context.state_dict_paths = tuple( + patcher_context.paths, + ) # type: ignore[assignment] + + +@deprecated( + "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", +) +class ONNXRuntimeOptions: + """Options to influence the execution of the ONNX model through ONNX Runtime. + + .. deprecated:: 2.7 + Please use ``torch.onnx.export(..., dynamo=True)`` instead. + + Attributes: + session_options: ONNX Runtime session options. + execution_providers: ONNX Runtime execution providers to use during model execution. + execution_provider_options: ONNX Runtime execution provider options. + """ + + session_options: Sequence[onnxruntime.SessionOptions] | None = None + """ONNX Runtime session options.""" + + execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None + """ONNX Runtime execution providers to use during model execution.""" + + execution_provider_options: Sequence[dict[Any, Any]] | None = None + """ONNX Runtime execution provider options.""" + + def __init__( + self, + *, + session_options: Sequence[onnxruntime.SessionOptions] | None = None, + execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None, + execution_provider_options: Sequence[dict[Any, Any]] | None = None, + ): + self.session_options = session_options + self.execution_providers = execution_providers + self.execution_provider_options = execution_provider_options + + +class FXGraphExtractor(abc.ABC): + """Abstract interface for FX graph extractor engines. + This class isolates FX extraction logic from the rest of the export logic. + That allows a single ONNX exporter that can leverage different FX graphs.""" + + def __init__(self) -> None: + super().__init__() + self.input_adapter: io_adapter.InputAdapter = io_adapter.InputAdapter() + self.output_adapter: io_adapter.OutputAdapter = io_adapter.OutputAdapter() + + @abc.abstractmethod + def generate_fx( + self, + options: ResolvedExportOptions, + model: torch.nn.Module | Callable, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + ) -> torch.fx.GraphModule: + """Analyzes user ``model`` and generates a FX graph. + Args: + options: The export options. + model: The user model. + model_args: The model's positional input arguments. + model_kwargs: The model's keyword input arguments. + Returns: + The generated FX Graph. + """ + ... + + # TODO: Design the passes API + @abc.abstractmethod + def pre_export_passes( + self, + options: ResolvedExportOptions, + original_model: torch.nn.Module | Callable, + fx_module: torch.fx.GraphModule, + fx_module_args: Sequence[Any], + ): + """Applies pre-export passes to the FX graph. + + Pre-export passes are FX-to-FX graph transformations that make the graph + more palatable for the FX-to-ONNX conversion. + For example, it can be used to flatten model input/output, add explicit + casts to the graph, replace/decompose operators, functionalize the graph, etc. + """ + ... + + +def common_pre_export_passes( + options: ResolvedExportOptions, + original_model: torch.nn.Module | Callable, + fx_module: torch.fx.GraphModule, + fx_module_args: Sequence[Any], +): + # TODO: Import here to prevent circular dependency + from torch.onnx._internal.fx import passes + + # Apply decomposition table to the input graph. + module = passes.Decompose( + fx_module, + options.decomposition_table, # type: ignore[arg-type] + enable_dynamic_axes=options.dynamic_shapes, + allow_fake_constant=options.fake_context is not None, + ).run(*fx_module_args) + + # ONNX does not support views and mutations. + # Functionalize to get a semantically equivalent graph without mutations. + module = passes.Functionalize( + module, + enable_dynamic_axes=options.dynamic_shapes, + allow_fake_constant=options.fake_context is not None, + ).run(*fx_module_args) + + # Input mutations are detected and distilled after `Functionalize` pass. + # Remove them since ONNX inference does not need them. + module = passes.RemoveInputMutation(module).run(*fx_module_args) + + # ONNX does not support concept of (implicit) type promotion. + # Insert type casts explicitly where needed. + module = passes.InsertTypePromotion(module).run() + + if isinstance(original_model, torch.nn.Module): + module = passes.RestoreParameterAndBufferNames(module, original_model).run() + + # ONNX does not support None inputs. During graph building, all None inputs + # are removed. Here we register this step to input adapter. + options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep()) + + # NOTE: temp workaround for https://github.com/pytorch/pytorch/issues/99534 + # Dynamo doesn't support non-tensor inputs. + options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNonTensorInputStep()) + + # ONNX does not support complex inputs. During graph building, all complex inputs + # are converted to real representation inputs. Here we register this step to + # input/output adapter. + options.fx_tracer.input_adapter.append_step( + io_adapter.ConvertComplexToRealRepresentationInputStep() + ) + + # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of + # tensor, etc), we flatten the collection and register each element as output. + options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep()) + + # Output post-processing steps should happen after `FlattenOutputStep`. + options.fx_tracer.output_adapter.append_step( + io_adapter.ConvertComplexToRealRepresentationOutputStep() + ) + + return module diff --git a/torch/onnx/_internal/_lazy_import.py b/torch/onnx/_internal/_lazy_import.py index 5e2340fe4c42d..84c0cb71fe3ca 100644 --- a/torch/onnx/_internal/_lazy_import.py +++ b/torch/onnx/_internal/_lazy_import.py @@ -30,7 +30,11 @@ def __getattr__(self, attr: str) -> object: import onnx import onnx_ir # type: ignore[import-untyped] import onnxscript +<<<<<<< HEAD import onnxscript._framework_apis.torch_2_9 as onnxscript_apis +======= + import onnxscript._framework_apis.torch_2_8 as onnxscript_apis +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) onnxscript_ir = onnx_ir @@ -38,4 +42,8 @@ def __getattr__(self, attr: str) -> object: onnx = _LazyModule("onnx") onnxscript = _LazyModule("onnxscript") onnxscript_ir = _LazyModule("onnx_ir") +<<<<<<< HEAD onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_9") +======= + onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_8") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/_internal/exporter/_analysis.py b/torch/onnx/_internal/exporter/_analysis.py index 53860413526ee..ead68e6a699d1 100644 --- a/torch/onnx/_internal/exporter/_analysis.py +++ b/torch/onnx/_internal/exporter/_analysis.py @@ -159,6 +159,7 @@ def _get_io_specs(exported_program: torch.export.ExportedProgram) -> tuple[dict, for spec in exported_program.graph_signature.output_specs if spec.kind == graph_signature.OutputKind.USER_OUTPUT ] +<<<<<<< HEAD inputs: dict[str, torch._export.serde.schema.TensorMeta | str] = {} outputs: dict[str, torch._export.serde.schema.TensorMeta | str] = {} for spec in user_inputs: @@ -201,6 +202,24 @@ def _log_spec_into_io_specs( return inputs_or_outputs +======= + inputs: dict[str, torch._export.serde.schema.TensorMeta] = {} + outputs: dict[str, torch._export.serde.schema.TensorMeta] = {} + for spec in user_inputs: + if isinstance(spec.arg, graph_signature.ConstantArgument): + continue + name = spec.arg.name + # FIXME: tensor_meta is None sometimes when the exported program still knows the shape/type + inputs[name] = nodes[name].meta["tensor_meta"] + for spec in user_outputs: + if isinstance(spec.arg, graph_signature.ConstantArgument): + continue + name = spec.arg.name + outputs[name] = nodes[name].meta["tensor_meta"] + return inputs, outputs + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _count_fx_targets( exported_program: torch.export.ExportedProgram, ) -> defaultdict[str, int]: diff --git a/torch/onnx/_internal/exporter/_capture_strategies.py b/torch/onnx/_internal/exporter/_capture_strategies.py index 89a2b7e9e5e2f..bae5b11ff2bad 100644 --- a/torch/onnx/_internal/exporter/_capture_strategies.py +++ b/torch/onnx/_internal/exporter/_capture_strategies.py @@ -12,7 +12,11 @@ from typing import Any, Callable, TYPE_CHECKING import torch +<<<<<<< HEAD from torch.onnx import _flags +======= +from torch.export import _draft_export +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING: @@ -251,7 +255,11 @@ class TorchExportDraftExportStrategy(CaptureStrategy): def _capture( self, model, args, kwargs, dynamic_shapes ) -> torch.export.ExportedProgram: +<<<<<<< HEAD ep = torch.export.draft_export( +======= + ep = _draft_export.draft_export( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes ) report = ep._report # type: ignore[attr-defined] @@ -263,19 +271,28 @@ def _capture( def _enter(self, model) -> None: model_repr = _take_first_line(repr(model)) self._verbose_print( +<<<<<<< HEAD f"Obtain model graph for `{model_repr}` with `torch.export.draft_export`..." +======= + f"Obtain model graph for `{model_repr}` with `torch.export draft_export`..." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def _success(self, model) -> None: model_repr = _take_first_line(repr(model)) self._verbose_print( +<<<<<<< HEAD f"Obtain model graph for `{model_repr}` with `torch.export.draft_export`... āœ…" +======= + f"Obtain model graph for `{model_repr}` with `torch.export draft_export`... āœ…" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def _failure(self, model, e) -> None: del e # Unused model_repr = _take_first_line(repr(model)) self._verbose_print( +<<<<<<< HEAD f"Obtain model graph for `{model_repr}` with `torch.export.draft_export`... āŒ" ) @@ -287,3 +304,14 @@ def _failure(self, model, e) -> None: if _flags.ENABLE_DRAFT_EXPORT: CAPTURE_STRATEGIES = (*CAPTURE_STRATEGIES, TorchExportDraftExportStrategy) +======= + f"Obtain model graph for `{model_repr}` with `torch.export draft_export`... āŒ" + ) + + +CAPTURE_STRATEGIES = ( + TorchExportNonStrictStrategy, # strict=False is preferred over strict=True because it does not have dynamo issues + TorchExportStrictStrategy, + TorchExportDraftExportStrategy, +) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index 0bc0c6182fca0..22046955dac6c 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -4,15 +4,22 @@ # mypy: disable-error-code=attr-defined from __future__ import annotations +<<<<<<< HEAD import io +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import logging import warnings from collections.abc import Mapping, Sequence from typing import Any, Callable, TYPE_CHECKING import torch +<<<<<<< HEAD from torch.onnx import _constants as onnx_constants from torch.onnx._internal._lazy_import import onnx, onnxscript_apis, onnxscript_ir as ir +======= +from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.onnx._internal.exporter import ( _constants, _core, @@ -52,7 +59,11 @@ def export_compat( verbose: bool | None = None, input_names: Sequence[str] | None = None, output_names: Sequence[str] | None = None, +<<<<<<< HEAD opset_version: int | None = onnx_constants.ONNX_DEFAULT_OPSET, +======= + opset_version: int | None = _constants.TORCHLIB_OPSET, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) custom_translation_table: dict[Callable, Callable | Sequence[Callable]] | None = None, dynamic_axes: Mapping[str, Mapping[int, str]] @@ -62,7 +73,11 @@ def export_compat( keep_initializers_as_inputs: bool = False, external_data: bool = True, report: bool = False, +<<<<<<< HEAD optimize: bool = True, +======= + optimize: bool = False, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) verify: bool = False, profile: bool = False, dump_exported_program: bool = False, @@ -72,7 +87,11 @@ def export_compat( legacy_export_kwargs: dict[str, Any] | None = None, ) -> _onnx_program.ONNXProgram: if opset_version is None: +<<<<<<< HEAD opset_version = onnx_constants.ONNX_DEFAULT_OPSET +======= + opset_version = _constants.TORCHLIB_OPSET +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(model, torch.export.ExportedProgram): # We know the model is already exported program, so the args, kwargs, and dynamic_shapes @@ -109,6 +128,7 @@ def export_compat( dynamic_shapes_with_export_dim, need_axis_mapping = ( _dynamic_shapes.convert_str_to_export_dim(dynamic_shapes) ) +<<<<<<< HEAD if opset_version < _constants.TORCHLIB_OPSET: logger.warning( @@ -130,6 +150,9 @@ def export_compat( registry = _registration.ONNXRegistry().from_torchlib( opset_version=registry_opset_version ) +======= + registry = _registration.ONNXRegistry().from_torchlib(opset_version=opset_version) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if custom_translation_table is not None: for torch_op, onnx_ops in custom_translation_table.items(): # TODO(justinchuby): Support complex inputs with annotations @@ -212,6 +235,7 @@ def export_compat( onnx_program.optimize() if f is not None: +<<<<<<< HEAD if isinstance(f, io.BytesIO): # For legacy export compatibility, we allow f to be a BytesIO object. # This is not explicitly supported but we may need to maintain the @@ -230,5 +254,13 @@ def export_compat( keep_initializers_as_inputs=keep_initializers_as_inputs, external_data=external_data, ) +======= + onnx_program.save( + f, + include_initializers=export_params, + keep_initializers_as_inputs=keep_initializers_as_inputs, + external_data=external_data, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return onnx_program diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index 33a19d629388d..7a4c37235f340 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -79,7 +79,11 @@ f"""\ Failed to export the model with torch.export. {_BLUE}This is step 1/3{_END} of exporting the model to ONNX. Next steps: - Modify the model code for `torch.export.export` to succeed. Refer to https://pytorch.org/docs/stable/generated/exportdb/index.html for more information. +<<<<<<< HEAD - Debug `torch.export.export` and submit a PR to PyTorch. +======= + - Debug `torch.export.export` and summit a PR to PyTorch. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) - Create an issue in the PyTorch GitHub repository against the {_BLUE}*torch.export*{_END} component and attach the full error stack as well as reproduction scripts.""" ) @@ -726,12 +730,15 @@ def _handle_output_node( # node.args[0] can be a tuple with more than one elements. This happens when, # for example, a subgraph has multiple outputs. We flatten them all as ONNX graph outputs for output in node.args[0]: # type: ignore[index,union-attr] +<<<<<<< HEAD if output is None: logger.warning( "Output node %s has None output. The output is ignored in the exported graph. Please ensure the graph output order is expected", node.name, ) continue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) output_value_name = output.name # type: ignore[union-attr] assert isinstance(output_value_name, str), ( f"Bug: Expected {output_value_name!r} to be a string" @@ -954,7 +961,11 @@ def exported_program_to_ir( Args: exported_program: The exported program to convert. lower: Whether to lower the graph to core ONNX operators. +<<<<<<< HEAD at_conversion: Lower when translating the FX graph to ONNX IR. +======= + at_conversion: Lower whe translating the FX graph to ONNX IR. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) none: Do not lower the graph. registry: The registry of all ONNX Script decomposition. """ @@ -1032,7 +1043,11 @@ def _exported_program_to_onnx_program( exported_program: The exported program to convert. The exported program should be the one that is after decompositions have been applied. lower: Whether to lower the graph to core ONNX operators. +<<<<<<< HEAD at_conversion: Lower when translating the FX graph to ONNX IR. +======= + at_conversion: Lower whe translating the FX graph to ONNX IR. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) none: Do not lower the graph. registry: The registry of all ONNX Script decomposition. """ diff --git a/torch/onnx/_internal/exporter/_flags.py b/torch/onnx/_internal/exporter/_flags.py index 0f07508f831ec..11d548114bb07 100644 --- a/torch/onnx/_internal/exporter/_flags.py +++ b/torch/onnx/_internal/exporter/_flags.py @@ -3,12 +3,17 @@ from __future__ import annotations import functools +<<<<<<< HEAD from typing import Callable, TypeVar from typing_extensions import ParamSpec +======= +from typing import Any, Callable, cast, TypeVar +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _is_onnx_exporting = False +<<<<<<< HEAD # Use ParamSpec to preserve parameter types instead of erasing to Any _P = ParamSpec("_P") _R = TypeVar("_R") @@ -17,6 +22,14 @@ def set_onnx_exporting_flag(func: Callable[_P, _R]) -> Callable[_P, _R]: @functools.wraps(func) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: +======= +TCallable = TypeVar("TCallable", bound=Callable[..., Any]) + + +def set_onnx_exporting_flag(func: TCallable) -> TCallable: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) global _is_onnx_exporting _is_onnx_exporting = True try: @@ -25,4 +38,8 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: # Ensure it resets even if an exception occurs _is_onnx_exporting = False +<<<<<<< HEAD return wrapper +======= + return cast(TCallable, wrapper) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/_internal/exporter/_fx_passes.py b/torch/onnx/_internal/exporter/_fx_passes.py index 98359f2ebaff1..b890d2bf48eb9 100644 --- a/torch/onnx/_internal/exporter/_fx_passes.py +++ b/torch/onnx/_internal/exporter/_fx_passes.py @@ -37,9 +37,16 @@ def remove_assertion_nodes(graph_module: torch.fx.GraphModule) -> torch.fx.Graph torch.ops.aten._assert_scalar.default, torch.ops.aten._assert_tensor_metadata.default, } +<<<<<<< HEAD for gm in graph_module.modules(): for node in gm.graph.nodes: # type: ignore[union-attr] if node.op == "call_function" and node.target in aten_assertion_targets: gm.graph.erase_node(node) # type: ignore[operator, union-attr] gm.recompile() # type: ignore[operator] +======= + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in aten_assertion_targets: + graph_module.graph.erase_node(node) + graph_module.recompile() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return graph_module diff --git a/torch/onnx/_internal/exporter/_onnx_program.py b/torch/onnx/_internal/exporter/_onnx_program.py index 35d51e8329499..e230c79e9748f 100644 --- a/torch/onnx/_internal/exporter/_onnx_program.py +++ b/torch/onnx/_internal/exporter/_onnx_program.py @@ -13,7 +13,10 @@ import tempfile import textwrap import warnings +<<<<<<< HEAD from collections.abc import Sequence +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing import Any, Callable, TYPE_CHECKING import torch @@ -26,7 +29,12 @@ # because ONNXProgram is exposed to the public API if TYPE_CHECKING: +<<<<<<< HEAD import numpy as np +======= + from collections.abc import Sequence + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import onnxruntime as ort _LARGE_MODEL_THRESHOLD = 1536 * 1024 * 1024 # 1536MB @@ -117,6 +125,7 @@ def _create_value_mapping(graph: ir.Graph) -> dict[str, ir.Value]: return values +<<<<<<< HEAD def _to_numpy_array(input: torch.Tensor | int | float | str | bool) -> np.ndarray: if isinstance(input, (int, float, str, bool)): return ir.tensor(input).numpy() @@ -147,10 +156,15 @@ def _from_numpy_array(array: np.ndarray) -> torch.Tensor: def _to_ort_value(input: torch.Tensor | int | float | str | bool) -> ort.OrtValue: """Convert a PyTorch tensor to an ONNX Runtime OrtValue.""" import numpy as np +======= +def _to_ort_value(tensor: torch.Tensor) -> ort.OrtValue: + """Convert a PyTorch tensor to an ONNX Runtime OrtValue.""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import onnxruntime as ort from torch.onnx._internal.exporter import _core +<<<<<<< HEAD if isinstance(input, (int, float, str, bool)): # Convert scalar values to OrtValue dtype_mapping = { @@ -179,6 +193,27 @@ def _to_ort_value(input: torch.Tensor | int | float | str | bool) -> ort.OrtValu ) # TODO(#151064): Use dlpack when ORT properly supports it return ort.OrtValue.ortvalue_from_numpy(input.numpy(force=True)) +======= + if tensor.dtype == torch.bfloat16 or tensor.dtype in _NP_UNSUPPORTED_DTYPES_8BIT: + if hasattr(ort.OrtValue, "ortvalue_from_numpy_with_onnx_type"): + # This requires ONNX Runtime 1.21 or newer + if tensor.dtype == torch.bfloat16: + uint_type = torch.uint16 + else: + uint_type = torch.uint8 + onnx_type = _core.torch_dtype_to_onnx_dtype(tensor.dtype) + # Make tensor contiguous to ensure view() works + tensor = tensor.contiguous() + return ort.OrtValue.ortvalue_from_numpy_with_onnx_type( + tensor.view(uint_type).numpy(force=True), onnx_element_type=onnx_type + ) + raise RuntimeError( + f"Failed to convert tensor of type '{tensor.dtype}' to OrtValue. " + "Please ensure that ONNX Runtime is built with DLPack support or is the latest version" + ) + # TODO(#151064): Use dlpack when ORT properly supports it + return ort.OrtValue.ortvalue_from_numpy(tensor.numpy(force=True)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _from_ort_value(value: ort.OrtValue) -> torch.Tensor: @@ -245,6 +280,10 @@ def __call__(self, *args, **kwargs) -> Sequence[torch.Tensor]: assert self._inference_session is not None +<<<<<<< HEAD +======= + # We don't expect non-tensor as inputs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ort_input = { k.name: _to_ort_value(v) for k, v in zip(self.model.graph.inputs, flatten_args) @@ -258,6 +297,7 @@ def __call__(self, *args, **kwargs) -> Sequence[torch.Tensor]: logger.debug("Inference session run completed.") return tuple(_from_ort_value(output) for output in outputs) +<<<<<<< HEAD def call_reference(self, *args, **kwargs) -> Sequence[torch.Tensor]: """Run the ONNX model using the reference backend.""" import onnx.reference @@ -273,6 +313,8 @@ def call_reference(self, *args, **kwargs) -> Sequence[torch.Tensor]: assert isinstance(outputs, Sequence) return tuple(_from_numpy_array(output) for output in outputs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def compute_values( self, value_names: Sequence[str], args=(), kwargs=None ) -> Sequence[torch.Tensor]: @@ -465,6 +507,10 @@ def _process_args(args, kwargs) -> tuple[torch.Tensor, ...]: """Process input arguments for the ONNX model.""" args = _flatten_inputs(args, kwargs) args = _remove_none_from_inputs(args) +<<<<<<< HEAD +======= + args = _remove_non_tensor(args) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = _convert_complex_to_real_representation(args) return args @@ -478,6 +524,50 @@ def _remove_none_from_inputs(model_args): return tuple(arg for arg in model_args if arg is not None) +<<<<<<< HEAD +======= +def _remove_non_tensor(model_args): + """Remove the non-tensor input arguments. + + Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534). + + Specifically, it does put the input into graph with an empty node, but consumed by no ones. + The concrete value is embedded into the graph as a constant arg of a target node. Meta + suggests in this case that one should rewrite the model code to make it tensor if the + input value is supposed to change at runtime. We might need to further investigate + the feasibility of that suggestion. + + For example, + + def func(x, b=1.0): + y = x + b + z = y.relu() + return (y, z) + + x = torch.randn(1, 1, 2, dtype=torch.float32) + gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real") + + # class GraphModule(torch.nn.Module): + # def forward(self, x, b): + # arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec) + # # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b + # add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None + + # # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu() + # relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor) + # return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec) + + Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as + it's ignored in ONNX graph. Thus, we delete the useless input here. + + """ + + return tuple( + arg for arg in model_args if not isinstance(arg, (int, float, bool, str)) + ) + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _convert_complex_to_real_representation(model_args): """Convert complex dtype tensors to real representation tensors. diff --git a/torch/onnx/_internal/exporter/_reporting.py b/torch/onnx/_internal/exporter/_reporting.py index e2e02e089c5d1..0a52776284884 100644 --- a/torch/onnx/_internal/exporter/_reporting.py +++ b/torch/onnx/_internal/exporter/_reporting.py @@ -22,7 +22,11 @@ class ExportStatus: torch_export_strict: bool | None = None # Whether torch.export.export(..., strict=False) succeeds torch_export_non_strict: bool | None = None +<<<<<<< HEAD # Whether torch.export.draft_export() succeeds +======= + # Whether torch.export._draft_export.draft_export() succeeds +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch_export_draft_export: bool | None = None # Whether decomposition succeeds decomposition: bool | None = None @@ -47,7 +51,11 @@ def _format_export_status(status: ExportStatus) -> str: f"```\n" f"{_status_emoji(status.torch_export_non_strict)} Obtain model graph with `torch.export.export(..., strict=False)`\n" f"{_status_emoji(status.torch_export_strict)} Obtain model graph with `torch.export.export(..., strict=True)`\n" +<<<<<<< HEAD f"{_status_emoji(status.torch_export_draft_export)} Obtain model graph with `torch.export.draft_export`\n" +======= + f"{_status_emoji(status.torch_export_draft_export)} Obtain model graph with `torch.export._draft_export.draft_export`\n" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"{_status_emoji(status.decomposition)} Decompose operators for ONNX compatibility\n" f"{_status_emoji(status.onnx_translation)} Translate the graph into ONNX\n" f"{_status_emoji(status.onnx_checker)} Run `onnx.checker` on the ONNX model\n" diff --git a/torch/onnx/_internal/exporter/_testing.py b/torch/onnx/_internal/exporter/_testing.py index c34c2f1a38c3d..8a5f376d043b9 100644 --- a/torch/onnx/_internal/exporter/_testing.py +++ b/torch/onnx/_internal/exporter/_testing.py @@ -5,7 +5,11 @@ __all__ = ["assert_onnx_program"] +<<<<<<< HEAD from typing import Any, Literal, TYPE_CHECKING +======= +from typing import Any, TYPE_CHECKING +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch.utils import _pytree @@ -23,7 +27,10 @@ def assert_onnx_program( args: tuple[Any, ...] | None = None, kwargs: dict[str, Any] | None = None, strategy: str | None = "TorchExportNonStrictStrategy", +<<<<<<< HEAD backend: Literal["onnxruntime", "reference"] = "onnxruntime", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> None: """Assert that the ONNX model produces the same output as the PyTorch ExportedProgram. @@ -38,8 +45,11 @@ def assert_onnx_program( strategy: Assert the capture strategy used to export the program. Values can be class names like "TorchExportNonStrictStrategy". If None, the strategy is not asserted. +<<<<<<< HEAD backend: The backend to use for evaluating the ONNX program. Supported values are "onnxruntime" and "reference". +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ if strategy is not None: if program._capture_strategy != strategy: @@ -71,15 +81,19 @@ class names like "TorchExportNonStrictStrategy". # ONNX outputs are always real, so we need to convert torch complex outputs to real representations torch_outputs_adapted = [] for output in torch_outputs: +<<<<<<< HEAD # ONNX graph does not support None outputs, so we skip them if output is None: continue +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not isinstance(output, torch.Tensor): torch_outputs_adapted.append(torch.tensor(output)) elif torch.is_complex(output): torch_outputs_adapted.append(torch.view_as_real(output)) else: torch_outputs_adapted.append(output) +<<<<<<< HEAD # Obtain the ONNX outputs using the specified backend if backend == "onnxruntime": @@ -91,6 +105,9 @@ class names like "TorchExportNonStrictStrategy". f"Unsupported backend '{backend}'. Supported backends are 'onnxruntime' and 'reference'." ) +======= + onnx_outputs = program(*args, **kwargs) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # TODO(justinchuby): Include output names in the error message torch.testing.assert_close( tuple(onnx_outputs), diff --git a/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py b/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py index 8c045d11a2b8f..43a1c9be53914 100644 --- a/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py +++ b/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py @@ -10,7 +10,10 @@ import logging from collections.abc import Sequence from typing import Any, Callable, TypeVar +<<<<<<< HEAD from typing_extensions import ParamSpec +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import onnxscript @@ -18,9 +21,13 @@ from torch.onnx._internal.exporter import _constants, _registration +<<<<<<< HEAD # Use ParamSpec for better type preservation instead of bound Callable TypeVar _P = ParamSpec("_P") _R = TypeVar("_R") +======= +_T = TypeVar("_T", bound=Callable) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) logger = logging.getLogger("__name__") @@ -36,7 +43,11 @@ def onnx_impl( opset_introduced: int = 18, no_compile: bool = False, private: bool = False, +<<<<<<< HEAD ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: +======= +) -> Callable[[_T], _T]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Register an ONNX implementation of a torch op.""" if isinstance(target, torch._ops.OpOverloadPacket): @@ -47,8 +58,13 @@ def onnx_impl( ) def wrapper( +<<<<<<< HEAD func: Callable[_P, _R], ) -> Callable[_P, _R]: +======= + func: _T, + ) -> _T: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) processed_func: Any if no_compile: processed_func = func diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py index 90815bc18d6e3..ee8787585f116 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py @@ -58,6 +58,7 @@ def aten_group_norm( ) +<<<<<<< HEAD @onnx_impl(aten.rms_norm.default, trace_only=True, opset_introduced=23) def aten_rms_norm( input: TFloat, @@ -85,6 +86,8 @@ def aten_rms_norm( return op23.RMSNormalization(input, weight, axis=axis, epsilon=eps) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @onnx_impl( aten.scaled_dot_product_attention.default, trace_only=True, opset_introduced=23 ) @@ -104,7 +107,11 @@ def aten_scaled_dot_product_attention_23( 1. https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html 2. https://onnx.ai/onnx/operators/onnx__Attention.html +<<<<<<< HEAD Attempts to convert SDPA to Attention onnx op and fallbacks to an onnx graph equivalent to the following PyTorch code:: +======= + Attempts to convert SDPA to Attention onnx op and fallbacks to an onnx graph equivivalent to the following PyTorch code:: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale attn_mask = ( torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) @@ -147,7 +154,11 @@ def aten_scaled_dot_product_attention_23( ) # NOTE: num_heads attributes (q_num_heads/kv_num_heads) should not be specified for 4D. +<<<<<<< HEAD # They are not populated with 4D inputs because this information directly comes from input shapes: +======= + # They are not populated with 4D inputs because this information directy comes from input shapes: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # `q_num_heads=query.shape[1]` and `kv_num_heads=key.shape[1]`. # This dimension is usually static but it could not be dynamic if also given as an attribute. # num_heads attributes are needed for 3D attention inputs: diff --git a/torch/onnx/_internal/fx/__init__.py b/torch/onnx/_internal/fx/__init__.py index e69de29bb2d1d..b5716bdafced7 100644 --- a/torch/onnx/_internal/fx/__init__.py +++ b/torch/onnx/_internal/fx/__init__.py @@ -0,0 +1,8 @@ +from .patcher import ONNXTorchPatcher +from .serialization import save_model_with_external_data + + +__all__ = [ + "save_model_with_external_data", + "ONNXTorchPatcher", +] diff --git a/torch/onnx/_internal/fx/decomposition_table.py b/torch/onnx/_internal/fx/decomposition_table.py new file mode 100644 index 0000000000000..71715e1ad2344 --- /dev/null +++ b/torch/onnx/_internal/fx/decomposition_table.py @@ -0,0 +1,116 @@ +# mypy: allow-untyped-defs +"""Dispatcher for AtenLib functions from onnx-script.""" + +from __future__ import annotations + +from typing import Callable + +import torch +import torch._ops +import torch.fx +from torch.onnx._internal.fx import registration + + +def _create_onnx_supports_op_overload_table( + registry, +) -> set[torch._ops.OperatorBase | Callable]: + """ + Creates a set of OperatorBase and Callable objects that represent ONNX-supported PyTorch operations. + + Args: + registry (OnnxRegistry): The ONNX registry for PyTorch. + + Returns: + A collection of OperatorBase and Callable objects representing ONNX-supported PyTorch operations. + """ + table: set[torch._ops.OperatorBase | Callable] = set() + + # Some ops in `torch.ops.aten` are not discoverable through `dir(torch.ops.aten)`, + # but retrievable via explicit lookup. + # https://github.com/pytorch/pytorch/issues/99681 + # This is a workaround to make sure we register ONNX symbolic functions for these. + onnx_supported_aten_lookup_table = [ + k.split("::")[1].split(".")[0] + for k in registry._all_registered_ops() + if k.startswith("aten::") + ] + + for op_namespace in (torch.ops.aten, torch.ops.prims): + attr_names = dir(op_namespace) + if op_namespace is torch.ops.aten: + attr_names += onnx_supported_aten_lookup_table + for attr_name in attr_names: + if not hasattr(op_namespace, attr_name): + # torchlib owns some attributes that are not aten ops. + continue + op_overload_packet = getattr(op_namespace, attr_name) + if not isinstance(op_overload_packet, torch._ops.OpOverloadPacket): + continue + + for overload_name in op_overload_packet.overloads(): + op_overload = getattr(op_overload_packet, overload_name) + internal_op_name = registration.OpName.from_qualified_name( + qualified_name=op_overload.name() + ) + # NOTE: If the overload is supported in registry or it's default overload is supported in registry, + # we add it to the table. + if registry.is_registered_op( + namespace=internal_op_name.namespace, + op_name=internal_op_name.op_name, + overload=internal_op_name.overload, + ) or registry.is_registered_op( + namespace=internal_op_name.namespace, + op_name=internal_op_name.op_name, + overload=None, + ): + # This line maps torch.ops.aten.add.Tensor, torch.ops.aten.add.Scalar, torch.ops.aten.add.out, etc + # to "aten::add". This means the exporter for "aten::add" is used for all overloads of "aten::add". + # This is applied to all ops under torch.ops.aten. + table.add(op_overload) + return table + + +def create_onnx_friendly_decomposition_table( + registry, +) -> dict[torch._ops.OperatorBase, Callable]: + """ + This function creates a dictionary of op overloads and their decomposition functions + for ops that do not have ONNX symbolic functions. If an op already has an ONNX symbolic function, + its decomposition function is excluded from the table. The decomposition table is a subset of PyTorch's + built-in aten-to-aten decomposition. + + Args: + registry: The ONNX registry for PyTorch. + + Returns: + Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding + decomposition functions. + """ + decomposition_table: dict[torch._ops.OperatorBase, Callable] = {} + # Dictionary that maps torch.ops.aten.* to exporter look up key; e.g., + # _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE[torch.add.Tensor] is "aten::add". + _ONNX_SUPPORT_OP_OVERLOADS = _create_onnx_supports_op_overload_table(registry) + + # NOTE: If we import torch._decomp, we will get RuntimeError: Only a single + # TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your + # definitions in a single TORCH_LIBRARY block. + for op_overload, decomp_fn in torch._decomp.decomposition_table.items(): + # Skip decomposition into "prim::*" ops (defined in 'torch._refs'), because they + # are not generally supported by ONNX. + # Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX + # symbolic function. + if ( + "torch._refs" in decomp_fn.__module__ + or op_overload in _ONNX_SUPPORT_OP_OVERLOADS + ): + continue + decomposition_table[op_overload] = decomp_fn + + # NOTE: There are ops in core ATen and under torch._refs, + # that are not decomposed to prim::ops. We need to pick them + # back + for op_overload, decomp_fn in torch._decomp.core_aten_decompositions().items(): + if op_overload in _ONNX_SUPPORT_OP_OVERLOADS: + continue + decomposition_table[op_overload] = decomp_fn + return decomposition_table diff --git a/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/torch/onnx/_internal/fx/dynamo_graph_extractor.py new file mode 100644 index 0000000000000..b11903619c080 --- /dev/null +++ b/torch/onnx/_internal/fx/dynamo_graph_extractor.py @@ -0,0 +1,232 @@ +# mypy: allow-untyped-defs +# NOTE: This file is referenced by name at +# /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES. +# introduced by https://github.com/pytorch/pytorch/pull/98894. +# If this file is renamed, moved, etc please update the reference there! + +from __future__ import annotations + +import contextlib +import functools +import inspect +from typing import Any, Callable, TYPE_CHECKING + +import torch._dynamo +import torch.export as torch_export +import torch.fx +import torch.onnx +from torch.onnx._internal import _exporter_legacy, io_adapter +from torch.utils import _pytree as pytree + + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + +class _PyTreeExtensionContext: + """Context manager to register PyTree extension.""" + + _extensions: dict[type, tuple[pytree.FlattenFunc, pytree.UnflattenFunc]] + + def __init__(self) -> None: + self._extensions = {} + # Register PyTree extension for HuggingFace model output. + self._register_huggingface_model_output_extension() + + def __enter__(self): + for class_type, (flatten_func, unflatten_func) in self._extensions.items(): + pytree._private_register_pytree_node( + class_type, + flatten_func, + unflatten_func, + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + for class_type in self._extensions: + pytree.SUPPORTED_NODES.pop(class_type) + + def register_pytree_node( + self, + class_type: type, + flatten_func: pytree.FlattenFunc, + unflatten_func: pytree.UnflattenFunc, + ): + """Register PyTree extension for a custom python type. + + Args: + class_type: The custom python type. + flatten_func: The flatten function. + unflatten_func: The unflatten function. + + Raises: + AssertionError: If the custom python type is already registered. + """ + if class_type in pytree.SUPPORTED_NODES or class_type in self._extensions: + # PyTree node already registered. + # E.g., `huggingface/transformer` registers `ModelOutput` as PyTree node after + # https://github.com/huggingface/transformers/pull/25358. + return + self._extensions[class_type] = (flatten_func, unflatten_func) + + def _register_huggingface_model_output_extension(self): + try: + from transformers import modeling_outputs # type: ignore[import] + except ImportError: + return + + def model_output_flatten( + output: modeling_outputs.ModelOutput, + ) -> tuple[list[Any], pytree.Context]: + return list(output.values()), (type(output), list(output.keys())) + + def model_output_unflatten( + values: list[Any], context: pytree.Context + ) -> modeling_outputs.ModelOutput: + output_type, keys = context + return output_type(**dict(zip(keys, values))) + + # All 'ModelOutput' subclasses are defined under module 'modeling_outputs'. + named_model_output_classes = inspect.getmembers( + modeling_outputs, + lambda x: ( + inspect.isclass(x) + and issubclass(x, modeling_outputs.ModelOutput) + and x is not modeling_outputs.ModelOutput + ), + ) + + for _, class_type in named_model_output_classes: + self.register_pytree_node( + class_type, + model_output_flatten, + model_output_unflatten, # type: ignore[arg-type ] + ) + + +class DynamoFlattenOutputStep(io_adapter.FlattenOutputStep): + """Flatten nested collection and custom python types and return a flat list of elements. + + Extended from :class:`io_adapter.FlattenOutputStep` to support flattening arbitrary + types via pytree extension. By default this supports many common user defined python + types such as :class:`ModelOutput` from HuggingFace transformers. + + The pytree extension can be customized by passing in a ``_PyTreeExtensionContext`` + object. See :meth:`_PyTreeExtensionContext.register_pytree_node`. + """ + + def __init__(self, pytree_extension_context: _PyTreeExtensionContext | None = None): + super().__init__() + self._pytree_extension_context = ( + pytree_extension_context or _PyTreeExtensionContext() + ) + + def apply( + self, + model_outputs: Any, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> Sequence[Any]: + """Flatten the model outputs, under the context of pytree extension.""" + with self._pytree_extension_context: + return super().apply(model_outputs, model=model) + + +def _wrap_model_with_output_adapter( + model: torch.nn.Module | Callable, + output_adapter: DynamoFlattenOutputStep, +) -> Callable: + """Wrap model with output adapter. + + This is a helper function to enable :func:`dynamo.export` on models that produce + custom user defined types outputs. It wraps the model with an output adapter to + convert the outputs to :func:`dynamo.export` compatible types, i.e. :class:`torch.Tensor`. + + The adapting logic is controlled by ``output_adapter``. + + Args: + model: PyTorch model or function. + output_adapter: Output adapter to apply to model output. + Returns: + Wrapped model. + """ + model_func = model.forward if isinstance(model, torch.nn.Module) else model + + # Preserve original function signature. + @functools.wraps(model_func) + def wrapped(*args, **kwargs): + return output_adapter.apply(model_func(*args, **kwargs), model=model) + + return wrapped + + +class DynamoExport(_exporter_legacy.FXGraphExtractor): + """Generates a FX GraphModule using torch.dynamo.export API + Args: + aten_graph: If True, exports a graph with ATen operators. + If False, exports a graph with Python operators. + """ + + def __init__( + self, + aten_graph: bool | None = None, + ): + super().__init__() + self.aten_graph = aten_graph or True + + def generate_fx( + self, + options: _exporter_legacy.ResolvedExportOptions, + model: torch.nn.Module | Callable, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + ) -> torch.fx.GraphModule: + # `dynamo.export` does not recognize custom user defined classes as output type. + # Apply wrapper to adapt the outputs back to `dynamo.export` compatible types, + # i.e. :class:`torch.Tensor`. + dynamo_flatten_output_step = DynamoFlattenOutputStep() + wrapped_model = _wrap_model_with_output_adapter( + model, dynamo_flatten_output_step + ) + # Record the output adapter step. + self.output_adapter.append_step(dynamo_flatten_output_step) + + # Translate callable to FX graph. + # + fake_mode = ( + options.fake_context.fake_mode + if options.fake_context + else contextlib.nullcontext() + ) + fx_mode = "symbolic" if options.dynamic_shapes else "fake" + with fake_mode: # type: ignore[attr-defined] + graph_module, graph_guard = torch._dynamo.export( + wrapped_model, + tracing_mode=fx_mode, + )( + *model_args, + **model_kwargs, + ) + del graph_guard # Unused + torch._dynamo.reset() + + # Export FX graph to ONNX ModelProto. + self.input_adapter.append_step( + io_adapter.FlattenInputWithTreeSpecValidationInputStep() + ) + + updated_model_args = self.input_adapter.apply( + *model_args, model=model, **model_kwargs + ) + + return self.pre_export_passes(options, model, graph_module, updated_model_args) # type: ignore[return-value] + + def pre_export_passes( + self, + options: _exporter_legacy.ResolvedExportOptions, + original_model: torch.nn.Module | Callable, + fx_module: torch.fx.GraphModule, + fx_module_args: Sequence[Any], + ): + return _exporter_legacy.common_pre_export_passes( + options, original_model, fx_module, fx_module_args + ) diff --git a/torch/onnx/_internal/fx/fx_onnx_interpreter.py b/torch/onnx/_internal/fx/fx_onnx_interpreter.py new file mode 100644 index 0000000000000..424f2d171b978 --- /dev/null +++ b/torch/onnx/_internal/fx/fx_onnx_interpreter.py @@ -0,0 +1,718 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import inspect +import operator +from typing import Callable, TYPE_CHECKING + +import onnxscript +from onnxscript.function_libs.torch_lib import ( + graph_building as onnxscript_graph_building, +) + +import torch +import torch.fx +from torch.onnx import _type_utils as jit_type_utils +from torch.onnx._internal.fx import ( + _pass, + onnxfunction_dispatcher, + type_utils as fx_type_utils, +) +from torch.utils import _pytree + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +def _fx_node_to_onnx_message_formatter( + fn: Callable, + self, + node: torch.fx.Node, + *args, + **kwargs, +) -> str: + return f"FX Node: {node.op}:{node.target}[name={node.name}]. " + + +def _fx_graph_to_onnx_message_formatter( + fn: Callable, + self, + fx_graph_module: torch.fx.GraphModule, + *args, + **kwargs, +) -> str: + return f"FX Graph: {fx_graph_module._get_name()}. " + + +def _retrieve_or_adapt_input_to_graph_set( + fx_node_arg: fx_type_utils.Argument, + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, +): + """Map FX value to TorchScript value. + + When creating TorchScript graph from FX graph, we need a mapping from FX variable + to TorchScript variable. This function maps FX variable, fx_node_arg, to torch.jit.Value. + """ + from onnxscript import opset18 as op + + onnx_tensor = fx_node_arg + if isinstance(onnx_tensor, torch.fx.Node): + # 1. fx_node_arg is a torch.fx.Node, which means + # fx_node_arg stands for the output of that torch.fx.Node. + # 2. fx_node_arg (variable in torch.fx.Graph) is be mapped to + # torch.jit.Value, fx_name_to_onnxscript_value[fx_node_arg.name], + # in TorchScript graph. + return fx_name_to_onnxscript_value[onnx_tensor.name] + elif isinstance(onnx_tensor, (tuple, list)) and any( + isinstance(node, torch.fx.Node) + and fx_type_utils.is_torch_symbolic_type(node.meta.get("val")) + for node in onnx_tensor + ): + # This intends to handle dynamic axes. for example, if the input size of op.Expand + # is dynamic, each dimension would be variable (i.e., sym variable in Pytorch + # FX graph. Note that sym variable is mapped to tensor in ONNX Script world) + # calculated by other operators. + sequence_mixed_elements: list[ + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...] + | list[int] + ] = [] + # onnx_tensor contains a list of scalars which could be one of + # - tensor with empty shape, + # - tensor with tensor with shape (1,), + # - torch.SymInt, + # - int + # - ... + # They should all be promoted to tensor with shape (1,) + # in order to call ONNX's Concat. + for tensor in onnx_tensor: + # Prepare `tensor` as input of ONNX's Concat. + + if isinstance( + tensor, torch.fx.Node + ) and fx_type_utils.is_torch_symbolic_type(tensor.meta.get("val")): + # In this case, tensor is a torch.SymInt from Dynamo's perspective. + # It might be mapped to tensor with shape () or (1,) in ONNX. + element_value = fx_name_to_onnxscript_value[tensor.name] + if isinstance( + element_value, onnxscript_graph_building.TorchScriptTensor + ): + # All elements sequence_mixed_elements will be send to onnx's Concat + # as inputs. Therefore, they are required to have the same rank. + # Since tensors with rank=0 (i.e., scalar) cannot be concated, all + # scalars are promoted to tensors with shape (1,). + with onnxscript.evaluator.default_as(tracer): + element_value = op.Reshape( + element_value, # type: ignore[arg-type, type-var] + [1], # type: ignore[arg-type, type-var] + ) + sequence_mixed_elements.append(element_value) + elif isinstance(tensor, int): + # NOTE: op.Concat doesn't support scalar, so we need to wrap it with + # dim, and onnx-script will promote it to tensor(int64) + sequence_mixed_elements.append([tensor]) + else: + raise RuntimeError( + f"Unsupported type in sequence_mixed_elements: {type(tensor)}" + ) + # Concat all the elements in the sequence. + # shapes are mapped to tensors in ONNX graph (TorchScriptGraph), + # so list of sym_ints is concatenated to a tensor before calling ONNX op. + + # For example: + # inputs: [[2], [4], fx.Node(SymIntA), [1], fx.Node(SymIntB)] + # outputs: op.Concat([op.Constant(2), op.Constant(4), TorchScriptTensor(A), op.Constant(1), TorchScriptTensor(B)]) + + # onnx-script auto wraps python number with op.Constants, + # so we don't need to specifically process them. + with onnxscript.evaluator.default_as(tracer): + output = op.Concat(*sequence_mixed_elements, axis=0) # type: ignore[type-var] + output.dtype = torch.int64 # type: ignore[union-attr] + output.shape = [len(sequence_mixed_elements)] # type: ignore[union-attr] + return output + elif isinstance(onnx_tensor, (tuple, list)) and all( + isinstance(node, torch.fx.Node) or node is None for node in onnx_tensor + ): + sequence_elements: list[ + onnxscript_graph_building.TorchScriptTensor + | None + | tuple[onnxscript_graph_building.TorchScriptTensor, ...] + ] = [] + for tensor in onnx_tensor: + sequence_elements.append( + fx_name_to_onnxscript_value[tensor.name] if tensor is not None else None # type: ignore[index, union-attr] + ) + return sequence_elements + if isinstance(onnx_tensor, torch.dtype): + onnx_tensor = int( # type: ignore[call-overload] + jit_type_utils.JitScalarType.from_dtype(onnx_tensor).onnx_type() + ) + # NOTE: if device is specified in kwargs (not consumed), it's free to ignored. But + # if it's in args, we need to set it to string for dispatcher to match schema. + if isinstance(onnx_tensor, torch.device): + # torch.device is not supported by onnxscript (no op). We turn it into + # a string. + return str(onnx_tensor) + # all other cases, we do nothing. + return onnx_tensor + + +def filter_incompatible_and_dtype_convert_kwargs(kwargs): + """Filter out kwargs that are not supported by onnxscript.""" + filtered = {} + for key, value in kwargs.items(): + if key in { + "layout", + "device", + "requires_grad", + "pin_memory", + "memory_format", + "implicit", + }: + continue + if key == "dtype": + if value is None: + # We omit if dtype is not provided, because onnxscript handles the + # default case. + continue + else: + value = int(jit_type_utils.JitScalarType.from_dtype(value).onnx_type()) # type: ignore[call-overload] + filtered[key] = value + return filtered + + +def _fill_tensor_shape_type( + onnxscript_values: onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + name: str, + expected_values: fx_type_utils.META_VALUE_TYPE + | list[fx_type_utils.META_VALUE_TYPE] + | tuple[fx_type_utils.META_VALUE_TYPE | None, ...], +): + """Fill the meta information of onnxscript_values with that from the fx FakeTensor.""" + + if isinstance(expected_values, (list, tuple)) and not isinstance( + onnxscript_values, (list, tuple) + ): + # ex: aten::split - in onnx_dtype: seq(tensor) + # onnxscript_values is a single tensor, but expected_values is a list of tensors. + return + + flat_onnxscript_values, _ = _pytree.tree_flatten(onnxscript_values) + flat_expected_values, _ = _pytree.tree_flatten(expected_values) + for i, (onnxscript_value, expected_value) in enumerate( + zip(flat_onnxscript_values, flat_expected_values) + ): + if expected_value is None: + # There is no shape/type from None. + # NOTE: according to https://github.com/pytorch/pytorch/blob/main/torch/_meta_registrations.py, + # None could be a valid value for return type, so we need to handle it. + # e.g. the function: meta__scaled_dot_product_flash() in cpu mode. + continue + elif fx_type_utils.is_torch_symbolic_type(expected_value): + # aten::sym_size output is a int, not a tensor, which stands + # for the size of one dim. We treat it as 1-D tensor. + onnxscript_value.dtype = fx_type_utils.from_sym_value_to_torch_dtype( + expected_value + ) + onnxscript_value.shape = torch.Size([1]) + elif isinstance(expected_value, (int, float, bool)): + onnxscript_value.dtype = fx_type_utils.from_scalar_type_to_torch_dtype( + type(expected_value) + ) + onnxscript_value.shape = torch.Size([]) + elif isinstance(expected_value, complex): + # From complex scalar to real representation + onnxscript_value_to_torch_dtype = ( + fx_type_utils.from_scalar_type_to_torch_dtype(type(expected_value)) + ) + onnxscript_value.dtype = ( + fx_type_utils.from_complex_to_float(onnxscript_value_to_torch_dtype) + if onnxscript_value_to_torch_dtype is not None + else None + ) + onnxscript_value.shape = torch.Size([2]) + elif fx_type_utils.is_torch_complex_dtype(expected_value.dtype): + # Like torch.view_as_real, we flatten complex tensors to real tensors with + # additional last dimension of 2 + onnxscript_value.shape = torch.Size((*expected_value.size(), 2)) + # complex64 -> float32, complex128 -> float64, etc. + onnxscript_value.dtype = fx_type_utils.from_complex_to_float( + expected_value.dtype + ) + # Dispatcher needs to know the value is complex + onnxscript_value.is_complex = True + else: + # We set node output sizes to be dynamic to continue the model conversion, + # and inputs are also set to be dynamic in add_input(). + onnxscript_value.shape = expected_value.size() + onnxscript_value.dtype = expected_value.dtype + + # naming + if i > 0: + onnxscript_value.name = f"{name}_{i}" + else: + onnxscript_value.name = name + + +def _fill_in_default_kwargs( + node: torch.fx.Node, +) -> tuple[list[fx_type_utils.Argument], dict[str, fx_type_utils.Argument]]: + """Find and Fill in the not provided kwargs with default values.""" + + # TODO: aten::sym_size has overload, but fx graph is using + # overloadpacket for some reasons. + # https://github.com/pytorch/pytorch/issues/97201 + # We manually assigned overload for aten::sym_size. + if hasattr(node.target, "_schema"): + node_schema = node.target._schema # type: ignore[union-attr] + else: + node_schema = torch.ops.aten.sym_size.int._schema # type: ignore[union-attr] + + # This function assumes the order of arguments in FX op is the + # same as the order of arguments in TorchScript op. + complete_args: list[fx_type_utils.Argument] = [] + complete_kwargs: dict[str, fx_type_utils.Argument] = {} + + if inspect.isbuiltin(node.target): + complete_args = list(node.args) + else: + for i, expected_arg in enumerate(node_schema.arguments): + if i < len(node.args): + complete_args.append(node.args[i]) + elif expected_arg.name in node.kwargs: + complete_kwargs[expected_arg.name] = node.kwargs[expected_arg.name] + else: + # Get default from schema. + complete_kwargs[expected_arg.name] = expected_arg.default_value + + return complete_args, complete_kwargs + + +def _wrap_fx_args_as_onnxscript_args( + complete_args: list[fx_type_utils.Argument], + complete_kwargs: dict[str, fx_type_utils.Argument], + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, +) -> tuple[ + Sequence[ + onnxscript_graph_building.TorchScriptTensor + | str + | int + | float + | bool + | list + | complex + | None + ], + dict[str, fx_type_utils.Argument], +]: + """Map all FX arguments of a node to arguments in TorchScript graph.""" + + onnxscript_args = tuple( + _retrieve_or_adapt_input_to_graph_set(arg, fx_name_to_onnxscript_value, tracer) + for arg in complete_args + ) + onnxscript_kwargs = filter_incompatible_and_dtype_convert_kwargs(complete_kwargs) + + return onnxscript_args, onnxscript_kwargs + + +class FxOnnxInterpreter: + """Stateless class to process FX graph Nodes and translate them into their ONNX counterparts. + + All FX nodes described by [FX Graph](https://pytorch.org/docs/stable/fx.html#torch.fx.Graph) are supported. + Similarly to [FX Interpreter pattern](https://pytorch.org/docs/stable/fx.html#torch.fx.Interpreter), each FX node + must be implemented on its own method in this class. + + Each operator's implementation returns either an `onnxscript.OnnxFunction` or + `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. They can + also raise RuntimeError: If there are no overloaded functions available for the given FX node. + """ + + def run_node( + self, + node, + fx_graph_module: torch.fx.GraphModule, + onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, + onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, + onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + ): + """Execute a single FX node to produce its ONNX counterpart. + + Args: + node: The FX node to be translated. + fx_graph_module: The FX graph module containing the node. + onnxfunction_dispatcher: The dispatcher to find the best matched ONNX op. + onnxscript_graph: The ONNX graph to be populated. + onnxscript_tracer: The tracer to trace the ONNX graph. + fx_name_to_onnxscript_value: The mapping from FX node name to ONNX Script value. + + Raises: + RuntimeError: When a node.op is not supported. + """ + if node.op == "placeholder": + self.placeholder(node, onnxscript_graph, fx_name_to_onnxscript_value) + elif node.op == "get_attr": + self.get_attr( + node, + onnxscript_graph, + fx_name_to_onnxscript_value, + fx_graph_module, + ) + elif node.op == "call_function": + self.call_function( + node, + onnxscript_tracer, + fx_name_to_onnxscript_value, + onnxfunction_dispatcher, + fx_graph_module, + ) + elif node.op == "call_method": + self.call_method(node) + elif node.op == "call_module": + self.call_module( + node, + onnxscript_graph, + fx_name_to_onnxscript_value, + onnxscript_tracer, + fx_graph_module, + onnxfunction_dispatcher, + ) + elif node.op == "output": + self.output(node, onnxscript_graph, fx_name_to_onnxscript_value) + else: + raise RuntimeError(f"Found node type not defined in torch.fx: {node.op}") + + def run( + self, + fx_graph_module: torch.fx.GraphModule, + onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, + parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph + | None = None, + ) -> onnxscript_graph_building.TorchScriptGraph: + """Analyze all FX nodes and trigger their ONNX translation. + + Args: + fx_graph_module: FX graph module to be translated. + onnxfunction_dispatcher: ONNX function dispatcher. + parent_onnxscript_graph: The parent TorchScript graph. Must be provided if + `fx_graph_module` is a submodule. If not provided, + `fx_graph_module` is assumed to be the root module. + """ + if parent_onnxscript_graph is not None: + # If parent_onnxscript_graph is provided, we assume fx_graph_module is a + # submodule representing a forward call of an nn.Module. + # Compose package and version where the nn.Module is defined as domain name + # for the local function. + + onnx_meta: _pass.GraphModuleOnnxMeta | None = fx_graph_module.meta.get( + "onnx" + ) + if onnx_meta is None: + raise RuntimeError( + f"ONNX meta is not found in submodule {fx_graph_module._get_name()}. " + f"Only submodules produced by `Modularize` pass is supported in ONNX export." + ) + + onnx_domain = onnx_meta.package_info.to_onnx_domain_string() + else: + # Leave as default domain name for the root module. + onnx_domain = None + + onnxscript_graph = onnxscript_graph_building.TorchScriptGraph( + parent_onnxscript_graph, domain_name=onnx_domain + ) + onnxscript_tracer = onnxscript_graph_building.TorchScriptTracingEvaluator( + onnxscript_graph + ) + # In the following loop, a TorchScript graph is created to + # represent the input FX graph with ONNX symbols (e.g., onnx::add). + # To connect the values to nodes in the TorchScript graph, we maintain + # fx_name_to_onnxscript_value. Basically, we want to translate + # fx_tensor_x (type: torch.fx.Node) -> fx_node_1 -> fx_tensor_y (type: torch.fx.Node) + # to + # fx_name_to_onnxscript_value[fx_tensor_x.name] -> onnx_node_1 -> fx_name_to_onnxscript_value[fx_tensor_y.name] + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ] = {} + + # TODO: Fix FakeTensorMode limitation asap + # We want to pass list of ints and floats to TorchScript graph correctly + # in _export_fx_to_ts, so we must disable FakeTensorMode. Otherwise, graph may + # receive FakeTensor and results runtime error. In addition, TorchScript-based + # ONNX exporter used in _ts_graph_to_onnx_model_in_protobuf is not compatible + # with FakeTensorMode. + with torch.utils._mode_utils.no_dispatch(): + for node in fx_graph_module.graph.nodes: + self.run_node( + node, + fx_graph_module, + onnxfunction_dispatcher, + onnxscript_graph, + onnxscript_tracer, + fx_name_to_onnxscript_value, + ) + + return onnxscript_graph + + def placeholder( + self, + node: torch.fx.Node, + onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + ): + # Input of graph. + # The node.meta["val"] is generated by FakeTensorProp. + # NOTE: add_input() intends to create nodes with shape/type + fake_tensor = node.meta.get("val", None) + # NOTE: During the tracing, when inputs are constants, they are represented + # by nodes with node.meta['val'] being None (nn.Module to dynamo_export) + # or nodes with node.meta['val'] being a builtin value (ExportedProgram to dynamo_export). + # Nonethless, the nodes are not consumed by others, so we don't need to + # create a TorchScriptTensor for them. + if fake_tensor is None or isinstance(fake_tensor, (int, float, bool, str)): + output = onnxscript_graph.add_input( + input_name=None, + ) + elif isinstance(fake_tensor, torch.Tensor): + # NOTE: ONNX doesn't support tensor of complex64/complex128, so we + # convert them to float32/float64 with real representation. + if fx_type_utils.is_torch_complex_dtype(fake_tensor.dtype): + fake_tensor = torch.view_as_real(fake_tensor.resolve_conj()) + output = onnxscript_graph.add_input( + input_name=node.name, + shape=fake_tensor.shape, + dtype=fake_tensor.dtype, + ) + + elif fx_type_utils.is_torch_symbolic_type(fake_tensor): + output = onnxscript_graph.add_input( + input_name=node.name, + shape=torch.Size([]), + dtype=fx_type_utils.from_sym_value_to_torch_dtype(fake_tensor), + ) + else: + raise RuntimeError( + f"Unsupported type(node.meta['val']) for placeholder: {type(fake_tensor)}" + ) + assert output is not None, ( + f"Node creates None with target={node.target} and name={node.name}" + ) + + assert isinstance(output, onnxscript_graph_building.TorchScriptTensor) + assert isinstance(output, onnxscript.tensor.Tensor) + + fx_name_to_onnxscript_value[node.name] = output + + def call_function( + self, + node: torch.fx.Node, + onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, + fx_graph_module: torch.fx.GraphModule, + ): + # aten ops and other stateless functions. + if node.target == operator.getitem and isinstance( + fx_name_to_onnxscript_value[node.args[0].name], # type: ignore[union-attr,index] + tuple, + ): + onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr,index] + index = node.args[1] + value = onnx_tensor_tuple[index] # type: ignore[index] + assert value is not None, ( + f"Node creates None with target={node.target} and name={node.name}" + ) + assert isinstance( + value, (onnxscript_graph_building.TorchScriptTensor, tuple) + ), type(value) + + fx_name_to_onnxscript_value[node.name] = value + return + + # Map FX inputs to ONNX inputs and fill optional inputs with default values. + # torch_args and torch_kwargs are for op-level validation + fx_args, fx_kwargs = _fill_in_default_kwargs(node) + + onnx_args, onnx_kwargs = _wrap_fx_args_as_onnxscript_args( + fx_args, + fx_kwargs, + fx_name_to_onnxscript_value, + onnxscript_tracer, + ) + # Dispatch to ONNX op through OpShema. The input argument dtypes are compared to + # function signature in OpSchema, and find the best matched overload. + symbolic_fn = onnxfunction_dispatcher.dispatch( + node=node, + onnx_args=onnx_args, # type: ignore[arg-type] + onnx_kwargs=onnx_kwargs, + ) + with onnxscript.evaluator.default_as(onnxscript_tracer): + output: ( + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...] + ) = symbolic_fn(*onnx_args, **onnx_kwargs) + assert output is not None, ( + f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}" + ) + # Assign type and shape from fx graph. + _fill_tensor_shape_type(output, node.name, node.meta["val"]) + # One fx node could produce multiple outputs (e.g., tuple of tensors); in + # that case, v is a tuple of TorchScriptTensors. + assert isinstance( + output, (onnxscript_graph_building.TorchScriptTensor, tuple) + ), type(output) + fx_name_to_onnxscript_value[node.name] = output + + def output( + self, + node: torch.fx.Node, + onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + ): + if isinstance(node.args[0], torch.fx.Node): + onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] + onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple) + else: + # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of + # tensor, etc), we flatten the collection and register each element as output. + flat_args, _ = _pytree.tree_flatten(node.args[0]) + for arg in flat_args: + assert isinstance(arg, torch.fx.Node), ( + f"arg must be a torch.fx.Node, not {type(arg)}" + ) + onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[arg.name] + onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple) + + def call_method(self, node: torch.fx.Node): + # TODO(wechi): Support call_method. + raise RuntimeError("call_method is not supported yet.") + + def call_module( + self, + node: torch.fx.Node, + parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, + root_fx_graph_module: torch.fx.GraphModule, + onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, + ) -> None: + """Export a fx.GraphModule submodule to ONNXScript graph. + + The export process specifically targets `call_module` nodes that are created by + the exporter's `Modularize` pass. Each `call_module` node has an associated fx.GraphModule + by `node.target` underneath the root fx.GraphModule. These `call_module` nodes are exported as ONNX + function nodes. The related `sub_module` is then exported as an ONNX model local function, + which is represented by another `TorchScriptGraph`. This `TorchScriptGraph` sets the current + `onnxscript_graph` as its parent. + + Args: + node: The call_module node in the FX graph that represents the submodule call. + parent_onnxscript_graph: The parent ONNXScript graph to which the ONNX function and + function node belong. + fx_name_to_onnxscript_value: The mapping from FX node name to ONNXScript value. + tracer: The tracer used to trace the ONNXScript graph. + root_fx_graph_module: The root FX module. + onnxfunction_dispatcher: The dispatcher. + """ + assert isinstance(node.target, str), ( + f"node.target must be a str, not {type(node.target)} for node {node}." + ) + + sub_module = root_fx_graph_module.get_submodule(node.target) + + assert isinstance(sub_module, torch.fx.GraphModule), ( + f"sub_module must be a torch.fx.GraphModule, not {type(sub_module)} for node {node}." + ) + + sub_onnxscript_graph = self.run( + sub_module, onnxfunction_dispatcher, parent_onnxscript_graph + ) + + onnx_args, _ = _wrap_fx_args_as_onnxscript_args( + list(node.args), {}, fx_name_to_onnxscript_value, tracer + ) + + # TODO: We may want to consider other naming styles. The goal is to be stable and + # unique such that it can be easily identified in case of kernel substitution. + # Example for current style is combination of qualified module class name and + # module attribute name: `torch_nn_modules_conv_Conv2d_conv1`. + # Other naming styles such as qualified module class name made unique can also + # be considered. + unique_module_name = f"{sub_module._get_name()}_{node.target}" + + outputs: ( + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...] + ) = parent_onnxscript_graph.add_module_call( # type: ignore[assignment] + unique_module_name, sub_onnxscript_graph, onnx_args + ) + + assert isinstance( + outputs, (onnxscript_graph_building.TorchScriptTensor, tuple) + ), f"Unexpected outputs type {type(outputs)} for node {node}." + + _fill_tensor_shape_type(outputs, node.name, node.meta["val"]) + fx_name_to_onnxscript_value[node.name] = outputs + + # Skip op_level_validation for call_module. Subgraph nodes are validated individually. + + def get_attr( + self, + node: torch.fx.Node, + onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, + fx_name_to_onnxscript_value: dict[ + str, + onnxscript_graph_building.TorchScriptTensor + | tuple[onnxscript_graph_building.TorchScriptTensor, ...], + ], + fx_graph_module: torch.fx.GraphModule, + ): + assert isinstance(node.target, str), f"node.target {node.target} is not a str." + attr_tensor = getattr(fx_graph_module, node.target) + assert isinstance(attr_tensor, torch.Tensor), f"{attr_tensor} is not a tensor." + + # Parameter/buffer name cannot contain "." + # Revert from "/" to restore namespace formatting. + input_ = onnxscript_graph.add_initializer( + name=node.target.replace("/", "."), + value=attr_tensor, + ) + + assert isinstance(input_, onnxscript_graph_building.TorchScriptTensor) + assert isinstance(input_, onnxscript.tensor.Tensor) + fx_name_to_onnxscript_value[node.name] = input_ diff --git a/torch/onnx/_internal/fx/onnxfunction_dispatcher.py b/torch/onnx/_internal/fx/onnxfunction_dispatcher.py new file mode 100644 index 0000000000000..516eb36368886 --- /dev/null +++ b/torch/onnx/_internal/fx/onnxfunction_dispatcher.py @@ -0,0 +1,731 @@ +# mypy: allow-untyped-defs +"""Dispatcher for AtenLib functions from onnx-script. + +This is a deprecated module to be removed. +""" + +from __future__ import annotations + +import logging +import operator +import types +from typing import Any, TYPE_CHECKING + +import torch +import torch._ops +import torch.fx +from torch.onnx._internal.fx import registration, type_utils as fx_type_utils + + +if TYPE_CHECKING: + from collections.abc import Sequence + + import onnxscript # type: ignore[import] + from onnxscript.function_libs.torch_lib import ( # type: ignore[import] + graph_building as onnxscript_graph_building, + ) + + from torch.onnx._internal._exporter_legacy import OnnxRegistry + + +logger = logging.getLogger(__name__) + + +class OnnxFunctionDispatcher: + """A dispatcher that finds the best ONNX Function for ATen/Custom operators. + + It uses the `torch.ops` name to find the function. If not found, it falls back to default. + Otherwise, the best match is found among all function overloads. An exact match has + higher precedence over the closest ones. + + Below is a breakdown on how the dispatch mechanism works: + + 1. Use the torch.ops name to find the function: + a. Check if the ATen overload exists in the registry. + b. If not, check if the default overload exists in the registry. + + 2. Find the nearest match among all overloaded functions: + a. If the types match perfectly, select the function. + b. Otherwise, find the nearest one with the highest matching score. Because of + the potential wrongly annotated dtypes and attributes matching, we use + nearest match to find the best function once the aten name is targeted. + + 3. Tie-breaker: If there are multiple nearest matches, we will select the one with + the highest matching score. + + NOTE: The nearest match `doesn't guarantee` a correct match, and a warning message is logged. + """ + + def __init__( + self, + onnx_registry: OnnxRegistry, + ): + """Initialize the ONNX Function dispatcher. + + Args: + onnx_registry: The ONNX registry. + """ + self.onnx_registry = onnx_registry + + def dispatch( + self, + node: torch.fx.Node, + onnx_args: Sequence[ + fx_type_utils.TensorLike | str | int | float | bool | list | complex | None + ], + onnx_kwargs: dict[str, fx_type_utils.Argument], + ) -> onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction: + """Dispatches an ONNX function based on the given FX node, arguments, and keyword arguments. + Args: + node: The TorchFX node to dispatch the function for. + onnx_args: The arguments of the ONNX function. + onnx_kwargs: The keyword arguments of the ONNX function. + + Returns: + Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. + Raises: + RuntimeError: If there are no overloaded functions available for the given FX node. + """ + # If there are no overloaded functions available for the given FX node, raise an + # unsupported error + default_and_custom_functions = self.get_function_overloads(node) + + # If there are overloaded functions available, we will find one that perfect or + # nearest matches the given arguments and keyword arguments + return self._find_the_perfect_or_nearest_match_onnxfunction( + node, + default_and_custom_functions, + onnx_args, + onnx_kwargs, + ) + + def _filter_or_keep_complex( + self, + node, + default_and_custom_functions: list[registration.ONNXFunction], + ) -> list[registration.ONNXFunction]: + """Filter the complex functions if the input has complex dtype.""" + + args_with_complex_dtype = [_is_arg_with_complex_dtype(arg) for arg in node.args] + if any(args_with_complex_dtype): + default_and_custom_functions = [ + func for func in default_and_custom_functions if func.is_complex + ] + # If we can't find the complex function group, raise error. + if not default_and_custom_functions: + op_full_name = self._get_aten_name(node).qualified_name() + raise RuntimeError( + f"Cannot find any COMPLEX symbolic function for {op_full_name}, " + f"which should be registered under {node.target}.", + ) + else: + default_and_custom_functions = [ + func for func in default_and_custom_functions if not func.is_complex + ] + # If we can't find the complex function group, raise error. + if not default_and_custom_functions: + op_full_name = self._get_aten_name(node).qualified_name() + raise RuntimeError( + f"Can ONLY find COMPLEX symbolic function for {op_full_name}, " + f"which should be registered under {node.target}.", + ) + return default_and_custom_functions + + def _find_the_perfect_or_nearest_match_onnxfunction( + self, + node: torch.fx.Node, + default_and_custom_functions: list[registration.ONNXFunction], + onnx_args: Sequence[ + fx_type_utils.TensorLike | str | int | float | bool | list | complex | None + ], + onnx_kwargs: dict[str, fx_type_utils.Argument], + ): + """Find the perfect/nearest matched OnnxFunction for the given FX node, arguments, and keyword arguments. + + Args: + default_and_custom_functions: The list includes overloaded functions, with + custom ones appearing after the default ones. + onnx_args: Arguments organized in PyTorch inputs way. + onnx_kwargs: Keyword arguments organized in PyTorch inputs way. + + Returns: + Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. + Raises: + RuntimeError: If there are no overloaded functions available for the given FX node. + """ + overload_match_ranking: dict[registration.ONNXFunction, int | None] = {} + + # Iterate the overloaded functions in reverse order to prioritize the custom ones + # over the default ones, and find the perfect match. + for symbolic_function in reversed(default_and_custom_functions): + function_opschema = _OnnxSchemaChecker(symbolic_function.onnx_function) + + # NOTE: 1. If the perfect match is found, return the function + if function_opschema.perfect_match_inputs(onnx_args, onnx_kwargs): + return symbolic_function.onnx_function + # Record the match score for the nearest match if it's not the perfect match + overload_match_ranking[symbolic_function] = function_opschema.match_score + + # NOTE: 2. If there is no perfect match, find the nearest match among the nearest matche candidates + # If there is no nearest match, raise an error + overload_match_ranking = { + k: v for k, v in overload_match_ranking.items() if v is not None + } + if not overload_match_ranking: + # If there are no overloaded functions available for the given FX node, raise an + # unsupported error + op_full_name = self._get_aten_name(node).qualified_name() + raise RuntimeError( + f"Cannot find any perfect/nearest match of symbolic function for {op_full_name}," + f"which should be registered under {node.target}.", + ) + + # NOTE: 3. Tie breaker: if there are multiple nearest matches, we will choose the one + # that is custom first. If there are multiple custom ones, we will choose the one + # that is added lastly in the list. + symbolic_function_list: list[registration.ONNXFunction] = sorted( + overload_match_ranking, + key=lambda k: ( + overload_match_ranking[k], + k.is_custom, + default_and_custom_functions.index(k), + ), + reverse=True, + ) + return symbolic_function_list[0].onnx_function + + def _get_aten_name(self, node: torch.fx.Node) -> registration.OpName: + """Get the OpName from the target. + + Args: + node: The TorchFX node to get the aten name for. + + Returns: + The internal op name within dataclass: registration.OpName. + """ + if node.target == operator.getitem: + return registration.OpName.from_name_parts( + namespace="aten", op_name="getitem" + ) + if isinstance(node.target, torch._ops.OpOverloadPacket): + # aten::sym_size is the only OverloadPacket that we support. + # schema: aten::sym_size(Tensor self, int dim) -> Tensor + if node.target != torch.ops.aten.sym_size: + raise RuntimeError( + f"Unsupported OverloadPacket: {node.target}, aten.sym_size is the only allowed OverloadPacket!", + ) + # TODO(titaiwang): aten::sym_size has overload, but fx graph is using + # overloadpacket for some reasons. + # https://github.com/pytorch/pytorch/issues/97201 + aten_op_default = node.target.default + return registration.OpName.from_op_overload(op_overload=aten_op_default) # type: ignore[no-any-return] + + if isinstance(node.target, types.BuiltinFunctionType): + # Make sure it's symint/symfloat consuming builtin ops. + for node_arg in node.args: + if (not isinstance(node_arg, (torch.fx.Node, int, float))) or ( + isinstance(node_arg, torch.fx.Node) + and not fx_type_utils.is_torch_symbolic_type(node_arg.meta["val"]) + ): + raise RuntimeError( + f"Unsupported node arg: {node_arg} (type {type(node_arg)}) with builtin function: {node.target}," + " only int/float/SymInt/SymFloat is supported with built-in ops!", + ) + return registration.OpName.from_builtin_function(node.target) + + if isinstance(node.target, torch._ops.OpOverload): + return registration.OpName.from_op_overload(op_overload=node.target) + + # Unexpected target, raise error. + raise RuntimeError(f"Unknown call_function target: {node.target}") + + def get_function_overloads( + self, + node: torch.fx.Node, + ) -> list[registration.ONNXFunction]: + """Get the function overloads from the registry. + + Args: + node: The node to get the function overloads for. + + Returns: + The list contains ONNXFunctions, starting with the default ones and + followed by any custom ones. + """ + + internal_opname: registration.OpName = self._get_aten_name(node=node) + + # If the ATen/Custom operators are not registered, the group will be None. + # And non-registered ATen/Custom operators will trigger error in the next step. + function_group: list[registration.ONNXFunction] | None = None + + function_group = self.onnx_registry.get_op_functions( + namespace=internal_opname.namespace, + op_name=internal_opname.op_name, + overload=internal_opname.overload, + ) + + # NOTE: Fall back to default overload if the ONNX registry doesn't have the overload. + if function_group is None: + function_group = self.onnx_registry.get_op_functions( + namespace=internal_opname.namespace, + op_name=internal_opname.op_name, + overload=None, + ) + if function_group is not None: + op_full_name = internal_opname.qualified_name() + + if function_group is not None: + # NOTE: If the input has complex dtype, we will only dispatch to the complex functions. + function_group = self._filter_or_keep_complex(node, function_group) + return function_group # type: ignore[return-value] + + op_full_name = internal_opname.qualified_name() + raise RuntimeError( + f"Cannot find symbolic function for {op_full_name}, " + f"which should be registered under {node.target}.", + ) + + +class _OnnxSchemaChecker: + """ + The OnnxSchemaChecker class is a checker for ONNX OpSchema and param schema. + + It provides methods to check for input compatibility based on the OpSchema. It also + provides a matching score to indicate how well the OpSchema matches the input and + kwargs types. A function will be evaluated as perfect match, nearest match eligible, + or no match. + + Here are some common examples in categories: + + 1. [NOTE: Perfect match]: The number of inputs and attributes are exactly the same as + the OpSchema. The types of inputs and attributes are exactly the same as the + OpSchema. + + ```python + inputs = (Tensor[2, 3], Tensor[2, 3]) + attributes = {"alpha": 1.0} + + + @torch_op("aten::op") + def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal: ... + ``` + Result: Perfect match. + + 2. [NOTE: Optional input]: The dispatcher recognizes optional inputs. However, + the input can't be ignored. None must be provided. + + ```python + inputs = (Tensor([2, 3]), None) + attributes = {} + + aten_op(X: TTensor, Y: Optional[INT64]): + ... + ``` + Result: Perfect match. + Real example: `aten::convolution`. + + 3. [NOTE: Different attributes]: If an attribute is provided with value, it's + a must to match the attribute in function signature. + ```python + inputs = (Tensor([2, 3]),) + attributes = {"a":1, "b":2} + + aten_op(X: TTensor, a: int): + ... + ``` + Result: No match. + Real example: `aten::div` vs `aten::div.Tensor_mode`. + + 4. [NOTE: Default attributes]: Default attribute will fill in the value into + inputs/attributes. + ```python + inputs = (Tensor([2, 3]),) + attributes = {} + + aten_op(X: TTensor, a: int = 3): + ... + ``` + Result: Perfect match. + Real example: `aten::clone` + + 5. [NOTE: Ignore attribute with None value]: The attributes with None value + will be ignored in matching. + ```python + inputs = (Tensor([2, 3]),) + attributes = {"a": None} + + aten_op(X: TTensor): + ... + ``` + Result: Perfect match. + + ```python + inputs = (Tensor([2, 3]),) + attributes = {"a": None} + + aten_op(X: TTensor, a: int = 3): + ... + ``` + Result: Nearest match eligible. + + Real example: `aten::div` vs `aten::div.Tensor_mode`. + + Attributes: + onnxfunction: The OnnxFunction. + param_schema: The parameter schema defined in the OnnxFunction. + op_schema: The ONNX OpSchema. + type_constraints: The type constraints defined in the OpSchema. + attributes: The attributes defined in the OpSchema. + _matching_score: The matching score of the OnnxSchemaChecker . + + """ + + def __init__( + self, + onnxfunction: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, + ): + """Initialize the OnnxSchemaChecker . + + Args: + onnxfunction: The OnnxFunction. + """ + self.onnxfunction = onnxfunction + self.param_schema = self.onnxfunction.param_schemas() + op_schema = self.onnxfunction.op_schema + # Both `OnnxFunction` and `TracedOnnxFunction` never return None for `op_schema`. + # However their base class would. Hence return type is annotated as Optional[OpSchema]. + assert op_schema is not None + self.op_schema = op_schema + self.type_constraints = { + # "T": {"tensor(int64)"} + constraint.type_param_str: set(constraint.allowed_type_strs) + for constraint in self.op_schema.type_constraints + } + self.attributes = self.op_schema.attributes + self._matching_score: int | None = None + + @property + def match_score(self) -> int | None: + """The matching score of the OnnxSchemaChecker . + + If this remains None, it means the matching score has not been calculated, + and it's not a nearest match candidate. + + Returns: + The matching score of the OnnxSchemaChecker . + """ + return self._matching_score + + def perfect_match_inputs( + self, + args: Sequence[ + fx_type_utils.TensorLike | str | int | float | bool | list | complex | None + ], + kwargs: dict[str, fx_type_utils.Argument], + ) -> bool: + """Check if the inputs perfectly match the OpSchema requirements. + + The definition of perfect match is that the input types are all in the type + constraints and the number of inputs matches the number of inputs in the + OpSchema. + + Checking steps: + 1. The function signature matches the inputs number, and attribute names. + 2. The input/attribute types are all in the type constraints. + + A function should at least pass the first step to be eligible for the + nearest matching. + + Args: + args: The input arguments organized in PyTorch inputs way. + kwargs: The input keyword arguments organized in PyTorch inputs way. + + Returns: + True if the inputs match the requirements, False otherwise. + """ + + # NOTE: OnnxFunction does not have the same function signature as the original + # PyTorch operator. We need to separate the input/attributes from the arguments. + ( + function_inputs, + function_attributes, + ) = self._separate_input_attributes_from_arguments( + self.param_schema, + args, + kwargs, + fill_defaults=True, # fill defaults for optional arguments to match + ) + # NOTE: 1. Check if the input number and attribute names match the + # OpSchema. If it's not, we know the function is not eligible to be a perfect + # match, nor a nearest match. + # We use is_perfect_match to postpone the return value to the end + # of the function, as we want to log all the mismatch info. + is_perfect_match = True + if len(function_inputs) != len(self.op_schema.inputs): + logger.info( + "Actual %d vs expected %d", + len(function_inputs), + len(self.op_schema.inputs), + ) + logger.info("The function is not a nearest match candidate.") + is_perfect_match = False + + if set(function_attributes) != set(self.attributes): + logger.info("The function is not a nearest match candidate.") + is_perfect_match = False + + # If it's already not a perfect match, we can return False directly. Further + # checking is only for the functions that are eligible for nearest match. + if not is_perfect_match: + return False + + # NOTE: 2. The dtypes of inputs and attributes should be in the + # type constraints of the OpSchema. If they are not, we know the function is not + # eligible to be a perfect match, but can be a nearest match candidate. + for schema_input, torch_input in zip(self.op_schema.inputs, function_inputs): + torch_input_compatible_types = _find_onnx_data_type(torch_input) + allowed_types = self.type_constraints[schema_input.type_str] + if not allowed_types.intersection(torch_input_compatible_types) and not any( + fx_type_utils.is_optional_onnx_dtype_str(onnx_type_str) + for onnx_type_str in allowed_types + ): + # If torch_input_compatible_types isn't in allowed_types + # of this input defined in the OpSchema, we know the function + # and the input are not compatible + logger.info( + "Actual %s vs\nExpected %s", + torch_input_compatible_types, + allowed_types, + ) + is_perfect_match = False + + for attribute_name, attribute in function_attributes.items(): + if not self._match_onnx_attribute_type(attribute_name, attribute): + # If the attribute type of the OpSchema and the attribute type don't match, + # we know the function and the input are not compatible + logger.info( + "Actual %s vs\nExpected %s", + type(attribute), + self.attributes[attribute_name].type, + ) + is_perfect_match = False + + # NOTE: This is still a candidate for nearest match, as it only mismatches attributes on dtype. + self._record_matching_score(function_inputs, function_attributes) + logger.info("match score: %d", self.match_score) + return is_perfect_match + + def _match_onnx_attribute_type( + self, + attribute_name: str, + attribute: fx_type_utils.Argument | onnxscript_graph_building.TorchScriptTensor, + is_sequence: bool = False, + ) -> bool: + if isinstance(attribute, (int, float, bool, str)): + attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type( + type(attribute), is_sequence=is_sequence + ) + if attribute_onnx_type != self.attributes[attribute_name].type: + return False + # If the attribute is an empty list, we don't know the type of the list + # so it's a mismatch + elif isinstance(attribute, (list, tuple)) and attribute: + return self._match_onnx_attribute_type( + attribute_name, attribute[0], is_sequence=True + ) + else: + # NOTE: Unrecognized attribute type + return False + return True + + def _record_matching_score( + self, + inputs: Sequence[ + fx_type_utils.TensorLike | str | int | float | bool | list | complex | None + ], + attributes: dict[str, fx_type_utils.Argument], + ): + """Calculate the inputs matching score of the OpSchema requirements to find the nearest match. + + Only the functions which have the same number of inputs and attributes as the + OpSchema are eligible to be a nearest match candidate. Thus, we don't need to + check the length of inputs and attributes here, and only check the types of + inputs and attributes. + + How the matchsing score is calculated: + score += 1 if one input/attribute type is in the type constraints. + + Limitations: + None/NoeType/[] could result in zero matches, and the same score of overloads. + + Args: + inputs: The input arguments. + attributes: The input keyword arguments. + + Returns: + True if the inputs match the requirements, False otherwise. + """ + self._matching_score = 0 + # If they have different length of arguments, the score would be lower to those + # functions which have the same length of arguments. + for schema_input, torch_input in zip(self.op_schema.inputs, inputs): + torch_input_compatible_types = _find_onnx_data_type(torch_input) + allowed_types = self.type_constraints[schema_input.type_str] + if allowed_types.intersection(torch_input_compatible_types): + # If torch_input_compatible_types is in allowed_types + # of this input defined in the OpSchema, we know the function + # and the input are compatible + self._matching_score += 1 + # NOTE: The penalty is applied to those functions which have different attributes. + for attribute_name, attribute_proto in self.attributes.items(): + attribute = attributes[attribute_name] + attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type( + type(attribute) + ) + if attribute_onnx_type != attribute_proto.type: + # If the attribute type of the OpSchema and the attribute type don't match, + # we know the function and the input are not compatible + self._matching_score -= 1 + + # NOTE: Referenced from onnxscript internal function. + # Importing this function makes the code less robust, as it is not a public API. + + def _separate_input_attributes_from_arguments( + self, + param_schemas: Sequence[onnxscript.values.ParamSchema], + args: Sequence[ + fx_type_utils.TensorLike | str | int | float | bool | list | complex | None + ], + kwargs: dict[str, fx_type_utils.Argument], + fill_defaults: bool = True, + ) -> tuple[list[Any], dict[str, Any]]: + """Separate Python args and kwargs into ONNX inputs and attributes. + + Extra_kwargs are ignored if their values are None. For example, if the + OpSchema has an attribute "rounding_mode" and the caller provides + "rounding_mode=None", the attribute "rounding_mode" will not be included + in the returned attributes when the OnnxFunction signature doesn't have + "rounding_mode" as an attribute. + + Args: + param_schemas: The parameter schemas of an Op or a OnnxFunction. + args: The Python positional arguments supplied by the caller. + kwargs: The Python keyword arguments supplied by the caller. + fill_defaults: Whether to fill the default values for attributes. + + Returns: + A tuple of two elements: + - A list of ONNX inputs. + - An dictionary of ONNX attribute names and values. + + Raises: + TypeError: When allow_extra_kwargs is False and there are unknown kwargs. + TypeError: When a required input is not provided. + """ + # args, kwargs and param_schemas should be all in order + # user may not specify all inputs or attributes + + import onnx + + onnx_inputs: list[Any] = [] + onnx_attributes: dict[str, Any] = {} + # NOTE: We need to copy kwargs because we will mutate it + copy_kwargs = kwargs.copy() + for i, param in enumerate(param_schemas): + if param.is_variadic_input: + # Exhaust all remaining args + onnx_inputs.extend(args[i:]) + args = [] + continue + if i < len(args): + if param.is_input: + onnx_inputs.append(args[i]) + else: + onnx_attributes[param.name] = args[i] + elif param.name in copy_kwargs: + if param.is_input: + # Move the input from kwargs to inputs + onnx_inputs.append(copy_kwargs[param.name]) + copy_kwargs.pop(param.name) + else: + onnx_attributes[param.name] = copy_kwargs[param.name] + elif ( + param.is_attribute + and self.attributes[param.name].default_value.type + != onnx.AttributeProto.UNDEFINED # type: ignore[attr-defined] + ): + # User did not provide the attribute + if fill_defaults: + onnx_attributes[param.name] = param.default + # optional input + elif param.is_input: + if fill_defaults: + onnx_inputs.append(None) + + # NOTE: Pick up extra kwargs if it's not None. None is not expected + # as an attribute value in torchlib. + for k, v in copy_kwargs.items(): + if k not in onnx_attributes and v is not None: + onnx_attributes[k] = v + return onnx_inputs, onnx_attributes + + +def _is_arg_with_complex_dtype(arg: fx_type_utils.Argument) -> bool: + """Check if the node has complex dtype recursively.""" + if ( + isinstance(arg, torch.fx.Node) + and "val" in arg.meta + and isinstance(arg.meta["val"], torch.Tensor) + and torch.is_complex(arg.meta["val"]) + ): + return True + elif isinstance(arg, list): + for item in arg: + return _is_arg_with_complex_dtype(item) + return False + + +def _find_onnx_data_type( + torch_input: fx_type_utils.TensorLike + | str + | int + | float + | bool + | list + | tuple + | complex + | None, +) -> set[str]: + """Convert inputs data type from torch acceptable dtype to the compatible onnx dtype string.""" + if ( + isinstance(torch_input, fx_type_utils.TensorLike) + and torch_input.dtype is not None + ): + return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(torch_input.dtype) + if isinstance(torch_input, (int, float, bool, str, complex)): + return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(type(torch_input)) + if isinstance(torch_input, (list, tuple)) and torch_input: # [Tensor, Tensor] + the_first_non_none_item = next( + (item for item in torch_input if item is not None), None + ) + set_dtype = _find_onnx_data_type(the_first_non_none_item) + if any(isinstance(input, fx_type_utils.TensorLike) for input in torch_input): + # NOTE: Any Tensor involved in a list would make it a seq(tensor(onnx_type)) + return {f"seq({dtype})" for dtype in set_dtype} + else: + # constant list of non-tensor type + return set_dtype + if ( + torch_input is None + or ( + isinstance(torch_input, fx_type_utils.TensorLike) + and torch_input.dtype is None + ) + or (isinstance(torch_input, (list, tuple)) and not torch_input) + ): + # NOTE: None, No dtype, and empty list are edge cases, we allow it to be any type to relax the type check + # seq(tensor) also goes to here, as it is not supported in torchscript, and it would be None in this case. + return set() + + raise RuntimeError(f"Unknown input type from input: {torch_input}") diff --git a/torch/onnx/_internal/fx/passes/__init__.py b/torch/onnx/_internal/fx/passes/__init__.py index eff83563a5a08..b161d89bcf9f9 100644 --- a/torch/onnx/_internal/fx/passes/__init__.py +++ b/torch/onnx/_internal/fx/passes/__init__.py @@ -1,6 +1,26 @@ +<<<<<<< HEAD from .type_promotion import InsertTypePromotion __all__ = [ "InsertTypePromotion", +======= +from .decomp import Decompose +from .functionalization import Functionalize, RemoveInputMutation +from .modularization import Modularize +from .readability import RestoreParameterAndBufferNames +from .type_promotion import InsertTypePromotion +from .virtualization import MovePlaceholderToFront, ReplaceGetAttrWithPlaceholder + + +__all__ = [ + "Decompose", + "InsertTypePromotion", + "Functionalize", + "Modularize", + "MovePlaceholderToFront", + "RemoveInputMutation", + "RestoreParameterAndBufferNames", + "ReplaceGetAttrWithPlaceholder", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] diff --git a/torch/onnx/_internal/fx/passes/_utils.py b/torch/onnx/_internal/fx/passes/_utils.py new file mode 100644 index 0000000000000..a7b05786ab171 --- /dev/null +++ b/torch/onnx/_internal/fx/passes/_utils.py @@ -0,0 +1,114 @@ +# mypy: allow-untyped-defs +"""Common utility functions for FX passes. + +These functions should NOT be directly invoked outside of `passes` package. +""" + +from __future__ import annotations + +import collections +import re +from typing import Callable + +import torch.fx +import torch.fx.traceback as fx_traceback + + +def wrap_graph_module_for_node_meta_preservation( + graph_module: torch.fx.GraphModule, +) -> Callable: + """Wrap a GraphModule with contexts to preserve node meta information, such as stacktrace info. + + This is typically useful before calling `make_fx`. Without this wrapper, the + stacktrace information will be lost afterwards. + """ + + def wrapped(*args): + with fx_traceback.preserve_node_meta(): + return torch.fx.Interpreter(graph_module).run(*args) + + return wrapped + + +def _get_node_base_name(node_name: str) -> tuple[str, int | None]: + pattern = r"(.*)\.(\d+)" + match = re.match(pattern, node_name) + if match is not None: + base_name, count_str = match.groups() + return base_name, int(count_str) + return node_name, None + + +def set_node_name( + node: torch.fx.Node, + new_name: str, + name_to_node_cache: dict[str, torch.fx.Node], +): + """Safely set the unique name of a node. + + If the new name is already taken by another node, the name of the other node will be + updated. If `new_name` is a string of format f"{base_name}.{count}", where `count` + is an integer, the other node will be renamed as f"{base_name}.{count+1}". If not, + the other node will be renamed as "{new_name}.1". This function will iteratively + update the names until there is no conflict. + + ``name_to_node_cache`` is required as an argument to avoid recomputation. The caller + is responsible for ensuring the cache is accurate and in sync with the owning module + of the node. The values in the cache will be updated accordingly. + + Args: + node: The node to update. + new_name: The new name to use. + name_to_node_cache: A cache of node names to nodes. + """ + node_name_to_set = collections.deque([(node, new_name)]) + + while node_name_to_set: + node, new_name = node_name_to_set.pop() + if new_name in name_to_node_cache and name_to_node_cache[new_name] != node: + base_name, postfix_count = _get_node_base_name(new_name) + if postfix_count is None: + postfix_count = 0 + node_name_to_set.append( + (name_to_node_cache[new_name], f"{base_name}.{postfix_count + 1}") + ) + node.name = new_name + name_to_node_cache[new_name] = node + + +def replace_placeholder_name_and_target( + module: torch.fx.GraphModule, reference_module: torch.fx.GraphModule +): + """Replace the argument names in module with those in reference_module. + + This function assumes the two modules have the same signature structure. + The caller is responsible for ensuring this. Otherwise, the behavior of this + function is undefined. This function only does minimal sanity check that the two + modules have the same number of arguments. + + Name conflicts between new names and existing node names in the graph are handled. + Check the documentation of :func:`set_node_name` for more details. + + Raises: + RuntimeError: If the two modules have different number of arguments. + """ + placeholders = [node for node in module.graph.nodes if node.op == "placeholder"] + reference_placeholders = [ + node for node in reference_module.graph.nodes if node.op == "placeholder" + ] + + if len(placeholders) != len(reference_placeholders): + raise RuntimeError( + "The two modules have different number of arguments. " + f"module: {len(placeholders)}, reference_module: {len(reference_placeholders)}" + ) + + name_to_node: dict[str, torch.fx.Node] = {} + for node in module.graph.nodes: + name_to_node[node.name] = node + + for placeholder, reference_placeholder in zip(placeholders, reference_placeholders): + placeholder.target = reference_placeholder.target + set_node_name(placeholder, reference_placeholder.name, name_to_node) + + module.recompile() diff --git a/torch/onnx/_internal/fx/passes/decomp.py b/torch/onnx/_internal/fx/passes/decomp.py new file mode 100644 index 0000000000000..1573264d6fc76 --- /dev/null +++ b/torch/onnx/_internal/fx/passes/decomp.py @@ -0,0 +1,87 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import contextlib +from typing import Callable, TYPE_CHECKING + +import torch +import torch._ops +from torch._dispatch import python as python_dispatch +from torch._subclasses import fake_tensor +from torch.fx.experimental import proxy_tensor +from torch.onnx._internal.fx import _pass +from torch.onnx._internal.fx.passes import _utils + + +if TYPE_CHECKING: + from collections.abc import Mapping + + import torch.fx + + +class Decompose(_pass.Transform): + def __init__( + self, + module: torch.fx.GraphModule, + decomposition_table: Mapping[torch._ops.OpOverload, Callable], + enable_dynamic_axes: bool, + allow_fake_constant: bool | None = False, + ): + super().__init__(module) + self.decomposition_table = decomposition_table + self.enable_dynamic_axes = enable_dynamic_axes + self.allow_fake_constant = allow_fake_constant + + def _run(self, *args, **kwargs) -> torch.fx.GraphModule: + assert not kwargs, "kwargs is not supported in Decompose." + + # To preserve stack trace info after `make_fx`. + module = _utils.wrap_graph_module_for_node_meta_preservation(self.module) + + # fake mode use static size to trace the size of tensors. while symbolic + # mode generates aten::sym_size to dynamically trace the size of tensors. + + # e.g. fake mode: + # view: f32[3, 5, 20] = torch.ops.aten.view.default(x, [3, 5, 20]) + + # e.g. symbolic mode: + # sym_size = torch.ops.aten.sym_size(x, 0) + # sym_size_1 = torch.ops.aten.sym_size(x, 1) + # sym_size_2 = torch.ops.aten.sym_size(x, 2) + # sym_size_3 = torch.ops.aten.sym_size(x, 3) + # mul = sym_size_2 * sym_size_3; sym_size_2 = sym_size_3 = None + # view: f32[3, 5, 20] = torch.ops.aten.view.default(x, [sym_size, sym_size_1, mul]) + + # Mimic `torch._dynamo.export(aten_graph=True)` behavior in invoking `make_fx`. + # TODO: May need revisit for user fake mode export + dynamic shape scenario. + fake_mode: fake_tensor.FakeTensorMode | None = self.fake_mode + maybe_fake_args = self._maybe_fakefy_args(fake_mode, *args) + if fake_mode is not None: + # Using existing fake mode as context, signal `make_fx` that it does not need + # to create a new fake mode by passing tracing_mode as "real". + tracing_mode = "real" + else: + # Existing fake mode not found, signal `make_fx` to create one. + fake_mode = contextlib.nullcontext() # type: ignore[assignment] + tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake" + + # Apply decomposition table to the input graph. + assert fake_mode is not None # for mypy + with ( + fake_tensor.unset_fake_temporarily(), + python_dispatch.enable_python_dispatcher(), + fake_mode, + ): + decomposed_module = proxy_tensor.make_fx( + module, + decomposition_table=self.decomposition_table, + tracing_mode=tracing_mode, + _allow_non_fake_inputs=True, + _allow_fake_constant=bool(self.allow_fake_constant), + )(*maybe_fake_args) + + # Rename placeholder targets to match the original module's signature since + # We don't want to map forward(x, y, z) to forward(arg0, arg1, arg2). + _utils.replace_placeholder_name_and_target(decomposed_module, self.module) + + return decomposed_module diff --git a/torch/onnx/_internal/fx/passes/functionalization.py b/torch/onnx/_internal/fx/passes/functionalization.py new file mode 100644 index 0000000000000..fd8d3c7d48ac5 --- /dev/null +++ b/torch/onnx/_internal/fx/passes/functionalization.py @@ -0,0 +1,152 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import contextlib +from typing import Callable + +import torch +import torch._ops +import torch.func +import torch.fx +from torch._subclasses import fake_tensor +from torch.fx.experimental import proxy_tensor +from torch.onnx._internal.fx import _pass +from torch.onnx._internal.fx.passes import _utils +from torch.utils import _pytree as pytree + + +class Functionalize(_pass.Transform): + """Functionalize a GraphModule. + + This pass utilizes ``functionalization`` utility of ``torch._functorch`` to convert + a GraphModule into a functional form. The two main functionalities are (copied from + its documentations): + + * ``functionalization`` removes (intermediate) mutations and aliasing from a + function, while preserving the function's semantics. + + * ``functionalization`` also removes mutations (and views) that were performed + on function inputs. However to preserve semantics, functionalize will "fix up" the + mutations after the transform has finished running, by detecting if any tensor inputs + "should have" been mutated, and copying the new data back to the inputs if necessary. + For example, consider:: + + def fn(a, b): + a.add_(b) + return a + + For a call like `fn(x, y)`, the variable `x` outside is also mutated. Hence just + functionalizing is not enough for preserving the original semantics. A "special" + input mutation step needs to be inserted at the end.:: + + # After functionalization, without input mutation "fix up". + # This is not semantically the same. The variable outside the function call that + # was passed in as `a` is not mutated. + def fn(a, b): + new_a = a + b + return new_a + + # Functionalization with input mutation "fix up" that preserves semantics. + def fn(a, b): + new_a = a + b + + # Copying the new data back to the inputs + a.copy_(new_a) + + return new_a + + For ONNX inference, it is recommended to run ``RemoveInputMutation`` after this pass. + ``RemoveInputMutation`` removes the "fix up" nodes that were added by ``Functionalize``, + which are not needed for ONNX inference. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + enable_dynamic_axes: bool, + allow_fake_constant: bool | None = False, + ): + super().__init__(module) + self.enable_dynamic_axes = enable_dynamic_axes + self.allow_fake_constant = allow_fake_constant + + def _functionalize(self, function: Callable) -> Callable: + # Working around a dispatcher issue with `torch.func.functionalize` when used + # together with `make_fx`. + # Ref: https://github.com/pytorch/pytorch/issues/99774#issuecomment-1527949391 + def wrapped(*inputs): + inputs_functional = pytree.tree_map_only( + torch.Tensor, torch._to_functional_tensor, inputs + ) + torch._enable_functionalization(reapply_views=True) + try: + out = function(*inputs_functional) + finally: + torch._disable_functionalization() + + flat_inputs_functional = pytree.tree_leaves(inputs_functional) + for input_functional in flat_inputs_functional: + if isinstance(input_functional, torch.Tensor): + torch._sync(input_functional) + pytree.tree_map(torch._sync, out) + out_unwrapped = pytree.tree_map(torch._from_functional_tensor, out) + return out_unwrapped + + return wrapped + + def _run(self, *args) -> torch.fx.GraphModule: + # To preserve stack trace info after `make_fx`. + module = _utils.wrap_graph_module_for_node_meta_preservation(self.module) + + functionalized_callable = self._functionalize(module) + + # Mimic `torch._dynamo.export(aten_graph=True)` behavior in invoking `make_fx`. + # TODO: May need revisit for user fake mode export + dynamic shape scenario. + fake_mode: fake_tensor.FakeTensorMode | None = self.fake_mode + maybe_fake_args = self._maybe_fakefy_args(fake_mode, *args) + if fake_mode is not None: + # Using existing fake mode as context, signal `make_fx` that it does not need + # to create a new fake mode by passing tracing_mode as "real". + tracing_mode = "real" + else: + # Existing fake mode not found, signal `make_fx` to create one. + fake_mode = contextlib.nullcontext() # type: ignore[assignment] + tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake" + + assert fake_mode is not None # for mypy + with fake_tensor.unset_fake_temporarily(), fake_mode: + graph_module = proxy_tensor.make_fx( + functionalized_callable, + decomposition_table={}, + tracing_mode=tracing_mode, + _allow_non_fake_inputs=True, + _allow_fake_constant=bool(self.allow_fake_constant), + )(*maybe_fake_args) + + # Rename placeholder targets to match the original module's signature since + # We don't want to map forward(x, y, z) to forward(arg0, arg1, arg2). + _utils.replace_placeholder_name_and_target(graph_module, self.module) + + return graph_module + + +class RemoveInputMutation(_pass.Transform): + """Remove `aten.copy_.default` nodes that mutate module inputs. + + This pass is recommended to be used after ``Functionalization`` pass. + ``Functionalization`` pass adds `aten.copy_.default` nodes to the graph + when it detects mutations to inputs. These nodes are not needed for ONNX export + for inference. They could be useful for training. + """ + + def _run(self, *args) -> torch.fx.GraphModule: + for node in reversed(self.module.graph.nodes): + if ( + node.op == "call_function" + and node.target == torch.ops.aten.copy_.default + and len(node.users) == 0 + and isinstance(node.args[0], torch.fx.Node) + and node.args[0].op == "placeholder" + ): + self.module.graph.erase_node(node) + return self.module diff --git a/torch/onnx/_internal/fx/passes/modularization.py b/torch/onnx/_internal/fx/passes/modularization.py new file mode 100644 index 0000000000000..18a424826bfef --- /dev/null +++ b/torch/onnx/_internal/fx/passes/modularization.py @@ -0,0 +1,857 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import abc +import collections +import copy +import operator +from typing import Any, Final, TYPE_CHECKING + +import torch +import torch.fx +from torch.onnx._internal.fx import _pass +from torch.utils import _pytree as pytree + + +if TYPE_CHECKING: + from collections.abc import Generator, Iterator, Sequence + + +_FX_TRACER_NN_MODULE_META_TYPE = tuple[str, type] +"""Legacy type of item from `node.meta["nn_module_stack"].items()` produced by FX symbolic tracer.""" +_FX_TRACER_NN_MODULE_STACK_META_TYPE = collections.OrderedDict +"""Legacy type of `node.meta["nn_module_stack"]` produced by FX symbolic tracer.""" + +_DYNAMO_NN_MODULE_META_TYPE = tuple[str, tuple[str, type]] +"""Type of item from `node.meta["nn_module_stack"].items()` produced by FX dynamo tracer.""" +_DYNAMO_NN_MODULE_STACK_META_TYPE = dict[str, _DYNAMO_NN_MODULE_META_TYPE] +"""Type of `node.meta["nn_module_stack"]` produced by FX dynamo tracer.""" + + +class _ModuleMeta: + """Meta information about a module. + + This class is used to represent the module information in a more structured way. + It parses raw module information from a single item from + `node.meta["nn_module_stack"].items()`. + + See the uses of `from_raw_meta`, `from_fx_tracer_produced_raw_meta`, and + `from_dynamo_produced_raw_meta` for how to create an instance. + + Attributes: + _module_class: The class of the module. E.g. `torch.nn.module.sparse.Embedding`. + _module_name: The name of the module. E.g. `L__self___h_1_mlp_c_proj`. + _raw_meta: The raw meta '(module_name, node.meta["nn_module_stack"][module_name])'. + """ + + _module_class: Final[type | str | None] # type: ignore[misc] + _module_name: Final[str] # type: ignore[misc] + _raw_meta: Final[tuple[Any, Any]] # type: ignore[misc] + + def __init__( + self, + module_name: str, + module_class: type | str | None, + raw_meta: tuple[Any, Any], + ): + self._module_name = module_name + self._module_class = module_class + self._raw_meta = raw_meta + + @property + def module_display_name(self) -> str: + """The display name of the module. + + E.g. `h_1_mlp_c_proj`. + """ + # E.g., from 'L__self___h_1_mlp_c_proj' to 'h_1_mlp_c_proj'. + name = self.module_name + name = name.removeprefix("L__self___") + return name + + @property + def qualified_module_class_name(self) -> str: + """Qualified name of the module class. + + E.g. `torch_nn_module_sparse_Embedding`. + """ + if self._module_class is None: + return "" + mod_cls = self._module_class + if isinstance(mod_cls, type): + mod_cls = mod_cls.__module__ + "." + mod_cls.__qualname__ + return mod_cls.replace(".", "_") + + @property + def module_class_name(self) -> str: + """Name of the module class. + + E.g. `Embedding`. + """ + if self._module_class is None: + return "" + if isinstance(self._module_class, type): + return self._module_class.__name__ + return self._module_class + + @property + def module_name(self) -> str: + """Name of the module. + + E.g. `L__self___h_1_mlp_c_proj`. + """ + return self._module_name + + @property + def raw_meta(self) -> tuple[Any, Any]: + """Returns the raw module meta data. + + I.e. (module_name, node.meta['nn_module_stack'][module_name]). + """ + return self._raw_meta + + def __eq__(self, other: object, /) -> bool: + if not isinstance(other, _ModuleMeta): + return False + return ( + self._module_name == other._module_name + and self._module_class == other._module_class + ) + + def __hash__(self) -> int: + return hash((self._module_name, self._module_class)) + + def __repr__(self) -> str: + return f"ModuleMeta(name={self._module_name}, class={self._module_class})" + + @classmethod + def create_root(cls) -> _ModuleMeta: + """Create an empty module meta representing root module.""" + return _ModuleMeta("", None, ("", None)) + + @classmethod + def from_fx_tracer_produced_raw_meta( + cls, raw_meta: _FX_TRACER_NN_MODULE_META_TYPE + ) -> _ModuleMeta: + """Create a module meta from raw meta produced by FX symbolic tracer.""" + module_name, module_class = raw_meta + return _ModuleMeta(module_name, module_class, raw_meta) + + @classmethod + def from_dynamo_produced_raw_meta( + cls, raw_meta: _DYNAMO_NN_MODULE_META_TYPE + ) -> _ModuleMeta: + """Create a module meta from raw meta produced by FX dynamo tracer.""" + module_name, (_qualified_name, module_class) = raw_meta + return _ModuleMeta(module_name.split("@")[0], module_class, raw_meta) + + @classmethod + def from_raw_meta( + cls, + raw_meta: _FX_TRACER_NN_MODULE_META_TYPE | _DYNAMO_NN_MODULE_META_TYPE, + ) -> _ModuleMeta: + if ( + isinstance(raw_meta, tuple) + and len(raw_meta) == 2 + and isinstance(raw_meta[1], type) + ): + # Trying to do `instance(raw_meta, _FX_TRACER_NN_MODULE_META_TYPE)` + return _ModuleMeta.from_fx_tracer_produced_raw_meta(raw_meta) + if ( + isinstance(raw_meta, tuple) + and len(raw_meta) == 2 + and isinstance(raw_meta[1], tuple) + ): + # Trying to do `instance(raw_meta, _DYNAMO_NN_MODULE_META_TYPE)` + return _ModuleMeta.from_dynamo_produced_raw_meta(raw_meta) + raise TypeError( + f"Unknown type of raw meta item from node.meta['nn_module_stack'].items(): {type(raw_meta)}" + ) + + +class _ModuleStackMeta: + """Meta information about the module call stack. + + This class is used to represent the module call stack information in a more + structured way. It parses raw module stack information from `node.meta["nn_module_stack"]`. + + Example of raw module stack information: + + If produced by dynamo: + + { + 'L__self___h_1': ( + "L['self'].h[1]", + + ), + 'L__self___h_1_attn': ( + "L['self'].h[1].attn", + + ) + } + + If produced by fx.symbolic_trace: + + { + 'h.1': , + 'h.1.attn': + } + """ + + _module_stack: Final[list[_ModuleMeta]] # type: ignore[misc] + + def __init__( + self, + nn_module_stack_meta: _FX_TRACER_NN_MODULE_STACK_META_TYPE + | _DYNAMO_NN_MODULE_STACK_META_TYPE + | None, + is_exported_program: bool = True, + ): + self._module_stack = [] + if nn_module_stack_meta is None: + return + raw_meta = copy.copy(nn_module_stack_meta) + for item in raw_meta.items(): + # If produced by torch.export.export, there is another call stack layer + # that we need to skip + if is_exported_program: + is_exported_program = False + continue + self.push(_ModuleMeta.from_raw_meta(item)) # type: ignore[arg-type] + + def __len__(self) -> int: + return len(self._module_stack) + + def __getitem__(self, index: int) -> _ModuleMeta: + return self._module_stack[index] + + def __iter__(self) -> Iterator[_ModuleMeta]: + return iter(self._module_stack) + + def is_empty_or_root(self) -> bool: + return len(self._module_stack) == 0 + + def top(self) -> _ModuleMeta: + """Returns the top module meta in the stack. I.e., the meta for leaf module. + + Example: + + Consider the following module stack: + + stack = [GPT, block1, Attention_1, MLP] + + stack.top() == MLP + """ + if self.is_empty_or_root(): + return _ModuleMeta.create_root() + return self._module_stack[-1] + + def is_superset_of( + self, + module_stack: _ModuleStackMeta, + ) -> bool: + """Determines if self is a superset of the provided module stack. + + I.e., If self includes all elements from the provided module stack, plus additional + elements on top. If self is empty or root, this method always return False. + + Example: + + Consider the following module stack: + + stack_1 = [GPT, block1, Attention_1, MLP] + stack_2 = [GPT, block1] + + stack_1.is_superset_of(stack_2) == True + stack_2.is_superset_of(stack_1) == False + + stack_3 = [GPT, block2, Attention_1] + + stack_1.is_superset_of(stack_3) == False + stack_3.is_superset_of(stack_1) == False + """ + if self.is_empty_or_root(): + return False + + if module_stack.is_empty_or_root() is None: + return True + + if len(self) <= len(module_stack): + return False + + for i, parent_key in enumerate(module_stack): + if self[i] != parent_key: + return False + + return True + + def push(self, module_meta: _ModuleMeta) -> None: + """Pushes a module meta to the stack.""" + self._module_stack.append(module_meta) + + def __eq__(self, other: object, /) -> bool: + if not isinstance(other, _ModuleStackMeta): + return False + return self._module_stack == other._module_stack + + @property + def raw_meta(self) -> dict[str, tuple[str, type]] | None: + """Returns the raw module stack meta data, i.e. node.meta['nn_module_stack'].""" + return { + module_meta.raw_meta[0]: module_meta.raw_meta[1] + for module_meta in self._module_stack + } + + def __repr__(self) -> str: + return f"ModuleStackMeta({self._module_stack})" + + @property + def module_display_name(self) -> str: + """Returns the module display name of the top module.""" + return self.top().module_display_name + + @property + def qualified_module_class_name(self) -> str: + """Returns the qualified module class name of the top module.""" + return self.top().qualified_module_class_name + + @property + def module_class(self) -> type | str | None: + """Returns the module class of the top module.""" + return self.top()._module_class + + +def _module_stack_meta_from_node( + node: torch.fx.Node, is_exported_program: bool = False +) -> _ModuleStackMeta: + return _ModuleStackMeta( + node.meta.get("nn_module_stack"), is_exported_program=is_exported_program + ) + + +def _get_unique_module_name(module_names: dict[str, int], module_name: str) -> str: + module_names.setdefault(module_name, 0) + module_names[module_name] += 1 + return f"{module_name}_{module_names[module_name]}" + + +class _IRNode(abc.ABC): + """Base class for IR nodes. + + IR nodes are used for Modularize pass only. They add a layer of abstraction on top of + torch.fx.Node. + + [NOTE: Modularize Pass Implementation] + The main job of the pass is to group `fx.Node`s that belong to the same `nn.Module` + forward call, and then create `call_module` node and sub `fx.GraphModule` from them. + Each `fx.Node` possesses an `nn_module_stack` meta data that contains information + about the module call stack. See `_ModuleStackMeta` for examples. + + Analysis step + ------------- + + Each module call is identified by a set of base stack layers. For each module call, + the pass creates a `_ModuleNode` and groups the sequence of nodes that shares the + same base stack layers. + + For example, + + stack_of_node_0 = [GPT, block0] + stack_of_node_1 = [GPT, block1] + stack_of_node_2 = [GPT, block1, Attention1, MLP] + stack_of_node_3 = [GPT, block1, Attention1] + stack_of_node_4 = [GPT, block2] + + All nodes belong to the `GPT` module call, since they share the base stack layers [GPT]. + [node_1, node_2, node_3] are grouped for `GPT.block1`, because they share the base + stack layers [GPT, block1]. And [node_2, node_3] for `GPT.block1.Attention1`, [node_0] + for `GPT.block0`, and [node_4] for `GPT.block2` respectfully. + + After the analysis step, a hierarchical representation is generated. + + For above example, the representation is: + + _ModuleNode(GPT) + _ModuleNode(block0) + _LeafNode(node_0) + _ModuleNode(block1) + _LeafNode(node_1) + _ModuleNode(Attention1) + _ModuleNode(MLP) + _LeafNode(node_2) + _LeafNode(node_3) + _ModuleNode(block2) + _LeafNode(node_4) + + Construction step + ----------------- + + The second step is to build the actual `call_module` node and the sub `fx.GraphModule`. + This is done recursively from the leaf `_ModuleNode` to the root. + + For example, the first submodule to be built is `GPT.block1.Attention1.MLP`. Below pair + is generated from `_ModuleNode(MLP)`. + + fx.GraphModule(GPT.block1.Attention1.MLP) + graph: + node_2 + + new_mlp_node = `call_module[GPT.block1.Attention1.MLP](...)` + + Next, the `GPT.block1.Attention1` submodule is built. Below is generated from + `_ModuleNode(Attention1)`. + + fx.GraphModule(GPT.block1.Attention1) + graph: + new_mlp_node + node_3 + + new_attention1_node = `call_module[GPT.block1.Attention1](...)` + + Until every submodule is built, the new modularized `fx.GraphModule` is generated. + + Alternatives + ------------ + + The current algorithm adopts a top down approach. A bottom up approach is similar. + In contrast to these two, an alternative flat order approach is also possible, where + each node is traversed and copied to the corresponding submodule. + + The advantage of the current approach lies in the encapsulation of the fx.GraphModule + construction for each individual submodule within a single `build_module` method, which + can be called separately once the analysis phase is completed, making debugging more + convenient. + + Regarding construction step, an alternative implementation is to utilize `fx.Interpreter` + for traversing all the nodes under the flattened root module and copying the nodes + into their respective submodule under construction. This approach is not adopted because + + 1. It uses the flat order approach discussed above. This means one cannot individually + construct a submodule and examine it while debugging. + + 2. The graph execution functionality of `fx.Interpreter` is not necessary for the + purpose of this pass. Ignoring that, `fx.Interpreter.run` achieves the same effect + as a for loop over all the nodes. + """ + + @property + @abc.abstractmethod + def stack_meta(self) -> _ModuleStackMeta: + """The module stack meta data associated with this node.""" + ... + + @property + @abc.abstractmethod + def stack_trace(self) -> str | None: + """The stack trace associated with this node.""" + ... + + +class _ModuleNode(_IRNode): + """Representing a sequence of fx.Nodes to be formed into a fx.GraphModule. + + This class encapsulates metadata and provides building block methods to construct this + layered abstraction from a sequence of flat fx.Nodes. + + Attributes: + - _stack_meta: Metadata of the module stack. + - _nodes: List of IR nodes in the module. + - _reference_root_module: Reference to the root flat fx.GraphModule instance. + """ + + def __init__( + self, reference_root_module: torch.fx.GraphModule, stack_meta: _ModuleStackMeta + ): + self._stack_meta = stack_meta + self._nodes: list[_IRNode] = [] + self._reference_module = reference_root_module + + @property + def stack_meta(self) -> _ModuleStackMeta: + return self._stack_meta + + @property + def stack_trace(self) -> str | None: + assert self._nodes + return self._nodes[0].stack_trace + + def __str__(self) -> str: + return f"ModuleNode({self._stack_meta})" + + def is_same_module_as(self, node: _IRNode) -> bool: + """Determines if the provided node pertains to the same module as this node.""" + return self.stack_meta == node.stack_meta + + def is_parent_module_of(self, node: _IRNode) -> bool: + """Determines if this node represents a parent module of the provided node.""" + return node.stack_meta.is_superset_of(self.stack_meta) + + def add_leaf_node(self, leaf_node: _LeafNode) -> None: + """Adds a leaf node to the module. + + The leaf node must belong to the same or a child module. This method will recursively + construct _ModuleNode instance based on the stack_meta information of the leaf node. + """ + if self.is_same_module_as(leaf_node) or leaf_node.fx_op == "call_module": + self._nodes.append(leaf_node) + elif leaf_node.fx_op == "placeholder": + # Although the original placeholder has empty nn_module_stack, the placeholder lifted + # from get_attr nodes by exported program has their original nn_module_stack. Here + # we need to avoid them building submodule. + self._nodes.append(leaf_node) + elif self.is_parent_module_of(leaf_node): + # This node belongs in a submodule. + # Check if the last node is a submodule and if it is the parent of this node. + last_node = self._nodes[-1] if self._nodes else None + if isinstance(last_node, _ModuleNode) and ( + last_node.is_parent_module_of(leaf_node) + or last_node.is_same_module_as(leaf_node) + ): + # This node belongs to the last_node. + last_node.add_leaf_node(leaf_node) + else: + # Create a new SubmoduleNode for the immediate child module of the current + # module. The leaf node may be a grandchild of the current module. + # Example: + # self.stack_meta = [A, B, C] + # leaf_node.stack_meta = [A, B, C, D, E, F] + # Create a new ModuleNode with stack_meta = [A, B, C, D] and add leaf_node to it. + stack_meta = copy.deepcopy(self.stack_meta) + stack_meta.push(leaf_node.stack_meta[len(self.stack_meta)]) + last_node = _ModuleNode( + self._reference_module, + stack_meta, + ) + self._nodes.append(last_node) + last_node.add_leaf_node(leaf_node) + else: + raise AssertionError( + f"Node {leaf_node} ({leaf_node.stack_meta}) does not belong to module " + f"{self._stack_meta}." + ) + + def fx_nodes(self) -> Generator[torch.fx.Node, None, None]: + """Returns an iterator for the sequence of fx nodes this instance holds.""" + for node in self._nodes: + if isinstance(node, _ModuleNode): + yield from node.fx_nodes() + else: + assert isinstance(node, _LeafNode) + yield node.fx_node + + def module_inputs(self) -> Sequence[torch.fx.Node]: + """Extract module inputs from the sequence of fx nodes this instance holds. + + All node args that are produced by nodes outside of the module are considered module + inputs. The order of returned module inputs is the same as the their use order. + + ### Known limitations + + The original ordering of module inputs is not preserved. There is no meta information + to be found from the `fx.GraphModule` that can be used to recover the original ordering. + + Returns: + Sequence of module inputs. + """ + nodes = list(self.fx_nodes()) + assert len(nodes) > 0, "Cannot extract module inputs from empty nodes." + module_inputs: dict[torch.fx.Node, None] = {} + node_set: set[torch.fx.Node] = set(nodes) + + def _extract_arg_if_node_outside_module(arg: Any): + if isinstance(arg, torch.fx.Node) and arg not in node_set: + module_inputs[arg] = None + + for node in nodes: + pytree.tree_map(_extract_arg_if_node_outside_module, node.args) + pytree.tree_map(_extract_arg_if_node_outside_module, node.kwargs) + return list(module_inputs.keys()) + + def module_outputs(self) -> Sequence[torch.fx.Node]: + """Extract module outputs from the sequence of fx nodes this instance holds. + + All nodes that are used by nodes outside of the module are considered module + outputs. The order of returned module outputs is the same as the their creation order. + + ### Known limitations + + The original ordering of module outputs is not preserved. There is no meta information + to be found from the `fx.GraphModule` that can be used to recover the original ordering. + + Returns: + Sequence of module outputs. + """ + nodes = list(self.fx_nodes()) + assert len(nodes) > 0, "Cannot extract module inputs from empty nodes." + # Need ordered set. Emulate with dict. + module_outputs: dict[torch.fx.Node, None] = {} + node_set: set[torch.fx.Node] = set(nodes) + + for node in nodes: + if any(user not in node_set for user in node.users): + module_outputs[node] = None + return list(module_outputs.keys()) + + def build_module(self, module_names: dict[str, int]) -> torch.fx.GraphModule: + """ + Constructs the fx.GraphModule for this node, registering submodules as necessary. + + Args: + module_names: A dictionary of module names and their counts. This is used to + generate unique module names for submodules. This should be an empty + dictionary when the method is called on a root module. + """ + module_class_name = self._stack_meta.qualified_module_class_name + fx_graph = torch.fx.Graph() + copy_env: dict[torch.fx.Node, torch.fx.Node] = {} + + def _arg_transform(node: torch.fx.Node) -> torch.fx.Node: + return copy_env[node] + + ref_inputs = self.module_inputs() + for node in ref_inputs: + copy_env[node] = fx_graph.placeholder(node.name, node.type) + copy_env[node].meta = copy.copy(node.meta) + + for ir_node in self._nodes: + if isinstance(ir_node, _LeafNode): + fx_node = ir_node.fx_node + copy_env[fx_node] = fx_graph.node_copy( + fx_node, arg_transform=_arg_transform + ) + continue + + assert isinstance(ir_node, _ModuleNode) + # Create fx.GraphModule for child submodule. + submodule = ir_node.build_module(module_names) + ref_submodule_inputs = ir_node.module_inputs() + ref_submodule_outputs = ir_node.module_outputs() + unique_submodule_name = _get_unique_module_name( + module_names, ir_node.stack_meta.module_display_name + ) + # Link the newly generated sub fx.GraphModule with the root reference module. + # This step is essential to meet the needs of the subsequent fx.GraphModule initialization + # for the fx.GraphModule being created by this method. + # The initialization of fx.GraphModule will replicate all necessary attributes from a reference + # fx.GraphModule for the fx.Graph. While the root reference module possesses all + # parameters and buffers, it does not include the newly created sub fx.GraphModule. + # Therefore, it's necessary to register it under the root reference at this stage. + self._reference_module.add_submodule(unique_submodule_name, submodule) + + # create call_module fx.Node + submodule_node = fx_graph.call_module( + unique_submodule_name, + tuple(_arg_transform(node) for node in ref_submodule_inputs), + ) + if len(ref_submodule_outputs) > 1: + # Module node has multiple output. Create 'getitem' node for each output. + submodule_node.meta["val"] = tuple( + ref_output.meta.get("val") for ref_output in ref_submodule_outputs + ) + for i, ref_output in enumerate(ref_submodule_outputs): + getitem_node = fx_graph.call_function( + operator.getitem, + args=(submodule_node, i), + type_expr=ref_output.type, + ) + getitem_node.meta = copy.copy(ref_output.meta) + # Make a copy for "nn_module_stack" since the current module will be + # popped from the stack for this 'getitem' node. + getitem_node.meta["nn_module_stack"] = copy.copy( + ref_output.meta["nn_module_stack"] + ) + # The node is associated with the parent module. + getitem_node.meta["nn_module_stack"].popitem() + copy_env[ref_output] = getitem_node + else: + # Module node has single output. Use module node directly. + copy_env[ref_submodule_outputs[0]] = submodule_node + submodule_node.meta = copy.copy(ref_submodule_outputs[0].meta) + + # Update meta for new call_module node. + if (stack_trace := ir_node.stack_trace) is not None: + submodule_node.meta["stack_trace"] = stack_trace + raw_module_stack_meta = ir_node.stack_meta.raw_meta + assert raw_module_stack_meta is not None + submodule_node.meta["nn_module_stack"] = copy.copy(raw_module_stack_meta) + # The node is associated with the parent module. + submodule_node.meta["nn_module_stack"].popitem() + + new_nodes = fx_graph.nodes + # Skip if the last node is already 'output'. This is the case for root module. + # Otherwise create an 'output' node for the inferred outputs. + if next(iter(reversed(new_nodes))).op != "output": + ref_submodule_outputs = self.module_outputs() + new_outputs = [copy_env[ref_output] for ref_output in self.module_outputs()] + node = fx_graph.output( + new_outputs[0] if len(new_outputs) == 1 else new_outputs + ) + + graph_module = torch.fx.GraphModule( + self._reference_module, fx_graph, module_class_name + ) + if (module_class := self._stack_meta.module_class) is not None: + graph_module.meta["onnx"] = _pass.GraphModuleOnnxMeta( + _pass.PackageInfo.from_python_class(module_class) + ) + return graph_module + + +class _LeafNode(_IRNode): + """Representing a single fx.Node.""" + + def __init__(self, node: torch.fx.Node, is_exported_program: bool = False): + self._node = node + self._stack_meta = _module_stack_meta_from_node( + node, is_exported_program=is_exported_program + ) + + @property + def fx_op(self) -> str: + """Syntax sugar for self.fx_node.op.""" + return self._node.op + + @property + def fx_node(self) -> torch.fx.Node: + """Returns the fx.Node this instance represents.""" + return self._node + + @property + def stack_meta(self) -> _ModuleStackMeta: + """Returns the module stack meta data associated with this node.""" + return self._stack_meta + + @property + def stack_trace(self) -> str | None: + """Returns the stack trace associated with this node.""" + return self.fx_node.meta.get("stack_trace") + + def __str__(self) -> str: + return f"LeafNode({self._node})" + + +class Modularize(_pass.Transform): + """Transforms a flattened `fx.GraphModule` into a modular structure. + + In the flattened `fx.GraphModule`, each `nn.Module` forward call has been traced as + a sequence of `fx.Node`s. All these `fx.Node`s are flattened and reside in the same + `fx.GraphModule`. `fx.GraphModule` could be from `torch.export.ExportedProgram` or + directly generated by `torch._dynamo.export` with torch.nn.Module. + + This pass generates a new `fx.GraphModule`. It groups the flattened `fx.Node`s that belong + to the same `nn.Module` forward call into a sub `fx.GraphModule`. It then replaces the + sequence of flattened `fx.Node`s with a single `call_module` node, which is linked with + the sub `fx.GraphModule` by `node.target`. The sub `fx.GraphModule` is registered as a + submodule of the new `fx.GraphModule`. + + The process is done based on information from the `nn_module_stack` metadata of each node, i.e. + `node.meta["nn_module_stack"]`. For more implementation details, see [NOTE: Modularize Pass Implementation]. + + An fx submodule under this context can typically be interpreted in three different ways: + + 1. As an embodiment of an nn.Module class, which is considered stateless. + Its execution path can vary depending on the configuration of module initialization, + which should also be part of the inputs. + + 2. As a representation of an nn.Module instance. It maintains the state initialized in the module. + The execution path can vary based on actual input data. + + 3. As a captured call of an nn.Module instance, where the execution path + is set. + + The generality decreases along this list. Within the scope of this function, the pass + creates fx submodules according to the third interpretation. + + The first interpretation is the most general case. It requires complex analysis and additional + metadata and code information to construct its general form. Consider an example nn.Module + that generates arbitrary submodules based on an initialization configuration file. It's impractical + to extract this logic for the generated fx submodule to function with arbitrary configuration. + + The second interpretation demands less analysis and is sturdier than the + first. In most use cases, it's equivalent to the third. It only differs in exceptional situations + where a complex nn.Module instance is called multiple times, each with a different set of inputs + leading to a unique execution branching path. + + The third interpretation is the most specific scenario. It necessitates the minimum + analysis and creates the most stable representation. The drawback is that it + generates more redundancy than the other two methods. If needed, a subsequent post-processing + pass can be applied to consolidate completely identical functions and reduce duplication. + + ### Known constraints + Two successive calls to the same module instance will be conflated. They are indistinguishable. + This is due to limitations of the current fx metadata "nn_module_stack". + + [NOTE: Modularize pass ordering] + This pass groups fx nodes into subgraphs that reside within the `call_module` fx node. + Other fx passes (including some outside the exporter) might not recognize `call_module`. + They may assume that all nodes are flattened. Hence it is recommended to invoke this pass + as the last pre onnx export fx pass. If not for this consideration, this operation could + potentially be relocated anywhere earlier in the pipeline. + + Example: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) + >>> import torch + >>> from torch.onnx._internal.fx import passes + >>> + >>> class CustomModule(torch.nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.embedding = torch.nn.Embedding(10, 32) + >>> self.relu = torch.nn.ReLU() + >>> + >>> def forward(self, x): + >>> out = self.embedding(x) + >>> out = self.relu(out) + >>> return out + >>> + >>> class TestModule(torch.nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.layer = CustomModule() + >>> self.linear = torch.nn.Linear(32, 10) + >>> + >>> def forward(self, x): + >>> out = self.layer(x) + >>> out = self.linear(out) + >>> return out + >>> + >>> gm, _ = torch._dynamo.export(TestModule(), aten_graph=True)( + ... torch.tensor([0, 1, 2]) + ... ) + >>> gm.print_readable() + + >>> gm = passes.Modularize( + ... gm, + ... ).run() + >>> gm.print_readable() + + """ + + def __init__( + self, + module: torch.fx.GraphModule, + is_exported_program: bool = False, + ): + super().__init__(module) + self.module = module + self.is_exported_program = is_exported_program + + def _run(self) -> torch.fx.GraphModule: + # DCE to remove unused nodes. + # If a submodule is unused, it is hard to analyze which nodes constitutes the submodule + # outputs. But since it is unused, we can just remove it. + self.module.graph.eliminate_dead_code() + + reference_module = torch.fx.GraphModule(self.module, self.module.graph) + root_module_node = _ModuleNode( + reference_module, + _ModuleStackMeta( + nn_module_stack_meta=None, is_exported_program=self.is_exported_program + ), + ) + for fx_node in self.module.graph.nodes: + root_module_node.add_leaf_node( + _LeafNode(fx_node, is_exported_program=self.is_exported_program) + ) + return root_module_node.build_module({}) diff --git a/torch/onnx/_internal/fx/passes/readability.py b/torch/onnx/_internal/fx/passes/readability.py new file mode 100644 index 0000000000000..a14d07b9aa197 --- /dev/null +++ b/torch/onnx/_internal/fx/passes/readability.py @@ -0,0 +1,130 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch +from torch.onnx._internal.fx import _pass + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +logger = logging.getLogger(__name__) + + +class RestoreParameterAndBufferNames(_pass.Transform): + """Restore parameter and buffer names from original nn.module. + + This pass is useful for readability of the exported ONNX graph. It restores the + parameter and buffer names from the original nn.module. For example, if the original + nn.module has a parameter named `root.linear.0.weight`, and the parameter is renamed to + `_param_constant9` by FX, this pass will rename it back. + + This pass must be run after `Decompose` pass. Because this pass is expected to be called on + `fx.GraphModule` produced by `proxy_tensor.make_fx`, where all parameters and buffers + are registered at root level. + """ + + def __init__( + self, + fx_module: torch.fx.GraphModule, + original_nn_module: torch.nn.Module, + ): + super().__init__(fx_module) + self.original_nn_module = original_nn_module + + def _rename_param_and_buffer( + self, + nodes: Sequence[torch.fx.Node], + new_name: str, + ) -> None: + """Rename the parameter/buffer and replace corresponding nodes with new nodes of updated target.""" + assert len(nodes) > 0, "`nodes` cannot be empty" + assert len({node.target for node in nodes}) == 1, ( + "`nodes` must all have same `target`" + ) + old_name = nodes[0].target + assert isinstance(old_name, str), f"Expected str, got type({old_name})" + # Parameter/buffer name cannot contain "." + normalized_name = new_name.replace(".", "/") + attr_value = getattr(self.module, old_name) + setattr(self.module, normalized_name, attr_value) + delattr(self.module, old_name) + for node in nodes: + with self.module.graph.inserting_before(node): + new_node = self.module.graph.get_attr(normalized_name) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + self.module.graph.erase_node(node) + logger.info( + "Renamed 'self.%s' to 'self.%s', " + "normalized from original parameter name '%s'.", + old_name, + normalized_name, + new_name, + ) + + def _run(self, *args, **kwargs) -> torch.fx.GraphModule: + """Restore parameter and buffer names from original module. + + For each `get_attr` node, if the target is a str representing a parameter or buffer + under `self.module`, we rename the parameter or buffer to its original name. + The parameters and buffers between `self.module` and `self.original_nn_module` refer + to the same objects, allowing us to use it as key to retrieve the original name. + """ + assert len(args) == 0, "RestoreParameterAndBufferNames does not take any args" + assert len(kwargs) == 0, ( + "RestoreParameterAndBufferNames does not take any kwargs" + ) + # state_to_readable_name[parameter/buffer] returns the original readable name of + # the parameter/buffer. E.g., "self.linear.weight". + state_to_readable_name: dict[torch.nn.Parameter | torch.Tensor, str] = {} + state_to_readable_name.update( + {v: k for k, v in self.original_nn_module.named_parameters()} + ) + state_to_readable_name.update( + {v: k for k, v in self.original_nn_module.named_buffers()} + ) + + # old_name_to_nodes[old_name] returns a tuple of (nodes, new_name) + # where `nodes` is a list of `get_attr` nodes with `old_name` as `target` and + # `new_name` is the new readable name. + old_name_to_nodes: dict[str, tuple[list[torch.fx.Node], str]] = {} + + for node in self.module.graph.nodes: + if node.op == "get_attr": + assert isinstance(node.target, str), ( + f"Expected str, got type({node.target})" + ) + if node.target.find(".") != -1: + raise RuntimeError( + f"Unexpected target {node.target} in get_attr, found '.' in target. " + f"All parameters and buffers are expected to be registered at root level, " + f"i.e., self.module. " + ) + if node.target in old_name_to_nodes: + # We have already processed this parameter/buffer. + old_name_to_nodes[node.target][0].append(node) + continue + attr_value = getattr(self.module, node.target) + if ( + isinstance(attr_value, (torch.nn.Parameter, torch.Tensor)) + and attr_value in state_to_readable_name + ): + readable_name = state_to_readable_name[attr_value] + old_name_to_nodes[node.target] = ([node], readable_name) + continue + + logger.info( + "Cannot find readable name for self.%s: %s. The name is unchanged.", + node.target, + type(attr_value), + ) + + for nodes, new_name in old_name_to_nodes.values(): + self._rename_param_and_buffer(nodes, new_name) + + return self.module diff --git a/torch/onnx/_internal/fx/passes/virtualization.py b/torch/onnx/_internal/fx/passes/virtualization.py new file mode 100644 index 0000000000000..504dea1d84247 --- /dev/null +++ b/torch/onnx/_internal/fx/passes/virtualization.py @@ -0,0 +1,96 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch.onnx._internal.fx import _pass + + +if TYPE_CHECKING: + import torch.fx + + +class MovePlaceholderToFront(_pass.Transform): + """This pass move all placeholder nodes to the front of the graph node list. + + In torch.fx.Graph, placeholder is a special assignment node. If it's not + executed in the beginning, it could overwrite values computed by upstream + nodes. + """ + + def _run(self, *args, **kwargs) -> torch.fx.GraphModule: + graph_module = self.module + graph = graph_module.graph + placeholders = [] + first_not_placeholder = None + for node in graph.nodes: + if node.op == "placeholder": + placeholders.append(node) + if first_not_placeholder is None and node.op != "placeholder": + first_not_placeholder = node + if first_not_placeholder is None: + return graph_module + for placeholder in placeholders: + first_not_placeholder.prepend(placeholder) + return graph_module + + +class ReplaceGetAttrWithPlaceholder(_pass.Transform): + """Replace get_attr with placeholder. + + The parameters and buffers accessed by the original get_attr are returned; + they are useful when creating random inputs for the modified graph_module. + """ + + _replaced_attrs: tuple[torch.Tensor, ...] | None + + @property + def replaced_attrs(self) -> tuple[torch.Tensor, ...]: + """The list of replaced weight tensors.""" + assert self._replaced_attrs is not None, ( + "Must run ReplaceGetAttrWithPlaceholder first" + ) + return self._replaced_attrs + + def _run(self, *args, **kwargs) -> torch.fx.GraphModule: + graph_module = self.module + graph = graph_module.graph + replaced_attrs: list[torch.Tensor] = [] + for node in graph.nodes: + if node.op == "get_attr": + replaced_attr: torch.Tensor | None = None + # get_attr could retrieve either parameter or buffer, so + # we need to try both. + try: + replaced_attr = graph_module.get_parameter(node.target) + except AttributeError: + # It's possible that model author use buffer instead of + # parameter to store trainable weights. In this case, + # 1. get_parameter will throw something like + # AttributeError: `bias` is not an nn.Parameter. + # 2. get_buffer should work. + replaced_attr = graph_module.get_buffer(node.target) + + # Reassign op type so that get_attr node becomes placeholder node. + node.op = "placeholder" + # The target name in placeholder must be a valid Python identifier. + # Thus, we replace, e.g., "module.submodule.weight" with + # "module_submodule_weight". + node.target = node.target.replace(".", "_") + # Default value is None. This is needed as long as the "graph_module" + # has optional inputs. Assume the original forward signature is + # def forward(self, x, y=None) + # and the replaced get_attr node has target "z". Then, the modified + # signature should be + # def forward(self, x, y=None, z=None) + # Without the following line, the signature will be + # def forward(self, x, y=None, z) + # , which is not valid Python code. + node.args = (None,) + + replaced_attrs.append(replaced_attr) + + self._replaced_attrs = tuple(replaced_attrs) + + return graph_module diff --git a/torch/onnx/_internal/fx/patcher.py b/torch/onnx/_internal/fx/patcher.py new file mode 100644 index 0000000000000..6c9724e9f5a73 --- /dev/null +++ b/torch/onnx/_internal/fx/patcher.py @@ -0,0 +1,143 @@ +# mypy: allow-untyped-defs +import copy +import functools +from typing import TYPE_CHECKING, Union + +import torch + + +if TYPE_CHECKING: + import io + + +# TODO: Remove after https://github.com/huggingface/safetensors/pull/318 +@functools.cache +def has_safetensors_and_transformers(): + try: + # safetensors is not an exporter requirement, but needed for some huggingface models + import safetensors # type: ignore[import] # noqa: F401 + import transformers # type: ignore[import] # noqa: F401 + from safetensors import torch as safetensors_torch # noqa: F401 + + return True + except ImportError: + return False + + +class ONNXTorchPatcher: + """Context manager to temporarily patch PyTorch during FX-to-ONNX export. + + This class is a collection of "patches" required by FX-to-ONNX exporter. + + This context overrides several torch functions to support symbolic + export of large scale models. + + torch.load: + This function is patched to record the files PyTorch stores model + parameters and buffers. Downstream FX-to-ONNX exporter can create + initializers from these files. + torch.fx._symbolic_trace._wrapped_methods_to_patch: + This list is extended with (torch.Tensor, "__getitem__") so that + weight[x, :, y] becomes exportable with torch.fx.symbolic_trace. + safetensors.torch.load_file: + This function is patched to allow safetensors to be loaded within + FakeTensorMode. Remove after https://github.com/huggingface/safetensors/pull/318 + + Search for ONNXTorchPatcher in test_fx_to_onnx_with_onnxruntime.py for + example usage. + + TODO: Should this really be a global patcher? Can we make it a local patcher? + A reason for splitting this into several patchers is to patch one part of the code + as a collateral damage of patching another part of the code. For example, we + for tracing model with torch._dynamo.export, we don't need to patch + `torch.fx._symbolic_trace._wrapped_methods_to_patch` + """ + + def __init__(self) -> None: + # List of file paths processed by torch.load. + self.paths: list[Union[str, io.BufferedIOBase]] = [] + + def torch_load_wrapper(f, *args, **kwargs): + # Record path for later serialization into ONNX proto + self.paths.append(f) + # Then, call the original torch.load. + return self.torch_load(f, *args, **kwargs) + + # Original version of torch.load. + self.torch_load = torch.load + + # Wrapper or modified version of torch functions. + self.torch_load_wrapper = torch_load_wrapper + + if has_safetensors_and_transformers(): + import safetensors + import transformers + + def safetensors_load_file_wrapper(filename, device="cpu"): + # Record path for later serialization into ONNX proto + self.paths.append(filename) + result = {} + with safetensors.torch.safe_open( # type: ignore[attr-defined] + filename, framework="pt", device=device + ) as f: + for k in f.keys(): + fake_mode = torch._guards.detect_fake_mode() + if not fake_mode: + result[k] = f.get_tensor(k) + else: + empty_tensor = f.get_slice(k) + result[k] = torch.empty( + tuple(empty_tensor.get_shape()), + dtype=safetensors.torch._getdtype( + empty_tensor.get_dtype() + ), + ) + return result + + self.safetensors_torch_load_file = safetensors.torch.load_file + self.safetensors_torch_load_file_wrapper = safetensors_load_file_wrapper + self.transformers_modeling_utils_safe_load_file = ( + transformers.modeling_utils.safe_load_file + ) + + def __enter__(self): + torch.load = self.torch_load_wrapper + + self.torch_fx__symbolic_trace__wrapped_methods_to_patch = ( + torch.fx._symbolic_trace._wrapped_methods_to_patch + ) + desired_wrapped_methods = copy.deepcopy( + torch.fx._symbolic_trace._wrapped_methods_to_patch + ) + if (torch.Tensor, "__getitem__") not in desired_wrapped_methods: + # Adding `__getitem__` to the patching list will make tensor indexing traceable via + # torch.fx.symbolic_trace. Otherwise, `tensor[x, :, y]` cannot be traced. + # This happens because `__getitem__` is neither under torch domain nor an aten operator, + # so the patching (or similar Proxy-generating mechanism) doesn't happen automatically. + # Note that torch.fx.symbolic_trace defines FX_PATCH_GETITEM environment variable for + # enabling the line below for patching. + desired_wrapped_methods.append((torch.Tensor, "__getitem__")) + torch.fx._symbolic_trace._wrapped_methods_to_patch = desired_wrapped_methods + + if has_safetensors_and_transformers(): + import safetensors + import transformers + + safetensors.torch.load_file = self.safetensors_torch_load_file_wrapper + transformers.modeling_utils.safe_load_file = ( + self.safetensors_torch_load_file_wrapper + ) + + def __exit__(self, exc_type, exc_value, traceback): + torch.load = self.torch_load + torch.fx._symbolic_trace._wrapped_methods_to_patch = ( + self.torch_fx__symbolic_trace__wrapped_methods_to_patch + ) + if has_safetensors_and_transformers(): + import safetensors + import transformers + + safetensors.torch.load_file = self.safetensors_torch_load_file + transformers.modeling_utils.safe_load_file = ( + self.transformers_modeling_utils_safe_load_file + ) diff --git a/torch/onnx/_internal/fx/registration.py b/torch/onnx/_internal/fx/registration.py new file mode 100644 index 0000000000000..e855f98f044f6 --- /dev/null +++ b/torch/onnx/_internal/fx/registration.py @@ -0,0 +1,87 @@ +"""Module for handling ATen to ONNX functions registration.""" + +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING + + +# We can only import onnx from this module in a type-checking context to ensure that +# 'import torch.onnx' continues to work without having 'onnx' installed. We fully +# 'import onnx' inside of dynamo_export (by way of _assert_dependencies). +if TYPE_CHECKING: + import types + + import onnxscript # type: ignore[import] + + import torch._ops + + +@dataclasses.dataclass(frozen=True, eq=True) +class ONNXFunction: + """A wrapper of onnx-script function. + + op_full_name: The qualified name of the function. In the form of '::.'. + onnx_function: The onnx-script function from torchlib. + is_custom: Whether the function is a custom function. + is_complex: Whether the function is a function that handles complex valued inputs. + + """ + + onnx_function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction + op_full_name: str + is_custom: bool = False + is_complex: bool = False + + +@dataclasses.dataclass(frozen=True, eq=True) +class OpName: + """A class representing an operator name in internal ONNX converter.""" + + namespace: str + op_name: str + overload: str + + @classmethod + def from_name_parts( + cls, namespace: str, op_name: str, overload: str | None = None + ) -> OpName: + # NOTE: in PyTorch, the overload could be unprovided to indicate the + # default overload + if overload is None or overload == "": + overload = "default" + return cls(namespace, op_name, overload) + + @classmethod + def from_qualified_name(cls, qualified_name: str) -> OpName: + """When the name is ::[.]""" + namespace, opname_overload = qualified_name.split("::") + op_name, *overload = opname_overload.split(".", 1) + overload = overload[0] if overload else "default" + return cls(namespace, op_name, overload) + + @classmethod + def from_op_overload(cls, op_overload: torch._ops.OpOverload) -> OpName: + return cls.from_qualified_name(op_overload.name()) + + @classmethod + def from_builtin_function( + cls, builtin_function: types.BuiltinFunctionType + ) -> OpName: + """From a builtin function, e.g. operator.add, math.ceil, etc, get the OpName. + + FX graph uses built-in functions to caculate sympy expression. This function + is used to get the OpName from a builtin function. + + Args: + builtin_function (types.BuiltinFunctionType): operator.add, math.ceil, etc. + + Returns: + OpName: _description_ + """ + op = builtin_function.__name__ # add, sub, etc. + module = builtin_function.__module__ # _operators or math + return cls.from_qualified_name(module + "::" + op) + + def qualified_name(self) -> str: + return f"{self.namespace}::{self.op_name}.{self.overload}" diff --git a/torch/onnx/_internal/fx/serialization.py b/torch/onnx/_internal/fx/serialization.py new file mode 100644 index 0000000000000..cda71e465758d --- /dev/null +++ b/torch/onnx/_internal/fx/serialization.py @@ -0,0 +1,250 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import io +import logging +import os +from typing import IO, TYPE_CHECKING + +import torch +from torch.onnx import _type_utils as jit_type_utils + + +if TYPE_CHECKING: + import onnx + + from torch.types import FileLike + +log = logging.getLogger(__name__) + + +def _create_tensor_proto_with_external_data( + tensor: torch.Tensor, + name: str, + location: str, + basepath: str, + dtype_override: onnx.TypeProto | None = None, # type: ignore[name-defined] +) -> onnx.TensorProto: # type: ignore[name-defined] + """Create a TensorProto with external data from a PyTorch tensor. + The external data is saved to os.path.join(basepath, location). + + Args: + tensor: Tensor to be saved. + name: Name of the tensor (i.e., initializer name in ONNX graph). + location: Relative location of the external data file + (e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx"). + basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp"). + + + Reference for ONNX's external data format: + How to load? + https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187 + How to save? + https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43 + How to set ONNX fields? + https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88 + """ + # FIXME: Avoid importing onnx into torch.onnx. + import onnx + + scalar_type = ( + jit_type_utils.JitScalarType.from_onnx_type( + dtype_override.tensor_type.elem_type + ) + if dtype_override is not None + else jit_type_utils.JitScalarType.from_dtype(tensor.dtype) + ) + + # Checkpoints can be stored with a different dtype as the model expects because + # the user script can explicitly cast the original type to something or maybe + # PyTorch's type promotion might do it + if dtype_override is not None and scalar_type.dtype() != tensor.dtype: + tensor = tensor.to(scalar_type.dtype()) + + tensor_proto = onnx.TensorProto() # type: ignore[attr-defined] + tensor_proto.name = name + tensor_proto.data_type = scalar_type.onnx_type() # type: ignore[assignment] + + tensor_proto.dims.extend(tensor.shape) + tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined] + + # Settings for saving one tensor per file. + # Offset is zero because there is no other tensor in the same file. + key_value_pairs = { + "location": location, + "offset": 0, + "length": tensor.untyped_storage().nbytes(), + } + for k, v in key_value_pairs.items(): + entry = tensor_proto.external_data.add() + entry.key = k + entry.value = str(v) + + # Actual path to write content of tensor. + external_data_file_path = os.path.join(basepath, location) + if os.path.exists(external_data_file_path): + os.remove(external_data_file_path) + + # Create external data's folder if not exists. + external_data_dir_path = os.path.dirname(external_data_file_path) + if not os.path.exists(external_data_dir_path): + # if the demo_folder directory is not present + # then create it. + os.makedirs(external_data_dir_path) + + # Create a fresh file. + with open(external_data_file_path, "xb") as data_file: + # No need to call "seek" because offset is 0. + # data_file.seek(0) + # Write tensor content to the file. + data_file.write(tensor.numpy(force=True).tobytes()) + + return tensor_proto + + +def _convert_safetensors_to_torch_format(safetensors_file): + # It this function is called, safetensors is guaranteed to exist + # because the HF model with safetensors was already loaded and exported to ONNX + from safetensors import safe_open # type: ignore[import-not-found, import-untyped] + + tensors = {} + with safe_open(safetensors_file, framework="pt", device="cpu") as f: # type: ignore[attr-defined] + for k in f.keys(): + tensors[k] = f.get_tensor(k).cpu() + return tensors + + +# TODO: generalize to allow more checkpoints formats (torch or gguf) +def save_model_with_external_data( + basepath: str, + model_location: str, + initializer_location: str, + torch_state_dicts: tuple[dict | FileLike, ...], + onnx_model: onnx.ModelProto, # type: ignore[name-defined] + rename_initializer: bool = False, +) -> None: + """Load PyTorch tensors from files and add to "onnx_model" as external initializers. + + Output files: + ONNX model file path: + ONNX initializer folder: os.path.join(basepath, initializer_location) + + After running this function, you can do + ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location)) + to execute the model. + + Arguments: + basepath: Base path of the ONNX external data file (e.g., "/path/to/large_model/"). + model_location: Relative location of the ONNX model file. + E.g., "model.onnx" so that the model file is saved to + "/model.onnx". + initializer_location: Relative location of the ONNX initializer folder. + E.g., "initializers" so that the initializers are saved to + "/initializers/". + Note: When initializers are >2GB, must be the same as `model_location`. + torch_state_dicts: Dictionaries or files which contain PyTorch tensors to be saved + as ONNX initializers. For non-dict arguments, `torch.load` will be used to load them from file-like objects. + onnx_model: ONNX model to be saved with external initializers. + If an input name matches a tensor loaded from "torch_state_dicts", + the tensor will be saved as that input's external initializer. + rename_initializer: Replaces "." by "_" for all ONNX initializer names. + Not needed by the official torch.onnx.dynamo_export. This is a hack + for supporting `FXSymbolicTracer` tracer with fake tensor mode. + In short, `FXSymbolicTracer` lifts FX parameters (self.linear_weight) + as inputs (`def forward(self, linear_weight)`) and therefore, `.` cannot be used. + """ + # FIXME: Avoid importing onnx into torch.onnx. + import onnx + + initializers_to_be_deleted = {} # Using dict because it is **ordered** + existing_initializers = { + k.name: idx for idx, k in enumerate(onnx_model.graph.initializer) + } + onnx_input_names = {input.name for input in onnx_model.graph.input} + for el in torch_state_dicts: + if isinstance(el, dict): + # Useful for when state_dict is loaded with torch.load(..., mmap=True, map_location="cpu") by the user + # Using torch.save wouldn't leverage mmap, leading to higher memory usage + state_dict = el + else: + if isinstance(el, (str, os.PathLike)) and os.fspath(el).endswith( + ".safetensors" + ): + state_dict = _convert_safetensors_to_torch_format(el) + else: + try: + # Loads checkpoint using memory-map on CPU to support really large models + # The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded + state_dict = torch.load(el, map_location="cpu", mmap=True) + except (RuntimeError, ValueError) as e: + if "mmap can only be used with files saved with" in str(e) or ( + isinstance(el, (io.IOBase, IO)) + and el.readable() + and el.seekable() + ): + log.warning( + "Failed to load the checkpoint with memory-map enabled, retrying without memory-map." + "Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6." + ) + if isinstance(el, (io.IOBase, IO)): + el.seek(0) # torch.load from `try:` has read the file. + state_dict = torch.load(el, map_location="cpu") + else: + raise e + + for name, tensor in state_dict.items(): + if rename_initializer: + # Basically, "transformer.attention.self.query.weight" is mapped + # to "transformer_attention_self_query_weight" for mimicking the + # name-modifying code in FX-to-ONNX exporter. + # See function _replace_get_attr_with_placeholder for details. + name = name.replace(".", "_") + + # This block tries to match the onnx initializer name with torch parameter/buffer + # e.g. A pytorch buffer 'transformer.h.0.attn.bias' can be named 'h.0.attn.bias' in a ONNX initializer + # For each PyTorch tensor name loaded by torch.load, + # 1. Search its best match in ONNX model. E.g., the match of + # "transformer_attention_weight" could be "attention_weight". + # 2. Set "tensor" as the initializer of the matched ONNX input. + # E.g., "tensor" is stored as the initializer of "attention_weight". + # Step 1 is required because sometimes, tensor names are stored with prefix the dictionary + # loaded by torch.load. + if name in onnx_input_names: + # Same input name shouldn't be matched again + onnx_input_names.remove(name) + else: + for onnx_input_name in onnx_input_names: + if onnx_input_name.endswith(name) or name.endswith(onnx_input_name): + # Find a match. Change name to the matched ONNX input name, so that we + # create initializer with the right ONNX name. + name = onnx_input_name + onnx_input_names.remove(onnx_input_name) + break + + relative_tensor_file_path = os.path.join(initializer_location, name) + # Create one file per tensor. + # tensor_proto.raw_data is stored to external file at + # os.path.join(basepath, relative_tensor_file_path). + model_input_types = {k.name: k.type for k in onnx_model.graph.input} + + # Mark for deletion - a replacement will be appended next + if name in existing_initializers: + initializers_to_be_deleted[existing_initializers[name]] = name + tensor_proto = _create_tensor_proto_with_external_data( + tensor, + name, + relative_tensor_file_path, + basepath, + model_input_types.pop(name, None), + ) + # Add the tensor_proto to the ONNX model as an initializer with external data. + onnx_model.graph.initializer.append(tensor_proto) + # Remove old duplicated initializers, if any. delete in desc order to not invalidate deletion indices + initializers_to_be_deleted = dict( + sorted(initializers_to_be_deleted.items(), reverse=True) + ) + for idx in initializers_to_be_deleted.keys(): + del onnx_model.graph.initializer[idx] + + # model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx". + onnx.save(onnx_model, os.path.join(basepath, model_location)) # type: ignore[attr-defined] diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py new file mode 100644 index 0000000000000..6c414e8d54e78 --- /dev/null +++ b/torch/onnx/_internal/io_adapter.py @@ -0,0 +1,652 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import Any, Callable, TYPE_CHECKING +from typing_extensions import Protocol, runtime_checkable + +import torch +import torch.export as torch_export +from torch.utils import _pytree as pytree + + +if TYPE_CHECKING: + import inspect + from collections.abc import Mapping, Sequence + + +@runtime_checkable +class InputAdaptStep(Protocol): + """A protocol that defines a step in the input adapting process. + + The input adapting process is a sequence of steps that are applied to the + PyTorch model inputs to transform them into the inputs format expected by the + exported ONNX model. Each step takes the PyTorch model inputs as arguments and + returns the transformed inputs. + + This serves as a base formalized construct for the transformation done to model + input signature by any individual component in the exporter. + """ + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: ... + + +class InputAdapter: + """A class that adapts the PyTorch model inputs to exported ONNX model inputs format.""" + + def __init__(self, steps: list[InputAdaptStep] | None = None): + self._steps = steps or [] + + def append_step(self, step: InputAdaptStep) -> None: + """Appends a step to the input adapt steps. + + Args: + step: The step to append. + """ + self._steps.append(step) + + def apply( + self, + *model_args, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + **model_kwargs, + ) -> Sequence[int | float | bool | str | torch.Tensor | torch.dtype | None]: + """Converts the PyTorch model inputs to exported ONNX model inputs format. + + Args: + model_args: The PyTorch model inputs. + model: The PyTorch model. + model_kwargs: The PyTorch model keyword inputs. + Returns: + A sequence of tensors converted from PyTorch model inputs. + """ + args: Sequence[Any] = model_args + kwargs: Mapping[str, Any] = model_kwargs + for step in self._steps: + args, kwargs = step.apply(args, kwargs, model=model) + assert not kwargs + return args + + +@runtime_checkable +class OutputAdaptStep(Protocol): + """A protocol that defines a step in the output adapting process. + + The output adapting process is a sequence of steps that are applied to the + PyTorch model outputs to transform them into the outputs format produced by the + exported ONNX model. Each step takes the PyTorch model outputs as arguments and + returns the transformed outputs. + + This serves as a base formalized construct for the transformation done to model + output signature by any individual component in the exporter. + """ + + def apply( + self, + model_outputs: Any, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> Any: ... + + +class OutputAdapter: + """A class that adapts the PyTorch model outputs to exported ONNX model outputs format.""" + + def __init__(self, steps: list[OutputAdaptStep] | None = None): + self._steps = steps or [] + + def append_step(self, step: OutputAdaptStep) -> None: + """Appends a step to the output format steps. + + Args: + step: The step to append. + """ + self._steps.append(step) + + def apply( + self, + model_outputs: Any, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> Sequence[torch.Tensor | int | float | bool | str]: + """Converts the PyTorch model outputs to exported ONNX model outputs format. + + Args: + model_outputs: The PyTorch model outputs. + model: The PyTorch model. + + Returns: + PyTorch model outputs in exported ONNX model outputs format. + """ + for step in self._steps: + model_outputs = step.apply(model_outputs, model=model) + return model_outputs + + +# TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276 + + +# TODO(XuehaiPan): Dynamo does not support `dummy_leaf = object()` as a sentinel value in the frame. +class _DummyLeaf: # use a class instead. + pass + + +def _replace_list_with_tuple(spec: pytree.TreeSpec) -> pytree.TreeSpec: + def replace_list_with_tuple(x: Any) -> Any: + if type(x) is list: + return pytree.tree_map( + replace_list_with_tuple, + tuple(x), + is_leaf=lambda x: type(x) is list, + ) + return x + + dummy_leaf = _DummyLeaf() + dummy_tree = pytree.tree_unflatten([dummy_leaf] * spec.num_leaves, spec) + dummy_tree = pytree.tree_map( + replace_list_with_tuple, + dummy_tree, + is_leaf=lambda x: type(x) is list, + ) + return pytree.tree_structure(dummy_tree) + + +def _open_top_level_sequence_if_single_element( + spec: pytree.TreeSpec, +) -> pytree.TreeSpec: + if spec.type in (tuple, list) and spec.num_children == 1: + return spec.children_specs[0] + return spec + + +def _assert_identical_pytree_spec( + spec1: pytree.TreeSpec, spec2: pytree.TreeSpec, error_message: str +) -> None: + """Assert the two `TreeSpec` objects are identical. + + Args: + spec1: The first `TreeSpec` object. + spec2: The second `TreeSpec` object. + error_message: The error message to raise if the two `TreeSpec` objects are not + identical. + + Raises: + ValueError: If the two `TreeSpec` objects are not identical. + """ + pass_if_any_checks: Sequence[Callable[[], bool]] = [ + lambda: spec1 == spec2, + # FIXME: Bug in `dynamo.export`. Sometimes outputs returned in 'list' instead of 'tuple'. + lambda: _replace_list_with_tuple(spec1) == _replace_list_with_tuple(spec2), + # FIXME: Bug in `dynamo.export`. Sometimes single function return is wrapped in list. + lambda: _open_top_level_sequence_if_single_element(spec1) == spec2, + lambda: spec1 == _open_top_level_sequence_if_single_element(spec2), + ] + + if not any(check() for check in pass_if_any_checks): + raise ValueError(f"{error_message}\nExpect {spec1}.\nActual {spec2}.") + + +class BindInputStep(InputAdaptStep): + """Bind the input arguments to the model signature.""" + + def __init__(self, model_signature: inspect.Signature): + self._model_signature = model_signature + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Bind the input arguments to the model signature. + + We hope the input kwargs will be mapped to bound.args after binding. + If not, we will raise an error. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the model args and kwargs. args is always empty. + + Raises: + ValueError: If there are keyword-only arguments left after binding args and + kwargs to model signature. + """ + bound = self._model_signature.bind(*model_args, **model_kwargs) + bound.apply_defaults() + + # keyword-only arguments are not handled. + # bound.kwargs only contains keyword-only arguments after calling + # bind & apply_defaults, so we raise if it's not empty. + if bound.kwargs: + raise ValueError("Keyword-only arguments are not supported.") + return (), bound.arguments + + +class MergeKwargsIntoArgsInputStep(InputAdaptStep): + """Merge the input kwargs into the input args.""" + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Merge the input kwargs into the input args. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the model args and kwargs. kwargs is always empty. + """ + return tuple(model_args) + tuple(model_kwargs.values()), {} + + +class LiftParametersAndBuffersIntoArgsInputStep(InputAdaptStep): + """Append parameters and buffers to model's positional argument list.""" + + def __init__(self, inputs: tuple[torch.Tensor, ...]) -> None: + self.inputs = inputs + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Append model's parameters and buffers into its input. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the model args + appended inputs and kwargs. + """ + return (*model_args, *self.inputs), model_kwargs + + +class ConvertComplexToRealRepresentationInputStep(InputAdaptStep): + """Convert complex dtype tensors to real representation tensors. + + ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors + to real representation tensors (i.e., float dtype tensors with an extra dimension + representing the real and imaginary parts of the complex number). + + """ + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Convert complex tensors to float tensors. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the model args and kwargs. + """ + return ( + tuple( + torch.view_as_real(arg.resolve_conj()) + if isinstance(arg, torch.Tensor) and arg.is_complex() + else arg + for arg in model_args + ), + model_kwargs, + ) + + +class RemoveNoneInputStep(InputAdaptStep): + """Remove `None` from arguments. + + This adapt step assumes ``model_kwargs`` is empty. It also assumes ``model_args`` + is flattened, i.e. it does not check `None` inside nested collections. + """ + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Remove `None` from arguments. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the model args and kwargs. + + Raises: + ValueError: If `model_kwargs` is not empty. + """ + assert not model_kwargs + return tuple(arg for arg in model_args if arg is not None), {} + + +class RemoveNonTensorInputStep(InputAdaptStep): + """Remove the non-tensor input arguments. + + Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534). + + Specifically, it does put the input into graph with an empty node, but consumed by no ones. + The concrete value is embedded into the graph as a constant arg of a target node. Meta + suggests in this case that one should rewrite the model code to make it tensor if the + input value is supposed to change at runtime. We might need to further investigate + the feasibility of that suggestion. + + For example, + + def func(x, b=1.0): + y = x + b + z = y.relu() + return (y, z) + + x = torch.randn(1, 1, 2, dtype=torch.float32) + gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real") + + # class GraphModule(torch.nn.Module): + # def forward(self, x, b): + # arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec) + # # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b + # add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None + + # # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu() + # relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor) + # return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec) + + Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as + it's ignored in ONNX graph. Thus, we delete the useless input here. + + """ + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Remove Constant from arguments. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the model args and kwargs. + + Raises: + ValueError: If `model_kwargs` is not empty. + """ + assert not model_kwargs + return ( + tuple( + arg + for arg in model_args + if not isinstance(arg, (int, float, bool, str)) + ), + {}, + ) + + +class FlattenInputWithTreeSpecValidationInputStep(InputAdaptStep): + """Flatten nested collection types and return a flat list of elements. + + ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor, + etc). + + This class stores the `SpecTree` output produced when `adapt` was called the first + time. It then validates the `SpecTree` output produced from later `adapt` calls. + """ + + _spec: pytree.TreeSpec | None = None + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Flatten the model args and kwargs and validate the `SpecTree` output. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the flattened model args and kwargs. The kwargs is empty, because + they are flattened and merged into the args. + + Raises: + ValueError: If the `SpecTree` output produced from the current `model_outputs` + is not identical to the `SpecTree` output produced from the first + `model_outputs` that was passed to this method. + """ + flattened_args, spec = pytree.tree_flatten((model_args, model_kwargs)) + if self._spec is None: + self._spec = spec + else: + _assert_identical_pytree_spec( + self._spec, + spec, + error_message="Model inputs incompatible with the format that was exported. ", + ) + return flattened_args, {} + + +class FlattenOutputStep(OutputAdaptStep): + """Flatten nested collection types and return a flat list of elements. + + ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor, + etc). + + NOTE: Ideally we would want to use ``FlattenOutputWithTreeSpecValidationOutputStep``, such + that `SpecTree` can be validate for new model outputs. However, this is not possible + currently because we never have access to real PyTorch model outputs during export. + Only traced outputs may be available, but they are not an accurate reflection of the + original PyTorch model outputs format as they are typically in their own unique format, + depending on the tracing strategy. + """ + + def apply( + self, + model_outputs: Any, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> Sequence[Any]: + """Flatten the model outputs. + + Args: + model_outputs: The model outputs to flatten. + model: The PyTorch model. + + Returns: + A tuple of the flattened model outputs. + """ + return pytree.tree_leaves(model_outputs) + + +class ConvertComplexToRealRepresentationOutputStep(OutputAdaptStep): + """Convert complex dtype tensors to real representation tensors. + + ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors + to real representation tensors (i.e., float dtype tensors with an extra dimension + representing the real and imaginary parts of the complex number). + + """ + + def apply( + self, + model_outputs: Any, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> Any: + """Convert float tensors to complex tensors. + + Args: + model_output: The model output. + model: The PyTorch model. + + Returns: + A tuple of the model output. + """ + return [ + torch.view_as_real(output.resolve_conj()) + if isinstance(output, torch.Tensor) and torch.is_complex(output) + else output + for output in model_outputs + ] + + +class FlattenOutputWithTreeSpecValidationOutputStep(OutputAdaptStep): + """Same as ``FlattenOutputStep``, with additional `TreeSpec` validation. + + This class stores the `SpecTree` output produced when `adapt` was called the first + time. It then validates the `SpecTree` output produced from later `adapt` calls. + """ + + _spec: pytree.TreeSpec | None = None + + def apply( + self, + model_outputs: Any, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> Sequence[Any]: + """Flatten the model outputs and validate the `SpecTree` output. + + Args: + model_outputs: The model outputs to flatten. + model: The PyTorch model. + + Returns: + flattened_outputs: The flattened model outputs. + + Raises: + ValueError: If the `SpecTree` output produced from the current `model_outputs` + is not identical to the `SpecTree` output produced from the first + `model_outputs` that was passed to this method. + """ + flattened_outputs, spec = pytree.tree_flatten(model_outputs) + if self._spec is None: + self._spec = spec + else: + _assert_identical_pytree_spec( + self._spec, + spec, + error_message="Model outputs incompatible with the format that was exported. ", + ) + return flattened_outputs + + +class PrependParamsBuffersConstantAotAutogradInputStep(InputAdaptStep): + """Prepend model parameters, buffers and constants to the user input. + + :func:`torch.export.export` lifts model parameters, buffers and constants as model input, thus, they + must be added to the user input before the model is executed. + + Args: + model: The PyTorch model with embedded parameters and buffers. + """ + + def apply( + self, + model_args: Sequence[Any], + model_kwargs: Mapping[str, Any], + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> tuple[Sequence[Any], Mapping[str, Any]]: + """Convert complex tensors to float tensors. + + Args: + model_args: The model args. + model_kwargs: The model kwargs. + model: The PyTorch model. + + Returns: + A tuple of the model args and kwargs. + """ + ordered_params = tuple( + model.state_dict[name] # type: ignore[union-attr,index] + for name in model.graph_signature.parameters # type: ignore[union-attr] + ) + non_persistent_buffers = set(model.graph_signature.non_persistent_buffers) # type: ignore[arg-type, union-attr] + ordered_buffers = [] + for name in model.graph_signature.buffers: # type: ignore[union-attr] + if name in non_persistent_buffers: + ordered_buffers.append(model.constants[name]) # type: ignore[index, union-attr] + else: + ordered_buffers.append(model.state_dict[name]) # type: ignore[union-attr,index] + ordered_constant_tensors = tuple( + model.constants[fqn] # type: ignore[union-attr,index] + for fqn in model.graph_signature.lifted_tensor_constants # type: ignore[union-attr] + ) + + # NOTE: calling convention is first params, then buffers, then args as user supplied them. + # See: torch/_functorch/aot_autograd.py#L1034 + updated_args = ( + *ordered_params, + *ordered_buffers, + *ordered_constant_tensors, + *model_args, + ) + if model_kwargs: + return MergeKwargsIntoArgsInputStep().apply( + updated_args, model_kwargs, model=model + ) + return updated_args, {} + + +class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep): + """Prepend model's mutated buffers to the user output. + + :func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they + must be added to the user output after the model is executed. + + Args: + model: The PyTorch model with mutated buffers. + """ + + def apply( + self, + model_outputs: Any, + model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, + ) -> Sequence[Any]: + """Flatten the model outputs and validate the `SpecTree` output. + + Args: + model_outputs: The model outputs to flatten. + model: The PyTorch model. + + Returns: + flattened_outputs: The flattened model outputs. + """ + + assert isinstance(model, torch_export.ExportedProgram), ( + "'model' must be torch_export.ExportedProgram" + ) + ordered_buffers = tuple( + model.state_dict[name] + if name in model.state_dict + else model.constants[name] + for name in model.graph_signature.buffers_to_mutate.values() + ) + + # NOTE: calling convention is first mutated buffers, then outputs args as model returned them. + updated_outputs = (*ordered_buffers, *model_outputs) + return updated_outputs diff --git a/torch/onnx/_internal/jit_utils.py b/torch/onnx/_internal/jit_utils.py new file mode 100644 index 0000000000000..5db66f6c83a4e --- /dev/null +++ b/torch/onnx/_internal/jit_utils.py @@ -0,0 +1,374 @@ +# mypy: allow-untyped-defs +"""Utilities for manipulating the torch.Graph object and the torchscript.""" + +# TODO(justinchuby): Move more of the symbolic helper functions here and expose +# them to the user. + +from __future__ import annotations + +import dataclasses +import re +import typing +from collections.abc import Iterable, Sequence +from typing import Any + +import torch +from torch import _C +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import registration + + +_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$") +_SKIP_NODE_ATTRIBUTES = {"inplace", "aten"} + + +@dataclasses.dataclass +class GraphContext: + """Extra context for symbolic functions with all methods from torch.Graph. + + NOTE: This class is not meant for external consumption. Please do not depend on + it outside of torch.onnx as the interface may evolve. + + Attributes: + graph: The _C.Graph being constructed. + block: The current _C.Block being constructed. + opset: The opset version. + original_node: Current node that is being converted from. + params_dict: Mapping from graph initializer name to IValue. + env: Mapping from Torch domain graph Value to ONNX domain graph Value. + values_in_env: Set of all values in env, for constant-time lookups. + new_nodes: List that tracks all new nodes that are added (used to make + sure metadata is propagated to all new nodes). + """ + + graph: _C.Graph + block: _C.Block + opset: int + original_node: _C.Node + params_dict: dict[str, _C.IValue] + env: dict[_C.Value, _C.Value] + values_in_env: set[_C.Value] + new_nodes: list[_C.Node] = dataclasses.field(default_factory=list) + + # Relay methods from _C.Graph for compatibility with symbolic functions that expect + # a _C.Graph + def __getattr__(self, name: str) -> Any: + return getattr(self.graph, name) + + def op( + self, + opname: str, + *raw_args: torch.Tensor | _C.Value, + outputs: int = 1, + **kwargs, + ): + """Creates an ONNX operator "opname", taking "raw_args" as inputs and "kwargs" as attributes. + + The set of operators and the inputs/attributes they take + is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md + + Args: + opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified + with a namespace, e.g., `aten::add`. + raw_args: The inputs to the operator; usually provided + as arguments to the `symbolic` definition. + outputs: The number of outputs this operator returns. + By default an operator is assumed to return a single output. + If `outputs` is greater than one, this functions returns a tuple + of output `Value`, representing each output of the ONNX operator + in order. + kwargs: The attributes of the ONNX operator, whose keys are named + according to the following convention: `alpha_f` indicates + the `alpha` attribute with type `f`. The valid type specifiers are + `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute + specified with type float accepts either a single float, or a + list of floats (e.g., you would say `dims_i` for a `dims` attribute + that takes a list of integers). + + Returns: + The value representing the single output of this operator (see the `outputs` + keyword argument for multi-return nodes). + """ + # FIXME(justinchuby): Add the return type back once we know how to handle mypy + return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs) + + def aten_op(self, operator: str, *args, overload_name: str = "", **kwargs): + """Generates an ONNX ATen op node. + + This function is for backward compatibility with the old symbolic functions. + """ + return self.op( + "aten::ATen", + *args, + operator_s=operator, + overload_name_s=overload_name, + **kwargs, + ) + + # NOTE: For backward compatibility with the old symbolic functions. + # We are probably going to remove this only after the fx exporter is established. + at = aten_op + + def onnxscript_op( + self, + onnx_fn, + *raw_args: torch.Tensor | _C.Value, + outputs: int = 1, + **kwargs, + ): + """Creates an ONNX operator from onnx-script function, taking "raw_args" as inputs and "kwargs" as attributes. + + onnx-script repository: https://github.com/microsoft/onnx-script + + Args: + onnx_fn: ONNXFunction from onnx-script; An example can be found at + https://github.com/microsoft/onnx-script#example + raw_args: The inputs to the operator; usually provided + as arguments to the `symbolic` definition. + outputs: The number of outputs this operator returns. + By default an operator is assumed to return a single output. + If `outputs` is greater than one, this functions returns a tuple + of output `Value`, representing each output of the ONNX operator + in order. + kwargs: The attributes of the ONNX operator, whose keys are named + according to the following convention: `alpha_f` indicates + the `alpha` attribute with type `f`. The valid type specifiers are + `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute + specified with type float accepts either a single float, or a + list of floats (e.g., you would say `dims_i` for a `dims` attribute + that takes a list of integers). + + Returns: + The value representing the single output of this operator (see the `outputs` + keyword argument for multi-return nodes). + """ + # NOTE(titaiwang): This is using class attributes, and it needs to be updated + # if onnx-script makes any change on these. + symbolic_name = f"{onnx_fn.opset.domain}::{onnx_fn.name}" + opset_version = onnx_fn.opset.version + + registration.custom_onnx_symbolic(symbolic_name, opset_version)(onnx_fn) + + return _add_op(self, symbolic_name, *raw_args, outputs=outputs, **kwargs) + + +def add_op_with_blocks( + graph_context: GraphContext, + opname: str, + *inputs: _C.Value, + outputs: int = 1, + n_blocks: int = 1, + **attributes, +) -> tuple[Any, tuple[GraphContext, ...], _C.Node]: + """Creates an ONNX operator "opname", taking inputs and attributes. + + Args: + graph_context: The context for the current graph. + opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified + with a namespace, e.g., `aten::add`. + inputs: The inputs to the operator. + outputs: The number of outputs this operator returns. + By default an operator is assumed to return a single output. + If `outputs` is greater than one, this functions returns a tuple + of output `Value`, representing each output of the ONNX operator + in order. + n_blocks: The number of sub-blocks to create in the node. + attributes: The attributes of the ONNX operator. + + Returns: + A tuple of (output_values, new_contexts, node) where: + output_values: One or more output value of this operator + (see the `outputs` keyword argument for multi-return nodes). + new_contexts: A tuple of new graph contexts for each sub-block. + node: The node representing the operator. + """ + + output_values = graph_context.op(opname, *inputs, outputs=outputs, **attributes) + if isinstance(output_values, Sequence): + node = output_values[0].node() + else: + node = output_values.node() + + new_contexts = [] + for _ in range(n_blocks): + new_block = node.addBlock() + # Create shallow copy of the graph context and update the block + new_context = dataclasses.replace(graph_context, block=new_block) + new_contexts.append(new_context) + + return output_values, tuple(new_contexts), node + + +def _add_op( + graph_context: GraphContext, + opname: str, + *args: torch.Tensor | _C.Value, + outputs: int = 1, + **kwargs, +): + """Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs". + + The set of operators and the inputs/attributes they take + is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md + + This function is monkey-patched onto Graph. + + Args: + graph_context: The Torch Graph or Block. + opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified + with a namespace, e.g., `aten::add`. + args: The inputs to the operator; usually provided + as arguments to the `symbolic` definition. + outputs: The number of outputs this operator returns. + By default an operator is assumed to return a single output. + If `outputs` is greater than one, this functions returns a tuple + of output `Value`, representing each output of the ONNX operator + in order. + kwargs: The attributes of the ONNX operator, whose keys are named + according to the following convention: `alpha_f` indicates + the `alpha` attribute with type `f`. The valid type specifiers are + `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute + specified with type float accepts either a single float, or a + list of floats (e.g., you would say `dims_i` for a `dims` attribute + that takes a list of integers). + + Returns: + (Union[_C.Value, Tuple[_C.Value, ...]]) + The value representing the single output of this operator (see the `outputs` + keyword argument for multi-return nodes). + """ + inputs = [_const_if_tensor(graph_context, arg) for arg in args] + # Filter out None attributes, this can be convenient client side because + # now they can pass through None attributes, and have them not show up + attributes = {k: v for k, v in kwargs.items() if v is not None} + + if "::" not in opname: + opname = "onnx::" + opname + + node = _create_node( + graph_context.block, + opname, + inputs, + attributes, + params_dict=graph_context.params_dict, + opset_version=graph_context.opset, + n_outputs=outputs, + shape_inference=GLOBALS.onnx_shape_inference, + ) + graph_context.new_nodes.append(node) + + if outputs == 1: + return node.output() + return tuple(node.outputs()) + + +def _const_if_tensor(graph_context: GraphContext, arg): + if arg is None: + return arg + if isinstance(arg, _C.Value): + return arg + + return _add_op(graph_context, "onnx::Constant", value_z=arg) + + +def _create_node( + graph_or_block: _C.Graph | _C.Block, + domain_op: str, + inputs: Sequence, + attributes: dict, + params_dict: dict, + opset_version: int, + n_outputs: int, + shape_inference: bool = True, +) -> _C.Node: + """Creates an node 'domain_op', taking inputs and attributes.""" + if isinstance(graph_or_block, _C.Graph): + graph = graph_or_block + node = graph.create(domain_op, inputs, n_outputs) + node = graph.insertNode(node) + elif isinstance(graph_or_block, _C.Block): + block = graph_or_block + node = block.addNode(domain_op, inputs) + + # Block does not have create defined, so we need to add outputs manually + if n_outputs > 1: + for _ in range(1, n_outputs): + node.addOutput() + + node_outputs = tuple(node.outputs()) # type: ignore[possibly-undefined] + assert len(node_outputs) == n_outputs + + aten = domain_op.startswith("aten::") + + # Add all attributes + for key, value in sorted(attributes.items()): + if key in _SKIP_NODE_ATTRIBUTES: + continue + _add_attribute(node, key, value, aten=aten) + if shape_inference: + _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version) + return node + + +def _is_onnx_list(value): + return isinstance(value, Iterable) and not isinstance( + value, (str, bytes, torch.Tensor) + ) + + +def _scalar(x: torch.Tensor): + """Convert a scalar tensor into a Python value.""" + assert x.numel() == 1 + return x[0] + + +def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool): + r"""Initializes the right attribute based on type of value.""" + m = _ATTR_PATTERN.match(key) + if m is None: + raise ValueError( + f"Invalid attribute specifier '{key}' names " + "must be suffixed with type, e.g. 'dim_i' or 'dims_i'" + ) + name, kind = m.group(1), m.group(2) + if _is_onnx_list(value): + kind += "s" + + return getattr(node, f"{kind}_")(name, value) + + +# TODO: Expose this to user when migrating symbolic helper functions to here. +def _is_tensor(x: _C.Value) -> bool: + return x.type().isSubtypeOf(_C.TensorType.get()) + + +def get_device_from_value(value: _C.Value) -> torch.device | None: + if not _is_tensor(value): + return None + tensor_type = typing.cast(_C.TensorType, value.type()) + return tensor_type.device() + + +def parse_node_kind(kind: str) -> tuple[str, str]: + """Parse node kind into domain and Op name.""" + if "::" not in kind: + raise ValueError(f"Node kind: {kind} is invalid. '::' is not in node kind.") + domain, opname = kind.split("::", 1) + if "::" in opname: + raise ValueError(f"Node kind: {kind} is invalid. '::' should only apear once.") + return domain, opname + + +def is_aten(domain: str) -> bool: + """Check if the domain is official.""" + return domain == "aten" + + +def is_prim(domain: str) -> bool: + """Check if the domain is official.""" + return domain == "prim" + + +def is_onnx(domain: str) -> bool: + """Check if the domain is official.""" + return domain == "onnx" diff --git a/torch/onnx/_internal/onnx_proto_utils.py b/torch/onnx/_internal/onnx_proto_utils.py new file mode 100644 index 0000000000000..04ed0f83ef84c --- /dev/null +++ b/torch/onnx/_internal/onnx_proto_utils.py @@ -0,0 +1,250 @@ +# mypy: allow-untyped-defs +"""Utilities for manipulating the onnx and onnx-script dependencies and ONNX proto.""" + +from __future__ import annotations + +import glob +import os +import shutil +from typing import Any, TYPE_CHECKING + +import torch +import torch.jit._trace +import torch.serialization +from torch.onnx import errors +from torch.onnx._internal import jit_utils, registration + + +if TYPE_CHECKING: + import io + from collections.abc import Mapping + + +def export_as_test_case( + model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str +) -> str: + """Export an ONNX model as a self contained ONNX test case. + + The test case contains the model and the inputs/outputs data. The directory structure + is as follows: + + dir + \u251c\u2500\u2500 test_ + \u2502 \u251c\u2500\u2500 model.onnx + \u2502 \u2514\u2500\u2500 test_data_set_0 + \u2502 \u251c\u2500\u2500 input_0.pb + \u2502 \u251c\u2500\u2500 input_1.pb + \u2502 \u251c\u2500\u2500 output_0.pb + \u2502 \u2514\u2500\u2500 output_1.pb + + Args: + model_bytes: The ONNX model in bytes. + inputs_data: The inputs data, nested data structure of numpy.ndarray. + outputs_data: The outputs data, nested data structure of numpy.ndarray. + + Returns: + The path to the test case directory. + """ + try: + import onnx + except ImportError as exc: + raise ImportError( + "Export test case to ONNX format failed: Please install ONNX." + ) from exc + + test_case_dir = os.path.join(dir, "test_" + name) + os.makedirs(test_case_dir, exist_ok=True) + _export_file( + model_bytes, + os.path.join(test_case_dir, "model.onnx"), + {}, + ) + data_set_dir = os.path.join(test_case_dir, "test_data_set_0") + if os.path.exists(data_set_dir): + shutil.rmtree(data_set_dir) + os.makedirs(data_set_dir) + + proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined] + + for i, (input_proto, input) in enumerate(zip(proto.graph.input, inputs_data)): + export_data(input, input_proto, os.path.join(data_set_dir, f"input_{i}.pb")) + for i, (output_proto, output) in enumerate(zip(proto.graph.output, outputs_data)): + export_data(output, output_proto, os.path.join(data_set_dir, f"output_{i}.pb")) + + return test_case_dir + + +def load_test_case(dir: str) -> tuple[bytes, Any, Any]: + """Load a self contained ONNX test case from a directory. + + The test case must contain the model and the inputs/outputs data. The directory structure + should be as follows: + + dir + \u251c\u2500\u2500 test_ + \u2502 \u251c\u2500\u2500 model.onnx + \u2502 \u2514\u2500\u2500 test_data_set_0 + \u2502 \u251c\u2500\u2500 input_0.pb + \u2502 \u251c\u2500\u2500 input_1.pb + \u2502 \u251c\u2500\u2500 output_0.pb + \u2502 \u2514\u2500\u2500 output_1.pb + + Args: + dir: The directory containing the test case. + + Returns: + model_bytes: The ONNX model in bytes. + inputs: the inputs data, mapping from input name to numpy.ndarray. + outputs: the outputs data, mapping from output name to numpy.ndarray. + """ + try: + import onnx + from onnx import numpy_helper # type: ignore[attr-defined] + except ImportError as exc: + raise ImportError( + "Load test case from ONNX format failed: Please install ONNX." + ) from exc + + with open(os.path.join(dir, "model.onnx"), "rb") as f: + model_bytes = f.read() + + test_data_dir = os.path.join(dir, "test_data_set_0") + + inputs = {} + input_files = glob.glob(os.path.join(test_data_dir, "input_*.pb")) + for input_file in input_files: + tensor = onnx.load_tensor(input_file) # type: ignore[attr-defined] + inputs[tensor.name] = numpy_helper.to_array(tensor) + outputs = {} + output_files = glob.glob(os.path.join(test_data_dir, "output_*.pb")) + for output_file in output_files: + tensor = onnx.load_tensor(output_file) # type: ignore[attr-defined] + outputs[tensor.name] = numpy_helper.to_array(tensor) + + return model_bytes, inputs, outputs + + +def export_data(data, value_info_proto, f: str) -> None: + """Export data to ONNX protobuf format. + + Args: + data: The data to export, nested data structure of numpy.ndarray. + value_info_proto: The ValueInfoProto of the data. The type of the ValueInfoProto + determines how the data is stored. + f: The file to write the data to. + """ + try: + from onnx import numpy_helper # type: ignore[attr-defined] + except ImportError as exc: + raise ImportError( + "Export data to ONNX format failed: Please install ONNX." + ) from exc + + with open(f, "wb") as opened_file: + if value_info_proto.type.HasField("map_type"): + opened_file.write( + numpy_helper.from_dict(data, value_info_proto.name).SerializeToString() + ) + elif value_info_proto.type.HasField("sequence_type"): + opened_file.write( + numpy_helper.from_list(data, value_info_proto.name).SerializeToString() + ) + elif value_info_proto.type.HasField("optional_type"): + opened_file.write( + numpy_helper.from_optional( + data, value_info_proto.name + ).SerializeToString() + ) + else: + assert value_info_proto.type.HasField("tensor_type") + opened_file.write( + numpy_helper.from_array(data, value_info_proto.name).SerializeToString() + ) + + +def _export_file( + model_bytes: bytes, + f: io.BytesIO | str, + export_map: Mapping[str, bytes], +) -> None: + """export/write model bytes into directory/protobuf/zip""" + assert len(export_map) == 0 + with torch.serialization._open_file_like(f, "wb") as opened_file: + opened_file.write(model_bytes) + + +def _add_onnxscript_fn( + model_bytes: bytes, + custom_opsets: Mapping[str, int], +) -> bytes: + """Insert model-included custom onnx-script function into ModelProto""" + try: + import onnx + except ImportError as e: + raise errors.OnnxExporterError("Module onnx is not installed!") from e + + # For > 2GB model, onnx.load_fromstring would fail. However, because + # in _export_onnx, the tensors should be saved separately if the proto + # size > 2GB, and if it for some reason did not, the model would fail on + # serialization anyway in terms of the protobuf limitation. So we don't + # need to worry about > 2GB model getting here. + model_proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined] + + # Iterate graph nodes to insert only the included custom + # function_proto into model_proto + onnx_function_list = [] # type: ignore[var-annotated] + included_node_func: set[str] = set() + # onnx_function_list and included_node_func are expanded in-place + _find_onnxscript_op( + model_proto.graph, included_node_func, custom_opsets, onnx_function_list + ) + + if onnx_function_list: + model_proto.functions.extend(onnx_function_list) + model_bytes = model_proto.SerializeToString() + return model_bytes + + +def _find_onnxscript_op( + graph_proto, + included_node_func: set[str], + custom_opsets: Mapping[str, int], + onnx_function_list: list, +): + """Recursively iterate ModelProto to find ONNXFunction op as it may contain control flow Op.""" + for node in graph_proto.node: + node_kind = node.domain + "::" + node.op_type + # Recursive needed for control flow nodes: IF/Loop which has inner graph_proto + for attr in node.attribute: + if attr.g is not None: + _find_onnxscript_op( + attr.g, included_node_func, custom_opsets, onnx_function_list + ) + # Only custom Op with ONNX function and aten with symbolic_fn should be found in registry + onnx_function_group = registration.registry.get_function_group(node_kind) + # Ruled out corner cases: onnx/prim in registry + if ( + node.domain + and not jit_utils.is_aten(node.domain) + and not jit_utils.is_prim(node.domain) + and not jit_utils.is_onnx(node.domain) + and onnx_function_group is not None + and node_kind not in included_node_func + ): + specified_version = custom_opsets.get(node.domain, 1) + onnx_fn = onnx_function_group.get(specified_version) + if onnx_fn is not None: + if hasattr(onnx_fn, "to_function_proto"): + onnx_function_proto = onnx_fn.to_function_proto() # type: ignore[attr-defined] + onnx_function_list.append(onnx_function_proto) + included_node_func.add(node_kind) + continue + + raise errors.UnsupportedOperatorError( + node_kind, + specified_version, + onnx_function_group.get_min_supported() + if onnx_function_group + else None, + ) + return onnx_function_list, included_node_func diff --git a/torch/onnx/_internal/onnxruntime.py b/torch/onnx/_internal/onnxruntime.py new file mode 100644 index 0000000000000..b994328fcdd82 --- /dev/null +++ b/torch/onnx/_internal/onnxruntime.py @@ -0,0 +1,1260 @@ +# mypy: allow-untyped-defs +import dataclasses +import importlib +import logging +import os +from collections.abc import Mapping, Sequence +from typing import Any, Callable, Final, Optional, TYPE_CHECKING, Union +from typing_extensions import TypeAlias + +import torch +import torch._C +import torch._ops +import torch._prims.executor +import torch.fx +import torch.onnx._internal._lazy_import +from torch._subclasses.fake_tensor import FakeTensor +from torch.fx._compatibility import compatibility +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.fx.passes.operator_support import OperatorSupport +from torch.fx.passes.tools_common import CALLABLE_NODE_OPS +from torch.utils import _pytree + + +if TYPE_CHECKING: + import onnx + import onnxruntime + from onnxruntime.capi import _pybind_state as ORTC + + import torch.onnx + import torch.onnx._internal + import torch.onnx._internal._exporter_legacy + import torch.onnx._internal.fx.decomposition_table + import torch.onnx._internal.fx.passes # noqa: TCH004 + + +_SUPPORT_ONNXRT: Optional[bool] = None + +__all__ = [ + "is_onnxrt_backend_supported", + "torch_compile_backend", + "OrtExecutionProvider", + "OrtBackendOptions", + "OrtBackend", +] + + +def is_onnxrt_backend_supported() -> bool: + """Returns ``True`` if ONNX Runtime dependencies are installed and usable + to support TorchDynamo backend integration; ``False`` otherwise. + + Example:: + + # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) + >>> import torch + >>> if torch.onnx.is_onnxrt_backend_supported(): + ... @torch.compile(backend="onnxrt") + ... def f(x): + ... return x * x + ... print(f(torch.randn(10))) + ... else: + ... print("pip install onnx onnxscript onnxruntime") + ... + """ + global _SUPPORT_ONNXRT + + if _SUPPORT_ONNXRT is None: + # `onnxruntime` might import a lot of other runtime packages, + # e.g. apex, deepspeed, transformers. + # So lazy-importing onnxruntime to avoid possible circular import. + try: + importlib.import_module("onnxruntime") + importlib.import_module("onnxruntime.capi._pybind_state") + + # This is not use directly in DORT but needed by underlying exporter, + # so we still need to check if it exists. + importlib.import_module("onnxscript") + + import torch.onnx # noqa: F401 + import torch.onnx._internal # noqa: F401 + import torch.onnx._internal._exporter_legacy # noqa: F401 + from torch.onnx._internal.fx import ( # noqa: F401 + decomposition_table, + fx_onnx_interpreter, + passes, + type_utils, + ) + + _SUPPORT_ONNXRT = True + except ImportError: + _SUPPORT_ONNXRT = False + + return _SUPPORT_ONNXRT + + +_dumped_onnx_model: dict[str, int] = {} + + +def _dump_onnx_model( + model_string: bytes, graph_module: Optional[torch.fx.GraphModule] = None +) -> str: + """Stores the onnx model into a file. + The name is "{ONNXRT_DUMP_PATH}{N}.onnx" + where *N* is the number of files already stored with + this prefix. + If graph_module is not None, the graph is stored as a string with + the same filename except the extension (.txt). + """ + prefix = os.environ.get("ONNXRT_DUMP_PATH", None) + if not prefix: + return "" + n = _dumped_onnx_model.get(prefix, -1) + 1 + filename = f"{prefix}{n}.onnx" + with open(filename, "wb") as f: + f.write(model_string) + _dumped_onnx_model[prefix] = n + if graph_module is not None: + filename_txt = f"{prefix}{n}.txt" + with open(filename_txt, "w", encoding="utf-8") as f: + f.write(str(graph_module.graph)) + return filename + + +def _infer_default_eps() -> Sequence[str]: + # TODO: select a good default based on the capabilities of the host + # e.g. DML on Windows, etc. + return ["CPUExecutionProvider"] + + +def _nvtx_range_push(name: str): + """If PyTorch is installed with CUDA support, this starts NVTX range. + + Check torch.cuda.nvtx.range_push's document for more details. + """ + if torch.cuda.is_available(): + torch.cuda.nvtx.range_push(name) + + +def _nvtx_range_pop(): + """If PyTorch is installed with CUDA support, this terminates NVTX range. + + Check torch.cuda.nvtx.range_pop's document for more details. + """ + if torch.cuda.is_available(): + torch.cuda.nvtx.range_pop() + + +def _get_ort_device_type(device_type: str): + from onnxruntime.capi import _pybind_state as ORTC + + if device_type == "cuda": + return ORTC.OrtDevice.cuda() + if device_type == "cpu": + return ORTC.OrtDevice.cpu() + # ort pytorch device is mapped to NPU OrtDevice type + if device_type == "maia": + return ORTC.OrtDevice.npu() + raise ValueError("Unsupported device type: " + device_type) + + +logger = logging.getLogger(__name__) +# Uncomment the following lines to print out development info. +# logging.basicConfig(level=logging.WARNING) +# logger.setLevel(logging.WARNING) + + +class OrtOperatorSupport(OperatorSupport): + """Operator support for ONNXRuntime backend. + + It has two-level of support decision. One is via support_dict and the other one + is via extra_support_dict. The logic of using support_dict is implemented in + OrtOperatorSupport and extra_support_dict is used by OperatorSupport.is_node_supported. + """ + + def __init__(self, support_dict: set[Any], extra_support_dict: dict[str, Any]): + # Use extra_support_dict[op_name] = None to indicate + # we support op_name with all input types. Otherwise, + # see support_dict (type: SupportDict) in operator_support.py + # for specifying supported types. + super().__init__(extra_support_dict) + self._onnx_support_dict = support_dict + + def is_node_supported( + self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + # OperatorSupport.is_node_supported returns True for non-callable nodes. + # Since ORT can't execute them, we return False here to override the base + # behavior. + if node.op not in CALLABLE_NODE_OPS: + return False + # This is the and the only place to decide if aten op is supported. + if node.op == "call_function" and node.target in self._onnx_support_dict: + logger.info( + "support_dict supports node.target: %s (type: %s)", + node.target, + type(node.target), + ) + return True + # If node.target is not in support_dict, we still want to check if torch.jit.script + # can convert it to ONNX equivalence. Let's use base mechanism to do this. + # See extra_support_dict for supported ops. + if super().is_node_supported(submodules, node): + logger.info( + "extra_support_dict supports node.target: %s (type: %s)", + node.target, + type(node.target), + ) + return True + logger.warning( + "support_dict and extra_support_dict don't support node.target: %s (type: %s)", + node.target, + type(node.target), + ) + return False + + +def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None: + """ + In torch.fx.Graph, placeholder is a special assignment node. If it's not + executed in the beginning, it could overwrite values computed by upstream + nodes. + """ + + graph = graph_module.graph + placeholders = [] + first_not_placeholder = None + for node in graph.nodes: + if node.op == "placeholder": + placeholders.append(node) + if first_not_placeholder is None and node.op != "placeholder": + first_not_placeholder = node + if first_not_placeholder is None: + return + for placeholder in placeholders: + first_not_placeholder.prepend(placeholder) + + +def _infer_ep_from_device(*args) -> tuple[str, ...]: + """Return the first valid device (i.e., GPU or CPU) in argument list.""" + eps = [] + for arg in args: + if hasattr(arg, "device"): + device = arg.device + if device.type == "cuda": + eps.append("CUDAExecutionProvider") + elif device.type == "cpu": + eps.append("CPUExecutionProvider") + return tuple(eps) + + +def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> tuple[Any, ...]: + placeholders = [] + for node in graph_module.graph.nodes: + if node.op == "placeholder": + if hasattr(node, "meta") and "val" in node.meta: + assert isinstance(node.meta["val"], torch.Tensor) + placeholders.append(node) + return tuple(placeholders) + + +def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any: + """Collect "val" fields from outputs metadata in this torch.fx.GraphModule.""" + for node in graph_module.graph.nodes: + if node.op == "output": + # Output node is unique. Let's retrieve output values from + # this node's input list. And then just return. + return node.args[0] + raise ValueError("No output node found in this torch.fx.GraphModule.") + + +def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> tuple[str, ...]: + """Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule.""" + flattened_output_args, _ = _pytree.tree_flatten( + _extract_graph_module_outputs(graph_module) + ) + # Output arguments with example value (type: torch.Tensor) in the `graph_module`. + selected_output_args = [ + output_arg.meta["val"] + for output_arg in flattened_output_args + # output_arg must have tensor for its device information. + # Otherwise, skip it. + if (hasattr(output_arg, "meta") and "val" in output_arg.meta) + ] + return _infer_ep_from_device(*selected_output_args) + + +def _sort_eps(eps: tuple[str, ...]) -> tuple[str, ...]: + """Sort execution providers in eps based on pre-set priority.""" + + def get_execution_provider_priority(ep: str) -> int: + if ep == "CPUExecutionProvider": + # Lowest priority. + return 2 + if ep == "CUDAExecutionProvider": + # Higher priority than CPU but lower than + # other specialized EPs. + return 1 + # Highest priority. + return 0 + + unique_eps = set(eps) + return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True)) + + +def _get_onnx_devices( + values: tuple[ + Union[ + torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool + ], + ..., + ], +) -> tuple["ORTC.OrtDevice", ...]: + from onnxruntime.capi import _pybind_state as ORTC + + def _device_id_or_zero(device_id: int) -> int: + return device_id or 0 + + def _map_tensor_or_sym_to_device( + value: Union[ + torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool + ], + ) -> int: + if isinstance(value, torch.Tensor): + return ORTC.OrtDevice( + _get_ort_device_type(value.device.type), + ORTC.OrtDevice.default_memory(), + _device_id_or_zero(value.device.index), + ) + elif isinstance( + value, (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool) + ): + return ORTC.OrtDevice( + _get_ort_device_type("cpu"), ORTC.OrtDevice.default_memory(), 0 + ) + else: + raise ValueError("Unsupported value type: " + str(type(value))) + + if len(values) > 0: + ort_devices = tuple(_map_tensor_or_sym_to_device(value) for value in values) + return ort_devices + else: + return (_map_tensor_or_sym_to_device(1),) + + +def _get_ortvalues_from_torch_tensors( + tensors: tuple[torch.Tensor, ...], devices: tuple["ORTC.OrtDevice", ...] +) -> tuple[torch.Tensor, ...]: + # TODO(justinchuby): Refactor this function + import numpy as np + from onnxruntime.capi import _pybind_state as ORTC + + torch_dtype_to_numpy_dtype = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.uint8: np.uint8, + torch.int8: np.int8, + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.longlong, + torch.bool: np.bool_, + } + ortvalues = ORTC.OrtValueVector() + ortvalues.reserve(len(tensors)) + dtypes = [] + shapes = [] + data_ptrs = [] + + for tensor in tensors: + dtypes.append(torch_dtype_to_numpy_dtype[tensor.dtype]) + shapes.append(tensor.size()) + data_ptrs.append(tensor.data_ptr()) + ortvalues.push_back_batch(tensors, data_ptrs, dtypes, shapes, devices) + return ortvalues + + +def _to_real_tensor(tensor: FakeTensor) -> torch.Tensor: + if tensor.is_sparse: + raise ValueError("sparse tensor is not yet supported.") + out = torch.empty(tensor.size(), dtype=tensor.dtype, device=tensor.device) + return out + + +def _adjust_scalar_from_fx_to_onnx( + dynamo_value: Union[ + torch.Tensor, + int, + float, + bool, + ], + value_info: "onnx.ValueInfoProto", # type: ignore[name-defined] +) -> torch.Tensor: + """Helper function to wrap PyTorch variables as torch.Tensor""" + if ( + isinstance(dynamo_value, torch.Tensor) + and len(value_info.type.tensor_type.shape.dim) == 0 + and dynamo_value.shape == (1,) + ): + # ONNX expect a scalar with empty shape. + # In contrast, PyTorch usually allows implicit + # conversion between shape=() and shape=(1,). + # + # Below, PyTorch's shape (1,) is reshaped to (). + return torch.squeeze(dynamo_value) + elif isinstance(dynamo_value, int): + return torch.tensor(dynamo_value, dtype=torch.int64) + elif isinstance(dynamo_value, float): + return torch.tensor(dynamo_value, dtype=torch.float32) + elif isinstance(dynamo_value, bool): + return torch.tensor(dynamo_value, dtype=torch.bool) + else: + assert isinstance(dynamo_value, torch.Tensor) + return dynamo_value.contiguous() + + +def _adjust_scalar_from_onnx_to_fx( + tensor: torch.Tensor, + prim_value: Union[ + torch.Tensor, + torch.SymInt, + int, + torch.SymFloat, + float, + torch.SymBool, + bool, + ], +) -> Union[ + torch.Tensor, + int, + float, + bool, +]: + """Helper function to wrap ORT-produced torch.Tensor as PyTorch variables""" + assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor." + if isinstance( + prim_value, + (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool), + ): + # Convert tensor back to scalar to match Dynamo's expectation. + return tensor.item() + return tensor + + +def _run_onnx_session_with_ortvaluevector( + sess: "onnxruntime.InferenceSession", + input_names: tuple[str, ...], + inputs: tuple[torch.Tensor, ...], + input_devices: tuple["ORTC.OrtDevice", ...], + output_names: tuple[str, ...], + outputs: tuple[torch.Tensor, ...], + output_devices: tuple["ORTC.OrtDevice", ...], + preallocate_output: bool, + input_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] + normalized_prim_outputs: tuple[ + Union[ + torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool + ], + ..., + ], +) -> tuple[Union[torch.Tensor, int, float, bool], ...]: + import onnxruntime + from onnxruntime.capi import _pybind_state as ORTC + + _nvtx_range_push("contiguous") + inputs = tuple( + _adjust_scalar_from_fx_to_onnx(arg, value_info) + for arg, value_info in zip(inputs, input_value_infos) + ) + _nvtx_range_pop() + + _nvtx_range_push("push_back_batch") + ort_inputs = _get_ortvalues_from_torch_tensors(inputs, input_devices) + + # preallocate output pytorch Tensors and use the buffers affined to the torch device for the output ortvalue. + # Because the output ortvalue is not allocated and owned by ort, it does not need to convert the output ortvalue + # to torch Tensor transferring the ownership. + if preallocate_output: + pth_outputs = tuple( + _to_real_tensor(t) if isinstance(t, FakeTensor) else t for t in outputs + ) + ort_outputs = _get_ortvalues_from_torch_tensors(pth_outputs, output_devices) + else: + ort_outputs = ORTC.OrtValueVector() + _nvtx_range_pop() + + _nvtx_range_push("run_with_ortvaluevector") + run_options = onnxruntime.RunOptions() + run_options.add_run_config_entry("disable_synchronize_execution_providers", "1") + sess.run_with_ortvaluevector( + run_options, input_names, ort_inputs, output_names, ort_outputs, output_devices + ) + _nvtx_range_pop() + + # Post-processing step: + # wrap ORT's outputs to the schema represented by + # `prim_output` (obtained by running the original + # torch.fx.GraphModule). + if preallocate_output: + # Profile the ORT-to-PyTorch type cast below + _nvtx_range_push("after run_with_ortvaluevector") + # Outputs are stored on pre-allocated torch.Tensors' memory, + # so this case doesn't need to convert ORTValue to torch.Tensor. + pth_outputs = tuple( + _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc] + for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs) + ) + _nvtx_range_pop() + return pth_outputs + else: + import onnxruntime.training + + # Profile the two ORT-to-PyTorch type casts below + _nvtx_range_push("after run_with_ortvaluevector") + # Map ORTValue to torch.Tensor. + pth_outputs = onnxruntime.training.ortmodule._utils._ortvalues_to_torch_tensor( + ort_outputs + ) + # Change some torch.Tensor to int, float, bool. + pth_outputs = tuple( + _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc] + for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs) + ) + _nvtx_range_pop() + return pth_outputs + + +def _run_onnx_session_with_fetch( + sess: "onnxruntime.InferenceSession", + input_names: tuple[str, ...], + inputs: tuple[torch.Tensor, ...], + input_devices: tuple["ORTC.OrtDevice", ...], + output_names: tuple[str, ...], + outputs: tuple[torch.Tensor, ...], + output_devices: tuple["ORTC.OrtDevice", ...], + preallocate_output: bool, + input_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] + normalized_prim_outputs: tuple[ + Union[ + torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool + ], + ..., + ], +) -> tuple[Union[torch.Tensor, int, float, bool], ...]: + import onnxruntime + + inputs = tuple( + _adjust_scalar_from_fx_to_onnx(arg, value_info) + for arg, value_info in zip(inputs, input_value_infos) + ) + feed = { + name: onnxruntime.OrtValue.ortvalue_from_numpy(tensor.cpu().numpy()) + for name, tensor in zip(input_names, inputs) + } + ort_outputs = sess.run(output_names, feed) + pth_outputs = tuple( + _adjust_scalar_from_onnx_to_fx( + torch.from_numpy(value), + prim_output, + ) + for value, prim_output in zip(ort_outputs, normalized_prim_outputs) + ) + return pth_outputs + + +def _from_python_type_to_onnx_tensor_element_type(type: type): + """ + Converts a Python type to the corresponding ONNX tensor element type. + For example, `_from_python_type_to_onnx_tensor_element_type(float)` returns + `onnx.TensorProto.FLOAT`. + + Args: + type (type): The Python type to convert. + + Returns: + int: The corresponding ONNX tensor element type. + + """ + import onnx + + _PYTHON_TYPE_TO_ONNX_TENSOR_ELEMENT_TYPE = { + float: onnx.TensorProto.FLOAT, # type: ignore[attr-defined] + int: onnx.TensorProto.INT64, # type: ignore[attr-defined] + bool: onnx.TensorProto.BOOL, # type: ignore[attr-defined] + } + return _PYTHON_TYPE_TO_ONNX_TENSOR_ELEMENT_TYPE.get(type) + + +class OrtExecutionInfoPerSession: + """Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession""" + + def __init__( + self, + session: "onnxruntime.InferenceSession", + input_names: tuple[str, ...], + input_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] + output_names: tuple[str, ...], + output_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] + input_devices: tuple["ORTC.OrtDevice", ...], + output_devices: tuple["ORTC.OrtDevice", ...], + example_outputs: Union[tuple[torch.Tensor, ...], torch.Tensor], + ): + # Carrier of ONNX model and its executor. + self.session: onnxruntime.InferenceSession = session + # For the ONNX model stored in self.session, self.input_names[i] is the + # name of the i-th positional input. + self.input_names: tuple[str, ...] = input_names + # self.input_name[i]'s type information is stored in self.input_value_infos[i]. + self.input_value_infos: tuple[onnx.ValueInfoProto, ...] = input_value_infos # type: ignore[name-defined] + # Similar to self.input_names, but for outputs. + self.output_names: tuple[str, ...] = output_names + # Similar to self.input_value_infos but for outputs. + self.output_value_infos: tuple[onnx.ValueInfoProto, ...] = output_value_infos # type: ignore[name-defined] + # For the ONNX model stored in self.session, self.input_devices[i] is the + # i-th positional input's device. + self.input_devices: tuple[ORTC.OrtDevice, ...] = input_devices + # Similar to self.input_devices, but for outputs. + self.output_devices: tuple[ORTC.OrtDevice, ...] = output_devices + # This is the outputs of executing the original torch.fx.GraphModule with example inputs + # (i.e., args passed into OrtBackend._ort_acclerated_call). + self.example_outputs: Union[tuple[torch.Tensor, ...], torch.Tensor] = ( + example_outputs + ) + + def is_supported(self, *args): + # TODO(justinchuby): Simplify + import onnx + + _onnx_tensor_element_type_to_torch_dtype = { + onnx.TensorProto.FLOAT: torch.float32, # type: ignore[attr-defined] + onnx.TensorProto.FLOAT16: torch.float16, # type: ignore[attr-defined] + onnx.TensorProto.FLOAT8E5M2: torch.float8_e5m2, # type: ignore[attr-defined] + onnx.TensorProto.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz, # type: ignore[attr-defined] + onnx.TensorProto.FLOAT8E4M3FN: torch.float8_e4m3fn, # type: ignore[attr-defined] + onnx.TensorProto.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz, # type: ignore[attr-defined] + onnx.TensorProto.DOUBLE: torch.float64, # type: ignore[attr-defined] + onnx.TensorProto.BOOL: torch.bool, # type: ignore[attr-defined] + onnx.TensorProto.UINT8: torch.uint8, # type: ignore[attr-defined] + onnx.TensorProto.INT8: torch.int8, # type: ignore[attr-defined] + onnx.TensorProto.INT16: torch.int16, # type: ignore[attr-defined] + onnx.TensorProto.INT32: torch.int32, # type: ignore[attr-defined] + onnx.TensorProto.INT64: torch.int64, # type: ignore[attr-defined] + } + _torch_dtype_to_onnx_tensor_element_type = { + value: key + for key, value in _onnx_tensor_element_type_to_torch_dtype.items() + } + + # Compare the args and the input schema in ONNX model and + # return the first match. + if len(args) != len(self.input_value_infos): + return False + for arg, value_info in zip(args, self.input_value_infos): + if not isinstance(arg, (torch.Tensor, float, int)): + return False + + # Check Python scalars such as int, float, and bool. + if isinstance(arg, (int, float, bool)): + # Map, e.g., float to onnx.TensorProto.FLOAT. + onnx_dtype = _from_python_type_to_onnx_tensor_element_type(type(arg)) + if onnx_dtype != value_info.type.tensor_type.elem_type: + return False + if len(value_info.type.tensor_type.shape.dim) != 0: + return False + continue + + # Check tensor. + onnx_dtype = _torch_dtype_to_onnx_tensor_element_type[arg.dtype] + if onnx_dtype != value_info.type.tensor_type.elem_type: + return False + for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim): + if isinstance(dim, int) and ( + onnx_dim.dim_value == dim or onnx_dim.dim_param + ): + continue + elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param: + continue + else: + return False + return True + + +@dataclasses.dataclass +class OrtExecutionInfoForAllGraphModules: + def __init__(self) -> None: + # All sessions (and their related information) created by exporting the same GraphModule + # with different inputs. + self.execution_info_per_graph_module: dict[ + torch.fx.GraphModule, list[OrtExecutionInfoPerSession] + ] = {} + + def search_reusable_session_execution_info( + self, graph_module: torch.fx.GraphModule, *args + ): + if graph_module not in self.execution_info_per_graph_module: + return None + # All execution information for ONNX models exported from the same `graph_module` + # with different inputs. + candidates = self.execution_info_per_graph_module[graph_module] + + for candidate in candidates: + if candidate.is_supported(*args): + # Returns the first session that accepts this input schema. + return candidate + # No reusable session found. + return None + + def cache_session_execution_info( + self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession + ): + if graph_module not in self.execution_info_per_graph_module: + self.execution_info_per_graph_module[graph_module] = [info] + else: + self.execution_info_per_graph_module[graph_module].append(info) + + +OrtExecutionProvider: TypeAlias = Union[str, tuple[str, Mapping[str, Any]]] +"""Either the name of an ONNX Runtime execution provider as a string or +a 2-tuple of the name and a dictionary of execution provider options. + +Examples:: + + >>> "CPUExecutionProvider" + + >>> ("CUDAExecutionProvider", {"device_id": 3}) + +""" + + +@dataclasses.dataclass(frozen=True) +@compatibility(is_backward_compatible=False) +class OrtBackendOptions: + """Options for constructing an ``OrtBackend``, the ONNX Runtime + backend (``"onnxrt"``) for ``torch.compile``. + + Example:: + + >>> @torch.compile( + ... backend="onnxrt", + ... options=torch.onnx._OrtBackendOptions(...), + ... ) + ... def ort_function(x): + ... return x ** x + """ + + preferred_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None + """An optional sequence of execution providers to be prioritized ahead of any + execution providers that may be inferred (see ``infer_execution_providers``). + """ + + infer_execution_providers: bool = True + """Whether to infer an execution provider from ``torch.device`` bound to inputs or found in the graph.""" + + default_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None + """The default fallback execution providers. If not specified, one will be + be selected based on the host environment (most likely ``"CPUExecutionProvider"``). + """ + + # preallocate_output allows for allocating output torch Tensor buffers and feeding them to InferenceSession + # in order to avoid internal allocation of output buffers in InferenceSession. + # If output ortvalue returned from InferenceSession is allocated internally, + # it needs to be converted to torch Tensor for return, and the torch Tensor should hold the ownership. + # When a custom torch device is used with a custom aten allocator, the conversion from ortvalue to torch Tensor + # should be supported, which is currently done through dlpack. Note that dlpack might not support a custom torch device. + # It can be avoided by allowing for preallocation for output buffers allocated by a custom aten allocator, + # and use the preallocated output buffers for InferenceSession not holding any ownership for them. + # TODO(wschin): Make it to inference session level flag. + # See https://github.com/pytorch/pytorch/issues/106869. + preallocate_output: bool = False + """If ``True``, allocate memory for ONNX Runtime's outputs on the PyTorch side.""" + + use_aot_autograd: bool = True + """Whether to wrap the ``OrtBackend`` with TorchDynamo's aot_autograd backend + to support training (i.e., backward graphs are also sent to ``OrtBackend``). + + Symbolic execution is used to capture the forward pass and backward passes as a single graph. + Then, a selected graph partition algorithm (``min_cut_rematerialization_partition``) is used + to split the entire graph into forward sub-graph and backward sub-graph. Finally, both + sub-graphs are compiled by ``OrtBackend``. + """ + + ort_session_options: Optional["onnxruntime.SessionOptions"] = None + """Options for the ``onnxruntime.InferenceSession`` used by the ``OrtBackend``.""" + + pre_ort_model_transforms: Optional[ # type: ignore[name-defined] + Sequence[Callable[["onnx.ModelProto"], None]] + ] = None + """A list of graph transforms to be applied to the ONNX model before it + is fed to ONNXRuntime's InferenceSession.""" + + +@compatibility(is_backward_compatible=False) +class OrtBackend: + """A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls. + + The compiler entry point is OrtBackend.compile, which + 1. partitions the original graph into supported sub-graphs (type: torch.fx.GraphModule) and unsupported + sub-graphs. + 2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call. + 3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph. + """ + + def __init__(self, options: Optional[OrtBackendOptions] = None): + from onnxruntime.capi import _pybind_state as ORTC + + import torch.onnx + import torch.onnx._internal._exporter_legacy + import torch.onnx._internal.fx.decomposition_table + + self._options: Final = OrtBackendOptions() if options is None else options + + # options.export_options contains information shared between exporter and DORT. + # For example, they should use the same decomposition table when + # 1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py) + # 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model + # (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below). + # + # Convert user-facing option to internal option used by ONNX exporter + # to access required information. + # Some useful fields: + # - Decomposition table for decomposing FX operators in exporter is + # self._resolved_onnx_exporter_options.decomposition_table. + # - self._resolved_onnx_exporter_options.onnx_registry records what + # aten/prim ops are supported by exporter and their exporters (type: callable). + self._resolved_onnx_exporter_options = ( + torch.onnx._internal._exporter_legacy.ResolvedExportOptions() + ) + + # Given DORT's computation flow: + # 1. OrtOperatorSupport uses support_dict and extra_support_dict to select operators + # and send them to DORT. + # 2. Then, DORT exports the selected sub-graphs into ONNX. + # 3. Finally DORT calls ORT to do the computation. + # OrtOperatorSupport and create_onnx_friendly_decomposition_table(...) + # must use the same support_dict. If the support_dict here contains something not + # supported by exporter, exporter will fails in step 2 since the selected graphs may + # contains unsupported operators such as aten::_who_you_are. + # This restriction is automatically done since DORT and exporter shares the same + # self._resolved_onnx_exporter_options. + support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table( + self._resolved_onnx_exporter_options.onnx_registry + ) + + extra_support_dict: dict[str, Any] = { + "getattr": None, + # To send operator.getitem to ORT, add the corresponding string + # recognized by PyTorch's OperatorSupport class. + "_operator.getitem": None, + # To send operator.mul to ORT, add the corresponding string + # recognized by PyTorch's OperatorSupport class. + "_operator.mul": None, + "_operator.add": None, + "_operator.sub": None, + } + + self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict) + # TODO(wschin): this is a naive implementation of cache without proper guard + # See https://github.com/pytorch/pytorch/issues/106868. + self._partitioner_cache: dict[torch.fx.GraphModule, torch.fx.GraphModule] = {} + # Conceptually, this filed is a 2-layer dictionary + # GraphModule 0 + # ONNX Model 0 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) + # ONNX Model 1 + # ... + # GraphModule 1 + # ONNX Model 2 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) + # ONNX Model 3 + # ... + # ... + # , which caches all previous compilation result so that we can reuse them. + # ONNX Model 0 and 1 are exported from the same GraphModule 0 but with different inputs + # (e.g., tensors with different ranks). GraphModule 0 and GraphModule 1 are different + # graphs captured by Dynamo and sent to OrtBackend.compile. + self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules() + + self._assert_allclose_to_baseline = False + + self.execution_count = 0 + + # Function which invokes ORT do to the real computation. + self.run = ( + _run_onnx_session_with_ortvaluevector + if hasattr(ORTC.OrtValueVector, "push_back_batch") + else _run_onnx_session_with_fetch + ) + + def _select_eps( + self, graph_module: torch.fx.GraphModule, *args + ) -> Sequence[tuple[str, Mapping[str, Any]]]: + inferred_eps: tuple[str, ...] = () + if self._options.infer_execution_providers: + if eps_from_args := _infer_ep_from_device(*args): + # If user feeds CUDA tensor as input argument, + # we want to use CUDA EP. + # Thus, `eps_from_args` (deduced from input arguments) + # has highest priority. + inferred_eps = eps_from_args + elif eps_from_graph_module := _infer_ep_from_graph_module(graph_module): + # If there is no EP in input arguments, we deduce EP from + # graph_module's outputs. Those outputs may come from + # FakeTensorProp or Dynamo's built-in symbolic shape inference. + inferred_eps = eps_from_graph_module + + selected_eps = [] + + for ep in ( + *(self._options.preferred_execution_providers or []), + *_sort_eps(inferred_eps), + *(self._options.default_execution_providers or _infer_default_eps()), + ): + if isinstance(ep, str): + ep = (ep, {}) + elif isinstance(ep, tuple) and ep[1] is None: + ep = (ep[0], {}) + if ep is not None and ep not in selected_eps: + selected_eps.append(ep) + + return selected_eps + + def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs): + """This function replaces GraphModule._wrapped_call in compiled model. + + The _wrapped_call is the underlying implementation of forward method. Replacing + it means we delegate the computation to _ort_acclerated_call and therefore + onnxruntime.InferenceSession. + """ + import onnxruntime + + from torch.onnx._internal.fx import fx_onnx_interpreter, passes + + cached_execution_info_per_session = ( + self._all_ort_execution_info.search_reusable_session_execution_info( + graph_module, *args + ) + ) + if cached_execution_info_per_session: + onnx_session = cached_execution_info_per_session.session + input_names = cached_execution_info_per_session.input_names + output_names = cached_execution_info_per_session.output_names + input_value_infos = cached_execution_info_per_session.input_value_infos + output_value_infos = cached_execution_info_per_session.output_value_infos + input_devices = cached_execution_info_per_session.input_devices + output_devices = cached_execution_info_per_session.output_devices + prim_outputs = cached_execution_info_per_session.example_outputs + else: + # It's first time seeing such as graph. Let's make a new session + # (type: onnxruntime.InferenceSession) for it. + + graph_module = passes.MovePlaceholderToFront( + graph_module, + ).run() + # Generate reference outputs. They are used to indicate output + # tensors' types and devices when calling ORT. + # + # WARNING: The downstream code should not change prim_outputs and + # this backend should always produces output with schema identical to prim_outputs'. + + if self._resolved_onnx_exporter_options.dynamic_shapes: + # No pre-allocation when dynamic shape is enabled. + self.preallocate_output = False + extracted_outputs = _extract_graph_module_outputs(graph_module) + + def maybe_map_to_meta_val(value): + if hasattr(value, "meta") and "val" in value.meta: + # Select outputs with "val" information. Without "val", + # it's not possible access output_arg.meta["val"].device. + return value.meta["val"] + else: + return value + + prim_outputs = _pytree.tree_map( + maybe_map_to_meta_val, extracted_outputs + ) + else: + try: + prim_outputs = FakeTensorProp(graph_module).propagate( + *args, **kwargs + ) + except Exception: + logger.warning("FakeTensorProb failed for %s", graph_module) + # When FakeTensorProp fails, it is not possible to preallocate output buffers + # because the output shapes are not inferred. + self.preallocate_output = False + + # rethrow FakeTensorProb failure because it is not yet currently handled. + raise + + # Create the object to iterate through the nodes in graph one-by-one + # and calls the corresponding ONNX exporter for each node. + fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter() + # Cast FX variables if they will result schema-mismatch when searching + # for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch, + # but ONNX expects add(double_tensor, double_tensor). + graph_module = passes.InsertTypePromotion(graph_module).run() + # Start the per-node exporting process. It's conceptually a for loop + # scanning through the nodes in the graph. + exported = fx_interpreter.run( + fx_graph_module=graph_module, + onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher, + ) + # Convert the exported result to ONNX ModelProto. + onnx_model = exported.to_model_proto( + opset_version=self._resolved_onnx_exporter_options.onnx_registry.opset_version, + ) + + # Modify ONNX model using pre-registered graph transforms. + # They are in-place modifications for avoiding unnecessary + # copy of ONNX initializers. + if self._options.pre_ort_model_transforms: + for transform in self._options.pre_ort_model_transforms: + transform(onnx_model) + + onnx_model_bytes = onnx_model.SerializeToString() + if os.environ.get("ONNXRT_DUMP_PATH", None): + # If not empty, environment variable ONNXRT_DUMP_PATH defined the path + # where generated onnx files should be stored. + # This module keeps a global variables keeping track of the + # stored models. + # If ONNXRT_DUMP_PATH="dumped/dumped_model_" + # The first file name will be 'dumped/dumped_model_0.onnx'. + # For every dumped model, a text file 'dumped/dumped_model_0.txt' + # is created as well to contain the string representing the graph_module. + _dump_onnx_model(onnx_model_bytes, graph_module=graph_module) + + # Initialize a ORT session to execute this ONNX model. + # Note that TorchDynamo assumes all inputs/outputs are on the + # same device, but it's subject to change (very likely with + # dynamic shape support), so we add execution providers + # based on the logic in _select_eps: (explicitly preferred EPs, + # EPs inferred from inputs or graph, and the fallback default EP)/ + # + # TODO(wschin): enable external allocators. + # See https://github.com/pytorch/pytorch/issues/106867 + onnx_session = onnxruntime.InferenceSession( + path_or_bytes=onnx_model_bytes, + sess_options=self._options.ort_session_options, + providers=self._select_eps(graph_module, *args), + ) + + # Cache ORT session. It's reused for the same "graph_module". + # Generate ONNX model and extract its input and output names. + input_names = tuple(input.name for input in onnx_model.graph.input) + output_names = tuple(output.name for output in onnx_model.graph.output) + input_devices = _get_onnx_devices(args) + # Cache devices for inputs and outputs. They are used to invoke + # ORT session. Output devices indicate where (e.g., GPU or CPU) + # to store outputs + if isinstance(prim_outputs, tuple): + output_devices = _get_onnx_devices(prim_outputs) + else: + output_devices = _get_onnx_devices((prim_outputs,)) + + input_value_infos = tuple(input for input in onnx_model.graph.input) + output_value_infos = tuple(output for output in onnx_model.graph.output) + + execution_info_per_session = OrtExecutionInfoPerSession( + session=onnx_session, + input_names=input_names, + input_value_infos=input_value_infos, + output_names=output_names, + output_value_infos=output_value_infos, + input_devices=input_devices, + output_devices=output_devices, + example_outputs=prim_outputs, + ) + + self._all_ort_execution_info.cache_session_execution_info( + graph_module, execution_info_per_session + ) + + self.execution_count += 1 + + # ORT always returns a tuple of outputs. If the original output is a tensor, + # ORT output's first element must be extracted and returned. Otherwise, type + # mismatch may happen in downstream computation. + is_single_tensor_output = isinstance(prim_outputs, torch.Tensor) + normalized_prim_outputs = ( + (prim_outputs,) if is_single_tensor_output else prim_outputs + ) + assert isinstance(normalized_prim_outputs, tuple) + assert all( + isinstance(elem, (torch.Tensor, torch.SymInt, int)) + for elem in normalized_prim_outputs + ) + + _nvtx_range_push("run_onnx_session_with_ortvaluevector") + onnx_outputs = self.run( + onnx_session, + input_names, + args, + input_devices, + output_names, + normalized_prim_outputs, + output_devices, + self._options.preallocate_output, + input_value_infos, + normalized_prim_outputs, + ) + _nvtx_range_pop() + + if self._assert_allclose_to_baseline: + # Compute baseline. + baseline_outputs = torch._prims.executor.execute( + graph_module, *args, executor="aten" + ) + normalized_baseline_ouptuts = ( + (baseline_outputs,) if is_single_tensor_output else baseline_outputs + ) + # Ensure every output tensor is close to the corresponding baseline. + for onnx_output, baseline_output in zip( + onnx_outputs, normalized_baseline_ouptuts + ): + torch.testing.assert_close(onnx_output, baseline_output) + return onnx_outputs[0] if is_single_tensor_output else onnx_outputs + + def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule: + # Deferred import since CapabilityBasedPartitioner is not decorated with + # @compatibility; importing it at the module level will result in the test + # failing: pytest test/test_fx.py -k test_public_api_surface + # because this module is imported into torch.onnx. + from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner + + # FX graph based partitioning based on ONNX supported ops. + # Given a graph module + # GraphModule0 + # node_0 + # node_1 + # node_2 + # node_3 + # node_4 + # If only node_2 is not supported by ONNX, this graph module will be partitioned into + # GraphModule0 + # GraphModule1 + # node_0 + # node_1 + # node_2 + # GraphModule2 + # node_3 + # node_4 + # by calling CapabilityBasedPartitioner.partition_and_fuse. + # Then, GraphModule1's and GraphModule2's forward method (GraphModule._wrapped_call) + # will be replaced by OrtBackend._ort_accelerated_call to delegate computation to ORT. + if graph_module in self._partitioner_cache: + partitioned_prim_graph_module = self._partitioner_cache[graph_module] + else: + prim_graph_module = graph_module + partitioner = CapabilityBasedPartitioner( + prim_graph_module, + self._supported_ops, + allows_single_node_partition=True, + ) + partitioned_prim_graph_module = partitioner.partition_and_fuse() + self._partitioner_cache[graph_module] = partitioned_prim_graph_module + + # Overriding fused_module's __call__() function with ort_acclerated_call() + # This loop goes through all graph partitions (each of them is an ONNX-representable graph) + # and override their _wrapped_call function with _ort_accelerated_call. + # Inside _ort_accelerated_call, the partition's graph is exported into ONNX and executed by ORT. + for node in partitioned_prim_graph_module.graph.nodes: + # TODO(wschin): use a better way to identify fused submodule + # See https://github.com/pytorch/pytorch/issues/106872. + if node.op == "call_module" and "fused_" in node.name: + fused_module = getattr(partitioned_prim_graph_module, node.name) + # self.ort_acclerated_call is responsible for exporting graph to ONNX, + # creating ORT session, and running ORT session. + fused_module._wrapped_call = self._ort_acclerated_call + + return partitioned_prim_graph_module + + def __call__( + self, graph_module: torch.fx.GraphModule, args + ) -> torch.fx.GraphModule: + """If ``OrtBackendOptions.use_aot_autograd`` is ``True``, the `auto_autograd` compiler + will be invoked, wrapping this ``OrtBackend`` instance's ``compile`` method. Otherwise, + the ``compile`` method is invoked directly.""" + if self._options.use_aot_autograd: + from functorch.compile import min_cut_rematerialization_partition + from torch._dynamo.backends.common import aot_autograd + + return aot_autograd( + fw_compiler=self.compile, + partition_fn=min_cut_rematerialization_partition, + decompositions=self._resolved_onnx_exporter_options.decomposition_table, + )(graph_module, args) + + return self.compile(graph_module, args) + + __instance_cache_max_count: Final = 8 + __instance_cache: Final[list["OrtBackend"]] = [] + + @staticmethod + def get_cached_instance_for_options( + options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None, + ) -> "OrtBackend": + """Returns a possibly cached instance of an ``OrtBackend``. If an existing + backend was created previously through this function with the same options, + it will be returned. Otherwise a new backend will be created, cached, and + returned. + + Note: if ``options`` sets ``ort_session_options``, a new ``OrtBackend`` + will always be returned, since ``onnxruntime.SessionOptions`` cannot + participate in caching.""" + + def reusable(a: OrtBackendOptions, b: OrtBackendOptions): + if ( + a.preferred_execution_providers != b.preferred_execution_providers + or a.infer_execution_providers != b.infer_execution_providers + or a.default_execution_providers != b.default_execution_providers + or a.preallocate_output != b.preallocate_output + or a.use_aot_autograd != b.use_aot_autograd + or a.pre_ort_model_transforms != b.pre_ort_model_transforms + ): + return False + + # onnxruntime.SessionOptions is a pybind11 object, cannot be pickled, + # and holds too much potential state to reasonably check manually; + # ort_session_options is provided at all, the backend does not participate + # in caching. + if a.ort_session_options is not None or b.ort_session_options is not None: + return False + + return True + + if not isinstance(options, OrtBackendOptions): + options = OrtBackendOptions(**(options or {})) + + backend = next( + (b for b in OrtBackend.__instance_cache if reusable(b._options, options)), + None, + ) + + if backend is None: + assert ( + len(OrtBackend.__instance_cache) < OrtBackend.__instance_cache_max_count + ), ( + f"No more than {OrtBackend.__instance_cache_max_count} instances of " + f"{OrtBackend} allowed. Please instantiate `{OrtBackend}` explicitly " + "to pass to `torch.compile`. " + "See https://github.com/pytorch/pytorch/pull/107973#discussion_r1306144795 " + "for discussion." + ) + OrtBackend.__instance_cache.append(backend := OrtBackend(options)) + + return backend + + @staticmethod + def clear_cached_instances(): + OrtBackend.__instance_cache.clear() + + @staticmethod + def get_cached_instances(): + return tuple(OrtBackend.__instance_cache) + + +@compatibility(is_backward_compatible=False) +def torch_compile_backend( + graph_module: torch.fx.GraphModule, + args, + *, + options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None, +): + return OrtBackend.get_cached_instance_for_options(options)(graph_module, args) diff --git a/torch/onnx/_internal/registration.py b/torch/onnx/_internal/registration.py new file mode 100644 index 0000000000000..b8bba134f36b6 --- /dev/null +++ b/torch/onnx/_internal/registration.py @@ -0,0 +1,335 @@ +# mypy: allow-untyped-defs +"""Module for handling symbolic function registration.""" + +import warnings +from collections.abc import Collection, Sequence +from typing import Callable, Generic, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +from torch.onnx import _constants, errors + + +OpsetVersion = int + + +def _dispatch_opset_version( + target: OpsetVersion, registered_opsets: Collection[OpsetVersion] +) -> Optional[OpsetVersion]: + """Finds the registered opset given a target opset version and the available opsets. + + Args: + target: The target opset version. + registered_opsets: The available opsets. + + Returns: + The registered opset version. + """ + if not registered_opsets: + return None + + descending_registered_versions = sorted(registered_opsets, reverse=True) + # Linear search for the opset version, which is fine since the number of opset + # versions is small. + + if target >= _constants.ONNX_BASE_OPSET: + # Always look down toward opset 1 when the target is >= ONNX_BASE_OPSET (opset 9). + # When a custom op is register at opset 1, we want to be able to discover it as a + # fallback for all opsets >= ONNX_BASE_OPSET. + for version in descending_registered_versions: + if version <= target: + return version + return None + + # target < opset 9. This is the legacy behavior to support opset 7 and opset 8. + # for caffe2 support. We search up toward opset 9. + for version in reversed(descending_registered_versions): + # Count back up until _constants.ONNX_BASE_OPSET + if target <= version <= _constants.ONNX_BASE_OPSET: + return version + + return None + + +_K = TypeVar("_K") +_V = TypeVar("_V") +_R = TypeVar("_R") +_P = ParamSpec("_P") + + +class OverrideDict(Collection[_K], Generic[_K, _V]): + """A dictionary that merges built-in and custom symbolic functions. + + It supports overriding and un-overriding built-in symbolic functions with custom + ones. + """ + + def __init__(self) -> None: + self._base: dict[_K, _V] = {} + self._overrides: dict[_K, _V] = {} + self._merged: dict[_K, _V] = {} + + def set_base(self, key: _K, value: _V) -> None: + self._base[key] = value + if key not in self._overrides: + self._merged[key] = value + + def in_base(self, key: _K) -> bool: + """Checks if a key is in the base dictionary.""" + return key in self._base + + def override(self, key: _K, value: _V) -> None: + """Overrides a base key-value with a new pair.""" + self._overrides[key] = value + self._merged[key] = value + + def remove_override(self, key: _K) -> None: + """Un-overrides a key-value pair.""" + self._overrides.pop(key, None) # type: ignore[arg-type] + self._merged.pop(key, None) # type: ignore[arg-type] + if key in self._base: + self._merged[key] = self._base[key] + + def overridden(self, key: _K) -> bool: + """Checks if a key-value pair is overridden.""" + return key in self._overrides + + def __getitem__(self, key: _K) -> _V: + return self._merged[key] + + def get(self, key: _K, default: Optional[_V] = None): + return self._merged.get(key, default) + + def __contains__(self, key: object) -> bool: + return key in self._merged + + def __iter__(self): + return iter(self._merged) + + def __len__(self) -> int: + return len(self._merged) + + def __repr__(self) -> str: + return f"OverrideDict(base={self._base}, overrides={self._overrides})" + + def __bool__(self) -> bool: + return bool(self._merged) + + +class _SymbolicFunctionGroup: + """Different versions of symbolic functions registered to the same name. + + O(number of registered versions of an op) search is performed to find the most + recent version of the op. + + The registration is delayed until op is used to improve startup time. + + Function overloads with different arguments are not allowed. + Custom op overrides are supported. + """ + + def __init__(self, name: str) -> None: + self._name = name + # A dictionary of functions, keyed by the opset version. + self._functions: OverrideDict[OpsetVersion, Callable] = OverrideDict() + + def __repr__(self) -> str: + return f"_SymbolicFunctionGroup({self._name}, registered={self._functions})" + + def __getitem__(self, key: OpsetVersion) -> Callable: + result = self.get(key) + if result is None: + raise KeyError(key) + return result + + # TODO(justinchuby): Add @functools.lru_cache(maxsize=None) if lookup time becomes + # a problem. + def get(self, opset: OpsetVersion) -> Optional[Callable]: + """Find the most recent version of the function.""" + version = _dispatch_opset_version(opset, self._functions) + if version is None: + return None + + return self._functions[version] + + def add(self, func: Callable, opset: OpsetVersion) -> None: + """Adds a symbolic function. + + Args: + func: The function to add. + opset: The opset version of the function to add. + """ + if self._functions.in_base(opset): + warnings.warn( + f"Symbolic function '{self._name}' already registered for opset {opset}. " + f"Replacing the existing function with new function. This is unexpected. " + f"Please report it on {_constants.PYTORCH_GITHUB_ISSUES_URL}.", + errors.OnnxExporterWarning, + ) + self._functions.set_base(opset, func) + + def add_custom(self, func: Callable, opset: OpsetVersion) -> None: + """Adds a custom symbolic function. + + Args: + func: The symbolic function to register. + opset: The corresponding opset version. + """ + self._functions.override(opset, func) + + def remove_custom(self, opset: OpsetVersion) -> None: + """Removes a custom symbolic function. + + Args: + opset: The opset version of the custom function to remove. + """ + if not self._functions.overridden(opset): + warnings.warn( + f"No custom function registered for '{self._name}' opset {opset}" + ) + return + self._functions.remove_override(opset) + + def get_min_supported(self) -> OpsetVersion: + """Returns the lowest built-in opset version supported by the function.""" + return min(self._functions) + + +class SymbolicRegistry: + """Registry for symbolic functions. + + The registry maintains a mapping from qualified names to symbolic functions. + It is used to register new symbolic functions and to dispatch calls to + the appropriate function. + """ + + def __init__(self) -> None: + self._registry: dict[str, _SymbolicFunctionGroup] = {} + + def register( + self, name: str, opset: OpsetVersion, func: Callable, custom: bool = False + ) -> None: + """Registers a symbolic function. + + Args: + name: The qualified name of the function to register. In the form of 'domain::op'. + E.g. 'aten::add'. + opset: The opset version of the function to register. + func: The symbolic function to register. + custom: Whether the function is a custom function that overrides existing ones. + + Raises: + ValueError: If the separator '::' is not in the name. + """ + if "::" not in name: + raise ValueError( + f"The name must be in the form of 'domain::op', not '{name}'" + ) + symbolic_functions = self._registry.setdefault( + name, _SymbolicFunctionGroup(name) + ) + if custom: + symbolic_functions.add_custom(func, opset) + else: + symbolic_functions.add(func, opset) + + def unregister(self, name: str, opset: OpsetVersion) -> None: + """Unregisters a symbolic function. + + Args: + name: The qualified name of the function to unregister. + opset: The opset version of the function to unregister. + """ + if name not in self._registry: + return + self._registry[name].remove_custom(opset) + + def get_function_group(self, name: str) -> Optional[_SymbolicFunctionGroup]: + """Returns the function group for the given name.""" + return self._registry.get(name) + + def is_registered_op(self, name: str, version: int) -> bool: + """Returns whether the given op is registered for the given opset version.""" + functions = self.get_function_group(name) + if functions is None: + return False + return functions.get(version) is not None + + def all_functions(self) -> set[str]: + """Returns the set of all registered function names.""" + return set(self._registry) + + +def onnx_symbolic( + name: str, + opset: Union[OpsetVersion, Sequence[OpsetVersion]], + decorate: Optional[Sequence[Callable]] = None, + custom: bool = False, +) -> Callable: + """Registers a symbolic function. + + Usage:: + + ``` + @onnx_symbolic( + "aten::symbolic_b", + opset=10, + decorate=[quantized_aten_handler(scale=1 / 128, zero_point=0)], + ) + @symbolic_helper.parse_args("v", "v", "b") + def symbolic_b(g: _C.Graph, x: _C.Value, y: _C.Value, arg1: bool) -> _C.Value: ... + ``` + + Args: + name: The qualified name of the function in the form of 'domain::op'. + E.g. 'aten::add'. + opset: The opset versions of the function to register at. + decorate: A sequence of decorators to apply to the function. + custom: Whether the function is a custom symbolic function. + + Raises: + ValueError: If the separator '::' is not in the name. + """ + + def wrapper(func: Callable[_P, _R]) -> Callable[_P, _R]: + decorated = func + if decorate is not None: + for decorate_func in decorate: + decorated = decorate_func(decorated) + + global registry + nonlocal opset + if isinstance(opset, OpsetVersion): + opset = (opset,) + for opset_version in opset: + registry.register(name, opset_version, decorated, custom=custom) + + # Return the original function because the decorators in "decorate" are only + # specific to the instance being registered. + return func + + return wrapper + + +def custom_onnx_symbolic( + name: str, + opset: Union[OpsetVersion, Sequence[OpsetVersion]], + decorate: Optional[Sequence[Callable]] = None, +) -> Callable: + """Registers a custom symbolic function. + + Args: + name: the qualified name of the function. + opset: the opset version of the function. + decorate: a sequence of decorators to apply to the function. + + Returns: + The decorator. + + Raises: + ValueError: If the separator '::' is not in the name. + """ + return onnx_symbolic(name, opset, decorate, custom=True) + + +# The registry for all symbolic functions. +registry = SymbolicRegistry() diff --git a/torch/onnx/_onnx_supported_ops.py b/torch/onnx/_onnx_supported_ops.py new file mode 100644 index 0000000000000..f3d703ffc227f --- /dev/null +++ b/torch/onnx/_onnx_supported_ops.py @@ -0,0 +1,98 @@ +# mypy: allow-untyped-defs +import inspect +from typing import Union + +from torch import _C +from torch.onnx import _constants +from torch.onnx._internal import registration + + +class _TorchSchema: + def __init__(self, schema: Union[_C.FunctionSchema, str]) -> None: + if isinstance(schema, _C.FunctionSchema): + self.name: str = schema.name + self.overload_name: str = schema.overload_name + self.arguments: list[str] = [arg.name for arg in schema.arguments] + self.optional_arguments: list[str] = [] + self.returns: list[str] = [ret.name for ret in schema.returns] + self.opsets: list[int] = [] + else: + self.name = schema + self.overload_name = "" + self.arguments = [] + self.optional_arguments = [] + self.returns = [] + self.opsets = [] + + def __str__(self) -> str: + s = ( + f"{self.name}.{self.overload_name}(" + + ", ".join(self.arguments) + + ") -> (" + + ", ".join(self.returns) + + ")" + + " in opsets " + + ", ".join(str(opset) for opset in self.opsets) + ) + return s + + def __hash__(self): + # TODO(thiagocrepaldi): handle overload_name? + return hash(self.name) + + def __eq__(self, other) -> bool: + if not isinstance(other, _TorchSchema): + return False + # TODO(thiagocrepaldi): handle overload_name? + return self.name == other.name + + def is_aten(self) -> bool: + return self.name.startswith("aten::") + + def is_backward(self) -> bool: + return "backward" in self.name + + +def _symbolic_argument_count(func): + params = [] + signature = inspect.signature(func) + optional_params = [] + for name, parameter in signature.parameters.items(): + if name in {"_outputs", "g"}: + continue + if parameter.default is parameter.empty: + optional_params.append(parameter) + else: + params.append(str(parameter)) + return params + + +def all_forward_schemas() -> dict[str, _TorchSchema]: + """Returns schemas for all TorchScript forward ops.""" + torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()] + return {schema.name: schema for schema in torch_schemas if not schema.is_backward()} + + +def all_symbolics_schemas() -> dict[str, _TorchSchema]: + """Returns schemas for all onnx supported ops.""" + symbolics_schemas = {} + + for name in registration.registry.all_functions(): + func_group = registration.registry.get_function_group(name) + assert func_group is not None + symbolics_schema = _TorchSchema(name) + func = func_group.get(_constants.ONNX_MAX_OPSET) + if func is not None: + symbolics_schema.arguments = _symbolic_argument_count(func) + symbolics_schema.opsets = list( + range(func_group.get_min_supported(), _constants.ONNX_MAX_OPSET + 1) + ) + else: + # Only support opset < 9 + func = func_group.get(7) + symbolics_schema.arguments = _symbolic_argument_count(func) + symbolics_schema.opsets = list(range(7, _constants.ONNX_BASE_OPSET)) + + symbolics_schemas[name] = symbolics_schema + + return symbolics_schemas diff --git a/torch/onnx/_type_utils.py b/torch/onnx/_type_utils.py new file mode 100644 index 0000000000000..81bcaeef1107a --- /dev/null +++ b/torch/onnx/_type_utils.py @@ -0,0 +1,391 @@ +# mypy: allow-untyped-defs +"""Utilities for converting and operating on ONNX, JIT and torch types.""" + +from __future__ import annotations + +import enum +import typing +from typing import Literal + +import torch +from torch._C import _onnx as _C_onnx +from torch.onnx import errors + + +if typing.TYPE_CHECKING: + # Hack to help mypy to recognize torch._C.Value + from torch import _C # noqa: F401 + +ScalarName = Literal[ + "Byte", + "Char", + "Double", + "Float", + "Half", + "Int", + "Long", + "Short", + "Bool", + "ComplexHalf", + "ComplexFloat", + "ComplexDouble", + "QInt8", + "QUInt8", + "QInt32", + "BFloat16", + "Float8E5M2", + "Float8E4M3FN", + "Float8E5M2FNUZ", + "Float8E4M3FNUZ", + "Undefined", +] + +TorchName = Literal[ + "bool", + "uint8_t", + "int8_t", + "double", + "float", + "half", + "int", + "int64_t", + "int16_t", + "complex32", + "complex64", + "complex128", + "qint8", + "quint8", + "qint32", + "bfloat16", + "float8_e5m2", + "float8_e4m3fn", + "float8_e5m2fnuz", + "float8_e4m3fnuz", +] + + +class JitScalarType(enum.IntEnum): + """Scalar types defined in torch. + + Use ``JitScalarType`` to convert from torch and JIT scalar types to ONNX scalar types. + + Examples: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) + >>> # xdoctest: +IGNORE_WANT("win32 has different output") + >>> JitScalarType.from_value(torch.ones(1, 2)).onnx_type() + TensorProtoDataType.FLOAT + + >>> JitScalarType.from_value(torch_c_value_with_type_float).onnx_type() + TensorProtoDataType.FLOAT + + >>> JitScalarType.from_dtype(torch.get_default_dtype).onnx_type() + TensorProtoDataType.FLOAT + + """ + + # Order defined in https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h + UINT8 = 0 + INT8 = enum.auto() # 1 + INT16 = enum.auto() # 2 + INT = enum.auto() # 3 + INT64 = enum.auto() # 4 + HALF = enum.auto() # 5 + FLOAT = enum.auto() # 6 + DOUBLE = enum.auto() # 7 + COMPLEX32 = enum.auto() # 8 + COMPLEX64 = enum.auto() # 9 + COMPLEX128 = enum.auto() # 10 + BOOL = enum.auto() # 11 + QINT8 = enum.auto() # 12 + QUINT8 = enum.auto() # 13 + QINT32 = enum.auto() # 14 + BFLOAT16 = enum.auto() # 15 + FLOAT8E5M2 = enum.auto() # 16 + FLOAT8E4M3FN = enum.auto() # 17 + FLOAT8E5M2FNUZ = enum.auto() # 18 + FLOAT8E4M3FNUZ = enum.auto() # 19 + UNDEFINED = enum.auto() # 20 + + @classmethod + def _from_name(cls, name: ScalarName | TorchName | str | None) -> JitScalarType: + """Convert a JIT scalar type or torch type name to ScalarType. + + Note: DO NOT USE this API when `name` comes from a `torch._C.Value.type()` calls. + A "RuntimeError: INTERNAL ASSERT FAILED at "../aten/src/ATen/core/jit_type_base.h" can + be raised in several scenarios where shape info is not present. + Instead use `from_value` API which is safer. + + Args: + name: JIT scalar type name (Byte) or torch type name (uint8_t). + + Returns: + JitScalarType + + Raises: + OnnxExporterError: if name is not a valid scalar type name or if it is None. + """ + if name is None: + raise errors.OnnxExporterError("Scalar type name cannot be None") + if valid_scalar_name(name): + return _SCALAR_NAME_TO_TYPE[name] # type: ignore[index] + if valid_torch_name(name): + return _TORCH_NAME_TO_SCALAR_TYPE[name] # type: ignore[index] + + raise errors.OnnxExporterError(f"Unknown torch or scalar type: '{name}'") + + @classmethod + def from_dtype(cls, dtype: torch.dtype | None) -> JitScalarType: + """Convert a torch dtype to JitScalarType. + + Note: DO NOT USE this API when `dtype` comes from a `torch._C.Value.type()` calls. + A "RuntimeError: INTERNAL ASSERT FAILED at "../aten/src/ATen/core/jit_type_base.h" can + be raised in several scenarios where shape info is not present. + Instead use `from_value` API which is safer. + + Args: + dtype: A torch.dtype to create a JitScalarType from + + Returns: + JitScalarType + + Raises: + OnnxExporterError: if dtype is not a valid torch.dtype or if it is None. + """ + if dtype not in _DTYPE_TO_SCALAR_TYPE: + raise errors.OnnxExporterError(f"Unknown dtype: {dtype}") + return _DTYPE_TO_SCALAR_TYPE[dtype] + + @classmethod + def from_onnx_type( + cls, onnx_type: int | _C_onnx.TensorProtoDataType | None + ) -> JitScalarType: + """Convert a ONNX data type to JitScalarType. + + Args: + onnx_type: A torch._C._onnx.TensorProtoDataType to create a JitScalarType from + + Returns: + JitScalarType + + Raises: + OnnxExporterError: if dtype is not a valid torch.dtype or if it is None. + """ + if onnx_type not in _ONNX_TO_SCALAR_TYPE: + raise errors.OnnxExporterError(f"Unknown onnx_type: {onnx_type}") + return _ONNX_TO_SCALAR_TYPE[typing.cast(_C_onnx.TensorProtoDataType, onnx_type)] + + @classmethod + def from_value( + cls, value: None | torch._C.Value | torch.Tensor, default=None + ) -> JitScalarType: + """Create a JitScalarType from an value's scalar type. + + Args: + value: An object to fetch scalar type from. + default: The JitScalarType to return if a valid scalar cannot be fetched from value + + Returns: + JitScalarType. + + Raises: + OnnxExporterError: if value does not have a valid scalar type and default is None. + SymbolicValueError: when value.type()'s info are empty and default is None + """ + + if not isinstance(value, (torch._C.Value, torch.Tensor)) or ( + isinstance(value, torch._C.Value) and value.node().mustBeNone() + ): + # default value of type JitScalarType is returned when value is not valid + if default is None: + raise errors.OnnxExporterError( + "value must be either torch._C.Value or torch.Tensor objects." + ) + elif not isinstance(default, JitScalarType): + raise errors.OnnxExporterError( + "default value must be a JitScalarType object." + ) + return default + + # Each value type has their own way of storing scalar type + if isinstance(value, torch.Tensor): + return cls.from_dtype(value.dtype) + if isinstance(value.type(), torch.ListType): + try: + return cls.from_dtype(value.type().getElementType().dtype()) + except RuntimeError: + return cls._from_name(str(value.type().getElementType())) + if isinstance(value.type(), torch._C.OptionalType): + if value.type().getElementType().dtype() is None: + if isinstance(default, JitScalarType): + return default + raise errors.OnnxExporterError( + "default value must be a JitScalarType object." + ) + return cls.from_dtype(value.type().getElementType().dtype()) + + scalar_type = None + if value.node().kind() != "prim::Constant" or not isinstance( + value.type(), torch._C.NoneType + ): + # value must be a non-list torch._C.Value scalar + scalar_type = value.type().scalarType() + + if scalar_type is not None: + return cls._from_name(scalar_type) + + # When everything fails... try to default + if default is not None: + return default + raise errors.SymbolicValueError( + f"Cannot determine scalar type for this '{type(value.type())}' instance and " + "a default value was not provided.", + value, + ) + + def scalar_name(self) -> ScalarName: + """Convert a JitScalarType to a JIT scalar type name.""" + return _SCALAR_TYPE_TO_NAME[self] + + def torch_name(self) -> TorchName: + """Convert a JitScalarType to a torch type name.""" + return _SCALAR_TYPE_TO_TORCH_NAME[self] + + def dtype(self) -> torch.dtype: + """Convert a JitScalarType to a torch dtype.""" + return _SCALAR_TYPE_TO_DTYPE[self] + + def onnx_type(self) -> _C_onnx.TensorProtoDataType: + """Convert a JitScalarType to an ONNX data type.""" + if self not in _SCALAR_TYPE_TO_ONNX: + raise errors.OnnxExporterError( + f"Scalar type {self} cannot be converted to ONNX" + ) + return _SCALAR_TYPE_TO_ONNX[self] + + def onnx_compatible(self) -> bool: + """Return whether this JitScalarType is compatible with ONNX.""" + return ( + self in _SCALAR_TYPE_TO_ONNX + and self != JitScalarType.UNDEFINED + and self != JitScalarType.COMPLEX32 + ) + + +def valid_scalar_name(scalar_name: ScalarName | str) -> bool: + """Return whether the given scalar name is a valid JIT scalar type name.""" + return scalar_name in _SCALAR_NAME_TO_TYPE + + +def valid_torch_name(torch_name: TorchName | str) -> bool: + """Return whether the given torch name is a valid torch type name.""" + return torch_name in _TORCH_NAME_TO_SCALAR_TYPE + + +# https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h +_SCALAR_TYPE_TO_NAME: dict[JitScalarType, ScalarName] = { + JitScalarType.BOOL: "Bool", + JitScalarType.UINT8: "Byte", + JitScalarType.INT8: "Char", + JitScalarType.INT16: "Short", + JitScalarType.INT: "Int", + JitScalarType.INT64: "Long", + JitScalarType.HALF: "Half", + JitScalarType.FLOAT: "Float", + JitScalarType.DOUBLE: "Double", + JitScalarType.COMPLEX32: "ComplexHalf", + JitScalarType.COMPLEX64: "ComplexFloat", + JitScalarType.COMPLEX128: "ComplexDouble", + JitScalarType.QINT8: "QInt8", + JitScalarType.QUINT8: "QUInt8", + JitScalarType.QINT32: "QInt32", + JitScalarType.BFLOAT16: "BFloat16", + JitScalarType.FLOAT8E5M2: "Float8E5M2", + JitScalarType.FLOAT8E4M3FN: "Float8E4M3FN", + JitScalarType.FLOAT8E5M2FNUZ: "Float8E5M2FNUZ", + JitScalarType.FLOAT8E4M3FNUZ: "Float8E4M3FNUZ", + JitScalarType.UNDEFINED: "Undefined", +} + +_SCALAR_NAME_TO_TYPE: dict[ScalarName, JitScalarType] = { + v: k for k, v in _SCALAR_TYPE_TO_NAME.items() +} + +_SCALAR_TYPE_TO_TORCH_NAME: dict[JitScalarType, TorchName] = { + JitScalarType.BOOL: "bool", + JitScalarType.UINT8: "uint8_t", + JitScalarType.INT8: "int8_t", + JitScalarType.INT16: "int16_t", + JitScalarType.INT: "int", + JitScalarType.INT64: "int64_t", + JitScalarType.HALF: "half", + JitScalarType.FLOAT: "float", + JitScalarType.DOUBLE: "double", + JitScalarType.COMPLEX32: "complex32", + JitScalarType.COMPLEX64: "complex64", + JitScalarType.COMPLEX128: "complex128", + JitScalarType.QINT8: "qint8", + JitScalarType.QUINT8: "quint8", + JitScalarType.QINT32: "qint32", + JitScalarType.BFLOAT16: "bfloat16", + JitScalarType.FLOAT8E5M2: "float8_e5m2", + JitScalarType.FLOAT8E4M3FN: "float8_e4m3fn", + JitScalarType.FLOAT8E5M2FNUZ: "float8_e5m2fnuz", + JitScalarType.FLOAT8E4M3FNUZ: "float8_e4m3fnuz", +} + +_TORCH_NAME_TO_SCALAR_TYPE: dict[TorchName, JitScalarType] = { + v: k for k, v in _SCALAR_TYPE_TO_TORCH_NAME.items() +} + +_SCALAR_TYPE_TO_ONNX = { + JitScalarType.BOOL: _C_onnx.TensorProtoDataType.BOOL, + JitScalarType.UINT8: _C_onnx.TensorProtoDataType.UINT8, + JitScalarType.INT8: _C_onnx.TensorProtoDataType.INT8, + JitScalarType.INT16: _C_onnx.TensorProtoDataType.INT16, + JitScalarType.INT: _C_onnx.TensorProtoDataType.INT32, + JitScalarType.INT64: _C_onnx.TensorProtoDataType.INT64, + JitScalarType.HALF: _C_onnx.TensorProtoDataType.FLOAT16, + JitScalarType.FLOAT: _C_onnx.TensorProtoDataType.FLOAT, + JitScalarType.DOUBLE: _C_onnx.TensorProtoDataType.DOUBLE, + JitScalarType.COMPLEX64: _C_onnx.TensorProtoDataType.COMPLEX64, + JitScalarType.COMPLEX128: _C_onnx.TensorProtoDataType.COMPLEX128, + JitScalarType.BFLOAT16: _C_onnx.TensorProtoDataType.BFLOAT16, + JitScalarType.UNDEFINED: _C_onnx.TensorProtoDataType.UNDEFINED, + JitScalarType.COMPLEX32: _C_onnx.TensorProtoDataType.UNDEFINED, + JitScalarType.QINT8: _C_onnx.TensorProtoDataType.INT8, + JitScalarType.QUINT8: _C_onnx.TensorProtoDataType.UINT8, + JitScalarType.QINT32: _C_onnx.TensorProtoDataType.INT32, + JitScalarType.FLOAT8E5M2: _C_onnx.TensorProtoDataType.FLOAT8E5M2, + JitScalarType.FLOAT8E4M3FN: _C_onnx.TensorProtoDataType.FLOAT8E4M3FN, + JitScalarType.FLOAT8E5M2FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E5M2FNUZ, + JitScalarType.FLOAT8E4M3FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E4M3FNUZ, +} + +_ONNX_TO_SCALAR_TYPE = {v: k for k, v in _SCALAR_TYPE_TO_ONNX.items()} + +# source of truth is +# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_dtypes.cpp +_SCALAR_TYPE_TO_DTYPE = { + JitScalarType.BOOL: torch.bool, + JitScalarType.UINT8: torch.uint8, + JitScalarType.INT8: torch.int8, + JitScalarType.INT16: torch.short, + JitScalarType.INT: torch.int, + JitScalarType.INT64: torch.int64, + JitScalarType.HALF: torch.half, + JitScalarType.FLOAT: torch.float, + JitScalarType.DOUBLE: torch.double, + JitScalarType.COMPLEX32: torch.complex32, + JitScalarType.COMPLEX64: torch.complex64, + JitScalarType.COMPLEX128: torch.complex128, + JitScalarType.QINT8: torch.qint8, + JitScalarType.QUINT8: torch.quint8, + JitScalarType.QINT32: torch.qint32, + JitScalarType.BFLOAT16: torch.bfloat16, + JitScalarType.FLOAT8E5M2: torch.float8_e5m2, + JitScalarType.FLOAT8E4M3FN: torch.float8_e4m3fn, + JitScalarType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz, + JitScalarType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz, +} + +_DTYPE_TO_SCALAR_TYPE = {v: k for k, v in _SCALAR_TYPE_TO_DTYPE.items()} diff --git a/torch/onnx/ops/__init__.py b/torch/onnx/ops/__init__.py index d10ba1ac7a3cc..af4d40d8a451c 100644 --- a/torch/onnx/ops/__init__.py +++ b/torch/onnx/ops/__init__.py @@ -61,6 +61,7 @@ def aten_decompositions() -> dict[torch._ops.OpOverload, Callable]: def _parse_domain_op_type(domain_op: str) -> tuple[str, str]: +<<<<<<< HEAD split = domain_op.split("::", 1) if len(split) == 1: domain = "" @@ -68,6 +69,15 @@ def _parse_domain_op_type(domain_op: str) -> tuple[str, str]: else: domain = split[0] op_type = split[1] +======= + splitted = domain_op.split("::", 1) + if len(splitted) == 1: + domain = "" + op_type = splitted[0] + else: + domain = splitted[0] + op_type = splitted[1] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return domain, op_type @@ -208,7 +218,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain. # The output tensors will have the specified dtypes and shapes +<<<<<<< HEAD (out1, out2) = torch.onnx.ops.symbolic_multi_out( +======= + (out1, out2) = torch.onnx.ops.symbolic( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "custom_domain::CustomOp", (x,), dict(attr_key="attr_value"), diff --git a/torch/onnx/ops/_impl.py b/torch/onnx/ops/_impl.py index a7eba334ecfc8..0e79fc353479f 100644 --- a/torch/onnx/ops/_impl.py +++ b/torch/onnx/ops/_impl.py @@ -1,15 +1,24 @@ # flake8: noqa: B950 import math +<<<<<<< HEAD from typing import Callable, Optional, TypeVar from typing_extensions import ParamSpec +======= +import typing +from typing import Callable, Optional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch.onnx.ops import _dtype_mappings +<<<<<<< HEAD # Use ParamSpec for better type preservation instead of bound Callable TypeVar _P = ParamSpec("_P") _R = TypeVar("_R") +======= +_T = typing.TypeVar("_T", bound=Callable) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # ONNX to ATen decomp table ONNX_ATEN_DECOMP_TABLE: dict[torch._ops.OpOverload, Callable] = {} @@ -23,12 +32,19 @@ ) +<<<<<<< HEAD def _onnx_op( op_type: str, opset_version: int ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: """Decorator to register an ONNX operator with a custom implementation.""" def decorator(func: Callable[_P, _R]) -> Callable[_P, _R]: +======= +def _onnx_op(op_type: str, opset_version: int) -> Callable[[_T], _T]: + """Decorator to register an ONNX operator with a custom implementation.""" + + def decorator(func: _T) -> _T: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) overload = f"opset{opset_version}" torch_op = torch.library.custom_op( f"onnx::{op_type}.{overload}", mutates_args=() @@ -56,6 +72,7 @@ def rotary_embedding_23( rotary_embedding_dim: int = 0, ) -> torch.Tensor: """RotaryEmbedding-23 https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html#rotaryembedding-23""" +<<<<<<< HEAD # x has shape (batch_size, num_heads, sequence_length, head_size) # or (batch_size, sequence_length, hidden_size) input_shape = x.shape @@ -105,6 +122,20 @@ def rotary_embedding_23( new_shape = [batch_size, sequence_length, num_heads, head_size] x = torch.reshape(x, new_shape) +======= + # First ensure x has shape [batch_size, num_heads, seq_len, head_size] + batch_size = x.shape[0] + sequence_length = x.shape[1] + if len(x.shape) == 3: + hidden_size = x.shape[2] + torch._check( + num_heads != 0, + lambda: f"num_heads must be provided for 3D inputs. Received input tensor with shape {x.shape}", + ) + head_size = hidden_size // num_heads + new_shape = [batch_size, sequence_length, num_heads, head_size] + x = torch.reshape(x, new_shape) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._check(len(x.shape) == 4, lambda: "x should be a 4D tensor by now") head_size = x.shape[3] @@ -125,6 +156,7 @@ def rotary_embedding_23( position_ids ] # Shape: [batch_size, sequence_length, head_size/2] else: +<<<<<<< HEAD cos = cos_cache # Shape: [batch_size, sequence_length, rotary_embedding_dim/2] sin = sin_cache # Shape: [batch_size, sequence_length, rotary_embedding_dim/2] @@ -144,6 +176,16 @@ def rotary_embedding_23( sin.shape[-1] == rotary_embedding_dim_half, lambda: f"Last dimension of sin cache ({sin.shape[-1]}) should match rotary_embedding_dim/2 ({rotary_embedding_dim_half}).", ) +======= + cos = cos_cache + sin = sin_cache + cos = cos[ + :, :, :rotary_embedding_dim_half + ] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2] + sin = sin[ + :, :, :rotary_embedding_dim_half + ] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cos = torch.unsqueeze( cos, 2 ) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2] @@ -173,11 +215,17 @@ def rotary_embedding_23( else: x_rotate = torch.cat((real, imag), dim=-1) output = torch.cat((x_rotate, x_not_rotate), dim=-1) +<<<<<<< HEAD if input_rank == 3: return torch.reshape(output, input_shape) # Return the dimensions to the original order return torch.permute(output, (0, 2, 1, 3)) +======= + if len(x.shape) == 3: + output = torch.reshape(output, x.shape) + return output +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _get_scale_factor(scale: Optional[float], head_size: int) -> float: diff --git a/torch/onnx/symbolic_caffe2.py b/torch/onnx/symbolic_caffe2.py new file mode 100644 index 0000000000000..83a2ff6c32ec9 --- /dev/null +++ b/torch/onnx/symbolic_caffe2.py @@ -0,0 +1,361 @@ +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +import importlib +import inspect + +from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import jit_utils, registration + + +def register_quantized_ops(domain: str, version: int): + # Register all quantized ops + module = importlib.import_module("torch.onnx.symbolic_caffe2") + quant_version_ops = inspect.getmembers(module) + aten_q_ops = { + "relu", + "_empty_affine_quantized", + "dequantize", + "quantize_per_tensor", + "upsample_nearest2d", + "avg_pool2d", + "reshape", + "slice", + "cat", + "max_pool2d", + "sigmoid", + } + for op, func in quant_version_ops: + name = f"{domain}::{op}" + if inspect.isfunction(func) and not registration.registry.is_registered_op( + name, version + ): + if op in aten_q_ops: + # Override the builtin aten ops + registration.registry.register( + f"aten::{op}", version, func, custom=True + ) + registration.registry.register(name, version, func) + + +def _permute_helper(g: jit_utils.GraphContext, input, axes): + quant_args = { + "axes_i": axes, + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + output = g.op("_caffe2::Int8Transpose", input, **quant_args) + symbolic_helper._quantized_ops.add(output) + return output + + +def nchw2nhwc(g: jit_utils.GraphContext, input): + axes = [0, 2, 3, 1] + return _permute_helper(g, input, axes) + + +def nhwc2nchw(g: jit_utils.GraphContext, input): + axes = [0, 3, 1, 2] + return _permute_helper(g, input, axes) + + +def linear_prepack(g: jit_utils.GraphContext, weight, bias): + # Mapping to a dummy caffe2 prepack node. + # During the onnx -> c2 conversion we can look up original weight and bias + # from this node + output = g.op("_caffe2::WeightPrepack", weight, bias) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "v", "v", "f", "i") +def linear(g: jit_utils.GraphContext, input, weight, bias, scale, zero_point): + kwargs = { + "Y_scale_f": scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8FC", input, weight, bias, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +def conv_prepack( + g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups +): + # Mapping to a dummy caffe2 prepack node. + # During the onnx -> c2 conversion we can look up original weight and bias + # from this node + output = g.op("_caffe2::WeightPrepack", input, weight, bias) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") +def conv2d( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + scale, + zero_point, +): + kernel_size = weight.node()["shape"][1:3] + kwargs = { + "strides_i": stride, + "pads_i": padding + padding, + "dilations_i": dilation, + "group_i": groups, + "kernels_i": kernel_size, + "order_s": "NHWC", + "Y_scale_f": scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8Conv", input, weight, bias, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") +def conv2d_relu( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + scale, + zero_point, +): + kernel_size = weight.node()["shape"][1:3] + kwargs = { + "strides_i": stride, + "pads_i": padding + padding, + "dilations_i": dilation, + "group_i": groups, + "kernels_i": kernel_size, + "order_s": "NHWC", + "Y_scale_f": scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8ConvRelu", input, weight, bias, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "v", "f", "i") +def add(g: jit_utils.GraphContext, input_a, input_b, scale, zero_point): + kwargs = { + "Y_scale_f": scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8Add", input_a, input_b, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v") +def relu(g: jit_utils.GraphContext, input): + if input not in symbolic_helper._quantized_ops: + return opset9.relu(g, input) + kwargs = { + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + output = g.op("_caffe2::Int8Relu", input, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "f", "i", "t") +def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): + kwargs = { + "Y_scale_f": scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8Quantize", input, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v") +def dequantize(g: jit_utils.GraphContext, input): + return g.op("_caffe2::Int8Dequantize", input) + + +@symbolic_helper.parse_args("v", "t", "t", "t", "t", "t", "t", "t") +def _empty_affine_quantized( + g: jit_utils.GraphContext, + input, + shape, + scale, + zero_point, + dtype, + pin_memory, + memory_format, + layout, +): + return input + + +def upsample_nearest2d( + g: jit_utils.GraphContext, + input, + output_size, + align_corners=None, + scales_h=None, + scales_w=None, +): + if input not in symbolic_helper._quantized_ops: + return opset9.upsample_nearest2d(g, input, output_size, align_corners) # type: ignore[attr-defined] + + output_size = symbolic_helper._parse_arg(output_size, "is") + kwargs = { + "output_size_i": output_size, + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + input = nchw2nhwc(g, input) + output = g.op("_caffe2::Int8ResizeNearest", input, **kwargs) + output = nhwc2nchw(g, output) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") +def max_pool2d( + g: jit_utils.GraphContext, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, +): + if input not in symbolic_helper._quantized_ops: + return opset9.max_pool2d( # type: ignore[attr-defined] + g, input, kernel_size, stride, padding, dilation, ceil_mode + ) + kwargs = { + "strides_i": stride, + "pads_i": padding + padding, + "kernel_i": kernel_size[0], + "order_s": "NHWC", + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + input = nchw2nhwc(g, input) + output = g.op("_caffe2::Int8MaxPool", input, **kwargs) + output = nhwc2nchw(g, output) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") +def avg_pool2d( + g: jit_utils.GraphContext, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override=None, +): + if input not in symbolic_helper._quantized_ops: + return opset9.avg_pool2d( # type: ignore[attr-defined] + g, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + kwargs = { + "strides_i": stride, + "pads_i": padding + padding, + "kernel_i": kernel_size[0], + "order_s": "NHWC", + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + input = nchw2nhwc(g, input) + output = g.op("_caffe2::Int8AveragePool", input, **kwargs) + output = nhwc2nchw(g, output) + symbolic_helper._quantized_ops.add(output) + return output + + +def reshape(g: jit_utils.GraphContext, input, shape): + if input not in symbolic_helper._quantized_ops: + return opset9.reshape(g, input, shape) + + kwargs = { + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + output = g.op("_caffe2::Int8Reshape", input, shape, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v", "v", "v", "v", "i") +def slice(g: jit_utils.GraphContext, input, dim, start, end, step): + if input not in symbolic_helper._quantized_ops: + return opset9.slice(g, input, dim, start, end, step) + + if step != 1: + raise RuntimeError("ONNX quantized slice export only works for step 1.") + start = symbolic_helper._parse_arg(start, "i") + end = symbolic_helper._parse_arg(end, "i") + dim = symbolic_helper._parse_arg(dim, "i") + + kwargs = { + "start_idx_i": start, + "end_idx_i": end, + "dim_i": dim, + "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), + "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), + } + output = g.op("_caffe2::Int8Slice", input, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +def cat(g: jit_utils.GraphContext, tensor_list, dim, scale=None, zero_point=None): + tensors = symbolic_helper._unpack_list(tensor_list) + input = tensors[0] + if input not in symbolic_helper._quantized_ops: + return opset9.cat(g, tensor_list, dim) + + dim = symbolic_helper._parse_arg(dim, "i") + kwargs = { + "Y_scale_f": tensors[0].node()["Y_scale"], + "Y_zero_point_i": tensors[0].node()["Y_zero_point"], + } + output = g.op("_caffe2::Int8Concat", *tensors, axis_i=dim, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output + + +@symbolic_helper.parse_args("v") +def sigmoid(g: jit_utils.GraphContext, input): + if input not in symbolic_helper._quantized_ops: + return opset9.sigmoid(g, input) + # Caffe2 expects the output scale to be 1/2^8 + # and output zero_point to be 0 (quint8 type) + out_scale = 1.0 / 256 + zero_point = 0 + kwargs = { + "Y_scale_f": out_scale, + "Y_zero_point_i": zero_point, + } + output = g.op("_caffe2::Int8Sigmoid", input, **kwargs) + symbolic_helper._quantized_ops.add(output) + return output diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 76b50a8eb3f77..1e104c8f5b93c 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.symbolic_helper.""" from __future__ import annotations @@ -6,3 +7,2272 @@ __all__: list[str] = [] from torch.onnx._internal.torchscript_exporter.symbolic_helper import * # noqa: F401,F403 +======= +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import inspect +import math +import sys +import typing +import warnings +from typing import Any, Callable, Literal, NoReturn, TypeVar as _TypeVar +from typing_extensions import Concatenate as _Concatenate, ParamSpec as _ParamSpec + +import torch +import torch._C._onnx as _C_onnx +from torch import _C + +# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics +from torch.onnx import _constants, _type_utils, errors, utils +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import jit_utils + + +if typing.TYPE_CHECKING: + from collections.abc import Sequence + + from torch.types import Number + +_T = _TypeVar("_T") +_U = _TypeVar("_U") +_P = _ParamSpec("_P") + +# --------------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------------- + +_ValueDescriptor = Literal[ + "v", + "i", + "is", + "f", + "fs", + "b", + "s", + "t", + "none", +] + + +def _parse_arg( + value, + desc: _ValueDescriptor, + arg_name: str | None = None, + node_name: str | None = None, +): + if desc == "none": + return value + if desc == "v" or not _is_value(value): + return value + + node = value.node() + if node.mustBeNone(): + return None + if node.kind() == "onnx::Constant": + node_val = _node_get(node, "value") + if desc == "i": + return int(node_val) + elif desc == "f": + return float(node_val) + elif desc == "b": + return bool(node_val) + elif desc == "s": + return str(node_val) + elif desc == "t": + return node_val + elif desc == "is": + return [int(v) for v in node_val] + elif desc == "fs": + return [float(v) for v in node_val] + else: + raise errors.SymbolicValueError( + f"ONNX symbolic does not understand the Constant node '{node}' " + f"specified with descriptor '{desc}'.", + value, + ) + elif node.kind() == "prim::ListConstruct": + if desc == "is": + for v in node.inputs(): + element_node = v.node() + if element_node.kind() != "onnx::Constant": + raise errors.SymbolicValueError( + f"Failed to export a node '{element_node}' " + f"(in list node {node}) " + f"because it is not constant. " + f"Please try to make things (e.g. kernel sizes) static if possible.", + value, + ) + return [int(_node_get(v.node(), "value")) for v in value.node().inputs()] + else: + raise errors.SymbolicValueError( + f"ONNX symbolic does not know how to unpack the ListConstruct node that " + f"is not a list of integers: '{node}'", + value, + ) + + if arg_name is None or node_name is None: + raise errors.SymbolicValueError( + f"Expected node type 'onnx::Constant', got '{node.kind()}'.", + value, + ) + + raise errors.SymbolicValueError( + "Expected node type 'onnx::Constant' " + f"for argument '{arg_name}' of node '{node_name}', got '{node.kind()}'.", + value, + ) + + +def _node_get(node: _C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type.""" + assert isinstance(node, _C.Node) + sel = node.kindOf(key) + return getattr(node, sel)(key) + + +def _is_onnx_constant(value: _C.Value): + """Whether a Value is an ONNX constant.""" + return value.node().kind() == "onnx::Constant" + + +def _maybe_get_const( + value: _C.Value | torch.Tensor | Number | Sequence | None, + descriptor: _ValueDescriptor, +): + # NOTE: prim::Constant at this stage usually means something not compatible in ONNX, + # otherwise it'd be converted to onnx::Constant + # TODO(justinchuby): Replace insinstance with _is_value once we figure out mypy + if isinstance(value, _C.Value) and _is_onnx_constant(value): + return _parse_arg(value, descriptor) + return value + + +def _maybe_get_scalar(value): + value_t = _maybe_get_const(value, "t") + if isinstance(value_t, torch.Tensor) and value_t.shape == (): + return value_t + return value + + +def _get_const(value, desc, arg_name): + if not _is_constant(value): + raise errors.SymbolicValueError( + f"ONNX symbolic expected a constant value of the '{arg_name}' argument, " + f"got '{value}'", + value, + ) + return _parse_arg(value, desc) + + +def _unpack_list(list_value: _C.Value) -> list[_C.Value]: + list_node = list_value.node() + if list_node.kind() != "prim::ListConstruct": + raise errors.SymbolicValueError( + f"ONNX symbolic expected node type prim::ListConstruct, got '{list_node}'.", + list_value, + ) + return list(list_node.inputs()) + + +def _unpack_tuple(tuple_value: _C.Value) -> tuple[_C.Value, ...]: + tuple_node = tuple_value.node() + if not _is_tuple_construct(tuple_value): + raise errors.SymbolicValueError( + f"ONNX symbolic expected node type 'prim::TupleConstruct', " + f"got '{tuple_node.kind()}'.", + tuple_value, + ) + return tuple(tuple_node.inputs()) + + +def _unpack_quantized_tensor(tuple_value: _C.Value) -> tuple[_C.Value, ...]: + """Unpacks a quantized tensor into a tuple of tensor and scale/zero_point. + Args: + tuple_value: A tuple of tensor, scale, zero_point, and optionally axis. + Returns: + A tuple of tensor, scale, zero_point, and optionally axis. + """ + tuple_node = tuple_value.node() + # A quantized tensor is represented as tuple of the form (tensor, scale, zero_point, ) + if not _is_tuple_construct(tuple_value): + raise errors.SymbolicValueError( + f"ONNX symbolic expected the output of `{tuple_node}` to be a quantized " + f"tensor. Is this likely due to missing support for quantized " + f"`{tuple_node.kind()}`. Please create an issue on {_constants.PYTORCH_GITHUB_ISSUES_URL}", + tuple_value, + ) + unpacked = tuple(tuple_node.inputs()) + assert len(unpacked) == 3 or len(unpacked) == 4 + return unpacked + + +# Check if list_value is output from prim::ListConstruct +# This is usually called before _unpack_list to ensure the list can be unpacked. +def _is_packed_list(list_value: Any) -> bool: + return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct" + + +def parse_args( + *arg_descriptors: _ValueDescriptor, +) -> Callable[[Callable[_Concatenate[_U, _P], _T]], Callable[_Concatenate[_U, _P], _T]]: + """A decorator which converts args from torch._C.Value to built-in types. + + For example: + + ``` + @parse_args('v', 'i', 'fs') + foo(g, a, b, c): + assert isinstance(a, torch._C.Value) + assert isinstance(b, int) + assert isinstance(c, list) + assert isinstance(c[0], float) + ``` + + Args: + arg_descriptors: list of str, where each element is + a string that specifies the type to convert to. Valid descriptors: + "v": no conversion, keep torch._C.Value. + "i": int + "is": list of int + "f": float + "fs": list of float + "b": bool + "s": str + "t": torch.Tensor + "none": the variable is unused + """ + + def decorator( + fn: Callable[_Concatenate[_U, _P], _T], + ) -> Callable[_Concatenate[_U, _P], _T]: + fn._arg_descriptors = arg_descriptors # type: ignore[attr-defined] + + @functools.wraps(fn) + def wrapper(g: _U, *args: _P.args, **kwargs: _P.kwargs) -> _T: + # some args may be optional, so the length may be smaller + FILE_BUG_MSG = ( + "If you believe this is not due to custom symbolic implementation within your code or " + "an external library, please file an issue at " + "https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to report this bug." + ) + assert len(arg_descriptors) >= len(args), ( + f"A mismatch between the number of arguments ({len(args)}) and " + f"their descriptors ({len(arg_descriptors)}) was found at symbolic function '{fn.__name__}'. " + f"{FILE_BUG_MSG}" + ) + + try: + sig = inspect.signature(fn) + arg_names = list(sig.parameters.keys())[1:] + fn_name = fn.__name__ + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + arg_names = [None] * len(args) # type: ignore[list-item] + fn_name = None + args = [ + _parse_arg(arg, arg_desc, arg_name, fn_name) # type: ignore[method-assign] + for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names) + ] + # only support _outputs in kwargs + assert len(kwargs) <= 1, ( + f"Symbolic function {fn.__name__}'s '**kwargs' can contain a single " + f"key/value entry. " + f"{FILE_BUG_MSG}" + ) + + if len(kwargs) == 1: + assert "_outputs" in kwargs, ( + f"Symbolic function {fn.__name__}'s '**kwargs' can only contain " + f"'_outputs' key at '**kwargs'. " + f"{FILE_BUG_MSG}" + ) + return fn(g, *args, **kwargs) + + return wrapper + + return decorator + + +def quantized_args( + *arg_q_descriptors: bool, + scale: float | None = None, + zero_point: int | None = None, + quantize_output: bool = True, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """A decorator which extends support for quantized version of the base operator. + + Quantization is detected by examining the arguments that are annotated by + `arg_q_descriptors`. + + If quantization is detected, the base operator symbolic function will be wrapped with + argument de-quantization and output quantization. + + Otherwise, only the base symbolic function will be invoked. + + For example: + + ``` + @quantized_args(True, False) + def foo(g, x, y): + return x + y + ``` + + is equivalent to + + ``` + def q_foo(g, x, y): + if is_quantized_tensor(x): + x = dequantize(x) + out = foo(g, x, y) + return quantize(out) + else: + return foo(g, x, y) + ``` + + Args: + arg_q_descriptors: A sequence of bool, where each element represents if the + argument is QTensor for quantized version of this operator. It defaults + to False for unspecified (variable length) arguments. + scale: Quantized output scale. If None, derive from + the first quantized input scale. + zero_point: Quantized output zero point. If None, + derive from the first quantized input zero point. + quantize_output: If True, quantize the output of the base operator. Default is True + """ + + def decorator(fn): + @functools.wraps(fn) + def wrapper(g, *args, **kwargs): + nonlocal scale + nonlocal zero_point + if scale is not None: + _scale = g.op("Constant", value_t=torch.tensor(scale)) + else: + _scale = None + if zero_point is not None: + _zero_point = g.op("Constant", value_t=torch.tensor(zero_point)) + else: + _zero_point = None + + # Support variable length arguments by marking unspecified ones as non-quantized + arg_q_descriptors_extended = arg_q_descriptors + (False,) * ( + len(args) - len(arg_q_descriptors) + ) + descriptor_args = tuple(zip(arg_q_descriptors_extended, args)) + + def _is_arg_quantized(descriptor, arg): + return descriptor and _is_value(arg) and _is_tuple_construct(arg) + + # Run regular symbolic function if none of the argument is QTensor. + is_quantized: list[bool] = [] + for descriptor, arg in descriptor_args: + # ListConstruct + if _is_packed_list(arg): + is_quantized.extend( + _is_arg_quantized(descriptor, arg_input) + for arg_input in arg.node().inputs() + ) + else: + is_quantized.append(_is_arg_quantized(descriptor, arg)) + + if not any(is_quantized): + return fn(g, *args, **kwargs) + + # Dequantize arguments that are quantized + non_quantized_args = [] + for descriptor, arg in descriptor_args: + if _is_arg_quantized(descriptor, arg): + # Quantized arg is a tuple of (value, scale, zero_point) + dequantized_arg, arg_scale, arg_zero_point, _ = dequantize_helper( + g, arg + ) + non_quantized_args.append(dequantized_arg) + # Set scale and zero_point to the first quantized input if not already set + if _scale is None: + _scale = arg_scale + if _zero_point is None: + _zero_point = arg_zero_point + # ListConstruct + elif _is_packed_list(arg): + for arg_input in arg.node().inputs(): + if _is_arg_quantized(descriptor, arg_input): + # Quantized arg is a tuple of (value, scale, zero_point) + ( + dequantized_arg, + arg_scale, + arg_zero_point, + _, + ) = dequantize_helper(g, arg_input) + # Set scale and zero_point to the first quantized input if not already set + if _scale is None: + _scale = arg_scale + if _zero_point is None: + _zero_point = arg_zero_point + arg_input.replaceAllUsesWith(dequantized_arg) + non_quantized_args.append(arg) + else: + # Non-quantized arg + non_quantized_args.append(arg) + # TODO(justinchuby): Only single output is supported for now. We may want to + # support multiple outputs in the future. + output = fn(g, *non_quantized_args, **kwargs) + + assert _scale is not None, "Bug: Scale must be set for quantized operator" + assert _zero_point is not None, ( + "Bug: Zero point must be set for quantized operator" + ) + + if quantize_output: + return quantize_helper(g, output, _scale, _zero_point) + return output + + return wrapper + + return decorator + + +def _scalar(x: Any) -> Number | None: + """Convert a scalar tensor into a Python value.""" + if isinstance(x, torch.Tensor) and x.shape == (): + return x.item() + return None + + +def _if_scalar_type_as(self, tensor): + """ + Convert self into the same type of tensor, as necessary. + We only support implicit casting for scalars, so we never + actually need to insert an ONNX cast operator here; just + fix up the scalar. + """ + if isinstance(self, _C.Value): + return self + + scalar_type = _type_utils.JitScalarType.from_value( + tensor, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + ty = scalar_type.scalar_name().lower() + return getattr(self, ty)() + return self + + +def _is_none(x: Any) -> bool: + return x is None or (x.node().mustBeNone() if isinstance(x, _C.Value) else False) + + +def _is_value(x: Any) -> bool: + return isinstance(x, _C.Value) + + +def _is_constant(value: Any) -> bool: + return not _is_value(value) or value.node().kind() in { + "onnx::Constant", + "prim::Constant", + } + + +def _is_tensor(x: _C.Value) -> bool: + return x.type().isSubtypeOf(_C.TensorType.get()) + + +# Note: _C.JitType is not exposed to Python and cannot be checked in runtime. +def _as_list_type(jit_type: _C.JitType) -> _C.ListType | None: + if isinstance(jit_type, _C.ListType): + return jit_type + return None + + +def _is_list(x: _C.Value) -> bool: + return _as_list_type(x.type()) is not None + + +def _is_tensor_list(x: _C.Value) -> bool: + x_type = _as_list_type(x.type()) + if x_type is None: + return False + return isinstance(x_type.getElementType(), _C.TensorType) + + +def _is_scalar_list(x: _C.Value) -> bool: + """Checks if x is a scalar list, for example: List[float], List[int]. + + Besides checking the type is ListType, we also check if the data type is + a valid ONNX data type. + """ + x_type = _as_list_type(x.type()) + if x_type is None: + return False + scalar_type = _type_utils.JitScalarType.from_value(x) + return scalar_type.onnx_compatible() + + +def _is_tuple_construct(x: _C.Value) -> bool: + return x.node().kind() == "prim::TupleConstruct" + + +def is_complex_value(x: _C.Value) -> bool: + assert _is_value(x) + return _type_utils.JitScalarType.from_value( + x, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.COMPLEX32, + _type_utils.JitScalarType.COMPLEX64, + _type_utils.JitScalarType.COMPLEX128, + } + + +def _get_tensor_rank(x: _C.Value) -> int | None: + if not _is_tensor(x) or x.type() is None: + return None + x_type = x.type() + x_type = typing.cast(_C.TensorType, x_type) + return x_type.dim() + + +def _get_tensor_sizes(x: _C.Value, allow_nonstatic: bool = True): + if not _is_tensor(x) or x.type() is None: + return None + x_type = x.type() + x_type = typing.cast(_C.TensorType, x_type) + if allow_nonstatic: + # Each individual symbol is returned as None. + # e.g. [1, "a", "b"] -> [1, None, None] + return x_type.varyingSizes() + # returns None, if exists any symbol in sizes. + # e.g. [1, "a", "b"] -> None + return x_type.sizes() + + +def _get_tensor_dim_size(x: _C.Value, dim: int) -> int | None: + sizes = _get_tensor_sizes(x) + return sizes[dim] if sizes else None + + +def _get_dim_for_cross(x: _C.Value, dim: int | None): + if dim == -1: + tensor_rank = _get_tensor_rank(x) + assert tensor_rank is not None + return dim + tensor_rank + # If dim is not given, it defaults to the first dimension found with the size 3 + if dim is None: + sizes = _get_tensor_sizes(x) + assert sizes is not None + for index, size in enumerate(sizes): + if size is not None and size == 3: + return index + return dim + + +def _unimplemented(op: str, msg: str, value: _C.Value | None = None) -> None: + # For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators + if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: + _onnx_unsupported(f"{op}, {msg}", value) + + +def _onnx_unsupported(op_name: str, value: _C.Value | None = None) -> NoReturn: + message = ( + f"Unsupported: ONNX export of operator {op_name}. " + f"Please feel free to request support or submit a pull request " + f"on PyTorch GitHub: {_constants.PYTORCH_GITHUB_ISSUES_URL}" + ) + if isinstance(value, _C.Value): + raise errors.SymbolicValueError( + message, + value, + ) + raise errors.OnnxExporterError(message) + + +def _onnx_opset_unsupported( + op_name: str, + current_opset: int, + supported_opset: int, + value: _C.Value | None = None, +) -> NoReturn: + message = ( + f"Unsupported: ONNX export of {op_name} in opset {current_opset}. " + f"Please try opset version {supported_opset}." + ) + if isinstance(value, _C.Value): + raise errors.SymbolicValueError( + message, + value, + ) + raise errors.OnnxExporterError(message) + + +def _onnx_opset_unsupported_detailed( + op_name: str, + current_opset: int, + supported_opset: int, + reason: str, + value: _C.Value | None = None, +) -> NoReturn: + message = ( + f"Unsupported: ONNX export of {op_name} in " + f"opset {current_opset}. {reason}. Please try opset version {supported_opset}." + ) + if isinstance(value, _C.Value): + raise errors.SymbolicValueError( + message, + value, + ) + raise errors.OnnxExporterError(message) + + +def _block_list_in_opset(name: str): + def symbolic_fn(*args, **kwargs): + raise errors.OnnxExporterError( + f"ONNX export failed on {name}, which is not implemented for opset " + f"{GLOBALS.export_onnx_opset_version}. " + "Try exporting with other opset versions." + ) + + return symbolic_fn + + +def _try_get_scalar_type(*args) -> _type_utils.JitScalarType | None: + for arg in args: + scalar_type = _type_utils.JitScalarType.from_value( + arg, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + return scalar_type + return None + + +def _type_promote_from_values(*args) -> _type_utils.JitScalarType: + undef = _type_utils.JitScalarType.UNDEFINED + jit_types = [_try_get_scalar_type(arg) for arg in args] + if len(jit_types) == 0: + return undef + if len(jit_types) == 1: + return jit_types[0] # type: ignore[return-value] + new_dtype = jit_types[0].dtype() # type: ignore[union-attr] + for t in jit_types: + new_dtype = torch.promote_types(new_dtype, t.dtype()) # type: ignore[union-attr] + return _type_utils.JitScalarType.from_dtype(new_dtype) + + +def _maybe_cast_to_type( + g: jit_utils.GraphContext, value, jit_type: _type_utils.JitScalarType +): + if ( + _type_utils.JitScalarType.from_value(value, _type_utils.JitScalarType.UNDEFINED) + != jit_type + ): + return g.op( + "Cast", + value, + to_i=jit_type.onnx_type(), + ) + return value + + +def _select_helper(g: jit_utils.GraphContext, self, dim, index, apply_reshape=True): + index_const = _maybe_get_scalar(index) + index_dim = _get_tensor_rank(index) + if not _is_value(index_const): + # Index is a constant scalar. Make it a size 1 constant tensor. + index = g.op("Constant", value_t=torch.LongTensor([index_const])) + elif index_dim is not None and apply_reshape: + if index_dim == 0: + # Index is a scalar. Reshape it to a size 1 tensor. + index = _reshape_helper( + g, index, g.op("Constant", value_t=torch.LongTensor([1])) + ) + + index_scalar_type = _type_utils.JitScalarType.from_value( + index, _type_utils.JitScalarType.UNDEFINED + ) + if index_scalar_type not in { + _type_utils.JitScalarType.INT64, + _type_utils.JitScalarType.INT, + }: + index = g.op("Cast", index, to_i=_C_onnx.TensorProtoDataType.INT64) + return g.op("Gather", self, index, axis_i=dim) + + +def _slice_helper( + g: jit_utils.GraphContext, + input, + axes, + starts, + ends, + steps=None, +): + if g.opset <= 9: + from torch.onnx.symbolic_opset9 import _slice as _slice9 + + return _slice9(g, input, axes, starts, ends) + else: + from torch.onnx.symbolic_opset10 import _slice as _slice10 + + return _slice10(g, input, axes, starts, ends, steps) + + +def _is_fp(value) -> bool: + return _type_utils.JitScalarType.from_value( + value, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.FLOAT, + _type_utils.JitScalarType.DOUBLE, + _type_utils.JitScalarType.HALF, + _type_utils.JitScalarType.BFLOAT16, + } + + +def _is_bool(value) -> bool: + return _type_utils.JitScalarType.from_value( + value, _type_utils.JitScalarType.UNDEFINED + ) in {_type_utils.JitScalarType.BOOL} + + +def _generate_wrapped_number(g: jit_utils.GraphContext, scalar): + """Creates a wrapped number based on https://github.com/pytorch/pytorch/issues/9515. + + A Tensor is a considered a "wrapped number" if it is + auto-wrapped from a C++ or Python number type. Integer types are + wrapped as 0-dim int64 tensors and floating-point types are + wrapped as 0-dim double tensors. + + The input to this function is constant value. If the data type + is a floating point type, it is converted to a 0-dim double + tensor, else it is converted to a 0-dim tensor of its original type + """ + assert not isinstance(scalar, torch.Tensor) + if isinstance(scalar, float): + return g.op("Constant", value_t=torch.tensor(scalar, dtype=torch.double)) + return g.op("Constant", value_t=torch.tensor(scalar)) + + +def _sort_helper(g: jit_utils.GraphContext, input, dim, decending=True, out=None): + if out is not None: + _unimplemented("Sort", "Out parameter is not supported") + shape_ = g.op("Shape", input) + dim_size_ = g.op( + "Gather", + shape_, + g.op("Constant", value_t=torch.tensor([dim], dtype=torch.int64)), + ) + if g.opset <= 10: + if not decending: + _unimplemented("Sort", "Ascending is not supported") + return g.op("TopK", input, dim_size_, axis_i=dim, outputs=2) + else: + return g.op( + "TopK", input, dim_size_, axis_i=dim, largest_i=decending, outputs=2 + ) + + +def _topk_helper( + g: jit_utils.GraphContext, input, k, dim, largest=True, sorted=False, out=None +): + if out is not None: + _unimplemented("TopK", "Out parameter is not supported") + if not _is_value(k): + k = g.op("Constant", value_t=torch.tensor([k], dtype=torch.int64)) + else: + k = _reshape_helper(g, k, g.op("Constant", value_t=torch.tensor([1]))) + if _try_get_scalar_type(k) != _type_utils.JitScalarType.INT64: + k = g.op("Cast", k, to_i=_C_onnx.TensorProtoDataType.INT64) + if g.opset <= 10: + if not largest: + _unimplemented("TopK", "Ascending is not supported") + return g.op("TopK", input, k, axis_i=dim, outputs=2) + else: + return g.op( + "TopK", input, k, axis_i=dim, largest_i=largest, sorted_i=sorted, outputs=2 + ) + + +def _lt_helper(g: jit_utils.GraphContext, input, other): + if g.opset <= 8: + from torch.onnx.symbolic_opset8 import lt as _lt8 + + return _lt8(g, input, other) + else: + from torch.onnx.symbolic_opset9 import lt as _lt9 + + return _lt9(g, input, other) + + +def _interpolate_warning(interpolate_mode): + onnx_op = ( + "onnx:Resize" if GLOBALS.export_onnx_opset_version >= 10 else "onnx:Upsample" + ) + warnings.warn( + "You are trying to export the model with " + + onnx_op + + " for ONNX opset version " + "" + str(GLOBALS.export_onnx_opset_version) + ". " + "This operator might cause results to not match the expected results by PyTorch.\n" + "ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. " + "Attributes to determine how to transform the input were added in onnx:Resize in opset 11 " + "to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).\n" + "We recommend using opset 11 and above for models using this operator." + ) + + +def _unsqueeze_helper(g: jit_utils.GraphContext, input, axes_i): + if len(axes_i) == 0: + # unnecessary unsqueeze if axes length==0 + return input + elif _is_constant(axes_i[0]): + if g.opset >= 13: + axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long)) + return g.op("Unsqueeze", input, axes) + return g.op("Unsqueeze", input, axes_i=axes_i) + # Tensor type + if g.opset < 13: + raise errors.SymbolicValueError( + "Opset version must be >= 13 for Unsqueeze with dynamic axes.", input + ) + return g.op("Unsqueeze", input, axes_i[0]) + + +def _squeeze_helper(g: jit_utils.GraphContext, input, axes_i): + if _is_constant(axes_i[0]): + if g.opset >= 13: + axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long)) + return g.op("Squeeze", input, axes) + return g.op("Squeeze", input, axes_i=axes_i) + # Tensor type + if g.opset < 13: + raise errors.SymbolicValueError( + "Opset version must be >= 13 for Squeeze with dynamic axes.", input + ) + axes_t = axes_i[0] + axes_rank = _get_tensor_rank(axes_t) + assert axes_rank is not None + if axes_rank > 1: + raise errors.SymbolicValueError( + "For Squeeze axses as input, the axes rank must be one in ONNX spec.", input + ) + elif axes_rank == 0: + # The axes is a scalar. Unsqueeze it to a rank 1 tensor. + axes_t = _unsqueeze_helper(g, axes_t, [0]) + return g.op("Squeeze", input, axes_t) + return g.op("Squeeze", input, axes_t) + + +def _reducesum_helper( + g: jit_utils.GraphContext, + input, + axes_i=None, + keepdims_i=1, + noop_with_empty_axes_i=0, +): + keepdims_i = _maybe_get_const(keepdims_i, "i") + if g.opset >= 13: + if axes_i: + if not _is_value(axes_i): + axes_i = g.op( + "Constant", value_t=torch.tensor(axes_i, dtype=torch.long) + ) + return g.op( + "ReduceSum", + input, + axes_i, + keepdims_i=keepdims_i, + noop_with_empty_axes_i=noop_with_empty_axes_i, + ) + return g.op( + "ReduceSum", + input, + keepdims_i=keepdims_i, + noop_with_empty_axes_i=noop_with_empty_axes_i, + ) + else: + return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i) + + +def _interpolate_size_to_scales(g: jit_utils.GraphContext, input, output_size, dim): + output_size = _maybe_get_const(output_size, "is") + if _is_value(output_size): + offset = 2 + offsets = g.op("Constant", value_t=torch.ones(offset, dtype=torch.float32)) + dividend = g.op("Cast", output_size, to_i=_C_onnx.TensorProtoDataType.FLOAT) + divisor = _slice_helper( + g, g.op("Shape", input), axes=[0], ends=[sys.maxsize], starts=[offset] + ) + divisor = g.op("Cast", divisor, to_i=_C_onnx.TensorProtoDataType.FLOAT) + scale_dims = g.op("Div", dividend, divisor) + scales = g.op("Concat", offsets, scale_dims, axis_i=0) + else: + scales_constant = [ + 1.0 + if i < 2 + else float(output_size[-(dim - i)]) + / float(input.type().sizes()[-(dim - i)]) + for i in range(0, dim) + ] + scales = g.op( + "Constant", value_t=torch.tensor(scales_constant, dtype=torch.float32) + ) + return scales + + +def _interpolate_get_scales_if_available(g: jit_utils.GraphContext, scales): + available_scales = _maybe_get_const(scales[0], "fs") != -1 and not _is_none( + scales[0] + ) + + if not available_scales: + return None + + offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32)) + scales_list = g.op( + "Constant", value_t=torch.tensor(_maybe_get_const(scales[0], "fs")) + ) + scales = g.op("Concat", offsets, scales_list, axis_i=0) + return scales + + +def _get_interpolate_attributes(g: jit_utils.GraphContext, mode, args): + if mode == "nearest": + align_corners = None + scales = args[0:] + else: + align_corners = args[0] + scales = args[1:] + scales = _interpolate_get_scales_if_available(g, scales) + return scales, align_corners + + +def _interpolate_get_scales(g: jit_utils.GraphContext, scale_factor, dim): + offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32)) + scale_factor_rank = _get_tensor_rank(scale_factor) + if isinstance(scale_factor.type(), _C.ListType) or ( + scale_factor_rank is not None and scale_factor_rank > 0 + ): + return g.op("Concat", offsets, scale_factor, axis_i=0) + else: + scale_factor = _unsqueeze_helper(g, scale_factor, [0]) + scale_factor = g.op( + "Cast", scale_factor, to_i=_C_onnx.TensorProtoDataType.FLOAT + ) + scales = [scale_factor for i in range(dim - 2)] + scale_factor = g.op("Concat", offsets, *scales, axis_i=0) + return scale_factor + + +def _interpolate_get_scales_and_mode( + g: jit_utils.GraphContext, input, size, scale_factor, mode, align_corners +): + mode = _maybe_get_const(mode, "s") + if "linear" in mode: + mode = "linear" + if "cubic" in mode: + mode = "cubic" + _interpolate_warning(mode) + + align_corners = _maybe_get_const(align_corners, "b") + if isinstance(align_corners, bool) and align_corners: + return _unimplemented("interpolate", "align_corners == True") + + if not input.type().dim(): + return _unimplemented("interpolate", "missing input shape") + dim = input.type().dim() + + if not _is_none(scale_factor): + scale_factor = _interpolate_get_scales(g, scale_factor, dim) + elif not _is_none(size): + if not _is_packed_list(size): + is_scalar = _maybe_get_const(size, "t").dim() == 0 + if is_scalar: + size = _unsqueeze_helper(g, size, [0]) + size = [size for i in range(dim - 2)] + size = g.op("Concat", *size, axis_i=0) + scale_factor = _interpolate_size_to_scales(g, input, size, dim) + else: + return _unimplemented( + "interpolate", "Both size and scales are None in __interpolate" + ) + return scale_factor, mode + + +def _argmin_argmax_helper( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, + op_name: str, +): + def op_wrapper(input, axis_i, keepdims_i): + if g.opset >= 12: + return g.op( + op_name, + input, + axis_i=axis_i, + keepdims_i=keepdims_i, + select_last_index_i=False, + ) + return g.op(op_name, input, axis_i=axis_i, keepdims_i=keepdims_i) + + if _is_none(dim): + flattened = _reshape_helper( + g, input, g.op("Constant", value_t=torch.tensor([-1])) + ) + output = op_wrapper(flattened, axis_i=0, keepdims_i=False) + if keepdim: + input_shape = g.op("Shape", input) + input_shape_shape = g.op("Shape", input_shape) + new_shape = g.op( + "ConstantOfShape", + input_shape_shape, + value_t=torch.tensor([1], dtype=torch.int64), + ) + output = g.op("Reshape", output, new_shape) + return output + + dim = _parse_arg(dim, "i") + return op_wrapper(input, axis_i=dim, keepdims_i=keepdim) + + +def _interpolate_helper(name, dim, interpolate_mode): + @quantized_args(True, False, False) + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = _get_interpolate_attributes(g, interpolate_mode, args) + align_corners = _maybe_get_scalar(align_corners) + coordinate_transformation_mode = ( + "asymmetric" + if interpolate_mode == "nearest" + else "align_corners" + if align_corners + else "half_pixel" + ) + + if scales is None: + input_size = g.op("Shape", input) + input_size_beg = _slice_helper( + g, input_size, axes=[0], ends=[2], starts=[0] + ) + output_size = g.op( + "Cast", output_size, to_i=_C_onnx.TensorProtoDataType.INT64 + ) + output_size = g.op("Concat", input_size_beg, output_size, axis_i=0) + + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + empty_scales = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + empty_scales = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + + return g.op( + "Resize", + input, + empty_roi, + empty_scales, + output_size, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=interpolate_mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) # only valid when mode="nearest" + else: + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + + return g.op( + "Resize", + input, + empty_roi, + scales, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=interpolate_mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) # only valid when mode="nearest" + + return symbolic_fn + + +def __interpolate_helper( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, +): + mode = _maybe_get_const(mode, "s") + if "linear" in mode: + mode = "linear" + if "cubic" in mode: + mode = "cubic" + align_corners = _maybe_get_const(align_corners, "b") + align_corners = False if not isinstance(align_corners, bool) else align_corners + coordinate_transformation_mode = ( + "asymmetric" + if mode == "nearest" + else "align_corners" + if align_corners + else "half_pixel" + ) + + if not _is_none(size): + input_size = g.op("Shape", input) + input_size = _slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) + # in some cases size is not a packed list but size is a scalar + # We need to also verify that (_maybe_get_const(size, "t").dim() == 0) + # but this information is not always available. Try to get the dim, + # and if not assume that it is not a scalar. + try: + is_scalar = not _is_packed_list(size) and ( + _maybe_get_const(size, "t").dim() == 0 + ) + except AttributeError: + is_scalar = not _is_packed_list(size) + if not is_scalar: + warnings.warn( + "Cannot verify if the output_size is a scalar " + "while exporting interpolate. Assuming that it is not a scalar." + ) + + if is_scalar: + rank = _get_tensor_rank(input) + if rank is None: + return _unimplemented( + "interpolate (with a scalar output_size)", + "missing input shape (try giving an array of output_size values)", + ) + size = _unsqueeze_helper(g, size, [0]) + size = [size for i in range(rank - 2)] + size = g.op("Concat", *size, axis_i=0) + size = g.op("Cast", size, to_i=_C_onnx.TensorProtoDataType.INT64) + size = g.op("Concat", input_size, size, axis_i=0) + + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + empty_scales = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + empty_scales = g.op( + "Constant", value_t=torch.tensor([], dtype=torch.float32) + ) + + return g.op( + "Resize", + input, + empty_roi, + empty_scales, + size, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) + else: # if not _is_none(scales) + rank = _get_tensor_rank(input) + if rank is None: + return _unimplemented("interpolate (with scales)", "missing input shape") + + if g.opset >= 13: + empty_roi = _optional_input_placeholder_tensor(g) + else: + empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + + scales = _interpolate_get_scales(g, scale_factor, rank) + return g.op( + "Resize", + input, + empty_roi, + scales, + coordinate_transformation_mode_s=coordinate_transformation_mode, + cubic_coeff_a_f=-0.75, # only valid when mode="cubic" + mode_s=mode, # nearest, linear, or cubic + nearest_mode_s="floor", + ) # only valid when mode="nearest" + + +def _unbind_helper(g: jit_utils.GraphContext, self, dim, _outputs): + if g.opset < 11: + from torch.onnx.symbolic_opset9 import unbind + elif g.opset <= 12: + from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef] + else: + from torch.onnx.symbolic_opset13 import unbind # type: ignore[no-redef] + return unbind(g, self, dim, _outputs) + + +def _scatter_helper(g: jit_utils.GraphContext, self, dim, index, src): + if g.opset <= 10: + from torch.onnx.symbolic_opset9 import scatter + else: + # for mypy, scatter was imported two lines above + from torch.onnx.symbolic_opset11 import scatter # type: ignore[no-redef] + return scatter(g, self, dim, index, src) + + +def _repeat_interleave_split_helper(g: jit_utils.GraphContext, self, reps, dim): + if g.opset <= 12: + split_out = g.op("Split", self, split_i=[1] * reps, axis_i=dim, outputs=reps) + else: + from torch.onnx.symbolic_opset13 import split + + repeats = g.op("Constant", value_t=torch.tensor([1] * reps)) + split_out = split(g, self, repeats, dim, _outputs=reps) + return split_out if reps > 1 else [split_out] + + +def _repeat_interleave_single_value_repeat_helper( + g: jit_utils.GraphContext, self, repeats, dim +): + from torch.onnx.symbolic_opset9 import flatten, unsqueeze + + if not _is_tensor(repeats): + repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) + + const_repeats: bool = _is_constant(repeats) + reps = _maybe_get_const(repeats, "t") + + # Convert 'repeats' to 1-d if it is 0-d. + if _get_tensor_rank(repeats) == 0: + repeats = g.op("Reshape", repeats, g.op("Constant", value_t=torch.tensor([1]))) + + # Create a new dim of size 1, then expand it to be 'repeats' long, and finally collapse it. + unsqueezed = unsqueeze(g, self, dim + 1) + + # repeats_per_dim is 1 for all dims except for the new unsqueezed dim, where it has value 'repeats'. + if const_repeats: + # 'Repeats' is a constant, 'repeats_per_dim' can be a constant. + onehot = torch.ones(_get_tensor_rank(unsqueezed), dtype=torch.int64) # type: ignore[arg-type] + onehot[dim + 1] = reps + repeats_per_dim = g.op("Constant", value_t=onehot) + else: + # 'Repeats' is a variable, 'repeats_per_dim' cannot be a constant. + onehot = g.op( + "OneHot", + unsqueeze(g, dim + 1, 0), # indices, must be >= 1-dimensional + g.op( + "Constant", value_t=torch.tensor(_get_tensor_rank(unsqueezed)) + ), # depth + g.op( + "Concat", g.op("Constant", value_t=torch.tensor([1])), repeats, axis_i=0 + ), # on/off values + ) + repeats_per_dim = flatten(g, onehot, 0, 1) + + tiled = g.op("Tile", unsqueezed, repeats_per_dim) + return flatten(g, tiled, dim, dim + 1) + + +def _arange_cast_helper( + g: jit_utils.GraphContext, end, start=None, step=None, dtype=None +) -> tuple[ + _type_utils.JitScalarType, + _C.Value | None, + _C.Value | None, + _C.Value | None, +]: + def _is_all_integral(scalars): + for scalar in scalars: + scalar_type = _type_utils.JitScalarType.from_value( + scalar, _type_utils.JitScalarType.UNDEFINED + ) + if ( + scalar_type != _type_utils.JitScalarType.INT64 + and scalar_type != _type_utils.JitScalarType.UNDEFINED + ): + return False + return True + + # This logic is based on torch.arange docs. If "dtype" is provided, + # infer input types from dtype. If not, then check if any of start, stop, + # or step are floating point, and infer the type from get_default. + # Otherwise, the dtype is inferred to be torch.int64. + if dtype is None or (_is_value(dtype) and _is_none(dtype)): + if _is_all_integral([start, end, step]): + scalar_type = _type_utils.JitScalarType.INT64 + else: + scalar_type = _type_utils.JitScalarType.from_dtype( + torch.get_default_dtype() + ) + else: + assert isinstance(dtype, int) + # TODO(justinchuby): Check if dtype is indeed a int. + scalar_type = _type_utils.JitScalarType(dtype) + + start = g.op("Cast", start, to_i=scalar_type.onnx_type()) if start else None + end = g.op("Cast", end, to_i=scalar_type.onnx_type()) if end else None + step = g.op("Cast", step, to_i=scalar_type.onnx_type()) if step else None + return scalar_type, end, start, step + + +def _arange_helper(g: jit_utils.GraphContext, *args): + if g.opset <= 10: + from torch.onnx.symbolic_opset9 import arange + else: + from torch.onnx.symbolic_opset11 import arange # type: ignore[no-redef] + return arange(g, *args) + + +def _size_helper(g: jit_utils.GraphContext, self, dim): + full_shape = g.op("Shape", self) + from torch.onnx.symbolic_opset9 import select + + return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim) + + +def _index_fill_reshape_helper(g: jit_utils.GraphContext, self, dim, index): + # 1. reshape index => [1, ..., 1, dim, 1, ..., 1] + # 2. expand index => [..., dim, ...], same shape as self except for dim. + # 3. expand value as well. + # 4. apply onnx::scatter. + + from torch.onnx.symbolic_opset9 import expand + + if g.opset <= 10: + from torch.onnx.symbolic_opset9 import scatter + else: + # for mypy, scatter was imported two lines above + from torch.onnx.symbolic_opset11 import scatter # type: ignore[no-redef] + + if self.type().dim() is None: + return _unimplemented("index_fill", "input rank not accessible") + self_dim = self.type().dim() + dim_value = _parse_arg(dim, "i") + if dim_value < 0: + dim_value += self_dim + unsqueezed_index = _unsqueeze_helper( + g, index, [i for i in range(self_dim) if i != dim_value] + ) + expanded_index_shape = scatter( + g, g.op("Shape", self), 0, _unsqueeze_helper(g, dim, [0]), g.op("Shape", index) + ) + expanded_index = expand(g, unsqueezed_index, expanded_index_shape, None) + return expanded_index_shape, expanded_index + + +# By default, when any value in the 'shape' input is equal to zero +# the corresponding dimension value is copied from the input tensor dynamically. +# allowzero=1 indicates that if any value in the 'shape' input is set to zero, +# the zero value is honored, similar to NumPy. +# allowzero=1 is only supported for opset version >= 14. +def _reshape_helper(g: jit_utils.GraphContext, input, shape, allowzero=0): + shape = _maybe_get_const(shape, "is") + if not _is_value(shape): + shape = g.op("Constant", value_t=torch.LongTensor(shape)) + if g.opset <= 13: + if allowzero == 1: + _onnx_opset_unsupported( + "Reshape with allowzero=1", GLOBALS.export_onnx_opset_version, 14, input + ) + return g.op("Reshape", input, shape) + else: + return g.op("Reshape", input, shape, allowzero_i=allowzero) + + +def _batchnorm_helper( + g: jit_utils.GraphContext, input, weight, bias, running_mean, running_var +): + from torch.onnx.symbolic_opset9 import _var_mean + + batch_size = _get_tensor_dim_size(input, 0) + channel_size = _get_tensor_dim_size(input, 1) + + if weight is None or _is_none(weight): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of batch_norm for unknown channel size.", + input, + ) + weight_value = torch.tensor( + [1.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + weight = g.op("Constant", value_t=weight_value) + if bias is None or _is_none(bias): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of batch_norm for unknown channel size.", + input, + ) + bias_value = torch.tensor( + [0.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + bias = g.op("Constant", value_t=bias_value) + # If track_running_stats is set to False batch statistics are instead used during evaluation time + if ( + running_mean is None + or _is_none(running_mean) + or running_var is None + or _is_none(running_var) + ): + assert batch_size is not None and channel_size is not None + reshape_in = _reshape_helper( + g, + input, + g.op( + "Constant", + value_t=torch.tensor([batch_size, channel_size, -1], dtype=torch.int64), + ), + ) + trans_in = g.op("Transpose", reshape_in, perm_i=[0, 2, 1]) + running_var, running_mean = _var_mean( + g, + trans_in, + g.op("Constant", value_t=torch.tensor([0, 1], dtype=torch.int64)), + False, + False, + ) + return weight, bias, running_mean, running_var + + +def _avgpool_helper( + tuple_fn: Callable[[Any], Sequence[int]], + padding: int | Sequence[int], + kernel_size, + stride, + divisor_override, + name, +) -> tuple[int, ...]: + if divisor_override and divisor_override.node().kind() != "prim::Constant": + _unimplemented(name, "divisor_override") + return tuple(tuple_fn(padding)) + + +def check_training_mode(op_train_mode: int, op_name: str) -> None: + """Warns the user if the model's training mode and the export mode do not agree.""" + if GLOBALS.training_mode == _C_onnx.TrainingMode.PRESERVE: + return + + if op_train_mode: + op_mode_enum = _C_onnx.TrainingMode.TRAINING + else: + op_mode_enum = _C_onnx.TrainingMode.EVAL + if op_mode_enum == GLOBALS.training_mode: + # The modes agree. Do nothing + return + + op_mode_text = f"train={bool(op_train_mode)}" + # Setting the model mode could result in op_mode != GLOBALS.training_mode + # if the model is a FuncModule. In this case we warn the user of + # the state and export depending on op_mode + # This is to support use-cases of fixing certain layer weights + # in training. + warnings.warn( + f"ONNX export mode is set to {GLOBALS.training_mode}, but operator '{op_name}' " + f"is set to {op_mode_text}. Exporting with {op_mode_text}." + ) + + +def _flatten_helper(g: jit_utils.GraphContext, input, start_dim, end_dim, dim): + input_size = g.op("Shape", input) + slice1 = _slice_helper(g, input_size, axes=[0], starts=[0], ends=[start_dim]) + slices = [slice1, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))] + if end_dim < dim - 1: + slice3 = _slice_helper( + g, input_size, axes=[0], starts=[end_dim + 1], ends=[dim] + ) + slices = [ + slice1, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + slice3, + ] + + final_shape = g.op("Concat", *slices, axis_i=0) + from torch.onnx.symbolic_opset9 import _reshape_from_tensor + + return _reshape_from_tensor(g, input, final_shape) + + +def _is_split_static(split_size_or_sizes, _outputs): + if _outputs is None: + return False + if ( + _is_value(split_size_or_sizes) + and split_size_or_sizes.node().kind() != "onnx::Constant" + ): + return False + return True + + +def _optional_input_placeholder_tensor(g): + n = g.op("prim::Constant") + n.setType(_C.OptionalType.ofTensor()) + return n + + +def _handle_reduce_dim_none(g: jit_utils.GraphContext, self, op_name): + rank = _get_tensor_rank(self) + if rank is not None and any( + _get_tensor_dim_size(self, i) == 0 for i in range(rank) + ): + # If input tensor is empty, according to ONNX ReduceSum definition, + # set keepdims=1 so that the resulted tensor has the same rank as the input. + return g.op(op_name, self, keepdims_i=1) + return g.op(op_name, self, keepdims_i=0) + + +def dequantize_helper( + g: jit_utils.GraphContext, + qtensor: _C.Value, + qdtype: _C_onnx.TensorProtoDataType | None = None, +) -> tuple[_C.Value, _C.Value, _C.Value, _C.Value | None]: + """Appends to graph `g` ONNX nodes that dequantizes `qtensor` into `tensor`. + + Args: + g: Graph, the ONNX IR graph that is under construction. + qtensor: torch._C.Value, either a tuple of (quantized_tensor, scale, zero_point) + for per tensor quantization, or + (quantized_tensor, scale, zero_point, axis) for per channel quantization, + representing the quantized tensor. + qdtype: torch.onnx.TensorProtoDataType default None, if not None, represents the + data type of quantized tensor. It must be either + torch.onnx.TensorProtoDataType.UINT8 or torch.onnx.TensorProtoDataType.INT8. + """ + unpacked_qtensors = _unpack_quantized_tensor(qtensor) + tensor, scale, zero_point = unpacked_qtensors[:3] + axis = unpacked_qtensors[3] if len(unpacked_qtensors) >= 4 else None + axis_i = _get_const(axis, "i", "axis") + input_qdtype = _type_utils.JitScalarType.from_value(tensor) + if qdtype is None: + if input_qdtype is not None: + qdtype = input_qdtype.onnx_type() + else: + qdtype = _C_onnx.TensorProtoDataType.UINT8 + value = g.op("Cast", tensor, to_i=qdtype) + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + zero_point = g.op("Cast", zero_point, to_i=qdtype) + + if axis_i is not None and GLOBALS.export_onnx_opset_version < 13: + _onnx_opset_unsupported_detailed( + "DequantizeLinear", + GLOBALS.export_onnx_opset_version, + 13, + "Attribute axis is not supported.", + qtensor, + ) + + return ( + g.op("DequantizeLinear", value, scale, zero_point, axis_i=axis_i), + scale, + zero_point, + axis, + ) + + +def quantize_helper( + g: jit_utils.GraphContext, + tensor: _C.Value, + scale: _C.Value, + zero_point: _C.Value, + axis: _C.Value | None = None, +) -> _C.Value: + """Appends to graph `g` ONNX nodes that quantizes `tensor` based on `scale`, `zero_point` and `axis`. + + Args: + g: Graph, the ONNX IR graph that is under construction. + tensor: torch._C.Value, representing the tensor to be quantized. + scale: torch._C.Value, quantized scale. + zero_point: torch._C.Value, quantized zero point. + axis: Optional[torch._C.Value] default None, if None, represents per tensor quantization. + Otherwise, represents per channel quantization, along given axis. + + Returns: + A TupleConstruct storing information of the quantized tensor. + """ + if ( + axis is not None + and not _is_none(axis) + and GLOBALS.export_onnx_opset_version < 13 + ): + _onnx_opset_unsupported_detailed( + "QuantizeLinear", + GLOBALS.export_onnx_opset_version, + 13, + "Attribute axis is not supported.", + tensor, + ) + + assert scale is not None + if ( + _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) + != _type_utils.JitScalarType.FLOAT + ): + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + + assert zero_point is not None + if _type_utils.JitScalarType.from_value( + zero_point, _type_utils.JitScalarType.UNDEFINED + ) not in { + _type_utils.JitScalarType.UINT8, + _type_utils.JitScalarType.INT8, + }: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + output = g.op( + "QuantizeLinear", + tensor, + scale, + zero_point, + axis_i=_get_const(axis, "i", "axis"), + ) + args = [output, scale, zero_point] + if axis is not None and not _is_none(axis): + args.append(axis) + return g.op("prim::TupleConstruct", *args) + + +def requantize_bias_helper( + g: jit_utils.GraphContext, bias, input_scale, weight_scale, axis=None +): + """In PyTorch, bias is float and is quantized to int32 implicitly inside the quantized ATen op kernel. + In ONNX we need to make the quantization explicit because operators expect all of their inputs to be quantized. + Since int32 is not a supported output type by ONNX operator `QuantizeLinear`, quantization is exported using + regular operators. + """ + bias_scale = g.op("Mul", weight_scale, input_scale) + bias_scale_shape = g.op("Shape", bias_scale) + bias_zero_point = g.op( + "ConstantOfShape", bias_scale_shape, value_t=torch.tensor([0], dtype=torch.int) + ) + q_bias = g.op( + "Cast", g.op("Div", bias, bias_scale), to_i=_C_onnx.TensorProtoDataType.INT32 + ) + axis_args = [] + if axis is not None and not _is_none(axis): + axis_args.append(axis) + return g.op("prim::TupleConstruct", q_bias, bias_scale, bias_zero_point, *axis_args) + + +def args_have_same_dtype(args): + assert args + base_dtype = _type_utils.JitScalarType.from_value(args[0]) + has_same_dtype = all( + _type_utils.JitScalarType.from_value(elem) == base_dtype for elem in args + ) + return has_same_dtype + + +def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs): + """Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types. + This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch + operator data type. For example, `Cast(Clip(Cast(INPUT)))` can be used to mimic + `Clip(INPUT)` (opset version < 12). + + Args: + g (torch._C.Graph): graph to write the ONNX representation into. + op_name (str): operator name in ONNX. + *args (tuple): operands to the operator. + **kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default) + indicating the smallest opset version to trigger such casting behavior and "target_float_t" + (optional, torch.onnx.JitScalarType.FLOAT by default) indicating the data type of internal operator. + + Returns: + Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator. + """ + opset_before = kwargs.pop("opset_before", None) + target_float_t = kwargs.pop("target_float_t", _type_utils.JitScalarType.FLOAT) + + inputs = list(args) + dtype_0 = _type_utils.JitScalarType.from_value(inputs[0]) + + require_cast = not _is_fp(inputs[0]) and ( + opset_before is None or GLOBALS.export_onnx_opset_version < opset_before + ) + + if require_cast: + for input in inputs: + if input.isCompleteTensor(): + input_scalar_type = _type_utils.JitScalarType.from_value(input) + if input_scalar_type != dtype_0: + raise errors.SymbolicValueError( + f"Inputs of {op_name} must have same dtype." + f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}", + input, + ) + for i, input in enumerate(inputs): + if input.isCompleteTensor() and not _is_fp(input): + inputs[i] = g.op( + "Cast", + input, + to_i=target_float_t.onnx_type(), + ) + + self = g.op(op_name, *inputs, **kwargs) + + if require_cast: + self = g.op("Cast", self, to_i=dtype_0.onnx_type()) + + return self + + +def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + # This check only covers traced modules where dtype is present + # pytorch reduce-ops cast all other integral types to int64 + if not _is_fp(self) and scalar_type != _type_utils.JitScalarType.INT64: + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.INT64) + return self + + +def _apply_params(*args, **kwargs): + """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" + + def _apply(fn): + return fn(*args, **kwargs) + + return _apply + + +def _reduce_op_symbolic_helper(onnx_op_name, allow_multi_dim_support=True): + def symbolic(g, self, dim=None, keepdim=None): + self = _maybe_cast_reduce_op_input(g, self) + if dim is None or dim == (): + # Dim can be 0, which will cause (not dim) == True. So we don't want to do + # (not dim) + # all-reduce path + return _handle_reduce_dim_none(g, self, onnx_op_name) + else: + # dim-reduce path + keepdim = _get_const(keepdim, "i", "keepdim") + if g.opset < 18: + desc = "is" if allow_multi_dim_support else "i" + dim = _get_const(dim, desc, "dim") + dim_list = dim if allow_multi_dim_support else [dim] + return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim) + else: + if _is_value(dim): + axes = dim + else: + if allow_multi_dim_support: + axes = g.op( + "Constant", value_t=torch.tensor(dim, dtype=torch.long) + ) + else: + axes = g.op( + "Constant", value_t=torch.tensor([dim], dtype=torch.long) + ) + return g.op(onnx_op_name, self, axes, keepdims_i=keepdim) + + return symbolic + + +def _overload_by_arg_count(fn): + @functools.wraps(fn) + def wrapper(g, *args): + overloads = fn(g, *args) + for overload in overloads: + arg_descriptors = overload._arg_descriptors + if len(arg_descriptors) == len(args): + return overload(g, *args) + return _unimplemented(f"aten::{fn.__name__}", f"with {len(args)} arguments") + + return wrapper + + +def _reduce_with_dtype_helper( + onnx_op: str, name: str, allow_multi_dim_support: bool = True +): + symbolic = _reduce_op_symbolic_helper( + onnx_op, allow_multi_dim_support=allow_multi_dim_support + ) + + @_overload_by_arg_count + def reduce(g, *args, **kwargs): + @quantized_args(True) + @parse_args("v", "none") + def reduce_nodim(g, self, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = _get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return _unimplemented(name, "dtype", dtype) + result = symbolic(g, self) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + dim_desc = "is" if allow_multi_dim_support else "i" + + @quantized_args(True) + @parse_args("v", dim_desc, "i", "none") # type: ignore[arg-type] + def reduce_dim(g, self, dim, keepdim, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = _get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return _unimplemented(name, "dtype", dtype) + result = symbolic(g, self, dim, keepdim) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + return reduce_nodim, reduce_dim + + return reduce + + +def _max_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.max(input) + if dim_or_y is None and keepdim is None: + return g.op("ReduceMax", self, keepdims_i=0) + # torch.max(input, other) + if keepdim is None: + return _op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12) + # torch.max(input, dim, keepdim) + else: + keepdim = _get_const(keepdim, "i", "keepdim") + dim = _get_const(dim_or_y, "i", "dim") + if g.opset < 18: + max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim) + else: + axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + max = g.op("ReduceMax", self, axes, keepdims_i=keepdim) + indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim) + return max, indices + + +def _min_helper(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.min(input) + if dim_or_y is None and keepdim is None: + return g.op("ReduceMin", self, keepdims_i=0) + # torch.min(input, other) + if keepdim is None: + return _op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12) + # torch.min(input, dim, keepdim) + else: + keepdim = _get_const(keepdim, "i", "keepdim") + dim = _get_const(dim_or_y, "i", "dim") + if g.opset < 18: + min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim) + else: + axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + min = g.op("ReduceMin", self, axes, keepdims_i=keepdim) + indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim) + return min, indices + + +def _numel_helper(g: jit_utils.GraphContext, self): + shape = g.op("Shape", self) + return g.op("ReduceProd", shape, keepdims_i=0) + + +@parse_args("v", "is", "i", "i") +def _var_mean_helper(g: jit_utils.GraphContext, input, dim, correction, keepdim): + if g.opset < 18: + if dim is None: + mean = g.op("ReduceMean", input, keepdims_i=0) + t_mean = mean + num_elements = _numel_helper(g, input) + else: + mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim) + t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1) + redudced_dims = g.op("Shape", input) + # dim could contain one or multiple dimensions + redudced_dims = g.op( + "Gather", + redudced_dims, + g.op("Constant", value_t=torch.tensor(dim)), + axis_i=0, + ) + num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) + sub_v = g.op("Sub", input, t_mean) + sqr_sub = g.op("Mul", sub_v, sub_v) + keepdim_mean = 0 if dim is None else keepdim + var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean) + # Correct bias in calculating variance, by dividing it over (N - correction) instead on N + if correction is None: + correction = 1 + if correction != 0: + num_elements = g.op( + "Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT + ) + one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float)) + mul = g.op("Mul", var, num_elements) + var = g.op("Div", mul, g.op("Sub", num_elements, one)) + return var, mean + else: + axes = None + if dim is None: + mean = g.op("ReduceMean", input, keepdims_i=0) + t_mean = mean + num_elements = _numel_helper(g, input) + else: + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + mean = g.op("ReduceMean", input, axes, keepdims_i=keepdim) + t_mean = g.op("ReduceMean", input, axes, keepdims_i=1) + redudced_dims = g.op("Shape", input) + # dim could contain one or multiple dimensions + redudced_dims = g.op( + "Gather", + redudced_dims, + g.op("Constant", value_t=torch.tensor(dim)), + axis_i=0, + ) + num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) + sub_v = g.op("Sub", input, t_mean) + sqr_sub = g.op("Mul", sub_v, sub_v) + keepdim_mean = 0 if dim is None else keepdim + if axes is None: + var = g.op("ReduceMean", sqr_sub, keepdims_i=keepdim_mean) + else: + var = g.op("ReduceMean", sqr_sub, axes, keepdims_i=keepdim_mean) + # Correct bias in calculating variance, by dividing it over (N - correction) instead on N + if correction is None: + correction = 1 + if correction != 0: + num_elements = g.op( + "Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT + ) + one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float)) + mul = g.op("Mul", var, num_elements) + var = g.op("Div", mul, g.op("Sub", num_elements, one)) + return var, mean + + +def _embedding_bag_helper( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + if scale_grad_by_freq and GLOBALS.export_training: + return _onnx_unsupported( + "embedding_bag with scale_grad_by_freq for training mode" + ) + if padding_idx is not None and padding_idx >= 0: + raise RuntimeError("embedding_bag with padding_idx") + + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + zero = g.op("Constant", value_t=torch.tensor([0])) + + indices_len = _unsqueeze_helper( + g, + _size_helper(g, indices, g.op("Constant", value_t=torch.tensor(0))), + [0], + ) + if not include_last_offset: + offsets = [offsets, indices_len] + offsets = g.op("Concat", *offsets, axis_i=0) + + # Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by + # offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings. + # The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in. + offsets_starts = _slice_helper( + g, offsets, axes=[0], starts=[0], ends=[sys.maxsize], steps=[1] + ) + offsets_ends = _slice_helper( + g, offsets, axes=[0], starts=[1], ends=[sys.maxsize], steps=[1] + ) + + loop_len = _size_helper(g, offsets_ends, g.op("Constant", value_t=torch.tensor(0))) + + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, n_blocks=1 + ) + loop_block = loop_context.block + + # FIXME(justinchuby): We need to handle what happens when we call b.op on a node return + block_input_iter = utils._add_input_to_block(loop_block) + utils._add_input_to_block(loop_block) + + indices_start = loop_context.op( + "Gather", offsets_starts, block_input_iter, axis_i=0 + ) + indices_end = loop_context.op("Gather", offsets_ends, block_input_iter, axis_i=0) + indices_start = _unsqueeze_helper(loop_context, indices_start, [0]) + indices_end = _unsqueeze_helper(loop_context, indices_end, [0]) + + indices_row = loop_context.op("Slice", indices, indices_start, indices_end, zero) + embeddings = loop_context.op("Gather", embedding_matrix, indices_row, axis_i=0) + if not _is_none(per_sample_weights): + per_sample_weights_row = loop_context.op( + "Slice", per_sample_weights, indices_start, indices_end, zero + ) + per_sample_weights_row = _unsqueeze_helper( + loop_context, per_sample_weights_row, [1] + ) + embeddings = loop_context.op("Mul", embeddings, per_sample_weights_row) + if mode == 0: + embeddings = _reducesum_helper( + loop_context, embeddings, axes_i=[0], keepdims_i=0 + ) + elif mode == 1: + if loop_context.opset < 18: + embeddings = loop_context.op( + "ReduceMean", embeddings, axes_i=[0], keepdims_i=0 + ) + else: + axes = loop_context.op( + "Constant", value_t=torch.tensor([0], dtype=torch.long) + ) + embeddings = loop_context.op("ReduceMean", embeddings, axes, keepdims_i=0) + else: + if loop_context.opset < 18: + embeddings = loop_context.op( + "ReduceMax", embeddings, axes_i=[0], keepdims_i=0 + ) + else: + axes = loop_context.op( + "Constant", value_t=torch.tensor([0], dtype=torch.long) + ) + embeddings = loop_context.op("ReduceMax", embeddings, axes, keepdims_i=0) + + cond_out = loop_context.op( + "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL + ) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, embeddings) + + # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. + # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. + return loop.node().output(), None, None, None + + +def _linalg_vector_norm_helper( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: float, + dim: Sequence[int] | None, + keepdim: bool, + dtype: torch._C.Value, +): + axes = None + # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html + if _is_none(dim): + self = _reshape_helper(g, self, [-1]) + keepdim = False + elif g.opset >= 18: + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + + if ord == math.inf: + if g.opset < 18: + result = g.op( + "ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim + ) + else: + if axes is None: + result = g.op("ReduceMax", g.op("Abs", self), keepdims_i=keepdim) + else: + result = g.op("ReduceMax", g.op("Abs", self), axes, keepdims_i=keepdim) + elif ord == -math.inf: + if g.opset < 18: + result = g.op( + "ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim + ) + else: + if axes is None: + result = g.op("ReduceMin", g.op("Abs", self), keepdims_i=keepdim) + else: + result = g.op("ReduceMin", g.op("Abs", self), axes, keepdims_i=keepdim) + elif ord == 0: + if g.opset < 11: + return _onnx_opset_unsupported_detailed( + "linalg_vector_norm", 9, 11, "ord=0 not supported", self + ) + else: + if dim is None: + self = _reshape_helper( + g, + self, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), + ) + keepdim = False + + cond_op = g.op( + "Not", + g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0]))), + ) + cond_op = g.op( + "Cast", + cond_op, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + return _reducesum_helper(g, cond_op, axes_i=dim, keepdims_i=keepdim) + elif ord == 1: + if g.opset < 18: + result = _reduce_op_symbolic_helper("ReduceL1")( + g, self, dim=dim, keepdim=keepdim + ) + else: + if axes is None: + result = _reduce_op_symbolic_helper("ReduceL1")( + g, self, keepdim=keepdim + ) + else: + result = _reduce_op_symbolic_helper("ReduceL1")( + g, self, axes, keepdim=keepdim + ) + elif ord == 2: + if g.opset < 18: + result = _reduce_op_symbolic_helper("ReduceL2")( + g, self, dim=dim, keepdim=keepdim + ) + else: + if axes is None: + result = _reduce_op_symbolic_helper("ReduceL2")( + g, self, keepdim=keepdim + ) + else: + result = _reduce_op_symbolic_helper("ReduceL2")( + g, self, axes, keepdim=keepdim + ) + else: + ord_op = g.op("Constant", value_t=torch.tensor(ord, dtype=torch.float32)) + result = _reducesum_helper( + g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim + ) + result = g.op( + "Pow", + result, + g.op( + "Div", + g.op("Constant", value_t=torch.tensor(1, dtype=torch.float32)), + ord_op, + ), + ) + + if not _is_none(dtype): + dtype = _get_const(dtype, "i", "dtype") + result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) # type: ignore[arg-type] + return result + + +# Deprecated. Internally use _type_utils.ScalarType +# TODO: remove these once we support Type's in the JIT IR and we can once again +# use the unified toType operator +cast_pytorch_to_onnx = { + "Byte": _C_onnx.TensorProtoDataType.UINT8, + "Char": _C_onnx.TensorProtoDataType.INT8, + "Double": _C_onnx.TensorProtoDataType.DOUBLE, + "Float": _C_onnx.TensorProtoDataType.FLOAT, + "Half": _C_onnx.TensorProtoDataType.FLOAT16, + "Int": _C_onnx.TensorProtoDataType.INT32, + "Long": _C_onnx.TensorProtoDataType.INT64, + "Short": _C_onnx.TensorProtoDataType.INT16, + "Bool": _C_onnx.TensorProtoDataType.BOOL, + "ComplexFloat": _C_onnx.TensorProtoDataType.COMPLEX64, + "ComplexDouble": _C_onnx.TensorProtoDataType.COMPLEX128, + "BFloat16": _C_onnx.TensorProtoDataType.BFLOAT16, + "Undefined": _C_onnx.TensorProtoDataType.UNDEFINED, +} + +# Deprecated. Internally use _type_utils.ScalarType +scalar_name_to_pytorch = { + "uint8_t": "Byte", + "int8_t": "Char", + "double": "Double", + "float": "Float", + "half": "Half", + "int": "Int", + "int64_t": "Long", + "int16_t": "Short", + "bool": "Bool", + "complex64": "ComplexFloat", + "complex128": "ComplexDouble", + "qint8": "QInt8", + "quint8": "QUInt8", + "qint32": "QInt32", + "bfloat16": "BFloat16", +} + + +# Deprecated. Internally use _type_utils.ScalarType +# This indicates each scalar type's corresponding +# torch type. Related source: +# https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h +scalar_type_to_pytorch_type = [ + torch.uint8, # 0 + torch.int8, # 1 + torch.short, # 2 + torch.int, # 3 + torch.int64, # 4 + torch.half, # 5 + torch.float, # 6 + torch.double, # 7 + torch.complex32, # 8 + torch.complex64, # 9 + torch.complex128, # 10 + torch.bool, # 11 + torch.qint8, # 12 + torch.quint8, # 13 + torch.qint32, # 14 + torch.bfloat16, # 15 +] + +# Deprecated. Internally use _type_utils.ScalarType +# source of truth is +# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_dtypes.cpp +pytorch_name_to_type = { + "Byte": torch.uint8, + "Char": torch.int8, + "Double": torch.double, + "Float": torch.float, + "Half": torch.half, + "Int": torch.int, + "Long": torch.int64, + "Short": torch.short, + "Bool": torch.bool, + "ComplexFloat": torch.complex64, + "ComplexDouble": torch.complex128, + "QInt8": torch.qint8, + "QUInt8": torch.quint8, + "QInt32": torch.qint32, + "BFloat16": torch.bfloat16, +} + + +# Deprecated. Internally use _type_utils.ScalarType +scalar_type_to_onnx = [ + cast_pytorch_to_onnx["Byte"], # 0 + cast_pytorch_to_onnx["Char"], # 1 + cast_pytorch_to_onnx["Short"], # 2 + cast_pytorch_to_onnx["Int"], # 3 + cast_pytorch_to_onnx["Long"], # 4 + cast_pytorch_to_onnx["Half"], # 5 + cast_pytorch_to_onnx["Float"], # 6 + cast_pytorch_to_onnx["Double"], # 7 + cast_pytorch_to_onnx["Undefined"], # 8 + cast_pytorch_to_onnx["ComplexFloat"], # 9 + cast_pytorch_to_onnx["ComplexDouble"], # 10 + cast_pytorch_to_onnx["Bool"], # 11 + cast_pytorch_to_onnx["Char"], # 12 + cast_pytorch_to_onnx["Byte"], # 13 + cast_pytorch_to_onnx["Int"], # 14 + cast_pytorch_to_onnx["BFloat16"], # 15 +] + +# Global set to store the list of quantized operators in the network. +# This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX. +_quantized_ops: set[int] = set() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 9bda69b81ab60..e82b1324b67c2 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.symbolic_opset10.""" from __future__ import annotations @@ -9,3 +10,1193 @@ from torch.onnx._internal.torchscript_exporter.symbolic_opset10 import ( # noqa: F401 _slice, ) +======= +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +from __future__ import annotations + +import functools +import sys +import warnings +from typing import TYPE_CHECKING + +import torch +import torch._C._onnx as _C_onnx +import torch.onnx +from torch import _C + +# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics +from torch.onnx import ( + _constants, + _type_utils, + errors, + symbolic_helper, + symbolic_opset9 as opset9, +) +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import jit_utils, registration + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +# This file exports ONNX ops for opset 10 +# Opset 10 is supported by ONNX release 1.5.0 +# release on 04/24/19 + + +__all__ = [ + "dequantize", + "div", + "embedding_bag", + "fake_quantize_per_tensor_affine", + "flip", + "fmod", + "isfinite", + "isinf", + "nan_to_num", + "quantize_per_tensor", + "quantized_add_relu", + "quantized_add", + "quantized_cat", + "quantized_conv1d_relu", + "quantized_conv2d_relu", + "quantized_conv3d_relu", + "quantized_conv1d", + "quantized_conv2d", + "quantized_conv3d", + "quantized_conv_transpose1d", + "quantized_conv_transpose2d", + "quantized_conv_transpose3d", + "quantized_group_norm", + "quantized_hardswish", + "quantized_instance_norm", + "quantized_layer_norm", + "quantized_leaky_relu", + "quantized_linear", + "quantized_linear_relu", + "quantized_mul", + "quantized_sigmoid", + "slice", + "sort", + "topk", +] + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10) + + +@_onnx_symbolic("aten::div") +def div(g: jit_utils.GraphContext, self, other, *args): + if len(args) == 0: + return opset9.true_divide(g, self, other) + else: + return _div_rounding_mode(g, self, other, *args) + + +@symbolic_helper.parse_args("v", "v", "s") +def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): + if rounding_mode == "floor": + return _floor_divide(g, self, other) + else: + return opset9._div_rounding_mode(g, self, other, rounding_mode) + + +@_onnx_symbolic("aten::_floor_divide") +def _floor_divide(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): + out = opset9.true_divide(g, self, other) + return g.op("Floor", out) + else: + # Integer division does trunction rounding + div = g.op("Div", self, other) + # Division is negative if: self < 0 != other < 0 + zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) + negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero)) + + # For negative numbers with self % other != 0, subtract 1 to round down instead of up + mod = g.op("Mod", self, other, fmod_i=0) + fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) + + one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + fixup = g.op("Sub", div, one) + return g.op("Where", fixup_mask, fixup, div) + + +@_onnx_symbolic("aten::sort") +@symbolic_helper.parse_args("v", "i", "i", "none") +def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): + return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) + + +@_onnx_symbolic("aten::topk") +@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") +def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): + return symbolic_helper._topk_helper( + g, self, k, dim, largest=largest, sorted=sorted, out=out + ) + + +def _aten_max_pool_onnx( + g: jit_utils.GraphContext, + self: _C.Value, + kernel_shape: Sequence[int], + strides: Sequence[int], + pads: Sequence[int], + dilations: Sequence[int], + ceil_mode: bool, + unbatched_rank: int, +) -> _C.Value: + self_rank = g.op("Size", g.op("Shape", self)) + if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 + self = g.op( + "Unsqueeze", + self, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + pool_result, _ = g.op( + "MaxPool", + self, + outputs=2, + ceil_mode_i=ceil_mode, + dilations_i=dilations, + kernel_shape_i=kernel_shape, + pads_i=pads, + strides_i=strides, + ) + + if self_rank == unbatched_rank: + pool_result = g.op( + "Squeeze", + pool_result, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + return pool_result + + +# For MaxPool +def _adjust_attributes_of_max_pool( + expand_size: int, + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int, + padding: Sequence[int] | int, + dilation: Sequence[int] | int, +) -> tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]: + """Adjust attributes of avg_pool to match ONNX specification.""" + + if isinstance(dilation, int): + dilation = [dilation] * expand_size + + if isinstance(kernel_size, int): + kernel_shape = [kernel_size] * expand_size + else: + kernel_shape = kernel_size # type: ignore[assignment] + + if isinstance(padding, int): + pads = [padding] * expand_size * 2 # type: ignore[operator, assignment] + elif len(padding) == 1: + pads = padding * expand_size * 2 # type: ignore[operator, assignment] + elif len(padding) == 2: + # 2D padding + pads = padding * 2 # type: ignore[operator, assignment] + elif len(padding) == 3: + # 3D padding + pads = padding * 2 # type: ignore[operator, assignment] + else: + # When padding is already done for all dimensions, + # we don't need to double it + # eg: (1, 1, 1, 1, 1, 1) + pads = padding # type: ignore[assignment] + + if isinstance(stride, int): + strides = [stride] * expand_size + elif not stride: + strides = kernel_shape + else: + strides = stride # type: ignore[assignment] + + return (kernel_shape, strides, pads, dilation) + + +def _aten_max_pool_with_indices_onnx( + g: jit_utils.GraphContext, + self: _C.Value, + kernel_shape: Sequence[int], + strides: Sequence[int], + pads: Sequence[int], + dilations: Sequence[int], + ceil_mode: bool, + unbatched_rank: int, + n_dims_one: Sequence[int], + n_dims_zero: Sequence[int], + n_dims_axes: Sequence[int], +) -> tuple[_C.Value, Sequence[int]]: + self_rank = g.op("Size", g.op("Shape", self)) + if self_rank == unbatched_rank: # C,H,W -> N,C,H,W and N=1 + self = g.op( + "Unsqueeze", + self, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + pool_result, indices = g.op( + "MaxPool", + self, + outputs=2, + ceil_mode_i=ceil_mode, + dilations_i=dilations, + kernel_shape_i=kernel_shape, + pads_i=pads, + strides_i=strides, + ) + _, flatten_indices = g.op( + "MaxPool", + self, + outputs=2, + dilations_i=dilations, + kernel_shape_i=n_dims_one, + strides_i=n_dims_one, + ) + + ends = g.op("Constant", value_t=torch.tensor(n_dims_one)) + starts = g.op("Constant", value_t=torch.tensor(n_dims_zero)) + axes = g.op("Constant", value_t=torch.tensor(n_dims_axes)) + + delta = g.op("Slice", flatten_indices, starts, ends, axes) + indices = g.op("Sub", indices, delta) + + if self_rank == unbatched_rank: + pool_result = g.op( + "Squeeze", pool_result, value_t=torch.tensor([0], dtype=torch.int64) + ) + indices = g.op("Squeeze", indices, value_t=torch.tensor([0], dtype=torch.int64)) + + return (pool_result, indices) + + +@_onnx_symbolic( + "aten::max_pool1d", + decorate=[symbolic_helper._apply_params("max_pool1d", 1, return_indices=False)], +) +@_onnx_symbolic( + "aten::max_pool2d", + decorate=[symbolic_helper._apply_params("max_pool2d", 2, return_indices=False)], +) +@_onnx_symbolic( + "aten::max_pool3d", + decorate=[symbolic_helper._apply_params("max_pool3d", 3, return_indices=False)], +) +@_onnx_symbolic( + "aten::max_pool1d_with_indices", + decorate=[ + symbolic_helper._apply_params( + "max_pool1d_with_indices", + 1, + return_indices=True, + ) + ], +) +@_onnx_symbolic( + "aten::max_pool2d_with_indices", + decorate=[ + symbolic_helper._apply_params( + "max_pool2d_with_indices", + 2, + return_indices=True, + ) + ], +) +@_onnx_symbolic( + "aten::max_pool3d_with_indices", + decorate=[ + symbolic_helper._apply_params( + "max_pool3d_with_indices", + 3, + return_indices=True, + ) + ], +) +def _max_pool(name: str, expand_size: int, return_indices: bool): + @symbolic_helper.quantized_args(True, False, False, False, False, False) + @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") + def symbolic_fn( + g: jit_utils.GraphContext, + input: _C.Value, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: int | Sequence[int], + dilation: Sequence[int], + ceil_mode: bool, + ): + kernel_shape, strides, pads, dilations = _adjust_attributes_of_max_pool( + expand_size, kernel_size, stride, padding, dilation + ) + + if return_indices: + return _aten_max_pool_with_indices_onnx( + g, + input, + kernel_shape, + strides, + pads, + dilations, + ceil_mode, + expand_size + 1, + ([1] * expand_size), + ([0] * expand_size), + ([2 + i for i in range(expand_size)]), + ) + else: + return _aten_max_pool_onnx( + g, + input, + kernel_shape, + strides, + pads, + dilations, + ceil_mode, + expand_size + 1, + ) + + return symbolic_fn + + +# For AvgPool +def _adjust_attributes_of_avg_pool( + expand_size: int, + kernel_size: Sequence[int] | int, + stride: Sequence[int] | int, + padding: Sequence[int] | int, +) -> tuple[Sequence[int], Sequence[int], Sequence[int]]: + """Adjust attributes of avg_pool to match ONNX specification.""" + + if isinstance(kernel_size, int): + kernel_shape = [kernel_size] * expand_size + else: + kernel_shape = kernel_size # type: ignore[assignment] + + if isinstance(padding, int): + pads = [padding] * expand_size * 2 + elif len(padding) == 1: + pads = padding * expand_size * 2 # type: ignore[operator, assignment] + elif len(padding) == 2: + pads = padding * expand_size # type: ignore[operator, assignment] + else: + pads = padding * 2 # type: ignore[operator, assignment] + + if isinstance(stride, int): + strides = [stride] * expand_size + elif not stride: + strides = kernel_shape + else: + strides = stride # type: ignore[assignment] + + return (kernel_shape, strides, pads) + + +@_onnx_symbolic( + "aten::avg_pool1d", + decorate=[symbolic_helper._apply_params("avg_pool1d", 1)], +) +@_onnx_symbolic( + "aten::avg_pool2d", + decorate=[symbolic_helper._apply_params("avg_pool2d", 2)], +) +@_onnx_symbolic( + "aten::avg_pool3d", + decorate=[symbolic_helper._apply_params("avg_pool3d", 3)], +) +def _avg_pool(name, expand_size): + @symbolic_helper.quantized_args(True, False, False, False, False, False, False) + @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") + def symbolic_fn( + g, + input: _C.Value, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: int | Sequence[int], + ceil_mode: int, + count_include_pad: int, + divisor_override=None, + ): + kernel_shape, strides, pads = _adjust_attributes_of_avg_pool( + expand_size, kernel_size, stride, padding + ) + + result = g.op( + "AveragePool", + input, + ceil_mode_i=ceil_mode, + count_include_pad_i=count_include_pad, + kernel_shape_i=kernel_shape, + pads_i=pads, + strides_i=strides, + ) + + return result + + return symbolic_fn + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], +) +def _interpolate(name, dim, interpolate_mode): + @symbolic_helper.quantized_args(True, False, False) + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = symbolic_helper._get_interpolate_attributes( + g, interpolate_mode, args + ) + symbolic_helper._interpolate_warning(interpolate_mode) + align_corners = symbolic_helper._maybe_get_scalar(align_corners) + if align_corners: + return symbolic_helper._unimplemented(name, "align_corners == True", input) + if scales is None: + scales = symbolic_helper._interpolate_size_to_scales( + g, input, output_size, dim + ) + return g.op("Resize", input, scales, mode_s=interpolate_mode) + + return symbolic_fn + + +@_onnx_symbolic("aten::__interpolate") +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + scales, mode = symbolic_helper._interpolate_get_scales_and_mode( + g, input, size, scale_factor, mode, align_corners + ) + return g.op("Resize", input, scales, mode_s=mode) + + +def _slice( + g: jit_utils.GraphContext, + input: torch._C.Value, + axes: list | torch.Tensor | torch._C.Value, + starts: list | torch.Tensor | torch._C.Value, + ends: list | torch.Tensor | torch._C.Value, + steps: list | torch.Tensor | torch._C.Value | None = None, +): + def is_none_value(value): + if value is None: + return True + return ( + isinstance(value, torch._C.Value) + and value.node().kind() == "prim::Constant" + and isinstance(value.type(), _C.NoneType) + ) + + def to_slice_input(list_or_value, default_value=None): + # Convert input param into a 1D torch.Value. + if is_none_value(list_or_value) and default_value is not None: + list_or_value = [default_value] + + if isinstance(list_or_value, (list, torch.Tensor)): + return g.op("Constant", value_t=torch.tensor(list_or_value)) + + rank = symbolic_helper._get_tensor_rank(list_or_value) + if rank == 0: + return symbolic_helper._unsqueeze_helper(g, list_or_value, [0]) + if rank == 1: + return list_or_value + raise errors.SymbolicValueError( + f"Rank must be 0 or 1, not {rank}", list_or_value + ) + + def get_const_value(list_or_value): + if isinstance(list_or_value, (list, torch.Tensor)): + if len(list_or_value) == 1: + return list_or_value[0] + return None + return symbolic_helper._maybe_get_const(list_or_value, "i") + + # Check if slice is a no-op + if ( + get_const_value(starts) == 0 + and get_const_value(ends) == _constants.INT64_MAX + and (steps is None or get_const_value(steps) == 1) + ): + return input + + axes = to_slice_input(axes) + starts = to_slice_input(starts, default_value=0) + ends = to_slice_input(ends, default_value=_constants.INT64_MAX) + if steps is None: + return g.op("Slice", input, starts, ends, axes) + steps = to_slice_input(steps, default_value=1) + return g.op("Slice", input, starts, ends, axes, steps) + + +@_onnx_symbolic("aten::slice") +def slice(g: jit_utils.GraphContext, self, *args): + if len(args) == 4: + # aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor + dims, start, end, step = args + elif len(args) == 3: + # aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[] + start, end, step = args + dims = [0] + else: + raise errors.SymbolicValueError("Unknown aten::slice signature", self) + + return symbolic_helper._slice_helper( + g, + self, + axes=dims, + starts=start, + ends=end, + steps=step, + ) + + +@_onnx_symbolic("aten::flip") +@symbolic_helper.parse_args("v", "is") +def flip(g: jit_utils.GraphContext, input, dims): + return symbolic_helper._slice_helper( + g, + input, + axes=dims, + starts=[-1] * len(dims), + ends=[-_constants.INT64_MAX] * len(dims), + steps=[-1] * len(dims), + ) + + +@_onnx_symbolic("aten::fmod") +def fmod(g: jit_utils.GraphContext, input, other): + return g.op("Mod", input, other, fmod_i=1) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + if scale_grad_by_freq and GLOBALS.export_training: + return symbolic_helper._onnx_unsupported( + "embedding_bag with scale_grad_by_freq for training mode" + ) + if padding_idx is not None and padding_idx >= 0: + raise RuntimeError("embedding_bag with padding_idx") + + warnings.warn( + "Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. " + "Please use opset 11 or higher to export model for dynamic input shape.'" + ) + offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0) + if offsets_dim_0 is not None: + if include_last_offset: + offset_len = offsets_dim_0 - 1 + offsets_extended = offsets + else: + offset_len = offsets_dim_0 + offsets_extended = [ + offsets, + g.op("Constant", value_t=torch.tensor([sys.maxsize])), + ] + offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) + list_ = [] + for i in range(offset_len): + start_ = symbolic_helper._unsqueeze_helper( + g, + opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), + [0], + ) + end_ = symbolic_helper._unsqueeze_helper( + g, + opset9.select( + g, offsets_extended, torch.tensor(0), torch.tensor(i + 1) + ), + [0], + ) + axes_ = g.op("Constant", value_t=torch.tensor([0])) + indices_row = g.op("Slice", indices, start_, end_, axes_) + + embeddings = g.op("Gather", embedding_matrix, indices_row) + if not symbolic_helper._is_none(per_sample_weights): + per_sample_weights_row = g.op( + "Slice", per_sample_weights, start_, end_, axes_ + ) + per_sample_weights_row = symbolic_helper._unsqueeze_helper( + g, per_sample_weights_row, [1] + ) + embeddings = g.op("Mul", embeddings, per_sample_weights_row) + if mode == 0: + embeddings = symbolic_helper._reducesum_helper( + g, embeddings, axes_i=[0], keepdims_i=0 + ) + elif mode == 1: + embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) + else: + embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) + + embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0]) + list_.append(embeddings) + + output = g.op("Concat", *list_, axis_i=0) + # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. + # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. + return output, None, None, None + else: + return symbolic_helper._onnx_unsupported( + "embedding_bag with unknown shape of offsets for opset 10 is not supported. " + "please use opset 11 or higher." + ) + + +@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i") +def fake_quantize_per_tensor_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + quant_min=-128, + quant_max=127, +): + # NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) == (0, 127): + symbolic_helper._onnx_opset_unsupported_detailed( + "fake_quantize_per_tensor_affine", + 10, + 13, + "Quantize range (0, 127) not supported, requires opset 13 Clip", + inputs, + ) + if (quant_min, quant_max) not in [(0, 255), (-128, 127)]: + raise errors.SymbolicValueError( + f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + inputs, + ) + scale = symbolic_helper._maybe_get_scalar(scale) + if scale is None: + symbolic_helper._onnx_opset_unsupported_detailed( + "fake_quantize_per_tensor_affine", + 10, + 13, + "Non-constant scale not supported", + inputs, + ) + scale = scale.float().data # Avoid exporter generating double type + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + return g.op( + "DequantizeLinear", + g.op("QuantizeLinear", inputs, scale, zero_point), + scale, + zero_point, + ) + + +@_onnx_symbolic("aten::isinf") +def isinf(g: jit_utils.GraphContext, input): + return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)) + + +@_onnx_symbolic("aten::isfinite") +def isfinite(g: jit_utils.GraphContext, input): + inf_node = isinf(g, input) + nan_node = opset9.isnan(g, input) + return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node)) + + +@_onnx_symbolic("aten::quantize_per_tensor") +def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + # TODO(justinchuby): Extract all the cast ops into a helper function. + zero_point = g.op( + "Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return symbolic_helper.quantize_helper(g, input, scale, zero_point) + + +@_onnx_symbolic("aten::dequantize") +def dequantize(g: jit_utils.GraphContext, input): + return symbolic_helper.dequantize_helper(g, input)[0] + + +@_onnx_symbolic("aten::nan_to_num") +@symbolic_helper.parse_args("v", "f", "f", "f") +def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf): + # Cannot create a int type tensor with inf/nan values, so we simply + # return the original tensor + if not symbolic_helper._is_fp(input): + return input + input_dtype = _type_utils.JitScalarType.from_value(input).dtype() + if nan is None: + nan = 0.0 + nan_cond = opset9.isnan(g, input) + nan_result = g.op( + "Where", + nan_cond, + g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)), + input, + ) + + # For None values of posinf, neginf we use the greatest/lowest finite + # value representable by input's dtype. + finfo = torch.finfo(input_dtype) + if posinf is None: + posinf = finfo.max + posinf_cond = opset9.logical_and( + g, + isinf(g, nan_result), + opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))), + ) + nan_posinf_result = g.op( + "Where", + posinf_cond, + g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)), + nan_result, + ) + + if neginf is None: + neginf = finfo.min + neginf_cond = opset9.logical_and( + g, + isinf(g, nan_posinf_result), + opset9.lt( + g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0])) + ), + ) + return g.op( + "Where", + neginf_cond, + g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)), + nan_posinf_result, + ) + + +# Quantized symbolics --------------------------------------------------------- +# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export +# Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were +# introduced in opset version 10. +@_onnx_symbolic("quantized::linear") +def quantized_linear( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::linear_relu") +def quantized_linear_relu( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::add") +def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + y, _, _, _ = symbolic_helper.dequantize_helper(g, y) + + output = opset9.add(g, x, y) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::add_relu") +def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + y, _, _, _ = symbolic_helper.dequantize_helper(g, y) + + output = opset9.add(g, x, y) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::mul") +def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + y, _, _, _ = symbolic_helper.dequantize_helper(g, y) + + output = opset9.mul(g, x, y) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::hardswish") +def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.hardswish(g, x) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::sigmoid") +def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.sigmoid(g, x) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::leaky_relu") +def quantized_leaky_relu( + g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.leaky_relu(g, x, negative_slope, inplace) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::layer_norm") +def quantized_layer_norm( + g: jit_utils.GraphContext, + x, + normalized_shape, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::group_norm") +def quantized_group_norm( + g: jit_utils.GraphContext, + x, + num_groups, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::instance_norm") +@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v") +def quantized_instance_norm( + g: jit_utils.GraphContext, + q_input, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input) + + output = opset9.instance_norm( + g, input, weight, bias, None, None, False, 0.0, eps, False + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d_relu") +def quantized_conv1d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d_relu") +def quantized_conv2d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d_relu") +def quantized_conv3d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d") +def quantized_conv1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d") +def quantized_conv2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d") +def quantized_conv3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose1d") +def quantized_conv_transpose1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose2d") +def quantized_conv_transpose2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose3d") +def quantized_conv_transpose3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose3d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::cat") +@symbolic_helper.parse_args("v", "i", "v", "v") +def quantized_cat( + g: jit_utils.GraphContext, + q_inputs: _C.Value, + dim: int, + op_scale: _C.Value, + op_zero_point: _C.Value, +) -> _C.Value: + unpacked_inputs = symbolic_helper._unpack_list(q_inputs) + dequantized = [ + symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs + ] + concatenated = g.op("Concat", *dequantized, axis_i=dim) + return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 276ef7209bf69..adfdc0f89dbf5 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.symbolic_opset11.""" from __future__ import annotations @@ -6,3 +7,1474 @@ __all__: list[str] = [] from torch.onnx._internal.torchscript_exporter.symbolic_opset11 import * # noqa: F401,F403 +======= +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 11.""" + +from __future__ import annotations + +import functools +import sys +import warnings +from typing import TYPE_CHECKING + +import torch +from torch import _C +from torch._C import _onnx as _C_onnx +from torch.onnx import ( + _type_utils, + errors, + symbolic_helper, + symbolic_opset10 as opset10, + symbolic_opset9 as opset9, + utils, +) +from torch.onnx._internal import jit_utils, registration + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +__all__ = [ + "add", + "append", + "arange", + "argsort", + "atleast_1d", + "atleast_2d", + "atleast_3d", + "cat", + "chunk", + "clamp_max", + "clamp_min", + "clamp", + "constant_pad_nd", + "cumsum", + "Delete", + "embedding_bag", + "embedding_renorm", + "flatten", + "gather", + "hardtanh", + "hstack", + "im2col", + "index_fill", + "index", + "index_copy", + "index_put", + "insert", + "linalg_det", + "linalg_vector_norm", + "logdet", + "masked_scatter", + "masked_select", + "mm", + "narrow", + "normal", + "pad", + "pixel_shuffle", + "pop", + "prim_constant_chunk", + "reflection_pad", + "relu6", + "remainder", + "replication_pad", + "round", + "scatter", + "select", + "size", + "sort", + "split_with_sizes", + "split", + "squeeze", + "stack", + "topk", + "unbind", + "unique_dim", + "unsqueeze", + "vstack", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=11) + + +@_onnx_symbolic("aten::hardtanh") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "f", "f") +def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + min_val = g.op( + "Constant", + value_t=torch.tensor(min_val, dtype=scalar_type.dtype()), + ) + max_val = g.op( + "Constant", + value_t=torch.tensor(max_val, dtype=scalar_type.dtype()), + ) + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min_val, max_val, opset_before=12 + ) + + +@_onnx_symbolic("aten::clamp") +def clamp(g: jit_utils.GraphContext, self, min, max): + def _cast_if_not_none(tensor, dtype): + if tensor is not None and not symbolic_helper._is_none(tensor): + return g.op( + "Cast", + tensor, + to_i=dtype.onnx_type(), + ) + else: + return tensor + + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + min = _cast_if_not_none(min, scalar_type) + max = _cast_if_not_none(max, scalar_type) + + if symbolic_helper._is_none(min): + return clamp_max(g, self, max) + elif symbolic_helper._is_none(max): + return clamp_min(g, self, min) + else: + if ( + symbolic_helper._get_tensor_rank(min) == 0 + and symbolic_helper._get_tensor_rank(max) == 0 + ): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min, max, opset_before=12 + ) + else: + return clamp_max(g, clamp_min(g, self, min), max) + + +@_onnx_symbolic("aten::clamp_min") +@symbolic_helper.parse_args("v", "v") +def clamp_min(g: jit_utils.GraphContext, self, min): + min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()) + if symbolic_helper._get_tensor_rank(min) == 0: + max = opset9.unused(g) + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min, max, opset_before=12 + ) + else: + return symbolic_helper._op_with_optional_float_cast( + g, "Max", self, min, opset_before=12 + ) + + +@_onnx_symbolic("aten::clamp_max") +@symbolic_helper.parse_args("v", "v") +def clamp_max(g: jit_utils.GraphContext, self, max): + max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()) + if symbolic_helper._get_tensor_rank(max) == 0: + min = opset9.unused(g) + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min, max, opset_before=12 + ) + else: + return symbolic_helper._op_with_optional_float_cast( + g, "Min", self, max, opset_before=12 + ) + + +@_onnx_symbolic("aten::relu6") +def relu6(g: jit_utils.GraphContext, input): + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + min_val = g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ) + max_val = g.op( + "Constant", + value_t=torch.tensor(6, dtype=scalar_type.dtype()), + ) + return clamp(g, input, min_val, max_val) + + +@_onnx_symbolic("aten::select") +# Opset 11 gather accepts negative indices +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "i", "v") +def select(g: jit_utils.GraphContext, self, dim, index): + return g.op("Gather", self, index, axis_i=dim) + + +@_onnx_symbolic("aten::index_put") +def index_put( + g: jit_utils.GraphContext, self, indices_list_value, values, accumulate=False +): + if symbolic_helper._is_packed_list(indices_list_value): + indices_list = symbolic_helper._unpack_list(indices_list_value) + else: + indices_list = [indices_list_value] + accumulate = symbolic_helper._parse_arg(accumulate, "b") + + if len(indices_list) == 0: + return values + + if len(indices_list) > 1: + for idx_ in range(len(indices_list)): + if symbolic_helper._is_bool(indices_list[idx_]): + indices_list[idx_] = g.op("NonZero", indices_list[idx_]) + index = indices_list[0] + + for ind in indices_list[1:]: + index = opset9.add(g, index, ind) + broadcast_index_shape = g.op("Shape", index) + indices_list = [ + symbolic_helper._unsqueeze_helper( + g, opset9.expand(g, ind, broadcast_index_shape, None), [-1] + ) + for ind in indices_list + ] + index = g.op("Concat", *indices_list, axis_i=-1) + else: + # Replace index_put node with masked_scatter or masked_fill + # when inputs to the index_put node contains a single boolean input. + # + # index_put -> masked_fill + # * input index contains single tensor of Bool type (e.g.: %24 <- %23). + # * input value contains single element (e.g.: %18). + # + # Torch IR + # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) + # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = + # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) + # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() + # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) + # %24 : Tensor?[] = prim::ListConstruct(%23) + # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = + # aten::index_put(%mask, %24, %18, %30) + # return (%25) + # + # + # index_put -> masked_scatter + # * input index contains single tensor of Bool type (e.g.: %32 <- %31). + # * input value contains multiple elements (e.g.: %28). + # + # Torch IR + # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) + # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) + # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() + # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::ne(%mask, %some_const) + # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) + # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %30 : int[] = prim::Constant[value=[-1]]() + # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) + # %32 : Tensor?[] = prim::ListConstruct(%31) + # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) + # = aten::index_put(%mask, %32, %28, %38) + # return (%33) + index = indices_list[0] + bool_inp = index + if symbolic_helper._is_bool(bool_inp): + rank = symbolic_helper._get_tensor_rank(values) + if rank is not None and rank == 0: + return opset9.masked_fill(g, self, bool_inp, values) + mask_rank = symbolic_helper._get_tensor_rank(bool_inp) + self_rank = symbolic_helper._get_tensor_rank(self) + if ( + mask_rank is not None + and self_rank is not None + and self_rank > mask_rank + ): + # Unsqueeze 'bool_inp' to be broadcastable to shape of 'self'. + bool_inp = symbolic_helper._unsqueeze_helper( + g, bool_inp, list(range(mask_rank, self_rank)) + ) + return masked_scatter(g, self, bool_inp, values) + broadcast_index_shape = g.op("Shape", index) + index = symbolic_helper._unsqueeze_helper(g, index, [-1]) + sub_data_shape = symbolic_helper._slice_helper( + g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize] + ) + values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) + # Check if values is a singular value and expand accordingly + rank = symbolic_helper._get_tensor_rank(values) + if rank is not None and rank == 0: + values = opset9.expand(g, values, values_shape, None) + values = symbolic_helper._reshape_helper(g, values, values_shape) + + self_scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if self_scalar_type != _type_utils.JitScalarType.UNDEFINED: + values_scalar_type = _type_utils.JitScalarType.from_value( + values, _type_utils.JitScalarType.UNDEFINED + ) + if self_scalar_type != values_scalar_type: + values = g.op("Cast", values, to_i=self_scalar_type.onnx_type()) + elif accumulate: + raise errors.SymbolicValueError("self does not have a valid scalar type.", self) + + if accumulate: + zeros = g.op( + "ConstantOfShape", + g.op("Shape", self), + value_t=torch.tensor([0], dtype=self_scalar_type.dtype()), + ) + result = g.op("ScatterND", zeros, index, values) + result = add(g, self, result) + else: + result = g.op("ScatterND", self, index, values) + + return result + + +@_onnx_symbolic("aten::pixel_shuffle") +@symbolic_helper.parse_args("v", "i") +def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None and rank != 4: + return symbolic_helper._unimplemented("pixel_shuffle", "only support 4d input") + return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD") + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bicubic2d", + decorate=[symbolic_helper._apply_params("upsample_bicubic2d", 4, "cubic")], +) +def _interpolate(name: str, dim: int, interpolate_mode: str): + return symbolic_helper._interpolate_helper(name, dim, interpolate_mode) + + +@_onnx_symbolic("aten::__interpolate") +@symbolic_helper.quantized_args(True, False, False, False, False, False, False) +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + return symbolic_helper.__interpolate_helper( + g, input, size, scale_factor, mode, align_corners, recompute_scale_factor + ) + + +@_onnx_symbolic("aten::gather") +@symbolic_helper.parse_args("v", "i", "v", "v") +def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): + if symbolic_helper._maybe_get_const(sparse_grad, "i"): + return symbolic_helper._unimplemented("gather", "sparse_grad == True") + return g.op("GatherElements", self, index, axis_i=dim) + + +@_onnx_symbolic("aten::scatter") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter(g: jit_utils.GraphContext, self, dim, index, src): + src_type = _type_utils.JitScalarType.from_value(src) + src = symbolic_helper._maybe_get_scalar(src) + if symbolic_helper._is_value(src): + return g.op("ScatterElements", self, index, src, axis_i=dim) + else: + # Check if scalar "src" has same type as self (PyTorch allows different + # type for scalar src (but not when src is tensor)). If not, insert Cast node. + if _type_utils.JitScalarType.from_value(self) != src_type: + src = g.op( + "Cast", + src, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + return g.op( + "ScatterElements", self, index, opset9.expand_as(g, src, index), axis_i=dim + ) + + +@_onnx_symbolic("aten::cumsum") +@symbolic_helper.parse_args("v", "i", "none") +def cumsum(g: jit_utils.GraphContext, self, dim, dtype=None): + dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int)) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + cast = g.op( + "Cast", self, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + else: + cast = self + csum = g.op("CumSum", cast, dim_tensor) + return csum + + +@_onnx_symbolic("aten::masked_select") +def masked_select(g: jit_utils.GraphContext, self, mask): + index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) + return g.op("GatherND", self, index) + + +@_onnx_symbolic("aten::masked_scatter") +def masked_scatter(g: jit_utils.GraphContext, self, mask, source): + index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) + # NOTE: source can have more elements than needed. + # It could also have arbitrary shape. + # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. + source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1])) + source = symbolic_helper._slice_helper( + g, + source, + axes=torch.LongTensor([0]), + starts=torch.LongTensor([0]), + ends=opset9.size(g, index, torch.LongTensor([0])), + ) + return g.op("ScatterND", self, index, source) + + +@_onnx_symbolic("aten::len") +def _len(g: jit_utils.GraphContext, self): + if ( + symbolic_helper._is_tensor_list(self) + or self.node().kind() == "onnx::SplitToSequence" + ): + return g.op("SequenceLength", self) + sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) + return symbolic_helper._squeeze_helper(g, sz_0, [0]) + + +@_onnx_symbolic("aten::__getitem_") +def __getitem_(g: jit_utils.GraphContext, self, i): + if symbolic_helper._is_tensor_list(self): + # SequenceAt requires that the input be a List of Tensors + return g.op("SequenceAt", self, i) + else: + from torch.onnx.symbolic_opset9 import __getitem_ as getitem + + return getitem(g, self, i) + + +@_onnx_symbolic("aten::_set_item") +def _set_item(g: jit_utils.GraphContext, tensor_list, i, v): + tensor_list = g.op("SequenceErase", tensor_list, i) + return g.op("SequenceInsert", tensor_list, v, i) + + +@_onnx_symbolic("aten::append") +def append(g: jit_utils.GraphContext, self, tensor): + return g.op("SequenceInsert", self, tensor) + + +@_onnx_symbolic("aten::add") +def add(g: jit_utils.GraphContext, self, other, alpha=None): + if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): + tensor_list_node = other.node() + if tensor_list_node.kind() != "prim::ListConstruct": + return symbolic_helper._unimplemented( + "add", "does not support adding dynamic tensor list to another" + ) + tensors = symbolic_helper._unpack_list(other) + l = self + for t in tensors: + l = g.op("SequenceInsert", l, t) + return l + + return opset9.add(g, self, other, alpha) + + +@_onnx_symbolic("aten::insert") +def insert(g: jit_utils.GraphContext, self, pos, tensor): + return g.op("SequenceInsert", self, tensor, pos) + + +@_onnx_symbolic("aten::pop") +def pop(g: jit_utils.GraphContext, tensor_list, dim): + return g.op("SequenceErase", tensor_list, dim) + + +@_onnx_symbolic("aten::Delete") +def Delete(g: jit_utils.GraphContext, tensor_list, dim): + return g.op("SequenceErase", tensor_list, dim) + + +@_onnx_symbolic("aten::cat") +@symbolic_helper.quantized_args(True) +def cat(g: jit_utils.GraphContext, tensor_list, dim): + if symbolic_helper._is_packed_list(tensor_list): + return opset9.cat(g, tensor_list, dim) + else: + dim = symbolic_helper._get_const(dim, "i", "dim") + return g.op("ConcatFromSequence", tensor_list, axis_i=dim) + + +@_onnx_symbolic("aten::stack") +def stack(g: jit_utils.GraphContext, tensor_list, dim): + if symbolic_helper._is_packed_list(tensor_list): + return opset9.stack(g, tensor_list, dim) + else: + dim = symbolic_helper._get_const(dim, "i", "dim") + return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1) + + +@_onnx_symbolic("aten::_unique2") +@symbolic_helper.parse_args("v", "i", "i", "i") +def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_counts): + u, _indices, inverse_indices, counts = g.op( + "Unique", self, sorted_i=sorted, outputs=4 + ) + return u, inverse_indices, counts + + +@_onnx_symbolic("aten::unique_dim") +@symbolic_helper.parse_args("v", "i", "i", "i", "i") +def unique_dim( + g: jit_utils.GraphContext, self, dim, sorted, return_inverse, return_counts +): + u, _indices, inverse_indices, counts = g.op( + "Unique", self, axis_i=dim, sorted_i=sorted, outputs=4 + ) + return u, inverse_indices, counts + + +@_onnx_symbolic("aten::topk") +@symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") +def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): + return symbolic_helper._topk_helper( + g, self, k, dim, largest=largest, sorted=sorted, out=out + ) + + +@_onnx_symbolic("aten::sort") +@symbolic_helper.parse_args("v", "i", "i", "none") +def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): + return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) + + +@_onnx_symbolic("aten::argsort") +@symbolic_helper.parse_args("v", "i", "i", "none") +def argsort(g: jit_utils.GraphContext, self, dim, decending, out=None): + _, indices = symbolic_helper._sort_helper( + g, self, dim, decending=decending, out=out + ) + return indices + + +@_onnx_symbolic("aten::round") +@symbolic_helper.parse_args("v", "i") +def round(g: jit_utils.GraphContext, self, decimals=0): + if not symbolic_helper._is_fp(self): + return self + if decimals == 0: + return g.op("Round", self) + mul = g.op("Mul", self, g.op("Constant", value_t=torch.tensor(pow(10, decimals)))) + round = g.op("Round", mul) + return g.op( + "Mul", round, g.op("Constant", value_t=torch.tensor(pow(10, -1 * decimals))) + ) + + +@_onnx_symbolic("aten::remainder") +def remainder(g: jit_utils.GraphContext, input, other): + if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other): + return opset9.remainder(g, input, other) + return g.op("Mod", input, other, fmod_i=0) + + +@_onnx_symbolic("aten::split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): + split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) + if _outputs is None: + return split_out + # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. + if ( + symbolic_helper._is_packed_list(split_size_or_sizes) + and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs + ): + split_sizes = [ + symbolic_helper._unsqueeze_helper(g, v, [0]) + for v in symbolic_helper._unpack_list(split_size_or_sizes) + ] + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + res = [] + for i in range(_outputs): + end = g.op( + "Add", start, split_sizes[i] + ) # split_sizes is a list of same length as _outputs + res.append(g.op("Slice", self, start, end, axis)) + start = end + return res + return [ + g.op( + "SequenceAt", + split_out, + g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), + ) + for i in range(_outputs) + ] + else: + return opset9.split(g, self, split_size_or_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::split_with_sizes") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): + return split(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unbind") +@symbolic_helper.parse_args("v", "i", "i") +def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): + if _outputs is None: + return g.op( + "SplitToSequence", + self, + g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), + axis_i=dim, + keepdims_i=0, + ) + else: + return opset9.unbind(g, self, dim, _outputs) + + +def _prepare_onnx_paddings(g: jit_utils.GraphContext, input, pad): + """Generate paddings in ONNX order based on pad in pytorch. + + Args: + input: the input tensor. + pad: the paddings in pytorch. + The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end, + where m is in range [0, n]. + """ + if ( + not symbolic_helper._is_packed_list(pad) + and symbolic_helper._is_list(pad) + and symbolic_helper._is_scalar_list(pad) + ): + pad = g.op("ConcatFromSequence", pad, axis_i=0, new_axis_i=1) + # The desired order of paddings is + # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. + # n is the dimension of input. + # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning + pad_len = opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0]))) + # Set extension = [0] * (dim * 2 - len(pad)) + rank = symbolic_helper._get_tensor_rank(input) + if rank is None: + rank = g.op("Size", g.op("Shape", input)) + else: + rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64)) + extension = g.op( + "Sub", + g.op("Mul", rank, g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), + pad_len, + ) + # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ] + # Currently ONNX only supports int64 type for Pad + pad = g.op("Cast", pad, to_i=_C_onnx.TensorProtoDataType.INT64) + paddings = g.op( + "Concat", + pad, + g.op( + "ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64) + ), + axis_i=0, + ) + # Reshape and reverse order and collate first beginnings and then ends + # paddings = [[..., 0, dim_n-1_begin, dim_n_begin], + # [..., 0, dim_n-1_end, dim_n_end]] + # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end] + paddings = symbolic_helper._reshape_helper( + g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2])) + ) + paddings = g.op("Transpose", opset10.flip(g, paddings, [0]), perm_i=[1, 0]) + paddings = symbolic_helper._reshape_helper( + g, paddings, g.op("Constant", value_t=torch.tensor([-1])) + ) + padding_c = g.op("Cast", paddings, to_i=_C_onnx.TensorProtoDataType.INT64) + return padding_c + + +@_onnx_symbolic("aten::constant_pad_nd") +def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value=None): + mode = "constant" + value = symbolic_helper._maybe_get_scalar(value) + value = symbolic_helper._if_scalar_type_as(value, input) + pad = _prepare_onnx_paddings(g, input, padding) + return g.op("Pad", input, pad, value, mode_s=mode) + + +@_onnx_symbolic("aten::reflection_pad1d") +@_onnx_symbolic("aten::reflection_pad2d") +@_onnx_symbolic("aten::reflection_pad3d") +def reflection_pad(g: jit_utils.GraphContext, input, padding): + mode = "reflect" + paddings = _prepare_onnx_paddings(g, input, padding) + return g.op("Pad", input, paddings, mode_s=mode) + + +@_onnx_symbolic("aten::replication_pad1d") +@_onnx_symbolic("aten::replication_pad2d") +@_onnx_symbolic("aten::replication_pad3d") +def replication_pad(g: jit_utils.GraphContext, input, padding): + mode = "edge" + paddings = _prepare_onnx_paddings(g, input, padding) + return g.op("Pad", input, paddings, mode_s=mode) + + +@_onnx_symbolic("aten::pad") +def pad( + g: jit_utils.GraphContext, + input: _C.Value, + pad: _C.Value, + mode: _C.Value, + value: _C.Value, +): + mode = symbolic_helper._parse_arg(mode, "s") + if mode == "replicate": + return replication_pad(g, input, pad) + elif mode == "reflect": + return reflection_pad(g, input, pad) + elif mode == "constant": + return constant_pad_nd(g, input, pad, value) + elif mode == "circular": + return opset9._pad_circular(g, input, pad) + else: + raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) + + +@_onnx_symbolic("aten::linalg_det") +def linalg_det(g: jit_utils.GraphContext, self): + return g.op("Det", self) + + +@_onnx_symbolic("aten::logdet") +def logdet(g: jit_utils.GraphContext, input): + return opset9.log(g, linalg_det(g, input)) + + +@_onnx_symbolic("aten::arange") +def arange(g: jit_utils.GraphContext, *args): + def _get_arange_dtype(dtype): + dtype = symbolic_helper._maybe_get_const(dtype, "i") + return dtype + + if len(args) == 2 and all(isinstance(val, int) for val in args): + # aten::arange(Scalar start, Scalar end) + dtype = torch.int64 + # Start index. + start = g.op( + "Constant", + value_t=torch.tensor(args[0], dtype=dtype), + ) + # End (exclusive) index. + end = g.op( + "Constant", + value_t=torch.tensor(args[1], dtype=dtype), + ) + # Step size from start to end indexes. + delta_default = g.op( + "Constant", + value_t=torch.tensor(1, dtype=dtype), + ) + return g.op("Range", start, end, delta_default) + elif len(args) == 2 or len(args) == 5: + if len(args) == 2: + # aten::arange(Scalar end, Tensor out) + dtype = None + else: + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[1]) + type_, end, start, step = symbolic_helper._arange_cast_helper( + g, end=args[0], dtype=dtype + ) + start_default = g.op( + "Constant", + value_t=torch.tensor(0, dtype=type_.dtype()), + ) + delta_default = g.op( + "Constant", + value_t=torch.tensor(1, dtype=type_.dtype()), + ) + return g.op("Range", start_default, end, delta_default) + elif len(args) == 4 or len(args) == 7: + if len(args) == 4: + # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) + dtype = None + else: + # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[3]) + _, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], step=args[2], dtype=dtype + ) + return g.op("Range", start, end, step) + elif len(args) == 6: + # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[2]) + type_, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], dtype=dtype + ) + delta_default = g.op( + "Constant", + value_t=torch.tensor(1, dtype=type_.dtype()), + ) + return g.op("Range", start, end, delta_default) + else: + return symbolic_helper._unimplemented( + "aten::arange", f"with {len(args)} arguments" + ) + + +@_onnx_symbolic("aten::_dim_arange") +@symbolic_helper.parse_args("v", "i") +def _dim_arange(g: jit_utils.GraphContext, like, dim): + like_shape = g.op("Shape", like) + stop = g.op( + "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 + ) + return arange(g, stop, 4, None, None, None) + + +@_onnx_symbolic("aten::size") +@symbolic_helper.quantized_args(True, quantize_output=False) +def size(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Shape", self) + return symbolic_helper._size_helper(g, self, dim) + + +@_onnx_symbolic("aten::squeeze") +def squeeze(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Squeeze", self) + + # dim as a tensor + if not symbolic_helper._is_constant(dim): + return symbolic_helper._squeeze_helper(g, self, [dim]) + + dim = symbolic_helper._get_const(dim, "i", "dim") + + input_rank = symbolic_helper._get_tensor_rank(self) + adjusted_dim = dim + if input_rank is not None and dim < 0: + adjusted_dim += input_rank + dim_size = symbolic_helper._get_tensor_dim_size(self, adjusted_dim) + if (dim < 0 and input_rank is None) or dim_size is None: + # If onnx shape inference is not on, export always as dynamic. + # Because we cannot tell if observed static shape is also static at runtime. + # create "cond" node (condition is shape[i]==1) + dim_constant = g.op("Constant", value_t=torch.tensor([dim])) + size = symbolic_helper._size_helper(g, self, dim_constant) + const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64)) + cond = g.op("Equal", size, const_one) + # create the "If" node and add the "then" and "else" blocks to it. + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", cond, n_blocks=2 + ) + squeeze_ = symbolic_helper._squeeze_helper(if_context, self, [dim]) + utils._add_output_to_block(if_context.block, squeeze_) + identity_ = else_context.op("Identity", self) + utils._add_output_to_block(else_context.block, identity_) + return if_op + + # For static input shape + dim = adjusted_dim + if dim_size > 1: + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(dim) + + ". The size of " + + "this dimension in the given input is " + + str(dim_size) + + ". The model will " + + "be exported without the squeeze node. If the model is intended to be used with dynamic " + + "input shapes, please export with dynamic_axes argument." + ) + return self + return symbolic_helper._squeeze_helper(g, self, [dim]) + + +@_onnx_symbolic("aten::unsqueeze") +def unsqueeze(g: jit_utils.GraphContext, self, dim): + if symbolic_helper._is_constant(dim): + dim = symbolic_helper._get_const(dim, "i", "dim") + + return symbolic_helper._unsqueeze_helper(g, self, [dim]) + + +@_onnx_symbolic("aten::mm") +def mm(g: jit_utils.GraphContext, self, other): + return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0) + + +@_onnx_symbolic("aten::index") +def index(g: jit_utils.GraphContext, self, index): + if symbolic_helper._is_packed_list(index): + indices = symbolic_helper._unpack_list(index) + else: + indices = [index] + + # Handle single mask index. + if len(indices) == 1: + index = indices[0] + if not symbolic_helper._is_none(index) and ( + symbolic_helper._is_bool(index) + or _type_utils.JitScalarType.from_value(index) + == _type_utils.JitScalarType.UINT8 + ): + index = opset9.nonzero(g, index) + return g.op("GatherND", self, index) + return opset9.index(g, self, index) + + +@_onnx_symbolic("aten::index_fill") +def index_fill(g: jit_utils.GraphContext, self, dim, index, value): + expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + value = symbolic_helper._maybe_get_scalar(value) + value = symbolic_helper._if_scalar_type_as(value, self) + expanded_value = opset9.expand(g, value, expanded_index_shape, None) + return scatter(g, self, dim, expanded_index, expanded_value) + + +@_onnx_symbolic("aten::index_copy") +def index_copy(g: jit_utils.GraphContext, self, dim, index, source): + _expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + return scatter(g, self, dim, expanded_index, source) + + +@_onnx_symbolic("aten::bitwise_right_shift") +@_onnx_symbolic("aten::__rshift_") +def __rshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + if _type_utils.JitScalarType.from_value( + other, _type_utils.JitScalarType.UNDEFINED + ) != _type_utils.JitScalarType.from_value(self): + other = g.op( + "Cast", + other, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + + if ( + _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) + == _type_utils.JitScalarType.UINT8 + ): + return g.op("BitShift", self, other, direction_s="RIGHT") + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + rshift = g.op("Div", self, two_pow) + return rshift + + +@_onnx_symbolic("aten::bitwise_left_shift") +@_onnx_symbolic("aten::__lshift_") +def __lshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + if _type_utils.JitScalarType.from_value( + other, _type_utils.JitScalarType.UNDEFINED + ) != _type_utils.JitScalarType.from_value(self): + other = g.op( + "Cast", + other, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + + if ( + _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) + == _type_utils.JitScalarType.UINT8 + ): + return g.op("BitShift", self, other, direction_s="LEFT") + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + lshift = g.op("Mul", self, two_pow) + return lshift + + +def _get_im2col_indices_along_dim( + g: jit_utils.GraphContext, input_d, kernel_size_d, dilation_d, padding_d, stride_d +): + # Input is always 4-D (N, C, H, W) + # Calculate indices of sliding blocks along spatial dimension + # Slide kernel over input each dim d: + # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1) + # with steps = stride + + blocks_d = g.op( + "Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2)) + ) + blocks_d = g.op( + "Sub", + blocks_d, + g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1))), + ) + + # Stride kernel over input and find starting indices along dim d + blocks_d_indices = g.op( + "Range", + g.op("Constant", value_t=torch.tensor(0)), + blocks_d, + g.op("Constant", value_t=torch.tensor(stride_d)), + ) + + # Apply dilation on kernel and find its indices along dim d + kernel_grid = torch.arange(0, kernel_size_d * dilation_d, dilation_d) + kernel_grid = g.op("Constant", value_t=kernel_grid.unsqueeze(0)) + + # Broadcast and add kernel staring positions (indices) with + # kernel_grid along dim d, to get block indices along dim d + blocks_d_indices = symbolic_helper._unsqueeze_helper( + g, blocks_d_indices, [0] + ) # Reshape to [1, -1] + kernel_mask = symbolic_helper._reshape_helper( + g, kernel_grid, g.op("Constant", value_t=torch.tensor([-1, 1])) + ) + block_mask = g.op("Add", blocks_d_indices, kernel_mask) + + return block_mask + + +def _get_im2col_padded_input(g: jit_utils.GraphContext, input, padding_h, padding_w): + # Input is always 4-D tensor (N, C, H, W) + # Padding tensor has the following format: (padding_h, padding_w) + # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...) + pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2)) + return g.op("Pad", input, pad) + + +def _get_im2col_output_shape(g: jit_utils.GraphContext, input, kernel_h, kernel_w): + batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0))) + channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1))) + channel_unfolded = g.op( + "Mul", channel_dim, g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w)) + ) + + return g.op( + "Concat", + symbolic_helper._unsqueeze_helper(g, batch_dim, [0]), + symbolic_helper._unsqueeze_helper(g, channel_unfolded, [0]), + g.op("Constant", value_t=torch.tensor([-1])), + axis_i=0, + ) + + +@_onnx_symbolic("aten::im2col") +@symbolic_helper.parse_args("v", "is", "is", "is", "is") +def im2col(g: jit_utils.GraphContext, input, kernel_size, dilation, padding, stride): + # Input is always 4-D tensor (N, C, H, W) + # All other args are int[2] + + input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2))) + input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3))) + + stride_h, stride_w = stride[0], stride[1] + padding_h, padding_w = padding[0], padding[1] + dilation_h, dilation_w = dilation[0], dilation[1] + kernel_h, kernel_w = kernel_size[0], kernel_size[1] + + blocks_row_indices = _get_im2col_indices_along_dim( + g, input_h, kernel_h, dilation_h, padding_h, stride_h + ) + blocks_col_indices = _get_im2col_indices_along_dim( + g, input_w, kernel_w, dilation_w, padding_w, stride_w + ) + + output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w) + padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w) + + # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1 + # [[[[1., 2., 3.,], + # [4., 5., 6.,], + # [7., 8., 9.,]]]] + # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get: + # [[[[[1., 2., 3.], + # [4., 5., 6.]], + # [[4., 5., 6.], + # [7., 8., 9.]]]]] + # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get: + # [[[[[[1., 2.], + # [4., 5.]], + # [[2., 3.], + # [5., 6]]], + # [[[4., 5.], + # [7., 8.]], + # [[5., 6.], + # [8., 9.]]]]]] + # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get: + # [[[1., 2., 4., 5.], + # [2., 3., 5., 6.], + # [4., 5., 7., 8.], + # [5., 6., 8., 9.]]] + output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2) + output = g.op("Gather", output, blocks_col_indices, axis_i=4) + output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5]) + return symbolic_helper._reshape_helper(g, output, output_shape) + + +@_onnx_symbolic("aten::narrow") +def narrow(g: jit_utils.GraphContext, input, dim, start, length): + end = g.op("Add", start, length) + return symbolic_helper._slice_helper(g, input, axes=dim, starts=start, ends=end) + + +@_onnx_symbolic("aten::flatten") +@symbolic_helper.quantized_args(True, False, False) +@symbolic_helper.parse_args("v", "i", "i") +def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): + dim = symbolic_helper._get_tensor_rank(input) + if dim == 1: + return input + # use ONNX's Flatten operator for cases where the output shape is 2D + if start_dim == 1: + if end_dim == -1 or (dim is not None and end_dim == dim - 1): + return g.op("Flatten", input, axis_i=start_dim) + elif start_dim == 0: + if end_dim == -2 or (dim is not None and end_dim == dim - 2): + return g.op("Flatten", input, axis_i=end_dim + 1) + if dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + ) + # if end_dim is negative add dim + if end_dim < 0: + end_dim = dim + end_dim + + return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) + + +@_onnx_symbolic("aten::linalg_vector_norm") +@symbolic_helper.parse_args("v", "f", "is", "b", "v") +def linalg_vector_norm( + g: jit_utils.GraphContext, + self, + ord, + dim: Sequence[int] | None, + keepdim: bool, + dtype, +): + return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + return symbolic_helper._embedding_bag_helper( + g, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ) + + +@_onnx_symbolic("aten::embedding_renorm") +@symbolic_helper.parse_args("v", "v", "f", "f") +def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_type): + unique_indices = g.op("Unique", indices) + partial_weight = g.op("Gather", weight, unique_indices) + norm_i = int(norm_type) + if norm_i == 1: + norm_type = "ReduceL1" + elif norm_i == 2: + norm_type = "ReduceL2" + else: + raise errors.SymbolicValueError( + f"Unsupported: ONNX export of embedding_renorm with norm: {norm_i}. " + "Only 1. and 2. are supported.", + weight, + ) + partial_weight_norm = g.op(norm_type, partial_weight, axes_i=[1], keepdims_i=1) + # https://github.com/pytorch/pytorch/blob/0a07488ed2c47765e337e290bd138c0e6e459cbd/aten/src/ATen/native/Embedding.cpp#L177 + # Add 1e-7 to prevent division by zero. + partial_weight_norm_ = g.op( + "Add", partial_weight_norm, g.op("Constant", value_t=torch.tensor(1e-7)) + ) + max_norm = torch.tensor(max_norm) + scales = g.op("Div", max_norm, partial_weight_norm_) + partial_weight_renorm = g.op("Mul", partial_weight, scales) + partial_weight_renorm = g.op( + "Where", + g.op("Greater", partial_weight_norm, max_norm), + partial_weight_renorm, + partial_weight, + ) + return g.op( + "ScatterND", + weight, + symbolic_helper._unsqueeze_helper(g, unique_indices, [1]), + partial_weight_renorm, + ) + + +@_onnx_symbolic("aten::chunk") +def chunk(g: jit_utils.GraphContext, self, chunks, dim): + # Calculate chunk size for dynamic chunk + dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0) + chunk_size_s = g.op( + "Sub", chunks, g.op("Constant", value_t=torch.tensor([1], dtype=torch.long)) + ) + chunk_size = g.op("Div", g.op("Add", dim_size, chunk_size_s), chunks) + # Create splits vector + chunk_vec = [ + opset9.expand(g, chunk_size, chunk_size_s, None), + g.op("Sub", dim_size, g.op("Mul", chunk_size, chunk_size_s)), + ] + chunk_vec = g.op("Concat", *chunk_vec, axis_i=0) + return split(g, self, chunk_vec, dim) + + +@_onnx_symbolic("aten::normal") +def normal( + g: jit_utils.GraphContext, + mean, + std, + sizes=None, + generator=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, +): + # If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a + # scale-location transformation of that distribution, which has mean mu and variance sigma's square. If x is a sample + # from a mean 0 and variance 1 distribution then + # sigma x+mu + # is a sample with mean mu and variance sigma's square. + if sizes is not None and not symbolic_helper._is_none(sizes): + mean = opset9.expand(g, mean, sizes, None) + result = opset9.mul(g, std, g.op("RandomNormalLike", mean)) + return add(g, result, mean) + + +@_onnx_symbolic("aten::atleast_1d") +def atleast_1d(g: jit_utils.GraphContext, self: torch._C.Value): + # NOTE: If it's 0D, reshape to 1D + + # NOTE: self could be a packed list or a tensor + if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): + tensor_list = symbolic_helper._unpack_list(self) + new_tensor_list = [] + for tensor in tensor_list: + new_tensor = tensor + tensor_rank = symbolic_helper._get_tensor_rank(tensor) + if tensor_rank == 0: + new_tensor = symbolic_helper._reshape_helper( + g, new_tensor, g.op("Constant", value_t=torch.tensor([1])) + ) + new_tensor_list.append(new_tensor) + return g.op("SequenceConstruct", *new_tensor_list) + + tensor_rank = symbolic_helper._get_tensor_rank(self) + if tensor_rank == 0: + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([1])) + ) + return self + + +@_onnx_symbolic("aten::atleast_2d") +def atleast_2d(g: jit_utils.GraphContext, self: torch._C.Value): + # NOTE: If it's 0D, reshape to 2D + # If it's 1D, unsqueeze to 2D + + # NOTE: self could be a packed list or a tensor + if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): + tensor_list = symbolic_helper._unpack_list(self) + new_tensor_list = [] + for tensor in tensor_list: + new_tensor = tensor + tensor_rank = symbolic_helper._get_tensor_rank(tensor) + if tensor_rank == 0: + new_tensor = symbolic_helper._reshape_helper( + g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1])) + ) + elif tensor_rank == 1: + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[0] + ) + new_tensor_list.append(new_tensor) + return g.op("SequenceConstruct", *new_tensor_list) + + tensor_rank = symbolic_helper._get_tensor_rank(self) + if tensor_rank == 0: + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([1, 1])) + ) + elif tensor_rank == 1: + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0]) + return self + + +@_onnx_symbolic("aten::atleast_3d") +def atleast_3d(g: jit_utils.GraphContext, self: torch._C.Value): + # NOTE: If it's 0D, reshape to 3D + # If it's 1D, unsqueeze to 3D + # If it's 2D, unsqueeze to 3D + + # NOTE: self could be a packed list or a tensor + if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): + tensor_list = symbolic_helper._unpack_list(self) + new_tensor_list = [] + for tensor in tensor_list: + new_tensor = tensor + tensor_rank = symbolic_helper._get_tensor_rank(tensor) + if tensor_rank == 0: + new_tensor = symbolic_helper._reshape_helper( + g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1, 1])) + ) + elif tensor_rank == 1: + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[0] + ) + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[-1] + ) + elif tensor_rank == 2: + new_tensor = symbolic_helper._unsqueeze_helper( + g, new_tensor, axes_i=[-1] + ) + new_tensor_list.append(new_tensor) + return g.op("SequenceConstruct", *new_tensor_list) + + tensor_rank = symbolic_helper._get_tensor_rank(self) + if tensor_rank == 0: + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([1, 1, 1])) + ) + elif tensor_rank == 1: + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0]) + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1]) + elif tensor_rank == 2: + self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1]) + return self + + +@_onnx_symbolic("prim::ConstantChunk") +def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): + input_shape = g.op("Shape", self) + axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0) + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long)) + chunk_size_minus_1 = g.op( + "Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long) + ) + input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1) + chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size) + res = [] + for i in range(chunks): + index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long)) + end = g.op("Mul", chunk_dim, index) + res.append(g.op("Slice", self, start, end, axis)) + start = end + return res + + +@_onnx_symbolic("aten::hstack") +def hstack(g: jit_utils.GraphContext, tensor_list: _C.Value): + tensor_list = atleast_1d(g, tensor_list) + first_tensor = g.op( + "SequenceAt", + tensor_list, + g.op("Constant", value_t=torch.tensor(0, dtype=torch.long)), + ) + first_tensor_shape = g.op("Shape", first_tensor) + first_tensor_dim = g.op("Size", first_tensor_shape) + + const_one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) + equal_to_one = g.op("Equal", first_tensor_dim, const_one) + + ( + if_op_greater, + (if_context_equal, else_context_equal), + _, + ) = jit_utils.add_op_with_blocks(g, "If", equal_to_one, n_blocks=2, outputs=1) + result_if = if_context_equal.op( + "ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0 + ) + utils._add_output_to_block(if_context_equal.block, result_if) + result_else = else_context_equal.op( + "ConcatFromSequence", tensor_list, axis_i=1, new_axis_i=0 + ) + utils._add_output_to_block(else_context_equal.block, result_else) + result = if_op_greater.node().output() + + return result + + +@_onnx_symbolic("aten::vstack") +def vstack(g: jit_utils.GraphContext, tensor_list: _C.Value): + tensor_list = atleast_2d(g, tensor_list) + return g.op("ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index 63e137734e8a7..b847cd2f351db 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.symbolic_opset12.""" from __future__ import annotations @@ -6,3 +7,469 @@ __all__: list[str] = [] from torch.onnx._internal.torchscript_exporter.symbolic_opset12 import * # noqa: F401,F403 +======= +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +from __future__ import annotations + +import functools +import sys + +import torch +from torch._C import _onnx as _C_onnx +from torch.onnx import ( + _type_utils, + errors, + symbolic_helper, + symbolic_opset9 as opset9, + utils, +) +from torch.onnx._internal import jit_utils, registration + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +# This file exports ONNX ops for opset 12 + +__all__ = [ + "argmax", + "argmin", + "binary_cross_entropy_with_logits", + "celu", + "cross_entropy_loss", + "dropout", + "einsum", + "ge", + "le", + "native_dropout", + "nll_loss", + "nll_loss2d", + "nll_loss_nd", + "outer", + "pow", + "tensordot", + "unfold", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12) + + +def _einsum_helper(g: jit_utils.GraphContext, equation, tensors): + if not tensors: + raise RuntimeError("Einsum inputs are empty.") + # ONNX does not support bool for Einsum inputs. + if symbolic_helper._is_bool(tensors[0]): + tensors = [ + g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64) + for tensor in tensors + ] + return g.op( + "Cast", + g.op("Einsum", *tensors, equation_s=equation), + to_i=_C_onnx.TensorProtoDataType.BOOL, + ) + else: + return g.op("Einsum", *tensors, equation_s=equation) + + +@_onnx_symbolic("aten::einsum") +@symbolic_helper.parse_args("s", "v", "is") +def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None): + tensors = symbolic_helper._unpack_list(tensor_list) + return _einsum_helper(g, equation, tensors) + + +@_onnx_symbolic("aten::outer") +@symbolic_helper.parse_args("v", "v") +def outer(g: jit_utils.GraphContext, input, other): + # make sure to cast other to self's type + if _type_utils.JitScalarType.from_value( + other, _type_utils.JitScalarType.UNDEFINED + ) != _type_utils.JitScalarType.from_value(input): + other = g.op( + "Cast", + other, + to_i=_type_utils.JitScalarType.from_value(input).onnx_type(), + ) + return _einsum_helper(g, "i,j->ij", [input, other]) + + +def _dropout_returns_masked_input_and_mask( + g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool +) -> tuple[torch._C.Value, torch._C.Value | None]: + symbolic_helper.check_training_mode(train, "dropout") + # In eval mode, dropout is non-op. That is, if the node's + # train param is set to False, dropout just returns its inputs. + if not train: + return input, None + p = g.op("Constant", value_t=torch.tensor(p)) + t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool)) + r, mask = g.op("Dropout", input, p, t, outputs=2) + return r, mask + + +@_onnx_symbolic("aten::dropout") +@symbolic_helper.parse_args("v", "f", "b") +def dropout(g: jit_utils.GraphContext, input, p, train): + masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train) + return masked + + +@_onnx_symbolic("aten::native_dropout") +@symbolic_helper.parse_args("v", "f", "b") +def native_dropout(g: jit_utils.GraphContext, input, p, train): + return _dropout_returns_masked_input_and_mask(g, input, p, train) + + +@_onnx_symbolic("aten::nll_loss") +def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index): + # none reduction : onnx::Constant[value={0}] + # mean reduction : onnx::Constant[value={1}] + # sum reduction : onnx::Constant[value={2}] + reduction = symbolic_helper._maybe_get_const(reduction, "i") + reduction_vals = ["none", "mean", "sum"] + reduction = reduction_vals[reduction] + + # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value. + # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). + ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") + if weight.node().mustBeNone(): + nllloss = g.op( + "NegativeLogLikelihoodLoss", + self, + target, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + else: + nllloss = g.op( + "NegativeLogLikelihoodLoss", + self, + target, + weight, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + + return nllloss + + +@_onnx_symbolic("aten::nll_loss2d") +def nll_loss2d( + g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index +): + return nll_loss(g, self, target, weight, reduction, ignore_index) + + +@_onnx_symbolic("aten::nll_loss_nd") +def nll_loss_nd( + g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index +): + return nll_loss(g, self, target, weight, reduction, ignore_index) + + +@_onnx_symbolic("aten::cross_entropy_loss") +def cross_entropy_loss( + g: jit_utils.GraphContext, + self, + target, + weight, + reduction, + ignore_index, + label_smoothing, +): + # none reduction : onnx::Constant[value={0}] + # mean reduction : onnx::Constant[value={1}] + # sum reduction : onnx::Constant[value={2}] + reduction = symbolic_helper._maybe_get_const(reduction, "i") + reduction_vals = ["none", "mean", "sum"] + reduction = reduction_vals[reduction] + + label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f") + if label_smoothing is not None and label_smoothing > 0.0: + raise errors.SymbolicValueError( + "Unsupported: ONNX does not support label_smoothing", self + ) + + # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value. + # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). + ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") + if weight.node().mustBeNone(): + celoss = g.op( + "SoftmaxCrossEntropyLoss", + self, + target, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + else: + celoss = g.op( + "SoftmaxCrossEntropyLoss", + self, + target, + weight, + reduction_s=reduction, + ignore_index_i=ignore_index, + ) + + return celoss + + +@_onnx_symbolic("aten::binary_cross_entropy_with_logits") +@symbolic_helper.parse_args("v", "v", "v", "v", "i") +def binary_cross_entropy_with_logits( + g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction +): + p = g.op("Constant", value_t=torch.tensor([1])) + sig_x = opset9.sigmoid(g, input) + log_sig_x = opset9.log(g, sig_x) + sub_1_x = opset9.sub(g, p, sig_x) + sub_1_y = opset9.sub(g, p, target) + log_1_x = opset9.log(g, sub_1_x) + if pos_weight is None or symbolic_helper._is_none(pos_weight): + output = opset9.neg( + g, + opset9.add( + g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x) + ), + ) + else: + output = opset9.neg( + g, + opset9.add( + g, + opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight), + opset9.mul(g, sub_1_y, log_1_x), + ), + ) + + if weight is not None and not symbolic_helper._is_none(weight): + output = opset9.mul(g, weight, output) + + reduction = symbolic_helper._maybe_get_const(reduction, "i") + if reduction == 0: + return output + elif reduction == 1: + return g.op("ReduceMean", output, keepdims_i=0) + elif reduction == 2: + return g.op("ReduceSum", output, keepdims_i=0) + else: + return symbolic_helper._onnx_unsupported( + "binary_cross_entropy_with_logits with reduction other than none, mean, or sum", + input, + ) + + +@_onnx_symbolic("aten::celu") +def celu(g: jit_utils.GraphContext, self, alpha): + alpha = symbolic_helper._maybe_get_const(alpha, "f") + # if the input is of type double cast it to float + if ( + _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) + == _type_utils.JitScalarType.DOUBLE + ): + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) + out = g.op("Celu", self, alpha_f=alpha) + return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE) + + return g.op("Celu", self, alpha_f=alpha) + + +@_onnx_symbolic("aten::argmax") +@symbolic_helper.parse_args("v", "v", "b") +def argmax( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") + + +@_onnx_symbolic("aten::argmin") +@symbolic_helper.parse_args("v", "v", "b") +def argmin( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") + + +@_onnx_symbolic("aten::pow") +def pow(g: jit_utils.GraphContext, self, exponent): + return g.op("Pow", self, exponent) + + +@_onnx_symbolic("aten::ge") +def ge(g: jit_utils.GraphContext, input, other): + return g.op("GreaterOrEqual", input, other) + + +@_onnx_symbolic("aten::le") +def le(g: jit_utils.GraphContext, input, other): + return g.op("LessOrEqual", input, other) + + +@_onnx_symbolic("aten::unfold") +@symbolic_helper.parse_args("v", "i", "v", "v") +def unfold(g: jit_utils.GraphContext, input, dimension, size, step): + const_size = symbolic_helper._maybe_get_const(size, "i") + const_step = symbolic_helper._maybe_get_const(step, "i") + if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value( + const_step + ): + return opset9.unfold(g, input, dimension, const_size, const_step) + + sizedim = symbolic_helper._get_tensor_dim_size(input, dimension) + if sizedim is not None: + low_start = g.op("Constant", value_t=torch.tensor(0)) + low_end = g.op("Constant", value_t=torch.tensor(sizedim)) + hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1)) + low_indices = g.op("Range", low_start, low_end, step) + hi_indices = g.op("Range", size, hi_end, step) + + low_size = symbolic_helper._size_helper( + g, low_indices, g.op("Constant", value_t=torch.tensor(0)) + ) + hi_size = symbolic_helper._size_helper( + g, hi_indices, g.op("Constant", value_t=torch.tensor(0)) + ) + + ndim = symbolic_helper._get_tensor_rank(input) + assert ndim is not None + perm = list(range(0, ndim)) + perm.append(perm.pop(dimension)) + + unsqueeze_list = [] + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op( + "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL + ) + loop_len = g.op("Min", low_size, hi_size) + + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, n_blocks=1 + ) + + loop_block = loop_context.block + block_input_iter = utils._add_input_to_block(loop_block) + cond = utils._add_input_to_block(loop_block) # noqa: F841 + + starts = loop_context.op("Gather", low_indices, block_input_iter) + ends = loop_context.op("Gather", hi_indices, block_input_iter) + axes = loop_context.op("Constant", value_t=torch.tensor([2])) + starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0]) + ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0]) + stack = loop_context.op("Slice", input, starts, ends, axes) + + unsqueeze = symbolic_helper._unsqueeze_helper( + loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension] + ) + unsqueeze_list.append(unsqueeze) + concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0) + + cond_out = loop_context.op( + "Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL + ) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, concat) + + loop_output = loop.node().output() + perm = [0, 1, 2, 3, 4] + perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0] + transpose = g.op("Transpose", loop_output, perm_i=perm) + squeeze = symbolic_helper._squeeze_helper(g, transpose, [0]) + + return squeeze + + return symbolic_helper._unimplemented("Unfold", "input size not accessible") + + +@_onnx_symbolic("aten::tensordot") +@symbolic_helper.parse_args("v", "v", "is", "is", "v") +def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None): + if out is not None: + symbolic_helper._unimplemented( + "Tensordot", "Out parameter is not supported for tensordot." + ) + + dim_count_a = symbolic_helper._get_tensor_rank(input_a) + if dim_count_a is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.", + input_a, + ) + + dim_count_b = symbolic_helper._get_tensor_rank(input_b) + if dim_count_b is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.", + input_b, + ) + + dims_a = [ + (dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i] + for i in range(len(dims_a)) + ] + dims_b = [ + (dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i] + for i in range(len(dims_b)) + ] + + left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)] + left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)] + + new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a) + new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b) + + input_shape = g.op("Shape", new_input_a) + left_sizes_a = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)] + ) + shape_sizes = [ + left_sizes_a, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + ] + output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) + + input_shape = g.op("Shape", output_a) + slices = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] + ) + shape_sizes = [ + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + slices, + ] + output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) + + input_shape = g.op("Shape", new_input_b) + left_sizes_b = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize] + ) + slices = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)] + ) + shape_sizes = [ + slices, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + ] + output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) + + input_shape = g.op("Shape", output_b) + slices = symbolic_helper._slice_helper( + g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] + ) + shape_sizes = [ + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), + slices, + ] + output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) + + output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b])) + + shape_sizes = [left_sizes_a, left_sizes_b] + return opset9._reshape_from_tensor(g, output, shape_sizes) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py index 18aff9295be8c..9aaef2b5a36a1 100644 --- a/torch/onnx/symbolic_opset13.py +++ b/torch/onnx/symbolic_opset13.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.symbolic_opset13.""" from __future__ import annotations @@ -6,3 +7,1118 @@ __all__: list[str] = [] from torch.onnx._internal.torchscript_exporter.symbolic_opset13 import * # noqa: F401,F403 +======= +# mypy: allow-untyped-defs +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +# This file exports ONNX ops for opset 13 +import functools + +import torch +import torch._C._onnx as _C_onnx +from torch.onnx import ( + _constants, + _type_utils, + errors, + symbolic_helper, + symbolic_opset11 as opset11, + symbolic_opset9 as opset9, + utils, +) +from torch.onnx._internal import jit_utils, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13) + + +@_onnx_symbolic("aten::softmax") +@symbolic_helper.parse_args("v", "i", "none") +def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + softmax = g.op("Softmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + softmax = g.op( + "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + + return softmax + + +@_onnx_symbolic("aten::log_softmax") +@symbolic_helper.parse_args("v", "i", "none") +def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + return_op = g.op("LogSoftmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + return_op = g.op( + "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + return return_op + + +@_onnx_symbolic("aten::frobenius_norm") +@symbolic_helper.parse_args("v", "v", "i") +def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): + dim_val = symbolic_helper._maybe_get_const(dim, "is") + if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0: + return g.op("ReduceL2", self, keepdims_i=0) + sqr = g.op("Mul", self, self) + sumsqr = symbolic_helper._reducesum_helper(g, sqr, dim, keepdims_i=keepdim) + return g.op("Sqrt", sumsqr) + + +@_onnx_symbolic("aten::split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): + split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) + if _outputs is None: + return split_out + # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. + if ( + symbolic_helper._is_packed_list(split_size_or_sizes) + and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs + ): + split_sizes = [ + symbolic_helper._unsqueeze_helper(g, v, [0]) + for v in symbolic_helper._unpack_list(split_size_or_sizes) + ] + + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + res = [] + for i in range(_outputs): + end = g.op( + "Add", start, split_sizes[i] + ) # split_sizes is a list of same length as _outputs + res.append(g.op("Slice", self, start, end, axis)) + start = end + return res + return [ + g.op( + "SequenceAt", + split_out, + g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), + ) + for i in range(_outputs) + ] + + split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") + if split_val.dim() > 0: + return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs) + split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + if _outputs is not None: + size = split_size * _outputs + else: + raise errors.SymbolicValueError( + "Unknown dimension size not supported", self + ) + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + splits = g.op("Constant", value_t=torch.tensor(splits)) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::split_with_sizes") +def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): + return split(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unsafe_split") +def unsafe_split( + g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None +): + return split(g, self, split_size_or_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unsafe_split_with_sizes") +def unsafe_split_with_sizes( + g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None +): + return split_with_sizes(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::tensor_split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def tensor_split( + g: jit_utils.GraphContext, self, indices_or_sections, dim, _outputs=None +): + axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + axis = opset11.unsqueeze(g, axis, 0) + const_1 = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) + + if symbolic_helper._is_split_static(indices_or_sections, _outputs): + split_val = symbolic_helper._node_get(indices_or_sections.node(), "value") + + if split_val.dim() > 0: + start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + res = [] + assert _outputs is not None + for i in range(_outputs - 1): + end = g.op( + "Gather", + indices_or_sections, + g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), + axis_i=0, + ) + res.append(g.op("Slice", self, start, end, axis)) + start = end + + end = symbolic_helper._size_helper(g, self, axis) + res.append(g.op("Slice", self, start, end, axis)) + return res + + split_size = symbolic_helper._get_const( + indices_or_sections, "i", "indices_or_sections" + ) + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + if _outputs is not None: + size = split_size * _outputs + else: + raise errors.SymbolicValueError( + "Unknown dimension size not supported", self + ) + + min_split_size = size // split_size + num_splits_one_extra = size % split_size + + splits = num_splits_one_extra * [min_split_size + 1] + leftover = (split_size - num_splits_one_extra) * [min_split_size] + + splits = g.op( + "Constant", value_t=torch.tensor(splits + leftover, dtype=torch.long) + ) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + if ( + symbolic_helper._is_tensor(indices_or_sections) + and symbolic_helper._get_tensor_rank(indices_or_sections) == 1 + ): + loop_len = symbolic_helper._size_helper( + g, indices_or_sections, g.op("Constant", value_t=torch.tensor(0)) + ) + loop_len = opset11.unsqueeze(g, loop_len, 0) + loop_condition = g.op("Cast", const_1, to_i=_C_onnx.TensorProtoDataType.BOOL) + + # To make the first slice in the below loop work, + # we pad a zero to the first position so that it will be the initial start of slice. + padding_0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) + indices_or_sections = g.op("Concat", padding_0, indices_or_sections, axis_i=0) + + final_splits = g.op("SequenceEmpty") + # Loop inputs + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, final_splits, outputs=1, n_blocks=1 + ) + + loop_block = loop_context.block + block_input_iter = utils._add_input_to_block(loop_block) + cond = utils._add_input_to_block(loop_block) # noqa: F841 + final_splits = utils._add_input_to_block(loop_block) + + start = loop_context.op( + "Gather", indices_or_sections, block_input_iter, axis_i=0 + ) + end = loop_context.op( + "Gather", + indices_or_sections, + loop_context.op("Add", block_input_iter, const_1), + axis_i=0, + ) + + slice = loop_context.op("Slice", self, start, end, axis) + final_splits = loop_context.op("SequenceInsert", final_splits, slice) + + # Loop outputs + cond_out = loop_context.op("Identity", loop_condition) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, final_splits) + + loop_out = loop.node().output() + start = g.op( + "Gather", + indices_or_sections, + g.op("Constant", value_t=torch.tensor(-1, dtype=torch.long)), + axis_i=0, + ) + start = opset11.unsqueeze(g, start, 0) + end = symbolic_helper._size_helper(g, self, axis) + + last_slice = g.op("Slice", self, start, end, axis) + + return g.op("SequenceInsert", loop_out, last_slice) + + else: # scalar tensor + dim_size = symbolic_helper._size_helper(g, self, axis) + min_split_size = g.op("Div", dim_size, indices_or_sections) + min_split_size_plus_1 = g.op( + "Add", + min_split_size, + const_1, + ) + num_splits_one_extra = g.op("Mod", dim_size, indices_or_sections) + splits = g.op("Tile", min_split_size_plus_1, num_splits_one_extra) + leftover = g.op( + "Tile", + min_split_size, + g.op( + "Sub", + opset11.unsqueeze(g, indices_or_sections, 0), + num_splits_one_extra, + ), + ) + + splits = g.op("Concat", splits, leftover, axis_i=0) + if _outputs is None: + return g.op("SplitToSequence", self, splits, axis_i=dim) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::unbind") +@symbolic_helper.parse_args("v", "i", "i") +def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): + if _outputs is None: + return g.op( + "SplitToSequence", + self, + g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), + axis_i=dim, + keepdims_i=0, + ) + + splits = g.op("Constant", value_t=torch.tensor([1] * _outputs)) + outputs = g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + outputs = [outputs] if _outputs == 1 else outputs + squeezed_outputs = [ + g.op("Squeeze", out, g.op("Constant", value_t=torch.tensor([dim]))) + for out in outputs + ] + return squeezed_outputs + + +@_onnx_symbolic("aten::nonzero_numpy") +# Emitted from `torch.nonzero(x, as_tuple=True)` +def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): + return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs) + + +@_onnx_symbolic("aten::where") +@symbolic_helper.parse_args("v", "v", "v", "i") +def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): + # Assumes that torch.where's first argument takes only Bool and Byte tensors. + if not symbolic_helper._is_bool(condition): + condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + if self is None: + condition = opset9.nonzero(g, condition) + return symbolic_helper._unbind_helper( + g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs + ) + return g.op("Where", condition, self, other) + + +@_onnx_symbolic("aten::fake_quantize_per_channel_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i") +def fake_quantize_per_channel_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + axis, + quant_min=-128, + quant_max=127, +): + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise errors.SymbolicValueError( + "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + inputs, + ) + # ONNX defines zero_point to be int8 or uint8 + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis) + if (quant_min, quant_max) == (0, 127): + quantized = g.op( + "Clip", + quantized, + opset9.unused(g), + g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), + ) + return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis) + + +@_onnx_symbolic("aten::fake_quantize_per_tensor_affine") +@symbolic_helper.parse_args("v", "v", "v", "i", "i") +def fake_quantize_per_tensor_affine( + g: jit_utils.GraphContext, + inputs, + scale, + zero_point, + quant_min=-128, + quant_max=127, +): + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise errors.SymbolicValueError( + "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + inputs, + ) + if quant_min == 0: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) + else: + zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) + if ( + _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED) + != _type_utils.JitScalarType.FLOAT + ): + scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) + quantized = g.op("QuantizeLinear", inputs, scale, zero_point) + if (quant_min, quant_max) == (0, 127): + quantized = g.op( + "Clip", + quantized, + opset9.unused(g), + g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)), + ) + return g.op("DequantizeLinear", quantized, scale, zero_point) + + +def _reduce_op_symbolic(onnx_op_name): + def symbolic(g, self, dim=None, keepdim=None): + self = symbolic_helper._maybe_cast_reduce_op_input(g, self) + if dim is None: + # all-reduce path + return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name) + else: + keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim") + return g.op(onnx_op_name, self, dim, keepdims_i=keepdim) + + return symbolic + + +@_onnx_symbolic( + "aten::sum", + decorate=[symbolic_helper._apply_params("ReduceSum", "sum")], +) +def _reduce_with_dtype(onnx_op, name): + symbolic = _reduce_op_symbolic(onnx_op) + + @symbolic_helper._overload_by_arg_count + def reduce(g, *args, **kwargs): + @symbolic_helper.parse_args("v", "none") + def reduce_nodim(g, self, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return symbolic_helper._unimplemented(name, "dtype", dtype) + result = symbolic(g, self) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + @symbolic_helper.parse_args("v", "v", "i", "none") + def reduce_dim(g, self, dim, keepdim, dtype): + dtype_onnx = None + if dtype.node().kind() == "onnx::Constant": + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() + self = g.op("Cast", self, to_i=dtype_onnx) + elif dtype.node().kind() != "prim::Constant": + return symbolic_helper._unimplemented(name, "dtype", dtype) + result = symbolic(g, self, dim, keepdim) + if dtype_onnx is not None: + result_dtype_onnx = _type_utils.JitScalarType.from_value( + result + ).onnx_type() + if result_dtype_onnx != dtype_onnx: + result = g.op("Cast", result, to_i=dtype_onnx) + return result + + return reduce_nodim, reduce_dim + + return reduce + + +# Ported from +# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/core.py#L6097 +# NOTE: Supporting aten::unflatten before opset13 needs helper function to adjust ONNX op changes in Concat, Slice, ... +@_onnx_symbolic("aten::unflatten") +def unflatten(g: jit_utils.GraphContext, input, dim, unflattened_size): + input_dim = symbolic_helper._get_tensor_rank(input) + if input_dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + ) + + # dim could be negative + input_dim = g.op("Constant", value_t=torch.tensor([input_dim], dtype=torch.int64)) + dim = g.op("Add", input_dim, dim) + dim = g.op("Mod", dim, input_dim) + + input_size = g.op("Shape", input) + + head_start_idx = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)) + head_end_idx = g.op( + "Reshape", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) + ) + head_part_rank = g.op("Slice", input_size, head_start_idx, head_end_idx) + + dim_plus_one = g.op( + "Add", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)) + ) + tail_start_idx = g.op( + "Reshape", + dim_plus_one, + g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)), + ) + tail_end_idx = g.op( + "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) + ) + tail_part_rank = g.op("Slice", input_size, tail_start_idx, tail_end_idx) + + final_shape = g.op( + "Concat", head_part_rank, unflattened_size, tail_part_rank, axis_i=0 + ) + + return symbolic_helper._reshape_helper(g, input, final_shape) + + +@_onnx_symbolic("aten::unsafe_chunk") +@symbolic_helper.parse_args("v", "i", "i", "i") +def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): + if _outputs is None: + return g.op( + "SplitToSequence", + self, + g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), + axis_i=dim, + keepdims_i=0, + ) + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + return symbolic_helper._unimplemented("unsafe_chunk", "unknown dimension size") + split_size = (size + chunks - 1) // chunks + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + + # TODO: So far we don"t have a module using this method. We"ll keep + # this as a constant unless we see a request of dynamics in any + # user's modules. + splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long)) + return g.op("Split", self, splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::tile") +def tile(g: jit_utils.GraphContext, self, dims): + self_shape = g.op("Shape", self) + self_rank = g.op("Size", self_shape) + dims_rank = g.op("Size", dims) + diff = g.op("Sub", self_rank, dims_rank) + const_zero = g.op("Constant", value_t=torch.tensor([0])) + + # 1. If dims is shorter than self.shape pad dims with 1 + dims_shorter_than_self_shape = g.op("Greater", diff, const_zero) + ( + if_op_greater, + (if_context_greater, else_context_greater), + _, + ) = jit_utils.add_op_with_blocks( + g, "If", dims_shorter_than_self_shape, n_blocks=2, outputs=1 + ) + const_one = if_context_greater.op("Constant", value_t=torch.LongTensor([1])) + diff_1d_greater = if_context_greater.op("Reshape", diff, const_one) + exapnd_ones_greater = if_context_greater.op("Expand", const_one, diff_1d_greater) + dims_ = if_context_greater.op("Concat", exapnd_ones_greater, dims, axis_i=0) + utils._add_output_to_block(if_context_greater.block, dims_) + identity_dim = else_context_greater.op("Identity", dims) + utils._add_output_to_block(else_context_greater.block, identity_dim) + dims_final = if_op_greater.node().output() + + # 2. If dims is longer than self.shape pad self.shape with 1 + dims_longer_than_self_shape = g.op("Less", diff, const_zero) + ( + if_op_less, + (if_context_less, else_context_less), + _, + ) = jit_utils.add_op_with_blocks( + g, "If", dims_longer_than_self_shape, n_blocks=2, outputs=1 + ) + const_one = if_context_less.op("Constant", value_t=torch.LongTensor([1])) + diff_1d_less = if_context_less.op( + "Reshape", + if_context_less.op("Abs", diff), + const_one, + ) + exapnd_ones_less = if_context_less.op("Expand", const_one, diff_1d_less) + self_final_shape = if_context_less.op( + "Concat", exapnd_ones_less, self_shape, axis_i=0 + ) + self_ = if_context_less.op("Reshape", self, self_final_shape) + utils._add_output_to_block(if_context_less.block, self_) + identity_self = else_context_less.op("Identity", self) + utils._add_output_to_block(else_context_less.block, identity_self) + self_final = if_op_less.node().output() + + dims_final = g.op("Cast", dims_final, to_i=_C_onnx.TensorProtoDataType.INT64) + return g.op("Tile", self_final, dims_final) + + +@_onnx_symbolic("aten::repeat_interleave") +def repeat_interleave( + g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None +): + repeats_dim = symbolic_helper._get_tensor_rank(repeats) + repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) + input_sizes = symbolic_helper._get_tensor_sizes(self) + if repeats_dim is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", + self, + ) + if repeats_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats size.", + self, + ) + if input_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown input size.", + self, + ) + + final_dim = dim + # if dim is None flatten + # By default, use the flattened input array, and return a flat output array + if symbolic_helper._is_none(dim): + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1])) + ) + dim = torch.tensor(0, dtype=torch.int64) + else: + dim = symbolic_helper._maybe_get_scalar(dim) + + # Handle cases where dim is negative + if dim < 0: + dim += len(input_sizes) + + output_sizes = input_sizes.copy() + for idx, input_size in enumerate(input_sizes): + if input_size is None: + output_sizes[idx], input_sizes[idx] = 0, -1 + + # Check if all indices should be repeated the same number of times. + if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): + return symbolic_helper._repeat_interleave_single_value_repeat_helper( + g, self, repeats, dim + ) + + cond_dynamic_repeats = repeats_dim == 1 and repeats_sizes[0] is None + # If input size is dynamic or repeats vector is dynamic + if output_sizes[dim] == 0 or cond_dynamic_repeats: + reps = symbolic_helper._size_helper(g, self, dim) + reps = opset11.unsqueeze(g, reps, 0) + + # Check if repeats is dynamic + # As repeats is dynamic, we use a where node as a substitute for the if statement + # If repests_dim = 1, expand repeats otherwise use original tensor + if cond_dynamic_repeats: + repeat_dim = symbolic_helper._size_helper( + g, repeats, g.op("Constant", value_t=torch.LongTensor([0])) + ) + repeat_cond = g.op( + "Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1])) + ) + repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats) + # There are cases when the repeats are 1-d tensor with multiple repeats, but dim + # provided along one of the dynamic axes provided. A simple example would be + # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2 + # Now, repeat interleaving can be performed in pytorch when the value of * matches + # with the number of elements in repeat, for example if * -> 2, number of repeats + # should be 2 as well. + else: + return opset9.repeat_interleave(g, self, repeats, final_dim) + + reps_like = g.op( + "ConstantOfShape", + g.op("Shape", repeats), + value_t=torch.tensor([1], dtype=torch.long), + ) + r_splits = split(g, repeats, reps_like, 0) + i_splits = split(g, self, reps_like, dim) + + output_sizes[dim], input_sizes[dim] = -1, 1 + + # Create a loop to iterate over each value along the dimension + # and perform individual interleaving using the repeats tensor + # Loop is of the following pattern + # input (trip_count, cond) + # int trip_count = ...; + # bool cond = ...; + # for (int i=0; i < trip_count && cond; ++i) { + # cond = ...; + # } + + # Loop conditions + loop_condition = g.op("Constant", value_t=torch.tensor(1)) + loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + loop_len = reps + + # Create an empty sequence to store final expansions + final_splits = g.op("SequenceEmpty") + + # Loop inputs + loop, (loop_context,), _ = jit_utils.add_op_with_blocks( + g, "Loop", loop_len, loop_condition, final_splits, n_blocks=1 + ) + + loop_block = loop_context.block + block_input_iter = utils._add_input_to_block(loop_block) + cond = utils._add_input_to_block(loop_block) # noqa: F841 + final_splits = utils._add_input_to_block(loop_block) + + r_split = loop_context.op("SequenceAt", r_splits, block_input_iter) + i_split = loop_context.op("SequenceAt", i_splits, block_input_iter) + + i_split = opset11.unsqueeze(loop_context, i_split, dim + 1) + r_concat = [ + loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[: dim + 1])), + r_split, + loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1 :])), + ] + r_concat = loop_context.op("Concat", *r_concat, axis_i=0) + i_split = opset9.expand(loop_context, i_split, r_concat, None) + i_split = symbolic_helper._reshape_helper( + loop_context, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes)) + ) + final_splits = loop_context.op("SequenceInsert", final_splits, i_split) + + # Loop outputs + cond_out = loop_context.op( + "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL + ) + utils._add_output_to_block(loop_block, cond_out) + utils._add_output_to_block(loop_block, final_splits) + + loop_out = loop.node().output() + loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim) + return loop_out + + +@_onnx_symbolic("aten::diagonal") +@symbolic_helper.parse_args("v", "i", "i", "i") +def diagonal(g: jit_utils.GraphContext, self, offset, dim1, dim2): + rank = symbolic_helper._get_tensor_rank(self) + # Replace negative indexing when rank is known + if rank is not None: + dim1 = dim1 if dim1 >= 0 else dim1 + rank + dim2 = dim2 if dim2 >= 0 else dim2 + rank + + dim1_size = opset9.size( + g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1])) + ) + dim2_size = opset9.size( + g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2])) + ) + # Create appropriate mask + mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0) + mask = opset9.zeros(g, mask_shape, None, None, None) + mask = g.op("EyeLike", mask, k_i=offset) + # dim1 and dim2 appended as a dimension at the end of the shape + + if rank is not None: + axes = list(range(rank)) + axes.remove(dim1) + axes.remove(dim2) + self = g.op("Transpose", self, perm_i=axes + [dim1, dim2]) + else: + return symbolic_helper._unimplemented("diagonal", "unknown input rank") + + # Multiply input and mask to calculate values along diagonal + # The mask consists of one values where diagonal values are to be calculated + # For example: + # [[1.1, 1.2, 1.3], * [[1, 0, 0] = [[1.1, 0, 0], + # [2.1, 2.2, 2.3], [0, 1, 0] [0, 2.2, 0], + # [3.1, 3.2, 3.3]] [0, 0, 1]] [0, 0, 3.3]] + result = g.op("Mul", self, mask) + result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0) + + # Calculate gather indices based on offset and dims + # If offset is greater than zero, set offset to zero as this aids in + # calculation of selection window + offset_op = g.op("Constant", value_t=torch.LongTensor([offset])) + if offset >= 0: + diag_size = g.op( + "Max", + g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)), + g.op("Constant", value_t=torch.LongTensor([0])), + ) + offset = 0 + else: + diag_size = g.op( + "Max", + g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size), + g.op("Constant", value_t=torch.LongTensor([0])), + ) + diag_size = g.op("Concat", diag_size, axis_i=0) + + # Calculate which diagonal values to select + # For example, in cases with offsets: + # [[0, 1.1, 0] + # [0, 0, 2.2]] + # we need to select the last two columns, so we create a tensor + # with all columns that are to be selected + # So in this example, it is [1, 2] + select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None) + select_window = g.op( + "CumSum", + select_window_ones_fill, + g.op("Constant", value_t=torch.LongTensor([0])), + ) + select_window = g.op( + "Add", + select_window, + g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])), + ) + + gather_shape = [ + opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis]))) + for axis in list(range(rank))[:-2] + ] + gather_shape.append(diag_size) + gather_shape = g.op("Concat", *gather_shape, axis_i=0) + gather_indices = opset9.zeros(g, gather_shape, 4, None, None) + + # There might be cases where offset value is greater than number of rows/columns + # and might cause the diagonal to overrun and as a result of this, diag_size would be zero. + # For example, if + # offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows) + # diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above + # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0 + # In cases without diagonal overrun, we select the appropriate rows/columns along which we + # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has + # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially + # returning an empty tensor + overrun_cond = g.op( + "Not", + g.op( + "Equal", + diag_size, + g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)), + ), + ) + + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", overrun_cond, n_blocks=2 + ) + + gather_indices_if_block = if_context.op("Add", gather_indices, select_window) + gather_indices_if_block = symbolic_helper._unsqueeze_helper( + if_context, gather_indices_if_block, [rank - 1] + ) + final_non_overrun = if_context.op( + "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2 + ) + final_overrun = opset9.zeros(else_context, gather_shape, 6, None, None) + utils._add_output_to_block(if_context.block, final_non_overrun) + utils._add_output_to_block(else_context.block, final_overrun) + return if_op + + +# Quantized ops + + +@_onnx_symbolic("quantized::linear") +def quantized_linear( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::linear_relu") +def quantized_linear_relu( + g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.linear(g, input, weight, bias) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d_relu") +def quantized_conv1d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d_relu") +def quantized_conv2d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d_relu") +def quantized_conv3d_relu( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + output = opset9.relu(g, output) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv1d") +def quantized_conv1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv2d") +def quantized_conv2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv3d") +def quantized_conv3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv3d(g, input, weight, bias, stride, padding, dilation, groups) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose1d") +def quantized_conv_transpose1d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose2d") +def quantized_conv_transpose2d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose2d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +@_onnx_symbolic("quantized::conv_transpose3d") +def quantized_conv_transpose3d( + g: jit_utils.GraphContext, + q_input, + q_weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + op_scale, + op_zero_point, +): + input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) + weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight) + q_bias = symbolic_helper.requantize_bias_helper( + g, bias, input_scale, weight_scale, axis + ) + bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) + + output = opset9.conv_transpose3d( + g, input, weight, bias, stride, padding, output_padding, groups, dilation + ) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py index 367aa9eb0832a..106566e26c05d 100644 --- a/torch/onnx/symbolic_opset14.py +++ b/torch/onnx/symbolic_opset14.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.symbolic_opset14.""" from __future__ import annotations @@ -6,3 +7,290 @@ __all__: list[str] = [] from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import * # noqa: F401,F403 +======= +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 14. + +Note [ONNX operators that are added/updated in opset 14] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +New operators: + HardSwish, Trilu + +Updated operators: + Reshape + Add, Sub, Mul, Div + GRU, LSTM, RNN + BatchNorm, Cumsum, Relu +""" + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md +from __future__ import annotations + +import functools + +import torch +from torch.onnx import _constants, _type_utils, symbolic_helper +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import jit_utils, registration + + +__all__ = [ + "hardswish", + "tril", + "triu", + "reshape", + "batch_norm", + "quantized_hardswish", + "scaled_dot_product_attention", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14) + + +@_onnx_symbolic("aten::hardswish") +@symbolic_helper.parse_args("v") +def hardswish(g: jit_utils.GraphContext, self): + return g.op("HardSwish", self) + + +@_onnx_symbolic("aten::tril") +def tril(g: jit_utils.GraphContext, self, diagonal, out=None): + return g.op("Trilu", self, diagonal, upper_i=0) + + +@_onnx_symbolic("aten::triu") +def triu(g: jit_utils.GraphContext, self, diagonal, out=None): + return g.op("Trilu", self, diagonal, upper_i=1) + + +@_onnx_symbolic("aten::reshape") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v") +def reshape(g: jit_utils.GraphContext, self, shape): + # NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664 + # Reshape export cannot utilize the new allowzero attribute introduced in opset 14. + return symbolic_helper._reshape_helper(g, self, shape, allowzero=0) + + +@_onnx_symbolic("aten::batch_norm") +@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") +def batch_norm( + g: jit_utils.GraphContext, + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + cudnn_enabled, +): + if ( + torch.is_autocast_enabled() + and not symbolic_helper.args_have_same_dtype( + [input, weight, bias, running_mean, running_var] + ) + and GLOBALS.export_onnx_opset_version < 15 + ): + return symbolic_helper._onnx_opset_unsupported_detailed( + "BatchNormalization", + 14, + 15, + "All input tensors must have the same `dtype`." + " Turn off Autocast or export using opset version 15.", + input, + ) + + symbolic_helper.check_training_mode(training, "batch_norm") + weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( + g, input, weight, bias, running_mean, running_var + ) + out = g.op( + "BatchNormalization", + input, + weight, + bias, + running_mean, + running_var, + epsilon_f=eps, + momentum_f=1 - momentum, + training_mode_i=0 if not training else 1, + outputs=1 if not training else 3, + ) + if not training: + return out + else: + res, new_running_mean, new_running_var = out + new_running_mean.setType(running_mean.type()) + new_running_var.setType(running_var.type()) + return res + + +@_onnx_symbolic("quantized::hardswish") +def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = hardswish(g, x) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +# Ported from +# https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/nn.py#L1504 +# aten_scaled_dot_product_attention +# NOTE: Need op.Trilu +@_onnx_symbolic("aten::scaled_dot_product_attention") +@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b") +def scaled_dot_product_attention( + g: jit_utils.GraphContext, + query: torch._C.Value, + key: torch._C.Value, + value: torch._C.Value, + attn_mask: torch._C.Value | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: torch._C.Value | None = None, + enable_gqa: bool = False, +): + assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), ( + "is_causal and attn_mask cannot be set at the same time" + ) + assert not enable_gqa, ( + "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + ) + + if symbolic_helper._is_none(scale): + scale = _attention_scale(g, query) + + if is_causal: + attn_mask = _causal_attention_mask(g, query, key) + + # Swap the last two axes of key + # NOTE: onnx-script has different logic here, because the attribute perms in + # transpose needs list of ints + key_shape_builtin = symbolic_helper._get_tensor_rank(key) + key_transposed_axes = list(range(key_shape_builtin)) + key_transposed_axes[-1], key_transposed_axes[-2] = ( + key_transposed_axes[-2], + key_transposed_axes[-1], + ) + key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes) + + # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 + # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math + query_scaled = g.op("Mul", query, g.op("Sqrt", scale)) + key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale)) + mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled) + + if symbolic_helper._is_none(attn_mask): + mul_qk_add = mul_qk + elif ( + _type_utils.JitScalarType.from_value(attn_mask) + == _type_utils.JitScalarType.BOOL + ): + # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) + const_zero = g.op("Constant", value_t=torch.tensor([0.0])) + const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) + attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) + mul_qk_add = g.op("Add", mul_qk, attn_mask) + elif _type_utils.JitScalarType.from_value(attn_mask) in ( + _type_utils.JitScalarType.FLOAT, + _type_utils.JitScalarType.HALF, + _type_utils.JitScalarType.BFLOAT16, + ): + mul_qk_add = g.op("Add", mul_qk, attn_mask) + else: + raise ValueError( + f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" + ) + + attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) + + if dropout_p != 0: + attn_weight = g.op( + "Dropout", + attn_weight, + g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)), + ) + + return g.op("MatMul", attn_weight, value) + + +def _attention_scale( + g: jit_utils.GraphContext, query: torch._C.Value +) -> torch._C.Value: + """Calculate the scale factor for the attention result. + + Args: + query: Tensor of shape [..., L, E] + + Returns: + Scalar scale factor := 1 / math.sqrt(query.size(-1)) + """ + query_shape = g.op("Shape", query) + query_shape_last = g.op( + "Slice", + query_shape, + g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), + g.op( + "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) + ), + ) + embedding_size = g.op( + "Cast", + query_shape_last, + to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), + ) + const_one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float)) + scale = g.op("Div", const_one, g.op("Sqrt", embedding_size)) + # Add a Cast to convert the scale back to original type + scale = g.op( + "Cast", + scale, + to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), + ) + return scale + + +def _causal_attention_mask( + g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value +) -> torch._C.Value: + """Create a causal mask for the given query and key tensors. + + Equivalent to:: + mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_mask = torch.zeros(L, S, dtype=torch.float) + attn_mask = attn_mask.masked_fill(not mask, -float("inf")) + + Args: + query: Tensor of shape [..., L, E] + key: Tensor of shape [..., S, E] + + Returns: + Tensor of shape [L, S] + """ + + query_shape = g.op("Shape", query) + key_shape = g.op("Shape", key) + + last_idx = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + second_last_idx = g.op("Constant", value_t=torch.tensor([-2], dtype=torch.int64)) + target_length = g.op("Slice", query_shape, second_last_idx, last_idx) + source_length = g.op("Slice", key_shape, second_last_idx, last_idx) + # attn_mask = torch.ones(L, S) := { + size = g.op("Concat", target_length, source_length, axis_i=0) + const_one = g.op("Constant", value_t=torch.tensor([1.0])) + attn_mask = g.op("Expand", const_one, size) + # } + attn_mask = g.op("Trilu", attn_mask, upper_i=0) + # The causal mask has 0s in the lower triangle and -inf in the upper triangle. + const_zero = g.op("Constant", value_t=torch.tensor([0.0])) + const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) + attn_mask = g.op( + "Where", g.op("Equal", attn_mask, const_zero), const_neg_inf, const_zero + ) + return attn_mask +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/symbolic_opset15.py b/torch/onnx/symbolic_opset15.py index e04e3b0452127..dd30af7729de7 100644 --- a/torch/onnx/symbolic_opset15.py +++ b/torch/onnx/symbolic_opset15.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.symbolic_opset15.""" from __future__ import annotations @@ -6,3 +7,85 @@ __all__: list[str] = [] from torch.onnx._internal.torchscript_exporter.symbolic_opset15 import * # noqa: F401,F403 +======= +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 15. + +Note [ONNX operators that are added/updated in opset 15] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set +New operators: + Bernoulli + CastLike + Optional + OptionalGetElement + OptionalHasElement + +Updated operators: + BatchNormalization https://github.com/onnx/onnx/pull/3545 + Backwards compatible + TODO: test coverage for mixed types inputs. + Pow https://github.com/onnx/onnx/pull/3412 + Backwards compatible + TODO: bfloat16 support. + Shape https://github.com/onnx/onnx/pull/3580 + Backwards compatible + TODO: optional start/end attribute. +""" + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +import functools + +import torch +from torch import _C +from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import jit_utils, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15) + + +@_onnx_symbolic("aten::__is_") +def aten__is_(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_none(other): + if isinstance(self.type(), _C.OptionalType): + none = g.op("OptionalHasElement", self) + return g.op("Not", none) + else: + return g.op("Constant", value_t=torch.BoolTensor([0])) + return opset9.eq(g, self, other) + + +@_onnx_symbolic("aten::__isnot_") +@opset9.wrap_logical_op_with_negation # type: ignore[has-type] +def aten__isnot_(g: jit_utils.GraphContext, self, other): + return aten__is_(g, self, other) + + +@_onnx_symbolic("aten::bernoulli") +def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): + if out is not None and not symbolic_helper._is_none(out): + symbolic_helper._unimplemented( + "Bernoulli", "out parameter is not supported for bernoulli", input + ) + if generator is not None and not symbolic_helper._is_none(generator): + symbolic_helper._unimplemented( + "Bernoulli", "generator is not supported for bernoulli", input + ) + if p is None or symbolic_helper._is_none(p): + return g.op("Bernoulli", input) + return opset9.bernoulli(g, input, p, generator, out) + + +@_onnx_symbolic("prim::unchecked_cast") +def prim_unchecked_cast(g: jit_utils.GraphContext, self): + # exists to refine the type of the Value + # if x is Optional[Tensor], unchecked_cast will cast + # x to Tensor, so the rest of the graph knows that x is a Tensor. + if isinstance(self.type(), _C.OptionalType): + return g.op("OptionalGetElement", self) + + return self +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/symbolic_opset16.py b/torch/onnx/symbolic_opset16.py index 9a248bb0f26c5..adde30c642dbe 100644 --- a/torch/onnx/symbolic_opset16.py +++ b/torch/onnx/symbolic_opset16.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.symbolic_opset16.""" from __future__ import annotations @@ -6,3 +7,190 @@ __all__: list[str] = [] from torch.onnx._internal.torchscript_exporter.symbolic_opset16 import * # noqa: F401,F403 +======= +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 16. + +Note [ONNX Operators that are added/updated in opset 16] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set +New operators: + GridSample https://github.com/onnx/onnx/pull/3557 + +Updated operators: + Identity + If + LeakyRelu + Loop + PRelu + RoiAlign + Scan + ScatterElements + ScatterND + Where + GreaterOrEqual + LessOrEqual +""" + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +import functools + +import torch +from torch.nn.functional import ( + GRID_SAMPLE_INTERPOLATION_MODES, + GRID_SAMPLE_PADDING_MODES, +) +from torch.onnx import _type_utils, errors, symbolic_helper, utils +from torch.onnx._internal import jit_utils, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16) + + +# note (mkozuki): Why `grid_sampler` instead of `grid_sample`? +# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`. +@_onnx_symbolic("aten::grid_sampler") +@symbolic_helper.parse_args("v", "v", "i", "i", "b") +def grid_sampler( + g: jit_utils.GraphContext, + input, + grid, + mode_enum, + padding_mode_enum, + align_corners, +): + # Check the input and grid tensor rank beforehand. + if symbolic_helper._get_tensor_rank(input) == 5: + return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input") + mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg] + padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg] + padding_mode_enum + ] + return g.op( + "GridSample", + input, + grid, + align_corners_i=int(align_corners), + mode_s=mode_s, + padding_mode_s=padding_mode_s, + ) + + +@_onnx_symbolic("aten::scatter_add") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): + src_type = _type_utils.JitScalarType.from_value( + src, _type_utils.JitScalarType.UNDEFINED + ) + src_sizes = symbolic_helper._get_tensor_sizes(src) + index_sizes = symbolic_helper._get_tensor_sizes(index) + + if len(src_sizes) != len(index_sizes): + return symbolic_helper._unimplemented( + "scatter_add", + f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})", + ) + + # PyTorch only allows index shape <= src shape, so we can only consider + # taking index as subset size to src, like PyTorch does. When sizes for src + # and index are not matched or there are dynamic axes, we take index shape to + # slice src to accommodate. + if src_sizes != index_sizes or None in index_sizes: + adjusted_shape = g.op("Shape", index) + starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes))) + src = g.op("Slice", src, starts, adjusted_shape) + + src = symbolic_helper._maybe_get_scalar(src) + if symbolic_helper._is_value(src): + return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add") + else: + # Check if scalar "src" has same type as self (PyTorch allows different + # type for scalar src (but not when src is tensor)). If not, insert Cast node. + if _type_utils.JitScalarType.from_value(self) != src_type: + src = g.op( + "Cast", + src, + to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), + ) + + return g.op( + "ScatterElements", + self, + index, + src, + axis_i=dim, + reduction_s="add", + ) + + +@_onnx_symbolic("aten::scatter_reduce") +@symbolic_helper.parse_args("v", "i", "v", "v", "s", "b") +def scatter_reduce( + g: jit_utils.GraphContext, + self: torch._C.Value, + dim: int, + index: torch._C.Value, + src: torch._C.Value, + reduce: str, + include_self: bool, +): + if reduce == "mean": + raise errors.OnnxExporterError( + "ONNX does not support mean reduction for scatter_reduce" + ) + if not include_self: + raise errors.OnnxExporterError( + "ONNX does not support include_self=False for scatter_reduce" + ) + + reduce_mode = { # convert torch string name to onnx string name + "mean": "none", # 'mean' doesn't support in ONNX 1.14 definition + "sum": "add", + "prod": "mul", + "amin": "min", + "amax": "max", + } + onnx_reduce = reduce_mode[reduce] + + self_rank = g.op("Size", g.op("Shape", self)) + + # if self_rank == 0: # assert (index_rank == 0 and rank_src == 0) + self_rank_is_zero = g.op( + "Equal", self_rank, g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) + ) + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", self_rank_is_zero, n_blocks=2, outputs=3 + ) + neg_1 = if_context.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + + self_reshape = if_context.op("Reshape", self, neg_1) + utils._add_output_to_block(if_context.block, self_reshape) + index_reshape = if_context.op("Reshape", index, neg_1) + utils._add_output_to_block(if_context.block, index_reshape) + src_reshape = if_context.op("Reshape", src, neg_1) + utils._add_output_to_block(if_context.block, src_reshape) + + self_identity = else_context.op("Identity", self) + utils._add_output_to_block(else_context.block, self_identity) + index_identitye = else_context.op("Identity", index) + utils._add_output_to_block(else_context.block, index_identitye) + src_identity = else_context.op("Identity", src) + utils._add_output_to_block(else_context.block, src_identity) + + result = g.op("ScatterElements", *if_op, axis_i=dim, reduction_s=onnx_reduce) + + # if self_rank == 0: + if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( + g, "If", self_rank_is_zero, n_blocks=2, outputs=1 + ) + result_squeezed = if_context.op("Squeeze", result) + utils._add_output_to_block(if_context.block, result_squeezed) + result_identity = else_context.op("Identity", result) + utils._add_output_to_block(else_context.block, result_identity) + result_final = if_op.node().output() + + return result_final +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/symbolic_opset17.py b/torch/onnx/symbolic_opset17.py index 800acd446b5dc..d028b8cf9c109 100644 --- a/torch/onnx/symbolic_opset17.py +++ b/torch/onnx/symbolic_opset17.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.symbolic_opset17.""" from __future__ import annotations @@ -6,3 +7,244 @@ __all__: list[str] = [] from torch.onnx._internal.torchscript_exporter.symbolic_opset17 import * # noqa: F401,F403 +======= +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 17. + +Note [ONNX Operators that are added/updated in opset 17] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set +New operators: + BlackmanWindow + DFT + HammingWindow + HannWindow + LayerNormalization + MelWeightMatrix + STFT + SequenceMap +""" + +import functools +from collections.abc import Sequence +from typing import Optional + +import torch +from torch import _C +from torch.onnx import _type_utils, errors, symbolic_helper +from torch.onnx._internal import jit_utils, registration + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +__all__ = ["layer_norm", "stft", "quantized_layer_norm"] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17) + + +@_onnx_symbolic("aten::layer_norm") +@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") +def layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, + cudnn_enable: bool, +): + # normalized_shape: input shape from an expected input of size + # axis: The first normalization dimension. + # layer_norm normalizes on the last D dimensions, + # where D is the size of normalized_shape + axis = -len(normalized_shape) + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + dtype = scalar_type.dtype() + if symbolic_helper._is_none(weight): + weight_value = torch.ones(normalized_shape, dtype=dtype) + weight = g.op("Constant", value_t=weight_value) + if symbolic_helper._is_none(bias): + bias_value = torch.zeros(normalized_shape, dtype=dtype) + bias = g.op("Constant", value_t=bias_value) + return g.op( + "LayerNormalization", + input, + weight, + bias, + epsilon_f=eps, + axis_i=axis, + ) + + +@_onnx_symbolic("quantized::layer_norm") +def quantized_layer_norm( + g: jit_utils.GraphContext, + x, + normalized_shape, + weight, + bias, + eps, + op_scale, + op_zero_point, +): + x, _, _, _ = symbolic_helper.dequantize_helper(g, x) + + output = layer_norm(g, x, normalized_shape, weight, bias, eps, False) + + return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) + + +def _compute_edge_sizes(n_fft, window_size): + """Helper function to compute the sizes of the edges (left and right) + of a given window centered within an FFT size.""" + left = (n_fft - window_size) // 2 + right = n_fft - left - window_size + return left, right + + +@_onnx_symbolic("aten::stft") +@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b", "b") +def stft( + g: jit_utils.GraphContext, + input: _C.Value, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[_C.Value] = None, + normalized: bool = False, + onesided: Optional[bool] = True, + return_complex: Optional[bool] = False, + align_to_window: Optional[bool] = None, +) -> _C.Value: + """Associates `torch.stft` with the `STFT` ONNX operator. + Note that torch.stft calls _VF.stft, without centering or padding options. + Hence, this function does not contain these two arguments. + See torch.stft source code for more info. + + Args: + g: Graph to write the ONNX representation into + input: Input tensor for the transformation + n_fft: FFT size + hop_length: Size of the hop. Defaults to `floot(n_fft // 4)` + win_length: Size of the analysis window. Defaults to `n_fft` + window: Analysis window. Defaults to a window of all ones + normalized: Whether to return a normalized STFT + onesided: Whether to return only half (+1) of the results, given the + symmetry of the STFT + return_complex: Whether to return the complex value (Note: Must be + `False` or `None`) + + Returns: + op: Operator for torch.stft associated with STFT (ONNX) + """ + # Checks + if return_complex: + raise errors.SymbolicValueError( + msg="STFT does not currently support complex types", value=input + ) + + if align_to_window is not None: + raise errors.SymbolicValueError( + msg="STFT does not currently support the align_to_window option", + value=input, + ) # TODO(#145944): add compatibility with align_to_window option. + + # Get STFT sizes + frame_step_value = hop_length if hop_length is not None else n_fft // 4 + frame_step_const = g.op( + "Constant", value_t=torch.tensor(frame_step_value, dtype=torch.int64) + ) + frame_length_const = g.op( + "Constant", value_t=torch.tensor(n_fft, dtype=torch.int64) + ) + + # Pre-process input if needed + signal = input + signal_rank = symbolic_helper._get_tensor_rank(signal) + if signal_rank == 1: + # Add batch dimension + signal = g.op( + "Unsqueeze", + signal, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + elif signal_rank is None or signal_rank > 2: + raise errors.SymbolicValueError( + msg="STFT can only take inputs of 1 [signal] or 2 [batch, signal] dimensions. " + f"Current rank of signal is {signal_rank}, please reduce it.", + value=input, + ) + + # Get window and make sure it's the same size as `win_length` or `n_fft` + n_win = symbolic_helper._get_tensor_dim_size(window, dim=0) + if n_win is not None: + win_length_default = win_length if win_length else n_fft + assert n_win == win_length_default, ( + "Analysis window size must equal `win_length` or `n_fft`. " + f"Please, set `win_length` or `n_fft` to match `window` size ({n_win})", + ) + + # Center window around zeros if needed (required by ONNX's STFT) + if n_win < n_fft: + left, right = _compute_edge_sizes(n_fft, n_win) + left_win = g.op("Constant", value_t=torch.zeros(left)) + right_win = g.op("Constant", value_t=torch.zeros(right)) + window = g.op("Concat", left_win, window, right_win, axis_i=0) + + # Create window, if needed + if symbolic_helper._is_none(window): + if win_length: + if win_length > n_fft: + raise errors.SymbolicValueError( + msg="The analysis window can't be longer than the size of the FFT. " + f"Please set `win_length` ({win_length}) to `n_fft` ({n_fft}) or less.", + value=input, + ) + + # Center window, if needed + left, right = _compute_edge_sizes(n_fft, win_length) + torch_window = torch.hstack( + (torch.zeros(left), torch.ones(win_length), torch.zeros(right)) + ) + else: + # Rectangle window + torch_window = torch.ones(n_fft) + assert torch_window.shape[0] == n_fft + window = g.op("Constant", value_t=torch_window) + window = g.op( + "Cast", window, to_i=_type_utils.JitScalarType.from_value(signal).onnx_type() + ) + + # Run STFT + result = g.op( + "STFT", + signal, + frame_step_const, + window, + frame_length_const, + onesided_i=1 if onesided is None or onesided else 0, + ) + + # Transpose to mimic torch.stft's behavior + result = g.op("Transpose", result, perm_i=[0, 2, 1, 3]) + + # Remove batch dimension, if needed + if signal_rank == 1: + result = g.op( + "Squeeze", + result, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)), + ) + + # Normalize, if needed + if normalized: + sqrt_nfft = torch.sqrt(torch.tensor(n_fft, dtype=signal.type().dtype())) + result = g.op("Div", result, g.op("Constant", value_t=sqrt_nfft)) + + return result +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/symbolic_opset18.py b/torch/onnx/symbolic_opset18.py index cc07a60f018d8..5ade100a5b031 100644 --- a/torch/onnx/symbolic_opset18.py +++ b/torch/onnx/symbolic_opset18.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.symbolic_opset18.""" from __future__ import annotations @@ -6,3 +7,270 @@ __all__: list[str] = [] from torch.onnx._internal.torchscript_exporter.symbolic_opset18 import * # noqa: F401,F403 +======= +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 18. + +Note [ONNX Operators that are added/updated in opset 18] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set +New operators: + BitwiseAnd + CenterCropPad + Col2Im + Mish + OptionalGetElement + OptionalHasElement + Pad + Resize + ScatterElements + ScatterND + Split +""" + +import functools +from collections.abc import Sequence +from typing import Optional + +import torch +from torch import _C +from torch.onnx import _type_utils, symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import jit_utils, registration + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in symbolic_helper.py + +__all__ = [ + "col2im", +] + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18) + + +@_onnx_symbolic("aten::__and_") +@_onnx_symbolic("aten::bitwise_and") +def __and_(g: jit_utils.GraphContext, self, other): + # do type promotion (scalars don't seem to apply) + args = [self, other] + # type promotion doesn't happen with torch.bitwise_and(tensor, scalar) + prom_args = [arg for arg in args if symbolic_helper._get_tensor_rank(arg)] + if len(prom_args) == 0: + prom_args = args + promotion_jit_type = symbolic_helper._type_promote_from_values(*prom_args) + self = symbolic_helper._maybe_cast_to_type(g, self, promotion_jit_type) + other = symbolic_helper._maybe_cast_to_type(g, other, promotion_jit_type) + if promotion_jit_type == _type_utils.JitScalarType.BOOL: + return g.op("And", self, other) + return g.op("BitwiseAnd", self, other) + + +@_onnx_symbolic("aten::col2im") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is") +def col2im( + g, + input: _C.Value, + output_size: _C.Value, + kernel_size: _C.Value, + dilation: Sequence[int], + padding: Sequence[int], + stride: Sequence[int], +): + # convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in] + adjusted_padding: list[int] = [] + for pad in padding: + adjusted_padding.extend(pad for _ in range(2)) + + num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0] + if not adjusted_padding: + adjusted_padding = [0, 0] * num_dimensional_axis + + if not dilation: + dilation = [1] * num_dimensional_axis + + if not stride: + stride = [1] * num_dimensional_axis + + return g.op( + "Col2Im", + input, + output_size, + kernel_size, + dilations_i=dilation, + pads_i=adjusted_padding, + strides_i=stride, + ) + + +@_onnx_symbolic( + "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")] +) +@_onnx_symbolic( + "aten::prod", + decorate=[ + symbolic_helper._apply_params( + "ReduceProd", "prod", allow_multi_dim_support=False + ) + ], +) +def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): + return symbolic_helper._reduce_with_dtype_helper( + onnx_op, name, allow_multi_dim_support + ) + + +@_onnx_symbolic("aten::native_layer_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "is", "v", "v", "f") +def _native_layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, +) -> tuple[_C.Value, _C.Value, _C.Value]: + return opset9.native_layer_norm(g, input, normalized_shape, weight, bias, eps) + + +@_onnx_symbolic("aten::glu") +@symbolic_helper.parse_args("v", "i") +def _glu(g: jit_utils.GraphContext, input, dim): + dim_size = symbolic_helper._get_tensor_dim_size(input, dim) + if dim_size is not None: + assert dim_size % 2 == 0 + + first, second = g.op("Split", input, axis_i=dim, num_outputs_i=2, outputs=2) + return g.op("Mul", first, g.op("Sigmoid", second)) + + +@_onnx_symbolic("aten::max") +# torch.max (same for torch.min) actually has two interfaces smashed together: +# torch.max(x, dim, keepdim) and torch.max(x, y) +# TODO(justinchuby): Support multiple quantized args in output +def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._max_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::maximum") +@symbolic_helper.quantized_args(True, True) +def maximum(g: jit_utils.GraphContext, input, other): + return max(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::min") +# TODO(justinchuby): Support multiple quantized args in output +def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._min_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::minimum") +@symbolic_helper.quantized_args(True, True) +def minimum(g: jit_utils.GraphContext, input, other): + return min(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::amax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amax(g: jit_utils.GraphContext, self, dim, keepdim): + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + return g.op("ReduceMax", self, axes, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::amin") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amin(g: jit_utils.GraphContext, self, dim, keepdim): + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + return g.op("ReduceMin", self, axes, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::aminmax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "i") +def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): + if not symbolic_helper._is_none(dim): + dim = symbolic_helper._get_const(dim, "i", "dim") + axes = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) + return g.op("ReduceMin", self, axes, keepdims_i=keepdim), g.op( + "ReduceMax", self, axes, keepdims_i=keepdim + ) + else: + return g.op("ReduceMin", self, keepdims_i=keepdim), g.op( + "ReduceMax", self, keepdims_i=keepdim + ) + + +@_onnx_symbolic("aten::var_mean") +def _var_mean(g: jit_utils.GraphContext, input, *args): + if len(args) == 1: + return symbolic_helper._var_mean_helper(g, input, None, args[0], None) + else: + return symbolic_helper._var_mean_helper(g, input, *args) + + +@_onnx_symbolic("aten::logsumexp") +@symbolic_helper.parse_args("v", "is", "i") +def _logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): + if dim is None: + return g.op("ReduceLogSumExp", input, keepdims_i=0) + else: + axes = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long)) + return g.op("ReduceLogSumExp", input, axes, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::linalg_matrix_norm") +@symbolic_helper.parse_args("v", "v", "is", "b", "v") +def _linalg_matrix_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: torch._C.Value, + dim: list[int], + keepdim: bool, + dtype: torch._C.Value, +): + return opset9.linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + return symbolic_helper._embedding_bag_helper( + g, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, + ) + + +@_onnx_symbolic("aten::linalg_vector_norm") +@symbolic_helper.parse_args("v", "f", "is", "b", "v") +def linalg_vector_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: float, + dim: Optional[Sequence[int]], + keepdim: bool, + dtype: torch._C.Value, +): + return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/symbolic_opset19.py b/torch/onnx/symbolic_opset19.py index 4f7a54fc1dd38..9a5908ac8cf02 100644 --- a/torch/onnx/symbolic_opset19.py +++ b/torch/onnx/symbolic_opset19.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.symbolic_opset19.""" from __future__ import annotations @@ -6,3 +7,36 @@ __all__: list[str] = [] from torch.onnx._internal.torchscript_exporter.symbolic_opset19 import * # noqa: F401,F403 +======= +"""This file exports ONNX ops for opset 19. + +Note [ONNX Operators that are added/updated in opset 19] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-19-of-the-default-onnx-operator-set +New operators: +AveragePool +Cast +CastLike +Constant +DeformConv +DequantizeLinear +Equal +Identity +If +Loop +Pad +QuantizeLinear +Reshape +Resize +Scan +Shape +Size +""" + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in symbolic_helper.py + +__all__: list[str] = [] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/symbolic_opset20.py b/torch/onnx/symbolic_opset20.py index 56635a7811611..5b588565aa0fe 100644 --- a/torch/onnx/symbolic_opset20.py +++ b/torch/onnx/symbolic_opset20.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.symbolic_opset20.""" from __future__ import annotations @@ -6,3 +7,97 @@ __all__: list[str] = [] from torch.onnx._internal.torchscript_exporter.symbolic_opset20 import * # noqa: F401,F403 +======= +# mypy: allow-untyped-defs +"""This file exports ONNX ops for opset 20. + +Note [ONNX Operators that are added/updated in opset 20] + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set +New operators: + AffineGrid + ConstantOfShape + DFT + Gelu + GridSample + ImageDecoder + IsInf + IsNaN + ReduceMax + ReduceMin + RegexFullMatch + StringConcat + StringSplit +""" + +import functools + +import torch.nn.functional as F +from torch import _C +from torch.onnx import symbolic_helper +from torch.onnx._internal import jit_utils, registration + + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in symbolic_helper.py + +__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"] + + +def convert_grid_sample_mode(mode_s): + return ( + "linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s + ) + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20) + + +@_onnx_symbolic("aten::grid_sampler") +@symbolic_helper.parse_args("v", "v", "i", "i", "b") +def _grid_sampler( + g: jit_utils.GraphContext, + input: _C.Value, + grid: _C.Value, + mode_enum: int, + padding_mode_enum: int, + align_corners: bool, +): + mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index] + # mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html + mode_s = convert_grid_sample_mode(mode_s) + padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index] + padding_mode_enum # type: ignore[index] + ] + return g.op( + "GridSample", + input, + grid, + align_corners_i=int(align_corners), + mode_s=mode_s, + padding_mode_s=padding_mode_s, + ) + + +@_onnx_symbolic("aten::affine_grid_generator") +@symbolic_helper.parse_args("v", "v", "b") +def _affine_grid_generator( + g: jit_utils.GraphContext, + theta: _C.Value, + size: _C.Value, + align_corners: bool, +): + return g.op( + "AffineGrid", + theta, + size, + align_corners_i=int(align_corners), + ) + + +@_onnx_symbolic("aten::gelu") +@symbolic_helper.parse_args("v", "s") +def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"): + return g.op("Gelu", self, approximate_s=approximate) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/symbolic_opset7.py b/torch/onnx/symbolic_opset7.py index c11e769677ec4..7bf870a322c68 100644 --- a/torch/onnx/symbolic_opset7.py +++ b/torch/onnx/symbolic_opset7.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.symbolic_opset7.""" from __future__ import annotations @@ -6,3 +7,72 @@ __all__: list[str] = [] from torch.onnx._internal.torchscript_exporter.symbolic_opset7 import * # noqa: F401,F403 +======= +# mypy: allow-untyped-defs +""" +Note [ONNX operators that are added/updated from opset 7 to opset 8] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +New operators: + Expand + +Updated operators: + Min, Max, Sum, Mean: supports multidirectional broadcasting. + MaxPool: added optional indices output. + Scan +""" + +import functools +import warnings + +from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import jit_utils, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=7) + +block_listed_operators = ( + "scan", + "expand", + "expand_as", + "meshgrid", + "adaptive_max_pool1d", + "adaptive_max_pool2d", + "adaptive_max_pool3d", + "max_pool1d_with_indices", + "max_pool2d_with_indices", + "max_pool3d_with_indices", +) + + +# NOTE: max, min, sum, mean: broadcasting is not supported in opset 7. +# torch.max (same for torch.min) actually has two interfaces smashed together: +# torch.max(x, dim, keepdim) and torch.max(x, y) +@_onnx_symbolic("aten::max") +def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.max(input, other) + if keepdim is None and dim_or_y is not None: + warnings.warn( + "Multidirectional broadcasting is not supported in opset 7. " + "This might cause the onnx model to be incorrect, if inputs to max operators " + "have different shapes" + ) + return opset9.max(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::min") +def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + # torch.min(input, other) + if keepdim is None and dim_or_y is not None: + warnings.warn( + "Multidirectional broadcasting is not supported in opset 7. " + "This might cause the onnx model to be incorrect, if inputs to min operators " + "have different shapes" + ) + return opset9.min(g, self, dim_or_y, keepdim) + + +for block_listed_op in block_listed_operators: + _onnx_symbolic(f"aten::{block_listed_op}")( + symbolic_helper._block_list_in_opset(block_listed_op) + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index 0e4411649f3e0..ba799bd6a5721 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.symbolic_opset8.""" from __future__ import annotations @@ -6,3 +7,468 @@ __all__: list[str] = [] from torch.onnx._internal.torchscript_exporter.symbolic_opset8 import * # noqa: F401,F403 +======= +# mypy: allow-untyped-defs +""" +Note [ONNX operators that are added/updated from opset 8 to opset 9] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +New operators: + Compress + ConstantOfShape + EyeLike + MaxUnpool + OneHot + Sinh + Cosh + Asinh + Acosh + Atanh + Shrink + IsNaN + Sign + Erf + Scatter + Where + NonZero + TfIdfVectorizer + MeanVarianceNormalization + +Updated operators: + BatchNormalization: removed spatial attribute. + Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported. + Cast: more data types{string} supported. + Upsample: moved scales from attribute to input. + Scan +""" + +import functools +import warnings + +import torch +from torch._C import _onnx as _C_onnx +from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9 +from torch.onnx._internal import jit_utils, registration + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8) + +block_listed_operators = ( + "nonzero", + "where", + "scatter", + "scatter_add", + "erf", + "sign", + "isnan", + "gather", + "arange", + "masked_fill", + "index_fill", + "index_copy", + "repeat_interleave", + "any", + "all", +) + +for block_listed_op in block_listed_operators: + _onnx_symbolic(f"aten::{block_listed_op}")( + symbolic_helper._block_list_in_opset(block_listed_op) + ) + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest")], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[symbolic_helper._apply_params("upsample_linear1d", 3, "linear")], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear")], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear")], +) +def _interpolate(name, dim, interpolate_mode): + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = symbolic_helper._get_interpolate_attributes( + g, interpolate_mode, args + ) + symbolic_helper._interpolate_warning(interpolate_mode) + align_corners = symbolic_helper._maybe_get_scalar(align_corners) + if align_corners: + return symbolic_helper._unimplemented(name, "align_corners == True", input) + output_size = symbolic_helper._maybe_get_const(output_size, "is") + if symbolic_helper._is_value(output_size): + return symbolic_helper._unimplemented( + name, "torch._C.Value (output_size) indexing" + ) + if scales is None: + scales = [ + 1.0 + if i < 2 + else float(output_size[-(dim - i)]) + / float(input.type().sizes()[-(dim - i)]) + for i in range(0, dim) + ] + return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales) + + return symbolic_fn + + +@_onnx_symbolic("aten::__interpolate") +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + align_corners = symbolic_helper._maybe_get_const(align_corners, "b") + if not symbolic_helper._is_none(align_corners) and align_corners: + return symbolic_helper._unimplemented("interpolate", "align_corners == True") + + if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value( + scale_factor + ): + return symbolic_helper._unimplemented( + "interpolate", "dynamic scales in opset 8" + ) + + if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size): + return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8") + + scales, mode = symbolic_helper._interpolate_get_scales_and_mode( + g, input, size, scale_factor, mode, align_corners + ) + return g.op("Upsample", input, mode_s=mode, scales_f=scales) + + +# NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation +# issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which +# is lost after casting. +def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args): + floating_scalar_types = { + _type_utils.JitScalarType.HALF, + _type_utils.JitScalarType.FLOAT, + _type_utils.JitScalarType.DOUBLE, + } + old_type = None + # Cast the input tensor to Float if its scalarType is known and is not floating number. + # If casting is performed, return the old scalarType, otherwise return None. + arg0_type = _type_utils.JitScalarType.from_value( + args[0], _type_utils.JitScalarType.UNDEFINED + ) + if arg0_type != _type_utils.JitScalarType.UNDEFINED: + old_type = arg0_type + if old_type not in floating_scalar_types: + old_type = old_type.scalar_name() # type: ignore[assignment] + args = tuple( + g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT) + for arg in args + ) + else: + return (None,) + args + else: + warnings.warn( + "Only floating datatype is supported for these operators: " + "{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause " + "the onnx model to be incorrect, if inputs have integer datatypes." + ) + return (old_type,) + args + + +def _cast_to_type(g: jit_utils.GraphContext, input, to_type): + if to_type is None: + return input + return getattr(opset9, f"_cast_{to_type}")(g, input, False) + + +def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name): + other = symbolic_helper._maybe_get_scalar(other) + other = symbolic_helper._if_scalar_type_as(other, input) + _, input, other = _try_cast_integer_to_float(g, input, other) + return g.op(op_name, input, other) + + +# NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten}, +# integer input type not supported in opset8. Cast to float if possible. +@_onnx_symbolic("aten::gt") +def gt(g: jit_utils.GraphContext, input, other): + return _comparison_operator(g, input, other, "Greater") + + +@_onnx_symbolic("aten::lt") +def lt(g: jit_utils.GraphContext, input, other): + return _comparison_operator(g, input, other, "Less") + + +@_onnx_symbolic("aten::bmm") +def bmm(g: jit_utils.GraphContext, self, other): + if symbolic_helper._try_get_scalar_type(self): + old_type, self, other = _try_cast_integer_to_float(g, self, other) + return _cast_to_type(g, g.op("MatMul", self, other), old_type) + else: + return g.op("MatMul", self, other) + + +@_onnx_symbolic("aten::matmul") +def matmul(g: jit_utils.GraphContext, self, other): + return bmm(g, self, other) + + +@_onnx_symbolic("aten::prelu") +def prelu(g: jit_utils.GraphContext, self, weight): + self_rank = symbolic_helper._get_tensor_rank(self) + weight_sizes = symbolic_helper._get_tensor_sizes(weight) + if self_rank is not None and self_rank > 2: + weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1))) + elif self_rank == 0 and weight_sizes == [1]: + # self and weight are both scalar but weight has rank == 1, squeeze weight. + weight = symbolic_helper._squeeze_helper(g, weight, [0]) + if symbolic_helper._try_get_scalar_type(self): + old_type, self, weight = _try_cast_integer_to_float(g, self, weight) + return _cast_to_type(g, g.op("PRelu", self, weight), old_type) + else: + return g.op("PRelu", self, weight) + + +@_onnx_symbolic("aten::mm") +def mm(g: jit_utils.GraphContext, self, other): + # Create a dummy C tensor. Only needed for API purposes, the value is + # since beta = 0 + scalar_type = symbolic_helper._try_get_scalar_type(self, other) + if scalar_type is None: + raise errors.SymbolicValueError( + "mm can only operate on tensors with known types", self + ) + zero_constant = g.op( + "Constant", + value_t=torch.tensor([0], dtype=scalar_type.dtype()), + ) + + if symbolic_helper._try_get_scalar_type(self): + old_type, self, other, zero_constant = _try_cast_integer_to_float( + g, self, other, zero_constant + ) + return _cast_to_type( + g, + g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0), + old_type, + ) + return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0) + + +@_onnx_symbolic("aten::addmm") +@symbolic_helper.parse_args("v", "v", "v", "t", "t") +def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): + if symbolic_helper._try_get_scalar_type(self): + old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2) + return _cast_to_type( + g, + g.op( + "Gemm", + mat1, + mat2, + self, + beta_f=symbolic_helper._scalar(beta), + alpha_f=symbolic_helper._scalar(alpha), + ), + old_type, + ) + else: + return g.op( + "Gemm", + mat1, + mat2, + self, + beta_f=symbolic_helper._scalar(beta), + alpha_f=symbolic_helper._scalar(alpha), + ) + + +@_onnx_symbolic("aten::flatten") +def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): + start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim") + end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim") + + dim = input.type().dim() + if end_dim_i < 0: + end_dim_i = dim + end_dim_i + # use ONNX's Flatten operator for cases where the output shape is 2D + if start_dim_i == 1 and end_dim_i == dim - 1: + if symbolic_helper._try_get_scalar_type(input): + old_type, input = _try_cast_integer_to_float(g, input) + return _cast_to_type( + g, g.op("Flatten", input, axis_i=start_dim_i), old_type + ) + else: + return g.op("Flatten", input, axis_i=start_dim_i) + if start_dim_i == 0 and end_dim_i == dim - 2: + if symbolic_helper._try_get_scalar_type(input): + old_type, input = _try_cast_integer_to_float(g, input) + return _cast_to_type( + g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type + ) + else: + return g.op("Flatten", input, axis_i=end_dim_i + 1) + + return opset9.flatten(g, input, start_dim, end_dim) + + +def _constant_fill(g: jit_utils.GraphContext, sizes, dtype: int, const_value): + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + if not scalar_type.dtype().is_floating_point: + result = g.op( + "ConstantFill", + sizes, + dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(), + input_as_shape_i=1, + value_f=const_value, + ) + return g.op("Cast", result, to_i=scalar_type.onnx_type()) + else: + return g.op( + "ConstantFill", + sizes, + dtype_i=scalar_type.onnx_type(), + input_as_shape_i=1, + value_f=const_value, + ) + + +@_onnx_symbolic("aten::empty") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty( + g: jit_utils.GraphContext, + sizes, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + return zeros(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::empty_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty_like( + g: jit_utils.GraphContext, + input, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + return zeros_like(g, input, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::zeros") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + # NOTE: no way to set device and layout in ONNX, so we ignore it + return _constant_fill(g, sizes, dtype, 0) + + +@_onnx_symbolic("aten::zeros_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def zeros_like( + g: jit_utils.GraphContext, + input, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + return _constant_fill(g, shape, dtype, 0) + + +@_onnx_symbolic("aten::ones") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + return _constant_fill(g, sizes, dtype, 1) + + +@_onnx_symbolic("aten::ones_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def ones_like( + g: jit_utils.GraphContext, + input, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + return _constant_fill(g, shape, dtype, 1) + + +@_onnx_symbolic("aten::full") +def full( + g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False +): + const_value = symbolic_helper._maybe_get_const(value, "t") + if symbolic_helper._is_value(const_value): + tmp = zeros(g, sizes, dtype, layout, device) + return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) + else: + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + return _constant_fill(g, sizes, dtype, const_value) + + +@_onnx_symbolic("aten::full_like") +@symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v") +def full_like( + g: jit_utils.GraphContext, + input, + fill_value, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + return _constant_fill(g, shape, dtype, fill_value) + + +@_onnx_symbolic("aten::repeat") +def repeat(g: jit_utils.GraphContext, self, repeats): + if not symbolic_helper._is_value(repeats): + repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) + if symbolic_helper._is_packed_list(repeats): + repeat_size_len = len(symbolic_helper._unpack_list(repeats)) + else: + const_repeats = symbolic_helper._maybe_get_const(repeats, "is") + repeat_size_len = len(const_repeats) + if self.isCompleteTensor(): + sizes = self.type().sizes() + diff_dims = repeat_size_len - len(sizes) + if diff_dims > 0: + self = opset9.view( + g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes)) + ) + return g.op("Tile", self, repeats) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index bd0f4795340ae..68d83695d692c 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.symbolic_opset9.""" from __future__ import annotations @@ -12,3 +13,6658 @@ _slice, _var_mean, ) +======= +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# mypy: disable-error-code=arg-type +"""This file exports ONNX ops for opset 9. + +Opset 9 is supported by ONNX release 1.4.1 +release on 01/23/19 +""" + +from __future__ import annotations + +import builtins +import functools +import math +import sys +import warnings +from typing import Callable, TYPE_CHECKING +from typing_extensions import deprecated + +import torch +import torch._C._onnx as _C_onnx +import torch.nn.modules.utils +import torch.onnx +from torch import _C + +# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics +from torch.onnx import _constants, _type_utils, errors, symbolic_helper +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import jit_utils, registration + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from torch.types import Number + +# EDITING THIS FILE? READ THIS FIRST! +# see Note [Edit Symbolic Files] in README.md + +__all__ = [ + "abs", + "acos", + "add", + "addcmul", + "addmm", + "alias", + "amax", + "amin", + "aminmax", + "arange", + "argmax", + "argmin", + "as_strided", + "as_tensor", + "asin", + "atan", + "atan2", + "baddbmm", + "batch_norm", + "bernoulli", + "bitwise_not", + "bitwise_or", + "bmm", + "broadcast_tensors", + "broadcast_to", + "bucketize", + "cat", + "cdist", + "ceil", + "clamp_max", + "clamp_min", + "clamp", + "clone", + "constant_pad_nd", + "contiguous", + "conv_tbc", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "conv1d", + "conv2d", + "conv3d", + "convert_element_type", + "convolution", + "cos", + "cosine_similarity", + "cross", + "cumsum", + "detach", + "dim", + "div", + "dot", + "dropout", + "elu", + "embedding_bag", + "embedding", + "empty_like", + "empty", + "eq", + "erf", + "exp", + "expand_as", + "expand", + "eye", + "fill", + "flatten", + "floor_divide", + "floor", + "floordiv", + "frobenius_norm", + "full_like", + "full", + "gather", + "ge", + "gelu", + "get_pool_ceil_padding", + "glu", + "group_norm", + "gt", + "hann_window", + "hardshrink", + "hardsigmoid", + "hardswish", + "hardtanh", + "index_add", + "index_copy", + "index_fill", + "index_put", + "index_select", + "index", + "instance_norm", + "is_floating_point", + "is_pinned", + "isnan", + "item", + "kl_div", + "layer_norm", + "le", + "leaky_relu", + "lerp", + "lift", + "linalg_cross", + "linalg_matrix_norm", + "linalg_norm", + "linalg_vector_norm", + "linear", + "linspace", + "log_sigmoid", + "log_softmax", + "log", + "log10", + "log1p", + "log2", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "logit", + "logsumexp", + "lstm_cell", + "lstm", + "lt", + "masked_fill", + "masked_fill_", + "matmul", + "max_pool1d_with_indices", + "max_pool2d_with_indices", + "max_pool3d_with_indices", + "max", + "maximum", + "meshgrid", + "min", + "minimum", + "mish", + "mm", + "movedim", + "mse_loss", + "mul", + "multinomial", + "mv", + "narrow", + "native_layer_norm", + "ne", + "neg", + "new_empty", + "new_full", + "new_ones", + "new_zeros", + "nonzero_numpy", + "nonzero", + "norm", + "numel", + "numpy_T", + "one_hot", + "ones_like", + "ones", + "onnx_placeholder", + "pad", + "pairwise_distance", + "permute", + "pixel_shuffle", + "pixel_unshuffle", + "pow", + "prelu", + "prim_constant_chunk", + "prim_constant_split", + "prim_constant", + "prim_data", + "prim_device", + "prim_dtype", + "prim_if", + "prim_layout", + "prim_list_construct", + "prim_list_unpack", + "prim_loop", + "prim_max", + "prim_min", + "prim_shape", + "prim_tolist", + "prim_tuple_construct", + "prim_type", + "prim_unchecked_cast", + "prim_uninitialized", + "rand_like", + "rand", + "randint_like", + "randint", + "randn_like", + "randn", + "reciprocal", + "reflection_pad", + "relu", + "relu6", + "remainder", + "repeat_interleave", + "repeat", + "replication_pad", + "reshape_as", + "reshape", + "roll", + "rrelu", + "rsqrt", + "rsub", + "scalar_tensor", + "scatter_add", + "scatter", + "select", + "selu", + "sigmoid", + "sign", + "silu", + "sin", + "size", + "slice", + "softmax", + "softplus", + "softshrink", + "sort", + "split_with_sizes", + "split", + "sqrt", + "square", + "squeeze", + "stack", + "std_mean", + "std", + "sub", + "t", + "take", + "tan", + "tanh", + "tanhshrink", + "tensor", + "threshold", + "to", + "topk", + "transpose", + "true_divide", + "type_as", + "unbind", + "unfold", + "unsafe_chunk", + "unsafe_split_with_sizes", + "unsafe_split", + "unsqueeze", + "unsupported_complex_operators", + "noop_complex_operators", + "unused", + "var_mean", + "var", + "view_as", + "view", + "where", + "wrap_logical_op_with_cast_to", + "wrap_logical_op_with_negation", + "zeros_like", + "zeros", + "zero", +] + + +_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9) + + +def _export(name: str): + """Exports the function in the current global namespace.""" + + def wrapper(func): + globals()[name] = func + __all__.append(name) + return func + + return wrapper + + +def unused(g): + """Represents "missing" optional inputs.""" + n = g.op("prim::Constant") + n.setType(_C.OptionalType.ofTensor()) + return n + + +@_onnx_symbolic("aten::_shape_as_tensor") +def _shape_as_tensor(g: jit_utils.GraphContext, input): + return g.op("Shape", input) + + +@_onnx_symbolic("aten::_reshape_from_tensor") +def _reshape_from_tensor(g: jit_utils.GraphContext, input, shape): + if isinstance(shape, list): + shape = g.op("Concat", *shape, axis_i=0) + return reshape(g, input, shape) + + +@_onnx_symbolic("aten::reshape") +@symbolic_helper.quantized_args(True) +def reshape(g: jit_utils.GraphContext, self, shape): + return symbolic_helper._reshape_helper(g, self, shape) + + +@_onnx_symbolic("aten::reshape_as") +@symbolic_helper.quantized_args(True) +def reshape_as(g: jit_utils.GraphContext, self, other): + shape = g.op("Shape", other) + return reshape(g, self, shape) + + +@_onnx_symbolic("aten::add") +def add(g: jit_utils.GraphContext, self, other, alpha=None): + """ + This function takes the add function and returns the corresponding ONNX operator. + + This function is not meant to be called directly by the user. + + Args: + g (GraphContext): The graph context. + self (Tensor): The first operand. + other (Tensor): The second operand. + alpha (float, optional): The scaling factor for the second operand. Defaults to None. + + Returns: + ONNX operator. + """ + if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): + return symbolic_helper._onnx_opset_unsupported_detailed( + "Add", 9, 11, "Add between list of tensors not supported", self + ) + if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: + other = g.op("Mul", other, alpha) + return g.op("Add", self, other) + + +@_onnx_symbolic("aten::sub") +def sub(g: jit_utils.GraphContext, self, other, alpha=None): + """ + Consumes sub function and returns the corresponding ONNX operator. + + This function is not meant to be called directly by the user. + + Args: + g (GraphContext): The graph context. + self (Tensor): The first operand. + other (Tensor): The second operand. + alpha (Optional[Tensor]): A scaling factor to apply to the second operand. + If `alpha` is not provided, it defaults to 1. + + Returns: + ONNX operator + """ + if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: + other = g.op("Mul", other, alpha) + return g.op("Sub", self, other) + + +@_onnx_symbolic("aten::rsub") +def rsub(g: jit_utils.GraphContext, self, other, alpha=None): + return sub(g, other, self, alpha=alpha) + + +@_onnx_symbolic("aten::mul") +def mul(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_bool(self) and symbolic_helper._is_bool(other): + # ONNX Mul doesn't support Boolean, so use And as an equivalent operator. + return g.op("And", self, other) + else: + return g.op("Mul", self, other) + + +@_onnx_symbolic("aten::div") +def div(g: jit_utils.GraphContext, self, other, *args): + if len(args) == 0: + return true_divide(g, self, other) + else: + return _div_rounding_mode(g, self, other, *args) + + +@_onnx_symbolic("aten::addcmul") +@symbolic_helper.parse_args("v", "v", "v", "f") +def addcmul(g: jit_utils.GraphContext, self, tensor1, tensor2, value=1.0): + value_tens = g.op("Constant", value_t=torch.tensor([value])) + return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens)) + + +@symbolic_helper.parse_args("v", "v", "s") +def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): + if rounding_mode is None: + return true_divide(g, self, other) + elif rounding_mode == "floor": + return _floor_divide(g, self, other) + elif rounding_mode == "trunc": + return _trunc_divide(g, self, other) + else: + raise errors.SymbolicValueError( + f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"', + self, + ) + + +def _trunc_divide(g: jit_utils.GraphContext, self, other): + out = g.op("Div", self, other) + # the correct operation is truncate, which is not supported in ONNX, + # we cannot call floor since it will behave differently for negative numbers + # (eg. -0.1 should become -0 ) + # - if scalar_type information are not available, assume that + # we need to call floor (treat as float) + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.INT64) + + # Matching PyTorch's behavior: + # - if self is fp the output's type is self's type + # - if self is not fp and other is fp, the output is of type JitScalarType.FLOAT + # - self is not fp and other is not fp, the output's type is self's output type + # - the output type defaults to Float + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) + if scalar_type != _type_utils.JitScalarType.UNDEFINED: + if not symbolic_helper._is_fp(self) and symbolic_helper._is_fp(other): + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) + else: + out = g.op( + "Cast", + out, + to_i=scalar_type.onnx_type(), + ) + else: + out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return out + + +def _floor_divide(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): + out = true_divide(g, self, other) + return g.op("Floor", out) + else: + # Integer division does trunction rounding + div = g.op("Div", self, other) + # Division is negative if: self < 0 != other < 0 + zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) + negative = g.op( + "Xor", + symbolic_helper._lt_helper(g, self, zero), + symbolic_helper._lt_helper(g, other, zero), + ) + + # For negative numbers with self % other != 0, subtract 1 to round down instead of up + mod = g.op("Sub", self, g.op("Mul", div, other)) + fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) + + one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + fixup = g.op("Mul", fixup_mask, one) + return g.op("Sub", div, fixup) + + +@_onnx_symbolic("aten::floor_divide") +def floor_divide(g: jit_utils.GraphContext, self, other): + # Deprecated behavior, floor_divide actually truncates + return _trunc_divide(g, self, other) + + +@_onnx_symbolic("aten::floordiv") +def floordiv(g: jit_utils.GraphContext, self, other): + return floor_divide(g, self, other) + + +@_onnx_symbolic("aten::true_divide") +def true_divide(g: jit_utils.GraphContext, self, other): + """Division where both inputs are cast to floating types + + If both inputs are floating, performs div as usual + If only one input is a floating type, the other input is cast to its type + If neither input is a floating type, both inputs are cast to the default scalar type + """ + + # Case 1: either values are floating + # Performs div as usual. + # Implicit casting will be handled in scalar type analysis pass. + if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): + return g.op("Div", self, other) + + # Case 2: neither is floating + # Casts both inputs to the default scalar type + scalar_type = torch.get_default_dtype() + onnx_scalar_type = _C_onnx.TensorProtoDataType.FLOAT + assert scalar_type is torch.float or scalar_type is torch.double + if torch.get_default_dtype() is torch.double: + onnx_scalar_type = _C_onnx.TensorProtoDataType.DOUBLE + + self = g.op("Cast", self, to_i=onnx_scalar_type) + other = g.op("Cast", other, to_i=onnx_scalar_type) + return g.op("Div", self, other) + + +@_onnx_symbolic("aten::reciprocal") +def reciprocal(g: jit_utils.GraphContext, self): + # torch.reciprocal implicitly casts to float, so we do the same. + if not symbolic_helper._is_fp(self): + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return g.op("Reciprocal", self) + + +@_onnx_symbolic("aten::cat") +@symbolic_helper.parse_args("v", "i") +def cat(g: jit_utils.GraphContext, tensor_list, dim): + """Implement concatenation of pytorch tensors in ONNX along the specified `dim` dimension. + + Parameters: + g (jit_utils.GraphContext): Graph context. + tensor_list (List[torch.Tensor]): List of tensors to concatenate. + dim (int): Dimension along which to concatenate the tensors. + + Returns: + ONNX graph node representing the concatenated tensor. + """ + tensors = symbolic_helper._unpack_list(tensor_list) + # torch.cat ignores empty tensors such as `torch.Tensor([])` + # These needs to be removed as input from ONNX's concat too, otherwise shape inference + # will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else) + nonempty_tensors = [] + for t in tensors: + if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size( + t, 0 + ): + continue + nonempty_tensors.append(t) + assert len(nonempty_tensors) > 0 + assert all( + symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None + or symbolic_helper._get_tensor_rank(t) is None + or symbolic_helper._get_tensor_rank(t) + == symbolic_helper._get_tensor_rank(nonempty_tensors[0]) + for t in nonempty_tensors + ) + tensor_list.node().removeAllInputs() + for t in nonempty_tensors: + tensor_list.node().addInput(t) + + tensors = symbolic_helper._unpack_list(tensor_list) + return g.op("Concat", *tensors, axis_i=dim) + + +@_onnx_symbolic("aten::stack") +@symbolic_helper.parse_args("v", "i") +def stack(g: jit_utils.GraphContext, tensor_list, dim): + unsqueezed = [ + symbolic_helper._unsqueeze_helper(g, t, [dim]) + for t in symbolic_helper._unpack_list(tensor_list) + ] + return g.op("Concat", *unsqueezed, axis_i=dim) + + +@_onnx_symbolic("aten::list") +def _list(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("aten::mm") +def mm(g: jit_utils.GraphContext, self, other): + # Create a dummy C tensor. Only needed for API purposes, the value is + # since beta = 0 + C = g.op("Constant", value_t=torch.tensor([1])) + return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0) + + +@_onnx_symbolic("aten::bmm") +def bmm(g: jit_utils.GraphContext, self, other): + return g.op("MatMul", self, other) + + +@_onnx_symbolic("aten::matmul") +def matmul(g: jit_utils.GraphContext, self, other): + return g.op("MatMul", self, other) + + +@_onnx_symbolic("aten::addmm") +@symbolic_helper.parse_args("v", "v", "v", "t", "t") +def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): + scalar_type = None + self_scalar_type = symbolic_helper._try_get_scalar_type(self) + mat1_scalar_type = symbolic_helper._try_get_scalar_type(mat1) + mat2_scalar_type = symbolic_helper._try_get_scalar_type(mat2) + if self_scalar_type is not None: + scalar_type = self_scalar_type + elif mat1_scalar_type is not None: + scalar_type = mat1_scalar_type + elif mat2_scalar_type is not None: + scalar_type = mat2_scalar_type + + mat1_rank = symbolic_helper._get_tensor_rank(mat1) + mat2_rank = symbolic_helper._get_tensor_rank(mat2) + + def is_not_none_nor(v, u): + return v is not None and v != u + + if scalar_type is not None and ( + is_not_none_nor(mat1_rank, 2) or is_not_none_nor(mat2_rank, 2) + ): + res1 = g.op("MatMul", mat1, mat2) + res2 = self + + alpha = symbolic_helper._scalar(alpha) + beta = symbolic_helper._scalar(beta) + + if alpha != 1: + alpha = g.op( + "Constant", value_t=torch.tensor(alpha, dtype=scalar_type.dtype()) + ) + res1 = g.op("Mul", res1, alpha) + if beta != 1: + beta = g.op( + "Constant", + value_t=torch.tensor( + symbolic_helper._scalar(beta), dtype=scalar_type.dtype() + ), + ) + res2 = g.op("Mul", res2, beta) + + return g.op("Add", res1, res2) + + return g.op( + "Gemm", + mat1, + mat2, + self, + beta_f=symbolic_helper._scalar(beta), + alpha_f=symbolic_helper._scalar(alpha), + ) + + +@_onnx_symbolic("aten::neg") +def neg(g: jit_utils.GraphContext, self): + return g.op("Neg", self) + + +@_onnx_symbolic("aten::sqrt") +def sqrt(g: jit_utils.GraphContext, self): + if _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.UINT8, + _type_utils.JitScalarType.INT8, + _type_utils.JitScalarType.INT16, + _type_utils.JitScalarType.INT, + _type_utils.JitScalarType.INT64, + }: + # torch converts all int inputs to sqrt to float + self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) + + return g.op("Sqrt", self) + + +@_onnx_symbolic("aten::rsqrt") +def rsqrt(g: jit_utils.GraphContext, self): + return g.op( + "Div", symbolic_helper._if_scalar_type_as(torch.ones(1), self), sqrt(g, self) + ) + + +@_onnx_symbolic("aten::tanh") +# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qtanh.cpp +@symbolic_helper.quantized_args(True, scale=2.0 / 256.0, zero_point=128) +def tanh(g: jit_utils.GraphContext, self): + return g.op("Tanh", self) + + +@_onnx_symbolic("aten::sin") +def sin(g: jit_utils.GraphContext, self): + return g.op("Sin", self) + + +@_onnx_symbolic("aten::cos") +def cos(g: jit_utils.GraphContext, self): + return g.op("Cos", self) + + +@_onnx_symbolic("aten::tan") +def tan(g: jit_utils.GraphContext, self): + return g.op("Tan", self) + + +@_onnx_symbolic("aten::asin") +def asin(g: jit_utils.GraphContext, self): + return g.op("Asin", self) + + +@_onnx_symbolic("aten::acos") +def acos(g: jit_utils.GraphContext, self): + return g.op("Acos", self) + + +@_onnx_symbolic("aten::atan") +def atan(g: jit_utils.GraphContext, self): + return g.op("Atan", self) + + +@_onnx_symbolic("aten::atan2") +def atan2(g: jit_utils.GraphContext, self, other): + # self is y, and other is x on coordinate + slope = g.op("Div", self, other) + atan = g.op("Atan", slope) + const_zero = g.op("Constant", value_t=torch.tensor(0)) + const_pi = g.op("Constant", value_t=torch.tensor(math.pi)) + + condition_second_or_third_quadrant = g.op("Greater", self, const_zero) + second_third_quadrant = g.op( + "Where", + condition_second_or_third_quadrant, + g.op("Add", atan, const_pi), + g.op("Sub", atan, const_pi), + ) + + condition_14_or_23_quadrant = g.op("Less", other, const_zero) + result = g.op("Where", condition_14_or_23_quadrant, second_third_quadrant, atan) + + return result + + +@_onnx_symbolic("aten::sigmoid") +# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp +@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) +def sigmoid(g: jit_utils.GraphContext, self): + """Converts the corresponding PyTorch function into ONNX operators. + + It is not meant to be called directly by a user. + + Args: + g (jit_utils.GraphContext): Graph context. + self (Tensor): the input tensor. + Returns: + ONNX operator + """ + return g.op("Sigmoid", self) + + +@_onnx_symbolic("aten::sign") +def sign(g: jit_utils.GraphContext, self): + return g.op("Sign", self) + + +@symbolic_helper.quantized_args(True) +def _slice(g: jit_utils.GraphContext, input, axes, starts, ends): + assert len(starts) == len(ends) + if len(starts) == 1 and starts[0] == 0 and ends[0] == _constants.INT64_MAX: + return input + return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends) + + +@_onnx_symbolic( + "aten::sum", decorate=[symbolic_helper._apply_params("ReduceSum", "sum")] +) +@_onnx_symbolic( + "aten::mean", decorate=[symbolic_helper._apply_params("ReduceMean", "mean")] +) +# torch.prod does not support multidimensional "dim" +@_onnx_symbolic( + "aten::prod", + decorate=[ + symbolic_helper._apply_params( + "ReduceProd", "prod", allow_multi_dim_support=False + ) + ], +) +def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): + return symbolic_helper._reduce_with_dtype_helper( + onnx_op, name, allow_multi_dim_support + ) + + +@_onnx_symbolic("aten::cumsum") +@symbolic_helper.parse_args("v", "i", "none") +def cumsum(g: jit_utils.GraphContext, input, dim, dtype): + symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input) + + +@_onnx_symbolic("aten::_sample_dirichlet") +def _sample_dirichlet(g: jit_utils.GraphContext, self, generator): + return symbolic_helper._onnx_unsupported("_sample_dirichlet", self) + + +@_onnx_symbolic("aten::_standard_gamma") +def _standard_gamma(g: jit_utils.GraphContext, self, generator): + return symbolic_helper._onnx_unsupported("_standard_gamma", self) + + +@_onnx_symbolic("aten::t") +def t(g: jit_utils.GraphContext, self): + rank = symbolic_helper._get_tensor_rank(self) + if rank is None or rank < 2: + # The transpose of a 1d or 0d tensor is itself. ONNX does not define the behavior + # clearly and onnxruntime fails on these cases. So we add an Identity node to + # mirror the behavior of eager mode. + return g.op("Identity", self) + return g.op("Transpose", self, perm_i=(1, 0)) + + +@_onnx_symbolic("aten::numpy_T") +@symbolic_helper.quantized_args(True) +def numpy_T(g: jit_utils.GraphContext, input): + ndim = symbolic_helper._get_tensor_rank(input) + assert ndim is not None + perm = list(reversed(range(0, ndim))) + return g.op("Transpose", input, perm_i=perm) + + +@_onnx_symbolic("aten::expand") +@symbolic_helper.quantized_args(True) +def expand(g: jit_utils.GraphContext, self, size, implicit): + """Implement the expand function for a pytorch tensor in ONNX according to specified `size`""" + size = symbolic_helper._maybe_get_const(size, "is") + if not symbolic_helper._is_value(size): + size = g.op("Constant", value_t=torch.LongTensor(size)) + elif symbolic_helper._is_packed_list(size): + # Expand with -1 dim value means dim is unchanged. + # Since onnx::expand supports two-way broadcasting, + # -1 dim value can be exported to onnx as 1 + size = symbolic_helper._reshape_helper( + g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) + ) + dtype = _type_utils.JitScalarType.INT64 + ones = ones_like(g, size, dtype) + neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) + size = where(g, g.op("Equal", size, neg_ones), ones, size) + return g.op("Expand", self, size) + + +@_onnx_symbolic("aten::broadcast_to") +@symbolic_helper.quantized_args(True) +def broadcast_to(g: jit_utils.GraphContext, self, size): + size = symbolic_helper._maybe_get_const(size, "is") + if not symbolic_helper._is_value(size): + size = g.op("Constant", value_t=torch.LongTensor(size)) + elif symbolic_helper._is_packed_list(size): + # Expand with -1 dim value means dim is unchanged. + # Since onnx::expand supports two-way broadcasting, + # -1 dim value can be exported to onnx as 1 + size = symbolic_helper._reshape_helper( + g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) + ) + dtype = _type_utils.JitScalarType.INT64 + ones = ones_like(g, size, dtype) + neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) + size = where(g, g.op("Equal", size, neg_ones), ones, size) + return g.op("Expand", self, size) + + +@_onnx_symbolic("aten::expand_as") +@symbolic_helper.quantized_args(True, True) +def expand_as(g: jit_utils.GraphContext, self, other): + self_t = symbolic_helper._maybe_get_const(self, "t") + if isinstance(self_t, torch.Tensor): + orig_type = self_t.dtype + self_t = self_t.to(torch.double) + dims = [] + for d in range(self_t.dim()): + if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t): + dims.append(d) + self = g.op( + "Constant", value_t=self_t.mean(dims, keepdim=True).to(orig_type) + ) + + shape = g.op("Shape", other) + return g.op("Expand", self, shape) + + +@_onnx_symbolic("aten::embedding") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "i", "b", "v") +def embedding( + g: jit_utils.GraphContext, + weight, + indices, + padding_idx, + scale_grad_by_freq, + sparse, +): + if scale_grad_by_freq and GLOBALS.export_training: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of embedding with scale_grad_by_freq=True " + "for training mode. ONNX does not support scaling the gradients.", + weight, + ) + if padding_idx >= 0 and GLOBALS.export_training: + warnings.warn( + "Warning: ONNX export of embedding with padding_idx >= 0 " + "for training mode. " + "ONNX does not support not updating the embedding vector at padding_idx during training." + ) + + return g.op("Gather", weight, indices) + + +@_onnx_symbolic("aten::embedding_bag") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") +def embedding_bag( + g: jit_utils.GraphContext, + embedding_matrix, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, +): + if not symbolic_helper._is_none(per_sample_weights): + return symbolic_helper._onnx_unsupported( + "embedding_bag with per_sample_weights" + ) + + return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix) + + +@_onnx_symbolic("aten::size") +@symbolic_helper.quantized_args(True, quantize_output=False) +def size(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Shape", self) + if symbolic_helper._maybe_get_const(dim, "i") < 0: + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + dim = symbolic_helper._maybe_get_const(dim, "i") + rank + dim = g.op("Constant", value_t=torch.tensor(dim)) + return symbolic_helper._size_helper(g, self, dim) + + +@_onnx_symbolic("aten::transpose") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "i", "i") +def transpose(g: jit_utils.GraphContext, self, dim0, dim1): + if dim0 == dim1: # micro-optimization + return self + + # NB: Transpose in ONNX is actually a Permute + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + axes = list(range(rank)) + axes[dim0], axes[dim1] = axes[dim1], axes[dim0] + return g.op("Transpose", self, perm_i=axes) + else: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of transpose for tensor of unknown rank.", + self, + ) + + +@_onnx_symbolic("aten::permute") +@symbolic_helper.parse_args("v", "is") +def permute(g: jit_utils.GraphContext, self, dims): + if dims == list(range(0, len(dims))): + return self + return g.op("Transpose", self, perm_i=dims) + + +@_onnx_symbolic("aten::view") +@symbolic_helper.quantized_args(True) +def view(g: jit_utils.GraphContext, self, size): + return reshape(g, self, size) + + +@_onnx_symbolic("aten::view_as") +def view_as(g: jit_utils.GraphContext, self, other): + shape = g.op("Shape", other) + return reshape(g, self, shape) + + +@_onnx_symbolic("aten::unsafe_chunk") +@symbolic_helper.parse_args("v", "i", "i", "i") +def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): + if _outputs is None: + return symbolic_helper._onnx_opset_unsupported_detailed( + "unsafe_chunk", 9, 11, "Dynamic number of outputs not supported", self + ) + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + return symbolic_helper._unimplemented( + "unsafe_chunk", "unknown dimension size", self + ) + split_size = (size + chunks - 1) // chunks + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::split") +@symbolic_helper.parse_args("v", "v", "i", "i") +def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): + return symbolic_helper._onnx_opset_unsupported_detailed( + "split", 9, 11, "Dynamic number of outputs not supported", self + ) + split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") + if split_val.dim() > 0: + return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs) + split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") + + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + if _outputs is not None: + size = split_size * _outputs + else: + return symbolic_helper._onnx_opset_unsupported_detailed( + "split", 9, 11, "Unknown dimension size not supported", self + ) + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::unsafe_split") +def unsafe_split( + g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None +): + return split(g, self, split_size_or_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::split_with_sizes") +@symbolic_helper.parse_args("v", "is", "i", "i") +def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): + if not symbolic_helper._is_split_static(split_sizes, _outputs): + return symbolic_helper._onnx_opset_unsupported_detailed( + "split_with_sizes", 9, 11, "Dynamic number of outputs not supported", self + ) + return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs) + + +@_onnx_symbolic("aten::unsafe_split_with_sizes") +def unsafe_split_with_sizes( + g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None +): + return split_with_sizes(g, self, split_sizes, dim, _outputs) + + +@_onnx_symbolic("aten::unbind") +@symbolic_helper.parse_args("v", "i", "i") +def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): + if _outputs is None: + return symbolic_helper._onnx_opset_unsupported_detailed( + "unbind", 9, 11, "Dynamic number of outputs not supported", self + ) + + outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs) + outputs = [outputs] if _outputs == 1 else outputs + squeezed_outputs = [ + symbolic_helper._squeeze_helper(g, out, [dim]) for out in outputs + ] + return squeezed_outputs + + +@_onnx_symbolic("aten::select") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "i", "v") +def select(g: jit_utils.GraphContext, self, dim, index): + """Implement the select functionality for a pytorch tensor in ONNX. + + Selects elements from the input tensor along the specified `dim` dimension based on the `index` tensor. + """ + index = symbolic_helper._maybe_get_scalar(index) + if (not symbolic_helper._is_value(index)) and (index < 0): + if index == -1: + end_index = _constants.INT64_MAX + else: + end_index = index + 1 + slice_node = symbolic_helper._slice_helper( + g, self, axes=[dim], starts=[index], ends=[end_index] + ) + return symbolic_helper._squeeze_helper(g, slice_node, [dim]) + else: + # FIXME(justinchuby): can index be an int and not a value? + return g.op("Gather", self, index, axis_i=dim) + + +@_onnx_symbolic("aten::square") +def square(g: jit_utils.GraphContext, self): + return g.op("Mul", self, self) + + +@_onnx_symbolic("aten::squeeze") +def squeeze(g: jit_utils.GraphContext, self, dim=None): + if dim is None: + return g.op("Squeeze", self) + + squeeze_dim = symbolic_helper._get_const(dim, "i", "dim") + # Handle negative dims + if squeeze_dim < 0: + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + warnings.warn( + "ONNX export squeeze with negative axis " + + str(squeeze_dim) + + " might cause the onnx model to be incorrect. " + + "Negative axis is not supported in ONNX. " + + "Axis is converted to " + + str(squeeze_dim + rank) + + " based on input shape at export time. " + + "Passing an tensor of different rank in execution will be incorrect." + ) + squeeze_dim += rank + else: + return symbolic_helper._unimplemented( + "squeeze", "negative axis with unknown input rank", self + ) + + dim_size = symbolic_helper._get_tensor_dim_size(self, squeeze_dim) + if dim_size is None: + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(squeeze_dim) + + " on an input " + + "with unknown shape. Note that if the size of dimension " + + str(squeeze_dim) + + " of the input " + + "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " + + "non-singleton dimensions, it is recommended to export this model using opset " + + "version 11 or higher." + ) + return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) + if dim_size > 1: + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(squeeze_dim) + + ". The size of " + + "this dimension in the given input is " + + str(dim_size) + + ". The model will " + + "be exported without the squeeze node. If the model is intended to be used with dynamic " + + "input shapes, please use opset version 11 to " + + "export the model." + ) + return self + + warnings.warn( + "This model contains a squeeze operation on dimension " + + str(squeeze_dim) + + ". If the model is " + + "intended to be used with dynamic input shapes, please use opset version 11 to export the model." + ) + return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) + + +@_onnx_symbolic("aten::prelu") +def prelu(g: jit_utils.GraphContext, self, weight): + self_rank = symbolic_helper._get_tensor_rank(self) + weight_sizes = symbolic_helper._get_tensor_sizes(weight) + weight_rank = len(weight_sizes) + if self_rank is not None: + if self_rank > 2: + # make weight unidirectional broadcastable + weight = symbolic_helper._unsqueeze_helper( + g, weight, list(range(1, self_rank - 1)) + ) + elif self_rank == 0 and weight_sizes == [1]: + # self and weight are both scalar but weight has rank == 1, squeeze weight. + weight = symbolic_helper._squeeze_helper(g, weight, [0]) + weight_rank = 0 + + if self_rank is not None and weight_rank is not None: + assert self_rank >= weight_rank, ( + f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}" + ) + return g.op("PRelu", self, weight) + + +@_onnx_symbolic("aten::silu") +def silu(g: jit_utils.GraphContext, input): + return g.op("Mul", input, g.op("Sigmoid", input)) + + +@_onnx_symbolic("aten::mish") +def mish(g: jit_utils.GraphContext, input): + return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input))) + + +@_onnx_symbolic("aten::relu") +@symbolic_helper.quantized_args(True) +def relu(g: jit_utils.GraphContext, input): + return symbolic_helper._op_with_optional_float_cast( + g, "Relu", input, opset_before=14 + ) + + +@_onnx_symbolic("aten::relu6") +@symbolic_helper.quantized_args(True) +def relu6(g: jit_utils.GraphContext, input): + return clamp(g, input, 0, 6) + + +@_onnx_symbolic("aten::ceil") +def ceil(g: jit_utils.GraphContext, input): + return g.op("Ceil", input) + + +@_onnx_symbolic("aten::floor") +def floor(g: jit_utils.GraphContext, input): + return g.op("Floor", input) + + +@_onnx_symbolic("aten::len") +def _len(g: jit_utils.GraphContext, self): + sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) + return symbolic_helper._squeeze_helper(g, sz_0, [0]) + + +@_onnx_symbolic("aten::threshold") +@symbolic_helper.parse_args("v", "t", "t") +def threshold(g: jit_utils.GraphContext, self, threshold, value): + # See Note [Export inplace] + if symbolic_helper._scalar(threshold) != 0: + return symbolic_helper._unimplemented("threshold", "non-zero threshold", self) + if symbolic_helper._scalar(value) != 0: + return symbolic_helper._unimplemented("threshold", "non-zero value", self) + return g.op("Relu", self) + + +@_onnx_symbolic("aten::leaky_relu") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "f", "b") +def leaky_relu( + g: jit_utils.GraphContext, + input: _C.Value, + negative_slope: float, + inplace: bool = False, +): + # See Note [Export inplace] + return g.op("LeakyRelu", input, alpha_f=negative_slope) + + +@_onnx_symbolic("aten::glu") +@symbolic_helper.parse_args("v", "i") +def glu(g: jit_utils.GraphContext, input, dim): + dim_size = symbolic_helper._get_tensor_dim_size(input, dim) + if dim_size is not None: + assert dim_size % 2 == 0 + + first, second = g.op("Split", input, axis_i=dim, outputs=2) + return g.op("Mul", first, g.op("Sigmoid", second)) + + +@_onnx_symbolic("aten::softmax") +@symbolic_helper.parse_args("v", "i", "none") +def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + # Softmax does normalization at vector level. + # PyTorch and ONNX use different strategies to split the input tensor into vectors. + # Thus dim and axis have different meanings. + # PyTorch slices the input tensor into vectors along the `dim`-th dimension. + # ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced. + # If input is a 2 x 3 tensor: + # input = [[1.0, 1.0, 1.0], + # [1.0, 1,0, 1,0]] + # with dim = 0, the result is: + # result = [[0.5, 0.5, 0.5], + # [0.5, 0.5, 0.5]] + # with axis = 0, the result is: + # result = [[0.167, 0.167, 0.167], + # [0.167, 0.167, 0.167]] + # So only when dim and axis both equal to ndim - 1 (the last dimension), + # their semantics are equivalent. + # So use softmax when dim and axis both equal to ndim - 1, + # otherwise transpose the input to put the vectors to be normalized to the last dimension. + # When input rank is not known at export time we compute softmax using a subgraph + # with other operators + input_dim = symbolic_helper._get_tensor_rank(input) + if input_dim is not None: + # TODO: remove this as onnx opset 11 spec allows negative axes + if dim < 0: + dim = input_dim + dim + + is_transpose_required = input_dim != dim + 1 + + if is_transpose_required: + axes = list(range(input_dim)) + axes[dim], axes[-1] = axes[-1], axes[dim] + input = g.op("Transpose", input, perm_i=axes) + dim = input_dim - 1 + + softmax = g.op("Softmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + softmax = g.op( + "Cast", + softmax, + to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type(), + ) + + if is_transpose_required: + softmax = g.op("Transpose", softmax, perm_i=axes) # type: ignore[possibly-undefined] + return softmax + + # Apply max normalization. + input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1)) + + exp = g.op("Exp", input) + sum = symbolic_helper._reducesum_helper(g, exp, axes_i=[dim]) + softmax = g.op("Div", exp, sum) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + softmax = g.op( + "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + return softmax + + +@_onnx_symbolic("aten::softplus") +def softplus(g: jit_utils.GraphContext, self, beta, threshold): + beta_const = symbolic_helper._maybe_get_const(beta, "f") + if beta_const != 1: + return g.op("Div", g.op("Softplus", g.op("Mul", self, beta)), beta) + return g.op("Softplus", self) + + +@_onnx_symbolic("aten::get_pool_ceil_padding") +def get_pool_ceil_padding(input, kernel_size, stride, padding): + # TODO(justinchuby): Looks like this op is deprecated in torch + sizes = symbolic_helper._get_tensor_sizes(input) + dim = sizes[-len(padding) :] if sizes is not None else None + if dim is None or any(i is None for i in dim): + return symbolic_helper._unimplemented( + "get_pool_ceil_padding", "input size not accessible", input + ) + ceiled_output_dim = [ + int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) + + 1 + for i in range(0, len(padding)) + ] + # ensure last pooling starts inside + ceiled_output_dim = [ + ( + ceiled_output_dim[i] - 1 + if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i])) + else ceiled_output_dim[i] + ) + for i in range(0, len(ceiled_output_dim)) + ] + padding_ceil = [ + ( + 0 + if (stride[i] == 1) + else ( + kernel_size[i] + - ( + dim[i] + + 2 * padding[i] + - ((ceiled_output_dim[i] - 1) * stride[i] + 1) + ) + ) + ) + for i in range(0, len(padding)) + ] + # ensure padding is not > kernel_size + padding_ceil = [ + ( + ( + int(padding_ceil[i]) + if padding_ceil[i] < kernel_size[i] - 1 + else int(kernel_size[i] - 1) + ) + if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i])) + else int(padding_ceil[i]) + ) + for i in range(0, len(padding_ceil)) + ] + return padding_ceil + + +@_onnx_symbolic( + "aten::max_pool1d", + decorate=[ + symbolic_helper._apply_params( + "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False + ), + _export("max_pool1d"), + ], +) +@_onnx_symbolic( + "aten::max_pool2d", + decorate=[ + symbolic_helper._apply_params( + "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False + ), + _export("max_pool2d"), + ], +) +@_onnx_symbolic( + "aten::max_pool3d", + decorate=[ + symbolic_helper._apply_params( + "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False + ), + _export("max_pool3d"), + ], +) +def _max_pool(name, tuple_fn, ndims, return_indices): + @symbolic_helper.quantized_args(True, False, False, False, False, False) + @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") + def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): + if set(tuple_fn(dilation)) != {1}: + return symbolic_helper._unimplemented(name, "dilation", input) + if not stride: + stride = kernel_size + padding = tuple(tuple_fn(padding)) + if ceil_mode: + padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) + padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding)) + else: + padding = padding * 2 + kwargs = { + "kernel_shape_i": tuple_fn(kernel_size), + "pads_i": padding, + "strides_i": tuple_fn(stride), + } + # easy but hacky way to get flattened indices values + # to be used to convert the indices values to non-flattened. + # In ONNX the indices are computed as a flatten 1-D tensor, + # so the values in indices are in [0, N x C x D1 x ... x Dn). + # To convert the indices to the same format used by Pytorch, + # we first execute a maxpool with a kernel and stride of 1 on the same input. + # This will result in a tensor of indices in which each index will have it's own value. + # Using this tensor as a reference, we extract the first index of each axis and subtract + # it from each index of this axis in the indices to convert. + # This step will result in a tensor were each dimension has values of indices within + # the dimension it is in. + # For more information : + # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 + if return_indices: + r, indices = g.op("MaxPool", input, outputs=2, **kwargs) + _, flattened_indices = g.op( + "MaxPool", + input, + outputs=2, + kernel_shape_i=[1 for _ in range(ndims)], + strides_i=[1 for _ in range(ndims)], + ) + # convert indices to have non-flattened indices values + s = symbolic_helper._slice_helper( + g, + flattened_indices, + axes=[2 + i for i in range(ndims)], + starts=list(tuple_fn(0)), + ends=list(tuple_fn(1)), + ) + indices = sub(g, indices, s) + return r, indices + else: + r = g.op("MaxPool", input, outputs=1, **kwargs) + return r + + return symbolic_fn + + +max_pool1d_with_indices = _onnx_symbolic("aten::max_pool1d_with_indices")( + _max_pool( + "max_pool1d_with_indices", + torch.nn.modules.utils._single, + 1, + return_indices=True, + ) +) +max_pool2d_with_indices = _onnx_symbolic("aten::max_pool2d_with_indices")( + _max_pool( + "max_pool2d_with_indices", + torch.nn.modules.utils._pair, + 2, + return_indices=True, + ) +) +max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")( + _max_pool( + "max_pool3d_with_indices", + torch.nn.modules.utils._triple, + 3, + return_indices=True, + ) +) + + +@_onnx_symbolic( + "aten::avg_pool1d", + decorate=[ + symbolic_helper._apply_params("avg_pool1d", torch.nn.modules.utils._single), + _export("avg_pool1d"), + ], +) +@_onnx_symbolic( + "aten::avg_pool2d", + decorate=[ + symbolic_helper._apply_params("avg_pool2d", torch.nn.modules.utils._pair), + _export("avg_pool2d"), + ], +) +@_onnx_symbolic( + "aten::avg_pool3d", + decorate=[ + symbolic_helper._apply_params("avg_pool3d", torch.nn.modules.utils._triple), + _export("avg_pool3d"), + ], +) +def _avg_pool(name, tuple_fn): + @symbolic_helper.quantized_args(True) + @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") + def symbolic_fn( + g, + input: _C.Value, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: int | Sequence[int], + ceil_mode: int, + count_include_pad: int, + divisor_override=None, + ): + if not stride: + stride = kernel_size + padding = symbolic_helper._avgpool_helper( + tuple_fn, padding, kernel_size, stride, divisor_override, name + ) + assert isinstance(padding, tuple) + adjusted_padding = padding + # Although onnx::AvgPool provides count_include_pad, + # The corner case of Average Pooling with ceil_mode on + # PyTorch allows sliding window go off bound, which leads to + # this accommodation. + # More detail on https://github.com/pytorch/pytorch/issues/57178 + if count_include_pad: + input = symbolic_helper._op_with_optional_float_cast( + g, + "Pad", + input, + pads_i=((0,) * 2 + padding) * 2, + mode_s="constant", + value_f=0.0, + opset_before=11, + ) + adjusted_padding = (0,) * len(padding) + if ceil_mode: + padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) + adjusted_padding = adjusted_padding + tuple( + a + b for (a, b) in zip(padding_ceil, adjusted_padding) + ) + else: + adjusted_padding = adjusted_padding * 2 + output = g.op( + "AveragePool", + input, + kernel_shape_i=tuple_fn(kernel_size), + strides_i=tuple_fn(stride), + pads_i=adjusted_padding, + ) + return output + + return symbolic_fn + + +@_onnx_symbolic( + "aten::adaptive_avg_pool1d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single + ), + _export("adaptive_avg_pool1d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_avg_pool2d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair + ), + _export("adaptive_avg_pool2d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_avg_pool3d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple + ), + _export("adaptive_avg_pool3d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_max_pool1d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_max_pool1d", + "MaxPool", + torch.nn.modules.utils._single, + max_pool1d_with_indices, + ), + _export("adaptive_max_pool1d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_max_pool2d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_max_pool2d", + "MaxPool", + torch.nn.modules.utils._pair, + max_pool2d_with_indices, + ), + _export("adaptive_max_pool2d"), + ], +) +@_onnx_symbolic( + "aten::adaptive_max_pool3d", + decorate=[ + symbolic_helper._apply_params( + "adaptive_max_pool3d", + "MaxPool", + torch.nn.modules.utils._triple, + max_pool3d_with_indices, + ), + _export("adaptive_max_pool3d"), + ], +) +def _adaptive_pool(name, type, tuple_fn, fn=None): + @symbolic_helper.quantized_args(True, False) + def symbolic_fn(g, input, output_size): + # _adaptive_pool is supported for cases where output_size is 1 for all dimensions, + # by executing a GlobalPool. + # It is also supported for cases where the output size is a factor of the input size. + # For these cases the stride and kernel size are uniform along all the indices of + # the same dimension, which makes it possible to export it to ONNX. + # for MaxPool, GlobalMaxPool does not return indices, + # so we try using max_poolxd_with_indices, and if it is not possible + # (input is not a complete tensor or output size not factor of input size) + # then we call GlobalAveragePool and return None for the indices + output_size_value = output_size + try: + output_size = symbolic_helper._parse_arg(output_size, "is") + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + return symbolic_helper._onnx_unsupported( + "adaptive pooling, since output_size is not constant.", input + ) + if output_size == [1] * len(output_size) and type == "AveragePool": + return g.op("GlobalAveragePool", input) + sizes = symbolic_helper._get_tensor_sizes(input) + try: + dim = sizes[2:] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + dim = None + if dim is None or any(i is None for i in dim): + if output_size == [1] * len(output_size): + return g.op("GlobalMaxPool", input), None + return symbolic_helper._unimplemented( + name, "input size not accessible", input + ) + # verify if output size % input size = 0 for all dim + mod = [dim[i] % output_size[i] for i in range(0, len(dim))] + if mod != [0] * len(mod): + if output_size == [1] * len(output_size): + return g.op("GlobalMaxPool", input), None + return symbolic_helper._unimplemented( + name, "output size that are not factor of input size", output_size_value + ) + k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] + # call max_poolxd_with_indices to get indices in the output + if type == "MaxPool": + return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False) + output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k)) + return output + + return symbolic_fn + + +def _prepare_onnx_paddings(dim: int, pad): + """Generate paddings in ONNX order based on pad in pytorch. + Args: + dim: the dimension of the tensor. + pad: the paddings in pytorch. + The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ... + """ + # The desired order of paddings is + # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. + # n is the dimension of input. + # assume zero-dimensions in the beginning + paddings = list(pad[:]) + [0] * (dim * 2 - len(pad)) + # reverse order and collate first beginnings and then ends + paddings = paddings[-2::-2] + paddings[-1::-2] + return paddings + + +def _convert_padding_node(input): + padding = symbolic_helper._maybe_get_const(input, "is") + if symbolic_helper._is_value(padding) and symbolic_helper._is_packed_list(padding): + input_list = symbolic_helper._unpack_list(padding) + try: + padding = [ + symbolic_helper._get_const(v, "i", "padding") for v in input_list + ] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + return symbolic_helper._onnx_opset_unsupported_detailed( + "Pad", 9, 11, "The sizes of the padding must be constant", input + ) + return padding + + +@_onnx_symbolic("aten::constant_pad_nd") +def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value): + mode = "constant" + try: + value = symbolic_helper._get_const(value, "f", "value") + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + return symbolic_helper._onnx_opset_unsupported_detailed( + "Pad", 9, 11, "The value for the padding must be constant", value + ) + + padding = _convert_padding_node(padding) + paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) + return symbolic_helper._op_with_optional_float_cast( + g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11 + ) + + +def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value): + padding = _convert_padding_node(pad) + assert len(padding) % 2 == 0 + ndim = len(padding) // 2 + + cur = input + for idx in range(ndim): + pad_r = padding[-(2 * idx + 1)] + pad_l = padding[-(2 * idx + 2)] + tensors = [] + if pad_l > 0: + left = symbolic_helper._slice_helper( + g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[_constants.INT64_MAX] + ) + tensors.append(left) + + if pad_l < 0 or pad_r < 0: + start = builtins.max(0, -pad_l) + end = -(builtins.max(0, -pad_r)) + middle = symbolic_helper._slice_helper( + g, + cur, + axes=[2 + idx], + starts=[start], + ends=[end], + ) + tensors.append(middle) + else: + tensors.append(cur) + + if pad_r > 0: + right = symbolic_helper._slice_helper( + g, cur, axes=[2 + idx], starts=[0], ends=[pad_r] + ) + tensors.append(right) + + cur = g.op("Concat", *tensors, axis_i=(2 + idx)) + + return cur + + +@_onnx_symbolic("aten::reflection_pad1d") +@_onnx_symbolic("aten::reflection_pad2d") +@_onnx_symbolic("aten::reflection_pad3d") +def reflection_pad(g: jit_utils.GraphContext, input, padding): + mode = "reflect" + padding = _convert_padding_node(padding) + paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) + return symbolic_helper._op_with_optional_float_cast( + g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 + ) + + +@_onnx_symbolic("aten::replication_pad1d") +@_onnx_symbolic("aten::replication_pad2d") +@_onnx_symbolic("aten::replication_pad3d") +def replication_pad(g: jit_utils.GraphContext, input, padding): + mode = "edge" + padding = _convert_padding_node(padding) + paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) + return symbolic_helper._op_with_optional_float_cast( + g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 + ) + + +@_onnx_symbolic("aten::pad") +def pad( + g: jit_utils.GraphContext, + input: _C.Value, + pad: _C.Value, + mode: _C.Value, + value: _C.Value, +): + mode = symbolic_helper._parse_arg(mode, "s") + if mode == "replicate": + return replication_pad(g, input, pad) + elif mode == "reflect": + return reflection_pad(g, input, pad) + elif mode == "constant": + return constant_pad_nd(g, input, pad, value) + elif mode == "circular": + return _pad_circular(g, input, pad) + else: + raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) + + +@_onnx_symbolic( + "aten::upsample_nearest1d", + decorate=[ + symbolic_helper._apply_params("upsample_nearest1d", 3, "nearest"), + _export("upsample_nearest1d"), + ], +) +@_onnx_symbolic( + "aten::upsample_nearest2d", + decorate=[ + symbolic_helper._apply_params("upsample_nearest2d", 4, "nearest"), + _export("upsample_nearest2d"), + ], +) +@_onnx_symbolic( + "aten::upsample_nearest3d", + decorate=[ + symbolic_helper._apply_params("upsample_nearest3d", 5, "nearest"), + _export("upsample_nearest3d"), + ], +) +@_onnx_symbolic( + "aten::upsample_linear1d", + decorate=[ + symbolic_helper._apply_params("upsample_linear1d", 3, "linear"), + _export("upsample_linear1d"), + ], +) +@_onnx_symbolic( + "aten::upsample_bilinear2d", + decorate=[ + symbolic_helper._apply_params("upsample_bilinear2d", 4, "linear"), + _export("upsample_bilinear2d"), + ], +) +@_onnx_symbolic( + "aten::upsample_trilinear3d", + decorate=[ + symbolic_helper._apply_params("upsample_trilinear3d", 5, "linear"), + _export("upsample_trilinear3d"), + ], +) +def _interpolate(name: str, dim: int, interpolate_mode: str): + def symbolic_fn(g, input, output_size, *args): + scales, align_corners = symbolic_helper._get_interpolate_attributes( + g, interpolate_mode, args + ) + symbolic_helper._interpolate_warning(interpolate_mode) + align_corners = symbolic_helper._maybe_get_scalar(align_corners) + if align_corners: + return symbolic_helper._unimplemented(name, "align_corners == True", input) + if scales is None: + scales = symbolic_helper._interpolate_size_to_scales( + g, input, output_size, dim + ) + return g.op("Upsample", input, scales, mode_s=interpolate_mode) + + return symbolic_fn + + +@_onnx_symbolic("aten::__interpolate") +def __interpolate( + g: jit_utils.GraphContext, + input, + size, + scale_factor, + mode, + align_corners, + recompute_scale_factor, + antialias, +): + scales, mode = symbolic_helper._interpolate_get_scales_and_mode( + g, input, size, scale_factor, mode, align_corners + ) + return g.op("Upsample", input, scales, mode_s=mode) + + +@_onnx_symbolic("aten::bitwise_not") +def bitwise_not(g: jit_utils.GraphContext, input): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise Not " + "for non-boolean input values", + input, + ) + return g.op("Not", input) + + +@_onnx_symbolic("aten::bitwise_or") +def bitwise_or(g, self, other): + if not symbolic_helper._is_bool(self): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values. self: ", + self, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values. other: ", + other, + ) + return g.op("Or", self, other) + + +def wrap_logical_op_with_cast_to(to_type): + def decorator(fn): + @functools.wraps(fn) + def wrap_with_cast(g, input, other): + to_cast_func = globals()[f"_cast_{to_type}"] + return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False)) + + return wrap_with_cast + + return decorator + + +def wrap_logical_op_with_negation(func: Callable) -> Callable: + @functools.wraps(func) + def wrap_with_not(g, input, other): + return g.op("Not", func(g, input, other)) + + return wrap_with_not + + +@_onnx_symbolic("aten::__not_") +def __not_(g: jit_utils.GraphContext, self): + if not symbolic_helper._is_bool(self): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise Not " + "for non-boolean input values", + self, + ) + return g.op("Not", self) + + +@_onnx_symbolic("aten::eq") +@symbolic_helper.quantized_args(True, True) +def eq(g: jit_utils.GraphContext, self, other): + if isinstance(self.type(), _C.DeviceObjType) and isinstance( + other.type(), _C.DeviceObjType + ): + # ONNX doesn't have devices, so consider them all to be equal. + # The no-op check for equality will get constant-folded. + return g.op("Constant", value_t=torch.tensor(True, dtype=torch.bool)) + self_node = self.node() + other_node = other.node() + if self_node.kind() == other_node.kind() == "onnx::Constant": + if self_node.kindOf("value") == other_node.kindOf("value") == "s": + # Exporting strings to ONNX is not supported. + # If both strings are constant, we can compare them directly. + # The no-op check for equality will get constant-folded. + return g.op( + "Constant", + value_t=torch.tensor( + self_node.s("value") == other_node.s("value"), + dtype=torch.bool, + ), + ) + + return g.op("Equal", self, other) + + +@_onnx_symbolic("aten::ne") +@symbolic_helper.quantized_args(True, True) +@wrap_logical_op_with_negation +def ne(g: jit_utils.GraphContext, self, other): + return eq(g, self, other) + + +@_onnx_symbolic("aten::gt") +@symbolic_helper.quantized_args(True, True) +def gt(g: jit_utils.GraphContext, input, other): + return _gt_impl(g, input, other) + + +def _gt_impl(g: jit_utils.GraphContext, input, other): + if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) + return g.op("Greater", input, other) + + +@_onnx_symbolic("aten::lt") +@symbolic_helper.quantized_args(True, True) +def lt(g: jit_utils.GraphContext, input, other): + return _lt_impl(g, input, other) + + +def _lt_impl(g: jit_utils.GraphContext, input, other): + if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) + return g.op("Less", input, other) + + +@_onnx_symbolic("aten::ge") +@symbolic_helper.quantized_args(True, True) +@wrap_logical_op_with_negation +def ge(g: jit_utils.GraphContext, input, other): + return _lt_impl(g, input, other) + + +@_onnx_symbolic("aten::le") +@symbolic_helper.quantized_args(True, True) +@wrap_logical_op_with_negation +def le(g: jit_utils.GraphContext, input, other): + return _gt_impl(g, input, other) + + +@_onnx_symbolic("aten::__and_") +def __and_(g: jit_utils.GraphContext, input, other): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise AND " + "for non-boolean input values", + input, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise AND " + "for non-boolean input values", + other, + ) + return g.op("And", input, other) + + +@_onnx_symbolic("aten::__or_") +def __or_(g: jit_utils.GraphContext, input, other): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values", + input, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise OR " + "for non-boolean input values", + other, + ) + return g.op("Or", input, other) + + +@_onnx_symbolic("aten::__xor_") +def __xor_(g: jit_utils.GraphContext, input, other): + if not symbolic_helper._is_bool(input): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise XOR " + "for non-boolean input values", + input, + ) + if not symbolic_helper._is_bool(other): + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting bitwise XOR " + "for non-boolean input values", + other, + ) + return g.op("Xor", input, other) + + +@_onnx_symbolic("aten::logical_and") +@wrap_logical_op_with_cast_to("Bool") +def logical_and(g: jit_utils.GraphContext, input, other): + return g.op("And", input, other) + + +@_onnx_symbolic("aten::logical_or") +@wrap_logical_op_with_cast_to("Bool") +def logical_or(g: jit_utils.GraphContext, input, other): + return g.op("Or", input, other) + + +@_onnx_symbolic("aten::logical_xor") +@wrap_logical_op_with_cast_to("Bool") +def logical_xor(g: jit_utils.GraphContext, input, other): + return g.op("Xor", input, other) + + +@_onnx_symbolic("aten::logical_not") +def logical_not(g: jit_utils.GraphContext, input): + return g.op("Not", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)) + + +@_onnx_symbolic("aten::__rshift_") +def __rshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + self_scalar_type = _type_utils.JitScalarType.from_value(self) + if ( + _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) + != self_scalar_type + ): + other = g.op( + "Cast", + other, + to_i=self_scalar_type.onnx_type(), + ) + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=self_scalar_type.onnx_type(), + ) + rshift = g.op("Div", self, two_pow) + return rshift + + +@_onnx_symbolic("aten::__lshift_") +def __lshift_(g: jit_utils.GraphContext, self, other): + # make sure to cast other to self's type + # (when self is long, make sure that other is not float) + self_scalar_type = _type_utils.JitScalarType.from_value(self) + if ( + _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) + != self_scalar_type + ): + other = g.op( + "Cast", + other, + to_i=self_scalar_type.onnx_type(), + ) + + two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) + # exponent (same type as self) has to be float or double in onnx::Pow + if not symbolic_helper._is_fp(self): + other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) + two_pow = g.op("Pow", two, other) + two_pow = g.op( + "Cast", + two_pow, + to_i=self_scalar_type.onnx_type(), + ) + lshift = g.op("Mul", self, two_pow) + return lshift + + +@_onnx_symbolic("aten::where") +@symbolic_helper.parse_args("v", "v", "v", "i") +def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): + # Assumes that torch.where's first argument takes only Bool and Byte tensors. + if not symbolic_helper._is_bool(condition): + condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) + if self is None: + condition = nonzero(g, condition) + return symbolic_helper._unbind_helper( + g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs + ) + return g.op("Where", condition, self, other) + + +@_onnx_symbolic("aten::log_softmax") +@symbolic_helper.parse_args("v", "i", "none") +def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): + # PyTorch dim and ONNX axis have different meanings. + # See Softmax comment for details. + # TODO: remove this as onnx opset 11 spec allows negative axes + input_dim = symbolic_helper._get_tensor_rank(input) + if input_dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + ) + if dim < 0: + dim = input_dim + dim + is_transpose_required = input_dim != dim + 1 + # ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases. + if is_transpose_required: + axes = list(range(input_dim)) + axes[dim], axes[-1] = axes[-1], axes[dim] + input = g.op("Transpose", input, perm_i=axes) + dim = input_dim - 1 + return_op = g.op("LogSoftmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") + return_op = g.op( + "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() + ) + if is_transpose_required: + return_op = g.op("Transpose", return_op, perm_i=axes) # type: ignore[possibly-undefined] + return return_op + + +@_onnx_symbolic("aten::_log_softmax") +@symbolic_helper.parse_args("v", "i", "i") +def _log_softmax(g: jit_utils.GraphContext, input, dim, half_to_float): + if ( + half_to_float + and _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.UNDEFINED + ) + == _type_utils.JitScalarType.HALF + ): + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) + return log_softmax(g, input, dim) + + +@_onnx_symbolic("aten::_convolution") +@symbolic_helper.parse_args( + "v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i" +) +def _convolution( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + benchmark, + deterministic, + cudnn_enabled, + allow_tf32=None, +): + weight_size = symbolic_helper._get_tensor_sizes(weight) + try: + kernel_shape = weight_size[2:] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + kernel_shape = None + + if kernel_shape is None or any(i is None for i in kernel_shape): + raise errors.SymbolicValueError( + "Unsupported: ONNX export of convolution for kernel of unknown shape.", + input, + ) + + args = [input, weight] + # ONNX only supports 1D bias + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) == 1 + ): + args.append(bias) + + kwargs = { + "kernel_shape_i": weight_size[2:], + "strides_i": stride, + # NB: ONNX supports asymmetric padding, whereas PyTorch supports only + # symmetric padding + "pads_i": padding + padding, + "dilations_i": dilation, + "group_i": groups, + } + + if any(o != 0 for o in output_padding): + # ONNX supports both output_shape and output_padding. they are equivalent expressive. + # output_padding is more straightforward, so we use it here. + # output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2 + assert transposed + assert len(stride) == len(output_padding) + kwargs["output_padding_i"] = output_padding + + n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs) + + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) != 1 + ): + return g.op("Add", n, bias) + else: + return n + + +@_onnx_symbolic("aten::_convolution_mode") +@symbolic_helper.parse_args( + "v", + "v", + "v", + "is", + "s", + "is", + "i", +) +def _convolution_mode( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + groups, +): + weight_size = symbolic_helper._get_tensor_sizes(weight) + try: + kernel_shape = weight_size[2:] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + kernel_shape = None + + if kernel_shape is None or any(i is None for i in kernel_shape): + raise errors.SymbolicValueError( + "Unsupported: ONNX export of convolution for kernel of unknown shape.", + input, + ) + + args = [input, weight] + # ONNX only supports 1D bias + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) == 1 + ): + args.append(bias) + + if padding == "valid": + padding = "VALID" + elif padding == "same": + padding = "SAME_UPPER" + kwargs = { + "kernel_shape_i": weight_size[2:], + "strides_i": stride, + "auto_pad_s": padding, + "dilations_i": dilation, + "group_i": groups, + } + + n = g.op("Conv", *args, **kwargs) + + if ( + not symbolic_helper._is_none(bias) + and symbolic_helper._get_tensor_rank(bias) != 1 + ): + return g.op("Add", n, bias) + else: + return n + + +@_onnx_symbolic("aten::convolution") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i") +def convolution( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv1d") +@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") +def conv1d( + g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups +): + str_padding = symbolic_helper._parse_arg(padding, "s") + if str_padding in ["valid", "same"]: + return _convolution_mode( + g, + input, + weight, + bias, + stride, + str_padding, + dilation, + groups, + ) + else: + padding = symbolic_helper._parse_arg(padding, "is") + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + False, + (), + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv2d") +@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") +def conv2d( + g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups +): + str_padding = symbolic_helper._parse_arg(padding, "s") + if str_padding in ["valid", "same"]: + return _convolution_mode( + g, + input, + weight, + bias, + stride, + str_padding, + dilation, + groups, + ) + else: + padding = symbolic_helper._parse_arg(padding, "is") + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + False, + (), + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv3d") +@symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") +def conv3d( + g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups +): + str_padding = symbolic_helper._parse_arg(padding, "s") + if str_padding in ["valid", "same"]: + return _convolution_mode( + g, + input, + weight, + bias, + stride, + str_padding, + dilation, + groups, + ) + else: + padding = symbolic_helper._parse_arg(padding, "is") + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + False, + (), + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv_transpose1d") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") +def conv_transpose1d( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + output_padding, + groups, + dilation, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + True, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv_transpose2d") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") +def conv_transpose2d( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + output_padding, + groups, + dilation, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + True, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::conv_transpose3d") +@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") +def conv_transpose3d( + g: jit_utils.GraphContext, + input, + weight, + bias, + stride, + padding, + output_padding, + groups, + dilation, +): + return _convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + True, + output_padding, + groups, + None, + None, + None, + None, + ) + + +@_onnx_symbolic("aten::batch_norm") +@symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") +def batch_norm( + g: jit_utils.GraphContext, + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + cudnn_enabled, +): + symbolic_helper.check_training_mode(training, "batch_norm") + + if ( + torch.is_autocast_enabled() + and not symbolic_helper.args_have_same_dtype( + [input, weight, bias, running_mean, running_var] + ) + and GLOBALS.export_onnx_opset_version < 15 + ): + return symbolic_helper._onnx_opset_unsupported_detailed( + "BatchNormalization", + 9, + 15, + "All input tensors must have the same `dtype`." + " Turn off Autocast or export using opset version 15.", + input, + ) + + weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( + g, input, weight, bias, running_mean, running_var + ) + out = g.op( + "BatchNormalization", + input, + weight, + bias, + running_mean, + running_var, + epsilon_f=eps, + momentum_f=1 - momentum, + outputs=1 if not training else 5, + ) + if not training: + return out + else: + res, new_running_mean, new_running_var, saved_mean, saved_var = out + new_running_mean.setType(running_mean.type()) + new_running_var.setType(running_var.type()) + saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName()) + saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName()) + return res + + +@_onnx_symbolic("aten::native_layer_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "is", "v", "v", "f") +def native_layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, +) -> tuple[_C.Value, _C.Value, _C.Value]: + axes = [-i for i in range(len(normalized_shape), 0, -1)] + + two_cst = symbolic_helper._generate_wrapped_number(g, 2.0) + eps_cst = symbolic_helper._generate_wrapped_number(g, eps) + + if g.opset < 18: + mean = g.op("ReduceMean", input, axes_i=axes) + else: + mean = g.op( + "ReduceMean", + input, + g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), + ) + + numerator = sub(g, input, mean) + + # Cast it to eps dtype to avoid precision loss + is_type_half = ( + _type_utils.JitScalarType.from_value(numerator) + == _type_utils.JitScalarType.HALF + ) + if is_type_half: + eps_dtype = _type_utils.JitScalarType.from_value(eps_cst) + numerator = g.op( + "Cast", numerator, to_i=_type_utils.JitScalarType(eps_dtype).onnx_type() + ) + + # variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula + if g.opset < 18: + variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes) + else: + variance = g.op( + "ReduceMean", + pow(g, numerator, two_cst), + g.op("Constant", value_t=torch.tensor(axes, dtype=torch.long)), + ) + + denominator = sqrt(g, g.op("Add", variance, eps_cst)) + normalized = g.op("Div", numerator, denominator) + + # Cast back to input type as eps related ops are all done + if is_type_half: + input_dtype = _type_utils.JitScalarType.from_value(input) + normalized = g.op( + "Cast", normalized, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() + ) + + if not (weight is None or symbolic_helper._is_none(weight)): + normalized = mul(g, normalized, weight) + if not (bias is None or symbolic_helper._is_none(bias)): + normalized = add(g, normalized, bias) + + # rdenominator := 1 / sqrt(variance + eps) + # According to aten::native_layer_norm, rdenominator should have the same dtype as input, + # mean and normalized, so we need to Cast it back + if is_type_half: + denominator = g.op( + "Cast", + denominator, + to_i=_type_utils.JitScalarType(input_dtype).onnx_type(), # type: ignore[possibly-undefined] + ) + rdenominator = g.op("Reciprocal", denominator) + else: + rdenominator = reciprocal(g, denominator) + + return normalized, mean, rdenominator + + +@_onnx_symbolic("aten::layer_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "is", "v", "v", "f", "b") +def layer_norm( + g: jit_utils.GraphContext, + input: _C.Value, + normalized_shape: Sequence[int], + weight: _C.Value, + bias: _C.Value, + eps: float, + cudnn_enable: bool, +) -> _C.Value: + normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps) + return normalized + + +@_onnx_symbolic("aten::instance_norm") +@symbolic_helper.parse_args("v", "v", "v", "v", "v", "b", "f", "f", "b") +def instance_norm( + g: jit_utils.GraphContext, + input, + weight, + bias, + running_mean, + running_var, + use_input_stats: bool, + momentum: Number, + eps: Number, + cudnn_enabled: bool, +): + symbolic_helper.check_training_mode(use_input_stats, "instance_norm") + channel_size = symbolic_helper._get_tensor_dim_size(input, 1) + if weight is None or symbolic_helper._is_none(weight): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of instance_norm for unknown channel size.", + input, + ) + weight_value = torch.tensor( + [1.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + weight = g.op("Constant", value_t=weight_value) + if bias is None or symbolic_helper._is_none(bias): + if channel_size is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of instance_norm for unknown channel size.", + input, + ) + bias_value = torch.tensor( + [0.0] * channel_size, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ) + bias = g.op("Constant", value_t=bias_value) + if ( + running_mean is None + or symbolic_helper._is_none(running_mean) + or running_var is None + or symbolic_helper._is_none(running_var) + ): + return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps) + else: + input_size = symbolic_helper._get_tensor_sizes(input) + # If input shape is [N, C, H, W], reshape to [1, N * C, H, W] and call batch_norm. + # For more information instance_norm(): + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L542 + input_size_reshape = input_size.copy() + n = input_size[0] + if n is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of instance_norm training for unknown " + "batch size.", + input, + ) + c = input_size[1] + input_size_reshape[0] = 1 + input_size_reshape[1] = n * c + weight_ = repeat( + g, weight, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) + ) + bias_ = repeat( + g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) + ) + running_mean_ = repeat( + g, + running_mean, + g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), + ) + running_var_ = repeat( + g, + running_var, + g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), + ) + input_reshaped = g.op( + "Reshape", + input, + g.op("Constant", value_t=torch.LongTensor(input_size_reshape)), + ) + out = batch_norm( + g, + input_reshaped, + weight_, + bias_, + running_mean_, + running_var_, + use_input_stats, + momentum, + eps, + cudnn_enabled, + ) + return view(g, out, g.op("Constant", value_t=torch.tensor(input_size))) + + +@_onnx_symbolic("aten::unfold") +@symbolic_helper.parse_args("v", "i", "i", "i") +def unfold(g: jit_utils.GraphContext, input, dimension, size, step): + sizes = symbolic_helper._get_tensor_sizes(input) + # FIXME(justinchuby): Get rid of the try catch here to improve readability + try: + sizedim = sizes[dimension] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + sizedim = None + if sizedim is not None: + low_indices = range(0, sizedim, step) + hi_indices = range(size, sizedim + 1, step) + stack = [ + symbolic_helper._slice_helper( + g, input, axes=[dimension], starts=[low], ends=[hi] + ) + for low, hi in zip(low_indices, hi_indices) + ] + ndim = len(sizes) + perm = list(range(0, ndim)) + perm.append(perm.pop(dimension)) + unsqueeze = [ + symbolic_helper._unsqueeze_helper( + g, g.op("Transpose", t, perm_i=perm), [dimension] + ) + for t in stack + ] + return g.op("Concat", *unsqueeze, axis_i=dimension) + else: + return symbolic_helper._unimplemented( + "Unfold", "input size not accessible", input + ) + + +@_onnx_symbolic("aten::elu") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "t", "t", "t") +def elu(g: jit_utils.GraphContext, input, alpha, scale, input_scale): + if scale and scale != 1.0: + return symbolic_helper._unimplemented( + "scale", "does not support scale in Elu", scale + ) + if input_scale and input_scale != 1.0: + return symbolic_helper._unimplemented( + "input_scale", "does not support input_scale in Elu", input_scale + ) + # See Note [Export inplace] + return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha)) + + +@_onnx_symbolic("aten::selu") +@symbolic_helper.quantized_args(True) +def selu(g: jit_utils.GraphContext, input): + return g.op("Selu", input) + + +@_onnx_symbolic("aten::index_select") +@symbolic_helper.parse_args("v", "i", "v") +def index_select(g: jit_utils.GraphContext, self, dim, index): + # In case of a scalar index, index_select returns a tensor with the same rank as the input. + # To match this behavior in ONNX, we make index a 1D tensor so that the following gather + # also produces a tensor with the same rank as the input. + return symbolic_helper._select_helper(g, self, dim, index) + + +@_onnx_symbolic("aten::index_put") +def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accumulate): + if symbolic_helper._is_packed_list(indices_list_value): + indices_list = symbolic_helper._unpack_list(indices_list_value) + else: + indices_list = [indices_list_value] + + accumulate = symbolic_helper._parse_arg(accumulate, "b") + + if len(indices_list) == 0: + if accumulate: + return add(g, self, values) + return values + symbolic_helper._onnx_opset_unsupported("index_put", 9, 11, self) + + +@_onnx_symbolic("aten::index_fill") +def index_fill(g: jit_utils.GraphContext, self, dim, index, value): + expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + value = symbolic_helper._maybe_get_scalar(value) + value = symbolic_helper._if_scalar_type_as(value, self) + expanded_value = expand(g, value, expanded_index_shape, None) + + return scatter(g, self, dim, expanded_index, expanded_value) + + +@_onnx_symbolic("aten::index_copy") +def index_copy(g: jit_utils.GraphContext, self, dim, index, source): + _expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + g, self, dim, index + ) + return scatter(g, self, dim, expanded_index, source) + + +@_onnx_symbolic("aten::bucketize") +@symbolic_helper.parse_args("v", "v", "b", "b") +def bucketize( + g: jit_utils.GraphContext, self, boundaries, out_int32=False, right=False +): + out_type = _C_onnx.TensorProtoDataType.INT64 + if out_int32: + out_type = _C_onnx.TensorProtoDataType.INT32 + # A tensor expanded_boundaries is created such that it + # contains a copy of boundaries for each element of self. + new_shape = g.op("Concat", g.op("Shape", boundaries), g.op("Shape", self), axis_i=0) + # Unsqueeze step is performed to respect ONNX's numpy style broadcasting for comparison ops + # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md + tensor_rank = symbolic_helper._get_tensor_rank(self) + assert tensor_rank is not None + unsqueeze_axes = list(range(1, tensor_rank + 1)) + expanded_boundaries = expand( + g, + symbolic_helper._unsqueeze_helper(g, boundaries, unsqueeze_axes), + new_shape, + None, + ) + # Compare each element of self to boundaries to get a tensor + # with leading 1s and trailing 0s. + # e.g., 4 > [1, 3, 4] = [1, 1, 0] + # The index of the last 1 is the bucket where the element should go. + if right: + cond = ge(g, self, expanded_boundaries) + else: + cond = gt(g, self, expanded_boundaries) + cond_out = g.op("Cast", cond, to_i=out_type) + # Sum to get the number of 1s corresponding to each element, + # which is the same as the bucket index. + # e.g., sum(4 > [1, 3, 4]) = sum([1, 1, 0]) = 2 + return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0) + + +@_onnx_symbolic("aten::type_as") +def type_as(g: jit_utils.GraphContext, self, other): + self_dtype = symbolic_helper._try_get_scalar_type(self) + other_dtype = symbolic_helper._try_get_scalar_type(other) + if self_dtype == other_dtype and self_dtype is not None: + return self + if other_dtype is not None: + return g.op( + "Cast", + self, + to_i=other_dtype.onnx_type(), + ) + + raise errors.SymbolicValueError( + "Unsupported: ONNX export of type_as for tensor " + "of unknown dtype. Please check if the dtype of the " + "parameter passed to the type_as function is correct.", + other, + ) + + +@_onnx_symbolic("aten::cosine_similarity") +@symbolic_helper.parse_args("v", "v", "i", "f") +def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps): + cross = symbolic_helper._reducesum_helper( + g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0 + ) + x1_l2 = symbolic_helper._reducesum_helper( + g, mul(g, x1, x1), axes_i=[dim], keepdims_i=0 + ) + x2_l2 = symbolic_helper._reducesum_helper( + g, mul(g, x2, x2), axes_i=[dim], keepdims_i=0 + ) + div_tens = max( + g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps])) + ) + return div(g, cross, div_tens) + + +@_onnx_symbolic("aten::pairwise_distance") +def pairwise_distance(g: jit_utils.GraphContext, input1, input2, p, eps, keepdim): + if not symbolic_helper._is_value(eps): + eps = g.op("Constant", value_t=torch.tensor([eps])) + inv_p = div( + g, + g.op("Constant", value_t=torch.tensor([1], dtype=torch.float)), + add(g, p, eps), + ) + summation = symbolic_helper._reducesum_helper( + g, + pow(g, sub(g, input1, input2), p), + axes_i=[-1], + keepdims_i=symbolic_helper._parse_arg(keepdim, "i"), + ) + return pow(g, summation, inv_p) + + +@_onnx_symbolic("aten::clone") +# ignore clone operators that are inserted by PyTorch autograd +def clone(g: jit_utils.GraphContext, input, unused_memory_format): + return input + + +@_onnx_symbolic("aten::abs") +def abs(g: jit_utils.GraphContext, self): + return g.op("Abs", self) + + +@_onnx_symbolic("aten::log") +def log(g: jit_utils.GraphContext, self): + return g.op("Log", self) + + +@_onnx_symbolic("aten::log1p") +def log1p(g: jit_utils.GraphContext, self): + return log(g, add(g, symbolic_helper._if_scalar_type_as(torch.ones(1), self), self)) + + +@_onnx_symbolic("aten::log10") +def log10(g: jit_utils.GraphContext, self): + _ln10 = 2.30258509299404568401 + return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10]))) + + +@_onnx_symbolic("aten::pow") +def pow(g: jit_utils.GraphContext, self, exponent): + f_dtype = _type_utils.JitScalarType.from_value(self) + if not symbolic_helper._is_fp(self): + f_dtype = _type_utils.JitScalarType.FLOAT + self = g.op("Cast", self, to_i=f_dtype.onnx_type()) + if not symbolic_helper._is_fp(exponent): + exponent = g.op( + "Cast", + exponent, + to_i=f_dtype.onnx_type(), + ) + pow = g.op("Pow", self, exponent) + return pow + + +@_onnx_symbolic("aten::clamp") +def clamp(g: jit_utils.GraphContext, self, min, max): + # min or max may be None that we need to dispatch to + # Clip separately, as ONNX does not have None syntax + if symbolic_helper._is_none(min): + return clamp_max(g, self, max) + elif symbolic_helper._is_none(max): + return clamp_min(g, self, min) + else: + if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max): + return symbolic_helper._op_with_optional_float_cast( + g, + "Clip", + self, + min_f=symbolic_helper._parse_arg(min, "f"), + max_f=symbolic_helper._parse_arg(max, "f"), + opset_before=12, + ) + else: + return clamp_max(g, clamp_min(g, self, min), max) + + +@_onnx_symbolic("aten::clamp_min") +@symbolic_helper.parse_args("v", "v") +def clamp_min(g: jit_utils.GraphContext, self, min): + if symbolic_helper._is_constant(min): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12 + ) + else: + dtype = _type_utils.JitScalarType.from_value(self) + min = g.op("Cast", min, to_i=dtype.onnx_type()) + return symbolic_helper._op_with_optional_float_cast( + g, "Max", self, min, opset_before=12 + ) + + +@_onnx_symbolic("aten::clamp_max") +@symbolic_helper.parse_args("v", "v") +def clamp_max(g: jit_utils.GraphContext, self, max): + if symbolic_helper._is_constant(max): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12 + ) + else: + dtype = _type_utils.JitScalarType.from_value(self) + max = g.op("Cast", max, to_i=dtype.onnx_type()) + return symbolic_helper._op_with_optional_float_cast( + g, "Min", self, max, opset_before=12 + ) + + +@_onnx_symbolic("aten::max") +# torch.max (same for torch.min) actually has two interfaces smashed together: +# torch.max(x, dim, keepdim) and torch.max(x, y) +# TODO(justinchuby): Support multiple quantized args in output +def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._max_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::maximum") +@symbolic_helper.quantized_args(True, True) +def maximum(g: jit_utils.GraphContext, input, other): + return max(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::min") +# TODO(justinchuby): Support multiple quantized args in output +def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): + return symbolic_helper._min_helper(g, self, dim_or_y, keepdim) + + +@_onnx_symbolic("aten::minimum") +@symbolic_helper.quantized_args(True, True) +def minimum(g: jit_utils.GraphContext, input, other): + return min(g, input, dim_or_y=other) + + +@_onnx_symbolic("aten::amax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amax(g: jit_utils.GraphContext, self, dim, keepdim): + return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::amin") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "is", "i") +def amin(g: jit_utils.GraphContext, self, dim, keepdim): + return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::aminmax") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "i") +def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): + reduce_kwargs = {"keepdims_i": keepdim} + if not symbolic_helper._is_none(dim): + dim = symbolic_helper._get_const(dim, "i", "dim") + reduce_kwargs["axes_i"] = [dim] + + return g.op("ReduceMin", self, **reduce_kwargs), g.op( + "ReduceMax", self, **reduce_kwargs + ) + + +@_onnx_symbolic("aten::exp") +def exp(g: jit_utils.GraphContext, self): + return g.op("Exp", self) + + +@_onnx_symbolic("aten::dropout_") +@_onnx_symbolic("aten::dropout") +@symbolic_helper.parse_args("v", "f", "i") +def dropout(g: jit_utils.GraphContext, input, p, train): + symbolic_helper.check_training_mode(train, "dropout") + # if train is False, dropout is no-op + if not train: + return input + r, _ = g.op("Dropout", input, ratio_f=p, outputs=2) + return r + + +@_onnx_symbolic( + "aten::alpha_dropout_", + decorate=[symbolic_helper._apply_params("aten::alpha_dropout_")], +) # See Note [Export inplace] +@_onnx_symbolic( + "aten::feature_alpha_dropout_", + decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout_")], +) +@_onnx_symbolic( + "aten::feature_dropout_", + decorate=[symbolic_helper._apply_params("aten::feature_dropout_")], +) +@_onnx_symbolic( + "aten::feature_alpha_dropout", + decorate=[symbolic_helper._apply_params("aten::feature_alpha_dropout")], +) +@_onnx_symbolic( + "aten::alpha_dropout", + decorate=[symbolic_helper._apply_params("aten::alpha_dropout")], +) +@_onnx_symbolic( + "aten::feature_dropout", + decorate=[symbolic_helper._apply_params("aten::feature_dropout")], +) +def _unsupported_dropout(name: str): + @symbolic_helper.parse_args("v", "none", "b") + def feature_dropout(g, input, p, train): + # NB: In inference mode, FeatureDropout is exported as an identity op. + if train: + return symbolic_helper._unimplemented(name, "training mode", input) + return input + + return feature_dropout + + +@_onnx_symbolic("aten::norm") +@symbolic_helper.parse_args("v", "t", "is", "i", "v") +def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None): + if p == 1: + f = symbolic_helper._reduce_op_symbolic_helper("ReduceL1") + elif p == 2: + f = symbolic_helper._reduce_op_symbolic_helper("ReduceL2") + else: + raise errors.SymbolicValueError( + "ONNX export only p-norms with p of 1 or 2", self + ) + result = f(g, self, dim=dim, keepdim=keepdim) + if dtype is not None: + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + return result + + +@_onnx_symbolic("aten::conv_tbc") +@symbolic_helper.parse_args("v", "v", "v", "i") +def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad): + # input must have 3 dimensions, see: + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 + # input = (time, batch, in_channels) + # weight = (kernel_width, in_channels, out_channels) + # bias = (out_channels,) + input = g.op("Transpose", input, perm_i=[1, 2, 0]) + weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) + conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) + return g.op("Transpose", conv, perm_i=[2, 0, 1]) + + +@_onnx_symbolic("aten::_unique") +@symbolic_helper.parse_args("v", "i", "i") +def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse): + return symbolic_helper._onnx_unsupported("_unique", input) + + +@_onnx_symbolic("aten::_unique2") +@symbolic_helper.parse_args("v", "i", "i", "i") +def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts): + symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input) + + +@_onnx_symbolic("aten::_cast_Byte") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8) + + +@_onnx_symbolic("aten::_cast_Char") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Char(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8) + + +@_onnx_symbolic("aten::_cast_Short") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Short(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16) + + +@_onnx_symbolic("aten::_cast_Int") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Int(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) + + +@_onnx_symbolic("aten::_cast_Long") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Long(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) + + +@_onnx_symbolic("aten::_cast_Half") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Half(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16) + + +@_onnx_symbolic("aten::_cast_Float") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Float(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) + + +@_onnx_symbolic("aten::_cast_Double") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Double(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE) + + +@_onnx_symbolic("aten::_cast_Bool") +@deprecated("Avoid using this function and create a Cast node instead") +def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking): + return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL) + + +@_onnx_symbolic("aten::empty") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty( + g: jit_utils.GraphContext, + sizes, + dtype, + layout, + device, + pin_memory=False, + memory_format=None, +): + return zeros(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::empty_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def empty_like( + g: jit_utils.GraphContext, + input, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + return zeros_like(g, input, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::new_empty") +def new_empty( + g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return empty(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::scalar_tensor") +def scalar_tensor(g: jit_utils.GraphContext, scalar, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + dtype = _type_utils.JitScalarType.FLOAT + scalar = g.op("Cast", scalar, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + return scalar + + +@_onnx_symbolic("aten::tensor") +def tensor( + g: jit_utils.GraphContext, data, dtype=None, device=None, requires_grad=False +): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if symbolic_helper._is_packed_list(data): + if dtype is None: + dtype = _type_utils.JitScalarType.from_value( + symbolic_helper._unpack_list(data)[0] + ) + input_list = [] + for t in symbolic_helper._unpack_list(data): + shape_reference = g.op("Constant", value_t=torch.LongTensor([1])) + t = symbolic_helper._reshape_helper(g, t, shape_reference) + t = g.op("Cast", t, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + input_list.append(t) + return g.op("Concat", *input_list, axis_i=0) + else: + if dtype is None: + dtype = _type_utils.JitScalarType.from_value(data) + if symbolic_helper._is_list(data) and ( + symbolic_helper._is_tensor_list(data) + or symbolic_helper._is_scalar_list(data) + ): + data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1) + return g.op("Cast", data, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + + +@_onnx_symbolic("aten::as_tensor") +def as_tensor(g: jit_utils.GraphContext, data, dtype=None, device=None): + return tensor(g, data, dtype, device) + + +@_onnx_symbolic("aten::zeros") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + # NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + sizes_ = symbolic_helper._maybe_get_const(sizes, "is") + if isinstance(sizes_, list) and len(sizes_) == 0: + sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) + return g.op( + "ConstantOfShape", + sizes, + value_t=torch.tensor([0], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::zeros_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def zeros_like( + g: jit_utils.GraphContext, + input, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + if symbolic_helper._is_none(dtype): + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + return g.op( + "ConstantOfShape", + shape, + value_t=torch.tensor([0], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::new_zeros") +def new_zeros( + g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return zeros(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::zero") +def zero(g: jit_utils.GraphContext, self): + self_dtype = symbolic_helper._try_get_scalar_type(self) + return zeros_like(g, self, self_dtype) + + +@_onnx_symbolic("aten::ones") +@symbolic_helper.parse_args("v", "i", "v", "v", "v") +def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + sizes_ = symbolic_helper._maybe_get_const(sizes, "is") + if isinstance(sizes_, list) and len(sizes_) == 0: + sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) + return g.op( + "ConstantOfShape", + sizes, + value_t=torch.tensor([1], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::ones_like") +@symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") +def ones_like( + g: jit_utils.GraphContext, + input, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + shape = g.op("Shape", input) + if symbolic_helper._is_none(dtype): + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + return g.op( + "ConstantOfShape", + shape, + value_t=torch.tensor([1], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::new_ones") +def new_ones( + g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return ones(g, sizes, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::full") +def full( + g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False +): + const_value = symbolic_helper._maybe_get_const(value, "t") + if symbolic_helper._is_value(const_value): + dtype = _type_utils.JitScalarType.FLOAT if dtype is None else dtype + tmp = zeros(g, sizes, dtype, layout, device) + return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) + else: + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + sizes_ = symbolic_helper._maybe_get_const(sizes, "is") + if isinstance(sizes_, list) and len(sizes_) == 0: + sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) + return g.op( + "ConstantOfShape", + sizes, + value_t=const_value.view(1).to(scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::full_like") +def full_like( + g: jit_utils.GraphContext, + input, + fill_value, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + fill_value = symbolic_helper._maybe_get_const(fill_value, "f") + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + if symbolic_helper._is_value(fill_value): + tmp = zeros_like(g, input, dtype, layout, device) + fill_value = g.op("Cast", fill_value, to_i=scalar_type.onnx_type()) + return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1))) + else: + shape = g.op("Shape", input) + return g.op( + "ConstantOfShape", + shape, + value_t=torch.tensor([fill_value], dtype=scalar_type.dtype()), + ) + + +@_onnx_symbolic("aten::new_full") +def new_full( + g: jit_utils.GraphContext, + self, + size, + fill_value, + dtype, + layout, + device, + pin_memory=False, +): + self_dtype = symbolic_helper._try_get_scalar_type(self) + if symbolic_helper._is_none(dtype) and self_dtype is not None: + dtype = self_dtype + return full(g, size, fill_value, dtype, layout, device, pin_memory) + + +@_onnx_symbolic("aten::eye") +def eye(g: jit_utils.GraphContext, *args): + if len(args) == 5: + # aten::eye(n, dtype, layout, device, pin_memory) + n, dtype, layout, device, _pin_memory = args + dim_size = symbolic_helper._unsqueeze_helper(g, n, [0]) + shape = g.op("Concat", dim_size, dim_size, axis_i=0) + tensor = zeros(g, shape, dtype, layout, device) + return g.op("EyeLike", tensor) + if len(args) == 6: + # aten::eye(n, m, dtype, layout, device, pin_memory) + n, m, dtype, layout, device, _pin_memory = args + shape = g.op( + "Concat", + symbolic_helper._unsqueeze_helper(g, n, [0]), + symbolic_helper._unsqueeze_helper(g, m, [0]), + axis_i=0, + ) + tensor = zeros(g, shape, dtype, layout, device) + return g.op("EyeLike", tensor) + + return symbolic_helper._unimplemented("aten::eye", f"with {len(args)} arguments") + + +@_onnx_symbolic("aten::slice") +def slice(g: jit_utils.GraphContext, self, *args): + if len(args) == 4: + # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor + dim, start, end, step = args + step = symbolic_helper._parse_arg(step, "i") + if step != 1: + raise errors.SymbolicValueError("step!=1 is currently not supported", self) + is_start_none = start.node().kind() == "prim::Constant" and isinstance( + start.type(), _C.NoneType + ) + is_end_none = end.node().kind() == "prim::Constant" and isinstance( + end.type(), _C.NoneType + ) + is_start_onnx_const = start.node().kind() == "onnx::Constant" + is_end_onnx_const = end.node().kind() == "onnx::Constant" + if ( + ((not is_start_none) and (not is_start_onnx_const)) + or ((not is_end_none) and (not is_end_onnx_const)) + or dim.node().kind() != "onnx::Constant" + ): + if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice " + "is a deprecated experimental op. Please use statically allocated " + "variables or export to a higher opset version.", + self, + ) + else: + start_unsqueezed = symbolic_helper._unsqueeze_helper(g, start, [0]) + end_unsqueezed = symbolic_helper._unsqueeze_helper(g, end, [0]) + dim_unsqueezed = symbolic_helper._unsqueeze_helper(g, dim, [0]) + return g.op( + "DynamicSlice", + self, + start_unsqueezed, + end_unsqueezed, + dim_unsqueezed, + ) + else: + start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") + end = ( + _constants.INT64_MAX + if is_end_none + else symbolic_helper._parse_arg(end, "i") + ) + dim = symbolic_helper._parse_arg(dim, "i") + return symbolic_helper._slice_helper( + g, self, axes=[dim], starts=[start], ends=[end] + ) + elif len(args) == 3: + # aten::slice(t[] l, int start, int end, int step) -> t[] + start, end, step = args + dim = 0 + is_start_none = start.node().kind() == "prim::Constant" and isinstance( + start.type(), _C.NoneType + ) + is_end_none = end.node().kind() == "prim::Constant" and isinstance( + end.type(), _C.NoneType + ) + start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") + end = ( + _constants.INT64_MAX + if is_end_none + else symbolic_helper._parse_arg(end, "i") + ) + return symbolic_helper._slice_helper( + g, self, axes=[dim], starts=[start], ends=[end] + ) + + return symbolic_helper._unimplemented("aten::slice", f"with {len(args)} arguments") + + +@_onnx_symbolic("aten::hardtanh") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "f", "f") +def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): + return symbolic_helper._op_with_optional_float_cast( + g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12 + ) + + +@_onnx_symbolic("aten::hardswish") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v") +def hardswish(g: jit_utils.GraphContext, self): + hs = hardsigmoid(g, self) + return g.op("Mul", self, hs) + + +@_onnx_symbolic("aten::hardsigmoid") +# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp +@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) +@symbolic_helper.parse_args("v") +def hardsigmoid(g: jit_utils.GraphContext, self): + # Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid. + # See https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html + return g.op("HardSigmoid", self, alpha_f=1 / 6) + + +@_onnx_symbolic("aten::tanhshrink") +@symbolic_helper.parse_args("v") +def tanhshrink(g: jit_utils.GraphContext, self): + return g.op("Sub", self, tanh(g, self)) + + +@_onnx_symbolic("aten::hardshrink") +@symbolic_helper.parse_args("v", "f") +def hardshrink(g: jit_utils.GraphContext, self, lambd): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + lambd_op = g.op( + "Constant", + value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), + ) + cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op))) + return g.op( + "Where", + cond, + self, + g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ), + ) + + +@_onnx_symbolic("aten::softshrink") +@symbolic_helper.parse_args("v", "f") +def softshrink(g: jit_utils.GraphContext, self, lambd): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + lambd_op = g.op( + "Constant", + value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), + ) + gt_cond = gt(g, self, lambd_op) + gt_out = g.op( + "Where", + gt_cond, + sub(g, self, lambd_op), + g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ), + ) + lt_cond = lt(g, self, neg(g, lambd_op)) + lt_out = g.op( + "Where", + lt_cond, + add(g, self, lambd_op), + g.op( + "Constant", + value_t=torch.tensor(0, dtype=scalar_type.dtype()), + ), + ) + return add(g, gt_out, lt_out) + + +@_onnx_symbolic("aten::alias") +def alias(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("aten::unsqueeze") +@symbolic_helper.parse_args("v", "i") +def unsqueeze(g: jit_utils.GraphContext, self, dim): + """Implement unsqueezing a pytorch tensor in ONNX by inserting a new dimension at the specified `dim`""" + # Handle negative dim + if dim < 0: + rank = symbolic_helper._get_tensor_rank(self) + if rank is not None: + warnings.warn( + "ONNX export unsqueeze with negative axis " + + str(dim) + + " might cause the onnx model to be incorrect. " + + "Negative axis is not supported in ONNX. " + + "Axis is converted to " + + str(dim + rank + 1) + + " based on input shape at export time. " + + "Passing an tensor of different rank in execution will be incorrect." + ) + dim = dim + rank + 1 + else: + return symbolic_helper._unimplemented( + "unsqueeze", "negative axis with unknown input rank", self + ) + + return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim]) + + +@_onnx_symbolic("aten::sort") +# TODO(justinchuby): Support multiple quantized args in output +@symbolic_helper.parse_args("v", "i", "i", "none") +def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): + if out is not None: + symbolic_helper._unimplemented( + "Sort", "Out parameter is not supported for sort", self + ) + self_sizes = symbolic_helper._get_tensor_sizes(self) + try: + dim_size = self_sizes[dim] + except Exception: + # FIXME(justinchuby): Avoid catching Exception. + # Catch a more specific exception instead. + dim_size = None + + if dim_size is None: + return symbolic_helper._unimplemented("Sort", "input size not accessible", self) + + return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2) + + +@_onnx_symbolic("aten::numel") +def numel(g: jit_utils.GraphContext, self): + return symbolic_helper._numel_helper(g, self) + + +@_onnx_symbolic("aten::topk") +# TODO(justinchuby): Support multiple quantized args in output +@symbolic_helper.parse_args("v", "i", "i", "i", "i", "none") +def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): + if out is not None: + symbolic_helper._unimplemented( + "TopK", "Out parameter is not supported for topk", self + ) + if not largest: + symbolic_helper._unimplemented("TopK", "Ascending TopK is not supported", self) + + return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2) + + +@_onnx_symbolic("prim::convert_element_type") +def convert_element_type(g: jit_utils.GraphContext, self, *args): + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + + +@_onnx_symbolic("aten::to") +def to(g: jit_utils.GraphContext, self, *args): + def is_aten_to_device_only(args): + if len(args) == 4: + # aten::to(Tensor, Device, bool, bool, memory_format) + return ( + args[0].node().kind() == "prim::device" + or args[0].type().isSubtypeOf(_C.ListType.ofInts()) + or isinstance(args[0].type(), _C.DeviceObjType) + ) + elif len(args) == 5: + # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) + # When dtype is None, this is a aten::to(device) call + dtype = symbolic_helper._get_const(args[1], "i", "dtype") + return dtype is None + elif len(args) in (6, 7): + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor + # When dtype is None, this is a aten::to(device) call + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + return dtype is None + return False + + # ONNX doesn't have a concept of a device, so we ignore device-only casts + if is_aten_to_device_only(args): + return self + + if len(args) == 4: + # TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=]() + # In this case, the constant value is a tensor not int, + # so symbolic_helper._maybe_get_const(args[0], 'i') would not work. + dtype = args[0] + if ( + symbolic_helper._is_value(args[0]) + and args[0].node().kind() == "onnx::Constant" + ): + tval = symbolic_helper._node_get(args[0].node(), "value") + if isinstance(tval, torch.Tensor): + if len(tval.shape) == 0: + tval = tval.item() + dtype = int(tval) + else: + dtype = tval + + if symbolic_helper._is_value(dtype) or isinstance(dtype, torch.Tensor): + # aten::to(Tensor, Tensor, bool, bool, memory_format) + dtype = _type_utils.JitScalarType.from_value(args[0]) + return g.op( + "Cast", + self, + to_i=dtype.onnx_type(), + ) + else: + # aten::to(Tensor, ScalarType, bool, bool, memory_format) + # memory_format is ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + elif len(args) == 5: + # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) + dtype = symbolic_helper._get_const(args[1], "i", "dtype") + # memory_format is ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + elif len(args) == 6: + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + # Layout, device and memory_format are ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + elif len(args) == 7: + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor + dtype = symbolic_helper._get_const(args[0], "i", "dtype") + # Layout, device and memory_format are ignored + return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) + + return symbolic_helper._onnx_unsupported("Unknown aten::to signature", self) + + +@_onnx_symbolic("aten::repeat") +def repeat(g: jit_utils.GraphContext, self, repeats): + dtype = _type_utils.JitScalarType.INT64 + shape_ = ones_like(g, repeats, dtype) + self = g.op("Expand", self, shape_) + return g.op("Tile", self, repeats) + + +@_onnx_symbolic("aten::repeat_interleave") +def repeat_interleave( + g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None +): + repeats_dim = symbolic_helper._get_tensor_rank(repeats) + repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) + input_sizes = symbolic_helper._get_tensor_sizes(self) + if repeats_dim is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", + self, + ) + if repeats_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown repeats size.", + self, + ) + if input_sizes is None: + raise errors.SymbolicValueError( + "Unsupported: ONNX export of repeat_interleave for unknown input size.", + self, + ) + + # if dim is None flatten + # By default, use the flattened input array, and return a flat output array + if symbolic_helper._is_none(dim): + self = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1])) + ) + dim = torch.tensor(0, dtype=torch.int64) + else: + dim = symbolic_helper._maybe_get_scalar(dim) + + # Handle cases where dim is negative + if dim < 0: + dim += len(input_sizes) + + input_sizes_temp = input_sizes.copy() + for idx, input_size in enumerate(input_sizes): + if input_size is None: + input_sizes[idx], input_sizes_temp[idx] = 0, -1 + + # Cases where repeats is an int or single value tensor + if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): + if input_sizes[dim] == 0: + return symbolic_helper._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported along dimension with unknown input size", + self, + ) + return symbolic_helper._repeat_interleave_single_value_repeat_helper( + g, self, repeats, dim + ) + + # Cases where repeats is a 1 dim Tensor + elif repeats_dim == 1: + if input_sizes[dim] == 0: + return symbolic_helper._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported along dimension with unknown input size", + self, + ) + if repeats_sizes[0] is None: + return symbolic_helper._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported for cases with dynamic repeats", + self, + ) + assert repeats_sizes[0] == input_sizes[dim], ( + "repeats must have the same size as input along dim" + ) + reps = repeats_sizes[0] + else: + raise errors.SymbolicValueError("repeats must be 0-dim or 1-dim tensor", self) + + final_splits = [] + r_splits = symbolic_helper._repeat_interleave_split_helper(g, repeats, reps, 0) + i_splits = symbolic_helper._repeat_interleave_split_helper(g, self, reps, dim) + input_sizes[dim], input_sizes_temp[dim] = -1, 1 + for idx, r_split in enumerate(r_splits): + i_split = unsqueeze(g, i_splits[idx], dim + 1) + r_concat = [ + g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])), + r_split, + g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])), + ] + r_concat = g.op("Concat", *r_concat, axis_i=0) + i_split = expand(g, i_split, r_concat, None) + i_split = symbolic_helper._reshape_helper( + g, + i_split, + g.op("Constant", value_t=torch.LongTensor(input_sizes)), + allowzero=0, + ) + final_splits.append(i_split) + return g.op("Concat", *final_splits, axis_i=dim) + + +@_onnx_symbolic("aten::pixel_shuffle") +@symbolic_helper.parse_args("v", "i") +def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): + dims = symbolic_helper._get_tensor_sizes(self) + if len(dims) != 4: + return symbolic_helper._unimplemented( + "pixel_shuffle", "only support 4d input", self + ) + if any(i is None for i in dims[1:]): + after_view = symbolic_helper._reshape_helper( + g, + symbolic_helper._unsqueeze_helper(g, self, [2, 3]), + g.op( + "Constant", + value_t=torch.tensor([0, -1, upscale_factor, upscale_factor, 0, 0]), + ), + allowzero=0, + ) + after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) + # For dynamic input shapes, two reshapes are performed + reshape_h = symbolic_helper._reshape_helper( + g, + after_transpose, + g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])), + allowzero=0, + ) + reshape_w = symbolic_helper._reshape_helper( + g, + reshape_h, + g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])), + allowzero=0, + ) + return symbolic_helper._squeeze_helper(g, reshape_w, [3, 5]) + else: + output_channel = dims[1] // upscale_factor // upscale_factor + after_view = symbolic_helper._reshape_helper( + g, + self, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + output_channel, + upscale_factor, + upscale_factor, + dims[2], + dims[3], + ] + ), + ), + allowzero=0, + ) + after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) + return symbolic_helper._reshape_helper( + g, + after_transpose, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + output_channel, + dims[2] * upscale_factor, + dims[3] * upscale_factor, + ] + ), + ), + allowzero=0, + ) + + +@_onnx_symbolic("aten::pixel_unshuffle") +@symbolic_helper.parse_args("v", "i") +def pixel_unshuffle(g: jit_utils.GraphContext, self, downscale_factor): + dims = symbolic_helper._get_tensor_sizes(self) + if len(dims) != 4: + return symbolic_helper._unimplemented( + "pixel_shuffle", "only support 4d input", self + ) + if any(i is None for i in dims[1:]): + # For dynamic input shapes, two reshapes are performed + reshape_h = symbolic_helper._reshape_helper( + g, + symbolic_helper._unsqueeze_helper(g, self, [3]), + g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])), + allowzero=0, + ) + reshape_w = symbolic_helper._reshape_helper( + g, + reshape_h, + g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])), + allowzero=0, + ) + after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4]) + final_reshape = symbolic_helper._reshape_helper( + g, + after_transpose, + g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])), + allowzero=0, + ) + return symbolic_helper._squeeze_helper(g, final_reshape, [2, 3]) + else: + output_channel = dims[1] * downscale_factor * downscale_factor + after_view = symbolic_helper._reshape_helper( + g, + self, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + dims[1], + dims[2] // downscale_factor, + downscale_factor, + dims[3] // downscale_factor, + downscale_factor, + ] + ), + ), + allowzero=0, + ) + after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4]) + return symbolic_helper._reshape_helper( + g, + after_transpose, + g.op( + "Constant", + value_t=torch.tensor( + [ + -1, + output_channel, + dims[2] // downscale_factor, + dims[3] // downscale_factor, + ] + ), + ), + allowzero=0, + ) + + +def _generic_rnn( + g: jit_utils.GraphContext, + variant, + input, + initial_states, + all_weights, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first=None, + batch_sizes=None, +): + warnings.warn( + "Exporting a model to ONNX with a batch_size other than 1, " + + "with a variable length with " + + variant + + " can cause an error " + + "when running the ONNX model with a different batch size. " + + "Make sure to save the model with a batch size of 1, " + + "or define the initial states (h0/c0) as inputs of the model. " + ) + + onnxActivations = [ + "Relu", + "Tanh", + "Sigmoid", + "Affine", + "LeakyRelu", + "ThresholdedRelu", + "ScaledTanh", + "HardSigmoid", + "Elu", + "Softsign", + "Softplus", + ] + variantToOnnxActivationMap = dict( + zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations) + ) + weights_per_layer = 4 if has_biases else 2 + # this means that projections are used inside LSTM, so need to tell user that it's not supported + if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * ( + 1 + bidirectional + ): + return symbolic_helper._unimplemented("LSTM", "LSTMs with projections", input) + assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional) + layer_weights = [ + all_weights[i : i + weights_per_layer] + for i in range(0, len(all_weights), weights_per_layer) + ] + if batch_first: + # batch, seq, feat -> seq, batch, feat + input = g.op("Transpose", input, perm_i=[1, 0, 2]) + if dropout and train: + return symbolic_helper._unimplemented( + "RNN/GRU/LSTM", "dropout in training mode", input + ) + + if variant.startswith("RNN"): + nonlinearity = variantToOnnxActivationMap[variant[4:].lower()] + variant = "RNN" + + w_hh = all_weights[1] + hidden_size = symbolic_helper._get_tensor_dim_size(w_hh, 1) + if hidden_size is None: + return symbolic_helper._unimplemented( + "RNN/GRU/LSTM", "unknown hidden size", input + ) + + unidirectional = not bidirectional + + prev_output = input + + h_outs = [] + if variant == "RNN" or variant == "GRU": + h0 = initial_states + elif variant == "LSTM": + h0, c0 = initial_states + c_outs = [] + + sequence_lens = unused(g) if batch_sizes is None else batch_sizes + + if variant == "GRU": + # pytorch is reset, input, hidden + # onnx is input, reset, hidden + reform_permutation = [(1, 2), (0, 1), (2, 3)] + elif variant == "LSTM": + # pytorch is input, forget, cell, output. + # onnx is input, output, forget, cell. + reform_permutation = [(0, 1), (3, 4), (1, 3)] + + def reform_weights(g, w, n, intervals): + slices = [ + symbolic_helper._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n]) + for x, y in intervals + ] + return g.op("Concat", *slices, axis_i=0) + + def transform_weights_no_bias(layer_index): + weights = layer_weights[layer_index] + if variant == "RNN": + weight_ih, weight_hh = weights + elif variant == "GRU" or variant == "LSTM": + weight_ih, weight_hh = ( + reform_weights(g, w, hidden_size, reform_permutation) for w in weights + ) + return tuple( + symbolic_helper._unsqueeze_helper(g, x, [0]) + for x in (weight_ih, weight_hh) # type: ignore[possibly-undefined] + ) + + def transform_weights(layer_index): + weights = layer_weights[layer_index] + if variant == "RNN": + weight_ih, weight_hh, bias_ih, bias_hh = weights + elif variant == "GRU" or variant == "LSTM": + weight_ih, weight_hh, bias_ih, bias_hh = ( + reform_weights(g, w, hidden_size, reform_permutation) for w in weights + ) + bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) # type: ignore[possibly-undefined] + return tuple( + symbolic_helper._unsqueeze_helper(g, x, [0]) + for x in (weight_ih, weight_hh, bias_concat) # type: ignore[possibly-undefined] + ) + + def retrieve_state(x, start, end): + return ( + x + if num_layers == 1 + else symbolic_helper._slice_helper( + g, x, axes=[0], starts=[start], ends=[end] + ) + ) + + for i in range(num_layers): + if unidirectional: + if weights_per_layer == 4: + weight_ih, weight_hh, bias_concat = transform_weights(i) + else: + weight_ih, weight_hh = transform_weights_no_bias(i) + bias_concat = unused(g) + + state_indices = i, i + 1 + else: + if weights_per_layer == 4: + weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i) + weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1) + bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0) + else: + weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i) + weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1) + bias_concat = unused(g) + + weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0) + weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0) + + state_indices = 2 * i, 2 * i + 2 + + inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens] + + inputs.append(retrieve_state(h0, *state_indices)) # type: ignore[possibly-undefined] + if variant == "LSTM": + inputs.append(retrieve_state(c0, *state_indices)) # type: ignore[possibly-undefined] + + extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"} + if variant == "RNN": + if bidirectional: + activation = [nonlinearity, nonlinearity] # type: ignore[possibly-undefined] + else: + activation = [nonlinearity] # type: ignore[possibly-undefined] + + prev_output, h_out = g.op( + "RNN", + *inputs, + outputs=2, + hidden_size_i=hidden_size, + activations_s=activation, + **extra_kwargs, + ) + elif variant == "GRU": + prev_output, h_out = g.op( + "GRU", + *inputs, + outputs=2, + hidden_size_i=hidden_size, + linear_before_reset_i=1, + **extra_kwargs, + ) + elif variant == "LSTM": + prev_output, h_out, c_out = g.op( + "LSTM", *inputs, outputs=3, hidden_size_i=hidden_size, **extra_kwargs + ) + + if bidirectional: + # The ONNX RNN/GRU/LSTM produce an output of dimensions + # seq_len, num_directions, batch, hidden_size + # We have to convert to match pytorch's expected + # seq_len, batch, num_directions * hidden_size + # by first moving num_directions before hidden_size with + # Transpose, and then combining it with hidden_size + # with Reshape. + prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3]) + prev_output = symbolic_helper._reshape_helper( + g, + prev_output, + g.op("Constant", value_t=torch.LongTensor([0, 0, -1])), + allowzero=0, + ) + else: + prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1]) + + h_outs.append(h_out) # type: ignore[possibly-undefined] + if variant == "LSTM": + c_outs.append(c_out) # type: ignore[possibly-undefined] + if batch_first: + # seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size + prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2]) + h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) # type: ignore[possibly-undefined] + if variant == "RNN" or variant == "GRU": + return prev_output, h_outs + elif variant == "LSTM": + c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) # type: ignore[possibly-undefined] + return prev_output, h_outs, c_outs + + +@symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") +def _lstm_full( + g: jit_utils.GraphContext, + input, + hidden_v, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + hidden, weight = ( + symbolic_helper._unpack_list(hidden_v), + symbolic_helper._unpack_list(weight_v), + ) + return _generic_rnn( + g, + "LSTM", + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + ) + + +@symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") +def _lstm_packed( + g: jit_utils.GraphContext, + input, + batch_sizes, + hidden_v, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + hidden, weight = ( + symbolic_helper._unpack_list(hidden_v), + symbolic_helper._unpack_list(weight_v), + ) + return _generic_rnn( + g, + "LSTM", + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_sizes=batch_sizes, + ) + + +@_onnx_symbolic("aten::lstm") +def lstm(g: jit_utils.GraphContext, *args): + if symbolic_helper._is_tensor_list(args[3]): + return _lstm_packed(g, *args) + else: + return _lstm_full(g, *args) + + +@_onnx_symbolic("aten::lstm_cell") +def lstm_cell(g: jit_utils.GraphContext, self, hidden, w_ih, w_hh, b_ih, b_hh): + input = symbolic_helper._unsqueeze_helper(g, self, [0]) + hidden = symbolic_helper._unpack_list(hidden) + hidden = [symbolic_helper._unsqueeze_helper(g, x, [0]) for x in hidden] + weight = ( + (w_ih, w_hh, b_ih, b_hh) if symbolic_helper._is_tensor(b_ih) else (w_ih, w_hh) + ) + has_biases = True if symbolic_helper._is_tensor(b_ih) else False + _, h_outs, c_outs = _generic_rnn( + g, + "LSTM", + input, + hidden, + weight, + has_biases, + num_layers=1, + dropout=0, + train=0, + bidirectional=False, + batch_first=False, + ) + return symbolic_helper._squeeze_helper( + g, h_outs, [0] + ), symbolic_helper._squeeze_helper(g, c_outs, [0]) + + +@_onnx_symbolic( + "aten::gru", decorate=[symbolic_helper._apply_params("GRU"), _export("gru")] +) +@_onnx_symbolic( + "aten::rnn_tanh", + decorate=[symbolic_helper._apply_params("RNN_TANH"), _export("rnn_tanh")], +) +@_onnx_symbolic( + "aten::rnn_relu", + decorate=[symbolic_helper._apply_params("RNN_RELU"), _export("rnn_relu")], +) +def _one_hidden_rnn(kind: str): + @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") + def _rnn_full( + g, + input, + hidden, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + ): + weight = symbolic_helper._unpack_list(weight_v) + return _generic_rnn( + g, + kind, + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + ) + + @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") + def _rnn_packed( + g, + input, + batch_sizes, + hidden, + weight_v, + has_biases, + num_layers, + dropout, + train, + bidirectional, + ): + weight = symbolic_helper._unpack_list(weight_v) + return _generic_rnn( + g, + kind, + input, + hidden, + weight, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_sizes=batch_sizes, + ) + + def symbolic(g, *args): + if symbolic_helper._is_tensor_list(args[3]): + return _rnn_packed(g, *args) + else: + return _rnn_full(g, *args) + + return symbolic + + +@_onnx_symbolic("aten::_dim_arange") +@symbolic_helper.parse_args("v", "i") +def _dim_arange(g: jit_utils.GraphContext, like, dim): + like_shape = g.op("Shape", like) + stop = g.op( + "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 + ) + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + return arange(g, stop, 4, None, None, None) + + +@_onnx_symbolic("aten::detach") +def detach(g: jit_utils.GraphContext, input): + # Erase aten::detach nodes because ONNX is inference only + return input + + +@_onnx_symbolic("aten::contiguous") +@symbolic_helper.parse_args("v", "i") +def contiguous(g: jit_utils.GraphContext, input, memory_format): + if memory_format > 2: # allower values are any, preserve and contiguous_format + raise errors.SymbolicValueError( + "onnx memory_format support is not implemented", input + ) + return input + + +@_onnx_symbolic("aten::_pack_padded_sequence") +@symbolic_helper.parse_args("v", "v", "i") +def _pack_padded_sequence(g: jit_utils.GraphContext, input, lengths, batch_first): + # Currently there is no PackPadded operator in ONNX. We rely on an + # optimization pass to remove this later. It is an error if all + # PackPadded operators cannot be optimized out. + if batch_first: + input = g.op("Transpose", input, perm_i=[1, 0, 2]) + if not lengths.type().isSubtypeOf(torch._C.TensorType.get()): + raise errors.SymbolicValueError( + "'lengths' must be a Tensor for ONNX export", input + ) + # We know it's a TensorType so this check is now safe. + # It's really only necessary because those operators expand to something that + # only works with int32 types in Caffe2... + if ( + _type_utils.JitScalarType.from_value( + lengths, _type_utils.JitScalarType.UNDEFINED + ) + != _type_utils.JitScalarType.INT + ): + lengths = g.op("Cast", lengths, to_i=_C_onnx.TensorProtoDataType.INT32) + return g.op("prim::PackPadded", input, lengths, outputs=2) + + +@_onnx_symbolic("aten::_pad_packed_sequence") +@symbolic_helper.parse_args("v", "v", "i", "t", "v") +def _pad_packed_sequence( + g: jit_utils.GraphContext, + data, + batch_sizes, + batch_first, + padding_value, + total_length, +): + # Ignore total_length as it is not supported in _symbolic_pad_packed_sequence + # It is only useful/used when training using data_parallel model, so + # It shouldn't be relevant for ONNX anyway + data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2) + if batch_first: + data = g.op("Transpose", data, perm_i=[1, 0, 2]) + return data, lengths + + +@_onnx_symbolic("aten::randint") +def randint(g: jit_utils.GraphContext, low, high, shapes, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + low_i = symbolic_helper._get_const(low, "i", "low") + high_i = symbolic_helper._get_const(high, "i", "high") + if dtype is None: + scalar_type = _type_utils.JitScalarType.INT64 + else: + scalar_type = _type_utils.JitScalarType(dtype) + if low_i is None: + raise symbolic_helper._onnx_unsupported("randint", low) + if high_i is None: + raise symbolic_helper._onnx_unsupported("randint", high) + + shape = symbolic_helper._maybe_get_const(shapes, "is") + if symbolic_helper._is_value(shape): + shape_const = g.op( + "ConstantOfShape", + shapes, + value_t=torch.tensor([0], dtype=torch.float), + ) + randn = g.op( + "RandomUniformLike", + shape_const, + low_f=low_i, + high_f=high_i, + ) + else: + randn = g.op( + "RandomUniform", + shape_i=shape, + low_f=low_i, + high_f=high_i, + ) + + # cast to integer type + int_dtype = _type_utils.JitScalarType.INT64 + randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) + if int_dtype != scalar_type: + randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) + return randint + + +@_onnx_symbolic("aten::randint_like") +def randint_like(g: jit_utils.GraphContext, self, low, high, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + low_i = symbolic_helper._get_const(low, "i", "low") + high_i = symbolic_helper._get_const(high, "i", "high") + if dtype is None: + scalar_type = _type_utils.JitScalarType.INT64 + else: + scalar_type = _type_utils.JitScalarType(dtype) + if low_i is None: + raise symbolic_helper._onnx_unsupported("randint", low) + if high_i is None: + raise symbolic_helper._onnx_unsupported("randint", high) + + randn = g.op( + "RandomUniformLike", + self, + low_f=low_i, + high_f=high_i, + ) + + # cast to integer type + int_dtype = _type_utils.JitScalarType.INT64 + randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) + if int_dtype != scalar_type: + randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) + return randint + + +@_onnx_symbolic("aten::randn") +def randn(g: jit_utils.GraphContext, shapes, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + shape = symbolic_helper._maybe_get_const(shapes, "is") + if symbolic_helper._is_value(shape): + shape_const = g.op( + "ConstantOfShape", + shapes, + value_t=torch.tensor([0], dtype=torch.float), + ) + return g.op( + "RandomNormalLike", + shape_const, + dtype_i=scalar_type.onnx_type(), + ) + return g.op( + "RandomNormal", + shape_i=shape, + dtype_i=scalar_type.onnx_type(), + ) + + +@_onnx_symbolic("aten::rand") +def rand(g: jit_utils.GraphContext, shapes, dtype, *options): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.FLOAT + else: + scalar_type = _type_utils.JitScalarType(dtype) + shape = symbolic_helper._maybe_get_const(shapes, "is") + if symbolic_helper._is_value(shape): + shape_const = g.op( + "ConstantOfShape", + shapes, + value_t=torch.tensor([0], dtype=torch.float), + ) + return g.op( + "RandomUniformLike", + shape_const, + dtype_i=scalar_type.onnx_type(), + ) + return g.op( + "RandomUniform", + shape_i=shape, + dtype_i=scalar_type.onnx_type(), + ) + + +@_onnx_symbolic("aten::randn_like") +def randn_like( + g: jit_utils.GraphContext, + self, + dtype, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + else: + scalar_type = _type_utils.JitScalarType(dtype) + return g.op("RandomNormalLike", self, dtype_i=scalar_type.onnx_type()) + + +@_onnx_symbolic("aten::rand_like") +def rand_like( + g: jit_utils.GraphContext, + self, + dtype, + layout=None, + device=None, + pin_memory=False, + memory_format=None, +): + dtype = symbolic_helper._get_const(dtype, "i", "dtype") + if dtype is None: + dtype = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + return g.op( + "RandomUniformLike", self, dtype_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + + +@_onnx_symbolic("aten::rrelu") +@symbolic_helper.parse_args("v", "f", "f", "i", "none") +def rrelu(g: jit_utils.GraphContext, input, lower, upper, training, generator): + if not training: + slope = (upper + lower) / 2.0 + return g.op("LeakyRelu", input, alpha_f=slope) + p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower) + return g.op("PRelu", input, p) + + +@_onnx_symbolic("aten::bernoulli") +def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): + if out is not None and not symbolic_helper._is_none(out): + symbolic_helper._unimplemented( + "Bernoulli", "out parameter is not supported for bernoulli", input + ) + if generator is not None and not symbolic_helper._is_none(generator): + symbolic_helper._unimplemented( + "Bernoulli", "generator is not supported for bernoulli", input + ) + + dtype = _type_utils.JitScalarType.from_value( + input, _type_utils.JitScalarType.UNDEFINED + ) + if dtype == _type_utils.JitScalarType.UNDEFINED: + return symbolic_helper._unimplemented( + "Bernoulli", "input dtype not accessible", input + ) + + rands = g.op( + "RandomUniformLike", + input, + high_f=1.0, + low_f=0.0, + dtype_i=dtype.onnx_type(), + ) + prob = p if p is not None and not symbolic_helper._is_none(p) else input + output = g.op("Less", rands, prob) + return g.op("Cast", output, to_i=dtype.onnx_type()) + + +@_onnx_symbolic("aten::log_sigmoid") +@symbolic_helper.parse_args("v") +def log_sigmoid(g: jit_utils.GraphContext, input): + p = g.op("Sigmoid", input) + return g.op("Log", p) + + +@_onnx_symbolic("aten::erf") +@symbolic_helper.parse_args("v") +def erf(g: jit_utils.GraphContext, input): + return g.op("Erf", input) + + +@_onnx_symbolic("aten::flatten") +@symbolic_helper.quantized_args(True, False, False) +@symbolic_helper.parse_args("v", "i", "i") +def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): + dim = symbolic_helper._get_tensor_rank(input) + if dim is None: + return symbolic_helper._unimplemented( + "dim", + "ONNX and PyTorch use different strategies to split the input. " + "Input rank must be known at export time.", + input, + ) + + if dim == 0: + return symbolic_helper._reshape_helper(g, input, [1]) + if dim == 1: + return g.op("Identity", input) + # TODO: remove this as onnx opset 11 spec allows negative axes + if end_dim < 0: + end_dim = dim + end_dim + # use ONNX's Flatten operator for cases where the output shape is 2D + if start_dim == 1 and end_dim == dim - 1: + return g.op("Flatten", input, axis_i=start_dim) + if start_dim == 0 and end_dim == dim - 2: + return g.op("Flatten", input, axis_i=end_dim + 1) + + return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) + + +@_onnx_symbolic("aten::nonzero") +@symbolic_helper.parse_args("v") +def nonzero(g: jit_utils.GraphContext, input): + """Emitted from `torch.nonzero(x, as_tuple=False)`""" + return t(g, g.op("NonZero", input)) + + +@_onnx_symbolic("aten::nonzero_numpy") +# Emitted from `torch.nonzero(x, as_tuple=True)` +def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): + return unbind(g, nonzero(g, input), 1, _outputs=_outputs) + + +@_onnx_symbolic("aten::isnan") +@symbolic_helper.parse_args("v") +def isnan(g: jit_utils.GraphContext, input): + output = g.op("IsNaN", input) + return output + + +@_onnx_symbolic("aten::any") +def _any(g: jit_utils.GraphContext, *args): + # aten::any(Tensor self) + if len(args) == 1: + input = args[0] + dim, keepdim = None, 0 + # aten::any(Tensor self, int[]? dim, bool keepdim) + else: + input, dim, keepdim = args + # Can be int list or single int + dim = symbolic_helper._parse_arg(dim, "t") + dim = [int(d) for d in dim.view(-1)] + keepdim = symbolic_helper._parse_arg(keepdim, "i") + input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) + input_sum = symbolic_helper._reducesum_helper( + g, input, axes_i=dim, keepdims_i=keepdim + ) + return gt(g, input_sum, g.op("Constant", value_t=torch.tensor(0, dtype=torch.long))) + + +@_onnx_symbolic("aten::all") +def _all(g: jit_utils.GraphContext, *args): + input = g.op("Not", args[0]) + # aten::all(Tensor self) + if len(args) == 1: + return g.op("Not", _any(g, input)) + # aten::all(Tensor self, int[]? dim, bool keepdim) + else: + return g.op("Not", _any(g, input, args[1], args[2])) + + +@_onnx_symbolic("aten::narrow") +@symbolic_helper.parse_args("v", "i", "i", "i") +def narrow(g: jit_utils.GraphContext, input, dim, start, length): + return symbolic_helper._slice_helper( + g, input, axes=[dim], starts=[start], ends=[start + length] + ) + + +@_onnx_symbolic("aten::argmax") +@symbolic_helper.parse_args("v", "v", "b") +def argmax( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") + + +@_onnx_symbolic("aten::argmin") +@symbolic_helper.parse_args("v", "v", "b") +def argmin( + g: jit_utils.GraphContext, + input: torch._C.Value, + dim: torch._C.Value, + keepdim: bool, +): + return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") + + +@_onnx_symbolic("aten::scatter") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter(g: jit_utils.GraphContext, self, dim, index, src): + src_type = _type_utils.JitScalarType.from_value( + src, _type_utils.JitScalarType.UNDEFINED + ) + src = symbolic_helper._maybe_get_scalar(src) + if symbolic_helper._is_value(src): + return g.op("Scatter", self, index, src, axis_i=dim) + else: + # Check if scalar "src" has same type as self (PyTorch allows different + # type for scalar src (but not when src is tensor)). If not, insert Cast node. + self_scalar_type = _type_utils.JitScalarType.from_value(self) + if self_scalar_type != src_type: + src = g.op("Cast", src, to_i=self_scalar_type.onnx_type()) + return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim) + + +@_onnx_symbolic("aten::scatter_add") +@symbolic_helper.parse_args("v", "i", "v", "v") +def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): + scalar_type = symbolic_helper._try_get_scalar_type(self) + if scalar_type is None: + return symbolic_helper._unimplemented( + "scatter_add", "input dtype not accessible", self + ) + sizes = symbolic_helper._get_tensor_sizes(self, allow_nonstatic=False) + if sizes: + to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=scalar_type.dtype())) + else: + to_add = zeros_like(g, self, scalar_type) + to_add = symbolic_helper._scatter_helper(g, to_add, dim, index, src) + return add(g, self, to_add) + + +@_onnx_symbolic("aten::log2") +def log2(g: jit_utils.GraphContext, self): + _ln2 = 0.693147180559945309 + return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor(_ln2))) + + +@_onnx_symbolic("aten::is_floating_point") +def is_floating_point(g: jit_utils.GraphContext, self): + if symbolic_helper._is_fp(self): + return g.op("Constant", value_t=torch.BoolTensor([1])) + return g.op("Constant", value_t=torch.BoolTensor([0])) + + +@_onnx_symbolic("aten::__is_") +def __is_(g: jit_utils.GraphContext, self, other): + if symbolic_helper._is_none(other): + if symbolic_helper._is_none(self): + return g.op("Constant", value_t=torch.BoolTensor([1])) + return g.op("Constant", value_t=torch.BoolTensor([0])) + return eq(g, self, other) + + +@_onnx_symbolic("aten::__isnot_") +@wrap_logical_op_with_negation +def __isnot_(g: jit_utils.GraphContext, self, other): + return __is_(g, self, other) + + +@_onnx_symbolic("aten::one_hot") +def one_hot(g: jit_utils.GraphContext, self, num_classes): + values = g.op("Constant", value_t=torch.LongTensor([0, 1])) + # onnxruntime supports limited type combinations for OneHot. + if _type_utils.JitScalarType.from_value( + num_classes, _type_utils.JitScalarType.UNDEFINED + ) in { + _type_utils.JitScalarType.UINT8, + _type_utils.JitScalarType.INT8, + _type_utils.JitScalarType.INT, + _type_utils.JitScalarType.INT16, + }: + num_classes = g.op("Cast", num_classes, to_i=_C_onnx.TensorProtoDataType.INT64) + return g.op("OneHot", self, num_classes, values, axis_i=-1) + + +@_onnx_symbolic("aten::gather") +@symbolic_helper.parse_args("v", "i", "v", "v") +def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): + if symbolic_helper._maybe_get_const(sparse_grad, "i"): + return symbolic_helper._unimplemented("gather", "sparse_grad == True", self) + # NOTE: This workaround is needed since GatherElement is only supported + # since opset 11, and Gather in ONNX is not the same as torch.gather. + scalar_type = _type_utils.JitScalarType.from_value(self) + values = g.op("Constant", value_t=torch.LongTensor([0, 1])) + depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim]))) + index = g.op( + "Cast", + g.op("OneHot", index, depth, values, axis_i=dim), + to_i=scalar_type.onnx_type(), + ) + mul = g.op("Mul", symbolic_helper._unsqueeze_helper(g, self, [dim + 1]), index) + return symbolic_helper._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0) + + +@symbolic_helper.parse_args("v", "is", "i", "i") +def _var_mean(g: jit_utils.GraphContext, input, dim, correction, keepdim): + return symbolic_helper._var_mean_helper(g, input, dim, correction, keepdim) + + +@_onnx_symbolic("aten::std") +def std(g: jit_utils.GraphContext, input, *args): + var, _ = var_mean(g, input, *args) + return g.op("Sqrt", var) + + +@_onnx_symbolic("aten::var") +def var(g: jit_utils.GraphContext, input, *args): + var, _ = var_mean(g, input, *args) + return var + + +@_onnx_symbolic("aten::var_mean") +def var_mean(g: jit_utils.GraphContext, input, *args): + if len(args) == 1: + return _var_mean(g, input, None, args[0], None) + else: + return _var_mean(g, input, *args) + + +@_onnx_symbolic("aten::std_mean") +def std_mean(g: jit_utils.GraphContext, input, *args): + var, mean = var_mean(g, input, *args) + return g.op("Sqrt", var), mean + + +@_onnx_symbolic("aten::logsumexp") +@symbolic_helper.parse_args("v", "is", "i") +def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): + return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim) + + +@_onnx_symbolic("aten::arange") +def arange(g: jit_utils.GraphContext, *args): + def _get_arange_dtype(dtype): + dtype = symbolic_helper._maybe_get_const(dtype, "i") + return dtype + + def _float_step_convert(range_tensor): + if symbolic_helper._is_fp(range_tensor): + range_tensor = g.op( + "Cast", + g.op("Ceil", range_tensor), + to_i=_type_utils.JitScalarType.INT64.onnx_type(), + ) + return range_tensor + + if len(args) == 2 or len(args) == 5: + if len(args) == 2: + # aten::arange(Scalar end, Tensor out) + dtype = None + else: + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[1]) + dtype, end, start, step = symbolic_helper._arange_cast_helper( + g, end=args[0], dtype=dtype + ) + end = symbolic_helper._unsqueeze_helper(g, end, [0]) + range_tensor = _float_step_convert(end) + arange_tensor = symbolic_helper._squeeze_helper( + g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1] + ) + return g.op( + "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + elif len(args) == 4 or len(args) == 7: + if len(args) == 4: + # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) + dtype = None + else: + # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[3]) + dtype, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], step=args[2], dtype=dtype + ) + step = symbolic_helper._unsqueeze_helper(g, step, [0]) + end = symbolic_helper._unsqueeze_helper(g, end, [0]) + start = symbolic_helper._unsqueeze_helper(g, start, [0]) + range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step)) + arange_tensor = symbolic_helper._squeeze_helper( + g, nonzero(g, ones(g, range_tensor, None, None, None)), [1] + ) + arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start) + return g.op( + "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + elif len(args) == 6: + # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + dtype = _get_arange_dtype(args[2]) + dtype, end, start, step = symbolic_helper._arange_cast_helper( + g, start=args[0], end=args[1], dtype=dtype + ) + end = symbolic_helper._unsqueeze_helper(g, end, [0]) + start = symbolic_helper._unsqueeze_helper(g, start, [0]) + range_tensor = _float_step_convert(g.op("Sub", end, start)) + arange_tensor = g.op( + "Add", + symbolic_helper._squeeze_helper( + g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1] + ), + start, + ) + return g.op( + "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() + ) + + return symbolic_helper._unimplemented("aten::arange", f"with {len(args)} arguments") + + +@_onnx_symbolic("aten::linspace") +def linspace( + g: jit_utils.GraphContext, start, end, steps, dtype, layout, device, pin_memory +): + range_tensor = symbolic_helper._arange_helper(g, steps, None) + step = div( + g, + sub(g, end, start), + sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))), + ) + return add(g, mul(g, range_tensor, step), start) + + +@_onnx_symbolic("aten::lift") +def lift(g: jit_utils.GraphContext, self): + # at::lift() is a no-op from the perspective of tracing for onnx + return self + + +@_onnx_symbolic("aten::masked_fill") +def masked_fill(g: jit_utils.GraphContext, self, mask, value): + """Implement the masked_fill functionality available for a pytorch tensor in ONNX. + + Fills elements of the input tensor with `value` where `mask` is True. + """ + mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) + value = symbolic_helper._maybe_get_scalar(value) + return g.op("Where", mask, symbolic_helper._if_scalar_type_as(value, self), self) + + +@_onnx_symbolic("aten::masked_fill_") +def masked_fill_(g: jit_utils.GraphContext, self, mask, value): + return masked_fill(g, self, mask, value) + + +@_onnx_symbolic("aten::index") +def index(g: jit_utils.GraphContext, self, index): + if symbolic_helper._is_packed_list(index): + indices = symbolic_helper._unpack_list(index) + else: + indices = [index] + + def try_mask_to_index(index): + if not symbolic_helper._is_none(index) and ( + _type_utils.JitScalarType.from_value( + index, _type_utils.JitScalarType.UNDEFINED + ) + == _type_utils.JitScalarType.UINT8 + or symbolic_helper._is_bool(index) + ): + if g.opset < 9: + raise errors.SymbolicValueError( + "Exporting masked indices are only supported after ONNX opset 9.", + self, + ) + warnings.warn( + "Exporting aten::index operator with indices of type Byte. " + "Only 1-D indices are supported. In any other case, " + "this will produce an incorrect ONNX graph." + ) + index = symbolic_helper._squeeze_helper(g, nonzero(g, index), [1]) + return index + + indices = [try_mask_to_index(idx) for idx in indices] + if len(indices) == 1: + return symbolic_helper._select_helper( + g, self, 0, indices[0], apply_reshape=False + ) + else: + # Multiple tensors as indices. Each tensor could either be + # 1. prim::Constant() + # representing ":" in python indexing. E.g. tensor[:, :] + # 2. prim::Constant[value=...] or tensor output + # representing advanced indexing. E.g. tensor[[0, 1], [2, 0]]. + # For more info on advanced indexing, + # check https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing + + # Consider a general case of + # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] + # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for ":". + # Same results can be achieved through transposing t into + # t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] + # and use gatherND. However ONNX does not have gatherND, to use 1d gather we'll need to flatten t + # and process the tensor indices. + # t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n] + # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)) + # After gather, reshape and transpose back. + adv_idx_indices = [ + i for i, idx in enumerate(indices) if not symbolic_helper._is_none(idx) + ] + + if len(adv_idx_indices) == 0: + return self + elif len(adv_idx_indices) == 1: + return index_select( + g, self, adv_idx_indices[0], indices[adv_idx_indices[0]] + ) + else: + rank = symbolic_helper._get_tensor_rank(self) + if rank is None: + return symbolic_helper._unimplemented( + "aten::index", + "operator of advanced indexing on tensor of unknown rank. ", + self, + ) + # TODO: If indexing is supported natively in ONNX in future opsets, + # update the warning to recommend exporting with higher opset version. + warnings.warn( + "Exporting aten::index operator of advanced indexing in opset " + f"{GLOBALS.export_onnx_opset_version}" + " is achieved by combination of multiple ONNX operators, " + "including Reshape, Transpose, Concat, and Gather. " + "If indices include negative values, the exported graph will produce incorrect results." + ) + adv_idx_count = len(adv_idx_indices) + shape_tensor = _shape_as_tensor(g, self) + dim_tensor_list = [ + g.op( + "Gather", + shape_tensor, + g.op("Constant", value_t=torch.LongTensor([dim])), + axis_i=0, + ) + for dim in range(rank) + ] + + self = g.op( + "Transpose", + self, + perm_i=adv_idx_indices + + [i for i in range(rank) if i not in adv_idx_indices], + ) + self = g.op("Flatten", self, axis_i=adv_idx_count) + + # Note that tensor indices will be broadcasted while accumulating. Thus we get the final subarray shape as well. + cum_adv_index = indices[adv_idx_indices[-1]] + multiplier = dim_tensor_list[adv_idx_indices[-1]] + for i in range(adv_idx_count - 2, -1, -1): + adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier) + cum_adv_index = g.op("Add", cum_adv_index, adv_index) + multiplier = g.op( + "Mul", multiplier, dim_tensor_list[adv_idx_indices[i]] + ) + + # perform gather + self = index_select(g, self, 0, cum_adv_index) + + cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index) + # check if all advanced indices are consecutive. + # Refer to https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + # to understand how the subarray position is decided. + if adv_idx_indices == list( + range(adv_idx_indices[0], adv_idx_indices[-1] + 1) + ): + # unfold regular index axes + folded_adv_idx_shape_list = [ + g.op("Constant", value_t=torch.LongTensor([-1])) + ] + [ + dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices + ] + folded_adv_idx_shape = g.op( + "Concat", *folded_adv_idx_shape_list, axis_i=0 + ) + self = symbolic_helper._reshape_helper(g, self, folded_adv_idx_shape) + + # Transpose folded advanced indexed axis to its original location. + adv_idx_permute = ( + list(range(1, adv_idx_indices[0] + 1)) + + [0] + + list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1)) + ) + self = g.op("Transpose", self, perm_i=adv_idx_permute) + + # unfold advanced index axes + final_shape_list = ( + [dim_tensor_list[i] for i in range(adv_idx_indices[0])] + + [cum_adv_index_shape_tensor] + + [ + dim_tensor_list[i] + for i in range(adv_idx_indices[0], rank) + if i not in adv_idx_indices + ] + ) + final_shape = g.op("Concat", *final_shape_list, axis_i=0) + else: + final_shape = g.op( + "Concat", + cum_adv_index_shape_tensor, + *[ + dim_tensor_list[i] + for i in range(rank) + if i not in adv_idx_indices + ], + axis_i=0, + ) + + return symbolic_helper._reshape_helper(g, self, final_shape) + + +@_onnx_symbolic("aten::linalg_norm") +@symbolic_helper.parse_args("v", "v", "is", "b", "v") +def linalg_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: torch._C.Value, + dim: Sequence[int] | None, + keepdim: bool, + dtype: torch._C.Value, +): + # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.norm.html + ord_value = None + if dim is None: + if symbolic_helper._is_none(ord): + self = symbolic_helper._reshape_helper(g, self, [-1]) + ord = g.op("Constant", value_t=torch.LongTensor([2])) + self_dim = symbolic_helper._get_tensor_rank(self) + if self_dim is None: + return symbolic_helper._unimplemented( + "dim", "Input rank must be known at export time.", self + ) + if self_dim == 1: + ord_value = symbolic_helper._parse_arg(ord, "f") + else: + dim = [0, 1] + else: + if len(dim) == 1: + if symbolic_helper._is_none(ord): + ord = g.op("Constant", value_t=torch.LongTensor([2])) + ord_value = symbolic_helper._parse_arg(ord, "f") + if ord_value: + return linalg_vector_norm(g, self, ord_value, dim, keepdim, dtype) + return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::linalg_vector_norm") +@symbolic_helper.parse_args("v", "f", "is", "b", "v") +def linalg_vector_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: float, + dim: Sequence[int] | None, + keepdim: bool, + dtype: torch._C.Value, +): + return symbolic_helper._linalg_vector_norm_helper(g, self, ord, dim, keepdim, dtype) + + +@_onnx_symbolic("aten::linalg_matrix_norm") +@symbolic_helper.parse_args("v", "v", "is", "b", "v") +def linalg_matrix_norm( + g: jit_utils.GraphContext, + self: torch._C.Value, + ord: torch._C.Value, + dim: list[int], + keepdim: bool, + dtype: torch._C.Value, +): + # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html + ord_value = symbolic_helper._parse_arg(ord, "s") + if ord_value == "fro": + return frobenius_norm(g, self, dim, keepdim) + elif ord_value == "nuc": + return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==nuc", self) + else: + ord_value = symbolic_helper._parse_arg(ord, "f") + if ord_value is None: + return frobenius_norm(g, self, dim, keepdim) + if ord_value == 2 or ord_value == -2: + # ord = 2/-2 unimplemented due to lack of operators + # used to calculate singular values + return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==2", self) + # Wrap the dim vector to handle negative dim values + self_dim = symbolic_helper._get_tensor_rank(self) + if self_dim is None: + return symbolic_helper._unimplemented( + "linalg.matrix_norm", "Input rank must be known at export time.", self + ) + # Common implementation for cases with + # ord = 1/-1 and ord = inf/-inf + if dim[0] < 0: + dim[0] += self_dim + if dim[1] < 0: + dim[1] += self_dim + + if ord_value == math.inf or ord_value == -math.inf: + dim[0], dim[1] = dim[1], dim[0] + if dim[1] > dim[0] and not keepdim: + dim[1] -= 1 + sum = symbolic_helper._reducesum_helper( + g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim + ) + if ord_value > 0: + result, _indices = max( + g, + sum, + dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), + keepdim=keepdim, + ) + else: + result, _indices = min( + g, + sum, + dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), + keepdim=keepdim, + ) + return result + + +@_onnx_symbolic("aten::linalg_cross") +@symbolic_helper.parse_args("v", "v", "i") +def linalg_cross(g: jit_utils.GraphContext, input, other, dim=-1): + return cross(g, input, other, dim) + + +@_onnx_symbolic("aten::frobenius_norm") +@symbolic_helper.parse_args("v", "is", "b") +def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): + sqr = g.op("Mul", self, self) + sumsqr = symbolic_helper._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim) + return g.op("Sqrt", sumsqr) + + +@_onnx_symbolic("aten::multinomial") +@symbolic_helper.parse_args("v", "i", "b", "v") +def multinomial( + g: jit_utils.GraphContext, input, num_samples, replacement=False, generator=None +): + if generator is not None and not symbolic_helper._is_none(generator): + symbolic_helper._unimplemented( + "Multinomial", "generator is not supported for multinomial", input + ) + if not replacement and num_samples > 1: + symbolic_helper._unimplemented( + "Multinomial", + "replacement=False when num_samples > 1 is not supported for multinomial", + input, + ) + + log_input = log(g, input) + return g.op( + "Multinomial", + log_input, + dtype_i=_C_onnx.TensorProtoDataType.INT64, + sample_size_i=num_samples, + ) + + +@_onnx_symbolic("aten::baddbmm") +def baddbmm(g: jit_utils.GraphContext, self, batch1, batch2, beta, alpha): + scalar_type = _type_utils.JitScalarType.from_value(self) + batch_mul = matmul(g, batch1, batch2) + mul_a = mul( + g, + batch_mul, + g.op("Cast", alpha, to_i=scalar_type.onnx_type()), + ) + mul_b = mul( + g, + self, + g.op("Cast", beta, to_i=scalar_type.onnx_type()), + ) + return add(g, mul_a, mul_b) + + +@_onnx_symbolic("aten::meshgrid") +@symbolic_helper.parse_args("v", "s") +def meshgrid(g: jit_utils.GraphContext, tensor_list, indexing: str | None = None): + if indexing is None: + indexing = "ij" + elif indexing not in {"ij", "xy"}: + raise errors.SymbolicValueError( + f"Unsupported indexing: {indexing}", tensor_list + ) + unpacked_tensor_list = symbolic_helper._unpack_list(tensor_list) + if indexing == "xy": + unpacked_tensor_list[:2] = unpacked_tensor_list[1::-1] + tensors = [ + symbolic_helper._reshape_helper( + g, t, g.op("Constant", value_t=torch.LongTensor([-1])) + ) + for t in unpacked_tensor_list + ] + tensors_shape = [g.op("Shape", t) for t in tensors] + out_shape = g.op("Concat", *tensors_shape, axis_i=0) + out = [] + for i, t in enumerate(tensors): + shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len( + tensors + ) + shape_i[i] = tensors_shape[i] + t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0)) + out.append(g.op("Expand", t_reshaped, out_shape)) + if indexing == "xy": + out[0], out[1] = out[1], out[0] + return g.op("prim::ListConstruct", *out) + + +@_onnx_symbolic("aten::remainder") +def remainder(g: jit_utils.GraphContext, input, other): + div = _floor_divide(g, input, other) + quo = g.op("Mul", div, other) + return g.op("Sub", input, quo) + + +@_onnx_symbolic("aten::gelu") +@symbolic_helper.parse_args("v", "s") +def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "none"): + if approximate == "tanh": + kBeta = math.sqrt(2 / math.pi) + kKappa = 0.044715 + + beta = torch.tensor(kBeta, dtype=torch.double) + kappa = torch.tensor(kKappa, dtype=torch.double) + one = torch.tensor(1.0, dtype=torch.double) + half = torch.tensor(0.5, dtype=torch.double) + + self_cube = mul(g, self, mul(g, self, self)) + inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube))) + return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner)))) + else: + _sqrt2 = 1.4142135623730951 + erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double))) + erf_plusone = add( + g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)) + ) + return mul( + g, + mul(g, self, erf_plusone), + g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)), + ) + + +@_onnx_symbolic("aten::group_norm") +@symbolic_helper.quantized_args(True, False, False, False) +@symbolic_helper.parse_args("v", "i", "v", "v", "f", "i") +def group_norm( + g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled +): + channel_size = symbolic_helper._get_tensor_dim_size(input, 1) + if channel_size is not None: + assert channel_size % num_groups == 0 + input_rank = symbolic_helper._get_tensor_rank(input) + if input_rank is None: + return symbolic_helper._unimplemented("group_norm", "unknown input rank", input) + # 0 in the shape list keeps dimension value unchanged. + shape = [0, num_groups, -1] + input_reshaped = symbolic_helper._reshape_helper( + g, input, g.op("Constant", value_t=torch.LongTensor(shape)) + ) + + # C is always divisible by num_groups + # Due to shape difference. we need to apply weight and bias after + # instance norm computation and reshape + weight_ = g.op( + "Constant", + value_t=torch.tensor( + [1.0] * num_groups, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ), + ) + bias_ = g.op( + "Constant", + value_t=torch.tensor( + [0.0] * num_groups, + dtype=_type_utils.JitScalarType.from_value(input).dtype(), + ), + ) + + norm_reshaped = g.op( + "InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps + ) + norm = symbolic_helper._reshape_helper(g, norm_reshaped, g.op("Shape", input)) + + if weight is None or weight.node().mustBeNone(): + weight_value = torch.tensor( + [1.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() + ) + weight = g.op("Constant", value_t=weight_value) + if bias is None or bias.node().mustBeNone(): + bias_value = torch.tensor( + [0.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() + ) + bias = g.op("Constant", value_t=bias_value) + + # Norm has shape [N, C, *] so we reshape weight and bias to [C, *] + axes = list(range(1, input_rank - 1)) + return add( + g, + mul(g, norm, symbolic_helper._unsqueeze_helper(g, weight, axes)), + symbolic_helper._unsqueeze_helper(g, bias, axes), + ) + + +@_onnx_symbolic("aten::_weight_norm") +@symbolic_helper.parse_args("v", "v", "i") +def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim): + rank = symbolic_helper._get_tensor_rank(weight_v) + if rank is not None: + # W = g * ((v) / ||v||) + # Compute norm_except_dim for l2 norm. dim = None means over all dims + # torch's weight_norm module sets dim = -1 if it's None. + # This conflicts the logic for negative axes to access dims backwards + # TODO: Might need a fix in torch group_norm module + axes = list(range(rank)) + if dim is not None: + if dim < -1: + dim += rank + if dim != -1: + axes.remove(dim) + norm_v = norm(g, weight_v, 2, axes, 1) + div = g.op("Div", weight_v, norm_v) + return g.op("Mul", div, weight_g) + raise errors.SymbolicValueError( + "Unsupported: ONNX export of _weight_norm for tensor of unknown rank.", + weight_v, + ) + + +@_onnx_symbolic("aten::dim") +def dim(g: jit_utils.GraphContext, self): + """Implement the dim functionality available for a pytorch tensor in ONNX""" + # ONNX does not support dim directly in this opset so we can use 2 ops to get the info + shape = g.op("Shape", self) + return g.op("Size", shape) + + +@_onnx_symbolic("aten::__contains_") +def __contains_(g: jit_utils.GraphContext, self, element): + unpacked_list = symbolic_helper._unpack_list(self) + if all( + symbolic_helper._is_constant(x) for x in unpacked_list + ) and symbolic_helper._is_constant(element): + return g.op( + "Constant", + value_t=torch.tensor( + symbolic_helper._node_get(element.node(), "value") + in (symbolic_helper._node_get(x.node(), "value") for x in unpacked_list) + ), + ) + + raise errors.SymbolicValueError( + "Unsupported: ONNX export of __contains__ for non-constant list or element.", + self, + ) + + +@_onnx_symbolic("aten::__getitem_") +def __getitem_(g: jit_utils.GraphContext, self, i): + return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i) + + +@_onnx_symbolic("aten::item") +def item(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("aten::take") +def take(g: jit_utils.GraphContext, self, index): + self_flattened = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + ) + out = index_select(g, self_flattened, 0, index) + out = reshape_as(g, out, index) + return out + + +def _kl_div_log_target_impl(g: jit_utils.GraphContext, input, target): + diff_ = sub(g, target, input) + exp_ = exp(g, target) + output = mul(g, exp_, diff_) + return output + + +def _kl_div_non_log_target_impl(g: jit_utils.GraphContext, input, target): + log_ = log(g, target) + diff_ = sub(g, log_, input) + output_pos = mul(g, target, diff_) + zeros_ = zeros_like(g, output_pos) + mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0))) + output = where(g, mask_, output_pos, zeros_) + return output + + +@_onnx_symbolic("aten::kl_div") +@symbolic_helper.parse_args("v", "v", "i", "b") +def kl_div(g: jit_utils.GraphContext, input, target, reduction, log_target): + if log_target: + output = _kl_div_log_target_impl(g, input, target) + else: + output = _kl_div_non_log_target_impl(g, input, target) + + if reduction == 0: + return output + elif reduction == 1: + return g.op("ReduceMean", output, keepdims_i=0) + elif reduction == 2: + return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) + else: + return symbolic_helper._onnx_unsupported( + "kl_div with reduction other than none, mean, or sum.", input + ) + + +@_onnx_symbolic("aten::mse_loss") +@symbolic_helper.parse_args("v", "v", "i") +def mse_loss(g: jit_utils.GraphContext, input, target, reduction): + output = mul(g, sub(g, input, target), sub(g, input, target)) + if reduction == 0: + return output + elif reduction == 1: + return g.op("ReduceMean", output, keepdims_i=0) + elif reduction == 2: + return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) + else: + return symbolic_helper._onnx_unsupported( + "mse_loss with reduction other than none, mean, or sum.", input + ) + + +@_onnx_symbolic("aten::as_strided") +@symbolic_helper.quantized_args(True) +@symbolic_helper.parse_args("v", "v", "is", "i") +def as_strided(g: jit_utils.GraphContext, self, sizes, strides, offset=None): + sizes = symbolic_helper._maybe_get_const(sizes, "is") + rank = len(strides) + self_1d = symbolic_helper._reshape_helper( + g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) + ) + ind: torch.Tensor | None + if not symbolic_helper._is_value(sizes): + ind = torch.tensor([0], dtype=torch.long) + for i, (size, stride) in enumerate(zip(sizes, strides)): + r_size = [1] * rank + r_size[i] = -1 + ind = ind + torch.arange(size).view(r_size) * stride + if offset: + ind = ind + offset + return g.op("Gather", self_1d, g.op("Constant", value_t=ind)) + else: + ind = None + for i, stride in enumerate(strides): + r_size = [1] * rank + r_size[i] = -1 + size = select( + g, + sizes, + g.op("Constant", value_t=torch.tensor([0])), + g.op("Constant", value_t=torch.tensor(i)), + ) + tmp_ind = symbolic_helper._reshape_helper( + g, + arange(g, size, 4, None, None, None), + g.op("Constant", value_t=torch.tensor(r_size)), + ) + tmp_ind = g.op( + "Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride])) + ) + if ind is None: + ind = tmp_ind + else: + ind = g.op("Add", ind, tmp_ind) + if offset: + ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset]))) + return g.op("Gather", self_1d, ind) + + +@_onnx_symbolic("aten::__derive_index") +def __derive_index(g: jit_utils.GraphContext, index, start, step): + return g.op("Add", start, g.op("Mul", index, step)) + + +@_onnx_symbolic("aten::__range_length") +# Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp +# if (step > 0 && lo < hi) { +# push(stack, 1 + (hi - 1 - lo) / step); +# } else if (step < 0 && lo > hi) { +# push(stack, 1 + (lo - 1 - hi) / (0 - step)); +# } else { +# push(stack, 0); +# } +def __range_length(g: jit_utils.GraphContext, lo, hi, step): + sub = g.op("Sub", hi, lo) + div = g.op("Ceil", true_divide(g, sub, step)) + return g.op("Cast", div, to_i=_C_onnx.TensorProtoDataType.INT64) + + +@_onnx_symbolic("aten::linear") +def linear(g: jit_utils.GraphContext, input, weight, bias): + rank = symbolic_helper._get_tensor_rank(input) + weight = t(g, weight) + if rank == 2 and not bias.node().mustBeNone(): + alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) + output = addmm(g, bias, input, weight, alpha, beta) + else: + output = matmul(g, input, weight) + if not bias.node().mustBeNone(): + output = add(g, bias, output) + + return output + + +@_onnx_symbolic("aten::hann_window") +@symbolic_helper.parse_args("v", "b", "i", "v", "v", "v", "v") +def hann_window( + g: jit_utils.GraphContext, + window_length, + periodic=True, + dtype: int | None = None, + layout=None, + device=None, + pin_memory=None, + requires_grad=False, +): + if dtype is None: + dtype_ = torch.get_default_dtype() + if not dtype_ or not dtype_.is_floating_point: + dtype_ = torch.float + scalar_type = _type_utils.JitScalarType.from_dtype(dtype_) + else: + scalar_type = _type_utils.JitScalarType(dtype) + + n_array = arange(g, window_length, 4, None, None, None) + output = g.op("Cast", n_array, to_i=_C_onnx.TensorProtoDataType.FLOAT) + output = mul( + g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output + ) + + if periodic is False: + window_length = sub( + g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int)) + ) + output = div(g, output, window_length) + output = g.op( + "Cast", + square(g, sin(g, output)), + to_i=scalar_type.onnx_type(), + ) + + return output + + +@_onnx_symbolic("aten::mv") +def mv(g: jit_utils.GraphContext, self, vec): + return matmul(g, self, vec) + + +@_onnx_symbolic("aten::dot") +def dot(g: jit_utils.GraphContext, self, other): + return matmul(g, self, other) + + +@_onnx_symbolic("aten::movedim") +@symbolic_helper.parse_args("v", "t", "t") +def movedim(g: jit_utils.GraphContext, self, source, destination): + # This is a pythonic implementation mostly taken from aten/src/ATen/native/TensorShape.cpp::movedim + source = source.view(-1) + destination = destination.view(-1) + + assert source.size() == destination.size() + + if (source == destination).all(): + return self + + self_rank = symbolic_helper._get_tensor_rank(self) + assert self_rank is not None + + perm = list(range(self_rank)) + + src_dims = perm.copy() + dst_dims = perm.copy() + + for src, dst in zip(source.tolist(), destination.tolist()): + perm[dst] = src + src_dims[src] = -1 + dst_dims[dst] = -1 + + src_dims = [dim for dim in src_dims if dim != -1] + dst_dims = [dim for dim in dst_dims if dim != -1] + + for src, dst in zip(src_dims, dst_dims): + perm[dst] = src + + return g.op("Transpose", self, perm_i=perm) + + +@_onnx_symbolic("aten::fill") +@symbolic_helper.parse_args("v", "v") +def fill(g: jit_utils.GraphContext, self, value): + scalar_type = _type_utils.JitScalarType.from_value( + self, _type_utils.JitScalarType.FLOAT + ) + return full_like(g, self, value, scalar_type) + + +@_onnx_symbolic("aten::index_add") +def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None): + warnings.warn( + "Warning: ONNX export does not support duplicated values in 'index' field, " + + "this will cause the ONNX model to be incorrect." + ) + + # ONNX does not support "alpha" argument, unlike aten index_add + # See: https://github.com/pytorch/pytorch/pull/65993#issuecomment-953151102 for more context + if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: + return symbolic_helper._unimplemented("index_add", "alpha != 1", self) + + dim = symbolic_helper._maybe_get_const(dim, "i") + if dim is None: + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting 'index_add_()' function with " + "unknown 'dim' value.", + self, + ) + + self_dim_rank = symbolic_helper._get_tensor_rank(self) + other_dim_rank = symbolic_helper._get_tensor_rank(other) + + if self_dim_rank is None or other_dim_rank is None: + raise errors.SymbolicValueError( + "ONNX export does NOT support exporting 'index_add_()' function while " + "the rank of self tensor or tensor to be added is unknown.", + self, + ) + + if other_dim_rank != self_dim_rank: + delta = self_dim_rank - other_dim_rank + for i in range(delta): + other = symbolic_helper._unsqueeze_helper( + g, other, [symbolic_helper._get_tensor_rank(other)] + ) + + other_dim_size = symbolic_helper._get_tensor_dim_size(other, dim) + self_dim_size = symbolic_helper._get_tensor_dim_size(self, dim) + + if (other_dim_size is not None) and (self_dim_size is not None): + if other_dim_size > self_dim_size: + raise errors.SymbolicValueError( + "ONNX export does not support exporting 'index_add_()' function with " + "duplicated values in 'index' parameter yet.", + self, + ) + + # Construct a new shape. It's almost as same as self except the size of the 'dim' + # dimension is 1, so that we can expand other dimensions as expected. + new_shape_axes = list(range(self_dim_rank)) + new_shape_starts = [0 for i in range(self_dim_rank)] + new_shape_ends = [sys.maxsize if (i != dim) else 1 for i in range(self_dim_rank)] + + new_shape = symbolic_helper._slice_helper( + g, self, axes=new_shape_axes, starts=new_shape_starts, ends=new_shape_ends + ) + other = expand_as(g, other, new_shape) + + for i in range(dim): + index = symbolic_helper._unsqueeze_helper(g, index, [0]) + + for i in range(self_dim_rank - dim - 1): + index = symbolic_helper._unsqueeze_helper( + g, index, [symbolic_helper._get_tensor_rank(index)] + ) + + return scatter_add(g, self, dim, expand_as(g, index, other), other) + + +@_onnx_symbolic("aten::roll") +@symbolic_helper.parse_args("v", "is", "is") +def roll(g: jit_utils.GraphContext, self, shifts, dims): + assert len(shifts) == len(dims) + + result = self + for i in range(len(shifts)): + shapes = [] + shape = symbolic_helper._slice_helper( + g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[sys.maxsize] + ) + shapes.append(shape) + shape = symbolic_helper._slice_helper( + g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]] + ) + shapes.append(shape) + result = g.op("Concat", *shapes, axis_i=dims[i]) + + return result + + +@_onnx_symbolic("aten::cross") +@symbolic_helper.parse_args("v", "v", "i") +def cross(g: jit_utils.GraphContext, input, other, dim=None): + dim = symbolic_helper._get_dim_for_cross(input, dim) + # If we have two tensors such that + # A = [a, b, c], B = [d, e, f], we permute the tensor such that we have + # After first roll, + # A' = [b, c, a], B' = [f, d, e], so that we calculate (b*f, c*d, a*e) + roll_x_1 = roll(g, input, [2], [dim]) + roll_y_1 = roll(g, other, [1], [dim]) + # After second roll, + # A' = [c, a, b], B' = [e, f, d], so that we calculate (c*e, a*f, b*d) + roll_x_2 = roll(g, input, [1], [dim]) + roll_y_2 = roll(g, other, [2], [dim]) + # cross product is calculated as + # result = [(b*f - c*e), (c*d - a*f), (a*e - b*d)] + return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2)) + + +@_onnx_symbolic("aten::cdist") +def cdist( + g: jit_utils.GraphContext, + x1, + x2, + p=2.0, + compute_mode="use_mm_for_euclid_dist_if_necessary", +): + # X1.shape = (B * P * D), X2.shape = (B * R * D) + # In order to respect numpy style broadcasting as demonstrated in + # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md + # we unsqueeze both input tensors + row_size_x1 = symbolic_helper._get_tensor_dim_size(x1, -2) + row_size_x2 = symbolic_helper._get_tensor_dim_size(x2, -2) + assert row_size_x1 is not None + assert row_size_x2 is not None + p_float = symbolic_helper._parse_arg(p, "f") + compute_mode = symbolic_helper._parse_arg(compute_mode, "i") + if p_float == 2.0 and ( + compute_mode == 1 + or (compute_mode is None and row_size_x1 >= 25 and row_size_x2 >= 25) + ): + return _euclidean_dist(g, x1, x2) + rank = symbolic_helper._get_tensor_rank(x1) + assert rank is not None + broadcasted_x1 = symbolic_helper._unsqueeze_helper(g, x1, [rank - 1]) + broadcasted_x2 = symbolic_helper._unsqueeze_helper(g, x2, [rank - 2]) + return pairwise_distance( + g, broadcasted_x1, broadcasted_x2, p, eps=1e-06, keepdim=False + ) + + +def _euclidean_dist(g: jit_utils.GraphContext, x1, x2): + # X1.shape = (B * P * D), X2.shape = (B * R * D) + # using matrix multiplication to accelerate the calculation of + # the euclidean distance + rank = symbolic_helper._get_tensor_rank(x1) + assert rank is not None + x1_norm = symbolic_helper._reducesum_helper( + g, + pow(g, x1, symbolic_helper._generate_wrapped_number(g, 2.0)), + axes_i=[-1], + keepdims_i=True, + ) + x1_pad = ones_like(g, x1_norm) + x2_norm = symbolic_helper._reducesum_helper( + g, + pow(g, x2, symbolic_helper._generate_wrapped_number(g, 2.0)), + axes_i=[-1], + keepdims_i=True, + ) + x2_pad = ones_like(g, x2_norm) + x1_ = g.op( + "Concat", + *[ + mul(g, symbolic_helper._generate_wrapped_number(g, -2.0), x1), + x1_norm, + x1_pad, + ], + axis_i=-1, + ) + x2_ = g.op("Concat", *[x2, x2_pad, x2_norm], axis_i=-1) + result = matmul(g, x1_, transpose(g, x2_, -2, -1)) + dtype = _type_utils.JitScalarType.from_value(result) + min = g.op( + "Cast", symbolic_helper._generate_wrapped_number(g, 0.0), to_i=dtype.onnx_type() + ) + result = symbolic_helper._op_with_optional_float_cast( + g, "Max", result, min, opset_before=12 + ) + result = sqrt(g, result) + return result + + +@_onnx_symbolic("aten::lerp") +def lerp(g: jit_utils.GraphContext, self, end, weight): + # Conditional for better numeric. This has been discussed in + # https://github.com/pytorch/pytorch/pull/18871 + diff = g.op("Sub", end, self) + return where( + g, + g.op("Less", weight, g.op("Constant", value_t=torch.tensor(0.5))), + g.op("Add", self, g.op("Mul", weight, diff)), + g.op( + "Sub", + end, + g.op( + "Mul", + diff, + g.op("Sub", g.op("Constant", value_t=torch.tensor(1.0)), weight), + ), + ), + ) + + +@_onnx_symbolic("aten::broadcast_tensors") +def broadcast_tensors(g: jit_utils.GraphContext, self): + all_tensors = symbolic_helper._unpack_list(self) + t_with_final_shape = zeros_like(g, all_tensors[0]) + + # Add operator supports multidirectional broadcasting. So we leverage this function + # to infer the final shape generated by the broadcast. + for t in all_tensors: + t_with_final_shape = add(g, t_with_final_shape, t) + + t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors] + return g.op("prim::ListConstruct", *t_list) + + +@_onnx_symbolic("aten::is_pinned") +def is_pinned(g: jit_utils.GraphContext, self, device=None): + # Unused by ONNX. + return None + + +@_onnx_symbolic("prim::ConstantSplit") +def prim_constant_split(g: jit_utils.GraphContext, self, split_size, dim): + size = symbolic_helper._get_tensor_dim_size(self, dim) + if size is None: + return symbolic_helper._unimplemented( + "prim::ConstantSplit", "unknown dimension size", self + ) + splits = [split_size] * (size // split_size) + leftover = size % split_size + if leftover: + splits.append(leftover) + return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits)) + + +# TODO: It would be better to export this as a chunk directly, as this is +# less sensitive to changes in input size. +# TODO: Once we have proper scoping, stop reimplementing chunk, delete this +# method, and use the desugared version +@_onnx_symbolic("prim::ConstantChunk") +def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): + dim_size = symbolic_helper._get_tensor_dim_size(self, dim) + if dim_size is None: + return symbolic_helper._unimplemented( + "prim::ConstantChunk", "unknown dimension size", self + ) + split_size = (dim_size + chunks - 1) // chunks + return prim_constant_split(g, self, split_size, dim) + + +@_onnx_symbolic("prim::shape") +def prim_shape(g: jit_utils.GraphContext, self): + return g.op("Shape", self) + + +@_onnx_symbolic("prim::max") +def prim_max(g: jit_utils.GraphContext, self, other): + return symbolic_helper._op_with_optional_float_cast( + g, "Max", self, other, opset_before=12 + ) + + +@_onnx_symbolic("prim::min") +def prim_min(g: jit_utils.GraphContext, self, other=None): + if not other: + if symbolic_helper._is_packed_list(self): + self = stack(g, self, g.op("Constant", value_t=torch.tensor([0]))) + return min(g, self) + return min(g, self, other) + + +@_onnx_symbolic("prim::data") +def prim_data(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("prim::layout") +def prim_layout(g: jit_utils.GraphContext, self): + # Always return 'torch.strided'. Other layout types are not supported by JIT 'TensorType'. + # Layout class defined in 'c10/core/Layout.h'. + return g.op("Constant", value_t=torch.tensor(0)) + + +@_onnx_symbolic("prim::ListConstruct") +def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs): + return None + + +@_onnx_symbolic("prim::ListUnpack") +def prim_list_unpack( + g: jit_utils.GraphContext, *inputs, **kwargs +) -> list[_C.Value] | None: + if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct": + # Cancel the previous node if it is ListConstruct by returning its inputs + # TODO(justinchuby): Use a public method in the helper module + return symbolic_helper._unpack_list(inputs[0]) + + return None + + +@_onnx_symbolic("prim::TupleConstruct") +def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs): + return None + + +@_onnx_symbolic("prim::Uninitialized") +def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs): + return None + + +# exists to refine the type of the Value +# if x is an optional Tensor, unchecked_cast will cast +# x to Tensor, so the rest of the graph knows that x is a Tensor +# this doesn't do anything in runtime and is a noop in ONNX +@_onnx_symbolic("prim::unchecked_cast") +def prim_unchecked_cast(g: jit_utils.GraphContext, self): + return self + + +@_onnx_symbolic("prim::dtype") +def prim_dtype(g: jit_utils.GraphContext, self): + scalar_type = symbolic_helper._try_get_scalar_type(self) + if scalar_type is None: + scalar_type = _type_utils.JitScalarType.FLOAT + # This node records a torch dtype as int + return g.op("Constant", value_t=torch.tensor(scalar_type)) + + +@_onnx_symbolic("prim::tolist") +def prim_tolist(g: jit_utils.GraphContext, input, dim_val, elem_ty_val): + """tolist is currently supported only for 1D input tensors. + + dim_val and elem_ty_val represent dimension and type annotations + that need to match dimension and type of the input tensor. + """ + dim = symbolic_helper._maybe_get_const(dim_val, "i") + if dim > 1: + return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1", input) + return input + + +# ----------------------------------------------------------------------------- +# Symbolic functions that need extra context +# ----------------------------------------------------------------------------- +@_onnx_symbolic("prim::device") +def prim_device(g: jit_utils.GraphContext, *inputs, **kwargs) -> None: + output_type = g.original_node.output().type() + if isinstance(output_type, _C.DeviceObjType): + return None + + return symbolic_helper._unimplemented( + "prim::device", + f"output type should be 'DeviceObjType', not '{output_type.kind()}'", + g.original_node.output(), + ) + + +@_onnx_symbolic("prim::Loop") +def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: + node = g.original_node + env = g.env + values_in_env = g.values_in_env + params_dict = g.params_dict + + operator_export_type = GLOBALS.operator_export_type + opset_version = GLOBALS.export_onnx_opset_version + + old_blocks = tuple(node.blocks()) + _new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( + g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks) + ) + + for old_block, new_block_context in zip(old_blocks, new_block_contexts): + # Copy input metadata to subblock + # + # prim::Loop(iter, cond, input_1, ..., input_n) + # block0(iter, input_1, ..., input_n) + # + # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`. + for i, b_in in enumerate(old_block.inputs()): + if i == 0 and i < len(inputs): + b_in.setType(inputs[i].type()) + # For optional block inputs, they may switch between None not-None inside + # the loop body, so if the loop input is not optional, the block input may + # still need to be optional. + if ( + i > 0 + and (i + 1) < len(inputs) + and not isinstance(b_in.type(), _C.OptionalType) + ): + b_in.setType(inputs[i + 1].type()) + torch._C._jit_pass_onnx_block( + old_block, + new_block_context.block, + operator_export_type, + env, + values_in_env, + False, + ) + fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( + new_node, opset_version + ) + # Run shape type inference for Loop after subblock is converted. + if GLOBALS.onnx_shape_inference: + torch._C._jit_pass_onnx_node_shape_type_inference( + new_node, params_dict, opset_version + ) + return fixed_outputs + + +@_onnx_symbolic("prim::If") +def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: + n = g.original_node + block = g.block + env = g.env + values_in_env = g.values_in_env + params_dict = g.params_dict + + operator_export_type = GLOBALS.operator_export_type + opset_version = GLOBALS.export_onnx_opset_version + + static_if = inputs[0].node().kind() == "onnx::Constant" + if static_if: + # Fold static if + # + # The torch IR + # graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu), + # %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ... + # %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + # %21 : Long(device=cpu) = aten::eq(%20, %64) + # %22 : Long(device=cpu) = prim::If(%21) + # block0(): + # %23 : Long(device=cpu) = aten::is_floating_point(%input.1) + # -> (%23) + # block1(): + # -> (%65) + # %input.53 : Tensor, %weight : Tensor = prim::If(%22) + # block0(): + # -> (%embedding_matrix.1, %input.1) + # block1(): + # -> (%input.1, %embedding_matrix.1) + # %26 : int[] = aten::size(%input.53) + # + # The converted ONNX graph + # %10 : Bool(device=cpu) = onnx::Constant[value={0}]() + # %14 : Bool(device=cpu) = onnx::Equal(%13, %8) + # %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]() + # %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1) + input_flag = symbolic_helper._node_get(inputs[0].node(), "value").tolist() + const_value = ( + all(input_flag) if isinstance(input_flag, list) else bool(input_flag) + ) + block_idx = 0 if const_value else 1 + current_b = list(n.blocks())[block_idx] + env = torch._C._jit_pass_onnx_block( + current_b, + block, + operator_export_type, + env, + values_in_env, + True, + ) + if_output_list = list(n.outputs()) + current_b_list = list(current_b.outputs()) + + final_b_list = [] + for idx in range(len(if_output_list)): + if current_b_list[idx] not in env: + raise errors.SymbolicValueError( + f"The sub block ATen output {current_b_list[idx]} is not in env.", + current_b_list[idx], + ) # type:ignore[operator] + onnx_b = env[current_b_list[idx]] + final_b_list.append(onnx_b) + return final_b_list + else: + old_blocks = tuple(n.blocks()) + _new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( + g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks) + ) + + for old_block, new_block_context in zip(old_blocks, new_block_contexts): + torch._C._jit_pass_onnx_block( + old_block, + new_block_context.block, + operator_export_type, + env, + values_in_env, + False, + ) + fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( + new_node, opset_version + ) + # Run shape type inference for If after subblock is converted. + if GLOBALS.onnx_shape_inference: + torch._C._jit_pass_onnx_node_shape_type_inference( + new_node, params_dict, opset_version + ) + return fixed_outputs + + +@_onnx_symbolic("prim::Constant") +def prim_constant(g: jit_utils.GraphContext, *inputs, **attrs): + node = g.original_node + + if node.mustBeNone(): + return None + # This must go before checking for string values, because some device constants + # have string values, but we want to keep them as unconverted Device types so + # that eq() can work on them. + if isinstance(node.output().type(), _C.DeviceObjType): + return None + if node.kindOf("value") == "t": + return g.op("Constant", value_t=symbolic_helper._node_get(node, "value")) + if node.kindOf("value") == "s": + return g.op("Constant", value_s=symbolic_helper._node_get(node, "value")) + if node.output().type().isSubtypeOf( + _C.ListType.ofInts() + ) or node.output().type().isSubtypeOf(_C.ListType.ofFloats()): + return g.op( + "Constant", value_t=torch.tensor(symbolic_helper._node_get(node, "value")) + ) + if node.output().type().isSubtypeOf(_C.ListType.ofStrings()): + str_constants = [ + g.op("Constant", value_s=s) + for s in symbolic_helper._node_get(node, "value") + ] + return g.op("prim::ListConstruct", *str_constants) + + raise errors.SymbolicValueError( + f"Unsupported prim::Constant kind: '{node.kindOf('value')}'. " + f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.", + node.output(), + ) + + +@_onnx_symbolic("prim::type") +def prim_type(g: jit_utils.GraphContext, device_value: _C.Value, *args, **kwargs): + if device_value.node().kind() == "prim::device": + device = jit_utils.get_device_from_value(device_value.node().input()) + if device is not None: + return g.op("Constant", value_s=str(device)) + + return symbolic_helper._unimplemented( + "prim::type", + "Device type cannot be statically determined.", + device_value, + ) + + +@_onnx_symbolic("onnx::Placeholder") +def onnx_placeholder(g: jit_utils.GraphContext, *inputs, **attrs): + node = g.original_node + block = g.block + env = g.env + values_in_env = g.values_in_env + + return torch._C._jit_onnx_convert_pattern_from_subblock( + block, node, env, values_in_env + ) + + +@_onnx_symbolic("aten::resolve_conj") +@_onnx_symbolic("aten::resolve_neg") +def noop_complex_operators(g: jit_utils.GraphContext, input: _C.Value): + # ONNX does not have operators to *directly* manipulate real/imaginary components + # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, + # which results in failures due to missing operators for complex numbers + + # `aten::resolve_conj` and `aten::resolve_neg` can safely be implemented as no-op + return input + + +@_onnx_symbolic("aten::_conj") +@_onnx_symbolic("aten::conj_physical") +def unsupported_complex_operators(g: jit_utils.GraphContext, input: _C.Value): + # ONNX does not have operators to *directly* manipulate real/imaginary components + # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, + # which results in failures due to missing operators for complex numbers + + # While `aten::_conj` and `aten::conj_physical` raise exception when input is complex + if symbolic_helper.is_complex_value(input): + # FIXME(justinchuby): report correct name for symbolic being executed + return symbolic_helper._onnx_unsupported( + "aten::_conj, aten::conj_physical", + input, + ) + + # they can safely be implemented as no-op for real numbers only + return noop_complex_operators(g, input) + + +@_onnx_symbolic("aten::logit") +def logit(g: jit_utils.GraphContext, self: torch._C.Value, eps: torch._C.Value): + one = g.op("Constant", value_t=torch.tensor(1.0)) + + if not symbolic_helper._is_none(eps): + eps = g.op( + "Cast", eps, to_i=_type_utils.JitScalarType.from_value(self).onnx_type() + ) + one_sub_eps = g.op("Sub", one, eps) + self_less_equal_one_sub_eps = g.op("Greater", one_sub_eps, self) + temporary_self = g.op("Where", self_less_equal_one_sub_eps, self, one_sub_eps) + + temporary_self_less_eps = g.op("Less", temporary_self, eps) + z = g.op("Where", temporary_self_less_eps, eps, temporary_self) + else: + z = self + + sub = g.op("Sub", one, z) + div = g.op("Div", z, sub) + return g.op("Log", div) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 6b1d752bb04ea..4e163f110537e 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """Backward compatibility module for torch.onnx.utils.""" from __future__ import annotations @@ -6,3 +7,1885 @@ __all__: list[str] = [] from torch.onnx._internal.torchscript_exporter.utils import * # noqa: F401,F403 +======= +# mypy: allow-untyped-defs +"""Functions to export models into the ONNX IR format. + +These models can be loaded with the ONNX library and then +converted to models which run on other deep learning frameworks. +""" + +from __future__ import annotations + +import contextlib +import copy +import inspect +import re +import typing +import warnings +from typing import Any, Callable, cast +from typing_extensions import deprecated + +import torch +import torch._C._onnx as _C_onnx +import torch.jit._trace +import torch.serialization +from torch import _C +from torch.onnx import _constants, errors, symbolic_helper # noqa: F401 +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import jit_utils, onnx_proto_utils, registration + + +if typing.TYPE_CHECKING: + from collections.abc import Collection, Mapping, Sequence + + +__all__ = [ + "select_model_mode_for_export", + "disable_apex_o2_state_dict_hook", + "setup_onnx_logging", + "exporter_context", + "export", + "model_signature", + "warn_on_static_input_change", + "unpack_quantized_tensor", + "unconvertible_ops", + "register_custom_op_symbolic", + "unregister_custom_op_symbolic", +] + + +# TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp +# Skip check due to cannot import IValue from torch._C +_params_dict = {} # type: ignore[var-annotated] + + +@deprecated("Please set training mode before exporting the model", category=None) +@contextlib.contextmanager +def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode): + """A context manager to temporarily set the training mode of ``model`` + to ``mode``, resetting it when we exit the with-block. + + .. deprecated:: 2.7 + Please set training mode before exporting the model. + + Args: + model: Same type and meaning as ``model`` arg to :func:`export`. + mode: Same type and meaning as ``training`` arg to :func:`export`. + """ + if not isinstance(mode, _C_onnx.TrainingMode): + raise TypeError( + f"'mode' should be a torch.onnx.TrainingMode enum, but got '{type(mode)}'." + ) + originally_training: bool = False + + if hasattr(model, "training"): + originally_training = model.training + + # ONNX opset 12 has better support for training amenable models, with updated + # versions of the dropout and batch_norm operators + if mode == _C_onnx.TrainingMode.TRAINING or ( + mode == _C_onnx.TrainingMode.PRESERVE and originally_training + ): + GLOBALS.export_training = True + if GLOBALS.export_onnx_opset_version < 12: + warnings.warn( + "You are exporting the model in training mode with onnx opset " + f"version {GLOBALS.export_onnx_opset_version}. " + "Opset versions lower than opset 12 will not be able to export " + "nodes such as Dropout and BatchNorm correctly." + ) + else: + GLOBALS.export_training = False + + GLOBALS.training_mode = mode + if mode == _C_onnx.TrainingMode.TRAINING: + model.train(True) + elif mode == _C_onnx.TrainingMode.EVAL: + model.train(False) + # else mode == _C_onnx.TrainingMode.PRESERVE, do nothing + + try: + yield + finally: + if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE: + model.train(originally_training) + + +@deprecated( + "Please remove usage of this function. Copy its logic if it is required in user code", + category=None, +) +@contextlib.contextmanager +def disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFunction): + """A context manager to temporarily disable the Apex O2 hook that returns. + + .. deprecated:: 2.7 + Please remove usage of this function. + """ + # Apex O2 hook state_dict to return fp16 weights as fp32. + # Exporter cannot identify them as same tensors. + # Since this hook is only used by optimizer, it is safe to + # remove this hook while exporting. + if not isinstance(model, torch.jit.ScriptFunction): + model_hooks = {} # type: ignore[var-annotated] + for module in model.modules(): + for key, hook in module._state_dict_hooks.items(): + if type(hook).__name__ == "O2StateDictHook": + if module not in model_hooks: + model_hooks[module] = {} + model_hooks[module][key] = hook + if module in model_hooks: + for key in model_hooks[module]: + module._state_dict_hooks.pop(key) + try: + yield + finally: + # Add the hooks back + for module, m_map in model_hooks.items(): + for key, hook in m_map.items(): + module._state_dict_hooks[key] = hook + else: + try: + yield + finally: + pass + + +@deprecated("The feature will be removed. Please remove usage of this function") +@contextlib.contextmanager +def setup_onnx_logging(verbose: bool): + """A context manager to temporarily set the ONNX logging verbosity. + + .. deprecated:: 2.7 + Please remove usage of this function. + """ + is_originally_enabled = _C._jit_is_onnx_log_enabled + if is_originally_enabled or verbose: # type: ignore[truthy-function] + _C._jit_set_onnx_log_enabled(True) + try: + yield + finally: + if not is_originally_enabled: # type: ignore[truthy-function] + _C._jit_set_onnx_log_enabled(False) + + +@deprecated( + "The feature will be removed. Please remove usage of this function " + "and implement equivalent logic if needed", + category=None, +) +@contextlib.contextmanager +def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool): + """A context manager to temporarily set the training mode of ``model`` + to ``mode``, disable the Apex O2 hook, and set the ONNX logging verbosity. + + .. deprecated:: 2.7 + Please set training mode before exporting the model. + """ + with ( + select_model_mode_for_export(model, mode) as mode_ctx, + disable_apex_o2_state_dict_hook(model) as apex_ctx, + setup_onnx_logging(verbose) as log_ctx, + ): + yield (mode_ctx, apex_ctx, log_ctx) + + +def _get_torch_export_args( + args: tuple[Any, ...], + kwargs: dict[str, Any] | None, +) -> tuple[tuple[Any, ...], dict[str, Any] | None]: + """Obtain the arguments for torch.onnx.export from the model and the input arguments.""" + if not kwargs and args and isinstance(args[-1], dict): + kwargs = args[-1] + args = args[:-1] + return args, kwargs + + +def export( + model: torch.nn.Module | torch.jit.ScriptModule | torch.jit.ScriptFunction, + args: tuple[Any, ...] | torch.Tensor, + f: str, + *, + kwargs: dict[str, Any] | None = None, + export_params: bool = True, + verbose: bool = False, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX, + opset_version: int | None = None, + do_constant_folding: bool = True, + dynamic_axes: Mapping[str, Mapping[int, str]] + | Mapping[str, Sequence[int]] + | None = None, + keep_initializers_as_inputs: bool | None = None, + custom_opsets: Mapping[str, int] | None = None, + export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False, + autograd_inlining: bool = True, +) -> None: + r"""Exports a model into ONNX format. + + If ``model`` is not a :class:`torch.jit.ScriptModule` nor a + :class:`torch.jit.ScriptFunction`, this runs + ``model`` once in order to convert it to a TorchScript graph to be exported + (the equivalent of :func:`torch.jit.trace`). Thus this has the same limited support + for dynamic control flow as :func:`torch.jit.trace`. + + Args: + model: The model to be exported. + args: + + args can be structured either as: + + 1. ONLY A TUPLE OF ARGUMENTS:: + + args = (x, y, z) + + The tuple should contain model inputs such that ``model(*args)`` is a valid + invocation of the model. Any non-Tensor arguments will be hard-coded into the + exported model; any Tensor arguments will become inputs of the exported model, + in the order they occur in the tuple. + + 2. A TENSOR:: + + args = torch.Tensor([1]) + + This is equivalent to a 1-ary tuple of that Tensor. + + 3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS:: + + args = (x, {"y": input_y, "z": input_z}) + + All but the last element of the tuple will be passed as non-keyword arguments, + and named arguments will be set from the last element. If a named argument is + not present in the dictionary, it is assigned the default value, or None if a + default value is not provided. + + .. warning:: + This behavior will be deprecated in a future release. Please use the + kwargs argument instead. + + .. note:: + If a dictionary is the last element of the args tuple, it will be + interpreted as containing named arguments. In order to pass a dict as the + last non-keyword arg, provide an empty dict as the last element of the args + tuple. For example, instead of:: + + torch.onnx.export( + model, + ( + x, + # WRONG: will be interpreted as named arguments + {y: z}, + ), + "test.onnx.pb", + ) + + Write:: + + torch.onnx.export(model, (x, {y: z}, {}), "test.onnx.pb") + + f: Path to the output ONNX model file. E.g. "model.onnx". + kwargs: Named arguments to the model. + export_params: If True, all parameters will + be exported. Set this to False if you want to export an untrained model. + In this case, the exported model will first take all of its parameters + as arguments, with the ordering as specified by ``model.state_dict().values()`` + verbose: if True, prints a description of the + model being exported to stdout. In addition, the final ONNX graph will include the + field ``doc_string``` from the exported model which mentions the source code locations + for ``model``. If True, ONNX exporter logging will be turned on. + training: + * ``TrainingMode.EVAL``: export the model in inference mode. + * ``TrainingMode.PRESERVE``: export the model in inference mode if model.training is + False and in training mode if model.training is True. + * ``TrainingMode.TRAINING``: export the model in training mode. Disables optimizations + which might interfere with training. + input_names (list of str, default empty list): names to assign to the + input nodes of the graph, in order. + output_names (list of str, default empty list): names to assign to the + output nodes of the graph, in order. + operator_export_type (enum, default OperatorExportTypes.ONNX): + + .. warning:: + This option will be deprecated in a future release. Future exported + graphs will always use the default opset domain. + + * ``OperatorExportTypes.ONNX``: Export all ops as regular ONNX ops + (in the default opset domain). + * ``OperatorExportTypes.ONNX_FALLTHROUGH``: Try to convert all ops + to standard ONNX ops in the default opset domain. If unable to do so + (e.g. because support has not been added to convert a particular torch op to ONNX), + fall back to exporting the op into a custom opset domain without conversion. Applies + to `custom ops `_ + as well as ATen ops. For the exported model to be usable, the runtime must support + these non-standard ops. + * ``OperatorExportTypes.ONNX_ATEN``: All ATen ops (in the TorchScript namespace "aten") + are exported as ATen ops (in opset domain "org.pytorch.aten"). + `ATen `_ is PyTorch's built-in tensor library, so + this instructs the runtime to use PyTorch's implementation of these ops. + + .. warning:: + + Models exported this way are probably runnable only by Caffe2. + + This may be useful if the numeric differences in implementations of operators are + causing large differences in behavior between PyTorch and Caffe2 (which is more + common on untrained models). + + * ``OperatorExportTypes.ONNX_ATEN_FALLBACK``: Try to export each ATen op + (in the TorchScript namespace "aten") as a regular ONNX op. If we are unable to do so + (e.g. because support has not been added to convert a particular torch op to ONNX), + fall back to exporting an ATen op. See documentation on OperatorExportTypes.ONNX_ATEN for + context. + For example:: + + graph(%0 : Float): + %3 : int = prim::Constant[value=0]() + # conversion unsupported + %4 : Float = aten::triu(%0, %3) + # conversion supported + %5 : Float = aten::mul(%4, %0) + return (%5) + + Assuming ``aten::triu`` is not supported in ONNX, this will be exported as:: + + graph(%0 : Float): + %1 : Long() = onnx::Constant[value={0}]() + # not converted + %2 : Float = aten::ATen[operator="triu"](%0, %1) + # converted + %3 : Float = onnx::Mul(%2, %0) + return (%3) + + .. warning:: + + Models exported this way are probably runnable only by Caffe2. + + opset_version (int, default 18): The version of the + `default (ai.onnx) opset `_ + to target. Must be >= 7. + do_constant_folding: Apply the constant-folding optimization. + Constant-folding will replace some of the ops that have all constant inputs + with pre-computed constant nodes. + dynamic_axes: + + By default the exported model will have the shapes of all input and output tensors + set to exactly match those given in ``args``. To specify axes of tensors as + dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema: + + * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or + ``output_names``. + * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a + list, each element is an axis index. + + For example:: + + class SumModule(torch.nn.Module): + def forward(self, x): + return torch.sum(x, dim=1) + + + torch.onnx.export( + SumModule(), + (torch.ones(2, 2),), + "onnx.pb", + input_names=["x"], + output_names=["sum"], + ) + + Produces:: + + input { + name: "x" + ... + shape { + dim { + dim_value: 2 # axis 0 + } + dim { + dim_value: 2 # axis 1 + ... + output { + name: "sum" + ... + shape { + dim { + dim_value: 2 # axis 0 + ... + + While:: + + torch.onnx.export( + SumModule(), + (torch.ones(2, 2),), + "onnx.pb", + input_names=["x"], + output_names=["sum"], + dynamic_axes={ + # dict value: manually named axes + "x": {0: "my_custom_axis_name"}, + # list value: automatic names + "sum": [0], + }, + ) + + Produces:: + + input { + name: "x" + ... + shape { + dim { + dim_param: "my_custom_axis_name" # axis 0 + } + dim { + dim_value: 2 # axis 1 + ... + output { + name: "sum" + ... + shape { + dim { + dim_param: "sum_dynamic_axes_1" # axis 0 + ... + + keep_initializers_as_inputs: If True, all the + initializers (typically corresponding to parameters) in the + exported graph will also be added as inputs to the graph. If False, + then initializers are not added as inputs to the graph, and only + the non-parameter inputs are added as inputs. + This may allow for better optimizations (e.g. constant folding) by + backends/runtimes. + + If True, `deduplicate_initializers` pass will not be executed. This means + initializers with duplicated values will not be deduplicated and + will be treated as distinct inputs to the graph. This allows different + input initializers to be supplied at the runtime following export. + + If ``opset_version < 9``, initializers MUST be part of graph + inputs and this argument will be ignored and the behavior will be + equivalent to setting this argument to True. + + custom_opsets (dict[str, int], default empty dict): A dict with schema: + + * KEY (str): opset domain name + * VALUE (int): opset version + + If a custom opset is referenced by ``model`` but not mentioned in this dictionary, + the opset version is set to 1. Only custom opset domain name and version should be + indicated through this argument. + + export_modules_as_functions: Flag to enable + exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the + particular types of modules to export as local functions in ONNX. + This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because + ``opset_version`` < 15 implies IR version < 8, which means no local function support. + Module variables will be exported as function attributes. There are two categories of function + attributes. + + 1. Annotated attributes: class variables that have type annotations via + `PEP 526-style `_ + will be exported as attributes. + Annotated attributes are not used inside the subgraph of ONNX local function because + they are not created by PyTorch JIT tracing, but they may be used by consumers + to determine whether or not to replace the function with a particular fused kernel. + + 2. Inferred attributes: variables that are used by operators inside the module. Attribute names + will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from + python module annotations. Inferred attributes are used inside the subgraph of ONNX local function. + + * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes. + * ``True``: export all ``nn.Module`` forward calls as local function nodes. + * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes, + only if the type of the ``nn.Module`` is found in the set. + + autograd_inlining: Flag used to control whether to inline autograd functions. + Refer to https://github.com/pytorch/pytorch/pull/74765 for more details. + + Raises: + :class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph. + :class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it + uses an operator that is not supported by the exporter. + :class:`torch.onnx.errors.OnnxExporterError`: Other errors that can occur during export. + All errors are subclasses of :class:`errors.OnnxExporterError`. + """ + if operator_export_type != _C_onnx.OperatorExportTypes.ONNX: + warnings.warn( + "Setting `operator_export_type` to something other than default is deprecated. " + "The option will be removed in a future release.", + category=DeprecationWarning, + ) + if training == _C_onnx.TrainingMode.TRAINING: + warnings.warn( + "Setting `training` to something other than default is deprecated. " + "The option will be removed in a future release. Please set the training mode " + "before exporting the model.", + category=DeprecationWarning, + ) + + args = (args,) if isinstance(args, torch.Tensor) else args + if kwargs is not None: + args = args + (kwargs,) + + _export( + model, + args, + f, + export_params, + verbose, + training, + input_names, + output_names, + operator_export_type=operator_export_type, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + dynamic_axes=dynamic_axes, + keep_initializers_as_inputs=keep_initializers_as_inputs, + custom_opsets=custom_opsets, + export_modules_as_functions=export_modules_as_functions, + autograd_inlining=autograd_inlining, + ) + + return None + + +def _is_constant_tensor_list(node): + if node.kind() != "prim::Constant": + return False + output_type = node.output().type() + if output_type.isSubtypeOf(_C.ListType.ofTensors()): + return True + if output_type.isSubtypeOf(_C.ListType(_C.OptionalType.ofTensor())): + return True + + +# ONNX can't handle constants that are lists of tensors, which can +# get generated in constant prop. So we split them back into prim::ListConstructs + + +def _split_tensor_list_constants(g, block): + for node in block.nodes(): + for subblock in node.blocks(): + _split_tensor_list_constants(g, subblock) + if _is_constant_tensor_list(node): + inputs = [] + for val in node.output().toIValue(): + input = g.insertConstant(val) + input.node().moveBefore(node) + input.node().copyMetadata(node) + inputs.append(input) + + lc = ( + g.create("prim::ListConstruct", inputs) + .insertBefore(node) + .output() + .setType(_C.ListType.ofTensors()) + ) + lc.node().copyMetadata(node) + node.output().replaceAllUsesWith(lc) + + +def _optimize_graph( + graph: _C.Graph, + operator_export_type: _C_onnx.OperatorExportTypes, + _disable_torch_constant_prop: bool = False, + fixed_batch_size: bool = False, + params_dict=None, + dynamic_axes=None, + input_names=None, + module=None, +): + if params_dict is None: + params_dict = {} + + # Inline everything + _C._jit_pass_inline(graph) + + # Remove fork/wait nodes + _C._jit_pass_inline_fork_wait(graph) + _C._jit_pass_lint(graph) + if GLOBALS.autograd_inlining: + _C._jit_pass_onnx_autograd_function_process(graph) + _C._jit_pass_lower_all_tuples(graph) + + # we now record some ops like ones/zeros + # into a trace where we previously recorded constants. + # use constant prop to maintain our current level of onnx support + # without implementing symbolics for all of them + if _disable_torch_constant_prop is False: + _C._jit_pass_constant_propagation(graph) + + _split_tensor_list_constants(graph, graph) + # run dce to eliminate dead parts of the graph that might have been + # left behind by things like symbolic_override + _C._jit_pass_dce(graph) + _C._jit_pass_lint(graph) + + # CSE should improve perf when Autocast is used with disabled cache + # Autocast is disabled due to a limitation on tracer as described at https://github.com/pytorch/pytorch/issues/84092 + # Must run before _C._jit_pass_erase_number_types to prevent type substitution + if _C._jit_pass_cse(graph): + _C._jit_pass_onnx_lint(graph) + + _C._jit_pass_canonicalize_graph_fuser_ops(graph) + _C._jit_pass_lint(graph) + _C._jit_pass_peephole(graph, True) + _C._jit_pass_fuse_addmm(graph) + _C._jit_pass_lint(graph) + + _C._jit_pass_peephole(graph, True) + _C._jit_pass_lower_all_tuples(graph) + # in _jit_pass_onnx, symbolic functions are called for each node for conversion. + # However, there are nodes that cannot be converted without additional context. + # For example, the number of outputs from split (and whether it is static or dynamic) is unknown + # until the point where it is unpacked by listUnpack node. + # This pass does a preprocess, and prepares the nodes such that enough context can be received + # by the symbolic function. + _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) + _C._jit_pass_onnx_preprocess(graph) + + # onnx does not support tuples, so try to remove them + _C._jit_pass_lint(graph) + + # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0 + _C._jit_pass_prepare_division_for_onnx(graph) + + _C._jit_pass_onnx_remove_print(graph) + _C._jit_pass_onnx_preprocess_caffe2(graph) + + symbolic_helper._quantized_ops.clear() + # Unpack quantized weights for conv and linear ops and insert into graph. + _C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict) + # onnx only supports tensors, so we turn all out number types into tensors + _C._jit_pass_erase_number_types(graph) + if GLOBALS.onnx_shape_inference: + input_names = [] if input_names is None else input_names + dynamic_axes = {} if dynamic_axes is None else dynamic_axes + _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names) + _C._jit_pass_onnx_lint(graph) + + graph = _C._jit_pass_onnx(graph, operator_export_type) + _C._jit_pass_onnx_lint(graph) + _C._jit_pass_lint(graph) + + _C._jit_pass_onnx_scalar_type_analysis( + graph, True, GLOBALS.export_onnx_opset_version + ) + _C._jit_pass_lint(graph) + + _C._jit_pass_onnx_peephole( + graph, GLOBALS.export_onnx_opset_version, fixed_batch_size + ) + _C._jit_pass_lint(graph) + + # graph is not a valid jit graph anymore because types have been replaced + # (e.g. int with Tensor), so it now contains operators that don't actually + # exist. We can't run normal dead code elimination because it'd fail trying + # to look up if an operator has side effects, but we can run a dead code + # elimination variant that doesn't need to look up if an op has side effects. + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + _C._jit_pass_lint(graph) + graph = _C._jit_pass_canonicalize(graph) + _C._jit_pass_lint(graph) + if GLOBALS.onnx_shape_inference: + try: + _C._jit_pass_onnx_graph_shape_type_inference( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) + except RuntimeError: + # NOTE: shape type inference error should not stop the export process + # https://github.com/pytorch/pytorch/issues/132205 + pass + + return graph + + +def warn_on_static_input_change(input_states): + """Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph. + + We accept dictionaries and strings as ONNX inputs, but they should be only for + configuration use. we detect here if these inputs are modified, and if so we warn + the user that the changes won't take effect in the traced ONNX graph. + """ + for input, traced_input in zip(input_states[0], input_states[1]): + if isinstance(input, dict): + if list(input.keys()) != list(traced_input.keys()): + warning = ( + "We detected that you are modifying a dictionary that is an input to your " + "model. " + "Note that dictionaries are allowed as inputs in ONNX but they should be " + "handled with care. " + "Usages of dictionaries is not recommended, and should not be used except " + "for configuration use. " + "Also note that the order and values of the keys must remain the same. " + ) + warnings.warn(warning) + elif isinstance(input, str): + if input != traced_input: + warning = ( + "The model seems to have string inputs/outputs. " + "Note that strings will not appear as inputs/outputs of the ONNX graph. " + ) + warnings.warn(warning) + + +def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type): + """Resolves the arguments that are ignored when export_type != operator_export_type.ONNX.""" + return arg_value + + +def _decide_keep_init_as_input( + keep_initializers_as_inputs: bool | None, + operator_export_type: _C_onnx.OperatorExportTypes, + opset_version: int, +): + """Decides whether the initializers in the graph should be listed as ONNX graph inputs. + + This method encapsulates the logic to decide whether the initializers in the graph + should be listed as ONNX graph inputs (i.e., whether to choose ONNX IR v3 or v4). + If keep_initializers_as_inputs is not specified (None), then we decide whether to keep + initializers as graph inputs (val_keep_init_as_ip) based on export type. If export type + is ONNX, then do not keep initializers as input (val_keep_init_as_ip=False). For all other + export types keep initializers as input (val_keep_init_as_ip=True). + If keep_initializers_as_inputs is specified, then respect it. Unless opset version <= 8, + in which case it must be ignored because for opset version <= 8, all initializers MUST be + part of graph input (only ONNX IR v3 is allowed), i.e. val_keep_init_as_ip=True. + + Special handling is needed for opset version 8 or lower, because irrespective + of user input for keep_initializers_as_inputs, the graph must follow ONNX IR v3 + semantics, i.e. all initializers must be listed as ONNX graph input. + """ + + if opset_version < 9: + if keep_initializers_as_inputs is False: + warnings.warn( + "Setting 'keep_initializers_as_inputs=False' for opset version" + "8 or lower would lead to an invalid ONNX graph. Therefore, " + "'keep_initializers_as_inputs=False' is ignored during export." + "Exported model will have initializers as graph inputs (compliant " + " to ONNX IR v3)." + ) + return True # i.e. True == initializers are part of graph input (ONNX IR v3) + val_keep_init_as_ip = ( + True if keep_initializers_as_inputs is None else keep_initializers_as_inputs + ) + if ( + keep_initializers_as_inputs is None + and operator_export_type is _C_onnx.OperatorExportTypes.ONNX + ): + val_keep_init_as_ip = False + return val_keep_init_as_ip + + +def _decide_add_node_names(add_node_names, operator_export_type): + return _resolve_args_by_export_type( + "add_node_names", add_node_names, operator_export_type + ) + + +def _decide_constant_folding(do_constant_folding, operator_export_type, training): + do_constant_folding = _resolve_args_by_export_type( + "do_constant_folding", do_constant_folding, operator_export_type + ) + if do_constant_folding and ( + training is not None and training is not _C_onnx.TrainingMode.EVAL + ): + warnings.warn( + "It is recommended that constant folding be turned off ('do_constant_folding=False') " + "when exporting the model in training-amenable mode, i.e. with 'training=TrainingMode.TRAIN' " + "or 'training=TrainingMode.PRESERVE' (when model is in training mode). Otherwise, some " + "learnable model parameters may not translate correctly in the exported ONNX model " + "because constant folding mutates model parameters. Please consider " + "turning off constant folding or setting the training=TrainingMode.EVAL." + ) + return do_constant_folding + + +def _signature(model) -> inspect.Signature: + should_be_callable = getattr(model, "forward", model) + if callable(should_be_callable): + return inspect.signature(should_be_callable) + raise ValueError("model has no forward method and is not callable") + + +def _decide_input_format(model, args): + try: + sig = _signature(model) + except ValueError as e: + warnings.warn(f"{e}, skipping _decide_input_format") + return args + try: + ordered_list_keys = list(sig.parameters.keys()) + if ordered_list_keys[0] == "self": + ordered_list_keys = ordered_list_keys[1:] + args_dict: dict = {} + if isinstance(args, list): + args_list = args + elif isinstance(args, tuple): + args_list = list(args) + else: + args_list = [args] + if isinstance(args_list[-1], dict): + args_dict = args_list[-1] + args_list = args_list[:-1] + n_nonkeyword = len(args_list) + for optional_arg in ordered_list_keys[n_nonkeyword:]: + if optional_arg in args_dict: + args_list.append(args_dict[optional_arg]) + # Check if this arg has a default value + else: + param = sig.parameters[optional_arg] + if param.default != param.empty: + args_list.append(param.default) + args = args_list if isinstance(args, list) else tuple(args_list) + # Cases of models with no input args + except IndexError: + warnings.warn("No input args, skipping _decide_input_format") + except Exception as e: + warnings.warn(f"Skipping _decide_input_format\n {e.args[0]}") + return args + + +def _trace(func, args, operator_export_type, return_outs=False): + # Special case for common case of passing a single Tensor + if isinstance(args, torch.Tensor): + args = (args,) + + trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( + func, + args, + strict=False, + _force_outplace=False, + _return_inputs_states=True, + ) + warn_on_static_input_change(inputs_states) + + trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={}) + if return_outs: + return trace_graph, torch_out + return trace_graph + + +def _trace_and_get_graph_from_model(model, args): + # A basic sanity check: make sure the state_dict keys are the same + # before and after running the model. Fail fast! + orig_state_dict_keys = torch.jit._unique_state_dict(model).keys() + + # Disable Autocast cache because it replaces kernel's weight and bias + # by (undesired) constants. + # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665 + prev_autocast_cache_enabled = torch.is_autocast_cache_enabled() + torch.set_autocast_cache_enabled(False) + trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( + model, + args, + strict=False, + _force_outplace=False, + _return_inputs_states=True, + ) + torch.set_autocast_cache_enabled(prev_autocast_cache_enabled) + + warn_on_static_input_change(inputs_states) + + if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys(): + raise RuntimeError( + "state_dict changed after running the tracer; " + "something weird is happening in your model!" + ) + + return trace_graph, torch_out + + +def _get_param_count_list(method_graph, args_params): + param_count_list = [] + for input_, arg_params_ in zip(method_graph.inputs(), args_params): + if "PackedParams" in str(input_.type()): + in_vars, _ = torch.jit._flatten(arg_params_) + param_count_list.append(len(in_vars)) + else: + param_count_list.append(arg_params_ is not None) + + return param_count_list + + +def _check_flatten_did_not_remove(original, jit_flattened): + """torch.jit._flatten removes None. Check if it did so in this case.""" + + def flatten(x): + if isinstance(x, (list, tuple)): + for inner in x: + yield from flatten(inner) + elif isinstance(x, dict): + for inner in x.values(): + yield from flatten(inner) + else: + yield x + + flattened_with_none = list(flatten(original)) + num_none = len(flattened_with_none) - len(jit_flattened) + assert num_none >= 0 + if num_none: + raise ValueError( + f"args contained {num_none} None's after flattening. " + "When exporting a ScriptModule or ScriptFunction, no args may " + "be None because that breaks type propagation." + ) + + +def _create_jit_graph( + model: torch.nn.Module | torch.jit.ScriptFunction, args: Sequence[Any] +) -> tuple[_C.Graph, list[_C.IValue], Any | None, _C.ScriptModule | None]: + if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): + flattened_args = tuple(torch.jit._flatten(tuple(args))[0]) + _check_flatten_did_not_remove(args, flattened_args) + torch_out = None + + if isinstance(model, torch.jit.ScriptModule): + try: + graph = model.forward.graph # type: ignore[attr-defined] + except AttributeError as e: + raise RuntimeError("'forward' method must be a script method") from e + _C._jit_pass_onnx_function_substitution(graph) + freezed_module = _C._freeze_module( + cast(_C.ScriptModule, model._c), preserveParameters=True + ) + module, params = _C._jit_onnx_list_model_parameters(freezed_module) + method_graph = module._get_method("forward").graph + args_params = tuple(args) + tuple(params) + param_count_list = _get_param_count_list(method_graph, args_params) + in_vars, _ = torch.jit._flatten(args_params) + graph = _C._propagate_and_assign_input_shapes( + method_graph, tuple(in_vars), param_count_list, False, False + ) + return graph, params, torch_out, module + + # torch.jit.ScriptFunction + params = [] + graph = model.graph + _C._jit_pass_onnx_function_substitution(graph) + param_count_list = _get_param_count_list(graph, args) + graph = _C._propagate_and_assign_input_shapes( + graph, flattened_args, param_count_list, False, False + ) + return graph, params, torch_out, None + + graph, torch_out = _trace_and_get_graph_from_model(model, args) + _C._jit_pass_onnx_lint(graph) + state_dict = torch.jit._unique_state_dict(model) + params = list(state_dict.values()) + graph_inputs = list(graph.inputs()) + user_input_num = len(graph_inputs) - len(state_dict) + param_names = list(state_dict.keys()) + for i, inp in enumerate(graph_inputs): + if i >= user_input_num: + inp.setDebugName(param_names[i - user_input_num]) + _C._jit_pass_onnx_function_substitution(graph) + return graph, params, torch_out, None + + +def _get_named_param_dict(graph, params): + input_and_param_names = [val.debugName() for val in graph.inputs()] + param_names = input_and_param_names[len(input_and_param_names) - len(params) :] + _params_dict = dict(zip(param_names, params)) + return _params_dict + + +def _get_example_outputs(model, args): + input_args = copy.deepcopy(args) + input_kwargs = {} + if input_args and isinstance(input_args[-1], dict): + input_kwargs = input_args[-1] + input_args = input_args[:-1] + + example_outputs = model(*input_args, **input_kwargs) + if isinstance(example_outputs, list): + example_outputs = [example_outputs] + elif not isinstance(example_outputs, tuple): + example_outputs = (example_outputs,) + + return example_outputs + + +_qtype_vtype_map = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + torch.quint4x2: torch.int8, +} + + +def unpack_quantized_tensor(value, cast_onnx_accepted=True): + if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map: + q_value_dequantize = value.dequantize() + q_scale = ( + torch.tensor(value.q_scale(), dtype=torch.double) + if cast_onnx_accepted + else torch.tensor(value.q_scale(), dtype=torch.float32) + ) + q_zero_point = ( + torch.tensor(value.q_zero_point(), dtype=torch.int64) + if cast_onnx_accepted + else torch.tensor(value.q_zero_point(), dtype=_qtype_vtype_map[value.dtype]) + ) + q_value = q_value_dequantize / q_scale + q_zero_point + q_value = q_value.to(dtype=_qtype_vtype_map[value.dtype]) + return q_value, q_scale, q_zero_point + else: + return (value,) + + +def _pre_trace_quant_model(model, args): + r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return + original model. + + This is due to https://github.com/pytorch/pytorch/issues/75761. + """ + if any( + hasattr(m, "_packed_params") for m in getattr(model, "modules", list)() + ) or any(getattr(arg, "is_quantized", False) for arg in args): + return torch.jit.trace(model, args) + return model + + +def _model_to_graph( + model, + args, + verbose=False, + input_names=None, + output_names=None, + operator_export_type=_C_onnx.OperatorExportTypes.ONNX, + do_constant_folding=True, + _disable_torch_constant_prop=False, + fixed_batch_size=False, + training=_C_onnx.TrainingMode.EVAL, + dynamic_axes=None, +) -> tuple[ + _C.Graph, + dict[str, torch.Tensor], + torch.Tensor + | tuple[torch.Tensor, ...] + | list[torch.Tensor] + | dict[str, torch.Tensor] + | Any + | None, +]: + """Converts model into an ONNX graph. + + Returns: + graph: A TorchScript IR Graph with ONNX nodes. + params_dict: Dict from input param name to param value. + torch_out: The output tensors resulting from the trace of ``model``. + If ``model`` is a :class:`torch.jit.ScriptModule` or :class:`torch.jit.ScriptFunction`, + this will be None, since we are not doing any tracing. + """ + # TODO: can we simplify this to always return a tuple of Tensor or None? + + # Special case for common case of passing a single Tensor + if isinstance(args, (torch.Tensor, int, float, bool)): + args = (args,) + + model = _pre_trace_quant_model(model, args) + graph, params, torch_out, module = _create_jit_graph(model, args) + params_dict = _get_named_param_dict(graph, params) + + try: + graph = _optimize_graph( + graph, + operator_export_type, + _disable_torch_constant_prop=_disable_torch_constant_prop, + fixed_batch_size=fixed_batch_size, + params_dict=params_dict, + dynamic_axes=dynamic_axes, + input_names=input_names, + module=module, + ) + except Exception: + _C._jit_onnx_log("Torch IR graph at exception: ", graph) + raise + + is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)) + if is_script: + example_outputs = _get_example_outputs(model, args) + example_outputs_final = () + for example_output in example_outputs: + example_outputs_final += unpack_quantized_tensor(example_output) + out_vars, desc = torch.jit._flatten(example_outputs_final) + _C._jit_pass_onnx_assign_output_shape( + graph, + out_vars, + desc, + GLOBALS.onnx_shape_inference, + is_script, + GLOBALS.export_onnx_opset_version, + ) + + # NB: ONNX requires complete information about output types, which might be + # erased by some optimizations, so we need to set it explicitly again. + else: + if not isinstance(torch_out, (list, tuple)): + output_wrapped = [torch_out] + else: + output_wrapped = torch_out # type: ignore[assignment] + + output_tensors, out_desc = torch.jit._flatten(tuple(output_wrapped)) + # assign_output_shape pass is not compatible with quantized outputs. + # Quantized outputs are flattened to 3 values in ONNX, while packed as + # single value in PyTorch. + if not any(getattr(out, "is_quantized", False) for out in output_tensors): + _C._jit_pass_onnx_assign_output_shape( + graph, + output_tensors, + out_desc, + GLOBALS.onnx_shape_inference, + is_script, + GLOBALS.export_onnx_opset_version, + ) + + _set_input_and_output_names(graph, input_names, output_names) + params_dict = _get_named_param_dict(graph, params) + + if ( + do_constant_folding + and GLOBALS.export_onnx_opset_version + >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET + ): + if training is None or training == _C_onnx.TrainingMode.EVAL: + params_dict = _C._jit_pass_onnx_eval_peephole(graph, params_dict) + + params_dict = _C._jit_pass_onnx_constant_fold( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + + if GLOBALS.onnx_shape_inference: + try: + _C._jit_pass_onnx_graph_shape_type_inference( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) + except RuntimeError: + # NOTE: shape type inference error should not stop the export process + # https://github.com/pytorch/pytorch/issues/132205 + pass + + params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) + + # For ONNX opset < 9, constants only have three data types: float16, float, double. + # In this pass transform constants of other data types to float/double + cast operator. + if GLOBALS.export_onnx_opset_version < 9: + _C._jit_pass_onnx_cast_all_constant_to_floating(graph) + + params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) + _C._jit_decay_packed_param_input_types(graph) + + # If output names lack a proper name and are identified only by their unique + # give them a legible name for debugging purposes + _apply_friendly_debug_names(graph, params_dict) + + return graph, params_dict, torch_out + + +@deprecated( + "Unconvertible ops are not definitive. Please remove usage of this function" +) +def unconvertible_ops( + model, + args, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + opset_version: int | None = None, +) -> tuple[_C.Graph, list[str]]: + """Returns an approximated list of all ops that are yet supported by :mod:`torch.onnx`. + + .. deprecated:: 2.5 + Unconvertible ops are not definitive. Please remove usage of this function. + + The list is approximated because some ops may be removed during the conversion + process and don't need to be converted. Some other ops may have partial support + that will fail conversion with particular inputs. Please open a Github Issue + for op support requests. + + Args: + model: Same as the `model` parameter in :func:`torch.onnx.export`. + args: Same as the `args` parameter in :func:`torch.onnx.export`. + training: Same as the `training` parameter in :func:`torch.onnx.export`. + opset_version: Same as the `opset_version` parameter in :func:`torch.onnx.export`. + + Returns: + The JIT graph and a list of unconvertible ops in the format of "domain::op". + """ + + opset_version = opset_version or _constants.ONNX_DEFAULT_OPSET + GLOBALS.export_onnx_opset_version = opset_version + + try: + with exporter_context(model, training, verbose=False): + # Create a mostly clean JIT graph that contains the plain aten and + # other ops we can check with the symbolic registry. + # NOTE: We don't want to actually convert any ops to ONNX or run any + # symbolic functions because there is a higher chance that a pass + # fails or an unconvertible op messes up the graph during ONNX conversion. + # This way we can always generate a list just by looking at the names + # of the ops in the graph. + args = _decide_input_format(model, args) + model = _pre_trace_quant_model(model, args) + graph, _, _, module = _create_jit_graph(model, args) + _C._jit_pass_inline(graph) + _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module) + _C._jit_pass_erase_number_types(graph) + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + except Exception as e: + raise errors.OnnxExporterError( + "Failed to discover unconvertible ops because of errors during the JIT graph " + "generation process." + ) from e + + unsupported_ops = [] + for node in graph.nodes(): + domain_op = node.kind() + if domain_op.startswith(("onnx::", "prim::")): + # We consider onnx and prim ops as supported ops, even though some "prim" + # ops are not implemented as symbolic functions, because they may be + # eliminated in the conversion passes. Users may still see errors caused + # by prim ops even though they don't show up in the list. + continue + if not registration.registry.is_registered_op( + domain_op.rstrip("_"), opset_version + ): + # We consider all registered ops supported, even though some of them are + # only partially supported, because there is not yet a good way to check + # if an op is fully supported. + # TODO(justinchuby): Create a way to check if an op is fully supported. + unsupported_ops.append(domain_op) + return graph, unsupported_ops + + +def _setup_trace_module_map( + model: torch.nn.Module | torch.jit.ScriptModule, + export_modules_as_functions: bool | Collection[type[torch.nn.Module]], +) -> set[str]: + def __register_attribute_hook(): + attr_name = "_onnx_attrs" + + def _track_module_attributes_forward_pre_hook(module, input): + setattr(module, attr_name, _get_module_attributes(module)) + + def _track_module_attributes_forward_hook(module, input, output): + tracing_state = _C._get_tracing_state() + if not tracing_state: + return + + graph = tracing_state.graph() + onnx_attrs = {} + if hasattr(module, attr_name): + onnx_attrs = getattr(module, attr_name) + delattr(module, attr_name) + + _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs) + + for m in model.modules(): + m.register_forward_hook(_track_module_attributes_forward_hook) + m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook) + + def _unqualified_variable_name(qualified_name: str) -> str: + """ + Parse qualified variable name and return the unqualified version. + + Pure numeric atoms are considered inadequate, so this function will look past them, + and start from the first non-numeric atom. + + Example: + >>> _unqualified_variable_name("__main__.Foo.bar") + 'bar' + >>> _unqualified_variable_name("__main__.Foo.bar.0") + 'bar.0' + """ + name_atoms = qualified_name.split(".") + for i, atom in reversed(list(enumerate(name_atoms))): + if not atom.isnumeric(): + return ".".join(name_atoms[i:]) + return qualified_name + + trace_module_map = { + _m: torch._C._jit_onnx_create_full_scope_name( + torch.typename(type(_m)), _unqualified_variable_name(_n) + ) + for _n, _m in model.named_modules() + } + torch.jit._trace._trace_module_map = trace_module_map + if isinstance(export_modules_as_functions, bool) and export_modules_as_functions: + module_typenames = {torch.typename(type(module)) for module in trace_module_map} + elif isinstance(export_modules_as_functions, set) and export_modules_as_functions: + + def _find_typename(v): + if isinstance(v, type): + return torch.typename(v) + else: + raise RuntimeError( + "Only type of the `nn.Module` should be " + "passed in the set for argument `export_modules_as_functions`. " + f"Got `{type(v).__name__}`." + ) + + module_typenames = {_find_typename(v) for v in export_modules_as_functions} + else: + module_typenames = set() + + if module_typenames: + __register_attribute_hook() + + return module_typenames + + +def _reset_trace_module_map(): + torch.jit._trace._trace_module_map = None + _C._jit_pass_onnx_clear_scope_records() + + +def _get_module_attributes(module): + annotations = typing.get_type_hints(type(module)) + base_m_annotations = typing.get_type_hints(torch.nn.Module) + [annotations.pop(k, None) for k in base_m_annotations] + # Check whether module attributes can be accessed. Some classes + # define attributes but don't provide access to them in their + # constructor. + # + # For example, torch.nn.Embedding has the `freeze` variable and its + # type specified in the class but the attribute is not created in the + # constructor. In other words, there is no `self.freeze = ` + # in the constructor. + # + # Reference: https://github.com/pytorch/pytorch/blob/92de1d322223fb5584e384971b32c46b93bc2f4b/torch/nn/modules/sparse.py#L120 + attrs = {} + for k in annotations: + try: + attrs[k] = getattr(module, k) + except AttributeError: + _C._jit_onnx_log(f"Skipping module attribute '{k}'") + continue + return attrs + + +def _export( + model, + args, + f, + export_params=True, + verbose=False, + training=_C_onnx.TrainingMode.EVAL, + input_names=None, + output_names=None, + operator_export_type=_C_onnx.OperatorExportTypes.ONNX, + export_type=None, + opset_version=None, + do_constant_folding=True, + dynamic_axes=None, + keep_initializers_as_inputs=None, + fixed_batch_size=False, + custom_opsets=None, + add_node_names=True, + onnx_shape_inference=True, + export_modules_as_functions: Any = False, + autograd_inlining=True, +): + assert GLOBALS.in_onnx_export is False + + if isinstance(model, torch.nn.DataParallel): + raise ValueError( + "torch.nn.DataParallel is not supported by ONNX " + "exporter, please use 'attribute' module to " + "unwrap model from torch.nn.DataParallel. Try " + "torch.onnx.export(model.module, ...)" + ) + + GLOBALS.onnx_shape_inference = onnx_shape_inference + + if opset_version is None: + opset_version = _constants.ONNX_DEFAULT_OPSET + + if opset_version > _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET: + warnings.warn( + f"Exporting to ONNX opset version {opset_version} is not supported. " + f"by 'torch.onnx.export()'. " + f"The highest opset version supported is {_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET}. " + f"To use a newer opset version, consider 'torch.onnx.export(..., dynamo=True)'. ", + category=errors.OnnxExporterWarning, + ) + + if export_modules_as_functions and opset_version < 15: + raise ValueError( + "`export_modules_as_functions` is not supported for `opset_version` < 15." + "This is because `opset_version` < 15 implies IR version < 8, which means " + "no local function support. " + ) + if not operator_export_type: + operator_export_type = _C_onnx.OperatorExportTypes.ONNX + + # By default, training=TrainingMode.EVAL, + # which is good because running a model in training mode could result in + # internal buffers getting updated, dropout getting applied, etc. + # If you really know what you're doing, you can turn + # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE, + # (to preserve whatever the original training mode was.) + GLOBALS.export_onnx_opset_version = opset_version + GLOBALS.operator_export_type = operator_export_type + + try: + GLOBALS.in_onnx_export = True + _autograd_inlining_previous = GLOBALS.autograd_inlining + GLOBALS.autograd_inlining = autograd_inlining + + module_typenames_to_export_as_functions: set[str] = set() + if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)): + module_typenames_to_export_as_functions = _setup_trace_module_map( + model, export_modules_as_functions + ) + + with exporter_context(model, training, verbose): + val_keep_init_as_ip = _decide_keep_init_as_input( + keep_initializers_as_inputs, + operator_export_type, + opset_version, + ) + val_add_node_names = _decide_add_node_names( + add_node_names, operator_export_type + ) + val_do_constant_folding = _decide_constant_folding( + do_constant_folding, operator_export_type, training + ) + # Normally f can be a file-like object, but for large models, the external data format requires a + # valid `model_file_location`. Code in export.cpp will enforce this. + if isinstance(f, str): + model_file_location = f + else: + model_file_location = "" + args = _decide_input_format(model, args) + if dynamic_axes is None: + dynamic_axes = {} + _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) + + graph, params_dict, torch_out = _model_to_graph( + model, + args, + verbose, + input_names, + output_names, + operator_export_type, + val_do_constant_folding, + fixed_batch_size=fixed_batch_size, + training=training, + dynamic_axes=dynamic_axes, + ) + + if custom_opsets is None: + custom_opsets = {} + + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + node_attr_to_name = {} # type: ignore[var-annotated] + if module_typenames_to_export_as_functions: + # NOTE: cannot call DCE after this pass. DCE will remove function definition nodes. + node_attr_to_name = _C._jit_pass_onnx_function_extraction( + graph, + module_typenames_to_export_as_functions, + list(params_dict.keys()), + ) + + if keep_initializers_as_inputs is not True: + params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment] + graph, + params_dict, # type: ignore[arg-type] + getattr(model, "training", False), # type: ignore[arg-type] + ) + _C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph) + defer_weight_export = False + if export_params: + ( + proto, + export_map, + _val_use_external_data_format, + _node_names, + ) = graph._export_onnx( # type: ignore[attr-defined] + params_dict, + opset_version, + dynamic_axes, + defer_weight_export, + operator_export_type, + not verbose, + val_keep_init_as_ip, + custom_opsets, + val_add_node_names, + model_file_location, + node_attr_to_name, + ) + else: + ( + proto, + export_map, + _, + _, + ) = graph._export_onnx( # type: ignore[attr-defined] + {}, + opset_version, + dynamic_axes, + defer_weight_export, + operator_export_type, + not verbose, + val_keep_init_as_ip, + custom_opsets, + val_add_node_names, + model_file_location, + node_attr_to_name, + ) + # insert function_proto into model_proto. + proto = onnx_proto_utils._add_onnxscript_fn( + proto, + custom_opsets, + ) + if verbose: + _C._jit_onnx_log("Exported graph: ", graph) + onnx_proto_utils._export_file(proto, f, export_map) + finally: + assert GLOBALS.in_onnx_export + GLOBALS.in_onnx_export = False + GLOBALS.autograd_inlining = _autograd_inlining_previous + _reset_trace_module_map() + + return torch_out + + +def _apply_friendly_debug_names(graph, params): + for n in graph.nodes(): + for v in n.inputs(): + old_name = v.debugName() + if old_name != str(v.unique()): + continue + new_name = f"{n.kind()}_{v.unique()}" + v.setDebugName(new_name) + if old_name in params: + params[new_name] = params.pop(old_name) + + +def _set_input_and_output_names(graph, input_names, output_names): + def set_names(node_list, name_list, descriptor): + if name_list is None: + return + if len(name_list) > len(node_list): + raise RuntimeError( + f"number of {descriptor} names provided ({len(name_list)}) " + f"exceeded number of {descriptor}s ({len(node_list)})" + ) + + # Mark if the output node DebugName is set before. + output_node_set = set() + for i, (name, node) in enumerate(zip(name_list, node_list)): + # Duplicated output node, insert onnx::Identity to avoid setting the same DebugName after setDebugName(). + if descriptor == "output": + if node in output_node_set: + identity_node = graph.create("onnx::Identity") + identity_node.insertAfter(node.node()) + identity_node.addInput(node) + identity_node.output().setType(node.type()) + graph.return_node().replaceInput(i, identity_node.output()) + node = identity_node.output() + output_node_set.add(node) + + if node.debugName() != name: + node.setDebugName(name) + + set_names(list(graph.inputs()), input_names, "input") + set_names(list(graph.outputs()), output_names, "output") + + +def _run_symbolic_method(g, op_name, symbolic_fn, args): + r""" + This trampoline function gets invoked for every symbolic method + call from C++. + """ + try: + graph_context = jit_utils.GraphContext( + graph=g, + block=g.block(), + opset=GLOBALS.export_onnx_opset_version, + original_node=None, # type: ignore[arg-type] + params_dict=_params_dict, + env={}, + values_in_env=set(), + new_nodes=[], + ) + return symbolic_fn(graph_context, *args) + except TypeError as e: + # Handle the specific case where we didn't successfully dispatch + # to symbolic_fn. Otherwise, the backtrace will have the clues + # you need. + e.args = (f"{e.args[0]} (occurred when translating {op_name})",) + raise + + +def _add_block(node: _C.Node) -> _C.Block: + return node.addBlock() + + +def _add_input_to_block(block: _C.Block): + return block.addInputToBlock() # type: ignore[attr-defined] + + +def _add_output_to_block(block: _C.Block, value: _C.Value) -> int: + return block.registerOutput(value) + + +def _should_aten_fallback( + name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes +): + # For all builds, if domain=="aten" and operator_export_type==ONNX_ATEN, + # an aten::ATen operator is created regardless of symbolics existence + + is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version) + is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN + is_aten_fallback_export = ( + operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK + ) + + if not name.startswith("aten::"): + return False + + if is_onnx_aten_export or (is_aten_fallback_export and not is_exportable_aten_op): + return True + + return False + + +def _get_aten_op_overload_name(n: _C.Node) -> str: + # Returns `overload_name` attribute to ATen ops on non-Caffe2 builds + schema = n.schema() + if not schema.startswith("aten::"): + return "" + return _C.parse_schema(schema).overload_name + + +def _run_symbolic_function( + graph: _C.Graph, + block: _C.Block, + node: _C.Node, + inputs: Any, + env: dict[_C.Value, _C.Value], + values_in_env: set[_C.Value], + new_nodes: list[_C.Node], + operator_export_type=_C_onnx.OperatorExportTypes.ONNX, +) -> _C.Value | Sequence[_C.Value | None] | None: + """Runs a symbolic function. + + The function is used in C++ to export the node to ONNX. + + Returns: + A single or a tuple of Values. + None when the node gets cloned as is into the new graph. + """ + + opset_version = GLOBALS.export_onnx_opset_version + + # See Note [Export inplace] + node_kind = node.kind() + if node_kind.endswith("_"): + # Treat relu_ -> relu; add_ -> add etc. + ns_op_name = node_kind[:-1] + else: + ns_op_name = node_kind + + namespace, op_name = jit_utils.parse_node_kind(ns_op_name) + + graph_context = jit_utils.GraphContext( + graph=graph, + block=block, + opset=opset_version, + original_node=node, + params_dict=_params_dict, + env=env, + values_in_env=values_in_env, + new_nodes=new_nodes, + ) + + # Direct ATen export requested + if _should_aten_fallback(ns_op_name, opset_version, operator_export_type): + attrs = { + k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) + for k in node.attributeNames() + } + outputs = node.outputsSize() + attrs["outputs"] = outputs + return graph_context.aten_op( + op_name, + *inputs, + overload_name=_get_aten_op_overload_name(node), + **attrs, + ) + + try: + domain = namespace + symbolic_function_name = f"{domain}::{op_name}" + + symbolic_function_group = registration.registry.get_function_group( + symbolic_function_name + ) + if symbolic_function_group is not None: + symbolic_fn = symbolic_function_group.get(opset_version) + if symbolic_fn is not None: + # TODO Wrap almost identical attrs assignment or comment the difference. + attrs = { + k: symbolic_helper._node_get(node, k) for k in node.attributeNames() + } + return symbolic_fn(graph_context, *inputs, **attrs) + + attrs = { + k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) + for k in node.attributeNames() + } + if namespace == "onnx": + # Clone node to trigger ONNX shape inference + return graph_context.op( + op_name, *inputs, **attrs, outputs=node.outputsSize() + ) # type: ignore[attr-defined] + + raise errors.UnsupportedOperatorError( + symbolic_function_name, + opset_version, + symbolic_function_group.get_min_supported() + if symbolic_function_group + else None, + ) + + except RuntimeError: + if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH: + return None + elif operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: + # Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK` + attrs = { + k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) + for k in node.attributeNames() + } + return graph_context.aten_op( + op_name, + *inputs, + overload_name=_get_aten_op_overload_name(node), + **attrs, + ) + raise + except TypeError as e: + # Handle the specific case where we didn't successfully dispatch. + # Otherwise, the backtrace will have the clues you need. + e.args = (f"{e.args[0]} \n(Occurred when translating {op_name}).",) + raise + + +def _verify_custom_op_name(symbolic_name: str): + if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name): + raise errors.OnnxExporterError( + f"Failed to register operator {symbolic_name}. " + "The symbolic name must match the format domain::name, " + "and should start with a letter and contain only " + "alphanumerical characters" + ) + + ns, _ = jit_utils.parse_node_kind(symbolic_name) + if ns == "onnx": + raise ValueError( + f"Failed to register operator {symbolic_name}. {ns} domain cannot be modified." + ) + + +def register_custom_op_symbolic( + symbolic_name: str, + symbolic_fn: Callable, + opset_version: int, +): + """Registers a symbolic function for a custom operator. + + When the user registers symbolic for custom/contrib ops, + it is highly recommended to add shape inference for that operator via setType API, + otherwise the exported graph may have incorrect shape inference in some extreme cases. + An example of setType is `test_aten_embedding_2` in `test_operators.py`. + + See "Custom Operators" in the module documentation for an example usage. + + Args: + symbolic_name (str): The name of the custom operator in "::" + format. + symbolic_fn (Callable): A function that takes in the ONNX graph and + the input arguments to the current operator, and returns new + operator nodes to add to the graph. + opset_version (int): The ONNX opset version in which to register. + """ + if symbolic_name.startswith("::"): + symbolic_name = f"aten{symbolic_name}" + + _verify_custom_op_name(symbolic_name) + + registration.custom_onnx_symbolic(symbolic_name, opset_version)(symbolic_fn) + + +def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int): + """Unregisters ``symbolic_name``. + + See "Custom Operators" in the module documentation for an example usage. + + Args: + symbolic_name (str): The name of the custom operator in "::" + format. + opset_version (int): The ONNX opset version in which to unregister. + """ + if symbolic_name.startswith("::"): + symbolic_name = f"aten{symbolic_name}" + + _verify_custom_op_name(symbolic_name) + + registration.registry.unregister(symbolic_name, opset_version) + + +def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names): + """Ensures dynamic axes argument is follows the expected format.""" + if len(dynamic_axes) == 0: + return + + if hasattr(model, "graph"): + # Extracting set of valid input/output names that shall be used for dynamic_axes + if (input_names is None) or len(input_names) == 0: + input_names = [x.debugName() for x in model.graph.inputs()] + if (output_names is None) or len(output_names) == 0: + output_names = [y.debugName() for y in model.graph.outputs()] + + valid_names = set((input_names or []) + (output_names or [])) + + # If dynamic axes are provided as a list rather than dictionary, they should + # first get converted to a dictionary in expected format. If desired axes names + # are not provided for dynamic axes, automatic names shall be generated for + # provided dynamic axes of specified input/output + for key, value in dynamic_axes.items(): + if key not in valid_names: + warnings.warn( + f"Provided key {key} for dynamic axes is not a valid input/output name" + ) + if isinstance(value, list): + warnings.warn( + "No names were found for specified dynamic axes of provided input." + f"Automatically generated names will be applied to each dynamic axes of input {key}" + ) + + value_dict = {} + for i, x in enumerate(value): + if not isinstance(x, int): + raise ValueError( + "The type of axis index is expected to be an integer" + ) + if x in value_dict: + warnings.warn( + f"Duplicate dynamic axis index {x} was provided for input {key}." + ) + else: + value_dict[x] = str(key) + "_dynamic_axes_" + str(i + 1) + dynamic_axes[key] = value_dict + + +def model_signature(model: torch.nn.Module | Callable) -> inspect.Signature: + return inspect.signature( + model.forward if isinstance(model, torch.nn.Module) else model + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index 70d901acb47a9..eb428dad8c2ab 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -1,12 +1,1886 @@ +<<<<<<< HEAD """A set of tools to verify the correctness of ONNX models.""" __all__ = ["VerificationInfo", "verify_onnx_program"] +======= +# mypy: allow-untyped-defs +"""The ONNX verification module provides a set of tools to verify the correctness of ONNX models.""" + +from __future__ import annotations + + +__all__ = [ + "OnnxBackend", + "VerificationOptions", + "verify", + "check_export_model_diff", + "VerificationInfo", + "verify_onnx_program", + "GraphInfo", + "GraphInfoPrettyPrinter", + "OnnxTestCaseRepro", + "find_mismatch", + "verify_aten_graph", +] + +import contextlib +import copy +import dataclasses +import datetime +import difflib +import enum +import functools +import io +import itertools +import os +import tempfile +import typing_extensions +import warnings +from collections.abc import Collection, Mapping, Sequence +from typing import Any, Callable, Union + +import numpy as np +import numpy.typing as npt + +import torch +import torch._C._onnx as _C_onnx +from torch import _C +from torch.onnx import _constants, _experimental, utils +from torch.onnx._globals import GLOBALS +from torch.onnx._internal import onnx_proto_utils +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.onnx._internal.exporter._verification import ( VerificationInfo, verify_onnx_program, ) +<<<<<<< HEAD VerificationInfo.__module__ = "torch.onnx.verification" verify_onnx_program.__module__ = "torch.onnx.verification" +======= +from torch.types import Number + + +# TODO: Update deprecation messages to recommend the new classes + +VerificationInfo.__module__ = "torch.onnx.verification" +verify_onnx_program.__module__ = "torch.onnx.verification" + +# Everything below are deprecated ############################################## + +_ORT_PROVIDERS = ("CPUExecutionProvider",) + +_NumericType = Union[Number, torch.Tensor, np.ndarray] +_ModelType = Union[torch.nn.Module, torch.jit.ScriptModule] +_InputArgsType = Union[torch.Tensor, tuple[Any, ...]] +_InputKwargsType = Mapping[str, Any] +_OutputsType = Union[Sequence[_NumericType], Sequence] + + +class OnnxBackend(enum.Enum): + """Enum class for ONNX backend used for export verification. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + """ + + REFERENCE = "ONNXReferenceEvaluator" + ONNX_RUNTIME_CPU = "CPUExecutionProvider" + ONNX_RUNTIME_CUDA = "CUDAExecutionProvider" + + +@dataclasses.dataclass +class VerificationOptions: + """Options for ONNX export verification. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + + Attributes: + flatten: If True, unpack nested list/tuple/dict inputs into a flattened list of + Tensors for ONNX. Set this to False if nested structures are to be preserved + for ONNX, which is usually the case with exporting ScriptModules. Default True. + ignore_none: Whether to ignore None type in torch output, which is usually the + case with tracing. Set this to False, if torch output should keep None type, + which is usually the case with exporting ScriptModules. Default to True. + check_shape: Whether to check the shapes between PyTorch and ONNX Runtime outputs + are exactly the same. Set this to False to allow output shape broadcasting. + Default to True. + check_dtype: Whether to check the dtypes between PyTorch and ONNX Runtime outputs + are consistent. Default to True. + backend: ONNX backend for verification. Default to OnnxBackend.ONNX_RUNTIME_CPU. + rtol: relative tolerance in comparison between ONNX and PyTorch outputs. + atol: absolute tolerance in comparison between ONNX and PyTorch outputs. + remained_onnx_input_idx: If provided, only the specified inputs will be passed + to the ONNX model. Supply a list when there are unused inputs in the model. + Since unused inputs will be removed in the exported ONNX model, supplying + all inputs will cause an error on unexpected inputs. This parameter tells + the verifier which inputs to pass into the ONNX model. + acceptable_error_percentage: acceptable percentage of element mismatches in comparison. + It should be a float of value between 0.0 and 1.0. + """ + + flatten: bool = True + ignore_none: bool = True + check_shape: bool = True + check_dtype: bool = True + backend: OnnxBackend = OnnxBackend.ONNX_RUNTIME_CPU + rtol: float = 1e-3 + atol: float = 1e-7 + remained_onnx_input_idx: Sequence[int] | None = None + acceptable_error_percentage: float | None = None + + +def _flatten_tuples(elem): + flattened = [] + for t in elem: + if isinstance(t, tuple): + flattened.extend(_flatten_tuples(t)) + else: + flattened.append(t) + return flattened + + +# TODO(justinchuby): Add type checking by narrowing down the return type when input is None +def _to_numpy(elem) -> list | npt.NDArray: + if isinstance(elem, torch.Tensor): + if elem.requires_grad: + return elem.detach().cpu().numpy() + else: + return elem.cpu().numpy() + elif isinstance(elem, (list, tuple)): + return [_to_numpy(inp) for inp in elem] + elif isinstance(elem, (bool, int, float)): + return np.array(elem) + elif isinstance(elem, dict): + flattened = [] + for k in elem: + flattened.extend([_to_numpy(k), _to_numpy(elem[k])]) + return flattened + return elem + + +def _inline_flatten_list(inputs, res_list) -> list: + for i in inputs: + res_list.append(i) if not isinstance( + i, (list, tuple) + ) else _inline_flatten_list(i, res_list) + return res_list + + +def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list: + value_unpacked = [] + for value in values: + value_unpacked.extend( + utils.unpack_quantized_tensor(value, cast_onnx_accepted=cast_onnx_accepted) + ) + return [_to_numpy(v) for v in value_unpacked] + + +def _run_onnx(onnx_session, inputs) -> _OutputsType: + kw_inputs = {} + if inputs and isinstance(inputs[-1], dict): + kw_inputs = inputs[-1] + inputs = inputs[:-1] + inputs = _unpack_to_numpy(_flatten_tuples(inputs)) + ort_inputs = {} + for input_name, input in kw_inputs.items(): + ort_inputs[input_name] = _to_numpy(input) + inputs = _to_numpy(inputs) + if hasattr(onnx_session, "get_inputs"): + # onnxruntime.InferenceSession + input_names = [i.name for i in onnx_session.get_inputs()] + elif hasattr(onnx_session, "input_names"): + # onnx.reference.ReferenceEvaluator + input_names = onnx_session.input_names + else: + raise ValueError(f"Unknown ONNX backend type: {type(onnx_session)}.") + + for i, input in enumerate(inputs): + if i == len(input_names) or input_names[i] in ort_inputs: + raise ValueError( + f"got too many positional inputs. inputs: {inputs}. kw_inputs: {kw_inputs}. " + f"input names: {input_names}." + ) + ort_inputs[input_names[i]] = input + onnx_outs = onnx_session.run(None, ort_inputs) + return onnx_outs + + +def _ort_session( + model: str | io.BytesIO, ort_providers: Sequence[str] = _ORT_PROVIDERS +): + try: + import onnxruntime # type: ignore[import] + except ImportError as e: + raise ImportError("onnxruntime is required for export verification.") from e + + if ort_providers is None: + ort_providers = _ORT_PROVIDERS + + session_options = onnxruntime.SessionOptions() + # suppress ort warnings. + # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. + session_options.log_severity_level = 3 + ort_session = onnxruntime.InferenceSession( + model if isinstance(model, str) else model.getvalue(), + session_options, + providers=ort_providers, + ) + return ort_session + + +def _onnx_reference_evaluator_session(model: str | io.BytesIO): + try: + import onnx + from onnx import reference as onnx_reference # type: ignore[attr-defined] + except ImportError as exc: + raise ImportError("onnx >= 1.13 is required for reference evaluator.") from exc + + proto = ( + onnx.load(model) # type: ignore[attr-defined] + if isinstance(model, str) + else onnx.load_model_from_string(model.getvalue()) # type: ignore[attr-defined] + ) + onnx_session = onnx_reference.ReferenceEvaluator(proto) + return onnx_session + + +def _onnx_backend_session(model: str | io.BytesIO, backend: OnnxBackend): + if backend == OnnxBackend.REFERENCE: + onnx_session = _onnx_reference_evaluator_session(model) + elif backend in {OnnxBackend.ONNX_RUNTIME_CPU, OnnxBackend.ONNX_RUNTIME_CUDA}: + onnx_session = _ort_session(model, (backend.value,)) + else: + raise ValueError(f"Unsupported backend: {backend}") + return onnx_session + + +def _compare_onnx_pytorch_outputs_in_np( + onnx_outs: _OutputsType, + pt_outs: _OutputsType, + options: VerificationOptions, +): + assert len(onnx_outs) == len(pt_outs), ( + f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})" + ) + acceptable_error_percentage = options.acceptable_error_percentage + if acceptable_error_percentage and ( + acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0 + ): + raise ValueError( + "If set, acceptable_error_percentage should be between 0.0 and 1.0" + ) + + for ort_out, pt_out in zip(onnx_outs, pt_outs): + try: + # TODO: Remove `check_shape` option once every shape inconsistent issue is addressed. + if not options.check_shape: + # Allow different but broadcastable output shapes. + ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out) + torch.testing.assert_close( + ort_out, + pt_out, + rtol=options.rtol, + atol=options.atol, + check_dtype=options.check_dtype, + equal_nan=True, + ) + except AssertionError as e: + if acceptable_error_percentage: + error_percentage = 1 - np.sum( + np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol) + ) / np.prod(ort_out.shape) + if error_percentage <= acceptable_error_percentage: + warnings.warn( + f"Suppressed AssertionError:\n{e}.\n" + f"Error percentage {error_percentage} " + f"within acceptable range {acceptable_error_percentage}." + ) + continue + if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8: + warnings.warn("ONNX output is quantized") + if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8: + warnings.warn("PyTorch output is quantized") + raise + + +def _compare_onnx_pytorch_outputs( + onnx_outs: _OutputsType, + pt_outs: Any, + options: VerificationOptions, +): + """ + Compare ONNX and PyTorch outputs. + + Args: + onnx_outs: outputs from ONNX backend. + pt_outs: outputs from PyTorch. + options: options for verification. + + Raises: + AssertionError: if outputs from ONNX model and PyTorch model are not + equal up to specified precision. + ValueError: if arguments provided are invalid. + """ + if options.ignore_none: + # torch.jit._flatten filters None type + pt_outs, _ = torch.jit._flatten(pt_outs) + else: + pt_outs = _inline_flatten_list([pt_outs], []) + pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False) + onnx_outs = _inline_flatten_list(onnx_outs, []) + _compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options) + + +def _prepare_input_for_pytorch(args, kwargs): + """Prepare input for PyTorch model execution. + + Any future changes/formatting to the input before dispatching to the PyTorch + model should be made in this function. + + Args: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + + Returns: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + """ + if isinstance(args, (torch.Tensor, dict)): + args = (args,) + # In-place operators will update input tensor data as well. + # Thus inputs are replicated before every forward call. + args = copy.deepcopy(args) + if kwargs: + kwargs = copy.deepcopy(kwargs) + else: + kwargs = {} + return args, kwargs + + +def _prepare_input_for_export(args, kwargs): + """Prepare input for ONNX model export. + + Any future changes/formatting to the input before dispatching to the + :func:`torch.onnx.export` api should be made in this function. + + Args: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + + Returns: + onnx_inputs: positional arguments for ONNX model export, as `args` in + :func:`torch.onnx.export`. + """ + args, kwargs = _prepare_input_for_pytorch(args, kwargs) + if not kwargs and len(args) > 0 and isinstance(args[-1], dict): + onnx_inputs = args + ({},) + elif kwargs: + onnx_inputs = args + (kwargs,) + else: + onnx_inputs = args + return onnx_inputs + + +def _prepare_input_for_onnx( + args, kwargs, remained_onnx_input_idx: Sequence[int] | None, flatten: bool +): + """Prepare input for ONNX model execution in ONNX backend. + + Any future changes/formatting to the input before dispatching to the ONNX backend + run should be made in this function. + + Args: + args: positional arguments for PyTorch model forward method. + kwargs: keyword arguments for PyTorch model forward method. + remained_onnx_input_idx: indices of inputs to be used for ONNX model execution. + flatten: whether to flatten the input before dispatching to the ONNX model execution. + + Returns: + onnx_inputs: positional arguments for ONNX model execution in ONNX backend. + """ + onnx_inputs = _prepare_input_for_export(args, kwargs) + if flatten: + onnx_inputs, _ = torch.jit._flatten(onnx_inputs) + elif onnx_inputs and onnx_inputs[-1] == {}: + # Handle empty kwargs (normally removed by flatten). + onnx_inputs = onnx_inputs[:-1] + if remained_onnx_input_idx is not None: + return [onnx_inputs[i] for i in remained_onnx_input_idx] + else: + return onnx_inputs + + +def _try_clone_model(model): + """Used for preserving original model in case forward mutates model states.""" + try: + return copy.deepcopy(model) + except Exception: + warnings.warn( + "Failed to clone model. Model state might be mutated during verification." + ) + return model + + +def _compare_onnx_pytorch_model( + pt_model: _ModelType, + onnx_model_f: str | io.BytesIO, + input_args: _InputArgsType, + input_kwargs: _InputKwargsType | None, + additional_test_inputs: Sequence[_InputArgsType] | None, + options: VerificationOptions, +): + """Compare outputs from ONNX model runs with outputs from PyTorch model runs. + + Args: + pt_model: PyTorch model. + onnx_model_f: ONNX model file path or file-like object. + input_args: positional arguments for PyTorch model forward method. + input_kwargs: keyword arguments for PyTorch model forward method. + additional_test_inputs: additional positional arguments for PyTorch model + forward method. + options: options for verification. + + Raises: + AssertionError: if outputs from ONNX model and PyTorch model are not + equal up to specified precision. + """ + onnx_session = _onnx_backend_session(onnx_model_f, options.backend) + + def compare_onnx_pytorch_model_with_input(input_args, input_kwargs): + pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs) + # TODO: remove this and treat mutating model separately. See #77679 + pt_model_copy = _try_clone_model(pt_model) + pt_outs = pt_model_copy(*pt_args, **pt_kwargs) + + onnx_inputs = _prepare_input_for_onnx( + input_args, input_kwargs, options.remained_onnx_input_idx, options.flatten + ) + + onnx_outs = _run_onnx(onnx_session, onnx_inputs) + + _compare_onnx_pytorch_outputs( + onnx_outs=onnx_outs, + pt_outs=pt_outs, + options=options, + ) + + compare_onnx_pytorch_model_with_input(input_args, input_kwargs) + + if additional_test_inputs: + for test_input_args in additional_test_inputs: + compare_onnx_pytorch_model_with_input(test_input_args, {}) + + +class _GraphDiff: + """A class to represent the difference between two graphs.""" + + def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph): + """Construct a _GraphDiff object. + + Args: + graph_a (_C.Graph): First graph to compare. + graph_b (_C.Graph): Second graph to compare. + """ + self.graph_a = graph_a + self.graph_b = graph_b + + def __str__(self): + """See function :func:`diff_report`.""" + return self.diff_report() + + def _indent(self, lines: str) -> str: + return "\n".join(["\t" + line for line in lines.splitlines()]) + + def diff_report(self) -> str: + """Return a string representation of the graph difference. + + The report shows the first pair of nodes that diverges. It also shows the source + location of the pair of nodes. + + Returns: + graph_diff_report (str): A string representation of the graph difference. + """ + graph_a = self.graph_a + graph_b = self.graph_b + + graph_a_str = str(graph_a) + graph_b_str = str(graph_b) + + if graph_a_str == graph_b_str: + return "" + + graph_diff = difflib.ndiff( + graph_a_str.splitlines(True), graph_b_str.splitlines(True) + ) + graph_diff_report = ["Graph diff:", self._indent("".join(graph_diff))] + + for node_a, node_b in itertools.zip_longest(graph_a.nodes(), graph_b.nodes()): + if str(node_a) != str(node_b): + graph_diff_report.append("First diverging operator:") + node_diff = difflib.ndiff( + str(node_a).splitlines(True), str(node_b).splitlines(True) + ) + source_printout = ["node diff:", self._indent("".join(node_diff))] + + stack_a = node_a.sourceRange() if node_a else None + if stack_a: + source_printout.extend( + ["Former source location:", self._indent(str(stack_a))] + ) + stack_b = node_b.sourceRange() if node_b else None + if stack_b: + source_printout.extend( + ["Latter source location:", self._indent(str(stack_b))] + ) + + graph_diff_report.extend(source_printout) + + break + + return "\n".join(graph_diff_report) + + +def _check_graph_diff( + model: torch.nn.Module | torch.jit.ScriptModule, + test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], + export_options: _experimental.ExportOptions, + model_to_graph_func: Callable[ + [ + torch.nn.Module, + tuple[Any, ...], + Mapping[str, Any], + _experimental.ExportOptions, + ], + _C.Graph, + ], +) -> str: + """Check if graph produced by `model_to_graph_func` is the same across `test_input_groups`. + + Args: + model: See :func:`check_export_model_diff`. + test_input_groups: See :func:`check_export_model_diff`. + export_options: See :func:`check_export_model_diff`. + model_to_graph_func: A function to convert a PyTorch model to a JIT IR graph. + + Returns: + graph_diff_report (str): A string representation of the graph difference. + """ + if len(test_input_groups) < 2: + raise ValueError("Need at least two groups of test inputs to compare.") + + ref_jit_graph = None + for args, kwargs in test_input_groups: + jit_graph = model_to_graph_func(model, args, kwargs, export_options) + if ref_jit_graph is None: + ref_jit_graph = jit_graph + continue + + graph_diff_report = _GraphDiff(ref_jit_graph, jit_graph).diff_report() + if graph_diff_report: + return graph_diff_report + return "" + + +def _traced_graph_from_model( + model: torch.nn.Module | torch.jit.ScriptModule, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + export_options: _experimental.ExportOptions, +) -> _C.Graph: + """As part of the ONNX export steps, create a traced JIT graph from a PyTorch model. + + Args: + model: See :func:`check_export_model_diff`. + args: See :func:`check_export_model_diff`. + kwargs: See :func:`check_export_model_diff`. + export_options: See :func:`check_export_model_diff`. + + Returns: + jit_graph (_C.Graph): A traced JIT graph. + """ + training = export_options.training + verbose = export_options.verbose + + with utils.exporter_context(model, training, verbose): + export_inputs = _prepare_input_for_export(args, kwargs) + model = utils._pre_trace_quant_model(model, export_inputs) + jit_graph, _, _, _ = utils._create_jit_graph(model, export_inputs) + return jit_graph + + +def _onnx_graph_from_model( + model: torch.nn.Module | torch.jit.ScriptModule, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + export_options: _experimental.ExportOptions, +) -> _C.Graph: + """As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model. + + Args: + model: See :func:`check_export_model_diff`. + args: See :func:`check_export_model_diff`. + kwargs: See :func:`check_export_model_diff`. + export_options: See :func:`check_export_model_diff`. + + Returns: + onnx_graph (_C.Graph): An ONNX JIT graph. + """ + # TODO: refactor utils.py to remove duplicated code of context setup. See #78834 + opset_version = export_options.opset_version + operator_export_type = export_options.operator_export_type + export_modules_as_functions = export_options.export_modules_as_functions + training = export_options.training + verbose = export_options.verbose + dynamic_axes = export_options.dynamic_axes + input_names = export_options.input_names + output_names = export_options.output_names + + if opset_version is None: + opset_version = _constants.ONNX_DEFAULT_OPSET + + utils._setup_trace_module_map(model, export_modules_as_functions) + + if not operator_export_type: + operator_export_type = _C_onnx.OperatorExportTypes.ONNX + + GLOBALS.export_onnx_opset_version = opset_version + GLOBALS.operator_export_type = operator_export_type + + with utils.exporter_context(model, training, verbose): + do_constant_folding = utils._decide_constant_folding( + export_options.do_constant_folding, operator_export_type, training + ) + + if dynamic_axes is None: + dynamic_axes = {} + utils._validate_dynamic_axes(dynamic_axes, model, input_names, output_names) + + export_inputs = _prepare_input_for_export(args, kwargs) + export_inputs = utils._decide_input_format(model, export_inputs) + onnx_graph, _, _ = utils._model_to_graph( + model, + export_inputs, + verbose, + input_names, + output_names, + operator_export_type, + do_constant_folding, + training=training, + dynamic_axes=dynamic_axes, + ) + + return onnx_graph + + +def _onnx_graph_from_aten_graph( + graph: torch.Graph, + export_options: _experimental.ExportOptions, + params_dict: dict[str, Any] | None = None, +) -> tuple[torch.Graph, dict[str, Any]]: + if params_dict is None: + params_dict = {} + operator_export_type = export_options.operator_export_type + dynamic_axes = export_options.dynamic_axes or {} + input_names = export_options.input_names + training = export_options.training + do_constant_folding = export_options.do_constant_folding + opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET + + GLOBALS.export_onnx_opset_version = opset_version + GLOBALS.operator_export_type = operator_export_type + + do_constant_folding = utils._decide_constant_folding( + do_constant_folding, operator_export_type, training + ) + + # TODO: Below is doing aten graph to onnx. It should be abstracted as a + # function in torch/onnx/utils.py. + graph = graph.copy() + graph = utils._optimize_graph( + graph, + operator_export_type, + params_dict=params_dict, + dynamic_axes=dynamic_axes, + input_names=input_names, + ) + + if training is None or training == _C_onnx.TrainingMode.EVAL: + params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) + + if ( + do_constant_folding + and opset_version >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET + ): + params_dict = _C._jit_pass_onnx_constant_fold(graph, params_dict, opset_version) + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + + if GLOBALS.onnx_shape_inference: + _C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, opset_version) + + params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) + + # For ONNX opset < 9, constants only have three data types: float16, float, double. + # In this pass transform constants of other data types to float/double + cast operator. + if opset_version < 9: + _C._jit_pass_onnx_cast_all_constant_to_floating(graph) + + params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) + _C._jit_decay_packed_param_input_types(graph) + + _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) + + if export_options.verbose: + print("ONNX graph: ", graph) + + return graph, params_dict + + +def _onnx_proto_from_onnx_graph( + onnx_graph: torch.Graph, + export_options: _experimental.ExportOptions, + params_dict: dict[str, Any], +) -> tuple[bytes, Mapping[str, bytes]]: + opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET + dynamic_axes = export_options.dynamic_axes or {} + operator_export_type = export_options.operator_export_type + val_keep_init_as_ip = utils._decide_keep_init_as_input( + export_options.keep_initializers_as_inputs, + operator_export_type, + opset_version, + ) + val_add_node_names = utils._decide_add_node_names(True, operator_export_type) + custom_opsets = export_options.custom_opsets or {} + + proto, export_map, _, _ = onnx_graph._export_onnx( # type: ignore[attr-defined] + params_dict, + opset_version, + dynamic_axes, + False, + operator_export_type, + not export_options.verbose, + val_keep_init_as_ip, + custom_opsets, + val_add_node_names, + "", + {}, + ) + + return proto, export_map + + +def check_export_model_diff( + model: torch.nn.Module | torch.jit.ScriptModule, + test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], + export_options: _experimental.ExportOptions | None = None, +) -> str: + """Verify exported model discrepancy between different groups of inputs. + + A graph is exported for each group of inputs. The exported graphs are then compared + to each other, and discrepancies of first pair of nodes are reported. This function + first checks the jit graph. If no discrepancies were found, it then checks the onnx + graph. + + Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless + of the inputs used for exporting. A discrepancy implies the graph exported is + not accurate when run on other groups of inputs, which will typically results in + runtime errors or mismatching output. + + Args: + model (torch.nn.Module or torch.jit.ScriptModule): The model to be exported. + test_input_groups (Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]]): A sequence + of input groups to be used to export the model. Each input group is a pair of + (args, kwargs). + export_options (_experimental.ExportOptions, optional): An _experimental.ExportOptions + object that controls the export behavior. + + Returns: + str: A string containing the diff of the exported models. + """ + export_options = ( + _experimental.ExportOptions() if export_options is None else export_options + ) + + jit_diff_report = _check_graph_diff( + model, test_input_groups, export_options, _traced_graph_from_model + ) + if jit_diff_report: + return jit_diff_report + + return _check_graph_diff( + model, test_input_groups, export_options, _onnx_graph_from_model + ) + + +@typing_extensions.deprecated( + "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " + "and use ONNXProgram to test the ONNX model", + category=None, +) +def verify( + model: _ModelType, + input_args: _InputArgsType, + input_kwargs: _InputKwargsType | None = None, + do_constant_folding: bool = True, + dynamic_axes: Mapping[str, Mapping[int, str] | Mapping[str, Sequence[int]]] + | None = None, + input_names: Sequence[str] | None = None, + output_names: Sequence[str] | None = None, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + opset_version: int | None = None, + keep_initializers_as_inputs: bool = True, + verbose: bool = False, + fixed_batch_size: bool = False, + use_external_data: bool = False, + additional_test_inputs: Sequence[_InputArgsType] | None = None, + options: VerificationOptions | None = None, +): + """Verify model export to ONNX against original PyTorch model. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + + Args: + model: See :func:`torch.onnx.export`. + input_args: See :func:`torch.onnx.export`. + input_kwargs: See :func:`torch.onnx.export`. + do_constant_folding: See :func:`torch.onnx.export`. + dynamic_axes: See :func:`torch.onnx.export`. + input_names: See :func:`torch.onnx.export`. + output_names: See :func:`torch.onnx.export`. + training: See :func:`torch.onnx.export`. + opset_version: See :func:`torch.onnx.export`. + keep_initializers_as_inputs: See :func:`torch.onnx.export`. + verbose: See :func:`torch.onnx.export`. + fixed_batch_size: Legacy argument, used only by rnn test cases. + use_external_data: Explicitly specify whether to export the model with external data. + additional_test_inputs: List of tuples. Each tuple is a group of + input arguments to test. Currently only ``*args`` are supported. + options: A VerificationOptions object that controls the verification behavior. + + Raises: + AssertionError: if outputs from ONNX model and PyTorch model are not + equal up to specified precision. + ValueError: if arguments provided are invalid. + """ + if options is None: + options = VerificationOptions() + + if training == torch.onnx.TrainingMode.TRAINING: + model.train() + elif training == torch.onnx.TrainingMode.EVAL: + model.eval() + with torch.no_grad(), contextlib.ExitStack() as stack: + model_f: str | io.BytesIO = io.BytesIO() + if use_external_data: + tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory()) + model_f = os.path.join(tmpdir_path, "model.onnx") + + inputs_for_export = _prepare_input_for_export(input_args, input_kwargs) + + # TODO(#77679): remove this and treat mutating model separately. + model_copy = _try_clone_model(model) + utils._export( + model, + inputs_for_export, + model_f, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + keep_initializers_as_inputs=keep_initializers_as_inputs, + dynamic_axes=dynamic_axes, + input_names=input_names, + output_names=output_names, + fixed_batch_size=fixed_batch_size, + training=training, + verbose=verbose, + ) + + _compare_onnx_pytorch_model( + pt_model=model_copy, + onnx_model_f=model_f, + input_args=input_args, + input_kwargs=input_kwargs, + additional_test_inputs=additional_test_inputs, + options=options, + ) + + +@typing_extensions.deprecated( + "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " + "and use ONNXProgram to test the ONNX model" +) +def verify_aten_graph( + graph: torch.Graph, + input_args: tuple[Any, ...], + export_options: _experimental.ExportOptions, + params_dict: dict[str, Any] | None = None, + verification_options: VerificationOptions | None = None, +) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: + """Verify aten graph export to ONNX against original PyTorch model. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + """ + if verification_options is None: + verification_options = VerificationOptions() + if params_dict is None: + params_dict = {} + + original_jit_graph = graph + graph = graph.copy() + + # Execute aten graph and get reference torch jit outputs. + graph_inputs = list(graph.inputs()) + jit_inputs = tuple([arg for arg in input_args if arg is not None]) + weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]] + assert all(w is not None for w in weights) + # TODO: Only copy the argument if mutation is detected in Graph. + jit_inputs = copy.deepcopy(jit_inputs) + jit_input_and_parameters = jit_inputs + tuple(weights) + jit_outs = torch._C._jit_interpret_graph(graph, jit_input_and_parameters) # type: ignore[attr-defined] + if not isinstance(jit_outs, (list, tuple)): + jit_outs = [jit_outs] + + # Convert aten graph to onnx graph. + graph, onnx_params_dict = _onnx_graph_from_aten_graph( + graph, export_options, params_dict + ) + + proto, export_map = _onnx_proto_from_onnx_graph( + graph, export_options, onnx_params_dict + ) + model_f: str | io.BytesIO = io.BytesIO() + onnx_proto_utils._export_file(proto, model_f, export_map) + + # NOTE: Verification is unstable. Try catch to emit information for debugging. + try: + # NOTE: Input might be dce'ed, so we need to remove those from the input args. + new_input_names = {v.debugName() for v in graph.inputs()} + new_input_args = [] + for v, arg in zip(original_jit_graph.inputs(), input_args): + if v.debugName() in new_input_names: + new_input_args.append(arg) + input_args = tuple(new_input_args) + + onnx_inputs = _prepare_input_for_onnx( + input_args, + {}, + verification_options.remained_onnx_input_idx, + verification_options.flatten, + ) + + onnx_session = _onnx_backend_session(model_f, verification_options.backend) + onnx_outs = _run_onnx(onnx_session, onnx_inputs) + del onnx_session # To free device memory + + try: + _compare_onnx_pytorch_outputs( + onnx_outs=onnx_outs, + pt_outs=jit_outs, + options=verification_options, + ) + except AssertionError as e: + return e, graph, jit_outs, onnx_outs + + return None, graph, jit_outs, onnx_outs + + except Exception as e: + print("Unexpected error during verification.") + print("jit graph: ", original_jit_graph) + print("onnx graph: ", graph) + raise e + + +class GraphInfoPrettyPrinter: + graph_info: GraphInfo | None + upper_printer: GraphInfoPrettyPrinter | None + lower_printer: GraphInfoPrettyPrinter | None + + graph_str_lambdas: Mapping[int, str] + connector_str_lambdas: Mapping[int, str] + children_str_lambdas: Mapping[int, str] + + def __init__(self, graph_info: GraphInfo | None): + self.graph_info = graph_info + if ( + graph_info is not None + and graph_info.upper_graph_info is not None + and graph_info.lower_graph_info is not None + ): + self.upper_printer = GraphInfoPrettyPrinter(graph_info.upper_graph_info) + self.lower_printer = GraphInfoPrettyPrinter(graph_info.lower_graph_info) + else: + self.upper_printer = None + self.lower_printer = None + + def _total_rows(self) -> int: + if self.graph_info is None: + return 1 + if self.upper_printer and self.lower_printer: + return ( + self.upper_printer._total_rows() + self.lower_printer._total_rows() + 1 + ) + return 2 # Two lines: node count + id. + + def _node_count_segment_str(self) -> str: + if self.graph_info is None: + return "..." + node_count = self.graph_info.essential_node_count() + has_mismatch = self.graph_info.has_mismatch() + error_node_kind = ( + f"({self.graph_info.essential_node_kinds().pop()})" + if node_count == 1 and has_mismatch + else "" + ) + + return f"{node_count} {'X' if has_mismatch else chr(0x2713)} {error_node_kind}" + + def _graph_id_segment_str(self) -> str: + if self.graph_info is None: + return "" + return f"id: {self.graph_info.id}" + + def _max_segment_columns(self) -> int: + return max( + map(len, (self._node_count_segment_str(), self._graph_id_segment_str())) + ) + + def _graph_segment_str_at_line(self, line: int) -> str: + """Get the string representation of the graph segment at the given line.""" + if line == 0: + result_str = self._node_count_segment_str() + result_str += " " * (self._max_segment_columns() - len(result_str)) + return result_str + if line == 1: + result_str = self._graph_id_segment_str() + result_str += " " * (self._max_segment_columns() - len(result_str)) + return result_str + if 0 <= line < self._total_rows(): + return " " * self._max_segment_columns() + return "" + + def _connector_segment_str_at_line(self, line: int) -> str: + """Get the connector segment string at the given line.""" + if self.upper_printer is None and self.lower_printer is None: + return "" + upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 + lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 + if line == 0: + return " __" + elif line < upper_total_rows + 1: + return " | " + elif line == upper_total_rows + 1: + return " |__" + elif line < upper_total_rows + lower_total_rows + 1: + return " " + return "" + + def _children_str_at_line(self, line: int) -> str: + """Get the string representation of the children at the given line. + + Recursively calls `_str_at_line` on children nodes. + """ + if self.upper_printer is None and self.lower_printer is None: + return "" + upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 + lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 + if 0 <= line < upper_total_rows: + return ( + self.upper_printer._str_at_line(line) if self.upper_printer else "..." + ) + elif upper_total_rows < line < upper_total_rows + lower_total_rows + 1: + return ( + self.lower_printer._str_at_line(line - upper_total_rows - 1) + if self.lower_printer + else "..." + ) + return "" + + def _str_at_line(self, line: int) -> str: + """Get the string representation of the graph at the given line.""" + return ( + self._graph_segment_str_at_line(line) + + self._connector_segment_str_at_line(line) + + self._children_str_at_line(line) + ) + + def pretty_print(self): + if self.graph_info is None: + print(None) + return + # Print tree. + print(" Tree: ".center(80, "=")) + total_rows = self._total_rows() + for line in range(total_rows): + print(self._str_at_line(line).rstrip()) + if self.graph_info.has_mismatch(): + # Summarize leaf subgraphs with mismatch. + print(" Mismatch leaf subgraphs: ".center(80, "=")) + print( + [ + graph_info.id + for graph_info in self.graph_info.all_mismatch_leaf_graph_info() + ] + ) + # Summarize node kinds with mismatch. + mismatch_node_kinds: dict[str, int] = {} + for graph_info in self.graph_info.all_mismatch_leaf_graph_info(): + node_kinds = graph_info.essential_node_kinds() + if len(node_kinds) == 1: + node_kind = node_kinds.pop() + mismatch_node_kinds[node_kind] = ( + mismatch_node_kinds.get(node_kind, 0) + 1 + ) + print(" Mismatch node kinds: ".center(80, "=")) + print(mismatch_node_kinds) + else: + print(" No mismatch found. ".center(80, "=")) + + +class OnnxTestCaseRepro: + def __init__(self, repro_dir): + self.repro_dir = repro_dir + self.proto, self.inputs, self.outputs = onnx_proto_utils.load_test_case( + repro_dir + ) + + @classmethod + def create_test_case_repro( + cls, proto: bytes, inputs, outputs, dir: str, name: str | None = None + ): + """Create a repro under "{dir}/test_{name}" for an ONNX test case. + + The test case contains the model and the inputs/outputs data. The directory + structure is as follows: + + dir + \u251c\u2500\u2500 test_ + \u2502 \u251c\u2500\u2500 model.onnx + \u2502 \u2514\u2500\u2500 test_data_set_0 + \u2502 \u251c\u2500\u2500 input_0.pb + \u2502 \u251c\u2500\u2500 input_1.pb + \u2502 \u251c\u2500\u2500 output_0.pb + \u2502 \u2514\u2500\u2500 output_1.pb + + Args: + proto: ONNX model proto. + inputs: Inputs to the model. + outputs: Outputs of the model. + dir: Directory to save the repro. + name: Name of the test case. If not specified, a name based on current time + will be generated. + Returns: + Path to the repro. + """ + if name is None: + name = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + return onnx_proto_utils.export_as_test_case( + proto, + _to_numpy(inputs), + _to_numpy(outputs), + name, + dir, + ) + + def validate(self, options: VerificationOptions): + """Run the ONNX test case with options.backend, and compare with the expected outputs. + + Args: + options: Options for validation. + + Raise: + AssertionError: if outputs from options.backend and expected outputs are not + equal up to specified precision. + """ + onnx_session = _onnx_backend_session(io.BytesIO(self.proto), options.backend) + run_outputs = onnx_session.run(None, self.inputs) + if hasattr(onnx_session, "get_outputs"): + output_names = [o.name for o in onnx_session.get_outputs()] + elif hasattr(onnx_session, "output_names"): + output_names = onnx_session.output_names + else: + raise ValueError(f"Unknown onnx session type: {type(onnx_session)}") + expected_outs = [self.outputs[name] for name in output_names] + _compare_onnx_pytorch_outputs_in_np(run_outputs, expected_outs, options) + + +@typing_extensions.deprecated( + "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " + "and use ONNXProgram to test the ONNX model" +) +@dataclasses.dataclass +class GraphInfo: + """GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + """ + + graph: torch.Graph + input_args: tuple[Any, ...] + params_dict: dict[str, Any] + export_options: _experimental.ExportOptions = dataclasses.field( + default_factory=_experimental.ExportOptions + ) + mismatch_error: AssertionError | None = dataclasses.field(default=None, init=False) + pt_outs: Sequence[_NumericType] | None = dataclasses.field(default=None, init=False) + upper_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) + lower_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) + id: str = dataclasses.field(default="") + _onnx_graph: torch.Graph | None = dataclasses.field(init=False, default=None) + + _EXCLUDED_NODE_KINDS: frozenset[str] = frozenset( + {"prim::Constant", "prim::ListConstruct", "aten::ScalarImplicit"} + ) + + def clear(self): + """Clear states and results of previous verification.""" + self.mismatch_error = None + self.pt_outs = None + self._onnx_graph = None + self.upper_graph_info = None + self.lower_graph_info = None + + def pretty_print_tree(self): + """Pretty print `GraphInfo` tree. + + Each node represents a subgraph, showing the number of nodes in the subgraph and + a check mark if the subgraph has output mismatch between torch and ONNX. + + The id of the subgraph is shown under the node. The `GraphInfo` object for any + subgraph can be retrieved by calling `graph_info.find_partition(id)`. + + Example:: + + ==================================== Tree: ===================================== + 5 X __2 X __1 \u2713 + id: | id: 0 | id: 00 + | | + | |__1 X (aten::relu) + | id: 01 + | + |__3 X __1 \u2713 + id: 1 | id: 10 + | + |__2 X __1 X (aten::relu) + id: 11 | id: 110 + | + |__1 \u2713 + id: 111 + =========================== Mismatch leaf subgraphs: =========================== + ['01', '110'] + ============================= Mismatch node kinds: ============================= + {'aten::relu': 2} + + """ + GraphInfoPrettyPrinter(self).pretty_print() + + def pretty_print_mismatch(self, graph: bool = False): + """Pretty print details of the mismatch between torch and ONNX. + + Args: + graph: If True, print the ATen JIT graph and ONNX graph. + """ + print(f" Mismatch info for graph partition {self.id}: ".center(80, "=")) + if graph: + print(" ATen JIT graph ".center(80, "=")) + # TODO: A more compact graph printer. + # * Drop stride, grad, device information. + # * Show source location on a separate line. + print(self.graph) + if self._onnx_graph is not None: + print(" ONNX graph ".center(80, "=")) + print(self._onnx_graph) + if self.has_mismatch(): + print(" Mismatch error ".center(80, "=")) + print(self.mismatch_error) + else: + print(" No mismatch ".center(80, "=")) + + def has_mismatch(self) -> bool: + """Return True if the subgraph has output mismatch between torch and ONNX.""" + return self.mismatch_error is not None + + def essential_node_count(self) -> int: + """Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" + return sum( + 1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS + ) + + def essential_node_kinds(self) -> set[str]: + """Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" + return { + n.kind() + for n in self.graph.nodes() + if n.kind() not in self._EXCLUDED_NODE_KINDS + } + + def all_mismatch_leaf_graph_info(self) -> list[GraphInfo]: + """Return a list of all leaf `GraphInfo` objects that have mismatch.""" + if not self.has_mismatch(): + return [] + + no_mismatch_children = ( + self.upper_graph_info is None or not self.upper_graph_info.has_mismatch() + ) and ( + self.lower_graph_info is None or not self.lower_graph_info.has_mismatch() + ) + + if no_mismatch_children: + return [self] + + results = [] + if self.upper_graph_info is not None: + results += self.upper_graph_info.all_mismatch_leaf_graph_info() + if self.lower_graph_info is not None: + results += self.lower_graph_info.all_mismatch_leaf_graph_info() + + return results + + def find_partition(self, id: str) -> GraphInfo | None: + """Find the `GraphInfo` object with the given id.""" + if id == self.id: + return self + current_length = len(self.id) + if len(id) > current_length: + if id[current_length] == "0" and self.upper_graph_info is not None: + return self.upper_graph_info.find_partition(id) + elif id[current_length] == "1" and self.lower_graph_info is not None: + return self.lower_graph_info.find_partition(id) + return None + + def export_repro( + self, repro_dir: str | None = None, name: str | None = None + ) -> str: + """Export the subgraph to ONNX along with the input/output data for repro. + + The repro directory will contain the following files:: + + dir + \u251c\u2500\u2500 test_ + \u2502 \u251c\u2500\u2500 model.onnx + \u2502 \u2514\u2500\u2500 test_data_set_0 + \u2502 \u251c\u2500\u2500 input_0.pb + \u2502 \u251c\u2500\u2500 input_1.pb + \u2502 \u251c\u2500\u2500 output_0.pb + \u2502 \u2514\u2500\u2500 output_1.pb + + Args: + repro_dir: The directory to export the repro files to. Defaults to current + working directory if None. + name: An optional name for the test case folder: "test_{name}". + + Returns: + The path to the exported repro directory. + """ + + if repro_dir is None: + repro_dir = os.getcwd() + repro_dir = os.path.join(repro_dir, "onnx_debug") + + onnx_graph, onnx_params_dict = _onnx_graph_from_aten_graph( + self.graph, self.export_options, self.params_dict + ) + + proto, _ = _onnx_proto_from_onnx_graph( + onnx_graph, self.export_options, onnx_params_dict + ) + return OnnxTestCaseRepro.create_test_case_repro( + proto, self.input_args, self.pt_outs, repro_dir, name + ) + + def _graph_partition_pivot(self) -> int: + """Find the pivot index to partition the graph. + + The pivot is the node that splits the graph into two parts. Each part should + have the similar amount of nodes, excluding non essential ops, defined in + `_EXCLUDED_NODE_KINDS`, such as `prim::Constant`. + If the graph has an odd number of nodes, the upper part will have one more node. + If the graph does not have any node that can be partitioned, return -1. + + Returns: + The index of the pivot node. + """ + included_node_indices = [ + i + for i, n in enumerate(self.graph.nodes()) + if n.kind() not in self._EXCLUDED_NODE_KINDS + ] + half_idx = len(included_node_indices) // 2 - 1 + if half_idx >= 0 and len(included_node_indices) > half_idx: + return included_node_indices[half_idx] + 1 + return -1 + + def _partition_upper_graph(self) -> torch.Graph: + pivot = self._graph_partition_pivot() + if pivot == -1: + return torch.Graph() + graph = self.graph.copy() # Copy to not mutate parent graph. + original_outputs = list(graph.outputs()) + + def _process_bridge_value_for_upper( + new_outputs: list[torch.Value], bridge_value: torch.Value + ) -> torch.Value: + # Add bridge values as upper graph outputs. + new_outputs.append(bridge_value) + return bridge_value + + new_outputs: list[torch.Value] = [] + process_bridge_value_for_upper = functools.partial( + _process_bridge_value_for_upper, new_outputs + ) + _, dropped_nodes, complete_upper_nodes_set, _ = self._partition_nodes( + graph, pivot, process_bridge_value_for_upper + ) + + for _ in enumerate(original_outputs): + graph.eraseOutput(0) + for output in new_outputs: + graph.registerOutput(output) + + for node in reversed(dropped_nodes): + node.destroy() + + for i, input in reversed(list(enumerate(list(graph.inputs())))): + if ( + not _has_uses_by_nodes(input, complete_upper_nodes_set) + and input not in new_outputs + ): + try: + graph.eraseInput(i) + except RuntimeError as e: + print(input, graph) + raise e + + return graph + + def _partition_lower_graph(self) -> torch.Graph: + pivot = self._graph_partition_pivot() + if pivot == -1: + return torch.Graph() + graph = self.graph.copy() # Copy to not mutate parent graph. + original_outputs = list(graph.outputs()) + original_inputs = list(graph.inputs()) + + def _process_bridge_value_for_lower( + graph: torch.Graph, bridge_value: torch.Value + ) -> torch.Value: + # Add bridge values as lower graph inputs. + new_input = graph.addInput() + bridge_value.replaceAllUsesWith(new_input) + new_input.copyMetadata(bridge_value) + return new_input + + process_bridge_value_for_lower = functools.partial( + _process_bridge_value_for_lower, graph + ) + + upper_nodes, lower_nodes, _, complete_lower_nodes_set = self._partition_nodes( + graph, pivot, process_bridge_value_for_lower + ) + + new_outputs = [ + output for output in original_outputs if _produced_by(output, lower_nodes) + ] + for _ in enumerate(original_outputs): + graph.eraseOutput(0) + for output in new_outputs: + graph.registerOutput(output) + + for input in original_inputs: + if _has_uses_by_nodes(input, complete_lower_nodes_set): + new_input = graph.addInput() + input.replaceAllUsesWith(new_input) + new_input.copyMetadata(input) + + for node in reversed(upper_nodes): + if node not in complete_lower_nodes_set: + try: + node.destroy() + except RuntimeError as e: + print(node, graph) + raise e + + for _ in original_inputs: + graph.eraseInput(0) + + return graph + + def _partition_node( + self, + node: torch.Node, + complete_upper_nodes_set: set[torch.Node], + complete_lower_nodes_set: set[torch.Node], + original_graph_outputs: set[torch.Value], + covered_bridge_values: set[torch.Value], + process_bridge_value: Callable[[torch.Value], torch.Value], + ): + if node in complete_lower_nodes_set: + return + + if ( + _node_has_uses_by(node, complete_lower_nodes_set) + and node.kind() in self._EXCLUDED_NODE_KINDS + ): + complete_lower_nodes_set.update(_all_nodes([node])) + for input in node.inputs(): + if input in covered_bridge_values: + continue + self._partition_node( + input.node(), + complete_upper_nodes_set, + complete_lower_nodes_set, + original_graph_outputs, + covered_bridge_values, + process_bridge_value, + ) + else: + for output in node.outputs(): + if output in covered_bridge_values: + continue + if ( + _has_uses_by_nodes(output, complete_lower_nodes_set) + or output in original_graph_outputs + ): + covered_bridge_values.add(process_bridge_value(output)) + + def _partition_nodes( + self, + graph: torch.Graph, + pivot: int, + process_bridge_value: Callable[[torch.Value], torch.Value], + ) -> tuple[list[torch.Node], list[torch.Node], set[torch.Node], set[torch.Node]]: + nodes = list(graph.nodes()) + upper_nodes = nodes[:pivot] + lower_nodes = nodes[pivot:] + # `upper_nodes` and `complete_upper_nodes_set` differs in that the latter + # recursively contains nodes in subblock of `upper_nodes`. + # The same applies for `lower_nodes` and `complete_lower_nodes_set`. + # With addition that `complete_lower_nodes_set` will include nodes that + # are determined to be copied from `upper_nodes` to `lower_nodes`. + complete_upper_nodes_set = _all_nodes(upper_nodes) + complete_lower_nodes_set = _all_nodes(lower_nodes) + original_graph_outputs = set(graph.outputs()) + # Bridge values are values produced from upper graph, and consumed + # by lower graph. These values need to be become upper graph outputs + # and lower graph inputs, to bridge the interaction. + # Start with all graph inputs marked as covered. If any graph input is + # needed by lower graph, just keep it in lower graph inputs later. + covered_bridge_values = set(graph.inputs()) + for node in upper_nodes: + self._partition_node( + node, + complete_upper_nodes_set, + complete_lower_nodes_set, + original_graph_outputs, + covered_bridge_values, + process_bridge_value, + ) + return ( + upper_nodes, + lower_nodes, + complete_upper_nodes_set, + complete_lower_nodes_set, + ) + + def _bridge_kwargs(self): + pt_outs = self.pt_outs + graph_outputs = list(self.graph.outputs()) + assert pt_outs is not None + assert len(graph_outputs) == len(pt_outs), ( + f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}" + ) + return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)} + + def _args_and_params_for_partition_graph( + self, + graph: torch.Graph, + bridge_kwargs: Mapping[str, _NumericType | Sequence[_NumericType]], + full_kwargs: Mapping[str, torch.Tensor], + full_params: Mapping[str, torch.Tensor], + ): + input_names = [input.debugName() for input in graph.inputs()] + args = tuple(bridge_kwargs[k] for k in input_names if k in bridge_kwargs) + args += tuple(full_kwargs[k] for k in input_names if k in full_kwargs) + params = {k: full_params[k] for k in input_names if k in full_params} + assert len(args) + len(params) == len(input_names), ( + f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}" + ) + return args, params + + def verify_export( + self, options: VerificationOptions + ) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: + """ + Verify the export from TorchScript IR graph to ONNX. + + Export the TorchScript IR graph to ONNX, with the inputs, parameters and export + options recorded in this object. Then verify the exported ONNX graph against + the original TorchScript IR graph under the provided verification options. + + Args: + options: The verification options. + + Returns: + error: The AssertionError raised during the verification. Returns None if no + error is raised. + onnx_graph: The exported ONNX graph in TorchScript IR format. + onnx_outs: The outputs from running exported ONNX model under the onnx + backend in `options`. + pt_outs: The outputs from running the TorchScript IR graph. + """ + return verify_aten_graph( + self.graph, + input_args=self.input_args, + params_dict=self.params_dict, + export_options=self.export_options, + verification_options=options, + ) + + def find_mismatch( + self, + options: VerificationOptions | None = None, + ): + """ + Find all mismatches between the TorchScript IR graph and the exported onnx model. + + Binary searches the model graph to find the minimal subgraph that exhibits the + mismatch. A `GraphInfo` object is created for each subgraph, recording the test + inputs and export options, as well as the validation results. + + Args: + options: The verification options. + """ + self.clear() + + if options is None: + options = VerificationOptions() + + if self.export_options.verbose: + print(self.graph) + + if len(list(self.graph.outputs())) == 0: + return + + assert len(self.input_args) + len(self.params_dict) == len( + list(self.graph.inputs()) + ), ( + f"Number of graph inputs({len(list(self.graph.inputs()))}) does not match " + f"the provided tensor arguments({len(self.input_args)} + {len(self.params_dict)})." + ) + + self.mismatch_error, self._onnx_graph, self.pt_outs, _ = self.verify_export( + options + ) + + if self.mismatch_error is None: + # No mismatch found in graph. + return + + if self.essential_node_count() <= 1: + # Reached leaf node, no more partitioning. + return + + full_kwargs = { + k.debugName(): v for k, v in zip(self.graph.inputs(), self.input_args) + } + full_params = self.params_dict + + upper_graph = self._partition_upper_graph() + upper_args, upper_params = self._args_and_params_for_partition_graph( + upper_graph, {}, full_kwargs, full_params + ) + self.upper_graph_info = GraphInfo( + upper_graph, + upper_args, + upper_params, + self.export_options, + id=self.id + "0", + ) + + self.upper_graph_info.find_mismatch(options) + + bridge_kwargs = self.upper_graph_info._bridge_kwargs() + lower_graph = self._partition_lower_graph() + lower_args, lower_params = self._args_and_params_for_partition_graph( + lower_graph, bridge_kwargs, full_kwargs, full_params + ) + self.lower_graph_info = GraphInfo( + lower_graph, + lower_args, + lower_params, + self.export_options, + id=self.id + "1", + ) + + self.lower_graph_info.find_mismatch(options) + + +def _all_nodes(nodes: Collection[torch.Node]) -> set[torch.Node]: + all_nodes = set(nodes) + for n in nodes: + for b in n.blocks(): + all_nodes.update(_all_nodes(list(b.nodes()))) + return all_nodes + + +def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool: + return any(use.user in nodes for use in value.uses()) + + +def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool: + for output in node.outputs(): + if _has_uses_by_nodes(output, nodes): + return True + return False + + +def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool: + return value.node() in nodes + + +@typing_extensions.deprecated( + "torch.onnx.verification.* is deprecated. Consider using torch.onnx.export(..., dynamo=True) " + "and use ONNXProgram to test the ONNX model" +) +def find_mismatch( + model: torch.nn.Module | torch.jit.ScriptModule, + input_args: tuple[Any, ...], + do_constant_folding: bool = True, + training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, + opset_version: int | None = None, + keep_initializers_as_inputs: bool = True, + verbose: bool = False, + options: VerificationOptions | None = None, +) -> GraphInfo: + r"""Find all mismatches between the original model and the exported model. + + .. deprecated:: 2.7 + Consider using ``torch.onnx.export(..., dynamo=True)`` and use the returned + ``ONNXProgram`` to test the ONNX model. + + Experimental. The API is subject to change. + + This tool helps debug the mismatch between the original PyTorch model and exported + ONNX model. It binary searches the model graph to find the minimal subgraph that + exhibits the mismatch. + + Args: + model: The model to be exported. + input_args: The input arguments to the model. + do_constant_folding: Same as `do_constant_folding` in :func:`torch.onnx.export`. + training: Same as `training` in :func:`torch.onnx.export`. + opset_version: Same as `opset_version` in :func:`torch.onnx.export`. + keep_initializers_as_inputs: Same as `keep_initializers_as_inputs` in :func:`torch.onnx.export`. + verbose: Same as `verbose` in :func:`torch.onnx.export`. + options: The options for the mismatch verification. + + Returns: + A GraphInfo object that contains the mismatch information. + + Example:: + + >>> import torch + >>> import torch.onnx.verification + >>> torch.manual_seed(0) + >>> opset_version = 15 + >>> # Define a custom symbolic function for aten::relu. + >>> # The custom symbolic function is incorrect, which will result in mismatches. + >>> def incorrect_relu_symbolic_function(g, self): + ... return self + >>> torch.onnx.register_custom_op_symbolic( + ... "aten::relu", + ... incorrect_relu_symbolic_function, + ... opset_version=opset_version, + ... ) + >>> class Model(torch.nn.Module): + ... def __init__(self) -> None: + ... super().__init__() + ... self.layers = torch.nn.Sequential( + ... torch.nn.Linear(3, 4), + ... torch.nn.ReLU(), + ... torch.nn.Linear(4, 5), + ... torch.nn.ReLU(), + ... torch.nn.Linear(5, 6), + ... ) + ... def forward(self, x): + ... return self.layers(x) + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) + >>> graph_info = torch.onnx.verification.find_mismatch( + ... Model(), + ... (torch.randn(2, 3),), + ... opset_version=opset_version, + ... ) + ===================== Mismatch info for graph partition : ====================== + ================================ Mismatch error ================================ + Tensor-likes are not close! + Mismatched elements: 12 / 12 (100.0%) + Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed) + Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed) + ==================================== Tree: ===================================== + 5 X __2 X __1 \u2713 + id: | id: 0 | id: 00 + | | + | |__1 X (aten::relu) + | id: 01 + | + |__3 X __1 \u2713 + id: 1 | id: 10 + | + |__2 X __1 X (aten::relu) + id: 11 | id: 110 + | + |__1 \u2713 + id: 111 + =========================== Mismatch leaf subgraphs: =========================== + ['01', '110'] + ============================= Mismatch node kinds: ============================= + {'aten::relu': 2} + + """ + if options is None: + options = VerificationOptions() + if opset_version is None: + opset_version = _constants.ONNX_DEFAULT_OPSET + """From aten graph, do binary search on graph partition to find operator export discrepancy.""" + # TODO: Copied from utils.py `export` until `_optimize_graph`. + if training == torch.onnx.TrainingMode.TRAINING: + model.train() + elif training == torch.onnx.TrainingMode.EVAL: + model.eval() + with torch.no_grad(): + inputs_for_export = _prepare_input_for_export(input_args, {}) + args = utils._decide_input_format(model, inputs_for_export) + + model = utils._pre_trace_quant_model(model, args) + graph, params, _torch_out, _module = utils._create_jit_graph(model, args) + params_dict = utils._get_named_param_dict(graph, params) + + utils._apply_friendly_debug_names(graph, params_dict) + + graph_info = GraphInfo( + graph, + input_args, + params_dict, + _experimental.ExportOptions( + do_constant_folding=do_constant_folding, + training=training, + opset_version=opset_version, + keep_initializers_as_inputs=keep_initializers_as_inputs, + verbose=verbose, + ), + ) + graph_info.find_mismatch(options) + graph_info.pretty_print_mismatch() + graph_info.pretty_print_tree() + + return graph_info +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/optim/__init__.py b/torch/optim/__init__.py index 1060a6287a8e6..00981e061754c 100644 --- a/torch/optim/__init__.py +++ b/torch/optim/__init__.py @@ -8,7 +8,10 @@ from torch.optim import lr_scheduler as lr_scheduler, swa_utils as swa_utils from torch.optim._adafactor import Adafactor as Adafactor +<<<<<<< HEAD from torch.optim._muon import Muon as Muon +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.optim.adadelta import Adadelta as Adadelta from torch.optim.adagrad import Adagrad as Adagrad from torch.optim.adam import Adam as Adam @@ -26,7 +29,10 @@ Adafactor.__module__ = "torch.optim" +<<<<<<< HEAD Muon.__module__ = "torch.optim" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) del adadelta # type: ignore[name-defined] # noqa: F821 @@ -54,7 +60,10 @@ "ASGD", "LBFGS", "lr_scheduler", +<<<<<<< HEAD "Muon", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "NAdam", "Optimizer", "RAdam", diff --git a/torch/optim/_adafactor.py b/torch/optim/_adafactor.py index fc6225d52272b..bd2a3988f40dc 100644 --- a/torch/optim/_adafactor.py +++ b/torch/optim/_adafactor.py @@ -47,6 +47,7 @@ def __init__( raise ValueError(f"Clipping threshold d should be >= 1 but is: {d}") if not 0.0 <= weight_decay: raise ValueError(f"weight_decay should be >= 0 but is: {weight_decay}") +<<<<<<< HEAD defaults = { "lr": lr, "beta2_decay": beta2_decay, @@ -56,6 +57,17 @@ def __init__( "foreach": foreach, "maximize": maximize, } +======= + defaults = dict( + lr=lr, + beta2_decay=beta2_decay, + eps=eps, + d=d, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(params, defaults) def __setstate__(self, state): @@ -315,9 +327,12 @@ def step(self, closure=None): &\hspace{5mm}U_t \leftarrow \frac{G_t}{\sqrt{\widehat{V}_t}} \\ \end{aligned} +<<<<<<< HEAD You may note that Noam Shazeer and Mitchell Stern describe using the sum of squared gradients, while this implementation uses the mean instead. This choice is mathematically equivalent and allows for greater numerical stability for large sums. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) .. _Adafactor\: Adaptive Learning Rates with Sublinear Memory Cost: https://arxiv.org/pdf/1804.04235 diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index 49c1dd0df7713..1aa4739a2f69c 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -50,6 +50,7 @@ def __init__( if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") +<<<<<<< HEAD defaults = { "lr": lr, "rho": rho, @@ -60,6 +61,18 @@ def __init__( "foreach": foreach, "differentiable": differentiable, } +======= + defaults = dict( + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable, + foreach=foreach, + differentiable=differentiable, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(params, defaults) def __setstate__(self, state): @@ -372,7 +385,11 @@ def _multi_tensor_adadelta( device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] if weight_decay != 0: +<<<<<<< HEAD # Reuse the intermediate memory (device_grads) already allocated for maximize +======= + # Re-use the intermediate memory (device_grads) already allocated for maximize +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if maximize: torch._foreach_add_(device_grads, device_params, alpha=weight_decay) else: diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 00b3c9c28774f..d4e5836619eae 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -54,6 +54,7 @@ def __init__( if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") +<<<<<<< HEAD defaults = { "lr": lr, "lr_decay": lr_decay, @@ -65,6 +66,19 @@ def __init__( "differentiable": differentiable, "fused": fused, } +======= + defaults = dict( + lr=lr, + lr_decay=lr_decay, + eps=eps, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + fused=fused, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(params, defaults) if fused: @@ -117,7 +131,10 @@ def __setstate__(self, state): ) def share_memory(self): +<<<<<<< HEAD """Calls tensor.share_memory_() on the state sum tensors.""" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for group in self.param_groups: for p in group["params"]: state = self.state[p] @@ -467,7 +484,11 @@ def _multi_tensor_adagrad( torch._foreach_add_(device_state_steps, 1) if weight_decay != 0: +<<<<<<< HEAD # Reuse the intermediate memory (device_grads) already allocated for maximize +======= + # Re-use the intermediate memory (device_grads) already allocated for maximize +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if maximize: torch._foreach_add_(device_grads, device_params, alpha=weight_decay) else: @@ -485,7 +506,11 @@ def _multi_tensor_adagrad( torch._foreach_add_(std, eps) if weight_decay != 0 or maximize: +<<<<<<< HEAD # Again, reuse the intermediate memory (device_grads) already allocated +======= + # Again, re-use the intermediate memory (device_grads) already allocated +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._foreach_mul_(device_grads, minus_clr) numerator = device_grads else: diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 8bbccfb0bc117..f91cc3f7ce372 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -85,6 +85,7 @@ def __init__( if betas[1].numel() != 1: raise ValueError("Tensor betas[1] must be 1-element") +<<<<<<< HEAD defaults = { "lr": lr, "betas": betas, @@ -98,6 +99,21 @@ def __init__( "fused": fused, "decoupled_weight_decay": decoupled_weight_decay, } +======= + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + maximize=maximize, + foreach=foreach, + capturable=capturable, + differentiable=differentiable, + fused=fused, + decoupled_weight_decay=decoupled_weight_decay, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(params, defaults) if fused: @@ -459,11 +475,17 @@ def _single_tensor_adam( # expavg.lerp(grad^2, 1-beta2) exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2) else: +<<<<<<< HEAD exp_avg_sq.mul_(beta2).addcmul_( grad, grad, value=cast(float, 1 - beta2) ) else: exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # type: ignore[arg-type] +======= + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + else: + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if capturable or differentiable: step = step_t @@ -534,7 +556,11 @@ def _single_tensor_adam( else: denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) +<<<<<<< HEAD param.addcdiv_(exp_avg, denom, value=-step_size) # type: ignore[arg-type] +======= + param.addcdiv_(exp_avg, denom, value=-step_size) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Lastly, switch back to complex view if amsgrad and torch.is_complex(params[i]): @@ -677,7 +703,11 @@ def _multi_tensor_adam( # Perform stepweight decay torch._foreach_mul_(device_params, 1 - lr * weight_decay) else: +<<<<<<< HEAD # Reuse the intermediate memory (device_grads) already allocated for maximize +======= + # Re-use the intermediate memory (device_grads) already allocated for maximize +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if maximize: torch._foreach_add_(device_grads, device_params, alpha=weight_decay) else: @@ -688,9 +718,13 @@ def _multi_tensor_adam( # Decay the first and second moment running average coefficient # Use device beta1 if beta1 is a tensor to ensure all # tensors are on the same device +<<<<<<< HEAD torch._foreach_lerp_( device_exp_avgs, device_grads, cast(float, 1 - device_beta1) ) +======= + torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - device_beta1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._foreach_mul_(device_exp_avg_sqs, beta2) diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index 7c58aa3dda6f2..48a776bc8c646 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -53,6 +53,7 @@ def __init__( if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") +<<<<<<< HEAD defaults = { "lr": lr, "betas": betas, @@ -63,6 +64,18 @@ def __init__( "differentiable": differentiable, "capturable": capturable, } +======= + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(params, defaults) def __setstate__(self, state): @@ -376,7 +389,11 @@ def _multi_tensor_adamax( if weight_decay != 0: if maximize: +<<<<<<< HEAD # Reuse the intermediate memory (grouped_grads) already allocated for maximize +======= + # Re-use the intermediate memory (grouped_grads) already allocated for maximize +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) else: grouped_grads = torch._foreach_add( # type: ignore[assignment] diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index aff201520adb7..54d8c4824d6af 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -47,6 +47,7 @@ def __init__( if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") +<<<<<<< HEAD defaults = { "lr": lr, "lambd": lambd, @@ -58,6 +59,19 @@ def __init__( "differentiable": differentiable, "capturable": capturable, } +======= + defaults = dict( + lr=lr, + lambd=lambd, + alpha=alpha, + t0=t0, + weight_decay=weight_decay, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(params, defaults) def __setstate__(self, state): diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index 674aaaf268835..d7c2bce62a21d 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -231,6 +231,7 @@ def __init__( raise ValueError(f"Invalid learning rate: {lr}") if max_eval is None: max_eval = max_iter * 5 // 4 +<<<<<<< HEAD defaults = { "lr": lr, "max_iter": max_iter, @@ -240,6 +241,17 @@ def __init__( "history_size": history_size, "line_search_fn": line_search_fn, } +======= + defaults = dict( + lr=lr, + max_iter=max_iter, + max_eval=max_eval, + tolerance_grad=tolerance_grad, + tolerance_change=tolerance_change, + history_size=history_size, + line_search_fn=line_search_fn, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(params, defaults) if len(self.param_groups) != 1: @@ -454,8 +466,12 @@ def obj_func(x, t, d): # the reason we do this: in a stochastic setting, # no use to re-evaluate that function here with torch.enable_grad(): +<<<<<<< HEAD loss = closure() loss = float(loss) +======= + loss = float(closure()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) flat_grad = self._gather_flat_grad() opt_cond = flat_grad.abs().max() <= tolerance_grad ls_func_evals = 1 diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 8703719dabc72..4ac3f0e7c7668 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -200,16 +200,24 @@ def step(self, epoch: Optional[int] = None) -> None: ) self._step_count += 1 +<<<<<<< HEAD if epoch is not None: warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) self._update_lr(epoch) def _update_lr(self, epoch: Optional[int] = None): +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with _enable_get_lr_call(self): if epoch is None: self.last_epoch += 1 values = self.get_lr() else: +<<<<<<< HEAD +======= + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.last_epoch = epoch if hasattr(self, "_get_closed_form_lr"): values = cast(list[float], self._get_closed_form_lr()) @@ -916,7 +924,11 @@ def step(self) -> None: # type: ignore[override] idx = bisect_right(self._milestones, self.last_epoch) scheduler = self._schedulers[idx] if idx > 0 and self._milestones[idx - 1] == self.last_epoch: +<<<<<<< HEAD scheduler._update_lr(0) +======= + scheduler.step(0) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: scheduler.step() @@ -1347,7 +1359,11 @@ def step(self, metrics: SupportsFloat, epoch=None) -> None: # type: ignore[over warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) self.last_epoch = epoch +<<<<<<< HEAD if self._is_better(current, self.best): +======= + if self.is_better(current, self.best): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.best = current self.num_bad_epochs = 0 else: @@ -1389,7 +1405,11 @@ def _reduce_lr(self, epoch): def in_cooldown(self): # noqa: D102 return self.cooldown_counter > 0 +<<<<<<< HEAD def _is_better(self, a, best): # noqa: D102 +======= + def is_better(self, a, best): # noqa: D102 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.mode == "min" and self.threshold_mode == "rel": rel_epsilon = 1.0 - self.threshold return a < best * rel_epsilon @@ -1689,6 +1709,7 @@ def get_lr(self) -> list[float]: @override def state_dict(self) -> dict[str, Any]: # noqa: D102 +<<<<<<< HEAD """Return the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which @@ -1698,6 +1719,8 @@ def state_dict(self) -> dict[str, Any]: # noqa: D102 When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. """ +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) state = super().state_dict() # We are dropping the `_scale_fn_ref` attribute because it is a # `weakref.WeakMethod` and can't be pickled. diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index 2adb5866ad07b..cb214495c1738 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -59,6 +59,7 @@ def __init__( raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= momentum_decay: raise ValueError(f"Invalid momentum_decay value: {momentum_decay}") +<<<<<<< HEAD defaults = { "lr": lr, "betas": betas, @@ -71,6 +72,20 @@ def __init__( "capturable": capturable, "differentiable": differentiable, } +======= + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + momentum_decay=momentum_decay, + decoupled_weight_decay=decoupled_weight_decay, + maximize=maximize, + foreach=foreach, + capturable=capturable, + differentiable=differentiable, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(params, defaults) def __setstate__(self, state): # noqa: D105 @@ -371,9 +386,13 @@ def _single_tensor_nadam( grad, denom, value=(-lr * (1.0 - mu) / (1.0 - _get_value(mu_product))) ) param.addcdiv_( +<<<<<<< HEAD exp_avg, denom, value=cast(float, (-lr * mu_next) / (1.0 - mu_product_next)), +======= + exp_avg, denom, value=(-lr * mu_next) / (1.0 - mu_product_next) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -462,7 +481,11 @@ def _multi_tensor_nadam( # Perform stepweight decay torch._foreach_mul_(grouped_params, 1 - lr * weight_decay) else: +<<<<<<< HEAD # Reuse the intermediate memory (grouped_grads) already allocated for maximize +======= + # Re-use the intermediate memory (grouped_grads) already allocated for maximize +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if maximize: torch._foreach_add_( grouped_grads, grouped_params, alpha=weight_decay diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 2ef6c48f4efab..d5fdfb21c9ff2 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -28,10 +28,16 @@ Args: TypeAlias = tuple[Any, ...] Kwargs: TypeAlias = dict[str, Any] StateDict: TypeAlias = dict[str, Any] +<<<<<<< HEAD DeviceDict: TypeAlias = dict[Optional[torch.device], torch.Tensor] DeviceDtypeDict: TypeAlias = dict[ Optional[tuple[torch.device, torch.dtype]], torch.Tensor ] +======= +DeviceDict = dict[Optional[torch.device], torch.Tensor] +DeviceDtypeDict = dict[Optional[tuple[torch.device, torch.dtype]], torch.Tensor] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GlobalOptimizerPreHook: TypeAlias = Callable[ ["Optimizer", Args, Kwargs], Optional[tuple[Args, Kwargs]] @@ -999,6 +1005,7 @@ def zero_grad(self, set_to_none: bool = True) -> None: r"""Reset the gradients of all optimized :class:`torch.Tensor` s. Args: +<<<<<<< HEAD set_to_none (bool, optional): Instead of setting to zero, set the grads to None. Default: ``True`` This will in general have lower memory footprint, and can modestly improve performance. @@ -1011,6 +1018,18 @@ def zero_grad(self, set_to_none: bool = True) -> None: 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether). +======= + set_to_none (bool): instead of setting to zero, set the grads to None. + This will in general have lower memory footprint, and can modestly improve performance. + However, it changes certain behaviors. For example: + 1. When the user tries to access a gradient and perform manual ops on it, + a None attribute or a Tensor full of 0s will behave differently. + 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s + are guaranteed to be None for params that did not receive a gradient. + 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None + (in one case it does the step with a gradient of 0 and in the other it skips + the step altogether). +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ foreach = self.defaults.get("foreach", False) or self.defaults.get( "fused", False diff --git a/torch/optim/radam.py b/torch/optim/radam.py index bf5bc9102ce23..0d3114beffd12 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -56,6 +56,7 @@ def __init__( if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") +<<<<<<< HEAD defaults = { "lr": lr, "betas": betas, @@ -67,6 +68,19 @@ def __init__( "decoupled_weight_decay": decoupled_weight_decay, "differentiable": differentiable, } +======= + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + maximize=maximize, + foreach=foreach, + capturable=capturable, + decoupled_weight_decay=decoupled_weight_decay, + differentiable=differentiable, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(params, defaults) def __setstate__(self, state): # noqa: D105 @@ -461,7 +475,11 @@ def _multi_tensor_radam( if decoupled_weight_decay: torch._foreach_mul_(grouped_params, 1 - lr * weight_decay) else: +<<<<<<< HEAD # Reuse the intermediate memory (grouped_grads) already allocated for maximize +======= + # Re-use the intermediate memory (grouped_grads) already allocated for maximize +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if maximize: torch._foreach_add_( grouped_grads, grouped_params, alpha=weight_decay diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 7dd0ba2c048ff..b3f3e6128d2b4 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -55,6 +55,7 @@ def __init__( if not 0.0 <= alpha: raise ValueError(f"Invalid alpha value: {alpha}") +<<<<<<< HEAD defaults = { "lr": lr, "momentum": momentum, @@ -67,6 +68,20 @@ def __init__( "maximize": maximize, "differentiable": differentiable, } +======= + defaults = dict( + lr=lr, + momentum=momentum, + alpha=alpha, + eps=eps, + centered=centered, + weight_decay=weight_decay, + capturable=capturable, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(params, defaults) def __setstate__(self, state): # noqa: D105 @@ -420,7 +435,11 @@ def _multi_tensor_rmsprop( torch._foreach_add_(grouped_state_steps, 1) if weight_decay != 0: +<<<<<<< HEAD # Reuse the intermediate memory (grouped_grads) already allocated for maximize +======= + # Re-use the intermediate memory (grouped_grads) already allocated for maximize +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if maximize: torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) else: diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index c46fc24bfa98c..4f56e9e93256c 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -47,6 +47,7 @@ def __init__( if not 0.0 < etas[0] < 1.0 < etas[1]: raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}") +<<<<<<< HEAD defaults = { "lr": lr, "etas": etas, @@ -56,6 +57,17 @@ def __init__( "differentiable": differentiable, "capturable": capturable, } +======= + defaults = dict( + lr=lr, + etas=etas, + step_sizes=step_sizes, + foreach=foreach, + maximize=maximize, + differentiable=differentiable, + capturable=capturable, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(params, defaults) def __setstate__(self, state): # noqa: D105 @@ -200,9 +212,15 @@ def step(self, closure=None): For further details regarding the algorithm we refer to the paper `A Direct Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm +<<<<<<< HEAD `_.""" # codespell:ignore + rf""" +======= + `_. + """ + + rf""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: {_params_doc} lr (float, optional): learning rate (default: 1e-2) diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 4fafecbd31bdd..a674bda280e6f 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -49,6 +49,7 @@ def __init__( if weight_decay < 0.0: raise ValueError(f"Invalid weight_decay value: {weight_decay}") +<<<<<<< HEAD defaults = { "lr": lr, "momentum": momentum, @@ -60,6 +61,19 @@ def __init__( "differentiable": differentiable, "fused": fused, } +======= + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + maximize=maximize, + foreach=foreach, + differentiable=differentiable, + fused=fused, + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") super().__init__(params, defaults) @@ -355,7 +369,11 @@ def _single_tensor_sgd( buf = momentum_buffer_list[i] if buf is None: +<<<<<<< HEAD buf = grad.detach().clone() +======= + buf = torch.clone(grad).detach() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) momentum_buffer_list[i] = buf else: buf.mul_(momentum).add_(grad, alpha=1 - dampening) @@ -417,7 +435,11 @@ def _multi_tensor_sgd( device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] if weight_decay != 0: +<<<<<<< HEAD # Reuse the intermediate memory (device_grads) already allocated for maximize +======= + # Re-use the intermediate memory (device_grads) already allocated for maximize +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if maximize: torch._foreach_add_(device_grads, device_params, alpha=weight_decay) else: @@ -445,7 +467,11 @@ def _multi_tensor_sgd( if device_momentum_buffer_list[i] is None: buf = device_momentum_buffer_list[i] = momentum_buffer_list[ indices[i] +<<<<<<< HEAD ] = device_grads[i].detach().clone() +======= + ] = torch.clone(device_grads[i]).detach() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: buf = cast(Tensor, device_momentum_buffer_list[i]) buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening) diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py index cbcd9d2797335..014b5e17cbe1d 100644 --- a/torch/optim/sparse_adam.py +++ b/torch/optim/sparse_adam.py @@ -31,12 +31,16 @@ def __init__( if not 0.0 <= betas[1] < 1.0: raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") +<<<<<<< HEAD defaults = { "lr": lr, "betas": betas, "eps": eps, "maximize": maximize, } +======= + defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(params, defaults) sparse_params = [] diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index da4f005820c68..1a50929eb2dac 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -6,7 +6,11 @@ import warnings from collections.abc import Iterable from copy import deepcopy +<<<<<<< HEAD from typing import Any, Callable, cast, Literal, Optional, Union +======= +from typing import Any, Callable, Literal, Optional, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch from torch import Tensor @@ -69,9 +73,13 @@ def swa_update( averaged_param_list[0] ): torch._foreach_lerp_( +<<<<<<< HEAD averaged_param_list, current_param_list, cast(float, 1 / (num_averaged + 1)), +======= + averaged_param_list, current_param_list, 1 / (num_averaged + 1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) else: diffs = torch._foreach_sub(current_param_list, averaged_param_list) @@ -259,12 +267,19 @@ def update_parameters(self, model: Module): ) self_param_detached: list[Optional[Tensor]] = [] model_param_detached: list[Optional[Tensor]] = [] +<<<<<<< HEAD copy_param = bool(self.n_averaged == 0) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for p_averaged, p_model in zip(self_param, model_param): p_model_ = p_model.detach().to(p_averaged.device) self_param_detached.append(p_averaged.detach()) model_param_detached.append(p_model_) +<<<<<<< HEAD if copy_param: +======= + if self.n_averaged == 0: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) p_averaged.detach().copy_(p_model_) if self.n_averaged > 0: diff --git a/torch/overrides.py b/torch/overrides.py index c8fd7c6a22899..b86c886689f43 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -362,7 +362,10 @@ def get_ignored_functions() -> set[Callable]: Tensor._view_func, Tensor._view_func_unsafe, Tensor._rev_view_func_unsafe, +<<<<<<< HEAD Tensor._make_dtensor, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor._make_wrapper_subclass, Tensor._python_dispatch.__get__, Tensor._has_symbolic_sizes_strides.__get__, @@ -610,8 +613,13 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.fused_moving_avg_obs_fake_quant: ( lambda x, observer_on, fake_quant_on, averaging_const, running_min, running_max, scale, zero_point, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False: -1 # noqa: B950 ), +<<<<<<< HEAD torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias, output: -1, torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias, output: -1, +======= + torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1, + torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, # noqa: B950 torch.fbgemm_linear_int8_weight_fp32_activation: ( lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1 @@ -676,7 +684,10 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.gt: lambda input, other, out=None: -1, torch.greater: lambda input, other, out=None: -1, torch.hardshrink: lambda input, lambd=0.5: -1, +<<<<<<< HEAD torch.hash_tensor: lambda input, dim=None, keepdim=False, mode=0, out=None: -1, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.heaviside: lambda input, values, out=None: -1, torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950 torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1, @@ -822,7 +833,10 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1, torch.native_dropout: lambda input, p, train: -1, torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, +<<<<<<< HEAD torch._fused_rms_norm: lambda input, normalized_shape, weight=None, eps=1e-05: -1, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1, torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1, torch.native_channel_shuffle: lambda input, groups: -1, @@ -1514,7 +1528,11 @@ def get_testing_overrides() -> dict[Callable, Callable]: Tensor.view: lambda self, shape: -1, Tensor.view_as: lambda self, other: -1, Tensor.zero_: lambda self: -1, +<<<<<<< HEAD Tensor.__dlpack__: lambda self, stream=None, max_version=None, dl_device=None, copy=None: -1, +======= + Tensor.__dlpack__: lambda self, stream=None: -1, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Tensor.__dlpack_device__: lambda self: -1, torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1, } # fmt: skip @@ -2006,7 +2024,11 @@ def is_tensor_like(inp): class TorchFunctionMode: """ A ``TorchFunctionMode`` allows you to override the meaning of all +<<<<<<< HEAD ``__torch_function__`` overridable functions within a dynamic scope, +======= + ``__torch_function__`` overrideable functions within a dynamic scope, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) without having to actually create a tensor subclass or manually monkey-patch functions in the PyTorch API. Some common situations where you should use a mode: diff --git a/torch/package/_mangling.py b/torch/package/_mangling.py index 08b0560f79322..9a962b383f165 100644 --- a/torch/package/_mangling.py +++ b/torch/package/_mangling.py @@ -2,7 +2,10 @@ """Import mangling. See mangling.md for details. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import re diff --git a/torch/package/_stdlib.py b/torch/package/_stdlib.py index 57a51ac41cfd9..a871242d5abb6 100644 --- a/torch/package/_stdlib.py +++ b/torch/package/_stdlib.py @@ -17,12 +17,20 @@ def is_stdlib_module(module: str) -> bool: def _get_stdlib_modules(): +<<<<<<< HEAD if sys.version_info.major == 3: # noqa: UP036 +======= + if sys.version_info.major == 3: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if sys.version_info.minor == 9: return stdlib3_9 if sys.version_info.minor >= 10: # noqa: YTT204 return sys.stdlib_module_names # type: ignore[attr-defined] +<<<<<<< HEAD elif sys.version_info.major > 3: # noqa: UP036 +======= + elif sys.version_info.major > 3: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return sys.stdlib_module_names # type: ignore[attr-defined] raise RuntimeError(f"Unsupported Python version: {sys.version_info}") diff --git a/torch/package/importer.py b/torch/package/importer.py index 8cfc1e336a454..f90bde04ac93c 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -1,6 +1,9 @@ # mypy: allow-untyped-defs import importlib +<<<<<<< HEAD import logging +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from abc import ABC, abstractmethod from pickle import ( # type: ignore[attr-defined] _getattribute, @@ -14,7 +17,10 @@ __all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"] +<<<<<<< HEAD log = logging.getLogger(__name__) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ObjNotFoundError(Exception): @@ -206,6 +212,7 @@ def _is_torchpackage_dummy(self, module): return True return module.__file__ is None +<<<<<<< HEAD def get_name(self, obj: Any, name: Optional[str] = None) -> tuple[str, str]: for importer in self._importers: try: @@ -220,6 +227,8 @@ def get_name(self, obj: Any, name: Optional[str] = None) -> tuple[str, str]: f"Could not find obj {obj} and name {name} in any of the importers {self._importers}" ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def import_module(self, module_name: str) -> ModuleType: last_err = None for importer in self._importers: diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 6118e8ce80964..5898df9c11c8b 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -605,9 +605,15 @@ def save_pickle( dependencies (bool, optional): If ``True``, we scan the source for dependencies. """ +<<<<<<< HEAD assert (pickle_protocol == 4) or (pickle_protocol == 3), ( "torch.package only supports pickle protocols 3 and 4" ) +======= + assert (pickle_protocol == 4) or ( + pickle_protocol == 3 + ), "torch.package only supports pickle protocols 3 and 4" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) filename = self._filename(package, resource) # Write the pickle data for `obj` diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index 7291227e42ae2..594516650988e 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -423,12 +423,16 @@ def _load_module(self, name: str, parent: str): module.__dict__.setdefault(old_name, new_name) return module +<<<<<<< HEAD return self._make_module( name, cur.source_file, # type: ignore[attr-defined] isinstance(cur, _PackageNode), parent, ) +======= + return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _compile_source(self, fullpath: str, mangled_filename: str): source = self.zip_reader.get_record(fullpath) diff --git a/torch/profiler/__init__.py b/torch/profiler/__init__.py index 153d4560e2641..ac4e0347daf4f 100644 --- a/torch/profiler/__init__.py +++ b/torch/profiler/__init__.py @@ -1,3 +1,7 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r""" PyTorch Profiler is a tool that allows the collection of performance metrics during training and inference. Profiler's context manager API can be used to better understand what model operators are the most expensive, @@ -7,16 +11,24 @@ An earlier version of the API in :mod:`torch.autograd` module is considered legacy and will be deprecated. """ +<<<<<<< HEAD import os from typing import Any from typing_extensions import TypeVarTuple, Unpack +======= +import os +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch._C._autograd import _supported_activities, DeviceType, kineto_available from torch._C._profiler import _ExperimentalConfig, ProfilerActivity, RecordScope from torch._environment import is_fbcode from torch.autograd.profiler import KinetoStepTracker, record_function +<<<<<<< HEAD from torch.optim.optimizer import Optimizer, register_optimizer_step_post_hook +======= +from torch.optim.optimizer import register_optimizer_step_post_hook +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from .profiler import ( _KinetoProfile, @@ -45,12 +57,16 @@ from . import itt +<<<<<<< HEAD _Ts = TypeVarTuple("_Ts") def _optimizer_post_hook( optimizer: Optimizer, args: tuple[Unpack[_Ts]], kwargs: dict[str, Any] ) -> None: +======= +def _optimizer_post_hook(optimizer, args, kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) KinetoStepTracker.increment_step("Optimizer") diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index d9f3a917c1525..4c584574bdfd4 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -239,12 +239,19 @@ def inputs_are_mutable(cls, t: _ExtraFields_TorchOp) -> tuple[Optional[bool], .. def match_schemas(cls, t: _ExtraFields_TorchOp) -> tuple[FunctionSchema, ...]: signature = tuple( # Tensor +<<<<<<< HEAD TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata) # # TensorList else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list) +======= + TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata) + # + # TensorList + else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # # Scalar and uncaptured inputs. else i @@ -516,7 +523,11 @@ def __init__(self, op_tree: OpTree) -> None: def flow_nodes(self) -> tuple[DataFlowNode, ...]: return tuple(self._flow_nodes) +<<<<<<< HEAD def validate(self) -> None: +======= + def validate(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Check that each (Tensor, version) pair has a unique creation node outputs: set[tuple[TensorKey, int]] = set() for node in self.flow_nodes: @@ -966,7 +977,11 @@ def _set_optimizer_state(self) -> None: if key is not None: self._categories.set_by_id(key, Category.OPTIMIZER_STATE) +<<<<<<< HEAD def _set_autograd_detail(self) -> None: +======= + def _set_autograd_detail(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) prior = {None, Category.AUTOGRAD_DETAIL} for node in self._data_flow_graph.flow_nodes: if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event): @@ -978,7 +993,11 @@ def _set_autograd_detail(self) -> None: class MemoryProfileTimeline: +<<<<<<< HEAD def __init__(self, memory_profile) -> None: +======= + def __init__(self, memory_profile): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """The minimum representation of the memory profile timeline includes the memory timeline and categories. The timeline consists of [timestamp, action, (TensorKey, version), numbytes] @@ -1001,7 +1020,11 @@ def _coalesce_timeline(self, device_str): times: list[int] = [] sizes: list[list[int]] = [] +<<<<<<< HEAD def update(key, version, delta) -> None: +======= + def update(key, version, delta): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) category = ( self.categories.get(key, version) if isinstance(key, TensorKey) diff --git a/torch/profiler/_pattern_matcher.py b/torch/profiler/_pattern_matcher.py index cee47f28eb04a..5c3c7d6e94664 100644 --- a/torch/profiler/_pattern_matcher.py +++ b/torch/profiler/_pattern_matcher.py @@ -26,7 +26,11 @@ class Pattern: In subclass, define description and skip property. """ +<<<<<<< HEAD def __init__(self, prof: profile, should_benchmark: bool = False) -> None: +======= + def __init__(self, prof: profile, should_benchmark: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.prof = prof self.should_benchmark = should_benchmark self.name = "Please specify a name for pattern" @@ -39,7 +43,11 @@ def __init__(self, prof: profile, should_benchmark: bool = False) -> None: self.tid_root.setdefault(event.start_tid, []).append(event) @property +<<<<<<< HEAD def skip(self) -> bool: +======= + def skip(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False def report(self, event: _ProfilerEvent): @@ -66,8 +74,13 @@ def summary(self, events: list[_ProfilerEvent]): ) return default_summary +<<<<<<< HEAD def benchmark_summary(self, events: list[_ProfilerEvent]) -> str: def format_time(time_ns: int) -> str: +======= + def benchmark_summary(self, events: list[_ProfilerEvent]): + def format_time(time_ns: int): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unit_lst = ["ns", "us", "ms"] for unit in unit_lst: if time_ns < 1000: @@ -135,9 +148,13 @@ def go_up_until(self, event: _ProfilerEvent, predicate): class NamePattern(Pattern): +<<<<<<< HEAD def __init__( self, prof: profile, name: str, should_benchmark: bool = False ) -> None: +======= + def __init__(self, prof: profile, name: str, should_benchmark: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(prof, should_benchmark) self.description = f"Matched Name Event: {name}" self.name = name @@ -163,7 +180,11 @@ class ExtraCUDACopyPattern(Pattern): If at any step we failed, it is not a match. """ +<<<<<<< HEAD def __init__(self, prof: profile, should_benchmark: bool = False) -> None: +======= + def __init__(self, prof: profile, should_benchmark: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(prof, should_benchmark) self.name = "Extra CUDA Copy Pattern" self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initialize it on GPU." @@ -176,7 +197,11 @@ def __init__(self, prof: profile, should_benchmark: bool = False) -> None: } @property +<<<<<<< HEAD def skip(self) -> bool: +======= + def skip(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return not self.prof.with_stack or not self.prof.record_shapes def match(self, event): @@ -250,7 +275,11 @@ class ForLoopIndexingPattern(Pattern): We also keep a dictionary to avoid duplicate match in the for loop. """ +<<<<<<< HEAD def __init__(self, prof: profile, should_benchmark: bool = False) -> None: +======= + def __init__(self, prof: profile, should_benchmark: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(prof, should_benchmark) self.name = "For Loop Indexing Pattern" self.description = "For loop indexing detected. Vectorization recommended." @@ -273,7 +302,11 @@ def match(self, event: _ProfilerEvent): return False # Custom event list matching +<<<<<<< HEAD def same_ops(list1, list2) -> bool: +======= + def same_ops(list1, list2): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if len(list1) != len(list2): return False for op1, op2 in zip(list1, list2): @@ -297,7 +330,11 @@ def same_ops(list1, list2) -> bool: class FP32MatMulPattern(Pattern): +<<<<<<< HEAD def __init__(self, prof: profile, should_benchmark: bool = False) -> None: +======= + def __init__(self, prof: profile, should_benchmark: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(prof, should_benchmark) self.name = "FP32 MatMul Pattern" self.description = ( @@ -312,6 +349,7 @@ def skip(self): has_tf32 = False else: # Anything less than sm_80 is not Ampere which doesn't support TF32 +<<<<<<< HEAD has_tf32 = all( int(re.sub("sm_|compute_", "", arch)) >= 80 for arch in torch.cuda.get_arch_list() @@ -319,6 +357,12 @@ def skip(self): return has_tf32 is False or super().skip or not self.prof.record_shapes def match(self, event: _ProfilerEvent) -> bool: +======= + has_tf32 = all(int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list()) + return has_tf32 is False or super().skip or not self.prof.record_shapes + + def match(self, event: _ProfilerEvent): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # If we saw this pattern once, we don't need to match it again if event.tag != _EventType.TorchOp: return False @@ -367,7 +411,11 @@ class OptimizerSingleTensorPattern(Pattern): String match """ +<<<<<<< HEAD def __init__(self, prof: profile, should_benchmark: bool = False) -> None: +======= + def __init__(self, prof: profile, should_benchmark: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(prof, should_benchmark) self.name = "Optimizer Single Tensor Pattern" self.optimizers_with_foreach = ["adam", "sgd", "adamw"] @@ -377,7 +425,11 @@ def __init__(self, prof: profile, should_benchmark: bool = False) -> None: ) self.url = "" +<<<<<<< HEAD def match(self, event: _ProfilerEvent) -> bool: +======= + def match(self, event: _ProfilerEvent): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for optimizer in self.optimizers_with_foreach: if event.name.endswith(f"_single_tensor_{optimizer}"): return True @@ -402,7 +454,11 @@ class SynchronizedDataLoaderPattern(Pattern): """ +<<<<<<< HEAD def __init__(self, prof: profile, should_benchmark: bool = False) -> None: +======= + def __init__(self, prof: profile, should_benchmark: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(prof, should_benchmark) self.name = "Synchronized DataLoader Pattern" self.description = ( @@ -414,7 +470,11 @@ def __init__(self, prof: profile, should_benchmark: bool = False) -> None: "#enable-async-data-loading-and-augmentation" ) +<<<<<<< HEAD def match(self, event: _ProfilerEvent) -> bool: +======= + def match(self, event: _ProfilerEvent): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def is_dataloader_function(name: str, function_name: str): return name.startswith( os.path.join("torch", "utils", "data", "dataloader.py") @@ -461,7 +521,11 @@ class GradNotSetToNonePattern(Pattern): String match """ +<<<<<<< HEAD def __init__(self, prof: profile, should_benchmark: bool = False) -> None: +======= + def __init__(self, prof: profile, should_benchmark: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(prof, should_benchmark) self.name = "Gradient Set To Zero Instead of None Pattern" self.description = ( @@ -473,7 +537,11 @@ def __init__(self, prof: profile, should_benchmark: bool = False) -> None: "#disable-gradient-calculation-for-validation-or-inference" ) +<<<<<<< HEAD def match(self, event: _ProfilerEvent) -> bool: +======= + def match(self, event: _ProfilerEvent): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not event.name.endswith(": zero_grad"): return False if not event.children: @@ -502,7 +570,11 @@ class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern): String match """ +<<<<<<< HEAD def __init__(self, prof: profile, should_benchmark: bool = False) -> None: +======= + def __init__(self, prof: profile, should_benchmark: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(prof, should_benchmark) self.name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern" self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d." @@ -533,17 +605,28 @@ def match(self, event: _ProfilerEvent): class MatMulDimInFP16Pattern(Pattern): +<<<<<<< HEAD def __init__(self, prof: profile, should_benchmark: bool = False) -> None: +======= + def __init__(self, prof: profile, should_benchmark: bool = False): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__(prof, should_benchmark) self.name = "Matrix Multiplication Dimension Not Aligned Pattern" self.description = "Detected matmul with dimension not aligned. Please use matmul with aligned dimension." self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-mixed-precision-and-amp" @property +<<<<<<< HEAD def skip(self) -> bool: return not self.prof.with_stack or not self.prof.record_shapes def match(self, event: _ProfilerEvent) -> bool: +======= + def skip(self): + return not self.prof.with_stack or not self.prof.record_shapes + + def match(self, event: _ProfilerEvent): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def mutiple_of(shapes, multiple): return all(dim % multiple == 0 for shape in shapes for dim in shape[-2:]) @@ -586,7 +669,11 @@ def closest_multiple(shapes, multiple): return shapes_factor_map +<<<<<<< HEAD def source_code_location(event: Optional[_ProfilerEvent]) -> str: +======= +def source_code_location(event: Optional[_ProfilerEvent]): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) while event: if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall: assert isinstance( @@ -613,7 +700,11 @@ def report_all_anti_patterns( should_benchmark: bool = False, print_enable: bool = True, json_report_dir: Optional[str] = None, +<<<<<<< HEAD ) -> None: +======= +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) report_dict: dict = {} anti_patterns = [ ExtraCUDACopyPattern(prof, should_benchmark), diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 5b631ef743c6e..abbf3ee325c0e 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -52,7 +52,11 @@ class Interval: class EventKey: +<<<<<<< HEAD def __init__(self, event) -> None: +======= + def __init__(self, event): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.event = event def __hash__(self): @@ -61,7 +65,11 @@ def __hash__(self): def __eq__(self, other): return self.event.id == other.event.id +<<<<<<< HEAD def __repr__(self) -> str: +======= + def __repr__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f"{self.event.name}" def intervals_overlap(self, intervals: list[Interval]): @@ -98,7 +106,11 @@ def intervals_overlap(self, intervals: list[Interval]): class BasicEvaluation: +<<<<<<< HEAD def __init__(self, prof: profile) -> None: +======= + def __init__(self, prof: profile): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.profile = prof self.metrics: dict[EventKey, EventMetrics] = {} self.compute_self_time() @@ -110,7 +122,11 @@ def __init__(self, prof: profile) -> None: self.queue_depth_list = self.compute_queue_depth() self.compute_idle_time() +<<<<<<< HEAD def compute_self_time(self) -> None: +======= + def compute_self_time(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Computes event's self time(total time - time in child ops). """ @@ -124,9 +140,15 @@ def compute_self_time(self) -> None: for child_event in curr_event.children: self_time -= child_event.duration_time_ns stack.append(child_event) +<<<<<<< HEAD assert EventKey(curr_event) not in self.metrics, ( f"Duplicate id: {curr_event.id}, {curr_event.name}" ) +======= + assert ( + EventKey(curr_event) not in self.metrics + ), f"Duplicate id: {curr_event.id}, {curr_event.name}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time) self.metrics[ EventKey(curr_event) @@ -142,6 +164,7 @@ def compute_queue_depth(self): cuda_event_list = self.profile.kineto_results.events() def is_cuda_launch_kernel(e): +<<<<<<< HEAD """Check if the event is a CUDA launch kernel.""" launch_patterns = { "cudaLaunchKernel", # Standard CUDA @@ -165,6 +188,14 @@ def is_cuda_kernel(e): exclude_patterns = {"mem", "cpy", "alloc", "free"} return not any(pattern in name for pattern in exclude_patterns) +======= + # TODO: find a better way to identify cudaLaunchKernel + return e.name == "cudaLaunchKernel" + + def is_cuda_kernel(e): + # TODO: find a better way to identify CUDA Kernel + return e.device_type() == DeviceType.CUDA and "mem" not in e.name.lower() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cuda_launch_events = sorted( (e for e in cuda_event_list if is_cuda_launch_kernel(e)), @@ -227,7 +258,12 @@ def new_old_event_comparator(event): while ( current_kernel_index < len(cuda_kernel_events) +<<<<<<< HEAD and (cuda_kernel_events[current_kernel_index].start_ns()) <= start_time # type: ignore[possibly-undefined] +======= + and (cuda_kernel_events[current_kernel_index].start_ns()) + <= start_time # type: ignore[possibly-undefined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): current_kernel_index += 1 current_queue_depth = spawned_kernel_index - current_kernel_index + 1 @@ -242,7 +278,11 @@ def new_old_event_comparator(event): return queue_depth_list +<<<<<<< HEAD def compute_idle_time(self) -> None: +======= + def compute_idle_time(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Computes idle time of the profile. """ @@ -351,11 +391,19 @@ def get_optimizable_events(self, length: int = 1, print_enable: bool = True): output += "\n".join( [ +<<<<<<< HEAD f"""{"-" * 80} Event: {event} Source code location: {source_code_location(event.event)} Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}% {"-" * 80}""" +======= + f"""{'-' * 80} +Event: {event} +Source code location: {source_code_location(event.event)} +Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}% +{'-' * 80}""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for event in event_list ] ) @@ -394,7 +442,11 @@ def source_code_location(event): # https://github.com/pytorch/pytorch/issues/75504 # TODO(dberard) - deprecate / remove workaround for CUDA >= 12, when # we stop supporting older CUDA versions. +<<<<<<< HEAD def _init_for_cuda_graphs() -> None: +======= +def _init_for_cuda_graphs(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.autograd.profiler import profile with profile(): diff --git a/torch/profiler/itt.py b/torch/profiler/itt.py index 7b1a6eac0f0bc..764c1b0ed447b 100644 --- a/torch/profiler/itt.py +++ b/torch/profiler/itt.py @@ -1,6 +1,9 @@ # mypy: allow-untyped-defs from contextlib import contextmanager +<<<<<<< HEAD from typing import NoReturn +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: @@ -9,13 +12,21 @@ class _ITTStub: @staticmethod +<<<<<<< HEAD def _fail(*args, **kwargs) -> NoReturn: +======= + def _fail(*args, **kwargs): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise RuntimeError( "ITT functions not installed. Are you sure you have a ITT build?" ) @staticmethod +<<<<<<< HEAD def is_available() -> bool: +======= + def is_available(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return False rangePush = _fail diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 573541799bbe6..14f137639ff54 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -145,7 +145,11 @@ def __init__( execution_trace_observer: Optional[_ITraceObserver] = None, acc_events: bool = False, custom_trace_id_callback: Optional[Callable[[], str]] = None, +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.activities = set(activities) if activities else supported_activities() self.record_shapes = record_shapes self.with_flops = with_flops @@ -174,6 +178,7 @@ def __init__( # user-defined metadata to be amended to the trace self.preset_metadata: dict[str, str] = {} +<<<<<<< HEAD def start(self) -> None: self.prepare_trace() self.start_trace() @@ -182,6 +187,16 @@ def stop(self) -> None: self.stop_trace() def prepare_trace(self) -> None: +======= + def start(self): + self.prepare_trace() + self.start_trace() + + def stop(self): + self.stop_trace() + + def prepare_trace(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if hasattr(torch, "_inductor"): import torch._inductor.config as inductor_config @@ -202,7 +217,11 @@ def prepare_trace(self) -> None: ) self.profiler._prepare_trace() +<<<<<<< HEAD def start_trace(self) -> None: +======= + def start_trace(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.execution_trace_observer: self.execution_trace_observer.start() assert self.profiler is not None @@ -248,7 +267,11 @@ def start_trace(self) -> None: for k, v in self.preset_metadata.items(): self.add_metadata_json(k, v) +<<<<<<< HEAD def stop_trace(self) -> None: +======= + def stop_trace(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.execution_trace_observer: self.execution_trace_observer.stop() assert self.profiler is not None @@ -284,7 +307,11 @@ def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): def toggle_collection_dynamic( self, enable: bool, activities: Iterable[ProfilerActivity] +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Toggle collection of activities on/off at any point of collection. Currently supports toggling Torch Ops (CPU) and CUDA activity supported in Kineto @@ -341,7 +368,11 @@ def events(self): assert self.profiler return self.profiler.function_events +<<<<<<< HEAD def add_metadata(self, key: str, value: str) -> None: +======= + def add_metadata(self, key: str, value: str): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Adds a user defined metadata with a string key and a string value into the trace file @@ -349,14 +380,22 @@ def add_metadata(self, key: str, value: str) -> None: wrapped_value = '"' + value.replace('"', '\\"') + '"' torch.autograd._add_metadata_json(key, wrapped_value) +<<<<<<< HEAD def add_metadata_json(self, key: str, value: str) -> None: +======= + def add_metadata_json(self, key: str, value: str): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Adds a user defined metadata with a string key and a valid json value into the trace file """ torch.autograd._add_metadata_json(key, value) +<<<<<<< HEAD def preset_metadata_json(self, key: str, value: str) -> None: +======= + def preset_metadata_json(self, key: str, value: str): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Preset a user defined metadata when the profiler is not started and added into the trace file later. @@ -527,6 +566,10 @@ def tensorboard_trace_handler( ``worker_name`` should be unique for each worker in distributed scenario, it will be set to '[hostname]_[pid]' by default. """ +<<<<<<< HEAD +======= + import os +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import socket import time @@ -623,7 +666,12 @@ class profile(_KinetoProfile): ] ) as p: code_to_profile() +<<<<<<< HEAD print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) +======= + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions: @@ -633,17 +681,28 @@ class profile(_KinetoProfile): # on different iterations of the training loop; # trace_handler is called every time a new trace becomes available def trace_handler(prof): +<<<<<<< HEAD print( prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1) ) # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") +======= + print(prof.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # In this example with wait=1, warmup=1, active=2, repeat=1, # profiler will skip the first step/iteration, # start warming up on the second, record @@ -651,6 +710,7 @@ def trace_handler(prof): # after which the trace will become available # and on_trace_ready (when set) is called; # the cycle repeats starting with the next step +<<<<<<< HEAD schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1), on_trace_ready=trace_handler, # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') @@ -660,6 +720,22 @@ def trace_handler(prof): code_iteration_to_profile(iter) # send a signal to the profiler that the next iteration has started p.step() +======= + + schedule=torch.profiler.schedule( + wait=1, + warmup=1, + active=2, + repeat=1), + on_trace_ready=trace_handler + # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') + # used when outputting for tensorboard + ) as p: + for iter in range(N): + code_iteration_to_profile(iter) + # send a signal to the profiler that the next iteration has started + p.step() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`) @@ -696,7 +772,11 @@ def __init__( # deprecated: use_cuda: Optional[bool] = None, custom_trace_id_callback: Optional[Callable[[], str]] = None, +<<<<<<< HEAD ) -> None: +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) activities_set = set(activities) if activities else supported_activities() if use_cuda is not None: warn( @@ -812,7 +892,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): if self.execution_trace_observer: self.execution_trace_observer.cleanup() +<<<<<<< HEAD def start(self) -> None: +======= + def start(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._transit_action(ProfilerAction.NONE, self.current_action) if self.record_steps: self.step_rec_fn = prof.record_function( @@ -820,12 +904,20 @@ def start(self) -> None: ) self.step_rec_fn.__enter__() +<<<<<<< HEAD def stop(self) -> None: +======= + def stop(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.record_steps and self.step_rec_fn: self.step_rec_fn.__exit__(None, None, None) self._transit_action(self.current_action, None) +<<<<<<< HEAD def step(self) -> None: +======= + def step(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Signals the profiler that the next profiling step has started. """ @@ -847,7 +939,11 @@ def step(self) -> None: ) self.step_rec_fn.__enter__() +<<<<<<< HEAD def set_custom_trace_id_callback(self, callback) -> None: +======= + def set_custom_trace_id_callback(self, callback): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Sets a callback to be called when a new trace ID is generated. """ @@ -861,11 +957,19 @@ def get_trace_id(self): return None return self.profiler.trace_id +<<<<<<< HEAD def _trace_ready(self) -> None: if self.on_trace_ready: self.on_trace_ready(self) def _transit_action(self, prev_action, current_action) -> None: +======= + def _trace_ready(self): + if self.on_trace_ready: + self.on_trace_ready(self) + + def _transit_action(self, prev_action, current_action): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) action_list = self.action_map.get((prev_action, current_action)) if action_list: for action in action_list: @@ -903,7 +1007,11 @@ def __init__(self) -> None: self.output_file_path: str = "" self.output_file_path_observer: str = "" +<<<<<<< HEAD def __del__(self) -> None: +======= + def __del__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Calls unregister_callback() to make sure to finalize outputs. """ @@ -1015,7 +1123,11 @@ def get_resources_dir_for_et_path( return None return resource_dir +<<<<<<< HEAD def unregister_callback(self) -> None: +======= + def unregister_callback(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Removes ET observer from record function callbacks. """ @@ -1081,7 +1193,11 @@ def is_running(self): """ return self._execution_trace_running +<<<<<<< HEAD def start(self) -> None: +======= + def start(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Starts to capture. """ @@ -1090,7 +1206,11 @@ def start(self) -> None: self._execution_trace_running = True self._record_pg_config() +<<<<<<< HEAD def stop(self) -> None: +======= + def stop(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Stops to capture. """ @@ -1098,7 +1218,11 @@ def stop(self) -> None: _disable_execution_trace_observer() self._execution_trace_running = False +<<<<<<< HEAD def cleanup(self) -> None: +======= + def cleanup(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ Calls unregister_callback() to make sure to finalize outputs. """ diff --git a/torch/quantization/fuser_method_mappings.py b/torch/quantization/fuser_method_mappings.py index 5a68fbf02015f..42fe593ce3f27 100644 --- a/torch/quantization/fuser_method_mappings.py +++ b/torch/quantization/fuser_method_mappings.py @@ -6,7 +6,10 @@ `torch/ao/quantization/fuser_method_mappings.py`, while adding an import statement here. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.ao.quantization.fuser_method_mappings import ( _DEFAULT_OP_LIST_TO_FUSER_METHOD, fuse_conv_bn, diff --git a/torch/quantization/fx/_equalize.py b/torch/quantization/fx/_equalize.py index d6b8611d4a769..4aa11e55aa66f 100644 --- a/torch/quantization/fx/_equalize.py +++ b/torch/quantization/fx/_equalize.py @@ -6,7 +6,10 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.ao.quantization.fx._equalize import ( _convert_equalization_ref, _InputEqualizationObserver, diff --git a/torch/quantization/fx/convert.py b/torch/quantization/fx/convert.py index 30a661da41e5e..ba6cdc9caefe2 100644 --- a/torch/quantization/fx/convert.py +++ b/torch/quantization/fx/convert.py @@ -6,5 +6,8 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.ao.quantization.fx.convert import convert diff --git a/torch/quantization/fx/fuse.py b/torch/quantization/fx/fuse.py index 22ad750e9f878..3c11a8b22c653 100644 --- a/torch/quantization/fx/fuse.py +++ b/torch/quantization/fx/fuse.py @@ -6,5 +6,8 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.ao.quantization.fx.fuse import fuse diff --git a/torch/quantization/fx/fusion_patterns.py b/torch/quantization/fx/fusion_patterns.py index 982d919655f36..bc25e504da796 100644 --- a/torch/quantization/fx/fusion_patterns.py +++ b/torch/quantization/fx/fusion_patterns.py @@ -6,5 +6,8 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.ao.quantization.fx.fuse_handler import DefaultFuseHandler, FuseHandler diff --git a/torch/quantization/fx/graph_module.py b/torch/quantization/fx/graph_module.py index 74b63903d7400..f44f79ca0a068 100644 --- a/torch/quantization/fx/graph_module.py +++ b/torch/quantization/fx/graph_module.py @@ -6,7 +6,10 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.ao.quantization.fx.graph_module import ( _is_observed_module, _is_observed_standalone_module, diff --git a/torch/quantization/fx/match_utils.py b/torch/quantization/fx/match_utils.py index 8585a21ad445d..12dadb093a119 100644 --- a/torch/quantization/fx/match_utils.py +++ b/torch/quantization/fx/match_utils.py @@ -6,7 +6,10 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.ao.quantization.fx.match_utils import ( _find_matches, _is_match, diff --git a/torch/quantization/fx/pattern_utils.py b/torch/quantization/fx/pattern_utils.py index fa601d1eb619c..64b4c39931370 100644 --- a/torch/quantization/fx/pattern_utils.py +++ b/torch/quantization/fx/pattern_utils.py @@ -6,7 +6,10 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.ao.quantization.fx.pattern_utils import ( _register_fusion_pattern, _register_quant_pattern, diff --git a/torch/quantization/fx/prepare.py b/torch/quantization/fx/prepare.py index a6007ef242af5..6a54eaa62d203 100644 --- a/torch/quantization/fx/prepare.py +++ b/torch/quantization/fx/prepare.py @@ -6,5 +6,8 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.ao.quantization.fx.prepare import prepare diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 89f8d4406e912..5229bae9d74aa 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -6,7 +6,10 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.ao.quantization.fx.quantize_handler import ( BatchNormQuantizeHandler, BinaryOpQuantizeHandler, diff --git a/torch/quantization/fx/quantization_types.py b/torch/quantization/fx/quantization_types.py index 0820ea057078e..06502e3ce3b82 100644 --- a/torch/quantization/fx/quantization_types.py +++ b/torch/quantization/fx/quantization_types.py @@ -6,5 +6,8 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.ao.quantization.utils import Pattern, QuantizerCls diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py index e45c82b8fb6f2..bd72032a279b9 100644 --- a/torch/quantization/fx/utils.py +++ b/torch/quantization/fx/utils.py @@ -6,7 +6,10 @@ appropriate files under `torch/ao/quantization/fx/`, while adding an import statement here. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.ao.quantization.fx.utils import ( all_node_args_have_no_tensors, assert_and_get_unique_device, diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index 2163e2717b069..b3c663a854ab4 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -6,7 +6,10 @@ `torch/ao/quantization/observer.py`, while adding an import statement here. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.ao.quantization.observer import ( _is_activation_post_process, _is_per_channel_script_obs_instance, diff --git a/torch/quantization/qconfig.py b/torch/quantization/qconfig.py index a02ff7d6f7388..2cd8cda2bdabd 100644 --- a/torch/quantization/qconfig.py +++ b/torch/quantization/qconfig.py @@ -6,7 +6,10 @@ `torch/ao/quantization/qconfig.py`, while adding an import statement here. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.ao.quantization.qconfig import ( _add_module_to_qconfig_obs_ctr, _assert_valid_qconfig, diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index faa24d391d31a..fe143f13f08b9 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -6,7 +6,10 @@ `torch/ao/quantization/quantization_mappings.py`, while adding an import statement here. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.ao.quantization.quantization_mappings import ( _get_special_act_post_process, _has_special_act_post_process, diff --git a/torch/serialization.py b/torch/serialization.py index a6eb314fc1a82..a669a14a58b79 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -189,7 +189,11 @@ def set_crc32_options(compute_crc32: bool): able to load the file. Args: +<<<<<<< HEAD compute_crc32 (bool): set crc32 computation flag +======= + compute_crc32 (bool): set crc32 compuation flag +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ from torch.utils.serialization import config @@ -1066,7 +1070,11 @@ def persistent_id(obj: Any) -> Optional[tuple]: # tensor]`, where `tensor.storage()` is the same as `storage`, and # `tensor.element_size() > 1`. Let's say that `tensor.dtype == # torch.float`. The storage will be serialized with element size +<<<<<<< HEAD # of 1, since we're choosing to serialize the first occurrence of +======= + # of 1, since we're choosing to serialize the first occurance of +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # a duplicate storage. Since this legacy serialization format saves # the numel of the storage, rather than nbytes directly, we'll be # effectively saving nbytes in this case. We'll be able to load it @@ -1109,6 +1117,7 @@ def persistent_id(obj: Any) -> Optional[tuple]: return res return None +<<<<<<< HEAD sys_info = { "protocol_version": PROTOCOL_VERSION, "little_endian": sys.byteorder == "little", @@ -1118,6 +1127,17 @@ def persistent_id(obj: Any) -> Optional[tuple]: "long": LONG_SIZE, }, } +======= + sys_info = dict( + protocol_version=PROTOCOL_VERSION, + little_endian=sys.byteorder == "little", + type_sizes=dict( + short=SHORT_SIZE, + int=INT_SIZE, + long=LONG_SIZE, + ), + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) @@ -1147,7 +1167,11 @@ def _save( pickle_protocol, _disable_byteorder_record, ): +<<<<<<< HEAD serialized_storages: dict[str, torch.storage.UntypedStorage] = {} +======= + serialized_storages = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) id_map: dict[int, str] = {} # Since loading storages that view the same data with different dtypes is @@ -1328,11 +1352,19 @@ def load( loading only tensors, primitive types, dictionaries and any types added via :func:`torch.serialization.add_safe_globals`. See :ref:`weights-only` for more details. +<<<<<<< HEAD mmap: Indicates whether the file should be mapped rather than loading all the storages into memory. Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they are moved to the location that they were tagged with when saving, or specified by ``map_location``. This second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the tensor storages from disk to CPU memory in the first step, ``f`` is mapped, which means tensor storages +======= + mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory. + Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they + are moved to the location that they were tagged with when saving, or specified by ``map_location``. This + second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the + tensor storages from disk to CPU memory in the first step, ``f`` is mmaped, which means tensor storages +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) will be lazily loaded when their data is accessed. pickle_load_args: (Python 3 only) optional keyword arguments passed over to :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g., @@ -1426,7 +1458,11 @@ def _get_wo_message(message: str) -> str: "Please file an issue with the following so that we can make " "`weights_only=True` compatible with your use case: WeightsUnpickler error: " ) +<<<<<<< HEAD updated_message += "\n\n" + message +======= + updated_message += message +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return updated_message + DOCS_MESSAGE weights_only_not_set = weights_only is None @@ -1988,7 +2024,11 @@ def _get_offset(key, name, numel): # for a given key. offsets[name] = storage_offset +<<<<<<< HEAD # Increment current_offset to offset where next zipfile header starts +======= + # Increment current_offset of offset where next zipfile header starts +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) current_offset = storage_offset + numel # add size of data descriptor after payload if numel > 0: @@ -2004,10 +2044,14 @@ def load_tensor(dtype, numel, key, location): if torch._guards.detect_fake_mode(None) is not None: nbytes = numel * torch._utils._element_size(dtype) storage = torch.UntypedStorage(nbytes, device="meta") +<<<<<<< HEAD if can_calculate_storage_offsets: storage._checkpoint_offset = _get_offset(key, name, numel) else: storage._checkpoint_offset = zip_file.get_record_offset(name) +======= + storage._checkpoint_offset = zip_file.get_record_offset(name) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) elif _serialization_tls.skip_data: nbytes = numel * torch._utils._element_size(dtype) storage = torch.UntypedStorage(nbytes) diff --git a/torch/signal/windows/windows.py b/torch/signal/windows/windows.py index e68c202f03e8a..09bb8259fbf21 100644 --- a/torch/signal/windows/windows.py +++ b/torch/signal/windows/windows.py @@ -128,7 +128,13 @@ def _window_function_checks( >>> # Generates a periodic exponential window and decay factor equal to .5 >>> torch.signal.windows.exponential(10, sym=False,tau=.5) tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04]) +<<<<<<< HEAD """.format(**window_common_args), +======= + """.format( + **window_common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def exponential( M: int, @@ -450,7 +456,13 @@ def kaiser( >>> # Generates a periodic Hamming window. >>> torch.signal.windows.hamming(10, sym=False) tensor([0.0800, 0.1679, 0.3979, 0.6821, 0.9121, 1.0000, 0.9121, 0.6821, 0.3979, 0.1679]) +<<<<<<< HEAD """.format(**window_common_args), +======= +""".format( + **window_common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def hamming( M: int, @@ -504,7 +516,13 @@ def hamming( >>> # Generates a periodic Hann window. >>> torch.signal.windows.hann(10, sym=False) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) +<<<<<<< HEAD """.format(**window_common_args), +======= +""".format( + **window_common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def hann( M: int, @@ -558,7 +576,13 @@ def hann( >>> # Generates a periodic Blackman window. >>> torch.signal.windows.blackman(5, sym=False) tensor([-1.4901e-08, 2.0077e-01, 8.4923e-01, 8.4923e-01, 2.0077e-01]) +<<<<<<< HEAD """.format(**window_common_args), +======= +""".format( + **window_common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def blackman( M: int, @@ -619,7 +643,13 @@ def blackman( >>> # Generates a periodic Bartlett window. >>> torch.signal.windows.bartlett(10, sym=False) tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 0.8000, 0.6000, 0.4000, 0.2000]) +<<<<<<< HEAD """.format(**window_common_args), +======= +""".format( + **window_common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def bartlett( M: int, @@ -694,7 +724,13 @@ def bartlett( >>> # Generates a periodic general cosine window with 2 coefficients. >>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) +<<<<<<< HEAD """.format(**window_common_args), +======= +""".format( + **window_common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def general_cosine( M, @@ -787,7 +823,13 @@ def general_cosine( >>> # Generates a periodic Hann window with the general Hamming window. >>> torch.signal.windows.general_hamming(10, alpha=0.5, sym=False) tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955]) +<<<<<<< HEAD """.format(**window_common_args), +======= +""".format( + **window_common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def general_hamming( M, @@ -852,7 +894,13 @@ def general_hamming( >>> # Generates a periodic Nuttall window. >>> torch.signal.windows.general_hamming(5, sym=False) tensor([3.6280e-04, 1.1052e-01, 7.9826e-01, 7.9826e-01, 1.1052e-01]) +<<<<<<< HEAD """.format(**window_common_args), +======= +""".format( + **window_common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) def nuttall( M: int, diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 31299314a85f1..369be2bb70bca 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -559,11 +559,15 @@ def as_sparse_gradcheck(gradcheck): For example: >>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck) +<<<<<<< HEAD >>> x = ( ... torch.tensor([[0, 1], [2, 3]], dtype=torch.float64) ... .to_sparse_coo() ... .requires_grad_(True) ... ) +======= + >>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse_coo().requires_grad_(True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> gradcheck(lambda x: x.to_sparse_csr(), x) True """ @@ -602,10 +606,14 @@ def convert_to_strided_representation(args): and obj.requires_grad and obj.layout in sparse_layouts ): +<<<<<<< HEAD d = { "layout": obj.layout, "shape": obj.shape, } +======= + d = dict(layout=obj.layout, shape=obj.shape) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not masked: # Materialize unspecified elements with zero values batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim() @@ -671,7 +679,11 @@ def restore_from_strided_representation(args): ) else: raise NotImplementedError( +<<<<<<< HEAD f"conversion of {d['layout']} strided representation to tensor" +======= + f'conversion of {d["layout"]} strided representation to tensor' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) new_args.append(a) return tuple(new_args) diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index ea36264d8f822..59cfaed526fed 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -296,11 +296,19 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): for b in range(nbatches): for i, r in enumerate(r_offsets): r0, r1 = divmod(r, N) +<<<<<<< HEAD acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] for g in range(c_indices[i], c_indices[i + 1]): p = p_offsets[g] q0, q1 = divmod(q_offsets[g], N) acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] +======= + acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] + for g in range(c_indices[i], c_indices[i+1]): + p = p_offsets[g] + q0, q1 = divmod(q_offsets[g], N) + acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are integer multiples of ``Ms`` and ``Ks``, respectively. @@ -320,11 +328,19 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): n = (r % N) // Ns r0, r1 = divmod(r, N) c0, c1 = c_indices[m], c_indices[m + 1] +<<<<<<< HEAD acc = accumulators[b, r0 : r0 + Ms, r1 : r1 + Ns] for i, p in enumerate(range(c0, c1)): q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i] q0, q1 = divmod(q, N) acc += blocks[p] @ others[b, q0 : q0 + Ks, q1 : q1 + Ns] +======= + acc = accumulators[b, r0:r0 + Ms, r1:r1 + Ns] + for i, p in enumerate(range(c0, c1)): + q = q_offsets[n * c1 + (SPLIT_N - n) * c0 + i] + q0, q1 = divmod(q, N) + acc += blocks[p] @ others[b, q0:q0 + Ks, q1:q1 + Ns] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) where ``Ns = N // meta['SPLIT_N']``, and ``M`` and ``K`` are integer multiples of ``Ms`` and ``Ks``, respectively. diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py index 89245246395a9..ce4c059bd2424 100644 --- a/torch/sparse/_triton_ops_meta.py +++ b/torch/sparse/_triton_ops_meta.py @@ -60,7 +60,11 @@ the pytorch development tree:: cd /path/to/pytorch +<<<<<<< HEAD python -m pip install --no-build-isolation -v -e . +======= + python setup.py develop +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python torch/sparse/_triton_ops_meta.py This will compute the optimal kernel parameters for the GPU device @@ -97,7 +101,10 @@ kernel parameters for addmm-based operations. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __all__ = ["get_meta", "tune_bsr_dense_addmm", "tune__int_bsr_dense_addmm"] import inspect @@ -433,9 +440,15 @@ def from_key(key, parameters): def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device): +<<<<<<< HEAD assert sparsity <= 1.0 and sparsity >= 0.0, ( "sparsity should be a value between 0 and 1" ) +======= + assert ( + sparsity <= 1.0 and sparsity >= 0.0 + ), "sparsity should be a value between 0 and 1" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert M % blocksize[0] == 0 assert N % blocksize[1] == 0 shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :] diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index b225eaabb3206..bd219b63071c0 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -465,6 +465,7 @@ def prune_dense_static_sort( The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below: ``` from torch.sparse import SparseSemiStructuredTensorCUTLASS +<<<<<<< HEAD from torch.sparse._semi_structured_conversions import ( _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask, @@ -485,6 +486,16 @@ def prune_dense_static_sort( meta_t_cutlass, bitmask, ) +======= + from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask + + pruned = _sparse_semi_structured_tile(dense) + packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned) + packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous()) + bitmask = _compute_compressed_swizzled_bitmask(pruned) + + SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` """ # We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag. @@ -595,19 +606,27 @@ def prune_dense_static_sort( The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below: ``` from torch.sparse import SparseSemiStructuredTensorCUSPARSELT +<<<<<<< HEAD from torch.sparse._semi_structured_conversions import ( _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask, ) +======= + from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pruned = _sparse_semi_structured_tile(dense) packed_cusparselt = torch._cslt_compress(pruned) packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous()) bitmask = _compute_compressed_swizzled_bitmask(pruned) +<<<<<<< HEAD SparseSemiStructuredTensorCUSPARSELT( dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask ) +======= + SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``` """ ( diff --git a/torch/special/__init__.py b/torch/special/__init__.py index dbc9314ad2087..f66bf82bcfeee 100644 --- a/torch/special/__init__.py +++ b/torch/special/__init__.py @@ -134,7 +134,13 @@ >>> torch.special.digamma(a) tensor([-0.5772, -1.9635]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) gammaln = _add_docstr( @@ -160,7 +166,13 @@ >>> torch.special.gammaln(a) tensor([ 0.5724, 0.0000, -0.1208]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) polygamma = _add_docstr( @@ -196,7 +208,13 @@ tensor([ 6.4939, 97.4091]) >>> torch.special.polygamma(4, a) tensor([ -24.8863, -771.4742]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) erf = _add_docstr( @@ -220,7 +238,13 @@ >>> torch.special.erf(torch.tensor([0, -1., 10.])) tensor([ 0.0000, -0.8427, 1.0000]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) erfc = _add_docstr( @@ -245,7 +269,13 @@ >>> torch.special.erfc(torch.tensor([0, -1., 10.])) tensor([ 1.0000, 1.8427, 0.0000]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) erfcx = _add_docstr( @@ -273,7 +303,13 @@ >>> torch.special.erfcx(torch.tensor([0, -1., 10.])) tensor([ 1.0000, 5.0090, 0.0561]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) erfinv = _add_docstr( @@ -299,7 +335,13 @@ >>> torch.special.erfinv(torch.tensor([0, 0.5, -1.])) tensor([ 0.0000, 0.4769, -inf]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) logit = _add_docstr( @@ -337,7 +379,13 @@ tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516]) >>> torch.special.logit(a, eps=1e-6) tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) logsumexp = _add_docstr( @@ -346,7 +394,13 @@ logsumexp(input, dim, keepdim=False, *, out=None) Alias for :func:`torch.logsumexp`. +<<<<<<< HEAD """.format(**multi_dim_common), +======= +""".format( + **multi_dim_common + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) expit = _add_docstr( @@ -373,7 +427,13 @@ tensor([ 0.9213, 1.0887, -0.8858, -1.7683]) >>> torch.special.expit(t) tensor([ 0.7153, 0.7481, 0.2920, 0.1458]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) exp2 = _add_docstr( @@ -398,7 +458,13 @@ >>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4])) tensor([ 1., 2., 8., 16.]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) expm1 = _add_docstr( @@ -426,7 +492,13 @@ >>> torch.special.expm1(torch.tensor([0, math.log(2.)])) tensor([ 0., 1.]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) xlog1py = _add_docstr( @@ -471,7 +543,13 @@ tensor([1.6094, 3.2189, 4.8283]) >>> torch.special.xlog1py(2, y) tensor([2.7726, 2.1972, 1.3863]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) xlogy = _add_docstr( @@ -516,7 +594,13 @@ tensor([1.3863, 2.7726, 4.1589]) >>> torch.special.xlogy(2, y) tensor([2.1972, 1.3863, 0.0000]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) i0 = _add_docstr( @@ -542,7 +626,13 @@ >>> torch.i0(torch.arange(5, dtype=torch.float32)) tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) i0e = _add_docstr( @@ -567,7 +657,13 @@ >>> torch.special.i0e(torch.arange(5, dtype=torch.float32)) tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) i1 = _add_docstr( @@ -592,7 +688,13 @@ >>> torch.special.i1(torch.arange(5, dtype=torch.float32)) tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) i1e = _add_docstr( @@ -618,7 +720,13 @@ >>> torch.special.i1e(torch.arange(5, dtype=torch.float32)) tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ndtr = _add_docstr( @@ -643,7 +751,13 @@ >>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3])) tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) ndtri = _add_docstr( @@ -671,7 +785,13 @@ >>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1])) tensor([ -inf, -0.6745, 0.0000, 0.6745, inf]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) log_ndtr = _add_docstr( @@ -696,7 +816,13 @@ >>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3])) tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) log1p = _add_docstr( @@ -737,7 +863,13 @@ tensor([ 0.2252, -0.2948, 1.0267, -1.1566]) >>> torch.special.sinc(t) tensor([ 0.9186, 0.8631, -0.0259, -0.1300]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) round = _add_docstr( @@ -842,7 +974,13 @@ tensor([1.6449, 0.0823]) >>> torch.special.zeta(2, torch.tensor([1., 2.])) tensor([1.6449, 0.6449]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) multigammaln = _add_docstr( @@ -879,7 +1017,13 @@ >>> torch.special.multigammaln(a, 2) tensor([[0.3928, 0.4007, 0.7586], [1.0311, 0.3901, 0.5049]]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) gammainc = _add_docstr( @@ -928,7 +1072,13 @@ >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2) tensor([1., 1., 1.]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) gammaincc = _add_docstr( @@ -976,7 +1126,13 @@ >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2) tensor([1., 1., 1.]) +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) airy_ai = _add_docstr( @@ -993,7 +1149,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) bessel_j0 = _add_docstr( @@ -1010,7 +1172,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) bessel_j1 = _add_docstr( @@ -1027,7 +1195,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) bessel_y0 = _add_docstr( @@ -1044,7 +1218,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) bessel_y1 = _add_docstr( @@ -1061,7 +1241,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) chebyshev_polynomial_t = _add_docstr( @@ -1092,7 +1278,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) chebyshev_polynomial_u = _add_docstr( @@ -1124,7 +1316,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) chebyshev_polynomial_v = _add_docstr( @@ -1142,7 +1340,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) chebyshev_polynomial_w = _add_docstr( @@ -1160,7 +1364,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) hermite_polynomial_h = _add_docstr( @@ -1186,7 +1396,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) hermite_polynomial_he = _add_docstr( @@ -1212,7 +1428,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) laguerre_polynomial_l = _add_docstr( @@ -1238,7 +1460,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) legendre_polynomial_p = _add_docstr( @@ -1264,7 +1492,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) modified_bessel_i0 = _add_docstr( @@ -1281,7 +1515,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) modified_bessel_i1 = _add_docstr( @@ -1298,7 +1538,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) modified_bessel_k0 = _add_docstr( @@ -1315,7 +1561,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) modified_bessel_k1 = _add_docstr( @@ -1332,7 +1584,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) scaled_modified_bessel_k0 = _add_docstr( @@ -1349,7 +1607,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) scaled_modified_bessel_k1 = _add_docstr( @@ -1366,7 +1630,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) shifted_chebyshev_polynomial_t = _add_docstr( @@ -1384,7 +1654,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) shifted_chebyshev_polynomial_u = _add_docstr( @@ -1402,7 +1678,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) shifted_chebyshev_polynomial_v = _add_docstr( @@ -1420,7 +1702,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) shifted_chebyshev_polynomial_w = _add_docstr( @@ -1438,7 +1726,13 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) spherical_bessel_j0 = _add_docstr( @@ -1455,5 +1749,11 @@ Keyword args: {out} +<<<<<<< HEAD """.format(**common_args), +======= +""".format( + **common_args + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index eff07c413deb4..bac85111bbcd0 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -1538,9 +1538,13 @@ def assert_close( >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default error message can be overwritten. +<<<<<<< HEAD >>> torch.testing.assert_close( ... actual, expected, msg="Argh, the tensors are not close!" ... ) +======= + >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py index 23d80d6ceae4f..bbb52d2e59117 100644 --- a/torch/testing/_creation.py +++ b/torch/testing/_creation.py @@ -115,11 +115,19 @@ def make_tensor( >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) >>> from torch.testing import make_tensor >>> # Creates a float tensor with values in [-1, 1) +<<<<<<< HEAD >>> make_tensor((3,), device="cpu", dtype=torch.float32, low=-1, high=1) >>> # xdoctest: +SKIP tensor([ 0.1205, 0.2282, -0.6380]) >>> # Creates a bool tensor on CUDA >>> make_tensor((2, 2), device="cuda", dtype=torch.bool) +======= + >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) + >>> # xdoctest: +SKIP + tensor([ 0.1205, 0.2282, -0.6380]) + >>> # Creates a bool tensor on CUDA + >>> make_tensor((2, 2), device='cuda', dtype=torch.bool) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tensor([[False, False], [False, True]], device='cuda:0') """ diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 846d2b407684c..4ad4b4a0cbb4e 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -40,8 +40,11 @@ and torch.cuda.get_device_capability()[1] > 0) IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and (torch.cuda.get_device_capability() in [(7, 2), (8, 7)] or IS_THOR)) IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 9)) +<<<<<<< HEAD IS_SM90 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)) IS_SM100 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (10, 0)) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def evaluate_gfx_arch_within(arch_list): if not torch.cuda.is_available(): @@ -108,6 +111,7 @@ def evaluate_platform_supports_fp8(): return SM90OrLater or torch.cuda.get_device_capability() == (8, 9) return False +<<<<<<< HEAD def evaluate_platform_supports_fp8_grouped_gemm(): if torch.cuda.is_available(): if torch.version.hip: @@ -122,6 +126,11 @@ def evaluate_platform_supports_fp8_grouped_gemm(): return False def evaluate_platform_supports_mx_gemm(): +======= +PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8()) + +def _platform_supports_mx_gemm(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if torch.cuda.is_available(): if torch.version.hip: if ROCM_VERSION >= (7, 0): @@ -130,6 +139,7 @@ def evaluate_platform_supports_mx_gemm(): return SM100OrLater return False +<<<<<<< HEAD def evaluate_platform_supports_mxfp8_grouped_gemm(): if torch.cuda.is_available() and not torch.version.hip: built_with_fbgemm_genai = "USE_FBGEMM_GENAI" in torch.__config__.show() @@ -141,6 +151,9 @@ def evaluate_platform_supports_mxfp8_grouped_gemm(): PLATFORM_SUPPORTS_FP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_fp8_grouped_gemm()) PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater) PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mxfp8_grouped_gemm()) +======= +PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: _platform_supports_mx_gemm()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TEST_NUMBA: try: @@ -181,6 +194,12 @@ def tf32_off(): @contextlib.contextmanager def tf32_on(self, tf32_precision=1e-5): +<<<<<<< HEAD +======= + if torch.version.hip: + hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) + os.environ["HIPBLASLT_ALLOW_TF32"] = "1" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 old_precision = self.precision try: @@ -189,6 +208,14 @@ def tf32_on(self, tf32_precision=1e-5): with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True): yield finally: +<<<<<<< HEAD +======= + if torch.version.hip: + if hip_allow_tf32 is not None: + os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 + else: + del os.environ["HIPBLASLT_ALLOW_TF32"] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul self.precision = old_precision @@ -238,7 +265,11 @@ def tf32_enabled(): # if device is specified, it will check if device is cuda # if dtype is specified, it will check if dtype is float32 or complex64 # tf32 and fp32 are different only when all the three checks pass +<<<<<<< HEAD def tf32_on_and_off(tf32_precision=1e-5, *, only_if=True): +======= +def tf32_on_and_off(tf32_precision=1e-5, only_if=True): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def with_tf32_disabled(self, function_call): with tf32_off(): function_call() @@ -298,7 +329,11 @@ def _get_torch_rocm_version(): if not TEST_WITH_ROCM or torch.version.hip is None: return (0, 0) rocm_version = str(torch.version.hip) +<<<<<<< HEAD rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha +======= + rocm_version = rocm_version.split("-")[0] # ignore git sha +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return tuple(int(x) for x in rocm_version.split(".")) def _check_cusparse_generic_available(): @@ -311,7 +346,11 @@ def _check_hipsparse_generic_available(): return False rocm_version = str(torch.version.hip) +<<<<<<< HEAD rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha +======= + rocm_version = rocm_version.split("-")[0] # ignore git sha +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1)) diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 8971eca1bb24e..d15d32405b8eb 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -721,9 +721,15 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo intersect = set(except_for if except_for else []) & set( only_for if only_for else [] ) +<<<<<<< HEAD assert not intersect, ( f"device ({intersect}) appeared in both except_for and only_for" ) +======= + assert ( + not intersect + ), f"device ({intersect}) appeared in both except_for and only_for" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Replace your privateuse1 backend name with 'privateuse1' if is_privateuse1_backend_available(): @@ -1277,6 +1283,7 @@ def __init__(self, dep, reason): def _has_sufficient_memory(device, size): +<<<<<<< HEAD device_ = torch.device(device) device_type = device_.type if device_type in ["cuda", "xpu"]: @@ -1310,6 +1317,28 @@ def _has_sufficient_memory(device, size): raise unittest.SkipTest("TODO: Memory availability checks for XLA?") if device_type != "cpu": +======= + if torch.device(device).type == "cuda": + if not torch.cuda.is_available(): + return False + gc.collect() + torch.cuda.empty_cache() + # torch.cuda.mem_get_info, aka cudaMemGetInfo, returns a tuple of (free memory, total memory) of a GPU + if device == "cuda": + device = "cuda:0" + return ( + torch.cuda.memory.mem_get_info(device)[0] + * torch.cuda.memory.get_per_process_memory_fraction(device) + ) >= size + + if device == "xla": + raise unittest.SkipTest("TODO: Memory availability checks for XLA?") + + if device == "xpu": + raise unittest.SkipTest("TODO: Memory availability checks for Intel GPU?") + + if device != "cpu": +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise unittest.SkipTest("Unknown device type") # CPU @@ -1355,6 +1384,10 @@ def dep_fn(self, *args, **kwargs): # an additional array of the same size as the input. if inductor and torch._inductor.config.cpp_wrapper and _device != "cpu": size_bytes *= 2 +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not _has_sufficient_memory(_device, size_bytes): raise unittest.SkipTest(f"Insufficient {_device} memory") @@ -1419,9 +1452,15 @@ def __init__(self, num_required_devices): self.num_required_devices = num_required_devices def __call__(self, fn): +<<<<<<< HEAD assert not hasattr(fn, "num_required_devices"), ( f"deviceCountAtLeast redefinition for {fn.__name__}" ) +======= + assert not hasattr( + fn, "num_required_devices" + ), f"deviceCountAtLeast redefinition for {fn.__name__}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fn.num_required_devices = self.num_required_devices @wraps(fn) @@ -1486,6 +1525,7 @@ def only_fn(self, *args, **kwargs): # self.precision *2, max(1, self.precision)). class precisionOverride: def __init__(self, d): +<<<<<<< HEAD assert isinstance(d, dict), ( "precisionOverride not given a dtype : precision dict!" ) @@ -1493,6 +1533,15 @@ def __init__(self, d): assert isinstance(dtype, torch.dtype), ( f"precisionOverride given unknown dtype {dtype}" ) +======= + assert isinstance( + d, dict + ), "precisionOverride not given a dtype : precision dict!" + for dtype in d.keys(): + assert isinstance( + dtype, torch.dtype + ), f"precisionOverride given unknown dtype {dtype}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.d = d @@ -1525,12 +1574,21 @@ class toleranceOverride: def __init__(self, d): assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!" for dtype, prec in d.items(): +<<<<<<< HEAD assert isinstance(dtype, torch.dtype), ( f"toleranceOverride given unknown dtype {dtype}" ) assert isinstance(prec, tol), ( "toleranceOverride not given a dtype : tol dict!" ) +======= + assert isinstance( + dtype, torch.dtype + ), f"toleranceOverride given unknown dtype {dtype}" + assert isinstance( + prec, tol + ), "toleranceOverride not given a dtype : tol dict!" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.d = d @@ -1558,6 +1616,7 @@ def __init__(self, *args, device_type="all"): "all dtype variants must be. " f"Received non-list non-tuple dtype {str(arg)}" ) +<<<<<<< HEAD assert all(isinstance(dtype, torch.dtype) for dtype in arg), ( f"Unknown dtype in {str(arg)}" ) @@ -1565,6 +1624,15 @@ def __init__(self, *args, device_type="all"): assert all(isinstance(arg, torch.dtype) for arg in args), ( f"Unknown dtype in {str(args)}" ) +======= + assert all( + isinstance(dtype, torch.dtype) for dtype in arg + ), f"Unknown dtype in {str(arg)}" + else: + assert all( + isinstance(arg, torch.dtype) for arg in args + ), f"Unknown dtype in {str(args)}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.args = args self.device_type = device_type @@ -1589,12 +1657,15 @@ def __init__(self, *args): super().__init__(*args, device_type="cuda") +<<<<<<< HEAD # Overrides specified dtypes on Intel GPU. class dtypesIfXPU(dtypes): def __init__(self, *args): super().__init__(*args, device_type="xpu") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class dtypesIfMPS(dtypes): def __init__(self, *args): super().__init__(*args, device_type="mps") @@ -1978,18 +2049,27 @@ def get_all_device_types() -> list[str]: and torch.cpu._is_avx2_supported() and os.getenv("ATEN_CPU_CAPABILITY") != "default" ) +<<<<<<< HEAD IS_FLEX_ATTENTION_XPU_PLATFORM_SUPPORTED = ( torch.xpu.is_available() and torch.utils._triton.has_triton() ) flex_attention_supported_platform = unittest.skipUnless( IS_FLEX_ATTENTION_XPU_PLATFORM_SUPPORTED or IS_FLEX_ATTENTION_CPU_PLATFORM_SUPPORTED +======= +flex_attention_supported_platform = unittest.skipUnless( + IS_FLEX_ATTENTION_CPU_PLATFORM_SUPPORTED +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) or ( torch.cuda.is_available() and torch.utils._triton.has_triton() and torch.cuda.get_device_capability() >= (8, 0) ), +<<<<<<< HEAD "Requires CUDA and Triton, Intel GPU and triton, or CPU with avx2 and later", +======= + "Requires CUDA and Triton, or CPU with avx2 and later", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName: e4m3_type = torch.float8_e4m3fnuz diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index c1f75697fe889..5ac5d97f154ba 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -96,10 +96,17 @@ class TestSkip(NamedTuple): class DistTestCases: # Backends that do not support a specific collective skip_collective = {} +<<<<<<< HEAD skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc", "xccl"} skip_collective["reduce"] = set() skip_collective["sendrecv anysource"] = {"nccl", "ucc", "xccl"} skip_collective["cpu barrier"] = {"nccl", "ucc", "xccl"} +======= + skip_collective["allgather_coalesced"] = {"nccl", "mpi", "ucc"} + skip_collective["reduce"] = set() + skip_collective["sendrecv anysource"] = {"nccl", "ucc"} + skip_collective["cpu barrier"] = {"nccl", "ucc"} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Sets showing that something is implemented backend_feature = {} @@ -253,9 +260,15 @@ def verify_ddp_error_logged(model_DDP, err_substr): if err_substr.find("\nException raised from ") == -1 else err_substr.split("\nException raised from ")[0] ) +<<<<<<< HEAD assert actual in logging_err, ( f"Did not find expected {actual} in ddp logging data error: {logging_err}" ) +======= + assert ( + actual in logging_err + ), f"Did not find expected {actual} in ddp logging data error: {logging_err}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def with_nccl_blocking_wait(func): @@ -294,9 +307,15 @@ def wrapper(*args, **kwargs): finally: # restore old values. if cached_nccl_async_error_handling is not None: +<<<<<<< HEAD os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( cached_nccl_async_error_handling ) +======= + os.environ[ + "TORCH_NCCL_ASYNC_ERROR_HANDLING" + ] = cached_nccl_async_error_handling +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if cached_nccl_blocking_wait is not None: os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait @@ -338,6 +357,7 @@ def requires_gloo(): def requires_nccl_version(version, msg): +<<<<<<< HEAD if TEST_CUDA: if not c10d.is_nccl_available(): return skip_but_pass_in_sandcastle( @@ -358,6 +378,17 @@ def wrapper(*args, **kwargs): return wrapper return decorator +======= + if not c10d.is_nccl_available(): + return skip_but_pass_in_sandcastle( + "c10d was not compiled with the NCCL backend", + ) + else: + return skip_but_pass_in_sandcastle_if( + torch.cuda.nccl.version() < version, + f"Requires NCCL version greater than or equal to: {version}, found: {torch.cuda.nccl.version()}, reason: {msg}", + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def requires_nccl(): @@ -446,10 +477,16 @@ def sm_is_or_higher_than(device: torch.device, major: int, minor: int) -> bool: Returns True if the device's compute capability is (major, minor) or higher. Error out if the device is not a CUDA device. Returns False if device is a RoCM device. +<<<<<<< HEAD Returns True if device is a non-CUDA device. """ if device.type != "cuda": return True +======= + """ + if device.type != "cuda": + raise ValueError("sm_is_or_later() is only supported for CUDA devices") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if torch.version.hip is not None: # ROCm devices may have different compute capability codes @@ -824,7 +861,11 @@ def run_test(self, test_name: str, parent_pipe) -> None: sys.exit(TEST_SKIPS["generic"].exit_code) except Exception: logger.error( +<<<<<<< HEAD "Caught exception: \n%s exiting process %s with exit code: %s", +======= + "Caught exception: \n%s exiting " "process %s with exit code: %s", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) traceback.format_exc(), self.rank, MultiProcessTestCase.TEST_ERROR_EXIT_CODE, @@ -1161,7 +1202,11 @@ def worker(rank, world_pg, store): ) try: callback() +<<<<<<< HEAD except BaseException as ex: # noqa: B036 +======= + except BaseException as ex: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Exceptions are handled in MultiThreadedTestCase MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info())) ProcessLocalGroup.exception_handle( @@ -1322,7 +1367,11 @@ def run_test_with_threaded_pg(self, test_name, rank, world_size): try: getattr(self, test_name)() +<<<<<<< HEAD except BaseException as ex: # noqa: B036 +======= + except BaseException as ex: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.exception_queue.put((rank, sys.exc_info())) ProcessLocalGroup.exception_handle( ex @@ -1468,12 +1517,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @contextmanager def _dynamo_dist_per_rank_init( +<<<<<<< HEAD rank, world_size, backend=None, init_pg=True, fake_pg=False +======= + rank, world_size, backend="nccl", init_pg=True, fake_pg=False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): # To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase, # Just manually implement the most important part of the dynamo behavior to reset/clear. if not fake_pg: torch.accelerator.set_device_index(rank) +<<<<<<< HEAD device_type = ( acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" @@ -1481,6 +1535,8 @@ def _dynamo_dist_per_rank_init( if backend is None: backend = c10d.get_default_backend_for_device(device_type) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "6789" if init_pg: @@ -1527,12 +1583,18 @@ def setUpClass(cls): ) ) cls.rank = 0 +<<<<<<< HEAD device = torch.accelerator.current_accelerator().type cls.device = f"{device}:{cls.rank}" cls.device_ids = None if device in cls.device else [cls.rank] c10d.init_process_group( c10d.get_default_backend_for_device(device), rank=cls.rank, world_size=1 ) +======= + cls.device = f"cuda:{cls.rank}" + cls.device_ids = None if "cuda" in cls.device else [cls.rank] + c10d.init_process_group("nccl", rank=cls.rank, world_size=1) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def tearDownClass(cls): @@ -1568,7 +1630,11 @@ def _run( self.run_test(test_name, parent_pipe) +<<<<<<< HEAD class MultiProcContinuousTest(TestCase): +======= +class MultiProcContinousTest(TestCase): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Class variables: MAIN_PROCESS_RANK = -1 # number of test processes @@ -1611,11 +1677,16 @@ def opts(cls, high_priority_stream=False): @classmethod def _init_pg(cls, rank, world_size, rdvz_file): assert rdvz_file is not None +<<<<<<< HEAD # rank should be local_rank for tests running on <= 8gpus which is how all these tests are designed # and we expect LOCAL_RANK set by torchrun. Setting it lets init_device_mesh set the device without # issuing a warning os.environ["LOCAL_RANK"] = str(rank) store = c10d.FileStore(rdvz_file, world_size) +======= + store = c10d.FileStore(rdvz_file, world_size) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # create nccl processgroup with opts c10d.init_process_group( backend=cls.backend_str(), @@ -1630,7 +1701,11 @@ def _init_pg(cls, rank, world_size, rdvz_file): @classmethod def _run_test_given_id(cls, test_id: str, **kwargs) -> None: # self.id() == e.g. '__main__.TestDistributed.TestAdditive.test_get_rank' +<<<<<<< HEAD test_name = test_id.rsplit(".", maxsplit=1)[-1] +======= + test_name = test_id.split(".")[-1] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Get the test function from the test class self = cls(test_name) self.rank = cls.rank @@ -1641,7 +1716,10 @@ def _run_test_given_id(cls, test_id: str, **kwargs) -> None: @classmethod def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue): +<<<<<<< HEAD raised_exception = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Sub tests are going to access these values, check first assert 0 <= rank < world_size # set class variables for the test class @@ -1652,7 +1730,11 @@ def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue) cls._init_pg(rank, world_size, rdvz_file) # End of bootstrap +<<<<<<< HEAD logger.debug("Setup complete") +======= + logger.info("Setup complete") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Loop forever, waiting for a test name to run while True: @@ -1666,6 +1748,7 @@ def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue) try: cls._run_test_given_id(test_id) completion_queue.put(test_id) +<<<<<<< HEAD except BaseException as ex: # noqa: B036 raised_exception = True # Send the exception and stack trace back to the dispatcher @@ -1684,6 +1767,15 @@ def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue) # Only call this on a clean exit path if not raised_exception: c10d.destroy_process_group() +======= + except BaseException as ex: + # Send the exception back to the dispatcher + completion_queue.put(ex) + + # Termination + logger.info("Terminating ...") + c10d.destroy_process_group() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def _spawn_processes(cls, world_size) -> None: @@ -1714,7 +1806,13 @@ def _spawn_processes(cls, world_size) -> None: cls.processes.append(process) cls.task_queues.append(task_queue) cls.completion_queues.append(completion_queue) +<<<<<<< HEAD logger.debug("Started process %s with pid %s", rank, process.pid) # noqa: UP031 +======= + logger.info( + "Started process %s with pid %s", rank, process.pid + ) # noqa: UP031 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @classmethod def setUpClass(cls): diff --git a/torch/testing/_internal/common_dtype.py b/torch/testing/_internal/common_dtype.py index 474bb689f0ad9..a8b31ebcff601 100644 --- a/torch/testing/_internal/common_dtype.py +++ b/torch/testing/_internal/common_dtype.py @@ -121,6 +121,7 @@ def all_types_and_half(): return _all_types_and_half +<<<<<<< HEAD _all_mps_types = ( _dispatch_dtypes({torch.float, torch.half, torch.bfloat16}) + _integral_types ) @@ -134,6 +135,8 @@ def all_mps_types_and(*dtypes): return _all_mps_types + _validate_dtypes(*dtypes) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _float8_types = _dispatch_dtypes( ( torch.float8_e4m3fn, diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index c7274fddd6d3b..82a3f2e2f866a 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -6,7 +6,10 @@ import re import sys import time +<<<<<<< HEAD import unittest +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import warnings from abc import ABC, abstractmethod from contextlib import nullcontext @@ -1123,7 +1126,10 @@ def check_sharded_parity( cls.assertEqual(sharded_param.grad.to_local(), sharded_ref_grad.to_local()) +<<<<<<< HEAD @unittest.skipIf(TEST_XPU, "not-support-multithread") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class FSDPTestMultiThread(MultiThreadedTestCase): @property def world_size(self): @@ -1189,8 +1195,11 @@ def _run(cls, rank, test_name, file_name, pipe, **kwargs): # type: ignore[overr fake_pg = kwargs.get("fake_pg", False) print(f"dist init r={self.rank}, world={self.world_size}") +<<<<<<< HEAD if torch.accelerator.device_count() < self.world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Specify gloo backend to make 'init_process_group()' succeed, # Actual tests will be skipped if there is no enough GPUs. @@ -1287,10 +1296,17 @@ def _train_for_several_steps( loss = sharded_grad_scaler.scale(loss) if not mixed_precision and not use_pure_fp16: +<<<<<<< HEAD assert loss.dtype == torch.float32, ( "loss data type should be float32, as the original \ parameter data type is float32." ) +======= + assert ( + loss.dtype == torch.float32 + ), "loss data type should be float32, as the original \ + parameter data type is float32." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: if use_pure_fp16: self.assertEqual(loss.dtype, torch.float16) @@ -1356,9 +1372,15 @@ def _test_fsdp_parity( wrapper should provide data parallel semantics. If ``None``, then the callable defaults to the DDP constructor. """ +<<<<<<< HEAD assert fsdp_init_mode != FSDPInitMode.NO_FSDP, ( "Expects an FSDP init mode that wraps with FSDP" ) +======= + assert ( + fsdp_init_mode != FSDPInitMode.NO_FSDP + ), "Expects an FSDP init mode that wraps with FSDP" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if init_kwargs is None: init_kwargs = {} lr = 1e-2 diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4c2c3e023031f..1523b45e70c06 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1161,8 +1161,13 @@ def make_arg_conj(size): def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs): +<<<<<<< HEAD alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6 if dtype.is_floating_point else 2) beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2 if dtype.is_floating_point else 3) +======= + alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6) + beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) tests_list = [ ((2, 3), (2, 2), (2, 3), False), ((3, 3), (3, 3), (3, 3), False), @@ -2456,7 +2461,11 @@ def error_inputs_cat(op_info, device, **kwargs): # error inputs for empty tensors yield ErrorInput(SampleInput([], kwargs={'dim': 1}), +<<<<<<< HEAD error_regex='non-empty list of Tensors', error_type=ValueError) +======= + error_regex='non-empty list of Tensors') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # error inputs for different sizes yield ErrorInput(SampleInput([make_arg((S, S, L, L)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}), @@ -2509,7 +2518,11 @@ def error_inputs_cat(op_info, device, **kwargs): error_regex='zero-dimensional.*cannot be concatenated') # error inputs for different dtype of out tensors +<<<<<<< HEAD d = make_tensor((2, 3), device=device, dtype=torch.double if not device.startswith("mps") else torch.float16) +======= + d = make_tensor((2, 3), device=device, dtype=torch.double) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) x = make_tensor((2, 3), device=device, dtype=torch.float32) yield ErrorInput(SampleInput(x, kwargs={'out': d}), error_type=TypeError, error_regex='invalid combination of arguments') @@ -6004,7 +6017,10 @@ def sample_inputs_repeat_interleave(op_info, device, dtype, requires_grad, **kwa yield SampleInput(make_input((2, 3, 4)), repeats=2) yield SampleInput(make_input((2, 3, 4)), repeats=2, dim=1) yield SampleInput(make_input((2, 3, 4)), repeats=torch.arange(3, device=device), dim=1) +<<<<<<< HEAD yield SampleInput(make_input((4, 1)), repeats=torch.arange(4, device=device), dim=0, output_size=6) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs): @@ -6526,8 +6542,11 @@ def sample_inputs_view_as_real(op_info, device, dtype, requires_grad, **kwargs): def error_inputs_complex(op_info, device, is_ref=False, **kwargs): make_arg = partial(make_tensor, dtype=torch.float32, device=device) +<<<<<<< HEAD other_dtype = torch.float16 if device.startswith("mps") else torch.float64 other_dtype_name = "Half" if device.startswith("mps") else "Double" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if is_ref: error_float = "Expected both inputs to be Half, Float or Double tensors but got torch.float32 and torch.int32" @@ -6535,16 +6554,28 @@ def error_inputs_complex(op_info, device, is_ref=False, **kwargs): error_out = "Expected out tensor to have dtype torch.complex128 but got torch.complex64 instead" else: error_float = "Expected both inputs to be Half, Float or Double tensors but got Float and Int" +<<<<<<< HEAD error_dtype = f"Expected object of scalar type Float but got scalar type {other_dtype_name} for second argument" error_out = f"Expected object of scalar type Complex{other_dtype_name} but got scalar type ComplexFloat for argument 'out'" +======= + error_dtype = "Expected object of scalar type Float but got scalar type Double for second argument" + error_out = "Expected object of scalar type ComplexDouble but got scalar type ComplexFloat for argument 'out'" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.int)), error_type=RuntimeError, error_regex=error_float) +<<<<<<< HEAD yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=other_dtype)), error_type=RuntimeError, error_regex=error_dtype) yield ErrorInput(SampleInput(make_arg(M, S, dtype=other_dtype), make_arg(M, S, dtype=other_dtype), +======= + yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.float64)), + error_type=RuntimeError, error_regex=error_dtype) + + yield ErrorInput(SampleInput(make_arg(M, S, dtype=torch.float64), make_arg(M, S, dtype=torch.float64), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out=make_arg(M, S, dtype=torch.complex64)), error_type=RuntimeError, error_regex=error_out) @@ -8359,6 +8390,7 @@ def sample_inputs_grid_sampler_2d(op_info, device, dtype, requires_grad, **kwarg align_corners, ) +<<<<<<< HEAD def sample_inputs_grid_sampler_3d(op_info, device, dtype, requires_grad, **kwargs): _make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=-1, high=1) @@ -8389,6 +8421,8 @@ def sample_inputs_grid_sampler_3d(op_info, device, dtype, requires_grad, **kwarg align_corners, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def sample_inputs_cosine_embedding_loss(op_info, device, dtype, requires_grad, **kwargs): make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -11628,6 +11662,7 @@ def reference_searchsorted(sorted_sequence, boundary, out_int32=False, right=Fal split_ret = [i.astype(np.int32) for i in split_ret] if out_int32 else split_ret return np.stack(split_ret).reshape(orig_shape) +<<<<<<< HEAD def reference_hash_tensor(tensor, dim=(), keepdim=False, mode=0): assert mode == 0, "Only mode=0 (xor_sum) is supported right now" @@ -11648,6 +11683,8 @@ def reference_hash_tensor(tensor, dim=(), keepdim=False, mode=0): return result +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def loss_reference_reduction_wrapper(fn): def wrapper(input, target, *, size_average=None, reduce=None, reduction="mean", **other_kwargs): if size_average is not None or reduce is not None: @@ -11880,11 +11917,16 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): safe_val=2)), BinaryUfuncInfo('add', # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate +<<<<<<< HEAD ref=lambda input, other, *, alpha=1: ( np.add(input, other) if alpha == 1 else np.add(input, np.multiply(alpha, other)) ), +======= + ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \ + else np.add(input, np.multiply(alpha, other)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), @@ -12362,10 +12404,13 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): DecorateInfo( toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), +<<<<<<< HEAD # Higher differences starting with Zen3 or Alder Lake DecorateInfo( toleranceOverride({torch.complex64: tol(atol=4e-05, rtol=4e-06)}), 'TestDecomp', 'test_quick', device_type='cpu'), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DecorateInfo( toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), 'TestMathBits', 'test_conj_view', device_type='cuda'), @@ -19687,7 +19732,19 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_gradgrad=True, supports_out=True, check_batched_grad=False, +<<<<<<< HEAD ), +======= + skips=( + # Expected __torch_dispatch__ for aten::unbind_copy.int_out to return None + # but it returned something else instead. + DecorateInfo( + unittest.expectedFailure, + 'TestProxyTensorOpInfo', + 'test_make_fx_symbolic_exhaustive_out' + ), + )), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) OpInfo('vstack', aliases=('row_stack',), dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), @@ -20229,7 +20286,12 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): ), OpInfo('logcumsumexp', dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), +<<<<<<< HEAD backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), +======= + backward_dtypes=floating_and_complex_types_and(torch.bfloat16), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( @@ -20538,11 +20600,16 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): 'jiterator_binary', op=torch.cuda.jiterator._create_jit_fn( "template T binary(T x, T y, T alpha) { return x + alpha * y; }", alpha=1), +<<<<<<< HEAD ref=lambda input, other, *, alpha=1: ( np.add(input, other) if alpha == 1 else np.add(input, np.multiply(alpha, other)) ), +======= + ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \ + else np.add(input, np.multiply(alpha, other)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-3.14), supports_out=False, @@ -21069,6 +21136,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): DecorateInfo(slowTest, 'TestDecomp', 'test_comprehensive', dtypes=(torch.float32, torch.float64), active_if=IS_WINDOWS), ),), +<<<<<<< HEAD # TODO: Remove grid_sampler_3d tests once `nn.functional.grid_sample` has # MPS support for all cases. OpInfo( @@ -21085,6 +21153,8 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.skip('Skipped!'), device_type='xpu'), DecorateInfo(unittest.skip('Skipped!'), device_type='meta'), ),), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) OpInfo( "argwhere", ref=np.argwhere, @@ -21441,6 +21511,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): "TestConsistency", "test_output_match", device_type="mps"), ), ), +<<<<<<< HEAD ReductionOpInfo( 'hash_tensor', result_dtype=torch.uint64, @@ -21461,6 +21532,8 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), ) ), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) OpInfo( "nn.functional.ctc_loss", dtypes=floating_types(), diff --git a/torch/testing/_internal/common_mkldnn.py b/torch/testing/_internal/common_mkldnn.py index 44da60a5ad1fe..af055f856f953 100644 --- a/torch/testing/_internal/common_mkldnn.py +++ b/torch/testing/_internal/common_mkldnn.py @@ -7,6 +7,12 @@ import torch +<<<<<<< HEAD +======= +# Test whether hardware BF32 math mode enabled. It is enabled only on: +# - MKLDNN is available +# - BF16 is supported by MKLDNN +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def bf32_is_not_fp32(): if not torch.backends.mkldnn.is_available(): return False @@ -15,6 +21,7 @@ def bf32_is_not_fp32(): return True +<<<<<<< HEAD def tf32_is_not_fp32(): if not torch.backends.mkldnn.is_available(): return False @@ -75,16 +82,51 @@ def tf32_on(self, tf32_precision=1e-5): def reduced_f32_on_and_off(bf32_precision=1e-2, tf32_precision=1e-5): def with_reduced_f32_disabled(self, function_call): with reduced_f32_off(): +======= +@contextlib.contextmanager +def bf32_off(): + old_matmul_precision = torch.get_float32_matmul_precision() + try: + torch.set_float32_matmul_precision("highest") + yield + finally: + torch.set_float32_matmul_precision(old_matmul_precision) + + +@contextlib.contextmanager +def bf32_on(self, bf32_precision=1e-5): + old_matmul_precision = torch.get_float32_matmul_precision() + old_precision = self.precision + try: + torch.set_float32_matmul_precision("medium") + self.precision = bf32_precision + yield + finally: + torch.set_float32_matmul_precision(old_matmul_precision) + self.precision = old_precision + + +# This is a wrapper that wraps a test to run this test twice, one with +# allow_bf32=True, another with allow_bf32=False. When running with +# allow_bf32=True, it will use reduced precision as specified by the +# argument +def bf32_on_and_off(bf32_precision=1e-5): + def with_bf32_disabled(self, function_call): + with bf32_off(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) function_call() def with_bf32_enabled(self, function_call): with bf32_on(self, bf32_precision): function_call() +<<<<<<< HEAD def with_tf32_enabled(self, function_call): with tf32_on(self, tf32_precision): function_call() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def wrapper(f): params = inspect.signature(f).parameters arg_names = tuple(params.keys()) @@ -92,11 +134,16 @@ def wrapper(f): @functools.wraps(f) def wrapped(*args, **kwargs): kwargs.update(zip(arg_names, args)) +<<<<<<< HEAD cond = True +======= + cond = bf32_is_not_fp32() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if "device" in kwargs: cond = cond and (torch.device(kwargs["device"]).type == "cpu") if "dtype" in kwargs: cond = cond and (kwargs["dtype"] == torch.float) +<<<<<<< HEAD bf32_cond = cond and bf32_is_not_fp32() tf32_cond = cond and tf32_is_not_fp32() if bf32_cond or tf32_cond: @@ -105,6 +152,11 @@ def wrapped(*args, **kwargs): with_bf32_enabled(kwargs["self"], lambda: f(**kwargs)) if tf32_cond: with_tf32_enabled(kwargs["self"], lambda: f(**kwargs)) +======= + if cond: + with_bf32_disabled(kwargs["self"], lambda: f(**kwargs)) + with_bf32_enabled(kwargs["self"], lambda: f(**kwargs)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: f(**kwargs) diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index edb897b6f99a5..87842d0c32321 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -3424,8 +3424,13 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad skips=( # No channels_last support for AvgPool1d as it does not take 4D inputs DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'), +<<<<<<< HEAD # backward not supported on MPS backend DecorateInfo(skipMPS, 'TestModule', 'test_non_contiguous_tensors'),) +======= + # not supported on MPS backend + DecorateInfo(skipMPS),) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ModuleInfo(torch.nn.BatchNorm1d, train_and_eval_differ=True, @@ -3864,6 +3869,12 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ModuleInfo(torch.nn.MaxPool3d, module_inputs_func=module_inputs_torch_nn_MaxPool3d, gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, +<<<<<<< HEAD +======= + skips=( + # not supported on MPS backend + DecorateInfo(skipIfMPS, device_type='mps'),) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ModuleInfo(torch.nn.KLDivLoss, module_inputs_func=module_inputs_torch_nn_KLDivLoss, @@ -4064,6 +4075,17 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ), ModuleInfo(torch.nn.LocalResponseNorm, module_inputs_func=module_inputs_torch_nn_LocalResponseNorm, +<<<<<<< HEAD +======= + skips=( + # uses avg_pool3d which is not supported on MPS backend + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format'), + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_non_contiguous_tensors'), + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_forward'), + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_if_train_and_eval_modes_differ'), + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_non_contiguous'), + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_save_load'),) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), ModuleInfo(torch.nn.LayerNorm, module_inputs_func=module_inputs_torch_nn_LayerNorm, diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index ea07fd3c05143..8318955d96dba 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -12,9 +12,14 @@ def mps_ops_modifier( ops: Sequence[OpInfo], +<<<<<<< HEAD device_type: str = "mps", xfail_exclusion: Optional[list[str]] = None, sparse: bool = False, +======= + device_type: Optional[str] = None, + xfail_exclusion: Optional[list[str]] = None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Sequence[OpInfo]: if xfail_exclusion is None: xfail_exclusion = [] @@ -26,7 +31,10 @@ def mps_ops_modifier( "__rsub__", "__getitem__", "_unsafe_masked_index", +<<<<<<< HEAD "_unsafe_masked_index_put_accumulate", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "abs", "add", "alias_copy", @@ -38,7 +46,10 @@ def mps_ops_modifier( "as_strided_copy", "as_strided_scatter", "asin", +<<<<<<< HEAD "asinh", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "acos", "atan", "broadcast_tensors", @@ -76,10 +87,15 @@ def mps_ops_modifier( "H", "hsplit", "imag", +<<<<<<< HEAD "index_add", "index_copy", "index_select", "index_put", +======= + "index_copy", + "index_select", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "isfinite", "isinf", "isreal", @@ -183,6 +199,12 @@ def mps_ops_modifier( "zero_", "zeros", "zeros_like", +<<<<<<< HEAD +======= + } + + AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS = { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "__rdiv__", "__rmatmul__", "_chunk_cat", @@ -276,6 +298,11 @@ def mps_ops_modifier( "roll", "rot90", "short", +<<<<<<< HEAD +======= + "sinh", + "sqrt", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "square", "stack", "stft", @@ -289,6 +316,88 @@ def mps_ops_modifier( "where", "byte", } +<<<<<<< HEAD +======= + # Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758 + MACOS_BEFORE_13_3_XFAILLIST = { + # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ + "cdist": [torch.float32], + # CPU Error: cpu not giving nan for x/0.0 + "atan2": [ + torch.bool, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.int8, + ], + # test blow pass on macOS 12 as it falls back to cpu + # Argsort case using duplicate indices (undefined behaviour): + # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], device='cpu') + # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') + # Elements from index 30 and 5133 are both equal. + # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. + "argsort": [torch.float16, torch.int8, torch.uint8, torch.bool], + # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. + # The values of the sorted tensor match the CPU, + # but in case of the returned indices this results in undefined behaviour. + "sort": [torch.int8, torch.uint8, torch.bool, torch.float16], + # Unsupported dtypes + "cumsum": [torch.int64], + "cumprod": [torch.int64], + "cumulative_trapezoid": [torch.int64], + "masked.cumsum": [torch.int64], + "masked.cumprod": [torch.int64], + "linalg.vander": [torch.int64], + # Fail with `Expected 1.0 but got nan.` for empty tensors + # Caused by sample input at index 23: SampleInput( + # input=Tensor[size=(), device="mps:0", dtype=torch.float32], + # args=(0), + # kwargs={'mask': 'Tensor[size=(), device="mps:0", dtype=torch.bool]'}, + # broadcasts_input=False, name='') + "masked.softmin": [torch.float32, torch.float16], + "masked.softmax": [torch.float32, torch.float16], + "masked.log_softmax": [torch.float32, torch.float16], + } + + MACOS_AFTER_13_1_XFAILLIST = { + # before macOS 13.2 it falls back to cpu and pass the forward pass + "grid_sampler_2d": [ + torch.float32, + torch.float16, + torch.bfloat16, + ], # Unsupported Border padding mode + } + + MACOS_13_3_XFAILLIST = { + # Failure due to precision issue for fp16 + # on both cpu and mps there are test cases that might produce inf result + # 'nn.functional.pairwise_distance': [torch.float16], + # test blow pass on macOS 12 as it falls back to cpu + # Argsort case using duplicate indices (undefined behaviour): + # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], device='cpu') + # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') + # Elements from index 30 and 5133 are both equal. + # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. + "argsort": [ + torch.float16, + torch.int8, + torch.uint8, + torch.bool, + torch.bfloat16, + ], + # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. + # The values of the sorted tensor match the CPU, + # but in case of the returned indices this results in undefined behaviour. + "sort": [ + torch.int8, + torch.uint8, + torch.bool, + torch.float16, + torch.bfloat16, + ], + } +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MACOS_BEFORE_14_4_XFAILLIST = { # These ops work fine in 14.4 but fail in 14.2 or 13.x @@ -296,7 +405,11 @@ def mps_ops_modifier( } # Those ops are not expected to work +<<<<<<< HEAD UNIMPLEMENTED_XFAILLIST: dict[str, Optional[list]] = { +======= + UNIMPLEMENTED_XFAILLIST = { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Failures due to lack of op implementation on MPS backend "logspace": None, "logspacetensor_overload": None, @@ -311,13 +424,23 @@ def mps_ops_modifier( "gcd": None, "geqrf": None, "nn.functional.grid_sample": None, # Unsupported Border padding mode +<<<<<<< HEAD "hash_tensor": None, "heaviside": None, +======= + "heaviside": None, + "igamma": None, + "igammac": None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "index_reduceprod": None, "index_reducemean": None, "index_reduceamax": None, "index_reduceamin": None, +<<<<<<< HEAD # "kthvalue": None, +======= + "kthvalue": None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lcm": None, "linalg.cond": None, "linalg.eigh": None, @@ -336,10 +459,18 @@ def mps_ops_modifier( "linalg.qr": None, "linalg.svdvals": None, "linalg.vecdot": None, +<<<<<<< HEAD +======= + "logcumsumexp": None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "lu_solve": None, "masked.median": None, "matrix_exp": None, "mode": None, +<<<<<<< HEAD +======= + "native_dropout_backward": None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "normnuc": None, "nn.functional.fractional_max_pool2d": None, "nn.functional.fractional_max_pool3d": None, @@ -347,8 +478,21 @@ def mps_ops_modifier( "nn.functional.adaptive_max_pool3d": None, "nn.functional.interpolatearea": None, "nn.functional.interpolatebicubic": [torch.uint8], +<<<<<<< HEAD "nn.functional.ctc_loss": None, "nn.functional.embedding_bag": None, +======= + "nn.functional.max_unpool1dgrad": None, + "nn.functional.max_unpool2dgrad": None, + "nn.functional.max_unpool3dgrad": None, + "nn.functional.avg_pool3d": None, + "nn.functional.ctc_loss": None, + "nn.functional.embedding_bag": None, + "nn.functional.max_pool3d": None, + "nn.functional.max_unpool1d": None, + "nn.functional.max_unpool2d": None, + "nn.functional.max_unpool3d": None, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "nn.functional.multi_margin_loss": None, "nn.functional.multilabel_margin_loss": None, "nn.functional.pdist": None, @@ -377,7 +521,10 @@ def mps_ops_modifier( "special.airy_ai": None, "special.erfcx": None, "special.laguerre_polynomial_l": None, +<<<<<<< HEAD "special.legendre_polynomial_p": None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "special.log_ndtr": None, "special.ndtri": None, "svd_lowrank": None, @@ -418,7 +565,13 @@ def mps_ops_modifier( torch.float16, ], # Unsupported dtypes +<<<<<<< HEAD + "histc": [torch.float16, torch.bfloat16], +======= + "dot": [torch.int64] if MACOS_VERSION < 14.0 else [], "histc": [torch.float16, torch.bfloat16], + "index_add": [torch.int64], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # GEMM on MPS is not supported for integral types "nn.functional.linear": [ torch.int16, @@ -427,9 +580,25 @@ def mps_ops_modifier( torch.uint8, torch.int8, ], +<<<<<<< HEAD "addbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], "baddbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], "mat": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], +======= + "addmmdecomposed": [ + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.int8, + ], + "addbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + "addmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + "baddbmm": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + "mat": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + "matmul": [torch.int64] if MACOS_VERSION < 14.0 else [], + "__rmatmul__": [torch.int64] if MACOS_VERSION < 14.0 else [], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # returned output on CPU is float64 "bincount": [ torch.int16, @@ -438,6 +607,7 @@ def mps_ops_modifier( torch.uint8, torch.int8, ], +<<<<<<< HEAD } UNIMPLEMENTED_XFAILLIST_SPARSE: dict[str, Optional[list]] = { "logspace": None, @@ -446,6 +616,56 @@ def mps_ops_modifier( "linalg.eigvals": None, "put": None, } +======= + # round not working properly for float16 and bfloat16 + "round": [torch.float16, torch.bfloat16], + "rounddecimals_0": [torch.bfloat16], + # atomic operations not supported + "_unsafe_masked_index_put_accumulate": [ + torch.int8, + torch.uint8, + torch.int16, + torch.int64, + ], + } + + if MACOS_VERSION < 14.0: + # FFT and BFloat16 support was added in MacOS 14 + UNIMPLEMENTED_XFAILLIST.update( + { + "bfloat16": None, + "fft.fft": None, + "fft.fft2": None, + "fft.fftn": None, + "fft.hfft": None, + "fft.hfft2": None, + "fft.hfftn": None, + "fft.ifft": None, + "fft.ifft2": None, + "fft.ifftn": None, + "fft.ihfft": None, + "fft.ihfft2": None, + "fft.ihfftn": None, + "fft.irfft": None, + "fft.irfft2": None, + "fft.irfftn": None, + "fft.rfft": None, + "fft.rfft2": None, + "fft.rfftn": None, + "stft": None, + # Error in TestConsistencyCPU.test_output_match_isin_cpu fails for integers, + # not reproducible in later OS. Added assert to op if used in < 14.0 + "isin": [ + torch.int64, + torch.int32, + torch.int16, + torch.uint8, + torch.int8, + ], + "nn.functional.max_pool2d": [torch.uint8], + } + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if MACOS_VERSION < 15.0: UNIMPLEMENTED_XFAILLIST.update( @@ -454,10 +674,15 @@ def mps_ops_modifier( "nanquantile": None, } ) +<<<<<<< HEAD if sparse: UNIMPLEMENTED_XFAILLIST.update(UNIMPLEMENTED_XFAILLIST_SPARSE) UNDEFINED_XFAILLIST: dict[str, Optional[list]] = { +======= + + UNDEFINED_XFAILLIST = { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Top 60 operators # topk fails with duplicate indices "topk": [ @@ -504,6 +729,15 @@ def mps_ops_modifier( torch.float16, torch.bfloat16, ], +<<<<<<< HEAD +======= + "index_put": [ + torch.uint8, + torch.int8, + torch.int16, + torch.int64, + ], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # zero to negative integer powers are undefined "__rpow__": [torch.int8, torch.int16, torch.int32, torch.int64], "resize_": [torch.float16, torch.float32, torch.bfloat16], @@ -534,12 +768,17 @@ def mps_ops_modifier( ], } +<<<<<<< HEAD ON_MPS_XFAILLIST: dict[str, Optional[list]] = { +======= + ON_MPS_XFAILLIST = { +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Failures due to lack of implementation of downstream functions on MPS backend # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented "linalg.matrix_rank": None, # Exception: Caused by `torch.arange(-8.001, -4.0, dtype=torch.uint8, device="mps")` "arange": [torch.uint8], +<<<<<<< HEAD # before macOS 13.2 it falls back to cpu and pass the forward pass "grid_sampler_2d": [ torch.float32, @@ -572,6 +811,8 @@ def mps_ops_modifier( torch.float16, torch.bfloat16, ], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } EMPTY_OPS_SKIPLIST = { @@ -593,10 +834,13 @@ def mps_ops_modifier( # Unsupported # This doesn't work on M1, but is partially working on M2 with the exception of torch.float16 "nn.functional.conv3d": None, +<<<<<<< HEAD # The CPU impl of grid_sampler_3d does not use opmath_t, so it has a # large amount of error compared with the MPS impl for half # precision types. So we have to skip these for now. "grid_sampler_3d": [torch.float16, torch.bfloat16], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } def addDecorator(op: OpInfo, d: DecorateInfo) -> None: @@ -607,6 +851,7 @@ def addDecorator(op: OpInfo, d: DecorateInfo) -> None: for op in ops: key = op.name + op.variant_test_name +<<<<<<< HEAD addDecorator( op, DecorateInfo( @@ -631,6 +876,8 @@ def addDecorator(op: OpInfo, d: DecorateInfo) -> None: ], ), ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if key in EMPTY_OPS_SKIPLIST: addDecorator( op, @@ -667,8 +914,53 @@ def addDecorator(op: OpInfo, d: DecorateInfo) -> None: ), ) +<<<<<<< HEAD # If ops is not supported for complex types, expect it to fail if key not in SUPPORTED_COMPLEX_OPS: +======= + if ( + key in MACOS_BEFORE_13_3_XFAILLIST + and key not in xfail_exclusion + and (torch.backends.mps.is_macos13_or_newer() and MACOS_VERSION < 13.3) + ): + addDecorator( + op, + DecorateInfo( + unittest.expectedFailure, + dtypes=MACOS_BEFORE_13_3_XFAILLIST[key], + ), + ) + + if ( + key in MACOS_AFTER_13_1_XFAILLIST + and key not in xfail_exclusion + and torch.backends.mps.is_macos13_or_newer(2) + ): + addDecorator( + op, + DecorateInfo( + unittest.expectedFailure, dtypes=MACOS_AFTER_13_1_XFAILLIST[key] + ), + ) + + if ( + key in MACOS_13_3_XFAILLIST + and key not in xfail_exclusion + and (MACOS_VERSION >= 13.3) + ): + addDecorator( + op, + DecorateInfo( + unittest.expectedFailure, dtypes=MACOS_13_3_XFAILLIST[key] + ), + ) + + # If ops is not supported for complex types, expect it to fail + if key not in SUPPORTED_COMPLEX_OPS and ( + key not in AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS + or MACOS_VERSION < 14.0 + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) addDecorator( op, DecorateInfo( @@ -691,10 +983,14 @@ def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: "scalar_tensor": [torch.float16, torch.float32], "cdist": [torch.float32], "masked.scatter": [torch.float16, torch.float32], +<<<<<<< HEAD "grid_sampler_3d": None, "index_fill": [torch.float16, torch.float32], # missing `aten::_unique`. "igamma": None, # currently not supported for any device "igammac": None, # currently not supported for any device +======= + "index_fill": [torch.float16, torch.float32], # missing `aten::_unique`. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "linalg.solve": [torch.float16, torch.float32], # missing `aten::lu_solve`. "linalg.solve_ex": [ torch.float16, @@ -715,11 +1011,14 @@ def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: "special.i1e": [torch.float16], # "i1e_backward" not implemented for 'Half' # Correctness issues "atanh": [torch.float32], +<<<<<<< HEAD # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour). # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU. # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS. # Running `msort` with stable `sort` passes. "msort": [torch.float16], +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Random output "exponential": [torch.float16, torch.float32], # CPU errors @@ -756,10 +1055,34 @@ def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: "signal.windows.kaiser": [torch.float32], "signal.windows.nuttall": [torch.float32], "eye": [torch.float16, torch.float32], +<<<<<<< HEAD +======= + # round not working properly for float16 + "round": [torch.float16], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # topk fails with duplicate indices "topk": [torch.float16], } +<<<<<<< HEAD +======= + MACOS_BEFORE_13_3_XFAILLIST_GRAD = { + # Failures due to precision issues (may be fast-math). These has been fixed in MacOS 14 + "masked.softmin": [torch.float32, torch.float16], + "masked.softmax": [torch.float32, torch.float16], + "masked.log_softmax": [torch.float32, torch.float16], + "atanh": [torch.float16], + "triangular_solve": [torch.float32], + # Unsupported Border padding mode, forward pass success as fallback to cpu + "grid_sampler_2d": [torch.float32, torch.float16, torch.bfloat16], + # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour). + # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU. + # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS. + # Running `msort` with stable `sort` passes. + "msort": [torch.float16], + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) SKIPLIST_GRAD = { "nn.functional.pairwise_distance": [torch.float16], # failed assertion `destination datatype must be fp32' @@ -771,6 +1094,17 @@ def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: "nn.functional.conv_transpose3d": [torch.float16], } +<<<<<<< HEAD +======= + MACOS_13_3_XFAILLIST_GRAD = { + # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour). + # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU. + # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS. + # Running `msort` with stable `sort` passes. + "msort": [torch.float16], + } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ON_MPS_XFAILLIST = { # Failures due to lack of implementation of downstream functions on MPS backend # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented @@ -801,6 +1135,27 @@ def addDecorator(op: OpInfo, d: DecorateInfo) -> None: ), ) +<<<<<<< HEAD +======= + if key in MACOS_BEFORE_13_3_XFAILLIST_GRAD and ( + torch.backends.mps.is_macos13_or_newer() and MACOS_VERSION < 13.3 + ): + addDecorator( + op, + DecorateInfo( + unittest.expectedFailure, + dtypes=MACOS_BEFORE_13_3_XFAILLIST_GRAD[key], + ), + ) + + if key in MACOS_13_3_XFAILLIST_GRAD and (MACOS_VERSION >= 13.3): + addDecorator( + op, + DecorateInfo( + unittest.expectedFailure, dtypes=MACOS_13_3_XFAILLIST_GRAD[key] + ), + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ops def mps_ops_error_inputs_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: @@ -815,6 +1170,11 @@ def mps_ops_error_inputs_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: "clamp_min", "masked_scatter", # unsupported float64 dtype +<<<<<<< HEAD +======= + "cat", + "complex", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "multinomial", "nn.functional.conv1d", "nn.functional.conv2d", @@ -828,6 +1188,11 @@ def mps_ops_error_inputs_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: "aminmax", # memory overlapping checks "index_select", +<<<<<<< HEAD +======= + # unimplemented + "logcumsumexp", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } def addDecorator(op: OpInfo, d: DecorateInfo) -> None: @@ -839,6 +1204,7 @@ def addDecorator(op: OpInfo, d: DecorateInfo) -> None: addDecorator(op, DecorateInfo(unittest.expectedFailure)) return ops +<<<<<<< HEAD else: def mps_ops_modifier( @@ -848,3 +1214,5 @@ def mps_ops_modifier( sparse: bool = False, ) -> Sequence[OpInfo]: return ops +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 1e4368380bb59..92cb7772e8102 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -20,7 +20,10 @@ AdamW, ASGD, LBFGS, +<<<<<<< HEAD Muon, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NAdam, Optimizer, RAdam, @@ -246,9 +249,14 @@ def test_wrapper(*args, **kwargs): # Helper function for generating error inputs for all optimizers, used below. def get_error_inputs_for_all_optims(device, dtype): if _get_device_type(device) == "cpu": +<<<<<<< HEAD # Creating 2D parameters for compatibility with Muon. sample_param = Parameter(torch.randn(1, 1, device=device, dtype=dtype)) sample_param2 = Parameter(torch.randn(1, 1, device=device, dtype=dtype)) +======= + sample_param = Parameter(torch.randn(1, device=device, dtype=dtype)) + sample_param2 = Parameter(torch.randn(1, device=device, dtype=dtype)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return [ ErrorOptimizerInput( OptimizerInput( @@ -835,6 +843,7 @@ def optim_error_inputs_func_lbfgs(device, dtype): return error_inputs +<<<<<<< HEAD def optim_inputs_func_muon(device, dtype=None): return [ OptimizerInput(params=None, kwargs={}, desc="default"), @@ -910,6 +919,8 @@ def optim_error_inputs_func_muon(device, dtype): return error_inputs +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def optim_inputs_func_nadam(device, dtype=None): cuda_supported_configs = [ OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), @@ -1345,9 +1356,15 @@ def _get_optim_inputs_including_global_cliquey_kwargs( trivial. That said, we sometimes want to test for all possible configs on an optimizer including all supported flags, so this helper returns all optim inputs. """ +<<<<<<< HEAD assert all(x in ["foreach", "fused", "differentiable"] for x in skip), ( "skip must be a subset of ['foreach', 'fused', 'differentiable']" ) +======= + assert all( + x in ["foreach", "fused", "differentiable"] for x in skip + ), "skip must be a subset of ['foreach', 'fused', 'differentiable']" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) optim_inputs = optim_info.optim_inputs_func(device) @@ -1947,6 +1964,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( ), ), OptimizerInfo( +<<<<<<< HEAD Muon, optim_inputs_func=optim_inputs_func_muon, optim_error_inputs_func=optim_error_inputs_func_muon, @@ -1968,6 +1986,8 @@ def _get_optim_inputs_including_global_cliquey_kwargs( ), ), OptimizerInfo( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NAdam, optim_inputs_func=optim_inputs_func_nadam, optim_error_inputs_func=optim_error_inputs_func_nadam, diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 600848b80a7e8..fdd171b81fa94 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -611,7 +611,11 @@ def _group_quantize_tensor_symmetric(w, n_bit=4, groupsize=32): def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): +<<<<<<< HEAD # source: https://github.com/meta-pytorch/gpt-fast/blob/main/quantize.py +======= + # source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # default setup for affine quantization of activations x_dtype = x.dtype x = x.float() @@ -3394,7 +3398,11 @@ def get_default_quantizer(is_qat, is_dynamic, inputs): maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() with maybe_no_grad: +<<<<<<< HEAD export_model = export_for_training(mod, inputs, strict=True).module(check_guards=False) +======= + export_model = export_for_training(mod, inputs, strict=True).module() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) quantizer = ( quantizer if quantizer diff --git a/torch/testing/_internal/common_quantized.py b/torch/testing/_internal/common_quantized.py index 0dc9d4cb3db72..c4075999a3a4d 100644 --- a/torch/testing/_internal/common_quantized.py +++ b/torch/testing/_internal/common_quantized.py @@ -10,6 +10,10 @@ from torch.testing._internal.common_utils import TEST_WITH_TSAN, IS_PPC, IS_MACOS, IS_WINDOWS supported_qengines = torch.backends.quantized.supported_engines +<<<<<<< HEAD +======= +supported_qengines.remove('none') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Note: We currently do not run QNNPACK tests on WINDOWS and MACOS as it is flaky. Issue #29326 # QNNPACK is not supported on PPC if 'qnnpack' in supported_qengines and any([IS_PPC, TEST_WITH_TSAN, IS_MACOS, IS_WINDOWS]): @@ -479,6 +483,7 @@ def to_blocked(input_matrix) -> torch.Tensor: rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) return rearranged.flatten() +<<<<<<< HEAD # This function is extracted from https://github.com/pytorch/ao/blob/v0.12.0/torchao/prototype/mx_formats/mx_tensor.py#L142 def to_mxfp8( @@ -586,3 +591,5 @@ def generate_jagged_offs(E, M, multiple_of=16, dtype=torch.int32, device="cuda") selected_values, _ = torch.sort(selected_values) return selected_values.to(dtype).to(device) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 38dc910f595e4..41b5de86d1c66 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -49,6 +49,10 @@ from typing import ( Any, Callable, +<<<<<<< HEAD +======= + Dict, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Optional, TypeVar, Union, @@ -101,9 +105,14 @@ except ImportError: has_pytest = False +<<<<<<< HEAD MI350_ARCH = ("gfx950",) MI300_ARCH = ("gfx942",) MI200_ARCH = ("gfx90a") +======= + +MI300_ARCH = ("gfx940", "gfx941", "gfx942") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) NAVI_ARCH = ("gfx1030", "gfx1100", "gfx1101", "gfx1200", "gfx1201") NAVI3_ARCH = ("gfx1100", "gfx1101") NAVI4_ARCH = ("gfx1200", "gfx1201") @@ -308,7 +317,11 @@ def maybe_load_json(filename): if os.getenv("DISABLED_TESTS_FILE", ""): disabled_tests_dict = maybe_load_json(os.getenv("DISABLED_TESTS_FILE", "")) +<<<<<<< HEAD NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', 'mps', torch._C._get_privateuse1_backend_name()) +======= +NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', torch._C._get_privateuse1_backend_name()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # used for managing devices testing for torch profiler UTs # for now cpu, cuda and xpu are added for testing torch profiler UTs @@ -340,10 +353,16 @@ def extract_test_fn() -> Optional[Callable]: self_val = frame.f_locals["self"] if isinstance(self_val, unittest.TestCase): test_id = self_val.id() +<<<<<<< HEAD *_, cls_name, test_name = test_id.rsplit('.', 2) if cls_name == type(self_val).__name__ and test_name.startswith("test"): test_fn = getattr(self_val, test_name).__func__ return test_fn +======= + test_name = test_id.split('.')[2] + test_fn = getattr(self_val, test_name).__func__ + return test_fn +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except Exception: pass return None @@ -1708,10 +1727,13 @@ def serialTest(condition=True): """ Decorator for running tests serially. Requires pytest """ +<<<<<<< HEAD # If one apply decorator directly condition will be callable # And test will essentially be essentially skipped, which is undesirable assert type(condition) is bool +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def decorator(fn): if has_pytest and condition: return pytest.mark.serial(fn) @@ -1931,6 +1953,7 @@ def wrapper(*args, **kwargs): return dec_fn(func) return dec_fn +<<<<<<< HEAD def getRocmArchName(device_index: int = 0): return torch.cuda.get_device_properties(device_index).gcnArchName @@ -1938,13 +1961,23 @@ def isRocmArchAnyOf(arch: tuple[str, ...]): rocmArch = getRocmArchName() return any(x in rocmArch for x in arch) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def skipIfRocmArch(arch: tuple[str, ...]): def dec_fn(fn): @wraps(fn) def wrap_fn(self, *args, **kwargs): +<<<<<<< HEAD if TEST_WITH_ROCM and isRocmArchAnyOf(arch): reason = f"skipIfRocm: test skipped on {arch}" raise unittest.SkipTest(reason) +======= + if TEST_WITH_ROCM: + prop = torch.cuda.get_device_properties(0) + if prop.gcnArchName.split(":")[0] in arch: + reason = f"skipIfRocm: test skipped on {arch}" + raise unittest.SkipTest(reason) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return fn(self, *args, **kwargs) return wrap_fn return dec_fn @@ -1962,9 +1995,17 @@ def runOnRocmArch(arch: tuple[str, ...]): def dec_fn(fn): @wraps(fn) def wrap_fn(self, *args, **kwargs): +<<<<<<< HEAD if TEST_WITH_ROCM and not isRocmArchAnyOf(arch): reason = f"skipIfRocm: test only runs on {arch}" raise unittest.SkipTest(reason) +======= + if TEST_WITH_ROCM: + prop = torch.cuda.get_device_properties(0) + if prop.gcnArchName.split(":")[0] not in arch: + reason = f"skipIfRocm: test only runs on {arch}" + raise unittest.SkipTest(reason) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return fn(self, *args, **kwargs) return wrap_fn return dec_fn @@ -1972,6 +2013,7 @@ def wrap_fn(self, *args, **kwargs): def xfailIfS390X(func): return unittest.expectedFailure(func) if IS_S390X else func +<<<<<<< HEAD def xfailIf(condition): def wrapper(func): if condition: @@ -1980,6 +2022,8 @@ def wrapper(func): return func return wrapper +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def skipIfXpu(func=None, *, msg="test doesn't currently work on the XPU stack"): def dec_fn(fn): reason = f"skipIfXpu: {msg}" @@ -2024,18 +2068,27 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper +<<<<<<< HEAD def getRocmVersion() -> tuple[int, int]: from torch.testing._internal.common_cuda import _get_torch_rocm_version rocm_version = _get_torch_rocm_version() return (rocm_version[0], rocm_version[1]) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Skips a test on CUDA if ROCm is available and its version is lower than requested. def skipIfRocmVersionLessThan(version=None): def dec_fn(fn): @wraps(fn) def wrap_fn(self, *args, **kwargs): if TEST_WITH_ROCM: +<<<<<<< HEAD rocm_version_tuple = getRocmVersion() +======= + rocm_version = str(torch.version.hip) + rocm_version = rocm_version.split("-")[0] # ignore git sha + rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version): reason = f"ROCm {rocm_version_tuple} is available but {version} required" raise unittest.SkipTest(reason) @@ -2043,6 +2096,26 @@ def wrap_fn(self, *args, **kwargs): return wrap_fn return dec_fn +<<<<<<< HEAD +======= +def skipIfRocmVersionAndArch(version=None, arch=None): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if TEST_WITH_ROCM: + rocm_version = str(torch.version.hip) + rocm_version = rocm_version.split("-")[0] # ignore git sha + rocm_version_tuple = tuple(int(x) for x in rocm_version.split(".")) + if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version): + prop = torch.cuda.get_device_properties(0) + if prop.gcnArchName.split(":")[0] in arch: + reason = f"ROCm {version} and {arch} combination not supported" + raise unittest.SkipTest(reason) + return fn(self, *args, **kwargs) + return wrap_fn + return dec_fn + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def skipIfNotMiopenSuggestNHWC(fn): @wraps(fn) def wrapper(*args, **kwargs): @@ -2067,6 +2140,7 @@ def wrapper(*args, **kwargs): return dec_fn(func) return dec_fn +<<<<<<< HEAD def requires_cuda_p2p_access(): cuda_p2p_access_available = ( torch.cuda.is_available() @@ -2087,6 +2161,8 @@ def requires_cuda_p2p_access(): "cuda p2p access is not available", ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Reverts the linalg backend back to default to make sure potential failures in one # test do not affect other tests def setLinalgBackendsToDefaultFinally(fn): @@ -2601,18 +2677,32 @@ def __exit__(self, exc_type, exc_value, traceback): msg = ("CUDA caching allocator reports a memory leak not " # type: ignore[possibly-undefined] f"verified by the driver API in {self.name}! " f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} " +<<<<<<< HEAD f"and is now reported as {caching_allocator_mem_allocated} " # type: ignore[possibly-undefined] f"on device {i}. " f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.") # type: ignore[possibly-undefined] warnings.warn(msg) elif caching_allocator_discrepancy and driver_discrepancy: # type: ignore[possibly-undefined] +======= + f"and is now reported as {caching_allocator_mem_allocated} " + f"on device {i}. " + f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.") + warnings.warn(msg) + elif caching_allocator_discrepancy and driver_discrepancy: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # A caching allocator discrepancy validated by the driver API is a # failure (except on ROCm, see below) msg = (f"CUDA driver API confirmed a leak in {self.name}! " # type: ignore[possibly-undefined] f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} " +<<<<<<< HEAD f"and is now reported as {caching_allocator_mem_allocated} " # type: ignore[possibly-undefined] f"on device {i}. " f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.") # type: ignore[possibly-undefined] +======= + f"and is now reported as {caching_allocator_mem_allocated} " + f"on device {i}. " + f"CUDA driver allocated memory was {self.driver_befores[i]} and is now {driver_mem_allocated}.") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise RuntimeError(msg) @@ -3342,7 +3432,11 @@ def expect_failure(f, file_name): def wrapper(*args, **kwargs): try: f(*args, **kwargs) +<<<<<<< HEAD except BaseException as e: # noqa: B036 +======= + except BaseException as e: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.skipTest(e) raise RuntimeError(f"Unexpected success, please remove `{file_name}`") return wrapper @@ -3364,7 +3458,11 @@ def ignore_failure(f, file_name): def wrapper(*args, **kwargs): try: f(*args, **kwargs) +<<<<<<< HEAD except BaseException as e: # noqa: B036 +======= + except BaseException as e: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.skipTest(e) method = getattr(self, self._testMethodName) if getattr(method, "__unittest_expecting_failure__", False): @@ -3394,8 +3492,11 @@ def wrapper(*args, **kwargs): if strict_mode or should_reset_dynamo: torch._dynamo.reset() +<<<<<<< HEAD elif torch._dynamo.config.compiled_autograd: torch._dynamo.compiled_autograd.reset() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Early terminate test if necessary. If using pytest, use the -x flag instead if using_unittest and self._should_stop_test_suite(): @@ -5760,6 +5861,7 @@ def load_inline(*args, **kwargs): return cpp_extension.load_inline(*args, **kwargs) return func(*args, load_inline=load_inline, **kwargs) +<<<<<<< HEAD return wrapper def recover_orig_fp32_precision(fn): @@ -5795,3 +5897,30 @@ def wrap_fn(self, *args, **kwargs): raise unittest.SkipTest("Python version mismatch") return wrap_fn return dec_fn +======= + + return wrapper + +# Decorator to patch multiple test class members for the duration of the subtest +def patch_test_members(updates: Dict[str, Any]): + def decorator(test_func): + @wraps(test_func) + def wrapper(self, *args, **kwargs): + # Store the original values of the specified members + original_values = {member: getattr(self, member) for member in updates} + + # Update the members before running the subtest + for member, value in updates.items(): + setattr(self, member, value) + + # Run the test function, allowing subtests to run + try: + return test_func(self, *args, **kwargs) + finally: + # Restore the original values of the specified members after the subtest finishes + for member, original_value in original_values.items(): + setattr(self, member, original_value) + + return wrapper + return decorator +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py b/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py index 60c744ac1a84c..6315d263a4a1c 100644 --- a/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py @@ -22,7 +22,11 @@ def world_size(self): return TEST_GPU_NUM def init_pg(self, backend="nccl"): +<<<<<<< HEAD if backend not in ["nccl", "gloo", "mpi", "hccl"]: +======= + if backend not in ["nccl", "gloo", "mpi"]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise RuntimeError(f"Backend {backend} not supported!") dist.init_process_group( diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index e25e08fbf5090..c6d35590e9418 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -17,8 +17,11 @@ from torch.distributed.tensor import ( DeviceMesh, distribute_tensor, +<<<<<<< HEAD DTensor, init_device_mesh, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Placement, Replicate, Shard, @@ -31,7 +34,10 @@ SequenceParallel, ) from torch.testing._internal.common_distributed import ( +<<<<<<< HEAD MultiProcContinuousTest, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) MultiProcessTestCase, MultiThreadedTestCase, run_subtests, @@ -42,8 +48,11 @@ from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec +<<<<<<< HEAD DEVICE_COUNT: int +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TEST_CUDA: DEVICE_TYPE = "cuda" PG_BACKEND = "nccl" @@ -337,6 +346,7 @@ def skip_unless_torch_gpu(method: T) -> T: return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method)) +<<<<<<< HEAD class DTensorContinuousTestBase(MultiProcContinuousTest): @classmethod def device_type(cls) -> str: @@ -352,6 +362,8 @@ def backend_str(cls) -> str: return backend +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class DTensorTestBase(MultiProcessTestCase): @property def world_size(self) -> int: @@ -371,6 +383,7 @@ def backend(self) -> str: return backend def build_device_mesh(self) -> DeviceMesh: +<<<<<<< HEAD return init_device_mesh(self.device_type, (self.world_size,)) def init_pg(self, eager_init, backend: Optional[str] = None) -> None: @@ -381,12 +394,22 @@ def init_pg(self, eager_init, backend: Optional[str] = None) -> None: backend = self.backend if backend not in [ +======= + return DeviceMesh(self.device_type, list(range(self.world_size))) + + def init_pg(self, eager_init) -> None: + if "nccl" in self.backend and torch.cuda.device_count() < self.world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) + + if self.backend not in [ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "nccl", "gloo", "mpi", "cpu:gloo,cuda:nccl", "hccl", "xccl", +<<<<<<< HEAD "fake", ]: raise RuntimeError(f"Backend {backend} not supported!") @@ -396,17 +419,32 @@ def init_pg(self, eager_init, backend: Optional[str] = None) -> None: # set device for nccl pg for collectives # TODO: if users want to enable testing across hosts, we may need # to change this part. +======= + ]: + raise RuntimeError(f"Backend {self.backend} not supported!") + + device_id = None + if "nccl" in self.backend or "xccl" in self.backend: + # set device for nccl pg for collectives +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.accelerator.set_device_index(self.rank) # we only need to set device_id for nccl backend with eager init device_id = ( torch.device(f"{self.device_type}:{self.rank}") if eager_init else None ) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # For nccl backend, bind the device to the process if device_id is not None # so the nccl communicator is immediately formed and we can use `ncclCommSplit` # for form subgroup to avoid unnecesssary overhead. dist.init_process_group( +<<<<<<< HEAD backend=backend, +======= + backend=self.backend, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) world_size=self.world_size, rank=self.rank, # pyre-ignore[16] init_method=f"file://{self.file_name}", # pyre-ignore[16] @@ -423,6 +461,7 @@ def destroy_pg(self, device_id: Optional[int] = None) -> None: device_id = ( torch.cuda.current_device() if self.device_type == "cuda" else self.rank ) +<<<<<<< HEAD if self.device_type == "cpu" and torch._C._get_accelerator().type != "cpu": # NOTE: when `device_id` is not None, barrier() will choose the accelerator @@ -435,12 +474,16 @@ def destroy_pg(self, device_id: Optional[int] = None) -> None: else: dist.barrier(device_ids=[device_id]) +======= + dist.barrier(device_ids=[device_id]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) dist.destroy_process_group() def setUp(self) -> None: super().setUp() self._spawn_processes() +<<<<<<< HEAD def _test_op_on_dtensor(self, op_call, *args, **kwargs) -> None: """ This function checks ``op_call(dtensor).full_tensor() == op_call(dtensor.full_tensor())``. @@ -467,6 +510,8 @@ def _test_op_on_dtensor(self, op_call, *args, **kwargs) -> None: d_out_full_tensor_flattened = [dt.full_tensor() for dt in d_out_flattened] self.assertEqual(out_flattened, d_out_full_tensor_flattened) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # pyre-ignore[2]: def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None: out = op_call(*args, **kwargs) @@ -485,6 +530,7 @@ def run_subtests(self, *args, **kwargs): # wrapper to initialize comms (processgroup) +<<<<<<< HEAD def with_comms( eager_init: Union[TestFunc, bool] = False, backend: Optional[str] = None ) -> TestFunc: @@ -496,6 +542,15 @@ def wrapper( **kwargs: dict[str, Any], # type: ignore[misc] ) -> None: self.init_pg(eager_init, backend) +======= +def with_comms(eager_init: Union[TestFunc, bool] = False) -> TestFunc: + def decorator(func, eager_init: bool = False): + @wraps(func) # pyre-ignore[6] + def wrapper( + self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc] + ) -> None: + self.init_pg(eager_init) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: func(self, *args, **kwargs) # type: ignore[misc] @@ -510,7 +565,11 @@ def wrapper( return ( decorator(func=eager_init) if callable(eager_init) +<<<<<<< HEAD else partial(decorator, eager_init=eager_init, backend=backend) +======= + else partial(decorator, eager_init=eager_init) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) @@ -524,7 +583,11 @@ def device_type(self) -> str: return DEVICE_TYPE def build_device_mesh(self): +<<<<<<< HEAD return init_device_mesh(self.device_type, (self.world_size,)) +======= + return DeviceMesh(self.device_type, list(range(self.world_size))) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def setUp(self) -> None: super().setUp() diff --git a/torch/testing/_internal/distributed/checkpoint_utils.py b/torch/testing/_internal/distributed/checkpoint_utils.py index 07b05140e36e6..a76060372324e 100644 --- a/torch/testing/_internal/distributed/checkpoint_utils.py +++ b/torch/testing/_internal/distributed/checkpoint_utils.py @@ -3,7 +3,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import io +<<<<<<< HEAD import logging +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import os import shutil import tempfile @@ -158,6 +161,7 @@ def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None: shutil.rmtree(self.temp_dir, ignore_errors=True) return wrapper +<<<<<<< HEAD def with_checkpoint_logging( @@ -191,3 +195,5 @@ def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None: target_logger.setLevel(original_level) return wrapper +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py index 61c21be3ca075..fe41b2fc1b1f5 100644 --- a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py +++ b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py @@ -253,11 +253,15 @@ def train_batch( else: input_batches = batches +<<<<<<< HEAD with ( self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext() ): +======= + with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for b in input_batches: with dist_autograd.context() as context_id: output = self.hybrid_module.forward(b) @@ -265,7 +269,12 @@ def train_batch( dist_autograd.backward(context_id, [loss]) grads_dict = dist_autograd.get_gradients(context_id) gLogger.info( +<<<<<<< HEAD "Loss is %s for mini batch: %s. Grads dict has %s entries: %s", +======= + "Loss is %s for mini batch: %s. " + "Grads dict has %s entries: %s", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) loss, mini_batch, len(grads_dict), diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 024fd47285ae8..618db39194f69 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -459,6 +459,7 @@ def require_world_size(world_size): return lambda func: func +<<<<<<< HEAD def require_exact_world_size(world_size): if int(os.environ["WORLD_SIZE"]) != world_size: return skip_but_pass_in_sandcastle( @@ -467,6 +468,8 @@ def require_exact_world_size(world_size): return lambda func: func +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @contextmanager def _lock(): TEMP_DIR = os.environ["TEMP_DIR"] @@ -929,7 +932,12 @@ def test_new_subgroups(self): BACKEND not in DistTestCases.backend_feature["subgroup"], f"The {BACKEND} backend does not support creating subgroups on CUDA devices", ) +<<<<<<< HEAD @require_exact_world_size(4) +======= + @require_world_size(4) + @skip_if_lt_x_gpu(4) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def test_new_subgroups_with_group_param(self): # Initialize global test environment self._init_global_test() @@ -974,10 +982,16 @@ def test_new_subgroups_group_size_exceeds_world_size(self): @require_world_size(4) @skip_if_lt_x_gpu(4) def test_new_subgroups_world_size_not_divisible_by_group_size(self): +<<<<<<< HEAD expected_msg = f"The world size ({dist.get_world_size()}) must be divisible by 'group_size=3'" with self.assertRaisesRegex( ValueError, re.escape(expected_msg), +======= + with self.assertRaisesRegex( + ValueError, + re.escape("The world size (4) must be divisible by 'group_size=3'"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): dist.new_subgroups(3) diff --git a/torch/testing/_internal/distributed/fake_pg.py b/torch/testing/_internal/distributed/fake_pg.py index 0a2814c246459..c5fa2f5464988 100644 --- a/torch/testing/_internal/distributed/fake_pg.py +++ b/torch/testing/_internal/distributed/fake_pg.py @@ -11,7 +11,11 @@ class FakeStore(dist.Store): """ +<<<<<<< HEAD def _create_fake_pg(common_opts, backend_opts): +======= +def _create_fake_pg(prefix_store, rank, world_size, timeout): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ A fake process group (not related to FakeTensor) is a process group which doesn't actually do any communication, it just hallucinates some @@ -22,6 +26,7 @@ def _create_fake_pg(common_opts, backend_opts): for every collective. It should be used as a convenient tool when playing with distributed but don't care about the actual data. """ +<<<<<<< HEAD return FakeProcessGroup( common_opts.group_rank, common_opts.group_size, backend_opts ) @@ -30,3 +35,9 @@ def _create_fake_pg(common_opts, backend_opts): dist.Backend.register_backend( "fake", _create_fake_pg, extended_api=True, devices=["cpu", "cuda", "hpu"] ) +======= + return FakeProcessGroup(rank, world_size) + + +dist.Backend.register_backend("fake", _create_fake_pg, devices=["cpu", "cuda", "hpu"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 4ec964092b391..d82b8acfd2f4b 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -3560,7 +3560,11 @@ def test_custom_exception_throw_during_reconstruction(self): print(f"Got msg {msg}") self.assertTrue("Original exception on remote side was" in msg) self.assertTrue("CustomException" in msg) +<<<<<<< HEAD except BaseException as e: # noqa: B036 +======= + except BaseException as e: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) raise RuntimeError(f"Failure - expected RuntimeError, got {e}") from e finally: self.assertTrue(exc_caught) diff --git a/torch/testing/_internal/dynamo_test_failures.py b/torch/testing/_internal/dynamo_test_failures.py index cdc69b7920cf0..c0b3a785ef43c 100644 --- a/torch/testing/_internal/dynamo_test_failures.py +++ b/torch/testing/_internal/dynamo_test_failures.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD """ This file contains the list of tests that are known to fail under Dynamo @@ -28,6 +29,37 @@ def find_test_dir() -> Optional[str]: +======= +# mypy: allow-untyped-defs +import logging +import os +import sys + + +# NOTE: [dynamo_test_failures.py] +# +# We generate xFailIfTorchDynamo* for all tests in `dynamo_expected_failures` +# We generate skipIfTorchDynamo* for all tests in `dynamo_skips` +# We generate runWithoutCompiledAutograd for all tests in `compiled_autograd_skips` +# +# For an easier-than-manual way of generating and updating these lists, +# see scripts/compile_tests/update_failures.py +# +# If you're adding a new test, and it's failing PYTORCH_TEST_WITH_DYNAMO=1, +# either add the appropriate decorators to your test or add skips for them +# via test/dynamo_skips and test/dynamo_expected_failures. +# +# *These are not exactly unittest.expectedFailure and unittest.skip. We'll +# always execute the test and then suppress the signal, if necessary. +# If your tests crashes, or is slow, please use @skipIfTorchDynamo instead. +# +# The expected failure and skip files are located in test/dynamo_skips and +# test/dynamo_expected_failures. They're individual files rather than a list so +# git will merge changes easier. + + +def find_test_dir(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Find the path to the dynamo expected failure and skip files. from os.path import abspath, basename, dirname, exists, join, normpath diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index 2a0883408892f..348b78237735c 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -202,6 +202,7 @@ def body_fn(iter_t, x): return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x)) +<<<<<<< HEAD def simple_while_loop_stack_output(iter_t, x): def cond_fn(iter_t, x): return iter_t > 0 @@ -211,6 +212,8 @@ def body_fn(iter_t, x): return torch._higher_order_ops.while_loop_stack_output(cond_fn, body_fn, (iter_t, x), tuple()) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs): make_arg = functools.partial( @@ -384,6 +387,7 @@ def fn(x): supports_autograd=False, ), OpInfo( +<<<<<<< HEAD name="while_loop_stack_output", variant_test_name="simple", op=simple_while_loop_stack_output, @@ -397,6 +401,8 @@ def fn(x): supports_autograd=False, ), OpInfo( +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) name="auto_functionalize", variant_test_name="simple", op=simple_auto_functionalize, diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 661181243250c..abedd2abb54d8 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -13,11 +13,17 @@ from torch.fx.experimental.proxy_tensor import make_fx from torch._inductor.graph import GraphLowering from torch._inductor.compile_fx import shape_env_from_inputs +<<<<<<< HEAD from torch._inductor.utils import OrderedSet from torch._inductor.codecache import CppCodeCache from torch._inductor.custom_graph_pass import CustomGraphModulePass from torch._inductor.codegen.common import ( get_custom_backend_config_for_device, +======= +from torch._inductor.codecache import CppCodeCache +from torch._inductor.custom_graph_pass import CustomGraphModulePass +from torch._inductor.codegen.common import ( +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) get_custom_backend_pass_for_device, get_scheduling_for_device, get_wrapper_codegen_for_device, @@ -29,7 +35,10 @@ from torch._inductor.utils import GPU_TYPES, get_gpu_type, is_gpu from torch.utils._helion import has_helion from torch.utils._triton import has_triton +<<<<<<< HEAD from torch.utils._config_module import ConfigModule +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.testing._internal.common_device_type import ( get_desired_device_type_test_bases, ) @@ -70,6 +79,7 @@ def test_cpu(): TRITON_HAS_CPU = False +<<<<<<< HEAD HAS_CUDA_AND_TRITON = torch.cuda.is_available() and HAS_TRITON HAS_XPU_AND_TRITON = torch.xpu.is_available() and HAS_TRITON @@ -77,6 +87,15 @@ def test_cpu(): HAS_MPS = torch.mps.is_available() HAS_GPU = HAS_CUDA_AND_TRITON or HAS_XPU_AND_TRITON +======= +HAS_CUDA = torch.cuda.is_available() and HAS_TRITON + +HAS_XPU = torch.xpu.is_available() and HAS_TRITON + +HAS_MPS = torch.mps.is_available() + +HAS_GPU = HAS_CUDA or HAS_XPU +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GPU_TYPE = get_gpu_type() @@ -164,16 +183,28 @@ def inner(fn): skipCPUIf = functools.partial(skipDeviceIf, device="cpu") IS_A100 = LazyVal( +<<<<<<< HEAD lambda: HAS_CUDA_AND_TRITON +======= + lambda: HAS_CUDA +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and get_gpu_shared_memory() == 166912 ) IS_H100 = LazyVal( +<<<<<<< HEAD lambda: HAS_CUDA_AND_TRITON and get_gpu_shared_memory() == 232448 ) IS_BIG_GPU = LazyVal(lambda: HAS_CUDA_AND_TRITON and is_big_gpu()) +======= + lambda: HAS_CUDA + and get_gpu_shared_memory() == 232448 +) + +IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def dummy_graph() -> GraphLowering: """ @@ -307,6 +338,7 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype): inverse_scale = scale.reciprocal() return x_fp8, inverse_scale +<<<<<<< HEAD class MockGraphHandler(GraphLowering): """Minimal mock graph handler for testing virtualized context.""" @@ -325,12 +357,18 @@ def get_dtype(self, buffer_name: str) -> torch.dtype: # noqa: ARG002 """Return default dtype for any buffer (for testing).""" return torch.float32 +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @contextlib.contextmanager def patch_inductor_backend( device: str, python_wrapper_codegen: PythonWrapperCodegen = None, +<<<<<<< HEAD custom_pass: CustomGraphModulePass = None, custom_backend_config: ConfigModule = None +======= + custom_pass: CustomGraphModulePass = None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): """ Patch the inductor backend for a specific device. @@ -342,9 +380,13 @@ def patch_inductor_backend( original_scheduling = get_scheduling_for_device(device) original_python_wrapper = get_wrapper_codegen_for_device(device, False) original_cpp_wrapper = get_wrapper_codegen_for_device(device, True) +<<<<<<< HEAD original_fx_wrapper = get_wrapper_codegen_for_device(device, fx_wrapper=True) original_custom_pass = get_custom_backend_pass_for_device(device) original_custom_backend_config = get_custom_backend_config_for_device(device) +======= + original_custom_pass = get_custom_backend_pass_for_device(device) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: # Register modified backend for the device @@ -353,9 +395,13 @@ def patch_inductor_backend( original_scheduling, python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper, original_cpp_wrapper, +<<<<<<< HEAD original_fx_wrapper, custom_pass if custom_pass is not None else original_custom_pass, custom_backend_config if custom_backend_config is not None else original_custom_backend_config +======= + custom_pass if custom_pass is not None else original_custom_pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) yield finally: @@ -365,7 +411,11 @@ def patch_inductor_backend( original_scheduling, original_python_wrapper, original_cpp_wrapper, +<<<<<<< HEAD original_fx_wrapper, original_custom_pass, original_custom_backend_config +======= + original_custom_pass +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py index 97dee3c7c0f4e..06058315273db 100644 --- a/torch/testing/_internal/opinfo/core.py +++ b/torch/testing/_internal/opinfo/core.py @@ -162,7 +162,13 @@ def __init__( # Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as # SampleInput(input, *args, **kwargs) but not to mix the two forms if args is not None or kwargs is not None: +<<<<<<< HEAD assert not var_args and not var_kwargs, """ +======= + assert ( + not var_args and not var_kwargs + ), """ +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) A SampleInput can be constructed "naturally" with *args and **kwargs or by explicitly setting the "args" and "kwargs" parameters, but the two methods of construction cannot be mixed!""" @@ -224,7 +230,11 @@ def _repr_helper(self, formatter): f"name={repr(self.name)}", ] +<<<<<<< HEAD return f"SampleInput({', '.join(a for a in arguments if a is not None)})" +======= + return f'SampleInput({", ".join(a for a in arguments if a is not None)})' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __repr__(self): return self._repr_helper(lambda x: x) @@ -1599,11 +1609,21 @@ def __post_init__(self): # returns a string identifier of the rule type @abstractmethod +<<<<<<< HEAD def type(self) -> str: ... # returns an appropriate context that handles the xfail, skips, etc. @abstractmethod def get_context(self, test_case): ... +======= + def type(self) -> str: + ... + + # returns an appropriate context that handles the xfail, skips, etc. + @abstractmethod + def get_context(self, test_case): + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # useful for specifying xfails @@ -1787,10 +1807,15 @@ def __init__( # kwargs to use when calling the op. This is required for operators that # have other required parameters besides the input tensor. generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: ( +<<<<<<< HEAD yield ( (), {}, ) +======= + yield (), + {}, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), # Options from the OpInfo base class **kwargs, @@ -2474,9 +2499,15 @@ def __init__( self.supports_one_python_scalar = True if self.supports_one_python_scalar: +<<<<<<< HEAD assert supports_rhs_python_scalar, ( "Can't support lhs and rhs Python scalars but not rhs scalars!" ) +======= + assert ( + supports_rhs_python_scalar + ), "Can't support lhs and rhs Python scalars but not rhs scalars!" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The following functions and classes are for testing elementwise unary operators. diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py index c5d08073803bb..1dbec437c9170 100644 --- a/torch/testing/_internal/opinfo/definitions/_masked.py +++ b/torch/testing/_internal/opinfo/definitions/_masked.py @@ -102,9 +102,14 @@ def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwar for mask in _generate_masked_op_mask( sample_input.input.shape, device, **kwargs ): +<<<<<<< HEAD sample_input_args, sample_input_kwargs = ( sample_input.args, dict(mask=mask, **sample_input.kwargs), +======= + sample_input_args, sample_input_kwargs = sample_input.args, dict( + mask=mask, **sample_input.kwargs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) yield SampleInput( sample_input.input.detach().requires_grad_(requires_grad), @@ -225,9 +230,14 @@ def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs): op_info, device, dtype, requires_grad, **kwargs ): sample_input_args, sample_input_kwargs = ( +<<<<<<< HEAD (ord,) + sample_input.args, sample_input.kwargs.copy(), ) +======= + ord, + ) + sample_input.args, sample_input.kwargs.copy() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) yield SampleInput( sample_input.input.clone().requires_grad_(requires_grad), args=sample_input_args, @@ -278,9 +288,14 @@ def masked_samples(): for mask in _generate_masked_op_mask( sample_input.input.shape, device, **kwargs ): +<<<<<<< HEAD sample_input_args, sample_input_kwargs = ( sample_input.args, dict(mask=mask, **sample_input.kwargs), +======= + sample_input_args, sample_input_kwargs = sample_input.args, dict( + mask=mask, **sample_input.kwargs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) yield SampleInput( sample_input.input.detach().requires_grad_(requires_grad), @@ -367,9 +382,14 @@ def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs) ): if type(mask) != torch.Tensor: continue +<<<<<<< HEAD sample_input_args, sample_input_kwargs = ( sample_input.args, dict(mask=mask, **sample_input.kwargs), +======= + sample_input_args, sample_input_kwargs = sample_input.args, dict( + mask=mask, **sample_input.kwargs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if "keepdim" in sample_input_kwargs: sample_input_kwargs.pop("keepdim") diff --git a/torch/testing/_internal/opinfo/definitions/nested.py b/torch/testing/_internal/opinfo/definitions/nested.py index 2f58ad2d7fb89..0c9ca5eaa0ad2 100644 --- a/torch/testing/_internal/opinfo/definitions/nested.py +++ b/torch/testing/_internal/opinfo/definitions/nested.py @@ -107,7 +107,10 @@ def get_dim_argnames(self) -> tuple[Optional[str], Optional[str]]: "flatten": ExtraOpData(is_view=True, dim_args=[["start_dim", "end_dim"]]), "flip": ExtraOpData(dim_args=[["dims..."]]), "gather": ExtraOpData(dim_args=[["dim"]]), +<<<<<<< HEAD "hash_tensor": ExtraOpData(dim_args=[["dim..."]]), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "imag": ExtraOpData(is_view=True), "index_add": ExtraOpData(dim_args=[["dim"]]), "index_copy": ExtraOpData(dim_args=[["dim"]]), diff --git a/torch/testing/_internal/opinfo/definitions/special.py b/torch/testing/_internal/opinfo/definitions/special.py index 1418685e88323..ae1ae8543fc5d 100644 --- a/torch/testing/_internal/opinfo/definitions/special.py +++ b/torch/testing/_internal/opinfo/definitions/special.py @@ -394,8 +394,16 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): skips=( DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), +<<<<<<< HEAD # Greatest absolute difference: nan DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), +======= + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), supports_one_python_scalar=True, supports_autograd=False, @@ -407,8 +415,16 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): skips=( DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), +<<<<<<< HEAD # Greatest absolute difference: nan DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), +======= + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), supports_one_python_scalar=True, supports_autograd=False, @@ -418,10 +434,20 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( +<<<<<<< HEAD DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), # Greatest absolute difference: nan DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), +======= + DecorateInfo( + unittest.skip( + "Skipping - testing takes an unreasonably long time, #79528" + ) + ), + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), supports_one_python_scalar=True, supports_autograd=False, @@ -431,10 +457,20 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( +<<<<<<< HEAD DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), # Greatest absolute difference: nan DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), +======= + DecorateInfo( + unittest.skip( + "Skipping - testing takes an unreasonably long time, #79528" + ) + ), + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), supports_one_python_scalar=True, supports_autograd=False, @@ -459,8 +495,16 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): skips=( DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), +<<<<<<< HEAD # Greatest absolute difference: inf DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), +======= + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), supports_one_python_scalar=True, supports_autograd=False, @@ -472,8 +516,16 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): skips=( DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), +<<<<<<< HEAD # Greatest absolute difference: nan DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), +======= + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), supports_one_python_scalar=True, supports_autograd=False, @@ -483,10 +535,25 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( +<<<<<<< HEAD DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), # Greatest absolute difference: nan DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), +======= + DecorateInfo( + unittest.skip( + "Skipping - testing takes an unreasonably long time, #79528" + ) + ), + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), supports_one_python_scalar=True, supports_autograd=False, @@ -580,10 +647,25 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( +<<<<<<< HEAD DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), # Greatest absolute difference: nan DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), +======= + DecorateInfo( + unittest.skip( + "Skipping - testing takes an unreasonably long time, #79528" + ) + ), + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), supports_one_python_scalar=True, supports_autograd=False, @@ -593,10 +675,25 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( +<<<<<<< HEAD DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), # Greatest absolute difference: nan DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), +======= + DecorateInfo( + unittest.skip( + "Skipping - testing takes an unreasonably long time, #79528" + ) + ), + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), supports_one_python_scalar=True, supports_autograd=False, @@ -606,10 +703,25 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( +<<<<<<< HEAD DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), # Greatest absolute difference: nan DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), +======= + DecorateInfo( + unittest.skip( + "Skipping - testing takes an unreasonably long time, #79528" + ) + ), + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), supports_one_python_scalar=True, supports_autograd=False, @@ -619,10 +731,25 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): dtypes=all_types_and(torch.bool), promotes_int_to_float=True, skips=( +<<<<<<< HEAD DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), # Greatest absolute difference: nan DecorateInfo(unittest.expectedFailure, "TestCommon", "test_compare_cpu"), +======= + DecorateInfo( + unittest.skip( + "Skipping - testing takes an unreasonably long time, #79528" + ) + ), + DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), + DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), + DecorateInfo( + unittest.skip("testing takes an unreasonably long time, #79528"), + "TestCommon", + "test_compare_cpu", + ), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), supports_one_python_scalar=True, supports_autograd=False, diff --git a/torch/testing/_internal/optests/autograd_registration.py b/torch/testing/_internal/optests/autograd_registration.py index ae5ae34059eaa..b50135e39009b 100644 --- a/torch/testing/_internal/optests/autograd_registration.py +++ b/torch/testing/_internal/optests/autograd_registration.py @@ -83,17 +83,27 @@ def autograd_registration_check(op, args, kwargs): # Determine which AutogradBACKEND key to check all_device_types = {arg.device.type for arg in all_tensors} +<<<<<<< HEAD if not all_device_types.issubset(["cpu", "cuda", "xpu"]): # Don't want to support other keys yet raise NotImplementedError( f"autograd_registration_check: NYI devices other than CPU/CUDA/XPU, got {all_device_types}" +======= + if not all_device_types.issubset(["cpu", "cuda"]): + # Don't want to support other keys yet + raise NotImplementedError( + f"autograd_registration_check: NYI devices other than CPU/CUDA, got {all_device_types}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if "cuda" in all_device_types: key = "AutogradCUDA" elif "cpu" in all_device_types: key = "AutogradCPU" +<<<<<<< HEAD elif "xpu" in all_device_types: key = "AutogradXPU" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), key): return @@ -129,6 +139,10 @@ def not_an_input_and_requires_grad(tensor): f"which may lead to silently incorrect results. If your operator consists " f"of regular PyTorch operations, consider not using an operator at all " f"or registering your operator as CompositeImplicitAutograd. If you have " +<<<<<<< HEAD f"an autograd.Function registered to a backend (CPU/CUDA/XPU) key, the correct " +======= + f"an autograd.Function registered to a backend (CPU/CUDA) key, the correct " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) f"location for it is the Autograd key." ) diff --git a/torch/testing/_internal/torchbind_impls.py b/torch/testing/_internal/torchbind_impls.py index e5162ba0d6cb6..d8ab9aa39e3c9 100644 --- a/torch/testing/_internal/torchbind_impls.py +++ b/torch/testing/_internal/torchbind_impls.py @@ -46,6 +46,7 @@ def fake_queue_pop(tq): def fake_queue_push(tq, x): return tq.push(x) +<<<<<<< HEAD torch.library.register_autocast( "_TorchScriptTesting::queue_push", "cpu", torch.float32 ) @@ -60,6 +61,8 @@ def fake_queue_push(tq, x): "_TorchScriptTesting::queue_pop", "cuda", torch.float32 ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @torch.library.register_fake("_TorchScriptTesting::queue_size") def fake_queue_size(tq): return tq.size() diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index 4edaf86dd1d71..0d547d93c2b66 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -2,6 +2,7 @@ import unittest +<<<<<<< HEAD from torch.testing._internal.inductor_utils import ( HAS_CUDA_AND_TRITON, HAS_GPU, @@ -16,12 +17,20 @@ requires_gpu_and_triton = unittest.skipUnless( HAS_XPU_AND_TRITON or HAS_CUDA_AND_TRITON, "requires gpu and triton" ) +======= +from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU +from torch.utils._triton import has_triton + + +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu") if has_triton(): import triton from triton import language as tl +<<<<<<< HEAD import torch def _get_strange_configs() -> list[triton.Config]: @@ -79,6 +88,8 @@ def _get_strange_configs() -> list[triton.Config]: ] return configs +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Define here so that multiple tests can take advantage of it @triton.jit def add_kernel( @@ -994,6 +1005,7 @@ def kernel_inline_asm_single_quotes( ) tl.store(out_ptr + offsets, cos_pow, mask=offsets < numel) +<<<<<<< HEAD @triton.jit def add_kernel_with_boolean_param( in_ptr0, @@ -1015,6 +1027,8 @@ def add_kernel_with_boolean_param( output = x tl.store(out_ptr + offsets, output, mask=mask) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # support the old (experimental) and new (tensor_descriptor) APIs def create_tensor_descriptor_shim( tensor, block_sizes: list[int], new_api: bool = True diff --git a/torch/utils/__init__.py b/torch/utils/__init__.py index 1c3ec15790063..68baac806d79b 100644 --- a/torch/utils/__init__.py +++ b/torch/utils/__init__.py @@ -29,7 +29,17 @@ def set_module(obj, mod): obj.__module__ = mod +<<<<<<< HEAD cmake_prefix_path = _osp.join(_osp.dirname(_osp.dirname(__file__)), "share", "cmake") +======= +if torch._running_with_deploy(): + # not valid inside torch_deploy interpreter, no paths exists for frozen modules + cmake_prefix_path = None +else: + cmake_prefix_path = _osp.join( + _osp.dirname(_osp.dirname(__file__)), "share", "cmake" + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def swap_tensors(t1, t2): diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 811b45fd1d697..b573b0f3509b0 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -52,17 +52,29 @@ class _Config(Generic[T]): alias: If set, the directly use the value of the alias. env_name_force: If set, this environment variable has precedence over everything after this. +<<<<<<< HEAD If multiple env variables are given, the precedence order is from +======= + If multiple env variables are given, the precendence order is from +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) left to right. user_override: If a user sets a value (i.e. foo.bar=True), that has precedence over everything after this. env_name_default: If set, this environment variable will override everything after this. +<<<<<<< HEAD If multiple env variables are given, the precedence order is from left to right. justknob: If this pytorch installation supports justknobs, that will override defaults, but will not override the user_override precedence. default: This value is the lowest precedence, and will be used if nothing is +======= + If multiple env variables are given, the precendence order is from + left to right. + justknob: If this pytorch installation supports justknobs, that will + override defaults, but will not override the user_override precendence. + default: This value is the lowest precendance, and will be used if nothing is +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) set. Environment Variables: @@ -112,7 +124,11 @@ def __init__( @staticmethod def string_or_list_of_string_to_list( +<<<<<<< HEAD val: Optional[Union[str, list[str]]], +======= + val: Optional[Union[str, list[str]]] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Optional[list[str]]: if val is None: return None @@ -135,7 +151,12 @@ def Config( env_name_force: Optional[Union[str, list[str]]] = None, value_type: Optional[type] = None, alias: Optional[str] = None, +<<<<<<< HEAD ) -> T: ... +======= + ) -> T: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: @@ -322,9 +343,15 @@ def __init__(self, config: _Config): # Ensure justknobs and envvars are allowlisted types if self.justknob is not None and self.default is not None: +<<<<<<< HEAD assert isinstance(self.default, bool), ( f"justknobs only support booleans, {self.default} is not a boolean" ) +======= + assert isinstance( + self.default, bool + ), f"justknobs only support booleans, {self.default} is not a boolean" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.value_type is not None and ( config.env_name_default is not None or config.env_name_force is not None ): @@ -333,9 +360,13 @@ def __init__(self, config: _Config): str, Optional[bool], Optional[str], +<<<<<<< HEAD ), ( f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither" ) +======= + ), f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class ConfigModule(ModuleType): diff --git a/torch/utils/_contextlib.py b/torch/utils/_contextlib.py index 8db27efa270a0..335c8fa5c0e6a 100644 --- a/torch/utils/_contextlib.py +++ b/torch/utils/_contextlib.py @@ -48,7 +48,11 @@ def generator_context(*args, **kwargs): gen.close() raise +<<<<<<< HEAD except BaseException: # noqa: B036 +======= + except BaseException: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Propagate the exception thrown at us by the caller with ctx_factory(): response = gen.throw(*sys.exc_info()) diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index efe140f10f014..dd7e7e1e23cd7 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -282,9 +282,15 @@ def tree_is_leaf( False >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple)) True +<<<<<<< HEAD >>> tree_is_leaf({"a": 1, "b": 2, "c": 3}) False >>> tree_is_leaf({"a": 1, "b": 2, "c": None}) +======= + >>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3}) + False + >>> tree_is_leaf({'a': 1, 'b': 2, 'c': None}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) False Args: @@ -586,6 +592,7 @@ def tree_map_( # These specializations help with type inference on the lambda passed to this # function @overload +<<<<<<< HEAD def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ... @@ -597,10 +604,25 @@ def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]] def map_only( type_or_types_or_pred: Type3[T, S, U], / ) -> MapOnlyFn[Fn3[T, S, U, Any]]: ... +======= +def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: + ... + + +@overload +def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: + ... + + +@overload +def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This specialization is needed for the implementations below that call @overload +<<<<<<< HEAD def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ... @@ -608,6 +630,15 @@ def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ... def map_only( type_or_types_or_pred: Callable[[Any], bool], / ) -> MapOnlyFn[FnAny[Any]]: ... +======= +def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: + ... + + +@overload +def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def map_only( @@ -663,7 +694,12 @@ def tree_map_only( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -673,7 +709,12 @@ def tree_map_only( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -683,7 +724,12 @@ def tree_map_only( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -693,7 +739,12 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -703,7 +754,12 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def tree_map_only( @@ -723,7 +779,12 @@ def tree_map_only_( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -733,7 +794,12 @@ def tree_map_only_( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -743,7 +809,12 @@ def tree_map_only_( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -753,7 +824,12 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -763,7 +839,12 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def tree_map_only_( @@ -801,7 +882,12 @@ def tree_all_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> bool: ... +======= +) -> bool: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -811,7 +897,12 @@ def tree_all_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> bool: ... +======= +) -> bool: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -821,7 +912,12 @@ def tree_all_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> bool: ... +======= +) -> bool: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def tree_all_only( @@ -842,7 +938,12 @@ def tree_any_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> bool: ... +======= +) -> bool: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -852,7 +953,12 @@ def tree_any_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> bool: ... +======= +) -> bool: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -862,7 +968,12 @@ def tree_any_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> bool: ... +======= +) -> bool: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def tree_any_only( @@ -994,7 +1105,11 @@ def __instancecheck__(self, instance: object) -> bool: return _is_pytreespec_instance(instance) and instance.is_leaf() +<<<<<<< HEAD class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): # type: ignore[misc,final] +======= +class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __new__(cls) -> "LeafSpec": return optree.treespec_leaf(none_is_leaf=True) # type: ignore[return-value] diff --git a/torch/utils/_device.py b/torch/utils/_device.py index de3ee4a9e3447..f62b458befca4 100644 --- a/torch/utils/_device.py +++ b/torch/utils/_device.py @@ -82,7 +82,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): CURRENT_DEVICE = self.old_device cur_stack = [] # Invariant: there should only be one DeviceContext on the stack at any time +<<<<<<< HEAD # (At the bottom), pop all modes until we hit the bottom, assert it's a DeviceContext +======= + # (At the bottom), pop all mdoes until we hit the bottom, assert it's a DeviceContext +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # or else someone else has popped it! for _ in range(_len_torch_function_stack() - 1): mode = _pop_mode() diff --git a/torch/utils/_foreach_utils.py b/torch/utils/_foreach_utils.py index e3a2070f2d4d6..dc0bf83de6cbd 100644 --- a/torch/utils/_foreach_utils.py +++ b/torch/utils/_foreach_utils.py @@ -8,7 +8,11 @@ def _get_foreach_kernels_supported_devices() -> list[str]: r"""Return the device type list that supports foreach kernels.""" +<<<<<<< HEAD return ["cuda", "xpu", "mtia", torch._C._get_privateuse1_backend_name()] +======= + return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _get_fused_kernels_supported_devices() -> list[str]: @@ -19,7 +23,10 @@ def _get_fused_kernels_supported_devices() -> list[str]: "xpu", "hpu", "cpu", +<<<<<<< HEAD "mtia", +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._C._get_privateuse1_backend_name(), ] diff --git a/torch/utils/_freeze.py b/torch/utils/_freeze.py new file mode 100644 index 0000000000000..8696065adb9f9 --- /dev/null +++ b/torch/utils/_freeze.py @@ -0,0 +1,292 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +""" +Freeze Python packages. + + + + +Freezing makes it possible to ship arbitrary Python modules as part of a C++ +library. The Python source of the module is compiled to bytecode and written +to `.c` files, to be imported by Python's built-in FrozenImporter. + +In a normal Python installation, FrozenImporter is only used to bootstrap the +initialization of the import machinery. Python's importers are defined in +Python (see `_bootstrap.py` and `_bootstrap_external.py`) but need to be +retrieved before any importers are available. Freezing the module bytecode +resolves this circular dependency. + +This script will freeze the Python standard library. It produces two things: +- Bytecode files: A set of `.c` that define C variables containing Python bytecode. +- Main file: A `main.c` file listing all of these modules in the right form to be + consumed by FrozenImporter. + +The library that wishes to these modules make them available to the local +Python instance by extending `PyImport_FrozenModules` appropriately (see +https://docs.python.org/3/c-api/import.html#c.PyImport_FrozenModules). +""" + +import argparse +import functools +import itertools +import marshal +import os +import types +from dataclasses import dataclass +from pathlib import Path + + +PATH_MARKER = "" +MAIN_INCLUDES = """#include + +""" + +MAIN_PREFIX_TEMPLATE = """ +// Compiled standard library modules. These should be appended to the existing +// `PyImport_FrozenModules` that ships with CPython. +struct _frozen {}[] = {{ +""" + +FAKE_PREFIX = MAIN_PREFIX_TEMPLATE.format("_PyImport_FrozenModules") + +MAIN_SUFFIX = """\ + {0, 0, 0} /* sentinel */ +}; +""" + +# Exclude some standard library modules to: +# 1. Slim down the final frozen lib. +# 2. Remove functionality we don't want to support. +DENY_LIST = [ + # Interface to unix databases + "dbm", + # ncurses bindings (terminal interfaces) + "curses", + # Tcl/Tk GUI + "tkinter", + "tkinter", + # Tests for the standard library + "test", + "tests", + "idle_test", + "__phello__.foo.py", + # importlib frozen modules. These are already baked into CPython. + "_bootstrap.py", + "_bootstrap_external.py", +] + +NUM_BYTECODE_FILES = 5 + + +def indent_msg(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + args[0].indent += 1 + ret = fn(*args, **kwargs) + args[0].indent -= 1 + return ret + + return wrapper + + +@dataclass +class FrozenModule: + # The fully qualified module name, e.g. 'foo.bar.baz' + module_name: str + # The name of the C variable that holds the bytecode, e.g. 'M_foo__bar__baz' + c_name: str + # The size of the C variable. Negative if this module is a package. + size: int + # The frozen bytecode + bytecode: bytes + + +class Freezer: + def __init__(self, verbose: bool): + self.frozen_modules: list[FrozenModule] = [] + self.indent: int = 0 + self.verbose: bool = verbose + + def msg(self, path: Path, code: str): + if not self.verbose: + return + # P: package dir + # F: python file + # S: skipped (not a package dir) + # X: skipped (deny-listed) + # N: skipped (not a python file) + print(" " * self.indent, end="") + print(f"{code} {path}") + + def write_bytecode(self, install_root): + """ + Write the `.c` files containing the frozen bytecode. + + Shared frozen modules evenly across the files. + """ + bytecode_file_names = [f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)] + bytecode_files = [ + open(os.path.join(install_root, name), "w") for name in bytecode_file_names + ] + it = itertools.cycle(bytecode_files) + for m in self.frozen_modules: + self.write_frozen(m, next(it)) + + for f in bytecode_files: + f.close() + + def write_main(self, install_root, oss, symbol_name): + """Write the `main.c` file containing a table enumerating all the frozen modules.""" + with open(os.path.join(install_root, "main.c"), "w") as outfp: + outfp.write(MAIN_INCLUDES) + for m in self.frozen_modules: + outfp.write(f"extern unsigned char {m.c_name}[];\n") + + outfp.write(MAIN_PREFIX_TEMPLATE.format(symbol_name)) + for m in self.frozen_modules: + outfp.write(f'\t{{"{m.module_name}", {m.c_name}, {m.size}}},\n') + outfp.write(MAIN_SUFFIX) + if oss: + outfp.write(FAKE_PREFIX) + outfp.write(MAIN_SUFFIX) + + def write_frozen(self, m: FrozenModule, outfp): + """Write a single frozen module's bytecode out to a C variable.""" + outfp.write(f"unsigned char {m.c_name}[] = {{") + for i in range(0, len(m.bytecode), 16): + outfp.write("\n\t") + for c in bytes(m.bytecode[i : i + 16]): + outfp.write(f"{c:d},") + outfp.write("\n};\n") + + def compile_path(self, path: Path, top_package_path: Path): + """Entry point for compiling a Path object.""" + if path.is_dir(): + self.compile_package(path, top_package_path) + else: + self.compile_file(path, top_package_path) + + @indent_msg + def compile_package(self, path: Path, top_package_path: Path): + """Compile all the files within a Python package dir.""" + assert path.is_dir() + if path.name in DENY_LIST: + self.msg(path, "X") + return + + # Python packages are directories that have __init__.py in them. + is_package_dir = any(child.name == "__init__.py" for child in path.iterdir()) + if not is_package_dir: + self.msg(path, "S") + return + + self.msg(path, "P") + # Recursively compile all children in this dir + for child in path.iterdir(): + self.compile_path(child, top_package_path) + + def get_module_qualname(self, file_path: Path, top_package_path: Path) -> list[str]: + # `path` looks like 'Lib/foo/bar/baz.py' + + # chop off 'Lib/' to get something that represents a Python module hierarchy. + # e.g. 'foo/bar/baz.py', which maps to 'foo.bar.baz' + normalized_path = file_path.relative_to(top_package_path.parent) + + if normalized_path.name == "__init__.py": + # Special handling for `__init__.py`. In this case, this file + # specifies that the containing directory should be treated as a package. + # For 'foo/bar/baz/__init__.py': + # - The module name is 'baz' + module_basename = normalized_path.parent.name + # - The parent is foo.bar (need to shave off the 'baz') + module_parent = normalized_path.parent.parent.parts + else: + module_basename = normalized_path.stem + module_parent = normalized_path.parent.parts + return list(module_parent) + [module_basename] + + def compile_string(self, file_content: str) -> types.CodeType: + # instead of passing in the real build time path to 'compile', we + # pass in a marker instead. This prevents the build time path being + # leaked to runtime. That path may not be available at runtime. + # Setting the path to a mark make sure it's a hard error rather + # than a flaky error when inspect module tries to retrieve python source + # code during torchscripting. + path_marker = PATH_MARKER + return compile(file_content, path_marker, "exec") + + @indent_msg + def compile_file(self, path: Path, top_package_path: Path): + """ + Compile a Python source file to frozen bytecode. + + Append the result to `self.frozen_modules`. + """ + assert path.is_file() + if path.suffix != ".py": + self.msg(path, "N") + return + + if path.name in DENY_LIST: + self.msg(path, "X") + return + + self.msg(path, "F") + module_qualname = self.get_module_qualname(path, top_package_path) + module_mangled_name = "__".join(module_qualname) + c_name = "M_" + module_mangled_name + + with open(path) as src_file: + co = self.compile_string(src_file.read()) + + bytecode = marshal.dumps(co) + size = len(bytecode) + if path.name == "__init__.py": + # Python packages are signified by negative size. + size = -size + self.frozen_modules.append( + FrozenModule(".".join(module_qualname), c_name, size, bytecode) + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Compile py source") + parser.add_argument("paths", nargs="*", help="Paths to freeze.") + parser.add_argument("--verbose", action="store_true", help="Print debug logs") + parser.add_argument( + "--install-dir", "--install_dir", help="Root directory for all output files" + ) + parser.add_argument( + "--oss", + action="store_true", + help="If it's OSS build, add a fake _PyImport_FrozenModules", + ) + parser.add_argument( + "--symbol-name", + "--symbol_name", + help="The name of the frozen module array symbol to generate", + default="_PyImport_FrozenModules_torch", + ) + + args = parser.parse_args() + + f = Freezer(args.verbose) + + for p in args.paths: + path = Path(p) + if path.is_dir() and not Path.exists(path / "__init__.py"): + # this 'top level path p' is a standard directory containing modules, + # not a module itself + # each 'mod' could be a dir containing __init__.py or .py file + # NB: sorted to make sure this is deterministic + for mod in sorted(path.glob("*")): + f.compile_path(mod, mod) + else: + f.compile_path(path, path) + + f.write_bytecode(args.install_dir) + f.write_main(args.install_dir, args.oss, args.symbol_name) + + +if __name__ == "__main__": + main() # pragma: no cover diff --git a/torch/utils/_functools.py b/torch/utils/_functools.py index 0b555ffc27f96..0186a5cd3e0ad 100644 --- a/torch/utils/_functools.py +++ b/torch/utils/_functools.py @@ -12,7 +12,11 @@ def cache_method( +<<<<<<< HEAD f: Callable[Concatenate[_C, _P], _T], +======= + f: Callable[Concatenate[_C, _P], _T] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Callable[Concatenate[_C, _P], _T]: """ Like `@functools.cache` but for methods. diff --git a/torch/utils/_import_utils.py b/torch/utils/_import_utils.py index 240f92acacb9d..803e7dea4899e 100644 --- a/torch/utils/_import_utils.py +++ b/torch/utils/_import_utils.py @@ -3,6 +3,11 @@ from types import ModuleType from typing import Optional +<<<<<<< HEAD +======= +import torch + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _check_module_exists(name: str) -> bool: r"""Returns if a top-level module with :attr:`name` exists *without** @@ -20,7 +25,15 @@ def _check_module_exists(name: str) -> bool: @functools.lru_cache def dill_available() -> bool: +<<<<<<< HEAD return _check_module_exists("dill") +======= + return ( + _check_module_exists("dill") + # dill fails to import under torchdeploy + and not torch._running_with_deploy() + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @functools.lru_cache diff --git a/torch/utils/_ordered_set.py b/torch/utils/_ordered_set.py index b2a69fc0ff340..dead84aba177e 100644 --- a/torch/utils/_ordered_set.py +++ b/torch/utils/_ordered_set.py @@ -1,7 +1,10 @@ from __future__ import annotations from collections.abc import ( +<<<<<<< HEAD Hashable, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Iterable, Iterator, MutableSet, @@ -11,8 +14,13 @@ from typing import Any, cast, Optional, TypeVar +<<<<<<< HEAD T = TypeVar("T", bound=Hashable) T_co = TypeVar("T_co", bound=Hashable, covariant=True) +======= +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __all__ = ["OrderedSet"] diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 5441468eb3b5f..1f21aa34c2155 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -1,11 +1,18 @@ # mypy: allow-untyped-defs import contextlib +<<<<<<< HEAD import functools +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import warnings from collections import deque from collections.abc import Sequence from dataclasses import dataclass +<<<<<<< HEAD from typing import Optional, overload, Protocol, Union +======= +from typing import Any, Optional, overload, Protocol, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing_extensions import TypeIs import torch @@ -28,8 +35,11 @@ _is_in_torch_dispatch_mode = False _is_in_non_infra_torch_dispatch_mode = False +<<<<<<< HEAD # If inside any mode that has ignore_compile_internals() = False _is_in_any_mode_without_ignore_compile_internals = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def is_in_torch_dispatch_mode(include_infra_modes=True) -> bool: @@ -40,10 +50,13 @@ def is_in_torch_dispatch_mode(include_infra_modes=True) -> bool: ) +<<<<<<< HEAD def is_in_any_mode_without_ignore_compile_internals() -> bool: return _is_in_any_mode_without_ignore_compile_internals +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class TorchDispatchMode: """ A ``TorchDispatchMode`` allows you to override the meaning of all @@ -75,12 +88,15 @@ class TorchDispatchMode: API self-referential (beware of infinite loops, in this case!) """ +<<<<<<< HEAD # - When False, custom torch dispatch mode will error out explicitly when a hop # is called under the mode. # - When True, custom torch dispatch mode's __torch_dispatch__ will be triggered. # Mode authors can implement how the mode interacts with higher order operators. supports_higher_order_operators = False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self, _dispatch_key=None): if _dispatch_key is not None: assert isinstance(_dispatch_key, torch._C.DispatchKey) @@ -88,9 +104,12 @@ def __init__(self, _dispatch_key=None): self.old_dispatch_mode_flags: deque[bool] = deque() self.old_non_infra_dispatch_mode_flags: deque[bool] = deque() +<<<<<<< HEAD self.old_without_ignore_compile_internals_dispatch_mode_flags: deque[bool] = ( deque() ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _lazy_init_old_dispatch_mode_flags(self): if not hasattr(self, "old_dispatch_mode_flags"): @@ -99,6 +118,7 @@ def _lazy_init_old_dispatch_mode_flags(self): if not hasattr(self, "old_non_infra_dispatch_mode_flags"): self.old_non_infra_dispatch_mode_flags: deque[bool] = deque() # type: ignore[no-redef] +<<<<<<< HEAD if not hasattr( self, "old_without_ignore_compile_internals_dispatch_mode_flags" ): @@ -106,14 +126,19 @@ def _lazy_init_old_dispatch_mode_flags(self): bool ] = deque() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __torch_dispatch__(self, func, types, args=(), kwargs=None): raise NotImplementedError def __enter__(self): global _is_in_torch_dispatch_mode global _is_in_non_infra_torch_dispatch_mode +<<<<<<< HEAD global _is_in_any_mode_without_ignore_compile_internals +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Previously, there wasn't any state in this class' constructor # super calls were added to existing modes, but for any new modes # this will replicate the previous behavior of not strictly needing @@ -127,6 +152,7 @@ def __enter__(self): _is_in_non_infra_torch_dispatch_mode = ( _is_in_non_infra_torch_dispatch_mode or not self.is_infra_mode() ) +<<<<<<< HEAD self.old_without_ignore_compile_internals_dispatch_mode_flags.append( _is_in_any_mode_without_ignore_compile_internals ) @@ -134,6 +160,8 @@ def __enter__(self): _is_in_any_mode_without_ignore_compile_internals or not self.ignore_compile_internals() ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _push_mode(self) return self @@ -149,10 +177,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): _is_in_non_infra_torch_dispatch_mode = ( self.old_non_infra_dispatch_mode_flags.pop() ) +<<<<<<< HEAD global _is_in_any_mode_without_ignore_compile_internals _is_in_any_mode_without_ignore_compile_internals = ( self.old_without_ignore_compile_internals_dispatch_mode_flags.pop() ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _pop_mode(mb_dk_or_mode_key) @classmethod @@ -167,6 +198,7 @@ def push(cls, *args, **kwargs): def is_infra_mode(cls): return False +<<<<<<< HEAD @classmethod def ignore_compile_internals(cls): """Ignore operators that are compiled via torch.compile. @@ -199,6 +231,8 @@ def f(x): return True return False +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _get_current_dispatch_mode(): stack_len = _len_torch_dispatch_stack() @@ -364,12 +398,22 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): # Subtypes which have __tensor_flatten__ and __tensor_unflatten__. class TensorWithFlatten(Protocol): +<<<<<<< HEAD def __tensor_flatten__(self) -> tuple[Sequence[str], object]: ... +======= + def __tensor_flatten__(self) -> tuple[Sequence[str], object]: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def __tensor_unflatten__( inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int +<<<<<<< HEAD ) -> torch.Tensor: ... +======= + ) -> torch.Tensor: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # It would be really nice to be able to say that the return of # is_traceable_wrapper_subclass() is Intersection[torch.Tensor, @@ -378,6 +422,7 @@ def __tensor_unflatten__( shape: torch._C.Size @overload +<<<<<<< HEAD def stride(self, dim: None = None) -> tuple[int, ...]: ... @overload @@ -392,6 +437,28 @@ def size(self, dim: int) -> int: ... def storage_offset(self) -> int: ... def dim(self) -> int: ... +======= + def stride(self, dim: None = None) -> tuple[int, ...]: + ... + + @overload + def stride(self, dim: int) -> int: + ... + + @overload + def size(self, dim: None = None) -> tuple[int, ...]: + ... + + @overload + def size(self, dim: int) -> int: + ... + + def storage_offset(self) -> int: + ... + + def dim(self) -> int: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload def to( @@ -401,7 +468,12 @@ def to( copy: bool = False, *, memory_format: Optional[torch.memory_format] = None, +<<<<<<< HEAD ) -> torch.Tensor: ... +======= + ) -> torch.Tensor: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload def to( @@ -412,7 +484,12 @@ def to( copy: bool = False, *, memory_format: Optional[torch.memory_format] = None, +<<<<<<< HEAD ) -> torch.Tensor: ... +======= + ) -> torch.Tensor: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload def to( @@ -422,7 +499,12 @@ def to( copy: bool = False, *, memory_format: Optional[torch.memory_format] = None, +<<<<<<< HEAD ) -> torch.Tensor: ... +======= + ) -> torch.Tensor: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]: @@ -455,7 +537,11 @@ def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]: that require the stride info to be constructed. In most cases, this arg can be safely ignored. """ +<<<<<<< HEAD is_subclass = isinstance(t, torch.Tensor) and type(t) is not torch.Tensor +======= + is_subclass = isinstance(t, torch.Tensor) and type(t) != torch.Tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ( is_subclass and hasattr(t, "__tensor_flatten__") @@ -467,7 +553,11 @@ def is_traceable_wrapper_subclass_type(t: type) -> TypeIs[type[TensorWithFlatten """Same as above, but takes a type argument instead of an instance.""" return ( issubclass(t, torch.Tensor) +<<<<<<< HEAD and t is not torch.Tensor +======= + and t != torch.Tensor +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__") ) @@ -536,6 +626,7 @@ def alias_non_inplace_storage(arg, ret): # in theory if a subclass that needs this API wants to sometimes return # plain tensors, we could remove the assert and just not perform the aliasing, # but it seems safer to learn more about this case first. +<<<<<<< HEAD # # Performance note: This is all just to assert that the argument and result # types match, checking that is cheaper than is_traceable_wrapper_subclass_type, @@ -546,6 +637,9 @@ def alias_non_inplace_storage(arg, ret): is_traceable_wrapper_subclass_type(arg_type) or is_traceable_wrapper_subclass_type(ret_type) ): +======= + if is_traceable_wrapper_subclass(arg) or is_traceable_wrapper_subclass(ret): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ret_list = ret if isinstance(ret, list) else [ret] for r in ret_list: assert type(arg) == type( @@ -570,12 +664,26 @@ def alias_non_inplace_storage(arg, ret): assert isinstance(ret, torch.Tensor), f"type: {type(ret)}" torch._functionalize_unsafe_set(ret, arg) +<<<<<<< HEAD for arg_idx, schema_arg in enumerate(schema_info.args): for return_idx, schema_out in enumerate(schema_info.outs): is_read_only_alias_match = ( schema_arg.alias_set & schema_out.alias_set ) and not schema_arg.is_write if is_read_only_alias_match: +======= + def is_read_only_alias_match(arg, ret): + shared_aliases = arg.alias_set & ret.alias_set + return len(shared_aliases) > 0 and not arg.is_write + + num_args = len(func._schema.arguments) + num_returns = len(func._schema.returns) + for arg_idx in range(num_args): + for return_idx in range(num_returns): + if is_read_only_alias_match( + schema_info.args[arg_idx], schema_info.outs[return_idx] + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) alias_non_inplace_storage(args[arg_idx], outs[return_idx]) @@ -594,6 +702,7 @@ class SchemaInfo: args: list[AliasInfo] outs: list[AliasInfo] +<<<<<<< HEAD # NOTE[SchemaInfo int_tags]: This has nothing to do with aliasing, but we take # advantage of our existing caching of data for each OpOverload to paper over an # efficiency problem with pybind11::enum_ (which currently is used to implement @@ -601,12 +710,23 @@ class SchemaInfo: # each element must be converted to int with the __int__ method, which incurs a lot # of overhead. Converting to int once and caching removes this per-op overhead. int_tags: list[int] +======= + +# Can't import torch._ops.OpOverload due to circular reference +parsed_schema_map: dict[Any, SchemaInfo] = {} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Given an OpOverload, returns schema information on it. # This is cached for efficiency, since it can involve running torchgen +<<<<<<< HEAD @functools.cache def get_alias_info(func) -> SchemaInfo: +======= +def get_alias_info(func) -> SchemaInfo: + if func in parsed_schema_map: + return parsed_schema_map[func] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # For ATen ops: use torchgen (since torchscript parser doesn't handle alias annotations # properly for some ops that output tensorlists) if func.namespace == "aten": @@ -668,6 +788,7 @@ def get_alias_info(func) -> SchemaInfo: ) for a in func._schema.returns ] +<<<<<<< HEAD schema_info = SchemaInfo( args=arg_schemas, outs=out_schemas, int_tags=[int(x) for x in func.tags] ) @@ -678,6 +799,13 @@ def get_alias_info(func) -> SchemaInfo: _TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload] +======= + schema_info = SchemaInfo(args=arg_schemas, outs=out_schemas) + parsed_schema_map[func] = schema_info + return schema_info + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def return_and_correct_aliasing(func, args, kwargs, out): """ This function should be used by wrapper tensor ``__torch_dispatch__`` subclasses @@ -699,6 +827,7 @@ def return_and_correct_aliasing(func, args, kwargs, out): schema_info = get_alias_info(func) def get_write_alias(x): +<<<<<<< HEAD alias_set = x.alias_set if not alias_set or not x.is_write: return None @@ -707,6 +836,16 @@ def get_write_alias(x): # timeit says next(iter(alias_set)) is faster than list(alias_set)[0] even for # set of size 1 on Python 3.13. return next(iter(alias_set)) +======= + if len(x.alias_set) == 0: + return None + alias_set = list(x.alias_set) + # torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing + assert len(alias_set) == 1 + if x.is_write: + return alias_set[0] + return None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_arg_from_alias(output_alias, schema_info, args, kwargs): new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc] @@ -732,8 +871,12 @@ def get_arg_from_alias(output_alias, schema_info, args, kwargs): # For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's # metadata is set correctly. +<<<<<<< HEAD # See NOTE[SchemaInfo int_tags] above. if _TORCH_TAG_INPLACE_VIEW_INT in schema_info.int_tags: +======= + if torch.Tag.inplace_view in func.tags: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # no_dispatch() to make sure that we secretly change the metadata on the wrapper, # but don't end up dispatching the op anywhere else. mutated_args = [ @@ -762,6 +905,7 @@ def get_arg_from_alias(output_alias, schema_info, args, kwargs): # Next: we need to make sure to return inputs directly, if the output is a mutable alias (e.g. add_()). +<<<<<<< HEAD # Compute write aliases once instead of repeatedly. schema_info_outs_write_aliases = [get_write_alias(r) for r in schema_info.outs] # simple case: none of our outputs have mutable aliases, so we can return the output as-is @@ -775,13 +919,37 @@ def get_arg_from_alias(output_alias, schema_info, args, kwargs): if len(schema_info_outs_write_aliases) == 1: return get_arg_from_alias( schema_info_outs_write_aliases[0], schema_info, args, kwargs +======= + # simple case: none of our outputs have mutable aliases, so we can return the output as-is + if not any(get_write_alias(r) is not None for r in schema_info.outs): + return out + + # simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)" + if not all(get_write_alias(r) is not None for r in schema_info.outs): + raise RuntimeError("Unsupported schema: " + str(func._schema)) + + if len(func._schema.returns) == 1: + return get_arg_from_alias( + get_write_alias(schema_info.outs[0]), schema_info, args, kwargs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # In the multi-return case, all aten ops return a tuple / list, so cast accordingly. outs_to_return = type(out)( [ +<<<<<<< HEAD (get_arg_from_alias(write_alias, schema_info, args, kwargs)) for write_alias in schema_info_outs_write_aliases +======= + ( + get_arg_from_alias( + get_write_alias(schema_info.outs[i]), schema_info, args, kwargs + ) + if get_write_alias(r) is not None + else o + ) + for ((i, r), o) in zip(enumerate(schema_info.outs), out) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] ) return outs_to_return diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 773e9f00e3d15..66d033f531efd 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -99,6 +99,7 @@ class KeyEntry(Protocol): +<<<<<<< HEAD def __hash__(self) -> int: ... def __eq__(self, other: object) -> bool: ... @@ -106,6 +107,19 @@ def __eq__(self, other: object) -> bool: ... def __str__(self) -> str: ... def get(self, parent: Any) -> Any: ... +======= + def __hash__(self) -> int: + ... + + def __eq__(self, other: object) -> bool: + ... + + def __str__(self) -> str: + ... + + def get(self, parent: Any) -> Any: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class EnumEncoder(json.JSONEncoder): @@ -370,7 +384,11 @@ def _unflatten_fn(values: Iterable[Any], context: Context) -> Any: def _flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]: flattened, (flat_names, _none_names) = _flatten_fn(obj) # type: ignore[misc] +<<<<<<< HEAD return [(GetAttrKey(k), v) for k, v in zip(flat_names, flattened)], flat_names +======= + return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _private_register_pytree_node( cls, @@ -753,7 +771,11 @@ def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]: def _tuple_flatten_with_keys( +<<<<<<< HEAD d: tuple[T, ...], +======= + d: tuple[T, ...] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _tuple_flatten(d) return [(SequenceKey(i), v) for i, v in enumerate(values)], context @@ -781,7 +803,11 @@ def _dict_flatten(d: dict[Any, T]) -> tuple[list[T], Context]: def _dict_flatten_with_keys( +<<<<<<< HEAD d: dict[Any, T], +======= + d: dict[Any, T] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _dict_flatten(d) return [(MappingKey(k), v) for k, v in zip(context, values)], context @@ -845,7 +871,11 @@ def _ordereddict_flatten(d: OrderedDict[Any, T]) -> tuple[list[T], Context]: def _ordereddict_flatten_with_keys( +<<<<<<< HEAD d: OrderedDict[Any, T], +======= + d: OrderedDict[Any, T] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _ordereddict_flatten(d) return [(MappingKey(k), v) for k, v in zip(context, values)], context @@ -868,7 +898,11 @@ def _defaultdict_flatten(d: defaultdict[Any, T]) -> tuple[list[T], Context]: def _defaultdict_flatten_with_keys( +<<<<<<< HEAD d: defaultdict[Any, T], +======= + d: defaultdict[Any, T] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _defaultdict_flatten(d) _, dict_context = context @@ -1031,9 +1065,15 @@ def tree_is_leaf( False >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple)) True +<<<<<<< HEAD >>> tree_is_leaf({"a": 1, "b": 2, "c": 3}) False >>> tree_is_leaf({"a": 1, "b": 2, "c": None}) +======= + >>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3}) + False + >>> tree_is_leaf({'a': 1, 'b': 2, 'c': None}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) False """ if is_leaf is not None and is_leaf(tree): @@ -1342,9 +1382,15 @@ def tree_map( See also :func:`tree_map_`. +<<<<<<< HEAD >>> tree_map(lambda x: x + 1, {"x": 7, "y": (42, 64)}) {'x': 8, 'y': (43, 65)} >>> tree_map(lambda x: x is None, {"x": 7, "y": (42, 64), "z": None}) +======= + >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) + {'x': 8, 'y': (43, 65)} + >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) {'x': False, 'y': (False, False), 'z': True} If multiple inputs are given, the structure of the tree is taken from the first input; @@ -1428,6 +1474,7 @@ def tree_map_( # These specializations help with type inference on the lambda passed to this # function @overload +<<<<<<< HEAD def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: ... @@ -1439,10 +1486,25 @@ def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]] def map_only( type_or_types_or_pred: Type3[T, S, U], / ) -> MapOnlyFn[Fn3[T, S, U, Any]]: ... +======= +def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]: + ... + + +@overload +def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]: + ... + + +@overload +def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This specialization is needed for the implementations below that call @overload +<<<<<<< HEAD def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ... @@ -1450,6 +1512,15 @@ def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: ... def map_only( type_or_types_or_pred: Callable[[Any], bool], / ) -> MapOnlyFn[FnAny[Any]]: ... +======= +def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]: + ... + + +@overload +def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def map_only( @@ -1505,7 +1576,12 @@ def tree_map_only( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -1515,7 +1591,12 @@ def tree_map_only( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -1525,7 +1606,12 @@ def tree_map_only( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -1535,7 +1621,12 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -1545,7 +1636,12 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def tree_map_only( @@ -1565,7 +1661,12 @@ def tree_map_only_( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -1575,7 +1676,12 @@ def tree_map_only_( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -1585,7 +1691,12 @@ def tree_map_only_( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -1595,7 +1706,12 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -1605,7 +1721,12 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> PyTree: ... +======= +) -> PyTree: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def tree_map_only_( @@ -1643,7 +1764,12 @@ def tree_all_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> bool: ... +======= +) -> bool: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -1653,7 +1779,12 @@ def tree_all_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> bool: ... +======= +) -> bool: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -1663,7 +1794,12 @@ def tree_all_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> bool: ... +======= +) -> bool: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def tree_all_only( @@ -1684,7 +1820,12 @@ def tree_any_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> bool: ... +======= +) -> bool: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -1694,7 +1835,12 @@ def tree_any_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> bool: ... +======= +) -> bool: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @@ -1704,7 +1850,12 @@ def tree_any_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, +<<<<<<< HEAD ) -> bool: ... +======= +) -> bool: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def tree_any_only( @@ -1841,7 +1992,11 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: raise NotImplementedError( +<<<<<<< HEAD f"Deserializing {json_schema['type']} in pytree is not registered.", +======= + f'Deserializing {json_schema["type"]} in pytree is not registered.', +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]] diff --git a/torch/utils/_strobelight/cli_function_profiler.py b/torch/utils/_strobelight/cli_function_profiler.py index 9b94a7b7a484b..adf999b2a7482 100644 --- a/torch/utils/_strobelight/cli_function_profiler.py +++ b/torch/utils/_strobelight/cli_function_profiler.py @@ -58,7 +58,11 @@ class StrobelightCLIFunctionProfiler: StrobelightCLIFunctionProfiler can be used to profile a python function and generate a strobelight link with the results. It works on meta servers but +<<<<<<< HEAD does not requires an fbcode target. +======= + does not requries an fbcode target. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) When stop_at_error is false(default), error during profiling does not prevent the work function from running. @@ -301,7 +305,11 @@ def strobelight( profiler = StrobelightCLIFunctionProfiler(**kwargs) def strobelight_inner( +<<<<<<< HEAD work_function: Callable[_P, _R], +======= + work_function: Callable[_P, _R] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Callable[_P, Optional[_R]]: @functools.wraps(work_function) def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 836cffc2144b2..2a9bb1f40d3f8 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -3,8 +3,12 @@ import math import operator import sys +<<<<<<< HEAD from collections.abc import Callable from typing import Optional, SupportsFloat, TYPE_CHECKING, TypeVar, Union +======= +from typing import Callable, Optional, SupportsFloat, TYPE_CHECKING, TypeVar, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from typing_extensions import TypeVarTuple, Unpack import sympy @@ -99,7 +103,11 @@ def _is_symbols_binary_summation(expr: sympy.Expr) -> bool: def _keep_float( +<<<<<<< HEAD f: Callable[[Unpack[_Ts]], _T], +======= + f: Callable[[Unpack[_Ts]], _T] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]: @functools.wraps(f) def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]: @@ -456,12 +464,15 @@ def _eval_is_nonnegative(self) -> Optional[bool]: def _eval_is_nonpositive(self) -> Optional[bool]: return True if self.args[1].is_negative else None # type: ignore[attr-defined] +<<<<<<< HEAD def _ccode(self, printer): p = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5) q = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5) abs_q = str(q) if self.args[1].is_positive else f"abs({q})" return f"({p} % {q}) < 0 ? {p} % {q} + {abs_q} : {p} % {q}" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Generic modulus: only defined on non-negative arguments class Mod(sympy.Function): @@ -927,12 +938,19 @@ def _find_localzeros(cls, values, **options): _eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731 _eval_is_antihermitian = lambda s: _torf( # noqa: E731 +<<<<<<< HEAD i.is_antihermitian for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_commutative = lambda s: _torf( # noqa: E731 i.is_commutative for i in s.args # noqa: E731 +======= + i.is_antihermitian for i in s.args # noqa: E731 + ) # noqa: E731 + _eval_is_commutative = lambda s: _torf( # noqa: E731 + i.is_commutative for i in s.args # noqa: E731 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # noqa: E731 _eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731 _eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731 @@ -946,12 +964,19 @@ def _find_localzeros(cls, values, **options): _eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731 _eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731 _eval_is_nonnegative = lambda s: _torf( # noqa: E731 +<<<<<<< HEAD i.is_nonnegative for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_nonpositive = lambda s: _torf( # noqa: E731 i.is_nonpositive for i in s.args # noqa: E731 +======= + i.is_nonnegative for i in s.args # noqa: E731 + ) # noqa: E731 + _eval_is_nonpositive = lambda s: _torf( # noqa: E731 + i.is_nonpositive for i in s.args # noqa: E731 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # noqa: E731 _eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731 _eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731 @@ -961,12 +986,19 @@ def _find_localzeros(cls, values, **options): _eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731 _eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731 _eval_is_extended_real = lambda s: _torf( # noqa: E731 +<<<<<<< HEAD i.is_extended_real for i in s.args # noqa: E731 ) # noqa: E731 _eval_is_transcendental = lambda s: _torf( # noqa: E731 i.is_transcendental for i in s.args # noqa: E731 +======= + i.is_extended_real for i in s.args # noqa: E731 + ) # noqa: E731 + _eval_is_transcendental = lambda s: _torf( # noqa: E731 + i.is_transcendental for i in s.args # noqa: E731 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # noqa: E731 _eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731 @@ -1068,7 +1100,11 @@ def eval(cls, base, exp): # base is assumed to be nonnegative, thereby prevent complex numbers from +<<<<<<< HEAD # occurring +======= +# occuring +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class FloatPow(sympy.Function): is_real = True @@ -1193,8 +1229,12 @@ def eval(cls, *args): # When all strides are integral, we can sort, and the size for the # largest stride doesn't matter and can be arbitrarily symbolic s_sizes, s_strides = zip( +<<<<<<< HEAD *sorted(zip(sizes, strides, strict=False), key=operator.itemgetter(1)), strict=False, +======= + *sorted(zip(sizes, strides), key=operator.itemgetter(1)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Put something arbitrary in the max size spot, it'll be ignored if all(isinstance(a, sympy.Integer) for a in s_sizes[:-1]): @@ -1329,7 +1369,11 @@ class OpaqueUnaryFn(sympy.Function): constant propagation. This helps avoid performing transformations that are valid for real numbers but are invalid for floating point; in particular, while we are willing to make optimizations that change +<<<<<<< HEAD numerics for Tensor compute, we are NOT willing to make optimizations +======= + numerics for Tensor compute, we are NOT willing to make optimziations +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) that change numerics for size compute. """ @@ -1413,10 +1457,14 @@ def eval(cls, a, b): return sympy.Integer(getattr(operator, real_op_name)(int(a), int(b))) return None +<<<<<<< HEAD nm = "BitwiseFn_" + name BitwiseFn.__name__ = nm BitwiseFn.__qualname__ = nm +======= + BitwiseFn.__name__ = "BitwiseFn_" + name +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return BitwiseFn diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index acfcc596bd49c..862c9c65a5e0b 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -20,9 +20,12 @@ class ExprPrinter(StrPrinter): def _print_Mul(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, "*", precedence(expr)) +<<<<<<< HEAD def _print_Not(self, expr: sympy.Expr) -> str: return f"not ({self._print(expr.args[0])})" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str: return self.stringify(expr.args, " + ", precedence(expr)) diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index e02e049cc36dd..474a293c5a28c 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -144,14 +144,24 @@ def __init__( self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn, +<<<<<<< HEAD ) -> None: ... +======= + ) -> None: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload def __init__( # type: ignore[misc] self: ValueRanges[SympyBoolean], lower: BoolIn, upper: BoolIn, +<<<<<<< HEAD ) -> None: ... +======= + ) -> None: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __init__(self, lower: AllIn, upper: AllIn) -> None: lower = simple_sympify(lower) @@ -238,13 +248,23 @@ def tighten(self, other) -> ValueRanges: def __and__( self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr], +<<<<<<< HEAD ) -> ValueRanges[sympy.Expr]: ... +======= + ) -> ValueRanges[sympy.Expr]: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload def __and__( # type: ignore[misc] self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean], +<<<<<<< HEAD ) -> ValueRanges[SympyBoolean]: ... +======= + ) -> ValueRanges[SympyBoolean]: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __and__(self: AllVR, other: AllVR) -> AllVR: if other in (ValueRanges.unknown(), ValueRanges.unknown_int()): @@ -268,13 +288,23 @@ def __and__(self: AllVR, other: AllVR) -> AllVR: def __or__( self: ValueRanges[sympy.Expr], other: ValueRanges[sympy.Expr], +<<<<<<< HEAD ) -> ValueRanges[sympy.Expr]: ... +======= + ) -> ValueRanges[sympy.Expr]: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload def __or__( # type: ignore[misc] self: ValueRanges[SympyBoolean], other: ValueRanges[SympyBoolean], +<<<<<<< HEAD ) -> ValueRanges[SympyBoolean]: ... +======= + ) -> ValueRanges[SympyBoolean]: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __or__(self: AllVR, other: AllVR) -> AllVR: if ValueRanges.unknown() in (self, other): @@ -337,7 +367,12 @@ def increasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: @overload @staticmethod +<<<<<<< HEAD def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: ... +======= + def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @staticmethod @@ -377,7 +412,12 @@ def coordinatewise_increasing_map( x: Union[ExprIn, ExprVR], y: Union[ExprIn, ExprVR], fn: ExprFn2, +<<<<<<< HEAD ) -> ExprVR: ... +======= + ) -> ExprVR: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @overload @staticmethod @@ -385,7 +425,12 @@ def coordinatewise_increasing_map( # type: ignore[misc] x: Union[BoolIn, BoolVR], y: Union[BoolIn, BoolVR], fn: BoolFn2, +<<<<<<< HEAD ) -> BoolVR: ... +======= + ) -> BoolVR: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @staticmethod def coordinatewise_increasing_map( @@ -911,7 +956,11 @@ def expr_cond_pair(a, b): return (a, b) # piecewise function can be used to convert a SymBool to SymInt: +<<<<<<< HEAD # int_expr = Piecewise((1, bool_expr), (0, True)), it evaluates to 1 when sym_bool is True and 0 otherwise. +======= + # int_expr = Piecewise((1, bool_expr), (0, True)), it evalutes to 1 when sym_bool is True and 0 otherwise. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # # ranges is a sequence of (expr_range, condition_range) pairs. The range pair is constructed in expr_cond_pair. # The ValueRange of Piecewise is just the union of all expr ranges whose condition expr can be True. @@ -1019,7 +1068,11 @@ def bound_sympy( # If there's a tracing context, augment available constrained ranges. context = torch._guards.TracingContext.try_get() +<<<<<<< HEAD if context and context.fake_mode and context.fake_mode.shape_env: +======= + if context and context.fake_mode.shape_env: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if ranges: ranges = {**context.fake_mode.shape_env.var_to_range, **ranges} else: diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index 7d545e8221643..1e856dfc62169 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -6,11 +6,21 @@ @functools.cache def has_triton_package() -> bool: try: +<<<<<<< HEAD import triton # noqa: F401 return True except ImportError: return False +======= + from triton.compiler.compiler import triton_key + + return triton_key is not None + except ImportError: + return False + except RuntimeError: + return False +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @functools.cache @@ -82,7 +92,11 @@ def has_triton_tma_device() -> bool: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) and not torch.version.hip +<<<<<<< HEAD ) or torch.xpu.is_available(): +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # old API try: from triton.language.extra.cuda import ( # noqa: F401 @@ -114,7 +128,11 @@ def has_triton_stable_tma_api() -> bool: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) and not torch.version.hip +<<<<<<< HEAD ) or torch.xpu.is_available(): +======= + ): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: from triton.language import make_tensor_descriptor # noqa: F401 @@ -146,7 +164,10 @@ def _return_true(device_interface: Any) -> bool: "cuda": cuda_extra_check, "xpu": _return_true, "cpu": cpu_extra_check, +<<<<<<< HEAD "mtia": _return_true, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } def is_device_compatible_with_triton() -> bool: @@ -170,7 +191,11 @@ def triton_backend() -> Any: @functools.cache def triton_hash_with_backend() -> str: +<<<<<<< HEAD from torch._inductor.runtime.triton_compat import triton_key +======= + from triton.compiler.compiler import triton_key +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) backend = triton_backend() key = f"{triton_key()}-{backend.hash()}" diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index 5a83aede8d468..504f3b2a250c6 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -106,7 +106,11 @@ def _get_current_device_index(): elif isinstance(device, str): device = torch.device(device) +<<<<<<< HEAD # variable device can only be torch.device type or int type +======= + # variable devcie can only be torch.device type or int type +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if isinstance(device, torch.device): if device.type != custom_backend_name: raise RuntimeError(f"Invalid device, must be {custom_backend_name} device") @@ -426,9 +430,15 @@ def func_name(*args, **kwargs): it is marked as private. It is a convenience function for backend implementers to more easily call the hooks into their backend extensions. """ +<<<<<<< HEAD assert isinstance(func_name, str), ( f"func_name must be `str`, but got `{type(func_name)}`." ) +======= + assert isinstance( + func_name, str + ), f"func_name must be `str`, but got `{type(func_name)}`." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) backend_name = _get_privateuse1_backend_name() custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type] function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type] diff --git a/torch/utils/benchmark/README.md b/torch/utils/benchmark/README.md index 6fa025e51d37c..b0521e1db16f5 100644 --- a/torch/utils/benchmark/README.md +++ b/torch/utils/benchmark/README.md @@ -25,7 +25,11 @@ into two broad categories: * `Timer` implements the `blocked_autorange` function which is a mixture of `timeit.Timer.repeat` and `timeit.Timer.autorange`. This function +<<<<<<< HEAD selects an appropriate number and runs for a roughly fixed amount of time +======= + selects and appropriate number and runs for a roughly fixed amount of time +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (like `autorange`), but is less wasteful than `autorange` which discards ~75% of measurements. It runs many times, similar to `repeat`, and returns a `Measurement` containing all of the run results. @@ -46,7 +50,11 @@ table will be generated per unique label. may be logically equivalent differ in implementation. Assigning separate sub_labels will result in a row per sub_label. If a sublabel is not provided, `stmt` is used instead. Statistics (such as computing the fastest +<<<<<<< HEAD implementation) use all sub_labels. +======= +implementation) are use all sub_labels. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) * `description`: This describes the inputs. For instance, `stmt=torch.add(x, y)` can be run over several values of `x` and `y`. Each pair should be given its diff --git a/torch/utils/benchmark/utils/compare.py b/torch/utils/benchmark/utils/compare.py index d1df2987ea6c7..3c0c33decb4d6 100644 --- a/torch/utils/benchmark/utils/compare.py +++ b/torch/utils/benchmark/utils/compare.py @@ -280,7 +280,11 @@ class Compare: https://pytorch.org/tutorials/recipes/recipes/benchmark.html Args: +<<<<<<< HEAD results: List of Measurement to display. +======= + results: List of Measurment to display. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """ def __init__(self, results: list[common.Measurement]): self._results: list[common.Measurement] = [] diff --git a/torch/utils/benchmark/utils/compile.py b/torch/utils/benchmark/utils/compile.py index cee9c8d7f7174..9cc563dc9ff51 100644 --- a/torch/utils/benchmark/utils/compile.py +++ b/torch/utils/benchmark/utils/compile.py @@ -127,7 +127,11 @@ def bench_all( This is a simple utility that can be used to benchmark torch.compile In particular it ensures that your GPU is setup to use tensor cores if it supports its It also tries out all the main backends and prints a table of results so you can easily compare them all +<<<<<<< HEAD Many of the backendds have their own optional dependencies so please pip install them separately +======= + Many of the backendds have their own optional dependencies so please pip install them seperately +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) You will get one table for inference and another for training If you'd like to leverage this utility for training make sure to pass in a torch.optim.Optimizer diff --git a/torch/utils/benchmark/utils/timer.py b/torch/utils/benchmark/utils/timer.py index 1889f6756e70f..74b393ccdfd7a 100644 --- a/torch/utils/benchmark/utils/timer.py +++ b/torch/utils/benchmark/utils/timer.py @@ -13,9 +13,27 @@ __all__ = ["Timer", "timer", "Language"] +<<<<<<< HEAD if torch.accelerator.is_available(): def timer() -> float: torch.accelerator.synchronize() +======= +if torch.backends.cuda.is_built() and torch.cuda.is_available(): # type: ignore[no-untyped-call] + def timer() -> float: + torch.cuda.synchronize() + return timeit.default_timer() +elif torch.xpu.is_available(): + def timer() -> float: + torch.xpu.synchronize() + return timeit.default_timer() +elif torch._C._get_privateuse1_backend_name() != "privateuseone": + privateuse1_device_handler = getattr(torch, torch._C._get_privateuse1_backend_name(), None) \ + if torch._C._get_privateuse1_backend_name() != "cpu" else None + + def timer() -> float: + if privateuse1_device_handler: + privateuse1_device_handler.synchronize() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return timeit.default_timer() else: timer = timeit.default_timer @@ -37,12 +55,21 @@ def __init__( ) -> None: if timer is not timeit.default_timer: raise NotImplementedError( +<<<<<<< HEAD "PyTorch was built with accelerators and an accelerator is present; however " "Timer does not yet support accelerator measurements. If your " "code is CPU only, pass `timer=timeit.default_timer` to the " "Timer's constructor to indicate this. (Note that this will " "produce incorrect results if an accelerator is in fact used, as " "Timer will not synchronize the accelerator.)" +======= + "PyTorch was built with CUDA and a GPU is present; however " + "Timer does not yet support GPU measurements. If your " + "code is CPU only, pass `timer=timeit.default_timer` to the " + "Timer's constructor to indicate this. (Note that this will " + "produce incorrect results if the GPU is in fact used, as " + "Timer will not synchronize CUDA.)" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) if globals: @@ -76,7 +103,11 @@ class Timer: 1) Runtime aware: Timer will perform warmups (important as some elements of PyTorch are lazily initialized), set threadpool size so that comparisons are +<<<<<<< HEAD apples-to-apples, and synchronize asynchronous accelerator functions when +======= + apples-to-apples, and synchronize asynchronous CUDA functions when +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) necessary. 2) Focus on replicates: @@ -119,8 +150,13 @@ class Timer: timer: Callable which returns the current time. If PyTorch was built +<<<<<<< HEAD without accelerators or there is no accelerator present, this defaults to `timeit.default_timer`; otherwise it will synchronize accelerators before +======= + without CUDA or there is no GPU present, this defaults to + `timeit.default_timer`; otherwise it will synchronize CUDA before +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) measuring the time. globals: @@ -347,7 +383,11 @@ def blocked_autorange( 2) A large block size better amortizes the cost of `timer` invocation, and results in a less biased measurement. This is +<<<<<<< HEAD important because accelerator synchronization time is non-trivial +======= + important because CUDA synchronization time is non-trivial +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) (order single to low double digit microseconds) and would otherwise bias the measurement. diff --git a/torch/utils/bundled_inputs.py b/torch/utils/bundled_inputs.py index 6209fc8ee8745..29eb7626f41f0 100644 --- a/torch/utils/bundled_inputs.py +++ b/torch/utils/bundled_inputs.py @@ -116,7 +116,11 @@ def bundle_inputs( ) # The above cloning function returns a torch._C.scriptmodule and we need a torch.jit.scriptmodule. +<<<<<<< HEAD # Fortunately there is a function in _recursive that does exactly that conversion. +======= + # Fortunately theres a function in _recursive that does exactly that conversion. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cloned_module = wrap_cpp_module(clone) if isinstance(inputs, dict): assert isinstance(info, dict) or info is None diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 30d2fc106f5ff..056337f04342b 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -347,7 +347,10 @@ def checkpoint( context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, determinism_check: str = _DEFAULT_DETERMINISM_MODE, debug: bool = False, +<<<<<<< HEAD early_stop: bool = True, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) **kwargs ): r"""Checkpoint a model or part of the model. @@ -374,7 +377,11 @@ def checkpoint( .. warning:: The ``use_reentrant`` parameter should be passed explicitly. In version +<<<<<<< HEAD 2.9 we will raise an exception if ``use_reentrant`` is not passed. +======= + 2.4 we will raise an exception if ``use_reentrant`` is not passed. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) If you are using the ``use_reentrant=True`` variant, please refer to the note below for important considerations and potential limitations. @@ -426,9 +433,12 @@ def checkpoint( passed as the tuple. For example, in LSTM, if user passes ``(activation, hidden)``, :attr:`function` should correctly use the first input as ``activation`` and the second input as ``hidden`` +<<<<<<< HEAD args: tuple containing inputs to the :attr:`function` Keyword args: +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) preserve_rng_state(bool, optional): Omit stashing and restoring the RNG state during each checkpoint. Note that under torch.compile, this flag doesn't take effect and we always preserve RNG state. @@ -436,7 +446,11 @@ def checkpoint( use_reentrant(bool): specify whether to use the activation checkpoint variant that requires reentrant autograd. This parameter should be passed +<<<<<<< HEAD explicitly. In version 2.9 we will raise an exception if +======= + explicitly. In version 2.5 we will raise an exception if +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ``use_reentrant`` is not passed. If ``use_reentrant=False``, ``checkpoint`` will use an implementation that does not require reentrant autograd. This allows ``checkpoint`` to support additional @@ -459,11 +473,15 @@ def checkpoint( a trace of the operators ran during the original forward computation as well as the recomputation. This argument is only supported if ``use_reentrant=False``. +<<<<<<< HEAD early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops recomputation as soon as it has computed all needed Tensors. This argument is ignored if ``use_reentrant=True``. Can be overridden globally using :func:`set_checkpoint_early_stop` context manager. Default: ``True``. +======= + args: tuple containing inputs to the :attr:`function` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Returns: Output of running :attr:`function` on :attr:`*args` @@ -471,8 +489,13 @@ def checkpoint( if use_reentrant is None: warnings.warn( "torch.utils.checkpoint: the use_reentrant parameter should be " +<<<<<<< HEAD "passed explicitly. Starting in PyTorch 2.9, calling checkpoint " "without use_reentrant will raise an exception. use_reentrant=False is " +======= + "passed explicitly. In version 2.5 we will raise an exception " + "if use_reentrant is not passed. use_reentrant=False is " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "recommended, but if you need to preserve the current default " "behavior, you can pass use_reentrant=True. Refer to docs for more " "details on the differences between the two variants.", @@ -496,7 +519,11 @@ def checkpoint( return CheckpointFunction.apply(function, preserve, *args) else: gen = _checkpoint_without_reentrant_generator( +<<<<<<< HEAD function, preserve, context_fn, determinism_check, debug, early_stop, *args, **kwargs +======= + function, preserve, context_fn, determinism_check, debug, *args, **kwargs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) # Runs pre-forward logic next(gen) @@ -519,7 +546,11 @@ def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwar .. warning:: The ``use_reentrant`` parameter should be passed explicitly. In version +<<<<<<< HEAD 2.9 we will raise an exception if ``use_reentrant`` is not passed. +======= + 2.4 we will raise an exception if ``use_reentrant`` is not passed. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) If you are using the ``use_reentrant=True` variant, please see :func:`~torch.utils.checkpoint.checkpoint` for the important considerations and limitations of this variant. It is @@ -560,7 +591,11 @@ def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwar warnings.warn( "torch.utils.checkpoint.checkpoint_sequential: the use_reentrant " "parameter should be passed explicitly. " +<<<<<<< HEAD "In version 2.9 we will raise an exception if use_reentrant " +======= + "In version 2.5 we will raise an exception if use_reentrant " +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "is not passed. use_reentrant=False is " "recommended, but if you need to preserve the current default " "behavior, you can pass use_reentrant=True. Refer to docs for more " @@ -739,7 +774,11 @@ def _internal_assert(cond): # by holder=None. We skip over them. We still save x at (4) (since its holder # is still alive.) +<<<<<<< HEAD _enable_checkpoint_early_stop: Optional[bool] = None +======= +_enable_checkpoint_early_stop = True +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @contextlib.contextmanager @@ -866,7 +905,11 @@ def check_recomputed_tensors_match(self, gid): if not len(self.weak_holders) == self.recomp_counter[gid]: # 2. During recompute, fewer tensors were saved # +<<<<<<< HEAD # We know that every time we save something do original forward +======= + # We know that everytime we save something do original forward +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # we append to weak_holder, and every time we save a tensor # during recompute we increment recompute_counter. raise CheckpointError( @@ -1278,7 +1321,11 @@ class CheckpointPolicy(enum.Enum): def _policy_from_bool(b): +<<<<<<< HEAD # For backward compatibility +======= + # For backward compatability +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE @@ -1456,7 +1503,10 @@ def _checkpoint_without_reentrant_generator( context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, determinism_check: str = _DEFAULT_DETERMINISM_MODE, debug: bool = False, +<<<<<<< HEAD early_stop: bool = True, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *args, **kwargs ): @@ -1484,10 +1534,13 @@ def _checkpoint_without_reentrant_generator( debug(bool, optional): If ``True``, error messages will also include a trace of the operators ran during the original forward computation as well as the recomputation. +<<<<<<< HEAD early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops recomputation as soon as it has computed all needed Tensors. Can be overridden globally using :func:`set_checkpoint_early_stop` context manager. Default: ``True``. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) *args: Arguments to pass in to the given ``function``. **kwargs: Keyword arguments to pass into the given ``function``. """ @@ -1556,7 +1609,11 @@ def recompute_fn(*inputs): new_frame = _CheckpointFrame( recompute_fn, +<<<<<<< HEAD _enable_checkpoint_early_stop if _enable_checkpoint_early_stop is not None else early_stop, +======= + _enable_checkpoint_early_stop, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) unpack_error_cb, metadata_fn ) diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index c6473220bc00a..9dd32e17419b4 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -6,22 +6,34 @@ import datetime import json import locale +<<<<<<< HEAD import os import re import subprocess import sys from collections import namedtuple from typing import cast as _cast +======= +import re +import subprocess +import sys +import os +from collections import namedtuple +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) try: import torch +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) TORCH_AVAILABLE = True except (ImportError, NameError, AttributeError, OSError): TORCH_AVAILABLE = False # System Environment Information +<<<<<<< HEAD SystemEnv = namedtuple( "SystemEnv", [ @@ -53,6 +65,35 @@ "cpu_info", ], ) +======= +SystemEnv = namedtuple('SystemEnv', [ + 'torch_version', + 'is_debug_build', + 'cuda_compiled_version', + 'gcc_version', + 'clang_version', + 'cmake_version', + 'os', + 'libc_version', + 'python_version', + 'python_platform', + 'is_cuda_available', + 'cuda_runtime_version', + 'cuda_module_loading', + 'nvidia_driver_version', + 'nvidia_gpu_models', + 'cudnn_version', + 'pip_version', # 'pip' or 'pip3' + 'pip_packages', + 'conda_packages', + 'hip_compiled_version', + 'hip_runtime_version', + 'miopen_runtime_version', + 'caching_allocator_config', + 'is_xnnpack_available', + 'cpu_info', +]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) COMMON_PATTERNS = [ "torch", @@ -79,6 +120,7 @@ "nvtx", ] +<<<<<<< HEAD ONEAPI_PATTERNS = [ "dpcpp-cpp-rt", "intel-cmplr-lib-rt", @@ -103,6 +145,8 @@ "tcmlib", ] +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) CONDA_PATTERNS = [ "cudatoolkit", "soumith", @@ -120,6 +164,7 @@ def run(command): """Return (return-code, stdout, stderr).""" shell = True if type(command) is str else False +<<<<<<< HEAD p = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell ) @@ -127,6 +172,14 @@ def run(command): rc = p.returncode if get_platform() == "win32": enc = "oem" +======= + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=shell) + raw_output, raw_err = p.communicate() + rc = p.returncode + if get_platform() == 'win32': + enc = 'oem' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: enc = locale.getpreferredencoding() output = raw_output.decode(enc) @@ -152,19 +205,31 @@ def run_and_parse_first_match(run_lambda, command, regex): return None return match.group(1) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def run_and_return_first_line(run_lambda, command): """Run command using run_lambda and returns first line if output is not empty.""" rc, out, _ = run_lambda(command) if rc != 0: return None +<<<<<<< HEAD return out.split("\n")[0] +======= + return out.split('\n')[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_conda_packages(run_lambda, patterns=None): if patterns is None: +<<<<<<< HEAD patterns = CONDA_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS conda = os.environ.get("CONDA_EXE", "conda") +======= + patterns = CONDA_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + conda = os.environ.get('CONDA_EXE', 'conda') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) out = run_and_read_all(run_lambda, "{} list".format(conda)) if out is None: return out @@ -172,6 +237,7 @@ def get_conda_packages(run_lambda, patterns=None): return "\n".join( line for line in out.splitlines() +<<<<<<< HEAD if not line.startswith("#") and any(name in line for name in patterns) ) @@ -206,6 +272,34 @@ def get_gpu_info(run_lambda): and hasattr(torch.version, "hip") and torch.version.hip is not None ): +======= + if not line.startswith("#") + and any(name in line for name in patterns) + ) + +def get_gcc_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') + +def get_clang_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)') + + +def get_cmake_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)') + + +def get_nvidia_driver_version(run_lambda): + if get_platform() == 'darwin': + cmd = 'kextstat | grep -i cuda' + return run_and_parse_first_match(run_lambda, cmd, + r'com[.]nvidia[.]CUDA [(](.*?)[)]') + smi = get_nvidia_smi() + return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ') + + +def get_gpu_info(run_lambda): + if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TORCH_AVAILABLE and torch.cuda.is_available(): if torch.version.hip is not None: prop = torch.cuda.get_device_properties(0) @@ -218,6 +312,7 @@ def get_gpu_info(run_lambda): return torch.cuda.get_device_name(None) + gcnArch return None smi = get_nvidia_smi() +<<<<<<< HEAD uuid_regex = re.compile(r" \(UUID: .+?\)") rc, out, _ = run_lambda(smi + " -L") if rc != 0: @@ -228,32 +323,65 @@ def get_gpu_info(run_lambda): def get_running_cuda_version(run_lambda): return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)") +======= + uuid_regex = re.compile(r' \(UUID: .+?\)') + rc, out, _ = run_lambda(smi + ' -L') + if rc != 0: + return None + # Anonymize GPUs by removing their UUID + return re.sub(uuid_regex, '', out) + + +def get_running_cuda_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_cudnn_version(run_lambda): """Return a list of libcudnn.so; it's hard to tell which one is being used.""" +<<<<<<< HEAD if get_platform() == "win32": system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%") where_cmd = os.path.join(system_root, "System32", "where") cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) elif get_platform() == "darwin": +======= + if get_platform() == 'win32': + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") + where_cmd = os.path.join(system_root, 'System32', 'where') + cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) + elif get_platform() == 'darwin': +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # CUDA libraries and drivers can be found in /usr/local/cuda/. See # https://docs.nvidia.com/cuda/archive/9.0/cuda-installation-guide-mac-os-x/index.html#installation # https://docs.nvidia.com/deeplearning/cudnn/installation/latest/ # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. +<<<<<<< HEAD cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*" +======= + cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' rc, out, _ = run_lambda(cudnn_cmd) # find will return 1 if there are permission errors or if not found if len(out) == 0 or (rc != 1 and rc != 0): +<<<<<<< HEAD l = os.environ.get("CUDNN_LIBRARY") +======= + l = os.environ.get('CUDNN_LIBRARY') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if l is not None and os.path.isfile(l): return os.path.realpath(l) return None files_set = set() +<<<<<<< HEAD for fn in out.split("\n"): +======= + for fn in out.split('\n'): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fn = os.path.realpath(fn) # eliminate symbolic links if os.path.isfile(fn): files_set.add(fn) @@ -263,12 +391,18 @@ def get_cudnn_version(run_lambda): files = sorted(files_set) if len(files) == 1: return files[0] +<<<<<<< HEAD result = "\n".join(files) return "Probably one of the following:\n{}".format(result) +======= + result = '\n'.join(files) + return 'Probably one of the following:\n{}'.format(result) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_nvidia_smi(): # Note: nvidia-smi is currently available only on Windows and Linux +<<<<<<< HEAD smi = "nvidia-smi" if get_platform() == "win32": system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") @@ -277,6 +411,14 @@ def get_nvidia_smi(): program_files_root, "NVIDIA Corporation", "NVSMI", smi ) new_path = os.path.join(system_root, "System32", smi) +======= + smi = 'nvidia-smi' + if get_platform() == 'win32': + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files') + legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi) + new_path = os.path.join(system_root, 'System32', smi) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) smis = [new_path, legacy_path] for candidate_smi in smis: if os.path.exists(candidate_smi): @@ -285,6 +427,7 @@ def get_nvidia_smi(): return smi +<<<<<<< HEAD def _detect_linux_pkg_manager(): if get_platform() != "linux": return "N/A" @@ -433,6 +576,8 @@ def get_intel_gpu_detected(run_lambda): return "\n".join(devices) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # example outputs of CPU infos # * linux # Architecture: x86_64 @@ -508,12 +653,20 @@ def get_intel_gpu_detected(run_lambda): # ProcessorType=3 # Revision=27142 +<<<<<<< HEAD def get_cpu_info(run_lambda): rc, out, err = 0, "", "" if get_platform() == "linux": rc, out, err = run_lambda("lscpu") elif get_platform() == "win32": +======= +def get_cpu_info(run_lambda): + rc, out, err = 0, '', '' + if get_platform() == 'linux': + rc, out, err = run_lambda('lscpu') + elif get_platform() == 'win32': +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) rc, out, err = run_lambda( 'powershell.exe "gwmi -Class Win32_Processor | Select-Object -Property Name,Manufacturer,Family,\ Architecture,ProcessorType,DeviceID,CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision\ @@ -533,9 +686,15 @@ def get_cpu_info(run_lambda): lst.append(out) lst.append(str(e)) out = "\n".join(lst) +<<<<<<< HEAD elif get_platform() == "darwin": rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") cpu_info = "None" +======= + elif get_platform() == 'darwin': + rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") + cpu_info = 'None' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if rc == 0: cpu_info = out else: @@ -544,6 +703,7 @@ def get_cpu_info(run_lambda): def get_platform(): +<<<<<<< HEAD if sys.platform.startswith("linux"): return "linux" elif sys.platform.startswith("win32"): @@ -552,12 +712,26 @@ def get_platform(): return "cygwin" elif sys.platform.startswith("darwin"): return "darwin" +======= + if sys.platform.startswith('linux'): + return 'linux' + elif sys.platform.startswith('win32'): + return 'win32' + elif sys.platform.startswith('cygwin'): + return 'cygwin' + elif sys.platform.startswith('darwin'): + return 'darwin' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: return sys.platform def get_mac_version(run_lambda): +<<<<<<< HEAD return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)") +======= + return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_windows_version(run_lambda): @@ -575,6 +749,7 @@ def get_windows_version(run_lambda): def get_lsb_version(run_lambda): +<<<<<<< HEAD return run_and_parse_first_match( run_lambda, "lsb_release -a", r"Description:\t(.*)" ) @@ -584,10 +759,19 @@ def check_release_file(run_lambda): return run_and_parse_first_match( run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"' ) +======= + return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)') + + +def check_release_file(run_lambda): + return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', + r'PRETTY_NAME="(.*)"') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_os(run_lambda): from platform import machine +<<<<<<< HEAD platform = get_platform() @@ -605,13 +789,37 @@ def get_os(run_lambda): desc = get_lsb_version(run_lambda) if desc is not None: return "{} ({})".format(desc, machine()) +======= + platform = get_platform() + + if platform == 'win32' or platform == 'cygwin': + return get_windows_version(run_lambda) + + if platform == 'darwin': + version = get_mac_version(run_lambda) + if version is None: + return None + return 'macOS {} ({})'.format(version, machine()) + + if platform == 'linux': + # Ubuntu/Debian based + desc = get_lsb_version(run_lambda) + if desc is not None: + return '{} ({})'.format(desc, machine()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Try reading /etc/*-release desc = check_release_file(run_lambda) if desc is not None: +<<<<<<< HEAD return "{} ({})".format(desc, machine()) return "{} ({})".format(platform, machine()) +======= + return '{} ({})'.format(desc, machine()) + + return '{} ({})'.format(platform, machine()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Unknown platform return platform @@ -619,21 +827,31 @@ def get_os(run_lambda): def get_python_platform(): import platform +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return platform.platform() def get_libc_version(): import platform +<<<<<<< HEAD if get_platform() != "linux": return "N/A" return "-".join(platform.libc_ver()) +======= + if get_platform() != 'linux': + return 'N/A' + return '-'.join(platform.libc_ver()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_pip_packages(run_lambda, patterns=None): """Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages.""" if patterns is None: +<<<<<<< HEAD patterns = PIP_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS pip_version = "pip3" if sys.version_info.major == 3 else "pip" @@ -649,22 +867,49 @@ def get_pip_packages(run_lambda, patterns=None): filtered_out = "\n".join( line for line in out.splitlines() if any(name in line for name in patterns) +======= + patterns = PIP_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + + pip_version = 'pip3' if sys.version_info.major == 3 else 'pip' + + os.environ['PIP_DISABLE_PIP_VERSION_CHECK'] = '1' + # People generally have pip as `pip` or `pip3` + # But here it is invoked as `python -mpip` + out = run_and_read_all(run_lambda, [sys.executable, '-mpip', 'list', '--format=freeze']) + if out is None: + return pip_version, out + + filtered_out = '\n'.join( + line + for line in out.splitlines() + if any(name in line for name in patterns) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) return pip_version, filtered_out def get_cachingallocator_config(): +<<<<<<< HEAD ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") if not ca_config: ca_config = os.environ.get("PYTORCH_HIP_ALLOC_CONF", "") +======= + ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') + if not ca_config: + ca_config = os.environ.get('PYTORCH_HIP_ALLOC_CONF', '') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return ca_config def get_cuda_module_loading_config(): if TORCH_AVAILABLE and torch.cuda.is_available(): torch.cuda.init() +<<<<<<< HEAD config = os.environ.get("CUDA_MODULE_LOADING", "") +======= + config = os.environ.get('CUDA_MODULE_LOADING', '') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return config else: return "N/A" @@ -673,12 +918,18 @@ def get_cuda_module_loading_config(): def is_xnnpack_available(): if TORCH_AVAILABLE: import torch.backends.xnnpack +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] else: return "N/A" +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_env_info(): """ Collects environment information to aid in debugging. @@ -692,7 +943,11 @@ def get_env_info(): Caching allocator config, XNNPACK availability and CPU information. Returns: +<<<<<<< HEAD SystemEnv (namedtuple): A tuple containing various environment details +======= + SystemEnv (namedtuple): A tuple containining various environment details +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) and system information. """ run_lambda = run @@ -703,6 +958,7 @@ def get_env_info(): debug_mode_str = str(torch.version.debug) cuda_available_str = str(torch.cuda.is_available()) cuda_version_str = torch.version.cuda +<<<<<<< HEAD xpu_available_str = str(torch.xpu.is_available()) if torch.xpu.is_available(): xpu_available_str = ( @@ -730,6 +986,23 @@ def get_version_or_na(cfg, prefix): else: version_str = debug_mode_str = cuda_available_str = cuda_version_str = xpu_available_str = "N/A" # type: ignore[assignment] hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" +======= + if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + else: # HIP version + def get_version_or_na(cfg, prefix): + _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] + return _lst[0] if _lst else 'N/A' + + cfg = torch._C._show_config().split('\n') + hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime') + miopen_runtime_version = get_version_or_na(cfg, 'MIOpen') + cuda_version_str = 'N/A' + hip_compiled_version = torch.version.hip + else: + version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A' + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sys_version = sys.version.replace("\n", " ") @@ -738,9 +1011,13 @@ def get_version_or_na(cfg, prefix): return SystemEnv( torch_version=version_str, is_debug_build=debug_mode_str, +<<<<<<< HEAD python_version="{} ({}-bit runtime)".format( sys_version, sys.maxsize.bit_length() + 1 ), +======= + python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) python_platform=get_python_platform(), is_cuda_available=cuda_available_str, cuda_compiled_version=cuda_version_str, @@ -749,7 +1026,10 @@ def get_version_or_na(cfg, prefix): nvidia_gpu_models=get_gpu_info(run_lambda), nvidia_driver_version=get_nvidia_driver_version(run_lambda), cudnn_version=get_cudnn_version(run_lambda), +<<<<<<< HEAD is_xpu_available=xpu_available_str, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) hip_compiled_version=hip_compiled_version, hip_runtime_version=hip_runtime_version, miopen_runtime_version=miopen_runtime_version, @@ -766,7 +1046,10 @@ def get_version_or_na(cfg, prefix): cpu_info=get_cpu_info(run_lambda), ) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) env_info_fmt = """ PyTorch version: {torch_version} Is debug build: {is_debug_build} @@ -787,7 +1070,10 @@ def get_version_or_na(cfg, prefix): GPU models and configuration: {nvidia_gpu_models} Nvidia driver version: {nvidia_driver_version} cuDNN version: {cudnn_version} +<<<<<<< HEAD Is XPU available: {is_xpu_available} +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) HIP runtime version: {hip_runtime_version} MIOpen runtime version: {miopen_runtime_version} Is XNNPACK available: {is_xnnpack_available} @@ -802,14 +1088,22 @@ def get_version_or_na(cfg, prefix): def pretty_str(envinfo): +<<<<<<< HEAD def replace_nones(dct, replacement="Could not collect"): +======= + def replace_nones(dct, replacement='Could not collect'): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for key in dct.keys(): if dct[key] is not None: continue dct[key] = replacement return dct +<<<<<<< HEAD def replace_bools(dct, true="Yes", false="No"): +======= + def replace_bools(dct, true='Yes', false='No'): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for key in dct.keys(): if dct[key] is True: dct[key] = true @@ -817,25 +1111,40 @@ def replace_bools(dct, true="Yes", false="No"): dct[key] = false return dct +<<<<<<< HEAD def prepend(text, tag="[prepend]"): lines = text.split("\n") updated_lines = [tag + line for line in lines] return "\n".join(updated_lines) def replace_if_empty(text, replacement="No relevant packages"): +======= + def prepend(text, tag='[prepend]'): + lines = text.split('\n') + updated_lines = [tag + line for line in lines] + return '\n'.join(updated_lines) + + def replace_if_empty(text, replacement='No relevant packages'): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if text is not None and len(text) == 0: return replacement return text def maybe_start_on_next_line(string): # If `string` is multiline, prepend a \n to it. +<<<<<<< HEAD if string is not None and len(string.split("\n")) > 1: return "\n{}\n".format(string) +======= + if string is not None and len(string.split('\n')) > 1: + return '\n{}\n'.format(string) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return string mutable_dict = envinfo._asdict() # If nvidia_gpu_models is multiline, start on the next line +<<<<<<< HEAD mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line( envinfo.nvidia_gpu_models ) @@ -859,6 +1168,25 @@ def maybe_start_on_next_line(string): mutable_dict[field] = "No CUDA" if envinfo.cuda_compiled_version is None: mutable_dict["cuda_compiled_version"] = "None" +======= + mutable_dict['nvidia_gpu_models'] = \ + maybe_start_on_next_line(envinfo.nvidia_gpu_models) + + # If the machine doesn't have CUDA, report some fields as 'No CUDA' + dynamic_cuda_fields = [ + 'cuda_runtime_version', + 'nvidia_gpu_models', + 'nvidia_driver_version', + ] + all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] + all_dynamic_cuda_fields_missing = all( + mutable_dict[field] is None for field in dynamic_cuda_fields) + if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing: + for field in all_cuda_fields: + mutable_dict[field] = 'No CUDA' + if envinfo.cuda_compiled_version is None: + mutable_dict['cuda_compiled_version'] = 'None' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Replace True with Yes, False with No mutable_dict = replace_bools(mutable_dict) @@ -867,6 +1195,7 @@ def maybe_start_on_next_line(string): mutable_dict = replace_nones(mutable_dict) # If either of these are '', replace with 'No relevant packages' +<<<<<<< HEAD mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"]) mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"]) @@ -881,6 +1210,20 @@ def maybe_start_on_next_line(string): mutable_dict["conda_packages"], "[conda] " ) mutable_dict["cpu_info"] = envinfo.cpu_info +======= + mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages']) + mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages']) + + # Tag conda and pip packages with a prefix + # If they were previously None, they'll show up as ie '[conda] Could not collect' + if mutable_dict['pip_packages']: + mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'], + '[{}] '.format(envinfo.pip_version)) + if mutable_dict['conda_packages']: + mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'], + '[conda] ') + mutable_dict['cpu_info'] = envinfo.cpu_info +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return env_info_fmt.format(**mutable_dict) @@ -904,6 +1247,7 @@ def main(): output = get_pretty_env_info() print(output) +<<<<<<< HEAD if ( TORCH_AVAILABLE and hasattr(torch, "utils") @@ -929,4 +1273,20 @@ def main(): if __name__ == "__main__": +======= + if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'): + minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR + if sys.platform == "linux" and os.path.exists(minidump_dir): + dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)] + latest = max(dumps, key=os.path.getctime) + ctime = os.path.getctime(latest) + creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S') + msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ + "if this is related to your bug please include it when you file a report ***" + print(msg, file=sys.stderr) + + + +if __name__ == '__main__': +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) main() diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 902d2fe6ce0f5..734ab467f9ebf 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -22,8 +22,14 @@ import torch._appdirs from .file_baton import FileBaton from ._cpp_extension_versioner import ExtensionVersioner +<<<<<<< HEAD from typing import Optional, Union from typing_extensions import deprecated +======= +from .hipify import hipify_python +from .hipify.hipify_python import GeneratedFileCleaner +from typing import Optional, Union +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.torch_version import TorchVersion, Version from setuptools.command.build_ext import build_ext @@ -274,6 +280,7 @@ def _join_sycl_home(*paths) -> str: '-DHIP_ENABLE_WARP_SYNC_BUILTINS=1' ] +<<<<<<< HEAD if IS_WINDOWS: # Compatibility flags, similar to those set in cmake/Dependencies.cmake. COMMON_HIPCC_FLAGS.append('-fms-extensions') @@ -290,6 +297,8 @@ def _get_icpx_version() -> str: assert len(version) == 3, "Failed to parse DPC++ compiler version" # Aligning version format with what torch.version.xpu() returns return f"{version[0]}{version[1]:02}{version[2]:02}" +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _get_sycl_arch_list(): @@ -318,10 +327,17 @@ def _append_sycl_targets_if_missing(cflags): cflags.append('-fsycl-targets=spir64') def _get_sycl_device_flags(cflags): +<<<<<<< HEAD # We need last occurrence of -fsycl-targets as it will be the one taking effect. # So searching in reversed list. flags = [f for f in reversed(cflags) if f.startswith('-fsycl-targets=')] assert flags, "bug: -fsycl-targets should have been amended to cflags" +======= + # We need last occurence of -fsycl-targets as it will be the one taking effect. + # So searching in reversed list. + flags = [f for f in reversed(cflags) if f.startswith('-fsycl-targets=')] + assert flags, "bug: -fsycl-targets should have been ammended to cflags" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) arch_list = _get_sycl_arch_list() if arch_list != '': @@ -469,18 +485,32 @@ def get_compiler_abi_compatibility_and_version(compiler) -> tuple[bool, TorchVer try: if IS_LINUX: minimum_required_version = MINIMUM_GCC_VERSION +<<<<<<< HEAD compiler_info = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion']) else: minimum_required_version = MINIMUM_MSVC_VERSION compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT) match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode(*SUBPROCESS_DECODE_ARGS).strip()) version = ['0', '0', '0'] if match is None else list(match.groups()) +======= + versionstr = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion']) + version = versionstr.decode(*SUBPROCESS_DECODE_ARGS).strip().split('.') + else: + minimum_required_version = MINIMUM_MSVC_VERSION + compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT) + match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.decode(*SUBPROCESS_DECODE_ARGS).strip()) + version = ['0', '0', '0'] if match is None else list(match.groups()) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except Exception: _, error, _ = sys.exc_info() logger.warning('Error checking compiler version for %s: %s', compiler, error) return (False, TorchVersion('0.0.0')) +<<<<<<< HEAD # convert alphanumeric string to numeric string +======= + # convert alpha-numeric string to numeric string +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # amdclang++ returns str like 0.0.0git, others return 0.0.0 numeric_version = [re.sub(r'\D', '', v) for v in version] @@ -552,7 +582,10 @@ def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> N f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).' ) +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Specify Visual Studio C runtime library for hipcc def _set_hipcc_runtime_lib(is_standalone, debug): if is_standalone: @@ -688,6 +721,18 @@ def build_extensions(self) -> None: # min supported CPython version. # See https://docs.python.org/3/c-api/stable.html#c.Py_LIMITED_API self._add_compile_flag(extension, f'-DPy_LIMITED_API={min_supported_cpython}') +<<<<<<< HEAD +======= + else: + # pybind11 is not CPython API stable so don't add these flags used when + # compiling pybind11 when pybind11 is not even used. otherwise, the build + # logs are confusing. + # See note [Pybind11 ABI constants] + for name in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]: + val = getattr(torch._C, f"_PYBIND11_{name}") + if val is not None and not IS_WINDOWS: + self._add_compile_flag(extension, f'-DPYBIND11_{name}="{val}"') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._define_torch_extension_name(extension) if 'nvcc_dlink' in extension.extra_compile_args: @@ -776,7 +821,11 @@ def unix_wrap_ninja_compile(sources, r"""Compiles sources by outputting a ninja file and running it.""" # NB: I copied some lines from self.compiler (which is an instance # of distutils.UnixCCompiler). See the following link. +<<<<<<< HEAD # https://github.com/python/cpython/blob/f03a8f8d5001963ad5b5b28dbd95497e9cc15596/Lib/distutils/ccompiler.py#L564-L567 # codespell:ignore +======= + # https://github.com/python/cpython/blob/f03a8f8d5001963ad5b5b28dbd95497e9cc15596/Lib/distutils/ccompiler.py#L564-L567 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This can be fragile, but a lot of other repos also do this # (see https://github.com/search?q=_setup_compile&type=Code) # so it is probably OK; we'll also get CI signal if/when @@ -848,11 +897,15 @@ def unix_wrap_ninja_compile(sources, host_cflags = extra_cc_cflags + common_cflags + post_cflags append_std17_if_no_std_present(host_cflags) # escaping quoted arguments to pass them thru SYCL compiler +<<<<<<< HEAD icpx_version = _get_icpx_version() if int(icpx_version) >= 20250200: host_cflags = [item.replace('"', '\\"') for item in host_cflags] else: host_cflags = [item.replace('"', '\\\\"') for item in host_cflags] +======= + host_cflags = [item.replace('"', '\\\\"') for item in host_cflags] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) host_cflags = ' '.join(host_cflags) # Note the order: shlex.quote sycl_flags first, _wrap_sycl_host_flags # second. Reason is that sycl host flags are quoted, space containing @@ -909,7 +962,11 @@ def spawn(cmd): if m ] +<<<<<<< HEAD obj_regex = re.compile('/Fo(.*)') # codespell:ignore +======= + obj_regex = re.compile('/Fo(.*)') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) obj_list = [ m.group(1) for m in (obj_regex.match(elem) for elem in cmd) if m @@ -1067,7 +1124,11 @@ def win_wrap_ninja_compile(sources, # Return *all* object filenames, not just the ones we just built. return objects # Monkey-patch the _compile or compile method. +<<<<<<< HEAD # https://github.com/python/cpython/blob/dc0284ee8f7a270b6005467f26d8e5773d76e959/Lib/distutils/ccompiler.py#L511 # codespell:ignore +======= + # https://github.com/python/cpython/blob/dc0284ee8f7a270b6005467f26d8e5773d76e959/Lib/distutils/ccompiler.py#L511 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if self.compiler.compiler_type == 'msvc': if self.use_ninja: self.compiler.compile = win_wrap_ninja_compile @@ -1367,7 +1428,10 @@ def CUDAExtension(name, sources, *args, **kwargs): include_dirs = kwargs.get('include_dirs', []) if IS_HIP_EXTENSION: +<<<<<<< HEAD from .hipify import hipify_python +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) build_dir = os.getcwd() hipify_result = hipify_python.hipify( project_directory=build_dir, @@ -1705,9 +1769,31 @@ def load(name, is_standalone, keep_intermediates=keep_intermediates) +<<<<<<< HEAD @deprecated("PyBind11 ABI handling is internal to PyBind11; this will be removed after PyTorch 2.9.0") def _get_pybind11_abi_build_flags() -> list[str]: return [] +======= +def _get_pybind11_abi_build_flags(): + # Note [Pybind11 ABI constants] + # + # Pybind11 before 2.4 used to build an ABI strings using the following pattern: + # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_BUILD_TYPE}__" + # Since 2.4 compier type, stdlib and build abi parameters are also encoded like this: + # f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_COMPILER_TYPE}{PYBIND11_STDLIB}{PYBIND11_BUILD_ABI}{PYBIND11_BUILD_TYPE}__" + # + # This was done in order to further narrow down the chances of compiler ABI incompatibility + # that can cause a hard to debug segfaults. + # For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties + # captured during PyTorch native library compilation in torch/csrc/Module.cpp + + abi_cflags = [] + for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]: + pval = getattr(torch._C, f"_PYBIND11_{pname}") + if pval is not None and not IS_WINDOWS: + abi_cflags.append(f'-DPYBIND11_{pname}=\\"{pval}\\"') + return abi_cflags +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def check_compiler_is_gcc(compiler): if not IS_LINUX: @@ -1739,7 +1825,11 @@ def _check_and_build_extension_h_precompiler_headers( is_standalone=False): r''' Precompiled Headers(PCH) can pre-build the same headers and reduce build time for pytorch load_inline modules. +<<<<<<< HEAD GCC official manual: https://gcc.gnu.org/onlinedocs/gcc-4.0.4/gcc/Precompiled-Headers.html +======= + GCC offical manual: https://gcc.gnu.org/onlinedocs/gcc-4.0.4/gcc/Precompiled-Headers.html +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) PCH only works when built pch file(header.h.gch) and build target have the same build parameters. So, We need add a signature file to record PCH file parameters. If the build parameters(signature) changed, it should rebuild PCH file. @@ -1838,6 +1928,10 @@ def build_precompile_header(pch_cmd): common_cflags += ['-DTORCH_API_INCLUDE_EXTENSION_H'] common_cflags += ['-std=c++17', '-fPIC'] +<<<<<<< HEAD +======= + common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) common_cflags_str = listToString(common_cflags) pch_cmd = format_precompiler_header_cmd(compiler, head_file, head_file_pch, common_cflags_str, torch_include_dirs_str, extra_cflags_str, extra_include_paths_str) @@ -2108,8 +2202,11 @@ def _jit_compile(name, if baton.try_acquire(): try: if version != old_version: +<<<<<<< HEAD from .hipify import hipify_python from .hipify.hipify_python import GeneratedFileCleaner +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) with GeneratedFileCleaner(keep_intermediates=keep_intermediates) as clean_ctx: if IS_HIP_EXTENSION and (with_cuda or with_cudnn): hipify_result = hipify_python.hipify( @@ -2400,13 +2497,21 @@ def _get_cuda_arch_flags(cflags: Optional[list[str]] = None) -> list[str]: ('Ampere', '8.0;8.6+PTX'), ('Ada', '8.9+PTX'), ('Hopper', '9.0+PTX'), +<<<<<<< HEAD ('Blackwell+Tegra', '11.0'), +======= + ('Blackwell+Tegra', '10.1'), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ('Blackwell', '10.0;10.3;12.0;12.1+PTX'), ]) supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2', '7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0', '9.0a', +<<<<<<< HEAD '10.0', '10.0a', '11.0', '11.0a', '10.3', '10.3a', '12.0', +======= + '10.0', '10.0a', '10.1', '10.1a', '10.3', '10.3a', '12.0', +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) '12.0a', '12.1', '12.1a'] valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches] @@ -2416,8 +2521,16 @@ def _get_cuda_arch_flags(cflags: Optional[list[str]] = None) -> list[str]: # See cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake _arch_list = os.environ.get('TORCH_CUDA_ARCH_LIST', None) +<<<<<<< HEAD # If not given or set as native, determine what's best for the GPU / CUDA version that can be found if not _arch_list or _arch_list == "native": +======= + # If not given, determine what's best for the GPU / CUDA version that can be found + if not _arch_list: + logger.warning( + "TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n" + "If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) arch_list = [] # the assumption is that the extension should run on any of the currently visible cards, # which could be of different types - therefore all archs for visible cards should be included @@ -2436,6 +2549,7 @@ def _get_cuda_arch_flags(cflags: Optional[list[str]] = None) -> list[str]: arch_list.append(arch) arch_list = sorted(arch_list) arch_list[-1] += '+PTX' +<<<<<<< HEAD if not _arch_list: # Only log on rank 0 in distributed settings to avoid spam @@ -2445,12 +2559,19 @@ def _get_cuda_arch_flags(cflags: Optional[list[str]] = None) -> list[str]: "TORCH_CUDA_ARCH_LIST is not set, using TORCH_CUDA_ARCH_LIST='%s' " "for visible GPU architectures. Set os.environ['TORCH_CUDA_ARCH_LIST'] to override.", arch_list_str) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) else: # Deal with lists that are ' ' separated (only deal with ';' after) _arch_list = _arch_list.replace(' ', ';') # Expand named arches +<<<<<<< HEAD for named_arch, archival in named_arches.items(): _arch_list = _arch_list.replace(named_arch, archival) +======= + for named_arch, archval in named_arches.items(): + _arch_list = _arch_list.replace(named_arch, archval) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) arch_list = _arch_list.split(';') @@ -2585,7 +2706,11 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> # subprocess.run assumes that sys.__stdout__ has not been modified and # attempts to write to it by default. However, when we call _run_ninja_build # from ahead-of-time cpp extensions, the following happens: +<<<<<<< HEAD # 1) If the stdout encoding is not utf-8, setuptools detaches __stdout__. +======= + # 1) If the stdout encoding is not utf-8, setuptools detachs __stdout__. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # https://github.com/pypa/setuptools/blob/7e97def47723303fafabe48b22168bbc11bb4821/setuptools/dist.py#L1110 # (it probably shouldn't do this) # 2) subprocess.run (on POSIX, with no stdout override) relies on @@ -2679,6 +2804,11 @@ def _write_ninja_file_to_build_library(path, common_cflags.append(f'-DTORCH_EXTENSION_NAME={name}') common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H') +<<<<<<< HEAD +======= + common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Windows does not understand `-isystem` and quotes flags later. if IS_WINDOWS: common_cflags += [f'-I{include}' for include in user_includes + system_includes] @@ -2698,7 +2828,11 @@ def _write_ninja_file_to_build_library(path, cuda_flags += _get_rocm_arch_flags(cuda_flags) cuda_flags += extra_cuda_cflags elif with_cuda: +<<<<<<< HEAD cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags(extra_cuda_cflags) +======= + cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if IS_WINDOWS: for flag in COMMON_MSVC_FLAGS: cuda_flags = ['-Xcompiler', flag] + cuda_flags @@ -2725,9 +2859,13 @@ def _write_ninja_file_to_build_library(path, _append_sycl_std_if_no_std_present(sycl_cflags) host_cflags = cflags # escaping quoted arguments to pass them thru SYCL compiler +<<<<<<< HEAD icpx_version = _get_icpx_version() if int(icpx_version) < 20250200: host_cflags = [item.replace('\\"', '\\\\"') for item in host_cflags] +======= + host_cflags = [item.replace('\\"', '\\\\"') for item in host_cflags] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) host_cflags = ' '.join(host_cflags) sycl_cflags += _wrap_sycl_host_flags(host_cflags) sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS.copy() @@ -2874,9 +3012,13 @@ def sanitize_flags(flags): if IS_WINDOWS: compiler_name = "$cxx" if IS_HIP_EXTENSION else "cl" compile_rule.append( +<<<<<<< HEAD f' command = {compiler_name} ' '/showIncludes $cflags -c $in /Fo$out $post_cflags' # codespell:ignore ) +======= + f' command = {compiler_name} /showIncludes $cflags -c $in /Fo$out $post_cflags') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if not IS_HIP_EXTENSION: compile_rule.append(' deps = msvc') else: diff --git a/torch/utils/data/_utils/__init__.py b/torch/utils/data/_utils/__init__.py index 44111ef697b71..881380a114621 100644 --- a/torch/utils/data/_utils/__init__.py +++ b/torch/utils/data/_utils/__init__.py @@ -1,3 +1,7 @@ +<<<<<<< HEAD +======= +# mypy: allow-untyped-defs +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r"""Utility classes & functions for data loading. Code in this folder is mostly used by ../dataloder.py. A lot of multiprocessing is used in data loading, which only supports running @@ -42,7 +46,11 @@ HAS_NUMPY = False +<<<<<<< HEAD def _set_python_exit_flag() -> None: +======= +def _set_python_exit_flag(): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) global python_exit_status python_exit_status = True diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index 3b291b1e60a4c..2356058f08642 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -44,7 +44,11 @@ def default_convert(data): >>> default_convert(np.array([0, 1])) tensor([0, 1]) >>> # Example with NamedTuple +<<<<<<< HEAD >>> Point = namedtuple("Point", ["x", "y"]) +======= + >>> Point = namedtuple('Point', ['x', 'y']) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> default_convert(Point(0, 0)) Point(x=0, y=0) >>> default_convert(Point(np.array(0), np.array(0))) @@ -366,6 +370,7 @@ def default_collate(batch): >>> default_collate([0, 1, 2, 3]) tensor([0, 1, 2, 3]) >>> # Example with a batch of `str`s: +<<<<<<< HEAD >>> default_collate(["a", "b", "c"]) ['a', 'b', 'c'] >>> # Example with `Map` inside the batch: @@ -373,6 +378,15 @@ def default_collate(batch): {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} >>> # Example with `NamedTuple` inside the batch: >>> Point = namedtuple("Point", ["x", "y"]) +======= + >>> default_collate(['a', 'b', 'c']) + ['a', 'b', 'c'] + >>> # Example with `Map` inside the batch: + >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) + {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} + >>> # Example with `NamedTuple` inside the batch: + >>> Point = namedtuple('Point', ['x', 'y']) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> default_collate([Point(0, 0), Point(1, 1)]) Point(x=tensor([0, 1]), y=tensor([0, 1])) >>> # Example with `Tuple` inside the batch: diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py index b53c7aef9596f..3433247ad7f75 100644 --- a/torch/utils/data/_utils/pin_memory.py +++ b/torch/utils/data/_utils/pin_memory.py @@ -21,7 +21,20 @@ def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device): torch.set_num_threads(1) torch.multiprocessing._set_thread_name("pt_data_pin") +<<<<<<< HEAD torch.accelerator.set_device_index(device_id) +======= + + if device == "cuda": + torch.cuda.set_device(device_id) + elif device == "xpu": + torch.xpu.set_device(device_id) # type: ignore[attr-defined] + elif device == torch._C._get_privateuse1_backend_name(): + custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) + custom_device_mod.set_device(device_id) + elif device is None: + torch.accelerator.set_device_index(device_id) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def do_one_step(): try: @@ -69,9 +82,13 @@ def pin_memory(data, device=None): ) return clone else: +<<<<<<< HEAD return type(data)( {k: pin_memory(sample, device) for k, sample in data.items()} ) # type: ignore[call-arg] +======= + return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) except TypeError: # The mapping type may not support `copy()` / `update(mapping)` # or `__init__(iterable)`. diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py index 97c7243e78ef7..cc053363f3c05 100644 --- a/torch/utils/data/_utils/worker.py +++ b/torch/utils/data/_utils/worker.py @@ -1,5 +1,9 @@ # mypy: allow-untyped-defs +<<<<<<< HEAD r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers. +======= +r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) These **needs** to be in global scope since Py2 doesn't support serializing static methods. diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 991b4f00eb85e..f2f1b3f562910 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -5,7 +5,10 @@ functions to be run in multiprocessing. E.g., the data loading worker loop is in `./_utils/worker.py`. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from __future__ import annotations import functools @@ -191,8 +194,14 @@ class DataLoader(Generic[_T_co]): persistent_workers (bool, optional): If ``True``, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers `Dataset` instances alive. (default: ``False``) +<<<<<<< HEAD pin_memory_device (str, optional): Deprecated, the current :ref:`accelerator` will be used as the device if ``pin_memory=True``. +======= + pin_memory_device (str, optional): the device to :attr:`pin_memory` on if ``pin_memory`` is + ``True``. If not given, the current :ref:`accelerator` will be the + default. This argument is discouraged and subject to deprecated. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) in_order (bool, optional): If ``False``, the data loader will not enforce that batches are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``) @@ -480,7 +489,11 @@ def __setattr__(self, attr, val): def __iter__(self) -> _BaseDataLoaderIter: # When using a single worker the returned iterator should be +<<<<<<< HEAD # created every time to avoid resetting its state +======= + # created everytime to avoid resetting its state +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # However, in the case of a multiple workers iterator # the iterator is only created once in the lifetime of the # DataLoader object so that workers can be reused @@ -556,10 +569,17 @@ def check_worker_number_rationality(self): # necessary. # # +<<<<<<< HEAD # [Note] Please note that this function respects `cpuset` only when os.sched_getaffinity is # available (available in most of Linux system, but not OSX and Windows). # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but # it doesn't respect cpuset. +======= + # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is + # available (available in most of Linux system, but not OSX and Windows). + # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but + # it doesn't repect cpuset. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # We don't take threading into account since each worker process is single threaded # at this time. # @@ -654,6 +674,7 @@ def __init__(self, loader: DataLoader) -> None: ws, rank = _get_distributed_settings() self._world_size = ws self._rank = rank +<<<<<<< HEAD if loader.pin_memory and loader.pin_memory_device: warnings.warn( @@ -691,6 +712,47 @@ def __init__(self, loader: DataLoader) -> None: ) warnings.warn(warn_msg) +======= + # If pin_memory_device not set, default behaviour is current accelerator. + # If pin_memory_device is set but pin_memory is not set, the default + # behaviour false. + if len(loader.pin_memory_device) == 0: + if loader.pin_memory and not torch.accelerator.is_available(): + warn_msg = ( + "'pin_memory' argument is set as true but no accelerator is found, " + "then device pinned memory won't be used." + ) + warnings.warn(warn_msg) + + self._pin_memory = loader.pin_memory and torch.accelerator.is_available() + self._pin_memory_device = None + # Currently, pin_memory would raise error on the MPS backend (see + # https://github.com/pytorch/pytorch/issues/86060), so forcibly + # disable pin_memory on MPS. Remove this restriction once pinned + # memory allocation for MPS is fixed. + if ( + self._pin_memory + and (acc := torch.accelerator.current_accelerator()) is not None + and acc.type == "mps" + ): + self._pin_memory = False + warn_msg = ( + "'pin_memory' argument is set as true but not supported on MPS now, " + "then device pinned memory won't be used." + ) + warnings.warn(warn_msg) + else: + if not loader.pin_memory: + warn_msg = ( + "'pin_memory_device' is set but 'pin_memory' argument is not set, " + "then device pinned memory won't be used." + "please set 'pin_memory' to true, if you need to use the device pin memory" + ) + warnings.warn(warn_msg) + + self._pin_memory = loader.pin_memory + self._pin_memory_device = loader.pin_memory_device +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._timeout = loader.timeout self._collate_fn = loader.collate_fn self._sampler_iter = iter(self._index_sampler) @@ -884,7 +946,11 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): # 2. A similar issue araises when a `DataLoader` is used in a subprocess. # When a process ends, it shuts the all its daemonic children # down with a SIGTERM (instead of joining them without a timeout). +<<<<<<< HEAD # Similarly for threads, but by a different mechanism. This fact, +======= + # Simiarly for threads, but by a different mechanism. This fact, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # together with a few implementation details of multiprocessing, forces # us to make workers daemonic. All of our problems arise when a # DataLoader is used in a subprocess, and are caused by multiprocessing @@ -1015,7 +1081,11 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): # `cancel_join_thread` on that queue if its `IterableDataset` iterator # happens to exhaust coincidentally, which is out of the control of the # main process). Thus, since we will exit `pin_memory_thread` before the +<<<<<<< HEAD # workers (see below), two separate events are used. +======= + # workers (see below), two separete events are used. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # # NOTE: In short, the protocol is that the main process will set these # `done_event`s and then the corresponding processes/threads a `None`, @@ -1176,13 +1246,32 @@ def __init__(self, loader): # Queue is not type-annotated self._data_queue = queue.Queue() # type: ignore[var-annotated] +<<<<<<< HEAD current_device_id = torch.accelerator.current_device_index() +======= + current_device = -1 + if self._pin_memory_device == "cuda": + current_device = torch.cuda.current_device() + elif self._pin_memory_device == "xpu": + current_device = torch.xpu.current_device() + elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): + custom_device_mod = getattr( + torch, torch._C._get_privateuse1_backend_name() + ) + current_device = custom_device_mod.current_device() + elif self._pin_memory_device is None: + current_device = torch.accelerator.current_device_index() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) pin_memory_thread = threading.Thread( target=_utils.pin_memory._pin_memory_loop, args=( self._worker_result_queue, self._data_queue, +<<<<<<< HEAD current_device_id, +======= + current_device, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self._pin_memory_thread_done_event, self._pin_memory_device, ), @@ -1209,10 +1298,14 @@ def __init__(self, loader): atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) # .pid can be None only before process is spawned (not the case, so ignore) +<<<<<<< HEAD _utils.signal_handling._set_worker_pids( id(self), tuple(w.pid for w in self._workers), # type: ignore[misc] ) +======= + _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) _utils.signal_handling._set_SIGCHLD_handler() self._worker_pids_set = True self._reset(loader, first_iter=True) @@ -1561,7 +1654,11 @@ def _mark_worker_as_unavailable(self, worker_id, shutdown=False): # (2) since we don't join, the worker may still raise error, and we # prefer capturing those, rather than ignoring them, even though they # are raised after the worker has finished its job. +<<<<<<< HEAD # Joining is deferred to `_shutdown_workers`, which it is called when +======= + # Joinning is deferred to `_shutdown_workers`, which it is called when +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # all workers finish their jobs (e.g., `IterableDataset` replicas) or # when this iterator is garbage collected. diff --git a/torch/utils/data/datapipes/README.md b/torch/utils/data/datapipes/README.md index e8776bc39b87b..f0e37f53f9190 100644 --- a/torch/utils/data/datapipes/README.md +++ b/torch/utils/data/datapipes/README.md @@ -51,7 +51,11 @@ Note that `__len__` method is optional for `IterDataPipe`. Like `CSVParserIterDataPipe` in the [Using DataPipe sector](#using-datapipe), `__len__` is not implemented because the size of each file streams is unknown for us before loading it. Besides, in some special cases, `__len__` method can be provided, but it would either return an integer length or raise Error depending on the arguments of DataPipe. +<<<<<<< HEAD And, the Error is required to be `TypeError` to support Python's built-in functions like `list(dp)`. +======= +And, the Error is required to be `TypeError` to support Python's build-in functions like `list(dp)`. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Please check NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] for detailed reason in PyTorch. ### Registering DataPipe with functional API diff --git a/torch/utils/data/datapipes/_decorator.py b/torch/utils/data/datapipes/_decorator.py index 0833f8fdf759b..d9b90339be71c 100644 --- a/torch/utils/data/datapipes/_decorator.py +++ b/torch/utils/data/datapipes/_decorator.py @@ -109,7 +109,12 @@ def __call__(self, *args, **kwargs): # Decorate with a functional argument if not ( +<<<<<<< HEAD isinstance(args[0], type) and issubclass(args[0], IterDataPipe) # type: ignore[arg-type] +======= + isinstance(args[0], type) + and issubclass(args[0], IterDataPipe) # type: ignore[arg-type] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ): raise TypeError( f"Only `IterDataPipe` can be decorated, but {args[0].__name__} is found" diff --git a/torch/utils/data/datapipes/_typing.py b/torch/utils/data/datapipes/_typing.py index d3ae5b4e18f4c..111bfd1c33bc3 100644 --- a/torch/utils/data/datapipes/_typing.py +++ b/torch/utils/data/datapipes/_typing.py @@ -138,7 +138,11 @@ def _issubtype_with_constraints(variant, constraints, recursive=True): # - TypeVar[TypeVar[...]] # So, variant and each constraint may be a TypeVar or a Union. # In these cases, all of inner types from the variant are required to be +<<<<<<< HEAD # extracted and verified as a subtype of any constraint. And, all of +======= + # extraced and verified as a subtype of any constraint. And, all of +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # inner types from any constraint being a TypeVar or a Union are # also required to be extracted and verified if the variant belongs to # any of them. diff --git a/torch/utils/data/datapipes/dataframe/dataframes.py b/torch/utils/data/datapipes/dataframe/dataframes.py index d697cb6ebc5c2..57a26668f6635 100644 --- a/torch/utils/data/datapipes/dataframe/dataframes.py +++ b/torch/utils/data/datapipes/dataframe/dataframes.py @@ -51,7 +51,11 @@ def __iter__(self): yield self.output_var.apply_ops(item) +<<<<<<< HEAD # TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registered functions +======= +# TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registred functions +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DATAPIPES_OPS = [ "_dataframes_as_tuples", "groupby", @@ -201,7 +205,11 @@ class CaptureLikeMock: def __init__(self, name): import unittest.mock as mock +<<<<<<< HEAD # TODO(VitalyFedyunin): Do not use private function here, copy own implementation instead. +======= + # TODO(VitalyFedyunin): Do not use provate function here, copy own implementation instead. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) get_target, attribute = mock._get_target(name) # type: ignore[attr-defined] self.get_target = get_target self.attribute = attribute diff --git a/torch/utils/data/datapipes/datapipe.py b/torch/utils/data/datapipes/datapipe.py index 506f642c411db..41c60a05afd80 100644 --- a/torch/utils/data/datapipes/datapipe.py +++ b/torch/utils/data/datapipes/datapipe.py @@ -99,9 +99,13 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): >>> from torchdata.datapipes.iter import IterableWrapper, Mapper >>> dp = IterableWrapper(range(10)) >>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor +<<<<<<< HEAD >>> map_dp_2 = dp.map( ... lambda x: x + 1 ... ) # Using functional form (recommended) +======= + >>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> list(map_dp_1) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] >>> list(map_dp_2) @@ -116,9 +120,13 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): >>> list(it1) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] >>> it1 = iter(source_dp) +<<<<<<< HEAD >>> it2 = iter( ... source_dp ... ) # The creation of a new iterator invalidates `it1` +======= + >>> it2 = iter(source_dp) # The creation of a new iterator invalidates `it1` +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> next(it2) 0 >>> next(it1) # Further usage of `it1` will raise a `RunTimeError` diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index 41c6bb362af2b..8a194df5876c3 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -4,7 +4,10 @@ from collections.abc import Iterator, Sized from typing import Any, Callable, Optional, TypeVar, Union +<<<<<<< HEAD import torch +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils.data._utils.collate import default_collate from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper @@ -55,8 +58,12 @@ class MapperIterDataPipe(IterDataPipe[_T_co]): >>> def add_one(x): ... return x + 1 >>> dp = IterableWrapper(range(10)) +<<<<<<< HEAD >>> # Invocation via functional form is preferred ... map_dp_1 = dp.map(add_one) +======= + >>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> list(map_dp_1) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle` @@ -76,7 +83,10 @@ def __init__( input_col=None, output_col=None, ) -> None: +<<<<<<< HEAD torch._C._log_api_usage_once("python.data_pipes.map") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__() self.datapipe = datapipe @@ -150,7 +160,11 @@ def _collate_helper(conversion, item): for name in conversion.keys(): if name not in columns_name: +<<<<<<< HEAD raise RuntimeError("Conversion keys mismatch") +======= + raise RuntimeError("Conversion keys missmatch") +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for name in columns_name: if name in conversion: @@ -203,7 +217,11 @@ class CollatorIterDataPipe(MapperIterDataPipe): >>> class MyIterDataPipe(torch.utils.data.IterDataPipe): ... def __init__(self, start, end): ... super(MyIterDataPipe).__init__() +<<<<<<< HEAD ... assert end > start, "this example only works with end >= start" +======= + ... assert end > start, "this example code only works with end >= start" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ... self.start = start ... self.end = end ... @@ -212,11 +230,19 @@ class CollatorIterDataPipe(MapperIterDataPipe): ... ... def __len__(self): ... return self.end - self.start +<<<<<<< HEAD +======= + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> ds = MyIterDataPipe(start=3, end=7) >>> print(list(ds)) [3, 4, 5, 6] >>> def collate_fn(batch): ... return torch.tensor(batch, dtype=torch.float) +<<<<<<< HEAD +======= + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn) >>> print(list(collated_ds)) [tensor(3.), tensor(4.), tensor(5.), tensor(6.)] diff --git a/torch/utils/data/datapipes/iter/combinatorics.py b/torch/utils/data/datapipes/iter/combinatorics.py index f92edd6b7b39c..5992a3a68319e 100644 --- a/torch/utils/data/datapipes/iter/combinatorics.py +++ b/torch/utils/data/datapipes/iter/combinatorics.py @@ -38,17 +38,27 @@ def __init__( sampler_args: Optional[tuple] = None, sampler_kwargs: Optional[dict] = None, ) -> None: +<<<<<<< HEAD assert isinstance(datapipe, Sized), ( "Sampler class requires input datapipe implemented `__len__`" ) +======= + assert isinstance( + datapipe, Sized + ), "Sampler class requires input datapipe implemented `__len__`" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) super().__init__() self.datapipe = datapipe self.sampler_args = () if sampler_args is None else sampler_args self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs # https://github.com/python/mypy/pull/9629 will solve +<<<<<<< HEAD self.sampler = sampler( *self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs ) # type: ignore[misc] +======= + self.sampler = sampler(*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs) # type: ignore[misc] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __iter__(self) -> Iterator[_T_co]: return iter(self.sampler) diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index 8c6abc5062105..ec97bd2e12d01 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -116,6 +116,7 @@ class _ContainerTemplate(ABC): r"""Abstract class for container ``DataPipes``. The followings are three required methods.""" @abstractmethod +<<<<<<< HEAD def get_next_element_by_instance(self, instance_id: int): ... @abstractmethod @@ -123,6 +124,18 @@ def is_every_instance_exhausted(self) -> bool: ... @abstractmethod def reset(self) -> None: ... +======= + def get_next_element_by_instance(self, instance_id: int): + ... + + @abstractmethod + def is_every_instance_exhausted(self) -> bool: + ... + + @abstractmethod + def reset(self) -> None: + ... +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) @abstractmethod def get_length_by_instance(self, instance_id: int): @@ -400,9 +413,13 @@ class DemultiplexerIterDataPipe(IterDataPipe): >>> # It can also filter out any element that gets `None` from the `classifier_fn` >>> def odd_or_even_no_zero(n): ... return n % 2 if n != 0 else None +<<<<<<< HEAD >>> dp1, dp2 = source_dp.demux( ... num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True ... ) +======= + >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> list(dp1) [2, 4] >>> list(dp2) @@ -427,9 +444,13 @@ def __new__( # When num_instances == 1, demux can be replaced by filter, # but keep it as Demultiplexer for the sake of consistency # like throwing Error when classification result is out of o range +<<<<<<< HEAD container = _DemultiplexerIterDataPipe( datapipe, num_instances, classifier_fn, drop_none, buffer_size ) # type: ignore[abstract] +======= + container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size) # type: ignore[abstract] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return [_ChildDataPipe(container, i) for i in range(num_instances)] @@ -603,18 +624,28 @@ class MultiplexerIterDataPipe(IterDataPipe): Example: >>> # xdoctest: +REQUIRES(module:torchdata) >>> from torchdata.datapipes.iter import IterableWrapper +<<<<<<< HEAD >>> dp1, dp2, dp3 = ( ... IterableWrapper(range(3)), ... IterableWrapper(range(10, 15)), ... IterableWrapper(range(20, 25)), ... ) +======= + >>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> list(dp1.mux(dp2, dp3)) [0, 10, 20, 1, 11, 21, 2, 12, 22] """ def __init__(self, *datapipes): self.datapipes = datapipes +<<<<<<< HEAD self.buffer: list = [] # Store values to be yielded only when every iterator provides one +======= + self.buffer: list = ( + [] + ) # Store values to be yielded only when every iterator provides one +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __iter__(self): iterators = [iter(x) for x in self.datapipes] @@ -673,11 +704,15 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]): Example: >>> # xdoctest: +REQUIRES(module:torchdata) >>> from torchdata.datapipes.iter import IterableWrapper +<<<<<<< HEAD >>> dp1, dp2, dp3 = ( ... IterableWrapper(range(5)), ... IterableWrapper(range(10, 15)), ... IterableWrapper(range(20, 25)), ... ) +======= + >>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> list(dp1.zip(dp2, dp3)) [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)] """ diff --git a/torch/utils/data/datapipes/iter/fileopener.py b/torch/utils/data/datapipes/iter/fileopener.py index 3025b809e12df..a737d548d1e05 100644 --- a/torch/utils/data/datapipes/iter/fileopener.py +++ b/torch/utils/data/datapipes/iter/fileopener.py @@ -33,12 +33,17 @@ class FileOpenerIterDataPipe(IterDataPipe[tuple[str, IOBase]]): Example: >>> # xdoctest: +SKIP +<<<<<<< HEAD >>> from torchdata.datapipes.iter import ( ... FileLister, ... FileOpener, ... StreamReader, ... ) >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith(".txt")) +======= + >>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader + >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt')) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> dp = FileOpener(dp) >>> dp = StreamReader(dp) >>> list(dp) diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index 055d9c28b09be..5d5d8d57ea358 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -182,9 +182,13 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]): >>> from torchdata.datapipes.iter import IterableWrapper >>> def group_fn(file): ... return os.path.basename(file).split(".")[0] +<<<<<<< HEAD >>> source_dp = IterableWrapper( ... ["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"] ... ) +======= + >>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> dp0 = source_dp.groupby(group_key_fn=group_fn) >>> list(dp0) [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']] @@ -193,12 +197,16 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]): >>> list(dp1) [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size` +<<<<<<< HEAD >>> dp2 = source_dp.groupby( ... group_key_fn=group_fn, ... buffer_size=3, ... group_size=3, ... guaranteed_group_size=2, ... ) +======= + >>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> list(dp2) [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] """ diff --git a/torch/utils/data/datapipes/iter/utils.py b/torch/utils/data/datapipes/iter/utils.py index f90b426be129a..1efa4b32ec026 100644 --- a/torch/utils/data/datapipes/iter/utils.py +++ b/torch/utils/data/datapipes/iter/utils.py @@ -1,17 +1,30 @@ +<<<<<<< HEAD import copy import warnings from collections.abc import Iterable, Iterator, Sized from typing import TypeVar +======= +# mypy: allow-untyped-defs +import copy +import warnings +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils.data.datapipes.datapipe import IterDataPipe +<<<<<<< HEAD _T = TypeVar("_T") __all__ = ["IterableWrapperIterDataPipe"] class IterableWrapperIterDataPipe(IterDataPipe[_T]): +======= +__all__ = ["IterableWrapperIterDataPipe"] + + +class IterableWrapperIterDataPipe(IterDataPipe): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r""" Wraps an iterable object to create an IterDataPipe. @@ -33,11 +46,19 @@ class IterableWrapperIterDataPipe(IterDataPipe[_T]): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ +<<<<<<< HEAD def __init__(self, iterable: Iterable[_T], deepcopy: bool = True) -> None: self.iterable = iterable self.deepcopy = deepcopy def __iter__(self) -> Iterator[_T]: +======= + def __init__(self, iterable, deepcopy=True): + self.iterable = iterable + self.deepcopy = deepcopy + + def __iter__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) source_data = self.iterable if self.deepcopy: try: @@ -53,7 +74,12 @@ def __iter__(self) -> Iterator[_T]: ) yield from source_data +<<<<<<< HEAD def __len__(self) -> int: if isinstance(self.iterable, Sized): return len(self.iterable) raise TypeError(f"{type(self).__name__} instance doesn't have valid length") +======= + def __len__(self): + return len(self.iterable) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/utils/data/datapipes/map/utils.py b/torch/utils/data/datapipes/map/utils.py index e1290df323724..2ce3390e1f037 100644 --- a/torch/utils/data/datapipes/map/utils.py +++ b/torch/utils/data/datapipes/map/utils.py @@ -1,17 +1,30 @@ +<<<<<<< HEAD import copy import warnings from collections.abc import Mapping, Sequence from typing import Any, TypeVar, Union +======= +# mypy: allow-untyped-defs +import copy +import warnings +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torch.utils.data.datapipes.datapipe import MapDataPipe +<<<<<<< HEAD _T = TypeVar("_T") __all__ = ["SequenceWrapperMapDataPipe"] class SequenceWrapperMapDataPipe(MapDataPipe[_T]): +======= +__all__ = ["SequenceWrapperMapDataPipe"] + + +class SequenceWrapperMapDataPipe(MapDataPipe): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) r""" Wraps a sequence object into a MapDataPipe. @@ -31,6 +44,7 @@ class SequenceWrapperMapDataPipe(MapDataPipe[_T]): >>> dp = SequenceWrapper(range(10)) >>> list(dp) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +<<<<<<< HEAD >>> dp = SequenceWrapper({"a": 100, "b": 200, "c": 300, "d": 400}) >>> dp["a"] 100 @@ -41,6 +55,14 @@ class SequenceWrapperMapDataPipe(MapDataPipe[_T]): def __init__( self, sequence: Union[Sequence[_T], Mapping[Any, _T]], deepcopy: bool = True ) -> None: +======= + >>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400}) + >>> dp['a'] + 100 + """ + + def __init__(self, sequence, deepcopy=True): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if deepcopy: try: self.sequence = copy.deepcopy(sequence) @@ -53,8 +75,15 @@ def __init__( else: self.sequence = sequence +<<<<<<< HEAD def __getitem__(self, index: int) -> _T: return self.sequence[index] def __len__(self) -> int: +======= + def __getitem__(self, index): + return self.sequence[index] + + def __len__(self): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return len(self.sequence) diff --git a/torch/utils/data/datapipes/utils/decoder.py b/torch/utils/data/datapipes/utils/decoder.py index 9db7309bdc525..41ad58096025c 100644 --- a/torch/utils/data/datapipes/utils/decoder.py +++ b/torch/utils/data/datapipes/utils/decoder.py @@ -45,8 +45,13 @@ def basichandlers(extension: str, data): Example: >>> import pickle +<<<<<<< HEAD >>> data = pickle.dumps("some data") >>> new_data = basichandlers("pickle", data) +======= + >>> data = pickle.dumps('some data') + >>> new_data = basichandlers('pickle', data) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) >>> new_data some data @@ -169,9 +174,15 @@ class ImageHandler: """ def __init__(self, imagespec): +<<<<<<< HEAD assert imagespec in list(imagespecs.keys()), ( f"unknown image specification: {imagespec}" ) +======= + assert imagespec in list( + imagespecs.keys() + ), f"unknown image specification: {imagespec}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.imagespec = imagespec.lower() def __call__(self, extension, data): @@ -205,18 +216,30 @@ def __call__(self, extension, data): return img elif atype == "numpy": result = np.asarray(img) +<<<<<<< HEAD assert result.dtype == np.uint8, ( f"numpy image array should be type uint8, but got {result.dtype}" ) +======= + assert ( + result.dtype == np.uint8 + ), f"numpy image array should be type uint8, but got {result.dtype}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if etype == "uint8": return result else: return result.astype("f") / 255.0 elif atype == "torch": result = np.asarray(img) +<<<<<<< HEAD assert result.dtype == np.uint8, ( f"numpy image array should be type uint8, but got {result.dtype}" ) +======= + assert ( + result.dtype == np.uint8 + ), f"numpy image array should be type uint8, but got {result.dtype}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if etype == "uint8": result = np.array(result.transpose(2, 0, 1)) diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index e8164e015a668..125f343f427b6 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -96,7 +96,11 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]): >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() +<<<<<<< HEAD ... assert end > start, "this example only works with end >= start" +======= + ... assert end > start, "this example code only works with end >= start" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ... self.start = start ... self.end = end ... @@ -138,7 +142,11 @@ class IterableDataset(Dataset[_T_co], Iterable[_T_co]): >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() +<<<<<<< HEAD ... assert end > start, "this example only works with end >= start" +======= + ... assert end > start, "this example code only works with end >= start" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ... self.start = start ... self.end = end ... @@ -198,9 +206,15 @@ class TensorDataset(Dataset[tuple[Tensor, ...]]): tensors: tuple[Tensor, ...] def __init__(self, *tensors: Tensor) -> None: +<<<<<<< HEAD assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), ( "Size mismatch between tensors" ) +======= + assert all( + tensors[0].size(0) == tensor.size(0) for tensor in tensors + ), "Size mismatch between tensors" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.tensors = tensors def __getitem__(self, index): @@ -222,7 +236,11 @@ class StackDataset(Dataset[_T_stack]): >>> tuple_stack = StackDataset(images, texts) >>> tuple_stack[0] == (images[0], texts[0]) >>> dict_stack = StackDataset(image=images, text=texts) +<<<<<<< HEAD >>> dict_stack[0] == {"image": images[0], "text": texts[0]} +======= + >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Args: *args (Dataset): Datasets for stacking returned as tuple. @@ -323,9 +341,15 @@ def __init__(self, datasets: Iterable[Dataset]) -> None: self.datasets = list(datasets) assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type] for d in self.datasets: +<<<<<<< HEAD assert not isinstance(d, IterableDataset), ( "ConcatDataset does not support IterableDataset" ) +======= + assert not isinstance( + d, IterableDataset + ), "ConcatDataset does not support IterableDataset" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) self.cumulative_sizes = self.cumsum(self.datasets) def __len__(self): @@ -371,17 +395,29 @@ def __init__(self, datasets: Iterable[Dataset]) -> None: def __iter__(self): for d in self.datasets: +<<<<<<< HEAD assert isinstance(d, IterableDataset), ( "ChainDataset only supports IterableDataset" ) +======= + assert isinstance( + d, IterableDataset + ), "ChainDataset only supports IterableDataset" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) yield from d def __len__(self): total = 0 for d in self.datasets: +<<<<<<< HEAD assert isinstance(d, IterableDataset), ( "ChainDataset only supports IterableDataset" ) +======= + assert isinstance( + d, IterableDataset + ), "ChainDataset only supports IterableDataset" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) total += len(d) # type: ignore[arg-type] return total diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 6c2e6dcaf2f45..2443dccaee126 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -6,12 +6,15 @@ import torch +<<<<<<< HEAD # Note: For benchmarking changes to samplers, see: # /benchmarks/data/samplers_bench.py # This benchmark compares the performance of different sampler implementations # and can be used to evaluate the impact of optimizations. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __all__ = [ "BatchSampler", "RandomSampler", @@ -236,6 +239,7 @@ class WeightedRandomSampler(Sampler[int]): Example: >>> # xdoctest: +IGNORE_WANT("non-deterministic") +<<<<<<< HEAD >>> list( ... WeightedRandomSampler( ... [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True @@ -247,6 +251,11 @@ class WeightedRandomSampler(Sampler[int]): ... [0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False ... ) ... ) +======= + >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) + [4, 4, 1, 4, 5] + >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [0, 1, 4, 3, 2] """ @@ -306,6 +315,7 @@ class BatchSampler(Sampler[list[int]]): its size would be less than ``batch_size`` Example: +<<<<<<< HEAD >>> list( ... BatchSampler( ... SequentialSampler(range(10)), batch_size=3, drop_last=False @@ -315,6 +325,11 @@ class BatchSampler(Sampler[list[int]]): >>> list( ... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True) ... ) +======= + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ @@ -344,6 +359,10 @@ def __init__( self.drop_last = drop_last def __iter__(self) -> Iterator[list[int]]: +<<<<<<< HEAD +======= + # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) sampler_iter = iter(self.sampler) if self.drop_last: # Create multiple references to the same iterator diff --git a/torch/utils/data/standard_pipes.ipynb b/torch/utils/data/standard_pipes.ipynb index c40058bca7699..8df58c14c4a03 100644 --- a/torch/utils/data/standard_pipes.ipynb +++ b/torch/utils/data/standard_pipes.ipynb @@ -753,7 +753,11 @@ "\n", "Arguments:\n", " - `group_key_fn`\n", +<<<<<<< HEAD " - `group_size` - yield resulted group as soon as `group_size` elements accumulated\n", +======= + " - `group_size` - yeild resulted group as soon as `group_size` elements accumulated\n", +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) " - `guaranteed_group_size:int = None`\n", " - `unbatch_level:int = 0` if specified calls `unbatch(unbatch_level=unbatch_level)` on source datapipe before batching (see `unbatch`)\n", "\n", @@ -962,7 +966,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ +<<<<<<< HEAD "This behaviour becomes noticeable when data is bigger than buffer and some groups getting evicted before gathering all potential items" +======= + "This behaviour becomes noticable when data is bigger than buffer and some groups getting evicted before gathering all potential items" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] }, { diff --git a/torch/utils/data/typing.ipynb b/torch/utils/data/typing.ipynb index 1b1aa8c9da72f..3d7725e96b1ee 100644 --- a/torch/utils/data/typing.ipynb +++ b/torch/utils/data/typing.ipynb @@ -399,7 +399,11 @@ "\n", "Note: This decorator is only allowed to be attached to `__iter__` for now. It can be extended into `__getitem__` and further `nonblocking` functions.\n", "\n", +<<<<<<< HEAD "`runtime_validation_disabled` is a context manager to turn off the type validation during runtime. It's useful for DataLoader to disable the runtime validation after the first epoch is finished for better performance. Note: the runtime validation is enabled by default." +======= + "`runtime_validation_disabled` is a context manager to turn off the type validaiton during runtime. It's useful for DataLoader to disable the runtime validaiton after the first epoch is finished for better performance. Note: the runtime validation is enabled by default." +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] }, { @@ -684,7 +688,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ +<<<<<<< HEAD "- Compatible with context manager to disable validation" +======= + "- Compatible with context mangager to disable validation" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] }, { diff --git a/torch/utils/dlpack.py b/torch/utils/dlpack.py index e7aeae1ba3c81..7ed07f3f13d71 100644 --- a/torch/utils/dlpack.py +++ b/torch/utils/dlpack.py @@ -1,14 +1,24 @@ +<<<<<<< HEAD from typing import Any, Optional +======= +from typing import Any +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import torch import enum +<<<<<<< HEAD from torch._C import _to_dlpack as to_dlpack from torch.types import Device as _Device +======= +from torch._C import _from_dlpack +from torch._C import _to_dlpack as to_dlpack +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) __all__ = [ "DLDeviceType", "from_dlpack", +<<<<<<< HEAD ] class DLDeviceType(enum.IntEnum): @@ -16,11 +26,23 @@ class DLDeviceType(enum.IntEnum): kDLCPU = 1, kDLCUDA = 2, kDLCUDAHost = 3, +======= + "to_dlpack", +] + + +class DLDeviceType(enum.IntEnum): + # Enums as in DLPack specification (aten/src/ATen/dlpack.h) + kDLCPU = 1, + kDLGPU = 2, + kDLCPUPinned = 3, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) kDLOpenCL = 4, kDLVulkan = 7, kDLMetal = 8, kDLVPI = 9, kDLROCM = 10, +<<<<<<< HEAD kDLROCMHost = 11, kDLExtDev = 12, kDLCUDAManaged = 13, @@ -28,6 +50,10 @@ class DLDeviceType(enum.IntEnum): kDLWebGPU = 15, kDLHexagon = 16, kDLMAIA = 17, +======= + kDLExtDev = 12, + kDLOneAPI = 14, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch._C._add_docstr(to_dlpack, r"""to_dlpack(tensor) -> PyCapsule @@ -55,12 +81,16 @@ class DLDeviceType(enum.IntEnum): # TODO: add a typing.Protocol to be able to tell Mypy that only objects with # __dlpack__ and __dlpack_device__ methods are accepted. +<<<<<<< HEAD def from_dlpack( ext_tensor: Any, *, device: Optional[_Device] = None, copy: Optional[bool] = None ) -> 'torch.Tensor': +======= +def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """from_dlpack(ext_tensor) -> Tensor Converts a tensor from an external library into a ``torch.Tensor``. @@ -82,6 +112,7 @@ def from_dlpack( an opaque ``PyCapsule`` instance, typically produced by a ``to_dlpack`` function or method. +<<<<<<< HEAD device (torch.device or str or None): An optional PyTorch device specifying where to place the new tensor. If None (default), the new tensor will be on the same device as ``ext_tensor``. @@ -89,6 +120,8 @@ def from_dlpack( copy (bool or None): An optional boolean indicating whether or not to copy ``self``. If None, PyTorch will copy only if necessary. +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Examples:: >>> import torch.utils.dlpack @@ -119,6 +152,7 @@ def from_dlpack( """ if hasattr(ext_tensor, '__dlpack__'): +<<<<<<< HEAD # Only populate kwargs if any of the optional arguments are, in fact, not None. Otherwise, # leave them out, since we might end up falling back to no-extra-kwargs __dlpack__ call. kwargs: dict[str, Any] = {} @@ -143,11 +177,19 @@ def from_dlpack( # stream if ext_device[0] in (DLDeviceType.kDLCUDA, DLDeviceType.kDLROCM): stream = torch.cuda.current_stream(f'cuda:{ext_device[1]}') +======= + device = ext_tensor.__dlpack_device__() + # device is either CUDA or ROCm, we need to pass the current + # stream + if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM): + stream = torch.cuda.current_stream(f'cuda:{device[1]}') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # cuda_stream is the pointer to the stream and it is a public # attribute, but it is not documented # The array API specify that the default legacy stream must be passed # with a value of 1 for CUDA # https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none +<<<<<<< HEAD is_cuda = ext_device[0] == DLDeviceType.kDLCUDA # Since pytorch is not using PTDS by default, lets directly pass # the legacy stream @@ -170,3 +212,16 @@ def from_dlpack( # Old versions just call the converter dlpack = ext_tensor return torch._C._from_dlpack(dlpack) +======= + is_cuda = device[0] == DLDeviceType.kDLGPU + # Since pytorch is not using PTDS by default, lets directly pass + # the legacy stream + stream_ptr = 1 if is_cuda and stream.cuda_stream == 0 else stream.cuda_stream + dlpack = ext_tensor.__dlpack__(stream=stream_ptr) + else: + dlpack = ext_tensor.__dlpack__() + else: + # Old versions just call the converter + dlpack = ext_tensor + return _from_dlpack(dlpack) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index b8d4e878b7f08..ee3b2b81ff267 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -127,6 +127,10 @@ def conv_flop_count( Returns: int: the number of flops """ +<<<<<<< HEAD +======= + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) batch_size = x_shape[0] conv_shape = (x_shape if transposed else out_shape)[2:] c_out, c_in, *filter_size = w_shape @@ -145,7 +149,11 @@ def conv_flop_count( flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2 return flop +<<<<<<< HEAD @register_flop_formula([aten.convolution, aten._convolution, aten.cudnn_convolution, aten._slow_conv2d_forward]) +======= +@register_flop_formula([aten.convolution, aten._convolution]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int: """Count flops for convolution.""" return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) @@ -560,8 +568,11 @@ def _efficient_attention_backward_flop( aten._scaled_mm: _scaled_mm_flop, aten.convolution: conv_flop, aten._convolution: conv_flop, +<<<<<<< HEAD aten.cudnn_convolution: conv_flop, aten._slow_conv2d_forward: conv_flop, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) aten.convolution_backward: conv_backward_flop, aten._scaled_dot_product_efficient_attention: sdpa_flop, aten._scaled_dot_product_flash_attention: sdpa_flop, @@ -755,6 +766,7 @@ def _count_flops(self, func_packet, out, args, kwargs): return out +<<<<<<< HEAD class _FlopCounterMode(TorchDispatchMode): supports_higher_order_operators = True @@ -830,12 +842,23 @@ def _handle_higher_order_ops(self, func, types, args, kwargs): # output with the same structure. return true_out +======= + +class _FlopCounterMode(TorchDispatchMode): + def __init__(self, counter: FlopCounterMode): + self.counter = counter + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} # Skip ops from non-standard dispatch_sizes_strides_policy such as NJT +<<<<<<< HEAD if func in {torch.ops.aten.sym_is_contiguous.default, torch.ops.aten.is_contiguous.default, +======= + if func in {torch.ops.aten.is_contiguous.default, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.ops.aten.is_contiguous.memory_format, torch.ops.aten.is_strides_like_format.default, torch.ops.aten.is_non_overlapping_and_dense.default, @@ -852,9 +875,12 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return NotImplemented +<<<<<<< HEAD if isinstance(func, torch._ops.HigherOrderOperator): return self._handle_higher_order_ops(func, types, args, kwargs) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # If we don't have func in flop_registry, see if it can decompose if func not in self.counter.flop_registry and func is not torch.ops.prim.device.default: with self: diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index 12291db1704c2..4ddcd3ceba16d 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -179,9 +179,12 @@ ), ("CUlimit", ("hipLimit_t", CONV_TYPE, API_DRIVER)), ("CUlimit_enum", ("hipLimit_t", CONV_TYPE, API_DRIVER)), +<<<<<<< HEAD ("CUmemAccessDesc", ("hipMemAccessDesc", CONV_TYPE, API_DRIVER)), ("CUmemAccessDesc_st", ("hipMemAccessDesc", CONV_TYPE, API_DRIVER)), ("CUmemAccessDesc_v1", ("hipMemAccessDesc", CONV_TYPE, API_DRIVER)), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ( "CUmemAttach_flags", ("hipMemAttachFlags_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), @@ -190,6 +193,7 @@ "CUmemAttach_flags_enum", ("hipMemAttachFlags_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), ), +<<<<<<< HEAD ("CUmemAllocationGranularity_flags", ("hipMemAllocationGranularity_flags", CONV_TYPE, API_DRIVER)), ("CUmemAllocationGranularity_flags_enum", ("hipMemAllocationGranularity_flags", CONV_TYPE, API_DRIVER)), ("CUmemAllocationHandleType", ("hipMemAllocationHandleType", CONV_TYPE, API_DRIVER)), @@ -222,6 +226,8 @@ ("CUmem_advise_enum", ("hipMemoryAdvise", CONV_TYPE, API_DRIVER)), ("CUmem_range_attribute_enum", ("hipMemRangeAttribute", CONV_TYPE, API_DRIVER)), ("CUmemoryPool", ("hipMemPool_t", CONV_TYPE, API_DRIVER)), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("CUmemorytype", ("hipMemType_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), ("CUmemorytype_enum", ("hipMemType_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), ("CUresourcetype", ("hipResourceType", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED)), @@ -578,7 +584,10 @@ ("curandState", ("hiprandState_t", CONV_TYPE, API_RAND)), ("CUuuid", ("hipUUID", CONV_TYPE, API_RUNTIME)), ("cudaGraph_t", ("hipGraph_t", CONV_TYPE, API_RAND)), +<<<<<<< HEAD ("cudaGraphNode_t", ("hipGraphNode_t", CONV_TYPE, API_RAND)), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("cudaGraphExec_t", ("hipGraphExec_t", CONV_TYPE, API_RAND)), ("__nv_bfloat16", ("__hip_bfloat16", CONV_TYPE, API_RUNTIME)), ("__nv_bfloat162", ("__hip_bfloat162", CONV_TYPE, API_RUNTIME)), @@ -607,12 +616,18 @@ "channel_descriptor.h", ("hip/channel_descriptor.h", CONV_INCLUDE, API_RUNTIME), ), +<<<<<<< HEAD ('include "device_functions.h', ('include "hip/device_functions.h', CONV_INCLUDE, API_RUNTIME)), ('include >>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("cuComplex.h", ("hip/hip_complex.h", CONV_INCLUDE, API_RUNTIME)), ("cuda_fp16.h", ("hip/hip_fp16.h", CONV_INCLUDE, API_RUNTIME)), ("cuda_bf16.h", ("hip/hip_bf16.h", CONV_INCLUDE, API_RUNTIME)), @@ -669,7 +684,10 @@ ("cub/device/device_scan.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), ("cub/device/device_select.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), ("nvtx3/nvtx3.hpp", ("roctracer/roctx.h", CONV_INCLUDE, API_ROCTX)), +<<<<<<< HEAD ("nvToolsExt.h", ("roctracer/roctx.h", CONV_INCLUDE, API_ROCTX)), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("nvml.h", ("rocm_smi/rocm_smi.h", CONV_INCLUDE, API_ROCMSMI)), ] ) @@ -2589,6 +2607,7 @@ "CU_MEMORYTYPE_UNIFIED", ("hipMemTypeUnified", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), ), +<<<<<<< HEAD ("CU_MEMHOSTREGISTER_READ_ONLY", ("hipHostRegisterReadOnly", CONV_TYPE, API_DRIVER)), ("CU_MEMPOOL_ATTR_RELEASE_THRESHOLD", ("hipMemPoolAttrReleaseThreshold", CONV_TYPE, API_DRIVER)), ("CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT", ("hipMemPoolAttrReservedMemCurrent", CONV_TYPE, API_DRIVER)), @@ -2621,6 +2640,8 @@ ("CU_MEM_LOCATION_TYPE_INVALID", ("hipMemLocationTypeInvalid", CONV_TYPE, API_DRIVER)), ("CU_MEM_OPERATION_TYPE_MAP", ("hipMemOperationTypeMap", CONV_TYPE, API_DRIVER)), ("CU_MEM_OPERATION_TYPE_UNMAP", ("hipMemOperationTypeUnmap", CONV_TYPE, API_DRIVER)), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ( "CU_RESOURCE_TYPE_ARRAY", ("hipResourceTypeArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), @@ -3251,6 +3272,7 @@ "cuMemGetAddressRange_v2", ("hipMemGetAddressRange", CONV_MEM, API_DRIVER), ), +<<<<<<< HEAD ("cuArray3DCreate_v2", ("hipArray3DCreate", CONV_MEM, API_DRIVER)), ("cuArray3DGetDescriptor_v2", ("hipArray3DGetDescriptor", CONV_MEM, API_DRIVER)), ("cuArrayGetDescriptor_v2", ("hipArrayGetDescriptor", CONV_MEM, API_DRIVER)), @@ -3308,6 +3330,8 @@ ("cuMemPoolSetAccess", ("hipMemPoolSetAccess", CONV_MEM, API_DRIVER)), ("cuMemPoolSetAttribute", ("hipMemPoolSetAttribute", CONV_MEM, API_DRIVER)), ("cuMemPoolTrimTo", ("hipMemPoolTrimTo", CONV_MEM, API_DRIVER)), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ( "cuPointerGetAttributes", ("hipPointerGetAttributes", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), @@ -4196,6 +4220,7 @@ ("cudaMemPoolAttrUsedMemCurrent", ("hipMemPoolAttrUsedMemCurrent", CONV_MEM, API_RUNTIME)), ("cudaMemPoolAttrUsedMemHigh", ("hipMemPoolAttrUsedMemHigh", CONV_MEM, API_RUNTIME)), ("cudaMemPoolGetAttribute", ("hipMemPoolGetAttribute", CONV_MEM, API_RUNTIME)), +<<<<<<< HEAD ( "cudaMemPoolReuseAllowInternalDependencies", ("hipMemPoolReuseAllowInternalDependencies", CONV_MEM, API_RUNTIME) @@ -4205,6 +4230,11 @@ "cudaMemPoolReuseFollowEventDependencies", ("hipMemPoolReuseFollowEventDependencies", CONV_MEM, API_RUNTIME) ), +======= + ("cudaMemPoolReuseAllowInternalDependencies", ("hipMemPoolReuseAllowInternalDependencies", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolReuseAllowOpportunistic", ("hipMemPoolReuseAllowOpportunistic", CONV_MEM, API_RUNTIME)), + ("cudaMemPoolReuseFollowEventDependencies", ("hipMemPoolReuseFollowEventDependencies", CONV_MEM, API_RUNTIME)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("cudaMemPoolSetAccess", ("hipMemPoolSetAccess", CONV_MEM, API_RUNTIME)), ("cudaMemPoolSetAttribute", ("hipMemPoolSetAttribute", CONV_MEM, API_RUNTIME)), ("cudaMemPoolTrimTo", ("hipMemPoolTrimTo", CONV_MEM, API_RUNTIME)), @@ -4235,7 +4265,10 @@ ), ("cudaMallocHost", ("hipHostMalloc", CONV_MEM, API_RUNTIME)), ("cudaMallocArray", ("hipMallocArray", CONV_MEM, API_RUNTIME)), +<<<<<<< HEAD ("cudaMallocAsync", ("hipMallocAsync", CONV_MEM, API_RUNTIME)), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("cudaMalloc", ("hipMalloc", CONV_MEM, API_RUNTIME)), ("cudaMalloc3D", ("hipMalloc3D", CONV_MEM, API_RUNTIME)), ("cudaMalloc3DArray", ("hipMalloc3DArray", CONV_MEM, API_RUNTIME)), @@ -4250,22 +4283,31 @@ ("cudaMallocPitch", ("hipMallocPitch", CONV_MEM, API_RUNTIME)), ("cudaFreeHost", ("hipHostFree", CONV_MEM, API_RUNTIME)), ("cudaFreeArray", ("hipFreeArray", CONV_MEM, API_RUNTIME)), +<<<<<<< HEAD ("cudaFreeAsync", ("hipFreeAsync", CONV_MEM, API_RUNTIME)), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("cudaFree", ("hipFree", CONV_MEM, API_RUNTIME)), ("cudaHostRegister", ("hipHostRegister", CONV_MEM, API_RUNTIME)), ("cudaHostUnregister", ("hipHostUnregister", CONV_MEM, API_RUNTIME)), ("cudaHostAlloc", ("hipHostMalloc", CONV_MEM, API_RUNTIME)), ("cudaMemoryTypeHost", ("hipMemoryTypeHost", CONV_MEM, API_RUNTIME)), ("cudaMemoryTypeDevice", ("hipMemoryTypeDevice", CONV_MEM, API_RUNTIME)), +<<<<<<< HEAD ("cudaMemoryTypeUnregistered", ("hipMemoryTypeUnregistered", CONV_MEM, API_RUNTIME)), ("cudaMemoryTypeManaged", ("hipMemoryTypeManaged", CONV_MEM, API_RUNTIME)), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("make_cudaExtent", ("make_hipExtent", CONV_MEM, API_RUNTIME)), ("make_cudaPitchedPtr", ("make_hipPitchedPtr", CONV_MEM, API_RUNTIME)), ("make_cudaPos", ("make_hipPos", CONV_MEM, API_RUNTIME)), ("cudaHostAllocDefault", ("hipHostMallocDefault", CONV_MEM, API_RUNTIME)), ("cudaHostAllocPortable", ("hipHostMallocPortable", CONV_MEM, API_RUNTIME)), ("cudaHostAllocMapped", ("hipHostMallocMapped", CONV_MEM, API_RUNTIME)), +<<<<<<< HEAD ("cudaHostNodeParams", ("hipHostNodeParams", CONV_MEM, API_RUNTIME)), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ( "cudaHostAllocWriteCombined", ("hipHostMallocWriteCombined", CONV_MEM, API_RUNTIME), @@ -4326,13 +4368,17 @@ ("cudaStreamGetCaptureInfo_v2", ("hipStreamGetCaptureInfo_v2", CONV_TYPE, API_RUNTIME)), ("cudaStreamCaptureStatus", ("hipStreamCaptureStatus", CONV_TYPE, API_RUNTIME)), ("cudaStreamCaptureStatusActive", ("hipStreamCaptureStatusActive", CONV_TYPE, API_RUNTIME)), +<<<<<<< HEAD ("cudaStreamCaptureStatusNone", ("hipStreamCaptureStatusNone", CONV_TYPE, API_RUNTIME)), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("cudaStreamCaptureMode", ("hipStreamCaptureMode", CONV_TYPE, API_RUNTIME)), ("cudaStreamCaptureModeGlobal", ("hipStreamCaptureModeGlobal", CONV_TYPE, API_RUNTIME)), ("cudaStreamCaptureModeRelaxed", ("hipStreamCaptureModeRelaxed", CONV_TYPE, API_RUNTIME)), ("cudaStreamCaptureModeThreadLocal", ("hipStreamCaptureModeThreadLocal", CONV_TYPE, API_RUNTIME)), ("cudaStreamBeginCapture", ("hipStreamBeginCapture", CONV_TYPE, API_RUNTIME)), ("cudaStreamEndCapture", ("hipStreamEndCapture", CONV_TYPE, API_RUNTIME)), +<<<<<<< HEAD ("cudaStreamSetCaptureDependencies", ("hipStreamSetCaptureDependencies", CONV_STREAM, API_RUNTIME)), ("cudaStreamUpdateCaptureDependencies", ("hipStreamUpdateCaptureDependencies", CONV_STREAM, API_RUNTIME)), ("cudaGraphInstantiate", ("hipGraphInstantiate", CONV_TYPE, API_RUNTIME)), @@ -4341,6 +4387,11 @@ "cudaGraphInstantiateFlagAutoFreeOnLaunch", ("hipGraphInstantiateFlagAutoFreeOnLaunch", CONV_TYPE, API_RUNTIME) ), +======= + ("cudaGraphInstantiate", ("hipGraphInstantiate", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateWithFlags", ("hipGraphInstantiateWithFlags", CONV_TYPE, API_RUNTIME)), + ("cudaGraphInstantiateFlagAutoFreeOnLaunch", ("hipGraphInstantiateFlagAutoFreeOnLaunch", CONV_TYPE, API_RUNTIME)), +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("cudaGraphDestroy", ("hipGraphDestroy", CONV_TYPE, API_RUNTIME)), ("cudaGraphExecDestroy", ("hipGraphExecDestroy", CONV_TYPE, API_RUNTIME)), ("cudaGraphLaunch", ("hipGraphLaunch", CONV_TYPE, API_RUNTIME)), @@ -4349,6 +4400,7 @@ ("cudaGraphDebugDotFlagsVerbose", ("hipGraphDebugDotFlagsVerbose", CONV_NUMERIC_LITERAL, API_RUNTIME)), ("cudaGraphRetainUserObject", ("hipGraphRetainUserObject", CONV_TYPE, API_RUNTIME)), ("cudaGraphUserObjectMove", ("hipGraphUserObjectMove", CONV_TYPE, API_RUNTIME)), +<<<<<<< HEAD ("cudaDeviceGetGraphMemAttribute", ("hipDeviceGetGraphMemAttribute", CONV_TYPE, API_RUNTIME)), ("cudaDeviceGraphMemTrim", ("hipDeviceGraphMemTrim", CONV_TYPE, API_RUNTIME)), ("cudaDeviceSetGraphMemAttribute", ("hipDeviceSetGraphMemAttribute", CONV_TYPE, API_RUNTIME)), @@ -4552,6 +4604,8 @@ ("cudaGraphNodeTypeMemcpy", ("hipGraphNodeTypeMemcpy", CONV_TYPE, API_RUNTIME)), ("cudaGraphNodeTypeMemset", ("hipGraphNodeTypeMemset", CONV_TYPE, API_RUNTIME)), ("cudaGraphNodeTypeWaitEvent", ("hipGraphNodeTypeWaitEvent", CONV_TYPE, API_RUNTIME)), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ("cudaUserObject_t", ("hipUserObject_t", CONV_TYPE, API_RUNTIME)), ("cudaUserObjectCreate", ("hipUserObjectCreate", CONV_TYPE, API_RUNTIME)), ("cudaUserObjectNoDestructorSync", ("hipUserObjectNoDestructorSync", CONV_TYPE, API_RUNTIME)), @@ -9385,7 +9439,11 @@ ] ) +<<<<<<< HEAD # We must treat very carefully here. Blanket conversions like are done +======= +# We must tread very carefully here. Blanket conversions like are done +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # in CAFFE2_SPECIFIC_MAPPINGS are not presently supported on PyTorch, # because a regex for CUDA will also match a filename like CUDAGuard.h, # but the HIPIFY script doesn't presently move the file and so the substitution @@ -9455,7 +9513,10 @@ ("C10_CUDA_KERNEL_LAUNCH_CHECK", ("C10_HIP_KERNEL_LAUNCH_CHECK", API_C10)), ("CUDAKernelLaunchRegistry", ("HIPKernelLaunchRegistry", API_C10)), ("c10::cuda::get_cuda_check_suffix", ("c10::hip::get_hip_check_suffix", API_C10)), +<<<<<<< HEAD ("c10::cuda::get_cuda_error_help", ("c10::hip::get_hip_error_help", API_C10)), +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] ) diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index 0e816020635be..6be5baa5edc67 100755 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -465,7 +465,11 @@ def find_closure_group(input_string, start, group): def find_bracket_group(input_string, start): +<<<<<<< HEAD """Finds the first balanced parentheses.""" +======= + """Finds the first balanced parantheses.""" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return find_closure_group(input_string, start, group=["{", "}"]) diff --git a/torch/utils/model_dump/__init__.py b/torch/utils/model_dump/__init__.py index 7d6a6890e4cea..ebdc698baddf9 100644 --- a/torch/utils/model_dump/__init__.py +++ b/torch/utils/model_dump/__init__.py @@ -213,6 +213,7 @@ def get_model_info( path_prefix = prefix elif prefix != path_prefix: raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}") # noqa: TRY002 +<<<<<<< HEAD zip_files.append( { "filename": zi.filename, @@ -221,6 +222,15 @@ def get_model_info( "file_size": zi.file_size, } ) +======= + zip_files.append(dict( + filename=zi.filename, + compression=zi.compress_type, + compressed_size=zi.compress_size, + file_size=zi.file_size, + )) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) assert path_prefix is not None version = zf.read(path_prefix + "/version").decode("utf-8").strip() @@ -233,14 +243,24 @@ def get_pickle(name): model_data = get_pickle("data") constants = get_pickle("constants") +<<<<<<< HEAD # Intern strings that are likely to be reused. # Pickle automatically detects shared structure, # so reused strings are stored efficiently. +======= + # Intern strings that are likely to be re-used. + # Pickle automatically detects shared structure, + # so re-used strings are stored efficiently. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # However, JSON has no way of representing this, # so we have to do it manually. interned_strings : dict[str, int] = {} +<<<<<<< HEAD def intern(s): +======= + def ist(s): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if s not in interned_strings: interned_strings[s] = len(interned_strings) return interned_strings[s] @@ -294,7 +314,11 @@ def parse_new_format(line): s_start = 0 s_end = 0 text = raw_code[start:end] +<<<<<<< HEAD code_parts.append([text.decode("utf-8"), intern(s_file), s_line, intern(s_text), s_start, s_end]) +======= + code_parts.append([text.decode("utf-8"), ist(s_file), s_line, ist(s_text), s_start, s_end]) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) code_files[zi.filename] = code_parts extra_files_json_pattern = re.compile(re.escape(path_prefix) + "/extra/.*\\.json") @@ -333,6 +357,7 @@ def parse_new_format(line): continue extra_pickles[zi.filename] = contents +<<<<<<< HEAD return { "model": { "title": title, @@ -347,6 +372,20 @@ def parse_new_format(line): "extra_pickles": extra_pickles, } } +======= + return {"model": dict( + title=title, + file_size=file_size, + version=version, + zip_files=zip_files, + interned_strings=list(interned_strings), + code_files=code_files, + model_data=model_data, + constants=constants, + extra_files_jsons=extra_files_jsons, + extra_pickles=extra_pickles, + )} +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_inline_skeleton(): diff --git a/torch/utils/module_tracker.py b/torch/utils/module_tracker.py index 4c7dec0481522..1e088abe7725e 100644 --- a/torch/utils/module_tracker.py +++ b/torch/utils/module_tracker.py @@ -49,7 +49,10 @@ class ModuleTracker: def my_linear(m1, m2, bias): print(f"Current modules: {tracker.parents}") return torch.mm(m1, m2.t()) + bias +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) torch.nn.functional.linear = my_linear mod(torch.rand(2, 2)) diff --git a/torch/utils/tensorboard/_proto_graph.py b/torch/utils/tensorboard/_proto_graph.py index c4e234dff6ba0..6864abf959072 100644 --- a/torch/utils/tensorboard/_proto_graph.py +++ b/torch/utils/tensorboard/_proto_graph.py @@ -1,13 +1,22 @@ +<<<<<<< HEAD import torch from typing import Optional, Union from collections.abc import Sequence +======= +# mypy: allow-untyped-defs +from typing import Optional +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.attr_value_pb2 import AttrValue from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto +<<<<<<< HEAD def attr_value_proto(dtype: object, shape: Optional[Sequence[int]], s: Optional[str]) -> dict[str, AttrValue]: +======= +def attr_value_proto(dtype, shape, s): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Create a dict of objects matching a NodeDef's attr field. Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto @@ -23,7 +32,11 @@ def attr_value_proto(dtype: object, shape: Optional[Sequence[int]], s: Optional[ return attr +<<<<<<< HEAD def tensor_shape_proto(outputsize: Sequence[int]) -> TensorShapeProto: +======= +def tensor_shape_proto(outputsize): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Create an object matching a tensor_shape field. Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto . @@ -32,6 +45,7 @@ def tensor_shape_proto(outputsize: Sequence[int]) -> TensorShapeProto: def node_proto( +<<<<<<< HEAD name: str, op: str = "UnSpecified", input: Optional[Union[list[str], str]] = None, @@ -40,6 +54,16 @@ def node_proto( outputsize: Optional[Sequence[int]] = None, attributes: str = "", ) -> NodeDef: +======= + name, + op="UnSpecified", + input=None, + dtype=None, + shape: Optional[tuple] = None, + outputsize=None, + attributes="", +): +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) """Create an object matching a NodeDef. Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto . diff --git a/torch/utils/tensorboard/_utils.py b/torch/utils/tensorboard/_utils.py index f0ad185d968f5..78339fb29c9b1 100644 --- a/torch/utils/tensorboard/_utils.py +++ b/torch/utils/tensorboard/_utils.py @@ -45,7 +45,11 @@ def _prepare_video(V): Convesrion is done from [batchsize, time(frame), channel(color), height, width] (5D tensor) to [time(frame), new_width, new_height, channel] (4D tensor). +<<<<<<< HEAD A batch of images are spread to a grid, which forms a frame. +======= + A batch of images are spreaded to a grid, which forms a frame. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) e.g. Video with batchsize 16 will have a 4x4 grid. """ b, t, c, h, w = V.shape diff --git a/torch/utils/viz/MemoryViz.js b/torch/utils/viz/MemoryViz.js index 45ab7a8ded984..1aead42518cbf 100644 --- a/torch/utils/viz/MemoryViz.js +++ b/torch/utils/viz/MemoryViz.js @@ -1724,6 +1724,7 @@ body.on('drop', () => { selection_to_div[''] = body .append('div') .text( +<<<<<<< HEAD 'Drag and drop or select a file to load a local snapshot. No data from the snapshot is uploaded.', ); @@ -1746,6 +1747,11 @@ const fileInput = body.append('input') selected_change(); // refresh the UI }); +======= + 'Drag and drop a file to load a local snapshot. No data from the snapshot is uploaded.', + ); + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) let next_unique_n = 1; function add_snapshot(name, loader) { if (name in snapshot_to_loader) { diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index 79d8e8b8b1715..39c9805a8702a 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -70,7 +70,11 @@ def remove(): gc.callbacks.remove(gc_callback) return remove +<<<<<<< HEAD # Function to visualize cycles adapted from refcycle: +======= +# Function to visualize cycles adapated from refcycle: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Copyright 2013 Mark Dickinson # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -336,7 +340,11 @@ def object_context(obj): def to_dot(nodes): lines = ["digraph GraphName {", "node [shape=rect];", 'rankdir=LR;'] for i, n in enumerate(nodes): +<<<<<<< HEAD lines.append(f'{i} [label={escape(n.label)}, color={"red" if n.root else "black"}];') +======= + lines.append(f'{i} [label={escape(n.label)}, color={ "red" if n.root else "black"}];') +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) for i, f in enumerate(nodes): for label, j in f.referrents: @@ -482,7 +490,11 @@ def warn_tensor_cycles(): Install a warning that reports whenever a cycle that is holding CUDA memory is observed. The warning produces an .html file that visualizes the cycle, +<<<<<<< HEAD and links it to the stack frame that allocated the CUDA tensor. +======= + and links it to the stack frame that allocted the CUDA tensor. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) Reference cycles are freed by the cycle collector rather than being cleaned up when the objects in the cycle first become unreachable. If a cycle points to a tensor, diff --git a/torch/utils/weak.py b/torch/utils/weak.py index 9c7218cb2ad3b..e6b4120a9a45d 100644 --- a/torch/utils/weak.py +++ b/torch/utils/weak.py @@ -3,6 +3,11 @@ import collections.abc as _collections_abc import weakref +<<<<<<< HEAD +======= + +from _weakrefset import _IterationGuard # type: ignore[attr-defined] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from collections.abc import Mapping, MutableMapping from weakref import ref @@ -20,6 +25,7 @@ ] +<<<<<<< HEAD # TODO: make weakref properly thread safe following # https://github.com/python/cpython/pull/125325 class _IterationGuard: @@ -47,6 +53,8 @@ def __exit__(self, e, t, b): w._commit_removals() +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # This file defines a variant of WeakKeyDictionary that overrides the hashing # behavior of the key to use object identity, rather than the builtin # __eq__/__hash__ functions. This is useful for Tensor weak keys, as their diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 79aae38a31685..9fc5e0bf316ba 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -6,7 +6,10 @@ This package is lazily initialized, so you can always import it, and use :func:`is_available()` to determine if your system supports XPU. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) import threading import traceback from functools import lru_cache @@ -236,6 +239,7 @@ def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]: Dict[str, Any]: the xpu capability dictionary of the device """ props = get_device_properties(device) +<<<<<<< HEAD # Only keep attributes that are safe for dictionary serialization. serializable_types = (int, float, bool, str, type(None), list, tuple, dict) return { @@ -243,6 +247,17 @@ def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]: for key in dir(props) if not key.startswith("__") and isinstance((value := getattr(props, key)), serializable_types) +======= + # pybind service attributes are no longer needed and their presence breaks + # the further logic related to the serialization of the created dictionary. + # In particular it filters out `` + # to fix Triton tests. + # This field appears after updating pybind to 2.13.6. + return { + prop: getattr(props, prop) + for prop in dir(props) + if not prop.startswith(("__", "_pybind11_")) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } @@ -291,7 +306,10 @@ class StreamContext: ``None``. .. note:: Streams are per-device. """ +<<<<<<< HEAD +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) cur_stream: Optional["torch.xpu.Stream"] def __init__(self, stream: Optional["torch.xpu.Stream"]): @@ -438,7 +456,11 @@ def get_gencode_flags() -> str: arch_list = get_arch_list() if len(arch_list) == 0: return "" +<<<<<<< HEAD return f"-device {','.join(arch for arch in arch_list)}" +======= + return f'-device {",".join(arch for arch in arch_list)}' +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def _get_generator(device: torch.device) -> torch._C.Generator: diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index 611400d271d96..79d8127746364 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -39,12 +39,18 @@ "aten._flash_attention_forward.default": {}, "aten._fused_moving_avg_obs_fq_helper_functional.default": {}, "aten._fused_moving_avg_obs_fq_helper.default": {}, +<<<<<<< HEAD "aten._fused_rms_norm.default": {}, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "aten._histogramdd_from_bin_cts.default": {}, "aten._int_mm.out": {}, "aten._pdist_backward.default": {}, "aten._pdist_forward.default": {}, +<<<<<<< HEAD "aten._scaled_dot_product_attention_math_for_mps.default": {}, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) "aten._scaled_dot_product_cudnn_attention_backward.default": {}, "aten._scaled_dot_product_cudnn_attention.default": {}, "aten._scaled_dot_product_efficient_attention_backward.default": {}, @@ -175,6 +181,7 @@ "aten.view.dtype": {}, "aten._weight_int4pack_mm_with_scales_and_zeros.default": {}, } +<<<<<<< HEAD # `python torchgen/gen.py --update-aoti-c-shim` will automatically generate # c_shim_aten.{h/cpp} based on the list below. @@ -189,3 +196,5 @@ "aten.new_empty.default": {}, "aten.new_zeros.default": {}, } +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/torchgen/api/functionalization.py b/torchgen/api/functionalization.py index f4b46b5f14760..a4fd6c3dcd234 100644 --- a/torchgen/api/functionalization.py +++ b/torchgen/api/functionalization.py @@ -23,6 +23,7 @@ # This file describes the translation of JIT schema to API's used +<<<<<<< HEAD # when creating `ViewMeta` specializations that are used by the functionalization pass. # These API's mostly follow the dispatcher API, with one difference: # - While the forward function just directly calls into the at::_ops API @@ -30,6 +31,22 @@ # is responsible for generating both the call-site, and the declarations # (which are implemented manually in the at::functionalization::impl namespace). +======= +# when creating view lambdas that are used by the functionalization pass. +# There are two types of lambdas: forward lambdas and reverse lambdas. +# These API's mostly follow the dispatcher API, with a few quirks: +# - The lambda capture has to convert reference types to value types +# - While the forward lambda just directly calls into the at::_ops API +# (following the dispatcher convention), the logic here for the reverse lambda +# is responsible for generating both the call-site, and the declarations +# (which are implemented manually in the at::functionalization::impl namespace). + +# The lambdas generated for each view op in the functionalization pass are of the form +# [capture_arguments](outer_arguments) -> returns_type { +# return name(inner_arguments); +# } + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Define some specific lambda input arguments. base_binding = Binding( name="base", @@ -39,6 +56,7 @@ ), default=None, ) +<<<<<<< HEAD has_symbolic_inputs_binding = Binding( name="has_symbolic_inputs", @@ -51,6 +69,8 @@ ), default=None, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) mutated_view_binding = Binding( name="mutated_view", nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), @@ -59,11 +79,19 @@ ), default=None, ) +<<<<<<< HEAD out_index_binding = Binding( name="out_index", nctype=NamedCType(name="out_index", type=BaseCType(longT)), argument=Argument( name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None +======= +mutated_view_idx_binding = Binding( + name="mutated_view_idx", + nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)), + argument=Argument( + name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ), default=None, ) @@ -91,6 +119,7 @@ ) +<<<<<<< HEAD # Name of the `ViewMeta` specialization class created. def classname(func: FunctionSchema, with_namespace: bool = False) -> str: namespace = "at::functionalization::" if with_namespace else "" @@ -98,6 +127,10 @@ def classname(func: FunctionSchema, with_namespace: bool = False) -> str: # Name of the operation called inside the `forward`/`reverse` implementations. +======= +# The lambda capture itself doesn't have a name. +# The name returned here corresponds to the name of the inner function called by the lambda. +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def name( g: NativeFunctionsViewGroup, *, @@ -134,6 +167,27 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str: return f"{api_name}_inverse" +<<<<<<< HEAD +======= +def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]: + # capture arguments include all arguments except `self`. + # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture), + # So any reference types (IntArrayRef) need to be converted to value types (vector) + args = func.arguments.flat_all + assert args[0].type == BaseType(BaseTy.Tensor) + non_self_args = args[1:] + non_self_value_bindings = [ + dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args + ] + + all_bindings = [ + inverse_return_mode_binding if is_reverse else reapply_views_binding + ] + all_bindings.extend(non_self_value_bindings) + return all_bindings + + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def returns_type(func: FunctionSchema) -> CType: # Assertion: all view ops return tensor-like outputs assert len(func.returns) >= 1 @@ -144,6 +198,7 @@ def returns_type(func: FunctionSchema) -> CType: return BaseCType(tensorT) +<<<<<<< HEAD # Checks whether `func` might return more than one value. def is_multi_output(func: FunctionSchema) -> bool: return len(func.returns) > 1 or ( @@ -187,6 +242,26 @@ def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]: def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: +======= +def outer_arguments(*, is_reverse: bool) -> list[Binding]: + if is_reverse: + return [base_binding, mutated_view_binding, mutated_view_idx_binding] + else: + return [base_binding, mutated_view_idx_binding] + + +def inner_call_index(func: FunctionSchema) -> Binding | None: + # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output. + # When we replay a view op that returns multiple tensors, we need to index into the output appropriately + if len(func.returns) > 1 or ( + len(func.returns) == 1 and func.returns[0].type.is_list_like() + ): + return mutated_view_idx_binding + return None + + +def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) args = func.arguments.flat_all assert args[0].type == BaseType(BaseTy.Tensor) non_self_args = args[1:] @@ -200,12 +275,21 @@ def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: # the reverse lambda does the same, but with an additional "mutated_view" arg # additionally, we have a calling convention: for view ops that return multiple tensor outputs # their corresponding view_inverse function takes in an additional index argument. +<<<<<<< HEAD if is_multi_output(func): +======= + index_binding = inner_call_index(func) + if index_binding is not None: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return [ base_binding, mutated_view_binding, inverse_return_mode_binding, +<<<<<<< HEAD out_index_binding, +======= + index_binding, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] + non_self_bindings else: return [ diff --git a/torchgen/api/types/signatures.py b/torchgen/api/types/signatures.py index d4a47536dd1ff..9a784d1f95b14 100644 --- a/torchgen/api/types/signatures.py +++ b/torchgen/api/types/signatures.py @@ -300,12 +300,91 @@ def decl(self) -> str: return_type = functionalization.returns_type(self.g.view.func) decls = [ a.decl() +<<<<<<< HEAD for a in functionalization.op_arguments(self.g.view.func, is_reverse=True) +======= + for a in functionalization.inner_arguments( + self.g.view.func, is_reverse=True + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ] return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});" @dataclass(frozen=True) +<<<<<<< HEAD +======= +class FunctionalizationLambda: + g: NativeFunctionsViewGroup + + # are we generating the forward lambda or the reverse lambda? + is_reverse: bool + + def captures(self) -> list[Expr]: + # The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments + # We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed, + # and plumb it into the lambda. + outer_ctx = dispatcher.arguments(self.g.view.func) + [ + functionalization.reapply_views_binding, + functionalization.inverse_return_mode_binding, + ] + capture_bindings = functionalization.capture_arguments( + self.g.view.func, is_reverse=self.is_reverse + ) + # allow_expensive_conversions is set because we want to convert + # some reference types (IntArrayRef) to value types (vector). + capture_exprs = translate.translate( + outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True + ) + return capture_exprs + + def decl(self) -> str: + return_type = functionalization.returns_type(self.g.view.func) + capture_str = ", ".join( + f"{val.type.name} = {val.expr}" for val in self.captures() + ) + decls = [ + a.decl() + for a in functionalization.outer_arguments(is_reverse=self.is_reverse) + ] + return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}" + + def inner_call(self, *, reapply_views: bool | None = None) -> str: + inner_call_name = functionalization.name( + self.g, + is_reverse=self.is_reverse, + include_namespace=True, + reapply_views=reapply_views, + ) + + arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse) + capture_ctx = functionalization.capture_arguments( + self.g.view.func, is_reverse=self.is_reverse + ) + full_ctx = arg_ctx + capture_ctx + + assert self.g.view_copy is not None + call_bindings = functionalization.inner_arguments( + self.g.view_copy.func, is_reverse=self.is_reverse + ) + maybe_index = functionalization.inner_call_index(self.g.view_copy.func) + call_exprs = [ + e.expr for e in translate.translate(full_ctx, call_bindings, method=False) + ] + if not self.is_reverse and maybe_index is not None: + return f"{inner_call_name}({', '.join(call_exprs)})[{maybe_index.name}];" + else: + return f"{inner_call_name}({', '.join(call_exprs)});" + + @staticmethod + def from_func( + g: NativeFunctionsViewGroup, *, is_reverse: bool + ) -> FunctionalizationLambda: + return FunctionalizationLambda(g, is_reverse) + + +@dataclass(frozen=True) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) class StructuredImplSignature: g: NativeFunctionsGroup name: str diff --git a/torchgen/api/types/types.py b/torchgen/api/types/types.py index 41c05653fffdf..50950e7e2f3f3 100644 --- a/torchgen/api/types/types.py +++ b/torchgen/api/types/types.py @@ -79,7 +79,10 @@ typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize") tensorGeometryT = BaseCppType("at", "TensorGeometry") SymIntT = BaseCppType("c10", "SymInt") +<<<<<<< HEAD SymBoolT = BaseCppType("c10", "SymBool") +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef") # Types representing template parameters. Technically, we probably shouldn't @@ -126,7 +129,10 @@ BaseTy.Storage: storageT, BaseTy.Stream: streamT, BaseTy.SymInt: SymIntT, +<<<<<<< HEAD BaseTy.SymBool: SymBoolT, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } # CTypes encode C++ type structure as needed for translation. diff --git a/torchgen/gen.py b/torchgen/gen.py index 7bbdd4a7a741f..561c7d67fdbce 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -43,8 +43,11 @@ gen_functionalization_definition, gen_functionalization_registration, gen_functionalization_view_inverse_declaration, +<<<<<<< HEAD gen_functionalization_view_meta_classes_decl, gen_functionalization_view_meta_classes_impl, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) GenCompositeViewCopyKernel, ) from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing @@ -97,7 +100,10 @@ if TYPE_CHECKING: from collections.abc import Sequence +<<<<<<< HEAD from typing import Optional +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) T = TypeVar("T") @@ -2218,7 +2224,11 @@ def gen_source_files( per_operator_headers: bool, skip_dispatcher_op_registration: bool, update_aoti_c_shim: bool, +<<<<<<< HEAD aoti_backends: set[Optional[DispatchKey]], +======= + aoti_backends: set[DispatchKey], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) extend_aoti_c_shim: bool, ) -> None: extra_cuda_headers = """\ @@ -2495,6 +2505,7 @@ def key_func( }, ) +<<<<<<< HEAD def gen_op_headers( g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, ) -> list[str]: @@ -2537,6 +2548,50 @@ def gen_op_headers( def functionalization_env_callable( g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, ) -> dict[str, list[str]]: +======= + def functionalization_env_callable( + g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, + ) -> dict[str, list[str]]: + def gen_op_headers( + g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, + ) -> list[str]: + if isinstance(g, NativeFunctionsViewGroup): + # view ops always get a functionalization kernel + headers = [ + f"#include ", + f"#include ", + ] + if g.view_copy is not None: + headers += [ + f"#include ", + f"#include ", + ] + return headers + elif isinstance(g, NativeFunctionsGroup): + headers = [ + f"#include ", + f"#include ", + f"#include ", + f"#include ", + ] + if g.inplace is not None: + headers += [ + f"#include ", + f"#include ", + ] + if g.mutable is not None: + headers += [ + f"#include ", + f"#include ", + ] + return headers + else: + return [ + f"#include ", + f"#include ", + ] + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return { "ops_headers": gen_op_headers(g), "func_definitions": gen_functionalization_definition( @@ -2602,6 +2657,7 @@ def functionalization_env_callable( }, ) +<<<<<<< HEAD cpu_fm.write( "ViewMetaClasses.h", lambda: { @@ -2627,6 +2683,8 @@ def functionalization_env_callable( }, ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # Note [view_copy NativeFunctions] # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd # needs to have a corresponding non-aliasing {view}_copy variant. @@ -2868,14 +2926,18 @@ def main() -> None: aoti_backends = { DispatchKey.CPU, DispatchKey.CUDA, +<<<<<<< HEAD # None will generate the aten shim based on aten_shimified_ops # which does not bypass the dispatcher None, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) } # TODO: stop generating CUDA kernels for non-CUDA builds ignore_keys = set() +<<<<<<< HEAD MPS_KEYS = {DispatchKey.MPS, DispatchKey.SparseMPS, DispatchKey.SparseCsrMPS} if options.mps or options.update_aoti_c_shim: functions_keys.update(MPS_KEYS) @@ -2883,6 +2945,16 @@ def main() -> None: else: ignore_keys.update(MPS_KEYS) dispatch_keys[:] = [k for k in dispatch_keys if k not in MPS_KEYS] +======= + if options.mps or options.update_aoti_c_shim: + functions_keys.add(DispatchKey.MPS) + aoti_backends.add(DispatchKey.MPS) + else: + ignore_keys.add(DispatchKey.MPS) + + if DispatchKey.MPS in dispatch_keys: + del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if options.xpu or options.update_aoti_c_shim: functions_keys.add(DispatchKey.XPU) diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index 65161200256e5..6767218f1bca0 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -6,7 +6,11 @@ from dataclasses import dataclass from typing import TYPE_CHECKING +<<<<<<< HEAD from torchgen.aoti.fallback_ops import aten_shimified_ops, inductor_fallback_ops +======= +from torchgen.aoti.fallback_ops import inductor_fallback_ops +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torchgen.api.types import DispatcherSignature from torchgen.api.types.signatures import CppSignature, CppSignatureGroup from torchgen.context import method_with_native_function @@ -24,14 +28,20 @@ OperatorName, OptionalType, Type, +<<<<<<< HEAD Variant, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) ) from torchgen.utils import FileManager, mapMaybe if TYPE_CHECKING: from collections.abc import Sequence +<<<<<<< HEAD from typing import Optional +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) base_type_to_c_type = { @@ -251,6 +261,7 @@ def convert_return(typ: BaseType, val: str) -> str: ret_pointer_can_be_null = False unambiguous_name = schema.name.unambiguous_name() +<<<<<<< HEAD for name in ( "_functional_sym_constrain_range", "_scaled_dot_product_cudnn_attention", @@ -264,6 +275,15 @@ def convert_return(typ: BaseType, val: str) -> str: "grid_sampler_3d_backward", "linear_backward", ): +======= + for name in [ + "_scaled_dot_product_flash_attention", + "_scaled_dot_product_efficient_attention", + "_scaled_dot_product_cudnn_attention", + "_scaled_dot_product_fused_attention_overrideable", + "convolution_backward", + ]: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if name in unambiguous_name: ret_pointer_can_be_null = True break @@ -393,6 +413,7 @@ def gen_static_dispatch_backend_call_signature( def gen_static_dispatch_backend_call( f: NativeFunction, +<<<<<<< HEAD backend_index: Optional[BackendIndex] = None, ) -> str: sig = DispatcherSignature.from_schema(f.func) @@ -416,20 +437,34 @@ def gen_static_dispatch_backend_call( return f"at::{cpp_sig.name()}" else: return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}" +======= + backend_index: BackendIndex, +) -> str: + sig = DispatcherSignature.from_schema(f.func) + cpp_sig = gen_static_dispatch_backend_call_signature(sig, f) + return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}" +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_backend_index_for_aoti( func: NativeFunction, func_group_mapping: dict[OperatorName, NativeFunctionsGroup], +<<<<<<< HEAD dispatch_key: Optional[DispatchKey], +======= + dispatch_key: DispatchKey, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) backend_indices: dict[DispatchKey, BackendIndex], extend_aoti_c_shim: bool, ) -> BackendIndex | None: backend_index = None +<<<<<<< HEAD if dispatch_key is None: return backend_index +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if backend_indices[dispatch_key].has_kernel(func) or ( func.structured_delegate is not None and func.structured_delegate in func_group_mapping @@ -463,19 +498,31 @@ def get_backend_index_for_aoti( def get_header_for_aoti( func: NativeFunction, func_group_mapping: dict[OperatorName, NativeFunctionsGroup], +<<<<<<< HEAD dispatch_key: Optional[DispatchKey], +======= + dispatch_key: DispatchKey, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) backend_indices: dict[DispatchKey, BackendIndex], extend_aoti_c_shim: bool, ) -> str | None: backend_index = get_backend_index_for_aoti( func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim ) +<<<<<<< HEAD if backend_index is None: if dispatch_key is None: return f"#include " return None return f"#include " +======= + return ( + None + if backend_index is None + else f"#include " + ) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def get_fallback_op_name(func: NativeFunction) -> str: @@ -490,7 +537,11 @@ def gen_c_shim( func: NativeFunction, version_info: dict[str, list[str]], func_group_mapping: dict[OperatorName, NativeFunctionsGroup], +<<<<<<< HEAD dispatch_key: Optional[DispatchKey], +======= + dispatch_key: DispatchKey, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) backend_indices: dict[DispatchKey, BackendIndex], header: bool, extend_aoti_c_shim: bool, @@ -498,11 +549,19 @@ def gen_c_shim( backend_index = get_backend_index_for_aoti( func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim ) +<<<<<<< HEAD if backend_index is None and dispatch_key is not None: return None schema = func.func device = "aten" if dispatch_key is None else dispatch_key.lower() +======= + if backend_index is None: + return None + + schema = func.func + device = dispatch_key.lower() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) backend_call = gen_static_dispatch_backend_call( func, backend_index, @@ -528,7 +587,11 @@ def gen_c_shim( class ShimGenerator: inductor_fallback_ops: dict[str, dict[str, list[str]]] func_group_mapping: dict[OperatorName, NativeFunctionsGroup] +<<<<<<< HEAD dispatch_key: Optional[DispatchKey] +======= + dispatch_key: DispatchKey +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) backend_indices: dict[DispatchKey, BackendIndex] header: bool # True to generate .h and False to generate .cpp extend_aoti_c_shim: bool @@ -555,7 +618,11 @@ def gen_aoti_c_shim( native_functions: Sequence[NativeFunction], inductor_fallback_ops: dict[str, dict[str, list[str]]], func_group_mapping: dict[OperatorName, NativeFunctionsGroup], +<<<<<<< HEAD dispatch_key: Optional[DispatchKey], +======= + dispatch_key: DispatchKey, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) backend_indices: dict[DispatchKey, BackendIndex], header: bool, extend_aoti_c_shim: bool, @@ -576,6 +643,7 @@ def gen_aoti_c_shim( ) ) ) +<<<<<<< HEAD device = "aten" if dispatch_key is None else dispatch_key.lower() include_device_functions = ( "#include " @@ -589,6 +657,9 @@ def gen_aoti_c_shim( if dispatch_key is None else "" ) +======= + device = dispatch_key.lower() +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) warning = """ // WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND. @@ -597,7 +668,10 @@ def gen_aoti_c_shim( if header: return ( warning +<<<<<<< HEAD + aten_warning +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + textwrap.dedent(""" #pragma once @@ -620,14 +694,21 @@ def gen_aoti_c_shim( else: return ( warning +<<<<<<< HEAD + aten_warning +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) + textwrap.dedent(f""" #include #include #ifndef AT_PER_OPERATOR_HEADERS +<<<<<<< HEAD {include_device_functions} +======= + #include +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) #include #include #include @@ -646,7 +727,11 @@ def gen_aoti_c_shim( def gen_aoti_c_shim_files( aoti_fm: FileManager, +<<<<<<< HEAD aoti_backends: set[Optional[DispatchKey]], +======= + aoti_backends: set[DispatchKey], +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) native_functions: Sequence[NativeFunction], backend_indices: dict[DispatchKey, BackendIndex], structured_native_functions: Sequence[NativeFunctionsGroup], @@ -662,6 +747,7 @@ def gen_aoti_c_shim_files( break for dispatch_key in aoti_backends: +<<<<<<< HEAD # Use aten_shimified_ops for the aten backend, inductor_fallback_ops for others fallback_ops_dict = ( aten_shimified_ops if dispatch_key is None else inductor_fallback_ops @@ -670,11 +756,18 @@ def gen_aoti_c_shim_files( for func in native_functions: op_name = get_fallback_op_name(func) if op_name in fallback_ops_dict: +======= + fallbacks = {} + for func in native_functions: + op_name = get_fallback_op_name(func) + if op_name in inductor_fallback_ops: +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) fallbacks[op_name] = func fallback_native_functions = tuple( value for _, value in sorted(fallbacks.items()) ) +<<<<<<< HEAD # Use "aten" as the device name when dispatch_key is Generic device_name = "aten" if dispatch_key is None else dispatch_key.lower() @@ -683,6 +776,13 @@ def gen_aoti_c_shim_files( new_header = gen_aoti_c_shim( fallback_native_functions, fallback_ops_dict, +======= + # header files were checked in for ABI-compatiblilty checking + header_file_name = f"c_shim_{dispatch_key.lower()}.h" + new_header = gen_aoti_c_shim( + fallback_native_functions, + inductor_fallback_ops, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) structured_func_group_dict, dispatch_key, backend_indices, @@ -750,6 +850,7 @@ def headers_for_aoti() -> str: headers.append(header) return "\n".join(sorted(set(headers))) +<<<<<<< HEAD extra_headers = ( extra_cuda_headers if dispatch_key is not None and is_cuda_dispatch_key(dispatch_key) @@ -761,6 +862,15 @@ def headers_for_aoti() -> str: lambda: gen_aoti_c_shim( fallback_native_functions, fallback_ops_dict, +======= + extra_headers = extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else "" + + aoti_fm.write( + f"c_shim_{dispatch_key.lower()}.cpp", + lambda: gen_aoti_c_shim( + fallback_native_functions, + inductor_fallback_ops, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) structured_func_group_dict, dispatch_key, backend_indices, diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index f47985837eac6..773ae9cbd944c 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -1,15 +1,25 @@ from __future__ import annotations from dataclasses import dataclass +<<<<<<< HEAD from typing import Callable, Optional, TYPE_CHECKING from torchgen.api import cpp, dispatcher, functionalization +======= +from typing import Callable, TYPE_CHECKING + +from torchgen.api import cpp, dispatcher +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) from torchgen.api.translate import translate from torchgen.api.types import ( BaseCType, Binding, CType, DispatcherSignature, +<<<<<<< HEAD +======= + FunctionalizationLambda, +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) iTensorListRefT, NativeSignature, OptionalCType, @@ -47,7 +57,11 @@ MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, ) +<<<<<<< HEAD from torchgen.utils import concatMap, dataclass_repr, FileManager +======= +from torchgen.utils import dataclass_repr +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) if TYPE_CHECKING: @@ -364,8 +378,11 @@ def emit_view_functionalization_body( with native_function_manager(f): call_sig = DispatcherSignature.from_schema(g.view_copy.func) +<<<<<<< HEAD spec = ViewMetaSpecialization(g, f=f) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # the "view_copy" op name that the functionalization kernels need to call api_name = g.view_copy.func.name.unambiguous_name() # Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors) @@ -386,6 +403,12 @@ def emit_view_functionalization_body( for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False) ] +<<<<<<< HEAD +======= + forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False) + reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True) + +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) # The meta API call should use the same arguments, but convert all tensors to meta tensors first. meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) meta_call_args = [ @@ -413,7 +436,23 @@ def emit_view_functionalization_body( : at::functionalization::InverseReturnMode::NeverView ); {symbolic_inputs_check} +<<<<<<< HEAD auto view_meta = {spec.new()}; +======= + at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( + {forward_lambda.decl()} {{ + if (reapply_views) {{ + return {forward_lambda.inner_call(reapply_views=True)} + }} else {{ + return {forward_lambda.inner_call(reapply_views=False)} + }} + }}, + {reverse_lambda.decl()} {{ + return {reverse_lambda.inner_call()} + }}, + /*has_symbolic_inputs=*/{symbolic_inputs_varname} + ); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto compute_reference_meta = {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) || {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit); @@ -441,6 +480,10 @@ def emit_view_functionalization_body( """ else: +<<<<<<< HEAD +======= + is_multi_output_view = isinstance(f.func.returns[0].type, ListType) +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) return f""" {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ {unwrap_tensor_args_str} @@ -474,7 +517,25 @@ def emit_view_functionalization_body( }} }} {symbolic_inputs_check} +<<<<<<< HEAD auto view_meta = {spec.new()}; +======= + at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( + {forward_lambda.decl()} {{ + if (reapply_views) {{ + return {forward_lambda.inner_call(reapply_views=True)} + }} else {{ + return {forward_lambda.inner_call(reapply_views=False)} + }} + }}, + {reverse_lambda.decl()} {{ + return {reverse_lambda.inner_call()} + }}, + /*has_symbolic_inputs=*/{symbolic_inputs_varname}, + /*is_multi_output=*/{str(is_multi_output_view).lower()}, + /*is_as_strided=*/{str(str(f.func.name) == "as_strided").lower()} + ); +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta); // See Note [Propagating strides in the functionalization pass] if (compute_reference_meta && !disable_meta_reference()) {{ @@ -742,6 +803,7 @@ def emit_decl_helper(g: NativeFunctionsViewGroup) -> str | None: return emit_decl_helper(g) +<<<<<<< HEAD # Helper class for generating `ViewMeta` specializations. @dataclass class ViewMetaSpecialization: @@ -1037,6 +1099,8 @@ def gen_functionalization_view_meta_classes( ) +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) def gen_functionalization_registration( selector: SelectiveBuilder, g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, diff --git a/torchgen/model.py b/torchgen/model.py index eb3a80dffe6a0..52be2d6563fbf 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -288,8 +288,11 @@ def codegen_per_backend_entries() -> str: DispatchKey.SparseCsrXPU, DispatchKey.SparseCUDA, DispatchKey.SparseCsrCUDA, +<<<<<<< HEAD DispatchKey.SparseMPS, DispatchKey.SparseCsrMPS, +======= +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) DispatchKey.QuantizedCPU, DispatchKey.QuantizedCUDA, DispatchKey.CompositeImplicitAutograd, diff --git a/torchgen/selective_build/operator.py b/torchgen/selective_build/operator.py index 8047f033e3d2b..9007b400332d5 100644 --- a/torchgen/selective_build/operator.py +++ b/torchgen/selective_build/operator.py @@ -168,4 +168,8 @@ def merge_operator_dicts( def strip_operator_overload_name(op_name: str) -> str: +<<<<<<< HEAD return op_name.split(".", maxsplit=1)[0] +======= + return op_name.split(".")[0] +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)) diff --git a/version.txt b/version.txt index c8e38b614057b..fb7c19039d5e7 100644 --- a/version.txt +++ b/version.txt @@ -1 +1,5 @@ +<<<<<<< HEAD 2.9.0 +======= +2.8.0 +>>>>>>> 5729657180 ([ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791))